stuff
This commit is contained in:
@@ -1,6 +1,4 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from ultralytics import YOLO
|
||||
from typing import List, Tuple, Dict, Any
|
||||
import cv2
|
||||
|
||||
@@ -13,14 +11,23 @@ class YOLODetector:
|
||||
self.confidence_threshold = confidence_threshold
|
||||
self.device = device
|
||||
self.model = None
|
||||
self._load_model()
|
||||
# Don't load model during init - load lazily when first used
|
||||
|
||||
def _load_model(self):
|
||||
"""Load YOLOv8 model"""
|
||||
"""Load YOLOv8 model lazily"""
|
||||
if self.model is not None:
|
||||
return # Already loaded
|
||||
|
||||
try:
|
||||
# Import heavy dependencies only when needed
|
||||
import torch
|
||||
from ultralytics import YOLO
|
||||
|
||||
self.model = YOLO(f"{self.model_name}.pt")
|
||||
if self.device == "cuda" and torch.cuda.is_available():
|
||||
self.model.to("cuda")
|
||||
|
||||
print(f"🎯 Loaded YOLO model: {self.model_name}")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load YOLO model {self.model_name}: {e}")
|
||||
|
||||
@@ -34,8 +41,9 @@ class YOLODetector:
|
||||
Returns:
|
||||
List of detection dictionaries with bbox, confidence, and class info
|
||||
"""
|
||||
# Load model lazily on first use
|
||||
if self.model is None:
|
||||
raise RuntimeError("YOLO model not loaded")
|
||||
self._load_model()
|
||||
|
||||
results = self.model(frame, verbose=False)
|
||||
detections = []
|
||||
|
||||
@@ -9,12 +9,16 @@ import tempfile
|
||||
import shutil
|
||||
import gc
|
||||
|
||||
try:
|
||||
from sam2.build_sam import build_sam2_video_predictor
|
||||
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
||||
SAM2_AVAILABLE = True
|
||||
except ImportError:
|
||||
SAM2_AVAILABLE = False
|
||||
# Check SAM2 availability without importing heavy modules
|
||||
def _check_sam2_available():
|
||||
try:
|
||||
import sam2
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
SAM2_AVAILABLE = _check_sam2_available()
|
||||
if not SAM2_AVAILABLE:
|
||||
warnings.warn("SAM2 not available. Please install sam2 package.")
|
||||
|
||||
|
||||
@@ -40,11 +44,18 @@ class SAM2VideoMatting:
|
||||
self.video_segments = {}
|
||||
self.temp_video_path = None
|
||||
|
||||
self._load_model(model_cfg, checkpoint_path)
|
||||
# Don't load model during init - load lazily when needed
|
||||
self._model_loaded = False
|
||||
|
||||
def _load_model(self, model_cfg: str, checkpoint_path: str):
|
||||
"""Load SAM2 video predictor with optimizations"""
|
||||
"""Load SAM2 video predictor lazily"""
|
||||
if self._model_loaded:
|
||||
return # Already loaded
|
||||
|
||||
try:
|
||||
# Import heavy SAM2 modules only when needed
|
||||
from sam2.build_sam import build_sam2_video_predictor
|
||||
|
||||
# Check for checkpoint in SAM2 repo structure
|
||||
if not Path(checkpoint_path).exists():
|
||||
# Try in segment-anything-2/checkpoints/
|
||||
@@ -63,6 +74,7 @@ class SAM2VideoMatting:
|
||||
if sam2_repo_path.exists():
|
||||
checkpoint_path = str(sam2_repo_path)
|
||||
|
||||
print(f"🎯 Loading SAM2 model: {model_cfg}")
|
||||
# Use SAM2's build_sam2_video_predictor which returns the predictor directly
|
||||
# The predictor IS the model - no .model attribute needed
|
||||
self.predictor = build_sam2_video_predictor(
|
||||
@@ -71,13 +83,16 @@ class SAM2VideoMatting:
|
||||
device=self.device
|
||||
)
|
||||
|
||||
self._model_loaded = True
|
||||
print(f"✅ SAM2 model loaded successfully")
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load SAM2 model: {e}")
|
||||
|
||||
def init_video_state(self, video_frames: List[np.ndarray] = None, video_path: str = None) -> None:
|
||||
"""Initialize video inference state"""
|
||||
if self.predictor is None:
|
||||
# Recreate predictor if it was cleaned up
|
||||
# Load model lazily on first use
|
||||
if not self._model_loaded:
|
||||
self._load_model(self.model_cfg, self.checkpoint_path)
|
||||
|
||||
if video_path is not None:
|
||||
|
||||
Reference in New Issue
Block a user