From 1d15fb5bc829d44bdf47a35f66a98a44c3326565 Mon Sep 17 00:00:00 2001 From: Scott Register Date: Sun, 27 Jul 2025 09:15:48 -0700 Subject: [PATCH] please fucking work --- vr180_streaming/sam2_streaming.py | 53 ++++++++++++++++++++++++------- 1 file changed, 41 insertions(+), 12 deletions(-) diff --git a/vr180_streaming/sam2_streaming.py b/vr180_streaming/sam2_streaming.py index 3f08eed..2b09e28 100644 --- a/vr180_streaming/sam2_streaming.py +++ b/vr180_streaming/sam2_streaming.py @@ -14,6 +14,7 @@ For a true streaming implementation, you may need to: import torch import numpy as np +import cv2 from pathlib import Path from typing import Dict, Any, List, Optional, Tuple, Generator import warnings @@ -151,8 +152,8 @@ class SAM2StreamingProcessor: with torch.inference_mode(): inference_state = self.predictor.init_state( video_path=dummy_path, - offload_video_to_cpu=self.memory_offload, - offload_state_to_cpu=self.memory_offload, + offload_video_to_cpu=False, # Keep video frames on GPU for streaming + offload_state_to_cpu=False, # Keep state on GPU for performance async_loading_frames=True ) @@ -212,12 +213,23 @@ class SAM2StreamingProcessor: warnings.warn(f"No detections to add at frame {frame_idx}") return [] - # Convert frame to tensor - frame_tensor = torch.tensor(frame, dtype=torch.float32, device=self.device) + # Convert frame to tensor (ensure proper format and device) + if isinstance(frame, np.ndarray): + # Convert BGR to RGB if needed (OpenCV uses BGR) + if frame.shape[-1] == 3: + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + frame_tensor = torch.from_numpy(frame).float().to(self.device) + else: + frame_tensor = frame.float().to(self.device) + if frame_tensor.ndim == 3: frame_tensor = frame_tensor.permute(2, 0, 1) # HWC -> CHW frame_tensor = frame_tensor.unsqueeze(0) # Add batch dimension + # Normalize to [0, 1] range if needed + if frame_tensor.max() > 1.0: + frame_tensor = frame_tensor / 255.0 + # Convert detections to SAM2 format boxes = [] for det in detections: @@ -235,12 +247,18 @@ class SAM2StreamingProcessor: state['cached_features'][frame_idx] = backbone_out # Add boxes as prompts for this specific frame - _, object_ids, masks = self.predictor.add_new_points_or_box( - inference_state=state, - frame_idx=frame_idx, - obj_id=None, # Let SAM2 auto-assign - box=boxes_tensor - ) + try: + _, object_ids, masks = self.predictor.add_new_points_or_box( + inference_state=state, + frame_idx=frame_idx, + obj_id=None, # Let SAM2 auto-assign + box=boxes_tensor + ) + except Exception as e: + print(f" Error in add_new_points_or_box: {e}") + print(f" Box tensor device: {boxes_tensor.device}") + print(f" Frame tensor device: {frame_tensor.device}") + raise self.object_ids = object_ids print(f" Added {len(detections)} detections at frame {frame_idx}: objects {object_ids}") @@ -262,12 +280,23 @@ class SAM2StreamingProcessor: Returns: Combined mask for all objects """ - # Convert frame to tensor - frame_tensor = torch.tensor(frame, dtype=torch.float32, device=self.device) + # Convert frame to tensor (ensure proper format and device) + if isinstance(frame, np.ndarray): + # Convert BGR to RGB if needed (OpenCV uses BGR) + if frame.shape[-1] == 3: + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + frame_tensor = torch.from_numpy(frame).float().to(self.device) + else: + frame_tensor = frame.float().to(self.device) + if frame_tensor.ndim == 3: frame_tensor = frame_tensor.permute(2, 0, 1) # HWC -> CHW frame_tensor = frame_tensor.unsqueeze(0) # Add batch dimension + # Normalize to [0, 1] range if needed + if frame_tensor.max() > 1.0: + frame_tensor = frame_tensor / 255.0 + with torch.inference_mode(): # Process frame through SAM2's image encoder backbone_out = self.predictor.forward_image(frame_tensor)