Files
test2/vr180_matting/detector.py
2025-07-26 15:18:01 -07:00

134 lines
4.9 KiB
Python

import numpy as np
from typing import List, Tuple, Dict, Any
import cv2
class YOLODetector:
"""YOLOv8-based person detector for automatic SAM2 prompting"""
def __init__(self, model_name: str = "yolov8n", confidence_threshold: float = 0.7, device: str = "cuda"):
self.model_name = model_name
self.confidence_threshold = confidence_threshold
self.device = device
self.model = None
# Don't load model during init - load lazily when first used
def _load_model(self):
"""Load YOLOv8 model lazily"""
if self.model is not None:
return # Already loaded
try:
# Import heavy dependencies only when needed
import torch
from ultralytics import YOLO
self.model = YOLO(f"{self.model_name}.pt")
if self.device == "cuda" and torch.cuda.is_available():
self.model.to("cuda")
print(f"🎯 Loaded YOLO model: {self.model_name}")
except Exception as e:
raise RuntimeError(f"Failed to load YOLO model {self.model_name}: {e}")
def detect_persons(self, frame: np.ndarray) -> List[Dict[str, Any]]:
"""
Detect persons in frame and return bounding boxes
Args:
frame: Input frame (H, W, 3)
Returns:
List of detection dictionaries with bbox, confidence, and class info
"""
# Load model lazily on first use
if self.model is None:
self._load_model()
results = self.model(frame, verbose=False)
detections = []
for result in results:
boxes = result.boxes
if boxes is not None:
for box in boxes:
# Only keep person detections (class 0 in COCO)
if int(box.cls) == 0 and float(box.conf) >= self.confidence_threshold:
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
confidence = float(box.conf)
detection = {
'bbox': [int(x1), int(y1), int(x2), int(y2)],
'confidence': confidence,
'class': 'person',
'area': (x2 - x1) * (y2 - y1)
}
detections.append(detection)
# Sort by confidence (highest first)
detections.sort(key=lambda x: x['confidence'], reverse=True)
return detections
def convert_to_sam_prompts(self, detections: List[Dict[str, Any]]) -> Tuple[np.ndarray, np.ndarray]:
"""
Convert YOLO detections to SAM2 box prompts
Args:
detections: List of detection dictionaries
Returns:
Tuple of (box_prompts, labels) for SAM2
"""
if not detections:
return np.array([]), np.array([])
box_prompts = []
labels = []
for detection in detections:
bbox = detection['bbox']
box_prompts.append(bbox)
labels.append(1) # Positive prompt
return np.array(box_prompts), np.array(labels)
def visualize_detections(self, frame: np.ndarray, detections: List[Dict[str, Any]]) -> np.ndarray:
"""
Draw detection boxes on frame for debugging
Args:
frame: Input frame
detections: List of detections
Returns:
Frame with drawn bounding boxes
"""
vis_frame = frame.copy()
for detection in detections:
x1, y1, x2, y2 = detection['bbox']
confidence = detection['confidence']
# Draw bounding box
cv2.rectangle(vis_frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
# Draw confidence score
label = f"Person: {confidence:.2f}"
label_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)[0]
cv2.rectangle(vis_frame, (x1, y1 - label_size[1] - 10),
(x1 + label_size[0], y1), (0, 255, 0), -1)
cv2.putText(vis_frame, label, (x1, y1 - 5),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 2)
return vis_frame
def get_largest_person(self, detections: List[Dict[str, Any]]) -> Dict[str, Any]:
"""Get the largest detected person (by bounding box area)"""
if not detections:
return None
return max(detections, key=lambda x: x['area'])
def filter_by_size(self, detections: List[Dict[str, Any]], min_area: int = 1000) -> List[Dict[str, Any]]:
"""Filter detections by minimum bounding box area"""
return [d for d in detections if d['area'] >= min_area]