1162 lines
48 KiB
Python
1162 lines
48 KiB
Python
import cv2
|
|
import numpy as np
|
|
from typing import List, Dict, Any, Optional, Tuple, Generator
|
|
from pathlib import Path
|
|
import ffmpeg
|
|
import tempfile
|
|
import shutil
|
|
from tqdm import tqdm
|
|
import warnings
|
|
import time
|
|
import subprocess
|
|
import gc
|
|
import psutil
|
|
import os
|
|
import sys
|
|
|
|
from .config import VR180Config
|
|
from .detector import YOLODetector
|
|
from .sam2_wrapper import SAM2VideoMatting
|
|
from .memory_manager import VRAMManager
|
|
|
|
|
|
class VideoProcessor:
|
|
"""Main video processing pipeline for VR180 matting"""
|
|
|
|
def __init__(self, config: VR180Config):
|
|
self.config = config
|
|
self.memory_manager = VRAMManager(
|
|
max_vram_gb=config.hardware.max_vram_gb,
|
|
device=config.hardware.device
|
|
)
|
|
|
|
# Initialize components
|
|
self.detector = None
|
|
self.sam2_model = None
|
|
|
|
# Video properties
|
|
self.video_info = None
|
|
self.total_frames = 0
|
|
self.fps = 30.0
|
|
self.frame_width = 0
|
|
self.frame_height = 0
|
|
|
|
# Processing statistics
|
|
self.processing_stats = {
|
|
'start_time': None,
|
|
'end_time': None,
|
|
'total_duration': 0,
|
|
'processing_fps': 0,
|
|
'chunks_processed': 0,
|
|
'frames_processed': 0
|
|
}
|
|
|
|
self._initialize_models()
|
|
|
|
def _get_process_memory_info(self) -> Dict[str, float]:
|
|
"""Get detailed memory usage for current process and children"""
|
|
current_process = psutil.Process(os.getpid())
|
|
|
|
# Get memory info for current process
|
|
memory_info = current_process.memory_info()
|
|
current_rss = memory_info.rss / 1024**3 # Convert to GB
|
|
current_vms = memory_info.vms / 1024**3 # Virtual memory
|
|
|
|
# Get memory info for all children
|
|
children_rss = 0
|
|
children_vms = 0
|
|
child_count = 0
|
|
|
|
try:
|
|
for child in current_process.children(recursive=True):
|
|
try:
|
|
child_memory = child.memory_info()
|
|
children_rss += child_memory.rss / 1024**3
|
|
children_vms += child_memory.vms / 1024**3
|
|
child_count += 1
|
|
except (psutil.NoSuchProcess, psutil.AccessDenied):
|
|
pass
|
|
except psutil.NoSuchProcess:
|
|
pass
|
|
|
|
# System memory info
|
|
system_memory = psutil.virtual_memory()
|
|
system_total = system_memory.total / 1024**3
|
|
system_available = system_memory.available / 1024**3
|
|
system_used = system_memory.used / 1024**3
|
|
system_percent = system_memory.percent
|
|
|
|
return {
|
|
'process_rss_gb': current_rss,
|
|
'process_vms_gb': current_vms,
|
|
'children_rss_gb': children_rss,
|
|
'children_vms_gb': children_vms,
|
|
'total_process_gb': current_rss + children_rss,
|
|
'child_count': child_count,
|
|
'system_total_gb': system_total,
|
|
'system_used_gb': system_used,
|
|
'system_available_gb': system_available,
|
|
'system_percent': system_percent
|
|
}
|
|
|
|
def _print_memory_step(self, step_name: str):
|
|
"""Print memory usage for a specific processing step"""
|
|
memory_info = self._get_process_memory_info()
|
|
|
|
print(f"\n📊 MEMORY: {step_name}")
|
|
print(f" Process RSS: {memory_info['process_rss_gb']:.2f} GB")
|
|
if memory_info['children_rss_gb'] > 0:
|
|
print(f" Children RSS: {memory_info['children_rss_gb']:.2f} GB ({memory_info['child_count']} processes)")
|
|
print(f" Total Process: {memory_info['total_process_gb']:.2f} GB")
|
|
print(f" System: {memory_info['system_used_gb']:.1f}/{memory_info['system_total_gb']:.1f} GB ({memory_info['system_percent']:.1f}%)")
|
|
print(f" Available: {memory_info['system_available_gb']:.1f} GB")
|
|
|
|
def _aggressive_memory_cleanup(self, step_name: str = ""):
|
|
"""Perform aggressive memory cleanup and report before/after"""
|
|
if step_name:
|
|
print(f"\n🧹 CLEANUP: Before {step_name}")
|
|
|
|
before_info = self._get_process_memory_info()
|
|
before_rss = before_info['total_process_gb']
|
|
|
|
# Multiple rounds of garbage collection
|
|
for i in range(3):
|
|
gc.collect()
|
|
|
|
# Clear torch cache if available
|
|
try:
|
|
import torch
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
torch.cuda.synchronize()
|
|
except ImportError:
|
|
pass
|
|
|
|
# Clear OpenCV internal caches
|
|
try:
|
|
# Clear OpenCV video capture cache
|
|
cv2.setUseOptimized(False)
|
|
cv2.setUseOptimized(True)
|
|
except Exception:
|
|
pass
|
|
|
|
# Clear CuPy caches if available
|
|
try:
|
|
import cupy as cp
|
|
cp._default_memory_pool.free_all_blocks()
|
|
cp._default_pinned_memory_pool.free_all_blocks()
|
|
cp.get_default_memory_pool().free_all_blocks()
|
|
cp.get_default_pinned_memory_pool().free_all_blocks()
|
|
except ImportError:
|
|
pass
|
|
except Exception as e:
|
|
print(f" Warning: Could not clear CuPy cache: {e}")
|
|
|
|
# Force Linux to release memory back to OS
|
|
if sys.platform == 'linux':
|
|
try:
|
|
import ctypes
|
|
libc = ctypes.CDLL("libc.so.6")
|
|
libc.malloc_trim(0)
|
|
except Exception as e:
|
|
print(f" Warning: Could not trim memory: {e}")
|
|
|
|
# Brief pause to allow cleanup
|
|
time.sleep(0.1)
|
|
|
|
after_info = self._get_process_memory_info()
|
|
after_rss = after_info['total_process_gb']
|
|
freed_memory = before_rss - after_rss
|
|
|
|
if step_name:
|
|
print(f" Before: {before_rss:.2f} GB → After: {after_rss:.2f} GB")
|
|
print(f" Freed: {freed_memory:.2f} GB")
|
|
|
|
def _initialize_models(self):
|
|
"""Initialize YOLO detector and SAM2 model"""
|
|
print("Initializing models...")
|
|
|
|
with self.memory_manager.memory_monitor("model loading"):
|
|
# Initialize YOLO detector
|
|
self.detector = YOLODetector(
|
|
model_name=self.config.detection.model,
|
|
confidence_threshold=self.config.detection.confidence_threshold,
|
|
device=self.config.hardware.device
|
|
)
|
|
|
|
# Initialize SAM2 model
|
|
self.sam2_model = SAM2VideoMatting(
|
|
model_cfg=self.config.matting.sam2_model_cfg,
|
|
checkpoint_path=self.config.matting.sam2_checkpoint,
|
|
device=self.config.hardware.device,
|
|
memory_offload=self.config.matting.memory_offload,
|
|
fp16=self.config.matting.fp16
|
|
)
|
|
|
|
def load_video_info(self, video_path: str) -> Dict[str, Any]:
|
|
"""Load video metadata using ffmpeg"""
|
|
try:
|
|
probe = ffmpeg.probe(video_path)
|
|
video_stream = next(
|
|
(stream for stream in probe['streams'] if stream['codec_type'] == 'video'),
|
|
None
|
|
)
|
|
|
|
if video_stream is None:
|
|
raise ValueError("No video stream found")
|
|
|
|
self.video_info = {
|
|
'width': int(video_stream['width']),
|
|
'height': int(video_stream['height']),
|
|
'fps': eval(video_stream['r_frame_rate']),
|
|
'duration': float(video_stream.get('duration', 0)),
|
|
'nb_frames': int(video_stream.get('nb_frames', 0)),
|
|
'codec': video_stream['codec_name'],
|
|
'pix_fmt': video_stream.get('pix_fmt', 'yuv420p')
|
|
}
|
|
|
|
self.frame_width = self.video_info['width']
|
|
self.frame_height = self.video_info['height']
|
|
self.fps = self.video_info['fps']
|
|
self.total_frames = self.video_info['nb_frames']
|
|
|
|
print(f"Video info: {self.frame_width}x{self.frame_height} @ {self.fps:.2f}fps")
|
|
print(f"Total frames: {self.total_frames}, Duration: {self.video_info['duration']:.1f}s")
|
|
|
|
return self.video_info
|
|
|
|
except Exception as e:
|
|
raise RuntimeError(f"Failed to load video info: {e}")
|
|
|
|
def read_video_frames(self,
|
|
video_path: str,
|
|
start_frame: int = 0,
|
|
num_frames: Optional[int] = None,
|
|
scale_factor: float = 1.0) -> List[np.ndarray]:
|
|
"""
|
|
Read video frames with optional scaling
|
|
|
|
Args:
|
|
video_path: Path to video file
|
|
start_frame: Starting frame index
|
|
num_frames: Number of frames to read (None for all)
|
|
scale_factor: Scaling factor for frames
|
|
|
|
Returns:
|
|
List of video frames
|
|
"""
|
|
cap = cv2.VideoCapture(video_path)
|
|
|
|
if not cap.isOpened():
|
|
raise RuntimeError(f"Failed to open video: {video_path}")
|
|
|
|
# Set starting position
|
|
if start_frame > 0:
|
|
cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
|
|
|
|
frames = []
|
|
frame_count = 0
|
|
|
|
with tqdm(desc="Reading frames", total=num_frames) as pbar:
|
|
while True:
|
|
ret, frame = cap.read()
|
|
if not ret:
|
|
break
|
|
|
|
# Apply scaling if needed
|
|
if scale_factor != 1.0:
|
|
new_width = int(frame.shape[1] * scale_factor)
|
|
new_height = int(frame.shape[0] * scale_factor)
|
|
frame = cv2.resize(frame, (new_width, new_height),
|
|
interpolation=cv2.INTER_AREA)
|
|
|
|
frames.append(frame)
|
|
frame_count += 1
|
|
pbar.update(1)
|
|
|
|
if num_frames is not None and frame_count >= num_frames:
|
|
break
|
|
|
|
cap.release()
|
|
print(f"Read {len(frames)} frames")
|
|
return frames
|
|
|
|
def read_video_frames_dual_resolution(self,
|
|
video_path: str,
|
|
start_frame: int = 0,
|
|
num_frames: Optional[int] = None,
|
|
scale_factor: float = 0.5) -> Dict[str, List[np.ndarray]]:
|
|
"""
|
|
Read video frames at both original and scaled resolution for dual-resolution processing
|
|
|
|
Args:
|
|
video_path: Path to video file
|
|
start_frame: Starting frame index
|
|
num_frames: Number of frames to read (None for all)
|
|
scale_factor: Scaling factor for inference frames
|
|
|
|
Returns:
|
|
Dictionary with 'original' and 'scaled' frame lists
|
|
"""
|
|
cap = cv2.VideoCapture(video_path)
|
|
|
|
if not cap.isOpened():
|
|
raise RuntimeError(f"Could not open video file: {video_path}")
|
|
|
|
# Set starting position
|
|
cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
|
|
|
|
original_frames = []
|
|
scaled_frames = []
|
|
frame_count = 0
|
|
|
|
# Progress tracking
|
|
total_to_read = num_frames if num_frames else self.total_frames - start_frame
|
|
|
|
with tqdm(total=total_to_read, desc="Reading dual-resolution frames") as pbar:
|
|
while True:
|
|
ret, frame = cap.read()
|
|
if not ret:
|
|
break
|
|
|
|
# Store original frame
|
|
original_frames.append(frame.copy())
|
|
|
|
# Create scaled frame for inference
|
|
if scale_factor != 1.0:
|
|
new_width = int(frame.shape[1] * scale_factor)
|
|
new_height = int(frame.shape[0] * scale_factor)
|
|
scaled_frame = cv2.resize(frame, (new_width, new_height),
|
|
interpolation=cv2.INTER_AREA)
|
|
else:
|
|
scaled_frame = frame.copy()
|
|
|
|
scaled_frames.append(scaled_frame)
|
|
frame_count += 1
|
|
pbar.update(1)
|
|
|
|
if num_frames is not None and frame_count >= num_frames:
|
|
break
|
|
|
|
cap.release()
|
|
|
|
print(f"Loaded {len(original_frames)} frames:")
|
|
print(f" Original: {original_frames[0].shape} per frame")
|
|
print(f" Scaled: {scaled_frames[0].shape} per frame (scale_factor={scale_factor})")
|
|
|
|
return {
|
|
'original': original_frames,
|
|
'scaled': scaled_frames
|
|
}
|
|
|
|
def upscale_mask(self, mask: np.ndarray, target_shape: tuple, method: str = 'cubic') -> np.ndarray:
|
|
"""
|
|
Upscale a mask from inference resolution to original resolution
|
|
|
|
Args:
|
|
mask: Low-resolution mask (H, W)
|
|
target_shape: Target shape (H, W) for upscaling
|
|
method: Upscaling method ('nearest', 'cubic', 'area')
|
|
|
|
Returns:
|
|
Upscaled mask at target resolution
|
|
"""
|
|
if mask.shape[:2] == target_shape[:2]:
|
|
return mask # Already correct size
|
|
|
|
# Ensure mask is 2D
|
|
if mask.ndim == 3:
|
|
mask = mask.squeeze()
|
|
|
|
# Choose interpolation method
|
|
if method == 'nearest':
|
|
interpolation = cv2.INTER_NEAREST # Crisp edges, good for sharp subjects
|
|
elif method == 'cubic':
|
|
interpolation = cv2.INTER_CUBIC # Smooth edges, good for most content
|
|
elif method == 'area':
|
|
interpolation = cv2.INTER_AREA # Good for downscaling, not upscaling
|
|
else:
|
|
interpolation = cv2.INTER_CUBIC # Default to cubic
|
|
|
|
# Upscale mask
|
|
upscaled_mask = cv2.resize(
|
|
mask.astype(np.uint8),
|
|
(target_shape[1], target_shape[0]), # (width, height) for cv2.resize
|
|
interpolation=interpolation
|
|
)
|
|
|
|
# Convert back to boolean if it was originally boolean
|
|
if mask.dtype == bool:
|
|
upscaled_mask = upscaled_mask.astype(bool)
|
|
|
|
return upscaled_mask
|
|
|
|
def calculate_optimal_chunking(self) -> Tuple[int, int]:
|
|
"""
|
|
Calculate optimal chunk size and overlap based on memory constraints
|
|
|
|
Returns:
|
|
Tuple of (chunk_size, overlap_frames)
|
|
"""
|
|
if self.config.processing.chunk_size > 0:
|
|
return self.config.processing.chunk_size, self.config.processing.overlap_frames
|
|
|
|
# Calculate based on memory constraints
|
|
scaled_height = int(self.frame_height * self.config.processing.scale_factor)
|
|
scaled_width = int(self.frame_width * self.config.processing.scale_factor)
|
|
|
|
optimal_chunk = self.memory_manager.get_optimal_chunk_size(
|
|
scaled_height, scaled_width, fp16=self.config.matting.fp16
|
|
)
|
|
|
|
overlap = min(60, optimal_chunk // 10) # 10% overlap, max 60 frames
|
|
|
|
print(f"Calculated optimal chunk size: {optimal_chunk} frames with {overlap} frame overlap")
|
|
return optimal_chunk, overlap
|
|
|
|
def process_chunk(self,
|
|
frames: List[np.ndarray],
|
|
chunk_idx: int = 0) -> List[np.ndarray]:
|
|
"""
|
|
Process a chunk of frames through the matting pipeline
|
|
|
|
Args:
|
|
frames: List of frames to process
|
|
chunk_idx: Chunk index for logging
|
|
|
|
Returns:
|
|
List of matted frames
|
|
"""
|
|
print(f"Processing chunk {chunk_idx} ({len(frames)} frames)")
|
|
|
|
with self.memory_manager.memory_monitor(f"chunk {chunk_idx}"):
|
|
# Initialize SAM2 with frames
|
|
self.sam2_model.init_video_state(frames)
|
|
|
|
# Detect persons in first frame
|
|
first_frame = frames[0]
|
|
detections = self.detector.detect_persons(first_frame)
|
|
|
|
if not detections:
|
|
warnings.warn(f"No persons detected in chunk {chunk_idx}")
|
|
return self._create_empty_masks(frames)
|
|
|
|
print(f"Detected {len(detections)} persons in first frame")
|
|
|
|
# Convert detections to SAM2 prompts
|
|
box_prompts, labels = self.detector.convert_to_sam_prompts(detections)
|
|
|
|
# Add prompts to SAM2
|
|
object_ids = self.sam2_model.add_person_prompts(0, box_prompts, labels)
|
|
print(f"Added prompts for {len(object_ids)} objects")
|
|
|
|
# Propagate masks through chunk
|
|
video_segments = self.sam2_model.propagate_masks(
|
|
start_frame=0,
|
|
max_frames=len(frames)
|
|
)
|
|
|
|
# Apply masks to frames
|
|
matted_frames = []
|
|
for frame_idx, frame in enumerate(tqdm(frames, desc="Applying masks")):
|
|
if frame_idx in video_segments:
|
|
frame_masks = video_segments[frame_idx]
|
|
combined_mask = self.sam2_model.get_combined_mask(frame_masks)
|
|
|
|
matted_frame = self.sam2_model.apply_mask_to_frame(
|
|
frame, combined_mask,
|
|
output_format=self.config.output.format,
|
|
background_color=self.config.output.background_color
|
|
)
|
|
else:
|
|
# No mask for this frame
|
|
matted_frame = self._create_empty_mask_frame(frame)
|
|
|
|
matted_frames.append(matted_frame)
|
|
|
|
# Cleanup SAM2 state
|
|
self.sam2_model.cleanup()
|
|
|
|
return matted_frames
|
|
|
|
def process_chunk_dual_resolution(self,
|
|
frame_data: Dict[str, List[np.ndarray]],
|
|
chunk_idx: int = 0) -> List[np.ndarray]:
|
|
"""
|
|
Process a chunk using dual-resolution approach: inference at low-res, output at full-res
|
|
|
|
Args:
|
|
frame_data: Dictionary with 'original' and 'scaled' frame lists
|
|
chunk_idx: Chunk index for logging
|
|
|
|
Returns:
|
|
List of matted frames at original resolution
|
|
"""
|
|
original_frames = frame_data['original']
|
|
scaled_frames = frame_data['scaled']
|
|
|
|
print(f"Processing chunk {chunk_idx} with dual-resolution ({len(original_frames)} frames)")
|
|
print(f" Inference: {scaled_frames[0].shape} → Output: {original_frames[0].shape}")
|
|
|
|
with self.memory_manager.memory_monitor(f"dual-res chunk {chunk_idx}"):
|
|
# Initialize SAM2 with scaled frames for inference
|
|
self.sam2_model.init_video_state(scaled_frames)
|
|
|
|
# Detect persons in first scaled frame
|
|
first_scaled_frame = scaled_frames[0]
|
|
detections = self.detector.detect_persons(first_scaled_frame)
|
|
|
|
if not detections:
|
|
warnings.warn(f"No persons detected in chunk {chunk_idx}")
|
|
return self._create_empty_masks(original_frames)
|
|
|
|
print(f"Detected {len(detections)} persons in first frame (at inference resolution)")
|
|
|
|
# Convert detections to SAM2 prompts (detections are already at scaled resolution)
|
|
box_prompts, labels = self.detector.convert_to_sam_prompts(detections)
|
|
|
|
# Add prompts to SAM2
|
|
object_ids = self.sam2_model.add_person_prompts(0, box_prompts, labels)
|
|
print(f"Added prompts for {len(object_ids)} objects")
|
|
|
|
# Propagate masks through chunk at inference resolution
|
|
video_segments = self.sam2_model.propagate_masks(
|
|
start_frame=0,
|
|
max_frames=len(scaled_frames)
|
|
)
|
|
|
|
# Apply upscaled masks to original resolution frames
|
|
matted_frames = []
|
|
original_shape = original_frames[0].shape[:2] # (H, W)
|
|
|
|
for frame_idx, original_frame in enumerate(tqdm(original_frames, desc="Applying upscaled masks")):
|
|
if frame_idx in video_segments:
|
|
frame_masks = video_segments[frame_idx]
|
|
|
|
# Get combined mask at inference resolution
|
|
combined_mask_scaled = self.sam2_model.get_combined_mask(frame_masks)
|
|
|
|
if combined_mask_scaled is not None:
|
|
# Upscale mask to original resolution
|
|
combined_mask_full = self.upscale_mask(
|
|
combined_mask_scaled,
|
|
target_shape=original_shape,
|
|
method='cubic' # Smooth upscaling for masks
|
|
)
|
|
|
|
# Apply upscaled mask to original resolution frame
|
|
matted_frame = self.sam2_model.apply_mask_to_frame(
|
|
original_frame, combined_mask_full,
|
|
output_format=self.config.output.format,
|
|
background_color=self.config.output.background_color
|
|
)
|
|
else:
|
|
# No mask for this frame
|
|
matted_frame = self._create_empty_mask_frame(original_frame)
|
|
else:
|
|
# No mask for this frame
|
|
matted_frame = self._create_empty_mask_frame(original_frame)
|
|
|
|
matted_frames.append(matted_frame)
|
|
|
|
# Cleanup SAM2 state
|
|
self.sam2_model.cleanup()
|
|
|
|
print(f"✅ Dual-resolution processing complete: {len(matted_frames)} frames at full resolution")
|
|
return matted_frames
|
|
|
|
def _create_empty_masks(self, frames: List[np.ndarray]) -> List[np.ndarray]:
|
|
"""Create empty masks when no persons detected"""
|
|
empty_frames = []
|
|
for frame in frames:
|
|
empty_frame = self._create_empty_mask_frame(frame)
|
|
empty_frames.append(empty_frame)
|
|
return empty_frames
|
|
|
|
def _create_empty_mask_frame(self, frame: np.ndarray) -> np.ndarray:
|
|
"""Create frame with empty mask (all background)"""
|
|
if self.config.output.format == "alpha":
|
|
# Transparent output
|
|
output = np.zeros((frame.shape[0], frame.shape[1], 4), dtype=np.uint8)
|
|
return output
|
|
else:
|
|
# Green screen background
|
|
return np.full_like(frame, self.config.output.background_color, dtype=np.uint8)
|
|
|
|
def merge_chunks_streaming(self, chunk_files: List[Path], output_path: str,
|
|
overlap_frames: int = 0, audio_source: str = None) -> None:
|
|
"""
|
|
Merge processed chunks using streaming approach (no memory accumulation)
|
|
|
|
Args:
|
|
chunk_files: List of chunk result files (.npz)
|
|
output_path: Final output video path
|
|
overlap_frames: Number of overlapping frames
|
|
audio_source: Audio source file for final video
|
|
"""
|
|
if not chunk_files:
|
|
raise ValueError("No chunk files to merge")
|
|
|
|
print(f"🎬 TRUE Streaming merge: {len(chunk_files)} chunks → {output_path}")
|
|
|
|
# Create temporary directory for frame images
|
|
import tempfile
|
|
temp_frames_dir = Path(tempfile.mkdtemp(prefix="merge_frames_"))
|
|
frame_counter = 0
|
|
|
|
try:
|
|
print(f"📁 Using temp frames dir: {temp_frames_dir}")
|
|
|
|
# Process each chunk frame-by-frame (true streaming)
|
|
for i, chunk_file in enumerate(chunk_files):
|
|
print(f"📼 Processing chunk {i+1}/{len(chunk_files)}: {chunk_file.name}")
|
|
|
|
# Load chunk metadata without loading frames array
|
|
chunk_data = np.load(str(chunk_file))
|
|
frames_array = chunk_data['frames'] # This is still mmap'd, not loaded
|
|
total_frames_in_chunk = frames_array.shape[0]
|
|
|
|
# Determine which frames to skip for overlap
|
|
start_frame_idx = overlap_frames if i > 0 and overlap_frames > 0 else 0
|
|
frames_to_process = total_frames_in_chunk - start_frame_idx
|
|
|
|
if start_frame_idx > 0:
|
|
print(f" ✂️ Skipping first {start_frame_idx} overlapping frames")
|
|
|
|
print(f" 🔄 Processing {frames_to_process} frames one-by-one...")
|
|
|
|
# Process frames ONE AT A TIME (true streaming)
|
|
for frame_idx in range(start_frame_idx, total_frames_in_chunk):
|
|
# Load only ONE frame at a time
|
|
frame = frames_array[frame_idx] # Load single frame
|
|
|
|
# Save frame directly to disk
|
|
frame_path = temp_frames_dir / f"frame_{frame_counter:06d}.jpg"
|
|
success = cv2.imwrite(str(frame_path), frame, [cv2.IMWRITE_JPEG_QUALITY, 95])
|
|
if not success:
|
|
raise RuntimeError(f"Failed to save frame {frame_counter}")
|
|
|
|
frame_counter += 1
|
|
|
|
# Periodic progress and cleanup
|
|
if frame_counter % 100 == 0:
|
|
print(f" 💾 Saved {frame_counter} frames...")
|
|
gc.collect() # Periodic cleanup
|
|
|
|
print(f" ✅ Saved {frames_to_process} frames to disk (total: {frame_counter})")
|
|
|
|
# Close chunk file and cleanup
|
|
chunk_data.close()
|
|
del chunk_data, frames_array
|
|
|
|
# Don't delete checkpoint files - they're needed for potential resume
|
|
# The checkpoint system manages cleanup separately
|
|
print(f" 📋 Keeping checkpoint file: {chunk_file.name}")
|
|
|
|
# Aggressive cleanup and memory monitoring after each chunk
|
|
self._aggressive_memory_cleanup(f"After streaming merge chunk {i}")
|
|
|
|
# Memory safety check
|
|
memory_info = self._get_process_memory_info()
|
|
if memory_info['total_process_gb'] > 35: # Warning if approaching 46GB limit
|
|
print(f"⚠️ High memory usage: {memory_info['total_process_gb']:.1f}GB - forcing cleanup")
|
|
gc.collect()
|
|
import torch
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
|
|
# Create final video directly from frame images using ffmpeg
|
|
print(f"📹 Creating final video from {frame_counter} frames...")
|
|
self._create_video_from_frames(temp_frames_dir, Path(output_path), frame_counter)
|
|
|
|
# Add audio if provided
|
|
if audio_source:
|
|
self._add_audio_to_video(output_path, audio_source)
|
|
|
|
except Exception as e:
|
|
print(f"❌ Streaming merge failed: {e}")
|
|
raise
|
|
|
|
finally:
|
|
# Cleanup temporary frames directory
|
|
try:
|
|
if temp_frames_dir.exists():
|
|
import shutil
|
|
shutil.rmtree(temp_frames_dir)
|
|
print(f"🗑️ Cleaned up temp frames dir: {temp_frames_dir}")
|
|
except Exception as e:
|
|
print(f"⚠️ Could not cleanup temp frames dir: {e}")
|
|
|
|
# Memory cleanup
|
|
gc.collect()
|
|
|
|
print(f"✅ TRUE Streaming merge complete: {output_path}")
|
|
|
|
def _create_video_from_frames(self, frames_dir: Path, output_path: Path, frame_count: int):
|
|
"""Create video directly from frame images using ffmpeg (memory efficient)"""
|
|
import subprocess
|
|
|
|
frame_pattern = str(frames_dir / "frame_%06d.jpg")
|
|
fps = self.video_info['fps'] if hasattr(self, 'video_info') and self.video_info else 30.0
|
|
|
|
print(f"🎬 Creating video with ffmpeg: {frame_count} frames at {fps} fps")
|
|
|
|
# Use GPU encoding if available, fallback to CPU
|
|
gpu_cmd = [
|
|
'ffmpeg', '-y', # -y to overwrite output file
|
|
'-framerate', str(fps),
|
|
'-i', frame_pattern,
|
|
'-c:v', 'h264_nvenc', # NVIDIA GPU encoder
|
|
'-preset', 'fast',
|
|
'-cq', '18', # Quality for GPU encoding
|
|
'-pix_fmt', 'yuv420p',
|
|
str(output_path)
|
|
]
|
|
|
|
cpu_cmd = [
|
|
'ffmpeg', '-y', # -y to overwrite output file
|
|
'-framerate', str(fps),
|
|
'-i', frame_pattern,
|
|
'-c:v', 'libx264', # CPU encoder
|
|
'-preset', 'medium',
|
|
'-crf', '18', # Quality for CPU encoding
|
|
'-pix_fmt', 'yuv420p',
|
|
str(output_path)
|
|
]
|
|
|
|
# Try GPU first
|
|
print(f"🚀 Trying GPU encoding...")
|
|
result = subprocess.run(gpu_cmd, capture_output=True, text=True)
|
|
|
|
if result.returncode != 0:
|
|
print("⚠️ GPU encoding failed, using CPU...")
|
|
print(f"🔄 CPU encoding...")
|
|
result = subprocess.run(cpu_cmd, capture_output=True, text=True)
|
|
else:
|
|
print("✅ GPU encoding successful!")
|
|
|
|
if result.returncode != 0:
|
|
print(f"❌ FFmpeg stdout: {result.stdout}")
|
|
print(f"❌ FFmpeg stderr: {result.stderr}")
|
|
raise RuntimeError(f"FFmpeg failed with return code {result.returncode}")
|
|
|
|
print(f"✅ Video created successfully: {output_path}")
|
|
|
|
def _add_audio_to_video(self, video_path: str, audio_source: str):
|
|
"""Add audio to video using ffmpeg"""
|
|
import subprocess
|
|
import tempfile
|
|
|
|
try:
|
|
# Create temporary file for output with audio
|
|
temp_path = Path(video_path).with_suffix('.temp.mp4')
|
|
|
|
cmd = [
|
|
'ffmpeg', '-y',
|
|
'-i', str(video_path), # Input video (no audio)
|
|
'-i', str(audio_source), # Input audio source
|
|
'-c:v', 'copy', # Copy video without re-encoding
|
|
'-c:a', 'aac', # Encode audio as AAC
|
|
'-map', '0:v:0', # Map video from first input
|
|
'-map', '1:a:0', # Map audio from second input
|
|
'-shortest', # Match shortest stream duration
|
|
str(temp_path)
|
|
]
|
|
|
|
print(f"🎵 Adding audio: {audio_source} → {video_path}")
|
|
result = subprocess.run(cmd, capture_output=True, text=True)
|
|
|
|
if result.returncode != 0:
|
|
print(f"⚠️ Audio addition failed: {result.stderr}")
|
|
# Keep original video without audio
|
|
return
|
|
|
|
# Replace original with audio version
|
|
Path(video_path).unlink()
|
|
temp_path.rename(video_path)
|
|
print(f"✅ Audio added successfully")
|
|
|
|
except Exception as e:
|
|
print(f"⚠️ Could not add audio: {e}")
|
|
|
|
def merge_overlapping_chunks(self,
|
|
chunk_results: List[List[np.ndarray]],
|
|
overlap_frames: int) -> List[np.ndarray]:
|
|
"""
|
|
Legacy merge method - DEPRECATED due to memory accumulation
|
|
Use merge_chunks_streaming() instead for memory efficiency
|
|
"""
|
|
import warnings
|
|
warnings.warn("merge_overlapping_chunks() is deprecated due to memory accumulation. Use merge_chunks_streaming()",
|
|
DeprecationWarning, stacklevel=2)
|
|
|
|
if len(chunk_results) == 1:
|
|
return chunk_results[0]
|
|
|
|
merged_frames = []
|
|
|
|
# Add first chunk completely
|
|
merged_frames.extend(chunk_results[0])
|
|
|
|
# Process remaining chunks
|
|
for chunk_idx in range(1, len(chunk_results)):
|
|
chunk = chunk_results[chunk_idx]
|
|
|
|
if overlap_frames > 0:
|
|
# Blend overlap region
|
|
overlap_start = len(merged_frames) - overlap_frames
|
|
|
|
for i in range(overlap_frames):
|
|
if i < len(chunk):
|
|
# Linear blending
|
|
alpha = i / overlap_frames
|
|
|
|
prev_frame = merged_frames[overlap_start + i]
|
|
curr_frame = chunk[i]
|
|
|
|
blended = self._blend_frames(prev_frame, curr_frame, alpha)
|
|
merged_frames[overlap_start + i] = blended
|
|
|
|
# Add remaining frames from current chunk
|
|
merged_frames.extend(chunk[overlap_frames:])
|
|
else:
|
|
# No overlap, just concatenate
|
|
merged_frames.extend(chunk)
|
|
|
|
return merged_frames
|
|
|
|
def _blend_frames(self, frame1: np.ndarray, frame2: np.ndarray, alpha: float) -> np.ndarray:
|
|
"""Blend two frames with alpha blending"""
|
|
if frame1.shape != frame2.shape:
|
|
return frame2 # Fallback to second frame
|
|
|
|
blended = (1 - alpha) * frame1.astype(np.float32) + alpha * frame2.astype(np.float32)
|
|
return blended.astype(np.uint8)
|
|
|
|
def save_video(self, frames: List[np.ndarray], output_path: str):
|
|
"""
|
|
Save processed frames as video
|
|
|
|
Args:
|
|
frames: List of processed frames
|
|
output_path: Output video path
|
|
"""
|
|
if not frames:
|
|
raise ValueError("No frames to save")
|
|
|
|
output_path = Path(output_path)
|
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Determine codec and format based on output format
|
|
if self.config.output.format == "alpha":
|
|
# Use PNG sequence for alpha channel
|
|
self._save_png_sequence(frames, output_path.parent / f"{output_path.stem}_frames")
|
|
else:
|
|
# Save as regular video
|
|
self._save_mp4_video(frames, str(output_path))
|
|
|
|
def _save_png_sequence(self, frames: List[np.ndarray], output_dir: Path):
|
|
"""Save frames as PNG sequence with alpha channel"""
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
for i, frame in enumerate(tqdm(frames, desc="Saving PNG sequence")):
|
|
frame_path = output_dir / f"frame_{i:06d}.png"
|
|
|
|
# Convert BGR to RGBA for PNG
|
|
if frame.shape[2] == 4: # Already RGBA
|
|
frame_rgba = cv2.cvtColor(frame, cv2.COLOR_BGRA2RGBA)
|
|
else: # BGR to RGBA
|
|
frame_rgba = cv2.cvtColor(frame, cv2.COLOR_BGR2RGBA)
|
|
|
|
cv2.imwrite(str(frame_path), frame_rgba)
|
|
|
|
print(f"Saved {len(frames)} PNG frames to {output_dir}")
|
|
|
|
def _save_mp4_video(self, frames: List[np.ndarray], output_path: str):
|
|
"""Save frames as MP4 video with audio preservation"""
|
|
if not frames:
|
|
return
|
|
|
|
output_path = Path(output_path)
|
|
temp_frames_dir = output_path.parent / f"temp_frames_{output_path.stem}"
|
|
temp_frames_dir.mkdir(exist_ok=True)
|
|
|
|
try:
|
|
# Save frames as images
|
|
print("Saving frames as images...")
|
|
for i, frame in enumerate(tqdm(frames, desc="Saving frames")):
|
|
if frame.shape[2] == 4: # Convert RGBA to BGR
|
|
frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2BGR)
|
|
|
|
frame_path = temp_frames_dir / f"frame_{i:06d}.jpg"
|
|
cv2.imwrite(str(frame_path), frame, [cv2.IMWRITE_JPEG_QUALITY, 95])
|
|
|
|
# Create video with ffmpeg
|
|
self._create_video_with_ffmpeg(temp_frames_dir, output_path, len(frames))
|
|
|
|
finally:
|
|
# Cleanup temporary frames
|
|
if temp_frames_dir.exists():
|
|
shutil.rmtree(temp_frames_dir)
|
|
|
|
def _create_video_with_ffmpeg(self, frames_dir: Path, output_path: Path, frame_count: int):
|
|
"""Create video using ffmpeg with audio preservation"""
|
|
frame_pattern = str(frames_dir / "frame_%06d.jpg")
|
|
|
|
if self.config.output.preserve_audio:
|
|
# Create video with audio from input
|
|
cmd = [
|
|
'ffmpeg', '-y',
|
|
'-framerate', str(self.fps),
|
|
'-i', frame_pattern,
|
|
'-i', str(self.config.input.video_path), # Input video for audio
|
|
'-c:v', 'h264_nvenc', # Try GPU encoding first
|
|
'-preset', 'fast',
|
|
'-cq', '18',
|
|
'-c:a', 'copy', # Copy audio without re-encoding
|
|
'-map', '0:v:0', # Map video from frames
|
|
'-map', '1:a:0', # Map audio from input video
|
|
'-shortest', # Match shortest stream duration
|
|
'-pix_fmt', 'yuv420p',
|
|
str(output_path)
|
|
]
|
|
else:
|
|
# Create video without audio
|
|
cmd = [
|
|
'ffmpeg', '-y',
|
|
'-framerate', str(self.fps),
|
|
'-i', frame_pattern,
|
|
'-c:v', 'h264_nvenc',
|
|
'-preset', 'fast',
|
|
'-cq', '18',
|
|
'-pix_fmt', 'yuv420p',
|
|
str(output_path)
|
|
]
|
|
|
|
print(f"Creating video with ffmpeg...")
|
|
result = subprocess.run(cmd, capture_output=True, text=True)
|
|
|
|
if result.returncode != 0:
|
|
# Try CPU encoding as fallback
|
|
print("GPU encoding failed, trying CPU encoding...")
|
|
cmd[cmd.index('h264_nvenc')] = 'libx264'
|
|
cmd[cmd.index('-cq')] = '-crf' # Change quality parameter for CPU
|
|
|
|
result = subprocess.run(cmd, capture_output=True, text=True)
|
|
|
|
if result.returncode != 0:
|
|
print(f"FFmpeg stdout: {result.stdout}")
|
|
print(f"FFmpeg stderr: {result.stderr}")
|
|
raise RuntimeError(f"FFmpeg failed with return code {result.returncode}")
|
|
|
|
# Verify frame count if sync verification is enabled
|
|
if self.config.output.verify_sync:
|
|
self._verify_frame_count(output_path, frame_count)
|
|
|
|
print(f"Saved video to {output_path}")
|
|
|
|
def _verify_frame_count(self, video_path: Path, expected_frames: int):
|
|
"""Verify output video has correct frame count"""
|
|
try:
|
|
probe = ffmpeg.probe(str(video_path))
|
|
video_stream = next(
|
|
(stream for stream in probe['streams'] if stream['codec_type'] == 'video'),
|
|
None
|
|
)
|
|
|
|
if video_stream:
|
|
actual_frames = int(video_stream.get('nb_frames', 0))
|
|
if actual_frames != expected_frames:
|
|
print(f"⚠️ Frame count mismatch: expected {expected_frames}, got {actual_frames}")
|
|
else:
|
|
print(f"✅ Frame count verified: {actual_frames} frames")
|
|
except Exception as e:
|
|
print(f"⚠️ Could not verify frame count: {e}")
|
|
|
|
def process_video(self) -> None:
|
|
"""Main video processing pipeline with checkpoint/resume support"""
|
|
self.processing_stats['start_time'] = time.time()
|
|
print("Starting VR180 video processing...")
|
|
|
|
# Load video info
|
|
self.load_video_info(self.config.input.video_path)
|
|
|
|
# Initialize checkpoint manager
|
|
from .checkpoint_manager import CheckpointManager
|
|
checkpoint_mgr = CheckpointManager(
|
|
self.config.input.video_path,
|
|
self.config.output.path
|
|
)
|
|
|
|
# Check for existing checkpoints
|
|
resume_info = checkpoint_mgr.get_resume_info()
|
|
if resume_info['can_resume']:
|
|
print(f"\n🔄 RESUME DETECTED:")
|
|
print(f" Found {resume_info['completed_chunks']} completed chunks")
|
|
print(f" Continue from where we left off? (saves time!)")
|
|
checkpoint_mgr.print_status()
|
|
|
|
# Calculate chunking parameters
|
|
chunk_size, overlap_frames = self.calculate_optimal_chunking()
|
|
|
|
# Calculate total chunks
|
|
total_chunks = 0
|
|
for _ in range(0, self.total_frames, chunk_size - overlap_frames):
|
|
total_chunks += 1
|
|
checkpoint_mgr.set_total_chunks(total_chunks)
|
|
|
|
# Process video in chunks
|
|
chunk_files = [] # Store file paths instead of frame data
|
|
temp_chunk_dir = Path(tempfile.mkdtemp(prefix="vr180_chunks_"))
|
|
|
|
try:
|
|
chunk_idx = 0
|
|
for start_frame in range(0, self.total_frames, chunk_size - overlap_frames):
|
|
end_frame = min(start_frame + chunk_size, self.total_frames)
|
|
frames_to_read = end_frame - start_frame
|
|
|
|
# Check if this chunk was already processed
|
|
existing_chunk = checkpoint_mgr.get_chunk_file(chunk_idx)
|
|
if existing_chunk:
|
|
print(f"\n✅ Chunk {chunk_idx} already processed: {existing_chunk.name}")
|
|
chunk_files.append(existing_chunk)
|
|
chunk_idx += 1
|
|
continue
|
|
|
|
print(f"\nProcessing chunk {chunk_idx}: frames {start_frame}-{end_frame}")
|
|
|
|
# Read chunk frames at dual resolution
|
|
print(f"🔄 Reading frames at dual resolution (scale_factor={self.config.processing.scale_factor})")
|
|
frame_data = self.read_video_frames_dual_resolution(
|
|
self.config.input.video_path,
|
|
start_frame=start_frame,
|
|
num_frames=frames_to_read,
|
|
scale_factor=self.config.processing.scale_factor
|
|
)
|
|
|
|
# Process chunk with dual-resolution approach
|
|
matted_frames = self.process_chunk_dual_resolution(frame_data, chunk_idx)
|
|
|
|
# Save chunk to disk immediately to free memory
|
|
chunk_path = temp_chunk_dir / f"chunk_{chunk_idx:04d}.npz"
|
|
print(f"Saving chunk {chunk_idx} to disk...")
|
|
np.savez_compressed(str(chunk_path), frames=matted_frames)
|
|
|
|
# Save to checkpoint
|
|
checkpoint_mgr.save_chunk(chunk_idx, None, source_chunk_path=chunk_path)
|
|
|
|
chunk_files.append(chunk_path)
|
|
chunk_idx += 1
|
|
|
|
# Free the frames from memory immediately
|
|
del matted_frames
|
|
del frame_data
|
|
|
|
# Update statistics
|
|
self.processing_stats['chunks_processed'] += 1
|
|
self.processing_stats['frames_processed'] += frames_to_read
|
|
|
|
# Aggressive memory cleanup after each chunk
|
|
self._aggressive_memory_cleanup(f"chunk {chunk_idx} completion")
|
|
|
|
# Also use memory manager cleanup
|
|
self.memory_manager.cleanup_memory()
|
|
|
|
if self.memory_manager.should_emergency_cleanup():
|
|
self.memory_manager.emergency_cleanup()
|
|
|
|
# Mark chunk processing as complete
|
|
checkpoint_mgr.mark_processing_complete()
|
|
|
|
# Check if merge was already done
|
|
if resume_info.get('merge_complete', False):
|
|
print("\n✅ Merge already completed in previous run!")
|
|
print(f" Output: {self.config.output.path}")
|
|
else:
|
|
# Use streaming merge to avoid memory accumulation (fixes OOM)
|
|
print("\n🎬 Using streaming merge (no memory accumulation)...")
|
|
|
|
# For resume scenarios, make sure we have all chunk files
|
|
if resume_info['can_resume']:
|
|
checkpoint_chunk_files = checkpoint_mgr.get_completed_chunk_files()
|
|
if len(checkpoint_chunk_files) != len(chunk_files):
|
|
print(f"⚠️ Using {len(checkpoint_chunk_files)} checkpoint files instead of {len(chunk_files)} temp files")
|
|
chunk_files = checkpoint_chunk_files
|
|
|
|
# Determine audio source for final video
|
|
audio_source = None
|
|
if self.config.output.preserve_audio and Path(self.config.input.video_path).exists():
|
|
audio_source = self.config.input.video_path
|
|
|
|
# Stream merge chunks directly to output (no memory accumulation)
|
|
self.merge_chunks_streaming(
|
|
chunk_files=chunk_files,
|
|
output_path=self.config.output.path,
|
|
overlap_frames=overlap_frames,
|
|
audio_source=audio_source
|
|
)
|
|
|
|
# Mark merge as complete
|
|
checkpoint_mgr.mark_merge_complete()
|
|
|
|
print("✅ Streaming merge complete - no memory accumulation!")
|
|
|
|
# Calculate final statistics
|
|
self.processing_stats['end_time'] = time.time()
|
|
self.processing_stats['total_duration'] = self.processing_stats['end_time'] - self.processing_stats['start_time']
|
|
if self.processing_stats['total_duration'] > 0:
|
|
self.processing_stats['processing_fps'] = self.processing_stats['frames_processed'] / self.processing_stats['total_duration']
|
|
|
|
# Print processing statistics
|
|
self._print_processing_statistics()
|
|
|
|
# Print final memory report
|
|
self.memory_manager.print_memory_report()
|
|
|
|
print("Video processing completed!")
|
|
|
|
# Option to clean up checkpoints
|
|
print("\n🗄️ CHECKPOINT CLEANUP OPTIONS:")
|
|
print(" Checkpoints saved successfully and can be cleaned up")
|
|
print(" - Keep checkpoints for debugging: checkpoint_mgr.cleanup_checkpoints(keep_chunks=True)")
|
|
print(" - Remove all checkpoints: checkpoint_mgr.cleanup_checkpoints()")
|
|
print(f" - Checkpoint location: {checkpoint_mgr.checkpoint_dir}")
|
|
|
|
# For now, keep checkpoints by default (user can manually clean)
|
|
print("\n💡 Checkpoints kept for safety. Delete manually when no longer needed.")
|
|
|
|
finally:
|
|
# Clean up temporary chunk files (but not checkpoints)
|
|
if temp_chunk_dir.exists():
|
|
print("Cleaning up temporary chunk files...")
|
|
try:
|
|
shutil.rmtree(temp_chunk_dir)
|
|
except Exception as e:
|
|
print(f"⚠️ Could not clean temp directory: {e}")
|
|
|
|
def _print_processing_statistics(self):
|
|
"""Print detailed processing statistics"""
|
|
stats = self.processing_stats
|
|
video_duration = self.total_frames / self.fps if self.fps > 0 else 0
|
|
|
|
print("\n" + "="*60)
|
|
print("PROCESSING STATISTICS")
|
|
print("="*60)
|
|
print(f"Input video duration: {video_duration:.1f} seconds ({self.total_frames} frames @ {self.fps:.2f} fps)")
|
|
print(f"Total processing time: {stats['total_duration']:.1f} seconds")
|
|
print(f"Processing speed: {stats['processing_fps']:.2f} fps")
|
|
print(f"Speedup factor: {self.fps / stats['processing_fps']:.1f}x slower than realtime")
|
|
print(f"Chunks processed: {stats['chunks_processed']}")
|
|
print(f"Frames processed: {stats['frames_processed']}")
|
|
|
|
if video_duration > 0:
|
|
efficiency = video_duration / stats['total_duration']
|
|
print(f"Processing efficiency: {efficiency:.3f} (1.0 = realtime)")
|
|
|
|
# Estimate time for different video lengths
|
|
print(f"\nEstimated processing times:")
|
|
print(f" 5 minutes: {(5 * 60) / efficiency / 60:.1f} minutes")
|
|
print(f" 30 minutes: {(30 * 60) / efficiency / 60:.1f} minutes")
|
|
print(f" 1 hour: {(60 * 60) / efficiency / 60:.1f} minutes")
|
|
|
|
print("="*60 + "\n") |