156 lines
5.2 KiB
Python
156 lines
5.2 KiB
Python
import yaml
|
|
from dataclasses import dataclass
|
|
from typing import List, Optional, Union
|
|
from pathlib import Path
|
|
|
|
|
|
@dataclass
|
|
class InputConfig:
|
|
video_path: str
|
|
|
|
|
|
@dataclass
|
|
class ProcessingConfig:
|
|
scale_factor: float = 0.5
|
|
chunk_size: int = 900
|
|
overlap_frames: int = 60
|
|
|
|
|
|
@dataclass
|
|
class DetectionConfig:
|
|
confidence_threshold: float = 0.7
|
|
model: str = "yolov8n"
|
|
|
|
|
|
@dataclass
|
|
class MattingConfig:
|
|
use_disparity_mapping: bool = True
|
|
memory_offload: bool = True
|
|
fp16: bool = True
|
|
sam2_model_cfg: str = "sam2.1_hiera_l"
|
|
sam2_checkpoint: str = "segment-anything-2/checkpoints/sam2.1_hiera_large.pt"
|
|
# Det-SAM2 optimizations
|
|
continuous_correction: bool = True
|
|
correction_interval: int = 60 # Add correction prompts every N frames
|
|
frame_release_interval: int = 50 # Release old frames every N frames
|
|
frame_window_size: int = 30 # Keep N frames in memory
|
|
|
|
|
|
@dataclass
|
|
class OutputConfig:
|
|
path: str
|
|
format: str = "alpha"
|
|
background_color: List[int] = None
|
|
maintain_sbs: bool = True
|
|
preserve_audio: bool = True
|
|
verify_sync: bool = True
|
|
|
|
def __post_init__(self):
|
|
if self.background_color is None:
|
|
self.background_color = [0, 255, 0]
|
|
|
|
|
|
@dataclass
|
|
class HardwareConfig:
|
|
device: str = "cuda"
|
|
max_vram_gb: int = 10
|
|
|
|
|
|
@dataclass
|
|
class VR180Config:
|
|
input: InputConfig
|
|
processing: ProcessingConfig
|
|
detection: DetectionConfig
|
|
matting: MattingConfig
|
|
output: OutputConfig
|
|
hardware: HardwareConfig
|
|
|
|
@classmethod
|
|
def from_yaml(cls, config_path: Union[str, Path]) -> "VR180Config":
|
|
"""Load configuration from YAML file"""
|
|
with open(config_path, 'r') as f:
|
|
data = yaml.safe_load(f)
|
|
|
|
return cls(
|
|
input=InputConfig(**data['input']),
|
|
processing=ProcessingConfig(**data.get('processing', {})),
|
|
detection=DetectionConfig(**data.get('detection', {})),
|
|
matting=MattingConfig(**data.get('matting', {})),
|
|
output=OutputConfig(**data['output']),
|
|
hardware=HardwareConfig(**data.get('hardware', {}))
|
|
)
|
|
|
|
def to_yaml(self, config_path: Union[str, Path]) -> None:
|
|
"""Save configuration to YAML file"""
|
|
data = {
|
|
'input': {
|
|
'video_path': self.input.video_path
|
|
},
|
|
'processing': {
|
|
'scale_factor': self.processing.scale_factor,
|
|
'chunk_size': self.processing.chunk_size,
|
|
'overlap_frames': self.processing.overlap_frames
|
|
},
|
|
'detection': {
|
|
'confidence_threshold': self.detection.confidence_threshold,
|
|
'model': self.detection.model
|
|
},
|
|
'matting': {
|
|
'use_disparity_mapping': self.matting.use_disparity_mapping,
|
|
'memory_offload': self.matting.memory_offload,
|
|
'fp16': self.matting.fp16,
|
|
'sam2_model_cfg': self.matting.sam2_model_cfg,
|
|
'sam2_checkpoint': self.matting.sam2_checkpoint
|
|
},
|
|
'output': {
|
|
'path': self.output.path,
|
|
'format': self.output.format,
|
|
'background_color': self.output.background_color,
|
|
'maintain_sbs': self.output.maintain_sbs,
|
|
'preserve_audio': self.output.preserve_audio,
|
|
'verify_sync': self.output.verify_sync
|
|
},
|
|
'hardware': {
|
|
'device': self.hardware.device,
|
|
'max_vram_gb': self.hardware.max_vram_gb
|
|
}
|
|
}
|
|
|
|
with open(config_path, 'w') as f:
|
|
yaml.dump(data, f, default_flow_style=False, indent=2)
|
|
|
|
def validate(self) -> List[str]:
|
|
"""Validate configuration and return list of errors"""
|
|
errors = []
|
|
|
|
if not Path(self.input.video_path).exists():
|
|
errors.append(f"Input video path does not exist: {self.input.video_path}")
|
|
|
|
if not 0.1 <= self.processing.scale_factor <= 1.0:
|
|
errors.append("Scale factor must be between 0.1 and 1.0")
|
|
|
|
if self.processing.chunk_size < 0:
|
|
errors.append("Chunk size must be non-negative (0 for full video)")
|
|
|
|
if not 0.1 <= self.detection.confidence_threshold <= 1.0:
|
|
errors.append("Confidence threshold must be between 0.1 and 1.0")
|
|
|
|
if self.detection.model not in ["yolov8n", "yolov8s", "yolov8m", "yolov8l", "yolov8x"]:
|
|
errors.append(f"Unsupported YOLO model: {self.detection.model}")
|
|
|
|
if self.output.format not in ["alpha", "greenscreen"]:
|
|
errors.append("Output format must be 'alpha' or 'greenscreen'")
|
|
|
|
if len(self.output.background_color) != 3:
|
|
errors.append("Background color must be RGB list with 3 values")
|
|
|
|
if not all(0 <= c <= 255 for c in self.output.background_color):
|
|
errors.append("Background color values must be between 0 and 255")
|
|
|
|
if self.hardware.device not in ["cuda", "cpu"]:
|
|
errors.append("Device must be 'cuda' or 'cpu'")
|
|
|
|
if self.hardware.max_vram_gb <= 0:
|
|
errors.append("Max VRAM must be positive")
|
|
|
|
return errors |