mask to png
This commit is contained in:
@@ -42,8 +42,10 @@ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
|||||||
# Variables for input and output directories
|
# Variables for input and output directories
|
||||||
SAM2_CHECKPOINT = "../checkpoints/sam2.1_hiera_large.pt"
|
SAM2_CHECKPOINT = "../checkpoints/sam2.1_hiera_large.pt"
|
||||||
MODEL_CFG = "configs/sam2.1/sam2.1_hiera_l.yaml"
|
MODEL_CFG = "configs/sam2.1/sam2.1_hiera_l.yaml"
|
||||||
|
GREEN = [0, 255, 0]
|
||||||
|
BLUE = [255, 0, 0]
|
||||||
|
|
||||||
INFERENCE_SCALE = 0.35
|
INFERENCE_SCALE = 0.25
|
||||||
FULL_SCALE = 1.0
|
FULL_SCALE = 1.0
|
||||||
|
|
||||||
def open_video(video_path):
|
def open_video(video_path):
|
||||||
@@ -68,12 +70,21 @@ def open_video(video_path):
|
|||||||
cap.release()
|
cap.release()
|
||||||
|
|
||||||
def load_previous_segment_mask(prev_segment_dir):
|
def load_previous_segment_mask(prev_segment_dir):
|
||||||
mask_path = os.path.join(prev_segment_dir, "mask.jpg")
|
mask_path = os.path.join(prev_segment_dir, "mask.png")
|
||||||
mask_image = cv2.imread(mask_path)
|
mask_image = cv2.imread(mask_path)
|
||||||
|
|
||||||
|
if mask_image is None:
|
||||||
|
raise FileNotFoundError(f"Mask image not found at {mask_path}")
|
||||||
|
|
||||||
|
# Ensure the mask_image has three color channels
|
||||||
|
if len(mask_image.shape) != 3 or mask_image.shape[2] != 3:
|
||||||
|
raise ValueError("Mask image does not have three color channels.")
|
||||||
|
|
||||||
|
mask_image = mask_image.astype(np.uint8)
|
||||||
|
|
||||||
# Extract Object A and Object B masks
|
# Extract Object A and Object B masks
|
||||||
mask_a = (mask_image[:, :, 1] == 255) # Green channel
|
mask_a = np.all(mask_image == GREEN, axis=2)
|
||||||
mask_b = (mask_image[:, :, 0] == 254) # Blue channel
|
mask_b = np.all(mask_image == BLUE, axis=2)
|
||||||
|
|
||||||
per_obj_input_mask = {1: mask_a, 2: mask_b}
|
per_obj_input_mask = {1: mask_a, 2: mask_b}
|
||||||
input_palette = None # No palette needed for binary mask
|
input_palette = None # No palette needed for binary mask
|
||||||
@@ -110,6 +121,8 @@ def apply_green_mask(frame, masks):
|
|||||||
interpolation=cv2.INTER_LINEAR
|
interpolation=cv2.INTER_LINEAR
|
||||||
)
|
)
|
||||||
# Threshold the resized mask to obtain a boolean mask
|
# Threshold the resized mask to obtain a boolean mask
|
||||||
|
# add a small gausian blur to the mask to smooth out the edges
|
||||||
|
resized_mask = cv2.GaussianBlur(resized_mask, (50, 50), 0)
|
||||||
|
|
||||||
mask = resized_mask > 0.5
|
mask = resized_mask > 0.5
|
||||||
else:
|
else:
|
||||||
@@ -193,7 +206,7 @@ def select_points(first_frame):
|
|||||||
points_a.append((x, y))
|
points_a.append((x, y))
|
||||||
point_count += 1
|
point_count += 1
|
||||||
print(f"Selected point {point_count} for Object A: ({x}, {y})")
|
print(f"Selected point {point_count} for Object A: ({x}, {y})")
|
||||||
if len(points_a) == 4: # Collect 4 points for Object A
|
if len(points_a) == 5: # Collect 4 points for Object A
|
||||||
current_object = 'B'
|
current_object = 'B'
|
||||||
point_count = 0
|
point_count = 0
|
||||||
print("Select point 1 for Object B")
|
print("Select point 1 for Object B")
|
||||||
@@ -201,7 +214,7 @@ def select_points(first_frame):
|
|||||||
points_b.append((x, y))
|
points_b.append((x, y))
|
||||||
point_count += 1
|
point_count += 1
|
||||||
print(f"Selected point {point_count} for Object B: ({x}, {y})")
|
print(f"Selected point {point_count} for Object B: ({x}, {y})")
|
||||||
if len(points_b) == 4: # Collect 4 points for Object B
|
if len(points_b) == 5: # Collect 4 points for Object B
|
||||||
selection_complete = True
|
selection_complete = True
|
||||||
|
|
||||||
print("Select point 1 for Object A")
|
print("Select point 1 for Object A")
|
||||||
@@ -217,7 +230,7 @@ def select_points(first_frame):
|
|||||||
return np.array(points_a, dtype=np.float32), np.array(points_b, dtype=np.float32)
|
return np.array(points_a, dtype=np.float32), np.array(points_b, dtype=np.float32)
|
||||||
|
|
||||||
def add_points_to_predictor(predictor, inference_state, points, obj_id):
|
def add_points_to_predictor(predictor, inference_state, points, obj_id):
|
||||||
labels = np.array([1, 1, 1, 1], np.int32) # Update labels to match 4 points
|
labels = np.array([1, 1, 1, 1, 1], np.int32) # Update labels to match 4 points
|
||||||
points = np.array(points, dtype=np.float32) # Ensure points have shape (4, 2)
|
points = np.array(points, dtype=np.float32) # Ensure points have shape (4, 2)
|
||||||
try:
|
try:
|
||||||
print(f"Adding points for Object {obj_id}: {points}")
|
print(f"Adding points for Object {obj_id}: {points}")
|
||||||
@@ -376,11 +389,11 @@ def save_final_masks(video_segments, mask_output_path):
|
|||||||
# mask a
|
# mask a
|
||||||
mask_a = mask_a.squeeze()
|
mask_a = mask_a.squeeze()
|
||||||
indices = np.where(mask_a)
|
indices = np.where(mask_a)
|
||||||
black_frame[mask_a] = [0, 255, 0] # Green for Object A
|
black_frame[mask_a] = GREEN
|
||||||
# mask b
|
# mask b
|
||||||
mask_b = mask_b.squeeze()
|
mask_b = mask_b.squeeze()
|
||||||
indices = np.where(mask_b)
|
indices = np.where(mask_b)
|
||||||
black_frame[mask_b] = [255, 0, 0] # Green for Object B
|
black_frame[mask_b] = BLUE
|
||||||
|
|
||||||
# Save the mask image
|
# Save the mask image
|
||||||
cv2.imwrite(mask_output_path, black_frame)
|
cv2.imwrite(mask_output_path, black_frame)
|
||||||
@@ -410,11 +423,12 @@ def create_low_res_video(input_video_path, output_video_path, scale):
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(description="Process video segments.")
|
parser = argparse.ArgumentParser(description="Process video segments.")
|
||||||
|
# arg for setting base_dir
|
||||||
|
parser.add_argument("--base-dir", type=str, help="Base directory for video segments.")
|
||||||
parser.add_argument("--segments-collect-points", nargs='+', type=int, help="Segments for which to collect points.")
|
parser.add_argument("--segments-collect-points", nargs='+', type=int, help="Segments for which to collect points.")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
base_dir = "./freya_short_segments"
|
base_dir = args.base_dir
|
||||||
#base_dir = "./606-short_segments"
|
|
||||||
segments = [d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d)) and d.startswith("segment_")]
|
segments = [d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d)) and d.startswith("segment_")]
|
||||||
segments.sort(key=lambda x: int(x.split("_")[1]))
|
segments.sort(key=lambda x: int(x.split("_")[1]))
|
||||||
scaled_frames_dir_name = "frames_scaled"
|
scaled_frames_dir_name = "frames_scaled"
|
||||||
@@ -455,8 +469,8 @@ def main():
|
|||||||
if os.path.exists(points_file):
|
if os.path.exists(points_file):
|
||||||
logger.info(f"Using segment_points for segment {segment}")
|
logger.info(f"Using segment_points for segment {segment}")
|
||||||
points = np.loadtxt(points_file, comments="#")
|
points = np.loadtxt(points_file, comments="#")
|
||||||
points_a = points[:4]
|
points_a = points[:5]
|
||||||
points_b = points[4:]
|
points_b = points[5:]
|
||||||
else:
|
else:
|
||||||
points_a = points_b = None
|
points_a = points_b = None
|
||||||
|
|
||||||
@@ -466,7 +480,7 @@ def main():
|
|||||||
# Try to load previous segment mask
|
# Try to load previous segment mask
|
||||||
logger.info(f"Using previous segment mask for segment {segment}")
|
logger.info(f"Using previous segment mask for segment {segment}")
|
||||||
prev_segment_dir = os.path.join(base_dir, segments[i - 1])
|
prev_segment_dir = os.path.join(base_dir, segments[i - 1])
|
||||||
prev_mask_path = os.path.join(prev_segment_dir, "mask.jpg")
|
prev_mask_path = os.path.join(prev_segment_dir, "mask.png")
|
||||||
if os.path.exists(prev_mask_path):
|
if os.path.exists(prev_mask_path):
|
||||||
per_obj_input_mask, input_palette = load_previous_segment_mask(prev_segment_dir)
|
per_obj_input_mask, input_palette = load_previous_segment_mask(prev_segment_dir)
|
||||||
# Add previous masks to predictor
|
# Add previous masks to predictor
|
||||||
@@ -499,7 +513,7 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Save final masks
|
# Save final masks
|
||||||
mask_output_path = os.path.join(segment_dir, "mask.jpg")
|
mask_output_path = os.path.join(segment_dir, "mask.png")
|
||||||
save_final_masks(video_segments, mask_output_path)
|
save_final_masks(video_segments, mask_output_path)
|
||||||
|
|
||||||
# Clean up
|
# Clean up
|
||||||
|
|||||||
Reference in New Issue
Block a user