Compare commits

..

2 Commits

Author SHA1 Message Date
4d1361df46 bigtime 2025-07-26 15:29:37 -07:00
884cb8dce2 lol 2025-07-26 15:29:28 -07:00
3 changed files with 210 additions and 2 deletions

139
test_inter_chunk_cleanup.py Normal file
View File

@@ -0,0 +1,139 @@
#!/usr/bin/env python3
"""
Test script to verify inter-chunk cleanup properly destroys models
"""
import psutil
import gc
import sys
from pathlib import Path
def get_memory_usage():
"""Get current memory usage in GB"""
process = psutil.Process()
return process.memory_info().rss / (1024**3)
def test_inter_chunk_cleanup():
"""Test that models are properly destroyed between chunks"""
print("🧪 TESTING INTER-CHUNK CLEANUP")
print("=" * 50)
baseline_memory = get_memory_usage()
print(f"📊 Baseline memory: {baseline_memory:.2f} GB")
# Import and create processor
print("\n1⃣ Creating processor...")
from vr180_matting.config import VR180Config
from vr180_matting.vr180_processor import VR180Processor
config = VR180Config.from_yaml('config.yaml')
processor = VR180Processor(config)
init_memory = get_memory_usage()
print(f"📊 After processor init: {init_memory:.2f} GB (+{init_memory - baseline_memory:.2f} GB)")
# Simulate chunk processing (just trigger model loading)
print("\n2⃣ Simulating chunk 0 processing...")
# Test 1: Force YOLO model loading
try:
detector = processor.detector
detector._load_model() # Force load
yolo_memory = get_memory_usage()
print(f"📊 After YOLO load: {yolo_memory:.2f} GB (+{yolo_memory - init_memory:.2f} GB)")
except Exception as e:
print(f"❌ YOLO loading failed: {e}")
yolo_memory = init_memory
# Test 2: Force SAM2 model loading
try:
sam2_model = processor.sam2_model
sam2_model._load_model(sam2_model.model_cfg, sam2_model.checkpoint_path)
sam2_memory = get_memory_usage()
print(f"📊 After SAM2 load: {sam2_memory:.2f} GB (+{sam2_memory - yolo_memory:.2f} GB)")
except Exception as e:
print(f"❌ SAM2 loading failed: {e}")
sam2_memory = yolo_memory
total_model_memory = sam2_memory - init_memory
print(f"📊 Total model memory: {total_model_memory:.2f} GB")
# Test 3: Inter-chunk cleanup
print("\n3⃣ Testing inter-chunk cleanup...")
processor._complete_inter_chunk_cleanup(chunk_idx=0)
cleanup_memory = get_memory_usage()
cleanup_improvement = sam2_memory - cleanup_memory
print(f"📊 After cleanup: {cleanup_memory:.2f} GB (-{cleanup_improvement:.2f} GB freed)")
# Test 4: Verify models reload fresh
print("\n4⃣ Testing fresh model reload...")
# Check YOLO state
yolo_reloaded = processor.detector.model is None
print(f"🔍 YOLO model destroyed: {'✅ YES' if yolo_reloaded else '❌ NO'}")
# Check SAM2 state
sam2_reloaded = not processor.sam2_model._model_loaded or processor.sam2_model.predictor is None
print(f"🔍 SAM2 model destroyed: {'✅ YES' if sam2_reloaded else '❌ NO'}")
# Test 5: Force reload to verify they work
print("\n5⃣ Testing model reload...")
try:
# Force YOLO reload
processor.detector._load_model()
yolo_reload_memory = get_memory_usage()
# Force SAM2 reload
processor.sam2_model._load_model(processor.sam2_model.model_cfg, processor.sam2_model.checkpoint_path)
sam2_reload_memory = get_memory_usage()
reload_growth = sam2_reload_memory - cleanup_memory
print(f"📊 After reload: {sam2_reload_memory:.2f} GB (+{reload_growth:.2f} GB)")
if abs(reload_growth - total_model_memory) < 1.0: # Within 1GB
print("✅ Models reloaded with similar memory usage (good)")
else:
print("⚠️ Model reload memory differs significantly")
except Exception as e:
print(f"❌ Model reload failed: {e}")
# Final summary
print(f"\n📊 SUMMARY:")
print(f" Baseline → Peak: {baseline_memory:.2f}GB → {sam2_memory:.2f}GB")
print(f" Peak → Cleanup: {sam2_memory:.2f}GB → {cleanup_memory:.2f}GB")
print(f" Memory freed: {cleanup_improvement:.2f}GB")
print(f" Models destroyed: YOLO={yolo_reloaded}, SAM2={sam2_reloaded}")
if cleanup_improvement > total_model_memory * 0.5: # Freed >50% of model memory
print("✅ Inter-chunk cleanup working effectively")
return True
else:
print("❌ Inter-chunk cleanup not freeing enough memory")
return False
def main():
if len(sys.argv) != 2:
print("Usage: python test_inter_chunk_cleanup.py <config.yaml>")
sys.exit(1)
config_path = sys.argv[1]
if not Path(config_path).exists():
print(f"Config file not found: {config_path}")
sys.exit(1)
success = test_inter_chunk_cleanup()
if success:
print(f"\n🎉 SUCCESS: Inter-chunk cleanup is working!")
print(f"💡 This should prevent 15-20GB model accumulation between chunks")
else:
print(f"\n❌ FAILURE: Inter-chunk cleanup needs improvement")
print(f"💡 Check model destruction logic in _complete_inter_chunk_cleanup")
return 0 if success else 1
if __name__ == "__main__":
sys.exit(main())

