Files
samyolo_on_segments/core/config_loader.py
2025-07-27 11:43:07 -07:00

158 lines
5.9 KiB
Python

"""
Configuration loader for YOLO + SAM2 video processing pipeline.
Handles loading and validation of YAML configuration files.
"""
import yaml
import os
from typing import Dict, Any, List, Union
import logging
logger = logging.getLogger(__name__)
class ConfigLoader:
"""Loads and validates configuration from YAML files."""
def __init__(self, config_path: str):
self.config_path = config_path
self.config = self._load_config()
self._validate_config()
def _load_config(self) -> Dict[str, Any]:
"""Load configuration from YAML file."""
if not os.path.exists(self.config_path):
raise FileNotFoundError(f"Configuration file not found: {self.config_path}")
try:
with open(self.config_path, 'r') as file:
config = yaml.safe_load(file)
logger.info(f"Loaded configuration from {self.config_path}")
return config
except yaml.YAMLError as e:
raise ValueError(f"Error parsing YAML file: {e}")
def _validate_config(self):
"""Validate required configuration fields."""
required_sections = ['input', 'output', 'processing', 'models']
for section in required_sections:
if section not in self.config:
raise ValueError(f"Missing required configuration section: {section}")
# Validate input section
if 'video_path' not in self.config['input']:
raise ValueError("Missing required field: input.video_path")
# Validate output section
required_output_fields = ['directory', 'filename']
for field in required_output_fields:
if field not in self.config['output']:
raise ValueError(f"Missing required field: output.{field}")
# Validate models section
required_model_fields = ['yolo_model', 'sam2_checkpoint', 'sam2_config']
for field in required_model_fields:
if field not in self.config['models']:
raise ValueError(f"Missing required field: models.{field}")
# Validate processing.detect_segments format
detect_segments = self.config['processing'].get('detect_segments', 'all')
if not isinstance(detect_segments, (str, list)):
raise ValueError("detect_segments must be 'all' or a list of integers")
if isinstance(detect_segments, list):
if not all(isinstance(x, int) for x in detect_segments):
raise ValueError("detect_segments list must contain only integers")
def get(self, key_path: str, default=None):
"""
Get configuration value using dot notation.
Args:
key_path: Dot-separated key path (e.g., 'processing.yolo_confidence')
default: Default value if key not found
Returns:
Configuration value or default
"""
keys = key_path.split('.')
value = self.config
try:
for key in keys:
value = value[key]
return value
except (KeyError, TypeError):
return default
def get_input_video_path(self) -> str:
"""Get input video path."""
return self.config['input']['video_path']
def get_output_directory(self) -> str:
"""Get output directory path."""
return self.config['output']['directory']
def get_output_filename(self) -> str:
"""Get output filename."""
return self.config['output']['filename']
def get_segment_duration(self) -> int:
"""Get segment duration in seconds."""
return self.config['processing'].get('segment_duration', 5)
def get_inference_scale(self) -> float:
"""Get inference scale factor."""
return self.config['processing'].get('inference_scale', 0.5)
def get_yolo_confidence(self) -> float:
"""Get YOLO confidence threshold."""
return self.config['processing'].get('yolo_confidence', 0.6)
def get_detect_segments(self) -> Union[str, List[int]]:
"""Get segments for YOLO detection."""
return self.config['processing'].get('detect_segments', 'all')
def get_yolo_model_path(self) -> str:
"""Get YOLO model path."""
return self.config['models']['yolo_model']
def get_sam2_checkpoint(self) -> str:
"""Get SAM2 checkpoint path."""
return self.config['models']['sam2_checkpoint']
def get_sam2_config(self) -> str:
"""Get SAM2 config path."""
return self.config['models']['sam2_config']
def get_use_nvenc(self) -> bool:
"""Get whether to use NVIDIA encoding."""
return self.config.get('video', {}).get('use_nvenc', True)
def get_preserve_audio(self) -> bool:
"""Get whether to preserve audio."""
return self.config.get('video', {}).get('preserve_audio', True)
def get_output_bitrate(self) -> str:
"""Get output video bitrate."""
return self.config.get('video', {}).get('output_bitrate', '50M')
def get_green_color(self) -> List[int]:
"""Get green screen color."""
return self.config.get('advanced', {}).get('green_color', [0, 255, 0])
def get_blue_color(self) -> List[int]:
"""Get blue screen color."""
return self.config.get('advanced', {}).get('blue_color', [255, 0, 0])
def get_human_class_id(self) -> int:
"""Get YOLO human class ID."""
return self.config.get('advanced', {}).get('human_class_id', 0)
def get_log_level(self) -> str:
"""Get logging level."""
return self.config.get('advanced', {}).get('log_level', 'INFO')
def should_cleanup_intermediate_files(self) -> bool:
"""Get whether to cleanup intermediate files."""
return self.config.get('advanced', {}).get('cleanup_intermediate_files', True)