fix streaming
This commit is contained in:
@@ -42,12 +42,16 @@ class SAM2StreamingProcessor:
|
||||
|
||||
model_cfg = config_mapping.get(model_cfg_name, model_cfg_name)
|
||||
|
||||
# Build predictor (simple, clean approach)
|
||||
# Build predictor (disable compilation to fix CUDA graph issues)
|
||||
self.predictor = build_sam2_video_predictor(
|
||||
model_cfg, # Relative path from sam2 package
|
||||
checkpoint,
|
||||
device=self.device,
|
||||
vos_optimized=True # Enable VOS optimizations for speed
|
||||
vos_optimized=False, # Disable to avoid CUDA graph issues
|
||||
hydra_overrides_extra=[
|
||||
"++model.compile_image_encoder=false", # Disable compilation
|
||||
"++model.memory_attention.use_amp=false", # Disable AMP for stability
|
||||
]
|
||||
)
|
||||
|
||||
# Frame buffer for streaming (like det-sam2)
|
||||
@@ -95,8 +99,12 @@ class SAM2StreamingProcessor:
|
||||
if len(self.frame_buffer) >= self.frame_buffer_size or detections:
|
||||
return self._process_buffer()
|
||||
else:
|
||||
# Return empty mask if no processing yet
|
||||
return np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8)
|
||||
# For frames without detections, still try to propagate if we have existing objects
|
||||
if self.inference_state is not None and self.object_ids:
|
||||
return self._propagate_existing_objects()
|
||||
else:
|
||||
# Return empty mask if no processing yet
|
||||
return np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8)
|
||||
|
||||
def _process_buffer(self) -> np.ndarray:
|
||||
"""Process current frame buffer (adapted det-sam2 approach)"""
|
||||
@@ -219,6 +227,67 @@ class SAM2StreamingProcessor:
|
||||
frame_shape = self.frame_buffer[-1]['frame'].shape
|
||||
return np.zeros((frame_shape[0], frame_shape[1]), dtype=np.uint8)
|
||||
|
||||
def _propagate_existing_objects(self) -> np.ndarray:
|
||||
"""Propagate existing objects without adding new detections"""
|
||||
if not self.object_ids or not self.frame_buffer:
|
||||
frame_shape = self.frame_buffer[-1]['frame'].shape if self.frame_buffer else (480, 640)
|
||||
return np.zeros((frame_shape[0], frame_shape[1]), dtype=np.uint8)
|
||||
|
||||
try:
|
||||
# Update temp frames with current buffer
|
||||
self._create_temp_frames()
|
||||
|
||||
# Reinitialize state (since we can't incrementally update)
|
||||
self.inference_state = self.predictor.init_state(
|
||||
video_path=self.temp_dir,
|
||||
offload_video_to_cpu=self.memory_offload,
|
||||
offload_state_to_cpu=self.memory_offload
|
||||
)
|
||||
|
||||
# Re-add all previous detections from buffer
|
||||
for buffer_idx, buffer_item in enumerate(self.frame_buffer):
|
||||
detections = buffer_item.get('detections', [])
|
||||
if detections: # Only add frames that had detections
|
||||
for det_idx, detection in enumerate(detections):
|
||||
box = detection['box']
|
||||
try:
|
||||
self.predictor.add_new_points_or_box(
|
||||
inference_state=self.inference_state,
|
||||
frame_idx=buffer_idx,
|
||||
obj_id=det_idx,
|
||||
box=np.array(box, dtype=np.float32)
|
||||
)
|
||||
except Exception as e:
|
||||
print(f" Warning: Failed to re-add detection: {e}")
|
||||
|
||||
# Get masks for latest frame
|
||||
latest_frame_idx = len(self.frame_buffer) - 1
|
||||
for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(
|
||||
self.inference_state,
|
||||
start_frame_idx=latest_frame_idx,
|
||||
max_frame_num_to_track=1,
|
||||
reverse=False
|
||||
):
|
||||
if out_frame_idx == latest_frame_idx and len(out_mask_logits) > 0:
|
||||
combined_mask = None
|
||||
for mask_logit in out_mask_logits:
|
||||
mask = (mask_logit > 0.0).cpu().numpy()
|
||||
if combined_mask is None:
|
||||
combined_mask = mask.astype(bool)
|
||||
else:
|
||||
combined_mask = combined_mask | mask.astype(bool)
|
||||
|
||||
return (combined_mask * 255).astype(np.uint8)
|
||||
|
||||
# If no masks, return empty
|
||||
frame_shape = self.frame_buffer[-1]['frame'].shape
|
||||
return np.zeros((frame_shape[0], frame_shape[1]), dtype=np.uint8)
|
||||
|
||||
except Exception as e:
|
||||
print(f" Warning: Object propagation failed: {e}")
|
||||
frame_shape = self.frame_buffer[-1]['frame'].shape
|
||||
return np.zeros((frame_shape[0], frame_shape[1]), dtype=np.uint8)
|
||||
|
||||
except Exception as e:
|
||||
print(f" Warning: Mask propagation failed: {e}")
|
||||
frame_shape = self.frame_buffer[-1]['frame'].shape
|
||||
|
||||
Reference in New Issue
Block a user