418 lines
17 KiB
Python
418 lines
17 KiB
Python
"""
|
|
Main VR180 streaming processor - orchestrates all components for true streaming
|
|
"""
|
|
|
|
import time
|
|
import gc
|
|
import json
|
|
import psutil
|
|
import torch
|
|
import numpy as np
|
|
from pathlib import Path
|
|
from typing import Dict, Any, Optional, Tuple
|
|
import warnings
|
|
|
|
from .frame_reader import StreamingFrameReader
|
|
from .frame_writer import StreamingFrameWriter
|
|
from .stereo_manager import StereoConsistencyManager
|
|
from .sam2_streaming import SAM2StreamingProcessor
|
|
from .detector import PersonDetector
|
|
from .config import StreamingConfig
|
|
|
|
|
|
class VR180StreamingProcessor:
|
|
"""Main processor for streaming VR180 human matting"""
|
|
|
|
def __init__(self, config: StreamingConfig):
|
|
self.config = config
|
|
|
|
# Initialize components
|
|
self.frame_reader = None
|
|
self.frame_writer = None
|
|
self.stereo_manager = None
|
|
self.sam2_processor = None
|
|
self.detector = None
|
|
|
|
# Processing state
|
|
self.start_time = None
|
|
self.frames_processed = 0
|
|
self.checkpoint_state = {}
|
|
|
|
# Performance monitoring
|
|
self.process = psutil.Process()
|
|
self.performance_stats = {
|
|
'fps': 0.0,
|
|
'avg_frame_time': 0.0,
|
|
'peak_memory_gb': 0.0,
|
|
'gpu_utilization': 0.0
|
|
}
|
|
|
|
def initialize(self) -> None:
|
|
"""Initialize all components"""
|
|
print("\n🚀 Initializing VR180 Streaming Processor")
|
|
print("=" * 60)
|
|
|
|
# Initialize frame reader
|
|
start_frame = self._load_checkpoint() if self.config.recovery.auto_resume else 0
|
|
self.frame_reader = StreamingFrameReader(
|
|
self.config.input.video_path,
|
|
start_frame=start_frame
|
|
)
|
|
|
|
# Get video info
|
|
video_info = self.frame_reader.get_video_info()
|
|
|
|
# Apply scaling to dimensions
|
|
scale = self.config.processing.scale_factor
|
|
output_width = int(video_info['width'] * scale)
|
|
output_height = int(video_info['height'] * scale)
|
|
|
|
# Initialize frame writer
|
|
self.frame_writer = StreamingFrameWriter(
|
|
output_path=self.config.output.path,
|
|
width=output_width,
|
|
height=output_height,
|
|
fps=video_info['fps'],
|
|
audio_source=self.config.input.video_path if self.config.output.maintain_sbs else None,
|
|
video_codec=self.config.output.video_codec,
|
|
quality_preset=self.config.output.quality_preset,
|
|
crf=self.config.output.crf
|
|
)
|
|
|
|
# Initialize stereo manager
|
|
self.stereo_manager = StereoConsistencyManager(self.config.to_dict())
|
|
|
|
# Initialize SAM2 processor
|
|
self.sam2_processor = SAM2StreamingProcessor(self.config.to_dict())
|
|
|
|
# Initialize detector
|
|
self.detector = PersonDetector(self.config.to_dict())
|
|
self.detector.warmup((output_height // 2, output_width // 2, 3)) # Warmup with single eye dims
|
|
|
|
print("\n✅ All components initialized successfully!")
|
|
print(f" Input: {video_info['width']}x{video_info['height']} @ {video_info['fps']}fps")
|
|
print(f" Output: {output_width}x{output_height} @ {video_info['fps']}fps")
|
|
print(f" Scale factor: {scale}")
|
|
print(f" Starting from frame: {start_frame}")
|
|
print("=" * 60 + "\n")
|
|
|
|
def process_video(self) -> None:
|
|
"""Main processing loop"""
|
|
try:
|
|
self.initialize()
|
|
self.start_time = time.time()
|
|
|
|
# Initialize SAM2 states for both eyes (streaming mode - no video loading)
|
|
print("🎯 Initializing SAM2 streaming states...")
|
|
video_info = self.frame_reader.get_video_info()
|
|
left_state = self.sam2_processor.init_state(
|
|
video_info,
|
|
eye='left'
|
|
)
|
|
right_state = self.sam2_processor.init_state(
|
|
video_info,
|
|
eye='right'
|
|
)
|
|
|
|
# Process first frame to establish detections
|
|
print("🔍 Processing first frame for initial detection...")
|
|
if not self._initialize_tracking(left_state, right_state):
|
|
raise RuntimeError("Failed to initialize tracking - no persons detected")
|
|
|
|
# Main streaming loop
|
|
print("\n🎬 Starting streaming processing loop...")
|
|
self._streaming_loop(left_state, right_state)
|
|
|
|
except KeyboardInterrupt:
|
|
print("\n⚠️ Processing interrupted by user")
|
|
self._save_checkpoint()
|
|
|
|
except Exception as e:
|
|
print(f"\n❌ Error during processing: {e}")
|
|
self._save_checkpoint()
|
|
raise
|
|
|
|
finally:
|
|
self._finalize()
|
|
|
|
def _initialize_tracking(self, left_state: Dict, right_state: Dict) -> bool:
|
|
"""Initialize tracking with first frame detection"""
|
|
# Read and process first frame
|
|
first_frame = self.frame_reader.read_frame()
|
|
if first_frame is None:
|
|
raise RuntimeError("Cannot read first frame")
|
|
|
|
# Scale frame if needed
|
|
if self.config.processing.scale_factor != 1.0:
|
|
first_frame = self._scale_frame(first_frame)
|
|
|
|
# Split into eyes
|
|
left_eye, right_eye = self.stereo_manager.split_frame(first_frame)
|
|
|
|
# Detect on master eye
|
|
master_eye = left_eye if self.stereo_manager.master_eye == 'left' else right_eye
|
|
detections = self.detector.detect_persons(master_eye)
|
|
|
|
if not detections:
|
|
warnings.warn("No persons detected in first frame")
|
|
return False
|
|
|
|
print(f" Detected {len(detections)} person(s) in first frame")
|
|
|
|
# Add detections to both eyes (streaming - pass frame data)
|
|
self.sam2_processor.add_detections(left_state, left_eye, detections, frame_idx=0)
|
|
|
|
# Transfer detections to slave eye
|
|
transferred_detections = self.stereo_manager.transfer_detections(
|
|
detections,
|
|
'left_to_right' if self.stereo_manager.master_eye == 'left' else 'right_to_left'
|
|
)
|
|
self.sam2_processor.add_detections(right_state, right_eye, transferred_detections, frame_idx=0)
|
|
|
|
# Process and write first frame
|
|
left_masks = self.sam2_processor.propagate_single_frame(left_state, left_eye, 0)
|
|
right_masks = self.sam2_processor.propagate_single_frame(right_state, right_eye, 0)
|
|
|
|
# Apply masks and write
|
|
processed_frame = self._apply_masks_to_frame(first_frame, left_masks, right_masks)
|
|
self.frame_writer.write_frame(processed_frame)
|
|
|
|
self.frames_processed = 1
|
|
return True
|
|
|
|
def _streaming_loop(self, left_state: Dict, right_state: Dict) -> None:
|
|
"""Main streaming processing loop"""
|
|
frame_times = []
|
|
last_log_time = time.time()
|
|
|
|
# Start from frame 1 (already processed frame 0)
|
|
for frame_idx, frame in enumerate(self.frame_reader, start=1):
|
|
frame_start_time = time.time()
|
|
|
|
# Scale frame if needed
|
|
if self.config.processing.scale_factor != 1.0:
|
|
frame = self._scale_frame(frame)
|
|
|
|
# Split into eyes
|
|
left_eye, right_eye = self.stereo_manager.split_frame(frame)
|
|
|
|
# Propagate masks for both eyes (streaming approach)
|
|
left_masks = self.sam2_processor.propagate_single_frame(left_state, left_eye, frame_idx)
|
|
right_masks = self.sam2_processor.propagate_single_frame(right_state, right_eye, frame_idx)
|
|
|
|
# Validate stereo consistency
|
|
right_masks = self.stereo_manager.validate_masks(
|
|
left_masks, right_masks, frame_idx
|
|
)
|
|
|
|
# Apply continuous correction if enabled
|
|
if (self.config.matting.continuous_correction and
|
|
frame_idx % self.config.matting.correction_interval == 0):
|
|
self._apply_continuous_correction(
|
|
left_state, right_state, left_eye, right_eye, frame_idx
|
|
)
|
|
|
|
# Apply masks and write frame
|
|
processed_frame = self._apply_masks_to_frame(frame, left_masks, right_masks)
|
|
self.frame_writer.write_frame(processed_frame)
|
|
|
|
# Update stats
|
|
frame_time = time.time() - frame_start_time
|
|
frame_times.append(frame_time)
|
|
self.frames_processed += 1
|
|
|
|
# Periodic logging and cleanup
|
|
if frame_idx % self.config.performance.log_interval == 0:
|
|
self._log_progress(frame_idx, frame_times)
|
|
frame_times = frame_times[-100:] # Keep only recent times
|
|
|
|
# Checkpoint saving
|
|
if (self.config.recovery.enable_checkpoints and
|
|
frame_idx % self.config.recovery.checkpoint_interval == 0):
|
|
self._save_checkpoint()
|
|
|
|
# Memory monitoring and cleanup
|
|
if frame_idx % 50 == 0:
|
|
self._monitor_and_cleanup()
|
|
|
|
# Check max frames limit
|
|
if (self.config.input.max_frames is not None and
|
|
self.frames_processed >= self.config.input.max_frames):
|
|
print(f"\n✅ Reached max frames limit ({self.config.input.max_frames})")
|
|
break
|
|
|
|
def _scale_frame(self, frame: np.ndarray) -> np.ndarray:
|
|
"""Scale frame according to configuration"""
|
|
scale = self.config.processing.scale_factor
|
|
if scale == 1.0:
|
|
return frame
|
|
|
|
new_width = int(frame.shape[1] * scale)
|
|
new_height = int(frame.shape[0] * scale)
|
|
|
|
import cv2
|
|
return cv2.resize(frame, (new_width, new_height), interpolation=cv2.INTER_AREA)
|
|
|
|
def _apply_masks_to_frame(self,
|
|
frame: np.ndarray,
|
|
left_masks: np.ndarray,
|
|
right_masks: np.ndarray) -> np.ndarray:
|
|
"""Apply masks to frame and combine results"""
|
|
# Split frame
|
|
left_eye, right_eye = self.stereo_manager.split_frame(frame)
|
|
|
|
# Apply masks to each eye
|
|
left_processed = self.sam2_processor.apply_mask_to_frame(
|
|
left_eye, left_masks,
|
|
output_format=self.config.output.format,
|
|
background_color=self.config.output.background_color
|
|
)
|
|
|
|
right_processed = self.sam2_processor.apply_mask_to_frame(
|
|
right_eye, right_masks,
|
|
output_format=self.config.output.format,
|
|
background_color=self.config.output.background_color
|
|
)
|
|
|
|
# Combine back to SBS
|
|
if self.config.output.maintain_sbs:
|
|
return self.stereo_manager.combine_frames(left_processed, right_processed)
|
|
else:
|
|
# Return just left eye for non-SBS output
|
|
return left_processed
|
|
|
|
def _apply_continuous_correction(self,
|
|
left_state: Dict,
|
|
right_state: Dict,
|
|
left_eye: np.ndarray,
|
|
right_eye: np.ndarray,
|
|
frame_idx: int) -> None:
|
|
"""Apply continuous correction to maintain tracking accuracy"""
|
|
print(f"\n🔄 Applying continuous correction at frame {frame_idx}")
|
|
|
|
# Detect on master eye
|
|
master_eye = left_eye if self.stereo_manager.master_eye == 'left' else right_eye
|
|
master_state = left_state if self.stereo_manager.master_eye == 'left' else right_state
|
|
|
|
self.sam2_processor.apply_continuous_correction(
|
|
master_state, master_eye, frame_idx, self.detector
|
|
)
|
|
|
|
# Transfer corrections to slave eye
|
|
# Note: This is simplified - actual implementation would transfer the refined prompts
|
|
|
|
def _log_progress(self, frame_idx: int, frame_times: list) -> None:
|
|
"""Log processing progress"""
|
|
elapsed = time.time() - self.start_time
|
|
avg_frame_time = np.mean(frame_times) if frame_times else 0
|
|
fps = 1.0 / avg_frame_time if avg_frame_time > 0 else 0
|
|
|
|
# Memory stats
|
|
memory_info = self.process.memory_info()
|
|
memory_gb = memory_info.rss / (1024**3)
|
|
|
|
# GPU stats if available
|
|
gpu_stats = self.sam2_processor.get_memory_usage()
|
|
|
|
# Progress percentage
|
|
progress = self.frame_reader.get_progress()
|
|
|
|
print(f"\n📊 Progress: Frame {frame_idx} ({progress:.1f}%)")
|
|
print(f" Speed: {fps:.1f} FPS (avg: {avg_frame_time*1000:.1f}ms/frame)")
|
|
print(f" Memory: {memory_gb:.1f}GB RAM", end="")
|
|
if 'cuda_allocated_gb' in gpu_stats:
|
|
print(f", {gpu_stats['cuda_allocated_gb']:.1f}GB VRAM")
|
|
else:
|
|
print()
|
|
print(f" Time elapsed: {elapsed/60:.1f} minutes")
|
|
|
|
# Update performance stats
|
|
self.performance_stats['fps'] = fps
|
|
self.performance_stats['avg_frame_time'] = avg_frame_time
|
|
self.performance_stats['peak_memory_gb'] = max(
|
|
self.performance_stats['peak_memory_gb'], memory_gb
|
|
)
|
|
|
|
def _monitor_and_cleanup(self) -> None:
|
|
"""Monitor memory and perform cleanup if needed"""
|
|
memory_info = self.process.memory_info()
|
|
memory_gb = memory_info.rss / (1024**3)
|
|
|
|
# Check if approaching limits
|
|
if memory_gb > self.config.hardware.max_ram_gb * 0.8:
|
|
print(f"\n⚠️ High memory usage ({memory_gb:.1f}GB) - running cleanup")
|
|
gc.collect()
|
|
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
torch.cuda.synchronize()
|
|
|
|
def _save_checkpoint(self) -> None:
|
|
"""Save processing checkpoint"""
|
|
if not self.config.recovery.enable_checkpoints:
|
|
return
|
|
|
|
checkpoint_dir = Path(self.config.recovery.checkpoint_dir)
|
|
checkpoint_dir.mkdir(exist_ok=True)
|
|
|
|
checkpoint_file = checkpoint_dir / f"{Path(self.config.output.path).stem}_checkpoint.json"
|
|
|
|
checkpoint_data = {
|
|
'frame_index': self.frames_processed,
|
|
'timestamp': time.time(),
|
|
'input_video': self.config.input.video_path,
|
|
'output_video': self.config.output.path,
|
|
'config': self.config.to_dict()
|
|
}
|
|
|
|
with open(checkpoint_file, 'w') as f:
|
|
json.dump(checkpoint_data, f, indent=2)
|
|
|
|
print(f"💾 Checkpoint saved at frame {self.frames_processed}")
|
|
|
|
def _load_checkpoint(self) -> int:
|
|
"""Load checkpoint if exists"""
|
|
checkpoint_dir = Path(self.config.recovery.checkpoint_dir)
|
|
checkpoint_file = checkpoint_dir / f"{Path(self.config.output.path).stem}_checkpoint.json"
|
|
|
|
if checkpoint_file.exists():
|
|
with open(checkpoint_file, 'r') as f:
|
|
checkpoint_data = json.load(f)
|
|
|
|
if checkpoint_data['input_video'] == self.config.input.video_path:
|
|
start_frame = checkpoint_data['frame_index']
|
|
print(f"📂 Found checkpoint - resuming from frame {start_frame}")
|
|
return start_frame
|
|
|
|
return 0
|
|
|
|
def _finalize(self) -> None:
|
|
"""Finalize processing and cleanup"""
|
|
print("\n🏁 Finalizing processing...")
|
|
|
|
# Close components
|
|
if self.frame_writer:
|
|
self.frame_writer.close()
|
|
|
|
if self.frame_reader:
|
|
self.frame_reader.close()
|
|
|
|
if self.sam2_processor:
|
|
self.sam2_processor.cleanup()
|
|
|
|
# Print final statistics
|
|
if self.start_time:
|
|
total_time = time.time() - self.start_time
|
|
print(f"\n📈 Final Statistics:")
|
|
print(f" Total frames: {self.frames_processed}")
|
|
print(f" Total time: {total_time/60:.1f} minutes")
|
|
print(f" Average FPS: {self.frames_processed/total_time:.1f}")
|
|
print(f" Peak memory: {self.performance_stats['peak_memory_gb']:.1f}GB")
|
|
|
|
# Stereo consistency stats
|
|
stereo_stats = self.stereo_manager.get_stats()
|
|
print(f"\n👀 Stereo Consistency:")
|
|
print(f" Corrections applied: {stereo_stats['corrections_applied']}")
|
|
print(f" Correction rate: {stereo_stats['correction_rate']*100:.1f}%")
|
|
|
|
print("\n✅ Processing complete!") |