Files
test2/vr180_streaming/streaming_processor.py
2025-07-27 10:37:40 -07:00

418 lines
17 KiB
Python

"""
Main VR180 streaming processor - orchestrates all components for true streaming
"""
import time
import gc
import json
import psutil
import torch
import numpy as np
from pathlib import Path
from typing import Dict, Any, Optional, Tuple
import warnings
from .frame_reader import StreamingFrameReader
from .frame_writer import StreamingFrameWriter
from .stereo_manager import StereoConsistencyManager
from .sam2_streaming_simple import SAM2StreamingProcessor
from .detector import PersonDetector
from .config import StreamingConfig
class VR180StreamingProcessor:
"""Main processor for streaming VR180 human matting"""
def __init__(self, config: StreamingConfig):
self.config = config
# Initialize components
self.frame_reader = None
self.frame_writer = None
self.stereo_manager = None
self.sam2_processor = None
self.detector = None
# Processing state
self.start_time = None
self.frames_processed = 0
self.checkpoint_state = {}
# Performance monitoring
self.process = psutil.Process()
self.performance_stats = {
'fps': 0.0,
'avg_frame_time': 0.0,
'peak_memory_gb': 0.0,
'gpu_utilization': 0.0
}
def initialize(self) -> None:
"""Initialize all components"""
print("\n🚀 Initializing VR180 Streaming Processor")
print("=" * 60)
# Initialize frame reader
start_frame = self._load_checkpoint() if self.config.recovery.auto_resume else 0
self.frame_reader = StreamingFrameReader(
self.config.input.video_path,
start_frame=start_frame
)
# Get video info
video_info = self.frame_reader.get_video_info()
# Apply scaling to dimensions
scale = self.config.processing.scale_factor
output_width = int(video_info['width'] * scale)
output_height = int(video_info['height'] * scale)
# Initialize frame writer
self.frame_writer = StreamingFrameWriter(
output_path=self.config.output.path,
width=output_width,
height=output_height,
fps=video_info['fps'],
audio_source=self.config.input.video_path if self.config.output.maintain_sbs else None,
video_codec=self.config.output.video_codec,
quality_preset=self.config.output.quality_preset,
crf=self.config.output.crf
)
# Initialize stereo manager
self.stereo_manager = StereoConsistencyManager(self.config.to_dict())
# Initialize SAM2 processor
self.sam2_processor = SAM2StreamingProcessor(self.config.to_dict())
# Initialize detector
self.detector = PersonDetector(self.config.to_dict())
self.detector.warmup((output_height // 2, output_width // 2, 3)) # Warmup with single eye dims
print("\n✅ All components initialized successfully!")
print(f" Input: {video_info['width']}x{video_info['height']} @ {video_info['fps']}fps")
print(f" Output: {output_width}x{output_height} @ {video_info['fps']}fps")
print(f" Scale factor: {scale}")
print(f" Starting from frame: {start_frame}")
print("=" * 60 + "\n")
def process_video(self) -> None:
"""Main processing loop"""
try:
self.initialize()
self.start_time = time.time()
# Simple SAM2 initialization (no complex state management needed)
print("🎯 SAM2 streaming processor ready...")
# Process first frame to establish detections
print("🔍 Processing first frame for initial detection...")
if not self._initialize_tracking():
raise RuntimeError("Failed to initialize tracking - no persons detected")
# Main streaming loop
print("\n🎬 Starting streaming processing loop...")
self._streaming_loop()
except KeyboardInterrupt:
print("\n⚠️ Processing interrupted by user")
self._save_checkpoint()
except Exception as e:
print(f"\n❌ Error during processing: {e}")
self._save_checkpoint()
raise
finally:
self._finalize()
def _initialize_tracking(self) -> bool:
"""Initialize tracking with first frame detection"""
# Read and process first frame
first_frame = self.frame_reader.read_frame()
if first_frame is None:
raise RuntimeError("Cannot read first frame")
# Scale frame if needed
if self.config.processing.scale_factor != 1.0:
first_frame = self._scale_frame(first_frame)
# Split into eyes
left_eye, right_eye = self.stereo_manager.split_frame(first_frame)
# Detect on master eye
master_eye = left_eye if self.stereo_manager.master_eye == 'left' else right_eye
detections = self.detector.detect_persons(master_eye)
if not detections:
warnings.warn("No persons detected in first frame")
return False
print(f" Detected {len(detections)} person(s) in first frame")
# Process with simple SAM2 approach
left_masks = self.sam2_processor.add_frame_and_detections(left_eye, detections, 0)
# Transfer detections to right eye
transferred_detections = self.stereo_manager.transfer_detections(
detections,
'left_to_right' if self.stereo_manager.master_eye == 'left' else 'right_to_left'
)
right_masks = self.sam2_processor.add_frame_and_detections(right_eye, transferred_detections, 0)
# Apply masks and write
processed_frame = self._apply_masks_to_frame(first_frame, left_masks, right_masks)
self.frame_writer.write_frame(processed_frame)
self.frames_processed = 1
return True
def _streaming_loop(self) -> None:
"""Main streaming processing loop"""
frame_times = []
last_log_time = time.time()
# Start from frame 1 (already processed frame 0)
for frame_idx, frame in enumerate(self.frame_reader, start=1):
frame_start_time = time.time()
# Scale frame if needed
if self.config.processing.scale_factor != 1.0:
frame = self._scale_frame(frame)
# Split into eyes
left_eye, right_eye = self.stereo_manager.split_frame(frame)
# Check if we need to run detection for continuous correction
detections = []
if (self.config.matting.continuous_correction and
frame_idx % self.config.matting.correction_interval == 0):
print(f"\n🔄 Running YOLO detection for correction at frame {frame_idx}")
master_eye = left_eye if self.stereo_manager.master_eye == 'left' else right_eye
detections = self.detector.detect_persons(master_eye)
if detections:
print(f" Detected {len(detections)} person(s) for correction")
else:
print(f" No persons detected for correction")
# Process frames (with detections if this is a correction frame)
left_masks = self.sam2_processor.add_frame_and_detections(left_eye, detections, frame_idx)
# For right eye, transfer detections if we have them
if detections:
transferred_detections = self.stereo_manager.transfer_detections(
detections,
'left_to_right' if self.stereo_manager.master_eye == 'left' else 'right_to_left'
)
right_masks = self.sam2_processor.add_frame_and_detections(right_eye, transferred_detections, frame_idx)
else:
right_masks = self.sam2_processor.add_frame_and_detections(right_eye, [], frame_idx)
# Validate stereo consistency
right_masks = self.stereo_manager.validate_masks(
left_masks, right_masks, frame_idx
)
# Apply masks and write frame
processed_frame = self._apply_masks_to_frame(frame, left_masks, right_masks)
self.frame_writer.write_frame(processed_frame)
# Update stats
frame_time = time.time() - frame_start_time
frame_times.append(frame_time)
self.frames_processed += 1
# Periodic logging and cleanup
if frame_idx % self.config.performance.log_interval == 0:
self._log_progress(frame_idx, frame_times)
frame_times = frame_times[-100:] # Keep only recent times
# Checkpoint saving
if (self.config.recovery.enable_checkpoints and
frame_idx % self.config.recovery.checkpoint_interval == 0):
self._save_checkpoint()
# Memory monitoring and cleanup
if frame_idx % 50 == 0:
self._monitor_and_cleanup()
# Check max frames limit
if (self.config.input.max_frames is not None and
self.frames_processed >= self.config.input.max_frames):
print(f"\n✅ Reached max frames limit ({self.config.input.max_frames})")
break
def _scale_frame(self, frame: np.ndarray) -> np.ndarray:
"""Scale frame according to configuration"""
scale = self.config.processing.scale_factor
if scale == 1.0:
return frame
new_width = int(frame.shape[1] * scale)
new_height = int(frame.shape[0] * scale)
import cv2
return cv2.resize(frame, (new_width, new_height), interpolation=cv2.INTER_AREA)
def _apply_masks_to_frame(self,
frame: np.ndarray,
left_masks: np.ndarray,
right_masks: np.ndarray) -> np.ndarray:
"""Apply masks to frame and combine results"""
# Split frame
left_eye, right_eye = self.stereo_manager.split_frame(frame)
# Apply masks to each eye
left_processed = self.sam2_processor.apply_mask_to_frame(
left_eye, left_masks,
output_format=self.config.output.format,
background_color=self.config.output.background_color
)
right_processed = self.sam2_processor.apply_mask_to_frame(
right_eye, right_masks,
output_format=self.config.output.format,
background_color=self.config.output.background_color
)
# Combine back to SBS
if self.config.output.maintain_sbs:
return self.stereo_manager.combine_frames(left_processed, right_processed)
else:
# Return just left eye for non-SBS output
return left_processed
def _apply_continuous_correction(self,
left_eye: np.ndarray,
right_eye: np.ndarray,
frame_idx: int) -> None:
"""Apply continuous correction to maintain tracking accuracy"""
print(f"\n🔄 Applying continuous correction at frame {frame_idx}")
# Detect on master eye and add fresh detections
master_eye = left_eye if self.stereo_manager.master_eye == 'left' else right_eye
detections = self.detector.detect_persons(master_eye)
if detections:
print(f" Adding {len(detections)} fresh detection(s) for correction")
# Add fresh detections to help correct drift
self.sam2_processor.add_frame_and_detections(master_eye, detections, frame_idx)
# Transfer corrections to slave eye
# Note: This is simplified - actual implementation would transfer the refined prompts
def _log_progress(self, frame_idx: int, frame_times: list) -> None:
"""Log processing progress"""
elapsed = time.time() - self.start_time
avg_frame_time = np.mean(frame_times) if frame_times else 0
fps = 1.0 / avg_frame_time if avg_frame_time > 0 else 0
# Memory stats
memory_info = self.process.memory_info()
memory_gb = memory_info.rss / (1024**3)
# GPU stats if available
gpu_stats = self.sam2_processor.get_memory_usage()
# Progress percentage
progress = self.frame_reader.get_progress()
print(f"\n📊 Progress: Frame {frame_idx} ({progress:.1f}%)")
print(f" Speed: {fps:.1f} FPS (avg: {avg_frame_time*1000:.1f}ms/frame)")
print(f" Memory: {memory_gb:.1f}GB RAM", end="")
if 'cuda_allocated_gb' in gpu_stats:
print(f", {gpu_stats['cuda_allocated_gb']:.1f}GB VRAM")
else:
print()
print(f" Time elapsed: {elapsed/60:.1f} minutes")
# Update performance stats
self.performance_stats['fps'] = fps
self.performance_stats['avg_frame_time'] = avg_frame_time
self.performance_stats['peak_memory_gb'] = max(
self.performance_stats['peak_memory_gb'], memory_gb
)
def _monitor_and_cleanup(self) -> None:
"""Monitor memory and perform cleanup if needed"""
memory_info = self.process.memory_info()
memory_gb = memory_info.rss / (1024**3)
# Check if approaching limits
if memory_gb > self.config.hardware.max_ram_gb * 0.8:
print(f"\n⚠️ High memory usage ({memory_gb:.1f}GB) - running cleanup")
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
def _save_checkpoint(self) -> None:
"""Save processing checkpoint"""
if not self.config.recovery.enable_checkpoints:
return
checkpoint_dir = Path(self.config.recovery.checkpoint_dir)
checkpoint_dir.mkdir(exist_ok=True)
checkpoint_file = checkpoint_dir / f"{Path(self.config.output.path).stem}_checkpoint.json"
checkpoint_data = {
'frame_index': self.frames_processed,
'timestamp': time.time(),
'input_video': self.config.input.video_path,
'output_video': self.config.output.path,
'config': self.config.to_dict()
}
with open(checkpoint_file, 'w') as f:
json.dump(checkpoint_data, f, indent=2)
print(f"💾 Checkpoint saved at frame {self.frames_processed}")
def _load_checkpoint(self) -> int:
"""Load checkpoint if exists"""
checkpoint_dir = Path(self.config.recovery.checkpoint_dir)
checkpoint_file = checkpoint_dir / f"{Path(self.config.output.path).stem}_checkpoint.json"
if checkpoint_file.exists():
with open(checkpoint_file, 'r') as f:
checkpoint_data = json.load(f)
if checkpoint_data['input_video'] == self.config.input.video_path:
start_frame = checkpoint_data['frame_index']
print(f"📂 Found checkpoint - resuming from frame {start_frame}")
return start_frame
return 0
def _finalize(self) -> None:
"""Finalize processing and cleanup"""
print("\n🏁 Finalizing processing...")
# Close components
if self.frame_writer:
self.frame_writer.close()
if self.frame_reader:
self.frame_reader.close()
if self.sam2_processor:
self.sam2_processor.cleanup()
# Print final statistics
if self.start_time:
total_time = time.time() - self.start_time
print(f"\n📈 Final Statistics:")
print(f" Total frames: {self.frames_processed}")
print(f" Total time: {total_time/60:.1f} minutes")
print(f" Average FPS: {self.frames_processed/total_time:.1f}")
print(f" Peak memory: {self.performance_stats['peak_memory_gb']:.1f}GB")
# Stereo consistency stats
stereo_stats = self.stereo_manager.get_stats()
print(f"\n👀 Stereo Consistency:")
print(f" Corrections applied: {stereo_stats['corrections_applied']}")
print(f" Correction rate: {stereo_stats['correction_rate']*100:.1f}%")
print("\n✅ Processing complete!")