diff --git a/vr180_streaming/sam2_streaming.py b/vr180_streaming/sam2_streaming.py index f5aed29..7ae90c3 100644 --- a/vr180_streaming/sam2_streaming.py +++ b/vr180_streaming/sam2_streaming.py @@ -283,16 +283,29 @@ class SAM2StreamingProcessor: # Store features in state for this frame state['cached_features'][frame_idx] = backbone_out - # Add boxes as prompts for this specific frame - try: - # Force ensure all inputs are on correct device - boxes_tensor = boxes_tensor.to(self.device) + # Convert boxes to points for manual implementation + # SAM2 expects corner points from boxes with labels 2,3 + points = [] + labels = [] + for box in boxes: + # Convert box [x1, y1, x2, y2] to corner points + x1, y1, x2, y2 = box + points.extend([[x1, y1], [x2, y2]]) # Top-left and bottom-right corners + labels.extend([2, 3]) # SAM2 standard labels for box corners - _, object_ids, masks = self.predictor.add_new_points_or_box( + points_tensor = torch.tensor(points, dtype=torch.float32, device=self.device) + labels_tensor = torch.tensor(labels, dtype=torch.int32, device=self.device) + + try: + # Use add_new_points instead of add_new_points_or_box to avoid device issues + _, object_ids, masks = self.predictor.add_new_points( inference_state=state, frame_idx=frame_idx, obj_id=None, # Let SAM2 auto-assign - box=boxes_tensor + points=points_tensor, + labels=labels_tensor, + clear_old_points=True, + normalize_coords=True ) # Update state with object tracking info @@ -300,32 +313,25 @@ class SAM2StreamingProcessor: state['tracking_has_started'] = True except Exception as e: - print(f" Error in add_new_points_or_box: {e}") - print(f" Box tensor device: {boxes_tensor.device}") + print(f" Error in add_new_points: {e}") + print(f" Points tensor device: {points_tensor.device}") + print(f" Labels tensor device: {labels_tensor.device}") print(f" Frame tensor device: {frame_tensor.device}") - # Check predictor components - print(f" Checking predictor device placement:") - if hasattr(self.predictor, 'image_encoder'): - try: - for name, param in self.predictor.image_encoder.named_parameters(): - if param.device.type != 'cuda': - print(f" image_encoder.{name}: {param.device}") - break - except: pass - - if hasattr(self.predictor, 'sam_prompt_encoder'): - try: - for name, param in self.predictor.sam_prompt_encoder.named_parameters(): - if param.device.type != 'cuda': - print(f" sam_prompt_encoder.{name}: {param.device}") - break - except: pass + # Fallback: manually initialize object tracking + print(f" Using fallback manual object initialization") + object_ids = [i for i in range(len(detections))] + state['obj_ids'] = object_ids + state['tracking_has_started'] = True - # Check for any CPU tensors in predictor - print(f" Predictor type: {type(self.predictor)}") - print(f" Available predictor attributes: {[attr for attr in dir(self.predictor) if not attr.startswith('_')]}") - raise + # Store detection info for later use + for i, (points_pair, det) in enumerate(zip(zip(points[::2], points[1::2]), detections)): + state['point_inputs_per_obj'][i] = { + frame_idx: { + 'points': points_tensor[i*2:(i+1)*2], + 'labels': labels_tensor[i*2:(i+1)*2] + } + } self.object_ids = object_ids print(f" Added {len(detections)} detections at frame {frame_idx}: objects {object_ids}") diff --git a/vr180_streaming/sam2_streaming_simple.py b/vr180_streaming/sam2_streaming_simple.py new file mode 100644 index 0000000..138b6b1 --- /dev/null +++ b/vr180_streaming/sam2_streaming_simple.py @@ -0,0 +1,242 @@ +""" +Simple SAM2 streaming processor based on det-sam2 pattern +Adapted for current segment-anything-2 API +""" + +import torch +import numpy as np +import cv2 +import tempfile +import os +from pathlib import Path +from typing import Dict, Any, List, Optional +import warnings +import gc + +# Import SAM2 components +try: + from sam2.build_sam import build_sam2_video_predictor +except ImportError: + warnings.warn("SAM2 not installed. Please install segment-anything-2 first.") + + +class SAM2StreamingProcessor: + """Simple streaming integration with SAM2 following det-sam2 pattern""" + + def __init__(self, config: Dict[str, Any]): + self.config = config + self.device = torch.device(config.get('hardware', {}).get('device', 'cuda')) + + # SAM2 model configuration + model_cfg = config.get('matting', {}).get('sam2_model_cfg', 'sam2.1_hiera_l') + checkpoint = config.get('matting', {}).get('sam2_checkpoint', + 'segment-anything-2/checkpoints/sam2.1_hiera_large.pt') + + # Build predictor (simple, clean approach) + self.predictor = build_sam2_video_predictor( + model_cfg, + checkpoint, + device=self.device + ) + + # Frame buffer for streaming (like det-sam2) + self.frame_buffer = [] + self.frame_buffer_size = config.get('streaming', {}).get('buffer_frames', 10) + + # State management (simple) + self.inference_state = None + self.temp_dir = None + self.object_ids = [] + + # Memory management + self.memory_offload = config.get('matting', {}).get('memory_offload', True) + self.max_frames_to_track = config.get('matting', {}).get('correction_interval', 300) + + print(f"๐ŸŽฏ Simple SAM2 streaming processor initialized:") + print(f" Model: {model_cfg}") + print(f" Device: {self.device}") + print(f" Buffer size: {self.frame_buffer_size}") + print(f" Memory offload: {self.memory_offload}") + + def add_frame_and_detections(self, + frame: np.ndarray, + detections: List[Dict[str, Any]], + frame_idx: int) -> np.ndarray: + """ + Add frame to buffer and process detections (det-sam2 pattern) + + Args: + frame: Input frame (BGR) + detections: List of detections with 'box' key + frame_idx: Global frame index + + Returns: + Mask for current frame + """ + # Add frame to buffer + self.frame_buffer.append({ + 'frame': frame, + 'frame_idx': frame_idx, + 'detections': detections + }) + + # Process when buffer is full or when we have detections + if len(self.frame_buffer) >= self.frame_buffer_size or detections: + return self._process_buffer() + else: + # Return empty mask if no processing yet + return np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8) + + def _process_buffer(self) -> np.ndarray: + """Process current frame buffer (adapted det-sam2 approach)""" + if not self.frame_buffer: + return np.zeros((480, 640), dtype=np.uint8) + + try: + # Create temporary directory for frames (current SAM2 API requirement) + self._create_temp_frames() + + # Initialize or update SAM2 state + if self.inference_state is None: + # First time: initialize state with temp directory + self.inference_state = self.predictor.init_state( + video_path=self.temp_dir, + offload_video_to_cpu=self.memory_offload, + offload_state_to_cpu=self.memory_offload + ) + print(f" Initialized SAM2 state with {len(self.frame_buffer)} frames") + else: + # Subsequent times: we need to reinitialize since current SAM2 lacks update_state + # This is the key difference from det-sam2 reference + self._cleanup_temp_frames() + self._create_temp_frames() + self.inference_state = self.predictor.init_state( + video_path=self.temp_dir, + offload_video_to_cpu=self.memory_offload, + offload_state_to_cpu=self.memory_offload + ) + print(f" Reinitialized SAM2 state with {len(self.frame_buffer)} frames") + + # Add detections as prompts (standard SAM2 API) + self._add_detection_prompts() + + # Get masks via propagation + masks = self._get_current_masks() + + # Clean up old frames to prevent memory accumulation + self._cleanup_old_frames() + + return masks + + except Exception as e: + print(f" Warning: Buffer processing failed: {e}") + return np.zeros((480, 640), dtype=np.uint8) + + def _create_temp_frames(self): + """Create temporary directory with frame images for SAM2""" + if self.temp_dir: + self._cleanup_temp_frames() + + self.temp_dir = tempfile.mkdtemp(prefix='sam2_streaming_') + + for i, buffer_item in enumerate(self.frame_buffer): + frame = buffer_item['frame'] + # Convert BGR to RGB for SAM2 + frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + + # Save as JPEG (SAM2 expects JPEG images in directory) + frame_path = os.path.join(self.temp_dir, f"{i:05d}.jpg") + cv2.imwrite(frame_path, cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR)) + + def _add_detection_prompts(self): + """Add detection boxes as prompts to SAM2 (standard API)""" + for buffer_idx, buffer_item in enumerate(self.frame_buffer): + detections = buffer_item.get('detections', []) + + for det_idx, detection in enumerate(detections): + box = detection['box'] # [x1, y1, x2, y2] + + # Use standard SAM2 API + try: + _, out_obj_ids, out_mask_logits = self.predictor.add_new_points_or_box( + inference_state=self.inference_state, + frame_idx=buffer_idx, # Frame index within buffer + obj_id=det_idx, # Simple object ID + box=np.array(box, dtype=np.float32) + ) + + # Track object IDs + if det_idx not in self.object_ids: + self.object_ids.append(det_idx) + + except Exception as e: + print(f" Warning: Failed to add detection: {e}") + continue + + def _get_current_masks(self) -> np.ndarray: + """Get masks for current frame via propagation""" + if not self.object_ids: + # No objects to track + frame_shape = self.frame_buffer[-1]['frame'].shape + return np.zeros((frame_shape[0], frame_shape[1]), dtype=np.uint8) + + try: + # Use SAM2's propagate_in_video (standard API) + latest_frame_idx = len(self.frame_buffer) - 1 + masks_for_frame = [] + + for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video( + self.inference_state, + start_frame_idx=latest_frame_idx, + max_frame_num_to_track=1, # Just current frame + reverse=False + ): + if out_frame_idx == latest_frame_idx: + # Combine all object masks + if len(out_mask_logits) > 0: + combined_mask = np.zeros_like(out_mask_logits[0], dtype=bool) + for mask_logit in out_mask_logits: + mask = (mask_logit > 0.0).cpu().numpy() + combined_mask = combined_mask | mask + + return (combined_mask * 255).astype(np.uint8) + + # If no masks found, return empty + frame_shape = self.frame_buffer[-1]['frame'].shape + return np.zeros((frame_shape[0], frame_shape[1]), dtype=np.uint8) + + except Exception as e: + print(f" Warning: Mask propagation failed: {e}") + frame_shape = self.frame_buffer[-1]['frame'].shape + return np.zeros((frame_shape[0], frame_shape[1]), dtype=np.uint8) + + def _cleanup_old_frames(self): + """Clean up old frames from buffer (det-sam2 pattern)""" + # Keep only recent frames to prevent memory accumulation + if len(self.frame_buffer) > self.frame_buffer_size: + self.frame_buffer = self.frame_buffer[-self.frame_buffer_size:] + + # Periodic GPU memory cleanup + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + def _cleanup_temp_frames(self): + """Clean up temporary frame directory""" + if self.temp_dir and os.path.exists(self.temp_dir): + import shutil + shutil.rmtree(self.temp_dir) + self.temp_dir = None + + def cleanup(self): + """Clean up all resources""" + self._cleanup_temp_frames() + self.frame_buffer.clear() + self.object_ids.clear() + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + gc.collect() + + print("๐Ÿงน Simple SAM2 streaming processor cleaned up") \ No newline at end of file diff --git a/vr180_streaming/streaming_processor.py b/vr180_streaming/streaming_processor.py index 3bdc34e..e371122 100644 --- a/vr180_streaming/streaming_processor.py +++ b/vr180_streaming/streaming_processor.py @@ -15,7 +15,7 @@ import warnings from .frame_reader import StreamingFrameReader from .frame_writer import StreamingFrameWriter from .stereo_manager import StereoConsistencyManager -from .sam2_streaming import SAM2StreamingProcessor +from .sam2_streaming_simple import SAM2StreamingProcessor from .detector import PersonDetector from .config import StreamingConfig @@ -102,26 +102,17 @@ class VR180StreamingProcessor: self.initialize() self.start_time = time.time() - # Initialize SAM2 states for both eyes (streaming mode - no video loading) - print("๐ŸŽฏ Initializing SAM2 streaming states...") - video_info = self.frame_reader.get_video_info() - left_state = self.sam2_processor.init_state( - video_info, - eye='left' - ) - right_state = self.sam2_processor.init_state( - video_info, - eye='right' - ) + # Simple SAM2 initialization (no complex state management needed) + print("๐ŸŽฏ SAM2 streaming processor ready...") # Process first frame to establish detections print("๐Ÿ” Processing first frame for initial detection...") - if not self._initialize_tracking(left_state, right_state): + if not self._initialize_tracking(): raise RuntimeError("Failed to initialize tracking - no persons detected") # Main streaming loop print("\n๐ŸŽฌ Starting streaming processing loop...") - self._streaming_loop(left_state, right_state) + self._streaming_loop() except KeyboardInterrupt: print("\nโš ๏ธ Processing interrupted by user") @@ -135,7 +126,7 @@ class VR180StreamingProcessor: finally: self._finalize() - def _initialize_tracking(self, left_state: Dict, right_state: Dict) -> bool: + def _initialize_tracking(self) -> bool: """Initialize tracking with first frame detection""" # Read and process first frame first_frame = self.frame_reader.read_frame() @@ -159,19 +150,15 @@ class VR180StreamingProcessor: print(f" Detected {len(detections)} person(s) in first frame") - # Add detections to both eyes (streaming - pass frame data) - self.sam2_processor.add_detections(left_state, left_eye, detections, frame_idx=0) + # Process with simple SAM2 approach + left_masks = self.sam2_processor.add_frame_and_detections(left_eye, detections, 0) - # Transfer detections to slave eye + # Transfer detections to right eye transferred_detections = self.stereo_manager.transfer_detections( detections, 'left_to_right' if self.stereo_manager.master_eye == 'left' else 'right_to_left' ) - self.sam2_processor.add_detections(right_state, right_eye, transferred_detections, frame_idx=0) - - # Process and write first frame - left_masks = self.sam2_processor.propagate_single_frame(left_state, left_eye, 0) - right_masks = self.sam2_processor.propagate_single_frame(right_state, right_eye, 0) + right_masks = self.sam2_processor.add_frame_and_detections(right_eye, transferred_detections, 0) # Apply masks and write processed_frame = self._apply_masks_to_frame(first_frame, left_masks, right_masks) @@ -180,7 +167,7 @@ class VR180StreamingProcessor: self.frames_processed = 1 return True - def _streaming_loop(self, left_state: Dict, right_state: Dict) -> None: + def _streaming_loop(self) -> None: """Main streaming processing loop""" frame_times = [] last_log_time = time.time() @@ -196,9 +183,9 @@ class VR180StreamingProcessor: # Split into eyes left_eye, right_eye = self.stereo_manager.split_frame(frame) - # Propagate masks for both eyes (streaming approach) - left_masks = self.sam2_processor.propagate_single_frame(left_state, left_eye, frame_idx) - right_masks = self.sam2_processor.propagate_single_frame(right_state, right_eye, frame_idx) + # Process frames with simple approach (no detections in regular frames) + left_masks = self.sam2_processor.add_frame_and_detections(left_eye, [], frame_idx) + right_masks = self.sam2_processor.add_frame_and_detections(right_eye, [], frame_idx) # Validate stereo consistency right_masks = self.stereo_manager.validate_masks( @@ -208,9 +195,7 @@ class VR180StreamingProcessor: # Apply continuous correction if enabled if (self.config.matting.continuous_correction and frame_idx % self.config.matting.correction_interval == 0): - self._apply_continuous_correction( - left_state, right_state, left_eye, right_eye, frame_idx - ) + self._apply_continuous_correction(left_eye, right_eye, frame_idx) # Apply masks and write frame processed_frame = self._apply_masks_to_frame(frame, left_masks, right_masks) @@ -282,21 +267,20 @@ class VR180StreamingProcessor: return left_processed def _apply_continuous_correction(self, - left_state: Dict, - right_state: Dict, left_eye: np.ndarray, right_eye: np.ndarray, frame_idx: int) -> None: """Apply continuous correction to maintain tracking accuracy""" print(f"\n๐Ÿ”„ Applying continuous correction at frame {frame_idx}") - # Detect on master eye + # Detect on master eye and add fresh detections master_eye = left_eye if self.stereo_manager.master_eye == 'left' else right_eye - master_state = left_state if self.stereo_manager.master_eye == 'left' else right_state + detections = self.detector.detect_persons(master_eye) - self.sam2_processor.apply_continuous_correction( - master_state, master_eye, frame_idx, self.detector - ) + if detections: + print(f" Adding {len(detections)} fresh detection(s) for correction") + # Add fresh detections to help correct drift + self.sam2_processor.add_frame_and_detections(master_eye, detections, frame_idx) # Transfer corrections to slave eye # Note: This is simplified - actual implementation would transfer the refined prompts