det core
This commit is contained in:
@@ -29,6 +29,11 @@ class MattingConfig:
|
|||||||
fp16: bool = True
|
fp16: bool = True
|
||||||
sam2_model_cfg: str = "sam2.1_hiera_l"
|
sam2_model_cfg: str = "sam2.1_hiera_l"
|
||||||
sam2_checkpoint: str = "segment-anything-2/checkpoints/sam2.1_hiera_large.pt"
|
sam2_checkpoint: str = "segment-anything-2/checkpoints/sam2.1_hiera_large.pt"
|
||||||
|
# Det-SAM2 optimizations
|
||||||
|
continuous_correction: bool = True
|
||||||
|
correction_interval: int = 60 # Add correction prompts every N frames
|
||||||
|
frame_release_interval: int = 50 # Release old frames every N frames
|
||||||
|
frame_window_size: int = 30 # Keep N frames in memory
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -152,13 +152,16 @@ class SAM2VideoMatting:
|
|||||||
|
|
||||||
return object_ids
|
return object_ids
|
||||||
|
|
||||||
def propagate_masks(self, start_frame: int = 0, max_frames: Optional[int] = None) -> Dict[int, Dict[int, np.ndarray]]:
|
def propagate_masks(self, start_frame: int = 0, max_frames: Optional[int] = None,
|
||||||
|
frame_release_interval: int = 50, frame_window_size: int = 30) -> Dict[int, Dict[int, np.ndarray]]:
|
||||||
"""
|
"""
|
||||||
Propagate masks through video
|
Propagate masks through video with Det-SAM2 style memory management
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
start_frame: Starting frame index
|
start_frame: Starting frame index
|
||||||
max_frames: Maximum number of frames to process
|
max_frames: Maximum number of frames to process
|
||||||
|
frame_release_interval: Release old frames every N frames
|
||||||
|
frame_window_size: Keep N frames in memory
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dictionary mapping frame_idx -> {obj_id: mask}
|
Dictionary mapping frame_idx -> {obj_id: mask}
|
||||||
@@ -182,9 +185,108 @@ class SAM2VideoMatting:
|
|||||||
|
|
||||||
video_segments[out_frame_idx] = frame_masks
|
video_segments[out_frame_idx] = frame_masks
|
||||||
|
|
||||||
# Memory management: release old frames periodically
|
# Det-SAM2 style memory management: more aggressive frame release
|
||||||
if self.memory_offload and out_frame_idx % 100 == 0:
|
if self.memory_offload and out_frame_idx % frame_release_interval == 0:
|
||||||
self._release_old_frames(out_frame_idx - 50)
|
self._release_old_frames(out_frame_idx - frame_window_size)
|
||||||
|
# Optional: Log frame release for monitoring
|
||||||
|
if out_frame_idx % (frame_release_interval * 4) == 0: # Log every 4x release interval
|
||||||
|
print(f"Det-SAM2: Released frames before {out_frame_idx - frame_window_size}, keeping {frame_window_size} frames")
|
||||||
|
|
||||||
|
return video_segments
|
||||||
|
|
||||||
|
def propagate_masks_with_continuous_correction(self,
|
||||||
|
detector,
|
||||||
|
temp_video_path: str,
|
||||||
|
start_frame: int = 0,
|
||||||
|
max_frames: Optional[int] = None,
|
||||||
|
correction_interval: int = 60,
|
||||||
|
frame_release_interval: int = 50,
|
||||||
|
frame_window_size: int = 30) -> Dict[int, Dict[int, np.ndarray]]:
|
||||||
|
"""
|
||||||
|
Det-SAM2 style: Propagate masks with continuous prompt correction
|
||||||
|
|
||||||
|
Args:
|
||||||
|
detector: YOLODetector instance for generating correction prompts
|
||||||
|
temp_video_path: Path to video file for frame access
|
||||||
|
start_frame: Starting frame index
|
||||||
|
max_frames: Maximum number of frames to process
|
||||||
|
correction_interval: Add correction prompts every N frames
|
||||||
|
frame_release_interval: Release old frames every N frames
|
||||||
|
frame_window_size: Keep N frames in memory
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping frame_idx -> {obj_id: mask}
|
||||||
|
"""
|
||||||
|
if self.inference_state is None:
|
||||||
|
raise RuntimeError("Video state not initialized")
|
||||||
|
|
||||||
|
video_segments = {}
|
||||||
|
max_frames = max_frames or 10000 # Default limit
|
||||||
|
|
||||||
|
# Open video for accessing frames during propagation
|
||||||
|
cap = cv2.VideoCapture(str(temp_video_path))
|
||||||
|
|
||||||
|
try:
|
||||||
|
for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(
|
||||||
|
self.inference_state,
|
||||||
|
start_frame_idx=start_frame,
|
||||||
|
max_frame_num_to_track=max_frames,
|
||||||
|
reverse=False
|
||||||
|
):
|
||||||
|
frame_masks = {}
|
||||||
|
|
||||||
|
for i, out_obj_id in enumerate(out_obj_ids):
|
||||||
|
mask = (out_mask_logits[i] > 0.0).cpu().numpy()
|
||||||
|
frame_masks[out_obj_id] = mask
|
||||||
|
|
||||||
|
video_segments[out_frame_idx] = frame_masks
|
||||||
|
|
||||||
|
# Det-SAM2 optimization: Add correction prompts at keyframes
|
||||||
|
if (out_frame_idx % correction_interval == 0 and
|
||||||
|
out_frame_idx > start_frame and
|
||||||
|
out_frame_idx < max_frames - 1):
|
||||||
|
|
||||||
|
# Read frame for detection
|
||||||
|
cap.set(cv2.CAP_PROP_POS_FRAMES, out_frame_idx)
|
||||||
|
ret, correction_frame = cap.read()
|
||||||
|
|
||||||
|
if ret:
|
||||||
|
# Run detection on this keyframe
|
||||||
|
detections = detector.detect_persons(correction_frame)
|
||||||
|
|
||||||
|
if detections:
|
||||||
|
# Convert to prompts and add as corrections
|
||||||
|
box_prompts, labels = detector.convert_to_sam_prompts(detections)
|
||||||
|
|
||||||
|
# Add correction prompts (SAM2 will propagate backward)
|
||||||
|
correction_count = 0
|
||||||
|
try:
|
||||||
|
for i, (box, label) in enumerate(zip(box_prompts, labels)):
|
||||||
|
# Use existing object IDs if available, otherwise create new ones
|
||||||
|
obj_id = out_obj_ids[i] if i < len(out_obj_ids) else len(out_obj_ids) + i + 1
|
||||||
|
|
||||||
|
self.predictor.add_new_points_or_box(
|
||||||
|
inference_state=self.inference_state,
|
||||||
|
frame_idx=out_frame_idx,
|
||||||
|
obj_id=obj_id,
|
||||||
|
box=box,
|
||||||
|
)
|
||||||
|
correction_count += 1
|
||||||
|
|
||||||
|
print(f"Det-SAM2: Added {correction_count} correction prompts at frame {out_frame_idx}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
warnings.warn(f"Failed to add correction prompt at frame {out_frame_idx}: {e}")
|
||||||
|
|
||||||
|
# Memory management: More aggressive frame release (Det-SAM2 style)
|
||||||
|
if self.memory_offload and out_frame_idx % frame_release_interval == 0:
|
||||||
|
self._release_old_frames(out_frame_idx - frame_window_size)
|
||||||
|
# Optional: Log frame release for monitoring
|
||||||
|
if out_frame_idx % (frame_release_interval * 4) == 0: # Log every 4x release interval
|
||||||
|
print(f"Det-SAM2: Released frames before {out_frame_idx - frame_window_size}, keeping {frame_window_size} frames")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
cap.release()
|
||||||
|
|
||||||
return video_segments
|
return video_segments
|
||||||
|
|
||||||
|
|||||||
@@ -375,10 +375,27 @@ class VR180Processor(VideoProcessor):
|
|||||||
|
|
||||||
# Propagate masks (most expensive operation)
|
# Propagate masks (most expensive operation)
|
||||||
self._print_memory_step(f"Before SAM2 propagation ({eye_name} eye, {num_frames} frames)")
|
self._print_memory_step(f"Before SAM2 propagation ({eye_name} eye, {num_frames} frames)")
|
||||||
video_segments = self.sam2_model.propagate_masks(
|
|
||||||
start_frame=0,
|
# Use Det-SAM2 continuous correction if enabled
|
||||||
max_frames=num_frames
|
if self.config.matting.continuous_correction:
|
||||||
)
|
video_segments = self.sam2_model.propagate_masks_with_continuous_correction(
|
||||||
|
detector=self.detector,
|
||||||
|
temp_video_path=str(temp_video_path),
|
||||||
|
start_frame=0,
|
||||||
|
max_frames=num_frames,
|
||||||
|
correction_interval=self.config.matting.correction_interval,
|
||||||
|
frame_release_interval=self.config.matting.frame_release_interval,
|
||||||
|
frame_window_size=self.config.matting.frame_window_size
|
||||||
|
)
|
||||||
|
print(f"Used Det-SAM2 continuous correction (interval: {self.config.matting.correction_interval} frames)")
|
||||||
|
else:
|
||||||
|
video_segments = self.sam2_model.propagate_masks(
|
||||||
|
start_frame=0,
|
||||||
|
max_frames=num_frames,
|
||||||
|
frame_release_interval=self.config.matting.frame_release_interval,
|
||||||
|
frame_window_size=self.config.matting.frame_window_size
|
||||||
|
)
|
||||||
|
|
||||||
self._print_memory_step(f"After SAM2 propagation ({eye_name} eye)")
|
self._print_memory_step(f"After SAM2 propagation ({eye_name} eye)")
|
||||||
|
|
||||||
# Apply masks - need to reload frames from temp video since we freed the original frames
|
# Apply masks - need to reload frames from temp video since we freed the original frames
|
||||||
|
|||||||
Reference in New Issue
Block a user