Compare commits

8 Commits

Author SHA1 Message Date
4d1361df46 bigtime 2025-07-26 15:29:37 -07:00
884cb8dce2 lol 2025-07-26 15:29:28 -07:00
36f58acb8b foo 2025-07-26 15:18:32 -07:00
fb51e82fd4 stuff 2025-07-26 15:18:01 -07:00
9f572d4430 analyze 2025-07-26 15:10:34 -07:00
ba8706b7ae quick check 2025-07-26 14:52:44 -07:00
734445cf48 more memory fixes hopeufly 2025-07-26 14:33:36 -07:00
80f947c91b det core 2025-07-26 13:51:21 -07:00
10 changed files with 1220 additions and 90 deletions

193
analyze_memory_profile.py Normal file
View File

@@ -0,0 +1,193 @@
#!/usr/bin/env python3
"""
Analyze memory profile JSON files to identify OOM causes
"""
import json
import glob
import os
import sys
from pathlib import Path
def analyze_memory_files():
"""Analyze partial memory profile files"""
# Get all partial files in order
files = sorted(glob.glob('memory_profile_partial_*.json'))
if not files:
print("❌ No memory profile files found!")
print("Expected files like: memory_profile_partial_0.json")
return
print(f"🔍 Found {len(files)} memory profile files")
print("=" * 60)
peak_memory = 0
peak_vram = 0
critical_points = []
all_checkpoints = []
for i, file in enumerate(files):
try:
with open(file, 'r') as f:
data = json.load(f)
timeline = data.get('timeline', [])
if not timeline:
continue
# Find peaks in this file
file_peak_rss = max([d['rss_gb'] for d in timeline])
file_peak_vram = max([d['vram_gb'] for d in timeline])
if file_peak_rss > peak_memory:
peak_memory = file_peak_rss
if file_peak_vram > peak_vram:
peak_vram = file_peak_vram
# Find memory growth spikes (>3GB increase)
for j in range(1, len(timeline)):
prev_rss = timeline[j-1]['rss_gb']
curr_rss = timeline[j]['rss_gb']
growth = curr_rss - prev_rss
if growth > 3.0: # >3GB growth spike
checkpoint = timeline[j].get('checkpoint', f'sample_{j}')
critical_points.append({
'file': file,
'file_index': i,
'sample': j,
'timestamp': timeline[j]['timestamp'],
'rss_gb': curr_rss,
'vram_gb': timeline[j]['vram_gb'],
'growth_gb': growth,
'checkpoint': checkpoint
})
# Collect all checkpoints
checkpoints = [d for d in timeline if 'checkpoint' in d]
for cp in checkpoints:
cp['file'] = file
cp['file_index'] = i
all_checkpoints.append(cp)
# Show progress for this file
if timeline:
start_rss = timeline[0]['rss_gb']
end_rss = timeline[-1]['rss_gb']
growth = end_rss - start_rss
samples = len(timeline)
print(f"📊 File {i+1:2d}: {start_rss:5.1f}GB → {end_rss:5.1f}GB "
f"(+{growth:4.1f}GB) [{samples:3d} samples]")
# Show significant checkpoints from this file
if checkpoints:
for cp in checkpoints:
print(f" 📍 {cp['checkpoint']}: {cp['rss_gb']:.1f}GB")
except Exception as e:
print(f"❌ Error reading {file}: {e}")
print("\n" + "=" * 60)
print("🎯 ANALYSIS SUMMARY")
print("=" * 60)
print(f"📈 Peak Memory: {peak_memory:.1f} GB")
print(f"🎮 Peak VRAM: {peak_vram:.1f} GB")
print(f"⚡ Growth Spikes: {len(critical_points)} events >3GB")
if critical_points:
print(f"\n💥 MEMORY GROWTH SPIKES (>3GB):")
print(" Location Growth Total VRAM")
print(" " + "-" * 55)
for point in critical_points:
location = point['checkpoint'][:30].ljust(30)
print(f" {location} +{point['growth_gb']:4.1f}GB → {point['rss_gb']:5.1f}GB {point['vram_gb']:4.1f}GB")
if all_checkpoints:
print(f"\n📍 CHECKPOINT PROGRESSION:")
print(" Checkpoint Memory VRAM File")
print(" " + "-" * 55)
for cp in all_checkpoints:
checkpoint = cp['checkpoint'][:30].ljust(30)
file_num = cp['file_index'] + 1
print(f" {checkpoint} {cp['rss_gb']:5.1f}GB {cp['vram_gb']:4.1f}GB #{file_num}")
# Memory growth analysis
if len(all_checkpoints) > 1:
print(f"\n📊 MEMORY GROWTH ANALYSIS:")
# Find the biggest memory jumps between checkpoints
big_jumps = []
for i in range(1, len(all_checkpoints)):
prev_cp = all_checkpoints[i-1]
curr_cp = all_checkpoints[i]
growth = curr_cp['rss_gb'] - prev_cp['rss_gb']
if growth > 2.0: # >2GB jump
big_jumps.append({
'from': prev_cp['checkpoint'],
'to': curr_cp['checkpoint'],
'growth': growth,
'from_memory': prev_cp['rss_gb'],
'to_memory': curr_cp['rss_gb']
})
if big_jumps:
print(" Major jumps (>2GB):")
for jump in big_jumps:
print(f" {jump['from']}{jump['to']}: "
f"+{jump['growth']:.1f}GB ({jump['from_memory']:.1f}{jump['to_memory']:.1f}GB)")
else:
print(" ✅ No major memory jumps detected")
# Diagnosis
print(f"\n🔬 DIAGNOSIS:")
if peak_memory > 400:
print(" 🔴 CRITICAL: Memory usage exceeded 400GB")
print(" 💡 Recommendation: Reduce chunk_size to 200-300 frames")
elif peak_memory > 200:
print(" 🟡 HIGH: Memory usage over 200GB")
print(" 💡 Recommendation: Reduce chunk_size to 400 frames")
else:
print(" 🟢 MODERATE: Memory usage under 200GB")
if critical_points:
# Find most common growth spike locations
spike_locations = {}
for point in critical_points:
location = point['checkpoint']
spike_locations[location] = spike_locations.get(location, 0) + 1
print("\n 🎯 Most problematic locations:")
for location, count in sorted(spike_locations.items(), key=lambda x: x[1], reverse=True)[:3]:
print(f" {location}: {count} spikes")
print(f"\n💡 NEXT STEPS:")
if 'merge' in str(critical_points).lower():
print(" 1. Chunk merging still causing memory accumulation")
print(" 2. Check if streaming merge is actually being used")
print(" 3. Verify chunk files are being deleted immediately")
elif 'propagation' in str(critical_points).lower():
print(" 1. SAM2 propagation using too much memory")
print(" 2. Reduce chunk_size further (try 300 frames)")
print(" 3. Enable more aggressive frame release")
else:
print(" 1. Review the checkpoint progression above")
print(" 2. Focus on locations with biggest memory spikes")
print(" 3. Consider reducing chunk_size if spikes are large")
def main():
print("🔍 MEMORY PROFILE ANALYZER")
print("Analyzing memory profile files for OOM causes...")
print()
analyze_memory_files()
if __name__ == "__main__":
main()

