inital commit
This commit is contained in:
362
core/sam2_processor.py
Normal file
362
core/sam2_processor.py
Normal file
@@ -0,0 +1,362 @@
|
||||
"""
|
||||
SAM2 processor module for video segmentation.
|
||||
Preserves the core SAM2 logic from the original implementation.
|
||||
"""
|
||||
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import logging
|
||||
import gc
|
||||
from typing import Dict, List, Any, Optional, Tuple
|
||||
from sam2.build_sam import build_sam2_video_predictor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class SAM2Processor:
|
||||
"""Handles SAM2-based video segmentation for human tracking."""
|
||||
|
||||
def __init__(self, checkpoint_path: str, config_path: str):
|
||||
"""
|
||||
Initialize SAM2 processor.
|
||||
|
||||
Args:
|
||||
checkpoint_path: Path to SAM2 checkpoint
|
||||
config_path: Path to SAM2 config file
|
||||
"""
|
||||
self.checkpoint_path = checkpoint_path
|
||||
self.config_path = config_path
|
||||
self.predictor = None
|
||||
self._initialize_predictor()
|
||||
|
||||
def _initialize_predictor(self):
|
||||
"""Initialize SAM2 video predictor with proper device setup."""
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
elif torch.backends.mps.is_available():
|
||||
device = torch.device("mps")
|
||||
logger.warning(
|
||||
"Support for MPS devices is preliminary. SAM 2 is trained with CUDA and might "
|
||||
"give numerically different outputs and sometimes degraded performance on MPS."
|
||||
)
|
||||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
|
||||
logger.info(f"Using device: {device}")
|
||||
|
||||
try:
|
||||
self.predictor = build_sam2_video_predictor(
|
||||
self.config_path,
|
||||
self.checkpoint_path,
|
||||
device=device
|
||||
)
|
||||
|
||||
# Enable optimizations for CUDA
|
||||
if device.type == "cuda":
|
||||
torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
|
||||
if torch.cuda.get_device_properties(0).major >= 8:
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
logger.info(f"SAM2 predictor initialized successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize SAM2 predictor: {e}")
|
||||
raise
|
||||
|
||||
def create_low_res_video(self, input_video_path: str, output_video_path: str, scale: float):
|
||||
"""
|
||||
Create a low-resolution version of the input video for inference.
|
||||
|
||||
Args:
|
||||
input_video_path: Path to input video
|
||||
output_video_path: Path to output low-res video
|
||||
scale: Scale factor for resolution reduction
|
||||
"""
|
||||
cap = cv2.VideoCapture(input_video_path)
|
||||
if not cap.isOpened():
|
||||
raise ValueError(f"Could not open video: {input_video_path}")
|
||||
|
||||
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH) * scale)
|
||||
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT) * scale)
|
||||
fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
|
||||
|
||||
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
||||
out = cv2.VideoWriter(output_video_path, fourcc, fps, (frame_width, frame_height))
|
||||
|
||||
frame_count = 0
|
||||
while True:
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
|
||||
low_res_frame = cv2.resize(frame, (frame_width, frame_height), interpolation=cv2.INTER_LINEAR)
|
||||
out.write(low_res_frame)
|
||||
frame_count += 1
|
||||
|
||||
cap.release()
|
||||
out.release()
|
||||
|
||||
logger.info(f"Created low-res video with {frame_count} frames: {output_video_path}")
|
||||
|
||||
def add_yolo_prompts_to_predictor(self, inference_state, prompts: List[Dict[str, Any]]) -> bool:
|
||||
"""
|
||||
Add YOLO detection prompts to SAM2 predictor.
|
||||
|
||||
Args:
|
||||
inference_state: SAM2 inference state
|
||||
prompts: List of prompt dictionaries with obj_id and bbox
|
||||
|
||||
Returns:
|
||||
True if prompts were added successfully
|
||||
"""
|
||||
if not prompts:
|
||||
logger.warning("No prompts provided to SAM2")
|
||||
return False
|
||||
|
||||
try:
|
||||
for prompt in prompts:
|
||||
obj_id = prompt['obj_id']
|
||||
bbox = prompt['bbox']
|
||||
|
||||
_, out_obj_ids, out_mask_logits = self.predictor.add_new_points_or_box(
|
||||
inference_state=inference_state,
|
||||
frame_idx=0,
|
||||
obj_id=obj_id,
|
||||
box=bbox.astype(np.float32),
|
||||
)
|
||||
|
||||
logger.debug(f"Added prompt for Object {obj_id}: {bbox}")
|
||||
|
||||
logger.info(f"Successfully added {len(prompts)} prompts to SAM2")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding prompts to SAM2: {e}")
|
||||
return False
|
||||
|
||||
def load_previous_segment_mask(self, prev_segment_dir: str) -> Optional[Dict[int, np.ndarray]]:
|
||||
"""
|
||||
Load masks from previous segment for continuity.
|
||||
|
||||
Args:
|
||||
prev_segment_dir: Directory of previous segment
|
||||
|
||||
Returns:
|
||||
Dictionary mapping object IDs to masks, or None if failed
|
||||
"""
|
||||
mask_path = os.path.join(prev_segment_dir, "mask.png")
|
||||
|
||||
if not os.path.exists(mask_path):
|
||||
logger.warning(f"Previous mask not found: {mask_path}")
|
||||
return None
|
||||
|
||||
try:
|
||||
mask_image = cv2.imread(mask_path)
|
||||
if mask_image is None:
|
||||
logger.error(f"Could not read mask image: {mask_path}")
|
||||
return None
|
||||
|
||||
if len(mask_image.shape) != 3 or mask_image.shape[2] != 3:
|
||||
logger.error("Mask image does not have three color channels")
|
||||
return None
|
||||
|
||||
mask_image = mask_image.astype(np.uint8)
|
||||
|
||||
# Extract Object A and Object B masks (preserving original logic)
|
||||
GREEN = [0, 255, 0]
|
||||
BLUE = [255, 0, 0]
|
||||
|
||||
mask_a = np.all(mask_image == GREEN, axis=2)
|
||||
mask_b = np.all(mask_image == BLUE, axis=2)
|
||||
|
||||
per_obj_input_mask = {}
|
||||
if np.any(mask_a):
|
||||
per_obj_input_mask[1] = mask_a
|
||||
if np.any(mask_b):
|
||||
per_obj_input_mask[2] = mask_b
|
||||
|
||||
logger.info(f"Loaded masks for {len(per_obj_input_mask)} objects from {prev_segment_dir}")
|
||||
return per_obj_input_mask
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading previous mask: {e}")
|
||||
return None
|
||||
|
||||
def add_previous_masks_to_predictor(self, inference_state, masks: Dict[int, np.ndarray]) -> bool:
|
||||
"""
|
||||
Add previous segment masks to predictor for continuity.
|
||||
|
||||
Args:
|
||||
inference_state: SAM2 inference state
|
||||
masks: Dictionary mapping object IDs to masks
|
||||
|
||||
Returns:
|
||||
True if masks were added successfully
|
||||
"""
|
||||
try:
|
||||
for obj_id, mask in masks.items():
|
||||
self.predictor.add_new_mask(inference_state, 0, obj_id, mask)
|
||||
logger.debug(f"Added previous mask for Object {obj_id}")
|
||||
|
||||
logger.info(f"Successfully added {len(masks)} previous masks to SAM2")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding previous masks to SAM2: {e}")
|
||||
return False
|
||||
|
||||
def propagate_masks(self, inference_state) -> Dict[int, Dict[int, np.ndarray]]:
|
||||
"""
|
||||
Propagate masks across all frames in the video.
|
||||
|
||||
Args:
|
||||
inference_state: SAM2 inference state
|
||||
|
||||
Returns:
|
||||
Dictionary mapping frame indices to object masks
|
||||
"""
|
||||
video_segments = {}
|
||||
|
||||
try:
|
||||
for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(inference_state):
|
||||
video_segments[out_frame_idx] = {
|
||||
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
|
||||
for i, out_obj_id in enumerate(out_obj_ids)
|
||||
}
|
||||
|
||||
logger.info(f"Propagated masks across {len(video_segments)} frames with {len(out_obj_ids)} objects")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during mask propagation: {e}")
|
||||
|
||||
return video_segments
|
||||
|
||||
def process_single_segment(self, segment_info: dict, yolo_prompts: Optional[List[Dict[str, Any]]] = None,
|
||||
previous_masks: Optional[Dict[int, np.ndarray]] = None,
|
||||
inference_scale: float = 0.5) -> Optional[Dict[int, Dict[int, np.ndarray]]]:
|
||||
"""
|
||||
Process a single video segment with SAM2.
|
||||
|
||||
Args:
|
||||
segment_info: Segment information dictionary
|
||||
yolo_prompts: Optional YOLO detection prompts
|
||||
previous_masks: Optional masks from previous segment
|
||||
inference_scale: Scale factor for inference
|
||||
|
||||
Returns:
|
||||
Video segments dictionary or None if failed
|
||||
"""
|
||||
segment_dir = segment_info['directory']
|
||||
video_path = segment_info['video_file']
|
||||
segment_idx = segment_info['index']
|
||||
|
||||
# Check if segment is already processed (resume capability)
|
||||
output_done_file = os.path.join(segment_dir, "output_frames_done")
|
||||
if os.path.exists(output_done_file):
|
||||
logger.info(f"Segment {segment_idx} already processed. Skipping.")
|
||||
return None # Indicate skip, not failure
|
||||
|
||||
logger.info(f"Processing segment {segment_idx} with SAM2")
|
||||
|
||||
# Create low-resolution video for inference
|
||||
low_res_video_path = os.path.join(segment_dir, "low_res_video.mp4")
|
||||
if not os.path.exists(low_res_video_path):
|
||||
try:
|
||||
self.create_low_res_video(video_path, low_res_video_path, inference_scale)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create low-res video for segment {segment_idx}: {e}")
|
||||
return None
|
||||
|
||||
try:
|
||||
# Initialize inference state
|
||||
inference_state = self.predictor.init_state(video_path=low_res_video_path, async_loading_frames=True)
|
||||
|
||||
# Add prompts or previous masks
|
||||
if yolo_prompts:
|
||||
if not self.add_yolo_prompts_to_predictor(inference_state, yolo_prompts):
|
||||
return None
|
||||
elif previous_masks:
|
||||
if not self.add_previous_masks_to_predictor(inference_state, previous_masks):
|
||||
return None
|
||||
else:
|
||||
logger.error(f"No prompts or previous masks available for segment {segment_idx}")
|
||||
return None
|
||||
|
||||
# Propagate masks
|
||||
video_segments = self.propagate_masks(inference_state)
|
||||
|
||||
# Clean up
|
||||
self.predictor.reset_state(inference_state)
|
||||
del inference_state
|
||||
gc.collect()
|
||||
|
||||
# Remove low-res video to save space
|
||||
try:
|
||||
os.remove(low_res_video_path)
|
||||
logger.debug(f"Removed low-res video: {low_res_video_path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not remove low-res video: {e}")
|
||||
|
||||
# Mark segment as completed (for resume capability)
|
||||
try:
|
||||
with open(output_done_file, 'w') as f:
|
||||
f.write(f"Segment {segment_idx} completed successfully\n")
|
||||
logger.debug(f"Marked segment {segment_idx} as completed")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not create completion marker: {e}")
|
||||
|
||||
return video_segments
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing segment {segment_idx}: {e}")
|
||||
return None
|
||||
|
||||
def save_final_masks(self, video_segments: Dict[int, Dict[int, np.ndarray]], output_path: str,
|
||||
green_color: List[int] = [0, 255, 0], blue_color: List[int] = [255, 0, 0]):
|
||||
"""
|
||||
Save the final masks as a colored image for continuity.
|
||||
|
||||
Args:
|
||||
video_segments: Video segments dictionary
|
||||
output_path: Path to save the mask image
|
||||
green_color: RGB color for object 1
|
||||
blue_color: RGB color for object 2
|
||||
"""
|
||||
if not video_segments:
|
||||
logger.error("No video segments to save")
|
||||
return
|
||||
|
||||
try:
|
||||
last_frame_idx = max(video_segments.keys())
|
||||
masks_dict = video_segments[last_frame_idx]
|
||||
|
||||
# Get masks for objects 1 and 2
|
||||
mask_a = masks_dict.get(1)
|
||||
mask_b = masks_dict.get(2)
|
||||
|
||||
if mask_a is None and mask_b is None:
|
||||
logger.error("No masks found for objects")
|
||||
return
|
||||
|
||||
# Use the first available mask to determine dimensions
|
||||
reference_mask = mask_a if mask_a is not None else mask_b
|
||||
reference_mask = reference_mask.squeeze()
|
||||
|
||||
black_frame = np.zeros((reference_mask.shape[0], reference_mask.shape[1], 3), dtype=np.uint8)
|
||||
|
||||
if mask_a is not None:
|
||||
mask_a = mask_a.squeeze().astype(bool)
|
||||
black_frame[mask_a] = green_color
|
||||
|
||||
if mask_b is not None:
|
||||
mask_b = mask_b.squeeze().astype(bool)
|
||||
black_frame[mask_b] = blue_color
|
||||
|
||||
# Save the mask image
|
||||
cv2.imwrite(output_path, black_frame)
|
||||
logger.info(f"Saved final masks to {output_path}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving final masks: {e}")
|
||||
Reference in New Issue
Block a user