195 lines
7.1 KiB
Python
195 lines
7.1 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Main entry point for YOLO + SAM2 video processing pipeline.
|
|
Processes long videos by splitting into segments, detecting humans with YOLO,
|
|
and creating green screen masks with SAM2.
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import argparse
|
|
from typing import List
|
|
|
|
# Add project root to path
|
|
sys.path.append(os.path.dirname(__file__))
|
|
|
|
from core.config_loader import ConfigLoader
|
|
from core.video_splitter import VideoSplitter
|
|
from core.yolo_detector import YOLODetector
|
|
from utils.logging_utils import setup_logging, get_logger
|
|
from utils.file_utils import ensure_directory
|
|
from utils.status_utils import print_processing_status, cleanup_incomplete_segment
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
def parse_arguments():
|
|
"""Parse command line arguments."""
|
|
parser = argparse.ArgumentParser(
|
|
description="YOLO + SAM2 Video Processing Pipeline"
|
|
)
|
|
parser.add_argument(
|
|
"--config",
|
|
type=str,
|
|
required=True,
|
|
help="Path to YAML configuration file"
|
|
)
|
|
parser.add_argument(
|
|
"--log-file",
|
|
type=str,
|
|
help="Optional log file path"
|
|
)
|
|
parser.add_argument(
|
|
"--status",
|
|
action="store_true",
|
|
help="Show processing status and exit"
|
|
)
|
|
parser.add_argument(
|
|
"--cleanup-segment",
|
|
type=int,
|
|
help="Clean up a specific segment for restart (segment index)"
|
|
)
|
|
return parser.parse_args()
|
|
|
|
def validate_dependencies():
|
|
"""Validate that required dependencies are available."""
|
|
try:
|
|
import torch
|
|
import cv2
|
|
import numpy as np
|
|
import cupy as cp
|
|
from ultralytics import YOLO
|
|
from sam2.build_sam import build_sam2_video_predictor
|
|
logger.info("All dependencies validated successfully")
|
|
return True
|
|
except ImportError as e:
|
|
logger.error(f"Missing dependency: {e}")
|
|
logger.error("Please install requirements: pip install -r requirements.txt")
|
|
return False
|
|
|
|
def resolve_detect_segments(detect_segments, total_segments: int) -> List[int]:
|
|
"""
|
|
Resolve detect_segments configuration to list of segment indices.
|
|
|
|
Args:
|
|
detect_segments: Configuration value ("all", list, or None)
|
|
total_segments: Total number of segments
|
|
|
|
Returns:
|
|
List of segment indices to process
|
|
"""
|
|
if detect_segments == "all" or detect_segments is None:
|
|
return list(range(total_segments))
|
|
elif isinstance(detect_segments, list):
|
|
# Filter out invalid segment indices
|
|
valid_segments = [s for s in detect_segments if 0 <= s < total_segments]
|
|
if len(valid_segments) != len(detect_segments):
|
|
logger.warning(f"Some segment indices are invalid. Using: {valid_segments}")
|
|
return valid_segments
|
|
else:
|
|
logger.warning(f"Invalid detect_segments format: {detect_segments}. Using all segments.")
|
|
return list(range(total_segments))
|
|
|
|
def main():
|
|
"""Main processing pipeline."""
|
|
args = parse_arguments()
|
|
|
|
try:
|
|
# Load configuration
|
|
config = ConfigLoader(args.config)
|
|
|
|
# Setup logging
|
|
setup_logging(config.get_log_level(), args.log_file)
|
|
|
|
# Handle status check
|
|
if args.status:
|
|
output_dir = config.get_output_directory()
|
|
input_video = config.get_input_video_path()
|
|
video_name = os.path.splitext(os.path.basename(input_video))[0]
|
|
segments_dir = os.path.join(output_dir, f"{video_name}_segments")
|
|
print_processing_status(segments_dir)
|
|
return 0
|
|
|
|
# Handle segment cleanup
|
|
if args.cleanup_segment is not None:
|
|
output_dir = config.get_output_directory()
|
|
input_video = config.get_input_video_path()
|
|
video_name = os.path.splitext(os.path.basename(input_video))[0]
|
|
segments_dir = os.path.join(output_dir, f"{video_name}_segments")
|
|
segment_dir = os.path.join(segments_dir, f"segment_{args.cleanup_segment}")
|
|
|
|
if cleanup_incomplete_segment(segment_dir):
|
|
logger.info(f"Successfully cleaned up segment {args.cleanup_segment}")
|
|
return 0
|
|
else:
|
|
logger.error(f"Failed to clean up segment {args.cleanup_segment}")
|
|
return 1
|
|
|
|
logger.info("Starting YOLO + SAM2 video processing pipeline")
|
|
|
|
# Validate dependencies
|
|
if not validate_dependencies():
|
|
return 1
|
|
|
|
# Validate input video exists
|
|
input_video = config.get_input_video_path()
|
|
if not os.path.exists(input_video):
|
|
logger.error(f"Input video not found: {input_video}")
|
|
return 1
|
|
|
|
# Setup output directory
|
|
output_dir = config.get_output_directory()
|
|
ensure_directory(output_dir)
|
|
|
|
# Step 1: Split video into segments
|
|
logger.info("Step 1: Splitting video into segments")
|
|
splitter = VideoSplitter(
|
|
segment_duration=config.get_segment_duration(),
|
|
force_keyframes=config.get('video.force_keyframes', True)
|
|
)
|
|
|
|
segments_dir, segment_dirs = splitter.split_video(input_video, output_dir)
|
|
logger.info(f"Created {len(segment_dirs)} segments in {segments_dir}")
|
|
|
|
# Get detailed segment information
|
|
segments_info = splitter.get_segment_info(segments_dir)
|
|
|
|
# Resolve which segments to process with YOLO
|
|
detect_segments_config = config.get_detect_segments()
|
|
detect_segments = resolve_detect_segments(detect_segments_config, len(segments_info))
|
|
|
|
# Step 2: Run YOLO detection on specified segments
|
|
logger.info("Step 2: Running YOLO human detection")
|
|
detector = YOLODetector(
|
|
model_path=config.get_yolo_model_path(),
|
|
confidence_threshold=config.get_yolo_confidence(),
|
|
human_class_id=config.get_human_class_id()
|
|
)
|
|
|
|
detection_results = detector.process_segments_batch(
|
|
segments_info,
|
|
detect_segments,
|
|
scale=config.get_inference_scale()
|
|
)
|
|
|
|
# Log detection summary
|
|
total_humans = sum(len(detections) for detections in detection_results.values())
|
|
logger.info(f"Detected {total_humans} humans across {len(detection_results)} segments")
|
|
|
|
# Step 3: Process segments with SAM2 (placeholder for now)
|
|
logger.info("Step 3: SAM2 processing and green screen generation")
|
|
logger.info("SAM2 processing module not yet implemented - this is where segment processing would occur")
|
|
|
|
# Step 4: Assemble final video (placeholder for now)
|
|
logger.info("Step 4: Assembling final video with audio")
|
|
logger.info("Video assembly module not yet implemented - this is where concatenation and audio copying would occur")
|
|
|
|
logger.info("Pipeline completed successfully")
|
|
return 0
|
|
|
|
except Exception as e:
|
|
logger.error(f"Pipeline failed: {e}", exc_info=True)
|
|
return 1
|
|
|
|
if __name__ == "__main__":
|
|
exit_code = main()
|
|
sys.exit(exit_code) |