Compare commits

6 Commits

Author SHA1 Message Date
36f58acb8b foo 2025-07-26 15:18:32 -07:00
fb51e82fd4 stuff 2025-07-26 15:18:01 -07:00
9f572d4430 analyze 2025-07-26 15:10:34 -07:00
ba8706b7ae quick check 2025-07-26 14:52:44 -07:00
734445cf48 more memory fixes hopeufly 2025-07-26 14:33:36 -07:00
80f947c91b det core 2025-07-26 13:51:21 -07:00
9 changed files with 1012 additions and 90 deletions

193
analyze_memory_profile.py Normal file
View File

@@ -0,0 +1,193 @@
#!/usr/bin/env python3
"""
Analyze memory profile JSON files to identify OOM causes
"""
import json
import glob
import os
import sys
from pathlib import Path
def analyze_memory_files():
"""Analyze partial memory profile files"""
# Get all partial files in order
files = sorted(glob.glob('memory_profile_partial_*.json'))
if not files:
print("❌ No memory profile files found!")
print("Expected files like: memory_profile_partial_0.json")
return
print(f"🔍 Found {len(files)} memory profile files")
print("=" * 60)
peak_memory = 0
peak_vram = 0
critical_points = []
all_checkpoints = []
for i, file in enumerate(files):
try:
with open(file, 'r') as f:
data = json.load(f)
timeline = data.get('timeline', [])
if not timeline:
continue
# Find peaks in this file
file_peak_rss = max([d['rss_gb'] for d in timeline])
file_peak_vram = max([d['vram_gb'] for d in timeline])
if file_peak_rss > peak_memory:
peak_memory = file_peak_rss
if file_peak_vram > peak_vram:
peak_vram = file_peak_vram
# Find memory growth spikes (>3GB increase)
for j in range(1, len(timeline)):
prev_rss = timeline[j-1]['rss_gb']
curr_rss = timeline[j]['rss_gb']
growth = curr_rss - prev_rss
if growth > 3.0: # >3GB growth spike
checkpoint = timeline[j].get('checkpoint', f'sample_{j}')
critical_points.append({
'file': file,
'file_index': i,
'sample': j,
'timestamp': timeline[j]['timestamp'],
'rss_gb': curr_rss,
'vram_gb': timeline[j]['vram_gb'],
'growth_gb': growth,
'checkpoint': checkpoint
})
# Collect all checkpoints
checkpoints = [d for d in timeline if 'checkpoint' in d]
for cp in checkpoints:
cp['file'] = file
cp['file_index'] = i
all_checkpoints.append(cp)
# Show progress for this file
if timeline:
start_rss = timeline[0]['rss_gb']
end_rss = timeline[-1]['rss_gb']
growth = end_rss - start_rss
samples = len(timeline)
print(f"📊 File {i+1:2d}: {start_rss:5.1f}GB → {end_rss:5.1f}GB "
f"(+{growth:4.1f}GB) [{samples:3d} samples]")
# Show significant checkpoints from this file
if checkpoints:
for cp in checkpoints:
print(f" 📍 {cp['checkpoint']}: {cp['rss_gb']:.1f}GB")
except Exception as e:
print(f"❌ Error reading {file}: {e}")
print("\n" + "=" * 60)
print("🎯 ANALYSIS SUMMARY")
print("=" * 60)
print(f"📈 Peak Memory: {peak_memory:.1f} GB")
print(f"🎮 Peak VRAM: {peak_vram:.1f} GB")
print(f"⚡ Growth Spikes: {len(critical_points)} events >3GB")
if critical_points:
print(f"\n💥 MEMORY GROWTH SPIKES (>3GB):")
print(" Location Growth Total VRAM")
print(" " + "-" * 55)
for point in critical_points:
location = point['checkpoint'][:30].ljust(30)
print(f" {location} +{point['growth_gb']:4.1f}GB → {point['rss_gb']:5.1f}GB {point['vram_gb']:4.1f}GB")
if all_checkpoints:
print(f"\n📍 CHECKPOINT PROGRESSION:")
print(" Checkpoint Memory VRAM File")
print(" " + "-" * 55)
for cp in all_checkpoints:
checkpoint = cp['checkpoint'][:30].ljust(30)
file_num = cp['file_index'] + 1
print(f" {checkpoint} {cp['rss_gb']:5.1f}GB {cp['vram_gb']:4.1f}GB #{file_num}")
# Memory growth analysis
if len(all_checkpoints) > 1:
print(f"\n📊 MEMORY GROWTH ANALYSIS:")
# Find the biggest memory jumps between checkpoints
big_jumps = []
for i in range(1, len(all_checkpoints)):
prev_cp = all_checkpoints[i-1]
curr_cp = all_checkpoints[i]
growth = curr_cp['rss_gb'] - prev_cp['rss_gb']
if growth > 2.0: # >2GB jump
big_jumps.append({
'from': prev_cp['checkpoint'],
'to': curr_cp['checkpoint'],
'growth': growth,
'from_memory': prev_cp['rss_gb'],
'to_memory': curr_cp['rss_gb']
})
if big_jumps:
print(" Major jumps (>2GB):")
for jump in big_jumps:
print(f" {jump['from']}{jump['to']}: "
f"+{jump['growth']:.1f}GB ({jump['from_memory']:.1f}{jump['to_memory']:.1f}GB)")
else:
print(" ✅ No major memory jumps detected")
# Diagnosis
print(f"\n🔬 DIAGNOSIS:")
if peak_memory > 400:
print(" 🔴 CRITICAL: Memory usage exceeded 400GB")
print(" 💡 Recommendation: Reduce chunk_size to 200-300 frames")
elif peak_memory > 200:
print(" 🟡 HIGH: Memory usage over 200GB")
print(" 💡 Recommendation: Reduce chunk_size to 400 frames")
else:
print(" 🟢 MODERATE: Memory usage under 200GB")
if critical_points:
# Find most common growth spike locations
spike_locations = {}
for point in critical_points:
location = point['checkpoint']
spike_locations[location] = spike_locations.get(location, 0) + 1
print("\n 🎯 Most problematic locations:")
for location, count in sorted(spike_locations.items(), key=lambda x: x[1], reverse=True)[:3]:
print(f" {location}: {count} spikes")
print(f"\n💡 NEXT STEPS:")
if 'merge' in str(critical_points).lower():
print(" 1. Chunk merging still causing memory accumulation")
print(" 2. Check if streaming merge is actually being used")
print(" 3. Verify chunk files are being deleted immediately")
elif 'propagation' in str(critical_points).lower():
print(" 1. SAM2 propagation using too much memory")
print(" 2. Reduce chunk_size further (try 300 frames)")
print(" 3. Enable more aggressive frame release")
else:
print(" 1. Review the checkpoint progression above")
print(" 2. Focus on locations with biggest memory spikes")
print(" 3. Consider reducing chunk_size if spikes are large")
def main():
print("🔍 MEMORY PROFILE ANALYZER")
print("Analyzing memory profile files for OOM causes...")
print()
analyze_memory_files()
if __name__ == "__main__":
main()

