import cv2 import numpy as np from typing import List, Dict, Any, Optional, Tuple, Generator from pathlib import Path import ffmpeg import tempfile import shutil from tqdm import tqdm import warnings import time import subprocess import gc import psutil import os import sys from .config import VR180Config from .detector import YOLODetector from .sam2_wrapper import SAM2VideoMatting from .memory_manager import VRAMManager class VideoProcessor: """Main video processing pipeline for VR180 matting""" def __init__(self, config: VR180Config): self.config = config self.memory_manager = VRAMManager( max_vram_gb=config.hardware.max_vram_gb, device=config.hardware.device ) # Initialize components self.detector = None self.sam2_model = None # Video properties self.video_info = None self.total_frames = 0 self.fps = 30.0 self.frame_width = 0 self.frame_height = 0 # Processing statistics self.processing_stats = { 'start_time': None, 'end_time': None, 'total_duration': 0, 'processing_fps': 0, 'chunks_processed': 0, 'frames_processed': 0 } self._initialize_models() def _get_process_memory_info(self) -> Dict[str, float]: """Get detailed memory usage for current process and children""" current_process = psutil.Process(os.getpid()) # Get memory info for current process memory_info = current_process.memory_info() current_rss = memory_info.rss / 1024**3 # Convert to GB current_vms = memory_info.vms / 1024**3 # Virtual memory # Get memory info for all children children_rss = 0 children_vms = 0 child_count = 0 try: for child in current_process.children(recursive=True): try: child_memory = child.memory_info() children_rss += child_memory.rss / 1024**3 children_vms += child_memory.vms / 1024**3 child_count += 1 except (psutil.NoSuchProcess, psutil.AccessDenied): pass except psutil.NoSuchProcess: pass # System memory info system_memory = psutil.virtual_memory() system_total = system_memory.total / 1024**3 system_available = system_memory.available / 1024**3 system_used = system_memory.used / 1024**3 system_percent = system_memory.percent return { 'process_rss_gb': current_rss, 'process_vms_gb': current_vms, 'children_rss_gb': children_rss, 'children_vms_gb': children_vms, 'total_process_gb': current_rss + children_rss, 'child_count': child_count, 'system_total_gb': system_total, 'system_used_gb': system_used, 'system_available_gb': system_available, 'system_percent': system_percent } def _print_memory_step(self, step_name: str): """Print memory usage for a specific processing step""" memory_info = self._get_process_memory_info() print(f"\n๐Ÿ“Š MEMORY: {step_name}") print(f" Process RSS: {memory_info['process_rss_gb']:.2f} GB") if memory_info['children_rss_gb'] > 0: print(f" Children RSS: {memory_info['children_rss_gb']:.2f} GB ({memory_info['child_count']} processes)") print(f" Total Process: {memory_info['total_process_gb']:.2f} GB") print(f" System: {memory_info['system_used_gb']:.1f}/{memory_info['system_total_gb']:.1f} GB ({memory_info['system_percent']:.1f}%)") print(f" Available: {memory_info['system_available_gb']:.1f} GB") def _aggressive_memory_cleanup(self, step_name: str = ""): """Perform aggressive memory cleanup and report before/after""" if step_name: print(f"\n๐Ÿงน CLEANUP: Before {step_name}") before_info = self._get_process_memory_info() before_rss = before_info['total_process_gb'] # Multiple rounds of garbage collection for i in range(3): gc.collect() # Clear torch cache if available try: import torch if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.synchronize() except ImportError: pass # Clear OpenCV internal caches try: # Clear OpenCV video capture cache cv2.setUseOptimized(False) cv2.setUseOptimized(True) except Exception: pass # Clear CuPy caches if available try: import cupy as cp cp._default_memory_pool.free_all_blocks() cp._default_pinned_memory_pool.free_all_blocks() cp.get_default_memory_pool().free_all_blocks() cp.get_default_pinned_memory_pool().free_all_blocks() except ImportError: pass except Exception as e: print(f" Warning: Could not clear CuPy cache: {e}") # Force Linux to release memory back to OS if sys.platform == 'linux': try: import ctypes libc = ctypes.CDLL("libc.so.6") libc.malloc_trim(0) except Exception as e: print(f" Warning: Could not trim memory: {e}") # Brief pause to allow cleanup time.sleep(0.1) after_info = self._get_process_memory_info() after_rss = after_info['total_process_gb'] freed_memory = before_rss - after_rss if step_name: print(f" Before: {before_rss:.2f} GB โ†’ After: {after_rss:.2f} GB") print(f" Freed: {freed_memory:.2f} GB") def _initialize_models(self): """Initialize YOLO detector and SAM2 model""" print("Initializing models...") with self.memory_manager.memory_monitor("model loading"): # Initialize YOLO detector self.detector = YOLODetector( model_name=self.config.detection.model, confidence_threshold=self.config.detection.confidence_threshold, device=self.config.hardware.device ) # Initialize SAM2 model self.sam2_model = SAM2VideoMatting( model_cfg=self.config.matting.sam2_model_cfg, checkpoint_path=self.config.matting.sam2_checkpoint, device=self.config.hardware.device, memory_offload=self.config.matting.memory_offload, fp16=self.config.matting.fp16 ) def load_video_info(self, video_path: str) -> Dict[str, Any]: """Load video metadata using ffmpeg""" try: probe = ffmpeg.probe(video_path) video_stream = next( (stream for stream in probe['streams'] if stream['codec_type'] == 'video'), None ) if video_stream is None: raise ValueError("No video stream found") self.video_info = { 'width': int(video_stream['width']), 'height': int(video_stream['height']), 'fps': eval(video_stream['r_frame_rate']), 'duration': float(video_stream.get('duration', 0)), 'nb_frames': int(video_stream.get('nb_frames', 0)), 'codec': video_stream['codec_name'], 'pix_fmt': video_stream.get('pix_fmt', 'yuv420p') } self.frame_width = self.video_info['width'] self.frame_height = self.video_info['height'] self.fps = self.video_info['fps'] self.total_frames = self.video_info['nb_frames'] print(f"Video info: {self.frame_width}x{self.frame_height} @ {self.fps:.2f}fps") print(f"Total frames: {self.total_frames}, Duration: {self.video_info['duration']:.1f}s") return self.video_info except Exception as e: raise RuntimeError(f"Failed to load video info: {e}") def read_video_frames(self, video_path: str, start_frame: int = 0, num_frames: Optional[int] = None, scale_factor: float = 1.0) -> List[np.ndarray]: """ Read video frames with optional scaling Args: video_path: Path to video file start_frame: Starting frame index num_frames: Number of frames to read (None for all) scale_factor: Scaling factor for frames Returns: List of video frames """ cap = cv2.VideoCapture(video_path) if not cap.isOpened(): raise RuntimeError(f"Failed to open video: {video_path}") # Set starting position if start_frame > 0: cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame) frames = [] frame_count = 0 with tqdm(desc="Reading frames", total=num_frames) as pbar: while True: ret, frame = cap.read() if not ret: break # Apply scaling if needed if scale_factor != 1.0: new_width = int(frame.shape[1] * scale_factor) new_height = int(frame.shape[0] * scale_factor) frame = cv2.resize(frame, (new_width, new_height), interpolation=cv2.INTER_AREA) frames.append(frame) frame_count += 1 pbar.update(1) if num_frames is not None and frame_count >= num_frames: break cap.release() print(f"Read {len(frames)} frames") return frames def calculate_optimal_chunking(self) -> Tuple[int, int]: """ Calculate optimal chunk size and overlap based on memory constraints Returns: Tuple of (chunk_size, overlap_frames) """ if self.config.processing.chunk_size > 0: return self.config.processing.chunk_size, self.config.processing.overlap_frames # Calculate based on memory constraints scaled_height = int(self.frame_height * self.config.processing.scale_factor) scaled_width = int(self.frame_width * self.config.processing.scale_factor) optimal_chunk = self.memory_manager.get_optimal_chunk_size( scaled_height, scaled_width, fp16=self.config.matting.fp16 ) overlap = min(60, optimal_chunk // 10) # 10% overlap, max 60 frames print(f"Calculated optimal chunk size: {optimal_chunk} frames with {overlap} frame overlap") return optimal_chunk, overlap def process_chunk(self, frames: List[np.ndarray], chunk_idx: int = 0) -> List[np.ndarray]: """ Process a chunk of frames through the matting pipeline Args: frames: List of frames to process chunk_idx: Chunk index for logging Returns: List of matted frames """ print(f"Processing chunk {chunk_idx} ({len(frames)} frames)") with self.memory_manager.memory_monitor(f"chunk {chunk_idx}"): # Initialize SAM2 with frames self.sam2_model.init_video_state(frames) # Detect persons in first frame first_frame = frames[0] detections = self.detector.detect_persons(first_frame) if not detections: warnings.warn(f"No persons detected in chunk {chunk_idx}") return self._create_empty_masks(frames) print(f"Detected {len(detections)} persons in first frame") # Convert detections to SAM2 prompts box_prompts, labels = self.detector.convert_to_sam_prompts(detections) # Add prompts to SAM2 object_ids = self.sam2_model.add_person_prompts(0, box_prompts, labels) print(f"Added prompts for {len(object_ids)} objects") # Propagate masks through chunk video_segments = self.sam2_model.propagate_masks( start_frame=0, max_frames=len(frames) ) # Apply masks to frames matted_frames = [] for frame_idx, frame in enumerate(tqdm(frames, desc="Applying masks")): if frame_idx in video_segments: frame_masks = video_segments[frame_idx] combined_mask = self.sam2_model.get_combined_mask(frame_masks) matted_frame = self.sam2_model.apply_mask_to_frame( frame, combined_mask, output_format=self.config.output.format, background_color=self.config.output.background_color ) else: # No mask for this frame matted_frame = self._create_empty_mask_frame(frame) matted_frames.append(matted_frame) # Cleanup SAM2 state self.sam2_model.cleanup() return matted_frames def _create_empty_masks(self, frames: List[np.ndarray]) -> List[np.ndarray]: """Create empty masks when no persons detected""" empty_frames = [] for frame in frames: empty_frame = self._create_empty_mask_frame(frame) empty_frames.append(empty_frame) return empty_frames def _create_empty_mask_frame(self, frame: np.ndarray) -> np.ndarray: """Create frame with empty mask (all background)""" if self.config.output.format == "alpha": # Transparent output output = np.zeros((frame.shape[0], frame.shape[1], 4), dtype=np.uint8) return output else: # Green screen background return np.full_like(frame, self.config.output.background_color, dtype=np.uint8) def merge_chunks_streaming(self, chunk_files: List[Path], output_path: str, overlap_frames: int = 0, audio_source: str = None) -> None: """ Merge processed chunks using streaming approach (no memory accumulation) Args: chunk_files: List of chunk result files (.npz) output_path: Final output video path overlap_frames: Number of overlapping frames audio_source: Audio source file for final video """ from .streaming_video_writer import StreamingVideoWriter if not chunk_files: raise ValueError("No chunk files to merge") print(f"๐ŸŽฌ Streaming merge: {len(chunk_files)} chunks โ†’ {output_path}") # Initialize streaming writer writer = StreamingVideoWriter( output_path=output_path, fps=self.video_info['fps'], audio_source=audio_source ) try: # Process each chunk without accumulation for i, chunk_file in enumerate(chunk_files): print(f"๐Ÿ“ผ Processing chunk {i+1}/{len(chunk_files)}: {chunk_file.name}") # Load chunk (this is the only copy in memory) chunk_data = np.load(str(chunk_file)) frames = chunk_data['frames'].tolist() # Convert to list of arrays chunk_data.close() # Write chunk with streaming writer writer.write_chunk( frames=frames, chunk_index=i, overlap_frames=overlap_frames if i > 0 else 0, blend_with_previous=(i > 0 and overlap_frames > 0) ) # Immediately free memory del frames, chunk_data # Delete chunk file to free disk space try: chunk_file.unlink() print(f" ๐Ÿ—‘๏ธ Deleted {chunk_file.name}") except Exception as e: print(f" โš ๏ธ Could not delete {chunk_file.name}: {e}") # Aggressive cleanup every chunk self._aggressive_memory_cleanup(f"After processing chunk {i}") # Finalize the video writer.finalize() except Exception as e: print(f"โŒ Streaming merge failed: {e}") writer.cleanup() raise print(f"โœ… Streaming merge complete: {output_path}") def merge_overlapping_chunks(self, chunk_results: List[List[np.ndarray]], overlap_frames: int) -> List[np.ndarray]: """ Legacy merge method - DEPRECATED due to memory accumulation Use merge_chunks_streaming() instead for memory efficiency """ import warnings warnings.warn("merge_overlapping_chunks() is deprecated due to memory accumulation. Use merge_chunks_streaming()", DeprecationWarning, stacklevel=2) if len(chunk_results) == 1: return chunk_results[0] merged_frames = [] # Add first chunk completely merged_frames.extend(chunk_results[0]) # Process remaining chunks for chunk_idx in range(1, len(chunk_results)): chunk = chunk_results[chunk_idx] if overlap_frames > 0: # Blend overlap region overlap_start = len(merged_frames) - overlap_frames for i in range(overlap_frames): if i < len(chunk): # Linear blending alpha = i / overlap_frames prev_frame = merged_frames[overlap_start + i] curr_frame = chunk[i] blended = self._blend_frames(prev_frame, curr_frame, alpha) merged_frames[overlap_start + i] = blended # Add remaining frames from current chunk merged_frames.extend(chunk[overlap_frames:]) else: # No overlap, just concatenate merged_frames.extend(chunk) return merged_frames def _blend_frames(self, frame1: np.ndarray, frame2: np.ndarray, alpha: float) -> np.ndarray: """Blend two frames with alpha blending""" if frame1.shape != frame2.shape: return frame2 # Fallback to second frame blended = (1 - alpha) * frame1.astype(np.float32) + alpha * frame2.astype(np.float32) return blended.astype(np.uint8) def save_video(self, frames: List[np.ndarray], output_path: str): """ Save processed frames as video Args: frames: List of processed frames output_path: Output video path """ if not frames: raise ValueError("No frames to save") output_path = Path(output_path) output_path.parent.mkdir(parents=True, exist_ok=True) # Determine codec and format based on output format if self.config.output.format == "alpha": # Use PNG sequence for alpha channel self._save_png_sequence(frames, output_path.parent / f"{output_path.stem}_frames") else: # Save as regular video self._save_mp4_video(frames, str(output_path)) def _save_png_sequence(self, frames: List[np.ndarray], output_dir: Path): """Save frames as PNG sequence with alpha channel""" output_dir.mkdir(parents=True, exist_ok=True) for i, frame in enumerate(tqdm(frames, desc="Saving PNG sequence")): frame_path = output_dir / f"frame_{i:06d}.png" # Convert BGR to RGBA for PNG if frame.shape[2] == 4: # Already RGBA frame_rgba = cv2.cvtColor(frame, cv2.COLOR_BGRA2RGBA) else: # BGR to RGBA frame_rgba = cv2.cvtColor(frame, cv2.COLOR_BGR2RGBA) cv2.imwrite(str(frame_path), frame_rgba) print(f"Saved {len(frames)} PNG frames to {output_dir}") def _save_mp4_video(self, frames: List[np.ndarray], output_path: str): """Save frames as MP4 video with audio preservation""" if not frames: return output_path = Path(output_path) temp_frames_dir = output_path.parent / f"temp_frames_{output_path.stem}" temp_frames_dir.mkdir(exist_ok=True) try: # Save frames as images print("Saving frames as images...") for i, frame in enumerate(tqdm(frames, desc="Saving frames")): if frame.shape[2] == 4: # Convert RGBA to BGR frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2BGR) frame_path = temp_frames_dir / f"frame_{i:06d}.jpg" cv2.imwrite(str(frame_path), frame, [cv2.IMWRITE_JPEG_QUALITY, 95]) # Create video with ffmpeg self._create_video_with_ffmpeg(temp_frames_dir, output_path, len(frames)) finally: # Cleanup temporary frames if temp_frames_dir.exists(): shutil.rmtree(temp_frames_dir) def _create_video_with_ffmpeg(self, frames_dir: Path, output_path: Path, frame_count: int): """Create video using ffmpeg with audio preservation""" frame_pattern = str(frames_dir / "frame_%06d.jpg") if self.config.output.preserve_audio: # Create video with audio from input cmd = [ 'ffmpeg', '-y', '-framerate', str(self.fps), '-i', frame_pattern, '-i', str(self.config.input.video_path), # Input video for audio '-c:v', 'h264_nvenc', # Try GPU encoding first '-preset', 'fast', '-cq', '18', '-c:a', 'copy', # Copy audio without re-encoding '-map', '0:v:0', # Map video from frames '-map', '1:a:0', # Map audio from input video '-shortest', # Match shortest stream duration '-pix_fmt', 'yuv420p', str(output_path) ] else: # Create video without audio cmd = [ 'ffmpeg', '-y', '-framerate', str(self.fps), '-i', frame_pattern, '-c:v', 'h264_nvenc', '-preset', 'fast', '-cq', '18', '-pix_fmt', 'yuv420p', str(output_path) ] print(f"Creating video with ffmpeg...") result = subprocess.run(cmd, capture_output=True, text=True) if result.returncode != 0: # Try CPU encoding as fallback print("GPU encoding failed, trying CPU encoding...") cmd[cmd.index('h264_nvenc')] = 'libx264' cmd[cmd.index('-cq')] = '-crf' # Change quality parameter for CPU result = subprocess.run(cmd, capture_output=True, text=True) if result.returncode != 0: print(f"FFmpeg stdout: {result.stdout}") print(f"FFmpeg stderr: {result.stderr}") raise RuntimeError(f"FFmpeg failed with return code {result.returncode}") # Verify frame count if sync verification is enabled if self.config.output.verify_sync: self._verify_frame_count(output_path, frame_count) print(f"Saved video to {output_path}") def _verify_frame_count(self, video_path: Path, expected_frames: int): """Verify output video has correct frame count""" try: probe = ffmpeg.probe(str(video_path)) video_stream = next( (stream for stream in probe['streams'] if stream['codec_type'] == 'video'), None ) if video_stream: actual_frames = int(video_stream.get('nb_frames', 0)) if actual_frames != expected_frames: print(f"โš ๏ธ Frame count mismatch: expected {expected_frames}, got {actual_frames}") else: print(f"โœ… Frame count verified: {actual_frames} frames") except Exception as e: print(f"โš ๏ธ Could not verify frame count: {e}") def process_video(self) -> None: """Main video processing pipeline""" self.processing_stats['start_time'] = time.time() print("Starting VR180 video processing...") # Load video info self.load_video_info(self.config.input.video_path) # Calculate chunking parameters chunk_size, overlap_frames = self.calculate_optimal_chunking() # Process video in chunks chunk_files = [] # Store file paths instead of frame data temp_chunk_dir = Path(tempfile.mkdtemp(prefix="vr180_chunks_")) try: for start_frame in range(0, self.total_frames, chunk_size - overlap_frames): end_frame = min(start_frame + chunk_size, self.total_frames) frames_to_read = end_frame - start_frame chunk_idx = len(chunk_files) print(f"\nProcessing chunk {chunk_idx}: frames {start_frame}-{end_frame}") # Read chunk frames frames = self.read_video_frames( self.config.input.video_path, start_frame=start_frame, num_frames=frames_to_read, scale_factor=self.config.processing.scale_factor ) # Process chunk matted_frames = self.process_chunk(frames, chunk_idx) # Save chunk to disk immediately to free memory chunk_path = temp_chunk_dir / f"chunk_{chunk_idx:04d}.npz" print(f"Saving chunk {chunk_idx} to disk...") np.savez_compressed(str(chunk_path), frames=matted_frames) chunk_files.append(chunk_path) # Free the frames from memory immediately del matted_frames del frames # Update statistics self.processing_stats['chunks_processed'] += 1 self.processing_stats['frames_processed'] += frames_to_read # Aggressive memory cleanup after each chunk self._aggressive_memory_cleanup(f"chunk {chunk_idx} completion") # Also use memory manager cleanup self.memory_manager.cleanup_memory() if self.memory_manager.should_emergency_cleanup(): self.memory_manager.emergency_cleanup() # Use streaming merge to avoid memory accumulation (fixes OOM) print("\n๐ŸŽฌ Using streaming merge (no memory accumulation)...") # Determine audio source for final video audio_source = None if self.config.output.preserve_audio and Path(self.config.input.video_path).exists(): audio_source = self.config.input.video_path # Stream merge chunks directly to output (no memory accumulation) self.merge_chunks_streaming( chunk_files=chunk_files, output_path=self.config.output.path, overlap_frames=overlap_frames, audio_source=audio_source ) print("โœ… Streaming merge complete - no memory accumulation!") # Calculate final statistics self.processing_stats['end_time'] = time.time() self.processing_stats['total_duration'] = self.processing_stats['end_time'] - self.processing_stats['start_time'] if self.processing_stats['total_duration'] > 0: self.processing_stats['processing_fps'] = self.processing_stats['frames_processed'] / self.processing_stats['total_duration'] # Print processing statistics self._print_processing_statistics() # Print final memory report self.memory_manager.print_memory_report() print("Video processing completed!") finally: # Clean up temporary chunk files if temp_chunk_dir.exists(): print("Cleaning up temporary chunk files...") shutil.rmtree(temp_chunk_dir) def _print_processing_statistics(self): """Print detailed processing statistics""" stats = self.processing_stats video_duration = self.total_frames / self.fps if self.fps > 0 else 0 print("\n" + "="*60) print("PROCESSING STATISTICS") print("="*60) print(f"Input video duration: {video_duration:.1f} seconds ({self.total_frames} frames @ {self.fps:.2f} fps)") print(f"Total processing time: {stats['total_duration']:.1f} seconds") print(f"Processing speed: {stats['processing_fps']:.2f} fps") print(f"Speedup factor: {self.fps / stats['processing_fps']:.1f}x slower than realtime") print(f"Chunks processed: {stats['chunks_processed']}") print(f"Frames processed: {stats['frames_processed']}") if video_duration > 0: efficiency = video_duration / stats['total_duration'] print(f"Processing efficiency: {efficiency:.3f} (1.0 = realtime)") # Estimate time for different video lengths print(f"\nEstimated processing times:") print(f" 5 minutes: {(5 * 60) / efficiency / 60:.1f} minutes") print(f" 30 minutes: {(30 * 60) / efficiency / 60:.1f} minutes") print(f" 1 hour: {(60 * 60) / efficiency / 60:.1f} minutes") print("="*60 + "\n")