det core
This commit is contained in:
@@ -152,13 +152,16 @@ class SAM2VideoMatting:
|
||||
|
||||
return object_ids
|
||||
|
||||
def propagate_masks(self, start_frame: int = 0, max_frames: Optional[int] = None) -> Dict[int, Dict[int, np.ndarray]]:
|
||||
def propagate_masks(self, start_frame: int = 0, max_frames: Optional[int] = None,
|
||||
frame_release_interval: int = 50, frame_window_size: int = 30) -> Dict[int, Dict[int, np.ndarray]]:
|
||||
"""
|
||||
Propagate masks through video
|
||||
Propagate masks through video with Det-SAM2 style memory management
|
||||
|
||||
Args:
|
||||
start_frame: Starting frame index
|
||||
max_frames: Maximum number of frames to process
|
||||
frame_release_interval: Release old frames every N frames
|
||||
frame_window_size: Keep N frames in memory
|
||||
|
||||
Returns:
|
||||
Dictionary mapping frame_idx -> {obj_id: mask}
|
||||
@@ -182,9 +185,108 @@ class SAM2VideoMatting:
|
||||
|
||||
video_segments[out_frame_idx] = frame_masks
|
||||
|
||||
# Memory management: release old frames periodically
|
||||
if self.memory_offload and out_frame_idx % 100 == 0:
|
||||
self._release_old_frames(out_frame_idx - 50)
|
||||
# Det-SAM2 style memory management: more aggressive frame release
|
||||
if self.memory_offload and out_frame_idx % frame_release_interval == 0:
|
||||
self._release_old_frames(out_frame_idx - frame_window_size)
|
||||
# Optional: Log frame release for monitoring
|
||||
if out_frame_idx % (frame_release_interval * 4) == 0: # Log every 4x release interval
|
||||
print(f"Det-SAM2: Released frames before {out_frame_idx - frame_window_size}, keeping {frame_window_size} frames")
|
||||
|
||||
return video_segments
|
||||
|
||||
def propagate_masks_with_continuous_correction(self,
|
||||
detector,
|
||||
temp_video_path: str,
|
||||
start_frame: int = 0,
|
||||
max_frames: Optional[int] = None,
|
||||
correction_interval: int = 60,
|
||||
frame_release_interval: int = 50,
|
||||
frame_window_size: int = 30) -> Dict[int, Dict[int, np.ndarray]]:
|
||||
"""
|
||||
Det-SAM2 style: Propagate masks with continuous prompt correction
|
||||
|
||||
Args:
|
||||
detector: YOLODetector instance for generating correction prompts
|
||||
temp_video_path: Path to video file for frame access
|
||||
start_frame: Starting frame index
|
||||
max_frames: Maximum number of frames to process
|
||||
correction_interval: Add correction prompts every N frames
|
||||
frame_release_interval: Release old frames every N frames
|
||||
frame_window_size: Keep N frames in memory
|
||||
|
||||
Returns:
|
||||
Dictionary mapping frame_idx -> {obj_id: mask}
|
||||
"""
|
||||
if self.inference_state is None:
|
||||
raise RuntimeError("Video state not initialized")
|
||||
|
||||
video_segments = {}
|
||||
max_frames = max_frames or 10000 # Default limit
|
||||
|
||||
# Open video for accessing frames during propagation
|
||||
cap = cv2.VideoCapture(str(temp_video_path))
|
||||
|
||||
try:
|
||||
for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(
|
||||
self.inference_state,
|
||||
start_frame_idx=start_frame,
|
||||
max_frame_num_to_track=max_frames,
|
||||
reverse=False
|
||||
):
|
||||
frame_masks = {}
|
||||
|
||||
for i, out_obj_id in enumerate(out_obj_ids):
|
||||
mask = (out_mask_logits[i] > 0.0).cpu().numpy()
|
||||
frame_masks[out_obj_id] = mask
|
||||
|
||||
video_segments[out_frame_idx] = frame_masks
|
||||
|
||||
# Det-SAM2 optimization: Add correction prompts at keyframes
|
||||
if (out_frame_idx % correction_interval == 0 and
|
||||
out_frame_idx > start_frame and
|
||||
out_frame_idx < max_frames - 1):
|
||||
|
||||
# Read frame for detection
|
||||
cap.set(cv2.CAP_PROP_POS_FRAMES, out_frame_idx)
|
||||
ret, correction_frame = cap.read()
|
||||
|
||||
if ret:
|
||||
# Run detection on this keyframe
|
||||
detections = detector.detect_persons(correction_frame)
|
||||
|
||||
if detections:
|
||||
# Convert to prompts and add as corrections
|
||||
box_prompts, labels = detector.convert_to_sam_prompts(detections)
|
||||
|
||||
# Add correction prompts (SAM2 will propagate backward)
|
||||
correction_count = 0
|
||||
try:
|
||||
for i, (box, label) in enumerate(zip(box_prompts, labels)):
|
||||
# Use existing object IDs if available, otherwise create new ones
|
||||
obj_id = out_obj_ids[i] if i < len(out_obj_ids) else len(out_obj_ids) + i + 1
|
||||
|
||||
self.predictor.add_new_points_or_box(
|
||||
inference_state=self.inference_state,
|
||||
frame_idx=out_frame_idx,
|
||||
obj_id=obj_id,
|
||||
box=box,
|
||||
)
|
||||
correction_count += 1
|
||||
|
||||
print(f"Det-SAM2: Added {correction_count} correction prompts at frame {out_frame_idx}")
|
||||
|
||||
except Exception as e:
|
||||
warnings.warn(f"Failed to add correction prompt at frame {out_frame_idx}: {e}")
|
||||
|
||||
# Memory management: More aggressive frame release (Det-SAM2 style)
|
||||
if self.memory_offload and out_frame_idx % frame_release_interval == 0:
|
||||
self._release_old_frames(out_frame_idx - frame_window_size)
|
||||
# Optional: Log frame release for monitoring
|
||||
if out_frame_idx % (frame_release_interval * 4) == 0: # Log every 4x release interval
|
||||
print(f"Det-SAM2: Released frames before {out_frame_idx - frame_window_size}, keeping {frame_window_size} frames")
|
||||
|
||||
finally:
|
||||
cap.release()
|
||||
|
||||
return video_segments
|
||||
|
||||
|
||||
Reference in New Issue
Block a user