This commit is contained in:
2025-07-26 15:18:01 -07:00
parent 9f572d4430
commit fb51e82fd4
2 changed files with 38 additions and 15 deletions

View File

@@ -1,6 +1,4 @@
import torch
import numpy as np
from ultralytics import YOLO
from typing import List, Tuple, Dict, Any
import cv2
@@ -13,14 +11,23 @@ class YOLODetector:
self.confidence_threshold = confidence_threshold
self.device = device
self.model = None
self._load_model()
# Don't load model during init - load lazily when first used
def _load_model(self):
"""Load YOLOv8 model"""
"""Load YOLOv8 model lazily"""
if self.model is not None:
return # Already loaded
try:
# Import heavy dependencies only when needed
import torch
from ultralytics import YOLO
self.model = YOLO(f"{self.model_name}.pt")
if self.device == "cuda" and torch.cuda.is_available():
self.model.to("cuda")
print(f"🎯 Loaded YOLO model: {self.model_name}")
except Exception as e:
raise RuntimeError(f"Failed to load YOLO model {self.model_name}: {e}")
@@ -34,8 +41,9 @@ class YOLODetector:
Returns:
List of detection dictionaries with bbox, confidence, and class info
"""
# Load model lazily on first use
if self.model is None:
raise RuntimeError("YOLO model not loaded")
self._load_model()
results = self.model(frame, verbose=False)
detections = []

View File

@@ -9,12 +9,16 @@ import tempfile
import shutil
import gc
try:
from sam2.build_sam import build_sam2_video_predictor
from sam2.sam2_image_predictor import SAM2ImagePredictor
SAM2_AVAILABLE = True
except ImportError:
SAM2_AVAILABLE = False
# Check SAM2 availability without importing heavy modules
def _check_sam2_available():
try:
import sam2
return True
except ImportError:
return False
SAM2_AVAILABLE = _check_sam2_available()
if not SAM2_AVAILABLE:
warnings.warn("SAM2 not available. Please install sam2 package.")
@@ -40,11 +44,18 @@ class SAM2VideoMatting:
self.video_segments = {}
self.temp_video_path = None
self._load_model(model_cfg, checkpoint_path)
# Don't load model during init - load lazily when needed
self._model_loaded = False
def _load_model(self, model_cfg: str, checkpoint_path: str):
"""Load SAM2 video predictor with optimizations"""
"""Load SAM2 video predictor lazily"""
if self._model_loaded:
return # Already loaded
try:
# Import heavy SAM2 modules only when needed
from sam2.build_sam import build_sam2_video_predictor
# Check for checkpoint in SAM2 repo structure
if not Path(checkpoint_path).exists():
# Try in segment-anything-2/checkpoints/
@@ -63,6 +74,7 @@ class SAM2VideoMatting:
if sam2_repo_path.exists():
checkpoint_path = str(sam2_repo_path)
print(f"🎯 Loading SAM2 model: {model_cfg}")
# Use SAM2's build_sam2_video_predictor which returns the predictor directly
# The predictor IS the model - no .model attribute needed
self.predictor = build_sam2_video_predictor(
@@ -71,13 +83,16 @@ class SAM2VideoMatting:
device=self.device
)
self._model_loaded = True
print(f"✅ SAM2 model loaded successfully")
except Exception as e:
raise RuntimeError(f"Failed to load SAM2 model: {e}")
def init_video_state(self, video_frames: List[np.ndarray] = None, video_path: str = None) -> None:
"""Initialize video inference state"""
if self.predictor is None:
# Recreate predictor if it was cleaned up
# Load model lazily on first use
if not self._model_loaded:
self._load_model(self.model_cfg, self.checkpoint_path)
if video_path is not None: