please fucking work
This commit is contained in:
@@ -14,6 +14,7 @@ For a true streaming implementation, you may need to:
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import cv2
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Any, List, Optional, Tuple, Generator
|
from typing import Dict, Any, List, Optional, Tuple, Generator
|
||||||
import warnings
|
import warnings
|
||||||
@@ -151,8 +152,8 @@ class SAM2StreamingProcessor:
|
|||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
inference_state = self.predictor.init_state(
|
inference_state = self.predictor.init_state(
|
||||||
video_path=dummy_path,
|
video_path=dummy_path,
|
||||||
offload_video_to_cpu=self.memory_offload,
|
offload_video_to_cpu=False, # Keep video frames on GPU for streaming
|
||||||
offload_state_to_cpu=self.memory_offload,
|
offload_state_to_cpu=False, # Keep state on GPU for performance
|
||||||
async_loading_frames=True
|
async_loading_frames=True
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -212,12 +213,23 @@ class SAM2StreamingProcessor:
|
|||||||
warnings.warn(f"No detections to add at frame {frame_idx}")
|
warnings.warn(f"No detections to add at frame {frame_idx}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# Convert frame to tensor
|
# Convert frame to tensor (ensure proper format and device)
|
||||||
frame_tensor = torch.tensor(frame, dtype=torch.float32, device=self.device)
|
if isinstance(frame, np.ndarray):
|
||||||
|
# Convert BGR to RGB if needed (OpenCV uses BGR)
|
||||||
|
if frame.shape[-1] == 3:
|
||||||
|
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||||
|
frame_tensor = torch.from_numpy(frame).float().to(self.device)
|
||||||
|
else:
|
||||||
|
frame_tensor = frame.float().to(self.device)
|
||||||
|
|
||||||
if frame_tensor.ndim == 3:
|
if frame_tensor.ndim == 3:
|
||||||
frame_tensor = frame_tensor.permute(2, 0, 1) # HWC -> CHW
|
frame_tensor = frame_tensor.permute(2, 0, 1) # HWC -> CHW
|
||||||
frame_tensor = frame_tensor.unsqueeze(0) # Add batch dimension
|
frame_tensor = frame_tensor.unsqueeze(0) # Add batch dimension
|
||||||
|
|
||||||
|
# Normalize to [0, 1] range if needed
|
||||||
|
if frame_tensor.max() > 1.0:
|
||||||
|
frame_tensor = frame_tensor / 255.0
|
||||||
|
|
||||||
# Convert detections to SAM2 format
|
# Convert detections to SAM2 format
|
||||||
boxes = []
|
boxes = []
|
||||||
for det in detections:
|
for det in detections:
|
||||||
@@ -235,12 +247,18 @@ class SAM2StreamingProcessor:
|
|||||||
state['cached_features'][frame_idx] = backbone_out
|
state['cached_features'][frame_idx] = backbone_out
|
||||||
|
|
||||||
# Add boxes as prompts for this specific frame
|
# Add boxes as prompts for this specific frame
|
||||||
_, object_ids, masks = self.predictor.add_new_points_or_box(
|
try:
|
||||||
inference_state=state,
|
_, object_ids, masks = self.predictor.add_new_points_or_box(
|
||||||
frame_idx=frame_idx,
|
inference_state=state,
|
||||||
obj_id=None, # Let SAM2 auto-assign
|
frame_idx=frame_idx,
|
||||||
box=boxes_tensor
|
obj_id=None, # Let SAM2 auto-assign
|
||||||
)
|
box=boxes_tensor
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f" Error in add_new_points_or_box: {e}")
|
||||||
|
print(f" Box tensor device: {boxes_tensor.device}")
|
||||||
|
print(f" Frame tensor device: {frame_tensor.device}")
|
||||||
|
raise
|
||||||
|
|
||||||
self.object_ids = object_ids
|
self.object_ids = object_ids
|
||||||
print(f" Added {len(detections)} detections at frame {frame_idx}: objects {object_ids}")
|
print(f" Added {len(detections)} detections at frame {frame_idx}: objects {object_ids}")
|
||||||
@@ -262,12 +280,23 @@ class SAM2StreamingProcessor:
|
|||||||
Returns:
|
Returns:
|
||||||
Combined mask for all objects
|
Combined mask for all objects
|
||||||
"""
|
"""
|
||||||
# Convert frame to tensor
|
# Convert frame to tensor (ensure proper format and device)
|
||||||
frame_tensor = torch.tensor(frame, dtype=torch.float32, device=self.device)
|
if isinstance(frame, np.ndarray):
|
||||||
|
# Convert BGR to RGB if needed (OpenCV uses BGR)
|
||||||
|
if frame.shape[-1] == 3:
|
||||||
|
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||||
|
frame_tensor = torch.from_numpy(frame).float().to(self.device)
|
||||||
|
else:
|
||||||
|
frame_tensor = frame.float().to(self.device)
|
||||||
|
|
||||||
if frame_tensor.ndim == 3:
|
if frame_tensor.ndim == 3:
|
||||||
frame_tensor = frame_tensor.permute(2, 0, 1) # HWC -> CHW
|
frame_tensor = frame_tensor.permute(2, 0, 1) # HWC -> CHW
|
||||||
frame_tensor = frame_tensor.unsqueeze(0) # Add batch dimension
|
frame_tensor = frame_tensor.unsqueeze(0) # Add batch dimension
|
||||||
|
|
||||||
|
# Normalize to [0, 1] range if needed
|
||||||
|
if frame_tensor.max() > 1.0:
|
||||||
|
frame_tensor = frame_tensor / 255.0
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
# Process frame through SAM2's image encoder
|
# Process frame through SAM2's image encoder
|
||||||
backbone_out = self.predictor.forward_image(frame_tensor)
|
backbone_out = self.predictor.forward_image(frame_tensor)
|
||||||
|
|||||||
Reference in New Issue
Block a user