This commit is contained in:
2025-07-27 09:04:40 -07:00
parent 3a59e87f3e
commit 2e5ded7dbf

View File

@@ -129,9 +129,45 @@ class SAM2StreamingProcessor:
# Create a streaming-compatible inference state # Create a streaming-compatible inference state
# This mirrors SAM2's internal state structure but without video frames # This mirrors SAM2's internal state structure but without video frames
with torch.inference_mode(): # Use SAM2's init_state but with a dummy 1-frame video to avoid loading
# Initialize empty inference state using SAM2's predictor # We'll override the frame access later
# We'll manually provide frames via propagate calls try:
# Create a minimal dummy video file temporarily
import tempfile
import cv2
# Create 1-frame dummy video
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_file:
dummy_path = tmp_file.name
# Write a single frame video
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(dummy_path, fourcc, 1.0, (video_info['width'], video_info['height']))
dummy_frame = np.zeros((video_info['height'], video_info['width'], 3), dtype=np.uint8)
out.write(dummy_frame)
out.release()
# Initialize with dummy video (SAM2 will load metadata only from 1 frame)
with torch.inference_mode():
inference_state = self.predictor.init_state(
video_path=dummy_path,
offload_video_to_cpu=self.memory_offload,
offload_state_to_cpu=self.memory_offload,
async_loading_frames=True
)
# Clean up dummy file
import os
os.unlink(dummy_path)
# Update state with actual video info
inference_state['num_frames'] = video_info.get('total_frames', video_info.get('frame_count', 0))
inference_state['video_height'] = video_info['height']
inference_state['video_width'] = video_info['width']
except Exception as e:
print(f" Warning: Failed to create proper SAM2 state ({e}), using minimal state")
# Fallback to minimal state
inference_state = { inference_state = {
'point_inputs_per_obj': {}, 'point_inputs_per_obj': {},
'mask_inputs_per_obj': {}, 'mask_inputs_per_obj': {},
@@ -151,13 +187,8 @@ class SAM2StreamingProcessor:
'storage_device': torch.device('cpu') if self.memory_offload else self.device, 'storage_device': torch.device('cpu') if self.memory_offload else self.device,
'offload_video_to_cpu': self.memory_offload, 'offload_video_to_cpu': self.memory_offload,
'offload_state_to_cpu': self.memory_offload, 'offload_state_to_cpu': self.memory_offload,
'inference_state': {},
} }
# Initialize SAM2 constants that don't depend on video frames
self.predictor._get_image_feature_cache = {}
self.predictor._feature_bank = {}
return inference_state return inference_state
def add_detections(self, def add_detections(self,
@@ -198,16 +229,16 @@ class SAM2StreamingProcessor:
# Manually process frame and add prompts (streaming approach) # Manually process frame and add prompts (streaming approach)
with torch.inference_mode(): with torch.inference_mode():
# Process frame through SAM2's image encoder # Process frame through SAM2's image encoder
features = self.predictor._get_image_features(frame_tensor) backbone_out = self.predictor.forward_image(frame_tensor)
# Store features in state for this frame # Store features in state for this frame
state['cached_features'][frame_idx] = features state['cached_features'][frame_idx] = backbone_out
# Add boxes as prompts for this specific frame # Add boxes as prompts for this specific frame
_, object_ids, masks = self.predictor.add_new_points_or_box( _, object_ids, masks = self.predictor.add_new_points_or_box(
inference_state=state, inference_state=state,
frame_idx=frame_idx, frame_idx=frame_idx,
obj_id=0, # SAM2 will auto-increment obj_id=None, # Let SAM2 auto-assign
box=boxes_tensor box=boxes_tensor
) )
@@ -239,33 +270,41 @@ class SAM2StreamingProcessor:
with torch.inference_mode(): with torch.inference_mode():
# Process frame through SAM2's image encoder # Process frame through SAM2's image encoder
features = self.predictor._get_image_features(frame_tensor) backbone_out = self.predictor.forward_image(frame_tensor)
# Store features in state for this frame # Store features in state for this frame
state['cached_features'][frame_idx] = features state['cached_features'][frame_idx] = backbone_out
# Get masks for current frame by propagating from previous frames # Use SAM2's single frame inference for propagation
masks = [] try:
for obj_id in state.get('obj_ids', []): # Run single frame inference for all tracked objects
# Use SAM2's mask propagation for this object output_dict = {}
try: self.predictor._run_single_frame_inference(
obj_mask = self.predictor._propagate_single_object( inference_state=state,
state, obj_id, frame_idx, features output_dict=output_dict,
) frame_idx=frame_idx,
if obj_mask is not None: batch_size=1,
masks.append(obj_mask) is_init_cond_frame=False, # Not initialization frame
except Exception as e: point_inputs=None,
# If propagation fails, use empty mask mask_inputs=None,
print(f" Warning: Propagation failed for object {obj_id}: {e}") reverse=False,
empty_mask = torch.zeros((frame.shape[0], frame.shape[1]), device=self.device) run_mem_encoder=True
masks.append(empty_mask) )
# Combine all object masks # Extract masks from output
if masks: if output_dict and 'pred_masks' in output_dict:
combined_mask = torch.stack(masks).max(dim=0)[0] pred_masks = output_dict['pred_masks']
# Convert to numpy # Combine all object masks
combined_mask_np = combined_mask.cpu().numpy().astype(np.uint8) if pred_masks.shape[0] > 0:
else: combined_mask = pred_masks.max(dim=0)[0]
combined_mask_np = (combined_mask > 0.0).cpu().numpy().astype(np.uint8) * 255
else:
combined_mask_np = np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8)
else:
combined_mask_np = np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8)
except Exception as e:
print(f" Warning: Single frame inference failed: {e}")
combined_mask_np = np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8) combined_mask_np = np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8)
# Cleanup old features to prevent memory accumulation # Cleanup old features to prevent memory accumulation