Files
test2/vr180_matting/sam2_wrapper.py
2025-07-26 15:18:01 -07:00

431 lines
18 KiB
Python

import torch
import numpy as np
from typing import List, Dict, Any, Optional, Tuple
import cv2
from pathlib import Path
import warnings
import os
import tempfile
import shutil
import gc
# Check SAM2 availability without importing heavy modules
def _check_sam2_available():
try:
import sam2
return True
except ImportError:
return False
SAM2_AVAILABLE = _check_sam2_available()
if not SAM2_AVAILABLE:
warnings.warn("SAM2 not available. Please install sam2 package.")
class SAM2VideoMatting:
"""SAM2-based video matting with memory optimization"""
def __init__(self,
model_cfg: str = "sam2_hiera_l",
checkpoint_path: str = "sam2_hiera_large.pt",
device: str = "cuda",
memory_offload: bool = True,
fp16: bool = True):
if not SAM2_AVAILABLE:
raise ImportError("SAM2 not available. Please install sam2 package.")
self.device = device
self.memory_offload = memory_offload
self.fp16 = fp16
self.model_cfg = model_cfg
self.checkpoint_path = checkpoint_path
self.predictor = None
self.inference_state = None
self.video_segments = {}
self.temp_video_path = None
# Don't load model during init - load lazily when needed
self._model_loaded = False
def _load_model(self, model_cfg: str, checkpoint_path: str):
"""Load SAM2 video predictor lazily"""
if self._model_loaded:
return # Already loaded
try:
# Import heavy SAM2 modules only when needed
from sam2.build_sam import build_sam2_video_predictor
# Check for checkpoint in SAM2 repo structure
if not Path(checkpoint_path).exists():
# Try in segment-anything-2/checkpoints/
sam2_path = Path("segment-anything-2/checkpoints") / Path(checkpoint_path).name
if sam2_path.exists():
checkpoint_path = str(sam2_path)
else:
# Try legacy models/ directory
models_path = Path("models") / Path(checkpoint_path).name
if models_path.exists():
checkpoint_path = str(models_path)
else:
# Try relative to package
package_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sam2_repo_path = Path(package_dir) / "segment-anything-2/checkpoints" / Path(checkpoint_path).name
if sam2_repo_path.exists():
checkpoint_path = str(sam2_repo_path)
print(f"🎯 Loading SAM2 model: {model_cfg}")
# Use SAM2's build_sam2_video_predictor which returns the predictor directly
# The predictor IS the model - no .model attribute needed
self.predictor = build_sam2_video_predictor(
config_file=model_cfg,
ckpt_path=checkpoint_path,
device=self.device
)
self._model_loaded = True
print(f"✅ SAM2 model loaded successfully")
except Exception as e:
raise RuntimeError(f"Failed to load SAM2 model: {e}")
def init_video_state(self, video_frames: List[np.ndarray] = None, video_path: str = None) -> None:
"""Initialize video inference state"""
# Load model lazily on first use
if not self._model_loaded:
self._load_model(self.model_cfg, self.checkpoint_path)
if video_path is not None:
# Use video path directly (SAM2's preferred method)
self.inference_state = self.predictor.init_state(
video_path=video_path,
offload_video_to_cpu=self.memory_offload,
async_loading_frames=True
)
else:
# For frame arrays, we need to save them as a temporary video first
if video_frames is None or len(video_frames) == 0:
raise ValueError("Either video_path or video_frames must be provided")
# Create temporary video file in current directory
import uuid
temp_video_name = f"temp_sam2_{uuid.uuid4().hex[:8]}.mp4"
temp_video_path = Path.cwd() / temp_video_name
# Write frames to temporary video
height, width = video_frames[0].shape[:2]
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
writer = cv2.VideoWriter(str(temp_video_path), fourcc, 30.0, (width, height))
for frame in video_frames:
writer.write(frame)
writer.release()
# Initialize with temporary video
self.inference_state = self.predictor.init_state(
video_path=str(temp_video_path),
offload_video_to_cpu=self.memory_offload,
async_loading_frames=True
)
# Store temp path for cleanup
self.temp_video_path = temp_video_path
def add_person_prompts(self,
frame_idx: int,
box_prompts: np.ndarray,
labels: np.ndarray) -> List[int]:
"""
Add person detection prompts to SAM2
Args:
frame_idx: Frame index to add prompts
box_prompts: Bounding boxes (N, 4)
labels: Prompt labels (N,)
Returns:
List of object IDs
"""
if self.inference_state is None:
raise RuntimeError("Video state not initialized")
object_ids = []
for i, (box, label) in enumerate(zip(box_prompts, labels)):
obj_id = i + 1 # Start from 1
# Add box prompt
_, out_obj_ids, out_mask_logits = self.predictor.add_new_points_or_box(
inference_state=self.inference_state,
frame_idx=frame_idx,
obj_id=obj_id,
box=box,
)
object_ids.extend(out_obj_ids)
return object_ids
def propagate_masks(self, start_frame: int = 0, max_frames: Optional[int] = None,
frame_release_interval: int = 50, frame_window_size: int = 30) -> Dict[int, Dict[int, np.ndarray]]:
"""
Propagate masks through video with Det-SAM2 style memory management
Args:
start_frame: Starting frame index
max_frames: Maximum number of frames to process
frame_release_interval: Release old frames every N frames
frame_window_size: Keep N frames in memory
Returns:
Dictionary mapping frame_idx -> {obj_id: mask}
"""
if self.inference_state is None:
raise RuntimeError("Video state not initialized")
video_segments = {}
for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(
self.inference_state,
start_frame_idx=start_frame,
max_frame_num_to_track=max_frames,
reverse=False
):
frame_masks = {}
for i, out_obj_id in enumerate(out_obj_ids):
mask = (out_mask_logits[i] > 0.0).cpu().numpy()
frame_masks[out_obj_id] = mask
video_segments[out_frame_idx] = frame_masks
# Det-SAM2 style memory management: more aggressive frame release
if self.memory_offload and out_frame_idx % frame_release_interval == 0:
self._release_old_frames(out_frame_idx - frame_window_size)
# Optional: Log frame release for monitoring
if out_frame_idx % (frame_release_interval * 4) == 0: # Log every 4x release interval
print(f"Det-SAM2: Released frames before {out_frame_idx - frame_window_size}, keeping {frame_window_size} frames")
return video_segments
def propagate_masks_with_continuous_correction(self,
detector,
temp_video_path: str,
start_frame: int = 0,
max_frames: Optional[int] = None,
correction_interval: int = 60,
frame_release_interval: int = 50,
frame_window_size: int = 30) -> Dict[int, Dict[int, np.ndarray]]:
"""
Det-SAM2 style: Propagate masks with continuous prompt correction
Args:
detector: YOLODetector instance for generating correction prompts
temp_video_path: Path to video file for frame access
start_frame: Starting frame index
max_frames: Maximum number of frames to process
correction_interval: Add correction prompts every N frames
frame_release_interval: Release old frames every N frames
frame_window_size: Keep N frames in memory
Returns:
Dictionary mapping frame_idx -> {obj_id: mask}
"""
if self.inference_state is None:
raise RuntimeError("Video state not initialized")
video_segments = {}
max_frames = max_frames or 10000 # Default limit
# Open video for accessing frames during propagation
cap = cv2.VideoCapture(str(temp_video_path))
try:
for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(
self.inference_state,
start_frame_idx=start_frame,
max_frame_num_to_track=max_frames,
reverse=False
):
frame_masks = {}
for i, out_obj_id in enumerate(out_obj_ids):
mask = (out_mask_logits[i] > 0.0).cpu().numpy()
frame_masks[out_obj_id] = mask
video_segments[out_frame_idx] = frame_masks
# Det-SAM2 optimization: Add correction prompts at keyframes
if (out_frame_idx % correction_interval == 0 and
out_frame_idx > start_frame and
out_frame_idx < max_frames - 1):
# Read frame for detection
cap.set(cv2.CAP_PROP_POS_FRAMES, out_frame_idx)
ret, correction_frame = cap.read()
if ret:
# Run detection on this keyframe
detections = detector.detect_persons(correction_frame)
if detections:
# Convert to prompts and add as corrections
box_prompts, labels = detector.convert_to_sam_prompts(detections)
# Add correction prompts (SAM2 will propagate backward)
correction_count = 0
try:
for i, (box, label) in enumerate(zip(box_prompts, labels)):
# Use existing object IDs if available, otherwise create new ones
obj_id = out_obj_ids[i] if i < len(out_obj_ids) else len(out_obj_ids) + i + 1
self.predictor.add_new_points_or_box(
inference_state=self.inference_state,
frame_idx=out_frame_idx,
obj_id=obj_id,
box=box,
)
correction_count += 1
print(f"Det-SAM2: Added {correction_count} correction prompts at frame {out_frame_idx}")
except Exception as e:
warnings.warn(f"Failed to add correction prompt at frame {out_frame_idx}: {e}")
# Memory management: More aggressive frame release (Det-SAM2 style)
if self.memory_offload and out_frame_idx % frame_release_interval == 0:
self._release_old_frames(out_frame_idx - frame_window_size)
# Optional: Log frame release for monitoring
if out_frame_idx % (frame_release_interval * 4) == 0: # Log every 4x release interval
print(f"Det-SAM2: Released frames before {out_frame_idx - frame_window_size}, keeping {frame_window_size} frames")
finally:
cap.release()
return video_segments
def _release_old_frames(self, before_frame_idx: int):
"""Release old frames from memory"""
try:
if hasattr(self.predictor, 'release_old_frames'):
self.predictor.release_old_frames(self.inference_state, before_frame_idx)
except Exception as e:
warnings.warn(f"Failed to release old frames: {e}")
def get_combined_mask(self, frame_masks: Dict[int, np.ndarray]) -> np.ndarray:
"""Combine masks from multiple objects into single mask"""
if not frame_masks:
return None
combined_mask = np.zeros_like(next(iter(frame_masks.values())), dtype=bool)
for obj_id, mask in frame_masks.items():
if mask.ndim == 3:
mask = mask.squeeze()
combined_mask = np.logical_or(combined_mask, mask)
return combined_mask
def apply_mask_to_frame(self,
frame: np.ndarray,
mask: np.ndarray,
output_format: str = "alpha",
background_color: List[int] = [0, 255, 0]) -> np.ndarray:
"""
Apply mask to frame to create matted output
Args:
frame: Input frame (H, W, 3)
mask: Binary mask (H, W)
output_format: "alpha" or "greenscreen"
background_color: RGB background color for greenscreen
Returns:
Matted frame
"""
if mask is None:
return frame
# Ensure mask is 2D
if mask.ndim == 3:
mask = mask.squeeze()
# Resize mask to match frame if needed
if mask.shape[:2] != frame.shape[:2]:
mask = cv2.resize(mask.astype(np.uint8),
(frame.shape[1], frame.shape[0]),
interpolation=cv2.INTER_NEAREST).astype(bool)
if output_format == "alpha":
# Create RGBA output
output = np.zeros((frame.shape[0], frame.shape[1], 4), dtype=np.uint8)
output[:, :, :3] = frame
output[:, :, 3] = mask.astype(np.uint8) * 255
return output
elif output_format == "greenscreen":
# Create RGB output with background
output = np.full_like(frame, background_color, dtype=np.uint8)
output[mask] = frame[mask]
return output
else:
raise ValueError(f"Unsupported output format: {output_format}")
def cleanup(self):
"""Clean up resources"""
if self.inference_state is not None:
try:
# Reset SAM2 state first (critical for memory cleanup)
if self.predictor is not None and hasattr(self.predictor, 'reset_state'):
self.predictor.reset_state(self.inference_state)
# Fallback to cleanup_state if available
elif self.predictor is not None and hasattr(self.predictor, 'cleanup_state'):
self.predictor.cleanup_state(self.inference_state)
# Explicitly delete inference state and video segments
del self.inference_state
if hasattr(self, 'video_segments') and self.video_segments:
del self.video_segments
self.video_segments = {}
except Exception as e:
warnings.warn(f"Failed to cleanup SAM2 state: {e}")
finally:
self.inference_state = None
# Clean up temporary video file
if self.temp_video_path is not None:
try:
if self.temp_video_path.exists():
# Remove the temporary video file
self.temp_video_path.unlink()
self.temp_video_path = None
except Exception as e:
warnings.warn(f"Failed to cleanup temp video: {e}")
# Clear CUDA cache
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Explicitly delete predictor for fresh creation next time
if self.predictor is not None:
try:
del self.predictor
except Exception as e:
warnings.warn(f"Failed to delete predictor: {e}")
finally:
self.predictor = None
# Force garbage collection (critical for memory leak prevention)
gc.collect()
def __del__(self):
"""Destructor to ensure cleanup"""
try:
self.cleanup()
except Exception:
# Ignore errors during Python shutdown
pass