Compare commits
6 Commits
cuda
...
36f58acb8b
| Author | SHA1 | Date | |
|---|---|---|---|
| 36f58acb8b | |||
| fb51e82fd4 | |||
| 9f572d4430 | |||
| ba8706b7ae | |||
| 734445cf48 | |||
| 80f947c91b |
193
analyze_memory_profile.py
Normal file
193
analyze_memory_profile.py
Normal file
@@ -0,0 +1,193 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Analyze memory profile JSON files to identify OOM causes
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import glob
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
def analyze_memory_files():
|
||||||
|
"""Analyze partial memory profile files"""
|
||||||
|
|
||||||
|
# Get all partial files in order
|
||||||
|
files = sorted(glob.glob('memory_profile_partial_*.json'))
|
||||||
|
|
||||||
|
if not files:
|
||||||
|
print("❌ No memory profile files found!")
|
||||||
|
print("Expected files like: memory_profile_partial_0.json")
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"🔍 Found {len(files)} memory profile files")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
peak_memory = 0
|
||||||
|
peak_vram = 0
|
||||||
|
critical_points = []
|
||||||
|
all_checkpoints = []
|
||||||
|
|
||||||
|
for i, file in enumerate(files):
|
||||||
|
try:
|
||||||
|
with open(file, 'r') as f:
|
||||||
|
data = json.load(f)
|
||||||
|
|
||||||
|
timeline = data.get('timeline', [])
|
||||||
|
if not timeline:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Find peaks in this file
|
||||||
|
file_peak_rss = max([d['rss_gb'] for d in timeline])
|
||||||
|
file_peak_vram = max([d['vram_gb'] for d in timeline])
|
||||||
|
|
||||||
|
if file_peak_rss > peak_memory:
|
||||||
|
peak_memory = file_peak_rss
|
||||||
|
if file_peak_vram > peak_vram:
|
||||||
|
peak_vram = file_peak_vram
|
||||||
|
|
||||||
|
# Find memory growth spikes (>3GB increase)
|
||||||
|
for j in range(1, len(timeline)):
|
||||||
|
prev_rss = timeline[j-1]['rss_gb']
|
||||||
|
curr_rss = timeline[j]['rss_gb']
|
||||||
|
growth = curr_rss - prev_rss
|
||||||
|
|
||||||
|
if growth > 3.0: # >3GB growth spike
|
||||||
|
checkpoint = timeline[j].get('checkpoint', f'sample_{j}')
|
||||||
|
critical_points.append({
|
||||||
|
'file': file,
|
||||||
|
'file_index': i,
|
||||||
|
'sample': j,
|
||||||
|
'timestamp': timeline[j]['timestamp'],
|
||||||
|
'rss_gb': curr_rss,
|
||||||
|
'vram_gb': timeline[j]['vram_gb'],
|
||||||
|
'growth_gb': growth,
|
||||||
|
'checkpoint': checkpoint
|
||||||
|
})
|
||||||
|
|
||||||
|
# Collect all checkpoints
|
||||||
|
checkpoints = [d for d in timeline if 'checkpoint' in d]
|
||||||
|
for cp in checkpoints:
|
||||||
|
cp['file'] = file
|
||||||
|
cp['file_index'] = i
|
||||||
|
all_checkpoints.append(cp)
|
||||||
|
|
||||||
|
# Show progress for this file
|
||||||
|
if timeline:
|
||||||
|
start_rss = timeline[0]['rss_gb']
|
||||||
|
end_rss = timeline[-1]['rss_gb']
|
||||||
|
growth = end_rss - start_rss
|
||||||
|
samples = len(timeline)
|
||||||
|
|
||||||
|
print(f"📊 File {i+1:2d}: {start_rss:5.1f}GB → {end_rss:5.1f}GB "
|
||||||
|
f"(+{growth:4.1f}GB) [{samples:3d} samples]")
|
||||||
|
|
||||||
|
# Show significant checkpoints from this file
|
||||||
|
if checkpoints:
|
||||||
|
for cp in checkpoints:
|
||||||
|
print(f" 📍 {cp['checkpoint']}: {cp['rss_gb']:.1f}GB")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Error reading {file}: {e}")
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("🎯 ANALYSIS SUMMARY")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
print(f"📈 Peak Memory: {peak_memory:.1f} GB")
|
||||||
|
print(f"🎮 Peak VRAM: {peak_vram:.1f} GB")
|
||||||
|
print(f"⚡ Growth Spikes: {len(critical_points)} events >3GB")
|
||||||
|
|
||||||
|
if critical_points:
|
||||||
|
print(f"\n💥 MEMORY GROWTH SPIKES (>3GB):")
|
||||||
|
print(" Location Growth Total VRAM")
|
||||||
|
print(" " + "-" * 55)
|
||||||
|
|
||||||
|
for point in critical_points:
|
||||||
|
location = point['checkpoint'][:30].ljust(30)
|
||||||
|
print(f" {location} +{point['growth_gb']:4.1f}GB → {point['rss_gb']:5.1f}GB {point['vram_gb']:4.1f}GB")
|
||||||
|
|
||||||
|
if all_checkpoints:
|
||||||
|
print(f"\n📍 CHECKPOINT PROGRESSION:")
|
||||||
|
print(" Checkpoint Memory VRAM File")
|
||||||
|
print(" " + "-" * 55)
|
||||||
|
|
||||||
|
for cp in all_checkpoints:
|
||||||
|
checkpoint = cp['checkpoint'][:30].ljust(30)
|
||||||
|
file_num = cp['file_index'] + 1
|
||||||
|
print(f" {checkpoint} {cp['rss_gb']:5.1f}GB {cp['vram_gb']:4.1f}GB #{file_num}")
|
||||||
|
|
||||||
|
# Memory growth analysis
|
||||||
|
if len(all_checkpoints) > 1:
|
||||||
|
print(f"\n📊 MEMORY GROWTH ANALYSIS:")
|
||||||
|
|
||||||
|
# Find the biggest memory jumps between checkpoints
|
||||||
|
big_jumps = []
|
||||||
|
for i in range(1, len(all_checkpoints)):
|
||||||
|
prev_cp = all_checkpoints[i-1]
|
||||||
|
curr_cp = all_checkpoints[i]
|
||||||
|
|
||||||
|
growth = curr_cp['rss_gb'] - prev_cp['rss_gb']
|
||||||
|
if growth > 2.0: # >2GB jump
|
||||||
|
big_jumps.append({
|
||||||
|
'from': prev_cp['checkpoint'],
|
||||||
|
'to': curr_cp['checkpoint'],
|
||||||
|
'growth': growth,
|
||||||
|
'from_memory': prev_cp['rss_gb'],
|
||||||
|
'to_memory': curr_cp['rss_gb']
|
||||||
|
})
|
||||||
|
|
||||||
|
if big_jumps:
|
||||||
|
print(" Major jumps (>2GB):")
|
||||||
|
for jump in big_jumps:
|
||||||
|
print(f" {jump['from']} → {jump['to']}: "
|
||||||
|
f"+{jump['growth']:.1f}GB ({jump['from_memory']:.1f}→{jump['to_memory']:.1f}GB)")
|
||||||
|
else:
|
||||||
|
print(" ✅ No major memory jumps detected")
|
||||||
|
|
||||||
|
# Diagnosis
|
||||||
|
print(f"\n🔬 DIAGNOSIS:")
|
||||||
|
|
||||||
|
if peak_memory > 400:
|
||||||
|
print(" 🔴 CRITICAL: Memory usage exceeded 400GB")
|
||||||
|
print(" 💡 Recommendation: Reduce chunk_size to 200-300 frames")
|
||||||
|
elif peak_memory > 200:
|
||||||
|
print(" 🟡 HIGH: Memory usage over 200GB")
|
||||||
|
print(" 💡 Recommendation: Reduce chunk_size to 400 frames")
|
||||||
|
else:
|
||||||
|
print(" 🟢 MODERATE: Memory usage under 200GB")
|
||||||
|
|
||||||
|
if critical_points:
|
||||||
|
# Find most common growth spike locations
|
||||||
|
spike_locations = {}
|
||||||
|
for point in critical_points:
|
||||||
|
location = point['checkpoint']
|
||||||
|
spike_locations[location] = spike_locations.get(location, 0) + 1
|
||||||
|
|
||||||
|
print("\n 🎯 Most problematic locations:")
|
||||||
|
for location, count in sorted(spike_locations.items(), key=lambda x: x[1], reverse=True)[:3]:
|
||||||
|
print(f" {location}: {count} spikes")
|
||||||
|
|
||||||
|
print(f"\n💡 NEXT STEPS:")
|
||||||
|
if 'merge' in str(critical_points).lower():
|
||||||
|
print(" 1. Chunk merging still causing memory accumulation")
|
||||||
|
print(" 2. Check if streaming merge is actually being used")
|
||||||
|
print(" 3. Verify chunk files are being deleted immediately")
|
||||||
|
elif 'propagation' in str(critical_points).lower():
|
||||||
|
print(" 1. SAM2 propagation using too much memory")
|
||||||
|
print(" 2. Reduce chunk_size further (try 300 frames)")
|
||||||
|
print(" 3. Enable more aggressive frame release")
|
||||||
|
else:
|
||||||
|
print(" 1. Review the checkpoint progression above")
|
||||||
|
print(" 2. Focus on locations with biggest memory spikes")
|
||||||
|
print(" 3. Consider reducing chunk_size if spikes are large")
|
||||||
|
|
||||||
|
def main():
|
||||||
|
print("🔍 MEMORY PROFILE ANALYZER")
|
||||||
|
print("Analyzing memory profile files for OOM causes...")
|
||||||
|
print()
|
||||||
|
|
||||||
|
analyze_memory_files()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
151
debug_memory_leak.py
Normal file
151
debug_memory_leak.py
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Debug memory leak between chunks - track exactly where memory accumulates
|
||||||
|
"""
|
||||||
|
|
||||||
|
import psutil
|
||||||
|
import gc
|
||||||
|
from pathlib import Path
|
||||||
|
import sys
|
||||||
|
|
||||||
|
def detailed_memory_check(label):
|
||||||
|
"""Get detailed memory info"""
|
||||||
|
process = psutil.Process()
|
||||||
|
memory_info = process.memory_info()
|
||||||
|
|
||||||
|
rss_gb = memory_info.rss / (1024**3)
|
||||||
|
vms_gb = memory_info.vms / (1024**3)
|
||||||
|
|
||||||
|
# System memory
|
||||||
|
sys_memory = psutil.virtual_memory()
|
||||||
|
available_gb = sys_memory.available / (1024**3)
|
||||||
|
|
||||||
|
print(f"🔍 {label}:")
|
||||||
|
print(f" RSS: {rss_gb:.2f} GB (physical memory)")
|
||||||
|
print(f" VMS: {vms_gb:.2f} GB (virtual memory)")
|
||||||
|
print(f" Available: {available_gb:.2f} GB")
|
||||||
|
|
||||||
|
return rss_gb
|
||||||
|
|
||||||
|
def simulate_chunk_processing():
|
||||||
|
"""Simulate the chunk processing to see where memory accumulates"""
|
||||||
|
|
||||||
|
print("🚀 SIMULATING CHUNK PROCESSING TO FIND MEMORY LEAK")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
base_memory = detailed_memory_check("0. Baseline")
|
||||||
|
|
||||||
|
# Step 1: Import everything (with lazy loading)
|
||||||
|
print("\n📦 Step 1: Imports")
|
||||||
|
from vr180_matting.config import VR180Config
|
||||||
|
from vr180_matting.vr180_processor import VR180Processor
|
||||||
|
|
||||||
|
import_memory = detailed_memory_check("1. After imports")
|
||||||
|
import_growth = import_memory - base_memory
|
||||||
|
print(f" Growth: +{import_growth:.2f} GB")
|
||||||
|
|
||||||
|
# Step 2: Load config
|
||||||
|
print("\n⚙️ Step 2: Config loading")
|
||||||
|
config = VR180Config.from_yaml('config.yaml')
|
||||||
|
config_memory = detailed_memory_check("2. After config load")
|
||||||
|
config_growth = config_memory - import_memory
|
||||||
|
print(f" Growth: +{config_growth:.2f} GB")
|
||||||
|
|
||||||
|
# Step 3: Initialize processor (models still lazy)
|
||||||
|
print("\n🏗️ Step 3: Processor initialization")
|
||||||
|
processor = VR180Processor(config)
|
||||||
|
processor_memory = detailed_memory_check("3. After processor init")
|
||||||
|
processor_growth = processor_memory - config_memory
|
||||||
|
print(f" Growth: +{processor_growth:.2f} GB")
|
||||||
|
|
||||||
|
# Step 4: Load video info (lightweight)
|
||||||
|
print("\n🎬 Step 4: Video info loading")
|
||||||
|
try:
|
||||||
|
video_info = processor.load_video_info(config.input.video_path)
|
||||||
|
print(f" Video: {video_info.get('width', 'unknown')}x{video_info.get('height', 'unknown')}, "
|
||||||
|
f"{video_info.get('total_frames', 'unknown')} frames")
|
||||||
|
except Exception as e:
|
||||||
|
print(f" Warning: Could not load video info: {e}")
|
||||||
|
|
||||||
|
video_info_memory = detailed_memory_check("4. After video info")
|
||||||
|
video_info_growth = video_info_memory - processor_memory
|
||||||
|
print(f" Growth: +{video_info_growth:.2f} GB")
|
||||||
|
|
||||||
|
# Step 5: Simulate chunk 0 processing (this is where models actually load)
|
||||||
|
print("\n🔄 Step 5: Simulating chunk 0 processing...")
|
||||||
|
|
||||||
|
# This is where the real memory usage starts
|
||||||
|
print(" Loading first 10 frames to trigger model loading...")
|
||||||
|
try:
|
||||||
|
# Read a small number of frames to trigger model loading
|
||||||
|
frames = processor.read_video_frames(
|
||||||
|
config.input.video_path,
|
||||||
|
start_frame=0,
|
||||||
|
num_frames=10, # Just 10 frames to trigger model loading
|
||||||
|
scale_factor=config.processing.scale_factor
|
||||||
|
)
|
||||||
|
|
||||||
|
frames_memory = detailed_memory_check("5a. After reading 10 frames")
|
||||||
|
frames_growth = frames_memory - video_info_memory
|
||||||
|
print(f" 10 frames growth: +{frames_growth:.2f} GB")
|
||||||
|
|
||||||
|
# Free frames
|
||||||
|
del frames
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
after_free_memory = detailed_memory_check("5b. After freeing 10 frames")
|
||||||
|
free_improvement = frames_memory - after_free_memory
|
||||||
|
print(f" Memory freed: -{free_improvement:.2f} GB")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f" Could not simulate frame loading: {e}")
|
||||||
|
after_free_memory = video_info_memory
|
||||||
|
|
||||||
|
print(f"\n📊 MEMORY ANALYSIS:")
|
||||||
|
print(f" Baseline → Final: {base_memory:.2f}GB → {after_free_memory:.2f}GB")
|
||||||
|
print(f" Total growth: +{after_free_memory - base_memory:.2f}GB")
|
||||||
|
|
||||||
|
if after_free_memory - base_memory > 10:
|
||||||
|
print(f" 🔴 HIGH: Memory growth > 10GB before any real processing")
|
||||||
|
print(f" 💡 This suggests model loading is using too much memory")
|
||||||
|
elif after_free_memory - base_memory > 5:
|
||||||
|
print(f" 🟡 MODERATE: Memory growth 5-10GB")
|
||||||
|
print(f" 💡 Normal for model loading, but monitor chunk processing")
|
||||||
|
else:
|
||||||
|
print(f" 🟢 GOOD: Memory growth < 5GB")
|
||||||
|
print(f" 💡 Initialization memory usage is reasonable")
|
||||||
|
|
||||||
|
print(f"\n🎯 KEY INSIGHTS:")
|
||||||
|
if import_growth > 1:
|
||||||
|
print(f" - Import growth: {import_growth:.2f}GB (fixed with lazy loading)")
|
||||||
|
if processor_growth > 10:
|
||||||
|
print(f" - Processor init: {processor_growth:.2f}GB (investigate model pre-loading)")
|
||||||
|
|
||||||
|
print(f"\n💡 RECOMMENDATIONS:")
|
||||||
|
if after_free_memory - base_memory > 15:
|
||||||
|
print(f" 1. Reduce chunk_size to 200-300 frames")
|
||||||
|
print(f" 2. Use smaller models (yolov8n instead of yolov8m)")
|
||||||
|
print(f" 3. Enable FP16 mode for SAM2")
|
||||||
|
elif after_free_memory - base_memory > 8:
|
||||||
|
print(f" 1. Monitor chunk processing carefully")
|
||||||
|
print(f" 2. Use streaming merge (should be automatic)")
|
||||||
|
print(f" 3. Current settings may be acceptable")
|
||||||
|
else:
|
||||||
|
print(f" 1. Settings look good for initialization")
|
||||||
|
print(f" 2. Focus on chunk processing memory leaks")
|
||||||
|
|
||||||
|
def main():
|
||||||
|
if len(sys.argv) != 2:
|
||||||
|
print("Usage: python debug_memory_leak.py <config.yaml>")
|
||||||
|
print("This simulates initialization to find memory leaks")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
config_path = sys.argv[1]
|
||||||
|
if not Path(config_path).exists():
|
||||||
|
print(f"Config file not found: {config_path}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
simulate_chunk_processing()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
249
memory_profiler_script.py
Normal file
249
memory_profiler_script.py
Normal file
@@ -0,0 +1,249 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Memory profiling script for VR180 matting pipeline
|
||||||
|
Tracks memory usage during processing to identify leaks
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import psutil
|
||||||
|
import tracemalloc
|
||||||
|
import subprocess
|
||||||
|
import gc
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Tuple
|
||||||
|
import threading
|
||||||
|
import json
|
||||||
|
|
||||||
|
class MemoryProfiler:
|
||||||
|
def __init__(self, output_file: str = "memory_profile.json"):
|
||||||
|
self.output_file = output_file
|
||||||
|
self.data = []
|
||||||
|
self.process = psutil.Process()
|
||||||
|
self.running = False
|
||||||
|
self.thread = None
|
||||||
|
self.checkpoint_counter = 0
|
||||||
|
|
||||||
|
def start_monitoring(self, interval: float = 1.0):
|
||||||
|
"""Start continuous memory monitoring"""
|
||||||
|
tracemalloc.start()
|
||||||
|
self.running = True
|
||||||
|
self.thread = threading.Thread(target=self._monitor_loop, args=(interval,))
|
||||||
|
self.thread.daemon = True
|
||||||
|
self.thread.start()
|
||||||
|
print(f"🔍 Memory monitoring started (interval: {interval}s)")
|
||||||
|
|
||||||
|
def stop_monitoring(self):
|
||||||
|
"""Stop monitoring and save results"""
|
||||||
|
self.running = False
|
||||||
|
if self.thread:
|
||||||
|
self.thread.join()
|
||||||
|
|
||||||
|
# Get tracemalloc snapshot
|
||||||
|
snapshot = tracemalloc.take_snapshot()
|
||||||
|
top_stats = snapshot.statistics('lineno')
|
||||||
|
|
||||||
|
# Save detailed results
|
||||||
|
results = {
|
||||||
|
'timeline': self.data,
|
||||||
|
'top_memory_allocations': [
|
||||||
|
{
|
||||||
|
'file': stat.traceback.format()[0],
|
||||||
|
'size_mb': stat.size / 1024 / 1024,
|
||||||
|
'count': stat.count
|
||||||
|
}
|
||||||
|
for stat in top_stats[:20] # Top 20 allocations
|
||||||
|
],
|
||||||
|
'summary': {
|
||||||
|
'peak_rss_gb': max([d['rss_gb'] for d in self.data]) if self.data else 0,
|
||||||
|
'peak_vram_gb': max([d['vram_gb'] for d in self.data]) if self.data else 0,
|
||||||
|
'total_samples': len(self.data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(self.output_file, 'w') as f:
|
||||||
|
json.dump(results, f, indent=2)
|
||||||
|
|
||||||
|
tracemalloc.stop()
|
||||||
|
print(f"📊 Memory profile saved to {self.output_file}")
|
||||||
|
|
||||||
|
def _monitor_loop(self, interval: float):
|
||||||
|
"""Continuous monitoring loop"""
|
||||||
|
while self.running:
|
||||||
|
try:
|
||||||
|
# System memory
|
||||||
|
memory_info = self.process.memory_info()
|
||||||
|
rss_gb = memory_info.rss / (1024**3)
|
||||||
|
|
||||||
|
# System-wide memory
|
||||||
|
sys_memory = psutil.virtual_memory()
|
||||||
|
sys_used_gb = (sys_memory.total - sys_memory.available) / (1024**3)
|
||||||
|
sys_available_gb = sys_memory.available / (1024**3)
|
||||||
|
|
||||||
|
# GPU memory (if available)
|
||||||
|
vram_gb = 0
|
||||||
|
vram_free_gb = 0
|
||||||
|
try:
|
||||||
|
result = subprocess.run(['nvidia-smi', '--query-gpu=memory.used,memory.free',
|
||||||
|
'--format=csv,noheader,nounits'],
|
||||||
|
capture_output=True, text=True, timeout=5)
|
||||||
|
if result.returncode == 0:
|
||||||
|
lines = result.stdout.strip().split('\n')
|
||||||
|
if lines and lines[0]:
|
||||||
|
used, free = lines[0].split(', ')
|
||||||
|
vram_gb = float(used) / 1024
|
||||||
|
vram_free_gb = float(free) / 1024
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Tracemalloc current usage
|
||||||
|
try:
|
||||||
|
current, peak = tracemalloc.get_traced_memory()
|
||||||
|
traced_mb = current / (1024**2)
|
||||||
|
except Exception:
|
||||||
|
traced_mb = 0
|
||||||
|
|
||||||
|
data_point = {
|
||||||
|
'timestamp': time.time(),
|
||||||
|
'rss_gb': rss_gb,
|
||||||
|
'vram_gb': vram_gb,
|
||||||
|
'vram_free_gb': vram_free_gb,
|
||||||
|
'sys_used_gb': sys_used_gb,
|
||||||
|
'sys_available_gb': sys_available_gb,
|
||||||
|
'traced_mb': traced_mb
|
||||||
|
}
|
||||||
|
|
||||||
|
self.data.append(data_point)
|
||||||
|
|
||||||
|
# Print periodic updates and save partial data
|
||||||
|
if len(self.data) % 10 == 0: # Every 10 samples
|
||||||
|
print(f"🔍 Memory: RSS={rss_gb:.2f}GB, VRAM={vram_gb:.2f}GB, Sys={sys_used_gb:.1f}GB")
|
||||||
|
|
||||||
|
# Save partial data every 30 samples in case of crash
|
||||||
|
if len(self.data) % 30 == 0:
|
||||||
|
self._save_partial_data()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Monitoring error: {e}")
|
||||||
|
|
||||||
|
time.sleep(interval)
|
||||||
|
|
||||||
|
def _save_partial_data(self):
|
||||||
|
"""Save partial data to prevent loss on crash"""
|
||||||
|
try:
|
||||||
|
partial_file = f"memory_profile_partial_{self.checkpoint_counter}.json"
|
||||||
|
with open(partial_file, 'w') as f:
|
||||||
|
json.dump({
|
||||||
|
'timeline': self.data,
|
||||||
|
'status': 'partial_save',
|
||||||
|
'samples': len(self.data)
|
||||||
|
}, f, indent=2)
|
||||||
|
self.checkpoint_counter += 1
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to save partial data: {e}")
|
||||||
|
|
||||||
|
def log_checkpoint(self, checkpoint_name: str):
|
||||||
|
"""Log a specific checkpoint"""
|
||||||
|
if self.data:
|
||||||
|
self.data[-1]['checkpoint'] = checkpoint_name
|
||||||
|
latest = self.data[-1]
|
||||||
|
print(f"📍 CHECKPOINT [{checkpoint_name}]: RSS={latest['rss_gb']:.2f}GB, VRAM={latest['vram_gb']:.2f}GB")
|
||||||
|
|
||||||
|
# Save checkpoint data immediately
|
||||||
|
self._save_partial_data()
|
||||||
|
|
||||||
|
def run_with_profiling(config_path: str):
|
||||||
|
"""Run the VR180 matting with memory profiling"""
|
||||||
|
profiler = MemoryProfiler("memory_profile_detailed.json")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Start monitoring
|
||||||
|
profiler.start_monitoring(interval=2.0) # Sample every 2 seconds
|
||||||
|
|
||||||
|
# Log initial state
|
||||||
|
profiler.log_checkpoint("STARTUP")
|
||||||
|
|
||||||
|
# Import after starting profiler to catch import memory usage
|
||||||
|
print("Importing VR180 processor...")
|
||||||
|
from vr180_matting.vr180_processor import VR180Processor
|
||||||
|
from vr180_matting.config import VR180Config
|
||||||
|
|
||||||
|
profiler.log_checkpoint("IMPORTS_COMPLETE")
|
||||||
|
|
||||||
|
# Load config
|
||||||
|
print(f"Loading config from {config_path}")
|
||||||
|
config = VR180Config.from_yaml(config_path)
|
||||||
|
|
||||||
|
profiler.log_checkpoint("CONFIG_LOADED")
|
||||||
|
|
||||||
|
# Initialize processor
|
||||||
|
print("Initializing VR180 processor...")
|
||||||
|
processor = VR180Processor(config)
|
||||||
|
|
||||||
|
profiler.log_checkpoint("PROCESSOR_INITIALIZED")
|
||||||
|
|
||||||
|
# Force garbage collection
|
||||||
|
gc.collect()
|
||||||
|
profiler.log_checkpoint("INITIAL_GC_COMPLETE")
|
||||||
|
|
||||||
|
# Run processing
|
||||||
|
print("Starting VR180 processing...")
|
||||||
|
processor.process_video()
|
||||||
|
|
||||||
|
profiler.log_checkpoint("PROCESSING_COMPLETE")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Error during processing: {e}")
|
||||||
|
profiler.log_checkpoint(f"ERROR: {str(e)}")
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
# Stop monitoring and save results
|
||||||
|
profiler.stop_monitoring()
|
||||||
|
|
||||||
|
# Print summary
|
||||||
|
print("\n" + "="*60)
|
||||||
|
print("MEMORY PROFILING SUMMARY")
|
||||||
|
print("="*60)
|
||||||
|
|
||||||
|
if profiler.data:
|
||||||
|
peak_rss = max([d['rss_gb'] for d in profiler.data])
|
||||||
|
peak_vram = max([d['vram_gb'] for d in profiler.data])
|
||||||
|
|
||||||
|
print(f"Peak RSS Memory: {peak_rss:.2f} GB")
|
||||||
|
print(f"Peak VRAM Usage: {peak_vram:.2f} GB")
|
||||||
|
print(f"Total Samples: {len(profiler.data)}")
|
||||||
|
|
||||||
|
# Show checkpoints
|
||||||
|
checkpoints = [d for d in profiler.data if 'checkpoint' in d]
|
||||||
|
if checkpoints:
|
||||||
|
print(f"\nCheckpoints ({len(checkpoints)}):")
|
||||||
|
for cp in checkpoints:
|
||||||
|
print(f" {cp['checkpoint']}: RSS={cp['rss_gb']:.2f}GB, VRAM={cp['vram_gb']:.2f}GB")
|
||||||
|
|
||||||
|
print(f"\nDetailed profile saved to: {profiler.output_file}")
|
||||||
|
|
||||||
|
def main():
|
||||||
|
if len(sys.argv) != 2:
|
||||||
|
print("Usage: python memory_profiler_script.py <config.yaml>")
|
||||||
|
print("\nThis script runs VR180 matting with detailed memory profiling")
|
||||||
|
print("It will:")
|
||||||
|
print("- Monitor RSS, VRAM, and system memory every 2 seconds")
|
||||||
|
print("- Track memory allocations with tracemalloc")
|
||||||
|
print("- Log checkpoints at key processing stages")
|
||||||
|
print("- Save detailed JSON report for analysis")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
config_path = sys.argv[1]
|
||||||
|
|
||||||
|
if not Path(config_path).exists():
|
||||||
|
print(f"❌ Config file not found: {config_path}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
print("🚀 Starting VR180 Memory Profiling")
|
||||||
|
print(f"Config: {config_path}")
|
||||||
|
print("="*60)
|
||||||
|
|
||||||
|
run_with_profiling(config_path)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
125
quick_memory_check.py
Normal file
125
quick_memory_check.py
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Quick memory and system check before running full pipeline
|
||||||
|
"""
|
||||||
|
|
||||||
|
import psutil
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
def check_system():
|
||||||
|
"""Check system resources before starting"""
|
||||||
|
print("🔍 SYSTEM RESOURCE CHECK")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
# Memory info
|
||||||
|
memory = psutil.virtual_memory()
|
||||||
|
print(f"📊 RAM:")
|
||||||
|
print(f" Total: {memory.total / (1024**3):.1f} GB")
|
||||||
|
print(f" Available: {memory.available / (1024**3):.1f} GB")
|
||||||
|
print(f" Used: {(memory.total - memory.available) / (1024**3):.1f} GB ({memory.percent:.1f}%)")
|
||||||
|
|
||||||
|
# GPU info
|
||||||
|
try:
|
||||||
|
result = subprocess.run(['nvidia-smi', '--query-gpu=name,memory.total,memory.used,memory.free',
|
||||||
|
'--format=csv,noheader,nounits'],
|
||||||
|
capture_output=True, text=True, timeout=10)
|
||||||
|
if result.returncode == 0:
|
||||||
|
lines = result.stdout.strip().split('\n')
|
||||||
|
print(f"\n🎮 GPU:")
|
||||||
|
for i, line in enumerate(lines):
|
||||||
|
if line.strip():
|
||||||
|
parts = line.split(', ')
|
||||||
|
if len(parts) >= 4:
|
||||||
|
name, total, used, free = parts[:4]
|
||||||
|
total_gb = float(total) / 1024
|
||||||
|
used_gb = float(used) / 1024
|
||||||
|
free_gb = float(free) / 1024
|
||||||
|
print(f" GPU {i}: {name}")
|
||||||
|
print(f" VRAM: {used_gb:.1f}/{total_gb:.1f} GB ({used_gb/total_gb*100:.1f}% used)")
|
||||||
|
print(f" Free: {free_gb:.1f} GB")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n⚠️ Could not get GPU info: {e}")
|
||||||
|
|
||||||
|
# Disk space
|
||||||
|
disk = psutil.disk_usage('/')
|
||||||
|
print(f"\n💾 Disk (/):")
|
||||||
|
print(f" Total: {disk.total / (1024**3):.1f} GB")
|
||||||
|
print(f" Used: {disk.used / (1024**3):.1f} GB ({disk.used/disk.total*100:.1f}%)")
|
||||||
|
print(f" Free: {disk.free / (1024**3):.1f} GB")
|
||||||
|
|
||||||
|
# Check config file
|
||||||
|
if len(sys.argv) > 1:
|
||||||
|
config_path = sys.argv[1]
|
||||||
|
if Path(config_path).exists():
|
||||||
|
print(f"\n✅ Config file found: {config_path}")
|
||||||
|
|
||||||
|
# Try to load and show key settings
|
||||||
|
try:
|
||||||
|
import yaml
|
||||||
|
with open(config_path, 'r') as f:
|
||||||
|
config = yaml.safe_load(f)
|
||||||
|
|
||||||
|
print(f"📋 Key Settings:")
|
||||||
|
if 'processing' in config:
|
||||||
|
proc = config['processing']
|
||||||
|
print(f" Chunk size: {proc.get('chunk_size', 'default')}")
|
||||||
|
print(f" Scale factor: {proc.get('scale_factor', 'default')}")
|
||||||
|
|
||||||
|
if 'hardware' in config:
|
||||||
|
hw = config['hardware']
|
||||||
|
print(f" Max VRAM: {hw.get('max_vram_gb', 'default')} GB")
|
||||||
|
|
||||||
|
if 'input' in config:
|
||||||
|
inp = config['input']
|
||||||
|
video_path = inp.get('video_path', '')
|
||||||
|
if video_path and Path(video_path).exists():
|
||||||
|
size_gb = Path(video_path).stat().st_size / (1024**3)
|
||||||
|
print(f" Input video: {video_path} ({size_gb:.1f} GB)")
|
||||||
|
else:
|
||||||
|
print(f" ⚠️ Input video not found: {video_path}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f" ⚠️ Could not parse config: {e}")
|
||||||
|
else:
|
||||||
|
print(f"\n❌ Config file not found: {config_path}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Memory safety warnings
|
||||||
|
print(f"\n⚠️ MEMORY SAFETY CHECKS:")
|
||||||
|
available_gb = memory.available / (1024**3)
|
||||||
|
|
||||||
|
if available_gb < 10:
|
||||||
|
print(f" 🔴 LOW MEMORY: Only {available_gb:.1f}GB available")
|
||||||
|
print(" Consider: reducing chunk_size or scale_factor")
|
||||||
|
return False
|
||||||
|
elif available_gb < 20:
|
||||||
|
print(f" 🟡 MODERATE MEMORY: {available_gb:.1f}GB available")
|
||||||
|
print(" Recommend: chunk_size ≤ 300, scale_factor ≤ 0.5")
|
||||||
|
else:
|
||||||
|
print(f" 🟢 GOOD MEMORY: {available_gb:.1f}GB available")
|
||||||
|
|
||||||
|
print(f"\n" + "=" * 50)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def main():
|
||||||
|
if len(sys.argv) != 2:
|
||||||
|
print("Usage: python quick_memory_check.py <config.yaml>")
|
||||||
|
print("\nThis checks system resources before running VR180 matting")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
safe_to_run = check_system()
|
||||||
|
|
||||||
|
if safe_to_run:
|
||||||
|
print("✅ System check passed - safe to run VR180 matting")
|
||||||
|
print("\nTo run with memory profiling:")
|
||||||
|
print(f" python memory_profiler_script.py {sys.argv[1]}")
|
||||||
|
print("\nTo run normally:")
|
||||||
|
print(f" vr180-matting {sys.argv[1]}")
|
||||||
|
else:
|
||||||
|
print("❌ System check failed - address issues before running")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -29,6 +29,11 @@ class MattingConfig:
|
|||||||
fp16: bool = True
|
fp16: bool = True
|
||||||
sam2_model_cfg: str = "sam2.1_hiera_l"
|
sam2_model_cfg: str = "sam2.1_hiera_l"
|
||||||
sam2_checkpoint: str = "segment-anything-2/checkpoints/sam2.1_hiera_large.pt"
|
sam2_checkpoint: str = "segment-anything-2/checkpoints/sam2.1_hiera_large.pt"
|
||||||
|
# Det-SAM2 optimizations
|
||||||
|
continuous_correction: bool = True
|
||||||
|
correction_interval: int = 60 # Add correction prompts every N frames
|
||||||
|
frame_release_interval: int = 50 # Release old frames every N frames
|
||||||
|
frame_window_size: int = 30 # Keep N frames in memory
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -1,6 +1,4 @@
|
|||||||
import torch
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from ultralytics import YOLO
|
|
||||||
from typing import List, Tuple, Dict, Any
|
from typing import List, Tuple, Dict, Any
|
||||||
import cv2
|
import cv2
|
||||||
|
|
||||||
@@ -13,14 +11,23 @@ class YOLODetector:
|
|||||||
self.confidence_threshold = confidence_threshold
|
self.confidence_threshold = confidence_threshold
|
||||||
self.device = device
|
self.device = device
|
||||||
self.model = None
|
self.model = None
|
||||||
self._load_model()
|
# Don't load model during init - load lazily when first used
|
||||||
|
|
||||||
def _load_model(self):
|
def _load_model(self):
|
||||||
"""Load YOLOv8 model"""
|
"""Load YOLOv8 model lazily"""
|
||||||
|
if self.model is not None:
|
||||||
|
return # Already loaded
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# Import heavy dependencies only when needed
|
||||||
|
import torch
|
||||||
|
from ultralytics import YOLO
|
||||||
|
|
||||||
self.model = YOLO(f"{self.model_name}.pt")
|
self.model = YOLO(f"{self.model_name}.pt")
|
||||||
if self.device == "cuda" and torch.cuda.is_available():
|
if self.device == "cuda" and torch.cuda.is_available():
|
||||||
self.model.to("cuda")
|
self.model.to("cuda")
|
||||||
|
|
||||||
|
print(f"🎯 Loaded YOLO model: {self.model_name}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Failed to load YOLO model {self.model_name}: {e}")
|
raise RuntimeError(f"Failed to load YOLO model {self.model_name}: {e}")
|
||||||
|
|
||||||
@@ -34,8 +41,9 @@ class YOLODetector:
|
|||||||
Returns:
|
Returns:
|
||||||
List of detection dictionaries with bbox, confidence, and class info
|
List of detection dictionaries with bbox, confidence, and class info
|
||||||
"""
|
"""
|
||||||
|
# Load model lazily on first use
|
||||||
if self.model is None:
|
if self.model is None:
|
||||||
raise RuntimeError("YOLO model not loaded")
|
self._load_model()
|
||||||
|
|
||||||
results = self.model(frame, verbose=False)
|
results = self.model(frame, verbose=False)
|
||||||
detections = []
|
detections = []
|
||||||
|
|||||||
@@ -9,12 +9,16 @@ import tempfile
|
|||||||
import shutil
|
import shutil
|
||||||
import gc
|
import gc
|
||||||
|
|
||||||
|
# Check SAM2 availability without importing heavy modules
|
||||||
|
def _check_sam2_available():
|
||||||
try:
|
try:
|
||||||
from sam2.build_sam import build_sam2_video_predictor
|
import sam2
|
||||||
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
return True
|
||||||
SAM2_AVAILABLE = True
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
SAM2_AVAILABLE = False
|
return False
|
||||||
|
|
||||||
|
SAM2_AVAILABLE = _check_sam2_available()
|
||||||
|
if not SAM2_AVAILABLE:
|
||||||
warnings.warn("SAM2 not available. Please install sam2 package.")
|
warnings.warn("SAM2 not available. Please install sam2 package.")
|
||||||
|
|
||||||
|
|
||||||
@@ -40,11 +44,18 @@ class SAM2VideoMatting:
|
|||||||
self.video_segments = {}
|
self.video_segments = {}
|
||||||
self.temp_video_path = None
|
self.temp_video_path = None
|
||||||
|
|
||||||
self._load_model(model_cfg, checkpoint_path)
|
# Don't load model during init - load lazily when needed
|
||||||
|
self._model_loaded = False
|
||||||
|
|
||||||
def _load_model(self, model_cfg: str, checkpoint_path: str):
|
def _load_model(self, model_cfg: str, checkpoint_path: str):
|
||||||
"""Load SAM2 video predictor with optimizations"""
|
"""Load SAM2 video predictor lazily"""
|
||||||
|
if self._model_loaded:
|
||||||
|
return # Already loaded
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# Import heavy SAM2 modules only when needed
|
||||||
|
from sam2.build_sam import build_sam2_video_predictor
|
||||||
|
|
||||||
# Check for checkpoint in SAM2 repo structure
|
# Check for checkpoint in SAM2 repo structure
|
||||||
if not Path(checkpoint_path).exists():
|
if not Path(checkpoint_path).exists():
|
||||||
# Try in segment-anything-2/checkpoints/
|
# Try in segment-anything-2/checkpoints/
|
||||||
@@ -63,6 +74,7 @@ class SAM2VideoMatting:
|
|||||||
if sam2_repo_path.exists():
|
if sam2_repo_path.exists():
|
||||||
checkpoint_path = str(sam2_repo_path)
|
checkpoint_path = str(sam2_repo_path)
|
||||||
|
|
||||||
|
print(f"🎯 Loading SAM2 model: {model_cfg}")
|
||||||
# Use SAM2's build_sam2_video_predictor which returns the predictor directly
|
# Use SAM2's build_sam2_video_predictor which returns the predictor directly
|
||||||
# The predictor IS the model - no .model attribute needed
|
# The predictor IS the model - no .model attribute needed
|
||||||
self.predictor = build_sam2_video_predictor(
|
self.predictor = build_sam2_video_predictor(
|
||||||
@@ -71,13 +83,16 @@ class SAM2VideoMatting:
|
|||||||
device=self.device
|
device=self.device
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._model_loaded = True
|
||||||
|
print(f"✅ SAM2 model loaded successfully")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Failed to load SAM2 model: {e}")
|
raise RuntimeError(f"Failed to load SAM2 model: {e}")
|
||||||
|
|
||||||
def init_video_state(self, video_frames: List[np.ndarray] = None, video_path: str = None) -> None:
|
def init_video_state(self, video_frames: List[np.ndarray] = None, video_path: str = None) -> None:
|
||||||
"""Initialize video inference state"""
|
"""Initialize video inference state"""
|
||||||
if self.predictor is None:
|
# Load model lazily on first use
|
||||||
# Recreate predictor if it was cleaned up
|
if not self._model_loaded:
|
||||||
self._load_model(self.model_cfg, self.checkpoint_path)
|
self._load_model(self.model_cfg, self.checkpoint_path)
|
||||||
|
|
||||||
if video_path is not None:
|
if video_path is not None:
|
||||||
@@ -152,13 +167,16 @@ class SAM2VideoMatting:
|
|||||||
|
|
||||||
return object_ids
|
return object_ids
|
||||||
|
|
||||||
def propagate_masks(self, start_frame: int = 0, max_frames: Optional[int] = None) -> Dict[int, Dict[int, np.ndarray]]:
|
def propagate_masks(self, start_frame: int = 0, max_frames: Optional[int] = None,
|
||||||
|
frame_release_interval: int = 50, frame_window_size: int = 30) -> Dict[int, Dict[int, np.ndarray]]:
|
||||||
"""
|
"""
|
||||||
Propagate masks through video
|
Propagate masks through video with Det-SAM2 style memory management
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
start_frame: Starting frame index
|
start_frame: Starting frame index
|
||||||
max_frames: Maximum number of frames to process
|
max_frames: Maximum number of frames to process
|
||||||
|
frame_release_interval: Release old frames every N frames
|
||||||
|
frame_window_size: Keep N frames in memory
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dictionary mapping frame_idx -> {obj_id: mask}
|
Dictionary mapping frame_idx -> {obj_id: mask}
|
||||||
@@ -182,9 +200,108 @@ class SAM2VideoMatting:
|
|||||||
|
|
||||||
video_segments[out_frame_idx] = frame_masks
|
video_segments[out_frame_idx] = frame_masks
|
||||||
|
|
||||||
# Memory management: release old frames periodically
|
# Det-SAM2 style memory management: more aggressive frame release
|
||||||
if self.memory_offload and out_frame_idx % 100 == 0:
|
if self.memory_offload and out_frame_idx % frame_release_interval == 0:
|
||||||
self._release_old_frames(out_frame_idx - 50)
|
self._release_old_frames(out_frame_idx - frame_window_size)
|
||||||
|
# Optional: Log frame release for monitoring
|
||||||
|
if out_frame_idx % (frame_release_interval * 4) == 0: # Log every 4x release interval
|
||||||
|
print(f"Det-SAM2: Released frames before {out_frame_idx - frame_window_size}, keeping {frame_window_size} frames")
|
||||||
|
|
||||||
|
return video_segments
|
||||||
|
|
||||||
|
def propagate_masks_with_continuous_correction(self,
|
||||||
|
detector,
|
||||||
|
temp_video_path: str,
|
||||||
|
start_frame: int = 0,
|
||||||
|
max_frames: Optional[int] = None,
|
||||||
|
correction_interval: int = 60,
|
||||||
|
frame_release_interval: int = 50,
|
||||||
|
frame_window_size: int = 30) -> Dict[int, Dict[int, np.ndarray]]:
|
||||||
|
"""
|
||||||
|
Det-SAM2 style: Propagate masks with continuous prompt correction
|
||||||
|
|
||||||
|
Args:
|
||||||
|
detector: YOLODetector instance for generating correction prompts
|
||||||
|
temp_video_path: Path to video file for frame access
|
||||||
|
start_frame: Starting frame index
|
||||||
|
max_frames: Maximum number of frames to process
|
||||||
|
correction_interval: Add correction prompts every N frames
|
||||||
|
frame_release_interval: Release old frames every N frames
|
||||||
|
frame_window_size: Keep N frames in memory
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping frame_idx -> {obj_id: mask}
|
||||||
|
"""
|
||||||
|
if self.inference_state is None:
|
||||||
|
raise RuntimeError("Video state not initialized")
|
||||||
|
|
||||||
|
video_segments = {}
|
||||||
|
max_frames = max_frames or 10000 # Default limit
|
||||||
|
|
||||||
|
# Open video for accessing frames during propagation
|
||||||
|
cap = cv2.VideoCapture(str(temp_video_path))
|
||||||
|
|
||||||
|
try:
|
||||||
|
for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(
|
||||||
|
self.inference_state,
|
||||||
|
start_frame_idx=start_frame,
|
||||||
|
max_frame_num_to_track=max_frames,
|
||||||
|
reverse=False
|
||||||
|
):
|
||||||
|
frame_masks = {}
|
||||||
|
|
||||||
|
for i, out_obj_id in enumerate(out_obj_ids):
|
||||||
|
mask = (out_mask_logits[i] > 0.0).cpu().numpy()
|
||||||
|
frame_masks[out_obj_id] = mask
|
||||||
|
|
||||||
|
video_segments[out_frame_idx] = frame_masks
|
||||||
|
|
||||||
|
# Det-SAM2 optimization: Add correction prompts at keyframes
|
||||||
|
if (out_frame_idx % correction_interval == 0 and
|
||||||
|
out_frame_idx > start_frame and
|
||||||
|
out_frame_idx < max_frames - 1):
|
||||||
|
|
||||||
|
# Read frame for detection
|
||||||
|
cap.set(cv2.CAP_PROP_POS_FRAMES, out_frame_idx)
|
||||||
|
ret, correction_frame = cap.read()
|
||||||
|
|
||||||
|
if ret:
|
||||||
|
# Run detection on this keyframe
|
||||||
|
detections = detector.detect_persons(correction_frame)
|
||||||
|
|
||||||
|
if detections:
|
||||||
|
# Convert to prompts and add as corrections
|
||||||
|
box_prompts, labels = detector.convert_to_sam_prompts(detections)
|
||||||
|
|
||||||
|
# Add correction prompts (SAM2 will propagate backward)
|
||||||
|
correction_count = 0
|
||||||
|
try:
|
||||||
|
for i, (box, label) in enumerate(zip(box_prompts, labels)):
|
||||||
|
# Use existing object IDs if available, otherwise create new ones
|
||||||
|
obj_id = out_obj_ids[i] if i < len(out_obj_ids) else len(out_obj_ids) + i + 1
|
||||||
|
|
||||||
|
self.predictor.add_new_points_or_box(
|
||||||
|
inference_state=self.inference_state,
|
||||||
|
frame_idx=out_frame_idx,
|
||||||
|
obj_id=obj_id,
|
||||||
|
box=box,
|
||||||
|
)
|
||||||
|
correction_count += 1
|
||||||
|
|
||||||
|
print(f"Det-SAM2: Added {correction_count} correction prompts at frame {out_frame_idx}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
warnings.warn(f"Failed to add correction prompt at frame {out_frame_idx}: {e}")
|
||||||
|
|
||||||
|
# Memory management: More aggressive frame release (Det-SAM2 style)
|
||||||
|
if self.memory_offload and out_frame_idx % frame_release_interval == 0:
|
||||||
|
self._release_old_frames(out_frame_idx - frame_window_size)
|
||||||
|
# Optional: Log frame release for monitoring
|
||||||
|
if out_frame_idx % (frame_release_interval * 4) == 0: # Log every 4x release interval
|
||||||
|
print(f"Det-SAM2: Released frames before {out_frame_idx - frame_window_size}, keeping {frame_window_size} frames")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
cap.release()
|
||||||
|
|
||||||
return video_segments
|
return video_segments
|
||||||
|
|
||||||
|
|||||||
@@ -387,19 +387,83 @@ class VideoProcessor:
|
|||||||
# Green screen background
|
# Green screen background
|
||||||
return np.full_like(frame, self.config.output.background_color, dtype=np.uint8)
|
return np.full_like(frame, self.config.output.background_color, dtype=np.uint8)
|
||||||
|
|
||||||
|
def merge_chunks_streaming(self, chunk_files: List[Path], output_path: str,
|
||||||
|
overlap_frames: int = 0, audio_source: str = None) -> None:
|
||||||
|
"""
|
||||||
|
Merge processed chunks using streaming approach (no memory accumulation)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chunk_files: List of chunk result files (.npz)
|
||||||
|
output_path: Final output video path
|
||||||
|
overlap_frames: Number of overlapping frames
|
||||||
|
audio_source: Audio source file for final video
|
||||||
|
"""
|
||||||
|
from .streaming_video_writer import StreamingVideoWriter
|
||||||
|
|
||||||
|
if not chunk_files:
|
||||||
|
raise ValueError("No chunk files to merge")
|
||||||
|
|
||||||
|
print(f"🎬 Streaming merge: {len(chunk_files)} chunks → {output_path}")
|
||||||
|
|
||||||
|
# Initialize streaming writer
|
||||||
|
writer = StreamingVideoWriter(
|
||||||
|
output_path=output_path,
|
||||||
|
fps=self.video_info['fps'],
|
||||||
|
audio_source=audio_source
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Process each chunk without accumulation
|
||||||
|
for i, chunk_file in enumerate(chunk_files):
|
||||||
|
print(f"📼 Processing chunk {i+1}/{len(chunk_files)}: {chunk_file.name}")
|
||||||
|
|
||||||
|
# Load chunk (this is the only copy in memory)
|
||||||
|
chunk_data = np.load(str(chunk_file))
|
||||||
|
frames = chunk_data['frames'].tolist() # Convert to list of arrays
|
||||||
|
chunk_data.close()
|
||||||
|
|
||||||
|
# Write chunk with streaming writer
|
||||||
|
writer.write_chunk(
|
||||||
|
frames=frames,
|
||||||
|
chunk_index=i,
|
||||||
|
overlap_frames=overlap_frames if i > 0 else 0,
|
||||||
|
blend_with_previous=(i > 0 and overlap_frames > 0)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Immediately free memory
|
||||||
|
del frames, chunk_data
|
||||||
|
|
||||||
|
# Delete chunk file to free disk space
|
||||||
|
try:
|
||||||
|
chunk_file.unlink()
|
||||||
|
print(f" 🗑️ Deleted {chunk_file.name}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f" ⚠️ Could not delete {chunk_file.name}: {e}")
|
||||||
|
|
||||||
|
# Aggressive cleanup every chunk
|
||||||
|
self._aggressive_memory_cleanup(f"After processing chunk {i}")
|
||||||
|
|
||||||
|
# Finalize the video
|
||||||
|
writer.finalize()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Streaming merge failed: {e}")
|
||||||
|
writer.cleanup()
|
||||||
|
raise
|
||||||
|
|
||||||
|
print(f"✅ Streaming merge complete: {output_path}")
|
||||||
|
|
||||||
def merge_overlapping_chunks(self,
|
def merge_overlapping_chunks(self,
|
||||||
chunk_results: List[List[np.ndarray]],
|
chunk_results: List[List[np.ndarray]],
|
||||||
overlap_frames: int) -> List[np.ndarray]:
|
overlap_frames: int) -> List[np.ndarray]:
|
||||||
"""
|
"""
|
||||||
Merge overlapping chunks with blending in overlap regions
|
Legacy merge method - DEPRECATED due to memory accumulation
|
||||||
|
Use merge_chunks_streaming() instead for memory efficiency
|
||||||
Args:
|
|
||||||
chunk_results: List of chunk results
|
|
||||||
overlap_frames: Number of overlapping frames
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Merged frame sequence
|
|
||||||
"""
|
"""
|
||||||
|
import warnings
|
||||||
|
warnings.warn("merge_overlapping_chunks() is deprecated due to memory accumulation. Use merge_chunks_streaming()",
|
||||||
|
DeprecationWarning, stacklevel=2)
|
||||||
|
|
||||||
if len(chunk_results) == 1:
|
if len(chunk_results) == 1:
|
||||||
return chunk_results[0]
|
return chunk_results[0]
|
||||||
|
|
||||||
@@ -640,36 +704,23 @@ class VideoProcessor:
|
|||||||
if self.memory_manager.should_emergency_cleanup():
|
if self.memory_manager.should_emergency_cleanup():
|
||||||
self.memory_manager.emergency_cleanup()
|
self.memory_manager.emergency_cleanup()
|
||||||
|
|
||||||
# Load and merge chunks from disk
|
# Use streaming merge to avoid memory accumulation (fixes OOM)
|
||||||
print("\nLoading and merging chunks...")
|
print("\n🎬 Using streaming merge (no memory accumulation)...")
|
||||||
chunk_results = []
|
|
||||||
for i, chunk_file in enumerate(chunk_files):
|
|
||||||
print(f"Loading {chunk_file.name}...")
|
|
||||||
chunk_data = np.load(str(chunk_file))
|
|
||||||
chunk_results.append(chunk_data['frames'])
|
|
||||||
chunk_data.close() # Close the file
|
|
||||||
|
|
||||||
# Delete chunk file immediately after loading to free disk space
|
# Determine audio source for final video
|
||||||
try:
|
audio_source = None
|
||||||
chunk_file.unlink()
|
if self.config.output.preserve_audio and Path(self.config.input.video_path).exists():
|
||||||
print(f" Deleted chunk file {chunk_file.name}")
|
audio_source = self.config.input.video_path
|
||||||
except Exception as e:
|
|
||||||
print(f" Warning: Could not delete chunk file: {e}")
|
|
||||||
|
|
||||||
# Aggressive cleanup every few chunks to prevent accumulation
|
# Stream merge chunks directly to output (no memory accumulation)
|
||||||
if i % 3 == 0 and i > 0:
|
self.merge_chunks_streaming(
|
||||||
self._aggressive_memory_cleanup(f"after loading chunk {i}")
|
chunk_files=chunk_files,
|
||||||
|
output_path=self.config.output.path,
|
||||||
|
overlap_frames=overlap_frames,
|
||||||
|
audio_source=audio_source
|
||||||
|
)
|
||||||
|
|
||||||
# Merge chunks
|
print("✅ Streaming merge complete - no memory accumulation!")
|
||||||
final_frames = self.merge_overlapping_chunks(chunk_results, overlap_frames)
|
|
||||||
|
|
||||||
# Free chunk results after merging - this is critical!
|
|
||||||
del chunk_results
|
|
||||||
self._aggressive_memory_cleanup("after merging chunks")
|
|
||||||
|
|
||||||
# Save results
|
|
||||||
print(f"Saving {len(final_frames)} processed frames...")
|
|
||||||
self.save_video(final_frames, self.config.output.path)
|
|
||||||
|
|
||||||
# Calculate final statistics
|
# Calculate final statistics
|
||||||
self.processing_stats['end_time'] = time.time()
|
self.processing_stats['end_time'] = time.time()
|
||||||
|
|||||||
@@ -375,31 +375,43 @@ class VR180Processor(VideoProcessor):
|
|||||||
|
|
||||||
# Propagate masks (most expensive operation)
|
# Propagate masks (most expensive operation)
|
||||||
self._print_memory_step(f"Before SAM2 propagation ({eye_name} eye, {num_frames} frames)")
|
self._print_memory_step(f"Before SAM2 propagation ({eye_name} eye, {num_frames} frames)")
|
||||||
|
|
||||||
|
# Use Det-SAM2 continuous correction if enabled
|
||||||
|
if self.config.matting.continuous_correction:
|
||||||
|
video_segments = self.sam2_model.propagate_masks_with_continuous_correction(
|
||||||
|
detector=self.detector,
|
||||||
|
temp_video_path=str(temp_video_path),
|
||||||
|
start_frame=0,
|
||||||
|
max_frames=num_frames,
|
||||||
|
correction_interval=self.config.matting.correction_interval,
|
||||||
|
frame_release_interval=self.config.matting.frame_release_interval,
|
||||||
|
frame_window_size=self.config.matting.frame_window_size
|
||||||
|
)
|
||||||
|
print(f"Used Det-SAM2 continuous correction (interval: {self.config.matting.correction_interval} frames)")
|
||||||
|
else:
|
||||||
video_segments = self.sam2_model.propagate_masks(
|
video_segments = self.sam2_model.propagate_masks(
|
||||||
start_frame=0,
|
start_frame=0,
|
||||||
max_frames=num_frames
|
max_frames=num_frames,
|
||||||
|
frame_release_interval=self.config.matting.frame_release_interval,
|
||||||
|
frame_window_size=self.config.matting.frame_window_size
|
||||||
)
|
)
|
||||||
|
|
||||||
self._print_memory_step(f"After SAM2 propagation ({eye_name} eye)")
|
self._print_memory_step(f"After SAM2 propagation ({eye_name} eye)")
|
||||||
|
|
||||||
# Apply masks - need to reload frames from temp video since we freed the original frames
|
# Apply masks with streaming approach (no frame accumulation)
|
||||||
self._print_memory_step(f"Before reloading frames for mask application ({eye_name} eye)")
|
self._print_memory_step(f"Before streaming mask application ({eye_name} eye)")
|
||||||
|
|
||||||
# Read frames back from the temp video for mask application
|
# Process frames one at a time without accumulation
|
||||||
cap = cv2.VideoCapture(str(temp_video_path))
|
cap = cv2.VideoCapture(str(temp_video_path))
|
||||||
reloaded_frames = []
|
matted_frames = []
|
||||||
|
|
||||||
|
try:
|
||||||
for frame_idx in range(num_frames):
|
for frame_idx in range(num_frames):
|
||||||
ret, frame = cap.read()
|
ret, frame = cap.read()
|
||||||
if not ret:
|
if not ret:
|
||||||
break
|
break
|
||||||
reloaded_frames.append(frame)
|
|
||||||
cap.release()
|
|
||||||
|
|
||||||
self._print_memory_step(f"Reloaded {len(reloaded_frames)} frames for mask application")
|
# Apply mask to this single frame
|
||||||
|
|
||||||
# Apply masks
|
|
||||||
matted_frames = []
|
|
||||||
for frame_idx, frame in enumerate(reloaded_frames):
|
|
||||||
if frame_idx in video_segments:
|
if frame_idx in video_segments:
|
||||||
frame_masks = video_segments[frame_idx]
|
frame_masks = video_segments[frame_idx]
|
||||||
combined_mask = self.sam2_model.get_combined_mask(frame_masks)
|
combined_mask = self.sam2_model.get_combined_mask(frame_masks)
|
||||||
@@ -414,11 +426,22 @@ class VR180Processor(VideoProcessor):
|
|||||||
|
|
||||||
matted_frames.append(matted_frame)
|
matted_frames.append(matted_frame)
|
||||||
|
|
||||||
# Free reloaded frames and video segments completely
|
# Free the original frame immediately (no accumulation)
|
||||||
del reloaded_frames
|
del frame
|
||||||
del video_segments # This holds processed masks from SAM2
|
|
||||||
self._aggressive_memory_cleanup(f"After mask application ({eye_name} eye)")
|
|
||||||
|
|
||||||
|
# Periodic cleanup during processing
|
||||||
|
if frame_idx % 100 == 0 and frame_idx > 0:
|
||||||
|
import gc
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
finally:
|
||||||
|
cap.release()
|
||||||
|
|
||||||
|
# Free video segments completely
|
||||||
|
del video_segments # This holds processed masks from SAM2
|
||||||
|
self._aggressive_memory_cleanup(f"After streaming mask application ({eye_name} eye)")
|
||||||
|
|
||||||
|
self._print_memory_step(f"Completed streaming mask application ({eye_name} eye)")
|
||||||
return matted_frames
|
return matted_frames
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
|
|||||||
Reference in New Issue
Block a user