Compare commits

37 Commits

Author SHA1 Message Date
277d554ecc fix scaling 1 2025-07-26 18:31:16 -07:00
d6d2b0aa93 full size babyyy 2025-07-26 18:09:48 -07:00
3a547b7c21 please god work 2025-07-26 17:44:23 -07:00
262cb00b69 checkpoints yay 2025-07-26 17:11:07 -07:00
caa4ddb5e0 actually fix streaming save 2025-07-26 17:05:50 -07:00
fa945b9c3e fix concat 2025-07-26 16:29:59 -07:00
4958c503dd please merge 2025-07-26 16:02:07 -07:00
366b132ef5 growth 2025-07-26 15:31:07 -07:00
4d1361df46 bigtime 2025-07-26 15:29:37 -07:00
884cb8dce2 lol 2025-07-26 15:29:28 -07:00
36f58acb8b foo 2025-07-26 15:18:32 -07:00
fb51e82fd4 stuff 2025-07-26 15:18:01 -07:00
9f572d4430 analyze 2025-07-26 15:10:34 -07:00
ba8706b7ae quick check 2025-07-26 14:52:44 -07:00
734445cf48 more memory fixes hopeufly 2025-07-26 14:33:36 -07:00
80f947c91b det core 2025-07-26 13:51:21 -07:00
6f93abcb08 dont use predictor over and over 2025-07-26 13:40:47 -07:00
c368d6dc97 not too hard 2025-07-26 13:30:13 -07:00
e7e9c5597b old sam cleanup 2025-07-26 13:21:39 -07:00
3af16df71e more memleak fixes 2025-07-26 13:03:04 -07:00
df7b009a7b fix gpu memory issue 2025-07-26 12:42:16 -07:00
725a781456 cupy 2025-07-26 12:29:32 -07:00
ccc68a3895 memleak fix hopefully 2025-07-26 12:25:55 -07:00
463f881eaf catagory A round 2 2025-07-26 11:56:51 -07:00
b642b562f0 optimizations A round 1 2025-07-26 11:04:04 -07:00
40ae537f7a memory stuff 2025-07-26 09:56:39 -07:00
28aa663b7b debug data 2025-07-26 09:31:50 -07:00
0244ba5204 fix some stuff 2025-07-26 09:24:30 -07:00
141302cccf ffmpegize 2025-07-26 09:16:45 -07:00
6b0eb6104d debug data 2025-07-26 09:14:11 -07:00
0f8818259e debug data 2025-07-26 09:10:59 -07:00
86274ba04a video debug 2025-07-26 09:07:57 -07:00
99c4da83af fix temp file 2025-07-26 09:01:38 -07:00
c4af7baf3d decord 2025-07-26 08:55:27 -07:00
3e21fd8678 fix again 2025-07-26 08:54:03 -07:00
d933d6b606 fix wrapper 2025-07-26 08:51:48 -07:00
7852303b40 maybe fix 2025-07-26 08:47:50 -07:00
15 changed files with 2919 additions and 164 deletions

193
analyze_memory_profile.py Normal file
View File

@@ -0,0 +1,193 @@
#!/usr/bin/env python3
"""
Analyze memory profile JSON files to identify OOM causes
"""
import json
import glob
import os
import sys
from pathlib import Path
def analyze_memory_files():
"""Analyze partial memory profile files"""
# Get all partial files in order
files = sorted(glob.glob('memory_profile_partial_*.json'))
if not files:
print("❌ No memory profile files found!")
print("Expected files like: memory_profile_partial_0.json")
return
print(f"🔍 Found {len(files)} memory profile files")
print("=" * 60)
peak_memory = 0
peak_vram = 0
critical_points = []
all_checkpoints = []
for i, file in enumerate(files):
try:
with open(file, 'r') as f:
data = json.load(f)
timeline = data.get('timeline', [])
if not timeline:
continue
# Find peaks in this file
file_peak_rss = max([d['rss_gb'] for d in timeline])
file_peak_vram = max([d['vram_gb'] for d in timeline])
if file_peak_rss > peak_memory:
peak_memory = file_peak_rss
if file_peak_vram > peak_vram:
peak_vram = file_peak_vram
# Find memory growth spikes (>3GB increase)
for j in range(1, len(timeline)):
prev_rss = timeline[j-1]['rss_gb']
curr_rss = timeline[j]['rss_gb']
growth = curr_rss - prev_rss
if growth > 3.0: # >3GB growth spike
checkpoint = timeline[j].get('checkpoint', f'sample_{j}')
critical_points.append({
'file': file,
'file_index': i,
'sample': j,
'timestamp': timeline[j]['timestamp'],
'rss_gb': curr_rss,
'vram_gb': timeline[j]['vram_gb'],
'growth_gb': growth,
'checkpoint': checkpoint
})
# Collect all checkpoints
checkpoints = [d for d in timeline if 'checkpoint' in d]
for cp in checkpoints:
cp['file'] = file
cp['file_index'] = i
all_checkpoints.append(cp)
# Show progress for this file
if timeline:
start_rss = timeline[0]['rss_gb']
end_rss = timeline[-1]['rss_gb']
growth = end_rss - start_rss
samples = len(timeline)
print(f"📊 File {i+1:2d}: {start_rss:5.1f}GB → {end_rss:5.1f}GB "
f"(+{growth:4.1f}GB) [{samples:3d} samples]")
# Show significant checkpoints from this file
if checkpoints:
for cp in checkpoints:
print(f" 📍 {cp['checkpoint']}: {cp['rss_gb']:.1f}GB")
except Exception as e:
print(f"❌ Error reading {file}: {e}")
print("\n" + "=" * 60)
print("🎯 ANALYSIS SUMMARY")
print("=" * 60)
print(f"📈 Peak Memory: {peak_memory:.1f} GB")
print(f"🎮 Peak VRAM: {peak_vram:.1f} GB")
print(f"⚡ Growth Spikes: {len(critical_points)} events >3GB")
if critical_points:
print(f"\n💥 MEMORY GROWTH SPIKES (>3GB):")
print(" Location Growth Total VRAM")
print(" " + "-" * 55)
for point in critical_points:
location = point['checkpoint'][:30].ljust(30)
print(f" {location} +{point['growth_gb']:4.1f}GB → {point['rss_gb']:5.1f}GB {point['vram_gb']:4.1f}GB")
if all_checkpoints:
print(f"\n📍 CHECKPOINT PROGRESSION:")
print(" Checkpoint Memory VRAM File")
print(" " + "-" * 55)
for cp in all_checkpoints:
checkpoint = cp['checkpoint'][:30].ljust(30)
file_num = cp['file_index'] + 1
print(f" {checkpoint} {cp['rss_gb']:5.1f}GB {cp['vram_gb']:4.1f}GB #{file_num}")
# Memory growth analysis
if len(all_checkpoints) > 1:
print(f"\n📊 MEMORY GROWTH ANALYSIS:")
# Find the biggest memory jumps between checkpoints
big_jumps = []
for i in range(1, len(all_checkpoints)):
prev_cp = all_checkpoints[i-1]
curr_cp = all_checkpoints[i]
growth = curr_cp['rss_gb'] - prev_cp['rss_gb']
if growth > 2.0: # >2GB jump
big_jumps.append({
'from': prev_cp['checkpoint'],
'to': curr_cp['checkpoint'],
'growth': growth,
'from_memory': prev_cp['rss_gb'],
'to_memory': curr_cp['rss_gb']
})
if big_jumps:
print(" Major jumps (>2GB):")
for jump in big_jumps:
print(f" {jump['from']}{jump['to']}: "
f"+{jump['growth']:.1f}GB ({jump['from_memory']:.1f}{jump['to_memory']:.1f}GB)")
else:
print(" ✅ No major memory jumps detected")
# Diagnosis
print(f"\n🔬 DIAGNOSIS:")
if peak_memory > 400:
print(" 🔴 CRITICAL: Memory usage exceeded 400GB")
print(" 💡 Recommendation: Reduce chunk_size to 200-300 frames")
elif peak_memory > 200:
print(" 🟡 HIGH: Memory usage over 200GB")
print(" 💡 Recommendation: Reduce chunk_size to 400 frames")
else:
print(" 🟢 MODERATE: Memory usage under 200GB")
if critical_points:
# Find most common growth spike locations
spike_locations = {}
for point in critical_points:
location = point['checkpoint']
spike_locations[location] = spike_locations.get(location, 0) + 1
print("\n 🎯 Most problematic locations:")
for location, count in sorted(spike_locations.items(), key=lambda x: x[1], reverse=True)[:3]:
print(f" {location}: {count} spikes")
print(f"\n💡 NEXT STEPS:")
if 'merge' in str(critical_points).lower():
print(" 1. Chunk merging still causing memory accumulation")
print(" 2. Check if streaming merge is actually being used")
print(" 3. Verify chunk files are being deleted immediately")
elif 'propagation' in str(critical_points).lower():
print(" 1. SAM2 propagation using too much memory")
print(" 2. Reduce chunk_size further (try 300 frames)")
print(" 3. Enable more aggressive frame release")
else:
print(" 1. Review the checkpoint progression above")
print(" 2. Focus on locations with biggest memory spikes")
print(" 3. Consider reducing chunk_size if spikes are large")
def main():
print("🔍 MEMORY PROFILE ANALYZER")
print("Analyzing memory profile files for OOM causes...")
print()
analyze_memory_files()
if __name__ == "__main__":
main()

View File

@@ -3,8 +3,8 @@ input:
processing:
scale_factor: 0.5 # A40 can handle 0.5 well
chunk_size: 0 # Auto-calculate based on A40's 48GB VRAM
overlap_frames: 60
chunk_size: 600 # Category A.4: Larger chunks for better VRAM utilization (was 200)
overlap_frames: 30 # Reduced overlap
detection:
confidence_threshold: 0.7
@@ -19,9 +19,11 @@ matting:
output:
path: "/workspace/output/matted_video.mp4"
format: "alpha"
format: "greenscreen" # Changed to greenscreen for easier testing
background_color: [0, 255, 0]
maintain_sbs: true
preserve_audio: true # Category A.1: Audio preservation
verify_sync: true # Category A.2: Frame count validation
hardware:
device: "cuda"

151
debug_memory_leak.py Normal file
View File

@@ -0,0 +1,151 @@
#!/usr/bin/env python3
"""
Debug memory leak between chunks - track exactly where memory accumulates
"""
import psutil
import gc
from pathlib import Path
import sys
def detailed_memory_check(label):
"""Get detailed memory info"""
process = psutil.Process()
memory_info = process.memory_info()
rss_gb = memory_info.rss / (1024**3)
vms_gb = memory_info.vms / (1024**3)
# System memory
sys_memory = psutil.virtual_memory()
available_gb = sys_memory.available / (1024**3)
print(f"🔍 {label}:")
print(f" RSS: {rss_gb:.2f} GB (physical memory)")
print(f" VMS: {vms_gb:.2f} GB (virtual memory)")
print(f" Available: {available_gb:.2f} GB")
return rss_gb
def simulate_chunk_processing():
"""Simulate the chunk processing to see where memory accumulates"""
print("🚀 SIMULATING CHUNK PROCESSING TO FIND MEMORY LEAK")
print("=" * 60)
base_memory = detailed_memory_check("0. Baseline")
# Step 1: Import everything (with lazy loading)
print("\n📦 Step 1: Imports")
from vr180_matting.config import VR180Config
from vr180_matting.vr180_processor import VR180Processor
import_memory = detailed_memory_check("1. After imports")
import_growth = import_memory - base_memory
print(f" Growth: +{import_growth:.2f} GB")
# Step 2: Load config
print("\n⚙️ Step 2: Config loading")
config = VR180Config.from_yaml('config.yaml')
config_memory = detailed_memory_check("2. After config load")
config_growth = config_memory - import_memory
print(f" Growth: +{config_growth:.2f} GB")
# Step 3: Initialize processor (models still lazy)
print("\n🏗️ Step 3: Processor initialization")
processor = VR180Processor(config)
processor_memory = detailed_memory_check("3. After processor init")
processor_growth = processor_memory - config_memory
print(f" Growth: +{processor_growth:.2f} GB")
# Step 4: Load video info (lightweight)
print("\n🎬 Step 4: Video info loading")
try:
video_info = processor.load_video_info(config.input.video_path)
print(f" Video: {video_info.get('width', 'unknown')}x{video_info.get('height', 'unknown')}, "
f"{video_info.get('total_frames', 'unknown')} frames")
except Exception as e:
print(f" Warning: Could not load video info: {e}")
video_info_memory = detailed_memory_check("4. After video info")
video_info_growth = video_info_memory - processor_memory
print(f" Growth: +{video_info_growth:.2f} GB")
# Step 5: Simulate chunk 0 processing (this is where models actually load)
print("\n🔄 Step 5: Simulating chunk 0 processing...")
# This is where the real memory usage starts
print(" Loading first 10 frames to trigger model loading...")
try:
# Read a small number of frames to trigger model loading
frames = processor.read_video_frames(
config.input.video_path,
start_frame=0,
num_frames=10, # Just 10 frames to trigger model loading
scale_factor=config.processing.scale_factor
)
frames_memory = detailed_memory_check("5a. After reading 10 frames")
frames_growth = frames_memory - video_info_memory
print(f" 10 frames growth: +{frames_growth:.2f} GB")
# Free frames
del frames
gc.collect()
after_free_memory = detailed_memory_check("5b. After freeing 10 frames")
free_improvement = frames_memory - after_free_memory
print(f" Memory freed: -{free_improvement:.2f} GB")
except Exception as e:
print(f" Could not simulate frame loading: {e}")
after_free_memory = video_info_memory
print(f"\n📊 MEMORY ANALYSIS:")
print(f" Baseline → Final: {base_memory:.2f}GB → {after_free_memory:.2f}GB")
print(f" Total growth: +{after_free_memory - base_memory:.2f}GB")
if after_free_memory - base_memory > 10:
print(f" 🔴 HIGH: Memory growth > 10GB before any real processing")
print(f" 💡 This suggests model loading is using too much memory")
elif after_free_memory - base_memory > 5:
print(f" 🟡 MODERATE: Memory growth 5-10GB")
print(f" 💡 Normal for model loading, but monitor chunk processing")
else:
print(f" 🟢 GOOD: Memory growth < 5GB")
print(f" 💡 Initialization memory usage is reasonable")
print(f"\n🎯 KEY INSIGHTS:")
if import_growth > 1:
print(f" - Import growth: {import_growth:.2f}GB (fixed with lazy loading)")
if processor_growth > 10:
print(f" - Processor init: {processor_growth:.2f}GB (investigate model pre-loading)")
print(f"\n💡 RECOMMENDATIONS:")
if after_free_memory - base_memory > 15:
print(f" 1. Reduce chunk_size to 200-300 frames")
print(f" 2. Use smaller models (yolov8n instead of yolov8m)")
print(f" 3. Enable FP16 mode for SAM2")
elif after_free_memory - base_memory > 8:
print(f" 1. Monitor chunk processing carefully")
print(f" 2. Use streaming merge (should be automatic)")
print(f" 3. Current settings may be acceptable")
else:
print(f" 1. Settings look good for initialization")
print(f" 2. Focus on chunk processing memory leaks")
def main():
if len(sys.argv) != 2:
print("Usage: python debug_memory_leak.py <config.yaml>")
print("This simulates initialization to find memory leaks")
sys.exit(1)
config_path = sys.argv[1]
if not Path(config_path).exists():
print(f"Config file not found: {config_path}")
sys.exit(1)
simulate_chunk_processing()
if __name__ == "__main__":
main()

