""" SAM2 processor module for video segmentation. Preserves the core SAM2 logic from the original implementation. """ import os import cv2 import numpy as np import torch import logging import subprocess import gc from typing import Dict, List, Any, Optional, Tuple from sam2.build_sam import build_sam2_video_predictor from .eye_processor import EyeProcessor logger = logging.getLogger(__name__) class SAM2Processor: """Handles SAM2-based video segmentation for human tracking.""" def __init__(self, checkpoint_path: str, config_path: str, vos_optimized: bool = False, separate_eye_processing: bool = False, eye_overlap_pixels: int = 0, async_preprocessor=None): """ Initialize SAM2 processor. Args: checkpoint_path: Path to SAM2 checkpoint config_path: Path to SAM2 config file vos_optimized: Enable VOS optimization for speedup (requires PyTorch 2.5.1+) separate_eye_processing: Enable VR180 separate eye processing mode eye_overlap_pixels: Pixel overlap between eyes for blending async_preprocessor: Optional async preprocessor for background low-res video generation """ self.checkpoint_path = checkpoint_path self.config_path = config_path self.vos_optimized = vos_optimized self.separate_eye_processing = separate_eye_processing self.async_preprocessor = async_preprocessor self.predictor = None # Initialize eye processor if separate eye processing is enabled if separate_eye_processing: self.eye_processor = EyeProcessor(eye_overlap_pixels=eye_overlap_pixels) else: self.eye_processor = None self._initialize_predictor() def _initialize_predictor(self): """Initialize SAM2 video predictor with proper device setup.""" if torch.cuda.is_available(): device = torch.device("cuda") elif torch.backends.mps.is_available(): device = torch.device("mps") logger.warning( "Support for MPS devices is preliminary. SAM 2 is trained with CUDA and might " "give numerically different outputs and sometimes degraded performance on MPS." ) os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" else: device = torch.device("cpu") logger.info(f"Using device: {device}") try: # Extract just the config filename for SAM2's Hydra-based loader # SAM2 expects a config name relative to its internal config directory config_name = os.path.basename(self.config_path) if config_name.endswith('.yaml'): config_name = config_name[:-5] # Remove .yaml extension # SAM2 configs are in the format "sam2.1_hiera_X.yaml" # and should be referenced as "configs/sam2.1/sam2.1_hiera_X" if config_name.startswith("sam2.1_hiera"): config_name = f"configs/sam2.1/{config_name}" elif config_name.startswith("sam2_hiera"): config_name = f"configs/sam2/{config_name}" logger.info(f"Using SAM2 config: {config_name}") # Use VOS optimization if enabled and supported if self.vos_optimized: try: self.predictor = build_sam2_video_predictor( config_name, # Use just the config name, not full path self.checkpoint_path, device=device, vos_optimized=True # New optimization for major speedup ) logger.info("Using optimized SAM2 VOS predictor with full model compilation") except Exception as e: logger.warning(f"Failed to use optimized VOS predictor: {e}") logger.info("Falling back to standard SAM2 predictor") # Fallback to standard predictor self.predictor = build_sam2_video_predictor( config_name, self.checkpoint_path, device=device, overrides=dict(conf=0.95) ) else: # Use standard predictor self.predictor = build_sam2_video_predictor( config_name, self.checkpoint_path, device=device, overrides=dict(conf=0.95) ) logger.info("Using standard SAM2 predictor") # Enable optimizations for CUDA if device.type == "cuda": torch.autocast("cuda", dtype=torch.bfloat16).__enter__() if torch.cuda.get_device_properties(0).major >= 8: torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True logger.info(f"SAM2 predictor initialized successfully") except Exception as e: logger.error(f"Failed to initialize SAM2 predictor: {e}") raise def create_low_res_video(self, input_video_path: str, output_video_path: str, scale: float): """ Create a low-resolution version of the input video for inference using FFmpeg with hardware acceleration for improved performance. Args: input_video_path: Path to input video output_video_path: Path to output low-res video scale: Scale factor for resolution reduction """ try: # Get video properties using OpenCV cap = cv2.VideoCapture(input_video_path) if not cap.isOpened(): raise ValueError(f"Could not open video: {input_video_path}") original_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) original_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) fps = cap.get(cv2.CAP_PROP_FPS) or 30.0 frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) cap.release() target_width = int(original_width * scale) target_height = int(original_height * scale) # Ensure dimensions are even, as required by many codecs target_width = target_width if target_width % 2 == 0 else target_width + 1 target_height = target_height if target_height % 2 == 0 else target_height + 1 # Construct FFmpeg command with hardware acceleration command = [ 'ffmpeg', '-y', '-hwaccel', 'auto', # Auto-detect hardware acceleration '-i', input_video_path, '-vf', f'scale={target_width}:{target_height}', '-c:v', 'h264_nvenc', # Use NVIDIA's hardware encoder '-preset', 'fast', '-crf', '23', output_video_path ] logger.info(f"Executing FFmpeg command: {' '.join(command)}") # Execute FFmpeg command process = subprocess.run(command, check=True, capture_output=True, text=True) if process.returncode != 0: logger.error(f"FFmpeg failed with error: {process.stderr}") raise RuntimeError(f"FFmpeg process failed: {process.stderr}") logger.info(f"Created low-res video with {frame_count} frames: {output_video_path}") except (subprocess.CalledProcessError, FileNotFoundError) as e: logger.warning(f"Hardware-accelerated FFmpeg failed: {e}. Falling back to OpenCV.") # Fallback to original OpenCV implementation if FFmpeg fails self._create_low_res_video_opencv(input_video_path, output_video_path, scale) def _create_low_res_video_opencv(self, input_video_path: str, output_video_path: str, scale: float): """Original OpenCV-based implementation for creating low-resolution video.""" cap = cv2.VideoCapture(input_video_path) if not cap.isOpened(): raise ValueError(f"Could not open video: {input_video_path}") frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH) * scale) frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT) * scale) fps = cap.get(cv2.CAP_PROP_FPS) or 30.0 fourcc = cv2.VideoWriter_fourcc(*'mp4v') out = cv2.VideoWriter(output_video_path, fourcc, fps, (frame_width, frame_height)) frame_count = 0 while True: ret, frame = cap.read() if not ret: break low_res_frame = cv2.resize(frame, (frame_width, frame_height), interpolation=cv2.INTER_LINEAR) out.write(low_res_frame) frame_count += 1 cap.release() out.release() logger.info(f"Created low-res video with {frame_count} frames using OpenCV: {output_video_path}") def ensure_low_res_video(self, input_video_path: str, output_video_path: str, scale: float, segment_idx: Optional[int] = None) -> bool: """ Ensure low-resolution video exists, using async preprocessor if available. Args: input_video_path: Path to input video output_video_path: Path to output low-res video scale: Scale factor for resolution reduction segment_idx: Optional segment index for async coordination Returns: True if low-res video is ready """ # Check if already exists if os.path.exists(output_video_path) and os.path.getsize(output_video_path) > 0: return True # Use async preprocessor if available and segment index provided if self.async_preprocessor and segment_idx is not None: if self.async_preprocessor.is_segment_ready(segment_idx): if os.path.exists(output_video_path) and os.path.getsize(output_video_path) > 0: logger.debug(f"Async preprocessor provided segment {segment_idx}") return True else: logger.debug(f"Async preprocessor hasn't completed segment {segment_idx} yet") # Fallback to synchronous creation try: logger.info(f"Creating low-res video synchronously: {input_video_path} -> {output_video_path}") self.create_low_res_video(input_video_path, output_video_path, scale) if os.path.exists(output_video_path) and os.path.getsize(output_video_path) > 0: logger.info(f"Successfully created low-res video: {output_video_path} ({os.path.getsize(output_video_path)} bytes)") return True else: logger.error(f"Low-res video creation failed - file doesn't exist or is empty: {output_video_path}") return False except Exception as e: logger.error(f"Failed to create low-res video {output_video_path}: {e}") return False def add_yolo_prompts_to_predictor(self, inference_state, prompts: List[Dict[str, Any]], inference_scale: float = 1.0) -> bool: """ Add YOLO detection prompts to SAM2 predictor. Includes error handling matching the working spec.md implementation. Args: inference_state: SAM2 inference state prompts: List of prompt dictionaries with obj_id and bbox inference_scale: Scale factor to apply to bounding boxes Returns: True if prompts were added successfully """ if not prompts: logger.warning("SAM2 Debug: No prompts provided to SAM2") return False logger.info(f"SAM2 Debug: Received {len(prompts)} prompts to add to predictor") success_count = 0 for i, prompt in enumerate(prompts): obj_id = prompt['obj_id'] bbox = prompt['bbox'] confidence = prompt.get('confidence', 'unknown') # Scale bounding box for SAM2 inference resolution scaled_bbox = bbox * inference_scale logger.info(f"SAM2 Debug: Adding prompt {i+1}/{len(prompts)}: Object {obj_id}") logger.info(f" Original bbox: {bbox}") logger.info(f" Scaled bbox (scale={inference_scale}): {scaled_bbox}") logger.info(f" Confidence: {confidence}") try: _, out_obj_ids, out_mask_logits = self.predictor.add_new_points_or_box( inference_state=inference_state, frame_idx=0, obj_id=obj_id, box=scaled_bbox.astype(np.float32), ) logger.info(f"SAM2 Debug: ✓ Successfully added Object {obj_id} - returned obj_ids: {out_obj_ids}") success_count += 1 except Exception as e: logger.error(f"SAM2 Debug: ✗ Error adding Object {obj_id}: {e}") # Continue processing other prompts even if one fails continue if success_count > 0: logger.info(f"SAM2 Debug: Final result - {success_count}/{len(prompts)} prompts successfully added") return True else: logger.error("SAM2 Debug: FAILED - No prompts were successfully added to SAM2") return False def load_previous_segment_mask(self, prev_segment_dir: str) -> Optional[Dict[int, np.ndarray]]: """ Load masks from previous segment for continuity. Args: prev_segment_dir: Directory of previous segment Returns: Dictionary mapping object IDs to masks, or None if failed """ mask_path = os.path.join(prev_segment_dir, "mask.png") if not os.path.exists(mask_path): logger.warning(f"Previous mask not found: {mask_path}") return None try: mask_image = cv2.imread(mask_path) if mask_image is None: logger.error(f"Could not read mask image: {mask_path}") return None if len(mask_image.shape) != 3 or mask_image.shape[2] != 3: logger.error("Mask image does not have three color channels") return None mask_image = mask_image.astype(np.uint8) # Extract Object A and Object B masks (preserving original logic) GREEN = [0, 255, 0] BLUE = [255, 0, 0] mask_a = np.all(mask_image == GREEN, axis=2) mask_b = np.all(mask_image == BLUE, axis=2) per_obj_input_mask = {} if np.any(mask_a): per_obj_input_mask[1] = mask_a if np.any(mask_b): per_obj_input_mask[2] = mask_b logger.info(f"Loaded masks for {len(per_obj_input_mask)} objects from {prev_segment_dir}") return per_obj_input_mask except Exception as e: logger.error(f"Error loading previous mask: {e}") return None def add_previous_masks_to_predictor(self, inference_state, masks: Dict[int, np.ndarray]) -> bool: """ Add previous segment masks to predictor for continuity. Args: inference_state: SAM2 inference state masks: Dictionary mapping object IDs to masks Returns: True if masks were added successfully """ try: for obj_id, mask in masks.items(): self.predictor.add_new_mask(inference_state, 0, obj_id, mask) logger.debug(f"Added previous mask for Object {obj_id}") logger.info(f"Successfully added {len(masks)} previous masks to SAM2") return True except Exception as e: logger.error(f"Error adding previous masks to SAM2: {e}") return False def propagate_masks(self, inference_state) -> Dict[int, Dict[int, np.ndarray]]: """ Propagate masks across all frames in the video. Args: inference_state: SAM2 inference state Returns: Dictionary mapping frame indices to object masks """ video_segments = {} frame_count = 0 try: logger.info("Starting SAM2 mask propagation...") for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(inference_state): video_segments[out_frame_idx] = { out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy() for i, out_obj_id in enumerate(out_obj_ids) } frame_count += 1 # Log progress every 50 frames if frame_count % 50 == 0: logger.info(f"SAM2 propagation progress: {frame_count} frames processed") logger.info(f"SAM2 propagation completed: {len(video_segments)} frames with {len(out_obj_ids) if 'out_obj_ids' in locals() else 0} objects") except Exception as e: logger.error(f"Error during mask propagation after {frame_count} frames: {e}") logger.error("This may be due to VOS optimization issues or insufficient GPU memory") if frame_count == 0: logger.error("No frames were processed - propagation failed completely") else: logger.warning(f"Partial propagation completed: {frame_count} frames before failure") return video_segments def process_single_segment(self, segment_info: dict, yolo_prompts: Optional[List[Dict[str, Any]]] = None, previous_masks: Optional[Dict[int, np.ndarray]] = None, inference_scale: float = 0.5, multi_frame_prompts: Optional[Dict[int, List[Dict[str, Any]]]] = None) -> Optional[Dict[int, Dict[int, np.ndarray]]]: """ Process a single video segment with SAM2. Args: segment_info: Segment information dictionary yolo_prompts: Optional YOLO detection prompts for first frame previous_masks: Optional masks from previous segment inference_scale: Scale factor for inference multi_frame_prompts: Optional prompts for multiple frames (mid-segment detection) Returns: Video segments dictionary or None if failed """ segment_dir = segment_info['directory'] video_path = segment_info['video_file'] segment_idx = segment_info['index'] # Check if segment is already processed (resume capability) output_done_file = os.path.join(segment_dir, "output_frames_done") if os.path.exists(output_done_file): logger.info(f"Segment {segment_idx} already processed. Skipping.") return None # Indicate skip, not failure logger.info(f"Processing segment {segment_idx} with SAM2") # Create low-resolution video for inference (async-aware) low_res_video_path = os.path.join(segment_dir, "low_res_video.mp4") if not self.ensure_low_res_video(video_path, low_res_video_path, inference_scale, segment_idx): logger.error(f"Failed to create low-res video for segment {segment_idx}") return None try: # Initialize inference state inference_state = self.predictor.init_state(video_path=low_res_video_path, async_loading_frames=True) # Add prompts or previous masks if yolo_prompts: if not self.add_yolo_prompts_to_predictor(inference_state, yolo_prompts, inference_scale): return None elif previous_masks: if not self.add_previous_masks_to_predictor(inference_state, previous_masks): return None else: logger.error(f"No prompts or previous masks available for segment {segment_idx}") return None # Add mid-segment prompts if provided if multi_frame_prompts: logger.info(f"Adding mid-segment prompts for segment {segment_idx}") if not self.add_multi_frame_prompts_to_predictor(inference_state, multi_frame_prompts): logger.warning(f"Failed to add mid-segment prompts for segment {segment_idx}") # Don't return None here - continue with existing prompts # Propagate masks video_segments = self.propagate_masks(inference_state) # Clean up self.predictor.reset_state(inference_state) del inference_state gc.collect() # Remove low-res video to save space try: os.remove(low_res_video_path) logger.debug(f"Removed low-res video: {low_res_video_path}") except Exception as e: logger.warning(f"Could not remove low-res video: {e}") return video_segments except Exception as e: logger.error(f"Error processing segment {segment_idx}: {e}") return None def save_final_masks(self, video_segments: Dict[int, Dict[int, np.ndarray]], output_path: str, green_color: List[int] = [0, 255, 0], blue_color: List[int] = [255, 0, 0]): """ Save the final masks as a colored image for continuity. Args: video_segments: Video segments dictionary output_path: Path to save the mask image green_color: RGB color for object 1 blue_color: RGB color for object 2 """ if not video_segments: logger.error("No video segments to save") return try: last_frame_idx = max(video_segments.keys()) masks_dict = video_segments[last_frame_idx] # Get masks for objects 1 and 2 mask_a = masks_dict.get(1) mask_b = masks_dict.get(2) if mask_a is None and mask_b is None: logger.error("No masks found for objects") return # Use the first available mask to determine dimensions reference_mask = mask_a if mask_a is not None else mask_b reference_mask = reference_mask.squeeze() black_frame = np.zeros((reference_mask.shape[0], reference_mask.shape[1], 3), dtype=np.uint8) if mask_a is not None: mask_a = mask_a.squeeze().astype(bool) black_frame[mask_a] = green_color if mask_b is not None: mask_b = mask_b.squeeze().astype(bool) black_frame[mask_b] = blue_color # Save the mask image cv2.imwrite(output_path, black_frame) logger.info(f"Saved final masks to {output_path}") except Exception as e: logger.error(f"Error saving final masks: {e}") def generate_first_frame_debug_masks(self, video_path: str, prompts: List[Dict[str, Any]], output_path: str, inference_scale: float = 0.5) -> bool: """ Generate SAM2 masks for just the first frame and save debug visualization. This helps debug what SAM2 is producing for each detected object. Args: video_path: Path to the video file prompts: List of SAM2 prompt dictionaries output_path: Path to save the debug image inference_scale: Scale factor for SAM2 inference Returns: True if debug masks were generated successfully """ if not prompts: logger.warning("No prompts provided for first frame debug") return False try: logger.info(f"SAM2 Debug: Generating first frame masks for {len(prompts)} objects") # Load the first frame cap = cv2.VideoCapture(video_path) ret, original_frame = cap.read() cap.release() if not ret: logger.error("Could not read first frame for debug mask generation") return False # Scale frame for inference if needed if inference_scale != 1.0: inference_frame = cv2.resize(original_frame, None, fx=inference_scale, fy=inference_scale, interpolation=cv2.INTER_LINEAR) else: inference_frame = original_frame.copy() # Create temporary low-res video with just first frame import tempfile import os temp_dir = tempfile.mkdtemp() temp_video_path = os.path.join(temp_dir, "first_frame.mp4") # Write single frame to temporary video fourcc = cv2.VideoWriter_fourcc(*'mp4v') out = cv2.VideoWriter(temp_video_path, fourcc, 1.0, (inference_frame.shape[1], inference_frame.shape[0])) out.write(inference_frame) out.release() # Initialize SAM2 inference state with single frame inference_state = self.predictor.init_state(video_path=temp_video_path, async_loading_frames=True) # Add prompts if not self.add_yolo_prompts_to_predictor(inference_state, prompts, inference_scale): logger.error("Failed to add prompts for first frame debug") return False # Generate masks for first frame only frame_masks = {} for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(inference_state): if out_frame_idx == 0: # Only process first frame frame_masks = { out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy() for i, out_obj_id in enumerate(out_obj_ids) } break if not frame_masks: logger.error("No masks generated for first frame debug") return False # Create debug visualization debug_frame = original_frame.copy() # Define colors for each object colors = { 1: (0, 255, 0), # Green for Object 1 (Left eye) 2: (255, 0, 0), # Blue for Object 2 (Right eye) 3: (0, 255, 255), # Yellow for Object 3 4: (255, 0, 255), # Magenta for Object 4 } # Overlay masks with transparency for obj_id, mask in frame_masks.items(): mask = mask.squeeze() # Resize mask to match original frame if needed if mask.shape != original_frame.shape[:2]: mask = cv2.resize(mask.astype(np.float32), (original_frame.shape[1], original_frame.shape[0]), interpolation=cv2.INTER_NEAREST) mask = mask > 0.5 # Apply colored overlay color = colors.get(obj_id, (128, 128, 128)) overlay = debug_frame.copy() overlay[mask] = color # Blend with original (30% overlay, 70% original) cv2.addWeighted(overlay, 0.3, debug_frame, 0.7, 0, debug_frame) # Draw outline contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) cv2.drawContours(debug_frame, contours, -1, color, 2) logger.info(f"SAM2 Debug: Object {obj_id} mask - shape: {mask.shape}, pixels: {np.sum(mask)}") # Add title title = f"SAM2 First Frame Masks: {len(frame_masks)} objects detected" cv2.putText(debug_frame, title, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (255, 255, 255), 2) # Add mask source information source_info = "Mask Source: SAM2 (from YOLO bounding boxes)" cv2.putText(debug_frame, source_info, (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2) # Add object legend y_offset = 90 for obj_id in sorted(frame_masks.keys()): color = colors.get(obj_id, (128, 128, 128)) text = f"Object {obj_id}: {'Left Eye' if obj_id == 1 else 'Right Eye' if obj_id == 2 else f'Object {obj_id}'}" cv2.putText(debug_frame, text, (10, y_offset), cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2) y_offset += 30 # Save debug image success = cv2.imwrite(output_path, debug_frame) # Cleanup self.predictor.reset_state(inference_state) import shutil shutil.rmtree(temp_dir) if success: logger.info(f"SAM2 Debug: Saved first frame masks to {output_path}") return True else: logger.error(f"Failed to save first frame masks to {output_path}") return False except Exception as e: logger.error(f"Error generating first frame debug masks: {e}") return False def add_multi_frame_prompts_to_predictor(self, inference_state, multi_frame_prompts: Dict[int, Any]) -> bool: """ Add YOLO prompts at multiple frame indices for mid-segment re-detection. Supports both bounding box prompts (detection mode) and mask prompts (segmentation mode). Args: inference_state: SAM2 inference state multi_frame_prompts: Dictionary mapping frame_index -> prompts (list of dicts for bbox, dict with 'masks' for segmentation) Returns: True if prompts were added successfully """ if not multi_frame_prompts: logger.warning("SAM2 Mid-segment: No multi-frame prompts provided") return False success_count = 0 total_count = 0 for frame_idx, prompts_data in multi_frame_prompts.items(): # Check if this is segmentation mode (masks) or detection mode (bbox prompts) if isinstance(prompts_data, dict) and 'masks' in prompts_data: # Segmentation mode: add masks directly masks_dict = prompts_data['masks'] logger.info(f"SAM2 Mid-segment: Processing frame {frame_idx} with {len(masks_dict)} YOLO masks") for obj_id, mask in masks_dict.items(): total_count += 1 logger.info(f"SAM2 Mid-segment: Frame {frame_idx}, adding mask for Object {obj_id}") try: self.predictor.add_new_mask(inference_state, frame_idx, obj_id, mask) logger.info(f"SAM2 Mid-segment: ✓ Frame {frame_idx}, Object {obj_id} mask added successfully") success_count += 1 except Exception as e: logger.error(f"SAM2 Mid-segment: ✗ Frame {frame_idx}, Object {obj_id} mask failed: {e}") continue else: # Detection mode: add bounding box prompts (existing logic) prompts = prompts_data logger.info(f"SAM2 Mid-segment: Processing frame {frame_idx} with {len(prompts)} bbox prompts") for i, prompt in enumerate(prompts): obj_id = prompt['obj_id'] bbox = prompt['bbox'] confidence = prompt.get('confidence', 'unknown') total_count += 1 logger.info(f"SAM2 Mid-segment: Frame {frame_idx}, Prompt {i+1}/{len(prompts)}: Object {obj_id}, bbox={bbox}, conf={confidence}") try: _, out_obj_ids, out_mask_logits = self.predictor.add_new_points_or_box( inference_state=inference_state, frame_idx=frame_idx, # Key: specify the exact frame index obj_id=obj_id, box=bbox.astype(np.float32), ) logger.info(f"SAM2 Mid-segment: ✓ Frame {frame_idx}, Object {obj_id} added successfully - returned obj_ids: {out_obj_ids}") success_count += 1 except Exception as e: logger.error(f"SAM2 Mid-segment: ✗ Frame {frame_idx}, Object {obj_id} failed: {e}") continue if success_count > 0: logger.info(f"SAM2 Mid-segment: Final result - {success_count}/{total_count} prompts successfully added across {len(multi_frame_prompts)} frames") return True else: logger.error("SAM2 Mid-segment: FAILED - No prompts were successfully added") return False def process_single_eye_segment(self, segment_info: dict, eye_side: str, yolo_prompts: Optional[List[Dict[str, Any]]] = None, previous_masks: Optional[Dict[int, np.ndarray]] = None, inference_scale: float = 0.5) -> Optional[Dict[int, np.ndarray]]: """ Process a single eye of a VR180 segment with SAM2. Args: segment_info: Segment information dictionary eye_side: 'left' or 'right' eye yolo_prompts: Optional YOLO detection prompts for first frame previous_masks: Optional masks from previous segment inference_scale: Scale factor for inference Returns: Dictionary mapping frame indices to masks, or None if failed """ if not self.eye_processor: logger.error("Eye processor not initialized - separate_eye_processing must be enabled") return None segment_dir = segment_info['directory'] video_path = segment_info['video_file'] segment_idx = segment_info['index'] logger.info(f"Processing {eye_side} eye for segment {segment_idx}") # Use the video path directly (it should already be the eye-specific video) eye_video_path = video_path # Verify the eye video exists if not os.path.exists(eye_video_path): logger.error(f"Eye video not found: {eye_video_path}") return None # Create low-resolution eye video for inference (async-aware) low_res_eye_video_path = os.path.join(segment_dir, f"low_res_{eye_side}_eye_video.mp4") if not self.ensure_low_res_video(eye_video_path, low_res_eye_video_path, inference_scale, segment_idx): logger.error(f"Failed to create low-res {eye_side} eye video for segment {segment_idx}") return None try: # Initialize inference state with eye-specific video inference_state = self.predictor.init_state(video_path=low_res_eye_video_path, async_loading_frames=True) # Add prompts or previous masks (always use obj_id=1 for single eye processing) if yolo_prompts: # Convert prompts to use obj_id=1 for single eye processing eye_prompts = [] for prompt in yolo_prompts: eye_prompt = prompt.copy() eye_prompt['obj_id'] = 1 # Always use obj_id=1 for single eye eye_prompts.append(eye_prompt) if not self.add_yolo_prompts_to_predictor(inference_state, eye_prompts, inference_scale): logger.error(f"Failed to add prompts for {eye_side} eye") return None elif previous_masks: # Convert previous masks to use obj_id=1 for single eye processing eye_masks = {1: list(previous_masks.values())[0]} if previous_masks else {} if not self.add_previous_masks_to_predictor(inference_state, eye_masks): logger.error(f"Failed to add previous masks for {eye_side} eye") return None else: logger.error(f"No prompts or previous masks available for {eye_side} eye of segment {segment_idx}") return None # Propagate masks logger.info(f"Propagating masks for {eye_side} eye") video_segments = self.propagate_masks(inference_state) # Extract just the masks (remove obj_id structure since we only use obj_id=1) eye_masks = {} for frame_idx, frame_masks in video_segments.items(): if 1 in frame_masks: # We always use obj_id=1 for single eye processing eye_masks[frame_idx] = frame_masks[1] # Clean up self.predictor.reset_state(inference_state) del inference_state gc.collect() # Remove temporary low-res video try: os.remove(low_res_eye_video_path) logger.debug(f"Removed low-res {eye_side} eye video: {low_res_eye_video_path}") except Exception as e: logger.warning(f"Could not remove low-res {eye_side} eye video: {e}") logger.info(f"Successfully processed {eye_side} eye with {len(eye_masks)} frames") return eye_masks except Exception as e: logger.error(f"Error processing {eye_side} eye for segment {segment_idx}: {e}") return None def process_segment_with_separate_eyes(self, segment_info: dict, left_prompts: Optional[List[Dict[str, Any]]] = None, right_prompts: Optional[List[Dict[str, Any]]] = None, previous_left_masks: Optional[Dict[int, np.ndarray]] = None, previous_right_masks: Optional[Dict[int, np.ndarray]] = None, inference_scale: float = 0.5, full_frame_shape: Optional[Tuple[int, int]] = None) -> Optional[Dict[int, Dict[int, np.ndarray]]]: """ Process a VR180 segment with separate left and right eye processing. Args: segment_info: Segment information dictionary left_prompts: Optional YOLO prompts for left eye right_prompts: Optional YOLO prompts for right eye previous_left_masks: Optional previous masks for left eye previous_right_masks: Optional previous masks for right eye inference_scale: Scale factor for inference full_frame_shape: Shape of full VR180 frame (height, width) Returns: Combined video segments dictionary or None if failed """ if not self.eye_processor: logger.error("Eye processor not initialized - separate_eye_processing must be enabled") return None segment_idx = segment_info['index'] logger.info(f"Processing segment {segment_idx} with separate eye processing") # Get full frame shape if not provided if full_frame_shape is None: try: cap = cv2.VideoCapture(segment_info['video_file']) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) cap.release() full_frame_shape = (height, width) except Exception as e: logger.error(f"Could not determine frame shape: {e}") return None # Process left eye if prompts or previous masks are available left_masks = None if left_prompts or previous_left_masks: logger.info(f"Processing left eye for segment {segment_idx}") left_masks = self.process_single_eye_segment( segment_info, 'left', left_prompts, previous_left_masks, inference_scale ) # Process right eye if prompts or previous masks are available right_masks = None if right_prompts or previous_right_masks: logger.info(f"Processing right eye for segment {segment_idx}") right_masks = self.process_single_eye_segment( segment_info, 'right', right_prompts, previous_right_masks, inference_scale ) # Combine masks back to full frame format if left_masks or right_masks: logger.info(f"Combining eye masks for segment {segment_idx}") combined_masks = self.eye_processor.combine_eye_masks( left_masks, right_masks, full_frame_shape ) # Clean up eye-specific videos to save space try: left_eye_path = os.path.join(segment_info['directory'], "left_eye_video.mp4") right_eye_path = os.path.join(segment_info['directory'], "right_eye_video.mp4") if os.path.exists(left_eye_path): os.remove(left_eye_path) logger.debug(f"Removed left eye video: {left_eye_path}") if os.path.exists(right_eye_path): os.remove(right_eye_path) logger.debug(f"Removed right eye video: {right_eye_path}") except Exception as e: logger.warning(f"Could not clean up eye videos: {e}") logger.info(f"Successfully processed segment {segment_idx} with separate eyes") return combined_masks else: logger.warning(f"No masks generated for either eye in segment {segment_idx}") return None def create_greenscreen_segment(self, segment_info: dict, green_color: List[int] = [0, 255, 0]) -> bool: """ Create a full greenscreen segment when no humans are detected. Args: segment_info: Segment information dictionary green_color: RGB values for green screen color Returns: True if greenscreen segment was created successfully """ segment_dir = segment_info['directory'] video_path = segment_info['video_file'] segment_idx = segment_info['index'] logger.info(f"Creating full greenscreen segment {segment_idx}") try: # Get video properties cap = cv2.VideoCapture(video_path) if not cap.isOpened(): logger.error(f"Could not open video: {video_path}") return False width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) fps = cap.get(cv2.CAP_PROP_FPS) frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) cap.release() # Create output video path output_video_path = os.path.join(segment_dir, f"output_{segment_idx}.mp4") # Create greenscreen frames greenscreen_frame = self.eye_processor.create_full_greenscreen_frame( (height, width, 3), green_color ) # Write greenscreen video fourcc = cv2.VideoWriter_fourcc(*'HEVC') out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height)) for _ in range(frame_count): out.write(greenscreen_frame) out.release() # Create mask file (empty/black mask since no humans detected) mask_output_path = os.path.join(segment_dir, "mask.png") black_mask = np.zeros((height, width, 3), dtype=np.uint8) cv2.imwrite(mask_output_path, black_mask) # Mark segment as completed output_done_file = os.path.join(segment_dir, "output_frames_done") with open(output_done_file, 'w') as f: f.write(f"Greenscreen segment {segment_idx} completed successfully\n") logger.info(f"Successfully created greenscreen segment {segment_idx}") return True except Exception as e: logger.error(f"Error creating greenscreen segment {segment_idx}: {e}") return False