fix sam2 models
This commit is contained in:
@@ -14,6 +14,8 @@ matting:
|
|||||||
use_disparity_mapping: true
|
use_disparity_mapping: true
|
||||||
memory_offload: true
|
memory_offload: true
|
||||||
fp16: true
|
fp16: true
|
||||||
|
sam2_model_cfg: "sam2_hiera_l.yaml"
|
||||||
|
sam2_checkpoint: "models/sam2_hiera_large.pt"
|
||||||
|
|
||||||
output:
|
output:
|
||||||
path: "path/to/output/"
|
path: "path/to/output/"
|
||||||
|
|||||||
@@ -14,6 +14,8 @@ matting:
|
|||||||
use_disparity_mapping: true
|
use_disparity_mapping: true
|
||||||
memory_offload: false # A40 has enough VRAM
|
memory_offload: false # A40 has enough VRAM
|
||||||
fp16: true
|
fp16: true
|
||||||
|
sam2_model_cfg: "sam2_hiera_l.yaml"
|
||||||
|
sam2_checkpoint: "models/sam2_hiera_large.pt"
|
||||||
|
|
||||||
output:
|
output:
|
||||||
path: "/workspace/output/matted_video.mp4"
|
path: "/workspace/output/matted_video.mp4"
|
||||||
|
|||||||
@@ -27,6 +27,8 @@ class MattingConfig:
|
|||||||
use_disparity_mapping: bool = True
|
use_disparity_mapping: bool = True
|
||||||
memory_offload: bool = True
|
memory_offload: bool = True
|
||||||
fp16: bool = True
|
fp16: bool = True
|
||||||
|
sam2_model_cfg: str = "sam2_hiera_l.yaml"
|
||||||
|
sam2_checkpoint: str = "sam2_hiera_large.pt"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -89,7 +91,9 @@ class VR180Config:
|
|||||||
'matting': {
|
'matting': {
|
||||||
'use_disparity_mapping': self.matting.use_disparity_mapping,
|
'use_disparity_mapping': self.matting.use_disparity_mapping,
|
||||||
'memory_offload': self.matting.memory_offload,
|
'memory_offload': self.matting.memory_offload,
|
||||||
'fp16': self.matting.fp16
|
'fp16': self.matting.fp16,
|
||||||
|
'sam2_model_cfg': self.matting.sam2_model_cfg,
|
||||||
|
'sam2_checkpoint': self.matting.sam2_checkpoint
|
||||||
},
|
},
|
||||||
'output': {
|
'output': {
|
||||||
'path': self.output.path,
|
'path': self.output.path,
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from typing import List, Dict, Any, Optional, Tuple
|
|||||||
import cv2
|
import cv2
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import warnings
|
import warnings
|
||||||
|
import os
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from sam2.build_sam import build_sam2_video_predictor
|
from sam2.build_sam import build_sam2_video_predictor
|
||||||
@@ -38,6 +39,19 @@ class SAM2VideoMatting:
|
|||||||
def _load_model(self, model_cfg: str, checkpoint_path: str):
|
def _load_model(self, model_cfg: str, checkpoint_path: str):
|
||||||
"""Load SAM2 video predictor with optimizations"""
|
"""Load SAM2 video predictor with optimizations"""
|
||||||
try:
|
try:
|
||||||
|
# Check for checkpoint in models directory if not found
|
||||||
|
if not Path(checkpoint_path).exists():
|
||||||
|
models_path = Path("models") / checkpoint_path
|
||||||
|
if models_path.exists():
|
||||||
|
checkpoint_path = str(models_path)
|
||||||
|
else:
|
||||||
|
# Try relative to package
|
||||||
|
import os
|
||||||
|
package_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
models_path = Path(package_dir) / "models" / checkpoint_path
|
||||||
|
if models_path.exists():
|
||||||
|
checkpoint_path = str(models_path)
|
||||||
|
|
||||||
self.predictor = build_sam2_video_predictor(
|
self.predictor = build_sam2_video_predictor(
|
||||||
model_cfg,
|
model_cfg,
|
||||||
checkpoint_path,
|
checkpoint_path,
|
||||||
|
|||||||
@@ -51,6 +51,8 @@ class VideoProcessor:
|
|||||||
|
|
||||||
# Initialize SAM2 model
|
# Initialize SAM2 model
|
||||||
self.sam2_model = SAM2VideoMatting(
|
self.sam2_model = SAM2VideoMatting(
|
||||||
|
model_cfg=self.config.matting.sam2_model_cfg,
|
||||||
|
checkpoint_path=self.config.matting.sam2_checkpoint,
|
||||||
device=self.config.hardware.device,
|
device=self.config.hardware.device,
|
||||||
memory_offload=self.config.matting.memory_offload,
|
memory_offload=self.config.matting.memory_offload,
|
||||||
fp16=self.config.matting.fp16
|
fp16=self.config.matting.fp16
|
||||||
|
|||||||
Reference in New Issue
Block a user