134 lines
4.9 KiB
Python
134 lines
4.9 KiB
Python
import numpy as np
|
|
from typing import List, Tuple, Dict, Any
|
|
import cv2
|
|
|
|
|
|
class YOLODetector:
|
|
"""YOLOv8-based person detector for automatic SAM2 prompting"""
|
|
|
|
def __init__(self, model_name: str = "yolov8n", confidence_threshold: float = 0.7, device: str = "cuda"):
|
|
self.model_name = model_name
|
|
self.confidence_threshold = confidence_threshold
|
|
self.device = device
|
|
self.model = None
|
|
# Don't load model during init - load lazily when first used
|
|
|
|
def _load_model(self):
|
|
"""Load YOLOv8 model lazily"""
|
|
if self.model is not None:
|
|
return # Already loaded
|
|
|
|
try:
|
|
# Import heavy dependencies only when needed
|
|
import torch
|
|
from ultralytics import YOLO
|
|
|
|
self.model = YOLO(f"{self.model_name}.pt")
|
|
if self.device == "cuda" and torch.cuda.is_available():
|
|
self.model.to("cuda")
|
|
|
|
print(f"🎯 Loaded YOLO model: {self.model_name}")
|
|
except Exception as e:
|
|
raise RuntimeError(f"Failed to load YOLO model {self.model_name}: {e}")
|
|
|
|
def detect_persons(self, frame: np.ndarray) -> List[Dict[str, Any]]:
|
|
"""
|
|
Detect persons in frame and return bounding boxes
|
|
|
|
Args:
|
|
frame: Input frame (H, W, 3)
|
|
|
|
Returns:
|
|
List of detection dictionaries with bbox, confidence, and class info
|
|
"""
|
|
# Load model lazily on first use
|
|
if self.model is None:
|
|
self._load_model()
|
|
|
|
results = self.model(frame, verbose=False)
|
|
detections = []
|
|
|
|
for result in results:
|
|
boxes = result.boxes
|
|
if boxes is not None:
|
|
for box in boxes:
|
|
# Only keep person detections (class 0 in COCO)
|
|
if int(box.cls) == 0 and float(box.conf) >= self.confidence_threshold:
|
|
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
|
|
confidence = float(box.conf)
|
|
|
|
detection = {
|
|
'bbox': [int(x1), int(y1), int(x2), int(y2)],
|
|
'confidence': confidence,
|
|
'class': 'person',
|
|
'area': (x2 - x1) * (y2 - y1)
|
|
}
|
|
detections.append(detection)
|
|
|
|
# Sort by confidence (highest first)
|
|
detections.sort(key=lambda x: x['confidence'], reverse=True)
|
|
return detections
|
|
|
|
def convert_to_sam_prompts(self, detections: List[Dict[str, Any]]) -> Tuple[np.ndarray, np.ndarray]:
|
|
"""
|
|
Convert YOLO detections to SAM2 box prompts
|
|
|
|
Args:
|
|
detections: List of detection dictionaries
|
|
|
|
Returns:
|
|
Tuple of (box_prompts, labels) for SAM2
|
|
"""
|
|
if not detections:
|
|
return np.array([]), np.array([])
|
|
|
|
box_prompts = []
|
|
labels = []
|
|
|
|
for detection in detections:
|
|
bbox = detection['bbox']
|
|
box_prompts.append(bbox)
|
|
labels.append(1) # Positive prompt
|
|
|
|
return np.array(box_prompts), np.array(labels)
|
|
|
|
def visualize_detections(self, frame: np.ndarray, detections: List[Dict[str, Any]]) -> np.ndarray:
|
|
"""
|
|
Draw detection boxes on frame for debugging
|
|
|
|
Args:
|
|
frame: Input frame
|
|
detections: List of detections
|
|
|
|
Returns:
|
|
Frame with drawn bounding boxes
|
|
"""
|
|
vis_frame = frame.copy()
|
|
|
|
for detection in detections:
|
|
x1, y1, x2, y2 = detection['bbox']
|
|
confidence = detection['confidence']
|
|
|
|
# Draw bounding box
|
|
cv2.rectangle(vis_frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
|
|
|
# Draw confidence score
|
|
label = f"Person: {confidence:.2f}"
|
|
label_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)[0]
|
|
cv2.rectangle(vis_frame, (x1, y1 - label_size[1] - 10),
|
|
(x1 + label_size[0], y1), (0, 255, 0), -1)
|
|
cv2.putText(vis_frame, label, (x1, y1 - 5),
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 2)
|
|
|
|
return vis_frame
|
|
|
|
def get_largest_person(self, detections: List[Dict[str, Any]]) -> Dict[str, Any]:
|
|
"""Get the largest detected person (by bounding box area)"""
|
|
if not detections:
|
|
return None
|
|
|
|
return max(detections, key=lambda x: x['area'])
|
|
|
|
def filter_by_size(self, detections: List[Dict[str, Any]], min_area: int = 1000) -> List[Dict[str, Any]]:
|
|
"""Filter detections by minimum bounding box area"""
|
|
return [d for d in detections if d['area'] >= min_area] |