#!/usr/bin/env python3 """ Analyze memory profile JSON files to identify OOM causes """ import json import glob import os import sys from pathlib import Path def analyze_memory_files(): """Analyze partial memory profile files""" # Get all partial files in order files = sorted(glob.glob('memory_profile_partial_*.json')) if not files: print("āŒ No memory profile files found!") print("Expected files like: memory_profile_partial_0.json") return print(f"šŸ” Found {len(files)} memory profile files") print("=" * 60) peak_memory = 0 peak_vram = 0 critical_points = [] all_checkpoints = [] for i, file in enumerate(files): try: with open(file, 'r') as f: data = json.load(f) timeline = data.get('timeline', []) if not timeline: continue # Find peaks in this file file_peak_rss = max([d['rss_gb'] for d in timeline]) file_peak_vram = max([d['vram_gb'] for d in timeline]) if file_peak_rss > peak_memory: peak_memory = file_peak_rss if file_peak_vram > peak_vram: peak_vram = file_peak_vram # Find memory growth spikes (>3GB increase) for j in range(1, len(timeline)): prev_rss = timeline[j-1]['rss_gb'] curr_rss = timeline[j]['rss_gb'] growth = curr_rss - prev_rss if growth > 3.0: # >3GB growth spike checkpoint = timeline[j].get('checkpoint', f'sample_{j}') critical_points.append({ 'file': file, 'file_index': i, 'sample': j, 'timestamp': timeline[j]['timestamp'], 'rss_gb': curr_rss, 'vram_gb': timeline[j]['vram_gb'], 'growth_gb': growth, 'checkpoint': checkpoint }) # Collect all checkpoints checkpoints = [d for d in timeline if 'checkpoint' in d] for cp in checkpoints: cp['file'] = file cp['file_index'] = i all_checkpoints.append(cp) # Show progress for this file if timeline: start_rss = timeline[0]['rss_gb'] end_rss = timeline[-1]['rss_gb'] growth = end_rss - start_rss samples = len(timeline) print(f"šŸ“Š File {i+1:2d}: {start_rss:5.1f}GB → {end_rss:5.1f}GB " f"(+{growth:4.1f}GB) [{samples:3d} samples]") # Show significant checkpoints from this file if checkpoints: for cp in checkpoints: print(f" šŸ“ {cp['checkpoint']}: {cp['rss_gb']:.1f}GB") except Exception as e: print(f"āŒ Error reading {file}: {e}") print("\n" + "=" * 60) print("šŸŽÆ ANALYSIS SUMMARY") print("=" * 60) print(f"šŸ“ˆ Peak Memory: {peak_memory:.1f} GB") print(f"šŸŽ® Peak VRAM: {peak_vram:.1f} GB") print(f"⚔ Growth Spikes: {len(critical_points)} events >3GB") if critical_points: print(f"\nšŸ’„ MEMORY GROWTH SPIKES (>3GB):") print(" Location Growth Total VRAM") print(" " + "-" * 55) for point in critical_points: location = point['checkpoint'][:30].ljust(30) print(f" {location} +{point['growth_gb']:4.1f}GB → {point['rss_gb']:5.1f}GB {point['vram_gb']:4.1f}GB") if all_checkpoints: print(f"\nšŸ“ CHECKPOINT PROGRESSION:") print(" Checkpoint Memory VRAM File") print(" " + "-" * 55) for cp in all_checkpoints: checkpoint = cp['checkpoint'][:30].ljust(30) file_num = cp['file_index'] + 1 print(f" {checkpoint} {cp['rss_gb']:5.1f}GB {cp['vram_gb']:4.1f}GB #{file_num}") # Memory growth analysis if len(all_checkpoints) > 1: print(f"\nšŸ“Š MEMORY GROWTH ANALYSIS:") # Find the biggest memory jumps between checkpoints big_jumps = [] for i in range(1, len(all_checkpoints)): prev_cp = all_checkpoints[i-1] curr_cp = all_checkpoints[i] growth = curr_cp['rss_gb'] - prev_cp['rss_gb'] if growth > 2.0: # >2GB jump big_jumps.append({ 'from': prev_cp['checkpoint'], 'to': curr_cp['checkpoint'], 'growth': growth, 'from_memory': prev_cp['rss_gb'], 'to_memory': curr_cp['rss_gb'] }) if big_jumps: print(" Major jumps (>2GB):") for jump in big_jumps: print(f" {jump['from']} → {jump['to']}: " f"+{jump['growth']:.1f}GB ({jump['from_memory']:.1f}→{jump['to_memory']:.1f}GB)") else: print(" āœ… No major memory jumps detected") # Diagnosis print(f"\nšŸ”¬ DIAGNOSIS:") if peak_memory > 400: print(" šŸ”“ CRITICAL: Memory usage exceeded 400GB") print(" šŸ’” Recommendation: Reduce chunk_size to 200-300 frames") elif peak_memory > 200: print(" 🟔 HIGH: Memory usage over 200GB") print(" šŸ’” Recommendation: Reduce chunk_size to 400 frames") else: print(" 🟢 MODERATE: Memory usage under 200GB") if critical_points: # Find most common growth spike locations spike_locations = {} for point in critical_points: location = point['checkpoint'] spike_locations[location] = spike_locations.get(location, 0) + 1 print("\n šŸŽÆ Most problematic locations:") for location, count in sorted(spike_locations.items(), key=lambda x: x[1], reverse=True)[:3]: print(f" {location}: {count} spikes") print(f"\nšŸ’” NEXT STEPS:") if 'merge' in str(critical_points).lower(): print(" 1. Chunk merging still causing memory accumulation") print(" 2. Check if streaming merge is actually being used") print(" 3. Verify chunk files are being deleted immediately") elif 'propagation' in str(critical_points).lower(): print(" 1. SAM2 propagation using too much memory") print(" 2. Reduce chunk_size further (try 300 frames)") print(" 3. Enable more aggressive frame release") else: print(" 1. Review the checkpoint progression above") print(" 2. Focus on locations with biggest memory spikes") print(" 3. Consider reducing chunk_size if spikes are large") def main(): print("šŸ” MEMORY PROFILE ANALYZER") print("Analyzing memory profile files for OOM causes...") print() analyze_memory_files() if __name__ == "__main__": main()