sbs working phase 1
This commit is contained in:
@@ -11,13 +11,15 @@ import logging
|
||||
import gc
|
||||
from typing import Dict, List, Any, Optional, Tuple
|
||||
from sam2.build_sam import build_sam2_video_predictor
|
||||
from .eye_processor import EyeProcessor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class SAM2Processor:
|
||||
"""Handles SAM2-based video segmentation for human tracking."""
|
||||
|
||||
def __init__(self, checkpoint_path: str, config_path: str, vos_optimized: bool = False):
|
||||
def __init__(self, checkpoint_path: str, config_path: str, vos_optimized: bool = False,
|
||||
separate_eye_processing: bool = False, eye_overlap_pixels: int = 0):
|
||||
"""
|
||||
Initialize SAM2 processor.
|
||||
|
||||
@@ -25,11 +27,21 @@ class SAM2Processor:
|
||||
checkpoint_path: Path to SAM2 checkpoint
|
||||
config_path: Path to SAM2 config file
|
||||
vos_optimized: Enable VOS optimization for speedup (requires PyTorch 2.5.1+)
|
||||
separate_eye_processing: Enable VR180 separate eye processing mode
|
||||
eye_overlap_pixels: Pixel overlap between eyes for blending
|
||||
"""
|
||||
self.checkpoint_path = checkpoint_path
|
||||
self.config_path = config_path
|
||||
self.vos_optimized = vos_optimized
|
||||
self.separate_eye_processing = separate_eye_processing
|
||||
self.predictor = None
|
||||
|
||||
# Initialize eye processor if separate eye processing is enabled
|
||||
if separate_eye_processing:
|
||||
self.eye_processor = EyeProcessor(eye_overlap_pixels=eye_overlap_pixels)
|
||||
else:
|
||||
self.eye_processor = None
|
||||
|
||||
self._initialize_predictor()
|
||||
|
||||
def _initialize_predictor(self):
|
||||
@@ -650,3 +662,253 @@ class SAM2Processor:
|
||||
else:
|
||||
logger.error("SAM2 Mid-segment: FAILED - No prompts were successfully added")
|
||||
return False
|
||||
|
||||
def process_single_eye_segment(self, segment_info: dict, eye_side: str,
|
||||
yolo_prompts: Optional[List[Dict[str, Any]]] = None,
|
||||
previous_masks: Optional[Dict[int, np.ndarray]] = None,
|
||||
inference_scale: float = 0.5) -> Optional[Dict[int, np.ndarray]]:
|
||||
"""
|
||||
Process a single eye of a VR180 segment with SAM2.
|
||||
|
||||
Args:
|
||||
segment_info: Segment information dictionary
|
||||
eye_side: 'left' or 'right' eye
|
||||
yolo_prompts: Optional YOLO detection prompts for first frame
|
||||
previous_masks: Optional masks from previous segment
|
||||
inference_scale: Scale factor for inference
|
||||
|
||||
Returns:
|
||||
Dictionary mapping frame indices to masks, or None if failed
|
||||
"""
|
||||
if not self.eye_processor:
|
||||
logger.error("Eye processor not initialized - separate_eye_processing must be enabled")
|
||||
return None
|
||||
|
||||
segment_dir = segment_info['directory']
|
||||
video_path = segment_info['video_file']
|
||||
segment_idx = segment_info['index']
|
||||
|
||||
logger.info(f"Processing {eye_side} eye for segment {segment_idx}")
|
||||
|
||||
# Use the video path directly (it should already be the eye-specific video)
|
||||
eye_video_path = video_path
|
||||
|
||||
# Verify the eye video exists
|
||||
if not os.path.exists(eye_video_path):
|
||||
logger.error(f"Eye video not found: {eye_video_path}")
|
||||
return None
|
||||
|
||||
# Create low-resolution eye video for inference
|
||||
low_res_eye_video_path = os.path.join(segment_dir, f"low_res_{eye_side}_eye_video.mp4")
|
||||
if not os.path.exists(low_res_eye_video_path):
|
||||
try:
|
||||
self.create_low_res_video(eye_video_path, low_res_eye_video_path, inference_scale)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create low-res {eye_side} eye video for segment {segment_idx}: {e}")
|
||||
return None
|
||||
|
||||
try:
|
||||
# Initialize inference state with eye-specific video
|
||||
inference_state = self.predictor.init_state(video_path=low_res_eye_video_path, async_loading_frames=True)
|
||||
|
||||
# Add prompts or previous masks (always use obj_id=1 for single eye processing)
|
||||
if yolo_prompts:
|
||||
# Convert prompts to use obj_id=1 for single eye processing
|
||||
eye_prompts = []
|
||||
for prompt in yolo_prompts:
|
||||
eye_prompt = prompt.copy()
|
||||
eye_prompt['obj_id'] = 1 # Always use obj_id=1 for single eye
|
||||
eye_prompts.append(eye_prompt)
|
||||
|
||||
if not self.add_yolo_prompts_to_predictor(inference_state, eye_prompts):
|
||||
logger.error(f"Failed to add prompts for {eye_side} eye")
|
||||
return None
|
||||
|
||||
elif previous_masks:
|
||||
# Convert previous masks to use obj_id=1 for single eye processing
|
||||
eye_masks = {1: list(previous_masks.values())[0]} if previous_masks else {}
|
||||
if not self.add_previous_masks_to_predictor(inference_state, eye_masks):
|
||||
logger.error(f"Failed to add previous masks for {eye_side} eye")
|
||||
return None
|
||||
else:
|
||||
logger.error(f"No prompts or previous masks available for {eye_side} eye of segment {segment_idx}")
|
||||
return None
|
||||
|
||||
# Propagate masks
|
||||
logger.info(f"Propagating masks for {eye_side} eye")
|
||||
video_segments = self.propagate_masks(inference_state)
|
||||
|
||||
# Extract just the masks (remove obj_id structure since we only use obj_id=1)
|
||||
eye_masks = {}
|
||||
for frame_idx, frame_masks in video_segments.items():
|
||||
if 1 in frame_masks: # We always use obj_id=1 for single eye processing
|
||||
eye_masks[frame_idx] = frame_masks[1]
|
||||
|
||||
# Clean up
|
||||
self.predictor.reset_state(inference_state)
|
||||
del inference_state
|
||||
gc.collect()
|
||||
|
||||
# Remove temporary low-res video
|
||||
try:
|
||||
os.remove(low_res_eye_video_path)
|
||||
logger.debug(f"Removed low-res {eye_side} eye video: {low_res_eye_video_path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not remove low-res {eye_side} eye video: {e}")
|
||||
|
||||
logger.info(f"Successfully processed {eye_side} eye with {len(eye_masks)} frames")
|
||||
return eye_masks
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing {eye_side} eye for segment {segment_idx}: {e}")
|
||||
return None
|
||||
|
||||
def process_segment_with_separate_eyes(self, segment_info: dict,
|
||||
left_prompts: Optional[List[Dict[str, Any]]] = None,
|
||||
right_prompts: Optional[List[Dict[str, Any]]] = None,
|
||||
previous_left_masks: Optional[Dict[int, np.ndarray]] = None,
|
||||
previous_right_masks: Optional[Dict[int, np.ndarray]] = None,
|
||||
inference_scale: float = 0.5,
|
||||
full_frame_shape: Optional[Tuple[int, int]] = None) -> Optional[Dict[int, Dict[int, np.ndarray]]]:
|
||||
"""
|
||||
Process a VR180 segment with separate left and right eye processing.
|
||||
|
||||
Args:
|
||||
segment_info: Segment information dictionary
|
||||
left_prompts: Optional YOLO prompts for left eye
|
||||
right_prompts: Optional YOLO prompts for right eye
|
||||
previous_left_masks: Optional previous masks for left eye
|
||||
previous_right_masks: Optional previous masks for right eye
|
||||
inference_scale: Scale factor for inference
|
||||
full_frame_shape: Shape of full VR180 frame (height, width)
|
||||
|
||||
Returns:
|
||||
Combined video segments dictionary or None if failed
|
||||
"""
|
||||
if not self.eye_processor:
|
||||
logger.error("Eye processor not initialized - separate_eye_processing must be enabled")
|
||||
return None
|
||||
|
||||
segment_idx = segment_info['index']
|
||||
logger.info(f"Processing segment {segment_idx} with separate eye processing")
|
||||
|
||||
# Get full frame shape if not provided
|
||||
if full_frame_shape is None:
|
||||
try:
|
||||
cap = cv2.VideoCapture(segment_info['video_file'])
|
||||
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
cap.release()
|
||||
full_frame_shape = (height, width)
|
||||
except Exception as e:
|
||||
logger.error(f"Could not determine frame shape: {e}")
|
||||
return None
|
||||
|
||||
# Process left eye if prompts or previous masks are available
|
||||
left_masks = None
|
||||
if left_prompts or previous_left_masks:
|
||||
logger.info(f"Processing left eye for segment {segment_idx}")
|
||||
left_masks = self.process_single_eye_segment(
|
||||
segment_info, 'left', left_prompts, previous_left_masks, inference_scale
|
||||
)
|
||||
|
||||
# Process right eye if prompts or previous masks are available
|
||||
right_masks = None
|
||||
if right_prompts or previous_right_masks:
|
||||
logger.info(f"Processing right eye for segment {segment_idx}")
|
||||
right_masks = self.process_single_eye_segment(
|
||||
segment_info, 'right', right_prompts, previous_right_masks, inference_scale
|
||||
)
|
||||
|
||||
# Combine masks back to full frame format
|
||||
if left_masks or right_masks:
|
||||
logger.info(f"Combining eye masks for segment {segment_idx}")
|
||||
combined_masks = self.eye_processor.combine_eye_masks(
|
||||
left_masks, right_masks, full_frame_shape
|
||||
)
|
||||
|
||||
# Clean up eye-specific videos to save space
|
||||
try:
|
||||
left_eye_path = os.path.join(segment_info['directory'], "left_eye_video.mp4")
|
||||
right_eye_path = os.path.join(segment_info['directory'], "right_eye_video.mp4")
|
||||
|
||||
if os.path.exists(left_eye_path):
|
||||
os.remove(left_eye_path)
|
||||
logger.debug(f"Removed left eye video: {left_eye_path}")
|
||||
|
||||
if os.path.exists(right_eye_path):
|
||||
os.remove(right_eye_path)
|
||||
logger.debug(f"Removed right eye video: {right_eye_path}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not clean up eye videos: {e}")
|
||||
|
||||
logger.info(f"Successfully processed segment {segment_idx} with separate eyes")
|
||||
return combined_masks
|
||||
else:
|
||||
logger.warning(f"No masks generated for either eye in segment {segment_idx}")
|
||||
return None
|
||||
|
||||
def create_greenscreen_segment(self, segment_info: dict, green_color: List[int] = [0, 255, 0]) -> bool:
|
||||
"""
|
||||
Create a full greenscreen segment when no humans are detected.
|
||||
|
||||
Args:
|
||||
segment_info: Segment information dictionary
|
||||
green_color: RGB values for green screen color
|
||||
|
||||
Returns:
|
||||
True if greenscreen segment was created successfully
|
||||
"""
|
||||
segment_dir = segment_info['directory']
|
||||
video_path = segment_info['video_file']
|
||||
segment_idx = segment_info['index']
|
||||
|
||||
logger.info(f"Creating full greenscreen segment {segment_idx}")
|
||||
|
||||
try:
|
||||
# Get video properties
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
if not cap.isOpened():
|
||||
logger.error(f"Could not open video: {video_path}")
|
||||
return False
|
||||
|
||||
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
cap.release()
|
||||
|
||||
# Create output video path
|
||||
output_video_path = os.path.join(segment_dir, f"output_{segment_idx}.mp4")
|
||||
|
||||
# Create greenscreen frames
|
||||
greenscreen_frame = self.eye_processor.create_full_greenscreen_frame(
|
||||
(height, width, 3), green_color
|
||||
)
|
||||
|
||||
# Write greenscreen video
|
||||
fourcc = cv2.VideoWriter_fourcc(*'HEVC')
|
||||
out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height))
|
||||
|
||||
for _ in range(frame_count):
|
||||
out.write(greenscreen_frame)
|
||||
|
||||
out.release()
|
||||
|
||||
# Create mask file (empty/black mask since no humans detected)
|
||||
mask_output_path = os.path.join(segment_dir, "mask.png")
|
||||
black_mask = np.zeros((height, width, 3), dtype=np.uint8)
|
||||
cv2.imwrite(mask_output_path, black_mask)
|
||||
|
||||
# Mark segment as completed
|
||||
output_done_file = os.path.join(segment_dir, "output_frames_done")
|
||||
with open(output_done_file, 'w') as f:
|
||||
f.write(f"Greenscreen segment {segment_idx} completed successfully\n")
|
||||
|
||||
logger.info(f"Successfully created greenscreen segment {segment_idx}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating greenscreen segment {segment_idx}: {e}")
|
||||
return False
|
||||
|
||||
Reference in New Issue
Block a user