first commit
This commit is contained in:
126
vr180_matting/detector.py
Normal file
126
vr180_matting/detector.py
Normal file
@@ -0,0 +1,126 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from ultralytics import YOLO
|
||||
from typing import List, Tuple, Dict, Any
|
||||
import cv2
|
||||
|
||||
|
||||
class YOLODetector:
|
||||
"""YOLOv8-based person detector for automatic SAM2 prompting"""
|
||||
|
||||
def __init__(self, model_name: str = "yolov8n", confidence_threshold: float = 0.7, device: str = "cuda"):
|
||||
self.model_name = model_name
|
||||
self.confidence_threshold = confidence_threshold
|
||||
self.device = device
|
||||
self.model = None
|
||||
self._load_model()
|
||||
|
||||
def _load_model(self):
|
||||
"""Load YOLOv8 model"""
|
||||
try:
|
||||
self.model = YOLO(f"{self.model_name}.pt")
|
||||
if self.device == "cuda" and torch.cuda.is_available():
|
||||
self.model.to("cuda")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load YOLO model {self.model_name}: {e}")
|
||||
|
||||
def detect_persons(self, frame: np.ndarray) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Detect persons in frame and return bounding boxes
|
||||
|
||||
Args:
|
||||
frame: Input frame (H, W, 3)
|
||||
|
||||
Returns:
|
||||
List of detection dictionaries with bbox, confidence, and class info
|
||||
"""
|
||||
if self.model is None:
|
||||
raise RuntimeError("YOLO model not loaded")
|
||||
|
||||
results = self.model(frame, verbose=False)
|
||||
detections = []
|
||||
|
||||
for result in results:
|
||||
boxes = result.boxes
|
||||
if boxes is not None:
|
||||
for box in boxes:
|
||||
# Only keep person detections (class 0 in COCO)
|
||||
if int(box.cls) == 0 and float(box.conf) >= self.confidence_threshold:
|
||||
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
|
||||
confidence = float(box.conf)
|
||||
|
||||
detection = {
|
||||
'bbox': [int(x1), int(y1), int(x2), int(y2)],
|
||||
'confidence': confidence,
|
||||
'class': 'person',
|
||||
'area': (x2 - x1) * (y2 - y1)
|
||||
}
|
||||
detections.append(detection)
|
||||
|
||||
# Sort by confidence (highest first)
|
||||
detections.sort(key=lambda x: x['confidence'], reverse=True)
|
||||
return detections
|
||||
|
||||
def convert_to_sam_prompts(self, detections: List[Dict[str, Any]]) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Convert YOLO detections to SAM2 box prompts
|
||||
|
||||
Args:
|
||||
detections: List of detection dictionaries
|
||||
|
||||
Returns:
|
||||
Tuple of (box_prompts, labels) for SAM2
|
||||
"""
|
||||
if not detections:
|
||||
return np.array([]), np.array([])
|
||||
|
||||
box_prompts = []
|
||||
labels = []
|
||||
|
||||
for detection in detections:
|
||||
bbox = detection['bbox']
|
||||
box_prompts.append(bbox)
|
||||
labels.append(1) # Positive prompt
|
||||
|
||||
return np.array(box_prompts), np.array(labels)
|
||||
|
||||
def visualize_detections(self, frame: np.ndarray, detections: List[Dict[str, Any]]) -> np.ndarray:
|
||||
"""
|
||||
Draw detection boxes on frame for debugging
|
||||
|
||||
Args:
|
||||
frame: Input frame
|
||||
detections: List of detections
|
||||
|
||||
Returns:
|
||||
Frame with drawn bounding boxes
|
||||
"""
|
||||
vis_frame = frame.copy()
|
||||
|
||||
for detection in detections:
|
||||
x1, y1, x2, y2 = detection['bbox']
|
||||
confidence = detection['confidence']
|
||||
|
||||
# Draw bounding box
|
||||
cv2.rectangle(vis_frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
||||
|
||||
# Draw confidence score
|
||||
label = f"Person: {confidence:.2f}"
|
||||
label_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)[0]
|
||||
cv2.rectangle(vis_frame, (x1, y1 - label_size[1] - 10),
|
||||
(x1 + label_size[0], y1), (0, 255, 0), -1)
|
||||
cv2.putText(vis_frame, label, (x1, y1 - 5),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 2)
|
||||
|
||||
return vis_frame
|
||||
|
||||
def get_largest_person(self, detections: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""Get the largest detected person (by bounding box area)"""
|
||||
if not detections:
|
||||
return None
|
||||
|
||||
return max(detections, key=lambda x: x['area'])
|
||||
|
||||
def filter_by_size(self, detections: List[Dict[str, Any]], min_area: int = 1000) -> List[Dict[str, Any]]:
|
||||
"""Filter detections by minimum bounding box area"""
|
||||
return [d for d in detections if d['area'] >= min_area]
|
||||
Reference in New Issue
Block a user