151
debug_memory_leak.py Normal file
View File

@@ -0,0 +1,151 @@
#!/usr/bin/env python3
"""
Debug memory leak between chunks - track exactly where memory accumulates
"""
import psutil
import gc
from pathlib import Path
import sys
def detailed_memory_check(label):
"""Get detailed memory info"""
process = psutil.Process()
memory_info = process.memory_info()
rss_gb = memory_info.rss / (1024**3)
vms_gb = memory_info.vms / (1024**3)
# System memory
sys_memory = psutil.virtual_memory()
available_gb = sys_memory.available / (1024**3)
print(f"🔍 {label}:")
print(f" RSS: {rss_gb:.2f} GB (physical memory)")
print(f" VMS: {vms_gb:.2f} GB (virtual memory)")
print(f" Available: {available_gb:.2f} GB")
return rss_gb
def simulate_chunk_processing():
"""Simulate the chunk processing to see where memory accumulates"""
print("🚀 SIMULATING CHUNK PROCESSING TO FIND MEMORY LEAK")
print("=" * 60)
base_memory = detailed_memory_check("0. Baseline")
# Step 1: Import everything (with lazy loading)
print("\n📦 Step 1: Imports")
from vr180_matting.config import VR180Config
from vr180_matting.vr180_processor import VR180Processor
import_memory = detailed_memory_check("1. After imports")
import_growth = import_memory - base_memory
print(f" Growth: +{import_growth:.2f} GB")
# Step 2: Load config
print("\n⚙️ Step 2: Config loading")
config = VR180Config.from_yaml('config.yaml')
config_memory = detailed_memory_check("2. After config load")
config_growth = config_memory - import_memory
print(f" Growth: +{config_growth:.2f} GB")
# Step 3: Initialize processor (models still lazy)
print("\n🏗️ Step 3: Processor initialization")
processor = VR180Processor(config)
processor_memory = detailed_memory_check("3. After processor init")
processor_growth = processor_memory - config_memory
print(f" Growth: +{processor_growth:.2f} GB")
# Step 4: Load video info (lightweight)
print("\n🎬 Step 4: Video info loading")
try:
video_info = processor.load_video_info(config.input.video_path)
print(f" Video: {video_info.get('width', 'unknown')}x{video_info.get('height', 'unknown')}, "
f"{video_info.get('total_frames', 'unknown')} frames")
except Exception as e:
print(f" Warning: Could not load video info: {e}")
video_info_memory = detailed_memory_check("4. After video info")
video_info_growth = video_info_memory - processor_memory
print(f" Growth: +{video_info_growth:.2f} GB")
# Step 5: Simulate chunk 0 processing (this is where models actually load)
print("\n🔄 Step 5: Simulating chunk 0 processing...")
# This is where the real memory usage starts
print(" Loading first 10 frames to trigger model loading...")
try:
# Read a small number of frames to trigger model loading
frames = processor.read_video_frames(
config.input.video_path,
start_frame=0,
num_frames=10, # Just 10 frames to trigger model loading
scale_factor=config.processing.scale_factor
)
frames_memory = detailed_memory_check("5a. After reading 10 frames")
frames_growth = frames_memory - video_info_memory
print(f" 10 frames growth: +{frames_growth:.2f} GB")
# Free frames
del frames
gc.collect()
after_free_memory = detailed_memory_check("5b. After freeing 10 frames")
free_improvement = frames_memory - after_free_memory
print(f" Memory freed: -{free_improvement:.2f} GB")
except Exception as e:
print(f" Could not simulate frame loading: {e}")
after_free_memory = video_info_memory
print(f"\n📊 MEMORY ANALYSIS:")
print(f" Baseline → Final: {base_memory:.2f}GB → {after_free_memory:.2f}GB")
print(f" Total growth: +{after_free_memory - base_memory:.2f}GB")
if after_free_memory - base_memory > 10:
print(f" 🔴 HIGH: Memory growth > 10GB before any real processing")
print(f" 💡 This suggests model loading is using too much memory")
elif after_free_memory - base_memory > 5:
print(f" 🟡 MODERATE: Memory growth 5-10GB")
print(f" 💡 Normal for model loading, but monitor chunk processing")
else:
print(f" 🟢 GOOD: Memory growth < 5GB")
print(f" 💡 Initialization memory usage is reasonable")
print(f"\n🎯 KEY INSIGHTS:")
if import_growth > 1:
print(f" - Import growth: {import_growth:.2f}GB (fixed with lazy loading)")
if processor_growth > 10:
print(f" - Processor init: {processor_growth:.2f}GB (investigate model pre-loading)")
print(f"\n💡 RECOMMENDATIONS:")
if after_free_memory - base_memory > 15:
print(f" 1. Reduce chunk_size to 200-300 frames")
print(f" 2. Use smaller models (yolov8n instead of yolov8m)")
print(f" 3. Enable FP16 mode for SAM2")
elif after_free_memory - base_memory > 8:
print(f" 1. Monitor chunk processing carefully")
print(f" 2. Use streaming merge (should be automatic)")
print(f" 3. Current settings may be acceptable")
else:
print(f" 1. Settings look good for initialization")
print(f" 2. Focus on chunk processing memory leaks")
def main():
if len(sys.argv) != 2:
print("Usage: python debug_memory_leak.py <config.yaml>")
print("This simulates initialization to find memory leaks")
sys.exit(1)
config_path = sys.argv[1]
if not Path(config_path).exists():
print(f"Config file not found: {config_path}")
sys.exit(1)
simulate_chunk_processing()
if __name__ == "__main__":
main()

249
memory_profiler_script.py Normal file
View File

