16 Commits
cuda ... det

Author SHA1 Message Date
277d554ecc fix scaling 1 2025-07-26 18:31:16 -07:00
d6d2b0aa93 full size babyyy 2025-07-26 18:09:48 -07:00
3a547b7c21 please god work 2025-07-26 17:44:23 -07:00
262cb00b69 checkpoints yay 2025-07-26 17:11:07 -07:00
caa4ddb5e0 actually fix streaming save 2025-07-26 17:05:50 -07:00
fa945b9c3e fix concat 2025-07-26 16:29:59 -07:00
4958c503dd please merge 2025-07-26 16:02:07 -07:00
366b132ef5 growth 2025-07-26 15:31:07 -07:00
4d1361df46 bigtime 2025-07-26 15:29:37 -07:00
884cb8dce2 lol 2025-07-26 15:29:28 -07:00
36f58acb8b foo 2025-07-26 15:18:32 -07:00
fb51e82fd4 stuff 2025-07-26 15:18:01 -07:00
9f572d4430 analyze 2025-07-26 15:10:34 -07:00
ba8706b7ae quick check 2025-07-26 14:52:44 -07:00
734445cf48 more memory fixes hopeufly 2025-07-26 14:33:36 -07:00
80f947c91b det core 2025-07-26 13:51:21 -07:00
11 changed files with 1873 additions and 105 deletions

193
analyze_memory_profile.py Normal file
View File

@@ -0,0 +1,193 @@
#!/usr/bin/env python3
"""
Analyze memory profile JSON files to identify OOM causes
"""
import json
import glob
import os
import sys
from pathlib import Path
def analyze_memory_files():
"""Analyze partial memory profile files"""
# Get all partial files in order
files = sorted(glob.glob('memory_profile_partial_*.json'))
if not files:
print("❌ No memory profile files found!")
print("Expected files like: memory_profile_partial_0.json")
return
print(f"🔍 Found {len(files)} memory profile files")
print("=" * 60)
peak_memory = 0
peak_vram = 0
critical_points = []
all_checkpoints = []
for i, file in enumerate(files):
try:
with open(file, 'r') as f:
data = json.load(f)
timeline = data.get('timeline', [])
if not timeline:
continue
# Find peaks in this file
file_peak_rss = max([d['rss_gb'] for d in timeline])
file_peak_vram = max([d['vram_gb'] for d in timeline])
if file_peak_rss > peak_memory:
peak_memory = file_peak_rss
if file_peak_vram > peak_vram:
peak_vram = file_peak_vram
# Find memory growth spikes (>3GB increase)
for j in range(1, len(timeline)):
prev_rss = timeline[j-1]['rss_gb']
curr_rss = timeline[j]['rss_gb']
growth = curr_rss - prev_rss
if growth > 3.0: # >3GB growth spike
checkpoint = timeline[j].get('checkpoint', f'sample_{j}')
critical_points.append({
'file': file,
'file_index': i,
'sample': j,
'timestamp': timeline[j]['timestamp'],
'rss_gb': curr_rss,
'vram_gb': timeline[j]['vram_gb'],
'growth_gb': growth,
'checkpoint': checkpoint
})
# Collect all checkpoints
checkpoints = [d for d in timeline if 'checkpoint' in d]
for cp in checkpoints:
cp['file'] = file
cp['file_index'] = i
all_checkpoints.append(cp)
# Show progress for this file
if timeline:
start_rss = timeline[0]['rss_gb']
end_rss = timeline[-1]['rss_gb']
growth = end_rss - start_rss
samples = len(timeline)
print(f"📊 File {i+1:2d}: {start_rss:5.1f}GB → {end_rss:5.1f}GB "
f"(+{growth:4.1f}GB) [{samples:3d} samples]")
# Show significant checkpoints from this file
if checkpoints:
for cp in checkpoints:
print(f" 📍 {cp['checkpoint']}: {cp['rss_gb']:.1f}GB")
except Exception as e:
print(f"❌ Error reading {file}: {e}")
print("\n" + "=" * 60)
print("🎯 ANALYSIS SUMMARY")
print("=" * 60)
print(f"📈 Peak Memory: {peak_memory:.1f} GB")
print(f"🎮 Peak VRAM: {peak_vram:.1f} GB")
print(f"⚡ Growth Spikes: {len(critical_points)} events >3GB")
if critical_points:
print(f"\n💥 MEMORY GROWTH SPIKES (>3GB):")
print(" Location Growth Total VRAM")
print(" " + "-" * 55)
for point in critical_points:
location = point['checkpoint'][:30].ljust(30)
print(f" {location} +{point['growth_gb']:4.1f}GB → {point['rss_gb']:5.1f}GB {point['vram_gb']:4.1f}GB")
if all_checkpoints:
print(f"\n📍 CHECKPOINT PROGRESSION:")
print(" Checkpoint Memory VRAM File")
print(" " + "-" * 55)
for cp in all_checkpoints:
checkpoint = cp['checkpoint'][:30].ljust(30)
file_num = cp['file_index'] + 1
print(f" {checkpoint} {cp['rss_gb']:5.1f}GB {cp['vram_gb']:4.1f}GB #{file_num}")
# Memory growth analysis
if len(all_checkpoints) > 1:
print(f"\n📊 MEMORY GROWTH ANALYSIS:")
# Find the biggest memory jumps between checkpoints
big_jumps = []
for i in range(1, len(all_checkpoints)):
prev_cp = all_checkpoints[i-1]
curr_cp = all_checkpoints[i]
growth = curr_cp['rss_gb'] - prev_cp['rss_gb']
if growth > 2.0: # >2GB jump
big_jumps.append({
'from': prev_cp['checkpoint'],
'to': curr_cp['checkpoint'],
'growth': growth,
'from_memory': prev_cp['rss_gb'],
'to_memory': curr_cp['rss_gb']
})
if big_jumps:
print(" Major jumps (>2GB):")
for jump in big_jumps:
print(f" {jump['from']}{jump['to']}: "
f"+{jump['growth']:.1f}GB ({jump['from_memory']:.1f}{jump['to_memory']:.1f}GB)")
else:
print(" ✅ No major memory jumps detected")
# Diagnosis
print(f"\n🔬 DIAGNOSIS:")
if peak_memory > 400:
print(" 🔴 CRITICAL: Memory usage exceeded 400GB")
print(" 💡 Recommendation: Reduce chunk_size to 200-300 frames")
elif peak_memory > 200:
print(" 🟡 HIGH: Memory usage over 200GB")
print(" 💡 Recommendation: Reduce chunk_size to 400 frames")
else:
print(" 🟢 MODERATE: Memory usage under 200GB")
if critical_points:
# Find most common growth spike locations
spike_locations = {}
for point in critical_points:
location = point['checkpoint']
spike_locations[location] = spike_locations.get(location, 0) + 1
print("\n 🎯 Most problematic locations:")
for location, count in sorted(spike_locations.items(), key=lambda x: x[1], reverse=True)[:3]:
print(f" {location}: {count} spikes")
print(f"\n💡 NEXT STEPS:")
if 'merge' in str(critical_points).lower():
print(" 1. Chunk merging still causing memory accumulation")
print(" 2. Check if streaming merge is actually being used")
print(" 3. Verify chunk files are being deleted immediately")
elif 'propagation' in str(critical_points).lower():
print(" 1. SAM2 propagation using too much memory")
print(" 2. Reduce chunk_size further (try 300 frames)")
print(" 3. Enable more aggressive frame release")
else:
print(" 1. Review the checkpoint progression above")
print(" 2. Focus on locations with biggest memory spikes")
print(" 3. Consider reducing chunk_size if spikes are large")
def main():
print("🔍 MEMORY PROFILE ANALYZER")
print("Analyzing memory profile files for OOM causes...")
print()
analyze_memory_files()
if __name__ == "__main__":
main()

151
debug_memory_leak.py Normal file
View File

