import torch import numpy as np from typing import List, Dict, Any, Optional, Tuple import cv2 from pathlib import Path import warnings import os import tempfile import shutil try: from sam2.build_sam import build_sam2_video_predictor from sam2.sam2_image_predictor import SAM2ImagePredictor SAM2_AVAILABLE = True except ImportError: SAM2_AVAILABLE = False warnings.warn("SAM2 not available. Please install sam2 package.") class SAM2VideoMatting: """SAM2-based video matting with memory optimization""" def __init__(self, model_cfg: str = "sam2_hiera_l", checkpoint_path: str = "sam2_hiera_large.pt", device: str = "cuda", memory_offload: bool = True, fp16: bool = True): if not SAM2_AVAILABLE: raise ImportError("SAM2 not available. Please install sam2 package.") self.device = device self.memory_offload = memory_offload self.fp16 = fp16 self.predictor = None self.inference_state = None self.video_segments = {} self.temp_video_path = None self._load_model(model_cfg, checkpoint_path) def _load_model(self, model_cfg: str, checkpoint_path: str): """Load SAM2 video predictor with optimizations""" try: # Check for checkpoint in SAM2 repo structure if not Path(checkpoint_path).exists(): # Try in segment-anything-2/checkpoints/ sam2_path = Path("segment-anything-2/checkpoints") / Path(checkpoint_path).name if sam2_path.exists(): checkpoint_path = str(sam2_path) else: # Try legacy models/ directory models_path = Path("models") / Path(checkpoint_path).name if models_path.exists(): checkpoint_path = str(models_path) else: # Try relative to package package_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sam2_repo_path = Path(package_dir) / "segment-anything-2/checkpoints" / Path(checkpoint_path).name if sam2_repo_path.exists(): checkpoint_path = str(sam2_repo_path) # Use SAM2's build_sam2_video_predictor which returns the predictor directly # The predictor IS the model - no .model attribute needed self.predictor = build_sam2_video_predictor( config_file=model_cfg, ckpt_path=checkpoint_path, device=self.device ) except Exception as e: raise RuntimeError(f"Failed to load SAM2 model: {e}") def init_video_state(self, video_frames: List[np.ndarray] = None, video_path: str = None) -> None: """Initialize video inference state""" if self.predictor is None: raise RuntimeError("SAM2 model not loaded") if video_path is not None: # Use video path directly (SAM2's preferred method) self.inference_state = self.predictor.init_state( video_path=video_path, offload_video_to_cpu=self.memory_offload, async_loading_frames=True ) else: # For frame arrays, we need to save them as a temporary video first if video_frames is None or len(video_frames) == 0: raise ValueError("Either video_path or video_frames must be provided") # Create temporary video file in current directory import uuid temp_video_name = f"temp_sam2_{uuid.uuid4().hex[:8]}.mp4" temp_video_path = Path.cwd() / temp_video_name # Write frames to temporary video height, width = video_frames[0].shape[:2] fourcc = cv2.VideoWriter_fourcc(*'mp4v') writer = cv2.VideoWriter(str(temp_video_path), fourcc, 30.0, (width, height)) for frame in video_frames: writer.write(frame) writer.release() # Initialize with temporary video self.inference_state = self.predictor.init_state( video_path=str(temp_video_path), offload_video_to_cpu=self.memory_offload, async_loading_frames=True ) # Store temp path for cleanup self.temp_video_path = temp_video_path def add_person_prompts(self, frame_idx: int, box_prompts: np.ndarray, labels: np.ndarray) -> List[int]: """ Add person detection prompts to SAM2 Args: frame_idx: Frame index to add prompts box_prompts: Bounding boxes (N, 4) labels: Prompt labels (N,) Returns: List of object IDs """ if self.inference_state is None: raise RuntimeError("Video state not initialized") object_ids = [] for i, (box, label) in enumerate(zip(box_prompts, labels)): obj_id = i + 1 # Start from 1 # Add box prompt _, out_obj_ids, out_mask_logits = self.predictor.add_new_points_or_box( inference_state=self.inference_state, frame_idx=frame_idx, obj_id=obj_id, box=box, ) object_ids.extend(out_obj_ids) return object_ids def propagate_masks(self, start_frame: int = 0, max_frames: Optional[int] = None) -> Dict[int, Dict[int, np.ndarray]]: """ Propagate masks through video Args: start_frame: Starting frame index max_frames: Maximum number of frames to process Returns: Dictionary mapping frame_idx -> {obj_id: mask} """ if self.inference_state is None: raise RuntimeError("Video state not initialized") video_segments = {} for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video( self.inference_state, start_frame_idx=start_frame, max_frame_num_to_track=max_frames, reverse=False ): frame_masks = {} for i, out_obj_id in enumerate(out_obj_ids): mask = (out_mask_logits[i] > 0.0).cpu().numpy() frame_masks[out_obj_id] = mask video_segments[out_frame_idx] = frame_masks # Memory management: release old frames periodically if self.memory_offload and out_frame_idx % 100 == 0: self._release_old_frames(out_frame_idx - 50) return video_segments def _release_old_frames(self, before_frame_idx: int): """Release old frames from memory""" try: if hasattr(self.predictor, 'release_old_frames'): self.predictor.release_old_frames(self.inference_state, before_frame_idx) except Exception as e: warnings.warn(f"Failed to release old frames: {e}") def get_combined_mask(self, frame_masks: Dict[int, np.ndarray]) -> np.ndarray: """Combine masks from multiple objects into single mask""" if not frame_masks: return None combined_mask = np.zeros_like(next(iter(frame_masks.values())), dtype=bool) for obj_id, mask in frame_masks.items(): if mask.ndim == 3: mask = mask.squeeze() combined_mask = np.logical_or(combined_mask, mask) return combined_mask def apply_mask_to_frame(self, frame: np.ndarray, mask: np.ndarray, output_format: str = "alpha", background_color: List[int] = [0, 255, 0]) -> np.ndarray: """ Apply mask to frame to create matted output Args: frame: Input frame (H, W, 3) mask: Binary mask (H, W) output_format: "alpha" or "greenscreen" background_color: RGB background color for greenscreen Returns: Matted frame """ if mask is None: return frame # Ensure mask is 2D if mask.ndim == 3: mask = mask.squeeze() # Resize mask to match frame if needed if mask.shape[:2] != frame.shape[:2]: mask = cv2.resize(mask.astype(np.uint8), (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_NEAREST).astype(bool) if output_format == "alpha": # Create RGBA output output = np.zeros((frame.shape[0], frame.shape[1], 4), dtype=np.uint8) output[:, :, :3] = frame output[:, :, 3] = mask.astype(np.uint8) * 255 return output elif output_format == "greenscreen": # Create RGB output with background output = np.full_like(frame, background_color, dtype=np.uint8) output[mask] = frame[mask] return output else: raise ValueError(f"Unsupported output format: {output_format}") def cleanup(self): """Clean up resources""" if self.inference_state is not None: try: if hasattr(self.predictor, 'cleanup_state'): self.predictor.cleanup_state(self.inference_state) except Exception as e: warnings.warn(f"Failed to cleanup SAM2 state: {e}") self.inference_state = None # Clean up temporary video file if self.temp_video_path is not None: try: if self.temp_video_path.exists(): # Remove the temporary video file self.temp_video_path.unlink() self.temp_video_path = None except Exception as e: warnings.warn(f"Failed to cleanup temp video: {e}") # Clear CUDA cache if torch.cuda.is_available(): torch.cuda.empty_cache() def __del__(self): """Destructor to ensure cleanup""" self.cleanup()