@@ -0,0 +1,249 @@
#!/usr/bin/env python3
"""
Memory profiling script for VR180 matting pipeline
Tracks memory usage during processing to identify leaks
"""
import sys
import time
import psutil
import tracemalloc
import subprocess
import gc
from pathlib import Path
from typing import Dict, List, Tuple
import threading
import json
class MemoryProfiler:
def __init__(self, output_file: str = "memory_profile.json"):
self.output_file = output_file
self.data = []
self.process = psutil.Process()
self.running = False
self.thread = None
self.checkpoint_counter = 0
def start_monitoring(self, interval: float = 1.0):
"""Start continuous memory monitoring"""
tracemalloc.start()
self.running = True
self.thread = threading.Thread(target=self._monitor_loop, args=(interval,))
self.thread.daemon = True
self.thread.start()
print(f"🔍 Memory monitoring started (interval: {interval}s)")
def stop_monitoring(self):
"""Stop monitoring and save results"""
self.running = False
if self.thread:
self.thread.join()
# Get tracemalloc snapshot
snapshot = tracemalloc.take_snapshot()
top_stats = snapshot.statistics('lineno')
# Save detailed results
results = {
'timeline': self.data,
'top_memory_allocations': [
{
'file': stat.traceback.format()[0],
'size_mb': stat.size / 1024 / 1024,
'count': stat.count
}
for stat in top_stats[:20] # Top 20 allocations
],
'summary': {
'peak_rss_gb': max([d['rss_gb'] for d in self.data]) if self.data else 0,
'peak_vram_gb': max([d['vram_gb'] for d in self.data]) if self.data else 0,
'total_samples': len(self.data)
}
}
with open(self.output_file, 'w') as f:
json.dump(results, f, indent=2)
tracemalloc.stop()
print(f"📊 Memory profile saved to {self.output_file}")
def _monitor_loop(self, interval: float):
"""Continuous monitoring loop"""
while self.running:
try:
# System memory
memory_info = self.process.memory_info()
rss_gb = memory_info.rss / (1024**3)
# System-wide memory
sys_memory = psutil.virtual_memory()
sys_used_gb = (sys_memory.total - sys_memory.available) / (1024**3)
sys_available_gb = sys_memory.available / (1024**3)
# GPU memory (if available)
vram_gb = 0
vram_free_gb = 0
try:
result = subprocess.run(['nvidia-smi', '--query-gpu=memory.used,memory.free',
'--format=csv,noheader,nounits'],
capture_output=True, text=True, timeout=5)
if result.returncode == 0:
lines = result.stdout.strip().split('\n')
if lines and lines[0]:
used, free = lines[0].split(', ')
vram_gb = float(used) / 1024
vram_free_gb = float(free) / 1024
except Exception:
pass
# Tracemalloc current usage
try:
current, peak = tracemalloc.get_traced_memory()
traced_mb = current / (1024**2)
except Exception:
traced_mb = 0
data_point = {
'timestamp': time.time(),
'rss_gb': rss_gb,
'vram_gb': vram_gb,
'vram_free_gb': vram_free_gb,
'sys_used_gb': sys_used_gb,
'sys_available_gb': sys_available_gb,
'traced_mb': traced_mb
}
self.data.append(data_point)
# Print periodic updates and save partial data
if len(self.data) % 10 == 0: # Every 10 samples
print(f"🔍 Memory: RSS={rss_gb:.2f}GB, VRAM={vram_gb:.2f}GB, Sys={sys_used_gb:.1f}GB")
# Save partial data every 30 samples in case of crash
if len(self.data) % 30 == 0:
self._save_partial_data()
except Exception as e:
print(f"Monitoring error: {e}")
time.sleep(interval)
def _save_partial_data(self):
"""Save partial data to prevent loss on crash"""
try:
partial_file = f"memory_profile_partial_{self.checkpoint_counter}.json"
with open(partial_file, 'w') as f:
json.dump({
'timeline': self.data,
'status': 'partial_save',
'samples': len(self.data)
}, f, indent=2)
self.checkpoint_counter += 1
except Exception as e:
print(f"Failed to save partial data: {e}")
def log_checkpoint(self, checkpoint_name: str):
"""Log a specific checkpoint"""
if self.data:
self.data[-1]['checkpoint'] = checkpoint_name
latest = self.data[-1]
print(f"📍 CHECKPOINT [{checkpoint_name}]: RSS={latest['rss_gb']:.2f}GB, VRAM={latest['vram_gb']:.2f}GB")
# Save checkpoint data immediately
self._save_partial_data()
def run_with_profiling(config_path: str):
"""Run the VR180 matting with memory profiling"""
profiler = MemoryProfiler("memory_profile_detailed.json")
try:
# Start monitoring
profiler.start_monitoring(interval=2.0) # Sample every 2 seconds
# Log initial state
profiler.log_checkpoint("STARTUP")
# Import after starting profiler to catch import memory usage
print("Importing VR180 processor...")
from vr180_matting.vr180_processor import VR180Processor
from vr180_matting.config import VR180Config
profiler.log_checkpoint("IMPORTS_COMPLETE")
# Load config
print(f"Loading config from {config_path}")
config = VR180Config.from_yaml(config_path)
profiler.log_checkpoint("CONFIG_LOADED")
# Initialize processor
print("Initializing VR180 processor...")
processor = VR180Processor(config)
profiler.log_checkpoint("PROCESSOR_INITIALIZED")
# Force garbage collection
gc.collect()
profiler.log_checkpoint("INITIAL_GC_COMPLETE")
# Run processing
print("Starting VR180 processing...")
processor.process_video()
profiler.log_checkpoint("PROCESSING_COMPLETE")
except Exception as e:
print(f"❌ Error during processing: {e}")
profiler.log_checkpoint(f"ERROR: {str(e)}")
raise
finally:
# Stop monitoring and save results
profiler.stop_monitoring()
# Print summary
print("\n" + "="*60)
print("MEMORY PROFILING SUMMARY")
print("="*60)
if profiler.data:
peak_rss = max([d['rss_gb'] for d in profiler.data])
peak_vram = max([d['vram_gb'] for d in profiler.data])
print(f"Peak RSS Memory: {peak_rss:.2f} GB")
print(f"Peak VRAM Usage: {peak_vram:.2f} GB")
print(f"Total Samples: {len(profiler.data)}")
# Show checkpoints
checkpoints = [d for d in profiler.data if 'checkpoint' in d]
if checkpoints:
print(f"\nCheckpoints ({len(checkpoints)}):")
for cp in checkpoints:
print(f" {cp['checkpoint']}: RSS={cp['rss_gb']:.2f}GB, VRAM={cp['vram_gb']:.2f}GB")
print(f"\nDetailed profile saved to: {profiler.output_file}")
def main():
if len(sys.argv) != 2:
print("Usage: python memory_profiler_script.py <config.yaml>")
print("\nThis script runs VR180 matting with detailed memory profiling")
print("It will:")
print("- Monitor RSS, VRAM, and system memory every 2 seconds")
print("- Track memory allocations with tracemalloc")
print("- Log checkpoints at key processing stages")
print("- Save detailed JSON report for analysis")
sys.exit(1)
config_path = sys.argv[1]
if not Path(config_path).exists():
print(f"❌ Config file not found: {config_path}")
sys.exit(1)
print("🚀 Starting VR180 Memory Profiling")
print(f"Config: {config_path}")
print("="*60)
run_with_profiling(config_path)
if __name__ == "__main__":
main()

125
quick_memory_check.py Normal file
View File

@@ -0,0 +1,125 @@
#!/usr/bin/env python3
"""
Quick memory and system check before running full pipeline
"""
import psutil
import subprocess
import sys
from pathlib import Path
def check_system():
"""Check system resources before starting"""
print("🔍 SYSTEM RESOURCE CHECK")
print("=" * 50)
# Memory info
memory = psutil.virtual_memory()
print(f"📊 RAM:")
print(f" Total: {memory.total / (1024**3):.1f} GB")
print(f" Available: {memory.available / (1024**3):.1f} GB")
print(f" Used: {(memory.total - memory.available) / (1024**3):.1f} GB ({memory.percent:.1f}%)")
# GPU info
try:
result = subprocess.run(['nvidia-smi', '--query-gpu=name,memory.total,memory.used,memory.free',
'--format=csv,noheader,nounits'],
capture_output=True, text=True, timeout=10)
if result.returncode == 0:
lines = result.stdout.strip().split('\n')
print(f"\n🎮 GPU:")
for i, line in enumerate(lines):
if line.strip():
parts = line.split(', ')
if len(parts) >= 4:
name, total, used, free = parts[:4]
total_gb = float(total) / 1024
used_gb = float(used) / 1024
free_gb = float(free) / 1024
print(f" GPU {i}: {name}")
print(f" VRAM: {used_gb:.1f}/{total_gb:.1f} GB ({used_gb/total_gb*100:.1f}% used)")
print(f" Free: {free_gb:.1f} GB")
except Exception as e:
print(f"\n⚠️ Could not get GPU info: {e}")
# Disk space
disk = psutil.disk_usage('/')
print(f"\n💾 Disk (/):")
print(f" Total: {disk.total / (1024**3):.1f} GB")
print(f" Used: {disk.used / (1024**3):.1f} GB ({disk.used/disk.total*100:.1f}%)")
print(f" Free: {disk.free / (1024**3):.1f} GB")
# Check config file
if len(sys.argv) > 1:
config_path = sys.argv[1]
if Path(config_path).exists():
print(f"\n✅ Config file found: {config_path}")
# Try to load and show key settings
try:
import yaml
with open(config_path, 'r') as f:
config = yaml.safe_load(f)
print(f"📋 Key Settings:")
if 'processing' in config:
proc = config['processing']
print(f" Chunk size: {proc.get('chunk_size', 'default')}")
print(f" Scale factor: {proc.get('scale_factor', 'default')}")
if 'hardware' in config:
hw = config['hardware']
print(f" Max VRAM: {hw.get('max_vram_gb', 'default')} GB")
if 'input' in config:
inp = config['input']
video_path = inp.get('video_path', '')
if video_path and Path(video_path).exists():
size_gb = Path(video_path).stat().st_size / (1024**3)
print(f" Input video: {video_path} ({size_gb:.1f} GB)")
else:
print(f" ⚠️ Input video not found: {video_path}")
except Exception as e:
print(f" ⚠️ Could not parse config: {e}")
else:
print(f"\n❌ Config file not found: {config_path}")
return False
# Memory safety warnings
print(f"\n⚠️ MEMORY SAFETY CHECKS:")
available_gb = memory.available / (1024**3)
if available_gb < 10:
print(f" 🔴 LOW MEMORY: Only {available_gb:.1f}GB available")
print(" Consider: reducing chunk_size or scale_factor")
return False
elif available_gb < 20:
print(f" 🟡 MODERATE MEMORY: {available_gb:.1f}GB available")
print(" Recommend: chunk_size ≤ 300, scale_factor ≤ 0.5")
else:
print(f" 🟢 GOOD MEMORY: {available_gb:.1f}GB available")
print(f"\n" + "=" * 50)
return True
def main():
if len(sys.argv) != 2:
print("Usage: python quick_memory_check.py <config.yaml>")
print("\nThis checks system resources before running VR180 matting")
sys.exit(1)
safe_to_run = check_system()
if safe_to_run:
print("✅ System check passed - safe to run VR180 matting")
print("\nTo run with memory profiling:")
print(f" python memory_profiler_script.py {sys.argv[1]}")
print("\nTo run normally:")
print(f" vr180-matting {sys.argv[1]}")
else:
print("❌ System check failed - address issues before running")
sys.exit(1)
if __name__ == "__main__":
main()

139
test_inter_chunk_cleanup.py Normal file
View File