249
memory_profiler_script.py Normal file
View File

@@ -0,0 +1,249 @@
#!/usr/bin/env python3
"""
Memory profiling script for VR180 matting pipeline
Tracks memory usage during processing to identify leaks
"""
import sys
import time
import psutil
import tracemalloc
import subprocess
import gc
from pathlib import Path
from typing import Dict, List, Tuple
import threading
import json
class MemoryProfiler:
def __init__(self, output_file: str = "memory_profile.json"):
self.output_file = output_file
self.data = []
self.process = psutil.Process()
self.running = False
self.thread = None
self.checkpoint_counter = 0
def start_monitoring(self, interval: float = 1.0):
"""Start continuous memory monitoring"""
tracemalloc.start()
self.running = True
self.thread = threading.Thread(target=self._monitor_loop, args=(interval,))
self.thread.daemon = True
self.thread.start()
print(f"🔍 Memory monitoring started (interval: {interval}s)")
def stop_monitoring(self):
"""Stop monitoring and save results"""
self.running = False
if self.thread:
self.thread.join()
# Get tracemalloc snapshot
snapshot = tracemalloc.take_snapshot()
top_stats = snapshot.statistics('lineno')
# Save detailed results
results = {
'timeline': self.data,
'top_memory_allocations': [
{
'file': stat.traceback.format()[0],
'size_mb': stat.size / 1024 / 1024,
'count': stat.count
}
for stat in top_stats[:20] # Top 20 allocations
],
'summary': {
'peak_rss_gb': max([d['rss_gb'] for d in self.data]) if self.data else 0,
'peak_vram_gb': max([d['vram_gb'] for d in self.data]) if self.data else 0,
'total_samples': len(self.data)
}
}
with open(self.output_file, 'w') as f:
json.dump(results, f, indent=2)
tracemalloc.stop()
print(f"📊 Memory profile saved to {self.output_file}")
def _monitor_loop(self, interval: float):
"""Continuous monitoring loop"""
while self.running:
try:
# System memory
memory_info = self.process.memory_info()
rss_gb = memory_info.rss / (1024**3)
# System-wide memory
sys_memory = psutil.virtual_memory()
sys_used_gb = (sys_memory.total - sys_memory.available) / (1024**3)
sys_available_gb = sys_memory.available / (1024**3)
# GPU memory (if available)
vram_gb = 0
vram_free_gb = 0
try:
result = subprocess.run(['nvidia-smi', '--query-gpu=memory.used,memory.free',
'--format=csv,noheader,nounits'],
capture_output=True, text=True, timeout=5)
if result.returncode == 0:
lines = result.stdout.strip().split('\n')
if lines and lines[0]:
used, free = lines[0].split(', ')
vram_gb = float(used) / 1024
vram_free_gb = float(free) / 1024
except Exception:
pass
# Tracemalloc current usage
try:
current, peak = tracemalloc.get_traced_memory()
traced_mb = current / (1024**2)
except Exception:
traced_mb = 0
data_point = {
'timestamp': time.time(),
'rss_gb': rss_gb,
'vram_gb': vram_gb,
'vram_free_gb': vram_free_gb,
'sys_used_gb': sys_used_gb,
'sys_available_gb': sys_available_gb,
'traced_mb': traced_mb
}
self.data.append(data_point)
# Print periodic updates and save partial data
if len(self.data) % 10 == 0: # Every 10 samples
print(f"🔍 Memory: RSS={rss_gb:.2f}GB, VRAM={vram_gb:.2f}GB, Sys={sys_used_gb:.1f}GB")
# Save partial data every 30 samples in case of crash
if len(self.data) % 30 == 0:
self._save_partial_data()
except Exception as e:
print(f"Monitoring error: {e}")
time.sleep(interval)
def _save_partial_data(self):
"""Save partial data to prevent loss on crash"""
try:
partial_file = f"memory_profile_partial_{self.checkpoint_counter}.json"
with open(partial_file, 'w') as f:
json.dump({
'timeline': self.data,
'status': 'partial_save',
'samples': len(self.data)
}, f, indent=2)
self.checkpoint_counter += 1
except Exception as e:
print(f"Failed to save partial data: {e}")
def log_checkpoint(self, checkpoint_name: str):
"""Log a specific checkpoint"""
if self.data:
self.data[-1]['checkpoint'] = checkpoint_name
latest = self.data[-1]
print(f"📍 CHECKPOINT [{checkpoint_name}]: RSS={latest['rss_gb']:.2f}GB, VRAM={latest['vram_gb']:.2f}GB")
# Save checkpoint data immediately
self._save_partial_data()
def run_with_profiling(config_path: str):
"""Run the VR180 matting with memory profiling"""
profiler = MemoryProfiler("memory_profile_detailed.json")
try:
# Start monitoring
profiler.start_monitoring(interval=2.0) # Sample every 2 seconds
# Log initial state
profiler.log_checkpoint("STARTUP")
# Import after starting profiler to catch import memory usage
print("Importing VR180 processor...")
from vr180_matting.vr180_processor import VR180Processor
from vr180_matting.config import VR180Config
profiler.log_checkpoint("IMPORTS_COMPLETE")
# Load config
print(f"Loading config from {config_path}")
config = VR180Config.from_yaml(config_path)
profiler.log_checkpoint("CONFIG_LOADED")
# Initialize processor
print("Initializing VR180 processor...")
processor = VR180Processor(config)
profiler.log_checkpoint("PROCESSOR_INITIALIZED")
# Force garbage collection
gc.collect()
profiler.log_checkpoint("INITIAL_GC_COMPLETE")
# Run processing
print("Starting VR180 processing...")
processor.process_video()
profiler.log_checkpoint("PROCESSING_COMPLETE")
except Exception as e:
print(f"❌ Error during processing: {e}")
profiler.log_checkpoint(f"ERROR: {str(e)}")
raise
finally:
# Stop monitoring and save results
profiler.stop_monitoring()
# Print summary
print("\n" + "="*60)
print("MEMORY PROFILING SUMMARY")
print("="*60)
if profiler.data:
peak_rss = max([d['rss_gb'] for d in profiler.data])
peak_vram = max([d['vram_gb'] for d in profiler.data])
print(f"Peak RSS Memory: {peak_rss:.2f} GB")
print(f"Peak VRAM Usage: {peak_vram:.2f} GB")
print(f"Total Samples: {len(profiler.data)}")
# Show checkpoints
checkpoints = [d for d in profiler.data if 'checkpoint' in d]
if checkpoints:
print(f"\nCheckpoints ({len(checkpoints)}):")
for cp in checkpoints:
print(f" {cp['checkpoint']}: RSS={cp['rss_gb']:.2f}GB, VRAM={cp['vram_gb']:.2f}GB")
print(f"\nDetailed profile saved to: {profiler.output_file}")
def main():
if len(sys.argv) != 2:
print("Usage: python memory_profiler_script.py <config.yaml>")
print("\nThis script runs VR180 matting with detailed memory profiling")
print("It will:")
print("- Monitor RSS, VRAM, and system memory every 2 seconds")
print("- Track memory allocations with tracemalloc")
print("- Log checkpoints at key processing stages")
print("- Save detailed JSON report for analysis")
sys.exit(1)
config_path = sys.argv[1]
if not Path(config_path).exists():
print(f"❌ Config file not found: {config_path}")
sys.exit(1)
print("🚀 Starting VR180 Memory Profiling")
print(f"Config: {config_path}")
print("="*60)
run_with_profiling(config_path)
if __name__ == "__main__":
main()

125
quick_memory_check.py Normal file
View File

@@ -0,0 +1,125 @@
#!/usr/bin/env python3
"""
Quick memory and system check before running full pipeline
"""
import psutil
import subprocess
import sys
from pathlib import Path
def check_system():
"""Check system resources before starting"""
print("🔍 SYSTEM RESOURCE CHECK")
print("=" * 50)
# Memory info
memory = psutil.virtual_memory()
print(f"📊 RAM:")
print(f" Total: {memory.total / (1024**3):.1f} GB")
print(f" Available: {memory.available / (1024**3):.1f} GB")
print(f" Used: {(memory.total - memory.available) / (1024**3):.1f} GB ({memory.percent:.1f}%)")
# GPU info
try:
result = subprocess.run(['nvidia-smi', '--query-gpu=name,memory.total,memory.used,memory.free',
'--format=csv,noheader,nounits'],
capture_output=True, text=True, timeout=10)
if result.returncode == 0:
lines = result.stdout.strip().split('\n')
print(f"\n🎮 GPU:")
for i, line in enumerate(lines):
if line.strip():
parts = line.split(', ')
if len(parts) >= 4:
name, total, used, free = parts[:4]
total_gb = float(total) / 1024
used_gb = float(used) / 1024
free_gb = float(free) / 1024
print(f" GPU {i}: {name}")
print(f" VRAM: {used_gb:.1f}/{total_gb:.1f} GB ({used_gb/total_gb*100:.1f}% used)")
print(f" Free: {free_gb:.1f} GB")
except Exception as e:
print(f"\n⚠️ Could not get GPU info: {e}")
# Disk space
disk = psutil.disk_usage('/')
print(f"\n💾 Disk (/):")
print(f" Total: {disk.total / (1024**3):.1f} GB")
print(f" Used: {disk.used / (1024**3):.1f} GB ({disk.used/disk.total*100:.1f}%)")
print(f" Free: {disk.free / (1024**3):.1f} GB")
# Check config file
if len(sys.argv) > 1:
config_path = sys.argv[1]
if Path(config_path).exists():
print(f"\n✅ Config file found: {config_path}")
# Try to load and show key settings
try:
import yaml
with open(config_path, 'r') as f:
config = yaml.safe_load(f)
print(f"📋 Key Settings:")
if 'processing' in config:
proc = config['processing']
print(f" Chunk size: {proc.get('chunk_size', 'default')}")
print(f" Scale factor: {proc.get('scale_factor', 'default')}")
if 'hardware' in config:
hw = config['hardware']
print(f" Max VRAM: {hw.get('max_vram_gb', 'default')} GB")
if 'input' in config:
inp = config['input']
video_path = inp.get('video_path', '')
if video_path and Path(video_path).exists():
size_gb = Path(video_path).stat().st_size / (1024**3)
print(f" Input video: {video_path} ({size_gb:.1f} GB)")
else:
print(f" ⚠️ Input video not found: {video_path}")
except Exception as e:
print(f" ⚠️ Could not parse config: {e}")
else:
print(f"\n❌ Config file not found: {config_path}")
return False
# Memory safety warnings
print(f"\n⚠️ MEMORY SAFETY CHECKS:")
available_gb = memory.available / (1024**3)
if available_gb < 10:
print(f" 🔴 LOW MEMORY: Only {available_gb:.1f}GB available")
print(" Consider: reducing chunk_size or scale_factor")
return False
elif available_gb < 20:
print(f" 🟡 MODERATE MEMORY: {available_gb:.1f}GB available")
print(" Recommend: chunk_size ≤ 300, scale_factor ≤ 0.5")
else:
print(f" 🟢 GOOD MEMORY: {available_gb:.1f}GB available")
print(f"\n" + "=" * 50)
return True
def main():
if len(sys.argv) != 2:
print("Usage: python quick_memory_check.py <config.yaml>")
print("\nThis checks system resources before running VR180 matting")
sys.exit(1)
safe_to_run = check_system()
if safe_to_run:
print("✅ System check passed - safe to run VR180 matting")
print("\nTo run with memory profiling:")
print(f" python memory_profiler_script.py {sys.argv[1]}")
print("\nTo run normally:")
print(f" vr180-matting {sys.argv[1]}")
else:
print("❌ System check failed - address issues before running")
sys.exit(1)
if __name__ == "__main__":
main()

View File

@@ -9,3 +9,7 @@ ultralytics>=8.0.0
tqdm>=4.65.0
psutil>=5.9.0
ffmpeg-python>=0.2.0
decord>=0.6.0
# GPU acceleration (optional but recommended for stereo validation speedup)
# cupy-cuda11x>=12.0.0 # For CUDA 11.x
# cupy-cuda12x>=12.0.0 # For CUDA 12.x - uncomment appropriate version

View File

@@ -14,6 +14,32 @@ echo "🐍 Installing Python dependencies..."
pip install --upgrade pip
pip install -r requirements.txt
# Install decord for SAM2 video loading
echo "📹 Installing decord for video processing..."
pip install decord
# Install CuPy for GPU acceleration of stereo validation
echo "🚀 Installing CuPy for GPU acceleration..."
# Auto-detect CUDA version and install appropriate CuPy
python -c "
import torch
if torch.cuda.is_available():
cuda_version = torch.version.cuda
print(f'CUDA version detected: {cuda_version}')
if cuda_version.startswith('11.'):
import subprocess
subprocess.run(['pip', 'install', 'cupy-cuda11x>=12.0.0'])
print('Installed CuPy for CUDA 11.x')
elif cuda_version.startswith('12.'):
import subprocess
subprocess.run(['pip', 'install', 'cupy-cuda12x>=12.0.0'])
print('Installed CuPy for CUDA 12.x')
else:
print(f'Unsupported CUDA version: {cuda_version}')
else:
print('CUDA not available, skipping CuPy installation')
"
# Install SAM2 separately (not on PyPI)
echo "🎯 Installing SAM2..."
pip install git+https://github.com/facebookresearch/segment-anything-2.git

198
spec.md
View File

