Compare commits
2 Commits
e195d23584
...
d933d6b606
| Author | SHA1 | Date | |
|---|---|---|---|
| d933d6b606 | |||
| 7852303b40 |
@@ -57,21 +57,14 @@ class SAM2VideoMatting:
|
|||||||
if sam2_repo_path.exists():
|
if sam2_repo_path.exists():
|
||||||
checkpoint_path = str(sam2_repo_path)
|
checkpoint_path = str(sam2_repo_path)
|
||||||
|
|
||||||
# Use the config path as-is (should be relative to SAM2 package)
|
# Use SAM2's build_sam2_video_predictor which returns the predictor directly
|
||||||
# Example: "configs/sam2.1/sam2.1_hiera_l.yaml"
|
# The predictor IS the model - no .model attribute needed
|
||||||
self.predictor = build_sam2_video_predictor(
|
self.predictor = build_sam2_video_predictor(
|
||||||
model_cfg,
|
config_file=model_cfg,
|
||||||
checkpoint_path,
|
ckpt_path=checkpoint_path,
|
||||||
device=self.device
|
device=self.device
|
||||||
)
|
)
|
||||||
|
|
||||||
# Enable memory optimizations
|
|
||||||
if self.memory_offload:
|
|
||||||
self.predictor.fill_hole_area = 8
|
|
||||||
|
|
||||||
if self.fp16 and self.device == "cuda":
|
|
||||||
self.predictor.model.half()
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Failed to load SAM2 model: {e}")
|
raise RuntimeError(f"Failed to load SAM2 model: {e}")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user