diff --git a/vr180_matting/config.py b/vr180_matting/config.py index da8d843..cca8d53 100644 --- a/vr180_matting/config.py +++ b/vr180_matting/config.py @@ -29,6 +29,11 @@ class MattingConfig: fp16: bool = True sam2_model_cfg: str = "sam2.1_hiera_l" sam2_checkpoint: str = "segment-anything-2/checkpoints/sam2.1_hiera_large.pt" + # Det-SAM2 optimizations + continuous_correction: bool = True + correction_interval: int = 60 # Add correction prompts every N frames + frame_release_interval: int = 50 # Release old frames every N frames + frame_window_size: int = 30 # Keep N frames in memory @dataclass diff --git a/vr180_matting/sam2_wrapper.py b/vr180_matting/sam2_wrapper.py index 45f74c1..b406bbe 100644 --- a/vr180_matting/sam2_wrapper.py +++ b/vr180_matting/sam2_wrapper.py @@ -152,13 +152,16 @@ class SAM2VideoMatting: return object_ids - def propagate_masks(self, start_frame: int = 0, max_frames: Optional[int] = None) -> Dict[int, Dict[int, np.ndarray]]: + def propagate_masks(self, start_frame: int = 0, max_frames: Optional[int] = None, + frame_release_interval: int = 50, frame_window_size: int = 30) -> Dict[int, Dict[int, np.ndarray]]: """ - Propagate masks through video + Propagate masks through video with Det-SAM2 style memory management Args: start_frame: Starting frame index max_frames: Maximum number of frames to process + frame_release_interval: Release old frames every N frames + frame_window_size: Keep N frames in memory Returns: Dictionary mapping frame_idx -> {obj_id: mask} @@ -182,9 +185,108 @@ class SAM2VideoMatting: video_segments[out_frame_idx] = frame_masks - # Memory management: release old frames periodically - if self.memory_offload and out_frame_idx % 100 == 0: - self._release_old_frames(out_frame_idx - 50) + # Det-SAM2 style memory management: more aggressive frame release + if self.memory_offload and out_frame_idx % frame_release_interval == 0: + self._release_old_frames(out_frame_idx - frame_window_size) + # Optional: Log frame release for monitoring + if out_frame_idx % (frame_release_interval * 4) == 0: # Log every 4x release interval + print(f"Det-SAM2: Released frames before {out_frame_idx - frame_window_size}, keeping {frame_window_size} frames") + + return video_segments + + def propagate_masks_with_continuous_correction(self, + detector, + temp_video_path: str, + start_frame: int = 0, + max_frames: Optional[int] = None, + correction_interval: int = 60, + frame_release_interval: int = 50, + frame_window_size: int = 30) -> Dict[int, Dict[int, np.ndarray]]: + """ + Det-SAM2 style: Propagate masks with continuous prompt correction + + Args: + detector: YOLODetector instance for generating correction prompts + temp_video_path: Path to video file for frame access + start_frame: Starting frame index + max_frames: Maximum number of frames to process + correction_interval: Add correction prompts every N frames + frame_release_interval: Release old frames every N frames + frame_window_size: Keep N frames in memory + + Returns: + Dictionary mapping frame_idx -> {obj_id: mask} + """ + if self.inference_state is None: + raise RuntimeError("Video state not initialized") + + video_segments = {} + max_frames = max_frames or 10000 # Default limit + + # Open video for accessing frames during propagation + cap = cv2.VideoCapture(str(temp_video_path)) + + try: + for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video( + self.inference_state, + start_frame_idx=start_frame, + max_frame_num_to_track=max_frames, + reverse=False + ): + frame_masks = {} + + for i, out_obj_id in enumerate(out_obj_ids): + mask = (out_mask_logits[i] > 0.0).cpu().numpy() + frame_masks[out_obj_id] = mask + + video_segments[out_frame_idx] = frame_masks + + # Det-SAM2 optimization: Add correction prompts at keyframes + if (out_frame_idx % correction_interval == 0 and + out_frame_idx > start_frame and + out_frame_idx < max_frames - 1): + + # Read frame for detection + cap.set(cv2.CAP_PROP_POS_FRAMES, out_frame_idx) + ret, correction_frame = cap.read() + + if ret: + # Run detection on this keyframe + detections = detector.detect_persons(correction_frame) + + if detections: + # Convert to prompts and add as corrections + box_prompts, labels = detector.convert_to_sam_prompts(detections) + + # Add correction prompts (SAM2 will propagate backward) + correction_count = 0 + try: + for i, (box, label) in enumerate(zip(box_prompts, labels)): + # Use existing object IDs if available, otherwise create new ones + obj_id = out_obj_ids[i] if i < len(out_obj_ids) else len(out_obj_ids) + i + 1 + + self.predictor.add_new_points_or_box( + inference_state=self.inference_state, + frame_idx=out_frame_idx, + obj_id=obj_id, + box=box, + ) + correction_count += 1 + + print(f"Det-SAM2: Added {correction_count} correction prompts at frame {out_frame_idx}") + + except Exception as e: + warnings.warn(f"Failed to add correction prompt at frame {out_frame_idx}: {e}") + + # Memory management: More aggressive frame release (Det-SAM2 style) + if self.memory_offload and out_frame_idx % frame_release_interval == 0: + self._release_old_frames(out_frame_idx - frame_window_size) + # Optional: Log frame release for monitoring + if out_frame_idx % (frame_release_interval * 4) == 0: # Log every 4x release interval + print(f"Det-SAM2: Released frames before {out_frame_idx - frame_window_size}, keeping {frame_window_size} frames") + + finally: + cap.release() return video_segments diff --git a/vr180_matting/vr180_processor.py b/vr180_matting/vr180_processor.py index 37280fc..5fe875e 100644 --- a/vr180_matting/vr180_processor.py +++ b/vr180_matting/vr180_processor.py @@ -375,10 +375,27 @@ class VR180Processor(VideoProcessor): # Propagate masks (most expensive operation) self._print_memory_step(f"Before SAM2 propagation ({eye_name} eye, {num_frames} frames)") - video_segments = self.sam2_model.propagate_masks( - start_frame=0, - max_frames=num_frames - ) + + # Use Det-SAM2 continuous correction if enabled + if self.config.matting.continuous_correction: + video_segments = self.sam2_model.propagate_masks_with_continuous_correction( + detector=self.detector, + temp_video_path=str(temp_video_path), + start_frame=0, + max_frames=num_frames, + correction_interval=self.config.matting.correction_interval, + frame_release_interval=self.config.matting.frame_release_interval, + frame_window_size=self.config.matting.frame_window_size + ) + print(f"Used Det-SAM2 continuous correction (interval: {self.config.matting.correction_interval} frames)") + else: + video_segments = self.sam2_model.propagate_masks( + start_frame=0, + max_frames=num_frames, + frame_release_interval=self.config.matting.frame_release_interval, + frame_window_size=self.config.matting.frame_window_size + ) + self._print_memory_step(f"After SAM2 propagation ({eye_name} eye)") # Apply masks - need to reload frames from temp video since we freed the original frames