View File

@@ -49,8 +49,8 @@ class SAM2VideoMatting:
def _load_model(self, model_cfg: str, checkpoint_path: str):
"""Load SAM2 video predictor lazily"""
if self._model_loaded:
return # Already loaded
if self._model_loaded and self.predictor is not None:
return # Already loaded and predictor exists
try:
# Import heavy SAM2 modules only when needed
@@ -419,6 +419,9 @@ class SAM2VideoMatting:
finally:
self.predictor = None
# Reset model loaded state for fresh reload
self._model_loaded = False
# Force garbage collection (critical for memory leak prevention)
gc.collect()

View File

@@ -3,6 +3,7 @@ import numpy as np
from typing import List, Dict, Any, Optional, Tuple
from pathlib import Path
import warnings
import torch
from .video_processor import VideoProcessor
from .config import VR180Config
@@ -212,6 +213,10 @@ class VR180Processor(VideoProcessor):
del right_matted
self._aggressive_memory_cleanup(f"After combining frames chunk {chunk_idx}")
# CRITICAL: Complete inter-chunk cleanup to prevent model persistence
# This ensures models don't accumulate between chunks
self._complete_inter_chunk_cleanup(chunk_idx)
return combined_frames
def _process_eye_sequence(self,
@@ -691,6 +696,64 @@ class VR180Processor(VideoProcessor):
# TODO: Implement proper stereo correction algorithm
return right_frame
def _complete_inter_chunk_cleanup(self, chunk_idx: int):
"""
Complete inter-chunk cleanup: Destroy all models to prevent memory accumulation
This addresses the core issue where SAM2 and YOLO models (~15-20GB)
persist between chunks, causing OOM when processing subsequent chunks.
"""
print(f"🧹 INTER-CHUNK CLEANUP: Destroying all models after chunk {chunk_idx}")
# 1. Completely destroy SAM2 model (15-20GB)
if hasattr(self, 'sam2_model') and self.sam2_model is not None:
self.sam2_model.cleanup() # Call existing cleanup
# Force complete destruction of the model
try:
# Reset the model's loaded state so it will reload fresh
if hasattr(self.sam2_model, '_model_loaded'):
self.sam2_model._model_loaded = False
# Clear any cached state
if hasattr(self.sam2_model, 'predictor'):
self.sam2_model.predictor = None
if hasattr(self.sam2_model, 'inference_state'):
self.sam2_model.inference_state = None
print(f" ✅ SAM2 model destroyed and marked for fresh reload")
except Exception as e:
print(f" ⚠️ SAM2 destruction warning: {e}")
# 2. Completely destroy YOLO detector (400MB+)
if hasattr(self, 'detector') and self.detector is not None:
try:
# Force YOLO model to be reloaded fresh
if hasattr(self.detector, 'model') and self.detector.model is not None:
del self.detector.model
self.detector.model = None
print(f" ✅ YOLO model destroyed and marked for fresh reload")
except Exception as e:
print(f" ⚠️ YOLO destruction warning: {e}")
# 3. Clear CUDA cache aggressively
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize() # Wait for all operations to complete
print(f" ✅ CUDA cache cleared")
# 4. Force garbage collection
import gc
collected = gc.collect()
print(f" ✅ Garbage collection: {collected} objects freed")
# 5. Memory verification
self._print_memory_step(f"After complete inter-chunk cleanup (chunk {chunk_idx})")
print(f"🎯 RESULT: Models will reload fresh for next chunk (prevents 15-20GB accumulation)")
def process_chunk(self,
frames: List[np.ndarray],
chunk_idx: int = 0) -> List[np.ndarray]:
@@ -750,6 +813,9 @@ class VR180Processor(VideoProcessor):
combined = {'left': left_frame, 'right': right_frame}
combined_frames.append(combined)
# CRITICAL: Complete inter-chunk cleanup for independent processing too
self._complete_inter_chunk_cleanup(chunk_idx)
return combined_frames
def save_video(self, frames: List[np.ndarray], output_path: str):