51 hours to 31 hours

This commit is contained in:
2024-10-19 22:21:50 -07:00
parent 999c6660e9
commit fecb7f5f04
5 changed files with 389 additions and 72 deletions

View File

@@ -14,14 +14,19 @@
# to determine an input mask and use add_new_mask() instead of selecting
# points.
#
# Each segment has 2 versions of each frame, on high quality used for
# final rendering and 1 low quality used to speed up inference
#
# When the script finishes, each segment should have an output directory
# with the same object tracked throughout the every frame in all the segment directories
#
# I will then turn these back into a video using ffmpeg but that is outside the scope
# of this program
import os
import cv2
import numpy as np
from concurrent.futures import ThreadPoolExecutor
import torch
import sys
from sam2.build_sam import build_sam2_video_predictor
@@ -32,67 +37,19 @@ SAM2_CHECKPOINT = "../checkpoints/sam2.1_hiera_large.pt"
MODEL_CFG = "configs/sam2.1/sam2.1_hiera_l.yaml"
def load_previous_segment_mask(prev_segment_dir):
mask_path = os.path.join(prev_segment_dir, "mask.jpg")
mask_path = os.path.join(prev_segment_dir, "mask.jpg")
mask_image = cv2.imread(mask_path)
# Extract Object A and Object B masks
mask_a = (mask_image[:, :, 1] == 255) # Green channel
mask_b = (mask_image[:, :, 0] == 254) # Blue channel
# show an image of mask a and mask b, resize the window to 300 pixels
#cv2.namedWindow('Mask A', cv2.WINDOW_NORMAL)
#cv2.resizeWindow('Select Points', int(mask_image.shape[1] * (500 / mask_image.shape[0])), 500)
##cv2.imshow('Mask A', mask_a.astype(np.uint8) * 255)
#cv2.imshow('Mask A', mask_image)
per_obj_input_mask = {1: mask_a, 2: mask_b}
input_palette = None # No palette needed for binary mask
return per_obj_input_mask, input_palette
def convert_green_screen_to_mask(frame):
lower_green = np.array([0, 255, 0])
upper_green = np.array([0, 255, 0])
mask = cv2.inRange(frame, lower_green, upper_green)
mask = cv2.bitwise_not(mask)
return mask > 0
def get_per_obj_mask(mask):
object_ids = np.unique(mask)
object_ids = object_ids[object_ids > 0].tolist()
per_obj_mask = {object_id: (mask == object_id) for object_id in object_ids}
return per_obj_mask
def load_masks_from_dir(input_mask_dir, video_name, frame_name, per_obj_png_file, allow_missing=False):
if not per_obj_png_file:
input_mask_path = os.path.join(input_mask_dir, video_name, f"{frame_name}.png")
if allow_missing and not os.path.exists(input_mask_path):
return {}, None
input_mask, input_palette = load_ann_png(input_mask_path)
per_obj_input_mask = get_per_obj_mask(input_mask)
else:
per_obj_input_mask = {}
input_palette = None
for object_name in os.listdir(os.path.join(input_mask_dir, video_name)):
object_id = int(object_name)
input_mask_path = os.path.join(input_mask_dir, video_name, object_name, f"{frame_name}.png")
if allow_missing and not os.path.exists(input_mask_path):
continue
input_mask, input_palette = load_ann_png(input_mask_path)
per_obj_input_mask[object_id] = input_mask > 0
if not per_obj_input_mask:
frame_path = os.path.join(input_mask_dir, video_name, f"{frame_name}.jpg")
if os.path.exists(frame_path):
frame = cv2.imread(frame_path)
mask = convert_green_screen_to_mask(frame)
per_obj_input_mask = {1: mask}
return per_obj_input_mask, input_palette
def apply_green_mask(frame, masks):
def apply_green_mask_oldest(frame, masks):
green_mask = np.zeros_like(frame)
green_mask[:, :] = [0, 255, 0]
@@ -104,9 +61,96 @@ def apply_green_mask(frame, masks):
combined_mask = np.logical_or(combined_mask, mask)
inverted_mask = np.logical_not(combined_mask)
frame[inverted_mask] = green_mask[inverted_mask]
def apply_mask_part(start_row, end_row):
frame[start_row:end_row][inverted_mask[start_row:end_row]] = green_mask[start_row:end_row][inverted_mask[start_row:end_row]]
num_threads = 4
rows_per_thread = frame.shape[0] // num_threads
with ThreadPoolExecutor(max_workers=num_threads) as executor:
futures = [
executor.submit(apply_mask_part, i * rows_per_thread, (i + 1) * rows_per_thread)
for i in range(num_threads)
]
for future in futures:
future.result()
return frame
def apply_green_mask_old_good(frame, masks):
# Initialize combined mask as a boolean array
combined_mask = np.zeros(frame.shape[:2], dtype=bool)
for mask in masks:
mask = mask.squeeze()
# Resize mask if necessary
if mask.shape != frame.shape[:2]:
mask = cv2.resize(mask.astype(np.uint8), (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_NEAREST)
# Ensure mask is boolean
mask = mask.astype(bool)
# Combine masks using in-place logical OR
combined_mask |= mask
# Invert the combined mask to get background regions
inverted_mask = ~combined_mask
# Apply green color to background regions directly
frame[inverted_mask] = [0, 255, 0]
return frame
def apply_green_mask(frame, masks):
"""
Applies masks to the frame, replacing the background with green.
Parameters:
- frame: numpy array representing the image frame.
- masks: list of numpy arrays representing the masks.
Returns:
- result_frame: numpy array with the green background applied.
"""
# Initialize combined mask as a boolean array
combined_mask = np.zeros(frame.shape[:2], dtype=bool)
for mask in masks:
mask = mask.squeeze()
# Resize the mask if necessary
if mask.shape != frame.shape[:2]:
# Resize the mask using bilinear interpolation
# and convert it to float32 for accurate interpolation
resized_mask = cv2.resize(
mask.astype(np.float32),
(frame.shape[1], frame.shape[0]),
interpolation=cv2.INTER_LINEAR
)
# Threshold the resized mask to obtain a boolean mask
mask = resized_mask > 0.5
else:
# Ensure mask is boolean
mask = mask.astype(bool)
# Combine masks using logical OR
combined_mask |= mask # Now both arrays are bool
# Create a green background image
green_background = np.full_like(frame, [0, 255, 0])
# Use combined mask to overlay the original frame onto the green background
result_frame = np.where(
combined_mask[..., None],
frame,
green_background
)
return result_frame
def initialize_predictor():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
predictor = build_sam2_video_predictor(MODEL_CFG, SAM2_CHECKPOINT, device=device)
@@ -155,7 +199,6 @@ def select_points(first_frame):
cv2.destroyAllWindows()
return np.array(points_a, dtype=np.float32), np.array(points_b, dtype=np.float32)
def add_points_to_predictor(predictor, inference_state, points, obj_id):
labels = np.array([1, 1, 1, 1], np.int32) # Update labels to match 4 points
points = np.array(points, dtype=np.float32) # Ensure points have shape (4, 2)
@@ -215,33 +258,57 @@ def apply_colored_mask(frame, masks_a, masks_b):
mask = mask.squeeze()
if mask.shape != frame.shape[:2]:
mask = cv2.resize(mask, (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_NEAREST)
indices = np.where(mask)
colored_mask[mask] = [0, 255, 0] # Green for Object A
for mask in masks_b:
mask = mask.squeeze()
if mask.shape != frame.shape[:2]:
mask = cv2.resize(mask, (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_NEAREST)
indices = np.where(mask)
colored_mask[mask] = [255, 0, 0] # Blue for Object B
return colored_mask
def process_and_save_frames(input_frames_dir, output_frames_dir, frame_names, video_segments, segment_dir):
def process_and_save_frames(input_frames_dir, fullres_frames_dir, output_frames_dir, frame_names, video_segments, segment_dir):
def upscale_masks(masks, frame_shape):
upscaled_masks = []
for mask in masks:
mask = mask.squeeze()
upscaled_mask = cv2.resize(mask.astype(np.float32), (frame_shape[1], frame_shape[0]), interpolation=cv2.INTER_LINEAR)
#convert_mask to bool
upscaled_mask = (upscaled_mask > 0.5).astype(bool)
upscaled_masks.append(upscaled_mask)
return upscaled_masks
for out_frame_idx, frame_name in enumerate(frame_names):
frame_path = os.path.join(input_frames_dir, frame_name)
frame_path = os.path.join(fullres_frames_dir, frame_name)
frame = cv2.imread(frame_path)
masks = [video_segments[out_frame_idx][out_obj_id] for out_obj_id in video_segments[out_frame_idx]]
frame = apply_green_mask(frame, masks)
# Upscale masks to match the full-resolution frame
upscaled_masks = []
for mask in masks:
mask = mask.squeeze()
upscaled_mask = cv2.resize(mask.astype(np.uint8), (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_NEAREST)
upscaled_masks.append(upscaled_mask)
frame = apply_green_mask(frame, upscaled_masks)
output_path = os.path.join(output_frames_dir, frame_name)
cv2.imwrite(output_path, frame)
# Create and save mask.jpg
final_frame_path = os.path.join(input_frames_dir, frame_names[-1])
final_frame_path = os.path.join(fullres_frames_dir, frame_names[-1])
final_frame = cv2.imread(final_frame_path)
masks_a = [video_segments[len(frame_names) - 1][1]]
masks_b = [video_segments[len(frame_names) - 1][2]]
upscaled_masks_a = upscale_masks(masks_a, final_frame.shape)
upscaled_masks_b = upscale_masks(masks_b, final_frame.shape)
# Apply colored mask
mask_image = apply_colored_mask(final_frame, masks_a, masks_b)
mask_image = apply_colored_mask(final_frame, upscaled_masks_a, upscaled_masks_b)
mask_output_path = os.path.join(segment_dir, "mask.jpg")
cv2.imwrite(mask_output_path, mask_image)
@@ -253,10 +320,11 @@ def main():
parser.add_argument("--segments-collect-points", nargs='+', type=int, help="Segments for which to collect points.")
args = parser.parse_args()
base_dir = "./spirit_2min_segments"
base_dir = "./606-short_segments"
segments = [d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d)) and d.startswith("segment_")]
segments.sort(key=lambda x: int(x.split("_")[1]))
scaled_frames_dir_name = "frames_scaled"
fullres_frames_dir_name = "frames" # iwant to render the final video with these frames
collect_points_segments = args.segments_collect_points if args.segments_collect_points else []
@@ -266,7 +334,7 @@ def main():
segment_dir = os.path.join(base_dir, segment)
points_file = os.path.join(segment_dir, "segment_points")
if segment_index in collect_points_segments and not os.path.exists(points_file):
input_frames_dir = os.path.join(segment_dir, "frames")
input_frames_dir = os.path.join(segment_dir, f"{scaled_frames_dir_name}")
first_frame, _ = load_first_frame(input_frames_dir)
points_a, points_b = select_points(first_frame)
with open(points_file, 'w') as f:
@@ -285,7 +353,8 @@ def main():
points = np.loadtxt(points_file, comments="#")
points_a = points[:4]
points_b = points[4:]
input_frames_dir = os.path.join(segment_dir, "frames")
input_frames_dir = os.path.join(segment_dir, f"{scaled_frames_dir_name}")
fullres_frames_dir = os.path.join(segment_dir, f"{fullres_frames_dir_name}")
output_frames_dir = os.path.join(segment_dir, "output_frames")
os.makedirs(output_frames_dir, exist_ok=True)
first_frame, frame_names = load_first_frame(input_frames_dir)
@@ -306,7 +375,7 @@ def main():
masks_b = [(out_mask_logits_b[i] > 0.0).cpu().numpy() for i in range(len(out_mask_logits_b))]
video_segments = propagate_masks(predictor, inference_state)
predictor.reset_state(inference_state)
process_and_save_frames(input_frames_dir, output_frames_dir, frame_names, video_segments, segment_dir)
process_and_save_frames(input_frames_dir, fullres_frames_dir, output_frames_dir, frame_names, video_segments, segment_dir)
del inference_state
del video_segments
del predictor