151
debug_memory_leak.py Normal file
View File

@@ -0,0 +1,151 @@
#!/usr/bin/env python3
"""
Debug memory leak between chunks - track exactly where memory accumulates
"""
import psutil
import gc
from pathlib import Path
import sys
def detailed_memory_check(label):
"""Get detailed memory info"""
process = psutil.Process()
memory_info = process.memory_info()
rss_gb = memory_info.rss / (1024**3)
vms_gb = memory_info.vms / (1024**3)
# System memory
sys_memory = psutil.virtual_memory()
available_gb = sys_memory.available / (1024**3)
print(f"🔍 {label}:")
print(f" RSS: {rss_gb:.2f} GB (physical memory)")
print(f" VMS: {vms_gb:.2f} GB (virtual memory)")
print(f" Available: {available_gb:.2f} GB")
return rss_gb
def simulate_chunk_processing():
"""Simulate the chunk processing to see where memory accumulates"""
print("🚀 SIMULATING CHUNK PROCESSING TO FIND MEMORY LEAK")
print("=" * 60)
base_memory = detailed_memory_check("0. Baseline")
# Step 1: Import everything (with lazy loading)
print("\n📦 Step 1: Imports")
from vr180_matting.config import VR180Config
from vr180_matting.vr180_processor import VR180Processor
import_memory = detailed_memory_check("1. After imports")
import_growth = import_memory - base_memory
print(f" Growth: +{import_growth:.2f} GB")
# Step 2: Load config
print("\n⚙️ Step 2: Config loading")
config = VR180Config.from_yaml('config.yaml')
config_memory = detailed_memory_check("2. After config load")
config_growth = config_memory - import_memory
print(f" Growth: +{config_growth:.2f} GB")
# Step 3: Initialize processor (models still lazy)
print("\n🏗️ Step 3: Processor initialization")
processor = VR180Processor(config)
processor_memory = detailed_memory_check("3. After processor init")
processor_growth = processor_memory - config_memory
print(f" Growth: +{processor_growth:.2f} GB")
# Step 4: Load video info (lightweight)
print("\n🎬 Step 4: Video info loading")
try:
video_info = processor.load_video_info(config.input.video_path)
print(f" Video: {video_info.get('width', 'unknown')}x{video_info.get('height', 'unknown')}, "
f"{video_info.get('total_frames', 'unknown')} frames")
except Exception as e:
print(f" Warning: Could not load video info: {e}")
video_info_memory = detailed_memory_check("4. After video info")
video_info_growth = video_info_memory - processor_memory
print(f" Growth: +{video_info_growth:.2f} GB")
# Step 5: Simulate chunk 0 processing (this is where models actually load)
print("\n🔄 Step 5: Simulating chunk 0 processing...")
# This is where the real memory usage starts
print(" Loading first 10 frames to trigger model loading...")
try:
# Read a small number of frames to trigger model loading
frames = processor.read_video_frames(
config.input.video_path,
start_frame=0,
num_frames=10, # Just 10 frames to trigger model loading
scale_factor=config.processing.scale_factor
)
frames_memory = detailed_memory_check("5a. After reading 10 frames")
frames_growth = frames_memory - video_info_memory
print(f" 10 frames growth: +{frames_growth:.2f} GB")
# Free frames
del frames
gc.collect()
after_free_memory = detailed_memory_check("5b. After freeing 10 frames")
free_improvement = frames_memory - after_free_memory
print(f" Memory freed: -{free_improvement:.2f} GB")
except Exception as e:
print(f" Could not simulate frame loading: {e}")
after_free_memory = video_info_memory
print(f"\n📊 MEMORY ANALYSIS:")
print(f" Baseline → Final: {base_memory:.2f}GB → {after_free_memory:.2f}GB")
print(f" Total growth: +{after_free_memory - base_memory:.2f}GB")
if after_free_memory - base_memory > 10:
print(f" 🔴 HIGH: Memory growth > 10GB before any real processing")
print(f" 💡 This suggests model loading is using too much memory")
elif after_free_memory - base_memory > 5:
print(f" 🟡 MODERATE: Memory growth 5-10GB")
print(f" 💡 Normal for model loading, but monitor chunk processing")
else:
print(f" 🟢 GOOD: Memory growth < 5GB")
print(f" 💡 Initialization memory usage is reasonable")
print(f"\n🎯 KEY INSIGHTS:")
if import_growth > 1:
print(f" - Import growth: {import_growth:.2f}GB (fixed with lazy loading)")
if processor_growth > 10:
print(f" - Processor init: {processor_growth:.2f}GB (investigate model pre-loading)")
print(f"\n💡 RECOMMENDATIONS:")
if after_free_memory - base_memory > 15:
print(f" 1. Reduce chunk_size to 200-300 frames")
print(f" 2. Use smaller models (yolov8n instead of yolov8m)")
print(f" 3. Enable FP16 mode for SAM2")
elif after_free_memory - base_memory > 8:
print(f" 1. Monitor chunk processing carefully")
print(f" 2. Use streaming merge (should be automatic)")
print(f" 3. Current settings may be acceptable")
else:
print(f" 1. Settings look good for initialization")
print(f" 2. Focus on chunk processing memory leaks")
def main():
if len(sys.argv) != 2:
print("Usage: python debug_memory_leak.py <config.yaml>")
print("This simulates initialization to find memory leaks")
sys.exit(1)
config_path = sys.argv[1]
if not Path(config_path).exists():
print(f"Config file not found: {config_path}")
sys.exit(1)
simulate_chunk_processing()
if __name__ == "__main__":
main()

