fix sam2 hopefully

This commit is contained in:
2025-07-26 08:03:37 -07:00
parent 1bec8113de
commit eeed9ee578
6 changed files with 187 additions and 17 deletions

View File

@@ -14,7 +14,7 @@ matting:
use_disparity_mapping: true use_disparity_mapping: true
memory_offload: true memory_offload: true
fp16: true fp16: true
sam2_model_cfg: "sam2_hiera_l.yaml" sam2_model_cfg: "sam2_hiera_l"
sam2_checkpoint: "models/sam2_hiera_large.pt" sam2_checkpoint: "models/sam2_hiera_large.pt"
output: output:

View File

@@ -14,8 +14,8 @@ matting:
use_disparity_mapping: true use_disparity_mapping: true
memory_offload: false # A40 has enough VRAM memory_offload: false # A40 has enough VRAM
fp16: true fp16: true
sam2_model_cfg: "sam2_hiera_l.yaml" sam2_model_cfg: "sam2.1_hiera_l"
sam2_checkpoint: "models/sam2_hiera_large.pt" sam2_checkpoint: "models/sam2.1_hiera_large.pt"
output: output:
path: "/workspace/output/matted_video.mp4" path: "/workspace/output/matted_video.mp4"

View File

@@ -29,12 +29,35 @@ mkdir -p models
# Download YOLOv8 models # Download YOLOv8 models
python -c "from ultralytics import YOLO; YOLO('yolov8n.pt'); YOLO('yolov8m.pt')" python -c "from ultralytics import YOLO; YOLO('yolov8n.pt'); YOLO('yolov8m.pt')"
# Download SAM2 checkpoint # Download SAM2 checkpoints
cd models cd models
echo "📥 Downloading SAM2 models..."
# Try different SAM2 model versions
if [ ! -f "sam2_hiera_large.pt" ]; then if [ ! -f "sam2_hiera_large.pt" ]; then
echo "Downloading SAM2 model weights..." echo "Trying SAM2 checkpoint version 1..."
wget -q --show-progress https://dl.fbaipublicfiles.com/segment_anything_2/sam2_hiera_large.pt wget -q --show-progress https://dl.fbaipublicfiles.com/segment_anything_2/sam2_hiera_large.pt || true
fi fi
if [ ! -f "sam2.1_hiera_large.pt" ]; then
echo "Trying SAM2.1 checkpoint (latest)..."
wget -q --show-progress https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt || true
fi
# Download SAM2 config files
cd ..
mkdir -p sam2_configs
cd sam2_configs
echo "📥 Downloading SAM2 configuration files..."
wget -q --show-progress https://raw.githubusercontent.com/facebookresearch/segment-anything-2/main/sam2_configs/sam2_hiera_b+.yaml
wget -q --show-progress https://raw.githubusercontent.com/facebookresearch/segment-anything-2/main/sam2_configs/sam2_hiera_l.yaml
wget -q --show-progress https://raw.githubusercontent.com/facebookresearch/segment-anything-2/main/sam2_configs/sam2_hiera_s.yaml
wget -q --show-progress https://raw.githubusercontent.com/facebookresearch/segment-anything-2/main/sam2_configs/sam2_hiera_t.yaml
wget -q --show-progress https://raw.githubusercontent.com/facebookresearch/segment-anything-2/main/sam2_configs/sam2.1_hiera_b+.yaml
wget -q --show-progress https://raw.githubusercontent.com/facebookresearch/segment-anything-2/main/sam2_configs/sam2.1_hiera_l.yaml
wget -q --show-progress https://raw.githubusercontent.com/facebookresearch/segment-anything-2/main/sam2_configs/sam2.1_hiera_s.yaml
wget -q --show-progress https://raw.githubusercontent.com/facebookresearch/segment-anything-2/main/sam2_configs/sam2.1_hiera_t.yaml
cd .. cd ..
# Create working directories # Create working directories
@@ -45,6 +68,18 @@ echo ""
echo "🧪 Testing installation..." echo "🧪 Testing installation..."
python test_installation.py python test_installation.py
# Check which SAM2 model is available
echo ""
echo "📊 SAM2 Models available:"
if [ -f "models/sam2_hiera_large.pt" ]; then
echo " ✅ sam2_hiera_large.pt"
echo " Use in config: sam2_checkpoint: 'models/sam2_hiera_large.pt'"
fi
if [ -f "models/sam2.1_hiera_large.pt" ]; then
echo " ✅ sam2.1_hiera_large.pt (recommended)"
echo " Use in config: sam2_checkpoint: 'models/sam2.1_hiera_large.pt'"
fi
echo "" echo ""
echo "✅ Setup complete!" echo "✅ Setup complete!"
echo "" echo ""

View File

