inital commit
This commit is contained in:
158
core/config_loader.py
Normal file
158
core/config_loader.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""
|
||||
Configuration loader for YOLO + SAM2 video processing pipeline.
|
||||
Handles loading and validation of YAML configuration files.
|
||||
"""
|
||||
|
||||
import yaml
|
||||
import os
|
||||
from typing import Dict, Any, List, Union
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ConfigLoader:
|
||||
"""Loads and validates configuration from YAML files."""
|
||||
|
||||
def __init__(self, config_path: str):
|
||||
self.config_path = config_path
|
||||
self.config = self._load_config()
|
||||
self._validate_config()
|
||||
|
||||
def _load_config(self) -> Dict[str, Any]:
|
||||
"""Load configuration from YAML file."""
|
||||
if not os.path.exists(self.config_path):
|
||||
raise FileNotFoundError(f"Configuration file not found: {self.config_path}")
|
||||
|
||||
try:
|
||||
with open(self.config_path, 'r') as file:
|
||||
config = yaml.safe_load(file)
|
||||
logger.info(f"Loaded configuration from {self.config_path}")
|
||||
return config
|
||||
except yaml.YAMLError as e:
|
||||
raise ValueError(f"Error parsing YAML file: {e}")
|
||||
|
||||
def _validate_config(self):
|
||||
"""Validate required configuration fields."""
|
||||
required_sections = ['input', 'output', 'processing', 'models']
|
||||
|
||||
for section in required_sections:
|
||||
if section not in self.config:
|
||||
raise ValueError(f"Missing required configuration section: {section}")
|
||||
|
||||
# Validate input section
|
||||
if 'video_path' not in self.config['input']:
|
||||
raise ValueError("Missing required field: input.video_path")
|
||||
|
||||
# Validate output section
|
||||
required_output_fields = ['directory', 'filename']
|
||||
for field in required_output_fields:
|
||||
if field not in self.config['output']:
|
||||
raise ValueError(f"Missing required field: output.{field}")
|
||||
|
||||
# Validate models section
|
||||
required_model_fields = ['yolo_model', 'sam2_checkpoint', 'sam2_config']
|
||||
for field in required_model_fields:
|
||||
if field not in self.config['models']:
|
||||
raise ValueError(f"Missing required field: models.{field}")
|
||||
|
||||
# Validate processing.detect_segments format
|
||||
detect_segments = self.config['processing'].get('detect_segments', 'all')
|
||||
if not isinstance(detect_segments, (str, list)):
|
||||
raise ValueError("detect_segments must be 'all' or a list of integers")
|
||||
|
||||
if isinstance(detect_segments, list):
|
||||
if not all(isinstance(x, int) for x in detect_segments):
|
||||
raise ValueError("detect_segments list must contain only integers")
|
||||
|
||||
def get(self, key_path: str, default=None):
|
||||
"""
|
||||
Get configuration value using dot notation.
|
||||
|
||||
Args:
|
||||
key_path: Dot-separated key path (e.g., 'processing.yolo_confidence')
|
||||
default: Default value if key not found
|
||||
|
||||
Returns:
|
||||
Configuration value or default
|
||||
"""
|
||||
keys = key_path.split('.')
|
||||
value = self.config
|
||||
|
||||
try:
|
||||
for key in keys:
|
||||
value = value[key]
|
||||
return value
|
||||
except (KeyError, TypeError):
|
||||
return default
|
||||
|
||||
def get_input_video_path(self) -> str:
|
||||
"""Get input video path."""
|
||||
return self.config['input']['video_path']
|
||||
|
||||
def get_output_directory(self) -> str:
|
||||
"""Get output directory path."""
|
||||
return self.config['output']['directory']
|
||||
|
||||
def get_output_filename(self) -> str:
|
||||
"""Get output filename."""
|
||||
return self.config['output']['filename']
|
||||
|
||||
def get_segment_duration(self) -> int:
|
||||
"""Get segment duration in seconds."""
|
||||
return self.config['processing'].get('segment_duration', 5)
|
||||
|
||||
def get_inference_scale(self) -> float:
|
||||
"""Get inference scale factor."""
|
||||
return self.config['processing'].get('inference_scale', 0.5)
|
||||
|
||||
def get_yolo_confidence(self) -> float:
|
||||
"""Get YOLO confidence threshold."""
|
||||
return self.config['processing'].get('yolo_confidence', 0.6)
|
||||
|
||||
def get_detect_segments(self) -> Union[str, List[int]]:
|
||||
"""Get segments for YOLO detection."""
|
||||
return self.config['processing'].get('detect_segments', 'all')
|
||||
|
||||
def get_yolo_model_path(self) -> str:
|
||||
"""Get YOLO model path."""
|
||||
return self.config['models']['yolo_model']
|
||||
|
||||
def get_sam2_checkpoint(self) -> str:
|
||||
"""Get SAM2 checkpoint path."""
|
||||
return self.config['models']['sam2_checkpoint']
|
||||
|
||||
def get_sam2_config(self) -> str:
|
||||
"""Get SAM2 config path."""
|
||||
return self.config['models']['sam2_config']
|
||||
|
||||
def get_use_nvenc(self) -> bool:
|
||||
"""Get whether to use NVIDIA encoding."""
|
||||
return self.config.get('video', {}).get('use_nvenc', True)
|
||||
|
||||
def get_preserve_audio(self) -> bool:
|
||||
"""Get whether to preserve audio."""
|
||||
return self.config.get('video', {}).get('preserve_audio', True)
|
||||
|
||||
def get_output_bitrate(self) -> str:
|
||||
"""Get output video bitrate."""
|
||||
return self.config.get('video', {}).get('output_bitrate', '50M')
|
||||
|
||||
def get_green_color(self) -> List[int]:
|
||||
"""Get green screen color."""
|
||||
return self.config.get('advanced', {}).get('green_color', [0, 255, 0])
|
||||
|
||||
def get_blue_color(self) -> List[int]:
|
||||
"""Get blue screen color."""
|
||||
return self.config.get('advanced', {}).get('blue_color', [255, 0, 0])
|
||||
|
||||
def get_human_class_id(self) -> int:
|
||||
"""Get YOLO human class ID."""
|
||||
return self.config.get('advanced', {}).get('human_class_id', 0)
|
||||
|
||||
def get_log_level(self) -> str:
|
||||
"""Get logging level."""
|
||||
return self.config.get('advanced', {}).get('log_level', 'INFO')
|
||||
|
||||
def should_cleanup_intermediate_files(self) -> bool:
|
||||
"""Get whether to cleanup intermediate files."""
|
||||
return self.config.get('advanced', {}).get('cleanup_intermediate_files', True)
|
||||
Reference in New Issue
Block a user