Files
test2/vr180_matting/video_processor.py
2025-07-26 07:57:01 -07:00

417 lines
16 KiB
Python

import cv2
import numpy as np
from typing import List, Dict, Any, Optional, Tuple, Generator
from pathlib import Path
import ffmpeg
import tempfile
import shutil
from tqdm import tqdm
import warnings
from .config import VR180Config
from .detector import YOLODetector
from .sam2_wrapper import SAM2VideoMatting
from .memory_manager import VRAMManager
class VideoProcessor:
"""Main video processing pipeline for VR180 matting"""
def __init__(self, config: VR180Config):
self.config = config
self.memory_manager = VRAMManager(
max_vram_gb=config.hardware.max_vram_gb,
device=config.hardware.device
)
# Initialize components
self.detector = None
self.sam2_model = None
# Video properties
self.video_info = None
self.total_frames = 0
self.fps = 30.0
self.frame_width = 0
self.frame_height = 0
self._initialize_models()
def _initialize_models(self):
"""Initialize YOLO detector and SAM2 model"""
print("Initializing models...")
with self.memory_manager.memory_monitor("model loading"):
# Initialize YOLO detector
self.detector = YOLODetector(
model_name=self.config.detection.model,
confidence_threshold=self.config.detection.confidence_threshold,
device=self.config.hardware.device
)
# Initialize SAM2 model
self.sam2_model = SAM2VideoMatting(
model_cfg=self.config.matting.sam2_model_cfg,
checkpoint_path=self.config.matting.sam2_checkpoint,
device=self.config.hardware.device,
memory_offload=self.config.matting.memory_offload,
fp16=self.config.matting.fp16
)
def load_video_info(self, video_path: str) -> Dict[str, Any]:
"""Load video metadata using ffmpeg"""
try:
probe = ffmpeg.probe(video_path)
video_stream = next(
(stream for stream in probe['streams'] if stream['codec_type'] == 'video'),
None
)
if video_stream is None:
raise ValueError("No video stream found")
self.video_info = {
'width': int(video_stream['width']),
'height': int(video_stream['height']),
'fps': eval(video_stream['r_frame_rate']),
'duration': float(video_stream.get('duration', 0)),
'nb_frames': int(video_stream.get('nb_frames', 0)),
'codec': video_stream['codec_name'],
'pix_fmt': video_stream.get('pix_fmt', 'yuv420p')
}
self.frame_width = self.video_info['width']
self.frame_height = self.video_info['height']
self.fps = self.video_info['fps']
self.total_frames = self.video_info['nb_frames']
print(f"Video info: {self.frame_width}x{self.frame_height} @ {self.fps:.2f}fps")
print(f"Total frames: {self.total_frames}, Duration: {self.video_info['duration']:.1f}s")
return self.video_info
except Exception as e:
raise RuntimeError(f"Failed to load video info: {e}")
def read_video_frames(self,
video_path: str,
start_frame: int = 0,
num_frames: Optional[int] = None,
scale_factor: float = 1.0) -> List[np.ndarray]:
"""
Read video frames with optional scaling
Args:
video_path: Path to video file
start_frame: Starting frame index
num_frames: Number of frames to read (None for all)
scale_factor: Scaling factor for frames
Returns:
List of video frames
"""
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise RuntimeError(f"Failed to open video: {video_path}")
# Set starting position
if start_frame > 0:
cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
frames = []
frame_count = 0
with tqdm(desc="Reading frames", total=num_frames) as pbar:
while True:
ret, frame = cap.read()
if not ret:
break
# Apply scaling if needed
if scale_factor != 1.0:
new_width = int(frame.shape[1] * scale_factor)
new_height = int(frame.shape[0] * scale_factor)
frame = cv2.resize(frame, (new_width, new_height),
interpolation=cv2.INTER_AREA)
frames.append(frame)
frame_count += 1
pbar.update(1)
if num_frames is not None and frame_count >= num_frames:
break
cap.release()
print(f"Read {len(frames)} frames")
return frames
def calculate_optimal_chunking(self) -> Tuple[int, int]:
"""
Calculate optimal chunk size and overlap based on memory constraints
Returns:
Tuple of (chunk_size, overlap_frames)
"""
if self.config.processing.chunk_size > 0:
return self.config.processing.chunk_size, self.config.processing.overlap_frames
# Calculate based on memory constraints
scaled_height = int(self.frame_height * self.config.processing.scale_factor)
scaled_width = int(self.frame_width * self.config.processing.scale_factor)
optimal_chunk = self.memory_manager.get_optimal_chunk_size(
scaled_height, scaled_width, fp16=self.config.matting.fp16
)
overlap = min(60, optimal_chunk // 10) # 10% overlap, max 60 frames
print(f"Calculated optimal chunk size: {optimal_chunk} frames with {overlap} frame overlap")
return optimal_chunk, overlap
def process_chunk(self,
frames: List[np.ndarray],
chunk_idx: int = 0) -> List[np.ndarray]:
"""
Process a chunk of frames through the matting pipeline
Args:
frames: List of frames to process
chunk_idx: Chunk index for logging
Returns:
List of matted frames
"""
print(f"Processing chunk {chunk_idx} ({len(frames)} frames)")
with self.memory_manager.memory_monitor(f"chunk {chunk_idx}"):
# Initialize SAM2 with frames
self.sam2_model.init_video_state(frames)
# Detect persons in first frame
first_frame = frames[0]
detections = self.detector.detect_persons(first_frame)
if not detections:
warnings.warn(f"No persons detected in chunk {chunk_idx}")
return self._create_empty_masks(frames)
print(f"Detected {len(detections)} persons in first frame")
# Convert detections to SAM2 prompts
box_prompts, labels = self.detector.convert_to_sam_prompts(detections)
# Add prompts to SAM2
object_ids = self.sam2_model.add_person_prompts(0, box_prompts, labels)
print(f"Added prompts for {len(object_ids)} objects")
# Propagate masks through chunk
video_segments = self.sam2_model.propagate_masks(
start_frame=0,
max_frames=len(frames)
)
# Apply masks to frames
matted_frames = []
for frame_idx, frame in enumerate(tqdm(frames, desc="Applying masks")):
if frame_idx in video_segments:
frame_masks = video_segments[frame_idx]
combined_mask = self.sam2_model.get_combined_mask(frame_masks)
matted_frame = self.sam2_model.apply_mask_to_frame(
frame, combined_mask,
output_format=self.config.output.format,
background_color=self.config.output.background_color
)
else:
# No mask for this frame
matted_frame = self._create_empty_mask_frame(frame)
matted_frames.append(matted_frame)
# Cleanup SAM2 state
self.sam2_model.cleanup()
return matted_frames
def _create_empty_masks(self, frames: List[np.ndarray]) -> List[np.ndarray]:
"""Create empty masks when no persons detected"""
empty_frames = []
for frame in frames:
empty_frame = self._create_empty_mask_frame(frame)
empty_frames.append(empty_frame)
return empty_frames
def _create_empty_mask_frame(self, frame: np.ndarray) -> np.ndarray:
"""Create frame with empty mask (all background)"""
if self.config.output.format == "alpha":
# Transparent output
output = np.zeros((frame.shape[0], frame.shape[1], 4), dtype=np.uint8)
return output
else:
# Green screen background
return np.full_like(frame, self.config.output.background_color, dtype=np.uint8)
def merge_overlapping_chunks(self,
chunk_results: List[List[np.ndarray]],
overlap_frames: int) -> List[np.ndarray]:
"""
Merge overlapping chunks with blending in overlap regions
Args:
chunk_results: List of chunk results
overlap_frames: Number of overlapping frames
Returns:
Merged frame sequence
"""
if len(chunk_results) == 1:
return chunk_results[0]
merged_frames = []
# Add first chunk completely
merged_frames.extend(chunk_results[0])
# Process remaining chunks
for chunk_idx in range(1, len(chunk_results)):
chunk = chunk_results[chunk_idx]
if overlap_frames > 0:
# Blend overlap region
overlap_start = len(merged_frames) - overlap_frames
for i in range(overlap_frames):
if i < len(chunk):
# Linear blending
alpha = i / overlap_frames
prev_frame = merged_frames[overlap_start + i]
curr_frame = chunk[i]
blended = self._blend_frames(prev_frame, curr_frame, alpha)
merged_frames[overlap_start + i] = blended
# Add remaining frames from current chunk
merged_frames.extend(chunk[overlap_frames:])
else:
# No overlap, just concatenate
merged_frames.extend(chunk)
return merged_frames
def _blend_frames(self, frame1: np.ndarray, frame2: np.ndarray, alpha: float) -> np.ndarray:
"""Blend two frames with alpha blending"""
if frame1.shape != frame2.shape:
return frame2 # Fallback to second frame
blended = (1 - alpha) * frame1.astype(np.float32) + alpha * frame2.astype(np.float32)
return blended.astype(np.uint8)
def save_video(self, frames: List[np.ndarray], output_path: str):
"""
Save processed frames as video
Args:
frames: List of processed frames
output_path: Output video path
"""
if not frames:
raise ValueError("No frames to save")
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
# Determine codec and format based on output format
if self.config.output.format == "alpha":
# Use PNG sequence for alpha channel
self._save_png_sequence(frames, output_path.parent / f"{output_path.stem}_frames")
else:
# Save as regular video
self._save_mp4_video(frames, str(output_path))
def _save_png_sequence(self, frames: List[np.ndarray], output_dir: Path):
"""Save frames as PNG sequence with alpha channel"""
output_dir.mkdir(parents=True, exist_ok=True)
for i, frame in enumerate(tqdm(frames, desc="Saving PNG sequence")):
frame_path = output_dir / f"frame_{i:06d}.png"
# Convert BGR to RGBA for PNG
if frame.shape[2] == 4: # Already RGBA
frame_rgba = cv2.cvtColor(frame, cv2.COLOR_BGRA2RGBA)
else: # BGR to RGBA
frame_rgba = cv2.cvtColor(frame, cv2.COLOR_BGR2RGBA)
cv2.imwrite(str(frame_path), frame_rgba)
print(f"Saved {len(frames)} PNG frames to {output_dir}")
def _save_mp4_video(self, frames: List[np.ndarray], output_path: str):
"""Save frames as MP4 video"""
if not frames:
return
height, width = frames[0].shape[:2]
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
writer = cv2.VideoWriter(output_path, fourcc, self.fps, (width, height))
for frame in tqdm(frames, desc="Writing video"):
if frame.shape[2] == 4: # Convert RGBA to BGR
frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2BGR)
writer.write(frame)
writer.release()
print(f"Saved video to {output_path}")
def process_video(self) -> None:
"""Main video processing pipeline"""
print("Starting VR180 video processing...")
# Load video info
self.load_video_info(self.config.input.video_path)
# Calculate chunking parameters
chunk_size, overlap_frames = self.calculate_optimal_chunking()
# Process video in chunks
chunk_results = []
for start_frame in range(0, self.total_frames, chunk_size - overlap_frames):
end_frame = min(start_frame + chunk_size, self.total_frames)
frames_to_read = end_frame - start_frame
chunk_idx = len(chunk_results)
print(f"\nProcessing chunk {chunk_idx}: frames {start_frame}-{end_frame}")
# Read chunk frames
frames = self.read_video_frames(
self.config.input.video_path,
start_frame=start_frame,
num_frames=frames_to_read,
scale_factor=self.config.processing.scale_factor
)
# Process chunk
matted_frames = self.process_chunk(frames, chunk_idx)
chunk_results.append(matted_frames)
# Memory cleanup
self.memory_manager.cleanup_memory()
if self.memory_manager.should_emergency_cleanup():
self.memory_manager.emergency_cleanup()
# Merge chunks if multiple
print("\nMerging chunks...")
final_frames = self.merge_overlapping_chunks(chunk_results, overlap_frames)
# Save results
print(f"Saving {len(final_frames)} processed frames...")
self.save_video(final_frames, self.config.output.path)
# Print final memory report
self.memory_manager.print_memory_report()
print("Video processing completed!")