# this script will process multiple video segments (5 second clips) # of a long (1 hour total video) and will greenscreen/key out everything # except tracked objects. # # To specify which items are tracked, the user will call the script with # --segments-collect-points 1 5 10 ... Which are segments to be treated # as "keyframes", for these we will ask the user to select the point # we want to track by hand, this is done for large changes in the objects # position in the video. # # For other segments, since the object is mostly static in the frame, # We will use the previous segments final frame (an image # with the object we want to track visible, and everything else green) # to determine an input mask and use add_new_mask() instead of selecting # points. # # Each segment has 2 versions of each frame, on high quality used for # final rendering and 1 low quality used to speed up inference # # When the script finishes, each segment should have an output directory # with the same object tracked throughout the every frame in all the segment directories # # I will then turn these back into a video using ffmpeg but that is outside the scope # of this program import os import cv2 import numpy as np from concurrent.futures import ThreadPoolExecutor import torch import logging import sys import gc from sam2.build_sam import build_sam2_video_predictor import argparse logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # Variables for input and output directories SAM2_CHECKPOINT = "../checkpoints/sam2.1_hiera_large.pt" MODEL_CFG = "configs/sam2.1/sam2.1_hiera_l.yaml" GREEN = [0, 255, 0] BLUE = [255, 0, 0] INFERENCE_SCALE = 0.25 FULL_SCALE = 1.0 def open_video(video_path): """ Opens a video file and returns a generator that yields frames. Parameters: - video_path: Path to the video file. Returns: - A generator that yields frames from the video. """ cap = cv2.VideoCapture(video_path) if not cap.isOpened(): print(f"Error: Could not open video file {video_path}") return while True: ret, frame = cap.read() if not ret: break yield frame cap.release() def load_previous_segment_mask(prev_segment_dir): mask_path = os.path.join(prev_segment_dir, "mask.png") mask_image = cv2.imread(mask_path) if mask_image is None: raise FileNotFoundError(f"Mask image not found at {mask_path}") # Ensure the mask_image has three color channels if len(mask_image.shape) != 3 or mask_image.shape[2] != 3: raise ValueError("Mask image does not have three color channels.") mask_image = mask_image.astype(np.uint8) # Extract Object A and Object B masks mask_a = np.all(mask_image == GREEN, axis=2) mask_b = np.all(mask_image == BLUE, axis=2) per_obj_input_mask = {1: mask_a, 2: mask_b} input_palette = None # No palette needed for binary mask return per_obj_input_mask, input_palette def apply_green_mask(frame, masks): """ Applies masks to the frame, replacing the background with green. Parameters: - frame: numpy array representing the image frame. - masks: list of numpy arrays representing the masks. Returns: - result_frame: numpy array with the green background applied. """ # Initialize combined mask as a boolean array combined_mask = np.zeros(frame.shape[:2], dtype=bool) for mask in masks: mask = mask.squeeze() # Resize the mask if necessary if mask.shape != frame.shape[:2]: # Resize the mask using bilinear interpolation # and convert it to float32 for accurate interpolation resized_mask = cv2.resize( mask.astype(np.float32), (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_LINEAR ) # Threshold the resized mask to obtain a boolean mask # add a small gausian blur to the mask to smooth out the edges resized_mask = cv2.GaussianBlur(resized_mask, (50, 50), 0) mask = resized_mask > 0.5 else: # Ensure mask is boolean mask = mask.astype(bool) # Combine masks using logical OR combined_mask |= mask # Now both arrays are bool # Create a green background image green_background = np.full_like(frame, [0, 255, 0]) # Use combined mask to overlay the original frame onto the green background result_frame = np.where( combined_mask[..., None], frame, green_background ) return result_frame def initialize_predictor(): if torch.cuda.is_available(): device = torch.device("cuda") elif torch.backends.mps.is_available(): device = torch.device("mps") print( "\nSupport for MPS devices is preliminary. SAM 2 is trained with CUDA and might " "give numerically different outputs and sometimes degraded performance on MPS." ) # Enable MPS fallback for operations not supported on MPS os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" else: device = torch.device("cpu") logger.info(f"Using device: {device}") predictor = build_sam2_video_predictor(MODEL_CFG, SAM2_CHECKPOINT, device=device) return predictor def load_first_frame(video_path, scale=1.0): """ Opens a video file and returns the first frame, scaled as specified. Parameters: - video_path: Path to the video file. - scale: Scaling factor for the frame (default is 1.0 for original size). Returns: - first_frame: The first frame of the video, scaled accordingly. """ cap = cv2.VideoCapture(video_path) if not cap.isOpened(): logger.error(f"Error: Could not open video file {video_path}") return None ret, frame = cap.read() cap.release() if not ret: logger.error(f"Error: Could not read frame from video file {video_path}") return None if scale != 1.0: frame = cv2.resize( frame, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR ) return frame def select_points(first_frame): points_a = [] points_b = [] current_object = 'A' point_count = 0 selection_complete = False def mouse_callback(event, x, y, flags, param): nonlocal points_a, points_b, current_object, point_count, selection_complete if event == cv2.EVENT_LBUTTONDOWN: if current_object == 'A': points_a.append((x, y)) point_count += 1 print(f"Selected point {point_count} for Object A: ({x}, {y})") if len(points_a) == 5: # Collect 4 points for Object A current_object = 'B' point_count = 0 print("Select point 1 for Object B") elif current_object == 'B': points_b.append((x, y)) point_count += 1 print(f"Selected point {point_count} for Object B: ({x}, {y})") if len(points_b) == 5: # Collect 4 points for Object B selection_complete = True print("Select point 1 for Object A") cv2.namedWindow('Select Points', cv2.WINDOW_NORMAL) cv2.resizeWindow('Select Points', int(first_frame.shape[1] * (500 / first_frame.shape[0])), 500) cv2.imshow('Select Points', first_frame) cv2.setMouseCallback('Select Points', mouse_callback) while not selection_complete: cv2.waitKey(1) cv2.destroyAllWindows() return np.array(points_a, dtype=np.float32), np.array(points_b, dtype=np.float32) def add_points_to_predictor(predictor, inference_state, points, obj_id): labels = np.array([1, 1, 1, 1, 1], np.int32) # Update labels to match 4 points points = np.array(points, dtype=np.float32) # Ensure points have shape (4, 2) try: print(f"Adding points for Object {obj_id}: {points}") _, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box( inference_state=inference_state, frame_idx=0, obj_id=obj_id, points=points, labels=labels, ) print(f"Object {obj_id} added successfully: {out_obj_ids}") return out_mask_logits except Exception as e: print(f"Error adding points for Object {obj_id}: {e}") exit() def propagate_masks(predictor, inference_state): video_segments = {} for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state): video_segments[out_frame_idx] = { out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy() for i, out_obj_id in enumerate(out_obj_ids) } return video_segments def apply_colored_mask(frame, masks_a, masks_b): colored_mask = np.zeros_like(frame) # Apply colors to the masks for mask in masks_a: mask = mask.squeeze() if mask.shape != frame.shape[:2]: mask = cv2.resize(mask, (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_NEAREST) indices = np.where(mask) colored_mask[mask] = [0, 255, 0] # Green for Object A for mask in masks_b: mask = mask.squeeze() if mask.shape != frame.shape[:2]: mask = cv2.resize(mask, (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_NEAREST) indices = np.where(mask) colored_mask[mask] = [255, 0, 0] # Blue for Object B return colored_mask def process_and_save_output_video(video_path, output_video_path, video_segments, use_nvenc=False): """ Process high-resolution frames, apply upscaled masks, and save the output video. """ cap = cv2.VideoCapture(video_path) frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) fps = cap.get(cv2.CAP_PROP_FPS) or 59.94 # Setup VideoWriter with desired settings if use_nvenc: # Use FFmpeg with NVENC offloading for H.265 encoding import subprocess if sys.platform == 'darwin': encoder = 'hevc_videotoolbox' else: encoder = 'hevc_nvenc' command = [ 'ffmpeg', '-y', # Overwrite output file if it exists '-f', 'rawvideo', '-vcodec', 'rawvideo', '-pix_fmt', 'bgr24', '-s', f'{frame_width}x{frame_height}', '-r', str(fps), '-i', '-', # Input from stdin '-an', # No audio '-vcodec', encoder, '-pix_fmt', 'yuv420p', '-preset', 'slow', '-b:v', '50M', output_video_path ] process = subprocess.Popen(command, stdin=subprocess.PIPE) else: # Use OpenCV VideoWriter fourcc = cv2.VideoWriter_fourcc(*'HEVC') # H.265 out = cv2.VideoWriter(output_video_path, fourcc, fps, (frame_width, frame_height)) frame_idx = 0 while True: ret, frame = cap.read() if not ret or frame_idx >= len(video_segments): break masks = [video_segments[frame_idx][out_obj_id] for out_obj_id in video_segments[frame_idx]] upscaled_masks = [] for mask in masks: mask = mask.squeeze() upscaled_mask = cv2.resize(mask.astype(np.uint8), (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_NEAREST) upscaled_masks.append(upscaled_mask) result_frame = apply_green_mask(frame, upscaled_masks) # Write frame to output if use_nvenc: process.stdin.write(result_frame.tobytes()) else: out.write(result_frame) frame_idx += 1 cap.release() if use_nvenc: process.stdin.close() process.wait() else: out.release() def get_video_file_name(index): return f"segment_{str(index).zfill(3)}.mp4" def do_collect_segment_points(base_dir, segments, collect_points_segments, scale=1.0): logger.info("Collecting points for requested segments.") for i, segment in enumerate(segments): segment_index = int(segment.split("_")[1]) segment_dir = os.path.join(base_dir, segment) points_file = os.path.join(segment_dir, "segment_points") video_file = os.path.join(segment_dir, get_video_file_name(i)) if segment_index in collect_points_segments and not os.path.exists(points_file): first_frame = load_first_frame(video_file, scale) points_a, points_b = select_points(first_frame) with open(points_file, 'w') as f: np.savetxt(f, points_a, header="Object A Points") np.savetxt(f, points_b, header="Object B Points") def save_final_masks(video_segments, mask_output_path): """ Save the final masks as a colored image. """ last_frame_idx = max(video_segments.keys()) masks_dict = video_segments[last_frame_idx] # Assuming you have two objects with IDs 1 and 2 mask_a = masks_dict.get(1).squeeze() mask_b = masks_dict.get(2).squeeze() #create a black image with dimensions with shape (mask_a.y, mask_a.x, 3) black_frame = np.zeros((mask_a.shape[0], mask_a.shape[1], 3), dtype=np.uint8) if mask_a is None or mask_b is None: print("Error: Masks for objects not found.") return #convert mask to np.uint8 mask_a = mask_a.astype(bool) mask_b = mask_b.astype(bool) # mask a mask_a = mask_a.squeeze() indices = np.where(mask_a) black_frame[mask_a] = GREEN # mask b mask_b = mask_b.squeeze() indices = np.where(mask_b) black_frame[mask_b] = BLUE # Save the mask image cv2.imwrite(mask_output_path, black_frame) def create_low_res_video(input_video_path, output_video_path, scale): """ Creates a low-resolution version of the input video for inference. """ cap = cv2.VideoCapture(input_video_path) frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH) * scale) frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT) * scale) fps = cap.get(cv2.CAP_PROP_FPS) or 59.94 fourcc = cv2.VideoWriter_fourcc(*'mp4v') out = cv2.VideoWriter(output_video_path, fourcc, fps, (frame_width, frame_height)) while True: ret, frame = cap.read() if not ret: break low_res_frame = cv2.resize(frame, (frame_width, frame_height), interpolation=cv2.INTER_LINEAR) out.write(low_res_frame) cap.release() out.release() def main(): parser = argparse.ArgumentParser(description="Process video segments.") # arg for setting base_dir parser.add_argument("--base-dir", type=str, help="Base directory for video segments.") parser.add_argument("--segments-collect-points", nargs='+', type=int, help="Segments for which to collect points.") args = parser.parse_args() base_dir = args.base_dir segments = [d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d)) and d.startswith("segment_")] segments.sort(key=lambda x: int(x.split("_")[1])) scaled_frames_dir_name = "frames_scaled" fullres_frames_dir_name = "frames" # iwant to render the final video with these frames collect_points_segments = args.segments_collect_points if args.segments_collect_points else [] #inference_scale for getting the mask, then use full scale when rendering the video do_collect_segment_points(base_dir, segments, collect_points_segments, scale=INFERENCE_SCALE) for i, segment in enumerate(segments): segment_index = int(segment.split("_")[1]) segment_dir = os.path.join(base_dir, segment) video_file_name = get_video_file_name(i) video_path = os.path.join(segment_dir, video_file_name) output_done_file = os.path.join(segment_dir, "output_frames_done") if os.path.exists(output_done_file): print(f"Segment {segment} already processed. Skipping.") continue logger.info(f"Processing segment {segment}") # Initialize predictor predictor = initialize_predictor() # Prepare low-resolution video frames for inference low_res_video_path = os.path.join(segment_dir, "low_res_video.mp4") if not os.path.exists(low_res_video_path): create_low_res_video(video_path, low_res_video_path, INFERENCE_SCALE) logger.info(f"Low-resolution video created for segment {segment}") else: logger.info(f"Low-resolution video already exists for segment {segment}, reuse") # Initialize inference state with low-resolution video inference_state = predictor.init_state(video_path=low_res_video_path, async_loading_frames=True) # Load points or previous masks points_file = os.path.join(segment_dir, "segment_points") if os.path.exists(points_file): logger.info(f"Using segment_points for segment {segment}") points = np.loadtxt(points_file, comments="#") points_a = points[:5] points_b = points[5:] else: points_a = points_b = None collect_points = segment_index in collect_points_segments if i > 0 and not collect_points and points_a is None: # Try to load previous segment mask logger.info(f"Using previous segment mask for segment {segment}") prev_segment_dir = os.path.join(base_dir, segments[i - 1]) prev_mask_path = os.path.join(prev_segment_dir, "mask.png") if os.path.exists(prev_mask_path): per_obj_input_mask, input_palette = load_previous_segment_mask(prev_segment_dir) # Add previous masks to predictor for obj_id, mask in per_obj_input_mask.items(): predictor.add_new_mask(inference_state, 0, obj_id, mask) else: print(f"Warning: Previous segment mask not found for segment {segment}.") continue # Skip this segment or handle as needed else: if points_a is not None and points_b is not None: print("Using points for segment") # Add points to predictor add_points_to_predictor(predictor, inference_state, points_a, obj_id=1) add_points_to_predictor(predictor, inference_state, points_b, obj_id=2) else: print("Error: No points available for segment.") continue # Skip this segment # Perform inference and collect masks per frame video_segments = propagate_masks(predictor, inference_state) # Process high-resolution frames and save output video output_video_path = os.path.join(segment_dir, f"output_{segment_index}.mp4") print("Processing of segment complete, attempting to save process full video from low res masks") process_and_save_output_video( video_path, output_video_path, video_segments, use_nvenc=True # Set to True to use NVENC offloading ) # Save final masks mask_output_path = os.path.join(segment_dir, "mask.png") save_final_masks(video_segments, mask_output_path) # Clean up predictor.reset_state(inference_state) del inference_state del video_segments del predictor gc.collect() try: os.remove(low_res_video_path) logger.info(f"Deleted low-resolution video for segment {segment}") except Exception as e: logger.warning(f"Could not delete low-resolution video for segment {segment}: {e}") open(output_done_file, 'a').close() print("Processing complete.") if __name__ == "__main__": main()