install sam2 the way facebook says
This commit is contained in:
@@ -71,47 +71,62 @@ except ImportError as e:
|
||||
|
||||
# Check SAM2 models
|
||||
print("\n🔍 Checking SAM2 models...")
|
||||
models_dir = Path("models")
|
||||
sam2_checkpoints_dir = Path("segment-anything-2/checkpoints")
|
||||
models_dir = Path("models") # Legacy location
|
||||
|
||||
sam2_models = {
|
||||
"sam2_hiera_large.pt": "Original SAM2 Large",
|
||||
"sam2.1_hiera_tiny.pt": "SAM2.1 Tiny",
|
||||
"sam2.1_hiera_small.pt": "SAM2.1 Small",
|
||||
"sam2.1_hiera_base_plus.pt": "SAM2.1 Base+",
|
||||
"sam2.1_hiera_large.pt": "SAM2.1 Large (recommended)",
|
||||
"sam2_hiera_base.pt": "SAM2 Base",
|
||||
"sam2.1_hiera_base.pt": "SAM2.1 Base"
|
||||
"sam2_hiera_tiny.pt": "SAM2 Tiny",
|
||||
"sam2_hiera_small.pt": "SAM2 Small",
|
||||
"sam2_hiera_base_plus.pt": "SAM2 Base+",
|
||||
"sam2_hiera_large.pt": "SAM2 Large"
|
||||
}
|
||||
|
||||
found_models = []
|
||||
for model_file, model_name in sam2_models.items():
|
||||
model_path = models_dir / model_file
|
||||
if model_path.exists():
|
||||
size_mb = model_path.stat().st_size / (1024 * 1024)
|
||||
# Check SAM2 repo location first
|
||||
sam2_path = sam2_checkpoints_dir / model_file
|
||||
legacy_path = models_dir / model_file
|
||||
|
||||
if sam2_path.exists():
|
||||
size_mb = sam2_path.stat().st_size / (1024 * 1024)
|
||||
print(f"✅ {model_name}: {model_file} ({size_mb:.1f} MB)")
|
||||
found_models.append(model_file)
|
||||
found_models.append((model_file, str(sam2_path)))
|
||||
elif legacy_path.exists():
|
||||
size_mb = legacy_path.stat().st_size / (1024 * 1024)
|
||||
print(f"✅ {model_name}: {model_file} ({size_mb:.1f} MB) [legacy location]")
|
||||
found_models.append((model_file, str(legacy_path)))
|
||||
|
||||
if not found_models:
|
||||
print("❌ No SAM2 models found!")
|
||||
issues.append("No SAM2 models found - run setup script or download manually")
|
||||
else:
|
||||
print(f"\n💡 Recommended config for best model found:")
|
||||
if "sam2.1_hiera_large.pt" in found_models:
|
||||
# Prioritize SAM2.1 models
|
||||
if any("sam2.1_hiera_large.pt" in model[0] for model in found_models):
|
||||
best_model = next(model for model in found_models if "sam2.1_hiera_large.pt" in model[0])
|
||||
print(" sam2_model_cfg: 'sam2.1_hiera_l'")
|
||||
print(" sam2_checkpoint: 'models/sam2.1_hiera_large.pt'")
|
||||
elif "sam2_hiera_large.pt" in found_models:
|
||||
print(f" sam2_checkpoint: '{best_model[1]}'")
|
||||
elif any("sam2.1_hiera_base_plus.pt" in model[0] for model in found_models):
|
||||
best_model = next(model for model in found_models if "sam2.1_hiera_base_plus.pt" in model[0])
|
||||
print(" sam2_model_cfg: 'sam2.1_hiera_base_plus'")
|
||||
print(f" sam2_checkpoint: '{best_model[1]}'")
|
||||
elif any("sam2_hiera_large.pt" in model[0] for model in found_models):
|
||||
best_model = next(model for model in found_models if "sam2_hiera_large.pt" in model[0])
|
||||
print(" sam2_model_cfg: 'sam2_hiera_l'")
|
||||
print(" sam2_checkpoint: 'models/sam2_hiera_large.pt'")
|
||||
print(f" sam2_checkpoint: '{best_model[1]}'")
|
||||
|
||||
# Check SAM2 configs
|
||||
print("\n🔍 Checking SAM2 config files...")
|
||||
configs_dir = Path("sam2_configs")
|
||||
if configs_dir.exists():
|
||||
config_files = list(configs_dir.glob("*.yaml"))
|
||||
if config_files:
|
||||
print(f"✅ Found {len(config_files)} SAM2 config files")
|
||||
else:
|
||||
print("❌ No SAM2 config files found")
|
||||
issues.append("SAM2 config files missing - may cause model loading errors")
|
||||
else:
|
||||
print("❌ sam2_configs directory not found")
|
||||
issues.append("SAM2 configs directory missing")
|
||||
# Check SAM2 configs (now part of installed package)
|
||||
print("\n🔍 Checking SAM2 configuration...")
|
||||
try:
|
||||
import sam2.sam2_configs
|
||||
print("✅ SAM2 configs available in installed package")
|
||||
except ImportError:
|
||||
print("❌ SAM2 configs not found")
|
||||
issues.append("SAM2 configs not available - SAM2 may not be properly installed")
|
||||
|
||||
# Test model loading if possible
|
||||
if not any("SAM2 not installed" in issue for issue in issues):
|
||||
@@ -119,6 +134,22 @@ if not any("SAM2 not installed" in issue for issue in issues):
|
||||
try:
|
||||
# Try to load the default model config
|
||||
from vr180_matting.config import VR180Config
|
||||
# Use the best available model
|
||||
if found_models:
|
||||
best_model = found_models[0] # Use first found model (prioritized)
|
||||
sam2_checkpoint = best_model[1]
|
||||
if "sam2.1_hiera_large.pt" in best_model[0]:
|
||||
sam2_cfg = "sam2.1_hiera_l"
|
||||
elif "sam2.1_hiera_base_plus.pt" in best_model[0]:
|
||||
sam2_cfg = "sam2.1_hiera_base_plus"
|
||||
elif "sam2_hiera_large.pt" in best_model[0]:
|
||||
sam2_cfg = "sam2_hiera_l"
|
||||
else:
|
||||
sam2_cfg = "sam2.1_hiera_l"
|
||||
else:
|
||||
sam2_cfg = "sam2.1_hiera_l"
|
||||
sam2_checkpoint = "segment-anything-2/checkpoints/sam2.1_hiera_large.pt"
|
||||
|
||||
config = VR180Config(
|
||||
input=type('obj', (object,), {'video_path': 'test.mp4'})(),
|
||||
processing=type('obj', (object,), {'scale_factor': 0.5, 'chunk_size': 900, 'overlap_frames': 60})(),
|
||||
@@ -127,8 +158,8 @@ if not any("SAM2 not installed" in issue for issue in issues):
|
||||
'use_disparity_mapping': True,
|
||||
'memory_offload': True,
|
||||
'fp16': True,
|
||||
'sam2_model_cfg': 'sam2.1_hiera_l' if Path('models/sam2.1_hiera_large.pt').exists() else 'sam2_hiera_l',
|
||||
'sam2_checkpoint': 'models/sam2.1_hiera_large.pt' if Path('models/sam2.1_hiera_large.pt').exists() else 'models/sam2_hiera_large.pt'
|
||||
'sam2_model_cfg': sam2_cfg,
|
||||
'sam2_checkpoint': sam2_checkpoint
|
||||
})(),
|
||||
output=type('obj', (object,), {'path': 'output/', 'format': 'alpha', 'background_color': [0, 255, 0], 'maintain_sbs': True})(),
|
||||
hardware=type('obj', (object,), {'device': 'cuda' if torch.cuda.is_available() else 'cpu', 'max_vram_gb': 10})()
|
||||
|
||||
Reference in New Issue
Block a user