249
memory_profiler_script.py Normal file
View File

@@ -0,0 +1,249 @@
#!/usr/bin/env python3
"""
Memory profiling script for VR180 matting pipeline
Tracks memory usage during processing to identify leaks
"""
import sys
import time
import psutil
import tracemalloc
import subprocess
import gc
from pathlib import Path
from typing import Dict, List, Tuple
import threading
import json
class MemoryProfiler:
def __init__(self, output_file: str = "memory_profile.json"):
self.output_file = output_file
self.data = []
self.process = psutil.Process()
self.running = False
self.thread = None
self.checkpoint_counter = 0
def start_monitoring(self, interval: float = 1.0):
"""Start continuous memory monitoring"""
tracemalloc.start()
self.running = True
self.thread = threading.Thread(target=self._monitor_loop, args=(interval,))
self.thread.daemon = True
self.thread.start()
print(f"🔍 Memory monitoring started (interval: {interval}s)")
def stop_monitoring(self):
"""Stop monitoring and save results"""
self.running = False
if self.thread:
self.thread.join()
# Get tracemalloc snapshot
snapshot = tracemalloc.take_snapshot()
top_stats = snapshot.statistics('lineno')
# Save detailed results
results = {
'timeline': self.data,
'top_memory_allocations': [
{
'file': stat.traceback.format()[0],
'size_mb': stat.size / 1024 / 1024,
'count': stat.count
}
for stat in top_stats[:20] # Top 20 allocations
],
'summary': {
'peak_rss_gb': max([d['rss_gb'] for d in self.data]) if self.data else 0,
'peak_vram_gb': max([d['vram_gb'] for d in self.data]) if self.data else 0,
'total_samples': len(self.data)
}
}
with open(self.output_file, 'w') as f:
json.dump(results, f, indent=2)
tracemalloc.stop()
print(f"📊 Memory profile saved to {self.output_file}")
def _monitor_loop(self, interval: float):
"""Continuous monitoring loop"""
while self.running:
try:
# System memory
memory_info = self.process.memory_info()
rss_gb = memory_info.rss / (1024**3)
# System-wide memory
sys_memory = psutil.virtual_memory()
sys_used_gb = (sys_memory.total - sys_memory.available) / (1024**3)
sys_available_gb = sys_memory.available / (1024**3)
# GPU memory (if available)
vram_gb = 0
vram_free_gb = 0
try:
result = subprocess.run(['nvidia-smi', '--query-gpu=memory.used,memory.free',
'--format=csv,noheader,nounits'],
capture_output=True, text=True, timeout=5)
if result.returncode == 0:
lines = result.stdout.strip().split('\n')
if lines and lines[0]:
used, free = lines[0].split(', ')
vram_gb = float(used) / 1024
vram_free_gb = float(free) / 1024
except Exception:
pass
# Tracemalloc current usage
try:
current, peak = tracemalloc.get_traced_memory()
traced_mb = current / (1024**2)
except Exception:
traced_mb = 0
data_point = {
'timestamp': time.time(),
'rss_gb': rss_gb,
'vram_gb': vram_gb,
'vram_free_gb': vram_free_gb,
'sys_used_gb': sys_used_gb,
'sys_available_gb': sys_available_gb,
'traced_mb': traced_mb
}
self.data.append(data_point)
# Print periodic updates and save partial data
if len(self.data) % 10 == 0: # Every 10 samples
print(f"🔍 Memory: RSS={rss_gb:.2f}GB, VRAM={vram_gb:.2f}GB, Sys={sys_used_gb:.1f}GB")
# Save partial data every 30 samples in case of crash
if len(self.data) % 30 == 0:
self._save_partial_data()
except Exception as e:
print(f"Monitoring error: {e}")
time.sleep(interval)
def _save_partial_data(self):
"""Save partial data to prevent loss on crash"""
try:
partial_file = f"memory_profile_partial_{self.checkpoint_counter}.json"
with open(partial_file, 'w') as f:
json.dump({
'timeline': self.data,
'status': 'partial_save',
'samples': len(self.data)
}, f, indent=2)
self.checkpoint_counter += 1
except Exception as e:
print(f"Failed to save partial data: {e}")
def log_checkpoint(self, checkpoint_name: str):
"""Log a specific checkpoint"""
if self.data:
self.data[-1]['checkpoint'] = checkpoint_name
latest = self.data[-1]
print(f"📍 CHECKPOINT [{checkpoint_name}]: RSS={latest['rss_gb']:.2f}GB, VRAM={latest['vram_gb']:.2f}GB")
# Save checkpoint data immediately
self._save_partial_data()
def run_with_profiling(config_path: str):
"""Run the VR180 matting with memory profiling"""
profiler = MemoryProfiler("memory_profile_detailed.json")
try:
# Start monitoring
profiler.start_monitoring(interval=2.0) # Sample every 2 seconds
# Log initial state
profiler.log_checkpoint("STARTUP")
# Import after starting profiler to catch import memory usage
print("Importing VR180 processor...")
from vr180_matting.vr180_processor import VR180Processor
from vr180_matting.config import VR180Config
profiler.log_checkpoint("IMPORTS_COMPLETE")
# Load config
print(f"Loading config from {config_path}")
config = VR180Config.from_yaml(config_path)
profiler.log_checkpoint("CONFIG_LOADED")
# Initialize processor
print("Initializing VR180 processor...")
processor = VR180Processor(config)
profiler.log_checkpoint("PROCESSOR_INITIALIZED")
# Force garbage collection
gc.collect()
profiler.log_checkpoint("INITIAL_GC_COMPLETE")
# Run processing
print("Starting VR180 processing...")
processor.process_video()
profiler.log_checkpoint("PROCESSING_COMPLETE")
except Exception as e:
print(f"❌ Error during processing: {e}")
profiler.log_checkpoint(f"ERROR: {str(e)}")
raise
finally:
# Stop monitoring and save results
profiler.stop_monitoring()
# Print summary
print("\n" + "="*60)
print("MEMORY PROFILING SUMMARY")
print("="*60)
if profiler.data:
peak_rss = max([d['rss_gb'] for d in profiler.data])
peak_vram = max([d['vram_gb'] for d in profiler.data])
print(f"Peak RSS Memory: {peak_rss:.2f} GB")
print(f"Peak VRAM Usage: {peak_vram:.2f} GB")
print(f"Total Samples: {len(profiler.data)}")
# Show checkpoints
checkpoints = [d for d in profiler.data if 'checkpoint' in d]
if checkpoints:
print(f"\nCheckpoints ({len(checkpoints)}):")
for cp in checkpoints:
print(f" {cp['checkpoint']}: RSS={cp['rss_gb']:.2f}GB, VRAM={cp['vram_gb']:.2f}GB")
print(f"\nDetailed profile saved to: {profiler.output_file}")
def main():
if len(sys.argv) != 2:
print("Usage: python memory_profiler_script.py <config.yaml>")
print("\nThis script runs VR180 matting with detailed memory profiling")
print("It will:")
print("- Monitor RSS, VRAM, and system memory every 2 seconds")
print("- Track memory allocations with tracemalloc")
print("- Log checkpoints at key processing stages")
print("- Save detailed JSON report for analysis")
sys.exit(1)
config_path = sys.argv[1]
if not Path(config_path).exists():
print(f"❌ Config file not found: {config_path}")
sys.exit(1)
print("🚀 Starting VR180 Memory Profiling")
print(f"Config: {config_path}")
print("="*60)
run_with_profiling(config_path)
if __name__ == "__main__":
main()

