305 lines
12 KiB
Python
Executable File
305 lines
12 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""
|
|
Model download script for YOLO + SAM2 video processing pipeline.
|
|
Downloads SAM2.1 models and organizes them in the models directory.
|
|
"""
|
|
|
|
import os
|
|
import urllib.request
|
|
import urllib.error
|
|
from pathlib import Path
|
|
import sys
|
|
|
|
def create_directory_structure():
|
|
"""Create the models directory structure."""
|
|
base_dir = Path(__file__).parent
|
|
models_dir = base_dir / "models"
|
|
|
|
# Create main models directory
|
|
models_dir.mkdir(exist_ok=True)
|
|
|
|
# Create subdirectories
|
|
sam2_dir = models_dir / "sam2"
|
|
sam2_configs_dir = sam2_dir / "configs" / "sam2.1"
|
|
sam2_checkpoints_dir = sam2_dir / "checkpoints"
|
|
yolo_dir = models_dir / "yolo"
|
|
|
|
sam2_dir.mkdir(exist_ok=True)
|
|
sam2_configs_dir.mkdir(parents=True, exist_ok=True)
|
|
sam2_checkpoints_dir.mkdir(exist_ok=True)
|
|
yolo_dir.mkdir(exist_ok=True)
|
|
|
|
print(f"Created models directory structure in: {models_dir}")
|
|
return models_dir, sam2_configs_dir, sam2_checkpoints_dir, yolo_dir
|
|
|
|
def download_file(url, destination, description="file"):
|
|
"""Download a file with progress indication."""
|
|
try:
|
|
print(f"Downloading {description}...")
|
|
print(f" URL: {url}")
|
|
print(f" Destination: {destination}")
|
|
|
|
def progress_hook(block_num, block_size, total_size):
|
|
if total_size > 0:
|
|
percent = min(100, (block_num * block_size * 100) // total_size)
|
|
sys.stdout.write(f"\r Progress: {percent}%")
|
|
sys.stdout.flush()
|
|
|
|
urllib.request.urlretrieve(url, destination, progress_hook)
|
|
print(f"\n ✓ Downloaded {description}")
|
|
return True
|
|
|
|
except urllib.error.URLError as e:
|
|
print(f"\n ✗ Failed to download {description}: {e}")
|
|
return False
|
|
except Exception as e:
|
|
print(f"\n ✗ Error downloading {description}: {e}")
|
|
return False
|
|
|
|
def download_sam2_models():
|
|
"""Download SAM2.1 model configurations and checkpoints."""
|
|
print("Setting up SAM2.1 models...")
|
|
|
|
# Create directory structure
|
|
models_dir, configs_dir, checkpoints_dir, yolo_dir = create_directory_structure()
|
|
|
|
# SAM2.1 model definitions
|
|
sam2_models = {
|
|
"tiny": {
|
|
"config_url": "https://raw.githubusercontent.com/facebookresearch/sam2/main/sam2/configs/sam2.1/sam2.1_hiera_t.yaml",
|
|
"checkpoint_url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt",
|
|
"config_file": "sam2.1_hiera_t.yaml",
|
|
"checkpoint_file": "sam2.1_hiera_tiny.pt"
|
|
},
|
|
"small": {
|
|
"config_url": "https://raw.githubusercontent.com/facebookresearch/sam2/main/sam2/configs/sam2.1/sam2.1_hiera_s.yaml",
|
|
"checkpoint_url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt",
|
|
"config_file": "sam2.1_hiera_s.yaml",
|
|
"checkpoint_file": "sam2.1_hiera_small.pt"
|
|
},
|
|
"base_plus": {
|
|
"config_url": "https://raw.githubusercontent.com/facebookresearch/sam2/main/sam2/configs/sam2.1/sam2.1_hiera_b+.yaml",
|
|
"checkpoint_url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt",
|
|
"config_file": "sam2.1_hiera_b+.yaml",
|
|
"checkpoint_file": "sam2.1_hiera_base_plus.pt"
|
|
},
|
|
"large": {
|
|
"config_url": "https://raw.githubusercontent.com/facebookresearch/sam2/main/sam2/configs/sam2.1/sam2.1_hiera_l.yaml",
|
|
"checkpoint_url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt",
|
|
"config_file": "sam2.1_hiera_l.yaml",
|
|
"checkpoint_file": "sam2.1_hiera_large.pt"
|
|
}
|
|
}
|
|
|
|
success_count = 0
|
|
total_downloads = len(sam2_models) * 2 # configs + checkpoints
|
|
|
|
# Download each model's config and checkpoint
|
|
for model_name, model_info in sam2_models.items():
|
|
print(f"\n--- Downloading SAM2.1 {model_name.upper()} model ---")
|
|
|
|
# Download config file
|
|
config_path = configs_dir / model_info["config_file"]
|
|
if not config_path.exists():
|
|
if download_file(
|
|
model_info["config_url"],
|
|
config_path,
|
|
f"SAM2.1 {model_name} config"
|
|
):
|
|
success_count += 1
|
|
else:
|
|
print(f" ✓ Config file already exists: {config_path}")
|
|
success_count += 1
|
|
|
|
# Download checkpoint file
|
|
checkpoint_path = checkpoints_dir / model_info["checkpoint_file"]
|
|
if not checkpoint_path.exists():
|
|
if download_file(
|
|
model_info["checkpoint_url"],
|
|
checkpoint_path,
|
|
f"SAM2.1 {model_name} checkpoint"
|
|
):
|
|
success_count += 1
|
|
else:
|
|
print(f" ✓ Checkpoint file already exists: {checkpoint_path}")
|
|
success_count += 1
|
|
|
|
print(f"\n=== Download Summary ===")
|
|
print(f"Successfully downloaded: {success_count}/{total_downloads} files")
|
|
|
|
if success_count == total_downloads:
|
|
print("✓ All SAM2.1 models downloaded successfully!")
|
|
return True
|
|
else:
|
|
print(f"⚠ Some downloads failed ({total_downloads - success_count} files)")
|
|
return False
|
|
|
|
def download_yolo_models():
|
|
"""Download default YOLO models to models directory."""
|
|
print("\n--- Setting up YOLO models ---")
|
|
print(" Downloading both detection and segmentation models...")
|
|
|
|
try:
|
|
from ultralytics import YOLO
|
|
import torch
|
|
|
|
# Default YOLO models to download (both detection and segmentation)
|
|
yolo_models = [
|
|
"yolov8n.pt", # Detection models
|
|
"yolov8s.pt",
|
|
"yolov8m.pt",
|
|
"yolov8n-seg.pt", # Segmentation models
|
|
"yolov8s-seg.pt",
|
|
"yolov8m-seg.pt"
|
|
]
|
|
models_dir = Path(__file__).parent / "models" / "yolo"
|
|
|
|
for model_name in yolo_models:
|
|
model_path = models_dir / model_name
|
|
if not model_path.exists():
|
|
print(f"Downloading {model_name}...")
|
|
try:
|
|
# First try to download using the YOLO class with export
|
|
model = YOLO(model_name)
|
|
|
|
# Export/save the model to our directory
|
|
# The model.ckpt is the internal checkpoint
|
|
if hasattr(model, 'ckpt') and hasattr(model.ckpt, 'save'):
|
|
# Save the checkpoint directly
|
|
torch.save(model.ckpt, str(model_path))
|
|
print(f" ✓ Saved {model_name} to models directory")
|
|
else:
|
|
# Alternative: try to find where YOLO downloaded the model
|
|
import shutil
|
|
|
|
# Common locations where YOLO might store models
|
|
possible_paths = [
|
|
Path.home() / ".cache" / "ultralytics" / "models" / model_name,
|
|
Path.home() / ".ultralytics" / "models" / model_name,
|
|
Path.home() / "runs" / "detect" / model_name,
|
|
Path.cwd() / model_name, # Current directory
|
|
]
|
|
|
|
found = False
|
|
for possible_path in possible_paths:
|
|
if possible_path.exists():
|
|
shutil.copy2(possible_path, model_path)
|
|
print(f" ✓ Copied {model_name} from {possible_path}")
|
|
found = True
|
|
# Clean up if it was downloaded to current directory
|
|
if possible_path.parent == Path.cwd() and possible_path != model_path:
|
|
possible_path.unlink()
|
|
break
|
|
|
|
if not found:
|
|
# Last resort: use urllib to download directly
|
|
yolo_url = f"https://github.com/ultralytics/assets/releases/download/v8.2.0/{model_name}"
|
|
print(f" Downloading directly from {yolo_url}...")
|
|
download_file(yolo_url, str(model_path), f"YOLO {model_name}")
|
|
|
|
except Exception as e:
|
|
print(f" ⚠ Error downloading {model_name}: {e}")
|
|
# Try direct download as fallback
|
|
try:
|
|
yolo_url = f"https://github.com/ultralytics/assets/releases/download/v8.2.0/{model_name}"
|
|
print(f" Trying direct download from {yolo_url}...")
|
|
download_file(yolo_url, str(model_path), f"YOLO {model_name}")
|
|
except Exception as e2:
|
|
print(f" ✗ Failed to download {model_name}: {e2}")
|
|
else:
|
|
print(f" ✓ {model_name} already exists")
|
|
|
|
# Verify all models exist
|
|
success = all((models_dir / model).exists() for model in yolo_models)
|
|
if success:
|
|
print("✓ YOLO models setup complete!")
|
|
print(" Available detection models: yolov8n.pt, yolov8s.pt, yolov8m.pt")
|
|
print(" Available segmentation models: yolov8n-seg.pt, yolov8s-seg.pt, yolov8m-seg.pt")
|
|
else:
|
|
missing_models = [model for model in yolo_models if not (models_dir / model).exists()]
|
|
print("⚠ Some YOLO models may be missing:")
|
|
for model in missing_models:
|
|
print(f" - {model}")
|
|
return success
|
|
|
|
except ImportError:
|
|
print("⚠ ultralytics not installed. YOLO models will be downloaded on first use.")
|
|
return False
|
|
except Exception as e:
|
|
print(f"⚠ Error setting up YOLO models: {e}")
|
|
return False
|
|
|
|
def update_config_file():
|
|
"""Update config.yaml to use local model paths."""
|
|
print("\n--- Updating config.yaml ---")
|
|
|
|
config_path = Path(__file__).parent / "config.yaml"
|
|
if not config_path.exists():
|
|
print("⚠ config.yaml not found, skipping update")
|
|
return False
|
|
|
|
try:
|
|
# Read current config
|
|
with open(config_path, 'r') as f:
|
|
content = f.read()
|
|
|
|
# Update model paths to use local models
|
|
updated_content = content.replace(
|
|
'yolo_model: "yolov8n.pt"',
|
|
'yolo_model: "models/yolo/yolov8n.pt"'
|
|
).replace(
|
|
'yolo_detection_model: "models/yolo/yolov8n.pt"',
|
|
'yolo_detection_model: "models/yolo/yolov8n.pt"'
|
|
).replace(
|
|
'yolo_segmentation_model: "models/yolo/yolov8n-seg.pt"',
|
|
'yolo_segmentation_model: "models/yolo/yolov8n-seg.pt"'
|
|
).replace(
|
|
'sam2_checkpoint: "../checkpoints/sam2.1_hiera_large.pt"',
|
|
'sam2_checkpoint: "models/sam2/checkpoints/sam2.1_hiera_large.pt"'
|
|
).replace(
|
|
'sam2_config: "configs/sam2.1/sam2.1_hiera_l.yaml"',
|
|
'sam2_config: "models/sam2/configs/sam2.1/sam2.1_hiera_l.yaml"'
|
|
)
|
|
|
|
# Write updated config
|
|
with open(config_path, 'w') as f:
|
|
f.write(updated_content)
|
|
|
|
print("✓ Updated config.yaml to use local model paths")
|
|
return True
|
|
|
|
except Exception as e:
|
|
print(f"⚠ Error updating config.yaml: {e}")
|
|
return False
|
|
|
|
def main():
|
|
"""Main function to download all models."""
|
|
print("🤖 YOLO + SAM2 Model Download Script")
|
|
print("="*50)
|
|
|
|
# Download SAM2 models
|
|
sam2_success = download_sam2_models()
|
|
|
|
# Download YOLO models
|
|
yolo_success = download_yolo_models()
|
|
|
|
# Update config file
|
|
config_success = update_config_file()
|
|
|
|
print("\n" + "="*50)
|
|
print("📋 Final Summary:")
|
|
print(f" SAM2 models: {'✓' if sam2_success else '⚠'}")
|
|
print(f" YOLO models: {'✓' if yolo_success else '⚠'}")
|
|
print(f" Config update: {'✓' if config_success else '⚠'}")
|
|
|
|
if sam2_success and config_success:
|
|
print("\n🎉 Setup complete! You can now run the pipeline with:")
|
|
print(" python main.py --config config.yaml")
|
|
else:
|
|
print("\n⚠ Some steps failed. Check the output above for details.")
|
|
|
|
print("\n📁 Models are organized in:")
|
|
print(f" {Path(__file__).parent / 'models'}")
|
|
|
|
if __name__ == "__main__":
|
|
main() |