working with segemntation

This commit is contained in:
2025-07-27 13:55:52 -07:00
parent 46363a8a11
commit cd7bc54efe
7 changed files with 1302 additions and 105 deletions

View File

@@ -57,3 +57,6 @@ advanced:
# Logging level (DEBUG, INFO, WARNING, ERROR) # Logging level (DEBUG, INFO, WARNING, ERROR)
log_level: "INFO" log_level: "INFO"
# Save debug frames with YOLO detections visualized
save_yolo_debug_frames: true

View File

@@ -50,11 +50,31 @@ class ConfigLoader:
raise ValueError(f"Missing required field: output.{field}") raise ValueError(f"Missing required field: output.{field}")
# Validate models section # Validate models section
required_model_fields = ['yolo_model', 'sam2_checkpoint', 'sam2_config'] required_model_fields = ['sam2_checkpoint', 'sam2_config']
for field in required_model_fields: for field in required_model_fields:
if field not in self.config['models']: if field not in self.config['models']:
raise ValueError(f"Missing required field: models.{field}") raise ValueError(f"Missing required field: models.{field}")
# Validate YOLO model configuration
yolo_mode = self.config['models'].get('yolo_mode', 'detection')
if yolo_mode not in ['detection', 'segmentation']:
raise ValueError(f"Invalid yolo_mode: {yolo_mode}. Must be 'detection' or 'segmentation'")
# Check for legacy yolo_model field vs new structure
has_legacy_yolo_model = 'yolo_model' in self.config['models']
has_new_yolo_models = 'yolo_detection_model' in self.config['models'] or 'yolo_segmentation_model' in self.config['models']
if not has_legacy_yolo_model and not has_new_yolo_models:
raise ValueError("Missing YOLO model configuration. Provide either 'yolo_model' (legacy) or 'yolo_detection_model'/'yolo_segmentation_model' (new)")
# Validate that the required model for the current mode exists
if yolo_mode == 'detection':
if has_new_yolo_models and 'yolo_detection_model' not in self.config['models']:
raise ValueError("yolo_mode is 'detection' but yolo_detection_model not specified")
elif yolo_mode == 'segmentation':
if has_new_yolo_models and 'yolo_segmentation_model' not in self.config['models']:
raise ValueError("yolo_mode is 'segmentation' but yolo_segmentation_model not specified")
# Validate processing.detect_segments format # Validate processing.detect_segments format
detect_segments = self.config['processing'].get('detect_segments', 'all') detect_segments = self.config['processing'].get('detect_segments', 'all')
if not isinstance(detect_segments, (str, list)): if not isinstance(detect_segments, (str, list)):
@@ -114,8 +134,17 @@ class ConfigLoader:
return self.config['processing'].get('detect_segments', 'all') return self.config['processing'].get('detect_segments', 'all')
def get_yolo_model_path(self) -> str: def get_yolo_model_path(self) -> str:
"""Get YOLO model path.""" """Get YOLO model path (legacy method for backward compatibility)."""
return self.config['models']['yolo_model'] # Check for legacy configuration first
if 'yolo_model' in self.config['models']:
return self.config['models']['yolo_model']
# Use new configuration based on mode
yolo_mode = self.config['models'].get('yolo_mode', 'detection')
if yolo_mode == 'detection':
return self.config['models'].get('yolo_detection_model', 'yolov8n.pt')
else: # segmentation mode
return self.config['models'].get('yolo_segmentation_model', 'yolov8n-seg.pt')
def get_sam2_checkpoint(self) -> str: def get_sam2_checkpoint(self) -> str:
"""Get SAM2 checkpoint path.""" """Get SAM2 checkpoint path."""

View File

