fix sam2 hopefully
This commit is contained in:
@@ -14,7 +14,7 @@ matting:
|
|||||||
use_disparity_mapping: true
|
use_disparity_mapping: true
|
||||||
memory_offload: true
|
memory_offload: true
|
||||||
fp16: true
|
fp16: true
|
||||||
sam2_model_cfg: "sam2_hiera_l.yaml"
|
sam2_model_cfg: "sam2_hiera_l"
|
||||||
sam2_checkpoint: "models/sam2_hiera_large.pt"
|
sam2_checkpoint: "models/sam2_hiera_large.pt"
|
||||||
|
|
||||||
output:
|
output:
|
||||||
|
|||||||
@@ -14,8 +14,8 @@ matting:
|
|||||||
use_disparity_mapping: true
|
use_disparity_mapping: true
|
||||||
memory_offload: false # A40 has enough VRAM
|
memory_offload: false # A40 has enough VRAM
|
||||||
fp16: true
|
fp16: true
|
||||||
sam2_model_cfg: "sam2_hiera_l.yaml"
|
sam2_model_cfg: "sam2.1_hiera_l"
|
||||||
sam2_checkpoint: "models/sam2_hiera_large.pt"
|
sam2_checkpoint: "models/sam2.1_hiera_large.pt"
|
||||||
|
|
||||||
output:
|
output:
|
||||||
path: "/workspace/output/matted_video.mp4"
|
path: "/workspace/output/matted_video.mp4"
|
||||||
|
|||||||
@@ -29,12 +29,35 @@ mkdir -p models
|
|||||||
# Download YOLOv8 models
|
# Download YOLOv8 models
|
||||||
python -c "from ultralytics import YOLO; YOLO('yolov8n.pt'); YOLO('yolov8m.pt')"
|
python -c "from ultralytics import YOLO; YOLO('yolov8n.pt'); YOLO('yolov8m.pt')"
|
||||||
|
|
||||||
# Download SAM2 checkpoint
|
# Download SAM2 checkpoints
|
||||||
cd models
|
cd models
|
||||||
|
echo "📥 Downloading SAM2 models..."
|
||||||
|
|
||||||
|
# Try different SAM2 model versions
|
||||||
if [ ! -f "sam2_hiera_large.pt" ]; then
|
if [ ! -f "sam2_hiera_large.pt" ]; then
|
||||||
echo "Downloading SAM2 model weights..."
|
echo "Trying SAM2 checkpoint version 1..."
|
||||||
wget -q --show-progress https://dl.fbaipublicfiles.com/segment_anything_2/sam2_hiera_large.pt
|
wget -q --show-progress https://dl.fbaipublicfiles.com/segment_anything_2/sam2_hiera_large.pt || true
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
if [ ! -f "sam2.1_hiera_large.pt" ]; then
|
||||||
|
echo "Trying SAM2.1 checkpoint (latest)..."
|
||||||
|
wget -q --show-progress https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt || true
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Download SAM2 config files
|
||||||
|
cd ..
|
||||||
|
mkdir -p sam2_configs
|
||||||
|
cd sam2_configs
|
||||||
|
echo "📥 Downloading SAM2 configuration files..."
|
||||||
|
wget -q --show-progress https://raw.githubusercontent.com/facebookresearch/segment-anything-2/main/sam2_configs/sam2_hiera_b+.yaml
|
||||||
|
wget -q --show-progress https://raw.githubusercontent.com/facebookresearch/segment-anything-2/main/sam2_configs/sam2_hiera_l.yaml
|
||||||
|
wget -q --show-progress https://raw.githubusercontent.com/facebookresearch/segment-anything-2/main/sam2_configs/sam2_hiera_s.yaml
|
||||||
|
wget -q --show-progress https://raw.githubusercontent.com/facebookresearch/segment-anything-2/main/sam2_configs/sam2_hiera_t.yaml
|
||||||
|
wget -q --show-progress https://raw.githubusercontent.com/facebookresearch/segment-anything-2/main/sam2_configs/sam2.1_hiera_b+.yaml
|
||||||
|
wget -q --show-progress https://raw.githubusercontent.com/facebookresearch/segment-anything-2/main/sam2_configs/sam2.1_hiera_l.yaml
|
||||||
|
wget -q --show-progress https://raw.githubusercontent.com/facebookresearch/segment-anything-2/main/sam2_configs/sam2.1_hiera_s.yaml
|
||||||
|
wget -q --show-progress https://raw.githubusercontent.com/facebookresearch/segment-anything-2/main/sam2_configs/sam2.1_hiera_t.yaml
|
||||||
|
|
||||||
cd ..
|
cd ..
|
||||||
|
|
||||||
# Create working directories
|
# Create working directories
|
||||||
@@ -45,6 +68,18 @@ echo ""
|
|||||||
echo "🧪 Testing installation..."
|
echo "🧪 Testing installation..."
|
||||||
python test_installation.py
|
python test_installation.py
|
||||||
|
|
||||||
|
# Check which SAM2 model is available
|
||||||
|
echo ""
|
||||||
|
echo "📊 SAM2 Models available:"
|
||||||
|
if [ -f "models/sam2_hiera_large.pt" ]; then
|
||||||
|
echo " ✅ sam2_hiera_large.pt"
|
||||||
|
echo " Use in config: sam2_checkpoint: 'models/sam2_hiera_large.pt'"
|
||||||
|
fi
|
||||||
|
if [ -f "models/sam2.1_hiera_large.pt" ]; then
|
||||||
|
echo " ✅ sam2.1_hiera_large.pt (recommended)"
|
||||||
|
echo " Use in config: sam2_checkpoint: 'models/sam2.1_hiera_large.pt'"
|
||||||
|
fi
|
||||||
|
|
||||||
echo ""
|
echo ""
|
||||||
echo "✅ Setup complete!"
|
echo "✅ Setup complete!"
|
||||||
echo ""
|
echo ""
|
||||||
|
|||||||
@@ -5,12 +5,19 @@ import sys
|
|||||||
import torch
|
import torch
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from pathlib import Path
|
||||||
|
import os
|
||||||
|
|
||||||
print("VR180 Matting Installation Test")
|
print("VR180 Matting Installation Test")
|
||||||
print("=" * 50)
|
print("=" * 50)
|
||||||
|
|
||||||
|
# Track all issues
|
||||||
|
issues = []
|
||||||
|
|
||||||
# Check Python version
|
# Check Python version
|
||||||
print(f"Python version: {sys.version}")
|
print(f"Python version: {sys.version}")
|
||||||
|
if sys.version_info < (3, 8):
|
||||||
|
issues.append("Python 3.8+ required")
|
||||||
|
|
||||||
# Check PyTorch and CUDA
|
# Check PyTorch and CUDA
|
||||||
print(f"\nPyTorch version: {torch.__version__}")
|
print(f"\nPyTorch version: {torch.__version__}")
|
||||||
@@ -19,16 +26,29 @@ if torch.cuda.is_available():
|
|||||||
print(f"CUDA version: {torch.version.cuda}")
|
print(f"CUDA version: {torch.version.cuda}")
|
||||||
print(f"GPU: {torch.cuda.get_device_name(0)}")
|
print(f"GPU: {torch.cuda.get_device_name(0)}")
|
||||||
print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
|
print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
|
||||||
|
else:
|
||||||
|
issues.append("No CUDA GPU detected - will run slowly on CPU")
|
||||||
|
|
||||||
# Check OpenCV
|
# Check OpenCV
|
||||||
print(f"\nOpenCV version: {cv2.__version__}")
|
print(f"\nOpenCV version: {cv2.__version__}")
|
||||||
|
|
||||||
# Test imports
|
# Test imports
|
||||||
|
print("\n🔍 Testing imports...")
|
||||||
try:
|
try:
|
||||||
from ultralytics import YOLO
|
from ultralytics import YOLO
|
||||||
print("\n✅ YOLO import successful")
|
print("✅ YOLO import successful")
|
||||||
|
|
||||||
|
# Check if YOLO models exist
|
||||||
|
yolo_models = ["yolov8n.pt", "yolov8s.pt", "yolov8m.pt"]
|
||||||
|
available_yolo = [m for m in yolo_models if Path(m).exists()]
|
||||||
|
if available_yolo:
|
||||||
|
print(f" Found YOLO models: {', '.join(available_yolo)}")
|
||||||
|
else:
|
||||||
|
issues.append("No YOLO models found - will download on first run")
|
||||||
|
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
print(f"\n❌ YOLO import failed: {e}")
|
print(f"❌ YOLO import failed: {e}")
|
||||||
|
issues.append("Ultralytics YOLO not installed")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from sam2.build_sam import build_sam2_video_predictor
|
from sam2.build_sam import build_sam2_video_predictor
|
||||||
@@ -36,6 +56,7 @@ try:
|
|||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
print(f"❌ SAM2 import failed: {e}")
|
print(f"❌ SAM2 import failed: {e}")
|
||||||
print(" Install with: pip install git+https://github.com/facebookresearch/segment-anything-2.git")
|
print(" Install with: pip install git+https://github.com/facebookresearch/segment-anything-2.git")
|
||||||
|
issues.append("SAM2 not installed")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from vr180_matting.config import VR180Config
|
from vr180_matting.config import VR180Config
|
||||||
@@ -46,6 +67,102 @@ try:
|
|||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
print(f"❌ VR180 matting import failed: {e}")
|
print(f"❌ VR180 matting import failed: {e}")
|
||||||
print(" Make sure to run: pip install -e .")
|
print(" Make sure to run: pip install -e .")
|
||||||
|
issues.append("VR180 matting package not installed")
|
||||||
|
|
||||||
|
# Check SAM2 models
|
||||||
|
print("\n🔍 Checking SAM2 models...")
|
||||||
|
models_dir = Path("models")
|
||||||
|
sam2_models = {
|
||||||
|
"sam2_hiera_large.pt": "Original SAM2 Large",
|
||||||
|
"sam2.1_hiera_large.pt": "SAM2.1 Large (recommended)",
|
||||||
|
"sam2_hiera_base.pt": "SAM2 Base",
|
||||||
|
"sam2.1_hiera_base.pt": "SAM2.1 Base"
|
||||||
|
}
|
||||||
|
|
||||||
|
found_models = []
|
||||||
|
for model_file, model_name in sam2_models.items():
|
||||||
|
model_path = models_dir / model_file
|
||||||
|
if model_path.exists():
|
||||||
|
size_mb = model_path.stat().st_size / (1024 * 1024)
|
||||||
|
print(f"✅ {model_name}: {model_file} ({size_mb:.1f} MB)")
|
||||||
|
found_models.append(model_file)
|
||||||
|
|
||||||
|
if not found_models:
|
||||||
|
print("❌ No SAM2 models found!")
|
||||||
|
issues.append("No SAM2 models found - run setup script or download manually")
|
||||||
|
else:
|
||||||
|
print(f"\n💡 Recommended config for best model found:")
|
||||||
|
if "sam2.1_hiera_large.pt" in found_models:
|
||||||
|
print(" sam2_model_cfg: 'sam2.1_hiera_l'")
|
||||||
|
print(" sam2_checkpoint: 'models/sam2.1_hiera_large.pt'")
|
||||||
|
elif "sam2_hiera_large.pt" in found_models:
|
||||||
|
print(" sam2_model_cfg: 'sam2_hiera_l'")
|
||||||
|
print(" sam2_checkpoint: 'models/sam2_hiera_large.pt'")
|
||||||
|
|
||||||
|
# Check SAM2 configs
|
||||||
|
print("\n🔍 Checking SAM2 config files...")
|
||||||
|
configs_dir = Path("sam2_configs")
|
||||||
|
if configs_dir.exists():
|
||||||
|
config_files = list(configs_dir.glob("*.yaml"))
|
||||||
|
if config_files:
|
||||||
|
print(f"✅ Found {len(config_files)} SAM2 config files")
|
||||||
|
else:
|
||||||
|
print("❌ No SAM2 config files found")
|
||||||
|
issues.append("SAM2 config files missing - may cause model loading errors")
|
||||||
|
else:
|
||||||
|
print("❌ sam2_configs directory not found")
|
||||||
|
issues.append("SAM2 configs directory missing")
|
||||||
|
|
||||||
|
# Test model loading if possible
|
||||||
|
if not any("SAM2 not installed" in issue for issue in issues):
|
||||||
|
print("\n🧪 Testing SAM2 model loading...")
|
||||||
|
try:
|
||||||
|
# Try to load the default model config
|
||||||
|
from vr180_matting.config import VR180Config
|
||||||
|
config = VR180Config(
|
||||||
|
input=type('obj', (object,), {'video_path': 'test.mp4'})(),
|
||||||
|
processing=type('obj', (object,), {'scale_factor': 0.5, 'chunk_size': 900, 'overlap_frames': 60})(),
|
||||||
|
detection=type('obj', (object,), {'confidence_threshold': 0.7, 'model': 'yolov8n'})(),
|
||||||
|
matting=type('obj', (object,), {
|
||||||
|
'use_disparity_mapping': True,
|
||||||
|
'memory_offload': True,
|
||||||
|
'fp16': True,
|
||||||
|
'sam2_model_cfg': 'sam2.1_hiera_l' if Path('models/sam2.1_hiera_large.pt').exists() else 'sam2_hiera_l',
|
||||||
|
'sam2_checkpoint': 'models/sam2.1_hiera_large.pt' if Path('models/sam2.1_hiera_large.pt').exists() else 'models/sam2_hiera_large.pt'
|
||||||
|
})(),
|
||||||
|
output=type('obj', (object,), {'path': 'output/', 'format': 'alpha', 'background_color': [0, 255, 0], 'maintain_sbs': True})(),
|
||||||
|
hardware=type('obj', (object,), {'device': 'cuda' if torch.cuda.is_available() else 'cpu', 'max_vram_gb': 10})()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Try loading just the model config check
|
||||||
|
model_path = Path(config.matting.sam2_checkpoint)
|
||||||
|
if model_path.exists():
|
||||||
|
print(f"✅ Found checkpoint: {model_path}")
|
||||||
|
# Quick check of checkpoint structure
|
||||||
|
checkpoint = torch.load(model_path, map_location='cpu')
|
||||||
|
if 'model' in checkpoint:
|
||||||
|
print(f" Model has {len(checkpoint['model'])} parameters")
|
||||||
|
else:
|
||||||
|
print(f"❌ Checkpoint not found: {model_path}")
|
||||||
|
issues.append(f"SAM2 checkpoint missing: {model_path}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"⚠️ Could not test model loading: {e}")
|
||||||
|
|
||||||
|
# Check ffmpeg
|
||||||
|
print("\n🔍 Checking ffmpeg...")
|
||||||
|
import subprocess
|
||||||
|
try:
|
||||||
|
result = subprocess.run(['ffmpeg', '-version'], capture_output=True, text=True)
|
||||||
|
if result.returncode == 0:
|
||||||
|
version_line = result.stdout.split('\n')[0]
|
||||||
|
print(f"✅ {version_line}")
|
||||||
|
else:
|
||||||
|
print("❌ ffmpeg error")
|
||||||
|
issues.append("ffmpeg not working properly")
|
||||||
|
except FileNotFoundError:
|
||||||
|
print("❌ ffmpeg not found")
|
||||||
|
issues.append("ffmpeg not installed - required for video processing")
|
||||||
|
|
||||||
# Memory check
|
# Memory check
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
@@ -54,11 +171,29 @@ if torch.cuda.is_available():
|
|||||||
print(f" Reserved: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")
|
print(f" Reserved: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")
|
||||||
print(f" Free: {(torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_reserved()) / 1024**3:.2f} GB")
|
print(f" Free: {(torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_reserved()) / 1024**3:.2f} GB")
|
||||||
|
|
||||||
|
# Check example config
|
||||||
|
print("\n🔍 Checking configuration files...")
|
||||||
|
example_configs = ["config_example.yaml", "config_runpod.yaml", "config.yaml"]
|
||||||
|
found_configs = [c for c in example_configs if Path(c).exists()]
|
||||||
|
if found_configs:
|
||||||
|
print(f"✅ Found configs: {', '.join(found_configs)}")
|
||||||
|
else:
|
||||||
|
print("❌ No config files found")
|
||||||
|
issues.append("No configuration files found - generate with: vr180-matting --generate-config config.yaml")
|
||||||
|
|
||||||
|
# Summary
|
||||||
print("\n" + "=" * 50)
|
print("\n" + "=" * 50)
|
||||||
print("Installation test complete!")
|
if issues:
|
||||||
print("\nNext steps:")
|
print("❌ Issues found:")
|
||||||
print("1. If any imports failed, install missing dependencies")
|
for i, issue in enumerate(issues, 1):
|
||||||
print("2. Download SAM2 model weights if needed")
|
print(f" {i}. {issue}")
|
||||||
print("3. Run: vr180-matting --generate-config config.yaml")
|
print("\n📋 To fix issues:")
|
||||||
print("4. Edit config.yaml with your video path")
|
print(" 1. Run: ./runpod_setup.sh")
|
||||||
print("5. Run: vr180-matting config.yaml")
|
print(" 2. Make sure SAM2 is installed: pip install git+https://github.com/facebookresearch/segment-anything-2.git")
|
||||||
|
print(" 3. Install package: pip install -e .")
|
||||||
|
else:
|
||||||
|
print("✅ All checks passed! Ready to process videos.")
|
||||||
|
print("\n📋 Quick start:")
|
||||||
|
print(" 1. Copy example config: cp config_runpod.yaml config.yaml")
|
||||||
|
print(" 2. Edit config.yaml with your video path")
|
||||||
|
print(" 3. Run: vr180-matting config.yaml")
|
||||||
@@ -27,7 +27,7 @@ class MattingConfig:
|
|||||||
use_disparity_mapping: bool = True
|
use_disparity_mapping: bool = True
|
||||||
memory_offload: bool = True
|
memory_offload: bool = True
|
||||||
fp16: bool = True
|
fp16: bool = True
|
||||||
sam2_model_cfg: str = "sam2_hiera_l.yaml"
|
sam2_model_cfg: str = "sam2_hiera_l"
|
||||||
sam2_checkpoint: str = "sam2_hiera_large.pt"
|
sam2_checkpoint: str = "sam2_hiera_large.pt"
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ class SAM2VideoMatting:
|
|||||||
"""SAM2-based video matting with memory optimization"""
|
"""SAM2-based video matting with memory optimization"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
model_cfg: str = "sam2_hiera_l.yaml",
|
model_cfg: str = "sam2_hiera_l",
|
||||||
checkpoint_path: str = "sam2_hiera_large.pt",
|
checkpoint_path: str = "sam2_hiera_large.pt",
|
||||||
device: str = "cuda",
|
device: str = "cuda",
|
||||||
memory_offload: bool = True,
|
memory_offload: bool = True,
|
||||||
|
|||||||
Reference in New Issue
Block a user