first commit
This commit is contained in:
3
vr180_matting/__init__.py
Normal file
3
vr180_matting/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""VR180 Human Matting with Det-SAM2"""
|
||||
|
||||
__version__ = "0.1.0"
|
||||
143
vr180_matting/config.py
Normal file
143
vr180_matting/config.py
Normal file
@@ -0,0 +1,143 @@
|
||||
import yaml
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Union
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@dataclass
|
||||
class InputConfig:
|
||||
video_path: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProcessingConfig:
|
||||
scale_factor: float = 0.5
|
||||
chunk_size: int = 900
|
||||
overlap_frames: int = 60
|
||||
|
||||
|
||||
@dataclass
|
||||
class DetectionConfig:
|
||||
confidence_threshold: float = 0.7
|
||||
model: str = "yolov8n"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MattingConfig:
|
||||
use_disparity_mapping: bool = True
|
||||
memory_offload: bool = True
|
||||
fp16: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class OutputConfig:
|
||||
path: str
|
||||
format: str = "alpha"
|
||||
background_color: List[int] = None
|
||||
maintain_sbs: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
if self.background_color is None:
|
||||
self.background_color = [0, 255, 0]
|
||||
|
||||
|
||||
@dataclass
|
||||
class HardwareConfig:
|
||||
device: str = "cuda"
|
||||
max_vram_gb: int = 10
|
||||
|
||||
|
||||
@dataclass
|
||||
class VR180Config:
|
||||
input: InputConfig
|
||||
processing: ProcessingConfig
|
||||
detection: DetectionConfig
|
||||
matting: MattingConfig
|
||||
output: OutputConfig
|
||||
hardware: HardwareConfig
|
||||
|
||||
@classmethod
|
||||
def from_yaml(cls, config_path: Union[str, Path]) -> "VR180Config":
|
||||
"""Load configuration from YAML file"""
|
||||
with open(config_path, 'r') as f:
|
||||
data = yaml.safe_load(f)
|
||||
|
||||
return cls(
|
||||
input=InputConfig(**data['input']),
|
||||
processing=ProcessingConfig(**data.get('processing', {})),
|
||||
detection=DetectionConfig(**data.get('detection', {})),
|
||||
matting=MattingConfig(**data.get('matting', {})),
|
||||
output=OutputConfig(**data['output']),
|
||||
hardware=HardwareConfig(**data.get('hardware', {}))
|
||||
)
|
||||
|
||||
def to_yaml(self, config_path: Union[str, Path]) -> None:
|
||||
"""Save configuration to YAML file"""
|
||||
data = {
|
||||
'input': {
|
||||
'video_path': self.input.video_path
|
||||
},
|
||||
'processing': {
|
||||
'scale_factor': self.processing.scale_factor,
|
||||
'chunk_size': self.processing.chunk_size,
|
||||
'overlap_frames': self.processing.overlap_frames
|
||||
},
|
||||
'detection': {
|
||||
'confidence_threshold': self.detection.confidence_threshold,
|
||||
'model': self.detection.model
|
||||
},
|
||||
'matting': {
|
||||
'use_disparity_mapping': self.matting.use_disparity_mapping,
|
||||
'memory_offload': self.matting.memory_offload,
|
||||
'fp16': self.matting.fp16
|
||||
},
|
||||
'output': {
|
||||
'path': self.output.path,
|
||||
'format': self.output.format,
|
||||
'background_color': self.output.background_color,
|
||||
'maintain_sbs': self.output.maintain_sbs
|
||||
},
|
||||
'hardware': {
|
||||
'device': self.hardware.device,
|
||||
'max_vram_gb': self.hardware.max_vram_gb
|
||||
}
|
||||
}
|
||||
|
||||
with open(config_path, 'w') as f:
|
||||
yaml.dump(data, f, default_flow_style=False, indent=2)
|
||||
|
||||
def validate(self) -> List[str]:
|
||||
"""Validate configuration and return list of errors"""
|
||||
errors = []
|
||||
|
||||
if not Path(self.input.video_path).exists():
|
||||
errors.append(f"Input video path does not exist: {self.input.video_path}")
|
||||
|
||||
if not 0.1 <= self.processing.scale_factor <= 1.0:
|
||||
errors.append("Scale factor must be between 0.1 and 1.0")
|
||||
|
||||
if self.processing.chunk_size < 0:
|
||||
errors.append("Chunk size must be non-negative (0 for full video)")
|
||||
|
||||
if not 0.1 <= self.detection.confidence_threshold <= 1.0:
|
||||
errors.append("Confidence threshold must be between 0.1 and 1.0")
|
||||
|
||||
if self.detection.model not in ["yolov8n", "yolov8s", "yolov8m", "yolov8l", "yolov8x"]:
|
||||
errors.append(f"Unsupported YOLO model: {self.detection.model}")
|
||||
|
||||
if self.output.format not in ["alpha", "greenscreen"]:
|
||||
errors.append("Output format must be 'alpha' or 'greenscreen'")
|
||||
|
||||
if len(self.output.background_color) != 3:
|
||||
errors.append("Background color must be RGB list with 3 values")
|
||||
|
||||
if not all(0 <= c <= 255 for c in self.output.background_color):
|
||||
errors.append("Background color values must be between 0 and 255")
|
||||
|
||||
if self.hardware.device not in ["cuda", "cpu"]:
|
||||
errors.append("Device must be 'cuda' or 'cpu'")
|
||||
|
||||
if self.hardware.max_vram_gb <= 0:
|
||||
errors.append("Max VRAM must be positive")
|
||||
|
||||
return errors
|
||||
126
vr180_matting/detector.py
Normal file
126
vr180_matting/detector.py
Normal file
@@ -0,0 +1,126 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from ultralytics import YOLO
|
||||
from typing import List, Tuple, Dict, Any
|
||||
import cv2
|
||||
|
||||
|
||||
class YOLODetector:
|
||||
"""YOLOv8-based person detector for automatic SAM2 prompting"""
|
||||
|
||||
def __init__(self, model_name: str = "yolov8n", confidence_threshold: float = 0.7, device: str = "cuda"):
|
||||
self.model_name = model_name
|
||||
self.confidence_threshold = confidence_threshold
|
||||
self.device = device
|
||||
self.model = None
|
||||
self._load_model()
|
||||
|
||||
def _load_model(self):
|
||||
"""Load YOLOv8 model"""
|
||||
try:
|
||||
self.model = YOLO(f"{self.model_name}.pt")
|
||||
if self.device == "cuda" and torch.cuda.is_available():
|
||||
self.model.to("cuda")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load YOLO model {self.model_name}: {e}")
|
||||
|
||||
def detect_persons(self, frame: np.ndarray) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Detect persons in frame and return bounding boxes
|
||||
|
||||
Args:
|
||||
frame: Input frame (H, W, 3)
|
||||
|
||||
Returns:
|
||||
List of detection dictionaries with bbox, confidence, and class info
|
||||
"""
|
||||
if self.model is None:
|
||||
raise RuntimeError("YOLO model not loaded")
|
||||
|
||||
results = self.model(frame, verbose=False)
|
||||
detections = []
|
||||
|
||||
for result in results:
|
||||
boxes = result.boxes
|
||||
if boxes is not None:
|
||||
for box in boxes:
|
||||
# Only keep person detections (class 0 in COCO)
|
||||
if int(box.cls) == 0 and float(box.conf) >= self.confidence_threshold:
|
||||
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
|
||||
confidence = float(box.conf)
|
||||
|
||||
detection = {
|
||||
'bbox': [int(x1), int(y1), int(x2), int(y2)],
|
||||
'confidence': confidence,
|
||||
'class': 'person',
|
||||
'area': (x2 - x1) * (y2 - y1)
|
||||
}
|
||||
detections.append(detection)
|
||||
|
||||
# Sort by confidence (highest first)
|
||||
detections.sort(key=lambda x: x['confidence'], reverse=True)
|
||||
return detections
|
||||
|
||||
def convert_to_sam_prompts(self, detections: List[Dict[str, Any]]) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Convert YOLO detections to SAM2 box prompts
|
||||
|
||||
Args:
|
||||
detections: List of detection dictionaries
|
||||
|
||||
Returns:
|
||||
Tuple of (box_prompts, labels) for SAM2
|
||||
"""
|
||||
if not detections:
|
||||
return np.array([]), np.array([])
|
||||
|
||||
box_prompts = []
|
||||
labels = []
|
||||
|
||||
for detection in detections:
|
||||
bbox = detection['bbox']
|
||||
box_prompts.append(bbox)
|
||||
labels.append(1) # Positive prompt
|
||||
|
||||
return np.array(box_prompts), np.array(labels)
|
||||
|
||||
def visualize_detections(self, frame: np.ndarray, detections: List[Dict[str, Any]]) -> np.ndarray:
|
||||
"""
|
||||
Draw detection boxes on frame for debugging
|
||||
|
||||
Args:
|
||||
frame: Input frame
|
||||
detections: List of detections
|
||||
|
||||
Returns:
|
||||
Frame with drawn bounding boxes
|
||||
"""
|
||||
vis_frame = frame.copy()
|
||||
|
||||
for detection in detections:
|
||||
x1, y1, x2, y2 = detection['bbox']
|
||||
confidence = detection['confidence']
|
||||
|
||||
# Draw bounding box
|
||||
cv2.rectangle(vis_frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
||||
|
||||
# Draw confidence score
|
||||
label = f"Person: {confidence:.2f}"
|
||||
label_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)[0]
|
||||
cv2.rectangle(vis_frame, (x1, y1 - label_size[1] - 10),
|
||||
(x1 + label_size[0], y1), (0, 255, 0), -1)
|
||||
cv2.putText(vis_frame, label, (x1, y1 - 5),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 2)
|
||||
|
||||
return vis_frame
|
||||
|
||||
def get_largest_person(self, detections: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""Get the largest detected person (by bounding box area)"""
|
||||
if not detections:
|
||||
return None
|
||||
|
||||
return max(detections, key=lambda x: x['area'])
|
||||
|
||||
def filter_by_size(self, detections: List[Dict[str, Any]], min_area: int = 1000) -> List[Dict[str, Any]]:
|
||||
"""Filter detections by minimum bounding box area"""
|
||||
return [d for d in detections if d['area'] >= min_area]
|
||||
240
vr180_matting/main.py
Normal file
240
vr180_matting/main.py
Normal file
@@ -0,0 +1,240 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
VR180 Human Matting with Det-SAM2
|
||||
Main CLI entry point
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import traceback
|
||||
|
||||
from .config import VR180Config
|
||||
from .vr180_processor import VR180Processor
|
||||
|
||||
|
||||
def create_parser() -> argparse.ArgumentParser:
|
||||
"""Create command line argument parser"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="VR180 Human Matting with Det-SAM2",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
# Process video with default config
|
||||
vr180-matting config.yaml
|
||||
|
||||
# Process with custom output path
|
||||
vr180-matting config.yaml --output /path/to/output.mp4
|
||||
|
||||
# Generate example config
|
||||
vr180-matting --generate-config config_example.yaml
|
||||
|
||||
# Process with different scale factor
|
||||
vr180-matting config.yaml --scale 0.25
|
||||
"""
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"config",
|
||||
nargs="?",
|
||||
help="Path to YAML configuration file"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--generate-config",
|
||||
metavar="PATH",
|
||||
help="Generate example configuration file at specified path"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output", "-o",
|
||||
metavar="PATH",
|
||||
help="Override output path from config"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--scale",
|
||||
type=float,
|
||||
metavar="FACTOR",
|
||||
help="Override scale factor (0.25, 0.5, 1.0)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--chunk-size",
|
||||
type=int,
|
||||
metavar="FRAMES",
|
||||
help="Override chunk size in frames (0 for auto)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
choices=["cuda", "cpu"],
|
||||
help="Override processing device"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--format",
|
||||
choices=["alpha", "greenscreen"],
|
||||
help="Override output format"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--verbose", "-v",
|
||||
action="store_true",
|
||||
help="Enable verbose output"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="Validate configuration without processing"
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def generate_example_config(output_path: str) -> None:
|
||||
"""Generate example configuration file"""
|
||||
config_content = '''input:
|
||||
video_path: "path/to/input.mp4"
|
||||
|
||||
processing:
|
||||
scale_factor: 0.5 # 0.25, 0.5, 1.0
|
||||
chunk_size: 900 # frames, 0 for full video
|
||||
overlap_frames: 60 # for chunked processing
|
||||
|
||||
detection:
|
||||
confidence_threshold: 0.7
|
||||
model: "yolov8n" # yolov8n, yolov8s, yolov8m
|
||||
|
||||
matting:
|
||||
use_disparity_mapping: true
|
||||
memory_offload: true
|
||||
fp16: true
|
||||
|
||||
output:
|
||||
path: "path/to/output/"
|
||||
format: "alpha" # "alpha" or "greenscreen"
|
||||
background_color: [0, 255, 0] # for greenscreen
|
||||
maintain_sbs: true # keep side-by-side format
|
||||
|
||||
hardware:
|
||||
device: "cuda"
|
||||
max_vram_gb: 10 # RTX 3080 limit
|
||||
'''
|
||||
|
||||
output_path = Path(output_path)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(output_path, 'w') as f:
|
||||
f.write(config_content)
|
||||
|
||||
print(f"Generated example configuration: {output_path}")
|
||||
print("Edit the configuration file and run:")
|
||||
print(f" vr180-matting {output_path}")
|
||||
|
||||
|
||||
def validate_config(config: VR180Config, verbose: bool = False) -> bool:
|
||||
"""Validate configuration and print any errors"""
|
||||
errors = config.validate()
|
||||
|
||||
if errors:
|
||||
print("Configuration validation failed:")
|
||||
for error in errors:
|
||||
print(f" ❌ {error}")
|
||||
return False
|
||||
|
||||
if verbose:
|
||||
print("Configuration validation passed ✅")
|
||||
print(f"Input video: {config.input.video_path}")
|
||||
print(f"Output path: {config.output.path}")
|
||||
print(f"Scale factor: {config.processing.scale_factor}")
|
||||
print(f"Device: {config.hardware.device}")
|
||||
print(f"Output format: {config.output.format}")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def apply_cli_overrides(config: VR180Config, args: argparse.Namespace) -> None:
|
||||
"""Apply command line overrides to configuration"""
|
||||
if args.output:
|
||||
config.output.path = args.output
|
||||
|
||||
if args.scale:
|
||||
if not 0.1 <= args.scale <= 1.0:
|
||||
raise ValueError("Scale factor must be between 0.1 and 1.0")
|
||||
config.processing.scale_factor = args.scale
|
||||
|
||||
if args.chunk_size is not None:
|
||||
if args.chunk_size < 0:
|
||||
raise ValueError("Chunk size must be non-negative")
|
||||
config.processing.chunk_size = args.chunk_size
|
||||
|
||||
if args.device:
|
||||
config.hardware.device = args.device
|
||||
|
||||
if args.format:
|
||||
config.output.format = args.format
|
||||
|
||||
|
||||
def main() -> int:
|
||||
"""Main entry point"""
|
||||
parser = create_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
# Handle config generation
|
||||
if args.generate_config:
|
||||
generate_example_config(args.generate_config)
|
||||
return 0
|
||||
|
||||
# Require config file for processing
|
||||
if not args.config:
|
||||
parser.print_help()
|
||||
print("\nError: Configuration file required")
|
||||
return 1
|
||||
|
||||
# Load configuration
|
||||
config_path = Path(args.config)
|
||||
if not config_path.exists():
|
||||
print(f"Error: Configuration file not found: {config_path}")
|
||||
return 1
|
||||
|
||||
print(f"Loading configuration from {config_path}")
|
||||
config = VR180Config.from_yaml(config_path)
|
||||
|
||||
# Apply CLI overrides
|
||||
apply_cli_overrides(config, args)
|
||||
|
||||
# Validate configuration
|
||||
if not validate_config(config, verbose=args.verbose):
|
||||
return 1
|
||||
|
||||
# Dry run mode
|
||||
if args.dry_run:
|
||||
print("Dry run completed successfully ✅")
|
||||
return 0
|
||||
|
||||
# Initialize processor
|
||||
print("Initializing VR180 processor...")
|
||||
processor = VR180Processor(config)
|
||||
|
||||
# Process video
|
||||
processor.process_video()
|
||||
|
||||
print("✅ Processing completed successfully!")
|
||||
return 0
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n⚠️ Processing interrupted by user")
|
||||
return 130
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error: {e}")
|
||||
if args.verbose:
|
||||
traceback.print_exc()
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
241
vr180_matting/memory_manager.py
Normal file
241
vr180_matting/memory_manager.py
Normal file
@@ -0,0 +1,241 @@
|
||||
import torch
|
||||
import psutil
|
||||
import gc
|
||||
import warnings
|
||||
from typing import Optional, Dict, Any
|
||||
from contextlib import contextmanager
|
||||
import time
|
||||
|
||||
|
||||
class VRAMManager:
|
||||
"""VRAM and memory optimization manager"""
|
||||
|
||||
def __init__(self, max_vram_gb: float = 10.0, device: str = "cuda"):
|
||||
self.max_vram_gb = max_vram_gb
|
||||
self.device = device
|
||||
self.max_vram_bytes = max_vram_gb * 1024**3
|
||||
|
||||
# Memory tracking
|
||||
self.memory_stats = {
|
||||
'peak_allocated': 0,
|
||||
'peak_reserved': 0,
|
||||
'allocations': 0,
|
||||
'deallocations': 0
|
||||
}
|
||||
|
||||
self._check_device()
|
||||
|
||||
def _check_device(self):
|
||||
"""Check if CUDA is available and get device info"""
|
||||
if self.device == "cuda":
|
||||
if not torch.cuda.is_available():
|
||||
warnings.warn("CUDA not available, falling back to CPU")
|
||||
self.device = "cpu"
|
||||
return
|
||||
|
||||
device_props = torch.cuda.get_device_properties(0)
|
||||
total_memory = device_props.total_memory
|
||||
|
||||
print(f"GPU: {device_props.name}")
|
||||
print(f"Total VRAM: {total_memory / 1024**3:.1f} GB")
|
||||
print(f"Max VRAM limit: {self.max_vram_gb:.1f} GB")
|
||||
|
||||
if self.max_vram_bytes > total_memory * 0.9:
|
||||
warnings.warn(f"Max VRAM limit ({self.max_vram_gb:.1f} GB) is close to total VRAM")
|
||||
|
||||
def get_memory_usage(self) -> Dict[str, float]:
|
||||
"""Get current memory usage statistics"""
|
||||
stats = {}
|
||||
|
||||
if self.device == "cuda" and torch.cuda.is_available():
|
||||
stats['vram_allocated'] = torch.cuda.memory_allocated() / 1024**3
|
||||
stats['vram_reserved'] = torch.cuda.memory_reserved() / 1024**3
|
||||
stats['vram_free'] = (torch.cuda.get_device_properties(0).total_memory -
|
||||
torch.cuda.memory_reserved()) / 1024**3
|
||||
else:
|
||||
stats['vram_allocated'] = 0
|
||||
stats['vram_reserved'] = 0
|
||||
stats['vram_free'] = 0
|
||||
|
||||
# System RAM
|
||||
ram_info = psutil.virtual_memory()
|
||||
stats['ram_used'] = ram_info.used / 1024**3
|
||||
stats['ram_available'] = ram_info.available / 1024**3
|
||||
stats['ram_percent'] = ram_info.percent
|
||||
|
||||
return stats
|
||||
|
||||
def check_memory_available(self, required_gb: float) -> bool:
|
||||
"""Check if enough memory is available for operation"""
|
||||
stats = self.get_memory_usage()
|
||||
|
||||
if self.device == "cuda":
|
||||
return stats['vram_free'] >= required_gb
|
||||
else:
|
||||
return stats['ram_available'] >= required_gb
|
||||
|
||||
def cleanup_memory(self, aggressive: bool = False):
|
||||
"""Clean up memory"""
|
||||
if self.device == "cuda" and torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if aggressive:
|
||||
torch.cuda.ipc_collect()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Python garbage collection
|
||||
gc.collect()
|
||||
|
||||
if aggressive:
|
||||
# Force garbage collection multiple times
|
||||
for _ in range(3):
|
||||
gc.collect()
|
||||
|
||||
def estimate_processing_memory(self,
|
||||
frame_height: int,
|
||||
frame_width: int,
|
||||
num_frames: int,
|
||||
fp16: bool = True) -> float:
|
||||
"""
|
||||
Estimate memory requirements for processing
|
||||
|
||||
Args:
|
||||
frame_height: Frame height in pixels
|
||||
frame_width: Frame width in pixels
|
||||
num_frames: Number of frames to process
|
||||
fp16: Whether using FP16 precision
|
||||
|
||||
Returns:
|
||||
Estimated memory usage in GB
|
||||
"""
|
||||
bytes_per_pixel = 2 if fp16 else 4 # FP16 vs FP32
|
||||
|
||||
# Estimate memory components
|
||||
frame_memory = frame_height * frame_width * 3 * bytes_per_pixel * num_frames
|
||||
model_memory = 2.0 * 1024**3 # ~2GB for SAM2 model
|
||||
yolo_memory = 0.5 * 1024**3 # ~0.5GB for YOLO
|
||||
working_memory = frame_memory * 2 # Working space for masks, etc.
|
||||
|
||||
total_memory = frame_memory + model_memory + yolo_memory + working_memory
|
||||
|
||||
return total_memory / 1024**3
|
||||
|
||||
def get_optimal_chunk_size(self,
|
||||
frame_height: int,
|
||||
frame_width: int,
|
||||
target_memory_gb: Optional[float] = None,
|
||||
fp16: bool = True) -> int:
|
||||
"""
|
||||
Calculate optimal chunk size for processing
|
||||
|
||||
Args:
|
||||
frame_height: Frame height in pixels
|
||||
frame_width: Frame width in pixels
|
||||
target_memory_gb: Target memory usage (defaults to 80% of max VRAM)
|
||||
fp16: Whether using FP16 precision
|
||||
|
||||
Returns:
|
||||
Optimal number of frames per chunk
|
||||
"""
|
||||
if target_memory_gb is None:
|
||||
target_memory_gb = self.max_vram_gb * 0.8
|
||||
|
||||
# Binary search for optimal chunk size
|
||||
min_frames = 1
|
||||
max_frames = 1000
|
||||
optimal_frames = min_frames
|
||||
|
||||
while min_frames <= max_frames:
|
||||
mid_frames = (min_frames + max_frames) // 2
|
||||
estimated_memory = self.estimate_processing_memory(
|
||||
frame_height, frame_width, mid_frames, fp16
|
||||
)
|
||||
|
||||
if estimated_memory <= target_memory_gb:
|
||||
optimal_frames = mid_frames
|
||||
min_frames = mid_frames + 1
|
||||
else:
|
||||
max_frames = mid_frames - 1
|
||||
|
||||
return max(optimal_frames, 1)
|
||||
|
||||
@contextmanager
|
||||
def memory_monitor(self, operation_name: str = "operation"):
|
||||
"""Context manager for monitoring memory usage during operations"""
|
||||
start_stats = self.get_memory_usage()
|
||||
start_time = time.time()
|
||||
|
||||
print(f"Starting {operation_name}")
|
||||
print(f"Initial VRAM: {start_stats['vram_allocated']:.2f} GB allocated, "
|
||||
f"{start_stats['vram_free']:.2f} GB free")
|
||||
|
||||
try:
|
||||
yield self
|
||||
finally:
|
||||
end_stats = self.get_memory_usage()
|
||||
end_time = time.time()
|
||||
|
||||
vram_diff = end_stats['vram_allocated'] - start_stats['vram_allocated']
|
||||
duration = end_time - start_time
|
||||
|
||||
print(f"Completed {operation_name} in {duration:.1f}s")
|
||||
print(f"Final VRAM: {end_stats['vram_allocated']:.2f} GB allocated, "
|
||||
f"{end_stats['vram_free']:.2f} GB free")
|
||||
print(f"VRAM change: {vram_diff:+.2f} GB")
|
||||
|
||||
# Update peak stats
|
||||
self.memory_stats['peak_allocated'] = max(
|
||||
self.memory_stats['peak_allocated'],
|
||||
end_stats['vram_allocated']
|
||||
)
|
||||
self.memory_stats['peak_reserved'] = max(
|
||||
self.memory_stats['peak_reserved'],
|
||||
end_stats['vram_reserved']
|
||||
)
|
||||
|
||||
def print_memory_report(self):
|
||||
"""Print detailed memory usage report"""
|
||||
stats = self.get_memory_usage()
|
||||
|
||||
print("\n" + "="*50)
|
||||
print("MEMORY USAGE REPORT")
|
||||
print("="*50)
|
||||
|
||||
if self.device == "cuda":
|
||||
print(f"VRAM Allocated: {stats['vram_allocated']:.2f} GB")
|
||||
print(f"VRAM Reserved: {stats['vram_reserved']:.2f} GB")
|
||||
print(f"VRAM Free: {stats['vram_free']:.2f} GB")
|
||||
print(f"Peak Allocated: {self.memory_stats['peak_allocated']:.2f} GB")
|
||||
print(f"Peak Reserved: {self.memory_stats['peak_reserved']:.2f} GB")
|
||||
print(f"Max VRAM Limit: {self.max_vram_gb:.2f} GB")
|
||||
|
||||
utilization = (stats['vram_allocated'] / self.max_vram_gb) * 100
|
||||
print(f"VRAM Utilization: {utilization:.1f}%")
|
||||
|
||||
print(f"\nSystem RAM Used: {stats['ram_used']:.2f} GB")
|
||||
print(f"System RAM Available: {stats['ram_available']:.2f} GB")
|
||||
print(f"System RAM Usage: {stats['ram_percent']:.1f}%")
|
||||
print("="*50 + "\n")
|
||||
|
||||
def emergency_cleanup(self):
|
||||
"""Emergency memory cleanup when running low"""
|
||||
print("WARNING: Running low on memory, performing emergency cleanup...")
|
||||
|
||||
self.cleanup_memory(aggressive=True)
|
||||
|
||||
# Additional cleanup steps
|
||||
if self.device == "cuda" and torch.cuda.is_available():
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
stats = self.get_memory_usage()
|
||||
print(f"After cleanup - VRAM: {stats['vram_allocated']:.2f} GB, "
|
||||
f"Free: {stats['vram_free']:.2f} GB")
|
||||
|
||||
def should_emergency_cleanup(self) -> bool:
|
||||
"""Check if emergency cleanup is needed"""
|
||||
stats = self.get_memory_usage()
|
||||
|
||||
if self.device == "cuda":
|
||||
return stats['vram_free'] < 1.0 # Less than 1GB free
|
||||
else:
|
||||
return stats['ram_available'] < 2.0 # Less than 2GB RAM available
|
||||
226
vr180_matting/sam2_wrapper.py
Normal file
226
vr180_matting/sam2_wrapper.py
Normal file
@@ -0,0 +1,226 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
import cv2
|
||||
from pathlib import Path
|
||||
import warnings
|
||||
|
||||
try:
|
||||
from sam2.build_sam import build_sam2_video_predictor
|
||||
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
||||
SAM2_AVAILABLE = True
|
||||
except ImportError:
|
||||
SAM2_AVAILABLE = False
|
||||
warnings.warn("SAM2 not available. Please install sam2 package.")
|
||||
|
||||
|
||||
class SAM2VideoMatting:
|
||||
"""SAM2-based video matting with memory optimization"""
|
||||
|
||||
def __init__(self,
|
||||
model_cfg: str = "sam2_hiera_l.yaml",
|
||||
checkpoint_path: str = "sam2_hiera_large.pt",
|
||||
device: str = "cuda",
|
||||
memory_offload: bool = True,
|
||||
fp16: bool = True):
|
||||
if not SAM2_AVAILABLE:
|
||||
raise ImportError("SAM2 not available. Please install sam2 package.")
|
||||
|
||||
self.device = device
|
||||
self.memory_offload = memory_offload
|
||||
self.fp16 = fp16
|
||||
self.predictor = None
|
||||
self.inference_state = None
|
||||
self.video_segments = {}
|
||||
|
||||
self._load_model(model_cfg, checkpoint_path)
|
||||
|
||||
def _load_model(self, model_cfg: str, checkpoint_path: str):
|
||||
"""Load SAM2 video predictor with optimizations"""
|
||||
try:
|
||||
self.predictor = build_sam2_video_predictor(
|
||||
model_cfg,
|
||||
checkpoint_path,
|
||||
device=self.device
|
||||
)
|
||||
|
||||
# Enable memory optimizations
|
||||
if self.memory_offload:
|
||||
self.predictor.fill_hole_area = 8
|
||||
|
||||
if self.fp16 and self.device == "cuda":
|
||||
self.predictor.model.half()
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load SAM2 model: {e}")
|
||||
|
||||
def init_video_state(self, video_frames: List[np.ndarray]) -> None:
|
||||
"""Initialize video inference state"""
|
||||
if self.predictor is None:
|
||||
raise RuntimeError("SAM2 model not loaded")
|
||||
|
||||
# Create temporary directory for frames if needed
|
||||
self.inference_state = self.predictor.init_state(
|
||||
video_path=None,
|
||||
video_frames=video_frames,
|
||||
offload_video_to_cpu=self.memory_offload,
|
||||
async_loading_frames=True
|
||||
)
|
||||
|
||||
def add_person_prompts(self,
|
||||
frame_idx: int,
|
||||
box_prompts: np.ndarray,
|
||||
labels: np.ndarray) -> List[int]:
|
||||
"""
|
||||
Add person detection prompts to SAM2
|
||||
|
||||
Args:
|
||||
frame_idx: Frame index to add prompts
|
||||
box_prompts: Bounding boxes (N, 4)
|
||||
labels: Prompt labels (N,)
|
||||
|
||||
Returns:
|
||||
List of object IDs
|
||||
"""
|
||||
if self.inference_state is None:
|
||||
raise RuntimeError("Video state not initialized")
|
||||
|
||||
object_ids = []
|
||||
|
||||
for i, (box, label) in enumerate(zip(box_prompts, labels)):
|
||||
obj_id = i + 1 # Start from 1
|
||||
|
||||
# Add box prompt
|
||||
_, out_obj_ids, out_mask_logits = self.predictor.add_new_points_or_box(
|
||||
inference_state=self.inference_state,
|
||||
frame_idx=frame_idx,
|
||||
obj_id=obj_id,
|
||||
box=box,
|
||||
)
|
||||
|
||||
object_ids.extend(out_obj_ids)
|
||||
|
||||
return object_ids
|
||||
|
||||
def propagate_masks(self, start_frame: int = 0, max_frames: Optional[int] = None) -> Dict[int, Dict[int, np.ndarray]]:
|
||||
"""
|
||||
Propagate masks through video
|
||||
|
||||
Args:
|
||||
start_frame: Starting frame index
|
||||
max_frames: Maximum number of frames to process
|
||||
|
||||
Returns:
|
||||
Dictionary mapping frame_idx -> {obj_id: mask}
|
||||
"""
|
||||
if self.inference_state is None:
|
||||
raise RuntimeError("Video state not initialized")
|
||||
|
||||
video_segments = {}
|
||||
|
||||
for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(
|
||||
self.inference_state,
|
||||
start_frame_idx=start_frame,
|
||||
max_frame_num_to_track=max_frames,
|
||||
reverse=False
|
||||
):
|
||||
frame_masks = {}
|
||||
|
||||
for i, out_obj_id in enumerate(out_obj_ids):
|
||||
mask = (out_mask_logits[i] > 0.0).cpu().numpy()
|
||||
frame_masks[out_obj_id] = mask
|
||||
|
||||
video_segments[out_frame_idx] = frame_masks
|
||||
|
||||
# Memory management: release old frames periodically
|
||||
if self.memory_offload and out_frame_idx % 100 == 0:
|
||||
self._release_old_frames(out_frame_idx - 50)
|
||||
|
||||
return video_segments
|
||||
|
||||
def _release_old_frames(self, before_frame_idx: int):
|
||||
"""Release old frames from memory"""
|
||||
try:
|
||||
if hasattr(self.predictor, 'release_old_frames'):
|
||||
self.predictor.release_old_frames(self.inference_state, before_frame_idx)
|
||||
except Exception as e:
|
||||
warnings.warn(f"Failed to release old frames: {e}")
|
||||
|
||||
def get_combined_mask(self, frame_masks: Dict[int, np.ndarray]) -> np.ndarray:
|
||||
"""Combine masks from multiple objects into single mask"""
|
||||
if not frame_masks:
|
||||
return None
|
||||
|
||||
combined_mask = np.zeros_like(next(iter(frame_masks.values())), dtype=bool)
|
||||
|
||||
for obj_id, mask in frame_masks.items():
|
||||
if mask.ndim == 3:
|
||||
mask = mask.squeeze()
|
||||
combined_mask = np.logical_or(combined_mask, mask)
|
||||
|
||||
return combined_mask
|
||||
|
||||
def apply_mask_to_frame(self,
|
||||
frame: np.ndarray,
|
||||
mask: np.ndarray,
|
||||
output_format: str = "alpha",
|
||||
background_color: List[int] = [0, 255, 0]) -> np.ndarray:
|
||||
"""
|
||||
Apply mask to frame to create matted output
|
||||
|
||||
Args:
|
||||
frame: Input frame (H, W, 3)
|
||||
mask: Binary mask (H, W)
|
||||
output_format: "alpha" or "greenscreen"
|
||||
background_color: RGB background color for greenscreen
|
||||
|
||||
Returns:
|
||||
Matted frame
|
||||
"""
|
||||
if mask is None:
|
||||
return frame
|
||||
|
||||
# Ensure mask is 2D
|
||||
if mask.ndim == 3:
|
||||
mask = mask.squeeze()
|
||||
|
||||
# Resize mask to match frame if needed
|
||||
if mask.shape[:2] != frame.shape[:2]:
|
||||
mask = cv2.resize(mask.astype(np.uint8),
|
||||
(frame.shape[1], frame.shape[0]),
|
||||
interpolation=cv2.INTER_NEAREST).astype(bool)
|
||||
|
||||
if output_format == "alpha":
|
||||
# Create RGBA output
|
||||
output = np.zeros((frame.shape[0], frame.shape[1], 4), dtype=np.uint8)
|
||||
output[:, :, :3] = frame
|
||||
output[:, :, 3] = mask.astype(np.uint8) * 255
|
||||
return output
|
||||
|
||||
elif output_format == "greenscreen":
|
||||
# Create RGB output with background
|
||||
output = np.full_like(frame, background_color, dtype=np.uint8)
|
||||
output[mask] = frame[mask]
|
||||
return output
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported output format: {output_format}")
|
||||
|
||||
def cleanup(self):
|
||||
"""Clean up resources"""
|
||||
if self.inference_state is not None:
|
||||
try:
|
||||
if hasattr(self.predictor, 'cleanup_state'):
|
||||
self.predictor.cleanup_state(self.inference_state)
|
||||
except Exception as e:
|
||||
warnings.warn(f"Failed to cleanup SAM2 state: {e}")
|
||||
|
||||
self.inference_state = None
|
||||
|
||||
# Clear CUDA cache
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def __del__(self):
|
||||
"""Destructor to ensure cleanup"""
|
||||
self.cleanup()
|
||||
415
vr180_matting/video_processor.py
Normal file
415
vr180_matting/video_processor.py
Normal file
@@ -0,0 +1,415 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
from typing import List, Dict, Any, Optional, Tuple, Generator
|
||||
from pathlib import Path
|
||||
import ffmpeg
|
||||
import tempfile
|
||||
import shutil
|
||||
from tqdm import tqdm
|
||||
import warnings
|
||||
|
||||
from .config import VR180Config
|
||||
from .detector import YOLODetector
|
||||
from .sam2_wrapper import SAM2VideoMatting
|
||||
from .memory_manager import VRAMManager
|
||||
|
||||
|
||||
class VideoProcessor:
|
||||
"""Main video processing pipeline for VR180 matting"""
|
||||
|
||||
def __init__(self, config: VR180Config):
|
||||
self.config = config
|
||||
self.memory_manager = VRAMManager(
|
||||
max_vram_gb=config.hardware.max_vram_gb,
|
||||
device=config.hardware.device
|
||||
)
|
||||
|
||||
# Initialize components
|
||||
self.detector = None
|
||||
self.sam2_model = None
|
||||
|
||||
# Video properties
|
||||
self.video_info = None
|
||||
self.total_frames = 0
|
||||
self.fps = 30.0
|
||||
self.frame_width = 0
|
||||
self.frame_height = 0
|
||||
|
||||
self._initialize_models()
|
||||
|
||||
def _initialize_models(self):
|
||||
"""Initialize YOLO detector and SAM2 model"""
|
||||
print("Initializing models...")
|
||||
|
||||
with self.memory_manager.memory_monitor("model loading"):
|
||||
# Initialize YOLO detector
|
||||
self.detector = YOLODetector(
|
||||
model_name=self.config.detection.model,
|
||||
confidence_threshold=self.config.detection.confidence_threshold,
|
||||
device=self.config.hardware.device
|
||||
)
|
||||
|
||||
# Initialize SAM2 model
|
||||
self.sam2_model = SAM2VideoMatting(
|
||||
device=self.config.hardware.device,
|
||||
memory_offload=self.config.matting.memory_offload,
|
||||
fp16=self.config.matting.fp16
|
||||
)
|
||||
|
||||
def load_video_info(self, video_path: str) -> Dict[str, Any]:
|
||||
"""Load video metadata using ffmpeg"""
|
||||
try:
|
||||
probe = ffmpeg.probe(video_path)
|
||||
video_stream = next(
|
||||
(stream for stream in probe['streams'] if stream['codec_type'] == 'video'),
|
||||
None
|
||||
)
|
||||
|
||||
if video_stream is None:
|
||||
raise ValueError("No video stream found")
|
||||
|
||||
self.video_info = {
|
||||
'width': int(video_stream['width']),
|
||||
'height': int(video_stream['height']),
|
||||
'fps': eval(video_stream['r_frame_rate']),
|
||||
'duration': float(video_stream.get('duration', 0)),
|
||||
'nb_frames': int(video_stream.get('nb_frames', 0)),
|
||||
'codec': video_stream['codec_name'],
|
||||
'pix_fmt': video_stream.get('pix_fmt', 'yuv420p')
|
||||
}
|
||||
|
||||
self.frame_width = self.video_info['width']
|
||||
self.frame_height = self.video_info['height']
|
||||
self.fps = self.video_info['fps']
|
||||
self.total_frames = self.video_info['nb_frames']
|
||||
|
||||
print(f"Video info: {self.frame_width}x{self.frame_height} @ {self.fps:.2f}fps")
|
||||
print(f"Total frames: {self.total_frames}, Duration: {self.video_info['duration']:.1f}s")
|
||||
|
||||
return self.video_info
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load video info: {e}")
|
||||
|
||||
def read_video_frames(self,
|
||||
video_path: str,
|
||||
start_frame: int = 0,
|
||||
num_frames: Optional[int] = None,
|
||||
scale_factor: float = 1.0) -> List[np.ndarray]:
|
||||
"""
|
||||
Read video frames with optional scaling
|
||||
|
||||
Args:
|
||||
video_path: Path to video file
|
||||
start_frame: Starting frame index
|
||||
num_frames: Number of frames to read (None for all)
|
||||
scale_factor: Scaling factor for frames
|
||||
|
||||
Returns:
|
||||
List of video frames
|
||||
"""
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
|
||||
if not cap.isOpened():
|
||||
raise RuntimeError(f"Failed to open video: {video_path}")
|
||||
|
||||
# Set starting position
|
||||
if start_frame > 0:
|
||||
cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
|
||||
|
||||
frames = []
|
||||
frame_count = 0
|
||||
|
||||
with tqdm(desc="Reading frames", total=num_frames) as pbar:
|
||||
while True:
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
|
||||
# Apply scaling if needed
|
||||
if scale_factor != 1.0:
|
||||
new_width = int(frame.shape[1] * scale_factor)
|
||||
new_height = int(frame.shape[0] * scale_factor)
|
||||
frame = cv2.resize(frame, (new_width, new_height),
|
||||
interpolation=cv2.INTER_AREA)
|
||||
|
||||
frames.append(frame)
|
||||
frame_count += 1
|
||||
pbar.update(1)
|
||||
|
||||
if num_frames is not None and frame_count >= num_frames:
|
||||
break
|
||||
|
||||
cap.release()
|
||||
print(f"Read {len(frames)} frames")
|
||||
return frames
|
||||
|
||||
def calculate_optimal_chunking(self) -> Tuple[int, int]:
|
||||
"""
|
||||
Calculate optimal chunk size and overlap based on memory constraints
|
||||
|
||||
Returns:
|
||||
Tuple of (chunk_size, overlap_frames)
|
||||
"""
|
||||
if self.config.processing.chunk_size > 0:
|
||||
return self.config.processing.chunk_size, self.config.processing.overlap_frames
|
||||
|
||||
# Calculate based on memory constraints
|
||||
scaled_height = int(self.frame_height * self.config.processing.scale_factor)
|
||||
scaled_width = int(self.frame_width * self.config.processing.scale_factor)
|
||||
|
||||
optimal_chunk = self.memory_manager.get_optimal_chunk_size(
|
||||
scaled_height, scaled_width, fp16=self.config.matting.fp16
|
||||
)
|
||||
|
||||
overlap = min(60, optimal_chunk // 10) # 10% overlap, max 60 frames
|
||||
|
||||
print(f"Calculated optimal chunk size: {optimal_chunk} frames with {overlap} frame overlap")
|
||||
return optimal_chunk, overlap
|
||||
|
||||
def process_chunk(self,
|
||||
frames: List[np.ndarray],
|
||||
chunk_idx: int = 0) -> List[np.ndarray]:
|
||||
"""
|
||||
Process a chunk of frames through the matting pipeline
|
||||
|
||||
Args:
|
||||
frames: List of frames to process
|
||||
chunk_idx: Chunk index for logging
|
||||
|
||||
Returns:
|
||||
List of matted frames
|
||||
"""
|
||||
print(f"Processing chunk {chunk_idx} ({len(frames)} frames)")
|
||||
|
||||
with self.memory_manager.memory_monitor(f"chunk {chunk_idx}"):
|
||||
# Initialize SAM2 with frames
|
||||
self.sam2_model.init_video_state(frames)
|
||||
|
||||
# Detect persons in first frame
|
||||
first_frame = frames[0]
|
||||
detections = self.detector.detect_persons(first_frame)
|
||||
|
||||
if not detections:
|
||||
warnings.warn(f"No persons detected in chunk {chunk_idx}")
|
||||
return self._create_empty_masks(frames)
|
||||
|
||||
print(f"Detected {len(detections)} persons in first frame")
|
||||
|
||||
# Convert detections to SAM2 prompts
|
||||
box_prompts, labels = self.detector.convert_to_sam_prompts(detections)
|
||||
|
||||
# Add prompts to SAM2
|
||||
object_ids = self.sam2_model.add_person_prompts(0, box_prompts, labels)
|
||||
print(f"Added prompts for {len(object_ids)} objects")
|
||||
|
||||
# Propagate masks through chunk
|
||||
video_segments = self.sam2_model.propagate_masks(
|
||||
start_frame=0,
|
||||
max_frames=len(frames)
|
||||
)
|
||||
|
||||
# Apply masks to frames
|
||||
matted_frames = []
|
||||
for frame_idx, frame in enumerate(tqdm(frames, desc="Applying masks")):
|
||||
if frame_idx in video_segments:
|
||||
frame_masks = video_segments[frame_idx]
|
||||
combined_mask = self.sam2_model.get_combined_mask(frame_masks)
|
||||
|
||||
matted_frame = self.sam2_model.apply_mask_to_frame(
|
||||
frame, combined_mask,
|
||||
output_format=self.config.output.format,
|
||||
background_color=self.config.output.background_color
|
||||
)
|
||||
else:
|
||||
# No mask for this frame
|
||||
matted_frame = self._create_empty_mask_frame(frame)
|
||||
|
||||
matted_frames.append(matted_frame)
|
||||
|
||||
# Cleanup SAM2 state
|
||||
self.sam2_model.cleanup()
|
||||
|
||||
return matted_frames
|
||||
|
||||
def _create_empty_masks(self, frames: List[np.ndarray]) -> List[np.ndarray]:
|
||||
"""Create empty masks when no persons detected"""
|
||||
empty_frames = []
|
||||
for frame in frames:
|
||||
empty_frame = self._create_empty_mask_frame(frame)
|
||||
empty_frames.append(empty_frame)
|
||||
return empty_frames
|
||||
|
||||
def _create_empty_mask_frame(self, frame: np.ndarray) -> np.ndarray:
|
||||
"""Create frame with empty mask (all background)"""
|
||||
if self.config.output.format == "alpha":
|
||||
# Transparent output
|
||||
output = np.zeros((frame.shape[0], frame.shape[1], 4), dtype=np.uint8)
|
||||
return output
|
||||
else:
|
||||
# Green screen background
|
||||
return np.full_like(frame, self.config.output.background_color, dtype=np.uint8)
|
||||
|
||||
def merge_overlapping_chunks(self,
|
||||
chunk_results: List[List[np.ndarray]],
|
||||
overlap_frames: int) -> List[np.ndarray]:
|
||||
"""
|
||||
Merge overlapping chunks with blending in overlap regions
|
||||
|
||||
Args:
|
||||
chunk_results: List of chunk results
|
||||
overlap_frames: Number of overlapping frames
|
||||
|
||||
Returns:
|
||||
Merged frame sequence
|
||||
"""
|
||||
if len(chunk_results) == 1:
|
||||
return chunk_results[0]
|
||||
|
||||
merged_frames = []
|
||||
|
||||
# Add first chunk completely
|
||||
merged_frames.extend(chunk_results[0])
|
||||
|
||||
# Process remaining chunks
|
||||
for chunk_idx in range(1, len(chunk_results)):
|
||||
chunk = chunk_results[chunk_idx]
|
||||
|
||||
if overlap_frames > 0:
|
||||
# Blend overlap region
|
||||
overlap_start = len(merged_frames) - overlap_frames
|
||||
|
||||
for i in range(overlap_frames):
|
||||
if i < len(chunk):
|
||||
# Linear blending
|
||||
alpha = i / overlap_frames
|
||||
|
||||
prev_frame = merged_frames[overlap_start + i]
|
||||
curr_frame = chunk[i]
|
||||
|
||||
blended = self._blend_frames(prev_frame, curr_frame, alpha)
|
||||
merged_frames[overlap_start + i] = blended
|
||||
|
||||
# Add remaining frames from current chunk
|
||||
merged_frames.extend(chunk[overlap_frames:])
|
||||
else:
|
||||
# No overlap, just concatenate
|
||||
merged_frames.extend(chunk)
|
||||
|
||||
return merged_frames
|
||||
|
||||
def _blend_frames(self, frame1: np.ndarray, frame2: np.ndarray, alpha: float) -> np.ndarray:
|
||||
"""Blend two frames with alpha blending"""
|
||||
if frame1.shape != frame2.shape:
|
||||
return frame2 # Fallback to second frame
|
||||
|
||||
blended = (1 - alpha) * frame1.astype(np.float32) + alpha * frame2.astype(np.float32)
|
||||
return blended.astype(np.uint8)
|
||||
|
||||
def save_video(self, frames: List[np.ndarray], output_path: str):
|
||||
"""
|
||||
Save processed frames as video
|
||||
|
||||
Args:
|
||||
frames: List of processed frames
|
||||
output_path: Output video path
|
||||
"""
|
||||
if not frames:
|
||||
raise ValueError("No frames to save")
|
||||
|
||||
output_path = Path(output_path)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Determine codec and format based on output format
|
||||
if self.config.output.format == "alpha":
|
||||
# Use PNG sequence for alpha channel
|
||||
self._save_png_sequence(frames, output_path.parent / f"{output_path.stem}_frames")
|
||||
else:
|
||||
# Save as regular video
|
||||
self._save_mp4_video(frames, str(output_path))
|
||||
|
||||
def _save_png_sequence(self, frames: List[np.ndarray], output_dir: Path):
|
||||
"""Save frames as PNG sequence with alpha channel"""
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for i, frame in enumerate(tqdm(frames, desc="Saving PNG sequence")):
|
||||
frame_path = output_dir / f"frame_{i:06d}.png"
|
||||
|
||||
# Convert BGR to RGBA for PNG
|
||||
if frame.shape[2] == 4: # Already RGBA
|
||||
frame_rgba = cv2.cvtColor(frame, cv2.COLOR_BGRA2RGBA)
|
||||
else: # BGR to RGBA
|
||||
frame_rgba = cv2.cvtColor(frame, cv2.COLOR_BGR2RGBA)
|
||||
|
||||
cv2.imwrite(str(frame_path), frame_rgba)
|
||||
|
||||
print(f"Saved {len(frames)} PNG frames to {output_dir}")
|
||||
|
||||
def _save_mp4_video(self, frames: List[np.ndarray], output_path: str):
|
||||
"""Save frames as MP4 video"""
|
||||
if not frames:
|
||||
return
|
||||
|
||||
height, width = frames[0].shape[:2]
|
||||
|
||||
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
||||
writer = cv2.VideoWriter(output_path, fourcc, self.fps, (width, height))
|
||||
|
||||
for frame in tqdm(frames, desc="Writing video"):
|
||||
if frame.shape[2] == 4: # Convert RGBA to BGR
|
||||
frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2BGR)
|
||||
writer.write(frame)
|
||||
|
||||
writer.release()
|
||||
print(f"Saved video to {output_path}")
|
||||
|
||||
def process_video(self) -> None:
|
||||
"""Main video processing pipeline"""
|
||||
print("Starting VR180 video processing...")
|
||||
|
||||
# Load video info
|
||||
self.load_video_info(self.config.input.video_path)
|
||||
|
||||
# Calculate chunking parameters
|
||||
chunk_size, overlap_frames = self.calculate_optimal_chunking()
|
||||
|
||||
# Process video in chunks
|
||||
chunk_results = []
|
||||
|
||||
for start_frame in range(0, self.total_frames, chunk_size - overlap_frames):
|
||||
end_frame = min(start_frame + chunk_size, self.total_frames)
|
||||
frames_to_read = end_frame - start_frame
|
||||
|
||||
chunk_idx = len(chunk_results)
|
||||
print(f"\nProcessing chunk {chunk_idx}: frames {start_frame}-{end_frame}")
|
||||
|
||||
# Read chunk frames
|
||||
frames = self.read_video_frames(
|
||||
self.config.input.video_path,
|
||||
start_frame=start_frame,
|
||||
num_frames=frames_to_read,
|
||||
scale_factor=self.config.processing.scale_factor
|
||||
)
|
||||
|
||||
# Process chunk
|
||||
matted_frames = self.process_chunk(frames, chunk_idx)
|
||||
chunk_results.append(matted_frames)
|
||||
|
||||
# Memory cleanup
|
||||
self.memory_manager.cleanup_memory()
|
||||
|
||||
if self.memory_manager.should_emergency_cleanup():
|
||||
self.memory_manager.emergency_cleanup()
|
||||
|
||||
# Merge chunks if multiple
|
||||
print("\nMerging chunks...")
|
||||
final_frames = self.merge_overlapping_chunks(chunk_results, overlap_frames)
|
||||
|
||||
# Save results
|
||||
print(f"Saving {len(final_frames)} processed frames...")
|
||||
self.save_video(final_frames, self.config.output.path)
|
||||
|
||||
# Print final memory report
|
||||
self.memory_manager.print_memory_report()
|
||||
|
||||
print("Video processing completed!")
|
||||
396
vr180_matting/vr180_processor.py
Normal file
396
vr180_matting/vr180_processor.py
Normal file
@@ -0,0 +1,396 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
from pathlib import Path
|
||||
import warnings
|
||||
|
||||
from .video_processor import VideoProcessor
|
||||
from .config import VR180Config
|
||||
|
||||
|
||||
class VR180Processor(VideoProcessor):
|
||||
"""Enhanced video processor with VR180-specific optimizations"""
|
||||
|
||||
def __init__(self, config: VR180Config):
|
||||
super().__init__(config)
|
||||
|
||||
# VR180 specific properties
|
||||
self.left_eye_width = 0
|
||||
self.right_eye_width = 0
|
||||
self.eye_height = 0
|
||||
self.sbs_split_point = 0
|
||||
|
||||
def analyze_sbs_layout(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Analyze side-by-side layout and determine eye regions
|
||||
|
||||
Returns:
|
||||
Dictionary with eye region information
|
||||
"""
|
||||
if self.video_info is None:
|
||||
raise RuntimeError("Video info not loaded")
|
||||
|
||||
total_width = self.video_info['width']
|
||||
total_height = self.video_info['height']
|
||||
|
||||
# Assume equal split for VR180 SBS
|
||||
self.sbs_split_point = total_width // 2
|
||||
self.left_eye_width = self.sbs_split_point
|
||||
self.right_eye_width = total_width - self.sbs_split_point
|
||||
self.eye_height = total_height
|
||||
|
||||
layout_info = {
|
||||
'total_width': total_width,
|
||||
'total_height': total_height,
|
||||
'split_point': self.sbs_split_point,
|
||||
'left_eye_region': (0, 0, self.left_eye_width, self.eye_height),
|
||||
'right_eye_region': (self.sbs_split_point, 0, self.right_eye_width, self.eye_height),
|
||||
'eye_aspect_ratio': self.left_eye_width / self.eye_height
|
||||
}
|
||||
|
||||
print(f"VR180 SBS Layout: {total_width}x{total_height}")
|
||||
print(f"Split point: {self.sbs_split_point}")
|
||||
print(f"Left eye: {self.left_eye_width}x{self.eye_height}")
|
||||
print(f"Right eye: {self.right_eye_width}x{self.eye_height}")
|
||||
|
||||
return layout_info
|
||||
|
||||
def split_sbs_frame(self, frame: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Split side-by-side frame into left and right eye views
|
||||
|
||||
Args:
|
||||
frame: Input SBS frame
|
||||
|
||||
Returns:
|
||||
Tuple of (left_eye_frame, right_eye_frame)
|
||||
"""
|
||||
if self.sbs_split_point == 0:
|
||||
self.sbs_split_point = frame.shape[1] // 2
|
||||
|
||||
left_eye = frame[:, :self.sbs_split_point]
|
||||
right_eye = frame[:, self.sbs_split_point:]
|
||||
|
||||
return left_eye, right_eye
|
||||
|
||||
def combine_sbs_frame(self, left_eye: np.ndarray, right_eye: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Combine left and right eye frames back into side-by-side format
|
||||
|
||||
Args:
|
||||
left_eye: Left eye frame
|
||||
right_eye: Right eye frame
|
||||
|
||||
Returns:
|
||||
Combined SBS frame
|
||||
"""
|
||||
# Ensure frames have same height
|
||||
if left_eye.shape[0] != right_eye.shape[0]:
|
||||
target_height = min(left_eye.shape[0], right_eye.shape[0])
|
||||
left_eye = cv2.resize(left_eye, (left_eye.shape[1], target_height))
|
||||
right_eye = cv2.resize(right_eye, (right_eye.shape[1], target_height))
|
||||
|
||||
# Combine horizontally
|
||||
combined = np.hstack([left_eye, right_eye])
|
||||
return combined
|
||||
|
||||
def process_with_disparity_mapping(self,
|
||||
frames: List[np.ndarray],
|
||||
chunk_idx: int = 0) -> List[np.ndarray]:
|
||||
"""
|
||||
Process frames using disparity mapping optimization
|
||||
|
||||
Args:
|
||||
frames: List of SBS frames
|
||||
chunk_idx: Chunk index
|
||||
|
||||
Returns:
|
||||
List of processed SBS frames
|
||||
"""
|
||||
print(f"Processing chunk {chunk_idx} with disparity mapping ({len(frames)} frames)")
|
||||
|
||||
# Split all frames into left/right eyes
|
||||
left_eye_frames = []
|
||||
right_eye_frames = []
|
||||
|
||||
for frame in frames:
|
||||
left, right = self.split_sbs_frame(frame)
|
||||
left_eye_frames.append(left)
|
||||
right_eye_frames.append(right)
|
||||
|
||||
# Process left eye at full quality
|
||||
print("Processing left eye...")
|
||||
with self.memory_manager.memory_monitor(f"left eye chunk {chunk_idx}"):
|
||||
left_matted = self._process_eye_sequence(left_eye_frames, "left", chunk_idx)
|
||||
|
||||
# Process right eye with cross-validation
|
||||
print("Processing right eye with cross-validation...")
|
||||
with self.memory_manager.memory_monitor(f"right eye chunk {chunk_idx}"):
|
||||
right_matted = self._process_eye_sequence_with_validation(
|
||||
right_eye_frames, left_matted, "right", chunk_idx
|
||||
)
|
||||
|
||||
# Combine results back to SBS format
|
||||
combined_frames = []
|
||||
for left_frame, right_frame in zip(left_matted, right_matted):
|
||||
if self.config.output.maintain_sbs:
|
||||
combined = self.combine_sbs_frame(left_frame, right_frame)
|
||||
else:
|
||||
# Return as separate eye outputs
|
||||
combined = {'left': left_frame, 'right': right_frame}
|
||||
combined_frames.append(combined)
|
||||
|
||||
return combined_frames
|
||||
|
||||
def _process_eye_sequence(self,
|
||||
eye_frames: List[np.ndarray],
|
||||
eye_name: str,
|
||||
chunk_idx: int) -> List[np.ndarray]:
|
||||
"""Process a single eye sequence"""
|
||||
if not eye_frames:
|
||||
return []
|
||||
|
||||
# Initialize SAM2 with eye frames
|
||||
self.sam2_model.init_video_state(eye_frames)
|
||||
|
||||
# Detect persons in first frame
|
||||
first_frame = eye_frames[0]
|
||||
detections = self.detector.detect_persons(first_frame)
|
||||
|
||||
if not detections:
|
||||
warnings.warn(f"No persons detected in {eye_name} eye, chunk {chunk_idx}")
|
||||
return self._create_empty_masks(eye_frames)
|
||||
|
||||
print(f"Detected {len(detections)} persons in {eye_name} eye first frame")
|
||||
|
||||
# Convert to SAM2 prompts
|
||||
box_prompts, labels = self.detector.convert_to_sam_prompts(detections)
|
||||
|
||||
# Add prompts
|
||||
object_ids = self.sam2_model.add_person_prompts(0, box_prompts, labels)
|
||||
|
||||
# Propagate masks
|
||||
video_segments = self.sam2_model.propagate_masks(
|
||||
start_frame=0,
|
||||
max_frames=len(eye_frames)
|
||||
)
|
||||
|
||||
# Apply masks
|
||||
matted_frames = []
|
||||
for frame_idx, frame in enumerate(eye_frames):
|
||||
if frame_idx in video_segments:
|
||||
frame_masks = video_segments[frame_idx]
|
||||
combined_mask = self.sam2_model.get_combined_mask(frame_masks)
|
||||
|
||||
matted_frame = self.sam2_model.apply_mask_to_frame(
|
||||
frame, combined_mask,
|
||||
output_format=self.config.output.format,
|
||||
background_color=self.config.output.background_color
|
||||
)
|
||||
else:
|
||||
matted_frame = self._create_empty_mask_frame(frame)
|
||||
|
||||
matted_frames.append(matted_frame)
|
||||
|
||||
# Cleanup
|
||||
self.sam2_model.cleanup()
|
||||
|
||||
return matted_frames
|
||||
|
||||
def _process_eye_sequence_with_validation(self,
|
||||
right_eye_frames: List[np.ndarray],
|
||||
left_eye_results: List[np.ndarray],
|
||||
eye_name: str,
|
||||
chunk_idx: int) -> List[np.ndarray]:
|
||||
"""
|
||||
Process right eye with validation against left eye results
|
||||
|
||||
Args:
|
||||
right_eye_frames: Right eye frame sequence
|
||||
left_eye_results: Processed left eye results for validation
|
||||
eye_name: Eye identifier
|
||||
chunk_idx: Chunk index
|
||||
|
||||
Returns:
|
||||
Processed right eye frames
|
||||
"""
|
||||
# For now, process right eye independently
|
||||
# TODO: Implement stereo consistency validation
|
||||
right_matted = self._process_eye_sequence(right_eye_frames, eye_name, chunk_idx)
|
||||
|
||||
# Apply stereo consistency checks
|
||||
validated_results = self._validate_stereo_consistency(
|
||||
left_eye_results, right_matted
|
||||
)
|
||||
|
||||
return validated_results
|
||||
|
||||
def _validate_stereo_consistency(self,
|
||||
left_results: List[np.ndarray],
|
||||
right_results: List[np.ndarray]) -> List[np.ndarray]:
|
||||
"""
|
||||
Validate and correct stereo consistency between left and right eye results
|
||||
|
||||
Args:
|
||||
left_results: Left eye processed frames
|
||||
right_results: Right eye processed frames
|
||||
|
||||
Returns:
|
||||
Validated right eye frames
|
||||
"""
|
||||
validated_frames = []
|
||||
|
||||
for i, (left_frame, right_frame) in enumerate(zip(left_results, right_results)):
|
||||
# Simple validation: check if mask areas are similar
|
||||
left_mask_area = self._get_mask_area(left_frame)
|
||||
right_mask_area = self._get_mask_area(right_frame)
|
||||
|
||||
# If areas differ significantly, apply correction
|
||||
area_ratio = right_mask_area / (left_mask_area + 1e-6)
|
||||
|
||||
if area_ratio < 0.5 or area_ratio > 2.0:
|
||||
# Significant difference - apply correction
|
||||
corrected_frame = self._apply_stereo_correction(
|
||||
left_frame, right_frame, area_ratio
|
||||
)
|
||||
validated_frames.append(corrected_frame)
|
||||
else:
|
||||
validated_frames.append(right_frame)
|
||||
|
||||
return validated_frames
|
||||
|
||||
def _get_mask_area(self, frame: np.ndarray) -> float:
|
||||
"""Get mask area from processed frame"""
|
||||
if frame.shape[2] == 4: # Alpha channel
|
||||
mask = frame[:, :, 3] > 0
|
||||
else: # Green screen - detect non-background pixels
|
||||
bg_color = np.array(self.config.output.background_color)
|
||||
diff = np.abs(frame.astype(np.float32) - bg_color).sum(axis=2)
|
||||
mask = diff > 30 # Threshold for non-background
|
||||
|
||||
return np.sum(mask)
|
||||
|
||||
def _apply_stereo_correction(self,
|
||||
left_frame: np.ndarray,
|
||||
right_frame: np.ndarray,
|
||||
area_ratio: float) -> np.ndarray:
|
||||
"""
|
||||
Apply stereo correction to right frame based on left frame
|
||||
|
||||
This is a simplified correction - in production, you'd use
|
||||
proper disparity mapping and stereo geometry
|
||||
"""
|
||||
# For now, return the right frame as-is
|
||||
# TODO: Implement proper stereo correction algorithm
|
||||
return right_frame
|
||||
|
||||
def process_chunk(self,
|
||||
frames: List[np.ndarray],
|
||||
chunk_idx: int = 0) -> List[np.ndarray]:
|
||||
"""
|
||||
Override parent method to handle VR180-specific processing
|
||||
|
||||
Args:
|
||||
frames: List of SBS frames to process
|
||||
chunk_idx: Chunk index for logging
|
||||
|
||||
Returns:
|
||||
List of processed frames
|
||||
"""
|
||||
if not frames:
|
||||
return []
|
||||
|
||||
# Analyze SBS layout if not done yet
|
||||
if self.sbs_split_point == 0:
|
||||
sample_frame = frames[0]
|
||||
self.sbs_split_point = sample_frame.shape[1] // 2
|
||||
|
||||
# Choose processing method based on configuration
|
||||
if self.config.matting.use_disparity_mapping:
|
||||
return self.process_with_disparity_mapping(frames, chunk_idx)
|
||||
else:
|
||||
# Process each eye independently and combine
|
||||
return self._process_eyes_independently(frames, chunk_idx)
|
||||
|
||||
def _process_eyes_independently(self,
|
||||
frames: List[np.ndarray],
|
||||
chunk_idx: int) -> List[np.ndarray]:
|
||||
"""Process left and right eyes independently"""
|
||||
print(f"Processing chunk {chunk_idx} with independent eye processing")
|
||||
|
||||
# Split frames
|
||||
left_eye_frames = []
|
||||
right_eye_frames = []
|
||||
|
||||
for frame in frames:
|
||||
left, right = self.split_sbs_frame(frame)
|
||||
left_eye_frames.append(left)
|
||||
right_eye_frames.append(right)
|
||||
|
||||
# Process each eye
|
||||
print("Processing left eye...")
|
||||
left_matted = self._process_eye_sequence(left_eye_frames, "left", chunk_idx)
|
||||
|
||||
print("Processing right eye...")
|
||||
right_matted = self._process_eye_sequence(right_eye_frames, "right", chunk_idx)
|
||||
|
||||
# Combine results
|
||||
combined_frames = []
|
||||
for left_frame, right_frame in zip(left_matted, right_matted):
|
||||
if self.config.output.maintain_sbs:
|
||||
combined = self.combine_sbs_frame(left_frame, right_frame)
|
||||
else:
|
||||
combined = {'left': left_frame, 'right': right_frame}
|
||||
combined_frames.append(combined)
|
||||
|
||||
return combined_frames
|
||||
|
||||
def save_video(self, frames: List[np.ndarray], output_path: str):
|
||||
"""
|
||||
Override parent method to handle VR180-specific output formats
|
||||
|
||||
Args:
|
||||
frames: List of processed frames
|
||||
output_path: Output path
|
||||
"""
|
||||
if not frames:
|
||||
raise ValueError("No frames to save")
|
||||
|
||||
output_path = Path(output_path)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Check if frames are in separate eye format
|
||||
if isinstance(frames[0], dict) and 'left' in frames[0]:
|
||||
# Save separate eye videos
|
||||
self._save_separate_eye_videos(frames, output_path)
|
||||
else:
|
||||
# Save as combined SBS video
|
||||
super().save_video(frames, str(output_path))
|
||||
|
||||
def _save_separate_eye_videos(self, frames: List[Dict[str, np.ndarray]], output_path: Path):
|
||||
"""Save left and right eye videos separately"""
|
||||
left_frames = [frame['left'] for frame in frames]
|
||||
right_frames = [frame['right'] for frame in frames]
|
||||
|
||||
# Save left eye
|
||||
left_path = output_path.parent / f"{output_path.stem}_left{output_path.suffix}"
|
||||
super().save_video(left_frames, str(left_path))
|
||||
|
||||
# Save right eye
|
||||
right_path = output_path.parent / f"{output_path.stem}_right{output_path.suffix}"
|
||||
super().save_video(right_frames, str(right_path))
|
||||
|
||||
print(f"Saved separate eye videos: {left_path}, {right_path}")
|
||||
|
||||
def process_video(self) -> None:
|
||||
"""
|
||||
Override parent method to add VR180-specific initialization
|
||||
"""
|
||||
print("Starting VR180 video processing...")
|
||||
|
||||
# Load video info and analyze SBS layout
|
||||
self.load_video_info(self.config.input.video_path)
|
||||
self.analyze_sbs_layout()
|
||||
|
||||
# Continue with parent processing
|
||||
super().process_video()
|
||||
Reference in New Issue
Block a user