diff --git a/vr180_streaming/sam2_streaming.py b/vr180_streaming/sam2_streaming.py index 2f06f95..3f08eed 100644 --- a/vr180_streaming/sam2_streaming.py +++ b/vr180_streaming/sam2_streaming.py @@ -129,9 +129,45 @@ class SAM2StreamingProcessor: # Create a streaming-compatible inference state # This mirrors SAM2's internal state structure but without video frames - with torch.inference_mode(): - # Initialize empty inference state using SAM2's predictor - # We'll manually provide frames via propagate calls + # Use SAM2's init_state but with a dummy 1-frame video to avoid loading + # We'll override the frame access later + try: + # Create a minimal dummy video file temporarily + import tempfile + import cv2 + + # Create 1-frame dummy video + with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_file: + dummy_path = tmp_file.name + + # Write a single frame video + fourcc = cv2.VideoWriter_fourcc(*'mp4v') + out = cv2.VideoWriter(dummy_path, fourcc, 1.0, (video_info['width'], video_info['height'])) + dummy_frame = np.zeros((video_info['height'], video_info['width'], 3), dtype=np.uint8) + out.write(dummy_frame) + out.release() + + # Initialize with dummy video (SAM2 will load metadata only from 1 frame) + with torch.inference_mode(): + inference_state = self.predictor.init_state( + video_path=dummy_path, + offload_video_to_cpu=self.memory_offload, + offload_state_to_cpu=self.memory_offload, + async_loading_frames=True + ) + + # Clean up dummy file + import os + os.unlink(dummy_path) + + # Update state with actual video info + inference_state['num_frames'] = video_info.get('total_frames', video_info.get('frame_count', 0)) + inference_state['video_height'] = video_info['height'] + inference_state['video_width'] = video_info['width'] + + except Exception as e: + print(f" Warning: Failed to create proper SAM2 state ({e}), using minimal state") + # Fallback to minimal state inference_state = { 'point_inputs_per_obj': {}, 'mask_inputs_per_obj': {}, @@ -151,13 +187,8 @@ class SAM2StreamingProcessor: 'storage_device': torch.device('cpu') if self.memory_offload else self.device, 'offload_video_to_cpu': self.memory_offload, 'offload_state_to_cpu': self.memory_offload, - 'inference_state': {}, } - # Initialize SAM2 constants that don't depend on video frames - self.predictor._get_image_feature_cache = {} - self.predictor._feature_bank = {} - return inference_state def add_detections(self, @@ -198,16 +229,16 @@ class SAM2StreamingProcessor: # Manually process frame and add prompts (streaming approach) with torch.inference_mode(): # Process frame through SAM2's image encoder - features = self.predictor._get_image_features(frame_tensor) + backbone_out = self.predictor.forward_image(frame_tensor) # Store features in state for this frame - state['cached_features'][frame_idx] = features + state['cached_features'][frame_idx] = backbone_out # Add boxes as prompts for this specific frame _, object_ids, masks = self.predictor.add_new_points_or_box( inference_state=state, frame_idx=frame_idx, - obj_id=0, # SAM2 will auto-increment + obj_id=None, # Let SAM2 auto-assign box=boxes_tensor ) @@ -239,33 +270,41 @@ class SAM2StreamingProcessor: with torch.inference_mode(): # Process frame through SAM2's image encoder - features = self.predictor._get_image_features(frame_tensor) + backbone_out = self.predictor.forward_image(frame_tensor) # Store features in state for this frame - state['cached_features'][frame_idx] = features + state['cached_features'][frame_idx] = backbone_out - # Get masks for current frame by propagating from previous frames - masks = [] - for obj_id in state.get('obj_ids', []): - # Use SAM2's mask propagation for this object - try: - obj_mask = self.predictor._propagate_single_object( - state, obj_id, frame_idx, features - ) - if obj_mask is not None: - masks.append(obj_mask) - except Exception as e: - # If propagation fails, use empty mask - print(f" Warning: Propagation failed for object {obj_id}: {e}") - empty_mask = torch.zeros((frame.shape[0], frame.shape[1]), device=self.device) - masks.append(empty_mask) - - # Combine all object masks - if masks: - combined_mask = torch.stack(masks).max(dim=0)[0] - # Convert to numpy - combined_mask_np = combined_mask.cpu().numpy().astype(np.uint8) - else: + # Use SAM2's single frame inference for propagation + try: + # Run single frame inference for all tracked objects + output_dict = {} + self.predictor._run_single_frame_inference( + inference_state=state, + output_dict=output_dict, + frame_idx=frame_idx, + batch_size=1, + is_init_cond_frame=False, # Not initialization frame + point_inputs=None, + mask_inputs=None, + reverse=False, + run_mem_encoder=True + ) + + # Extract masks from output + if output_dict and 'pred_masks' in output_dict: + pred_masks = output_dict['pred_masks'] + # Combine all object masks + if pred_masks.shape[0] > 0: + combined_mask = pred_masks.max(dim=0)[0] + combined_mask_np = (combined_mask > 0.0).cpu().numpy().astype(np.uint8) * 255 + else: + combined_mask_np = np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8) + else: + combined_mask_np = np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8) + + except Exception as e: + print(f" Warning: Single frame inference failed: {e}") combined_mask_np = np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8) # Cleanup old features to prevent memory accumulation