This commit is contained in:
2025-07-29 10:13:29 -07:00
parent 02ad4d87d2
commit 6617acb1c9
2 changed files with 25 additions and 8 deletions

View File

@@ -148,9 +148,13 @@ def download_yolo_models():
"yolov8n.pt", # Detection models "yolov8n.pt", # Detection models
"yolov8s.pt", "yolov8s.pt",
"yolov8m.pt", "yolov8m.pt",
"yolo11l.pt", # YOLOv11 detection models
"yolo11x.pt",
"yolov8n-seg.pt", # Segmentation models "yolov8n-seg.pt", # Segmentation models
"yolov8s-seg.pt", "yolov8s-seg.pt",
"yolov8m-seg.pt" "yolov8m-seg.pt",
"yolo11l-seg.pt", # YOLOv11 segmentation models
"yolo11x-seg.pt"
] ]
models_dir = Path(__file__).parent / "models" / "yolo" models_dir = Path(__file__).parent / "models" / "yolo"
@@ -193,7 +197,11 @@ def download_yolo_models():
if not found: if not found:
# Last resort: use urllib to download directly # Last resort: use urllib to download directly
yolo_url = f"https://github.com/ultralytics/assets/releases/download/v8.2.0/{model_name}" # Use different release versions for different YOLO versions
if model_name.startswith("yolov11"):
yolo_url = f"https://github.com/ultralytics/assets/releases/download/v8.3.0/{model_name}"
else:
yolo_url = f"https://github.com/ultralytics/assets/releases/download/v8.2.0/{model_name}"
print(f" Downloading directly from {yolo_url}...") print(f" Downloading directly from {yolo_url}...")
download_file(yolo_url, str(model_path), f"YOLO {model_name}") download_file(yolo_url, str(model_path), f"YOLO {model_name}")
@@ -201,7 +209,11 @@ def download_yolo_models():
print(f" ⚠ Error downloading {model_name}: {e}") print(f" ⚠ Error downloading {model_name}: {e}")
# Try direct download as fallback # Try direct download as fallback
try: try:
yolo_url = f"https://github.com/ultralytics/assets/releases/download/v8.2.0/{model_name}" # Use different release versions for different YOLO versions
if model_name.startswith("yolov11"):
yolo_url = f"https://github.com/ultralytics/assets/releases/download/v8.3.0/{model_name}"
else:
yolo_url = f"https://github.com/ultralytics/assets/releases/download/v8.2.0/{model_name}"
print(f" Trying direct download from {yolo_url}...") print(f" Trying direct download from {yolo_url}...")
download_file(yolo_url, str(model_path), f"YOLO {model_name}") download_file(yolo_url, str(model_path), f"YOLO {model_name}")
except Exception as e2: except Exception as e2:
@@ -213,8 +225,8 @@ def download_yolo_models():
success = all((models_dir / model).exists() for model in yolo_models) success = all((models_dir / model).exists() for model in yolo_models)
if success: if success:
print("✓ YOLO models setup complete!") print("✓ YOLO models setup complete!")
print(" Available detection models: yolov8n.pt, yolov8s.pt, yolov8m.pt") print(" Available detection models: yolov8n.pt, yolov8s.pt, yolov8m.pt, yolov11l.pt, yolov11x.pt")
print(" Available segmentation models: yolov8n-seg.pt, yolov8s-seg.pt, yolov8m-seg.pt") print(" Available segmentation models: yolov8n-seg.pt, yolov8s-seg.pt, yolov8m-seg.pt, yolov11l-seg.pt, yolov11x-seg.pt")
else: else:
missing_models = [model for model in yolo_models if not (models_dir / model).exists()] missing_models = [model for model in yolo_models if not (models_dir / model).exists()]
print("⚠ Some YOLO models may be missing:") print("⚠ Some YOLO models may be missing:")

View File

@@ -281,10 +281,12 @@ def main():
vos_optimized=config.get('models.sam2_vos_optimized', False) vos_optimized=config.get('models.sam2_vos_optimized', False)
) )
# Initialize mask processor # Initialize mask processor with quality enhancements
mask_quality_config = config.get('mask_processing', {})
mask_processor = MaskProcessor( mask_processor = MaskProcessor(
green_color=config.get_green_color(), green_color=config.get_green_color(),
blue_color=config.get_blue_color() blue_color=config.get_blue_color(),
mask_quality_config=mask_quality_config
) )
# Process each segment sequentially (YOLO -> SAM2 -> Render) # Process each segment sequentially (YOLO -> SAM2 -> Render)
@@ -296,6 +298,9 @@ def main():
logger.info(f"Processing segment {segment_idx}/{len(segments_info)-1}") logger.info(f"Processing segment {segment_idx}/{len(segments_info)-1}")
# Reset temporal history for new segment
mask_processor.reset_temporal_history()
# Skip if segment output already exists # Skip if segment output already exists
output_video = os.path.join(segment_info['directory'], f"output_{segment_idx}.mp4") output_video = os.path.join(segment_info['directory'], f"output_{segment_idx}.mp4")
if os.path.exists(output_video): if os.path.exists(output_video):