diff --git a/vr180_streaming/sam2_streaming_simple.py b/vr180_streaming/sam2_streaming_simple.py index ae081e4..5c8fdb4 100644 --- a/vr180_streaming/sam2_streaming_simple.py +++ b/vr180_streaming/sam2_streaming_simple.py @@ -42,12 +42,16 @@ class SAM2StreamingProcessor: model_cfg = config_mapping.get(model_cfg_name, model_cfg_name) - # Build predictor (simple, clean approach) + # Build predictor (disable compilation to fix CUDA graph issues) self.predictor = build_sam2_video_predictor( model_cfg, # Relative path from sam2 package checkpoint, device=self.device, - vos_optimized=True # Enable VOS optimizations for speed + vos_optimized=False, # Disable to avoid CUDA graph issues + hydra_overrides_extra=[ + "++model.compile_image_encoder=false", # Disable compilation + "++model.memory_attention.use_amp=false", # Disable AMP for stability + ] ) # Frame buffer for streaming (like det-sam2) @@ -95,8 +99,12 @@ class SAM2StreamingProcessor: if len(self.frame_buffer) >= self.frame_buffer_size or detections: return self._process_buffer() else: - # Return empty mask if no processing yet - return np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8) + # For frames without detections, still try to propagate if we have existing objects + if self.inference_state is not None and self.object_ids: + return self._propagate_existing_objects() + else: + # Return empty mask if no processing yet + return np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8) def _process_buffer(self) -> np.ndarray: """Process current frame buffer (adapted det-sam2 approach)""" @@ -219,6 +227,67 @@ class SAM2StreamingProcessor: frame_shape = self.frame_buffer[-1]['frame'].shape return np.zeros((frame_shape[0], frame_shape[1]), dtype=np.uint8) + def _propagate_existing_objects(self) -> np.ndarray: + """Propagate existing objects without adding new detections""" + if not self.object_ids or not self.frame_buffer: + frame_shape = self.frame_buffer[-1]['frame'].shape if self.frame_buffer else (480, 640) + return np.zeros((frame_shape[0], frame_shape[1]), dtype=np.uint8) + + try: + # Update temp frames with current buffer + self._create_temp_frames() + + # Reinitialize state (since we can't incrementally update) + self.inference_state = self.predictor.init_state( + video_path=self.temp_dir, + offload_video_to_cpu=self.memory_offload, + offload_state_to_cpu=self.memory_offload + ) + + # Re-add all previous detections from buffer + for buffer_idx, buffer_item in enumerate(self.frame_buffer): + detections = buffer_item.get('detections', []) + if detections: # Only add frames that had detections + for det_idx, detection in enumerate(detections): + box = detection['box'] + try: + self.predictor.add_new_points_or_box( + inference_state=self.inference_state, + frame_idx=buffer_idx, + obj_id=det_idx, + box=np.array(box, dtype=np.float32) + ) + except Exception as e: + print(f" Warning: Failed to re-add detection: {e}") + + # Get masks for latest frame + latest_frame_idx = len(self.frame_buffer) - 1 + for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video( + self.inference_state, + start_frame_idx=latest_frame_idx, + max_frame_num_to_track=1, + reverse=False + ): + if out_frame_idx == latest_frame_idx and len(out_mask_logits) > 0: + combined_mask = None + for mask_logit in out_mask_logits: + mask = (mask_logit > 0.0).cpu().numpy() + if combined_mask is None: + combined_mask = mask.astype(bool) + else: + combined_mask = combined_mask | mask.astype(bool) + + return (combined_mask * 255).astype(np.uint8) + + # If no masks, return empty + frame_shape = self.frame_buffer[-1]['frame'].shape + return np.zeros((frame_shape[0], frame_shape[1]), dtype=np.uint8) + + except Exception as e: + print(f" Warning: Object propagation failed: {e}") + frame_shape = self.frame_buffer[-1]['frame'].shape + return np.zeros((frame_shape[0], frame_shape[1]), dtype=np.uint8) + except Exception as e: print(f" Warning: Mask propagation failed: {e}") frame_shape = self.frame_buffer[-1]['frame'].shape