@@ -0,0 +1,139 @@
#!/usr/bin/env python3
"""
Test script to verify inter-chunk cleanup properly destroys models
"""
import psutil
import gc
import sys
from pathlib import Path
def get_memory_usage():
"""Get current memory usage in GB"""
process = psutil.Process()
return process.memory_info().rss / (1024**3)
def test_inter_chunk_cleanup():
"""Test that models are properly destroyed between chunks"""
print("🧪 TESTING INTER-CHUNK CLEANUP")
print("=" * 50)
baseline_memory = get_memory_usage()
print(f"📊 Baseline memory: {baseline_memory:.2f} GB")
# Import and create processor
print("\n1⃣ Creating processor...")
from vr180_matting.config import VR180Config
from vr180_matting.vr180_processor import VR180Processor
config = VR180Config.from_yaml('config.yaml')
processor = VR180Processor(config)
init_memory = get_memory_usage()
print(f"📊 After processor init: {init_memory:.2f} GB (+{init_memory - baseline_memory:.2f} GB)")
# Simulate chunk processing (just trigger model loading)
print("\n2⃣ Simulating chunk 0 processing...")
# Test 1: Force YOLO model loading
try:
detector = processor.detector
detector._load_model() # Force load
yolo_memory = get_memory_usage()
print(f"📊 After YOLO load: {yolo_memory:.2f} GB (+{yolo_memory - init_memory:.2f} GB)")
except Exception as e:
print(f"❌ YOLO loading failed: {e}")
yolo_memory = init_memory
# Test 2: Force SAM2 model loading
try:
sam2_model = processor.sam2_model
sam2_model._load_model(sam2_model.model_cfg, sam2_model.checkpoint_path)
sam2_memory = get_memory_usage()
print(f"📊 After SAM2 load: {sam2_memory:.2f} GB (+{sam2_memory - yolo_memory:.2f} GB)")
except Exception as e:
print(f"❌ SAM2 loading failed: {e}")
sam2_memory = yolo_memory
total_model_memory = sam2_memory - init_memory
print(f"📊 Total model memory: {total_model_memory:.2f} GB")
# Test 3: Inter-chunk cleanup
print("\n3⃣ Testing inter-chunk cleanup...")
processor._complete_inter_chunk_cleanup(chunk_idx=0)
cleanup_memory = get_memory_usage()
cleanup_improvement = sam2_memory - cleanup_memory
print(f"📊 After cleanup: {cleanup_memory:.2f} GB (-{cleanup_improvement:.2f} GB freed)")
# Test 4: Verify models reload fresh
print("\n4⃣ Testing fresh model reload...")
# Check YOLO state
yolo_reloaded = processor.detector.model is None
print(f"🔍 YOLO model destroyed: {'✅ YES' if yolo_reloaded else '❌ NO'}")
# Check SAM2 state
sam2_reloaded = not processor.sam2_model._model_loaded or processor.sam2_model.predictor is None
print(f"🔍 SAM2 model destroyed: {'✅ YES' if sam2_reloaded else '❌ NO'}")
# Test 5: Force reload to verify they work
print("\n5⃣ Testing model reload...")
try:
# Force YOLO reload
processor.detector._load_model()
yolo_reload_memory = get_memory_usage()
# Force SAM2 reload
processor.sam2_model._load_model(processor.sam2_model.model_cfg, processor.sam2_model.checkpoint_path)
sam2_reload_memory = get_memory_usage()
reload_growth = sam2_reload_memory - cleanup_memory
print(f"📊 After reload: {sam2_reload_memory:.2f} GB (+{reload_growth:.2f} GB)")
if abs(reload_growth - total_model_memory) < 1.0: # Within 1GB
print("✅ Models reloaded with similar memory usage (good)")
else:
print("⚠️ Model reload memory differs significantly")
except Exception as e:
print(f"❌ Model reload failed: {e}")
# Final summary
print(f"\n📊 SUMMARY:")
print(f" Baseline → Peak: {baseline_memory:.2f}GB → {sam2_memory:.2f}GB")
print(f" Peak → Cleanup: {sam2_memory:.2f}GB → {cleanup_memory:.2f}GB")
print(f" Memory freed: {cleanup_improvement:.2f}GB")
print(f" Models destroyed: YOLO={yolo_reloaded}, SAM2={sam2_reloaded}")
if cleanup_improvement > total_model_memory * 0.5: # Freed >50% of model memory
print("✅ Inter-chunk cleanup working effectively")
return True
else:
print("❌ Inter-chunk cleanup not freeing enough memory")
return False
def main():
if len(sys.argv) != 2:
print("Usage: python test_inter_chunk_cleanup.py <config.yaml>")
sys.exit(1)
config_path = sys.argv[1]
if not Path(config_path).exists():
print(f"Config file not found: {config_path}")
sys.exit(1)
success = test_inter_chunk_cleanup()
if success:
print(f"\n🎉 SUCCESS: Inter-chunk cleanup is working!")
print(f"💡 This should prevent 15-20GB model accumulation between chunks")
else:
print(f"\n❌ FAILURE: Inter-chunk cleanup needs improvement")
print(f"💡 Check model destruction logic in _complete_inter_chunk_cleanup")
return 0 if success else 1
if __name__ == "__main__":
sys.exit(main())

View File

@@ -29,6 +29,11 @@ class MattingConfig:
fp16: bool = True fp16: bool = True
sam2_model_cfg: str = "sam2.1_hiera_l" sam2_model_cfg: str = "sam2.1_hiera_l"
sam2_checkpoint: str = "segment-anything-2/checkpoints/sam2.1_hiera_large.pt" sam2_checkpoint: str = "segment-anything-2/checkpoints/sam2.1_hiera_large.pt"
# Det-SAM2 optimizations
continuous_correction: bool = True
correction_interval: int = 60 # Add correction prompts every N frames
frame_release_interval: int = 50 # Release old frames every N frames
frame_window_size: int = 30 # Keep N frames in memory
@dataclass @dataclass

View File

@@ -1,6 +1,4 @@
import torch
import numpy as np import numpy as np
from ultralytics import YOLO
from typing import List, Tuple, Dict, Any from typing import List, Tuple, Dict, Any
import cv2 import cv2
@@ -13,14 +11,23 @@ class YOLODetector:
self.confidence_threshold = confidence_threshold self.confidence_threshold = confidence_threshold
self.device = device self.device = device
self.model = None self.model = None
self._load_model() # Don't load model during init - load lazily when first used
def _load_model(self): def _load_model(self):
"""Load YOLOv8 model""" """Load YOLOv8 model lazily"""
if self.model is not None:
return # Already loaded
try: try:
# Import heavy dependencies only when needed
import torch
from ultralytics import YOLO
self.model = YOLO(f"{self.model_name}.pt") self.model = YOLO(f"{self.model_name}.pt")
if self.device == "cuda" and torch.cuda.is_available(): if self.device == "cuda" and torch.cuda.is_available():
self.model.to("cuda") self.model.to("cuda")
print(f"🎯 Loaded YOLO model: {self.model_name}")
except Exception as e: except Exception as e:
raise RuntimeError(f"Failed to load YOLO model {self.model_name}: {e}") raise RuntimeError(f"Failed to load YOLO model {self.model_name}: {e}")
@@ -34,8 +41,9 @@ class YOLODetector:
Returns: Returns:
List of detection dictionaries with bbox, confidence, and class info List of detection dictionaries with bbox, confidence, and class info
""" """
# Load model lazily on first use
if self.model is None: if self.model is None:
raise RuntimeError("YOLO model not loaded") self._load_model()
results = self.model(frame, verbose=False) results = self.model(frame, verbose=False)
detections = [] detections = []

View File