@@ -0,0 +1,151 @@
#!/usr/bin/env python3
"""
Debug memory leak between chunks - track exactly where memory accumulates
"""
import psutil
import gc
from pathlib import Path
import sys
def detailed_memory_check(label):
"""Get detailed memory info"""
process = psutil.Process()
memory_info = process.memory_info()
rss_gb = memory_info.rss / (1024**3)
vms_gb = memory_info.vms / (1024**3)
# System memory
sys_memory = psutil.virtual_memory()
available_gb = sys_memory.available / (1024**3)
print(f"🔍 {label}:")
print(f" RSS: {rss_gb:.2f} GB (physical memory)")
print(f" VMS: {vms_gb:.2f} GB (virtual memory)")
print(f" Available: {available_gb:.2f} GB")
return rss_gb
def simulate_chunk_processing():
"""Simulate the chunk processing to see where memory accumulates"""
print("🚀 SIMULATING CHUNK PROCESSING TO FIND MEMORY LEAK")
print("=" * 60)
base_memory = detailed_memory_check("0. Baseline")
# Step 1: Import everything (with lazy loading)
print("\n📦 Step 1: Imports")
from vr180_matting.config import VR180Config
from vr180_matting.vr180_processor import VR180Processor
import_memory = detailed_memory_check("1. After imports")
import_growth = import_memory - base_memory
print(f" Growth: +{import_growth:.2f} GB")
# Step 2: Load config
print("\n⚙️ Step 2: Config loading")
config = VR180Config.from_yaml('config.yaml')
config_memory = detailed_memory_check("2. After config load")
config_growth = config_memory - import_memory
print(f" Growth: +{config_growth:.2f} GB")
# Step 3: Initialize processor (models still lazy)
print("\n🏗️ Step 3: Processor initialization")
processor = VR180Processor(config)
processor_memory = detailed_memory_check("3. After processor init")
processor_growth = processor_memory - config_memory
print(f" Growth: +{processor_growth:.2f} GB")
# Step 4: Load video info (lightweight)
print("\n🎬 Step 4: Video info loading")
try:
video_info = processor.load_video_info(config.input.video_path)
print(f" Video: {video_info.get('width', 'unknown')}x{video_info.get('height', 'unknown')}, "
f"{video_info.get('total_frames', 'unknown')} frames")
except Exception as e:
print(f" Warning: Could not load video info: {e}")
video_info_memory = detailed_memory_check("4. After video info")
video_info_growth = video_info_memory - processor_memory
print(f" Growth: +{video_info_growth:.2f} GB")
# Step 5: Simulate chunk 0 processing (this is where models actually load)
print("\n🔄 Step 5: Simulating chunk 0 processing...")
# This is where the real memory usage starts
print(" Loading first 10 frames to trigger model loading...")
try:
# Read a small number of frames to trigger model loading
frames = processor.read_video_frames(
config.input.video_path,
start_frame=0,
num_frames=10, # Just 10 frames to trigger model loading
scale_factor=config.processing.scale_factor
)
frames_memory = detailed_memory_check("5a. After reading 10 frames")
frames_growth = frames_memory - video_info_memory
print(f" 10 frames growth: +{frames_growth:.2f} GB")
# Free frames
del frames
gc.collect()
after_free_memory = detailed_memory_check("5b. After freeing 10 frames")
free_improvement = frames_memory - after_free_memory
print(f" Memory freed: -{free_improvement:.2f} GB")
except Exception as e:
print(f" Could not simulate frame loading: {e}")
after_free_memory = video_info_memory
print(f"\n📊 MEMORY ANALYSIS:")
print(f" Baseline → Final: {base_memory:.2f}GB → {after_free_memory:.2f}GB")
print(f" Total growth: +{after_free_memory - base_memory:.2f}GB")
if after_free_memory - base_memory > 10:
print(f" 🔴 HIGH: Memory growth > 10GB before any real processing")
print(f" 💡 This suggests model loading is using too much memory")
elif after_free_memory - base_memory > 5:
print(f" 🟡 MODERATE: Memory growth 5-10GB")
print(f" 💡 Normal for model loading, but monitor chunk processing")
else:
print(f" 🟢 GOOD: Memory growth < 5GB")
print(f" 💡 Initialization memory usage is reasonable")
print(f"\n🎯 KEY INSIGHTS:")
if import_growth > 1:
print(f" - Import growth: {import_growth:.2f}GB (fixed with lazy loading)")
if processor_growth > 10:
print(f" - Processor init: {processor_growth:.2f}GB (investigate model pre-loading)")
print(f"\n💡 RECOMMENDATIONS:")
if after_free_memory - base_memory > 15:
print(f" 1. Reduce chunk_size to 200-300 frames")
print(f" 2. Use smaller models (yolov8n instead of yolov8m)")
print(f" 3. Enable FP16 mode for SAM2")
elif after_free_memory - base_memory > 8:
print(f" 1. Monitor chunk processing carefully")
print(f" 2. Use streaming merge (should be automatic)")
print(f" 3. Current settings may be acceptable")
else:
print(f" 1. Settings look good for initialization")
print(f" 2. Focus on chunk processing memory leaks")
def main():
if len(sys.argv) != 2:
print("Usage: python debug_memory_leak.py <config.yaml>")
print("This simulates initialization to find memory leaks")
sys.exit(1)
config_path = sys.argv[1]
if not Path(config_path).exists():
print(f"Config file not found: {config_path}")
sys.exit(1)
simulate_chunk_processing()
if __name__ == "__main__":
main()

249
memory_profiler_script.py Normal file
View File

@@ -0,0 +1,249 @@
#!/usr/bin/env python3
"""
Memory profiling script for VR180 matting pipeline
Tracks memory usage during processing to identify leaks
"""
import sys
import time
import psutil
import tracemalloc
import subprocess
import gc
from pathlib import Path
from typing import Dict, List, Tuple
import threading
import json
class MemoryProfiler:
def __init__(self, output_file: str = "memory_profile.json"):
self.output_file = output_file
self.data = []
self.process = psutil.Process()
self.running = False
self.thread = None
self.checkpoint_counter = 0
def start_monitoring(self, interval: float = 1.0):
"""Start continuous memory monitoring"""
tracemalloc.start()
self.running = True
self.thread = threading.Thread(target=self._monitor_loop, args=(interval,))
self.thread.daemon = True
self.thread.start()
print(f"🔍 Memory monitoring started (interval: {interval}s)")
def stop_monitoring(self):
"""Stop monitoring and save results"""
self.running = False
if self.thread:
self.thread.join()
# Get tracemalloc snapshot
snapshot = tracemalloc.take_snapshot()
top_stats = snapshot.statistics('lineno')
# Save detailed results
results = {
'timeline': self.data,
'top_memory_allocations': [
{
'file': stat.traceback.format()[0],
'size_mb': stat.size / 1024 / 1024,
'count': stat.count
}
for stat in top_stats[:20] # Top 20 allocations
],
'summary': {
'peak_rss_gb': max([d['rss_gb'] for d in self.data]) if self.data else 0,
'peak_vram_gb': max([d['vram_gb'] for d in self.data]) if self.data else 0,
'total_samples': len(self.data)
}
}
with open(self.output_file, 'w') as f:
json.dump(results, f, indent=2)
tracemalloc.stop()
print(f"📊 Memory profile saved to {self.output_file}")
def _monitor_loop(self, interval: float):
"""Continuous monitoring loop"""
while self.running:
try:
# System memory
memory_info = self.process.memory_info()
rss_gb = memory_info.rss / (1024**3)
# System-wide memory
sys_memory = psutil.virtual_memory()
sys_used_gb = (sys_memory.total - sys_memory.available) / (1024**3)
sys_available_gb = sys_memory.available / (1024**3)
# GPU memory (if available)
vram_gb = 0
vram_free_gb = 0
try:
result = subprocess.run(['nvidia-smi', '--query-gpu=memory.used,memory.free',
'--format=csv,noheader,nounits'],
capture_output=True, text=True, timeout=5)
if result.returncode == 0:
lines = result.stdout.strip().split('\n')
if lines and lines[0]:
used, free = lines[0].split(', ')
vram_gb = float(used) / 1024
vram_free_gb = float(free) / 1024
except Exception:
pass
# Tracemalloc current usage
try:
current, peak = tracemalloc.get_traced_memory()
traced_mb = current / (1024**2)
except Exception:
traced_mb = 0
data_point = {
'timestamp': time.time(),
'rss_gb': rss_gb,
'vram_gb': vram_gb,
'vram_free_gb': vram_free_gb,
'sys_used_gb': sys_used_gb,
'sys_available_gb': sys_available_gb,
'traced_mb': traced_mb
}
self.data.append(data_point)
# Print periodic updates and save partial data
if len(self.data) % 10 == 0: # Every 10 samples
print(f"🔍 Memory: RSS={rss_gb:.2f}GB, VRAM={vram_gb:.2f}GB, Sys={sys_used_gb:.1f}GB")
# Save partial data every 30 samples in case of crash
if len(self.data) % 30 == 0:
self._save_partial_data()
except Exception as e:
print(f"Monitoring error: {e}")
time.sleep(interval)
def _save_partial_data(self):
"""Save partial data to prevent loss on crash"""
try:
partial_file = f"memory_profile_partial_{self.checkpoint_counter}.json"
with open(partial_file, 'w') as f:
json.dump({
'timeline': self.data,
'status': 'partial_save',
'samples': len(self.data)
}, f, indent=2)
self.checkpoint_counter += 1
except Exception as e:
print(f"Failed to save partial data: {e}")
def log_checkpoint(self, checkpoint_name: str):
"""Log a specific checkpoint"""
if self.data:
self.data[-1]['checkpoint'] = checkpoint_name
latest = self.data[-1]
print(f"📍 CHECKPOINT [{checkpoint_name}]: RSS={latest['rss_gb']:.2f}GB, VRAM={latest['vram_gb']:.2f}GB")
# Save checkpoint data immediately
self._save_partial_data()
def run_with_profiling(config_path: str):
"""Run the VR180 matting with memory profiling"""
profiler = MemoryProfiler("memory_profile_detailed.json")
try:
# Start monitoring
profiler.start_monitoring(interval=2.0) # Sample every 2 seconds
# Log initial state
profiler.log_checkpoint("STARTUP")
# Import after starting profiler to catch import memory usage
print("Importing VR180 processor...")
from vr180_matting.vr180_processor import VR180Processor
from vr180_matting.config import VR180Config
profiler.log_checkpoint("IMPORTS_COMPLETE")
# Load config
print(f"Loading config from {config_path}")
config = VR180Config.from_yaml(config_path)
profiler.log_checkpoint("CONFIG_LOADED")
# Initialize processor
print("Initializing VR180 processor...")
processor = VR180Processor(config)
profiler.log_checkpoint("PROCESSOR_INITIALIZED")
# Force garbage collection
gc.collect()
profiler.log_checkpoint("INITIAL_GC_COMPLETE")
# Run processing
print("Starting VR180 processing...")
processor.process_video()
profiler.log_checkpoint("PROCESSING_COMPLETE")
except Exception as e:
print(f"❌ Error during processing: {e}")
profiler.log_checkpoint(f"ERROR: {str(e)}")
raise
finally:
# Stop monitoring and save results
profiler.stop_monitoring()
# Print summary
print("\n" + "="*60)
print("MEMORY PROFILING SUMMARY")
print("="*60)
if profiler.data:
peak_rss = max([d['rss_gb'] for d in profiler.data])
peak_vram = max([d['vram_gb'] for d in profiler.data])
print(f"Peak RSS Memory: {peak_rss:.2f} GB")
print(f"Peak VRAM Usage: {peak_vram:.2f} GB")
print(f"Total Samples: {len(profiler.data)}")
# Show checkpoints
checkpoints = [d for d in profiler.data if 'checkpoint' in d]
if checkpoints:
print(f"\nCheckpoints ({len(checkpoints)}):")
for cp in checkpoints:
print(f" {cp['checkpoint']}: RSS={cp['rss_gb']:.2f}GB, VRAM={cp['vram_gb']:.2f}GB")
print(f"\nDetailed profile saved to: {profiler.output_file}")
def main():
if len(sys.argv) != 2:
print("Usage: python memory_profiler_script.py <config.yaml>")
print("\nThis script runs VR180 matting with detailed memory profiling")
print("It will:")
print("- Monitor RSS, VRAM, and system memory every 2 seconds")
print("- Track memory allocations with tracemalloc")
print("- Log checkpoints at key processing stages")
print("- Save detailed JSON report for analysis")
sys.exit(1)
config_path = sys.argv[1]
if not Path(config_path).exists():
print(f"❌ Config file not found: {config_path}")
sys.exit(1)
print("🚀 Starting VR180 Memory Profiling")
print(f"Config: {config_path}")
print("="*60)
run_with_profiling(config_path)
if __name__ == "__main__":
main()

