416 lines
17 KiB
Python
416 lines
17 KiB
Python
import torch
|
|
import numpy as np
|
|
from typing import List, Dict, Any, Optional, Tuple
|
|
import cv2
|
|
from pathlib import Path
|
|
import warnings
|
|
import os
|
|
import tempfile
|
|
import shutil
|
|
import gc
|
|
|
|
try:
|
|
from sam2.build_sam import build_sam2_video_predictor
|
|
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
|
SAM2_AVAILABLE = True
|
|
except ImportError:
|
|
SAM2_AVAILABLE = False
|
|
warnings.warn("SAM2 not available. Please install sam2 package.")
|
|
|
|
|
|
class SAM2VideoMatting:
|
|
"""SAM2-based video matting with memory optimization"""
|
|
|
|
def __init__(self,
|
|
model_cfg: str = "sam2_hiera_l",
|
|
checkpoint_path: str = "sam2_hiera_large.pt",
|
|
device: str = "cuda",
|
|
memory_offload: bool = True,
|
|
fp16: bool = True):
|
|
if not SAM2_AVAILABLE:
|
|
raise ImportError("SAM2 not available. Please install sam2 package.")
|
|
|
|
self.device = device
|
|
self.memory_offload = memory_offload
|
|
self.fp16 = fp16
|
|
self.model_cfg = model_cfg
|
|
self.checkpoint_path = checkpoint_path
|
|
self.predictor = None
|
|
self.inference_state = None
|
|
self.video_segments = {}
|
|
self.temp_video_path = None
|
|
|
|
self._load_model(model_cfg, checkpoint_path)
|
|
|
|
def _load_model(self, model_cfg: str, checkpoint_path: str):
|
|
"""Load SAM2 video predictor with optimizations"""
|
|
try:
|
|
# Check for checkpoint in SAM2 repo structure
|
|
if not Path(checkpoint_path).exists():
|
|
# Try in segment-anything-2/checkpoints/
|
|
sam2_path = Path("segment-anything-2/checkpoints") / Path(checkpoint_path).name
|
|
if sam2_path.exists():
|
|
checkpoint_path = str(sam2_path)
|
|
else:
|
|
# Try legacy models/ directory
|
|
models_path = Path("models") / Path(checkpoint_path).name
|
|
if models_path.exists():
|
|
checkpoint_path = str(models_path)
|
|
else:
|
|
# Try relative to package
|
|
package_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
sam2_repo_path = Path(package_dir) / "segment-anything-2/checkpoints" / Path(checkpoint_path).name
|
|
if sam2_repo_path.exists():
|
|
checkpoint_path = str(sam2_repo_path)
|
|
|
|
# Use SAM2's build_sam2_video_predictor which returns the predictor directly
|
|
# The predictor IS the model - no .model attribute needed
|
|
self.predictor = build_sam2_video_predictor(
|
|
config_file=model_cfg,
|
|
ckpt_path=checkpoint_path,
|
|
device=self.device
|
|
)
|
|
|
|
except Exception as e:
|
|
raise RuntimeError(f"Failed to load SAM2 model: {e}")
|
|
|
|
def init_video_state(self, video_frames: List[np.ndarray] = None, video_path: str = None) -> None:
|
|
"""Initialize video inference state"""
|
|
if self.predictor is None:
|
|
# Recreate predictor if it was cleaned up
|
|
self._load_model(self.model_cfg, self.checkpoint_path)
|
|
|
|
if video_path is not None:
|
|
# Use video path directly (SAM2's preferred method)
|
|
self.inference_state = self.predictor.init_state(
|
|
video_path=video_path,
|
|
offload_video_to_cpu=self.memory_offload,
|
|
async_loading_frames=True
|
|
)
|
|
else:
|
|
# For frame arrays, we need to save them as a temporary video first
|
|
|
|
if video_frames is None or len(video_frames) == 0:
|
|
raise ValueError("Either video_path or video_frames must be provided")
|
|
|
|
# Create temporary video file in current directory
|
|
import uuid
|
|
temp_video_name = f"temp_sam2_{uuid.uuid4().hex[:8]}.mp4"
|
|
temp_video_path = Path.cwd() / temp_video_name
|
|
|
|
# Write frames to temporary video
|
|
height, width = video_frames[0].shape[:2]
|
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
|
writer = cv2.VideoWriter(str(temp_video_path), fourcc, 30.0, (width, height))
|
|
|
|
for frame in video_frames:
|
|
writer.write(frame)
|
|
writer.release()
|
|
|
|
# Initialize with temporary video
|
|
self.inference_state = self.predictor.init_state(
|
|
video_path=str(temp_video_path),
|
|
offload_video_to_cpu=self.memory_offload,
|
|
async_loading_frames=True
|
|
)
|
|
|
|
# Store temp path for cleanup
|
|
self.temp_video_path = temp_video_path
|
|
|
|
def add_person_prompts(self,
|
|
frame_idx: int,
|
|
box_prompts: np.ndarray,
|
|
labels: np.ndarray) -> List[int]:
|
|
"""
|
|
Add person detection prompts to SAM2
|
|
|
|
Args:
|
|
frame_idx: Frame index to add prompts
|
|
box_prompts: Bounding boxes (N, 4)
|
|
labels: Prompt labels (N,)
|
|
|
|
Returns:
|
|
List of object IDs
|
|
"""
|
|
if self.inference_state is None:
|
|
raise RuntimeError("Video state not initialized")
|
|
|
|
object_ids = []
|
|
|
|
for i, (box, label) in enumerate(zip(box_prompts, labels)):
|
|
obj_id = i + 1 # Start from 1
|
|
|
|
# Add box prompt
|
|
_, out_obj_ids, out_mask_logits = self.predictor.add_new_points_or_box(
|
|
inference_state=self.inference_state,
|
|
frame_idx=frame_idx,
|
|
obj_id=obj_id,
|
|
box=box,
|
|
)
|
|
|
|
object_ids.extend(out_obj_ids)
|
|
|
|
return object_ids
|
|
|
|
def propagate_masks(self, start_frame: int = 0, max_frames: Optional[int] = None,
|
|
frame_release_interval: int = 50, frame_window_size: int = 30) -> Dict[int, Dict[int, np.ndarray]]:
|
|
"""
|
|
Propagate masks through video with Det-SAM2 style memory management
|
|
|
|
Args:
|
|
start_frame: Starting frame index
|
|
max_frames: Maximum number of frames to process
|
|
frame_release_interval: Release old frames every N frames
|
|
frame_window_size: Keep N frames in memory
|
|
|
|
Returns:
|
|
Dictionary mapping frame_idx -> {obj_id: mask}
|
|
"""
|
|
if self.inference_state is None:
|
|
raise RuntimeError("Video state not initialized")
|
|
|
|
video_segments = {}
|
|
|
|
for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(
|
|
self.inference_state,
|
|
start_frame_idx=start_frame,
|
|
max_frame_num_to_track=max_frames,
|
|
reverse=False
|
|
):
|
|
frame_masks = {}
|
|
|
|
for i, out_obj_id in enumerate(out_obj_ids):
|
|
mask = (out_mask_logits[i] > 0.0).cpu().numpy()
|
|
frame_masks[out_obj_id] = mask
|
|
|
|
video_segments[out_frame_idx] = frame_masks
|
|
|
|
# Det-SAM2 style memory management: more aggressive frame release
|
|
if self.memory_offload and out_frame_idx % frame_release_interval == 0:
|
|
self._release_old_frames(out_frame_idx - frame_window_size)
|
|
# Optional: Log frame release for monitoring
|
|
if out_frame_idx % (frame_release_interval * 4) == 0: # Log every 4x release interval
|
|
print(f"Det-SAM2: Released frames before {out_frame_idx - frame_window_size}, keeping {frame_window_size} frames")
|
|
|
|
return video_segments
|
|
|
|
def propagate_masks_with_continuous_correction(self,
|
|
detector,
|
|
temp_video_path: str,
|
|
start_frame: int = 0,
|
|
max_frames: Optional[int] = None,
|
|
correction_interval: int = 60,
|
|
frame_release_interval: int = 50,
|
|
frame_window_size: int = 30) -> Dict[int, Dict[int, np.ndarray]]:
|
|
"""
|
|
Det-SAM2 style: Propagate masks with continuous prompt correction
|
|
|
|
Args:
|
|
detector: YOLODetector instance for generating correction prompts
|
|
temp_video_path: Path to video file for frame access
|
|
start_frame: Starting frame index
|
|
max_frames: Maximum number of frames to process
|
|
correction_interval: Add correction prompts every N frames
|
|
frame_release_interval: Release old frames every N frames
|
|
frame_window_size: Keep N frames in memory
|
|
|
|
Returns:
|
|
Dictionary mapping frame_idx -> {obj_id: mask}
|
|
"""
|
|
if self.inference_state is None:
|
|
raise RuntimeError("Video state not initialized")
|
|
|
|
video_segments = {}
|
|
max_frames = max_frames or 10000 # Default limit
|
|
|
|
# Open video for accessing frames during propagation
|
|
cap = cv2.VideoCapture(str(temp_video_path))
|
|
|
|
try:
|
|
for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(
|
|
self.inference_state,
|
|
start_frame_idx=start_frame,
|
|
max_frame_num_to_track=max_frames,
|
|
reverse=False
|
|
):
|
|
frame_masks = {}
|
|
|
|
for i, out_obj_id in enumerate(out_obj_ids):
|
|
mask = (out_mask_logits[i] > 0.0).cpu().numpy()
|
|
frame_masks[out_obj_id] = mask
|
|
|
|
video_segments[out_frame_idx] = frame_masks
|
|
|
|
# Det-SAM2 optimization: Add correction prompts at keyframes
|
|
if (out_frame_idx % correction_interval == 0 and
|
|
out_frame_idx > start_frame and
|
|
out_frame_idx < max_frames - 1):
|
|
|
|
# Read frame for detection
|
|
cap.set(cv2.CAP_PROP_POS_FRAMES, out_frame_idx)
|
|
ret, correction_frame = cap.read()
|
|
|
|
if ret:
|
|
# Run detection on this keyframe
|
|
detections = detector.detect_persons(correction_frame)
|
|
|
|
if detections:
|
|
# Convert to prompts and add as corrections
|
|
box_prompts, labels = detector.convert_to_sam_prompts(detections)
|
|
|
|
# Add correction prompts (SAM2 will propagate backward)
|
|
correction_count = 0
|
|
try:
|
|
for i, (box, label) in enumerate(zip(box_prompts, labels)):
|
|
# Use existing object IDs if available, otherwise create new ones
|
|
obj_id = out_obj_ids[i] if i < len(out_obj_ids) else len(out_obj_ids) + i + 1
|
|
|
|
self.predictor.add_new_points_or_box(
|
|
inference_state=self.inference_state,
|
|
frame_idx=out_frame_idx,
|
|
obj_id=obj_id,
|
|
box=box,
|
|
)
|
|
correction_count += 1
|
|
|
|
print(f"Det-SAM2: Added {correction_count} correction prompts at frame {out_frame_idx}")
|
|
|
|
except Exception as e:
|
|
warnings.warn(f"Failed to add correction prompt at frame {out_frame_idx}: {e}")
|
|
|
|
# Memory management: More aggressive frame release (Det-SAM2 style)
|
|
if self.memory_offload and out_frame_idx % frame_release_interval == 0:
|
|
self._release_old_frames(out_frame_idx - frame_window_size)
|
|
# Optional: Log frame release for monitoring
|
|
if out_frame_idx % (frame_release_interval * 4) == 0: # Log every 4x release interval
|
|
print(f"Det-SAM2: Released frames before {out_frame_idx - frame_window_size}, keeping {frame_window_size} frames")
|
|
|
|
finally:
|
|
cap.release()
|
|
|
|
return video_segments
|
|
|
|
def _release_old_frames(self, before_frame_idx: int):
|
|
"""Release old frames from memory"""
|
|
try:
|
|
if hasattr(self.predictor, 'release_old_frames'):
|
|
self.predictor.release_old_frames(self.inference_state, before_frame_idx)
|
|
except Exception as e:
|
|
warnings.warn(f"Failed to release old frames: {e}")
|
|
|
|
def get_combined_mask(self, frame_masks: Dict[int, np.ndarray]) -> np.ndarray:
|
|
"""Combine masks from multiple objects into single mask"""
|
|
if not frame_masks:
|
|
return None
|
|
|
|
combined_mask = np.zeros_like(next(iter(frame_masks.values())), dtype=bool)
|
|
|
|
for obj_id, mask in frame_masks.items():
|
|
if mask.ndim == 3:
|
|
mask = mask.squeeze()
|
|
combined_mask = np.logical_or(combined_mask, mask)
|
|
|
|
return combined_mask
|
|
|
|
def apply_mask_to_frame(self,
|
|
frame: np.ndarray,
|
|
mask: np.ndarray,
|
|
output_format: str = "alpha",
|
|
background_color: List[int] = [0, 255, 0]) -> np.ndarray:
|
|
"""
|
|
Apply mask to frame to create matted output
|
|
|
|
Args:
|
|
frame: Input frame (H, W, 3)
|
|
mask: Binary mask (H, W)
|
|
output_format: "alpha" or "greenscreen"
|
|
background_color: RGB background color for greenscreen
|
|
|
|
Returns:
|
|
Matted frame
|
|
"""
|
|
if mask is None:
|
|
return frame
|
|
|
|
# Ensure mask is 2D
|
|
if mask.ndim == 3:
|
|
mask = mask.squeeze()
|
|
|
|
# Resize mask to match frame if needed
|
|
if mask.shape[:2] != frame.shape[:2]:
|
|
mask = cv2.resize(mask.astype(np.uint8),
|
|
(frame.shape[1], frame.shape[0]),
|
|
interpolation=cv2.INTER_NEAREST).astype(bool)
|
|
|
|
if output_format == "alpha":
|
|
# Create RGBA output
|
|
output = np.zeros((frame.shape[0], frame.shape[1], 4), dtype=np.uint8)
|
|
output[:, :, :3] = frame
|
|
output[:, :, 3] = mask.astype(np.uint8) * 255
|
|
return output
|
|
|
|
elif output_format == "greenscreen":
|
|
# Create RGB output with background
|
|
output = np.full_like(frame, background_color, dtype=np.uint8)
|
|
output[mask] = frame[mask]
|
|
return output
|
|
|
|
else:
|
|
raise ValueError(f"Unsupported output format: {output_format}")
|
|
|
|
def cleanup(self):
|
|
"""Clean up resources"""
|
|
if self.inference_state is not None:
|
|
try:
|
|
# Reset SAM2 state first (critical for memory cleanup)
|
|
if self.predictor is not None and hasattr(self.predictor, 'reset_state'):
|
|
self.predictor.reset_state(self.inference_state)
|
|
|
|
# Fallback to cleanup_state if available
|
|
elif self.predictor is not None and hasattr(self.predictor, 'cleanup_state'):
|
|
self.predictor.cleanup_state(self.inference_state)
|
|
|
|
# Explicitly delete inference state and video segments
|
|
del self.inference_state
|
|
if hasattr(self, 'video_segments') and self.video_segments:
|
|
del self.video_segments
|
|
self.video_segments = {}
|
|
|
|
except Exception as e:
|
|
warnings.warn(f"Failed to cleanup SAM2 state: {e}")
|
|
finally:
|
|
self.inference_state = None
|
|
|
|
# Clean up temporary video file
|
|
if self.temp_video_path is not None:
|
|
try:
|
|
if self.temp_video_path.exists():
|
|
# Remove the temporary video file
|
|
self.temp_video_path.unlink()
|
|
self.temp_video_path = None
|
|
except Exception as e:
|
|
warnings.warn(f"Failed to cleanup temp video: {e}")
|
|
|
|
# Clear CUDA cache
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
|
|
# Explicitly delete predictor for fresh creation next time
|
|
if self.predictor is not None:
|
|
try:
|
|
del self.predictor
|
|
except Exception as e:
|
|
warnings.warn(f"Failed to delete predictor: {e}")
|
|
finally:
|
|
self.predictor = None
|
|
|
|
# Force garbage collection (critical for memory leak prevention)
|
|
gc.collect()
|
|
|
|
def __del__(self):
|
|
"""Destructor to ensure cleanup"""
|
|
try:
|
|
self.cleanup()
|
|
except Exception:
|
|
# Ignore errors during Python shutdown
|
|
pass |