add newest shit
This commit is contained in:
@@ -30,17 +30,20 @@ from concurrent.futures import ThreadPoolExecutor
|
||||
import torch
|
||||
import logging
|
||||
import sys
|
||||
import gc
|
||||
from sam2.build_sam import build_sam2_video_predictor
|
||||
import argparse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
||||
|
||||
# Variables for input and output directories
|
||||
SAM2_CHECKPOINT = "../checkpoints/sam2.1_hiera_large.pt"
|
||||
MODEL_CFG = "configs/sam2.1/sam2.1_hiera_l.yaml"
|
||||
|
||||
INFERENCE_SCALE = 0.5
|
||||
INFERENCE_SCALE = 0.35
|
||||
FULL_SCALE = 1.0
|
||||
|
||||
def open_video(video_path):
|
||||
@@ -113,13 +116,11 @@ def apply_green_mask(frame, masks):
|
||||
# Ensure mask is boolean
|
||||
mask = mask.astype(bool)
|
||||
|
||||
|
||||
# Combine masks using logical OR
|
||||
combined_mask |= mask # Now both arrays are bool
|
||||
|
||||
# Create a green background image
|
||||
green_background = np.full_like(frame, [0, 255, 0])
|
||||
|
||||
# Use combined mask to overlay the original frame onto the green background
|
||||
result_frame = np.where(
|
||||
combined_mask[..., None],
|
||||
@@ -131,10 +132,23 @@ def apply_green_mask(frame, masks):
|
||||
return result_frame
|
||||
|
||||
def initialize_predictor():
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
elif torch.backends.mps.is_available():
|
||||
device = torch.device("mps")
|
||||
print(
|
||||
"\nSupport for MPS devices is preliminary. SAM 2 is trained with CUDA and might "
|
||||
"give numerically different outputs and sometimes degraded performance on MPS."
|
||||
)
|
||||
# Enable MPS fallback for operations not supported on MPS
|
||||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
logger.info(f"Using device: {device}")
|
||||
predictor = build_sam2_video_predictor(MODEL_CFG, SAM2_CHECKPOINT, device=device)
|
||||
return predictor
|
||||
|
||||
|
||||
def load_first_frame(video_path, scale=1.0):
|
||||
"""
|
||||
Opens a video file and returns the first frame, scaled as specified.
|
||||
@@ -220,30 +234,6 @@ def add_points_to_predictor(predictor, inference_state, points, obj_id):
|
||||
print(f"Error adding points for Object {obj_id}: {e}")
|
||||
exit()
|
||||
|
||||
def show_mask_on_frame(frame, masks):
|
||||
combined_mask = np.zeros(frame.shape[:2], dtype=bool)
|
||||
for mask in masks:
|
||||
mask = mask.squeeze()
|
||||
if mask.shape != frame.shape[:2]:
|
||||
mask = cv2.resize(mask, (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_NEAREST)
|
||||
combined_mask = np.logical_or(combined_mask, mask)
|
||||
color = (0, 255, 0)
|
||||
frame[combined_mask] = color
|
||||
return frame
|
||||
|
||||
def confirm_masks(first_frame, masks_a, masks_b):
|
||||
first_frame_with_masks = show_mask_on_frame(first_frame.copy(), masks_a + masks_b)
|
||||
cv2.namedWindow('First Frame with Masks', cv2.WINDOW_NORMAL)
|
||||
cv2.resizeWindow('First Frame with Masks', int(first_frame.shape[1] * (500 / first_frame.shape[0])), 500)
|
||||
cv2.imshow('First Frame with Masks', first_frame_with_masks)
|
||||
cv2.waitKey(0)
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
confirmation = input("Are the masks correct? (yes/no): ").strip().lower()
|
||||
if confirmation != 'yes':
|
||||
print("Aborting process.")
|
||||
exit()
|
||||
|
||||
def propagate_masks(predictor, inference_state):
|
||||
video_segments = {}
|
||||
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
|
||||
@@ -274,52 +264,6 @@ def apply_colored_mask(frame, masks_a, masks_b):
|
||||
return colored_mask
|
||||
|
||||
|
||||
|
||||
def process_and_save_frames(input_frames_dir, fullres_frames_dir, output_frames_dir, frame_names, video_segments, segment_dir):
|
||||
def upscale_masks(masks, frame_shape):
|
||||
upscaled_masks = []
|
||||
for mask in masks:
|
||||
mask = mask.squeeze()
|
||||
upscaled_mask = cv2.resize(mask.astype(np.float32), (frame_shape[1], frame_shape[0]), interpolation=cv2.INTER_LINEAR)
|
||||
#convert_mask to bool
|
||||
upscaled_mask = (upscaled_mask > 0.5).astype(bool)
|
||||
upscaled_masks.append(upscaled_mask)
|
||||
return upscaled_masks
|
||||
|
||||
|
||||
for out_frame_idx, frame_name in enumerate(frame_names):
|
||||
frame_path = os.path.join(fullres_frames_dir, frame_name)
|
||||
frame = cv2.imread(frame_path)
|
||||
masks = [video_segments[out_frame_idx][out_obj_id] for out_obj_id in video_segments[out_frame_idx]]
|
||||
|
||||
# Upscale masks to match the full-resolution frame
|
||||
upscaled_masks = []
|
||||
for mask in masks:
|
||||
mask = mask.squeeze()
|
||||
upscaled_mask = cv2.resize(mask.astype(np.uint8), (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_NEAREST)
|
||||
upscaled_masks.append(upscaled_mask)
|
||||
|
||||
frame = apply_green_mask(frame, upscaled_masks)
|
||||
output_path = os.path.join(output_frames_dir, frame_name)
|
||||
cv2.imwrite(output_path, frame)
|
||||
|
||||
# Create and save mask.jpg
|
||||
final_frame_path = os.path.join(fullres_frames_dir, frame_names[-1])
|
||||
final_frame = cv2.imread(final_frame_path)
|
||||
masks_a = [video_segments[len(frame_names) - 1][1]]
|
||||
masks_b = [video_segments[len(frame_names) - 1][2]]
|
||||
upscaled_masks_a = upscale_masks(masks_a, final_frame.shape)
|
||||
upscaled_masks_b = upscale_masks(masks_b, final_frame.shape)
|
||||
|
||||
|
||||
# Apply colored mask
|
||||
mask_image = apply_colored_mask(final_frame, upscaled_masks_a, upscaled_masks_b)
|
||||
|
||||
mask_output_path = os.path.join(segment_dir, "mask.jpg")
|
||||
cv2.imwrite(mask_output_path, mask_image)
|
||||
|
||||
print("Processing complete. Frames saved in:", output_frames_dir)
|
||||
|
||||
def process_and_save_output_video(video_path, output_video_path, video_segments, use_nvenc=False):
|
||||
"""
|
||||
Process high-resolution frames, apply upscaled masks, and save the output video.
|
||||
@@ -334,6 +278,11 @@ def process_and_save_output_video(video_path, output_video_path, video_segments,
|
||||
# Use FFmpeg with NVENC offloading for H.265 encoding
|
||||
import subprocess
|
||||
|
||||
if sys.platform == 'darwin':
|
||||
encoder = 'hevc_videotoolbox'
|
||||
else:
|
||||
encoder = 'hevc_nvenc'
|
||||
|
||||
command = [
|
||||
'ffmpeg',
|
||||
'-y', # Overwrite output file if it exists
|
||||
@@ -344,10 +293,10 @@ def process_and_save_output_video(video_path, output_video_path, video_segments,
|
||||
'-r', str(fps),
|
||||
'-i', '-', # Input from stdin
|
||||
'-an', # No audio
|
||||
'-vcodec', 'hevc_nvenc',
|
||||
'-vcodec', encoder,
|
||||
'-pix_fmt', 'yuv420p',
|
||||
'-preset', 'slow',
|
||||
'-b:v', '60M',
|
||||
'-b:v', '50M',
|
||||
output_video_path
|
||||
]
|
||||
process = subprocess.Popen(command, stdin=subprocess.PIPE)
|
||||
@@ -362,22 +311,7 @@ def process_and_save_output_video(video_path, output_video_path, video_segments,
|
||||
if not ret or frame_idx >= len(video_segments):
|
||||
break
|
||||
|
||||
#masks_dict = video_segments[frame_idx]
|
||||
#upscaled_masks = []
|
||||
|
||||
## Upscale masks to full resolution
|
||||
#for mask in masks_dict.values():
|
||||
# mask = mask.squeeze()
|
||||
# upscaled_mask = cv2.resize(
|
||||
# mask.astype(np.uint8),
|
||||
# (frame.shape[1], frame.shape[0]),
|
||||
# interpolation=cv2.INTER_NEAREST
|
||||
# )
|
||||
# upscaled_masks.append(upscaled_mask)
|
||||
|
||||
masks = [video_segments[frame_idx][out_obj_id] for out_obj_id in video_segments[frame_idx]]
|
||||
|
||||
# Upscale masks to match the full-resolution frame
|
||||
upscaled_masks = []
|
||||
|
||||
for mask in masks:
|
||||
@@ -414,69 +348,11 @@ def do_collect_segment_points(base_dir, segments, collect_points_segments, scale
|
||||
video_file = os.path.join(segment_dir, get_video_file_name(i))
|
||||
if segment_index in collect_points_segments and not os.path.exists(points_file):
|
||||
first_frame = load_first_frame(video_file, scale)
|
||||
|
||||
|
||||
points_a, points_b = select_points(first_frame)
|
||||
with open(points_file, 'w') as f:
|
||||
|
||||
np.savetxt(f, points_a, header="Object A Points")
|
||||
np.savetxt(f, points_b, header="Object B Points")
|
||||
|
||||
def perform_inference(predictor, inference_state, video_path, inference_scale, collect_points, prev_segment_mask=None):
|
||||
"""
|
||||
Performs inference on the video frames at a specified scale.
|
||||
|
||||
Parameters:
|
||||
- predictor: The initialized predictor object.
|
||||
- inference_state: The predictor's inference state.
|
||||
- video_path: Path to the video file.
|
||||
- inference_scale: Scaling factor for inference frames.
|
||||
- collect_points: Boolean indicating whether to collect points.
|
||||
- prev_segment_mask: Previous segment's mask if available.
|
||||
|
||||
|
||||
Returns:
|
||||
- masks_per_frame: List of masks for each frame.
|
||||
"""
|
||||
masks_per_frame = []
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
frame_idx = 0
|
||||
|
||||
while True:
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
|
||||
# Resize frame for inference
|
||||
low_res_frame = cv2.resize(frame, None, fx=inference_scale, fy=inference_scale, interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
if frame_idx == 0:
|
||||
if collect_points:
|
||||
|
||||
# Collect points on the low-res frame
|
||||
points_a, points_b = select_points(low_res_frame)
|
||||
add_points_to_predictor(predictor, inference_state, points_a, obj_id=1)
|
||||
add_points_to_predictor(predictor, inference_state, points_b, obj_id=2)
|
||||
elif prev_segment_mask is not None:
|
||||
# Use the previous segment's mask
|
||||
per_obj_input_mask = prev_segment_mask
|
||||
for obj_id, mask in per_obj_input_mask.items():
|
||||
predictor.add_new_mask(inference_state, 0, obj_id, mask)
|
||||
else:
|
||||
print("Error: No points or previous mask provided for inference.")
|
||||
exit()
|
||||
|
||||
|
||||
# Perform inference
|
||||
predictor.predict(inference_state, low_res_frame)
|
||||
masks = predictor.get_masks(inference_state)
|
||||
masks_per_frame.append(masks)
|
||||
|
||||
frame_idx += 1
|
||||
|
||||
cap.release()
|
||||
return masks_per_frame
|
||||
|
||||
def save_final_masks(video_segments, mask_output_path):
|
||||
"""
|
||||
Save the final masks as a colored image.
|
||||
@@ -494,13 +370,20 @@ def save_final_masks(video_segments, mask_output_path):
|
||||
return
|
||||
|
||||
#convert mask to np.uint8
|
||||
mask_a = mask_a.astype(np.uint8)
|
||||
mask_b = mask_b.astype(np.uint8)
|
||||
mask_a = mask_a.astype(bool)
|
||||
mask_b = mask_b.astype(bool)
|
||||
|
||||
mask_image = apply_colored_mask(black_frame, [mask_a], [mask_b])
|
||||
# mask a
|
||||
mask_a = mask_a.squeeze()
|
||||
indices = np.where(mask_a)
|
||||
black_frame[mask_a] = [0, 255, 0] # Green for Object A
|
||||
# mask b
|
||||
mask_b = mask_b.squeeze()
|
||||
indices = np.where(mask_b)
|
||||
black_frame[mask_b] = [255, 0, 0] # Green for Object B
|
||||
|
||||
# Save the mask image
|
||||
cv2.imwrite(mask_output_path, mask_image)
|
||||
cv2.imwrite(mask_output_path, black_frame)
|
||||
|
||||
def create_low_res_video(input_video_path, output_video_path, scale):
|
||||
"""
|
||||
@@ -531,18 +414,15 @@ def main():
|
||||
args = parser.parse_args()
|
||||
|
||||
base_dir = "./freya_short_segments"
|
||||
#base_dir = "./606-short_segments"
|
||||
segments = [d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d)) and d.startswith("segment_")]
|
||||
segments.sort(key=lambda x: int(x.split("_")[1]))
|
||||
scaled_frames_dir_name = "frames_scaled"
|
||||
fullres_frames_dir_name = "frames" # iwant to render the final video with these frames
|
||||
|
||||
collect_points_segments = args.segments_collect_points if args.segments_collect_points else []
|
||||
#inference_scale = 0.5 global, full_scale = 1.0 global
|
||||
#inference_scale for getting the mask, then use full scale when rendering the video
|
||||
do_collect_segment_points(base_dir, segments, collect_points_segments, scale=INFERENCE_SCALE)
|
||||
# i have validated do_collect_segment_points() is working!
|
||||
|
||||
# open_video(video_path) is defined above that returns
|
||||
# generator that yields frames from the video
|
||||
|
||||
for i, segment in enumerate(segments):
|
||||
segment_index = int(segment.split("_")[1])
|
||||
@@ -628,8 +508,14 @@ def main():
|
||||
del video_segments
|
||||
del predictor
|
||||
gc.collect()
|
||||
open(output_done_file, 'a').close()
|
||||
|
||||
try:
|
||||
os.remove(low_res_video_path)
|
||||
logger.info(f"Deleted low-resolution video for segment {segment}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not delete low-resolution video for segment {segment}: {e}")
|
||||
|
||||
open(output_done_file, 'a').close()
|
||||
print("Processing complete.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user