fix api
This commit is contained in:
@@ -129,9 +129,45 @@ class SAM2StreamingProcessor:
|
||||
# Create a streaming-compatible inference state
|
||||
# This mirrors SAM2's internal state structure but without video frames
|
||||
|
||||
with torch.inference_mode():
|
||||
# Initialize empty inference state using SAM2's predictor
|
||||
# We'll manually provide frames via propagate calls
|
||||
# Use SAM2's init_state but with a dummy 1-frame video to avoid loading
|
||||
# We'll override the frame access later
|
||||
try:
|
||||
# Create a minimal dummy video file temporarily
|
||||
import tempfile
|
||||
import cv2
|
||||
|
||||
# Create 1-frame dummy video
|
||||
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_file:
|
||||
dummy_path = tmp_file.name
|
||||
|
||||
# Write a single frame video
|
||||
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
||||
out = cv2.VideoWriter(dummy_path, fourcc, 1.0, (video_info['width'], video_info['height']))
|
||||
dummy_frame = np.zeros((video_info['height'], video_info['width'], 3), dtype=np.uint8)
|
||||
out.write(dummy_frame)
|
||||
out.release()
|
||||
|
||||
# Initialize with dummy video (SAM2 will load metadata only from 1 frame)
|
||||
with torch.inference_mode():
|
||||
inference_state = self.predictor.init_state(
|
||||
video_path=dummy_path,
|
||||
offload_video_to_cpu=self.memory_offload,
|
||||
offload_state_to_cpu=self.memory_offload,
|
||||
async_loading_frames=True
|
||||
)
|
||||
|
||||
# Clean up dummy file
|
||||
import os
|
||||
os.unlink(dummy_path)
|
||||
|
||||
# Update state with actual video info
|
||||
inference_state['num_frames'] = video_info.get('total_frames', video_info.get('frame_count', 0))
|
||||
inference_state['video_height'] = video_info['height']
|
||||
inference_state['video_width'] = video_info['width']
|
||||
|
||||
except Exception as e:
|
||||
print(f" Warning: Failed to create proper SAM2 state ({e}), using minimal state")
|
||||
# Fallback to minimal state
|
||||
inference_state = {
|
||||
'point_inputs_per_obj': {},
|
||||
'mask_inputs_per_obj': {},
|
||||
@@ -151,13 +187,8 @@ class SAM2StreamingProcessor:
|
||||
'storage_device': torch.device('cpu') if self.memory_offload else self.device,
|
||||
'offload_video_to_cpu': self.memory_offload,
|
||||
'offload_state_to_cpu': self.memory_offload,
|
||||
'inference_state': {},
|
||||
}
|
||||
|
||||
# Initialize SAM2 constants that don't depend on video frames
|
||||
self.predictor._get_image_feature_cache = {}
|
||||
self.predictor._feature_bank = {}
|
||||
|
||||
return inference_state
|
||||
|
||||
def add_detections(self,
|
||||
@@ -198,16 +229,16 @@ class SAM2StreamingProcessor:
|
||||
# Manually process frame and add prompts (streaming approach)
|
||||
with torch.inference_mode():
|
||||
# Process frame through SAM2's image encoder
|
||||
features = self.predictor._get_image_features(frame_tensor)
|
||||
backbone_out = self.predictor.forward_image(frame_tensor)
|
||||
|
||||
# Store features in state for this frame
|
||||
state['cached_features'][frame_idx] = features
|
||||
state['cached_features'][frame_idx] = backbone_out
|
||||
|
||||
# Add boxes as prompts for this specific frame
|
||||
_, object_ids, masks = self.predictor.add_new_points_or_box(
|
||||
inference_state=state,
|
||||
frame_idx=frame_idx,
|
||||
obj_id=0, # SAM2 will auto-increment
|
||||
obj_id=None, # Let SAM2 auto-assign
|
||||
box=boxes_tensor
|
||||
)
|
||||
|
||||
@@ -239,33 +270,41 @@ class SAM2StreamingProcessor:
|
||||
|
||||
with torch.inference_mode():
|
||||
# Process frame through SAM2's image encoder
|
||||
features = self.predictor._get_image_features(frame_tensor)
|
||||
backbone_out = self.predictor.forward_image(frame_tensor)
|
||||
|
||||
# Store features in state for this frame
|
||||
state['cached_features'][frame_idx] = features
|
||||
state['cached_features'][frame_idx] = backbone_out
|
||||
|
||||
# Get masks for current frame by propagating from previous frames
|
||||
masks = []
|
||||
for obj_id in state.get('obj_ids', []):
|
||||
# Use SAM2's mask propagation for this object
|
||||
try:
|
||||
obj_mask = self.predictor._propagate_single_object(
|
||||
state, obj_id, frame_idx, features
|
||||
)
|
||||
if obj_mask is not None:
|
||||
masks.append(obj_mask)
|
||||
except Exception as e:
|
||||
# If propagation fails, use empty mask
|
||||
print(f" Warning: Propagation failed for object {obj_id}: {e}")
|
||||
empty_mask = torch.zeros((frame.shape[0], frame.shape[1]), device=self.device)
|
||||
masks.append(empty_mask)
|
||||
|
||||
# Combine all object masks
|
||||
if masks:
|
||||
combined_mask = torch.stack(masks).max(dim=0)[0]
|
||||
# Convert to numpy
|
||||
combined_mask_np = combined_mask.cpu().numpy().astype(np.uint8)
|
||||
else:
|
||||
# Use SAM2's single frame inference for propagation
|
||||
try:
|
||||
# Run single frame inference for all tracked objects
|
||||
output_dict = {}
|
||||
self.predictor._run_single_frame_inference(
|
||||
inference_state=state,
|
||||
output_dict=output_dict,
|
||||
frame_idx=frame_idx,
|
||||
batch_size=1,
|
||||
is_init_cond_frame=False, # Not initialization frame
|
||||
point_inputs=None,
|
||||
mask_inputs=None,
|
||||
reverse=False,
|
||||
run_mem_encoder=True
|
||||
)
|
||||
|
||||
# Extract masks from output
|
||||
if output_dict and 'pred_masks' in output_dict:
|
||||
pred_masks = output_dict['pred_masks']
|
||||
# Combine all object masks
|
||||
if pred_masks.shape[0] > 0:
|
||||
combined_mask = pred_masks.max(dim=0)[0]
|
||||
combined_mask_np = (combined_mask > 0.0).cpu().numpy().astype(np.uint8) * 255
|
||||
else:
|
||||
combined_mask_np = np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8)
|
||||
else:
|
||||
combined_mask_np = np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8)
|
||||
|
||||
except Exception as e:
|
||||
print(f" Warning: Single frame inference failed: {e}")
|
||||
combined_mask_np = np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8)
|
||||
|
||||
# Cleanup old features to prevent memory accumulation
|
||||
|
||||
Reference in New Issue
Block a user