286 lines
11 KiB
Python
286 lines
11 KiB
Python
"""
|
|
YOLO detector module for human detection in video segments.
|
|
Preserves the core detection logic from the original implementation.
|
|
"""
|
|
|
|
import os
|
|
import cv2
|
|
import numpy as np
|
|
import logging
|
|
from typing import List, Dict, Any, Optional
|
|
from ultralytics import YOLO
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class YOLODetector:
|
|
\"\"\"Handles YOLO-based human detection for video segments.\"\"\"
|
|
|
|
def __init__(self, model_path: str, confidence_threshold: float = 0.6, human_class_id: int = 0):
|
|
\"\"\"
|
|
Initialize YOLO detector.
|
|
|
|
Args:
|
|
model_path: Path to YOLO model weights
|
|
confidence_threshold: Detection confidence threshold
|
|
human_class_id: COCO class ID for humans (0 = person)
|
|
\"\"\"
|
|
self.model_path = model_path
|
|
self.confidence_threshold = confidence_threshold
|
|
self.human_class_id = human_class_id
|
|
|
|
# Load YOLO model
|
|
try:
|
|
self.model = YOLO(model_path)
|
|
logger.info(f\"Loaded YOLO model from {model_path}\")
|
|
except Exception as e:
|
|
logger.error(f\"Failed to load YOLO model: {e}\")
|
|
raise
|
|
|
|
def detect_humans_in_frame(self, frame: np.ndarray) -> List[Dict[str, Any]]:
|
|
\"\"\"
|
|
Detect humans in a single frame using YOLO.
|
|
|
|
Args:
|
|
frame: Input frame (BGR format from OpenCV)
|
|
|
|
Returns:
|
|
List of human detection dictionaries with bbox and confidence
|
|
\"\"\"
|
|
# Run YOLO detection
|
|
results = self.model(frame, conf=self.confidence_threshold, verbose=False)
|
|
|
|
human_detections = []
|
|
|
|
# Process results
|
|
for result in results:
|
|
boxes = result.boxes
|
|
if boxes is not None:
|
|
for box in boxes:
|
|
# Get class ID
|
|
cls = int(box.cls.cpu().numpy()[0])
|
|
|
|
# Check if it's a person (human_class_id)
|
|
if cls == self.human_class_id:
|
|
# Get bounding box coordinates (x1, y1, x2, y2)
|
|
coords = box.xyxy[0].cpu().numpy()
|
|
conf = float(box.conf.cpu().numpy()[0])
|
|
|
|
human_detections.append({
|
|
'bbox': coords,
|
|
'confidence': conf
|
|
})
|
|
|
|
logger.debug(f\"Detected human with confidence {conf:.2f} at {coords}\")
|
|
|
|
return human_detections
|
|
|
|
def detect_humans_in_video_first_frame(self, video_path: str, scale: float = 1.0) -> List[Dict[str, Any]]:
|
|
\"\"\"
|
|
Detect humans in the first frame of a video.
|
|
|
|
Args:
|
|
video_path: Path to video file
|
|
scale: Scale factor for frame processing
|
|
|
|
Returns:
|
|
List of human detection dictionaries
|
|
\"\"\"
|
|
if not os.path.exists(video_path):
|
|
logger.error(f\"Video file not found: {video_path}\")
|
|
return []
|
|
|
|
cap = cv2.VideoCapture(video_path)
|
|
if not cap.isOpened():
|
|
logger.error(f\"Could not open video: {video_path}\")
|
|
return []
|
|
|
|
ret, frame = cap.read()
|
|
cap.release()
|
|
|
|
if not ret:
|
|
logger.error(f\"Could not read first frame from: {video_path}\")
|
|
return []
|
|
|
|
# Scale frame if needed
|
|
if scale != 1.0:
|
|
frame = cv2.resize(frame, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR)
|
|
|
|
return self.detect_humans_in_frame(frame)
|
|
|
|
def save_detections_to_file(self, detections: List[Dict[str, Any]], output_path: str) -> bool:
|
|
\"\"\"
|
|
Save detection results to file.
|
|
|
|
Args:
|
|
detections: List of detection dictionaries
|
|
output_path: Path to save detections
|
|
|
|
Returns:
|
|
True if saved successfully
|
|
\"\"\"
|
|
try:
|
|
with open(output_path, 'w') as f:
|
|
f.write(\"# YOLO Human Detections\\n\")
|
|
if detections:
|
|
for detection in detections:
|
|
bbox = detection['bbox']
|
|
conf = detection['confidence']
|
|
f.write(f\"{bbox[0]},{bbox[1]},{bbox[2]},{bbox[3]},{conf}\\n\")
|
|
logger.info(f\"Saved {len(detections)} detections to {output_path}\")
|
|
else:
|
|
f.write(\"# No humans detected\\n\")
|
|
logger.info(f\"Saved empty detection file to {output_path}\")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f\"Failed to save detections to {output_path}: {e}\")
|
|
return False
|
|
|
|
def load_detections_from_file(self, file_path: str) -> List[Dict[str, Any]]:
|
|
\"\"\"
|
|
Load detection results from file.
|
|
|
|
Args:
|
|
file_path: Path to detection file
|
|
|
|
Returns:
|
|
List of detection dictionaries
|
|
\"\"\"
|
|
detections = []
|
|
|
|
if not os.path.exists(file_path):
|
|
logger.warning(f\"Detection file not found: {file_path}\")
|
|
return detections
|
|
|
|
try:
|
|
with open(file_path, 'r') as f:
|
|
for line in f:
|
|
line = line.strip()
|
|
# Skip comments and empty lines
|
|
if line.startswith('#') or not line:
|
|
continue
|
|
|
|
# Parse detection line: x1,y1,x2,y2,confidence
|
|
parts = line.split(',')
|
|
if len(parts) == 5:
|
|
try:
|
|
bbox = [float(x) for x in parts[:4]]
|
|
conf = float(parts[4])
|
|
detections.append({
|
|
'bbox': np.array(bbox),
|
|
'confidence': conf
|
|
})
|
|
except ValueError:
|
|
logger.warning(f\"Invalid detection line: {line}\")
|
|
continue
|
|
|
|
logger.info(f\"Loaded {len(detections)} detections from {file_path}\")
|
|
except Exception as e:
|
|
logger.error(f\"Failed to load detections from {file_path}: {e}\")
|
|
|
|
return detections
|
|
|
|
def process_segments_batch(self, segments_info: List[dict], detect_segments: List[int],
|
|
scale: float = 0.5) -> Dict[int, List[Dict[str, Any]]]:
|
|
\"\"\"
|
|
Process multiple segments for human detection.
|
|
|
|
Args:
|
|
segments_info: List of segment information dictionaries
|
|
detect_segments: List of segment indices to process
|
|
scale: Scale factor for processing
|
|
|
|
Returns:
|
|
Dictionary mapping segment index to detection results
|
|
\"\"\"
|
|
results = {}
|
|
|
|
for segment_info in segments_info:
|
|
segment_idx = segment_info['index']
|
|
|
|
# Skip if not in detect_segments list
|
|
if detect_segments != 'all' and segment_idx not in detect_segments:
|
|
continue
|
|
|
|
video_path = segment_info['video_file']
|
|
detection_file = os.path.join(segment_info['directory'], \"yolo_detections\")
|
|
|
|
# Skip if already processed
|
|
if os.path.exists(detection_file):
|
|
logger.info(f\"Segment {segment_idx} already has detections, skipping\")
|
|
detections = self.load_detections_from_file(detection_file)
|
|
results[segment_idx] = detections
|
|
continue
|
|
|
|
# Run detection
|
|
logger.info(f\"Processing segment {segment_idx} for human detection\")
|
|
detections = self.detect_humans_in_video_first_frame(video_path, scale)
|
|
|
|
# Save results
|
|
self.save_detections_to_file(detections, detection_file)
|
|
results[segment_idx] = detections
|
|
|
|
return results
|
|
|
|
def convert_detections_to_sam2_prompts(self, detections: List[Dict[str, Any]],
|
|
frame_width: int) -> List[Dict[str, Any]]:
|
|
\"\"\"
|
|
Convert YOLO detections to SAM2-compatible prompts for stereo video.
|
|
|
|
Args:
|
|
detections: List of YOLO detection results
|
|
frame_width: Width of the video frame
|
|
|
|
Returns:
|
|
List of SAM2 prompt dictionaries with obj_id and bbox
|
|
\"\"\"
|
|
if not detections:
|
|
return []
|
|
|
|
half_frame_width = frame_width // 2
|
|
prompts = []
|
|
|
|
# Sort detections by x-coordinate to get consistent left/right assignment
|
|
sorted_detections = sorted(detections, key=lambda x: x['bbox'][0])
|
|
|
|
obj_id = 1
|
|
|
|
for i, detection in enumerate(sorted_detections[:2]): # Take up to 2 humans
|
|
bbox = detection['bbox'].copy()
|
|
|
|
# For stereo videos, assign obj_id based on position
|
|
if len(sorted_detections) >= 2:
|
|
center_x = (bbox[0] + bbox[2]) / 2
|
|
if center_x < half_frame_width:
|
|
current_obj_id = 1 # Left human
|
|
else:
|
|
current_obj_id = 2 # Right human
|
|
else:
|
|
# If only one human, create prompts for both sides
|
|
current_obj_id = obj_id
|
|
obj_id += 1
|
|
|
|
# Create mirrored version for stereo
|
|
if obj_id <= 2:
|
|
mirrored_bbox = bbox.copy()
|
|
mirrored_bbox[0] += half_frame_width # Shift x1
|
|
mirrored_bbox[2] += half_frame_width # Shift x2
|
|
|
|
# Ensure mirrored bbox is within frame bounds
|
|
mirrored_bbox[0] = max(0, min(mirrored_bbox[0], frame_width - 1))
|
|
mirrored_bbox[2] = max(0, min(mirrored_bbox[2], frame_width - 1))
|
|
|
|
prompts.append({
|
|
'obj_id': obj_id,
|
|
'bbox': mirrored_bbox,
|
|
'confidence': detection['confidence']
|
|
})
|
|
obj_id += 1
|
|
|
|
prompts.append({
|
|
'obj_id': current_obj_id,
|
|
'bbox': bbox,
|
|
'confidence': detection['confidence']
|
|
})
|
|
|
|
logger.debug(f\"Converted {len(detections)} detections to {len(prompts)} SAM2 prompts\")
|
|
return prompts |