import torch import numpy as np from typing import List, Dict, Any, Optional, Tuple import cv2 from pathlib import Path import warnings import os import tempfile import shutil import gc try: from sam2.build_sam import build_sam2_video_predictor from sam2.sam2_image_predictor import SAM2ImagePredictor SAM2_AVAILABLE = True except ImportError: SAM2_AVAILABLE = False warnings.warn("SAM2 not available. Please install sam2 package.") class SAM2VideoMatting: """SAM2-based video matting with memory optimization""" def __init__(self, model_cfg: str = "sam2_hiera_l", checkpoint_path: str = "sam2_hiera_large.pt", device: str = "cuda", memory_offload: bool = True, fp16: bool = True): if not SAM2_AVAILABLE: raise ImportError("SAM2 not available. Please install sam2 package.") self.device = device self.memory_offload = memory_offload self.fp16 = fp16 self.predictor = None self.inference_state = None self.video_segments = {} self.temp_video_path = None self._load_model(model_cfg, checkpoint_path) def _load_model(self, model_cfg: str, checkpoint_path: str): """Load SAM2 video predictor with optimizations""" try: # Check for checkpoint in SAM2 repo structure if not Path(checkpoint_path).exists(): # Try in segment-anything-2/checkpoints/ sam2_path = Path("segment-anything-2/checkpoints") / Path(checkpoint_path).name if sam2_path.exists(): checkpoint_path = str(sam2_path) else: # Try legacy models/ directory models_path = Path("models") / Path(checkpoint_path).name if models_path.exists(): checkpoint_path = str(models_path) else: # Try relative to package package_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sam2_repo_path = Path(package_dir) / "segment-anything-2/checkpoints" / Path(checkpoint_path).name if sam2_repo_path.exists(): checkpoint_path = str(sam2_repo_path) # Use SAM2's build_sam2_video_predictor which returns the predictor directly # The predictor IS the model - no .model attribute needed self.predictor = build_sam2_video_predictor( config_file=model_cfg, ckpt_path=checkpoint_path, device=self.device ) except Exception as e: raise RuntimeError(f"Failed to load SAM2 model: {e}") def init_video_state(self, video_frames: List[np.ndarray] = None, video_path: str = None) -> None: """Initialize video inference state""" if self.predictor is None: raise RuntimeError("SAM2 model not loaded") if video_path is not None: # Use video path directly (SAM2's preferred method) self.inference_state = self.predictor.init_state( video_path=video_path, offload_video_to_cpu=self.memory_offload, async_loading_frames=True ) else: # For frame arrays, we need to save them as a temporary video first if video_frames is None or len(video_frames) == 0: raise ValueError("Either video_path or video_frames must be provided") # Create temporary video file in current directory import uuid temp_video_name = f"temp_sam2_{uuid.uuid4().hex[:8]}.mp4" temp_video_path = Path.cwd() / temp_video_name # Write frames to temporary video height, width = video_frames[0].shape[:2] fourcc = cv2.VideoWriter_fourcc(*'mp4v') writer = cv2.VideoWriter(str(temp_video_path), fourcc, 30.0, (width, height)) for frame in video_frames: writer.write(frame) writer.release() # Initialize with temporary video self.inference_state = self.predictor.init_state( video_path=str(temp_video_path), offload_video_to_cpu=self.memory_offload, async_loading_frames=True ) # Store temp path for cleanup self.temp_video_path = temp_video_path def add_person_prompts(self, frame_idx: int, box_prompts: np.ndarray, labels: np.ndarray) -> List[int]: """ Add person detection prompts to SAM2 Args: frame_idx: Frame index to add prompts box_prompts: Bounding boxes (N, 4) labels: Prompt labels (N,) Returns: List of object IDs """ if self.inference_state is None: raise RuntimeError("Video state not initialized") object_ids = [] for i, (box, label) in enumerate(zip(box_prompts, labels)): obj_id = i + 1 # Start from 1 # Add box prompt _, out_obj_ids, out_mask_logits = self.predictor.add_new_points_or_box( inference_state=self.inference_state, frame_idx=frame_idx, obj_id=obj_id, box=box, ) object_ids.extend(out_obj_ids) return object_ids def propagate_masks(self, start_frame: int = 0, max_frames: Optional[int] = None) -> Dict[int, Dict[int, np.ndarray]]: """ Propagate masks through video Args: start_frame: Starting frame index max_frames: Maximum number of frames to process Returns: Dictionary mapping frame_idx -> {obj_id: mask} """ if self.inference_state is None: raise RuntimeError("Video state not initialized") video_segments = {} for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video( self.inference_state, start_frame_idx=start_frame, max_frame_num_to_track=max_frames, reverse=False ): frame_masks = {} for i, out_obj_id in enumerate(out_obj_ids): mask = (out_mask_logits[i] > 0.0).cpu().numpy() frame_masks[out_obj_id] = mask video_segments[out_frame_idx] = frame_masks # Memory management: release old frames periodically if self.memory_offload and out_frame_idx % 100 == 0: self._release_old_frames(out_frame_idx - 50) return video_segments def _release_old_frames(self, before_frame_idx: int): """Release old frames from memory""" try: if hasattr(self.predictor, 'release_old_frames'): self.predictor.release_old_frames(self.inference_state, before_frame_idx) except Exception as e: warnings.warn(f"Failed to release old frames: {e}") def get_combined_mask(self, frame_masks: Dict[int, np.ndarray]) -> np.ndarray: """Combine masks from multiple objects into single mask""" if not frame_masks: return None combined_mask = np.zeros_like(next(iter(frame_masks.values())), dtype=bool) for obj_id, mask in frame_masks.items(): if mask.ndim == 3: mask = mask.squeeze() combined_mask = np.logical_or(combined_mask, mask) return combined_mask def apply_mask_to_frame(self, frame: np.ndarray, mask: np.ndarray, output_format: str = "alpha", background_color: List[int] = [0, 255, 0]) -> np.ndarray: """ Apply mask to frame to create matted output Args: frame: Input frame (H, W, 3) mask: Binary mask (H, W) output_format: "alpha" or "greenscreen" background_color: RGB background color for greenscreen Returns: Matted frame """ if mask is None: return frame # Ensure mask is 2D if mask.ndim == 3: mask = mask.squeeze() # Resize mask to match frame if needed if mask.shape[:2] != frame.shape[:2]: mask = cv2.resize(mask.astype(np.uint8), (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_NEAREST).astype(bool) if output_format == "alpha": # Create RGBA output output = np.zeros((frame.shape[0], frame.shape[1], 4), dtype=np.uint8) output[:, :, :3] = frame output[:, :, 3] = mask.astype(np.uint8) * 255 return output elif output_format == "greenscreen": # Create RGB output with background output = np.full_like(frame, background_color, dtype=np.uint8) output[mask] = frame[mask] return output else: raise ValueError(f"Unsupported output format: {output_format}") def cleanup(self): """Clean up resources""" if self.inference_state is not None: try: # Reset SAM2 state first (critical for memory cleanup) if self.predictor is not None and hasattr(self.predictor, 'reset_state'): self.predictor.reset_state(self.inference_state) # Fallback to cleanup_state if available elif self.predictor is not None and hasattr(self.predictor, 'cleanup_state'): self.predictor.cleanup_state(self.inference_state) # Explicitly delete inference state and video segments del self.inference_state if hasattr(self, 'video_segments') and self.video_segments: del self.video_segments self.video_segments = {} except Exception as e: warnings.warn(f"Failed to cleanup SAM2 state: {e}") finally: self.inference_state = None # Clean up temporary video file if self.temp_video_path is not None: try: if self.temp_video_path.exists(): # Remove the temporary video file self.temp_video_path.unlink() self.temp_video_path = None except Exception as e: warnings.warn(f"Failed to cleanup temp video: {e}") # Clear CUDA cache if torch.cuda.is_available(): torch.cuda.empty_cache() # Force garbage collection (critical for memory leak prevention) gc.collect() # Clear predictor reference (but don't delete the object itself) self.predictor = None def __del__(self): """Destructor to ensure cleanup""" try: self.cleanup() except Exception: # Ignore errors during Python shutdown pass