diff --git a/config_runpod.yaml b/config_runpod.yaml index 4f8f010..2a52bff 100644 --- a/config_runpod.yaml +++ b/config_runpod.yaml @@ -14,7 +14,7 @@ matting: use_disparity_mapping: true memory_offload: false # A40 has enough VRAM fp16: true - sam2_model_cfg: "sam2.1_hiera_l" + sam2_model_cfg: "configs/sam2.1/sam2.1_hiera_l.yaml" sam2_checkpoint: "segment-anything-2/checkpoints/sam2.1_hiera_large.pt" output: diff --git a/vr180_matting/sam2_wrapper.py b/vr180_matting/sam2_wrapper.py index 0c49bdf..320da35 100644 --- a/vr180_matting/sam2_wrapper.py +++ b/vr180_matting/sam2_wrapper.py @@ -57,23 +57,10 @@ class SAM2VideoMatting: if sam2_repo_path.exists(): checkpoint_path = str(sam2_repo_path) - # Handle config path - if it contains a dot, look for the actual file - config_path = model_cfg - if not model_cfg.endswith('.yaml'): - # Try to find the config file in SAM2 repo structure - sam2_config_paths = [ - Path("segment-anything-2/sam2/configs/sam2.1") / f"{model_cfg}.yaml", - Path("segment-anything-2/sam2/configs/sam2") / f"{model_cfg}.yaml", - Path("segment-anything-2/sam2") / f"{model_cfg}.yaml" - ] - - for config_file_path in sam2_config_paths: - if config_file_path.exists(): - config_path = str(config_file_path) - break - + # Use the config path as-is (should be relative to SAM2 package) + # Example: "configs/sam2.1/sam2.1_hiera_l.yaml" self.predictor = build_sam2_video_predictor( - config_path, + model_cfg, checkpoint_path, device=self.device )