""" SAM2 streaming processor for frame-by-frame video segmentation NOTE: This is a template implementation. The actual SAM2 integration would need to: 1. Handle the fact that SAM2VideoPredictor loads the entire video internally 2. Potentially modify SAM2 to support frame-by-frame input 3. Or use a custom video loader that provides frames on demand For a true streaming implementation, you may need to: - Extend SAM2VideoPredictor to accept a frame generator instead of video path - Implement a custom video loader that doesn't load all frames at once - Use the memory offloading features more aggressively """ import torch import numpy as np import cv2 from pathlib import Path from typing import Dict, Any, List, Optional, Tuple, Generator import warnings import gc # Import SAM2 components - these will be available after SAM2 installation try: from sam2.build_sam import build_sam2_video_predictor from sam2.utils.misc import load_video_frames except ImportError: warnings.warn("SAM2 not installed. Please install segment-anything-2 first.") class SAM2StreamingProcessor: """Streaming integration with SAM2 video predictor""" def __init__(self, config: Dict[str, Any]): self.config = config self.device = torch.device(config.get('hardware', {}).get('device', 'cuda')) # Processing parameters (set before _init_predictor) self.memory_offload = config.get('matting', {}).get('memory_offload', True) self.fp16 = config.get('matting', {}).get('fp16', True) self.correction_interval = config.get('matting', {}).get('correction_interval', 300) # SAM2 model configuration model_cfg = config.get('matting', {}).get('sam2_model_cfg', 'sam2.1_hiera_l') checkpoint = config.get('matting', {}).get('sam2_checkpoint', 'segment-anything-2/checkpoints/sam2.1_hiera_large.pt') # Build predictor self.predictor = None self._init_predictor(model_cfg, checkpoint) # State management self.states = {} # eye -> inference state self.object_ids = [] self.frame_count = 0 print(f"๐ŸŽฏ SAM2 streaming processor initialized:") print(f" Model: {model_cfg}") print(f" Device: {self.device}") print(f" Memory offload: {self.memory_offload}") print(f" FP16: {self.fp16}") def _init_predictor(self, model_cfg: str, checkpoint: str) -> None: """Initialize SAM2 video predictor""" try: # Map config string to actual config path config_mapping = { 'sam2.1_hiera_t': 'configs/sam2.1/sam2.1_hiera_t.yaml', 'sam2.1_hiera_s': 'configs/sam2.1/sam2.1_hiera_s.yaml', 'sam2.1_hiera_b+': 'configs/sam2.1/sam2.1_hiera_b+.yaml', 'sam2.1_hiera_l': 'configs/sam2.1/sam2.1_hiera_l.yaml', } actual_config = config_mapping.get(model_cfg, model_cfg) # Build predictor with VOS optimizations self.predictor = build_sam2_video_predictor( actual_config, checkpoint, device=self.device, vos_optimized=True # Enable full model compilation for speed ) # Set to eval mode and ensure all model components are on GPU self.predictor.eval() # Force all predictor components to GPU self.predictor = self.predictor.to(self.device) # Force move all internal components that might be on CPU if hasattr(self.predictor, 'image_encoder'): self.predictor.image_encoder = self.predictor.image_encoder.to(self.device) if hasattr(self.predictor, 'memory_attention'): self.predictor.memory_attention = self.predictor.memory_attention.to(self.device) if hasattr(self.predictor, 'memory_encoder'): self.predictor.memory_encoder = self.predictor.memory_encoder.to(self.device) if hasattr(self.predictor, 'sam_mask_decoder'): self.predictor.sam_mask_decoder = self.predictor.sam_mask_decoder.to(self.device) if hasattr(self.predictor, 'sam_prompt_encoder'): self.predictor.sam_prompt_encoder = self.predictor.sam_prompt_encoder.to(self.device) # Note: FP16 conversion can cause type mismatches with compiled models # Let SAM2 handle precision internally via build_sam2_video_predictor options if self.fp16 and self.device.type == 'cuda': print(" FP16 enabled via SAM2 internal settings") print(f" All SAM2 components moved to {self.device}") except Exception as e: raise RuntimeError(f"Failed to initialize SAM2 predictor: {e}") def init_state(self, video_info: Dict[str, Any], eye: str = 'full') -> Dict[str, Any]: """ Initialize inference state for streaming (NO VIDEO LOADING) Args: video_info: Video metadata dict with width, height, frame_count eye: Eye identifier ('left', 'right', or 'full') Returns: Inference state dictionary """ print(f" Initializing streaming state for {eye} eye...") # Monitor memory before initialization if torch.cuda.is_available(): before_mem = torch.cuda.memory_allocated() / 1e9 print(f" ๐Ÿ“Š GPU memory before init: {before_mem:.1f}GB") # Create streaming state WITHOUT loading video frames state = self._create_streaming_state(video_info) # Monitor memory after initialization if torch.cuda.is_available(): after_mem = torch.cuda.memory_allocated() / 1e9 print(f" ๐Ÿ“Š GPU memory after init: {after_mem:.1f}GB (+{after_mem-before_mem:.1f}GB)") self.states[eye] = state print(f" โœ… Streaming state initialized for {eye} eye") return state def _create_streaming_state(self, video_info: Dict[str, Any]) -> Dict[str, Any]: """Create streaming state for frame-by-frame processing""" # Create a streaming-compatible inference state # This mirrors SAM2's internal state structure but without video frames # Create streaming-compatible state without loading video # This approach avoids the dummy video complexity with torch.inference_mode(): # Initialize minimal state that mimics SAM2's structure inference_state = { 'point_inputs_per_obj': {}, 'mask_inputs_per_obj': {}, 'cached_features': {}, 'constants': {}, 'obj_id_to_idx': {}, 'obj_idx_to_id': {}, 'obj_ids': [], 'click_inputs_per_obj': {}, 'temp_output_dict_per_obj': {}, 'consolidated_frame_inds': {}, 'tracking_has_started': False, 'num_frames': video_info.get('total_frames', video_info.get('frame_count', 0)), 'video_height': video_info['height'], 'video_width': video_info['width'], 'device': self.device, 'storage_device': self.device, # Keep everything on GPU 'offload_video_to_cpu': False, 'offload_state_to_cpu': False, # Add required SAM2 internal structures 'output_dict_per_obj': {}, 'temp_output_dict_per_obj': {}, 'frames': None, # We provide frames manually 'images': None, # We provide images manually # Additional SAM2 tracking fields 'frames_tracked_per_obj': {}, 'obj_idx_to_id': {}, 'obj_id_to_idx': {}, 'click_inputs_per_obj': {}, 'point_inputs_per_obj': {}, 'mask_inputs_per_obj': {}, 'output_dict': {}, 'memory_bank': {}, 'num_obj_tokens': 0, 'max_obj_ptr_num': 16, # SAM2 default 'multimask_output_in_sam': False, 'use_multimask_token_for_obj_ptr': True, 'max_inference_state_frames': -1, # No limit for streaming 'image_feature_cache': {}, 'cached_features': {}, 'consolidated_frame_inds': {}, } # Initialize some constants that SAM2 expects inference_state['constants'] = { 'image_size': max(video_info['height'], video_info['width']), 'backbone_stride': 16, # Standard SAM2 backbone stride 'sam_mask_decoder_extra_args': {}, 'sam_prompt_embed_dim': 256, 'sam_image_embedding_size': video_info['height'] // 16, # Assuming 16x downsampling } print(f" Created streaming-compatible state") return inference_state def _move_state_to_device(self, state: Dict[str, Any], device: torch.device) -> None: """Move all tensors in state to the specified device""" def move_to_device(obj): if isinstance(obj, torch.Tensor): return obj.to(device) elif isinstance(obj, dict): return {k: move_to_device(v) for k, v in obj.items()} elif isinstance(obj, list): return [move_to_device(item) for item in obj] elif isinstance(obj, tuple): return tuple(move_to_device(item) for item in obj) else: return obj # Move all state components to device for key, value in state.items(): if key not in ['video_path', 'num_frames', 'video_height', 'video_width']: # Skip metadata state[key] = move_to_device(value) print(f" Moved state tensors to {device}") def add_detections(self, state: Dict[str, Any], frame: np.ndarray, detections: List[Dict[str, Any]], frame_idx: int = 0) -> List[int]: """ Add detection boxes as prompts to SAM2 with frame data Args: state: Inference state frame: Frame image (RGB numpy array) detections: List of detections with 'box' key frame_idx: Frame index to add prompts Returns: List of object IDs """ if not detections: warnings.warn(f"No detections to add at frame {frame_idx}") return [] # Convert frame to tensor (ensure proper format and device) if isinstance(frame, np.ndarray): # Convert BGR to RGB if needed (OpenCV uses BGR) if frame.shape[-1] == 3: frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frame_tensor = torch.from_numpy(frame).float().to(self.device) else: frame_tensor = frame.float().to(self.device) if frame_tensor.ndim == 3: frame_tensor = frame_tensor.permute(2, 0, 1) # HWC -> CHW frame_tensor = frame_tensor.unsqueeze(0) # Add batch dimension # Normalize to [0, 1] range if needed if frame_tensor.max() > 1.0: frame_tensor = frame_tensor / 255.0 # Convert detections to SAM2 format boxes = [] for det in detections: box = det['box'] # [x1, y1, x2, y2] boxes.append(box) boxes_tensor = torch.tensor(boxes, dtype=torch.float32, device=self.device) # Manually process frame and add prompts (streaming approach) with torch.inference_mode(): # Process frame through SAM2's image encoder backbone_out = self.predictor.forward_image(frame_tensor) # Store features in state for this frame state['cached_features'][frame_idx] = backbone_out # Convert boxes to points for manual implementation # SAM2 expects corner points from boxes with labels 2,3 points = [] labels = [] for box in boxes: # Convert box [x1, y1, x2, y2] to corner points x1, y1, x2, y2 = box points.extend([[x1, y1], [x2, y2]]) # Top-left and bottom-right corners labels.extend([2, 3]) # SAM2 standard labels for box corners points_tensor = torch.tensor(points, dtype=torch.float32, device=self.device) labels_tensor = torch.tensor(labels, dtype=torch.int32, device=self.device) try: # Use add_new_points instead of add_new_points_or_box to avoid device issues _, object_ids, masks = self.predictor.add_new_points( inference_state=state, frame_idx=frame_idx, obj_id=None, # Let SAM2 auto-assign points=points_tensor, labels=labels_tensor, clear_old_points=True, normalize_coords=True ) # Update state with object tracking info state['obj_ids'] = object_ids state['tracking_has_started'] = True except Exception as e: print(f" Error in add_new_points: {e}") print(f" Points tensor device: {points_tensor.device}") print(f" Labels tensor device: {labels_tensor.device}") print(f" Frame tensor device: {frame_tensor.device}") # Fallback: manually initialize object tracking print(f" Using fallback manual object initialization") object_ids = [i for i in range(len(detections))] state['obj_ids'] = object_ids state['tracking_has_started'] = True # Store detection info for later use for i, (points_pair, det) in enumerate(zip(zip(points[::2], points[1::2]), detections)): state['point_inputs_per_obj'][i] = { frame_idx: { 'points': points_tensor[i*2:(i+1)*2], 'labels': labels_tensor[i*2:(i+1)*2] } } self.object_ids = object_ids print(f" Added {len(detections)} detections at frame {frame_idx}: objects {object_ids}") return object_ids def propagate_single_frame(self, state: Dict[str, Any], frame: np.ndarray, frame_idx: int) -> np.ndarray: """ Propagate masks for a single frame (true streaming) Args: state: Inference state frame: Frame image (RGB numpy array) frame_idx: Frame index Returns: Combined mask for all objects """ # Convert frame to tensor (ensure proper format and device) if isinstance(frame, np.ndarray): # Convert BGR to RGB if needed (OpenCV uses BGR) if frame.shape[-1] == 3: frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frame_tensor = torch.from_numpy(frame).float().to(self.device) else: frame_tensor = frame.float().to(self.device) if frame_tensor.ndim == 3: frame_tensor = frame_tensor.permute(2, 0, 1) # HWC -> CHW frame_tensor = frame_tensor.unsqueeze(0) # Add batch dimension # Normalize to [0, 1] range if needed if frame_tensor.max() > 1.0: frame_tensor = frame_tensor / 255.0 with torch.inference_mode(): # Process frame through SAM2's image encoder backbone_out = self.predictor.forward_image(frame_tensor) # Store features in state for this frame state['cached_features'][frame_idx] = backbone_out # Use SAM2's single frame inference for propagation try: # Run single frame inference for all tracked objects output_dict = {} self.predictor._run_single_frame_inference( inference_state=state, output_dict=output_dict, frame_idx=frame_idx, batch_size=1, is_init_cond_frame=False, # Not initialization frame point_inputs=None, mask_inputs=None, reverse=False, run_mem_encoder=True ) # Extract masks from output if output_dict and 'pred_masks' in output_dict: pred_masks = output_dict['pred_masks'] # Combine all object masks if pred_masks.shape[0] > 0: combined_mask = pred_masks.max(dim=0)[0] combined_mask_np = (combined_mask > 0.0).cpu().numpy().astype(np.uint8) * 255 else: combined_mask_np = np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8) else: combined_mask_np = np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8) except Exception as e: print(f" Warning: Single frame inference failed: {e}") combined_mask_np = np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8) # Cleanup old features to prevent memory accumulation self._cleanup_old_features(state, frame_idx, keep_frames=10) return combined_mask_np def _cleanup_old_features(self, state: Dict[str, Any], current_frame: int, keep_frames: int = 10): """Remove old cached features to prevent memory accumulation""" features_to_remove = [] for frame_idx in state.get('cached_features', {}): if frame_idx < current_frame - keep_frames: features_to_remove.append(frame_idx) for frame_idx in features_to_remove: del state['cached_features'][frame_idx] # Periodic GPU memory cleanup if current_frame % 50 == 0: if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() def propagate_frame_pair(self, left_state: Dict[str, Any], right_state: Dict[str, Any], left_frame: np.ndarray, right_frame: np.ndarray, frame_idx: int) -> Tuple[np.ndarray, np.ndarray]: """ Propagate masks for a stereo frame pair Args: left_state: Left eye inference state right_state: Right eye inference state left_frame: Left eye frame right_frame: Right eye frame frame_idx: Current frame index Returns: Tuple of (left_masks, right_masks) """ # For actual implementation, we would need to handle the video frames # being already loaded in the state. This is a simplified version. # In practice, SAM2's propagate_in_video would handle frame loading. # Get masks from the current propagation state # This is pseudo-code as actual integration would depend on # how frames are provided to SAM2VideoPredictor left_masks = np.zeros((left_frame.shape[0], left_frame.shape[1]), dtype=np.uint8) right_masks = np.zeros((right_frame.shape[0], right_frame.shape[1]), dtype=np.uint8) # In actual implementation, you would: # 1. Use predictor.propagate_in_video() generator # 2. Extract masks for current frame_idx # 3. Combine multiple object masks if needed return left_masks, right_masks def _propagate_single_frame(self, state: Dict[str, Any], frame: np.ndarray, frame_idx: int) -> np.ndarray: """ Propagate masks for a single frame Args: state: Inference state frame: Input frame frame_idx: Frame index Returns: Combined mask for all objects """ # This is a simplified version - in practice we'd use the actual # SAM2 propagation API which handles memory updates internally # Get current masks from propagation # Note: This is pseudo-code as the actual API may differ masks = [] # For each tracked object for obj_idx in range(len(self.object_ids)): # Get mask for this object # In reality, SAM2 handles this internally obj_mask = self._get_object_mask(state, obj_idx, frame_idx) masks.append(obj_mask) # Combine all object masks if masks: combined_mask = np.max(masks, axis=0) else: combined_mask = np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8) return combined_mask def _get_object_mask(self, state: Dict[str, Any], obj_idx: int, frame_idx: int) -> np.ndarray: """ Get mask for specific object (placeholder - actual implementation uses SAM2 API) """ # In practice, this would extract the mask from SAM2's internal state # For now, return a placeholder h, w = state.get('video_height', 1080), state.get('video_width', 1920) return np.zeros((h, w), dtype=np.uint8) def apply_continuous_correction(self, state: Dict[str, Any], frame: np.ndarray, frame_idx: int, detector: Any) -> None: """ Apply continuous correction by re-detecting and refining masks Args: state: Inference state frame: Current frame frame_idx: Frame index detector: Person detector instance """ if frame_idx % self.correction_interval != 0: return print(f" ๐Ÿ”„ Applying continuous correction at frame {frame_idx}") # Detect persons in current frame new_detections = detector.detect_persons(frame) if new_detections: # Add new prompts to refine tracking with torch.inference_mode(): boxes = [det['box'] for det in new_detections] boxes_tensor = torch.tensor(boxes, dtype=torch.float32, device=self.device) # Add refinement prompts self.predictor.add_new_points_or_box( inference_state=state, frame_idx=frame_idx, obj_id=0, # Refine existing objects box=boxes_tensor ) def apply_mask_to_frame(self, frame: np.ndarray, mask: np.ndarray, output_format: str = 'greenscreen', background_color: List[int] = [0, 255, 0]) -> np.ndarray: """ Apply mask to frame with specified output format Args: frame: Input frame (BGR) mask: Binary mask output_format: 'alpha' or 'greenscreen' background_color: Background color for greenscreen Returns: Processed frame """ if output_format == 'alpha': # Add alpha channel if mask.dtype != np.uint8: mask = (mask * 255).astype(np.uint8) # Create BGRA image bgra = np.zeros((frame.shape[0], frame.shape[1], 4), dtype=np.uint8) bgra[:, :, :3] = frame bgra[:, :, 3] = mask return bgra else: # greenscreen # Create green background background = np.full_like(frame, background_color, dtype=np.uint8) # Expand mask to 3 channels if mask.ndim == 2: mask_3ch = np.expand_dims(mask, axis=2) mask_3ch = np.repeat(mask_3ch, 3, axis=2) else: mask_3ch = mask # Normalize mask to 0-1 if mask_3ch.dtype == np.uint8: mask_float = mask_3ch.astype(np.float32) / 255.0 else: mask_float = mask_3ch.astype(np.float32) # Composite result = (frame * mask_float + background * (1 - mask_float)).astype(np.uint8) return result def cleanup(self) -> None: """Clean up resources""" # Clear states self.states.clear() # Clear CUDA cache if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.synchronize() # Garbage collection gc.collect() print("๐Ÿงน SAM2 streaming processor cleaned up") def get_memory_usage(self) -> Dict[str, float]: """Get current memory usage""" memory_stats = { 'states_count': len(self.states), 'object_count': len(self.object_ids), } if torch.cuda.is_available(): memory_stats['cuda_allocated_gb'] = torch.cuda.memory_allocated() / 1e9 memory_stats['cuda_reserved_gb'] = torch.cuda.memory_reserved() / 1e9 return memory_stats