stuff
This commit is contained in:
@@ -9,12 +9,16 @@ import tempfile
|
||||
import shutil
|
||||
import gc
|
||||
|
||||
try:
|
||||
from sam2.build_sam import build_sam2_video_predictor
|
||||
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
||||
SAM2_AVAILABLE = True
|
||||
except ImportError:
|
||||
SAM2_AVAILABLE = False
|
||||
# Check SAM2 availability without importing heavy modules
|
||||
def _check_sam2_available():
|
||||
try:
|
||||
import sam2
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
SAM2_AVAILABLE = _check_sam2_available()
|
||||
if not SAM2_AVAILABLE:
|
||||
warnings.warn("SAM2 not available. Please install sam2 package.")
|
||||
|
||||
|
||||
@@ -40,11 +44,18 @@ class SAM2VideoMatting:
|
||||
self.video_segments = {}
|
||||
self.temp_video_path = None
|
||||
|
||||
self._load_model(model_cfg, checkpoint_path)
|
||||
# Don't load model during init - load lazily when needed
|
||||
self._model_loaded = False
|
||||
|
||||
def _load_model(self, model_cfg: str, checkpoint_path: str):
|
||||
"""Load SAM2 video predictor with optimizations"""
|
||||
"""Load SAM2 video predictor lazily"""
|
||||
if self._model_loaded:
|
||||
return # Already loaded
|
||||
|
||||
try:
|
||||
# Import heavy SAM2 modules only when needed
|
||||
from sam2.build_sam import build_sam2_video_predictor
|
||||
|
||||
# Check for checkpoint in SAM2 repo structure
|
||||
if not Path(checkpoint_path).exists():
|
||||
# Try in segment-anything-2/checkpoints/
|
||||
@@ -63,6 +74,7 @@ class SAM2VideoMatting:
|
||||
if sam2_repo_path.exists():
|
||||
checkpoint_path = str(sam2_repo_path)
|
||||
|
||||
print(f"🎯 Loading SAM2 model: {model_cfg}")
|
||||
# Use SAM2's build_sam2_video_predictor which returns the predictor directly
|
||||
# The predictor IS the model - no .model attribute needed
|
||||
self.predictor = build_sam2_video_predictor(
|
||||
@@ -70,14 +82,17 @@ class SAM2VideoMatting:
|
||||
ckpt_path=checkpoint_path,
|
||||
device=self.device
|
||||
)
|
||||
|
||||
self._model_loaded = True
|
||||
print(f"✅ SAM2 model loaded successfully")
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load SAM2 model: {e}")
|
||||
|
||||
def init_video_state(self, video_frames: List[np.ndarray] = None, video_path: str = None) -> None:
|
||||
"""Initialize video inference state"""
|
||||
if self.predictor is None:
|
||||
# Recreate predictor if it was cleaned up
|
||||
# Load model lazily on first use
|
||||
if not self._model_loaded:
|
||||
self._load_model(self.model_cfg, self.checkpoint_path)
|
||||
|
||||
if video_path is not None:
|
||||
|
||||
Reference in New Issue
Block a user