idk
This commit is contained in:
@@ -130,45 +130,11 @@ class SAM2StreamingProcessor:
|
||||
# Create a streaming-compatible inference state
|
||||
# This mirrors SAM2's internal state structure but without video frames
|
||||
|
||||
# Use SAM2's init_state but with a dummy 1-frame video to avoid loading
|
||||
# We'll override the frame access later
|
||||
try:
|
||||
# Create a minimal dummy video file temporarily
|
||||
import tempfile
|
||||
import cv2
|
||||
# Create streaming-compatible state without loading video
|
||||
# This approach avoids the dummy video complexity
|
||||
|
||||
# Create 1-frame dummy video
|
||||
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_file:
|
||||
dummy_path = tmp_file.name
|
||||
|
||||
# Write a single frame video
|
||||
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
||||
out = cv2.VideoWriter(dummy_path, fourcc, 1.0, (video_info['width'], video_info['height']))
|
||||
dummy_frame = np.zeros((video_info['height'], video_info['width'], 3), dtype=np.uint8)
|
||||
out.write(dummy_frame)
|
||||
out.release()
|
||||
|
||||
# Initialize with dummy video (SAM2 will load metadata only from 1 frame)
|
||||
with torch.inference_mode():
|
||||
inference_state = self.predictor.init_state(
|
||||
video_path=dummy_path,
|
||||
offload_video_to_cpu=False, # Keep video frames on GPU for streaming
|
||||
offload_state_to_cpu=False, # Keep state on GPU for performance
|
||||
async_loading_frames=True
|
||||
)
|
||||
|
||||
# Clean up dummy file
|
||||
import os
|
||||
os.unlink(dummy_path)
|
||||
|
||||
# Update state with actual video info
|
||||
inference_state['num_frames'] = video_info.get('total_frames', video_info.get('frame_count', 0))
|
||||
inference_state['video_height'] = video_info['height']
|
||||
inference_state['video_width'] = video_info['width']
|
||||
|
||||
except Exception as e:
|
||||
print(f" Warning: Failed to create proper SAM2 state ({e}), using minimal state")
|
||||
# Fallback to minimal state
|
||||
with torch.inference_mode():
|
||||
# Initialize minimal state that mimics SAM2's structure
|
||||
inference_state = {
|
||||
'point_inputs_per_obj': {},
|
||||
'mask_inputs_per_obj': {},
|
||||
@@ -185,13 +151,50 @@ class SAM2StreamingProcessor:
|
||||
'video_height': video_info['height'],
|
||||
'video_width': video_info['width'],
|
||||
'device': self.device,
|
||||
'storage_device': torch.device('cpu') if self.memory_offload else self.device,
|
||||
'offload_video_to_cpu': self.memory_offload,
|
||||
'offload_state_to_cpu': self.memory_offload,
|
||||
'storage_device': self.device, # Keep everything on GPU
|
||||
'offload_video_to_cpu': False,
|
||||
'offload_state_to_cpu': False,
|
||||
# Add some required SAM2 internal structures
|
||||
'output_dict_per_obj': {},
|
||||
'temp_output_dict_per_obj': {},
|
||||
'frames': None, # We provide frames manually
|
||||
'images': None, # We provide images manually
|
||||
}
|
||||
|
||||
# Initialize some constants that SAM2 expects
|
||||
inference_state['constants'] = {
|
||||
'image_size': max(video_info['height'], video_info['width']),
|
||||
'backbone_stride': 16, # Standard SAM2 backbone stride
|
||||
'sam_mask_decoder_extra_args': {},
|
||||
'sam_prompt_embed_dim': 256,
|
||||
'sam_image_embedding_size': video_info['height'] // 16, # Assuming 16x downsampling
|
||||
}
|
||||
|
||||
print(f" Created streaming-compatible state")
|
||||
|
||||
return inference_state
|
||||
|
||||
def _move_state_to_device(self, state: Dict[str, Any], device: torch.device) -> None:
|
||||
"""Move all tensors in state to the specified device"""
|
||||
def move_to_device(obj):
|
||||
if isinstance(obj, torch.Tensor):
|
||||
return obj.to(device)
|
||||
elif isinstance(obj, dict):
|
||||
return {k: move_to_device(v) for k, v in obj.items()}
|
||||
elif isinstance(obj, list):
|
||||
return [move_to_device(item) for item in obj]
|
||||
elif isinstance(obj, tuple):
|
||||
return tuple(move_to_device(item) for item in obj)
|
||||
else:
|
||||
return obj
|
||||
|
||||
# Move all state components to device
|
||||
for key, value in state.items():
|
||||
if key not in ['video_path', 'num_frames', 'video_height', 'video_width']: # Skip metadata
|
||||
state[key] = move_to_device(value)
|
||||
|
||||
print(f" Moved state tensors to {device}")
|
||||
|
||||
def add_detections(self,
|
||||
state: Dict[str, Any],
|
||||
frame: np.ndarray,
|
||||
@@ -248,16 +251,29 @@ class SAM2StreamingProcessor:
|
||||
|
||||
# Add boxes as prompts for this specific frame
|
||||
try:
|
||||
# Force ensure all inputs are on correct device
|
||||
boxes_tensor = boxes_tensor.to(self.device)
|
||||
|
||||
_, object_ids, masks = self.predictor.add_new_points_or_box(
|
||||
inference_state=state,
|
||||
frame_idx=frame_idx,
|
||||
obj_id=None, # Let SAM2 auto-assign
|
||||
box=boxes_tensor
|
||||
)
|
||||
|
||||
# Update state with object tracking info
|
||||
state['obj_ids'] = object_ids
|
||||
state['tracking_has_started'] = True
|
||||
|
||||
except Exception as e:
|
||||
print(f" Error in add_new_points_or_box: {e}")
|
||||
print(f" Box tensor device: {boxes_tensor.device}")
|
||||
print(f" Frame tensor device: {frame_tensor.device}")
|
||||
print(f" State device keys: {[k for k in state.keys() if 'device' in k.lower()]}")
|
||||
# Try to inspect state tensor devices
|
||||
for key, value in state.items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
print(f" State[{key}] device: {value.device}")
|
||||
raise
|
||||
|
||||
self.object_ids = object_ids
|
||||
|
||||
Reference in New Issue
Block a user