415 lines
16 KiB
Python
415 lines
16 KiB
Python
import cv2
|
|
import numpy as np
|
|
from typing import List, Dict, Any, Optional, Tuple, Generator
|
|
from pathlib import Path
|
|
import ffmpeg
|
|
import tempfile
|
|
import shutil
|
|
from tqdm import tqdm
|
|
import warnings
|
|
|
|
from .config import VR180Config
|
|
from .detector import YOLODetector
|
|
from .sam2_wrapper import SAM2VideoMatting
|
|
from .memory_manager import VRAMManager
|
|
|
|
|
|
class VideoProcessor:
|
|
"""Main video processing pipeline for VR180 matting"""
|
|
|
|
def __init__(self, config: VR180Config):
|
|
self.config = config
|
|
self.memory_manager = VRAMManager(
|
|
max_vram_gb=config.hardware.max_vram_gb,
|
|
device=config.hardware.device
|
|
)
|
|
|
|
# Initialize components
|
|
self.detector = None
|
|
self.sam2_model = None
|
|
|
|
# Video properties
|
|
self.video_info = None
|
|
self.total_frames = 0
|
|
self.fps = 30.0
|
|
self.frame_width = 0
|
|
self.frame_height = 0
|
|
|
|
self._initialize_models()
|
|
|
|
def _initialize_models(self):
|
|
"""Initialize YOLO detector and SAM2 model"""
|
|
print("Initializing models...")
|
|
|
|
with self.memory_manager.memory_monitor("model loading"):
|
|
# Initialize YOLO detector
|
|
self.detector = YOLODetector(
|
|
model_name=self.config.detection.model,
|
|
confidence_threshold=self.config.detection.confidence_threshold,
|
|
device=self.config.hardware.device
|
|
)
|
|
|
|
# Initialize SAM2 model
|
|
self.sam2_model = SAM2VideoMatting(
|
|
device=self.config.hardware.device,
|
|
memory_offload=self.config.matting.memory_offload,
|
|
fp16=self.config.matting.fp16
|
|
)
|
|
|
|
def load_video_info(self, video_path: str) -> Dict[str, Any]:
|
|
"""Load video metadata using ffmpeg"""
|
|
try:
|
|
probe = ffmpeg.probe(video_path)
|
|
video_stream = next(
|
|
(stream for stream in probe['streams'] if stream['codec_type'] == 'video'),
|
|
None
|
|
)
|
|
|
|
if video_stream is None:
|
|
raise ValueError("No video stream found")
|
|
|
|
self.video_info = {
|
|
'width': int(video_stream['width']),
|
|
'height': int(video_stream['height']),
|
|
'fps': eval(video_stream['r_frame_rate']),
|
|
'duration': float(video_stream.get('duration', 0)),
|
|
'nb_frames': int(video_stream.get('nb_frames', 0)),
|
|
'codec': video_stream['codec_name'],
|
|
'pix_fmt': video_stream.get('pix_fmt', 'yuv420p')
|
|
}
|
|
|
|
self.frame_width = self.video_info['width']
|
|
self.frame_height = self.video_info['height']
|
|
self.fps = self.video_info['fps']
|
|
self.total_frames = self.video_info['nb_frames']
|
|
|
|
print(f"Video info: {self.frame_width}x{self.frame_height} @ {self.fps:.2f}fps")
|
|
print(f"Total frames: {self.total_frames}, Duration: {self.video_info['duration']:.1f}s")
|
|
|
|
return self.video_info
|
|
|
|
except Exception as e:
|
|
raise RuntimeError(f"Failed to load video info: {e}")
|
|
|
|
def read_video_frames(self,
|
|
video_path: str,
|
|
start_frame: int = 0,
|
|
num_frames: Optional[int] = None,
|
|
scale_factor: float = 1.0) -> List[np.ndarray]:
|
|
"""
|
|
Read video frames with optional scaling
|
|
|
|
Args:
|
|
video_path: Path to video file
|
|
start_frame: Starting frame index
|
|
num_frames: Number of frames to read (None for all)
|
|
scale_factor: Scaling factor for frames
|
|
|
|
Returns:
|
|
List of video frames
|
|
"""
|
|
cap = cv2.VideoCapture(video_path)
|
|
|
|
if not cap.isOpened():
|
|
raise RuntimeError(f"Failed to open video: {video_path}")
|
|
|
|
# Set starting position
|
|
if start_frame > 0:
|
|
cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
|
|
|
|
frames = []
|
|
frame_count = 0
|
|
|
|
with tqdm(desc="Reading frames", total=num_frames) as pbar:
|
|
while True:
|
|
ret, frame = cap.read()
|
|
if not ret:
|
|
break
|
|
|
|
# Apply scaling if needed
|
|
if scale_factor != 1.0:
|
|
new_width = int(frame.shape[1] * scale_factor)
|
|
new_height = int(frame.shape[0] * scale_factor)
|
|
frame = cv2.resize(frame, (new_width, new_height),
|
|
interpolation=cv2.INTER_AREA)
|
|
|
|
frames.append(frame)
|
|
frame_count += 1
|
|
pbar.update(1)
|
|
|
|
if num_frames is not None and frame_count >= num_frames:
|
|
break
|
|
|
|
cap.release()
|
|
print(f"Read {len(frames)} frames")
|
|
return frames
|
|
|
|
def calculate_optimal_chunking(self) -> Tuple[int, int]:
|
|
"""
|
|
Calculate optimal chunk size and overlap based on memory constraints
|
|
|
|
Returns:
|
|
Tuple of (chunk_size, overlap_frames)
|
|
"""
|
|
if self.config.processing.chunk_size > 0:
|
|
return self.config.processing.chunk_size, self.config.processing.overlap_frames
|
|
|
|
# Calculate based on memory constraints
|
|
scaled_height = int(self.frame_height * self.config.processing.scale_factor)
|
|
scaled_width = int(self.frame_width * self.config.processing.scale_factor)
|
|
|
|
optimal_chunk = self.memory_manager.get_optimal_chunk_size(
|
|
scaled_height, scaled_width, fp16=self.config.matting.fp16
|
|
)
|
|
|
|
overlap = min(60, optimal_chunk // 10) # 10% overlap, max 60 frames
|
|
|
|
print(f"Calculated optimal chunk size: {optimal_chunk} frames with {overlap} frame overlap")
|
|
return optimal_chunk, overlap
|
|
|
|
def process_chunk(self,
|
|
frames: List[np.ndarray],
|
|
chunk_idx: int = 0) -> List[np.ndarray]:
|
|
"""
|
|
Process a chunk of frames through the matting pipeline
|
|
|
|
Args:
|
|
frames: List of frames to process
|
|
chunk_idx: Chunk index for logging
|
|
|
|
Returns:
|
|
List of matted frames
|
|
"""
|
|
print(f"Processing chunk {chunk_idx} ({len(frames)} frames)")
|
|
|
|
with self.memory_manager.memory_monitor(f"chunk {chunk_idx}"):
|
|
# Initialize SAM2 with frames
|
|
self.sam2_model.init_video_state(frames)
|
|
|
|
# Detect persons in first frame
|
|
first_frame = frames[0]
|
|
detections = self.detector.detect_persons(first_frame)
|
|
|
|
if not detections:
|
|
warnings.warn(f"No persons detected in chunk {chunk_idx}")
|
|
return self._create_empty_masks(frames)
|
|
|
|
print(f"Detected {len(detections)} persons in first frame")
|
|
|
|
# Convert detections to SAM2 prompts
|
|
box_prompts, labels = self.detector.convert_to_sam_prompts(detections)
|
|
|
|
# Add prompts to SAM2
|
|
object_ids = self.sam2_model.add_person_prompts(0, box_prompts, labels)
|
|
print(f"Added prompts for {len(object_ids)} objects")
|
|
|
|
# Propagate masks through chunk
|
|
video_segments = self.sam2_model.propagate_masks(
|
|
start_frame=0,
|
|
max_frames=len(frames)
|
|
)
|
|
|
|
# Apply masks to frames
|
|
matted_frames = []
|
|
for frame_idx, frame in enumerate(tqdm(frames, desc="Applying masks")):
|
|
if frame_idx in video_segments:
|
|
frame_masks = video_segments[frame_idx]
|
|
combined_mask = self.sam2_model.get_combined_mask(frame_masks)
|
|
|
|
matted_frame = self.sam2_model.apply_mask_to_frame(
|
|
frame, combined_mask,
|
|
output_format=self.config.output.format,
|
|
background_color=self.config.output.background_color
|
|
)
|
|
else:
|
|
# No mask for this frame
|
|
matted_frame = self._create_empty_mask_frame(frame)
|
|
|
|
matted_frames.append(matted_frame)
|
|
|
|
# Cleanup SAM2 state
|
|
self.sam2_model.cleanup()
|
|
|
|
return matted_frames
|
|
|
|
def _create_empty_masks(self, frames: List[np.ndarray]) -> List[np.ndarray]:
|
|
"""Create empty masks when no persons detected"""
|
|
empty_frames = []
|
|
for frame in frames:
|
|
empty_frame = self._create_empty_mask_frame(frame)
|
|
empty_frames.append(empty_frame)
|
|
return empty_frames
|
|
|
|
def _create_empty_mask_frame(self, frame: np.ndarray) -> np.ndarray:
|
|
"""Create frame with empty mask (all background)"""
|
|
if self.config.output.format == "alpha":
|
|
# Transparent output
|
|
output = np.zeros((frame.shape[0], frame.shape[1], 4), dtype=np.uint8)
|
|
return output
|
|
else:
|
|
# Green screen background
|
|
return np.full_like(frame, self.config.output.background_color, dtype=np.uint8)
|
|
|
|
def merge_overlapping_chunks(self,
|
|
chunk_results: List[List[np.ndarray]],
|
|
overlap_frames: int) -> List[np.ndarray]:
|
|
"""
|
|
Merge overlapping chunks with blending in overlap regions
|
|
|
|
Args:
|
|
chunk_results: List of chunk results
|
|
overlap_frames: Number of overlapping frames
|
|
|
|
Returns:
|
|
Merged frame sequence
|
|
"""
|
|
if len(chunk_results) == 1:
|
|
return chunk_results[0]
|
|
|
|
merged_frames = []
|
|
|
|
# Add first chunk completely
|
|
merged_frames.extend(chunk_results[0])
|
|
|
|
# Process remaining chunks
|
|
for chunk_idx in range(1, len(chunk_results)):
|
|
chunk = chunk_results[chunk_idx]
|
|
|
|
if overlap_frames > 0:
|
|
# Blend overlap region
|
|
overlap_start = len(merged_frames) - overlap_frames
|
|
|
|
for i in range(overlap_frames):
|
|
if i < len(chunk):
|
|
# Linear blending
|
|
alpha = i / overlap_frames
|
|
|
|
prev_frame = merged_frames[overlap_start + i]
|
|
curr_frame = chunk[i]
|
|
|
|
blended = self._blend_frames(prev_frame, curr_frame, alpha)
|
|
merged_frames[overlap_start + i] = blended
|
|
|
|
# Add remaining frames from current chunk
|
|
merged_frames.extend(chunk[overlap_frames:])
|
|
else:
|
|
# No overlap, just concatenate
|
|
merged_frames.extend(chunk)
|
|
|
|
return merged_frames
|
|
|
|
def _blend_frames(self, frame1: np.ndarray, frame2: np.ndarray, alpha: float) -> np.ndarray:
|
|
"""Blend two frames with alpha blending"""
|
|
if frame1.shape != frame2.shape:
|
|
return frame2 # Fallback to second frame
|
|
|
|
blended = (1 - alpha) * frame1.astype(np.float32) + alpha * frame2.astype(np.float32)
|
|
return blended.astype(np.uint8)
|
|
|
|
def save_video(self, frames: List[np.ndarray], output_path: str):
|
|
"""
|
|
Save processed frames as video
|
|
|
|
Args:
|
|
frames: List of processed frames
|
|
output_path: Output video path
|
|
"""
|
|
if not frames:
|
|
raise ValueError("No frames to save")
|
|
|
|
output_path = Path(output_path)
|
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Determine codec and format based on output format
|
|
if self.config.output.format == "alpha":
|
|
# Use PNG sequence for alpha channel
|
|
self._save_png_sequence(frames, output_path.parent / f"{output_path.stem}_frames")
|
|
else:
|
|
# Save as regular video
|
|
self._save_mp4_video(frames, str(output_path))
|
|
|
|
def _save_png_sequence(self, frames: List[np.ndarray], output_dir: Path):
|
|
"""Save frames as PNG sequence with alpha channel"""
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
for i, frame in enumerate(tqdm(frames, desc="Saving PNG sequence")):
|
|
frame_path = output_dir / f"frame_{i:06d}.png"
|
|
|
|
# Convert BGR to RGBA for PNG
|
|
if frame.shape[2] == 4: # Already RGBA
|
|
frame_rgba = cv2.cvtColor(frame, cv2.COLOR_BGRA2RGBA)
|
|
else: # BGR to RGBA
|
|
frame_rgba = cv2.cvtColor(frame, cv2.COLOR_BGR2RGBA)
|
|
|
|
cv2.imwrite(str(frame_path), frame_rgba)
|
|
|
|
print(f"Saved {len(frames)} PNG frames to {output_dir}")
|
|
|
|
def _save_mp4_video(self, frames: List[np.ndarray], output_path: str):
|
|
"""Save frames as MP4 video"""
|
|
if not frames:
|
|
return
|
|
|
|
height, width = frames[0].shape[:2]
|
|
|
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
|
writer = cv2.VideoWriter(output_path, fourcc, self.fps, (width, height))
|
|
|
|
for frame in tqdm(frames, desc="Writing video"):
|
|
if frame.shape[2] == 4: # Convert RGBA to BGR
|
|
frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2BGR)
|
|
writer.write(frame)
|
|
|
|
writer.release()
|
|
print(f"Saved video to {output_path}")
|
|
|
|
def process_video(self) -> None:
|
|
"""Main video processing pipeline"""
|
|
print("Starting VR180 video processing...")
|
|
|
|
# Load video info
|
|
self.load_video_info(self.config.input.video_path)
|
|
|
|
# Calculate chunking parameters
|
|
chunk_size, overlap_frames = self.calculate_optimal_chunking()
|
|
|
|
# Process video in chunks
|
|
chunk_results = []
|
|
|
|
for start_frame in range(0, self.total_frames, chunk_size - overlap_frames):
|
|
end_frame = min(start_frame + chunk_size, self.total_frames)
|
|
frames_to_read = end_frame - start_frame
|
|
|
|
chunk_idx = len(chunk_results)
|
|
print(f"\nProcessing chunk {chunk_idx}: frames {start_frame}-{end_frame}")
|
|
|
|
# Read chunk frames
|
|
frames = self.read_video_frames(
|
|
self.config.input.video_path,
|
|
start_frame=start_frame,
|
|
num_frames=frames_to_read,
|
|
scale_factor=self.config.processing.scale_factor
|
|
)
|
|
|
|
# Process chunk
|
|
matted_frames = self.process_chunk(frames, chunk_idx)
|
|
chunk_results.append(matted_frames)
|
|
|
|
# Memory cleanup
|
|
self.memory_manager.cleanup_memory()
|
|
|
|
if self.memory_manager.should_emergency_cleanup():
|
|
self.memory_manager.emergency_cleanup()
|
|
|
|
# Merge chunks if multiple
|
|
print("\nMerging chunks...")
|
|
final_frames = self.merge_overlapping_chunks(chunk_results, overlap_frames)
|
|
|
|
# Save results
|
|
print(f"Saving {len(final_frames)} processed frames...")
|
|
self.save_video(final_frames, self.config.output.path)
|
|
|
|
# Print final memory report
|
|
self.memory_manager.print_memory_report()
|
|
|
|
print("Video processing completed!") |