@@ -5,12 +5,19 @@ import sys
import torch import torch
import cv2 import cv2
import numpy as np import numpy as np
from pathlib import Path
import os
print("VR180 Matting Installation Test") print("VR180 Matting Installation Test")
print("=" * 50) print("=" * 50)
# Track all issues
issues = []
# Check Python version # Check Python version
print(f"Python version: {sys.version}") print(f"Python version: {sys.version}")
if sys.version_info < (3, 8):
issues.append("Python 3.8+ required")
# Check PyTorch and CUDA # Check PyTorch and CUDA
print(f"\nPyTorch version: {torch.__version__}") print(f"\nPyTorch version: {torch.__version__}")
@@ -19,16 +26,29 @@ if torch.cuda.is_available():
print(f"CUDA version: {torch.version.cuda}") print(f"CUDA version: {torch.version.cuda}")
print(f"GPU: {torch.cuda.get_device_name(0)}") print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB") print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
else:
issues.append("No CUDA GPU detected - will run slowly on CPU")
# Check OpenCV # Check OpenCV
print(f"\nOpenCV version: {cv2.__version__}") print(f"\nOpenCV version: {cv2.__version__}")
# Test imports # Test imports
print("\n🔍 Testing imports...")
try: try:
from ultralytics import YOLO from ultralytics import YOLO
print("\n✅ YOLO import successful") print("✅ YOLO import successful")
# Check if YOLO models exist
yolo_models = ["yolov8n.pt", "yolov8s.pt", "yolov8m.pt"]
available_yolo = [m for m in yolo_models if Path(m).exists()]
if available_yolo:
print(f" Found YOLO models: {', '.join(available_yolo)}")
else:
issues.append("No YOLO models found - will download on first run")
except ImportError as e: except ImportError as e:
print(f"\n❌ YOLO import failed: {e}") print(f"❌ YOLO import failed: {e}")
issues.append("Ultralytics YOLO not installed")
try: try:
from sam2.build_sam import build_sam2_video_predictor from sam2.build_sam import build_sam2_video_predictor
@@ -36,6 +56,7 @@ try:
except ImportError as e: except ImportError as e:
print(f"❌ SAM2 import failed: {e}") print(f"❌ SAM2 import failed: {e}")
print(" Install with: pip install git+https://github.com/facebookresearch/segment-anything-2.git") print(" Install with: pip install git+https://github.com/facebookresearch/segment-anything-2.git")
issues.append("SAM2 not installed")
try: try:
from vr180_matting.config import VR180Config from vr180_matting.config import VR180Config
@@ -46,6 +67,102 @@ try:
except ImportError as e: except ImportError as e:
print(f"❌ VR180 matting import failed: {e}") print(f"❌ VR180 matting import failed: {e}")
print(" Make sure to run: pip install -e .") print(" Make sure to run: pip install -e .")
issues.append("VR180 matting package not installed")
# Check SAM2 models
print("\n🔍 Checking SAM2 models...")
models_dir = Path("models")
sam2_models = {
"sam2_hiera_large.pt": "Original SAM2 Large",
"sam2.1_hiera_large.pt": "SAM2.1 Large (recommended)",
"sam2_hiera_base.pt": "SAM2 Base",
"sam2.1_hiera_base.pt": "SAM2.1 Base"
}
found_models = []
for model_file, model_name in sam2_models.items():
model_path = models_dir / model_file
if model_path.exists():
size_mb = model_path.stat().st_size / (1024 * 1024)
print(f"{model_name}: {model_file} ({size_mb:.1f} MB)")
found_models.append(model_file)
if not found_models:
print("❌ No SAM2 models found!")
issues.append("No SAM2 models found - run setup script or download manually")
else:
print(f"\n💡 Recommended config for best model found:")
if "sam2.1_hiera_large.pt" in found_models:
print(" sam2_model_cfg: 'sam2.1_hiera_l'")
print(" sam2_checkpoint: 'models/sam2.1_hiera_large.pt'")
elif "sam2_hiera_large.pt" in found_models:
print(" sam2_model_cfg: 'sam2_hiera_l'")
print(" sam2_checkpoint: 'models/sam2_hiera_large.pt'")
# Check SAM2 configs
print("\n🔍 Checking SAM2 config files...")
configs_dir = Path("sam2_configs")
if configs_dir.exists():
config_files = list(configs_dir.glob("*.yaml"))
if config_files:
print(f"✅ Found {len(config_files)} SAM2 config files")
else:
print("❌ No SAM2 config files found")
issues.append("SAM2 config files missing - may cause model loading errors")
else:
print("❌ sam2_configs directory not found")
issues.append("SAM2 configs directory missing")
# Test model loading if possible
if not any("SAM2 not installed" in issue for issue in issues):
print("\n🧪 Testing SAM2 model loading...")
try:
# Try to load the default model config
from vr180_matting.config import VR180Config
config = VR180Config(
input=type('obj', (object,), {'video_path': 'test.mp4'})(),
processing=type('obj', (object,), {'scale_factor': 0.5, 'chunk_size': 900, 'overlap_frames': 60})(),
detection=type('obj', (object,), {'confidence_threshold': 0.7, 'model': 'yolov8n'})(),
matting=type('obj', (object,), {
'use_disparity_mapping': True,
'memory_offload': True,
'fp16': True,
'sam2_model_cfg': 'sam2.1_hiera_l' if Path('models/sam2.1_hiera_large.pt').exists() else 'sam2_hiera_l',
'sam2_checkpoint': 'models/sam2.1_hiera_large.pt' if Path('models/sam2.1_hiera_large.pt').exists() else 'models/sam2_hiera_large.pt'
})(),
output=type('obj', (object,), {'path': 'output/', 'format': 'alpha', 'background_color': [0, 255, 0], 'maintain_sbs': True})(),
hardware=type('obj', (object,), {'device': 'cuda' if torch.cuda.is_available() else 'cpu', 'max_vram_gb': 10})()
)
# Try loading just the model config check
model_path = Path(config.matting.sam2_checkpoint)
if model_path.exists():
print(f"✅ Found checkpoint: {model_path}")
# Quick check of checkpoint structure
checkpoint = torch.load(model_path, map_location='cpu')
if 'model' in checkpoint:
print(f" Model has {len(checkpoint['model'])} parameters")
else:
print(f"❌ Checkpoint not found: {model_path}")
issues.append(f"SAM2 checkpoint missing: {model_path}")
except Exception as e:
print(f"⚠️ Could not test model loading: {e}")
# Check ffmpeg
print("\n🔍 Checking ffmpeg...")
import subprocess
try:
result = subprocess.run(['ffmpeg', '-version'], capture_output=True, text=True)
if result.returncode == 0:
version_line = result.stdout.split('\n')[0]
print(f"{version_line}")
else:
print("❌ ffmpeg error")
issues.append("ffmpeg not working properly")
except FileNotFoundError:
print("❌ ffmpeg not found")
issues.append("ffmpeg not installed - required for video processing")
# Memory check # Memory check
if torch.cuda.is_available(): if torch.cuda.is_available():
@@ -54,11 +171,29 @@ if torch.cuda.is_available():
print(f" Reserved: {torch.cuda.memory_reserved() / 1024**3:.2f} GB") print(f" Reserved: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")
print(f" Free: {(torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_reserved()) / 1024**3:.2f} GB") print(f" Free: {(torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_reserved()) / 1024**3:.2f} GB")
# Check example config
print("\n🔍 Checking configuration files...")
example_configs = ["config_example.yaml", "config_runpod.yaml", "config.yaml"]
found_configs = [c for c in example_configs if Path(c).exists()]
if found_configs:
print(f"✅ Found configs: {', '.join(found_configs)}")
else:
print("❌ No config files found")
issues.append("No configuration files found - generate with: vr180-matting --generate-config config.yaml")
# Summary
print("\n" + "=" * 50) print("\n" + "=" * 50)
print("Installation test complete!") if issues:
print("\nNext steps:") print("❌ Issues found:")
print("1. If any imports failed, install missing dependencies") for i, issue in enumerate(issues, 1):
print("2. Download SAM2 model weights if needed") print(f" {i}. {issue}")
print("3. Run: vr180-matting --generate-config config.yaml") print("\n📋 To fix issues:")
print("4. Edit config.yaml with your video path") print(" 1. Run: ./runpod_setup.sh")
print("5. Run: vr180-matting config.yaml") print(" 2. Make sure SAM2 is installed: pip install git+https://github.com/facebookresearch/segment-anything-2.git")
print(" 3. Install package: pip install -e .")
else:
print("✅ All checks passed! Ready to process videos.")
print("\n📋 Quick start:")
print(" 1. Copy example config: cp config_runpod.yaml config.yaml")
print(" 2. Edit config.yaml with your video path")
print(" 3. Run: vr180-matting config.yaml")

View File

@@ -27,7 +27,7 @@ class MattingConfig:
use_disparity_mapping: bool = True use_disparity_mapping: bool = True
memory_offload: bool = True memory_offload: bool = True
fp16: bool = True fp16: bool = True
sam2_model_cfg: str = "sam2_hiera_l.yaml" sam2_model_cfg: str = "sam2_hiera_l"
sam2_checkpoint: str = "sam2_hiera_large.pt" sam2_checkpoint: str = "sam2_hiera_large.pt"

View File

@@ -19,7 +19,7 @@ class SAM2VideoMatting:
"""SAM2-based video matting with memory optimization""" """SAM2-based video matting with memory optimization"""
def __init__(self, def __init__(self,
model_cfg: str = "sam2_hiera_l.yaml", model_cfg: str = "sam2_hiera_l",
checkpoint_path: str = "sam2_hiera_large.pt", checkpoint_path: str = "sam2_hiera_large.pt",
device: str = "cuda", device: str = "cuda",
memory_offload: bool = True, memory_offload: bool = True,