working with segemntation
This commit is contained in:
@@ -50,11 +50,31 @@ class ConfigLoader:
|
||||
raise ValueError(f"Missing required field: output.{field}")
|
||||
|
||||
# Validate models section
|
||||
required_model_fields = ['yolo_model', 'sam2_checkpoint', 'sam2_config']
|
||||
required_model_fields = ['sam2_checkpoint', 'sam2_config']
|
||||
for field in required_model_fields:
|
||||
if field not in self.config['models']:
|
||||
raise ValueError(f"Missing required field: models.{field}")
|
||||
|
||||
# Validate YOLO model configuration
|
||||
yolo_mode = self.config['models'].get('yolo_mode', 'detection')
|
||||
if yolo_mode not in ['detection', 'segmentation']:
|
||||
raise ValueError(f"Invalid yolo_mode: {yolo_mode}. Must be 'detection' or 'segmentation'")
|
||||
|
||||
# Check for legacy yolo_model field vs new structure
|
||||
has_legacy_yolo_model = 'yolo_model' in self.config['models']
|
||||
has_new_yolo_models = 'yolo_detection_model' in self.config['models'] or 'yolo_segmentation_model' in self.config['models']
|
||||
|
||||
if not has_legacy_yolo_model and not has_new_yolo_models:
|
||||
raise ValueError("Missing YOLO model configuration. Provide either 'yolo_model' (legacy) or 'yolo_detection_model'/'yolo_segmentation_model' (new)")
|
||||
|
||||
# Validate that the required model for the current mode exists
|
||||
if yolo_mode == 'detection':
|
||||
if has_new_yolo_models and 'yolo_detection_model' not in self.config['models']:
|
||||
raise ValueError("yolo_mode is 'detection' but yolo_detection_model not specified")
|
||||
elif yolo_mode == 'segmentation':
|
||||
if has_new_yolo_models and 'yolo_segmentation_model' not in self.config['models']:
|
||||
raise ValueError("yolo_mode is 'segmentation' but yolo_segmentation_model not specified")
|
||||
|
||||
# Validate processing.detect_segments format
|
||||
detect_segments = self.config['processing'].get('detect_segments', 'all')
|
||||
if not isinstance(detect_segments, (str, list)):
|
||||
@@ -114,8 +134,17 @@ class ConfigLoader:
|
||||
return self.config['processing'].get('detect_segments', 'all')
|
||||
|
||||
def get_yolo_model_path(self) -> str:
|
||||
"""Get YOLO model path."""
|
||||
return self.config['models']['yolo_model']
|
||||
"""Get YOLO model path (legacy method for backward compatibility)."""
|
||||
# Check for legacy configuration first
|
||||
if 'yolo_model' in self.config['models']:
|
||||
return self.config['models']['yolo_model']
|
||||
|
||||
# Use new configuration based on mode
|
||||
yolo_mode = self.config['models'].get('yolo_mode', 'detection')
|
||||
if yolo_mode == 'detection':
|
||||
return self.config['models'].get('yolo_detection_model', 'yolov8n.pt')
|
||||
else: # segmentation mode
|
||||
return self.config['models'].get('yolo_segmentation_model', 'yolov8n-seg.pt')
|
||||
|
||||
def get_sam2_checkpoint(self) -> str:
|
||||
"""Get SAM2 checkpoint path."""
|
||||
|
||||
Reference in New Issue
Block a user