Compare commits
2 Commits
36f58acb8b
...
4d1361df46
| Author | SHA1 | Date | |
|---|---|---|---|
| 4d1361df46 | |||
| 884cb8dce2 |
139
test_inter_chunk_cleanup.py
Normal file
139
test_inter_chunk_cleanup.py
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Test script to verify inter-chunk cleanup properly destroys models
|
||||||
|
"""
|
||||||
|
|
||||||
|
import psutil
|
||||||
|
import gc
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
def get_memory_usage():
|
||||||
|
"""Get current memory usage in GB"""
|
||||||
|
process = psutil.Process()
|
||||||
|
return process.memory_info().rss / (1024**3)
|
||||||
|
|
||||||
|
def test_inter_chunk_cleanup():
|
||||||
|
"""Test that models are properly destroyed between chunks"""
|
||||||
|
|
||||||
|
print("🧪 TESTING INTER-CHUNK CLEANUP")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
baseline_memory = get_memory_usage()
|
||||||
|
print(f"📊 Baseline memory: {baseline_memory:.2f} GB")
|
||||||
|
|
||||||
|
# Import and create processor
|
||||||
|
print("\n1️⃣ Creating processor...")
|
||||||
|
from vr180_matting.config import VR180Config
|
||||||
|
from vr180_matting.vr180_processor import VR180Processor
|
||||||
|
|
||||||
|
config = VR180Config.from_yaml('config.yaml')
|
||||||
|
processor = VR180Processor(config)
|
||||||
|
|
||||||
|
init_memory = get_memory_usage()
|
||||||
|
print(f"📊 After processor init: {init_memory:.2f} GB (+{init_memory - baseline_memory:.2f} GB)")
|
||||||
|
|
||||||
|
# Simulate chunk processing (just trigger model loading)
|
||||||
|
print("\n2️⃣ Simulating chunk 0 processing...")
|
||||||
|
|
||||||
|
# Test 1: Force YOLO model loading
|
||||||
|
try:
|
||||||
|
detector = processor.detector
|
||||||
|
detector._load_model() # Force load
|
||||||
|
yolo_memory = get_memory_usage()
|
||||||
|
print(f"📊 After YOLO load: {yolo_memory:.2f} GB (+{yolo_memory - init_memory:.2f} GB)")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ YOLO loading failed: {e}")
|
||||||
|
yolo_memory = init_memory
|
||||||
|
|
||||||
|
# Test 2: Force SAM2 model loading
|
||||||
|
try:
|
||||||
|
sam2_model = processor.sam2_model
|
||||||
|
sam2_model._load_model(sam2_model.model_cfg, sam2_model.checkpoint_path)
|
||||||
|
sam2_memory = get_memory_usage()
|
||||||
|
print(f"📊 After SAM2 load: {sam2_memory:.2f} GB (+{sam2_memory - yolo_memory:.2f} GB)")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ SAM2 loading failed: {e}")
|
||||||
|
sam2_memory = yolo_memory
|
||||||
|
|
||||||
|
total_model_memory = sam2_memory - init_memory
|
||||||
|
print(f"📊 Total model memory: {total_model_memory:.2f} GB")
|
||||||
|
|
||||||
|
# Test 3: Inter-chunk cleanup
|
||||||
|
print("\n3️⃣ Testing inter-chunk cleanup...")
|
||||||
|
processor._complete_inter_chunk_cleanup(chunk_idx=0)
|
||||||
|
|
||||||
|
cleanup_memory = get_memory_usage()
|
||||||
|
cleanup_improvement = sam2_memory - cleanup_memory
|
||||||
|
print(f"📊 After cleanup: {cleanup_memory:.2f} GB (-{cleanup_improvement:.2f} GB freed)")
|
||||||
|
|
||||||
|
# Test 4: Verify models reload fresh
|
||||||
|
print("\n4️⃣ Testing fresh model reload...")
|
||||||
|
|
||||||
|
# Check YOLO state
|
||||||
|
yolo_reloaded = processor.detector.model is None
|
||||||
|
print(f"🔍 YOLO model destroyed: {'✅ YES' if yolo_reloaded else '❌ NO'}")
|
||||||
|
|
||||||
|
# Check SAM2 state
|
||||||
|
sam2_reloaded = not processor.sam2_model._model_loaded or processor.sam2_model.predictor is None
|
||||||
|
print(f"🔍 SAM2 model destroyed: {'✅ YES' if sam2_reloaded else '❌ NO'}")
|
||||||
|
|
||||||
|
# Test 5: Force reload to verify they work
|
||||||
|
print("\n5️⃣ Testing model reload...")
|
||||||
|
try:
|
||||||
|
# Force YOLO reload
|
||||||
|
processor.detector._load_model()
|
||||||
|
yolo_reload_memory = get_memory_usage()
|
||||||
|
|
||||||
|
# Force SAM2 reload
|
||||||
|
processor.sam2_model._load_model(processor.sam2_model.model_cfg, processor.sam2_model.checkpoint_path)
|
||||||
|
sam2_reload_memory = get_memory_usage()
|
||||||
|
|
||||||
|
reload_growth = sam2_reload_memory - cleanup_memory
|
||||||
|
print(f"📊 After reload: {sam2_reload_memory:.2f} GB (+{reload_growth:.2f} GB)")
|
||||||
|
|
||||||
|
if abs(reload_growth - total_model_memory) < 1.0: # Within 1GB
|
||||||
|
print("✅ Models reloaded with similar memory usage (good)")
|
||||||
|
else:
|
||||||
|
print("⚠️ Model reload memory differs significantly")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Model reload failed: {e}")
|
||||||
|
|
||||||
|
# Final summary
|
||||||
|
print(f"\n📊 SUMMARY:")
|
||||||
|
print(f" Baseline → Peak: {baseline_memory:.2f}GB → {sam2_memory:.2f}GB")
|
||||||
|
print(f" Peak → Cleanup: {sam2_memory:.2f}GB → {cleanup_memory:.2f}GB")
|
||||||
|
print(f" Memory freed: {cleanup_improvement:.2f}GB")
|
||||||
|
print(f" Models destroyed: YOLO={yolo_reloaded}, SAM2={sam2_reloaded}")
|
||||||
|
|
||||||
|
if cleanup_improvement > total_model_memory * 0.5: # Freed >50% of model memory
|
||||||
|
print("✅ Inter-chunk cleanup working effectively")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
print("❌ Inter-chunk cleanup not freeing enough memory")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def main():
|
||||||
|
if len(sys.argv) != 2:
|
||||||
|
print("Usage: python test_inter_chunk_cleanup.py <config.yaml>")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
config_path = sys.argv[1]
|
||||||
|
if not Path(config_path).exists():
|
||||||
|
print(f"Config file not found: {config_path}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
success = test_inter_chunk_cleanup()
|
||||||
|
|
||||||
|
if success:
|
||||||
|
print(f"\n🎉 SUCCESS: Inter-chunk cleanup is working!")
|
||||||
|
print(f"💡 This should prevent 15-20GB model accumulation between chunks")
|
||||||
|
else:
|
||||||
|
print(f"\n❌ FAILURE: Inter-chunk cleanup needs improvement")
|
||||||
|
print(f"💡 Check model destruction logic in _complete_inter_chunk_cleanup")
|
||||||
|
|
||||||
|
return 0 if success else 1
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
sys.exit(main())
|
||||||
@@ -49,8 +49,8 @@ class SAM2VideoMatting:
|
|||||||
|
|
||||||
def _load_model(self, model_cfg: str, checkpoint_path: str):
|
def _load_model(self, model_cfg: str, checkpoint_path: str):
|
||||||
"""Load SAM2 video predictor lazily"""
|
"""Load SAM2 video predictor lazily"""
|
||||||
if self._model_loaded:
|
if self._model_loaded and self.predictor is not None:
|
||||||
return # Already loaded
|
return # Already loaded and predictor exists
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Import heavy SAM2 modules only when needed
|
# Import heavy SAM2 modules only when needed
|
||||||
@@ -419,6 +419,9 @@ class SAM2VideoMatting:
|
|||||||
finally:
|
finally:
|
||||||
self.predictor = None
|
self.predictor = None
|
||||||
|
|
||||||
|
# Reset model loaded state for fresh reload
|
||||||
|
self._model_loaded = False
|
||||||
|
|
||||||
# Force garbage collection (critical for memory leak prevention)
|
# Force garbage collection (critical for memory leak prevention)
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import numpy as np
|
|||||||
from typing import List, Dict, Any, Optional, Tuple
|
from typing import List, Dict, Any, Optional, Tuple
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import warnings
|
import warnings
|
||||||
|
import torch
|
||||||
|
|
||||||
from .video_processor import VideoProcessor
|
from .video_processor import VideoProcessor
|
||||||
from .config import VR180Config
|
from .config import VR180Config
|
||||||
@@ -212,6 +213,10 @@ class VR180Processor(VideoProcessor):
|
|||||||
del right_matted
|
del right_matted
|
||||||
self._aggressive_memory_cleanup(f"After combining frames chunk {chunk_idx}")
|
self._aggressive_memory_cleanup(f"After combining frames chunk {chunk_idx}")
|
||||||
|
|
||||||
|
# CRITICAL: Complete inter-chunk cleanup to prevent model persistence
|
||||||
|
# This ensures models don't accumulate between chunks
|
||||||
|
self._complete_inter_chunk_cleanup(chunk_idx)
|
||||||
|
|
||||||
return combined_frames
|
return combined_frames
|
||||||
|
|
||||||
def _process_eye_sequence(self,
|
def _process_eye_sequence(self,
|
||||||
@@ -691,6 +696,64 @@ class VR180Processor(VideoProcessor):
|
|||||||
# TODO: Implement proper stereo correction algorithm
|
# TODO: Implement proper stereo correction algorithm
|
||||||
return right_frame
|
return right_frame
|
||||||
|
|
||||||
|
def _complete_inter_chunk_cleanup(self, chunk_idx: int):
|
||||||
|
"""
|
||||||
|
Complete inter-chunk cleanup: Destroy all models to prevent memory accumulation
|
||||||
|
|
||||||
|
This addresses the core issue where SAM2 and YOLO models (~15-20GB)
|
||||||
|
persist between chunks, causing OOM when processing subsequent chunks.
|
||||||
|
"""
|
||||||
|
print(f"🧹 INTER-CHUNK CLEANUP: Destroying all models after chunk {chunk_idx}")
|
||||||
|
|
||||||
|
# 1. Completely destroy SAM2 model (15-20GB)
|
||||||
|
if hasattr(self, 'sam2_model') and self.sam2_model is not None:
|
||||||
|
self.sam2_model.cleanup() # Call existing cleanup
|
||||||
|
|
||||||
|
# Force complete destruction of the model
|
||||||
|
try:
|
||||||
|
# Reset the model's loaded state so it will reload fresh
|
||||||
|
if hasattr(self.sam2_model, '_model_loaded'):
|
||||||
|
self.sam2_model._model_loaded = False
|
||||||
|
|
||||||
|
# Clear any cached state
|
||||||
|
if hasattr(self.sam2_model, 'predictor'):
|
||||||
|
self.sam2_model.predictor = None
|
||||||
|
if hasattr(self.sam2_model, 'inference_state'):
|
||||||
|
self.sam2_model.inference_state = None
|
||||||
|
|
||||||
|
print(f" ✅ SAM2 model destroyed and marked for fresh reload")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f" ⚠️ SAM2 destruction warning: {e}")
|
||||||
|
|
||||||
|
# 2. Completely destroy YOLO detector (400MB+)
|
||||||
|
if hasattr(self, 'detector') and self.detector is not None:
|
||||||
|
try:
|
||||||
|
# Force YOLO model to be reloaded fresh
|
||||||
|
if hasattr(self.detector, 'model') and self.detector.model is not None:
|
||||||
|
del self.detector.model
|
||||||
|
self.detector.model = None
|
||||||
|
print(f" ✅ YOLO model destroyed and marked for fresh reload")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f" ⚠️ YOLO destruction warning: {e}")
|
||||||
|
|
||||||
|
# 3. Clear CUDA cache aggressively
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
torch.cuda.synchronize() # Wait for all operations to complete
|
||||||
|
print(f" ✅ CUDA cache cleared")
|
||||||
|
|
||||||
|
# 4. Force garbage collection
|
||||||
|
import gc
|
||||||
|
collected = gc.collect()
|
||||||
|
print(f" ✅ Garbage collection: {collected} objects freed")
|
||||||
|
|
||||||
|
# 5. Memory verification
|
||||||
|
self._print_memory_step(f"After complete inter-chunk cleanup (chunk {chunk_idx})")
|
||||||
|
|
||||||
|
print(f"🎯 RESULT: Models will reload fresh for next chunk (prevents 15-20GB accumulation)")
|
||||||
|
|
||||||
def process_chunk(self,
|
def process_chunk(self,
|
||||||
frames: List[np.ndarray],
|
frames: List[np.ndarray],
|
||||||
chunk_idx: int = 0) -> List[np.ndarray]:
|
chunk_idx: int = 0) -> List[np.ndarray]:
|
||||||
@@ -750,6 +813,9 @@ class VR180Processor(VideoProcessor):
|
|||||||
combined = {'left': left_frame, 'right': right_frame}
|
combined = {'left': left_frame, 'right': right_frame}
|
||||||
combined_frames.append(combined)
|
combined_frames.append(combined)
|
||||||
|
|
||||||
|
# CRITICAL: Complete inter-chunk cleanup for independent processing too
|
||||||
|
self._complete_inter_chunk_cleanup(chunk_idx)
|
||||||
|
|
||||||
return combined_frames
|
return combined_frames
|
||||||
|
|
||||||
def save_video(self, frames: List[np.ndarray], output_path: str):
|
def save_video(self, frames: List[np.ndarray], output_path: str):
|
||||||
|
|||||||
Reference in New Issue
Block a user