stage 1 working
This commit is contained in:
286
download_models.py
Executable file
286
download_models.py
Executable file
@@ -0,0 +1,286 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Model download script for YOLO + SAM2 video processing pipeline.
|
||||
Downloads SAM2.1 models and organizes them in the models directory.
|
||||
"""
|
||||
|
||||
import os
|
||||
import urllib.request
|
||||
import urllib.error
|
||||
from pathlib import Path
|
||||
import sys
|
||||
|
||||
def create_directory_structure():
|
||||
"""Create the models directory structure."""
|
||||
base_dir = Path(__file__).parent
|
||||
models_dir = base_dir / "models"
|
||||
|
||||
# Create main models directory
|
||||
models_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Create subdirectories
|
||||
sam2_dir = models_dir / "sam2"
|
||||
sam2_configs_dir = sam2_dir / "configs" / "sam2.1"
|
||||
sam2_checkpoints_dir = sam2_dir / "checkpoints"
|
||||
yolo_dir = models_dir / "yolo"
|
||||
|
||||
sam2_dir.mkdir(exist_ok=True)
|
||||
sam2_configs_dir.mkdir(parents=True, exist_ok=True)
|
||||
sam2_checkpoints_dir.mkdir(exist_ok=True)
|
||||
yolo_dir.mkdir(exist_ok=True)
|
||||
|
||||
print(f"Created models directory structure in: {models_dir}")
|
||||
return models_dir, sam2_configs_dir, sam2_checkpoints_dir, yolo_dir
|
||||
|
||||
def download_file(url, destination, description="file"):
|
||||
"""Download a file with progress indication."""
|
||||
try:
|
||||
print(f"Downloading {description}...")
|
||||
print(f" URL: {url}")
|
||||
print(f" Destination: {destination}")
|
||||
|
||||
def progress_hook(block_num, block_size, total_size):
|
||||
if total_size > 0:
|
||||
percent = min(100, (block_num * block_size * 100) // total_size)
|
||||
sys.stdout.write(f"\r Progress: {percent}%")
|
||||
sys.stdout.flush()
|
||||
|
||||
urllib.request.urlretrieve(url, destination, progress_hook)
|
||||
print(f"\n ✓ Downloaded {description}")
|
||||
return True
|
||||
|
||||
except urllib.error.URLError as e:
|
||||
print(f"\n ✗ Failed to download {description}: {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"\n ✗ Error downloading {description}: {e}")
|
||||
return False
|
||||
|
||||
def download_sam2_models():
|
||||
"""Download SAM2.1 model configurations and checkpoints."""
|
||||
print("Setting up SAM2.1 models...")
|
||||
|
||||
# Create directory structure
|
||||
models_dir, configs_dir, checkpoints_dir, yolo_dir = create_directory_structure()
|
||||
|
||||
# SAM2.1 model definitions
|
||||
sam2_models = {
|
||||
"tiny": {
|
||||
"config_url": "https://raw.githubusercontent.com/facebookresearch/sam2/main/sam2/configs/sam2.1/sam2.1_hiera_t.yaml",
|
||||
"checkpoint_url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt",
|
||||
"config_file": "sam2.1_hiera_t.yaml",
|
||||
"checkpoint_file": "sam2.1_hiera_tiny.pt"
|
||||
},
|
||||
"small": {
|
||||
"config_url": "https://raw.githubusercontent.com/facebookresearch/sam2/main/sam2/configs/sam2.1/sam2.1_hiera_s.yaml",
|
||||
"checkpoint_url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt",
|
||||
"config_file": "sam2.1_hiera_s.yaml",
|
||||
"checkpoint_file": "sam2.1_hiera_small.pt"
|
||||
},
|
||||
"base_plus": {
|
||||
"config_url": "https://raw.githubusercontent.com/facebookresearch/sam2/main/sam2/configs/sam2.1/sam2.1_hiera_b+.yaml",
|
||||
"checkpoint_url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt",
|
||||
"config_file": "sam2.1_hiera_b+.yaml",
|
||||
"checkpoint_file": "sam2.1_hiera_base_plus.pt"
|
||||
},
|
||||
"large": {
|
||||
"config_url": "https://raw.githubusercontent.com/facebookresearch/sam2/main/sam2/configs/sam2.1/sam2.1_hiera_l.yaml",
|
||||
"checkpoint_url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt",
|
||||
"config_file": "sam2.1_hiera_l.yaml",
|
||||
"checkpoint_file": "sam2.1_hiera_large.pt"
|
||||
}
|
||||
}
|
||||
|
||||
success_count = 0
|
||||
total_downloads = len(sam2_models) * 2 # configs + checkpoints
|
||||
|
||||
# Download each model's config and checkpoint
|
||||
for model_name, model_info in sam2_models.items():
|
||||
print(f"\n--- Downloading SAM2.1 {model_name.upper()} model ---")
|
||||
|
||||
# Download config file
|
||||
config_path = configs_dir / model_info["config_file"]
|
||||
if not config_path.exists():
|
||||
if download_file(
|
||||
model_info["config_url"],
|
||||
config_path,
|
||||
f"SAM2.1 {model_name} config"
|
||||
):
|
||||
success_count += 1
|
||||
else:
|
||||
print(f" ✓ Config file already exists: {config_path}")
|
||||
success_count += 1
|
||||
|
||||
# Download checkpoint file
|
||||
checkpoint_path = checkpoints_dir / model_info["checkpoint_file"]
|
||||
if not checkpoint_path.exists():
|
||||
if download_file(
|
||||
model_info["checkpoint_url"],
|
||||
checkpoint_path,
|
||||
f"SAM2.1 {model_name} checkpoint"
|
||||
):
|
||||
success_count += 1
|
||||
else:
|
||||
print(f" ✓ Checkpoint file already exists: {checkpoint_path}")
|
||||
success_count += 1
|
||||
|
||||
print(f"\n=== Download Summary ===")
|
||||
print(f"Successfully downloaded: {success_count}/{total_downloads} files")
|
||||
|
||||
if success_count == total_downloads:
|
||||
print("✓ All SAM2.1 models downloaded successfully!")
|
||||
return True
|
||||
else:
|
||||
print(f"⚠ Some downloads failed ({total_downloads - success_count} files)")
|
||||
return False
|
||||
|
||||
def download_yolo_models():
|
||||
"""Download default YOLO models to models directory."""
|
||||
print("\n--- Setting up YOLO models ---")
|
||||
|
||||
try:
|
||||
from ultralytics import YOLO
|
||||
import torch
|
||||
|
||||
# Default YOLO models to download
|
||||
yolo_models = ["yolov8n.pt", "yolov8s.pt", "yolov8m.pt"]
|
||||
models_dir = Path(__file__).parent / "models" / "yolo"
|
||||
|
||||
for model_name in yolo_models:
|
||||
model_path = models_dir / model_name
|
||||
if not model_path.exists():
|
||||
print(f"Downloading {model_name}...")
|
||||
try:
|
||||
# First try to download using the YOLO class with export
|
||||
model = YOLO(model_name)
|
||||
|
||||
# Export/save the model to our directory
|
||||
# The model.ckpt is the internal checkpoint
|
||||
if hasattr(model, 'ckpt') and hasattr(model.ckpt, 'save'):
|
||||
# Save the checkpoint directly
|
||||
torch.save(model.ckpt, str(model_path))
|
||||
print(f" ✓ Saved {model_name} to models directory")
|
||||
else:
|
||||
# Alternative: try to find where YOLO downloaded the model
|
||||
import shutil
|
||||
|
||||
# Common locations where YOLO might store models
|
||||
possible_paths = [
|
||||
Path.home() / ".cache" / "ultralytics" / "models" / model_name,
|
||||
Path.home() / ".ultralytics" / "models" / model_name,
|
||||
Path.home() / "runs" / "detect" / model_name,
|
||||
Path.cwd() / model_name, # Current directory
|
||||
]
|
||||
|
||||
found = False
|
||||
for possible_path in possible_paths:
|
||||
if possible_path.exists():
|
||||
shutil.copy2(possible_path, model_path)
|
||||
print(f" ✓ Copied {model_name} from {possible_path}")
|
||||
found = True
|
||||
# Clean up if it was downloaded to current directory
|
||||
if possible_path.parent == Path.cwd() and possible_path != model_path:
|
||||
possible_path.unlink()
|
||||
break
|
||||
|
||||
if not found:
|
||||
# Last resort: use urllib to download directly
|
||||
yolo_url = f"https://github.com/ultralytics/assets/releases/download/v8.2.0/{model_name}"
|
||||
print(f" Downloading directly from {yolo_url}...")
|
||||
download_file(yolo_url, str(model_path), f"YOLO {model_name}")
|
||||
|
||||
except Exception as e:
|
||||
print(f" ⚠ Error downloading {model_name}: {e}")
|
||||
# Try direct download as fallback
|
||||
try:
|
||||
yolo_url = f"https://github.com/ultralytics/assets/releases/download/v8.2.0/{model_name}"
|
||||
print(f" Trying direct download from {yolo_url}...")
|
||||
download_file(yolo_url, str(model_path), f"YOLO {model_name}")
|
||||
except Exception as e2:
|
||||
print(f" ✗ Failed to download {model_name}: {e2}")
|
||||
else:
|
||||
print(f" ✓ {model_name} already exists")
|
||||
|
||||
# Verify all models exist
|
||||
success = all((models_dir / model).exists() for model in yolo_models)
|
||||
if success:
|
||||
print("✓ YOLO models setup complete!")
|
||||
else:
|
||||
print("⚠ Some YOLO models may be missing")
|
||||
return success
|
||||
|
||||
except ImportError:
|
||||
print("⚠ ultralytics not installed. YOLO models will be downloaded on first use.")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"⚠ Error setting up YOLO models: {e}")
|
||||
return False
|
||||
|
||||
def update_config_file():
|
||||
"""Update config.yaml to use local model paths."""
|
||||
print("\n--- Updating config.yaml ---")
|
||||
|
||||
config_path = Path(__file__).parent / "config.yaml"
|
||||
if not config_path.exists():
|
||||
print("⚠ config.yaml not found, skipping update")
|
||||
return False
|
||||
|
||||
try:
|
||||
# Read current config
|
||||
with open(config_path, 'r') as f:
|
||||
content = f.read()
|
||||
|
||||
# Update model paths to use local models
|
||||
updated_content = content.replace(
|
||||
'yolo_model: "yolov8n.pt"',
|
||||
'yolo_model: "models/yolo/yolov8n.pt"'
|
||||
).replace(
|
||||
'sam2_checkpoint: "../checkpoints/sam2.1_hiera_large.pt"',
|
||||
'sam2_checkpoint: "models/sam2/checkpoints/sam2.1_hiera_large.pt"'
|
||||
).replace(
|
||||
'sam2_config: "configs/sam2.1/sam2.1_hiera_l.yaml"',
|
||||
'sam2_config: "models/sam2/configs/sam2.1/sam2.1_hiera_l.yaml"'
|
||||
)
|
||||
|
||||
# Write updated config
|
||||
with open(config_path, 'w') as f:
|
||||
f.write(updated_content)
|
||||
|
||||
print("✓ Updated config.yaml to use local model paths")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"⚠ Error updating config.yaml: {e}")
|
||||
return False
|
||||
|
||||
def main():
|
||||
"""Main function to download all models."""
|
||||
print("🤖 YOLO + SAM2 Model Download Script")
|
||||
print("="*50)
|
||||
|
||||
# Download SAM2 models
|
||||
sam2_success = download_sam2_models()
|
||||
|
||||
# Download YOLO models
|
||||
yolo_success = download_yolo_models()
|
||||
|
||||
# Update config file
|
||||
config_success = update_config_file()
|
||||
|
||||
print("\n" + "="*50)
|
||||
print("📋 Final Summary:")
|
||||
print(f" SAM2 models: {'✓' if sam2_success else '⚠'}")
|
||||
print(f" YOLO models: {'✓' if yolo_success else '⚠'}")
|
||||
print(f" Config update: {'✓' if config_success else '⚠'}")
|
||||
|
||||
if sam2_success and config_success:
|
||||
print("\n🎉 Setup complete! You can now run the pipeline with:")
|
||||
print(" python main.py --config config.yaml")
|
||||
else:
|
||||
print("\n⚠ Some steps failed. Check the output above for details.")
|
||||
|
||||
print("\n📁 Models are organized in:")
|
||||
print(f" {Path(__file__).parent / 'models'}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user