148 lines
5.5 KiB
Python
148 lines
5.5 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
Test script to verify inter-chunk cleanup properly destroys models
|
||
"""
|
||
|
||
import psutil
|
||
import gc
|
||
import sys
|
||
from pathlib import Path
|
||
|
||
def get_memory_usage():
|
||
"""Get current memory usage in GB"""
|
||
process = psutil.Process()
|
||
return process.memory_info().rss / (1024**3)
|
||
|
||
def test_inter_chunk_cleanup():
|
||
"""Test that models are properly destroyed between chunks"""
|
||
|
||
print("🧪 TESTING INTER-CHUNK CLEANUP")
|
||
print("=" * 50)
|
||
|
||
baseline_memory = get_memory_usage()
|
||
print(f"📊 Baseline memory: {baseline_memory:.2f} GB")
|
||
|
||
# Import and create processor
|
||
print("\n1️⃣ Creating processor...")
|
||
from vr180_matting.config import VR180Config
|
||
from vr180_matting.vr180_processor import VR180Processor
|
||
|
||
config = VR180Config.from_yaml('config.yaml')
|
||
processor = VR180Processor(config)
|
||
|
||
init_memory = get_memory_usage()
|
||
print(f"📊 After processor init: {init_memory:.2f} GB (+{init_memory - baseline_memory:.2f} GB)")
|
||
|
||
# Simulate chunk processing (just trigger model loading)
|
||
print("\n2️⃣ Simulating chunk 0 processing...")
|
||
|
||
# Test 1: Force YOLO model loading
|
||
try:
|
||
detector = processor.detector
|
||
detector._load_model() # Force load
|
||
yolo_memory = get_memory_usage()
|
||
print(f"📊 After YOLO load: {yolo_memory:.2f} GB (+{yolo_memory - init_memory:.2f} GB)")
|
||
except Exception as e:
|
||
print(f"❌ YOLO loading failed: {e}")
|
||
yolo_memory = init_memory
|
||
|
||
# Test 2: Force SAM2 model loading
|
||
try:
|
||
sam2_model = processor.sam2_model
|
||
sam2_model._load_model(sam2_model.model_cfg, sam2_model.checkpoint_path)
|
||
sam2_memory = get_memory_usage()
|
||
print(f"📊 After SAM2 load: {sam2_memory:.2f} GB (+{sam2_memory - yolo_memory:.2f} GB)")
|
||
except Exception as e:
|
||
print(f"❌ SAM2 loading failed: {e}")
|
||
sam2_memory = yolo_memory
|
||
|
||
total_model_memory = sam2_memory - init_memory
|
||
print(f"📊 Total model memory: {total_model_memory:.2f} GB")
|
||
|
||
# Test 3: Inter-chunk cleanup
|
||
print("\n3️⃣ Testing inter-chunk cleanup...")
|
||
processor._complete_inter_chunk_cleanup(chunk_idx=0)
|
||
|
||
cleanup_memory = get_memory_usage()
|
||
cleanup_improvement = sam2_memory - cleanup_memory
|
||
print(f"📊 After cleanup: {cleanup_memory:.2f} GB (-{cleanup_improvement:.2f} GB freed)")
|
||
|
||
# Test 4: Verify models reload fresh
|
||
print("\n4️⃣ Testing fresh model reload...")
|
||
|
||
# Check YOLO state
|
||
yolo_reloaded = processor.detector.model is None
|
||
print(f"🔍 YOLO model destroyed: {'✅ YES' if yolo_reloaded else '❌ NO'}")
|
||
|
||
# Check SAM2 state
|
||
sam2_reloaded = not processor.sam2_model._model_loaded or processor.sam2_model.predictor is None
|
||
print(f"🔍 SAM2 model destroyed: {'✅ YES' if sam2_reloaded else '❌ NO'}")
|
||
|
||
# Test 5: Force reload to verify they work
|
||
print("\n5️⃣ Testing model reload...")
|
||
try:
|
||
# Force YOLO reload
|
||
processor.detector._load_model()
|
||
yolo_reload_memory = get_memory_usage()
|
||
|
||
# Force SAM2 reload
|
||
processor.sam2_model._load_model(processor.sam2_model.model_cfg, processor.sam2_model.checkpoint_path)
|
||
sam2_reload_memory = get_memory_usage()
|
||
|
||
reload_growth = sam2_reload_memory - cleanup_memory
|
||
print(f"📊 After reload: {sam2_reload_memory:.2f} GB (+{reload_growth:.2f} GB)")
|
||
|
||
if abs(reload_growth - total_model_memory) < 1.0: # Within 1GB
|
||
print("✅ Models reloaded with similar memory usage (good)")
|
||
else:
|
||
print("⚠️ Model reload memory differs significantly")
|
||
|
||
except Exception as e:
|
||
print(f"❌ Model reload failed: {e}")
|
||
|
||
# Final summary
|
||
print(f"\n📊 SUMMARY:")
|
||
print(f" Baseline → Peak: {baseline_memory:.2f}GB → {sam2_memory:.2f}GB")
|
||
print(f" Peak → Cleanup: {sam2_memory:.2f}GB → {cleanup_memory:.2f}GB")
|
||
print(f" Memory freed: {cleanup_improvement:.2f}GB")
|
||
print(f" Models destroyed: YOLO={yolo_reloaded}, SAM2={sam2_reloaded}")
|
||
|
||
# Success criteria: Both models destroyed AND can reload
|
||
models_destroyed = yolo_reloaded and sam2_reloaded
|
||
can_reload = 'reload_growth' in locals()
|
||
|
||
if models_destroyed and can_reload:
|
||
print("✅ Inter-chunk cleanup working effectively")
|
||
print("💡 Models destroyed and can reload fresh (memory will be freed during real processing)")
|
||
return True
|
||
elif models_destroyed:
|
||
print("⚠️ Models destroyed but reload test incomplete")
|
||
print("💡 This should still prevent accumulation during real processing")
|
||
return True
|
||
else:
|
||
print("❌ Inter-chunk cleanup not freeing enough memory")
|
||
return False
|
||
|
||
def main():
|
||
if len(sys.argv) != 2:
|
||
print("Usage: python test_inter_chunk_cleanup.py <config.yaml>")
|
||
sys.exit(1)
|
||
|
||
config_path = sys.argv[1]
|
||
if not Path(config_path).exists():
|
||
print(f"Config file not found: {config_path}")
|
||
sys.exit(1)
|
||
|
||
success = test_inter_chunk_cleanup()
|
||
|
||
if success:
|
||
print(f"\n🎉 SUCCESS: Inter-chunk cleanup is working!")
|
||
print(f"💡 This should prevent 15-20GB model accumulation between chunks")
|
||
else:
|
||
print(f"\n❌ FAILURE: Inter-chunk cleanup needs improvement")
|
||
print(f"💡 Check model destruction logic in _complete_inter_chunk_cleanup")
|
||
|
||
return 0 if success else 1
|
||
|
||
if __name__ == "__main__":
|
||
sys.exit(main()) |