fix streaming

This commit is contained in:
2025-07-27 10:16:39 -07:00
parent 9cc755b5c7
commit 1e9c42adbd

View File

@@ -42,12 +42,16 @@ class SAM2StreamingProcessor:
model_cfg = config_mapping.get(model_cfg_name, model_cfg_name)
# Build predictor (simple, clean approach)
# Build predictor (disable compilation to fix CUDA graph issues)
self.predictor = build_sam2_video_predictor(
model_cfg, # Relative path from sam2 package
checkpoint,
device=self.device,
vos_optimized=True # Enable VOS optimizations for speed
vos_optimized=False, # Disable to avoid CUDA graph issues
hydra_overrides_extra=[
"++model.compile_image_encoder=false", # Disable compilation
"++model.memory_attention.use_amp=false", # Disable AMP for stability
]
)
# Frame buffer for streaming (like det-sam2)
@@ -95,8 +99,12 @@ class SAM2StreamingProcessor:
if len(self.frame_buffer) >= self.frame_buffer_size or detections:
return self._process_buffer()
else:
# Return empty mask if no processing yet
return np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8)
# For frames without detections, still try to propagate if we have existing objects
if self.inference_state is not None and self.object_ids:
return self._propagate_existing_objects()
else:
# Return empty mask if no processing yet
return np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8)
def _process_buffer(self) -> np.ndarray:
"""Process current frame buffer (adapted det-sam2 approach)"""
@@ -219,6 +227,67 @@ class SAM2StreamingProcessor:
frame_shape = self.frame_buffer[-1]['frame'].shape
return np.zeros((frame_shape[0], frame_shape[1]), dtype=np.uint8)
def _propagate_existing_objects(self) -> np.ndarray:
"""Propagate existing objects without adding new detections"""
if not self.object_ids or not self.frame_buffer:
frame_shape = self.frame_buffer[-1]['frame'].shape if self.frame_buffer else (480, 640)
return np.zeros((frame_shape[0], frame_shape[1]), dtype=np.uint8)
try:
# Update temp frames with current buffer
self._create_temp_frames()
# Reinitialize state (since we can't incrementally update)
self.inference_state = self.predictor.init_state(
video_path=self.temp_dir,
offload_video_to_cpu=self.memory_offload,
offload_state_to_cpu=self.memory_offload
)
# Re-add all previous detections from buffer
for buffer_idx, buffer_item in enumerate(self.frame_buffer):
detections = buffer_item.get('detections', [])
if detections: # Only add frames that had detections
for det_idx, detection in enumerate(detections):
box = detection['box']
try:
self.predictor.add_new_points_or_box(
inference_state=self.inference_state,
frame_idx=buffer_idx,
obj_id=det_idx,
box=np.array(box, dtype=np.float32)
)
except Exception as e:
print(f" Warning: Failed to re-add detection: {e}")
# Get masks for latest frame
latest_frame_idx = len(self.frame_buffer) - 1
for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(
self.inference_state,
start_frame_idx=latest_frame_idx,
max_frame_num_to_track=1,
reverse=False
):
if out_frame_idx == latest_frame_idx and len(out_mask_logits) > 0:
combined_mask = None
for mask_logit in out_mask_logits:
mask = (mask_logit > 0.0).cpu().numpy()
if combined_mask is None:
combined_mask = mask.astype(bool)
else:
combined_mask = combined_mask | mask.astype(bool)
return (combined_mask * 255).astype(np.uint8)
# If no masks, return empty
frame_shape = self.frame_buffer[-1]['frame'].shape
return np.zeros((frame_shape[0], frame_shape[1]), dtype=np.uint8)
except Exception as e:
print(f" Warning: Object propagation failed: {e}")
frame_shape = self.frame_buffer[-1]['frame'].shape
return np.zeros((frame_shape[0], frame_shape[1]), dtype=np.uint8)
except Exception as e:
print(f" Warning: Mask propagation failed: {e}")
frame_shape = self.frame_buffer[-1]['frame'].shape