diff --git a/README.md b/README.md index 005589d..c2dc53f 100644 --- a/README.md +++ b/README.md @@ -32,19 +32,40 @@ git clone cd samyolo_on_segments # Install Python dependencies -pip install -r requirements.txt +uv venv && source .venv/bin/activate +uv pip install -r requirements.txt ``` -### Model Dependencies +### Download Models -You'll need to download the required model checkpoints: +Use the provided script to automatically download all required models: +```bash +# Download SAM2.1 and YOLO models +python download_models.py +``` + +This script will: +- Create a `models/` directory structure +- Download SAM2.1 configs and checkpoints (tiny, small, base+, large) +- Download common YOLO models (yolov8n, yolov8s, yolov8m) +- Update `config.yaml` to use local model paths + +**Manual Download (Alternative):** 1. **SAM2 Models**: Download from [Meta's SAM2 repository](https://github.com/facebookresearch/sam2) -2. **YOLO Models**: YOLOv8 models will be downloaded automatically or you can specify a custom path +2. **YOLO Models**: YOLOv8 models will be downloaded automatically on first use ## Quick Start -### 1. Configure the Pipeline +### 1. Download Models + +First, download the required SAM2.1 and YOLO models: + +```bash +python download_models.py +``` + +### 2. Configure the Pipeline Edit `config.yaml` to specify your input video and desired settings: @@ -63,18 +84,18 @@ processing: detect_segments: "all" models: - yolo_model: "yolov8n.pt" - sam2_checkpoint: "../checkpoints/sam2.1_hiera_large.pt" - sam2_config: "configs/sam2.1/sam2.1_hiera_l.yaml" + yolo_model: "models/yolo/yolov8n.pt" + sam2_checkpoint: "models/sam2/checkpoints/sam2.1_hiera_large.pt" + sam2_config: "models/sam2/configs/sam2.1/sam2.1_hiera_l.yaml" ``` -### 2. Run the Pipeline +### 3. Run the Pipeline ```bash python main.py --config config.yaml ``` -### 3. Monitor Progress +### 4. Monitor Progress Check processing status: ```bash @@ -166,8 +187,25 @@ samyolo_on_segments/ ├── README.md # This documentation ├── config.yaml # Default configuration ├── main.py # Main entry point +├── download_models.py # Model download script ├── requirements.txt # Python dependencies ├── spec.md # Detailed specification +├── models/ # Downloaded models (created by script) +│ ├── sam2/ +│ │ ├── configs/sam2.1/ # SAM2.1 configuration files +│ │ │ ├── sam2.1_hiera_t.yaml +│ │ │ ├── sam2.1_hiera_s.yaml +│ │ │ ├── sam2.1_hiera_b+.yaml +│ │ │ └── sam2.1_hiera_l.yaml +│ │ └── checkpoints/ # SAM2.1 model weights +│ │ ├── sam2.1_hiera_tiny.pt +│ │ ├── sam2.1_hiera_small.pt +│ │ ├── sam2.1_hiera_base_plus.pt +│ │ └── sam2.1_hiera_large.pt +│ └── yolo/ # YOLO model weights +│ ├── yolov8n.pt +│ ├── yolov8s.pt +│ └── yolov8m.pt ├── core/ # Core processing modules │ ├── __init__.py │ ├── config_loader.py # Configuration management @@ -297,4 +335,4 @@ This project is under active development. The core detection pipeline is functio For issues and questions: 1. Check the troubleshooting section 2. Review the logs with `log_level: "DEBUG"` -3. Open an issue with your configuration and error details \ No newline at end of file +3. Open an issue with your configuration and error details diff --git a/config.yaml b/config.yaml index 0758405..16db3cd 100644 --- a/config.yaml +++ b/config.yaml @@ -23,11 +23,11 @@ processing: models: # YOLO model path - can be pretrained (yolov8n.pt) or custom path - yolo_model: "yolov8n.pt" + yolo_model: "models/yolo/yolov8n.pt" # SAM2 model configuration - sam2_checkpoint: "../checkpoints/sam2.1_hiera_large.pt" - sam2_config: "configs/sam2.1/sam2.1_hiera_l.yaml" + sam2_checkpoint: "models/sam2/checkpoints/sam2.1_hiera_large.pt" + sam2_config: "models/sam2/configs/sam2.1/sam2.1_hiera_l.yaml" video: # Use NVIDIA hardware encoding (requires NVENC-capable GPU) diff --git a/core/video_splitter.py b/core/video_splitter.py index 7d49fba..97c6fa9 100644 --- a/core/video_splitter.py +++ b/core/video_splitter.py @@ -7,7 +7,7 @@ import os import subprocess import logging from typing import List, Tuple -from ..utils.file_utils import ensure_directory, get_video_file_name +from utils.file_utils import ensure_directory, get_video_file_name logger = logging.getLogger(__name__) diff --git a/core/yolo_detector.py b/core/yolo_detector.py index 78be285..028dd53 100644 --- a/core/yolo_detector.py +++ b/core/yolo_detector.py @@ -13,17 +13,17 @@ from ultralytics import YOLO logger = logging.getLogger(__name__) class YOLODetector: - \"\"\"Handles YOLO-based human detection for video segments.\"\"\" + """Handles YOLO-based human detection for video segments.""" def __init__(self, model_path: str, confidence_threshold: float = 0.6, human_class_id: int = 0): - \"\"\" + """ Initialize YOLO detector. Args: model_path: Path to YOLO model weights confidence_threshold: Detection confidence threshold human_class_id: COCO class ID for humans (0 = person) - \"\"\" + """ self.model_path = model_path self.confidence_threshold = confidence_threshold self.human_class_id = human_class_id @@ -31,13 +31,13 @@ class YOLODetector: # Load YOLO model try: self.model = YOLO(model_path) - logger.info(f\"Loaded YOLO model from {model_path}\") + logger.info(f"Loaded YOLO model from {model_path}") except Exception as e: - logger.error(f\"Failed to load YOLO model: {e}\") + logger.error(f"Failed to load YOLO model: {e}") raise def detect_humans_in_frame(self, frame: np.ndarray) -> List[Dict[str, Any]]: - \"\"\" + """ Detect humans in a single frame using YOLO. Args: @@ -45,7 +45,7 @@ class YOLODetector: Returns: List of human detection dictionaries with bbox and confidence - \"\"\" + """ # Run YOLO detection results = self.model(frame, conf=self.confidence_threshold, verbose=False) @@ -70,12 +70,12 @@ class YOLODetector: 'confidence': conf }) - logger.debug(f\"Detected human with confidence {conf:.2f} at {coords}\") + logger.debug(f"Detected human with confidence {conf:.2f} at {coords}") return human_detections def detect_humans_in_video_first_frame(self, video_path: str, scale: float = 1.0) -> List[Dict[str, Any]]: - \"\"\" + """ Detect humans in the first frame of a video. Args: @@ -84,21 +84,21 @@ class YOLODetector: Returns: List of human detection dictionaries - \"\"\" + """ if not os.path.exists(video_path): - logger.error(f\"Video file not found: {video_path}\") + logger.error(f"Video file not found: {video_path}") return [] cap = cv2.VideoCapture(video_path) if not cap.isOpened(): - logger.error(f\"Could not open video: {video_path}\") + logger.error(f"Could not open video: {video_path}") return [] ret, frame = cap.read() cap.release() if not ret: - logger.error(f\"Could not read first frame from: {video_path}\") + logger.error(f"Could not read first frame from: {video_path}") return [] # Scale frame if needed @@ -108,7 +108,7 @@ class YOLODetector: return self.detect_humans_in_frame(frame) def save_detections_to_file(self, detections: List[Dict[str, Any]], output_path: str) -> bool: - \"\"\" + """ Save detection results to file. Args: @@ -117,26 +117,26 @@ class YOLODetector: Returns: True if saved successfully - \"\"\" + """ try: with open(output_path, 'w') as f: - f.write(\"# YOLO Human Detections\\n\") + f.write("# YOLO Human Detections\\n") if detections: for detection in detections: bbox = detection['bbox'] conf = detection['confidence'] - f.write(f\"{bbox[0]},{bbox[1]},{bbox[2]},{bbox[3]},{conf}\\n\") - logger.info(f\"Saved {len(detections)} detections to {output_path}\") + f.write(f"{bbox[0]},{bbox[1]},{bbox[2]},{bbox[3]},{conf}\\n") + logger.info(f"Saved {len(detections)} detections to {output_path}") else: - f.write(\"# No humans detected\\n\") - logger.info(f\"Saved empty detection file to {output_path}\") + f.write("# No humans detected\\n") + logger.info(f"Saved empty detection file to {output_path}") return True except Exception as e: - logger.error(f\"Failed to save detections to {output_path}: {e}\") + logger.error(f"Failed to save detections to {output_path}: {e}") return False def load_detections_from_file(self, file_path: str) -> List[Dict[str, Any]]: - \"\"\" + """ Load detection results from file. Args: @@ -144,11 +144,11 @@ class YOLODetector: Returns: List of detection dictionaries - \"\"\" + """ detections = [] if not os.path.exists(file_path): - logger.warning(f\"Detection file not found: {file_path}\") + logger.warning(f"Detection file not found: {file_path}") return detections try: @@ -170,18 +170,18 @@ class YOLODetector: 'confidence': conf }) except ValueError: - logger.warning(f\"Invalid detection line: {line}\") + logger.warning(f"Invalid detection line: {line}") continue - logger.info(f\"Loaded {len(detections)} detections from {file_path}\") + logger.info(f"Loaded {len(detections)} detections from {file_path}") except Exception as e: - logger.error(f\"Failed to load detections from {file_path}: {e}\") + logger.error(f"Failed to load detections from {file_path}: {e}") return detections def process_segments_batch(self, segments_info: List[dict], detect_segments: List[int], scale: float = 0.5) -> Dict[int, List[Dict[str, Any]]]: - \"\"\" + """ Process multiple segments for human detection. Args: @@ -191,7 +191,7 @@ class YOLODetector: Returns: Dictionary mapping segment index to detection results - \"\"\" + """ results = {} for segment_info in segments_info: @@ -202,17 +202,17 @@ class YOLODetector: continue video_path = segment_info['video_file'] - detection_file = os.path.join(segment_info['directory'], \"yolo_detections\") + detection_file = os.path.join(segment_info['directory'], "yolo_detections") # Skip if already processed if os.path.exists(detection_file): - logger.info(f\"Segment {segment_idx} already has detections, skipping\") + logger.info(f"Segment {segment_idx} already has detections, skipping") detections = self.load_detections_from_file(detection_file) results[segment_idx] = detections continue # Run detection - logger.info(f\"Processing segment {segment_idx} for human detection\") + logger.info(f"Processing segment {segment_idx} for human detection") detections = self.detect_humans_in_video_first_frame(video_path, scale) # Save results @@ -223,7 +223,7 @@ class YOLODetector: def convert_detections_to_sam2_prompts(self, detections: List[Dict[str, Any]], frame_width: int) -> List[Dict[str, Any]]: - \"\"\" + """ Convert YOLO detections to SAM2-compatible prompts for stereo video. Args: @@ -232,7 +232,7 @@ class YOLODetector: Returns: List of SAM2 prompt dictionaries with obj_id and bbox - \"\"\" + """ if not detections: return [] @@ -282,5 +282,5 @@ class YOLODetector: 'confidence': detection['confidence'] }) - logger.debug(f\"Converted {len(detections)} detections to {len(prompts)} SAM2 prompts\") + logger.debug(f"Converted {len(detections)} detections to {len(prompts)} SAM2 prompts") return prompts \ No newline at end of file diff --git a/download_models.py b/download_models.py new file mode 100755 index 0000000..d1a45c0 --- /dev/null +++ b/download_models.py @@ -0,0 +1,286 @@ +#!/usr/bin/env python3 +""" +Model download script for YOLO + SAM2 video processing pipeline. +Downloads SAM2.1 models and organizes them in the models directory. +""" + +import os +import urllib.request +import urllib.error +from pathlib import Path +import sys + +def create_directory_structure(): + """Create the models directory structure.""" + base_dir = Path(__file__).parent + models_dir = base_dir / "models" + + # Create main models directory + models_dir.mkdir(exist_ok=True) + + # Create subdirectories + sam2_dir = models_dir / "sam2" + sam2_configs_dir = sam2_dir / "configs" / "sam2.1" + sam2_checkpoints_dir = sam2_dir / "checkpoints" + yolo_dir = models_dir / "yolo" + + sam2_dir.mkdir(exist_ok=True) + sam2_configs_dir.mkdir(parents=True, exist_ok=True) + sam2_checkpoints_dir.mkdir(exist_ok=True) + yolo_dir.mkdir(exist_ok=True) + + print(f"Created models directory structure in: {models_dir}") + return models_dir, sam2_configs_dir, sam2_checkpoints_dir, yolo_dir + +def download_file(url, destination, description="file"): + """Download a file with progress indication.""" + try: + print(f"Downloading {description}...") + print(f" URL: {url}") + print(f" Destination: {destination}") + + def progress_hook(block_num, block_size, total_size): + if total_size > 0: + percent = min(100, (block_num * block_size * 100) // total_size) + sys.stdout.write(f"\r Progress: {percent}%") + sys.stdout.flush() + + urllib.request.urlretrieve(url, destination, progress_hook) + print(f"\n ✓ Downloaded {description}") + return True + + except urllib.error.URLError as e: + print(f"\n ✗ Failed to download {description}: {e}") + return False + except Exception as e: + print(f"\n ✗ Error downloading {description}: {e}") + return False + +def download_sam2_models(): + """Download SAM2.1 model configurations and checkpoints.""" + print("Setting up SAM2.1 models...") + + # Create directory structure + models_dir, configs_dir, checkpoints_dir, yolo_dir = create_directory_structure() + + # SAM2.1 model definitions + sam2_models = { + "tiny": { + "config_url": "https://raw.githubusercontent.com/facebookresearch/sam2/main/sam2/configs/sam2.1/sam2.1_hiera_t.yaml", + "checkpoint_url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt", + "config_file": "sam2.1_hiera_t.yaml", + "checkpoint_file": "sam2.1_hiera_tiny.pt" + }, + "small": { + "config_url": "https://raw.githubusercontent.com/facebookresearch/sam2/main/sam2/configs/sam2.1/sam2.1_hiera_s.yaml", + "checkpoint_url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt", + "config_file": "sam2.1_hiera_s.yaml", + "checkpoint_file": "sam2.1_hiera_small.pt" + }, + "base_plus": { + "config_url": "https://raw.githubusercontent.com/facebookresearch/sam2/main/sam2/configs/sam2.1/sam2.1_hiera_b+.yaml", + "checkpoint_url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt", + "config_file": "sam2.1_hiera_b+.yaml", + "checkpoint_file": "sam2.1_hiera_base_plus.pt" + }, + "large": { + "config_url": "https://raw.githubusercontent.com/facebookresearch/sam2/main/sam2/configs/sam2.1/sam2.1_hiera_l.yaml", + "checkpoint_url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt", + "config_file": "sam2.1_hiera_l.yaml", + "checkpoint_file": "sam2.1_hiera_large.pt" + } + } + + success_count = 0 + total_downloads = len(sam2_models) * 2 # configs + checkpoints + + # Download each model's config and checkpoint + for model_name, model_info in sam2_models.items(): + print(f"\n--- Downloading SAM2.1 {model_name.upper()} model ---") + + # Download config file + config_path = configs_dir / model_info["config_file"] + if not config_path.exists(): + if download_file( + model_info["config_url"], + config_path, + f"SAM2.1 {model_name} config" + ): + success_count += 1 + else: + print(f" ✓ Config file already exists: {config_path}") + success_count += 1 + + # Download checkpoint file + checkpoint_path = checkpoints_dir / model_info["checkpoint_file"] + if not checkpoint_path.exists(): + if download_file( + model_info["checkpoint_url"], + checkpoint_path, + f"SAM2.1 {model_name} checkpoint" + ): + success_count += 1 + else: + print(f" ✓ Checkpoint file already exists: {checkpoint_path}") + success_count += 1 + + print(f"\n=== Download Summary ===") + print(f"Successfully downloaded: {success_count}/{total_downloads} files") + + if success_count == total_downloads: + print("✓ All SAM2.1 models downloaded successfully!") + return True + else: + print(f"⚠ Some downloads failed ({total_downloads - success_count} files)") + return False + +def download_yolo_models(): + """Download default YOLO models to models directory.""" + print("\n--- Setting up YOLO models ---") + + try: + from ultralytics import YOLO + import torch + + # Default YOLO models to download + yolo_models = ["yolov8n.pt", "yolov8s.pt", "yolov8m.pt"] + models_dir = Path(__file__).parent / "models" / "yolo" + + for model_name in yolo_models: + model_path = models_dir / model_name + if not model_path.exists(): + print(f"Downloading {model_name}...") + try: + # First try to download using the YOLO class with export + model = YOLO(model_name) + + # Export/save the model to our directory + # The model.ckpt is the internal checkpoint + if hasattr(model, 'ckpt') and hasattr(model.ckpt, 'save'): + # Save the checkpoint directly + torch.save(model.ckpt, str(model_path)) + print(f" ✓ Saved {model_name} to models directory") + else: + # Alternative: try to find where YOLO downloaded the model + import shutil + + # Common locations where YOLO might store models + possible_paths = [ + Path.home() / ".cache" / "ultralytics" / "models" / model_name, + Path.home() / ".ultralytics" / "models" / model_name, + Path.home() / "runs" / "detect" / model_name, + Path.cwd() / model_name, # Current directory + ] + + found = False + for possible_path in possible_paths: + if possible_path.exists(): + shutil.copy2(possible_path, model_path) + print(f" ✓ Copied {model_name} from {possible_path}") + found = True + # Clean up if it was downloaded to current directory + if possible_path.parent == Path.cwd() and possible_path != model_path: + possible_path.unlink() + break + + if not found: + # Last resort: use urllib to download directly + yolo_url = f"https://github.com/ultralytics/assets/releases/download/v8.2.0/{model_name}" + print(f" Downloading directly from {yolo_url}...") + download_file(yolo_url, str(model_path), f"YOLO {model_name}") + + except Exception as e: + print(f" ⚠ Error downloading {model_name}: {e}") + # Try direct download as fallback + try: + yolo_url = f"https://github.com/ultralytics/assets/releases/download/v8.2.0/{model_name}" + print(f" Trying direct download from {yolo_url}...") + download_file(yolo_url, str(model_path), f"YOLO {model_name}") + except Exception as e2: + print(f" ✗ Failed to download {model_name}: {e2}") + else: + print(f" ✓ {model_name} already exists") + + # Verify all models exist + success = all((models_dir / model).exists() for model in yolo_models) + if success: + print("✓ YOLO models setup complete!") + else: + print("⚠ Some YOLO models may be missing") + return success + + except ImportError: + print("⚠ ultralytics not installed. YOLO models will be downloaded on first use.") + return False + except Exception as e: + print(f"⚠ Error setting up YOLO models: {e}") + return False + +def update_config_file(): + """Update config.yaml to use local model paths.""" + print("\n--- Updating config.yaml ---") + + config_path = Path(__file__).parent / "config.yaml" + if not config_path.exists(): + print("⚠ config.yaml not found, skipping update") + return False + + try: + # Read current config + with open(config_path, 'r') as f: + content = f.read() + + # Update model paths to use local models + updated_content = content.replace( + 'yolo_model: "yolov8n.pt"', + 'yolo_model: "models/yolo/yolov8n.pt"' + ).replace( + 'sam2_checkpoint: "../checkpoints/sam2.1_hiera_large.pt"', + 'sam2_checkpoint: "models/sam2/checkpoints/sam2.1_hiera_large.pt"' + ).replace( + 'sam2_config: "configs/sam2.1/sam2.1_hiera_l.yaml"', + 'sam2_config: "models/sam2/configs/sam2.1/sam2.1_hiera_l.yaml"' + ) + + # Write updated config + with open(config_path, 'w') as f: + f.write(updated_content) + + print("✓ Updated config.yaml to use local model paths") + return True + + except Exception as e: + print(f"⚠ Error updating config.yaml: {e}") + return False + +def main(): + """Main function to download all models.""" + print("🤖 YOLO + SAM2 Model Download Script") + print("="*50) + + # Download SAM2 models + sam2_success = download_sam2_models() + + # Download YOLO models + yolo_success = download_yolo_models() + + # Update config file + config_success = update_config_file() + + print("\n" + "="*50) + print("📋 Final Summary:") + print(f" SAM2 models: {'✓' if sam2_success else '⚠'}") + print(f" YOLO models: {'✓' if yolo_success else '⚠'}") + print(f" Config update: {'✓' if config_success else '⚠'}") + + if sam2_success and config_success: + print("\n🎉 Setup complete! You can now run the pipeline with:") + print(" python main.py --config config.yaml") + else: + print("\n⚠ Some steps failed. Check the output above for details.") + + print("\n📁 Models are organized in:") + print(f" {Path(__file__).parent / 'models'}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/spec.md b/spec.md index b64792e..4f6788e 100644 --- a/spec.md +++ b/spec.md @@ -189,4 +189,622 @@ models: ### Model Improvements - **Fine-tuned YOLO**: Domain-specific human detection models - **SAM2 Optimization**: Custom SAM2 checkpoints for video content -- **Temporal Consistency**: Enhanced cross-segment mask propagation \ No newline at end of file +- **Temporal Consistency**: Enhanced cross-segment mask propagation + + +Here is the original monolithic script this repo is a refactor/modularization of. If something +doesn't work in this repo, then consult the following script becasue it does work so this can +be used to solve problems: + + +import os +import cv2 +import numpy as np +import cupy as cp +from concurrent.futures import ThreadPoolExecutor +import torch +import logging +import sys +import gc +from sam2.build_sam import build_sam2_video_predictor +import argparse +from ultralytics import YOLO + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + +os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" + +# Variables for input and output directories +SAM2_CHECKPOINT = "../checkpoints/sam2.1_hiera_large.pt" +MODEL_CFG = "configs/sam2.1/sam2.1_hiera_l.yaml" +GREEN = [0, 255, 0] +BLUE = [255, 0, 0] + +INFERENCE_SCALE = 0.50 +FULL_SCALE = 1.0 + +# YOLO model for human detection (class 0 = person) +YOLO_MODEL_PATH = "yolov8n.pt" # You can change this to a custom model +YOLO_CONFIDENCE = 0.6 +HUMAN_CLASS_ID = 0 # COCO class ID for person + +def open_video(video_path): + """ + Opens a video file and returns a generator that yields frames. + + Parameters: + - video_path: Path to the video file. + + Returns: + - A generator that yields frames from the video. + """ + cap = cv2.VideoCapture(video_path) + if not cap.isOpened(): + print(f"Error: Could not open video file {video_path}") + return + while True: + ret, frame = cap.read() + if not ret: + break + yield frame + cap.release() + +def load_previous_segment_mask(prev_segment_dir): + mask_path = os.path.join(prev_segment_dir, "mask.png") + mask_image = cv2.imread(mask_path) + + if mask_image is None: + raise FileNotFoundError(f"Mask image not found at {mask_path}") + + # Ensure the mask_image has three color channels + if len(mask_image.shape) != 3 or mask_image.shape[2] != 3: + raise ValueError("Mask image does not have three color channels.") + + mask_image = mask_image.astype(np.uint8) + + # Extract Object A and Object B masks + mask_a = np.all(mask_image == GREEN, axis=2) + mask_b = np.all(mask_image == BLUE, axis=2) + + per_obj_input_mask = {1: mask_a, 2: mask_b} + input_palette = None # No palette needed for binary mask + + return per_obj_input_mask, input_palette + + +def apply_green_mask(frame, masks): + # Convert frame and masks to CuPy arrays + frame_gpu = cp.asarray(frame) + combined_mask = cp.zeros(frame_gpu.shape[:2], dtype=cp.bool_) + + for mask in masks: + mask_gpu = cp.asarray(mask.squeeze()) + if mask_gpu.shape != frame_gpu.shape[:2]: + resized_mask = cv2.resize(cp.asnumpy(mask_gpu).astype(cp.float32), + (frame_gpu.shape[1], frame_gpu.shape[0])) + mask_gpu = cp.asarray(resized_mask > 0.5) # Convert back to CuPy boolean array + else: + mask_gpu = mask_gpu.astype(cp.bool_) # Ensure boolean type + combined_mask |= mask_gpu # Perform the bitwise OR operation + + green_background = cp.full(frame_gpu.shape, cp.array([0, 255, 0], dtype=cp.uint8), dtype=cp.uint8) + result_frame = cp.where(combined_mask[..., None], frame_gpu, green_background) + return cp.asnumpy(result_frame) # Convert back to NumPy + + +def initialize_predictor(): + if torch.cuda.is_available(): + device = torch.device("cuda") + elif torch.backends.mps.is_available(): + device = torch.device("mps") + print( + "\nSupport for MPS devices is preliminary. SAM 2 is trained with CUDA and might " + "give numerically different outputs and sometimes degraded performance on MPS." + ) + # Enable MPS fallback for operations not supported on MPS + os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" + else: + device = torch.device("cpu") + logger.info(f"Using device: {device}") + predictor = build_sam2_video_predictor(MODEL_CFG, SAM2_CHECKPOINT, device=device) + return predictor + + +def load_first_frame(video_path, scale=1.0): + """ + Opens a video file and returns the first frame, scaled as specified. + + Parameters: + - video_path: Path to the video file. + - scale: Scaling factor for the frame (default is 1.0 for original size). + + Returns: + - first_frame: The first frame of the video, scaled accordingly. + """ + cap = cv2.VideoCapture(video_path) + if not cap.isOpened(): + logger.error(f"Error: Could not open video file {video_path}") + return None + + ret, frame = cap.read() + cap.release() + + if not ret: + logger.error(f"Error: Could not read frame from video file {video_path}") + return None + + if scale != 1.0: + frame = cv2.resize( + frame, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR + ) + + return frame + +def detect_humans_with_yolo(frame, yolo_model, confidence_threshold=YOLO_CONFIDENCE): + """ + Detect humans in a frame using YOLO model. + + Parameters: + - frame: Input frame (BGR format) + - yolo_model: Loaded YOLO model + - confidence_threshold: Detection confidence threshold + + Returns: + - human_boxes: List of bounding boxes for detected humans + """ + # Run YOLO detection + results = yolo_model(frame, conf=confidence_threshold, verbose=False) + + human_boxes = [] + + # Process results + for result in results: + boxes = result.boxes + if boxes is not None: + for box in boxes: + # Get class ID + cls = int(box.cls.cpu().numpy()[0]) + + # Check if it's a person (class 0 in COCO) + if cls == HUMAN_CLASS_ID: + # Get bounding box coordinates (x1, y1, x2, y2) + coords = box.xyxy[0].cpu().numpy() + conf = float(box.conf.cpu().numpy()[0]) + + human_boxes.append({ + 'bbox': coords, + 'confidence': conf + }) + + logger.info(f"Detected human with confidence {conf:.2f} at {coords}") + + return human_boxes + +def add_yolo_detections_to_predictor(predictor, inference_state, human_detections, frame_width): + """ + Add YOLO human detections as bounding boxes to SAM2 predictor. + For stereo videos, creates two objects (left and right humans). + + Parameters: + - predictor: SAM2 video predictor + - inference_state: SAM2 inference state + - human_detections: List of human detection results + - frame_width: Width of the frame for stereo splitting + + Returns: + - out_mask_logits: SAM2 output mask logits + """ + half_frame_width = frame_width // 2 + + # Sort detections by x-coordinate to get left and right humans + human_detections.sort(key=lambda x: x['bbox'][0]) # Sort by x1 coordinate + + obj_id = 1 + out_mask_logits = None + + for i, detection in enumerate(human_detections[:2]): # Take up to 2 humans (left and right) + bbox = detection['bbox'] + + # For stereo videos, assign obj_id based on position + if len(human_detections) >= 2: + # If we have multiple humans, assign based on left/right position + center_x = (bbox[0] + bbox[2]) / 2 + if center_x < half_frame_width: + current_obj_id = 1 # Left human + else: + current_obj_id = 2 # Right human + else: + # If only one human, duplicate for both sides (as in original stereo logic) + current_obj_id = obj_id + obj_id += 1 + + # Also add the mirrored version for stereo + if obj_id <= 2: + mirrored_bbox = bbox.copy() + mirrored_bbox[0] += half_frame_width # Shift x1 + mirrored_bbox[2] += half_frame_width # Shift x2 + + # Ensure mirrored bbox is within frame bounds + mirrored_bbox[0] = max(0, min(mirrored_bbox[0], frame_width - 1)) + mirrored_bbox[2] = max(0, min(mirrored_bbox[2], frame_width - 1)) + + try: + _, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box( + inference_state=inference_state, + frame_idx=0, + obj_id=obj_id, + box=mirrored_bbox.astype(np.float32), + ) + logger.info(f"Added mirrored human detection for Object {obj_id}") + obj_id += 1 + except Exception as e: + logger.error(f"Error adding mirrored human detection for Object {obj_id}: {e}") + + try: + _, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box( + inference_state=inference_state, + frame_idx=0, + obj_id=current_obj_id, + box=bbox.astype(np.float32), + ) + logger.info(f"Added human detection for Object {current_obj_id}") + except Exception as e: + logger.error(f"Error adding human detection for Object {current_obj_id}: {e}") + + return out_mask_logits + +def propagate_masks(predictor, inference_state): + video_segments = {} + for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state): + video_segments[out_frame_idx] = { + out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy() + for i, out_obj_id in enumerate(out_obj_ids) + } + return video_segments + +def apply_colored_mask(frame, masks_a, masks_b): + colored_mask = np.zeros_like(frame) + + # Apply colors to the masks + for mask in masks_a: + mask = mask.squeeze() + if mask.shape != frame.shape[:2]: + mask = cv2.resize(mask, (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_NEAREST) + indices = np.where(mask) + colored_mask[mask] = [0, 255, 0] # Green for Object A + + for mask in masks_b: + mask = mask.squeeze() + if mask.shape != frame.shape[:2]: + mask = cv2.resize(mask, (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_NEAREST) + indices = np.where(mask) + colored_mask[mask] = [255, 0, 0] # Blue for Object B + + return colored_mask + + +def process_and_save_output_video(video_path, output_video_path, video_segments, use_nvenc=False): + """ + Process high-resolution frames, apply upscaled masks, and save the output video. + """ + cap = cv2.VideoCapture(video_path) + frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + fps = cap.get(cv2.CAP_PROP_FPS) or 59.94 + + # Setup VideoWriter with desired settings + if use_nvenc: + # Use FFmpeg with NVENC offloading for H.265 encoding + import subprocess + + if sys.platform == 'darwin': + encoder = 'hevc_videotoolbox' + else: + encoder = 'hevc_nvenc' + + command = [ + 'ffmpeg', + '-y', # Overwrite output file if it exists + '-f', 'rawvideo', + '-vcodec', 'rawvideo', + '-pix_fmt', 'bgr24', + '-s', f'{frame_width}x{frame_height}', + '-r', str(fps), + '-i', '-', # Input from stdin + '-an', # No audio + '-vcodec', encoder, + '-pix_fmt', 'nv12', + '-preset', 'slow', + '-b:v', '50M', + output_video_path + ] + process = subprocess.Popen(command, stdin=subprocess.PIPE) + else: + # Use OpenCV VideoWriter + fourcc = cv2.VideoWriter_fourcc(*'HEVC') # H.265 + out = cv2.VideoWriter(output_video_path, fourcc, fps, (frame_width, frame_height)) + + frame_idx = 0 + while True: + ret, frame = cap.read() + if not ret or frame_idx >= len(video_segments): + break + + masks = [video_segments[frame_idx][out_obj_id] for out_obj_id in video_segments[frame_idx]] + upscaled_masks = [] + + for mask in masks: + mask = mask.squeeze() + upscaled_mask = cv2.resize(mask.astype(np.uint8), (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_NEAREST) + upscaled_masks.append(upscaled_mask) + + result_frame = apply_green_mask(frame, upscaled_masks) + + # Write frame to output + if use_nvenc: + process.stdin.write(result_frame.tobytes()) + else: + out.write(result_frame) + + frame_idx += 1 + + cap.release() + if use_nvenc: + process.stdin.close() + process.wait() + else: + out.release() + +def get_video_file_name(index): + return f"segment_{str(index).zfill(3)}.mp4" + +def do_yolo_detection_on_segments(base_dir, segments, detect_segments, scale=1.0, yolo_model_path=YOLO_MODEL_PATH): + """ + Run YOLO detection on specified segments and save detection results. + """ + logger.info("Running YOLO detection on requested segments.") + + # Load YOLO model + yolo_model = YOLO(yolo_model_path) + + for i, segment in enumerate(segments): + segment_index = int(segment.split("_")[1]) + segment_dir = os.path.join(base_dir, segment) + detection_file = os.path.join(segment_dir, "yolo_detections") + video_file = os.path.join(segment_dir, get_video_file_name(i)) + + if segment_index in detect_segments and not os.path.exists(detection_file): + first_frame = load_first_frame(video_file, scale) + if first_frame is None: + continue + + # Convert BGR to RGB for YOLO (YOLO expects BGR, so keep as BGR) + human_detections = detect_humans_with_yolo(first_frame, yolo_model) + + if human_detections: + # Save detection results + with open(detection_file, 'w') as f: + f.write("# YOLO Human Detections\n") + for detection in human_detections: + bbox = detection['bbox'] + conf = detection['confidence'] + f.write(f"{bbox[0]},{bbox[1]},{bbox[2]},{bbox[3]},{conf}\n") + logger.info(f"Saved {len(human_detections)} human detections for segment {segment}") + else: + logger.warning(f"No humans detected in segment {segment}") + # Create empty file to mark as processed + with open(detection_file, 'w') as f: + f.write("# No humans detected\n") + +def save_final_masks(video_segments, mask_output_path): + """ + Save the final masks as a colored image. + """ + last_frame_idx = max(video_segments.keys()) + masks_dict = video_segments[last_frame_idx] + # Assuming you have two objects with IDs 1 and 2 + mask_a = masks_dict.get(1).squeeze() if 1 in masks_dict else None + mask_b = masks_dict.get(2).squeeze() if 2 in masks_dict else None + + if mask_a is None and mask_b is None: + logger.error("No masks found for objects.") + return + + # Use the first available mask to determine dimensions + reference_mask = mask_a if mask_a is not None else mask_b + black_frame = np.zeros((reference_mask.shape[0], reference_mask.shape[1], 3), dtype=np.uint8) + + if mask_a is not None: + mask_a = mask_a.astype(bool) + black_frame[mask_a] = GREEN + + if mask_b is not None: + mask_b = mask_b.astype(bool) + black_frame[mask_b] = BLUE + + # Save the mask image + cv2.imwrite(mask_output_path, black_frame) + logger.info(f"Saved final masks to {mask_output_path}") + +def create_low_res_video(input_video_path, output_video_path, scale): + """ + Creates a low-resolution version of the input video for inference. + """ + cap = cv2.VideoCapture(input_video_path) + frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH) * scale) + frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT) * scale) + fps = cap.get(cv2.CAP_PROP_FPS) or 59.94 + + fourcc = cv2.VideoWriter_fourcc(*'mp4v') + out = cv2.VideoWriter(output_video_path, fourcc, fps, (frame_width, frame_height)) + + while True: + ret, frame = cap.read() + if not ret: + break + low_res_frame = cv2.resize(frame, (frame_width, frame_height), interpolation=cv2.INTER_LINEAR) + out.write(low_res_frame) + + cap.release() + out.release() + + +def main(): + parser = argparse.ArgumentParser(description="Process video segments with YOLO + SAM2.") + parser.add_argument("--base-dir", type=str, help="Base directory for video segments.") + parser.add_argument("--segments-detect-humans", nargs='*', help="Segments for which to run YOLO human detection. Use 'all' for all segments, or list specific segment numbers (e.g., 1 5 10). Default: all segments.") + parser.add_argument("--yolo-model", type=str, default=YOLO_MODEL_PATH, help="Path to YOLO model.") + parser.add_argument("--yolo-confidence", type=float, default=YOLO_CONFIDENCE, help="YOLO detection confidence threshold.") + args = parser.parse_args() + + base_dir = args.base_dir + segments = [d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d)) and d.startswith("segment_")] + segments.sort(key=lambda x: int(x.split("_")[1])) + + # Handle different ways to specify segments for YOLO detection + if args.segments_detect_humans is None or len(args.segments_detect_humans) == 0: + # Default: run YOLO on all segments + detect_segments = [int(seg.split("_")[1]) for seg in segments] + logger.info("No segments specified, running YOLO detection on ALL segments") + elif len(args.segments_detect_humans) == 1 and args.segments_detect_humans[0].lower() == 'all': + # Explicit 'all' keyword + detect_segments = [int(seg.split("_")[1]) for seg in segments] + logger.info("Running YOLO detection on ALL segments") + else: + # Specific segment numbers provided + try: + detect_segments = [int(x) for x in args.segments_detect_humans] + logger.info(f"Running YOLO detection on segments: {detect_segments}") + except ValueError: + logger.error("Invalid segment numbers provided. Use integers or 'all'.") + return + + # Run YOLO detection on specified segments + do_yolo_detection_on_segments(base_dir, segments, detect_segments, scale=INFERENCE_SCALE, yolo_model_path=args.yolo_model) + + # Load YOLO model for inference + yolo_model = YOLO(args.yolo_model) + + for i, segment in enumerate(segments): + segment_index = int(segment.split("_")[1]) + segment_dir = os.path.join(base_dir, segment) + video_file_name = get_video_file_name(i) + video_path = os.path.join(segment_dir, video_file_name) + output_done_file = os.path.join(segment_dir, "output_frames_done") + + if os.path.exists(output_done_file): + logger.info(f"Segment {segment} already processed. Skipping.") + continue + + logger.info(f"Processing segment {segment}") + + # Initialize predictor + predictor = initialize_predictor() + + # Prepare low-resolution video frames for inference + low_res_video_path = os.path.join(segment_dir, "low_res_video.mp4") + if not os.path.exists(low_res_video_path): + create_low_res_video(video_path, low_res_video_path, INFERENCE_SCALE) + logger.info(f"Low-resolution video created for segment {segment}") + else: + logger.info(f"Low-resolution video already exists for segment {segment}, reuse") + + # Initialize inference state with low-resolution video + inference_state = predictor.init_state(video_path=low_res_video_path, async_loading_frames=True) + + # Load YOLO detections or previous masks + detection_file = os.path.join(segment_dir, "yolo_detections") + use_detections = segment_index in detect_segments + + if i == 0 and not use_detections: + # First segment must use YOLO detection since there's no previous mask + logger.warning(f"First segment {segment} requires YOLO detection. Running YOLO detection.") + use_detections = True + + if i > 0 and not use_detections: + # Try to load previous segment mask - search backwards for the most recent successful mask + logger.info(f"Using previous segment mask for segment {segment}") + mask_found = False + + # Search backwards through previous segments to find a valid mask + for j in range(i - 1, -1, -1): + prev_segment_dir = os.path.join(base_dir, segments[j]) + prev_mask_path = os.path.join(prev_segment_dir, "mask.png") + + if os.path.exists(prev_mask_path): + try: + per_obj_input_mask, input_palette = load_previous_segment_mask(prev_segment_dir) + # Add previous masks to predictor + for obj_id, mask in per_obj_input_mask.items(): + predictor.add_new_mask(inference_state, 0, obj_id, mask) + logger.info(f"Successfully loaded mask from segment {segments[j]}") + mask_found = True + break + except Exception as e: + logger.warning(f"Error loading mask from {segments[j]}: {e}") + continue + + if not mask_found: + logger.error(f"No valid previous mask found for segment {segment}. Consider running YOLO detection on this segment.") + continue + else: + # Load first frame for detection + first_frame = load_first_frame(low_res_video_path, scale=1.0) + if first_frame is None: + logger.error(f"Could not load first frame for segment {segment}") + continue + + # Run YOLO detection on first frame (either from file or on-the-fly) + if os.path.exists(detection_file): + logger.info(f"Using existing YOLO detections for segment {segment}") + else: + logger.info(f"Running YOLO detection on-the-fly for segment {segment}") + + human_detections = detect_humans_with_yolo(first_frame, yolo_model, args.yolo_confidence) + + if human_detections: + # Add YOLO detections to predictor + frame_width = first_frame.shape[1] + add_yolo_detections_to_predictor(predictor, inference_state, human_detections, frame_width) + else: + logger.warning(f"No humans detected in segment {segment}") + continue + + # Perform inference and collect masks per frame + video_segments = propagate_masks(predictor, inference_state) + + # Process high-resolution frames and save output video + output_video_path = os.path.join(segment_dir, f"output_{segment_index}.mp4") + logger.info("Processing segment complete, attempting to save full video from low res masks") + process_and_save_output_video( + video_path, + output_video_path, + video_segments, + use_nvenc=True # Set to True to use NVENC offloading + ) + + # Save final masks + mask_output_path = os.path.join(segment_dir, "mask.png") + save_final_masks(video_segments, mask_output_path) + + # Clean up + predictor.reset_state(inference_state) + del inference_state + del video_segments + del predictor + gc.collect() + + try: + os.remove(low_res_video_path) + logger.info(f"Deleted low-resolution video for segment {segment}") + except Exception as e: + logger.warning(f"Could not delete low-resolution video for segment {segment}: {e}") + + # Mark segment as completed + open(output_done_file, 'a').close() + + logger.info("Processing complete.") + +if __name__ == "__main__": + main()