@@ -47,8 +47,23 @@ class SAM2Processor:
logger.info(f"Using device: {device}") logger.info(f"Using device: {device}")
try: try:
# Extract just the config filename for SAM2's Hydra-based loader
# SAM2 expects a config name relative to its internal config directory
config_name = os.path.basename(self.config_path)
if config_name.endswith('.yaml'):
config_name = config_name[:-5] # Remove .yaml extension
# SAM2 configs are in the format "sam2.1_hiera_X.yaml"
# and should be referenced as "configs/sam2.1/sam2.1_hiera_X"
if config_name.startswith("sam2.1_hiera"):
config_name = f"configs/sam2.1/{config_name}"
elif config_name.startswith("sam2_hiera"):
config_name = f"configs/sam2/{config_name}"
logger.info(f"Using SAM2 config: {config_name}")
self.predictor = build_sam2_video_predictor( self.predictor = build_sam2_video_predictor(
self.config_path, config_name, # Use just the config name, not full path
self.checkpoint_path, self.checkpoint_path,
device=device device=device
) )
@@ -103,6 +118,7 @@ class SAM2Processor:
def add_yolo_prompts_to_predictor(self, inference_state, prompts: List[Dict[str, Any]]) -> bool: def add_yolo_prompts_to_predictor(self, inference_state, prompts: List[Dict[str, Any]]) -> bool:
""" """
Add YOLO detection prompts to SAM2 predictor. Add YOLO detection prompts to SAM2 predictor.
Includes error handling matching the working spec.md implementation.
Args: Args:
inference_state: SAM2 inference state inference_state: SAM2 inference state
@@ -112,14 +128,21 @@ class SAM2Processor:
True if prompts were added successfully True if prompts were added successfully
""" """
if not prompts: if not prompts:
logger.warning("No prompts provided to SAM2") logger.warning("SAM2 Debug: No prompts provided to SAM2")
return False return False
try: logger.info(f"SAM2 Debug: Received {len(prompts)} prompts to add to predictor")
for prompt in prompts:
obj_id = prompt['obj_id']
bbox = prompt['bbox']
success_count = 0
for i, prompt in enumerate(prompts):
obj_id = prompt['obj_id']
bbox = prompt['bbox']
confidence = prompt.get('confidence', 'unknown')
logger.info(f"SAM2 Debug: Adding prompt {i+1}/{len(prompts)}: Object {obj_id}, bbox={bbox}, conf={confidence}")
try:
_, out_obj_ids, out_mask_logits = self.predictor.add_new_points_or_box( _, out_obj_ids, out_mask_logits = self.predictor.add_new_points_or_box(
inference_state=inference_state, inference_state=inference_state,
frame_idx=0, frame_idx=0,
@@ -127,13 +150,19 @@ class SAM2Processor:
box=bbox.astype(np.float32), box=bbox.astype(np.float32),
) )
logger.debug(f"Added prompt for Object {obj_id}: {bbox}") logger.info(f"SAM2 Debug: ✓ Successfully added Object {obj_id} - returned obj_ids: {out_obj_ids}")
success_count += 1
logger.info(f"Successfully added {len(prompts)} prompts to SAM2") except Exception as e:
logger.error(f"SAM2 Debug: ✗ Error adding Object {obj_id}: {e}")
# Continue processing other prompts even if one fails
continue
if success_count > 0:
logger.info(f"SAM2 Debug: Final result - {success_count}/{len(prompts)} prompts successfully added")
return True return True
else:
except Exception as e: logger.error("SAM2 Debug: FAILED - No prompts were successfully added to SAM2")
logger.error(f"Error adding prompts to SAM2: {e}")
return False return False
def load_previous_segment_mask(self, prev_segment_dir: str) -> Optional[Dict[int, np.ndarray]]: def load_previous_segment_mask(self, prev_segment_dir: str) -> Optional[Dict[int, np.ndarray]]:
@@ -235,15 +264,17 @@ class SAM2Processor:
def process_single_segment(self, segment_info: dict, yolo_prompts: Optional[List[Dict[str, Any]]] = None, def process_single_segment(self, segment_info: dict, yolo_prompts: Optional[List[Dict[str, Any]]] = None,
previous_masks: Optional[Dict[int, np.ndarray]] = None, previous_masks: Optional[Dict[int, np.ndarray]] = None,
inference_scale: float = 0.5) -> Optional[Dict[int, Dict[int, np.ndarray]]]: inference_scale: float = 0.5,
multi_frame_prompts: Optional[Dict[int, List[Dict[str, Any]]]] = None) -> Optional[Dict[int, Dict[int, np.ndarray]]]:
""" """
Process a single video segment with SAM2. Process a single video segment with SAM2.
Args: Args:
segment_info: Segment information dictionary segment_info: Segment information dictionary
yolo_prompts: Optional YOLO detection prompts yolo_prompts: Optional YOLO detection prompts for first frame
previous_masks: Optional masks from previous segment previous_masks: Optional masks from previous segment
inference_scale: Scale factor for inference inference_scale: Scale factor for inference
multi_frame_prompts: Optional prompts for multiple frames (mid-segment detection)
Returns: Returns:
Video segments dictionary or None if failed Video segments dictionary or None if failed
@@ -284,6 +315,13 @@ class SAM2Processor:
logger.error(f"No prompts or previous masks available for segment {segment_idx}") logger.error(f"No prompts or previous masks available for segment {segment_idx}")
return None return None
# Add mid-segment prompts if provided
if multi_frame_prompts:
logger.info(f"Adding mid-segment prompts for segment {segment_idx}")
if not self.add_multi_frame_prompts_to_predictor(inference_state, multi_frame_prompts):
logger.warning(f"Failed to add mid-segment prompts for segment {segment_idx}")
# Don't return None here - continue with existing prompts
# Propagate masks # Propagate masks
video_segments = self.propagate_masks(inference_state) video_segments = self.propagate_masks(inference_state)
@@ -360,3 +398,197 @@ class SAM2Processor:
except Exception as e: except Exception as e:
logger.error(f"Error saving final masks: {e}") logger.error(f"Error saving final masks: {e}")
def generate_first_frame_debug_masks(self, video_path: str, prompts: List[Dict[str, Any]],
output_path: str, inference_scale: float = 0.5) -> bool:
"""
Generate SAM2 masks for just the first frame and save debug visualization.
This helps debug what SAM2 is producing for each detected object.
Args:
video_path: Path to the video file
prompts: List of SAM2 prompt dictionaries
output_path: Path to save the debug image
inference_scale: Scale factor for SAM2 inference
Returns:
True if debug masks were generated successfully
"""
if not prompts:
logger.warning("No prompts provided for first frame debug")
return False
try:
logger.info(f"SAM2 Debug: Generating first frame masks for {len(prompts)} objects")
# Load the first frame
cap = cv2.VideoCapture(video_path)
ret, original_frame = cap.read()
cap.release()
if not ret:
logger.error("Could not read first frame for debug mask generation")
return False
# Scale frame for inference if needed
if inference_scale != 1.0:
inference_frame = cv2.resize(original_frame, None, fx=inference_scale, fy=inference_scale, interpolation=cv2.INTER_LINEAR)
else:
inference_frame = original_frame.copy()
# Create temporary low-res video with just first frame
import tempfile
import os
temp_dir = tempfile.mkdtemp()
temp_video_path = os.path.join(temp_dir, "first_frame.mp4")
# Write single frame to temporary video
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(temp_video_path, fourcc, 1.0, (inference_frame.shape[1], inference_frame.shape[0]))
out.write(inference_frame)
out.release()
# Initialize SAM2 inference state with single frame
inference_state = self.predictor.init_state(video_path=temp_video_path, async_loading_frames=True)
# Add prompts
if not self.add_yolo_prompts_to_predictor(inference_state, prompts):
logger.error("Failed to add prompts for first frame debug")
return False
# Generate masks for first frame only
frame_masks = {}
for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(inference_state):
if out_frame_idx == 0: # Only process first frame
frame_masks = {
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
for i, out_obj_id in enumerate(out_obj_ids)
}
break
if not frame_masks:
logger.error("No masks generated for first frame debug")
return False
# Create debug visualization
debug_frame = original_frame.copy()
# Define colors for each object
colors = {
1: (0, 255, 0), # Green for Object 1 (Left eye)
2: (255, 0, 0), # Blue for Object 2 (Right eye)
3: (0, 255, 255), # Yellow for Object 3
4: (255, 0, 255), # Magenta for Object 4
}
# Overlay masks with transparency
for obj_id, mask in frame_masks.items():
mask = mask.squeeze()
# Resize mask to match original frame if needed
if mask.shape != original_frame.shape[:2]:
mask = cv2.resize(mask.astype(np.float32), (original_frame.shape[1], original_frame.shape[0]), interpolation=cv2.INTER_NEAREST)
mask = mask > 0.5
# Apply colored overlay
color = colors.get(obj_id, (128, 128, 128))
overlay = debug_frame.copy()
overlay[mask] = color
# Blend with original (30% overlay, 70% original)
cv2.addWeighted(overlay, 0.3, debug_frame, 0.7, 0, debug_frame)
# Draw outline
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
cv2.drawContours(debug_frame, contours, -1, color, 2)
logger.info(f"SAM2 Debug: Object {obj_id} mask - shape: {mask.shape}, pixels: {np.sum(mask)}")
# Add title
title = f"SAM2 First Frame Masks: {len(frame_masks)} objects detected"
cv2.putText(debug_frame, title, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (255, 255, 255), 2)
# Add mask source information
source_info = "Mask Source: SAM2 (from YOLO bounding boxes)"
cv2.putText(debug_frame, source_info, (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2)
# Add object legend
y_offset = 90
for obj_id in sorted(frame_masks.keys()):
color = colors.get(obj_id, (128, 128, 128))
text = f"Object {obj_id}: {'Left Eye' if obj_id == 1 else 'Right Eye' if obj_id == 2 else f'Object {obj_id}'}"
cv2.putText(debug_frame, text, (10, y_offset), cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2)
y_offset += 30
# Save debug image
success = cv2.imwrite(output_path, debug_frame)
# Cleanup
self.predictor.reset_state(inference_state)
import shutil
shutil.rmtree(temp_dir)
if success:
logger.info(f"SAM2 Debug: Saved first frame masks to {output_path}")
return True
else:
logger.error(f"Failed to save first frame masks to {output_path}")
return False
except Exception as e:
logger.error(f"Error generating first frame debug masks: {e}")
return False
def add_multi_frame_prompts_to_predictor(self, inference_state, multi_frame_prompts: Dict[int, List[Dict[str, Any]]]) -> bool:
"""
Add YOLO detection prompts at multiple frame indices for mid-segment re-detection.
Args:
inference_state: SAM2 inference state
multi_frame_prompts: Dictionary mapping frame_index -> list of prompt dictionaries
Returns:
True if prompts were added successfully
"""
if not multi_frame_prompts:
logger.warning("SAM2 Mid-segment: No multi-frame prompts provided")
return False
total_prompts = sum(len(prompts) for prompts in multi_frame_prompts.values())
logger.info(f"SAM2 Mid-segment: Adding {total_prompts} prompts across {len(multi_frame_prompts)} frames")
success_count = 0
total_count = 0
for frame_idx, prompts in multi_frame_prompts.items():
logger.info(f"SAM2 Mid-segment: Processing frame {frame_idx} with {len(prompts)} prompts")
for i, prompt in enumerate(prompts):
obj_id = prompt['obj_id']
bbox = prompt['bbox']
confidence = prompt.get('confidence', 'unknown')
total_count += 1
logger.info(f"SAM2 Mid-segment: Frame {frame_idx}, Prompt {i+1}/{len(prompts)}: Object {obj_id}, bbox={bbox}, conf={confidence}")
try:
_, out_obj_ids, out_mask_logits = self.predictor.add_new_points_or_box(
inference_state=inference_state,
frame_idx=frame_idx, # Key: specify the exact frame index
obj_id=obj_id,
box=bbox.astype(np.float32),
)
logger.info(f"SAM2 Mid-segment: ✓ Frame {frame_idx}, Object {obj_id} added successfully - returned obj_ids: {out_obj_ids}")
success_count += 1
except Exception as e:
logger.error(f"SAM2 Mid-segment: ✗ Frame {frame_idx}, Object {obj_id} failed: {e}")
continue
if success_count > 0:
logger.info(f"SAM2 Mid-segment: Final result - {success_count}/{total_count} prompts successfully added across {len(multi_frame_prompts)} frames")
return True
else:
logger.error("SAM2 Mid-segment: FAILED - No prompts were successfully added")
return False

View File

@@ -7,31 +7,56 @@ import os
import cv2 import cv2
import numpy as np import numpy as np
import logging import logging
from typing import List, Dict, Any, Optional from typing import List, Dict, Any, Optional, Tuple
from ultralytics import YOLO from ultralytics import YOLO
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class YOLODetector: class YOLODetector:
"""Handles YOLO-based human detection for video segments.""" """Handles YOLO-based human detection for video segments with support for both detection and segmentation modes."""
def __init__(self, model_path: str, confidence_threshold: float = 0.6, human_class_id: int = 0): def __init__(self, detection_model_path: str = None, segmentation_model_path: str = None,
mode: str = "detection", confidence_threshold: float = 0.6, human_class_id: int = 0):
""" """
Initialize YOLO detector. Initialize YOLO detector with support for both detection and segmentation modes.
Args: Args:
model_path: Path to YOLO model weights detection_model_path: Path to YOLO detection model weights (e.g., yolov8n.pt)
segmentation_model_path: Path to YOLO segmentation model weights (e.g., yolov8n-seg.pt)
mode: Detection mode - "detection" for bboxes, "segmentation" for masks
confidence_threshold: Detection confidence threshold confidence_threshold: Detection confidence threshold
human_class_id: COCO class ID for humans (0 = person) human_class_id: COCO class ID for humans (0 = person)
""" """
self.model_path = model_path self.mode = mode
self.confidence_threshold = confidence_threshold self.confidence_threshold = confidence_threshold
self.human_class_id = human_class_id self.human_class_id = human_class_id
# Select model path based on mode
if mode == "segmentation":
if not segmentation_model_path:
raise ValueError("segmentation_model_path required for segmentation mode")
self.model_path = segmentation_model_path
self.supports_segmentation = True
elif mode == "detection":
if not detection_model_path:
raise ValueError("detection_model_path required for detection mode")
self.model_path = detection_model_path
self.supports_segmentation = False
else:
raise ValueError(f"Invalid mode: {mode}. Must be 'detection' or 'segmentation'")
# Load YOLO model # Load YOLO model
try: try:
self.model = YOLO(model_path) self.model = YOLO(self.model_path)
logger.info(f"Loaded YOLO model from {model_path}") logger.info(f"Loaded YOLO model in {mode} mode from {self.model_path}")
# Verify model capabilities
if mode == "segmentation":
# Test if model actually supports segmentation
logger.info(f"YOLO Segmentation: Model loaded, will output direct masks")
else:
logger.info(f"YOLO Detection: Model loaded, will output bounding boxes")
except Exception as e: except Exception as e:
logger.error(f"Failed to load YOLO model: {e}") logger.error(f"Failed to load YOLO model: {e}")
raise raise
@@ -44,9 +69,9 @@ class YOLODetector:
frame: Input frame (BGR format from OpenCV) frame: Input frame (BGR format from OpenCV)
Returns: Returns:
List of human detection dictionaries with bbox and confidence List of human detection dictionaries with bbox, confidence, and optionally masks
""" """
# Run YOLO detection # Run YOLO detection/segmentation
results = self.model(frame, conf=self.confidence_threshold, verbose=False) results = self.model(frame, conf=self.confidence_threshold, verbose=False)
human_detections = [] human_detections = []
@@ -54,8 +79,10 @@ class YOLODetector:
# Process results # Process results
for result in results: for result in results:
boxes = result.boxes boxes = result.boxes
masks = result.masks if hasattr(result, 'masks') and result.masks is not None else None
if boxes is not None: if boxes is not None:
for box in boxes: for i, box in enumerate(boxes):
# Get class ID # Get class ID
cls = int(box.cls.cpu().numpy()[0]) cls = int(box.cls.cpu().numpy()[0])
@@ -65,12 +92,29 @@ class YOLODetector:
coords = box.xyxy[0].cpu().numpy() coords = box.xyxy[0].cpu().numpy()
conf = float(box.conf.cpu().numpy()[0]) conf = float(box.conf.cpu().numpy()[0])
human_detections.append({ detection = {
'bbox': coords, 'bbox': coords,
'confidence': conf 'confidence': conf,
}) 'has_mask': False,
'mask': None
}
logger.debug(f"Detected human with confidence {conf:.2f} at {coords}") # Extract mask if available (segmentation mode)
if masks is not None and i < len(masks.data):
mask_data = masks.data[i].cpu().numpy() # Get mask for this detection
detection['has_mask'] = True
detection['mask'] = mask_data
logger.debug(f"YOLO Segmentation: Detected human with mask - conf={conf:.2f}, mask_shape={mask_data.shape}")
else:
logger.debug(f"YOLO Detection: Detected human with bbox - conf={conf:.2f}, bbox={coords}")
human_detections.append(detection)
if self.supports_segmentation:
masks_found = sum(1 for d in human_detections if d['has_mask'])
logger.info(f"YOLO Segmentation: Found {len(human_detections)} humans, {masks_found} with masks")
else:
logger.debug(f"YOLO Detection: Found {len(human_detections)} humans with bounding boxes")
return human_detections return human_detections
@@ -153,25 +197,33 @@ class YOLODetector:
try: try:
with open(file_path, 'r') as f: with open(file_path, 'r') as f:
for line in f: content = f.read()
line = line.strip()
# Skip comments and empty lines
if line.startswith('#') or not line:
continue
# Parse detection line: x1,y1,x2,y2,confidence # Handle files with literal \n characters
parts = line.split(',') if '\\n' in content:
if len(parts) == 5: lines = content.split('\\n')
try: else:
bbox = [float(x) for x in parts[:4]] lines = content.split('\n')
conf = float(parts[4])
detections.append({ for line in lines:
'bbox': np.array(bbox), line = line.strip()
'confidence': conf # Skip comments and empty lines
}) if line.startswith('#') or not line:
except ValueError: continue
logger.warning(f"Invalid detection line: {line}")
continue # Parse detection line: x1,y1,x2,y2,confidence
parts = line.split(',')
if len(parts) == 5:
try:
bbox = [float(x) for x in parts[:4]]
conf = float(parts[4])
detections.append({
'bbox': np.array(bbox),
'confidence': conf
})
except ValueError:
logger.warning(f"Invalid detection line: {line}")
continue
logger.info(f"Loaded {len(detections)} detections from {file_path}") logger.info(f"Loaded {len(detections)} detections from {file_path}")
except Exception as e: except Exception as e:
@@ -179,6 +231,120 @@ class YOLODetector:
return detections return detections
def debug_detect_with_lower_confidence(self, frame: np.ndarray, debug_confidence: float = 0.3) -> List[Dict[str, Any]]:
"""
Run YOLO detection with a lower confidence threshold for debugging.
This helps identify if detections are being missed due to high confidence threshold.
Args:
frame: Input frame (BGR format from OpenCV)
debug_confidence: Lower confidence threshold for debugging
Returns:
List of human detection dictionaries with lower confidence threshold
"""
logger.info(f"VR180 Debug: Running YOLO with lower confidence {debug_confidence} (vs normal {self.confidence_threshold})")
# Run YOLO detection with lower confidence
results = self.model(frame, conf=debug_confidence, verbose=False)
debug_detections = []
# Process results
for result in results:
boxes = result.boxes
if boxes is not None:
for box in boxes:
# Get class ID
cls = int(box.cls.cpu().numpy()[0])
# Check if it's a person (human_class_id)
if cls == self.human_class_id:
# Get bounding box coordinates (x1, y1, x2, y2)
coords = box.xyxy[0].cpu().numpy()
conf = float(box.conf.cpu().numpy()[0])
debug_detections.append({
'bbox': coords,
'confidence': conf
})
logger.info(f"VR180 Debug: Lower confidence detection found {len(debug_detections)} total detections")
return debug_detections
def detect_humans_multi_frame(self, video_path: str, frame_indices: List[int],
scale: float = 1.0) -> Dict[int, List[Dict[str, Any]]]:
"""
Detect humans at multiple specific frame indices in a video.
Used for mid-segment re-detection to improve SAM2 tracking.
Args:
video_path: Path to video file
frame_indices: List of frame indices to run detection on (e.g., [0, 30, 60, 90])
scale: Scale factor for frame processing
Returns:
Dictionary mapping frame_index -> list of detection dictionaries
"""
if not frame_indices:
logger.warning("No frame indices provided for multi-frame detection")
return {}
if not os.path.exists(video_path):
logger.error(f"Video file not found: {video_path}")
return {}
logger.info(f"Mid-segment Detection: Running YOLO on {len(frame_indices)} frames: {frame_indices}")
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
logger.error(f"Could not open video: {video_path}")
return {}
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
# Filter out frame indices that are beyond video length
valid_frame_indices = [idx for idx in frame_indices if 0 <= idx < total_frames]
if len(valid_frame_indices) != len(frame_indices):
invalid_frames = [idx for idx in frame_indices if idx not in valid_frame_indices]
logger.warning(f"Mid-segment Detection: Skipping invalid frame indices: {invalid_frames} (video has {total_frames} frames)")
multi_frame_detections = {}
for frame_idx in valid_frame_indices:
# Seek to specific frame
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
ret, frame = cap.read()
if not ret:
logger.warning(f"Mid-segment Detection: Could not read frame {frame_idx}")
continue
# Scale frame if needed
if scale != 1.0:
frame = cv2.resize(frame, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR)
# Run YOLO detection on this frame
detections = self.detect_humans_in_frame(frame)
multi_frame_detections[frame_idx] = detections
# Log detection results
time_seconds = frame_idx / fps
logger.info(f"Mid-segment Detection: Frame {frame_idx} (t={time_seconds:.1f}s): {len(detections)} humans detected")
for i, detection in enumerate(detections):
bbox = detection['bbox']
conf = detection['confidence']
logger.debug(f"Mid-segment Detection: Frame {frame_idx}, Human {i+1}: bbox={bbox}, conf={conf:.3f}")
cap.release()
total_detections = sum(len(dets) for dets in multi_frame_detections.values())
logger.info(f"Mid-segment Detection: Complete - {total_detections} total detections across {len(valid_frame_indices)} frames")
return multi_frame_detections
def process_segments_batch(self, segments_info: List[dict], detect_segments: List[int], def process_segments_batch(self, segments_info: List[dict], detect_segments: List[int],
scale: float = 0.5) -> Dict[int, List[Dict[str, Any]]]: scale: float = 0.5) -> Dict[int, List[Dict[str, Any]]]:
""" """
@@ -224,7 +390,8 @@ class YOLODetector:
def convert_detections_to_sam2_prompts(self, detections: List[Dict[str, Any]], def convert_detections_to_sam2_prompts(self, detections: List[Dict[str, Any]],
frame_width: int) -> List[Dict[str, Any]]: frame_width: int) -> List[Dict[str, Any]]:
""" """
Convert YOLO detections to SAM2-compatible prompts for stereo video. Convert YOLO detections to SAM2-compatible prompts for VR180 SBS video.
For VR180, we expect 2 real detections (left and right eye views), not mirrored ones.
Args: Args:
detections: List of YOLO detection results detections: List of YOLO detection results
@@ -234,53 +401,335 @@ class YOLODetector:
List of SAM2 prompt dictionaries with obj_id and bbox List of SAM2 prompt dictionaries with obj_id and bbox
""" """
if not detections: if not detections:
logger.warning("No detections provided for SAM2 prompt conversion")
return [] return []
half_frame_width = frame_width // 2 half_frame_width = frame_width // 2
prompts = [] prompts = []
logger.info(f"VR180 SBS Debug: Converting {len(detections)} detections for frame width {frame_width}")
logger.info(f"VR180 SBS Debug: Half frame width = {half_frame_width}")
# Sort detections by x-coordinate to get consistent left/right assignment # Sort detections by x-coordinate to get consistent left/right assignment
sorted_detections = sorted(detections, key=lambda x: x['bbox'][0]) sorted_detections = sorted(detections, key=lambda x: x['bbox'][0])
# Analyze detections by frame half
left_detections = []
right_detections = []
for i, detection in enumerate(sorted_detections):
bbox = detection['bbox'].copy()
center_x = (bbox[0] + bbox[2]) / 2
pixel_range = f"{bbox[0]:.0f}-{bbox[2]:.0f}"
if center_x < half_frame_width:
left_detections.append((detection, i, pixel_range))
side = "LEFT"
else:
right_detections.append((detection, i, pixel_range))
side = "RIGHT"
logger.info(f"VR180 SBS Debug: Detection {i}: pixels {pixel_range}, center_x={center_x:.1f}, side={side}")
# VR180 SBS Format Validation
logger.info(f"VR180 SBS Debug: Found {len(left_detections)} LEFT detections, {len(right_detections)} RIGHT detections")
# Analyze confidence scores
if left_detections:
left_confidences = [det[0]['confidence'] for det in left_detections]
logger.info(f"VR180 SBS Debug: LEFT eye confidences: {[f'{c:.3f}' for c in left_confidences]}")
if right_detections:
right_confidences = [det[0]['confidence'] for det in right_detections]
logger.info(f"VR180 SBS Debug: RIGHT eye confidences: {[f'{c:.3f}' for c in right_confidences]}")
if len(right_detections) == 0:
logger.warning(f"VR180 SBS Warning: No detections found in RIGHT eye view (pixels {half_frame_width}-{frame_width})")
logger.warning(f"VR180 SBS Warning: This may indicate:")
logger.warning(f" 1. Person not visible in right eye view")
logger.warning(f" 2. YOLO confidence threshold ({self.confidence_threshold}) too high")
logger.warning(f" 3. VR180 SBS format issue")
logger.warning(f" 4. Right eye view quality/lighting problems")
logger.warning(f"VR180 SBS Suggestion: Try lowering yolo_confidence to 0.3-0.4 in config")
if len(left_detections) == 0:
logger.warning(f"VR180 SBS Warning: No detections found in LEFT eye view (pixels 0-{half_frame_width})")
# Additional validation for VR180 SBS expectations
total_detections = len(left_detections) + len(right_detections)
if total_detections == 1:
logger.warning(f"VR180 SBS Warning: Only 1 detection found - expected 2 for proper VR180 SBS")
elif total_detections > 2:
logger.warning(f"VR180 SBS Warning: {total_detections} detections found - will use only first 2")
# Assign object IDs sequentially, regardless of which half they're in
# This ensures we always get Object 1 and Object 2 for up to 2 detections
obj_id = 1 obj_id = 1
for i, detection in enumerate(sorted_detections[:2]): # Take up to 2 humans # Process up to 2 detections total (left + right combined)
all_detections = sorted_detections[:2]
for i, detection in enumerate(all_detections):
bbox = detection['bbox'].copy() bbox = detection['bbox'].copy()
center_x = (bbox[0] + bbox[2]) / 2
pixel_range = f"{bbox[0]:.0f}-{bbox[2]:.0f}"
# For stereo videos, assign obj_id based on position # Determine which eye view this detection is in
if len(sorted_detections) >= 2: if center_x < half_frame_width:
center_x = (bbox[0] + bbox[2]) / 2 eye_view = "LEFT"
if center_x < half_frame_width:
current_obj_id = 1 # Left human
else:
current_obj_id = 2 # Right human
else: else:
# If only one human, create prompts for both sides eye_view = "RIGHT"
current_obj_id = obj_id
obj_id += 1
# Create mirrored version for stereo
if obj_id <= 2:
mirrored_bbox = bbox.copy()
mirrored_bbox[0] += half_frame_width # Shift x1
mirrored_bbox[2] += half_frame_width # Shift x2
# Ensure mirrored bbox is within frame bounds
mirrored_bbox[0] = max(0, min(mirrored_bbox[0], frame_width - 1))
mirrored_bbox[2] = max(0, min(mirrored_bbox[2], frame_width - 1))
prompts.append({
'obj_id': obj_id,
'bbox': mirrored_bbox,
'confidence': detection['confidence']
})
obj_id += 1
prompts.append({ prompts.append({
'obj_id': current_obj_id, 'obj_id': obj_id,
'bbox': bbox, 'bbox': bbox,
'confidence': detection['confidence'] 'confidence': detection['confidence']
}) })
logger.debug(f"Converted {len(detections)} detections to {len(prompts)} SAM2 prompts") logger.info(f"VR180 SBS Debug: Added {eye_view} eye detection as SAM2 Object {obj_id}")
logger.info(f"VR180 SBS Debug: Object {obj_id} bbox: {bbox} (pixels {pixel_range})")
obj_id += 1
logger.info(f"VR180 SBS Debug: Final result - {len(detections)} YOLO detections → {len(prompts)} SAM2 prompts")
# Verify we have the expected objects
obj_ids = [p['obj_id'] for p in prompts]
logger.info(f"VR180 SBS Debug: SAM2 Object IDs created: {obj_ids}")
return prompts return prompts
def convert_yolo_masks_to_video_segments(self, detections: List[Dict[str, Any]],
frame_width: int, target_frame_shape: Tuple[int, int] = None) -> Optional[Dict[int, Dict[int, np.ndarray]]]:
"""
Convert YOLO segmentation masks to SAM2-compatible video segments format.
This allows using YOLO masks directly without SAM2 processing.
Args:
detections: List of YOLO detection results with masks
frame_width: Width of the video frame for VR180 object ID assignment
target_frame_shape: Target shape (height, width) for mask resizing
Returns:
Video segments dictionary compatible with SAM2 output format, or None if no masks
"""
if not detections:
logger.warning("No detections provided for mask conversion")
return None
# Check if any detections have masks
detections_with_masks = [d for d in detections if d.get('has_mask', False)]
if not detections_with_masks:
logger.warning("No detections have masks - YOLO segmentation may not be working")
return None
logger.info(f"YOLO Mask Conversion: Converting {len(detections_with_masks)} YOLO masks to video segments format")
half_frame_width = frame_width // 2
video_segments = {}
# Create frame 0 with converted masks
frame_masks = {}
obj_id = 1
# Sort detections by x-coordinate for consistent VR180 SBS assignment
sorted_detections = sorted(detections_with_masks, key=lambda x: x['bbox'][0])
for i, detection in enumerate(sorted_detections[:2]): # Take up to 2 humans
mask = detection['mask']
bbox = detection['bbox']
center_x = (bbox[0] + bbox[2]) / 2
# Assign sequential object IDs (similar to prompt conversion logic)
current_obj_id = obj_id
# Determine which eye view for logging
if center_x < half_frame_width:
eye_view = "LEFT"
else:
eye_view = "RIGHT"
# Resize mask to target frame shape if specified
if target_frame_shape and mask.shape != target_frame_shape:
mask_resized = cv2.resize(mask.astype(np.float32), (target_frame_shape[1], target_frame_shape[0]), interpolation=cv2.INTER_NEAREST)
mask = (mask_resized > 0.5).astype(bool)
else:
mask = mask.astype(bool)
frame_masks[current_obj_id] = mask
logger.info(f"YOLO Mask Conversion: {eye_view} eye detection -> Object {current_obj_id}, mask_shape={mask.shape}, pixels={np.sum(mask)}")
obj_id += 1 # Always increment for next detection
# Store masks in video segments format (single frame)
video_segments[0] = frame_masks
total_objects = len(frame_masks)
total_pixels = sum(np.sum(mask) for mask in frame_masks.values())
logger.info(f"YOLO Mask Conversion: Created video segments with {total_objects} objects, {total_pixels} total mask pixels")
return video_segments
def save_debug_frame_with_detections(self, frame: np.ndarray, detections: List[Dict[str, Any]],
output_path: str, prompts: List[Dict[str, Any]] = None) -> bool:
"""
Save a debug frame with YOLO detections and SAM2 prompts overlaid as bounding boxes.
Args:
frame: Input frame (BGR format from OpenCV)
detections: List of detection dictionaries with bbox and confidence
output_path: Path to save the debug image
prompts: Optional list of SAM2 prompt dictionaries with obj_id and bbox
Returns:
True if saved successfully
"""
try:
debug_frame = frame.copy()
# Draw masks (if available) or bounding boxes for each detection
for i, detection in enumerate(detections):
bbox = detection['bbox']
confidence = detection['confidence']
has_mask = detection.get('has_mask', False)
# Extract coordinates
x1, y1, x2, y2 = map(int, bbox)
# Choose color based on confidence (green for high, yellow for medium, red for low)
if confidence >= 0.8:
color = (0, 255, 0) # Green
elif confidence >= 0.6:
color = (0, 255, 255) # Yellow
else:
color = (0, 0, 255) # Red
if has_mask and 'mask' in detection:
# Draw segmentation mask
mask = detection['mask']
# Resize mask to match frame if needed
if mask.shape != debug_frame.shape[:2]:
mask = cv2.resize(mask.astype(np.float32), (debug_frame.shape[1], debug_frame.shape[0]), interpolation=cv2.INTER_NEAREST)
mask = mask > 0.5
mask = mask.astype(bool)
# Apply colored overlay with transparency
overlay = debug_frame.copy()
overlay[mask] = color
cv2.addWeighted(overlay, 0.3, debug_frame, 0.7, 0, debug_frame)
# Draw mask outline
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
cv2.drawContours(debug_frame, contours, -1, color, 2)
# Prepare label text for segmentation
label = f"Person {i+1}: {confidence:.2f} (MASK)"
else:
# Draw bounding box (detection mode or no mask available)
cv2.rectangle(debug_frame, (x1, y1), (x2, y2), color, 2)
# Prepare label text for detection
label = f"Person {i+1}: {confidence:.2f} (BBOX)"
label_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)[0]
# Draw label background
cv2.rectangle(debug_frame,
(x1, y1 - label_size[1] - 10),
(x1 + label_size[0], y1),
color, -1)
# Draw label text
cv2.putText(debug_frame, label,
(x1, y1 - 5),
cv2.FONT_HERSHEY_SIMPLEX, 0.6,
(255, 255, 255), 2)
# Draw SAM2 prompts if provided (with different colors/style)
if prompts:
for prompt in prompts:
obj_id = prompt['obj_id']
bbox = prompt['bbox']
# Extract coordinates
x1, y1, x2, y2 = map(int, bbox)
# Use different colors for each object ID
if obj_id == 1:
prompt_color = (0, 255, 0) # Green for Object 1
elif obj_id == 2:
prompt_color = (255, 0, 0) # Blue for Object 2
else:
prompt_color = (255, 255, 0) # Cyan for others
# Draw thicker, dashed-style border for SAM2 prompts
thickness = 3
cv2.rectangle(debug_frame, (x1-2, y1-2), (x2+2, y2+2), prompt_color, thickness)
# Add SAM2 object ID label
sam_label = f"SAM2 Obj {obj_id}"
label_size = cv2.getTextSize(sam_label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)[0]
# Draw label background
cv2.rectangle(debug_frame,
(x1-2, y2+5),
(x1-2 + label_size[0], y2+5 + label_size[1] + 5),
prompt_color, -1)
# Draw label text
cv2.putText(debug_frame, sam_label,
(x1-2, y2+5 + label_size[1]),
cv2.FONT_HERSHEY_SIMPLEX, 0.5,
(255, 255, 255), 2)
# Draw VR180 SBS boundary line (center line separating left and right eye views)
frame_height, frame_width = debug_frame.shape[:2]
center_x = frame_width // 2
cv2.line(debug_frame, (center_x, 0), (center_x, frame_height), (0, 255, 255), 3) # Yellow line
# Add VR180 SBS labels
cv2.putText(debug_frame, "LEFT EYE", (10, frame_height - 20),
cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 255), 2)
cv2.putText(debug_frame, "RIGHT EYE", (center_x + 10, frame_height - 20),
cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 255), 2)
# Add summary text at top with mode information
mode_text = f"YOLO Mode: {self.mode.upper()}"
masks_available = sum(1 for d in detections if d.get('has_mask', False))
if self.supports_segmentation and masks_available > 0:
summary = f"VR180 SBS: {len(detections)} detections → {masks_available} MASKS (for SAM2 propagation)"
else:
summary = f"VR180 SBS: {len(detections)} detections → {len(prompts) if prompts else 0} SAM2 prompts"
cv2.putText(debug_frame, mode_text,
(10, 30),
cv2.FONT_HERSHEY_SIMPLEX, 0.8,
(0, 255, 255), 2) # Yellow for mode
cv2.putText(debug_frame, summary,
(10, 60),
cv2.FONT_HERSHEY_SIMPLEX, 1.0,
(255, 255, 255), 2)
# Add frame dimensions info
dims_info = f"Frame: {frame_width}x{frame_height}, Center: {center_x}"
cv2.putText(debug_frame, dims_info,
(10, 90),
cv2.FONT_HERSHEY_SIMPLEX, 0.6,
(255, 255, 255), 2)
# Save debug frame
success = cv2.imwrite(output_path, debug_frame)
if success:
logger.info(f"Saved YOLO debug frame to {output_path}")
else:
logger.error(f"Failed to save debug frame to {output_path}")
return success
except Exception as e:
logger.error(f"Error creating debug frame: {e}")
return False

