diff --git a/vr180_matting/sam2_wrapper.py b/vr180_matting/sam2_wrapper.py index 64a37d0..ccb4b34 100644 --- a/vr180_matting/sam2_wrapper.py +++ b/vr180_matting/sam2_wrapper.py @@ -49,8 +49,8 @@ class SAM2VideoMatting: def _load_model(self, model_cfg: str, checkpoint_path: str): """Load SAM2 video predictor lazily""" - if self._model_loaded: - return # Already loaded + if self._model_loaded and self.predictor is not None: + return # Already loaded and predictor exists try: # Import heavy SAM2 modules only when needed @@ -419,6 +419,9 @@ class SAM2VideoMatting: finally: self.predictor = None + # Reset model loaded state for fresh reload + self._model_loaded = False + # Force garbage collection (critical for memory leak prevention) gc.collect() diff --git a/vr180_matting/vr180_processor.py b/vr180_matting/vr180_processor.py index 738299b..1292056 100644 --- a/vr180_matting/vr180_processor.py +++ b/vr180_matting/vr180_processor.py @@ -3,6 +3,7 @@ import numpy as np from typing import List, Dict, Any, Optional, Tuple from pathlib import Path import warnings +import torch from .video_processor import VideoProcessor from .config import VR180Config @@ -212,6 +213,10 @@ class VR180Processor(VideoProcessor): del right_matted self._aggressive_memory_cleanup(f"After combining frames chunk {chunk_idx}") + # CRITICAL: Complete inter-chunk cleanup to prevent model persistence + # This ensures models don't accumulate between chunks + self._complete_inter_chunk_cleanup(chunk_idx) + return combined_frames def _process_eye_sequence(self, @@ -691,6 +696,64 @@ class VR180Processor(VideoProcessor): # TODO: Implement proper stereo correction algorithm return right_frame + def _complete_inter_chunk_cleanup(self, chunk_idx: int): + """ + Complete inter-chunk cleanup: Destroy all models to prevent memory accumulation + + This addresses the core issue where SAM2 and YOLO models (~15-20GB) + persist between chunks, causing OOM when processing subsequent chunks. + """ + print(f"๐Ÿงน INTER-CHUNK CLEANUP: Destroying all models after chunk {chunk_idx}") + + # 1. Completely destroy SAM2 model (15-20GB) + if hasattr(self, 'sam2_model') and self.sam2_model is not None: + self.sam2_model.cleanup() # Call existing cleanup + + # Force complete destruction of the model + try: + # Reset the model's loaded state so it will reload fresh + if hasattr(self.sam2_model, '_model_loaded'): + self.sam2_model._model_loaded = False + + # Clear any cached state + if hasattr(self.sam2_model, 'predictor'): + self.sam2_model.predictor = None + if hasattr(self.sam2_model, 'inference_state'): + self.sam2_model.inference_state = None + + print(f" โœ… SAM2 model destroyed and marked for fresh reload") + + except Exception as e: + print(f" โš ๏ธ SAM2 destruction warning: {e}") + + # 2. Completely destroy YOLO detector (400MB+) + if hasattr(self, 'detector') and self.detector is not None: + try: + # Force YOLO model to be reloaded fresh + if hasattr(self.detector, 'model') and self.detector.model is not None: + del self.detector.model + self.detector.model = None + print(f" โœ… YOLO model destroyed and marked for fresh reload") + + except Exception as e: + print(f" โš ๏ธ YOLO destruction warning: {e}") + + # 3. Clear CUDA cache aggressively + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() # Wait for all operations to complete + print(f" โœ… CUDA cache cleared") + + # 4. Force garbage collection + import gc + collected = gc.collect() + print(f" โœ… Garbage collection: {collected} objects freed") + + # 5. Memory verification + self._print_memory_step(f"After complete inter-chunk cleanup (chunk {chunk_idx})") + + print(f"๐ŸŽฏ RESULT: Models will reload fresh for next chunk (prevents 15-20GB accumulation)") + def process_chunk(self, frames: List[np.ndarray], chunk_idx: int = 0) -> List[np.ndarray]: @@ -750,6 +813,9 @@ class VR180Processor(VideoProcessor): combined = {'left': left_frame, 'right': right_frame} combined_frames.append(combined) + # CRITICAL: Complete inter-chunk cleanup for independent processing too + self._complete_inter_chunk_cleanup(chunk_idx) + return combined_frames def save_video(self, frames: List[np.ndarray], output_path: str):