This commit is contained in:
2025-07-26 13:51:21 -07:00
parent 6f93abcb08
commit 80f947c91b
3 changed files with 133 additions and 9 deletions

View File

@@ -29,6 +29,11 @@ class MattingConfig:
fp16: bool = True fp16: bool = True
sam2_model_cfg: str = "sam2.1_hiera_l" sam2_model_cfg: str = "sam2.1_hiera_l"
sam2_checkpoint: str = "segment-anything-2/checkpoints/sam2.1_hiera_large.pt" sam2_checkpoint: str = "segment-anything-2/checkpoints/sam2.1_hiera_large.pt"
# Det-SAM2 optimizations
continuous_correction: bool = True
correction_interval: int = 60 # Add correction prompts every N frames
frame_release_interval: int = 50 # Release old frames every N frames
frame_window_size: int = 30 # Keep N frames in memory
@dataclass @dataclass

View File

@@ -152,13 +152,16 @@ class SAM2VideoMatting:
return object_ids return object_ids
def propagate_masks(self, start_frame: int = 0, max_frames: Optional[int] = None) -> Dict[int, Dict[int, np.ndarray]]: def propagate_masks(self, start_frame: int = 0, max_frames: Optional[int] = None,
frame_release_interval: int = 50, frame_window_size: int = 30) -> Dict[int, Dict[int, np.ndarray]]:
""" """
Propagate masks through video Propagate masks through video with Det-SAM2 style memory management
Args: Args:
start_frame: Starting frame index start_frame: Starting frame index
max_frames: Maximum number of frames to process max_frames: Maximum number of frames to process
frame_release_interval: Release old frames every N frames
frame_window_size: Keep N frames in memory
Returns: Returns:
Dictionary mapping frame_idx -> {obj_id: mask} Dictionary mapping frame_idx -> {obj_id: mask}
@@ -182,9 +185,108 @@ class SAM2VideoMatting:
video_segments[out_frame_idx] = frame_masks video_segments[out_frame_idx] = frame_masks
# Memory management: release old frames periodically # Det-SAM2 style memory management: more aggressive frame release
if self.memory_offload and out_frame_idx % 100 == 0: if self.memory_offload and out_frame_idx % frame_release_interval == 0:
self._release_old_frames(out_frame_idx - 50) self._release_old_frames(out_frame_idx - frame_window_size)
# Optional: Log frame release for monitoring
if out_frame_idx % (frame_release_interval * 4) == 0: # Log every 4x release interval
print(f"Det-SAM2: Released frames before {out_frame_idx - frame_window_size}, keeping {frame_window_size} frames")
return video_segments
def propagate_masks_with_continuous_correction(self,
detector,
temp_video_path: str,
start_frame: int = 0,
max_frames: Optional[int] = None,
correction_interval: int = 60,
frame_release_interval: int = 50,
frame_window_size: int = 30) -> Dict[int, Dict[int, np.ndarray]]:
"""
Det-SAM2 style: Propagate masks with continuous prompt correction
Args:
detector: YOLODetector instance for generating correction prompts
temp_video_path: Path to video file for frame access
start_frame: Starting frame index
max_frames: Maximum number of frames to process
correction_interval: Add correction prompts every N frames
frame_release_interval: Release old frames every N frames
frame_window_size: Keep N frames in memory
Returns:
Dictionary mapping frame_idx -> {obj_id: mask}
"""
if self.inference_state is None:
raise RuntimeError("Video state not initialized")
video_segments = {}
max_frames = max_frames or 10000 # Default limit
# Open video for accessing frames during propagation
cap = cv2.VideoCapture(str(temp_video_path))
try:
for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(
self.inference_state,
start_frame_idx=start_frame,
max_frame_num_to_track=max_frames,
reverse=False
):
frame_masks = {}
for i, out_obj_id in enumerate(out_obj_ids):
mask = (out_mask_logits[i] > 0.0).cpu().numpy()
frame_masks[out_obj_id] = mask
video_segments[out_frame_idx] = frame_masks
# Det-SAM2 optimization: Add correction prompts at keyframes
if (out_frame_idx % correction_interval == 0 and
out_frame_idx > start_frame and
out_frame_idx < max_frames - 1):
# Read frame for detection
cap.set(cv2.CAP_PROP_POS_FRAMES, out_frame_idx)
ret, correction_frame = cap.read()
if ret:
# Run detection on this keyframe
detections = detector.detect_persons(correction_frame)
if detections:
# Convert to prompts and add as corrections
box_prompts, labels = detector.convert_to_sam_prompts(detections)
# Add correction prompts (SAM2 will propagate backward)
correction_count = 0
try:
for i, (box, label) in enumerate(zip(box_prompts, labels)):
# Use existing object IDs if available, otherwise create new ones
obj_id = out_obj_ids[i] if i < len(out_obj_ids) else len(out_obj_ids) + i + 1
self.predictor.add_new_points_or_box(
inference_state=self.inference_state,
frame_idx=out_frame_idx,
obj_id=obj_id,
box=box,
)
correction_count += 1
print(f"Det-SAM2: Added {correction_count} correction prompts at frame {out_frame_idx}")
except Exception as e:
warnings.warn(f"Failed to add correction prompt at frame {out_frame_idx}: {e}")
# Memory management: More aggressive frame release (Det-SAM2 style)
if self.memory_offload and out_frame_idx % frame_release_interval == 0:
self._release_old_frames(out_frame_idx - frame_window_size)
# Optional: Log frame release for monitoring
if out_frame_idx % (frame_release_interval * 4) == 0: # Log every 4x release interval
print(f"Det-SAM2: Released frames before {out_frame_idx - frame_window_size}, keeping {frame_window_size} frames")
finally:
cap.release()
return video_segments return video_segments

View File

@@ -375,10 +375,27 @@ class VR180Processor(VideoProcessor):
# Propagate masks (most expensive operation) # Propagate masks (most expensive operation)
self._print_memory_step(f"Before SAM2 propagation ({eye_name} eye, {num_frames} frames)") self._print_memory_step(f"Before SAM2 propagation ({eye_name} eye, {num_frames} frames)")
video_segments = self.sam2_model.propagate_masks(
start_frame=0, # Use Det-SAM2 continuous correction if enabled
max_frames=num_frames if self.config.matting.continuous_correction:
) video_segments = self.sam2_model.propagate_masks_with_continuous_correction(
detector=self.detector,
temp_video_path=str(temp_video_path),
start_frame=0,
max_frames=num_frames,
correction_interval=self.config.matting.correction_interval,
frame_release_interval=self.config.matting.frame_release_interval,
frame_window_size=self.config.matting.frame_window_size
)
print(f"Used Det-SAM2 continuous correction (interval: {self.config.matting.correction_interval} frames)")
else:
video_segments = self.sam2_model.propagate_masks(
start_frame=0,
max_frames=num_frames,
frame_release_interval=self.config.matting.frame_release_interval,
frame_window_size=self.config.matting.frame_window_size
)
self._print_memory_step(f"After SAM2 propagation ({eye_name} eye)") self._print_memory_step(f"After SAM2 propagation ({eye_name} eye)")
# Apply masks - need to reload frames from temp video since we freed the original frames # Apply masks - need to reload frames from temp video since we freed the original frames