Files
test2/vr180_matting/sam2_wrapper.py
2025-07-26 08:47:50 -07:00

249 lines
9.1 KiB
Python

import torch
import numpy as np
from typing import List, Dict, Any, Optional, Tuple
import cv2
from pathlib import Path
import warnings
import os
try:
from sam2.build_sam import build_sam2_video_predictor
from sam2.sam2_image_predictor import SAM2ImagePredictor
SAM2_AVAILABLE = True
except ImportError:
SAM2_AVAILABLE = False
warnings.warn("SAM2 not available. Please install sam2 package.")
class SAM2VideoMatting:
"""SAM2-based video matting with memory optimization"""
def __init__(self,
model_cfg: str = "sam2_hiera_l",
checkpoint_path: str = "sam2_hiera_large.pt",
device: str = "cuda",
memory_offload: bool = True,
fp16: bool = True):
if not SAM2_AVAILABLE:
raise ImportError("SAM2 not available. Please install sam2 package.")
self.device = device
self.memory_offload = memory_offload
self.fp16 = fp16
self.predictor = None
self.inference_state = None
self.video_segments = {}
self._load_model(model_cfg, checkpoint_path)
def _load_model(self, model_cfg: str, checkpoint_path: str):
"""Load SAM2 video predictor with optimizations"""
try:
# Check for checkpoint in SAM2 repo structure
if not Path(checkpoint_path).exists():
# Try in segment-anything-2/checkpoints/
sam2_path = Path("segment-anything-2/checkpoints") / Path(checkpoint_path).name
if sam2_path.exists():
checkpoint_path = str(sam2_path)
else:
# Try legacy models/ directory
models_path = Path("models") / Path(checkpoint_path).name
if models_path.exists():
checkpoint_path = str(models_path)
else:
# Try relative to package
package_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sam2_repo_path = Path(package_dir) / "segment-anything-2/checkpoints" / Path(checkpoint_path).name
if sam2_repo_path.exists():
checkpoint_path = str(sam2_repo_path)
# Use the config path as-is (should be relative to SAM2 package)
# Example: "configs/sam2.1/sam2.1_hiera_l.yaml"
self.predictor = build_sam2_video_predictor(
model_cfg,
checkpoint_path,
device=self.device
)
# Enable memory optimizations
if self.memory_offload:
# SAM2 has different memory optimization options
pass # Memory offloading is handled by SAM2 internally
if self.fp16 and self.device == "cuda":
# SAM2 handles FP16 internally, no need to manually convert
pass
except Exception as e:
raise RuntimeError(f"Failed to load SAM2 model: {e}")
def init_video_state(self, video_frames: List[np.ndarray]) -> None:
"""Initialize video inference state"""
if self.predictor is None:
raise RuntimeError("SAM2 model not loaded")
# Create temporary directory for frames if needed
self.inference_state = self.predictor.init_state(
video_path=None,
video_frames=video_frames,
offload_video_to_cpu=self.memory_offload,
async_loading_frames=True
)
def add_person_prompts(self,
frame_idx: int,
box_prompts: np.ndarray,
labels: np.ndarray) -> List[int]:
"""
Add person detection prompts to SAM2
Args:
frame_idx: Frame index to add prompts
box_prompts: Bounding boxes (N, 4)
labels: Prompt labels (N,)
Returns:
List of object IDs
"""
if self.inference_state is None:
raise RuntimeError("Video state not initialized")
object_ids = []
for i, (box, label) in enumerate(zip(box_prompts, labels)):
obj_id = i + 1 # Start from 1
# Add box prompt
_, out_obj_ids, out_mask_logits = self.predictor.add_new_points_or_box(
inference_state=self.inference_state,
frame_idx=frame_idx,
obj_id=obj_id,
box=box,
)
object_ids.extend(out_obj_ids)
return object_ids
def propagate_masks(self, start_frame: int = 0, max_frames: Optional[int] = None) -> Dict[int, Dict[int, np.ndarray]]:
"""
Propagate masks through video
Args:
start_frame: Starting frame index
max_frames: Maximum number of frames to process
Returns:
Dictionary mapping frame_idx -> {obj_id: mask}
"""
if self.inference_state is None:
raise RuntimeError("Video state not initialized")
video_segments = {}
for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(
self.inference_state,
start_frame_idx=start_frame,
max_frame_num_to_track=max_frames,
reverse=False
):
frame_masks = {}
for i, out_obj_id in enumerate(out_obj_ids):
mask = (out_mask_logits[i] > 0.0).cpu().numpy()
frame_masks[out_obj_id] = mask
video_segments[out_frame_idx] = frame_masks
# Memory management: release old frames periodically
if self.memory_offload and out_frame_idx % 100 == 0:
self._release_old_frames(out_frame_idx - 50)
return video_segments
def _release_old_frames(self, before_frame_idx: int):
"""Release old frames from memory"""
try:
if hasattr(self.predictor, 'release_old_frames'):
self.predictor.release_old_frames(self.inference_state, before_frame_idx)
except Exception as e:
warnings.warn(f"Failed to release old frames: {e}")
def get_combined_mask(self, frame_masks: Dict[int, np.ndarray]) -> np.ndarray:
"""Combine masks from multiple objects into single mask"""
if not frame_masks:
return None
combined_mask = np.zeros_like(next(iter(frame_masks.values())), dtype=bool)
for obj_id, mask in frame_masks.items():
if mask.ndim == 3:
mask = mask.squeeze()
combined_mask = np.logical_or(combined_mask, mask)
return combined_mask
def apply_mask_to_frame(self,
frame: np.ndarray,
mask: np.ndarray,
output_format: str = "alpha",
background_color: List[int] = [0, 255, 0]) -> np.ndarray:
"""
Apply mask to frame to create matted output
Args:
frame: Input frame (H, W, 3)
mask: Binary mask (H, W)
output_format: "alpha" or "greenscreen"
background_color: RGB background color for greenscreen
Returns:
Matted frame
"""
if mask is None:
return frame
# Ensure mask is 2D
if mask.ndim == 3:
mask = mask.squeeze()
# Resize mask to match frame if needed
if mask.shape[:2] != frame.shape[:2]:
mask = cv2.resize(mask.astype(np.uint8),
(frame.shape[1], frame.shape[0]),
interpolation=cv2.INTER_NEAREST).astype(bool)
if output_format == "alpha":
# Create RGBA output
output = np.zeros((frame.shape[0], frame.shape[1], 4), dtype=np.uint8)
output[:, :, :3] = frame
output[:, :, 3] = mask.astype(np.uint8) * 255
return output
elif output_format == "greenscreen":
# Create RGB output with background
output = np.full_like(frame, background_color, dtype=np.uint8)
output[mask] = frame[mask]
return output
else:
raise ValueError(f"Unsupported output format: {output_format}")
def cleanup(self):
"""Clean up resources"""
if self.inference_state is not None:
try:
if hasattr(self.predictor, 'cleanup_state'):
self.predictor.cleanup_state(self.inference_state)
except Exception as e:
warnings.warn(f"Failed to cleanup SAM2 state: {e}")
self.inference_state = None
# Clear CUDA cache
if torch.cuda.is_available():
torch.cuda.empty_cache()
def __del__(self):
"""Destructor to ensure cleanup"""
self.cleanup()