please fucking work
This commit is contained in:
@@ -14,6 +14,7 @@ For a true streaming implementation, you may need to:
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import cv2
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, List, Optional, Tuple, Generator
|
||||
import warnings
|
||||
@@ -151,8 +152,8 @@ class SAM2StreamingProcessor:
|
||||
with torch.inference_mode():
|
||||
inference_state = self.predictor.init_state(
|
||||
video_path=dummy_path,
|
||||
offload_video_to_cpu=self.memory_offload,
|
||||
offload_state_to_cpu=self.memory_offload,
|
||||
offload_video_to_cpu=False, # Keep video frames on GPU for streaming
|
||||
offload_state_to_cpu=False, # Keep state on GPU for performance
|
||||
async_loading_frames=True
|
||||
)
|
||||
|
||||
@@ -212,12 +213,23 @@ class SAM2StreamingProcessor:
|
||||
warnings.warn(f"No detections to add at frame {frame_idx}")
|
||||
return []
|
||||
|
||||
# Convert frame to tensor
|
||||
frame_tensor = torch.tensor(frame, dtype=torch.float32, device=self.device)
|
||||
# Convert frame to tensor (ensure proper format and device)
|
||||
if isinstance(frame, np.ndarray):
|
||||
# Convert BGR to RGB if needed (OpenCV uses BGR)
|
||||
if frame.shape[-1] == 3:
|
||||
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
frame_tensor = torch.from_numpy(frame).float().to(self.device)
|
||||
else:
|
||||
frame_tensor = frame.float().to(self.device)
|
||||
|
||||
if frame_tensor.ndim == 3:
|
||||
frame_tensor = frame_tensor.permute(2, 0, 1) # HWC -> CHW
|
||||
frame_tensor = frame_tensor.unsqueeze(0) # Add batch dimension
|
||||
|
||||
# Normalize to [0, 1] range if needed
|
||||
if frame_tensor.max() > 1.0:
|
||||
frame_tensor = frame_tensor / 255.0
|
||||
|
||||
# Convert detections to SAM2 format
|
||||
boxes = []
|
||||
for det in detections:
|
||||
@@ -235,12 +247,18 @@ class SAM2StreamingProcessor:
|
||||
state['cached_features'][frame_idx] = backbone_out
|
||||
|
||||
# Add boxes as prompts for this specific frame
|
||||
try:
|
||||
_, object_ids, masks = self.predictor.add_new_points_or_box(
|
||||
inference_state=state,
|
||||
frame_idx=frame_idx,
|
||||
obj_id=None, # Let SAM2 auto-assign
|
||||
box=boxes_tensor
|
||||
)
|
||||
except Exception as e:
|
||||
print(f" Error in add_new_points_or_box: {e}")
|
||||
print(f" Box tensor device: {boxes_tensor.device}")
|
||||
print(f" Frame tensor device: {frame_tensor.device}")
|
||||
raise
|
||||
|
||||
self.object_ids = object_ids
|
||||
print(f" Added {len(detections)} detections at frame {frame_idx}: objects {object_ids}")
|
||||
@@ -262,12 +280,23 @@ class SAM2StreamingProcessor:
|
||||
Returns:
|
||||
Combined mask for all objects
|
||||
"""
|
||||
# Convert frame to tensor
|
||||
frame_tensor = torch.tensor(frame, dtype=torch.float32, device=self.device)
|
||||
# Convert frame to tensor (ensure proper format and device)
|
||||
if isinstance(frame, np.ndarray):
|
||||
# Convert BGR to RGB if needed (OpenCV uses BGR)
|
||||
if frame.shape[-1] == 3:
|
||||
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
frame_tensor = torch.from_numpy(frame).float().to(self.device)
|
||||
else:
|
||||
frame_tensor = frame.float().to(self.device)
|
||||
|
||||
if frame_tensor.ndim == 3:
|
||||
frame_tensor = frame_tensor.permute(2, 0, 1) # HWC -> CHW
|
||||
frame_tensor = frame_tensor.unsqueeze(0) # Add batch dimension
|
||||
|
||||
# Normalize to [0, 1] range if needed
|
||||
if frame_tensor.max() > 1.0:
|
||||
frame_tensor = frame_tensor / 255.0
|
||||
|
||||
with torch.inference_mode():
|
||||
# Process frame through SAM2's image encoder
|
||||
backbone_out = self.predictor.forward_image(frame_tensor)
|
||||
|
||||
Reference in New Issue
Block a user