import os import cv2 import numpy as np import torch import sys from sam2.build_sam import build_sam2_video_predictor # Variables for input and output directories SAM2_CHECKPOINT = "../checkpoints/sam2.1_hiera_large.pt" MODEL_CFG = "configs/sam2.1/sam2.1_hiera_l.yaml" def validate_points(base_dir): segments = [d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d)) and d.startswith("segment_")] #sort the segment directories, they in in the form segment_ segments.sort(key=lambda x: int(x.split("_")[1])) prev_segment = None for segment in segments: prev_segment = segment segment_dir = os.path.join(base_dir, segment) points_file = os.path.join(segment_dir, "segment_points") if not os.path.exists(points_file): print(f"Points file missing for {segment_dir}. Generating points...") input_frames_dir = os.path.join(segment_dir, "frames") first_frame, _ = load_first_frame(input_frames_dir) #check if prev_segment/segment_points exists, if so ask the user if they #want to reuse, otherwise move to select_points() prev_points_file = os.path.join(prev_segment, "segment_points") if os.path.exists(prev_points_file): reuse = input(f"Would you like to reuse points from {prev_segment}? (yes/no): ").strip().lower() if reuse == 'yes': points = np.loadtxt(prev_points_file, comments="#") points_a = points[:2] points_b = points[2:] with open(points_file, 'w') as f: np.savetxt(f, points_a, header="Object A Points") np.savetxt(f, points_b, header="Object B Points") continue points_a, points_b = select_points(first_frame) with open(points_file, 'w') as f: np.savetxt(f, points_a, header="Object A Points") np.savetxt(f, points_b, header="Object B Points") points = np.loadtxt(points_file, comments="#") points_a = points[:2] points_b = points[2:] # Check X coordinate distance for Object A if abs(points_a[0][0] - points_a[1][0]) > 1500: print(f"Validation failed for Object A in {segment_dir}") continue # Check X coordinate distance for Object B if abs(points_b[0][0] - points_b[1][0]) > 1500: print(f"Validation failed for Object B in {segment_dir}") continue print("Validation complete.") # Function to apply green mask to the frame def apply_green_mask(frame, masks): green_mask = np.zeros_like(frame) green_mask[:, :] = [0, 255, 0] # BGR format for green combined_mask = np.zeros(frame.shape[:2], dtype=bool) for mask in masks: mask = mask.squeeze() if mask.shape != frame.shape[:2]: mask = cv2.resize(mask, (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_NEAREST) combined_mask = np.logical_or(combined_mask, mask) inverted_mask = np.logical_not(combined_mask) frame[inverted_mask] = green_mask[inverted_mask] return frame # Function to initialize the predictor def initialize_predictor(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") predictor = build_sam2_video_predictor(MODEL_CFG, SAM2_CHECKPOINT, device=device) return predictor # Function to load the first frame def load_first_frame(input_frames_dir): frame_names = sorted([p for p in os.listdir(input_frames_dir) if p.endswith(('.jpg', '.jpeg', '.png'))]) first_frame_path = os.path.join(input_frames_dir, frame_names[0]) first_frame = cv2.imread(first_frame_path) return first_frame, frame_names # Function to select points using OpenCV def select_points(first_frame): points_a = [] points_b = [] current_object = 'A' point_count = 0 selection_complete = False def mouse_callback(event, x, y, flags, param): nonlocal points_a, points_b, current_object, point_count, selection_complete if event == cv2.EVENT_LBUTTONDOWN: if current_object == 'A': points_a.append((x, y)) point_count += 1 print(f"Selected point {point_count} for Object A: ({x}, {y})") if len(points_a) == 2: current_object = 'B' point_count = 0 print("Select point 1 for Object B") elif current_object == 'B': points_b.append((x, y)) point_count += 1 print(f"Selected point {point_count} for Object B: ({x}, {y})") if len(points_b) == 2: selection_complete = True print("Select point 1 for Object A") cv2.namedWindow('Select Points', cv2.WINDOW_NORMAL) cv2.resizeWindow('Select Points', int(first_frame.shape[1] * (500 / first_frame.shape[0])), 500) cv2.imshow('Select Points', first_frame) cv2.setMouseCallback('Select Points', mouse_callback) while not selection_complete: cv2.waitKey(1) cv2.destroyAllWindows() return np.array(points_a, dtype=np.float32), np.array(points_b, dtype=np.float32) # Function to add points to the predictor def add_points_to_predictor(predictor, inference_state, points, obj_id): labels = np.array([1, 1], np.int32) try: print(f"Adding points for Object {obj_id}: {points}") _, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box( inference_state=inference_state, frame_idx=0, obj_id=obj_id, points=points, labels=labels, ) print(f"Object {obj_id} added successfully: {out_obj_ids}") return out_mask_logits except Exception as e: print(f"Error adding points for Object {obj_id}: {e}") exit() # Function to show mask on frame def show_mask_on_frame(frame, masks): combined_mask = np.zeros(frame.shape[:2], dtype=bool) for mask in masks: mask = mask.squeeze() if mask.shape != frame.shape[:2]: mask = cv2.resize(mask, (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_NEAREST) combined_mask = np.logical_or(combined_mask, mask) color = (0, 255, 0) # Green color for the mask frame[combined_mask] = color return frame # Function to confirm masks with the user def confirm_masks(first_frame, masks_a, masks_b): first_frame_with_masks = show_mask_on_frame(first_frame.copy(), masks_a + masks_b) cv2.namedWindow('First Frame with Masks', cv2.WINDOW_NORMAL) cv2.resizeWindow('First Frame with Masks', int(first_frame.shape[1] * (500 / first_frame.shape[0])), 500) cv2.imshow('First Frame with Masks', first_frame_with_masks) cv2.waitKey(0) cv2.destroyAllWindows() confirmation = input("Are the masks correct? (yes/no): ").strip().lower() if confirmation != 'yes': print("Aborting process.") exit() # Function to propagate masks across video frames def propagate_masks(predictor, inference_state): video_segments = {} for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state): video_segments[out_frame_idx] = { out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy() for i, out_obj_id in enumerate(out_obj_ids) } return video_segments # Function to process and save frames def process_and_save_frames(input_frames_dir, output_frames_dir, frame_names, video_segments): for out_frame_idx, frame_name in enumerate(frame_names): frame_path = os.path.join(input_frames_dir, frame_name) frame = cv2.imread(frame_path) masks = [video_segments[out_frame_idx][out_obj_id] for out_obj_id in video_segments[out_frame_idx]] frame = apply_green_mask(frame, masks) output_path = os.path.join(output_frames_dir, frame_name) cv2.imwrite(output_path, frame) print("Processing complete. Frames saved in:", output_frames_dir) # Main function def main(): base_dir = "./spirit_2min_segments" segments = [d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d)) and d.startswith("segment_")] segments.sort(key=lambda x: int(x.split("_")[1])) if len(sys.argv) > 1 and sys.argv[1] == "--validate-points": validate_points(base_dir) return # Step 1: Select points for each segment old_points = None for segment in segments: segment_dir = os.path.join(base_dir, segment) points_file = os.path.join(segment_dir, "segment_points") if not os.path.exists(points_file): input_frames_dir = os.path.join(segment_dir, "frames") first_frame, _ = load_first_frame(input_frames_dir) # check if oldpoints is not None, if so ask user if they # want to reuse them and then skip selecting new points #if old_points is not None: # #show the frame # cv2.namedWindow('Select Points', cv2.WINDOW_NORMAL) # cv2.resizeWindow('Select Points', int(first_frame.shape[1] * (500 / first_frame.shape[0])), 500) # cv2.imshow(f'Select Points', first_frame) # reuse = input(f"Would you like to reuse points from the previous segment? (yes/no): ").strip().lower() # if reuse == 'yes': # points_a, points_b = old_points # with open(points_file, 'w') as f: # np.savetxt(f, points_a, header="Object A Points") # np.savetxt(f, points_b, header="Object B Points") # old_points = (points_a, points_b) # cv2.destroyAllWindows() # continue # cv2.destroyAllWindows() points_a, points_b = select_points(first_frame) with open(points_file, 'w') as f: np.savetxt(f, points_a, header="Object A Points") np.savetxt(f, points_b, header="Object B Points") old_points = (points_a, points_b) # Step 2: Process frames for each segment for segment in segments: segment_dir = os.path.join(base_dir, segment) output_done_file = os.path.join(segment_dir, "output_frames_done") if not os.path.exists(output_done_file): print(f"Processing segment {segment}") points_file = os.path.join(segment_dir, "segment_points") points = np.loadtxt(points_file, comments="#") points_a = points[:2] points_b = points[2:] input_frames_dir = os.path.join(segment_dir, "frames") output_frames_dir = os.path.join(segment_dir, "output_frames") os.makedirs(output_frames_dir, exist_ok=True) first_frame, frame_names = load_first_frame(input_frames_dir) predictor = initialize_predictor() inference_state = predictor.init_state(video_path=input_frames_dir, async_loading_frames=True) out_mask_logits_a = add_points_to_predictor(predictor, inference_state, points_a, obj_id=1) out_mask_logits_b = add_points_to_predictor(predictor, inference_state, points_b, obj_id=2) masks_a = [(out_mask_logits_a[i] > 0.0).cpu().numpy() for i in range(len(out_mask_logits_a))] masks_b = [(out_mask_logits_b[i] > 0.0).cpu().numpy() for i in range(len(out_mask_logits_b))] #confirm_masks(first_frame, masks_a, masks_b) video_segments = propagate_masks(predictor, inference_state) predictor.reset_state(inference_state) process_and_save_frames(input_frames_dir, output_frames_dir, frame_names, video_segments) del inference_state del video_segments del predictor import gc gc.collect() open(output_done_file, 'a').close() if __name__ == "__main__": main() --- second script start: # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import argparse import os from collections import defaultdict import numpy as np import torch from PIL import Image from sam2.build_sam import build_sam2_video_predictor # the PNG palette for DAVIS 2017 dataset DAVIS_PALETTE = b"\x00\x00\x00\x80\x00\x00\x00\x80\x00\x80\x80\x00\x00\x00\x80\x80\x00\x80\x00\x80\x80\x80\x80\x80@\x00\x00\xc0\x00\x00@\x80\x00\xc0\x80\x00@\x00\x80\xc0\x00\x80@\x80\x80\xc0\x80\x80\x00@\x00\x80@\x00\x00\xc0\x00\x80\xc0\x00\x00@\x80\x80@\x80\x00\xc0\x80\x80\xc0\x80@@\x00\xc0@\x00@\xc0\x00\xc0\xc0\x00@@\x80\xc0@\x80@\xc0\x80\xc0\xc0\x80\x00\x00@\x80\x00@\x00\x80@\x80\x80@\x00\x00\xc0\x80\x00\xc0\x00\x80\xc0\x80\x80\xc0@\x00@\xc0\x00@@\x80@\xc0\x80@@\x00\xc0\xc0\x00\xc0@\x80\xc0\xc0\x80\xc0\x00@@\x80@@\x00\xc0@\x80\xc0@\x00@\xc0\x80@\xc0\x00\xc0\xc0\x80\xc0\xc0@@@\xc0@@@\xc0@\xc0\xc0@@@\xc0\xc0@\xc0@\xc0\xc0\xc0\xc0\xc0 \x00\x00\xa0\x00\x00 \x80\x00\xa0\x80\x00 \x00\x80\xa0\x00\x80 \x80\x80\xa0\x80\x80`\x00\x00\xe0\x00\x00`\x80\x00\xe0\x80\x00`\x00\x80\xe0\x00\x80`\x80\x80\xe0\x80\x80 @\x00\xa0@\x00 \xc0\x00\xa0\xc0\x00 @\x80\xa0@\x80 \xc0\x80\xa0\xc0\x80`@\x00\xe0@\x00`\xc0\x00\xe0\xc0\x00`@\x80\xe0@\x80`\xc0\x80\xe0\xc0\x80 \x00@\xa0\x00@ \x80@\xa0\x80@ \x00\xc0\xa0\x00\xc0 \x80\xc0\xa0\x80\xc0`\x00@\xe0\x00@`\x80@\xe0\x80@`\x00\xc0\xe0\x00\xc0`\x80\xc0\xe0\x80\xc0 @@\xa0@@ \xc0@\xa0\xc0@ @\xc0\xa0@\xc0 \xc0\xc0\xa0\xc0\xc0`@@\xe0@@`\xc0@\xe0\xc0@`@\xc0\xe0@\xc0`\xc0\xc0\xe0\xc0\xc0\x00 \x00\x80 \x00\x00\xa0\x00\x80\xa0\x00\x00 \x80\x80 \x80\x00\xa0\x80\x80\xa0\x80@ \x00\xc0 \x00@\xa0\x00\xc0\xa0\x00@ \x80\xc0 \x80@\xa0\x80\xc0\xa0\x80\x00`\x00\x80`\x00\x00\xe0\x00\x80\xe0\x00\x00`\x80\x80`\x80\x00\xe0\x80\x80\xe0\x80@`\x00\xc0`\x00@\xe0\x00\xc0\xe0\x00@`\x80\xc0`\x80@\xe0\x80\xc0\xe0\x80\x00 @\x80 @\x00\xa0@\x80\xa0@\x00 \xc0\x80 \xc0\x00\xa0\xc0\x80\xa0\xc0@ @\xc0 @@\xa0@\xc0\xa0@@ \xc0\xc0 \xc0@\xa0\xc0\xc0\xa0\xc0\x00`@\x80`@\x00\xe0@\x80\xe0@\x00`\xc0\x80`\xc0\x00\xe0\xc0\x80\xe0\xc0@`@\xc0`@@\xe0@\xc0\xe0@@`\xc0\xc0`\xc0@\xe0\xc0\xc0\xe0\xc0 \x00\xa0 \x00 \xa0\x00\xa0\xa0\x00 \x80\xa0 \x80 \xa0\x80\xa0\xa0\x80` \x00\xe0 \x00`\xa0\x00\xe0\xa0\x00` \x80\xe0 \x80`\xa0\x80\xe0\xa0\x80 `\x00\xa0`\x00 \xe0\x00\xa0\xe0\x00 `\x80\xa0`\x80 \xe0\x80\xa0\xe0\x80``\x00\xe0`\x00`\xe0\x00\xe0\xe0\x00``\x80\xe0`\x80`\xe0\x80\xe0\xe0\x80 @\xa0 @ \xa0@\xa0\xa0@ \xc0\xa0 \xc0 \xa0\xc0\xa0\xa0\xc0` @\xe0 @`\xa0@\xe0\xa0@` \xc0\xe0 \xc0`\xa0\xc0\xe0\xa0\xc0 `@\xa0`@ \xe0@\xa0\xe0@ `\xc0\xa0`\xc0 \xe0\xc0\xa0\xe0\xc0``@\xe0`@`\xe0@\xe0\xe0@``\xc0\xe0`\xc0`\xe0\xc0\xe0\xe0\xc0" def load_ann_png(path): """Load a PNG file as a mask and its palette.""" mask = Image.open(path) palette = mask.getpalette() mask = np.array(mask).astype(np.uint8) return mask, palette def save_ann_png(path, mask, palette): """Save a mask as a PNG file with the given palette.""" assert mask.dtype == np.uint8 assert mask.ndim == 2 output_mask = Image.fromarray(mask) output_mask.putpalette(palette) output_mask.save(path) def get_per_obj_mask(mask): """Split a mask into per-object masks.""" object_ids = np.unique(mask) object_ids = object_ids[object_ids > 0].tolist() per_obj_mask = {object_id: (mask == object_id) for object_id in object_ids} return per_obj_mask def put_per_obj_mask(per_obj_mask, height, width): """Combine per-object masks into a single mask.""" mask = np.zeros((height, width), dtype=np.uint8) object_ids = sorted(per_obj_mask)[::-1] for object_id in object_ids: object_mask = per_obj_mask[object_id] object_mask = object_mask.reshape(height, width) mask[object_mask] = object_id return mask def load_masks_from_dir( input_mask_dir, video_name, frame_name, per_obj_png_file, allow_missing=False ): """Load masks from a directory as a dict of per-object masks.""" if not per_obj_png_file: input_mask_path = os.path.join(input_mask_dir, video_name, f"{frame_name}.png") if allow_missing and not os.path.exists(input_mask_path): return {}, None input_mask, input_palette = load_ann_png(input_mask_path) per_obj_input_mask = get_per_obj_mask(input_mask) else: per_obj_input_mask = {} input_palette = None # each object is a directory in "{object_id:%03d}" format for object_name in os.listdir(os.path.join(input_mask_dir, video_name)): object_id = int(object_name) input_mask_path = os.path.join( input_mask_dir, video_name, object_name, f"{frame_name}.png" ) if allow_missing and not os.path.exists(input_mask_path): continue input_mask, input_palette = load_ann_png(input_mask_path) per_obj_input_mask[object_id] = input_mask > 0 return per_obj_input_mask, input_palette def save_masks_to_dir( output_mask_dir, video_name, frame_name, per_obj_output_mask, height, width, per_obj_png_file, output_palette, ): """Save masks to a directory as PNG files.""" os.makedirs(os.path.join(output_mask_dir, video_name), exist_ok=True) if not per_obj_png_file: output_mask = put_per_obj_mask(per_obj_output_mask, height, width) output_mask_path = os.path.join( output_mask_dir, video_name, f"{frame_name}.png" ) save_ann_png(output_mask_path, output_mask, output_palette) else: for object_id, object_mask in per_obj_output_mask.items(): object_name = f"{object_id:03d}" os.makedirs( os.path.join(output_mask_dir, video_name, object_name), exist_ok=True, ) output_mask = object_mask.reshape(height, width).astype(np.uint8) output_mask_path = os.path.join( output_mask_dir, video_name, object_name, f"{frame_name}.png" ) save_ann_png(output_mask_path, output_mask, output_palette) @torch.inference_mode() @torch.autocast(device_type="cuda", dtype=torch.bfloat16) def vos_inference( predictor, base_video_dir, input_mask_dir, output_mask_dir, video_name, score_thresh=0.0, use_all_masks=False, per_obj_png_file=False, ): """Run VOS inference on a single video with the given predictor.""" # load the video frames and initialize the inference state on this video video_dir = os.path.join(base_video_dir, video_name) frame_names = [ os.path.splitext(p)[0] for p in os.listdir(video_dir) if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"] ] frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) inference_state = predictor.init_state( video_path=video_dir, async_loading_frames=False ) height = inference_state["video_height"] width = inference_state["video_width"] input_palette = None # fetch mask inputs from input_mask_dir (either only mask for the first frame, or all available masks) if not use_all_masks: # use only the first video's ground-truth mask as the input mask input_frame_inds = [0] else: # use all mask files available in the input_mask_dir as the input masks if not per_obj_png_file: input_frame_inds = [ idx for idx, name in enumerate(frame_names) if os.path.exists( os.path.join(input_mask_dir, video_name, f"{name}.png") ) ] else: input_frame_inds = [ idx for object_name in os.listdir(os.path.join(input_mask_dir, video_name)) for idx, name in enumerate(frame_names) if os.path.exists( os.path.join(input_mask_dir, video_name, object_name, f"{name}.png") ) ] # check and make sure we got at least one input frame if len(input_frame_inds) == 0: raise RuntimeError( f"In {video_name=}, got no input masks in {input_mask_dir=}. " "Please make sure the input masks are available in the correct format." ) input_frame_inds = sorted(set(input_frame_inds)) # add those input masks to SAM 2 inference state before propagation object_ids_set = None for input_frame_idx in input_frame_inds: try: per_obj_input_mask, input_palette = load_masks_from_dir( input_mask_dir=input_mask_dir, video_name=video_name, frame_name=frame_names[input_frame_idx], per_obj_png_file=per_obj_png_file, ) except FileNotFoundError as e: raise RuntimeError( f"In {video_name=}, failed to load input mask for frame {input_frame_idx=}. " "Please add the `--track_object_appearing_later_in_video` flag " "for VOS datasets that don't have all objects to track appearing " "in the first frame (such as LVOS or YouTube-VOS)." ) from e # get the list of object ids to track from the first input frame if object_ids_set is None: object_ids_set = set(per_obj_input_mask) for object_id, object_mask in per_obj_input_mask.items(): # check and make sure no new object ids appear only in later frames if object_id not in object_ids_set: raise RuntimeError( f"In {video_name=}, got a new {object_id=} appearing only in a " f"later {input_frame_idx=} (but not appearing in the first frame). " "Please add the `--track_object_appearing_later_in_video` flag " "for VOS datasets that don't have all objects to track appearing " "in the first frame (such as LVOS or YouTube-VOS)." ) predictor.add_new_mask( inference_state=inference_state, frame_idx=input_frame_idx, obj_id=object_id, mask=object_mask, ) # check and make sure we have at least one object to track if object_ids_set is None or len(object_ids_set) == 0: raise RuntimeError( f"In {video_name=}, got no object ids on {input_frame_inds=}. " "Please add the `--track_object_appearing_later_in_video` flag " "for VOS datasets that don't have all objects to track appearing " "in the first frame (such as LVOS or YouTube-VOS)." ) # run propagation throughout the video and collect the results in a dict os.makedirs(os.path.join(output_mask_dir, video_name), exist_ok=True) output_palette = input_palette or DAVIS_PALETTE video_segments = {} # video_segments contains the per-frame segmentation results for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video( inference_state ): per_obj_output_mask = { out_obj_id: (out_mask_logits[i] > score_thresh).cpu().numpy() for i, out_obj_id in enumerate(out_obj_ids) } video_segments[out_frame_idx] = per_obj_output_mask # write the output masks as palette PNG files to output_mask_dir for out_frame_idx, per_obj_output_mask in video_segments.items(): save_masks_to_dir( output_mask_dir=output_mask_dir, video_name=video_name, frame_name=frame_names[out_frame_idx], per_obj_output_mask=per_obj_output_mask, height=height, width=width, per_obj_png_file=per_obj_png_file, output_palette=output_palette, ) @torch.inference_mode() @torch.autocast(device_type="cuda", dtype=torch.bfloat16) def vos_separate_inference_per_object( predictor, base_video_dir, input_mask_dir, output_mask_dir, video_name, score_thresh=0.0, use_all_masks=False, per_obj_png_file=False, ): """ Run VOS inference on a single video with the given predictor. Unlike `vos_inference`, this function run inference separately for each object in a video, which could be applied to datasets like LVOS or YouTube-VOS that don't have all objects to track appearing in the first frame (i.e. some objects might appear only later in the video). """ # load the video frames and initialize the inference state on this video video_dir = os.path.join(base_video_dir, video_name) frame_names = [ os.path.splitext(p)[0] for p in os.listdir(video_dir) if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"] ] frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) inference_state = predictor.init_state( video_path=video_dir, async_loading_frames=False ) height = inference_state["video_height"] width = inference_state["video_width"] input_palette = None # collect all the object ids and their input masks inputs_per_object = defaultdict(dict) for idx, name in enumerate(frame_names): if per_obj_png_file or os.path.exists( os.path.join(input_mask_dir, video_name, f"{name}.png") ): per_obj_input_mask, input_palette = load_masks_from_dir( input_mask_dir=input_mask_dir, video_name=video_name, frame_name=frame_names[idx], per_obj_png_file=per_obj_png_file, allow_missing=True, ) for object_id, object_mask in per_obj_input_mask.items(): # skip empty masks if not np.any(object_mask): continue # if `use_all_masks=False`, we only use the first mask for each object if len(inputs_per_object[object_id]) > 0 and not use_all_masks: continue print(f"adding mask from frame {idx} as input for {object_id=}") inputs_per_object[object_id][idx] = object_mask # run inference separately for each object in the video object_ids = sorted(inputs_per_object) output_scores_per_object = defaultdict(dict) for object_id in object_ids: # add those input masks to SAM 2 inference state before propagation input_frame_inds = sorted(inputs_per_object[object_id]) predictor.reset_state(inference_state) for input_frame_idx in input_frame_inds: predictor.add_new_mask( inference_state=inference_state, frame_idx=input_frame_idx, obj_id=object_id, mask=inputs_per_object[object_id][input_frame_idx], ) # run propagation throughout the video and collect the results in a dict for out_frame_idx, _, out_mask_logits in predictor.propagate_in_video( inference_state, start_frame_idx=min(input_frame_inds), reverse=False, ): obj_scores = out_mask_logits.cpu().numpy() output_scores_per_object[object_id][out_frame_idx] = obj_scores # post-processing: consolidate the per-object scores into per-frame masks os.makedirs(os.path.join(output_mask_dir, video_name), exist_ok=True) output_palette = input_palette or DAVIS_PALETTE video_segments = {} # video_segments contains the per-frame segmentation results for frame_idx in range(len(frame_names)): scores = torch.full( size=(len(object_ids), 1, height, width), fill_value=-1024.0, dtype=torch.float32, ) for i, object_id in enumerate(object_ids): if frame_idx in output_scores_per_object[object_id]: scores[i] = torch.from_numpy( output_scores_per_object[object_id][frame_idx] ) if not per_obj_png_file: scores = predictor._apply_non_overlapping_constraints(scores) per_obj_output_mask = { object_id: (scores[i] > score_thresh).cpu().numpy() for i, object_id in enumerate(object_ids) } video_segments[frame_idx] = per_obj_output_mask # write the output masks as palette PNG files to output_mask_dir for frame_idx, per_obj_output_mask in video_segments.items(): save_masks_to_dir( output_mask_dir=output_mask_dir, video_name=video_name, frame_name=frame_names[frame_idx], per_obj_output_mask=per_obj_output_mask, height=height, width=width, per_obj_png_file=per_obj_png_file, output_palette=output_palette, ) def main(): parser = argparse.ArgumentParser() parser.add_argument( "--sam2_cfg", type=str, default="configs/sam2.1/sam2.1_hiera_b+.yaml", help="SAM 2 model configuration file", ) parser.add_argument( "--sam2_checkpoint", type=str, default="./checkpoints/sam2.1_hiera_b+.pt", help="path to the SAM 2 model checkpoint", ) parser.add_argument( "--base_video_dir", type=str, required=True, help="directory containing videos (as JPEG files) to run VOS prediction on", ) parser.add_argument( "--input_mask_dir", type=str, required=True, help="directory containing input masks (as PNG files) of each video", ) parser.add_argument( "--video_list_file", type=str, default=None, help="text file containing the list of video names to run VOS prediction on", ) parser.add_argument( "--output_mask_dir", type=str, required=True, help="directory to save the output masks (as PNG files)", ) parser.add_argument( "--score_thresh", type=float, default=0.0, help="threshold for the output mask logits (default: 0.0)", ) parser.add_argument( "--use_all_masks", action="store_true", help="whether to use all available PNG files in input_mask_dir " "(default without this flag: just the first PNG file as input to the SAM 2 model; " "usually we don't need this flag, since semi-supervised VOS evaluation usually takes input from the first frame only)", ) parser.add_argument( "--per_obj_png_file", action="store_true", help="whether use separate per-object PNG files for input and output masks " "(default without this flag: all object masks are packed into a single PNG file on each frame following DAVIS format; " "note that the SA-V dataset stores each object mask as an individual PNG file and requires this flag)", ) parser.add_argument( "--apply_postprocessing", action="store_true", help="whether to apply postprocessing (e.g. hole-filling) to the output masks " "(we don't apply such post-processing in the SAM 2 model evaluation)", ) parser.add_argument( "--track_object_appearing_later_in_video", action="store_true", help="whether to track objects that appear later in the video (i.e. not on the first frame; " "some VOS datasets like LVOS or YouTube-VOS don't have all objects appearing in the first frame)", ) args = parser.parse_args() # if we use per-object PNG files, they could possibly overlap in inputs and outputs hydra_overrides_extra = [ "++model.non_overlap_masks=" + ("false" if args.per_obj_png_file else "true") ] predictor = build_sam2_video_predictor( config_file=args.sam2_cfg, ckpt_path=args.sam2_checkpoint, apply_postprocessing=args.apply_postprocessing, hydra_overrides_extra=hydra_overrides_extra, ) if args.use_all_masks: print("using all available masks in input_mask_dir as input to the SAM 2 model") else: print( "using only the first frame's mask in input_mask_dir as input to the SAM 2 model" ) # if a video list file is provided, read the video names from the file # (otherwise, we use all subdirectories in base_video_dir) if args.video_list_file is not None: with open(args.video_list_file, "r") as f: video_names = [v.strip() for v in f.readlines()] else: video_names = [ p for p in os.listdir(args.base_video_dir) if os.path.isdir(os.path.join(args.base_video_dir, p)) ] print(f"running VOS prediction on {len(video_names)} videos:\n{video_names}") for n_video, video_name in enumerate(video_names): print(f"\n{n_video + 1}/{len(video_names)} - running on {video_name}") if not args.track_object_appearing_later_in_video: vos_inference( predictor=predictor, base_video_dir=args.base_video_dir, input_mask_dir=args.input_mask_dir, output_mask_dir=args.output_mask_dir, video_name=video_name, score_thresh=args.score_thresh, use_all_masks=args.use_all_masks, per_obj_png_file=args.per_obj_png_file, ) else: vos_separate_inference_per_object( predictor=predictor, base_video_dir=args.base_video_dir, input_mask_dir=args.input_mask_dir, output_mask_dir=args.output_mask_dir, video_name=video_name, score_thresh=args.score_thresh, use_all_masks=args.use_all_masks, per_obj_png_file=args.per_obj_png_file, ) print( f"completed VOS prediction on {len(video_names)} videos -- " f"output masks saved to {args.output_mask_dir}" ) if __name__ == "__main__": main()