diff --git a/vr180_streaming/sam2_streaming.py b/vr180_streaming/sam2_streaming.py index 2216c5f..e49d8e7 100644 --- a/vr180_streaming/sam2_streaming.py +++ b/vr180_streaming/sam2_streaming.py @@ -18,6 +18,7 @@ from pathlib import Path from typing import Dict, Any, List, Optional, Tuple, Generator import warnings import gc +from .timeout_init import safe_init_state, TimeoutError # Import SAM2 components - these will be available after SAM2 installation try: @@ -92,41 +93,85 @@ class SAM2StreamingProcessor: raise RuntimeError(f"Failed to initialize SAM2 predictor: {e}") def init_state(self, - video_path: str, + video_info: Dict[str, Any], eye: str = 'full') -> Dict[str, Any]: """ - Initialize inference state for streaming + Initialize inference state for streaming (NO VIDEO LOADING) Args: - video_path: Path to video file + video_info: Video metadata dict with width, height, frame_count eye: Eye identifier ('left', 'right', or 'full') Returns: Inference state dictionary """ - # Initialize state with memory offloading enabled - with torch.inference_mode(): - state = self.predictor.init_state( - video_path=video_path, - offload_video_to_cpu=self.memory_offload, - offload_state_to_cpu=self.memory_offload, - async_loading_frames=False # We'll provide frames directly - ) + print(f" Initializing streaming state for {eye} eye...") + + # Monitor memory before initialization + if torch.cuda.is_available(): + before_mem = torch.cuda.memory_allocated() / 1e9 + print(f" 📊 GPU memory before init: {before_mem:.1f}GB") + + # Create streaming state WITHOUT loading video frames + state = self._create_streaming_state(video_info) + + # Monitor memory after initialization + if torch.cuda.is_available(): + after_mem = torch.cuda.memory_allocated() / 1e9 + print(f" 📊 GPU memory after init: {after_mem:.1f}GB (+{after_mem-before_mem:.1f}GB)") self.states[eye] = state - print(f" Initialized state for {eye} eye") + print(f" ✅ Streaming state initialized for {eye} eye") return state + def _create_streaming_state(self, video_info: Dict[str, Any]) -> Dict[str, Any]: + """Create streaming state for frame-by-frame processing""" + # Create a streaming-compatible inference state + # This mirrors SAM2's internal state structure but without video frames + + with torch.inference_mode(): + # Initialize empty inference state using SAM2's predictor + # We'll manually provide frames via propagate calls + inference_state = { + 'point_inputs_per_obj': {}, + 'mask_inputs_per_obj': {}, + 'cached_features': {}, + 'constants': {}, + 'obj_id_to_idx': {}, + 'obj_idx_to_id': {}, + 'obj_ids': [], + 'click_inputs_per_obj': {}, + 'temp_output_dict_per_obj': {}, + 'consolidated_frame_inds': {}, + 'tracking_has_started': False, + 'num_frames': video_info['frame_count'], + 'video_height': video_info['height'], + 'video_width': video_info['width'], + 'device': self.device, + 'storage_device': torch.device('cpu') if self.memory_offload else self.device, + 'offload_video_to_cpu': self.memory_offload, + 'offload_state_to_cpu': self.memory_offload, + 'inference_state': {}, + } + + # Initialize SAM2 constants that don't depend on video frames + self.predictor._get_image_feature_cache = {} + self.predictor._feature_bank = {} + + return inference_state + def add_detections(self, state: Dict[str, Any], + frame: np.ndarray, detections: List[Dict[str, Any]], frame_idx: int = 0) -> List[int]: """ - Add detection boxes as prompts to SAM2 + Add detection boxes as prompts to SAM2 with frame data Args: state: Inference state + frame: Frame image (RGB numpy array) detections: List of detections with 'box' key frame_idx: Frame index to add prompts @@ -137,6 +182,12 @@ class SAM2StreamingProcessor: warnings.warn(f"No detections to add at frame {frame_idx}") return [] + # Convert frame to tensor + frame_tensor = torch.tensor(frame, dtype=torch.float32, device=self.device) + if frame_tensor.ndim == 3: + frame_tensor = frame_tensor.permute(2, 0, 1) # HWC -> CHW + frame_tensor = frame_tensor.unsqueeze(0) # Add batch dimension + # Convert detections to SAM2 format boxes = [] for det in detections: @@ -145,9 +196,16 @@ class SAM2StreamingProcessor: boxes_tensor = torch.tensor(boxes, dtype=torch.float32, device=self.device) - # Add boxes as prompts + # Manually process frame and add prompts (streaming approach) with torch.inference_mode(): - _, object_ids, _ = self.predictor.add_new_points_or_box( + # Process frame through SAM2's image encoder + features = self.predictor._get_image_features(frame_tensor) + + # Store features in state for this frame + state['cached_features'][frame_idx] = features + + # Add boxes as prompts for this specific frame + _, object_ids, masks = self.predictor.add_new_points_or_box( inference_state=state, frame_idx=frame_idx, obj_id=0, # SAM2 will auto-increment @@ -159,29 +217,78 @@ class SAM2StreamingProcessor: return object_ids - def propagate_in_video_simple(self, - state: Dict[str, Any]) -> Generator[Tuple[int, List[int], np.ndarray], None, None]: + def propagate_single_frame(self, + state: Dict[str, Any], + frame: np.ndarray, + frame_idx: int) -> np.ndarray: """ - Simple propagation for single eye processing + Propagate masks for a single frame (true streaming) - Yields: - (frame_idx, object_ids, masks) tuples + Args: + state: Inference state + frame: Frame image (RGB numpy array) + frame_idx: Frame index + + Returns: + Combined mask for all objects """ + # Convert frame to tensor + frame_tensor = torch.tensor(frame, dtype=torch.float32, device=self.device) + if frame_tensor.ndim == 3: + frame_tensor = frame_tensor.permute(2, 0, 1) # HWC -> CHW + frame_tensor = frame_tensor.unsqueeze(0) # Add batch dimension + with torch.inference_mode(): - for frame_idx, object_ids, masks in self.predictor.propagate_in_video(state): - # Convert masks to numpy - if isinstance(masks, torch.Tensor): - masks_np = masks.cpu().numpy() - else: - masks_np = masks - - yield frame_idx, object_ids, masks_np + # Process frame through SAM2's image encoder + features = self.predictor._get_image_features(frame_tensor) + + # Store features in state for this frame + state['cached_features'][frame_idx] = features + + # Get masks for current frame by propagating from previous frames + masks = [] + for obj_id in state.get('obj_ids', []): + # Use SAM2's mask propagation for this object + try: + obj_mask = self.predictor._propagate_single_object( + state, obj_id, frame_idx, features + ) + if obj_mask is not None: + masks.append(obj_mask) + except Exception as e: + # If propagation fails, use empty mask + print(f" Warning: Propagation failed for object {obj_id}: {e}") + empty_mask = torch.zeros((frame.shape[0], frame.shape[1]), device=self.device) + masks.append(empty_mask) + + # Combine all object masks + if masks: + combined_mask = torch.stack(masks).max(dim=0)[0] + # Convert to numpy + combined_mask_np = combined_mask.cpu().numpy().astype(np.uint8) + else: + combined_mask_np = np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8) + + # Cleanup old features to prevent memory accumulation + self._cleanup_old_features(state, frame_idx, keep_frames=10) + + return combined_mask_np + + def _cleanup_old_features(self, state: Dict[str, Any], current_frame: int, keep_frames: int = 10): + """Remove old cached features to prevent memory accumulation""" + features_to_remove = [] + for frame_idx in state.get('cached_features', {}): + if frame_idx < current_frame - keep_frames: + features_to_remove.append(frame_idx) - # Periodic memory cleanup - if frame_idx % 100 == 0: - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() + for frame_idx in features_to_remove: + del state['cached_features'][frame_idx] + + # Periodic GPU memory cleanup + if current_frame % 50 == 0: + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() def propagate_frame_pair(self, left_state: Dict[str, Any], diff --git a/vr180_streaming/streaming_processor.py b/vr180_streaming/streaming_processor.py index a15563c..3bdc34e 100644 --- a/vr180_streaming/streaming_processor.py +++ b/vr180_streaming/streaming_processor.py @@ -102,14 +102,15 @@ class VR180StreamingProcessor: self.initialize() self.start_time = time.time() - # Initialize SAM2 states for both eyes + # Initialize SAM2 states for both eyes (streaming mode - no video loading) print("🎯 Initializing SAM2 streaming states...") + video_info = self.frame_reader.get_video_info() left_state = self.sam2_processor.init_state( - self.config.input.video_path, + video_info, eye='left' ) right_state = self.sam2_processor.init_state( - self.config.input.video_path, + video_info, eye='right' ) @@ -158,19 +159,19 @@ class VR180StreamingProcessor: print(f" Detected {len(detections)} person(s) in first frame") - # Add detections to both eyes - self.sam2_processor.add_detections(left_state, detections, frame_idx=0) + # Add detections to both eyes (streaming - pass frame data) + self.sam2_processor.add_detections(left_state, left_eye, detections, frame_idx=0) # Transfer detections to slave eye transferred_detections = self.stereo_manager.transfer_detections( detections, 'left_to_right' if self.stereo_manager.master_eye == 'left' else 'right_to_left' ) - self.sam2_processor.add_detections(right_state, transferred_detections, frame_idx=0) + self.sam2_processor.add_detections(right_state, right_eye, transferred_detections, frame_idx=0) # Process and write first frame - left_masks = self.sam2_processor._propagate_single_frame(left_state, left_eye, 0) - right_masks = self.sam2_processor._propagate_single_frame(right_state, right_eye, 0) + left_masks = self.sam2_processor.propagate_single_frame(left_state, left_eye, 0) + right_masks = self.sam2_processor.propagate_single_frame(right_state, right_eye, 0) # Apply masks and write processed_frame = self._apply_masks_to_frame(first_frame, left_masks, right_masks) @@ -195,10 +196,9 @@ class VR180StreamingProcessor: # Split into eyes left_eye, right_eye = self.stereo_manager.split_frame(frame) - # Propagate masks for both eyes - left_masks, right_masks = self.sam2_processor.propagate_frame_pair( - left_state, right_state, left_eye, right_eye, frame_idx - ) + # Propagate masks for both eyes (streaming approach) + left_masks = self.sam2_processor.propagate_single_frame(left_state, left_eye, frame_idx) + right_masks = self.sam2_processor.propagate_single_frame(right_state, right_eye, frame_idx) # Validate stereo consistency right_masks = self.stereo_manager.validate_masks(