This commit is contained in:
2025-07-26 15:29:37 -07:00
parent 884cb8dce2
commit 4d1361df46
2 changed files with 71 additions and 2 deletions

View File

@@ -49,8 +49,8 @@ class SAM2VideoMatting:
def _load_model(self, model_cfg: str, checkpoint_path: str): def _load_model(self, model_cfg: str, checkpoint_path: str):
"""Load SAM2 video predictor lazily""" """Load SAM2 video predictor lazily"""
if self._model_loaded: if self._model_loaded and self.predictor is not None:
return # Already loaded return # Already loaded and predictor exists
try: try:
# Import heavy SAM2 modules only when needed # Import heavy SAM2 modules only when needed
@@ -419,6 +419,9 @@ class SAM2VideoMatting:
finally: finally:
self.predictor = None self.predictor = None
# Reset model loaded state for fresh reload
self._model_loaded = False
# Force garbage collection (critical for memory leak prevention) # Force garbage collection (critical for memory leak prevention)
gc.collect() gc.collect()

View File

@@ -3,6 +3,7 @@ import numpy as np
from typing import List, Dict, Any, Optional, Tuple from typing import List, Dict, Any, Optional, Tuple
from pathlib import Path from pathlib import Path
import warnings import warnings
import torch
from .video_processor import VideoProcessor from .video_processor import VideoProcessor
from .config import VR180Config from .config import VR180Config
@@ -212,6 +213,10 @@ class VR180Processor(VideoProcessor):
del right_matted del right_matted
self._aggressive_memory_cleanup(f"After combining frames chunk {chunk_idx}") self._aggressive_memory_cleanup(f"After combining frames chunk {chunk_idx}")
# CRITICAL: Complete inter-chunk cleanup to prevent model persistence
# This ensures models don't accumulate between chunks
self._complete_inter_chunk_cleanup(chunk_idx)
return combined_frames return combined_frames
def _process_eye_sequence(self, def _process_eye_sequence(self,
@@ -691,6 +696,64 @@ class VR180Processor(VideoProcessor):
# TODO: Implement proper stereo correction algorithm # TODO: Implement proper stereo correction algorithm
return right_frame return right_frame
def _complete_inter_chunk_cleanup(self, chunk_idx: int):
"""
Complete inter-chunk cleanup: Destroy all models to prevent memory accumulation
This addresses the core issue where SAM2 and YOLO models (~15-20GB)
persist between chunks, causing OOM when processing subsequent chunks.
"""
print(f"🧹 INTER-CHUNK CLEANUP: Destroying all models after chunk {chunk_idx}")
# 1. Completely destroy SAM2 model (15-20GB)
if hasattr(self, 'sam2_model') and self.sam2_model is not None:
self.sam2_model.cleanup() # Call existing cleanup
# Force complete destruction of the model
try:
# Reset the model's loaded state so it will reload fresh
if hasattr(self.sam2_model, '_model_loaded'):
self.sam2_model._model_loaded = False
# Clear any cached state
if hasattr(self.sam2_model, 'predictor'):
self.sam2_model.predictor = None
if hasattr(self.sam2_model, 'inference_state'):
self.sam2_model.inference_state = None
print(f" ✅ SAM2 model destroyed and marked for fresh reload")
except Exception as e:
print(f" ⚠️ SAM2 destruction warning: {e}")
# 2. Completely destroy YOLO detector (400MB+)
if hasattr(self, 'detector') and self.detector is not None:
try:
# Force YOLO model to be reloaded fresh
if hasattr(self.detector, 'model') and self.detector.model is not None:
del self.detector.model
self.detector.model = None
print(f" ✅ YOLO model destroyed and marked for fresh reload")
except Exception as e:
print(f" ⚠️ YOLO destruction warning: {e}")
# 3. Clear CUDA cache aggressively
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize() # Wait for all operations to complete
print(f" ✅ CUDA cache cleared")
# 4. Force garbage collection
import gc
collected = gc.collect()
print(f" ✅ Garbage collection: {collected} objects freed")
# 5. Memory verification
self._print_memory_step(f"After complete inter-chunk cleanup (chunk {chunk_idx})")
print(f"🎯 RESULT: Models will reload fresh for next chunk (prevents 15-20GB accumulation)")
def process_chunk(self, def process_chunk(self,
frames: List[np.ndarray], frames: List[np.ndarray],
chunk_idx: int = 0) -> List[np.ndarray]: chunk_idx: int = 0) -> List[np.ndarray]:
@@ -750,6 +813,9 @@ class VR180Processor(VideoProcessor):
combined = {'left': left_frame, 'right': right_frame} combined = {'left': left_frame, 'right': right_frame}
combined_frames.append(combined) combined_frames.append(combined)
# CRITICAL: Complete inter-chunk cleanup for independent processing too
self._complete_inter_chunk_cleanup(chunk_idx)
return combined_frames return combined_frames
def save_video(self, frames: List[np.ndarray], output_path: str): def save_video(self, frames: List[np.ndarray], output_path: str):