add foo_points_priv
This commit is contained in:
318
notebooks/foo_points_prev.py
Normal file
318
notebooks/foo_points_prev.py
Normal file
@@ -0,0 +1,318 @@
|
||||
# this script will process multiple video segments (5 second clips)
|
||||
# of a long (1 hour total video) and will greenscreen/key out everything
|
||||
# except tracked objects.
|
||||
#
|
||||
# To specify which items are tracked, the user will call the script with
|
||||
# --segments-collect-points 1 5 10 ... Which are segments to be treated
|
||||
# as "keyframes", for these we will ask the user to select the point
|
||||
# we want to track by hand, this is done for large changes in the objects
|
||||
# position in the video.
|
||||
#
|
||||
# For other segments, since the object is mostly static in the frame,
|
||||
# We will use the previous segments final frame (an image
|
||||
# with the object we want to track visible, and everything else green)
|
||||
# to determine an input mask and use add_new_mask() instead of selecting
|
||||
# points.
|
||||
#
|
||||
# When the script finishes, each segment should have an output directory
|
||||
# with the same object tracked throughout the every frame in all the segment directories
|
||||
#
|
||||
# I will then turn these back into a video using ffmpeg but that is outside the scope
|
||||
# of this program
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import sys
|
||||
from sam2.build_sam import build_sam2_video_predictor
|
||||
import argparse
|
||||
|
||||
# Variables for input and output directories
|
||||
SAM2_CHECKPOINT = "../checkpoints/sam2.1_hiera_large.pt"
|
||||
MODEL_CFG = "configs/sam2.1/sam2.1_hiera_l.yaml"
|
||||
|
||||
def load_previous_segment_mask(prev_segment_dir):
|
||||
mask_path = os.path.join(prev_segment_dir, "mask.jpg")
|
||||
mask_image = cv2.imread(mask_path)
|
||||
|
||||
# Extract Object A and Object B masks
|
||||
mask_a = (mask_image[:, :, 1] == 255) # Green channel
|
||||
mask_b = (mask_image[:, :, 0] == 254) # Blue channel
|
||||
|
||||
# show an image of mask a and mask b, resize the window to 300 pixels
|
||||
#cv2.namedWindow('Mask A', cv2.WINDOW_NORMAL)
|
||||
#cv2.resizeWindow('Select Points', int(mask_image.shape[1] * (500 / mask_image.shape[0])), 500)
|
||||
##cv2.imshow('Mask A', mask_a.astype(np.uint8) * 255)
|
||||
#cv2.imshow('Mask A', mask_image)
|
||||
|
||||
|
||||
per_obj_input_mask = {1: mask_a, 2: mask_b}
|
||||
input_palette = None # No palette needed for binary mask
|
||||
|
||||
return per_obj_input_mask, input_palette
|
||||
|
||||
def convert_green_screen_to_mask(frame):
|
||||
lower_green = np.array([0, 255, 0])
|
||||
upper_green = np.array([0, 255, 0])
|
||||
mask = cv2.inRange(frame, lower_green, upper_green)
|
||||
mask = cv2.bitwise_not(mask)
|
||||
return mask > 0
|
||||
|
||||
def get_per_obj_mask(mask):
|
||||
object_ids = np.unique(mask)
|
||||
object_ids = object_ids[object_ids > 0].tolist()
|
||||
per_obj_mask = {object_id: (mask == object_id) for object_id in object_ids}
|
||||
return per_obj_mask
|
||||
|
||||
def load_masks_from_dir(input_mask_dir, video_name, frame_name, per_obj_png_file, allow_missing=False):
|
||||
if not per_obj_png_file:
|
||||
input_mask_path = os.path.join(input_mask_dir, video_name, f"{frame_name}.png")
|
||||
if allow_missing and not os.path.exists(input_mask_path):
|
||||
return {}, None
|
||||
input_mask, input_palette = load_ann_png(input_mask_path)
|
||||
per_obj_input_mask = get_per_obj_mask(input_mask)
|
||||
else:
|
||||
per_obj_input_mask = {}
|
||||
input_palette = None
|
||||
for object_name in os.listdir(os.path.join(input_mask_dir, video_name)):
|
||||
object_id = int(object_name)
|
||||
input_mask_path = os.path.join(input_mask_dir, video_name, object_name, f"{frame_name}.png")
|
||||
if allow_missing and not os.path.exists(input_mask_path):
|
||||
continue
|
||||
input_mask, input_palette = load_ann_png(input_mask_path)
|
||||
per_obj_input_mask[object_id] = input_mask > 0
|
||||
|
||||
if not per_obj_input_mask:
|
||||
frame_path = os.path.join(input_mask_dir, video_name, f"{frame_name}.jpg")
|
||||
if os.path.exists(frame_path):
|
||||
frame = cv2.imread(frame_path)
|
||||
mask = convert_green_screen_to_mask(frame)
|
||||
per_obj_input_mask = {1: mask}
|
||||
|
||||
return per_obj_input_mask, input_palette
|
||||
|
||||
|
||||
def apply_green_mask(frame, masks):
|
||||
green_mask = np.zeros_like(frame)
|
||||
green_mask[:, :] = [0, 255, 0]
|
||||
|
||||
combined_mask = np.zeros(frame.shape[:2], dtype=bool)
|
||||
for mask in masks:
|
||||
mask = mask.squeeze()
|
||||
if mask.shape != frame.shape[:2]:
|
||||
mask = cv2.resize(mask, (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_NEAREST)
|
||||
combined_mask = np.logical_or(combined_mask, mask)
|
||||
|
||||
inverted_mask = np.logical_not(combined_mask)
|
||||
frame[inverted_mask] = green_mask[inverted_mask]
|
||||
return frame
|
||||
|
||||
def initialize_predictor():
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
predictor = build_sam2_video_predictor(MODEL_CFG, SAM2_CHECKPOINT, device=device)
|
||||
return predictor
|
||||
|
||||
def load_first_frame(input_frames_dir):
|
||||
frame_names = sorted([p for p in os.listdir(input_frames_dir) if p.endswith(('.jpg', '.jpeg', '.png'))])
|
||||
first_frame_path = os.path.join(input_frames_dir, frame_names[0])
|
||||
first_frame = cv2.imread(first_frame_path)
|
||||
return first_frame, frame_names
|
||||
|
||||
def select_points(first_frame):
|
||||
points_a = []
|
||||
points_b = []
|
||||
current_object = 'A'
|
||||
point_count = 0
|
||||
selection_complete = False
|
||||
|
||||
def mouse_callback(event, x, y, flags, param):
|
||||
nonlocal points_a, points_b, current_object, point_count, selection_complete
|
||||
if event == cv2.EVENT_LBUTTONDOWN:
|
||||
if current_object == 'A':
|
||||
points_a.append((x, y))
|
||||
point_count += 1
|
||||
print(f"Selected point {point_count} for Object A: ({x}, {y})")
|
||||
if len(points_a) == 4: # Collect 4 points for Object A
|
||||
current_object = 'B'
|
||||
point_count = 0
|
||||
print("Select point 1 for Object B")
|
||||
elif current_object == 'B':
|
||||
points_b.append((x, y))
|
||||
point_count += 1
|
||||
print(f"Selected point {point_count} for Object B: ({x}, {y})")
|
||||
if len(points_b) == 4: # Collect 4 points for Object B
|
||||
selection_complete = True
|
||||
|
||||
print("Select point 1 for Object A")
|
||||
cv2.namedWindow('Select Points', cv2.WINDOW_NORMAL)
|
||||
cv2.resizeWindow('Select Points', int(first_frame.shape[1] * (500 / first_frame.shape[0])), 500)
|
||||
cv2.imshow('Select Points', first_frame)
|
||||
cv2.setMouseCallback('Select Points', mouse_callback)
|
||||
|
||||
while not selection_complete:
|
||||
cv2.waitKey(1)
|
||||
|
||||
cv2.destroyAllWindows()
|
||||
return np.array(points_a, dtype=np.float32), np.array(points_b, dtype=np.float32)
|
||||
|
||||
|
||||
def add_points_to_predictor(predictor, inference_state, points, obj_id):
|
||||
labels = np.array([1, 1, 1, 1], np.int32) # Update labels to match 4 points
|
||||
points = np.array(points, dtype=np.float32) # Ensure points have shape (4, 2)
|
||||
try:
|
||||
print(f"Adding points for Object {obj_id}: {points}")
|
||||
_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
|
||||
inference_state=inference_state,
|
||||
frame_idx=0,
|
||||
obj_id=obj_id,
|
||||
points=points,
|
||||
labels=labels,
|
||||
)
|
||||
print(f"Object {obj_id} added successfully: {out_obj_ids}")
|
||||
return out_mask_logits
|
||||
except Exception as e:
|
||||
print(f"Error adding points for Object {obj_id}: {e}")
|
||||
exit()
|
||||
|
||||
def show_mask_on_frame(frame, masks):
|
||||
combined_mask = np.zeros(frame.shape[:2], dtype=bool)
|
||||
for mask in masks:
|
||||
mask = mask.squeeze()
|
||||
if mask.shape != frame.shape[:2]:
|
||||
mask = cv2.resize(mask, (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_NEAREST)
|
||||
combined_mask = np.logical_or(combined_mask, mask)
|
||||
color = (0, 255, 0)
|
||||
frame[combined_mask] = color
|
||||
return frame
|
||||
|
||||
def confirm_masks(first_frame, masks_a, masks_b):
|
||||
first_frame_with_masks = show_mask_on_frame(first_frame.copy(), masks_a + masks_b)
|
||||
cv2.namedWindow('First Frame with Masks', cv2.WINDOW_NORMAL)
|
||||
cv2.resizeWindow('First Frame with Masks', int(first_frame.shape[1] * (500 / first_frame.shape[0])), 500)
|
||||
cv2.imshow('First Frame with Masks', first_frame_with_masks)
|
||||
cv2.waitKey(0)
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
confirmation = input("Are the masks correct? (yes/no): ").strip().lower()
|
||||
if confirmation != 'yes':
|
||||
print("Aborting process.")
|
||||
exit()
|
||||
|
||||
def propagate_masks(predictor, inference_state):
|
||||
video_segments = {}
|
||||
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
|
||||
video_segments[out_frame_idx] = {
|
||||
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
|
||||
for i, out_obj_id in enumerate(out_obj_ids)
|
||||
}
|
||||
return video_segments
|
||||
|
||||
def apply_colored_mask(frame, masks_a, masks_b):
|
||||
colored_mask = np.zeros_like(frame)
|
||||
|
||||
# Apply colors to the masks
|
||||
for mask in masks_a:
|
||||
mask = mask.squeeze()
|
||||
if mask.shape != frame.shape[:2]:
|
||||
mask = cv2.resize(mask, (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_NEAREST)
|
||||
colored_mask[mask] = [0, 255, 0] # Green for Object A
|
||||
|
||||
for mask in masks_b:
|
||||
mask = mask.squeeze()
|
||||
if mask.shape != frame.shape[:2]:
|
||||
mask = cv2.resize(mask, (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_NEAREST)
|
||||
colored_mask[mask] = [255, 0, 0] # Blue for Object B
|
||||
|
||||
return colored_mask
|
||||
|
||||
def process_and_save_frames(input_frames_dir, output_frames_dir, frame_names, video_segments, segment_dir):
|
||||
for out_frame_idx, frame_name in enumerate(frame_names):
|
||||
frame_path = os.path.join(input_frames_dir, frame_name)
|
||||
frame = cv2.imread(frame_path)
|
||||
masks = [video_segments[out_frame_idx][out_obj_id] for out_obj_id in video_segments[out_frame_idx]]
|
||||
frame = apply_green_mask(frame, masks)
|
||||
output_path = os.path.join(output_frames_dir, frame_name)
|
||||
cv2.imwrite(output_path, frame)
|
||||
|
||||
# Create and save mask.jpg
|
||||
final_frame_path = os.path.join(input_frames_dir, frame_names[-1])
|
||||
final_frame = cv2.imread(final_frame_path)
|
||||
masks_a = [video_segments[len(frame_names) - 1][1]]
|
||||
masks_b = [video_segments[len(frame_names) - 1][2]]
|
||||
|
||||
# Apply colored mask
|
||||
mask_image = apply_colored_mask(final_frame, masks_a, masks_b)
|
||||
|
||||
mask_output_path = os.path.join(segment_dir, "mask.jpg")
|
||||
cv2.imwrite(mask_output_path, mask_image)
|
||||
|
||||
print("Processing complete. Frames saved in:", output_frames_dir)
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Process video segments.")
|
||||
parser.add_argument("--segments-collect-points", nargs='+', type=int, help="Segments for which to collect points.")
|
||||
args = parser.parse_args()
|
||||
|
||||
base_dir = "./spirit_2min_segments"
|
||||
segments = [d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d)) and d.startswith("segment_")]
|
||||
segments.sort(key=lambda x: int(x.split("_")[1]))
|
||||
|
||||
|
||||
collect_points_segments = args.segments_collect_points if args.segments_collect_points else []
|
||||
|
||||
for i, segment in enumerate(segments):
|
||||
print("Processing segment", segment)
|
||||
segment_index = int(segment.split("_")[1])
|
||||
segment_dir = os.path.join(base_dir, segment)
|
||||
points_file = os.path.join(segment_dir, "segment_points")
|
||||
if segment_index in collect_points_segments and not os.path.exists(points_file):
|
||||
input_frames_dir = os.path.join(segment_dir, "frames")
|
||||
first_frame, _ = load_first_frame(input_frames_dir)
|
||||
points_a, points_b = select_points(first_frame)
|
||||
with open(points_file, 'w') as f:
|
||||
np.savetxt(f, points_a, header="Object A Points")
|
||||
np.savetxt(f, points_b, header="Object B Points")
|
||||
|
||||
for i, segment in enumerate(segments):
|
||||
segment_index = int(segment.split("_")[1])
|
||||
segment_dir = os.path.join(base_dir, segment)
|
||||
output_done_file = os.path.join(segment_dir, "output_frames_done")
|
||||
if not os.path.exists(output_done_file):
|
||||
print(f"Processing segment {segment}")
|
||||
|
||||
points_file = os.path.join(segment_dir, "segment_points")
|
||||
if os.path.exists(points_file):
|
||||
points = np.loadtxt(points_file, comments="#")
|
||||
points_a = points[:4]
|
||||
points_b = points[4:]
|
||||
input_frames_dir = os.path.join(segment_dir, "frames")
|
||||
output_frames_dir = os.path.join(segment_dir, "output_frames")
|
||||
os.makedirs(output_frames_dir, exist_ok=True)
|
||||
first_frame, frame_names = load_first_frame(input_frames_dir)
|
||||
predictor = initialize_predictor()
|
||||
inference_state = predictor.init_state(video_path=input_frames_dir, async_loading_frames=True)
|
||||
|
||||
if i > 0 and segment_index not in collect_points_segments and not os.path.exists(points_file):
|
||||
prev_segment_dir = os.path.join(base_dir, segments[i-1])
|
||||
print(f"Loading previous segment masks from {prev_segment_dir}")
|
||||
per_obj_input_mask, input_palette = load_previous_segment_mask(prev_segment_dir)
|
||||
for obj_id, mask in per_obj_input_mask.items():
|
||||
predictor.add_new_mask(inference_state, 0, obj_id, mask)
|
||||
else:
|
||||
print("Using points for segment")
|
||||
out_mask_logits_a = add_points_to_predictor(predictor, inference_state, points_a, obj_id=1)
|
||||
out_mask_logits_b = add_points_to_predictor(predictor, inference_state, points_b, obj_id=2)
|
||||
masks_a = [(out_mask_logits_a[i] > 0.0).cpu().numpy() for i in range(len(out_mask_logits_a))]
|
||||
masks_b = [(out_mask_logits_b[i] > 0.0).cpu().numpy() for i in range(len(out_mask_logits_b))]
|
||||
video_segments = propagate_masks(predictor, inference_state)
|
||||
predictor.reset_state(inference_state)
|
||||
process_and_save_frames(input_frames_dir, output_frames_dir, frame_names, video_segments, segment_dir)
|
||||
del inference_state
|
||||
del video_segments
|
||||
del predictor
|
||||
import gc
|
||||
gc.collect()
|
||||
open(output_done_file, 'a').close()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user