From 999c6660e9bf23b278057e92acd21a248ce4811e Mon Sep 17 00:00:00 2001 From: Scott Register Date: Sat, 19 Oct 2024 14:34:17 -0700 Subject: [PATCH] add foo_points_priv --- notebooks/foo_points_prev.py | 318 +++++++++++++++++++++++++++++++++++ 1 file changed, 318 insertions(+) create mode 100644 notebooks/foo_points_prev.py diff --git a/notebooks/foo_points_prev.py b/notebooks/foo_points_prev.py new file mode 100644 index 0000000..f026abc --- /dev/null +++ b/notebooks/foo_points_prev.py @@ -0,0 +1,318 @@ +# this script will process multiple video segments (5 second clips) +# of a long (1 hour total video) and will greenscreen/key out everything +# except tracked objects. +# +# To specify which items are tracked, the user will call the script with +# --segments-collect-points 1 5 10 ... Which are segments to be treated +# as "keyframes", for these we will ask the user to select the point +# we want to track by hand, this is done for large changes in the objects +# position in the video. +# +# For other segments, since the object is mostly static in the frame, +# We will use the previous segments final frame (an image +# with the object we want to track visible, and everything else green) +# to determine an input mask and use add_new_mask() instead of selecting +# points. +# +# When the script finishes, each segment should have an output directory +# with the same object tracked throughout the every frame in all the segment directories +# +# I will then turn these back into a video using ffmpeg but that is outside the scope +# of this program +import os +import cv2 +import numpy as np +import torch +import sys +from sam2.build_sam import build_sam2_video_predictor +import argparse + +# Variables for input and output directories +SAM2_CHECKPOINT = "../checkpoints/sam2.1_hiera_large.pt" +MODEL_CFG = "configs/sam2.1/sam2.1_hiera_l.yaml" + +def load_previous_segment_mask(prev_segment_dir): + mask_path = os.path.join(prev_segment_dir, "mask.jpg") + mask_image = cv2.imread(mask_path) + + # Extract Object A and Object B masks + mask_a = (mask_image[:, :, 1] == 255) # Green channel + mask_b = (mask_image[:, :, 0] == 254) # Blue channel + + # show an image of mask a and mask b, resize the window to 300 pixels + #cv2.namedWindow('Mask A', cv2.WINDOW_NORMAL) + #cv2.resizeWindow('Select Points', int(mask_image.shape[1] * (500 / mask_image.shape[0])), 500) + ##cv2.imshow('Mask A', mask_a.astype(np.uint8) * 255) + #cv2.imshow('Mask A', mask_image) + + + per_obj_input_mask = {1: mask_a, 2: mask_b} + input_palette = None # No palette needed for binary mask + + return per_obj_input_mask, input_palette + +def convert_green_screen_to_mask(frame): + lower_green = np.array([0, 255, 0]) + upper_green = np.array([0, 255, 0]) + mask = cv2.inRange(frame, lower_green, upper_green) + mask = cv2.bitwise_not(mask) + return mask > 0 + +def get_per_obj_mask(mask): + object_ids = np.unique(mask) + object_ids = object_ids[object_ids > 0].tolist() + per_obj_mask = {object_id: (mask == object_id) for object_id in object_ids} + return per_obj_mask + +def load_masks_from_dir(input_mask_dir, video_name, frame_name, per_obj_png_file, allow_missing=False): + if not per_obj_png_file: + input_mask_path = os.path.join(input_mask_dir, video_name, f"{frame_name}.png") + if allow_missing and not os.path.exists(input_mask_path): + return {}, None + input_mask, input_palette = load_ann_png(input_mask_path) + per_obj_input_mask = get_per_obj_mask(input_mask) + else: + per_obj_input_mask = {} + input_palette = None + for object_name in os.listdir(os.path.join(input_mask_dir, video_name)): + object_id = int(object_name) + input_mask_path = os.path.join(input_mask_dir, video_name, object_name, f"{frame_name}.png") + if allow_missing and not os.path.exists(input_mask_path): + continue + input_mask, input_palette = load_ann_png(input_mask_path) + per_obj_input_mask[object_id] = input_mask > 0 + + if not per_obj_input_mask: + frame_path = os.path.join(input_mask_dir, video_name, f"{frame_name}.jpg") + if os.path.exists(frame_path): + frame = cv2.imread(frame_path) + mask = convert_green_screen_to_mask(frame) + per_obj_input_mask = {1: mask} + + return per_obj_input_mask, input_palette + + +def apply_green_mask(frame, masks): + green_mask = np.zeros_like(frame) + green_mask[:, :] = [0, 255, 0] + + combined_mask = np.zeros(frame.shape[:2], dtype=bool) + for mask in masks: + mask = mask.squeeze() + if mask.shape != frame.shape[:2]: + mask = cv2.resize(mask, (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_NEAREST) + combined_mask = np.logical_or(combined_mask, mask) + + inverted_mask = np.logical_not(combined_mask) + frame[inverted_mask] = green_mask[inverted_mask] + return frame + +def initialize_predictor(): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + predictor = build_sam2_video_predictor(MODEL_CFG, SAM2_CHECKPOINT, device=device) + return predictor + +def load_first_frame(input_frames_dir): + frame_names = sorted([p for p in os.listdir(input_frames_dir) if p.endswith(('.jpg', '.jpeg', '.png'))]) + first_frame_path = os.path.join(input_frames_dir, frame_names[0]) + first_frame = cv2.imread(first_frame_path) + return first_frame, frame_names + +def select_points(first_frame): + points_a = [] + points_b = [] + current_object = 'A' + point_count = 0 + selection_complete = False + + def mouse_callback(event, x, y, flags, param): + nonlocal points_a, points_b, current_object, point_count, selection_complete + if event == cv2.EVENT_LBUTTONDOWN: + if current_object == 'A': + points_a.append((x, y)) + point_count += 1 + print(f"Selected point {point_count} for Object A: ({x}, {y})") + if len(points_a) == 4: # Collect 4 points for Object A + current_object = 'B' + point_count = 0 + print("Select point 1 for Object B") + elif current_object == 'B': + points_b.append((x, y)) + point_count += 1 + print(f"Selected point {point_count} for Object B: ({x}, {y})") + if len(points_b) == 4: # Collect 4 points for Object B + selection_complete = True + + print("Select point 1 for Object A") + cv2.namedWindow('Select Points', cv2.WINDOW_NORMAL) + cv2.resizeWindow('Select Points', int(first_frame.shape[1] * (500 / first_frame.shape[0])), 500) + cv2.imshow('Select Points', first_frame) + cv2.setMouseCallback('Select Points', mouse_callback) + + while not selection_complete: + cv2.waitKey(1) + + cv2.destroyAllWindows() + return np.array(points_a, dtype=np.float32), np.array(points_b, dtype=np.float32) + + +def add_points_to_predictor(predictor, inference_state, points, obj_id): + labels = np.array([1, 1, 1, 1], np.int32) # Update labels to match 4 points + points = np.array(points, dtype=np.float32) # Ensure points have shape (4, 2) + try: + print(f"Adding points for Object {obj_id}: {points}") + _, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box( + inference_state=inference_state, + frame_idx=0, + obj_id=obj_id, + points=points, + labels=labels, + ) + print(f"Object {obj_id} added successfully: {out_obj_ids}") + return out_mask_logits + except Exception as e: + print(f"Error adding points for Object {obj_id}: {e}") + exit() + +def show_mask_on_frame(frame, masks): + combined_mask = np.zeros(frame.shape[:2], dtype=bool) + for mask in masks: + mask = mask.squeeze() + if mask.shape != frame.shape[:2]: + mask = cv2.resize(mask, (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_NEAREST) + combined_mask = np.logical_or(combined_mask, mask) + color = (0, 255, 0) + frame[combined_mask] = color + return frame + +def confirm_masks(first_frame, masks_a, masks_b): + first_frame_with_masks = show_mask_on_frame(first_frame.copy(), masks_a + masks_b) + cv2.namedWindow('First Frame with Masks', cv2.WINDOW_NORMAL) + cv2.resizeWindow('First Frame with Masks', int(first_frame.shape[1] * (500 / first_frame.shape[0])), 500) + cv2.imshow('First Frame with Masks', first_frame_with_masks) + cv2.waitKey(0) + cv2.destroyAllWindows() + + confirmation = input("Are the masks correct? (yes/no): ").strip().lower() + if confirmation != 'yes': + print("Aborting process.") + exit() + +def propagate_masks(predictor, inference_state): + video_segments = {} + for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state): + video_segments[out_frame_idx] = { + out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy() + for i, out_obj_id in enumerate(out_obj_ids) + } + return video_segments + +def apply_colored_mask(frame, masks_a, masks_b): + colored_mask = np.zeros_like(frame) + + # Apply colors to the masks + for mask in masks_a: + mask = mask.squeeze() + if mask.shape != frame.shape[:2]: + mask = cv2.resize(mask, (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_NEAREST) + colored_mask[mask] = [0, 255, 0] # Green for Object A + + for mask in masks_b: + mask = mask.squeeze() + if mask.shape != frame.shape[:2]: + mask = cv2.resize(mask, (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_NEAREST) + colored_mask[mask] = [255, 0, 0] # Blue for Object B + + return colored_mask + +def process_and_save_frames(input_frames_dir, output_frames_dir, frame_names, video_segments, segment_dir): + for out_frame_idx, frame_name in enumerate(frame_names): + frame_path = os.path.join(input_frames_dir, frame_name) + frame = cv2.imread(frame_path) + masks = [video_segments[out_frame_idx][out_obj_id] for out_obj_id in video_segments[out_frame_idx]] + frame = apply_green_mask(frame, masks) + output_path = os.path.join(output_frames_dir, frame_name) + cv2.imwrite(output_path, frame) + + # Create and save mask.jpg + final_frame_path = os.path.join(input_frames_dir, frame_names[-1]) + final_frame = cv2.imread(final_frame_path) + masks_a = [video_segments[len(frame_names) - 1][1]] + masks_b = [video_segments[len(frame_names) - 1][2]] + + # Apply colored mask + mask_image = apply_colored_mask(final_frame, masks_a, masks_b) + + mask_output_path = os.path.join(segment_dir, "mask.jpg") + cv2.imwrite(mask_output_path, mask_image) + + print("Processing complete. Frames saved in:", output_frames_dir) + +def main(): + parser = argparse.ArgumentParser(description="Process video segments.") + parser.add_argument("--segments-collect-points", nargs='+', type=int, help="Segments for which to collect points.") + args = parser.parse_args() + + base_dir = "./spirit_2min_segments" + segments = [d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d)) and d.startswith("segment_")] + segments.sort(key=lambda x: int(x.split("_")[1])) + + + collect_points_segments = args.segments_collect_points if args.segments_collect_points else [] + + for i, segment in enumerate(segments): + print("Processing segment", segment) + segment_index = int(segment.split("_")[1]) + segment_dir = os.path.join(base_dir, segment) + points_file = os.path.join(segment_dir, "segment_points") + if segment_index in collect_points_segments and not os.path.exists(points_file): + input_frames_dir = os.path.join(segment_dir, "frames") + first_frame, _ = load_first_frame(input_frames_dir) + points_a, points_b = select_points(first_frame) + with open(points_file, 'w') as f: + np.savetxt(f, points_a, header="Object A Points") + np.savetxt(f, points_b, header="Object B Points") + + for i, segment in enumerate(segments): + segment_index = int(segment.split("_")[1]) + segment_dir = os.path.join(base_dir, segment) + output_done_file = os.path.join(segment_dir, "output_frames_done") + if not os.path.exists(output_done_file): + print(f"Processing segment {segment}") + + points_file = os.path.join(segment_dir, "segment_points") + if os.path.exists(points_file): + points = np.loadtxt(points_file, comments="#") + points_a = points[:4] + points_b = points[4:] + input_frames_dir = os.path.join(segment_dir, "frames") + output_frames_dir = os.path.join(segment_dir, "output_frames") + os.makedirs(output_frames_dir, exist_ok=True) + first_frame, frame_names = load_first_frame(input_frames_dir) + predictor = initialize_predictor() + inference_state = predictor.init_state(video_path=input_frames_dir, async_loading_frames=True) + + if i > 0 and segment_index not in collect_points_segments and not os.path.exists(points_file): + prev_segment_dir = os.path.join(base_dir, segments[i-1]) + print(f"Loading previous segment masks from {prev_segment_dir}") + per_obj_input_mask, input_palette = load_previous_segment_mask(prev_segment_dir) + for obj_id, mask in per_obj_input_mask.items(): + predictor.add_new_mask(inference_state, 0, obj_id, mask) + else: + print("Using points for segment") + out_mask_logits_a = add_points_to_predictor(predictor, inference_state, points_a, obj_id=1) + out_mask_logits_b = add_points_to_predictor(predictor, inference_state, points_b, obj_id=2) + masks_a = [(out_mask_logits_a[i] > 0.0).cpu().numpy() for i in range(len(out_mask_logits_a))] + masks_b = [(out_mask_logits_b[i] > 0.0).cpu().numpy() for i in range(len(out_mask_logits_b))] + video_segments = propagate_masks(predictor, inference_state) + predictor.reset_state(inference_state) + process_and_save_frames(input_frames_dir, output_frames_dir, frame_names, video_segments, segment_dir) + del inference_state + del video_segments + del predictor + import gc + gc.collect() + open(output_done_file, 'a').close() + +if __name__ == "__main__": + main()