stuff
This commit is contained in:
@@ -1,6 +1,4 @@
|
|||||||
import torch
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from ultralytics import YOLO
|
|
||||||
from typing import List, Tuple, Dict, Any
|
from typing import List, Tuple, Dict, Any
|
||||||
import cv2
|
import cv2
|
||||||
|
|
||||||
@@ -13,14 +11,23 @@ class YOLODetector:
|
|||||||
self.confidence_threshold = confidence_threshold
|
self.confidence_threshold = confidence_threshold
|
||||||
self.device = device
|
self.device = device
|
||||||
self.model = None
|
self.model = None
|
||||||
self._load_model()
|
# Don't load model during init - load lazily when first used
|
||||||
|
|
||||||
def _load_model(self):
|
def _load_model(self):
|
||||||
"""Load YOLOv8 model"""
|
"""Load YOLOv8 model lazily"""
|
||||||
|
if self.model is not None:
|
||||||
|
return # Already loaded
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# Import heavy dependencies only when needed
|
||||||
|
import torch
|
||||||
|
from ultralytics import YOLO
|
||||||
|
|
||||||
self.model = YOLO(f"{self.model_name}.pt")
|
self.model = YOLO(f"{self.model_name}.pt")
|
||||||
if self.device == "cuda" and torch.cuda.is_available():
|
if self.device == "cuda" and torch.cuda.is_available():
|
||||||
self.model.to("cuda")
|
self.model.to("cuda")
|
||||||
|
|
||||||
|
print(f"🎯 Loaded YOLO model: {self.model_name}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Failed to load YOLO model {self.model_name}: {e}")
|
raise RuntimeError(f"Failed to load YOLO model {self.model_name}: {e}")
|
||||||
|
|
||||||
@@ -34,8 +41,9 @@ class YOLODetector:
|
|||||||
Returns:
|
Returns:
|
||||||
List of detection dictionaries with bbox, confidence, and class info
|
List of detection dictionaries with bbox, confidence, and class info
|
||||||
"""
|
"""
|
||||||
|
# Load model lazily on first use
|
||||||
if self.model is None:
|
if self.model is None:
|
||||||
raise RuntimeError("YOLO model not loaded")
|
self._load_model()
|
||||||
|
|
||||||
results = self.model(frame, verbose=False)
|
results = self.model(frame, verbose=False)
|
||||||
detections = []
|
detections = []
|
||||||
|
|||||||
@@ -9,12 +9,16 @@ import tempfile
|
|||||||
import shutil
|
import shutil
|
||||||
import gc
|
import gc
|
||||||
|
|
||||||
|
# Check SAM2 availability without importing heavy modules
|
||||||
|
def _check_sam2_available():
|
||||||
try:
|
try:
|
||||||
from sam2.build_sam import build_sam2_video_predictor
|
import sam2
|
||||||
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
return True
|
||||||
SAM2_AVAILABLE = True
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
SAM2_AVAILABLE = False
|
return False
|
||||||
|
|
||||||
|
SAM2_AVAILABLE = _check_sam2_available()
|
||||||
|
if not SAM2_AVAILABLE:
|
||||||
warnings.warn("SAM2 not available. Please install sam2 package.")
|
warnings.warn("SAM2 not available. Please install sam2 package.")
|
||||||
|
|
||||||
|
|
||||||
@@ -40,11 +44,18 @@ class SAM2VideoMatting:
|
|||||||
self.video_segments = {}
|
self.video_segments = {}
|
||||||
self.temp_video_path = None
|
self.temp_video_path = None
|
||||||
|
|
||||||
self._load_model(model_cfg, checkpoint_path)
|
# Don't load model during init - load lazily when needed
|
||||||
|
self._model_loaded = False
|
||||||
|
|
||||||
def _load_model(self, model_cfg: str, checkpoint_path: str):
|
def _load_model(self, model_cfg: str, checkpoint_path: str):
|
||||||
"""Load SAM2 video predictor with optimizations"""
|
"""Load SAM2 video predictor lazily"""
|
||||||
|
if self._model_loaded:
|
||||||
|
return # Already loaded
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# Import heavy SAM2 modules only when needed
|
||||||
|
from sam2.build_sam import build_sam2_video_predictor
|
||||||
|
|
||||||
# Check for checkpoint in SAM2 repo structure
|
# Check for checkpoint in SAM2 repo structure
|
||||||
if not Path(checkpoint_path).exists():
|
if not Path(checkpoint_path).exists():
|
||||||
# Try in segment-anything-2/checkpoints/
|
# Try in segment-anything-2/checkpoints/
|
||||||
@@ -63,6 +74,7 @@ class SAM2VideoMatting:
|
|||||||
if sam2_repo_path.exists():
|
if sam2_repo_path.exists():
|
||||||
checkpoint_path = str(sam2_repo_path)
|
checkpoint_path = str(sam2_repo_path)
|
||||||
|
|
||||||
|
print(f"🎯 Loading SAM2 model: {model_cfg}")
|
||||||
# Use SAM2's build_sam2_video_predictor which returns the predictor directly
|
# Use SAM2's build_sam2_video_predictor which returns the predictor directly
|
||||||
# The predictor IS the model - no .model attribute needed
|
# The predictor IS the model - no .model attribute needed
|
||||||
self.predictor = build_sam2_video_predictor(
|
self.predictor = build_sam2_video_predictor(
|
||||||
@@ -71,13 +83,16 @@ class SAM2VideoMatting:
|
|||||||
device=self.device
|
device=self.device
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._model_loaded = True
|
||||||
|
print(f"✅ SAM2 model loaded successfully")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Failed to load SAM2 model: {e}")
|
raise RuntimeError(f"Failed to load SAM2 model: {e}")
|
||||||
|
|
||||||
def init_video_state(self, video_frames: List[np.ndarray] = None, video_path: str = None) -> None:
|
def init_video_state(self, video_frames: List[np.ndarray] = None, video_path: str = None) -> None:
|
||||||
"""Initialize video inference state"""
|
"""Initialize video inference state"""
|
||||||
if self.predictor is None:
|
# Load model lazily on first use
|
||||||
# Recreate predictor if it was cleaned up
|
if not self._model_loaded:
|
||||||
self._load_model(self.model_cfg, self.checkpoint_path)
|
self._load_model(self.model_cfg, self.checkpoint_path)
|
||||||
|
|
||||||
if video_path is not None:
|
if video_path is not None:
|
||||||
|
|||||||
Reference in New Issue
Block a user