Compare commits
5 Commits
ed08ef2b4b
...
6617acb1c9
| Author | SHA1 | Date | |
|---|---|---|---|
| 6617acb1c9 | |||
| 02ad4d87d2 | |||
| 97f12c79a4 | |||
| cd7bc54efe | |||
| 46363a8a11 |
60
README.md
60
README.md
@@ -32,19 +32,40 @@ git clone <repository-url>
|
|||||||
cd samyolo_on_segments
|
cd samyolo_on_segments
|
||||||
|
|
||||||
# Install Python dependencies
|
# Install Python dependencies
|
||||||
pip install -r requirements.txt
|
uv venv && source .venv/bin/activate
|
||||||
|
uv pip install -r requirements.txt
|
||||||
```
|
```
|
||||||
|
|
||||||
### Model Dependencies
|
### Download Models
|
||||||
|
|
||||||
You'll need to download the required model checkpoints:
|
Use the provided script to automatically download all required models:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Download SAM2.1 and YOLO models
|
||||||
|
python download_models.py
|
||||||
|
```
|
||||||
|
|
||||||
|
This script will:
|
||||||
|
- Create a `models/` directory structure
|
||||||
|
- Download SAM2.1 configs and checkpoints (tiny, small, base+, large)
|
||||||
|
- Download common YOLO models (yolov8n, yolov8s, yolov8m)
|
||||||
|
- Update `config.yaml` to use local model paths
|
||||||
|
|
||||||
|
**Manual Download (Alternative):**
|
||||||
1. **SAM2 Models**: Download from [Meta's SAM2 repository](https://github.com/facebookresearch/sam2)
|
1. **SAM2 Models**: Download from [Meta's SAM2 repository](https://github.com/facebookresearch/sam2)
|
||||||
2. **YOLO Models**: YOLOv8 models will be downloaded automatically or you can specify a custom path
|
2. **YOLO Models**: YOLOv8 models will be downloaded automatically on first use
|
||||||
|
|
||||||
## Quick Start
|
## Quick Start
|
||||||
|
|
||||||
### 1. Configure the Pipeline
|
### 1. Download Models
|
||||||
|
|
||||||
|
First, download the required SAM2.1 and YOLO models:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python download_models.py
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Configure the Pipeline
|
||||||
|
|
||||||
Edit `config.yaml` to specify your input video and desired settings:
|
Edit `config.yaml` to specify your input video and desired settings:
|
||||||
|
|
||||||
@@ -63,18 +84,18 @@ processing:
|
|||||||
detect_segments: "all"
|
detect_segments: "all"
|
||||||
|
|
||||||
models:
|
models:
|
||||||
yolo_model: "yolov8n.pt"
|
yolo_model: "models/yolo/yolov8n.pt"
|
||||||
sam2_checkpoint: "../checkpoints/sam2.1_hiera_large.pt"
|
sam2_checkpoint: "models/sam2/checkpoints/sam2.1_hiera_large.pt"
|
||||||
sam2_config: "configs/sam2.1/sam2.1_hiera_l.yaml"
|
sam2_config: "models/sam2/configs/sam2.1/sam2.1_hiera_l.yaml"
|
||||||
```
|
```
|
||||||
|
|
||||||
### 2. Run the Pipeline
|
### 3. Run the Pipeline
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python main.py --config config.yaml
|
python main.py --config config.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
### 3. Monitor Progress
|
### 4. Monitor Progress
|
||||||
|
|
||||||
Check processing status:
|
Check processing status:
|
||||||
```bash
|
```bash
|
||||||
@@ -166,8 +187,25 @@ samyolo_on_segments/
|
|||||||
├── README.md # This documentation
|
├── README.md # This documentation
|
||||||
├── config.yaml # Default configuration
|
├── config.yaml # Default configuration
|
||||||
├── main.py # Main entry point
|
├── main.py # Main entry point
|
||||||
|
├── download_models.py # Model download script
|
||||||
├── requirements.txt # Python dependencies
|
├── requirements.txt # Python dependencies
|
||||||
├── spec.md # Detailed specification
|
├── spec.md # Detailed specification
|
||||||
|
├── models/ # Downloaded models (created by script)
|
||||||
|
│ ├── sam2/
|
||||||
|
│ │ ├── configs/sam2.1/ # SAM2.1 configuration files
|
||||||
|
│ │ │ ├── sam2.1_hiera_t.yaml
|
||||||
|
│ │ │ ├── sam2.1_hiera_s.yaml
|
||||||
|
│ │ │ ├── sam2.1_hiera_b+.yaml
|
||||||
|
│ │ │ └── sam2.1_hiera_l.yaml
|
||||||
|
│ │ └── checkpoints/ # SAM2.1 model weights
|
||||||
|
│ │ ├── sam2.1_hiera_tiny.pt
|
||||||
|
│ │ ├── sam2.1_hiera_small.pt
|
||||||
|
│ │ ├── sam2.1_hiera_base_plus.pt
|
||||||
|
│ │ └── sam2.1_hiera_large.pt
|
||||||
|
│ └── yolo/ # YOLO model weights
|
||||||
|
│ ├── yolov8n.pt
|
||||||
|
│ ├── yolov8s.pt
|
||||||
|
│ └── yolov8m.pt
|
||||||
├── core/ # Core processing modules
|
├── core/ # Core processing modules
|
||||||
│ ├── __init__.py
|
│ ├── __init__.py
|
||||||
│ ├── config_loader.py # Configuration management
|
│ ├── config_loader.py # Configuration management
|
||||||
@@ -297,4 +335,4 @@ This project is under active development. The core detection pipeline is functio
|
|||||||
For issues and questions:
|
For issues and questions:
|
||||||
1. Check the troubleshooting section
|
1. Check the troubleshooting section
|
||||||
2. Review the logs with `log_level: "DEBUG"`
|
2. Review the logs with `log_level: "DEBUG"`
|
||||||
3. Open an issue with your configuration and error details
|
3. Open an issue with your configuration and error details
|
||||||
|
|||||||
11
config.yaml
11
config.yaml
@@ -23,11 +23,11 @@ processing:
|
|||||||
|
|
||||||
models:
|
models:
|
||||||
# YOLO model path - can be pretrained (yolov8n.pt) or custom path
|
# YOLO model path - can be pretrained (yolov8n.pt) or custom path
|
||||||
yolo_model: "yolov8n.pt"
|
yolo_model: "models/yolo/yolov8n.pt"
|
||||||
|
|
||||||
# SAM2 model configuration
|
# SAM2 model configuration
|
||||||
sam2_checkpoint: "../checkpoints/sam2.1_hiera_large.pt"
|
sam2_checkpoint: "models/sam2/checkpoints/sam2.1_hiera_large.pt"
|
||||||
sam2_config: "configs/sam2.1/sam2.1_hiera_l.yaml"
|
sam2_config: "models/sam2/configs/sam2.1/sam2.1_hiera_l.yaml"
|
||||||
|
|
||||||
video:
|
video:
|
||||||
# Use NVIDIA hardware encoding (requires NVENC-capable GPU)
|
# Use NVIDIA hardware encoding (requires NVENC-capable GPU)
|
||||||
@@ -56,4 +56,7 @@ advanced:
|
|||||||
cleanup_intermediate_files: true
|
cleanup_intermediate_files: true
|
||||||
|
|
||||||
# Logging level (DEBUG, INFO, WARNING, ERROR)
|
# Logging level (DEBUG, INFO, WARNING, ERROR)
|
||||||
log_level: "INFO"
|
log_level: "INFO"
|
||||||
|
|
||||||
|
# Save debug frames with YOLO detections visualized
|
||||||
|
save_yolo_debug_frames: true
|
||||||
|
|||||||
@@ -50,11 +50,31 @@ class ConfigLoader:
|
|||||||
raise ValueError(f"Missing required field: output.{field}")
|
raise ValueError(f"Missing required field: output.{field}")
|
||||||
|
|
||||||
# Validate models section
|
# Validate models section
|
||||||
required_model_fields = ['yolo_model', 'sam2_checkpoint', 'sam2_config']
|
required_model_fields = ['sam2_checkpoint', 'sam2_config']
|
||||||
for field in required_model_fields:
|
for field in required_model_fields:
|
||||||
if field not in self.config['models']:
|
if field not in self.config['models']:
|
||||||
raise ValueError(f"Missing required field: models.{field}")
|
raise ValueError(f"Missing required field: models.{field}")
|
||||||
|
|
||||||
|
# Validate YOLO model configuration
|
||||||
|
yolo_mode = self.config['models'].get('yolo_mode', 'detection')
|
||||||
|
if yolo_mode not in ['detection', 'segmentation']:
|
||||||
|
raise ValueError(f"Invalid yolo_mode: {yolo_mode}. Must be 'detection' or 'segmentation'")
|
||||||
|
|
||||||
|
# Check for legacy yolo_model field vs new structure
|
||||||
|
has_legacy_yolo_model = 'yolo_model' in self.config['models']
|
||||||
|
has_new_yolo_models = 'yolo_detection_model' in self.config['models'] or 'yolo_segmentation_model' in self.config['models']
|
||||||
|
|
||||||
|
if not has_legacy_yolo_model and not has_new_yolo_models:
|
||||||
|
raise ValueError("Missing YOLO model configuration. Provide either 'yolo_model' (legacy) or 'yolo_detection_model'/'yolo_segmentation_model' (new)")
|
||||||
|
|
||||||
|
# Validate that the required model for the current mode exists
|
||||||
|
if yolo_mode == 'detection':
|
||||||
|
if has_new_yolo_models and 'yolo_detection_model' not in self.config['models']:
|
||||||
|
raise ValueError("yolo_mode is 'detection' but yolo_detection_model not specified")
|
||||||
|
elif yolo_mode == 'segmentation':
|
||||||
|
if has_new_yolo_models and 'yolo_segmentation_model' not in self.config['models']:
|
||||||
|
raise ValueError("yolo_mode is 'segmentation' but yolo_segmentation_model not specified")
|
||||||
|
|
||||||
# Validate processing.detect_segments format
|
# Validate processing.detect_segments format
|
||||||
detect_segments = self.config['processing'].get('detect_segments', 'all')
|
detect_segments = self.config['processing'].get('detect_segments', 'all')
|
||||||
if not isinstance(detect_segments, (str, list)):
|
if not isinstance(detect_segments, (str, list)):
|
||||||
@@ -114,8 +134,17 @@ class ConfigLoader:
|
|||||||
return self.config['processing'].get('detect_segments', 'all')
|
return self.config['processing'].get('detect_segments', 'all')
|
||||||
|
|
||||||
def get_yolo_model_path(self) -> str:
|
def get_yolo_model_path(self) -> str:
|
||||||
"""Get YOLO model path."""
|
"""Get YOLO model path (legacy method for backward compatibility)."""
|
||||||
return self.config['models']['yolo_model']
|
# Check for legacy configuration first
|
||||||
|
if 'yolo_model' in self.config['models']:
|
||||||
|
return self.config['models']['yolo_model']
|
||||||
|
|
||||||
|
# Use new configuration based on mode
|
||||||
|
yolo_mode = self.config['models'].get('yolo_mode', 'detection')
|
||||||
|
if yolo_mode == 'detection':
|
||||||
|
return self.config['models'].get('yolo_detection_model', 'yolov8n.pt')
|
||||||
|
else: # segmentation mode
|
||||||
|
return self.config['models'].get('yolo_segmentation_model', 'yolov8n-seg.pt')
|
||||||
|
|
||||||
def get_sam2_checkpoint(self) -> str:
|
def get_sam2_checkpoint(self) -> str:
|
||||||
"""Get SAM2 checkpoint path."""
|
"""Get SAM2 checkpoint path."""
|
||||||
|
|||||||
@@ -17,16 +17,18 @@ logger = logging.getLogger(__name__)
|
|||||||
class SAM2Processor:
|
class SAM2Processor:
|
||||||
"""Handles SAM2-based video segmentation for human tracking."""
|
"""Handles SAM2-based video segmentation for human tracking."""
|
||||||
|
|
||||||
def __init__(self, checkpoint_path: str, config_path: str):
|
def __init__(self, checkpoint_path: str, config_path: str, vos_optimized: bool = False):
|
||||||
"""
|
"""
|
||||||
Initialize SAM2 processor.
|
Initialize SAM2 processor.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
checkpoint_path: Path to SAM2 checkpoint
|
checkpoint_path: Path to SAM2 checkpoint
|
||||||
config_path: Path to SAM2 config file
|
config_path: Path to SAM2 config file
|
||||||
|
vos_optimized: Enable VOS optimization for speedup (requires PyTorch 2.5.1+)
|
||||||
"""
|
"""
|
||||||
self.checkpoint_path = checkpoint_path
|
self.checkpoint_path = checkpoint_path
|
||||||
self.config_path = config_path
|
self.config_path = config_path
|
||||||
|
self.vos_optimized = vos_optimized
|
||||||
self.predictor = None
|
self.predictor = None
|
||||||
self._initialize_predictor()
|
self._initialize_predictor()
|
||||||
|
|
||||||
@@ -47,11 +49,50 @@ class SAM2Processor:
|
|||||||
logger.info(f"Using device: {device}")
|
logger.info(f"Using device: {device}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.predictor = build_sam2_video_predictor(
|
# Extract just the config filename for SAM2's Hydra-based loader
|
||||||
self.config_path,
|
# SAM2 expects a config name relative to its internal config directory
|
||||||
self.checkpoint_path,
|
config_name = os.path.basename(self.config_path)
|
||||||
device=device
|
if config_name.endswith('.yaml'):
|
||||||
)
|
config_name = config_name[:-5] # Remove .yaml extension
|
||||||
|
|
||||||
|
# SAM2 configs are in the format "sam2.1_hiera_X.yaml"
|
||||||
|
# and should be referenced as "configs/sam2.1/sam2.1_hiera_X"
|
||||||
|
if config_name.startswith("sam2.1_hiera"):
|
||||||
|
config_name = f"configs/sam2.1/{config_name}"
|
||||||
|
elif config_name.startswith("sam2_hiera"):
|
||||||
|
config_name = f"configs/sam2/{config_name}"
|
||||||
|
|
||||||
|
logger.info(f"Using SAM2 config: {config_name}")
|
||||||
|
|
||||||
|
# Use VOS optimization if enabled and supported
|
||||||
|
if self.vos_optimized:
|
||||||
|
try:
|
||||||
|
self.predictor = build_sam2_video_predictor(
|
||||||
|
config_name, # Use just the config name, not full path
|
||||||
|
self.checkpoint_path,
|
||||||
|
device=device,
|
||||||
|
vos_optimized=True # New optimization for major speedup
|
||||||
|
)
|
||||||
|
logger.info("Using optimized SAM2 VOS predictor with full model compilation")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to use optimized VOS predictor: {e}")
|
||||||
|
logger.info("Falling back to standard SAM2 predictor")
|
||||||
|
# Fallback to standard predictor
|
||||||
|
self.predictor = build_sam2_video_predictor(
|
||||||
|
config_name,
|
||||||
|
self.checkpoint_path,
|
||||||
|
device=device,
|
||||||
|
overrides=dict(conf=0.95)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Use standard predictor
|
||||||
|
self.predictor = build_sam2_video_predictor(
|
||||||
|
config_name,
|
||||||
|
self.checkpoint_path,
|
||||||
|
device=device,
|
||||||
|
overrides=dict(conf=0.95)
|
||||||
|
)
|
||||||
|
logger.info("Using standard SAM2 predictor")
|
||||||
|
|
||||||
# Enable optimizations for CUDA
|
# Enable optimizations for CUDA
|
||||||
if device.type == "cuda":
|
if device.type == "cuda":
|
||||||
@@ -103,6 +144,7 @@ class SAM2Processor:
|
|||||||
def add_yolo_prompts_to_predictor(self, inference_state, prompts: List[Dict[str, Any]]) -> bool:
|
def add_yolo_prompts_to_predictor(self, inference_state, prompts: List[Dict[str, Any]]) -> bool:
|
||||||
"""
|
"""
|
||||||
Add YOLO detection prompts to SAM2 predictor.
|
Add YOLO detection prompts to SAM2 predictor.
|
||||||
|
Includes error handling matching the working spec.md implementation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
inference_state: SAM2 inference state
|
inference_state: SAM2 inference state
|
||||||
@@ -112,14 +154,21 @@ class SAM2Processor:
|
|||||||
True if prompts were added successfully
|
True if prompts were added successfully
|
||||||
"""
|
"""
|
||||||
if not prompts:
|
if not prompts:
|
||||||
logger.warning("No prompts provided to SAM2")
|
logger.warning("SAM2 Debug: No prompts provided to SAM2")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
logger.info(f"SAM2 Debug: Received {len(prompts)} prompts to add to predictor")
|
||||||
for prompt in prompts:
|
|
||||||
obj_id = prompt['obj_id']
|
success_count = 0
|
||||||
bbox = prompt['bbox']
|
|
||||||
|
for i, prompt in enumerate(prompts):
|
||||||
|
obj_id = prompt['obj_id']
|
||||||
|
bbox = prompt['bbox']
|
||||||
|
confidence = prompt.get('confidence', 'unknown')
|
||||||
|
|
||||||
|
logger.info(f"SAM2 Debug: Adding prompt {i+1}/{len(prompts)}: Object {obj_id}, bbox={bbox}, conf={confidence}")
|
||||||
|
|
||||||
|
try:
|
||||||
_, out_obj_ids, out_mask_logits = self.predictor.add_new_points_or_box(
|
_, out_obj_ids, out_mask_logits = self.predictor.add_new_points_or_box(
|
||||||
inference_state=inference_state,
|
inference_state=inference_state,
|
||||||
frame_idx=0,
|
frame_idx=0,
|
||||||
@@ -127,13 +176,19 @@ class SAM2Processor:
|
|||||||
box=bbox.astype(np.float32),
|
box=bbox.astype(np.float32),
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(f"Added prompt for Object {obj_id}: {bbox}")
|
logger.info(f"SAM2 Debug: ✓ Successfully added Object {obj_id} - returned obj_ids: {out_obj_ids}")
|
||||||
|
success_count += 1
|
||||||
logger.info(f"Successfully added {len(prompts)} prompts to SAM2")
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"SAM2 Debug: ✗ Error adding Object {obj_id}: {e}")
|
||||||
|
# Continue processing other prompts even if one fails
|
||||||
|
continue
|
||||||
|
|
||||||
|
if success_count > 0:
|
||||||
|
logger.info(f"SAM2 Debug: Final result - {success_count}/{len(prompts)} prompts successfully added")
|
||||||
return True
|
return True
|
||||||
|
else:
|
||||||
except Exception as e:
|
logger.error("SAM2 Debug: FAILED - No prompts were successfully added to SAM2")
|
||||||
logger.error(f"Error adding prompts to SAM2: {e}")
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def load_previous_segment_mask(self, prev_segment_dir: str) -> Optional[Dict[int, np.ndarray]]:
|
def load_previous_segment_mask(self, prev_segment_dir: str) -> Optional[Dict[int, np.ndarray]]:
|
||||||
@@ -218,32 +273,46 @@ class SAM2Processor:
|
|||||||
Dictionary mapping frame indices to object masks
|
Dictionary mapping frame indices to object masks
|
||||||
"""
|
"""
|
||||||
video_segments = {}
|
video_segments = {}
|
||||||
|
frame_count = 0
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
logger.info("Starting SAM2 mask propagation...")
|
||||||
for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(inference_state):
|
for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(inference_state):
|
||||||
video_segments[out_frame_idx] = {
|
video_segments[out_frame_idx] = {
|
||||||
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
|
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
|
||||||
for i, out_obj_id in enumerate(out_obj_ids)
|
for i, out_obj_id in enumerate(out_obj_ids)
|
||||||
}
|
}
|
||||||
|
frame_count += 1
|
||||||
|
|
||||||
|
# Log progress every 50 frames
|
||||||
|
if frame_count % 50 == 0:
|
||||||
|
logger.info(f"SAM2 propagation progress: {frame_count} frames processed")
|
||||||
|
|
||||||
logger.info(f"Propagated masks across {len(video_segments)} frames with {len(out_obj_ids)} objects")
|
logger.info(f"SAM2 propagation completed: {len(video_segments)} frames with {len(out_obj_ids) if 'out_obj_ids' in locals() else 0} objects")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error during mask propagation: {e}")
|
logger.error(f"Error during mask propagation after {frame_count} frames: {e}")
|
||||||
|
logger.error("This may be due to VOS optimization issues or insufficient GPU memory")
|
||||||
|
if frame_count == 0:
|
||||||
|
logger.error("No frames were processed - propagation failed completely")
|
||||||
|
else:
|
||||||
|
logger.warning(f"Partial propagation completed: {frame_count} frames before failure")
|
||||||
|
|
||||||
return video_segments
|
return video_segments
|
||||||
|
|
||||||
def process_single_segment(self, segment_info: dict, yolo_prompts: Optional[List[Dict[str, Any]]] = None,
|
def process_single_segment(self, segment_info: dict, yolo_prompts: Optional[List[Dict[str, Any]]] = None,
|
||||||
previous_masks: Optional[Dict[int, np.ndarray]] = None,
|
previous_masks: Optional[Dict[int, np.ndarray]] = None,
|
||||||
inference_scale: float = 0.5) -> Optional[Dict[int, Dict[int, np.ndarray]]]:
|
inference_scale: float = 0.5,
|
||||||
|
multi_frame_prompts: Optional[Dict[int, List[Dict[str, Any]]]] = None) -> Optional[Dict[int, Dict[int, np.ndarray]]]:
|
||||||
"""
|
"""
|
||||||
Process a single video segment with SAM2.
|
Process a single video segment with SAM2.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
segment_info: Segment information dictionary
|
segment_info: Segment information dictionary
|
||||||
yolo_prompts: Optional YOLO detection prompts
|
yolo_prompts: Optional YOLO detection prompts for first frame
|
||||||
previous_masks: Optional masks from previous segment
|
previous_masks: Optional masks from previous segment
|
||||||
inference_scale: Scale factor for inference
|
inference_scale: Scale factor for inference
|
||||||
|
multi_frame_prompts: Optional prompts for multiple frames (mid-segment detection)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Video segments dictionary or None if failed
|
Video segments dictionary or None if failed
|
||||||
@@ -284,6 +353,13 @@ class SAM2Processor:
|
|||||||
logger.error(f"No prompts or previous masks available for segment {segment_idx}")
|
logger.error(f"No prompts or previous masks available for segment {segment_idx}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
# Add mid-segment prompts if provided
|
||||||
|
if multi_frame_prompts:
|
||||||
|
logger.info(f"Adding mid-segment prompts for segment {segment_idx}")
|
||||||
|
if not self.add_multi_frame_prompts_to_predictor(inference_state, multi_frame_prompts):
|
||||||
|
logger.warning(f"Failed to add mid-segment prompts for segment {segment_idx}")
|
||||||
|
# Don't return None here - continue with existing prompts
|
||||||
|
|
||||||
# Propagate masks
|
# Propagate masks
|
||||||
video_segments = self.propagate_masks(inference_state)
|
video_segments = self.propagate_masks(inference_state)
|
||||||
|
|
||||||
@@ -359,4 +435,218 @@ class SAM2Processor:
|
|||||||
logger.info(f"Saved final masks to {output_path}")
|
logger.info(f"Saved final masks to {output_path}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error saving final masks: {e}")
|
logger.error(f"Error saving final masks: {e}")
|
||||||
|
|
||||||
|
def generate_first_frame_debug_masks(self, video_path: str, prompts: List[Dict[str, Any]],
|
||||||
|
output_path: str, inference_scale: float = 0.5) -> bool:
|
||||||
|
"""
|
||||||
|
Generate SAM2 masks for just the first frame and save debug visualization.
|
||||||
|
This helps debug what SAM2 is producing for each detected object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
video_path: Path to the video file
|
||||||
|
prompts: List of SAM2 prompt dictionaries
|
||||||
|
output_path: Path to save the debug image
|
||||||
|
inference_scale: Scale factor for SAM2 inference
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if debug masks were generated successfully
|
||||||
|
"""
|
||||||
|
if not prompts:
|
||||||
|
logger.warning("No prompts provided for first frame debug")
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.info(f"SAM2 Debug: Generating first frame masks for {len(prompts)} objects")
|
||||||
|
|
||||||
|
# Load the first frame
|
||||||
|
cap = cv2.VideoCapture(video_path)
|
||||||
|
ret, original_frame = cap.read()
|
||||||
|
cap.release()
|
||||||
|
|
||||||
|
if not ret:
|
||||||
|
logger.error("Could not read first frame for debug mask generation")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Scale frame for inference if needed
|
||||||
|
if inference_scale != 1.0:
|
||||||
|
inference_frame = cv2.resize(original_frame, None, fx=inference_scale, fy=inference_scale, interpolation=cv2.INTER_LINEAR)
|
||||||
|
else:
|
||||||
|
inference_frame = original_frame.copy()
|
||||||
|
|
||||||
|
# Create temporary low-res video with just first frame
|
||||||
|
import tempfile
|
||||||
|
import os
|
||||||
|
temp_dir = tempfile.mkdtemp()
|
||||||
|
temp_video_path = os.path.join(temp_dir, "first_frame.mp4")
|
||||||
|
|
||||||
|
# Write single frame to temporary video
|
||||||
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
||||||
|
out = cv2.VideoWriter(temp_video_path, fourcc, 1.0, (inference_frame.shape[1], inference_frame.shape[0]))
|
||||||
|
out.write(inference_frame)
|
||||||
|
out.release()
|
||||||
|
|
||||||
|
# Initialize SAM2 inference state with single frame
|
||||||
|
inference_state = self.predictor.init_state(video_path=temp_video_path, async_loading_frames=True)
|
||||||
|
|
||||||
|
# Add prompts
|
||||||
|
if not self.add_yolo_prompts_to_predictor(inference_state, prompts):
|
||||||
|
logger.error("Failed to add prompts for first frame debug")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Generate masks for first frame only
|
||||||
|
frame_masks = {}
|
||||||
|
for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(inference_state):
|
||||||
|
if out_frame_idx == 0: # Only process first frame
|
||||||
|
frame_masks = {
|
||||||
|
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
|
||||||
|
for i, out_obj_id in enumerate(out_obj_ids)
|
||||||
|
}
|
||||||
|
break
|
||||||
|
|
||||||
|
if not frame_masks:
|
||||||
|
logger.error("No masks generated for first frame debug")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Create debug visualization
|
||||||
|
debug_frame = original_frame.copy()
|
||||||
|
|
||||||
|
# Define colors for each object
|
||||||
|
colors = {
|
||||||
|
1: (0, 255, 0), # Green for Object 1 (Left eye)
|
||||||
|
2: (255, 0, 0), # Blue for Object 2 (Right eye)
|
||||||
|
3: (0, 255, 255), # Yellow for Object 3
|
||||||
|
4: (255, 0, 255), # Magenta for Object 4
|
||||||
|
}
|
||||||
|
|
||||||
|
# Overlay masks with transparency
|
||||||
|
for obj_id, mask in frame_masks.items():
|
||||||
|
mask = mask.squeeze()
|
||||||
|
|
||||||
|
# Resize mask to match original frame if needed
|
||||||
|
if mask.shape != original_frame.shape[:2]:
|
||||||
|
mask = cv2.resize(mask.astype(np.float32), (original_frame.shape[1], original_frame.shape[0]), interpolation=cv2.INTER_NEAREST)
|
||||||
|
mask = mask > 0.5
|
||||||
|
|
||||||
|
# Apply colored overlay
|
||||||
|
color = colors.get(obj_id, (128, 128, 128))
|
||||||
|
overlay = debug_frame.copy()
|
||||||
|
overlay[mask] = color
|
||||||
|
|
||||||
|
# Blend with original (30% overlay, 70% original)
|
||||||
|
cv2.addWeighted(overlay, 0.3, debug_frame, 0.7, 0, debug_frame)
|
||||||
|
|
||||||
|
# Draw outline
|
||||||
|
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||||
|
cv2.drawContours(debug_frame, contours, -1, color, 2)
|
||||||
|
|
||||||
|
logger.info(f"SAM2 Debug: Object {obj_id} mask - shape: {mask.shape}, pixels: {np.sum(mask)}")
|
||||||
|
|
||||||
|
# Add title
|
||||||
|
title = f"SAM2 First Frame Masks: {len(frame_masks)} objects detected"
|
||||||
|
cv2.putText(debug_frame, title, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (255, 255, 255), 2)
|
||||||
|
|
||||||
|
# Add mask source information
|
||||||
|
source_info = "Mask Source: SAM2 (from YOLO bounding boxes)"
|
||||||
|
cv2.putText(debug_frame, source_info, (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2)
|
||||||
|
|
||||||
|
# Add object legend
|
||||||
|
y_offset = 90
|
||||||
|
for obj_id in sorted(frame_masks.keys()):
|
||||||
|
color = colors.get(obj_id, (128, 128, 128))
|
||||||
|
text = f"Object {obj_id}: {'Left Eye' if obj_id == 1 else 'Right Eye' if obj_id == 2 else f'Object {obj_id}'}"
|
||||||
|
cv2.putText(debug_frame, text, (10, y_offset), cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2)
|
||||||
|
y_offset += 30
|
||||||
|
|
||||||
|
# Save debug image
|
||||||
|
success = cv2.imwrite(output_path, debug_frame)
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
self.predictor.reset_state(inference_state)
|
||||||
|
import shutil
|
||||||
|
shutil.rmtree(temp_dir)
|
||||||
|
|
||||||
|
if success:
|
||||||
|
logger.info(f"SAM2 Debug: Saved first frame masks to {output_path}")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
logger.error(f"Failed to save first frame masks to {output_path}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error generating first frame debug masks: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def add_multi_frame_prompts_to_predictor(self, inference_state, multi_frame_prompts: Dict[int, Any]) -> bool:
|
||||||
|
"""
|
||||||
|
Add YOLO prompts at multiple frame indices for mid-segment re-detection.
|
||||||
|
Supports both bounding box prompts (detection mode) and mask prompts (segmentation mode).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inference_state: SAM2 inference state
|
||||||
|
multi_frame_prompts: Dictionary mapping frame_index -> prompts (list of dicts for bbox, dict with 'masks' for segmentation)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if prompts were added successfully
|
||||||
|
"""
|
||||||
|
if not multi_frame_prompts:
|
||||||
|
logger.warning("SAM2 Mid-segment: No multi-frame prompts provided")
|
||||||
|
return False
|
||||||
|
|
||||||
|
success_count = 0
|
||||||
|
total_count = 0
|
||||||
|
|
||||||
|
for frame_idx, prompts_data in multi_frame_prompts.items():
|
||||||
|
# Check if this is segmentation mode (masks) or detection mode (bbox prompts)
|
||||||
|
if isinstance(prompts_data, dict) and 'masks' in prompts_data:
|
||||||
|
# Segmentation mode: add masks directly
|
||||||
|
masks_dict = prompts_data['masks']
|
||||||
|
logger.info(f"SAM2 Mid-segment: Processing frame {frame_idx} with {len(masks_dict)} YOLO masks")
|
||||||
|
|
||||||
|
for obj_id, mask in masks_dict.items():
|
||||||
|
total_count += 1
|
||||||
|
logger.info(f"SAM2 Mid-segment: Frame {frame_idx}, adding mask for Object {obj_id}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.predictor.add_new_mask(inference_state, frame_idx, obj_id, mask)
|
||||||
|
logger.info(f"SAM2 Mid-segment: ✓ Frame {frame_idx}, Object {obj_id} mask added successfully")
|
||||||
|
success_count += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"SAM2 Mid-segment: ✗ Frame {frame_idx}, Object {obj_id} mask failed: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Detection mode: add bounding box prompts (existing logic)
|
||||||
|
prompts = prompts_data
|
||||||
|
logger.info(f"SAM2 Mid-segment: Processing frame {frame_idx} with {len(prompts)} bbox prompts")
|
||||||
|
|
||||||
|
for i, prompt in enumerate(prompts):
|
||||||
|
obj_id = prompt['obj_id']
|
||||||
|
bbox = prompt['bbox']
|
||||||
|
confidence = prompt.get('confidence', 'unknown')
|
||||||
|
total_count += 1
|
||||||
|
|
||||||
|
logger.info(f"SAM2 Mid-segment: Frame {frame_idx}, Prompt {i+1}/{len(prompts)}: Object {obj_id}, bbox={bbox}, conf={confidence}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
_, out_obj_ids, out_mask_logits = self.predictor.add_new_points_or_box(
|
||||||
|
inference_state=inference_state,
|
||||||
|
frame_idx=frame_idx, # Key: specify the exact frame index
|
||||||
|
obj_id=obj_id,
|
||||||
|
box=bbox.astype(np.float32),
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"SAM2 Mid-segment: ✓ Frame {frame_idx}, Object {obj_id} added successfully - returned obj_ids: {out_obj_ids}")
|
||||||
|
success_count += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"SAM2 Mid-segment: ✗ Frame {frame_idx}, Object {obj_id} failed: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if success_count > 0:
|
||||||
|
logger.info(f"SAM2 Mid-segment: Final result - {success_count}/{total_count} prompts successfully added across {len(multi_frame_prompts)} frames")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
logger.error("SAM2 Mid-segment: FAILED - No prompts were successfully added")
|
||||||
|
return False
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import os
|
|||||||
import subprocess
|
import subprocess
|
||||||
import logging
|
import logging
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
from ..utils.file_utils import ensure_directory, get_video_file_name
|
from utils.file_utils import ensure_directory, get_video_file_name
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -7,50 +7,249 @@ import os
|
|||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import logging
|
import logging
|
||||||
from typing import List, Dict, Any, Optional
|
from typing import List, Dict, Any, Optional, Tuple
|
||||||
from ultralytics import YOLO
|
from ultralytics import YOLO
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class YOLODetector:
|
class YOLODetector:
|
||||||
\"\"\"Handles YOLO-based human detection for video segments.\"\"\"
|
"""Handles YOLO-based human detection for video segments with support for both detection and segmentation modes."""
|
||||||
|
|
||||||
def __init__(self, model_path: str, confidence_threshold: float = 0.6, human_class_id: int = 0):
|
def __init__(self, detection_model_path: str = None, segmentation_model_path: str = None,
|
||||||
\"\"\"
|
mode: str = "detection", confidence_threshold: float = 0.6, human_class_id: int = 0):
|
||||||
Initialize YOLO detector.
|
"""
|
||||||
|
Initialize YOLO detector with support for both detection and segmentation modes.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_path: Path to YOLO model weights
|
detection_model_path: Path to YOLO detection model weights (e.g., yolov8n.pt)
|
||||||
|
segmentation_model_path: Path to YOLO segmentation model weights (e.g., yolov8n-seg.pt)
|
||||||
|
mode: Detection mode - "detection" for bboxes, "segmentation" for masks
|
||||||
confidence_threshold: Detection confidence threshold
|
confidence_threshold: Detection confidence threshold
|
||||||
human_class_id: COCO class ID for humans (0 = person)
|
human_class_id: COCO class ID for humans (0 = person)
|
||||||
\"\"\"
|
"""
|
||||||
self.model_path = model_path
|
self.mode = mode
|
||||||
self.confidence_threshold = confidence_threshold
|
self.confidence_threshold = confidence_threshold
|
||||||
self.human_class_id = human_class_id
|
self.human_class_id = human_class_id
|
||||||
|
|
||||||
|
# Select model path based on mode
|
||||||
|
if mode == "segmentation":
|
||||||
|
if not segmentation_model_path:
|
||||||
|
raise ValueError("segmentation_model_path required for segmentation mode")
|
||||||
|
self.model_path = segmentation_model_path
|
||||||
|
self.supports_segmentation = True
|
||||||
|
elif mode == "detection":
|
||||||
|
if not detection_model_path:
|
||||||
|
raise ValueError("detection_model_path required for detection mode")
|
||||||
|
self.model_path = detection_model_path
|
||||||
|
self.supports_segmentation = False
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid mode: {mode}. Must be 'detection' or 'segmentation'")
|
||||||
|
|
||||||
# Load YOLO model
|
# Load YOLO model
|
||||||
try:
|
try:
|
||||||
self.model = YOLO(model_path)
|
self.model = YOLO(self.model_path)
|
||||||
logger.info(f\"Loaded YOLO model from {model_path}\")
|
logger.info(f"Loaded YOLO model in {mode} mode from {self.model_path}")
|
||||||
|
|
||||||
|
# Verify model capabilities
|
||||||
|
if mode == "segmentation":
|
||||||
|
# Test if model actually supports segmentation
|
||||||
|
logger.info(f"YOLO Segmentation: Model loaded, will output direct masks")
|
||||||
|
else:
|
||||||
|
logger.info(f"YOLO Detection: Model loaded, will output bounding boxes")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f\"Failed to load YOLO model: {e}\")
|
logger.error(f"Failed to load YOLO model: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def detect_humans_in_frame(self, frame: np.ndarray) -> List[Dict[str, Any]]:
|
def detect_humans_in_frame(self, frame: np.ndarray) -> List[Dict[str, Any]]:
|
||||||
\"\"\"
|
"""
|
||||||
Detect humans in a single frame using YOLO.
|
Detect humans in a single frame using YOLO.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
frame: Input frame (BGR format from OpenCV)
|
frame: Input frame (BGR format from OpenCV)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of human detection dictionaries with bbox and confidence
|
List of human detection dictionaries with bbox, confidence, and optionally masks
|
||||||
\"\"\"
|
"""
|
||||||
# Run YOLO detection
|
# Run YOLO detection/segmentation
|
||||||
results = self.model(frame, conf=self.confidence_threshold, verbose=False)
|
results = self.model(frame, conf=self.confidence_threshold, verbose=False)
|
||||||
|
|
||||||
human_detections = []
|
human_detections = []
|
||||||
|
|
||||||
|
# Process results
|
||||||
|
for result in results:
|
||||||
|
boxes = result.boxes
|
||||||
|
masks = result.masks if hasattr(result, 'masks') and result.masks is not None else None
|
||||||
|
|
||||||
|
if boxes is not None:
|
||||||
|
for i, box in enumerate(boxes):
|
||||||
|
# Get class ID
|
||||||
|
cls = int(box.cls.cpu().numpy()[0])
|
||||||
|
|
||||||
|
# Check if it's a person (human_class_id)
|
||||||
|
if cls == self.human_class_id:
|
||||||
|
# Get bounding box coordinates (x1, y1, x2, y2)
|
||||||
|
coords = box.xyxy[0].cpu().numpy()
|
||||||
|
conf = float(box.conf.cpu().numpy()[0])
|
||||||
|
|
||||||
|
detection = {
|
||||||
|
'bbox': coords,
|
||||||
|
'confidence': conf,
|
||||||
|
'has_mask': False,
|
||||||
|
'mask': None
|
||||||
|
}
|
||||||
|
|
||||||
|
# Extract mask if available (segmentation mode)
|
||||||
|
if masks is not None and i < len(masks.data):
|
||||||
|
mask_data = masks.data[i].cpu().numpy() # Get mask for this detection
|
||||||
|
detection['has_mask'] = True
|
||||||
|
detection['mask'] = mask_data
|
||||||
|
logger.debug(f"YOLO Segmentation: Detected human with mask - conf={conf:.2f}, mask_shape={mask_data.shape}")
|
||||||
|
else:
|
||||||
|
logger.debug(f"YOLO Detection: Detected human with bbox - conf={conf:.2f}, bbox={coords}")
|
||||||
|
|
||||||
|
human_detections.append(detection)
|
||||||
|
|
||||||
|
if self.supports_segmentation:
|
||||||
|
masks_found = sum(1 for d in human_detections if d['has_mask'])
|
||||||
|
logger.info(f"YOLO Segmentation: Found {len(human_detections)} humans, {masks_found} with masks")
|
||||||
|
else:
|
||||||
|
logger.debug(f"YOLO Detection: Found {len(human_detections)} humans with bounding boxes")
|
||||||
|
|
||||||
|
return human_detections
|
||||||
|
|
||||||
|
def detect_humans_in_video_first_frame(self, video_path: str, scale: float = 1.0) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Detect humans in the first frame of a video.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
video_path: Path to video file
|
||||||
|
scale: Scale factor for frame processing
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of human detection dictionaries
|
||||||
|
"""
|
||||||
|
if not os.path.exists(video_path):
|
||||||
|
logger.error(f"Video file not found: {video_path}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
cap = cv2.VideoCapture(video_path)
|
||||||
|
if not cap.isOpened():
|
||||||
|
logger.error(f"Could not open video: {video_path}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
ret, frame = cap.read()
|
||||||
|
cap.release()
|
||||||
|
|
||||||
|
if not ret:
|
||||||
|
logger.error(f"Could not read first frame from: {video_path}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Scale frame if needed
|
||||||
|
if scale != 1.0:
|
||||||
|
frame = cv2.resize(frame, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR)
|
||||||
|
|
||||||
|
return self.detect_humans_in_frame(frame)
|
||||||
|
|
||||||
|
def save_detections_to_file(self, detections: List[Dict[str, Any]], output_path: str) -> bool:
|
||||||
|
"""
|
||||||
|
Save detection results to file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
detections: List of detection dictionaries
|
||||||
|
output_path: Path to save detections
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if saved successfully
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with open(output_path, 'w') as f:
|
||||||
|
f.write("# YOLO Human Detections\\n")
|
||||||
|
if detections:
|
||||||
|
for detection in detections:
|
||||||
|
bbox = detection['bbox']
|
||||||
|
conf = detection['confidence']
|
||||||
|
f.write(f"{bbox[0]},{bbox[1]},{bbox[2]},{bbox[3]},{conf}\\n")
|
||||||
|
logger.info(f"Saved {len(detections)} detections to {output_path}")
|
||||||
|
else:
|
||||||
|
f.write("# No humans detected\\n")
|
||||||
|
logger.info(f"Saved empty detection file to {output_path}")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to save detections to {output_path}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def load_detections_from_file(self, file_path: str) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Load detection results from file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to detection file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of detection dictionaries
|
||||||
|
"""
|
||||||
|
detections = []
|
||||||
|
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
logger.warning(f"Detection file not found: {file_path}")
|
||||||
|
return detections
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(file_path, 'r') as f:
|
||||||
|
content = f.read()
|
||||||
|
|
||||||
|
# Handle files with literal \n characters
|
||||||
|
if '\\n' in content:
|
||||||
|
lines = content.split('\\n')
|
||||||
|
else:
|
||||||
|
lines = content.split('\n')
|
||||||
|
|
||||||
|
for line in lines:
|
||||||
|
line = line.strip()
|
||||||
|
# Skip comments and empty lines
|
||||||
|
if line.startswith('#') or not line:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Parse detection line: x1,y1,x2,y2,confidence
|
||||||
|
parts = line.split(',')
|
||||||
|
if len(parts) == 5:
|
||||||
|
try:
|
||||||
|
bbox = [float(x) for x in parts[:4]]
|
||||||
|
conf = float(parts[4])
|
||||||
|
detections.append({
|
||||||
|
'bbox': np.array(bbox),
|
||||||
|
'confidence': conf
|
||||||
|
})
|
||||||
|
except ValueError:
|
||||||
|
logger.warning(f"Invalid detection line: {line}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
logger.info(f"Loaded {len(detections)} detections from {file_path}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to load detections from {file_path}: {e}")
|
||||||
|
|
||||||
|
return detections
|
||||||
|
|
||||||
|
def debug_detect_with_lower_confidence(self, frame: np.ndarray, debug_confidence: float = 0.3) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Run YOLO detection with a lower confidence threshold for debugging.
|
||||||
|
This helps identify if detections are being missed due to high confidence threshold.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
frame: Input frame (BGR format from OpenCV)
|
||||||
|
debug_confidence: Lower confidence threshold for debugging
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of human detection dictionaries with lower confidence threshold
|
||||||
|
"""
|
||||||
|
logger.info(f"VR180 Debug: Running YOLO with lower confidence {debug_confidence} (vs normal {self.confidence_threshold})")
|
||||||
|
|
||||||
|
# Run YOLO detection with lower confidence
|
||||||
|
results = self.model(frame, conf=debug_confidence, verbose=False)
|
||||||
|
|
||||||
|
debug_detections = []
|
||||||
|
|
||||||
# Process results
|
# Process results
|
||||||
for result in results:
|
for result in results:
|
||||||
boxes = result.boxes
|
boxes = result.boxes
|
||||||
@@ -65,123 +264,90 @@ class YOLODetector:
|
|||||||
coords = box.xyxy[0].cpu().numpy()
|
coords = box.xyxy[0].cpu().numpy()
|
||||||
conf = float(box.conf.cpu().numpy()[0])
|
conf = float(box.conf.cpu().numpy()[0])
|
||||||
|
|
||||||
human_detections.append({
|
debug_detections.append({
|
||||||
'bbox': coords,
|
'bbox': coords,
|
||||||
'confidence': conf
|
'confidence': conf
|
||||||
})
|
})
|
||||||
|
|
||||||
logger.debug(f\"Detected human with confidence {conf:.2f} at {coords}\")
|
|
||||||
|
|
||||||
return human_detections
|
logger.info(f"VR180 Debug: Lower confidence detection found {len(debug_detections)} total detections")
|
||||||
|
return debug_detections
|
||||||
|
|
||||||
def detect_humans_in_video_first_frame(self, video_path: str, scale: float = 1.0) -> List[Dict[str, Any]]:
|
def detect_humans_multi_frame(self, video_path: str, frame_indices: List[int],
|
||||||
\"\"\"
|
scale: float = 1.0) -> Dict[int, List[Dict[str, Any]]]:
|
||||||
Detect humans in the first frame of a video.
|
"""
|
||||||
|
Detect humans at multiple specific frame indices in a video.
|
||||||
|
Used for mid-segment re-detection to improve SAM2 tracking.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
video_path: Path to video file
|
video_path: Path to video file
|
||||||
|
frame_indices: List of frame indices to run detection on (e.g., [0, 30, 60, 90])
|
||||||
scale: Scale factor for frame processing
|
scale: Scale factor for frame processing
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of human detection dictionaries
|
Dictionary mapping frame_index -> list of detection dictionaries
|
||||||
\"\"\"
|
"""
|
||||||
|
if not frame_indices:
|
||||||
|
logger.warning("No frame indices provided for multi-frame detection")
|
||||||
|
return {}
|
||||||
|
|
||||||
if not os.path.exists(video_path):
|
if not os.path.exists(video_path):
|
||||||
logger.error(f\"Video file not found: {video_path}\")
|
logger.error(f"Video file not found: {video_path}")
|
||||||
return []
|
return {}
|
||||||
|
|
||||||
|
logger.info(f"Mid-segment Detection: Running YOLO on {len(frame_indices)} frames: {frame_indices}")
|
||||||
|
|
||||||
cap = cv2.VideoCapture(video_path)
|
cap = cv2.VideoCapture(video_path)
|
||||||
if not cap.isOpened():
|
if not cap.isOpened():
|
||||||
logger.error(f\"Could not open video: {video_path}\")
|
logger.error(f"Could not open video: {video_path}")
|
||||||
return []
|
return {}
|
||||||
|
|
||||||
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||||
|
fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
|
||||||
|
|
||||||
|
# Filter out frame indices that are beyond video length
|
||||||
|
valid_frame_indices = [idx for idx in frame_indices if 0 <= idx < total_frames]
|
||||||
|
if len(valid_frame_indices) != len(frame_indices):
|
||||||
|
invalid_frames = [idx for idx in frame_indices if idx not in valid_frame_indices]
|
||||||
|
logger.warning(f"Mid-segment Detection: Skipping invalid frame indices: {invalid_frames} (video has {total_frames} frames)")
|
||||||
|
|
||||||
|
multi_frame_detections = {}
|
||||||
|
|
||||||
|
for frame_idx in valid_frame_indices:
|
||||||
|
# Seek to specific frame
|
||||||
|
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
|
||||||
|
ret, frame = cap.read()
|
||||||
|
|
||||||
|
if not ret:
|
||||||
|
logger.warning(f"Mid-segment Detection: Could not read frame {frame_idx}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Scale frame if needed
|
||||||
|
if scale != 1.0:
|
||||||
|
frame = cv2.resize(frame, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR)
|
||||||
|
|
||||||
|
# Run YOLO detection on this frame
|
||||||
|
detections = self.detect_humans_in_frame(frame)
|
||||||
|
multi_frame_detections[frame_idx] = detections
|
||||||
|
|
||||||
|
# Log detection results
|
||||||
|
time_seconds = frame_idx / fps
|
||||||
|
logger.info(f"Mid-segment Detection: Frame {frame_idx} (t={time_seconds:.1f}s): {len(detections)} humans detected")
|
||||||
|
|
||||||
|
for i, detection in enumerate(detections):
|
||||||
|
bbox = detection['bbox']
|
||||||
|
conf = detection['confidence']
|
||||||
|
logger.debug(f"Mid-segment Detection: Frame {frame_idx}, Human {i+1}: bbox={bbox}, conf={conf:.3f}")
|
||||||
|
|
||||||
ret, frame = cap.read()
|
|
||||||
cap.release()
|
cap.release()
|
||||||
|
|
||||||
if not ret:
|
total_detections = sum(len(dets) for dets in multi_frame_detections.values())
|
||||||
logger.error(f\"Could not read first frame from: {video_path}\")
|
logger.info(f"Mid-segment Detection: Complete - {total_detections} total detections across {len(valid_frame_indices)} frames")
|
||||||
return []
|
|
||||||
|
|
||||||
# Scale frame if needed
|
return multi_frame_detections
|
||||||
if scale != 1.0:
|
|
||||||
frame = cv2.resize(frame, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR)
|
|
||||||
|
|
||||||
return self.detect_humans_in_frame(frame)
|
|
||||||
|
|
||||||
def save_detections_to_file(self, detections: List[Dict[str, Any]], output_path: str) -> bool:
|
|
||||||
\"\"\"
|
|
||||||
Save detection results to file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
detections: List of detection dictionaries
|
|
||||||
output_path: Path to save detections
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if saved successfully
|
|
||||||
\"\"\"
|
|
||||||
try:
|
|
||||||
with open(output_path, 'w') as f:
|
|
||||||
f.write(\"# YOLO Human Detections\\n\")
|
|
||||||
if detections:
|
|
||||||
for detection in detections:
|
|
||||||
bbox = detection['bbox']
|
|
||||||
conf = detection['confidence']
|
|
||||||
f.write(f\"{bbox[0]},{bbox[1]},{bbox[2]},{bbox[3]},{conf}\\n\")
|
|
||||||
logger.info(f\"Saved {len(detections)} detections to {output_path}\")
|
|
||||||
else:
|
|
||||||
f.write(\"# No humans detected\\n\")
|
|
||||||
logger.info(f\"Saved empty detection file to {output_path}\")
|
|
||||||
return True
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f\"Failed to save detections to {output_path}: {e}\")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def load_detections_from_file(self, file_path: str) -> List[Dict[str, Any]]:
|
|
||||||
\"\"\"
|
|
||||||
Load detection results from file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file_path: Path to detection file
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of detection dictionaries
|
|
||||||
\"\"\"
|
|
||||||
detections = []
|
|
||||||
|
|
||||||
if not os.path.exists(file_path):
|
|
||||||
logger.warning(f\"Detection file not found: {file_path}\")
|
|
||||||
return detections
|
|
||||||
|
|
||||||
try:
|
|
||||||
with open(file_path, 'r') as f:
|
|
||||||
for line in f:
|
|
||||||
line = line.strip()
|
|
||||||
# Skip comments and empty lines
|
|
||||||
if line.startswith('#') or not line:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Parse detection line: x1,y1,x2,y2,confidence
|
|
||||||
parts = line.split(',')
|
|
||||||
if len(parts) == 5:
|
|
||||||
try:
|
|
||||||
bbox = [float(x) for x in parts[:4]]
|
|
||||||
conf = float(parts[4])
|
|
||||||
detections.append({
|
|
||||||
'bbox': np.array(bbox),
|
|
||||||
'confidence': conf
|
|
||||||
})
|
|
||||||
except ValueError:
|
|
||||||
logger.warning(f\"Invalid detection line: {line}\")
|
|
||||||
continue
|
|
||||||
|
|
||||||
logger.info(f\"Loaded {len(detections)} detections from {file_path}\")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f\"Failed to load detections from {file_path}: {e}\")
|
|
||||||
|
|
||||||
return detections
|
|
||||||
|
|
||||||
def process_segments_batch(self, segments_info: List[dict], detect_segments: List[int],
|
def process_segments_batch(self, segments_info: List[dict], detect_segments: List[int],
|
||||||
scale: float = 0.5) -> Dict[int, List[Dict[str, Any]]]:
|
scale: float = 0.5) -> Dict[int, List[Dict[str, Any]]]:
|
||||||
\"\"\"
|
"""
|
||||||
Process multiple segments for human detection.
|
Process multiple segments for human detection.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -191,7 +357,7 @@ class YOLODetector:
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dictionary mapping segment index to detection results
|
Dictionary mapping segment index to detection results
|
||||||
\"\"\"
|
"""
|
||||||
results = {}
|
results = {}
|
||||||
|
|
||||||
for segment_info in segments_info:
|
for segment_info in segments_info:
|
||||||
@@ -202,17 +368,17 @@ class YOLODetector:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
video_path = segment_info['video_file']
|
video_path = segment_info['video_file']
|
||||||
detection_file = os.path.join(segment_info['directory'], \"yolo_detections\")
|
detection_file = os.path.join(segment_info['directory'], "yolo_detections")
|
||||||
|
|
||||||
# Skip if already processed
|
# Skip if already processed
|
||||||
if os.path.exists(detection_file):
|
if os.path.exists(detection_file):
|
||||||
logger.info(f\"Segment {segment_idx} already has detections, skipping\")
|
logger.info(f"Segment {segment_idx} already has detections, skipping")
|
||||||
detections = self.load_detections_from_file(detection_file)
|
detections = self.load_detections_from_file(detection_file)
|
||||||
results[segment_idx] = detections
|
results[segment_idx] = detections
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Run detection
|
# Run detection
|
||||||
logger.info(f\"Processing segment {segment_idx} for human detection\")
|
logger.info(f"Processing segment {segment_idx} for human detection")
|
||||||
detections = self.detect_humans_in_video_first_frame(video_path, scale)
|
detections = self.detect_humans_in_video_first_frame(video_path, scale)
|
||||||
|
|
||||||
# Save results
|
# Save results
|
||||||
@@ -223,8 +389,9 @@ class YOLODetector:
|
|||||||
|
|
||||||
def convert_detections_to_sam2_prompts(self, detections: List[Dict[str, Any]],
|
def convert_detections_to_sam2_prompts(self, detections: List[Dict[str, Any]],
|
||||||
frame_width: int) -> List[Dict[str, Any]]:
|
frame_width: int) -> List[Dict[str, Any]]:
|
||||||
\"\"\"
|
"""
|
||||||
Convert YOLO detections to SAM2-compatible prompts for stereo video.
|
Convert YOLO detections to SAM2-compatible prompts for VR180 SBS video.
|
||||||
|
For VR180, we expect 2 real detections (left and right eye views), not mirrored ones.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
detections: List of YOLO detection results
|
detections: List of YOLO detection results
|
||||||
@@ -232,55 +399,337 @@ class YOLODetector:
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of SAM2 prompt dictionaries with obj_id and bbox
|
List of SAM2 prompt dictionaries with obj_id and bbox
|
||||||
\"\"\"
|
"""
|
||||||
if not detections:
|
if not detections:
|
||||||
|
logger.warning("No detections provided for SAM2 prompt conversion")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
half_frame_width = frame_width // 2
|
half_frame_width = frame_width // 2
|
||||||
prompts = []
|
prompts = []
|
||||||
|
|
||||||
|
logger.info(f"VR180 SBS Debug: Converting {len(detections)} detections for frame width {frame_width}")
|
||||||
|
logger.info(f"VR180 SBS Debug: Half frame width = {half_frame_width}")
|
||||||
|
|
||||||
# Sort detections by x-coordinate to get consistent left/right assignment
|
# Sort detections by x-coordinate to get consistent left/right assignment
|
||||||
sorted_detections = sorted(detections, key=lambda x: x['bbox'][0])
|
sorted_detections = sorted(detections, key=lambda x: x['bbox'][0])
|
||||||
|
|
||||||
|
# Analyze detections by frame half
|
||||||
|
left_detections = []
|
||||||
|
right_detections = []
|
||||||
|
|
||||||
|
for i, detection in enumerate(sorted_detections):
|
||||||
|
bbox = detection['bbox'].copy()
|
||||||
|
center_x = (bbox[0] + bbox[2]) / 2
|
||||||
|
pixel_range = f"{bbox[0]:.0f}-{bbox[2]:.0f}"
|
||||||
|
|
||||||
|
if center_x < half_frame_width:
|
||||||
|
left_detections.append((detection, i, pixel_range))
|
||||||
|
side = "LEFT"
|
||||||
|
else:
|
||||||
|
right_detections.append((detection, i, pixel_range))
|
||||||
|
side = "RIGHT"
|
||||||
|
|
||||||
|
logger.info(f"VR180 SBS Debug: Detection {i}: pixels {pixel_range}, center_x={center_x:.1f}, side={side}")
|
||||||
|
|
||||||
|
# VR180 SBS Format Validation
|
||||||
|
logger.info(f"VR180 SBS Debug: Found {len(left_detections)} LEFT detections, {len(right_detections)} RIGHT detections")
|
||||||
|
|
||||||
|
# Analyze confidence scores
|
||||||
|
if left_detections:
|
||||||
|
left_confidences = [det[0]['confidence'] for det in left_detections]
|
||||||
|
logger.info(f"VR180 SBS Debug: LEFT eye confidences: {[f'{c:.3f}' for c in left_confidences]}")
|
||||||
|
|
||||||
|
if right_detections:
|
||||||
|
right_confidences = [det[0]['confidence'] for det in right_detections]
|
||||||
|
logger.info(f"VR180 SBS Debug: RIGHT eye confidences: {[f'{c:.3f}' for c in right_confidences]}")
|
||||||
|
|
||||||
|
if len(right_detections) == 0:
|
||||||
|
logger.warning(f"VR180 SBS Warning: No detections found in RIGHT eye view (pixels {half_frame_width}-{frame_width})")
|
||||||
|
logger.warning(f"VR180 SBS Warning: This may indicate:")
|
||||||
|
logger.warning(f" 1. Person not visible in right eye view")
|
||||||
|
logger.warning(f" 2. YOLO confidence threshold ({self.confidence_threshold}) too high")
|
||||||
|
logger.warning(f" 3. VR180 SBS format issue")
|
||||||
|
logger.warning(f" 4. Right eye view quality/lighting problems")
|
||||||
|
logger.warning(f"VR180 SBS Suggestion: Try lowering yolo_confidence to 0.3-0.4 in config")
|
||||||
|
|
||||||
|
if len(left_detections) == 0:
|
||||||
|
logger.warning(f"VR180 SBS Warning: No detections found in LEFT eye view (pixels 0-{half_frame_width})")
|
||||||
|
|
||||||
|
# Additional validation for VR180 SBS expectations
|
||||||
|
total_detections = len(left_detections) + len(right_detections)
|
||||||
|
if total_detections == 1:
|
||||||
|
logger.warning(f"VR180 SBS Warning: Only 1 detection found - expected 2 for proper VR180 SBS")
|
||||||
|
elif total_detections > 2:
|
||||||
|
logger.warning(f"VR180 SBS Warning: {total_detections} detections found - will use only first 2")
|
||||||
|
|
||||||
|
# Assign object IDs sequentially, regardless of which half they're in
|
||||||
|
# This ensures we always get Object 1 and Object 2 for up to 2 detections
|
||||||
obj_id = 1
|
obj_id = 1
|
||||||
|
|
||||||
for i, detection in enumerate(sorted_detections[:2]): # Take up to 2 humans
|
# Process up to 2 detections total (left + right combined)
|
||||||
|
all_detections = sorted_detections[:2]
|
||||||
|
|
||||||
|
for i, detection in enumerate(all_detections):
|
||||||
bbox = detection['bbox'].copy()
|
bbox = detection['bbox'].copy()
|
||||||
|
center_x = (bbox[0] + bbox[2]) / 2
|
||||||
|
pixel_range = f"{bbox[0]:.0f}-{bbox[2]:.0f}"
|
||||||
|
|
||||||
# For stereo videos, assign obj_id based on position
|
# Determine which eye view this detection is in
|
||||||
if len(sorted_detections) >= 2:
|
if center_x < half_frame_width:
|
||||||
center_x = (bbox[0] + bbox[2]) / 2
|
eye_view = "LEFT"
|
||||||
if center_x < half_frame_width:
|
|
||||||
current_obj_id = 1 # Left human
|
|
||||||
else:
|
|
||||||
current_obj_id = 2 # Right human
|
|
||||||
else:
|
else:
|
||||||
# If only one human, create prompts for both sides
|
eye_view = "RIGHT"
|
||||||
current_obj_id = obj_id
|
|
||||||
obj_id += 1
|
|
||||||
|
|
||||||
# Create mirrored version for stereo
|
|
||||||
if obj_id <= 2:
|
|
||||||
mirrored_bbox = bbox.copy()
|
|
||||||
mirrored_bbox[0] += half_frame_width # Shift x1
|
|
||||||
mirrored_bbox[2] += half_frame_width # Shift x2
|
|
||||||
|
|
||||||
# Ensure mirrored bbox is within frame bounds
|
|
||||||
mirrored_bbox[0] = max(0, min(mirrored_bbox[0], frame_width - 1))
|
|
||||||
mirrored_bbox[2] = max(0, min(mirrored_bbox[2], frame_width - 1))
|
|
||||||
|
|
||||||
prompts.append({
|
|
||||||
'obj_id': obj_id,
|
|
||||||
'bbox': mirrored_bbox,
|
|
||||||
'confidence': detection['confidence']
|
|
||||||
})
|
|
||||||
obj_id += 1
|
|
||||||
|
|
||||||
prompts.append({
|
prompts.append({
|
||||||
'obj_id': current_obj_id,
|
'obj_id': obj_id,
|
||||||
'bbox': bbox,
|
'bbox': bbox,
|
||||||
'confidence': detection['confidence']
|
'confidence': detection['confidence']
|
||||||
})
|
})
|
||||||
|
|
||||||
|
logger.info(f"VR180 SBS Debug: Added {eye_view} eye detection as SAM2 Object {obj_id}")
|
||||||
|
logger.info(f"VR180 SBS Debug: Object {obj_id} bbox: {bbox} (pixels {pixel_range})")
|
||||||
|
|
||||||
|
obj_id += 1
|
||||||
|
|
||||||
logger.debug(f\"Converted {len(detections)} detections to {len(prompts)} SAM2 prompts\")
|
logger.info(f"VR180 SBS Debug: Final result - {len(detections)} YOLO detections → {len(prompts)} SAM2 prompts")
|
||||||
return prompts
|
|
||||||
|
# Verify we have the expected objects
|
||||||
|
obj_ids = [p['obj_id'] for p in prompts]
|
||||||
|
logger.info(f"VR180 SBS Debug: SAM2 Object IDs created: {obj_ids}")
|
||||||
|
|
||||||
|
return prompts
|
||||||
|
|
||||||
|
def convert_yolo_masks_to_video_segments(self, detections: List[Dict[str, Any]],
|
||||||
|
frame_width: int, target_frame_shape: Tuple[int, int] = None) -> Optional[Dict[int, Dict[int, np.ndarray]]]:
|
||||||
|
"""
|
||||||
|
Convert YOLO segmentation masks to SAM2-compatible video segments format.
|
||||||
|
This allows using YOLO masks directly without SAM2 processing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
detections: List of YOLO detection results with masks
|
||||||
|
frame_width: Width of the video frame for VR180 object ID assignment
|
||||||
|
target_frame_shape: Target shape (height, width) for mask resizing
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Video segments dictionary compatible with SAM2 output format, or None if no masks
|
||||||
|
"""
|
||||||
|
if not detections:
|
||||||
|
logger.warning("No detections provided for mask conversion")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Check if any detections have masks
|
||||||
|
detections_with_masks = [d for d in detections if d.get('has_mask', False)]
|
||||||
|
if not detections_with_masks:
|
||||||
|
logger.warning("No detections have masks - YOLO segmentation may not be working")
|
||||||
|
return None
|
||||||
|
|
||||||
|
logger.info(f"YOLO Mask Conversion: Converting {len(detections_with_masks)} YOLO masks to video segments format")
|
||||||
|
|
||||||
|
half_frame_width = frame_width // 2
|
||||||
|
video_segments = {}
|
||||||
|
|
||||||
|
# Create frame 0 with converted masks
|
||||||
|
frame_masks = {}
|
||||||
|
obj_id = 1
|
||||||
|
|
||||||
|
# Sort detections by x-coordinate for consistent VR180 SBS assignment
|
||||||
|
sorted_detections = sorted(detections_with_masks, key=lambda x: x['bbox'][0])
|
||||||
|
|
||||||
|
for i, detection in enumerate(sorted_detections[:2]): # Take up to 2 humans
|
||||||
|
mask = detection['mask']
|
||||||
|
bbox = detection['bbox']
|
||||||
|
center_x = (bbox[0] + bbox[2]) / 2
|
||||||
|
|
||||||
|
# Assign sequential object IDs (similar to prompt conversion logic)
|
||||||
|
current_obj_id = obj_id
|
||||||
|
|
||||||
|
# Determine which eye view for logging
|
||||||
|
if center_x < half_frame_width:
|
||||||
|
eye_view = "LEFT"
|
||||||
|
else:
|
||||||
|
eye_view = "RIGHT"
|
||||||
|
|
||||||
|
# Resize mask to target frame shape if specified
|
||||||
|
if target_frame_shape and mask.shape != target_frame_shape:
|
||||||
|
mask_resized = cv2.resize(mask.astype(np.float32), (target_frame_shape[1], target_frame_shape[0]), interpolation=cv2.INTER_NEAREST)
|
||||||
|
mask = (mask_resized > 0.5).astype(bool)
|
||||||
|
else:
|
||||||
|
mask = mask.astype(bool)
|
||||||
|
|
||||||
|
frame_masks[current_obj_id] = mask
|
||||||
|
|
||||||
|
logger.info(f"YOLO Mask Conversion: {eye_view} eye detection -> Object {current_obj_id}, mask_shape={mask.shape}, pixels={np.sum(mask)}")
|
||||||
|
|
||||||
|
obj_id += 1 # Always increment for next detection
|
||||||
|
|
||||||
|
# Store masks in video segments format (single frame)
|
||||||
|
video_segments[0] = frame_masks
|
||||||
|
|
||||||
|
total_objects = len(frame_masks)
|
||||||
|
total_pixels = sum(np.sum(mask) for mask in frame_masks.values())
|
||||||
|
logger.info(f"YOLO Mask Conversion: Created video segments with {total_objects} objects, {total_pixels} total mask pixels")
|
||||||
|
|
||||||
|
return video_segments
|
||||||
|
|
||||||
|
def save_debug_frame_with_detections(self, frame: np.ndarray, detections: List[Dict[str, Any]],
|
||||||
|
output_path: str, prompts: List[Dict[str, Any]] = None) -> bool:
|
||||||
|
"""
|
||||||
|
Save a debug frame with YOLO detections and SAM2 prompts overlaid as bounding boxes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
frame: Input frame (BGR format from OpenCV)
|
||||||
|
detections: List of detection dictionaries with bbox and confidence
|
||||||
|
output_path: Path to save the debug image
|
||||||
|
prompts: Optional list of SAM2 prompt dictionaries with obj_id and bbox
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if saved successfully
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
debug_frame = frame.copy()
|
||||||
|
|
||||||
|
# Draw masks (if available) or bounding boxes for each detection
|
||||||
|
for i, detection in enumerate(detections):
|
||||||
|
bbox = detection['bbox']
|
||||||
|
confidence = detection['confidence']
|
||||||
|
has_mask = detection.get('has_mask', False)
|
||||||
|
|
||||||
|
# Extract coordinates
|
||||||
|
x1, y1, x2, y2 = map(int, bbox)
|
||||||
|
|
||||||
|
# Choose color based on confidence (green for high, yellow for medium, red for low)
|
||||||
|
if confidence >= 0.8:
|
||||||
|
color = (0, 255, 0) # Green
|
||||||
|
elif confidence >= 0.6:
|
||||||
|
color = (0, 255, 255) # Yellow
|
||||||
|
else:
|
||||||
|
color = (0, 0, 255) # Red
|
||||||
|
|
||||||
|
if has_mask and 'mask' in detection:
|
||||||
|
# Draw segmentation mask
|
||||||
|
mask = detection['mask']
|
||||||
|
|
||||||
|
# Resize mask to match frame if needed
|
||||||
|
if mask.shape != debug_frame.shape[:2]:
|
||||||
|
mask = cv2.resize(mask.astype(np.float32), (debug_frame.shape[1], debug_frame.shape[0]), interpolation=cv2.INTER_NEAREST)
|
||||||
|
mask = mask > 0.5
|
||||||
|
|
||||||
|
mask = mask.astype(bool)
|
||||||
|
|
||||||
|
# Apply colored overlay with transparency
|
||||||
|
overlay = debug_frame.copy()
|
||||||
|
overlay[mask] = color
|
||||||
|
cv2.addWeighted(overlay, 0.3, debug_frame, 0.7, 0, debug_frame)
|
||||||
|
|
||||||
|
# Draw mask outline
|
||||||
|
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||||
|
cv2.drawContours(debug_frame, contours, -1, color, 2)
|
||||||
|
|
||||||
|
# Prepare label text for segmentation
|
||||||
|
label = f"Person {i+1}: {confidence:.2f} (MASK)"
|
||||||
|
else:
|
||||||
|
# Draw bounding box (detection mode or no mask available)
|
||||||
|
cv2.rectangle(debug_frame, (x1, y1), (x2, y2), color, 2)
|
||||||
|
|
||||||
|
# Prepare label text for detection
|
||||||
|
label = f"Person {i+1}: {confidence:.2f} (BBOX)"
|
||||||
|
|
||||||
|
label_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)[0]
|
||||||
|
|
||||||
|
# Draw label background
|
||||||
|
cv2.rectangle(debug_frame,
|
||||||
|
(x1, y1 - label_size[1] - 10),
|
||||||
|
(x1 + label_size[0], y1),
|
||||||
|
color, -1)
|
||||||
|
|
||||||
|
# Draw label text
|
||||||
|
cv2.putText(debug_frame, label,
|
||||||
|
(x1, y1 - 5),
|
||||||
|
cv2.FONT_HERSHEY_SIMPLEX, 0.6,
|
||||||
|
(255, 255, 255), 2)
|
||||||
|
|
||||||
|
# Draw SAM2 prompts if provided (with different colors/style)
|
||||||
|
if prompts:
|
||||||
|
for prompt in prompts:
|
||||||
|
obj_id = prompt['obj_id']
|
||||||
|
bbox = prompt['bbox']
|
||||||
|
|
||||||
|
# Extract coordinates
|
||||||
|
x1, y1, x2, y2 = map(int, bbox)
|
||||||
|
|
||||||
|
# Use different colors for each object ID
|
||||||
|
if obj_id == 1:
|
||||||
|
prompt_color = (0, 255, 0) # Green for Object 1
|
||||||
|
elif obj_id == 2:
|
||||||
|
prompt_color = (255, 0, 0) # Blue for Object 2
|
||||||
|
else:
|
||||||
|
prompt_color = (255, 255, 0) # Cyan for others
|
||||||
|
|
||||||
|
# Draw thicker, dashed-style border for SAM2 prompts
|
||||||
|
thickness = 3
|
||||||
|
cv2.rectangle(debug_frame, (x1-2, y1-2), (x2+2, y2+2), prompt_color, thickness)
|
||||||
|
|
||||||
|
# Add SAM2 object ID label
|
||||||
|
sam_label = f"SAM2 Obj {obj_id}"
|
||||||
|
label_size = cv2.getTextSize(sam_label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)[0]
|
||||||
|
|
||||||
|
# Draw label background
|
||||||
|
cv2.rectangle(debug_frame,
|
||||||
|
(x1-2, y2+5),
|
||||||
|
(x1-2 + label_size[0], y2+5 + label_size[1] + 5),
|
||||||
|
prompt_color, -1)
|
||||||
|
|
||||||
|
# Draw label text
|
||||||
|
cv2.putText(debug_frame, sam_label,
|
||||||
|
(x1-2, y2+5 + label_size[1]),
|
||||||
|
cv2.FONT_HERSHEY_SIMPLEX, 0.5,
|
||||||
|
(255, 255, 255), 2)
|
||||||
|
|
||||||
|
# Draw VR180 SBS boundary line (center line separating left and right eye views)
|
||||||
|
frame_height, frame_width = debug_frame.shape[:2]
|
||||||
|
center_x = frame_width // 2
|
||||||
|
cv2.line(debug_frame, (center_x, 0), (center_x, frame_height), (0, 255, 255), 3) # Yellow line
|
||||||
|
|
||||||
|
# Add VR180 SBS labels
|
||||||
|
cv2.putText(debug_frame, "LEFT EYE", (10, frame_height - 20),
|
||||||
|
cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 255), 2)
|
||||||
|
cv2.putText(debug_frame, "RIGHT EYE", (center_x + 10, frame_height - 20),
|
||||||
|
cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 255), 2)
|
||||||
|
|
||||||
|
# Add summary text at top with mode information
|
||||||
|
mode_text = f"YOLO Mode: {self.mode.upper()}"
|
||||||
|
masks_available = sum(1 for d in detections if d.get('has_mask', False))
|
||||||
|
|
||||||
|
if self.supports_segmentation and masks_available > 0:
|
||||||
|
summary = f"VR180 SBS: {len(detections)} detections → {masks_available} MASKS (for SAM2 propagation)"
|
||||||
|
else:
|
||||||
|
summary = f"VR180 SBS: {len(detections)} detections → {len(prompts) if prompts else 0} SAM2 prompts"
|
||||||
|
|
||||||
|
cv2.putText(debug_frame, mode_text,
|
||||||
|
(10, 30),
|
||||||
|
cv2.FONT_HERSHEY_SIMPLEX, 0.8,
|
||||||
|
(0, 255, 255), 2) # Yellow for mode
|
||||||
|
cv2.putText(debug_frame, summary,
|
||||||
|
(10, 60),
|
||||||
|
cv2.FONT_HERSHEY_SIMPLEX, 1.0,
|
||||||
|
(255, 255, 255), 2)
|
||||||
|
|
||||||
|
# Add frame dimensions info
|
||||||
|
dims_info = f"Frame: {frame_width}x{frame_height}, Center: {center_x}"
|
||||||
|
cv2.putText(debug_frame, dims_info,
|
||||||
|
(10, 90),
|
||||||
|
cv2.FONT_HERSHEY_SIMPLEX, 0.6,
|
||||||
|
(255, 255, 255), 2)
|
||||||
|
|
||||||
|
# Save debug frame
|
||||||
|
success = cv2.imwrite(output_path, debug_frame)
|
||||||
|
if success:
|
||||||
|
logger.info(f"Saved YOLO debug frame to {output_path}")
|
||||||
|
else:
|
||||||
|
logger.error(f"Failed to save debug frame to {output_path}")
|
||||||
|
|
||||||
|
return success
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error creating debug frame: {e}")
|
||||||
|
return False
|
||||||
317
download_models.py
Executable file
317
download_models.py
Executable file
@@ -0,0 +1,317 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Model download script for YOLO + SAM2 video processing pipeline.
|
||||||
|
Downloads SAM2.1 models and organizes them in the models directory.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import urllib.request
|
||||||
|
import urllib.error
|
||||||
|
from pathlib import Path
|
||||||
|
import sys
|
||||||
|
|
||||||
|
def create_directory_structure():
|
||||||
|
"""Create the models directory structure."""
|
||||||
|
base_dir = Path(__file__).parent
|
||||||
|
models_dir = base_dir / "models"
|
||||||
|
|
||||||
|
# Create main models directory
|
||||||
|
models_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
# Create subdirectories
|
||||||
|
sam2_dir = models_dir / "sam2"
|
||||||
|
sam2_configs_dir = sam2_dir / "configs" / "sam2.1"
|
||||||
|
sam2_checkpoints_dir = sam2_dir / "checkpoints"
|
||||||
|
yolo_dir = models_dir / "yolo"
|
||||||
|
|
||||||
|
sam2_dir.mkdir(exist_ok=True)
|
||||||
|
sam2_configs_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
sam2_checkpoints_dir.mkdir(exist_ok=True)
|
||||||
|
yolo_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
print(f"Created models directory structure in: {models_dir}")
|
||||||
|
return models_dir, sam2_configs_dir, sam2_checkpoints_dir, yolo_dir
|
||||||
|
|
||||||
|
def download_file(url, destination, description="file"):
|
||||||
|
"""Download a file with progress indication."""
|
||||||
|
try:
|
||||||
|
print(f"Downloading {description}...")
|
||||||
|
print(f" URL: {url}")
|
||||||
|
print(f" Destination: {destination}")
|
||||||
|
|
||||||
|
def progress_hook(block_num, block_size, total_size):
|
||||||
|
if total_size > 0:
|
||||||
|
percent = min(100, (block_num * block_size * 100) // total_size)
|
||||||
|
sys.stdout.write(f"\r Progress: {percent}%")
|
||||||
|
sys.stdout.flush()
|
||||||
|
|
||||||
|
urllib.request.urlretrieve(url, destination, progress_hook)
|
||||||
|
print(f"\n ✓ Downloaded {description}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except urllib.error.URLError as e:
|
||||||
|
print(f"\n ✗ Failed to download {description}: {e}")
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n ✗ Error downloading {description}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def download_sam2_models():
|
||||||
|
"""Download SAM2.1 model configurations and checkpoints."""
|
||||||
|
print("Setting up SAM2.1 models...")
|
||||||
|
|
||||||
|
# Create directory structure
|
||||||
|
models_dir, configs_dir, checkpoints_dir, yolo_dir = create_directory_structure()
|
||||||
|
|
||||||
|
# SAM2.1 model definitions
|
||||||
|
sam2_models = {
|
||||||
|
"tiny": {
|
||||||
|
"config_url": "https://raw.githubusercontent.com/facebookresearch/sam2/main/sam2/configs/sam2.1/sam2.1_hiera_t.yaml",
|
||||||
|
"checkpoint_url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt",
|
||||||
|
"config_file": "sam2.1_hiera_t.yaml",
|
||||||
|
"checkpoint_file": "sam2.1_hiera_tiny.pt"
|
||||||
|
},
|
||||||
|
"small": {
|
||||||
|
"config_url": "https://raw.githubusercontent.com/facebookresearch/sam2/main/sam2/configs/sam2.1/sam2.1_hiera_s.yaml",
|
||||||
|
"checkpoint_url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt",
|
||||||
|
"config_file": "sam2.1_hiera_s.yaml",
|
||||||
|
"checkpoint_file": "sam2.1_hiera_small.pt"
|
||||||
|
},
|
||||||
|
"base_plus": {
|
||||||
|
"config_url": "https://raw.githubusercontent.com/facebookresearch/sam2/main/sam2/configs/sam2.1/sam2.1_hiera_b+.yaml",
|
||||||
|
"checkpoint_url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt",
|
||||||
|
"config_file": "sam2.1_hiera_b+.yaml",
|
||||||
|
"checkpoint_file": "sam2.1_hiera_base_plus.pt"
|
||||||
|
},
|
||||||
|
"large": {
|
||||||
|
"config_url": "https://raw.githubusercontent.com/facebookresearch/sam2/main/sam2/configs/sam2.1/sam2.1_hiera_l.yaml",
|
||||||
|
"checkpoint_url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt",
|
||||||
|
"config_file": "sam2.1_hiera_l.yaml",
|
||||||
|
"checkpoint_file": "sam2.1_hiera_large.pt"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
success_count = 0
|
||||||
|
total_downloads = len(sam2_models) * 2 # configs + checkpoints
|
||||||
|
|
||||||
|
# Download each model's config and checkpoint
|
||||||
|
for model_name, model_info in sam2_models.items():
|
||||||
|
print(f"\n--- Downloading SAM2.1 {model_name.upper()} model ---")
|
||||||
|
|
||||||
|
# Download config file
|
||||||
|
config_path = configs_dir / model_info["config_file"]
|
||||||
|
if not config_path.exists():
|
||||||
|
if download_file(
|
||||||
|
model_info["config_url"],
|
||||||
|
config_path,
|
||||||
|
f"SAM2.1 {model_name} config"
|
||||||
|
):
|
||||||
|
success_count += 1
|
||||||
|
else:
|
||||||
|
print(f" ✓ Config file already exists: {config_path}")
|
||||||
|
success_count += 1
|
||||||
|
|
||||||
|
# Download checkpoint file
|
||||||
|
checkpoint_path = checkpoints_dir / model_info["checkpoint_file"]
|
||||||
|
if not checkpoint_path.exists():
|
||||||
|
if download_file(
|
||||||
|
model_info["checkpoint_url"],
|
||||||
|
checkpoint_path,
|
||||||
|
f"SAM2.1 {model_name} checkpoint"
|
||||||
|
):
|
||||||
|
success_count += 1
|
||||||
|
else:
|
||||||
|
print(f" ✓ Checkpoint file already exists: {checkpoint_path}")
|
||||||
|
success_count += 1
|
||||||
|
|
||||||
|
print(f"\n=== Download Summary ===")
|
||||||
|
print(f"Successfully downloaded: {success_count}/{total_downloads} files")
|
||||||
|
|
||||||
|
if success_count == total_downloads:
|
||||||
|
print("✓ All SAM2.1 models downloaded successfully!")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
print(f"⚠ Some downloads failed ({total_downloads - success_count} files)")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def download_yolo_models():
|
||||||
|
"""Download default YOLO models to models directory."""
|
||||||
|
print("\n--- Setting up YOLO models ---")
|
||||||
|
print(" Downloading both detection and segmentation models...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from ultralytics import YOLO
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# Default YOLO models to download (both detection and segmentation)
|
||||||
|
yolo_models = [
|
||||||
|
"yolov8n.pt", # Detection models
|
||||||
|
"yolov8s.pt",
|
||||||
|
"yolov8m.pt",
|
||||||
|
"yolo11l.pt", # YOLOv11 detection models
|
||||||
|
"yolo11x.pt",
|
||||||
|
"yolov8n-seg.pt", # Segmentation models
|
||||||
|
"yolov8s-seg.pt",
|
||||||
|
"yolov8m-seg.pt",
|
||||||
|
"yolo11l-seg.pt", # YOLOv11 segmentation models
|
||||||
|
"yolo11x-seg.pt"
|
||||||
|
]
|
||||||
|
models_dir = Path(__file__).parent / "models" / "yolo"
|
||||||
|
|
||||||
|
for model_name in yolo_models:
|
||||||
|
model_path = models_dir / model_name
|
||||||
|
if not model_path.exists():
|
||||||
|
print(f"Downloading {model_name}...")
|
||||||
|
try:
|
||||||
|
# First try to download using the YOLO class with export
|
||||||
|
model = YOLO(model_name)
|
||||||
|
|
||||||
|
# Export/save the model to our directory
|
||||||
|
# The model.ckpt is the internal checkpoint
|
||||||
|
if hasattr(model, 'ckpt') and hasattr(model.ckpt, 'save'):
|
||||||
|
# Save the checkpoint directly
|
||||||
|
torch.save(model.ckpt, str(model_path))
|
||||||
|
print(f" ✓ Saved {model_name} to models directory")
|
||||||
|
else:
|
||||||
|
# Alternative: try to find where YOLO downloaded the model
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
# Common locations where YOLO might store models
|
||||||
|
possible_paths = [
|
||||||
|
Path.home() / ".cache" / "ultralytics" / "models" / model_name,
|
||||||
|
Path.home() / ".ultralytics" / "models" / model_name,
|
||||||
|
Path.home() / "runs" / "detect" / model_name,
|
||||||
|
Path.cwd() / model_name, # Current directory
|
||||||
|
]
|
||||||
|
|
||||||
|
found = False
|
||||||
|
for possible_path in possible_paths:
|
||||||
|
if possible_path.exists():
|
||||||
|
shutil.copy2(possible_path, model_path)
|
||||||
|
print(f" ✓ Copied {model_name} from {possible_path}")
|
||||||
|
found = True
|
||||||
|
# Clean up if it was downloaded to current directory
|
||||||
|
if possible_path.parent == Path.cwd() and possible_path != model_path:
|
||||||
|
possible_path.unlink()
|
||||||
|
break
|
||||||
|
|
||||||
|
if not found:
|
||||||
|
# Last resort: use urllib to download directly
|
||||||
|
# Use different release versions for different YOLO versions
|
||||||
|
if model_name.startswith("yolov11"):
|
||||||
|
yolo_url = f"https://github.com/ultralytics/assets/releases/download/v8.3.0/{model_name}"
|
||||||
|
else:
|
||||||
|
yolo_url = f"https://github.com/ultralytics/assets/releases/download/v8.2.0/{model_name}"
|
||||||
|
print(f" Downloading directly from {yolo_url}...")
|
||||||
|
download_file(yolo_url, str(model_path), f"YOLO {model_name}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f" ⚠ Error downloading {model_name}: {e}")
|
||||||
|
# Try direct download as fallback
|
||||||
|
try:
|
||||||
|
# Use different release versions for different YOLO versions
|
||||||
|
if model_name.startswith("yolov11"):
|
||||||
|
yolo_url = f"https://github.com/ultralytics/assets/releases/download/v8.3.0/{model_name}"
|
||||||
|
else:
|
||||||
|
yolo_url = f"https://github.com/ultralytics/assets/releases/download/v8.2.0/{model_name}"
|
||||||
|
print(f" Trying direct download from {yolo_url}...")
|
||||||
|
download_file(yolo_url, str(model_path), f"YOLO {model_name}")
|
||||||
|
except Exception as e2:
|
||||||
|
print(f" ✗ Failed to download {model_name}: {e2}")
|
||||||
|
else:
|
||||||
|
print(f" ✓ {model_name} already exists")
|
||||||
|
|
||||||
|
# Verify all models exist
|
||||||
|
success = all((models_dir / model).exists() for model in yolo_models)
|
||||||
|
if success:
|
||||||
|
print("✓ YOLO models setup complete!")
|
||||||
|
print(" Available detection models: yolov8n.pt, yolov8s.pt, yolov8m.pt, yolov11l.pt, yolov11x.pt")
|
||||||
|
print(" Available segmentation models: yolov8n-seg.pt, yolov8s-seg.pt, yolov8m-seg.pt, yolov11l-seg.pt, yolov11x-seg.pt")
|
||||||
|
else:
|
||||||
|
missing_models = [model for model in yolo_models if not (models_dir / model).exists()]
|
||||||
|
print("⚠ Some YOLO models may be missing:")
|
||||||
|
for model in missing_models:
|
||||||
|
print(f" - {model}")
|
||||||
|
return success
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
print("⚠ ultralytics not installed. YOLO models will be downloaded on first use.")
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
print(f"⚠ Error setting up YOLO models: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def update_config_file():
|
||||||
|
"""Update config.yaml to use local model paths."""
|
||||||
|
print("\n--- Updating config.yaml ---")
|
||||||
|
|
||||||
|
config_path = Path(__file__).parent / "config.yaml"
|
||||||
|
if not config_path.exists():
|
||||||
|
print("⚠ config.yaml not found, skipping update")
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Read current config
|
||||||
|
with open(config_path, 'r') as f:
|
||||||
|
content = f.read()
|
||||||
|
|
||||||
|
# Update model paths to use local models
|
||||||
|
updated_content = content.replace(
|
||||||
|
'yolo_model: "yolov8n.pt"',
|
||||||
|
'yolo_model: "models/yolo/yolov8n.pt"'
|
||||||
|
).replace(
|
||||||
|
'yolo_detection_model: "models/yolo/yolov8n.pt"',
|
||||||
|
'yolo_detection_model: "models/yolo/yolov8n.pt"'
|
||||||
|
).replace(
|
||||||
|
'yolo_segmentation_model: "models/yolo/yolov8n-seg.pt"',
|
||||||
|
'yolo_segmentation_model: "models/yolo/yolov8n-seg.pt"'
|
||||||
|
).replace(
|
||||||
|
'sam2_checkpoint: "../checkpoints/sam2.1_hiera_large.pt"',
|
||||||
|
'sam2_checkpoint: "models/sam2/checkpoints/sam2.1_hiera_large.pt"'
|
||||||
|
).replace(
|
||||||
|
'sam2_config: "configs/sam2.1/sam2.1_hiera_l.yaml"',
|
||||||
|
'sam2_config: "models/sam2/configs/sam2.1/sam2.1_hiera_l.yaml"'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Write updated config
|
||||||
|
with open(config_path, 'w') as f:
|
||||||
|
f.write(updated_content)
|
||||||
|
|
||||||
|
print("✓ Updated config.yaml to use local model paths")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"⚠ Error updating config.yaml: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Main function to download all models."""
|
||||||
|
print("🤖 YOLO + SAM2 Model Download Script")
|
||||||
|
print("="*50)
|
||||||
|
|
||||||
|
# Download SAM2 models
|
||||||
|
sam2_success = download_sam2_models()
|
||||||
|
|
||||||
|
# Download YOLO models
|
||||||
|
yolo_success = download_yolo_models()
|
||||||
|
|
||||||
|
# Update config file
|
||||||
|
config_success = update_config_file()
|
||||||
|
|
||||||
|
print("\n" + "="*50)
|
||||||
|
print("📋 Final Summary:")
|
||||||
|
print(f" SAM2 models: {'✓' if sam2_success else '⚠'}")
|
||||||
|
print(f" YOLO models: {'✓' if yolo_success else '⚠'}")
|
||||||
|
print(f" Config update: {'✓' if config_success else '⚠'}")
|
||||||
|
|
||||||
|
if sam2_success and config_success:
|
||||||
|
print("\n🎉 Setup complete! You can now run the pipeline with:")
|
||||||
|
print(" python main.py --config config.yaml")
|
||||||
|
else:
|
||||||
|
print("\n⚠ Some steps failed. Check the output above for details.")
|
||||||
|
|
||||||
|
print("\n📁 Models are organized in:")
|
||||||
|
print(f" {Path(__file__).parent / 'models'}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
532
main.py
532
main.py
@@ -8,6 +8,8 @@ and creating green screen masks with SAM2.
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import argparse
|
import argparse
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
# Add project root to path
|
# Add project root to path
|
||||||
@@ -16,6 +18,9 @@ sys.path.append(os.path.dirname(__file__))
|
|||||||
from core.config_loader import ConfigLoader
|
from core.config_loader import ConfigLoader
|
||||||
from core.video_splitter import VideoSplitter
|
from core.video_splitter import VideoSplitter
|
||||||
from core.yolo_detector import YOLODetector
|
from core.yolo_detector import YOLODetector
|
||||||
|
from core.sam2_processor import SAM2Processor
|
||||||
|
from core.mask_processor import MaskProcessor
|
||||||
|
from core.video_assembler import VideoAssembler
|
||||||
from utils.logging_utils import setup_logging, get_logger
|
from utils.logging_utils import setup_logging, get_logger
|
||||||
from utils.file_utils import ensure_directory
|
from utils.file_utils import ensure_directory
|
||||||
from utils.status_utils import print_processing_status, cleanup_incomplete_segment
|
from utils.status_utils import print_processing_status, cleanup_incomplete_segment
|
||||||
@@ -66,6 +71,100 @@ def validate_dependencies():
|
|||||||
logger.error("Please install requirements: pip install -r requirements.txt")
|
logger.error("Please install requirements: pip install -r requirements.txt")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def create_yolo_mask_debug_frame(detections: List[dict], video_path: str, output_path: str, scale: float = 1.0) -> bool:
|
||||||
|
"""
|
||||||
|
Create debug visualization for YOLO direct masks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
detections: List of YOLO detections with masks
|
||||||
|
video_path: Path to video file
|
||||||
|
output_path: Path to save debug image
|
||||||
|
scale: Scale factor for frame processing
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if debug frame was created successfully
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Load first frame
|
||||||
|
cap = cv2.VideoCapture(video_path)
|
||||||
|
ret, original_frame = cap.read()
|
||||||
|
cap.release()
|
||||||
|
|
||||||
|
if not ret:
|
||||||
|
logger.error("Could not read first frame for YOLO mask debug")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Scale frame if needed
|
||||||
|
if scale != 1.0:
|
||||||
|
original_frame = cv2.resize(original_frame, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR)
|
||||||
|
|
||||||
|
debug_frame = original_frame.copy()
|
||||||
|
|
||||||
|
# Define colors for each object
|
||||||
|
colors = {
|
||||||
|
1: (0, 255, 0), # Green for Object 1 (Left eye)
|
||||||
|
2: (255, 0, 0), # Blue for Object 2 (Right eye)
|
||||||
|
}
|
||||||
|
|
||||||
|
# Get detections with masks
|
||||||
|
detections_with_masks = [d for d in detections if d.get('has_mask', False)]
|
||||||
|
|
||||||
|
# Overlay masks with transparency
|
||||||
|
obj_id = 1
|
||||||
|
for detection in detections_with_masks[:2]: # Up to 2 objects
|
||||||
|
mask = detection['mask']
|
||||||
|
|
||||||
|
# Resize mask to match frame if needed
|
||||||
|
if mask.shape != original_frame.shape[:2]:
|
||||||
|
mask = cv2.resize(mask.astype(np.float32), (original_frame.shape[1], original_frame.shape[0]), interpolation=cv2.INTER_NEAREST)
|
||||||
|
mask = mask > 0.5
|
||||||
|
|
||||||
|
mask = mask.astype(bool)
|
||||||
|
|
||||||
|
# Apply colored overlay
|
||||||
|
color = colors.get(obj_id, (128, 128, 128))
|
||||||
|
overlay = debug_frame.copy()
|
||||||
|
overlay[mask] = color
|
||||||
|
|
||||||
|
# Blend with original (30% overlay, 70% original)
|
||||||
|
cv2.addWeighted(overlay, 0.3, debug_frame, 0.7, 0, debug_frame)
|
||||||
|
|
||||||
|
# Draw outline
|
||||||
|
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||||
|
cv2.drawContours(debug_frame, contours, -1, color, 2)
|
||||||
|
|
||||||
|
logger.info(f"YOLO Mask Debug: Object {obj_id} mask - shape: {mask.shape}, pixels: {np.sum(mask)}")
|
||||||
|
obj_id += 1
|
||||||
|
|
||||||
|
# Add title and source info
|
||||||
|
title = f"YOLO Direct Masks: {len(detections_with_masks)} objects detected"
|
||||||
|
cv2.putText(debug_frame, title, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (255, 255, 255), 2)
|
||||||
|
|
||||||
|
source_info = "Mask Source: YOLO Segmentation (DIRECT - No SAM2)"
|
||||||
|
cv2.putText(debug_frame, source_info, (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2) # Green for YOLO
|
||||||
|
|
||||||
|
# Add object legend
|
||||||
|
y_offset = 90
|
||||||
|
for i, detection in enumerate(detections_with_masks[:2]):
|
||||||
|
obj_id = i + 1
|
||||||
|
color = colors.get(obj_id, (128, 128, 128))
|
||||||
|
text = f"Object {obj_id}: {'Left Eye' if obj_id == 1 else 'Right Eye'} (YOLO Mask)"
|
||||||
|
cv2.putText(debug_frame, text, (10, y_offset), cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2)
|
||||||
|
y_offset += 30
|
||||||
|
|
||||||
|
# Save debug image
|
||||||
|
success = cv2.imwrite(output_path, debug_frame)
|
||||||
|
if success:
|
||||||
|
logger.info(f"YOLO Mask Debug: Saved debug frame to {output_path}")
|
||||||
|
else:
|
||||||
|
logger.error(f"Failed to save YOLO mask debug frame to {output_path}")
|
||||||
|
|
||||||
|
return success
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error creating YOLO mask debug frame: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
def resolve_detect_segments(detect_segments, total_segments: int) -> List[int]:
|
def resolve_detect_segments(detect_segments, total_segments: int) -> List[int]:
|
||||||
"""
|
"""
|
||||||
Resolve detect_segments configuration to list of segment indices.
|
Resolve detect_segments configuration to list of segment indices.
|
||||||
@@ -157,31 +256,432 @@ def main():
|
|||||||
detect_segments_config = config.get_detect_segments()
|
detect_segments_config = config.get_detect_segments()
|
||||||
detect_segments = resolve_detect_segments(detect_segments_config, len(segments_info))
|
detect_segments = resolve_detect_segments(detect_segments_config, len(segments_info))
|
||||||
|
|
||||||
# Step 2: Run YOLO detection on specified segments
|
# Initialize processors once
|
||||||
logger.info("Step 2: Running YOLO human detection")
|
logger.info("Step 2: Initializing YOLO detector")
|
||||||
|
|
||||||
|
# Get YOLO mode and model paths
|
||||||
|
yolo_mode = config.get('models.yolo_mode', 'detection')
|
||||||
|
detection_model = config.get('models.yolo_detection_model', config.get_yolo_model_path())
|
||||||
|
segmentation_model = config.get('models.yolo_segmentation_model', None)
|
||||||
|
|
||||||
|
logger.info(f"YOLO Mode: {yolo_mode}")
|
||||||
|
|
||||||
detector = YOLODetector(
|
detector = YOLODetector(
|
||||||
model_path=config.get_yolo_model_path(),
|
detection_model_path=detection_model,
|
||||||
|
segmentation_model_path=segmentation_model,
|
||||||
|
mode=yolo_mode,
|
||||||
confidence_threshold=config.get_yolo_confidence(),
|
confidence_threshold=config.get_yolo_confidence(),
|
||||||
human_class_id=config.get_human_class_id()
|
human_class_id=config.get_human_class_id()
|
||||||
)
|
)
|
||||||
|
|
||||||
detection_results = detector.process_segments_batch(
|
logger.info("Step 3: Initializing SAM2 processor")
|
||||||
segments_info,
|
sam2_processor = SAM2Processor(
|
||||||
detect_segments,
|
checkpoint_path=config.get_sam2_checkpoint(),
|
||||||
scale=config.get_inference_scale()
|
config_path=config.get_sam2_config(),
|
||||||
|
vos_optimized=config.get('models.sam2_vos_optimized', False)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Log detection summary
|
# Initialize mask processor with quality enhancements
|
||||||
total_humans = sum(len(detections) for detections in detection_results.values())
|
mask_quality_config = config.get('mask_processing', {})
|
||||||
logger.info(f"Detected {total_humans} humans across {len(detection_results)} segments")
|
mask_processor = MaskProcessor(
|
||||||
|
green_color=config.get_green_color(),
|
||||||
|
blue_color=config.get_blue_color(),
|
||||||
|
mask_quality_config=mask_quality_config
|
||||||
|
)
|
||||||
|
|
||||||
# Step 3: Process segments with SAM2 (placeholder for now)
|
# Process each segment sequentially (YOLO -> SAM2 -> Render)
|
||||||
logger.info("Step 3: SAM2 processing and green screen generation")
|
logger.info("Step 4: Processing segments sequentially")
|
||||||
logger.info("SAM2 processing module not yet implemented - this is where segment processing would occur")
|
total_humans_detected = 0
|
||||||
|
|
||||||
# Step 4: Assemble final video (placeholder for now)
|
for i, segment_info in enumerate(segments_info):
|
||||||
logger.info("Step 4: Assembling final video with audio")
|
segment_idx = segment_info['index']
|
||||||
logger.info("Video assembly module not yet implemented - this is where concatenation and audio copying would occur")
|
|
||||||
|
logger.info(f"Processing segment {segment_idx}/{len(segments_info)-1}")
|
||||||
|
|
||||||
|
# Reset temporal history for new segment
|
||||||
|
mask_processor.reset_temporal_history()
|
||||||
|
|
||||||
|
# Skip if segment output already exists
|
||||||
|
output_video = os.path.join(segment_info['directory'], f"output_{segment_idx}.mp4")
|
||||||
|
if os.path.exists(output_video):
|
||||||
|
logger.info(f"Segment {segment_idx} already processed, skipping")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Determine if we should use YOLO detections or previous masks
|
||||||
|
use_detections = segment_idx in detect_segments
|
||||||
|
|
||||||
|
# First segment must use detections
|
||||||
|
if segment_idx == 0 and not use_detections:
|
||||||
|
logger.warning(f"First segment must use YOLO detection")
|
||||||
|
use_detections = True
|
||||||
|
|
||||||
|
# Get YOLO prompts or previous masks
|
||||||
|
yolo_prompts = None
|
||||||
|
previous_masks = None
|
||||||
|
|
||||||
|
if use_detections:
|
||||||
|
# Run YOLO detection on current segment
|
||||||
|
logger.info(f"Running YOLO detection on segment {segment_idx}")
|
||||||
|
detection_file = os.path.join(segment_info['directory'], "yolo_detections")
|
||||||
|
|
||||||
|
# Check if detection already exists
|
||||||
|
if os.path.exists(detection_file):
|
||||||
|
logger.info(f"Loading existing YOLO detections for segment {segment_idx}")
|
||||||
|
detections = detector.load_detections_from_file(detection_file)
|
||||||
|
else:
|
||||||
|
# Run YOLO detection on first frame
|
||||||
|
detections = detector.detect_humans_in_video_first_frame(
|
||||||
|
segment_info['video_file'],
|
||||||
|
scale=config.get_inference_scale()
|
||||||
|
)
|
||||||
|
# Save detections for future runs
|
||||||
|
detector.save_detections_to_file(detections, detection_file)
|
||||||
|
|
||||||
|
if detections:
|
||||||
|
total_humans_detected += len(detections)
|
||||||
|
logger.info(f"Found {len(detections)} humans in segment {segment_idx}")
|
||||||
|
|
||||||
|
# Get frame width from video
|
||||||
|
cap = cv2.VideoCapture(segment_info['video_file'])
|
||||||
|
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||||
|
cap.release()
|
||||||
|
|
||||||
|
yolo_prompts = detector.convert_detections_to_sam2_prompts(
|
||||||
|
detections, frame_width
|
||||||
|
)
|
||||||
|
|
||||||
|
# If no right eye detections found, run debug analysis with lower confidence
|
||||||
|
half_frame_width = frame_width // 2
|
||||||
|
right_eye_detections = [d for d in detections if (d['bbox'][0] + d['bbox'][2]) / 2 >= half_frame_width]
|
||||||
|
|
||||||
|
if len(right_eye_detections) == 0 and config.get('advanced.save_yolo_debug_frames', False):
|
||||||
|
logger.info(f"VR180 Debug: No right eye detections found, running lower confidence analysis...")
|
||||||
|
|
||||||
|
# Load first frame for debug analysis
|
||||||
|
cap = cv2.VideoCapture(segment_info['video_file'])
|
||||||
|
ret, debug_frame = cap.read()
|
||||||
|
cap.release()
|
||||||
|
|
||||||
|
if ret:
|
||||||
|
# Scale frame to match detection scale
|
||||||
|
if config.get_inference_scale() != 1.0:
|
||||||
|
scale = config.get_inference_scale()
|
||||||
|
debug_frame = cv2.resize(debug_frame, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR)
|
||||||
|
|
||||||
|
# Run debug detection with lower confidence
|
||||||
|
debug_detections = detector.debug_detect_with_lower_confidence(debug_frame, debug_confidence=0.3)
|
||||||
|
|
||||||
|
# Analyze where these lower confidence detections are
|
||||||
|
debug_right_eye = [d for d in debug_detections if (d['bbox'][0] + d['bbox'][2]) / 2 >= half_frame_width]
|
||||||
|
|
||||||
|
if len(debug_right_eye) > 0:
|
||||||
|
logger.warning(f"VR180 Debug: Found {len(debug_right_eye)} right eye detections with lower confidence!")
|
||||||
|
for i, det in enumerate(debug_right_eye):
|
||||||
|
logger.warning(f"VR180 Debug: Right eye detection {i+1}: conf={det['confidence']:.3f}, bbox={det['bbox']}")
|
||||||
|
logger.warning(f"VR180 Debug: Consider lowering yolo_confidence from {config.get_yolo_confidence()} to 0.3-0.4")
|
||||||
|
else:
|
||||||
|
logger.info(f"VR180 Debug: No right eye detections found even with confidence 0.3")
|
||||||
|
logger.info(f"VR180 Debug: This confirms person is not visible in right eye view")
|
||||||
|
|
||||||
|
logger.info(f"Pipeline Debug: Segment {segment_idx} - Generated {len(yolo_prompts)} SAM2 prompts from {len(detections)} YOLO detections")
|
||||||
|
|
||||||
|
# Save debug frame with detections visualized (if enabled)
|
||||||
|
if config.get('advanced.save_yolo_debug_frames', False):
|
||||||
|
debug_frame_path = os.path.join(segment_info['directory'], "yolo_debug.jpg")
|
||||||
|
|
||||||
|
# Load first frame for debug visualization
|
||||||
|
cap = cv2.VideoCapture(segment_info['video_file'])
|
||||||
|
ret, debug_frame = cap.read()
|
||||||
|
cap.release()
|
||||||
|
|
||||||
|
if ret:
|
||||||
|
# Scale frame to match detection scale
|
||||||
|
if config.get_inference_scale() != 1.0:
|
||||||
|
scale = config.get_inference_scale()
|
||||||
|
debug_frame = cv2.resize(debug_frame, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR)
|
||||||
|
|
||||||
|
detector.save_debug_frame_with_detections(debug_frame, detections, debug_frame_path, yolo_prompts)
|
||||||
|
else:
|
||||||
|
logger.warning(f"Could not load frame for debug visualization in segment {segment_idx}")
|
||||||
|
|
||||||
|
# Check if we have YOLO masks for debug visualization
|
||||||
|
has_yolo_masks = False
|
||||||
|
if detections and detector.supports_segmentation:
|
||||||
|
has_yolo_masks = any(d.get('has_mask', False) for d in detections)
|
||||||
|
|
||||||
|
# Generate first frame masks debug (SAM2 or YOLO)
|
||||||
|
first_frame_debug_path = os.path.join(segment_info['directory'], "first_frame_detection.jpg")
|
||||||
|
|
||||||
|
if has_yolo_masks:
|
||||||
|
logger.info(f"Pipeline Debug: Generating YOLO first frame masks for segment {segment_idx}")
|
||||||
|
# Create YOLO mask debug visualization
|
||||||
|
create_yolo_mask_debug_frame(detections, segment_info['video_file'], first_frame_debug_path, config.get_inference_scale())
|
||||||
|
else:
|
||||||
|
logger.info(f"Pipeline Debug: Generating SAM2 first frame masks for segment {segment_idx}")
|
||||||
|
sam2_processor.generate_first_frame_debug_masks(
|
||||||
|
segment_info['video_file'],
|
||||||
|
yolo_prompts,
|
||||||
|
first_frame_debug_path,
|
||||||
|
config.get_inference_scale()
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(f"No humans detected in segment {segment_idx}")
|
||||||
|
|
||||||
|
# Save debug frame even when no detections (if enabled)
|
||||||
|
if config.get('advanced.save_yolo_debug_frames', False):
|
||||||
|
debug_frame_path = os.path.join(segment_info['directory'], "yolo_debug_no_detections.jpg")
|
||||||
|
|
||||||
|
# Load first frame for debug visualization
|
||||||
|
cap = cv2.VideoCapture(segment_info['video_file'])
|
||||||
|
ret, debug_frame = cap.read()
|
||||||
|
cap.release()
|
||||||
|
|
||||||
|
if ret:
|
||||||
|
# Scale frame to match detection scale
|
||||||
|
if config.get_inference_scale() != 1.0:
|
||||||
|
scale = config.get_inference_scale()
|
||||||
|
debug_frame = cv2.resize(debug_frame, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR)
|
||||||
|
|
||||||
|
# Add "No detections" text overlay
|
||||||
|
cv2.putText(debug_frame, "YOLO: No humans detected",
|
||||||
|
(10, 30),
|
||||||
|
cv2.FONT_HERSHEY_SIMPLEX, 1.0,
|
||||||
|
(0, 0, 255), 2) # Red text
|
||||||
|
|
||||||
|
cv2.imwrite(debug_frame_path, debug_frame)
|
||||||
|
logger.info(f"Saved no-detection debug frame to {debug_frame_path}")
|
||||||
|
else:
|
||||||
|
logger.warning(f"Could not load frame for no-detection debug visualization in segment {segment_idx}")
|
||||||
|
elif segment_idx > 0:
|
||||||
|
# Try to load previous segment mask
|
||||||
|
for j in range(segment_idx - 1, -1, -1):
|
||||||
|
prev_segment_dir = segments_info[j]['directory']
|
||||||
|
previous_masks = sam2_processor.load_previous_segment_mask(prev_segment_dir)
|
||||||
|
if previous_masks:
|
||||||
|
logger.info(f"Using masks from segment {j} for segment {segment_idx}")
|
||||||
|
break
|
||||||
|
|
||||||
|
if not yolo_prompts and not previous_masks:
|
||||||
|
logger.error(f"No prompts or previous masks available for segment {segment_idx}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check if we have YOLO masks and can skip SAM2 (recheck in case detections were loaded from file)
|
||||||
|
if not 'has_yolo_masks' in locals():
|
||||||
|
has_yolo_masks = False
|
||||||
|
if detections and detector.supports_segmentation:
|
||||||
|
has_yolo_masks = any(d.get('has_mask', False) for d in detections)
|
||||||
|
|
||||||
|
if has_yolo_masks:
|
||||||
|
logger.info(f"Pipeline Debug: YOLO segmentation provided masks - using as SAM2 initial masks for segment {segment_idx}")
|
||||||
|
|
||||||
|
# Convert YOLO masks to initial masks for SAM2
|
||||||
|
cap = cv2.VideoCapture(segment_info['video_file'])
|
||||||
|
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||||
|
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||||
|
cap.release()
|
||||||
|
|
||||||
|
# Convert YOLO masks to the format expected by SAM2 add_previous_masks_to_predictor
|
||||||
|
yolo_masks_dict = {}
|
||||||
|
for i, detection in enumerate(detections[:2]): # Up to 2 objects
|
||||||
|
if detection.get('has_mask', False):
|
||||||
|
mask = detection['mask']
|
||||||
|
# Resize mask to match inference scale
|
||||||
|
if config.get_inference_scale() != 1.0:
|
||||||
|
scale = config.get_inference_scale()
|
||||||
|
scaled_height = int(frame_height * scale)
|
||||||
|
scaled_width = int(frame_width * scale)
|
||||||
|
mask = cv2.resize(mask.astype(np.float32), (scaled_width, scaled_height), interpolation=cv2.INTER_NEAREST)
|
||||||
|
mask = mask > 0.5
|
||||||
|
|
||||||
|
obj_id = i + 1 # Sequential object IDs
|
||||||
|
yolo_masks_dict[obj_id] = mask.astype(bool)
|
||||||
|
logger.info(f"Pipeline Debug: YOLO mask for Object {obj_id} - shape: {mask.shape}, pixels: {np.sum(mask)}")
|
||||||
|
|
||||||
|
logger.info(f"Pipeline Debug: Using YOLO masks as SAM2 initial masks - {len(yolo_masks_dict)} objects")
|
||||||
|
|
||||||
|
# Use traditional SAM2 pipeline with YOLO masks as initial masks
|
||||||
|
previous_masks = yolo_masks_dict
|
||||||
|
yolo_prompts = None # Don't use bounding box prompts when we have masks
|
||||||
|
|
||||||
|
# Debug what we're passing to SAM2
|
||||||
|
if yolo_prompts:
|
||||||
|
logger.info(f"Pipeline Debug: Passing {len(yolo_prompts)} YOLO prompts to SAM2 for segment {segment_idx}")
|
||||||
|
for i, prompt in enumerate(yolo_prompts):
|
||||||
|
logger.info(f"Pipeline Debug: Prompt {i+1}: Object {prompt['obj_id']}, bbox={prompt['bbox']}")
|
||||||
|
|
||||||
|
if previous_masks:
|
||||||
|
logger.info(f"Pipeline Debug: Using {len(previous_masks)} previous masks for segment {segment_idx}")
|
||||||
|
logger.info(f"Pipeline Debug: Previous mask object IDs: {list(previous_masks.keys())}")
|
||||||
|
|
||||||
|
# Handle mid-segment detection if enabled (works for both detection and segmentation modes)
|
||||||
|
multi_frame_prompts = None
|
||||||
|
if config.get('advanced.enable_mid_segment_detection', False) and (yolo_prompts or has_yolo_masks):
|
||||||
|
logger.info(f"Mid-segment Detection: Enabled for segment {segment_idx}")
|
||||||
|
|
||||||
|
# Calculate frame indices for re-detection
|
||||||
|
cap = cv2.VideoCapture(segment_info['video_file'])
|
||||||
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||||
|
fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
|
||||||
|
cap.release()
|
||||||
|
|
||||||
|
redetection_interval = config.get('advanced.redetection_interval', 30)
|
||||||
|
max_redetections = config.get('advanced.max_redetections_per_segment', 10)
|
||||||
|
|
||||||
|
# Generate frame indices: [30, 60, 90, ...] (skip frame 0 since we already have first frame prompts)
|
||||||
|
frame_indices = []
|
||||||
|
frame_idx = redetection_interval
|
||||||
|
while frame_idx < total_frames and len(frame_indices) < max_redetections:
|
||||||
|
frame_indices.append(frame_idx)
|
||||||
|
frame_idx += redetection_interval
|
||||||
|
|
||||||
|
if frame_indices:
|
||||||
|
logger.info(f"Mid-segment Detection: Running YOLO on frames {frame_indices} (interval={redetection_interval})")
|
||||||
|
|
||||||
|
# Run multi-frame detection
|
||||||
|
multi_frame_detections = detector.detect_humans_multi_frame(
|
||||||
|
segment_info['video_file'],
|
||||||
|
frame_indices,
|
||||||
|
scale=config.get_inference_scale()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert detections to SAM2 prompts (different handling for segmentation vs detection mode)
|
||||||
|
multi_frame_prompts = {}
|
||||||
|
cap = cv2.VideoCapture(segment_info['video_file'])
|
||||||
|
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||||
|
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||||
|
cap.release()
|
||||||
|
|
||||||
|
for frame_idx, detections in multi_frame_detections.items():
|
||||||
|
if detections:
|
||||||
|
if has_yolo_masks:
|
||||||
|
# Segmentation mode: convert YOLO masks to SAM2 mask prompts
|
||||||
|
frame_masks = {}
|
||||||
|
for i, detection in enumerate(detections[:2]): # Up to 2 objects
|
||||||
|
if detection.get('has_mask', False):
|
||||||
|
mask = detection['mask']
|
||||||
|
# Resize mask to match inference scale
|
||||||
|
if config.get_inference_scale() != 1.0:
|
||||||
|
scale = config.get_inference_scale()
|
||||||
|
scaled_height = int(frame_height * scale)
|
||||||
|
scaled_width = int(frame_width * scale)
|
||||||
|
mask = cv2.resize(mask.astype(np.float32), (scaled_width, scaled_height), interpolation=cv2.INTER_NEAREST)
|
||||||
|
mask = mask > 0.5
|
||||||
|
|
||||||
|
obj_id = i + 1 # Sequential object IDs
|
||||||
|
frame_masks[obj_id] = mask.astype(bool)
|
||||||
|
logger.debug(f"Mid-segment Detection: Frame {frame_idx}, Object {obj_id} mask - shape: {mask.shape}, pixels: {np.sum(mask)}")
|
||||||
|
|
||||||
|
if frame_masks:
|
||||||
|
# Store as mask prompts (different format than bbox prompts)
|
||||||
|
multi_frame_prompts[frame_idx] = {'masks': frame_masks}
|
||||||
|
logger.info(f"Mid-segment Detection: Frame {frame_idx} -> {len(frame_masks)} YOLO masks")
|
||||||
|
else:
|
||||||
|
# Detection mode: convert to bounding box prompts (existing logic)
|
||||||
|
prompts = detector.convert_detections_to_sam2_prompts(detections, frame_width)
|
||||||
|
multi_frame_prompts[frame_idx] = prompts
|
||||||
|
logger.info(f"Mid-segment Detection: Frame {frame_idx} -> {len(prompts)} SAM2 prompts")
|
||||||
|
|
||||||
|
logger.info(f"Mid-segment Detection: Generated prompts for {len(multi_frame_prompts)} frames")
|
||||||
|
else:
|
||||||
|
logger.info(f"Mid-segment Detection: No additional frames to process (segment has {total_frames} frames)")
|
||||||
|
elif config.get('advanced.enable_mid_segment_detection', False):
|
||||||
|
logger.info(f"Mid-segment Detection: Skipped for segment {segment_idx} (no initial YOLO data)")
|
||||||
|
|
||||||
|
# Process segment with SAM2
|
||||||
|
logger.info(f"Pipeline Debug: Starting SAM2 processing for segment {segment_idx}")
|
||||||
|
video_segments = sam2_processor.process_single_segment(
|
||||||
|
segment_info,
|
||||||
|
yolo_prompts=yolo_prompts,
|
||||||
|
previous_masks=previous_masks,
|
||||||
|
inference_scale=config.get_inference_scale(),
|
||||||
|
multi_frame_prompts=multi_frame_prompts
|
||||||
|
)
|
||||||
|
|
||||||
|
if video_segments is None:
|
||||||
|
logger.error(f"SAM2 processing failed for segment {segment_idx}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check if SAM2 produced adequate results
|
||||||
|
if len(video_segments) == 0:
|
||||||
|
logger.error(f"SAM2 produced no frames for segment {segment_idx}")
|
||||||
|
continue
|
||||||
|
elif len(video_segments) < 10: # Expected many frames for a 5-second segment
|
||||||
|
logger.warning(f"SAM2 produced very few frames ({len(video_segments)}) for segment {segment_idx} - this may indicate propagation failure")
|
||||||
|
|
||||||
|
# Debug what SAM2 produced
|
||||||
|
logger.info(f"Pipeline Debug: SAM2 completed for segment {segment_idx}")
|
||||||
|
logger.info(f"Pipeline Debug: Generated masks for {len(video_segments)} frames")
|
||||||
|
|
||||||
|
if video_segments:
|
||||||
|
# Check first frame to see what objects were tracked
|
||||||
|
first_frame_idx = min(video_segments.keys())
|
||||||
|
first_frame_objects = video_segments[first_frame_idx]
|
||||||
|
logger.info(f"Pipeline Debug: First frame contains {len(first_frame_objects)} tracked objects")
|
||||||
|
logger.info(f"Pipeline Debug: Tracked object IDs: {list(first_frame_objects.keys())}")
|
||||||
|
|
||||||
|
for obj_id, mask in first_frame_objects.items():
|
||||||
|
mask_pixels = np.sum(mask)
|
||||||
|
logger.info(f"Pipeline Debug: Object {obj_id} mask has {mask_pixels} pixels")
|
||||||
|
|
||||||
|
# Check last frame as well
|
||||||
|
last_frame_idx = max(video_segments.keys())
|
||||||
|
last_frame_objects = video_segments[last_frame_idx]
|
||||||
|
logger.info(f"Pipeline Debug: Last frame contains {len(last_frame_objects)} tracked objects")
|
||||||
|
logger.info(f"Pipeline Debug: Final object IDs: {list(last_frame_objects.keys())}")
|
||||||
|
|
||||||
|
# Save final masks for next segment
|
||||||
|
mask_path = os.path.join(segment_info['directory'], "mask.png")
|
||||||
|
sam2_processor.save_final_masks(
|
||||||
|
video_segments,
|
||||||
|
mask_path,
|
||||||
|
green_color=config.get_green_color(),
|
||||||
|
blue_color=config.get_blue_color()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply green screen and save output video
|
||||||
|
success = mask_processor.process_segment(
|
||||||
|
segment_info,
|
||||||
|
video_segments,
|
||||||
|
use_nvenc=config.get_use_nvenc(),
|
||||||
|
bitrate=config.get_output_bitrate()
|
||||||
|
)
|
||||||
|
|
||||||
|
if success:
|
||||||
|
logger.info(f"Successfully processed segment {segment_idx}")
|
||||||
|
else:
|
||||||
|
logger.error(f"Failed to create green screen video for segment {segment_idx}")
|
||||||
|
|
||||||
|
# Log processing summary
|
||||||
|
logger.info(f"Sequential processing complete. Total humans detected: {total_humans_detected}")
|
||||||
|
|
||||||
|
# Step 3: Assemble final video
|
||||||
|
logger.info("Step 3: Assembling final video with audio")
|
||||||
|
|
||||||
|
# Initialize video assembler
|
||||||
|
assembler = VideoAssembler(
|
||||||
|
preserve_audio=config.get_preserve_audio(),
|
||||||
|
use_nvenc=config.get_use_nvenc()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify all segments are complete
|
||||||
|
all_complete, missing = assembler.verify_segment_completeness(segments_dir)
|
||||||
|
|
||||||
|
if not all_complete:
|
||||||
|
logger.error(f"Cannot assemble video - missing segments: {missing}")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
# Assemble final video
|
||||||
|
final_output = os.path.join(output_dir, config.get_output_filename())
|
||||||
|
|
||||||
|
success = assembler.assemble_final_video(
|
||||||
|
segments_dir,
|
||||||
|
input_video,
|
||||||
|
final_output,
|
||||||
|
bitrate=config.get_output_bitrate()
|
||||||
|
)
|
||||||
|
|
||||||
|
if success:
|
||||||
|
logger.info(f"Final video saved to: {final_output}")
|
||||||
|
|
||||||
logger.info("Pipeline completed successfully")
|
logger.info("Pipeline completed successfully")
|
||||||
return 0
|
return 0
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ opencv-python>=4.8.0
|
|||||||
numpy>=1.24.0
|
numpy>=1.24.0
|
||||||
|
|
||||||
# SAM2 - Segment Anything Model 2
|
# SAM2 - Segment Anything Model 2
|
||||||
|
# Note: Make sure to run download_models.py after installing to get model weights
|
||||||
git+https://github.com/facebookresearch/sam2.git
|
git+https://github.com/facebookresearch/sam2.git
|
||||||
|
|
||||||
# GPU acceleration (optional but recommended)
|
# GPU acceleration (optional but recommended)
|
||||||
@@ -17,6 +18,8 @@ tqdm>=4.65.0
|
|||||||
matplotlib>=3.7.0
|
matplotlib>=3.7.0
|
||||||
Pillow>=10.0.0
|
Pillow>=10.0.0
|
||||||
|
|
||||||
|
decord
|
||||||
|
|
||||||
# Optional: For advanced features
|
# Optional: For advanced features
|
||||||
psutil>=5.9.0 # Memory monitoring
|
psutil>=5.9.0 # Memory monitoring
|
||||||
pympler>=0.9 # Memory profiling (for debugging)
|
pympler>=0.9 # Memory profiling (for debugging)
|
||||||
@@ -27,4 +30,4 @@ ffmpeg-python>=0.2.0 # Python wrapper for FFmpeg (optional, shell ffmpeg still
|
|||||||
# Development dependencies (optional)
|
# Development dependencies (optional)
|
||||||
pytest>=7.0.0
|
pytest>=7.0.0
|
||||||
black>=23.0.0
|
black>=23.0.0
|
||||||
flake8>=6.0.0
|
flake8>=6.0.0
|
||||||
|
|||||||
620
spec.md
620
spec.md
@@ -189,4 +189,622 @@ models:
|
|||||||
### Model Improvements
|
### Model Improvements
|
||||||
- **Fine-tuned YOLO**: Domain-specific human detection models
|
- **Fine-tuned YOLO**: Domain-specific human detection models
|
||||||
- **SAM2 Optimization**: Custom SAM2 checkpoints for video content
|
- **SAM2 Optimization**: Custom SAM2 checkpoints for video content
|
||||||
- **Temporal Consistency**: Enhanced cross-segment mask propagation
|
- **Temporal Consistency**: Enhanced cross-segment mask propagation
|
||||||
|
|
||||||
|
|
||||||
|
Here is the original monolithic script this repo is a refactor/modularization of. If something
|
||||||
|
doesn't work in this repo, then consult the following script becasue it does work so this can
|
||||||
|
be used to solve problems:
|
||||||
|
|
||||||
|
|
||||||
|
import os
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import cupy as cp
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
import torch
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
import gc
|
||||||
|
from sam2.build_sam import build_sam2_video_predictor
|
||||||
|
import argparse
|
||||||
|
from ultralytics import YOLO
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
||||||
|
|
||||||
|
# Variables for input and output directories
|
||||||
|
SAM2_CHECKPOINT = "../checkpoints/sam2.1_hiera_large.pt"
|
||||||
|
MODEL_CFG = "configs/sam2.1/sam2.1_hiera_l.yaml"
|
||||||
|
GREEN = [0, 255, 0]
|
||||||
|
BLUE = [255, 0, 0]
|
||||||
|
|
||||||
|
INFERENCE_SCALE = 0.50
|
||||||
|
FULL_SCALE = 1.0
|
||||||
|
|
||||||
|
# YOLO model for human detection (class 0 = person)
|
||||||
|
YOLO_MODEL_PATH = "yolov8n.pt" # You can change this to a custom model
|
||||||
|
YOLO_CONFIDENCE = 0.6
|
||||||
|
HUMAN_CLASS_ID = 0 # COCO class ID for person
|
||||||
|
|
||||||
|
def open_video(video_path):
|
||||||
|
"""
|
||||||
|
Opens a video file and returns a generator that yields frames.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- video_path: Path to the video file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- A generator that yields frames from the video.
|
||||||
|
"""
|
||||||
|
cap = cv2.VideoCapture(video_path)
|
||||||
|
if not cap.isOpened():
|
||||||
|
print(f"Error: Could not open video file {video_path}")
|
||||||
|
return
|
||||||
|
while True:
|
||||||
|
ret, frame = cap.read()
|
||||||
|
if not ret:
|
||||||
|
break
|
||||||
|
yield frame
|
||||||
|
cap.release()
|
||||||
|
|
||||||
|
def load_previous_segment_mask(prev_segment_dir):
|
||||||
|
mask_path = os.path.join(prev_segment_dir, "mask.png")
|
||||||
|
mask_image = cv2.imread(mask_path)
|
||||||
|
|
||||||
|
if mask_image is None:
|
||||||
|
raise FileNotFoundError(f"Mask image not found at {mask_path}")
|
||||||
|
|
||||||
|
# Ensure the mask_image has three color channels
|
||||||
|
if len(mask_image.shape) != 3 or mask_image.shape[2] != 3:
|
||||||
|
raise ValueError("Mask image does not have three color channels.")
|
||||||
|
|
||||||
|
mask_image = mask_image.astype(np.uint8)
|
||||||
|
|
||||||
|
# Extract Object A and Object B masks
|
||||||
|
mask_a = np.all(mask_image == GREEN, axis=2)
|
||||||
|
mask_b = np.all(mask_image == BLUE, axis=2)
|
||||||
|
|
||||||
|
per_obj_input_mask = {1: mask_a, 2: mask_b}
|
||||||
|
input_palette = None # No palette needed for binary mask
|
||||||
|
|
||||||
|
return per_obj_input_mask, input_palette
|
||||||
|
|
||||||
|
|
||||||
|
def apply_green_mask(frame, masks):
|
||||||
|
# Convert frame and masks to CuPy arrays
|
||||||
|
frame_gpu = cp.asarray(frame)
|
||||||
|
combined_mask = cp.zeros(frame_gpu.shape[:2], dtype=cp.bool_)
|
||||||
|
|
||||||
|
for mask in masks:
|
||||||
|
mask_gpu = cp.asarray(mask.squeeze())
|
||||||
|
if mask_gpu.shape != frame_gpu.shape[:2]:
|
||||||
|
resized_mask = cv2.resize(cp.asnumpy(mask_gpu).astype(cp.float32),
|
||||||
|
(frame_gpu.shape[1], frame_gpu.shape[0]))
|
||||||
|
mask_gpu = cp.asarray(resized_mask > 0.5) # Convert back to CuPy boolean array
|
||||||
|
else:
|
||||||
|
mask_gpu = mask_gpu.astype(cp.bool_) # Ensure boolean type
|
||||||
|
combined_mask |= mask_gpu # Perform the bitwise OR operation
|
||||||
|
|
||||||
|
green_background = cp.full(frame_gpu.shape, cp.array([0, 255, 0], dtype=cp.uint8), dtype=cp.uint8)
|
||||||
|
result_frame = cp.where(combined_mask[..., None], frame_gpu, green_background)
|
||||||
|
return cp.asnumpy(result_frame) # Convert back to NumPy
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_predictor():
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda")
|
||||||
|
elif torch.backends.mps.is_available():
|
||||||
|
device = torch.device("mps")
|
||||||
|
print(
|
||||||
|
"\nSupport for MPS devices is preliminary. SAM 2 is trained with CUDA and might "
|
||||||
|
"give numerically different outputs and sometimes degraded performance on MPS."
|
||||||
|
)
|
||||||
|
# Enable MPS fallback for operations not supported on MPS
|
||||||
|
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
||||||
|
else:
|
||||||
|
device = torch.device("cpu")
|
||||||
|
logger.info(f"Using device: {device}")
|
||||||
|
predictor = build_sam2_video_predictor(MODEL_CFG, SAM2_CHECKPOINT, device=device)
|
||||||
|
return predictor
|
||||||
|
|
||||||
|
|
||||||
|
def load_first_frame(video_path, scale=1.0):
|
||||||
|
"""
|
||||||
|
Opens a video file and returns the first frame, scaled as specified.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- video_path: Path to the video file.
|
||||||
|
- scale: Scaling factor for the frame (default is 1.0 for original size).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- first_frame: The first frame of the video, scaled accordingly.
|
||||||
|
"""
|
||||||
|
cap = cv2.VideoCapture(video_path)
|
||||||
|
if not cap.isOpened():
|
||||||
|
logger.error(f"Error: Could not open video file {video_path}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
ret, frame = cap.read()
|
||||||
|
cap.release()
|
||||||
|
|
||||||
|
if not ret:
|
||||||
|
logger.error(f"Error: Could not read frame from video file {video_path}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
if scale != 1.0:
|
||||||
|
frame = cv2.resize(
|
||||||
|
frame, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR
|
||||||
|
)
|
||||||
|
|
||||||
|
return frame
|
||||||
|
|
||||||
|
def detect_humans_with_yolo(frame, yolo_model, confidence_threshold=YOLO_CONFIDENCE):
|
||||||
|
"""
|
||||||
|
Detect humans in a frame using YOLO model.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- frame: Input frame (BGR format)
|
||||||
|
- yolo_model: Loaded YOLO model
|
||||||
|
- confidence_threshold: Detection confidence threshold
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- human_boxes: List of bounding boxes for detected humans
|
||||||
|
"""
|
||||||
|
# Run YOLO detection
|
||||||
|
results = yolo_model(frame, conf=confidence_threshold, verbose=False)
|
||||||
|
|
||||||
|
human_boxes = []
|
||||||
|
|
||||||
|
# Process results
|
||||||
|
for result in results:
|
||||||
|
boxes = result.boxes
|
||||||
|
if boxes is not None:
|
||||||
|
for box in boxes:
|
||||||
|
# Get class ID
|
||||||
|
cls = int(box.cls.cpu().numpy()[0])
|
||||||
|
|
||||||
|
# Check if it's a person (class 0 in COCO)
|
||||||
|
if cls == HUMAN_CLASS_ID:
|
||||||
|
# Get bounding box coordinates (x1, y1, x2, y2)
|
||||||
|
coords = box.xyxy[0].cpu().numpy()
|
||||||
|
conf = float(box.conf.cpu().numpy()[0])
|
||||||
|
|
||||||
|
human_boxes.append({
|
||||||
|
'bbox': coords,
|
||||||
|
'confidence': conf
|
||||||
|
})
|
||||||
|
|
||||||
|
logger.info(f"Detected human with confidence {conf:.2f} at {coords}")
|
||||||
|
|
||||||
|
return human_boxes
|
||||||
|
|
||||||
|
def add_yolo_detections_to_predictor(predictor, inference_state, human_detections, frame_width):
|
||||||
|
"""
|
||||||
|
Add YOLO human detections as bounding boxes to SAM2 predictor.
|
||||||
|
For stereo videos, creates two objects (left and right humans).
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- predictor: SAM2 video predictor
|
||||||
|
- inference_state: SAM2 inference state
|
||||||
|
- human_detections: List of human detection results
|
||||||
|
- frame_width: Width of the frame for stereo splitting
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- out_mask_logits: SAM2 output mask logits
|
||||||
|
"""
|
||||||
|
half_frame_width = frame_width // 2
|
||||||
|
|
||||||
|
# Sort detections by x-coordinate to get left and right humans
|
||||||
|
human_detections.sort(key=lambda x: x['bbox'][0]) # Sort by x1 coordinate
|
||||||
|
|
||||||
|
obj_id = 1
|
||||||
|
out_mask_logits = None
|
||||||
|
|
||||||
|
for i, detection in enumerate(human_detections[:2]): # Take up to 2 humans (left and right)
|
||||||
|
bbox = detection['bbox']
|
||||||
|
|
||||||
|
# For stereo videos, assign obj_id based on position
|
||||||
|
if len(human_detections) >= 2:
|
||||||
|
# If we have multiple humans, assign based on left/right position
|
||||||
|
center_x = (bbox[0] + bbox[2]) / 2
|
||||||
|
if center_x < half_frame_width:
|
||||||
|
current_obj_id = 1 # Left human
|
||||||
|
else:
|
||||||
|
current_obj_id = 2 # Right human
|
||||||
|
else:
|
||||||
|
# If only one human, duplicate for both sides (as in original stereo logic)
|
||||||
|
current_obj_id = obj_id
|
||||||
|
obj_id += 1
|
||||||
|
|
||||||
|
# Also add the mirrored version for stereo
|
||||||
|
if obj_id <= 2:
|
||||||
|
mirrored_bbox = bbox.copy()
|
||||||
|
mirrored_bbox[0] += half_frame_width # Shift x1
|
||||||
|
mirrored_bbox[2] += half_frame_width # Shift x2
|
||||||
|
|
||||||
|
# Ensure mirrored bbox is within frame bounds
|
||||||
|
mirrored_bbox[0] = max(0, min(mirrored_bbox[0], frame_width - 1))
|
||||||
|
mirrored_bbox[2] = max(0, min(mirrored_bbox[2], frame_width - 1))
|
||||||
|
|
||||||
|
try:
|
||||||
|
_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
|
||||||
|
inference_state=inference_state,
|
||||||
|
frame_idx=0,
|
||||||
|
obj_id=obj_id,
|
||||||
|
box=mirrored_bbox.astype(np.float32),
|
||||||
|
)
|
||||||
|
logger.info(f"Added mirrored human detection for Object {obj_id}")
|
||||||
|
obj_id += 1
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error adding mirrored human detection for Object {obj_id}: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
|
||||||
|
inference_state=inference_state,
|
||||||
|
frame_idx=0,
|
||||||
|
obj_id=current_obj_id,
|
||||||
|
box=bbox.astype(np.float32),
|
||||||
|
)
|
||||||
|
logger.info(f"Added human detection for Object {current_obj_id}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error adding human detection for Object {current_obj_id}: {e}")
|
||||||
|
|
||||||
|
return out_mask_logits
|
||||||
|
|
||||||
|
def propagate_masks(predictor, inference_state):
|
||||||
|
video_segments = {}
|
||||||
|
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
|
||||||
|
video_segments[out_frame_idx] = {
|
||||||
|
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
|
||||||
|
for i, out_obj_id in enumerate(out_obj_ids)
|
||||||
|
}
|
||||||
|
return video_segments
|
||||||
|
|
||||||
|
def apply_colored_mask(frame, masks_a, masks_b):
|
||||||
|
colored_mask = np.zeros_like(frame)
|
||||||
|
|
||||||
|
# Apply colors to the masks
|
||||||
|
for mask in masks_a:
|
||||||
|
mask = mask.squeeze()
|
||||||
|
if mask.shape != frame.shape[:2]:
|
||||||
|
mask = cv2.resize(mask, (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_NEAREST)
|
||||||
|
indices = np.where(mask)
|
||||||
|
colored_mask[mask] = [0, 255, 0] # Green for Object A
|
||||||
|
|
||||||
|
for mask in masks_b:
|
||||||
|
mask = mask.squeeze()
|
||||||
|
if mask.shape != frame.shape[:2]:
|
||||||
|
mask = cv2.resize(mask, (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_NEAREST)
|
||||||
|
indices = np.where(mask)
|
||||||
|
colored_mask[mask] = [255, 0, 0] # Blue for Object B
|
||||||
|
|
||||||
|
return colored_mask
|
||||||
|
|
||||||
|
|
||||||
|
def process_and_save_output_video(video_path, output_video_path, video_segments, use_nvenc=False):
|
||||||
|
"""
|
||||||
|
Process high-resolution frames, apply upscaled masks, and save the output video.
|
||||||
|
"""
|
||||||
|
cap = cv2.VideoCapture(video_path)
|
||||||
|
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||||
|
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||||
|
fps = cap.get(cv2.CAP_PROP_FPS) or 59.94
|
||||||
|
|
||||||
|
# Setup VideoWriter with desired settings
|
||||||
|
if use_nvenc:
|
||||||
|
# Use FFmpeg with NVENC offloading for H.265 encoding
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
if sys.platform == 'darwin':
|
||||||
|
encoder = 'hevc_videotoolbox'
|
||||||
|
else:
|
||||||
|
encoder = 'hevc_nvenc'
|
||||||
|
|
||||||
|
command = [
|
||||||
|
'ffmpeg',
|
||||||
|
'-y', # Overwrite output file if it exists
|
||||||
|
'-f', 'rawvideo',
|
||||||
|
'-vcodec', 'rawvideo',
|
||||||
|
'-pix_fmt', 'bgr24',
|
||||||
|
'-s', f'{frame_width}x{frame_height}',
|
||||||
|
'-r', str(fps),
|
||||||
|
'-i', '-', # Input from stdin
|
||||||
|
'-an', # No audio
|
||||||
|
'-vcodec', encoder,
|
||||||
|
'-pix_fmt', 'nv12',
|
||||||
|
'-preset', 'slow',
|
||||||
|
'-b:v', '50M',
|
||||||
|
output_video_path
|
||||||
|
]
|
||||||
|
process = subprocess.Popen(command, stdin=subprocess.PIPE)
|
||||||
|
else:
|
||||||
|
# Use OpenCV VideoWriter
|
||||||
|
fourcc = cv2.VideoWriter_fourcc(*'HEVC') # H.265
|
||||||
|
out = cv2.VideoWriter(output_video_path, fourcc, fps, (frame_width, frame_height))
|
||||||
|
|
||||||
|
frame_idx = 0
|
||||||
|
while True:
|
||||||
|
ret, frame = cap.read()
|
||||||
|
if not ret or frame_idx >= len(video_segments):
|
||||||
|
break
|
||||||
|
|
||||||
|
masks = [video_segments[frame_idx][out_obj_id] for out_obj_id in video_segments[frame_idx]]
|
||||||
|
upscaled_masks = []
|
||||||
|
|
||||||
|
for mask in masks:
|
||||||
|
mask = mask.squeeze()
|
||||||
|
upscaled_mask = cv2.resize(mask.astype(np.uint8), (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_NEAREST)
|
||||||
|
upscaled_masks.append(upscaled_mask)
|
||||||
|
|
||||||
|
result_frame = apply_green_mask(frame, upscaled_masks)
|
||||||
|
|
||||||
|
# Write frame to output
|
||||||
|
if use_nvenc:
|
||||||
|
process.stdin.write(result_frame.tobytes())
|
||||||
|
else:
|
||||||
|
out.write(result_frame)
|
||||||
|
|
||||||
|
frame_idx += 1
|
||||||
|
|
||||||
|
cap.release()
|
||||||
|
if use_nvenc:
|
||||||
|
process.stdin.close()
|
||||||
|
process.wait()
|
||||||
|
else:
|
||||||
|
out.release()
|
||||||
|
|
||||||
|
def get_video_file_name(index):
|
||||||
|
return f"segment_{str(index).zfill(3)}.mp4"
|
||||||
|
|
||||||
|
def do_yolo_detection_on_segments(base_dir, segments, detect_segments, scale=1.0, yolo_model_path=YOLO_MODEL_PATH):
|
||||||
|
"""
|
||||||
|
Run YOLO detection on specified segments and save detection results.
|
||||||
|
"""
|
||||||
|
logger.info("Running YOLO detection on requested segments.")
|
||||||
|
|
||||||
|
# Load YOLO model
|
||||||
|
yolo_model = YOLO(yolo_model_path)
|
||||||
|
|
||||||
|
for i, segment in enumerate(segments):
|
||||||
|
segment_index = int(segment.split("_")[1])
|
||||||
|
segment_dir = os.path.join(base_dir, segment)
|
||||||
|
detection_file = os.path.join(segment_dir, "yolo_detections")
|
||||||
|
video_file = os.path.join(segment_dir, get_video_file_name(i))
|
||||||
|
|
||||||
|
if segment_index in detect_segments and not os.path.exists(detection_file):
|
||||||
|
first_frame = load_first_frame(video_file, scale)
|
||||||
|
if first_frame is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Convert BGR to RGB for YOLO (YOLO expects BGR, so keep as BGR)
|
||||||
|
human_detections = detect_humans_with_yolo(first_frame, yolo_model)
|
||||||
|
|
||||||
|
if human_detections:
|
||||||
|
# Save detection results
|
||||||
|
with open(detection_file, 'w') as f:
|
||||||
|
f.write("# YOLO Human Detections\n")
|
||||||
|
for detection in human_detections:
|
||||||
|
bbox = detection['bbox']
|
||||||
|
conf = detection['confidence']
|
||||||
|
f.write(f"{bbox[0]},{bbox[1]},{bbox[2]},{bbox[3]},{conf}\n")
|
||||||
|
logger.info(f"Saved {len(human_detections)} human detections for segment {segment}")
|
||||||
|
else:
|
||||||
|
logger.warning(f"No humans detected in segment {segment}")
|
||||||
|
# Create empty file to mark as processed
|
||||||
|
with open(detection_file, 'w') as f:
|
||||||
|
f.write("# No humans detected\n")
|
||||||
|
|
||||||
|
def save_final_masks(video_segments, mask_output_path):
|
||||||
|
"""
|
||||||
|
Save the final masks as a colored image.
|
||||||
|
"""
|
||||||
|
last_frame_idx = max(video_segments.keys())
|
||||||
|
masks_dict = video_segments[last_frame_idx]
|
||||||
|
# Assuming you have two objects with IDs 1 and 2
|
||||||
|
mask_a = masks_dict.get(1).squeeze() if 1 in masks_dict else None
|
||||||
|
mask_b = masks_dict.get(2).squeeze() if 2 in masks_dict else None
|
||||||
|
|
||||||
|
if mask_a is None and mask_b is None:
|
||||||
|
logger.error("No masks found for objects.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Use the first available mask to determine dimensions
|
||||||
|
reference_mask = mask_a if mask_a is not None else mask_b
|
||||||
|
black_frame = np.zeros((reference_mask.shape[0], reference_mask.shape[1], 3), dtype=np.uint8)
|
||||||
|
|
||||||
|
if mask_a is not None:
|
||||||
|
mask_a = mask_a.astype(bool)
|
||||||
|
black_frame[mask_a] = GREEN
|
||||||
|
|
||||||
|
if mask_b is not None:
|
||||||
|
mask_b = mask_b.astype(bool)
|
||||||
|
black_frame[mask_b] = BLUE
|
||||||
|
|
||||||
|
# Save the mask image
|
||||||
|
cv2.imwrite(mask_output_path, black_frame)
|
||||||
|
logger.info(f"Saved final masks to {mask_output_path}")
|
||||||
|
|
||||||
|
def create_low_res_video(input_video_path, output_video_path, scale):
|
||||||
|
"""
|
||||||
|
Creates a low-resolution version of the input video for inference.
|
||||||
|
"""
|
||||||
|
cap = cv2.VideoCapture(input_video_path)
|
||||||
|
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH) * scale)
|
||||||
|
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT) * scale)
|
||||||
|
fps = cap.get(cv2.CAP_PROP_FPS) or 59.94
|
||||||
|
|
||||||
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
||||||
|
out = cv2.VideoWriter(output_video_path, fourcc, fps, (frame_width, frame_height))
|
||||||
|
|
||||||
|
while True:
|
||||||
|
ret, frame = cap.read()
|
||||||
|
if not ret:
|
||||||
|
break
|
||||||
|
low_res_frame = cv2.resize(frame, (frame_width, frame_height), interpolation=cv2.INTER_LINEAR)
|
||||||
|
out.write(low_res_frame)
|
||||||
|
|
||||||
|
cap.release()
|
||||||
|
out.release()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Process video segments with YOLO + SAM2.")
|
||||||
|
parser.add_argument("--base-dir", type=str, help="Base directory for video segments.")
|
||||||
|
parser.add_argument("--segments-detect-humans", nargs='*', help="Segments for which to run YOLO human detection. Use 'all' for all segments, or list specific segment numbers (e.g., 1 5 10). Default: all segments.")
|
||||||
|
parser.add_argument("--yolo-model", type=str, default=YOLO_MODEL_PATH, help="Path to YOLO model.")
|
||||||
|
parser.add_argument("--yolo-confidence", type=float, default=YOLO_CONFIDENCE, help="YOLO detection confidence threshold.")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
base_dir = args.base_dir
|
||||||
|
segments = [d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d)) and d.startswith("segment_")]
|
||||||
|
segments.sort(key=lambda x: int(x.split("_")[1]))
|
||||||
|
|
||||||
|
# Handle different ways to specify segments for YOLO detection
|
||||||
|
if args.segments_detect_humans is None or len(args.segments_detect_humans) == 0:
|
||||||
|
# Default: run YOLO on all segments
|
||||||
|
detect_segments = [int(seg.split("_")[1]) for seg in segments]
|
||||||
|
logger.info("No segments specified, running YOLO detection on ALL segments")
|
||||||
|
elif len(args.segments_detect_humans) == 1 and args.segments_detect_humans[0].lower() == 'all':
|
||||||
|
# Explicit 'all' keyword
|
||||||
|
detect_segments = [int(seg.split("_")[1]) for seg in segments]
|
||||||
|
logger.info("Running YOLO detection on ALL segments")
|
||||||
|
else:
|
||||||
|
# Specific segment numbers provided
|
||||||
|
try:
|
||||||
|
detect_segments = [int(x) for x in args.segments_detect_humans]
|
||||||
|
logger.info(f"Running YOLO detection on segments: {detect_segments}")
|
||||||
|
except ValueError:
|
||||||
|
logger.error("Invalid segment numbers provided. Use integers or 'all'.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Run YOLO detection on specified segments
|
||||||
|
do_yolo_detection_on_segments(base_dir, segments, detect_segments, scale=INFERENCE_SCALE, yolo_model_path=args.yolo_model)
|
||||||
|
|
||||||
|
# Load YOLO model for inference
|
||||||
|
yolo_model = YOLO(args.yolo_model)
|
||||||
|
|
||||||
|
for i, segment in enumerate(segments):
|
||||||
|
segment_index = int(segment.split("_")[1])
|
||||||
|
segment_dir = os.path.join(base_dir, segment)
|
||||||
|
video_file_name = get_video_file_name(i)
|
||||||
|
video_path = os.path.join(segment_dir, video_file_name)
|
||||||
|
output_done_file = os.path.join(segment_dir, "output_frames_done")
|
||||||
|
|
||||||
|
if os.path.exists(output_done_file):
|
||||||
|
logger.info(f"Segment {segment} already processed. Skipping.")
|
||||||
|
continue
|
||||||
|
|
||||||
|
logger.info(f"Processing segment {segment}")
|
||||||
|
|
||||||
|
# Initialize predictor
|
||||||
|
predictor = initialize_predictor()
|
||||||
|
|
||||||
|
# Prepare low-resolution video frames for inference
|
||||||
|
low_res_video_path = os.path.join(segment_dir, "low_res_video.mp4")
|
||||||
|
if not os.path.exists(low_res_video_path):
|
||||||
|
create_low_res_video(video_path, low_res_video_path, INFERENCE_SCALE)
|
||||||
|
logger.info(f"Low-resolution video created for segment {segment}")
|
||||||
|
else:
|
||||||
|
logger.info(f"Low-resolution video already exists for segment {segment}, reuse")
|
||||||
|
|
||||||
|
# Initialize inference state with low-resolution video
|
||||||
|
inference_state = predictor.init_state(video_path=low_res_video_path, async_loading_frames=True)
|
||||||
|
|
||||||
|
# Load YOLO detections or previous masks
|
||||||
|
detection_file = os.path.join(segment_dir, "yolo_detections")
|
||||||
|
use_detections = segment_index in detect_segments
|
||||||
|
|
||||||
|
if i == 0 and not use_detections:
|
||||||
|
# First segment must use YOLO detection since there's no previous mask
|
||||||
|
logger.warning(f"First segment {segment} requires YOLO detection. Running YOLO detection.")
|
||||||
|
use_detections = True
|
||||||
|
|
||||||
|
if i > 0 and not use_detections:
|
||||||
|
# Try to load previous segment mask - search backwards for the most recent successful mask
|
||||||
|
logger.info(f"Using previous segment mask for segment {segment}")
|
||||||
|
mask_found = False
|
||||||
|
|
||||||
|
# Search backwards through previous segments to find a valid mask
|
||||||
|
for j in range(i - 1, -1, -1):
|
||||||
|
prev_segment_dir = os.path.join(base_dir, segments[j])
|
||||||
|
prev_mask_path = os.path.join(prev_segment_dir, "mask.png")
|
||||||
|
|
||||||
|
if os.path.exists(prev_mask_path):
|
||||||
|
try:
|
||||||
|
per_obj_input_mask, input_palette = load_previous_segment_mask(prev_segment_dir)
|
||||||
|
# Add previous masks to predictor
|
||||||
|
for obj_id, mask in per_obj_input_mask.items():
|
||||||
|
predictor.add_new_mask(inference_state, 0, obj_id, mask)
|
||||||
|
logger.info(f"Successfully loaded mask from segment {segments[j]}")
|
||||||
|
mask_found = True
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error loading mask from {segments[j]}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not mask_found:
|
||||||
|
logger.error(f"No valid previous mask found for segment {segment}. Consider running YOLO detection on this segment.")
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
# Load first frame for detection
|
||||||
|
first_frame = load_first_frame(low_res_video_path, scale=1.0)
|
||||||
|
if first_frame is None:
|
||||||
|
logger.error(f"Could not load first frame for segment {segment}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Run YOLO detection on first frame (either from file or on-the-fly)
|
||||||
|
if os.path.exists(detection_file):
|
||||||
|
logger.info(f"Using existing YOLO detections for segment {segment}")
|
||||||
|
else:
|
||||||
|
logger.info(f"Running YOLO detection on-the-fly for segment {segment}")
|
||||||
|
|
||||||
|
human_detections = detect_humans_with_yolo(first_frame, yolo_model, args.yolo_confidence)
|
||||||
|
|
||||||
|
if human_detections:
|
||||||
|
# Add YOLO detections to predictor
|
||||||
|
frame_width = first_frame.shape[1]
|
||||||
|
add_yolo_detections_to_predictor(predictor, inference_state, human_detections, frame_width)
|
||||||
|
else:
|
||||||
|
logger.warning(f"No humans detected in segment {segment}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Perform inference and collect masks per frame
|
||||||
|
video_segments = propagate_masks(predictor, inference_state)
|
||||||
|
|
||||||
|
# Process high-resolution frames and save output video
|
||||||
|
output_video_path = os.path.join(segment_dir, f"output_{segment_index}.mp4")
|
||||||
|
logger.info("Processing segment complete, attempting to save full video from low res masks")
|
||||||
|
process_and_save_output_video(
|
||||||
|
video_path,
|
||||||
|
output_video_path,
|
||||||
|
video_segments,
|
||||||
|
use_nvenc=True # Set to True to use NVENC offloading
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save final masks
|
||||||
|
mask_output_path = os.path.join(segment_dir, "mask.png")
|
||||||
|
save_final_masks(video_segments, mask_output_path)
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
predictor.reset_state(inference_state)
|
||||||
|
del inference_state
|
||||||
|
del video_segments
|
||||||
|
del predictor
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
try:
|
||||||
|
os.remove(low_res_video_path)
|
||||||
|
logger.info(f"Deleted low-resolution video for segment {segment}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Could not delete low-resolution video for segment {segment}: {e}")
|
||||||
|
|
||||||
|
# Mark segment as completed
|
||||||
|
open(output_done_file, 'a').close()
|
||||||
|
|
||||||
|
logger.info("Processing complete.")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|||||||
Reference in New Issue
Block a user