checkpoints yay
This commit is contained in:
207
vr180_matting/checkpoint_manager.py
Normal file
207
vr180_matting/checkpoint_manager.py
Normal file
@@ -0,0 +1,207 @@
|
||||
"""
|
||||
Checkpoint manager for resumable video processing
|
||||
Saves progress to avoid reprocessing after OOM or crashes
|
||||
"""
|
||||
|
||||
import json
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional, List
|
||||
import os
|
||||
import shutil
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class CheckpointManager:
|
||||
"""Manages processing checkpoints for resumable execution"""
|
||||
|
||||
def __init__(self, video_path: str, output_path: str, checkpoint_dir: Optional[Path] = None):
|
||||
"""
|
||||
Initialize checkpoint manager
|
||||
|
||||
Args:
|
||||
video_path: Input video path
|
||||
output_path: Output video path
|
||||
checkpoint_dir: Directory for checkpoint files (default: .vr180_checkpoints in CWD)
|
||||
"""
|
||||
self.video_path = Path(video_path)
|
||||
self.output_path = Path(output_path)
|
||||
|
||||
# Create unique checkpoint ID based on video file
|
||||
self.video_hash = self._compute_video_hash()
|
||||
|
||||
# Setup checkpoint directory
|
||||
if checkpoint_dir is None:
|
||||
self.checkpoint_dir = Path.cwd() / ".vr180_checkpoints" / self.video_hash
|
||||
else:
|
||||
self.checkpoint_dir = Path(checkpoint_dir) / self.video_hash
|
||||
|
||||
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Checkpoint files
|
||||
self.status_file = self.checkpoint_dir / "processing_status.json"
|
||||
self.chunks_dir = self.checkpoint_dir / "chunks"
|
||||
self.chunks_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Load existing status or create new
|
||||
self.status = self._load_status()
|
||||
|
||||
def _compute_video_hash(self) -> str:
|
||||
"""Compute hash of video file for unique identification"""
|
||||
# Use file path, size, and modification time for quick hash
|
||||
stat = self.video_path.stat()
|
||||
hash_str = f"{self.video_path}_{stat.st_size}_{stat.st_mtime}"
|
||||
return hashlib.md5(hash_str.encode()).hexdigest()[:12]
|
||||
|
||||
def _load_status(self) -> Dict[str, Any]:
|
||||
"""Load processing status from checkpoint file"""
|
||||
if self.status_file.exists():
|
||||
with open(self.status_file, 'r') as f:
|
||||
status = json.load(f)
|
||||
print(f"📋 Loaded checkpoint: {status['completed_chunks']}/{status['total_chunks']} chunks completed")
|
||||
return status
|
||||
else:
|
||||
# Create new status
|
||||
return {
|
||||
'video_path': str(self.video_path),
|
||||
'output_path': str(self.output_path),
|
||||
'video_hash': self.video_hash,
|
||||
'start_time': datetime.now().isoformat(),
|
||||
'total_chunks': 0,
|
||||
'completed_chunks': 0,
|
||||
'chunk_info': {},
|
||||
'processing_complete': False,
|
||||
'merge_complete': False
|
||||
}
|
||||
|
||||
def _save_status(self):
|
||||
"""Save current status to checkpoint file"""
|
||||
self.status['last_update'] = datetime.now().isoformat()
|
||||
with open(self.status_file, 'w') as f:
|
||||
json.dump(self.status, f, indent=2)
|
||||
|
||||
def set_total_chunks(self, total_chunks: int):
|
||||
"""Set total number of chunks to process"""
|
||||
self.status['total_chunks'] = total_chunks
|
||||
self._save_status()
|
||||
|
||||
def is_chunk_completed(self, chunk_idx: int) -> bool:
|
||||
"""Check if a chunk has already been processed"""
|
||||
chunk_key = f"chunk_{chunk_idx}"
|
||||
return chunk_key in self.status['chunk_info'] and \
|
||||
self.status['chunk_info'][chunk_key].get('completed', False)
|
||||
|
||||
def get_chunk_file(self, chunk_idx: int) -> Optional[Path]:
|
||||
"""Get saved chunk file path if it exists"""
|
||||
chunk_file = self.chunks_dir / f"chunk_{chunk_idx:04d}.npz"
|
||||
if chunk_file.exists() and self.is_chunk_completed(chunk_idx):
|
||||
return chunk_file
|
||||
return None
|
||||
|
||||
def save_chunk(self, chunk_idx: int, frames: List, source_chunk_path: Optional[Path] = None):
|
||||
"""
|
||||
Save processed chunk and mark as completed
|
||||
|
||||
Args:
|
||||
chunk_idx: Chunk index
|
||||
frames: Processed frames (can be None if using source_chunk_path)
|
||||
source_chunk_path: If provided, copy this file instead of saving frames
|
||||
"""
|
||||
chunk_file = self.chunks_dir / f"chunk_{chunk_idx:04d}.npz"
|
||||
|
||||
try:
|
||||
if source_chunk_path and source_chunk_path.exists():
|
||||
# Copy existing chunk file
|
||||
shutil.copy2(source_chunk_path, chunk_file)
|
||||
print(f"💾 Copied chunk {chunk_idx} to checkpoint: {chunk_file.name}")
|
||||
elif frames is not None:
|
||||
# Save new frames
|
||||
import numpy as np
|
||||
np.savez_compressed(str(chunk_file), frames=frames)
|
||||
print(f"💾 Saved chunk {chunk_idx} to checkpoint: {chunk_file.name}")
|
||||
else:
|
||||
raise ValueError("Either frames or source_chunk_path must be provided")
|
||||
|
||||
# Update status
|
||||
chunk_key = f"chunk_{chunk_idx}"
|
||||
self.status['chunk_info'][chunk_key] = {
|
||||
'completed': True,
|
||||
'file': chunk_file.name,
|
||||
'timestamp': datetime.now().isoformat()
|
||||
}
|
||||
self.status['completed_chunks'] = len([c for c in self.status['chunk_info'].values() if c['completed']])
|
||||
self._save_status()
|
||||
|
||||
print(f"✅ Chunk {chunk_idx} checkpoint saved ({self.status['completed_chunks']}/{self.status['total_chunks']})")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Failed to save chunk {chunk_idx} checkpoint: {e}")
|
||||
|
||||
def get_completed_chunk_files(self) -> List[Path]:
|
||||
"""Get list of all completed chunk files in order"""
|
||||
chunk_files = []
|
||||
for i in range(self.status['total_chunks']):
|
||||
chunk_file = self.get_chunk_file(i)
|
||||
if chunk_file:
|
||||
chunk_files.append(chunk_file)
|
||||
else:
|
||||
break # Stop at first missing chunk
|
||||
return chunk_files
|
||||
|
||||
def mark_processing_complete(self):
|
||||
"""Mark all chunk processing as complete"""
|
||||
self.status['processing_complete'] = True
|
||||
self._save_status()
|
||||
print(f"✅ All chunks processed and checkpointed")
|
||||
|
||||
def mark_merge_complete(self):
|
||||
"""Mark final merge as complete"""
|
||||
self.status['merge_complete'] = True
|
||||
self._save_status()
|
||||
print(f"✅ Video merge completed")
|
||||
|
||||
def cleanup_checkpoints(self, keep_chunks: bool = False):
|
||||
"""
|
||||
Clean up checkpoint files after successful completion
|
||||
|
||||
Args:
|
||||
keep_chunks: If True, keep chunk files but remove status
|
||||
"""
|
||||
if keep_chunks:
|
||||
# Just remove status file
|
||||
if self.status_file.exists():
|
||||
self.status_file.unlink()
|
||||
print(f"🗑️ Removed checkpoint status file")
|
||||
else:
|
||||
# Remove entire checkpoint directory
|
||||
if self.checkpoint_dir.exists():
|
||||
shutil.rmtree(self.checkpoint_dir)
|
||||
print(f"🗑️ Removed all checkpoint files: {self.checkpoint_dir}")
|
||||
|
||||
def get_resume_info(self) -> Dict[str, Any]:
|
||||
"""Get information about what can be resumed"""
|
||||
return {
|
||||
'can_resume': self.status['completed_chunks'] > 0,
|
||||
'completed_chunks': self.status['completed_chunks'],
|
||||
'total_chunks': self.status['total_chunks'],
|
||||
'processing_complete': self.status['processing_complete'],
|
||||
'merge_complete': self.status['merge_complete'],
|
||||
'checkpoint_dir': str(self.checkpoint_dir)
|
||||
}
|
||||
|
||||
def print_status(self):
|
||||
"""Print current checkpoint status"""
|
||||
print(f"\n📊 CHECKPOINT STATUS:")
|
||||
print(f" Video: {self.video_path.name}")
|
||||
print(f" Hash: {self.video_hash}")
|
||||
print(f" Progress: {self.status['completed_chunks']}/{self.status['total_chunks']} chunks")
|
||||
print(f" Processing complete: {self.status['processing_complete']}")
|
||||
print(f" Merge complete: {self.status['merge_complete']}")
|
||||
print(f" Checkpoint dir: {self.checkpoint_dir}")
|
||||
|
||||
if self.status['completed_chunks'] > 0:
|
||||
print(f"\n Completed chunks:")
|
||||
for i in range(self.status['completed_chunks']):
|
||||
chunk_info = self.status['chunk_info'].get(f'chunk_{i}', {})
|
||||
if chunk_info.get('completed'):
|
||||
print(f" ✓ Chunk {i}: {chunk_info.get('file', 'unknown')}")
|
||||
@@ -781,26 +781,55 @@ class VideoProcessor:
|
||||
print(f"⚠️ Could not verify frame count: {e}")
|
||||
|
||||
def process_video(self) -> None:
|
||||
"""Main video processing pipeline"""
|
||||
"""Main video processing pipeline with checkpoint/resume support"""
|
||||
self.processing_stats['start_time'] = time.time()
|
||||
print("Starting VR180 video processing...")
|
||||
|
||||
# Load video info
|
||||
self.load_video_info(self.config.input.video_path)
|
||||
|
||||
# Initialize checkpoint manager
|
||||
from .checkpoint_manager import CheckpointManager
|
||||
checkpoint_mgr = CheckpointManager(
|
||||
self.config.input.video_path,
|
||||
self.config.output.path
|
||||
)
|
||||
|
||||
# Check for existing checkpoints
|
||||
resume_info = checkpoint_mgr.get_resume_info()
|
||||
if resume_info['can_resume']:
|
||||
print(f"\n🔄 RESUME DETECTED:")
|
||||
print(f" Found {resume_info['completed_chunks']} completed chunks")
|
||||
print(f" Continue from where we left off? (saves time!)")
|
||||
checkpoint_mgr.print_status()
|
||||
|
||||
# Calculate chunking parameters
|
||||
chunk_size, overlap_frames = self.calculate_optimal_chunking()
|
||||
|
||||
# Calculate total chunks
|
||||
total_chunks = 0
|
||||
for _ in range(0, self.total_frames, chunk_size - overlap_frames):
|
||||
total_chunks += 1
|
||||
checkpoint_mgr.set_total_chunks(total_chunks)
|
||||
|
||||
# Process video in chunks
|
||||
chunk_files = [] # Store file paths instead of frame data
|
||||
temp_chunk_dir = Path(tempfile.mkdtemp(prefix="vr180_chunks_"))
|
||||
|
||||
try:
|
||||
chunk_idx = 0
|
||||
for start_frame in range(0, self.total_frames, chunk_size - overlap_frames):
|
||||
end_frame = min(start_frame + chunk_size, self.total_frames)
|
||||
frames_to_read = end_frame - start_frame
|
||||
|
||||
chunk_idx = len(chunk_files)
|
||||
# Check if this chunk was already processed
|
||||
existing_chunk = checkpoint_mgr.get_chunk_file(chunk_idx)
|
||||
if existing_chunk:
|
||||
print(f"\n✅ Chunk {chunk_idx} already processed: {existing_chunk.name}")
|
||||
chunk_files.append(existing_chunk)
|
||||
chunk_idx += 1
|
||||
continue
|
||||
|
||||
print(f"\nProcessing chunk {chunk_idx}: frames {start_frame}-{end_frame}")
|
||||
|
||||
# Read chunk frames
|
||||
@@ -818,7 +847,12 @@ class VideoProcessor:
|
||||
chunk_path = temp_chunk_dir / f"chunk_{chunk_idx:04d}.npz"
|
||||
print(f"Saving chunk {chunk_idx} to disk...")
|
||||
np.savez_compressed(str(chunk_path), frames=matted_frames)
|
||||
|
||||
# Save to checkpoint
|
||||
checkpoint_mgr.save_chunk(chunk_idx, None, source_chunk_path=chunk_path)
|
||||
|
||||
chunk_files.append(chunk_path)
|
||||
chunk_idx += 1
|
||||
|
||||
# Free the frames from memory immediately
|
||||
del matted_frames
|
||||
@@ -837,6 +871,14 @@ class VideoProcessor:
|
||||
if self.memory_manager.should_emergency_cleanup():
|
||||
self.memory_manager.emergency_cleanup()
|
||||
|
||||
# Mark chunk processing as complete
|
||||
checkpoint_mgr.mark_processing_complete()
|
||||
|
||||
# Check if merge was already done
|
||||
if resume_info.get('merge_complete', False):
|
||||
print("\n✅ Merge already completed in previous run!")
|
||||
print(f" Output: {self.config.output.path}")
|
||||
else:
|
||||
# Use streaming merge to avoid memory accumulation (fixes OOM)
|
||||
print("\n🎬 Using streaming merge (no memory accumulation)...")
|
||||
|
||||
@@ -853,6 +895,9 @@ class VideoProcessor:
|
||||
audio_source=audio_source
|
||||
)
|
||||
|
||||
# Mark merge as complete
|
||||
checkpoint_mgr.mark_merge_complete()
|
||||
|
||||
print("✅ Streaming merge complete - no memory accumulation!")
|
||||
|
||||
# Calculate final statistics
|
||||
@@ -869,11 +914,24 @@ class VideoProcessor:
|
||||
|
||||
print("Video processing completed!")
|
||||
|
||||
# Option to clean up checkpoints
|
||||
print("\n🗄️ CHECKPOINT CLEANUP OPTIONS:")
|
||||
print(" Checkpoints saved successfully and can be cleaned up")
|
||||
print(" - Keep checkpoints for debugging: checkpoint_mgr.cleanup_checkpoints(keep_chunks=True)")
|
||||
print(" - Remove all checkpoints: checkpoint_mgr.cleanup_checkpoints()")
|
||||
print(f" - Checkpoint location: {checkpoint_mgr.checkpoint_dir}")
|
||||
|
||||
# For now, keep checkpoints by default (user can manually clean)
|
||||
print("\n💡 Checkpoints kept for safety. Delete manually when no longer needed.")
|
||||
|
||||
finally:
|
||||
# Clean up temporary chunk files
|
||||
# Clean up temporary chunk files (but not checkpoints)
|
||||
if temp_chunk_dir.exists():
|
||||
print("Cleaning up temporary chunk files...")
|
||||
try:
|
||||
shutil.rmtree(temp_chunk_dir)
|
||||
except Exception as e:
|
||||
print(f"⚠️ Could not clean temp directory: {e}")
|
||||
|
||||
def _print_processing_statistics(self):
|
||||
"""Print detailed processing statistics"""
|
||||
|
||||
Reference in New Issue
Block a user