125
quick_memory_check.py Normal file
View File

@@ -0,0 +1,125 @@
#!/usr/bin/env python3
"""
Quick memory and system check before running full pipeline
"""
import psutil
import subprocess
import sys
from pathlib import Path
def check_system():
"""Check system resources before starting"""
print("🔍 SYSTEM RESOURCE CHECK")
print("=" * 50)
# Memory info
memory = psutil.virtual_memory()
print(f"📊 RAM:")
print(f" Total: {memory.total / (1024**3):.1f} GB")
print(f" Available: {memory.available / (1024**3):.1f} GB")
print(f" Used: {(memory.total - memory.available) / (1024**3):.1f} GB ({memory.percent:.1f}%)")
# GPU info
try:
result = subprocess.run(['nvidia-smi', '--query-gpu=name,memory.total,memory.used,memory.free',
'--format=csv,noheader,nounits'],
capture_output=True, text=True, timeout=10)
if result.returncode == 0:
lines = result.stdout.strip().split('\n')
print(f"\n🎮 GPU:")
for i, line in enumerate(lines):
if line.strip():
parts = line.split(', ')
if len(parts) >= 4:
name, total, used, free = parts[:4]
total_gb = float(total) / 1024
used_gb = float(used) / 1024
free_gb = float(free) / 1024
print(f" GPU {i}: {name}")
print(f" VRAM: {used_gb:.1f}/{total_gb:.1f} GB ({used_gb/total_gb*100:.1f}% used)")
print(f" Free: {free_gb:.1f} GB")
except Exception as e:
print(f"\n⚠️ Could not get GPU info: {e}")
# Disk space
disk = psutil.disk_usage('/')
print(f"\n💾 Disk (/):")
print(f" Total: {disk.total / (1024**3):.1f} GB")
print(f" Used: {disk.used / (1024**3):.1f} GB ({disk.used/disk.total*100:.1f}%)")
print(f" Free: {disk.free / (1024**3):.1f} GB")
# Check config file
if len(sys.argv) > 1:
config_path = sys.argv[1]
if Path(config_path).exists():
print(f"\n✅ Config file found: {config_path}")
# Try to load and show key settings
try:
import yaml
with open(config_path, 'r') as f:
config = yaml.safe_load(f)
print(f"📋 Key Settings:")
if 'processing' in config:
proc = config['processing']
print(f" Chunk size: {proc.get('chunk_size', 'default')}")
print(f" Scale factor: {proc.get('scale_factor', 'default')}")
if 'hardware' in config:
hw = config['hardware']
print(f" Max VRAM: {hw.get('max_vram_gb', 'default')} GB")
if 'input' in config:
inp = config['input']
video_path = inp.get('video_path', '')
if video_path and Path(video_path).exists():
size_gb = Path(video_path).stat().st_size / (1024**3)
print(f" Input video: {video_path} ({size_gb:.1f} GB)")
else:
print(f" ⚠️ Input video not found: {video_path}")
except Exception as e:
print(f" ⚠️ Could not parse config: {e}")
else:
print(f"\n❌ Config file not found: {config_path}")
return False
# Memory safety warnings
print(f"\n⚠️ MEMORY SAFETY CHECKS:")
available_gb = memory.available / (1024**3)
if available_gb < 10:
print(f" 🔴 LOW MEMORY: Only {available_gb:.1f}GB available")
print(" Consider: reducing chunk_size or scale_factor")
return False
elif available_gb < 20:
print(f" 🟡 MODERATE MEMORY: {available_gb:.1f}GB available")
print(" Recommend: chunk_size ≤ 300, scale_factor ≤ 0.5")
else:
print(f" 🟢 GOOD MEMORY: {available_gb:.1f}GB available")
print(f"\n" + "=" * 50)
return True
def main():
if len(sys.argv) != 2:
print("Usage: python quick_memory_check.py <config.yaml>")
print("\nThis checks system resources before running VR180 matting")
sys.exit(1)
safe_to_run = check_system()
if safe_to_run:
print("✅ System check passed - safe to run VR180 matting")
print("\nTo run with memory profiling:")
print(f" python memory_profiler_script.py {sys.argv[1]}")
print("\nTo run normally:")
print(f" vr180-matting {sys.argv[1]}")
else:
print("❌ System check failed - address issues before running")
sys.exit(1)
if __name__ == "__main__":
main()

