diff --git a/vr180_matting/sam2_wrapper.py b/vr180_matting/sam2_wrapper.py index d2134ad..7cc2c41 100644 --- a/vr180_matting/sam2_wrapper.py +++ b/vr180_matting/sam2_wrapper.py @@ -5,6 +5,8 @@ import cv2 from pathlib import Path import warnings import os +import tempfile +import shutil try: from sam2.build_sam import build_sam2_video_predictor @@ -33,6 +35,7 @@ class SAM2VideoMatting: self.predictor = None self.inference_state = None self.video_segments = {} + self.temp_video_path = None self._load_model(model_cfg, checkpoint_path) @@ -68,18 +71,46 @@ class SAM2VideoMatting: except Exception as e: raise RuntimeError(f"Failed to load SAM2 model: {e}") - def init_video_state(self, video_frames: List[np.ndarray]) -> None: + def init_video_state(self, video_frames: List[np.ndarray] = None, video_path: str = None) -> None: """Initialize video inference state""" if self.predictor is None: raise RuntimeError("SAM2 model not loaded") - # Create temporary directory for frames if needed - self.inference_state = self.predictor.init_state( - video_path=None, - video_frames=video_frames, - offload_video_to_cpu=self.memory_offload, - async_loading_frames=True - ) + if video_path is not None: + # Use video path directly (SAM2's preferred method) + self.inference_state = self.predictor.init_state( + video_path=video_path, + offload_video_to_cpu=self.memory_offload, + async_loading_frames=True + ) + else: + # For frame arrays, we need to save them as a temporary video first + + if video_frames is None or len(video_frames) == 0: + raise ValueError("Either video_path or video_frames must be provided") + + # Create temporary video file + temp_dir = tempfile.mkdtemp() + temp_video_path = Path(temp_dir) / "temp_video.mp4" + + # Write frames to temporary video + height, width = video_frames[0].shape[:2] + fourcc = cv2.VideoWriter_fourcc(*'mp4v') + writer = cv2.VideoWriter(str(temp_video_path), fourcc, 30.0, (width, height)) + + for frame in video_frames: + writer.write(frame) + writer.release() + + # Initialize with temporary video + self.inference_state = self.predictor.init_state( + video_path=str(temp_video_path), + offload_video_to_cpu=self.memory_offload, + async_loading_frames=True + ) + + # Store temp path for cleanup + self.temp_video_path = temp_video_path def add_person_prompts(self, frame_idx: int, @@ -231,6 +262,16 @@ class SAM2VideoMatting: self.inference_state = None + # Clean up temporary video file + if self.temp_video_path is not None: + try: + if self.temp_video_path.exists(): + # Remove the temporary directory + shutil.rmtree(self.temp_video_path.parent) + self.temp_video_path = None + except Exception as e: + warnings.warn(f"Failed to cleanup temp video: {e}") + # Clear CUDA cache if torch.cuda.is_available(): torch.cuda.empty_cache()