diff --git a/vr180_matting/vr180_processor.py b/vr180_matting/vr180_processor.py index 145ad30..94a5d8a 100644 --- a/vr180_matting/vr180_processor.py +++ b/vr180_matting/vr180_processor.py @@ -465,48 +465,70 @@ class VR180Processor(VideoProcessor): def _validate_stereo_consistency_gpu(self, left_results: List[np.ndarray], right_results: List[np.ndarray]) -> List[np.ndarray]: - """GPU-accelerated batch stereo validation using CuPy""" + """GPU-accelerated batch stereo validation using CuPy with memory-safe batching""" import cupy as cp print(" Using GPU acceleration for stereo validation") - # Convert all frames to GPU at once (batch processing) - print(" Transferring frames to GPU...") - left_stack = cp.stack([cp.asarray(frame) for frame in left_results]) - right_stack = cp.stack([cp.asarray(frame) for frame in right_results]) + # Process in batches to avoid GPU OOM + batch_size = 50 # Process 50 frames at a time (safe for 45GB GPU) + total_frames = len(left_results) + area_ratios_all = [] + needs_correction_all = [] - print(" Computing mask areas on GPU...") + print(f" Processing {total_frames} frames in batches of {batch_size}...") - # Batch calculate all mask areas - if left_stack.shape[3] == 4: # Alpha channel - left_masks = left_stack[:, :, :, 3] > 0 - right_masks = right_stack[:, :, :, 3] > 0 - else: # Green screen detection - bg_color = cp.array(self.config.output.background_color) - left_diff = cp.abs(left_stack.astype(cp.float32) - bg_color).sum(axis=3) - right_diff = cp.abs(right_stack.astype(cp.float32) - bg_color).sum(axis=3) - left_masks = left_diff > 30 - right_masks = right_diff > 30 + for batch_start in range(0, total_frames, batch_size): + batch_end = min(batch_start + batch_size, total_frames) + batch_frames = batch_end - batch_start + + if batch_start % 100 == 0: + print(f" GPU batch {batch_start//batch_size + 1}: frames {batch_start}-{batch_end}") + + # Get batch slices + left_batch = left_results[batch_start:batch_end] + right_batch = right_results[batch_start:batch_end] + + # Convert batch to GPU + left_stack = cp.stack([cp.asarray(frame) for frame in left_batch]) + right_stack = cp.stack([cp.asarray(frame) for frame in right_batch]) + + # Batch calculate mask areas for this batch + if left_stack.shape[3] == 4: # Alpha channel + left_masks = left_stack[:, :, :, 3] > 0 + right_masks = right_stack[:, :, :, 3] > 0 + else: # Green screen detection + bg_color = cp.array(self.config.output.background_color) + left_diff = cp.abs(left_stack.astype(cp.float32) - bg_color).sum(axis=3) + right_diff = cp.abs(right_stack.astype(cp.float32) - bg_color).sum(axis=3) + left_masks = left_diff > 30 + right_masks = right_diff > 30 + + # Calculate areas for this batch + left_areas = cp.sum(left_masks, axis=(1, 2)) + right_areas = cp.sum(right_masks, axis=(1, 2)) + area_ratios = right_areas.astype(cp.float32) / (left_areas.astype(cp.float32) + 1e-6) + + # Find frames needing correction in this batch + needs_correction = (area_ratios < 0.5) | (area_ratios > 2.0) + + # Transfer batch results back to CPU and accumulate + area_ratios_all.extend(cp.asnumpy(area_ratios)) + needs_correction_all.extend(cp.asnumpy(needs_correction)) + + # Free GPU memory for this batch + del left_stack, right_stack, left_masks, right_masks + del left_areas, right_areas, area_ratios, needs_correction + cp._default_memory_pool.free_all_blocks() - # Calculate all areas at once (massive parallel speedup) - left_areas = cp.sum(left_masks, axis=(1, 2)) - right_areas = cp.sum(right_masks, axis=(1, 2)) - area_ratios = right_areas.astype(cp.float32) / (left_areas.astype(cp.float32) + 1e-6) - - # Find frames needing correction - needs_correction = (area_ratios < 0.5) | (area_ratios > 2.0) - correction_count = int(cp.sum(needs_correction)) - - print(f" GPU validation complete: {correction_count}/{len(left_results)} frames need correction") - - # Transfer results back to CPU for processing - area_ratios_cpu = cp.asnumpy(area_ratios) - needs_correction_cpu = cp.asnumpy(needs_correction) + correction_count = sum(needs_correction_all) + print(f" GPU validation complete: {correction_count}/{total_frames} frames need correction") + # Apply corrections using CPU results validated_frames = [] - for i, (needs_fix, ratio) in enumerate(zip(needs_correction_cpu, area_ratios_cpu)): + for i, (needs_fix, ratio) in enumerate(zip(needs_correction_all, area_ratios_all)): if i % 100 == 0: - print(f" Processing validation results: {i}/{len(left_results)}") + print(f" Processing validation results: {i}/{total_frames}") if needs_fix: # Apply correction