inital commit
This commit is contained in:
59
config.yaml
Normal file
59
config.yaml
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
# YOLO + SAM2 Video Processing Configuration
|
||||||
|
|
||||||
|
input:
|
||||||
|
video_path: "/path/to/input/video.mp4"
|
||||||
|
|
||||||
|
output:
|
||||||
|
directory: "/path/to/output/"
|
||||||
|
filename: "processed_video.mp4"
|
||||||
|
|
||||||
|
processing:
|
||||||
|
# Duration of each video segment in seconds
|
||||||
|
segment_duration: 5
|
||||||
|
|
||||||
|
# Scale factor for SAM2 inference (0.5 = half resolution)
|
||||||
|
inference_scale: 0.5
|
||||||
|
|
||||||
|
# YOLO detection confidence threshold
|
||||||
|
yolo_confidence: 0.6
|
||||||
|
|
||||||
|
# Which segments to run YOLO detection on
|
||||||
|
# Options: "all", [0, 5, 10], or [] for default (all)
|
||||||
|
detect_segments: "all"
|
||||||
|
|
||||||
|
models:
|
||||||
|
# YOLO model path - can be pretrained (yolov8n.pt) or custom path
|
||||||
|
yolo_model: "yolov8n.pt"
|
||||||
|
|
||||||
|
# SAM2 model configuration
|
||||||
|
sam2_checkpoint: "../checkpoints/sam2.1_hiera_large.pt"
|
||||||
|
sam2_config: "configs/sam2.1/sam2.1_hiera_l.yaml"
|
||||||
|
|
||||||
|
video:
|
||||||
|
# Use NVIDIA hardware encoding (requires NVENC-capable GPU)
|
||||||
|
use_nvenc: true
|
||||||
|
|
||||||
|
# Output video bitrate
|
||||||
|
output_bitrate: "50M"
|
||||||
|
|
||||||
|
# Preserve original audio track
|
||||||
|
preserve_audio: true
|
||||||
|
|
||||||
|
# Force keyframes for better segment boundaries
|
||||||
|
force_keyframes: true
|
||||||
|
|
||||||
|
advanced:
|
||||||
|
# Green screen color (RGB values)
|
||||||
|
green_color: [0, 255, 0]
|
||||||
|
|
||||||
|
# Blue screen color for second object (RGB values)
|
||||||
|
blue_color: [255, 0, 0]
|
||||||
|
|
||||||
|
# YOLO human class ID (0 for COCO person class)
|
||||||
|
human_class_id: 0
|
||||||
|
|
||||||
|
# GPU memory management
|
||||||
|
cleanup_intermediate_files: true
|
||||||
|
|
||||||
|
# Logging level (DEBUG, INFO, WARNING, ERROR)
|
||||||
|
log_level: "INFO"
|
||||||
2
core/__init__.py
Normal file
2
core/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
# YOLO + SAM2 Video Processing Pipeline
|
||||||
|
# Core modules for video processing with human detection and segmentation
|
||||||
158
core/config_loader.py
Normal file
158
core/config_loader.py
Normal file
@@ -0,0 +1,158 @@
|
|||||||
|
"""
|
||||||
|
Configuration loader for YOLO + SAM2 video processing pipeline.
|
||||||
|
Handles loading and validation of YAML configuration files.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
import os
|
||||||
|
from typing import Dict, Any, List, Union
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class ConfigLoader:
|
||||||
|
"""Loads and validates configuration from YAML files."""
|
||||||
|
|
||||||
|
def __init__(self, config_path: str):
|
||||||
|
self.config_path = config_path
|
||||||
|
self.config = self._load_config()
|
||||||
|
self._validate_config()
|
||||||
|
|
||||||
|
def _load_config(self) -> Dict[str, Any]:
|
||||||
|
"""Load configuration from YAML file."""
|
||||||
|
if not os.path.exists(self.config_path):
|
||||||
|
raise FileNotFoundError(f"Configuration file not found: {self.config_path}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(self.config_path, 'r') as file:
|
||||||
|
config = yaml.safe_load(file)
|
||||||
|
logger.info(f"Loaded configuration from {self.config_path}")
|
||||||
|
return config
|
||||||
|
except yaml.YAMLError as e:
|
||||||
|
raise ValueError(f"Error parsing YAML file: {e}")
|
||||||
|
|
||||||
|
def _validate_config(self):
|
||||||
|
"""Validate required configuration fields."""
|
||||||
|
required_sections = ['input', 'output', 'processing', 'models']
|
||||||
|
|
||||||
|
for section in required_sections:
|
||||||
|
if section not in self.config:
|
||||||
|
raise ValueError(f"Missing required configuration section: {section}")
|
||||||
|
|
||||||
|
# Validate input section
|
||||||
|
if 'video_path' not in self.config['input']:
|
||||||
|
raise ValueError("Missing required field: input.video_path")
|
||||||
|
|
||||||
|
# Validate output section
|
||||||
|
required_output_fields = ['directory', 'filename']
|
||||||
|
for field in required_output_fields:
|
||||||
|
if field not in self.config['output']:
|
||||||
|
raise ValueError(f"Missing required field: output.{field}")
|
||||||
|
|
||||||
|
# Validate models section
|
||||||
|
required_model_fields = ['yolo_model', 'sam2_checkpoint', 'sam2_config']
|
||||||
|
for field in required_model_fields:
|
||||||
|
if field not in self.config['models']:
|
||||||
|
raise ValueError(f"Missing required field: models.{field}")
|
||||||
|
|
||||||
|
# Validate processing.detect_segments format
|
||||||
|
detect_segments = self.config['processing'].get('detect_segments', 'all')
|
||||||
|
if not isinstance(detect_segments, (str, list)):
|
||||||
|
raise ValueError("detect_segments must be 'all' or a list of integers")
|
||||||
|
|
||||||
|
if isinstance(detect_segments, list):
|
||||||
|
if not all(isinstance(x, int) for x in detect_segments):
|
||||||
|
raise ValueError("detect_segments list must contain only integers")
|
||||||
|
|
||||||
|
def get(self, key_path: str, default=None):
|
||||||
|
"""
|
||||||
|
Get configuration value using dot notation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key_path: Dot-separated key path (e.g., 'processing.yolo_confidence')
|
||||||
|
default: Default value if key not found
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configuration value or default
|
||||||
|
"""
|
||||||
|
keys = key_path.split('.')
|
||||||
|
value = self.config
|
||||||
|
|
||||||
|
try:
|
||||||
|
for key in keys:
|
||||||
|
value = value[key]
|
||||||
|
return value
|
||||||
|
except (KeyError, TypeError):
|
||||||
|
return default
|
||||||
|
|
||||||
|
def get_input_video_path(self) -> str:
|
||||||
|
"""Get input video path."""
|
||||||
|
return self.config['input']['video_path']
|
||||||
|
|
||||||
|
def get_output_directory(self) -> str:
|
||||||
|
"""Get output directory path."""
|
||||||
|
return self.config['output']['directory']
|
||||||
|
|
||||||
|
def get_output_filename(self) -> str:
|
||||||
|
"""Get output filename."""
|
||||||
|
return self.config['output']['filename']
|
||||||
|
|
||||||
|
def get_segment_duration(self) -> int:
|
||||||
|
"""Get segment duration in seconds."""
|
||||||
|
return self.config['processing'].get('segment_duration', 5)
|
||||||
|
|
||||||
|
def get_inference_scale(self) -> float:
|
||||||
|
"""Get inference scale factor."""
|
||||||
|
return self.config['processing'].get('inference_scale', 0.5)
|
||||||
|
|
||||||
|
def get_yolo_confidence(self) -> float:
|
||||||
|
"""Get YOLO confidence threshold."""
|
||||||
|
return self.config['processing'].get('yolo_confidence', 0.6)
|
||||||
|
|
||||||
|
def get_detect_segments(self) -> Union[str, List[int]]:
|
||||||
|
"""Get segments for YOLO detection."""
|
||||||
|
return self.config['processing'].get('detect_segments', 'all')
|
||||||
|
|
||||||
|
def get_yolo_model_path(self) -> str:
|
||||||
|
"""Get YOLO model path."""
|
||||||
|
return self.config['models']['yolo_model']
|
||||||
|
|
||||||
|
def get_sam2_checkpoint(self) -> str:
|
||||||
|
"""Get SAM2 checkpoint path."""
|
||||||
|
return self.config['models']['sam2_checkpoint']
|
||||||
|
|
||||||
|
def get_sam2_config(self) -> str:
|
||||||
|
"""Get SAM2 config path."""
|
||||||
|
return self.config['models']['sam2_config']
|
||||||
|
|
||||||
|
def get_use_nvenc(self) -> bool:
|
||||||
|
"""Get whether to use NVIDIA encoding."""
|
||||||
|
return self.config.get('video', {}).get('use_nvenc', True)
|
||||||
|
|
||||||
|
def get_preserve_audio(self) -> bool:
|
||||||
|
"""Get whether to preserve audio."""
|
||||||
|
return self.config.get('video', {}).get('preserve_audio', True)
|
||||||
|
|
||||||
|
def get_output_bitrate(self) -> str:
|
||||||
|
"""Get output video bitrate."""
|
||||||
|
return self.config.get('video', {}).get('output_bitrate', '50M')
|
||||||
|
|
||||||
|
def get_green_color(self) -> List[int]:
|
||||||
|
"""Get green screen color."""
|
||||||
|
return self.config.get('advanced', {}).get('green_color', [0, 255, 0])
|
||||||
|
|
||||||
|
def get_blue_color(self) -> List[int]:
|
||||||
|
"""Get blue screen color."""
|
||||||
|
return self.config.get('advanced', {}).get('blue_color', [255, 0, 0])
|
||||||
|
|
||||||
|
def get_human_class_id(self) -> int:
|
||||||
|
"""Get YOLO human class ID."""
|
||||||
|
return self.config.get('advanced', {}).get('human_class_id', 0)
|
||||||
|
|
||||||
|
def get_log_level(self) -> str:
|
||||||
|
"""Get logging level."""
|
||||||
|
return self.config.get('advanced', {}).get('log_level', 'INFO')
|
||||||
|
|
||||||
|
def should_cleanup_intermediate_files(self) -> bool:
|
||||||
|
"""Get whether to cleanup intermediate files."""
|
||||||
|
return self.config.get('advanced', {}).get('cleanup_intermediate_files', True)
|
||||||
362
core/sam2_processor.py
Normal file
362
core/sam2_processor.py
Normal file
@@ -0,0 +1,362 @@
|
|||||||
|
"""
|
||||||
|
SAM2 processor module for video segmentation.
|
||||||
|
Preserves the core SAM2 logic from the original implementation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import logging
|
||||||
|
import gc
|
||||||
|
from typing import Dict, List, Any, Optional, Tuple
|
||||||
|
from sam2.build_sam import build_sam2_video_predictor
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class SAM2Processor:
|
||||||
|
"""Handles SAM2-based video segmentation for human tracking."""
|
||||||
|
|
||||||
|
def __init__(self, checkpoint_path: str, config_path: str):
|
||||||
|
"""
|
||||||
|
Initialize SAM2 processor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
checkpoint_path: Path to SAM2 checkpoint
|
||||||
|
config_path: Path to SAM2 config file
|
||||||
|
"""
|
||||||
|
self.checkpoint_path = checkpoint_path
|
||||||
|
self.config_path = config_path
|
||||||
|
self.predictor = None
|
||||||
|
self._initialize_predictor()
|
||||||
|
|
||||||
|
def _initialize_predictor(self):
|
||||||
|
"""Initialize SAM2 video predictor with proper device setup."""
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda")
|
||||||
|
elif torch.backends.mps.is_available():
|
||||||
|
device = torch.device("mps")
|
||||||
|
logger.warning(
|
||||||
|
"Support for MPS devices is preliminary. SAM 2 is trained with CUDA and might "
|
||||||
|
"give numerically different outputs and sometimes degraded performance on MPS."
|
||||||
|
)
|
||||||
|
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
||||||
|
else:
|
||||||
|
device = torch.device("cpu")
|
||||||
|
|
||||||
|
logger.info(f"Using device: {device}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.predictor = build_sam2_video_predictor(
|
||||||
|
self.config_path,
|
||||||
|
self.checkpoint_path,
|
||||||
|
device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
# Enable optimizations for CUDA
|
||||||
|
if device.type == "cuda":
|
||||||
|
torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
|
||||||
|
if torch.cuda.get_device_properties(0).major >= 8:
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
torch.backends.cudnn.allow_tf32 = True
|
||||||
|
|
||||||
|
logger.info(f"SAM2 predictor initialized successfully")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to initialize SAM2 predictor: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def create_low_res_video(self, input_video_path: str, output_video_path: str, scale: float):
|
||||||
|
"""
|
||||||
|
Create a low-resolution version of the input video for inference.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_video_path: Path to input video
|
||||||
|
output_video_path: Path to output low-res video
|
||||||
|
scale: Scale factor for resolution reduction
|
||||||
|
"""
|
||||||
|
cap = cv2.VideoCapture(input_video_path)
|
||||||
|
if not cap.isOpened():
|
||||||
|
raise ValueError(f"Could not open video: {input_video_path}")
|
||||||
|
|
||||||
|
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH) * scale)
|
||||||
|
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT) * scale)
|
||||||
|
fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
|
||||||
|
|
||||||
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
||||||
|
out = cv2.VideoWriter(output_video_path, fourcc, fps, (frame_width, frame_height))
|
||||||
|
|
||||||
|
frame_count = 0
|
||||||
|
while True:
|
||||||
|
ret, frame = cap.read()
|
||||||
|
if not ret:
|
||||||
|
break
|
||||||
|
|
||||||
|
low_res_frame = cv2.resize(frame, (frame_width, frame_height), interpolation=cv2.INTER_LINEAR)
|
||||||
|
out.write(low_res_frame)
|
||||||
|
frame_count += 1
|
||||||
|
|
||||||
|
cap.release()
|
||||||
|
out.release()
|
||||||
|
|
||||||
|
logger.info(f"Created low-res video with {frame_count} frames: {output_video_path}")
|
||||||
|
|
||||||
|
def add_yolo_prompts_to_predictor(self, inference_state, prompts: List[Dict[str, Any]]) -> bool:
|
||||||
|
"""
|
||||||
|
Add YOLO detection prompts to SAM2 predictor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inference_state: SAM2 inference state
|
||||||
|
prompts: List of prompt dictionaries with obj_id and bbox
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if prompts were added successfully
|
||||||
|
"""
|
||||||
|
if not prompts:
|
||||||
|
logger.warning("No prompts provided to SAM2")
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
for prompt in prompts:
|
||||||
|
obj_id = prompt['obj_id']
|
||||||
|
bbox = prompt['bbox']
|
||||||
|
|
||||||
|
_, out_obj_ids, out_mask_logits = self.predictor.add_new_points_or_box(
|
||||||
|
inference_state=inference_state,
|
||||||
|
frame_idx=0,
|
||||||
|
obj_id=obj_id,
|
||||||
|
box=bbox.astype(np.float32),
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(f"Added prompt for Object {obj_id}: {bbox}")
|
||||||
|
|
||||||
|
logger.info(f"Successfully added {len(prompts)} prompts to SAM2")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error adding prompts to SAM2: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def load_previous_segment_mask(self, prev_segment_dir: str) -> Optional[Dict[int, np.ndarray]]:
|
||||||
|
"""
|
||||||
|
Load masks from previous segment for continuity.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prev_segment_dir: Directory of previous segment
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping object IDs to masks, or None if failed
|
||||||
|
"""
|
||||||
|
mask_path = os.path.join(prev_segment_dir, "mask.png")
|
||||||
|
|
||||||
|
if not os.path.exists(mask_path):
|
||||||
|
logger.warning(f"Previous mask not found: {mask_path}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
mask_image = cv2.imread(mask_path)
|
||||||
|
if mask_image is None:
|
||||||
|
logger.error(f"Could not read mask image: {mask_path}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
if len(mask_image.shape) != 3 or mask_image.shape[2] != 3:
|
||||||
|
logger.error("Mask image does not have three color channels")
|
||||||
|
return None
|
||||||
|
|
||||||
|
mask_image = mask_image.astype(np.uint8)
|
||||||
|
|
||||||
|
# Extract Object A and Object B masks (preserving original logic)
|
||||||
|
GREEN = [0, 255, 0]
|
||||||
|
BLUE = [255, 0, 0]
|
||||||
|
|
||||||
|
mask_a = np.all(mask_image == GREEN, axis=2)
|
||||||
|
mask_b = np.all(mask_image == BLUE, axis=2)
|
||||||
|
|
||||||
|
per_obj_input_mask = {}
|
||||||
|
if np.any(mask_a):
|
||||||
|
per_obj_input_mask[1] = mask_a
|
||||||
|
if np.any(mask_b):
|
||||||
|
per_obj_input_mask[2] = mask_b
|
||||||
|
|
||||||
|
logger.info(f"Loaded masks for {len(per_obj_input_mask)} objects from {prev_segment_dir}")
|
||||||
|
return per_obj_input_mask
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error loading previous mask: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def add_previous_masks_to_predictor(self, inference_state, masks: Dict[int, np.ndarray]) -> bool:
|
||||||
|
"""
|
||||||
|
Add previous segment masks to predictor for continuity.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inference_state: SAM2 inference state
|
||||||
|
masks: Dictionary mapping object IDs to masks
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if masks were added successfully
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
for obj_id, mask in masks.items():
|
||||||
|
self.predictor.add_new_mask(inference_state, 0, obj_id, mask)
|
||||||
|
logger.debug(f"Added previous mask for Object {obj_id}")
|
||||||
|
|
||||||
|
logger.info(f"Successfully added {len(masks)} previous masks to SAM2")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error adding previous masks to SAM2: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def propagate_masks(self, inference_state) -> Dict[int, Dict[int, np.ndarray]]:
|
||||||
|
"""
|
||||||
|
Propagate masks across all frames in the video.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inference_state: SAM2 inference state
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping frame indices to object masks
|
||||||
|
"""
|
||||||
|
video_segments = {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(inference_state):
|
||||||
|
video_segments[out_frame_idx] = {
|
||||||
|
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
|
||||||
|
for i, out_obj_id in enumerate(out_obj_ids)
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(f"Propagated masks across {len(video_segments)} frames with {len(out_obj_ids)} objects")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error during mask propagation: {e}")
|
||||||
|
|
||||||
|
return video_segments
|
||||||
|
|
||||||
|
def process_single_segment(self, segment_info: dict, yolo_prompts: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
previous_masks: Optional[Dict[int, np.ndarray]] = None,
|
||||||
|
inference_scale: float = 0.5) -> Optional[Dict[int, Dict[int, np.ndarray]]]:
|
||||||
|
"""
|
||||||
|
Process a single video segment with SAM2.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
segment_info: Segment information dictionary
|
||||||
|
yolo_prompts: Optional YOLO detection prompts
|
||||||
|
previous_masks: Optional masks from previous segment
|
||||||
|
inference_scale: Scale factor for inference
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Video segments dictionary or None if failed
|
||||||
|
"""
|
||||||
|
segment_dir = segment_info['directory']
|
||||||
|
video_path = segment_info['video_file']
|
||||||
|
segment_idx = segment_info['index']
|
||||||
|
|
||||||
|
# Check if segment is already processed (resume capability)
|
||||||
|
output_done_file = os.path.join(segment_dir, "output_frames_done")
|
||||||
|
if os.path.exists(output_done_file):
|
||||||
|
logger.info(f"Segment {segment_idx} already processed. Skipping.")
|
||||||
|
return None # Indicate skip, not failure
|
||||||
|
|
||||||
|
logger.info(f"Processing segment {segment_idx} with SAM2")
|
||||||
|
|
||||||
|
# Create low-resolution video for inference
|
||||||
|
low_res_video_path = os.path.join(segment_dir, "low_res_video.mp4")
|
||||||
|
if not os.path.exists(low_res_video_path):
|
||||||
|
try:
|
||||||
|
self.create_low_res_video(video_path, low_res_video_path, inference_scale)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to create low-res video for segment {segment_idx}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Initialize inference state
|
||||||
|
inference_state = self.predictor.init_state(video_path=low_res_video_path, async_loading_frames=True)
|
||||||
|
|
||||||
|
# Add prompts or previous masks
|
||||||
|
if yolo_prompts:
|
||||||
|
if not self.add_yolo_prompts_to_predictor(inference_state, yolo_prompts):
|
||||||
|
return None
|
||||||
|
elif previous_masks:
|
||||||
|
if not self.add_previous_masks_to_predictor(inference_state, previous_masks):
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
logger.error(f"No prompts or previous masks available for segment {segment_idx}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Propagate masks
|
||||||
|
video_segments = self.propagate_masks(inference_state)
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
self.predictor.reset_state(inference_state)
|
||||||
|
del inference_state
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
# Remove low-res video to save space
|
||||||
|
try:
|
||||||
|
os.remove(low_res_video_path)
|
||||||
|
logger.debug(f"Removed low-res video: {low_res_video_path}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Could not remove low-res video: {e}")
|
||||||
|
|
||||||
|
# Mark segment as completed (for resume capability)
|
||||||
|
try:
|
||||||
|
with open(output_done_file, 'w') as f:
|
||||||
|
f.write(f"Segment {segment_idx} completed successfully\n")
|
||||||
|
logger.debug(f"Marked segment {segment_idx} as completed")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Could not create completion marker: {e}")
|
||||||
|
|
||||||
|
return video_segments
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing segment {segment_idx}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def save_final_masks(self, video_segments: Dict[int, Dict[int, np.ndarray]], output_path: str,
|
||||||
|
green_color: List[int] = [0, 255, 0], blue_color: List[int] = [255, 0, 0]):
|
||||||
|
"""
|
||||||
|
Save the final masks as a colored image for continuity.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
video_segments: Video segments dictionary
|
||||||
|
output_path: Path to save the mask image
|
||||||
|
green_color: RGB color for object 1
|
||||||
|
blue_color: RGB color for object 2
|
||||||
|
"""
|
||||||
|
if not video_segments:
|
||||||
|
logger.error("No video segments to save")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
last_frame_idx = max(video_segments.keys())
|
||||||
|
masks_dict = video_segments[last_frame_idx]
|
||||||
|
|
||||||
|
# Get masks for objects 1 and 2
|
||||||
|
mask_a = masks_dict.get(1)
|
||||||
|
mask_b = masks_dict.get(2)
|
||||||
|
|
||||||
|
if mask_a is None and mask_b is None:
|
||||||
|
logger.error("No masks found for objects")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Use the first available mask to determine dimensions
|
||||||
|
reference_mask = mask_a if mask_a is not None else mask_b
|
||||||
|
reference_mask = reference_mask.squeeze()
|
||||||
|
|
||||||
|
black_frame = np.zeros((reference_mask.shape[0], reference_mask.shape[1], 3), dtype=np.uint8)
|
||||||
|
|
||||||
|
if mask_a is not None:
|
||||||
|
mask_a = mask_a.squeeze().astype(bool)
|
||||||
|
black_frame[mask_a] = green_color
|
||||||
|
|
||||||
|
if mask_b is not None:
|
||||||
|
mask_b = mask_b.squeeze().astype(bool)
|
||||||
|
black_frame[mask_b] = blue_color
|
||||||
|
|
||||||
|
# Save the mask image
|
||||||
|
cv2.imwrite(output_path, black_frame)
|
||||||
|
logger.info(f"Saved final masks to {output_path}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error saving final masks: {e}")
|
||||||
174
core/video_splitter.py
Normal file
174
core/video_splitter.py
Normal file
@@ -0,0 +1,174 @@
|
|||||||
|
"""
|
||||||
|
Video splitter module for the YOLO + SAM2 processing pipeline.
|
||||||
|
Handles splitting long videos into manageable segments.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import logging
|
||||||
|
from typing import List, Tuple
|
||||||
|
from ..utils.file_utils import ensure_directory, get_video_file_name
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class VideoSplitter:
|
||||||
|
"""Handles splitting videos into segments for processing."""
|
||||||
|
|
||||||
|
def __init__(self, segment_duration: int = 5, force_keyframes: bool = True):
|
||||||
|
"""
|
||||||
|
Initialize video splitter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
segment_duration: Duration of each segment in seconds
|
||||||
|
force_keyframes: Whether to force keyframes for clean cuts
|
||||||
|
"""
|
||||||
|
self.segment_duration = segment_duration
|
||||||
|
self.force_keyframes = force_keyframes
|
||||||
|
|
||||||
|
def split_video(self, input_video: str, output_dir: str) -> Tuple[str, List[str]]:
|
||||||
|
"""
|
||||||
|
Split video into segments and organize into directory structure.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_video: Path to input video file
|
||||||
|
output_dir: Base output directory
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (segments_directory, list_of_segment_directories)
|
||||||
|
"""
|
||||||
|
if not os.path.exists(input_video):
|
||||||
|
raise FileNotFoundError(f"Input video not found: {input_video}")
|
||||||
|
|
||||||
|
# Create output directory structure
|
||||||
|
video_name = os.path.splitext(os.path.basename(input_video))[0]
|
||||||
|
segments_dir = os.path.join(output_dir, f"{video_name}_segments")
|
||||||
|
ensure_directory(segments_dir)
|
||||||
|
|
||||||
|
logger.info(f"Splitting video {input_video} into {self.segment_duration}s segments")
|
||||||
|
|
||||||
|
# Split video using ffmpeg
|
||||||
|
segment_pattern = os.path.join(segments_dir, "segment_%03d.mp4")
|
||||||
|
|
||||||
|
# Build ffmpeg command
|
||||||
|
cmd = [
|
||||||
|
'ffmpeg', '-i', input_video,
|
||||||
|
'-f', 'segment',
|
||||||
|
'-segment_time', str(self.segment_duration),
|
||||||
|
'-reset_timestamps', '1',
|
||||||
|
'-c', 'copy'
|
||||||
|
]
|
||||||
|
|
||||||
|
# Add keyframe forcing if enabled
|
||||||
|
if self.force_keyframes:
|
||||||
|
cmd.extend(['-force_key_frames', f'expr:gte(t,n_forced*{self.segment_duration})'])
|
||||||
|
|
||||||
|
# Add copyts for timestamp preservation
|
||||||
|
cmd.extend(['-copyts', segment_pattern])
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
cmd,
|
||||||
|
check=True,
|
||||||
|
capture_output=True,
|
||||||
|
text=True
|
||||||
|
)
|
||||||
|
logger.debug(f"FFmpeg output: {result.stderr}")
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
logger.error(f"FFmpeg failed: {e.stderr}")
|
||||||
|
raise RuntimeError(f"Video splitting failed: {e}")
|
||||||
|
|
||||||
|
# Organize segments into individual directories
|
||||||
|
segment_dirs = self._organize_segments(segments_dir)
|
||||||
|
|
||||||
|
# Create file list for later concatenation
|
||||||
|
self._create_file_list(segments_dir, segment_dirs)
|
||||||
|
|
||||||
|
logger.info(f"Successfully split video into {len(segment_dirs)} segments")
|
||||||
|
return segments_dir, segment_dirs
|
||||||
|
|
||||||
|
def _organize_segments(self, segments_dir: str) -> List[str]:
|
||||||
|
"""
|
||||||
|
Move each segment into its own subdirectory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
segments_dir: Directory containing split segments
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of created segment directory names
|
||||||
|
"""
|
||||||
|
segment_files = []
|
||||||
|
segment_dirs = []
|
||||||
|
|
||||||
|
# Find all segment files
|
||||||
|
for file in os.listdir(segments_dir):
|
||||||
|
if file.startswith("segment_") and file.endswith(".mp4"):
|
||||||
|
segment_files.append(file)
|
||||||
|
|
||||||
|
# Sort segment files numerically
|
||||||
|
segment_files.sort(key=lambda x: int(x.split("_")[1].split(".")[0]))
|
||||||
|
|
||||||
|
# Move each segment to its own directory
|
||||||
|
for i, segment_file in enumerate(segment_files):
|
||||||
|
segment_dir_name = f"segment_{i}"
|
||||||
|
segment_dir_path = os.path.join(segments_dir, segment_dir_name)
|
||||||
|
ensure_directory(segment_dir_path)
|
||||||
|
|
||||||
|
# Move segment file to subdirectory with standardized name
|
||||||
|
old_path = os.path.join(segments_dir, segment_file)
|
||||||
|
new_path = os.path.join(segment_dir_path, get_video_file_name(i))
|
||||||
|
|
||||||
|
os.rename(old_path, new_path)
|
||||||
|
segment_dirs.append(segment_dir_name)
|
||||||
|
|
||||||
|
logger.debug(f"Organized segment {i}: {new_path}")
|
||||||
|
|
||||||
|
return segment_dirs
|
||||||
|
|
||||||
|
def _create_file_list(self, segments_dir: str, segment_dirs: List[str]):
|
||||||
|
"""
|
||||||
|
Create a file list for future concatenation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
segments_dir: Base segments directory
|
||||||
|
segment_dirs: List of segment directory names
|
||||||
|
"""
|
||||||
|
file_list_path = os.path.join(segments_dir, "file_list.txt")
|
||||||
|
|
||||||
|
with open(file_list_path, 'w') as f:
|
||||||
|
for i, segment_dir in enumerate(segment_dirs):
|
||||||
|
segment_path = os.path.join(segment_dir, get_video_file_name(i))
|
||||||
|
f.write(f"file '{segment_path}'\\n")
|
||||||
|
|
||||||
|
logger.debug(f"Created file list: {file_list_path}")
|
||||||
|
|
||||||
|
def get_segment_info(self, segments_dir: str) -> List[dict]:
|
||||||
|
"""
|
||||||
|
Get information about all segments in a directory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
segments_dir: Directory containing segments
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of segment information dictionaries
|
||||||
|
"""
|
||||||
|
segment_info = []
|
||||||
|
|
||||||
|
for item in os.listdir(segments_dir):
|
||||||
|
item_path = os.path.join(segments_dir, item)
|
||||||
|
|
||||||
|
if os.path.isdir(item_path) and item.startswith("segment_"):
|
||||||
|
segment_index = int(item.split("_")[1])
|
||||||
|
video_file = os.path.join(item_path, get_video_file_name(segment_index))
|
||||||
|
|
||||||
|
info = {
|
||||||
|
'index': segment_index,
|
||||||
|
'directory': item_path,
|
||||||
|
'video_file': video_file,
|
||||||
|
'exists': os.path.exists(video_file)
|
||||||
|
}
|
||||||
|
segment_info.append(info)
|
||||||
|
|
||||||
|
# Sort by index
|
||||||
|
segment_info.sort(key=lambda x: x['index'])
|
||||||
|
|
||||||
|
return segment_info
|
||||||
286
core/yolo_detector.py
Normal file
286
core/yolo_detector.py
Normal file
@@ -0,0 +1,286 @@
|
|||||||
|
"""
|
||||||
|
YOLO detector module for human detection in video segments.
|
||||||
|
Preserves the core detection logic from the original implementation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import logging
|
||||||
|
from typing import List, Dict, Any, Optional
|
||||||
|
from ultralytics import YOLO
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class YOLODetector:
|
||||||
|
\"\"\"Handles YOLO-based human detection for video segments.\"\"\"
|
||||||
|
|
||||||
|
def __init__(self, model_path: str, confidence_threshold: float = 0.6, human_class_id: int = 0):
|
||||||
|
\"\"\"
|
||||||
|
Initialize YOLO detector.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path: Path to YOLO model weights
|
||||||
|
confidence_threshold: Detection confidence threshold
|
||||||
|
human_class_id: COCO class ID for humans (0 = person)
|
||||||
|
\"\"\"
|
||||||
|
self.model_path = model_path
|
||||||
|
self.confidence_threshold = confidence_threshold
|
||||||
|
self.human_class_id = human_class_id
|
||||||
|
|
||||||
|
# Load YOLO model
|
||||||
|
try:
|
||||||
|
self.model = YOLO(model_path)
|
||||||
|
logger.info(f\"Loaded YOLO model from {model_path}\")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f\"Failed to load YOLO model: {e}\")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def detect_humans_in_frame(self, frame: np.ndarray) -> List[Dict[str, Any]]:
|
||||||
|
\"\"\"
|
||||||
|
Detect humans in a single frame using YOLO.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
frame: Input frame (BGR format from OpenCV)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of human detection dictionaries with bbox and confidence
|
||||||
|
\"\"\"
|
||||||
|
# Run YOLO detection
|
||||||
|
results = self.model(frame, conf=self.confidence_threshold, verbose=False)
|
||||||
|
|
||||||
|
human_detections = []
|
||||||
|
|
||||||
|
# Process results
|
||||||
|
for result in results:
|
||||||
|
boxes = result.boxes
|
||||||
|
if boxes is not None:
|
||||||
|
for box in boxes:
|
||||||
|
# Get class ID
|
||||||
|
cls = int(box.cls.cpu().numpy()[0])
|
||||||
|
|
||||||
|
# Check if it's a person (human_class_id)
|
||||||
|
if cls == self.human_class_id:
|
||||||
|
# Get bounding box coordinates (x1, y1, x2, y2)
|
||||||
|
coords = box.xyxy[0].cpu().numpy()
|
||||||
|
conf = float(box.conf.cpu().numpy()[0])
|
||||||
|
|
||||||
|
human_detections.append({
|
||||||
|
'bbox': coords,
|
||||||
|
'confidence': conf
|
||||||
|
})
|
||||||
|
|
||||||
|
logger.debug(f\"Detected human with confidence {conf:.2f} at {coords}\")
|
||||||
|
|
||||||
|
return human_detections
|
||||||
|
|
||||||
|
def detect_humans_in_video_first_frame(self, video_path: str, scale: float = 1.0) -> List[Dict[str, Any]]:
|
||||||
|
\"\"\"
|
||||||
|
Detect humans in the first frame of a video.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
video_path: Path to video file
|
||||||
|
scale: Scale factor for frame processing
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of human detection dictionaries
|
||||||
|
\"\"\"
|
||||||
|
if not os.path.exists(video_path):
|
||||||
|
logger.error(f\"Video file not found: {video_path}\")
|
||||||
|
return []
|
||||||
|
|
||||||
|
cap = cv2.VideoCapture(video_path)
|
||||||
|
if not cap.isOpened():
|
||||||
|
logger.error(f\"Could not open video: {video_path}\")
|
||||||
|
return []
|
||||||
|
|
||||||
|
ret, frame = cap.read()
|
||||||
|
cap.release()
|
||||||
|
|
||||||
|
if not ret:
|
||||||
|
logger.error(f\"Could not read first frame from: {video_path}\")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Scale frame if needed
|
||||||
|
if scale != 1.0:
|
||||||
|
frame = cv2.resize(frame, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR)
|
||||||
|
|
||||||
|
return self.detect_humans_in_frame(frame)
|
||||||
|
|
||||||
|
def save_detections_to_file(self, detections: List[Dict[str, Any]], output_path: str) -> bool:
|
||||||
|
\"\"\"
|
||||||
|
Save detection results to file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
detections: List of detection dictionaries
|
||||||
|
output_path: Path to save detections
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if saved successfully
|
||||||
|
\"\"\"
|
||||||
|
try:
|
||||||
|
with open(output_path, 'w') as f:
|
||||||
|
f.write(\"# YOLO Human Detections\\n\")
|
||||||
|
if detections:
|
||||||
|
for detection in detections:
|
||||||
|
bbox = detection['bbox']
|
||||||
|
conf = detection['confidence']
|
||||||
|
f.write(f\"{bbox[0]},{bbox[1]},{bbox[2]},{bbox[3]},{conf}\\n\")
|
||||||
|
logger.info(f\"Saved {len(detections)} detections to {output_path}\")
|
||||||
|
else:
|
||||||
|
f.write(\"# No humans detected\\n\")
|
||||||
|
logger.info(f\"Saved empty detection file to {output_path}\")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f\"Failed to save detections to {output_path}: {e}\")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def load_detections_from_file(self, file_path: str) -> List[Dict[str, Any]]:
|
||||||
|
\"\"\"
|
||||||
|
Load detection results from file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to detection file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of detection dictionaries
|
||||||
|
\"\"\"
|
||||||
|
detections = []
|
||||||
|
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
logger.warning(f\"Detection file not found: {file_path}\")
|
||||||
|
return detections
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(file_path, 'r') as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.strip()
|
||||||
|
# Skip comments and empty lines
|
||||||
|
if line.startswith('#') or not line:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Parse detection line: x1,y1,x2,y2,confidence
|
||||||
|
parts = line.split(',')
|
||||||
|
if len(parts) == 5:
|
||||||
|
try:
|
||||||
|
bbox = [float(x) for x in parts[:4]]
|
||||||
|
conf = float(parts[4])
|
||||||
|
detections.append({
|
||||||
|
'bbox': np.array(bbox),
|
||||||
|
'confidence': conf
|
||||||
|
})
|
||||||
|
except ValueError:
|
||||||
|
logger.warning(f\"Invalid detection line: {line}\")
|
||||||
|
continue
|
||||||
|
|
||||||
|
logger.info(f\"Loaded {len(detections)} detections from {file_path}\")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f\"Failed to load detections from {file_path}: {e}\")
|
||||||
|
|
||||||
|
return detections
|
||||||
|
|
||||||
|
def process_segments_batch(self, segments_info: List[dict], detect_segments: List[int],
|
||||||
|
scale: float = 0.5) -> Dict[int, List[Dict[str, Any]]]:
|
||||||
|
\"\"\"
|
||||||
|
Process multiple segments for human detection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
segments_info: List of segment information dictionaries
|
||||||
|
detect_segments: List of segment indices to process
|
||||||
|
scale: Scale factor for processing
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping segment index to detection results
|
||||||
|
\"\"\"
|
||||||
|
results = {}
|
||||||
|
|
||||||
|
for segment_info in segments_info:
|
||||||
|
segment_idx = segment_info['index']
|
||||||
|
|
||||||
|
# Skip if not in detect_segments list
|
||||||
|
if detect_segments != 'all' and segment_idx not in detect_segments:
|
||||||
|
continue
|
||||||
|
|
||||||
|
video_path = segment_info['video_file']
|
||||||
|
detection_file = os.path.join(segment_info['directory'], \"yolo_detections\")
|
||||||
|
|
||||||
|
# Skip if already processed
|
||||||
|
if os.path.exists(detection_file):
|
||||||
|
logger.info(f\"Segment {segment_idx} already has detections, skipping\")
|
||||||
|
detections = self.load_detections_from_file(detection_file)
|
||||||
|
results[segment_idx] = detections
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Run detection
|
||||||
|
logger.info(f\"Processing segment {segment_idx} for human detection\")
|
||||||
|
detections = self.detect_humans_in_video_first_frame(video_path, scale)
|
||||||
|
|
||||||
|
# Save results
|
||||||
|
self.save_detections_to_file(detections, detection_file)
|
||||||
|
results[segment_idx] = detections
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def convert_detections_to_sam2_prompts(self, detections: List[Dict[str, Any]],
|
||||||
|
frame_width: int) -> List[Dict[str, Any]]:
|
||||||
|
\"\"\"
|
||||||
|
Convert YOLO detections to SAM2-compatible prompts for stereo video.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
detections: List of YOLO detection results
|
||||||
|
frame_width: Width of the video frame
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of SAM2 prompt dictionaries with obj_id and bbox
|
||||||
|
\"\"\"
|
||||||
|
if not detections:
|
||||||
|
return []
|
||||||
|
|
||||||
|
half_frame_width = frame_width // 2
|
||||||
|
prompts = []
|
||||||
|
|
||||||
|
# Sort detections by x-coordinate to get consistent left/right assignment
|
||||||
|
sorted_detections = sorted(detections, key=lambda x: x['bbox'][0])
|
||||||
|
|
||||||
|
obj_id = 1
|
||||||
|
|
||||||
|
for i, detection in enumerate(sorted_detections[:2]): # Take up to 2 humans
|
||||||
|
bbox = detection['bbox'].copy()
|
||||||
|
|
||||||
|
# For stereo videos, assign obj_id based on position
|
||||||
|
if len(sorted_detections) >= 2:
|
||||||
|
center_x = (bbox[0] + bbox[2]) / 2
|
||||||
|
if center_x < half_frame_width:
|
||||||
|
current_obj_id = 1 # Left human
|
||||||
|
else:
|
||||||
|
current_obj_id = 2 # Right human
|
||||||
|
else:
|
||||||
|
# If only one human, create prompts for both sides
|
||||||
|
current_obj_id = obj_id
|
||||||
|
obj_id += 1
|
||||||
|
|
||||||
|
# Create mirrored version for stereo
|
||||||
|
if obj_id <= 2:
|
||||||
|
mirrored_bbox = bbox.copy()
|
||||||
|
mirrored_bbox[0] += half_frame_width # Shift x1
|
||||||
|
mirrored_bbox[2] += half_frame_width # Shift x2
|
||||||
|
|
||||||
|
# Ensure mirrored bbox is within frame bounds
|
||||||
|
mirrored_bbox[0] = max(0, min(mirrored_bbox[0], frame_width - 1))
|
||||||
|
mirrored_bbox[2] = max(0, min(mirrored_bbox[2], frame_width - 1))
|
||||||
|
|
||||||
|
prompts.append({
|
||||||
|
'obj_id': obj_id,
|
||||||
|
'bbox': mirrored_bbox,
|
||||||
|
'confidence': detection['confidence']
|
||||||
|
})
|
||||||
|
obj_id += 1
|
||||||
|
|
||||||
|
prompts.append({
|
||||||
|
'obj_id': current_obj_id,
|
||||||
|
'bbox': bbox,
|
||||||
|
'confidence': detection['confidence']
|
||||||
|
})
|
||||||
|
|
||||||
|
logger.debug(f\"Converted {len(detections)} detections to {len(prompts)} SAM2 prompts\")
|
||||||
|
return prompts
|
||||||
195
main.py
Normal file
195
main.py
Normal file
@@ -0,0 +1,195 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Main entry point for YOLO + SAM2 video processing pipeline.
|
||||||
|
Processes long videos by splitting into segments, detecting humans with YOLO,
|
||||||
|
and creating green screen masks with SAM2.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import argparse
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
# Add project root to path
|
||||||
|
sys.path.append(os.path.dirname(__file__))
|
||||||
|
|
||||||
|
from core.config_loader import ConfigLoader
|
||||||
|
from core.video_splitter import VideoSplitter
|
||||||
|
from core.yolo_detector import YOLODetector
|
||||||
|
from utils.logging_utils import setup_logging, get_logger
|
||||||
|
from utils.file_utils import ensure_directory
|
||||||
|
from utils.status_utils import print_processing_status, cleanup_incomplete_segment
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
def parse_arguments():
|
||||||
|
"""Parse command line arguments."""
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="YOLO + SAM2 Video Processing Pipeline"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--config",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to YAML configuration file"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--log-file",
|
||||||
|
type=str,
|
||||||
|
help="Optional log file path"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--status",
|
||||||
|
action="store_true",
|
||||||
|
help="Show processing status and exit"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--cleanup-segment",
|
||||||
|
type=int,
|
||||||
|
help="Clean up a specific segment for restart (segment index)"
|
||||||
|
)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
def validate_dependencies():
|
||||||
|
"""Validate that required dependencies are available."""
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import cupy as cp
|
||||||
|
from ultralytics import YOLO
|
||||||
|
from sam2.build_sam import build_sam2_video_predictor
|
||||||
|
logger.info("All dependencies validated successfully")
|
||||||
|
return True
|
||||||
|
except ImportError as e:
|
||||||
|
logger.error(f"Missing dependency: {e}")
|
||||||
|
logger.error("Please install requirements: pip install -r requirements.txt")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def resolve_detect_segments(detect_segments, total_segments: int) -> List[int]:
|
||||||
|
"""
|
||||||
|
Resolve detect_segments configuration to list of segment indices.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
detect_segments: Configuration value ("all", list, or None)
|
||||||
|
total_segments: Total number of segments
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of segment indices to process
|
||||||
|
"""
|
||||||
|
if detect_segments == "all" or detect_segments is None:
|
||||||
|
return list(range(total_segments))
|
||||||
|
elif isinstance(detect_segments, list):
|
||||||
|
# Filter out invalid segment indices
|
||||||
|
valid_segments = [s for s in detect_segments if 0 <= s < total_segments]
|
||||||
|
if len(valid_segments) != len(detect_segments):
|
||||||
|
logger.warning(f"Some segment indices are invalid. Using: {valid_segments}")
|
||||||
|
return valid_segments
|
||||||
|
else:
|
||||||
|
logger.warning(f"Invalid detect_segments format: {detect_segments}. Using all segments.")
|
||||||
|
return list(range(total_segments))
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Main processing pipeline."""
|
||||||
|
args = parse_arguments()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Load configuration
|
||||||
|
config = ConfigLoader(args.config)
|
||||||
|
|
||||||
|
# Setup logging
|
||||||
|
setup_logging(config.get_log_level(), args.log_file)
|
||||||
|
|
||||||
|
# Handle status check
|
||||||
|
if args.status:
|
||||||
|
output_dir = config.get_output_directory()
|
||||||
|
input_video = config.get_input_video_path()
|
||||||
|
video_name = os.path.splitext(os.path.basename(input_video))[0]
|
||||||
|
segments_dir = os.path.join(output_dir, f"{video_name}_segments")
|
||||||
|
print_processing_status(segments_dir)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# Handle segment cleanup
|
||||||
|
if args.cleanup_segment is not None:
|
||||||
|
output_dir = config.get_output_directory()
|
||||||
|
input_video = config.get_input_video_path()
|
||||||
|
video_name = os.path.splitext(os.path.basename(input_video))[0]
|
||||||
|
segments_dir = os.path.join(output_dir, f"{video_name}_segments")
|
||||||
|
segment_dir = os.path.join(segments_dir, f"segment_{args.cleanup_segment}")
|
||||||
|
|
||||||
|
if cleanup_incomplete_segment(segment_dir):
|
||||||
|
logger.info(f"Successfully cleaned up segment {args.cleanup_segment}")
|
||||||
|
return 0
|
||||||
|
else:
|
||||||
|
logger.error(f"Failed to clean up segment {args.cleanup_segment}")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
logger.info("Starting YOLO + SAM2 video processing pipeline")
|
||||||
|
|
||||||
|
# Validate dependencies
|
||||||
|
if not validate_dependencies():
|
||||||
|
return 1
|
||||||
|
|
||||||
|
# Validate input video exists
|
||||||
|
input_video = config.get_input_video_path()
|
||||||
|
if not os.path.exists(input_video):
|
||||||
|
logger.error(f"Input video not found: {input_video}")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
# Setup output directory
|
||||||
|
output_dir = config.get_output_directory()
|
||||||
|
ensure_directory(output_dir)
|
||||||
|
|
||||||
|
# Step 1: Split video into segments
|
||||||
|
logger.info("Step 1: Splitting video into segments")
|
||||||
|
splitter = VideoSplitter(
|
||||||
|
segment_duration=config.get_segment_duration(),
|
||||||
|
force_keyframes=config.get('video.force_keyframes', True)
|
||||||
|
)
|
||||||
|
|
||||||
|
segments_dir, segment_dirs = splitter.split_video(input_video, output_dir)
|
||||||
|
logger.info(f"Created {len(segment_dirs)} segments in {segments_dir}")
|
||||||
|
|
||||||
|
# Get detailed segment information
|
||||||
|
segments_info = splitter.get_segment_info(segments_dir)
|
||||||
|
|
||||||
|
# Resolve which segments to process with YOLO
|
||||||
|
detect_segments_config = config.get_detect_segments()
|
||||||
|
detect_segments = resolve_detect_segments(detect_segments_config, len(segments_info))
|
||||||
|
|
||||||
|
# Step 2: Run YOLO detection on specified segments
|
||||||
|
logger.info("Step 2: Running YOLO human detection")
|
||||||
|
detector = YOLODetector(
|
||||||
|
model_path=config.get_yolo_model_path(),
|
||||||
|
confidence_threshold=config.get_yolo_confidence(),
|
||||||
|
human_class_id=config.get_human_class_id()
|
||||||
|
)
|
||||||
|
|
||||||
|
detection_results = detector.process_segments_batch(
|
||||||
|
segments_info,
|
||||||
|
detect_segments,
|
||||||
|
scale=config.get_inference_scale()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Log detection summary
|
||||||
|
total_humans = sum(len(detections) for detections in detection_results.values())
|
||||||
|
logger.info(f"Detected {total_humans} humans across {len(detection_results)} segments")
|
||||||
|
|
||||||
|
# Step 3: Process segments with SAM2 (placeholder for now)
|
||||||
|
logger.info("Step 3: SAM2 processing and green screen generation")
|
||||||
|
logger.info("SAM2 processing module not yet implemented - this is where segment processing would occur")
|
||||||
|
|
||||||
|
# Step 4: Assemble final video (placeholder for now)
|
||||||
|
logger.info("Step 4: Assembling final video with audio")
|
||||||
|
logger.info("Video assembly module not yet implemented - this is where concatenation and audio copying would occur")
|
||||||
|
|
||||||
|
logger.info("Pipeline completed successfully")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Pipeline failed: {e}", exc_info=True)
|
||||||
|
return 1
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
exit_code = main()
|
||||||
|
sys.exit(exit_code)
|
||||||
30
requirements.txt
Normal file
30
requirements.txt
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
# Core deep learning and computer vision
|
||||||
|
torch>=2.0.0
|
||||||
|
torchvision>=0.15.0
|
||||||
|
ultralytics>=8.0.0
|
||||||
|
opencv-python>=4.8.0
|
||||||
|
numpy>=1.24.0
|
||||||
|
|
||||||
|
# SAM2 - Segment Anything Model 2
|
||||||
|
git+https://github.com/facebookresearch/sam2.git
|
||||||
|
|
||||||
|
# GPU acceleration (optional but recommended)
|
||||||
|
cupy-cuda12x>=12.0.0 # For CUDA 12.x, adjust version as needed
|
||||||
|
|
||||||
|
# Configuration and utilities
|
||||||
|
PyYAML>=6.0
|
||||||
|
tqdm>=4.65.0
|
||||||
|
matplotlib>=3.7.0
|
||||||
|
Pillow>=10.0.0
|
||||||
|
|
||||||
|
# Optional: For advanced features
|
||||||
|
psutil>=5.9.0 # Memory monitoring
|
||||||
|
pympler>=0.9 # Memory profiling (for debugging)
|
||||||
|
|
||||||
|
# Video processing
|
||||||
|
ffmpeg-python>=0.2.0 # Python wrapper for FFmpeg (optional, shell ffmpeg still needed)
|
||||||
|
|
||||||
|
# Development dependencies (optional)
|
||||||
|
pytest>=7.0.0
|
||||||
|
black>=23.0.0
|
||||||
|
flake8>=6.0.0
|
||||||
192
spec.md
Normal file
192
spec.md
Normal file
@@ -0,0 +1,192 @@
|
|||||||
|
# YOLO + SAM2 Video Processing Pipeline
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
This project provides an automated video processing pipeline that uses YOLO for human detection and SAM2 for precise segmentation to create green screen videos. The system processes long videos by splitting them into manageable segments, detecting and tracking humans in each segment, and then reassembling the processed segments into a final output video with preserved audio.
|
||||||
|
|
||||||
|
## Core Functionality
|
||||||
|
|
||||||
|
### Input
|
||||||
|
- **Long video file** (MP4 format, any duration)
|
||||||
|
- **Configuration file** (YAML format) specifying processing parameters
|
||||||
|
|
||||||
|
### Output
|
||||||
|
- **Processed video file** with humans visible and background replaced with green screen
|
||||||
|
- **Preserved audio** from the original input video
|
||||||
|
- **Intermediate files** for debugging and quality control
|
||||||
|
|
||||||
|
## Processing Pipeline
|
||||||
|
|
||||||
|
### 1. Video Segmentation
|
||||||
|
- Splits input video into configurable-duration segments (default: 5 seconds)
|
||||||
|
- Creates organized directory structure: `segment_0/`, `segment_1/`, etc.
|
||||||
|
- Each segment folder contains the segment video file
|
||||||
|
- Generates force keyframes for consistent encoding
|
||||||
|
|
||||||
|
### 2. Human Detection & Tracking
|
||||||
|
- **YOLO Detection**: Automatically detects humans in keyframe segments using YOLOv8
|
||||||
|
- **SAM2 Segmentation**: Uses detected bounding boxes as prompts for precise mask generation
|
||||||
|
- **Mask Propagation**: Propagates masks across all frames in each segment
|
||||||
|
- **Stereo Video Support**: Handles VR/stereo content with left/right human assignment
|
||||||
|
- **Continuity**: Non-keyframe segments use previous segment masks for consistency
|
||||||
|
|
||||||
|
### 3. Green Screen Processing
|
||||||
|
- **Mask Application**: Applies generated masks to isolate humans
|
||||||
|
- **Background Replacement**: Replaces non-human areas with green screen (RGB: 0,255,0)
|
||||||
|
- **GPU Acceleration**: Uses CuPy for fast mask processing
|
||||||
|
- **Multi-resolution**: Low-res inference for speed, full-res final rendering
|
||||||
|
|
||||||
|
### 4. Video Assembly
|
||||||
|
- **Segment Concatenation**: Combines all processed segments into single video
|
||||||
|
- **Audio Preservation**: Copies original audio track to final output
|
||||||
|
- **Quality Maintenance**: Preserves original video quality and framerate
|
||||||
|
|
||||||
|
## Key Features
|
||||||
|
|
||||||
|
### Automated Processing
|
||||||
|
- **No Manual Intervention**: Fully automated human detection eliminates manual point selection
|
||||||
|
- **Batch Processing**: Processes multiple segments efficiently
|
||||||
|
- **Smart Fallback**: Robust mask propagation with intelligent previous-segment loading
|
||||||
|
|
||||||
|
### Modular Architecture
|
||||||
|
- **Configuration-Driven**: YAML-based configuration for easy parameter adjustment
|
||||||
|
- **Extensible Design**: Modular structure allows for easy feature additions
|
||||||
|
- **Error Recovery**: Graceful handling of detection failures and missing segments
|
||||||
|
|
||||||
|
### Performance Optimizations
|
||||||
|
- **GPU Acceleration**: CUDA/NVENC support for faster processing
|
||||||
|
- **Memory Management**: Efficient handling of large videos through segmentation
|
||||||
|
- **Concurrent Processing**: Thread-safe operations where applicable
|
||||||
|
|
||||||
|
## Technical Stack
|
||||||
|
|
||||||
|
### Core Dependencies
|
||||||
|
- **SAM2**: Facebook's Segment Anything Model 2 for precise segmentation
|
||||||
|
- **YOLOv8 (Ultralytics)**: Human detection and bounding box generation
|
||||||
|
- **OpenCV**: Video processing and frame manipulation
|
||||||
|
- **CuPy**: GPU-accelerated array operations
|
||||||
|
- **FFmpeg**: Video encoding/decoding and audio handling
|
||||||
|
- **PyTorch**: Deep learning framework backend
|
||||||
|
|
||||||
|
### Supported Formats
|
||||||
|
- **Input Video**: MP4, AVI, MOV (any OpenCV-supported format)
|
||||||
|
- **Output Video**: MP4 with H.265/HEVC encoding
|
||||||
|
- **Audio**: Preserves original audio codec and quality
|
||||||
|
|
||||||
|
## Configuration Options
|
||||||
|
|
||||||
|
### Video Processing
|
||||||
|
- `segment_duration`: Duration of each video segment (seconds)
|
||||||
|
- `inference_scale`: Scale factor for SAM2 inference (for speed)
|
||||||
|
- `output_scale`: Scale factor for final output
|
||||||
|
|
||||||
|
### Detection Parameters
|
||||||
|
- `yolo_model`: Path to YOLO model weights
|
||||||
|
- `yolo_confidence`: Detection confidence threshold
|
||||||
|
- `detect_segments`: Which segments to run YOLO detection on
|
||||||
|
|
||||||
|
### SAM2 Parameters
|
||||||
|
- `sam2_checkpoint`: Path to SAM2 model weights
|
||||||
|
- `sam2_config`: SAM2 model configuration file
|
||||||
|
|
||||||
|
### Output Options
|
||||||
|
- `use_nvenc`: Enable NVIDIA hardware encoding
|
||||||
|
- `output_bitrate`: Video bitrate for final output
|
||||||
|
- `preserve_audio`: Whether to copy audio track
|
||||||
|
|
||||||
|
## Directory Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
new_yolo/
|
||||||
|
├── spec.md # This specification document
|
||||||
|
├── requirements.txt # Python dependencies
|
||||||
|
├── config.yaml # Default configuration file
|
||||||
|
├── main.py # Entry point script
|
||||||
|
├── core/
|
||||||
|
│ ├── __init__.py
|
||||||
|
│ ├── video_splitter.py # Video segmentation logic
|
||||||
|
│ ├── yolo_detector.py # YOLO human detection
|
||||||
|
│ ├── sam2_processor.py # SAM2 segmentation
|
||||||
|
│ ├── mask_processor.py # Mask application and green screen
|
||||||
|
│ ├── video_assembler.py # Final video assembly
|
||||||
|
│ └── config_loader.py # Configuration management
|
||||||
|
├── utils/
|
||||||
|
│ ├── __init__.py
|
||||||
|
│ ├── file_utils.py # File system operations
|
||||||
|
│ ├── video_utils.py # Video processing utilities
|
||||||
|
│ └── logging_utils.py # Logging configuration
|
||||||
|
└── examples/
|
||||||
|
├── basic_config.yaml # Example configuration
|
||||||
|
└── advanced_config.yaml # Advanced configuration options
|
||||||
|
```
|
||||||
|
|
||||||
|
## Usage Examples
|
||||||
|
|
||||||
|
### Basic Usage
|
||||||
|
```bash
|
||||||
|
python main.py --config config.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
### Custom Configuration
|
||||||
|
```bash
|
||||||
|
python main.py --config examples/advanced_config.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
### Configuration File Example
|
||||||
|
```yaml
|
||||||
|
input:
|
||||||
|
video_path: "/path/to/input/video.mp4"
|
||||||
|
|
||||||
|
output:
|
||||||
|
directory: "/path/to/output/"
|
||||||
|
filename: "processed_video.mp4"
|
||||||
|
|
||||||
|
processing:
|
||||||
|
segment_duration: 5
|
||||||
|
inference_scale: 0.5
|
||||||
|
yolo_confidence: 0.6
|
||||||
|
detect_segments: "all" # or [0, 5, 10]
|
||||||
|
|
||||||
|
models:
|
||||||
|
yolo_model: "yolov8n.pt"
|
||||||
|
sam2_checkpoint: "../checkpoints/sam2.1_hiera_large.pt"
|
||||||
|
sam2_config: "configs/sam2.1/sam2.1_hiera_l.yaml"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Use Cases
|
||||||
|
|
||||||
|
### Content Creation
|
||||||
|
- **VR/360 Video Processing**: Remove backgrounds from immersive content
|
||||||
|
- **Green Screen Production**: Automated background removal for video production
|
||||||
|
- **Social Media Content**: Quick background replacement for content creators
|
||||||
|
|
||||||
|
### Commercial Applications
|
||||||
|
- **Video Conferencing**: Real-time background replacement
|
||||||
|
- **E-learning**: Professional video production with clean backgrounds
|
||||||
|
- **Marketing**: Product demonstration videos with custom backgrounds
|
||||||
|
|
||||||
|
## Performance Considerations
|
||||||
|
|
||||||
|
### Hardware Requirements
|
||||||
|
- **GPU**: NVIDIA GPU with CUDA support (recommended)
|
||||||
|
- **RAM**: 16GB+ for processing large videos
|
||||||
|
- **Storage**: SSD recommended for temporary file operations
|
||||||
|
|
||||||
|
### Processing Time
|
||||||
|
- Approximately **1-2x real-time** on modern GPUs
|
||||||
|
- Scales with video resolution and segment count
|
||||||
|
- Memory usage remains constant regardless of input video length
|
||||||
|
|
||||||
|
## Future Enhancements
|
||||||
|
|
||||||
|
### Planned Features
|
||||||
|
- **Multi-object Tracking**: Support for multiple humans per frame
|
||||||
|
- **Custom Object Detection**: Configurable object classes beyond humans
|
||||||
|
- **Real-time Processing**: Live video stream support
|
||||||
|
- **Cloud Integration**: AWS/GCP processing support
|
||||||
|
- **Web Interface**: Browser-based configuration and monitoring
|
||||||
|
|
||||||
|
### Model Improvements
|
||||||
|
- **Fine-tuned YOLO**: Domain-specific human detection models
|
||||||
|
- **SAM2 Optimization**: Custom SAM2 checkpoints for video content
|
||||||
|
- **Temporal Consistency**: Enhanced cross-segment mask propagation
|
||||||
1
utils/__init__.py
Normal file
1
utils/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
# Utility modules for the YOLO + SAM2 processing pipeline
|
||||||
168
utils/file_utils.py
Normal file
168
utils/file_utils.py
Normal file
@@ -0,0 +1,168 @@
|
|||||||
|
"""
|
||||||
|
File system utilities for the YOLO + SAM2 video processing pipeline.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import glob
|
||||||
|
from typing import List, Optional
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
def ensure_directory(path: str) -> str:
|
||||||
|
"""
|
||||||
|
Ensure directory exists, create if it doesn't.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Directory path to create
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The created directory path
|
||||||
|
"""
|
||||||
|
os.makedirs(path, exist_ok=True)
|
||||||
|
logger.debug(f"Ensured directory exists: {path}")
|
||||||
|
return path
|
||||||
|
|
||||||
|
def cleanup_directory(path: str, pattern: str = "*") -> int:
|
||||||
|
"""
|
||||||
|
Clean up files matching pattern in directory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Directory path to clean
|
||||||
|
pattern: File pattern to match (default: all files)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of files removed
|
||||||
|
"""
|
||||||
|
if not os.path.exists(path):
|
||||||
|
return 0
|
||||||
|
|
||||||
|
files_to_remove = glob.glob(os.path.join(path, pattern))
|
||||||
|
removed_count = 0
|
||||||
|
|
||||||
|
for file_path in files_to_remove:
|
||||||
|
try:
|
||||||
|
if os.path.isfile(file_path):
|
||||||
|
os.remove(file_path)
|
||||||
|
removed_count += 1
|
||||||
|
elif os.path.isdir(file_path):
|
||||||
|
shutil.rmtree(file_path)
|
||||||
|
removed_count += 1
|
||||||
|
except OSError as e:
|
||||||
|
logger.warning(f"Failed to remove {file_path}: {e}")
|
||||||
|
|
||||||
|
if removed_count > 0:
|
||||||
|
logger.info(f"Cleaned up {removed_count} files/directories from {path}")
|
||||||
|
|
||||||
|
return removed_count
|
||||||
|
|
||||||
|
def get_segments_directories(base_dir: str) -> List[str]:
|
||||||
|
"""
|
||||||
|
Get list of segment directories sorted by segment number.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base_dir: Base directory containing segments
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sorted list of segment directory names
|
||||||
|
"""
|
||||||
|
if not os.path.exists(base_dir):
|
||||||
|
return []
|
||||||
|
|
||||||
|
segments = [d for d in os.listdir(base_dir)
|
||||||
|
if os.path.isdir(os.path.join(base_dir, d)) and d.startswith("segment_")]
|
||||||
|
|
||||||
|
# Sort by segment number
|
||||||
|
segments.sort(key=lambda x: int(x.split("_")[1]))
|
||||||
|
|
||||||
|
logger.debug(f"Found {len(segments)} segment directories in {base_dir}")
|
||||||
|
return segments
|
||||||
|
|
||||||
|
def get_video_file_name(segment_index: int) -> str:
|
||||||
|
"""
|
||||||
|
Get standardized video filename for a segment.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
segment_index: Index of the segment
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted filename
|
||||||
|
"""
|
||||||
|
return f"segment_{str(segment_index).zfill(3)}.mp4"
|
||||||
|
|
||||||
|
def file_exists(file_path: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if file exists and is readable.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if file exists and is readable
|
||||||
|
"""
|
||||||
|
return os.path.isfile(file_path) and os.access(file_path, os.R_OK)
|
||||||
|
|
||||||
|
def create_file_list(segments_dir: str, output_path: str) -> str:
|
||||||
|
"""
|
||||||
|
Create ffmpeg-compatible file list for concatenation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
segments_dir: Directory containing segment subdirectories
|
||||||
|
output_path: Path to write the file list
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to the created file list
|
||||||
|
"""
|
||||||
|
segments = get_segments_directories(segments_dir)
|
||||||
|
|
||||||
|
with open(output_path, 'w') as f:
|
||||||
|
for i, segment in enumerate(segments):
|
||||||
|
segment_dir = os.path.join(segments_dir, segment)
|
||||||
|
output_video = os.path.join(segment_dir, f"output_{i}.mp4")
|
||||||
|
|
||||||
|
if file_exists(output_video):
|
||||||
|
# Use relative path for ffmpeg
|
||||||
|
relative_path = os.path.relpath(output_video, os.path.dirname(output_path))
|
||||||
|
f.write(f"file '{relative_path}'\\n")
|
||||||
|
|
||||||
|
logger.info(f"Created file list at {output_path}")
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
def safe_remove_file(file_path: str) -> bool:
|
||||||
|
"""
|
||||||
|
Safely remove a file with error handling.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to file to remove
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if file was removed successfully
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if os.path.exists(file_path):
|
||||||
|
os.remove(file_path)
|
||||||
|
logger.debug(f"Removed file: {file_path}")
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
except OSError as e:
|
||||||
|
logger.warning(f"Failed to remove {file_path}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_file_size_mb(file_path: str) -> float:
|
||||||
|
"""
|
||||||
|
Get file size in megabytes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
File size in MB, or 0 if file doesn't exist
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if os.path.exists(file_path):
|
||||||
|
size_bytes = os.path.getsize(file_path)
|
||||||
|
return size_bytes / (1024 * 1024)
|
||||||
|
return 0.0
|
||||||
|
except OSError:
|
||||||
|
return 0.0
|
||||||
52
utils/logging_utils.py
Normal file
52
utils/logging_utils.py
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
"""
|
||||||
|
Logging utilities for the YOLO + SAM2 video processing pipeline.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
def setup_logging(level: str = "INFO", log_file: Optional[str] = None):
|
||||||
|
"""
|
||||||
|
Setup logging configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
level: Logging level (DEBUG, INFO, WARNING, ERROR)
|
||||||
|
log_file: Optional log file path
|
||||||
|
"""
|
||||||
|
# Convert string level to logging constant
|
||||||
|
numeric_level = getattr(logging, level.upper(), None)
|
||||||
|
if not isinstance(numeric_level, int):
|
||||||
|
raise ValueError(f'Invalid log level: {level}')
|
||||||
|
|
||||||
|
# Create formatter
|
||||||
|
formatter = logging.Formatter(
|
||||||
|
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Setup console handler
|
||||||
|
console_handler = logging.StreamHandler(sys.stdout)
|
||||||
|
console_handler.setLevel(numeric_level)
|
||||||
|
console_handler.setFormatter(formatter)
|
||||||
|
|
||||||
|
# Setup root logger
|
||||||
|
root_logger = logging.getLogger()
|
||||||
|
root_logger.setLevel(numeric_level)
|
||||||
|
root_logger.addHandler(console_handler)
|
||||||
|
|
||||||
|
# Setup file handler if specified
|
||||||
|
if log_file:
|
||||||
|
file_handler = logging.FileHandler(log_file)
|
||||||
|
file_handler.setLevel(numeric_level)
|
||||||
|
file_handler.setFormatter(formatter)
|
||||||
|
root_logger.addHandler(file_handler)
|
||||||
|
|
||||||
|
# Reduce noise from some libraries
|
||||||
|
logging.getLogger('ultralytics').setLevel(logging.WARNING)
|
||||||
|
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
logging.info(f"Logging setup complete - Level: {level}")
|
||||||
|
|
||||||
|
def get_logger(name: str) -> logging.Logger:
|
||||||
|
"""Get a logger instance with the given name."""
|
||||||
|
return logging.getLogger(name)
|
||||||
198
utils/status_utils.py
Normal file
198
utils/status_utils.py
Normal file
@@ -0,0 +1,198 @@
|
|||||||
|
"""
|
||||||
|
Status utilities for tracking processing progress and resume capability.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import Dict, List, Tuple
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
def get_processing_status(segments_dir: str) -> Dict[str, any]:
|
||||||
|
"""
|
||||||
|
Get detailed processing status for all segments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
segments_dir: Directory containing video segments
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with processing status information
|
||||||
|
"""
|
||||||
|
if not os.path.exists(segments_dir):
|
||||||
|
return {
|
||||||
|
'total_segments': 0,
|
||||||
|
'segments_split': 0,
|
||||||
|
'yolo_completed': 0,
|
||||||
|
'sam2_completed': 0,
|
||||||
|
'can_resume': False,
|
||||||
|
'next_step': 'split_video'
|
||||||
|
}
|
||||||
|
|
||||||
|
# Find all segment directories
|
||||||
|
segments = []
|
||||||
|
for item in os.listdir(segments_dir):
|
||||||
|
item_path = os.path.join(segments_dir, item)
|
||||||
|
if os.path.isdir(item_path) and item.startswith("segment_"):
|
||||||
|
segments.append(item)
|
||||||
|
|
||||||
|
segments.sort(key=lambda x: int(x.split("_")[1]))
|
||||||
|
|
||||||
|
# Check status of each segment
|
||||||
|
segments_split = 0
|
||||||
|
yolo_completed = 0
|
||||||
|
sam2_completed = 0
|
||||||
|
|
||||||
|
for segment in segments:
|
||||||
|
segment_path = os.path.join(segments_dir, segment)
|
||||||
|
segment_idx = int(segment.split("_")[1])
|
||||||
|
|
||||||
|
# Check if segment video exists
|
||||||
|
video_file = os.path.join(segment_path, f"segment_{str(segment_idx).zfill(3)}.mp4")
|
||||||
|
if os.path.exists(video_file):
|
||||||
|
segments_split += 1
|
||||||
|
|
||||||
|
# Check if YOLO detection completed
|
||||||
|
yolo_file = os.path.join(segment_path, "yolo_detections")
|
||||||
|
if os.path.exists(yolo_file):
|
||||||
|
yolo_completed += 1
|
||||||
|
|
||||||
|
# Check if SAM2 processing completed
|
||||||
|
done_file = os.path.join(segment_path, "output_frames_done")
|
||||||
|
if os.path.exists(done_file):
|
||||||
|
sam2_completed += 1
|
||||||
|
|
||||||
|
# Determine next step
|
||||||
|
next_step = "complete"
|
||||||
|
if sam2_completed < len(segments):
|
||||||
|
next_step = "sam2_processing"
|
||||||
|
elif yolo_completed < len(segments):
|
||||||
|
next_step = "yolo_detection"
|
||||||
|
elif segments_split < len(segments):
|
||||||
|
next_step = "split_video"
|
||||||
|
|
||||||
|
return {
|
||||||
|
'total_segments': len(segments),
|
||||||
|
'segments_split': segments_split,
|
||||||
|
'yolo_completed': yolo_completed,
|
||||||
|
'sam2_completed': sam2_completed,
|
||||||
|
'can_resume': segments_split > 0,
|
||||||
|
'next_step': next_step,
|
||||||
|
'completion_percentage': (sam2_completed / len(segments) * 100) if segments else 0
|
||||||
|
}
|
||||||
|
|
||||||
|
def print_processing_status(segments_dir: str):
|
||||||
|
"""
|
||||||
|
Print a human-readable processing status report.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
segments_dir: Directory containing video segments
|
||||||
|
"""
|
||||||
|
status = get_processing_status(segments_dir)
|
||||||
|
|
||||||
|
print("\\n" + "="*50)
|
||||||
|
print("PROCESSING STATUS REPORT")
|
||||||
|
print("="*50)
|
||||||
|
print(f"Total Segments: {status['total_segments']}")
|
||||||
|
print(f"Video Splitting: {status['segments_split']}/{status['total_segments']} completed")
|
||||||
|
print(f"YOLO Detection: {status['yolo_completed']}/{status['total_segments']} completed")
|
||||||
|
print(f"SAM2 Processing: {status['sam2_completed']}/{status['total_segments']} completed")
|
||||||
|
print(f"Overall Progress: {status['completion_percentage']:.1f}%")
|
||||||
|
print(f"Next Step: {status['next_step']}")
|
||||||
|
print(f"Can Resume: {'Yes' if status['can_resume'] else 'No'}")
|
||||||
|
print("="*50 + "\\n")
|
||||||
|
|
||||||
|
def get_incomplete_segments(segments_dir: str) -> List[Tuple[int, str]]:
|
||||||
|
"""
|
||||||
|
Get list of segments that still need processing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
segments_dir: Directory containing video segments
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of tuples (segment_index, reason)
|
||||||
|
"""
|
||||||
|
incomplete = []
|
||||||
|
|
||||||
|
if not os.path.exists(segments_dir):
|
||||||
|
return incomplete
|
||||||
|
|
||||||
|
segments = []
|
||||||
|
for item in os.listdir(segments_dir):
|
||||||
|
item_path = os.path.join(segments_dir, item)
|
||||||
|
if os.path.isdir(item_path) and item.startswith("segment_"):
|
||||||
|
segments.append(item)
|
||||||
|
|
||||||
|
segments.sort(key=lambda x: int(x.split("_")[1]))
|
||||||
|
|
||||||
|
for segment in segments:
|
||||||
|
segment_path = os.path.join(segments_dir, segment)
|
||||||
|
segment_idx = int(segment.split("_")[1])
|
||||||
|
|
||||||
|
# Check SAM2 completion first (final step)
|
||||||
|
done_file = os.path.join(segment_path, "output_frames_done")
|
||||||
|
if not os.path.exists(done_file):
|
||||||
|
# Check what step is missing
|
||||||
|
yolo_file = os.path.join(segment_path, "yolo_detections")
|
||||||
|
video_file = os.path.join(segment_path, f"segment_{str(segment_idx).zfill(3)}.mp4")
|
||||||
|
|
||||||
|
if not os.path.exists(video_file):
|
||||||
|
incomplete.append((segment_idx, "video_splitting"))
|
||||||
|
elif not os.path.exists(yolo_file):
|
||||||
|
incomplete.append((segment_idx, "yolo_detection"))
|
||||||
|
else:
|
||||||
|
incomplete.append((segment_idx, "sam2_processing"))
|
||||||
|
|
||||||
|
return incomplete
|
||||||
|
|
||||||
|
def cleanup_incomplete_segment(segment_dir: str) -> bool:
|
||||||
|
"""
|
||||||
|
Clean up a partially processed segment for restart.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
segment_dir: Path to segment directory
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if cleanup was successful
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Remove temporary files that might cause issues
|
||||||
|
temp_files = [
|
||||||
|
"low_res_video.mp4",
|
||||||
|
"output_frames_done"
|
||||||
|
]
|
||||||
|
|
||||||
|
removed_count = 0
|
||||||
|
for temp_file in temp_files:
|
||||||
|
temp_path = os.path.join(segment_dir, temp_file)
|
||||||
|
if os.path.exists(temp_path):
|
||||||
|
os.remove(temp_path)
|
||||||
|
removed_count += 1
|
||||||
|
|
||||||
|
if removed_count > 0:
|
||||||
|
logger.info(f"Cleaned up {removed_count} temporary files from {segment_dir}")
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to cleanup segment {segment_dir}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def find_last_valid_mask(segments_dir: str, before_segment: int) -> str:
|
||||||
|
"""
|
||||||
|
Find the most recent segment with a valid mask file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
segments_dir: Directory containing segments
|
||||||
|
before_segment: Look for masks before this segment index
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to the most recent valid mask, or empty string if none found
|
||||||
|
"""
|
||||||
|
for i in range(before_segment - 1, -1, -1):
|
||||||
|
segment_path = os.path.join(segments_dir, f"segment_{i}")
|
||||||
|
mask_path = os.path.join(segment_path, "mask.png")
|
||||||
|
|
||||||
|
if os.path.exists(mask_path):
|
||||||
|
return segment_path
|
||||||
|
|
||||||
|
return ""
|
||||||
Reference in New Issue
Block a user