inital commit
This commit is contained in:
2
core/__init__.py
Normal file
2
core/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# YOLO + SAM2 Video Processing Pipeline
|
||||
# Core modules for video processing with human detection and segmentation
|
||||
158
core/config_loader.py
Normal file
158
core/config_loader.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""
|
||||
Configuration loader for YOLO + SAM2 video processing pipeline.
|
||||
Handles loading and validation of YAML configuration files.
|
||||
"""
|
||||
|
||||
import yaml
|
||||
import os
|
||||
from typing import Dict, Any, List, Union
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ConfigLoader:
|
||||
"""Loads and validates configuration from YAML files."""
|
||||
|
||||
def __init__(self, config_path: str):
|
||||
self.config_path = config_path
|
||||
self.config = self._load_config()
|
||||
self._validate_config()
|
||||
|
||||
def _load_config(self) -> Dict[str, Any]:
|
||||
"""Load configuration from YAML file."""
|
||||
if not os.path.exists(self.config_path):
|
||||
raise FileNotFoundError(f"Configuration file not found: {self.config_path}")
|
||||
|
||||
try:
|
||||
with open(self.config_path, 'r') as file:
|
||||
config = yaml.safe_load(file)
|
||||
logger.info(f"Loaded configuration from {self.config_path}")
|
||||
return config
|
||||
except yaml.YAMLError as e:
|
||||
raise ValueError(f"Error parsing YAML file: {e}")
|
||||
|
||||
def _validate_config(self):
|
||||
"""Validate required configuration fields."""
|
||||
required_sections = ['input', 'output', 'processing', 'models']
|
||||
|
||||
for section in required_sections:
|
||||
if section not in self.config:
|
||||
raise ValueError(f"Missing required configuration section: {section}")
|
||||
|
||||
# Validate input section
|
||||
if 'video_path' not in self.config['input']:
|
||||
raise ValueError("Missing required field: input.video_path")
|
||||
|
||||
# Validate output section
|
||||
required_output_fields = ['directory', 'filename']
|
||||
for field in required_output_fields:
|
||||
if field not in self.config['output']:
|
||||
raise ValueError(f"Missing required field: output.{field}")
|
||||
|
||||
# Validate models section
|
||||
required_model_fields = ['yolo_model', 'sam2_checkpoint', 'sam2_config']
|
||||
for field in required_model_fields:
|
||||
if field not in self.config['models']:
|
||||
raise ValueError(f"Missing required field: models.{field}")
|
||||
|
||||
# Validate processing.detect_segments format
|
||||
detect_segments = self.config['processing'].get('detect_segments', 'all')
|
||||
if not isinstance(detect_segments, (str, list)):
|
||||
raise ValueError("detect_segments must be 'all' or a list of integers")
|
||||
|
||||
if isinstance(detect_segments, list):
|
||||
if not all(isinstance(x, int) for x in detect_segments):
|
||||
raise ValueError("detect_segments list must contain only integers")
|
||||
|
||||
def get(self, key_path: str, default=None):
|
||||
"""
|
||||
Get configuration value using dot notation.
|
||||
|
||||
Args:
|
||||
key_path: Dot-separated key path (e.g., 'processing.yolo_confidence')
|
||||
default: Default value if key not found
|
||||
|
||||
Returns:
|
||||
Configuration value or default
|
||||
"""
|
||||
keys = key_path.split('.')
|
||||
value = self.config
|
||||
|
||||
try:
|
||||
for key in keys:
|
||||
value = value[key]
|
||||
return value
|
||||
except (KeyError, TypeError):
|
||||
return default
|
||||
|
||||
def get_input_video_path(self) -> str:
|
||||
"""Get input video path."""
|
||||
return self.config['input']['video_path']
|
||||
|
||||
def get_output_directory(self) -> str:
|
||||
"""Get output directory path."""
|
||||
return self.config['output']['directory']
|
||||
|
||||
def get_output_filename(self) -> str:
|
||||
"""Get output filename."""
|
||||
return self.config['output']['filename']
|
||||
|
||||
def get_segment_duration(self) -> int:
|
||||
"""Get segment duration in seconds."""
|
||||
return self.config['processing'].get('segment_duration', 5)
|
||||
|
||||
def get_inference_scale(self) -> float:
|
||||
"""Get inference scale factor."""
|
||||
return self.config['processing'].get('inference_scale', 0.5)
|
||||
|
||||
def get_yolo_confidence(self) -> float:
|
||||
"""Get YOLO confidence threshold."""
|
||||
return self.config['processing'].get('yolo_confidence', 0.6)
|
||||
|
||||
def get_detect_segments(self) -> Union[str, List[int]]:
|
||||
"""Get segments for YOLO detection."""
|
||||
return self.config['processing'].get('detect_segments', 'all')
|
||||
|
||||
def get_yolo_model_path(self) -> str:
|
||||
"""Get YOLO model path."""
|
||||
return self.config['models']['yolo_model']
|
||||
|
||||
def get_sam2_checkpoint(self) -> str:
|
||||
"""Get SAM2 checkpoint path."""
|
||||
return self.config['models']['sam2_checkpoint']
|
||||
|
||||
def get_sam2_config(self) -> str:
|
||||
"""Get SAM2 config path."""
|
||||
return self.config['models']['sam2_config']
|
||||
|
||||
def get_use_nvenc(self) -> bool:
|
||||
"""Get whether to use NVIDIA encoding."""
|
||||
return self.config.get('video', {}).get('use_nvenc', True)
|
||||
|
||||
def get_preserve_audio(self) -> bool:
|
||||
"""Get whether to preserve audio."""
|
||||
return self.config.get('video', {}).get('preserve_audio', True)
|
||||
|
||||
def get_output_bitrate(self) -> str:
|
||||
"""Get output video bitrate."""
|
||||
return self.config.get('video', {}).get('output_bitrate', '50M')
|
||||
|
||||
def get_green_color(self) -> List[int]:
|
||||
"""Get green screen color."""
|
||||
return self.config.get('advanced', {}).get('green_color', [0, 255, 0])
|
||||
|
||||
def get_blue_color(self) -> List[int]:
|
||||
"""Get blue screen color."""
|
||||
return self.config.get('advanced', {}).get('blue_color', [255, 0, 0])
|
||||
|
||||
def get_human_class_id(self) -> int:
|
||||
"""Get YOLO human class ID."""
|
||||
return self.config.get('advanced', {}).get('human_class_id', 0)
|
||||
|
||||
def get_log_level(self) -> str:
|
||||
"""Get logging level."""
|
||||
return self.config.get('advanced', {}).get('log_level', 'INFO')
|
||||
|
||||
def should_cleanup_intermediate_files(self) -> bool:
|
||||
"""Get whether to cleanup intermediate files."""
|
||||
return self.config.get('advanced', {}).get('cleanup_intermediate_files', True)
|
||||
362
core/sam2_processor.py
Normal file
362
core/sam2_processor.py
Normal file
@@ -0,0 +1,362 @@
|
||||
"""
|
||||
SAM2 processor module for video segmentation.
|
||||
Preserves the core SAM2 logic from the original implementation.
|
||||
"""
|
||||
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import logging
|
||||
import gc
|
||||
from typing import Dict, List, Any, Optional, Tuple
|
||||
from sam2.build_sam import build_sam2_video_predictor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class SAM2Processor:
|
||||
"""Handles SAM2-based video segmentation for human tracking."""
|
||||
|
||||
def __init__(self, checkpoint_path: str, config_path: str):
|
||||
"""
|
||||
Initialize SAM2 processor.
|
||||
|
||||
Args:
|
||||
checkpoint_path: Path to SAM2 checkpoint
|
||||
config_path: Path to SAM2 config file
|
||||
"""
|
||||
self.checkpoint_path = checkpoint_path
|
||||
self.config_path = config_path
|
||||
self.predictor = None
|
||||
self._initialize_predictor()
|
||||
|
||||
def _initialize_predictor(self):
|
||||
"""Initialize SAM2 video predictor with proper device setup."""
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
elif torch.backends.mps.is_available():
|
||||
device = torch.device("mps")
|
||||
logger.warning(
|
||||
"Support for MPS devices is preliminary. SAM 2 is trained with CUDA and might "
|
||||
"give numerically different outputs and sometimes degraded performance on MPS."
|
||||
)
|
||||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
|
||||
logger.info(f"Using device: {device}")
|
||||
|
||||
try:
|
||||
self.predictor = build_sam2_video_predictor(
|
||||
self.config_path,
|
||||
self.checkpoint_path,
|
||||
device=device
|
||||
)
|
||||
|
||||
# Enable optimizations for CUDA
|
||||
if device.type == "cuda":
|
||||
torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
|
||||
if torch.cuda.get_device_properties(0).major >= 8:
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
logger.info(f"SAM2 predictor initialized successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize SAM2 predictor: {e}")
|
||||
raise
|
||||
|
||||
def create_low_res_video(self, input_video_path: str, output_video_path: str, scale: float):
|
||||
"""
|
||||
Create a low-resolution version of the input video for inference.
|
||||
|
||||
Args:
|
||||
input_video_path: Path to input video
|
||||
output_video_path: Path to output low-res video
|
||||
scale: Scale factor for resolution reduction
|
||||
"""
|
||||
cap = cv2.VideoCapture(input_video_path)
|
||||
if not cap.isOpened():
|
||||
raise ValueError(f"Could not open video: {input_video_path}")
|
||||
|
||||
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH) * scale)
|
||||
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT) * scale)
|
||||
fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
|
||||
|
||||
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
||||
out = cv2.VideoWriter(output_video_path, fourcc, fps, (frame_width, frame_height))
|
||||
|
||||
frame_count = 0
|
||||
while True:
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
|
||||
low_res_frame = cv2.resize(frame, (frame_width, frame_height), interpolation=cv2.INTER_LINEAR)
|
||||
out.write(low_res_frame)
|
||||
frame_count += 1
|
||||
|
||||
cap.release()
|
||||
out.release()
|
||||
|
||||
logger.info(f"Created low-res video with {frame_count} frames: {output_video_path}")
|
||||
|
||||
def add_yolo_prompts_to_predictor(self, inference_state, prompts: List[Dict[str, Any]]) -> bool:
|
||||
"""
|
||||
Add YOLO detection prompts to SAM2 predictor.
|
||||
|
||||
Args:
|
||||
inference_state: SAM2 inference state
|
||||
prompts: List of prompt dictionaries with obj_id and bbox
|
||||
|
||||
Returns:
|
||||
True if prompts were added successfully
|
||||
"""
|
||||
if not prompts:
|
||||
logger.warning("No prompts provided to SAM2")
|
||||
return False
|
||||
|
||||
try:
|
||||
for prompt in prompts:
|
||||
obj_id = prompt['obj_id']
|
||||
bbox = prompt['bbox']
|
||||
|
||||
_, out_obj_ids, out_mask_logits = self.predictor.add_new_points_or_box(
|
||||
inference_state=inference_state,
|
||||
frame_idx=0,
|
||||
obj_id=obj_id,
|
||||
box=bbox.astype(np.float32),
|
||||
)
|
||||
|
||||
logger.debug(f"Added prompt for Object {obj_id}: {bbox}")
|
||||
|
||||
logger.info(f"Successfully added {len(prompts)} prompts to SAM2")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding prompts to SAM2: {e}")
|
||||
return False
|
||||
|
||||
def load_previous_segment_mask(self, prev_segment_dir: str) -> Optional[Dict[int, np.ndarray]]:
|
||||
"""
|
||||
Load masks from previous segment for continuity.
|
||||
|
||||
Args:
|
||||
prev_segment_dir: Directory of previous segment
|
||||
|
||||
Returns:
|
||||
Dictionary mapping object IDs to masks, or None if failed
|
||||
"""
|
||||
mask_path = os.path.join(prev_segment_dir, "mask.png")
|
||||
|
||||
if not os.path.exists(mask_path):
|
||||
logger.warning(f"Previous mask not found: {mask_path}")
|
||||
return None
|
||||
|
||||
try:
|
||||
mask_image = cv2.imread(mask_path)
|
||||
if mask_image is None:
|
||||
logger.error(f"Could not read mask image: {mask_path}")
|
||||
return None
|
||||
|
||||
if len(mask_image.shape) != 3 or mask_image.shape[2] != 3:
|
||||
logger.error("Mask image does not have three color channels")
|
||||
return None
|
||||
|
||||
mask_image = mask_image.astype(np.uint8)
|
||||
|
||||
# Extract Object A and Object B masks (preserving original logic)
|
||||
GREEN = [0, 255, 0]
|
||||
BLUE = [255, 0, 0]
|
||||
|
||||
mask_a = np.all(mask_image == GREEN, axis=2)
|
||||
mask_b = np.all(mask_image == BLUE, axis=2)
|
||||
|
||||
per_obj_input_mask = {}
|
||||
if np.any(mask_a):
|
||||
per_obj_input_mask[1] = mask_a
|
||||
if np.any(mask_b):
|
||||
per_obj_input_mask[2] = mask_b
|
||||
|
||||
logger.info(f"Loaded masks for {len(per_obj_input_mask)} objects from {prev_segment_dir}")
|
||||
return per_obj_input_mask
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading previous mask: {e}")
|
||||
return None
|
||||
|
||||
def add_previous_masks_to_predictor(self, inference_state, masks: Dict[int, np.ndarray]) -> bool:
|
||||
"""
|
||||
Add previous segment masks to predictor for continuity.
|
||||
|
||||
Args:
|
||||
inference_state: SAM2 inference state
|
||||
masks: Dictionary mapping object IDs to masks
|
||||
|
||||
Returns:
|
||||
True if masks were added successfully
|
||||
"""
|
||||
try:
|
||||
for obj_id, mask in masks.items():
|
||||
self.predictor.add_new_mask(inference_state, 0, obj_id, mask)
|
||||
logger.debug(f"Added previous mask for Object {obj_id}")
|
||||
|
||||
logger.info(f"Successfully added {len(masks)} previous masks to SAM2")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding previous masks to SAM2: {e}")
|
||||
return False
|
||||
|
||||
def propagate_masks(self, inference_state) -> Dict[int, Dict[int, np.ndarray]]:
|
||||
"""
|
||||
Propagate masks across all frames in the video.
|
||||
|
||||
Args:
|
||||
inference_state: SAM2 inference state
|
||||
|
||||
Returns:
|
||||
Dictionary mapping frame indices to object masks
|
||||
"""
|
||||
video_segments = {}
|
||||
|
||||
try:
|
||||
for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(inference_state):
|
||||
video_segments[out_frame_idx] = {
|
||||
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
|
||||
for i, out_obj_id in enumerate(out_obj_ids)
|
||||
}
|
||||
|
||||
logger.info(f"Propagated masks across {len(video_segments)} frames with {len(out_obj_ids)} objects")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during mask propagation: {e}")
|
||||
|
||||
return video_segments
|
||||
|
||||
def process_single_segment(self, segment_info: dict, yolo_prompts: Optional[List[Dict[str, Any]]] = None,
|
||||
previous_masks: Optional[Dict[int, np.ndarray]] = None,
|
||||
inference_scale: float = 0.5) -> Optional[Dict[int, Dict[int, np.ndarray]]]:
|
||||
"""
|
||||
Process a single video segment with SAM2.
|
||||
|
||||
Args:
|
||||
segment_info: Segment information dictionary
|
||||
yolo_prompts: Optional YOLO detection prompts
|
||||
previous_masks: Optional masks from previous segment
|
||||
inference_scale: Scale factor for inference
|
||||
|
||||
Returns:
|
||||
Video segments dictionary or None if failed
|
||||
"""
|
||||
segment_dir = segment_info['directory']
|
||||
video_path = segment_info['video_file']
|
||||
segment_idx = segment_info['index']
|
||||
|
||||
# Check if segment is already processed (resume capability)
|
||||
output_done_file = os.path.join(segment_dir, "output_frames_done")
|
||||
if os.path.exists(output_done_file):
|
||||
logger.info(f"Segment {segment_idx} already processed. Skipping.")
|
||||
return None # Indicate skip, not failure
|
||||
|
||||
logger.info(f"Processing segment {segment_idx} with SAM2")
|
||||
|
||||
# Create low-resolution video for inference
|
||||
low_res_video_path = os.path.join(segment_dir, "low_res_video.mp4")
|
||||
if not os.path.exists(low_res_video_path):
|
||||
try:
|
||||
self.create_low_res_video(video_path, low_res_video_path, inference_scale)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create low-res video for segment {segment_idx}: {e}")
|
||||
return None
|
||||
|
||||
try:
|
||||
# Initialize inference state
|
||||
inference_state = self.predictor.init_state(video_path=low_res_video_path, async_loading_frames=True)
|
||||
|
||||
# Add prompts or previous masks
|
||||
if yolo_prompts:
|
||||
if not self.add_yolo_prompts_to_predictor(inference_state, yolo_prompts):
|
||||
return None
|
||||
elif previous_masks:
|
||||
if not self.add_previous_masks_to_predictor(inference_state, previous_masks):
|
||||
return None
|
||||
else:
|
||||
logger.error(f"No prompts or previous masks available for segment {segment_idx}")
|
||||
return None
|
||||
|
||||
# Propagate masks
|
||||
video_segments = self.propagate_masks(inference_state)
|
||||
|
||||
# Clean up
|
||||
self.predictor.reset_state(inference_state)
|
||||
del inference_state
|
||||
gc.collect()
|
||||
|
||||
# Remove low-res video to save space
|
||||
try:
|
||||
os.remove(low_res_video_path)
|
||||
logger.debug(f"Removed low-res video: {low_res_video_path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not remove low-res video: {e}")
|
||||
|
||||
# Mark segment as completed (for resume capability)
|
||||
try:
|
||||
with open(output_done_file, 'w') as f:
|
||||
f.write(f"Segment {segment_idx} completed successfully\n")
|
||||
logger.debug(f"Marked segment {segment_idx} as completed")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not create completion marker: {e}")
|
||||
|
||||
return video_segments
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing segment {segment_idx}: {e}")
|
||||
return None
|
||||
|
||||
def save_final_masks(self, video_segments: Dict[int, Dict[int, np.ndarray]], output_path: str,
|
||||
green_color: List[int] = [0, 255, 0], blue_color: List[int] = [255, 0, 0]):
|
||||
"""
|
||||
Save the final masks as a colored image for continuity.
|
||||
|
||||
Args:
|
||||
video_segments: Video segments dictionary
|
||||
output_path: Path to save the mask image
|
||||
green_color: RGB color for object 1
|
||||
blue_color: RGB color for object 2
|
||||
"""
|
||||
if not video_segments:
|
||||
logger.error("No video segments to save")
|
||||
return
|
||||
|
||||
try:
|
||||
last_frame_idx = max(video_segments.keys())
|
||||
masks_dict = video_segments[last_frame_idx]
|
||||
|
||||
# Get masks for objects 1 and 2
|
||||
mask_a = masks_dict.get(1)
|
||||
mask_b = masks_dict.get(2)
|
||||
|
||||
if mask_a is None and mask_b is None:
|
||||
logger.error("No masks found for objects")
|
||||
return
|
||||
|
||||
# Use the first available mask to determine dimensions
|
||||
reference_mask = mask_a if mask_a is not None else mask_b
|
||||
reference_mask = reference_mask.squeeze()
|
||||
|
||||
black_frame = np.zeros((reference_mask.shape[0], reference_mask.shape[1], 3), dtype=np.uint8)
|
||||
|
||||
if mask_a is not None:
|
||||
mask_a = mask_a.squeeze().astype(bool)
|
||||
black_frame[mask_a] = green_color
|
||||
|
||||
if mask_b is not None:
|
||||
mask_b = mask_b.squeeze().astype(bool)
|
||||
black_frame[mask_b] = blue_color
|
||||
|
||||
# Save the mask image
|
||||
cv2.imwrite(output_path, black_frame)
|
||||
logger.info(f"Saved final masks to {output_path}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving final masks: {e}")
|
||||
174
core/video_splitter.py
Normal file
174
core/video_splitter.py
Normal file
@@ -0,0 +1,174 @@
|
||||
"""
|
||||
Video splitter module for the YOLO + SAM2 processing pipeline.
|
||||
Handles splitting long videos into manageable segments.
|
||||
"""
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import logging
|
||||
from typing import List, Tuple
|
||||
from ..utils.file_utils import ensure_directory, get_video_file_name
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class VideoSplitter:
|
||||
"""Handles splitting videos into segments for processing."""
|
||||
|
||||
def __init__(self, segment_duration: int = 5, force_keyframes: bool = True):
|
||||
"""
|
||||
Initialize video splitter.
|
||||
|
||||
Args:
|
||||
segment_duration: Duration of each segment in seconds
|
||||
force_keyframes: Whether to force keyframes for clean cuts
|
||||
"""
|
||||
self.segment_duration = segment_duration
|
||||
self.force_keyframes = force_keyframes
|
||||
|
||||
def split_video(self, input_video: str, output_dir: str) -> Tuple[str, List[str]]:
|
||||
"""
|
||||
Split video into segments and organize into directory structure.
|
||||
|
||||
Args:
|
||||
input_video: Path to input video file
|
||||
output_dir: Base output directory
|
||||
|
||||
Returns:
|
||||
Tuple of (segments_directory, list_of_segment_directories)
|
||||
"""
|
||||
if not os.path.exists(input_video):
|
||||
raise FileNotFoundError(f"Input video not found: {input_video}")
|
||||
|
||||
# Create output directory structure
|
||||
video_name = os.path.splitext(os.path.basename(input_video))[0]
|
||||
segments_dir = os.path.join(output_dir, f"{video_name}_segments")
|
||||
ensure_directory(segments_dir)
|
||||
|
||||
logger.info(f"Splitting video {input_video} into {self.segment_duration}s segments")
|
||||
|
||||
# Split video using ffmpeg
|
||||
segment_pattern = os.path.join(segments_dir, "segment_%03d.mp4")
|
||||
|
||||
# Build ffmpeg command
|
||||
cmd = [
|
||||
'ffmpeg', '-i', input_video,
|
||||
'-f', 'segment',
|
||||
'-segment_time', str(self.segment_duration),
|
||||
'-reset_timestamps', '1',
|
||||
'-c', 'copy'
|
||||
]
|
||||
|
||||
# Add keyframe forcing if enabled
|
||||
if self.force_keyframes:
|
||||
cmd.extend(['-force_key_frames', f'expr:gte(t,n_forced*{self.segment_duration})'])
|
||||
|
||||
# Add copyts for timestamp preservation
|
||||
cmd.extend(['-copyts', segment_pattern])
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
logger.debug(f"FFmpeg output: {result.stderr}")
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.error(f"FFmpeg failed: {e.stderr}")
|
||||
raise RuntimeError(f"Video splitting failed: {e}")
|
||||
|
||||
# Organize segments into individual directories
|
||||
segment_dirs = self._organize_segments(segments_dir)
|
||||
|
||||
# Create file list for later concatenation
|
||||
self._create_file_list(segments_dir, segment_dirs)
|
||||
|
||||
logger.info(f"Successfully split video into {len(segment_dirs)} segments")
|
||||
return segments_dir, segment_dirs
|
||||
|
||||
def _organize_segments(self, segments_dir: str) -> List[str]:
|
||||
"""
|
||||
Move each segment into its own subdirectory.
|
||||
|
||||
Args:
|
||||
segments_dir: Directory containing split segments
|
||||
|
||||
Returns:
|
||||
List of created segment directory names
|
||||
"""
|
||||
segment_files = []
|
||||
segment_dirs = []
|
||||
|
||||
# Find all segment files
|
||||
for file in os.listdir(segments_dir):
|
||||
if file.startswith("segment_") and file.endswith(".mp4"):
|
||||
segment_files.append(file)
|
||||
|
||||
# Sort segment files numerically
|
||||
segment_files.sort(key=lambda x: int(x.split("_")[1].split(".")[0]))
|
||||
|
||||
# Move each segment to its own directory
|
||||
for i, segment_file in enumerate(segment_files):
|
||||
segment_dir_name = f"segment_{i}"
|
||||
segment_dir_path = os.path.join(segments_dir, segment_dir_name)
|
||||
ensure_directory(segment_dir_path)
|
||||
|
||||
# Move segment file to subdirectory with standardized name
|
||||
old_path = os.path.join(segments_dir, segment_file)
|
||||
new_path = os.path.join(segment_dir_path, get_video_file_name(i))
|
||||
|
||||
os.rename(old_path, new_path)
|
||||
segment_dirs.append(segment_dir_name)
|
||||
|
||||
logger.debug(f"Organized segment {i}: {new_path}")
|
||||
|
||||
return segment_dirs
|
||||
|
||||
def _create_file_list(self, segments_dir: str, segment_dirs: List[str]):
|
||||
"""
|
||||
Create a file list for future concatenation.
|
||||
|
||||
Args:
|
||||
segments_dir: Base segments directory
|
||||
segment_dirs: List of segment directory names
|
||||
"""
|
||||
file_list_path = os.path.join(segments_dir, "file_list.txt")
|
||||
|
||||
with open(file_list_path, 'w') as f:
|
||||
for i, segment_dir in enumerate(segment_dirs):
|
||||
segment_path = os.path.join(segment_dir, get_video_file_name(i))
|
||||
f.write(f"file '{segment_path}'\\n")
|
||||
|
||||
logger.debug(f"Created file list: {file_list_path}")
|
||||
|
||||
def get_segment_info(self, segments_dir: str) -> List[dict]:
|
||||
"""
|
||||
Get information about all segments in a directory.
|
||||
|
||||
Args:
|
||||
segments_dir: Directory containing segments
|
||||
|
||||
Returns:
|
||||
List of segment information dictionaries
|
||||
"""
|
||||
segment_info = []
|
||||
|
||||
for item in os.listdir(segments_dir):
|
||||
item_path = os.path.join(segments_dir, item)
|
||||
|
||||
if os.path.isdir(item_path) and item.startswith("segment_"):
|
||||
segment_index = int(item.split("_")[1])
|
||||
video_file = os.path.join(item_path, get_video_file_name(segment_index))
|
||||
|
||||
info = {
|
||||
'index': segment_index,
|
||||
'directory': item_path,
|
||||
'video_file': video_file,
|
||||
'exists': os.path.exists(video_file)
|
||||
}
|
||||
segment_info.append(info)
|
||||
|
||||
# Sort by index
|
||||
segment_info.sort(key=lambda x: x['index'])
|
||||
|
||||
return segment_info
|
||||
286
core/yolo_detector.py
Normal file
286
core/yolo_detector.py
Normal file
@@ -0,0 +1,286 @@
|
||||
"""
|
||||
YOLO detector module for human detection in video segments.
|
||||
Preserves the core detection logic from the original implementation.
|
||||
"""
|
||||
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
import logging
|
||||
from typing import List, Dict, Any, Optional
|
||||
from ultralytics import YOLO
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class YOLODetector:
|
||||
\"\"\"Handles YOLO-based human detection for video segments.\"\"\"
|
||||
|
||||
def __init__(self, model_path: str, confidence_threshold: float = 0.6, human_class_id: int = 0):
|
||||
\"\"\"
|
||||
Initialize YOLO detector.
|
||||
|
||||
Args:
|
||||
model_path: Path to YOLO model weights
|
||||
confidence_threshold: Detection confidence threshold
|
||||
human_class_id: COCO class ID for humans (0 = person)
|
||||
\"\"\"
|
||||
self.model_path = model_path
|
||||
self.confidence_threshold = confidence_threshold
|
||||
self.human_class_id = human_class_id
|
||||
|
||||
# Load YOLO model
|
||||
try:
|
||||
self.model = YOLO(model_path)
|
||||
logger.info(f\"Loaded YOLO model from {model_path}\")
|
||||
except Exception as e:
|
||||
logger.error(f\"Failed to load YOLO model: {e}\")
|
||||
raise
|
||||
|
||||
def detect_humans_in_frame(self, frame: np.ndarray) -> List[Dict[str, Any]]:
|
||||
\"\"\"
|
||||
Detect humans in a single frame using YOLO.
|
||||
|
||||
Args:
|
||||
frame: Input frame (BGR format from OpenCV)
|
||||
|
||||
Returns:
|
||||
List of human detection dictionaries with bbox and confidence
|
||||
\"\"\"
|
||||
# Run YOLO detection
|
||||
results = self.model(frame, conf=self.confidence_threshold, verbose=False)
|
||||
|
||||
human_detections = []
|
||||
|
||||
# Process results
|
||||
for result in results:
|
||||
boxes = result.boxes
|
||||
if boxes is not None:
|
||||
for box in boxes:
|
||||
# Get class ID
|
||||
cls = int(box.cls.cpu().numpy()[0])
|
||||
|
||||
# Check if it's a person (human_class_id)
|
||||
if cls == self.human_class_id:
|
||||
# Get bounding box coordinates (x1, y1, x2, y2)
|
||||
coords = box.xyxy[0].cpu().numpy()
|
||||
conf = float(box.conf.cpu().numpy()[0])
|
||||
|
||||
human_detections.append({
|
||||
'bbox': coords,
|
||||
'confidence': conf
|
||||
})
|
||||
|
||||
logger.debug(f\"Detected human with confidence {conf:.2f} at {coords}\")
|
||||
|
||||
return human_detections
|
||||
|
||||
def detect_humans_in_video_first_frame(self, video_path: str, scale: float = 1.0) -> List[Dict[str, Any]]:
|
||||
\"\"\"
|
||||
Detect humans in the first frame of a video.
|
||||
|
||||
Args:
|
||||
video_path: Path to video file
|
||||
scale: Scale factor for frame processing
|
||||
|
||||
Returns:
|
||||
List of human detection dictionaries
|
||||
\"\"\"
|
||||
if not os.path.exists(video_path):
|
||||
logger.error(f\"Video file not found: {video_path}\")
|
||||
return []
|
||||
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
if not cap.isOpened():
|
||||
logger.error(f\"Could not open video: {video_path}\")
|
||||
return []
|
||||
|
||||
ret, frame = cap.read()
|
||||
cap.release()
|
||||
|
||||
if not ret:
|
||||
logger.error(f\"Could not read first frame from: {video_path}\")
|
||||
return []
|
||||
|
||||
# Scale frame if needed
|
||||
if scale != 1.0:
|
||||
frame = cv2.resize(frame, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
return self.detect_humans_in_frame(frame)
|
||||
|
||||
def save_detections_to_file(self, detections: List[Dict[str, Any]], output_path: str) -> bool:
|
||||
\"\"\"
|
||||
Save detection results to file.
|
||||
|
||||
Args:
|
||||
detections: List of detection dictionaries
|
||||
output_path: Path to save detections
|
||||
|
||||
Returns:
|
||||
True if saved successfully
|
||||
\"\"\"
|
||||
try:
|
||||
with open(output_path, 'w') as f:
|
||||
f.write(\"# YOLO Human Detections\\n\")
|
||||
if detections:
|
||||
for detection in detections:
|
||||
bbox = detection['bbox']
|
||||
conf = detection['confidence']
|
||||
f.write(f\"{bbox[0]},{bbox[1]},{bbox[2]},{bbox[3]},{conf}\\n\")
|
||||
logger.info(f\"Saved {len(detections)} detections to {output_path}\")
|
||||
else:
|
||||
f.write(\"# No humans detected\\n\")
|
||||
logger.info(f\"Saved empty detection file to {output_path}\")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f\"Failed to save detections to {output_path}: {e}\")
|
||||
return False
|
||||
|
||||
def load_detections_from_file(self, file_path: str) -> List[Dict[str, Any]]:
|
||||
\"\"\"
|
||||
Load detection results from file.
|
||||
|
||||
Args:
|
||||
file_path: Path to detection file
|
||||
|
||||
Returns:
|
||||
List of detection dictionaries
|
||||
\"\"\"
|
||||
detections = []
|
||||
|
||||
if not os.path.exists(file_path):
|
||||
logger.warning(f\"Detection file not found: {file_path}\")
|
||||
return detections
|
||||
|
||||
try:
|
||||
with open(file_path, 'r') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
# Skip comments and empty lines
|
||||
if line.startswith('#') or not line:
|
||||
continue
|
||||
|
||||
# Parse detection line: x1,y1,x2,y2,confidence
|
||||
parts = line.split(',')
|
||||
if len(parts) == 5:
|
||||
try:
|
||||
bbox = [float(x) for x in parts[:4]]
|
||||
conf = float(parts[4])
|
||||
detections.append({
|
||||
'bbox': np.array(bbox),
|
||||
'confidence': conf
|
||||
})
|
||||
except ValueError:
|
||||
logger.warning(f\"Invalid detection line: {line}\")
|
||||
continue
|
||||
|
||||
logger.info(f\"Loaded {len(detections)} detections from {file_path}\")
|
||||
except Exception as e:
|
||||
logger.error(f\"Failed to load detections from {file_path}: {e}\")
|
||||
|
||||
return detections
|
||||
|
||||
def process_segments_batch(self, segments_info: List[dict], detect_segments: List[int],
|
||||
scale: float = 0.5) -> Dict[int, List[Dict[str, Any]]]:
|
||||
\"\"\"
|
||||
Process multiple segments for human detection.
|
||||
|
||||
Args:
|
||||
segments_info: List of segment information dictionaries
|
||||
detect_segments: List of segment indices to process
|
||||
scale: Scale factor for processing
|
||||
|
||||
Returns:
|
||||
Dictionary mapping segment index to detection results
|
||||
\"\"\"
|
||||
results = {}
|
||||
|
||||
for segment_info in segments_info:
|
||||
segment_idx = segment_info['index']
|
||||
|
||||
# Skip if not in detect_segments list
|
||||
if detect_segments != 'all' and segment_idx not in detect_segments:
|
||||
continue
|
||||
|
||||
video_path = segment_info['video_file']
|
||||
detection_file = os.path.join(segment_info['directory'], \"yolo_detections\")
|
||||
|
||||
# Skip if already processed
|
||||
if os.path.exists(detection_file):
|
||||
logger.info(f\"Segment {segment_idx} already has detections, skipping\")
|
||||
detections = self.load_detections_from_file(detection_file)
|
||||
results[segment_idx] = detections
|
||||
continue
|
||||
|
||||
# Run detection
|
||||
logger.info(f\"Processing segment {segment_idx} for human detection\")
|
||||
detections = self.detect_humans_in_video_first_frame(video_path, scale)
|
||||
|
||||
# Save results
|
||||
self.save_detections_to_file(detections, detection_file)
|
||||
results[segment_idx] = detections
|
||||
|
||||
return results
|
||||
|
||||
def convert_detections_to_sam2_prompts(self, detections: List[Dict[str, Any]],
|
||||
frame_width: int) -> List[Dict[str, Any]]:
|
||||
\"\"\"
|
||||
Convert YOLO detections to SAM2-compatible prompts for stereo video.
|
||||
|
||||
Args:
|
||||
detections: List of YOLO detection results
|
||||
frame_width: Width of the video frame
|
||||
|
||||
Returns:
|
||||
List of SAM2 prompt dictionaries with obj_id and bbox
|
||||
\"\"\"
|
||||
if not detections:
|
||||
return []
|
||||
|
||||
half_frame_width = frame_width // 2
|
||||
prompts = []
|
||||
|
||||
# Sort detections by x-coordinate to get consistent left/right assignment
|
||||
sorted_detections = sorted(detections, key=lambda x: x['bbox'][0])
|
||||
|
||||
obj_id = 1
|
||||
|
||||
for i, detection in enumerate(sorted_detections[:2]): # Take up to 2 humans
|
||||
bbox = detection['bbox'].copy()
|
||||
|
||||
# For stereo videos, assign obj_id based on position
|
||||
if len(sorted_detections) >= 2:
|
||||
center_x = (bbox[0] + bbox[2]) / 2
|
||||
if center_x < half_frame_width:
|
||||
current_obj_id = 1 # Left human
|
||||
else:
|
||||
current_obj_id = 2 # Right human
|
||||
else:
|
||||
# If only one human, create prompts for both sides
|
||||
current_obj_id = obj_id
|
||||
obj_id += 1
|
||||
|
||||
# Create mirrored version for stereo
|
||||
if obj_id <= 2:
|
||||
mirrored_bbox = bbox.copy()
|
||||
mirrored_bbox[0] += half_frame_width # Shift x1
|
||||
mirrored_bbox[2] += half_frame_width # Shift x2
|
||||
|
||||
# Ensure mirrored bbox is within frame bounds
|
||||
mirrored_bbox[0] = max(0, min(mirrored_bbox[0], frame_width - 1))
|
||||
mirrored_bbox[2] = max(0, min(mirrored_bbox[2], frame_width - 1))
|
||||
|
||||
prompts.append({
|
||||
'obj_id': obj_id,
|
||||
'bbox': mirrored_bbox,
|
||||
'confidence': detection['confidence']
|
||||
})
|
||||
obj_id += 1
|
||||
|
||||
prompts.append({
|
||||
'obj_id': current_obj_id,
|
||||
'bbox': bbox,
|
||||
'confidence': detection['confidence']
|
||||
})
|
||||
|
||||
logger.debug(f\"Converted {len(detections)} detections to {len(prompts)} SAM2 prompts\")
|
||||
return prompts
|
||||
Reference in New Issue
Block a user