please fucking work

This commit is contained in:
2025-07-27 09:15:48 -07:00
parent 2e5ded7dbf
commit 1d15fb5bc8

View File

@@ -14,6 +14,7 @@ For a true streaming implementation, you may need to:
import torch import torch
import numpy as np import numpy as np
import cv2
from pathlib import Path from pathlib import Path
from typing import Dict, Any, List, Optional, Tuple, Generator from typing import Dict, Any, List, Optional, Tuple, Generator
import warnings import warnings
@@ -151,8 +152,8 @@ class SAM2StreamingProcessor:
with torch.inference_mode(): with torch.inference_mode():
inference_state = self.predictor.init_state( inference_state = self.predictor.init_state(
video_path=dummy_path, video_path=dummy_path,
offload_video_to_cpu=self.memory_offload, offload_video_to_cpu=False, # Keep video frames on GPU for streaming
offload_state_to_cpu=self.memory_offload, offload_state_to_cpu=False, # Keep state on GPU for performance
async_loading_frames=True async_loading_frames=True
) )
@@ -212,12 +213,23 @@ class SAM2StreamingProcessor:
warnings.warn(f"No detections to add at frame {frame_idx}") warnings.warn(f"No detections to add at frame {frame_idx}")
return [] return []
# Convert frame to tensor # Convert frame to tensor (ensure proper format and device)
frame_tensor = torch.tensor(frame, dtype=torch.float32, device=self.device) if isinstance(frame, np.ndarray):
# Convert BGR to RGB if needed (OpenCV uses BGR)
if frame.shape[-1] == 3:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame_tensor = torch.from_numpy(frame).float().to(self.device)
else:
frame_tensor = frame.float().to(self.device)
if frame_tensor.ndim == 3: if frame_tensor.ndim == 3:
frame_tensor = frame_tensor.permute(2, 0, 1) # HWC -> CHW frame_tensor = frame_tensor.permute(2, 0, 1) # HWC -> CHW
frame_tensor = frame_tensor.unsqueeze(0) # Add batch dimension frame_tensor = frame_tensor.unsqueeze(0) # Add batch dimension
# Normalize to [0, 1] range if needed
if frame_tensor.max() > 1.0:
frame_tensor = frame_tensor / 255.0
# Convert detections to SAM2 format # Convert detections to SAM2 format
boxes = [] boxes = []
for det in detections: for det in detections:
@@ -235,12 +247,18 @@ class SAM2StreamingProcessor:
state['cached_features'][frame_idx] = backbone_out state['cached_features'][frame_idx] = backbone_out
# Add boxes as prompts for this specific frame # Add boxes as prompts for this specific frame
_, object_ids, masks = self.predictor.add_new_points_or_box( try:
inference_state=state, _, object_ids, masks = self.predictor.add_new_points_or_box(
frame_idx=frame_idx, inference_state=state,
obj_id=None, # Let SAM2 auto-assign frame_idx=frame_idx,
box=boxes_tensor obj_id=None, # Let SAM2 auto-assign
) box=boxes_tensor
)
except Exception as e:
print(f" Error in add_new_points_or_box: {e}")
print(f" Box tensor device: {boxes_tensor.device}")
print(f" Frame tensor device: {frame_tensor.device}")
raise
self.object_ids = object_ids self.object_ids = object_ids
print(f" Added {len(detections)} detections at frame {frame_idx}: objects {object_ids}") print(f" Added {len(detections)} detections at frame {frame_idx}: objects {object_ids}")
@@ -262,12 +280,23 @@ class SAM2StreamingProcessor:
Returns: Returns:
Combined mask for all objects Combined mask for all objects
""" """
# Convert frame to tensor # Convert frame to tensor (ensure proper format and device)
frame_tensor = torch.tensor(frame, dtype=torch.float32, device=self.device) if isinstance(frame, np.ndarray):
# Convert BGR to RGB if needed (OpenCV uses BGR)
if frame.shape[-1] == 3:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame_tensor = torch.from_numpy(frame).float().to(self.device)
else:
frame_tensor = frame.float().to(self.device)
if frame_tensor.ndim == 3: if frame_tensor.ndim == 3:
frame_tensor = frame_tensor.permute(2, 0, 1) # HWC -> CHW frame_tensor = frame_tensor.permute(2, 0, 1) # HWC -> CHW
frame_tensor = frame_tensor.unsqueeze(0) # Add batch dimension frame_tensor = frame_tensor.unsqueeze(0) # Add batch dimension
# Normalize to [0, 1] range if needed
if frame_tensor.max() > 1.0:
frame_tensor = frame_tensor / 255.0
with torch.inference_mode(): with torch.inference_mode():
# Process frame through SAM2's image encoder # Process frame through SAM2's image encoder
backbone_out = self.predictor.forward_image(frame_tensor) backbone_out = self.predictor.forward_image(frame_tensor)