fix gpu memory issue

This commit is contained in:
2025-07-26 12:42:16 -07:00
parent 725a781456
commit df7b009a7b

View File

@@ -465,19 +465,35 @@ class VR180Processor(VideoProcessor):
def _validate_stereo_consistency_gpu(self, def _validate_stereo_consistency_gpu(self,
left_results: List[np.ndarray], left_results: List[np.ndarray],
right_results: List[np.ndarray]) -> List[np.ndarray]: right_results: List[np.ndarray]) -> List[np.ndarray]:
"""GPU-accelerated batch stereo validation using CuPy""" """GPU-accelerated batch stereo validation using CuPy with memory-safe batching"""
import cupy as cp import cupy as cp
print(" Using GPU acceleration for stereo validation") print(" Using GPU acceleration for stereo validation")
# Convert all frames to GPU at once (batch processing) # Process in batches to avoid GPU OOM
print(" Transferring frames to GPU...") batch_size = 50 # Process 50 frames at a time (safe for 45GB GPU)
left_stack = cp.stack([cp.asarray(frame) for frame in left_results]) total_frames = len(left_results)
right_stack = cp.stack([cp.asarray(frame) for frame in right_results]) area_ratios_all = []
needs_correction_all = []
print(" Computing mask areas on GPU...") print(f" Processing {total_frames} frames in batches of {batch_size}...")
# Batch calculate all mask areas for batch_start in range(0, total_frames, batch_size):
batch_end = min(batch_start + batch_size, total_frames)
batch_frames = batch_end - batch_start
if batch_start % 100 == 0:
print(f" GPU batch {batch_start//batch_size + 1}: frames {batch_start}-{batch_end}")
# Get batch slices
left_batch = left_results[batch_start:batch_end]
right_batch = right_results[batch_start:batch_end]
# Convert batch to GPU
left_stack = cp.stack([cp.asarray(frame) for frame in left_batch])
right_stack = cp.stack([cp.asarray(frame) for frame in right_batch])
# Batch calculate mask areas for this batch
if left_stack.shape[3] == 4: # Alpha channel if left_stack.shape[3] == 4: # Alpha channel
left_masks = left_stack[:, :, :, 3] > 0 left_masks = left_stack[:, :, :, 3] > 0
right_masks = right_stack[:, :, :, 3] > 0 right_masks = right_stack[:, :, :, 3] > 0
@@ -488,25 +504,31 @@ class VR180Processor(VideoProcessor):
left_masks = left_diff > 30 left_masks = left_diff > 30
right_masks = right_diff > 30 right_masks = right_diff > 30
# Calculate all areas at once (massive parallel speedup) # Calculate areas for this batch
left_areas = cp.sum(left_masks, axis=(1, 2)) left_areas = cp.sum(left_masks, axis=(1, 2))
right_areas = cp.sum(right_masks, axis=(1, 2)) right_areas = cp.sum(right_masks, axis=(1, 2))
area_ratios = right_areas.astype(cp.float32) / (left_areas.astype(cp.float32) + 1e-6) area_ratios = right_areas.astype(cp.float32) / (left_areas.astype(cp.float32) + 1e-6)
# Find frames needing correction # Find frames needing correction in this batch
needs_correction = (area_ratios < 0.5) | (area_ratios > 2.0) needs_correction = (area_ratios < 0.5) | (area_ratios > 2.0)
correction_count = int(cp.sum(needs_correction))
print(f" GPU validation complete: {correction_count}/{len(left_results)} frames need correction") # Transfer batch results back to CPU and accumulate
area_ratios_all.extend(cp.asnumpy(area_ratios))
needs_correction_all.extend(cp.asnumpy(needs_correction))
# Transfer results back to CPU for processing # Free GPU memory for this batch
area_ratios_cpu = cp.asnumpy(area_ratios) del left_stack, right_stack, left_masks, right_masks
needs_correction_cpu = cp.asnumpy(needs_correction) del left_areas, right_areas, area_ratios, needs_correction
cp._default_memory_pool.free_all_blocks()
correction_count = sum(needs_correction_all)
print(f" GPU validation complete: {correction_count}/{total_frames} frames need correction")
# Apply corrections using CPU results
validated_frames = [] validated_frames = []
for i, (needs_fix, ratio) in enumerate(zip(needs_correction_cpu, area_ratios_cpu)): for i, (needs_fix, ratio) in enumerate(zip(needs_correction_all, area_ratios_all)):
if i % 100 == 0: if i % 100 == 0:
print(f" Processing validation results: {i}/{len(left_results)}") print(f" Processing validation results: {i}/{total_frames}")
if needs_fix: if needs_fix:
# Apply correction # Apply correction