Compare commits
37 Commits
e195d23584
...
det
| Author | SHA1 | Date | |
|---|---|---|---|
| 277d554ecc | |||
| d6d2b0aa93 | |||
| 3a547b7c21 | |||
| 262cb00b69 | |||
| caa4ddb5e0 | |||
| fa945b9c3e | |||
| 4958c503dd | |||
| 366b132ef5 | |||
| 4d1361df46 | |||
| 884cb8dce2 | |||
| 36f58acb8b | |||
| fb51e82fd4 | |||
| 9f572d4430 | |||
| ba8706b7ae | |||
| 734445cf48 | |||
| 80f947c91b | |||
| 6f93abcb08 | |||
| c368d6dc97 | |||
| e7e9c5597b | |||
| 3af16df71e | |||
| df7b009a7b | |||
| 725a781456 | |||
| ccc68a3895 | |||
| 463f881eaf | |||
| b642b562f0 | |||
| 40ae537f7a | |||
| 28aa663b7b | |||
| 0244ba5204 | |||
| 141302cccf | |||
| 6b0eb6104d | |||
| 0f8818259e | |||
| 86274ba04a | |||
| 99c4da83af | |||
| c4af7baf3d | |||
| 3e21fd8678 | |||
| d933d6b606 | |||
| 7852303b40 |
193
analyze_memory_profile.py
Normal file
193
analyze_memory_profile.py
Normal file
@@ -0,0 +1,193 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Analyze memory profile JSON files to identify OOM causes
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import glob
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
def analyze_memory_files():
|
||||||
|
"""Analyze partial memory profile files"""
|
||||||
|
|
||||||
|
# Get all partial files in order
|
||||||
|
files = sorted(glob.glob('memory_profile_partial_*.json'))
|
||||||
|
|
||||||
|
if not files:
|
||||||
|
print("❌ No memory profile files found!")
|
||||||
|
print("Expected files like: memory_profile_partial_0.json")
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"🔍 Found {len(files)} memory profile files")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
peak_memory = 0
|
||||||
|
peak_vram = 0
|
||||||
|
critical_points = []
|
||||||
|
all_checkpoints = []
|
||||||
|
|
||||||
|
for i, file in enumerate(files):
|
||||||
|
try:
|
||||||
|
with open(file, 'r') as f:
|
||||||
|
data = json.load(f)
|
||||||
|
|
||||||
|
timeline = data.get('timeline', [])
|
||||||
|
if not timeline:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Find peaks in this file
|
||||||
|
file_peak_rss = max([d['rss_gb'] for d in timeline])
|
||||||
|
file_peak_vram = max([d['vram_gb'] for d in timeline])
|
||||||
|
|
||||||
|
if file_peak_rss > peak_memory:
|
||||||
|
peak_memory = file_peak_rss
|
||||||
|
if file_peak_vram > peak_vram:
|
||||||
|
peak_vram = file_peak_vram
|
||||||
|
|
||||||
|
# Find memory growth spikes (>3GB increase)
|
||||||
|
for j in range(1, len(timeline)):
|
||||||
|
prev_rss = timeline[j-1]['rss_gb']
|
||||||
|
curr_rss = timeline[j]['rss_gb']
|
||||||
|
growth = curr_rss - prev_rss
|
||||||
|
|
||||||
|
if growth > 3.0: # >3GB growth spike
|
||||||
|
checkpoint = timeline[j].get('checkpoint', f'sample_{j}')
|
||||||
|
critical_points.append({
|
||||||
|
'file': file,
|
||||||
|
'file_index': i,
|
||||||
|
'sample': j,
|
||||||
|
'timestamp': timeline[j]['timestamp'],
|
||||||
|
'rss_gb': curr_rss,
|
||||||
|
'vram_gb': timeline[j]['vram_gb'],
|
||||||
|
'growth_gb': growth,
|
||||||
|
'checkpoint': checkpoint
|
||||||
|
})
|
||||||
|
|
||||||
|
# Collect all checkpoints
|
||||||
|
checkpoints = [d for d in timeline if 'checkpoint' in d]
|
||||||
|
for cp in checkpoints:
|
||||||
|
cp['file'] = file
|
||||||
|
cp['file_index'] = i
|
||||||
|
all_checkpoints.append(cp)
|
||||||
|
|
||||||
|
# Show progress for this file
|
||||||
|
if timeline:
|
||||||
|
start_rss = timeline[0]['rss_gb']
|
||||||
|
end_rss = timeline[-1]['rss_gb']
|
||||||
|
growth = end_rss - start_rss
|
||||||
|
samples = len(timeline)
|
||||||
|
|
||||||
|
print(f"📊 File {i+1:2d}: {start_rss:5.1f}GB → {end_rss:5.1f}GB "
|
||||||
|
f"(+{growth:4.1f}GB) [{samples:3d} samples]")
|
||||||
|
|
||||||
|
# Show significant checkpoints from this file
|
||||||
|
if checkpoints:
|
||||||
|
for cp in checkpoints:
|
||||||
|
print(f" 📍 {cp['checkpoint']}: {cp['rss_gb']:.1f}GB")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Error reading {file}: {e}")
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("🎯 ANALYSIS SUMMARY")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
print(f"📈 Peak Memory: {peak_memory:.1f} GB")
|
||||||
|
print(f"🎮 Peak VRAM: {peak_vram:.1f} GB")
|
||||||
|
print(f"⚡ Growth Spikes: {len(critical_points)} events >3GB")
|
||||||
|
|
||||||
|
if critical_points:
|
||||||
|
print(f"\n💥 MEMORY GROWTH SPIKES (>3GB):")
|
||||||
|
print(" Location Growth Total VRAM")
|
||||||
|
print(" " + "-" * 55)
|
||||||
|
|
||||||
|
for point in critical_points:
|
||||||
|
location = point['checkpoint'][:30].ljust(30)
|
||||||
|
print(f" {location} +{point['growth_gb']:4.1f}GB → {point['rss_gb']:5.1f}GB {point['vram_gb']:4.1f}GB")
|
||||||
|
|
||||||
|
if all_checkpoints:
|
||||||
|
print(f"\n📍 CHECKPOINT PROGRESSION:")
|
||||||
|
print(" Checkpoint Memory VRAM File")
|
||||||
|
print(" " + "-" * 55)
|
||||||
|
|
||||||
|
for cp in all_checkpoints:
|
||||||
|
checkpoint = cp['checkpoint'][:30].ljust(30)
|
||||||
|
file_num = cp['file_index'] + 1
|
||||||
|
print(f" {checkpoint} {cp['rss_gb']:5.1f}GB {cp['vram_gb']:4.1f}GB #{file_num}")
|
||||||
|
|
||||||
|
# Memory growth analysis
|
||||||
|
if len(all_checkpoints) > 1:
|
||||||
|
print(f"\n📊 MEMORY GROWTH ANALYSIS:")
|
||||||
|
|
||||||
|
# Find the biggest memory jumps between checkpoints
|
||||||
|
big_jumps = []
|
||||||
|
for i in range(1, len(all_checkpoints)):
|
||||||
|
prev_cp = all_checkpoints[i-1]
|
||||||
|
curr_cp = all_checkpoints[i]
|
||||||
|
|
||||||
|
growth = curr_cp['rss_gb'] - prev_cp['rss_gb']
|
||||||
|
if growth > 2.0: # >2GB jump
|
||||||
|
big_jumps.append({
|
||||||
|
'from': prev_cp['checkpoint'],
|
||||||
|
'to': curr_cp['checkpoint'],
|
||||||
|
'growth': growth,
|
||||||
|
'from_memory': prev_cp['rss_gb'],
|
||||||
|
'to_memory': curr_cp['rss_gb']
|
||||||
|
})
|
||||||
|
|
||||||
|
if big_jumps:
|
||||||
|
print(" Major jumps (>2GB):")
|
||||||
|
for jump in big_jumps:
|
||||||
|
print(f" {jump['from']} → {jump['to']}: "
|
||||||
|
f"+{jump['growth']:.1f}GB ({jump['from_memory']:.1f}→{jump['to_memory']:.1f}GB)")
|
||||||
|
else:
|
||||||
|
print(" ✅ No major memory jumps detected")
|
||||||
|
|
||||||
|
# Diagnosis
|
||||||
|
print(f"\n🔬 DIAGNOSIS:")
|
||||||
|
|
||||||
|
if peak_memory > 400:
|
||||||
|
print(" 🔴 CRITICAL: Memory usage exceeded 400GB")
|
||||||
|
print(" 💡 Recommendation: Reduce chunk_size to 200-300 frames")
|
||||||
|
elif peak_memory > 200:
|
||||||
|
print(" 🟡 HIGH: Memory usage over 200GB")
|
||||||
|
print(" 💡 Recommendation: Reduce chunk_size to 400 frames")
|
||||||
|
else:
|
||||||
|
print(" 🟢 MODERATE: Memory usage under 200GB")
|
||||||
|
|
||||||
|
if critical_points:
|
||||||
|
# Find most common growth spike locations
|
||||||
|
spike_locations = {}
|
||||||
|
for point in critical_points:
|
||||||
|
location = point['checkpoint']
|
||||||
|
spike_locations[location] = spike_locations.get(location, 0) + 1
|
||||||
|
|
||||||
|
print("\n 🎯 Most problematic locations:")
|
||||||
|
for location, count in sorted(spike_locations.items(), key=lambda x: x[1], reverse=True)[:3]:
|
||||||
|
print(f" {location}: {count} spikes")
|
||||||
|
|
||||||
|
print(f"\n💡 NEXT STEPS:")
|
||||||
|
if 'merge' in str(critical_points).lower():
|
||||||
|
print(" 1. Chunk merging still causing memory accumulation")
|
||||||
|
print(" 2. Check if streaming merge is actually being used")
|
||||||
|
print(" 3. Verify chunk files are being deleted immediately")
|
||||||
|
elif 'propagation' in str(critical_points).lower():
|
||||||
|
print(" 1. SAM2 propagation using too much memory")
|
||||||
|
print(" 2. Reduce chunk_size further (try 300 frames)")
|
||||||
|
print(" 3. Enable more aggressive frame release")
|
||||||
|
else:
|
||||||
|
print(" 1. Review the checkpoint progression above")
|
||||||
|
print(" 2. Focus on locations with biggest memory spikes")
|
||||||
|
print(" 3. Consider reducing chunk_size if spikes are large")
|
||||||
|
|
||||||
|
def main():
|
||||||
|
print("🔍 MEMORY PROFILE ANALYZER")
|
||||||
|
print("Analyzing memory profile files for OOM causes...")
|
||||||
|
print()
|
||||||
|
|
||||||
|
analyze_memory_files()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -3,8 +3,8 @@ input:
|
|||||||
|
|
||||||
processing:
|
processing:
|
||||||
scale_factor: 0.5 # A40 can handle 0.5 well
|
scale_factor: 0.5 # A40 can handle 0.5 well
|
||||||
chunk_size: 0 # Auto-calculate based on A40's 48GB VRAM
|
chunk_size: 600 # Category A.4: Larger chunks for better VRAM utilization (was 200)
|
||||||
overlap_frames: 60
|
overlap_frames: 30 # Reduced overlap
|
||||||
|
|
||||||
detection:
|
detection:
|
||||||
confidence_threshold: 0.7
|
confidence_threshold: 0.7
|
||||||
@@ -19,9 +19,11 @@ matting:
|
|||||||
|
|
||||||
output:
|
output:
|
||||||
path: "/workspace/output/matted_video.mp4"
|
path: "/workspace/output/matted_video.mp4"
|
||||||
format: "alpha"
|
format: "greenscreen" # Changed to greenscreen for easier testing
|
||||||
background_color: [0, 255, 0]
|
background_color: [0, 255, 0]
|
||||||
maintain_sbs: true
|
maintain_sbs: true
|
||||||
|
preserve_audio: true # Category A.1: Audio preservation
|
||||||
|
verify_sync: true # Category A.2: Frame count validation
|
||||||
|
|
||||||
hardware:
|
hardware:
|
||||||
device: "cuda"
|
device: "cuda"
|
||||||
|
|||||||
151
debug_memory_leak.py
Normal file
151
debug_memory_leak.py
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Debug memory leak between chunks - track exactly where memory accumulates
|
||||||
|
"""
|
||||||
|
|
||||||
|
import psutil
|
||||||
|
import gc
|
||||||
|
from pathlib import Path
|
||||||
|
import sys
|
||||||
|
|
||||||
|
def detailed_memory_check(label):
|
||||||
|
"""Get detailed memory info"""
|
||||||
|
process = psutil.Process()
|
||||||
|
memory_info = process.memory_info()
|
||||||
|
|
||||||
|
rss_gb = memory_info.rss / (1024**3)
|
||||||
|
vms_gb = memory_info.vms / (1024**3)
|
||||||
|
|
||||||
|
# System memory
|
||||||
|
sys_memory = psutil.virtual_memory()
|
||||||
|
available_gb = sys_memory.available / (1024**3)
|
||||||
|
|
||||||
|
print(f"🔍 {label}:")
|
||||||
|
print(f" RSS: {rss_gb:.2f} GB (physical memory)")
|
||||||
|
print(f" VMS: {vms_gb:.2f} GB (virtual memory)")
|
||||||
|
print(f" Available: {available_gb:.2f} GB")
|
||||||
|
|
||||||
|
return rss_gb
|
||||||
|
|
||||||
|
def simulate_chunk_processing():
|
||||||
|
"""Simulate the chunk processing to see where memory accumulates"""
|
||||||
|
|
||||||
|
print("🚀 SIMULATING CHUNK PROCESSING TO FIND MEMORY LEAK")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
base_memory = detailed_memory_check("0. Baseline")
|
||||||
|
|
||||||
|
# Step 1: Import everything (with lazy loading)
|
||||||
|
print("\n📦 Step 1: Imports")
|
||||||
|
from vr180_matting.config import VR180Config
|
||||||
|
from vr180_matting.vr180_processor import VR180Processor
|
||||||
|
|
||||||
|
import_memory = detailed_memory_check("1. After imports")
|
||||||
|
import_growth = import_memory - base_memory
|
||||||
|
print(f" Growth: +{import_growth:.2f} GB")
|
||||||
|
|
||||||
|
# Step 2: Load config
|
||||||
|
print("\n⚙️ Step 2: Config loading")
|
||||||
|
config = VR180Config.from_yaml('config.yaml')
|
||||||
|
config_memory = detailed_memory_check("2. After config load")
|
||||||
|
config_growth = config_memory - import_memory
|
||||||
|
print(f" Growth: +{config_growth:.2f} GB")
|
||||||
|
|
||||||
|
# Step 3: Initialize processor (models still lazy)
|
||||||
|
print("\n🏗️ Step 3: Processor initialization")
|
||||||
|
processor = VR180Processor(config)
|
||||||
|
processor_memory = detailed_memory_check("3. After processor init")
|
||||||
|
processor_growth = processor_memory - config_memory
|
||||||
|
print(f" Growth: +{processor_growth:.2f} GB")
|
||||||
|
|
||||||
|
# Step 4: Load video info (lightweight)
|
||||||
|
print("\n🎬 Step 4: Video info loading")
|
||||||
|
try:
|
||||||
|
video_info = processor.load_video_info(config.input.video_path)
|
||||||
|
print(f" Video: {video_info.get('width', 'unknown')}x{video_info.get('height', 'unknown')}, "
|
||||||
|
f"{video_info.get('total_frames', 'unknown')} frames")
|
||||||
|
except Exception as e:
|
||||||
|
print(f" Warning: Could not load video info: {e}")
|
||||||
|
|
||||||
|
video_info_memory = detailed_memory_check("4. After video info")
|
||||||
|
video_info_growth = video_info_memory - processor_memory
|
||||||
|
print(f" Growth: +{video_info_growth:.2f} GB")
|
||||||
|
|
||||||
|
# Step 5: Simulate chunk 0 processing (this is where models actually load)
|
||||||
|
print("\n🔄 Step 5: Simulating chunk 0 processing...")
|
||||||
|
|
||||||
|
# This is where the real memory usage starts
|
||||||
|
print(" Loading first 10 frames to trigger model loading...")
|
||||||
|
try:
|
||||||
|
# Read a small number of frames to trigger model loading
|
||||||
|
frames = processor.read_video_frames(
|
||||||
|
config.input.video_path,
|
||||||
|
start_frame=0,
|
||||||
|
num_frames=10, # Just 10 frames to trigger model loading
|
||||||
|
scale_factor=config.processing.scale_factor
|
||||||
|
)
|
||||||
|
|
||||||
|
frames_memory = detailed_memory_check("5a. After reading 10 frames")
|
||||||
|
frames_growth = frames_memory - video_info_memory
|
||||||
|
print(f" 10 frames growth: +{frames_growth:.2f} GB")
|
||||||
|
|
||||||
|
# Free frames
|
||||||
|
del frames
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
after_free_memory = detailed_memory_check("5b. After freeing 10 frames")
|
||||||
|
free_improvement = frames_memory - after_free_memory
|
||||||
|
print(f" Memory freed: -{free_improvement:.2f} GB")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f" Could not simulate frame loading: {e}")
|
||||||
|
after_free_memory = video_info_memory
|
||||||
|
|
||||||
|
print(f"\n📊 MEMORY ANALYSIS:")
|
||||||
|
print(f" Baseline → Final: {base_memory:.2f}GB → {after_free_memory:.2f}GB")
|
||||||
|
print(f" Total growth: +{after_free_memory - base_memory:.2f}GB")
|
||||||
|
|
||||||
|
if after_free_memory - base_memory > 10:
|
||||||
|
print(f" 🔴 HIGH: Memory growth > 10GB before any real processing")
|
||||||
|
print(f" 💡 This suggests model loading is using too much memory")
|
||||||
|
elif after_free_memory - base_memory > 5:
|
||||||
|
print(f" 🟡 MODERATE: Memory growth 5-10GB")
|
||||||
|
print(f" 💡 Normal for model loading, but monitor chunk processing")
|
||||||
|
else:
|
||||||
|
print(f" 🟢 GOOD: Memory growth < 5GB")
|
||||||
|
print(f" 💡 Initialization memory usage is reasonable")
|
||||||
|
|
||||||
|
print(f"\n🎯 KEY INSIGHTS:")
|
||||||
|
if import_growth > 1:
|
||||||
|
print(f" - Import growth: {import_growth:.2f}GB (fixed with lazy loading)")
|
||||||
|
if processor_growth > 10:
|
||||||
|
print(f" - Processor init: {processor_growth:.2f}GB (investigate model pre-loading)")
|
||||||
|
|
||||||
|
print(f"\n💡 RECOMMENDATIONS:")
|
||||||
|
if after_free_memory - base_memory > 15:
|
||||||
|
print(f" 1. Reduce chunk_size to 200-300 frames")
|
||||||
|
print(f" 2. Use smaller models (yolov8n instead of yolov8m)")
|
||||||
|
print(f" 3. Enable FP16 mode for SAM2")
|
||||||
|
elif after_free_memory - base_memory > 8:
|
||||||
|
print(f" 1. Monitor chunk processing carefully")
|
||||||
|
print(f" 2. Use streaming merge (should be automatic)")
|
||||||
|
print(f" 3. Current settings may be acceptable")
|
||||||
|
else:
|
||||||
|
print(f" 1. Settings look good for initialization")
|
||||||
|
print(f" 2. Focus on chunk processing memory leaks")
|
||||||
|
|
||||||
|
def main():
|
||||||
|
if len(sys.argv) != 2:
|
||||||
|
print("Usage: python debug_memory_leak.py <config.yaml>")
|
||||||
|
print("This simulates initialization to find memory leaks")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
config_path = sys.argv[1]
|
||||||
|
if not Path(config_path).exists():
|
||||||
|
print(f"Config file not found: {config_path}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
simulate_chunk_processing()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
249
memory_profiler_script.py
Normal file
249
memory_profiler_script.py
Normal file
@@ -0,0 +1,249 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Memory profiling script for VR180 matting pipeline
|
||||||
|
Tracks memory usage during processing to identify leaks
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import psutil
|
||||||
|
import tracemalloc
|
||||||
|
import subprocess
|
||||||
|
import gc
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Tuple
|
||||||
|
import threading
|
||||||
|
import json
|
||||||
|
|
||||||
|
class MemoryProfiler:
|
||||||
|
def __init__(self, output_file: str = "memory_profile.json"):
|
||||||
|
self.output_file = output_file
|
||||||
|
self.data = []
|
||||||
|
self.process = psutil.Process()
|
||||||
|
self.running = False
|
||||||
|
self.thread = None
|
||||||
|
self.checkpoint_counter = 0
|
||||||
|
|
||||||
|
def start_monitoring(self, interval: float = 1.0):
|
||||||
|
"""Start continuous memory monitoring"""
|
||||||
|
tracemalloc.start()
|
||||||
|
self.running = True
|
||||||
|
self.thread = threading.Thread(target=self._monitor_loop, args=(interval,))
|
||||||
|
self.thread.daemon = True
|
||||||
|
self.thread.start()
|
||||||
|
print(f"🔍 Memory monitoring started (interval: {interval}s)")
|
||||||
|
|
||||||
|
def stop_monitoring(self):
|
||||||
|
"""Stop monitoring and save results"""
|
||||||
|
self.running = False
|
||||||
|
if self.thread:
|
||||||
|
self.thread.join()
|
||||||
|
|
||||||
|
# Get tracemalloc snapshot
|
||||||
|
snapshot = tracemalloc.take_snapshot()
|
||||||
|
top_stats = snapshot.statistics('lineno')
|
||||||
|
|
||||||
|
# Save detailed results
|
||||||
|
results = {
|
||||||
|
'timeline': self.data,
|
||||||
|
'top_memory_allocations': [
|
||||||
|
{
|
||||||
|
'file': stat.traceback.format()[0],
|
||||||
|
'size_mb': stat.size / 1024 / 1024,
|
||||||
|
'count': stat.count
|
||||||
|
}
|
||||||
|
for stat in top_stats[:20] # Top 20 allocations
|
||||||
|
],
|
||||||
|
'summary': {
|
||||||
|
'peak_rss_gb': max([d['rss_gb'] for d in self.data]) if self.data else 0,
|
||||||
|
'peak_vram_gb': max([d['vram_gb'] for d in self.data]) if self.data else 0,
|
||||||
|
'total_samples': len(self.data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(self.output_file, 'w') as f:
|
||||||
|
json.dump(results, f, indent=2)
|
||||||
|
|
||||||
|
tracemalloc.stop()
|
||||||
|
print(f"📊 Memory profile saved to {self.output_file}")
|
||||||
|
|
||||||
|
def _monitor_loop(self, interval: float):
|
||||||
|
"""Continuous monitoring loop"""
|
||||||
|
while self.running:
|
||||||
|
try:
|
||||||
|
# System memory
|
||||||
|
memory_info = self.process.memory_info()
|
||||||
|
rss_gb = memory_info.rss / (1024**3)
|
||||||
|
|
||||||
|
# System-wide memory
|
||||||
|
sys_memory = psutil.virtual_memory()
|
||||||
|
sys_used_gb = (sys_memory.total - sys_memory.available) / (1024**3)
|
||||||
|
sys_available_gb = sys_memory.available / (1024**3)
|
||||||
|
|
||||||
|
# GPU memory (if available)
|
||||||
|
vram_gb = 0
|
||||||
|
vram_free_gb = 0
|
||||||
|
try:
|
||||||
|
result = subprocess.run(['nvidia-smi', '--query-gpu=memory.used,memory.free',
|
||||||
|
'--format=csv,noheader,nounits'],
|
||||||
|
capture_output=True, text=True, timeout=5)
|
||||||
|
if result.returncode == 0:
|
||||||
|
lines = result.stdout.strip().split('\n')
|
||||||
|
if lines and lines[0]:
|
||||||
|
used, free = lines[0].split(', ')
|
||||||
|
vram_gb = float(used) / 1024
|
||||||
|
vram_free_gb = float(free) / 1024
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Tracemalloc current usage
|
||||||
|
try:
|
||||||
|
current, peak = tracemalloc.get_traced_memory()
|
||||||
|
traced_mb = current / (1024**2)
|
||||||
|
except Exception:
|
||||||
|
traced_mb = 0
|
||||||
|
|
||||||
|
data_point = {
|
||||||
|
'timestamp': time.time(),
|
||||||
|
'rss_gb': rss_gb,
|
||||||
|
'vram_gb': vram_gb,
|
||||||
|
'vram_free_gb': vram_free_gb,
|
||||||
|
'sys_used_gb': sys_used_gb,
|
||||||
|
'sys_available_gb': sys_available_gb,
|
||||||
|
'traced_mb': traced_mb
|
||||||
|
}
|
||||||
|
|
||||||
|
self.data.append(data_point)
|
||||||
|
|
||||||
|
# Print periodic updates and save partial data
|
||||||
|
if len(self.data) % 10 == 0: # Every 10 samples
|
||||||
|
print(f"🔍 Memory: RSS={rss_gb:.2f}GB, VRAM={vram_gb:.2f}GB, Sys={sys_used_gb:.1f}GB")
|
||||||
|
|
||||||
|
# Save partial data every 30 samples in case of crash
|
||||||
|
if len(self.data) % 30 == 0:
|
||||||
|
self._save_partial_data()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Monitoring error: {e}")
|
||||||
|
|
||||||
|
time.sleep(interval)
|
||||||
|
|
||||||
|
def _save_partial_data(self):
|
||||||
|
"""Save partial data to prevent loss on crash"""
|
||||||
|
try:
|
||||||
|
partial_file = f"memory_profile_partial_{self.checkpoint_counter}.json"
|
||||||
|
with open(partial_file, 'w') as f:
|
||||||
|
json.dump({
|
||||||
|
'timeline': self.data,
|
||||||
|
'status': 'partial_save',
|
||||||
|
'samples': len(self.data)
|
||||||
|
}, f, indent=2)
|
||||||
|
self.checkpoint_counter += 1
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to save partial data: {e}")
|
||||||
|
|
||||||
|
def log_checkpoint(self, checkpoint_name: str):
|
||||||
|
"""Log a specific checkpoint"""
|
||||||
|
if self.data:
|
||||||
|
self.data[-1]['checkpoint'] = checkpoint_name
|
||||||
|
latest = self.data[-1]
|
||||||
|
print(f"📍 CHECKPOINT [{checkpoint_name}]: RSS={latest['rss_gb']:.2f}GB, VRAM={latest['vram_gb']:.2f}GB")
|
||||||
|
|
||||||
|
# Save checkpoint data immediately
|
||||||
|
self._save_partial_data()
|
||||||
|
|
||||||
|
def run_with_profiling(config_path: str):
|
||||||
|
"""Run the VR180 matting with memory profiling"""
|
||||||
|
profiler = MemoryProfiler("memory_profile_detailed.json")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Start monitoring
|
||||||
|
profiler.start_monitoring(interval=2.0) # Sample every 2 seconds
|
||||||
|
|
||||||
|
# Log initial state
|
||||||
|
profiler.log_checkpoint("STARTUP")
|
||||||
|
|
||||||
|
# Import after starting profiler to catch import memory usage
|
||||||
|
print("Importing VR180 processor...")
|
||||||
|
from vr180_matting.vr180_processor import VR180Processor
|
||||||
|
from vr180_matting.config import VR180Config
|
||||||
|
|
||||||
|
profiler.log_checkpoint("IMPORTS_COMPLETE")
|
||||||
|
|
||||||
|
# Load config
|
||||||
|
print(f"Loading config from {config_path}")
|
||||||
|
config = VR180Config.from_yaml(config_path)
|
||||||
|
|
||||||
|
profiler.log_checkpoint("CONFIG_LOADED")
|
||||||
|
|
||||||
|
# Initialize processor
|
||||||
|
print("Initializing VR180 processor...")
|
||||||
|
processor = VR180Processor(config)
|
||||||
|
|
||||||
|
profiler.log_checkpoint("PROCESSOR_INITIALIZED")
|
||||||
|
|
||||||
|
# Force garbage collection
|
||||||
|
gc.collect()
|
||||||
|
profiler.log_checkpoint("INITIAL_GC_COMPLETE")
|
||||||
|
|
||||||
|
# Run processing
|
||||||
|
print("Starting VR180 processing...")
|
||||||
|
processor.process_video()
|
||||||
|
|
||||||
|
profiler.log_checkpoint("PROCESSING_COMPLETE")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Error during processing: {e}")
|
||||||
|
profiler.log_checkpoint(f"ERROR: {str(e)}")
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
# Stop monitoring and save results
|
||||||
|
profiler.stop_monitoring()
|
||||||
|
|
||||||
|
# Print summary
|
||||||
|
print("\n" + "="*60)
|
||||||
|
print("MEMORY PROFILING SUMMARY")
|
||||||
|
print("="*60)
|
||||||
|
|
||||||
|
if profiler.data:
|
||||||
|
peak_rss = max([d['rss_gb'] for d in profiler.data])
|
||||||
|
peak_vram = max([d['vram_gb'] for d in profiler.data])
|
||||||
|
|
||||||
|
print(f"Peak RSS Memory: {peak_rss:.2f} GB")
|
||||||
|
print(f"Peak VRAM Usage: {peak_vram:.2f} GB")
|
||||||
|
print(f"Total Samples: {len(profiler.data)}")
|
||||||
|
|
||||||
|
# Show checkpoints
|
||||||
|
checkpoints = [d for d in profiler.data if 'checkpoint' in d]
|
||||||
|
if checkpoints:
|
||||||
|
print(f"\nCheckpoints ({len(checkpoints)}):")
|
||||||
|
for cp in checkpoints:
|
||||||
|
print(f" {cp['checkpoint']}: RSS={cp['rss_gb']:.2f}GB, VRAM={cp['vram_gb']:.2f}GB")
|
||||||
|
|
||||||
|
print(f"\nDetailed profile saved to: {profiler.output_file}")
|
||||||
|
|
||||||
|
def main():
|
||||||
|
if len(sys.argv) != 2:
|
||||||
|
print("Usage: python memory_profiler_script.py <config.yaml>")
|
||||||
|
print("\nThis script runs VR180 matting with detailed memory profiling")
|
||||||
|
print("It will:")
|
||||||
|
print("- Monitor RSS, VRAM, and system memory every 2 seconds")
|
||||||
|
print("- Track memory allocations with tracemalloc")
|
||||||
|
print("- Log checkpoints at key processing stages")
|
||||||
|
print("- Save detailed JSON report for analysis")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
config_path = sys.argv[1]
|
||||||
|
|
||||||
|
if not Path(config_path).exists():
|
||||||
|
print(f"❌ Config file not found: {config_path}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
print("🚀 Starting VR180 Memory Profiling")
|
||||||
|
print(f"Config: {config_path}")
|
||||||
|
print("="*60)
|
||||||
|
|
||||||
|
run_with_profiling(config_path)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
125
quick_memory_check.py
Normal file
125
quick_memory_check.py
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Quick memory and system check before running full pipeline
|
||||||
|
"""
|
||||||
|
|
||||||
|
import psutil
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
def check_system():
|
||||||
|
"""Check system resources before starting"""
|
||||||
|
print("🔍 SYSTEM RESOURCE CHECK")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
# Memory info
|
||||||
|
memory = psutil.virtual_memory()
|
||||||
|
print(f"📊 RAM:")
|
||||||
|
print(f" Total: {memory.total / (1024**3):.1f} GB")
|
||||||
|
print(f" Available: {memory.available / (1024**3):.1f} GB")
|
||||||
|
print(f" Used: {(memory.total - memory.available) / (1024**3):.1f} GB ({memory.percent:.1f}%)")
|
||||||
|
|
||||||
|
# GPU info
|
||||||
|
try:
|
||||||
|
result = subprocess.run(['nvidia-smi', '--query-gpu=name,memory.total,memory.used,memory.free',
|
||||||
|
'--format=csv,noheader,nounits'],
|
||||||
|
capture_output=True, text=True, timeout=10)
|
||||||
|
if result.returncode == 0:
|
||||||
|
lines = result.stdout.strip().split('\n')
|
||||||
|
print(f"\n🎮 GPU:")
|
||||||
|
for i, line in enumerate(lines):
|
||||||
|
if line.strip():
|
||||||
|
parts = line.split(', ')
|
||||||
|
if len(parts) >= 4:
|
||||||
|
name, total, used, free = parts[:4]
|
||||||
|
total_gb = float(total) / 1024
|
||||||
|
used_gb = float(used) / 1024
|
||||||
|
free_gb = float(free) / 1024
|
||||||
|
print(f" GPU {i}: {name}")
|
||||||
|
print(f" VRAM: {used_gb:.1f}/{total_gb:.1f} GB ({used_gb/total_gb*100:.1f}% used)")
|
||||||
|
print(f" Free: {free_gb:.1f} GB")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n⚠️ Could not get GPU info: {e}")
|
||||||
|
|
||||||
|
# Disk space
|
||||||
|
disk = psutil.disk_usage('/')
|
||||||
|
print(f"\n💾 Disk (/):")
|
||||||
|
print(f" Total: {disk.total / (1024**3):.1f} GB")
|
||||||
|
print(f" Used: {disk.used / (1024**3):.1f} GB ({disk.used/disk.total*100:.1f}%)")
|
||||||
|
print(f" Free: {disk.free / (1024**3):.1f} GB")
|
||||||
|
|
||||||
|
# Check config file
|
||||||
|
if len(sys.argv) > 1:
|
||||||
|
config_path = sys.argv[1]
|
||||||
|
if Path(config_path).exists():
|
||||||
|
print(f"\n✅ Config file found: {config_path}")
|
||||||
|
|
||||||
|
# Try to load and show key settings
|
||||||
|
try:
|
||||||
|
import yaml
|
||||||
|
with open(config_path, 'r') as f:
|
||||||
|
config = yaml.safe_load(f)
|
||||||
|
|
||||||
|
print(f"📋 Key Settings:")
|
||||||
|
if 'processing' in config:
|
||||||
|
proc = config['processing']
|
||||||
|
print(f" Chunk size: {proc.get('chunk_size', 'default')}")
|
||||||
|
print(f" Scale factor: {proc.get('scale_factor', 'default')}")
|
||||||
|
|
||||||
|
if 'hardware' in config:
|
||||||
|
hw = config['hardware']
|
||||||
|
print(f" Max VRAM: {hw.get('max_vram_gb', 'default')} GB")
|
||||||
|
|
||||||
|
if 'input' in config:
|
||||||
|
inp = config['input']
|
||||||
|
video_path = inp.get('video_path', '')
|
||||||
|
if video_path and Path(video_path).exists():
|
||||||
|
size_gb = Path(video_path).stat().st_size / (1024**3)
|
||||||
|
print(f" Input video: {video_path} ({size_gb:.1f} GB)")
|
||||||
|
else:
|
||||||
|
print(f" ⚠️ Input video not found: {video_path}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f" ⚠️ Could not parse config: {e}")
|
||||||
|
else:
|
||||||
|
print(f"\n❌ Config file not found: {config_path}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Memory safety warnings
|
||||||
|
print(f"\n⚠️ MEMORY SAFETY CHECKS:")
|
||||||
|
available_gb = memory.available / (1024**3)
|
||||||
|
|
||||||
|
if available_gb < 10:
|
||||||
|
print(f" 🔴 LOW MEMORY: Only {available_gb:.1f}GB available")
|
||||||
|
print(" Consider: reducing chunk_size or scale_factor")
|
||||||
|
return False
|
||||||
|
elif available_gb < 20:
|
||||||
|
print(f" 🟡 MODERATE MEMORY: {available_gb:.1f}GB available")
|
||||||
|
print(" Recommend: chunk_size ≤ 300, scale_factor ≤ 0.5")
|
||||||
|
else:
|
||||||
|
print(f" 🟢 GOOD MEMORY: {available_gb:.1f}GB available")
|
||||||
|
|
||||||
|
print(f"\n" + "=" * 50)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def main():
|
||||||
|
if len(sys.argv) != 2:
|
||||||
|
print("Usage: python quick_memory_check.py <config.yaml>")
|
||||||
|
print("\nThis checks system resources before running VR180 matting")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
safe_to_run = check_system()
|
||||||
|
|
||||||
|
if safe_to_run:
|
||||||
|
print("✅ System check passed - safe to run VR180 matting")
|
||||||
|
print("\nTo run with memory profiling:")
|
||||||
|
print(f" python memory_profiler_script.py {sys.argv[1]}")
|
||||||
|
print("\nTo run normally:")
|
||||||
|
print(f" vr180-matting {sys.argv[1]}")
|
||||||
|
else:
|
||||||
|
print("❌ System check failed - address issues before running")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -9,3 +9,7 @@ ultralytics>=8.0.0
|
|||||||
tqdm>=4.65.0
|
tqdm>=4.65.0
|
||||||
psutil>=5.9.0
|
psutil>=5.9.0
|
||||||
ffmpeg-python>=0.2.0
|
ffmpeg-python>=0.2.0
|
||||||
|
decord>=0.6.0
|
||||||
|
# GPU acceleration (optional but recommended for stereo validation speedup)
|
||||||
|
# cupy-cuda11x>=12.0.0 # For CUDA 11.x
|
||||||
|
# cupy-cuda12x>=12.0.0 # For CUDA 12.x - uncomment appropriate version
|
||||||
@@ -14,6 +14,32 @@ echo "🐍 Installing Python dependencies..."
|
|||||||
pip install --upgrade pip
|
pip install --upgrade pip
|
||||||
pip install -r requirements.txt
|
pip install -r requirements.txt
|
||||||
|
|
||||||
|
# Install decord for SAM2 video loading
|
||||||
|
echo "📹 Installing decord for video processing..."
|
||||||
|
pip install decord
|
||||||
|
|
||||||
|
# Install CuPy for GPU acceleration of stereo validation
|
||||||
|
echo "🚀 Installing CuPy for GPU acceleration..."
|
||||||
|
# Auto-detect CUDA version and install appropriate CuPy
|
||||||
|
python -c "
|
||||||
|
import torch
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
cuda_version = torch.version.cuda
|
||||||
|
print(f'CUDA version detected: {cuda_version}')
|
||||||
|
if cuda_version.startswith('11.'):
|
||||||
|
import subprocess
|
||||||
|
subprocess.run(['pip', 'install', 'cupy-cuda11x>=12.0.0'])
|
||||||
|
print('Installed CuPy for CUDA 11.x')
|
||||||
|
elif cuda_version.startswith('12.'):
|
||||||
|
import subprocess
|
||||||
|
subprocess.run(['pip', 'install', 'cupy-cuda12x>=12.0.0'])
|
||||||
|
print('Installed CuPy for CUDA 12.x')
|
||||||
|
else:
|
||||||
|
print(f'Unsupported CUDA version: {cuda_version}')
|
||||||
|
else:
|
||||||
|
print('CUDA not available, skipping CuPy installation')
|
||||||
|
"
|
||||||
|
|
||||||
# Install SAM2 separately (not on PyPI)
|
# Install SAM2 separately (not on PyPI)
|
||||||
echo "🎯 Installing SAM2..."
|
echo "🎯 Installing SAM2..."
|
||||||
pip install git+https://github.com/facebookresearch/segment-anything-2.git
|
pip install git+https://github.com/facebookresearch/segment-anything-2.git
|
||||||
|
|||||||
198
spec.md
198
spec.md
@@ -123,6 +123,204 @@ hardware:
|
|||||||
3. **Performance Profiling**: Detailed resource usage analytics
|
3. **Performance Profiling**: Detailed resource usage analytics
|
||||||
4. **Quality Validation**: Comprehensive testing suite
|
4. **Quality Validation**: Comprehensive testing suite
|
||||||
|
|
||||||
|
## Post-Implementation Optimization Opportunities
|
||||||
|
|
||||||
|
*Based on first successful 30-second test clip execution results (A40 GPU, 50% scale, 9x200 frame chunks)*
|
||||||
|
|
||||||
|
### Performance Analysis Findings
|
||||||
|
- **Processing Speed**: ~0.54s per frame (64.4s for 120 frames per chunk)
|
||||||
|
- **VRAM Utilization**: Only 2.5% (1.11GB of 45GB available) - significantly underutilized
|
||||||
|
- **RAM Usage**: 106GB used of 494GB available (21.5%)
|
||||||
|
- **Primary Bottleneck**: Intermediate ffmpeg encoding operations per chunk
|
||||||
|
|
||||||
|
### Identified Optimization Categories
|
||||||
|
|
||||||
|
#### Category A: Performance Improvements (Quick Wins)
|
||||||
|
1. **Audio Track Preservation** ⚠️ **CRITICAL**
|
||||||
|
- Issue: Output video missing audio track from input
|
||||||
|
- Solution: Use ffmpeg to copy audio stream during final video creation
|
||||||
|
- Implementation: Add `-c:a copy` to final ffmpeg command
|
||||||
|
- Impact: Essential for production usability
|
||||||
|
- Risk: Low, standard ffmpeg operation
|
||||||
|
|
||||||
|
2. **Frame Count Synchronization** ⚠️ **CRITICAL**
|
||||||
|
- Issue: Audio sync drift if input/output frame counts differ
|
||||||
|
- Solution: Validate exact frame count preservation throughout pipeline
|
||||||
|
- Implementation: Frame count verification + duration matching
|
||||||
|
- Impact: Prevents audio desync in long videos
|
||||||
|
- Risk: Low, validation feature
|
||||||
|
|
||||||
|
3. **Memory Usage Reality Check** ⚠️ **IMPORTANT**
|
||||||
|
- Current assumption: Unlimited RAM for memory-only pipeline
|
||||||
|
- Reality: RunPod container limited to ~48GB RAM
|
||||||
|
- Risk calculation: 1-hour video = ~213k frames = potential 20-40GB+ memory usage
|
||||||
|
- Solution: Implement streaming output instead of full in-memory accumulation
|
||||||
|
- Impact: Enables processing of long-form content
|
||||||
|
- Risk: Medium, requires pipeline restructuring
|
||||||
|
|
||||||
|
4. **Larger Chunk Sizes**
|
||||||
|
- Current: 200 frames per chunk (conservative for 10GB RTX 3080)
|
||||||
|
- Opportunity: 600-800 frames per chunk on high-VRAM systems
|
||||||
|
- Impact: Reduce 9 chunks to 2-3 chunks, fewer intermediate operations
|
||||||
|
- Risk: Low, easily configurable
|
||||||
|
|
||||||
|
5. **Streaming Output Pipeline**
|
||||||
|
- Current: Accumulate all processed frames in memory, write once
|
||||||
|
- Opportunity: Write processed chunks to temporary segments, merge at end
|
||||||
|
- Impact: Constant memory usage regardless of video length
|
||||||
|
- Risk: Medium, requires temporary file management
|
||||||
|
|
||||||
|
6. **Enhanced Performance Profiling**
|
||||||
|
- Current: Basic memory monitoring
|
||||||
|
- Opportunity: Detailed timing per processing stage (detection, propagation, encoding)
|
||||||
|
- Impact: Identify exact bottlenecks for targeted optimization
|
||||||
|
- Risk: Low, debugging feature
|
||||||
|
|
||||||
|
7. **Parallel Eye Processing**
|
||||||
|
- Current: Sequential left eye → right eye processing
|
||||||
|
- Opportunity: Process both eyes simultaneously
|
||||||
|
- Impact: Potential 50% speedup, better GPU utilization
|
||||||
|
- Risk: Medium, memory management complexity
|
||||||
|
|
||||||
|
#### Category B: Stereo Consistency Fixes (Critical for VR)
|
||||||
|
1. **Master-Slave Eye Processing**
|
||||||
|
- Issue: Independent detection leads to mismatched person counts between eyes
|
||||||
|
- Solution: Use left eye detections as "seeds" for right eye processing
|
||||||
|
- Impact: Ensures identical person detection across stereo pair
|
||||||
|
- Risk: Low, maintains current quality while improving consistency
|
||||||
|
|
||||||
|
2. **Cross-Eye Detection Validation**
|
||||||
|
- Issue: Hair/clothing included on one eye but not the other
|
||||||
|
- Solution: Compare detection results, flag inconsistencies for reprocessing
|
||||||
|
- Impact: 90%+ stereo alignment improvement
|
||||||
|
- Risk: Low, fallback to current behavior
|
||||||
|
|
||||||
|
3. **Disparity-Aware Segmentation**
|
||||||
|
- Issue: Segmentation boundaries differ between eyes despite same person
|
||||||
|
- Solution: Use stereo disparity to correlate features between eyes
|
||||||
|
- Impact: True stereo-consistent matting
|
||||||
|
- Risk: High, complex implementation
|
||||||
|
|
||||||
|
4. **Joint Stereo Detection**
|
||||||
|
- Issue: YOLO runs independently on each eye
|
||||||
|
- Solution: Run YOLO on full SBS frame, split detections spatially
|
||||||
|
- Impact: Guaranteed identical detection counts
|
||||||
|
- Risk: Medium, requires detection coordinate mapping
|
||||||
|
|
||||||
|
#### Category C: Advanced Optimizations (Future)
|
||||||
|
1. **Adaptive Memory Management**
|
||||||
|
- Opportunity: Dynamic chunk sizing based on real-time VRAM usage
|
||||||
|
- Impact: Optimal resource utilization across different hardware
|
||||||
|
- Risk: Medium, complex heuristics
|
||||||
|
|
||||||
|
2. **Multi-Resolution Processing**
|
||||||
|
- Opportunity: Initial processing at lower resolution, edge refinement at full
|
||||||
|
- Impact: Speed improvement while maintaining quality
|
||||||
|
- Risk: Medium, quality validation required
|
||||||
|
|
||||||
|
3. **Enhanced Workflow Documentation**
|
||||||
|
- Issue: Unclear intermediate data lifecycle
|
||||||
|
- Solution: Detailed logging of chunk processing, optional intermediate preservation
|
||||||
|
- Impact: Better debugging and user understanding
|
||||||
|
- Risk: Low, documentation feature
|
||||||
|
|
||||||
|
### Implementation Strategy
|
||||||
|
- **Phase A**: Quick performance wins (larger chunks, profiling)
|
||||||
|
- **Phase B**: Stereo consistency (master-slave, validation)
|
||||||
|
- **Phase C**: Advanced features (disparity-aware, memory optimization)
|
||||||
|
|
||||||
|
### Configuration Extensions Required
|
||||||
|
```yaml
|
||||||
|
processing:
|
||||||
|
chunk_size: 600 # Increase from 200 for high-VRAM systems
|
||||||
|
memory_pipeline: false # Skip intermediate video creation (disabled due to RAM limits)
|
||||||
|
streaming_output: true # Write chunks progressively instead of accumulating
|
||||||
|
parallel_eyes: false # Process eyes simultaneously
|
||||||
|
max_memory_gb: 40 # Realistic RAM limit for RunPod containers
|
||||||
|
|
||||||
|
audio:
|
||||||
|
preserve_audio: true # Copy audio track from input to output
|
||||||
|
verify_sync: true # Validate frame count and duration matching
|
||||||
|
audio_codec: "copy" # Preserve original audio codec
|
||||||
|
|
||||||
|
stereo:
|
||||||
|
consistency_mode: "master_slave" # "independent", "master_slave", "joint"
|
||||||
|
validation_threshold: 0.8 # Similarity threshold between eyes
|
||||||
|
correction_method: "transfer" # "transfer", "reprocess", "ensemble"
|
||||||
|
|
||||||
|
performance:
|
||||||
|
profile_enabled: true # Detailed timing analysis
|
||||||
|
preserve_intermediates: false # For debugging workflow
|
||||||
|
|
||||||
|
debugging:
|
||||||
|
log_intermediate_workflow: true # Document chunk lifecycle
|
||||||
|
save_detection_visualization: false # Debug detection mismatches
|
||||||
|
frame_count_validation: true # Ensure exact frame preservation
|
||||||
|
```
|
||||||
|
|
||||||
|
### Technical Implementation Details
|
||||||
|
|
||||||
|
#### Audio Preservation Implementation
|
||||||
|
```python
|
||||||
|
# During final video save, include audio stream copy
|
||||||
|
ffmpeg_cmd = [
|
||||||
|
'ffmpeg', '-y',
|
||||||
|
'-framerate', str(fps),
|
||||||
|
'-i', frame_pattern, # Video frames
|
||||||
|
'-i', input_video_path, # Original video for audio
|
||||||
|
'-c:v', 'h264_nvenc', # GPU video codec (with CPU fallback)
|
||||||
|
'-c:a', 'copy', # Copy audio without re-encoding
|
||||||
|
'-map', '0:v:0', # Map video from first input
|
||||||
|
'-map', '1:a:0', # Map audio from second input
|
||||||
|
'-shortest', # Match shortest stream duration
|
||||||
|
output_path
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Streaming Output Implementation
|
||||||
|
```python
|
||||||
|
# Instead of accumulating frames in memory:
|
||||||
|
class StreamingVideoWriter:
|
||||||
|
def __init__(self, output_path, fps, audio_source):
|
||||||
|
self.temp_segments = []
|
||||||
|
self.current_segment = 0
|
||||||
|
|
||||||
|
def write_chunk(self, processed_frames):
|
||||||
|
# Write chunk to temporary segment
|
||||||
|
segment_path = f"temp_segment_{self.current_segment}.mp4"
|
||||||
|
self.write_video_segment(processed_frames, segment_path)
|
||||||
|
self.temp_segments.append(segment_path)
|
||||||
|
self.current_segment += 1
|
||||||
|
|
||||||
|
def finalize(self):
|
||||||
|
# Merge all segments with audio preservation
|
||||||
|
self.merge_segments_with_audio()
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Memory Usage Calculation
|
||||||
|
```python
|
||||||
|
def estimate_memory_requirements(duration_seconds, fps, resolution_scale=0.5):
|
||||||
|
"""Calculate memory usage for different video lengths"""
|
||||||
|
frames = duration_seconds * fps
|
||||||
|
|
||||||
|
# Per-frame memory (rough estimates for VR180 at 50% scale)
|
||||||
|
frame_size_mb = (3072 * 1536 * 3 * 4) / (1024 * 1024) # ~18MB per frame
|
||||||
|
|
||||||
|
total_memory_gb = (frames * frame_size_mb) / 1024
|
||||||
|
|
||||||
|
return {
|
||||||
|
'duration': duration_seconds,
|
||||||
|
'total_frames': frames,
|
||||||
|
'estimated_memory_gb': total_memory_gb,
|
||||||
|
'safe_for_48gb': total_memory_gb < 40
|
||||||
|
}
|
||||||
|
|
||||||
|
# Example outputs:
|
||||||
|
# 30 seconds: ~2.7GB (safe)
|
||||||
|
# 5 minutes: ~27GB (borderline)
|
||||||
|
# 1 hour: ~324GB (requires streaming)
|
||||||
|
```
|
||||||
|
|
||||||
## Success Criteria
|
## Success Criteria
|
||||||
|
|
||||||
### Technical Feasibility
|
### Technical Feasibility
|
||||||
|
|||||||
148
test_inter_chunk_cleanup.py
Normal file
148
test_inter_chunk_cleanup.py
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Test script to verify inter-chunk cleanup properly destroys models
|
||||||
|
"""
|
||||||
|
|
||||||
|
import psutil
|
||||||
|
import gc
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
def get_memory_usage():
|
||||||
|
"""Get current memory usage in GB"""
|
||||||
|
process = psutil.Process()
|
||||||
|
return process.memory_info().rss / (1024**3)
|
||||||
|
|
||||||
|
def test_inter_chunk_cleanup():
|
||||||
|
"""Test that models are properly destroyed between chunks"""
|
||||||
|
|
||||||
|
print("🧪 TESTING INTER-CHUNK CLEANUP")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
baseline_memory = get_memory_usage()
|
||||||
|
print(f"📊 Baseline memory: {baseline_memory:.2f} GB")
|
||||||
|
|
||||||
|
# Import and create processor
|
||||||
|
print("\n1️⃣ Creating processor...")
|
||||||
|
from vr180_matting.config import VR180Config
|
||||||
|
from vr180_matting.vr180_processor import VR180Processor
|
||||||
|
|
||||||
|
config = VR180Config.from_yaml('config.yaml')
|
||||||
|
processor = VR180Processor(config)
|
||||||
|
|
||||||
|
init_memory = get_memory_usage()
|
||||||
|
print(f"📊 After processor init: {init_memory:.2f} GB (+{init_memory - baseline_memory:.2f} GB)")
|
||||||
|
|
||||||
|
# Simulate chunk processing (just trigger model loading)
|
||||||
|
print("\n2️⃣ Simulating chunk 0 processing...")
|
||||||
|
|
||||||
|
# Test 1: Force YOLO model loading
|
||||||
|
try:
|
||||||
|
detector = processor.detector
|
||||||
|
detector._load_model() # Force load
|
||||||
|
yolo_memory = get_memory_usage()
|
||||||
|
print(f"📊 After YOLO load: {yolo_memory:.2f} GB (+{yolo_memory - init_memory:.2f} GB)")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ YOLO loading failed: {e}")
|
||||||
|
yolo_memory = init_memory
|
||||||
|
|
||||||
|
# Test 2: Force SAM2 model loading
|
||||||
|
try:
|
||||||
|
sam2_model = processor.sam2_model
|
||||||
|
sam2_model._load_model(sam2_model.model_cfg, sam2_model.checkpoint_path)
|
||||||
|
sam2_memory = get_memory_usage()
|
||||||
|
print(f"📊 After SAM2 load: {sam2_memory:.2f} GB (+{sam2_memory - yolo_memory:.2f} GB)")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ SAM2 loading failed: {e}")
|
||||||
|
sam2_memory = yolo_memory
|
||||||
|
|
||||||
|
total_model_memory = sam2_memory - init_memory
|
||||||
|
print(f"📊 Total model memory: {total_model_memory:.2f} GB")
|
||||||
|
|
||||||
|
# Test 3: Inter-chunk cleanup
|
||||||
|
print("\n3️⃣ Testing inter-chunk cleanup...")
|
||||||
|
processor._complete_inter_chunk_cleanup(chunk_idx=0)
|
||||||
|
|
||||||
|
cleanup_memory = get_memory_usage()
|
||||||
|
cleanup_improvement = sam2_memory - cleanup_memory
|
||||||
|
print(f"📊 After cleanup: {cleanup_memory:.2f} GB (-{cleanup_improvement:.2f} GB freed)")
|
||||||
|
|
||||||
|
# Test 4: Verify models reload fresh
|
||||||
|
print("\n4️⃣ Testing fresh model reload...")
|
||||||
|
|
||||||
|
# Check YOLO state
|
||||||
|
yolo_reloaded = processor.detector.model is None
|
||||||
|
print(f"🔍 YOLO model destroyed: {'✅ YES' if yolo_reloaded else '❌ NO'}")
|
||||||
|
|
||||||
|
# Check SAM2 state
|
||||||
|
sam2_reloaded = not processor.sam2_model._model_loaded or processor.sam2_model.predictor is None
|
||||||
|
print(f"🔍 SAM2 model destroyed: {'✅ YES' if sam2_reloaded else '❌ NO'}")
|
||||||
|
|
||||||
|
# Test 5: Force reload to verify they work
|
||||||
|
print("\n5️⃣ Testing model reload...")
|
||||||
|
try:
|
||||||
|
# Force YOLO reload
|
||||||
|
processor.detector._load_model()
|
||||||
|
yolo_reload_memory = get_memory_usage()
|
||||||
|
|
||||||
|
# Force SAM2 reload
|
||||||
|
processor.sam2_model._load_model(processor.sam2_model.model_cfg, processor.sam2_model.checkpoint_path)
|
||||||
|
sam2_reload_memory = get_memory_usage()
|
||||||
|
|
||||||
|
reload_growth = sam2_reload_memory - cleanup_memory
|
||||||
|
print(f"📊 After reload: {sam2_reload_memory:.2f} GB (+{reload_growth:.2f} GB)")
|
||||||
|
|
||||||
|
if abs(reload_growth - total_model_memory) < 1.0: # Within 1GB
|
||||||
|
print("✅ Models reloaded with similar memory usage (good)")
|
||||||
|
else:
|
||||||
|
print("⚠️ Model reload memory differs significantly")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Model reload failed: {e}")
|
||||||
|
|
||||||
|
# Final summary
|
||||||
|
print(f"\n📊 SUMMARY:")
|
||||||
|
print(f" Baseline → Peak: {baseline_memory:.2f}GB → {sam2_memory:.2f}GB")
|
||||||
|
print(f" Peak → Cleanup: {sam2_memory:.2f}GB → {cleanup_memory:.2f}GB")
|
||||||
|
print(f" Memory freed: {cleanup_improvement:.2f}GB")
|
||||||
|
print(f" Models destroyed: YOLO={yolo_reloaded}, SAM2={sam2_reloaded}")
|
||||||
|
|
||||||
|
# Success criteria: Both models destroyed AND can reload
|
||||||
|
models_destroyed = yolo_reloaded and sam2_reloaded
|
||||||
|
can_reload = 'reload_growth' in locals()
|
||||||
|
|
||||||
|
if models_destroyed and can_reload:
|
||||||
|
print("✅ Inter-chunk cleanup working effectively")
|
||||||
|
print("💡 Models destroyed and can reload fresh (memory will be freed during real processing)")
|
||||||
|
return True
|
||||||
|
elif models_destroyed:
|
||||||
|
print("⚠️ Models destroyed but reload test incomplete")
|
||||||
|
print("💡 This should still prevent accumulation during real processing")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
print("❌ Inter-chunk cleanup not freeing enough memory")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def main():
|
||||||
|
if len(sys.argv) != 2:
|
||||||
|
print("Usage: python test_inter_chunk_cleanup.py <config.yaml>")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
config_path = sys.argv[1]
|
||||||
|
if not Path(config_path).exists():
|
||||||
|
print(f"Config file not found: {config_path}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
success = test_inter_chunk_cleanup()
|
||||||
|
|
||||||
|
if success:
|
||||||
|
print(f"\n🎉 SUCCESS: Inter-chunk cleanup is working!")
|
||||||
|
print(f"💡 This should prevent 15-20GB model accumulation between chunks")
|
||||||
|
else:
|
||||||
|
print(f"\n❌ FAILURE: Inter-chunk cleanup needs improvement")
|
||||||
|
print(f"💡 Check model destruction logic in _complete_inter_chunk_cleanup")
|
||||||
|
|
||||||
|
return 0 if success else 1
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
sys.exit(main())
|
||||||
220
vr180_matting/checkpoint_manager.py
Normal file
220
vr180_matting/checkpoint_manager.py
Normal file
@@ -0,0 +1,220 @@
|
|||||||
|
"""
|
||||||
|
Checkpoint manager for resumable video processing
|
||||||
|
Saves progress to avoid reprocessing after OOM or crashes
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import hashlib
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Any, Optional, List
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
|
class CheckpointManager:
|
||||||
|
"""Manages processing checkpoints for resumable execution"""
|
||||||
|
|
||||||
|
def __init__(self, video_path: str, output_path: str, checkpoint_dir: Optional[Path] = None):
|
||||||
|
"""
|
||||||
|
Initialize checkpoint manager
|
||||||
|
|
||||||
|
Args:
|
||||||
|
video_path: Input video path
|
||||||
|
output_path: Output video path
|
||||||
|
checkpoint_dir: Directory for checkpoint files (default: .vr180_checkpoints in CWD)
|
||||||
|
"""
|
||||||
|
self.video_path = Path(video_path)
|
||||||
|
self.output_path = Path(output_path)
|
||||||
|
|
||||||
|
# Create unique checkpoint ID based on video file
|
||||||
|
self.video_hash = self._compute_video_hash()
|
||||||
|
|
||||||
|
# Setup checkpoint directory
|
||||||
|
if checkpoint_dir is None:
|
||||||
|
self.checkpoint_dir = Path.cwd() / ".vr180_checkpoints" / self.video_hash
|
||||||
|
else:
|
||||||
|
self.checkpoint_dir = Path(checkpoint_dir) / self.video_hash
|
||||||
|
|
||||||
|
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Checkpoint files
|
||||||
|
self.status_file = self.checkpoint_dir / "processing_status.json"
|
||||||
|
self.chunks_dir = self.checkpoint_dir / "chunks"
|
||||||
|
self.chunks_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
# Load existing status or create new
|
||||||
|
self.status = self._load_status()
|
||||||
|
|
||||||
|
def _compute_video_hash(self) -> str:
|
||||||
|
"""Compute hash of video file for unique identification"""
|
||||||
|
# Use file path, size, and modification time for quick hash
|
||||||
|
stat = self.video_path.stat()
|
||||||
|
hash_str = f"{self.video_path}_{stat.st_size}_{stat.st_mtime}"
|
||||||
|
return hashlib.md5(hash_str.encode()).hexdigest()[:12]
|
||||||
|
|
||||||
|
def _load_status(self) -> Dict[str, Any]:
|
||||||
|
"""Load processing status from checkpoint file"""
|
||||||
|
if self.status_file.exists():
|
||||||
|
with open(self.status_file, 'r') as f:
|
||||||
|
status = json.load(f)
|
||||||
|
print(f"📋 Loaded checkpoint: {status['completed_chunks']}/{status['total_chunks']} chunks completed")
|
||||||
|
return status
|
||||||
|
else:
|
||||||
|
# Create new status
|
||||||
|
return {
|
||||||
|
'video_path': str(self.video_path),
|
||||||
|
'output_path': str(self.output_path),
|
||||||
|
'video_hash': self.video_hash,
|
||||||
|
'start_time': datetime.now().isoformat(),
|
||||||
|
'total_chunks': 0,
|
||||||
|
'completed_chunks': 0,
|
||||||
|
'chunk_info': {},
|
||||||
|
'processing_complete': False,
|
||||||
|
'merge_complete': False
|
||||||
|
}
|
||||||
|
|
||||||
|
def _save_status(self):
|
||||||
|
"""Save current status to checkpoint file"""
|
||||||
|
self.status['last_update'] = datetime.now().isoformat()
|
||||||
|
with open(self.status_file, 'w') as f:
|
||||||
|
json.dump(self.status, f, indent=2)
|
||||||
|
|
||||||
|
def set_total_chunks(self, total_chunks: int):
|
||||||
|
"""Set total number of chunks to process"""
|
||||||
|
self.status['total_chunks'] = total_chunks
|
||||||
|
self._save_status()
|
||||||
|
|
||||||
|
def is_chunk_completed(self, chunk_idx: int) -> bool:
|
||||||
|
"""Check if a chunk has already been processed"""
|
||||||
|
chunk_key = f"chunk_{chunk_idx}"
|
||||||
|
return chunk_key in self.status['chunk_info'] and \
|
||||||
|
self.status['chunk_info'][chunk_key].get('completed', False)
|
||||||
|
|
||||||
|
def get_chunk_file(self, chunk_idx: int) -> Optional[Path]:
|
||||||
|
"""Get saved chunk file path if it exists"""
|
||||||
|
chunk_file = self.chunks_dir / f"chunk_{chunk_idx:04d}.npz"
|
||||||
|
if chunk_file.exists() and self.is_chunk_completed(chunk_idx):
|
||||||
|
return chunk_file
|
||||||
|
return None
|
||||||
|
|
||||||
|
def save_chunk(self, chunk_idx: int, frames: List, source_chunk_path: Optional[Path] = None):
|
||||||
|
"""
|
||||||
|
Save processed chunk and mark as completed
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chunk_idx: Chunk index
|
||||||
|
frames: Processed frames (can be None if using source_chunk_path)
|
||||||
|
source_chunk_path: If provided, copy this file instead of saving frames
|
||||||
|
"""
|
||||||
|
chunk_file = self.chunks_dir / f"chunk_{chunk_idx:04d}.npz"
|
||||||
|
|
||||||
|
try:
|
||||||
|
if source_chunk_path and source_chunk_path.exists():
|
||||||
|
# Copy existing chunk file
|
||||||
|
shutil.copy2(source_chunk_path, chunk_file)
|
||||||
|
print(f"💾 Copied chunk {chunk_idx} to checkpoint: {chunk_file.name}")
|
||||||
|
elif frames is not None:
|
||||||
|
# Save new frames
|
||||||
|
import numpy as np
|
||||||
|
np.savez_compressed(str(chunk_file), frames=frames)
|
||||||
|
print(f"💾 Saved chunk {chunk_idx} to checkpoint: {chunk_file.name}")
|
||||||
|
else:
|
||||||
|
raise ValueError("Either frames or source_chunk_path must be provided")
|
||||||
|
|
||||||
|
# Update status
|
||||||
|
chunk_key = f"chunk_{chunk_idx}"
|
||||||
|
self.status['chunk_info'][chunk_key] = {
|
||||||
|
'completed': True,
|
||||||
|
'file': chunk_file.name,
|
||||||
|
'timestamp': datetime.now().isoformat()
|
||||||
|
}
|
||||||
|
self.status['completed_chunks'] = len([c for c in self.status['chunk_info'].values() if c['completed']])
|
||||||
|
self._save_status()
|
||||||
|
|
||||||
|
print(f"✅ Chunk {chunk_idx} checkpoint saved ({self.status['completed_chunks']}/{self.status['total_chunks']})")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Failed to save chunk {chunk_idx} checkpoint: {e}")
|
||||||
|
|
||||||
|
def get_completed_chunk_files(self) -> List[Path]:
|
||||||
|
"""Get list of all completed chunk files in order"""
|
||||||
|
chunk_files = []
|
||||||
|
missing_chunks = []
|
||||||
|
|
||||||
|
for i in range(self.status['total_chunks']):
|
||||||
|
chunk_file = self.get_chunk_file(i)
|
||||||
|
if chunk_file:
|
||||||
|
chunk_files.append(chunk_file)
|
||||||
|
else:
|
||||||
|
# Check if chunk is marked as completed but file is missing
|
||||||
|
if self.is_chunk_completed(i):
|
||||||
|
missing_chunks.append(i)
|
||||||
|
print(f"⚠️ Chunk {i} marked complete but file missing!")
|
||||||
|
else:
|
||||||
|
break # Stop at first unprocessed chunk
|
||||||
|
|
||||||
|
if missing_chunks:
|
||||||
|
print(f"❌ Missing checkpoint files for chunks: {missing_chunks}")
|
||||||
|
print(f" This may happen if files were deleted during streaming merge")
|
||||||
|
print(f" These chunks may need to be reprocessed")
|
||||||
|
|
||||||
|
return chunk_files
|
||||||
|
|
||||||
|
def mark_processing_complete(self):
|
||||||
|
"""Mark all chunk processing as complete"""
|
||||||
|
self.status['processing_complete'] = True
|
||||||
|
self._save_status()
|
||||||
|
print(f"✅ All chunks processed and checkpointed")
|
||||||
|
|
||||||
|
def mark_merge_complete(self):
|
||||||
|
"""Mark final merge as complete"""
|
||||||
|
self.status['merge_complete'] = True
|
||||||
|
self._save_status()
|
||||||
|
print(f"✅ Video merge completed")
|
||||||
|
|
||||||
|
def cleanup_checkpoints(self, keep_chunks: bool = False):
|
||||||
|
"""
|
||||||
|
Clean up checkpoint files after successful completion
|
||||||
|
|
||||||
|
Args:
|
||||||
|
keep_chunks: If True, keep chunk files but remove status
|
||||||
|
"""
|
||||||
|
if keep_chunks:
|
||||||
|
# Just remove status file
|
||||||
|
if self.status_file.exists():
|
||||||
|
self.status_file.unlink()
|
||||||
|
print(f"🗑️ Removed checkpoint status file")
|
||||||
|
else:
|
||||||
|
# Remove entire checkpoint directory
|
||||||
|
if self.checkpoint_dir.exists():
|
||||||
|
shutil.rmtree(self.checkpoint_dir)
|
||||||
|
print(f"🗑️ Removed all checkpoint files: {self.checkpoint_dir}")
|
||||||
|
|
||||||
|
def get_resume_info(self) -> Dict[str, Any]:
|
||||||
|
"""Get information about what can be resumed"""
|
||||||
|
return {
|
||||||
|
'can_resume': self.status['completed_chunks'] > 0,
|
||||||
|
'completed_chunks': self.status['completed_chunks'],
|
||||||
|
'total_chunks': self.status['total_chunks'],
|
||||||
|
'processing_complete': self.status['processing_complete'],
|
||||||
|
'merge_complete': self.status['merge_complete'],
|
||||||
|
'checkpoint_dir': str(self.checkpoint_dir)
|
||||||
|
}
|
||||||
|
|
||||||
|
def print_status(self):
|
||||||
|
"""Print current checkpoint status"""
|
||||||
|
print(f"\n📊 CHECKPOINT STATUS:")
|
||||||
|
print(f" Video: {self.video_path.name}")
|
||||||
|
print(f" Hash: {self.video_hash}")
|
||||||
|
print(f" Progress: {self.status['completed_chunks']}/{self.status['total_chunks']} chunks")
|
||||||
|
print(f" Processing complete: {self.status['processing_complete']}")
|
||||||
|
print(f" Merge complete: {self.status['merge_complete']}")
|
||||||
|
print(f" Checkpoint dir: {self.checkpoint_dir}")
|
||||||
|
|
||||||
|
if self.status['completed_chunks'] > 0:
|
||||||
|
print(f"\n Completed chunks:")
|
||||||
|
for i in range(self.status['completed_chunks']):
|
||||||
|
chunk_info = self.status['chunk_info'].get(f'chunk_{i}', {})
|
||||||
|
if chunk_info.get('completed'):
|
||||||
|
print(f" ✓ Chunk {i}: {chunk_info.get('file', 'unknown')}")
|
||||||
@@ -29,6 +29,11 @@ class MattingConfig:
|
|||||||
fp16: bool = True
|
fp16: bool = True
|
||||||
sam2_model_cfg: str = "sam2.1_hiera_l"
|
sam2_model_cfg: str = "sam2.1_hiera_l"
|
||||||
sam2_checkpoint: str = "segment-anything-2/checkpoints/sam2.1_hiera_large.pt"
|
sam2_checkpoint: str = "segment-anything-2/checkpoints/sam2.1_hiera_large.pt"
|
||||||
|
# Det-SAM2 optimizations
|
||||||
|
continuous_correction: bool = True
|
||||||
|
correction_interval: int = 60 # Add correction prompts every N frames
|
||||||
|
frame_release_interval: int = 50 # Release old frames every N frames
|
||||||
|
frame_window_size: int = 30 # Keep N frames in memory
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -37,6 +42,8 @@ class OutputConfig:
|
|||||||
format: str = "alpha"
|
format: str = "alpha"
|
||||||
background_color: List[int] = None
|
background_color: List[int] = None
|
||||||
maintain_sbs: bool = True
|
maintain_sbs: bool = True
|
||||||
|
preserve_audio: bool = True
|
||||||
|
verify_sync: bool = True
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.background_color is None:
|
if self.background_color is None:
|
||||||
@@ -99,7 +106,9 @@ class VR180Config:
|
|||||||
'path': self.output.path,
|
'path': self.output.path,
|
||||||
'format': self.output.format,
|
'format': self.output.format,
|
||||||
'background_color': self.output.background_color,
|
'background_color': self.output.background_color,
|
||||||
'maintain_sbs': self.output.maintain_sbs
|
'maintain_sbs': self.output.maintain_sbs,
|
||||||
|
'preserve_audio': self.output.preserve_audio,
|
||||||
|
'verify_sync': self.output.verify_sync
|
||||||
},
|
},
|
||||||
'hardware': {
|
'hardware': {
|
||||||
'device': self.hardware.device,
|
'device': self.hardware.device,
|
||||||
|
|||||||
@@ -1,6 +1,4 @@
|
|||||||
import torch
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from ultralytics import YOLO
|
|
||||||
from typing import List, Tuple, Dict, Any
|
from typing import List, Tuple, Dict, Any
|
||||||
import cv2
|
import cv2
|
||||||
|
|
||||||
@@ -13,14 +11,23 @@ class YOLODetector:
|
|||||||
self.confidence_threshold = confidence_threshold
|
self.confidence_threshold = confidence_threshold
|
||||||
self.device = device
|
self.device = device
|
||||||
self.model = None
|
self.model = None
|
||||||
self._load_model()
|
# Don't load model during init - load lazily when first used
|
||||||
|
|
||||||
def _load_model(self):
|
def _load_model(self):
|
||||||
"""Load YOLOv8 model"""
|
"""Load YOLOv8 model lazily"""
|
||||||
|
if self.model is not None:
|
||||||
|
return # Already loaded
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# Import heavy dependencies only when needed
|
||||||
|
import torch
|
||||||
|
from ultralytics import YOLO
|
||||||
|
|
||||||
self.model = YOLO(f"{self.model_name}.pt")
|
self.model = YOLO(f"{self.model_name}.pt")
|
||||||
if self.device == "cuda" and torch.cuda.is_available():
|
if self.device == "cuda" and torch.cuda.is_available():
|
||||||
self.model.to("cuda")
|
self.model.to("cuda")
|
||||||
|
|
||||||
|
print(f"🎯 Loaded YOLO model: {self.model_name}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Failed to load YOLO model {self.model_name}: {e}")
|
raise RuntimeError(f"Failed to load YOLO model {self.model_name}: {e}")
|
||||||
|
|
||||||
@@ -34,8 +41,9 @@ class YOLODetector:
|
|||||||
Returns:
|
Returns:
|
||||||
List of detection dictionaries with bbox, confidence, and class info
|
List of detection dictionaries with bbox, confidence, and class info
|
||||||
"""
|
"""
|
||||||
|
# Load model lazily on first use
|
||||||
if self.model is None:
|
if self.model is None:
|
||||||
raise RuntimeError("YOLO model not loaded")
|
self._load_model()
|
||||||
|
|
||||||
results = self.model(frame, verbose=False)
|
results = self.model(frame, verbose=False)
|
||||||
detections = []
|
detections = []
|
||||||
|
|||||||
@@ -5,13 +5,20 @@ import cv2
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import warnings
|
import warnings
|
||||||
import os
|
import os
|
||||||
|
import tempfile
|
||||||
|
import shutil
|
||||||
|
import gc
|
||||||
|
|
||||||
|
# Check SAM2 availability without importing heavy modules
|
||||||
|
def _check_sam2_available():
|
||||||
try:
|
try:
|
||||||
from sam2.build_sam import build_sam2_video_predictor
|
import sam2
|
||||||
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
return True
|
||||||
SAM2_AVAILABLE = True
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
SAM2_AVAILABLE = False
|
return False
|
||||||
|
|
||||||
|
SAM2_AVAILABLE = _check_sam2_available()
|
||||||
|
if not SAM2_AVAILABLE:
|
||||||
warnings.warn("SAM2 not available. Please install sam2 package.")
|
warnings.warn("SAM2 not available. Please install sam2 package.")
|
||||||
|
|
||||||
|
|
||||||
@@ -30,15 +37,25 @@ class SAM2VideoMatting:
|
|||||||
self.device = device
|
self.device = device
|
||||||
self.memory_offload = memory_offload
|
self.memory_offload = memory_offload
|
||||||
self.fp16 = fp16
|
self.fp16 = fp16
|
||||||
|
self.model_cfg = model_cfg
|
||||||
|
self.checkpoint_path = checkpoint_path
|
||||||
self.predictor = None
|
self.predictor = None
|
||||||
self.inference_state = None
|
self.inference_state = None
|
||||||
self.video_segments = {}
|
self.video_segments = {}
|
||||||
|
self.temp_video_path = None
|
||||||
|
|
||||||
self._load_model(model_cfg, checkpoint_path)
|
# Don't load model during init - load lazily when needed
|
||||||
|
self._model_loaded = False
|
||||||
|
|
||||||
def _load_model(self, model_cfg: str, checkpoint_path: str):
|
def _load_model(self, model_cfg: str, checkpoint_path: str):
|
||||||
"""Load SAM2 video predictor with optimizations"""
|
"""Load SAM2 video predictor lazily"""
|
||||||
|
if self._model_loaded and self.predictor is not None:
|
||||||
|
return # Already loaded and predictor exists
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# Import heavy SAM2 modules only when needed
|
||||||
|
from sam2.build_sam import build_sam2_video_predictor
|
||||||
|
|
||||||
# Check for checkpoint in SAM2 repo structure
|
# Check for checkpoint in SAM2 repo structure
|
||||||
if not Path(checkpoint_path).exists():
|
if not Path(checkpoint_path).exists():
|
||||||
# Try in segment-anything-2/checkpoints/
|
# Try in segment-anything-2/checkpoints/
|
||||||
@@ -57,36 +74,63 @@ class SAM2VideoMatting:
|
|||||||
if sam2_repo_path.exists():
|
if sam2_repo_path.exists():
|
||||||
checkpoint_path = str(sam2_repo_path)
|
checkpoint_path = str(sam2_repo_path)
|
||||||
|
|
||||||
# Use the config path as-is (should be relative to SAM2 package)
|
print(f"🎯 Loading SAM2 model: {model_cfg}")
|
||||||
# Example: "configs/sam2.1/sam2.1_hiera_l.yaml"
|
# Use SAM2's build_sam2_video_predictor which returns the predictor directly
|
||||||
|
# The predictor IS the model - no .model attribute needed
|
||||||
self.predictor = build_sam2_video_predictor(
|
self.predictor = build_sam2_video_predictor(
|
||||||
model_cfg,
|
config_file=model_cfg,
|
||||||
checkpoint_path,
|
ckpt_path=checkpoint_path,
|
||||||
device=self.device
|
device=self.device
|
||||||
)
|
)
|
||||||
|
|
||||||
# Enable memory optimizations
|
self._model_loaded = True
|
||||||
if self.memory_offload:
|
print(f"✅ SAM2 model loaded successfully")
|
||||||
self.predictor.fill_hole_area = 8
|
|
||||||
|
|
||||||
if self.fp16 and self.device == "cuda":
|
|
||||||
self.predictor.model.half()
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Failed to load SAM2 model: {e}")
|
raise RuntimeError(f"Failed to load SAM2 model: {e}")
|
||||||
|
|
||||||
def init_video_state(self, video_frames: List[np.ndarray]) -> None:
|
def init_video_state(self, video_frames: List[np.ndarray] = None, video_path: str = None) -> None:
|
||||||
"""Initialize video inference state"""
|
"""Initialize video inference state"""
|
||||||
if self.predictor is None:
|
# Load model lazily on first use
|
||||||
raise RuntimeError("SAM2 model not loaded")
|
if not self._model_loaded:
|
||||||
|
self._load_model(self.model_cfg, self.checkpoint_path)
|
||||||
|
|
||||||
# Create temporary directory for frames if needed
|
if video_path is not None:
|
||||||
|
# Use video path directly (SAM2's preferred method)
|
||||||
self.inference_state = self.predictor.init_state(
|
self.inference_state = self.predictor.init_state(
|
||||||
video_path=None,
|
video_path=video_path,
|
||||||
video_frames=video_frames,
|
|
||||||
offload_video_to_cpu=self.memory_offload,
|
offload_video_to_cpu=self.memory_offload,
|
||||||
async_loading_frames=True
|
async_loading_frames=True
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
# For frame arrays, we need to save them as a temporary video first
|
||||||
|
|
||||||
|
if video_frames is None or len(video_frames) == 0:
|
||||||
|
raise ValueError("Either video_path or video_frames must be provided")
|
||||||
|
|
||||||
|
# Create temporary video file in current directory
|
||||||
|
import uuid
|
||||||
|
temp_video_name = f"temp_sam2_{uuid.uuid4().hex[:8]}.mp4"
|
||||||
|
temp_video_path = Path.cwd() / temp_video_name
|
||||||
|
|
||||||
|
# Write frames to temporary video
|
||||||
|
height, width = video_frames[0].shape[:2]
|
||||||
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
||||||
|
writer = cv2.VideoWriter(str(temp_video_path), fourcc, 30.0, (width, height))
|
||||||
|
|
||||||
|
for frame in video_frames:
|
||||||
|
writer.write(frame)
|
||||||
|
writer.release()
|
||||||
|
|
||||||
|
# Initialize with temporary video
|
||||||
|
self.inference_state = self.predictor.init_state(
|
||||||
|
video_path=str(temp_video_path),
|
||||||
|
offload_video_to_cpu=self.memory_offload,
|
||||||
|
async_loading_frames=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store temp path for cleanup
|
||||||
|
self.temp_video_path = temp_video_path
|
||||||
|
|
||||||
def add_person_prompts(self,
|
def add_person_prompts(self,
|
||||||
frame_idx: int,
|
frame_idx: int,
|
||||||
@@ -123,13 +167,16 @@ class SAM2VideoMatting:
|
|||||||
|
|
||||||
return object_ids
|
return object_ids
|
||||||
|
|
||||||
def propagate_masks(self, start_frame: int = 0, max_frames: Optional[int] = None) -> Dict[int, Dict[int, np.ndarray]]:
|
def propagate_masks(self, start_frame: int = 0, max_frames: Optional[int] = None,
|
||||||
|
frame_release_interval: int = 50, frame_window_size: int = 30) -> Dict[int, Dict[int, np.ndarray]]:
|
||||||
"""
|
"""
|
||||||
Propagate masks through video
|
Propagate masks through video with Det-SAM2 style memory management
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
start_frame: Starting frame index
|
start_frame: Starting frame index
|
||||||
max_frames: Maximum number of frames to process
|
max_frames: Maximum number of frames to process
|
||||||
|
frame_release_interval: Release old frames every N frames
|
||||||
|
frame_window_size: Keep N frames in memory
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dictionary mapping frame_idx -> {obj_id: mask}
|
Dictionary mapping frame_idx -> {obj_id: mask}
|
||||||
@@ -153,9 +200,108 @@ class SAM2VideoMatting:
|
|||||||
|
|
||||||
video_segments[out_frame_idx] = frame_masks
|
video_segments[out_frame_idx] = frame_masks
|
||||||
|
|
||||||
# Memory management: release old frames periodically
|
# Det-SAM2 style memory management: more aggressive frame release
|
||||||
if self.memory_offload and out_frame_idx % 100 == 0:
|
if self.memory_offload and out_frame_idx % frame_release_interval == 0:
|
||||||
self._release_old_frames(out_frame_idx - 50)
|
self._release_old_frames(out_frame_idx - frame_window_size)
|
||||||
|
# Optional: Log frame release for monitoring
|
||||||
|
if out_frame_idx % (frame_release_interval * 4) == 0: # Log every 4x release interval
|
||||||
|
print(f"Det-SAM2: Released frames before {out_frame_idx - frame_window_size}, keeping {frame_window_size} frames")
|
||||||
|
|
||||||
|
return video_segments
|
||||||
|
|
||||||
|
def propagate_masks_with_continuous_correction(self,
|
||||||
|
detector,
|
||||||
|
temp_video_path: str,
|
||||||
|
start_frame: int = 0,
|
||||||
|
max_frames: Optional[int] = None,
|
||||||
|
correction_interval: int = 60,
|
||||||
|
frame_release_interval: int = 50,
|
||||||
|
frame_window_size: int = 30) -> Dict[int, Dict[int, np.ndarray]]:
|
||||||
|
"""
|
||||||
|
Det-SAM2 style: Propagate masks with continuous prompt correction
|
||||||
|
|
||||||
|
Args:
|
||||||
|
detector: YOLODetector instance for generating correction prompts
|
||||||
|
temp_video_path: Path to video file for frame access
|
||||||
|
start_frame: Starting frame index
|
||||||
|
max_frames: Maximum number of frames to process
|
||||||
|
correction_interval: Add correction prompts every N frames
|
||||||
|
frame_release_interval: Release old frames every N frames
|
||||||
|
frame_window_size: Keep N frames in memory
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping frame_idx -> {obj_id: mask}
|
||||||
|
"""
|
||||||
|
if self.inference_state is None:
|
||||||
|
raise RuntimeError("Video state not initialized")
|
||||||
|
|
||||||
|
video_segments = {}
|
||||||
|
max_frames = max_frames or 10000 # Default limit
|
||||||
|
|
||||||
|
# Open video for accessing frames during propagation
|
||||||
|
cap = cv2.VideoCapture(str(temp_video_path))
|
||||||
|
|
||||||
|
try:
|
||||||
|
for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(
|
||||||
|
self.inference_state,
|
||||||
|
start_frame_idx=start_frame,
|
||||||
|
max_frame_num_to_track=max_frames,
|
||||||
|
reverse=False
|
||||||
|
):
|
||||||
|
frame_masks = {}
|
||||||
|
|
||||||
|
for i, out_obj_id in enumerate(out_obj_ids):
|
||||||
|
mask = (out_mask_logits[i] > 0.0).cpu().numpy()
|
||||||
|
frame_masks[out_obj_id] = mask
|
||||||
|
|
||||||
|
video_segments[out_frame_idx] = frame_masks
|
||||||
|
|
||||||
|
# Det-SAM2 optimization: Add correction prompts at keyframes
|
||||||
|
if (out_frame_idx % correction_interval == 0 and
|
||||||
|
out_frame_idx > start_frame and
|
||||||
|
out_frame_idx < max_frames - 1):
|
||||||
|
|
||||||
|
# Read frame for detection
|
||||||
|
cap.set(cv2.CAP_PROP_POS_FRAMES, out_frame_idx)
|
||||||
|
ret, correction_frame = cap.read()
|
||||||
|
|
||||||
|
if ret:
|
||||||
|
# Run detection on this keyframe
|
||||||
|
detections = detector.detect_persons(correction_frame)
|
||||||
|
|
||||||
|
if detections:
|
||||||
|
# Convert to prompts and add as corrections
|
||||||
|
box_prompts, labels = detector.convert_to_sam_prompts(detections)
|
||||||
|
|
||||||
|
# Add correction prompts (SAM2 will propagate backward)
|
||||||
|
correction_count = 0
|
||||||
|
try:
|
||||||
|
for i, (box, label) in enumerate(zip(box_prompts, labels)):
|
||||||
|
# Use existing object IDs if available, otherwise create new ones
|
||||||
|
obj_id = out_obj_ids[i] if i < len(out_obj_ids) else len(out_obj_ids) + i + 1
|
||||||
|
|
||||||
|
self.predictor.add_new_points_or_box(
|
||||||
|
inference_state=self.inference_state,
|
||||||
|
frame_idx=out_frame_idx,
|
||||||
|
obj_id=obj_id,
|
||||||
|
box=box,
|
||||||
|
)
|
||||||
|
correction_count += 1
|
||||||
|
|
||||||
|
print(f"Det-SAM2: Added {correction_count} correction prompts at frame {out_frame_idx}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
warnings.warn(f"Failed to add correction prompt at frame {out_frame_idx}: {e}")
|
||||||
|
|
||||||
|
# Memory management: More aggressive frame release (Det-SAM2 style)
|
||||||
|
if self.memory_offload and out_frame_idx % frame_release_interval == 0:
|
||||||
|
self._release_old_frames(out_frame_idx - frame_window_size)
|
||||||
|
# Optional: Log frame release for monitoring
|
||||||
|
if out_frame_idx % (frame_release_interval * 4) == 0: # Log every 4x release interval
|
||||||
|
print(f"Det-SAM2: Released frames before {out_frame_idx - frame_window_size}, keeping {frame_window_size} frames")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
cap.release()
|
||||||
|
|
||||||
return video_segments
|
return video_segments
|
||||||
|
|
||||||
@@ -231,17 +377,58 @@ class SAM2VideoMatting:
|
|||||||
"""Clean up resources"""
|
"""Clean up resources"""
|
||||||
if self.inference_state is not None:
|
if self.inference_state is not None:
|
||||||
try:
|
try:
|
||||||
if hasattr(self.predictor, 'cleanup_state'):
|
# Reset SAM2 state first (critical for memory cleanup)
|
||||||
|
if self.predictor is not None and hasattr(self.predictor, 'reset_state'):
|
||||||
|
self.predictor.reset_state(self.inference_state)
|
||||||
|
|
||||||
|
# Fallback to cleanup_state if available
|
||||||
|
elif self.predictor is not None and hasattr(self.predictor, 'cleanup_state'):
|
||||||
self.predictor.cleanup_state(self.inference_state)
|
self.predictor.cleanup_state(self.inference_state)
|
||||||
|
|
||||||
|
# Explicitly delete inference state and video segments
|
||||||
|
del self.inference_state
|
||||||
|
if hasattr(self, 'video_segments') and self.video_segments:
|
||||||
|
del self.video_segments
|
||||||
|
self.video_segments = {}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
warnings.warn(f"Failed to cleanup SAM2 state: {e}")
|
warnings.warn(f"Failed to cleanup SAM2 state: {e}")
|
||||||
|
finally:
|
||||||
self.inference_state = None
|
self.inference_state = None
|
||||||
|
|
||||||
|
# Clean up temporary video file
|
||||||
|
if self.temp_video_path is not None:
|
||||||
|
try:
|
||||||
|
if self.temp_video_path.exists():
|
||||||
|
# Remove the temporary video file
|
||||||
|
self.temp_video_path.unlink()
|
||||||
|
self.temp_video_path = None
|
||||||
|
except Exception as e:
|
||||||
|
warnings.warn(f"Failed to cleanup temp video: {e}")
|
||||||
|
|
||||||
# Clear CUDA cache
|
# Clear CUDA cache
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
# Explicitly delete predictor for fresh creation next time
|
||||||
|
if self.predictor is not None:
|
||||||
|
try:
|
||||||
|
del self.predictor
|
||||||
|
except Exception as e:
|
||||||
|
warnings.warn(f"Failed to delete predictor: {e}")
|
||||||
|
finally:
|
||||||
|
self.predictor = None
|
||||||
|
|
||||||
|
# Reset model loaded state for fresh reload
|
||||||
|
self._model_loaded = False
|
||||||
|
|
||||||
|
# Force garbage collection (critical for memory leak prevention)
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
"""Destructor to ensure cleanup"""
|
"""Destructor to ensure cleanup"""
|
||||||
|
try:
|
||||||
self.cleanup()
|
self.cleanup()
|
||||||
|
except Exception:
|
||||||
|
# Ignore errors during Python shutdown
|
||||||
|
pass
|
||||||
@@ -7,6 +7,12 @@ import tempfile
|
|||||||
import shutil
|
import shutil
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import warnings
|
import warnings
|
||||||
|
import time
|
||||||
|
import subprocess
|
||||||
|
import gc
|
||||||
|
import psutil
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
from .config import VR180Config
|
from .config import VR180Config
|
||||||
from .detector import YOLODetector
|
from .detector import YOLODetector
|
||||||
@@ -35,8 +41,137 @@ class VideoProcessor:
|
|||||||
self.frame_width = 0
|
self.frame_width = 0
|
||||||
self.frame_height = 0
|
self.frame_height = 0
|
||||||
|
|
||||||
|
# Processing statistics
|
||||||
|
self.processing_stats = {
|
||||||
|
'start_time': None,
|
||||||
|
'end_time': None,
|
||||||
|
'total_duration': 0,
|
||||||
|
'processing_fps': 0,
|
||||||
|
'chunks_processed': 0,
|
||||||
|
'frames_processed': 0
|
||||||
|
}
|
||||||
|
|
||||||
self._initialize_models()
|
self._initialize_models()
|
||||||
|
|
||||||
|
def _get_process_memory_info(self) -> Dict[str, float]:
|
||||||
|
"""Get detailed memory usage for current process and children"""
|
||||||
|
current_process = psutil.Process(os.getpid())
|
||||||
|
|
||||||
|
# Get memory info for current process
|
||||||
|
memory_info = current_process.memory_info()
|
||||||
|
current_rss = memory_info.rss / 1024**3 # Convert to GB
|
||||||
|
current_vms = memory_info.vms / 1024**3 # Virtual memory
|
||||||
|
|
||||||
|
# Get memory info for all children
|
||||||
|
children_rss = 0
|
||||||
|
children_vms = 0
|
||||||
|
child_count = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
for child in current_process.children(recursive=True):
|
||||||
|
try:
|
||||||
|
child_memory = child.memory_info()
|
||||||
|
children_rss += child_memory.rss / 1024**3
|
||||||
|
children_vms += child_memory.vms / 1024**3
|
||||||
|
child_count += 1
|
||||||
|
except (psutil.NoSuchProcess, psutil.AccessDenied):
|
||||||
|
pass
|
||||||
|
except psutil.NoSuchProcess:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# System memory info
|
||||||
|
system_memory = psutil.virtual_memory()
|
||||||
|
system_total = system_memory.total / 1024**3
|
||||||
|
system_available = system_memory.available / 1024**3
|
||||||
|
system_used = system_memory.used / 1024**3
|
||||||
|
system_percent = system_memory.percent
|
||||||
|
|
||||||
|
return {
|
||||||
|
'process_rss_gb': current_rss,
|
||||||
|
'process_vms_gb': current_vms,
|
||||||
|
'children_rss_gb': children_rss,
|
||||||
|
'children_vms_gb': children_vms,
|
||||||
|
'total_process_gb': current_rss + children_rss,
|
||||||
|
'child_count': child_count,
|
||||||
|
'system_total_gb': system_total,
|
||||||
|
'system_used_gb': system_used,
|
||||||
|
'system_available_gb': system_available,
|
||||||
|
'system_percent': system_percent
|
||||||
|
}
|
||||||
|
|
||||||
|
def _print_memory_step(self, step_name: str):
|
||||||
|
"""Print memory usage for a specific processing step"""
|
||||||
|
memory_info = self._get_process_memory_info()
|
||||||
|
|
||||||
|
print(f"\n📊 MEMORY: {step_name}")
|
||||||
|
print(f" Process RSS: {memory_info['process_rss_gb']:.2f} GB")
|
||||||
|
if memory_info['children_rss_gb'] > 0:
|
||||||
|
print(f" Children RSS: {memory_info['children_rss_gb']:.2f} GB ({memory_info['child_count']} processes)")
|
||||||
|
print(f" Total Process: {memory_info['total_process_gb']:.2f} GB")
|
||||||
|
print(f" System: {memory_info['system_used_gb']:.1f}/{memory_info['system_total_gb']:.1f} GB ({memory_info['system_percent']:.1f}%)")
|
||||||
|
print(f" Available: {memory_info['system_available_gb']:.1f} GB")
|
||||||
|
|
||||||
|
def _aggressive_memory_cleanup(self, step_name: str = ""):
|
||||||
|
"""Perform aggressive memory cleanup and report before/after"""
|
||||||
|
if step_name:
|
||||||
|
print(f"\n🧹 CLEANUP: Before {step_name}")
|
||||||
|
|
||||||
|
before_info = self._get_process_memory_info()
|
||||||
|
before_rss = before_info['total_process_gb']
|
||||||
|
|
||||||
|
# Multiple rounds of garbage collection
|
||||||
|
for i in range(3):
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
# Clear torch cache if available
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Clear OpenCV internal caches
|
||||||
|
try:
|
||||||
|
# Clear OpenCV video capture cache
|
||||||
|
cv2.setUseOptimized(False)
|
||||||
|
cv2.setUseOptimized(True)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Clear CuPy caches if available
|
||||||
|
try:
|
||||||
|
import cupy as cp
|
||||||
|
cp._default_memory_pool.free_all_blocks()
|
||||||
|
cp._default_pinned_memory_pool.free_all_blocks()
|
||||||
|
cp.get_default_memory_pool().free_all_blocks()
|
||||||
|
cp.get_default_pinned_memory_pool().free_all_blocks()
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
print(f" Warning: Could not clear CuPy cache: {e}")
|
||||||
|
|
||||||
|
# Force Linux to release memory back to OS
|
||||||
|
if sys.platform == 'linux':
|
||||||
|
try:
|
||||||
|
import ctypes
|
||||||
|
libc = ctypes.CDLL("libc.so.6")
|
||||||
|
libc.malloc_trim(0)
|
||||||
|
except Exception as e:
|
||||||
|
print(f" Warning: Could not trim memory: {e}")
|
||||||
|
|
||||||
|
# Brief pause to allow cleanup
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
after_info = self._get_process_memory_info()
|
||||||
|
after_rss = after_info['total_process_gb']
|
||||||
|
freed_memory = before_rss - after_rss
|
||||||
|
|
||||||
|
if step_name:
|
||||||
|
print(f" Before: {before_rss:.2f} GB → After: {after_rss:.2f} GB")
|
||||||
|
print(f" Freed: {freed_memory:.2f} GB")
|
||||||
|
|
||||||
def _initialize_models(self):
|
def _initialize_models(self):
|
||||||
"""Initialize YOLO detector and SAM2 model"""
|
"""Initialize YOLO detector and SAM2 model"""
|
||||||
print("Initializing models...")
|
print("Initializing models...")
|
||||||
@@ -146,6 +281,116 @@ class VideoProcessor:
|
|||||||
print(f"Read {len(frames)} frames")
|
print(f"Read {len(frames)} frames")
|
||||||
return frames
|
return frames
|
||||||
|
|
||||||
|
def read_video_frames_dual_resolution(self,
|
||||||
|
video_path: str,
|
||||||
|
start_frame: int = 0,
|
||||||
|
num_frames: Optional[int] = None,
|
||||||
|
scale_factor: float = 0.5) -> Dict[str, List[np.ndarray]]:
|
||||||
|
"""
|
||||||
|
Read video frames at both original and scaled resolution for dual-resolution processing
|
||||||
|
|
||||||
|
Args:
|
||||||
|
video_path: Path to video file
|
||||||
|
start_frame: Starting frame index
|
||||||
|
num_frames: Number of frames to read (None for all)
|
||||||
|
scale_factor: Scaling factor for inference frames
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with 'original' and 'scaled' frame lists
|
||||||
|
"""
|
||||||
|
cap = cv2.VideoCapture(video_path)
|
||||||
|
|
||||||
|
if not cap.isOpened():
|
||||||
|
raise RuntimeError(f"Could not open video file: {video_path}")
|
||||||
|
|
||||||
|
# Set starting position
|
||||||
|
cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
|
||||||
|
|
||||||
|
original_frames = []
|
||||||
|
scaled_frames = []
|
||||||
|
frame_count = 0
|
||||||
|
|
||||||
|
# Progress tracking
|
||||||
|
total_to_read = num_frames if num_frames else self.total_frames - start_frame
|
||||||
|
|
||||||
|
with tqdm(total=total_to_read, desc="Reading dual-resolution frames") as pbar:
|
||||||
|
while True:
|
||||||
|
ret, frame = cap.read()
|
||||||
|
if not ret:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Store original frame
|
||||||
|
original_frames.append(frame.copy())
|
||||||
|
|
||||||
|
# Create scaled frame for inference
|
||||||
|
if scale_factor != 1.0:
|
||||||
|
new_width = int(frame.shape[1] * scale_factor)
|
||||||
|
new_height = int(frame.shape[0] * scale_factor)
|
||||||
|
scaled_frame = cv2.resize(frame, (new_width, new_height),
|
||||||
|
interpolation=cv2.INTER_AREA)
|
||||||
|
else:
|
||||||
|
scaled_frame = frame.copy()
|
||||||
|
|
||||||
|
scaled_frames.append(scaled_frame)
|
||||||
|
frame_count += 1
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
if num_frames is not None and frame_count >= num_frames:
|
||||||
|
break
|
||||||
|
|
||||||
|
cap.release()
|
||||||
|
|
||||||
|
print(f"Loaded {len(original_frames)} frames:")
|
||||||
|
print(f" Original: {original_frames[0].shape} per frame")
|
||||||
|
print(f" Scaled: {scaled_frames[0].shape} per frame (scale_factor={scale_factor})")
|
||||||
|
|
||||||
|
return {
|
||||||
|
'original': original_frames,
|
||||||
|
'scaled': scaled_frames
|
||||||
|
}
|
||||||
|
|
||||||
|
def upscale_mask(self, mask: np.ndarray, target_shape: tuple, method: str = 'cubic') -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Upscale a mask from inference resolution to original resolution
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mask: Low-resolution mask (H, W)
|
||||||
|
target_shape: Target shape (H, W) for upscaling
|
||||||
|
method: Upscaling method ('nearest', 'cubic', 'area')
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Upscaled mask at target resolution
|
||||||
|
"""
|
||||||
|
if mask.shape[:2] == target_shape[:2]:
|
||||||
|
return mask # Already correct size
|
||||||
|
|
||||||
|
# Ensure mask is 2D
|
||||||
|
if mask.ndim == 3:
|
||||||
|
mask = mask.squeeze()
|
||||||
|
|
||||||
|
# Choose interpolation method
|
||||||
|
if method == 'nearest':
|
||||||
|
interpolation = cv2.INTER_NEAREST # Crisp edges, good for sharp subjects
|
||||||
|
elif method == 'cubic':
|
||||||
|
interpolation = cv2.INTER_CUBIC # Smooth edges, good for most content
|
||||||
|
elif method == 'area':
|
||||||
|
interpolation = cv2.INTER_AREA # Good for downscaling, not upscaling
|
||||||
|
else:
|
||||||
|
interpolation = cv2.INTER_CUBIC # Default to cubic
|
||||||
|
|
||||||
|
# Upscale mask
|
||||||
|
upscaled_mask = cv2.resize(
|
||||||
|
mask.astype(np.uint8),
|
||||||
|
(target_shape[1], target_shape[0]), # (width, height) for cv2.resize
|
||||||
|
interpolation=interpolation
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert back to boolean if it was originally boolean
|
||||||
|
if mask.dtype == bool:
|
||||||
|
upscaled_mask = upscaled_mask.astype(bool)
|
||||||
|
|
||||||
|
return upscaled_mask
|
||||||
|
|
||||||
def calculate_optimal_chunking(self) -> Tuple[int, int]:
|
def calculate_optimal_chunking(self) -> Tuple[int, int]:
|
||||||
"""
|
"""
|
||||||
Calculate optimal chunk size and overlap based on memory constraints
|
Calculate optimal chunk size and overlap based on memory constraints
|
||||||
@@ -234,6 +479,92 @@ class VideoProcessor:
|
|||||||
|
|
||||||
return matted_frames
|
return matted_frames
|
||||||
|
|
||||||
|
def process_chunk_dual_resolution(self,
|
||||||
|
frame_data: Dict[str, List[np.ndarray]],
|
||||||
|
chunk_idx: int = 0) -> List[np.ndarray]:
|
||||||
|
"""
|
||||||
|
Process a chunk using dual-resolution approach: inference at low-res, output at full-res
|
||||||
|
|
||||||
|
Args:
|
||||||
|
frame_data: Dictionary with 'original' and 'scaled' frame lists
|
||||||
|
chunk_idx: Chunk index for logging
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of matted frames at original resolution
|
||||||
|
"""
|
||||||
|
original_frames = frame_data['original']
|
||||||
|
scaled_frames = frame_data['scaled']
|
||||||
|
|
||||||
|
print(f"Processing chunk {chunk_idx} with dual-resolution ({len(original_frames)} frames)")
|
||||||
|
print(f" Inference: {scaled_frames[0].shape} → Output: {original_frames[0].shape}")
|
||||||
|
|
||||||
|
with self.memory_manager.memory_monitor(f"dual-res chunk {chunk_idx}"):
|
||||||
|
# Initialize SAM2 with scaled frames for inference
|
||||||
|
self.sam2_model.init_video_state(scaled_frames)
|
||||||
|
|
||||||
|
# Detect persons in first scaled frame
|
||||||
|
first_scaled_frame = scaled_frames[0]
|
||||||
|
detections = self.detector.detect_persons(first_scaled_frame)
|
||||||
|
|
||||||
|
if not detections:
|
||||||
|
warnings.warn(f"No persons detected in chunk {chunk_idx}")
|
||||||
|
return self._create_empty_masks(original_frames)
|
||||||
|
|
||||||
|
print(f"Detected {len(detections)} persons in first frame (at inference resolution)")
|
||||||
|
|
||||||
|
# Convert detections to SAM2 prompts (detections are already at scaled resolution)
|
||||||
|
box_prompts, labels = self.detector.convert_to_sam_prompts(detections)
|
||||||
|
|
||||||
|
# Add prompts to SAM2
|
||||||
|
object_ids = self.sam2_model.add_person_prompts(0, box_prompts, labels)
|
||||||
|
print(f"Added prompts for {len(object_ids)} objects")
|
||||||
|
|
||||||
|
# Propagate masks through chunk at inference resolution
|
||||||
|
video_segments = self.sam2_model.propagate_masks(
|
||||||
|
start_frame=0,
|
||||||
|
max_frames=len(scaled_frames)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply upscaled masks to original resolution frames
|
||||||
|
matted_frames = []
|
||||||
|
original_shape = original_frames[0].shape[:2] # (H, W)
|
||||||
|
|
||||||
|
for frame_idx, original_frame in enumerate(tqdm(original_frames, desc="Applying upscaled masks")):
|
||||||
|
if frame_idx in video_segments:
|
||||||
|
frame_masks = video_segments[frame_idx]
|
||||||
|
|
||||||
|
# Get combined mask at inference resolution
|
||||||
|
combined_mask_scaled = self.sam2_model.get_combined_mask(frame_masks)
|
||||||
|
|
||||||
|
if combined_mask_scaled is not None:
|
||||||
|
# Upscale mask to original resolution
|
||||||
|
combined_mask_full = self.upscale_mask(
|
||||||
|
combined_mask_scaled,
|
||||||
|
target_shape=original_shape,
|
||||||
|
method='cubic' # Smooth upscaling for masks
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply upscaled mask to original resolution frame
|
||||||
|
matted_frame = self.sam2_model.apply_mask_to_frame(
|
||||||
|
original_frame, combined_mask_full,
|
||||||
|
output_format=self.config.output.format,
|
||||||
|
background_color=self.config.output.background_color
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# No mask for this frame
|
||||||
|
matted_frame = self._create_empty_mask_frame(original_frame)
|
||||||
|
else:
|
||||||
|
# No mask for this frame
|
||||||
|
matted_frame = self._create_empty_mask_frame(original_frame)
|
||||||
|
|
||||||
|
matted_frames.append(matted_frame)
|
||||||
|
|
||||||
|
# Cleanup SAM2 state
|
||||||
|
self.sam2_model.cleanup()
|
||||||
|
|
||||||
|
print(f"✅ Dual-resolution processing complete: {len(matted_frames)} frames at full resolution")
|
||||||
|
return matted_frames
|
||||||
|
|
||||||
def _create_empty_masks(self, frames: List[np.ndarray]) -> List[np.ndarray]:
|
def _create_empty_masks(self, frames: List[np.ndarray]) -> List[np.ndarray]:
|
||||||
"""Create empty masks when no persons detected"""
|
"""Create empty masks when no persons detected"""
|
||||||
empty_frames = []
|
empty_frames = []
|
||||||
@@ -252,19 +583,213 @@ class VideoProcessor:
|
|||||||
# Green screen background
|
# Green screen background
|
||||||
return np.full_like(frame, self.config.output.background_color, dtype=np.uint8)
|
return np.full_like(frame, self.config.output.background_color, dtype=np.uint8)
|
||||||
|
|
||||||
|
def merge_chunks_streaming(self, chunk_files: List[Path], output_path: str,
|
||||||
|
overlap_frames: int = 0, audio_source: str = None) -> None:
|
||||||
|
"""
|
||||||
|
Merge processed chunks using streaming approach (no memory accumulation)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chunk_files: List of chunk result files (.npz)
|
||||||
|
output_path: Final output video path
|
||||||
|
overlap_frames: Number of overlapping frames
|
||||||
|
audio_source: Audio source file for final video
|
||||||
|
"""
|
||||||
|
if not chunk_files:
|
||||||
|
raise ValueError("No chunk files to merge")
|
||||||
|
|
||||||
|
print(f"🎬 TRUE Streaming merge: {len(chunk_files)} chunks → {output_path}")
|
||||||
|
|
||||||
|
# Create temporary directory for frame images
|
||||||
|
import tempfile
|
||||||
|
temp_frames_dir = Path(tempfile.mkdtemp(prefix="merge_frames_"))
|
||||||
|
frame_counter = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
print(f"📁 Using temp frames dir: {temp_frames_dir}")
|
||||||
|
|
||||||
|
# Process each chunk frame-by-frame (true streaming)
|
||||||
|
for i, chunk_file in enumerate(chunk_files):
|
||||||
|
print(f"📼 Processing chunk {i+1}/{len(chunk_files)}: {chunk_file.name}")
|
||||||
|
|
||||||
|
# Load chunk metadata without loading frames array
|
||||||
|
chunk_data = np.load(str(chunk_file))
|
||||||
|
frames_array = chunk_data['frames'] # This is still mmap'd, not loaded
|
||||||
|
total_frames_in_chunk = frames_array.shape[0]
|
||||||
|
|
||||||
|
# Determine which frames to skip for overlap
|
||||||
|
start_frame_idx = overlap_frames if i > 0 and overlap_frames > 0 else 0
|
||||||
|
frames_to_process = total_frames_in_chunk - start_frame_idx
|
||||||
|
|
||||||
|
if start_frame_idx > 0:
|
||||||
|
print(f" ✂️ Skipping first {start_frame_idx} overlapping frames")
|
||||||
|
|
||||||
|
print(f" 🔄 Processing {frames_to_process} frames one-by-one...")
|
||||||
|
|
||||||
|
# Process frames ONE AT A TIME (true streaming)
|
||||||
|
for frame_idx in range(start_frame_idx, total_frames_in_chunk):
|
||||||
|
# Load only ONE frame at a time
|
||||||
|
frame = frames_array[frame_idx] # Load single frame
|
||||||
|
|
||||||
|
# Save frame directly to disk
|
||||||
|
frame_path = temp_frames_dir / f"frame_{frame_counter:06d}.jpg"
|
||||||
|
success = cv2.imwrite(str(frame_path), frame, [cv2.IMWRITE_JPEG_QUALITY, 95])
|
||||||
|
if not success:
|
||||||
|
raise RuntimeError(f"Failed to save frame {frame_counter}")
|
||||||
|
|
||||||
|
frame_counter += 1
|
||||||
|
|
||||||
|
# Periodic progress and cleanup
|
||||||
|
if frame_counter % 100 == 0:
|
||||||
|
print(f" 💾 Saved {frame_counter} frames...")
|
||||||
|
gc.collect() # Periodic cleanup
|
||||||
|
|
||||||
|
print(f" ✅ Saved {frames_to_process} frames to disk (total: {frame_counter})")
|
||||||
|
|
||||||
|
# Close chunk file and cleanup
|
||||||
|
chunk_data.close()
|
||||||
|
del chunk_data, frames_array
|
||||||
|
|
||||||
|
# Don't delete checkpoint files - they're needed for potential resume
|
||||||
|
# The checkpoint system manages cleanup separately
|
||||||
|
print(f" 📋 Keeping checkpoint file: {chunk_file.name}")
|
||||||
|
|
||||||
|
# Aggressive cleanup and memory monitoring after each chunk
|
||||||
|
self._aggressive_memory_cleanup(f"After streaming merge chunk {i}")
|
||||||
|
|
||||||
|
# Memory safety check
|
||||||
|
memory_info = self._get_process_memory_info()
|
||||||
|
if memory_info['total_process_gb'] > 35: # Warning if approaching 46GB limit
|
||||||
|
print(f"⚠️ High memory usage: {memory_info['total_process_gb']:.1f}GB - forcing cleanup")
|
||||||
|
gc.collect()
|
||||||
|
import torch
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
# Create final video directly from frame images using ffmpeg
|
||||||
|
print(f"📹 Creating final video from {frame_counter} frames...")
|
||||||
|
self._create_video_from_frames(temp_frames_dir, Path(output_path), frame_counter)
|
||||||
|
|
||||||
|
# Add audio if provided
|
||||||
|
if audio_source:
|
||||||
|
self._add_audio_to_video(output_path, audio_source)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Streaming merge failed: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Cleanup temporary frames directory
|
||||||
|
try:
|
||||||
|
if temp_frames_dir.exists():
|
||||||
|
import shutil
|
||||||
|
shutil.rmtree(temp_frames_dir)
|
||||||
|
print(f"🗑️ Cleaned up temp frames dir: {temp_frames_dir}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"⚠️ Could not cleanup temp frames dir: {e}")
|
||||||
|
|
||||||
|
# Memory cleanup
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
print(f"✅ TRUE Streaming merge complete: {output_path}")
|
||||||
|
|
||||||
|
def _create_video_from_frames(self, frames_dir: Path, output_path: Path, frame_count: int):
|
||||||
|
"""Create video directly from frame images using ffmpeg (memory efficient)"""
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
frame_pattern = str(frames_dir / "frame_%06d.jpg")
|
||||||
|
fps = self.video_info['fps'] if hasattr(self, 'video_info') and self.video_info else 30.0
|
||||||
|
|
||||||
|
print(f"🎬 Creating video with ffmpeg: {frame_count} frames at {fps} fps")
|
||||||
|
|
||||||
|
# Use GPU encoding if available, fallback to CPU
|
||||||
|
gpu_cmd = [
|
||||||
|
'ffmpeg', '-y', # -y to overwrite output file
|
||||||
|
'-framerate', str(fps),
|
||||||
|
'-i', frame_pattern,
|
||||||
|
'-c:v', 'h264_nvenc', # NVIDIA GPU encoder
|
||||||
|
'-preset', 'fast',
|
||||||
|
'-cq', '18', # Quality for GPU encoding
|
||||||
|
'-pix_fmt', 'yuv420p',
|
||||||
|
str(output_path)
|
||||||
|
]
|
||||||
|
|
||||||
|
cpu_cmd = [
|
||||||
|
'ffmpeg', '-y', # -y to overwrite output file
|
||||||
|
'-framerate', str(fps),
|
||||||
|
'-i', frame_pattern,
|
||||||
|
'-c:v', 'libx264', # CPU encoder
|
||||||
|
'-preset', 'medium',
|
||||||
|
'-crf', '18', # Quality for CPU encoding
|
||||||
|
'-pix_fmt', 'yuv420p',
|
||||||
|
str(output_path)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Try GPU first
|
||||||
|
print(f"🚀 Trying GPU encoding...")
|
||||||
|
result = subprocess.run(gpu_cmd, capture_output=True, text=True)
|
||||||
|
|
||||||
|
if result.returncode != 0:
|
||||||
|
print("⚠️ GPU encoding failed, using CPU...")
|
||||||
|
print(f"🔄 CPU encoding...")
|
||||||
|
result = subprocess.run(cpu_cmd, capture_output=True, text=True)
|
||||||
|
else:
|
||||||
|
print("✅ GPU encoding successful!")
|
||||||
|
|
||||||
|
if result.returncode != 0:
|
||||||
|
print(f"❌ FFmpeg stdout: {result.stdout}")
|
||||||
|
print(f"❌ FFmpeg stderr: {result.stderr}")
|
||||||
|
raise RuntimeError(f"FFmpeg failed with return code {result.returncode}")
|
||||||
|
|
||||||
|
print(f"✅ Video created successfully: {output_path}")
|
||||||
|
|
||||||
|
def _add_audio_to_video(self, video_path: str, audio_source: str):
|
||||||
|
"""Add audio to video using ffmpeg"""
|
||||||
|
import subprocess
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create temporary file for output with audio
|
||||||
|
temp_path = Path(video_path).with_suffix('.temp.mp4')
|
||||||
|
|
||||||
|
cmd = [
|
||||||
|
'ffmpeg', '-y',
|
||||||
|
'-i', str(video_path), # Input video (no audio)
|
||||||
|
'-i', str(audio_source), # Input audio source
|
||||||
|
'-c:v', 'copy', # Copy video without re-encoding
|
||||||
|
'-c:a', 'aac', # Encode audio as AAC
|
||||||
|
'-map', '0:v:0', # Map video from first input
|
||||||
|
'-map', '1:a:0', # Map audio from second input
|
||||||
|
'-shortest', # Match shortest stream duration
|
||||||
|
str(temp_path)
|
||||||
|
]
|
||||||
|
|
||||||
|
print(f"🎵 Adding audio: {audio_source} → {video_path}")
|
||||||
|
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||||
|
|
||||||
|
if result.returncode != 0:
|
||||||
|
print(f"⚠️ Audio addition failed: {result.stderr}")
|
||||||
|
# Keep original video without audio
|
||||||
|
return
|
||||||
|
|
||||||
|
# Replace original with audio version
|
||||||
|
Path(video_path).unlink()
|
||||||
|
temp_path.rename(video_path)
|
||||||
|
print(f"✅ Audio added successfully")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"⚠️ Could not add audio: {e}")
|
||||||
|
|
||||||
def merge_overlapping_chunks(self,
|
def merge_overlapping_chunks(self,
|
||||||
chunk_results: List[List[np.ndarray]],
|
chunk_results: List[List[np.ndarray]],
|
||||||
overlap_frames: int) -> List[np.ndarray]:
|
overlap_frames: int) -> List[np.ndarray]:
|
||||||
"""
|
"""
|
||||||
Merge overlapping chunks with blending in overlap regions
|
Legacy merge method - DEPRECATED due to memory accumulation
|
||||||
|
Use merge_chunks_streaming() instead for memory efficiency
|
||||||
Args:
|
|
||||||
chunk_results: List of chunk results
|
|
||||||
overlap_frames: Number of overlapping frames
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Merged frame sequence
|
|
||||||
"""
|
"""
|
||||||
|
import warnings
|
||||||
|
warnings.warn("merge_overlapping_chunks() is deprecated due to memory accumulation. Use merge_chunks_streaming()",
|
||||||
|
DeprecationWarning, stacklevel=2)
|
||||||
|
|
||||||
if len(chunk_results) == 1:
|
if len(chunk_results) == 1:
|
||||||
return chunk_results[0]
|
return chunk_results[0]
|
||||||
|
|
||||||
@@ -348,70 +873,307 @@ class VideoProcessor:
|
|||||||
print(f"Saved {len(frames)} PNG frames to {output_dir}")
|
print(f"Saved {len(frames)} PNG frames to {output_dir}")
|
||||||
|
|
||||||
def _save_mp4_video(self, frames: List[np.ndarray], output_path: str):
|
def _save_mp4_video(self, frames: List[np.ndarray], output_path: str):
|
||||||
"""Save frames as MP4 video"""
|
"""Save frames as MP4 video with audio preservation"""
|
||||||
if not frames:
|
if not frames:
|
||||||
return
|
return
|
||||||
|
|
||||||
height, width = frames[0].shape[:2]
|
output_path = Path(output_path)
|
||||||
|
temp_frames_dir = output_path.parent / f"temp_frames_{output_path.stem}"
|
||||||
|
temp_frames_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
try:
|
||||||
writer = cv2.VideoWriter(output_path, fourcc, self.fps, (width, height))
|
# Save frames as images
|
||||||
|
print("Saving frames as images...")
|
||||||
for frame in tqdm(frames, desc="Writing video"):
|
for i, frame in enumerate(tqdm(frames, desc="Saving frames")):
|
||||||
if frame.shape[2] == 4: # Convert RGBA to BGR
|
if frame.shape[2] == 4: # Convert RGBA to BGR
|
||||||
frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2BGR)
|
frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2BGR)
|
||||||
writer.write(frame)
|
|
||||||
|
|
||||||
writer.release()
|
frame_path = temp_frames_dir / f"frame_{i:06d}.jpg"
|
||||||
|
cv2.imwrite(str(frame_path), frame, [cv2.IMWRITE_JPEG_QUALITY, 95])
|
||||||
|
|
||||||
|
# Create video with ffmpeg
|
||||||
|
self._create_video_with_ffmpeg(temp_frames_dir, output_path, len(frames))
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Cleanup temporary frames
|
||||||
|
if temp_frames_dir.exists():
|
||||||
|
shutil.rmtree(temp_frames_dir)
|
||||||
|
|
||||||
|
def _create_video_with_ffmpeg(self, frames_dir: Path, output_path: Path, frame_count: int):
|
||||||
|
"""Create video using ffmpeg with audio preservation"""
|
||||||
|
frame_pattern = str(frames_dir / "frame_%06d.jpg")
|
||||||
|
|
||||||
|
if self.config.output.preserve_audio:
|
||||||
|
# Create video with audio from input
|
||||||
|
cmd = [
|
||||||
|
'ffmpeg', '-y',
|
||||||
|
'-framerate', str(self.fps),
|
||||||
|
'-i', frame_pattern,
|
||||||
|
'-i', str(self.config.input.video_path), # Input video for audio
|
||||||
|
'-c:v', 'h264_nvenc', # Try GPU encoding first
|
||||||
|
'-preset', 'fast',
|
||||||
|
'-cq', '18',
|
||||||
|
'-c:a', 'copy', # Copy audio without re-encoding
|
||||||
|
'-map', '0:v:0', # Map video from frames
|
||||||
|
'-map', '1:a:0', # Map audio from input video
|
||||||
|
'-shortest', # Match shortest stream duration
|
||||||
|
'-pix_fmt', 'yuv420p',
|
||||||
|
str(output_path)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
# Create video without audio
|
||||||
|
cmd = [
|
||||||
|
'ffmpeg', '-y',
|
||||||
|
'-framerate', str(self.fps),
|
||||||
|
'-i', frame_pattern,
|
||||||
|
'-c:v', 'h264_nvenc',
|
||||||
|
'-preset', 'fast',
|
||||||
|
'-cq', '18',
|
||||||
|
'-pix_fmt', 'yuv420p',
|
||||||
|
str(output_path)
|
||||||
|
]
|
||||||
|
|
||||||
|
print(f"Creating video with ffmpeg...")
|
||||||
|
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||||
|
|
||||||
|
if result.returncode != 0:
|
||||||
|
# Try CPU encoding as fallback
|
||||||
|
print("GPU encoding failed, trying CPU encoding...")
|
||||||
|
cmd[cmd.index('h264_nvenc')] = 'libx264'
|
||||||
|
cmd[cmd.index('-cq')] = '-crf' # Change quality parameter for CPU
|
||||||
|
|
||||||
|
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||||
|
|
||||||
|
if result.returncode != 0:
|
||||||
|
print(f"FFmpeg stdout: {result.stdout}")
|
||||||
|
print(f"FFmpeg stderr: {result.stderr}")
|
||||||
|
raise RuntimeError(f"FFmpeg failed with return code {result.returncode}")
|
||||||
|
|
||||||
|
# Verify frame count if sync verification is enabled
|
||||||
|
if self.config.output.verify_sync:
|
||||||
|
self._verify_frame_count(output_path, frame_count)
|
||||||
|
|
||||||
print(f"Saved video to {output_path}")
|
print(f"Saved video to {output_path}")
|
||||||
|
|
||||||
|
def _verify_frame_count(self, video_path: Path, expected_frames: int):
|
||||||
|
"""Verify output video has correct frame count"""
|
||||||
|
try:
|
||||||
|
probe = ffmpeg.probe(str(video_path))
|
||||||
|
video_stream = next(
|
||||||
|
(stream for stream in probe['streams'] if stream['codec_type'] == 'video'),
|
||||||
|
None
|
||||||
|
)
|
||||||
|
|
||||||
|
if video_stream:
|
||||||
|
actual_frames = int(video_stream.get('nb_frames', 0))
|
||||||
|
if actual_frames != expected_frames:
|
||||||
|
print(f"⚠️ Frame count mismatch: expected {expected_frames}, got {actual_frames}")
|
||||||
|
else:
|
||||||
|
print(f"✅ Frame count verified: {actual_frames} frames")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"⚠️ Could not verify frame count: {e}")
|
||||||
|
|
||||||
def process_video(self) -> None:
|
def process_video(self) -> None:
|
||||||
"""Main video processing pipeline"""
|
"""Main video processing pipeline with checkpoint/resume support"""
|
||||||
|
self.processing_stats['start_time'] = time.time()
|
||||||
print("Starting VR180 video processing...")
|
print("Starting VR180 video processing...")
|
||||||
|
|
||||||
# Load video info
|
# Load video info
|
||||||
self.load_video_info(self.config.input.video_path)
|
self.load_video_info(self.config.input.video_path)
|
||||||
|
|
||||||
|
# Initialize checkpoint manager
|
||||||
|
from .checkpoint_manager import CheckpointManager
|
||||||
|
checkpoint_mgr = CheckpointManager(
|
||||||
|
self.config.input.video_path,
|
||||||
|
self.config.output.path
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for existing checkpoints
|
||||||
|
resume_info = checkpoint_mgr.get_resume_info()
|
||||||
|
if resume_info['can_resume']:
|
||||||
|
print(f"\n🔄 RESUME DETECTED:")
|
||||||
|
print(f" Found {resume_info['completed_chunks']} completed chunks")
|
||||||
|
print(f" Continue from where we left off? (saves time!)")
|
||||||
|
checkpoint_mgr.print_status()
|
||||||
|
|
||||||
# Calculate chunking parameters
|
# Calculate chunking parameters
|
||||||
chunk_size, overlap_frames = self.calculate_optimal_chunking()
|
chunk_size, overlap_frames = self.calculate_optimal_chunking()
|
||||||
|
|
||||||
# Process video in chunks
|
# Calculate total chunks
|
||||||
chunk_results = []
|
total_chunks = 0
|
||||||
|
for _ in range(0, self.total_frames, chunk_size - overlap_frames):
|
||||||
|
total_chunks += 1
|
||||||
|
checkpoint_mgr.set_total_chunks(total_chunks)
|
||||||
|
|
||||||
|
# Process video in chunks
|
||||||
|
chunk_files = [] # Store file paths instead of frame data
|
||||||
|
temp_chunk_dir = Path(tempfile.mkdtemp(prefix="vr180_chunks_"))
|
||||||
|
|
||||||
|
try:
|
||||||
|
chunk_idx = 0
|
||||||
for start_frame in range(0, self.total_frames, chunk_size - overlap_frames):
|
for start_frame in range(0, self.total_frames, chunk_size - overlap_frames):
|
||||||
end_frame = min(start_frame + chunk_size, self.total_frames)
|
end_frame = min(start_frame + chunk_size, self.total_frames)
|
||||||
frames_to_read = end_frame - start_frame
|
frames_to_read = end_frame - start_frame
|
||||||
|
|
||||||
chunk_idx = len(chunk_results)
|
# Check if this chunk was already processed
|
||||||
|
existing_chunk = checkpoint_mgr.get_chunk_file(chunk_idx)
|
||||||
|
if existing_chunk:
|
||||||
|
print(f"\n✅ Chunk {chunk_idx} already processed: {existing_chunk.name}")
|
||||||
|
chunk_files.append(existing_chunk)
|
||||||
|
chunk_idx += 1
|
||||||
|
continue
|
||||||
|
|
||||||
print(f"\nProcessing chunk {chunk_idx}: frames {start_frame}-{end_frame}")
|
print(f"\nProcessing chunk {chunk_idx}: frames {start_frame}-{end_frame}")
|
||||||
|
|
||||||
# Read chunk frames
|
# Choose processing approach based on scale factor
|
||||||
|
if self.config.processing.scale_factor == 1.0:
|
||||||
|
# No scaling needed - use original single-resolution approach
|
||||||
|
print(f"🔄 Reading frames at original resolution (no scaling)")
|
||||||
frames = self.read_video_frames(
|
frames = self.read_video_frames(
|
||||||
|
self.config.input.video_path,
|
||||||
|
start_frame=start_frame,
|
||||||
|
num_frames=frames_to_read,
|
||||||
|
scale_factor=1.0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process chunk normally (single resolution)
|
||||||
|
matted_frames = self.process_chunk(frames, chunk_idx)
|
||||||
|
else:
|
||||||
|
# Scaling required - use dual-resolution approach
|
||||||
|
print(f"🔄 Reading frames at dual resolution (scale_factor={self.config.processing.scale_factor})")
|
||||||
|
frame_data = self.read_video_frames_dual_resolution(
|
||||||
self.config.input.video_path,
|
self.config.input.video_path,
|
||||||
start_frame=start_frame,
|
start_frame=start_frame,
|
||||||
num_frames=frames_to_read,
|
num_frames=frames_to_read,
|
||||||
scale_factor=self.config.processing.scale_factor
|
scale_factor=self.config.processing.scale_factor
|
||||||
)
|
)
|
||||||
|
|
||||||
# Process chunk
|
# Process chunk with dual-resolution approach
|
||||||
matted_frames = self.process_chunk(frames, chunk_idx)
|
matted_frames = self.process_chunk_dual_resolution(frame_data, chunk_idx)
|
||||||
chunk_results.append(matted_frames)
|
|
||||||
|
|
||||||
# Memory cleanup
|
# Save chunk to disk immediately to free memory
|
||||||
|
chunk_path = temp_chunk_dir / f"chunk_{chunk_idx:04d}.npz"
|
||||||
|
print(f"Saving chunk {chunk_idx} to disk...")
|
||||||
|
np.savez_compressed(str(chunk_path), frames=matted_frames)
|
||||||
|
|
||||||
|
# Save to checkpoint
|
||||||
|
checkpoint_mgr.save_chunk(chunk_idx, None, source_chunk_path=chunk_path)
|
||||||
|
|
||||||
|
chunk_files.append(chunk_path)
|
||||||
|
chunk_idx += 1
|
||||||
|
|
||||||
|
# Free the frames from memory immediately
|
||||||
|
del matted_frames
|
||||||
|
if self.config.processing.scale_factor == 1.0:
|
||||||
|
del frames
|
||||||
|
else:
|
||||||
|
del frame_data
|
||||||
|
|
||||||
|
# Update statistics
|
||||||
|
self.processing_stats['chunks_processed'] += 1
|
||||||
|
self.processing_stats['frames_processed'] += frames_to_read
|
||||||
|
|
||||||
|
# Aggressive memory cleanup after each chunk
|
||||||
|
self._aggressive_memory_cleanup(f"chunk {chunk_idx} completion")
|
||||||
|
|
||||||
|
# Also use memory manager cleanup
|
||||||
self.memory_manager.cleanup_memory()
|
self.memory_manager.cleanup_memory()
|
||||||
|
|
||||||
if self.memory_manager.should_emergency_cleanup():
|
if self.memory_manager.should_emergency_cleanup():
|
||||||
self.memory_manager.emergency_cleanup()
|
self.memory_manager.emergency_cleanup()
|
||||||
|
|
||||||
# Merge chunks if multiple
|
# Mark chunk processing as complete
|
||||||
print("\nMerging chunks...")
|
checkpoint_mgr.mark_processing_complete()
|
||||||
final_frames = self.merge_overlapping_chunks(chunk_results, overlap_frames)
|
|
||||||
|
|
||||||
# Save results
|
# Check if merge was already done
|
||||||
print(f"Saving {len(final_frames)} processed frames...")
|
if resume_info.get('merge_complete', False):
|
||||||
self.save_video(final_frames, self.config.output.path)
|
print("\n✅ Merge already completed in previous run!")
|
||||||
|
print(f" Output: {self.config.output.path}")
|
||||||
|
else:
|
||||||
|
# Use streaming merge to avoid memory accumulation (fixes OOM)
|
||||||
|
print("\n🎬 Using streaming merge (no memory accumulation)...")
|
||||||
|
|
||||||
|
# For resume scenarios, make sure we have all chunk files
|
||||||
|
if resume_info['can_resume']:
|
||||||
|
checkpoint_chunk_files = checkpoint_mgr.get_completed_chunk_files()
|
||||||
|
if len(checkpoint_chunk_files) != len(chunk_files):
|
||||||
|
print(f"⚠️ Using {len(checkpoint_chunk_files)} checkpoint files instead of {len(chunk_files)} temp files")
|
||||||
|
chunk_files = checkpoint_chunk_files
|
||||||
|
|
||||||
|
# Determine audio source for final video
|
||||||
|
audio_source = None
|
||||||
|
if self.config.output.preserve_audio and Path(self.config.input.video_path).exists():
|
||||||
|
audio_source = self.config.input.video_path
|
||||||
|
|
||||||
|
# Stream merge chunks directly to output (no memory accumulation)
|
||||||
|
self.merge_chunks_streaming(
|
||||||
|
chunk_files=chunk_files,
|
||||||
|
output_path=self.config.output.path,
|
||||||
|
overlap_frames=overlap_frames,
|
||||||
|
audio_source=audio_source
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mark merge as complete
|
||||||
|
checkpoint_mgr.mark_merge_complete()
|
||||||
|
|
||||||
|
print("✅ Streaming merge complete - no memory accumulation!")
|
||||||
|
|
||||||
|
# Calculate final statistics
|
||||||
|
self.processing_stats['end_time'] = time.time()
|
||||||
|
self.processing_stats['total_duration'] = self.processing_stats['end_time'] - self.processing_stats['start_time']
|
||||||
|
if self.processing_stats['total_duration'] > 0:
|
||||||
|
self.processing_stats['processing_fps'] = self.processing_stats['frames_processed'] / self.processing_stats['total_duration']
|
||||||
|
|
||||||
|
# Print processing statistics
|
||||||
|
self._print_processing_statistics()
|
||||||
|
|
||||||
# Print final memory report
|
# Print final memory report
|
||||||
self.memory_manager.print_memory_report()
|
self.memory_manager.print_memory_report()
|
||||||
|
|
||||||
print("Video processing completed!")
|
print("Video processing completed!")
|
||||||
|
|
||||||
|
# Option to clean up checkpoints
|
||||||
|
print("\n🗄️ CHECKPOINT CLEANUP OPTIONS:")
|
||||||
|
print(" Checkpoints saved successfully and can be cleaned up")
|
||||||
|
print(" - Keep checkpoints for debugging: checkpoint_mgr.cleanup_checkpoints(keep_chunks=True)")
|
||||||
|
print(" - Remove all checkpoints: checkpoint_mgr.cleanup_checkpoints()")
|
||||||
|
print(f" - Checkpoint location: {checkpoint_mgr.checkpoint_dir}")
|
||||||
|
|
||||||
|
# For now, keep checkpoints by default (user can manually clean)
|
||||||
|
print("\n💡 Checkpoints kept for safety. Delete manually when no longer needed.")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Clean up temporary chunk files (but not checkpoints)
|
||||||
|
if temp_chunk_dir.exists():
|
||||||
|
print("Cleaning up temporary chunk files...")
|
||||||
|
try:
|
||||||
|
shutil.rmtree(temp_chunk_dir)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"⚠️ Could not clean temp directory: {e}")
|
||||||
|
|
||||||
|
def _print_processing_statistics(self):
|
||||||
|
"""Print detailed processing statistics"""
|
||||||
|
stats = self.processing_stats
|
||||||
|
video_duration = self.total_frames / self.fps if self.fps > 0 else 0
|
||||||
|
|
||||||
|
print("\n" + "="*60)
|
||||||
|
print("PROCESSING STATISTICS")
|
||||||
|
print("="*60)
|
||||||
|
print(f"Input video duration: {video_duration:.1f} seconds ({self.total_frames} frames @ {self.fps:.2f} fps)")
|
||||||
|
print(f"Total processing time: {stats['total_duration']:.1f} seconds")
|
||||||
|
print(f"Processing speed: {stats['processing_fps']:.2f} fps")
|
||||||
|
print(f"Speedup factor: {self.fps / stats['processing_fps']:.1f}x slower than realtime")
|
||||||
|
print(f"Chunks processed: {stats['chunks_processed']}")
|
||||||
|
print(f"Frames processed: {stats['frames_processed']}")
|
||||||
|
|
||||||
|
if video_duration > 0:
|
||||||
|
efficiency = video_duration / stats['total_duration']
|
||||||
|
print(f"Processing efficiency: {efficiency:.3f} (1.0 = realtime)")
|
||||||
|
|
||||||
|
# Estimate time for different video lengths
|
||||||
|
print(f"\nEstimated processing times:")
|
||||||
|
print(f" 5 minutes: {(5 * 60) / efficiency / 60:.1f} minutes")
|
||||||
|
print(f" 30 minutes: {(30 * 60) / efficiency / 60:.1f} minutes")
|
||||||
|
print(f" 1 hour: {(60 * 60) / efficiency / 60:.1f} minutes")
|
||||||
|
|
||||||
|
print("="*60 + "\n")
|
||||||
@@ -3,6 +3,7 @@ import numpy as np
|
|||||||
from typing import List, Dict, Any, Optional, Tuple
|
from typing import List, Dict, Any, Optional, Tuple
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import warnings
|
import warnings
|
||||||
|
import torch
|
||||||
|
|
||||||
from .video_processor import VideoProcessor
|
from .video_processor import VideoProcessor
|
||||||
from .config import VR180Config
|
from .config import VR180Config
|
||||||
@@ -65,17 +66,31 @@ class VR180Processor(VideoProcessor):
|
|||||||
Returns:
|
Returns:
|
||||||
Tuple of (left_eye_frame, right_eye_frame)
|
Tuple of (left_eye_frame, right_eye_frame)
|
||||||
"""
|
"""
|
||||||
if self.sbs_split_point == 0:
|
# Always calculate split point based on current frame width
|
||||||
self.sbs_split_point = frame.shape[1] // 2
|
# This handles scaled frames correctly
|
||||||
|
frame_width = frame.shape[1]
|
||||||
|
current_split_point = frame_width // 2
|
||||||
|
|
||||||
left_eye = frame[:, :self.sbs_split_point]
|
# Debug info on first use
|
||||||
right_eye = frame[:, self.sbs_split_point:]
|
if self.sbs_split_point == 0:
|
||||||
|
print(f"Frame dimensions: {frame.shape[1]}x{frame.shape[0]}")
|
||||||
|
print(f"Split point: {current_split_point}")
|
||||||
|
self.sbs_split_point = current_split_point # Store for reference
|
||||||
|
|
||||||
|
left_eye = frame[:, :current_split_point]
|
||||||
|
right_eye = frame[:, current_split_point:]
|
||||||
|
|
||||||
|
# Validate both eyes have content
|
||||||
|
if left_eye.size == 0:
|
||||||
|
raise RuntimeError(f"Left eye frame is empty after split (frame width: {frame_width})")
|
||||||
|
if right_eye.size == 0:
|
||||||
|
raise RuntimeError(f"Right eye frame is empty after split (frame width: {frame_width})")
|
||||||
|
|
||||||
return left_eye, right_eye
|
return left_eye, right_eye
|
||||||
|
|
||||||
def combine_sbs_frame(self, left_eye: np.ndarray, right_eye: np.ndarray) -> np.ndarray:
|
def combine_sbs_frame(self, left_eye: np.ndarray, right_eye: np.ndarray) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Combine left and right eye frames back into side-by-side format
|
Combine left and right eye frames back into side-by-side format with GPU acceleration
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
left_eye: Left eye frame
|
left_eye: Left eye frame
|
||||||
@@ -84,14 +99,44 @@ class VR180Processor(VideoProcessor):
|
|||||||
Returns:
|
Returns:
|
||||||
Combined SBS frame
|
Combined SBS frame
|
||||||
"""
|
"""
|
||||||
|
try:
|
||||||
|
import cupy as cp
|
||||||
|
|
||||||
|
# Transfer to GPU for faster combination
|
||||||
|
left_gpu = cp.asarray(left_eye)
|
||||||
|
right_gpu = cp.asarray(right_eye)
|
||||||
|
|
||||||
|
# Ensure frames have same height
|
||||||
|
if left_gpu.shape[0] != right_gpu.shape[0]:
|
||||||
|
target_height = min(left_gpu.shape[0], right_gpu.shape[0])
|
||||||
|
# Note: OpenCV resize not available in CuPy, fall back to CPU for resize
|
||||||
|
left_eye = cv2.resize(left_eye, (left_eye.shape[1], target_height))
|
||||||
|
right_eye = cv2.resize(right_eye, (right_eye.shape[1], target_height))
|
||||||
|
left_gpu = cp.asarray(left_eye)
|
||||||
|
right_gpu = cp.asarray(right_eye)
|
||||||
|
|
||||||
|
# Combine horizontally on GPU (much faster for large arrays)
|
||||||
|
combined_gpu = cp.hstack([left_gpu, right_gpu])
|
||||||
|
|
||||||
|
# Transfer back to CPU and ensure we get a copy, not a view
|
||||||
|
combined = cp.asnumpy(combined_gpu).copy()
|
||||||
|
|
||||||
|
# Free GPU memory immediately
|
||||||
|
del left_gpu, right_gpu, combined_gpu
|
||||||
|
cp._default_memory_pool.free_all_blocks()
|
||||||
|
|
||||||
|
return combined
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
# Fallback to CPU NumPy
|
||||||
# Ensure frames have same height
|
# Ensure frames have same height
|
||||||
if left_eye.shape[0] != right_eye.shape[0]:
|
if left_eye.shape[0] != right_eye.shape[0]:
|
||||||
target_height = min(left_eye.shape[0], right_eye.shape[0])
|
target_height = min(left_eye.shape[0], right_eye.shape[0])
|
||||||
left_eye = cv2.resize(left_eye, (left_eye.shape[1], target_height))
|
left_eye = cv2.resize(left_eye, (left_eye.shape[1], target_height))
|
||||||
right_eye = cv2.resize(right_eye, (right_eye.shape[1], target_height))
|
right_eye = cv2.resize(right_eye, (right_eye.shape[1], target_height))
|
||||||
|
|
||||||
# Combine horizontally
|
# Combine horizontally and ensure we get a copy, not a view
|
||||||
combined = np.hstack([left_eye, right_eye])
|
combined = np.hstack([left_eye, right_eye]).copy()
|
||||||
return combined
|
return combined
|
||||||
|
|
||||||
def process_with_disparity_mapping(self,
|
def process_with_disparity_mapping(self,
|
||||||
@@ -113,8 +158,23 @@ class VR180Processor(VideoProcessor):
|
|||||||
left_eye_frames = []
|
left_eye_frames = []
|
||||||
right_eye_frames = []
|
right_eye_frames = []
|
||||||
|
|
||||||
for frame in frames:
|
for i, frame in enumerate(frames):
|
||||||
left, right = self.split_sbs_frame(frame)
|
left, right = self.split_sbs_frame(frame)
|
||||||
|
|
||||||
|
# Debug: Check if frames are valid
|
||||||
|
if i == 0: # Only debug first frame
|
||||||
|
print(f"Original frame shape: {frame.shape}")
|
||||||
|
print(f"Left eye shape: {left.shape}")
|
||||||
|
print(f"Right eye shape: {right.shape}")
|
||||||
|
print(f"Left eye min/max: {left.min()}/{left.max()}")
|
||||||
|
print(f"Right eye min/max: {right.min()}/{right.max()}")
|
||||||
|
|
||||||
|
# Validate frames
|
||||||
|
if left.size == 0:
|
||||||
|
raise RuntimeError(f"Left eye frame {i} is empty")
|
||||||
|
if right.size == 0:
|
||||||
|
raise RuntimeError(f"Right eye frame {i} is empty")
|
||||||
|
|
||||||
left_eye_frames.append(left)
|
left_eye_frames.append(left)
|
||||||
right_eye_frames.append(right)
|
right_eye_frames.append(right)
|
||||||
|
|
||||||
@@ -123,6 +183,10 @@ class VR180Processor(VideoProcessor):
|
|||||||
with self.memory_manager.memory_monitor(f"left eye chunk {chunk_idx}"):
|
with self.memory_manager.memory_monitor(f"left eye chunk {chunk_idx}"):
|
||||||
left_matted = self._process_eye_sequence(left_eye_frames, "left", chunk_idx)
|
left_matted = self._process_eye_sequence(left_eye_frames, "left", chunk_idx)
|
||||||
|
|
||||||
|
# Free left eye frames after processing (before right eye to save memory)
|
||||||
|
del left_eye_frames
|
||||||
|
self._aggressive_memory_cleanup(f"After left eye processing chunk {chunk_idx}")
|
||||||
|
|
||||||
# Process right eye with cross-validation
|
# Process right eye with cross-validation
|
||||||
print("Processing right eye with cross-validation...")
|
print("Processing right eye with cross-validation...")
|
||||||
with self.memory_manager.memory_monitor(f"right eye chunk {chunk_idx}"):
|
with self.memory_manager.memory_monitor(f"right eye chunk {chunk_idx}"):
|
||||||
@@ -130,6 +194,10 @@ class VR180Processor(VideoProcessor):
|
|||||||
right_eye_frames, left_matted, "right", chunk_idx
|
right_eye_frames, left_matted, "right", chunk_idx
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Free right eye frames after processing
|
||||||
|
del right_eye_frames
|
||||||
|
self._aggressive_memory_cleanup(f"After right eye processing chunk {chunk_idx}")
|
||||||
|
|
||||||
# Combine results back to SBS format
|
# Combine results back to SBS format
|
||||||
combined_frames = []
|
combined_frames = []
|
||||||
for left_frame, right_frame in zip(left_matted, right_matted):
|
for left_frame, right_frame in zip(left_matted, right_matted):
|
||||||
@@ -140,6 +208,15 @@ class VR180Processor(VideoProcessor):
|
|||||||
combined = {'left': left_frame, 'right': right_frame}
|
combined = {'left': left_frame, 'right': right_frame}
|
||||||
combined_frames.append(combined)
|
combined_frames.append(combined)
|
||||||
|
|
||||||
|
# Free the individual eye results after combining
|
||||||
|
del left_matted
|
||||||
|
del right_matted
|
||||||
|
self._aggressive_memory_cleanup(f"After combining frames chunk {chunk_idx}")
|
||||||
|
|
||||||
|
# CRITICAL: Complete inter-chunk cleanup to prevent model persistence
|
||||||
|
# This ensures models don't accumulate between chunks
|
||||||
|
self._complete_inter_chunk_cleanup(chunk_idx)
|
||||||
|
|
||||||
return combined_frames
|
return combined_frames
|
||||||
|
|
||||||
def _process_eye_sequence(self,
|
def _process_eye_sequence(self,
|
||||||
@@ -150,16 +227,148 @@ class VR180Processor(VideoProcessor):
|
|||||||
if not eye_frames:
|
if not eye_frames:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# Initialize SAM2 with eye frames
|
# Create a unique temporary video for this eye processing
|
||||||
self.sam2_model.init_video_state(eye_frames)
|
import uuid
|
||||||
|
temp_video_name = f"temp_sam2_{eye_name}_chunk{chunk_idx}_{uuid.uuid4().hex[:8]}.mp4"
|
||||||
|
temp_video_path = Path.cwd() / temp_video_name
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Use ffmpeg approach since OpenCV video writer is failing
|
||||||
|
height, width = eye_frames[0].shape[:2]
|
||||||
|
temp_video_path = temp_video_path.with_suffix('.mp4')
|
||||||
|
|
||||||
|
print(f"Creating temp video using ffmpeg: {temp_video_path}")
|
||||||
|
print(f"Video params: size=({width}, {height}), frames={len(eye_frames)}")
|
||||||
|
|
||||||
|
# Create a temporary directory for frame images
|
||||||
|
temp_frames_dir = temp_video_path.parent / f"frames_{temp_video_path.stem}"
|
||||||
|
temp_frames_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
# Save frames as individual images (using JPEG for smaller file size)
|
||||||
|
print("Saving frames as images...")
|
||||||
|
for i, frame in enumerate(eye_frames):
|
||||||
|
# Check if frame is empty
|
||||||
|
if frame.size == 0:
|
||||||
|
raise RuntimeError(f"Frame {i} is empty (size=0)")
|
||||||
|
|
||||||
|
# Ensure frame is uint8
|
||||||
|
if frame.dtype != np.uint8:
|
||||||
|
frame = frame.astype(np.uint8)
|
||||||
|
|
||||||
|
# Debug first frame
|
||||||
|
if i == 0:
|
||||||
|
print(f"First frame to save: shape={frame.shape}, dtype={frame.dtype}, empty={frame.size == 0}")
|
||||||
|
|
||||||
|
# Use JPEG instead of PNG for smaller files (faster I/O, less disk space)
|
||||||
|
frame_path = temp_frames_dir / f"frame_{i:06d}.jpg"
|
||||||
|
# Use high quality JPEG to minimize compression artifacts
|
||||||
|
success = cv2.imwrite(str(frame_path), frame, [cv2.IMWRITE_JPEG_QUALITY, 95])
|
||||||
|
if not success:
|
||||||
|
print(f"Frame {i} details: shape={frame.shape}, dtype={frame.dtype}, size={frame.size}")
|
||||||
|
raise RuntimeError(f"Failed to save frame {i} as image")
|
||||||
|
|
||||||
|
if i % 50 == 0:
|
||||||
|
print(f"Saved {i}/{len(eye_frames)} frames")
|
||||||
|
|
||||||
|
# Force garbage collection every 100 frames to free memory
|
||||||
|
if i % 100 == 0:
|
||||||
|
import gc
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
# Use ffmpeg to create video from images
|
||||||
|
import subprocess
|
||||||
|
# Use the original video's framerate - access through parent class
|
||||||
|
original_fps = self.fps if hasattr(self, 'fps') else 30.0
|
||||||
|
print(f"Using framerate: {original_fps} fps")
|
||||||
|
|
||||||
|
# Memory monitoring before ffmpeg
|
||||||
|
self._print_memory_step(f"Before ffmpeg encoding ({eye_name} eye)")
|
||||||
|
# Try GPU encoding first, fallback to CPU
|
||||||
|
gpu_cmd = [
|
||||||
|
'ffmpeg', '-y', # -y to overwrite output file
|
||||||
|
'-framerate', str(original_fps),
|
||||||
|
'-i', str(temp_frames_dir / 'frame_%06d.jpg'),
|
||||||
|
'-c:v', 'h264_nvenc', # NVIDIA GPU encoder
|
||||||
|
'-preset', 'fast', # GPU preset
|
||||||
|
'-cq', '18', # Quality for GPU encoding
|
||||||
|
'-pix_fmt', 'yuv420p',
|
||||||
|
str(temp_video_path)
|
||||||
|
]
|
||||||
|
|
||||||
|
cpu_cmd = [
|
||||||
|
'ffmpeg', '-y', # -y to overwrite output file
|
||||||
|
'-framerate', str(original_fps),
|
||||||
|
'-i', str(temp_frames_dir / 'frame_%06d.jpg'),
|
||||||
|
'-c:v', 'libx264', # CPU encoder
|
||||||
|
'-pix_fmt', 'yuv420p',
|
||||||
|
'-crf', '18', # Quality for CPU encoding
|
||||||
|
'-preset', 'medium',
|
||||||
|
str(temp_video_path)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Try GPU first
|
||||||
|
print(f"Trying GPU encoding: {' '.join(gpu_cmd)}")
|
||||||
|
result = subprocess.run(gpu_cmd, capture_output=True, text=True)
|
||||||
|
|
||||||
|
if result.returncode != 0:
|
||||||
|
print("GPU encoding failed, trying CPU...")
|
||||||
|
print(f"GPU error: {result.stderr}")
|
||||||
|
ffmpeg_cmd = cpu_cmd
|
||||||
|
print(f"Using CPU encoding: {' '.join(ffmpeg_cmd)}")
|
||||||
|
result = subprocess.run(ffmpeg_cmd, capture_output=True, text=True)
|
||||||
|
else:
|
||||||
|
print("GPU encoding successful!")
|
||||||
|
ffmpeg_cmd = gpu_cmd
|
||||||
|
|
||||||
|
print(f"Running ffmpeg: {' '.join(ffmpeg_cmd)}")
|
||||||
|
result = subprocess.run(ffmpeg_cmd, capture_output=True, text=True)
|
||||||
|
|
||||||
|
if result.returncode != 0:
|
||||||
|
print(f"FFmpeg stdout: {result.stdout}")
|
||||||
|
print(f"FFmpeg stderr: {result.stderr}")
|
||||||
|
raise RuntimeError(f"FFmpeg failed with return code {result.returncode}")
|
||||||
|
|
||||||
|
# Clean up frame images
|
||||||
|
import shutil
|
||||||
|
shutil.rmtree(temp_frames_dir)
|
||||||
|
|
||||||
|
print(f"Created temp video successfully")
|
||||||
|
|
||||||
|
# Memory monitoring after ffmpeg
|
||||||
|
self._print_memory_step(f"After ffmpeg encoding ({eye_name} eye)")
|
||||||
|
|
||||||
|
# Verify the file was created and has content
|
||||||
|
if not temp_video_path.exists():
|
||||||
|
raise RuntimeError(f"Temporary video file was not created: {temp_video_path}")
|
||||||
|
|
||||||
|
file_size = temp_video_path.stat().st_size
|
||||||
|
if file_size == 0:
|
||||||
|
raise RuntimeError(f"Temporary video file is empty: {temp_video_path}")
|
||||||
|
|
||||||
|
print(f"Created temp video {temp_video_path} ({file_size / 1024 / 1024:.1f} MB)")
|
||||||
|
|
||||||
|
# Memory monitoring and cleanup before SAM2 initialization
|
||||||
|
num_frames = len(eye_frames) # Store count before freeing
|
||||||
|
first_frame = eye_frames[0].copy() # Copy first frame for detection before freeing
|
||||||
|
self._print_memory_step(f"Before SAM2 init ({eye_name} eye, {num_frames} frames)")
|
||||||
|
|
||||||
|
# CRITICAL: Explicitly free eye_frames from memory before SAM2 loads the same video
|
||||||
|
# This prevents the OOM issue where both Python frames and SAM2 frames exist simultaneously
|
||||||
|
del eye_frames # Free the frames array
|
||||||
|
self._aggressive_memory_cleanup(f"SAM2 init for {eye_name} eye")
|
||||||
|
|
||||||
|
# Initialize SAM2 with video path
|
||||||
|
self._print_memory_step(f"Starting SAM2 init ({eye_name} eye)")
|
||||||
|
self.sam2_model.init_video_state(video_path=str(temp_video_path))
|
||||||
|
self._print_memory_step(f"SAM2 initialized ({eye_name} eye)")
|
||||||
|
|
||||||
# Detect persons in first frame
|
# Detect persons in first frame
|
||||||
first_frame = eye_frames[0]
|
|
||||||
detections = self.detector.detect_persons(first_frame)
|
detections = self.detector.detect_persons(first_frame)
|
||||||
|
|
||||||
if not detections:
|
if not detections:
|
||||||
warnings.warn(f"No persons detected in {eye_name} eye, chunk {chunk_idx}")
|
warnings.warn(f"No persons detected in {eye_name} eye, chunk {chunk_idx}")
|
||||||
return self._create_empty_masks(eye_frames)
|
# Return empty masks for the number of frames
|
||||||
|
return self._create_empty_masks_from_count(num_frames, first_frame.shape)
|
||||||
|
|
||||||
print(f"Detected {len(detections)} persons in {eye_name} eye first frame")
|
print(f"Detected {len(detections)} persons in {eye_name} eye first frame")
|
||||||
|
|
||||||
@@ -169,15 +378,45 @@ class VR180Processor(VideoProcessor):
|
|||||||
# Add prompts
|
# Add prompts
|
||||||
object_ids = self.sam2_model.add_person_prompts(0, box_prompts, labels)
|
object_ids = self.sam2_model.add_person_prompts(0, box_prompts, labels)
|
||||||
|
|
||||||
# Propagate masks
|
# Propagate masks (most expensive operation)
|
||||||
|
self._print_memory_step(f"Before SAM2 propagation ({eye_name} eye, {num_frames} frames)")
|
||||||
|
|
||||||
|
# Use Det-SAM2 continuous correction if enabled
|
||||||
|
if self.config.matting.continuous_correction:
|
||||||
|
video_segments = self.sam2_model.propagate_masks_with_continuous_correction(
|
||||||
|
detector=self.detector,
|
||||||
|
temp_video_path=str(temp_video_path),
|
||||||
|
start_frame=0,
|
||||||
|
max_frames=num_frames,
|
||||||
|
correction_interval=self.config.matting.correction_interval,
|
||||||
|
frame_release_interval=self.config.matting.frame_release_interval,
|
||||||
|
frame_window_size=self.config.matting.frame_window_size
|
||||||
|
)
|
||||||
|
print(f"Used Det-SAM2 continuous correction (interval: {self.config.matting.correction_interval} frames)")
|
||||||
|
else:
|
||||||
video_segments = self.sam2_model.propagate_masks(
|
video_segments = self.sam2_model.propagate_masks(
|
||||||
start_frame=0,
|
start_frame=0,
|
||||||
max_frames=len(eye_frames)
|
max_frames=num_frames,
|
||||||
|
frame_release_interval=self.config.matting.frame_release_interval,
|
||||||
|
frame_window_size=self.config.matting.frame_window_size
|
||||||
)
|
)
|
||||||
|
|
||||||
# Apply masks
|
self._print_memory_step(f"After SAM2 propagation ({eye_name} eye)")
|
||||||
|
|
||||||
|
# Apply masks with streaming approach (no frame accumulation)
|
||||||
|
self._print_memory_step(f"Before streaming mask application ({eye_name} eye)")
|
||||||
|
|
||||||
|
# Process frames one at a time without accumulation
|
||||||
|
cap = cv2.VideoCapture(str(temp_video_path))
|
||||||
matted_frames = []
|
matted_frames = []
|
||||||
for frame_idx, frame in enumerate(eye_frames):
|
|
||||||
|
try:
|
||||||
|
for frame_idx in range(num_frames):
|
||||||
|
ret, frame = cap.read()
|
||||||
|
if not ret:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Apply mask to this single frame
|
||||||
if frame_idx in video_segments:
|
if frame_idx in video_segments:
|
||||||
frame_masks = video_segments[frame_idx]
|
frame_masks = video_segments[frame_idx]
|
||||||
combined_mask = self.sam2_model.get_combined_mask(frame_masks)
|
combined_mask = self.sam2_model.get_combined_mask(frame_masks)
|
||||||
@@ -192,10 +431,34 @@ class VR180Processor(VideoProcessor):
|
|||||||
|
|
||||||
matted_frames.append(matted_frame)
|
matted_frames.append(matted_frame)
|
||||||
|
|
||||||
# Cleanup
|
# Free the original frame immediately (no accumulation)
|
||||||
|
del frame
|
||||||
|
|
||||||
|
# Periodic cleanup during processing
|
||||||
|
if frame_idx % 100 == 0 and frame_idx > 0:
|
||||||
|
import gc
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
finally:
|
||||||
|
cap.release()
|
||||||
|
|
||||||
|
# Free video segments completely
|
||||||
|
del video_segments # This holds processed masks from SAM2
|
||||||
|
self._aggressive_memory_cleanup(f"After streaming mask application ({eye_name} eye)")
|
||||||
|
|
||||||
|
self._print_memory_step(f"Completed streaming mask application ({eye_name} eye)")
|
||||||
|
return matted_frames
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Always cleanup
|
||||||
self.sam2_model.cleanup()
|
self.sam2_model.cleanup()
|
||||||
|
|
||||||
return matted_frames
|
# Remove temporary video file
|
||||||
|
try:
|
||||||
|
if temp_video_path.exists():
|
||||||
|
temp_video_path.unlink()
|
||||||
|
except Exception as e:
|
||||||
|
warnings.warn(f"Failed to cleanup temp video {temp_video_path}: {e}")
|
||||||
|
|
||||||
def _process_eye_sequence_with_validation(self,
|
def _process_eye_sequence_with_validation(self,
|
||||||
right_eye_frames: List[np.ndarray],
|
right_eye_frames: List[np.ndarray],
|
||||||
@@ -223,13 +486,17 @@ class VR180Processor(VideoProcessor):
|
|||||||
left_eye_results, right_matted
|
left_eye_results, right_matted
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# CRITICAL: Free the intermediate results to prevent memory accumulation
|
||||||
|
del left_eye_results # Don't keep left eye results after validation
|
||||||
|
del right_matted # Don't keep unvalidated right results
|
||||||
|
|
||||||
return validated_results
|
return validated_results
|
||||||
|
|
||||||
def _validate_stereo_consistency(self,
|
def _validate_stereo_consistency(self,
|
||||||
left_results: List[np.ndarray],
|
left_results: List[np.ndarray],
|
||||||
right_results: List[np.ndarray]) -> List[np.ndarray]:
|
right_results: List[np.ndarray]) -> List[np.ndarray]:
|
||||||
"""
|
"""
|
||||||
Validate and correct stereo consistency between left and right eye results
|
Validate and correct stereo consistency between left and right eye results using GPU acceleration
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
left_results: Left eye processed frames
|
left_results: Left eye processed frames
|
||||||
@@ -238,9 +505,120 @@ class VR180Processor(VideoProcessor):
|
|||||||
Returns:
|
Returns:
|
||||||
Validated right eye frames
|
Validated right eye frames
|
||||||
"""
|
"""
|
||||||
|
print(f"🔍 VALIDATION: Starting stereo consistency check ({len(left_results)} frames)")
|
||||||
|
|
||||||
|
try:
|
||||||
|
import cupy as cp
|
||||||
|
return self._validate_stereo_consistency_gpu(left_results, right_results)
|
||||||
|
except ImportError:
|
||||||
|
print(" Warning: CuPy not available, using CPU validation")
|
||||||
|
return self._validate_stereo_consistency_cpu(left_results, right_results)
|
||||||
|
|
||||||
|
def _validate_stereo_consistency_gpu(self,
|
||||||
|
left_results: List[np.ndarray],
|
||||||
|
right_results: List[np.ndarray]) -> List[np.ndarray]:
|
||||||
|
"""GPU-accelerated batch stereo validation using CuPy with memory-safe batching"""
|
||||||
|
import cupy as cp
|
||||||
|
|
||||||
|
print(" Using GPU acceleration for stereo validation")
|
||||||
|
|
||||||
|
# Process in batches to avoid GPU OOM
|
||||||
|
batch_size = 50 # Process 50 frames at a time (safe for 45GB GPU)
|
||||||
|
total_frames = len(left_results)
|
||||||
|
area_ratios_all = []
|
||||||
|
needs_correction_all = []
|
||||||
|
|
||||||
|
print(f" Processing {total_frames} frames in batches of {batch_size}...")
|
||||||
|
|
||||||
|
for batch_start in range(0, total_frames, batch_size):
|
||||||
|
batch_end = min(batch_start + batch_size, total_frames)
|
||||||
|
batch_frames = batch_end - batch_start
|
||||||
|
|
||||||
|
if batch_start % 100 == 0:
|
||||||
|
print(f" GPU batch {batch_start//batch_size + 1}: frames {batch_start}-{batch_end}")
|
||||||
|
|
||||||
|
# Get batch slices
|
||||||
|
left_batch = left_results[batch_start:batch_end]
|
||||||
|
right_batch = right_results[batch_start:batch_end]
|
||||||
|
|
||||||
|
# Convert batch to GPU
|
||||||
|
left_stack = cp.stack([cp.asarray(frame) for frame in left_batch])
|
||||||
|
right_stack = cp.stack([cp.asarray(frame) for frame in right_batch])
|
||||||
|
|
||||||
|
# Batch calculate mask areas for this batch
|
||||||
|
if left_stack.shape[3] == 4: # Alpha channel
|
||||||
|
left_masks = left_stack[:, :, :, 3] > 0
|
||||||
|
right_masks = right_stack[:, :, :, 3] > 0
|
||||||
|
else: # Green screen detection
|
||||||
|
bg_color = cp.array(self.config.output.background_color)
|
||||||
|
left_diff = cp.abs(left_stack.astype(cp.float32) - bg_color).sum(axis=3)
|
||||||
|
right_diff = cp.abs(right_stack.astype(cp.float32) - bg_color).sum(axis=3)
|
||||||
|
left_masks = left_diff > 30
|
||||||
|
right_masks = right_diff > 30
|
||||||
|
|
||||||
|
# Calculate areas for this batch
|
||||||
|
left_areas = cp.sum(left_masks, axis=(1, 2))
|
||||||
|
right_areas = cp.sum(right_masks, axis=(1, 2))
|
||||||
|
area_ratios = right_areas.astype(cp.float32) / (left_areas.astype(cp.float32) + 1e-6)
|
||||||
|
|
||||||
|
# Find frames needing correction in this batch
|
||||||
|
needs_correction = (area_ratios < 0.5) | (area_ratios > 2.0)
|
||||||
|
|
||||||
|
# Transfer batch results back to CPU and accumulate
|
||||||
|
area_ratios_all.extend(cp.asnumpy(area_ratios))
|
||||||
|
needs_correction_all.extend(cp.asnumpy(needs_correction))
|
||||||
|
|
||||||
|
# Free GPU memory for this batch
|
||||||
|
del left_stack, right_stack, left_masks, right_masks
|
||||||
|
del left_areas, right_areas, area_ratios, needs_correction
|
||||||
|
cp._default_memory_pool.free_all_blocks()
|
||||||
|
|
||||||
|
# CRITICAL: Release ALL CuPy memory back to system after validation
|
||||||
|
try:
|
||||||
|
# Force release of all GPU memory pools
|
||||||
|
cp._default_memory_pool.free_all_blocks()
|
||||||
|
cp._default_pinned_memory_pool.free_all_blocks()
|
||||||
|
|
||||||
|
# Clear CuPy cache completely
|
||||||
|
cp.get_default_memory_pool().free_all_blocks()
|
||||||
|
cp.get_default_pinned_memory_pool().free_all_blocks()
|
||||||
|
|
||||||
|
print(f" CuPy memory pools cleared")
|
||||||
|
except Exception as e:
|
||||||
|
print(f" Warning: Could not clear CuPy memory pools: {e}")
|
||||||
|
|
||||||
|
correction_count = sum(needs_correction_all)
|
||||||
|
print(f" GPU validation complete: {correction_count}/{total_frames} frames need correction")
|
||||||
|
|
||||||
|
# Apply corrections using CPU results
|
||||||
|
validated_frames = []
|
||||||
|
for i, (needs_fix, ratio) in enumerate(zip(needs_correction_all, area_ratios_all)):
|
||||||
|
if i % 100 == 0:
|
||||||
|
print(f" Processing validation results: {i}/{total_frames}")
|
||||||
|
|
||||||
|
if needs_fix:
|
||||||
|
# Apply correction
|
||||||
|
corrected_frame = self._apply_stereo_correction(
|
||||||
|
left_results[i], right_results[i], float(ratio)
|
||||||
|
)
|
||||||
|
validated_frames.append(corrected_frame)
|
||||||
|
else:
|
||||||
|
validated_frames.append(right_results[i])
|
||||||
|
|
||||||
|
print("✅ VALIDATION: GPU stereo consistency check complete")
|
||||||
|
return validated_frames
|
||||||
|
|
||||||
|
def _validate_stereo_consistency_cpu(self,
|
||||||
|
left_results: List[np.ndarray],
|
||||||
|
right_results: List[np.ndarray]) -> List[np.ndarray]:
|
||||||
|
"""CPU fallback for stereo validation"""
|
||||||
|
print(" Using CPU validation (slower)")
|
||||||
validated_frames = []
|
validated_frames = []
|
||||||
|
|
||||||
for i, (left_frame, right_frame) in enumerate(zip(left_results, right_results)):
|
for i, (left_frame, right_frame) in enumerate(zip(left_results, right_results)):
|
||||||
|
if i % 50 == 0: # Progress every 50 frames
|
||||||
|
print(f" CPU validation progress: {i}/{len(left_results)}")
|
||||||
|
|
||||||
# Simple validation: check if mask areas are similar
|
# Simple validation: check if mask areas are similar
|
||||||
left_mask_area = self._get_mask_area(left_frame)
|
left_mask_area = self._get_mask_area(left_frame)
|
||||||
right_mask_area = self._get_mask_area(right_frame)
|
right_mask_area = self._get_mask_area(right_frame)
|
||||||
@@ -257,10 +635,44 @@ class VR180Processor(VideoProcessor):
|
|||||||
else:
|
else:
|
||||||
validated_frames.append(right_frame)
|
validated_frames.append(right_frame)
|
||||||
|
|
||||||
|
print("✅ VALIDATION: CPU stereo consistency check complete")
|
||||||
return validated_frames
|
return validated_frames
|
||||||
|
|
||||||
|
def _create_empty_masks_from_count(self, num_frames: int, frame_shape: tuple) -> List[np.ndarray]:
|
||||||
|
"""Create empty masks when no persons detected (without frame array)"""
|
||||||
|
empty_frames = []
|
||||||
|
for _ in range(num_frames):
|
||||||
|
if self.config.output.format == "alpha":
|
||||||
|
# Transparent output
|
||||||
|
output = np.zeros((frame_shape[0], frame_shape[1], 4), dtype=np.uint8)
|
||||||
|
else:
|
||||||
|
# Green screen background
|
||||||
|
output = np.full((frame_shape[0], frame_shape[1], 3),
|
||||||
|
self.config.output.background_color, dtype=np.uint8)
|
||||||
|
empty_frames.append(output)
|
||||||
|
return empty_frames
|
||||||
|
|
||||||
def _get_mask_area(self, frame: np.ndarray) -> float:
|
def _get_mask_area(self, frame: np.ndarray) -> float:
|
||||||
"""Get mask area from processed frame"""
|
"""Get mask area from processed frame using GPU acceleration"""
|
||||||
|
try:
|
||||||
|
import cupy as cp
|
||||||
|
|
||||||
|
# Transfer to GPU
|
||||||
|
frame_gpu = cp.asarray(frame)
|
||||||
|
|
||||||
|
if frame.shape[2] == 4: # Alpha channel
|
||||||
|
mask_gpu = frame_gpu[:, :, 3] > 0
|
||||||
|
else: # Green screen - detect non-background pixels
|
||||||
|
bg_color_gpu = cp.array(self.config.output.background_color)
|
||||||
|
diff_gpu = cp.abs(frame_gpu.astype(cp.float32) - bg_color_gpu).sum(axis=2)
|
||||||
|
mask_gpu = diff_gpu > 30 # Threshold for non-background
|
||||||
|
|
||||||
|
# Calculate area on GPU and return as Python int
|
||||||
|
area = int(cp.sum(mask_gpu))
|
||||||
|
return area
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
# Fallback to CPU NumPy if CuPy not available
|
||||||
if frame.shape[2] == 4: # Alpha channel
|
if frame.shape[2] == 4: # Alpha channel
|
||||||
mask = frame[:, :, 3] > 0
|
mask = frame[:, :, 3] > 0
|
||||||
else: # Green screen - detect non-background pixels
|
else: # Green screen - detect non-background pixels
|
||||||
@@ -284,6 +696,64 @@ class VR180Processor(VideoProcessor):
|
|||||||
# TODO: Implement proper stereo correction algorithm
|
# TODO: Implement proper stereo correction algorithm
|
||||||
return right_frame
|
return right_frame
|
||||||
|
|
||||||
|
def _complete_inter_chunk_cleanup(self, chunk_idx: int):
|
||||||
|
"""
|
||||||
|
Complete inter-chunk cleanup: Destroy all models to prevent memory accumulation
|
||||||
|
|
||||||
|
This addresses the core issue where SAM2 and YOLO models (~15-20GB)
|
||||||
|
persist between chunks, causing OOM when processing subsequent chunks.
|
||||||
|
"""
|
||||||
|
print(f"🧹 INTER-CHUNK CLEANUP: Destroying all models after chunk {chunk_idx}")
|
||||||
|
|
||||||
|
# 1. Completely destroy SAM2 model (15-20GB)
|
||||||
|
if hasattr(self, 'sam2_model') and self.sam2_model is not None:
|
||||||
|
self.sam2_model.cleanup() # Call existing cleanup
|
||||||
|
|
||||||
|
# Force complete destruction of the model
|
||||||
|
try:
|
||||||
|
# Reset the model's loaded state so it will reload fresh
|
||||||
|
if hasattr(self.sam2_model, '_model_loaded'):
|
||||||
|
self.sam2_model._model_loaded = False
|
||||||
|
|
||||||
|
# Clear any cached state
|
||||||
|
if hasattr(self.sam2_model, 'predictor'):
|
||||||
|
self.sam2_model.predictor = None
|
||||||
|
if hasattr(self.sam2_model, 'inference_state'):
|
||||||
|
self.sam2_model.inference_state = None
|
||||||
|
|
||||||
|
print(f" ✅ SAM2 model destroyed and marked for fresh reload")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f" ⚠️ SAM2 destruction warning: {e}")
|
||||||
|
|
||||||
|
# 2. Completely destroy YOLO detector (400MB+)
|
||||||
|
if hasattr(self, 'detector') and self.detector is not None:
|
||||||
|
try:
|
||||||
|
# Force YOLO model to be reloaded fresh
|
||||||
|
if hasattr(self.detector, 'model') and self.detector.model is not None:
|
||||||
|
del self.detector.model
|
||||||
|
self.detector.model = None
|
||||||
|
print(f" ✅ YOLO model destroyed and marked for fresh reload")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f" ⚠️ YOLO destruction warning: {e}")
|
||||||
|
|
||||||
|
# 3. Clear CUDA cache aggressively
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
torch.cuda.synchronize() # Wait for all operations to complete
|
||||||
|
print(f" ✅ CUDA cache cleared")
|
||||||
|
|
||||||
|
# 4. Force garbage collection
|
||||||
|
import gc
|
||||||
|
collected = gc.collect()
|
||||||
|
print(f" ✅ Garbage collection: {collected} objects freed")
|
||||||
|
|
||||||
|
# 5. Memory verification
|
||||||
|
self._print_memory_step(f"After complete inter-chunk cleanup (chunk {chunk_idx})")
|
||||||
|
|
||||||
|
print(f"🎯 RESULT: Models will reload fresh for next chunk (prevents 15-20GB accumulation)")
|
||||||
|
|
||||||
def process_chunk(self,
|
def process_chunk(self,
|
||||||
frames: List[np.ndarray],
|
frames: List[np.ndarray],
|
||||||
chunk_idx: int = 0) -> List[np.ndarray]:
|
chunk_idx: int = 0) -> List[np.ndarray]:
|
||||||
@@ -343,6 +813,9 @@ class VR180Processor(VideoProcessor):
|
|||||||
combined = {'left': left_frame, 'right': right_frame}
|
combined = {'left': left_frame, 'right': right_frame}
|
||||||
combined_frames.append(combined)
|
combined_frames.append(combined)
|
||||||
|
|
||||||
|
# CRITICAL: Complete inter-chunk cleanup for independent processing too
|
||||||
|
self._complete_inter_chunk_cleanup(chunk_idx)
|
||||||
|
|
||||||
return combined_frames
|
return combined_frames
|
||||||
|
|
||||||
def save_video(self, frames: List[np.ndarray], output_path: str):
|
def save_video(self, frames: List[np.ndarray], output_path: str):
|
||||||
|
|||||||
Reference in New Issue
Block a user