Compare commits

...

2 Commits

Author SHA1 Message Date
4d1361df46 bigtime 2025-07-26 15:29:37 -07:00
884cb8dce2 lol 2025-07-26 15:29:28 -07:00
3 changed files with 210 additions and 2 deletions

139
test_inter_chunk_cleanup.py Normal file
View File

@@ -0,0 +1,139 @@
#!/usr/bin/env python3
"""
Test script to verify inter-chunk cleanup properly destroys models
"""
import psutil
import gc
import sys
from pathlib import Path
def get_memory_usage():
"""Get current memory usage in GB"""
process = psutil.Process()
return process.memory_info().rss / (1024**3)
def test_inter_chunk_cleanup():
"""Test that models are properly destroyed between chunks"""
print("🧪 TESTING INTER-CHUNK CLEANUP")
print("=" * 50)
baseline_memory = get_memory_usage()
print(f"📊 Baseline memory: {baseline_memory:.2f} GB")
# Import and create processor
print("\n1⃣ Creating processor...")
from vr180_matting.config import VR180Config
from vr180_matting.vr180_processor import VR180Processor
config = VR180Config.from_yaml('config.yaml')
processor = VR180Processor(config)
init_memory = get_memory_usage()
print(f"📊 After processor init: {init_memory:.2f} GB (+{init_memory - baseline_memory:.2f} GB)")
# Simulate chunk processing (just trigger model loading)
print("\n2⃣ Simulating chunk 0 processing...")
# Test 1: Force YOLO model loading
try:
detector = processor.detector
detector._load_model() # Force load
yolo_memory = get_memory_usage()
print(f"📊 After YOLO load: {yolo_memory:.2f} GB (+{yolo_memory - init_memory:.2f} GB)")
except Exception as e:
print(f"❌ YOLO loading failed: {e}")
yolo_memory = init_memory
# Test 2: Force SAM2 model loading
try:
sam2_model = processor.sam2_model
sam2_model._load_model(sam2_model.model_cfg, sam2_model.checkpoint_path)
sam2_memory = get_memory_usage()
print(f"📊 After SAM2 load: {sam2_memory:.2f} GB (+{sam2_memory - yolo_memory:.2f} GB)")
except Exception as e:
print(f"❌ SAM2 loading failed: {e}")
sam2_memory = yolo_memory
total_model_memory = sam2_memory - init_memory
print(f"📊 Total model memory: {total_model_memory:.2f} GB")
# Test 3: Inter-chunk cleanup
print("\n3⃣ Testing inter-chunk cleanup...")
processor._complete_inter_chunk_cleanup(chunk_idx=0)
cleanup_memory = get_memory_usage()
cleanup_improvement = sam2_memory - cleanup_memory
print(f"📊 After cleanup: {cleanup_memory:.2f} GB (-{cleanup_improvement:.2f} GB freed)")
# Test 4: Verify models reload fresh
print("\n4⃣ Testing fresh model reload...")
# Check YOLO state
yolo_reloaded = processor.detector.model is None
print(f"🔍 YOLO model destroyed: {'✅ YES' if yolo_reloaded else '❌ NO'}")
# Check SAM2 state
sam2_reloaded = not processor.sam2_model._model_loaded or processor.sam2_model.predictor is None
print(f"🔍 SAM2 model destroyed: {'✅ YES' if sam2_reloaded else '❌ NO'}")
# Test 5: Force reload to verify they work
print("\n5⃣ Testing model reload...")
try:
# Force YOLO reload
processor.detector._load_model()
yolo_reload_memory = get_memory_usage()
# Force SAM2 reload
processor.sam2_model._load_model(processor.sam2_model.model_cfg, processor.sam2_model.checkpoint_path)
sam2_reload_memory = get_memory_usage()
reload_growth = sam2_reload_memory - cleanup_memory
print(f"📊 After reload: {sam2_reload_memory:.2f} GB (+{reload_growth:.2f} GB)")
if abs(reload_growth - total_model_memory) < 1.0: # Within 1GB
print("✅ Models reloaded with similar memory usage (good)")
else:
print("⚠️ Model reload memory differs significantly")
except Exception as e:
print(f"❌ Model reload failed: {e}")
# Final summary
print(f"\n📊 SUMMARY:")
print(f" Baseline → Peak: {baseline_memory:.2f}GB → {sam2_memory:.2f}GB")
print(f" Peak → Cleanup: {sam2_memory:.2f}GB → {cleanup_memory:.2f}GB")
print(f" Memory freed: {cleanup_improvement:.2f}GB")
print(f" Models destroyed: YOLO={yolo_reloaded}, SAM2={sam2_reloaded}")
if cleanup_improvement > total_model_memory * 0.5: # Freed >50% of model memory
print("✅ Inter-chunk cleanup working effectively")
return True
else:
print("❌ Inter-chunk cleanup not freeing enough memory")
return False
def main():
if len(sys.argv) != 2:
print("Usage: python test_inter_chunk_cleanup.py <config.yaml>")
sys.exit(1)
config_path = sys.argv[1]
if not Path(config_path).exists():
print(f"Config file not found: {config_path}")
sys.exit(1)
success = test_inter_chunk_cleanup()
if success:
print(f"\n🎉 SUCCESS: Inter-chunk cleanup is working!")
print(f"💡 This should prevent 15-20GB model accumulation between chunks")
else:
print(f"\n❌ FAILURE: Inter-chunk cleanup needs improvement")
print(f"💡 Check model destruction logic in _complete_inter_chunk_cleanup")
return 0 if success else 1
if __name__ == "__main__":
sys.exit(main())

