stage 1 working

This commit is contained in:
2025-07-27 12:11:36 -07:00
parent ed08ef2b4b
commit 46363a8a11
6 changed files with 993 additions and 51 deletions

620
spec.md
View File

@@ -189,4 +189,622 @@ models:
### Model Improvements
- **Fine-tuned YOLO**: Domain-specific human detection models
- **SAM2 Optimization**: Custom SAM2 checkpoints for video content
- **Temporal Consistency**: Enhanced cross-segment mask propagation
- **Temporal Consistency**: Enhanced cross-segment mask propagation
Here is the original monolithic script this repo is a refactor/modularization of. If something
doesn't work in this repo, then consult the following script becasue it does work so this can
be used to solve problems:
import os
import cv2
import numpy as np
import cupy as cp
from concurrent.futures import ThreadPoolExecutor
import torch
import logging
import sys
import gc
from sam2.build_sam import build_sam2_video_predictor
import argparse
from ultralytics import YOLO
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
# Variables for input and output directories
SAM2_CHECKPOINT = "../checkpoints/sam2.1_hiera_large.pt"
MODEL_CFG = "configs/sam2.1/sam2.1_hiera_l.yaml"
GREEN = [0, 255, 0]
BLUE = [255, 0, 0]
INFERENCE_SCALE = 0.50
FULL_SCALE = 1.0
# YOLO model for human detection (class 0 = person)
YOLO_MODEL_PATH = "yolov8n.pt" # You can change this to a custom model
YOLO_CONFIDENCE = 0.6
HUMAN_CLASS_ID = 0 # COCO class ID for person
def open_video(video_path):
"""
Opens a video file and returns a generator that yields frames.
Parameters:
- video_path: Path to the video file.
Returns:
- A generator that yields frames from the video.
"""
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
print(f"Error: Could not open video file {video_path}")
return
while True:
ret, frame = cap.read()
if not ret:
break
yield frame
cap.release()
def load_previous_segment_mask(prev_segment_dir):
mask_path = os.path.join(prev_segment_dir, "mask.png")
mask_image = cv2.imread(mask_path)
if mask_image is None:
raise FileNotFoundError(f"Mask image not found at {mask_path}")
# Ensure the mask_image has three color channels
if len(mask_image.shape) != 3 or mask_image.shape[2] != 3:
raise ValueError("Mask image does not have three color channels.")
mask_image = mask_image.astype(np.uint8)
# Extract Object A and Object B masks
mask_a = np.all(mask_image == GREEN, axis=2)
mask_b = np.all(mask_image == BLUE, axis=2)
per_obj_input_mask = {1: mask_a, 2: mask_b}
input_palette = None # No palette needed for binary mask
return per_obj_input_mask, input_palette
def apply_green_mask(frame, masks):
# Convert frame and masks to CuPy arrays
frame_gpu = cp.asarray(frame)
combined_mask = cp.zeros(frame_gpu.shape[:2], dtype=cp.bool_)
for mask in masks:
mask_gpu = cp.asarray(mask.squeeze())
if mask_gpu.shape != frame_gpu.shape[:2]:
resized_mask = cv2.resize(cp.asnumpy(mask_gpu).astype(cp.float32),
(frame_gpu.shape[1], frame_gpu.shape[0]))
mask_gpu = cp.asarray(resized_mask > 0.5) # Convert back to CuPy boolean array
else:
mask_gpu = mask_gpu.astype(cp.bool_) # Ensure boolean type
combined_mask |= mask_gpu # Perform the bitwise OR operation
green_background = cp.full(frame_gpu.shape, cp.array([0, 255, 0], dtype=cp.uint8), dtype=cp.uint8)
result_frame = cp.where(combined_mask[..., None], frame_gpu, green_background)
return cp.asnumpy(result_frame) # Convert back to NumPy
def initialize_predictor():
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
print(
"\nSupport for MPS devices is preliminary. SAM 2 is trained with CUDA and might "
"give numerically different outputs and sometimes degraded performance on MPS."
)
# Enable MPS fallback for operations not supported on MPS
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
else:
device = torch.device("cpu")
logger.info(f"Using device: {device}")
predictor = build_sam2_video_predictor(MODEL_CFG, SAM2_CHECKPOINT, device=device)
return predictor
def load_first_frame(video_path, scale=1.0):
"""
Opens a video file and returns the first frame, scaled as specified.
Parameters:
- video_path: Path to the video file.
- scale: Scaling factor for the frame (default is 1.0 for original size).
Returns:
- first_frame: The first frame of the video, scaled accordingly.
"""
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
logger.error(f"Error: Could not open video file {video_path}")
return None
ret, frame = cap.read()
cap.release()
if not ret:
logger.error(f"Error: Could not read frame from video file {video_path}")
return None
if scale != 1.0:
frame = cv2.resize(
frame, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR
)
return frame
def detect_humans_with_yolo(frame, yolo_model, confidence_threshold=YOLO_CONFIDENCE):
"""
Detect humans in a frame using YOLO model.
Parameters:
- frame: Input frame (BGR format)
- yolo_model: Loaded YOLO model
- confidence_threshold: Detection confidence threshold
Returns:
- human_boxes: List of bounding boxes for detected humans
"""
# Run YOLO detection
results = yolo_model(frame, conf=confidence_threshold, verbose=False)
human_boxes = []
# Process results
for result in results:
boxes = result.boxes
if boxes is not None:
for box in boxes:
# Get class ID
cls = int(box.cls.cpu().numpy()[0])
# Check if it's a person (class 0 in COCO)
if cls == HUMAN_CLASS_ID:
# Get bounding box coordinates (x1, y1, x2, y2)
coords = box.xyxy[0].cpu().numpy()
conf = float(box.conf.cpu().numpy()[0])
human_boxes.append({
'bbox': coords,
'confidence': conf
})
logger.info(f"Detected human with confidence {conf:.2f} at {coords}")
return human_boxes
def add_yolo_detections_to_predictor(predictor, inference_state, human_detections, frame_width):
"""
Add YOLO human detections as bounding boxes to SAM2 predictor.
For stereo videos, creates two objects (left and right humans).
Parameters:
- predictor: SAM2 video predictor
- inference_state: SAM2 inference state
- human_detections: List of human detection results
- frame_width: Width of the frame for stereo splitting
Returns:
- out_mask_logits: SAM2 output mask logits
"""
half_frame_width = frame_width // 2
# Sort detections by x-coordinate to get left and right humans
human_detections.sort(key=lambda x: x['bbox'][0]) # Sort by x1 coordinate
obj_id = 1
out_mask_logits = None
for i, detection in enumerate(human_detections[:2]): # Take up to 2 humans (left and right)
bbox = detection['bbox']
# For stereo videos, assign obj_id based on position
if len(human_detections) >= 2:
# If we have multiple humans, assign based on left/right position
center_x = (bbox[0] + bbox[2]) / 2
if center_x < half_frame_width:
current_obj_id = 1 # Left human
else:
current_obj_id = 2 # Right human
else:
# If only one human, duplicate for both sides (as in original stereo logic)
current_obj_id = obj_id
obj_id += 1
# Also add the mirrored version for stereo
if obj_id <= 2:
mirrored_bbox = bbox.copy()
mirrored_bbox[0] += half_frame_width # Shift x1
mirrored_bbox[2] += half_frame_width # Shift x2
# Ensure mirrored bbox is within frame bounds
mirrored_bbox[0] = max(0, min(mirrored_bbox[0], frame_width - 1))
mirrored_bbox[2] = max(0, min(mirrored_bbox[2], frame_width - 1))
try:
_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
inference_state=inference_state,
frame_idx=0,
obj_id=obj_id,
box=mirrored_bbox.astype(np.float32),
)
logger.info(f"Added mirrored human detection for Object {obj_id}")
obj_id += 1
except Exception as e:
logger.error(f"Error adding mirrored human detection for Object {obj_id}: {e}")
try:
_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
inference_state=inference_state,
frame_idx=0,
obj_id=current_obj_id,
box=bbox.astype(np.float32),
)
logger.info(f"Added human detection for Object {current_obj_id}")
except Exception as e:
logger.error(f"Error adding human detection for Object {current_obj_id}: {e}")
return out_mask_logits
def propagate_masks(predictor, inference_state):
video_segments = {}
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
video_segments[out_frame_idx] = {
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
for i, out_obj_id in enumerate(out_obj_ids)
}
return video_segments
def apply_colored_mask(frame, masks_a, masks_b):
colored_mask = np.zeros_like(frame)
# Apply colors to the masks
for mask in masks_a:
mask = mask.squeeze()
if mask.shape != frame.shape[:2]:
mask = cv2.resize(mask, (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_NEAREST)
indices = np.where(mask)
colored_mask[mask] = [0, 255, 0] # Green for Object A
for mask in masks_b:
mask = mask.squeeze()
if mask.shape != frame.shape[:2]:
mask = cv2.resize(mask, (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_NEAREST)
indices = np.where(mask)
colored_mask[mask] = [255, 0, 0] # Blue for Object B
return colored_mask
def process_and_save_output_video(video_path, output_video_path, video_segments, use_nvenc=False):
"""
Process high-resolution frames, apply upscaled masks, and save the output video.
"""
cap = cv2.VideoCapture(video_path)
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = cap.get(cv2.CAP_PROP_FPS) or 59.94
# Setup VideoWriter with desired settings
if use_nvenc:
# Use FFmpeg with NVENC offloading for H.265 encoding
import subprocess
if sys.platform == 'darwin':
encoder = 'hevc_videotoolbox'
else:
encoder = 'hevc_nvenc'
command = [
'ffmpeg',
'-y', # Overwrite output file if it exists
'-f', 'rawvideo',
'-vcodec', 'rawvideo',
'-pix_fmt', 'bgr24',
'-s', f'{frame_width}x{frame_height}',
'-r', str(fps),
'-i', '-', # Input from stdin
'-an', # No audio
'-vcodec', encoder,
'-pix_fmt', 'nv12',
'-preset', 'slow',
'-b:v', '50M',
output_video_path
]
process = subprocess.Popen(command, stdin=subprocess.PIPE)
else:
# Use OpenCV VideoWriter
fourcc = cv2.VideoWriter_fourcc(*'HEVC') # H.265
out = cv2.VideoWriter(output_video_path, fourcc, fps, (frame_width, frame_height))
frame_idx = 0
while True:
ret, frame = cap.read()
if not ret or frame_idx >= len(video_segments):
break
masks = [video_segments[frame_idx][out_obj_id] for out_obj_id in video_segments[frame_idx]]
upscaled_masks = []
for mask in masks:
mask = mask.squeeze()
upscaled_mask = cv2.resize(mask.astype(np.uint8), (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_NEAREST)
upscaled_masks.append(upscaled_mask)
result_frame = apply_green_mask(frame, upscaled_masks)
# Write frame to output
if use_nvenc:
process.stdin.write(result_frame.tobytes())
else:
out.write(result_frame)
frame_idx += 1
cap.release()
if use_nvenc:
process.stdin.close()
process.wait()
else:
out.release()
def get_video_file_name(index):
return f"segment_{str(index).zfill(3)}.mp4"
def do_yolo_detection_on_segments(base_dir, segments, detect_segments, scale=1.0, yolo_model_path=YOLO_MODEL_PATH):
"""
Run YOLO detection on specified segments and save detection results.
"""
logger.info("Running YOLO detection on requested segments.")
# Load YOLO model
yolo_model = YOLO(yolo_model_path)
for i, segment in enumerate(segments):
segment_index = int(segment.split("_")[1])
segment_dir = os.path.join(base_dir, segment)
detection_file = os.path.join(segment_dir, "yolo_detections")
video_file = os.path.join(segment_dir, get_video_file_name(i))
if segment_index in detect_segments and not os.path.exists(detection_file):
first_frame = load_first_frame(video_file, scale)
if first_frame is None:
continue
# Convert BGR to RGB for YOLO (YOLO expects BGR, so keep as BGR)
human_detections = detect_humans_with_yolo(first_frame, yolo_model)
if human_detections:
# Save detection results
with open(detection_file, 'w') as f:
f.write("# YOLO Human Detections\n")
for detection in human_detections:
bbox = detection['bbox']
conf = detection['confidence']
f.write(f"{bbox[0]},{bbox[1]},{bbox[2]},{bbox[3]},{conf}\n")
logger.info(f"Saved {len(human_detections)} human detections for segment {segment}")
else:
logger.warning(f"No humans detected in segment {segment}")
# Create empty file to mark as processed
with open(detection_file, 'w') as f:
f.write("# No humans detected\n")
def save_final_masks(video_segments, mask_output_path):
"""
Save the final masks as a colored image.
"""
last_frame_idx = max(video_segments.keys())
masks_dict = video_segments[last_frame_idx]
# Assuming you have two objects with IDs 1 and 2
mask_a = masks_dict.get(1).squeeze() if 1 in masks_dict else None
mask_b = masks_dict.get(2).squeeze() if 2 in masks_dict else None
if mask_a is None and mask_b is None:
logger.error("No masks found for objects.")
return
# Use the first available mask to determine dimensions
reference_mask = mask_a if mask_a is not None else mask_b
black_frame = np.zeros((reference_mask.shape[0], reference_mask.shape[1], 3), dtype=np.uint8)
if mask_a is not None:
mask_a = mask_a.astype(bool)
black_frame[mask_a] = GREEN
if mask_b is not None:
mask_b = mask_b.astype(bool)
black_frame[mask_b] = BLUE
# Save the mask image
cv2.imwrite(mask_output_path, black_frame)
logger.info(f"Saved final masks to {mask_output_path}")
def create_low_res_video(input_video_path, output_video_path, scale):
"""
Creates a low-resolution version of the input video for inference.
"""
cap = cv2.VideoCapture(input_video_path)
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH) * scale)
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT) * scale)
fps = cap.get(cv2.CAP_PROP_FPS) or 59.94
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_video_path, fourcc, fps, (frame_width, frame_height))
while True:
ret, frame = cap.read()
if not ret:
break
low_res_frame = cv2.resize(frame, (frame_width, frame_height), interpolation=cv2.INTER_LINEAR)
out.write(low_res_frame)
cap.release()
out.release()
def main():
parser = argparse.ArgumentParser(description="Process video segments with YOLO + SAM2.")
parser.add_argument("--base-dir", type=str, help="Base directory for video segments.")
parser.add_argument("--segments-detect-humans", nargs='*', help="Segments for which to run YOLO human detection. Use 'all' for all segments, or list specific segment numbers (e.g., 1 5 10). Default: all segments.")
parser.add_argument("--yolo-model", type=str, default=YOLO_MODEL_PATH, help="Path to YOLO model.")
parser.add_argument("--yolo-confidence", type=float, default=YOLO_CONFIDENCE, help="YOLO detection confidence threshold.")
args = parser.parse_args()
base_dir = args.base_dir
segments = [d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d)) and d.startswith("segment_")]
segments.sort(key=lambda x: int(x.split("_")[1]))
# Handle different ways to specify segments for YOLO detection
if args.segments_detect_humans is None or len(args.segments_detect_humans) == 0:
# Default: run YOLO on all segments
detect_segments = [int(seg.split("_")[1]) for seg in segments]
logger.info("No segments specified, running YOLO detection on ALL segments")
elif len(args.segments_detect_humans) == 1 and args.segments_detect_humans[0].lower() == 'all':
# Explicit 'all' keyword
detect_segments = [int(seg.split("_")[1]) for seg in segments]
logger.info("Running YOLO detection on ALL segments")
else:
# Specific segment numbers provided
try:
detect_segments = [int(x) for x in args.segments_detect_humans]
logger.info(f"Running YOLO detection on segments: {detect_segments}")
except ValueError:
logger.error("Invalid segment numbers provided. Use integers or 'all'.")
return
# Run YOLO detection on specified segments
do_yolo_detection_on_segments(base_dir, segments, detect_segments, scale=INFERENCE_SCALE, yolo_model_path=args.yolo_model)
# Load YOLO model for inference
yolo_model = YOLO(args.yolo_model)
for i, segment in enumerate(segments):
segment_index = int(segment.split("_")[1])
segment_dir = os.path.join(base_dir, segment)
video_file_name = get_video_file_name(i)
video_path = os.path.join(segment_dir, video_file_name)
output_done_file = os.path.join(segment_dir, "output_frames_done")
if os.path.exists(output_done_file):
logger.info(f"Segment {segment} already processed. Skipping.")
continue
logger.info(f"Processing segment {segment}")
# Initialize predictor
predictor = initialize_predictor()
# Prepare low-resolution video frames for inference
low_res_video_path = os.path.join(segment_dir, "low_res_video.mp4")
if not os.path.exists(low_res_video_path):
create_low_res_video(video_path, low_res_video_path, INFERENCE_SCALE)
logger.info(f"Low-resolution video created for segment {segment}")
else:
logger.info(f"Low-resolution video already exists for segment {segment}, reuse")
# Initialize inference state with low-resolution video
inference_state = predictor.init_state(video_path=low_res_video_path, async_loading_frames=True)
# Load YOLO detections or previous masks
detection_file = os.path.join(segment_dir, "yolo_detections")
use_detections = segment_index in detect_segments
if i == 0 and not use_detections:
# First segment must use YOLO detection since there's no previous mask
logger.warning(f"First segment {segment} requires YOLO detection. Running YOLO detection.")
use_detections = True
if i > 0 and not use_detections:
# Try to load previous segment mask - search backwards for the most recent successful mask
logger.info(f"Using previous segment mask for segment {segment}")
mask_found = False
# Search backwards through previous segments to find a valid mask
for j in range(i - 1, -1, -1):
prev_segment_dir = os.path.join(base_dir, segments[j])
prev_mask_path = os.path.join(prev_segment_dir, "mask.png")
if os.path.exists(prev_mask_path):
try:
per_obj_input_mask, input_palette = load_previous_segment_mask(prev_segment_dir)
# Add previous masks to predictor
for obj_id, mask in per_obj_input_mask.items():
predictor.add_new_mask(inference_state, 0, obj_id, mask)
logger.info(f"Successfully loaded mask from segment {segments[j]}")
mask_found = True
break
except Exception as e:
logger.warning(f"Error loading mask from {segments[j]}: {e}")
continue
if not mask_found:
logger.error(f"No valid previous mask found for segment {segment}. Consider running YOLO detection on this segment.")
continue
else:
# Load first frame for detection
first_frame = load_first_frame(low_res_video_path, scale=1.0)
if first_frame is None:
logger.error(f"Could not load first frame for segment {segment}")
continue
# Run YOLO detection on first frame (either from file or on-the-fly)
if os.path.exists(detection_file):
logger.info(f"Using existing YOLO detections for segment {segment}")
else:
logger.info(f"Running YOLO detection on-the-fly for segment {segment}")
human_detections = detect_humans_with_yolo(first_frame, yolo_model, args.yolo_confidence)
if human_detections:
# Add YOLO detections to predictor
frame_width = first_frame.shape[1]
add_yolo_detections_to_predictor(predictor, inference_state, human_detections, frame_width)
else:
logger.warning(f"No humans detected in segment {segment}")
continue
# Perform inference and collect masks per frame
video_segments = propagate_masks(predictor, inference_state)
# Process high-resolution frames and save output video
output_video_path = os.path.join(segment_dir, f"output_{segment_index}.mp4")
logger.info("Processing segment complete, attempting to save full video from low res masks")
process_and_save_output_video(
video_path,
output_video_path,
video_segments,
use_nvenc=True # Set to True to use NVENC offloading
)
# Save final masks
mask_output_path = os.path.join(segment_dir, "mask.png")
save_final_masks(video_segments, mask_output_path)
# Clean up
predictor.reset_state(inference_state)
del inference_state
del video_segments
del predictor
gc.collect()
try:
os.remove(low_res_video_path)
logger.info(f"Deleted low-resolution video for segment {segment}")
except Exception as e:
logger.warning(f"Could not delete low-resolution video for segment {segment}: {e}")
# Mark segment as completed
open(output_done_file, 'a').close()
logger.info("Processing complete.")
if __name__ == "__main__":
main()