This commit is contained in:
2025-07-27 09:04:40 -07:00
parent 3a59e87f3e
commit 2e5ded7dbf

View File

@@ -129,9 +129,45 @@ class SAM2StreamingProcessor:
# Create a streaming-compatible inference state
# This mirrors SAM2's internal state structure but without video frames
with torch.inference_mode():
# Initialize empty inference state using SAM2's predictor
# We'll manually provide frames via propagate calls
# Use SAM2's init_state but with a dummy 1-frame video to avoid loading
# We'll override the frame access later
try:
# Create a minimal dummy video file temporarily
import tempfile
import cv2
# Create 1-frame dummy video
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_file:
dummy_path = tmp_file.name
# Write a single frame video
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(dummy_path, fourcc, 1.0, (video_info['width'], video_info['height']))
dummy_frame = np.zeros((video_info['height'], video_info['width'], 3), dtype=np.uint8)
out.write(dummy_frame)
out.release()
# Initialize with dummy video (SAM2 will load metadata only from 1 frame)
with torch.inference_mode():
inference_state = self.predictor.init_state(
video_path=dummy_path,
offload_video_to_cpu=self.memory_offload,
offload_state_to_cpu=self.memory_offload,
async_loading_frames=True
)
# Clean up dummy file
import os
os.unlink(dummy_path)
# Update state with actual video info
inference_state['num_frames'] = video_info.get('total_frames', video_info.get('frame_count', 0))
inference_state['video_height'] = video_info['height']
inference_state['video_width'] = video_info['width']
except Exception as e:
print(f" Warning: Failed to create proper SAM2 state ({e}), using minimal state")
# Fallback to minimal state
inference_state = {
'point_inputs_per_obj': {},
'mask_inputs_per_obj': {},
@@ -151,13 +187,8 @@ class SAM2StreamingProcessor:
'storage_device': torch.device('cpu') if self.memory_offload else self.device,
'offload_video_to_cpu': self.memory_offload,
'offload_state_to_cpu': self.memory_offload,
'inference_state': {},
}
# Initialize SAM2 constants that don't depend on video frames
self.predictor._get_image_feature_cache = {}
self.predictor._feature_bank = {}
return inference_state
def add_detections(self,
@@ -198,16 +229,16 @@ class SAM2StreamingProcessor:
# Manually process frame and add prompts (streaming approach)
with torch.inference_mode():
# Process frame through SAM2's image encoder
features = self.predictor._get_image_features(frame_tensor)
backbone_out = self.predictor.forward_image(frame_tensor)
# Store features in state for this frame
state['cached_features'][frame_idx] = features
state['cached_features'][frame_idx] = backbone_out
# Add boxes as prompts for this specific frame
_, object_ids, masks = self.predictor.add_new_points_or_box(
inference_state=state,
frame_idx=frame_idx,
obj_id=0, # SAM2 will auto-increment
obj_id=None, # Let SAM2 auto-assign
box=boxes_tensor
)
@@ -239,33 +270,41 @@ class SAM2StreamingProcessor:
with torch.inference_mode():
# Process frame through SAM2's image encoder
features = self.predictor._get_image_features(frame_tensor)
backbone_out = self.predictor.forward_image(frame_tensor)
# Store features in state for this frame
state['cached_features'][frame_idx] = features
state['cached_features'][frame_idx] = backbone_out
# Get masks for current frame by propagating from previous frames
masks = []
for obj_id in state.get('obj_ids', []):
# Use SAM2's mask propagation for this object
try:
obj_mask = self.predictor._propagate_single_object(
state, obj_id, frame_idx, features
)
if obj_mask is not None:
masks.append(obj_mask)
except Exception as e:
# If propagation fails, use empty mask
print(f" Warning: Propagation failed for object {obj_id}: {e}")
empty_mask = torch.zeros((frame.shape[0], frame.shape[1]), device=self.device)
masks.append(empty_mask)
# Use SAM2's single frame inference for propagation
try:
# Run single frame inference for all tracked objects
output_dict = {}
self.predictor._run_single_frame_inference(
inference_state=state,
output_dict=output_dict,
frame_idx=frame_idx,
batch_size=1,
is_init_cond_frame=False, # Not initialization frame
point_inputs=None,
mask_inputs=None,
reverse=False,
run_mem_encoder=True
)
# Combine all object masks
if masks:
combined_mask = torch.stack(masks).max(dim=0)[0]
# Convert to numpy
combined_mask_np = combined_mask.cpu().numpy().astype(np.uint8)
else:
# Extract masks from output
if output_dict and 'pred_masks' in output_dict:
pred_masks = output_dict['pred_masks']
# Combine all object masks
if pred_masks.shape[0] > 0:
combined_mask = pred_masks.max(dim=0)[0]
combined_mask_np = (combined_mask > 0.0).cpu().numpy().astype(np.uint8) * 255
else:
combined_mask_np = np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8)
else:
combined_mask_np = np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8)
except Exception as e:
print(f" Warning: Single frame inference failed: {e}")
combined_mask_np = np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8)
# Cleanup old features to prevent memory accumulation