stage 1 working
This commit is contained in:
58
README.md
58
README.md
@@ -32,19 +32,40 @@ git clone <repository-url>
|
||||
cd samyolo_on_segments
|
||||
|
||||
# Install Python dependencies
|
||||
pip install -r requirements.txt
|
||||
uv venv && source .venv/bin/activate
|
||||
uv pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### Model Dependencies
|
||||
### Download Models
|
||||
|
||||
You'll need to download the required model checkpoints:
|
||||
Use the provided script to automatically download all required models:
|
||||
|
||||
```bash
|
||||
# Download SAM2.1 and YOLO models
|
||||
python download_models.py
|
||||
```
|
||||
|
||||
This script will:
|
||||
- Create a `models/` directory structure
|
||||
- Download SAM2.1 configs and checkpoints (tiny, small, base+, large)
|
||||
- Download common YOLO models (yolov8n, yolov8s, yolov8m)
|
||||
- Update `config.yaml` to use local model paths
|
||||
|
||||
**Manual Download (Alternative):**
|
||||
1. **SAM2 Models**: Download from [Meta's SAM2 repository](https://github.com/facebookresearch/sam2)
|
||||
2. **YOLO Models**: YOLOv8 models will be downloaded automatically or you can specify a custom path
|
||||
2. **YOLO Models**: YOLOv8 models will be downloaded automatically on first use
|
||||
|
||||
## Quick Start
|
||||
|
||||
### 1. Configure the Pipeline
|
||||
### 1. Download Models
|
||||
|
||||
First, download the required SAM2.1 and YOLO models:
|
||||
|
||||
```bash
|
||||
python download_models.py
|
||||
```
|
||||
|
||||
### 2. Configure the Pipeline
|
||||
|
||||
Edit `config.yaml` to specify your input video and desired settings:
|
||||
|
||||
@@ -63,18 +84,18 @@ processing:
|
||||
detect_segments: "all"
|
||||
|
||||
models:
|
||||
yolo_model: "yolov8n.pt"
|
||||
sam2_checkpoint: "../checkpoints/sam2.1_hiera_large.pt"
|
||||
sam2_config: "configs/sam2.1/sam2.1_hiera_l.yaml"
|
||||
yolo_model: "models/yolo/yolov8n.pt"
|
||||
sam2_checkpoint: "models/sam2/checkpoints/sam2.1_hiera_large.pt"
|
||||
sam2_config: "models/sam2/configs/sam2.1/sam2.1_hiera_l.yaml"
|
||||
```
|
||||
|
||||
### 2. Run the Pipeline
|
||||
### 3. Run the Pipeline
|
||||
|
||||
```bash
|
||||
python main.py --config config.yaml
|
||||
```
|
||||
|
||||
### 3. Monitor Progress
|
||||
### 4. Monitor Progress
|
||||
|
||||
Check processing status:
|
||||
```bash
|
||||
@@ -166,8 +187,25 @@ samyolo_on_segments/
|
||||
├── README.md # This documentation
|
||||
├── config.yaml # Default configuration
|
||||
├── main.py # Main entry point
|
||||
├── download_models.py # Model download script
|
||||
├── requirements.txt # Python dependencies
|
||||
├── spec.md # Detailed specification
|
||||
├── models/ # Downloaded models (created by script)
|
||||
│ ├── sam2/
|
||||
│ │ ├── configs/sam2.1/ # SAM2.1 configuration files
|
||||
│ │ │ ├── sam2.1_hiera_t.yaml
|
||||
│ │ │ ├── sam2.1_hiera_s.yaml
|
||||
│ │ │ ├── sam2.1_hiera_b+.yaml
|
||||
│ │ │ └── sam2.1_hiera_l.yaml
|
||||
│ │ └── checkpoints/ # SAM2.1 model weights
|
||||
│ │ ├── sam2.1_hiera_tiny.pt
|
||||
│ │ ├── sam2.1_hiera_small.pt
|
||||
│ │ ├── sam2.1_hiera_base_plus.pt
|
||||
│ │ └── sam2.1_hiera_large.pt
|
||||
│ └── yolo/ # YOLO model weights
|
||||
│ ├── yolov8n.pt
|
||||
│ ├── yolov8s.pt
|
||||
│ └── yolov8m.pt
|
||||
├── core/ # Core processing modules
|
||||
│ ├── __init__.py
|
||||
│ ├── config_loader.py # Configuration management
|
||||
|
||||
@@ -23,11 +23,11 @@ processing:
|
||||
|
||||
models:
|
||||
# YOLO model path - can be pretrained (yolov8n.pt) or custom path
|
||||
yolo_model: "yolov8n.pt"
|
||||
yolo_model: "models/yolo/yolov8n.pt"
|
||||
|
||||
# SAM2 model configuration
|
||||
sam2_checkpoint: "../checkpoints/sam2.1_hiera_large.pt"
|
||||
sam2_config: "configs/sam2.1/sam2.1_hiera_l.yaml"
|
||||
sam2_checkpoint: "models/sam2/checkpoints/sam2.1_hiera_large.pt"
|
||||
sam2_config: "models/sam2/configs/sam2.1/sam2.1_hiera_l.yaml"
|
||||
|
||||
video:
|
||||
# Use NVIDIA hardware encoding (requires NVENC-capable GPU)
|
||||
|
||||
@@ -7,7 +7,7 @@ import os
|
||||
import subprocess
|
||||
import logging
|
||||
from typing import List, Tuple
|
||||
from ..utils.file_utils import ensure_directory, get_video_file_name
|
||||
from utils.file_utils import ensure_directory, get_video_file_name
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -13,17 +13,17 @@ from ultralytics import YOLO
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class YOLODetector:
|
||||
\"\"\"Handles YOLO-based human detection for video segments.\"\"\"
|
||||
"""Handles YOLO-based human detection for video segments."""
|
||||
|
||||
def __init__(self, model_path: str, confidence_threshold: float = 0.6, human_class_id: int = 0):
|
||||
\"\"\"
|
||||
"""
|
||||
Initialize YOLO detector.
|
||||
|
||||
Args:
|
||||
model_path: Path to YOLO model weights
|
||||
confidence_threshold: Detection confidence threshold
|
||||
human_class_id: COCO class ID for humans (0 = person)
|
||||
\"\"\"
|
||||
"""
|
||||
self.model_path = model_path
|
||||
self.confidence_threshold = confidence_threshold
|
||||
self.human_class_id = human_class_id
|
||||
@@ -31,13 +31,13 @@ class YOLODetector:
|
||||
# Load YOLO model
|
||||
try:
|
||||
self.model = YOLO(model_path)
|
||||
logger.info(f\"Loaded YOLO model from {model_path}\")
|
||||
logger.info(f"Loaded YOLO model from {model_path}")
|
||||
except Exception as e:
|
||||
logger.error(f\"Failed to load YOLO model: {e}\")
|
||||
logger.error(f"Failed to load YOLO model: {e}")
|
||||
raise
|
||||
|
||||
def detect_humans_in_frame(self, frame: np.ndarray) -> List[Dict[str, Any]]:
|
||||
\"\"\"
|
||||
"""
|
||||
Detect humans in a single frame using YOLO.
|
||||
|
||||
Args:
|
||||
@@ -45,7 +45,7 @@ class YOLODetector:
|
||||
|
||||
Returns:
|
||||
List of human detection dictionaries with bbox and confidence
|
||||
\"\"\"
|
||||
"""
|
||||
# Run YOLO detection
|
||||
results = self.model(frame, conf=self.confidence_threshold, verbose=False)
|
||||
|
||||
@@ -70,12 +70,12 @@ class YOLODetector:
|
||||
'confidence': conf
|
||||
})
|
||||
|
||||
logger.debug(f\"Detected human with confidence {conf:.2f} at {coords}\")
|
||||
logger.debug(f"Detected human with confidence {conf:.2f} at {coords}")
|
||||
|
||||
return human_detections
|
||||
|
||||
def detect_humans_in_video_first_frame(self, video_path: str, scale: float = 1.0) -> List[Dict[str, Any]]:
|
||||
\"\"\"
|
||||
"""
|
||||
Detect humans in the first frame of a video.
|
||||
|
||||
Args:
|
||||
@@ -84,21 +84,21 @@ class YOLODetector:
|
||||
|
||||
Returns:
|
||||
List of human detection dictionaries
|
||||
\"\"\"
|
||||
"""
|
||||
if not os.path.exists(video_path):
|
||||
logger.error(f\"Video file not found: {video_path}\")
|
||||
logger.error(f"Video file not found: {video_path}")
|
||||
return []
|
||||
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
if not cap.isOpened():
|
||||
logger.error(f\"Could not open video: {video_path}\")
|
||||
logger.error(f"Could not open video: {video_path}")
|
||||
return []
|
||||
|
||||
ret, frame = cap.read()
|
||||
cap.release()
|
||||
|
||||
if not ret:
|
||||
logger.error(f\"Could not read first frame from: {video_path}\")
|
||||
logger.error(f"Could not read first frame from: {video_path}")
|
||||
return []
|
||||
|
||||
# Scale frame if needed
|
||||
@@ -108,7 +108,7 @@ class YOLODetector:
|
||||
return self.detect_humans_in_frame(frame)
|
||||
|
||||
def save_detections_to_file(self, detections: List[Dict[str, Any]], output_path: str) -> bool:
|
||||
\"\"\"
|
||||
"""
|
||||
Save detection results to file.
|
||||
|
||||
Args:
|
||||
@@ -117,26 +117,26 @@ class YOLODetector:
|
||||
|
||||
Returns:
|
||||
True if saved successfully
|
||||
\"\"\"
|
||||
"""
|
||||
try:
|
||||
with open(output_path, 'w') as f:
|
||||
f.write(\"# YOLO Human Detections\\n\")
|
||||
f.write("# YOLO Human Detections\\n")
|
||||
if detections:
|
||||
for detection in detections:
|
||||
bbox = detection['bbox']
|
||||
conf = detection['confidence']
|
||||
f.write(f\"{bbox[0]},{bbox[1]},{bbox[2]},{bbox[3]},{conf}\\n\")
|
||||
logger.info(f\"Saved {len(detections)} detections to {output_path}\")
|
||||
f.write(f"{bbox[0]},{bbox[1]},{bbox[2]},{bbox[3]},{conf}\\n")
|
||||
logger.info(f"Saved {len(detections)} detections to {output_path}")
|
||||
else:
|
||||
f.write(\"# No humans detected\\n\")
|
||||
logger.info(f\"Saved empty detection file to {output_path}\")
|
||||
f.write("# No humans detected\\n")
|
||||
logger.info(f"Saved empty detection file to {output_path}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f\"Failed to save detections to {output_path}: {e}\")
|
||||
logger.error(f"Failed to save detections to {output_path}: {e}")
|
||||
return False
|
||||
|
||||
def load_detections_from_file(self, file_path: str) -> List[Dict[str, Any]]:
|
||||
\"\"\"
|
||||
"""
|
||||
Load detection results from file.
|
||||
|
||||
Args:
|
||||
@@ -144,11 +144,11 @@ class YOLODetector:
|
||||
|
||||
Returns:
|
||||
List of detection dictionaries
|
||||
\"\"\"
|
||||
"""
|
||||
detections = []
|
||||
|
||||
if not os.path.exists(file_path):
|
||||
logger.warning(f\"Detection file not found: {file_path}\")
|
||||
logger.warning(f"Detection file not found: {file_path}")
|
||||
return detections
|
||||
|
||||
try:
|
||||
@@ -170,18 +170,18 @@ class YOLODetector:
|
||||
'confidence': conf
|
||||
})
|
||||
except ValueError:
|
||||
logger.warning(f\"Invalid detection line: {line}\")
|
||||
logger.warning(f"Invalid detection line: {line}")
|
||||
continue
|
||||
|
||||
logger.info(f\"Loaded {len(detections)} detections from {file_path}\")
|
||||
logger.info(f"Loaded {len(detections)} detections from {file_path}")
|
||||
except Exception as e:
|
||||
logger.error(f\"Failed to load detections from {file_path}: {e}\")
|
||||
logger.error(f"Failed to load detections from {file_path}: {e}")
|
||||
|
||||
return detections
|
||||
|
||||
def process_segments_batch(self, segments_info: List[dict], detect_segments: List[int],
|
||||
scale: float = 0.5) -> Dict[int, List[Dict[str, Any]]]:
|
||||
\"\"\"
|
||||
"""
|
||||
Process multiple segments for human detection.
|
||||
|
||||
Args:
|
||||
@@ -191,7 +191,7 @@ class YOLODetector:
|
||||
|
||||
Returns:
|
||||
Dictionary mapping segment index to detection results
|
||||
\"\"\"
|
||||
"""
|
||||
results = {}
|
||||
|
||||
for segment_info in segments_info:
|
||||
@@ -202,17 +202,17 @@ class YOLODetector:
|
||||
continue
|
||||
|
||||
video_path = segment_info['video_file']
|
||||
detection_file = os.path.join(segment_info['directory'], \"yolo_detections\")
|
||||
detection_file = os.path.join(segment_info['directory'], "yolo_detections")
|
||||
|
||||
# Skip if already processed
|
||||
if os.path.exists(detection_file):
|
||||
logger.info(f\"Segment {segment_idx} already has detections, skipping\")
|
||||
logger.info(f"Segment {segment_idx} already has detections, skipping")
|
||||
detections = self.load_detections_from_file(detection_file)
|
||||
results[segment_idx] = detections
|
||||
continue
|
||||
|
||||
# Run detection
|
||||
logger.info(f\"Processing segment {segment_idx} for human detection\")
|
||||
logger.info(f"Processing segment {segment_idx} for human detection")
|
||||
detections = self.detect_humans_in_video_first_frame(video_path, scale)
|
||||
|
||||
# Save results
|
||||
@@ -223,7 +223,7 @@ class YOLODetector:
|
||||
|
||||
def convert_detections_to_sam2_prompts(self, detections: List[Dict[str, Any]],
|
||||
frame_width: int) -> List[Dict[str, Any]]:
|
||||
\"\"\"
|
||||
"""
|
||||
Convert YOLO detections to SAM2-compatible prompts for stereo video.
|
||||
|
||||
Args:
|
||||
@@ -232,7 +232,7 @@ class YOLODetector:
|
||||
|
||||
Returns:
|
||||
List of SAM2 prompt dictionaries with obj_id and bbox
|
||||
\"\"\"
|
||||
"""
|
||||
if not detections:
|
||||
return []
|
||||
|
||||
@@ -282,5 +282,5 @@ class YOLODetector:
|
||||
'confidence': detection['confidence']
|
||||
})
|
||||
|
||||
logger.debug(f\"Converted {len(detections)} detections to {len(prompts)} SAM2 prompts\")
|
||||
logger.debug(f"Converted {len(detections)} detections to {len(prompts)} SAM2 prompts")
|
||||
return prompts
|
||||
286
download_models.py
Executable file
286
download_models.py
Executable file
@@ -0,0 +1,286 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Model download script for YOLO + SAM2 video processing pipeline.
|
||||
Downloads SAM2.1 models and organizes them in the models directory.
|
||||
"""
|
||||
|
||||
import os
|
||||
import urllib.request
|
||||
import urllib.error
|
||||
from pathlib import Path
|
||||
import sys
|
||||
|
||||
def create_directory_structure():
|
||||
"""Create the models directory structure."""
|
||||
base_dir = Path(__file__).parent
|
||||
models_dir = base_dir / "models"
|
||||
|
||||
# Create main models directory
|
||||
models_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Create subdirectories
|
||||
sam2_dir = models_dir / "sam2"
|
||||
sam2_configs_dir = sam2_dir / "configs" / "sam2.1"
|
||||
sam2_checkpoints_dir = sam2_dir / "checkpoints"
|
||||
yolo_dir = models_dir / "yolo"
|
||||
|
||||
sam2_dir.mkdir(exist_ok=True)
|
||||
sam2_configs_dir.mkdir(parents=True, exist_ok=True)
|
||||
sam2_checkpoints_dir.mkdir(exist_ok=True)
|
||||
yolo_dir.mkdir(exist_ok=True)
|
||||
|
||||
print(f"Created models directory structure in: {models_dir}")
|
||||
return models_dir, sam2_configs_dir, sam2_checkpoints_dir, yolo_dir
|
||||
|
||||
def download_file(url, destination, description="file"):
|
||||
"""Download a file with progress indication."""
|
||||
try:
|
||||
print(f"Downloading {description}...")
|
||||
print(f" URL: {url}")
|
||||
print(f" Destination: {destination}")
|
||||
|
||||
def progress_hook(block_num, block_size, total_size):
|
||||
if total_size > 0:
|
||||
percent = min(100, (block_num * block_size * 100) // total_size)
|
||||
sys.stdout.write(f"\r Progress: {percent}%")
|
||||
sys.stdout.flush()
|
||||
|
||||
urllib.request.urlretrieve(url, destination, progress_hook)
|
||||
print(f"\n ✓ Downloaded {description}")
|
||||
return True
|
||||
|
||||
except urllib.error.URLError as e:
|
||||
print(f"\n ✗ Failed to download {description}: {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"\n ✗ Error downloading {description}: {e}")
|
||||
return False
|
||||
|
||||
def download_sam2_models():
|
||||
"""Download SAM2.1 model configurations and checkpoints."""
|
||||
print("Setting up SAM2.1 models...")
|
||||
|
||||
# Create directory structure
|
||||
models_dir, configs_dir, checkpoints_dir, yolo_dir = create_directory_structure()
|
||||
|
||||
# SAM2.1 model definitions
|
||||
sam2_models = {
|
||||
"tiny": {
|
||||
"config_url": "https://raw.githubusercontent.com/facebookresearch/sam2/main/sam2/configs/sam2.1/sam2.1_hiera_t.yaml",
|
||||
"checkpoint_url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt",
|
||||
"config_file": "sam2.1_hiera_t.yaml",
|
||||
"checkpoint_file": "sam2.1_hiera_tiny.pt"
|
||||
},
|
||||
"small": {
|
||||
"config_url": "https://raw.githubusercontent.com/facebookresearch/sam2/main/sam2/configs/sam2.1/sam2.1_hiera_s.yaml",
|
||||
"checkpoint_url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt",
|
||||
"config_file": "sam2.1_hiera_s.yaml",
|
||||
"checkpoint_file": "sam2.1_hiera_small.pt"
|
||||
},
|
||||
"base_plus": {
|
||||
"config_url": "https://raw.githubusercontent.com/facebookresearch/sam2/main/sam2/configs/sam2.1/sam2.1_hiera_b+.yaml",
|
||||
"checkpoint_url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt",
|
||||
"config_file": "sam2.1_hiera_b+.yaml",
|
||||
"checkpoint_file": "sam2.1_hiera_base_plus.pt"
|
||||
},
|
||||
"large": {
|
||||
"config_url": "https://raw.githubusercontent.com/facebookresearch/sam2/main/sam2/configs/sam2.1/sam2.1_hiera_l.yaml",
|
||||
"checkpoint_url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt",
|
||||
"config_file": "sam2.1_hiera_l.yaml",
|
||||
"checkpoint_file": "sam2.1_hiera_large.pt"
|
||||
}
|
||||
}
|
||||
|
||||
success_count = 0
|
||||
total_downloads = len(sam2_models) * 2 # configs + checkpoints
|
||||
|
||||
# Download each model's config and checkpoint
|
||||
for model_name, model_info in sam2_models.items():
|
||||
print(f"\n--- Downloading SAM2.1 {model_name.upper()} model ---")
|
||||
|
||||
# Download config file
|
||||
config_path = configs_dir / model_info["config_file"]
|
||||
if not config_path.exists():
|
||||
if download_file(
|
||||
model_info["config_url"],
|
||||
config_path,
|
||||
f"SAM2.1 {model_name} config"
|
||||
):
|
||||
success_count += 1
|
||||
else:
|
||||
print(f" ✓ Config file already exists: {config_path}")
|
||||
success_count += 1
|
||||
|
||||
# Download checkpoint file
|
||||
checkpoint_path = checkpoints_dir / model_info["checkpoint_file"]
|
||||
if not checkpoint_path.exists():
|
||||
if download_file(
|
||||
model_info["checkpoint_url"],
|
||||
checkpoint_path,
|
||||
f"SAM2.1 {model_name} checkpoint"
|
||||
):
|
||||
success_count += 1
|
||||
else:
|
||||
print(f" ✓ Checkpoint file already exists: {checkpoint_path}")
|
||||
success_count += 1
|
||||
|
||||
print(f"\n=== Download Summary ===")
|
||||
print(f"Successfully downloaded: {success_count}/{total_downloads} files")
|
||||
|
||||
if success_count == total_downloads:
|
||||
print("✓ All SAM2.1 models downloaded successfully!")
|
||||
return True
|
||||
else:
|
||||
print(f"⚠ Some downloads failed ({total_downloads - success_count} files)")
|
||||
return False
|
||||
|
||||
def download_yolo_models():
|
||||
"""Download default YOLO models to models directory."""
|
||||
print("\n--- Setting up YOLO models ---")
|
||||
|
||||
try:
|
||||
from ultralytics import YOLO
|
||||
import torch
|
||||
|
||||
# Default YOLO models to download
|
||||
yolo_models = ["yolov8n.pt", "yolov8s.pt", "yolov8m.pt"]
|
||||
models_dir = Path(__file__).parent / "models" / "yolo"
|
||||
|
||||
for model_name in yolo_models:
|
||||
model_path = models_dir / model_name
|
||||
if not model_path.exists():
|
||||
print(f"Downloading {model_name}...")
|
||||
try:
|
||||
# First try to download using the YOLO class with export
|
||||
model = YOLO(model_name)
|
||||
|
||||
# Export/save the model to our directory
|
||||
# The model.ckpt is the internal checkpoint
|
||||
if hasattr(model, 'ckpt') and hasattr(model.ckpt, 'save'):
|
||||
# Save the checkpoint directly
|
||||
torch.save(model.ckpt, str(model_path))
|
||||
print(f" ✓ Saved {model_name} to models directory")
|
||||
else:
|
||||
# Alternative: try to find where YOLO downloaded the model
|
||||
import shutil
|
||||
|
||||
# Common locations where YOLO might store models
|
||||
possible_paths = [
|
||||
Path.home() / ".cache" / "ultralytics" / "models" / model_name,
|
||||
Path.home() / ".ultralytics" / "models" / model_name,
|
||||
Path.home() / "runs" / "detect" / model_name,
|
||||
Path.cwd() / model_name, # Current directory
|
||||
]
|
||||
|
||||
found = False
|
||||
for possible_path in possible_paths:
|
||||
if possible_path.exists():
|
||||
shutil.copy2(possible_path, model_path)
|
||||
print(f" ✓ Copied {model_name} from {possible_path}")
|
||||
found = True
|
||||
# Clean up if it was downloaded to current directory
|
||||
if possible_path.parent == Path.cwd() and possible_path != model_path:
|
||||
possible_path.unlink()
|
||||
break
|
||||
|
||||
if not found:
|
||||
# Last resort: use urllib to download directly
|
||||
yolo_url = f"https://github.com/ultralytics/assets/releases/download/v8.2.0/{model_name}"
|
||||
print(f" Downloading directly from {yolo_url}...")
|
||||
download_file(yolo_url, str(model_path), f"YOLO {model_name}")
|
||||
|
||||
except Exception as e:
|
||||
print(f" ⚠ Error downloading {model_name}: {e}")
|
||||
# Try direct download as fallback
|
||||
try:
|
||||
yolo_url = f"https://github.com/ultralytics/assets/releases/download/v8.2.0/{model_name}"
|
||||
print(f" Trying direct download from {yolo_url}...")
|
||||
download_file(yolo_url, str(model_path), f"YOLO {model_name}")
|
||||
except Exception as e2:
|
||||
print(f" ✗ Failed to download {model_name}: {e2}")
|
||||
else:
|
||||
print(f" ✓ {model_name} already exists")
|
||||
|
||||
# Verify all models exist
|
||||
success = all((models_dir / model).exists() for model in yolo_models)
|
||||
if success:
|
||||
print("✓ YOLO models setup complete!")
|
||||
else:
|
||||
print("⚠ Some YOLO models may be missing")
|
||||
return success
|
||||
|
||||
except ImportError:
|
||||
print("⚠ ultralytics not installed. YOLO models will be downloaded on first use.")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"⚠ Error setting up YOLO models: {e}")
|
||||
return False
|
||||
|
||||
def update_config_file():
|
||||
"""Update config.yaml to use local model paths."""
|
||||
print("\n--- Updating config.yaml ---")
|
||||
|
||||
config_path = Path(__file__).parent / "config.yaml"
|
||||
if not config_path.exists():
|
||||
print("⚠ config.yaml not found, skipping update")
|
||||
return False
|
||||
|
||||
try:
|
||||
# Read current config
|
||||
with open(config_path, 'r') as f:
|
||||
content = f.read()
|
||||
|
||||
# Update model paths to use local models
|
||||
updated_content = content.replace(
|
||||
'yolo_model: "yolov8n.pt"',
|
||||
'yolo_model: "models/yolo/yolov8n.pt"'
|
||||
).replace(
|
||||
'sam2_checkpoint: "../checkpoints/sam2.1_hiera_large.pt"',
|
||||
'sam2_checkpoint: "models/sam2/checkpoints/sam2.1_hiera_large.pt"'
|
||||
).replace(
|
||||
'sam2_config: "configs/sam2.1/sam2.1_hiera_l.yaml"',
|
||||
'sam2_config: "models/sam2/configs/sam2.1/sam2.1_hiera_l.yaml"'
|
||||
)
|
||||
|
||||
# Write updated config
|
||||
with open(config_path, 'w') as f:
|
||||
f.write(updated_content)
|
||||
|
||||
print("✓ Updated config.yaml to use local model paths")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"⚠ Error updating config.yaml: {e}")
|
||||
return False
|
||||
|
||||
def main():
|
||||
"""Main function to download all models."""
|
||||
print("🤖 YOLO + SAM2 Model Download Script")
|
||||
print("="*50)
|
||||
|
||||
# Download SAM2 models
|
||||
sam2_success = download_sam2_models()
|
||||
|
||||
# Download YOLO models
|
||||
yolo_success = download_yolo_models()
|
||||
|
||||
# Update config file
|
||||
config_success = update_config_file()
|
||||
|
||||
print("\n" + "="*50)
|
||||
print("📋 Final Summary:")
|
||||
print(f" SAM2 models: {'✓' if sam2_success else '⚠'}")
|
||||
print(f" YOLO models: {'✓' if yolo_success else '⚠'}")
|
||||
print(f" Config update: {'✓' if config_success else '⚠'}")
|
||||
|
||||
if sam2_success and config_success:
|
||||
print("\n🎉 Setup complete! You can now run the pipeline with:")
|
||||
print(" python main.py --config config.yaml")
|
||||
else:
|
||||
print("\n⚠ Some steps failed. Check the output above for details.")
|
||||
|
||||
print("\n📁 Models are organized in:")
|
||||
print(f" {Path(__file__).parent / 'models'}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
618
spec.md
618
spec.md
@@ -190,3 +190,621 @@ models:
|
||||
- **Fine-tuned YOLO**: Domain-specific human detection models
|
||||
- **SAM2 Optimization**: Custom SAM2 checkpoints for video content
|
||||
- **Temporal Consistency**: Enhanced cross-segment mask propagation
|
||||
|
||||
|
||||
Here is the original monolithic script this repo is a refactor/modularization of. If something
|
||||
doesn't work in this repo, then consult the following script becasue it does work so this can
|
||||
be used to solve problems:
|
||||
|
||||
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
import cupy as cp
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import torch
|
||||
import logging
|
||||
import sys
|
||||
import gc
|
||||
from sam2.build_sam import build_sam2_video_predictor
|
||||
import argparse
|
||||
from ultralytics import YOLO
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
||||
|
||||
# Variables for input and output directories
|
||||
SAM2_CHECKPOINT = "../checkpoints/sam2.1_hiera_large.pt"
|
||||
MODEL_CFG = "configs/sam2.1/sam2.1_hiera_l.yaml"
|
||||
GREEN = [0, 255, 0]
|
||||
BLUE = [255, 0, 0]
|
||||
|
||||
INFERENCE_SCALE = 0.50
|
||||
FULL_SCALE = 1.0
|
||||
|
||||
# YOLO model for human detection (class 0 = person)
|
||||
YOLO_MODEL_PATH = "yolov8n.pt" # You can change this to a custom model
|
||||
YOLO_CONFIDENCE = 0.6
|
||||
HUMAN_CLASS_ID = 0 # COCO class ID for person
|
||||
|
||||
def open_video(video_path):
|
||||
"""
|
||||
Opens a video file and returns a generator that yields frames.
|
||||
|
||||
Parameters:
|
||||
- video_path: Path to the video file.
|
||||
|
||||
Returns:
|
||||
- A generator that yields frames from the video.
|
||||
"""
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
if not cap.isOpened():
|
||||
print(f"Error: Could not open video file {video_path}")
|
||||
return
|
||||
while True:
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
yield frame
|
||||
cap.release()
|
||||
|
||||
def load_previous_segment_mask(prev_segment_dir):
|
||||
mask_path = os.path.join(prev_segment_dir, "mask.png")
|
||||
mask_image = cv2.imread(mask_path)
|
||||
|
||||
if mask_image is None:
|
||||
raise FileNotFoundError(f"Mask image not found at {mask_path}")
|
||||
|
||||
# Ensure the mask_image has three color channels
|
||||
if len(mask_image.shape) != 3 or mask_image.shape[2] != 3:
|
||||
raise ValueError("Mask image does not have three color channels.")
|
||||
|
||||
mask_image = mask_image.astype(np.uint8)
|
||||
|
||||
# Extract Object A and Object B masks
|
||||
mask_a = np.all(mask_image == GREEN, axis=2)
|
||||
mask_b = np.all(mask_image == BLUE, axis=2)
|
||||
|
||||
per_obj_input_mask = {1: mask_a, 2: mask_b}
|
||||
input_palette = None # No palette needed for binary mask
|
||||
|
||||
return per_obj_input_mask, input_palette
|
||||
|
||||
|
||||
def apply_green_mask(frame, masks):
|
||||
# Convert frame and masks to CuPy arrays
|
||||
frame_gpu = cp.asarray(frame)
|
||||
combined_mask = cp.zeros(frame_gpu.shape[:2], dtype=cp.bool_)
|
||||
|
||||
for mask in masks:
|
||||
mask_gpu = cp.asarray(mask.squeeze())
|
||||
if mask_gpu.shape != frame_gpu.shape[:2]:
|
||||
resized_mask = cv2.resize(cp.asnumpy(mask_gpu).astype(cp.float32),
|
||||
(frame_gpu.shape[1], frame_gpu.shape[0]))
|
||||
mask_gpu = cp.asarray(resized_mask > 0.5) # Convert back to CuPy boolean array
|
||||
else:
|
||||
mask_gpu = mask_gpu.astype(cp.bool_) # Ensure boolean type
|
||||
combined_mask |= mask_gpu # Perform the bitwise OR operation
|
||||
|
||||
green_background = cp.full(frame_gpu.shape, cp.array([0, 255, 0], dtype=cp.uint8), dtype=cp.uint8)
|
||||
result_frame = cp.where(combined_mask[..., None], frame_gpu, green_background)
|
||||
return cp.asnumpy(result_frame) # Convert back to NumPy
|
||||
|
||||
|
||||
def initialize_predictor():
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
elif torch.backends.mps.is_available():
|
||||
device = torch.device("mps")
|
||||
print(
|
||||
"\nSupport for MPS devices is preliminary. SAM 2 is trained with CUDA and might "
|
||||
"give numerically different outputs and sometimes degraded performance on MPS."
|
||||
)
|
||||
# Enable MPS fallback for operations not supported on MPS
|
||||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
logger.info(f"Using device: {device}")
|
||||
predictor = build_sam2_video_predictor(MODEL_CFG, SAM2_CHECKPOINT, device=device)
|
||||
return predictor
|
||||
|
||||
|
||||
def load_first_frame(video_path, scale=1.0):
|
||||
"""
|
||||
Opens a video file and returns the first frame, scaled as specified.
|
||||
|
||||
Parameters:
|
||||
- video_path: Path to the video file.
|
||||
- scale: Scaling factor for the frame (default is 1.0 for original size).
|
||||
|
||||
Returns:
|
||||
- first_frame: The first frame of the video, scaled accordingly.
|
||||
"""
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
if not cap.isOpened():
|
||||
logger.error(f"Error: Could not open video file {video_path}")
|
||||
return None
|
||||
|
||||
ret, frame = cap.read()
|
||||
cap.release()
|
||||
|
||||
if not ret:
|
||||
logger.error(f"Error: Could not read frame from video file {video_path}")
|
||||
return None
|
||||
|
||||
if scale != 1.0:
|
||||
frame = cv2.resize(
|
||||
frame, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR
|
||||
)
|
||||
|
||||
return frame
|
||||
|
||||
def detect_humans_with_yolo(frame, yolo_model, confidence_threshold=YOLO_CONFIDENCE):
|
||||
"""
|
||||
Detect humans in a frame using YOLO model.
|
||||
|
||||
Parameters:
|
||||
- frame: Input frame (BGR format)
|
||||
- yolo_model: Loaded YOLO model
|
||||
- confidence_threshold: Detection confidence threshold
|
||||
|
||||
Returns:
|
||||
- human_boxes: List of bounding boxes for detected humans
|
||||
"""
|
||||
# Run YOLO detection
|
||||
results = yolo_model(frame, conf=confidence_threshold, verbose=False)
|
||||
|
||||
human_boxes = []
|
||||
|
||||
# Process results
|
||||
for result in results:
|
||||
boxes = result.boxes
|
||||
if boxes is not None:
|
||||
for box in boxes:
|
||||
# Get class ID
|
||||
cls = int(box.cls.cpu().numpy()[0])
|
||||
|
||||
# Check if it's a person (class 0 in COCO)
|
||||
if cls == HUMAN_CLASS_ID:
|
||||
# Get bounding box coordinates (x1, y1, x2, y2)
|
||||
coords = box.xyxy[0].cpu().numpy()
|
||||
conf = float(box.conf.cpu().numpy()[0])
|
||||
|
||||
human_boxes.append({
|
||||
'bbox': coords,
|
||||
'confidence': conf
|
||||
})
|
||||
|
||||
logger.info(f"Detected human with confidence {conf:.2f} at {coords}")
|
||||
|
||||
return human_boxes
|
||||
|
||||
def add_yolo_detections_to_predictor(predictor, inference_state, human_detections, frame_width):
|
||||
"""
|
||||
Add YOLO human detections as bounding boxes to SAM2 predictor.
|
||||
For stereo videos, creates two objects (left and right humans).
|
||||
|
||||
Parameters:
|
||||
- predictor: SAM2 video predictor
|
||||
- inference_state: SAM2 inference state
|
||||
- human_detections: List of human detection results
|
||||
- frame_width: Width of the frame for stereo splitting
|
||||
|
||||
Returns:
|
||||
- out_mask_logits: SAM2 output mask logits
|
||||
"""
|
||||
half_frame_width = frame_width // 2
|
||||
|
||||
# Sort detections by x-coordinate to get left and right humans
|
||||
human_detections.sort(key=lambda x: x['bbox'][0]) # Sort by x1 coordinate
|
||||
|
||||
obj_id = 1
|
||||
out_mask_logits = None
|
||||
|
||||
for i, detection in enumerate(human_detections[:2]): # Take up to 2 humans (left and right)
|
||||
bbox = detection['bbox']
|
||||
|
||||
# For stereo videos, assign obj_id based on position
|
||||
if len(human_detections) >= 2:
|
||||
# If we have multiple humans, assign based on left/right position
|
||||
center_x = (bbox[0] + bbox[2]) / 2
|
||||
if center_x < half_frame_width:
|
||||
current_obj_id = 1 # Left human
|
||||
else:
|
||||
current_obj_id = 2 # Right human
|
||||
else:
|
||||
# If only one human, duplicate for both sides (as in original stereo logic)
|
||||
current_obj_id = obj_id
|
||||
obj_id += 1
|
||||
|
||||
# Also add the mirrored version for stereo
|
||||
if obj_id <= 2:
|
||||
mirrored_bbox = bbox.copy()
|
||||
mirrored_bbox[0] += half_frame_width # Shift x1
|
||||
mirrored_bbox[2] += half_frame_width # Shift x2
|
||||
|
||||
# Ensure mirrored bbox is within frame bounds
|
||||
mirrored_bbox[0] = max(0, min(mirrored_bbox[0], frame_width - 1))
|
||||
mirrored_bbox[2] = max(0, min(mirrored_bbox[2], frame_width - 1))
|
||||
|
||||
try:
|
||||
_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
|
||||
inference_state=inference_state,
|
||||
frame_idx=0,
|
||||
obj_id=obj_id,
|
||||
box=mirrored_bbox.astype(np.float32),
|
||||
)
|
||||
logger.info(f"Added mirrored human detection for Object {obj_id}")
|
||||
obj_id += 1
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding mirrored human detection for Object {obj_id}: {e}")
|
||||
|
||||
try:
|
||||
_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
|
||||
inference_state=inference_state,
|
||||
frame_idx=0,
|
||||
obj_id=current_obj_id,
|
||||
box=bbox.astype(np.float32),
|
||||
)
|
||||
logger.info(f"Added human detection for Object {current_obj_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding human detection for Object {current_obj_id}: {e}")
|
||||
|
||||
return out_mask_logits
|
||||
|
||||
def propagate_masks(predictor, inference_state):
|
||||
video_segments = {}
|
||||
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
|
||||
video_segments[out_frame_idx] = {
|
||||
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
|
||||
for i, out_obj_id in enumerate(out_obj_ids)
|
||||
}
|
||||
return video_segments
|
||||
|
||||
def apply_colored_mask(frame, masks_a, masks_b):
|
||||
colored_mask = np.zeros_like(frame)
|
||||
|
||||
# Apply colors to the masks
|
||||
for mask in masks_a:
|
||||
mask = mask.squeeze()
|
||||
if mask.shape != frame.shape[:2]:
|
||||
mask = cv2.resize(mask, (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_NEAREST)
|
||||
indices = np.where(mask)
|
||||
colored_mask[mask] = [0, 255, 0] # Green for Object A
|
||||
|
||||
for mask in masks_b:
|
||||
mask = mask.squeeze()
|
||||
if mask.shape != frame.shape[:2]:
|
||||
mask = cv2.resize(mask, (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_NEAREST)
|
||||
indices = np.where(mask)
|
||||
colored_mask[mask] = [255, 0, 0] # Blue for Object B
|
||||
|
||||
return colored_mask
|
||||
|
||||
|
||||
def process_and_save_output_video(video_path, output_video_path, video_segments, use_nvenc=False):
|
||||
"""
|
||||
Process high-resolution frames, apply upscaled masks, and save the output video.
|
||||
"""
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
fps = cap.get(cv2.CAP_PROP_FPS) or 59.94
|
||||
|
||||
# Setup VideoWriter with desired settings
|
||||
if use_nvenc:
|
||||
# Use FFmpeg with NVENC offloading for H.265 encoding
|
||||
import subprocess
|
||||
|
||||
if sys.platform == 'darwin':
|
||||
encoder = 'hevc_videotoolbox'
|
||||
else:
|
||||
encoder = 'hevc_nvenc'
|
||||
|
||||
command = [
|
||||
'ffmpeg',
|
||||
'-y', # Overwrite output file if it exists
|
||||
'-f', 'rawvideo',
|
||||
'-vcodec', 'rawvideo',
|
||||
'-pix_fmt', 'bgr24',
|
||||
'-s', f'{frame_width}x{frame_height}',
|
||||
'-r', str(fps),
|
||||
'-i', '-', # Input from stdin
|
||||
'-an', # No audio
|
||||
'-vcodec', encoder,
|
||||
'-pix_fmt', 'nv12',
|
||||
'-preset', 'slow',
|
||||
'-b:v', '50M',
|
||||
output_video_path
|
||||
]
|
||||
process = subprocess.Popen(command, stdin=subprocess.PIPE)
|
||||
else:
|
||||
# Use OpenCV VideoWriter
|
||||
fourcc = cv2.VideoWriter_fourcc(*'HEVC') # H.265
|
||||
out = cv2.VideoWriter(output_video_path, fourcc, fps, (frame_width, frame_height))
|
||||
|
||||
frame_idx = 0
|
||||
while True:
|
||||
ret, frame = cap.read()
|
||||
if not ret or frame_idx >= len(video_segments):
|
||||
break
|
||||
|
||||
masks = [video_segments[frame_idx][out_obj_id] for out_obj_id in video_segments[frame_idx]]
|
||||
upscaled_masks = []
|
||||
|
||||
for mask in masks:
|
||||
mask = mask.squeeze()
|
||||
upscaled_mask = cv2.resize(mask.astype(np.uint8), (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_NEAREST)
|
||||
upscaled_masks.append(upscaled_mask)
|
||||
|
||||
result_frame = apply_green_mask(frame, upscaled_masks)
|
||||
|
||||
# Write frame to output
|
||||
if use_nvenc:
|
||||
process.stdin.write(result_frame.tobytes())
|
||||
else:
|
||||
out.write(result_frame)
|
||||
|
||||
frame_idx += 1
|
||||
|
||||
cap.release()
|
||||
if use_nvenc:
|
||||
process.stdin.close()
|
||||
process.wait()
|
||||
else:
|
||||
out.release()
|
||||
|
||||
def get_video_file_name(index):
|
||||
return f"segment_{str(index).zfill(3)}.mp4"
|
||||
|
||||
def do_yolo_detection_on_segments(base_dir, segments, detect_segments, scale=1.0, yolo_model_path=YOLO_MODEL_PATH):
|
||||
"""
|
||||
Run YOLO detection on specified segments and save detection results.
|
||||
"""
|
||||
logger.info("Running YOLO detection on requested segments.")
|
||||
|
||||
# Load YOLO model
|
||||
yolo_model = YOLO(yolo_model_path)
|
||||
|
||||
for i, segment in enumerate(segments):
|
||||
segment_index = int(segment.split("_")[1])
|
||||
segment_dir = os.path.join(base_dir, segment)
|
||||
detection_file = os.path.join(segment_dir, "yolo_detections")
|
||||
video_file = os.path.join(segment_dir, get_video_file_name(i))
|
||||
|
||||
if segment_index in detect_segments and not os.path.exists(detection_file):
|
||||
first_frame = load_first_frame(video_file, scale)
|
||||
if first_frame is None:
|
||||
continue
|
||||
|
||||
# Convert BGR to RGB for YOLO (YOLO expects BGR, so keep as BGR)
|
||||
human_detections = detect_humans_with_yolo(first_frame, yolo_model)
|
||||
|
||||
if human_detections:
|
||||
# Save detection results
|
||||
with open(detection_file, 'w') as f:
|
||||
f.write("# YOLO Human Detections\n")
|
||||
for detection in human_detections:
|
||||
bbox = detection['bbox']
|
||||
conf = detection['confidence']
|
||||
f.write(f"{bbox[0]},{bbox[1]},{bbox[2]},{bbox[3]},{conf}\n")
|
||||
logger.info(f"Saved {len(human_detections)} human detections for segment {segment}")
|
||||
else:
|
||||
logger.warning(f"No humans detected in segment {segment}")
|
||||
# Create empty file to mark as processed
|
||||
with open(detection_file, 'w') as f:
|
||||
f.write("# No humans detected\n")
|
||||
|
||||
def save_final_masks(video_segments, mask_output_path):
|
||||
"""
|
||||
Save the final masks as a colored image.
|
||||
"""
|
||||
last_frame_idx = max(video_segments.keys())
|
||||
masks_dict = video_segments[last_frame_idx]
|
||||
# Assuming you have two objects with IDs 1 and 2
|
||||
mask_a = masks_dict.get(1).squeeze() if 1 in masks_dict else None
|
||||
mask_b = masks_dict.get(2).squeeze() if 2 in masks_dict else None
|
||||
|
||||
if mask_a is None and mask_b is None:
|
||||
logger.error("No masks found for objects.")
|
||||
return
|
||||
|
||||
# Use the first available mask to determine dimensions
|
||||
reference_mask = mask_a if mask_a is not None else mask_b
|
||||
black_frame = np.zeros((reference_mask.shape[0], reference_mask.shape[1], 3), dtype=np.uint8)
|
||||
|
||||
if mask_a is not None:
|
||||
mask_a = mask_a.astype(bool)
|
||||
black_frame[mask_a] = GREEN
|
||||
|
||||
if mask_b is not None:
|
||||
mask_b = mask_b.astype(bool)
|
||||
black_frame[mask_b] = BLUE
|
||||
|
||||
# Save the mask image
|
||||
cv2.imwrite(mask_output_path, black_frame)
|
||||
logger.info(f"Saved final masks to {mask_output_path}")
|
||||
|
||||
def create_low_res_video(input_video_path, output_video_path, scale):
|
||||
"""
|
||||
Creates a low-resolution version of the input video for inference.
|
||||
"""
|
||||
cap = cv2.VideoCapture(input_video_path)
|
||||
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH) * scale)
|
||||
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT) * scale)
|
||||
fps = cap.get(cv2.CAP_PROP_FPS) or 59.94
|
||||
|
||||
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
||||
out = cv2.VideoWriter(output_video_path, fourcc, fps, (frame_width, frame_height))
|
||||
|
||||
while True:
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
low_res_frame = cv2.resize(frame, (frame_width, frame_height), interpolation=cv2.INTER_LINEAR)
|
||||
out.write(low_res_frame)
|
||||
|
||||
cap.release()
|
||||
out.release()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Process video segments with YOLO + SAM2.")
|
||||
parser.add_argument("--base-dir", type=str, help="Base directory for video segments.")
|
||||
parser.add_argument("--segments-detect-humans", nargs='*', help="Segments for which to run YOLO human detection. Use 'all' for all segments, or list specific segment numbers (e.g., 1 5 10). Default: all segments.")
|
||||
parser.add_argument("--yolo-model", type=str, default=YOLO_MODEL_PATH, help="Path to YOLO model.")
|
||||
parser.add_argument("--yolo-confidence", type=float, default=YOLO_CONFIDENCE, help="YOLO detection confidence threshold.")
|
||||
args = parser.parse_args()
|
||||
|
||||
base_dir = args.base_dir
|
||||
segments = [d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d)) and d.startswith("segment_")]
|
||||
segments.sort(key=lambda x: int(x.split("_")[1]))
|
||||
|
||||
# Handle different ways to specify segments for YOLO detection
|
||||
if args.segments_detect_humans is None or len(args.segments_detect_humans) == 0:
|
||||
# Default: run YOLO on all segments
|
||||
detect_segments = [int(seg.split("_")[1]) for seg in segments]
|
||||
logger.info("No segments specified, running YOLO detection on ALL segments")
|
||||
elif len(args.segments_detect_humans) == 1 and args.segments_detect_humans[0].lower() == 'all':
|
||||
# Explicit 'all' keyword
|
||||
detect_segments = [int(seg.split("_")[1]) for seg in segments]
|
||||
logger.info("Running YOLO detection on ALL segments")
|
||||
else:
|
||||
# Specific segment numbers provided
|
||||
try:
|
||||
detect_segments = [int(x) for x in args.segments_detect_humans]
|
||||
logger.info(f"Running YOLO detection on segments: {detect_segments}")
|
||||
except ValueError:
|
||||
logger.error("Invalid segment numbers provided. Use integers or 'all'.")
|
||||
return
|
||||
|
||||
# Run YOLO detection on specified segments
|
||||
do_yolo_detection_on_segments(base_dir, segments, detect_segments, scale=INFERENCE_SCALE, yolo_model_path=args.yolo_model)
|
||||
|
||||
# Load YOLO model for inference
|
||||
yolo_model = YOLO(args.yolo_model)
|
||||
|
||||
for i, segment in enumerate(segments):
|
||||
segment_index = int(segment.split("_")[1])
|
||||
segment_dir = os.path.join(base_dir, segment)
|
||||
video_file_name = get_video_file_name(i)
|
||||
video_path = os.path.join(segment_dir, video_file_name)
|
||||
output_done_file = os.path.join(segment_dir, "output_frames_done")
|
||||
|
||||
if os.path.exists(output_done_file):
|
||||
logger.info(f"Segment {segment} already processed. Skipping.")
|
||||
continue
|
||||
|
||||
logger.info(f"Processing segment {segment}")
|
||||
|
||||
# Initialize predictor
|
||||
predictor = initialize_predictor()
|
||||
|
||||
# Prepare low-resolution video frames for inference
|
||||
low_res_video_path = os.path.join(segment_dir, "low_res_video.mp4")
|
||||
if not os.path.exists(low_res_video_path):
|
||||
create_low_res_video(video_path, low_res_video_path, INFERENCE_SCALE)
|
||||
logger.info(f"Low-resolution video created for segment {segment}")
|
||||
else:
|
||||
logger.info(f"Low-resolution video already exists for segment {segment}, reuse")
|
||||
|
||||
# Initialize inference state with low-resolution video
|
||||
inference_state = predictor.init_state(video_path=low_res_video_path, async_loading_frames=True)
|
||||
|
||||
# Load YOLO detections or previous masks
|
||||
detection_file = os.path.join(segment_dir, "yolo_detections")
|
||||
use_detections = segment_index in detect_segments
|
||||
|
||||
if i == 0 and not use_detections:
|
||||
# First segment must use YOLO detection since there's no previous mask
|
||||
logger.warning(f"First segment {segment} requires YOLO detection. Running YOLO detection.")
|
||||
use_detections = True
|
||||
|
||||
if i > 0 and not use_detections:
|
||||
# Try to load previous segment mask - search backwards for the most recent successful mask
|
||||
logger.info(f"Using previous segment mask for segment {segment}")
|
||||
mask_found = False
|
||||
|
||||
# Search backwards through previous segments to find a valid mask
|
||||
for j in range(i - 1, -1, -1):
|
||||
prev_segment_dir = os.path.join(base_dir, segments[j])
|
||||
prev_mask_path = os.path.join(prev_segment_dir, "mask.png")
|
||||
|
||||
if os.path.exists(prev_mask_path):
|
||||
try:
|
||||
per_obj_input_mask, input_palette = load_previous_segment_mask(prev_segment_dir)
|
||||
# Add previous masks to predictor
|
||||
for obj_id, mask in per_obj_input_mask.items():
|
||||
predictor.add_new_mask(inference_state, 0, obj_id, mask)
|
||||
logger.info(f"Successfully loaded mask from segment {segments[j]}")
|
||||
mask_found = True
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(f"Error loading mask from {segments[j]}: {e}")
|
||||
continue
|
||||
|
||||
if not mask_found:
|
||||
logger.error(f"No valid previous mask found for segment {segment}. Consider running YOLO detection on this segment.")
|
||||
continue
|
||||
else:
|
||||
# Load first frame for detection
|
||||
first_frame = load_first_frame(low_res_video_path, scale=1.0)
|
||||
if first_frame is None:
|
||||
logger.error(f"Could not load first frame for segment {segment}")
|
||||
continue
|
||||
|
||||
# Run YOLO detection on first frame (either from file or on-the-fly)
|
||||
if os.path.exists(detection_file):
|
||||
logger.info(f"Using existing YOLO detections for segment {segment}")
|
||||
else:
|
||||
logger.info(f"Running YOLO detection on-the-fly for segment {segment}")
|
||||
|
||||
human_detections = detect_humans_with_yolo(first_frame, yolo_model, args.yolo_confidence)
|
||||
|
||||
if human_detections:
|
||||
# Add YOLO detections to predictor
|
||||
frame_width = first_frame.shape[1]
|
||||
add_yolo_detections_to_predictor(predictor, inference_state, human_detections, frame_width)
|
||||
else:
|
||||
logger.warning(f"No humans detected in segment {segment}")
|
||||
continue
|
||||
|
||||
# Perform inference and collect masks per frame
|
||||
video_segments = propagate_masks(predictor, inference_state)
|
||||
|
||||
# Process high-resolution frames and save output video
|
||||
output_video_path = os.path.join(segment_dir, f"output_{segment_index}.mp4")
|
||||
logger.info("Processing segment complete, attempting to save full video from low res masks")
|
||||
process_and_save_output_video(
|
||||
video_path,
|
||||
output_video_path,
|
||||
video_segments,
|
||||
use_nvenc=True # Set to True to use NVENC offloading
|
||||
)
|
||||
|
||||
# Save final masks
|
||||
mask_output_path = os.path.join(segment_dir, "mask.png")
|
||||
save_final_masks(video_segments, mask_output_path)
|
||||
|
||||
# Clean up
|
||||
predictor.reset_state(inference_state)
|
||||
del inference_state
|
||||
del video_segments
|
||||
del predictor
|
||||
gc.collect()
|
||||
|
||||
try:
|
||||
os.remove(low_res_video_path)
|
||||
logger.info(f"Deleted low-resolution video for segment {segment}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not delete low-resolution video for segment {segment}: {e}")
|
||||
|
||||
# Mark segment as completed
|
||||
open(output_done_file, 'a').close()
|
||||
|
||||
logger.info("Processing complete.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user