This commit is contained in:
2025-07-26 15:29:37 -07:00
parent 884cb8dce2
commit 4d1361df46
2 changed files with 71 additions and 2 deletions

View File

@@ -49,8 +49,8 @@ class SAM2VideoMatting:
def _load_model(self, model_cfg: str, checkpoint_path: str):
"""Load SAM2 video predictor lazily"""
if self._model_loaded:
return # Already loaded
if self._model_loaded and self.predictor is not None:
return # Already loaded and predictor exists
try:
# Import heavy SAM2 modules only when needed
@@ -419,6 +419,9 @@ class SAM2VideoMatting:
finally:
self.predictor = None
# Reset model loaded state for fresh reload
self._model_loaded = False
# Force garbage collection (critical for memory leak prevention)
gc.collect()

View File

@@ -3,6 +3,7 @@ import numpy as np
from typing import List, Dict, Any, Optional, Tuple
from pathlib import Path
import warnings
import torch
from .video_processor import VideoProcessor
from .config import VR180Config
@@ -212,6 +213,10 @@ class VR180Processor(VideoProcessor):
del right_matted
self._aggressive_memory_cleanup(f"After combining frames chunk {chunk_idx}")
# CRITICAL: Complete inter-chunk cleanup to prevent model persistence
# This ensures models don't accumulate between chunks
self._complete_inter_chunk_cleanup(chunk_idx)
return combined_frames
def _process_eye_sequence(self,
@@ -691,6 +696,64 @@ class VR180Processor(VideoProcessor):
# TODO: Implement proper stereo correction algorithm
return right_frame
def _complete_inter_chunk_cleanup(self, chunk_idx: int):
"""
Complete inter-chunk cleanup: Destroy all models to prevent memory accumulation
This addresses the core issue where SAM2 and YOLO models (~15-20GB)
persist between chunks, causing OOM when processing subsequent chunks.
"""
print(f"🧹 INTER-CHUNK CLEANUP: Destroying all models after chunk {chunk_idx}")
# 1. Completely destroy SAM2 model (15-20GB)
if hasattr(self, 'sam2_model') and self.sam2_model is not None:
self.sam2_model.cleanup() # Call existing cleanup
# Force complete destruction of the model
try:
# Reset the model's loaded state so it will reload fresh
if hasattr(self.sam2_model, '_model_loaded'):
self.sam2_model._model_loaded = False
# Clear any cached state
if hasattr(self.sam2_model, 'predictor'):
self.sam2_model.predictor = None
if hasattr(self.sam2_model, 'inference_state'):
self.sam2_model.inference_state = None
print(f" ✅ SAM2 model destroyed and marked for fresh reload")
except Exception as e:
print(f" ⚠️ SAM2 destruction warning: {e}")
# 2. Completely destroy YOLO detector (400MB+)
if hasattr(self, 'detector') and self.detector is not None:
try:
# Force YOLO model to be reloaded fresh
if hasattr(self.detector, 'model') and self.detector.model is not None:
del self.detector.model
self.detector.model = None
print(f" ✅ YOLO model destroyed and marked for fresh reload")
except Exception as e:
print(f" ⚠️ YOLO destruction warning: {e}")
# 3. Clear CUDA cache aggressively
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize() # Wait for all operations to complete
print(f" ✅ CUDA cache cleared")
# 4. Force garbage collection
import gc
collected = gc.collect()
print(f" ✅ Garbage collection: {collected} objects freed")
# 5. Memory verification
self._print_memory_step(f"After complete inter-chunk cleanup (chunk {chunk_idx})")
print(f"🎯 RESULT: Models will reload fresh for next chunk (prevents 15-20GB accumulation)")
def process_chunk(self,
frames: List[np.ndarray],
chunk_idx: int = 0) -> List[np.ndarray]:
@@ -750,6 +813,9 @@ class VR180Processor(VideoProcessor):
combined = {'left': left_frame, 'right': right_frame}
combined_frames.append(combined)
# CRITICAL: Complete inter-chunk cleanup for independent processing too
self._complete_inter_chunk_cleanup(chunk_idx)
return combined_frames
def save_video(self, frames: List[np.ndarray], output_path: str):