bigtime
This commit is contained in:
@@ -49,8 +49,8 @@ class SAM2VideoMatting:
|
||||
|
||||
def _load_model(self, model_cfg: str, checkpoint_path: str):
|
||||
"""Load SAM2 video predictor lazily"""
|
||||
if self._model_loaded:
|
||||
return # Already loaded
|
||||
if self._model_loaded and self.predictor is not None:
|
||||
return # Already loaded and predictor exists
|
||||
|
||||
try:
|
||||
# Import heavy SAM2 modules only when needed
|
||||
@@ -419,6 +419,9 @@ class SAM2VideoMatting:
|
||||
finally:
|
||||
self.predictor = None
|
||||
|
||||
# Reset model loaded state for fresh reload
|
||||
self._model_loaded = False
|
||||
|
||||
# Force garbage collection (critical for memory leak prevention)
|
||||
gc.collect()
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ import numpy as np
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
from pathlib import Path
|
||||
import warnings
|
||||
import torch
|
||||
|
||||
from .video_processor import VideoProcessor
|
||||
from .config import VR180Config
|
||||
@@ -212,6 +213,10 @@ class VR180Processor(VideoProcessor):
|
||||
del right_matted
|
||||
self._aggressive_memory_cleanup(f"After combining frames chunk {chunk_idx}")
|
||||
|
||||
# CRITICAL: Complete inter-chunk cleanup to prevent model persistence
|
||||
# This ensures models don't accumulate between chunks
|
||||
self._complete_inter_chunk_cleanup(chunk_idx)
|
||||
|
||||
return combined_frames
|
||||
|
||||
def _process_eye_sequence(self,
|
||||
@@ -691,6 +696,64 @@ class VR180Processor(VideoProcessor):
|
||||
# TODO: Implement proper stereo correction algorithm
|
||||
return right_frame
|
||||
|
||||
def _complete_inter_chunk_cleanup(self, chunk_idx: int):
|
||||
"""
|
||||
Complete inter-chunk cleanup: Destroy all models to prevent memory accumulation
|
||||
|
||||
This addresses the core issue where SAM2 and YOLO models (~15-20GB)
|
||||
persist between chunks, causing OOM when processing subsequent chunks.
|
||||
"""
|
||||
print(f"🧹 INTER-CHUNK CLEANUP: Destroying all models after chunk {chunk_idx}")
|
||||
|
||||
# 1. Completely destroy SAM2 model (15-20GB)
|
||||
if hasattr(self, 'sam2_model') and self.sam2_model is not None:
|
||||
self.sam2_model.cleanup() # Call existing cleanup
|
||||
|
||||
# Force complete destruction of the model
|
||||
try:
|
||||
# Reset the model's loaded state so it will reload fresh
|
||||
if hasattr(self.sam2_model, '_model_loaded'):
|
||||
self.sam2_model._model_loaded = False
|
||||
|
||||
# Clear any cached state
|
||||
if hasattr(self.sam2_model, 'predictor'):
|
||||
self.sam2_model.predictor = None
|
||||
if hasattr(self.sam2_model, 'inference_state'):
|
||||
self.sam2_model.inference_state = None
|
||||
|
||||
print(f" ✅ SAM2 model destroyed and marked for fresh reload")
|
||||
|
||||
except Exception as e:
|
||||
print(f" ⚠️ SAM2 destruction warning: {e}")
|
||||
|
||||
# 2. Completely destroy YOLO detector (400MB+)
|
||||
if hasattr(self, 'detector') and self.detector is not None:
|
||||
try:
|
||||
# Force YOLO model to be reloaded fresh
|
||||
if hasattr(self.detector, 'model') and self.detector.model is not None:
|
||||
del self.detector.model
|
||||
self.detector.model = None
|
||||
print(f" ✅ YOLO model destroyed and marked for fresh reload")
|
||||
|
||||
except Exception as e:
|
||||
print(f" ⚠️ YOLO destruction warning: {e}")
|
||||
|
||||
# 3. Clear CUDA cache aggressively
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize() # Wait for all operations to complete
|
||||
print(f" ✅ CUDA cache cleared")
|
||||
|
||||
# 4. Force garbage collection
|
||||
import gc
|
||||
collected = gc.collect()
|
||||
print(f" ✅ Garbage collection: {collected} objects freed")
|
||||
|
||||
# 5. Memory verification
|
||||
self._print_memory_step(f"After complete inter-chunk cleanup (chunk {chunk_idx})")
|
||||
|
||||
print(f"🎯 RESULT: Models will reload fresh for next chunk (prevents 15-20GB accumulation)")
|
||||
|
||||
def process_chunk(self,
|
||||
frames: List[np.ndarray],
|
||||
chunk_idx: int = 0) -> List[np.ndarray]:
|
||||
@@ -750,6 +813,9 @@ class VR180Processor(VideoProcessor):
|
||||
combined = {'left': left_frame, 'right': right_frame}
|
||||
combined_frames.append(combined)
|
||||
|
||||
# CRITICAL: Complete inter-chunk cleanup for independent processing too
|
||||
self._complete_inter_chunk_cleanup(chunk_idx)
|
||||
|
||||
return combined_frames
|
||||
|
||||
def save_video(self, frames: List[np.ndarray], output_path: str):
|
||||
|
||||
Reference in New Issue
Block a user