Compare commits

..

5 Commits

Author SHA1 Message Date
6617acb1c9 working 2025-07-29 10:13:29 -07:00
02ad4d87d2 not working 2025-07-27 14:26:20 -07:00
97f12c79a4 working still 2025-07-27 14:14:21 -07:00
cd7bc54efe working with segemntation 2025-07-27 13:55:52 -07:00
46363a8a11 stage 1 working 2025-07-27 12:11:36 -07:00
10 changed files with 2455 additions and 208 deletions

View File

@@ -32,19 +32,40 @@ git clone <repository-url>
cd samyolo_on_segments cd samyolo_on_segments
# Install Python dependencies # Install Python dependencies
pip install -r requirements.txt uv venv && source .venv/bin/activate
uv pip install -r requirements.txt
``` ```
### Model Dependencies ### Download Models
You'll need to download the required model checkpoints: Use the provided script to automatically download all required models:
```bash
# Download SAM2.1 and YOLO models
python download_models.py
```
This script will:
- Create a `models/` directory structure
- Download SAM2.1 configs and checkpoints (tiny, small, base+, large)
- Download common YOLO models (yolov8n, yolov8s, yolov8m)
- Update `config.yaml` to use local model paths
**Manual Download (Alternative):**
1. **SAM2 Models**: Download from [Meta's SAM2 repository](https://github.com/facebookresearch/sam2) 1. **SAM2 Models**: Download from [Meta's SAM2 repository](https://github.com/facebookresearch/sam2)
2. **YOLO Models**: YOLOv8 models will be downloaded automatically or you can specify a custom path 2. **YOLO Models**: YOLOv8 models will be downloaded automatically on first use
## Quick Start ## Quick Start
### 1. Configure the Pipeline ### 1. Download Models
First, download the required SAM2.1 and YOLO models:
```bash
python download_models.py
```
### 2. Configure the Pipeline
Edit `config.yaml` to specify your input video and desired settings: Edit `config.yaml` to specify your input video and desired settings:
@@ -63,18 +84,18 @@ processing:
detect_segments: "all" detect_segments: "all"
models: models:
yolo_model: "yolov8n.pt" yolo_model: "models/yolo/yolov8n.pt"
sam2_checkpoint: "../checkpoints/sam2.1_hiera_large.pt" sam2_checkpoint: "models/sam2/checkpoints/sam2.1_hiera_large.pt"
sam2_config: "configs/sam2.1/sam2.1_hiera_l.yaml" sam2_config: "models/sam2/configs/sam2.1/sam2.1_hiera_l.yaml"
``` ```
### 2. Run the Pipeline ### 3. Run the Pipeline
```bash ```bash
python main.py --config config.yaml python main.py --config config.yaml
``` ```
### 3. Monitor Progress ### 4. Monitor Progress
Check processing status: Check processing status:
```bash ```bash
@@ -166,8 +187,25 @@ samyolo_on_segments/
├── README.md # This documentation ├── README.md # This documentation
├── config.yaml # Default configuration ├── config.yaml # Default configuration
├── main.py # Main entry point ├── main.py # Main entry point
├── download_models.py # Model download script
├── requirements.txt # Python dependencies ├── requirements.txt # Python dependencies
├── spec.md # Detailed specification ├── spec.md # Detailed specification
├── models/ # Downloaded models (created by script)
│ ├── sam2/
│ │ ├── configs/sam2.1/ # SAM2.1 configuration files
│ │ │ ├── sam2.1_hiera_t.yaml
│ │ │ ├── sam2.1_hiera_s.yaml
│ │ │ ├── sam2.1_hiera_b+.yaml
│ │ │ └── sam2.1_hiera_l.yaml
│ │ └── checkpoints/ # SAM2.1 model weights
│ │ ├── sam2.1_hiera_tiny.pt
│ │ ├── sam2.1_hiera_small.pt
│ │ ├── sam2.1_hiera_base_plus.pt
│ │ └── sam2.1_hiera_large.pt
│ └── yolo/ # YOLO model weights
│ ├── yolov8n.pt
│ ├── yolov8s.pt
│ └── yolov8m.pt
├── core/ # Core processing modules ├── core/ # Core processing modules
│ ├── __init__.py │ ├── __init__.py
│ ├── config_loader.py # Configuration management │ ├── config_loader.py # Configuration management

View File

@@ -23,11 +23,11 @@ processing:
models: models:
# YOLO model path - can be pretrained (yolov8n.pt) or custom path # YOLO model path - can be pretrained (yolov8n.pt) or custom path
yolo_model: "yolov8n.pt" yolo_model: "models/yolo/yolov8n.pt"
# SAM2 model configuration # SAM2 model configuration
sam2_checkpoint: "../checkpoints/sam2.1_hiera_large.pt" sam2_checkpoint: "models/sam2/checkpoints/sam2.1_hiera_large.pt"
sam2_config: "configs/sam2.1/sam2.1_hiera_l.yaml" sam2_config: "models/sam2/configs/sam2.1/sam2.1_hiera_l.yaml"
video: video:
# Use NVIDIA hardware encoding (requires NVENC-capable GPU) # Use NVIDIA hardware encoding (requires NVENC-capable GPU)
@@ -57,3 +57,6 @@ advanced:
# Logging level (DEBUG, INFO, WARNING, ERROR) # Logging level (DEBUG, INFO, WARNING, ERROR)
log_level: "INFO" log_level: "INFO"
# Save debug frames with YOLO detections visualized
save_yolo_debug_frames: true

View File

@@ -50,11 +50,31 @@ class ConfigLoader:
raise ValueError(f"Missing required field: output.{field}") raise ValueError(f"Missing required field: output.{field}")
# Validate models section # Validate models section
required_model_fields = ['yolo_model', 'sam2_checkpoint', 'sam2_config'] required_model_fields = ['sam2_checkpoint', 'sam2_config']
for field in required_model_fields: for field in required_model_fields:
if field not in self.config['models']: if field not in self.config['models']:
raise ValueError(f"Missing required field: models.{field}") raise ValueError(f"Missing required field: models.{field}")
# Validate YOLO model configuration
yolo_mode = self.config['models'].get('yolo_mode', 'detection')
if yolo_mode not in ['detection', 'segmentation']:
raise ValueError(f"Invalid yolo_mode: {yolo_mode}. Must be 'detection' or 'segmentation'")
# Check for legacy yolo_model field vs new structure
has_legacy_yolo_model = 'yolo_model' in self.config['models']
has_new_yolo_models = 'yolo_detection_model' in self.config['models'] or 'yolo_segmentation_model' in self.config['models']
if not has_legacy_yolo_model and not has_new_yolo_models:
raise ValueError("Missing YOLO model configuration. Provide either 'yolo_model' (legacy) or 'yolo_detection_model'/'yolo_segmentation_model' (new)")
# Validate that the required model for the current mode exists
if yolo_mode == 'detection':
if has_new_yolo_models and 'yolo_detection_model' not in self.config['models']:
raise ValueError("yolo_mode is 'detection' but yolo_detection_model not specified")
elif yolo_mode == 'segmentation':
if has_new_yolo_models and 'yolo_segmentation_model' not in self.config['models']:
raise ValueError("yolo_mode is 'segmentation' but yolo_segmentation_model not specified")
# Validate processing.detect_segments format # Validate processing.detect_segments format
detect_segments = self.config['processing'].get('detect_segments', 'all') detect_segments = self.config['processing'].get('detect_segments', 'all')
if not isinstance(detect_segments, (str, list)): if not isinstance(detect_segments, (str, list)):
@@ -114,9 +134,18 @@ class ConfigLoader:
return self.config['processing'].get('detect_segments', 'all') return self.config['processing'].get('detect_segments', 'all')
def get_yolo_model_path(self) -> str: def get_yolo_model_path(self) -> str:
"""Get YOLO model path.""" """Get YOLO model path (legacy method for backward compatibility)."""
# Check for legacy configuration first
if 'yolo_model' in self.config['models']:
return self.config['models']['yolo_model'] return self.config['models']['yolo_model']
# Use new configuration based on mode
yolo_mode = self.config['models'].get('yolo_mode', 'detection')
if yolo_mode == 'detection':
return self.config['models'].get('yolo_detection_model', 'yolov8n.pt')
else: # segmentation mode
return self.config['models'].get('yolo_segmentation_model', 'yolov8n-seg.pt')
def get_sam2_checkpoint(self) -> str: def get_sam2_checkpoint(self) -> str:
"""Get SAM2 checkpoint path.""" """Get SAM2 checkpoint path."""
return self.config['models']['sam2_checkpoint'] return self.config['models']['sam2_checkpoint']

View File

@@ -17,16 +17,18 @@ logger = logging.getLogger(__name__)
class SAM2Processor: class SAM2Processor:
"""Handles SAM2-based video segmentation for human tracking.""" """Handles SAM2-based video segmentation for human tracking."""
def __init__(self, checkpoint_path: str, config_path: str): def __init__(self, checkpoint_path: str, config_path: str, vos_optimized: bool = False):
""" """
Initialize SAM2 processor. Initialize SAM2 processor.
Args: Args:
checkpoint_path: Path to SAM2 checkpoint checkpoint_path: Path to SAM2 checkpoint
config_path: Path to SAM2 config file config_path: Path to SAM2 config file
vos_optimized: Enable VOS optimization for speedup (requires PyTorch 2.5.1+)
""" """
self.checkpoint_path = checkpoint_path self.checkpoint_path = checkpoint_path
self.config_path = config_path self.config_path = config_path
self.vos_optimized = vos_optimized
self.predictor = None self.predictor = None
self._initialize_predictor() self._initialize_predictor()
@@ -46,12 +48,51 @@ class SAM2Processor:
logger.info(f"Using device: {device}") logger.info(f"Using device: {device}")
try:
# Extract just the config filename for SAM2's Hydra-based loader
# SAM2 expects a config name relative to its internal config directory
config_name = os.path.basename(self.config_path)
if config_name.endswith('.yaml'):
config_name = config_name[:-5] # Remove .yaml extension
# SAM2 configs are in the format "sam2.1_hiera_X.yaml"
# and should be referenced as "configs/sam2.1/sam2.1_hiera_X"
if config_name.startswith("sam2.1_hiera"):
config_name = f"configs/sam2.1/{config_name}"
elif config_name.startswith("sam2_hiera"):
config_name = f"configs/sam2/{config_name}"
logger.info(f"Using SAM2 config: {config_name}")
# Use VOS optimization if enabled and supported
if self.vos_optimized:
try: try:
self.predictor = build_sam2_video_predictor( self.predictor = build_sam2_video_predictor(
self.config_path, config_name, # Use just the config name, not full path
self.checkpoint_path, self.checkpoint_path,
device=device device=device,
vos_optimized=True # New optimization for major speedup
) )
logger.info("Using optimized SAM2 VOS predictor with full model compilation")
except Exception as e:
logger.warning(f"Failed to use optimized VOS predictor: {e}")
logger.info("Falling back to standard SAM2 predictor")
# Fallback to standard predictor
self.predictor = build_sam2_video_predictor(
config_name,
self.checkpoint_path,
device=device,
overrides=dict(conf=0.95)
)
else:
# Use standard predictor
self.predictor = build_sam2_video_predictor(
config_name,
self.checkpoint_path,
device=device,
overrides=dict(conf=0.95)
)
logger.info("Using standard SAM2 predictor")
# Enable optimizations for CUDA # Enable optimizations for CUDA
if device.type == "cuda": if device.type == "cuda":
@@ -103,6 +144,7 @@ class SAM2Processor:
def add_yolo_prompts_to_predictor(self, inference_state, prompts: List[Dict[str, Any]]) -> bool: def add_yolo_prompts_to_predictor(self, inference_state, prompts: List[Dict[str, Any]]) -> bool:
""" """
Add YOLO detection prompts to SAM2 predictor. Add YOLO detection prompts to SAM2 predictor.
Includes error handling matching the working spec.md implementation.
Args: Args:
inference_state: SAM2 inference state inference_state: SAM2 inference state
@@ -112,14 +154,21 @@ class SAM2Processor:
True if prompts were added successfully True if prompts were added successfully
""" """
if not prompts: if not prompts:
logger.warning("No prompts provided to SAM2") logger.warning("SAM2 Debug: No prompts provided to SAM2")
return False return False
try: logger.info(f"SAM2 Debug: Received {len(prompts)} prompts to add to predictor")
for prompt in prompts:
success_count = 0
for i, prompt in enumerate(prompts):
obj_id = prompt['obj_id'] obj_id = prompt['obj_id']
bbox = prompt['bbox'] bbox = prompt['bbox']
confidence = prompt.get('confidence', 'unknown')
logger.info(f"SAM2 Debug: Adding prompt {i+1}/{len(prompts)}: Object {obj_id}, bbox={bbox}, conf={confidence}")
try:
_, out_obj_ids, out_mask_logits = self.predictor.add_new_points_or_box( _, out_obj_ids, out_mask_logits = self.predictor.add_new_points_or_box(
inference_state=inference_state, inference_state=inference_state,
frame_idx=0, frame_idx=0,
@@ -127,13 +176,19 @@ class SAM2Processor:
box=bbox.astype(np.float32), box=bbox.astype(np.float32),
) )
logger.debug(f"Added prompt for Object {obj_id}: {bbox}") logger.info(f"SAM2 Debug: ✓ Successfully added Object {obj_id} - returned obj_ids: {out_obj_ids}")
success_count += 1
logger.info(f"Successfully added {len(prompts)} prompts to SAM2")
return True
except Exception as e: except Exception as e:
logger.error(f"Error adding prompts to SAM2: {e}") logger.error(f"SAM2 Debug: ✗ Error adding Object {obj_id}: {e}")
# Continue processing other prompts even if one fails
continue
if success_count > 0:
logger.info(f"SAM2 Debug: Final result - {success_count}/{len(prompts)} prompts successfully added")
return True
else:
logger.error("SAM2 Debug: FAILED - No prompts were successfully added to SAM2")
return False return False
def load_previous_segment_mask(self, prev_segment_dir: str) -> Optional[Dict[int, np.ndarray]]: def load_previous_segment_mask(self, prev_segment_dir: str) -> Optional[Dict[int, np.ndarray]]:
@@ -218,32 +273,46 @@ class SAM2Processor:
Dictionary mapping frame indices to object masks Dictionary mapping frame indices to object masks
""" """
video_segments = {} video_segments = {}
frame_count = 0
try: try:
logger.info("Starting SAM2 mask propagation...")
for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(inference_state): for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(inference_state):
video_segments[out_frame_idx] = { video_segments[out_frame_idx] = {
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy() out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
for i, out_obj_id in enumerate(out_obj_ids) for i, out_obj_id in enumerate(out_obj_ids)
} }
frame_count += 1
logger.info(f"Propagated masks across {len(video_segments)} frames with {len(out_obj_ids)} objects") # Log progress every 50 frames
if frame_count % 50 == 0:
logger.info(f"SAM2 propagation progress: {frame_count} frames processed")
logger.info(f"SAM2 propagation completed: {len(video_segments)} frames with {len(out_obj_ids) if 'out_obj_ids' in locals() else 0} objects")
except Exception as e: except Exception as e:
logger.error(f"Error during mask propagation: {e}") logger.error(f"Error during mask propagation after {frame_count} frames: {e}")
logger.error("This may be due to VOS optimization issues or insufficient GPU memory")
if frame_count == 0:
logger.error("No frames were processed - propagation failed completely")
else:
logger.warning(f"Partial propagation completed: {frame_count} frames before failure")
return video_segments return video_segments
def process_single_segment(self, segment_info: dict, yolo_prompts: Optional[List[Dict[str, Any]]] = None, def process_single_segment(self, segment_info: dict, yolo_prompts: Optional[List[Dict[str, Any]]] = None,
previous_masks: Optional[Dict[int, np.ndarray]] = None, previous_masks: Optional[Dict[int, np.ndarray]] = None,
inference_scale: float = 0.5) -> Optional[Dict[int, Dict[int, np.ndarray]]]: inference_scale: float = 0.5,
multi_frame_prompts: Optional[Dict[int, List[Dict[str, Any]]]] = None) -> Optional[Dict[int, Dict[int, np.ndarray]]]:
""" """
Process a single video segment with SAM2. Process a single video segment with SAM2.
Args: Args:
segment_info: Segment information dictionary segment_info: Segment information dictionary
yolo_prompts: Optional YOLO detection prompts yolo_prompts: Optional YOLO detection prompts for first frame
previous_masks: Optional masks from previous segment previous_masks: Optional masks from previous segment
inference_scale: Scale factor for inference inference_scale: Scale factor for inference
multi_frame_prompts: Optional prompts for multiple frames (mid-segment detection)
Returns: Returns:
Video segments dictionary or None if failed Video segments dictionary or None if failed
@@ -284,6 +353,13 @@ class SAM2Processor:
logger.error(f"No prompts or previous masks available for segment {segment_idx}") logger.error(f"No prompts or previous masks available for segment {segment_idx}")
return None return None
# Add mid-segment prompts if provided
if multi_frame_prompts:
logger.info(f"Adding mid-segment prompts for segment {segment_idx}")
if not self.add_multi_frame_prompts_to_predictor(inference_state, multi_frame_prompts):
logger.warning(f"Failed to add mid-segment prompts for segment {segment_idx}")
# Don't return None here - continue with existing prompts
# Propagate masks # Propagate masks
video_segments = self.propagate_masks(inference_state) video_segments = self.propagate_masks(inference_state)
@@ -360,3 +436,217 @@ class SAM2Processor:
except Exception as e: except Exception as e:
logger.error(f"Error saving final masks: {e}") logger.error(f"Error saving final masks: {e}")
def generate_first_frame_debug_masks(self, video_path: str, prompts: List[Dict[str, Any]],
output_path: str, inference_scale: float = 0.5) -> bool:
"""
Generate SAM2 masks for just the first frame and save debug visualization.
This helps debug what SAM2 is producing for each detected object.
Args:
video_path: Path to the video file
prompts: List of SAM2 prompt dictionaries
output_path: Path to save the debug image
inference_scale: Scale factor for SAM2 inference
Returns:
True if debug masks were generated successfully
"""
if not prompts:
logger.warning("No prompts provided for first frame debug")
return False
try:
logger.info(f"SAM2 Debug: Generating first frame masks for {len(prompts)} objects")
# Load the first frame
cap = cv2.VideoCapture(video_path)
ret, original_frame = cap.read()
cap.release()
if not ret:
logger.error("Could not read first frame for debug mask generation")
return False
# Scale frame for inference if needed
if inference_scale != 1.0:
inference_frame = cv2.resize(original_frame, None, fx=inference_scale, fy=inference_scale, interpolation=cv2.INTER_LINEAR)
else:
inference_frame = original_frame.copy()
# Create temporary low-res video with just first frame
import tempfile
import os
temp_dir = tempfile.mkdtemp()
temp_video_path = os.path.join(temp_dir, "first_frame.mp4")
# Write single frame to temporary video
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(temp_video_path, fourcc, 1.0, (inference_frame.shape[1], inference_frame.shape[0]))
out.write(inference_frame)
out.release()
# Initialize SAM2 inference state with single frame
inference_state = self.predictor.init_state(video_path=temp_video_path, async_loading_frames=True)
# Add prompts
if not self.add_yolo_prompts_to_predictor(inference_state, prompts):
logger.error("Failed to add prompts for first frame debug")
return False
# Generate masks for first frame only
frame_masks = {}
for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(inference_state):
if out_frame_idx == 0: # Only process first frame
frame_masks = {
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
for i, out_obj_id in enumerate(out_obj_ids)
}
break
if not frame_masks:
logger.error("No masks generated for first frame debug")
return False
# Create debug visualization
debug_frame = original_frame.copy()
# Define colors for each object
colors = {
1: (0, 255, 0), # Green for Object 1 (Left eye)
2: (255, 0, 0), # Blue for Object 2 (Right eye)
3: (0, 255, 255), # Yellow for Object 3
4: (255, 0, 255), # Magenta for Object 4
}
# Overlay masks with transparency
for obj_id, mask in frame_masks.items():
mask = mask.squeeze()
# Resize mask to match original frame if needed
if mask.shape != original_frame.shape[:2]:
mask = cv2.resize(mask.astype(np.float32), (original_frame.shape[1], original_frame.shape[0]), interpolation=cv2.INTER_NEAREST)
mask = mask > 0.5
# Apply colored overlay
color = colors.get(obj_id, (128, 128, 128))
overlay = debug_frame.copy()
overlay[mask] = color
# Blend with original (30% overlay, 70% original)
cv2.addWeighted(overlay, 0.3, debug_frame, 0.7, 0, debug_frame)
# Draw outline
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
cv2.drawContours(debug_frame, contours, -1, color, 2)
logger.info(f"SAM2 Debug: Object {obj_id} mask - shape: {mask.shape}, pixels: {np.sum(mask)}")
# Add title
title = f"SAM2 First Frame Masks: {len(frame_masks)} objects detected"
cv2.putText(debug_frame, title, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (255, 255, 255), 2)
# Add mask source information
source_info = "Mask Source: SAM2 (from YOLO bounding boxes)"
cv2.putText(debug_frame, source_info, (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2)
# Add object legend
y_offset = 90
for obj_id in sorted(frame_masks.keys()):
color = colors.get(obj_id, (128, 128, 128))
text = f"Object {obj_id}: {'Left Eye' if obj_id == 1 else 'Right Eye' if obj_id == 2 else f'Object {obj_id}'}"
cv2.putText(debug_frame, text, (10, y_offset), cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2)
y_offset += 30
# Save debug image
success = cv2.imwrite(output_path, debug_frame)
# Cleanup
self.predictor.reset_state(inference_state)
import shutil
shutil.rmtree(temp_dir)
if success:
logger.info(f"SAM2 Debug: Saved first frame masks to {output_path}")
return True
else:
logger.error(f"Failed to save first frame masks to {output_path}")
return False
except Exception as e:
logger.error(f"Error generating first frame debug masks: {e}")
return False
def add_multi_frame_prompts_to_predictor(self, inference_state, multi_frame_prompts: Dict[int, Any]) -> bool:
"""
Add YOLO prompts at multiple frame indices for mid-segment re-detection.
Supports both bounding box prompts (detection mode) and mask prompts (segmentation mode).
Args:
inference_state: SAM2 inference state
multi_frame_prompts: Dictionary mapping frame_index -> prompts (list of dicts for bbox, dict with 'masks' for segmentation)
Returns:
True if prompts were added successfully
"""
if not multi_frame_prompts:
logger.warning("SAM2 Mid-segment: No multi-frame prompts provided")
return False
success_count = 0
total_count = 0
for frame_idx, prompts_data in multi_frame_prompts.items():
# Check if this is segmentation mode (masks) or detection mode (bbox prompts)
if isinstance(prompts_data, dict) and 'masks' in prompts_data:
# Segmentation mode: add masks directly
masks_dict = prompts_data['masks']
logger.info(f"SAM2 Mid-segment: Processing frame {frame_idx} with {len(masks_dict)} YOLO masks")
for obj_id, mask in masks_dict.items():
total_count += 1
logger.info(f"SAM2 Mid-segment: Frame {frame_idx}, adding mask for Object {obj_id}")
try:
self.predictor.add_new_mask(inference_state, frame_idx, obj_id, mask)
logger.info(f"SAM2 Mid-segment: ✓ Frame {frame_idx}, Object {obj_id} mask added successfully")
success_count += 1
except Exception as e:
logger.error(f"SAM2 Mid-segment: ✗ Frame {frame_idx}, Object {obj_id} mask failed: {e}")
continue
else:
# Detection mode: add bounding box prompts (existing logic)
prompts = prompts_data
logger.info(f"SAM2 Mid-segment: Processing frame {frame_idx} with {len(prompts)} bbox prompts")
for i, prompt in enumerate(prompts):
obj_id = prompt['obj_id']
bbox = prompt['bbox']
confidence = prompt.get('confidence', 'unknown')
total_count += 1
logger.info(f"SAM2 Mid-segment: Frame {frame_idx}, Prompt {i+1}/{len(prompts)}: Object {obj_id}, bbox={bbox}, conf={confidence}")
try:
_, out_obj_ids, out_mask_logits = self.predictor.add_new_points_or_box(
inference_state=inference_state,
frame_idx=frame_idx, # Key: specify the exact frame index
obj_id=obj_id,
box=bbox.astype(np.float32),
)
logger.info(f"SAM2 Mid-segment: ✓ Frame {frame_idx}, Object {obj_id} added successfully - returned obj_ids: {out_obj_ids}")
success_count += 1
except Exception as e:
logger.error(f"SAM2 Mid-segment: ✗ Frame {frame_idx}, Object {obj_id} failed: {e}")
continue
if success_count > 0:
logger.info(f"SAM2 Mid-segment: Final result - {success_count}/{total_count} prompts successfully added across {len(multi_frame_prompts)} frames")
return True
else:
logger.error("SAM2 Mid-segment: FAILED - No prompts were successfully added")
return False

View File

@@ -7,7 +7,7 @@ import os
import subprocess import subprocess
import logging import logging
from typing import List, Tuple from typing import List, Tuple
from ..utils.file_utils import ensure_directory, get_video_file_name from utils.file_utils import ensure_directory, get_video_file_name
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@@ -7,46 +7,71 @@ import os
import cv2 import cv2
import numpy as np import numpy as np
import logging import logging
from typing import List, Dict, Any, Optional from typing import List, Dict, Any, Optional, Tuple
from ultralytics import YOLO from ultralytics import YOLO
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class YOLODetector: class YOLODetector:
\"\"\"Handles YOLO-based human detection for video segments.\"\"\" """Handles YOLO-based human detection for video segments with support for both detection and segmentation modes."""
def __init__(self, model_path: str, confidence_threshold: float = 0.6, human_class_id: int = 0): def __init__(self, detection_model_path: str = None, segmentation_model_path: str = None,
\"\"\" mode: str = "detection", confidence_threshold: float = 0.6, human_class_id: int = 0):
Initialize YOLO detector. """
Initialize YOLO detector with support for both detection and segmentation modes.
Args: Args:
model_path: Path to YOLO model weights detection_model_path: Path to YOLO detection model weights (e.g., yolov8n.pt)
segmentation_model_path: Path to YOLO segmentation model weights (e.g., yolov8n-seg.pt)
mode: Detection mode - "detection" for bboxes, "segmentation" for masks
confidence_threshold: Detection confidence threshold confidence_threshold: Detection confidence threshold
human_class_id: COCO class ID for humans (0 = person) human_class_id: COCO class ID for humans (0 = person)
\"\"\" """
self.model_path = model_path self.mode = mode
self.confidence_threshold = confidence_threshold self.confidence_threshold = confidence_threshold
self.human_class_id = human_class_id self.human_class_id = human_class_id
# Select model path based on mode
if mode == "segmentation":
if not segmentation_model_path:
raise ValueError("segmentation_model_path required for segmentation mode")
self.model_path = segmentation_model_path
self.supports_segmentation = True
elif mode == "detection":
if not detection_model_path:
raise ValueError("detection_model_path required for detection mode")
self.model_path = detection_model_path
self.supports_segmentation = False
else:
raise ValueError(f"Invalid mode: {mode}. Must be 'detection' or 'segmentation'")
# Load YOLO model # Load YOLO model
try: try:
self.model = YOLO(model_path) self.model = YOLO(self.model_path)
logger.info(f\"Loaded YOLO model from {model_path}\") logger.info(f"Loaded YOLO model in {mode} mode from {self.model_path}")
# Verify model capabilities
if mode == "segmentation":
# Test if model actually supports segmentation
logger.info(f"YOLO Segmentation: Model loaded, will output direct masks")
else:
logger.info(f"YOLO Detection: Model loaded, will output bounding boxes")
except Exception as e: except Exception as e:
logger.error(f\"Failed to load YOLO model: {e}\") logger.error(f"Failed to load YOLO model: {e}")
raise raise
def detect_humans_in_frame(self, frame: np.ndarray) -> List[Dict[str, Any]]: def detect_humans_in_frame(self, frame: np.ndarray) -> List[Dict[str, Any]]:
\"\"\" """
Detect humans in a single frame using YOLO. Detect humans in a single frame using YOLO.
Args: Args:
frame: Input frame (BGR format from OpenCV) frame: Input frame (BGR format from OpenCV)
Returns: Returns:
List of human detection dictionaries with bbox and confidence List of human detection dictionaries with bbox, confidence, and optionally masks
\"\"\" """
# Run YOLO detection # Run YOLO detection/segmentation
results = self.model(frame, conf=self.confidence_threshold, verbose=False) results = self.model(frame, conf=self.confidence_threshold, verbose=False)
human_detections = [] human_detections = []
@@ -54,8 +79,10 @@ class YOLODetector:
# Process results # Process results
for result in results: for result in results:
boxes = result.boxes boxes = result.boxes
masks = result.masks if hasattr(result, 'masks') and result.masks is not None else None
if boxes is not None: if boxes is not None:
for box in boxes: for i, box in enumerate(boxes):
# Get class ID # Get class ID
cls = int(box.cls.cpu().numpy()[0]) cls = int(box.cls.cpu().numpy()[0])
@@ -65,17 +92,34 @@ class YOLODetector:
coords = box.xyxy[0].cpu().numpy() coords = box.xyxy[0].cpu().numpy()
conf = float(box.conf.cpu().numpy()[0]) conf = float(box.conf.cpu().numpy()[0])
human_detections.append({ detection = {
'bbox': coords, 'bbox': coords,
'confidence': conf 'confidence': conf,
}) 'has_mask': False,
'mask': None
}
logger.debug(f\"Detected human with confidence {conf:.2f} at {coords}\") # Extract mask if available (segmentation mode)
if masks is not None and i < len(masks.data):
mask_data = masks.data[i].cpu().numpy() # Get mask for this detection
detection['has_mask'] = True
detection['mask'] = mask_data
logger.debug(f"YOLO Segmentation: Detected human with mask - conf={conf:.2f}, mask_shape={mask_data.shape}")
else:
logger.debug(f"YOLO Detection: Detected human with bbox - conf={conf:.2f}, bbox={coords}")
human_detections.append(detection)
if self.supports_segmentation:
masks_found = sum(1 for d in human_detections if d['has_mask'])
logger.info(f"YOLO Segmentation: Found {len(human_detections)} humans, {masks_found} with masks")
else:
logger.debug(f"YOLO Detection: Found {len(human_detections)} humans with bounding boxes")
return human_detections return human_detections
def detect_humans_in_video_first_frame(self, video_path: str, scale: float = 1.0) -> List[Dict[str, Any]]: def detect_humans_in_video_first_frame(self, video_path: str, scale: float = 1.0) -> List[Dict[str, Any]]:
\"\"\" """
Detect humans in the first frame of a video. Detect humans in the first frame of a video.
Args: Args:
@@ -84,21 +128,21 @@ class YOLODetector:
Returns: Returns:
List of human detection dictionaries List of human detection dictionaries
\"\"\" """
if not os.path.exists(video_path): if not os.path.exists(video_path):
logger.error(f\"Video file not found: {video_path}\") logger.error(f"Video file not found: {video_path}")
return [] return []
cap = cv2.VideoCapture(video_path) cap = cv2.VideoCapture(video_path)
if not cap.isOpened(): if not cap.isOpened():
logger.error(f\"Could not open video: {video_path}\") logger.error(f"Could not open video: {video_path}")
return [] return []
ret, frame = cap.read() ret, frame = cap.read()
cap.release() cap.release()
if not ret: if not ret:
logger.error(f\"Could not read first frame from: {video_path}\") logger.error(f"Could not read first frame from: {video_path}")
return [] return []
# Scale frame if needed # Scale frame if needed
@@ -108,7 +152,7 @@ class YOLODetector:
return self.detect_humans_in_frame(frame) return self.detect_humans_in_frame(frame)
def save_detections_to_file(self, detections: List[Dict[str, Any]], output_path: str) -> bool: def save_detections_to_file(self, detections: List[Dict[str, Any]], output_path: str) -> bool:
\"\"\" """
Save detection results to file. Save detection results to file.
Args: Args:
@@ -117,26 +161,26 @@ class YOLODetector:
Returns: Returns:
True if saved successfully True if saved successfully
\"\"\" """
try: try:
with open(output_path, 'w') as f: with open(output_path, 'w') as f:
f.write(\"# YOLO Human Detections\\n\") f.write("# YOLO Human Detections\\n")
if detections: if detections:
for detection in detections: for detection in detections:
bbox = detection['bbox'] bbox = detection['bbox']
conf = detection['confidence'] conf = detection['confidence']
f.write(f\"{bbox[0]},{bbox[1]},{bbox[2]},{bbox[3]},{conf}\\n\") f.write(f"{bbox[0]},{bbox[1]},{bbox[2]},{bbox[3]},{conf}\\n")
logger.info(f\"Saved {len(detections)} detections to {output_path}\") logger.info(f"Saved {len(detections)} detections to {output_path}")
else: else:
f.write(\"# No humans detected\\n\") f.write("# No humans detected\\n")
logger.info(f\"Saved empty detection file to {output_path}\") logger.info(f"Saved empty detection file to {output_path}")
return True return True
except Exception as e: except Exception as e:
logger.error(f\"Failed to save detections to {output_path}: {e}\") logger.error(f"Failed to save detections to {output_path}: {e}")
return False return False
def load_detections_from_file(self, file_path: str) -> List[Dict[str, Any]]: def load_detections_from_file(self, file_path: str) -> List[Dict[str, Any]]:
\"\"\" """
Load detection results from file. Load detection results from file.
Args: Args:
@@ -144,16 +188,24 @@ class YOLODetector:
Returns: Returns:
List of detection dictionaries List of detection dictionaries
\"\"\" """
detections = [] detections = []
if not os.path.exists(file_path): if not os.path.exists(file_path):
logger.warning(f\"Detection file not found: {file_path}\") logger.warning(f"Detection file not found: {file_path}")
return detections return detections
try: try:
with open(file_path, 'r') as f: with open(file_path, 'r') as f:
for line in f: content = f.read()
# Handle files with literal \n characters
if '\\n' in content:
lines = content.split('\\n')
else:
lines = content.split('\n')
for line in lines:
line = line.strip() line = line.strip()
# Skip comments and empty lines # Skip comments and empty lines
if line.startswith('#') or not line: if line.startswith('#') or not line:
@@ -170,18 +222,132 @@ class YOLODetector:
'confidence': conf 'confidence': conf
}) })
except ValueError: except ValueError:
logger.warning(f\"Invalid detection line: {line}\") logger.warning(f"Invalid detection line: {line}")
continue continue
logger.info(f\"Loaded {len(detections)} detections from {file_path}\") logger.info(f"Loaded {len(detections)} detections from {file_path}")
except Exception as e: except Exception as e:
logger.error(f\"Failed to load detections from {file_path}: {e}\") logger.error(f"Failed to load detections from {file_path}: {e}")
return detections return detections
def debug_detect_with_lower_confidence(self, frame: np.ndarray, debug_confidence: float = 0.3) -> List[Dict[str, Any]]:
"""
Run YOLO detection with a lower confidence threshold for debugging.
This helps identify if detections are being missed due to high confidence threshold.
Args:
frame: Input frame (BGR format from OpenCV)
debug_confidence: Lower confidence threshold for debugging
Returns:
List of human detection dictionaries with lower confidence threshold
"""
logger.info(f"VR180 Debug: Running YOLO with lower confidence {debug_confidence} (vs normal {self.confidence_threshold})")
# Run YOLO detection with lower confidence
results = self.model(frame, conf=debug_confidence, verbose=False)
debug_detections = []
# Process results
for result in results:
boxes = result.boxes
if boxes is not None:
for box in boxes:
# Get class ID
cls = int(box.cls.cpu().numpy()[0])
# Check if it's a person (human_class_id)
if cls == self.human_class_id:
# Get bounding box coordinates (x1, y1, x2, y2)
coords = box.xyxy[0].cpu().numpy()
conf = float(box.conf.cpu().numpy()[0])
debug_detections.append({
'bbox': coords,
'confidence': conf
})
logger.info(f"VR180 Debug: Lower confidence detection found {len(debug_detections)} total detections")
return debug_detections
def detect_humans_multi_frame(self, video_path: str, frame_indices: List[int],
scale: float = 1.0) -> Dict[int, List[Dict[str, Any]]]:
"""
Detect humans at multiple specific frame indices in a video.
Used for mid-segment re-detection to improve SAM2 tracking.
Args:
video_path: Path to video file
frame_indices: List of frame indices to run detection on (e.g., [0, 30, 60, 90])
scale: Scale factor for frame processing
Returns:
Dictionary mapping frame_index -> list of detection dictionaries
"""
if not frame_indices:
logger.warning("No frame indices provided for multi-frame detection")
return {}
if not os.path.exists(video_path):
logger.error(f"Video file not found: {video_path}")
return {}
logger.info(f"Mid-segment Detection: Running YOLO on {len(frame_indices)} frames: {frame_indices}")
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
logger.error(f"Could not open video: {video_path}")
return {}
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
# Filter out frame indices that are beyond video length
valid_frame_indices = [idx for idx in frame_indices if 0 <= idx < total_frames]
if len(valid_frame_indices) != len(frame_indices):
invalid_frames = [idx for idx in frame_indices if idx not in valid_frame_indices]
logger.warning(f"Mid-segment Detection: Skipping invalid frame indices: {invalid_frames} (video has {total_frames} frames)")
multi_frame_detections = {}
for frame_idx in valid_frame_indices:
# Seek to specific frame
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
ret, frame = cap.read()
if not ret:
logger.warning(f"Mid-segment Detection: Could not read frame {frame_idx}")
continue
# Scale frame if needed
if scale != 1.0:
frame = cv2.resize(frame, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR)
# Run YOLO detection on this frame
detections = self.detect_humans_in_frame(frame)
multi_frame_detections[frame_idx] = detections
# Log detection results
time_seconds = frame_idx / fps
logger.info(f"Mid-segment Detection: Frame {frame_idx} (t={time_seconds:.1f}s): {len(detections)} humans detected")
for i, detection in enumerate(detections):
bbox = detection['bbox']
conf = detection['confidence']
logger.debug(f"Mid-segment Detection: Frame {frame_idx}, Human {i+1}: bbox={bbox}, conf={conf:.3f}")
cap.release()
total_detections = sum(len(dets) for dets in multi_frame_detections.values())
logger.info(f"Mid-segment Detection: Complete - {total_detections} total detections across {len(valid_frame_indices)} frames")
return multi_frame_detections
def process_segments_batch(self, segments_info: List[dict], detect_segments: List[int], def process_segments_batch(self, segments_info: List[dict], detect_segments: List[int],
scale: float = 0.5) -> Dict[int, List[Dict[str, Any]]]: scale: float = 0.5) -> Dict[int, List[Dict[str, Any]]]:
\"\"\" """
Process multiple segments for human detection. Process multiple segments for human detection.
Args: Args:
@@ -191,7 +357,7 @@ class YOLODetector:
Returns: Returns:
Dictionary mapping segment index to detection results Dictionary mapping segment index to detection results
\"\"\" """
results = {} results = {}
for segment_info in segments_info: for segment_info in segments_info:
@@ -202,17 +368,17 @@ class YOLODetector:
continue continue
video_path = segment_info['video_file'] video_path = segment_info['video_file']
detection_file = os.path.join(segment_info['directory'], \"yolo_detections\") detection_file = os.path.join(segment_info['directory'], "yolo_detections")
# Skip if already processed # Skip if already processed
if os.path.exists(detection_file): if os.path.exists(detection_file):
logger.info(f\"Segment {segment_idx} already has detections, skipping\") logger.info(f"Segment {segment_idx} already has detections, skipping")
detections = self.load_detections_from_file(detection_file) detections = self.load_detections_from_file(detection_file)
results[segment_idx] = detections results[segment_idx] = detections
continue continue
# Run detection # Run detection
logger.info(f\"Processing segment {segment_idx} for human detection\") logger.info(f"Processing segment {segment_idx} for human detection")
detections = self.detect_humans_in_video_first_frame(video_path, scale) detections = self.detect_humans_in_video_first_frame(video_path, scale)
# Save results # Save results
@@ -223,8 +389,9 @@ class YOLODetector:
def convert_detections_to_sam2_prompts(self, detections: List[Dict[str, Any]], def convert_detections_to_sam2_prompts(self, detections: List[Dict[str, Any]],
frame_width: int) -> List[Dict[str, Any]]: frame_width: int) -> List[Dict[str, Any]]:
\"\"\" """
Convert YOLO detections to SAM2-compatible prompts for stereo video. Convert YOLO detections to SAM2-compatible prompts for VR180 SBS video.
For VR180, we expect 2 real detections (left and right eye views), not mirrored ones.
Args: Args:
detections: List of YOLO detection results detections: List of YOLO detection results
@@ -232,55 +399,337 @@ class YOLODetector:
Returns: Returns:
List of SAM2 prompt dictionaries with obj_id and bbox List of SAM2 prompt dictionaries with obj_id and bbox
\"\"\" """
if not detections: if not detections:
logger.warning("No detections provided for SAM2 prompt conversion")
return [] return []
half_frame_width = frame_width // 2 half_frame_width = frame_width // 2
prompts = [] prompts = []
logger.info(f"VR180 SBS Debug: Converting {len(detections)} detections for frame width {frame_width}")
logger.info(f"VR180 SBS Debug: Half frame width = {half_frame_width}")
# Sort detections by x-coordinate to get consistent left/right assignment # Sort detections by x-coordinate to get consistent left/right assignment
sorted_detections = sorted(detections, key=lambda x: x['bbox'][0]) sorted_detections = sorted(detections, key=lambda x: x['bbox'][0])
# Analyze detections by frame half
left_detections = []
right_detections = []
for i, detection in enumerate(sorted_detections):
bbox = detection['bbox'].copy()
center_x = (bbox[0] + bbox[2]) / 2
pixel_range = f"{bbox[0]:.0f}-{bbox[2]:.0f}"
if center_x < half_frame_width:
left_detections.append((detection, i, pixel_range))
side = "LEFT"
else:
right_detections.append((detection, i, pixel_range))
side = "RIGHT"
logger.info(f"VR180 SBS Debug: Detection {i}: pixels {pixel_range}, center_x={center_x:.1f}, side={side}")
# VR180 SBS Format Validation
logger.info(f"VR180 SBS Debug: Found {len(left_detections)} LEFT detections, {len(right_detections)} RIGHT detections")
# Analyze confidence scores
if left_detections:
left_confidences = [det[0]['confidence'] for det in left_detections]
logger.info(f"VR180 SBS Debug: LEFT eye confidences: {[f'{c:.3f}' for c in left_confidences]}")
if right_detections:
right_confidences = [det[0]['confidence'] for det in right_detections]
logger.info(f"VR180 SBS Debug: RIGHT eye confidences: {[f'{c:.3f}' for c in right_confidences]}")
if len(right_detections) == 0:
logger.warning(f"VR180 SBS Warning: No detections found in RIGHT eye view (pixels {half_frame_width}-{frame_width})")
logger.warning(f"VR180 SBS Warning: This may indicate:")
logger.warning(f" 1. Person not visible in right eye view")
logger.warning(f" 2. YOLO confidence threshold ({self.confidence_threshold}) too high")
logger.warning(f" 3. VR180 SBS format issue")
logger.warning(f" 4. Right eye view quality/lighting problems")
logger.warning(f"VR180 SBS Suggestion: Try lowering yolo_confidence to 0.3-0.4 in config")
if len(left_detections) == 0:
logger.warning(f"VR180 SBS Warning: No detections found in LEFT eye view (pixels 0-{half_frame_width})")
# Additional validation for VR180 SBS expectations
total_detections = len(left_detections) + len(right_detections)
if total_detections == 1:
logger.warning(f"VR180 SBS Warning: Only 1 detection found - expected 2 for proper VR180 SBS")
elif total_detections > 2:
logger.warning(f"VR180 SBS Warning: {total_detections} detections found - will use only first 2")
# Assign object IDs sequentially, regardless of which half they're in
# This ensures we always get Object 1 and Object 2 for up to 2 detections
obj_id = 1 obj_id = 1
for i, detection in enumerate(sorted_detections[:2]): # Take up to 2 humans # Process up to 2 detections total (left + right combined)
all_detections = sorted_detections[:2]
for i, detection in enumerate(all_detections):
bbox = detection['bbox'].copy() bbox = detection['bbox'].copy()
# For stereo videos, assign obj_id based on position
if len(sorted_detections) >= 2:
center_x = (bbox[0] + bbox[2]) / 2 center_x = (bbox[0] + bbox[2]) / 2
pixel_range = f"{bbox[0]:.0f}-{bbox[2]:.0f}"
# Determine which eye view this detection is in
if center_x < half_frame_width: if center_x < half_frame_width:
current_obj_id = 1 # Left human eye_view = "LEFT"
else: else:
current_obj_id = 2 # Right human eye_view = "RIGHT"
else:
# If only one human, create prompts for both sides
current_obj_id = obj_id
obj_id += 1
# Create mirrored version for stereo
if obj_id <= 2:
mirrored_bbox = bbox.copy()
mirrored_bbox[0] += half_frame_width # Shift x1
mirrored_bbox[2] += half_frame_width # Shift x2
# Ensure mirrored bbox is within frame bounds
mirrored_bbox[0] = max(0, min(mirrored_bbox[0], frame_width - 1))
mirrored_bbox[2] = max(0, min(mirrored_bbox[2], frame_width - 1))
prompts.append({ prompts.append({
'obj_id': obj_id, 'obj_id': obj_id,
'bbox': mirrored_bbox,
'confidence': detection['confidence']
})
obj_id += 1
prompts.append({
'obj_id': current_obj_id,
'bbox': bbox, 'bbox': bbox,
'confidence': detection['confidence'] 'confidence': detection['confidence']
}) })
logger.debug(f\"Converted {len(detections)} detections to {len(prompts)} SAM2 prompts\") logger.info(f"VR180 SBS Debug: Added {eye_view} eye detection as SAM2 Object {obj_id}")
logger.info(f"VR180 SBS Debug: Object {obj_id} bbox: {bbox} (pixels {pixel_range})")
obj_id += 1
logger.info(f"VR180 SBS Debug: Final result - {len(detections)} YOLO detections → {len(prompts)} SAM2 prompts")
# Verify we have the expected objects
obj_ids = [p['obj_id'] for p in prompts]
logger.info(f"VR180 SBS Debug: SAM2 Object IDs created: {obj_ids}")
return prompts return prompts
def convert_yolo_masks_to_video_segments(self, detections: List[Dict[str, Any]],
frame_width: int, target_frame_shape: Tuple[int, int] = None) -> Optional[Dict[int, Dict[int, np.ndarray]]]:
"""
Convert YOLO segmentation masks to SAM2-compatible video segments format.
This allows using YOLO masks directly without SAM2 processing.
Args:
detections: List of YOLO detection results with masks
frame_width: Width of the video frame for VR180 object ID assignment
target_frame_shape: Target shape (height, width) for mask resizing
Returns:
Video segments dictionary compatible with SAM2 output format, or None if no masks
"""
if not detections:
logger.warning("No detections provided for mask conversion")
return None
# Check if any detections have masks
detections_with_masks = [d for d in detections if d.get('has_mask', False)]
if not detections_with_masks:
logger.warning("No detections have masks - YOLO segmentation may not be working")
return None
logger.info(f"YOLO Mask Conversion: Converting {len(detections_with_masks)} YOLO masks to video segments format")
half_frame_width = frame_width // 2
video_segments = {}
# Create frame 0 with converted masks
frame_masks = {}
obj_id = 1
# Sort detections by x-coordinate for consistent VR180 SBS assignment
sorted_detections = sorted(detections_with_masks, key=lambda x: x['bbox'][0])
for i, detection in enumerate(sorted_detections[:2]): # Take up to 2 humans
mask = detection['mask']
bbox = detection['bbox']
center_x = (bbox[0] + bbox[2]) / 2
# Assign sequential object IDs (similar to prompt conversion logic)
current_obj_id = obj_id
# Determine which eye view for logging
if center_x < half_frame_width:
eye_view = "LEFT"
else:
eye_view = "RIGHT"
# Resize mask to target frame shape if specified
if target_frame_shape and mask.shape != target_frame_shape:
mask_resized = cv2.resize(mask.astype(np.float32), (target_frame_shape[1], target_frame_shape[0]), interpolation=cv2.INTER_NEAREST)
mask = (mask_resized > 0.5).astype(bool)
else:
mask = mask.astype(bool)
frame_masks[current_obj_id] = mask
logger.info(f"YOLO Mask Conversion: {eye_view} eye detection -> Object {current_obj_id}, mask_shape={mask.shape}, pixels={np.sum(mask)}")
obj_id += 1 # Always increment for next detection
# Store masks in video segments format (single frame)
video_segments[0] = frame_masks
total_objects = len(frame_masks)
total_pixels = sum(np.sum(mask) for mask in frame_masks.values())
logger.info(f"YOLO Mask Conversion: Created video segments with {total_objects} objects, {total_pixels} total mask pixels")
return video_segments
def save_debug_frame_with_detections(self, frame: np.ndarray, detections: List[Dict[str, Any]],
output_path: str, prompts: List[Dict[str, Any]] = None) -> bool:
"""
Save a debug frame with YOLO detections and SAM2 prompts overlaid as bounding boxes.
Args:
frame: Input frame (BGR format from OpenCV)
detections: List of detection dictionaries with bbox and confidence
output_path: Path to save the debug image
prompts: Optional list of SAM2 prompt dictionaries with obj_id and bbox
Returns:
True if saved successfully
"""
try:
debug_frame = frame.copy()
# Draw masks (if available) or bounding boxes for each detection
for i, detection in enumerate(detections):
bbox = detection['bbox']
confidence = detection['confidence']
has_mask = detection.get('has_mask', False)
# Extract coordinates
x1, y1, x2, y2 = map(int, bbox)
# Choose color based on confidence (green for high, yellow for medium, red for low)
if confidence >= 0.8:
color = (0, 255, 0) # Green
elif confidence >= 0.6:
color = (0, 255, 255) # Yellow
else:
color = (0, 0, 255) # Red
if has_mask and 'mask' in detection:
# Draw segmentation mask
mask = detection['mask']
# Resize mask to match frame if needed
if mask.shape != debug_frame.shape[:2]:
mask = cv2.resize(mask.astype(np.float32), (debug_frame.shape[1], debug_frame.shape[0]), interpolation=cv2.INTER_NEAREST)
mask = mask > 0.5
mask = mask.astype(bool)
# Apply colored overlay with transparency
overlay = debug_frame.copy()
overlay[mask] = color
cv2.addWeighted(overlay, 0.3, debug_frame, 0.7, 0, debug_frame)
# Draw mask outline
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
cv2.drawContours(debug_frame, contours, -1, color, 2)
# Prepare label text for segmentation
label = f"Person {i+1}: {confidence:.2f} (MASK)"
else:
# Draw bounding box (detection mode or no mask available)
cv2.rectangle(debug_frame, (x1, y1), (x2, y2), color, 2)
# Prepare label text for detection
label = f"Person {i+1}: {confidence:.2f} (BBOX)"
label_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)[0]
# Draw label background
cv2.rectangle(debug_frame,
(x1, y1 - label_size[1] - 10),
(x1 + label_size[0], y1),
color, -1)
# Draw label text
cv2.putText(debug_frame, label,
(x1, y1 - 5),
cv2.FONT_HERSHEY_SIMPLEX, 0.6,
(255, 255, 255), 2)
# Draw SAM2 prompts if provided (with different colors/style)
if prompts:
for prompt in prompts:
obj_id = prompt['obj_id']
bbox = prompt['bbox']
# Extract coordinates
x1, y1, x2, y2 = map(int, bbox)
# Use different colors for each object ID
if obj_id == 1:
prompt_color = (0, 255, 0) # Green for Object 1
elif obj_id == 2:
prompt_color = (255, 0, 0) # Blue for Object 2
else:
prompt_color = (255, 255, 0) # Cyan for others
# Draw thicker, dashed-style border for SAM2 prompts
thickness = 3
cv2.rectangle(debug_frame, (x1-2, y1-2), (x2+2, y2+2), prompt_color, thickness)
# Add SAM2 object ID label
sam_label = f"SAM2 Obj {obj_id}"
label_size = cv2.getTextSize(sam_label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)[0]
# Draw label background
cv2.rectangle(debug_frame,
(x1-2, y2+5),
(x1-2 + label_size[0], y2+5 + label_size[1] + 5),
prompt_color, -1)
# Draw label text
cv2.putText(debug_frame, sam_label,
(x1-2, y2+5 + label_size[1]),
cv2.FONT_HERSHEY_SIMPLEX, 0.5,
(255, 255, 255), 2)
# Draw VR180 SBS boundary line (center line separating left and right eye views)
frame_height, frame_width = debug_frame.shape[:2]
center_x = frame_width // 2
cv2.line(debug_frame, (center_x, 0), (center_x, frame_height), (0, 255, 255), 3) # Yellow line
# Add VR180 SBS labels
cv2.putText(debug_frame, "LEFT EYE", (10, frame_height - 20),
cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 255), 2)
cv2.putText(debug_frame, "RIGHT EYE", (center_x + 10, frame_height - 20),
cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 255), 2)
# Add summary text at top with mode information
mode_text = f"YOLO Mode: {self.mode.upper()}"
masks_available = sum(1 for d in detections if d.get('has_mask', False))
if self.supports_segmentation and masks_available > 0:
summary = f"VR180 SBS: {len(detections)} detections → {masks_available} MASKS (for SAM2 propagation)"
else:
summary = f"VR180 SBS: {len(detections)} detections → {len(prompts) if prompts else 0} SAM2 prompts"
cv2.putText(debug_frame, mode_text,
(10, 30),
cv2.FONT_HERSHEY_SIMPLEX, 0.8,
(0, 255, 255), 2) # Yellow for mode
cv2.putText(debug_frame, summary,
(10, 60),
cv2.FONT_HERSHEY_SIMPLEX, 1.0,
(255, 255, 255), 2)
# Add frame dimensions info
dims_info = f"Frame: {frame_width}x{frame_height}, Center: {center_x}"
cv2.putText(debug_frame, dims_info,
(10, 90),
cv2.FONT_HERSHEY_SIMPLEX, 0.6,
(255, 255, 255), 2)
# Save debug frame
success = cv2.imwrite(output_path, debug_frame)
if success:
logger.info(f"Saved YOLO debug frame to {output_path}")
else:
logger.error(f"Failed to save debug frame to {output_path}")
return success
except Exception as e:
logger.error(f"Error creating debug frame: {e}")
return False

317
download_models.py Executable file
View File

@@ -0,0 +1,317 @@
#!/usr/bin/env python3
"""
Model download script for YOLO + SAM2 video processing pipeline.
Downloads SAM2.1 models and organizes them in the models directory.
"""
import os
import urllib.request
import urllib.error
from pathlib import Path
import sys
def create_directory_structure():
"""Create the models directory structure."""
base_dir = Path(__file__).parent
models_dir = base_dir / "models"
# Create main models directory
models_dir.mkdir(exist_ok=True)
# Create subdirectories
sam2_dir = models_dir / "sam2"
sam2_configs_dir = sam2_dir / "configs" / "sam2.1"
sam2_checkpoints_dir = sam2_dir / "checkpoints"
yolo_dir = models_dir / "yolo"
sam2_dir.mkdir(exist_ok=True)
sam2_configs_dir.mkdir(parents=True, exist_ok=True)
sam2_checkpoints_dir.mkdir(exist_ok=True)
yolo_dir.mkdir(exist_ok=True)
print(f"Created models directory structure in: {models_dir}")
return models_dir, sam2_configs_dir, sam2_checkpoints_dir, yolo_dir
def download_file(url, destination, description="file"):
"""Download a file with progress indication."""
try:
print(f"Downloading {description}...")
print(f" URL: {url}")
print(f" Destination: {destination}")
def progress_hook(block_num, block_size, total_size):
if total_size > 0:
percent = min(100, (block_num * block_size * 100) // total_size)
sys.stdout.write(f"\r Progress: {percent}%")
sys.stdout.flush()
urllib.request.urlretrieve(url, destination, progress_hook)
print(f"\n ✓ Downloaded {description}")
return True
except urllib.error.URLError as e:
print(f"\n ✗ Failed to download {description}: {e}")
return False
except Exception as e:
print(f"\n ✗ Error downloading {description}: {e}")
return False
def download_sam2_models():
"""Download SAM2.1 model configurations and checkpoints."""
print("Setting up SAM2.1 models...")
# Create directory structure
models_dir, configs_dir, checkpoints_dir, yolo_dir = create_directory_structure()
# SAM2.1 model definitions
sam2_models = {
"tiny": {
"config_url": "https://raw.githubusercontent.com/facebookresearch/sam2/main/sam2/configs/sam2.1/sam2.1_hiera_t.yaml",
"checkpoint_url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt",
"config_file": "sam2.1_hiera_t.yaml",
"checkpoint_file": "sam2.1_hiera_tiny.pt"
},
"small": {
"config_url": "https://raw.githubusercontent.com/facebookresearch/sam2/main/sam2/configs/sam2.1/sam2.1_hiera_s.yaml",
"checkpoint_url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt",
"config_file": "sam2.1_hiera_s.yaml",
"checkpoint_file": "sam2.1_hiera_small.pt"
},
"base_plus": {
"config_url": "https://raw.githubusercontent.com/facebookresearch/sam2/main/sam2/configs/sam2.1/sam2.1_hiera_b+.yaml",
"checkpoint_url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt",
"config_file": "sam2.1_hiera_b+.yaml",
"checkpoint_file": "sam2.1_hiera_base_plus.pt"
},
"large": {
"config_url": "https://raw.githubusercontent.com/facebookresearch/sam2/main/sam2/configs/sam2.1/sam2.1_hiera_l.yaml",
"checkpoint_url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt",
"config_file": "sam2.1_hiera_l.yaml",
"checkpoint_file": "sam2.1_hiera_large.pt"
}
}
success_count = 0
total_downloads = len(sam2_models) * 2 # configs + checkpoints
# Download each model's config and checkpoint
for model_name, model_info in sam2_models.items():
print(f"\n--- Downloading SAM2.1 {model_name.upper()} model ---")
# Download config file
config_path = configs_dir / model_info["config_file"]
if not config_path.exists():
if download_file(
model_info["config_url"],
config_path,
f"SAM2.1 {model_name} config"
):
success_count += 1
else:
print(f" ✓ Config file already exists: {config_path}")
success_count += 1
# Download checkpoint file
checkpoint_path = checkpoints_dir / model_info["checkpoint_file"]
if not checkpoint_path.exists():
if download_file(
model_info["checkpoint_url"],
checkpoint_path,
f"SAM2.1 {model_name} checkpoint"
):
success_count += 1
else:
print(f" ✓ Checkpoint file already exists: {checkpoint_path}")
success_count += 1
print(f"\n=== Download Summary ===")
print(f"Successfully downloaded: {success_count}/{total_downloads} files")
if success_count == total_downloads:
print("✓ All SAM2.1 models downloaded successfully!")
return True
else:
print(f"⚠ Some downloads failed ({total_downloads - success_count} files)")
return False
def download_yolo_models():
"""Download default YOLO models to models directory."""
print("\n--- Setting up YOLO models ---")
print(" Downloading both detection and segmentation models...")
try:
from ultralytics import YOLO
import torch
# Default YOLO models to download (both detection and segmentation)
yolo_models = [
"yolov8n.pt", # Detection models
"yolov8s.pt",
"yolov8m.pt",
"yolo11l.pt", # YOLOv11 detection models
"yolo11x.pt",
"yolov8n-seg.pt", # Segmentation models
"yolov8s-seg.pt",
"yolov8m-seg.pt",
"yolo11l-seg.pt", # YOLOv11 segmentation models
"yolo11x-seg.pt"
]
models_dir = Path(__file__).parent / "models" / "yolo"
for model_name in yolo_models:
model_path = models_dir / model_name
if not model_path.exists():
print(f"Downloading {model_name}...")
try:
# First try to download using the YOLO class with export
model = YOLO(model_name)
# Export/save the model to our directory
# The model.ckpt is the internal checkpoint
if hasattr(model, 'ckpt') and hasattr(model.ckpt, 'save'):
# Save the checkpoint directly
torch.save(model.ckpt, str(model_path))
print(f" ✓ Saved {model_name} to models directory")
else:
# Alternative: try to find where YOLO downloaded the model
import shutil
# Common locations where YOLO might store models
possible_paths = [
Path.home() / ".cache" / "ultralytics" / "models" / model_name,
Path.home() / ".ultralytics" / "models" / model_name,
Path.home() / "runs" / "detect" / model_name,
Path.cwd() / model_name, # Current directory
]
found = False
for possible_path in possible_paths:
if possible_path.exists():
shutil.copy2(possible_path, model_path)
print(f" ✓ Copied {model_name} from {possible_path}")
found = True
# Clean up if it was downloaded to current directory
if possible_path.parent == Path.cwd() and possible_path != model_path:
possible_path.unlink()
break
if not found:
# Last resort: use urllib to download directly
# Use different release versions for different YOLO versions
if model_name.startswith("yolov11"):
yolo_url = f"https://github.com/ultralytics/assets/releases/download/v8.3.0/{model_name}"
else:
yolo_url = f"https://github.com/ultralytics/assets/releases/download/v8.2.0/{model_name}"
print(f" Downloading directly from {yolo_url}...")
download_file(yolo_url, str(model_path), f"YOLO {model_name}")
except Exception as e:
print(f" ⚠ Error downloading {model_name}: {e}")
# Try direct download as fallback
try:
# Use different release versions for different YOLO versions
if model_name.startswith("yolov11"):
yolo_url = f"https://github.com/ultralytics/assets/releases/download/v8.3.0/{model_name}"
else:
yolo_url = f"https://github.com/ultralytics/assets/releases/download/v8.2.0/{model_name}"
print(f" Trying direct download from {yolo_url}...")
download_file(yolo_url, str(model_path), f"YOLO {model_name}")
except Exception as e2:
print(f" ✗ Failed to download {model_name}: {e2}")
else:
print(f"{model_name} already exists")
# Verify all models exist
success = all((models_dir / model).exists() for model in yolo_models)
if success:
print("✓ YOLO models setup complete!")
print(" Available detection models: yolov8n.pt, yolov8s.pt, yolov8m.pt, yolov11l.pt, yolov11x.pt")
print(" Available segmentation models: yolov8n-seg.pt, yolov8s-seg.pt, yolov8m-seg.pt, yolov11l-seg.pt, yolov11x-seg.pt")
else:
missing_models = [model for model in yolo_models if not (models_dir / model).exists()]
print("⚠ Some YOLO models may be missing:")
for model in missing_models:
print(f" - {model}")
return success
except ImportError:
print("⚠ ultralytics not installed. YOLO models will be downloaded on first use.")
return False
except Exception as e:
print(f"⚠ Error setting up YOLO models: {e}")
return False
def update_config_file():
"""Update config.yaml to use local model paths."""
print("\n--- Updating config.yaml ---")
config_path = Path(__file__).parent / "config.yaml"
if not config_path.exists():
print("⚠ config.yaml not found, skipping update")
return False
try:
# Read current config
with open(config_path, 'r') as f:
content = f.read()
# Update model paths to use local models
updated_content = content.replace(
'yolo_model: "yolov8n.pt"',
'yolo_model: "models/yolo/yolov8n.pt"'
).replace(
'yolo_detection_model: "models/yolo/yolov8n.pt"',
'yolo_detection_model: "models/yolo/yolov8n.pt"'
).replace(
'yolo_segmentation_model: "models/yolo/yolov8n-seg.pt"',
'yolo_segmentation_model: "models/yolo/yolov8n-seg.pt"'
).replace(
'sam2_checkpoint: "../checkpoints/sam2.1_hiera_large.pt"',
'sam2_checkpoint: "models/sam2/checkpoints/sam2.1_hiera_large.pt"'
).replace(
'sam2_config: "configs/sam2.1/sam2.1_hiera_l.yaml"',
'sam2_config: "models/sam2/configs/sam2.1/sam2.1_hiera_l.yaml"'
)
# Write updated config
with open(config_path, 'w') as f:
f.write(updated_content)
print("✓ Updated config.yaml to use local model paths")
return True
except Exception as e:
print(f"⚠ Error updating config.yaml: {e}")
return False
def main():
"""Main function to download all models."""
print("🤖 YOLO + SAM2 Model Download Script")
print("="*50)
# Download SAM2 models
sam2_success = download_sam2_models()
# Download YOLO models
yolo_success = download_yolo_models()
# Update config file
config_success = update_config_file()
print("\n" + "="*50)
print("📋 Final Summary:")
print(f" SAM2 models: {'' if sam2_success else ''}")
print(f" YOLO models: {'' if yolo_success else ''}")
print(f" Config update: {'' if config_success else ''}")
if sam2_success and config_success:
print("\n🎉 Setup complete! You can now run the pipeline with:")
print(" python main.py --config config.yaml")
else:
print("\n⚠ Some steps failed. Check the output above for details.")
print("\n📁 Models are organized in:")
print(f" {Path(__file__).parent / 'models'}")
if __name__ == "__main__":
main()

530
main.py
View File

@@ -8,6 +8,8 @@ and creating green screen masks with SAM2.
import os import os
import sys import sys
import argparse import argparse
import cv2
import numpy as np
from typing import List from typing import List
# Add project root to path # Add project root to path
@@ -16,6 +18,9 @@ sys.path.append(os.path.dirname(__file__))
from core.config_loader import ConfigLoader from core.config_loader import ConfigLoader
from core.video_splitter import VideoSplitter from core.video_splitter import VideoSplitter
from core.yolo_detector import YOLODetector from core.yolo_detector import YOLODetector
from core.sam2_processor import SAM2Processor
from core.mask_processor import MaskProcessor
from core.video_assembler import VideoAssembler
from utils.logging_utils import setup_logging, get_logger from utils.logging_utils import setup_logging, get_logger
from utils.file_utils import ensure_directory from utils.file_utils import ensure_directory
from utils.status_utils import print_processing_status, cleanup_incomplete_segment from utils.status_utils import print_processing_status, cleanup_incomplete_segment
@@ -66,6 +71,100 @@ def validate_dependencies():
logger.error("Please install requirements: pip install -r requirements.txt") logger.error("Please install requirements: pip install -r requirements.txt")
return False return False
def create_yolo_mask_debug_frame(detections: List[dict], video_path: str, output_path: str, scale: float = 1.0) -> bool:
"""
Create debug visualization for YOLO direct masks.
Args:
detections: List of YOLO detections with masks
video_path: Path to video file
output_path: Path to save debug image
scale: Scale factor for frame processing
Returns:
True if debug frame was created successfully
"""
try:
# Load first frame
cap = cv2.VideoCapture(video_path)
ret, original_frame = cap.read()
cap.release()
if not ret:
logger.error("Could not read first frame for YOLO mask debug")
return False
# Scale frame if needed
if scale != 1.0:
original_frame = cv2.resize(original_frame, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR)
debug_frame = original_frame.copy()
# Define colors for each object
colors = {
1: (0, 255, 0), # Green for Object 1 (Left eye)
2: (255, 0, 0), # Blue for Object 2 (Right eye)
}
# Get detections with masks
detections_with_masks = [d for d in detections if d.get('has_mask', False)]
# Overlay masks with transparency
obj_id = 1
for detection in detections_with_masks[:2]: # Up to 2 objects
mask = detection['mask']
# Resize mask to match frame if needed
if mask.shape != original_frame.shape[:2]:
mask = cv2.resize(mask.astype(np.float32), (original_frame.shape[1], original_frame.shape[0]), interpolation=cv2.INTER_NEAREST)
mask = mask > 0.5
mask = mask.astype(bool)
# Apply colored overlay
color = colors.get(obj_id, (128, 128, 128))
overlay = debug_frame.copy()
overlay[mask] = color
# Blend with original (30% overlay, 70% original)
cv2.addWeighted(overlay, 0.3, debug_frame, 0.7, 0, debug_frame)
# Draw outline
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
cv2.drawContours(debug_frame, contours, -1, color, 2)
logger.info(f"YOLO Mask Debug: Object {obj_id} mask - shape: {mask.shape}, pixels: {np.sum(mask)}")
obj_id += 1
# Add title and source info
title = f"YOLO Direct Masks: {len(detections_with_masks)} objects detected"
cv2.putText(debug_frame, title, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (255, 255, 255), 2)
source_info = "Mask Source: YOLO Segmentation (DIRECT - No SAM2)"
cv2.putText(debug_frame, source_info, (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2) # Green for YOLO
# Add object legend
y_offset = 90
for i, detection in enumerate(detections_with_masks[:2]):
obj_id = i + 1
color = colors.get(obj_id, (128, 128, 128))
text = f"Object {obj_id}: {'Left Eye' if obj_id == 1 else 'Right Eye'} (YOLO Mask)"
cv2.putText(debug_frame, text, (10, y_offset), cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2)
y_offset += 30
# Save debug image
success = cv2.imwrite(output_path, debug_frame)
if success:
logger.info(f"YOLO Mask Debug: Saved debug frame to {output_path}")
else:
logger.error(f"Failed to save YOLO mask debug frame to {output_path}")
return success
except Exception as e:
logger.error(f"Error creating YOLO mask debug frame: {e}")
return False
def resolve_detect_segments(detect_segments, total_segments: int) -> List[int]: def resolve_detect_segments(detect_segments, total_segments: int) -> List[int]:
""" """
Resolve detect_segments configuration to list of segment indices. Resolve detect_segments configuration to list of segment indices.
@@ -157,31 +256,432 @@ def main():
detect_segments_config = config.get_detect_segments() detect_segments_config = config.get_detect_segments()
detect_segments = resolve_detect_segments(detect_segments_config, len(segments_info)) detect_segments = resolve_detect_segments(detect_segments_config, len(segments_info))
# Step 2: Run YOLO detection on specified segments # Initialize processors once
logger.info("Step 2: Running YOLO human detection") logger.info("Step 2: Initializing YOLO detector")
# Get YOLO mode and model paths
yolo_mode = config.get('models.yolo_mode', 'detection')
detection_model = config.get('models.yolo_detection_model', config.get_yolo_model_path())
segmentation_model = config.get('models.yolo_segmentation_model', None)
logger.info(f"YOLO Mode: {yolo_mode}")
detector = YOLODetector( detector = YOLODetector(
model_path=config.get_yolo_model_path(), detection_model_path=detection_model,
segmentation_model_path=segmentation_model,
mode=yolo_mode,
confidence_threshold=config.get_yolo_confidence(), confidence_threshold=config.get_yolo_confidence(),
human_class_id=config.get_human_class_id() human_class_id=config.get_human_class_id()
) )
detection_results = detector.process_segments_batch( logger.info("Step 3: Initializing SAM2 processor")
segments_info, sam2_processor = SAM2Processor(
detect_segments, checkpoint_path=config.get_sam2_checkpoint(),
config_path=config.get_sam2_config(),
vos_optimized=config.get('models.sam2_vos_optimized', False)
)
# Initialize mask processor with quality enhancements
mask_quality_config = config.get('mask_processing', {})
mask_processor = MaskProcessor(
green_color=config.get_green_color(),
blue_color=config.get_blue_color(),
mask_quality_config=mask_quality_config
)
# Process each segment sequentially (YOLO -> SAM2 -> Render)
logger.info("Step 4: Processing segments sequentially")
total_humans_detected = 0
for i, segment_info in enumerate(segments_info):
segment_idx = segment_info['index']
logger.info(f"Processing segment {segment_idx}/{len(segments_info)-1}")
# Reset temporal history for new segment
mask_processor.reset_temporal_history()
# Skip if segment output already exists
output_video = os.path.join(segment_info['directory'], f"output_{segment_idx}.mp4")
if os.path.exists(output_video):
logger.info(f"Segment {segment_idx} already processed, skipping")
continue
# Determine if we should use YOLO detections or previous masks
use_detections = segment_idx in detect_segments
# First segment must use detections
if segment_idx == 0 and not use_detections:
logger.warning(f"First segment must use YOLO detection")
use_detections = True
# Get YOLO prompts or previous masks
yolo_prompts = None
previous_masks = None
if use_detections:
# Run YOLO detection on current segment
logger.info(f"Running YOLO detection on segment {segment_idx}")
detection_file = os.path.join(segment_info['directory'], "yolo_detections")
# Check if detection already exists
if os.path.exists(detection_file):
logger.info(f"Loading existing YOLO detections for segment {segment_idx}")
detections = detector.load_detections_from_file(detection_file)
else:
# Run YOLO detection on first frame
detections = detector.detect_humans_in_video_first_frame(
segment_info['video_file'],
scale=config.get_inference_scale()
)
# Save detections for future runs
detector.save_detections_to_file(detections, detection_file)
if detections:
total_humans_detected += len(detections)
logger.info(f"Found {len(detections)} humans in segment {segment_idx}")
# Get frame width from video
cap = cv2.VideoCapture(segment_info['video_file'])
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
cap.release()
yolo_prompts = detector.convert_detections_to_sam2_prompts(
detections, frame_width
)
# If no right eye detections found, run debug analysis with lower confidence
half_frame_width = frame_width // 2
right_eye_detections = [d for d in detections if (d['bbox'][0] + d['bbox'][2]) / 2 >= half_frame_width]
if len(right_eye_detections) == 0 and config.get('advanced.save_yolo_debug_frames', False):
logger.info(f"VR180 Debug: No right eye detections found, running lower confidence analysis...")
# Load first frame for debug analysis
cap = cv2.VideoCapture(segment_info['video_file'])
ret, debug_frame = cap.read()
cap.release()
if ret:
# Scale frame to match detection scale
if config.get_inference_scale() != 1.0:
scale = config.get_inference_scale()
debug_frame = cv2.resize(debug_frame, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR)
# Run debug detection with lower confidence
debug_detections = detector.debug_detect_with_lower_confidence(debug_frame, debug_confidence=0.3)
# Analyze where these lower confidence detections are
debug_right_eye = [d for d in debug_detections if (d['bbox'][0] + d['bbox'][2]) / 2 >= half_frame_width]
if len(debug_right_eye) > 0:
logger.warning(f"VR180 Debug: Found {len(debug_right_eye)} right eye detections with lower confidence!")
for i, det in enumerate(debug_right_eye):
logger.warning(f"VR180 Debug: Right eye detection {i+1}: conf={det['confidence']:.3f}, bbox={det['bbox']}")
logger.warning(f"VR180 Debug: Consider lowering yolo_confidence from {config.get_yolo_confidence()} to 0.3-0.4")
else:
logger.info(f"VR180 Debug: No right eye detections found even with confidence 0.3")
logger.info(f"VR180 Debug: This confirms person is not visible in right eye view")
logger.info(f"Pipeline Debug: Segment {segment_idx} - Generated {len(yolo_prompts)} SAM2 prompts from {len(detections)} YOLO detections")
# Save debug frame with detections visualized (if enabled)
if config.get('advanced.save_yolo_debug_frames', False):
debug_frame_path = os.path.join(segment_info['directory'], "yolo_debug.jpg")
# Load first frame for debug visualization
cap = cv2.VideoCapture(segment_info['video_file'])
ret, debug_frame = cap.read()
cap.release()
if ret:
# Scale frame to match detection scale
if config.get_inference_scale() != 1.0:
scale = config.get_inference_scale()
debug_frame = cv2.resize(debug_frame, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR)
detector.save_debug_frame_with_detections(debug_frame, detections, debug_frame_path, yolo_prompts)
else:
logger.warning(f"Could not load frame for debug visualization in segment {segment_idx}")
# Check if we have YOLO masks for debug visualization
has_yolo_masks = False
if detections and detector.supports_segmentation:
has_yolo_masks = any(d.get('has_mask', False) for d in detections)
# Generate first frame masks debug (SAM2 or YOLO)
first_frame_debug_path = os.path.join(segment_info['directory'], "first_frame_detection.jpg")
if has_yolo_masks:
logger.info(f"Pipeline Debug: Generating YOLO first frame masks for segment {segment_idx}")
# Create YOLO mask debug visualization
create_yolo_mask_debug_frame(detections, segment_info['video_file'], first_frame_debug_path, config.get_inference_scale())
else:
logger.info(f"Pipeline Debug: Generating SAM2 first frame masks for segment {segment_idx}")
sam2_processor.generate_first_frame_debug_masks(
segment_info['video_file'],
yolo_prompts,
first_frame_debug_path,
config.get_inference_scale()
)
else:
logger.warning(f"No humans detected in segment {segment_idx}")
# Save debug frame even when no detections (if enabled)
if config.get('advanced.save_yolo_debug_frames', False):
debug_frame_path = os.path.join(segment_info['directory'], "yolo_debug_no_detections.jpg")
# Load first frame for debug visualization
cap = cv2.VideoCapture(segment_info['video_file'])
ret, debug_frame = cap.read()
cap.release()
if ret:
# Scale frame to match detection scale
if config.get_inference_scale() != 1.0:
scale = config.get_inference_scale()
debug_frame = cv2.resize(debug_frame, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR)
# Add "No detections" text overlay
cv2.putText(debug_frame, "YOLO: No humans detected",
(10, 30),
cv2.FONT_HERSHEY_SIMPLEX, 1.0,
(0, 0, 255), 2) # Red text
cv2.imwrite(debug_frame_path, debug_frame)
logger.info(f"Saved no-detection debug frame to {debug_frame_path}")
else:
logger.warning(f"Could not load frame for no-detection debug visualization in segment {segment_idx}")
elif segment_idx > 0:
# Try to load previous segment mask
for j in range(segment_idx - 1, -1, -1):
prev_segment_dir = segments_info[j]['directory']
previous_masks = sam2_processor.load_previous_segment_mask(prev_segment_dir)
if previous_masks:
logger.info(f"Using masks from segment {j} for segment {segment_idx}")
break
if not yolo_prompts and not previous_masks:
logger.error(f"No prompts or previous masks available for segment {segment_idx}")
continue
# Check if we have YOLO masks and can skip SAM2 (recheck in case detections were loaded from file)
if not 'has_yolo_masks' in locals():
has_yolo_masks = False
if detections and detector.supports_segmentation:
has_yolo_masks = any(d.get('has_mask', False) for d in detections)
if has_yolo_masks:
logger.info(f"Pipeline Debug: YOLO segmentation provided masks - using as SAM2 initial masks for segment {segment_idx}")
# Convert YOLO masks to initial masks for SAM2
cap = cv2.VideoCapture(segment_info['video_file'])
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
cap.release()
# Convert YOLO masks to the format expected by SAM2 add_previous_masks_to_predictor
yolo_masks_dict = {}
for i, detection in enumerate(detections[:2]): # Up to 2 objects
if detection.get('has_mask', False):
mask = detection['mask']
# Resize mask to match inference scale
if config.get_inference_scale() != 1.0:
scale = config.get_inference_scale()
scaled_height = int(frame_height * scale)
scaled_width = int(frame_width * scale)
mask = cv2.resize(mask.astype(np.float32), (scaled_width, scaled_height), interpolation=cv2.INTER_NEAREST)
mask = mask > 0.5
obj_id = i + 1 # Sequential object IDs
yolo_masks_dict[obj_id] = mask.astype(bool)
logger.info(f"Pipeline Debug: YOLO mask for Object {obj_id} - shape: {mask.shape}, pixels: {np.sum(mask)}")
logger.info(f"Pipeline Debug: Using YOLO masks as SAM2 initial masks - {len(yolo_masks_dict)} objects")
# Use traditional SAM2 pipeline with YOLO masks as initial masks
previous_masks = yolo_masks_dict
yolo_prompts = None # Don't use bounding box prompts when we have masks
# Debug what we're passing to SAM2
if yolo_prompts:
logger.info(f"Pipeline Debug: Passing {len(yolo_prompts)} YOLO prompts to SAM2 for segment {segment_idx}")
for i, prompt in enumerate(yolo_prompts):
logger.info(f"Pipeline Debug: Prompt {i+1}: Object {prompt['obj_id']}, bbox={prompt['bbox']}")
if previous_masks:
logger.info(f"Pipeline Debug: Using {len(previous_masks)} previous masks for segment {segment_idx}")
logger.info(f"Pipeline Debug: Previous mask object IDs: {list(previous_masks.keys())}")
# Handle mid-segment detection if enabled (works for both detection and segmentation modes)
multi_frame_prompts = None
if config.get('advanced.enable_mid_segment_detection', False) and (yolo_prompts or has_yolo_masks):
logger.info(f"Mid-segment Detection: Enabled for segment {segment_idx}")
# Calculate frame indices for re-detection
cap = cv2.VideoCapture(segment_info['video_file'])
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
cap.release()
redetection_interval = config.get('advanced.redetection_interval', 30)
max_redetections = config.get('advanced.max_redetections_per_segment', 10)
# Generate frame indices: [30, 60, 90, ...] (skip frame 0 since we already have first frame prompts)
frame_indices = []
frame_idx = redetection_interval
while frame_idx < total_frames and len(frame_indices) < max_redetections:
frame_indices.append(frame_idx)
frame_idx += redetection_interval
if frame_indices:
logger.info(f"Mid-segment Detection: Running YOLO on frames {frame_indices} (interval={redetection_interval})")
# Run multi-frame detection
multi_frame_detections = detector.detect_humans_multi_frame(
segment_info['video_file'],
frame_indices,
scale=config.get_inference_scale() scale=config.get_inference_scale()
) )
# Log detection summary # Convert detections to SAM2 prompts (different handling for segmentation vs detection mode)
total_humans = sum(len(detections) for detections in detection_results.values()) multi_frame_prompts = {}
logger.info(f"Detected {total_humans} humans across {len(detection_results)} segments") cap = cv2.VideoCapture(segment_info['video_file'])
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
cap.release()
# Step 3: Process segments with SAM2 (placeholder for now) for frame_idx, detections in multi_frame_detections.items():
logger.info("Step 3: SAM2 processing and green screen generation") if detections:
logger.info("SAM2 processing module not yet implemented - this is where segment processing would occur") if has_yolo_masks:
# Segmentation mode: convert YOLO masks to SAM2 mask prompts
frame_masks = {}
for i, detection in enumerate(detections[:2]): # Up to 2 objects
if detection.get('has_mask', False):
mask = detection['mask']
# Resize mask to match inference scale
if config.get_inference_scale() != 1.0:
scale = config.get_inference_scale()
scaled_height = int(frame_height * scale)
scaled_width = int(frame_width * scale)
mask = cv2.resize(mask.astype(np.float32), (scaled_width, scaled_height), interpolation=cv2.INTER_NEAREST)
mask = mask > 0.5
# Step 4: Assemble final video (placeholder for now) obj_id = i + 1 # Sequential object IDs
logger.info("Step 4: Assembling final video with audio") frame_masks[obj_id] = mask.astype(bool)
logger.info("Video assembly module not yet implemented - this is where concatenation and audio copying would occur") logger.debug(f"Mid-segment Detection: Frame {frame_idx}, Object {obj_id} mask - shape: {mask.shape}, pixels: {np.sum(mask)}")
if frame_masks:
# Store as mask prompts (different format than bbox prompts)
multi_frame_prompts[frame_idx] = {'masks': frame_masks}
logger.info(f"Mid-segment Detection: Frame {frame_idx} -> {len(frame_masks)} YOLO masks")
else:
# Detection mode: convert to bounding box prompts (existing logic)
prompts = detector.convert_detections_to_sam2_prompts(detections, frame_width)
multi_frame_prompts[frame_idx] = prompts
logger.info(f"Mid-segment Detection: Frame {frame_idx} -> {len(prompts)} SAM2 prompts")
logger.info(f"Mid-segment Detection: Generated prompts for {len(multi_frame_prompts)} frames")
else:
logger.info(f"Mid-segment Detection: No additional frames to process (segment has {total_frames} frames)")
elif config.get('advanced.enable_mid_segment_detection', False):
logger.info(f"Mid-segment Detection: Skipped for segment {segment_idx} (no initial YOLO data)")
# Process segment with SAM2
logger.info(f"Pipeline Debug: Starting SAM2 processing for segment {segment_idx}")
video_segments = sam2_processor.process_single_segment(
segment_info,
yolo_prompts=yolo_prompts,
previous_masks=previous_masks,
inference_scale=config.get_inference_scale(),
multi_frame_prompts=multi_frame_prompts
)
if video_segments is None:
logger.error(f"SAM2 processing failed for segment {segment_idx}")
continue
# Check if SAM2 produced adequate results
if len(video_segments) == 0:
logger.error(f"SAM2 produced no frames for segment {segment_idx}")
continue
elif len(video_segments) < 10: # Expected many frames for a 5-second segment
logger.warning(f"SAM2 produced very few frames ({len(video_segments)}) for segment {segment_idx} - this may indicate propagation failure")
# Debug what SAM2 produced
logger.info(f"Pipeline Debug: SAM2 completed for segment {segment_idx}")
logger.info(f"Pipeline Debug: Generated masks for {len(video_segments)} frames")
if video_segments:
# Check first frame to see what objects were tracked
first_frame_idx = min(video_segments.keys())
first_frame_objects = video_segments[first_frame_idx]
logger.info(f"Pipeline Debug: First frame contains {len(first_frame_objects)} tracked objects")
logger.info(f"Pipeline Debug: Tracked object IDs: {list(first_frame_objects.keys())}")
for obj_id, mask in first_frame_objects.items():
mask_pixels = np.sum(mask)
logger.info(f"Pipeline Debug: Object {obj_id} mask has {mask_pixels} pixels")
# Check last frame as well
last_frame_idx = max(video_segments.keys())
last_frame_objects = video_segments[last_frame_idx]
logger.info(f"Pipeline Debug: Last frame contains {len(last_frame_objects)} tracked objects")
logger.info(f"Pipeline Debug: Final object IDs: {list(last_frame_objects.keys())}")
# Save final masks for next segment
mask_path = os.path.join(segment_info['directory'], "mask.png")
sam2_processor.save_final_masks(
video_segments,
mask_path,
green_color=config.get_green_color(),
blue_color=config.get_blue_color()
)
# Apply green screen and save output video
success = mask_processor.process_segment(
segment_info,
video_segments,
use_nvenc=config.get_use_nvenc(),
bitrate=config.get_output_bitrate()
)
if success:
logger.info(f"Successfully processed segment {segment_idx}")
else:
logger.error(f"Failed to create green screen video for segment {segment_idx}")
# Log processing summary
logger.info(f"Sequential processing complete. Total humans detected: {total_humans_detected}")
# Step 3: Assemble final video
logger.info("Step 3: Assembling final video with audio")
# Initialize video assembler
assembler = VideoAssembler(
preserve_audio=config.get_preserve_audio(),
use_nvenc=config.get_use_nvenc()
)
# Verify all segments are complete
all_complete, missing = assembler.verify_segment_completeness(segments_dir)
if not all_complete:
logger.error(f"Cannot assemble video - missing segments: {missing}")
return 1
# Assemble final video
final_output = os.path.join(output_dir, config.get_output_filename())
success = assembler.assemble_final_video(
segments_dir,
input_video,
final_output,
bitrate=config.get_output_bitrate()
)
if success:
logger.info(f"Final video saved to: {final_output}")
logger.info("Pipeline completed successfully") logger.info("Pipeline completed successfully")
return 0 return 0

View File

@@ -6,6 +6,7 @@ opencv-python>=4.8.0
numpy>=1.24.0 numpy>=1.24.0
# SAM2 - Segment Anything Model 2 # SAM2 - Segment Anything Model 2
# Note: Make sure to run download_models.py after installing to get model weights
git+https://github.com/facebookresearch/sam2.git git+https://github.com/facebookresearch/sam2.git
# GPU acceleration (optional but recommended) # GPU acceleration (optional but recommended)
@@ -17,6 +18,8 @@ tqdm>=4.65.0
matplotlib>=3.7.0 matplotlib>=3.7.0
Pillow>=10.0.0 Pillow>=10.0.0
decord
# Optional: For advanced features # Optional: For advanced features
psutil>=5.9.0 # Memory monitoring psutil>=5.9.0 # Memory monitoring
pympler>=0.9 # Memory profiling (for debugging) pympler>=0.9 # Memory profiling (for debugging)

618
spec.md
View File

@@ -190,3 +190,621 @@ models:
- **Fine-tuned YOLO**: Domain-specific human detection models - **Fine-tuned YOLO**: Domain-specific human detection models
- **SAM2 Optimization**: Custom SAM2 checkpoints for video content - **SAM2 Optimization**: Custom SAM2 checkpoints for video content
- **Temporal Consistency**: Enhanced cross-segment mask propagation - **Temporal Consistency**: Enhanced cross-segment mask propagation
Here is the original monolithic script this repo is a refactor/modularization of. If something
doesn't work in this repo, then consult the following script becasue it does work so this can
be used to solve problems:
import os
import cv2
import numpy as np
import cupy as cp
from concurrent.futures import ThreadPoolExecutor
import torch
import logging
import sys
import gc
from sam2.build_sam import build_sam2_video_predictor
import argparse
from ultralytics import YOLO
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
# Variables for input and output directories
SAM2_CHECKPOINT = "../checkpoints/sam2.1_hiera_large.pt"
MODEL_CFG = "configs/sam2.1/sam2.1_hiera_l.yaml"
GREEN = [0, 255, 0]
BLUE = [255, 0, 0]
INFERENCE_SCALE = 0.50
FULL_SCALE = 1.0
# YOLO model for human detection (class 0 = person)
YOLO_MODEL_PATH = "yolov8n.pt" # You can change this to a custom model
YOLO_CONFIDENCE = 0.6
HUMAN_CLASS_ID = 0 # COCO class ID for person
def open_video(video_path):
"""
Opens a video file and returns a generator that yields frames.
Parameters:
- video_path: Path to the video file.
Returns:
- A generator that yields frames from the video.
"""
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
print(f"Error: Could not open video file {video_path}")
return
while True:
ret, frame = cap.read()
if not ret:
break
yield frame
cap.release()
def load_previous_segment_mask(prev_segment_dir):
mask_path = os.path.join(prev_segment_dir, "mask.png")
mask_image = cv2.imread(mask_path)
if mask_image is None:
raise FileNotFoundError(f"Mask image not found at {mask_path}")
# Ensure the mask_image has three color channels
if len(mask_image.shape) != 3 or mask_image.shape[2] != 3:
raise ValueError("Mask image does not have three color channels.")
mask_image = mask_image.astype(np.uint8)
# Extract Object A and Object B masks
mask_a = np.all(mask_image == GREEN, axis=2)
mask_b = np.all(mask_image == BLUE, axis=2)
per_obj_input_mask = {1: mask_a, 2: mask_b}
input_palette = None # No palette needed for binary mask
return per_obj_input_mask, input_palette
def apply_green_mask(frame, masks):
# Convert frame and masks to CuPy arrays
frame_gpu = cp.asarray(frame)
combined_mask = cp.zeros(frame_gpu.shape[:2], dtype=cp.bool_)
for mask in masks:
mask_gpu = cp.asarray(mask.squeeze())
if mask_gpu.shape != frame_gpu.shape[:2]:
resized_mask = cv2.resize(cp.asnumpy(mask_gpu).astype(cp.float32),
(frame_gpu.shape[1], frame_gpu.shape[0]))
mask_gpu = cp.asarray(resized_mask > 0.5) # Convert back to CuPy boolean array
else:
mask_gpu = mask_gpu.astype(cp.bool_) # Ensure boolean type
combined_mask |= mask_gpu # Perform the bitwise OR operation
green_background = cp.full(frame_gpu.shape, cp.array([0, 255, 0], dtype=cp.uint8), dtype=cp.uint8)
result_frame = cp.where(combined_mask[..., None], frame_gpu, green_background)
return cp.asnumpy(result_frame) # Convert back to NumPy
def initialize_predictor():
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
print(
"\nSupport for MPS devices is preliminary. SAM 2 is trained with CUDA and might "
"give numerically different outputs and sometimes degraded performance on MPS."
)
# Enable MPS fallback for operations not supported on MPS
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
else:
device = torch.device("cpu")
logger.info(f"Using device: {device}")
predictor = build_sam2_video_predictor(MODEL_CFG, SAM2_CHECKPOINT, device=device)
return predictor
def load_first_frame(video_path, scale=1.0):
"""
Opens a video file and returns the first frame, scaled as specified.
Parameters:
- video_path: Path to the video file.
- scale: Scaling factor for the frame (default is 1.0 for original size).
Returns:
- first_frame: The first frame of the video, scaled accordingly.
"""
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
logger.error(f"Error: Could not open video file {video_path}")
return None
ret, frame = cap.read()
cap.release()
if not ret:
logger.error(f"Error: Could not read frame from video file {video_path}")
return None
if scale != 1.0:
frame = cv2.resize(
frame, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR
)
return frame
def detect_humans_with_yolo(frame, yolo_model, confidence_threshold=YOLO_CONFIDENCE):
"""
Detect humans in a frame using YOLO model.
Parameters:
- frame: Input frame (BGR format)
- yolo_model: Loaded YOLO model
- confidence_threshold: Detection confidence threshold
Returns:
- human_boxes: List of bounding boxes for detected humans
"""
# Run YOLO detection
results = yolo_model(frame, conf=confidence_threshold, verbose=False)
human_boxes = []
# Process results
for result in results:
boxes = result.boxes
if boxes is not None:
for box in boxes:
# Get class ID
cls = int(box.cls.cpu().numpy()[0])
# Check if it's a person (class 0 in COCO)
if cls == HUMAN_CLASS_ID:
# Get bounding box coordinates (x1, y1, x2, y2)
coords = box.xyxy[0].cpu().numpy()
conf = float(box.conf.cpu().numpy()[0])
human_boxes.append({
'bbox': coords,
'confidence': conf
})
logger.info(f"Detected human with confidence {conf:.2f} at {coords}")
return human_boxes
def add_yolo_detections_to_predictor(predictor, inference_state, human_detections, frame_width):
"""
Add YOLO human detections as bounding boxes to SAM2 predictor.
For stereo videos, creates two objects (left and right humans).
Parameters:
- predictor: SAM2 video predictor
- inference_state: SAM2 inference state
- human_detections: List of human detection results
- frame_width: Width of the frame for stereo splitting
Returns:
- out_mask_logits: SAM2 output mask logits
"""
half_frame_width = frame_width // 2
# Sort detections by x-coordinate to get left and right humans
human_detections.sort(key=lambda x: x['bbox'][0]) # Sort by x1 coordinate
obj_id = 1
out_mask_logits = None
for i, detection in enumerate(human_detections[:2]): # Take up to 2 humans (left and right)
bbox = detection['bbox']
# For stereo videos, assign obj_id based on position
if len(human_detections) >= 2:
# If we have multiple humans, assign based on left/right position
center_x = (bbox[0] + bbox[2]) / 2
if center_x < half_frame_width:
current_obj_id = 1 # Left human
else:
current_obj_id = 2 # Right human
else:
# If only one human, duplicate for both sides (as in original stereo logic)
current_obj_id = obj_id
obj_id += 1
# Also add the mirrored version for stereo
if obj_id <= 2:
mirrored_bbox = bbox.copy()
mirrored_bbox[0] += half_frame_width # Shift x1
mirrored_bbox[2] += half_frame_width # Shift x2
# Ensure mirrored bbox is within frame bounds
mirrored_bbox[0] = max(0, min(mirrored_bbox[0], frame_width - 1))
mirrored_bbox[2] = max(0, min(mirrored_bbox[2], frame_width - 1))
try:
_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
inference_state=inference_state,
frame_idx=0,
obj_id=obj_id,
box=mirrored_bbox.astype(np.float32),
)
logger.info(f"Added mirrored human detection for Object {obj_id}")
obj_id += 1
except Exception as e:
logger.error(f"Error adding mirrored human detection for Object {obj_id}: {e}")
try:
_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
inference_state=inference_state,
frame_idx=0,
obj_id=current_obj_id,
box=bbox.astype(np.float32),
)
logger.info(f"Added human detection for Object {current_obj_id}")
except Exception as e:
logger.error(f"Error adding human detection for Object {current_obj_id}: {e}")
return out_mask_logits
def propagate_masks(predictor, inference_state):
video_segments = {}
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
video_segments[out_frame_idx] = {
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
for i, out_obj_id in enumerate(out_obj_ids)
}
return video_segments
def apply_colored_mask(frame, masks_a, masks_b):
colored_mask = np.zeros_like(frame)
# Apply colors to the masks
for mask in masks_a:
mask = mask.squeeze()
if mask.shape != frame.shape[:2]:
mask = cv2.resize(mask, (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_NEAREST)
indices = np.where(mask)
colored_mask[mask] = [0, 255, 0] # Green for Object A
for mask in masks_b:
mask = mask.squeeze()
if mask.shape != frame.shape[:2]:
mask = cv2.resize(mask, (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_NEAREST)
indices = np.where(mask)
colored_mask[mask] = [255, 0, 0] # Blue for Object B
return colored_mask
def process_and_save_output_video(video_path, output_video_path, video_segments, use_nvenc=False):
"""
Process high-resolution frames, apply upscaled masks, and save the output video.
"""
cap = cv2.VideoCapture(video_path)
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = cap.get(cv2.CAP_PROP_FPS) or 59.94
# Setup VideoWriter with desired settings
if use_nvenc:
# Use FFmpeg with NVENC offloading for H.265 encoding
import subprocess
if sys.platform == 'darwin':
encoder = 'hevc_videotoolbox'
else:
encoder = 'hevc_nvenc'
command = [
'ffmpeg',
'-y', # Overwrite output file if it exists
'-f', 'rawvideo',
'-vcodec', 'rawvideo',
'-pix_fmt', 'bgr24',
'-s', f'{frame_width}x{frame_height}',
'-r', str(fps),
'-i', '-', # Input from stdin
'-an', # No audio
'-vcodec', encoder,
'-pix_fmt', 'nv12',
'-preset', 'slow',
'-b:v', '50M',
output_video_path
]
process = subprocess.Popen(command, stdin=subprocess.PIPE)
else:
# Use OpenCV VideoWriter
fourcc = cv2.VideoWriter_fourcc(*'HEVC') # H.265
out = cv2.VideoWriter(output_video_path, fourcc, fps, (frame_width, frame_height))
frame_idx = 0
while True:
ret, frame = cap.read()
if not ret or frame_idx >= len(video_segments):
break
masks = [video_segments[frame_idx][out_obj_id] for out_obj_id in video_segments[frame_idx]]
upscaled_masks = []
for mask in masks:
mask = mask.squeeze()
upscaled_mask = cv2.resize(mask.astype(np.uint8), (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_NEAREST)
upscaled_masks.append(upscaled_mask)
result_frame = apply_green_mask(frame, upscaled_masks)
# Write frame to output
if use_nvenc:
process.stdin.write(result_frame.tobytes())
else:
out.write(result_frame)
frame_idx += 1
cap.release()
if use_nvenc:
process.stdin.close()
process.wait()
else:
out.release()
def get_video_file_name(index):
return f"segment_{str(index).zfill(3)}.mp4"
def do_yolo_detection_on_segments(base_dir, segments, detect_segments, scale=1.0, yolo_model_path=YOLO_MODEL_PATH):
"""
Run YOLO detection on specified segments and save detection results.
"""
logger.info("Running YOLO detection on requested segments.")
# Load YOLO model
yolo_model = YOLO(yolo_model_path)
for i, segment in enumerate(segments):
segment_index = int(segment.split("_")[1])
segment_dir = os.path.join(base_dir, segment)
detection_file = os.path.join(segment_dir, "yolo_detections")
video_file = os.path.join(segment_dir, get_video_file_name(i))
if segment_index in detect_segments and not os.path.exists(detection_file):
first_frame = load_first_frame(video_file, scale)
if first_frame is None:
continue
# Convert BGR to RGB for YOLO (YOLO expects BGR, so keep as BGR)
human_detections = detect_humans_with_yolo(first_frame, yolo_model)
if human_detections:
# Save detection results
with open(detection_file, 'w') as f:
f.write("# YOLO Human Detections\n")
for detection in human_detections:
bbox = detection['bbox']
conf = detection['confidence']
f.write(f"{bbox[0]},{bbox[1]},{bbox[2]},{bbox[3]},{conf}\n")
logger.info(f"Saved {len(human_detections)} human detections for segment {segment}")
else:
logger.warning(f"No humans detected in segment {segment}")
# Create empty file to mark as processed
with open(detection_file, 'w') as f:
f.write("# No humans detected\n")
def save_final_masks(video_segments, mask_output_path):
"""
Save the final masks as a colored image.
"""
last_frame_idx = max(video_segments.keys())
masks_dict = video_segments[last_frame_idx]
# Assuming you have two objects with IDs 1 and 2
mask_a = masks_dict.get(1).squeeze() if 1 in masks_dict else None
mask_b = masks_dict.get(2).squeeze() if 2 in masks_dict else None
if mask_a is None and mask_b is None:
logger.error("No masks found for objects.")
return
# Use the first available mask to determine dimensions
reference_mask = mask_a if mask_a is not None else mask_b
black_frame = np.zeros((reference_mask.shape[0], reference_mask.shape[1], 3), dtype=np.uint8)
if mask_a is not None:
mask_a = mask_a.astype(bool)
black_frame[mask_a] = GREEN
if mask_b is not None:
mask_b = mask_b.astype(bool)
black_frame[mask_b] = BLUE
# Save the mask image
cv2.imwrite(mask_output_path, black_frame)
logger.info(f"Saved final masks to {mask_output_path}")
def create_low_res_video(input_video_path, output_video_path, scale):
"""
Creates a low-resolution version of the input video for inference.
"""
cap = cv2.VideoCapture(input_video_path)
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH) * scale)
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT) * scale)
fps = cap.get(cv2.CAP_PROP_FPS) or 59.94
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_video_path, fourcc, fps, (frame_width, frame_height))
while True:
ret, frame = cap.read()
if not ret:
break
low_res_frame = cv2.resize(frame, (frame_width, frame_height), interpolation=cv2.INTER_LINEAR)
out.write(low_res_frame)
cap.release()
out.release()
def main():
parser = argparse.ArgumentParser(description="Process video segments with YOLO + SAM2.")
parser.add_argument("--base-dir", type=str, help="Base directory for video segments.")
parser.add_argument("--segments-detect-humans", nargs='*', help="Segments for which to run YOLO human detection. Use 'all' for all segments, or list specific segment numbers (e.g., 1 5 10). Default: all segments.")
parser.add_argument("--yolo-model", type=str, default=YOLO_MODEL_PATH, help="Path to YOLO model.")
parser.add_argument("--yolo-confidence", type=float, default=YOLO_CONFIDENCE, help="YOLO detection confidence threshold.")
args = parser.parse_args()
base_dir = args.base_dir
segments = [d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d)) and d.startswith("segment_")]
segments.sort(key=lambda x: int(x.split("_")[1]))
# Handle different ways to specify segments for YOLO detection
if args.segments_detect_humans is None or len(args.segments_detect_humans) == 0:
# Default: run YOLO on all segments
detect_segments = [int(seg.split("_")[1]) for seg in segments]
logger.info("No segments specified, running YOLO detection on ALL segments")
elif len(args.segments_detect_humans) == 1 and args.segments_detect_humans[0].lower() == 'all':
# Explicit 'all' keyword
detect_segments = [int(seg.split("_")[1]) for seg in segments]
logger.info("Running YOLO detection on ALL segments")
else:
# Specific segment numbers provided
try:
detect_segments = [int(x) for x in args.segments_detect_humans]
logger.info(f"Running YOLO detection on segments: {detect_segments}")
except ValueError:
logger.error("Invalid segment numbers provided. Use integers or 'all'.")
return
# Run YOLO detection on specified segments
do_yolo_detection_on_segments(base_dir, segments, detect_segments, scale=INFERENCE_SCALE, yolo_model_path=args.yolo_model)
# Load YOLO model for inference
yolo_model = YOLO(args.yolo_model)
for i, segment in enumerate(segments):
segment_index = int(segment.split("_")[1])
segment_dir = os.path.join(base_dir, segment)
video_file_name = get_video_file_name(i)
video_path = os.path.join(segment_dir, video_file_name)
output_done_file = os.path.join(segment_dir, "output_frames_done")
if os.path.exists(output_done_file):
logger.info(f"Segment {segment} already processed. Skipping.")
continue
logger.info(f"Processing segment {segment}")
# Initialize predictor
predictor = initialize_predictor()
# Prepare low-resolution video frames for inference
low_res_video_path = os.path.join(segment_dir, "low_res_video.mp4")
if not os.path.exists(low_res_video_path):
create_low_res_video(video_path, low_res_video_path, INFERENCE_SCALE)
logger.info(f"Low-resolution video created for segment {segment}")
else:
logger.info(f"Low-resolution video already exists for segment {segment}, reuse")
# Initialize inference state with low-resolution video
inference_state = predictor.init_state(video_path=low_res_video_path, async_loading_frames=True)
# Load YOLO detections or previous masks
detection_file = os.path.join(segment_dir, "yolo_detections")
use_detections = segment_index in detect_segments
if i == 0 and not use_detections:
# First segment must use YOLO detection since there's no previous mask
logger.warning(f"First segment {segment} requires YOLO detection. Running YOLO detection.")
use_detections = True
if i > 0 and not use_detections:
# Try to load previous segment mask - search backwards for the most recent successful mask
logger.info(f"Using previous segment mask for segment {segment}")
mask_found = False
# Search backwards through previous segments to find a valid mask
for j in range(i - 1, -1, -1):
prev_segment_dir = os.path.join(base_dir, segments[j])
prev_mask_path = os.path.join(prev_segment_dir, "mask.png")
if os.path.exists(prev_mask_path):
try:
per_obj_input_mask, input_palette = load_previous_segment_mask(prev_segment_dir)
# Add previous masks to predictor
for obj_id, mask in per_obj_input_mask.items():
predictor.add_new_mask(inference_state, 0, obj_id, mask)
logger.info(f"Successfully loaded mask from segment {segments[j]}")
mask_found = True
break
except Exception as e:
logger.warning(f"Error loading mask from {segments[j]}: {e}")
continue
if not mask_found:
logger.error(f"No valid previous mask found for segment {segment}. Consider running YOLO detection on this segment.")
continue
else:
# Load first frame for detection
first_frame = load_first_frame(low_res_video_path, scale=1.0)
if first_frame is None:
logger.error(f"Could not load first frame for segment {segment}")
continue
# Run YOLO detection on first frame (either from file or on-the-fly)
if os.path.exists(detection_file):
logger.info(f"Using existing YOLO detections for segment {segment}")
else:
logger.info(f"Running YOLO detection on-the-fly for segment {segment}")
human_detections = detect_humans_with_yolo(first_frame, yolo_model, args.yolo_confidence)
if human_detections:
# Add YOLO detections to predictor
frame_width = first_frame.shape[1]
add_yolo_detections_to_predictor(predictor, inference_state, human_detections, frame_width)
else:
logger.warning(f"No humans detected in segment {segment}")
continue
# Perform inference and collect masks per frame
video_segments = propagate_masks(predictor, inference_state)
# Process high-resolution frames and save output video
output_video_path = os.path.join(segment_dir, f"output_{segment_index}.mp4")
logger.info("Processing segment complete, attempting to save full video from low res masks")
process_and_save_output_video(
video_path,
output_video_path,
video_segments,
use_nvenc=True # Set to True to use NVENC offloading
)
# Save final masks
mask_output_path = os.path.join(segment_dir, "mask.png")
save_final_masks(video_segments, mask_output_path)
# Clean up
predictor.reset_state(inference_state)
del inference_state
del video_segments
del predictor
gc.collect()
try:
os.remove(low_res_video_path)
logger.info(f"Deleted low-resolution video for segment {segment}")
except Exception as e:
logger.warning(f"Could not delete low-resolution video for segment {segment}: {e}")
# Mark segment as completed
open(output_done_file, 'a').close()
logger.info("Processing complete.")
if __name__ == "__main__":
main()