fix sam2 models
This commit is contained in:
@@ -14,6 +14,8 @@ matting:
|
||||
use_disparity_mapping: true
|
||||
memory_offload: true
|
||||
fp16: true
|
||||
sam2_model_cfg: "sam2_hiera_l.yaml"
|
||||
sam2_checkpoint: "models/sam2_hiera_large.pt"
|
||||
|
||||
output:
|
||||
path: "path/to/output/"
|
||||
|
||||
@@ -14,6 +14,8 @@ matting:
|
||||
use_disparity_mapping: true
|
||||
memory_offload: false # A40 has enough VRAM
|
||||
fp16: true
|
||||
sam2_model_cfg: "sam2_hiera_l.yaml"
|
||||
sam2_checkpoint: "models/sam2_hiera_large.pt"
|
||||
|
||||
output:
|
||||
path: "/workspace/output/matted_video.mp4"
|
||||
|
||||
@@ -27,6 +27,8 @@ class MattingConfig:
|
||||
use_disparity_mapping: bool = True
|
||||
memory_offload: bool = True
|
||||
fp16: bool = True
|
||||
sam2_model_cfg: str = "sam2_hiera_l.yaml"
|
||||
sam2_checkpoint: str = "sam2_hiera_large.pt"
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -89,7 +91,9 @@ class VR180Config:
|
||||
'matting': {
|
||||
'use_disparity_mapping': self.matting.use_disparity_mapping,
|
||||
'memory_offload': self.matting.memory_offload,
|
||||
'fp16': self.matting.fp16
|
||||
'fp16': self.matting.fp16,
|
||||
'sam2_model_cfg': self.matting.sam2_model_cfg,
|
||||
'sam2_checkpoint': self.matting.sam2_checkpoint
|
||||
},
|
||||
'output': {
|
||||
'path': self.output.path,
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import List, Dict, Any, Optional, Tuple
|
||||
import cv2
|
||||
from pathlib import Path
|
||||
import warnings
|
||||
import os
|
||||
|
||||
try:
|
||||
from sam2.build_sam import build_sam2_video_predictor
|
||||
@@ -38,6 +39,19 @@ class SAM2VideoMatting:
|
||||
def _load_model(self, model_cfg: str, checkpoint_path: str):
|
||||
"""Load SAM2 video predictor with optimizations"""
|
||||
try:
|
||||
# Check for checkpoint in models directory if not found
|
||||
if not Path(checkpoint_path).exists():
|
||||
models_path = Path("models") / checkpoint_path
|
||||
if models_path.exists():
|
||||
checkpoint_path = str(models_path)
|
||||
else:
|
||||
# Try relative to package
|
||||
import os
|
||||
package_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
models_path = Path(package_dir) / "models" / checkpoint_path
|
||||
if models_path.exists():
|
||||
checkpoint_path = str(models_path)
|
||||
|
||||
self.predictor = build_sam2_video_predictor(
|
||||
model_cfg,
|
||||
checkpoint_path,
|
||||
|
||||
@@ -51,6 +51,8 @@ class VideoProcessor:
|
||||
|
||||
# Initialize SAM2 model
|
||||
self.sam2_model = SAM2VideoMatting(
|
||||
model_cfg=self.config.matting.sam2_model_cfg,
|
||||
checkpoint_path=self.config.matting.sam2_checkpoint,
|
||||
device=self.config.hardware.device,
|
||||
memory_offload=self.config.matting.memory_offload,
|
||||
fp16=self.config.matting.fp16
|
||||
|
||||
Reference in New Issue
Block a user