From 535a984cba9404ce33292b507a132fd6faf882b6 Mon Sep 17 00:00:00 2001 From: Scott Register Date: Mon, 21 Oct 2024 16:56:11 -0700 Subject: [PATCH] mask to png --- notebooks/foo_points_prev.py | 44 ++++++++++++++++++++++++------------ 1 file changed, 29 insertions(+), 15 deletions(-) diff --git a/notebooks/foo_points_prev.py b/notebooks/foo_points_prev.py index 4e1e56e..7c31c90 100644 --- a/notebooks/foo_points_prev.py +++ b/notebooks/foo_points_prev.py @@ -42,8 +42,10 @@ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # Variables for input and output directories SAM2_CHECKPOINT = "../checkpoints/sam2.1_hiera_large.pt" MODEL_CFG = "configs/sam2.1/sam2.1_hiera_l.yaml" +GREEN = [0, 255, 0] +BLUE = [255, 0, 0] -INFERENCE_SCALE = 0.35 +INFERENCE_SCALE = 0.25 FULL_SCALE = 1.0 def open_video(video_path): @@ -68,12 +70,21 @@ def open_video(video_path): cap.release() def load_previous_segment_mask(prev_segment_dir): - mask_path = os.path.join(prev_segment_dir, "mask.jpg") + mask_path = os.path.join(prev_segment_dir, "mask.png") mask_image = cv2.imread(mask_path) + + if mask_image is None: + raise FileNotFoundError(f"Mask image not found at {mask_path}") + + # Ensure the mask_image has three color channels + if len(mask_image.shape) != 3 or mask_image.shape[2] != 3: + raise ValueError("Mask image does not have three color channels.") + + mask_image = mask_image.astype(np.uint8) # Extract Object A and Object B masks - mask_a = (mask_image[:, :, 1] == 255) # Green channel - mask_b = (mask_image[:, :, 0] == 254) # Blue channel + mask_a = np.all(mask_image == GREEN, axis=2) + mask_b = np.all(mask_image == BLUE, axis=2) per_obj_input_mask = {1: mask_a, 2: mask_b} input_palette = None # No palette needed for binary mask @@ -110,6 +121,8 @@ def apply_green_mask(frame, masks): interpolation=cv2.INTER_LINEAR ) # Threshold the resized mask to obtain a boolean mask + # add a small gausian blur to the mask to smooth out the edges + resized_mask = cv2.GaussianBlur(resized_mask, (50, 50), 0) mask = resized_mask > 0.5 else: @@ -193,7 +206,7 @@ def select_points(first_frame): points_a.append((x, y)) point_count += 1 print(f"Selected point {point_count} for Object A: ({x}, {y})") - if len(points_a) == 4: # Collect 4 points for Object A + if len(points_a) == 5: # Collect 4 points for Object A current_object = 'B' point_count = 0 print("Select point 1 for Object B") @@ -201,7 +214,7 @@ def select_points(first_frame): points_b.append((x, y)) point_count += 1 print(f"Selected point {point_count} for Object B: ({x}, {y})") - if len(points_b) == 4: # Collect 4 points for Object B + if len(points_b) == 5: # Collect 4 points for Object B selection_complete = True print("Select point 1 for Object A") @@ -217,7 +230,7 @@ def select_points(first_frame): return np.array(points_a, dtype=np.float32), np.array(points_b, dtype=np.float32) def add_points_to_predictor(predictor, inference_state, points, obj_id): - labels = np.array([1, 1, 1, 1], np.int32) # Update labels to match 4 points + labels = np.array([1, 1, 1, 1, 1], np.int32) # Update labels to match 4 points points = np.array(points, dtype=np.float32) # Ensure points have shape (4, 2) try: print(f"Adding points for Object {obj_id}: {points}") @@ -376,11 +389,11 @@ def save_final_masks(video_segments, mask_output_path): # mask a mask_a = mask_a.squeeze() indices = np.where(mask_a) - black_frame[mask_a] = [0, 255, 0] # Green for Object A + black_frame[mask_a] = GREEN # mask b mask_b = mask_b.squeeze() indices = np.where(mask_b) - black_frame[mask_b] = [255, 0, 0] # Green for Object B + black_frame[mask_b] = BLUE # Save the mask image cv2.imwrite(mask_output_path, black_frame) @@ -410,11 +423,12 @@ def create_low_res_video(input_video_path, output_video_path, scale): def main(): parser = argparse.ArgumentParser(description="Process video segments.") + # arg for setting base_dir + parser.add_argument("--base-dir", type=str, help="Base directory for video segments.") parser.add_argument("--segments-collect-points", nargs='+', type=int, help="Segments for which to collect points.") args = parser.parse_args() - base_dir = "./freya_short_segments" - #base_dir = "./606-short_segments" + base_dir = args.base_dir segments = [d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d)) and d.startswith("segment_")] segments.sort(key=lambda x: int(x.split("_")[1])) scaled_frames_dir_name = "frames_scaled" @@ -455,8 +469,8 @@ def main(): if os.path.exists(points_file): logger.info(f"Using segment_points for segment {segment}") points = np.loadtxt(points_file, comments="#") - points_a = points[:4] - points_b = points[4:] + points_a = points[:5] + points_b = points[5:] else: points_a = points_b = None @@ -466,7 +480,7 @@ def main(): # Try to load previous segment mask logger.info(f"Using previous segment mask for segment {segment}") prev_segment_dir = os.path.join(base_dir, segments[i - 1]) - prev_mask_path = os.path.join(prev_segment_dir, "mask.jpg") + prev_mask_path = os.path.join(prev_segment_dir, "mask.png") if os.path.exists(prev_mask_path): per_obj_input_mask, input_palette = load_previous_segment_mask(prev_segment_dir) # Add previous masks to predictor @@ -499,7 +513,7 @@ def main(): ) # Save final masks - mask_output_path = os.path.join(segment_dir, "mask.jpg") + mask_output_path = os.path.join(segment_dir, "mask.png") save_final_masks(video_segments, mask_output_path) # Clean up