diff --git a/vr180_matting/video_processor.py b/vr180_matting/video_processor.py index 94759cd..a99e91f 100644 --- a/vr180_matting/video_processor.py +++ b/vr180_matting/video_processor.py @@ -281,6 +281,116 @@ class VideoProcessor: print(f"Read {len(frames)} frames") return frames + def read_video_frames_dual_resolution(self, + video_path: str, + start_frame: int = 0, + num_frames: Optional[int] = None, + scale_factor: float = 0.5) -> Dict[str, List[np.ndarray]]: + """ + Read video frames at both original and scaled resolution for dual-resolution processing + + Args: + video_path: Path to video file + start_frame: Starting frame index + num_frames: Number of frames to read (None for all) + scale_factor: Scaling factor for inference frames + + Returns: + Dictionary with 'original' and 'scaled' frame lists + """ + cap = cv2.VideoCapture(video_path) + + if not cap.isOpened(): + raise RuntimeError(f"Could not open video file: {video_path}") + + # Set starting position + cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame) + + original_frames = [] + scaled_frames = [] + frame_count = 0 + + # Progress tracking + total_to_read = num_frames if num_frames else self.total_frames - start_frame + + with tqdm(total=total_to_read, desc="Reading dual-resolution frames") as pbar: + while True: + ret, frame = cap.read() + if not ret: + break + + # Store original frame + original_frames.append(frame.copy()) + + # Create scaled frame for inference + if scale_factor != 1.0: + new_width = int(frame.shape[1] * scale_factor) + new_height = int(frame.shape[0] * scale_factor) + scaled_frame = cv2.resize(frame, (new_width, new_height), + interpolation=cv2.INTER_AREA) + else: + scaled_frame = frame.copy() + + scaled_frames.append(scaled_frame) + frame_count += 1 + pbar.update(1) + + if num_frames is not None and frame_count >= num_frames: + break + + cap.release() + + print(f"Loaded {len(original_frames)} frames:") + print(f" Original: {original_frames[0].shape} per frame") + print(f" Scaled: {scaled_frames[0].shape} per frame (scale_factor={scale_factor})") + + return { + 'original': original_frames, + 'scaled': scaled_frames + } + + def upscale_mask(self, mask: np.ndarray, target_shape: tuple, method: str = 'cubic') -> np.ndarray: + """ + Upscale a mask from inference resolution to original resolution + + Args: + mask: Low-resolution mask (H, W) + target_shape: Target shape (H, W) for upscaling + method: Upscaling method ('nearest', 'cubic', 'area') + + Returns: + Upscaled mask at target resolution + """ + if mask.shape[:2] == target_shape[:2]: + return mask # Already correct size + + # Ensure mask is 2D + if mask.ndim == 3: + mask = mask.squeeze() + + # Choose interpolation method + if method == 'nearest': + interpolation = cv2.INTER_NEAREST # Crisp edges, good for sharp subjects + elif method == 'cubic': + interpolation = cv2.INTER_CUBIC # Smooth edges, good for most content + elif method == 'area': + interpolation = cv2.INTER_AREA # Good for downscaling, not upscaling + else: + interpolation = cv2.INTER_CUBIC # Default to cubic + + # Upscale mask + upscaled_mask = cv2.resize( + mask.astype(np.uint8), + (target_shape[1], target_shape[0]), # (width, height) for cv2.resize + interpolation=interpolation + ) + + # Convert back to boolean if it was originally boolean + if mask.dtype == bool: + upscaled_mask = upscaled_mask.astype(bool) + + return upscaled_mask + def calculate_optimal_chunking(self) -> Tuple[int, int]: """ Calculate optimal chunk size and overlap based on memory constraints @@ -369,6 +479,92 @@ class VideoProcessor: return matted_frames + def process_chunk_dual_resolution(self, + frame_data: Dict[str, List[np.ndarray]], + chunk_idx: int = 0) -> List[np.ndarray]: + """ + Process a chunk using dual-resolution approach: inference at low-res, output at full-res + + Args: + frame_data: Dictionary with 'original' and 'scaled' frame lists + chunk_idx: Chunk index for logging + + Returns: + List of matted frames at original resolution + """ + original_frames = frame_data['original'] + scaled_frames = frame_data['scaled'] + + print(f"Processing chunk {chunk_idx} with dual-resolution ({len(original_frames)} frames)") + print(f" Inference: {scaled_frames[0].shape} → Output: {original_frames[0].shape}") + + with self.memory_manager.memory_monitor(f"dual-res chunk {chunk_idx}"): + # Initialize SAM2 with scaled frames for inference + self.sam2_model.init_video_state(scaled_frames) + + # Detect persons in first scaled frame + first_scaled_frame = scaled_frames[0] + detections = self.detector.detect_persons(first_scaled_frame) + + if not detections: + warnings.warn(f"No persons detected in chunk {chunk_idx}") + return self._create_empty_masks(original_frames) + + print(f"Detected {len(detections)} persons in first frame (at inference resolution)") + + # Convert detections to SAM2 prompts (detections are already at scaled resolution) + box_prompts, labels = self.detector.convert_to_sam_prompts(detections) + + # Add prompts to SAM2 + object_ids = self.sam2_model.add_person_prompts(0, box_prompts, labels) + print(f"Added prompts for {len(object_ids)} objects") + + # Propagate masks through chunk at inference resolution + video_segments = self.sam2_model.propagate_masks( + start_frame=0, + max_frames=len(scaled_frames) + ) + + # Apply upscaled masks to original resolution frames + matted_frames = [] + original_shape = original_frames[0].shape[:2] # (H, W) + + for frame_idx, original_frame in enumerate(tqdm(original_frames, desc="Applying upscaled masks")): + if frame_idx in video_segments: + frame_masks = video_segments[frame_idx] + + # Get combined mask at inference resolution + combined_mask_scaled = self.sam2_model.get_combined_mask(frame_masks) + + if combined_mask_scaled is not None: + # Upscale mask to original resolution + combined_mask_full = self.upscale_mask( + combined_mask_scaled, + target_shape=original_shape, + method='cubic' # Smooth upscaling for masks + ) + + # Apply upscaled mask to original resolution frame + matted_frame = self.sam2_model.apply_mask_to_frame( + original_frame, combined_mask_full, + output_format=self.config.output.format, + background_color=self.config.output.background_color + ) + else: + # No mask for this frame + matted_frame = self._create_empty_mask_frame(original_frame) + else: + # No mask for this frame + matted_frame = self._create_empty_mask_frame(original_frame) + + matted_frames.append(matted_frame) + + # Cleanup SAM2 state + self.sam2_model.cleanup() + + print(f"✅ Dual-resolution processing complete: {len(matted_frames)} frames at full resolution") + return matted_frames + def _create_empty_masks(self, frames: List[np.ndarray]) -> List[np.ndarray]: """Create empty masks when no persons detected""" empty_frames = [] @@ -829,16 +1025,17 @@ class VideoProcessor: print(f"\nProcessing chunk {chunk_idx}: frames {start_frame}-{end_frame}") - # Read chunk frames - frames = self.read_video_frames( + # Read chunk frames at dual resolution + print(f"🔄 Reading frames at dual resolution (scale_factor={self.config.processing.scale_factor})") + frame_data = self.read_video_frames_dual_resolution( self.config.input.video_path, start_frame=start_frame, num_frames=frames_to_read, scale_factor=self.config.processing.scale_factor ) - # Process chunk - matted_frames = self.process_chunk(frames, chunk_idx) + # Process chunk with dual-resolution approach + matted_frames = self.process_chunk_dual_resolution(frame_data, chunk_idx) # Save chunk to disk immediately to free memory chunk_path = temp_chunk_dir / f"chunk_{chunk_idx:04d}.npz" @@ -853,7 +1050,7 @@ class VideoProcessor: # Free the frames from memory immediately del matted_frames - del frames + del frame_data # Update statistics self.processing_stats['chunks_processed'] += 1