Files
samyolo_on_segments/core/config_loader.py
2025-07-31 11:13:31 -07:00

195 lines
8.1 KiB
Python

"""
Configuration loader for YOLO + SAM2 video processing pipeline.
Handles loading and validation of YAML configuration files.
"""
import yaml
import os
from typing import Dict, Any, List, Union
import logging
logger = logging.getLogger(__name__)
class ConfigLoader:
"""Loads and validates configuration from YAML files."""
def __init__(self, config_path: str):
self.config_path = config_path
self.config = self._load_config()
self._validate_config()
def _load_config(self) -> Dict[str, Any]:
"""Load configuration from YAML file."""
if not os.path.exists(self.config_path):
raise FileNotFoundError(f"Configuration file not found: {self.config_path}")
try:
with open(self.config_path, 'r') as file:
config = yaml.safe_load(file)
logger.info(f"Loaded configuration from {self.config_path}")
return config
except yaml.YAMLError as e:
raise ValueError(f"Error parsing YAML file: {e}")
def _validate_config(self):
"""Validate required configuration fields."""
required_sections = ['input', 'output', 'processing', 'models']
for section in required_sections:
if section not in self.config:
raise ValueError(f"Missing required configuration section: {section}")
# Validate input section
if 'video_path' not in self.config['input']:
raise ValueError("Missing required field: input.video_path")
# Validate output section
required_output_fields = ['directory', 'filename']
for field in required_output_fields:
if field not in self.config['output']:
raise ValueError(f"Missing required field: output.{field}")
# Validate models section
required_model_fields = ['sam2_checkpoint', 'sam2_config']
for field in required_model_fields:
if field not in self.config['models']:
raise ValueError(f"Missing required field: models.{field}")
# Validate YOLO model configuration
yolo_mode = self.config['models'].get('yolo_mode', 'detection')
if yolo_mode not in ['detection', 'segmentation']:
raise ValueError(f"Invalid yolo_mode: {yolo_mode}. Must be 'detection' or 'segmentation'")
# Check for legacy yolo_model field vs new structure
has_legacy_yolo_model = 'yolo_model' in self.config['models']
has_new_yolo_models = 'yolo_detection_model' in self.config['models'] or 'yolo_segmentation_model' in self.config['models']
if not has_legacy_yolo_model and not has_new_yolo_models:
raise ValueError("Missing YOLO model configuration. Provide either 'yolo_model' (legacy) or 'yolo_detection_model'/'yolo_segmentation_model' (new)")
# Validate that the required model for the current mode exists
if yolo_mode == 'detection':
if has_new_yolo_models and 'yolo_detection_model' not in self.config['models']:
raise ValueError("yolo_mode is 'detection' but yolo_detection_model not specified")
elif yolo_mode == 'segmentation':
if has_new_yolo_models and 'yolo_segmentation_model' not in self.config['models']:
raise ValueError("yolo_mode is 'segmentation' but yolo_segmentation_model not specified")
# Validate processing.detect_segments format
detect_segments = self.config['processing'].get('detect_segments', 'all')
if not isinstance(detect_segments, (str, list)):
raise ValueError("detect_segments must be 'all' or a list of integers")
if isinstance(detect_segments, list):
if not all(isinstance(x, int) for x in detect_segments):
raise ValueError("detect_segments list must contain only integers")
def get(self, key_path: str, default=None):
"""
Get configuration value using dot notation.
Args:
key_path: Dot-separated key path (e.g., 'processing.yolo_confidence')
default: Default value if key not found
Returns:
Configuration value or default
"""
keys = key_path.split('.')
value = self.config
try:
for key in keys:
value = value[key]
return value
except (KeyError, TypeError):
return default
def get_input_video_path(self) -> str:
"""Get input video path."""
return self.config['input']['video_path']
def get_output_directory(self) -> str:
"""Get output directory path."""
return self.config['output']['directory']
def get_output_filename(self) -> str:
"""Get output filename."""
return self.config['output']['filename']
def get_segment_duration(self) -> int:
"""Get segment duration in seconds."""
return self.config['processing'].get('segment_duration', 5)
def get_inference_scale(self) -> float:
"""Get inference scale factor."""
return self.config['processing'].get('inference_scale', 0.5)
def get_yolo_confidence(self) -> float:
"""Get YOLO confidence threshold."""
return self.config['processing'].get('yolo_confidence', 0.6)
def get_detect_segments(self) -> Union[str, List[int]]:
"""Get segments for YOLO detection."""
return self.config['processing'].get('detect_segments', 'all')
def get_yolo_model_path(self) -> str:
"""Get YOLO model path (legacy method for backward compatibility)."""
# Check for legacy configuration first
if 'yolo_model' in self.config['models']:
return self.config['models']['yolo_model']
# Use new configuration based on mode
yolo_mode = self.config['models'].get('yolo_mode', 'detection')
if yolo_mode == 'detection':
return self.config['models'].get('yolo_detection_model', 'yolov8n.pt')
else: # segmentation mode
return self.config['models'].get('yolo_segmentation_model', 'yolov8n-seg.pt')
def get_sam2_checkpoint(self) -> str:
"""Get SAM2 checkpoint path."""
return self.config['models']['sam2_checkpoint']
def get_sam2_config(self) -> str:
"""Get SAM2 config path."""
return self.config['models']['sam2_config']
def get_use_nvenc(self) -> bool:
"""Get whether to use NVIDIA encoding."""
return self.config.get('video', {}).get('use_nvenc', True)
def get_preserve_audio(self) -> bool:
"""Get whether to preserve audio."""
return self.config.get('video', {}).get('preserve_audio', True)
def get_output_bitrate(self) -> str:
"""Get output video bitrate."""
return self.config.get('video', {}).get('output_bitrate', '50M')
def get_green_color(self) -> List[int]:
"""Get green screen color."""
return self.config.get('advanced', {}).get('green_color', [0, 255, 0])
def get_blue_color(self) -> List[int]:
"""Get blue screen color."""
return self.config.get('advanced', {}).get('blue_color', [255, 0, 0])
def get_human_class_id(self) -> int:
"""Get YOLO human class ID."""
return self.config.get('advanced', {}).get('human_class_id', 0)
def get_log_level(self) -> str:
"""Get logging level."""
return self.config.get('advanced', {}).get('log_level', 'INFO')
def should_cleanup_intermediate_files(self) -> bool:
"""Get whether to cleanup intermediate files."""
return self.config.get('advanced', {}).get('cleanup_intermediate_files', True)
def get_stereo_iou_threshold(self) -> float:
"""Get the IOU threshold for stereo mask agreement."""
return self.config['processing'].get('stereo_iou_threshold', 0.5)
def get_confidence_reduction_factor(self) -> float:
"""Get the factor to reduce YOLO confidence by on retry."""
return self.config['processing'].get('confidence_reduction_factor', 0.8)