fix sam2 models

This commit is contained in:
2025-07-26 07:57:01 -07:00
parent aa6ee4d40e
commit 1bec8113de
5 changed files with 25 additions and 1 deletions

View File

@@ -27,6 +27,8 @@ class MattingConfig:
use_disparity_mapping: bool = True
memory_offload: bool = True
fp16: bool = True
sam2_model_cfg: str = "sam2_hiera_l.yaml"
sam2_checkpoint: str = "sam2_hiera_large.pt"
@dataclass
@@ -89,7 +91,9 @@ class VR180Config:
'matting': {
'use_disparity_mapping': self.matting.use_disparity_mapping,
'memory_offload': self.matting.memory_offload,
'fp16': self.matting.fp16
'fp16': self.matting.fp16,
'sam2_model_cfg': self.matting.sam2_model_cfg,
'sam2_checkpoint': self.matting.sam2_checkpoint
},
'output': {
'path': self.output.path,

View File

@@ -4,6 +4,7 @@ from typing import List, Dict, Any, Optional, Tuple
import cv2
from pathlib import Path
import warnings
import os
try:
from sam2.build_sam import build_sam2_video_predictor
@@ -38,6 +39,19 @@ class SAM2VideoMatting:
def _load_model(self, model_cfg: str, checkpoint_path: str):
"""Load SAM2 video predictor with optimizations"""
try:
# Check for checkpoint in models directory if not found
if not Path(checkpoint_path).exists():
models_path = Path("models") / checkpoint_path
if models_path.exists():
checkpoint_path = str(models_path)
else:
# Try relative to package
import os
package_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
models_path = Path(package_dir) / "models" / checkpoint_path
if models_path.exists():
checkpoint_path = str(models_path)
self.predictor = build_sam2_video_predictor(
model_cfg,
checkpoint_path,

View File

@@ -51,6 +51,8 @@ class VideoProcessor:
# Initialize SAM2 model
self.sam2_model = SAM2VideoMatting(
model_cfg=self.config.matting.sam2_model_cfg,
checkpoint_path=self.config.matting.sam2_checkpoint,
device=self.config.hardware.device,
memory_offload=self.config.matting.memory_offload,
fp16=self.config.matting.fp16