From 262cb00b698e2a05c72c873a33dc6fb66adcea68 Mon Sep 17 00:00:00 2001 From: Scott Register Date: Sat, 26 Jul 2025 17:11:07 -0700 Subject: [PATCH] checkpoints yay --- vr180_matting/checkpoint_manager.py | 207 ++++++++++++++++++++++++++++ vr180_matting/video_processor.py | 94 ++++++++++--- 2 files changed, 283 insertions(+), 18 deletions(-) create mode 100644 vr180_matting/checkpoint_manager.py diff --git a/vr180_matting/checkpoint_manager.py b/vr180_matting/checkpoint_manager.py new file mode 100644 index 0000000..3272d89 --- /dev/null +++ b/vr180_matting/checkpoint_manager.py @@ -0,0 +1,207 @@ +""" +Checkpoint manager for resumable video processing +Saves progress to avoid reprocessing after OOM or crashes +""" + +import json +import hashlib +from pathlib import Path +from typing import Dict, Any, Optional, List +import os +import shutil +from datetime import datetime + + +class CheckpointManager: + """Manages processing checkpoints for resumable execution""" + + def __init__(self, video_path: str, output_path: str, checkpoint_dir: Optional[Path] = None): + """ + Initialize checkpoint manager + + Args: + video_path: Input video path + output_path: Output video path + checkpoint_dir: Directory for checkpoint files (default: .vr180_checkpoints in CWD) + """ + self.video_path = Path(video_path) + self.output_path = Path(output_path) + + # Create unique checkpoint ID based on video file + self.video_hash = self._compute_video_hash() + + # Setup checkpoint directory + if checkpoint_dir is None: + self.checkpoint_dir = Path.cwd() / ".vr180_checkpoints" / self.video_hash + else: + self.checkpoint_dir = Path(checkpoint_dir) / self.video_hash + + self.checkpoint_dir.mkdir(parents=True, exist_ok=True) + + # Checkpoint files + self.status_file = self.checkpoint_dir / "processing_status.json" + self.chunks_dir = self.checkpoint_dir / "chunks" + self.chunks_dir.mkdir(exist_ok=True) + + # Load existing status or create new + self.status = self._load_status() + + def _compute_video_hash(self) -> str: + """Compute hash of video file for unique identification""" + # Use file path, size, and modification time for quick hash + stat = self.video_path.stat() + hash_str = f"{self.video_path}_{stat.st_size}_{stat.st_mtime}" + return hashlib.md5(hash_str.encode()).hexdigest()[:12] + + def _load_status(self) -> Dict[str, Any]: + """Load processing status from checkpoint file""" + if self.status_file.exists(): + with open(self.status_file, 'r') as f: + status = json.load(f) + print(f"šŸ“‹ Loaded checkpoint: {status['completed_chunks']}/{status['total_chunks']} chunks completed") + return status + else: + # Create new status + return { + 'video_path': str(self.video_path), + 'output_path': str(self.output_path), + 'video_hash': self.video_hash, + 'start_time': datetime.now().isoformat(), + 'total_chunks': 0, + 'completed_chunks': 0, + 'chunk_info': {}, + 'processing_complete': False, + 'merge_complete': False + } + + def _save_status(self): + """Save current status to checkpoint file""" + self.status['last_update'] = datetime.now().isoformat() + with open(self.status_file, 'w') as f: + json.dump(self.status, f, indent=2) + + def set_total_chunks(self, total_chunks: int): + """Set total number of chunks to process""" + self.status['total_chunks'] = total_chunks + self._save_status() + + def is_chunk_completed(self, chunk_idx: int) -> bool: + """Check if a chunk has already been processed""" + chunk_key = f"chunk_{chunk_idx}" + return chunk_key in self.status['chunk_info'] and \ + self.status['chunk_info'][chunk_key].get('completed', False) + + def get_chunk_file(self, chunk_idx: int) -> Optional[Path]: + """Get saved chunk file path if it exists""" + chunk_file = self.chunks_dir / f"chunk_{chunk_idx:04d}.npz" + if chunk_file.exists() and self.is_chunk_completed(chunk_idx): + return chunk_file + return None + + def save_chunk(self, chunk_idx: int, frames: List, source_chunk_path: Optional[Path] = None): + """ + Save processed chunk and mark as completed + + Args: + chunk_idx: Chunk index + frames: Processed frames (can be None if using source_chunk_path) + source_chunk_path: If provided, copy this file instead of saving frames + """ + chunk_file = self.chunks_dir / f"chunk_{chunk_idx:04d}.npz" + + try: + if source_chunk_path and source_chunk_path.exists(): + # Copy existing chunk file + shutil.copy2(source_chunk_path, chunk_file) + print(f"šŸ’¾ Copied chunk {chunk_idx} to checkpoint: {chunk_file.name}") + elif frames is not None: + # Save new frames + import numpy as np + np.savez_compressed(str(chunk_file), frames=frames) + print(f"šŸ’¾ Saved chunk {chunk_idx} to checkpoint: {chunk_file.name}") + else: + raise ValueError("Either frames or source_chunk_path must be provided") + + # Update status + chunk_key = f"chunk_{chunk_idx}" + self.status['chunk_info'][chunk_key] = { + 'completed': True, + 'file': chunk_file.name, + 'timestamp': datetime.now().isoformat() + } + self.status['completed_chunks'] = len([c for c in self.status['chunk_info'].values() if c['completed']]) + self._save_status() + + print(f"āœ… Chunk {chunk_idx} checkpoint saved ({self.status['completed_chunks']}/{self.status['total_chunks']})") + + except Exception as e: + print(f"āŒ Failed to save chunk {chunk_idx} checkpoint: {e}") + + def get_completed_chunk_files(self) -> List[Path]: + """Get list of all completed chunk files in order""" + chunk_files = [] + for i in range(self.status['total_chunks']): + chunk_file = self.get_chunk_file(i) + if chunk_file: + chunk_files.append(chunk_file) + else: + break # Stop at first missing chunk + return chunk_files + + def mark_processing_complete(self): + """Mark all chunk processing as complete""" + self.status['processing_complete'] = True + self._save_status() + print(f"āœ… All chunks processed and checkpointed") + + def mark_merge_complete(self): + """Mark final merge as complete""" + self.status['merge_complete'] = True + self._save_status() + print(f"āœ… Video merge completed") + + def cleanup_checkpoints(self, keep_chunks: bool = False): + """ + Clean up checkpoint files after successful completion + + Args: + keep_chunks: If True, keep chunk files but remove status + """ + if keep_chunks: + # Just remove status file + if self.status_file.exists(): + self.status_file.unlink() + print(f"šŸ—‘ļø Removed checkpoint status file") + else: + # Remove entire checkpoint directory + if self.checkpoint_dir.exists(): + shutil.rmtree(self.checkpoint_dir) + print(f"šŸ—‘ļø Removed all checkpoint files: {self.checkpoint_dir}") + + def get_resume_info(self) -> Dict[str, Any]: + """Get information about what can be resumed""" + return { + 'can_resume': self.status['completed_chunks'] > 0, + 'completed_chunks': self.status['completed_chunks'], + 'total_chunks': self.status['total_chunks'], + 'processing_complete': self.status['processing_complete'], + 'merge_complete': self.status['merge_complete'], + 'checkpoint_dir': str(self.checkpoint_dir) + } + + def print_status(self): + """Print current checkpoint status""" + print(f"\nšŸ“Š CHECKPOINT STATUS:") + print(f" Video: {self.video_path.name}") + print(f" Hash: {self.video_hash}") + print(f" Progress: {self.status['completed_chunks']}/{self.status['total_chunks']} chunks") + print(f" Processing complete: {self.status['processing_complete']}") + print(f" Merge complete: {self.status['merge_complete']}") + print(f" Checkpoint dir: {self.checkpoint_dir}") + + if self.status['completed_chunks'] > 0: + print(f"\n Completed chunks:") + for i in range(self.status['completed_chunks']): + chunk_info = self.status['chunk_info'].get(f'chunk_{i}', {}) + if chunk_info.get('completed'): + print(f" āœ“ Chunk {i}: {chunk_info.get('file', 'unknown')}") \ No newline at end of file diff --git a/vr180_matting/video_processor.py b/vr180_matting/video_processor.py index 7df799e..3ee8957 100644 --- a/vr180_matting/video_processor.py +++ b/vr180_matting/video_processor.py @@ -781,26 +781,55 @@ class VideoProcessor: print(f"āš ļø Could not verify frame count: {e}") def process_video(self) -> None: - """Main video processing pipeline""" + """Main video processing pipeline with checkpoint/resume support""" self.processing_stats['start_time'] = time.time() print("Starting VR180 video processing...") # Load video info self.load_video_info(self.config.input.video_path) + # Initialize checkpoint manager + from .checkpoint_manager import CheckpointManager + checkpoint_mgr = CheckpointManager( + self.config.input.video_path, + self.config.output.path + ) + + # Check for existing checkpoints + resume_info = checkpoint_mgr.get_resume_info() + if resume_info['can_resume']: + print(f"\nšŸ”„ RESUME DETECTED:") + print(f" Found {resume_info['completed_chunks']} completed chunks") + print(f" Continue from where we left off? (saves time!)") + checkpoint_mgr.print_status() + # Calculate chunking parameters chunk_size, overlap_frames = self.calculate_optimal_chunking() + # Calculate total chunks + total_chunks = 0 + for _ in range(0, self.total_frames, chunk_size - overlap_frames): + total_chunks += 1 + checkpoint_mgr.set_total_chunks(total_chunks) + # Process video in chunks chunk_files = [] # Store file paths instead of frame data temp_chunk_dir = Path(tempfile.mkdtemp(prefix="vr180_chunks_")) try: + chunk_idx = 0 for start_frame in range(0, self.total_frames, chunk_size - overlap_frames): end_frame = min(start_frame + chunk_size, self.total_frames) frames_to_read = end_frame - start_frame - chunk_idx = len(chunk_files) + # Check if this chunk was already processed + existing_chunk = checkpoint_mgr.get_chunk_file(chunk_idx) + if existing_chunk: + print(f"\nāœ… Chunk {chunk_idx} already processed: {existing_chunk.name}") + chunk_files.append(existing_chunk) + chunk_idx += 1 + continue + print(f"\nProcessing chunk {chunk_idx}: frames {start_frame}-{end_frame}") # Read chunk frames @@ -818,7 +847,12 @@ class VideoProcessor: chunk_path = temp_chunk_dir / f"chunk_{chunk_idx:04d}.npz" print(f"Saving chunk {chunk_idx} to disk...") np.savez_compressed(str(chunk_path), frames=matted_frames) + + # Save to checkpoint + checkpoint_mgr.save_chunk(chunk_idx, None, source_chunk_path=chunk_path) + chunk_files.append(chunk_path) + chunk_idx += 1 # Free the frames from memory immediately del matted_frames @@ -837,21 +871,32 @@ class VideoProcessor: if self.memory_manager.should_emergency_cleanup(): self.memory_manager.emergency_cleanup() - # Use streaming merge to avoid memory accumulation (fixes OOM) - print("\nšŸŽ¬ Using streaming merge (no memory accumulation)...") + # Mark chunk processing as complete + checkpoint_mgr.mark_processing_complete() - # Determine audio source for final video - audio_source = None - if self.config.output.preserve_audio and Path(self.config.input.video_path).exists(): - audio_source = self.config.input.video_path - - # Stream merge chunks directly to output (no memory accumulation) - self.merge_chunks_streaming( - chunk_files=chunk_files, - output_path=self.config.output.path, - overlap_frames=overlap_frames, - audio_source=audio_source - ) + # Check if merge was already done + if resume_info.get('merge_complete', False): + print("\nāœ… Merge already completed in previous run!") + print(f" Output: {self.config.output.path}") + else: + # Use streaming merge to avoid memory accumulation (fixes OOM) + print("\nšŸŽ¬ Using streaming merge (no memory accumulation)...") + + # Determine audio source for final video + audio_source = None + if self.config.output.preserve_audio and Path(self.config.input.video_path).exists(): + audio_source = self.config.input.video_path + + # Stream merge chunks directly to output (no memory accumulation) + self.merge_chunks_streaming( + chunk_files=chunk_files, + output_path=self.config.output.path, + overlap_frames=overlap_frames, + audio_source=audio_source + ) + + # Mark merge as complete + checkpoint_mgr.mark_merge_complete() print("āœ… Streaming merge complete - no memory accumulation!") @@ -869,11 +914,24 @@ class VideoProcessor: print("Video processing completed!") + # Option to clean up checkpoints + print("\nšŸ—„ļø CHECKPOINT CLEANUP OPTIONS:") + print(" Checkpoints saved successfully and can be cleaned up") + print(" - Keep checkpoints for debugging: checkpoint_mgr.cleanup_checkpoints(keep_chunks=True)") + print(" - Remove all checkpoints: checkpoint_mgr.cleanup_checkpoints()") + print(f" - Checkpoint location: {checkpoint_mgr.checkpoint_dir}") + + # For now, keep checkpoints by default (user can manually clean) + print("\nšŸ’” Checkpoints kept for safety. Delete manually when no longer needed.") + finally: - # Clean up temporary chunk files + # Clean up temporary chunk files (but not checkpoints) if temp_chunk_dir.exists(): print("Cleaning up temporary chunk files...") - shutil.rmtree(temp_chunk_dir) + try: + shutil.rmtree(temp_chunk_dir) + except Exception as e: + print(f"āš ļø Could not clean temp directory: {e}") def _print_processing_statistics(self): """Print detailed processing statistics"""