diff --git a/memory_profiler_script.py b/memory_profiler_script.py new file mode 100644 index 0000000..80f0655 --- /dev/null +++ b/memory_profiler_script.py @@ -0,0 +1,227 @@ +#!/usr/bin/env python3 +""" +Memory profiling script for VR180 matting pipeline +Tracks memory usage during processing to identify leaks +""" + +import sys +import time +import psutil +import tracemalloc +import subprocess +import gc +from pathlib import Path +from typing import Dict, List, Tuple +import threading +import json + +class MemoryProfiler: + def __init__(self, output_file: str = "memory_profile.json"): + self.output_file = output_file + self.data = [] + self.process = psutil.Process() + self.running = False + self.thread = None + + def start_monitoring(self, interval: float = 1.0): + """Start continuous memory monitoring""" + tracemalloc.start() + self.running = True + self.thread = threading.Thread(target=self._monitor_loop, args=(interval,)) + self.thread.daemon = True + self.thread.start() + print(f"šŸ” Memory monitoring started (interval: {interval}s)") + + def stop_monitoring(self): + """Stop monitoring and save results""" + self.running = False + if self.thread: + self.thread.join() + + # Get tracemalloc snapshot + snapshot = tracemalloc.take_snapshot() + top_stats = snapshot.statistics('lineno') + + # Save detailed results + results = { + 'timeline': self.data, + 'top_memory_allocations': [ + { + 'file': stat.traceback.format()[0], + 'size_mb': stat.size / 1024 / 1024, + 'count': stat.count + } + for stat in top_stats[:20] # Top 20 allocations + ], + 'summary': { + 'peak_rss_gb': max([d['rss_gb'] for d in self.data]) if self.data else 0, + 'peak_vram_gb': max([d['vram_gb'] for d in self.data]) if self.data else 0, + 'total_samples': len(self.data) + } + } + + with open(self.output_file, 'w') as f: + json.dump(results, f, indent=2) + + tracemalloc.stop() + print(f"šŸ“Š Memory profile saved to {self.output_file}") + + def _monitor_loop(self, interval: float): + """Continuous monitoring loop""" + while self.running: + try: + # System memory + memory_info = self.process.memory_info() + rss_gb = memory_info.rss / (1024**3) + + # System-wide memory + sys_memory = psutil.virtual_memory() + sys_used_gb = (sys_memory.total - sys_memory.available) / (1024**3) + sys_available_gb = sys_memory.available / (1024**3) + + # GPU memory (if available) + vram_gb = 0 + vram_free_gb = 0 + try: + result = subprocess.run(['nvidia-smi', '--query-gpu=memory.used,memory.free', + '--format=csv,noheader,nounits'], + capture_output=True, text=True, timeout=5) + if result.returncode == 0: + lines = result.stdout.strip().split('\n') + if lines and lines[0]: + used, free = lines[0].split(', ') + vram_gb = float(used) / 1024 + vram_free_gb = float(free) / 1024 + except Exception: + pass + + # Tracemalloc current usage + try: + current, peak = tracemalloc.get_traced_memory() + traced_mb = current / (1024**2) + except Exception: + traced_mb = 0 + + data_point = { + 'timestamp': time.time(), + 'rss_gb': rss_gb, + 'vram_gb': vram_gb, + 'vram_free_gb': vram_free_gb, + 'sys_used_gb': sys_used_gb, + 'sys_available_gb': sys_available_gb, + 'traced_mb': traced_mb + } + + self.data.append(data_point) + + # Print periodic updates + if len(self.data) % 10 == 0: # Every 10 samples + print(f"šŸ” Memory: RSS={rss_gb:.2f}GB, VRAM={vram_gb:.2f}GB, Sys={sys_used_gb:.1f}GB") + + except Exception as e: + print(f"Monitoring error: {e}") + + time.sleep(interval) + + def log_checkpoint(self, checkpoint_name: str): + """Log a specific checkpoint""" + if self.data: + self.data[-1]['checkpoint'] = checkpoint_name + latest = self.data[-1] + print(f"šŸ“ CHECKPOINT [{checkpoint_name}]: RSS={latest['rss_gb']:.2f}GB, VRAM={latest['vram_gb']:.2f}GB") + +def run_with_profiling(config_path: str): + """Run the VR180 matting with memory profiling""" + profiler = MemoryProfiler("memory_profile_detailed.json") + + try: + # Start monitoring + profiler.start_monitoring(interval=2.0) # Sample every 2 seconds + + # Log initial state + profiler.log_checkpoint("STARTUP") + + # Import after starting profiler to catch import memory usage + print("Importing VR180 processor...") + from vr180_matting.vr180_processor import VR180Processor + from vr180_matting.config import VR180Config + + profiler.log_checkpoint("IMPORTS_COMPLETE") + + # Load config + print(f"Loading config from {config_path}") + config = VR180Config.from_yaml(config_path) + + profiler.log_checkpoint("CONFIG_LOADED") + + # Initialize processor + print("Initializing VR180 processor...") + processor = VR180Processor(config) + + profiler.log_checkpoint("PROCESSOR_INITIALIZED") + + # Force garbage collection + gc.collect() + profiler.log_checkpoint("INITIAL_GC_COMPLETE") + + # Run processing + print("Starting VR180 processing...") + processor.process_video() + + profiler.log_checkpoint("PROCESSING_COMPLETE") + + except Exception as e: + print(f"āŒ Error during processing: {e}") + profiler.log_checkpoint(f"ERROR: {str(e)}") + raise + finally: + # Stop monitoring and save results + profiler.stop_monitoring() + + # Print summary + print("\n" + "="*60) + print("MEMORY PROFILING SUMMARY") + print("="*60) + + if profiler.data: + peak_rss = max([d['rss_gb'] for d in profiler.data]) + peak_vram = max([d['vram_gb'] for d in profiler.data]) + + print(f"Peak RSS Memory: {peak_rss:.2f} GB") + print(f"Peak VRAM Usage: {peak_vram:.2f} GB") + print(f"Total Samples: {len(profiler.data)}") + + # Show checkpoints + checkpoints = [d for d in profiler.data if 'checkpoint' in d] + if checkpoints: + print(f"\nCheckpoints ({len(checkpoints)}):") + for cp in checkpoints: + print(f" {cp['checkpoint']}: RSS={cp['rss_gb']:.2f}GB, VRAM={cp['vram_gb']:.2f}GB") + + print(f"\nDetailed profile saved to: {profiler.output_file}") + +def main(): + if len(sys.argv) != 2: + print("Usage: python memory_profiler_script.py ") + print("\nThis script runs VR180 matting with detailed memory profiling") + print("It will:") + print("- Monitor RSS, VRAM, and system memory every 2 seconds") + print("- Track memory allocations with tracemalloc") + print("- Log checkpoints at key processing stages") + print("- Save detailed JSON report for analysis") + sys.exit(1) + + config_path = sys.argv[1] + + if not Path(config_path).exists(): + print(f"āŒ Config file not found: {config_path}") + sys.exit(1) + + print("šŸš€ Starting VR180 Memory Profiling") + print(f"Config: {config_path}") + print("="*60) + + run_with_profiling(config_path) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/vr180_matting/video_processor.py b/vr180_matting/video_processor.py index c319dd3..d9e4dcd 100644 --- a/vr180_matting/video_processor.py +++ b/vr180_matting/video_processor.py @@ -387,19 +387,83 @@ class VideoProcessor: # Green screen background return np.full_like(frame, self.config.output.background_color, dtype=np.uint8) + def merge_chunks_streaming(self, chunk_files: List[Path], output_path: str, + overlap_frames: int = 0, audio_source: str = None) -> None: + """ + Merge processed chunks using streaming approach (no memory accumulation) + + Args: + chunk_files: List of chunk result files (.npz) + output_path: Final output video path + overlap_frames: Number of overlapping frames + audio_source: Audio source file for final video + """ + from .streaming_video_writer import StreamingVideoWriter + + if not chunk_files: + raise ValueError("No chunk files to merge") + + print(f"šŸŽ¬ Streaming merge: {len(chunk_files)} chunks → {output_path}") + + # Initialize streaming writer + writer = StreamingVideoWriter( + output_path=output_path, + fps=self.video_info['fps'], + audio_source=audio_source + ) + + try: + # Process each chunk without accumulation + for i, chunk_file in enumerate(chunk_files): + print(f"šŸ“¼ Processing chunk {i+1}/{len(chunk_files)}: {chunk_file.name}") + + # Load chunk (this is the only copy in memory) + chunk_data = np.load(str(chunk_file)) + frames = chunk_data['frames'].tolist() # Convert to list of arrays + chunk_data.close() + + # Write chunk with streaming writer + writer.write_chunk( + frames=frames, + chunk_index=i, + overlap_frames=overlap_frames if i > 0 else 0, + blend_with_previous=(i > 0 and overlap_frames > 0) + ) + + # Immediately free memory + del frames, chunk_data + + # Delete chunk file to free disk space + try: + chunk_file.unlink() + print(f" šŸ—‘ļø Deleted {chunk_file.name}") + except Exception as e: + print(f" āš ļø Could not delete {chunk_file.name}: {e}") + + # Aggressive cleanup every chunk + self._aggressive_memory_cleanup(f"After processing chunk {i}") + + # Finalize the video + writer.finalize() + + except Exception as e: + print(f"āŒ Streaming merge failed: {e}") + writer.cleanup() + raise + + print(f"āœ… Streaming merge complete: {output_path}") + def merge_overlapping_chunks(self, chunk_results: List[List[np.ndarray]], overlap_frames: int) -> List[np.ndarray]: """ - Merge overlapping chunks with blending in overlap regions - - Args: - chunk_results: List of chunk results - overlap_frames: Number of overlapping frames - - Returns: - Merged frame sequence + Legacy merge method - DEPRECATED due to memory accumulation + Use merge_chunks_streaming() instead for memory efficiency """ + import warnings + warnings.warn("merge_overlapping_chunks() is deprecated due to memory accumulation. Use merge_chunks_streaming()", + DeprecationWarning, stacklevel=2) + if len(chunk_results) == 1: return chunk_results[0] @@ -640,36 +704,23 @@ class VideoProcessor: if self.memory_manager.should_emergency_cleanup(): self.memory_manager.emergency_cleanup() - # Load and merge chunks from disk - print("\nLoading and merging chunks...") - chunk_results = [] - for i, chunk_file in enumerate(chunk_files): - print(f"Loading {chunk_file.name}...") - chunk_data = np.load(str(chunk_file)) - chunk_results.append(chunk_data['frames']) - chunk_data.close() # Close the file - - # Delete chunk file immediately after loading to free disk space - try: - chunk_file.unlink() - print(f" Deleted chunk file {chunk_file.name}") - except Exception as e: - print(f" Warning: Could not delete chunk file: {e}") - - # Aggressive cleanup every few chunks to prevent accumulation - if i % 3 == 0 and i > 0: - self._aggressive_memory_cleanup(f"after loading chunk {i}") + # Use streaming merge to avoid memory accumulation (fixes OOM) + print("\nšŸŽ¬ Using streaming merge (no memory accumulation)...") - # Merge chunks - final_frames = self.merge_overlapping_chunks(chunk_results, overlap_frames) + # Determine audio source for final video + audio_source = None + if self.config.output.preserve_audio and Path(self.config.input.video_path).exists(): + audio_source = self.config.input.video_path - # Free chunk results after merging - this is critical! - del chunk_results - self._aggressive_memory_cleanup("after merging chunks") + # Stream merge chunks directly to output (no memory accumulation) + self.merge_chunks_streaming( + chunk_files=chunk_files, + output_path=self.config.output.path, + overlap_frames=overlap_frames, + audio_source=audio_source + ) - # Save results - print(f"Saving {len(final_frames)} processed frames...") - self.save_video(final_frames, self.config.output.path) + print("āœ… Streaming merge complete - no memory accumulation!") # Calculate final statistics self.processing_stats['end_time'] = time.time() diff --git a/vr180_matting/vr180_processor.py b/vr180_matting/vr180_processor.py index 5fe875e..738299b 100644 --- a/vr180_matting/vr180_processor.py +++ b/vr180_matting/vr180_processor.py @@ -398,44 +398,50 @@ class VR180Processor(VideoProcessor): self._print_memory_step(f"After SAM2 propagation ({eye_name} eye)") - # Apply masks - need to reload frames from temp video since we freed the original frames - self._print_memory_step(f"Before reloading frames for mask application ({eye_name} eye)") + # Apply masks with streaming approach (no frame accumulation) + self._print_memory_step(f"Before streaming mask application ({eye_name} eye)") - # Read frames back from the temp video for mask application + # Process frames one at a time without accumulation cap = cv2.VideoCapture(str(temp_video_path)) - reloaded_frames = [] - - for frame_idx in range(num_frames): - ret, frame = cap.read() - if not ret: - break - reloaded_frames.append(frame) - cap.release() - - self._print_memory_step(f"Reloaded {len(reloaded_frames)} frames for mask application") - - # Apply masks matted_frames = [] - for frame_idx, frame in enumerate(reloaded_frames): - if frame_idx in video_segments: - frame_masks = video_segments[frame_idx] - combined_mask = self.sam2_model.get_combined_mask(frame_masks) + + try: + for frame_idx in range(num_frames): + ret, frame = cap.read() + if not ret: + break - matted_frame = self.sam2_model.apply_mask_to_frame( - frame, combined_mask, - output_format=self.config.output.format, - background_color=self.config.output.background_color - ) - else: - matted_frame = self._create_empty_mask_frame(frame) - - matted_frames.append(matted_frame) + # Apply mask to this single frame + if frame_idx in video_segments: + frame_masks = video_segments[frame_idx] + combined_mask = self.sam2_model.get_combined_mask(frame_masks) + + matted_frame = self.sam2_model.apply_mask_to_frame( + frame, combined_mask, + output_format=self.config.output.format, + background_color=self.config.output.background_color + ) + else: + matted_frame = self._create_empty_mask_frame(frame) + + matted_frames.append(matted_frame) + + # Free the original frame immediately (no accumulation) + del frame + + # Periodic cleanup during processing + if frame_idx % 100 == 0 and frame_idx > 0: + import gc + gc.collect() + + finally: + cap.release() - # Free reloaded frames and video segments completely - del reloaded_frames + # Free video segments completely del video_segments # This holds processed masks from SAM2 - self._aggressive_memory_cleanup(f"After mask application ({eye_name} eye)") + self._aggressive_memory_cleanup(f"After streaming mask application ({eye_name} eye)") + self._print_memory_step(f"Completed streaming mask application ({eye_name} eye)") return matted_frames finally: