mask to png

This commit is contained in:
2024-10-21 16:56:11 -07:00
parent 4a8a0753cc
commit 535a984cba

View File

@@ -42,8 +42,10 @@ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
# Variables for input and output directories
SAM2_CHECKPOINT = "../checkpoints/sam2.1_hiera_large.pt"
MODEL_CFG = "configs/sam2.1/sam2.1_hiera_l.yaml"
GREEN = [0, 255, 0]
BLUE = [255, 0, 0]
INFERENCE_SCALE = 0.35
INFERENCE_SCALE = 0.25
FULL_SCALE = 1.0
def open_video(video_path):
@@ -68,12 +70,21 @@ def open_video(video_path):
cap.release()
def load_previous_segment_mask(prev_segment_dir):
mask_path = os.path.join(prev_segment_dir, "mask.jpg")
mask_path = os.path.join(prev_segment_dir, "mask.png")
mask_image = cv2.imread(mask_path)
if mask_image is None:
raise FileNotFoundError(f"Mask image not found at {mask_path}")
# Ensure the mask_image has three color channels
if len(mask_image.shape) != 3 or mask_image.shape[2] != 3:
raise ValueError("Mask image does not have three color channels.")
mask_image = mask_image.astype(np.uint8)
# Extract Object A and Object B masks
mask_a = (mask_image[:, :, 1] == 255) # Green channel
mask_b = (mask_image[:, :, 0] == 254) # Blue channel
mask_a = np.all(mask_image == GREEN, axis=2)
mask_b = np.all(mask_image == BLUE, axis=2)
per_obj_input_mask = {1: mask_a, 2: mask_b}
input_palette = None # No palette needed for binary mask
@@ -110,6 +121,8 @@ def apply_green_mask(frame, masks):
interpolation=cv2.INTER_LINEAR
)
# Threshold the resized mask to obtain a boolean mask
# add a small gausian blur to the mask to smooth out the edges
resized_mask = cv2.GaussianBlur(resized_mask, (50, 50), 0)
mask = resized_mask > 0.5
else:
@@ -193,7 +206,7 @@ def select_points(first_frame):
points_a.append((x, y))
point_count += 1
print(f"Selected point {point_count} for Object A: ({x}, {y})")
if len(points_a) == 4: # Collect 4 points for Object A
if len(points_a) == 5: # Collect 4 points for Object A
current_object = 'B'
point_count = 0
print("Select point 1 for Object B")
@@ -201,7 +214,7 @@ def select_points(first_frame):
points_b.append((x, y))
point_count += 1
print(f"Selected point {point_count} for Object B: ({x}, {y})")
if len(points_b) == 4: # Collect 4 points for Object B
if len(points_b) == 5: # Collect 4 points for Object B
selection_complete = True
print("Select point 1 for Object A")
@@ -217,7 +230,7 @@ def select_points(first_frame):
return np.array(points_a, dtype=np.float32), np.array(points_b, dtype=np.float32)
def add_points_to_predictor(predictor, inference_state, points, obj_id):
labels = np.array([1, 1, 1, 1], np.int32) # Update labels to match 4 points
labels = np.array([1, 1, 1, 1, 1], np.int32) # Update labels to match 4 points
points = np.array(points, dtype=np.float32) # Ensure points have shape (4, 2)
try:
print(f"Adding points for Object {obj_id}: {points}")
@@ -376,11 +389,11 @@ def save_final_masks(video_segments, mask_output_path):
# mask a
mask_a = mask_a.squeeze()
indices = np.where(mask_a)
black_frame[mask_a] = [0, 255, 0] # Green for Object A
black_frame[mask_a] = GREEN
# mask b
mask_b = mask_b.squeeze()
indices = np.where(mask_b)
black_frame[mask_b] = [255, 0, 0] # Green for Object B
black_frame[mask_b] = BLUE
# Save the mask image
cv2.imwrite(mask_output_path, black_frame)
@@ -410,11 +423,12 @@ def create_low_res_video(input_video_path, output_video_path, scale):
def main():
parser = argparse.ArgumentParser(description="Process video segments.")
# arg for setting base_dir
parser.add_argument("--base-dir", type=str, help="Base directory for video segments.")
parser.add_argument("--segments-collect-points", nargs='+', type=int, help="Segments for which to collect points.")
args = parser.parse_args()
base_dir = "./freya_short_segments"
#base_dir = "./606-short_segments"
base_dir = args.base_dir
segments = [d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d)) and d.startswith("segment_")]
segments.sort(key=lambda x: int(x.split("_")[1]))
scaled_frames_dir_name = "frames_scaled"
@@ -455,8 +469,8 @@ def main():
if os.path.exists(points_file):
logger.info(f"Using segment_points for segment {segment}")
points = np.loadtxt(points_file, comments="#")
points_a = points[:4]
points_b = points[4:]
points_a = points[:5]
points_b = points[5:]
else:
points_a = points_b = None
@@ -466,7 +480,7 @@ def main():
# Try to load previous segment mask
logger.info(f"Using previous segment mask for segment {segment}")
prev_segment_dir = os.path.join(base_dir, segments[i - 1])
prev_mask_path = os.path.join(prev_segment_dir, "mask.jpg")
prev_mask_path = os.path.join(prev_segment_dir, "mask.png")
if os.path.exists(prev_mask_path):
per_obj_input_mask, input_palette = load_previous_segment_mask(prev_segment_dir)
# Add previous masks to predictor
@@ -499,7 +513,7 @@ def main():
)
# Save final masks
mask_output_path = os.path.join(segment_dir, "mask.jpg")
mask_output_path = os.path.join(segment_dir, "mask.png")
save_final_masks(video_segments, mask_output_path)
# Clean up