125
quick_memory_check.py Normal file
View File

@@ -0,0 +1,125 @@
#!/usr/bin/env python3
"""
Quick memory and system check before running full pipeline
"""
import psutil
import subprocess
import sys
from pathlib import Path
def check_system():
"""Check system resources before starting"""
print("🔍 SYSTEM RESOURCE CHECK")
print("=" * 50)
# Memory info
memory = psutil.virtual_memory()
print(f"📊 RAM:")
print(f" Total: {memory.total / (1024**3):.1f} GB")
print(f" Available: {memory.available / (1024**3):.1f} GB")
print(f" Used: {(memory.total - memory.available) / (1024**3):.1f} GB ({memory.percent:.1f}%)")
# GPU info
try:
result = subprocess.run(['nvidia-smi', '--query-gpu=name,memory.total,memory.used,memory.free',
'--format=csv,noheader,nounits'],
capture_output=True, text=True, timeout=10)
if result.returncode == 0:
lines = result.stdout.strip().split('\n')
print(f"\n🎮 GPU:")
for i, line in enumerate(lines):
if line.strip():
parts = line.split(', ')
if len(parts) >= 4:
name, total, used, free = parts[:4]
total_gb = float(total) / 1024
used_gb = float(used) / 1024
free_gb = float(free) / 1024
print(f" GPU {i}: {name}")
print(f" VRAM: {used_gb:.1f}/{total_gb:.1f} GB ({used_gb/total_gb*100:.1f}% used)")
print(f" Free: {free_gb:.1f} GB")
except Exception as e:
print(f"\n⚠️ Could not get GPU info: {e}")
# Disk space
disk = psutil.disk_usage('/')
print(f"\n💾 Disk (/):")
print(f" Total: {disk.total / (1024**3):.1f} GB")
print(f" Used: {disk.used / (1024**3):.1f} GB ({disk.used/disk.total*100:.1f}%)")
print(f" Free: {disk.free / (1024**3):.1f} GB")
# Check config file
if len(sys.argv) > 1:
config_path = sys.argv[1]
if Path(config_path).exists():
print(f"\n✅ Config file found: {config_path}")
# Try to load and show key settings
try:
import yaml
with open(config_path, 'r') as f:
config = yaml.safe_load(f)
print(f"📋 Key Settings:")
if 'processing' in config:
proc = config['processing']
print(f" Chunk size: {proc.get('chunk_size', 'default')}")
print(f" Scale factor: {proc.get('scale_factor', 'default')}")
if 'hardware' in config:
hw = config['hardware']
print(f" Max VRAM: {hw.get('max_vram_gb', 'default')} GB")
if 'input' in config:
inp = config['input']
video_path = inp.get('video_path', '')
if video_path and Path(video_path).exists():
size_gb = Path(video_path).stat().st_size / (1024**3)
print(f" Input video: {video_path} ({size_gb:.1f} GB)")
else:
print(f" ⚠️ Input video not found: {video_path}")
except Exception as e:
print(f" ⚠️ Could not parse config: {e}")
else:
print(f"\n❌ Config file not found: {config_path}")
return False
# Memory safety warnings
print(f"\n⚠️ MEMORY SAFETY CHECKS:")
available_gb = memory.available / (1024**3)
if available_gb < 10:
print(f" 🔴 LOW MEMORY: Only {available_gb:.1f}GB available")
print(" Consider: reducing chunk_size or scale_factor")
return False
elif available_gb < 20:
print(f" 🟡 MODERATE MEMORY: {available_gb:.1f}GB available")
print(" Recommend: chunk_size ≤ 300, scale_factor ≤ 0.5")
else:
print(f" 🟢 GOOD MEMORY: {available_gb:.1f}GB available")
print(f"\n" + "=" * 50)
return True
def main():
if len(sys.argv) != 2:
print("Usage: python quick_memory_check.py <config.yaml>")
print("\nThis checks system resources before running VR180 matting")
sys.exit(1)
safe_to_run = check_system()
if safe_to_run:
print("✅ System check passed - safe to run VR180 matting")
print("\nTo run with memory profiling:")
print(f" python memory_profiler_script.py {sys.argv[1]}")
print("\nTo run normally:")
print(f" vr180-matting {sys.argv[1]}")
else:
print("❌ System check failed - address issues before running")
sys.exit(1)
if __name__ == "__main__":
main()

148
test_inter_chunk_cleanup.py Normal file
View File