@@ -9,12 +9,16 @@ import tempfile
import shutil import shutil
import gc import gc
try: # Check SAM2 availability without importing heavy modules
from sam2.build_sam import build_sam2_video_predictor def _check_sam2_available():
from sam2.sam2_image_predictor import SAM2ImagePredictor try:
SAM2_AVAILABLE = True import sam2
except ImportError: return True
SAM2_AVAILABLE = False except ImportError:
return False
SAM2_AVAILABLE = _check_sam2_available()
if not SAM2_AVAILABLE:
warnings.warn("SAM2 not available. Please install sam2 package.") warnings.warn("SAM2 not available. Please install sam2 package.")
@@ -40,11 +44,18 @@ class SAM2VideoMatting:
self.video_segments = {} self.video_segments = {}
self.temp_video_path = None self.temp_video_path = None
self._load_model(model_cfg, checkpoint_path) # Don't load model during init - load lazily when needed
self._model_loaded = False
def _load_model(self, model_cfg: str, checkpoint_path: str): def _load_model(self, model_cfg: str, checkpoint_path: str):
"""Load SAM2 video predictor with optimizations""" """Load SAM2 video predictor lazily"""
if self._model_loaded and self.predictor is not None:
return # Already loaded and predictor exists
try: try:
# Import heavy SAM2 modules only when needed
from sam2.build_sam import build_sam2_video_predictor
# Check for checkpoint in SAM2 repo structure # Check for checkpoint in SAM2 repo structure
if not Path(checkpoint_path).exists(): if not Path(checkpoint_path).exists():
# Try in segment-anything-2/checkpoints/ # Try in segment-anything-2/checkpoints/
@@ -63,6 +74,7 @@ class SAM2VideoMatting:
if sam2_repo_path.exists(): if sam2_repo_path.exists():
checkpoint_path = str(sam2_repo_path) checkpoint_path = str(sam2_repo_path)
print(f"🎯 Loading SAM2 model: {model_cfg}")
# Use SAM2's build_sam2_video_predictor which returns the predictor directly # Use SAM2's build_sam2_video_predictor which returns the predictor directly
# The predictor IS the model - no .model attribute needed # The predictor IS the model - no .model attribute needed
self.predictor = build_sam2_video_predictor( self.predictor = build_sam2_video_predictor(
@@ -71,13 +83,16 @@ class SAM2VideoMatting:
device=self.device device=self.device
) )
self._model_loaded = True
print(f"✅ SAM2 model loaded successfully")
except Exception as e: except Exception as e:
raise RuntimeError(f"Failed to load SAM2 model: {e}") raise RuntimeError(f"Failed to load SAM2 model: {e}")
def init_video_state(self, video_frames: List[np.ndarray] = None, video_path: str = None) -> None: def init_video_state(self, video_frames: List[np.ndarray] = None, video_path: str = None) -> None:
"""Initialize video inference state""" """Initialize video inference state"""
if self.predictor is None: # Load model lazily on first use
# Recreate predictor if it was cleaned up if not self._model_loaded:
self._load_model(self.model_cfg, self.checkpoint_path) self._load_model(self.model_cfg, self.checkpoint_path)
if video_path is not None: if video_path is not None:
@@ -152,13 +167,16 @@ class SAM2VideoMatting:
return object_ids return object_ids
def propagate_masks(self, start_frame: int = 0, max_frames: Optional[int] = None) -> Dict[int, Dict[int, np.ndarray]]: def propagate_masks(self, start_frame: int = 0, max_frames: Optional[int] = None,
frame_release_interval: int = 50, frame_window_size: int = 30) -> Dict[int, Dict[int, np.ndarray]]:
""" """
Propagate masks through video Propagate masks through video with Det-SAM2 style memory management
Args: Args:
start_frame: Starting frame index start_frame: Starting frame index
max_frames: Maximum number of frames to process max_frames: Maximum number of frames to process
frame_release_interval: Release old frames every N frames
frame_window_size: Keep N frames in memory
Returns: Returns:
Dictionary mapping frame_idx -> {obj_id: mask} Dictionary mapping frame_idx -> {obj_id: mask}
@@ -182,9 +200,108 @@ class SAM2VideoMatting:
video_segments[out_frame_idx] = frame_masks video_segments[out_frame_idx] = frame_masks
# Memory management: release old frames periodically # Det-SAM2 style memory management: more aggressive frame release
if self.memory_offload and out_frame_idx % 100 == 0: if self.memory_offload and out_frame_idx % frame_release_interval == 0:
self._release_old_frames(out_frame_idx - 50) self._release_old_frames(out_frame_idx - frame_window_size)
# Optional: Log frame release for monitoring
if out_frame_idx % (frame_release_interval * 4) == 0: # Log every 4x release interval
print(f"Det-SAM2: Released frames before {out_frame_idx - frame_window_size}, keeping {frame_window_size} frames")
return video_segments
def propagate_masks_with_continuous_correction(self,
detector,
temp_video_path: str,
start_frame: int = 0,
max_frames: Optional[int] = None,
correction_interval: int = 60,
frame_release_interval: int = 50,
frame_window_size: int = 30) -> Dict[int, Dict[int, np.ndarray]]:
"""
Det-SAM2 style: Propagate masks with continuous prompt correction
Args:
detector: YOLODetector instance for generating correction prompts
temp_video_path: Path to video file for frame access
start_frame: Starting frame index
max_frames: Maximum number of frames to process
correction_interval: Add correction prompts every N frames
frame_release_interval: Release old frames every N frames
frame_window_size: Keep N frames in memory
Returns:
Dictionary mapping frame_idx -> {obj_id: mask}
"""
if self.inference_state is None:
raise RuntimeError("Video state not initialized")
video_segments = {}
max_frames = max_frames or 10000 # Default limit
# Open video for accessing frames during propagation
cap = cv2.VideoCapture(str(temp_video_path))
try:
for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(
self.inference_state,
start_frame_idx=start_frame,
max_frame_num_to_track=max_frames,
reverse=False
):
frame_masks = {}
for i, out_obj_id in enumerate(out_obj_ids):
mask = (out_mask_logits[i] > 0.0).cpu().numpy()
frame_masks[out_obj_id] = mask
video_segments[out_frame_idx] = frame_masks
# Det-SAM2 optimization: Add correction prompts at keyframes
if (out_frame_idx % correction_interval == 0 and
out_frame_idx > start_frame and
out_frame_idx < max_frames - 1):
# Read frame for detection
cap.set(cv2.CAP_PROP_POS_FRAMES, out_frame_idx)
ret, correction_frame = cap.read()
if ret:
# Run detection on this keyframe
detections = detector.detect_persons(correction_frame)
if detections:
# Convert to prompts and add as corrections
box_prompts, labels = detector.convert_to_sam_prompts(detections)
# Add correction prompts (SAM2 will propagate backward)
correction_count = 0
try:
for i, (box, label) in enumerate(zip(box_prompts, labels)):
# Use existing object IDs if available, otherwise create new ones
obj_id = out_obj_ids[i] if i < len(out_obj_ids) else len(out_obj_ids) + i + 1
self.predictor.add_new_points_or_box(
inference_state=self.inference_state,
frame_idx=out_frame_idx,
obj_id=obj_id,
box=box,
)
correction_count += 1
print(f"Det-SAM2: Added {correction_count} correction prompts at frame {out_frame_idx}")
except Exception as e:
warnings.warn(f"Failed to add correction prompt at frame {out_frame_idx}: {e}")
# Memory management: More aggressive frame release (Det-SAM2 style)
if self.memory_offload and out_frame_idx % frame_release_interval == 0:
self._release_old_frames(out_frame_idx - frame_window_size)
# Optional: Log frame release for monitoring
if out_frame_idx % (frame_release_interval * 4) == 0: # Log every 4x release interval
print(f"Det-SAM2: Released frames before {out_frame_idx - frame_window_size}, keeping {frame_window_size} frames")
finally:
cap.release()
return video_segments return video_segments
@@ -302,6 +419,9 @@ class SAM2VideoMatting:
finally: finally:
self.predictor = None self.predictor = None
# Reset model loaded state for fresh reload
self._model_loaded = False
# Force garbage collection (critical for memory leak prevention) # Force garbage collection (critical for memory leak prevention)
gc.collect() gc.collect()

