bigtime
This commit is contained in:
@@ -49,8 +49,8 @@ class SAM2VideoMatting:
|
|||||||
|
|
||||||
def _load_model(self, model_cfg: str, checkpoint_path: str):
|
def _load_model(self, model_cfg: str, checkpoint_path: str):
|
||||||
"""Load SAM2 video predictor lazily"""
|
"""Load SAM2 video predictor lazily"""
|
||||||
if self._model_loaded:
|
if self._model_loaded and self.predictor is not None:
|
||||||
return # Already loaded
|
return # Already loaded and predictor exists
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Import heavy SAM2 modules only when needed
|
# Import heavy SAM2 modules only when needed
|
||||||
@@ -419,6 +419,9 @@ class SAM2VideoMatting:
|
|||||||
finally:
|
finally:
|
||||||
self.predictor = None
|
self.predictor = None
|
||||||
|
|
||||||
|
# Reset model loaded state for fresh reload
|
||||||
|
self._model_loaded = False
|
||||||
|
|
||||||
# Force garbage collection (critical for memory leak prevention)
|
# Force garbage collection (critical for memory leak prevention)
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import numpy as np
|
|||||||
from typing import List, Dict, Any, Optional, Tuple
|
from typing import List, Dict, Any, Optional, Tuple
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import warnings
|
import warnings
|
||||||
|
import torch
|
||||||
|
|
||||||
from .video_processor import VideoProcessor
|
from .video_processor import VideoProcessor
|
||||||
from .config import VR180Config
|
from .config import VR180Config
|
||||||
@@ -212,6 +213,10 @@ class VR180Processor(VideoProcessor):
|
|||||||
del right_matted
|
del right_matted
|
||||||
self._aggressive_memory_cleanup(f"After combining frames chunk {chunk_idx}")
|
self._aggressive_memory_cleanup(f"After combining frames chunk {chunk_idx}")
|
||||||
|
|
||||||
|
# CRITICAL: Complete inter-chunk cleanup to prevent model persistence
|
||||||
|
# This ensures models don't accumulate between chunks
|
||||||
|
self._complete_inter_chunk_cleanup(chunk_idx)
|
||||||
|
|
||||||
return combined_frames
|
return combined_frames
|
||||||
|
|
||||||
def _process_eye_sequence(self,
|
def _process_eye_sequence(self,
|
||||||
@@ -691,6 +696,64 @@ class VR180Processor(VideoProcessor):
|
|||||||
# TODO: Implement proper stereo correction algorithm
|
# TODO: Implement proper stereo correction algorithm
|
||||||
return right_frame
|
return right_frame
|
||||||
|
|
||||||
|
def _complete_inter_chunk_cleanup(self, chunk_idx: int):
|
||||||
|
"""
|
||||||
|
Complete inter-chunk cleanup: Destroy all models to prevent memory accumulation
|
||||||
|
|
||||||
|
This addresses the core issue where SAM2 and YOLO models (~15-20GB)
|
||||||
|
persist between chunks, causing OOM when processing subsequent chunks.
|
||||||
|
"""
|
||||||
|
print(f"🧹 INTER-CHUNK CLEANUP: Destroying all models after chunk {chunk_idx}")
|
||||||
|
|
||||||
|
# 1. Completely destroy SAM2 model (15-20GB)
|
||||||
|
if hasattr(self, 'sam2_model') and self.sam2_model is not None:
|
||||||
|
self.sam2_model.cleanup() # Call existing cleanup
|
||||||
|
|
||||||
|
# Force complete destruction of the model
|
||||||
|
try:
|
||||||
|
# Reset the model's loaded state so it will reload fresh
|
||||||
|
if hasattr(self.sam2_model, '_model_loaded'):
|
||||||
|
self.sam2_model._model_loaded = False
|
||||||
|
|
||||||
|
# Clear any cached state
|
||||||
|
if hasattr(self.sam2_model, 'predictor'):
|
||||||
|
self.sam2_model.predictor = None
|
||||||
|
if hasattr(self.sam2_model, 'inference_state'):
|
||||||
|
self.sam2_model.inference_state = None
|
||||||
|
|
||||||
|
print(f" ✅ SAM2 model destroyed and marked for fresh reload")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f" ⚠️ SAM2 destruction warning: {e}")
|
||||||
|
|
||||||
|
# 2. Completely destroy YOLO detector (400MB+)
|
||||||
|
if hasattr(self, 'detector') and self.detector is not None:
|
||||||
|
try:
|
||||||
|
# Force YOLO model to be reloaded fresh
|
||||||
|
if hasattr(self.detector, 'model') and self.detector.model is not None:
|
||||||
|
del self.detector.model
|
||||||
|
self.detector.model = None
|
||||||
|
print(f" ✅ YOLO model destroyed and marked for fresh reload")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f" ⚠️ YOLO destruction warning: {e}")
|
||||||
|
|
||||||
|
# 3. Clear CUDA cache aggressively
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
torch.cuda.synchronize() # Wait for all operations to complete
|
||||||
|
print(f" ✅ CUDA cache cleared")
|
||||||
|
|
||||||
|
# 4. Force garbage collection
|
||||||
|
import gc
|
||||||
|
collected = gc.collect()
|
||||||
|
print(f" ✅ Garbage collection: {collected} objects freed")
|
||||||
|
|
||||||
|
# 5. Memory verification
|
||||||
|
self._print_memory_step(f"After complete inter-chunk cleanup (chunk {chunk_idx})")
|
||||||
|
|
||||||
|
print(f"🎯 RESULT: Models will reload fresh for next chunk (prevents 15-20GB accumulation)")
|
||||||
|
|
||||||
def process_chunk(self,
|
def process_chunk(self,
|
||||||
frames: List[np.ndarray],
|
frames: List[np.ndarray],
|
||||||
chunk_idx: int = 0) -> List[np.ndarray]:
|
chunk_idx: int = 0) -> List[np.ndarray]:
|
||||||
@@ -750,6 +813,9 @@ class VR180Processor(VideoProcessor):
|
|||||||
combined = {'left': left_frame, 'right': right_frame}
|
combined = {'left': left_frame, 'right': right_frame}
|
||||||
combined_frames.append(combined)
|
combined_frames.append(combined)
|
||||||
|
|
||||||
|
# CRITICAL: Complete inter-chunk cleanup for independent processing too
|
||||||
|
self._complete_inter_chunk_cleanup(chunk_idx)
|
||||||
|
|
||||||
return combined_frames
|
return combined_frames
|
||||||
|
|
||||||
def save_video(self, frames: List[np.ndarray], output_path: str):
|
def save_video(self, frames: List[np.ndarray], output_path: str):
|
||||||
|
|||||||
Reference in New Issue
Block a user