""" Simple SAM2 streaming processor based on det-sam2 pattern Adapted for current segment-anything-2 API """ import torch import numpy as np import cv2 import tempfile import os from pathlib import Path from typing import Dict, Any, List, Optional import warnings import gc # Import SAM2 components try: from sam2.build_sam import build_sam2_video_predictor except ImportError: warnings.warn("SAM2 not installed. Please install segment-anything-2 first.") class SAM2StreamingProcessor: """Simple streaming integration with SAM2 following det-sam2 pattern""" def __init__(self, config: Dict[str, Any]): self.config = config self.device = torch.device(config.get('hardware', {}).get('device', 'cuda')) # SAM2 model configuration model_cfg_name = config.get('matting', {}).get('sam2_model_cfg', 'sam2.1_hiera_l') checkpoint = config.get('matting', {}).get('sam2_checkpoint', 'segment-anything-2/checkpoints/sam2.1_hiera_large.pt') # Map config name to Hydra path (like the examples show) config_mapping = { 'sam2.1_hiera_t': 'configs/sam2.1/sam2.1_hiera_t.yaml', 'sam2.1_hiera_s': 'configs/sam2.1/sam2.1_hiera_s.yaml', 'sam2.1_hiera_b+': 'configs/sam2.1/sam2.1_hiera_b+.yaml', 'sam2.1_hiera_l': 'configs/sam2.1/sam2.1_hiera_l.yaml', } model_cfg = config_mapping.get(model_cfg_name, model_cfg_name) # Build predictor (simple, clean approach) self.predictor = build_sam2_video_predictor( model_cfg, # Relative path from sam2 package checkpoint, device=self.device, vos_optimized=True # Enable VOS optimizations for speed ) # Frame buffer for streaming (like det-sam2) self.frame_buffer = [] self.frame_buffer_size = config.get('streaming', {}).get('buffer_frames', 10) # State management (simple) self.inference_state = None self.temp_dir = None self.object_ids = [] # Memory management self.memory_offload = config.get('matting', {}).get('memory_offload', True) self.max_frames_to_track = config.get('matting', {}).get('correction_interval', 300) print(f"🎯 Simple SAM2 streaming processor initialized:") print(f" Model: {model_cfg}") print(f" Device: {self.device}") print(f" Buffer size: {self.frame_buffer_size}") print(f" Memory offload: {self.memory_offload}") def add_frame_and_detections(self, frame: np.ndarray, detections: List[Dict[str, Any]], frame_idx: int) -> np.ndarray: """ Add frame to buffer and process detections (det-sam2 pattern) Args: frame: Input frame (BGR) detections: List of detections with 'box' key frame_idx: Global frame index Returns: Mask for current frame """ # Add frame to buffer self.frame_buffer.append({ 'frame': frame, 'frame_idx': frame_idx, 'detections': detections }) # Process when buffer is full or when we have detections if len(self.frame_buffer) >= self.frame_buffer_size or detections: return self._process_buffer() else: # Return empty mask if no processing yet return np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8) def _process_buffer(self) -> np.ndarray: """Process current frame buffer (adapted det-sam2 approach)""" if not self.frame_buffer: return np.zeros((480, 640), dtype=np.uint8) try: # Create temporary directory for frames (current SAM2 API requirement) self._create_temp_frames() # Initialize or update SAM2 state if self.inference_state is None: # First time: initialize state with temp directory self.inference_state = self.predictor.init_state( video_path=self.temp_dir, offload_video_to_cpu=self.memory_offload, offload_state_to_cpu=self.memory_offload ) print(f" Initialized SAM2 state with {len(self.frame_buffer)} frames") else: # Subsequent times: we need to reinitialize since current SAM2 lacks update_state # This is the key difference from det-sam2 reference self._cleanup_temp_frames() self._create_temp_frames() self.inference_state = self.predictor.init_state( video_path=self.temp_dir, offload_video_to_cpu=self.memory_offload, offload_state_to_cpu=self.memory_offload ) print(f" Reinitialized SAM2 state with {len(self.frame_buffer)} frames") # Add detections as prompts (standard SAM2 API) self._add_detection_prompts() # Get masks via propagation masks = self._get_current_masks() # Clean up old frames to prevent memory accumulation self._cleanup_old_frames() return masks except Exception as e: print(f" Warning: Buffer processing failed: {e}") return np.zeros((480, 640), dtype=np.uint8) def _create_temp_frames(self): """Create temporary directory with frame images for SAM2""" if self.temp_dir: self._cleanup_temp_frames() self.temp_dir = tempfile.mkdtemp(prefix='sam2_streaming_') for i, buffer_item in enumerate(self.frame_buffer): frame = buffer_item['frame'] # Convert BGR to RGB for SAM2 frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # Save as JPEG (SAM2 expects JPEG images in directory) frame_path = os.path.join(self.temp_dir, f"{i:05d}.jpg") cv2.imwrite(frame_path, cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR)) def _add_detection_prompts(self): """Add detection boxes as prompts to SAM2 (standard API)""" for buffer_idx, buffer_item in enumerate(self.frame_buffer): detections = buffer_item.get('detections', []) for det_idx, detection in enumerate(detections): box = detection['box'] # [x1, y1, x2, y2] # Use standard SAM2 API try: _, out_obj_ids, out_mask_logits = self.predictor.add_new_points_or_box( inference_state=self.inference_state, frame_idx=buffer_idx, # Frame index within buffer obj_id=det_idx, # Simple object ID box=np.array(box, dtype=np.float32) ) # Track object IDs if det_idx not in self.object_ids: self.object_ids.append(det_idx) except Exception as e: print(f" Warning: Failed to add detection: {e}") continue def _get_current_masks(self) -> np.ndarray: """Get masks for current frame via propagation""" if not self.object_ids: # No objects to track frame_shape = self.frame_buffer[-1]['frame'].shape return np.zeros((frame_shape[0], frame_shape[1]), dtype=np.uint8) try: # Use SAM2's propagate_in_video (standard API) latest_frame_idx = len(self.frame_buffer) - 1 masks_for_frame = [] for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video( self.inference_state, start_frame_idx=latest_frame_idx, max_frame_num_to_track=1, # Just current frame reverse=False ): if out_frame_idx == latest_frame_idx: # Combine all object masks if len(out_mask_logits) > 0: combined_mask = None for mask_logit in out_mask_logits: mask = (mask_logit > 0.0).cpu().numpy() if combined_mask is None: combined_mask = mask.astype(bool) else: combined_mask = combined_mask | mask.astype(bool) return (combined_mask * 255).astype(np.uint8) # If no masks found, return empty frame_shape = self.frame_buffer[-1]['frame'].shape return np.zeros((frame_shape[0], frame_shape[1]), dtype=np.uint8) except Exception as e: print(f" Warning: Mask propagation failed: {e}") frame_shape = self.frame_buffer[-1]['frame'].shape return np.zeros((frame_shape[0], frame_shape[1]), dtype=np.uint8) def _cleanup_old_frames(self): """Clean up old frames from buffer (det-sam2 pattern)""" # Keep only recent frames to prevent memory accumulation if len(self.frame_buffer) > self.frame_buffer_size: self.frame_buffer = self.frame_buffer[-self.frame_buffer_size:] # Periodic GPU memory cleanup if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() def _cleanup_temp_frames(self): """Clean up temporary frame directory""" if self.temp_dir and os.path.exists(self.temp_dir): import shutil shutil.rmtree(self.temp_dir) self.temp_dir = None def cleanup(self): """Clean up all resources""" self._cleanup_temp_frames() self.frame_buffer.clear() self.object_ids.clear() if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.synchronize() gc.collect() print("🧹 Simple SAM2 streaming processor cleaned up") def apply_mask_to_frame(self, frame: np.ndarray, mask: np.ndarray, output_format: str = "alpha", background_color: tuple = (0, 255, 0)) -> np.ndarray: """ Apply mask to frame with specified output format (matches chunked implementation) Args: frame: Input frame (BGR) mask: Binary mask (0-255 or boolean) output_format: "alpha" or "greenscreen" background_color: RGB background color for greenscreen mode Returns: Processed frame """ if mask is None: return frame # Ensure mask is 2D (handle 3D masks properly) if mask.ndim == 3: mask = mask.squeeze() # Resize mask to match frame if needed (use INTER_NEAREST for binary masks) if mask.shape[:2] != frame.shape[:2]: import cv2 # Convert to uint8 for resizing, then back to bool if mask.dtype == bool: mask_uint8 = mask.astype(np.uint8) * 255 else: mask_uint8 = mask.astype(np.uint8) mask_resized = cv2.resize(mask_uint8, (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_NEAREST) mask = mask_resized.astype(bool) if mask.dtype == bool else mask_resized if output_format == "alpha": # Create RGBA output (matches chunked implementation) output = np.zeros((frame.shape[0], frame.shape[1], 4), dtype=np.uint8) output[:, :, :3] = frame if mask.dtype == bool: output[:, :, 3] = mask.astype(np.uint8) * 255 else: output[:, :, 3] = mask.astype(np.uint8) return output elif output_format == "greenscreen": # Create RGB output with background (matches chunked implementation) output = np.full_like(frame, background_color, dtype=np.uint8) if mask.dtype == bool: output[mask] = frame[mask] else: mask_bool = mask.astype(bool) output[mask_bool] = frame[mask_bool] return output else: raise ValueError(f"Unsupported output format: {output_format}. Use 'alpha' or 'greenscreen'") def get_memory_usage(self) -> Dict[str, float]: """ Get current memory usage statistics Returns: Dictionary with memory usage info """ stats = {} if torch.cuda.is_available(): # GPU memory stats stats['cuda_allocated_gb'] = torch.cuda.memory_allocated() / (1024**3) stats['cuda_reserved_gb'] = torch.cuda.memory_reserved() / (1024**3) stats['cuda_max_allocated_gb'] = torch.cuda.max_memory_allocated() / (1024**3) return stats