mask to png

This commit is contained in:
2024-10-21 16:56:11 -07:00
parent 4a8a0753cc
commit 535a984cba

View File

@@ -42,8 +42,10 @@ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
# Variables for input and output directories # Variables for input and output directories
SAM2_CHECKPOINT = "../checkpoints/sam2.1_hiera_large.pt" SAM2_CHECKPOINT = "../checkpoints/sam2.1_hiera_large.pt"
MODEL_CFG = "configs/sam2.1/sam2.1_hiera_l.yaml" MODEL_CFG = "configs/sam2.1/sam2.1_hiera_l.yaml"
GREEN = [0, 255, 0]
BLUE = [255, 0, 0]
INFERENCE_SCALE = 0.35 INFERENCE_SCALE = 0.25
FULL_SCALE = 1.0 FULL_SCALE = 1.0
def open_video(video_path): def open_video(video_path):
@@ -68,12 +70,21 @@ def open_video(video_path):
cap.release() cap.release()
def load_previous_segment_mask(prev_segment_dir): def load_previous_segment_mask(prev_segment_dir):
mask_path = os.path.join(prev_segment_dir, "mask.jpg") mask_path = os.path.join(prev_segment_dir, "mask.png")
mask_image = cv2.imread(mask_path) mask_image = cv2.imread(mask_path)
if mask_image is None:
raise FileNotFoundError(f"Mask image not found at {mask_path}")
# Ensure the mask_image has three color channels
if len(mask_image.shape) != 3 or mask_image.shape[2] != 3:
raise ValueError("Mask image does not have three color channels.")
mask_image = mask_image.astype(np.uint8)
# Extract Object A and Object B masks # Extract Object A and Object B masks
mask_a = (mask_image[:, :, 1] == 255) # Green channel mask_a = np.all(mask_image == GREEN, axis=2)
mask_b = (mask_image[:, :, 0] == 254) # Blue channel mask_b = np.all(mask_image == BLUE, axis=2)
per_obj_input_mask = {1: mask_a, 2: mask_b} per_obj_input_mask = {1: mask_a, 2: mask_b}
input_palette = None # No palette needed for binary mask input_palette = None # No palette needed for binary mask
@@ -110,6 +121,8 @@ def apply_green_mask(frame, masks):
interpolation=cv2.INTER_LINEAR interpolation=cv2.INTER_LINEAR
) )
# Threshold the resized mask to obtain a boolean mask # Threshold the resized mask to obtain a boolean mask
# add a small gausian blur to the mask to smooth out the edges
resized_mask = cv2.GaussianBlur(resized_mask, (50, 50), 0)
mask = resized_mask > 0.5 mask = resized_mask > 0.5
else: else:
@@ -193,7 +206,7 @@ def select_points(first_frame):
points_a.append((x, y)) points_a.append((x, y))
point_count += 1 point_count += 1
print(f"Selected point {point_count} for Object A: ({x}, {y})") print(f"Selected point {point_count} for Object A: ({x}, {y})")
if len(points_a) == 4: # Collect 4 points for Object A if len(points_a) == 5: # Collect 4 points for Object A
current_object = 'B' current_object = 'B'
point_count = 0 point_count = 0
print("Select point 1 for Object B") print("Select point 1 for Object B")
@@ -201,7 +214,7 @@ def select_points(first_frame):
points_b.append((x, y)) points_b.append((x, y))
point_count += 1 point_count += 1
print(f"Selected point {point_count} for Object B: ({x}, {y})") print(f"Selected point {point_count} for Object B: ({x}, {y})")
if len(points_b) == 4: # Collect 4 points for Object B if len(points_b) == 5: # Collect 4 points for Object B
selection_complete = True selection_complete = True
print("Select point 1 for Object A") print("Select point 1 for Object A")
@@ -217,7 +230,7 @@ def select_points(first_frame):
return np.array(points_a, dtype=np.float32), np.array(points_b, dtype=np.float32) return np.array(points_a, dtype=np.float32), np.array(points_b, dtype=np.float32)
def add_points_to_predictor(predictor, inference_state, points, obj_id): def add_points_to_predictor(predictor, inference_state, points, obj_id):
labels = np.array([1, 1, 1, 1], np.int32) # Update labels to match 4 points labels = np.array([1, 1, 1, 1, 1], np.int32) # Update labels to match 4 points
points = np.array(points, dtype=np.float32) # Ensure points have shape (4, 2) points = np.array(points, dtype=np.float32) # Ensure points have shape (4, 2)
try: try:
print(f"Adding points for Object {obj_id}: {points}") print(f"Adding points for Object {obj_id}: {points}")
@@ -376,11 +389,11 @@ def save_final_masks(video_segments, mask_output_path):
# mask a # mask a
mask_a = mask_a.squeeze() mask_a = mask_a.squeeze()
indices = np.where(mask_a) indices = np.where(mask_a)
black_frame[mask_a] = [0, 255, 0] # Green for Object A black_frame[mask_a] = GREEN
# mask b # mask b
mask_b = mask_b.squeeze() mask_b = mask_b.squeeze()
indices = np.where(mask_b) indices = np.where(mask_b)
black_frame[mask_b] = [255, 0, 0] # Green for Object B black_frame[mask_b] = BLUE
# Save the mask image # Save the mask image
cv2.imwrite(mask_output_path, black_frame) cv2.imwrite(mask_output_path, black_frame)
@@ -410,11 +423,12 @@ def create_low_res_video(input_video_path, output_video_path, scale):
def main(): def main():
parser = argparse.ArgumentParser(description="Process video segments.") parser = argparse.ArgumentParser(description="Process video segments.")
# arg for setting base_dir
parser.add_argument("--base-dir", type=str, help="Base directory for video segments.")
parser.add_argument("--segments-collect-points", nargs='+', type=int, help="Segments for which to collect points.") parser.add_argument("--segments-collect-points", nargs='+', type=int, help="Segments for which to collect points.")
args = parser.parse_args() args = parser.parse_args()
base_dir = "./freya_short_segments" base_dir = args.base_dir
#base_dir = "./606-short_segments"
segments = [d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d)) and d.startswith("segment_")] segments = [d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d)) and d.startswith("segment_")]
segments.sort(key=lambda x: int(x.split("_")[1])) segments.sort(key=lambda x: int(x.split("_")[1]))
scaled_frames_dir_name = "frames_scaled" scaled_frames_dir_name = "frames_scaled"
@@ -455,8 +469,8 @@ def main():
if os.path.exists(points_file): if os.path.exists(points_file):
logger.info(f"Using segment_points for segment {segment}") logger.info(f"Using segment_points for segment {segment}")
points = np.loadtxt(points_file, comments="#") points = np.loadtxt(points_file, comments="#")
points_a = points[:4] points_a = points[:5]
points_b = points[4:] points_b = points[5:]
else: else:
points_a = points_b = None points_a = points_b = None
@@ -466,7 +480,7 @@ def main():
# Try to load previous segment mask # Try to load previous segment mask
logger.info(f"Using previous segment mask for segment {segment}") logger.info(f"Using previous segment mask for segment {segment}")
prev_segment_dir = os.path.join(base_dir, segments[i - 1]) prev_segment_dir = os.path.join(base_dir, segments[i - 1])
prev_mask_path = os.path.join(prev_segment_dir, "mask.jpg") prev_mask_path = os.path.join(prev_segment_dir, "mask.png")
if os.path.exists(prev_mask_path): if os.path.exists(prev_mask_path):
per_obj_input_mask, input_palette = load_previous_segment_mask(prev_segment_dir) per_obj_input_mask, input_palette = load_previous_segment_mask(prev_segment_dir)
# Add previous masks to predictor # Add previous masks to predictor
@@ -499,7 +513,7 @@ def main():
) )
# Save final masks # Save final masks
mask_output_path = os.path.join(segment_dir, "mask.jpg") mask_output_path = os.path.join(segment_dir, "mask.png")
save_final_masks(video_segments, mask_output_path) save_final_masks(video_segments, mask_output_path)
# Clean up # Clean up