install sam2 the way facebook says
This commit is contained in:
@@ -14,8 +14,8 @@ matting:
|
|||||||
use_disparity_mapping: true
|
use_disparity_mapping: true
|
||||||
memory_offload: true
|
memory_offload: true
|
||||||
fp16: true
|
fp16: true
|
||||||
sam2_model_cfg: "sam2_hiera_l"
|
sam2_model_cfg: "sam2.1_hiera_l"
|
||||||
sam2_checkpoint: "models/sam2_hiera_large.pt"
|
sam2_checkpoint: "segment-anything-2/checkpoints/sam2.1_hiera_large.pt"
|
||||||
|
|
||||||
output:
|
output:
|
||||||
path: "path/to/output/"
|
path: "path/to/output/"
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ matting:
|
|||||||
memory_offload: false # A40 has enough VRAM
|
memory_offload: false # A40 has enough VRAM
|
||||||
fp16: true
|
fp16: true
|
||||||
sam2_model_cfg: "sam2.1_hiera_l"
|
sam2_model_cfg: "sam2.1_hiera_l"
|
||||||
sam2_checkpoint: "models/sam2.1_hiera_large.pt"
|
sam2_checkpoint: "segment-anything-2/checkpoints/sam2.1_hiera_large.pt"
|
||||||
|
|
||||||
output:
|
output:
|
||||||
path: "/workspace/output/matted_video.mp4"
|
path: "/workspace/output/matted_video.mp4"
|
||||||
|
|||||||
@@ -1,17 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
# Download SAM2 model configuration files
|
|
||||||
|
|
||||||
echo "📥 Downloading SAM2 configuration files..."
|
|
||||||
|
|
||||||
mkdir -p sam2_configs
|
|
||||||
cd sam2_configs
|
|
||||||
|
|
||||||
# Download SAM2 config files
|
|
||||||
wget -q --show-progress https://raw.githubusercontent.com/facebookresearch/segment-anything-2/main/sam2_configs/sam2_hiera_b+.yaml
|
|
||||||
wget -q --show-progress https://raw.githubusercontent.com/facebookresearch/segment-anything-2/main/sam2_configs/sam2_hiera_l.yaml
|
|
||||||
wget -q --show-progress https://raw.githubusercontent.com/facebookresearch/segment-anything-2/main/sam2_configs/sam2_hiera_s.yaml
|
|
||||||
wget -q --show-progress https://raw.githubusercontent.com/facebookresearch/segment-anything-2/main/sam2_configs/sam2_hiera_t.yaml
|
|
||||||
|
|
||||||
cd ..
|
|
||||||
|
|
||||||
echo "✅ SAM2 configs downloaded to sam2_configs/"
|
|
||||||
@@ -29,36 +29,22 @@ mkdir -p models
|
|||||||
# Download YOLOv8 models
|
# Download YOLOv8 models
|
||||||
python -c "from ultralytics import YOLO; YOLO('yolov8n.pt'); YOLO('yolov8m.pt')"
|
python -c "from ultralytics import YOLO; YOLO('yolov8n.pt'); YOLO('yolov8m.pt')"
|
||||||
|
|
||||||
# Download SAM2 checkpoints
|
# Clone SAM2 repo for checkpoints
|
||||||
cd models
|
echo "📥 Cloning SAM2 for model checkpoints..."
|
||||||
echo "📥 Downloading SAM2 models..."
|
if [ ! -d "segment-anything-2" ]; then
|
||||||
|
git clone https://github.com/facebookresearch/segment-anything-2.git
|
||||||
# Try different SAM2 model versions
|
|
||||||
if [ ! -f "sam2_hiera_large.pt" ]; then
|
|
||||||
echo "Trying SAM2 checkpoint version 1..."
|
|
||||||
wget -q --show-progress https://dl.fbaipublicfiles.com/segment_anything_2/sam2_hiera_large.pt || true
|
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
# Download SAM2 checkpoints using their official script
|
||||||
|
cd segment-anything-2
|
||||||
|
mkdir -p checkpoints
|
||||||
|
cd checkpoints
|
||||||
if [ ! -f "sam2.1_hiera_large.pt" ]; then
|
if [ ! -f "sam2.1_hiera_large.pt" ]; then
|
||||||
echo "Trying SAM2.1 checkpoint (latest)..."
|
echo "📥 Downloading SAM2 checkpoints..."
|
||||||
wget -q --show-progress https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt || true
|
chmod +x ../download_ckpts.sh 2>/dev/null || true
|
||||||
|
bash ../download_ckpts.sh || bash <(curl -s https://raw.githubusercontent.com/facebookresearch/segment-anything-2/main/download_ckpts.sh)
|
||||||
fi
|
fi
|
||||||
|
cd ../..
|
||||||
# Download SAM2 config files
|
|
||||||
cd ..
|
|
||||||
mkdir -p sam2_configs
|
|
||||||
cd sam2_configs
|
|
||||||
echo "📥 Downloading SAM2 configuration files..."
|
|
||||||
wget -q --show-progress https://raw.githubusercontent.com/facebookresearch/segment-anything-2/main/sam2_configs/sam2_hiera_b+.yaml
|
|
||||||
wget -q --show-progress https://raw.githubusercontent.com/facebookresearch/segment-anything-2/main/sam2_configs/sam2_hiera_l.yaml
|
|
||||||
wget -q --show-progress https://raw.githubusercontent.com/facebookresearch/segment-anything-2/main/sam2_configs/sam2_hiera_s.yaml
|
|
||||||
wget -q --show-progress https://raw.githubusercontent.com/facebookresearch/segment-anything-2/main/sam2_configs/sam2_hiera_t.yaml
|
|
||||||
wget -q --show-progress https://raw.githubusercontent.com/facebookresearch/segment-anything-2/main/sam2_configs/sam2.1_hiera_b+.yaml
|
|
||||||
wget -q --show-progress https://raw.githubusercontent.com/facebookresearch/segment-anything-2/main/sam2_configs/sam2.1_hiera_l.yaml
|
|
||||||
wget -q --show-progress https://raw.githubusercontent.com/facebookresearch/segment-anything-2/main/sam2_configs/sam2.1_hiera_s.yaml
|
|
||||||
wget -q --show-progress https://raw.githubusercontent.com/facebookresearch/segment-anything-2/main/sam2_configs/sam2.1_hiera_t.yaml
|
|
||||||
|
|
||||||
cd ..
|
|
||||||
|
|
||||||
# Create working directories
|
# Create working directories
|
||||||
mkdir -p /workspace/data /workspace/output
|
mkdir -p /workspace/data /workspace/output
|
||||||
@@ -68,16 +54,21 @@ echo ""
|
|||||||
echo "🧪 Testing installation..."
|
echo "🧪 Testing installation..."
|
||||||
python test_installation.py
|
python test_installation.py
|
||||||
|
|
||||||
# Check which SAM2 model is available
|
# Check which SAM2 models are available
|
||||||
echo ""
|
echo ""
|
||||||
echo "📊 SAM2 Models available:"
|
echo "📊 SAM2 Models available:"
|
||||||
if [ -f "models/sam2_hiera_large.pt" ]; then
|
if [ -f "segment-anything-2/checkpoints/sam2.1_hiera_large.pt" ]; then
|
||||||
echo " ✅ sam2_hiera_large.pt"
|
|
||||||
echo " Use in config: sam2_checkpoint: 'models/sam2_hiera_large.pt'"
|
|
||||||
fi
|
|
||||||
if [ -f "models/sam2.1_hiera_large.pt" ]; then
|
|
||||||
echo " ✅ sam2.1_hiera_large.pt (recommended)"
|
echo " ✅ sam2.1_hiera_large.pt (recommended)"
|
||||||
echo " Use in config: sam2_checkpoint: 'models/sam2.1_hiera_large.pt'"
|
echo " Config: sam2_model_cfg: 'sam2.1_hiera_l'"
|
||||||
|
echo " Checkpoint: sam2_checkpoint: 'segment-anything-2/checkpoints/sam2.1_hiera_large.pt'"
|
||||||
|
fi
|
||||||
|
if [ -f "segment-anything-2/checkpoints/sam2.1_hiera_base_plus.pt" ]; then
|
||||||
|
echo " ✅ sam2.1_hiera_base_plus.pt"
|
||||||
|
echo " Config: sam2_model_cfg: 'sam2.1_hiera_base_plus'"
|
||||||
|
fi
|
||||||
|
if [ -f "segment-anything-2/checkpoints/sam2_hiera_large.pt" ]; then
|
||||||
|
echo " ✅ sam2_hiera_large.pt (legacy)"
|
||||||
|
echo " Config: sam2_model_cfg: 'sam2_hiera_l'"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
echo ""
|
echo ""
|
||||||
|
|||||||
@@ -71,47 +71,62 @@ except ImportError as e:
|
|||||||
|
|
||||||
# Check SAM2 models
|
# Check SAM2 models
|
||||||
print("\n🔍 Checking SAM2 models...")
|
print("\n🔍 Checking SAM2 models...")
|
||||||
models_dir = Path("models")
|
sam2_checkpoints_dir = Path("segment-anything-2/checkpoints")
|
||||||
|
models_dir = Path("models") # Legacy location
|
||||||
|
|
||||||
sam2_models = {
|
sam2_models = {
|
||||||
"sam2_hiera_large.pt": "Original SAM2 Large",
|
"sam2.1_hiera_tiny.pt": "SAM2.1 Tiny",
|
||||||
|
"sam2.1_hiera_small.pt": "SAM2.1 Small",
|
||||||
|
"sam2.1_hiera_base_plus.pt": "SAM2.1 Base+",
|
||||||
"sam2.1_hiera_large.pt": "SAM2.1 Large (recommended)",
|
"sam2.1_hiera_large.pt": "SAM2.1 Large (recommended)",
|
||||||
"sam2_hiera_base.pt": "SAM2 Base",
|
"sam2_hiera_tiny.pt": "SAM2 Tiny",
|
||||||
"sam2.1_hiera_base.pt": "SAM2.1 Base"
|
"sam2_hiera_small.pt": "SAM2 Small",
|
||||||
|
"sam2_hiera_base_plus.pt": "SAM2 Base+",
|
||||||
|
"sam2_hiera_large.pt": "SAM2 Large"
|
||||||
}
|
}
|
||||||
|
|
||||||
found_models = []
|
found_models = []
|
||||||
for model_file, model_name in sam2_models.items():
|
for model_file, model_name in sam2_models.items():
|
||||||
model_path = models_dir / model_file
|
# Check SAM2 repo location first
|
||||||
if model_path.exists():
|
sam2_path = sam2_checkpoints_dir / model_file
|
||||||
size_mb = model_path.stat().st_size / (1024 * 1024)
|
legacy_path = models_dir / model_file
|
||||||
|
|
||||||
|
if sam2_path.exists():
|
||||||
|
size_mb = sam2_path.stat().st_size / (1024 * 1024)
|
||||||
print(f"✅ {model_name}: {model_file} ({size_mb:.1f} MB)")
|
print(f"✅ {model_name}: {model_file} ({size_mb:.1f} MB)")
|
||||||
found_models.append(model_file)
|
found_models.append((model_file, str(sam2_path)))
|
||||||
|
elif legacy_path.exists():
|
||||||
|
size_mb = legacy_path.stat().st_size / (1024 * 1024)
|
||||||
|
print(f"✅ {model_name}: {model_file} ({size_mb:.1f} MB) [legacy location]")
|
||||||
|
found_models.append((model_file, str(legacy_path)))
|
||||||
|
|
||||||
if not found_models:
|
if not found_models:
|
||||||
print("❌ No SAM2 models found!")
|
print("❌ No SAM2 models found!")
|
||||||
issues.append("No SAM2 models found - run setup script or download manually")
|
issues.append("No SAM2 models found - run setup script or download manually")
|
||||||
else:
|
else:
|
||||||
print(f"\n💡 Recommended config for best model found:")
|
print(f"\n💡 Recommended config for best model found:")
|
||||||
if "sam2.1_hiera_large.pt" in found_models:
|
# Prioritize SAM2.1 models
|
||||||
|
if any("sam2.1_hiera_large.pt" in model[0] for model in found_models):
|
||||||
|
best_model = next(model for model in found_models if "sam2.1_hiera_large.pt" in model[0])
|
||||||
print(" sam2_model_cfg: 'sam2.1_hiera_l'")
|
print(" sam2_model_cfg: 'sam2.1_hiera_l'")
|
||||||
print(" sam2_checkpoint: 'models/sam2.1_hiera_large.pt'")
|
print(f" sam2_checkpoint: '{best_model[1]}'")
|
||||||
elif "sam2_hiera_large.pt" in found_models:
|
elif any("sam2.1_hiera_base_plus.pt" in model[0] for model in found_models):
|
||||||
|
best_model = next(model for model in found_models if "sam2.1_hiera_base_plus.pt" in model[0])
|
||||||
|
print(" sam2_model_cfg: 'sam2.1_hiera_base_plus'")
|
||||||
|
print(f" sam2_checkpoint: '{best_model[1]}'")
|
||||||
|
elif any("sam2_hiera_large.pt" in model[0] for model in found_models):
|
||||||
|
best_model = next(model for model in found_models if "sam2_hiera_large.pt" in model[0])
|
||||||
print(" sam2_model_cfg: 'sam2_hiera_l'")
|
print(" sam2_model_cfg: 'sam2_hiera_l'")
|
||||||
print(" sam2_checkpoint: 'models/sam2_hiera_large.pt'")
|
print(f" sam2_checkpoint: '{best_model[1]}'")
|
||||||
|
|
||||||
# Check SAM2 configs
|
# Check SAM2 configs (now part of installed package)
|
||||||
print("\n🔍 Checking SAM2 config files...")
|
print("\n🔍 Checking SAM2 configuration...")
|
||||||
configs_dir = Path("sam2_configs")
|
try:
|
||||||
if configs_dir.exists():
|
import sam2.sam2_configs
|
||||||
config_files = list(configs_dir.glob("*.yaml"))
|
print("✅ SAM2 configs available in installed package")
|
||||||
if config_files:
|
except ImportError:
|
||||||
print(f"✅ Found {len(config_files)} SAM2 config files")
|
print("❌ SAM2 configs not found")
|
||||||
else:
|
issues.append("SAM2 configs not available - SAM2 may not be properly installed")
|
||||||
print("❌ No SAM2 config files found")
|
|
||||||
issues.append("SAM2 config files missing - may cause model loading errors")
|
|
||||||
else:
|
|
||||||
print("❌ sam2_configs directory not found")
|
|
||||||
issues.append("SAM2 configs directory missing")
|
|
||||||
|
|
||||||
# Test model loading if possible
|
# Test model loading if possible
|
||||||
if not any("SAM2 not installed" in issue for issue in issues):
|
if not any("SAM2 not installed" in issue for issue in issues):
|
||||||
@@ -119,6 +134,22 @@ if not any("SAM2 not installed" in issue for issue in issues):
|
|||||||
try:
|
try:
|
||||||
# Try to load the default model config
|
# Try to load the default model config
|
||||||
from vr180_matting.config import VR180Config
|
from vr180_matting.config import VR180Config
|
||||||
|
# Use the best available model
|
||||||
|
if found_models:
|
||||||
|
best_model = found_models[0] # Use first found model (prioritized)
|
||||||
|
sam2_checkpoint = best_model[1]
|
||||||
|
if "sam2.1_hiera_large.pt" in best_model[0]:
|
||||||
|
sam2_cfg = "sam2.1_hiera_l"
|
||||||
|
elif "sam2.1_hiera_base_plus.pt" in best_model[0]:
|
||||||
|
sam2_cfg = "sam2.1_hiera_base_plus"
|
||||||
|
elif "sam2_hiera_large.pt" in best_model[0]:
|
||||||
|
sam2_cfg = "sam2_hiera_l"
|
||||||
|
else:
|
||||||
|
sam2_cfg = "sam2.1_hiera_l"
|
||||||
|
else:
|
||||||
|
sam2_cfg = "sam2.1_hiera_l"
|
||||||
|
sam2_checkpoint = "segment-anything-2/checkpoints/sam2.1_hiera_large.pt"
|
||||||
|
|
||||||
config = VR180Config(
|
config = VR180Config(
|
||||||
input=type('obj', (object,), {'video_path': 'test.mp4'})(),
|
input=type('obj', (object,), {'video_path': 'test.mp4'})(),
|
||||||
processing=type('obj', (object,), {'scale_factor': 0.5, 'chunk_size': 900, 'overlap_frames': 60})(),
|
processing=type('obj', (object,), {'scale_factor': 0.5, 'chunk_size': 900, 'overlap_frames': 60})(),
|
||||||
@@ -127,8 +158,8 @@ if not any("SAM2 not installed" in issue for issue in issues):
|
|||||||
'use_disparity_mapping': True,
|
'use_disparity_mapping': True,
|
||||||
'memory_offload': True,
|
'memory_offload': True,
|
||||||
'fp16': True,
|
'fp16': True,
|
||||||
'sam2_model_cfg': 'sam2.1_hiera_l' if Path('models/sam2.1_hiera_large.pt').exists() else 'sam2_hiera_l',
|
'sam2_model_cfg': sam2_cfg,
|
||||||
'sam2_checkpoint': 'models/sam2.1_hiera_large.pt' if Path('models/sam2.1_hiera_large.pt').exists() else 'models/sam2_hiera_large.pt'
|
'sam2_checkpoint': sam2_checkpoint
|
||||||
})(),
|
})(),
|
||||||
output=type('obj', (object,), {'path': 'output/', 'format': 'alpha', 'background_color': [0, 255, 0], 'maintain_sbs': True})(),
|
output=type('obj', (object,), {'path': 'output/', 'format': 'alpha', 'background_color': [0, 255, 0], 'maintain_sbs': True})(),
|
||||||
hardware=type('obj', (object,), {'device': 'cuda' if torch.cuda.is_available() else 'cpu', 'max_vram_gb': 10})()
|
hardware=type('obj', (object,), {'device': 'cuda' if torch.cuda.is_available() else 'cpu', 'max_vram_gb': 10})()
|
||||||
|
|||||||
@@ -27,8 +27,8 @@ class MattingConfig:
|
|||||||
use_disparity_mapping: bool = True
|
use_disparity_mapping: bool = True
|
||||||
memory_offload: bool = True
|
memory_offload: bool = True
|
||||||
fp16: bool = True
|
fp16: bool = True
|
||||||
sam2_model_cfg: str = "sam2_hiera_l"
|
sam2_model_cfg: str = "sam2.1_hiera_l"
|
||||||
sam2_checkpoint: str = "sam2_hiera_large.pt"
|
sam2_checkpoint: str = "segment-anything-2/checkpoints/sam2.1_hiera_large.pt"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -111,6 +111,8 @@ matting:
|
|||||||
use_disparity_mapping: true
|
use_disparity_mapping: true
|
||||||
memory_offload: true
|
memory_offload: true
|
||||||
fp16: true
|
fp16: true
|
||||||
|
sam2_model_cfg: "sam2.1_hiera_l"
|
||||||
|
sam2_checkpoint: "segment-anything-2/checkpoints/sam2.1_hiera_large.pt"
|
||||||
|
|
||||||
output:
|
output:
|
||||||
path: "path/to/output/"
|
path: "path/to/output/"
|
||||||
|
|||||||
@@ -39,18 +39,23 @@ class SAM2VideoMatting:
|
|||||||
def _load_model(self, model_cfg: str, checkpoint_path: str):
|
def _load_model(self, model_cfg: str, checkpoint_path: str):
|
||||||
"""Load SAM2 video predictor with optimizations"""
|
"""Load SAM2 video predictor with optimizations"""
|
||||||
try:
|
try:
|
||||||
# Check for checkpoint in models directory if not found
|
# Check for checkpoint in SAM2 repo structure
|
||||||
if not Path(checkpoint_path).exists():
|
if not Path(checkpoint_path).exists():
|
||||||
models_path = Path("models") / checkpoint_path
|
# Try in segment-anything-2/checkpoints/
|
||||||
if models_path.exists():
|
sam2_path = Path("segment-anything-2/checkpoints") / Path(checkpoint_path).name
|
||||||
checkpoint_path = str(models_path)
|
if sam2_path.exists():
|
||||||
|
checkpoint_path = str(sam2_path)
|
||||||
else:
|
else:
|
||||||
# Try relative to package
|
# Try legacy models/ directory
|
||||||
import os
|
models_path = Path("models") / Path(checkpoint_path).name
|
||||||
package_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
||||||
models_path = Path(package_dir) / "models" / checkpoint_path
|
|
||||||
if models_path.exists():
|
if models_path.exists():
|
||||||
checkpoint_path = str(models_path)
|
checkpoint_path = str(models_path)
|
||||||
|
else:
|
||||||
|
# Try relative to package
|
||||||
|
package_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
sam2_repo_path = Path(package_dir) / "segment-anything-2/checkpoints" / Path(checkpoint_path).name
|
||||||
|
if sam2_repo_path.exists():
|
||||||
|
checkpoint_path = str(sam2_repo_path)
|
||||||
|
|
||||||
self.predictor = build_sam2_video_predictor(
|
self.predictor = build_sam2_video_predictor(
|
||||||
model_cfg,
|
model_cfg,
|
||||||
|
|||||||
Reference in New Issue
Block a user