please fucking work

This commit is contained in:
2025-07-27 09:15:48 -07:00
parent 2e5ded7dbf
commit 1d15fb5bc8

View File

@@ -14,6 +14,7 @@ For a true streaming implementation, you may need to:
import torch
import numpy as np
import cv2
from pathlib import Path
from typing import Dict, Any, List, Optional, Tuple, Generator
import warnings
@@ -151,8 +152,8 @@ class SAM2StreamingProcessor:
with torch.inference_mode():
inference_state = self.predictor.init_state(
video_path=dummy_path,
offload_video_to_cpu=self.memory_offload,
offload_state_to_cpu=self.memory_offload,
offload_video_to_cpu=False, # Keep video frames on GPU for streaming
offload_state_to_cpu=False, # Keep state on GPU for performance
async_loading_frames=True
)
@@ -212,12 +213,23 @@ class SAM2StreamingProcessor:
warnings.warn(f"No detections to add at frame {frame_idx}")
return []
# Convert frame to tensor
frame_tensor = torch.tensor(frame, dtype=torch.float32, device=self.device)
# Convert frame to tensor (ensure proper format and device)
if isinstance(frame, np.ndarray):
# Convert BGR to RGB if needed (OpenCV uses BGR)
if frame.shape[-1] == 3:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame_tensor = torch.from_numpy(frame).float().to(self.device)
else:
frame_tensor = frame.float().to(self.device)
if frame_tensor.ndim == 3:
frame_tensor = frame_tensor.permute(2, 0, 1) # HWC -> CHW
frame_tensor = frame_tensor.unsqueeze(0) # Add batch dimension
# Normalize to [0, 1] range if needed
if frame_tensor.max() > 1.0:
frame_tensor = frame_tensor / 255.0
# Convert detections to SAM2 format
boxes = []
for det in detections:
@@ -235,12 +247,18 @@ class SAM2StreamingProcessor:
state['cached_features'][frame_idx] = backbone_out
# Add boxes as prompts for this specific frame
_, object_ids, masks = self.predictor.add_new_points_or_box(
inference_state=state,
frame_idx=frame_idx,
obj_id=None, # Let SAM2 auto-assign
box=boxes_tensor
)
try:
_, object_ids, masks = self.predictor.add_new_points_or_box(
inference_state=state,
frame_idx=frame_idx,
obj_id=None, # Let SAM2 auto-assign
box=boxes_tensor
)
except Exception as e:
print(f" Error in add_new_points_or_box: {e}")
print(f" Box tensor device: {boxes_tensor.device}")
print(f" Frame tensor device: {frame_tensor.device}")
raise
self.object_ids = object_ids
print(f" Added {len(detections)} detections at frame {frame_idx}: objects {object_ids}")
@@ -262,12 +280,23 @@ class SAM2StreamingProcessor:
Returns:
Combined mask for all objects
"""
# Convert frame to tensor
frame_tensor = torch.tensor(frame, dtype=torch.float32, device=self.device)
# Convert frame to tensor (ensure proper format and device)
if isinstance(frame, np.ndarray):
# Convert BGR to RGB if needed (OpenCV uses BGR)
if frame.shape[-1] == 3:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame_tensor = torch.from_numpy(frame).float().to(self.device)
else:
frame_tensor = frame.float().to(self.device)
if frame_tensor.ndim == 3:
frame_tensor = frame_tensor.permute(2, 0, 1) # HWC -> CHW
frame_tensor = frame_tensor.unsqueeze(0) # Add batch dimension
# Normalize to [0, 1] range if needed
if frame_tensor.max() > 1.0:
frame_tensor = frame_tensor / 255.0
with torch.inference_mode():
# Process frame through SAM2's image encoder
backbone_out = self.predictor.forward_image(frame_tensor)