sbs working phase 1

This commit is contained in:
2025-07-30 18:07:26 -07:00
parent 6617acb1c9
commit 70044e1b10
8 changed files with 2417 additions and 7 deletions

View File

@@ -11,13 +11,15 @@ import logging
import gc
from typing import Dict, List, Any, Optional, Tuple
from sam2.build_sam import build_sam2_video_predictor
from .eye_processor import EyeProcessor
logger = logging.getLogger(__name__)
class SAM2Processor:
"""Handles SAM2-based video segmentation for human tracking."""
def __init__(self, checkpoint_path: str, config_path: str, vos_optimized: bool = False):
def __init__(self, checkpoint_path: str, config_path: str, vos_optimized: bool = False,
separate_eye_processing: bool = False, eye_overlap_pixels: int = 0):
"""
Initialize SAM2 processor.
@@ -25,11 +27,21 @@ class SAM2Processor:
checkpoint_path: Path to SAM2 checkpoint
config_path: Path to SAM2 config file
vos_optimized: Enable VOS optimization for speedup (requires PyTorch 2.5.1+)
separate_eye_processing: Enable VR180 separate eye processing mode
eye_overlap_pixels: Pixel overlap between eyes for blending
"""
self.checkpoint_path = checkpoint_path
self.config_path = config_path
self.vos_optimized = vos_optimized
self.separate_eye_processing = separate_eye_processing
self.predictor = None
# Initialize eye processor if separate eye processing is enabled
if separate_eye_processing:
self.eye_processor = EyeProcessor(eye_overlap_pixels=eye_overlap_pixels)
else:
self.eye_processor = None
self._initialize_predictor()
def _initialize_predictor(self):
@@ -650,3 +662,253 @@ class SAM2Processor:
else:
logger.error("SAM2 Mid-segment: FAILED - No prompts were successfully added")
return False
def process_single_eye_segment(self, segment_info: dict, eye_side: str,
yolo_prompts: Optional[List[Dict[str, Any]]] = None,
previous_masks: Optional[Dict[int, np.ndarray]] = None,
inference_scale: float = 0.5) -> Optional[Dict[int, np.ndarray]]:
"""
Process a single eye of a VR180 segment with SAM2.
Args:
segment_info: Segment information dictionary
eye_side: 'left' or 'right' eye
yolo_prompts: Optional YOLO detection prompts for first frame
previous_masks: Optional masks from previous segment
inference_scale: Scale factor for inference
Returns:
Dictionary mapping frame indices to masks, or None if failed
"""
if not self.eye_processor:
logger.error("Eye processor not initialized - separate_eye_processing must be enabled")
return None
segment_dir = segment_info['directory']
video_path = segment_info['video_file']
segment_idx = segment_info['index']
logger.info(f"Processing {eye_side} eye for segment {segment_idx}")
# Use the video path directly (it should already be the eye-specific video)
eye_video_path = video_path
# Verify the eye video exists
if not os.path.exists(eye_video_path):
logger.error(f"Eye video not found: {eye_video_path}")
return None
# Create low-resolution eye video for inference
low_res_eye_video_path = os.path.join(segment_dir, f"low_res_{eye_side}_eye_video.mp4")
if not os.path.exists(low_res_eye_video_path):
try:
self.create_low_res_video(eye_video_path, low_res_eye_video_path, inference_scale)
except Exception as e:
logger.error(f"Failed to create low-res {eye_side} eye video for segment {segment_idx}: {e}")
return None
try:
# Initialize inference state with eye-specific video
inference_state = self.predictor.init_state(video_path=low_res_eye_video_path, async_loading_frames=True)
# Add prompts or previous masks (always use obj_id=1 for single eye processing)
if yolo_prompts:
# Convert prompts to use obj_id=1 for single eye processing
eye_prompts = []
for prompt in yolo_prompts:
eye_prompt = prompt.copy()
eye_prompt['obj_id'] = 1 # Always use obj_id=1 for single eye
eye_prompts.append(eye_prompt)
if not self.add_yolo_prompts_to_predictor(inference_state, eye_prompts):
logger.error(f"Failed to add prompts for {eye_side} eye")
return None
elif previous_masks:
# Convert previous masks to use obj_id=1 for single eye processing
eye_masks = {1: list(previous_masks.values())[0]} if previous_masks else {}
if not self.add_previous_masks_to_predictor(inference_state, eye_masks):
logger.error(f"Failed to add previous masks for {eye_side} eye")
return None
else:
logger.error(f"No prompts or previous masks available for {eye_side} eye of segment {segment_idx}")
return None
# Propagate masks
logger.info(f"Propagating masks for {eye_side} eye")
video_segments = self.propagate_masks(inference_state)
# Extract just the masks (remove obj_id structure since we only use obj_id=1)
eye_masks = {}
for frame_idx, frame_masks in video_segments.items():
if 1 in frame_masks: # We always use obj_id=1 for single eye processing
eye_masks[frame_idx] = frame_masks[1]
# Clean up
self.predictor.reset_state(inference_state)
del inference_state
gc.collect()
# Remove temporary low-res video
try:
os.remove(low_res_eye_video_path)
logger.debug(f"Removed low-res {eye_side} eye video: {low_res_eye_video_path}")
except Exception as e:
logger.warning(f"Could not remove low-res {eye_side} eye video: {e}")
logger.info(f"Successfully processed {eye_side} eye with {len(eye_masks)} frames")
return eye_masks
except Exception as e:
logger.error(f"Error processing {eye_side} eye for segment {segment_idx}: {e}")
return None
def process_segment_with_separate_eyes(self, segment_info: dict,
left_prompts: Optional[List[Dict[str, Any]]] = None,
right_prompts: Optional[List[Dict[str, Any]]] = None,
previous_left_masks: Optional[Dict[int, np.ndarray]] = None,
previous_right_masks: Optional[Dict[int, np.ndarray]] = None,
inference_scale: float = 0.5,
full_frame_shape: Optional[Tuple[int, int]] = None) -> Optional[Dict[int, Dict[int, np.ndarray]]]:
"""
Process a VR180 segment with separate left and right eye processing.
Args:
segment_info: Segment information dictionary
left_prompts: Optional YOLO prompts for left eye
right_prompts: Optional YOLO prompts for right eye
previous_left_masks: Optional previous masks for left eye
previous_right_masks: Optional previous masks for right eye
inference_scale: Scale factor for inference
full_frame_shape: Shape of full VR180 frame (height, width)
Returns:
Combined video segments dictionary or None if failed
"""
if not self.eye_processor:
logger.error("Eye processor not initialized - separate_eye_processing must be enabled")
return None
segment_idx = segment_info['index']
logger.info(f"Processing segment {segment_idx} with separate eye processing")
# Get full frame shape if not provided
if full_frame_shape is None:
try:
cap = cv2.VideoCapture(segment_info['video_file'])
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
cap.release()
full_frame_shape = (height, width)
except Exception as e:
logger.error(f"Could not determine frame shape: {e}")
return None
# Process left eye if prompts or previous masks are available
left_masks = None
if left_prompts or previous_left_masks:
logger.info(f"Processing left eye for segment {segment_idx}")
left_masks = self.process_single_eye_segment(
segment_info, 'left', left_prompts, previous_left_masks, inference_scale
)
# Process right eye if prompts or previous masks are available
right_masks = None
if right_prompts or previous_right_masks:
logger.info(f"Processing right eye for segment {segment_idx}")
right_masks = self.process_single_eye_segment(
segment_info, 'right', right_prompts, previous_right_masks, inference_scale
)
# Combine masks back to full frame format
if left_masks or right_masks:
logger.info(f"Combining eye masks for segment {segment_idx}")
combined_masks = self.eye_processor.combine_eye_masks(
left_masks, right_masks, full_frame_shape
)
# Clean up eye-specific videos to save space
try:
left_eye_path = os.path.join(segment_info['directory'], "left_eye_video.mp4")
right_eye_path = os.path.join(segment_info['directory'], "right_eye_video.mp4")
if os.path.exists(left_eye_path):
os.remove(left_eye_path)
logger.debug(f"Removed left eye video: {left_eye_path}")
if os.path.exists(right_eye_path):
os.remove(right_eye_path)
logger.debug(f"Removed right eye video: {right_eye_path}")
except Exception as e:
logger.warning(f"Could not clean up eye videos: {e}")
logger.info(f"Successfully processed segment {segment_idx} with separate eyes")
return combined_masks
else:
logger.warning(f"No masks generated for either eye in segment {segment_idx}")
return None
def create_greenscreen_segment(self, segment_info: dict, green_color: List[int] = [0, 255, 0]) -> bool:
"""
Create a full greenscreen segment when no humans are detected.
Args:
segment_info: Segment information dictionary
green_color: RGB values for green screen color
Returns:
True if greenscreen segment was created successfully
"""
segment_dir = segment_info['directory']
video_path = segment_info['video_file']
segment_idx = segment_info['index']
logger.info(f"Creating full greenscreen segment {segment_idx}")
try:
# Get video properties
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
logger.error(f"Could not open video: {video_path}")
return False
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = cap.get(cv2.CAP_PROP_FPS)
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
cap.release()
# Create output video path
output_video_path = os.path.join(segment_dir, f"output_{segment_idx}.mp4")
# Create greenscreen frames
greenscreen_frame = self.eye_processor.create_full_greenscreen_frame(
(height, width, 3), green_color
)
# Write greenscreen video
fourcc = cv2.VideoWriter_fourcc(*'HEVC')
out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height))
for _ in range(frame_count):
out.write(greenscreen_frame)
out.release()
# Create mask file (empty/black mask since no humans detected)
mask_output_path = os.path.join(segment_dir, "mask.png")
black_mask = np.zeros((height, width, 3), dtype=np.uint8)
cv2.imwrite(mask_output_path, black_mask)
# Mark segment as completed
output_done_file = os.path.join(segment_dir, "output_frames_done")
with open(output_done_file, 'w') as f:
f.write(f"Greenscreen segment {segment_idx} completed successfully\n")
logger.info(f"Successfully created greenscreen segment {segment_idx}")
return True
except Exception as e:
logger.error(f"Error creating greenscreen segment {segment_idx}: {e}")
return False