import cv2 import numpy as np from typing import List, Dict, Any, Optional, Tuple, Generator from pathlib import Path import ffmpeg import tempfile import shutil from tqdm import tqdm import warnings from .config import VR180Config from .detector import YOLODetector from .sam2_wrapper import SAM2VideoMatting from .memory_manager import VRAMManager class VideoProcessor: """Main video processing pipeline for VR180 matting""" def __init__(self, config: VR180Config): self.config = config self.memory_manager = VRAMManager( max_vram_gb=config.hardware.max_vram_gb, device=config.hardware.device ) # Initialize components self.detector = None self.sam2_model = None # Video properties self.video_info = None self.total_frames = 0 self.fps = 30.0 self.frame_width = 0 self.frame_height = 0 self._initialize_models() def _initialize_models(self): """Initialize YOLO detector and SAM2 model""" print("Initializing models...") with self.memory_manager.memory_monitor("model loading"): # Initialize YOLO detector self.detector = YOLODetector( model_name=self.config.detection.model, confidence_threshold=self.config.detection.confidence_threshold, device=self.config.hardware.device ) # Initialize SAM2 model self.sam2_model = SAM2VideoMatting( device=self.config.hardware.device, memory_offload=self.config.matting.memory_offload, fp16=self.config.matting.fp16 ) def load_video_info(self, video_path: str) -> Dict[str, Any]: """Load video metadata using ffmpeg""" try: probe = ffmpeg.probe(video_path) video_stream = next( (stream for stream in probe['streams'] if stream['codec_type'] == 'video'), None ) if video_stream is None: raise ValueError("No video stream found") self.video_info = { 'width': int(video_stream['width']), 'height': int(video_stream['height']), 'fps': eval(video_stream['r_frame_rate']), 'duration': float(video_stream.get('duration', 0)), 'nb_frames': int(video_stream.get('nb_frames', 0)), 'codec': video_stream['codec_name'], 'pix_fmt': video_stream.get('pix_fmt', 'yuv420p') } self.frame_width = self.video_info['width'] self.frame_height = self.video_info['height'] self.fps = self.video_info['fps'] self.total_frames = self.video_info['nb_frames'] print(f"Video info: {self.frame_width}x{self.frame_height} @ {self.fps:.2f}fps") print(f"Total frames: {self.total_frames}, Duration: {self.video_info['duration']:.1f}s") return self.video_info except Exception as e: raise RuntimeError(f"Failed to load video info: {e}") def read_video_frames(self, video_path: str, start_frame: int = 0, num_frames: Optional[int] = None, scale_factor: float = 1.0) -> List[np.ndarray]: """ Read video frames with optional scaling Args: video_path: Path to video file start_frame: Starting frame index num_frames: Number of frames to read (None for all) scale_factor: Scaling factor for frames Returns: List of video frames """ cap = cv2.VideoCapture(video_path) if not cap.isOpened(): raise RuntimeError(f"Failed to open video: {video_path}") # Set starting position if start_frame > 0: cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame) frames = [] frame_count = 0 with tqdm(desc="Reading frames", total=num_frames) as pbar: while True: ret, frame = cap.read() if not ret: break # Apply scaling if needed if scale_factor != 1.0: new_width = int(frame.shape[1] * scale_factor) new_height = int(frame.shape[0] * scale_factor) frame = cv2.resize(frame, (new_width, new_height), interpolation=cv2.INTER_AREA) frames.append(frame) frame_count += 1 pbar.update(1) if num_frames is not None and frame_count >= num_frames: break cap.release() print(f"Read {len(frames)} frames") return frames def calculate_optimal_chunking(self) -> Tuple[int, int]: """ Calculate optimal chunk size and overlap based on memory constraints Returns: Tuple of (chunk_size, overlap_frames) """ if self.config.processing.chunk_size > 0: return self.config.processing.chunk_size, self.config.processing.overlap_frames # Calculate based on memory constraints scaled_height = int(self.frame_height * self.config.processing.scale_factor) scaled_width = int(self.frame_width * self.config.processing.scale_factor) optimal_chunk = self.memory_manager.get_optimal_chunk_size( scaled_height, scaled_width, fp16=self.config.matting.fp16 ) overlap = min(60, optimal_chunk // 10) # 10% overlap, max 60 frames print(f"Calculated optimal chunk size: {optimal_chunk} frames with {overlap} frame overlap") return optimal_chunk, overlap def process_chunk(self, frames: List[np.ndarray], chunk_idx: int = 0) -> List[np.ndarray]: """ Process a chunk of frames through the matting pipeline Args: frames: List of frames to process chunk_idx: Chunk index for logging Returns: List of matted frames """ print(f"Processing chunk {chunk_idx} ({len(frames)} frames)") with self.memory_manager.memory_monitor(f"chunk {chunk_idx}"): # Initialize SAM2 with frames self.sam2_model.init_video_state(frames) # Detect persons in first frame first_frame = frames[0] detections = self.detector.detect_persons(first_frame) if not detections: warnings.warn(f"No persons detected in chunk {chunk_idx}") return self._create_empty_masks(frames) print(f"Detected {len(detections)} persons in first frame") # Convert detections to SAM2 prompts box_prompts, labels = self.detector.convert_to_sam_prompts(detections) # Add prompts to SAM2 object_ids = self.sam2_model.add_person_prompts(0, box_prompts, labels) print(f"Added prompts for {len(object_ids)} objects") # Propagate masks through chunk video_segments = self.sam2_model.propagate_masks( start_frame=0, max_frames=len(frames) ) # Apply masks to frames matted_frames = [] for frame_idx, frame in enumerate(tqdm(frames, desc="Applying masks")): if frame_idx in video_segments: frame_masks = video_segments[frame_idx] combined_mask = self.sam2_model.get_combined_mask(frame_masks) matted_frame = self.sam2_model.apply_mask_to_frame( frame, combined_mask, output_format=self.config.output.format, background_color=self.config.output.background_color ) else: # No mask for this frame matted_frame = self._create_empty_mask_frame(frame) matted_frames.append(matted_frame) # Cleanup SAM2 state self.sam2_model.cleanup() return matted_frames def _create_empty_masks(self, frames: List[np.ndarray]) -> List[np.ndarray]: """Create empty masks when no persons detected""" empty_frames = [] for frame in frames: empty_frame = self._create_empty_mask_frame(frame) empty_frames.append(empty_frame) return empty_frames def _create_empty_mask_frame(self, frame: np.ndarray) -> np.ndarray: """Create frame with empty mask (all background)""" if self.config.output.format == "alpha": # Transparent output output = np.zeros((frame.shape[0], frame.shape[1], 4), dtype=np.uint8) return output else: # Green screen background return np.full_like(frame, self.config.output.background_color, dtype=np.uint8) def merge_overlapping_chunks(self, chunk_results: List[List[np.ndarray]], overlap_frames: int) -> List[np.ndarray]: """ Merge overlapping chunks with blending in overlap regions Args: chunk_results: List of chunk results overlap_frames: Number of overlapping frames Returns: Merged frame sequence """ if len(chunk_results) == 1: return chunk_results[0] merged_frames = [] # Add first chunk completely merged_frames.extend(chunk_results[0]) # Process remaining chunks for chunk_idx in range(1, len(chunk_results)): chunk = chunk_results[chunk_idx] if overlap_frames > 0: # Blend overlap region overlap_start = len(merged_frames) - overlap_frames for i in range(overlap_frames): if i < len(chunk): # Linear blending alpha = i / overlap_frames prev_frame = merged_frames[overlap_start + i] curr_frame = chunk[i] blended = self._blend_frames(prev_frame, curr_frame, alpha) merged_frames[overlap_start + i] = blended # Add remaining frames from current chunk merged_frames.extend(chunk[overlap_frames:]) else: # No overlap, just concatenate merged_frames.extend(chunk) return merged_frames def _blend_frames(self, frame1: np.ndarray, frame2: np.ndarray, alpha: float) -> np.ndarray: """Blend two frames with alpha blending""" if frame1.shape != frame2.shape: return frame2 # Fallback to second frame blended = (1 - alpha) * frame1.astype(np.float32) + alpha * frame2.astype(np.float32) return blended.astype(np.uint8) def save_video(self, frames: List[np.ndarray], output_path: str): """ Save processed frames as video Args: frames: List of processed frames output_path: Output video path """ if not frames: raise ValueError("No frames to save") output_path = Path(output_path) output_path.parent.mkdir(parents=True, exist_ok=True) # Determine codec and format based on output format if self.config.output.format == "alpha": # Use PNG sequence for alpha channel self._save_png_sequence(frames, output_path.parent / f"{output_path.stem}_frames") else: # Save as regular video self._save_mp4_video(frames, str(output_path)) def _save_png_sequence(self, frames: List[np.ndarray], output_dir: Path): """Save frames as PNG sequence with alpha channel""" output_dir.mkdir(parents=True, exist_ok=True) for i, frame in enumerate(tqdm(frames, desc="Saving PNG sequence")): frame_path = output_dir / f"frame_{i:06d}.png" # Convert BGR to RGBA for PNG if frame.shape[2] == 4: # Already RGBA frame_rgba = cv2.cvtColor(frame, cv2.COLOR_BGRA2RGBA) else: # BGR to RGBA frame_rgba = cv2.cvtColor(frame, cv2.COLOR_BGR2RGBA) cv2.imwrite(str(frame_path), frame_rgba) print(f"Saved {len(frames)} PNG frames to {output_dir}") def _save_mp4_video(self, frames: List[np.ndarray], output_path: str): """Save frames as MP4 video""" if not frames: return height, width = frames[0].shape[:2] fourcc = cv2.VideoWriter_fourcc(*'mp4v') writer = cv2.VideoWriter(output_path, fourcc, self.fps, (width, height)) for frame in tqdm(frames, desc="Writing video"): if frame.shape[2] == 4: # Convert RGBA to BGR frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2BGR) writer.write(frame) writer.release() print(f"Saved video to {output_path}") def process_video(self) -> None: """Main video processing pipeline""" print("Starting VR180 video processing...") # Load video info self.load_video_info(self.config.input.video_path) # Calculate chunking parameters chunk_size, overlap_frames = self.calculate_optimal_chunking() # Process video in chunks chunk_results = [] for start_frame in range(0, self.total_frames, chunk_size - overlap_frames): end_frame = min(start_frame + chunk_size, self.total_frames) frames_to_read = end_frame - start_frame chunk_idx = len(chunk_results) print(f"\nProcessing chunk {chunk_idx}: frames {start_frame}-{end_frame}") # Read chunk frames frames = self.read_video_frames( self.config.input.video_path, start_frame=start_frame, num_frames=frames_to_read, scale_factor=self.config.processing.scale_factor ) # Process chunk matted_frames = self.process_chunk(frames, chunk_idx) chunk_results.append(matted_frames) # Memory cleanup self.memory_manager.cleanup_memory() if self.memory_manager.should_emergency_cleanup(): self.memory_manager.emergency_cleanup() # Merge chunks if multiple print("\nMerging chunks...") final_frames = self.merge_overlapping_chunks(chunk_results, overlap_frames) # Save results print(f"Saving {len(final_frames)} processed frames...") self.save_video(final_frames, self.config.output.path) # Print final memory report self.memory_manager.print_memory_report() print("Video processing completed!")