diff --git a/vr180_streaming/sam2_streaming.py b/vr180_streaming/sam2_streaming.py index 2b09e28..3c22938 100644 --- a/vr180_streaming/sam2_streaming.py +++ b/vr180_streaming/sam2_streaming.py @@ -130,45 +130,11 @@ class SAM2StreamingProcessor: # Create a streaming-compatible inference state # This mirrors SAM2's internal state structure but without video frames - # Use SAM2's init_state but with a dummy 1-frame video to avoid loading - # We'll override the frame access later - try: - # Create a minimal dummy video file temporarily - import tempfile - import cv2 - - # Create 1-frame dummy video - with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_file: - dummy_path = tmp_file.name - - # Write a single frame video - fourcc = cv2.VideoWriter_fourcc(*'mp4v') - out = cv2.VideoWriter(dummy_path, fourcc, 1.0, (video_info['width'], video_info['height'])) - dummy_frame = np.zeros((video_info['height'], video_info['width'], 3), dtype=np.uint8) - out.write(dummy_frame) - out.release() - - # Initialize with dummy video (SAM2 will load metadata only from 1 frame) - with torch.inference_mode(): - inference_state = self.predictor.init_state( - video_path=dummy_path, - offload_video_to_cpu=False, # Keep video frames on GPU for streaming - offload_state_to_cpu=False, # Keep state on GPU for performance - async_loading_frames=True - ) - - # Clean up dummy file - import os - os.unlink(dummy_path) - - # Update state with actual video info - inference_state['num_frames'] = video_info.get('total_frames', video_info.get('frame_count', 0)) - inference_state['video_height'] = video_info['height'] - inference_state['video_width'] = video_info['width'] - - except Exception as e: - print(f" Warning: Failed to create proper SAM2 state ({e}), using minimal state") - # Fallback to minimal state + # Create streaming-compatible state without loading video + # This approach avoids the dummy video complexity + + with torch.inference_mode(): + # Initialize minimal state that mimics SAM2's structure inference_state = { 'point_inputs_per_obj': {}, 'mask_inputs_per_obj': {}, @@ -185,13 +151,50 @@ class SAM2StreamingProcessor: 'video_height': video_info['height'], 'video_width': video_info['width'], 'device': self.device, - 'storage_device': torch.device('cpu') if self.memory_offload else self.device, - 'offload_video_to_cpu': self.memory_offload, - 'offload_state_to_cpu': self.memory_offload, + 'storage_device': self.device, # Keep everything on GPU + 'offload_video_to_cpu': False, + 'offload_state_to_cpu': False, + # Add some required SAM2 internal structures + 'output_dict_per_obj': {}, + 'temp_output_dict_per_obj': {}, + 'frames': None, # We provide frames manually + 'images': None, # We provide images manually } + # Initialize some constants that SAM2 expects + inference_state['constants'] = { + 'image_size': max(video_info['height'], video_info['width']), + 'backbone_stride': 16, # Standard SAM2 backbone stride + 'sam_mask_decoder_extra_args': {}, + 'sam_prompt_embed_dim': 256, + 'sam_image_embedding_size': video_info['height'] // 16, # Assuming 16x downsampling + } + + print(f" Created streaming-compatible state") + return inference_state + def _move_state_to_device(self, state: Dict[str, Any], device: torch.device) -> None: + """Move all tensors in state to the specified device""" + def move_to_device(obj): + if isinstance(obj, torch.Tensor): + return obj.to(device) + elif isinstance(obj, dict): + return {k: move_to_device(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [move_to_device(item) for item in obj] + elif isinstance(obj, tuple): + return tuple(move_to_device(item) for item in obj) + else: + return obj + + # Move all state components to device + for key, value in state.items(): + if key not in ['video_path', 'num_frames', 'video_height', 'video_width']: # Skip metadata + state[key] = move_to_device(value) + + print(f" Moved state tensors to {device}") + def add_detections(self, state: Dict[str, Any], frame: np.ndarray, @@ -248,16 +251,29 @@ class SAM2StreamingProcessor: # Add boxes as prompts for this specific frame try: + # Force ensure all inputs are on correct device + boxes_tensor = boxes_tensor.to(self.device) + _, object_ids, masks = self.predictor.add_new_points_or_box( inference_state=state, frame_idx=frame_idx, obj_id=None, # Let SAM2 auto-assign box=boxes_tensor ) + + # Update state with object tracking info + state['obj_ids'] = object_ids + state['tracking_has_started'] = True + except Exception as e: print(f" Error in add_new_points_or_box: {e}") print(f" Box tensor device: {boxes_tensor.device}") print(f" Frame tensor device: {frame_tensor.device}") + print(f" State device keys: {[k for k in state.keys() if 'device' in k.lower()]}") + # Try to inspect state tensor devices + for key, value in state.items(): + if isinstance(value, torch.Tensor): + print(f" State[{key}] device: {value.device}") raise self.object_ids = object_ids