""" Configuration loader for YOLO + SAM2 video processing pipeline. Handles loading and validation of YAML configuration files. """ import yaml import os from typing import Dict, Any, List, Union import logging logger = logging.getLogger(__name__) class ConfigLoader: """Loads and validates configuration from YAML files.""" def __init__(self, config_path: str): self.config_path = config_path self.config = self._load_config() self._validate_config() def _load_config(self) -> Dict[str, Any]: """Load configuration from YAML file.""" if not os.path.exists(self.config_path): raise FileNotFoundError(f"Configuration file not found: {self.config_path}") try: with open(self.config_path, 'r') as file: config = yaml.safe_load(file) logger.info(f"Loaded configuration from {self.config_path}") return config except yaml.YAMLError as e: raise ValueError(f"Error parsing YAML file: {e}") def _validate_config(self): """Validate required configuration fields.""" required_sections = ['input', 'output', 'processing', 'models'] for section in required_sections: if section not in self.config: raise ValueError(f"Missing required configuration section: {section}") # Validate input section if 'video_path' not in self.config['input']: raise ValueError("Missing required field: input.video_path") # Validate output section required_output_fields = ['directory', 'filename'] for field in required_output_fields: if field not in self.config['output']: raise ValueError(f"Missing required field: output.{field}") # Validate models section required_model_fields = ['sam2_checkpoint', 'sam2_config'] for field in required_model_fields: if field not in self.config['models']: raise ValueError(f"Missing required field: models.{field}") # Validate YOLO model configuration yolo_mode = self.config['models'].get('yolo_mode', 'detection') if yolo_mode not in ['detection', 'segmentation']: raise ValueError(f"Invalid yolo_mode: {yolo_mode}. Must be 'detection' or 'segmentation'") # Check for legacy yolo_model field vs new structure has_legacy_yolo_model = 'yolo_model' in self.config['models'] has_new_yolo_models = 'yolo_detection_model' in self.config['models'] or 'yolo_segmentation_model' in self.config['models'] if not has_legacy_yolo_model and not has_new_yolo_models: raise ValueError("Missing YOLO model configuration. Provide either 'yolo_model' (legacy) or 'yolo_detection_model'/'yolo_segmentation_model' (new)") # Validate that the required model for the current mode exists if yolo_mode == 'detection': if has_new_yolo_models and 'yolo_detection_model' not in self.config['models']: raise ValueError("yolo_mode is 'detection' but yolo_detection_model not specified") elif yolo_mode == 'segmentation': if has_new_yolo_models and 'yolo_segmentation_model' not in self.config['models']: raise ValueError("yolo_mode is 'segmentation' but yolo_segmentation_model not specified") # Validate processing.detect_segments format detect_segments = self.config['processing'].get('detect_segments', 'all') if not isinstance(detect_segments, (str, list)): raise ValueError("detect_segments must be 'all' or a list of integers") if isinstance(detect_segments, list): if not all(isinstance(x, int) for x in detect_segments): raise ValueError("detect_segments list must contain only integers") def get(self, key_path: str, default=None): """ Get configuration value using dot notation. Args: key_path: Dot-separated key path (e.g., 'processing.yolo_confidence') default: Default value if key not found Returns: Configuration value or default """ keys = key_path.split('.') value = self.config try: for key in keys: value = value[key] return value except (KeyError, TypeError): return default def get_input_video_path(self) -> str: """Get input video path.""" return self.config['input']['video_path'] def get_output_directory(self) -> str: """Get output directory path.""" return self.config['output']['directory'] def get_output_filename(self) -> str: """Get output filename.""" return self.config['output']['filename'] def get_segment_duration(self) -> int: """Get segment duration in seconds.""" return self.config['processing'].get('segment_duration', 5) def get_inference_scale(self) -> float: """Get inference scale factor.""" return self.config['processing'].get('inference_scale', 0.5) def get_yolo_confidence(self) -> float: """Get YOLO confidence threshold.""" return self.config['processing'].get('yolo_confidence', 0.6) def get_detect_segments(self) -> Union[str, List[int]]: """Get segments for YOLO detection.""" return self.config['processing'].get('detect_segments', 'all') def get_yolo_model_path(self) -> str: """Get YOLO model path (legacy method for backward compatibility).""" # Check for legacy configuration first if 'yolo_model' in self.config['models']: return self.config['models']['yolo_model'] # Use new configuration based on mode yolo_mode = self.config['models'].get('yolo_mode', 'detection') if yolo_mode == 'detection': return self.config['models'].get('yolo_detection_model', 'yolov8n.pt') else: # segmentation mode return self.config['models'].get('yolo_segmentation_model', 'yolov8n-seg.pt') def get_sam2_checkpoint(self) -> str: """Get SAM2 checkpoint path.""" return self.config['models']['sam2_checkpoint'] def get_sam2_config(self) -> str: """Get SAM2 config path.""" return self.config['models']['sam2_config'] def get_use_nvenc(self) -> bool: """Get whether to use NVIDIA encoding.""" return self.config.get('video', {}).get('use_nvenc', True) def get_preserve_audio(self) -> bool: """Get whether to preserve audio.""" return self.config.get('video', {}).get('preserve_audio', True) def get_output_bitrate(self) -> str: """Get output video bitrate.""" return self.config.get('video', {}).get('output_bitrate', '50M') def get_green_color(self) -> List[int]: """Get green screen color.""" return self.config.get('advanced', {}).get('green_color', [0, 255, 0]) def get_blue_color(self) -> List[int]: """Get blue screen color.""" return self.config.get('advanced', {}).get('blue_color', [255, 0, 0]) def get_human_class_id(self) -> int: """Get YOLO human class ID.""" return self.config.get('advanced', {}).get('human_class_id', 0) def get_log_level(self) -> str: """Get logging level.""" return self.config.get('advanced', {}).get('log_level', 'INFO') def should_cleanup_intermediate_files(self) -> bool: """Get whether to cleanup intermediate files.""" return self.config.get('advanced', {}).get('cleanup_intermediate_files', True)