Files
sam2/notebooks/foo_points_prev.py

637 lines
23 KiB
Python

# this script will process multiple video segments (5 second clips)
# of a long (1 hour total video) and will greenscreen/key out everything
# except tracked objects.
#
# To specify which items are tracked, the user will call the script with
# --segments-collect-points 1 5 10 ... Which are segments to be treated
# as "keyframes", for these we will ask the user to select the point
# we want to track by hand, this is done for large changes in the objects
# position in the video.
#
# For other segments, since the object is mostly static in the frame,
# We will use the previous segments final frame (an image
# with the object we want to track visible, and everything else green)
# to determine an input mask and use add_new_mask() instead of selecting
# points.
#
# Each segment has 2 versions of each frame, on high quality used for
# final rendering and 1 low quality used to speed up inference
#
# When the script finishes, each segment should have an output directory
# with the same object tracked throughout the every frame in all the segment directories
#
# I will then turn these back into a video using ffmpeg but that is outside the scope
# of this program
import os
import cv2
import numpy as np
from concurrent.futures import ThreadPoolExecutor
import torch
import logging
import sys
from sam2.build_sam import build_sam2_video_predictor
import argparse
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
# Variables for input and output directories
SAM2_CHECKPOINT = "../checkpoints/sam2.1_hiera_large.pt"
MODEL_CFG = "configs/sam2.1/sam2.1_hiera_l.yaml"
INFERENCE_SCALE = 0.5
FULL_SCALE = 1.0
def open_video(video_path):
"""
Opens a video file and returns a generator that yields frames.
Parameters:
- video_path: Path to the video file.
Returns:
- A generator that yields frames from the video.
"""
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
print(f"Error: Could not open video file {video_path}")
return
while True:
ret, frame = cap.read()
if not ret:
break
yield frame
cap.release()
def load_previous_segment_mask(prev_segment_dir):
mask_path = os.path.join(prev_segment_dir, "mask.jpg")
mask_image = cv2.imread(mask_path)
# Extract Object A and Object B masks
mask_a = (mask_image[:, :, 1] == 255) # Green channel
mask_b = (mask_image[:, :, 0] == 254) # Blue channel
per_obj_input_mask = {1: mask_a, 2: mask_b}
input_palette = None # No palette needed for binary mask
return per_obj_input_mask, input_palette
def apply_green_mask(frame, masks):
"""
Applies masks to the frame, replacing the background with green.
Parameters:
- frame: numpy array representing the image frame.
- masks: list of numpy arrays representing the masks.
Returns:
- result_frame: numpy array with the green background applied.
"""
# Initialize combined mask as a boolean array
combined_mask = np.zeros(frame.shape[:2], dtype=bool)
for mask in masks:
mask = mask.squeeze()
# Resize the mask if necessary
if mask.shape != frame.shape[:2]:
# Resize the mask using bilinear interpolation
# and convert it to float32 for accurate interpolation
resized_mask = cv2.resize(
mask.astype(np.float32),
(frame.shape[1], frame.shape[0]),
interpolation=cv2.INTER_LINEAR
)
# Threshold the resized mask to obtain a boolean mask
mask = resized_mask > 0.5
else:
# Ensure mask is boolean
mask = mask.astype(bool)
# Combine masks using logical OR
combined_mask |= mask # Now both arrays are bool
# Create a green background image
green_background = np.full_like(frame, [0, 255, 0])
# Use combined mask to overlay the original frame onto the green background
result_frame = np.where(
combined_mask[..., None],
frame,
green_background
)
return result_frame
def initialize_predictor():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
predictor = build_sam2_video_predictor(MODEL_CFG, SAM2_CHECKPOINT, device=device)
return predictor
def load_first_frame(video_path, scale=1.0):
"""
Opens a video file and returns the first frame, scaled as specified.
Parameters:
- video_path: Path to the video file.
- scale: Scaling factor for the frame (default is 1.0 for original size).
Returns:
- first_frame: The first frame of the video, scaled accordingly.
"""
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
logger.error(f"Error: Could not open video file {video_path}")
return None
ret, frame = cap.read()
cap.release()
if not ret:
logger.error(f"Error: Could not read frame from video file {video_path}")
return None
if scale != 1.0:
frame = cv2.resize(
frame, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR
)
return frame
def select_points(first_frame):
points_a = []
points_b = []
current_object = 'A'
point_count = 0
selection_complete = False
def mouse_callback(event, x, y, flags, param):
nonlocal points_a, points_b, current_object, point_count, selection_complete
if event == cv2.EVENT_LBUTTONDOWN:
if current_object == 'A':
points_a.append((x, y))
point_count += 1
print(f"Selected point {point_count} for Object A: ({x}, {y})")
if len(points_a) == 4: # Collect 4 points for Object A
current_object = 'B'
point_count = 0
print("Select point 1 for Object B")
elif current_object == 'B':
points_b.append((x, y))
point_count += 1
print(f"Selected point {point_count} for Object B: ({x}, {y})")
if len(points_b) == 4: # Collect 4 points for Object B
selection_complete = True
print("Select point 1 for Object A")
cv2.namedWindow('Select Points', cv2.WINDOW_NORMAL)
cv2.resizeWindow('Select Points', int(first_frame.shape[1] * (500 / first_frame.shape[0])), 500)
cv2.imshow('Select Points', first_frame)
cv2.setMouseCallback('Select Points', mouse_callback)
while not selection_complete:
cv2.waitKey(1)
cv2.destroyAllWindows()
return np.array(points_a, dtype=np.float32), np.array(points_b, dtype=np.float32)
def add_points_to_predictor(predictor, inference_state, points, obj_id):
labels = np.array([1, 1, 1, 1], np.int32) # Update labels to match 4 points
points = np.array(points, dtype=np.float32) # Ensure points have shape (4, 2)
try:
print(f"Adding points for Object {obj_id}: {points}")
_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
inference_state=inference_state,
frame_idx=0,
obj_id=obj_id,
points=points,
labels=labels,
)
print(f"Object {obj_id} added successfully: {out_obj_ids}")
return out_mask_logits
except Exception as e:
print(f"Error adding points for Object {obj_id}: {e}")
exit()
def show_mask_on_frame(frame, masks):
combined_mask = np.zeros(frame.shape[:2], dtype=bool)
for mask in masks:
mask = mask.squeeze()
if mask.shape != frame.shape[:2]:
mask = cv2.resize(mask, (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_NEAREST)
combined_mask = np.logical_or(combined_mask, mask)
color = (0, 255, 0)
frame[combined_mask] = color
return frame
def confirm_masks(first_frame, masks_a, masks_b):
first_frame_with_masks = show_mask_on_frame(first_frame.copy(), masks_a + masks_b)
cv2.namedWindow('First Frame with Masks', cv2.WINDOW_NORMAL)
cv2.resizeWindow('First Frame with Masks', int(first_frame.shape[1] * (500 / first_frame.shape[0])), 500)
cv2.imshow('First Frame with Masks', first_frame_with_masks)
cv2.waitKey(0)
cv2.destroyAllWindows()
confirmation = input("Are the masks correct? (yes/no): ").strip().lower()
if confirmation != 'yes':
print("Aborting process.")
exit()
def propagate_masks(predictor, inference_state):
video_segments = {}
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
video_segments[out_frame_idx] = {
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
for i, out_obj_id in enumerate(out_obj_ids)
}
return video_segments
def apply_colored_mask(frame, masks_a, masks_b):
colored_mask = np.zeros_like(frame)
# Apply colors to the masks
for mask in masks_a:
mask = mask.squeeze()
if mask.shape != frame.shape[:2]:
mask = cv2.resize(mask, (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_NEAREST)
indices = np.where(mask)
colored_mask[mask] = [0, 255, 0] # Green for Object A
for mask in masks_b:
mask = mask.squeeze()
if mask.shape != frame.shape[:2]:
mask = cv2.resize(mask, (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_NEAREST)
indices = np.where(mask)
colored_mask[mask] = [255, 0, 0] # Blue for Object B
return colored_mask
def process_and_save_frames(input_frames_dir, fullres_frames_dir, output_frames_dir, frame_names, video_segments, segment_dir):
def upscale_masks(masks, frame_shape):
upscaled_masks = []
for mask in masks:
mask = mask.squeeze()
upscaled_mask = cv2.resize(mask.astype(np.float32), (frame_shape[1], frame_shape[0]), interpolation=cv2.INTER_LINEAR)
#convert_mask to bool
upscaled_mask = (upscaled_mask > 0.5).astype(bool)
upscaled_masks.append(upscaled_mask)
return upscaled_masks
for out_frame_idx, frame_name in enumerate(frame_names):
frame_path = os.path.join(fullres_frames_dir, frame_name)
frame = cv2.imread(frame_path)
masks = [video_segments[out_frame_idx][out_obj_id] for out_obj_id in video_segments[out_frame_idx]]
# Upscale masks to match the full-resolution frame
upscaled_masks = []
for mask in masks:
mask = mask.squeeze()
upscaled_mask = cv2.resize(mask.astype(np.uint8), (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_NEAREST)
upscaled_masks.append(upscaled_mask)
frame = apply_green_mask(frame, upscaled_masks)
output_path = os.path.join(output_frames_dir, frame_name)
cv2.imwrite(output_path, frame)
# Create and save mask.jpg
final_frame_path = os.path.join(fullres_frames_dir, frame_names[-1])
final_frame = cv2.imread(final_frame_path)
masks_a = [video_segments[len(frame_names) - 1][1]]
masks_b = [video_segments[len(frame_names) - 1][2]]
upscaled_masks_a = upscale_masks(masks_a, final_frame.shape)
upscaled_masks_b = upscale_masks(masks_b, final_frame.shape)
# Apply colored mask
mask_image = apply_colored_mask(final_frame, upscaled_masks_a, upscaled_masks_b)
mask_output_path = os.path.join(segment_dir, "mask.jpg")
cv2.imwrite(mask_output_path, mask_image)
print("Processing complete. Frames saved in:", output_frames_dir)
def process_and_save_output_video(video_path, output_video_path, video_segments, use_nvenc=False):
"""
Process high-resolution frames, apply upscaled masks, and save the output video.
"""
cap = cv2.VideoCapture(video_path)
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = cap.get(cv2.CAP_PROP_FPS) or 59.94
# Setup VideoWriter with desired settings
if use_nvenc:
# Use FFmpeg with NVENC offloading for H.265 encoding
import subprocess
command = [
'ffmpeg',
'-y', # Overwrite output file if it exists
'-f', 'rawvideo',
'-vcodec', 'rawvideo',
'-pix_fmt', 'bgr24',
'-s', f'{frame_width}x{frame_height}',
'-r', str(fps),
'-i', '-', # Input from stdin
'-an', # No audio
'-vcodec', 'hevc_nvenc',
'-pix_fmt', 'yuv420p',
'-preset', 'slow',
'-b:v', '60M',
output_video_path
]
process = subprocess.Popen(command, stdin=subprocess.PIPE)
else:
# Use OpenCV VideoWriter
fourcc = cv2.VideoWriter_fourcc(*'HEVC') # H.265
out = cv2.VideoWriter(output_video_path, fourcc, fps, (frame_width, frame_height))
frame_idx = 0
while True:
ret, frame = cap.read()
if not ret or frame_idx >= len(video_segments):
break
#masks_dict = video_segments[frame_idx]
#upscaled_masks = []
## Upscale masks to full resolution
#for mask in masks_dict.values():
# mask = mask.squeeze()
# upscaled_mask = cv2.resize(
# mask.astype(np.uint8),
# (frame.shape[1], frame.shape[0]),
# interpolation=cv2.INTER_NEAREST
# )
# upscaled_masks.append(upscaled_mask)
masks = [video_segments[frame_idx][out_obj_id] for out_obj_id in video_segments[frame_idx]]
# Upscale masks to match the full-resolution frame
upscaled_masks = []
for mask in masks:
mask = mask.squeeze()
upscaled_mask = cv2.resize(mask.astype(np.uint8), (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_NEAREST)
upscaled_masks.append(upscaled_mask)
result_frame = apply_green_mask(frame, upscaled_masks)
# Write frame to output
if use_nvenc:
process.stdin.write(result_frame.tobytes())
else:
out.write(result_frame)
frame_idx += 1
cap.release()
if use_nvenc:
process.stdin.close()
process.wait()
else:
out.release()
def get_video_file_name(index):
return f"segment_{str(index).zfill(3)}.mp4"
def do_collect_segment_points(base_dir, segments, collect_points_segments, scale=1.0):
logger.info("Collecting points for requested segments.")
for i, segment in enumerate(segments):
segment_index = int(segment.split("_")[1])
segment_dir = os.path.join(base_dir, segment)
points_file = os.path.join(segment_dir, "segment_points")
video_file = os.path.join(segment_dir, get_video_file_name(i))
if segment_index in collect_points_segments and not os.path.exists(points_file):
first_frame = load_first_frame(video_file, scale)
points_a, points_b = select_points(first_frame)
with open(points_file, 'w') as f:
np.savetxt(f, points_a, header="Object A Points")
np.savetxt(f, points_b, header="Object B Points")
def perform_inference(predictor, inference_state, video_path, inference_scale, collect_points, prev_segment_mask=None):
"""
Performs inference on the video frames at a specified scale.
Parameters:
- predictor: The initialized predictor object.
- inference_state: The predictor's inference state.
- video_path: Path to the video file.
- inference_scale: Scaling factor for inference frames.
- collect_points: Boolean indicating whether to collect points.
- prev_segment_mask: Previous segment's mask if available.
Returns:
- masks_per_frame: List of masks for each frame.
"""
masks_per_frame = []
cap = cv2.VideoCapture(video_path)
frame_idx = 0
while True:
ret, frame = cap.read()
if not ret:
break
# Resize frame for inference
low_res_frame = cv2.resize(frame, None, fx=inference_scale, fy=inference_scale, interpolation=cv2.INTER_LINEAR)
if frame_idx == 0:
if collect_points:
# Collect points on the low-res frame
points_a, points_b = select_points(low_res_frame)
add_points_to_predictor(predictor, inference_state, points_a, obj_id=1)
add_points_to_predictor(predictor, inference_state, points_b, obj_id=2)
elif prev_segment_mask is not None:
# Use the previous segment's mask
per_obj_input_mask = prev_segment_mask
for obj_id, mask in per_obj_input_mask.items():
predictor.add_new_mask(inference_state, 0, obj_id, mask)
else:
print("Error: No points or previous mask provided for inference.")
exit()
# Perform inference
predictor.predict(inference_state, low_res_frame)
masks = predictor.get_masks(inference_state)
masks_per_frame.append(masks)
frame_idx += 1
cap.release()
return masks_per_frame
def save_final_masks(video_segments, mask_output_path):
"""
Save the final masks as a colored image.
"""
last_frame_idx = max(video_segments.keys())
masks_dict = video_segments[last_frame_idx]
# Assuming you have two objects with IDs 1 and 2
mask_a = masks_dict.get(1).squeeze()
mask_b = masks_dict.get(2).squeeze()
#create a black image with dimensions with shape (mask_a.y, mask_a.x, 3)
black_frame = np.zeros((mask_a.shape[0], mask_a.shape[1], 3), dtype=np.uint8)
if mask_a is None or mask_b is None:
print("Error: Masks for objects not found.")
return
#convert mask to np.uint8
mask_a = mask_a.astype(np.uint8)
mask_b = mask_b.astype(np.uint8)
mask_image = apply_colored_mask(black_frame, [mask_a], [mask_b])
# Save the mask image
cv2.imwrite(mask_output_path, mask_image)
def create_low_res_video(input_video_path, output_video_path, scale):
"""
Creates a low-resolution version of the input video for inference.
"""
cap = cv2.VideoCapture(input_video_path)
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH) * scale)
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT) * scale)
fps = cap.get(cv2.CAP_PROP_FPS) or 59.94
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_video_path, fourcc, fps, (frame_width, frame_height))
while True:
ret, frame = cap.read()
if not ret:
break
low_res_frame = cv2.resize(frame, (frame_width, frame_height), interpolation=cv2.INTER_LINEAR)
out.write(low_res_frame)
cap.release()
out.release()
def main():
parser = argparse.ArgumentParser(description="Process video segments.")
parser.add_argument("--segments-collect-points", nargs='+', type=int, help="Segments for which to collect points.")
args = parser.parse_args()
base_dir = "./freya_short_segments"
segments = [d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d)) and d.startswith("segment_")]
segments.sort(key=lambda x: int(x.split("_")[1]))
scaled_frames_dir_name = "frames_scaled"
fullres_frames_dir_name = "frames" # iwant to render the final video with these frames
collect_points_segments = args.segments_collect_points if args.segments_collect_points else []
#inference_scale = 0.5 global, full_scale = 1.0 global
do_collect_segment_points(base_dir, segments, collect_points_segments, scale=INFERENCE_SCALE)
# i have validated do_collect_segment_points() is working!
# open_video(video_path) is defined above that returns
# generator that yields frames from the video
for i, segment in enumerate(segments):
segment_index = int(segment.split("_")[1])
segment_dir = os.path.join(base_dir, segment)
video_file_name = get_video_file_name(i)
video_path = os.path.join(segment_dir, video_file_name)
output_done_file = os.path.join(segment_dir, "output_frames_done")
if os.path.exists(output_done_file):
print(f"Segment {segment} already processed. Skipping.")
continue
logger.info(f"Processing segment {segment}")
# Initialize predictor
predictor = initialize_predictor()
# Prepare low-resolution video frames for inference
low_res_video_path = os.path.join(segment_dir, "low_res_video.mp4")
if not os.path.exists(low_res_video_path):
create_low_res_video(video_path, low_res_video_path, INFERENCE_SCALE)
logger.info(f"Low-resolution video created for segment {segment}")
else:
logger.info(f"Low-resolution video already exists for segment {segment}, reuse")
# Initialize inference state with low-resolution video
inference_state = predictor.init_state(video_path=low_res_video_path, async_loading_frames=True)
# Load points or previous masks
points_file = os.path.join(segment_dir, "segment_points")
if os.path.exists(points_file):
logger.info(f"Using segment_points for segment {segment}")
points = np.loadtxt(points_file, comments="#")
points_a = points[:4]
points_b = points[4:]
else:
points_a = points_b = None
collect_points = segment_index in collect_points_segments
if i > 0 and not collect_points and points_a is None:
# Try to load previous segment mask
logger.info(f"Using previous segment mask for segment {segment}")
prev_segment_dir = os.path.join(base_dir, segments[i - 1])
prev_mask_path = os.path.join(prev_segment_dir, "mask.jpg")
if os.path.exists(prev_mask_path):
per_obj_input_mask, input_palette = load_previous_segment_mask(prev_segment_dir)
# Add previous masks to predictor
for obj_id, mask in per_obj_input_mask.items():
predictor.add_new_mask(inference_state, 0, obj_id, mask)
else:
print(f"Warning: Previous segment mask not found for segment {segment}.")
continue # Skip this segment or handle as needed
else:
if points_a is not None and points_b is not None:
print("Using points for segment")
# Add points to predictor
add_points_to_predictor(predictor, inference_state, points_a, obj_id=1)
add_points_to_predictor(predictor, inference_state, points_b, obj_id=2)
else:
print("Error: No points available for segment.")
continue # Skip this segment
# Perform inference and collect masks per frame
video_segments = propagate_masks(predictor, inference_state)
# Process high-resolution frames and save output video
output_video_path = os.path.join(segment_dir, f"output_{segment_index}.mp4")
print("Processing of segment complete, attempting to save process full video from low res masks")
process_and_save_output_video(
video_path,
output_video_path,
video_segments,
use_nvenc=True # Set to True to use NVENC offloading
)
# Save final masks
mask_output_path = os.path.join(segment_dir, "mask.jpg")
save_final_masks(video_segments, mask_output_path)
# Clean up
predictor.reset_state(inference_state)
del inference_state
del video_segments
del predictor
gc.collect()
open(output_done_file, 'a').close()
print("Processing complete.")
if __name__ == "__main__":
main()