From 032ea9da4b7686a9f35e65f09c6d07e5395172c3 Mon Sep 17 00:00:00 2001 From: Scott Register Date: Sat, 26 Jul 2025 08:14:35 -0700 Subject: [PATCH] install sam2 the way facebook says --- config_example.yaml | 4 +- config_runpod.yaml | 2 +- download_sam2_configs.sh | 17 ------- runpod_setup.sh | 57 ++++++++++------------- test_installation.py | 85 ++++++++++++++++++++++++----------- vr180_matting/config.py | 4 +- vr180_matting/main.py | 2 + vr180_matting/sam2_wrapper.py | 21 +++++---- 8 files changed, 102 insertions(+), 90 deletions(-) delete mode 100644 download_sam2_configs.sh diff --git a/config_example.yaml b/config_example.yaml index c5dc48c..079ca6c 100644 --- a/config_example.yaml +++ b/config_example.yaml @@ -14,8 +14,8 @@ matting: use_disparity_mapping: true memory_offload: true fp16: true - sam2_model_cfg: "sam2_hiera_l" - sam2_checkpoint: "models/sam2_hiera_large.pt" + sam2_model_cfg: "sam2.1_hiera_l" + sam2_checkpoint: "segment-anything-2/checkpoints/sam2.1_hiera_large.pt" output: path: "path/to/output/" diff --git a/config_runpod.yaml b/config_runpod.yaml index 27dacd9..4f8f010 100644 --- a/config_runpod.yaml +++ b/config_runpod.yaml @@ -15,7 +15,7 @@ matting: memory_offload: false # A40 has enough VRAM fp16: true sam2_model_cfg: "sam2.1_hiera_l" - sam2_checkpoint: "models/sam2.1_hiera_large.pt" + sam2_checkpoint: "segment-anything-2/checkpoints/sam2.1_hiera_large.pt" output: path: "/workspace/output/matted_video.mp4" diff --git a/download_sam2_configs.sh b/download_sam2_configs.sh deleted file mode 100644 index 984f66a..0000000 --- a/download_sam2_configs.sh +++ /dev/null @@ -1,17 +0,0 @@ -#!/bin/bash -# Download SAM2 model configuration files - -echo "๐Ÿ“ฅ Downloading SAM2 configuration files..." - -mkdir -p sam2_configs -cd sam2_configs - -# Download SAM2 config files -wget -q --show-progress https://raw.githubusercontent.com/facebookresearch/segment-anything-2/main/sam2_configs/sam2_hiera_b+.yaml -wget -q --show-progress https://raw.githubusercontent.com/facebookresearch/segment-anything-2/main/sam2_configs/sam2_hiera_l.yaml -wget -q --show-progress https://raw.githubusercontent.com/facebookresearch/segment-anything-2/main/sam2_configs/sam2_hiera_s.yaml -wget -q --show-progress https://raw.githubusercontent.com/facebookresearch/segment-anything-2/main/sam2_configs/sam2_hiera_t.yaml - -cd .. - -echo "โœ… SAM2 configs downloaded to sam2_configs/" \ No newline at end of file diff --git a/runpod_setup.sh b/runpod_setup.sh index 54a409d..8fc460b 100644 --- a/runpod_setup.sh +++ b/runpod_setup.sh @@ -29,36 +29,22 @@ mkdir -p models # Download YOLOv8 models python -c "from ultralytics import YOLO; YOLO('yolov8n.pt'); YOLO('yolov8m.pt')" -# Download SAM2 checkpoints -cd models -echo "๐Ÿ“ฅ Downloading SAM2 models..." - -# Try different SAM2 model versions -if [ ! -f "sam2_hiera_large.pt" ]; then - echo "Trying SAM2 checkpoint version 1..." - wget -q --show-progress https://dl.fbaipublicfiles.com/segment_anything_2/sam2_hiera_large.pt || true +# Clone SAM2 repo for checkpoints +echo "๐Ÿ“ฅ Cloning SAM2 for model checkpoints..." +if [ ! -d "segment-anything-2" ]; then + git clone https://github.com/facebookresearch/segment-anything-2.git fi +# Download SAM2 checkpoints using their official script +cd segment-anything-2 +mkdir -p checkpoints +cd checkpoints if [ ! -f "sam2.1_hiera_large.pt" ]; then - echo "Trying SAM2.1 checkpoint (latest)..." - wget -q --show-progress https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt || true + echo "๐Ÿ“ฅ Downloading SAM2 checkpoints..." + chmod +x ../download_ckpts.sh 2>/dev/null || true + bash ../download_ckpts.sh || bash <(curl -s https://raw.githubusercontent.com/facebookresearch/segment-anything-2/main/download_ckpts.sh) fi - -# Download SAM2 config files -cd .. -mkdir -p sam2_configs -cd sam2_configs -echo "๐Ÿ“ฅ Downloading SAM2 configuration files..." -wget -q --show-progress https://raw.githubusercontent.com/facebookresearch/segment-anything-2/main/sam2_configs/sam2_hiera_b+.yaml -wget -q --show-progress https://raw.githubusercontent.com/facebookresearch/segment-anything-2/main/sam2_configs/sam2_hiera_l.yaml -wget -q --show-progress https://raw.githubusercontent.com/facebookresearch/segment-anything-2/main/sam2_configs/sam2_hiera_s.yaml -wget -q --show-progress https://raw.githubusercontent.com/facebookresearch/segment-anything-2/main/sam2_configs/sam2_hiera_t.yaml -wget -q --show-progress https://raw.githubusercontent.com/facebookresearch/segment-anything-2/main/sam2_configs/sam2.1_hiera_b+.yaml -wget -q --show-progress https://raw.githubusercontent.com/facebookresearch/segment-anything-2/main/sam2_configs/sam2.1_hiera_l.yaml -wget -q --show-progress https://raw.githubusercontent.com/facebookresearch/segment-anything-2/main/sam2_configs/sam2.1_hiera_s.yaml -wget -q --show-progress https://raw.githubusercontent.com/facebookresearch/segment-anything-2/main/sam2_configs/sam2.1_hiera_t.yaml - -cd .. +cd ../.. # Create working directories mkdir -p /workspace/data /workspace/output @@ -68,16 +54,21 @@ echo "" echo "๐Ÿงช Testing installation..." python test_installation.py -# Check which SAM2 model is available +# Check which SAM2 models are available echo "" echo "๐Ÿ“Š SAM2 Models available:" -if [ -f "models/sam2_hiera_large.pt" ]; then - echo " โœ… sam2_hiera_large.pt" - echo " Use in config: sam2_checkpoint: 'models/sam2_hiera_large.pt'" -fi -if [ -f "models/sam2.1_hiera_large.pt" ]; then +if [ -f "segment-anything-2/checkpoints/sam2.1_hiera_large.pt" ]; then echo " โœ… sam2.1_hiera_large.pt (recommended)" - echo " Use in config: sam2_checkpoint: 'models/sam2.1_hiera_large.pt'" + echo " Config: sam2_model_cfg: 'sam2.1_hiera_l'" + echo " Checkpoint: sam2_checkpoint: 'segment-anything-2/checkpoints/sam2.1_hiera_large.pt'" +fi +if [ -f "segment-anything-2/checkpoints/sam2.1_hiera_base_plus.pt" ]; then + echo " โœ… sam2.1_hiera_base_plus.pt" + echo " Config: sam2_model_cfg: 'sam2.1_hiera_base_plus'" +fi +if [ -f "segment-anything-2/checkpoints/sam2_hiera_large.pt" ]; then + echo " โœ… sam2_hiera_large.pt (legacy)" + echo " Config: sam2_model_cfg: 'sam2_hiera_l'" fi echo "" diff --git a/test_installation.py b/test_installation.py index ced745f..494a0ea 100644 --- a/test_installation.py +++ b/test_installation.py @@ -71,47 +71,62 @@ except ImportError as e: # Check SAM2 models print("\n๐Ÿ” Checking SAM2 models...") -models_dir = Path("models") +sam2_checkpoints_dir = Path("segment-anything-2/checkpoints") +models_dir = Path("models") # Legacy location + sam2_models = { - "sam2_hiera_large.pt": "Original SAM2 Large", + "sam2.1_hiera_tiny.pt": "SAM2.1 Tiny", + "sam2.1_hiera_small.pt": "SAM2.1 Small", + "sam2.1_hiera_base_plus.pt": "SAM2.1 Base+", "sam2.1_hiera_large.pt": "SAM2.1 Large (recommended)", - "sam2_hiera_base.pt": "SAM2 Base", - "sam2.1_hiera_base.pt": "SAM2.1 Base" + "sam2_hiera_tiny.pt": "SAM2 Tiny", + "sam2_hiera_small.pt": "SAM2 Small", + "sam2_hiera_base_plus.pt": "SAM2 Base+", + "sam2_hiera_large.pt": "SAM2 Large" } found_models = [] for model_file, model_name in sam2_models.items(): - model_path = models_dir / model_file - if model_path.exists(): - size_mb = model_path.stat().st_size / (1024 * 1024) + # Check SAM2 repo location first + sam2_path = sam2_checkpoints_dir / model_file + legacy_path = models_dir / model_file + + if sam2_path.exists(): + size_mb = sam2_path.stat().st_size / (1024 * 1024) print(f"โœ… {model_name}: {model_file} ({size_mb:.1f} MB)") - found_models.append(model_file) + found_models.append((model_file, str(sam2_path))) + elif legacy_path.exists(): + size_mb = legacy_path.stat().st_size / (1024 * 1024) + print(f"โœ… {model_name}: {model_file} ({size_mb:.1f} MB) [legacy location]") + found_models.append((model_file, str(legacy_path))) if not found_models: print("โŒ No SAM2 models found!") issues.append("No SAM2 models found - run setup script or download manually") else: print(f"\n๐Ÿ’ก Recommended config for best model found:") - if "sam2.1_hiera_large.pt" in found_models: + # Prioritize SAM2.1 models + if any("sam2.1_hiera_large.pt" in model[0] for model in found_models): + best_model = next(model for model in found_models if "sam2.1_hiera_large.pt" in model[0]) print(" sam2_model_cfg: 'sam2.1_hiera_l'") - print(" sam2_checkpoint: 'models/sam2.1_hiera_large.pt'") - elif "sam2_hiera_large.pt" in found_models: + print(f" sam2_checkpoint: '{best_model[1]}'") + elif any("sam2.1_hiera_base_plus.pt" in model[0] for model in found_models): + best_model = next(model for model in found_models if "sam2.1_hiera_base_plus.pt" in model[0]) + print(" sam2_model_cfg: 'sam2.1_hiera_base_plus'") + print(f" sam2_checkpoint: '{best_model[1]}'") + elif any("sam2_hiera_large.pt" in model[0] for model in found_models): + best_model = next(model for model in found_models if "sam2_hiera_large.pt" in model[0]) print(" sam2_model_cfg: 'sam2_hiera_l'") - print(" sam2_checkpoint: 'models/sam2_hiera_large.pt'") + print(f" sam2_checkpoint: '{best_model[1]}'") -# Check SAM2 configs -print("\n๐Ÿ” Checking SAM2 config files...") -configs_dir = Path("sam2_configs") -if configs_dir.exists(): - config_files = list(configs_dir.glob("*.yaml")) - if config_files: - print(f"โœ… Found {len(config_files)} SAM2 config files") - else: - print("โŒ No SAM2 config files found") - issues.append("SAM2 config files missing - may cause model loading errors") -else: - print("โŒ sam2_configs directory not found") - issues.append("SAM2 configs directory missing") +# Check SAM2 configs (now part of installed package) +print("\n๐Ÿ” Checking SAM2 configuration...") +try: + import sam2.sam2_configs + print("โœ… SAM2 configs available in installed package") +except ImportError: + print("โŒ SAM2 configs not found") + issues.append("SAM2 configs not available - SAM2 may not be properly installed") # Test model loading if possible if not any("SAM2 not installed" in issue for issue in issues): @@ -119,6 +134,22 @@ if not any("SAM2 not installed" in issue for issue in issues): try: # Try to load the default model config from vr180_matting.config import VR180Config + # Use the best available model + if found_models: + best_model = found_models[0] # Use first found model (prioritized) + sam2_checkpoint = best_model[1] + if "sam2.1_hiera_large.pt" in best_model[0]: + sam2_cfg = "sam2.1_hiera_l" + elif "sam2.1_hiera_base_plus.pt" in best_model[0]: + sam2_cfg = "sam2.1_hiera_base_plus" + elif "sam2_hiera_large.pt" in best_model[0]: + sam2_cfg = "sam2_hiera_l" + else: + sam2_cfg = "sam2.1_hiera_l" + else: + sam2_cfg = "sam2.1_hiera_l" + sam2_checkpoint = "segment-anything-2/checkpoints/sam2.1_hiera_large.pt" + config = VR180Config( input=type('obj', (object,), {'video_path': 'test.mp4'})(), processing=type('obj', (object,), {'scale_factor': 0.5, 'chunk_size': 900, 'overlap_frames': 60})(), @@ -127,8 +158,8 @@ if not any("SAM2 not installed" in issue for issue in issues): 'use_disparity_mapping': True, 'memory_offload': True, 'fp16': True, - 'sam2_model_cfg': 'sam2.1_hiera_l' if Path('models/sam2.1_hiera_large.pt').exists() else 'sam2_hiera_l', - 'sam2_checkpoint': 'models/sam2.1_hiera_large.pt' if Path('models/sam2.1_hiera_large.pt').exists() else 'models/sam2_hiera_large.pt' + 'sam2_model_cfg': sam2_cfg, + 'sam2_checkpoint': sam2_checkpoint })(), output=type('obj', (object,), {'path': 'output/', 'format': 'alpha', 'background_color': [0, 255, 0], 'maintain_sbs': True})(), hardware=type('obj', (object,), {'device': 'cuda' if torch.cuda.is_available() else 'cpu', 'max_vram_gb': 10})() diff --git a/vr180_matting/config.py b/vr180_matting/config.py index 23f1f9f..6a01946 100644 --- a/vr180_matting/config.py +++ b/vr180_matting/config.py @@ -27,8 +27,8 @@ class MattingConfig: use_disparity_mapping: bool = True memory_offload: bool = True fp16: bool = True - sam2_model_cfg: str = "sam2_hiera_l" - sam2_checkpoint: str = "sam2_hiera_large.pt" + sam2_model_cfg: str = "sam2.1_hiera_l" + sam2_checkpoint: str = "segment-anything-2/checkpoints/sam2.1_hiera_large.pt" @dataclass diff --git a/vr180_matting/main.py b/vr180_matting/main.py index 7b6a785..7a5cd8d 100644 --- a/vr180_matting/main.py +++ b/vr180_matting/main.py @@ -111,6 +111,8 @@ matting: use_disparity_mapping: true memory_offload: true fp16: true + sam2_model_cfg: "sam2.1_hiera_l" + sam2_checkpoint: "segment-anything-2/checkpoints/sam2.1_hiera_large.pt" output: path: "path/to/output/" diff --git a/vr180_matting/sam2_wrapper.py b/vr180_matting/sam2_wrapper.py index 964008a..dbc69a8 100644 --- a/vr180_matting/sam2_wrapper.py +++ b/vr180_matting/sam2_wrapper.py @@ -39,18 +39,23 @@ class SAM2VideoMatting: def _load_model(self, model_cfg: str, checkpoint_path: str): """Load SAM2 video predictor with optimizations""" try: - # Check for checkpoint in models directory if not found + # Check for checkpoint in SAM2 repo structure if not Path(checkpoint_path).exists(): - models_path = Path("models") / checkpoint_path - if models_path.exists(): - checkpoint_path = str(models_path) + # Try in segment-anything-2/checkpoints/ + sam2_path = Path("segment-anything-2/checkpoints") / Path(checkpoint_path).name + if sam2_path.exists(): + checkpoint_path = str(sam2_path) else: - # Try relative to package - import os - package_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) - models_path = Path(package_dir) / "models" / checkpoint_path + # Try legacy models/ directory + models_path = Path("models") / Path(checkpoint_path).name if models_path.exists(): checkpoint_path = str(models_path) + else: + # Try relative to package + package_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + sam2_repo_path = Path(package_dir) / "segment-anything-2/checkpoints" / Path(checkpoint_path).name + if sam2_repo_path.exists(): + checkpoint_path = str(sam2_repo_path) self.predictor = build_sam2_video_predictor( model_cfg,