View File

@@ -29,6 +29,11 @@ class MattingConfig:
fp16: bool = True
sam2_model_cfg: str = "sam2.1_hiera_l"
sam2_checkpoint: str = "segment-anything-2/checkpoints/sam2.1_hiera_large.pt"
# Det-SAM2 optimizations
continuous_correction: bool = True
correction_interval: int = 60 # Add correction prompts every N frames
frame_release_interval: int = 50 # Release old frames every N frames
frame_window_size: int = 30 # Keep N frames in memory
@dataclass

View File

@@ -1,6 +1,4 @@
import torch
import numpy as np
from ultralytics import YOLO
from typing import List, Tuple, Dict, Any
import cv2
@@ -13,14 +11,23 @@ class YOLODetector:
self.confidence_threshold = confidence_threshold
self.device = device
self.model = None
self._load_model()
# Don't load model during init - load lazily when first used
def _load_model(self):
"""Load YOLOv8 model"""
"""Load YOLOv8 model lazily"""
if self.model is not None:
return # Already loaded
try:
# Import heavy dependencies only when needed
import torch
from ultralytics import YOLO
self.model = YOLO(f"{self.model_name}.pt")
if self.device == "cuda" and torch.cuda.is_available():
self.model.to("cuda")
print(f"🎯 Loaded YOLO model: {self.model_name}")
except Exception as e:
raise RuntimeError(f"Failed to load YOLO model {self.model_name}: {e}")
@@ -34,8 +41,9 @@ class YOLODetector:
Returns:
List of detection dictionaries with bbox, confidence, and class info
"""
# Load model lazily on first use
if self.model is None:
raise RuntimeError("YOLO model not loaded")
self._load_model()
results = self.model(frame, verbose=False)
detections = []

View File

@@ -9,12 +9,16 @@ import tempfile
import shutil
import gc
# Check SAM2 availability without importing heavy modules
def _check_sam2_available():
try:
from sam2.build_sam import build_sam2_video_predictor
from sam2.sam2_image_predictor import SAM2ImagePredictor
SAM2_AVAILABLE = True
import sam2
return True
except ImportError:
SAM2_AVAILABLE = False
return False
SAM2_AVAILABLE = _check_sam2_available()
if not SAM2_AVAILABLE:
warnings.warn("SAM2 not available. Please install sam2 package.")
@@ -40,11 +44,18 @@ class SAM2VideoMatting:
self.video_segments = {}
self.temp_video_path = None
self._load_model(model_cfg, checkpoint_path)
# Don't load model during init - load lazily when needed
self._model_loaded = False
def _load_model(self, model_cfg: str, checkpoint_path: str):
"""Load SAM2 video predictor with optimizations"""
"""Load SAM2 video predictor lazily"""
if self._model_loaded:
return # Already loaded
try:
# Import heavy SAM2 modules only when needed
from sam2.build_sam import build_sam2_video_predictor
# Check for checkpoint in SAM2 repo structure
if not Path(checkpoint_path).exists():
# Try in segment-anything-2/checkpoints/
@@ -63,6 +74,7 @@ class SAM2VideoMatting:
if sam2_repo_path.exists():
checkpoint_path = str(sam2_repo_path)
print(f"🎯 Loading SAM2 model: {model_cfg}")
# Use SAM2's build_sam2_video_predictor which returns the predictor directly
# The predictor IS the model - no .model attribute needed
self.predictor = build_sam2_video_predictor(
@@ -71,13 +83,16 @@ class SAM2VideoMatting:
device=self.device
)
self._model_loaded = True
print(f"✅ SAM2 model loaded successfully")
except Exception as e:
raise RuntimeError(f"Failed to load SAM2 model: {e}")
def init_video_state(self, video_frames: List[np.ndarray] = None, video_path: str = None) -> None:
"""Initialize video inference state"""
if self.predictor is None:
# Recreate predictor if it was cleaned up
# Load model lazily on first use
if not self._model_loaded:
self._load_model(self.model_cfg, self.checkpoint_path)
if video_path is not None:
@@ -152,13 +167,16 @@ class SAM2VideoMatting:
return object_ids
def propagate_masks(self, start_frame: int = 0, max_frames: Optional[int] = None) -> Dict[int, Dict[int, np.ndarray]]:
def propagate_masks(self, start_frame: int = 0, max_frames: Optional[int] = None,
frame_release_interval: int = 50, frame_window_size: int = 30) -> Dict[int, Dict[int, np.ndarray]]:
"""
Propagate masks through video
Propagate masks through video with Det-SAM2 style memory management
Args:
start_frame: Starting frame index
max_frames: Maximum number of frames to process
frame_release_interval: Release old frames every N frames
frame_window_size: Keep N frames in memory
Returns:
Dictionary mapping frame_idx -> {obj_id: mask}
@@ -182,9 +200,108 @@ class SAM2VideoMatting:
video_segments[out_frame_idx] = frame_masks
# Memory management: release old frames periodically
if self.memory_offload and out_frame_idx % 100 == 0:
self._release_old_frames(out_frame_idx - 50)
# Det-SAM2 style memory management: more aggressive frame release
if self.memory_offload and out_frame_idx % frame_release_interval == 0:
self._release_old_frames(out_frame_idx - frame_window_size)
# Optional: Log frame release for monitoring
if out_frame_idx % (frame_release_interval * 4) == 0: # Log every 4x release interval
print(f"Det-SAM2: Released frames before {out_frame_idx - frame_window_size}, keeping {frame_window_size} frames")
return video_segments
def propagate_masks_with_continuous_correction(self,
detector,
temp_video_path: str,
start_frame: int = 0,
max_frames: Optional[int] = None,
correction_interval: int = 60,
frame_release_interval: int = 50,
frame_window_size: int = 30) -> Dict[int, Dict[int, np.ndarray]]:
"""
Det-SAM2 style: Propagate masks with continuous prompt correction
Args:
detector: YOLODetector instance for generating correction prompts
temp_video_path: Path to video file for frame access
start_frame: Starting frame index
max_frames: Maximum number of frames to process
correction_interval: Add correction prompts every N frames
frame_release_interval: Release old frames every N frames
frame_window_size: Keep N frames in memory
Returns:
Dictionary mapping frame_idx -> {obj_id: mask}
"""
if self.inference_state is None:
raise RuntimeError("Video state not initialized")
video_segments = {}
max_frames = max_frames or 10000 # Default limit
# Open video for accessing frames during propagation
cap = cv2.VideoCapture(str(temp_video_path))
try:
for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(
self.inference_state,
start_frame_idx=start_frame,
max_frame_num_to_track=max_frames,
reverse=False
):
frame_masks = {}
for i, out_obj_id in enumerate(out_obj_ids):
mask = (out_mask_logits[i] > 0.0).cpu().numpy()
frame_masks[out_obj_id] = mask
video_segments[out_frame_idx] = frame_masks
# Det-SAM2 optimization: Add correction prompts at keyframes
if (out_frame_idx % correction_interval == 0 and
out_frame_idx > start_frame and
out_frame_idx < max_frames - 1):
# Read frame for detection
cap.set(cv2.CAP_PROP_POS_FRAMES, out_frame_idx)
ret, correction_frame = cap.read()
if ret:
# Run detection on this keyframe
detections = detector.detect_persons(correction_frame)
if detections:
# Convert to prompts and add as corrections
box_prompts, labels = detector.convert_to_sam_prompts(detections)
# Add correction prompts (SAM2 will propagate backward)
correction_count = 0
try:
for i, (box, label) in enumerate(zip(box_prompts, labels)):
# Use existing object IDs if available, otherwise create new ones
obj_id = out_obj_ids[i] if i < len(out_obj_ids) else len(out_obj_ids) + i + 1
self.predictor.add_new_points_or_box(
inference_state=self.inference_state,
frame_idx=out_frame_idx,
obj_id=obj_id,
box=box,
)
correction_count += 1
print(f"Det-SAM2: Added {correction_count} correction prompts at frame {out_frame_idx}")
except Exception as e:
warnings.warn(f"Failed to add correction prompt at frame {out_frame_idx}: {e}")
# Memory management: More aggressive frame release (Det-SAM2 style)
if self.memory_offload and out_frame_idx % frame_release_interval == 0:
self._release_old_frames(out_frame_idx - frame_window_size)
# Optional: Log frame release for monitoring
if out_frame_idx % (frame_release_interval * 4) == 0: # Log every 4x release interval
print(f"Det-SAM2: Released frames before {out_frame_idx - frame_window_size}, keeping {frame_window_size} frames")
finally:
cap.release()
return video_segments