@@ -0,0 +1,148 @@
#!/usr/bin/env python3
"""
Test script to verify inter-chunk cleanup properly destroys models
"""
import psutil
import gc
import sys
from pathlib import Path
def get_memory_usage():
"""Get current memory usage in GB"""
process = psutil.Process()
return process.memory_info().rss / (1024**3)
def test_inter_chunk_cleanup():
"""Test that models are properly destroyed between chunks"""
print("🧪 TESTING INTER-CHUNK CLEANUP")
print("=" * 50)
baseline_memory = get_memory_usage()
print(f"📊 Baseline memory: {baseline_memory:.2f} GB")
# Import and create processor
print("\n1⃣ Creating processor...")
from vr180_matting.config import VR180Config
from vr180_matting.vr180_processor import VR180Processor
config = VR180Config.from_yaml('config.yaml')
processor = VR180Processor(config)
init_memory = get_memory_usage()
print(f"📊 After processor init: {init_memory:.2f} GB (+{init_memory - baseline_memory:.2f} GB)")
# Simulate chunk processing (just trigger model loading)
print("\n2⃣ Simulating chunk 0 processing...")
# Test 1: Force YOLO model loading
try:
detector = processor.detector
detector._load_model() # Force load
yolo_memory = get_memory_usage()
print(f"📊 After YOLO load: {yolo_memory:.2f} GB (+{yolo_memory - init_memory:.2f} GB)")
except Exception as e:
print(f"❌ YOLO loading failed: {e}")
yolo_memory = init_memory
# Test 2: Force SAM2 model loading
try:
sam2_model = processor.sam2_model
sam2_model._load_model(sam2_model.model_cfg, sam2_model.checkpoint_path)
sam2_memory = get_memory_usage()
print(f"📊 After SAM2 load: {sam2_memory:.2f} GB (+{sam2_memory - yolo_memory:.2f} GB)")
except Exception as e:
print(f"❌ SAM2 loading failed: {e}")
sam2_memory = yolo_memory
total_model_memory = sam2_memory - init_memory
print(f"📊 Total model memory: {total_model_memory:.2f} GB")
# Test 3: Inter-chunk cleanup
print("\n3⃣ Testing inter-chunk cleanup...")
processor._complete_inter_chunk_cleanup(chunk_idx=0)
cleanup_memory = get_memory_usage()
cleanup_improvement = sam2_memory - cleanup_memory
print(f"📊 After cleanup: {cleanup_memory:.2f} GB (-{cleanup_improvement:.2f} GB freed)")
# Test 4: Verify models reload fresh
print("\n4⃣ Testing fresh model reload...")
# Check YOLO state
yolo_reloaded = processor.detector.model is None
print(f"🔍 YOLO model destroyed: {'✅ YES' if yolo_reloaded else '❌ NO'}")
# Check SAM2 state
sam2_reloaded = not processor.sam2_model._model_loaded or processor.sam2_model.predictor is None
print(f"🔍 SAM2 model destroyed: {'✅ YES' if sam2_reloaded else '❌ NO'}")
# Test 5: Force reload to verify they work
print("\n5⃣ Testing model reload...")
try:
# Force YOLO reload
processor.detector._load_model()
yolo_reload_memory = get_memory_usage()
# Force SAM2 reload
processor.sam2_model._load_model(processor.sam2_model.model_cfg, processor.sam2_model.checkpoint_path)
sam2_reload_memory = get_memory_usage()
reload_growth = sam2_reload_memory - cleanup_memory
print(f"📊 After reload: {sam2_reload_memory:.2f} GB (+{reload_growth:.2f} GB)")
if abs(reload_growth - total_model_memory) < 1.0: # Within 1GB
print("✅ Models reloaded with similar memory usage (good)")
else:
print("⚠️ Model reload memory differs significantly")
except Exception as e:
print(f"❌ Model reload failed: {e}")
# Final summary
print(f"\n📊 SUMMARY:")
print(f" Baseline → Peak: {baseline_memory:.2f}GB → {sam2_memory:.2f}GB")
print(f" Peak → Cleanup: {sam2_memory:.2f}GB → {cleanup_memory:.2f}GB")
print(f" Memory freed: {cleanup_improvement:.2f}GB")
print(f" Models destroyed: YOLO={yolo_reloaded}, SAM2={sam2_reloaded}")
# Success criteria: Both models destroyed AND can reload
models_destroyed = yolo_reloaded and sam2_reloaded
can_reload = 'reload_growth' in locals()
if models_destroyed and can_reload:
print("✅ Inter-chunk cleanup working effectively")
print("💡 Models destroyed and can reload fresh (memory will be freed during real processing)")
return True
elif models_destroyed:
print("⚠️ Models destroyed but reload test incomplete")
print("💡 This should still prevent accumulation during real processing")
return True
else:
print("❌ Inter-chunk cleanup not freeing enough memory")
return False
def main():
if len(sys.argv) != 2:
print("Usage: python test_inter_chunk_cleanup.py <config.yaml>")
sys.exit(1)
config_path = sys.argv[1]
if not Path(config_path).exists():
print(f"Config file not found: {config_path}")
sys.exit(1)
success = test_inter_chunk_cleanup()
if success:
print(f"\n🎉 SUCCESS: Inter-chunk cleanup is working!")
print(f"💡 This should prevent 15-20GB model accumulation between chunks")
else:
print(f"\n❌ FAILURE: Inter-chunk cleanup needs improvement")
print(f"💡 Check model destruction logic in _complete_inter_chunk_cleanup")
return 0 if success else 1
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,220 @@
"""
Checkpoint manager for resumable video processing
Saves progress to avoid reprocessing after OOM or crashes
"""
import json
import hashlib
from pathlib import Path
from typing import Dict, Any, Optional, List
import os
import shutil
from datetime import datetime
class CheckpointManager:
"""Manages processing checkpoints for resumable execution"""
def __init__(self, video_path: str, output_path: str, checkpoint_dir: Optional[Path] = None):
"""
Initialize checkpoint manager
Args:
video_path: Input video path
output_path: Output video path
checkpoint_dir: Directory for checkpoint files (default: .vr180_checkpoints in CWD)
"""
self.video_path = Path(video_path)
self.output_path = Path(output_path)
# Create unique checkpoint ID based on video file
self.video_hash = self._compute_video_hash()
# Setup checkpoint directory
if checkpoint_dir is None:
self.checkpoint_dir = Path.cwd() / ".vr180_checkpoints" / self.video_hash
else:
self.checkpoint_dir = Path(checkpoint_dir) / self.video_hash
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
# Checkpoint files
self.status_file = self.checkpoint_dir / "processing_status.json"
self.chunks_dir = self.checkpoint_dir / "chunks"
self.chunks_dir.mkdir(exist_ok=True)
# Load existing status or create new
self.status = self._load_status()
def _compute_video_hash(self) -> str:
"""Compute hash of video file for unique identification"""
# Use file path, size, and modification time for quick hash
stat = self.video_path.stat()
hash_str = f"{self.video_path}_{stat.st_size}_{stat.st_mtime}"
return hashlib.md5(hash_str.encode()).hexdigest()[:12]
def _load_status(self) -> Dict[str, Any]:
"""Load processing status from checkpoint file"""
if self.status_file.exists():
with open(self.status_file, 'r') as f:
status = json.load(f)
print(f"📋 Loaded checkpoint: {status['completed_chunks']}/{status['total_chunks']} chunks completed")
return status
else:
# Create new status
return {
'video_path': str(self.video_path),
'output_path': str(self.output_path),
'video_hash': self.video_hash,
'start_time': datetime.now().isoformat(),
'total_chunks': 0,
'completed_chunks': 0,
'chunk_info': {},
'processing_complete': False,
'merge_complete': False
}
def _save_status(self):
"""Save current status to checkpoint file"""
self.status['last_update'] = datetime.now().isoformat()
with open(self.status_file, 'w') as f:
json.dump(self.status, f, indent=2)
def set_total_chunks(self, total_chunks: int):
"""Set total number of chunks to process"""
self.status['total_chunks'] = total_chunks
self._save_status()
def is_chunk_completed(self, chunk_idx: int) -> bool:
"""Check if a chunk has already been processed"""
chunk_key = f"chunk_{chunk_idx}"
return chunk_key in self.status['chunk_info'] and \
self.status['chunk_info'][chunk_key].get('completed', False)
def get_chunk_file(self, chunk_idx: int) -> Optional[Path]:
"""Get saved chunk file path if it exists"""
chunk_file = self.chunks_dir / f"chunk_{chunk_idx:04d}.npz"
if chunk_file.exists() and self.is_chunk_completed(chunk_idx):
return chunk_file
return None
def save_chunk(self, chunk_idx: int, frames: List, source_chunk_path: Optional[Path] = None):
"""
Save processed chunk and mark as completed
Args:
chunk_idx: Chunk index
frames: Processed frames (can be None if using source_chunk_path)
source_chunk_path: If provided, copy this file instead of saving frames
"""
chunk_file = self.chunks_dir / f"chunk_{chunk_idx:04d}.npz"
try:
if source_chunk_path and source_chunk_path.exists():
# Copy existing chunk file
shutil.copy2(source_chunk_path, chunk_file)
print(f"💾 Copied chunk {chunk_idx} to checkpoint: {chunk_file.name}")
elif frames is not None:
# Save new frames
import numpy as np
np.savez_compressed(str(chunk_file), frames=frames)
print(f"💾 Saved chunk {chunk_idx} to checkpoint: {chunk_file.name}")
else:
raise ValueError("Either frames or source_chunk_path must be provided")
# Update status
chunk_key = f"chunk_{chunk_idx}"
self.status['chunk_info'][chunk_key] = {
'completed': True,
'file': chunk_file.name,
'timestamp': datetime.now().isoformat()
}
self.status['completed_chunks'] = len([c for c in self.status['chunk_info'].values() if c['completed']])
self._save_status()
print(f"✅ Chunk {chunk_idx} checkpoint saved ({self.status['completed_chunks']}/{self.status['total_chunks']})")
except Exception as e:
print(f"❌ Failed to save chunk {chunk_idx} checkpoint: {e}")
def get_completed_chunk_files(self) -> List[Path]:
"""Get list of all completed chunk files in order"""
chunk_files = []
missing_chunks = []
for i in range(self.status['total_chunks']):
chunk_file = self.get_chunk_file(i)
if chunk_file:
chunk_files.append(chunk_file)
else:
# Check if chunk is marked as completed but file is missing
if self.is_chunk_completed(i):
missing_chunks.append(i)
print(f"⚠️ Chunk {i} marked complete but file missing!")
else:
break # Stop at first unprocessed chunk
if missing_chunks:
print(f"❌ Missing checkpoint files for chunks: {missing_chunks}")
print(f" This may happen if files were deleted during streaming merge")
print(f" These chunks may need to be reprocessed")
return chunk_files
def mark_processing_complete(self):
"""Mark all chunk processing as complete"""
self.status['processing_complete'] = True
self._save_status()
print(f"✅ All chunks processed and checkpointed")
def mark_merge_complete(self):
"""Mark final merge as complete"""
self.status['merge_complete'] = True
self._save_status()
print(f"✅ Video merge completed")
def cleanup_checkpoints(self, keep_chunks: bool = False):
"""
Clean up checkpoint files after successful completion
Args:
keep_chunks: If True, keep chunk files but remove status
"""
if keep_chunks:
# Just remove status file
if self.status_file.exists():
self.status_file.unlink()
print(f"🗑️ Removed checkpoint status file")
else:
# Remove entire checkpoint directory
if self.checkpoint_dir.exists():
shutil.rmtree(self.checkpoint_dir)
print(f"🗑️ Removed all checkpoint files: {self.checkpoint_dir}")
def get_resume_info(self) -> Dict[str, Any]:
"""Get information about what can be resumed"""
return {
'can_resume': self.status['completed_chunks'] > 0,
'completed_chunks': self.status['completed_chunks'],
'total_chunks': self.status['total_chunks'],
'processing_complete': self.status['processing_complete'],
'merge_complete': self.status['merge_complete'],
'checkpoint_dir': str(self.checkpoint_dir)
}
def print_status(self):
"""Print current checkpoint status"""
print(f"\n📊 CHECKPOINT STATUS:")
print(f" Video: {self.video_path.name}")
print(f" Hash: {self.video_hash}")
print(f" Progress: {self.status['completed_chunks']}/{self.status['total_chunks']} chunks")
print(f" Processing complete: {self.status['processing_complete']}")
print(f" Merge complete: {self.status['merge_complete']}")
print(f" Checkpoint dir: {self.checkpoint_dir}")
if self.status['completed_chunks'] > 0:
print(f"\n Completed chunks:")
for i in range(self.status['completed_chunks']):
chunk_info = self.status['chunk_info'].get(f'chunk_{i}', {})
if chunk_info.get('completed'):
print(f" ✓ Chunk {i}: {chunk_info.get('file', 'unknown')}")

View File

@@ -29,6 +29,11 @@ class MattingConfig:
fp16: bool = True fp16: bool = True
sam2_model_cfg: str = "sam2.1_hiera_l" sam2_model_cfg: str = "sam2.1_hiera_l"
sam2_checkpoint: str = "segment-anything-2/checkpoints/sam2.1_hiera_large.pt" sam2_checkpoint: str = "segment-anything-2/checkpoints/sam2.1_hiera_large.pt"
# Det-SAM2 optimizations
continuous_correction: bool = True
correction_interval: int = 60 # Add correction prompts every N frames
frame_release_interval: int = 50 # Release old frames every N frames
frame_window_size: int = 30 # Keep N frames in memory
@dataclass @dataclass

View File

@@ -1,6 +1,4 @@
import torch
import numpy as np import numpy as np
from ultralytics import YOLO
from typing import List, Tuple, Dict, Any from typing import List, Tuple, Dict, Any
import cv2 import cv2
@@ -13,14 +11,23 @@ class YOLODetector:
self.confidence_threshold = confidence_threshold self.confidence_threshold = confidence_threshold
self.device = device self.device = device
self.model = None self.model = None
self._load_model() # Don't load model during init - load lazily when first used
def _load_model(self): def _load_model(self):
"""Load YOLOv8 model""" """Load YOLOv8 model lazily"""
if self.model is not None:
return # Already loaded
try: try:
# Import heavy dependencies only when needed
import torch
from ultralytics import YOLO
self.model = YOLO(f"{self.model_name}.pt") self.model = YOLO(f"{self.model_name}.pt")
if self.device == "cuda" and torch.cuda.is_available(): if self.device == "cuda" and torch.cuda.is_available():
self.model.to("cuda") self.model.to("cuda")
print(f"🎯 Loaded YOLO model: {self.model_name}")
except Exception as e: except Exception as e:
raise RuntimeError(f"Failed to load YOLO model {self.model_name}: {e}") raise RuntimeError(f"Failed to load YOLO model {self.model_name}: {e}")
@@ -34,8 +41,9 @@ class YOLODetector:
Returns: Returns:
List of detection dictionaries with bbox, confidence, and class info List of detection dictionaries with bbox, confidence, and class info
""" """
# Load model lazily on first use
if self.model is None: if self.model is None:
raise RuntimeError("YOLO model not loaded") self._load_model()
results = self.model(frame, verbose=False) results = self.model(frame, verbose=False)
detections = [] detections = []