View File

@@ -387,19 +387,83 @@ class VideoProcessor:
# Green screen background # Green screen background
return np.full_like(frame, self.config.output.background_color, dtype=np.uint8) return np.full_like(frame, self.config.output.background_color, dtype=np.uint8)
def merge_chunks_streaming(self, chunk_files: List[Path], output_path: str,
overlap_frames: int = 0, audio_source: str = None) -> None:
"""
Merge processed chunks using streaming approach (no memory accumulation)
Args:
chunk_files: List of chunk result files (.npz)
output_path: Final output video path
overlap_frames: Number of overlapping frames
audio_source: Audio source file for final video
"""
from .streaming_video_writer import StreamingVideoWriter
if not chunk_files:
raise ValueError("No chunk files to merge")
print(f"🎬 Streaming merge: {len(chunk_files)} chunks → {output_path}")
# Initialize streaming writer
writer = StreamingVideoWriter(
output_path=output_path,
fps=self.video_info['fps'],
audio_source=audio_source
)
try:
# Process each chunk without accumulation
for i, chunk_file in enumerate(chunk_files):
print(f"📼 Processing chunk {i+1}/{len(chunk_files)}: {chunk_file.name}")
# Load chunk (this is the only copy in memory)
chunk_data = np.load(str(chunk_file))
frames = chunk_data['frames'].tolist() # Convert to list of arrays
chunk_data.close()
# Write chunk with streaming writer
writer.write_chunk(
frames=frames,
chunk_index=i,
overlap_frames=overlap_frames if i > 0 else 0,
blend_with_previous=(i > 0 and overlap_frames > 0)
)
# Immediately free memory
del frames, chunk_data
# Delete chunk file to free disk space
try:
chunk_file.unlink()
print(f" 🗑️ Deleted {chunk_file.name}")
except Exception as e:
print(f" ⚠️ Could not delete {chunk_file.name}: {e}")
# Aggressive cleanup every chunk
self._aggressive_memory_cleanup(f"After processing chunk {i}")
# Finalize the video
writer.finalize()
except Exception as e:
print(f"❌ Streaming merge failed: {e}")
writer.cleanup()
raise
print(f"✅ Streaming merge complete: {output_path}")
def merge_overlapping_chunks(self, def merge_overlapping_chunks(self,
chunk_results: List[List[np.ndarray]], chunk_results: List[List[np.ndarray]],
overlap_frames: int) -> List[np.ndarray]: overlap_frames: int) -> List[np.ndarray]:
""" """
Merge overlapping chunks with blending in overlap regions Legacy merge method - DEPRECATED due to memory accumulation
Use merge_chunks_streaming() instead for memory efficiency
Args:
chunk_results: List of chunk results
overlap_frames: Number of overlapping frames
Returns:
Merged frame sequence
""" """
import warnings
warnings.warn("merge_overlapping_chunks() is deprecated due to memory accumulation. Use merge_chunks_streaming()",
DeprecationWarning, stacklevel=2)
if len(chunk_results) == 1: if len(chunk_results) == 1:
return chunk_results[0] return chunk_results[0]
@@ -640,36 +704,23 @@ class VideoProcessor:
if self.memory_manager.should_emergency_cleanup(): if self.memory_manager.should_emergency_cleanup():
self.memory_manager.emergency_cleanup() self.memory_manager.emergency_cleanup()
# Load and merge chunks from disk # Use streaming merge to avoid memory accumulation (fixes OOM)
print("\nLoading and merging chunks...") print("\n🎬 Using streaming merge (no memory accumulation)...")
chunk_results = []
for i, chunk_file in enumerate(chunk_files):
print(f"Loading {chunk_file.name}...")
chunk_data = np.load(str(chunk_file))
chunk_results.append(chunk_data['frames'])
chunk_data.close() # Close the file
# Delete chunk file immediately after loading to free disk space # Determine audio source for final video
try: audio_source = None
chunk_file.unlink() if self.config.output.preserve_audio and Path(self.config.input.video_path).exists():
print(f" Deleted chunk file {chunk_file.name}") audio_source = self.config.input.video_path
except Exception as e:
print(f" Warning: Could not delete chunk file: {e}")
# Aggressive cleanup every few chunks to prevent accumulation # Stream merge chunks directly to output (no memory accumulation)
if i % 3 == 0 and i > 0: self.merge_chunks_streaming(
self._aggressive_memory_cleanup(f"after loading chunk {i}") chunk_files=chunk_files,
output_path=self.config.output.path,
overlap_frames=overlap_frames,
audio_source=audio_source
)
# Merge chunks print("✅ Streaming merge complete - no memory accumulation!")
final_frames = self.merge_overlapping_chunks(chunk_results, overlap_frames)
# Free chunk results after merging - this is critical!
del chunk_results
self._aggressive_memory_cleanup("after merging chunks")
# Save results
print(f"Saving {len(final_frames)} processed frames...")
self.save_video(final_frames, self.config.output.path)
# Calculate final statistics # Calculate final statistics
self.processing_stats['end_time'] = time.time() self.processing_stats['end_time'] = time.time()

View File

