241 lines
8.9 KiB
Python
241 lines
8.9 KiB
Python
import torch
|
|
import psutil
|
|
import gc
|
|
import warnings
|
|
from typing import Optional, Dict, Any
|
|
from contextlib import contextmanager
|
|
import time
|
|
|
|
|
|
class VRAMManager:
|
|
"""VRAM and memory optimization manager"""
|
|
|
|
def __init__(self, max_vram_gb: float = 10.0, device: str = "cuda"):
|
|
self.max_vram_gb = max_vram_gb
|
|
self.device = device
|
|
self.max_vram_bytes = max_vram_gb * 1024**3
|
|
|
|
# Memory tracking
|
|
self.memory_stats = {
|
|
'peak_allocated': 0,
|
|
'peak_reserved': 0,
|
|
'allocations': 0,
|
|
'deallocations': 0
|
|
}
|
|
|
|
self._check_device()
|
|
|
|
def _check_device(self):
|
|
"""Check if CUDA is available and get device info"""
|
|
if self.device == "cuda":
|
|
if not torch.cuda.is_available():
|
|
warnings.warn("CUDA not available, falling back to CPU")
|
|
self.device = "cpu"
|
|
return
|
|
|
|
device_props = torch.cuda.get_device_properties(0)
|
|
total_memory = device_props.total_memory
|
|
|
|
print(f"GPU: {device_props.name}")
|
|
print(f"Total VRAM: {total_memory / 1024**3:.1f} GB")
|
|
print(f"Max VRAM limit: {self.max_vram_gb:.1f} GB")
|
|
|
|
if self.max_vram_bytes > total_memory * 0.9:
|
|
warnings.warn(f"Max VRAM limit ({self.max_vram_gb:.1f} GB) is close to total VRAM")
|
|
|
|
def get_memory_usage(self) -> Dict[str, float]:
|
|
"""Get current memory usage statistics"""
|
|
stats = {}
|
|
|
|
if self.device == "cuda" and torch.cuda.is_available():
|
|
stats['vram_allocated'] = torch.cuda.memory_allocated() / 1024**3
|
|
stats['vram_reserved'] = torch.cuda.memory_reserved() / 1024**3
|
|
stats['vram_free'] = (torch.cuda.get_device_properties(0).total_memory -
|
|
torch.cuda.memory_reserved()) / 1024**3
|
|
else:
|
|
stats['vram_allocated'] = 0
|
|
stats['vram_reserved'] = 0
|
|
stats['vram_free'] = 0
|
|
|
|
# System RAM
|
|
ram_info = psutil.virtual_memory()
|
|
stats['ram_used'] = ram_info.used / 1024**3
|
|
stats['ram_available'] = ram_info.available / 1024**3
|
|
stats['ram_percent'] = ram_info.percent
|
|
|
|
return stats
|
|
|
|
def check_memory_available(self, required_gb: float) -> bool:
|
|
"""Check if enough memory is available for operation"""
|
|
stats = self.get_memory_usage()
|
|
|
|
if self.device == "cuda":
|
|
return stats['vram_free'] >= required_gb
|
|
else:
|
|
return stats['ram_available'] >= required_gb
|
|
|
|
def cleanup_memory(self, aggressive: bool = False):
|
|
"""Clean up memory"""
|
|
if self.device == "cuda" and torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
|
|
if aggressive:
|
|
torch.cuda.ipc_collect()
|
|
torch.cuda.synchronize()
|
|
|
|
# Python garbage collection
|
|
gc.collect()
|
|
|
|
if aggressive:
|
|
# Force garbage collection multiple times
|
|
for _ in range(3):
|
|
gc.collect()
|
|
|
|
def estimate_processing_memory(self,
|
|
frame_height: int,
|
|
frame_width: int,
|
|
num_frames: int,
|
|
fp16: bool = True) -> float:
|
|
"""
|
|
Estimate memory requirements for processing
|
|
|
|
Args:
|
|
frame_height: Frame height in pixels
|
|
frame_width: Frame width in pixels
|
|
num_frames: Number of frames to process
|
|
fp16: Whether using FP16 precision
|
|
|
|
Returns:
|
|
Estimated memory usage in GB
|
|
"""
|
|
bytes_per_pixel = 2 if fp16 else 4 # FP16 vs FP32
|
|
|
|
# Estimate memory components
|
|
frame_memory = frame_height * frame_width * 3 * bytes_per_pixel * num_frames
|
|
model_memory = 2.0 * 1024**3 # ~2GB for SAM2 model
|
|
yolo_memory = 0.5 * 1024**3 # ~0.5GB for YOLO
|
|
working_memory = frame_memory * 2 # Working space for masks, etc.
|
|
|
|
total_memory = frame_memory + model_memory + yolo_memory + working_memory
|
|
|
|
return total_memory / 1024**3
|
|
|
|
def get_optimal_chunk_size(self,
|
|
frame_height: int,
|
|
frame_width: int,
|
|
target_memory_gb: Optional[float] = None,
|
|
fp16: bool = True) -> int:
|
|
"""
|
|
Calculate optimal chunk size for processing
|
|
|
|
Args:
|
|
frame_height: Frame height in pixels
|
|
frame_width: Frame width in pixels
|
|
target_memory_gb: Target memory usage (defaults to 80% of max VRAM)
|
|
fp16: Whether using FP16 precision
|
|
|
|
Returns:
|
|
Optimal number of frames per chunk
|
|
"""
|
|
if target_memory_gb is None:
|
|
target_memory_gb = self.max_vram_gb * 0.8
|
|
|
|
# Binary search for optimal chunk size
|
|
min_frames = 1
|
|
max_frames = 1000
|
|
optimal_frames = min_frames
|
|
|
|
while min_frames <= max_frames:
|
|
mid_frames = (min_frames + max_frames) // 2
|
|
estimated_memory = self.estimate_processing_memory(
|
|
frame_height, frame_width, mid_frames, fp16
|
|
)
|
|
|
|
if estimated_memory <= target_memory_gb:
|
|
optimal_frames = mid_frames
|
|
min_frames = mid_frames + 1
|
|
else:
|
|
max_frames = mid_frames - 1
|
|
|
|
return max(optimal_frames, 1)
|
|
|
|
@contextmanager
|
|
def memory_monitor(self, operation_name: str = "operation"):
|
|
"""Context manager for monitoring memory usage during operations"""
|
|
start_stats = self.get_memory_usage()
|
|
start_time = time.time()
|
|
|
|
print(f"Starting {operation_name}")
|
|
print(f"Initial VRAM: {start_stats['vram_allocated']:.2f} GB allocated, "
|
|
f"{start_stats['vram_free']:.2f} GB free")
|
|
|
|
try:
|
|
yield self
|
|
finally:
|
|
end_stats = self.get_memory_usage()
|
|
end_time = time.time()
|
|
|
|
vram_diff = end_stats['vram_allocated'] - start_stats['vram_allocated']
|
|
duration = end_time - start_time
|
|
|
|
print(f"Completed {operation_name} in {duration:.1f}s")
|
|
print(f"Final VRAM: {end_stats['vram_allocated']:.2f} GB allocated, "
|
|
f"{end_stats['vram_free']:.2f} GB free")
|
|
print(f"VRAM change: {vram_diff:+.2f} GB")
|
|
|
|
# Update peak stats
|
|
self.memory_stats['peak_allocated'] = max(
|
|
self.memory_stats['peak_allocated'],
|
|
end_stats['vram_allocated']
|
|
)
|
|
self.memory_stats['peak_reserved'] = max(
|
|
self.memory_stats['peak_reserved'],
|
|
end_stats['vram_reserved']
|
|
)
|
|
|
|
def print_memory_report(self):
|
|
"""Print detailed memory usage report"""
|
|
stats = self.get_memory_usage()
|
|
|
|
print("\n" + "="*50)
|
|
print("MEMORY USAGE REPORT")
|
|
print("="*50)
|
|
|
|
if self.device == "cuda":
|
|
print(f"VRAM Allocated: {stats['vram_allocated']:.2f} GB")
|
|
print(f"VRAM Reserved: {stats['vram_reserved']:.2f} GB")
|
|
print(f"VRAM Free: {stats['vram_free']:.2f} GB")
|
|
print(f"Peak Allocated: {self.memory_stats['peak_allocated']:.2f} GB")
|
|
print(f"Peak Reserved: {self.memory_stats['peak_reserved']:.2f} GB")
|
|
print(f"Max VRAM Limit: {self.max_vram_gb:.2f} GB")
|
|
|
|
utilization = (stats['vram_allocated'] / self.max_vram_gb) * 100
|
|
print(f"VRAM Utilization: {utilization:.1f}%")
|
|
|
|
print(f"\nSystem RAM Used: {stats['ram_used']:.2f} GB")
|
|
print(f"System RAM Available: {stats['ram_available']:.2f} GB")
|
|
print(f"System RAM Usage: {stats['ram_percent']:.1f}%")
|
|
print("="*50 + "\n")
|
|
|
|
def emergency_cleanup(self):
|
|
"""Emergency memory cleanup when running low"""
|
|
print("WARNING: Running low on memory, performing emergency cleanup...")
|
|
|
|
self.cleanup_memory(aggressive=True)
|
|
|
|
# Additional cleanup steps
|
|
if self.device == "cuda" and torch.cuda.is_available():
|
|
torch.cuda.reset_peak_memory_stats()
|
|
|
|
stats = self.get_memory_usage()
|
|
print(f"After cleanup - VRAM: {stats['vram_allocated']:.2f} GB, "
|
|
f"Free: {stats['vram_free']:.2f} GB")
|
|
|
|
def should_emergency_cleanup(self) -> bool:
|
|
"""Check if emergency cleanup is needed"""
|
|
stats = self.get_memory_usage()
|
|
|
|
if self.device == "cuda":
|
|
return stats['vram_free'] < 1.0 # Less than 1GB free
|
|
else:
|
|
return stats['ram_available'] < 2.0 # Less than 2GB RAM available |