streaming part1
This commit is contained in:
418
vr180_streaming/streaming_processor.py
Normal file
418
vr180_streaming/streaming_processor.py
Normal file
@@ -0,0 +1,418 @@
|
||||
"""
|
||||
Main VR180 streaming processor - orchestrates all components for true streaming
|
||||
"""
|
||||
|
||||
import time
|
||||
import gc
|
||||
import json
|
||||
import psutil
|
||||
import torch
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional, Tuple
|
||||
import warnings
|
||||
|
||||
from .frame_reader import StreamingFrameReader
|
||||
from .frame_writer import StreamingFrameWriter
|
||||
from .stereo_manager import StereoConsistencyManager
|
||||
from .sam2_streaming import SAM2StreamingProcessor
|
||||
from .detector import PersonDetector
|
||||
from .config import StreamingConfig
|
||||
|
||||
|
||||
class VR180StreamingProcessor:
|
||||
"""Main processor for streaming VR180 human matting"""
|
||||
|
||||
def __init__(self, config: StreamingConfig):
|
||||
self.config = config
|
||||
|
||||
# Initialize components
|
||||
self.frame_reader = None
|
||||
self.frame_writer = None
|
||||
self.stereo_manager = None
|
||||
self.sam2_processor = None
|
||||
self.detector = None
|
||||
|
||||
# Processing state
|
||||
self.start_time = None
|
||||
self.frames_processed = 0
|
||||
self.checkpoint_state = {}
|
||||
|
||||
# Performance monitoring
|
||||
self.process = psutil.Process()
|
||||
self.performance_stats = {
|
||||
'fps': 0.0,
|
||||
'avg_frame_time': 0.0,
|
||||
'peak_memory_gb': 0.0,
|
||||
'gpu_utilization': 0.0
|
||||
}
|
||||
|
||||
def initialize(self) -> None:
|
||||
"""Initialize all components"""
|
||||
print("\n🚀 Initializing VR180 Streaming Processor")
|
||||
print("=" * 60)
|
||||
|
||||
# Initialize frame reader
|
||||
start_frame = self._load_checkpoint() if self.config.recovery.auto_resume else 0
|
||||
self.frame_reader = StreamingFrameReader(
|
||||
self.config.input.video_path,
|
||||
start_frame=start_frame
|
||||
)
|
||||
|
||||
# Get video info
|
||||
video_info = self.frame_reader.get_video_info()
|
||||
|
||||
# Apply scaling to dimensions
|
||||
scale = self.config.processing.scale_factor
|
||||
output_width = int(video_info['width'] * scale)
|
||||
output_height = int(video_info['height'] * scale)
|
||||
|
||||
# Initialize frame writer
|
||||
self.frame_writer = StreamingFrameWriter(
|
||||
output_path=self.config.output.path,
|
||||
width=output_width,
|
||||
height=output_height,
|
||||
fps=video_info['fps'],
|
||||
audio_source=self.config.input.video_path if self.config.output.maintain_sbs else None,
|
||||
video_codec=self.config.output.video_codec,
|
||||
quality_preset=self.config.output.quality_preset,
|
||||
crf=self.config.output.crf
|
||||
)
|
||||
|
||||
# Initialize stereo manager
|
||||
self.stereo_manager = StereoConsistencyManager(self.config.to_dict())
|
||||
|
||||
# Initialize SAM2 processor
|
||||
self.sam2_processor = SAM2StreamingProcessor(self.config.to_dict())
|
||||
|
||||
# Initialize detector
|
||||
self.detector = PersonDetector(self.config.to_dict())
|
||||
self.detector.warmup((output_height // 2, output_width // 2, 3)) # Warmup with single eye dims
|
||||
|
||||
print("\n✅ All components initialized successfully!")
|
||||
print(f" Input: {video_info['width']}x{video_info['height']} @ {video_info['fps']}fps")
|
||||
print(f" Output: {output_width}x{output_height} @ {video_info['fps']}fps")
|
||||
print(f" Scale factor: {scale}")
|
||||
print(f" Starting from frame: {start_frame}")
|
||||
print("=" * 60 + "\n")
|
||||
|
||||
def process_video(self) -> None:
|
||||
"""Main processing loop"""
|
||||
try:
|
||||
self.initialize()
|
||||
self.start_time = time.time()
|
||||
|
||||
# Initialize SAM2 states for both eyes
|
||||
print("🎯 Initializing SAM2 streaming states...")
|
||||
left_state = self.sam2_processor.init_state(
|
||||
self.config.input.video_path,
|
||||
eye='left'
|
||||
)
|
||||
right_state = self.sam2_processor.init_state(
|
||||
self.config.input.video_path,
|
||||
eye='right'
|
||||
)
|
||||
|
||||
# Process first frame to establish detections
|
||||
print("🔍 Processing first frame for initial detection...")
|
||||
if not self._initialize_tracking(left_state, right_state):
|
||||
raise RuntimeError("Failed to initialize tracking - no persons detected")
|
||||
|
||||
# Main streaming loop
|
||||
print("\n🎬 Starting streaming processing loop...")
|
||||
self._streaming_loop(left_state, right_state)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n⚠️ Processing interrupted by user")
|
||||
self._save_checkpoint()
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n❌ Error during processing: {e}")
|
||||
self._save_checkpoint()
|
||||
raise
|
||||
|
||||
finally:
|
||||
self._finalize()
|
||||
|
||||
def _initialize_tracking(self, left_state: Dict, right_state: Dict) -> bool:
|
||||
"""Initialize tracking with first frame detection"""
|
||||
# Read and process first frame
|
||||
first_frame = self.frame_reader.read_frame()
|
||||
if first_frame is None:
|
||||
raise RuntimeError("Cannot read first frame")
|
||||
|
||||
# Scale frame if needed
|
||||
if self.config.processing.scale_factor != 1.0:
|
||||
first_frame = self._scale_frame(first_frame)
|
||||
|
||||
# Split into eyes
|
||||
left_eye, right_eye = self.stereo_manager.split_frame(first_frame)
|
||||
|
||||
# Detect on master eye
|
||||
master_eye = left_eye if self.stereo_manager.master_eye == 'left' else right_eye
|
||||
detections = self.detector.detect_persons(master_eye)
|
||||
|
||||
if not detections:
|
||||
warnings.warn("No persons detected in first frame")
|
||||
return False
|
||||
|
||||
print(f" Detected {len(detections)} person(s) in first frame")
|
||||
|
||||
# Add detections to both eyes
|
||||
self.sam2_processor.add_detections(left_state, detections, frame_idx=0)
|
||||
|
||||
# Transfer detections to slave eye
|
||||
transferred_detections = self.stereo_manager.transfer_detections(
|
||||
detections,
|
||||
'left_to_right' if self.stereo_manager.master_eye == 'left' else 'right_to_left'
|
||||
)
|
||||
self.sam2_processor.add_detections(right_state, transferred_detections, frame_idx=0)
|
||||
|
||||
# Process and write first frame
|
||||
left_masks = self.sam2_processor._propagate_single_frame(left_state, left_eye, 0)
|
||||
right_masks = self.sam2_processor._propagate_single_frame(right_state, right_eye, 0)
|
||||
|
||||
# Apply masks and write
|
||||
processed_frame = self._apply_masks_to_frame(first_frame, left_masks, right_masks)
|
||||
self.frame_writer.write_frame(processed_frame)
|
||||
|
||||
self.frames_processed = 1
|
||||
return True
|
||||
|
||||
def _streaming_loop(self, left_state: Dict, right_state: Dict) -> None:
|
||||
"""Main streaming processing loop"""
|
||||
frame_times = []
|
||||
last_log_time = time.time()
|
||||
|
||||
# Start from frame 1 (already processed frame 0)
|
||||
for frame_idx, frame in enumerate(self.frame_reader, start=1):
|
||||
frame_start_time = time.time()
|
||||
|
||||
# Scale frame if needed
|
||||
if self.config.processing.scale_factor != 1.0:
|
||||
frame = self._scale_frame(frame)
|
||||
|
||||
# Split into eyes
|
||||
left_eye, right_eye = self.stereo_manager.split_frame(frame)
|
||||
|
||||
# Propagate masks for both eyes
|
||||
left_masks, right_masks = self.sam2_processor.propagate_frame_pair(
|
||||
left_state, right_state, left_eye, right_eye, frame_idx
|
||||
)
|
||||
|
||||
# Validate stereo consistency
|
||||
right_masks = self.stereo_manager.validate_masks(
|
||||
left_masks, right_masks, frame_idx
|
||||
)
|
||||
|
||||
# Apply continuous correction if enabled
|
||||
if (self.config.matting.continuous_correction and
|
||||
frame_idx % self.config.matting.correction_interval == 0):
|
||||
self._apply_continuous_correction(
|
||||
left_state, right_state, left_eye, right_eye, frame_idx
|
||||
)
|
||||
|
||||
# Apply masks and write frame
|
||||
processed_frame = self._apply_masks_to_frame(frame, left_masks, right_masks)
|
||||
self.frame_writer.write_frame(processed_frame)
|
||||
|
||||
# Update stats
|
||||
frame_time = time.time() - frame_start_time
|
||||
frame_times.append(frame_time)
|
||||
self.frames_processed += 1
|
||||
|
||||
# Periodic logging and cleanup
|
||||
if frame_idx % self.config.performance.log_interval == 0:
|
||||
self._log_progress(frame_idx, frame_times)
|
||||
frame_times = frame_times[-100:] # Keep only recent times
|
||||
|
||||
# Checkpoint saving
|
||||
if (self.config.recovery.enable_checkpoints and
|
||||
frame_idx % self.config.recovery.checkpoint_interval == 0):
|
||||
self._save_checkpoint()
|
||||
|
||||
# Memory monitoring and cleanup
|
||||
if frame_idx % 50 == 0:
|
||||
self._monitor_and_cleanup()
|
||||
|
||||
# Check max frames limit
|
||||
if (self.config.input.max_frames is not None and
|
||||
self.frames_processed >= self.config.input.max_frames):
|
||||
print(f"\n✅ Reached max frames limit ({self.config.input.max_frames})")
|
||||
break
|
||||
|
||||
def _scale_frame(self, frame: np.ndarray) -> np.ndarray:
|
||||
"""Scale frame according to configuration"""
|
||||
scale = self.config.processing.scale_factor
|
||||
if scale == 1.0:
|
||||
return frame
|
||||
|
||||
new_width = int(frame.shape[1] * scale)
|
||||
new_height = int(frame.shape[0] * scale)
|
||||
|
||||
import cv2
|
||||
return cv2.resize(frame, (new_width, new_height), interpolation=cv2.INTER_AREA)
|
||||
|
||||
def _apply_masks_to_frame(self,
|
||||
frame: np.ndarray,
|
||||
left_masks: np.ndarray,
|
||||
right_masks: np.ndarray) -> np.ndarray:
|
||||
"""Apply masks to frame and combine results"""
|
||||
# Split frame
|
||||
left_eye, right_eye = self.stereo_manager.split_frame(frame)
|
||||
|
||||
# Apply masks to each eye
|
||||
left_processed = self.sam2_processor.apply_mask_to_frame(
|
||||
left_eye, left_masks,
|
||||
output_format=self.config.output.format,
|
||||
background_color=self.config.output.background_color
|
||||
)
|
||||
|
||||
right_processed = self.sam2_processor.apply_mask_to_frame(
|
||||
right_eye, right_masks,
|
||||
output_format=self.config.output.format,
|
||||
background_color=self.config.output.background_color
|
||||
)
|
||||
|
||||
# Combine back to SBS
|
||||
if self.config.output.maintain_sbs:
|
||||
return self.stereo_manager.combine_frames(left_processed, right_processed)
|
||||
else:
|
||||
# Return just left eye for non-SBS output
|
||||
return left_processed
|
||||
|
||||
def _apply_continuous_correction(self,
|
||||
left_state: Dict,
|
||||
right_state: Dict,
|
||||
left_eye: np.ndarray,
|
||||
right_eye: np.ndarray,
|
||||
frame_idx: int) -> None:
|
||||
"""Apply continuous correction to maintain tracking accuracy"""
|
||||
print(f"\n🔄 Applying continuous correction at frame {frame_idx}")
|
||||
|
||||
# Detect on master eye
|
||||
master_eye = left_eye if self.stereo_manager.master_eye == 'left' else right_eye
|
||||
master_state = left_state if self.stereo_manager.master_eye == 'left' else right_state
|
||||
|
||||
self.sam2_processor.apply_continuous_correction(
|
||||
master_state, master_eye, frame_idx, self.detector
|
||||
)
|
||||
|
||||
# Transfer corrections to slave eye
|
||||
# Note: This is simplified - actual implementation would transfer the refined prompts
|
||||
|
||||
def _log_progress(self, frame_idx: int, frame_times: list) -> None:
|
||||
"""Log processing progress"""
|
||||
elapsed = time.time() - self.start_time
|
||||
avg_frame_time = np.mean(frame_times) if frame_times else 0
|
||||
fps = 1.0 / avg_frame_time if avg_frame_time > 0 else 0
|
||||
|
||||
# Memory stats
|
||||
memory_info = self.process.memory_info()
|
||||
memory_gb = memory_info.rss / (1024**3)
|
||||
|
||||
# GPU stats if available
|
||||
gpu_stats = self.sam2_processor.get_memory_usage()
|
||||
|
||||
# Progress percentage
|
||||
progress = self.frame_reader.get_progress()
|
||||
|
||||
print(f"\n📊 Progress: Frame {frame_idx} ({progress:.1f}%)")
|
||||
print(f" Speed: {fps:.1f} FPS (avg: {avg_frame_time*1000:.1f}ms/frame)")
|
||||
print(f" Memory: {memory_gb:.1f}GB RAM", end="")
|
||||
if 'cuda_allocated_gb' in gpu_stats:
|
||||
print(f", {gpu_stats['cuda_allocated_gb']:.1f}GB VRAM")
|
||||
else:
|
||||
print()
|
||||
print(f" Time elapsed: {elapsed/60:.1f} minutes")
|
||||
|
||||
# Update performance stats
|
||||
self.performance_stats['fps'] = fps
|
||||
self.performance_stats['avg_frame_time'] = avg_frame_time
|
||||
self.performance_stats['peak_memory_gb'] = max(
|
||||
self.performance_stats['peak_memory_gb'], memory_gb
|
||||
)
|
||||
|
||||
def _monitor_and_cleanup(self) -> None:
|
||||
"""Monitor memory and perform cleanup if needed"""
|
||||
memory_info = self.process.memory_info()
|
||||
memory_gb = memory_info.rss / (1024**3)
|
||||
|
||||
# Check if approaching limits
|
||||
if memory_gb > self.config.hardware.max_ram_gb * 0.8:
|
||||
print(f"\n⚠️ High memory usage ({memory_gb:.1f}GB) - running cleanup")
|
||||
gc.collect()
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def _save_checkpoint(self) -> None:
|
||||
"""Save processing checkpoint"""
|
||||
if not self.config.recovery.enable_checkpoints:
|
||||
return
|
||||
|
||||
checkpoint_dir = Path(self.config.recovery.checkpoint_dir)
|
||||
checkpoint_dir.mkdir(exist_ok=True)
|
||||
|
||||
checkpoint_file = checkpoint_dir / f"{Path(self.config.output.path).stem}_checkpoint.json"
|
||||
|
||||
checkpoint_data = {
|
||||
'frame_index': self.frames_processed,
|
||||
'timestamp': time.time(),
|
||||
'input_video': self.config.input.video_path,
|
||||
'output_video': self.config.output.path,
|
||||
'config': self.config.to_dict()
|
||||
}
|
||||
|
||||
with open(checkpoint_file, 'w') as f:
|
||||
json.dump(checkpoint_data, f, indent=2)
|
||||
|
||||
print(f"💾 Checkpoint saved at frame {self.frames_processed}")
|
||||
|
||||
def _load_checkpoint(self) -> int:
|
||||
"""Load checkpoint if exists"""
|
||||
checkpoint_dir = Path(self.config.recovery.checkpoint_dir)
|
||||
checkpoint_file = checkpoint_dir / f"{Path(self.config.output.path).stem}_checkpoint.json"
|
||||
|
||||
if checkpoint_file.exists():
|
||||
with open(checkpoint_file, 'r') as f:
|
||||
checkpoint_data = json.load(f)
|
||||
|
||||
if checkpoint_data['input_video'] == self.config.input.video_path:
|
||||
start_frame = checkpoint_data['frame_index']
|
||||
print(f"📂 Found checkpoint - resuming from frame {start_frame}")
|
||||
return start_frame
|
||||
|
||||
return 0
|
||||
|
||||
def _finalize(self) -> None:
|
||||
"""Finalize processing and cleanup"""
|
||||
print("\n🏁 Finalizing processing...")
|
||||
|
||||
# Close components
|
||||
if self.frame_writer:
|
||||
self.frame_writer.close()
|
||||
|
||||
if self.frame_reader:
|
||||
self.frame_reader.close()
|
||||
|
||||
if self.sam2_processor:
|
||||
self.sam2_processor.cleanup()
|
||||
|
||||
# Print final statistics
|
||||
if self.start_time:
|
||||
total_time = time.time() - self.start_time
|
||||
print(f"\n📈 Final Statistics:")
|
||||
print(f" Total frames: {self.frames_processed}")
|
||||
print(f" Total time: {total_time/60:.1f} minutes")
|
||||
print(f" Average FPS: {self.frames_processed/total_time:.1f}")
|
||||
print(f" Peak memory: {self.performance_stats['peak_memory_gb']:.1f}GB")
|
||||
|
||||
# Stereo consistency stats
|
||||
stereo_stats = self.stereo_manager.get_stats()
|
||||
print(f"\n👀 Stereo Consistency:")
|
||||
print(f" Corrections applied: {stereo_stats['corrections_applied']}")
|
||||
print(f" Correction rate: {stereo_stats['correction_rate']*100:.1f}%")
|
||||
|
||||
print("\n✅ Processing complete!")
|
||||
Reference in New Issue
Block a user