View File

@@ -9,12 +9,16 @@ import tempfile
import shutil import shutil
import gc import gc
try: # Check SAM2 availability without importing heavy modules
from sam2.build_sam import build_sam2_video_predictor def _check_sam2_available():
from sam2.sam2_image_predictor import SAM2ImagePredictor try:
SAM2_AVAILABLE = True import sam2
except ImportError: return True
SAM2_AVAILABLE = False except ImportError:
return False
SAM2_AVAILABLE = _check_sam2_available()
if not SAM2_AVAILABLE:
warnings.warn("SAM2 not available. Please install sam2 package.") warnings.warn("SAM2 not available. Please install sam2 package.")
@@ -40,11 +44,18 @@ class SAM2VideoMatting:
self.video_segments = {} self.video_segments = {}
self.temp_video_path = None self.temp_video_path = None
self._load_model(model_cfg, checkpoint_path) # Don't load model during init - load lazily when needed
self._model_loaded = False
def _load_model(self, model_cfg: str, checkpoint_path: str): def _load_model(self, model_cfg: str, checkpoint_path: str):
"""Load SAM2 video predictor with optimizations""" """Load SAM2 video predictor lazily"""
if self._model_loaded and self.predictor is not None:
return # Already loaded and predictor exists
try: try:
# Import heavy SAM2 modules only when needed
from sam2.build_sam import build_sam2_video_predictor
# Check for checkpoint in SAM2 repo structure # Check for checkpoint in SAM2 repo structure
if not Path(checkpoint_path).exists(): if not Path(checkpoint_path).exists():
# Try in segment-anything-2/checkpoints/ # Try in segment-anything-2/checkpoints/
@@ -63,6 +74,7 @@ class SAM2VideoMatting:
if sam2_repo_path.exists(): if sam2_repo_path.exists():
checkpoint_path = str(sam2_repo_path) checkpoint_path = str(sam2_repo_path)
print(f"🎯 Loading SAM2 model: {model_cfg}")
# Use SAM2's build_sam2_video_predictor which returns the predictor directly # Use SAM2's build_sam2_video_predictor which returns the predictor directly
# The predictor IS the model - no .model attribute needed # The predictor IS the model - no .model attribute needed
self.predictor = build_sam2_video_predictor( self.predictor = build_sam2_video_predictor(
@@ -71,13 +83,16 @@ class SAM2VideoMatting:
device=self.device device=self.device
) )
self._model_loaded = True
print(f"✅ SAM2 model loaded successfully")
except Exception as e: except Exception as e:
raise RuntimeError(f"Failed to load SAM2 model: {e}") raise RuntimeError(f"Failed to load SAM2 model: {e}")
def init_video_state(self, video_frames: List[np.ndarray] = None, video_path: str = None) -> None: def init_video_state(self, video_frames: List[np.ndarray] = None, video_path: str = None) -> None:
"""Initialize video inference state""" """Initialize video inference state"""
if self.predictor is None: # Load model lazily on first use
# Recreate predictor if it was cleaned up if not self._model_loaded:
self._load_model(self.model_cfg, self.checkpoint_path) self._load_model(self.model_cfg, self.checkpoint_path)
if video_path is not None: if video_path is not None:
@@ -152,13 +167,16 @@ class SAM2VideoMatting:
return object_ids return object_ids
def propagate_masks(self, start_frame: int = 0, max_frames: Optional[int] = None) -> Dict[int, Dict[int, np.ndarray]]: def propagate_masks(self, start_frame: int = 0, max_frames: Optional[int] = None,
frame_release_interval: int = 50, frame_window_size: int = 30) -> Dict[int, Dict[int, np.ndarray]]:
""" """
Propagate masks through video Propagate masks through video with Det-SAM2 style memory management
Args: Args:
start_frame: Starting frame index start_frame: Starting frame index
max_frames: Maximum number of frames to process max_frames: Maximum number of frames to process
frame_release_interval: Release old frames every N frames
frame_window_size: Keep N frames in memory
Returns: Returns:
Dictionary mapping frame_idx -> {obj_id: mask} Dictionary mapping frame_idx -> {obj_id: mask}
@@ -182,9 +200,108 @@ class SAM2VideoMatting:
video_segments[out_frame_idx] = frame_masks video_segments[out_frame_idx] = frame_masks
# Memory management: release old frames periodically # Det-SAM2 style memory management: more aggressive frame release
if self.memory_offload and out_frame_idx % 100 == 0: if self.memory_offload and out_frame_idx % frame_release_interval == 0:
self._release_old_frames(out_frame_idx - 50) self._release_old_frames(out_frame_idx - frame_window_size)
# Optional: Log frame release for monitoring
if out_frame_idx % (frame_release_interval * 4) == 0: # Log every 4x release interval
print(f"Det-SAM2: Released frames before {out_frame_idx - frame_window_size}, keeping {frame_window_size} frames")
return video_segments
def propagate_masks_with_continuous_correction(self,
detector,
temp_video_path: str,
start_frame: int = 0,
max_frames: Optional[int] = None,
correction_interval: int = 60,
frame_release_interval: int = 50,
frame_window_size: int = 30) -> Dict[int, Dict[int, np.ndarray]]:
"""
Det-SAM2 style: Propagate masks with continuous prompt correction
Args:
detector: YOLODetector instance for generating correction prompts
temp_video_path: Path to video file for frame access
start_frame: Starting frame index
max_frames: Maximum number of frames to process
correction_interval: Add correction prompts every N frames
frame_release_interval: Release old frames every N frames
frame_window_size: Keep N frames in memory
Returns:
Dictionary mapping frame_idx -> {obj_id: mask}
"""
if self.inference_state is None:
raise RuntimeError("Video state not initialized")
video_segments = {}
max_frames = max_frames or 10000 # Default limit
# Open video for accessing frames during propagation
cap = cv2.VideoCapture(str(temp_video_path))
try:
for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(
self.inference_state,
start_frame_idx=start_frame,
max_frame_num_to_track=max_frames,
reverse=False
):
frame_masks = {}
for i, out_obj_id in enumerate(out_obj_ids):
mask = (out_mask_logits[i] > 0.0).cpu().numpy()
frame_masks[out_obj_id] = mask
video_segments[out_frame_idx] = frame_masks
# Det-SAM2 optimization: Add correction prompts at keyframes
if (out_frame_idx % correction_interval == 0 and
out_frame_idx > start_frame and
out_frame_idx < max_frames - 1):
# Read frame for detection
cap.set(cv2.CAP_PROP_POS_FRAMES, out_frame_idx)
ret, correction_frame = cap.read()
if ret:
# Run detection on this keyframe
detections = detector.detect_persons(correction_frame)
if detections:
# Convert to prompts and add as corrections
box_prompts, labels = detector.convert_to_sam_prompts(detections)
# Add correction prompts (SAM2 will propagate backward)
correction_count = 0
try:
for i, (box, label) in enumerate(zip(box_prompts, labels)):
# Use existing object IDs if available, otherwise create new ones
obj_id = out_obj_ids[i] if i < len(out_obj_ids) else len(out_obj_ids) + i + 1
self.predictor.add_new_points_or_box(
inference_state=self.inference_state,
frame_idx=out_frame_idx,
obj_id=obj_id,
box=box,
)
correction_count += 1
print(f"Det-SAM2: Added {correction_count} correction prompts at frame {out_frame_idx}")
except Exception as e:
warnings.warn(f"Failed to add correction prompt at frame {out_frame_idx}: {e}")
# Memory management: More aggressive frame release (Det-SAM2 style)
if self.memory_offload and out_frame_idx % frame_release_interval == 0:
self._release_old_frames(out_frame_idx - frame_window_size)
# Optional: Log frame release for monitoring
if out_frame_idx % (frame_release_interval * 4) == 0: # Log every 4x release interval
print(f"Det-SAM2: Released frames before {out_frame_idx - frame_window_size}, keeping {frame_window_size} frames")
finally:
cap.release()
return video_segments return video_segments
@@ -302,6 +419,9 @@ class SAM2VideoMatting:
finally: finally:
self.predictor = None self.predictor = None
# Reset model loaded state for fresh reload
self._model_loaded = False
# Force garbage collection (critical for memory leak prevention) # Force garbage collection (critical for memory leak prevention)
gc.collect() gc.collect()

View File

