Files
samyolo_on_segments/download_models.py
2025-07-27 12:11:36 -07:00

286 lines
12 KiB
Python
Executable File

#!/usr/bin/env python3
"""
Model download script for YOLO + SAM2 video processing pipeline.
Downloads SAM2.1 models and organizes them in the models directory.
"""
import os
import urllib.request
import urllib.error
from pathlib import Path
import sys
def create_directory_structure():
"""Create the models directory structure."""
base_dir = Path(__file__).parent
models_dir = base_dir / "models"
# Create main models directory
models_dir.mkdir(exist_ok=True)
# Create subdirectories
sam2_dir = models_dir / "sam2"
sam2_configs_dir = sam2_dir / "configs" / "sam2.1"
sam2_checkpoints_dir = sam2_dir / "checkpoints"
yolo_dir = models_dir / "yolo"
sam2_dir.mkdir(exist_ok=True)
sam2_configs_dir.mkdir(parents=True, exist_ok=True)
sam2_checkpoints_dir.mkdir(exist_ok=True)
yolo_dir.mkdir(exist_ok=True)
print(f"Created models directory structure in: {models_dir}")
return models_dir, sam2_configs_dir, sam2_checkpoints_dir, yolo_dir
def download_file(url, destination, description="file"):
"""Download a file with progress indication."""
try:
print(f"Downloading {description}...")
print(f" URL: {url}")
print(f" Destination: {destination}")
def progress_hook(block_num, block_size, total_size):
if total_size > 0:
percent = min(100, (block_num * block_size * 100) // total_size)
sys.stdout.write(f"\r Progress: {percent}%")
sys.stdout.flush()
urllib.request.urlretrieve(url, destination, progress_hook)
print(f"\n ✓ Downloaded {description}")
return True
except urllib.error.URLError as e:
print(f"\n ✗ Failed to download {description}: {e}")
return False
except Exception as e:
print(f"\n ✗ Error downloading {description}: {e}")
return False
def download_sam2_models():
"""Download SAM2.1 model configurations and checkpoints."""
print("Setting up SAM2.1 models...")
# Create directory structure
models_dir, configs_dir, checkpoints_dir, yolo_dir = create_directory_structure()
# SAM2.1 model definitions
sam2_models = {
"tiny": {
"config_url": "https://raw.githubusercontent.com/facebookresearch/sam2/main/sam2/configs/sam2.1/sam2.1_hiera_t.yaml",
"checkpoint_url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt",
"config_file": "sam2.1_hiera_t.yaml",
"checkpoint_file": "sam2.1_hiera_tiny.pt"
},
"small": {
"config_url": "https://raw.githubusercontent.com/facebookresearch/sam2/main/sam2/configs/sam2.1/sam2.1_hiera_s.yaml",
"checkpoint_url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt",
"config_file": "sam2.1_hiera_s.yaml",
"checkpoint_file": "sam2.1_hiera_small.pt"
},
"base_plus": {
"config_url": "https://raw.githubusercontent.com/facebookresearch/sam2/main/sam2/configs/sam2.1/sam2.1_hiera_b+.yaml",
"checkpoint_url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt",
"config_file": "sam2.1_hiera_b+.yaml",
"checkpoint_file": "sam2.1_hiera_base_plus.pt"
},
"large": {
"config_url": "https://raw.githubusercontent.com/facebookresearch/sam2/main/sam2/configs/sam2.1/sam2.1_hiera_l.yaml",
"checkpoint_url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt",
"config_file": "sam2.1_hiera_l.yaml",
"checkpoint_file": "sam2.1_hiera_large.pt"
}
}
success_count = 0
total_downloads = len(sam2_models) * 2 # configs + checkpoints
# Download each model's config and checkpoint
for model_name, model_info in sam2_models.items():
print(f"\n--- Downloading SAM2.1 {model_name.upper()} model ---")
# Download config file
config_path = configs_dir / model_info["config_file"]
if not config_path.exists():
if download_file(
model_info["config_url"],
config_path,
f"SAM2.1 {model_name} config"
):
success_count += 1
else:
print(f" ✓ Config file already exists: {config_path}")
success_count += 1
# Download checkpoint file
checkpoint_path = checkpoints_dir / model_info["checkpoint_file"]
if not checkpoint_path.exists():
if download_file(
model_info["checkpoint_url"],
checkpoint_path,
f"SAM2.1 {model_name} checkpoint"
):
success_count += 1
else:
print(f" ✓ Checkpoint file already exists: {checkpoint_path}")
success_count += 1
print(f"\n=== Download Summary ===")
print(f"Successfully downloaded: {success_count}/{total_downloads} files")
if success_count == total_downloads:
print("✓ All SAM2.1 models downloaded successfully!")
return True
else:
print(f"⚠ Some downloads failed ({total_downloads - success_count} files)")
return False
def download_yolo_models():
"""Download default YOLO models to models directory."""
print("\n--- Setting up YOLO models ---")
try:
from ultralytics import YOLO
import torch
# Default YOLO models to download
yolo_models = ["yolov8n.pt", "yolov8s.pt", "yolov8m.pt"]
models_dir = Path(__file__).parent / "models" / "yolo"
for model_name in yolo_models:
model_path = models_dir / model_name
if not model_path.exists():
print(f"Downloading {model_name}...")
try:
# First try to download using the YOLO class with export
model = YOLO(model_name)
# Export/save the model to our directory
# The model.ckpt is the internal checkpoint
if hasattr(model, 'ckpt') and hasattr(model.ckpt, 'save'):
# Save the checkpoint directly
torch.save(model.ckpt, str(model_path))
print(f" ✓ Saved {model_name} to models directory")
else:
# Alternative: try to find where YOLO downloaded the model
import shutil
# Common locations where YOLO might store models
possible_paths = [
Path.home() / ".cache" / "ultralytics" / "models" / model_name,
Path.home() / ".ultralytics" / "models" / model_name,
Path.home() / "runs" / "detect" / model_name,
Path.cwd() / model_name, # Current directory
]
found = False
for possible_path in possible_paths:
if possible_path.exists():
shutil.copy2(possible_path, model_path)
print(f" ✓ Copied {model_name} from {possible_path}")
found = True
# Clean up if it was downloaded to current directory
if possible_path.parent == Path.cwd() and possible_path != model_path:
possible_path.unlink()
break
if not found:
# Last resort: use urllib to download directly
yolo_url = f"https://github.com/ultralytics/assets/releases/download/v8.2.0/{model_name}"
print(f" Downloading directly from {yolo_url}...")
download_file(yolo_url, str(model_path), f"YOLO {model_name}")
except Exception as e:
print(f" ⚠ Error downloading {model_name}: {e}")
# Try direct download as fallback
try:
yolo_url = f"https://github.com/ultralytics/assets/releases/download/v8.2.0/{model_name}"
print(f" Trying direct download from {yolo_url}...")
download_file(yolo_url, str(model_path), f"YOLO {model_name}")
except Exception as e2:
print(f" ✗ Failed to download {model_name}: {e2}")
else:
print(f"{model_name} already exists")
# Verify all models exist
success = all((models_dir / model).exists() for model in yolo_models)
if success:
print("✓ YOLO models setup complete!")
else:
print("⚠ Some YOLO models may be missing")
return success
except ImportError:
print("⚠ ultralytics not installed. YOLO models will be downloaded on first use.")
return False
except Exception as e:
print(f"⚠ Error setting up YOLO models: {e}")
return False
def update_config_file():
"""Update config.yaml to use local model paths."""
print("\n--- Updating config.yaml ---")
config_path = Path(__file__).parent / "config.yaml"
if not config_path.exists():
print("⚠ config.yaml not found, skipping update")
return False
try:
# Read current config
with open(config_path, 'r') as f:
content = f.read()
# Update model paths to use local models
updated_content = content.replace(
'yolo_model: "yolov8n.pt"',
'yolo_model: "models/yolo/yolov8n.pt"'
).replace(
'sam2_checkpoint: "../checkpoints/sam2.1_hiera_large.pt"',
'sam2_checkpoint: "models/sam2/checkpoints/sam2.1_hiera_large.pt"'
).replace(
'sam2_config: "configs/sam2.1/sam2.1_hiera_l.yaml"',
'sam2_config: "models/sam2/configs/sam2.1/sam2.1_hiera_l.yaml"'
)
# Write updated config
with open(config_path, 'w') as f:
f.write(updated_content)
print("✓ Updated config.yaml to use local model paths")
return True
except Exception as e:
print(f"⚠ Error updating config.yaml: {e}")
return False
def main():
"""Main function to download all models."""
print("🤖 YOLO + SAM2 Model Download Script")
print("="*50)
# Download SAM2 models
sam2_success = download_sam2_models()
# Download YOLO models
yolo_success = download_yolo_models()
# Update config file
config_success = update_config_file()
print("\n" + "="*50)
print("📋 Final Summary:")
print(f" SAM2 models: {'' if sam2_success else ''}")
print(f" YOLO models: {'' if yolo_success else ''}")
print(f" Config update: {'' if config_success else ''}")
if sam2_success and config_success:
print("\n🎉 Setup complete! You can now run the pipeline with:")
print(" python main.py --config config.yaml")
else:
print("\n⚠ Some steps failed. Check the output above for details.")
print("\n📁 Models are organized in:")
print(f" {Path(__file__).parent / 'models'}")
if __name__ == "__main__":
main()