Files
samyolo_on_segments/core/sam2_processor.py
2025-07-27 14:26:20 -07:00

653 lines
28 KiB
Python

"""
SAM2 processor module for video segmentation.
Preserves the core SAM2 logic from the original implementation.
"""
import os
import cv2
import numpy as np
import torch
import logging
import gc
from typing import Dict, List, Any, Optional, Tuple
from sam2.build_sam import build_sam2_video_predictor
logger = logging.getLogger(__name__)
class SAM2Processor:
"""Handles SAM2-based video segmentation for human tracking."""
def __init__(self, checkpoint_path: str, config_path: str, vos_optimized: bool = False):
"""
Initialize SAM2 processor.
Args:
checkpoint_path: Path to SAM2 checkpoint
config_path: Path to SAM2 config file
vos_optimized: Enable VOS optimization for speedup (requires PyTorch 2.5.1+)
"""
self.checkpoint_path = checkpoint_path
self.config_path = config_path
self.vos_optimized = vos_optimized
self.predictor = None
self._initialize_predictor()
def _initialize_predictor(self):
"""Initialize SAM2 video predictor with proper device setup."""
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
logger.warning(
"Support for MPS devices is preliminary. SAM 2 is trained with CUDA and might "
"give numerically different outputs and sometimes degraded performance on MPS."
)
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
else:
device = torch.device("cpu")
logger.info(f"Using device: {device}")
try:
# Extract just the config filename for SAM2's Hydra-based loader
# SAM2 expects a config name relative to its internal config directory
config_name = os.path.basename(self.config_path)
if config_name.endswith('.yaml'):
config_name = config_name[:-5] # Remove .yaml extension
# SAM2 configs are in the format "sam2.1_hiera_X.yaml"
# and should be referenced as "configs/sam2.1/sam2.1_hiera_X"
if config_name.startswith("sam2.1_hiera"):
config_name = f"configs/sam2.1/{config_name}"
elif config_name.startswith("sam2_hiera"):
config_name = f"configs/sam2/{config_name}"
logger.info(f"Using SAM2 config: {config_name}")
# Use VOS optimization if enabled and supported
if self.vos_optimized:
try:
self.predictor = build_sam2_video_predictor(
config_name, # Use just the config name, not full path
self.checkpoint_path,
device=device,
vos_optimized=True # New optimization for major speedup
)
logger.info("Using optimized SAM2 VOS predictor with full model compilation")
except Exception as e:
logger.warning(f"Failed to use optimized VOS predictor: {e}")
logger.info("Falling back to standard SAM2 predictor")
# Fallback to standard predictor
self.predictor = build_sam2_video_predictor(
config_name,
self.checkpoint_path,
device=device,
overrides=dict(conf=0.95)
)
else:
# Use standard predictor
self.predictor = build_sam2_video_predictor(
config_name,
self.checkpoint_path,
device=device,
overrides=dict(conf=0.95)
)
logger.info("Using standard SAM2 predictor")
# Enable optimizations for CUDA
if device.type == "cuda":
torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
logger.info(f"SAM2 predictor initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize SAM2 predictor: {e}")
raise
def create_low_res_video(self, input_video_path: str, output_video_path: str, scale: float):
"""
Create a low-resolution version of the input video for inference.
Args:
input_video_path: Path to input video
output_video_path: Path to output low-res video
scale: Scale factor for resolution reduction
"""
cap = cv2.VideoCapture(input_video_path)
if not cap.isOpened():
raise ValueError(f"Could not open video: {input_video_path}")
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH) * scale)
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT) * scale)
fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_video_path, fourcc, fps, (frame_width, frame_height))
frame_count = 0
while True:
ret, frame = cap.read()
if not ret:
break
low_res_frame = cv2.resize(frame, (frame_width, frame_height), interpolation=cv2.INTER_LINEAR)
out.write(low_res_frame)
frame_count += 1
cap.release()
out.release()
logger.info(f"Created low-res video with {frame_count} frames: {output_video_path}")
def add_yolo_prompts_to_predictor(self, inference_state, prompts: List[Dict[str, Any]]) -> bool:
"""
Add YOLO detection prompts to SAM2 predictor.
Includes error handling matching the working spec.md implementation.
Args:
inference_state: SAM2 inference state
prompts: List of prompt dictionaries with obj_id and bbox
Returns:
True if prompts were added successfully
"""
if not prompts:
logger.warning("SAM2 Debug: No prompts provided to SAM2")
return False
logger.info(f"SAM2 Debug: Received {len(prompts)} prompts to add to predictor")
success_count = 0
for i, prompt in enumerate(prompts):
obj_id = prompt['obj_id']
bbox = prompt['bbox']
confidence = prompt.get('confidence', 'unknown')
logger.info(f"SAM2 Debug: Adding prompt {i+1}/{len(prompts)}: Object {obj_id}, bbox={bbox}, conf={confidence}")
try:
_, out_obj_ids, out_mask_logits = self.predictor.add_new_points_or_box(
inference_state=inference_state,
frame_idx=0,
obj_id=obj_id,
box=bbox.astype(np.float32),
)
logger.info(f"SAM2 Debug: ✓ Successfully added Object {obj_id} - returned obj_ids: {out_obj_ids}")
success_count += 1
except Exception as e:
logger.error(f"SAM2 Debug: ✗ Error adding Object {obj_id}: {e}")
# Continue processing other prompts even if one fails
continue
if success_count > 0:
logger.info(f"SAM2 Debug: Final result - {success_count}/{len(prompts)} prompts successfully added")
return True
else:
logger.error("SAM2 Debug: FAILED - No prompts were successfully added to SAM2")
return False
def load_previous_segment_mask(self, prev_segment_dir: str) -> Optional[Dict[int, np.ndarray]]:
"""
Load masks from previous segment for continuity.
Args:
prev_segment_dir: Directory of previous segment
Returns:
Dictionary mapping object IDs to masks, or None if failed
"""
mask_path = os.path.join(prev_segment_dir, "mask.png")
if not os.path.exists(mask_path):
logger.warning(f"Previous mask not found: {mask_path}")
return None
try:
mask_image = cv2.imread(mask_path)
if mask_image is None:
logger.error(f"Could not read mask image: {mask_path}")
return None
if len(mask_image.shape) != 3 or mask_image.shape[2] != 3:
logger.error("Mask image does not have three color channels")
return None
mask_image = mask_image.astype(np.uint8)
# Extract Object A and Object B masks (preserving original logic)
GREEN = [0, 255, 0]
BLUE = [255, 0, 0]
mask_a = np.all(mask_image == GREEN, axis=2)
mask_b = np.all(mask_image == BLUE, axis=2)
per_obj_input_mask = {}
if np.any(mask_a):
per_obj_input_mask[1] = mask_a
if np.any(mask_b):
per_obj_input_mask[2] = mask_b
logger.info(f"Loaded masks for {len(per_obj_input_mask)} objects from {prev_segment_dir}")
return per_obj_input_mask
except Exception as e:
logger.error(f"Error loading previous mask: {e}")
return None
def add_previous_masks_to_predictor(self, inference_state, masks: Dict[int, np.ndarray]) -> bool:
"""
Add previous segment masks to predictor for continuity.
Args:
inference_state: SAM2 inference state
masks: Dictionary mapping object IDs to masks
Returns:
True if masks were added successfully
"""
try:
for obj_id, mask in masks.items():
self.predictor.add_new_mask(inference_state, 0, obj_id, mask)
logger.debug(f"Added previous mask for Object {obj_id}")
logger.info(f"Successfully added {len(masks)} previous masks to SAM2")
return True
except Exception as e:
logger.error(f"Error adding previous masks to SAM2: {e}")
return False
def propagate_masks(self, inference_state) -> Dict[int, Dict[int, np.ndarray]]:
"""
Propagate masks across all frames in the video.
Args:
inference_state: SAM2 inference state
Returns:
Dictionary mapping frame indices to object masks
"""
video_segments = {}
frame_count = 0
try:
logger.info("Starting SAM2 mask propagation...")
for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(inference_state):
video_segments[out_frame_idx] = {
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
for i, out_obj_id in enumerate(out_obj_ids)
}
frame_count += 1
# Log progress every 50 frames
if frame_count % 50 == 0:
logger.info(f"SAM2 propagation progress: {frame_count} frames processed")
logger.info(f"SAM2 propagation completed: {len(video_segments)} frames with {len(out_obj_ids) if 'out_obj_ids' in locals() else 0} objects")
except Exception as e:
logger.error(f"Error during mask propagation after {frame_count} frames: {e}")
logger.error("This may be due to VOS optimization issues or insufficient GPU memory")
if frame_count == 0:
logger.error("No frames were processed - propagation failed completely")
else:
logger.warning(f"Partial propagation completed: {frame_count} frames before failure")
return video_segments
def process_single_segment(self, segment_info: dict, yolo_prompts: Optional[List[Dict[str, Any]]] = None,
previous_masks: Optional[Dict[int, np.ndarray]] = None,
inference_scale: float = 0.5,
multi_frame_prompts: Optional[Dict[int, List[Dict[str, Any]]]] = None) -> Optional[Dict[int, Dict[int, np.ndarray]]]:
"""
Process a single video segment with SAM2.
Args:
segment_info: Segment information dictionary
yolo_prompts: Optional YOLO detection prompts for first frame
previous_masks: Optional masks from previous segment
inference_scale: Scale factor for inference
multi_frame_prompts: Optional prompts for multiple frames (mid-segment detection)
Returns:
Video segments dictionary or None if failed
"""
segment_dir = segment_info['directory']
video_path = segment_info['video_file']
segment_idx = segment_info['index']
# Check if segment is already processed (resume capability)
output_done_file = os.path.join(segment_dir, "output_frames_done")
if os.path.exists(output_done_file):
logger.info(f"Segment {segment_idx} already processed. Skipping.")
return None # Indicate skip, not failure
logger.info(f"Processing segment {segment_idx} with SAM2")
# Create low-resolution video for inference
low_res_video_path = os.path.join(segment_dir, "low_res_video.mp4")
if not os.path.exists(low_res_video_path):
try:
self.create_low_res_video(video_path, low_res_video_path, inference_scale)
except Exception as e:
logger.error(f"Failed to create low-res video for segment {segment_idx}: {e}")
return None
try:
# Initialize inference state
inference_state = self.predictor.init_state(video_path=low_res_video_path, async_loading_frames=True)
# Add prompts or previous masks
if yolo_prompts:
if not self.add_yolo_prompts_to_predictor(inference_state, yolo_prompts):
return None
elif previous_masks:
if not self.add_previous_masks_to_predictor(inference_state, previous_masks):
return None
else:
logger.error(f"No prompts or previous masks available for segment {segment_idx}")
return None
# Add mid-segment prompts if provided
if multi_frame_prompts:
logger.info(f"Adding mid-segment prompts for segment {segment_idx}")
if not self.add_multi_frame_prompts_to_predictor(inference_state, multi_frame_prompts):
logger.warning(f"Failed to add mid-segment prompts for segment {segment_idx}")
# Don't return None here - continue with existing prompts
# Propagate masks
video_segments = self.propagate_masks(inference_state)
# Clean up
self.predictor.reset_state(inference_state)
del inference_state
gc.collect()
# Remove low-res video to save space
try:
os.remove(low_res_video_path)
logger.debug(f"Removed low-res video: {low_res_video_path}")
except Exception as e:
logger.warning(f"Could not remove low-res video: {e}")
# Mark segment as completed (for resume capability)
try:
with open(output_done_file, 'w') as f:
f.write(f"Segment {segment_idx} completed successfully\n")
logger.debug(f"Marked segment {segment_idx} as completed")
except Exception as e:
logger.warning(f"Could not create completion marker: {e}")
return video_segments
except Exception as e:
logger.error(f"Error processing segment {segment_idx}: {e}")
return None
def save_final_masks(self, video_segments: Dict[int, Dict[int, np.ndarray]], output_path: str,
green_color: List[int] = [0, 255, 0], blue_color: List[int] = [255, 0, 0]):
"""
Save the final masks as a colored image for continuity.
Args:
video_segments: Video segments dictionary
output_path: Path to save the mask image
green_color: RGB color for object 1
blue_color: RGB color for object 2
"""
if not video_segments:
logger.error("No video segments to save")
return
try:
last_frame_idx = max(video_segments.keys())
masks_dict = video_segments[last_frame_idx]
# Get masks for objects 1 and 2
mask_a = masks_dict.get(1)
mask_b = masks_dict.get(2)
if mask_a is None and mask_b is None:
logger.error("No masks found for objects")
return
# Use the first available mask to determine dimensions
reference_mask = mask_a if mask_a is not None else mask_b
reference_mask = reference_mask.squeeze()
black_frame = np.zeros((reference_mask.shape[0], reference_mask.shape[1], 3), dtype=np.uint8)
if mask_a is not None:
mask_a = mask_a.squeeze().astype(bool)
black_frame[mask_a] = green_color
if mask_b is not None:
mask_b = mask_b.squeeze().astype(bool)
black_frame[mask_b] = blue_color
# Save the mask image
cv2.imwrite(output_path, black_frame)
logger.info(f"Saved final masks to {output_path}")
except Exception as e:
logger.error(f"Error saving final masks: {e}")
def generate_first_frame_debug_masks(self, video_path: str, prompts: List[Dict[str, Any]],
output_path: str, inference_scale: float = 0.5) -> bool:
"""
Generate SAM2 masks for just the first frame and save debug visualization.
This helps debug what SAM2 is producing for each detected object.
Args:
video_path: Path to the video file
prompts: List of SAM2 prompt dictionaries
output_path: Path to save the debug image
inference_scale: Scale factor for SAM2 inference
Returns:
True if debug masks were generated successfully
"""
if not prompts:
logger.warning("No prompts provided for first frame debug")
return False
try:
logger.info(f"SAM2 Debug: Generating first frame masks for {len(prompts)} objects")
# Load the first frame
cap = cv2.VideoCapture(video_path)
ret, original_frame = cap.read()
cap.release()
if not ret:
logger.error("Could not read first frame for debug mask generation")
return False
# Scale frame for inference if needed
if inference_scale != 1.0:
inference_frame = cv2.resize(original_frame, None, fx=inference_scale, fy=inference_scale, interpolation=cv2.INTER_LINEAR)
else:
inference_frame = original_frame.copy()
# Create temporary low-res video with just first frame
import tempfile
import os
temp_dir = tempfile.mkdtemp()
temp_video_path = os.path.join(temp_dir, "first_frame.mp4")
# Write single frame to temporary video
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(temp_video_path, fourcc, 1.0, (inference_frame.shape[1], inference_frame.shape[0]))
out.write(inference_frame)
out.release()
# Initialize SAM2 inference state with single frame
inference_state = self.predictor.init_state(video_path=temp_video_path, async_loading_frames=True)
# Add prompts
if not self.add_yolo_prompts_to_predictor(inference_state, prompts):
logger.error("Failed to add prompts for first frame debug")
return False
# Generate masks for first frame only
frame_masks = {}
for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(inference_state):
if out_frame_idx == 0: # Only process first frame
frame_masks = {
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
for i, out_obj_id in enumerate(out_obj_ids)
}
break
if not frame_masks:
logger.error("No masks generated for first frame debug")
return False
# Create debug visualization
debug_frame = original_frame.copy()
# Define colors for each object
colors = {
1: (0, 255, 0), # Green for Object 1 (Left eye)
2: (255, 0, 0), # Blue for Object 2 (Right eye)
3: (0, 255, 255), # Yellow for Object 3
4: (255, 0, 255), # Magenta for Object 4
}
# Overlay masks with transparency
for obj_id, mask in frame_masks.items():
mask = mask.squeeze()
# Resize mask to match original frame if needed
if mask.shape != original_frame.shape[:2]:
mask = cv2.resize(mask.astype(np.float32), (original_frame.shape[1], original_frame.shape[0]), interpolation=cv2.INTER_NEAREST)
mask = mask > 0.5
# Apply colored overlay
color = colors.get(obj_id, (128, 128, 128))
overlay = debug_frame.copy()
overlay[mask] = color
# Blend with original (30% overlay, 70% original)
cv2.addWeighted(overlay, 0.3, debug_frame, 0.7, 0, debug_frame)
# Draw outline
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
cv2.drawContours(debug_frame, contours, -1, color, 2)
logger.info(f"SAM2 Debug: Object {obj_id} mask - shape: {mask.shape}, pixels: {np.sum(mask)}")
# Add title
title = f"SAM2 First Frame Masks: {len(frame_masks)} objects detected"
cv2.putText(debug_frame, title, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (255, 255, 255), 2)
# Add mask source information
source_info = "Mask Source: SAM2 (from YOLO bounding boxes)"
cv2.putText(debug_frame, source_info, (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2)
# Add object legend
y_offset = 90
for obj_id in sorted(frame_masks.keys()):
color = colors.get(obj_id, (128, 128, 128))
text = f"Object {obj_id}: {'Left Eye' if obj_id == 1 else 'Right Eye' if obj_id == 2 else f'Object {obj_id}'}"
cv2.putText(debug_frame, text, (10, y_offset), cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2)
y_offset += 30
# Save debug image
success = cv2.imwrite(output_path, debug_frame)
# Cleanup
self.predictor.reset_state(inference_state)
import shutil
shutil.rmtree(temp_dir)
if success:
logger.info(f"SAM2 Debug: Saved first frame masks to {output_path}")
return True
else:
logger.error(f"Failed to save first frame masks to {output_path}")
return False
except Exception as e:
logger.error(f"Error generating first frame debug masks: {e}")
return False
def add_multi_frame_prompts_to_predictor(self, inference_state, multi_frame_prompts: Dict[int, Any]) -> bool:
"""
Add YOLO prompts at multiple frame indices for mid-segment re-detection.
Supports both bounding box prompts (detection mode) and mask prompts (segmentation mode).
Args:
inference_state: SAM2 inference state
multi_frame_prompts: Dictionary mapping frame_index -> prompts (list of dicts for bbox, dict with 'masks' for segmentation)
Returns:
True if prompts were added successfully
"""
if not multi_frame_prompts:
logger.warning("SAM2 Mid-segment: No multi-frame prompts provided")
return False
success_count = 0
total_count = 0
for frame_idx, prompts_data in multi_frame_prompts.items():
# Check if this is segmentation mode (masks) or detection mode (bbox prompts)
if isinstance(prompts_data, dict) and 'masks' in prompts_data:
# Segmentation mode: add masks directly
masks_dict = prompts_data['masks']
logger.info(f"SAM2 Mid-segment: Processing frame {frame_idx} with {len(masks_dict)} YOLO masks")
for obj_id, mask in masks_dict.items():
total_count += 1
logger.info(f"SAM2 Mid-segment: Frame {frame_idx}, adding mask for Object {obj_id}")
try:
self.predictor.add_new_mask(inference_state, frame_idx, obj_id, mask)
logger.info(f"SAM2 Mid-segment: ✓ Frame {frame_idx}, Object {obj_id} mask added successfully")
success_count += 1
except Exception as e:
logger.error(f"SAM2 Mid-segment: ✗ Frame {frame_idx}, Object {obj_id} mask failed: {e}")
continue
else:
# Detection mode: add bounding box prompts (existing logic)
prompts = prompts_data
logger.info(f"SAM2 Mid-segment: Processing frame {frame_idx} with {len(prompts)} bbox prompts")
for i, prompt in enumerate(prompts):
obj_id = prompt['obj_id']
bbox = prompt['bbox']
confidence = prompt.get('confidence', 'unknown')
total_count += 1
logger.info(f"SAM2 Mid-segment: Frame {frame_idx}, Prompt {i+1}/{len(prompts)}: Object {obj_id}, bbox={bbox}, conf={confidence}")
try:
_, out_obj_ids, out_mask_logits = self.predictor.add_new_points_or_box(
inference_state=inference_state,
frame_idx=frame_idx, # Key: specify the exact frame index
obj_id=obj_id,
box=bbox.astype(np.float32),
)
logger.info(f"SAM2 Mid-segment: ✓ Frame {frame_idx}, Object {obj_id} added successfully - returned obj_ids: {out_obj_ids}")
success_count += 1
except Exception as e:
logger.error(f"SAM2 Mid-segment: ✗ Frame {frame_idx}, Object {obj_id} failed: {e}")
continue
if success_count > 0:
logger.info(f"SAM2 Mid-segment: Final result - {success_count}/{total_count} prompts successfully added across {len(multi_frame_prompts)} frames")
return True
else:
logger.error("SAM2 Mid-segment: FAILED - No prompts were successfully added")
return False