@@ -281,6 +281,116 @@ class VideoProcessor:
print(f"Read {len(frames)} frames") print(f"Read {len(frames)} frames")
return frames return frames
def read_video_frames_dual_resolution(self,
video_path: str,
start_frame: int = 0,
num_frames: Optional[int] = None,
scale_factor: float = 0.5) -> Dict[str, List[np.ndarray]]:
"""
Read video frames at both original and scaled resolution for dual-resolution processing
Args:
video_path: Path to video file
start_frame: Starting frame index
num_frames: Number of frames to read (None for all)
scale_factor: Scaling factor for inference frames
Returns:
Dictionary with 'original' and 'scaled' frame lists
"""
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise RuntimeError(f"Could not open video file: {video_path}")
# Set starting position
cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
original_frames = []
scaled_frames = []
frame_count = 0
# Progress tracking
total_to_read = num_frames if num_frames else self.total_frames - start_frame
with tqdm(total=total_to_read, desc="Reading dual-resolution frames") as pbar:
while True:
ret, frame = cap.read()
if not ret:
break
# Store original frame
original_frames.append(frame.copy())
# Create scaled frame for inference
if scale_factor != 1.0:
new_width = int(frame.shape[1] * scale_factor)
new_height = int(frame.shape[0] * scale_factor)
scaled_frame = cv2.resize(frame, (new_width, new_height),
interpolation=cv2.INTER_AREA)
else:
scaled_frame = frame.copy()
scaled_frames.append(scaled_frame)
frame_count += 1
pbar.update(1)
if num_frames is not None and frame_count >= num_frames:
break
cap.release()
print(f"Loaded {len(original_frames)} frames:")
print(f" Original: {original_frames[0].shape} per frame")
print(f" Scaled: {scaled_frames[0].shape} per frame (scale_factor={scale_factor})")
return {
'original': original_frames,
'scaled': scaled_frames
}
def upscale_mask(self, mask: np.ndarray, target_shape: tuple, method: str = 'cubic') -> np.ndarray:
"""
Upscale a mask from inference resolution to original resolution
Args:
mask: Low-resolution mask (H, W)
target_shape: Target shape (H, W) for upscaling
method: Upscaling method ('nearest', 'cubic', 'area')
Returns:
Upscaled mask at target resolution
"""
if mask.shape[:2] == target_shape[:2]:
return mask # Already correct size
# Ensure mask is 2D
if mask.ndim == 3:
mask = mask.squeeze()
# Choose interpolation method
if method == 'nearest':
interpolation = cv2.INTER_NEAREST # Crisp edges, good for sharp subjects
elif method == 'cubic':
interpolation = cv2.INTER_CUBIC # Smooth edges, good for most content
elif method == 'area':
interpolation = cv2.INTER_AREA # Good for downscaling, not upscaling
else:
interpolation = cv2.INTER_CUBIC # Default to cubic
# Upscale mask
upscaled_mask = cv2.resize(
mask.astype(np.uint8),
(target_shape[1], target_shape[0]), # (width, height) for cv2.resize
interpolation=interpolation
)
# Convert back to boolean if it was originally boolean
if mask.dtype == bool:
upscaled_mask = upscaled_mask.astype(bool)
return upscaled_mask
def calculate_optimal_chunking(self) -> Tuple[int, int]: def calculate_optimal_chunking(self) -> Tuple[int, int]:
""" """
Calculate optimal chunk size and overlap based on memory constraints Calculate optimal chunk size and overlap based on memory constraints
@@ -369,6 +479,92 @@ class VideoProcessor:
return matted_frames return matted_frames
def process_chunk_dual_resolution(self,
frame_data: Dict[str, List[np.ndarray]],
chunk_idx: int = 0) -> List[np.ndarray]:
"""
Process a chunk using dual-resolution approach: inference at low-res, output at full-res
Args:
frame_data: Dictionary with 'original' and 'scaled' frame lists
chunk_idx: Chunk index for logging
Returns:
List of matted frames at original resolution
"""
original_frames = frame_data['original']
scaled_frames = frame_data['scaled']
print(f"Processing chunk {chunk_idx} with dual-resolution ({len(original_frames)} frames)")
print(f" Inference: {scaled_frames[0].shape} → Output: {original_frames[0].shape}")
with self.memory_manager.memory_monitor(f"dual-res chunk {chunk_idx}"):
# Initialize SAM2 with scaled frames for inference
self.sam2_model.init_video_state(scaled_frames)
# Detect persons in first scaled frame
first_scaled_frame = scaled_frames[0]
detections = self.detector.detect_persons(first_scaled_frame)
if not detections:
warnings.warn(f"No persons detected in chunk {chunk_idx}")
return self._create_empty_masks(original_frames)
print(f"Detected {len(detections)} persons in first frame (at inference resolution)")
# Convert detections to SAM2 prompts (detections are already at scaled resolution)
box_prompts, labels = self.detector.convert_to_sam_prompts(detections)
# Add prompts to SAM2
object_ids = self.sam2_model.add_person_prompts(0, box_prompts, labels)
print(f"Added prompts for {len(object_ids)} objects")
# Propagate masks through chunk at inference resolution
video_segments = self.sam2_model.propagate_masks(
start_frame=0,
max_frames=len(scaled_frames)
)
# Apply upscaled masks to original resolution frames
matted_frames = []
original_shape = original_frames[0].shape[:2] # (H, W)
for frame_idx, original_frame in enumerate(tqdm(original_frames, desc="Applying upscaled masks")):
if frame_idx in video_segments:
frame_masks = video_segments[frame_idx]
# Get combined mask at inference resolution
combined_mask_scaled = self.sam2_model.get_combined_mask(frame_masks)
if combined_mask_scaled is not None:
# Upscale mask to original resolution
combined_mask_full = self.upscale_mask(
combined_mask_scaled,
target_shape=original_shape,
method='cubic' # Smooth upscaling for masks
)
# Apply upscaled mask to original resolution frame
matted_frame = self.sam2_model.apply_mask_to_frame(
original_frame, combined_mask_full,
output_format=self.config.output.format,
background_color=self.config.output.background_color
)
else:
# No mask for this frame
matted_frame = self._create_empty_mask_frame(original_frame)
else:
# No mask for this frame
matted_frame = self._create_empty_mask_frame(original_frame)
matted_frames.append(matted_frame)
# Cleanup SAM2 state
self.sam2_model.cleanup()
print(f"✅ Dual-resolution processing complete: {len(matted_frames)} frames at full resolution")
return matted_frames
def _create_empty_masks(self, frames: List[np.ndarray]) -> List[np.ndarray]: def _create_empty_masks(self, frames: List[np.ndarray]) -> List[np.ndarray]:
"""Create empty masks when no persons detected""" """Create empty masks when no persons detected"""
empty_frames = [] empty_frames = []
@@ -387,19 +583,213 @@ class VideoProcessor:
# Green screen background # Green screen background
return np.full_like(frame, self.config.output.background_color, dtype=np.uint8) return np.full_like(frame, self.config.output.background_color, dtype=np.uint8)
def merge_chunks_streaming(self, chunk_files: List[Path], output_path: str,
overlap_frames: int = 0, audio_source: str = None) -> None:
"""
Merge processed chunks using streaming approach (no memory accumulation)
Args:
chunk_files: List of chunk result files (.npz)
output_path: Final output video path
overlap_frames: Number of overlapping frames
audio_source: Audio source file for final video
"""
if not chunk_files:
raise ValueError("No chunk files to merge")
print(f"🎬 TRUE Streaming merge: {len(chunk_files)} chunks → {output_path}")
# Create temporary directory for frame images
import tempfile
temp_frames_dir = Path(tempfile.mkdtemp(prefix="merge_frames_"))
frame_counter = 0
try:
print(f"📁 Using temp frames dir: {temp_frames_dir}")
# Process each chunk frame-by-frame (true streaming)
for i, chunk_file in enumerate(chunk_files):
print(f"📼 Processing chunk {i+1}/{len(chunk_files)}: {chunk_file.name}")
# Load chunk metadata without loading frames array
chunk_data = np.load(str(chunk_file))
frames_array = chunk_data['frames'] # This is still mmap'd, not loaded
total_frames_in_chunk = frames_array.shape[0]
# Determine which frames to skip for overlap
start_frame_idx = overlap_frames if i > 0 and overlap_frames > 0 else 0
frames_to_process = total_frames_in_chunk - start_frame_idx
if start_frame_idx > 0:
print(f" ✂️ Skipping first {start_frame_idx} overlapping frames")
print(f" 🔄 Processing {frames_to_process} frames one-by-one...")
# Process frames ONE AT A TIME (true streaming)
for frame_idx in range(start_frame_idx, total_frames_in_chunk):
# Load only ONE frame at a time
frame = frames_array[frame_idx] # Load single frame
# Save frame directly to disk
frame_path = temp_frames_dir / f"frame_{frame_counter:06d}.jpg"
success = cv2.imwrite(str(frame_path), frame, [cv2.IMWRITE_JPEG_QUALITY, 95])
if not success:
raise RuntimeError(f"Failed to save frame {frame_counter}")
frame_counter += 1
# Periodic progress and cleanup
if frame_counter % 100 == 0:
print(f" 💾 Saved {frame_counter} frames...")
gc.collect() # Periodic cleanup
print(f" ✅ Saved {frames_to_process} frames to disk (total: {frame_counter})")
# Close chunk file and cleanup
chunk_data.close()
del chunk_data, frames_array
# Don't delete checkpoint files - they're needed for potential resume
# The checkpoint system manages cleanup separately
print(f" 📋 Keeping checkpoint file: {chunk_file.name}")
# Aggressive cleanup and memory monitoring after each chunk
self._aggressive_memory_cleanup(f"After streaming merge chunk {i}")
# Memory safety check
memory_info = self._get_process_memory_info()
if memory_info['total_process_gb'] > 35: # Warning if approaching 46GB limit
print(f"⚠️ High memory usage: {memory_info['total_process_gb']:.1f}GB - forcing cleanup")
gc.collect()
import torch
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Create final video directly from frame images using ffmpeg
print(f"📹 Creating final video from {frame_counter} frames...")
self._create_video_from_frames(temp_frames_dir, Path(output_path), frame_counter)
# Add audio if provided
if audio_source:
self._add_audio_to_video(output_path, audio_source)
except Exception as e:
print(f"❌ Streaming merge failed: {e}")
raise
finally:
# Cleanup temporary frames directory
try:
if temp_frames_dir.exists():
import shutil
shutil.rmtree(temp_frames_dir)
print(f"🗑️ Cleaned up temp frames dir: {temp_frames_dir}")
except Exception as e:
print(f"⚠️ Could not cleanup temp frames dir: {e}")
# Memory cleanup
gc.collect()
print(f"✅ TRUE Streaming merge complete: {output_path}")
def _create_video_from_frames(self, frames_dir: Path, output_path: Path, frame_count: int):
"""Create video directly from frame images using ffmpeg (memory efficient)"""
import subprocess
frame_pattern = str(frames_dir / "frame_%06d.jpg")
fps = self.video_info['fps'] if hasattr(self, 'video_info') and self.video_info else 30.0
print(f"🎬 Creating video with ffmpeg: {frame_count} frames at {fps} fps")
# Use GPU encoding if available, fallback to CPU
gpu_cmd = [
'ffmpeg', '-y', # -y to overwrite output file
'-framerate', str(fps),
'-i', frame_pattern,
'-c:v', 'h264_nvenc', # NVIDIA GPU encoder
'-preset', 'fast',
'-cq', '18', # Quality for GPU encoding
'-pix_fmt', 'yuv420p',
str(output_path)
]
cpu_cmd = [
'ffmpeg', '-y', # -y to overwrite output file
'-framerate', str(fps),
'-i', frame_pattern,
'-c:v', 'libx264', # CPU encoder
'-preset', 'medium',
'-crf', '18', # Quality for CPU encoding
'-pix_fmt', 'yuv420p',
str(output_path)
]
# Try GPU first
print(f"🚀 Trying GPU encoding...")
result = subprocess.run(gpu_cmd, capture_output=True, text=True)
if result.returncode != 0:
print("⚠️ GPU encoding failed, using CPU...")
print(f"🔄 CPU encoding...")
result = subprocess.run(cpu_cmd, capture_output=True, text=True)
else:
print("✅ GPU encoding successful!")
if result.returncode != 0:
print(f"❌ FFmpeg stdout: {result.stdout}")
print(f"❌ FFmpeg stderr: {result.stderr}")
raise RuntimeError(f"FFmpeg failed with return code {result.returncode}")
print(f"✅ Video created successfully: {output_path}")
def _add_audio_to_video(self, video_path: str, audio_source: str):
"""Add audio to video using ffmpeg"""
import subprocess
import tempfile
try:
# Create temporary file for output with audio
temp_path = Path(video_path).with_suffix('.temp.mp4')
cmd = [
'ffmpeg', '-y',
'-i', str(video_path), # Input video (no audio)
'-i', str(audio_source), # Input audio source
'-c:v', 'copy', # Copy video without re-encoding
'-c:a', 'aac', # Encode audio as AAC
'-map', '0:v:0', # Map video from first input
'-map', '1:a:0', # Map audio from second input
'-shortest', # Match shortest stream duration
str(temp_path)
]
print(f"🎵 Adding audio: {audio_source}{video_path}")
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
print(f"⚠️ Audio addition failed: {result.stderr}")
# Keep original video without audio
return
# Replace original with audio version
Path(video_path).unlink()
temp_path.rename(video_path)
print(f"✅ Audio added successfully")
except Exception as e:
print(f"⚠️ Could not add audio: {e}")
def merge_overlapping_chunks(self, def merge_overlapping_chunks(self,
chunk_results: List[List[np.ndarray]], chunk_results: List[List[np.ndarray]],
overlap_frames: int) -> List[np.ndarray]: overlap_frames: int) -> List[np.ndarray]:
""" """
Merge overlapping chunks with blending in overlap regions Legacy merge method - DEPRECATED due to memory accumulation
Use merge_chunks_streaming() instead for memory efficiency
Args:
chunk_results: List of chunk results
overlap_frames: Number of overlapping frames
Returns:
Merged frame sequence
""" """
import warnings
warnings.warn("merge_overlapping_chunks() is deprecated due to memory accumulation. Use merge_chunks_streaming()",
DeprecationWarning, stacklevel=2)
if len(chunk_results) == 1: if len(chunk_results) == 1:
return chunk_results[0] return chunk_results[0]
@@ -584,48 +974,100 @@ class VideoProcessor:
print(f"⚠️ Could not verify frame count: {e}") print(f"⚠️ Could not verify frame count: {e}")
def process_video(self) -> None: def process_video(self) -> None:
"""Main video processing pipeline""" """Main video processing pipeline with checkpoint/resume support"""
self.processing_stats['start_time'] = time.time() self.processing_stats['start_time'] = time.time()
print("Starting VR180 video processing...") print("Starting VR180 video processing...")
# Load video info # Load video info
self.load_video_info(self.config.input.video_path) self.load_video_info(self.config.input.video_path)
# Initialize checkpoint manager
from .checkpoint_manager import CheckpointManager
checkpoint_mgr = CheckpointManager(
self.config.input.video_path,
self.config.output.path
)
# Check for existing checkpoints
resume_info = checkpoint_mgr.get_resume_info()
if resume_info['can_resume']:
print(f"\n🔄 RESUME DETECTED:")
print(f" Found {resume_info['completed_chunks']} completed chunks")
print(f" Continue from where we left off? (saves time!)")
checkpoint_mgr.print_status()
# Calculate chunking parameters # Calculate chunking parameters
chunk_size, overlap_frames = self.calculate_optimal_chunking() chunk_size, overlap_frames = self.calculate_optimal_chunking()
# Calculate total chunks
total_chunks = 0
for _ in range(0, self.total_frames, chunk_size - overlap_frames):
total_chunks += 1
checkpoint_mgr.set_total_chunks(total_chunks)
# Process video in chunks # Process video in chunks
chunk_files = [] # Store file paths instead of frame data chunk_files = [] # Store file paths instead of frame data
temp_chunk_dir = Path(tempfile.mkdtemp(prefix="vr180_chunks_")) temp_chunk_dir = Path(tempfile.mkdtemp(prefix="vr180_chunks_"))
try: try:
chunk_idx = 0
for start_frame in range(0, self.total_frames, chunk_size - overlap_frames): for start_frame in range(0, self.total_frames, chunk_size - overlap_frames):
end_frame = min(start_frame + chunk_size, self.total_frames) end_frame = min(start_frame + chunk_size, self.total_frames)
frames_to_read = end_frame - start_frame frames_to_read = end_frame - start_frame
chunk_idx = len(chunk_files) # Check if this chunk was already processed
existing_chunk = checkpoint_mgr.get_chunk_file(chunk_idx)
if existing_chunk:
print(f"\n✅ Chunk {chunk_idx} already processed: {existing_chunk.name}")
chunk_files.append(existing_chunk)
chunk_idx += 1
continue
print(f"\nProcessing chunk {chunk_idx}: frames {start_frame}-{end_frame}") print(f"\nProcessing chunk {chunk_idx}: frames {start_frame}-{end_frame}")
# Read chunk frames # Choose processing approach based on scale factor
if self.config.processing.scale_factor == 1.0:
# No scaling needed - use original single-resolution approach
print(f"🔄 Reading frames at original resolution (no scaling)")
frames = self.read_video_frames( frames = self.read_video_frames(
self.config.input.video_path,
start_frame=start_frame,
num_frames=frames_to_read,
scale_factor=1.0
)
# Process chunk normally (single resolution)
matted_frames = self.process_chunk(frames, chunk_idx)
else:
# Scaling required - use dual-resolution approach
print(f"🔄 Reading frames at dual resolution (scale_factor={self.config.processing.scale_factor})")
frame_data = self.read_video_frames_dual_resolution(
self.config.input.video_path, self.config.input.video_path,
start_frame=start_frame, start_frame=start_frame,
num_frames=frames_to_read, num_frames=frames_to_read,
scale_factor=self.config.processing.scale_factor scale_factor=self.config.processing.scale_factor
) )
# Process chunk # Process chunk with dual-resolution approach
matted_frames = self.process_chunk(frames, chunk_idx) matted_frames = self.process_chunk_dual_resolution(frame_data, chunk_idx)
# Save chunk to disk immediately to free memory # Save chunk to disk immediately to free memory
chunk_path = temp_chunk_dir / f"chunk_{chunk_idx:04d}.npz" chunk_path = temp_chunk_dir / f"chunk_{chunk_idx:04d}.npz"
print(f"Saving chunk {chunk_idx} to disk...") print(f"Saving chunk {chunk_idx} to disk...")
np.savez_compressed(str(chunk_path), frames=matted_frames) np.savez_compressed(str(chunk_path), frames=matted_frames)
# Save to checkpoint
checkpoint_mgr.save_chunk(chunk_idx, None, source_chunk_path=chunk_path)
chunk_files.append(chunk_path) chunk_files.append(chunk_path)
chunk_idx += 1
# Free the frames from memory immediately # Free the frames from memory immediately
del matted_frames del matted_frames
if self.config.processing.scale_factor == 1.0:
del frames del frames
else:
del frame_data
# Update statistics # Update statistics
self.processing_stats['chunks_processed'] += 1 self.processing_stats['chunks_processed'] += 1
@@ -640,36 +1082,41 @@ class VideoProcessor:
if self.memory_manager.should_emergency_cleanup(): if self.memory_manager.should_emergency_cleanup():
self.memory_manager.emergency_cleanup() self.memory_manager.emergency_cleanup()
# Load and merge chunks from disk # Mark chunk processing as complete
print("\nLoading and merging chunks...") checkpoint_mgr.mark_processing_complete()
chunk_results = []
for i, chunk_file in enumerate(chunk_files):
print(f"Loading {chunk_file.name}...")
chunk_data = np.load(str(chunk_file))
chunk_results.append(chunk_data['frames'])
chunk_data.close() # Close the file
# Delete chunk file immediately after loading to free disk space # Check if merge was already done
try: if resume_info.get('merge_complete', False):
chunk_file.unlink() print("\n✅ Merge already completed in previous run!")
print(f" Deleted chunk file {chunk_file.name}") print(f" Output: {self.config.output.path}")
except Exception as e: else:
print(f" Warning: Could not delete chunk file: {e}") # Use streaming merge to avoid memory accumulation (fixes OOM)
print("\n🎬 Using streaming merge (no memory accumulation)...")
# Aggressive cleanup every few chunks to prevent accumulation # For resume scenarios, make sure we have all chunk files
if i % 3 == 0 and i > 0: if resume_info['can_resume']:
self._aggressive_memory_cleanup(f"after loading chunk {i}") checkpoint_chunk_files = checkpoint_mgr.get_completed_chunk_files()
if len(checkpoint_chunk_files) != len(chunk_files):
print(f"⚠️ Using {len(checkpoint_chunk_files)} checkpoint files instead of {len(chunk_files)} temp files")
chunk_files = checkpoint_chunk_files
# Merge chunks # Determine audio source for final video
final_frames = self.merge_overlapping_chunks(chunk_results, overlap_frames) audio_source = None
if self.config.output.preserve_audio and Path(self.config.input.video_path).exists():
audio_source = self.config.input.video_path
# Free chunk results after merging - this is critical! # Stream merge chunks directly to output (no memory accumulation)
del chunk_results self.merge_chunks_streaming(
self._aggressive_memory_cleanup("after merging chunks") chunk_files=chunk_files,
output_path=self.config.output.path,
overlap_frames=overlap_frames,
audio_source=audio_source
)
# Save results # Mark merge as complete
print(f"Saving {len(final_frames)} processed frames...") checkpoint_mgr.mark_merge_complete()
self.save_video(final_frames, self.config.output.path)
print("✅ Streaming merge complete - no memory accumulation!")
# Calculate final statistics # Calculate final statistics
self.processing_stats['end_time'] = time.time() self.processing_stats['end_time'] = time.time()
@@ -685,11 +1132,24 @@ class VideoProcessor:
print("Video processing completed!") print("Video processing completed!")
# Option to clean up checkpoints
print("\n🗄️ CHECKPOINT CLEANUP OPTIONS:")
print(" Checkpoints saved successfully and can be cleaned up")
print(" - Keep checkpoints for debugging: checkpoint_mgr.cleanup_checkpoints(keep_chunks=True)")
print(" - Remove all checkpoints: checkpoint_mgr.cleanup_checkpoints()")
print(f" - Checkpoint location: {checkpoint_mgr.checkpoint_dir}")
# For now, keep checkpoints by default (user can manually clean)
print("\n💡 Checkpoints kept for safety. Delete manually when no longer needed.")
finally: finally:
# Clean up temporary chunk files # Clean up temporary chunk files (but not checkpoints)
if temp_chunk_dir.exists(): if temp_chunk_dir.exists():
print("Cleaning up temporary chunk files...") print("Cleaning up temporary chunk files...")
try:
shutil.rmtree(temp_chunk_dir) shutil.rmtree(temp_chunk_dir)
except Exception as e:
print(f"⚠️ Could not clean temp directory: {e}")
def _print_processing_statistics(self): def _print_processing_statistics(self):
"""Print detailed processing statistics""" """Print detailed processing statistics"""

