""" Main VR180 streaming processor - orchestrates all components for true streaming """ import time import gc import json import psutil import torch import numpy as np from pathlib import Path from typing import Dict, Any, Optional, Tuple import warnings from .frame_reader import StreamingFrameReader from .frame_writer import StreamingFrameWriter from .stereo_manager import StereoConsistencyManager from .sam2_streaming_simple import SAM2StreamingProcessor from .detector import PersonDetector from .config import StreamingConfig class VR180StreamingProcessor: """Main processor for streaming VR180 human matting""" def __init__(self, config: StreamingConfig): self.config = config # Initialize components self.frame_reader = None self.frame_writer = None self.stereo_manager = None self.sam2_processor = None self.detector = None # Processing state self.start_time = None self.frames_processed = 0 self.checkpoint_state = {} # Performance monitoring self.process = psutil.Process() self.performance_stats = { 'fps': 0.0, 'avg_frame_time': 0.0, 'peak_memory_gb': 0.0, 'gpu_utilization': 0.0 } def initialize(self) -> None: """Initialize all components""" print("\nšŸš€ Initializing VR180 Streaming Processor") print("=" * 60) # Initialize frame reader start_frame = self._load_checkpoint() if self.config.recovery.auto_resume else 0 self.frame_reader = StreamingFrameReader( self.config.input.video_path, start_frame=start_frame ) # Get video info video_info = self.frame_reader.get_video_info() # Apply scaling to dimensions scale = self.config.processing.scale_factor output_width = int(video_info['width'] * scale) output_height = int(video_info['height'] * scale) # Initialize frame writer self.frame_writer = StreamingFrameWriter( output_path=self.config.output.path, width=output_width, height=output_height, fps=video_info['fps'], audio_source=self.config.input.video_path if self.config.output.maintain_sbs else None, video_codec=self.config.output.video_codec, quality_preset=self.config.output.quality_preset, crf=self.config.output.crf ) # Initialize stereo manager self.stereo_manager = StereoConsistencyManager(self.config.to_dict()) # Initialize SAM2 processor self.sam2_processor = SAM2StreamingProcessor(self.config.to_dict()) # Initialize detector self.detector = PersonDetector(self.config.to_dict()) self.detector.warmup((output_height // 2, output_width // 2, 3)) # Warmup with single eye dims print("\nāœ… All components initialized successfully!") print(f" Input: {video_info['width']}x{video_info['height']} @ {video_info['fps']}fps") print(f" Output: {output_width}x{output_height} @ {video_info['fps']}fps") print(f" Scale factor: {scale}") print(f" Starting from frame: {start_frame}") print("=" * 60 + "\n") def process_video(self) -> None: """Main processing loop""" try: self.initialize() self.start_time = time.time() # Simple SAM2 initialization (no complex state management needed) print("šŸŽÆ SAM2 streaming processor ready...") # Process first frame to establish detections print("šŸ” Processing first frame for initial detection...") if not self._initialize_tracking(): raise RuntimeError("Failed to initialize tracking - no persons detected") # Main streaming loop print("\nšŸŽ¬ Starting streaming processing loop...") self._streaming_loop() except KeyboardInterrupt: print("\nāš ļø Processing interrupted by user") self._save_checkpoint() except Exception as e: print(f"\nāŒ Error during processing: {e}") self._save_checkpoint() raise finally: self._finalize() def _initialize_tracking(self) -> bool: """Initialize tracking with first frame detection""" # Read and process first frame first_frame = self.frame_reader.read_frame() if first_frame is None: raise RuntimeError("Cannot read first frame") # Scale frame if needed if self.config.processing.scale_factor != 1.0: first_frame = self._scale_frame(first_frame) # Split into eyes left_eye, right_eye = self.stereo_manager.split_frame(first_frame) # Detect on master eye master_eye = left_eye if self.stereo_manager.master_eye == 'left' else right_eye detections = self.detector.detect_persons(master_eye) if not detections: warnings.warn("No persons detected in first frame") return False print(f" Detected {len(detections)} person(s) in first frame") # Process with simple SAM2 approach left_masks = self.sam2_processor.add_frame_and_detections(left_eye, detections, 0) # Transfer detections to right eye transferred_detections = self.stereo_manager.transfer_detections( detections, 'left_to_right' if self.stereo_manager.master_eye == 'left' else 'right_to_left' ) right_masks = self.sam2_processor.add_frame_and_detections(right_eye, transferred_detections, 0) # Apply masks and write processed_frame = self._apply_masks_to_frame(first_frame, left_masks, right_masks) self.frame_writer.write_frame(processed_frame) self.frames_processed = 1 return True def _streaming_loop(self) -> None: """Main streaming processing loop""" frame_times = [] last_log_time = time.time() # Start from frame 1 (already processed frame 0) for frame_idx, frame in enumerate(self.frame_reader, start=1): frame_start_time = time.time() # Scale frame if needed if self.config.processing.scale_factor != 1.0: frame = self._scale_frame(frame) # Split into eyes left_eye, right_eye = self.stereo_manager.split_frame(frame) # Check if we need to run detection for continuous correction detections = [] if (self.config.matting.continuous_correction and frame_idx % self.config.matting.correction_interval == 0): print(f"\nšŸ”„ Running YOLO detection for correction at frame {frame_idx}") master_eye = left_eye if self.stereo_manager.master_eye == 'left' else right_eye detections = self.detector.detect_persons(master_eye) if detections: print(f" Detected {len(detections)} person(s) for correction") else: print(f" No persons detected for correction") # Process frames (with detections if this is a correction frame) left_masks = self.sam2_processor.add_frame_and_detections(left_eye, detections, frame_idx) # For right eye, transfer detections if we have them if detections: transferred_detections = self.stereo_manager.transfer_detections( detections, 'left_to_right' if self.stereo_manager.master_eye == 'left' else 'right_to_left' ) right_masks = self.sam2_processor.add_frame_and_detections(right_eye, transferred_detections, frame_idx) else: right_masks = self.sam2_processor.add_frame_and_detections(right_eye, [], frame_idx) # Validate stereo consistency right_masks = self.stereo_manager.validate_masks( left_masks, right_masks, frame_idx ) # Apply masks and write frame processed_frame = self._apply_masks_to_frame(frame, left_masks, right_masks) self.frame_writer.write_frame(processed_frame) # Update stats frame_time = time.time() - frame_start_time frame_times.append(frame_time) self.frames_processed += 1 # Periodic logging and cleanup if frame_idx % self.config.performance.log_interval == 0: self._log_progress(frame_idx, frame_times) frame_times = frame_times[-100:] # Keep only recent times # Checkpoint saving if (self.config.recovery.enable_checkpoints and frame_idx % self.config.recovery.checkpoint_interval == 0): self._save_checkpoint() # Memory monitoring and cleanup if frame_idx % 50 == 0: self._monitor_and_cleanup() # Check max frames limit if (self.config.input.max_frames is not None and self.frames_processed >= self.config.input.max_frames): print(f"\nāœ… Reached max frames limit ({self.config.input.max_frames})") break def _scale_frame(self, frame: np.ndarray) -> np.ndarray: """Scale frame according to configuration""" scale = self.config.processing.scale_factor if scale == 1.0: return frame new_width = int(frame.shape[1] * scale) new_height = int(frame.shape[0] * scale) import cv2 return cv2.resize(frame, (new_width, new_height), interpolation=cv2.INTER_AREA) def _apply_masks_to_frame(self, frame: np.ndarray, left_masks: np.ndarray, right_masks: np.ndarray) -> np.ndarray: """Apply masks to frame and combine results""" # Split frame left_eye, right_eye = self.stereo_manager.split_frame(frame) # Apply masks to each eye left_processed = self.sam2_processor.apply_mask_to_frame( left_eye, left_masks, output_format=self.config.output.format, background_color=self.config.output.background_color ) right_processed = self.sam2_processor.apply_mask_to_frame( right_eye, right_masks, output_format=self.config.output.format, background_color=self.config.output.background_color ) # Combine back to SBS if self.config.output.maintain_sbs: return self.stereo_manager.combine_frames(left_processed, right_processed) else: # Return just left eye for non-SBS output return left_processed def _apply_continuous_correction(self, left_eye: np.ndarray, right_eye: np.ndarray, frame_idx: int) -> None: """Apply continuous correction to maintain tracking accuracy""" print(f"\nšŸ”„ Applying continuous correction at frame {frame_idx}") # Detect on master eye and add fresh detections master_eye = left_eye if self.stereo_manager.master_eye == 'left' else right_eye detections = self.detector.detect_persons(master_eye) if detections: print(f" Adding {len(detections)} fresh detection(s) for correction") # Add fresh detections to help correct drift self.sam2_processor.add_frame_and_detections(master_eye, detections, frame_idx) # Transfer corrections to slave eye # Note: This is simplified - actual implementation would transfer the refined prompts def _log_progress(self, frame_idx: int, frame_times: list) -> None: """Log processing progress""" elapsed = time.time() - self.start_time avg_frame_time = np.mean(frame_times) if frame_times else 0 fps = 1.0 / avg_frame_time if avg_frame_time > 0 else 0 # Memory stats memory_info = self.process.memory_info() memory_gb = memory_info.rss / (1024**3) # GPU stats if available gpu_stats = self.sam2_processor.get_memory_usage() # Progress percentage progress = self.frame_reader.get_progress() print(f"\nšŸ“Š Progress: Frame {frame_idx} ({progress:.1f}%)") print(f" Speed: {fps:.1f} FPS (avg: {avg_frame_time*1000:.1f}ms/frame)") print(f" Memory: {memory_gb:.1f}GB RAM", end="") if 'cuda_allocated_gb' in gpu_stats: print(f", {gpu_stats['cuda_allocated_gb']:.1f}GB VRAM") else: print() print(f" Time elapsed: {elapsed/60:.1f} minutes") # Update performance stats self.performance_stats['fps'] = fps self.performance_stats['avg_frame_time'] = avg_frame_time self.performance_stats['peak_memory_gb'] = max( self.performance_stats['peak_memory_gb'], memory_gb ) def _monitor_and_cleanup(self) -> None: """Monitor memory and perform cleanup if needed""" memory_info = self.process.memory_info() memory_gb = memory_info.rss / (1024**3) # Check if approaching limits if memory_gb > self.config.hardware.max_ram_gb * 0.8: print(f"\nāš ļø High memory usage ({memory_gb:.1f}GB) - running cleanup") gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.synchronize() def _save_checkpoint(self) -> None: """Save processing checkpoint""" if not self.config.recovery.enable_checkpoints: return checkpoint_dir = Path(self.config.recovery.checkpoint_dir) checkpoint_dir.mkdir(exist_ok=True) checkpoint_file = checkpoint_dir / f"{Path(self.config.output.path).stem}_checkpoint.json" checkpoint_data = { 'frame_index': self.frames_processed, 'timestamp': time.time(), 'input_video': self.config.input.video_path, 'output_video': self.config.output.path, 'config': self.config.to_dict() } with open(checkpoint_file, 'w') as f: json.dump(checkpoint_data, f, indent=2) print(f"šŸ’¾ Checkpoint saved at frame {self.frames_processed}") def _load_checkpoint(self) -> int: """Load checkpoint if exists""" checkpoint_dir = Path(self.config.recovery.checkpoint_dir) checkpoint_file = checkpoint_dir / f"{Path(self.config.output.path).stem}_checkpoint.json" if checkpoint_file.exists(): with open(checkpoint_file, 'r') as f: checkpoint_data = json.load(f) if checkpoint_data['input_video'] == self.config.input.video_path: start_frame = checkpoint_data['frame_index'] print(f"šŸ“‚ Found checkpoint - resuming from frame {start_frame}") return start_frame return 0 def _finalize(self) -> None: """Finalize processing and cleanup""" print("\nšŸ Finalizing processing...") # Close components if self.frame_writer: self.frame_writer.close() if self.frame_reader: self.frame_reader.close() if self.sam2_processor: self.sam2_processor.cleanup() # Print final statistics if self.start_time: total_time = time.time() - self.start_time print(f"\nšŸ“ˆ Final Statistics:") print(f" Total frames: {self.frames_processed}") print(f" Total time: {total_time/60:.1f} minutes") print(f" Average FPS: {self.frames_processed/total_time:.1f}") print(f" Peak memory: {self.performance_stats['peak_memory_gb']:.1f}GB") # Stereo consistency stats stereo_stats = self.stereo_manager.get_stats() print(f"\nšŸ‘€ Stereo Consistency:") print(f" Corrections applied: {stereo_stats['corrections_applied']}") print(f" Correction rate: {stereo_stats['correction_rate']*100:.1f}%") print("\nāœ… Processing complete!")