diff --git a/vr180_matting/detector.py b/vr180_matting/detector.py index 343b559..a510aa9 100644 --- a/vr180_matting/detector.py +++ b/vr180_matting/detector.py @@ -1,6 +1,4 @@ -import torch import numpy as np -from ultralytics import YOLO from typing import List, Tuple, Dict, Any import cv2 @@ -13,14 +11,23 @@ class YOLODetector: self.confidence_threshold = confidence_threshold self.device = device self.model = None - self._load_model() + # Don't load model during init - load lazily when first used def _load_model(self): - """Load YOLOv8 model""" + """Load YOLOv8 model lazily""" + if self.model is not None: + return # Already loaded + try: + # Import heavy dependencies only when needed + import torch + from ultralytics import YOLO + self.model = YOLO(f"{self.model_name}.pt") if self.device == "cuda" and torch.cuda.is_available(): self.model.to("cuda") + + print(f"🎯 Loaded YOLO model: {self.model_name}") except Exception as e: raise RuntimeError(f"Failed to load YOLO model {self.model_name}: {e}") @@ -34,8 +41,9 @@ class YOLODetector: Returns: List of detection dictionaries with bbox, confidence, and class info """ + # Load model lazily on first use if self.model is None: - raise RuntimeError("YOLO model not loaded") + self._load_model() results = self.model(frame, verbose=False) detections = [] diff --git a/vr180_matting/sam2_wrapper.py b/vr180_matting/sam2_wrapper.py index b406bbe..64a37d0 100644 --- a/vr180_matting/sam2_wrapper.py +++ b/vr180_matting/sam2_wrapper.py @@ -9,12 +9,16 @@ import tempfile import shutil import gc -try: - from sam2.build_sam import build_sam2_video_predictor - from sam2.sam2_image_predictor import SAM2ImagePredictor - SAM2_AVAILABLE = True -except ImportError: - SAM2_AVAILABLE = False +# Check SAM2 availability without importing heavy modules +def _check_sam2_available(): + try: + import sam2 + return True + except ImportError: + return False + +SAM2_AVAILABLE = _check_sam2_available() +if not SAM2_AVAILABLE: warnings.warn("SAM2 not available. Please install sam2 package.") @@ -40,11 +44,18 @@ class SAM2VideoMatting: self.video_segments = {} self.temp_video_path = None - self._load_model(model_cfg, checkpoint_path) + # Don't load model during init - load lazily when needed + self._model_loaded = False def _load_model(self, model_cfg: str, checkpoint_path: str): - """Load SAM2 video predictor with optimizations""" + """Load SAM2 video predictor lazily""" + if self._model_loaded: + return # Already loaded + try: + # Import heavy SAM2 modules only when needed + from sam2.build_sam import build_sam2_video_predictor + # Check for checkpoint in SAM2 repo structure if not Path(checkpoint_path).exists(): # Try in segment-anything-2/checkpoints/ @@ -63,6 +74,7 @@ class SAM2VideoMatting: if sam2_repo_path.exists(): checkpoint_path = str(sam2_repo_path) + print(f"🎯 Loading SAM2 model: {model_cfg}") # Use SAM2's build_sam2_video_predictor which returns the predictor directly # The predictor IS the model - no .model attribute needed self.predictor = build_sam2_video_predictor( @@ -70,14 +82,17 @@ class SAM2VideoMatting: ckpt_path=checkpoint_path, device=self.device ) + + self._model_loaded = True + print(f"✅ SAM2 model loaded successfully") except Exception as e: raise RuntimeError(f"Failed to load SAM2 model: {e}") def init_video_state(self, video_frames: List[np.ndarray] = None, video_path: str = None) -> None: """Initialize video inference state""" - if self.predictor is None: - # Recreate predictor if it was cleaned up + # Load model lazily on first use + if not self._model_loaded: self._load_model(self.model_cfg, self.checkpoint_path) if video_path is not None: