inital commit
This commit is contained in:
195
main.py
Normal file
195
main.py
Normal file
@@ -0,0 +1,195 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Main entry point for YOLO + SAM2 video processing pipeline.
|
||||
Processes long videos by splitting into segments, detecting humans with YOLO,
|
||||
and creating green screen masks with SAM2.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
from typing import List
|
||||
|
||||
# Add project root to path
|
||||
sys.path.append(os.path.dirname(__file__))
|
||||
|
||||
from core.config_loader import ConfigLoader
|
||||
from core.video_splitter import VideoSplitter
|
||||
from core.yolo_detector import YOLODetector
|
||||
from utils.logging_utils import setup_logging, get_logger
|
||||
from utils.file_utils import ensure_directory
|
||||
from utils.status_utils import print_processing_status, cleanup_incomplete_segment
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
def parse_arguments():
|
||||
"""Parse command line arguments."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="YOLO + SAM2 Video Processing Pipeline"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to YAML configuration file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log-file",
|
||||
type=str,
|
||||
help="Optional log file path"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--status",
|
||||
action="store_true",
|
||||
help="Show processing status and exit"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cleanup-segment",
|
||||
type=int,
|
||||
help="Clean up a specific segment for restart (segment index)"
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
def validate_dependencies():
|
||||
"""Validate that required dependencies are available."""
|
||||
try:
|
||||
import torch
|
||||
import cv2
|
||||
import numpy as np
|
||||
import cupy as cp
|
||||
from ultralytics import YOLO
|
||||
from sam2.build_sam import build_sam2_video_predictor
|
||||
logger.info("All dependencies validated successfully")
|
||||
return True
|
||||
except ImportError as e:
|
||||
logger.error(f"Missing dependency: {e}")
|
||||
logger.error("Please install requirements: pip install -r requirements.txt")
|
||||
return False
|
||||
|
||||
def resolve_detect_segments(detect_segments, total_segments: int) -> List[int]:
|
||||
"""
|
||||
Resolve detect_segments configuration to list of segment indices.
|
||||
|
||||
Args:
|
||||
detect_segments: Configuration value ("all", list, or None)
|
||||
total_segments: Total number of segments
|
||||
|
||||
Returns:
|
||||
List of segment indices to process
|
||||
"""
|
||||
if detect_segments == "all" or detect_segments is None:
|
||||
return list(range(total_segments))
|
||||
elif isinstance(detect_segments, list):
|
||||
# Filter out invalid segment indices
|
||||
valid_segments = [s for s in detect_segments if 0 <= s < total_segments]
|
||||
if len(valid_segments) != len(detect_segments):
|
||||
logger.warning(f"Some segment indices are invalid. Using: {valid_segments}")
|
||||
return valid_segments
|
||||
else:
|
||||
logger.warning(f"Invalid detect_segments format: {detect_segments}. Using all segments.")
|
||||
return list(range(total_segments))
|
||||
|
||||
def main():
|
||||
"""Main processing pipeline."""
|
||||
args = parse_arguments()
|
||||
|
||||
try:
|
||||
# Load configuration
|
||||
config = ConfigLoader(args.config)
|
||||
|
||||
# Setup logging
|
||||
setup_logging(config.get_log_level(), args.log_file)
|
||||
|
||||
# Handle status check
|
||||
if args.status:
|
||||
output_dir = config.get_output_directory()
|
||||
input_video = config.get_input_video_path()
|
||||
video_name = os.path.splitext(os.path.basename(input_video))[0]
|
||||
segments_dir = os.path.join(output_dir, f"{video_name}_segments")
|
||||
print_processing_status(segments_dir)
|
||||
return 0
|
||||
|
||||
# Handle segment cleanup
|
||||
if args.cleanup_segment is not None:
|
||||
output_dir = config.get_output_directory()
|
||||
input_video = config.get_input_video_path()
|
||||
video_name = os.path.splitext(os.path.basename(input_video))[0]
|
||||
segments_dir = os.path.join(output_dir, f"{video_name}_segments")
|
||||
segment_dir = os.path.join(segments_dir, f"segment_{args.cleanup_segment}")
|
||||
|
||||
if cleanup_incomplete_segment(segment_dir):
|
||||
logger.info(f"Successfully cleaned up segment {args.cleanup_segment}")
|
||||
return 0
|
||||
else:
|
||||
logger.error(f"Failed to clean up segment {args.cleanup_segment}")
|
||||
return 1
|
||||
|
||||
logger.info("Starting YOLO + SAM2 video processing pipeline")
|
||||
|
||||
# Validate dependencies
|
||||
if not validate_dependencies():
|
||||
return 1
|
||||
|
||||
# Validate input video exists
|
||||
input_video = config.get_input_video_path()
|
||||
if not os.path.exists(input_video):
|
||||
logger.error(f"Input video not found: {input_video}")
|
||||
return 1
|
||||
|
||||
# Setup output directory
|
||||
output_dir = config.get_output_directory()
|
||||
ensure_directory(output_dir)
|
||||
|
||||
# Step 1: Split video into segments
|
||||
logger.info("Step 1: Splitting video into segments")
|
||||
splitter = VideoSplitter(
|
||||
segment_duration=config.get_segment_duration(),
|
||||
force_keyframes=config.get('video.force_keyframes', True)
|
||||
)
|
||||
|
||||
segments_dir, segment_dirs = splitter.split_video(input_video, output_dir)
|
||||
logger.info(f"Created {len(segment_dirs)} segments in {segments_dir}")
|
||||
|
||||
# Get detailed segment information
|
||||
segments_info = splitter.get_segment_info(segments_dir)
|
||||
|
||||
# Resolve which segments to process with YOLO
|
||||
detect_segments_config = config.get_detect_segments()
|
||||
detect_segments = resolve_detect_segments(detect_segments_config, len(segments_info))
|
||||
|
||||
# Step 2: Run YOLO detection on specified segments
|
||||
logger.info("Step 2: Running YOLO human detection")
|
||||
detector = YOLODetector(
|
||||
model_path=config.get_yolo_model_path(),
|
||||
confidence_threshold=config.get_yolo_confidence(),
|
||||
human_class_id=config.get_human_class_id()
|
||||
)
|
||||
|
||||
detection_results = detector.process_segments_batch(
|
||||
segments_info,
|
||||
detect_segments,
|
||||
scale=config.get_inference_scale()
|
||||
)
|
||||
|
||||
# Log detection summary
|
||||
total_humans = sum(len(detections) for detections in detection_results.values())
|
||||
logger.info(f"Detected {total_humans} humans across {len(detection_results)} segments")
|
||||
|
||||
# Step 3: Process segments with SAM2 (placeholder for now)
|
||||
logger.info("Step 3: SAM2 processing and green screen generation")
|
||||
logger.info("SAM2 processing module not yet implemented - this is where segment processing would occur")
|
||||
|
||||
# Step 4: Assemble final video (placeholder for now)
|
||||
logger.info("Step 4: Assembling final video with audio")
|
||||
logger.info("Video assembly module not yet implemented - this is where concatenation and audio copying would occur")
|
||||
|
||||
logger.info("Pipeline completed successfully")
|
||||
return 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Pipeline failed: {e}", exc_info=True)
|
||||
return 1
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit_code = main()
|
||||
sys.exit(exit_code)
|
||||
Reference in New Issue
Block a user