#!/usr/bin/env python3 """ Model download script for YOLO + SAM2 video processing pipeline. Downloads SAM2.1 models and organizes them in the models directory. """ import os import urllib.request import urllib.error from pathlib import Path import sys def create_directory_structure(): """Create the models directory structure.""" base_dir = Path(__file__).parent models_dir = base_dir / "models" # Create main models directory models_dir.mkdir(exist_ok=True) # Create subdirectories sam2_dir = models_dir / "sam2" sam2_configs_dir = sam2_dir / "configs" / "sam2.1" sam2_checkpoints_dir = sam2_dir / "checkpoints" yolo_dir = models_dir / "yolo" sam2_dir.mkdir(exist_ok=True) sam2_configs_dir.mkdir(parents=True, exist_ok=True) sam2_checkpoints_dir.mkdir(exist_ok=True) yolo_dir.mkdir(exist_ok=True) print(f"Created models directory structure in: {models_dir}") return models_dir, sam2_configs_dir, sam2_checkpoints_dir, yolo_dir def download_file(url, destination, description="file"): """Download a file with progress indication.""" try: print(f"Downloading {description}...") print(f" URL: {url}") print(f" Destination: {destination}") def progress_hook(block_num, block_size, total_size): if total_size > 0: percent = min(100, (block_num * block_size * 100) // total_size) sys.stdout.write(f"\r Progress: {percent}%") sys.stdout.flush() urllib.request.urlretrieve(url, destination, progress_hook) print(f"\n ✓ Downloaded {description}") return True except urllib.error.URLError as e: print(f"\n ✗ Failed to download {description}: {e}") return False except Exception as e: print(f"\n ✗ Error downloading {description}: {e}") return False def download_sam2_models(): """Download SAM2.1 model configurations and checkpoints.""" print("Setting up SAM2.1 models...") # Create directory structure models_dir, configs_dir, checkpoints_dir, yolo_dir = create_directory_structure() # SAM2.1 model definitions sam2_models = { "tiny": { "config_url": "https://raw.githubusercontent.com/facebookresearch/sam2/main/sam2/configs/sam2.1/sam2.1_hiera_t.yaml", "checkpoint_url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt", "config_file": "sam2.1_hiera_t.yaml", "checkpoint_file": "sam2.1_hiera_tiny.pt" }, "small": { "config_url": "https://raw.githubusercontent.com/facebookresearch/sam2/main/sam2/configs/sam2.1/sam2.1_hiera_s.yaml", "checkpoint_url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt", "config_file": "sam2.1_hiera_s.yaml", "checkpoint_file": "sam2.1_hiera_small.pt" }, "base_plus": { "config_url": "https://raw.githubusercontent.com/facebookresearch/sam2/main/sam2/configs/sam2.1/sam2.1_hiera_b+.yaml", "checkpoint_url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt", "config_file": "sam2.1_hiera_b+.yaml", "checkpoint_file": "sam2.1_hiera_base_plus.pt" }, "large": { "config_url": "https://raw.githubusercontent.com/facebookresearch/sam2/main/sam2/configs/sam2.1/sam2.1_hiera_l.yaml", "checkpoint_url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt", "config_file": "sam2.1_hiera_l.yaml", "checkpoint_file": "sam2.1_hiera_large.pt" } } success_count = 0 total_downloads = len(sam2_models) * 2 # configs + checkpoints # Download each model's config and checkpoint for model_name, model_info in sam2_models.items(): print(f"\n--- Downloading SAM2.1 {model_name.upper()} model ---") # Download config file config_path = configs_dir / model_info["config_file"] if not config_path.exists(): if download_file( model_info["config_url"], config_path, f"SAM2.1 {model_name} config" ): success_count += 1 else: print(f" ✓ Config file already exists: {config_path}") success_count += 1 # Download checkpoint file checkpoint_path = checkpoints_dir / model_info["checkpoint_file"] if not checkpoint_path.exists(): if download_file( model_info["checkpoint_url"], checkpoint_path, f"SAM2.1 {model_name} checkpoint" ): success_count += 1 else: print(f" ✓ Checkpoint file already exists: {checkpoint_path}") success_count += 1 print(f"\n=== Download Summary ===") print(f"Successfully downloaded: {success_count}/{total_downloads} files") if success_count == total_downloads: print("✓ All SAM2.1 models downloaded successfully!") return True else: print(f"⚠ Some downloads failed ({total_downloads - success_count} files)") return False def download_yolo_models(): """Download default YOLO models to models directory.""" print("\n--- Setting up YOLO models ---") print(" Downloading both detection and segmentation models...") try: from ultralytics import YOLO import torch # Default YOLO models to download (both detection and segmentation) yolo_models = [ "yolov8n.pt", # Detection models "yolov8s.pt", "yolov8m.pt", "yolov8n-seg.pt", # Segmentation models "yolov8s-seg.pt", "yolov8m-seg.pt" ] models_dir = Path(__file__).parent / "models" / "yolo" for model_name in yolo_models: model_path = models_dir / model_name if not model_path.exists(): print(f"Downloading {model_name}...") try: # First try to download using the YOLO class with export model = YOLO(model_name) # Export/save the model to our directory # The model.ckpt is the internal checkpoint if hasattr(model, 'ckpt') and hasattr(model.ckpt, 'save'): # Save the checkpoint directly torch.save(model.ckpt, str(model_path)) print(f" ✓ Saved {model_name} to models directory") else: # Alternative: try to find where YOLO downloaded the model import shutil # Common locations where YOLO might store models possible_paths = [ Path.home() / ".cache" / "ultralytics" / "models" / model_name, Path.home() / ".ultralytics" / "models" / model_name, Path.home() / "runs" / "detect" / model_name, Path.cwd() / model_name, # Current directory ] found = False for possible_path in possible_paths: if possible_path.exists(): shutil.copy2(possible_path, model_path) print(f" ✓ Copied {model_name} from {possible_path}") found = True # Clean up if it was downloaded to current directory if possible_path.parent == Path.cwd() and possible_path != model_path: possible_path.unlink() break if not found: # Last resort: use urllib to download directly yolo_url = f"https://github.com/ultralytics/assets/releases/download/v8.2.0/{model_name}" print(f" Downloading directly from {yolo_url}...") download_file(yolo_url, str(model_path), f"YOLO {model_name}") except Exception as e: print(f" ⚠ Error downloading {model_name}: {e}") # Try direct download as fallback try: yolo_url = f"https://github.com/ultralytics/assets/releases/download/v8.2.0/{model_name}" print(f" Trying direct download from {yolo_url}...") download_file(yolo_url, str(model_path), f"YOLO {model_name}") except Exception as e2: print(f" ✗ Failed to download {model_name}: {e2}") else: print(f" ✓ {model_name} already exists") # Verify all models exist success = all((models_dir / model).exists() for model in yolo_models) if success: print("✓ YOLO models setup complete!") print(" Available detection models: yolov8n.pt, yolov8s.pt, yolov8m.pt") print(" Available segmentation models: yolov8n-seg.pt, yolov8s-seg.pt, yolov8m-seg.pt") else: missing_models = [model for model in yolo_models if not (models_dir / model).exists()] print("⚠ Some YOLO models may be missing:") for model in missing_models: print(f" - {model}") return success except ImportError: print("⚠ ultralytics not installed. YOLO models will be downloaded on first use.") return False except Exception as e: print(f"⚠ Error setting up YOLO models: {e}") return False def update_config_file(): """Update config.yaml to use local model paths.""" print("\n--- Updating config.yaml ---") config_path = Path(__file__).parent / "config.yaml" if not config_path.exists(): print("⚠ config.yaml not found, skipping update") return False try: # Read current config with open(config_path, 'r') as f: content = f.read() # Update model paths to use local models updated_content = content.replace( 'yolo_model: "yolov8n.pt"', 'yolo_model: "models/yolo/yolov8n.pt"' ).replace( 'yolo_detection_model: "models/yolo/yolov8n.pt"', 'yolo_detection_model: "models/yolo/yolov8n.pt"' ).replace( 'yolo_segmentation_model: "models/yolo/yolov8n-seg.pt"', 'yolo_segmentation_model: "models/yolo/yolov8n-seg.pt"' ).replace( 'sam2_checkpoint: "../checkpoints/sam2.1_hiera_large.pt"', 'sam2_checkpoint: "models/sam2/checkpoints/sam2.1_hiera_large.pt"' ).replace( 'sam2_config: "configs/sam2.1/sam2.1_hiera_l.yaml"', 'sam2_config: "models/sam2/configs/sam2.1/sam2.1_hiera_l.yaml"' ) # Write updated config with open(config_path, 'w') as f: f.write(updated_content) print("✓ Updated config.yaml to use local model paths") return True except Exception as e: print(f"⚠ Error updating config.yaml: {e}") return False def main(): """Main function to download all models.""" print("🤖 YOLO + SAM2 Model Download Script") print("="*50) # Download SAM2 models sam2_success = download_sam2_models() # Download YOLO models yolo_success = download_yolo_models() # Update config file config_success = update_config_file() print("\n" + "="*50) print("📋 Final Summary:") print(f" SAM2 models: {'✓' if sam2_success else '⚠'}") print(f" YOLO models: {'✓' if yolo_success else '⚠'}") print(f" Config update: {'✓' if config_success else '⚠'}") if sam2_success and config_success: print("\n🎉 Setup complete! You can now run the pipeline with:") print(" python main.py --config config.yaml") else: print("\n⚠ Some steps failed. Check the output above for details.") print("\n📁 Models are organized in:") print(f" {Path(__file__).parent / 'models'}") if __name__ == "__main__": main()