Files
samyolo_on_segments/spec.md
2025-07-27 12:11:36 -07:00

811 lines
31 KiB
Markdown

# YOLO + SAM2 Video Processing Pipeline
## Overview
This project provides an automated video processing pipeline that uses YOLO for human detection and SAM2 for precise segmentation to create green screen videos. The system processes long videos by splitting them into manageable segments, detecting and tracking humans in each segment, and then reassembling the processed segments into a final output video with preserved audio.
## Core Functionality
### Input
- **Long video file** (MP4 format, any duration)
- **Configuration file** (YAML format) specifying processing parameters
### Output
- **Processed video file** with humans visible and background replaced with green screen
- **Preserved audio** from the original input video
- **Intermediate files** for debugging and quality control
## Processing Pipeline
### 1. Video Segmentation
- Splits input video into configurable-duration segments (default: 5 seconds)
- Creates organized directory structure: `segment_0/`, `segment_1/`, etc.
- Each segment folder contains the segment video file
- Generates force keyframes for consistent encoding
### 2. Human Detection & Tracking
- **YOLO Detection**: Automatically detects humans in keyframe segments using YOLOv8
- **SAM2 Segmentation**: Uses detected bounding boxes as prompts for precise mask generation
- **Mask Propagation**: Propagates masks across all frames in each segment
- **Stereo Video Support**: Handles VR/stereo content with left/right human assignment
- **Continuity**: Non-keyframe segments use previous segment masks for consistency
### 3. Green Screen Processing
- **Mask Application**: Applies generated masks to isolate humans
- **Background Replacement**: Replaces non-human areas with green screen (RGB: 0,255,0)
- **GPU Acceleration**: Uses CuPy for fast mask processing
- **Multi-resolution**: Low-res inference for speed, full-res final rendering
### 4. Video Assembly
- **Segment Concatenation**: Combines all processed segments into single video
- **Audio Preservation**: Copies original audio track to final output
- **Quality Maintenance**: Preserves original video quality and framerate
## Key Features
### Automated Processing
- **No Manual Intervention**: Fully automated human detection eliminates manual point selection
- **Batch Processing**: Processes multiple segments efficiently
- **Smart Fallback**: Robust mask propagation with intelligent previous-segment loading
### Modular Architecture
- **Configuration-Driven**: YAML-based configuration for easy parameter adjustment
- **Extensible Design**: Modular structure allows for easy feature additions
- **Error Recovery**: Graceful handling of detection failures and missing segments
### Performance Optimizations
- **GPU Acceleration**: CUDA/NVENC support for faster processing
- **Memory Management**: Efficient handling of large videos through segmentation
- **Concurrent Processing**: Thread-safe operations where applicable
## Technical Stack
### Core Dependencies
- **SAM2**: Facebook's Segment Anything Model 2 for precise segmentation
- **YOLOv8 (Ultralytics)**: Human detection and bounding box generation
- **OpenCV**: Video processing and frame manipulation
- **CuPy**: GPU-accelerated array operations
- **FFmpeg**: Video encoding/decoding and audio handling
- **PyTorch**: Deep learning framework backend
### Supported Formats
- **Input Video**: MP4, AVI, MOV (any OpenCV-supported format)
- **Output Video**: MP4 with H.265/HEVC encoding
- **Audio**: Preserves original audio codec and quality
## Configuration Options
### Video Processing
- `segment_duration`: Duration of each video segment (seconds)
- `inference_scale`: Scale factor for SAM2 inference (for speed)
- `output_scale`: Scale factor for final output
### Detection Parameters
- `yolo_model`: Path to YOLO model weights
- `yolo_confidence`: Detection confidence threshold
- `detect_segments`: Which segments to run YOLO detection on
### SAM2 Parameters
- `sam2_checkpoint`: Path to SAM2 model weights
- `sam2_config`: SAM2 model configuration file
### Output Options
- `use_nvenc`: Enable NVIDIA hardware encoding
- `output_bitrate`: Video bitrate for final output
- `preserve_audio`: Whether to copy audio track
## Directory Structure
```
new_yolo/
├── spec.md # This specification document
├── requirements.txt # Python dependencies
├── config.yaml # Default configuration file
├── main.py # Entry point script
├── core/
│ ├── __init__.py
│ ├── video_splitter.py # Video segmentation logic
│ ├── yolo_detector.py # YOLO human detection
│ ├── sam2_processor.py # SAM2 segmentation
│ ├── mask_processor.py # Mask application and green screen
│ ├── video_assembler.py # Final video assembly
│ └── config_loader.py # Configuration management
├── utils/
│ ├── __init__.py
│ ├── file_utils.py # File system operations
│ ├── video_utils.py # Video processing utilities
│ └── logging_utils.py # Logging configuration
└── examples/
├── basic_config.yaml # Example configuration
└── advanced_config.yaml # Advanced configuration options
```
## Usage Examples
### Basic Usage
```bash
python main.py --config config.yaml
```
### Custom Configuration
```bash
python main.py --config examples/advanced_config.yaml
```
### Configuration File Example
```yaml
input:
video_path: "/path/to/input/video.mp4"
output:
directory: "/path/to/output/"
filename: "processed_video.mp4"
processing:
segment_duration: 5
inference_scale: 0.5
yolo_confidence: 0.6
detect_segments: "all" # or [0, 5, 10]
models:
yolo_model: "yolov8n.pt"
sam2_checkpoint: "../checkpoints/sam2.1_hiera_large.pt"
sam2_config: "configs/sam2.1/sam2.1_hiera_l.yaml"
```
## Use Cases
### Content Creation
- **VR/360 Video Processing**: Remove backgrounds from immersive content
- **Green Screen Production**: Automated background removal for video production
- **Social Media Content**: Quick background replacement for content creators
### Commercial Applications
- **Video Conferencing**: Real-time background replacement
- **E-learning**: Professional video production with clean backgrounds
- **Marketing**: Product demonstration videos with custom backgrounds
## Performance Considerations
### Hardware Requirements
- **GPU**: NVIDIA GPU with CUDA support (recommended)
- **RAM**: 16GB+ for processing large videos
- **Storage**: SSD recommended for temporary file operations
### Processing Time
- Approximately **1-2x real-time** on modern GPUs
- Scales with video resolution and segment count
- Memory usage remains constant regardless of input video length
## Future Enhancements
### Planned Features
- **Multi-object Tracking**: Support for multiple humans per frame
- **Custom Object Detection**: Configurable object classes beyond humans
- **Real-time Processing**: Live video stream support
- **Cloud Integration**: AWS/GCP processing support
- **Web Interface**: Browser-based configuration and monitoring
### Model Improvements
- **Fine-tuned YOLO**: Domain-specific human detection models
- **SAM2 Optimization**: Custom SAM2 checkpoints for video content
- **Temporal Consistency**: Enhanced cross-segment mask propagation
Here is the original monolithic script this repo is a refactor/modularization of. If something
doesn't work in this repo, then consult the following script becasue it does work so this can
be used to solve problems:
import os
import cv2
import numpy as np
import cupy as cp
from concurrent.futures import ThreadPoolExecutor
import torch
import logging
import sys
import gc
from sam2.build_sam import build_sam2_video_predictor
import argparse
from ultralytics import YOLO
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
# Variables for input and output directories
SAM2_CHECKPOINT = "../checkpoints/sam2.1_hiera_large.pt"
MODEL_CFG = "configs/sam2.1/sam2.1_hiera_l.yaml"
GREEN = [0, 255, 0]
BLUE = [255, 0, 0]
INFERENCE_SCALE = 0.50
FULL_SCALE = 1.0
# YOLO model for human detection (class 0 = person)
YOLO_MODEL_PATH = "yolov8n.pt" # You can change this to a custom model
YOLO_CONFIDENCE = 0.6
HUMAN_CLASS_ID = 0 # COCO class ID for person
def open_video(video_path):
"""
Opens a video file and returns a generator that yields frames.
Parameters:
- video_path: Path to the video file.
Returns:
- A generator that yields frames from the video.
"""
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
print(f"Error: Could not open video file {video_path}")
return
while True:
ret, frame = cap.read()
if not ret:
break
yield frame
cap.release()
def load_previous_segment_mask(prev_segment_dir):
mask_path = os.path.join(prev_segment_dir, "mask.png")
mask_image = cv2.imread(mask_path)
if mask_image is None:
raise FileNotFoundError(f"Mask image not found at {mask_path}")
# Ensure the mask_image has three color channels
if len(mask_image.shape) != 3 or mask_image.shape[2] != 3:
raise ValueError("Mask image does not have three color channels.")
mask_image = mask_image.astype(np.uint8)
# Extract Object A and Object B masks
mask_a = np.all(mask_image == GREEN, axis=2)
mask_b = np.all(mask_image == BLUE, axis=2)
per_obj_input_mask = {1: mask_a, 2: mask_b}
input_palette = None # No palette needed for binary mask
return per_obj_input_mask, input_palette
def apply_green_mask(frame, masks):
# Convert frame and masks to CuPy arrays
frame_gpu = cp.asarray(frame)
combined_mask = cp.zeros(frame_gpu.shape[:2], dtype=cp.bool_)
for mask in masks:
mask_gpu = cp.asarray(mask.squeeze())
if mask_gpu.shape != frame_gpu.shape[:2]:
resized_mask = cv2.resize(cp.asnumpy(mask_gpu).astype(cp.float32),
(frame_gpu.shape[1], frame_gpu.shape[0]))
mask_gpu = cp.asarray(resized_mask > 0.5) # Convert back to CuPy boolean array
else:
mask_gpu = mask_gpu.astype(cp.bool_) # Ensure boolean type
combined_mask |= mask_gpu # Perform the bitwise OR operation
green_background = cp.full(frame_gpu.shape, cp.array([0, 255, 0], dtype=cp.uint8), dtype=cp.uint8)
result_frame = cp.where(combined_mask[..., None], frame_gpu, green_background)
return cp.asnumpy(result_frame) # Convert back to NumPy
def initialize_predictor():
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
print(
"\nSupport for MPS devices is preliminary. SAM 2 is trained with CUDA and might "
"give numerically different outputs and sometimes degraded performance on MPS."
)
# Enable MPS fallback for operations not supported on MPS
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
else:
device = torch.device("cpu")
logger.info(f"Using device: {device}")
predictor = build_sam2_video_predictor(MODEL_CFG, SAM2_CHECKPOINT, device=device)
return predictor
def load_first_frame(video_path, scale=1.0):
"""
Opens a video file and returns the first frame, scaled as specified.
Parameters:
- video_path: Path to the video file.
- scale: Scaling factor for the frame (default is 1.0 for original size).
Returns:
- first_frame: The first frame of the video, scaled accordingly.
"""
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
logger.error(f"Error: Could not open video file {video_path}")
return None
ret, frame = cap.read()
cap.release()
if not ret:
logger.error(f"Error: Could not read frame from video file {video_path}")
return None
if scale != 1.0:
frame = cv2.resize(
frame, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR
)
return frame
def detect_humans_with_yolo(frame, yolo_model, confidence_threshold=YOLO_CONFIDENCE):
"""
Detect humans in a frame using YOLO model.
Parameters:
- frame: Input frame (BGR format)
- yolo_model: Loaded YOLO model
- confidence_threshold: Detection confidence threshold
Returns:
- human_boxes: List of bounding boxes for detected humans
"""
# Run YOLO detection
results = yolo_model(frame, conf=confidence_threshold, verbose=False)
human_boxes = []
# Process results
for result in results:
boxes = result.boxes
if boxes is not None:
for box in boxes:
# Get class ID
cls = int(box.cls.cpu().numpy()[0])
# Check if it's a person (class 0 in COCO)
if cls == HUMAN_CLASS_ID:
# Get bounding box coordinates (x1, y1, x2, y2)
coords = box.xyxy[0].cpu().numpy()
conf = float(box.conf.cpu().numpy()[0])
human_boxes.append({
'bbox': coords,
'confidence': conf
})
logger.info(f"Detected human with confidence {conf:.2f} at {coords}")
return human_boxes
def add_yolo_detections_to_predictor(predictor, inference_state, human_detections, frame_width):
"""
Add YOLO human detections as bounding boxes to SAM2 predictor.
For stereo videos, creates two objects (left and right humans).
Parameters:
- predictor: SAM2 video predictor
- inference_state: SAM2 inference state
- human_detections: List of human detection results
- frame_width: Width of the frame for stereo splitting
Returns:
- out_mask_logits: SAM2 output mask logits
"""
half_frame_width = frame_width // 2
# Sort detections by x-coordinate to get left and right humans
human_detections.sort(key=lambda x: x['bbox'][0]) # Sort by x1 coordinate
obj_id = 1
out_mask_logits = None
for i, detection in enumerate(human_detections[:2]): # Take up to 2 humans (left and right)
bbox = detection['bbox']
# For stereo videos, assign obj_id based on position
if len(human_detections) >= 2:
# If we have multiple humans, assign based on left/right position
center_x = (bbox[0] + bbox[2]) / 2
if center_x < half_frame_width:
current_obj_id = 1 # Left human
else:
current_obj_id = 2 # Right human
else:
# If only one human, duplicate for both sides (as in original stereo logic)
current_obj_id = obj_id
obj_id += 1
# Also add the mirrored version for stereo
if obj_id <= 2:
mirrored_bbox = bbox.copy()
mirrored_bbox[0] += half_frame_width # Shift x1
mirrored_bbox[2] += half_frame_width # Shift x2
# Ensure mirrored bbox is within frame bounds
mirrored_bbox[0] = max(0, min(mirrored_bbox[0], frame_width - 1))
mirrored_bbox[2] = max(0, min(mirrored_bbox[2], frame_width - 1))
try:
_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
inference_state=inference_state,
frame_idx=0,
obj_id=obj_id,
box=mirrored_bbox.astype(np.float32),
)
logger.info(f"Added mirrored human detection for Object {obj_id}")
obj_id += 1
except Exception as e:
logger.error(f"Error adding mirrored human detection for Object {obj_id}: {e}")
try:
_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
inference_state=inference_state,
frame_idx=0,
obj_id=current_obj_id,
box=bbox.astype(np.float32),
)
logger.info(f"Added human detection for Object {current_obj_id}")
except Exception as e:
logger.error(f"Error adding human detection for Object {current_obj_id}: {e}")
return out_mask_logits
def propagate_masks(predictor, inference_state):
video_segments = {}
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
video_segments[out_frame_idx] = {
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
for i, out_obj_id in enumerate(out_obj_ids)
}
return video_segments
def apply_colored_mask(frame, masks_a, masks_b):
colored_mask = np.zeros_like(frame)
# Apply colors to the masks
for mask in masks_a:
mask = mask.squeeze()
if mask.shape != frame.shape[:2]:
mask = cv2.resize(mask, (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_NEAREST)
indices = np.where(mask)
colored_mask[mask] = [0, 255, 0] # Green for Object A
for mask in masks_b:
mask = mask.squeeze()
if mask.shape != frame.shape[:2]:
mask = cv2.resize(mask, (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_NEAREST)
indices = np.where(mask)
colored_mask[mask] = [255, 0, 0] # Blue for Object B
return colored_mask
def process_and_save_output_video(video_path, output_video_path, video_segments, use_nvenc=False):
"""
Process high-resolution frames, apply upscaled masks, and save the output video.
"""
cap = cv2.VideoCapture(video_path)
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = cap.get(cv2.CAP_PROP_FPS) or 59.94
# Setup VideoWriter with desired settings
if use_nvenc:
# Use FFmpeg with NVENC offloading for H.265 encoding
import subprocess
if sys.platform == 'darwin':
encoder = 'hevc_videotoolbox'
else:
encoder = 'hevc_nvenc'
command = [
'ffmpeg',
'-y', # Overwrite output file if it exists
'-f', 'rawvideo',
'-vcodec', 'rawvideo',
'-pix_fmt', 'bgr24',
'-s', f'{frame_width}x{frame_height}',
'-r', str(fps),
'-i', '-', # Input from stdin
'-an', # No audio
'-vcodec', encoder,
'-pix_fmt', 'nv12',
'-preset', 'slow',
'-b:v', '50M',
output_video_path
]
process = subprocess.Popen(command, stdin=subprocess.PIPE)
else:
# Use OpenCV VideoWriter
fourcc = cv2.VideoWriter_fourcc(*'HEVC') # H.265
out = cv2.VideoWriter(output_video_path, fourcc, fps, (frame_width, frame_height))
frame_idx = 0
while True:
ret, frame = cap.read()
if not ret or frame_idx >= len(video_segments):
break
masks = [video_segments[frame_idx][out_obj_id] for out_obj_id in video_segments[frame_idx]]
upscaled_masks = []
for mask in masks:
mask = mask.squeeze()
upscaled_mask = cv2.resize(mask.astype(np.uint8), (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_NEAREST)
upscaled_masks.append(upscaled_mask)
result_frame = apply_green_mask(frame, upscaled_masks)
# Write frame to output
if use_nvenc:
process.stdin.write(result_frame.tobytes())
else:
out.write(result_frame)
frame_idx += 1
cap.release()
if use_nvenc:
process.stdin.close()
process.wait()
else:
out.release()
def get_video_file_name(index):
return f"segment_{str(index).zfill(3)}.mp4"
def do_yolo_detection_on_segments(base_dir, segments, detect_segments, scale=1.0, yolo_model_path=YOLO_MODEL_PATH):
"""
Run YOLO detection on specified segments and save detection results.
"""
logger.info("Running YOLO detection on requested segments.")
# Load YOLO model
yolo_model = YOLO(yolo_model_path)
for i, segment in enumerate(segments):
segment_index = int(segment.split("_")[1])
segment_dir = os.path.join(base_dir, segment)
detection_file = os.path.join(segment_dir, "yolo_detections")
video_file = os.path.join(segment_dir, get_video_file_name(i))
if segment_index in detect_segments and not os.path.exists(detection_file):
first_frame = load_first_frame(video_file, scale)
if first_frame is None:
continue
# Convert BGR to RGB for YOLO (YOLO expects BGR, so keep as BGR)
human_detections = detect_humans_with_yolo(first_frame, yolo_model)
if human_detections:
# Save detection results
with open(detection_file, 'w') as f:
f.write("# YOLO Human Detections\n")
for detection in human_detections:
bbox = detection['bbox']
conf = detection['confidence']
f.write(f"{bbox[0]},{bbox[1]},{bbox[2]},{bbox[3]},{conf}\n")
logger.info(f"Saved {len(human_detections)} human detections for segment {segment}")
else:
logger.warning(f"No humans detected in segment {segment}")
# Create empty file to mark as processed
with open(detection_file, 'w') as f:
f.write("# No humans detected\n")
def save_final_masks(video_segments, mask_output_path):
"""
Save the final masks as a colored image.
"""
last_frame_idx = max(video_segments.keys())
masks_dict = video_segments[last_frame_idx]
# Assuming you have two objects with IDs 1 and 2
mask_a = masks_dict.get(1).squeeze() if 1 in masks_dict else None
mask_b = masks_dict.get(2).squeeze() if 2 in masks_dict else None
if mask_a is None and mask_b is None:
logger.error("No masks found for objects.")
return
# Use the first available mask to determine dimensions
reference_mask = mask_a if mask_a is not None else mask_b
black_frame = np.zeros((reference_mask.shape[0], reference_mask.shape[1], 3), dtype=np.uint8)
if mask_a is not None:
mask_a = mask_a.astype(bool)
black_frame[mask_a] = GREEN
if mask_b is not None:
mask_b = mask_b.astype(bool)
black_frame[mask_b] = BLUE
# Save the mask image
cv2.imwrite(mask_output_path, black_frame)
logger.info(f"Saved final masks to {mask_output_path}")
def create_low_res_video(input_video_path, output_video_path, scale):
"""
Creates a low-resolution version of the input video for inference.
"""
cap = cv2.VideoCapture(input_video_path)
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH) * scale)
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT) * scale)
fps = cap.get(cv2.CAP_PROP_FPS) or 59.94
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_video_path, fourcc, fps, (frame_width, frame_height))
while True:
ret, frame = cap.read()
if not ret:
break
low_res_frame = cv2.resize(frame, (frame_width, frame_height), interpolation=cv2.INTER_LINEAR)
out.write(low_res_frame)
cap.release()
out.release()
def main():
parser = argparse.ArgumentParser(description="Process video segments with YOLO + SAM2.")
parser.add_argument("--base-dir", type=str, help="Base directory for video segments.")
parser.add_argument("--segments-detect-humans", nargs='*', help="Segments for which to run YOLO human detection. Use 'all' for all segments, or list specific segment numbers (e.g., 1 5 10). Default: all segments.")
parser.add_argument("--yolo-model", type=str, default=YOLO_MODEL_PATH, help="Path to YOLO model.")
parser.add_argument("--yolo-confidence", type=float, default=YOLO_CONFIDENCE, help="YOLO detection confidence threshold.")
args = parser.parse_args()
base_dir = args.base_dir
segments = [d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d)) and d.startswith("segment_")]
segments.sort(key=lambda x: int(x.split("_")[1]))
# Handle different ways to specify segments for YOLO detection
if args.segments_detect_humans is None or len(args.segments_detect_humans) == 0:
# Default: run YOLO on all segments
detect_segments = [int(seg.split("_")[1]) for seg in segments]
logger.info("No segments specified, running YOLO detection on ALL segments")
elif len(args.segments_detect_humans) == 1 and args.segments_detect_humans[0].lower() == 'all':
# Explicit 'all' keyword
detect_segments = [int(seg.split("_")[1]) for seg in segments]
logger.info("Running YOLO detection on ALL segments")
else:
# Specific segment numbers provided
try:
detect_segments = [int(x) for x in args.segments_detect_humans]
logger.info(f"Running YOLO detection on segments: {detect_segments}")
except ValueError:
logger.error("Invalid segment numbers provided. Use integers or 'all'.")
return
# Run YOLO detection on specified segments
do_yolo_detection_on_segments(base_dir, segments, detect_segments, scale=INFERENCE_SCALE, yolo_model_path=args.yolo_model)
# Load YOLO model for inference
yolo_model = YOLO(args.yolo_model)
for i, segment in enumerate(segments):
segment_index = int(segment.split("_")[1])
segment_dir = os.path.join(base_dir, segment)
video_file_name = get_video_file_name(i)
video_path = os.path.join(segment_dir, video_file_name)
output_done_file = os.path.join(segment_dir, "output_frames_done")
if os.path.exists(output_done_file):
logger.info(f"Segment {segment} already processed. Skipping.")
continue
logger.info(f"Processing segment {segment}")
# Initialize predictor
predictor = initialize_predictor()
# Prepare low-resolution video frames for inference
low_res_video_path = os.path.join(segment_dir, "low_res_video.mp4")
if not os.path.exists(low_res_video_path):
create_low_res_video(video_path, low_res_video_path, INFERENCE_SCALE)
logger.info(f"Low-resolution video created for segment {segment}")
else:
logger.info(f"Low-resolution video already exists for segment {segment}, reuse")
# Initialize inference state with low-resolution video
inference_state = predictor.init_state(video_path=low_res_video_path, async_loading_frames=True)
# Load YOLO detections or previous masks
detection_file = os.path.join(segment_dir, "yolo_detections")
use_detections = segment_index in detect_segments
if i == 0 and not use_detections:
# First segment must use YOLO detection since there's no previous mask
logger.warning(f"First segment {segment} requires YOLO detection. Running YOLO detection.")
use_detections = True
if i > 0 and not use_detections:
# Try to load previous segment mask - search backwards for the most recent successful mask
logger.info(f"Using previous segment mask for segment {segment}")
mask_found = False
# Search backwards through previous segments to find a valid mask
for j in range(i - 1, -1, -1):
prev_segment_dir = os.path.join(base_dir, segments[j])
prev_mask_path = os.path.join(prev_segment_dir, "mask.png")
if os.path.exists(prev_mask_path):
try:
per_obj_input_mask, input_palette = load_previous_segment_mask(prev_segment_dir)
# Add previous masks to predictor
for obj_id, mask in per_obj_input_mask.items():
predictor.add_new_mask(inference_state, 0, obj_id, mask)
logger.info(f"Successfully loaded mask from segment {segments[j]}")
mask_found = True
break
except Exception as e:
logger.warning(f"Error loading mask from {segments[j]}: {e}")
continue
if not mask_found:
logger.error(f"No valid previous mask found for segment {segment}. Consider running YOLO detection on this segment.")
continue
else:
# Load first frame for detection
first_frame = load_first_frame(low_res_video_path, scale=1.0)
if first_frame is None:
logger.error(f"Could not load first frame for segment {segment}")
continue
# Run YOLO detection on first frame (either from file or on-the-fly)
if os.path.exists(detection_file):
logger.info(f"Using existing YOLO detections for segment {segment}")
else:
logger.info(f"Running YOLO detection on-the-fly for segment {segment}")
human_detections = detect_humans_with_yolo(first_frame, yolo_model, args.yolo_confidence)
if human_detections:
# Add YOLO detections to predictor
frame_width = first_frame.shape[1]
add_yolo_detections_to_predictor(predictor, inference_state, human_detections, frame_width)
else:
logger.warning(f"No humans detected in segment {segment}")
continue
# Perform inference and collect masks per frame
video_segments = propagate_masks(predictor, inference_state)
# Process high-resolution frames and save output video
output_video_path = os.path.join(segment_dir, f"output_{segment_index}.mp4")
logger.info("Processing segment complete, attempting to save full video from low res masks")
process_and_save_output_video(
video_path,
output_video_path,
video_segments,
use_nvenc=True # Set to True to use NVENC offloading
)
# Save final masks
mask_output_path = os.path.join(segment_dir, "mask.png")
save_final_masks(video_segments, mask_output_path)
# Clean up
predictor.reset_state(inference_state)
del inference_state
del video_segments
del predictor
gc.collect()
try:
os.remove(low_res_video_path)
logger.info(f"Deleted low-resolution video for segment {segment}")
except Exception as e:
logger.warning(f"Could not delete low-resolution video for segment {segment}: {e}")
# Mark segment as completed
open(output_done_file, 'a').close()
logger.info("Processing complete.")
if __name__ == "__main__":
main()