@@ -3,6 +3,7 @@ import numpy as np
from typing import List, Dict, Any, Optional, Tuple from typing import List, Dict, Any, Optional, Tuple
from pathlib import Path from pathlib import Path
import warnings import warnings
import torch
from .video_processor import VideoProcessor from .video_processor import VideoProcessor
from .config import VR180Config from .config import VR180Config
@@ -212,6 +213,10 @@ class VR180Processor(VideoProcessor):
del right_matted del right_matted
self._aggressive_memory_cleanup(f"After combining frames chunk {chunk_idx}") self._aggressive_memory_cleanup(f"After combining frames chunk {chunk_idx}")
# CRITICAL: Complete inter-chunk cleanup to prevent model persistence
# This ensures models don't accumulate between chunks
self._complete_inter_chunk_cleanup(chunk_idx)
return combined_frames return combined_frames
def _process_eye_sequence(self, def _process_eye_sequence(self,
@@ -375,50 +380,73 @@ class VR180Processor(VideoProcessor):
# Propagate masks (most expensive operation) # Propagate masks (most expensive operation)
self._print_memory_step(f"Before SAM2 propagation ({eye_name} eye, {num_frames} frames)") self._print_memory_step(f"Before SAM2 propagation ({eye_name} eye, {num_frames} frames)")
video_segments = self.sam2_model.propagate_masks(
start_frame=0, # Use Det-SAM2 continuous correction if enabled
max_frames=num_frames if self.config.matting.continuous_correction:
) video_segments = self.sam2_model.propagate_masks_with_continuous_correction(
detector=self.detector,
temp_video_path=str(temp_video_path),
start_frame=0,
max_frames=num_frames,
correction_interval=self.config.matting.correction_interval,
frame_release_interval=self.config.matting.frame_release_interval,
frame_window_size=self.config.matting.frame_window_size
)
print(f"Used Det-SAM2 continuous correction (interval: {self.config.matting.correction_interval} frames)")
else:
video_segments = self.sam2_model.propagate_masks(
start_frame=0,
max_frames=num_frames,
frame_release_interval=self.config.matting.frame_release_interval,
frame_window_size=self.config.matting.frame_window_size
)
self._print_memory_step(f"After SAM2 propagation ({eye_name} eye)") self._print_memory_step(f"After SAM2 propagation ({eye_name} eye)")
# Apply masks - need to reload frames from temp video since we freed the original frames # Apply masks with streaming approach (no frame accumulation)
self._print_memory_step(f"Before reloading frames for mask application ({eye_name} eye)") self._print_memory_step(f"Before streaming mask application ({eye_name} eye)")
# Read frames back from the temp video for mask application # Process frames one at a time without accumulation
cap = cv2.VideoCapture(str(temp_video_path)) cap = cv2.VideoCapture(str(temp_video_path))
reloaded_frames = []
for frame_idx in range(num_frames):
ret, frame = cap.read()
if not ret:
break
reloaded_frames.append(frame)
cap.release()
self._print_memory_step(f"Reloaded {len(reloaded_frames)} frames for mask application")
# Apply masks
matted_frames = [] matted_frames = []
for frame_idx, frame in enumerate(reloaded_frames):
if frame_idx in video_segments:
frame_masks = video_segments[frame_idx]
combined_mask = self.sam2_model.get_combined_mask(frame_masks)
matted_frame = self.sam2_model.apply_mask_to_frame( try:
frame, combined_mask, for frame_idx in range(num_frames):
output_format=self.config.output.format, ret, frame = cap.read()
background_color=self.config.output.background_color if not ret:
) break
else:
matted_frame = self._create_empty_mask_frame(frame)
matted_frames.append(matted_frame) # Apply mask to this single frame
if frame_idx in video_segments:
frame_masks = video_segments[frame_idx]
combined_mask = self.sam2_model.get_combined_mask(frame_masks)
# Free reloaded frames and video segments completely matted_frame = self.sam2_model.apply_mask_to_frame(
del reloaded_frames frame, combined_mask,
output_format=self.config.output.format,
background_color=self.config.output.background_color
)
else:
matted_frame = self._create_empty_mask_frame(frame)
matted_frames.append(matted_frame)
# Free the original frame immediately (no accumulation)
del frame
# Periodic cleanup during processing
if frame_idx % 100 == 0 and frame_idx > 0:
import gc
gc.collect()
finally:
cap.release()
# Free video segments completely
del video_segments # This holds processed masks from SAM2 del video_segments # This holds processed masks from SAM2
self._aggressive_memory_cleanup(f"After mask application ({eye_name} eye)") self._aggressive_memory_cleanup(f"After streaming mask application ({eye_name} eye)")
self._print_memory_step(f"Completed streaming mask application ({eye_name} eye)")
return matted_frames return matted_frames
finally: finally:
@@ -668,6 +696,64 @@ class VR180Processor(VideoProcessor):
# TODO: Implement proper stereo correction algorithm # TODO: Implement proper stereo correction algorithm
return right_frame return right_frame
def _complete_inter_chunk_cleanup(self, chunk_idx: int):
"""
Complete inter-chunk cleanup: Destroy all models to prevent memory accumulation
This addresses the core issue where SAM2 and YOLO models (~15-20GB)
persist between chunks, causing OOM when processing subsequent chunks.
"""
print(f"🧹 INTER-CHUNK CLEANUP: Destroying all models after chunk {chunk_idx}")
# 1. Completely destroy SAM2 model (15-20GB)
if hasattr(self, 'sam2_model') and self.sam2_model is not None:
self.sam2_model.cleanup() # Call existing cleanup
# Force complete destruction of the model
try:
# Reset the model's loaded state so it will reload fresh
if hasattr(self.sam2_model, '_model_loaded'):
self.sam2_model._model_loaded = False
# Clear any cached state
if hasattr(self.sam2_model, 'predictor'):
self.sam2_model.predictor = None
if hasattr(self.sam2_model, 'inference_state'):
self.sam2_model.inference_state = None
print(f" ✅ SAM2 model destroyed and marked for fresh reload")
except Exception as e:
print(f" ⚠️ SAM2 destruction warning: {e}")
# 2. Completely destroy YOLO detector (400MB+)
if hasattr(self, 'detector') and self.detector is not None:
try:
# Force YOLO model to be reloaded fresh
if hasattr(self.detector, 'model') and self.detector.model is not None:
del self.detector.model
self.detector.model = None
print(f" ✅ YOLO model destroyed and marked for fresh reload")
except Exception as e:
print(f" ⚠️ YOLO destruction warning: {e}")
# 3. Clear CUDA cache aggressively
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize() # Wait for all operations to complete
print(f" ✅ CUDA cache cleared")
# 4. Force garbage collection
import gc
collected = gc.collect()
print(f" ✅ Garbage collection: {collected} objects freed")
# 5. Memory verification
self._print_memory_step(f"After complete inter-chunk cleanup (chunk {chunk_idx})")
print(f"🎯 RESULT: Models will reload fresh for next chunk (prevents 15-20GB accumulation)")
def process_chunk(self, def process_chunk(self,
frames: List[np.ndarray], frames: List[np.ndarray],
chunk_idx: int = 0) -> List[np.ndarray]: chunk_idx: int = 0) -> List[np.ndarray]:
@@ -727,6 +813,9 @@ class VR180Processor(VideoProcessor):
combined = {'left': left_frame, 'right': right_frame} combined = {'left': left_frame, 'right': right_frame}
combined_frames.append(combined) combined_frames.append(combined)
# CRITICAL: Complete inter-chunk cleanup for independent processing too
self._complete_inter_chunk_cleanup(chunk_idx)
return combined_frames return combined_frames
def save_video(self, frames: List[np.ndarray], output_path: str): def save_video(self, frames: List[np.ndarray], output_path: str):