first commit
This commit is contained in:
241
vr180_matting/memory_manager.py
Normal file
241
vr180_matting/memory_manager.py
Normal file
@@ -0,0 +1,241 @@
|
||||
import torch
|
||||
import psutil
|
||||
import gc
|
||||
import warnings
|
||||
from typing import Optional, Dict, Any
|
||||
from contextlib import contextmanager
|
||||
import time
|
||||
|
||||
|
||||
class VRAMManager:
|
||||
"""VRAM and memory optimization manager"""
|
||||
|
||||
def __init__(self, max_vram_gb: float = 10.0, device: str = "cuda"):
|
||||
self.max_vram_gb = max_vram_gb
|
||||
self.device = device
|
||||
self.max_vram_bytes = max_vram_gb * 1024**3
|
||||
|
||||
# Memory tracking
|
||||
self.memory_stats = {
|
||||
'peak_allocated': 0,
|
||||
'peak_reserved': 0,
|
||||
'allocations': 0,
|
||||
'deallocations': 0
|
||||
}
|
||||
|
||||
self._check_device()
|
||||
|
||||
def _check_device(self):
|
||||
"""Check if CUDA is available and get device info"""
|
||||
if self.device == "cuda":
|
||||
if not torch.cuda.is_available():
|
||||
warnings.warn("CUDA not available, falling back to CPU")
|
||||
self.device = "cpu"
|
||||
return
|
||||
|
||||
device_props = torch.cuda.get_device_properties(0)
|
||||
total_memory = device_props.total_memory
|
||||
|
||||
print(f"GPU: {device_props.name}")
|
||||
print(f"Total VRAM: {total_memory / 1024**3:.1f} GB")
|
||||
print(f"Max VRAM limit: {self.max_vram_gb:.1f} GB")
|
||||
|
||||
if self.max_vram_bytes > total_memory * 0.9:
|
||||
warnings.warn(f"Max VRAM limit ({self.max_vram_gb:.1f} GB) is close to total VRAM")
|
||||
|
||||
def get_memory_usage(self) -> Dict[str, float]:
|
||||
"""Get current memory usage statistics"""
|
||||
stats = {}
|
||||
|
||||
if self.device == "cuda" and torch.cuda.is_available():
|
||||
stats['vram_allocated'] = torch.cuda.memory_allocated() / 1024**3
|
||||
stats['vram_reserved'] = torch.cuda.memory_reserved() / 1024**3
|
||||
stats['vram_free'] = (torch.cuda.get_device_properties(0).total_memory -
|
||||
torch.cuda.memory_reserved()) / 1024**3
|
||||
else:
|
||||
stats['vram_allocated'] = 0
|
||||
stats['vram_reserved'] = 0
|
||||
stats['vram_free'] = 0
|
||||
|
||||
# System RAM
|
||||
ram_info = psutil.virtual_memory()
|
||||
stats['ram_used'] = ram_info.used / 1024**3
|
||||
stats['ram_available'] = ram_info.available / 1024**3
|
||||
stats['ram_percent'] = ram_info.percent
|
||||
|
||||
return stats
|
||||
|
||||
def check_memory_available(self, required_gb: float) -> bool:
|
||||
"""Check if enough memory is available for operation"""
|
||||
stats = self.get_memory_usage()
|
||||
|
||||
if self.device == "cuda":
|
||||
return stats['vram_free'] >= required_gb
|
||||
else:
|
||||
return stats['ram_available'] >= required_gb
|
||||
|
||||
def cleanup_memory(self, aggressive: bool = False):
|
||||
"""Clean up memory"""
|
||||
if self.device == "cuda" and torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if aggressive:
|
||||
torch.cuda.ipc_collect()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Python garbage collection
|
||||
gc.collect()
|
||||
|
||||
if aggressive:
|
||||
# Force garbage collection multiple times
|
||||
for _ in range(3):
|
||||
gc.collect()
|
||||
|
||||
def estimate_processing_memory(self,
|
||||
frame_height: int,
|
||||
frame_width: int,
|
||||
num_frames: int,
|
||||
fp16: bool = True) -> float:
|
||||
"""
|
||||
Estimate memory requirements for processing
|
||||
|
||||
Args:
|
||||
frame_height: Frame height in pixels
|
||||
frame_width: Frame width in pixels
|
||||
num_frames: Number of frames to process
|
||||
fp16: Whether using FP16 precision
|
||||
|
||||
Returns:
|
||||
Estimated memory usage in GB
|
||||
"""
|
||||
bytes_per_pixel = 2 if fp16 else 4 # FP16 vs FP32
|
||||
|
||||
# Estimate memory components
|
||||
frame_memory = frame_height * frame_width * 3 * bytes_per_pixel * num_frames
|
||||
model_memory = 2.0 * 1024**3 # ~2GB for SAM2 model
|
||||
yolo_memory = 0.5 * 1024**3 # ~0.5GB for YOLO
|
||||
working_memory = frame_memory * 2 # Working space for masks, etc.
|
||||
|
||||
total_memory = frame_memory + model_memory + yolo_memory + working_memory
|
||||
|
||||
return total_memory / 1024**3
|
||||
|
||||
def get_optimal_chunk_size(self,
|
||||
frame_height: int,
|
||||
frame_width: int,
|
||||
target_memory_gb: Optional[float] = None,
|
||||
fp16: bool = True) -> int:
|
||||
"""
|
||||
Calculate optimal chunk size for processing
|
||||
|
||||
Args:
|
||||
frame_height: Frame height in pixels
|
||||
frame_width: Frame width in pixels
|
||||
target_memory_gb: Target memory usage (defaults to 80% of max VRAM)
|
||||
fp16: Whether using FP16 precision
|
||||
|
||||
Returns:
|
||||
Optimal number of frames per chunk
|
||||
"""
|
||||
if target_memory_gb is None:
|
||||
target_memory_gb = self.max_vram_gb * 0.8
|
||||
|
||||
# Binary search for optimal chunk size
|
||||
min_frames = 1
|
||||
max_frames = 1000
|
||||
optimal_frames = min_frames
|
||||
|
||||
while min_frames <= max_frames:
|
||||
mid_frames = (min_frames + max_frames) // 2
|
||||
estimated_memory = self.estimate_processing_memory(
|
||||
frame_height, frame_width, mid_frames, fp16
|
||||
)
|
||||
|
||||
if estimated_memory <= target_memory_gb:
|
||||
optimal_frames = mid_frames
|
||||
min_frames = mid_frames + 1
|
||||
else:
|
||||
max_frames = mid_frames - 1
|
||||
|
||||
return max(optimal_frames, 1)
|
||||
|
||||
@contextmanager
|
||||
def memory_monitor(self, operation_name: str = "operation"):
|
||||
"""Context manager for monitoring memory usage during operations"""
|
||||
start_stats = self.get_memory_usage()
|
||||
start_time = time.time()
|
||||
|
||||
print(f"Starting {operation_name}")
|
||||
print(f"Initial VRAM: {start_stats['vram_allocated']:.2f} GB allocated, "
|
||||
f"{start_stats['vram_free']:.2f} GB free")
|
||||
|
||||
try:
|
||||
yield self
|
||||
finally:
|
||||
end_stats = self.get_memory_usage()
|
||||
end_time = time.time()
|
||||
|
||||
vram_diff = end_stats['vram_allocated'] - start_stats['vram_allocated']
|
||||
duration = end_time - start_time
|
||||
|
||||
print(f"Completed {operation_name} in {duration:.1f}s")
|
||||
print(f"Final VRAM: {end_stats['vram_allocated']:.2f} GB allocated, "
|
||||
f"{end_stats['vram_free']:.2f} GB free")
|
||||
print(f"VRAM change: {vram_diff:+.2f} GB")
|
||||
|
||||
# Update peak stats
|
||||
self.memory_stats['peak_allocated'] = max(
|
||||
self.memory_stats['peak_allocated'],
|
||||
end_stats['vram_allocated']
|
||||
)
|
||||
self.memory_stats['peak_reserved'] = max(
|
||||
self.memory_stats['peak_reserved'],
|
||||
end_stats['vram_reserved']
|
||||
)
|
||||
|
||||
def print_memory_report(self):
|
||||
"""Print detailed memory usage report"""
|
||||
stats = self.get_memory_usage()
|
||||
|
||||
print("\n" + "="*50)
|
||||
print("MEMORY USAGE REPORT")
|
||||
print("="*50)
|
||||
|
||||
if self.device == "cuda":
|
||||
print(f"VRAM Allocated: {stats['vram_allocated']:.2f} GB")
|
||||
print(f"VRAM Reserved: {stats['vram_reserved']:.2f} GB")
|
||||
print(f"VRAM Free: {stats['vram_free']:.2f} GB")
|
||||
print(f"Peak Allocated: {self.memory_stats['peak_allocated']:.2f} GB")
|
||||
print(f"Peak Reserved: {self.memory_stats['peak_reserved']:.2f} GB")
|
||||
print(f"Max VRAM Limit: {self.max_vram_gb:.2f} GB")
|
||||
|
||||
utilization = (stats['vram_allocated'] / self.max_vram_gb) * 100
|
||||
print(f"VRAM Utilization: {utilization:.1f}%")
|
||||
|
||||
print(f"\nSystem RAM Used: {stats['ram_used']:.2f} GB")
|
||||
print(f"System RAM Available: {stats['ram_available']:.2f} GB")
|
||||
print(f"System RAM Usage: {stats['ram_percent']:.1f}%")
|
||||
print("="*50 + "\n")
|
||||
|
||||
def emergency_cleanup(self):
|
||||
"""Emergency memory cleanup when running low"""
|
||||
print("WARNING: Running low on memory, performing emergency cleanup...")
|
||||
|
||||
self.cleanup_memory(aggressive=True)
|
||||
|
||||
# Additional cleanup steps
|
||||
if self.device == "cuda" and torch.cuda.is_available():
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
stats = self.get_memory_usage()
|
||||
print(f"After cleanup - VRAM: {stats['vram_allocated']:.2f} GB, "
|
||||
f"Free: {stats['vram_free']:.2f} GB")
|
||||
|
||||
def should_emergency_cleanup(self) -> bool:
|
||||
"""Check if emergency cleanup is needed"""
|
||||
stats = self.get_memory_usage()
|
||||
|
||||
if self.device == "cuda":
|
||||
return stats['vram_free'] < 1.0 # Less than 1GB free
|
||||
else:
|
||||
return stats['ram_available'] < 2.0 # Less than 2GB RAM available
|
||||
Reference in New Issue
Block a user