View File

@@ -49,8 +49,8 @@ class SAM2VideoMatting:
def _load_model(self, model_cfg: str, checkpoint_path: str): def _load_model(self, model_cfg: str, checkpoint_path: str):
"""Load SAM2 video predictor lazily""" """Load SAM2 video predictor lazily"""
if self._model_loaded: if self._model_loaded and self.predictor is not None:
return # Already loaded return # Already loaded and predictor exists
try: try:
# Import heavy SAM2 modules only when needed # Import heavy SAM2 modules only when needed
@@ -419,6 +419,9 @@ class SAM2VideoMatting:
finally: finally:
self.predictor = None self.predictor = None
# Reset model loaded state for fresh reload
self._model_loaded = False
# Force garbage collection (critical for memory leak prevention) # Force garbage collection (critical for memory leak prevention)
gc.collect() gc.collect()

View File

@@ -3,6 +3,7 @@ import numpy as np
from typing import List, Dict, Any, Optional, Tuple from typing import List, Dict, Any, Optional, Tuple
from pathlib import Path from pathlib import Path
import warnings import warnings
import torch
from .video_processor import VideoProcessor from .video_processor import VideoProcessor
from .config import VR180Config from .config import VR180Config
@@ -212,6 +213,10 @@ class VR180Processor(VideoProcessor):
del right_matted del right_matted
self._aggressive_memory_cleanup(f"After combining frames chunk {chunk_idx}") self._aggressive_memory_cleanup(f"After combining frames chunk {chunk_idx}")
# CRITICAL: Complete inter-chunk cleanup to prevent model persistence
# This ensures models don't accumulate between chunks
self._complete_inter_chunk_cleanup(chunk_idx)
return combined_frames return combined_frames
def _process_eye_sequence(self, def _process_eye_sequence(self,
@@ -691,6 +696,64 @@ class VR180Processor(VideoProcessor):
# TODO: Implement proper stereo correction algorithm # TODO: Implement proper stereo correction algorithm
return right_frame return right_frame
def _complete_inter_chunk_cleanup(self, chunk_idx: int):
"""
Complete inter-chunk cleanup: Destroy all models to prevent memory accumulation
This addresses the core issue where SAM2 and YOLO models (~15-20GB)
persist between chunks, causing OOM when processing subsequent chunks.
"""
print(f"🧹 INTER-CHUNK CLEANUP: Destroying all models after chunk {chunk_idx}")
# 1. Completely destroy SAM2 model (15-20GB)
if hasattr(self, 'sam2_model') and self.sam2_model is not None:
self.sam2_model.cleanup() # Call existing cleanup
# Force complete destruction of the model
try:
# Reset the model's loaded state so it will reload fresh
if hasattr(self.sam2_model, '_model_loaded'):
self.sam2_model._model_loaded = False
# Clear any cached state
if hasattr(self.sam2_model, 'predictor'):
self.sam2_model.predictor = None
if hasattr(self.sam2_model, 'inference_state'):
self.sam2_model.inference_state = None
print(f" ✅ SAM2 model destroyed and marked for fresh reload")
except Exception as e:
print(f" ⚠️ SAM2 destruction warning: {e}")
# 2. Completely destroy YOLO detector (400MB+)
if hasattr(self, 'detector') and self.detector is not None:
try:
# Force YOLO model to be reloaded fresh
if hasattr(self.detector, 'model') and self.detector.model is not None:
del self.detector.model
self.detector.model = None
print(f" ✅ YOLO model destroyed and marked for fresh reload")
except Exception as e:
print(f" ⚠️ YOLO destruction warning: {e}")
# 3. Clear CUDA cache aggressively
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize() # Wait for all operations to complete
print(f" ✅ CUDA cache cleared")
# 4. Force garbage collection
import gc
collected = gc.collect()
print(f" ✅ Garbage collection: {collected} objects freed")
# 5. Memory verification
self._print_memory_step(f"After complete inter-chunk cleanup (chunk {chunk_idx})")
print(f"🎯 RESULT: Models will reload fresh for next chunk (prevents 15-20GB accumulation)")
def process_chunk(self, def process_chunk(self,
frames: List[np.ndarray], frames: List[np.ndarray],
chunk_idx: int = 0) -> List[np.ndarray]: chunk_idx: int = 0) -> List[np.ndarray]:
@@ -750,6 +813,9 @@ class VR180Processor(VideoProcessor):
combined = {'left': left_frame, 'right': right_frame} combined = {'left': left_frame, 'right': right_frame}
combined_frames.append(combined) combined_frames.append(combined)
# CRITICAL: Complete inter-chunk cleanup for independent processing too
self._complete_inter_chunk_cleanup(chunk_idx)
return combined_frames return combined_frames
def save_video(self, frames: List[np.ndarray], output_path: str): def save_video(self, frames: List[np.ndarray], output_path: str):