View File

@@ -387,19 +387,83 @@ class VideoProcessor:
# Green screen background
return np.full_like(frame, self.config.output.background_color, dtype=np.uint8)
def merge_chunks_streaming(self, chunk_files: List[Path], output_path: str,
overlap_frames: int = 0, audio_source: str = None) -> None:
"""
Merge processed chunks using streaming approach (no memory accumulation)
Args:
chunk_files: List of chunk result files (.npz)
output_path: Final output video path
overlap_frames: Number of overlapping frames
audio_source: Audio source file for final video
"""
from .streaming_video_writer import StreamingVideoWriter
if not chunk_files:
raise ValueError("No chunk files to merge")
print(f"🎬 Streaming merge: {len(chunk_files)} chunks → {output_path}")
# Initialize streaming writer
writer = StreamingVideoWriter(
output_path=output_path,
fps=self.video_info['fps'],
audio_source=audio_source
)
try:
# Process each chunk without accumulation
for i, chunk_file in enumerate(chunk_files):
print(f"📼 Processing chunk {i+1}/{len(chunk_files)}: {chunk_file.name}")
# Load chunk (this is the only copy in memory)
chunk_data = np.load(str(chunk_file))
frames = chunk_data['frames'].tolist() # Convert to list of arrays
chunk_data.close()
# Write chunk with streaming writer
writer.write_chunk(
frames=frames,
chunk_index=i,
overlap_frames=overlap_frames if i > 0 else 0,
blend_with_previous=(i > 0 and overlap_frames > 0)
)
# Immediately free memory
del frames, chunk_data
# Delete chunk file to free disk space
try:
chunk_file.unlink()
print(f" 🗑️ Deleted {chunk_file.name}")
except Exception as e:
print(f" ⚠️ Could not delete {chunk_file.name}: {e}")
# Aggressive cleanup every chunk
self._aggressive_memory_cleanup(f"After processing chunk {i}")
# Finalize the video
writer.finalize()
except Exception as e:
print(f"❌ Streaming merge failed: {e}")
writer.cleanup()
raise
print(f"✅ Streaming merge complete: {output_path}")
def merge_overlapping_chunks(self,
chunk_results: List[List[np.ndarray]],
overlap_frames: int) -> List[np.ndarray]:
"""
Merge overlapping chunks with blending in overlap regions
Args:
chunk_results: List of chunk results
overlap_frames: Number of overlapping frames
Returns:
Merged frame sequence
Legacy merge method - DEPRECATED due to memory accumulation
Use merge_chunks_streaming() instead for memory efficiency
"""
import warnings
warnings.warn("merge_overlapping_chunks() is deprecated due to memory accumulation. Use merge_chunks_streaming()",
DeprecationWarning, stacklevel=2)
if len(chunk_results) == 1:
return chunk_results[0]
@@ -640,36 +704,23 @@ class VideoProcessor:
if self.memory_manager.should_emergency_cleanup():
self.memory_manager.emergency_cleanup()
# Load and merge chunks from disk
print("\nLoading and merging chunks...")
chunk_results = []
for i, chunk_file in enumerate(chunk_files):
print(f"Loading {chunk_file.name}...")
chunk_data = np.load(str(chunk_file))
chunk_results.append(chunk_data['frames'])
chunk_data.close() # Close the file
# Use streaming merge to avoid memory accumulation (fixes OOM)
print("\n🎬 Using streaming merge (no memory accumulation)...")
# Delete chunk file immediately after loading to free disk space
try:
chunk_file.unlink()
print(f" Deleted chunk file {chunk_file.name}")
except Exception as e:
print(f" Warning: Could not delete chunk file: {e}")
# Determine audio source for final video
audio_source = None
if self.config.output.preserve_audio and Path(self.config.input.video_path).exists():
audio_source = self.config.input.video_path
# Aggressive cleanup every few chunks to prevent accumulation
if i % 3 == 0 and i > 0:
self._aggressive_memory_cleanup(f"after loading chunk {i}")
# Stream merge chunks directly to output (no memory accumulation)
self.merge_chunks_streaming(
chunk_files=chunk_files,
output_path=self.config.output.path,
overlap_frames=overlap_frames,
audio_source=audio_source
)
# Merge chunks
final_frames = self.merge_overlapping_chunks(chunk_results, overlap_frames)
# Free chunk results after merging - this is critical!
del chunk_results
self._aggressive_memory_cleanup("after merging chunks")
# Save results
print(f"Saving {len(final_frames)} processed frames...")
self.save_video(final_frames, self.config.output.path)
print("✅ Streaming merge complete - no memory accumulation!")
# Calculate final statistics
self.processing_stats['end_time'] = time.time()

