This commit is contained in:
2025-07-26 15:18:01 -07:00
parent 9f572d4430
commit fb51e82fd4
2 changed files with 38 additions and 15 deletions

View File

@@ -1,6 +1,4 @@
import torch
import numpy as np import numpy as np
from ultralytics import YOLO
from typing import List, Tuple, Dict, Any from typing import List, Tuple, Dict, Any
import cv2 import cv2
@@ -13,14 +11,23 @@ class YOLODetector:
self.confidence_threshold = confidence_threshold self.confidence_threshold = confidence_threshold
self.device = device self.device = device
self.model = None self.model = None
self._load_model() # Don't load model during init - load lazily when first used
def _load_model(self): def _load_model(self):
"""Load YOLOv8 model""" """Load YOLOv8 model lazily"""
if self.model is not None:
return # Already loaded
try: try:
# Import heavy dependencies only when needed
import torch
from ultralytics import YOLO
self.model = YOLO(f"{self.model_name}.pt") self.model = YOLO(f"{self.model_name}.pt")
if self.device == "cuda" and torch.cuda.is_available(): if self.device == "cuda" and torch.cuda.is_available():
self.model.to("cuda") self.model.to("cuda")
print(f"🎯 Loaded YOLO model: {self.model_name}")
except Exception as e: except Exception as e:
raise RuntimeError(f"Failed to load YOLO model {self.model_name}: {e}") raise RuntimeError(f"Failed to load YOLO model {self.model_name}: {e}")
@@ -34,8 +41,9 @@ class YOLODetector:
Returns: Returns:
List of detection dictionaries with bbox, confidence, and class info List of detection dictionaries with bbox, confidence, and class info
""" """
# Load model lazily on first use
if self.model is None: if self.model is None:
raise RuntimeError("YOLO model not loaded") self._load_model()
results = self.model(frame, verbose=False) results = self.model(frame, verbose=False)
detections = [] detections = []

View File

@@ -9,12 +9,16 @@ import tempfile
import shutil import shutil
import gc import gc
try: # Check SAM2 availability without importing heavy modules
from sam2.build_sam import build_sam2_video_predictor def _check_sam2_available():
from sam2.sam2_image_predictor import SAM2ImagePredictor try:
SAM2_AVAILABLE = True import sam2
except ImportError: return True
SAM2_AVAILABLE = False except ImportError:
return False
SAM2_AVAILABLE = _check_sam2_available()
if not SAM2_AVAILABLE:
warnings.warn("SAM2 not available. Please install sam2 package.") warnings.warn("SAM2 not available. Please install sam2 package.")
@@ -40,11 +44,18 @@ class SAM2VideoMatting:
self.video_segments = {} self.video_segments = {}
self.temp_video_path = None self.temp_video_path = None
self._load_model(model_cfg, checkpoint_path) # Don't load model during init - load lazily when needed
self._model_loaded = False
def _load_model(self, model_cfg: str, checkpoint_path: str): def _load_model(self, model_cfg: str, checkpoint_path: str):
"""Load SAM2 video predictor with optimizations""" """Load SAM2 video predictor lazily"""
if self._model_loaded:
return # Already loaded
try: try:
# Import heavy SAM2 modules only when needed
from sam2.build_sam import build_sam2_video_predictor
# Check for checkpoint in SAM2 repo structure # Check for checkpoint in SAM2 repo structure
if not Path(checkpoint_path).exists(): if not Path(checkpoint_path).exists():
# Try in segment-anything-2/checkpoints/ # Try in segment-anything-2/checkpoints/
@@ -63,6 +74,7 @@ class SAM2VideoMatting:
if sam2_repo_path.exists(): if sam2_repo_path.exists():
checkpoint_path = str(sam2_repo_path) checkpoint_path = str(sam2_repo_path)
print(f"🎯 Loading SAM2 model: {model_cfg}")
# Use SAM2's build_sam2_video_predictor which returns the predictor directly # Use SAM2's build_sam2_video_predictor which returns the predictor directly
# The predictor IS the model - no .model attribute needed # The predictor IS the model - no .model attribute needed
self.predictor = build_sam2_video_predictor( self.predictor = build_sam2_video_predictor(
@@ -70,14 +82,17 @@ class SAM2VideoMatting:
ckpt_path=checkpoint_path, ckpt_path=checkpoint_path,
device=self.device device=self.device
) )
self._model_loaded = True
print(f"✅ SAM2 model loaded successfully")
except Exception as e: except Exception as e:
raise RuntimeError(f"Failed to load SAM2 model: {e}") raise RuntimeError(f"Failed to load SAM2 model: {e}")
def init_video_state(self, video_frames: List[np.ndarray] = None, video_path: str = None) -> None: def init_video_state(self, video_frames: List[np.ndarray] = None, video_path: str = None) -> None:
"""Initialize video inference state""" """Initialize video inference state"""
if self.predictor is None: # Load model lazily on first use
# Recreate predictor if it was cleaned up if not self._model_loaded:
self._load_model(self.model_cfg, self.checkpoint_path) self._load_model(self.model_cfg, self.checkpoint_path)
if video_path is not None: if video_path is not None: