install sam2 the way facebook says
This commit is contained in:
@@ -27,8 +27,8 @@ class MattingConfig:
|
||||
use_disparity_mapping: bool = True
|
||||
memory_offload: bool = True
|
||||
fp16: bool = True
|
||||
sam2_model_cfg: str = "sam2_hiera_l"
|
||||
sam2_checkpoint: str = "sam2_hiera_large.pt"
|
||||
sam2_model_cfg: str = "sam2.1_hiera_l"
|
||||
sam2_checkpoint: str = "segment-anything-2/checkpoints/sam2.1_hiera_large.pt"
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -111,6 +111,8 @@ matting:
|
||||
use_disparity_mapping: true
|
||||
memory_offload: true
|
||||
fp16: true
|
||||
sam2_model_cfg: "sam2.1_hiera_l"
|
||||
sam2_checkpoint: "segment-anything-2/checkpoints/sam2.1_hiera_large.pt"
|
||||
|
||||
output:
|
||||
path: "path/to/output/"
|
||||
|
||||
@@ -39,18 +39,23 @@ class SAM2VideoMatting:
|
||||
def _load_model(self, model_cfg: str, checkpoint_path: str):
|
||||
"""Load SAM2 video predictor with optimizations"""
|
||||
try:
|
||||
# Check for checkpoint in models directory if not found
|
||||
# Check for checkpoint in SAM2 repo structure
|
||||
if not Path(checkpoint_path).exists():
|
||||
models_path = Path("models") / checkpoint_path
|
||||
if models_path.exists():
|
||||
checkpoint_path = str(models_path)
|
||||
# Try in segment-anything-2/checkpoints/
|
||||
sam2_path = Path("segment-anything-2/checkpoints") / Path(checkpoint_path).name
|
||||
if sam2_path.exists():
|
||||
checkpoint_path = str(sam2_path)
|
||||
else:
|
||||
# Try relative to package
|
||||
import os
|
||||
package_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
models_path = Path(package_dir) / "models" / checkpoint_path
|
||||
# Try legacy models/ directory
|
||||
models_path = Path("models") / Path(checkpoint_path).name
|
||||
if models_path.exists():
|
||||
checkpoint_path = str(models_path)
|
||||
else:
|
||||
# Try relative to package
|
||||
package_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sam2_repo_path = Path(package_dir) / "segment-anything-2/checkpoints" / Path(checkpoint_path).name
|
||||
if sam2_repo_path.exists():
|
||||
checkpoint_path = str(sam2_repo_path)
|
||||
|
||||
self.predictor = build_sam2_video_predictor(
|
||||
model_cfg,
|
||||
|
||||
Reference in New Issue
Block a user