@@ -123,6 +123,204 @@ hardware:
3. **Performance Profiling**: Detailed resource usage analytics
4. **Quality Validation**: Comprehensive testing suite
## Post-Implementation Optimization Opportunities
*Based on first successful 30-second test clip execution results (A40 GPU, 50% scale, 9x200 frame chunks)*
### Performance Analysis Findings
- **Processing Speed**: ~0.54s per frame (64.4s for 120 frames per chunk)
- **VRAM Utilization**: Only 2.5% (1.11GB of 45GB available) - significantly underutilized
- **RAM Usage**: 106GB used of 494GB available (21.5%)
- **Primary Bottleneck**: Intermediate ffmpeg encoding operations per chunk
### Identified Optimization Categories
#### Category A: Performance Improvements (Quick Wins)
1. **Audio Track Preservation** ⚠️ **CRITICAL**
- Issue: Output video missing audio track from input
- Solution: Use ffmpeg to copy audio stream during final video creation
- Implementation: Add `-c:a copy` to final ffmpeg command
- Impact: Essential for production usability
- Risk: Low, standard ffmpeg operation
2. **Frame Count Synchronization** ⚠️ **CRITICAL**
- Issue: Audio sync drift if input/output frame counts differ
- Solution: Validate exact frame count preservation throughout pipeline
- Implementation: Frame count verification + duration matching
- Impact: Prevents audio desync in long videos
- Risk: Low, validation feature
3. **Memory Usage Reality Check** ⚠️ **IMPORTANT**
- Current assumption: Unlimited RAM for memory-only pipeline
- Reality: RunPod container limited to ~48GB RAM
- Risk calculation: 1-hour video = ~213k frames = potential 20-40GB+ memory usage
- Solution: Implement streaming output instead of full in-memory accumulation
- Impact: Enables processing of long-form content
- Risk: Medium, requires pipeline restructuring
4. **Larger Chunk Sizes**
- Current: 200 frames per chunk (conservative for 10GB RTX 3080)
- Opportunity: 600-800 frames per chunk on high-VRAM systems
- Impact: Reduce 9 chunks to 2-3 chunks, fewer intermediate operations
- Risk: Low, easily configurable
5. **Streaming Output Pipeline**
- Current: Accumulate all processed frames in memory, write once
- Opportunity: Write processed chunks to temporary segments, merge at end
- Impact: Constant memory usage regardless of video length
- Risk: Medium, requires temporary file management
6. **Enhanced Performance Profiling**
- Current: Basic memory monitoring
- Opportunity: Detailed timing per processing stage (detection, propagation, encoding)
- Impact: Identify exact bottlenecks for targeted optimization
- Risk: Low, debugging feature
7. **Parallel Eye Processing**
- Current: Sequential left eye → right eye processing
- Opportunity: Process both eyes simultaneously
- Impact: Potential 50% speedup, better GPU utilization
- Risk: Medium, memory management complexity
#### Category B: Stereo Consistency Fixes (Critical for VR)
1. **Master-Slave Eye Processing**
- Issue: Independent detection leads to mismatched person counts between eyes
- Solution: Use left eye detections as "seeds" for right eye processing
- Impact: Ensures identical person detection across stereo pair
- Risk: Low, maintains current quality while improving consistency
2. **Cross-Eye Detection Validation**
- Issue: Hair/clothing included on one eye but not the other
- Solution: Compare detection results, flag inconsistencies for reprocessing
- Impact: 90%+ stereo alignment improvement
- Risk: Low, fallback to current behavior
3. **Disparity-Aware Segmentation**
- Issue: Segmentation boundaries differ between eyes despite same person
- Solution: Use stereo disparity to correlate features between eyes
- Impact: True stereo-consistent matting
- Risk: High, complex implementation
4. **Joint Stereo Detection**
- Issue: YOLO runs independently on each eye
- Solution: Run YOLO on full SBS frame, split detections spatially
- Impact: Guaranteed identical detection counts
- Risk: Medium, requires detection coordinate mapping
#### Category C: Advanced Optimizations (Future)
1. **Adaptive Memory Management**
- Opportunity: Dynamic chunk sizing based on real-time VRAM usage
- Impact: Optimal resource utilization across different hardware
- Risk: Medium, complex heuristics
2. **Multi-Resolution Processing**
- Opportunity: Initial processing at lower resolution, edge refinement at full
- Impact: Speed improvement while maintaining quality
- Risk: Medium, quality validation required
3. **Enhanced Workflow Documentation**
- Issue: Unclear intermediate data lifecycle
- Solution: Detailed logging of chunk processing, optional intermediate preservation
- Impact: Better debugging and user understanding
- Risk: Low, documentation feature
### Implementation Strategy
- **Phase A**: Quick performance wins (larger chunks, profiling)
- **Phase B**: Stereo consistency (master-slave, validation)
- **Phase C**: Advanced features (disparity-aware, memory optimization)
### Configuration Extensions Required
```yaml
processing:
chunk_size: 600 # Increase from 200 for high-VRAM systems
memory_pipeline: false # Skip intermediate video creation (disabled due to RAM limits)
streaming_output: true # Write chunks progressively instead of accumulating
parallel_eyes: false # Process eyes simultaneously
max_memory_gb: 40 # Realistic RAM limit for RunPod containers
audio:
preserve_audio: true # Copy audio track from input to output
verify_sync: true # Validate frame count and duration matching
audio_codec: "copy" # Preserve original audio codec
stereo:
consistency_mode: "master_slave" # "independent", "master_slave", "joint"
validation_threshold: 0.8 # Similarity threshold between eyes
correction_method: "transfer" # "transfer", "reprocess", "ensemble"
performance:
profile_enabled: true # Detailed timing analysis
preserve_intermediates: false # For debugging workflow
debugging:
log_intermediate_workflow: true # Document chunk lifecycle
save_detection_visualization: false # Debug detection mismatches
frame_count_validation: true # Ensure exact frame preservation
```
### Technical Implementation Details
#### Audio Preservation Implementation
```python
# During final video save, include audio stream copy
ffmpeg_cmd = [
'ffmpeg', '-y',
'-framerate', str(fps),
'-i', frame_pattern, # Video frames
'-i', input_video_path, # Original video for audio
'-c:v', 'h264_nvenc', # GPU video codec (with CPU fallback)
'-c:a', 'copy', # Copy audio without re-encoding
'-map', '0:v:0', # Map video from first input
'-map', '1:a:0', # Map audio from second input
'-shortest', # Match shortest stream duration
output_path
]
```
#### Streaming Output Implementation
```python
# Instead of accumulating frames in memory:
class StreamingVideoWriter:
def __init__(self, output_path, fps, audio_source):
self.temp_segments = []
self.current_segment = 0
def write_chunk(self, processed_frames):
# Write chunk to temporary segment
segment_path = f"temp_segment_{self.current_segment}.mp4"
self.write_video_segment(processed_frames, segment_path)
self.temp_segments.append(segment_path)
self.current_segment += 1
def finalize(self):
# Merge all segments with audio preservation
self.merge_segments_with_audio()
```
#### Memory Usage Calculation
```python
def estimate_memory_requirements(duration_seconds, fps, resolution_scale=0.5):
"""Calculate memory usage for different video lengths"""
frames = duration_seconds * fps
# Per-frame memory (rough estimates for VR180 at 50% scale)
frame_size_mb = (3072 * 1536 * 3 * 4) / (1024 * 1024) # ~18MB per frame
total_memory_gb = (frames * frame_size_mb) / 1024
return {
'duration': duration_seconds,
'total_frames': frames,
'estimated_memory_gb': total_memory_gb,
'safe_for_48gb': total_memory_gb < 40
}
# Example outputs:
# 30 seconds: ~2.7GB (safe)
# 5 minutes: ~27GB (borderline)
# 1 hour: ~324GB (requires streaming)
```
## Success Criteria
### Technical Feasibility

148
test_inter_chunk_cleanup.py Normal file
View File

@@ -0,0 +1,148 @@
#!/usr/bin/env python3
"""
Test script to verify inter-chunk cleanup properly destroys models
"""
import psutil
import gc
import sys
from pathlib import Path
def get_memory_usage():
"""Get current memory usage in GB"""
process = psutil.Process()
return process.memory_info().rss / (1024**3)
def test_inter_chunk_cleanup():
"""Test that models are properly destroyed between chunks"""
print("🧪 TESTING INTER-CHUNK CLEANUP")
print("=" * 50)
baseline_memory = get_memory_usage()
print(f"📊 Baseline memory: {baseline_memory:.2f} GB")
# Import and create processor
print("\n1⃣ Creating processor...")
from vr180_matting.config import VR180Config
from vr180_matting.vr180_processor import VR180Processor
config = VR180Config.from_yaml('config.yaml')
processor = VR180Processor(config)
init_memory = get_memory_usage()
print(f"📊 After processor init: {init_memory:.2f} GB (+{init_memory - baseline_memory:.2f} GB)")
# Simulate chunk processing (just trigger model loading)
print("\n2⃣ Simulating chunk 0 processing...")
# Test 1: Force YOLO model loading
try:
detector = processor.detector
detector._load_model() # Force load
yolo_memory = get_memory_usage()
print(f"📊 After YOLO load: {yolo_memory:.2f} GB (+{yolo_memory - init_memory:.2f} GB)")
except Exception as e:
print(f"❌ YOLO loading failed: {e}")
yolo_memory = init_memory
# Test 2: Force SAM2 model loading
try:
sam2_model = processor.sam2_model
sam2_model._load_model(sam2_model.model_cfg, sam2_model.checkpoint_path)
sam2_memory = get_memory_usage()
print(f"📊 After SAM2 load: {sam2_memory:.2f} GB (+{sam2_memory - yolo_memory:.2f} GB)")
except Exception as e:
print(f"❌ SAM2 loading failed: {e}")
sam2_memory = yolo_memory
total_model_memory = sam2_memory - init_memory
print(f"📊 Total model memory: {total_model_memory:.2f} GB")
# Test 3: Inter-chunk cleanup
print("\n3⃣ Testing inter-chunk cleanup...")
processor._complete_inter_chunk_cleanup(chunk_idx=0)
cleanup_memory = get_memory_usage()
cleanup_improvement = sam2_memory - cleanup_memory
print(f"📊 After cleanup: {cleanup_memory:.2f} GB (-{cleanup_improvement:.2f} GB freed)")
# Test 4: Verify models reload fresh
print("\n4⃣ Testing fresh model reload...")
# Check YOLO state
yolo_reloaded = processor.detector.model is None
print(f"🔍 YOLO model destroyed: {'✅ YES' if yolo_reloaded else '❌ NO'}")
# Check SAM2 state
sam2_reloaded = not processor.sam2_model._model_loaded or processor.sam2_model.predictor is None
print(f"🔍 SAM2 model destroyed: {'✅ YES' if sam2_reloaded else '❌ NO'}")
# Test 5: Force reload to verify they work
print("\n5⃣ Testing model reload...")
try:
# Force YOLO reload
processor.detector._load_model()
yolo_reload_memory = get_memory_usage()
# Force SAM2 reload
processor.sam2_model._load_model(processor.sam2_model.model_cfg, processor.sam2_model.checkpoint_path)
sam2_reload_memory = get_memory_usage()
reload_growth = sam2_reload_memory - cleanup_memory
print(f"📊 After reload: {sam2_reload_memory:.2f} GB (+{reload_growth:.2f} GB)")
if abs(reload_growth - total_model_memory) < 1.0: # Within 1GB
print("✅ Models reloaded with similar memory usage (good)")
else:
print("⚠️ Model reload memory differs significantly")
except Exception as e:
print(f"❌ Model reload failed: {e}")
# Final summary
print(f"\n📊 SUMMARY:")
print(f" Baseline → Peak: {baseline_memory:.2f}GB → {sam2_memory:.2f}GB")
print(f" Peak → Cleanup: {sam2_memory:.2f}GB → {cleanup_memory:.2f}GB")
print(f" Memory freed: {cleanup_improvement:.2f}GB")
print(f" Models destroyed: YOLO={yolo_reloaded}, SAM2={sam2_reloaded}")
# Success criteria: Both models destroyed AND can reload
models_destroyed = yolo_reloaded and sam2_reloaded
can_reload = 'reload_growth' in locals()
if models_destroyed and can_reload:
print("✅ Inter-chunk cleanup working effectively")
print("💡 Models destroyed and can reload fresh (memory will be freed during real processing)")
return True
elif models_destroyed:
print("⚠️ Models destroyed but reload test incomplete")
print("💡 This should still prevent accumulation during real processing")
return True
else:
print("❌ Inter-chunk cleanup not freeing enough memory")
return False
def main():
if len(sys.argv) != 2:
print("Usage: python test_inter_chunk_cleanup.py <config.yaml>")
sys.exit(1)
config_path = sys.argv[1]
if not Path(config_path).exists():
print(f"Config file not found: {config_path}")
sys.exit(1)
success = test_inter_chunk_cleanup()
if success:
print(f"\n🎉 SUCCESS: Inter-chunk cleanup is working!")
print(f"💡 This should prevent 15-20GB model accumulation between chunks")
else:
print(f"\n❌ FAILURE: Inter-chunk cleanup needs improvement")
print(f"💡 Check model destruction logic in _complete_inter_chunk_cleanup")
return 0 if success else 1
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,220 @@
"""
Checkpoint manager for resumable video processing
Saves progress to avoid reprocessing after OOM or crashes
"""
import json
import hashlib
from pathlib import Path
from typing import Dict, Any, Optional, List
import os
import shutil
from datetime import datetime
class CheckpointManager:
"""Manages processing checkpoints for resumable execution"""
def __init__(self, video_path: str, output_path: str, checkpoint_dir: Optional[Path] = None):
"""
Initialize checkpoint manager
Args:
video_path: Input video path
output_path: Output video path
checkpoint_dir: Directory for checkpoint files (default: .vr180_checkpoints in CWD)
"""
self.video_path = Path(video_path)
self.output_path = Path(output_path)
# Create unique checkpoint ID based on video file
self.video_hash = self._compute_video_hash()
# Setup checkpoint directory
if checkpoint_dir is None:
self.checkpoint_dir = Path.cwd() / ".vr180_checkpoints" / self.video_hash
else:
self.checkpoint_dir = Path(checkpoint_dir) / self.video_hash
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
# Checkpoint files
self.status_file = self.checkpoint_dir / "processing_status.json"
self.chunks_dir = self.checkpoint_dir / "chunks"
self.chunks_dir.mkdir(exist_ok=True)
# Load existing status or create new
self.status = self._load_status()
def _compute_video_hash(self) -> str:
"""Compute hash of video file for unique identification"""
# Use file path, size, and modification time for quick hash
stat = self.video_path.stat()
hash_str = f"{self.video_path}_{stat.st_size}_{stat.st_mtime}"
return hashlib.md5(hash_str.encode()).hexdigest()[:12]
def _load_status(self) -> Dict[str, Any]:
"""Load processing status from checkpoint file"""
if self.status_file.exists():
with open(self.status_file, 'r') as f:
status = json.load(f)
print(f"📋 Loaded checkpoint: {status['completed_chunks']}/{status['total_chunks']} chunks completed")
return status
else:
# Create new status
return {
'video_path': str(self.video_path),
'output_path': str(self.output_path),
'video_hash': self.video_hash,
'start_time': datetime.now().isoformat(),
'total_chunks': 0,
'completed_chunks': 0,
'chunk_info': {},
'processing_complete': False,
'merge_complete': False
}
def _save_status(self):
"""Save current status to checkpoint file"""
self.status['last_update'] = datetime.now().isoformat()
with open(self.status_file, 'w') as f:
json.dump(self.status, f, indent=2)
def set_total_chunks(self, total_chunks: int):
"""Set total number of chunks to process"""
self.status['total_chunks'] = total_chunks
self._save_status()
def is_chunk_completed(self, chunk_idx: int) -> bool:
"""Check if a chunk has already been processed"""
chunk_key = f"chunk_{chunk_idx}"
return chunk_key in self.status['chunk_info'] and \
self.status['chunk_info'][chunk_key].get('completed', False)
def get_chunk_file(self, chunk_idx: int) -> Optional[Path]:
"""Get saved chunk file path if it exists"""
chunk_file = self.chunks_dir / f"chunk_{chunk_idx:04d}.npz"
if chunk_file.exists() and self.is_chunk_completed(chunk_idx):
return chunk_file
return None
def save_chunk(self, chunk_idx: int, frames: List, source_chunk_path: Optional[Path] = None):
"""
Save processed chunk and mark as completed
Args:
chunk_idx: Chunk index
frames: Processed frames (can be None if using source_chunk_path)
source_chunk_path: If provided, copy this file instead of saving frames
"""
chunk_file = self.chunks_dir / f"chunk_{chunk_idx:04d}.npz"
try:
if source_chunk_path and source_chunk_path.exists():
# Copy existing chunk file
shutil.copy2(source_chunk_path, chunk_file)
print(f"💾 Copied chunk {chunk_idx} to checkpoint: {chunk_file.name}")
elif frames is not None:
# Save new frames
import numpy as np
np.savez_compressed(str(chunk_file), frames=frames)
print(f"💾 Saved chunk {chunk_idx} to checkpoint: {chunk_file.name}")
else:
raise ValueError("Either frames or source_chunk_path must be provided")
# Update status
chunk_key = f"chunk_{chunk_idx}"
self.status['chunk_info'][chunk_key] = {
'completed': True,
'file': chunk_file.name,
'timestamp': datetime.now().isoformat()
}
self.status['completed_chunks'] = len([c for c in self.status['chunk_info'].values() if c['completed']])
self._save_status()
print(f"✅ Chunk {chunk_idx} checkpoint saved ({self.status['completed_chunks']}/{self.status['total_chunks']})")
except Exception as e:
print(f"❌ Failed to save chunk {chunk_idx} checkpoint: {e}")
def get_completed_chunk_files(self) -> List[Path]:
"""Get list of all completed chunk files in order"""
chunk_files = []
missing_chunks = []
for i in range(self.status['total_chunks']):
chunk_file = self.get_chunk_file(i)
if chunk_file:
chunk_files.append(chunk_file)
else:
# Check if chunk is marked as completed but file is missing
if self.is_chunk_completed(i):
missing_chunks.append(i)
print(f"⚠️ Chunk {i} marked complete but file missing!")
else:
break # Stop at first unprocessed chunk
if missing_chunks:
print(f"❌ Missing checkpoint files for chunks: {missing_chunks}")
print(f" This may happen if files were deleted during streaming merge")
print(f" These chunks may need to be reprocessed")
return chunk_files
def mark_processing_complete(self):
"""Mark all chunk processing as complete"""
self.status['processing_complete'] = True
self._save_status()
print(f"✅ All chunks processed and checkpointed")
def mark_merge_complete(self):
"""Mark final merge as complete"""
self.status['merge_complete'] = True
self._save_status()
print(f"✅ Video merge completed")
def cleanup_checkpoints(self, keep_chunks: bool = False):
"""
Clean up checkpoint files after successful completion
Args:
keep_chunks: If True, keep chunk files but remove status
"""
if keep_chunks:
# Just remove status file
if self.status_file.exists():
self.status_file.unlink()
print(f"🗑️ Removed checkpoint status file")
else:
# Remove entire checkpoint directory
if self.checkpoint_dir.exists():
shutil.rmtree(self.checkpoint_dir)
print(f"🗑️ Removed all checkpoint files: {self.checkpoint_dir}")
def get_resume_info(self) -> Dict[str, Any]:
"""Get information about what can be resumed"""
return {
'can_resume': self.status['completed_chunks'] > 0,
'completed_chunks': self.status['completed_chunks'],
'total_chunks': self.status['total_chunks'],
'processing_complete': self.status['processing_complete'],
'merge_complete': self.status['merge_complete'],
'checkpoint_dir': str(self.checkpoint_dir)
}
def print_status(self):
"""Print current checkpoint status"""
print(f"\n📊 CHECKPOINT STATUS:")
print(f" Video: {self.video_path.name}")
print(f" Hash: {self.video_hash}")
print(f" Progress: {self.status['completed_chunks']}/{self.status['total_chunks']} chunks")
print(f" Processing complete: {self.status['processing_complete']}")
print(f" Merge complete: {self.status['merge_complete']}")
print(f" Checkpoint dir: {self.checkpoint_dir}")
if self.status['completed_chunks'] > 0:
print(f"\n Completed chunks:")
for i in range(self.status['completed_chunks']):
chunk_info = self.status['chunk_info'].get(f'chunk_{i}', {})
if chunk_info.get('completed'):
print(f" ✓ Chunk {i}: {chunk_info.get('file', 'unknown')}")

