fix streaming
This commit is contained in:
@@ -42,12 +42,16 @@ class SAM2StreamingProcessor:
|
|||||||
|
|
||||||
model_cfg = config_mapping.get(model_cfg_name, model_cfg_name)
|
model_cfg = config_mapping.get(model_cfg_name, model_cfg_name)
|
||||||
|
|
||||||
# Build predictor (simple, clean approach)
|
# Build predictor (disable compilation to fix CUDA graph issues)
|
||||||
self.predictor = build_sam2_video_predictor(
|
self.predictor = build_sam2_video_predictor(
|
||||||
model_cfg, # Relative path from sam2 package
|
model_cfg, # Relative path from sam2 package
|
||||||
checkpoint,
|
checkpoint,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
vos_optimized=True # Enable VOS optimizations for speed
|
vos_optimized=False, # Disable to avoid CUDA graph issues
|
||||||
|
hydra_overrides_extra=[
|
||||||
|
"++model.compile_image_encoder=false", # Disable compilation
|
||||||
|
"++model.memory_attention.use_amp=false", # Disable AMP for stability
|
||||||
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
# Frame buffer for streaming (like det-sam2)
|
# Frame buffer for streaming (like det-sam2)
|
||||||
@@ -94,6 +98,10 @@ class SAM2StreamingProcessor:
|
|||||||
# Process when buffer is full or when we have detections
|
# Process when buffer is full or when we have detections
|
||||||
if len(self.frame_buffer) >= self.frame_buffer_size or detections:
|
if len(self.frame_buffer) >= self.frame_buffer_size or detections:
|
||||||
return self._process_buffer()
|
return self._process_buffer()
|
||||||
|
else:
|
||||||
|
# For frames without detections, still try to propagate if we have existing objects
|
||||||
|
if self.inference_state is not None and self.object_ids:
|
||||||
|
return self._propagate_existing_objects()
|
||||||
else:
|
else:
|
||||||
# Return empty mask if no processing yet
|
# Return empty mask if no processing yet
|
||||||
return np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8)
|
return np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8)
|
||||||
@@ -219,6 +227,67 @@ class SAM2StreamingProcessor:
|
|||||||
frame_shape = self.frame_buffer[-1]['frame'].shape
|
frame_shape = self.frame_buffer[-1]['frame'].shape
|
||||||
return np.zeros((frame_shape[0], frame_shape[1]), dtype=np.uint8)
|
return np.zeros((frame_shape[0], frame_shape[1]), dtype=np.uint8)
|
||||||
|
|
||||||
|
def _propagate_existing_objects(self) -> np.ndarray:
|
||||||
|
"""Propagate existing objects without adding new detections"""
|
||||||
|
if not self.object_ids or not self.frame_buffer:
|
||||||
|
frame_shape = self.frame_buffer[-1]['frame'].shape if self.frame_buffer else (480, 640)
|
||||||
|
return np.zeros((frame_shape[0], frame_shape[1]), dtype=np.uint8)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Update temp frames with current buffer
|
||||||
|
self._create_temp_frames()
|
||||||
|
|
||||||
|
# Reinitialize state (since we can't incrementally update)
|
||||||
|
self.inference_state = self.predictor.init_state(
|
||||||
|
video_path=self.temp_dir,
|
||||||
|
offload_video_to_cpu=self.memory_offload,
|
||||||
|
offload_state_to_cpu=self.memory_offload
|
||||||
|
)
|
||||||
|
|
||||||
|
# Re-add all previous detections from buffer
|
||||||
|
for buffer_idx, buffer_item in enumerate(self.frame_buffer):
|
||||||
|
detections = buffer_item.get('detections', [])
|
||||||
|
if detections: # Only add frames that had detections
|
||||||
|
for det_idx, detection in enumerate(detections):
|
||||||
|
box = detection['box']
|
||||||
|
try:
|
||||||
|
self.predictor.add_new_points_or_box(
|
||||||
|
inference_state=self.inference_state,
|
||||||
|
frame_idx=buffer_idx,
|
||||||
|
obj_id=det_idx,
|
||||||
|
box=np.array(box, dtype=np.float32)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f" Warning: Failed to re-add detection: {e}")
|
||||||
|
|
||||||
|
# Get masks for latest frame
|
||||||
|
latest_frame_idx = len(self.frame_buffer) - 1
|
||||||
|
for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(
|
||||||
|
self.inference_state,
|
||||||
|
start_frame_idx=latest_frame_idx,
|
||||||
|
max_frame_num_to_track=1,
|
||||||
|
reverse=False
|
||||||
|
):
|
||||||
|
if out_frame_idx == latest_frame_idx and len(out_mask_logits) > 0:
|
||||||
|
combined_mask = None
|
||||||
|
for mask_logit in out_mask_logits:
|
||||||
|
mask = (mask_logit > 0.0).cpu().numpy()
|
||||||
|
if combined_mask is None:
|
||||||
|
combined_mask = mask.astype(bool)
|
||||||
|
else:
|
||||||
|
combined_mask = combined_mask | mask.astype(bool)
|
||||||
|
|
||||||
|
return (combined_mask * 255).astype(np.uint8)
|
||||||
|
|
||||||
|
# If no masks, return empty
|
||||||
|
frame_shape = self.frame_buffer[-1]['frame'].shape
|
||||||
|
return np.zeros((frame_shape[0], frame_shape[1]), dtype=np.uint8)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f" Warning: Object propagation failed: {e}")
|
||||||
|
frame_shape = self.frame_buffer[-1]['frame'].shape
|
||||||
|
return np.zeros((frame_shape[0], frame_shape[1]), dtype=np.uint8)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f" Warning: Mask propagation failed: {e}")
|
print(f" Warning: Mask propagation failed: {e}")
|
||||||
frame_shape = self.frame_buffer[-1]['frame'].shape
|
frame_shape = self.frame_buffer[-1]['frame'].shape
|
||||||
|
|||||||
Reference in New Issue
Block a user