View File

@@ -375,31 +375,43 @@ class VR180Processor(VideoProcessor):
# Propagate masks (most expensive operation)
self._print_memory_step(f"Before SAM2 propagation ({eye_name} eye, {num_frames} frames)")
# Use Det-SAM2 continuous correction if enabled
if self.config.matting.continuous_correction:
video_segments = self.sam2_model.propagate_masks_with_continuous_correction(
detector=self.detector,
temp_video_path=str(temp_video_path),
start_frame=0,
max_frames=num_frames,
correction_interval=self.config.matting.correction_interval,
frame_release_interval=self.config.matting.frame_release_interval,
frame_window_size=self.config.matting.frame_window_size
)
print(f"Used Det-SAM2 continuous correction (interval: {self.config.matting.correction_interval} frames)")
else:
video_segments = self.sam2_model.propagate_masks(
start_frame=0,
max_frames=num_frames
max_frames=num_frames,
frame_release_interval=self.config.matting.frame_release_interval,
frame_window_size=self.config.matting.frame_window_size
)
self._print_memory_step(f"After SAM2 propagation ({eye_name} eye)")
# Apply masks - need to reload frames from temp video since we freed the original frames
self._print_memory_step(f"Before reloading frames for mask application ({eye_name} eye)")
# Apply masks with streaming approach (no frame accumulation)
self._print_memory_step(f"Before streaming mask application ({eye_name} eye)")
# Read frames back from the temp video for mask application
# Process frames one at a time without accumulation
cap = cv2.VideoCapture(str(temp_video_path))
reloaded_frames = []
matted_frames = []
try:
for frame_idx in range(num_frames):
ret, frame = cap.read()
if not ret:
break
reloaded_frames.append(frame)
cap.release()
self._print_memory_step(f"Reloaded {len(reloaded_frames)} frames for mask application")
# Apply masks
matted_frames = []
for frame_idx, frame in enumerate(reloaded_frames):
# Apply mask to this single frame
if frame_idx in video_segments:
frame_masks = video_segments[frame_idx]
combined_mask = self.sam2_model.get_combined_mask(frame_masks)
@@ -414,11 +426,22 @@ class VR180Processor(VideoProcessor):
matted_frames.append(matted_frame)
# Free reloaded frames and video segments completely
del reloaded_frames
del video_segments # This holds processed masks from SAM2
self._aggressive_memory_cleanup(f"After mask application ({eye_name} eye)")
# Free the original frame immediately (no accumulation)
del frame
# Periodic cleanup during processing
if frame_idx % 100 == 0 and frame_idx > 0:
import gc
gc.collect()
finally:
cap.release()
# Free video segments completely
del video_segments # This holds processed masks from SAM2
self._aggressive_memory_cleanup(f"After streaming mask application ({eye_name} eye)")
self._print_memory_step(f"Completed streaming mask application ({eye_name} eye)")
return matted_frames
finally: