195 lines
8.1 KiB
Python
195 lines
8.1 KiB
Python
"""
|
|
Configuration loader for YOLO + SAM2 video processing pipeline.
|
|
Handles loading and validation of YAML configuration files.
|
|
"""
|
|
|
|
import yaml
|
|
import os
|
|
from typing import Dict, Any, List, Union
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class ConfigLoader:
|
|
"""Loads and validates configuration from YAML files."""
|
|
|
|
def __init__(self, config_path: str):
|
|
self.config_path = config_path
|
|
self.config = self._load_config()
|
|
self._validate_config()
|
|
|
|
def _load_config(self) -> Dict[str, Any]:
|
|
"""Load configuration from YAML file."""
|
|
if not os.path.exists(self.config_path):
|
|
raise FileNotFoundError(f"Configuration file not found: {self.config_path}")
|
|
|
|
try:
|
|
with open(self.config_path, 'r') as file:
|
|
config = yaml.safe_load(file)
|
|
logger.info(f"Loaded configuration from {self.config_path}")
|
|
return config
|
|
except yaml.YAMLError as e:
|
|
raise ValueError(f"Error parsing YAML file: {e}")
|
|
|
|
def _validate_config(self):
|
|
"""Validate required configuration fields."""
|
|
required_sections = ['input', 'output', 'processing', 'models']
|
|
|
|
for section in required_sections:
|
|
if section not in self.config:
|
|
raise ValueError(f"Missing required configuration section: {section}")
|
|
|
|
# Validate input section
|
|
if 'video_path' not in self.config['input']:
|
|
raise ValueError("Missing required field: input.video_path")
|
|
|
|
# Validate output section
|
|
required_output_fields = ['directory', 'filename']
|
|
for field in required_output_fields:
|
|
if field not in self.config['output']:
|
|
raise ValueError(f"Missing required field: output.{field}")
|
|
|
|
# Validate models section
|
|
required_model_fields = ['sam2_checkpoint', 'sam2_config']
|
|
for field in required_model_fields:
|
|
if field not in self.config['models']:
|
|
raise ValueError(f"Missing required field: models.{field}")
|
|
|
|
# Validate YOLO model configuration
|
|
yolo_mode = self.config['models'].get('yolo_mode', 'detection')
|
|
if yolo_mode not in ['detection', 'segmentation']:
|
|
raise ValueError(f"Invalid yolo_mode: {yolo_mode}. Must be 'detection' or 'segmentation'")
|
|
|
|
# Check for legacy yolo_model field vs new structure
|
|
has_legacy_yolo_model = 'yolo_model' in self.config['models']
|
|
has_new_yolo_models = 'yolo_detection_model' in self.config['models'] or 'yolo_segmentation_model' in self.config['models']
|
|
|
|
if not has_legacy_yolo_model and not has_new_yolo_models:
|
|
raise ValueError("Missing YOLO model configuration. Provide either 'yolo_model' (legacy) or 'yolo_detection_model'/'yolo_segmentation_model' (new)")
|
|
|
|
# Validate that the required model for the current mode exists
|
|
if yolo_mode == 'detection':
|
|
if has_new_yolo_models and 'yolo_detection_model' not in self.config['models']:
|
|
raise ValueError("yolo_mode is 'detection' but yolo_detection_model not specified")
|
|
elif yolo_mode == 'segmentation':
|
|
if has_new_yolo_models and 'yolo_segmentation_model' not in self.config['models']:
|
|
raise ValueError("yolo_mode is 'segmentation' but yolo_segmentation_model not specified")
|
|
|
|
# Validate processing.detect_segments format
|
|
detect_segments = self.config['processing'].get('detect_segments', 'all')
|
|
if not isinstance(detect_segments, (str, list)):
|
|
raise ValueError("detect_segments must be 'all' or a list of integers")
|
|
|
|
if isinstance(detect_segments, list):
|
|
if not all(isinstance(x, int) for x in detect_segments):
|
|
raise ValueError("detect_segments list must contain only integers")
|
|
|
|
def get(self, key_path: str, default=None):
|
|
"""
|
|
Get configuration value using dot notation.
|
|
|
|
Args:
|
|
key_path: Dot-separated key path (e.g., 'processing.yolo_confidence')
|
|
default: Default value if key not found
|
|
|
|
Returns:
|
|
Configuration value or default
|
|
"""
|
|
keys = key_path.split('.')
|
|
value = self.config
|
|
|
|
try:
|
|
for key in keys:
|
|
value = value[key]
|
|
return value
|
|
except (KeyError, TypeError):
|
|
return default
|
|
|
|
def get_input_video_path(self) -> str:
|
|
"""Get input video path."""
|
|
return self.config['input']['video_path']
|
|
|
|
def get_output_directory(self) -> str:
|
|
"""Get output directory path."""
|
|
return self.config['output']['directory']
|
|
|
|
def get_output_filename(self) -> str:
|
|
"""Get output filename."""
|
|
return self.config['output']['filename']
|
|
|
|
def get_segment_duration(self) -> int:
|
|
"""Get segment duration in seconds."""
|
|
return self.config['processing'].get('segment_duration', 5)
|
|
|
|
def get_inference_scale(self) -> float:
|
|
"""Get inference scale factor."""
|
|
return self.config['processing'].get('inference_scale', 0.5)
|
|
|
|
def get_yolo_confidence(self) -> float:
|
|
"""Get YOLO confidence threshold."""
|
|
return self.config['processing'].get('yolo_confidence', 0.6)
|
|
|
|
def get_detect_segments(self) -> Union[str, List[int]]:
|
|
"""Get segments for YOLO detection."""
|
|
return self.config['processing'].get('detect_segments', 'all')
|
|
|
|
def get_yolo_model_path(self) -> str:
|
|
"""Get YOLO model path (legacy method for backward compatibility)."""
|
|
# Check for legacy configuration first
|
|
if 'yolo_model' in self.config['models']:
|
|
return self.config['models']['yolo_model']
|
|
|
|
# Use new configuration based on mode
|
|
yolo_mode = self.config['models'].get('yolo_mode', 'detection')
|
|
if yolo_mode == 'detection':
|
|
return self.config['models'].get('yolo_detection_model', 'yolov8n.pt')
|
|
else: # segmentation mode
|
|
return self.config['models'].get('yolo_segmentation_model', 'yolov8n-seg.pt')
|
|
|
|
def get_sam2_checkpoint(self) -> str:
|
|
"""Get SAM2 checkpoint path."""
|
|
return self.config['models']['sam2_checkpoint']
|
|
|
|
def get_sam2_config(self) -> str:
|
|
"""Get SAM2 config path."""
|
|
return self.config['models']['sam2_config']
|
|
|
|
def get_use_nvenc(self) -> bool:
|
|
"""Get whether to use NVIDIA encoding."""
|
|
return self.config.get('video', {}).get('use_nvenc', True)
|
|
|
|
def get_preserve_audio(self) -> bool:
|
|
"""Get whether to preserve audio."""
|
|
return self.config.get('video', {}).get('preserve_audio', True)
|
|
|
|
def get_output_bitrate(self) -> str:
|
|
"""Get output video bitrate."""
|
|
return self.config.get('video', {}).get('output_bitrate', '50M')
|
|
|
|
def get_green_color(self) -> List[int]:
|
|
"""Get green screen color."""
|
|
return self.config.get('advanced', {}).get('green_color', [0, 255, 0])
|
|
|
|
def get_blue_color(self) -> List[int]:
|
|
"""Get blue screen color."""
|
|
return self.config.get('advanced', {}).get('blue_color', [255, 0, 0])
|
|
|
|
def get_human_class_id(self) -> int:
|
|
"""Get YOLO human class ID."""
|
|
return self.config.get('advanced', {}).get('human_class_id', 0)
|
|
|
|
def get_log_level(self) -> str:
|
|
"""Get logging level."""
|
|
return self.config.get('advanced', {}).get('log_level', 'INFO')
|
|
|
|
def should_cleanup_intermediate_files(self) -> bool:
|
|
"""Get whether to cleanup intermediate files."""
|
|
return self.config.get('advanced', {}).get('cleanup_intermediate_files', True)
|
|
|
|
def get_stereo_iou_threshold(self) -> float:
|
|
"""Get the IOU threshold for stereo mask agreement."""
|
|
return self.config['processing'].get('stereo_iou_threshold', 0.5)
|
|
|
|
def get_confidence_reduction_factor(self) -> float:
|
|
"""Get the factor to reduce YOLO confidence by on retry."""
|
|
return self.config['processing'].get('confidence_reduction_factor', 0.8) |