fix api
This commit is contained in:
@@ -129,9 +129,45 @@ class SAM2StreamingProcessor:
|
|||||||
# Create a streaming-compatible inference state
|
# Create a streaming-compatible inference state
|
||||||
# This mirrors SAM2's internal state structure but without video frames
|
# This mirrors SAM2's internal state structure but without video frames
|
||||||
|
|
||||||
|
# Use SAM2's init_state but with a dummy 1-frame video to avoid loading
|
||||||
|
# We'll override the frame access later
|
||||||
|
try:
|
||||||
|
# Create a minimal dummy video file temporarily
|
||||||
|
import tempfile
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
# Create 1-frame dummy video
|
||||||
|
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_file:
|
||||||
|
dummy_path = tmp_file.name
|
||||||
|
|
||||||
|
# Write a single frame video
|
||||||
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
||||||
|
out = cv2.VideoWriter(dummy_path, fourcc, 1.0, (video_info['width'], video_info['height']))
|
||||||
|
dummy_frame = np.zeros((video_info['height'], video_info['width'], 3), dtype=np.uint8)
|
||||||
|
out.write(dummy_frame)
|
||||||
|
out.release()
|
||||||
|
|
||||||
|
# Initialize with dummy video (SAM2 will load metadata only from 1 frame)
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
# Initialize empty inference state using SAM2's predictor
|
inference_state = self.predictor.init_state(
|
||||||
# We'll manually provide frames via propagate calls
|
video_path=dummy_path,
|
||||||
|
offload_video_to_cpu=self.memory_offload,
|
||||||
|
offload_state_to_cpu=self.memory_offload,
|
||||||
|
async_loading_frames=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Clean up dummy file
|
||||||
|
import os
|
||||||
|
os.unlink(dummy_path)
|
||||||
|
|
||||||
|
# Update state with actual video info
|
||||||
|
inference_state['num_frames'] = video_info.get('total_frames', video_info.get('frame_count', 0))
|
||||||
|
inference_state['video_height'] = video_info['height']
|
||||||
|
inference_state['video_width'] = video_info['width']
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f" Warning: Failed to create proper SAM2 state ({e}), using minimal state")
|
||||||
|
# Fallback to minimal state
|
||||||
inference_state = {
|
inference_state = {
|
||||||
'point_inputs_per_obj': {},
|
'point_inputs_per_obj': {},
|
||||||
'mask_inputs_per_obj': {},
|
'mask_inputs_per_obj': {},
|
||||||
@@ -151,13 +187,8 @@ class SAM2StreamingProcessor:
|
|||||||
'storage_device': torch.device('cpu') if self.memory_offload else self.device,
|
'storage_device': torch.device('cpu') if self.memory_offload else self.device,
|
||||||
'offload_video_to_cpu': self.memory_offload,
|
'offload_video_to_cpu': self.memory_offload,
|
||||||
'offload_state_to_cpu': self.memory_offload,
|
'offload_state_to_cpu': self.memory_offload,
|
||||||
'inference_state': {},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# Initialize SAM2 constants that don't depend on video frames
|
|
||||||
self.predictor._get_image_feature_cache = {}
|
|
||||||
self.predictor._feature_bank = {}
|
|
||||||
|
|
||||||
return inference_state
|
return inference_state
|
||||||
|
|
||||||
def add_detections(self,
|
def add_detections(self,
|
||||||
@@ -198,16 +229,16 @@ class SAM2StreamingProcessor:
|
|||||||
# Manually process frame and add prompts (streaming approach)
|
# Manually process frame and add prompts (streaming approach)
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
# Process frame through SAM2's image encoder
|
# Process frame through SAM2's image encoder
|
||||||
features = self.predictor._get_image_features(frame_tensor)
|
backbone_out = self.predictor.forward_image(frame_tensor)
|
||||||
|
|
||||||
# Store features in state for this frame
|
# Store features in state for this frame
|
||||||
state['cached_features'][frame_idx] = features
|
state['cached_features'][frame_idx] = backbone_out
|
||||||
|
|
||||||
# Add boxes as prompts for this specific frame
|
# Add boxes as prompts for this specific frame
|
||||||
_, object_ids, masks = self.predictor.add_new_points_or_box(
|
_, object_ids, masks = self.predictor.add_new_points_or_box(
|
||||||
inference_state=state,
|
inference_state=state,
|
||||||
frame_idx=frame_idx,
|
frame_idx=frame_idx,
|
||||||
obj_id=0, # SAM2 will auto-increment
|
obj_id=None, # Let SAM2 auto-assign
|
||||||
box=boxes_tensor
|
box=boxes_tensor
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -239,34 +270,42 @@ class SAM2StreamingProcessor:
|
|||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
# Process frame through SAM2's image encoder
|
# Process frame through SAM2's image encoder
|
||||||
features = self.predictor._get_image_features(frame_tensor)
|
backbone_out = self.predictor.forward_image(frame_tensor)
|
||||||
|
|
||||||
# Store features in state for this frame
|
# Store features in state for this frame
|
||||||
state['cached_features'][frame_idx] = features
|
state['cached_features'][frame_idx] = backbone_out
|
||||||
|
|
||||||
# Get masks for current frame by propagating from previous frames
|
# Use SAM2's single frame inference for propagation
|
||||||
masks = []
|
|
||||||
for obj_id in state.get('obj_ids', []):
|
|
||||||
# Use SAM2's mask propagation for this object
|
|
||||||
try:
|
try:
|
||||||
obj_mask = self.predictor._propagate_single_object(
|
# Run single frame inference for all tracked objects
|
||||||
state, obj_id, frame_idx, features
|
output_dict = {}
|
||||||
|
self.predictor._run_single_frame_inference(
|
||||||
|
inference_state=state,
|
||||||
|
output_dict=output_dict,
|
||||||
|
frame_idx=frame_idx,
|
||||||
|
batch_size=1,
|
||||||
|
is_init_cond_frame=False, # Not initialization frame
|
||||||
|
point_inputs=None,
|
||||||
|
mask_inputs=None,
|
||||||
|
reverse=False,
|
||||||
|
run_mem_encoder=True
|
||||||
)
|
)
|
||||||
if obj_mask is not None:
|
|
||||||
masks.append(obj_mask)
|
|
||||||
except Exception as e:
|
|
||||||
# If propagation fails, use empty mask
|
|
||||||
print(f" Warning: Propagation failed for object {obj_id}: {e}")
|
|
||||||
empty_mask = torch.zeros((frame.shape[0], frame.shape[1]), device=self.device)
|
|
||||||
masks.append(empty_mask)
|
|
||||||
|
|
||||||
|
# Extract masks from output
|
||||||
|
if output_dict and 'pred_masks' in output_dict:
|
||||||
|
pred_masks = output_dict['pred_masks']
|
||||||
# Combine all object masks
|
# Combine all object masks
|
||||||
if masks:
|
if pred_masks.shape[0] > 0:
|
||||||
combined_mask = torch.stack(masks).max(dim=0)[0]
|
combined_mask = pred_masks.max(dim=0)[0]
|
||||||
# Convert to numpy
|
combined_mask_np = (combined_mask > 0.0).cpu().numpy().astype(np.uint8) * 255
|
||||||
combined_mask_np = combined_mask.cpu().numpy().astype(np.uint8)
|
|
||||||
else:
|
else:
|
||||||
combined_mask_np = np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8)
|
combined_mask_np = np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8)
|
||||||
|
else:
|
||||||
|
combined_mask_np = np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f" Warning: Single frame inference failed: {e}")
|
||||||
|
combined_mask_np = np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8)
|
||||||
|
|
||||||
# Cleanup old features to prevent memory accumulation
|
# Cleanup old features to prevent memory accumulation
|
||||||
self._cleanup_old_features(state, frame_idx, keep_frames=10)
|
self._cleanup_old_features(state, frame_idx, keep_frames=10)
|
||||||
|
|||||||
Reference in New Issue
Block a user