""" Configuration loader for YOLO + SAM2 video processing pipeline. Handles loading and validation of YAML configuration files. """ import yaml import os from typing import Dict, Any, List, Union import logging logger = logging.getLogger(__name__) class ConfigLoader: """Loads and validates configuration from YAML files.""" def __init__(self, config_path: str): self.config_path = config_path self.config = self._load_config() self._validate_config() def _load_config(self) -> Dict[str, Any]: """Load configuration from YAML file.""" if not os.path.exists(self.config_path): raise FileNotFoundError(f"Configuration file not found: {self.config_path}") try: with open(self.config_path, 'r') as file: config = yaml.safe_load(file) logger.info(f"Loaded configuration from {self.config_path}") return config except yaml.YAMLError as e: raise ValueError(f"Error parsing YAML file: {e}") def _validate_config(self): """Validate required configuration fields.""" required_sections = ['input', 'output', 'processing', 'models'] for section in required_sections: if section not in self.config: raise ValueError(f"Missing required configuration section: {section}") # Validate input section if 'video_path' not in self.config['input']: raise ValueError("Missing required field: input.video_path") # Validate output section required_output_fields = ['directory', 'filename'] for field in required_output_fields: if field not in self.config['output']: raise ValueError(f"Missing required field: output.{field}") # Validate models section required_model_fields = ['yolo_model', 'sam2_checkpoint', 'sam2_config'] for field in required_model_fields: if field not in self.config['models']: raise ValueError(f"Missing required field: models.{field}") # Validate processing.detect_segments format detect_segments = self.config['processing'].get('detect_segments', 'all') if not isinstance(detect_segments, (str, list)): raise ValueError("detect_segments must be 'all' or a list of integers") if isinstance(detect_segments, list): if not all(isinstance(x, int) for x in detect_segments): raise ValueError("detect_segments list must contain only integers") def get(self, key_path: str, default=None): """ Get configuration value using dot notation. Args: key_path: Dot-separated key path (e.g., 'processing.yolo_confidence') default: Default value if key not found Returns: Configuration value or default """ keys = key_path.split('.') value = self.config try: for key in keys: value = value[key] return value except (KeyError, TypeError): return default def get_input_video_path(self) -> str: """Get input video path.""" return self.config['input']['video_path'] def get_output_directory(self) -> str: """Get output directory path.""" return self.config['output']['directory'] def get_output_filename(self) -> str: """Get output filename.""" return self.config['output']['filename'] def get_segment_duration(self) -> int: """Get segment duration in seconds.""" return self.config['processing'].get('segment_duration', 5) def get_inference_scale(self) -> float: """Get inference scale factor.""" return self.config['processing'].get('inference_scale', 0.5) def get_yolo_confidence(self) -> float: """Get YOLO confidence threshold.""" return self.config['processing'].get('yolo_confidence', 0.6) def get_detect_segments(self) -> Union[str, List[int]]: """Get segments for YOLO detection.""" return self.config['processing'].get('detect_segments', 'all') def get_yolo_model_path(self) -> str: """Get YOLO model path.""" return self.config['models']['yolo_model'] def get_sam2_checkpoint(self) -> str: """Get SAM2 checkpoint path.""" return self.config['models']['sam2_checkpoint'] def get_sam2_config(self) -> str: """Get SAM2 config path.""" return self.config['models']['sam2_config'] def get_use_nvenc(self) -> bool: """Get whether to use NVIDIA encoding.""" return self.config.get('video', {}).get('use_nvenc', True) def get_preserve_audio(self) -> bool: """Get whether to preserve audio.""" return self.config.get('video', {}).get('preserve_audio', True) def get_output_bitrate(self) -> str: """Get output video bitrate.""" return self.config.get('video', {}).get('output_bitrate', '50M') def get_green_color(self) -> List[int]: """Get green screen color.""" return self.config.get('advanced', {}).get('green_color', [0, 255, 0]) def get_blue_color(self) -> List[int]: """Get blue screen color.""" return self.config.get('advanced', {}).get('blue_color', [255, 0, 0]) def get_human_class_id(self) -> int: """Get YOLO human class ID.""" return self.config.get('advanced', {}).get('human_class_id', 0) def get_log_level(self) -> str: """Get logging level.""" return self.config.get('advanced', {}).get('log_level', 'INFO') def should_cleanup_intermediate_files(self) -> bool: """Get whether to cleanup intermediate files.""" return self.config.get('advanced', {}).get('cleanup_intermediate_files', True)