View File

@@ -29,6 +29,11 @@ class MattingConfig:
fp16: bool = True
sam2_model_cfg: str = "sam2.1_hiera_l"
sam2_checkpoint: str = "segment-anything-2/checkpoints/sam2.1_hiera_large.pt"
# Det-SAM2 optimizations
continuous_correction: bool = True
correction_interval: int = 60 # Add correction prompts every N frames
frame_release_interval: int = 50 # Release old frames every N frames
frame_window_size: int = 30 # Keep N frames in memory
@dataclass
@@ -37,6 +42,8 @@ class OutputConfig:
format: str = "alpha"
background_color: List[int] = None
maintain_sbs: bool = True
preserve_audio: bool = True
verify_sync: bool = True
def __post_init__(self):
if self.background_color is None:
@@ -99,7 +106,9 @@ class VR180Config:
'path': self.output.path,
'format': self.output.format,
'background_color': self.output.background_color,
'maintain_sbs': self.output.maintain_sbs
'maintain_sbs': self.output.maintain_sbs,
'preserve_audio': self.output.preserve_audio,
'verify_sync': self.output.verify_sync
},
'hardware': {
'device': self.hardware.device,

View File

@@ -1,6 +1,4 @@
import torch
import numpy as np
from ultralytics import YOLO
from typing import List, Tuple, Dict, Any
import cv2
@@ -13,14 +11,23 @@ class YOLODetector:
self.confidence_threshold = confidence_threshold
self.device = device
self.model = None
self._load_model()
# Don't load model during init - load lazily when first used
def _load_model(self):
"""Load YOLOv8 model"""
"""Load YOLOv8 model lazily"""
if self.model is not None:
return # Already loaded
try:
# Import heavy dependencies only when needed
import torch
from ultralytics import YOLO
self.model = YOLO(f"{self.model_name}.pt")
if self.device == "cuda" and torch.cuda.is_available():
self.model.to("cuda")
print(f"🎯 Loaded YOLO model: {self.model_name}")
except Exception as e:
raise RuntimeError(f"Failed to load YOLO model {self.model_name}: {e}")
@@ -34,8 +41,9 @@ class YOLODetector:
Returns:
List of detection dictionaries with bbox, confidence, and class info
"""
# Load model lazily on first use
if self.model is None:
raise RuntimeError("YOLO model not loaded")
self._load_model()
results = self.model(frame, verbose=False)
detections = []

View File

@@ -5,13 +5,20 @@ import cv2
from pathlib import Path
import warnings
import os
import tempfile
import shutil
import gc
# Check SAM2 availability without importing heavy modules
def _check_sam2_available():
try:
from sam2.build_sam import build_sam2_video_predictor
from sam2.sam2_image_predictor import SAM2ImagePredictor
SAM2_AVAILABLE = True
import sam2
return True
except ImportError:
SAM2_AVAILABLE = False
return False
SAM2_AVAILABLE = _check_sam2_available()
if not SAM2_AVAILABLE:
warnings.warn("SAM2 not available. Please install sam2 package.")
@@ -30,15 +37,25 @@ class SAM2VideoMatting:
self.device = device
self.memory_offload = memory_offload
self.fp16 = fp16
self.model_cfg = model_cfg
self.checkpoint_path = checkpoint_path
self.predictor = None
self.inference_state = None
self.video_segments = {}
self.temp_video_path = None
self._load_model(model_cfg, checkpoint_path)
# Don't load model during init - load lazily when needed
self._model_loaded = False
def _load_model(self, model_cfg: str, checkpoint_path: str):
"""Load SAM2 video predictor with optimizations"""
"""Load SAM2 video predictor lazily"""
if self._model_loaded and self.predictor is not None:
return # Already loaded and predictor exists
try:
# Import heavy SAM2 modules only when needed
from sam2.build_sam import build_sam2_video_predictor
# Check for checkpoint in SAM2 repo structure
if not Path(checkpoint_path).exists():
# Try in segment-anything-2/checkpoints/
@@ -57,36 +74,63 @@ class SAM2VideoMatting:
if sam2_repo_path.exists():
checkpoint_path = str(sam2_repo_path)
# Use the config path as-is (should be relative to SAM2 package)
# Example: "configs/sam2.1/sam2.1_hiera_l.yaml"
print(f"🎯 Loading SAM2 model: {model_cfg}")
# Use SAM2's build_sam2_video_predictor which returns the predictor directly
# The predictor IS the model - no .model attribute needed
self.predictor = build_sam2_video_predictor(
model_cfg,
checkpoint_path,
config_file=model_cfg,
ckpt_path=checkpoint_path,
device=self.device
)
# Enable memory optimizations
if self.memory_offload:
self.predictor.fill_hole_area = 8
if self.fp16 and self.device == "cuda":
self.predictor.model.half()
self._model_loaded = True
print(f"✅ SAM2 model loaded successfully")
except Exception as e:
raise RuntimeError(f"Failed to load SAM2 model: {e}")
def init_video_state(self, video_frames: List[np.ndarray]) -> None:
def init_video_state(self, video_frames: List[np.ndarray] = None, video_path: str = None) -> None:
"""Initialize video inference state"""
if self.predictor is None:
raise RuntimeError("SAM2 model not loaded")
# Load model lazily on first use
if not self._model_loaded:
self._load_model(self.model_cfg, self.checkpoint_path)
# Create temporary directory for frames if needed
if video_path is not None:
# Use video path directly (SAM2's preferred method)
self.inference_state = self.predictor.init_state(
video_path=None,
video_frames=video_frames,
video_path=video_path,
offload_video_to_cpu=self.memory_offload,
async_loading_frames=True
)
else:
# For frame arrays, we need to save them as a temporary video first
if video_frames is None or len(video_frames) == 0:
raise ValueError("Either video_path or video_frames must be provided")
# Create temporary video file in current directory
import uuid
temp_video_name = f"temp_sam2_{uuid.uuid4().hex[:8]}.mp4"
temp_video_path = Path.cwd() / temp_video_name
# Write frames to temporary video
height, width = video_frames[0].shape[:2]
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
writer = cv2.VideoWriter(str(temp_video_path), fourcc, 30.0, (width, height))
for frame in video_frames:
writer.write(frame)
writer.release()
# Initialize with temporary video
self.inference_state = self.predictor.init_state(
video_path=str(temp_video_path),
offload_video_to_cpu=self.memory_offload,
async_loading_frames=True
)
# Store temp path for cleanup
self.temp_video_path = temp_video_path
def add_person_prompts(self,
frame_idx: int,
@@ -123,13 +167,16 @@ class SAM2VideoMatting:
return object_ids
def propagate_masks(self, start_frame: int = 0, max_frames: Optional[int] = None) -> Dict[int, Dict[int, np.ndarray]]:
def propagate_masks(self, start_frame: int = 0, max_frames: Optional[int] = None,
frame_release_interval: int = 50, frame_window_size: int = 30) -> Dict[int, Dict[int, np.ndarray]]:
"""
Propagate masks through video
Propagate masks through video with Det-SAM2 style memory management
Args:
start_frame: Starting frame index
max_frames: Maximum number of frames to process
frame_release_interval: Release old frames every N frames
frame_window_size: Keep N frames in memory
Returns:
Dictionary mapping frame_idx -> {obj_id: mask}
@@ -153,9 +200,108 @@ class SAM2VideoMatting:
video_segments[out_frame_idx] = frame_masks
# Memory management: release old frames periodically
if self.memory_offload and out_frame_idx % 100 == 0:
self._release_old_frames(out_frame_idx - 50)
# Det-SAM2 style memory management: more aggressive frame release
if self.memory_offload and out_frame_idx % frame_release_interval == 0:
self._release_old_frames(out_frame_idx - frame_window_size)
# Optional: Log frame release for monitoring
if out_frame_idx % (frame_release_interval * 4) == 0: # Log every 4x release interval
print(f"Det-SAM2: Released frames before {out_frame_idx - frame_window_size}, keeping {frame_window_size} frames")
return video_segments
def propagate_masks_with_continuous_correction(self,
detector,
temp_video_path: str,
start_frame: int = 0,
max_frames: Optional[int] = None,
correction_interval: int = 60,
frame_release_interval: int = 50,
frame_window_size: int = 30) -> Dict[int, Dict[int, np.ndarray]]:
"""
Det-SAM2 style: Propagate masks with continuous prompt correction
Args:
detector: YOLODetector instance for generating correction prompts
temp_video_path: Path to video file for frame access
start_frame: Starting frame index
max_frames: Maximum number of frames to process
correction_interval: Add correction prompts every N frames
frame_release_interval: Release old frames every N frames
frame_window_size: Keep N frames in memory
Returns:
Dictionary mapping frame_idx -> {obj_id: mask}
"""
if self.inference_state is None:
raise RuntimeError("Video state not initialized")
video_segments = {}
max_frames = max_frames or 10000 # Default limit
# Open video for accessing frames during propagation
cap = cv2.VideoCapture(str(temp_video_path))
try:
for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(
self.inference_state,
start_frame_idx=start_frame,
max_frame_num_to_track=max_frames,
reverse=False
):
frame_masks = {}
for i, out_obj_id in enumerate(out_obj_ids):
mask = (out_mask_logits[i] > 0.0).cpu().numpy()
frame_masks[out_obj_id] = mask
video_segments[out_frame_idx] = frame_masks
# Det-SAM2 optimization: Add correction prompts at keyframes
if (out_frame_idx % correction_interval == 0 and
out_frame_idx > start_frame and
out_frame_idx < max_frames - 1):
# Read frame for detection
cap.set(cv2.CAP_PROP_POS_FRAMES, out_frame_idx)
ret, correction_frame = cap.read()
if ret:
# Run detection on this keyframe
detections = detector.detect_persons(correction_frame)
if detections:
# Convert to prompts and add as corrections
box_prompts, labels = detector.convert_to_sam_prompts(detections)
# Add correction prompts (SAM2 will propagate backward)
correction_count = 0
try:
for i, (box, label) in enumerate(zip(box_prompts, labels)):
# Use existing object IDs if available, otherwise create new ones
obj_id = out_obj_ids[i] if i < len(out_obj_ids) else len(out_obj_ids) + i + 1
self.predictor.add_new_points_or_box(
inference_state=self.inference_state,
frame_idx=out_frame_idx,
obj_id=obj_id,
box=box,
)
correction_count += 1
print(f"Det-SAM2: Added {correction_count} correction prompts at frame {out_frame_idx}")
except Exception as e:
warnings.warn(f"Failed to add correction prompt at frame {out_frame_idx}: {e}")
# Memory management: More aggressive frame release (Det-SAM2 style)
if self.memory_offload and out_frame_idx % frame_release_interval == 0:
self._release_old_frames(out_frame_idx - frame_window_size)
# Optional: Log frame release for monitoring
if out_frame_idx % (frame_release_interval * 4) == 0: # Log every 4x release interval
print(f"Det-SAM2: Released frames before {out_frame_idx - frame_window_size}, keeping {frame_window_size} frames")
finally:
cap.release()
return video_segments
@@ -231,17 +377,58 @@ class SAM2VideoMatting:
"""Clean up resources"""
if self.inference_state is not None:
try:
if hasattr(self.predictor, 'cleanup_state'):
# Reset SAM2 state first (critical for memory cleanup)
if self.predictor is not None and hasattr(self.predictor, 'reset_state'):
self.predictor.reset_state(self.inference_state)
# Fallback to cleanup_state if available
elif self.predictor is not None and hasattr(self.predictor, 'cleanup_state'):
self.predictor.cleanup_state(self.inference_state)
# Explicitly delete inference state and video segments
del self.inference_state
if hasattr(self, 'video_segments') and self.video_segments:
del self.video_segments
self.video_segments = {}
except Exception as e:
warnings.warn(f"Failed to cleanup SAM2 state: {e}")
finally:
self.inference_state = None
# Clean up temporary video file
if self.temp_video_path is not None:
try:
if self.temp_video_path.exists():
# Remove the temporary video file
self.temp_video_path.unlink()
self.temp_video_path = None
except Exception as e:
warnings.warn(f"Failed to cleanup temp video: {e}")
# Clear CUDA cache
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Explicitly delete predictor for fresh creation next time
if self.predictor is not None:
try:
del self.predictor
except Exception as e:
warnings.warn(f"Failed to delete predictor: {e}")
finally:
self.predictor = None
# Reset model loaded state for fresh reload
self._model_loaded = False
# Force garbage collection (critical for memory leak prevention)
gc.collect()
def __del__(self):
"""Destructor to ensure cleanup"""
try:
self.cleanup()
except Exception:
# Ignore errors during Python shutdown
pass

View File

@@ -7,6 +7,12 @@ import tempfile
import shutil
from tqdm import tqdm
import warnings
import time
import subprocess
import gc
import psutil
import os
import sys
from .config import VR180Config
from .detector import YOLODetector
@@ -35,8 +41,137 @@ class VideoProcessor:
self.frame_width = 0
self.frame_height = 0
# Processing statistics
self.processing_stats = {
'start_time': None,
'end_time': None,
'total_duration': 0,
'processing_fps': 0,
'chunks_processed': 0,
'frames_processed': 0
}
self._initialize_models()
def _get_process_memory_info(self) -> Dict[str, float]:
"""Get detailed memory usage for current process and children"""
current_process = psutil.Process(os.getpid())
# Get memory info for current process
memory_info = current_process.memory_info()
current_rss = memory_info.rss / 1024**3 # Convert to GB
current_vms = memory_info.vms / 1024**3 # Virtual memory
# Get memory info for all children
children_rss = 0
children_vms = 0
child_count = 0
try:
for child in current_process.children(recursive=True):
try:
child_memory = child.memory_info()
children_rss += child_memory.rss / 1024**3
children_vms += child_memory.vms / 1024**3
child_count += 1
except (psutil.NoSuchProcess, psutil.AccessDenied):
pass
except psutil.NoSuchProcess:
pass
# System memory info
system_memory = psutil.virtual_memory()
system_total = system_memory.total / 1024**3
system_available = system_memory.available / 1024**3
system_used = system_memory.used / 1024**3
system_percent = system_memory.percent
return {
'process_rss_gb': current_rss,
'process_vms_gb': current_vms,
'children_rss_gb': children_rss,
'children_vms_gb': children_vms,
'total_process_gb': current_rss + children_rss,
'child_count': child_count,
'system_total_gb': system_total,
'system_used_gb': system_used,
'system_available_gb': system_available,
'system_percent': system_percent
}
def _print_memory_step(self, step_name: str):
"""Print memory usage for a specific processing step"""
memory_info = self._get_process_memory_info()
print(f"\n📊 MEMORY: {step_name}")
print(f" Process RSS: {memory_info['process_rss_gb']:.2f} GB")
if memory_info['children_rss_gb'] > 0:
print(f" Children RSS: {memory_info['children_rss_gb']:.2f} GB ({memory_info['child_count']} processes)")
print(f" Total Process: {memory_info['total_process_gb']:.2f} GB")
print(f" System: {memory_info['system_used_gb']:.1f}/{memory_info['system_total_gb']:.1f} GB ({memory_info['system_percent']:.1f}%)")
print(f" Available: {memory_info['system_available_gb']:.1f} GB")
def _aggressive_memory_cleanup(self, step_name: str = ""):
"""Perform aggressive memory cleanup and report before/after"""
if step_name:
print(f"\n🧹 CLEANUP: Before {step_name}")
before_info = self._get_process_memory_info()
before_rss = before_info['total_process_gb']
# Multiple rounds of garbage collection
for i in range(3):
gc.collect()
# Clear torch cache if available
try:
import torch
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
except ImportError:
pass
# Clear OpenCV internal caches
try:
# Clear OpenCV video capture cache
cv2.setUseOptimized(False)
cv2.setUseOptimized(True)
except Exception:
pass
# Clear CuPy caches if available
try:
import cupy as cp
cp._default_memory_pool.free_all_blocks()
cp._default_pinned_memory_pool.free_all_blocks()
cp.get_default_memory_pool().free_all_blocks()
cp.get_default_pinned_memory_pool().free_all_blocks()
except ImportError:
pass
except Exception as e:
print(f" Warning: Could not clear CuPy cache: {e}")
# Force Linux to release memory back to OS
if sys.platform == 'linux':
try:
import ctypes
libc = ctypes.CDLL("libc.so.6")
libc.malloc_trim(0)
except Exception as e:
print(f" Warning: Could not trim memory: {e}")
# Brief pause to allow cleanup
time.sleep(0.1)
after_info = self._get_process_memory_info()
after_rss = after_info['total_process_gb']
freed_memory = before_rss - after_rss
if step_name:
print(f" Before: {before_rss:.2f} GB → After: {after_rss:.2f} GB")
print(f" Freed: {freed_memory:.2f} GB")
def _initialize_models(self):
"""Initialize YOLO detector and SAM2 model"""
print("Initializing models...")
@@ -146,6 +281,116 @@ class VideoProcessor:
print(f"Read {len(frames)} frames")
return frames
def read_video_frames_dual_resolution(self,
video_path: str,
start_frame: int = 0,
num_frames: Optional[int] = None,
scale_factor: float = 0.5) -> Dict[str, List[np.ndarray]]:
"""
Read video frames at both original and scaled resolution for dual-resolution processing
Args:
video_path: Path to video file
start_frame: Starting frame index
num_frames: Number of frames to read (None for all)
scale_factor: Scaling factor for inference frames
Returns:
Dictionary with 'original' and 'scaled' frame lists
"""
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise RuntimeError(f"Could not open video file: {video_path}")
# Set starting position
cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
original_frames = []
scaled_frames = []
frame_count = 0
# Progress tracking
total_to_read = num_frames if num_frames else self.total_frames - start_frame
with tqdm(total=total_to_read, desc="Reading dual-resolution frames") as pbar:
while True:
ret, frame = cap.read()
if not ret:
break
# Store original frame
original_frames.append(frame.copy())
# Create scaled frame for inference
if scale_factor != 1.0:
new_width = int(frame.shape[1] * scale_factor)
new_height = int(frame.shape[0] * scale_factor)
scaled_frame = cv2.resize(frame, (new_width, new_height),
interpolation=cv2.INTER_AREA)
else:
scaled_frame = frame.copy()
scaled_frames.append(scaled_frame)
frame_count += 1
pbar.update(1)
if num_frames is not None and frame_count >= num_frames:
break
cap.release()
print(f"Loaded {len(original_frames)} frames:")
print(f" Original: {original_frames[0].shape} per frame")
print(f" Scaled: {scaled_frames[0].shape} per frame (scale_factor={scale_factor})")
return {
'original': original_frames,
'scaled': scaled_frames
}
def upscale_mask(self, mask: np.ndarray, target_shape: tuple, method: str = 'cubic') -> np.ndarray:
"""
Upscale a mask from inference resolution to original resolution
Args:
mask: Low-resolution mask (H, W)
target_shape: Target shape (H, W) for upscaling
method: Upscaling method ('nearest', 'cubic', 'area')
Returns:
Upscaled mask at target resolution
"""
if mask.shape[:2] == target_shape[:2]:
return mask # Already correct size
# Ensure mask is 2D
if mask.ndim == 3:
mask = mask.squeeze()
# Choose interpolation method
if method == 'nearest':
interpolation = cv2.INTER_NEAREST # Crisp edges, good for sharp subjects
elif method == 'cubic':
interpolation = cv2.INTER_CUBIC # Smooth edges, good for most content
elif method == 'area':
interpolation = cv2.INTER_AREA # Good for downscaling, not upscaling
else:
interpolation = cv2.INTER_CUBIC # Default to cubic
# Upscale mask
upscaled_mask = cv2.resize(
mask.astype(np.uint8),
(target_shape[1], target_shape[0]), # (width, height) for cv2.resize
interpolation=interpolation
)
# Convert back to boolean if it was originally boolean
if mask.dtype == bool:
upscaled_mask = upscaled_mask.astype(bool)
return upscaled_mask
def calculate_optimal_chunking(self) -> Tuple[int, int]:
"""
Calculate optimal chunk size and overlap based on memory constraints
@@ -234,6 +479,92 @@ class VideoProcessor:
return matted_frames
def process_chunk_dual_resolution(self,
frame_data: Dict[str, List[np.ndarray]],
chunk_idx: int = 0) -> List[np.ndarray]:
"""
Process a chunk using dual-resolution approach: inference at low-res, output at full-res
Args:
frame_data: Dictionary with 'original' and 'scaled' frame lists
chunk_idx: Chunk index for logging
Returns:
List of matted frames at original resolution
"""
original_frames = frame_data['original']
scaled_frames = frame_data['scaled']
print(f"Processing chunk {chunk_idx} with dual-resolution ({len(original_frames)} frames)")
print(f" Inference: {scaled_frames[0].shape} → Output: {original_frames[0].shape}")
with self.memory_manager.memory_monitor(f"dual-res chunk {chunk_idx}"):
# Initialize SAM2 with scaled frames for inference
self.sam2_model.init_video_state(scaled_frames)
# Detect persons in first scaled frame
first_scaled_frame = scaled_frames[0]
detections = self.detector.detect_persons(first_scaled_frame)
if not detections:
warnings.warn(f"No persons detected in chunk {chunk_idx}")
return self._create_empty_masks(original_frames)
print(f"Detected {len(detections)} persons in first frame (at inference resolution)")
# Convert detections to SAM2 prompts (detections are already at scaled resolution)
box_prompts, labels = self.detector.convert_to_sam_prompts(detections)
# Add prompts to SAM2
object_ids = self.sam2_model.add_person_prompts(0, box_prompts, labels)
print(f"Added prompts for {len(object_ids)} objects")
# Propagate masks through chunk at inference resolution
video_segments = self.sam2_model.propagate_masks(
start_frame=0,
max_frames=len(scaled_frames)
)
# Apply upscaled masks to original resolution frames
matted_frames = []
original_shape = original_frames[0].shape[:2] # (H, W)
for frame_idx, original_frame in enumerate(tqdm(original_frames, desc="Applying upscaled masks")):
if frame_idx in video_segments:
frame_masks = video_segments[frame_idx]
# Get combined mask at inference resolution
combined_mask_scaled = self.sam2_model.get_combined_mask(frame_masks)
if combined_mask_scaled is not None:
# Upscale mask to original resolution
combined_mask_full = self.upscale_mask(
combined_mask_scaled,
target_shape=original_shape,
method='cubic' # Smooth upscaling for masks
)
# Apply upscaled mask to original resolution frame
matted_frame = self.sam2_model.apply_mask_to_frame(
original_frame, combined_mask_full,
output_format=self.config.output.format,
background_color=self.config.output.background_color
)
else:
# No mask for this frame
matted_frame = self._create_empty_mask_frame(original_frame)
else:
# No mask for this frame
matted_frame = self._create_empty_mask_frame(original_frame)
matted_frames.append(matted_frame)
# Cleanup SAM2 state
self.sam2_model.cleanup()
print(f"✅ Dual-resolution processing complete: {len(matted_frames)} frames at full resolution")
return matted_frames
def _create_empty_masks(self, frames: List[np.ndarray]) -> List[np.ndarray]:
"""Create empty masks when no persons detected"""
empty_frames = []
@@ -252,19 +583,213 @@ class VideoProcessor:
# Green screen background
return np.full_like(frame, self.config.output.background_color, dtype=np.uint8)
def merge_chunks_streaming(self, chunk_files: List[Path], output_path: str,
overlap_frames: int = 0, audio_source: str = None) -> None:
"""
Merge processed chunks using streaming approach (no memory accumulation)
Args:
chunk_files: List of chunk result files (.npz)
output_path: Final output video path
overlap_frames: Number of overlapping frames
audio_source: Audio source file for final video
"""
if not chunk_files:
raise ValueError("No chunk files to merge")
print(f"🎬 TRUE Streaming merge: {len(chunk_files)} chunks → {output_path}")
# Create temporary directory for frame images
import tempfile
temp_frames_dir = Path(tempfile.mkdtemp(prefix="merge_frames_"))
frame_counter = 0
try:
print(f"📁 Using temp frames dir: {temp_frames_dir}")
# Process each chunk frame-by-frame (true streaming)
for i, chunk_file in enumerate(chunk_files):
print(f"📼 Processing chunk {i+1}/{len(chunk_files)}: {chunk_file.name}")
# Load chunk metadata without loading frames array
chunk_data = np.load(str(chunk_file))
frames_array = chunk_data['frames'] # This is still mmap'd, not loaded
total_frames_in_chunk = frames_array.shape[0]
# Determine which frames to skip for overlap
start_frame_idx = overlap_frames if i > 0 and overlap_frames > 0 else 0
frames_to_process = total_frames_in_chunk - start_frame_idx
if start_frame_idx > 0:
print(f" ✂️ Skipping first {start_frame_idx} overlapping frames")
print(f" 🔄 Processing {frames_to_process} frames one-by-one...")
# Process frames ONE AT A TIME (true streaming)
for frame_idx in range(start_frame_idx, total_frames_in_chunk):
# Load only ONE frame at a time
frame = frames_array[frame_idx] # Load single frame
# Save frame directly to disk
frame_path = temp_frames_dir / f"frame_{frame_counter:06d}.jpg"
success = cv2.imwrite(str(frame_path), frame, [cv2.IMWRITE_JPEG_QUALITY, 95])
if not success:
raise RuntimeError(f"Failed to save frame {frame_counter}")
frame_counter += 1
# Periodic progress and cleanup
if frame_counter % 100 == 0:
print(f" 💾 Saved {frame_counter} frames...")
gc.collect() # Periodic cleanup
print(f" ✅ Saved {frames_to_process} frames to disk (total: {frame_counter})")
# Close chunk file and cleanup
chunk_data.close()
del chunk_data, frames_array
# Don't delete checkpoint files - they're needed for potential resume
# The checkpoint system manages cleanup separately
print(f" 📋 Keeping checkpoint file: {chunk_file.name}")
# Aggressive cleanup and memory monitoring after each chunk
self._aggressive_memory_cleanup(f"After streaming merge chunk {i}")
# Memory safety check
memory_info = self._get_process_memory_info()
if memory_info['total_process_gb'] > 35: # Warning if approaching 46GB limit
print(f"⚠️ High memory usage: {memory_info['total_process_gb']:.1f}GB - forcing cleanup")
gc.collect()
import torch
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Create final video directly from frame images using ffmpeg
print(f"📹 Creating final video from {frame_counter} frames...")
self._create_video_from_frames(temp_frames_dir, Path(output_path), frame_counter)
# Add audio if provided
if audio_source:
self._add_audio_to_video(output_path, audio_source)
except Exception as e:
print(f"❌ Streaming merge failed: {e}")
raise
finally:
# Cleanup temporary frames directory
try:
if temp_frames_dir.exists():
import shutil
shutil.rmtree(temp_frames_dir)
print(f"🗑️ Cleaned up temp frames dir: {temp_frames_dir}")
except Exception as e:
print(f"⚠️ Could not cleanup temp frames dir: {e}")
# Memory cleanup
gc.collect()
print(f"✅ TRUE Streaming merge complete: {output_path}")
def _create_video_from_frames(self, frames_dir: Path, output_path: Path, frame_count: int):
"""Create video directly from frame images using ffmpeg (memory efficient)"""
import subprocess
frame_pattern = str(frames_dir / "frame_%06d.jpg")
fps = self.video_info['fps'] if hasattr(self, 'video_info') and self.video_info else 30.0
print(f"🎬 Creating video with ffmpeg: {frame_count} frames at {fps} fps")
# Use GPU encoding if available, fallback to CPU
gpu_cmd = [
'ffmpeg', '-y', # -y to overwrite output file
'-framerate', str(fps),
'-i', frame_pattern,
'-c:v', 'h264_nvenc', # NVIDIA GPU encoder
'-preset', 'fast',
'-cq', '18', # Quality for GPU encoding
'-pix_fmt', 'yuv420p',
str(output_path)
]
cpu_cmd = [
'ffmpeg', '-y', # -y to overwrite output file
'-framerate', str(fps),
'-i', frame_pattern,
'-c:v', 'libx264', # CPU encoder
'-preset', 'medium',
'-crf', '18', # Quality for CPU encoding
'-pix_fmt', 'yuv420p',
str(output_path)
]
# Try GPU first
print(f"🚀 Trying GPU encoding...")
result = subprocess.run(gpu_cmd, capture_output=True, text=True)
if result.returncode != 0:
print("⚠️ GPU encoding failed, using CPU...")
print(f"🔄 CPU encoding...")
result = subprocess.run(cpu_cmd, capture_output=True, text=True)
else:
print("✅ GPU encoding successful!")
if result.returncode != 0:
print(f"❌ FFmpeg stdout: {result.stdout}")
print(f"❌ FFmpeg stderr: {result.stderr}")
raise RuntimeError(f"FFmpeg failed with return code {result.returncode}")
print(f"✅ Video created successfully: {output_path}")
def _add_audio_to_video(self, video_path: str, audio_source: str):
"""Add audio to video using ffmpeg"""
import subprocess
import tempfile
try:
# Create temporary file for output with audio
temp_path = Path(video_path).with_suffix('.temp.mp4')
cmd = [
'ffmpeg', '-y',
'-i', str(video_path), # Input video (no audio)
'-i', str(audio_source), # Input audio source
'-c:v', 'copy', # Copy video without re-encoding
'-c:a', 'aac', # Encode audio as AAC
'-map', '0:v:0', # Map video from first input
'-map', '1:a:0', # Map audio from second input
'-shortest', # Match shortest stream duration
str(temp_path)
]
print(f"🎵 Adding audio: {audio_source}{video_path}")
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
print(f"⚠️ Audio addition failed: {result.stderr}")
# Keep original video without audio
return
# Replace original with audio version
Path(video_path).unlink()
temp_path.rename(video_path)
print(f"✅ Audio added successfully")
except Exception as e:
print(f"⚠️ Could not add audio: {e}")
def merge_overlapping_chunks(self,
chunk_results: List[List[np.ndarray]],
overlap_frames: int) -> List[np.ndarray]:
"""
Merge overlapping chunks with blending in overlap regions
Args:
chunk_results: List of chunk results
overlap_frames: Number of overlapping frames
Returns:
Merged frame sequence
Legacy merge method - DEPRECATED due to memory accumulation
Use merge_chunks_streaming() instead for memory efficiency
"""
import warnings
warnings.warn("merge_overlapping_chunks() is deprecated due to memory accumulation. Use merge_chunks_streaming()",
DeprecationWarning, stacklevel=2)
if len(chunk_results) == 1:
return chunk_results[0]
@@ -348,70 +873,307 @@ class VideoProcessor:
print(f"Saved {len(frames)} PNG frames to {output_dir}")
def _save_mp4_video(self, frames: List[np.ndarray], output_path: str):
"""Save frames as MP4 video"""
"""Save frames as MP4 video with audio preservation"""
if not frames:
return
height, width = frames[0].shape[:2]
output_path = Path(output_path)
temp_frames_dir = output_path.parent / f"temp_frames_{output_path.stem}"
temp_frames_dir.mkdir(exist_ok=True)
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
writer = cv2.VideoWriter(output_path, fourcc, self.fps, (width, height))
for frame in tqdm(frames, desc="Writing video"):
try:
# Save frames as images
print("Saving frames as images...")
for i, frame in enumerate(tqdm(frames, desc="Saving frames")):
if frame.shape[2] == 4: # Convert RGBA to BGR
frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2BGR)
writer.write(frame)
writer.release()
frame_path = temp_frames_dir / f"frame_{i:06d}.jpg"
cv2.imwrite(str(frame_path), frame, [cv2.IMWRITE_JPEG_QUALITY, 95])
# Create video with ffmpeg
self._create_video_with_ffmpeg(temp_frames_dir, output_path, len(frames))
finally:
# Cleanup temporary frames
if temp_frames_dir.exists():
shutil.rmtree(temp_frames_dir)
def _create_video_with_ffmpeg(self, frames_dir: Path, output_path: Path, frame_count: int):
"""Create video using ffmpeg with audio preservation"""
frame_pattern = str(frames_dir / "frame_%06d.jpg")
if self.config.output.preserve_audio:
# Create video with audio from input
cmd = [
'ffmpeg', '-y',
'-framerate', str(self.fps),
'-i', frame_pattern,
'-i', str(self.config.input.video_path), # Input video for audio
'-c:v', 'h264_nvenc', # Try GPU encoding first
'-preset', 'fast',
'-cq', '18',
'-c:a', 'copy', # Copy audio without re-encoding
'-map', '0:v:0', # Map video from frames
'-map', '1:a:0', # Map audio from input video
'-shortest', # Match shortest stream duration
'-pix_fmt', 'yuv420p',
str(output_path)
]
else:
# Create video without audio
cmd = [
'ffmpeg', '-y',
'-framerate', str(self.fps),
'-i', frame_pattern,
'-c:v', 'h264_nvenc',
'-preset', 'fast',
'-cq', '18',
'-pix_fmt', 'yuv420p',
str(output_path)
]
print(f"Creating video with ffmpeg...")
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
# Try CPU encoding as fallback
print("GPU encoding failed, trying CPU encoding...")
cmd[cmd.index('h264_nvenc')] = 'libx264'
cmd[cmd.index('-cq')] = '-crf' # Change quality parameter for CPU
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
print(f"FFmpeg stdout: {result.stdout}")
print(f"FFmpeg stderr: {result.stderr}")
raise RuntimeError(f"FFmpeg failed with return code {result.returncode}")
# Verify frame count if sync verification is enabled
if self.config.output.verify_sync:
self._verify_frame_count(output_path, frame_count)
print(f"Saved video to {output_path}")
def _verify_frame_count(self, video_path: Path, expected_frames: int):
"""Verify output video has correct frame count"""
try:
probe = ffmpeg.probe(str(video_path))
video_stream = next(
(stream for stream in probe['streams'] if stream['codec_type'] == 'video'),
None
)
if video_stream:
actual_frames = int(video_stream.get('nb_frames', 0))
if actual_frames != expected_frames:
print(f"⚠️ Frame count mismatch: expected {expected_frames}, got {actual_frames}")
else:
print(f"✅ Frame count verified: {actual_frames} frames")
except Exception as e:
print(f"⚠️ Could not verify frame count: {e}")
def process_video(self) -> None:
"""Main video processing pipeline"""
"""Main video processing pipeline with checkpoint/resume support"""
self.processing_stats['start_time'] = time.time()
print("Starting VR180 video processing...")
# Load video info
self.load_video_info(self.config.input.video_path)
# Initialize checkpoint manager
from .checkpoint_manager import CheckpointManager
checkpoint_mgr = CheckpointManager(
self.config.input.video_path,
self.config.output.path
)
# Check for existing checkpoints
resume_info = checkpoint_mgr.get_resume_info()
if resume_info['can_resume']:
print(f"\n🔄 RESUME DETECTED:")
print(f" Found {resume_info['completed_chunks']} completed chunks")
print(f" Continue from where we left off? (saves time!)")
checkpoint_mgr.print_status()
# Calculate chunking parameters
chunk_size, overlap_frames = self.calculate_optimal_chunking()
# Process video in chunks
chunk_results = []
# Calculate total chunks
total_chunks = 0
for _ in range(0, self.total_frames, chunk_size - overlap_frames):
total_chunks += 1
checkpoint_mgr.set_total_chunks(total_chunks)
# Process video in chunks
chunk_files = [] # Store file paths instead of frame data
temp_chunk_dir = Path(tempfile.mkdtemp(prefix="vr180_chunks_"))
try:
chunk_idx = 0
for start_frame in range(0, self.total_frames, chunk_size - overlap_frames):
end_frame = min(start_frame + chunk_size, self.total_frames)
frames_to_read = end_frame - start_frame
chunk_idx = len(chunk_results)
# Check if this chunk was already processed
existing_chunk = checkpoint_mgr.get_chunk_file(chunk_idx)
if existing_chunk:
print(f"\n✅ Chunk {chunk_idx} already processed: {existing_chunk.name}")
chunk_files.append(existing_chunk)
chunk_idx += 1
continue
print(f"\nProcessing chunk {chunk_idx}: frames {start_frame}-{end_frame}")
# Read chunk frames
# Choose processing approach based on scale factor
if self.config.processing.scale_factor == 1.0:
# No scaling needed - use original single-resolution approach
print(f"🔄 Reading frames at original resolution (no scaling)")
frames = self.read_video_frames(
self.config.input.video_path,
start_frame=start_frame,
num_frames=frames_to_read,
scale_factor=1.0
)
# Process chunk normally (single resolution)
matted_frames = self.process_chunk(frames, chunk_idx)
else:
# Scaling required - use dual-resolution approach
print(f"🔄 Reading frames at dual resolution (scale_factor={self.config.processing.scale_factor})")
frame_data = self.read_video_frames_dual_resolution(
self.config.input.video_path,
start_frame=start_frame,
num_frames=frames_to_read,
scale_factor=self.config.processing.scale_factor
)
# Process chunk
matted_frames = self.process_chunk(frames, chunk_idx)
chunk_results.append(matted_frames)
# Process chunk with dual-resolution approach
matted_frames = self.process_chunk_dual_resolution(frame_data, chunk_idx)
# Memory cleanup
# Save chunk to disk immediately to free memory
chunk_path = temp_chunk_dir / f"chunk_{chunk_idx:04d}.npz"
print(f"Saving chunk {chunk_idx} to disk...")
np.savez_compressed(str(chunk_path), frames=matted_frames)
# Save to checkpoint
checkpoint_mgr.save_chunk(chunk_idx, None, source_chunk_path=chunk_path)
chunk_files.append(chunk_path)
chunk_idx += 1
# Free the frames from memory immediately
del matted_frames
if self.config.processing.scale_factor == 1.0:
del frames
else:
del frame_data
# Update statistics
self.processing_stats['chunks_processed'] += 1
self.processing_stats['frames_processed'] += frames_to_read
# Aggressive memory cleanup after each chunk
self._aggressive_memory_cleanup(f"chunk {chunk_idx} completion")
# Also use memory manager cleanup
self.memory_manager.cleanup_memory()
if self.memory_manager.should_emergency_cleanup():
self.memory_manager.emergency_cleanup()
# Merge chunks if multiple
print("\nMerging chunks...")
final_frames = self.merge_overlapping_chunks(chunk_results, overlap_frames)
# Mark chunk processing as complete
checkpoint_mgr.mark_processing_complete()
# Save results
print(f"Saving {len(final_frames)} processed frames...")
self.save_video(final_frames, self.config.output.path)
# Check if merge was already done
if resume_info.get('merge_complete', False):
print("\n✅ Merge already completed in previous run!")
print(f" Output: {self.config.output.path}")
else:
# Use streaming merge to avoid memory accumulation (fixes OOM)
print("\n🎬 Using streaming merge (no memory accumulation)...")
# For resume scenarios, make sure we have all chunk files
if resume_info['can_resume']:
checkpoint_chunk_files = checkpoint_mgr.get_completed_chunk_files()
if len(checkpoint_chunk_files) != len(chunk_files):
print(f"⚠️ Using {len(checkpoint_chunk_files)} checkpoint files instead of {len(chunk_files)} temp files")
chunk_files = checkpoint_chunk_files
# Determine audio source for final video
audio_source = None
if self.config.output.preserve_audio and Path(self.config.input.video_path).exists():
audio_source = self.config.input.video_path
# Stream merge chunks directly to output (no memory accumulation)
self.merge_chunks_streaming(
chunk_files=chunk_files,
output_path=self.config.output.path,
overlap_frames=overlap_frames,
audio_source=audio_source
)
# Mark merge as complete
checkpoint_mgr.mark_merge_complete()
print("✅ Streaming merge complete - no memory accumulation!")
# Calculate final statistics
self.processing_stats['end_time'] = time.time()
self.processing_stats['total_duration'] = self.processing_stats['end_time'] - self.processing_stats['start_time']
if self.processing_stats['total_duration'] > 0:
self.processing_stats['processing_fps'] = self.processing_stats['frames_processed'] / self.processing_stats['total_duration']
# Print processing statistics
self._print_processing_statistics()
# Print final memory report
self.memory_manager.print_memory_report()
print("Video processing completed!")
# Option to clean up checkpoints
print("\n🗄️ CHECKPOINT CLEANUP OPTIONS:")
print(" Checkpoints saved successfully and can be cleaned up")
print(" - Keep checkpoints for debugging: checkpoint_mgr.cleanup_checkpoints(keep_chunks=True)")
print(" - Remove all checkpoints: checkpoint_mgr.cleanup_checkpoints()")
print(f" - Checkpoint location: {checkpoint_mgr.checkpoint_dir}")
# For now, keep checkpoints by default (user can manually clean)
print("\n💡 Checkpoints kept for safety. Delete manually when no longer needed.")
finally:
# Clean up temporary chunk files (but not checkpoints)
if temp_chunk_dir.exists():
print("Cleaning up temporary chunk files...")
try:
shutil.rmtree(temp_chunk_dir)
except Exception as e:
print(f"⚠️ Could not clean temp directory: {e}")
def _print_processing_statistics(self):
"""Print detailed processing statistics"""
stats = self.processing_stats
video_duration = self.total_frames / self.fps if self.fps > 0 else 0
print("\n" + "="*60)
print("PROCESSING STATISTICS")
print("="*60)
print(f"Input video duration: {video_duration:.1f} seconds ({self.total_frames} frames @ {self.fps:.2f} fps)")
print(f"Total processing time: {stats['total_duration']:.1f} seconds")
print(f"Processing speed: {stats['processing_fps']:.2f} fps")
print(f"Speedup factor: {self.fps / stats['processing_fps']:.1f}x slower than realtime")
print(f"Chunks processed: {stats['chunks_processed']}")
print(f"Frames processed: {stats['frames_processed']}")
if video_duration > 0:
efficiency = video_duration / stats['total_duration']
print(f"Processing efficiency: {efficiency:.3f} (1.0 = realtime)")
# Estimate time for different video lengths
print(f"\nEstimated processing times:")
print(f" 5 minutes: {(5 * 60) / efficiency / 60:.1f} minutes")
print(f" 30 minutes: {(30 * 60) / efficiency / 60:.1f} minutes")
print(f" 1 hour: {(60 * 60) / efficiency / 60:.1f} minutes")
print("="*60 + "\n")

View File

@@ -3,6 +3,7 @@ import numpy as np
from typing import List, Dict, Any, Optional, Tuple
from pathlib import Path
import warnings
import torch
from .video_processor import VideoProcessor
from .config import VR180Config
@@ -65,17 +66,31 @@ class VR180Processor(VideoProcessor):
Returns:
Tuple of (left_eye_frame, right_eye_frame)
"""
if self.sbs_split_point == 0:
self.sbs_split_point = frame.shape[1] // 2
# Always calculate split point based on current frame width
# This handles scaled frames correctly
frame_width = frame.shape[1]
current_split_point = frame_width // 2
left_eye = frame[:, :self.sbs_split_point]
right_eye = frame[:, self.sbs_split_point:]
# Debug info on first use
if self.sbs_split_point == 0:
print(f"Frame dimensions: {frame.shape[1]}x{frame.shape[0]}")
print(f"Split point: {current_split_point}")
self.sbs_split_point = current_split_point # Store for reference
left_eye = frame[:, :current_split_point]
right_eye = frame[:, current_split_point:]
# Validate both eyes have content
if left_eye.size == 0:
raise RuntimeError(f"Left eye frame is empty after split (frame width: {frame_width})")
if right_eye.size == 0:
raise RuntimeError(f"Right eye frame is empty after split (frame width: {frame_width})")
return left_eye, right_eye
def combine_sbs_frame(self, left_eye: np.ndarray, right_eye: np.ndarray) -> np.ndarray:
"""
Combine left and right eye frames back into side-by-side format
Combine left and right eye frames back into side-by-side format with GPU acceleration
Args:
left_eye: Left eye frame
@@ -84,14 +99,44 @@ class VR180Processor(VideoProcessor):
Returns:
Combined SBS frame
"""
try:
import cupy as cp
# Transfer to GPU for faster combination
left_gpu = cp.asarray(left_eye)
right_gpu = cp.asarray(right_eye)
# Ensure frames have same height
if left_gpu.shape[0] != right_gpu.shape[0]:
target_height = min(left_gpu.shape[0], right_gpu.shape[0])
# Note: OpenCV resize not available in CuPy, fall back to CPU for resize
left_eye = cv2.resize(left_eye, (left_eye.shape[1], target_height))
right_eye = cv2.resize(right_eye, (right_eye.shape[1], target_height))
left_gpu = cp.asarray(left_eye)
right_gpu = cp.asarray(right_eye)
# Combine horizontally on GPU (much faster for large arrays)
combined_gpu = cp.hstack([left_gpu, right_gpu])
# Transfer back to CPU and ensure we get a copy, not a view
combined = cp.asnumpy(combined_gpu).copy()
# Free GPU memory immediately
del left_gpu, right_gpu, combined_gpu
cp._default_memory_pool.free_all_blocks()
return combined
except ImportError:
# Fallback to CPU NumPy
# Ensure frames have same height
if left_eye.shape[0] != right_eye.shape[0]:
target_height = min(left_eye.shape[0], right_eye.shape[0])
left_eye = cv2.resize(left_eye, (left_eye.shape[1], target_height))
right_eye = cv2.resize(right_eye, (right_eye.shape[1], target_height))
# Combine horizontally
combined = np.hstack([left_eye, right_eye])
# Combine horizontally and ensure we get a copy, not a view
combined = np.hstack([left_eye, right_eye]).copy()
return combined
def process_with_disparity_mapping(self,
@@ -113,8 +158,23 @@ class VR180Processor(VideoProcessor):
left_eye_frames = []
right_eye_frames = []
for frame in frames:
for i, frame in enumerate(frames):
left, right = self.split_sbs_frame(frame)
# Debug: Check if frames are valid
if i == 0: # Only debug first frame
print(f"Original frame shape: {frame.shape}")
print(f"Left eye shape: {left.shape}")
print(f"Right eye shape: {right.shape}")
print(f"Left eye min/max: {left.min()}/{left.max()}")
print(f"Right eye min/max: {right.min()}/{right.max()}")
# Validate frames
if left.size == 0:
raise RuntimeError(f"Left eye frame {i} is empty")
if right.size == 0:
raise RuntimeError(f"Right eye frame {i} is empty")
left_eye_frames.append(left)
right_eye_frames.append(right)
@@ -123,6 +183,10 @@ class VR180Processor(VideoProcessor):
with self.memory_manager.memory_monitor(f"left eye chunk {chunk_idx}"):
left_matted = self._process_eye_sequence(left_eye_frames, "left", chunk_idx)
# Free left eye frames after processing (before right eye to save memory)
del left_eye_frames
self._aggressive_memory_cleanup(f"After left eye processing chunk {chunk_idx}")
# Process right eye with cross-validation
print("Processing right eye with cross-validation...")
with self.memory_manager.memory_monitor(f"right eye chunk {chunk_idx}"):
@@ -130,6 +194,10 @@ class VR180Processor(VideoProcessor):
right_eye_frames, left_matted, "right", chunk_idx
)
# Free right eye frames after processing
del right_eye_frames
self._aggressive_memory_cleanup(f"After right eye processing chunk {chunk_idx}")
# Combine results back to SBS format
combined_frames = []
for left_frame, right_frame in zip(left_matted, right_matted):
@@ -140,6 +208,15 @@ class VR180Processor(VideoProcessor):
combined = {'left': left_frame, 'right': right_frame}
combined_frames.append(combined)
# Free the individual eye results after combining
del left_matted
del right_matted
self._aggressive_memory_cleanup(f"After combining frames chunk {chunk_idx}")
# CRITICAL: Complete inter-chunk cleanup to prevent model persistence
# This ensures models don't accumulate between chunks
self._complete_inter_chunk_cleanup(chunk_idx)
return combined_frames
def _process_eye_sequence(self,
@@ -150,16 +227,148 @@ class VR180Processor(VideoProcessor):
if not eye_frames:
return []
# Initialize SAM2 with eye frames
self.sam2_model.init_video_state(eye_frames)
# Create a unique temporary video for this eye processing
import uuid
temp_video_name = f"temp_sam2_{eye_name}_chunk{chunk_idx}_{uuid.uuid4().hex[:8]}.mp4"
temp_video_path = Path.cwd() / temp_video_name
try:
# Use ffmpeg approach since OpenCV video writer is failing
height, width = eye_frames[0].shape[:2]
temp_video_path = temp_video_path.with_suffix('.mp4')
print(f"Creating temp video using ffmpeg: {temp_video_path}")
print(f"Video params: size=({width}, {height}), frames={len(eye_frames)}")
# Create a temporary directory for frame images
temp_frames_dir = temp_video_path.parent / f"frames_{temp_video_path.stem}"
temp_frames_dir.mkdir(exist_ok=True)
# Save frames as individual images (using JPEG for smaller file size)
print("Saving frames as images...")
for i, frame in enumerate(eye_frames):
# Check if frame is empty
if frame.size == 0:
raise RuntimeError(f"Frame {i} is empty (size=0)")
# Ensure frame is uint8
if frame.dtype != np.uint8:
frame = frame.astype(np.uint8)
# Debug first frame
if i == 0:
print(f"First frame to save: shape={frame.shape}, dtype={frame.dtype}, empty={frame.size == 0}")
# Use JPEG instead of PNG for smaller files (faster I/O, less disk space)
frame_path = temp_frames_dir / f"frame_{i:06d}.jpg"
# Use high quality JPEG to minimize compression artifacts
success = cv2.imwrite(str(frame_path), frame, [cv2.IMWRITE_JPEG_QUALITY, 95])
if not success:
print(f"Frame {i} details: shape={frame.shape}, dtype={frame.dtype}, size={frame.size}")
raise RuntimeError(f"Failed to save frame {i} as image")
if i % 50 == 0:
print(f"Saved {i}/{len(eye_frames)} frames")
# Force garbage collection every 100 frames to free memory
if i % 100 == 0:
import gc
gc.collect()
# Use ffmpeg to create video from images
import subprocess
# Use the original video's framerate - access through parent class
original_fps = self.fps if hasattr(self, 'fps') else 30.0
print(f"Using framerate: {original_fps} fps")
# Memory monitoring before ffmpeg
self._print_memory_step(f"Before ffmpeg encoding ({eye_name} eye)")
# Try GPU encoding first, fallback to CPU
gpu_cmd = [
'ffmpeg', '-y', # -y to overwrite output file
'-framerate', str(original_fps),
'-i', str(temp_frames_dir / 'frame_%06d.jpg'),
'-c:v', 'h264_nvenc', # NVIDIA GPU encoder
'-preset', 'fast', # GPU preset
'-cq', '18', # Quality for GPU encoding
'-pix_fmt', 'yuv420p',
str(temp_video_path)
]
cpu_cmd = [
'ffmpeg', '-y', # -y to overwrite output file
'-framerate', str(original_fps),
'-i', str(temp_frames_dir / 'frame_%06d.jpg'),
'-c:v', 'libx264', # CPU encoder
'-pix_fmt', 'yuv420p',
'-crf', '18', # Quality for CPU encoding
'-preset', 'medium',
str(temp_video_path)
]
# Try GPU first
print(f"Trying GPU encoding: {' '.join(gpu_cmd)}")
result = subprocess.run(gpu_cmd, capture_output=True, text=True)
if result.returncode != 0:
print("GPU encoding failed, trying CPU...")
print(f"GPU error: {result.stderr}")
ffmpeg_cmd = cpu_cmd
print(f"Using CPU encoding: {' '.join(ffmpeg_cmd)}")
result = subprocess.run(ffmpeg_cmd, capture_output=True, text=True)
else:
print("GPU encoding successful!")
ffmpeg_cmd = gpu_cmd
print(f"Running ffmpeg: {' '.join(ffmpeg_cmd)}")
result = subprocess.run(ffmpeg_cmd, capture_output=True, text=True)
if result.returncode != 0:
print(f"FFmpeg stdout: {result.stdout}")
print(f"FFmpeg stderr: {result.stderr}")
raise RuntimeError(f"FFmpeg failed with return code {result.returncode}")
# Clean up frame images
import shutil
shutil.rmtree(temp_frames_dir)
print(f"Created temp video successfully")
# Memory monitoring after ffmpeg
self._print_memory_step(f"After ffmpeg encoding ({eye_name} eye)")
# Verify the file was created and has content
if not temp_video_path.exists():
raise RuntimeError(f"Temporary video file was not created: {temp_video_path}")
file_size = temp_video_path.stat().st_size
if file_size == 0:
raise RuntimeError(f"Temporary video file is empty: {temp_video_path}")
print(f"Created temp video {temp_video_path} ({file_size / 1024 / 1024:.1f} MB)")
# Memory monitoring and cleanup before SAM2 initialization
num_frames = len(eye_frames) # Store count before freeing
first_frame = eye_frames[0].copy() # Copy first frame for detection before freeing
self._print_memory_step(f"Before SAM2 init ({eye_name} eye, {num_frames} frames)")
# CRITICAL: Explicitly free eye_frames from memory before SAM2 loads the same video
# This prevents the OOM issue where both Python frames and SAM2 frames exist simultaneously
del eye_frames # Free the frames array
self._aggressive_memory_cleanup(f"SAM2 init for {eye_name} eye")
# Initialize SAM2 with video path
self._print_memory_step(f"Starting SAM2 init ({eye_name} eye)")
self.sam2_model.init_video_state(video_path=str(temp_video_path))
self._print_memory_step(f"SAM2 initialized ({eye_name} eye)")
# Detect persons in first frame
first_frame = eye_frames[0]
detections = self.detector.detect_persons(first_frame)
if not detections:
warnings.warn(f"No persons detected in {eye_name} eye, chunk {chunk_idx}")
return self._create_empty_masks(eye_frames)
# Return empty masks for the number of frames
return self._create_empty_masks_from_count(num_frames, first_frame.shape)
print(f"Detected {len(detections)} persons in {eye_name} eye first frame")
@@ -169,15 +378,45 @@ class VR180Processor(VideoProcessor):
# Add prompts
object_ids = self.sam2_model.add_person_prompts(0, box_prompts, labels)
# Propagate masks
# Propagate masks (most expensive operation)
self._print_memory_step(f"Before SAM2 propagation ({eye_name} eye, {num_frames} frames)")
# Use Det-SAM2 continuous correction if enabled
if self.config.matting.continuous_correction:
video_segments = self.sam2_model.propagate_masks_with_continuous_correction(
detector=self.detector,
temp_video_path=str(temp_video_path),
start_frame=0,
max_frames=num_frames,
correction_interval=self.config.matting.correction_interval,
frame_release_interval=self.config.matting.frame_release_interval,
frame_window_size=self.config.matting.frame_window_size
)
print(f"Used Det-SAM2 continuous correction (interval: {self.config.matting.correction_interval} frames)")
else:
video_segments = self.sam2_model.propagate_masks(
start_frame=0,
max_frames=len(eye_frames)
max_frames=num_frames,
frame_release_interval=self.config.matting.frame_release_interval,
frame_window_size=self.config.matting.frame_window_size
)
# Apply masks
self._print_memory_step(f"After SAM2 propagation ({eye_name} eye)")
# Apply masks with streaming approach (no frame accumulation)
self._print_memory_step(f"Before streaming mask application ({eye_name} eye)")
# Process frames one at a time without accumulation
cap = cv2.VideoCapture(str(temp_video_path))
matted_frames = []
for frame_idx, frame in enumerate(eye_frames):
try:
for frame_idx in range(num_frames):
ret, frame = cap.read()
if not ret:
break
# Apply mask to this single frame
if frame_idx in video_segments:
frame_masks = video_segments[frame_idx]
combined_mask = self.sam2_model.get_combined_mask(frame_masks)
@@ -192,10 +431,34 @@ class VR180Processor(VideoProcessor):
matted_frames.append(matted_frame)
# Cleanup
# Free the original frame immediately (no accumulation)
del frame
# Periodic cleanup during processing
if frame_idx % 100 == 0 and frame_idx > 0:
import gc
gc.collect()
finally:
cap.release()
# Free video segments completely
del video_segments # This holds processed masks from SAM2
self._aggressive_memory_cleanup(f"After streaming mask application ({eye_name} eye)")
self._print_memory_step(f"Completed streaming mask application ({eye_name} eye)")
return matted_frames
finally:
# Always cleanup
self.sam2_model.cleanup()
return matted_frames
# Remove temporary video file
try:
if temp_video_path.exists():
temp_video_path.unlink()
except Exception as e:
warnings.warn(f"Failed to cleanup temp video {temp_video_path}: {e}")
def _process_eye_sequence_with_validation(self,
right_eye_frames: List[np.ndarray],
@@ -223,13 +486,17 @@ class VR180Processor(VideoProcessor):
left_eye_results, right_matted
)
# CRITICAL: Free the intermediate results to prevent memory accumulation
del left_eye_results # Don't keep left eye results after validation
del right_matted # Don't keep unvalidated right results
return validated_results
def _validate_stereo_consistency(self,
left_results: List[np.ndarray],
right_results: List[np.ndarray]) -> List[np.ndarray]:
"""
Validate and correct stereo consistency between left and right eye results
Validate and correct stereo consistency between left and right eye results using GPU acceleration
Args:
left_results: Left eye processed frames
@@ -238,9 +505,120 @@ class VR180Processor(VideoProcessor):
Returns:
Validated right eye frames
"""
print(f"🔍 VALIDATION: Starting stereo consistency check ({len(left_results)} frames)")
try:
import cupy as cp
return self._validate_stereo_consistency_gpu(left_results, right_results)
except ImportError:
print(" Warning: CuPy not available, using CPU validation")
return self._validate_stereo_consistency_cpu(left_results, right_results)
def _validate_stereo_consistency_gpu(self,
left_results: List[np.ndarray],
right_results: List[np.ndarray]) -> List[np.ndarray]:
"""GPU-accelerated batch stereo validation using CuPy with memory-safe batching"""
import cupy as cp
print(" Using GPU acceleration for stereo validation")
# Process in batches to avoid GPU OOM
batch_size = 50 # Process 50 frames at a time (safe for 45GB GPU)
total_frames = len(left_results)
area_ratios_all = []
needs_correction_all = []
print(f" Processing {total_frames} frames in batches of {batch_size}...")
for batch_start in range(0, total_frames, batch_size):
batch_end = min(batch_start + batch_size, total_frames)
batch_frames = batch_end - batch_start
if batch_start % 100 == 0:
print(f" GPU batch {batch_start//batch_size + 1}: frames {batch_start}-{batch_end}")
# Get batch slices
left_batch = left_results[batch_start:batch_end]
right_batch = right_results[batch_start:batch_end]
# Convert batch to GPU
left_stack = cp.stack([cp.asarray(frame) for frame in left_batch])
right_stack = cp.stack([cp.asarray(frame) for frame in right_batch])
# Batch calculate mask areas for this batch
if left_stack.shape[3] == 4: # Alpha channel
left_masks = left_stack[:, :, :, 3] > 0
right_masks = right_stack[:, :, :, 3] > 0
else: # Green screen detection
bg_color = cp.array(self.config.output.background_color)
left_diff = cp.abs(left_stack.astype(cp.float32) - bg_color).sum(axis=3)
right_diff = cp.abs(right_stack.astype(cp.float32) - bg_color).sum(axis=3)
left_masks = left_diff > 30
right_masks = right_diff > 30
# Calculate areas for this batch
left_areas = cp.sum(left_masks, axis=(1, 2))
right_areas = cp.sum(right_masks, axis=(1, 2))
area_ratios = right_areas.astype(cp.float32) / (left_areas.astype(cp.float32) + 1e-6)
# Find frames needing correction in this batch
needs_correction = (area_ratios < 0.5) | (area_ratios > 2.0)
# Transfer batch results back to CPU and accumulate
area_ratios_all.extend(cp.asnumpy(area_ratios))
needs_correction_all.extend(cp.asnumpy(needs_correction))
# Free GPU memory for this batch
del left_stack, right_stack, left_masks, right_masks
del left_areas, right_areas, area_ratios, needs_correction
cp._default_memory_pool.free_all_blocks()
# CRITICAL: Release ALL CuPy memory back to system after validation
try:
# Force release of all GPU memory pools
cp._default_memory_pool.free_all_blocks()
cp._default_pinned_memory_pool.free_all_blocks()
# Clear CuPy cache completely
cp.get_default_memory_pool().free_all_blocks()
cp.get_default_pinned_memory_pool().free_all_blocks()
print(f" CuPy memory pools cleared")
except Exception as e:
print(f" Warning: Could not clear CuPy memory pools: {e}")
correction_count = sum(needs_correction_all)
print(f" GPU validation complete: {correction_count}/{total_frames} frames need correction")
# Apply corrections using CPU results
validated_frames = []
for i, (needs_fix, ratio) in enumerate(zip(needs_correction_all, area_ratios_all)):
if i % 100 == 0:
print(f" Processing validation results: {i}/{total_frames}")
if needs_fix:
# Apply correction
corrected_frame = self._apply_stereo_correction(
left_results[i], right_results[i], float(ratio)
)
validated_frames.append(corrected_frame)
else:
validated_frames.append(right_results[i])
print("✅ VALIDATION: GPU stereo consistency check complete")
return validated_frames
def _validate_stereo_consistency_cpu(self,
left_results: List[np.ndarray],
right_results: List[np.ndarray]) -> List[np.ndarray]:
"""CPU fallback for stereo validation"""
print(" Using CPU validation (slower)")
validated_frames = []
for i, (left_frame, right_frame) in enumerate(zip(left_results, right_results)):
if i % 50 == 0: # Progress every 50 frames
print(f" CPU validation progress: {i}/{len(left_results)}")
# Simple validation: check if mask areas are similar
left_mask_area = self._get_mask_area(left_frame)
right_mask_area = self._get_mask_area(right_frame)
@@ -257,10 +635,44 @@ class VR180Processor(VideoProcessor):
else:
validated_frames.append(right_frame)
print("✅ VALIDATION: CPU stereo consistency check complete")
return validated_frames
def _create_empty_masks_from_count(self, num_frames: int, frame_shape: tuple) -> List[np.ndarray]:
"""Create empty masks when no persons detected (without frame array)"""
empty_frames = []
for _ in range(num_frames):
if self.config.output.format == "alpha":
# Transparent output
output = np.zeros((frame_shape[0], frame_shape[1], 4), dtype=np.uint8)
else:
# Green screen background
output = np.full((frame_shape[0], frame_shape[1], 3),
self.config.output.background_color, dtype=np.uint8)
empty_frames.append(output)
return empty_frames
def _get_mask_area(self, frame: np.ndarray) -> float:
"""Get mask area from processed frame"""
"""Get mask area from processed frame using GPU acceleration"""
try:
import cupy as cp
# Transfer to GPU
frame_gpu = cp.asarray(frame)
if frame.shape[2] == 4: # Alpha channel
mask_gpu = frame_gpu[:, :, 3] > 0
else: # Green screen - detect non-background pixels
bg_color_gpu = cp.array(self.config.output.background_color)
diff_gpu = cp.abs(frame_gpu.astype(cp.float32) - bg_color_gpu).sum(axis=2)
mask_gpu = diff_gpu > 30 # Threshold for non-background
# Calculate area on GPU and return as Python int
area = int(cp.sum(mask_gpu))
return area
except ImportError:
# Fallback to CPU NumPy if CuPy not available
if frame.shape[2] == 4: # Alpha channel
mask = frame[:, :, 3] > 0
else: # Green screen - detect non-background pixels
@@ -284,6 +696,64 @@ class VR180Processor(VideoProcessor):
# TODO: Implement proper stereo correction algorithm
return right_frame
def _complete_inter_chunk_cleanup(self, chunk_idx: int):
"""
Complete inter-chunk cleanup: Destroy all models to prevent memory accumulation
This addresses the core issue where SAM2 and YOLO models (~15-20GB)
persist between chunks, causing OOM when processing subsequent chunks.
"""
print(f"🧹 INTER-CHUNK CLEANUP: Destroying all models after chunk {chunk_idx}")
# 1. Completely destroy SAM2 model (15-20GB)
if hasattr(self, 'sam2_model') and self.sam2_model is not None:
self.sam2_model.cleanup() # Call existing cleanup
# Force complete destruction of the model
try:
# Reset the model's loaded state so it will reload fresh
if hasattr(self.sam2_model, '_model_loaded'):
self.sam2_model._model_loaded = False
# Clear any cached state
if hasattr(self.sam2_model, 'predictor'):
self.sam2_model.predictor = None
if hasattr(self.sam2_model, 'inference_state'):
self.sam2_model.inference_state = None
print(f" ✅ SAM2 model destroyed and marked for fresh reload")
except Exception as e:
print(f" ⚠️ SAM2 destruction warning: {e}")
# 2. Completely destroy YOLO detector (400MB+)
if hasattr(self, 'detector') and self.detector is not None:
try:
# Force YOLO model to be reloaded fresh
if hasattr(self.detector, 'model') and self.detector.model is not None:
del self.detector.model
self.detector.model = None
print(f" ✅ YOLO model destroyed and marked for fresh reload")
except Exception as e:
print(f" ⚠️ YOLO destruction warning: {e}")
# 3. Clear CUDA cache aggressively
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize() # Wait for all operations to complete
print(f" ✅ CUDA cache cleared")
# 4. Force garbage collection
import gc
collected = gc.collect()
print(f" ✅ Garbage collection: {collected} objects freed")
# 5. Memory verification
self._print_memory_step(f"After complete inter-chunk cleanup (chunk {chunk_idx})")
print(f"🎯 RESULT: Models will reload fresh for next chunk (prevents 15-20GB accumulation)")
def process_chunk(self,
frames: List[np.ndarray],
chunk_idx: int = 0) -> List[np.ndarray]:
@@ -343,6 +813,9 @@ class VR180Processor(VideoProcessor):
combined = {'left': left_frame, 'right': right_frame}
combined_frames.append(combined)
# CRITICAL: Complete inter-chunk cleanup for independent processing too
self._complete_inter_chunk_cleanup(chunk_idx)
return combined_frames
def save_video(self, frames: List[np.ndarray], output_path: str):