idk
This commit is contained in:
@@ -130,45 +130,11 @@ class SAM2StreamingProcessor:
|
|||||||
# Create a streaming-compatible inference state
|
# Create a streaming-compatible inference state
|
||||||
# This mirrors SAM2's internal state structure but without video frames
|
# This mirrors SAM2's internal state structure but without video frames
|
||||||
|
|
||||||
# Use SAM2's init_state but with a dummy 1-frame video to avoid loading
|
# Create streaming-compatible state without loading video
|
||||||
# We'll override the frame access later
|
# This approach avoids the dummy video complexity
|
||||||
try:
|
|
||||||
# Create a minimal dummy video file temporarily
|
|
||||||
import tempfile
|
|
||||||
import cv2
|
|
||||||
|
|
||||||
# Create 1-frame dummy video
|
with torch.inference_mode():
|
||||||
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_file:
|
# Initialize minimal state that mimics SAM2's structure
|
||||||
dummy_path = tmp_file.name
|
|
||||||
|
|
||||||
# Write a single frame video
|
|
||||||
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
|
||||||
out = cv2.VideoWriter(dummy_path, fourcc, 1.0, (video_info['width'], video_info['height']))
|
|
||||||
dummy_frame = np.zeros((video_info['height'], video_info['width'], 3), dtype=np.uint8)
|
|
||||||
out.write(dummy_frame)
|
|
||||||
out.release()
|
|
||||||
|
|
||||||
# Initialize with dummy video (SAM2 will load metadata only from 1 frame)
|
|
||||||
with torch.inference_mode():
|
|
||||||
inference_state = self.predictor.init_state(
|
|
||||||
video_path=dummy_path,
|
|
||||||
offload_video_to_cpu=False, # Keep video frames on GPU for streaming
|
|
||||||
offload_state_to_cpu=False, # Keep state on GPU for performance
|
|
||||||
async_loading_frames=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Clean up dummy file
|
|
||||||
import os
|
|
||||||
os.unlink(dummy_path)
|
|
||||||
|
|
||||||
# Update state with actual video info
|
|
||||||
inference_state['num_frames'] = video_info.get('total_frames', video_info.get('frame_count', 0))
|
|
||||||
inference_state['video_height'] = video_info['height']
|
|
||||||
inference_state['video_width'] = video_info['width']
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f" Warning: Failed to create proper SAM2 state ({e}), using minimal state")
|
|
||||||
# Fallback to minimal state
|
|
||||||
inference_state = {
|
inference_state = {
|
||||||
'point_inputs_per_obj': {},
|
'point_inputs_per_obj': {},
|
||||||
'mask_inputs_per_obj': {},
|
'mask_inputs_per_obj': {},
|
||||||
@@ -185,13 +151,50 @@ class SAM2StreamingProcessor:
|
|||||||
'video_height': video_info['height'],
|
'video_height': video_info['height'],
|
||||||
'video_width': video_info['width'],
|
'video_width': video_info['width'],
|
||||||
'device': self.device,
|
'device': self.device,
|
||||||
'storage_device': torch.device('cpu') if self.memory_offload else self.device,
|
'storage_device': self.device, # Keep everything on GPU
|
||||||
'offload_video_to_cpu': self.memory_offload,
|
'offload_video_to_cpu': False,
|
||||||
'offload_state_to_cpu': self.memory_offload,
|
'offload_state_to_cpu': False,
|
||||||
|
# Add some required SAM2 internal structures
|
||||||
|
'output_dict_per_obj': {},
|
||||||
|
'temp_output_dict_per_obj': {},
|
||||||
|
'frames': None, # We provide frames manually
|
||||||
|
'images': None, # We provide images manually
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Initialize some constants that SAM2 expects
|
||||||
|
inference_state['constants'] = {
|
||||||
|
'image_size': max(video_info['height'], video_info['width']),
|
||||||
|
'backbone_stride': 16, # Standard SAM2 backbone stride
|
||||||
|
'sam_mask_decoder_extra_args': {},
|
||||||
|
'sam_prompt_embed_dim': 256,
|
||||||
|
'sam_image_embedding_size': video_info['height'] // 16, # Assuming 16x downsampling
|
||||||
|
}
|
||||||
|
|
||||||
|
print(f" Created streaming-compatible state")
|
||||||
|
|
||||||
return inference_state
|
return inference_state
|
||||||
|
|
||||||
|
def _move_state_to_device(self, state: Dict[str, Any], device: torch.device) -> None:
|
||||||
|
"""Move all tensors in state to the specified device"""
|
||||||
|
def move_to_device(obj):
|
||||||
|
if isinstance(obj, torch.Tensor):
|
||||||
|
return obj.to(device)
|
||||||
|
elif isinstance(obj, dict):
|
||||||
|
return {k: move_to_device(v) for k, v in obj.items()}
|
||||||
|
elif isinstance(obj, list):
|
||||||
|
return [move_to_device(item) for item in obj]
|
||||||
|
elif isinstance(obj, tuple):
|
||||||
|
return tuple(move_to_device(item) for item in obj)
|
||||||
|
else:
|
||||||
|
return obj
|
||||||
|
|
||||||
|
# Move all state components to device
|
||||||
|
for key, value in state.items():
|
||||||
|
if key not in ['video_path', 'num_frames', 'video_height', 'video_width']: # Skip metadata
|
||||||
|
state[key] = move_to_device(value)
|
||||||
|
|
||||||
|
print(f" Moved state tensors to {device}")
|
||||||
|
|
||||||
def add_detections(self,
|
def add_detections(self,
|
||||||
state: Dict[str, Any],
|
state: Dict[str, Any],
|
||||||
frame: np.ndarray,
|
frame: np.ndarray,
|
||||||
@@ -248,16 +251,29 @@ class SAM2StreamingProcessor:
|
|||||||
|
|
||||||
# Add boxes as prompts for this specific frame
|
# Add boxes as prompts for this specific frame
|
||||||
try:
|
try:
|
||||||
|
# Force ensure all inputs are on correct device
|
||||||
|
boxes_tensor = boxes_tensor.to(self.device)
|
||||||
|
|
||||||
_, object_ids, masks = self.predictor.add_new_points_or_box(
|
_, object_ids, masks = self.predictor.add_new_points_or_box(
|
||||||
inference_state=state,
|
inference_state=state,
|
||||||
frame_idx=frame_idx,
|
frame_idx=frame_idx,
|
||||||
obj_id=None, # Let SAM2 auto-assign
|
obj_id=None, # Let SAM2 auto-assign
|
||||||
box=boxes_tensor
|
box=boxes_tensor
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Update state with object tracking info
|
||||||
|
state['obj_ids'] = object_ids
|
||||||
|
state['tracking_has_started'] = True
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f" Error in add_new_points_or_box: {e}")
|
print(f" Error in add_new_points_or_box: {e}")
|
||||||
print(f" Box tensor device: {boxes_tensor.device}")
|
print(f" Box tensor device: {boxes_tensor.device}")
|
||||||
print(f" Frame tensor device: {frame_tensor.device}")
|
print(f" Frame tensor device: {frame_tensor.device}")
|
||||||
|
print(f" State device keys: {[k for k in state.keys() if 'device' in k.lower()]}")
|
||||||
|
# Try to inspect state tensor devices
|
||||||
|
for key, value in state.items():
|
||||||
|
if isinstance(value, torch.Tensor):
|
||||||
|
print(f" State[{key}] device: {value.device}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
self.object_ids = object_ids
|
self.object_ids = object_ids
|
||||||
|
|||||||
Reference in New Issue
Block a user