This commit is contained in:
2025-07-27 09:20:42 -07:00
parent 1d15fb5bc8
commit 7b3ffb7830

View File

@@ -130,45 +130,11 @@ class SAM2StreamingProcessor:
# Create a streaming-compatible inference state # Create a streaming-compatible inference state
# This mirrors SAM2's internal state structure but without video frames # This mirrors SAM2's internal state structure but without video frames
# Use SAM2's init_state but with a dummy 1-frame video to avoid loading # Create streaming-compatible state without loading video
# We'll override the frame access later # This approach avoids the dummy video complexity
try:
# Create a minimal dummy video file temporarily with torch.inference_mode():
import tempfile # Initialize minimal state that mimics SAM2's structure
import cv2
# Create 1-frame dummy video
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_file:
dummy_path = tmp_file.name
# Write a single frame video
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(dummy_path, fourcc, 1.0, (video_info['width'], video_info['height']))
dummy_frame = np.zeros((video_info['height'], video_info['width'], 3), dtype=np.uint8)
out.write(dummy_frame)
out.release()
# Initialize with dummy video (SAM2 will load metadata only from 1 frame)
with torch.inference_mode():
inference_state = self.predictor.init_state(
video_path=dummy_path,
offload_video_to_cpu=False, # Keep video frames on GPU for streaming
offload_state_to_cpu=False, # Keep state on GPU for performance
async_loading_frames=True
)
# Clean up dummy file
import os
os.unlink(dummy_path)
# Update state with actual video info
inference_state['num_frames'] = video_info.get('total_frames', video_info.get('frame_count', 0))
inference_state['video_height'] = video_info['height']
inference_state['video_width'] = video_info['width']
except Exception as e:
print(f" Warning: Failed to create proper SAM2 state ({e}), using minimal state")
# Fallback to minimal state
inference_state = { inference_state = {
'point_inputs_per_obj': {}, 'point_inputs_per_obj': {},
'mask_inputs_per_obj': {}, 'mask_inputs_per_obj': {},
@@ -185,13 +151,50 @@ class SAM2StreamingProcessor:
'video_height': video_info['height'], 'video_height': video_info['height'],
'video_width': video_info['width'], 'video_width': video_info['width'],
'device': self.device, 'device': self.device,
'storage_device': torch.device('cpu') if self.memory_offload else self.device, 'storage_device': self.device, # Keep everything on GPU
'offload_video_to_cpu': self.memory_offload, 'offload_video_to_cpu': False,
'offload_state_to_cpu': self.memory_offload, 'offload_state_to_cpu': False,
# Add some required SAM2 internal structures
'output_dict_per_obj': {},
'temp_output_dict_per_obj': {},
'frames': None, # We provide frames manually
'images': None, # We provide images manually
} }
# Initialize some constants that SAM2 expects
inference_state['constants'] = {
'image_size': max(video_info['height'], video_info['width']),
'backbone_stride': 16, # Standard SAM2 backbone stride
'sam_mask_decoder_extra_args': {},
'sam_prompt_embed_dim': 256,
'sam_image_embedding_size': video_info['height'] // 16, # Assuming 16x downsampling
}
print(f" Created streaming-compatible state")
return inference_state return inference_state
def _move_state_to_device(self, state: Dict[str, Any], device: torch.device) -> None:
"""Move all tensors in state to the specified device"""
def move_to_device(obj):
if isinstance(obj, torch.Tensor):
return obj.to(device)
elif isinstance(obj, dict):
return {k: move_to_device(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [move_to_device(item) for item in obj]
elif isinstance(obj, tuple):
return tuple(move_to_device(item) for item in obj)
else:
return obj
# Move all state components to device
for key, value in state.items():
if key not in ['video_path', 'num_frames', 'video_height', 'video_width']: # Skip metadata
state[key] = move_to_device(value)
print(f" Moved state tensors to {device}")
def add_detections(self, def add_detections(self,
state: Dict[str, Any], state: Dict[str, Any],
frame: np.ndarray, frame: np.ndarray,
@@ -248,16 +251,29 @@ class SAM2StreamingProcessor:
# Add boxes as prompts for this specific frame # Add boxes as prompts for this specific frame
try: try:
# Force ensure all inputs are on correct device
boxes_tensor = boxes_tensor.to(self.device)
_, object_ids, masks = self.predictor.add_new_points_or_box( _, object_ids, masks = self.predictor.add_new_points_or_box(
inference_state=state, inference_state=state,
frame_idx=frame_idx, frame_idx=frame_idx,
obj_id=None, # Let SAM2 auto-assign obj_id=None, # Let SAM2 auto-assign
box=boxes_tensor box=boxes_tensor
) )
# Update state with object tracking info
state['obj_ids'] = object_ids
state['tracking_has_started'] = True
except Exception as e: except Exception as e:
print(f" Error in add_new_points_or_box: {e}") print(f" Error in add_new_points_or_box: {e}")
print(f" Box tensor device: {boxes_tensor.device}") print(f" Box tensor device: {boxes_tensor.device}")
print(f" Frame tensor device: {frame_tensor.device}") print(f" Frame tensor device: {frame_tensor.device}")
print(f" State device keys: {[k for k in state.keys() if 'device' in k.lower()]}")
# Try to inspect state tensor devices
for key, value in state.items():
if isinstance(value, torch.Tensor):
print(f" State[{key}] device: {value.device}")
raise raise
self.object_ids = object_ids self.object_ids = object_ids