View File

@@ -3,6 +3,7 @@ import numpy as np
from typing import List, Dict, Any, Optional, Tuple from typing import List, Dict, Any, Optional, Tuple
from pathlib import Path from pathlib import Path
import warnings import warnings
import torch
from .video_processor import VideoProcessor from .video_processor import VideoProcessor
from .config import VR180Config from .config import VR180Config
@@ -212,6 +213,10 @@ class VR180Processor(VideoProcessor):
del right_matted del right_matted
self._aggressive_memory_cleanup(f"After combining frames chunk {chunk_idx}") self._aggressive_memory_cleanup(f"After combining frames chunk {chunk_idx}")
# CRITICAL: Complete inter-chunk cleanup to prevent model persistence
# This ensures models don't accumulate between chunks
self._complete_inter_chunk_cleanup(chunk_idx)
return combined_frames return combined_frames
def _process_eye_sequence(self, def _process_eye_sequence(self,
@@ -375,31 +380,43 @@ class VR180Processor(VideoProcessor):
# Propagate masks (most expensive operation) # Propagate masks (most expensive operation)
self._print_memory_step(f"Before SAM2 propagation ({eye_name} eye, {num_frames} frames)") self._print_memory_step(f"Before SAM2 propagation ({eye_name} eye, {num_frames} frames)")
# Use Det-SAM2 continuous correction if enabled
if self.config.matting.continuous_correction:
video_segments = self.sam2_model.propagate_masks_with_continuous_correction(
detector=self.detector,
temp_video_path=str(temp_video_path),
start_frame=0,
max_frames=num_frames,
correction_interval=self.config.matting.correction_interval,
frame_release_interval=self.config.matting.frame_release_interval,
frame_window_size=self.config.matting.frame_window_size
)
print(f"Used Det-SAM2 continuous correction (interval: {self.config.matting.correction_interval} frames)")
else:
video_segments = self.sam2_model.propagate_masks( video_segments = self.sam2_model.propagate_masks(
start_frame=0, start_frame=0,
max_frames=num_frames max_frames=num_frames,
frame_release_interval=self.config.matting.frame_release_interval,
frame_window_size=self.config.matting.frame_window_size
) )
self._print_memory_step(f"After SAM2 propagation ({eye_name} eye)") self._print_memory_step(f"After SAM2 propagation ({eye_name} eye)")
# Apply masks - need to reload frames from temp video since we freed the original frames # Apply masks with streaming approach (no frame accumulation)
self._print_memory_step(f"Before reloading frames for mask application ({eye_name} eye)") self._print_memory_step(f"Before streaming mask application ({eye_name} eye)")
# Read frames back from the temp video for mask application # Process frames one at a time without accumulation
cap = cv2.VideoCapture(str(temp_video_path)) cap = cv2.VideoCapture(str(temp_video_path))
reloaded_frames = [] matted_frames = []
try:
for frame_idx in range(num_frames): for frame_idx in range(num_frames):
ret, frame = cap.read() ret, frame = cap.read()
if not ret: if not ret:
break break
reloaded_frames.append(frame)
cap.release()
self._print_memory_step(f"Reloaded {len(reloaded_frames)} frames for mask application") # Apply mask to this single frame
# Apply masks
matted_frames = []
for frame_idx, frame in enumerate(reloaded_frames):
if frame_idx in video_segments: if frame_idx in video_segments:
frame_masks = video_segments[frame_idx] frame_masks = video_segments[frame_idx]
combined_mask = self.sam2_model.get_combined_mask(frame_masks) combined_mask = self.sam2_model.get_combined_mask(frame_masks)
@@ -414,11 +431,22 @@ class VR180Processor(VideoProcessor):
matted_frames.append(matted_frame) matted_frames.append(matted_frame)
# Free reloaded frames and video segments completely # Free the original frame immediately (no accumulation)
del reloaded_frames del frame
del video_segments # This holds processed masks from SAM2
self._aggressive_memory_cleanup(f"After mask application ({eye_name} eye)")
# Periodic cleanup during processing
if frame_idx % 100 == 0 and frame_idx > 0:
import gc
gc.collect()
finally:
cap.release()
# Free video segments completely
del video_segments # This holds processed masks from SAM2
self._aggressive_memory_cleanup(f"After streaming mask application ({eye_name} eye)")
self._print_memory_step(f"Completed streaming mask application ({eye_name} eye)")
return matted_frames return matted_frames
finally: finally:
@@ -668,6 +696,64 @@ class VR180Processor(VideoProcessor):
# TODO: Implement proper stereo correction algorithm # TODO: Implement proper stereo correction algorithm
return right_frame return right_frame
def _complete_inter_chunk_cleanup(self, chunk_idx: int):
"""
Complete inter-chunk cleanup: Destroy all models to prevent memory accumulation
This addresses the core issue where SAM2 and YOLO models (~15-20GB)
persist between chunks, causing OOM when processing subsequent chunks.
"""
print(f"🧹 INTER-CHUNK CLEANUP: Destroying all models after chunk {chunk_idx}")
# 1. Completely destroy SAM2 model (15-20GB)
if hasattr(self, 'sam2_model') and self.sam2_model is not None:
self.sam2_model.cleanup() # Call existing cleanup
# Force complete destruction of the model
try:
# Reset the model's loaded state so it will reload fresh
if hasattr(self.sam2_model, '_model_loaded'):
self.sam2_model._model_loaded = False
# Clear any cached state
if hasattr(self.sam2_model, 'predictor'):
self.sam2_model.predictor = None
if hasattr(self.sam2_model, 'inference_state'):
self.sam2_model.inference_state = None
print(f" ✅ SAM2 model destroyed and marked for fresh reload")
except Exception as e:
print(f" ⚠️ SAM2 destruction warning: {e}")
# 2. Completely destroy YOLO detector (400MB+)
if hasattr(self, 'detector') and self.detector is not None:
try:
# Force YOLO model to be reloaded fresh
if hasattr(self.detector, 'model') and self.detector.model is not None:
del self.detector.model
self.detector.model = None
print(f" ✅ YOLO model destroyed and marked for fresh reload")
except Exception as e:
print(f" ⚠️ YOLO destruction warning: {e}")
# 3. Clear CUDA cache aggressively
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize() # Wait for all operations to complete
print(f" ✅ CUDA cache cleared")
# 4. Force garbage collection
import gc
collected = gc.collect()
print(f" ✅ Garbage collection: {collected} objects freed")
# 5. Memory verification
self._print_memory_step(f"After complete inter-chunk cleanup (chunk {chunk_idx})")
print(f"🎯 RESULT: Models will reload fresh for next chunk (prevents 15-20GB accumulation)")
def process_chunk(self, def process_chunk(self,
frames: List[np.ndarray], frames: List[np.ndarray],
chunk_idx: int = 0) -> List[np.ndarray]: chunk_idx: int = 0) -> List[np.ndarray]:
@@ -727,6 +813,9 @@ class VR180Processor(VideoProcessor):
combined = {'left': left_frame, 'right': right_frame} combined = {'left': left_frame, 'right': right_frame}
combined_frames.append(combined) combined_frames.append(combined)
# CRITICAL: Complete inter-chunk cleanup for independent processing too
self._complete_inter_chunk_cleanup(chunk_idx)
return combined_frames return combined_frames
def save_video(self, frames: List[np.ndarray], output_path: str): def save_video(self, frames: List[np.ndarray], output_path: str):