import torch import psutil import gc import warnings from typing import Optional, Dict, Any from contextlib import contextmanager import time class VRAMManager: """VRAM and memory optimization manager""" def __init__(self, max_vram_gb: float = 10.0, device: str = "cuda"): self.max_vram_gb = max_vram_gb self.device = device self.max_vram_bytes = max_vram_gb * 1024**3 # Memory tracking self.memory_stats = { 'peak_allocated': 0, 'peak_reserved': 0, 'allocations': 0, 'deallocations': 0 } self._check_device() def _check_device(self): """Check if CUDA is available and get device info""" if self.device == "cuda": if not torch.cuda.is_available(): warnings.warn("CUDA not available, falling back to CPU") self.device = "cpu" return device_props = torch.cuda.get_device_properties(0) total_memory = device_props.total_memory print(f"GPU: {device_props.name}") print(f"Total VRAM: {total_memory / 1024**3:.1f} GB") print(f"Max VRAM limit: {self.max_vram_gb:.1f} GB") if self.max_vram_bytes > total_memory * 0.9: warnings.warn(f"Max VRAM limit ({self.max_vram_gb:.1f} GB) is close to total VRAM") def get_memory_usage(self) -> Dict[str, float]: """Get current memory usage statistics""" stats = {} if self.device == "cuda" and torch.cuda.is_available(): stats['vram_allocated'] = torch.cuda.memory_allocated() / 1024**3 stats['vram_reserved'] = torch.cuda.memory_reserved() / 1024**3 stats['vram_free'] = (torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_reserved()) / 1024**3 else: stats['vram_allocated'] = 0 stats['vram_reserved'] = 0 stats['vram_free'] = 0 # System RAM ram_info = psutil.virtual_memory() stats['ram_used'] = ram_info.used / 1024**3 stats['ram_available'] = ram_info.available / 1024**3 stats['ram_percent'] = ram_info.percent return stats def check_memory_available(self, required_gb: float) -> bool: """Check if enough memory is available for operation""" stats = self.get_memory_usage() if self.device == "cuda": return stats['vram_free'] >= required_gb else: return stats['ram_available'] >= required_gb def cleanup_memory(self, aggressive: bool = False): """Clean up memory""" if self.device == "cuda" and torch.cuda.is_available(): torch.cuda.empty_cache() if aggressive: torch.cuda.ipc_collect() torch.cuda.synchronize() # Python garbage collection gc.collect() if aggressive: # Force garbage collection multiple times for _ in range(3): gc.collect() def estimate_processing_memory(self, frame_height: int, frame_width: int, num_frames: int, fp16: bool = True) -> float: """ Estimate memory requirements for processing Args: frame_height: Frame height in pixels frame_width: Frame width in pixels num_frames: Number of frames to process fp16: Whether using FP16 precision Returns: Estimated memory usage in GB """ bytes_per_pixel = 2 if fp16 else 4 # FP16 vs FP32 # Estimate memory components frame_memory = frame_height * frame_width * 3 * bytes_per_pixel * num_frames model_memory = 2.0 * 1024**3 # ~2GB for SAM2 model yolo_memory = 0.5 * 1024**3 # ~0.5GB for YOLO working_memory = frame_memory * 2 # Working space for masks, etc. total_memory = frame_memory + model_memory + yolo_memory + working_memory return total_memory / 1024**3 def get_optimal_chunk_size(self, frame_height: int, frame_width: int, target_memory_gb: Optional[float] = None, fp16: bool = True) -> int: """ Calculate optimal chunk size for processing Args: frame_height: Frame height in pixels frame_width: Frame width in pixels target_memory_gb: Target memory usage (defaults to 80% of max VRAM) fp16: Whether using FP16 precision Returns: Optimal number of frames per chunk """ if target_memory_gb is None: target_memory_gb = self.max_vram_gb * 0.8 # Binary search for optimal chunk size min_frames = 1 max_frames = 1000 optimal_frames = min_frames while min_frames <= max_frames: mid_frames = (min_frames + max_frames) // 2 estimated_memory = self.estimate_processing_memory( frame_height, frame_width, mid_frames, fp16 ) if estimated_memory <= target_memory_gb: optimal_frames = mid_frames min_frames = mid_frames + 1 else: max_frames = mid_frames - 1 return max(optimal_frames, 1) @contextmanager def memory_monitor(self, operation_name: str = "operation"): """Context manager for monitoring memory usage during operations""" start_stats = self.get_memory_usage() start_time = time.time() print(f"Starting {operation_name}") print(f"Initial VRAM: {start_stats['vram_allocated']:.2f} GB allocated, " f"{start_stats['vram_free']:.2f} GB free") try: yield self finally: end_stats = self.get_memory_usage() end_time = time.time() vram_diff = end_stats['vram_allocated'] - start_stats['vram_allocated'] duration = end_time - start_time print(f"Completed {operation_name} in {duration:.1f}s") print(f"Final VRAM: {end_stats['vram_allocated']:.2f} GB allocated, " f"{end_stats['vram_free']:.2f} GB free") print(f"VRAM change: {vram_diff:+.2f} GB") # Update peak stats self.memory_stats['peak_allocated'] = max( self.memory_stats['peak_allocated'], end_stats['vram_allocated'] ) self.memory_stats['peak_reserved'] = max( self.memory_stats['peak_reserved'], end_stats['vram_reserved'] ) def print_memory_report(self): """Print detailed memory usage report""" stats = self.get_memory_usage() print("\n" + "="*50) print("MEMORY USAGE REPORT") print("="*50) if self.device == "cuda": print(f"VRAM Allocated: {stats['vram_allocated']:.2f} GB") print(f"VRAM Reserved: {stats['vram_reserved']:.2f} GB") print(f"VRAM Free: {stats['vram_free']:.2f} GB") print(f"Peak Allocated: {self.memory_stats['peak_allocated']:.2f} GB") print(f"Peak Reserved: {self.memory_stats['peak_reserved']:.2f} GB") print(f"Max VRAM Limit: {self.max_vram_gb:.2f} GB") utilization = (stats['vram_allocated'] / self.max_vram_gb) * 100 print(f"VRAM Utilization: {utilization:.1f}%") print(f"\nSystem RAM Used: {stats['ram_used']:.2f} GB") print(f"System RAM Available: {stats['ram_available']:.2f} GB") print(f"System RAM Usage: {stats['ram_percent']:.1f}%") print("="*50 + "\n") def emergency_cleanup(self): """Emergency memory cleanup when running low""" print("WARNING: Running low on memory, performing emergency cleanup...") self.cleanup_memory(aggressive=True) # Additional cleanup steps if self.device == "cuda" and torch.cuda.is_available(): torch.cuda.reset_peak_memory_stats() stats = self.get_memory_usage() print(f"After cleanup - VRAM: {stats['vram_allocated']:.2f} GB, " f"Free: {stats['vram_free']:.2f} GB") def should_emergency_cleanup(self) -> bool: """Check if emergency cleanup is needed""" stats = self.get_memory_usage() if self.device == "cuda": return stats['vram_free'] < 1.0 # Less than 1GB free else: return stats['ram_available'] < 2.0 # Less than 2GB RAM available