fix again
This commit is contained in:
@@ -5,6 +5,8 @@ import cv2
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import warnings
|
import warnings
|
||||||
import os
|
import os
|
||||||
|
import tempfile
|
||||||
|
import shutil
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from sam2.build_sam import build_sam2_video_predictor
|
from sam2.build_sam import build_sam2_video_predictor
|
||||||
@@ -33,6 +35,7 @@ class SAM2VideoMatting:
|
|||||||
self.predictor = None
|
self.predictor = None
|
||||||
self.inference_state = None
|
self.inference_state = None
|
||||||
self.video_segments = {}
|
self.video_segments = {}
|
||||||
|
self.temp_video_path = None
|
||||||
|
|
||||||
self._load_model(model_cfg, checkpoint_path)
|
self._load_model(model_cfg, checkpoint_path)
|
||||||
|
|
||||||
@@ -68,18 +71,46 @@ class SAM2VideoMatting:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Failed to load SAM2 model: {e}")
|
raise RuntimeError(f"Failed to load SAM2 model: {e}")
|
||||||
|
|
||||||
def init_video_state(self, video_frames: List[np.ndarray]) -> None:
|
def init_video_state(self, video_frames: List[np.ndarray] = None, video_path: str = None) -> None:
|
||||||
"""Initialize video inference state"""
|
"""Initialize video inference state"""
|
||||||
if self.predictor is None:
|
if self.predictor is None:
|
||||||
raise RuntimeError("SAM2 model not loaded")
|
raise RuntimeError("SAM2 model not loaded")
|
||||||
|
|
||||||
# Create temporary directory for frames if needed
|
if video_path is not None:
|
||||||
|
# Use video path directly (SAM2's preferred method)
|
||||||
self.inference_state = self.predictor.init_state(
|
self.inference_state = self.predictor.init_state(
|
||||||
video_path=None,
|
video_path=video_path,
|
||||||
video_frames=video_frames,
|
|
||||||
offload_video_to_cpu=self.memory_offload,
|
offload_video_to_cpu=self.memory_offload,
|
||||||
async_loading_frames=True
|
async_loading_frames=True
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
# For frame arrays, we need to save them as a temporary video first
|
||||||
|
|
||||||
|
if video_frames is None or len(video_frames) == 0:
|
||||||
|
raise ValueError("Either video_path or video_frames must be provided")
|
||||||
|
|
||||||
|
# Create temporary video file
|
||||||
|
temp_dir = tempfile.mkdtemp()
|
||||||
|
temp_video_path = Path(temp_dir) / "temp_video.mp4"
|
||||||
|
|
||||||
|
# Write frames to temporary video
|
||||||
|
height, width = video_frames[0].shape[:2]
|
||||||
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
||||||
|
writer = cv2.VideoWriter(str(temp_video_path), fourcc, 30.0, (width, height))
|
||||||
|
|
||||||
|
for frame in video_frames:
|
||||||
|
writer.write(frame)
|
||||||
|
writer.release()
|
||||||
|
|
||||||
|
# Initialize with temporary video
|
||||||
|
self.inference_state = self.predictor.init_state(
|
||||||
|
video_path=str(temp_video_path),
|
||||||
|
offload_video_to_cpu=self.memory_offload,
|
||||||
|
async_loading_frames=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store temp path for cleanup
|
||||||
|
self.temp_video_path = temp_video_path
|
||||||
|
|
||||||
def add_person_prompts(self,
|
def add_person_prompts(self,
|
||||||
frame_idx: int,
|
frame_idx: int,
|
||||||
@@ -231,6 +262,16 @@ class SAM2VideoMatting:
|
|||||||
|
|
||||||
self.inference_state = None
|
self.inference_state = None
|
||||||
|
|
||||||
|
# Clean up temporary video file
|
||||||
|
if self.temp_video_path is not None:
|
||||||
|
try:
|
||||||
|
if self.temp_video_path.exists():
|
||||||
|
# Remove the temporary directory
|
||||||
|
shutil.rmtree(self.temp_video_path.parent)
|
||||||
|
self.temp_video_path = None
|
||||||
|
except Exception as e:
|
||||||
|
warnings.warn(f"Failed to cleanup temp video: {e}")
|
||||||
|
|
||||||
# Clear CUDA cache
|
# Clear CUDA cache
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|||||||
Reference in New Issue
Block a user