diff --git a/test_inter_chunk_cleanup.py b/test_inter_chunk_cleanup.py new file mode 100644 index 0000000..b3330e0 --- /dev/null +++ b/test_inter_chunk_cleanup.py @@ -0,0 +1,139 @@ +#!/usr/bin/env python3 +""" +Test script to verify inter-chunk cleanup properly destroys models +""" + +import psutil +import gc +import sys +from pathlib import Path + +def get_memory_usage(): + """Get current memory usage in GB""" + process = psutil.Process() + return process.memory_info().rss / (1024**3) + +def test_inter_chunk_cleanup(): + """Test that models are properly destroyed between chunks""" + + print("๐Ÿงช TESTING INTER-CHUNK CLEANUP") + print("=" * 50) + + baseline_memory = get_memory_usage() + print(f"๐Ÿ“Š Baseline memory: {baseline_memory:.2f} GB") + + # Import and create processor + print("\n1๏ธโƒฃ Creating processor...") + from vr180_matting.config import VR180Config + from vr180_matting.vr180_processor import VR180Processor + + config = VR180Config.from_yaml('config.yaml') + processor = VR180Processor(config) + + init_memory = get_memory_usage() + print(f"๐Ÿ“Š After processor init: {init_memory:.2f} GB (+{init_memory - baseline_memory:.2f} GB)") + + # Simulate chunk processing (just trigger model loading) + print("\n2๏ธโƒฃ Simulating chunk 0 processing...") + + # Test 1: Force YOLO model loading + try: + detector = processor.detector + detector._load_model() # Force load + yolo_memory = get_memory_usage() + print(f"๐Ÿ“Š After YOLO load: {yolo_memory:.2f} GB (+{yolo_memory - init_memory:.2f} GB)") + except Exception as e: + print(f"โŒ YOLO loading failed: {e}") + yolo_memory = init_memory + + # Test 2: Force SAM2 model loading + try: + sam2_model = processor.sam2_model + sam2_model._load_model(sam2_model.model_cfg, sam2_model.checkpoint_path) + sam2_memory = get_memory_usage() + print(f"๐Ÿ“Š After SAM2 load: {sam2_memory:.2f} GB (+{sam2_memory - yolo_memory:.2f} GB)") + except Exception as e: + print(f"โŒ SAM2 loading failed: {e}") + sam2_memory = yolo_memory + + total_model_memory = sam2_memory - init_memory + print(f"๐Ÿ“Š Total model memory: {total_model_memory:.2f} GB") + + # Test 3: Inter-chunk cleanup + print("\n3๏ธโƒฃ Testing inter-chunk cleanup...") + processor._complete_inter_chunk_cleanup(chunk_idx=0) + + cleanup_memory = get_memory_usage() + cleanup_improvement = sam2_memory - cleanup_memory + print(f"๐Ÿ“Š After cleanup: {cleanup_memory:.2f} GB (-{cleanup_improvement:.2f} GB freed)") + + # Test 4: Verify models reload fresh + print("\n4๏ธโƒฃ Testing fresh model reload...") + + # Check YOLO state + yolo_reloaded = processor.detector.model is None + print(f"๐Ÿ” YOLO model destroyed: {'โœ… YES' if yolo_reloaded else 'โŒ NO'}") + + # Check SAM2 state + sam2_reloaded = not processor.sam2_model._model_loaded or processor.sam2_model.predictor is None + print(f"๐Ÿ” SAM2 model destroyed: {'โœ… YES' if sam2_reloaded else 'โŒ NO'}") + + # Test 5: Force reload to verify they work + print("\n5๏ธโƒฃ Testing model reload...") + try: + # Force YOLO reload + processor.detector._load_model() + yolo_reload_memory = get_memory_usage() + + # Force SAM2 reload + processor.sam2_model._load_model(processor.sam2_model.model_cfg, processor.sam2_model.checkpoint_path) + sam2_reload_memory = get_memory_usage() + + reload_growth = sam2_reload_memory - cleanup_memory + print(f"๐Ÿ“Š After reload: {sam2_reload_memory:.2f} GB (+{reload_growth:.2f} GB)") + + if abs(reload_growth - total_model_memory) < 1.0: # Within 1GB + print("โœ… Models reloaded with similar memory usage (good)") + else: + print("โš ๏ธ Model reload memory differs significantly") + + except Exception as e: + print(f"โŒ Model reload failed: {e}") + + # Final summary + print(f"\n๐Ÿ“Š SUMMARY:") + print(f" Baseline โ†’ Peak: {baseline_memory:.2f}GB โ†’ {sam2_memory:.2f}GB") + print(f" Peak โ†’ Cleanup: {sam2_memory:.2f}GB โ†’ {cleanup_memory:.2f}GB") + print(f" Memory freed: {cleanup_improvement:.2f}GB") + print(f" Models destroyed: YOLO={yolo_reloaded}, SAM2={sam2_reloaded}") + + if cleanup_improvement > total_model_memory * 0.5: # Freed >50% of model memory + print("โœ… Inter-chunk cleanup working effectively") + return True + else: + print("โŒ Inter-chunk cleanup not freeing enough memory") + return False + +def main(): + if len(sys.argv) != 2: + print("Usage: python test_inter_chunk_cleanup.py ") + sys.exit(1) + + config_path = sys.argv[1] + if not Path(config_path).exists(): + print(f"Config file not found: {config_path}") + sys.exit(1) + + success = test_inter_chunk_cleanup() + + if success: + print(f"\n๐ŸŽ‰ SUCCESS: Inter-chunk cleanup is working!") + print(f"๐Ÿ’ก This should prevent 15-20GB model accumulation between chunks") + else: + print(f"\nโŒ FAILURE: Inter-chunk cleanup needs improvement") + print(f"๐Ÿ’ก Check model destruction logic in _complete_inter_chunk_cleanup") + + return 0 if success else 1 + +if __name__ == "__main__": + sys.exit(main()) \ No newline at end of file