230 lines
9.2 KiB
Python
230 lines
9.2 KiB
Python
#!/usr/bin/env python3
|
|
"""Test script to verify installation and GPU setup"""
|
|
|
|
import sys
|
|
import torch
|
|
import cv2
|
|
import numpy as np
|
|
from pathlib import Path
|
|
import os
|
|
|
|
print("VR180 Matting Installation Test")
|
|
print("=" * 50)
|
|
|
|
# Track all issues
|
|
issues = []
|
|
|
|
# Check Python version
|
|
print(f"Python version: {sys.version}")
|
|
if sys.version_info < (3, 8):
|
|
issues.append("Python 3.8+ required")
|
|
|
|
# Check PyTorch and CUDA
|
|
print(f"\nPyTorch version: {torch.__version__}")
|
|
print(f"CUDA available: {torch.cuda.is_available()}")
|
|
if torch.cuda.is_available():
|
|
print(f"CUDA version: {torch.version.cuda}")
|
|
print(f"GPU: {torch.cuda.get_device_name(0)}")
|
|
print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
|
|
else:
|
|
issues.append("No CUDA GPU detected - will run slowly on CPU")
|
|
|
|
# Check OpenCV
|
|
print(f"\nOpenCV version: {cv2.__version__}")
|
|
|
|
# Test imports
|
|
print("\n🔍 Testing imports...")
|
|
try:
|
|
from ultralytics import YOLO
|
|
print("✅ YOLO import successful")
|
|
|
|
# Check if YOLO models exist
|
|
yolo_models = ["yolov8n.pt", "yolov8s.pt", "yolov8m.pt"]
|
|
available_yolo = [m for m in yolo_models if Path(m).exists()]
|
|
if available_yolo:
|
|
print(f" Found YOLO models: {', '.join(available_yolo)}")
|
|
else:
|
|
issues.append("No YOLO models found - will download on first run")
|
|
|
|
except ImportError as e:
|
|
print(f"❌ YOLO import failed: {e}")
|
|
issues.append("Ultralytics YOLO not installed")
|
|
|
|
try:
|
|
from sam2.build_sam import build_sam2_video_predictor
|
|
print("✅ SAM2 import successful")
|
|
except ImportError as e:
|
|
print(f"❌ SAM2 import failed: {e}")
|
|
print(" Install with: pip install git+https://github.com/facebookresearch/segment-anything-2.git")
|
|
issues.append("SAM2 not installed")
|
|
|
|
try:
|
|
from vr180_matting.config import VR180Config
|
|
from vr180_matting.detector import YOLODetector
|
|
from vr180_matting.sam2_wrapper import SAM2VideoMatting
|
|
from vr180_matting.memory_manager import VRAMManager
|
|
print("✅ VR180 matting modules import successful")
|
|
except ImportError as e:
|
|
print(f"❌ VR180 matting import failed: {e}")
|
|
print(" Make sure to run: pip install -e .")
|
|
issues.append("VR180 matting package not installed")
|
|
|
|
# Check SAM2 models
|
|
print("\n🔍 Checking SAM2 models...")
|
|
sam2_checkpoints_dir = Path("segment-anything-2/checkpoints")
|
|
models_dir = Path("models") # Legacy location
|
|
|
|
sam2_models = {
|
|
"sam2.1_hiera_tiny.pt": "SAM2.1 Tiny",
|
|
"sam2.1_hiera_small.pt": "SAM2.1 Small",
|
|
"sam2.1_hiera_base_plus.pt": "SAM2.1 Base+",
|
|
"sam2.1_hiera_large.pt": "SAM2.1 Large (recommended)",
|
|
"sam2_hiera_tiny.pt": "SAM2 Tiny",
|
|
"sam2_hiera_small.pt": "SAM2 Small",
|
|
"sam2_hiera_base_plus.pt": "SAM2 Base+",
|
|
"sam2_hiera_large.pt": "SAM2 Large"
|
|
}
|
|
|
|
found_models = []
|
|
for model_file, model_name in sam2_models.items():
|
|
# Check SAM2 repo location first
|
|
sam2_path = sam2_checkpoints_dir / model_file
|
|
legacy_path = models_dir / model_file
|
|
|
|
if sam2_path.exists():
|
|
size_mb = sam2_path.stat().st_size / (1024 * 1024)
|
|
print(f"✅ {model_name}: {model_file} ({size_mb:.1f} MB)")
|
|
found_models.append((model_file, str(sam2_path)))
|
|
elif legacy_path.exists():
|
|
size_mb = legacy_path.stat().st_size / (1024 * 1024)
|
|
print(f"✅ {model_name}: {model_file} ({size_mb:.1f} MB) [legacy location]")
|
|
found_models.append((model_file, str(legacy_path)))
|
|
|
|
if not found_models:
|
|
print("❌ No SAM2 models found!")
|
|
issues.append("No SAM2 models found - run setup script or download manually")
|
|
else:
|
|
print(f"\n💡 Recommended config for best model found:")
|
|
# Prioritize SAM2.1 models
|
|
if any("sam2.1_hiera_large.pt" in model[0] for model in found_models):
|
|
best_model = next(model for model in found_models if "sam2.1_hiera_large.pt" in model[0])
|
|
print(" sam2_model_cfg: 'sam2.1_hiera_l'")
|
|
print(f" sam2_checkpoint: '{best_model[1]}'")
|
|
elif any("sam2.1_hiera_base_plus.pt" in model[0] for model in found_models):
|
|
best_model = next(model for model in found_models if "sam2.1_hiera_base_plus.pt" in model[0])
|
|
print(" sam2_model_cfg: 'sam2.1_hiera_base_plus'")
|
|
print(f" sam2_checkpoint: '{best_model[1]}'")
|
|
elif any("sam2_hiera_large.pt" in model[0] for model in found_models):
|
|
best_model = next(model for model in found_models if "sam2_hiera_large.pt" in model[0])
|
|
print(" sam2_model_cfg: 'sam2_hiera_l'")
|
|
print(f" sam2_checkpoint: '{best_model[1]}'")
|
|
|
|
# Check SAM2 configs (now part of installed package)
|
|
print("\n🔍 Checking SAM2 configuration...")
|
|
try:
|
|
import sam2.sam2_configs
|
|
print("✅ SAM2 configs available in installed package")
|
|
except ImportError:
|
|
print("❌ SAM2 configs not found")
|
|
issues.append("SAM2 configs not available - SAM2 may not be properly installed")
|
|
|
|
# Test model loading if possible
|
|
if not any("SAM2 not installed" in issue for issue in issues):
|
|
print("\n🧪 Testing SAM2 model loading...")
|
|
try:
|
|
# Try to load the default model config
|
|
from vr180_matting.config import VR180Config
|
|
# Use the best available model
|
|
if found_models:
|
|
best_model = found_models[0] # Use first found model (prioritized)
|
|
sam2_checkpoint = best_model[1]
|
|
if "sam2.1_hiera_large.pt" in best_model[0]:
|
|
sam2_cfg = "sam2.1_hiera_l"
|
|
elif "sam2.1_hiera_base_plus.pt" in best_model[0]:
|
|
sam2_cfg = "sam2.1_hiera_base_plus"
|
|
elif "sam2_hiera_large.pt" in best_model[0]:
|
|
sam2_cfg = "sam2_hiera_l"
|
|
else:
|
|
sam2_cfg = "sam2.1_hiera_l"
|
|
else:
|
|
sam2_cfg = "sam2.1_hiera_l"
|
|
sam2_checkpoint = "segment-anything-2/checkpoints/sam2.1_hiera_large.pt"
|
|
|
|
config = VR180Config(
|
|
input=type('obj', (object,), {'video_path': 'test.mp4'})(),
|
|
processing=type('obj', (object,), {'scale_factor': 0.5, 'chunk_size': 900, 'overlap_frames': 60})(),
|
|
detection=type('obj', (object,), {'confidence_threshold': 0.7, 'model': 'yolov8n'})(),
|
|
matting=type('obj', (object,), {
|
|
'use_disparity_mapping': True,
|
|
'memory_offload': True,
|
|
'fp16': True,
|
|
'sam2_model_cfg': sam2_cfg,
|
|
'sam2_checkpoint': sam2_checkpoint
|
|
})(),
|
|
output=type('obj', (object,), {'path': 'output/', 'format': 'alpha', 'background_color': [0, 255, 0], 'maintain_sbs': True})(),
|
|
hardware=type('obj', (object,), {'device': 'cuda' if torch.cuda.is_available() else 'cpu', 'max_vram_gb': 10})()
|
|
)
|
|
|
|
# Try loading just the model config check
|
|
model_path = Path(config.matting.sam2_checkpoint)
|
|
if model_path.exists():
|
|
print(f"✅ Found checkpoint: {model_path}")
|
|
# Quick check of checkpoint structure
|
|
checkpoint = torch.load(model_path, map_location='cpu')
|
|
if 'model' in checkpoint:
|
|
print(f" Model has {len(checkpoint['model'])} parameters")
|
|
else:
|
|
print(f"❌ Checkpoint not found: {model_path}")
|
|
issues.append(f"SAM2 checkpoint missing: {model_path}")
|
|
|
|
except Exception as e:
|
|
print(f"⚠️ Could not test model loading: {e}")
|
|
|
|
# Check ffmpeg
|
|
print("\n🔍 Checking ffmpeg...")
|
|
import subprocess
|
|
try:
|
|
result = subprocess.run(['ffmpeg', '-version'], capture_output=True, text=True)
|
|
if result.returncode == 0:
|
|
version_line = result.stdout.split('\n')[0]
|
|
print(f"✅ {version_line}")
|
|
else:
|
|
print("❌ ffmpeg error")
|
|
issues.append("ffmpeg not working properly")
|
|
except FileNotFoundError:
|
|
print("❌ ffmpeg not found")
|
|
issues.append("ffmpeg not installed - required for video processing")
|
|
|
|
# Memory check
|
|
if torch.cuda.is_available():
|
|
print(f"\n📊 Current GPU Memory Usage:")
|
|
print(f" Allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
|
|
print(f" Reserved: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")
|
|
print(f" Free: {(torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_reserved()) / 1024**3:.2f} GB")
|
|
|
|
# Check example config
|
|
print("\n🔍 Checking configuration files...")
|
|
example_configs = ["config_example.yaml", "config_runpod.yaml", "config.yaml"]
|
|
found_configs = [c for c in example_configs if Path(c).exists()]
|
|
if found_configs:
|
|
print(f"✅ Found configs: {', '.join(found_configs)}")
|
|
else:
|
|
print("❌ No config files found")
|
|
issues.append("No configuration files found - generate with: vr180-matting --generate-config config.yaml")
|
|
|
|
# Summary
|
|
print("\n" + "=" * 50)
|
|
if issues:
|
|
print("❌ Issues found:")
|
|
for i, issue in enumerate(issues, 1):
|
|
print(f" {i}. {issue}")
|
|
print("\n📋 To fix issues:")
|
|
print(" 1. Run: ./runpod_setup.sh")
|
|
print(" 2. Make sure SAM2 is installed: pip install git+https://github.com/facebookresearch/segment-anything-2.git")
|
|
print(" 3. Install package: pip install -e .")
|
|
else:
|
|
print("✅ All checks passed! Ready to process videos.")
|
|
print("\n📋 Quick start:")
|
|
print(" 1. Copy example config: cp config_runpod.yaml config.yaml")
|
|
print(" 2. Edit config.yaml with your video path")
|
|
print(" 3. Run: vr180-matting config.yaml") |