View File

@@ -137,13 +137,21 @@ def download_sam2_models():
def download_yolo_models(): def download_yolo_models():
"""Download default YOLO models to models directory.""" """Download default YOLO models to models directory."""
print("\n--- Setting up YOLO models ---") print("\n--- Setting up YOLO models ---")
print(" Downloading both detection and segmentation models...")
try: try:
from ultralytics import YOLO from ultralytics import YOLO
import torch import torch
# Default YOLO models to download # Default YOLO models to download (both detection and segmentation)
yolo_models = ["yolov8n.pt", "yolov8s.pt", "yolov8m.pt"] yolo_models = [
"yolov8n.pt", # Detection models
"yolov8s.pt",
"yolov8m.pt",
"yolov8n-seg.pt", # Segmentation models
"yolov8s-seg.pt",
"yolov8m-seg.pt"
]
models_dir = Path(__file__).parent / "models" / "yolo" models_dir = Path(__file__).parent / "models" / "yolo"
for model_name in yolo_models: for model_name in yolo_models:
@@ -205,8 +213,13 @@ def download_yolo_models():
success = all((models_dir / model).exists() for model in yolo_models) success = all((models_dir / model).exists() for model in yolo_models)
if success: if success:
print("✓ YOLO models setup complete!") print("✓ YOLO models setup complete!")
print(" Available detection models: yolov8n.pt, yolov8s.pt, yolov8m.pt")
print(" Available segmentation models: yolov8n-seg.pt, yolov8s-seg.pt, yolov8m-seg.pt")
else: else:
print("⚠ Some YOLO models may be missing") missing_models = [model for model in yolo_models if not (models_dir / model).exists()]
print("⚠ Some YOLO models may be missing:")
for model in missing_models:
print(f" - {model}")
return success return success
except ImportError: except ImportError:
@@ -234,6 +247,12 @@ def update_config_file():
updated_content = content.replace( updated_content = content.replace(
'yolo_model: "yolov8n.pt"', 'yolo_model: "yolov8n.pt"',
'yolo_model: "models/yolo/yolov8n.pt"' 'yolo_model: "models/yolo/yolov8n.pt"'
).replace(
'yolo_detection_model: "models/yolo/yolov8n.pt"',
'yolo_detection_model: "models/yolo/yolov8n.pt"'
).replace(
'yolo_segmentation_model: "models/yolo/yolov8n-seg.pt"',
'yolo_segmentation_model: "models/yolo/yolov8n-seg.pt"'
).replace( ).replace(
'sam2_checkpoint: "../checkpoints/sam2.1_hiera_large.pt"', 'sam2_checkpoint: "../checkpoints/sam2.1_hiera_large.pt"',
'sam2_checkpoint: "models/sam2/checkpoints/sam2.1_hiera_large.pt"' 'sam2_checkpoint: "models/sam2/checkpoints/sam2.1_hiera_large.pt"'

494
main.py
View File

@@ -8,6 +8,8 @@ and creating green screen masks with SAM2.
import os import os
import sys import sys
import argparse import argparse
import cv2
import numpy as np
from typing import List from typing import List
# Add project root to path # Add project root to path
@@ -16,6 +18,9 @@ sys.path.append(os.path.dirname(__file__))
from core.config_loader import ConfigLoader from core.config_loader import ConfigLoader
from core.video_splitter import VideoSplitter from core.video_splitter import VideoSplitter
from core.yolo_detector import YOLODetector from core.yolo_detector import YOLODetector
from core.sam2_processor import SAM2Processor
from core.mask_processor import MaskProcessor
from core.video_assembler import VideoAssembler
from utils.logging_utils import setup_logging, get_logger from utils.logging_utils import setup_logging, get_logger
from utils.file_utils import ensure_directory from utils.file_utils import ensure_directory
from utils.status_utils import print_processing_status, cleanup_incomplete_segment from utils.status_utils import print_processing_status, cleanup_incomplete_segment
@@ -66,6 +71,100 @@ def validate_dependencies():
logger.error("Please install requirements: pip install -r requirements.txt") logger.error("Please install requirements: pip install -r requirements.txt")
return False return False
def create_yolo_mask_debug_frame(detections: List[dict], video_path: str, output_path: str, scale: float = 1.0) -> bool:
"""
Create debug visualization for YOLO direct masks.
Args:
detections: List of YOLO detections with masks
video_path: Path to video file
output_path: Path to save debug image
scale: Scale factor for frame processing
Returns:
True if debug frame was created successfully
"""
try:
# Load first frame
cap = cv2.VideoCapture(video_path)
ret, original_frame = cap.read()
cap.release()
if not ret:
logger.error("Could not read first frame for YOLO mask debug")
return False
# Scale frame if needed
if scale != 1.0:
original_frame = cv2.resize(original_frame, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR)
debug_frame = original_frame.copy()
# Define colors for each object
colors = {
1: (0, 255, 0), # Green for Object 1 (Left eye)
2: (255, 0, 0), # Blue for Object 2 (Right eye)
}
# Get detections with masks
detections_with_masks = [d for d in detections if d.get('has_mask', False)]
# Overlay masks with transparency
obj_id = 1
for detection in detections_with_masks[:2]: # Up to 2 objects
mask = detection['mask']
# Resize mask to match frame if needed
if mask.shape != original_frame.shape[:2]:
mask = cv2.resize(mask.astype(np.float32), (original_frame.shape[1], original_frame.shape[0]), interpolation=cv2.INTER_NEAREST)
mask = mask > 0.5
mask = mask.astype(bool)
# Apply colored overlay
color = colors.get(obj_id, (128, 128, 128))
overlay = debug_frame.copy()
overlay[mask] = color
# Blend with original (30% overlay, 70% original)
cv2.addWeighted(overlay, 0.3, debug_frame, 0.7, 0, debug_frame)
# Draw outline
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
cv2.drawContours(debug_frame, contours, -1, color, 2)
logger.info(f"YOLO Mask Debug: Object {obj_id} mask - shape: {mask.shape}, pixels: {np.sum(mask)}")
obj_id += 1
# Add title and source info
title = f"YOLO Direct Masks: {len(detections_with_masks)} objects detected"
cv2.putText(debug_frame, title, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (255, 255, 255), 2)
source_info = "Mask Source: YOLO Segmentation (DIRECT - No SAM2)"
cv2.putText(debug_frame, source_info, (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2) # Green for YOLO
# Add object legend
y_offset = 90
for i, detection in enumerate(detections_with_masks[:2]):
obj_id = i + 1
color = colors.get(obj_id, (128, 128, 128))
text = f"Object {obj_id}: {'Left Eye' if obj_id == 1 else 'Right Eye'} (YOLO Mask)"
cv2.putText(debug_frame, text, (10, y_offset), cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2)
y_offset += 30
# Save debug image
success = cv2.imwrite(output_path, debug_frame)
if success:
logger.info(f"YOLO Mask Debug: Saved debug frame to {output_path}")
else:
logger.error(f"Failed to save YOLO mask debug frame to {output_path}")
return success
except Exception as e:
logger.error(f"Error creating YOLO mask debug frame: {e}")
return False
def resolve_detect_segments(detect_segments, total_segments: int) -> List[int]: def resolve_detect_segments(detect_segments, total_segments: int) -> List[int]:
""" """
Resolve detect_segments configuration to list of segment indices. Resolve detect_segments configuration to list of segment indices.
@@ -157,31 +256,394 @@ def main():
detect_segments_config = config.get_detect_segments() detect_segments_config = config.get_detect_segments()
detect_segments = resolve_detect_segments(detect_segments_config, len(segments_info)) detect_segments = resolve_detect_segments(detect_segments_config, len(segments_info))
# Step 2: Run YOLO detection on specified segments # Initialize processors once
logger.info("Step 2: Running YOLO human detection") logger.info("Step 2: Initializing YOLO detector")
# Get YOLO mode and model paths
yolo_mode = config.get('models.yolo_mode', 'detection')
detection_model = config.get('models.yolo_detection_model', config.get_yolo_model_path())
segmentation_model = config.get('models.yolo_segmentation_model', None)
logger.info(f"YOLO Mode: {yolo_mode}")
detector = YOLODetector( detector = YOLODetector(
model_path=config.get_yolo_model_path(), detection_model_path=detection_model,
segmentation_model_path=segmentation_model,
mode=yolo_mode,
confidence_threshold=config.get_yolo_confidence(), confidence_threshold=config.get_yolo_confidence(),
human_class_id=config.get_human_class_id() human_class_id=config.get_human_class_id()
) )
detection_results = detector.process_segments_batch( logger.info("Step 3: Initializing SAM2 processor")
segments_info, sam2_processor = SAM2Processor(
detect_segments, checkpoint_path=config.get_sam2_checkpoint(),
scale=config.get_inference_scale() config_path=config.get_sam2_config()
) )
# Log detection summary # Initialize mask processor
total_humans = sum(len(detections) for detections in detection_results.values()) mask_processor = MaskProcessor(
logger.info(f"Detected {total_humans} humans across {len(detection_results)} segments") green_color=config.get_green_color(),
blue_color=config.get_blue_color()
)
# Step 3: Process segments with SAM2 (placeholder for now) # Process each segment sequentially (YOLO -> SAM2 -> Render)
logger.info("Step 3: SAM2 processing and green screen generation") logger.info("Step 4: Processing segments sequentially")
logger.info("SAM2 processing module not yet implemented - this is where segment processing would occur") total_humans_detected = 0
# Step 4: Assemble final video (placeholder for now) for i, segment_info in enumerate(segments_info):
logger.info("Step 4: Assembling final video with audio") segment_idx = segment_info['index']
logger.info("Video assembly module not yet implemented - this is where concatenation and audio copying would occur")
logger.info(f"Processing segment {segment_idx}/{len(segments_info)-1}")
# Skip if segment output already exists
output_video = os.path.join(segment_info['directory'], f"output_{segment_idx}.mp4")
if os.path.exists(output_video):
logger.info(f"Segment {segment_idx} already processed, skipping")
continue
# Determine if we should use YOLO detections or previous masks
use_detections = segment_idx in detect_segments
# First segment must use detections
if segment_idx == 0 and not use_detections:
logger.warning(f"First segment must use YOLO detection")
use_detections = True
# Get YOLO prompts or previous masks
yolo_prompts = None
previous_masks = None
if use_detections:
# Run YOLO detection on current segment
logger.info(f"Running YOLO detection on segment {segment_idx}")
detection_file = os.path.join(segment_info['directory'], "yolo_detections")
# Check if detection already exists
if os.path.exists(detection_file):
logger.info(f"Loading existing YOLO detections for segment {segment_idx}")
detections = detector.load_detections_from_file(detection_file)
else:
# Run YOLO detection on first frame
detections = detector.detect_humans_in_video_first_frame(
segment_info['video_file'],
scale=config.get_inference_scale()
)
# Save detections for future runs
detector.save_detections_to_file(detections, detection_file)
if detections:
total_humans_detected += len(detections)
logger.info(f"Found {len(detections)} humans in segment {segment_idx}")
# Get frame width from video
cap = cv2.VideoCapture(segment_info['video_file'])
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
cap.release()
yolo_prompts = detector.convert_detections_to_sam2_prompts(
detections, frame_width
)
# If no right eye detections found, run debug analysis with lower confidence
half_frame_width = frame_width // 2
right_eye_detections = [d for d in detections if (d['bbox'][0] + d['bbox'][2]) / 2 >= half_frame_width]
if len(right_eye_detections) == 0 and config.get('advanced.save_yolo_debug_frames', False):
logger.info(f"VR180 Debug: No right eye detections found, running lower confidence analysis...")
# Load first frame for debug analysis
cap = cv2.VideoCapture(segment_info['video_file'])
ret, debug_frame = cap.read()
cap.release()
if ret:
# Scale frame to match detection scale
if config.get_inference_scale() != 1.0:
scale = config.get_inference_scale()
debug_frame = cv2.resize(debug_frame, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR)
# Run debug detection with lower confidence
debug_detections = detector.debug_detect_with_lower_confidence(debug_frame, debug_confidence=0.3)
# Analyze where these lower confidence detections are
debug_right_eye = [d for d in debug_detections if (d['bbox'][0] + d['bbox'][2]) / 2 >= half_frame_width]
if len(debug_right_eye) > 0:
logger.warning(f"VR180 Debug: Found {len(debug_right_eye)} right eye detections with lower confidence!")
for i, det in enumerate(debug_right_eye):
logger.warning(f"VR180 Debug: Right eye detection {i+1}: conf={det['confidence']:.3f}, bbox={det['bbox']}")
logger.warning(f"VR180 Debug: Consider lowering yolo_confidence from {config.get_yolo_confidence()} to 0.3-0.4")
else:
logger.info(f"VR180 Debug: No right eye detections found even with confidence 0.3")
logger.info(f"VR180 Debug: This confirms person is not visible in right eye view")
logger.info(f"Pipeline Debug: Segment {segment_idx} - Generated {len(yolo_prompts)} SAM2 prompts from {len(detections)} YOLO detections")
# Save debug frame with detections visualized (if enabled)
if config.get('advanced.save_yolo_debug_frames', False):
debug_frame_path = os.path.join(segment_info['directory'], "yolo_debug.jpg")
# Load first frame for debug visualization
cap = cv2.VideoCapture(segment_info['video_file'])
ret, debug_frame = cap.read()
cap.release()
if ret:
# Scale frame to match detection scale
if config.get_inference_scale() != 1.0:
scale = config.get_inference_scale()
debug_frame = cv2.resize(debug_frame, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR)
detector.save_debug_frame_with_detections(debug_frame, detections, debug_frame_path, yolo_prompts)
else:
logger.warning(f"Could not load frame for debug visualization in segment {segment_idx}")
# Check if we have YOLO masks for debug visualization
has_yolo_masks = False
if detections and detector.supports_segmentation:
has_yolo_masks = any(d.get('has_mask', False) for d in detections)
# Generate first frame masks debug (SAM2 or YOLO)
first_frame_debug_path = os.path.join(segment_info['directory'], "first_frame_detection.jpg")
if has_yolo_masks:
logger.info(f"Pipeline Debug: Generating YOLO first frame masks for segment {segment_idx}")
# Create YOLO mask debug visualization
create_yolo_mask_debug_frame(detections, segment_info['video_file'], first_frame_debug_path, config.get_inference_scale())
else:
logger.info(f"Pipeline Debug: Generating SAM2 first frame masks for segment {segment_idx}")
sam2_processor.generate_first_frame_debug_masks(
segment_info['video_file'],
yolo_prompts,
first_frame_debug_path,
config.get_inference_scale()
)
else:
logger.warning(f"No humans detected in segment {segment_idx}")
# Save debug frame even when no detections (if enabled)
if config.get('advanced.save_yolo_debug_frames', False):
debug_frame_path = os.path.join(segment_info['directory'], "yolo_debug_no_detections.jpg")
# Load first frame for debug visualization
cap = cv2.VideoCapture(segment_info['video_file'])
ret, debug_frame = cap.read()
cap.release()
if ret:
# Scale frame to match detection scale
if config.get_inference_scale() != 1.0:
scale = config.get_inference_scale()
debug_frame = cv2.resize(debug_frame, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR)
# Add "No detections" text overlay
cv2.putText(debug_frame, "YOLO: No humans detected",
(10, 30),
cv2.FONT_HERSHEY_SIMPLEX, 1.0,
(0, 0, 255), 2) # Red text
cv2.imwrite(debug_frame_path, debug_frame)
logger.info(f"Saved no-detection debug frame to {debug_frame_path}")
else:
logger.warning(f"Could not load frame for no-detection debug visualization in segment {segment_idx}")
elif segment_idx > 0:
# Try to load previous segment mask
for j in range(segment_idx - 1, -1, -1):
prev_segment_dir = segments_info[j]['directory']
previous_masks = sam2_processor.load_previous_segment_mask(prev_segment_dir)
if previous_masks:
logger.info(f"Using masks from segment {j} for segment {segment_idx}")
break
if not yolo_prompts and not previous_masks:
logger.error(f"No prompts or previous masks available for segment {segment_idx}")
continue
# Check if we have YOLO masks and can skip SAM2 (recheck in case detections were loaded from file)
if not 'has_yolo_masks' in locals():
has_yolo_masks = False
if detections and detector.supports_segmentation:
has_yolo_masks = any(d.get('has_mask', False) for d in detections)
if has_yolo_masks:
logger.info(f"Pipeline Debug: YOLO segmentation provided masks - using as SAM2 initial masks for segment {segment_idx}")
# Convert YOLO masks to initial masks for SAM2
cap = cv2.VideoCapture(segment_info['video_file'])
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
cap.release()
# Convert YOLO masks to the format expected by SAM2 add_previous_masks_to_predictor
yolo_masks_dict = {}
for i, detection in enumerate(detections[:2]): # Up to 2 objects
if detection.get('has_mask', False):
mask = detection['mask']
# Resize mask to match inference scale
if config.get_inference_scale() != 1.0:
scale = config.get_inference_scale()
scaled_height = int(frame_height * scale)
scaled_width = int(frame_width * scale)
mask = cv2.resize(mask.astype(np.float32), (scaled_width, scaled_height), interpolation=cv2.INTER_NEAREST)
mask = mask > 0.5
obj_id = i + 1 # Sequential object IDs
yolo_masks_dict[obj_id] = mask.astype(bool)
logger.info(f"Pipeline Debug: YOLO mask for Object {obj_id} - shape: {mask.shape}, pixels: {np.sum(mask)}")
logger.info(f"Pipeline Debug: Using YOLO masks as SAM2 initial masks - {len(yolo_masks_dict)} objects")
# Use traditional SAM2 pipeline with YOLO masks as initial masks
previous_masks = yolo_masks_dict
yolo_prompts = None # Don't use bounding box prompts when we have masks
# Debug what we're passing to SAM2
if yolo_prompts:
logger.info(f"Pipeline Debug: Passing {len(yolo_prompts)} YOLO prompts to SAM2 for segment {segment_idx}")
for i, prompt in enumerate(yolo_prompts):
logger.info(f"Pipeline Debug: Prompt {i+1}: Object {prompt['obj_id']}, bbox={prompt['bbox']}")
if previous_masks:
logger.info(f"Pipeline Debug: Using {len(previous_masks)} previous masks for segment {segment_idx}")
logger.info(f"Pipeline Debug: Previous mask object IDs: {list(previous_masks.keys())}")
# Handle mid-segment detection if enabled (only when using YOLO prompts, not masks)
multi_frame_prompts = None
if config.get('advanced.enable_mid_segment_detection', False) and yolo_prompts:
logger.info(f"Mid-segment Detection: Enabled for segment {segment_idx}")
# Calculate frame indices for re-detection
cap = cv2.VideoCapture(segment_info['video_file'])
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
cap.release()
redetection_interval = config.get('advanced.redetection_interval', 30)
max_redetections = config.get('advanced.max_redetections_per_segment', 10)
# Generate frame indices: [30, 60, 90, ...] (skip frame 0 since we already have first frame prompts)
frame_indices = []
frame_idx = redetection_interval
while frame_idx < total_frames and len(frame_indices) < max_redetections:
frame_indices.append(frame_idx)
frame_idx += redetection_interval
if frame_indices:
logger.info(f"Mid-segment Detection: Running YOLO on frames {frame_indices} (interval={redetection_interval})")
# Run multi-frame detection
multi_frame_detections = detector.detect_humans_multi_frame(
segment_info['video_file'],
frame_indices,
scale=config.get_inference_scale()
)
# Convert detections to SAM2 prompts
multi_frame_prompts = {}
cap = cv2.VideoCapture(segment_info['video_file'])
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
cap.release()
for frame_idx, detections in multi_frame_detections.items():
if detections:
prompts = detector.convert_detections_to_sam2_prompts(detections, frame_width)
multi_frame_prompts[frame_idx] = prompts
logger.info(f"Mid-segment Detection: Frame {frame_idx} -> {len(prompts)} SAM2 prompts")
logger.info(f"Mid-segment Detection: Generated prompts for {len(multi_frame_prompts)} frames")
else:
logger.info(f"Mid-segment Detection: No additional frames to process (segment has {total_frames} frames)")
elif config.get('advanced.enable_mid_segment_detection', False):
logger.info(f"Mid-segment Detection: Skipped for segment {segment_idx} (no initial YOLO prompts)")
# Process segment with SAM2
logger.info(f"Pipeline Debug: Starting SAM2 processing for segment {segment_idx}")
video_segments = sam2_processor.process_single_segment(
segment_info,
yolo_prompts=yolo_prompts,
previous_masks=previous_masks,
inference_scale=config.get_inference_scale(),
multi_frame_prompts=multi_frame_prompts
)
if video_segments is None:
logger.error(f"SAM2 processing failed for segment {segment_idx}")
continue
# Debug what SAM2 produced
logger.info(f"Pipeline Debug: SAM2 completed for segment {segment_idx}")
logger.info(f"Pipeline Debug: Generated masks for {len(video_segments)} frames")
if video_segments:
# Check first frame to see what objects were tracked
first_frame_idx = min(video_segments.keys())
first_frame_objects = video_segments[first_frame_idx]
logger.info(f"Pipeline Debug: First frame contains {len(first_frame_objects)} tracked objects")
logger.info(f"Pipeline Debug: Tracked object IDs: {list(first_frame_objects.keys())}")
for obj_id, mask in first_frame_objects.items():
mask_pixels = np.sum(mask)
logger.info(f"Pipeline Debug: Object {obj_id} mask has {mask_pixels} pixels")
# Check last frame as well
last_frame_idx = max(video_segments.keys())
last_frame_objects = video_segments[last_frame_idx]
logger.info(f"Pipeline Debug: Last frame contains {len(last_frame_objects)} tracked objects")
logger.info(f"Pipeline Debug: Final object IDs: {list(last_frame_objects.keys())}")
# Save final masks for next segment
mask_path = os.path.join(segment_info['directory'], "mask.png")
sam2_processor.save_final_masks(
video_segments,
mask_path,
green_color=config.get_green_color(),
blue_color=config.get_blue_color()
)
# Apply green screen and save output video
success = mask_processor.process_segment(
segment_info,
video_segments,
use_nvenc=config.get_use_nvenc(),
bitrate=config.get_output_bitrate()
)
if success:
logger.info(f"Successfully processed segment {segment_idx}")
else:
logger.error(f"Failed to create green screen video for segment {segment_idx}")
# Log processing summary
logger.info(f"Sequential processing complete. Total humans detected: {total_humans_detected}")
# Step 3: Assemble final video
logger.info("Step 3: Assembling final video with audio")
# Initialize video assembler
assembler = VideoAssembler(
preserve_audio=config.get_preserve_audio(),
use_nvenc=config.get_use_nvenc()
)
# Verify all segments are complete
all_complete, missing = assembler.verify_segment_completeness(segments_dir)
if not all_complete:
logger.error(f"Cannot assemble video - missing segments: {missing}")
return 1
# Assemble final video
final_output = os.path.join(output_dir, config.get_output_filename())
success = assembler.assemble_final_video(
segments_dir,
input_video,
final_output,
bitrate=config.get_output_bitrate()
)
if success:
logger.info(f"Final video saved to: {final_output}")
logger.info("Pipeline completed successfully") logger.info("Pipeline completed successfully")
return 0 return 0

View File

@@ -6,6 +6,7 @@ opencv-python>=4.8.0
numpy>=1.24.0 numpy>=1.24.0
# SAM2 - Segment Anything Model 2 # SAM2 - Segment Anything Model 2
# Note: Make sure to run download_models.py after installing to get model weights
git+https://github.com/facebookresearch/sam2.git git+https://github.com/facebookresearch/sam2.git
# GPU acceleration (optional but recommended) # GPU acceleration (optional but recommended)
@@ -17,6 +18,8 @@ tqdm>=4.65.0
matplotlib>=3.7.0 matplotlib>=3.7.0
Pillow>=10.0.0 Pillow>=10.0.0
decord
# Optional: For advanced features # Optional: For advanced features
psutil>=5.9.0 # Memory monitoring psutil>=5.9.0 # Memory monitoring
pympler>=0.9 # Memory profiling (for debugging) pympler>=0.9 # Memory profiling (for debugging)