Compare commits
20 Commits
7431954482
...
streaming
| Author | SHA1 | Date | |
|---|---|---|---|
| c1aa11e5a0 | |||
| f0cf3341af | |||
| ee330fa322 | |||
| 1e9c42adbd | |||
| 9cc755b5c7 | |||
| 300ae5613e | |||
| a479d6a5f0 | |||
| e38f63f539 | |||
| 66895a87a0 | |||
| 43be574729 | |||
| 9b7f36fec2 | |||
| 7b3ffb7830 | |||
| 1d15fb5bc8 | |||
| 2e5ded7dbf | |||
| 3a59e87f3e | |||
| abc48604a1 | |||
| ee80ed28b6 | |||
| b5eae7b41d | |||
| 4cc14bc0a9 | |||
| 9faaf4ed57 |
@@ -27,9 +27,9 @@ matting:
|
||||
sam2_model_cfg: "sam2.1_hiera_l" # Use large model for best quality
|
||||
sam2_checkpoint: "segment-anything-2/checkpoints/sam2.1_hiera_large.pt"
|
||||
memory_offload: true # Critical for streaming - offload to CPU when needed
|
||||
fp16: true # Use half precision for memory efficiency
|
||||
fp16: false # Disable FP16 to avoid type mismatch with compiled models for memory efficiency
|
||||
continuous_correction: true # Periodically refine tracking
|
||||
correction_interval: 300 # Correct every 5 seconds at 60fps
|
||||
correction_interval: 30 # Correct every 0.5 seconds at 60fps (for testing)
|
||||
|
||||
stereo:
|
||||
mode: "master_slave" # Left eye detects, right eye follows
|
||||
@@ -43,14 +43,14 @@ output:
|
||||
path: "/workspace/output_video.mp4" # Update with your output path
|
||||
format: "greenscreen" # "greenscreen" or "alpha"
|
||||
background_color: [0, 255, 0] # RGB for green screen
|
||||
video_codec: "h264_nvenc" # GPU encoding (or "hevc_nvenc" for better compression)
|
||||
quality_preset: "p4" # NVENC preset (p1-p7, higher = better quality)
|
||||
video_codec: "h264_nvenc" # GPU encoding for L40 (fallback to CPU if not available)
|
||||
quality_preset: "p4" # NVENC preset (p1=fastest, p7=slowest/best quality)
|
||||
crf: 18 # Quality (0-51, lower = better, 18 = high quality)
|
||||
maintain_sbs: true # Keep side-by-side format with audio
|
||||
|
||||
hardware:
|
||||
device: "cuda"
|
||||
max_vram_gb: 40.0 # Conservative limit for 48GB GPU
|
||||
max_vram_gb: 44.0 # Conservative limit for L40 48GB VRAM
|
||||
max_ram_gb: 48.0 # RunPod container RAM limit
|
||||
|
||||
recovery:
|
||||
|
||||
@@ -12,4 +12,4 @@ ffmpeg-python>=0.2.0
|
||||
decord>=0.6.0
|
||||
# GPU acceleration (optional but recommended for stereo validation speedup)
|
||||
# cupy-cuda11x>=12.0.0 # For CUDA 11.x
|
||||
# cupy-cuda12x>=12.0.0 # For CUDA 12.x - uncomment appropriate version
|
||||
cupy-cuda12x>=12.0.0 # For CUDA 12.x (most common on modern systems)
|
||||
@@ -1,6 +1,7 @@
|
||||
#!/bin/bash
|
||||
# VR180 Matting Unified Setup Script for RunPod
|
||||
# Supports both chunked and streaming implementations
|
||||
# Optimized for L40, A6000, and other NVENC-capable GPUs
|
||||
|
||||
set -e # Exit on error
|
||||
|
||||
@@ -58,20 +59,11 @@ pip install -r requirements.txt
|
||||
print_status "Installing video processing libraries..."
|
||||
pip install decord ffmpeg-python
|
||||
|
||||
# Install CuPy for GPU acceleration of stereo validation
|
||||
# Install CuPy for GPU acceleration (CUDA 12 is standard on modern RunPod)
|
||||
print_status "Installing CuPy for GPU acceleration..."
|
||||
# Auto-detect CUDA version and install appropriate CuPy
|
||||
if command -v nvidia-smi &> /dev/null; then
|
||||
CUDA_VERSION=$(nvidia-smi | grep "CUDA Version" | awk '{print $9}' | cut -d. -f1-2)
|
||||
echo "Detected CUDA version: $CUDA_VERSION"
|
||||
|
||||
if [[ "$CUDA_VERSION" == "11."* ]]; then
|
||||
pip install cupy-cuda11x>=12.0.0 && print_success "Installed CuPy for CUDA 11.x"
|
||||
elif [[ "$CUDA_VERSION" == "12."* ]]; then
|
||||
pip install cupy-cuda12x>=12.0.0 && print_success "Installed CuPy for CUDA 12.x"
|
||||
else
|
||||
print_error "Unknown CUDA version, skipping CuPy installation"
|
||||
fi
|
||||
print_status "Installing CuPy for CUDA 12.x (standard on RunPod)..."
|
||||
pip install cupy-cuda12x>=12.0.0 && print_success "Installed CuPy for CUDA 12.x"
|
||||
else
|
||||
print_error "NVIDIA GPU not detected, skipping CuPy installation"
|
||||
fi
|
||||
@@ -91,6 +83,10 @@ else
|
||||
cd ..
|
||||
fi
|
||||
|
||||
# Fix PyTorch version conflicts after SAM2 installation
|
||||
print_status "Fixing PyTorch version conflicts..."
|
||||
pip install torchaudio --upgrade --no-deps || print_error "Failed to upgrade torchaudio"
|
||||
|
||||
# Download SAM2 checkpoints
|
||||
print_status "Downloading SAM2 checkpoints..."
|
||||
cd segment-anything-2/checkpoints
|
||||
@@ -244,7 +240,7 @@ echo "==================="
|
||||
echo "- Streaming: Best for long videos, uses ~50GB RAM constant"
|
||||
echo "- Chunked: More stable but uses 100GB+ RAM in spikes"
|
||||
echo "- Scale factor: 0.25 (fast) → 0.5 (balanced) → 1.0 (quality)"
|
||||
echo "- A6000/A100: Can handle 0.5-0.75 scale easily"
|
||||
echo "- L40/A6000: Can handle 0.5-0.75 scale easily with NVENC GPU encoding"
|
||||
echo "- Monitor VRAM with: nvidia-smi -l 1"
|
||||
echo
|
||||
echo "🎯 Example Commands:"
|
||||
|
||||
@@ -11,6 +11,28 @@ import atexit
|
||||
import warnings
|
||||
|
||||
|
||||
def test_nvenc_support() -> bool:
|
||||
"""Test if NVENC encoding is available"""
|
||||
try:
|
||||
# Quick test with a 1-frame video
|
||||
cmd = [
|
||||
'ffmpeg', '-f', 'lavfi', '-i', 'testsrc=duration=0.1:size=320x240:rate=1',
|
||||
'-c:v', 'h264_nvenc', '-t', '0.1', '-f', 'null', '-'
|
||||
]
|
||||
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
timeout=10,
|
||||
text=True
|
||||
)
|
||||
|
||||
return result.returncode == 0
|
||||
|
||||
except (subprocess.TimeoutExpired, FileNotFoundError):
|
||||
return False
|
||||
|
||||
|
||||
class StreamingFrameWriter:
|
||||
"""Write frames directly to ffmpeg via pipe for memory-efficient output"""
|
||||
|
||||
@@ -36,6 +58,16 @@ class StreamingFrameWriter:
|
||||
self.frames_written = 0
|
||||
self.ffmpeg_process = None
|
||||
|
||||
# Test NVENC support if GPU codec requested
|
||||
if video_codec in ['h264_nvenc', 'hevc_nvenc']:
|
||||
print(f"🔍 Testing NVENC support...")
|
||||
if not test_nvenc_support():
|
||||
print(f"❌ NVENC not available, switching to CPU encoding")
|
||||
video_codec = 'libx264'
|
||||
quality_preset = 'medium'
|
||||
else:
|
||||
print(f"✅ NVENC available")
|
||||
|
||||
# Build ffmpeg command
|
||||
self.ffmpeg_cmd = self._build_ffmpeg_command(
|
||||
video_codec, quality_preset, crf
|
||||
@@ -132,16 +164,41 @@ class StreamingFrameWriter:
|
||||
bufsize=10**8 # Large buffer for performance
|
||||
)
|
||||
|
||||
# Test if ffmpeg starts successfully (quick check)
|
||||
import time
|
||||
time.sleep(0.2) # Give ffmpeg time to fail if it's going to
|
||||
|
||||
if self.ffmpeg_process.poll() is not None:
|
||||
# Process already died - read error
|
||||
stderr = self.ffmpeg_process.stderr.read().decode()
|
||||
|
||||
# Check for specific NVENC errors and provide better feedback
|
||||
if 'nvenc' in ' '.join(self.ffmpeg_cmd):
|
||||
if 'unsupported device' in stderr.lower():
|
||||
print(f"❌ NVENC not available on this GPU - switching to CPU encoding")
|
||||
elif 'cannot load' in stderr.lower() or 'not found' in stderr.lower():
|
||||
print(f"❌ NVENC drivers not available - switching to CPU encoding")
|
||||
else:
|
||||
print(f"❌ NVENC encoding failed: {stderr}")
|
||||
|
||||
# Try CPU fallback
|
||||
print(f"🔄 Falling back to CPU encoding (libx264)...")
|
||||
self.ffmpeg_cmd = self._build_ffmpeg_command('libx264', 'medium', 18)
|
||||
return self._start_ffmpeg()
|
||||
else:
|
||||
raise RuntimeError(f"FFmpeg failed: {stderr}")
|
||||
|
||||
# Set process to ignore SIGINT (Ctrl+C) - we'll handle it
|
||||
if hasattr(signal, 'pthread_sigmask'):
|
||||
signal.pthread_sigmask(signal.SIG_BLOCK, [signal.SIGINT])
|
||||
|
||||
except Exception as e:
|
||||
# Try CPU fallback if GPU encoding fails
|
||||
if 'nvenc' in self.ffmpeg_cmd:
|
||||
print(f"⚠️ GPU encoding failed, trying CPU fallback...")
|
||||
# Final fallback if everything fails
|
||||
if 'nvenc' in ' '.join(self.ffmpeg_cmd):
|
||||
print(f"⚠️ GPU encoding failed with error: {e}")
|
||||
print(f"🔄 Falling back to CPU encoding...")
|
||||
self.ffmpeg_cmd = self._build_ffmpeg_command('libx264', 'medium', 18)
|
||||
self._start_ffmpeg()
|
||||
return self._start_ffmpeg()
|
||||
else:
|
||||
raise RuntimeError(f"Failed to start ffmpeg: {e}")
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ For a true streaming implementation, you may need to:
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import cv2
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, List, Optional, Tuple, Generator
|
||||
import warnings
|
||||
@@ -34,6 +35,11 @@ class SAM2StreamingProcessor:
|
||||
self.config = config
|
||||
self.device = torch.device(config.get('hardware', {}).get('device', 'cuda'))
|
||||
|
||||
# Processing parameters (set before _init_predictor)
|
||||
self.memory_offload = config.get('matting', {}).get('memory_offload', True)
|
||||
self.fp16 = config.get('matting', {}).get('fp16', True)
|
||||
self.correction_interval = config.get('matting', {}).get('correction_interval', 300)
|
||||
|
||||
# SAM2 model configuration
|
||||
model_cfg = config.get('matting', {}).get('sam2_model_cfg', 'sam2.1_hiera_l')
|
||||
checkpoint = config.get('matting', {}).get('sam2_checkpoint',
|
||||
@@ -43,11 +49,6 @@ class SAM2StreamingProcessor:
|
||||
self.predictor = None
|
||||
self._init_predictor(model_cfg, checkpoint)
|
||||
|
||||
# Processing parameters
|
||||
self.memory_offload = config.get('matting', {}).get('memory_offload', True)
|
||||
self.fp16 = config.get('matting', {}).get('fp16', True)
|
||||
self.correction_interval = config.get('matting', {}).get('correction_interval', 300)
|
||||
|
||||
# State management
|
||||
self.states = {} # eye -> inference state
|
||||
self.object_ids = []
|
||||
@@ -80,52 +81,165 @@ class SAM2StreamingProcessor:
|
||||
vos_optimized=True # Enable full model compilation for speed
|
||||
)
|
||||
|
||||
# Set to eval mode
|
||||
# Set to eval mode and ensure all model components are on GPU
|
||||
self.predictor.eval()
|
||||
|
||||
# Enable FP16 if requested
|
||||
# Force all predictor components to GPU
|
||||
self.predictor = self.predictor.to(self.device)
|
||||
|
||||
# Force move all internal components that might be on CPU
|
||||
if hasattr(self.predictor, 'image_encoder'):
|
||||
self.predictor.image_encoder = self.predictor.image_encoder.to(self.device)
|
||||
if hasattr(self.predictor, 'memory_attention'):
|
||||
self.predictor.memory_attention = self.predictor.memory_attention.to(self.device)
|
||||
if hasattr(self.predictor, 'memory_encoder'):
|
||||
self.predictor.memory_encoder = self.predictor.memory_encoder.to(self.device)
|
||||
if hasattr(self.predictor, 'sam_mask_decoder'):
|
||||
self.predictor.sam_mask_decoder = self.predictor.sam_mask_decoder.to(self.device)
|
||||
if hasattr(self.predictor, 'sam_prompt_encoder'):
|
||||
self.predictor.sam_prompt_encoder = self.predictor.sam_prompt_encoder.to(self.device)
|
||||
|
||||
# Note: FP16 conversion can cause type mismatches with compiled models
|
||||
# Let SAM2 handle precision internally via build_sam2_video_predictor options
|
||||
if self.fp16 and self.device.type == 'cuda':
|
||||
self.predictor = self.predictor.half()
|
||||
print(" FP16 enabled via SAM2 internal settings")
|
||||
|
||||
print(f" All SAM2 components moved to {self.device}")
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to initialize SAM2 predictor: {e}")
|
||||
|
||||
def init_state(self,
|
||||
video_path: str,
|
||||
video_info: Dict[str, Any],
|
||||
eye: str = 'full') -> Dict[str, Any]:
|
||||
"""
|
||||
Initialize inference state for streaming
|
||||
Initialize inference state for streaming (NO VIDEO LOADING)
|
||||
|
||||
Args:
|
||||
video_path: Path to video file
|
||||
video_info: Video metadata dict with width, height, frame_count
|
||||
eye: Eye identifier ('left', 'right', or 'full')
|
||||
|
||||
Returns:
|
||||
Inference state dictionary
|
||||
"""
|
||||
# Initialize state with memory offloading enabled
|
||||
with torch.inference_mode():
|
||||
state = self.predictor.init_state(
|
||||
video_path=video_path,
|
||||
offload_video_to_cpu=self.memory_offload,
|
||||
offload_state_to_cpu=self.memory_offload,
|
||||
async_loading_frames=False # We'll provide frames directly
|
||||
)
|
||||
print(f" Initializing streaming state for {eye} eye...")
|
||||
|
||||
# Monitor memory before initialization
|
||||
if torch.cuda.is_available():
|
||||
before_mem = torch.cuda.memory_allocated() / 1e9
|
||||
print(f" 📊 GPU memory before init: {before_mem:.1f}GB")
|
||||
|
||||
# Create streaming state WITHOUT loading video frames
|
||||
state = self._create_streaming_state(video_info)
|
||||
|
||||
# Monitor memory after initialization
|
||||
if torch.cuda.is_available():
|
||||
after_mem = torch.cuda.memory_allocated() / 1e9
|
||||
print(f" 📊 GPU memory after init: {after_mem:.1f}GB (+{after_mem-before_mem:.1f}GB)")
|
||||
|
||||
self.states[eye] = state
|
||||
print(f" Initialized state for {eye} eye")
|
||||
print(f" ✅ Streaming state initialized for {eye} eye")
|
||||
|
||||
return state
|
||||
|
||||
def _create_streaming_state(self, video_info: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Create streaming state for frame-by-frame processing"""
|
||||
# Create a streaming-compatible inference state
|
||||
# This mirrors SAM2's internal state structure but without video frames
|
||||
|
||||
# Create streaming-compatible state without loading video
|
||||
# This approach avoids the dummy video complexity
|
||||
|
||||
with torch.inference_mode():
|
||||
# Initialize minimal state that mimics SAM2's structure
|
||||
inference_state = {
|
||||
'point_inputs_per_obj': {},
|
||||
'mask_inputs_per_obj': {},
|
||||
'cached_features': {},
|
||||
'constants': {},
|
||||
'obj_id_to_idx': {},
|
||||
'obj_idx_to_id': {},
|
||||
'obj_ids': [],
|
||||
'click_inputs_per_obj': {},
|
||||
'temp_output_dict_per_obj': {},
|
||||
'consolidated_frame_inds': {},
|
||||
'tracking_has_started': False,
|
||||
'num_frames': video_info.get('total_frames', video_info.get('frame_count', 0)),
|
||||
'video_height': video_info['height'],
|
||||
'video_width': video_info['width'],
|
||||
'device': self.device,
|
||||
'storage_device': self.device, # Keep everything on GPU
|
||||
'offload_video_to_cpu': False,
|
||||
'offload_state_to_cpu': False,
|
||||
# Add required SAM2 internal structures
|
||||
'output_dict_per_obj': {},
|
||||
'temp_output_dict_per_obj': {},
|
||||
'frames': None, # We provide frames manually
|
||||
'images': None, # We provide images manually
|
||||
# Additional SAM2 tracking fields
|
||||
'frames_tracked_per_obj': {},
|
||||
'obj_idx_to_id': {},
|
||||
'obj_id_to_idx': {},
|
||||
'click_inputs_per_obj': {},
|
||||
'point_inputs_per_obj': {},
|
||||
'mask_inputs_per_obj': {},
|
||||
'output_dict': {},
|
||||
'memory_bank': {},
|
||||
'num_obj_tokens': 0,
|
||||
'max_obj_ptr_num': 16, # SAM2 default
|
||||
'multimask_output_in_sam': False,
|
||||
'use_multimask_token_for_obj_ptr': True,
|
||||
'max_inference_state_frames': -1, # No limit for streaming
|
||||
'image_feature_cache': {},
|
||||
'cached_features': {},
|
||||
'consolidated_frame_inds': {},
|
||||
}
|
||||
|
||||
# Initialize some constants that SAM2 expects
|
||||
inference_state['constants'] = {
|
||||
'image_size': max(video_info['height'], video_info['width']),
|
||||
'backbone_stride': 16, # Standard SAM2 backbone stride
|
||||
'sam_mask_decoder_extra_args': {},
|
||||
'sam_prompt_embed_dim': 256,
|
||||
'sam_image_embedding_size': video_info['height'] // 16, # Assuming 16x downsampling
|
||||
}
|
||||
|
||||
print(f" Created streaming-compatible state")
|
||||
|
||||
return inference_state
|
||||
|
||||
def _move_state_to_device(self, state: Dict[str, Any], device: torch.device) -> None:
|
||||
"""Move all tensors in state to the specified device"""
|
||||
def move_to_device(obj):
|
||||
if isinstance(obj, torch.Tensor):
|
||||
return obj.to(device)
|
||||
elif isinstance(obj, dict):
|
||||
return {k: move_to_device(v) for k, v in obj.items()}
|
||||
elif isinstance(obj, list):
|
||||
return [move_to_device(item) for item in obj]
|
||||
elif isinstance(obj, tuple):
|
||||
return tuple(move_to_device(item) for item in obj)
|
||||
else:
|
||||
return obj
|
||||
|
||||
# Move all state components to device
|
||||
for key, value in state.items():
|
||||
if key not in ['video_path', 'num_frames', 'video_height', 'video_width']: # Skip metadata
|
||||
state[key] = move_to_device(value)
|
||||
|
||||
print(f" Moved state tensors to {device}")
|
||||
|
||||
def add_detections(self,
|
||||
state: Dict[str, Any],
|
||||
frame: np.ndarray,
|
||||
detections: List[Dict[str, Any]],
|
||||
frame_idx: int = 0) -> List[int]:
|
||||
"""
|
||||
Add detection boxes as prompts to SAM2
|
||||
Add detection boxes as prompts to SAM2 with frame data
|
||||
|
||||
Args:
|
||||
state: Inference state
|
||||
frame: Frame image (RGB numpy array)
|
||||
detections: List of detections with 'box' key
|
||||
frame_idx: Frame index to add prompts
|
||||
|
||||
@@ -136,6 +250,23 @@ class SAM2StreamingProcessor:
|
||||
warnings.warn(f"No detections to add at frame {frame_idx}")
|
||||
return []
|
||||
|
||||
# Convert frame to tensor (ensure proper format and device)
|
||||
if isinstance(frame, np.ndarray):
|
||||
# Convert BGR to RGB if needed (OpenCV uses BGR)
|
||||
if frame.shape[-1] == 3:
|
||||
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
frame_tensor = torch.from_numpy(frame).float().to(self.device)
|
||||
else:
|
||||
frame_tensor = frame.float().to(self.device)
|
||||
|
||||
if frame_tensor.ndim == 3:
|
||||
frame_tensor = frame_tensor.permute(2, 0, 1) # HWC -> CHW
|
||||
frame_tensor = frame_tensor.unsqueeze(0) # Add batch dimension
|
||||
|
||||
# Normalize to [0, 1] range if needed
|
||||
if frame_tensor.max() > 1.0:
|
||||
frame_tensor = frame_tensor / 255.0
|
||||
|
||||
# Convert detections to SAM2 format
|
||||
boxes = []
|
||||
for det in detections:
|
||||
@@ -144,43 +275,160 @@ class SAM2StreamingProcessor:
|
||||
|
||||
boxes_tensor = torch.tensor(boxes, dtype=torch.float32, device=self.device)
|
||||
|
||||
# Add boxes as prompts
|
||||
# Manually process frame and add prompts (streaming approach)
|
||||
with torch.inference_mode():
|
||||
_, object_ids, _ = self.predictor.add_new_points_or_box(
|
||||
inference_state=state,
|
||||
frame_idx=frame_idx,
|
||||
obj_id=0, # SAM2 will auto-increment
|
||||
box=boxes_tensor
|
||||
)
|
||||
# Process frame through SAM2's image encoder
|
||||
backbone_out = self.predictor.forward_image(frame_tensor)
|
||||
|
||||
# Store features in state for this frame
|
||||
state['cached_features'][frame_idx] = backbone_out
|
||||
|
||||
# Convert boxes to points for manual implementation
|
||||
# SAM2 expects corner points from boxes with labels 2,3
|
||||
points = []
|
||||
labels = []
|
||||
for box in boxes:
|
||||
# Convert box [x1, y1, x2, y2] to corner points
|
||||
x1, y1, x2, y2 = box
|
||||
points.extend([[x1, y1], [x2, y2]]) # Top-left and bottom-right corners
|
||||
labels.extend([2, 3]) # SAM2 standard labels for box corners
|
||||
|
||||
points_tensor = torch.tensor(points, dtype=torch.float32, device=self.device)
|
||||
labels_tensor = torch.tensor(labels, dtype=torch.int32, device=self.device)
|
||||
|
||||
try:
|
||||
# Use add_new_points instead of add_new_points_or_box to avoid device issues
|
||||
_, object_ids, masks = self.predictor.add_new_points(
|
||||
inference_state=state,
|
||||
frame_idx=frame_idx,
|
||||
obj_id=None, # Let SAM2 auto-assign
|
||||
points=points_tensor,
|
||||
labels=labels_tensor,
|
||||
clear_old_points=True,
|
||||
normalize_coords=True
|
||||
)
|
||||
|
||||
# Update state with object tracking info
|
||||
state['obj_ids'] = object_ids
|
||||
state['tracking_has_started'] = True
|
||||
|
||||
except Exception as e:
|
||||
print(f" Error in add_new_points: {e}")
|
||||
print(f" Points tensor device: {points_tensor.device}")
|
||||
print(f" Labels tensor device: {labels_tensor.device}")
|
||||
print(f" Frame tensor device: {frame_tensor.device}")
|
||||
|
||||
# Fallback: manually initialize object tracking
|
||||
print(f" Using fallback manual object initialization")
|
||||
object_ids = [i for i in range(len(detections))]
|
||||
state['obj_ids'] = object_ids
|
||||
state['tracking_has_started'] = True
|
||||
|
||||
# Store detection info for later use
|
||||
for i, (points_pair, det) in enumerate(zip(zip(points[::2], points[1::2]), detections)):
|
||||
state['point_inputs_per_obj'][i] = {
|
||||
frame_idx: {
|
||||
'points': points_tensor[i*2:(i+1)*2],
|
||||
'labels': labels_tensor[i*2:(i+1)*2]
|
||||
}
|
||||
}
|
||||
|
||||
self.object_ids = object_ids
|
||||
print(f" Added {len(detections)} detections at frame {frame_idx}: objects {object_ids}")
|
||||
|
||||
return object_ids
|
||||
|
||||
def propagate_in_video_simple(self,
|
||||
state: Dict[str, Any]) -> Generator[Tuple[int, List[int], np.ndarray], None, None]:
|
||||
def propagate_single_frame(self,
|
||||
state: Dict[str, Any],
|
||||
frame: np.ndarray,
|
||||
frame_idx: int) -> np.ndarray:
|
||||
"""
|
||||
Simple propagation for single eye processing
|
||||
Propagate masks for a single frame (true streaming)
|
||||
|
||||
Yields:
|
||||
(frame_idx, object_ids, masks) tuples
|
||||
Args:
|
||||
state: Inference state
|
||||
frame: Frame image (RGB numpy array)
|
||||
frame_idx: Frame index
|
||||
|
||||
Returns:
|
||||
Combined mask for all objects
|
||||
"""
|
||||
# Convert frame to tensor (ensure proper format and device)
|
||||
if isinstance(frame, np.ndarray):
|
||||
# Convert BGR to RGB if needed (OpenCV uses BGR)
|
||||
if frame.shape[-1] == 3:
|
||||
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
frame_tensor = torch.from_numpy(frame).float().to(self.device)
|
||||
else:
|
||||
frame_tensor = frame.float().to(self.device)
|
||||
|
||||
if frame_tensor.ndim == 3:
|
||||
frame_tensor = frame_tensor.permute(2, 0, 1) # HWC -> CHW
|
||||
frame_tensor = frame_tensor.unsqueeze(0) # Add batch dimension
|
||||
|
||||
# Normalize to [0, 1] range if needed
|
||||
if frame_tensor.max() > 1.0:
|
||||
frame_tensor = frame_tensor / 255.0
|
||||
|
||||
with torch.inference_mode():
|
||||
for frame_idx, object_ids, masks in self.predictor.propagate_in_video(state):
|
||||
# Convert masks to numpy
|
||||
if isinstance(masks, torch.Tensor):
|
||||
masks_np = masks.cpu().numpy()
|
||||
# Process frame through SAM2's image encoder
|
||||
backbone_out = self.predictor.forward_image(frame_tensor)
|
||||
|
||||
# Store features in state for this frame
|
||||
state['cached_features'][frame_idx] = backbone_out
|
||||
|
||||
# Use SAM2's single frame inference for propagation
|
||||
try:
|
||||
# Run single frame inference for all tracked objects
|
||||
output_dict = {}
|
||||
self.predictor._run_single_frame_inference(
|
||||
inference_state=state,
|
||||
output_dict=output_dict,
|
||||
frame_idx=frame_idx,
|
||||
batch_size=1,
|
||||
is_init_cond_frame=False, # Not initialization frame
|
||||
point_inputs=None,
|
||||
mask_inputs=None,
|
||||
reverse=False,
|
||||
run_mem_encoder=True
|
||||
)
|
||||
|
||||
# Extract masks from output
|
||||
if output_dict and 'pred_masks' in output_dict:
|
||||
pred_masks = output_dict['pred_masks']
|
||||
# Combine all object masks
|
||||
if pred_masks.shape[0] > 0:
|
||||
combined_mask = pred_masks.max(dim=0)[0]
|
||||
combined_mask_np = (combined_mask > 0.0).cpu().numpy().astype(np.uint8) * 255
|
||||
else:
|
||||
combined_mask_np = np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8)
|
||||
else:
|
||||
masks_np = masks
|
||||
combined_mask_np = np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8)
|
||||
|
||||
yield frame_idx, object_ids, masks_np
|
||||
except Exception as e:
|
||||
print(f" Warning: Single frame inference failed: {e}")
|
||||
combined_mask_np = np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8)
|
||||
|
||||
# Periodic memory cleanup
|
||||
if frame_idx % 100 == 0:
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
# Cleanup old features to prevent memory accumulation
|
||||
self._cleanup_old_features(state, frame_idx, keep_frames=10)
|
||||
|
||||
return combined_mask_np
|
||||
|
||||
def _cleanup_old_features(self, state: Dict[str, Any], current_frame: int, keep_frames: int = 10):
|
||||
"""Remove old cached features to prevent memory accumulation"""
|
||||
features_to_remove = []
|
||||
for frame_idx in state.get('cached_features', {}):
|
||||
if frame_idx < current_frame - keep_frames:
|
||||
features_to_remove.append(frame_idx)
|
||||
|
||||
for frame_idx in features_to_remove:
|
||||
del state['cached_features'][frame_idx]
|
||||
|
||||
# Periodic GPU memory cleanup
|
||||
if current_frame % 50 == 0:
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
def propagate_frame_pair(self,
|
||||
left_state: Dict[str, Any],
|
||||
|
||||
407
vr180_streaming/sam2_streaming_simple.py
Normal file
407
vr180_streaming/sam2_streaming_simple.py
Normal file
@@ -0,0 +1,407 @@
|
||||
"""
|
||||
Simple SAM2 streaming processor based on det-sam2 pattern
|
||||
Adapted for current segment-anything-2 API
|
||||
"""
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import cv2
|
||||
import tempfile
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, List, Optional
|
||||
import warnings
|
||||
import gc
|
||||
|
||||
# Import SAM2 components
|
||||
try:
|
||||
from sam2.build_sam import build_sam2_video_predictor
|
||||
except ImportError:
|
||||
warnings.warn("SAM2 not installed. Please install segment-anything-2 first.")
|
||||
|
||||
|
||||
class SAM2StreamingProcessor:
|
||||
"""Simple streaming integration with SAM2 following det-sam2 pattern"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
self.config = config
|
||||
self.device = torch.device(config.get('hardware', {}).get('device', 'cuda'))
|
||||
|
||||
# SAM2 model configuration
|
||||
model_cfg_name = config.get('matting', {}).get('sam2_model_cfg', 'sam2.1_hiera_l')
|
||||
checkpoint = config.get('matting', {}).get('sam2_checkpoint',
|
||||
'segment-anything-2/checkpoints/sam2.1_hiera_large.pt')
|
||||
|
||||
# Map config name to Hydra path (like the examples show)
|
||||
config_mapping = {
|
||||
'sam2.1_hiera_t': 'configs/sam2.1/sam2.1_hiera_t.yaml',
|
||||
'sam2.1_hiera_s': 'configs/sam2.1/sam2.1_hiera_s.yaml',
|
||||
'sam2.1_hiera_b+': 'configs/sam2.1/sam2.1_hiera_b+.yaml',
|
||||
'sam2.1_hiera_l': 'configs/sam2.1/sam2.1_hiera_l.yaml',
|
||||
}
|
||||
|
||||
model_cfg = config_mapping.get(model_cfg_name, model_cfg_name)
|
||||
|
||||
# Build predictor (disable compilation to fix CUDA graph issues)
|
||||
self.predictor = build_sam2_video_predictor(
|
||||
model_cfg, # Relative path from sam2 package
|
||||
checkpoint,
|
||||
device=self.device,
|
||||
vos_optimized=False, # Disable to avoid CUDA graph issues
|
||||
hydra_overrides_extra=[
|
||||
"++model.compile_image_encoder=false", # Disable compilation
|
||||
]
|
||||
)
|
||||
|
||||
# Frame buffer for streaming (like det-sam2)
|
||||
self.frame_buffer = []
|
||||
self.frame_buffer_size = config.get('streaming', {}).get('buffer_frames', 10)
|
||||
|
||||
# State management (simple)
|
||||
self.inference_state = None
|
||||
self.temp_dir = None
|
||||
self.object_ids = []
|
||||
|
||||
# Memory management
|
||||
self.memory_offload = config.get('matting', {}).get('memory_offload', True)
|
||||
self.max_frames_to_track = config.get('matting', {}).get('correction_interval', 300)
|
||||
|
||||
print(f"🎯 Simple SAM2 streaming processor initialized:")
|
||||
print(f" Model: {model_cfg}")
|
||||
print(f" Device: {self.device}")
|
||||
print(f" Buffer size: {self.frame_buffer_size}")
|
||||
print(f" Memory offload: {self.memory_offload}")
|
||||
|
||||
def add_frame_and_detections(self,
|
||||
frame: np.ndarray,
|
||||
detections: List[Dict[str, Any]],
|
||||
frame_idx: int) -> np.ndarray:
|
||||
"""
|
||||
Add frame to buffer and process detections (det-sam2 pattern)
|
||||
|
||||
Args:
|
||||
frame: Input frame (BGR)
|
||||
detections: List of detections with 'box' key
|
||||
frame_idx: Global frame index
|
||||
|
||||
Returns:
|
||||
Mask for current frame
|
||||
"""
|
||||
# Add frame to buffer
|
||||
self.frame_buffer.append({
|
||||
'frame': frame,
|
||||
'frame_idx': frame_idx,
|
||||
'detections': detections
|
||||
})
|
||||
|
||||
# Process when buffer is full or when we have detections
|
||||
if len(self.frame_buffer) >= self.frame_buffer_size or detections:
|
||||
return self._process_buffer()
|
||||
else:
|
||||
# For frames without detections, still try to propagate if we have existing objects
|
||||
if self.inference_state is not None and self.object_ids:
|
||||
return self._propagate_existing_objects()
|
||||
else:
|
||||
# Return empty mask if no processing yet
|
||||
return np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8)
|
||||
|
||||
def _process_buffer(self) -> np.ndarray:
|
||||
"""Process current frame buffer (adapted det-sam2 approach)"""
|
||||
if not self.frame_buffer:
|
||||
return np.zeros((480, 640), dtype=np.uint8)
|
||||
|
||||
try:
|
||||
# Create temporary directory for frames (current SAM2 API requirement)
|
||||
self._create_temp_frames()
|
||||
|
||||
# Initialize or update SAM2 state
|
||||
if self.inference_state is None:
|
||||
# First time: initialize state with temp directory
|
||||
self.inference_state = self.predictor.init_state(
|
||||
video_path=self.temp_dir,
|
||||
offload_video_to_cpu=self.memory_offload,
|
||||
offload_state_to_cpu=self.memory_offload
|
||||
)
|
||||
print(f" Initialized SAM2 state with {len(self.frame_buffer)} frames")
|
||||
else:
|
||||
# Subsequent times: we need to reinitialize since current SAM2 lacks update_state
|
||||
# This is the key difference from det-sam2 reference
|
||||
self._cleanup_temp_frames()
|
||||
self._create_temp_frames()
|
||||
self.inference_state = self.predictor.init_state(
|
||||
video_path=self.temp_dir,
|
||||
offload_video_to_cpu=self.memory_offload,
|
||||
offload_state_to_cpu=self.memory_offload
|
||||
)
|
||||
print(f" Reinitialized SAM2 state with {len(self.frame_buffer)} frames")
|
||||
|
||||
# Add detections as prompts (standard SAM2 API)
|
||||
self._add_detection_prompts()
|
||||
|
||||
# Get masks via propagation
|
||||
masks = self._get_current_masks()
|
||||
|
||||
# Clean up old frames to prevent memory accumulation
|
||||
self._cleanup_old_frames()
|
||||
|
||||
return masks
|
||||
|
||||
except Exception as e:
|
||||
print(f" Warning: Buffer processing failed: {e}")
|
||||
return np.zeros((480, 640), dtype=np.uint8)
|
||||
|
||||
def _create_temp_frames(self):
|
||||
"""Create temporary directory with frame images for SAM2"""
|
||||
if self.temp_dir:
|
||||
self._cleanup_temp_frames()
|
||||
|
||||
self.temp_dir = tempfile.mkdtemp(prefix='sam2_streaming_')
|
||||
|
||||
for i, buffer_item in enumerate(self.frame_buffer):
|
||||
frame = buffer_item['frame']
|
||||
# Convert BGR to RGB for SAM2
|
||||
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# Save as JPEG (SAM2 expects JPEG images in directory)
|
||||
frame_path = os.path.join(self.temp_dir, f"{i:05d}.jpg")
|
||||
cv2.imwrite(frame_path, cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR))
|
||||
|
||||
def _add_detection_prompts(self):
|
||||
"""Add detection boxes as prompts to SAM2 (standard API)"""
|
||||
for buffer_idx, buffer_item in enumerate(self.frame_buffer):
|
||||
detections = buffer_item.get('detections', [])
|
||||
|
||||
for det_idx, detection in enumerate(detections):
|
||||
box = detection['box'] # [x1, y1, x2, y2]
|
||||
|
||||
# Use standard SAM2 API
|
||||
try:
|
||||
_, out_obj_ids, out_mask_logits = self.predictor.add_new_points_or_box(
|
||||
inference_state=self.inference_state,
|
||||
frame_idx=buffer_idx, # Frame index within buffer
|
||||
obj_id=det_idx, # Simple object ID
|
||||
box=np.array(box, dtype=np.float32)
|
||||
)
|
||||
|
||||
# Track object IDs
|
||||
if det_idx not in self.object_ids:
|
||||
self.object_ids.append(det_idx)
|
||||
|
||||
except Exception as e:
|
||||
print(f" Warning: Failed to add detection: {e}")
|
||||
continue
|
||||
|
||||
def _get_current_masks(self) -> np.ndarray:
|
||||
"""Get masks for current frame via propagation"""
|
||||
if not self.object_ids:
|
||||
# No objects to track
|
||||
frame_shape = self.frame_buffer[-1]['frame'].shape
|
||||
return np.zeros((frame_shape[0], frame_shape[1]), dtype=np.uint8)
|
||||
|
||||
try:
|
||||
# Use SAM2's propagate_in_video (standard API)
|
||||
latest_frame_idx = len(self.frame_buffer) - 1
|
||||
masks_for_frame = []
|
||||
|
||||
for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(
|
||||
self.inference_state,
|
||||
start_frame_idx=latest_frame_idx,
|
||||
max_frame_num_to_track=1, # Just current frame
|
||||
reverse=False
|
||||
):
|
||||
if out_frame_idx == latest_frame_idx:
|
||||
# Combine all object masks
|
||||
if len(out_mask_logits) > 0:
|
||||
combined_mask = None
|
||||
for mask_logit in out_mask_logits:
|
||||
mask = (mask_logit > 0.0).cpu().numpy()
|
||||
if combined_mask is None:
|
||||
combined_mask = mask.astype(bool)
|
||||
else:
|
||||
combined_mask = combined_mask | mask.astype(bool)
|
||||
|
||||
return (combined_mask * 255).astype(np.uint8)
|
||||
|
||||
# If no masks found, return empty
|
||||
frame_shape = self.frame_buffer[-1]['frame'].shape
|
||||
return np.zeros((frame_shape[0], frame_shape[1]), dtype=np.uint8)
|
||||
|
||||
except Exception as e:
|
||||
print(f" Warning: Mask propagation failed: {e}")
|
||||
frame_shape = self.frame_buffer[-1]['frame'].shape
|
||||
return np.zeros((frame_shape[0], frame_shape[1]), dtype=np.uint8)
|
||||
|
||||
def _propagate_existing_objects(self) -> np.ndarray:
|
||||
"""Propagate existing objects without adding new detections"""
|
||||
if not self.object_ids or not self.frame_buffer:
|
||||
frame_shape = self.frame_buffer[-1]['frame'].shape if self.frame_buffer else (480, 640)
|
||||
return np.zeros((frame_shape[0], frame_shape[1]), dtype=np.uint8)
|
||||
|
||||
try:
|
||||
# Update temp frames with current buffer
|
||||
self._create_temp_frames()
|
||||
|
||||
# Reinitialize state (since we can't incrementally update)
|
||||
self.inference_state = self.predictor.init_state(
|
||||
video_path=self.temp_dir,
|
||||
offload_video_to_cpu=self.memory_offload,
|
||||
offload_state_to_cpu=self.memory_offload
|
||||
)
|
||||
|
||||
# Re-add all previous detections from buffer
|
||||
for buffer_idx, buffer_item in enumerate(self.frame_buffer):
|
||||
detections = buffer_item.get('detections', [])
|
||||
if detections: # Only add frames that had detections
|
||||
for det_idx, detection in enumerate(detections):
|
||||
box = detection['box']
|
||||
try:
|
||||
self.predictor.add_new_points_or_box(
|
||||
inference_state=self.inference_state,
|
||||
frame_idx=buffer_idx,
|
||||
obj_id=det_idx,
|
||||
box=np.array(box, dtype=np.float32)
|
||||
)
|
||||
except Exception as e:
|
||||
print(f" Warning: Failed to re-add detection: {e}")
|
||||
|
||||
# Get masks for latest frame
|
||||
latest_frame_idx = len(self.frame_buffer) - 1
|
||||
for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(
|
||||
self.inference_state,
|
||||
start_frame_idx=latest_frame_idx,
|
||||
max_frame_num_to_track=1,
|
||||
reverse=False
|
||||
):
|
||||
if out_frame_idx == latest_frame_idx and len(out_mask_logits) > 0:
|
||||
combined_mask = None
|
||||
for mask_logit in out_mask_logits:
|
||||
mask = (mask_logit > 0.0).cpu().numpy()
|
||||
if combined_mask is None:
|
||||
combined_mask = mask.astype(bool)
|
||||
else:
|
||||
combined_mask = combined_mask | mask.astype(bool)
|
||||
|
||||
return (combined_mask * 255).astype(np.uint8)
|
||||
|
||||
# If no masks, return empty
|
||||
frame_shape = self.frame_buffer[-1]['frame'].shape
|
||||
return np.zeros((frame_shape[0], frame_shape[1]), dtype=np.uint8)
|
||||
|
||||
except Exception as e:
|
||||
print(f" Warning: Object propagation failed: {e}")
|
||||
frame_shape = self.frame_buffer[-1]['frame'].shape
|
||||
return np.zeros((frame_shape[0], frame_shape[1]), dtype=np.uint8)
|
||||
|
||||
except Exception as e:
|
||||
print(f" Warning: Mask propagation failed: {e}")
|
||||
frame_shape = self.frame_buffer[-1]['frame'].shape
|
||||
return np.zeros((frame_shape[0], frame_shape[1]), dtype=np.uint8)
|
||||
|
||||
def _cleanup_old_frames(self):
|
||||
"""Clean up old frames from buffer (det-sam2 pattern)"""
|
||||
# Keep only recent frames to prevent memory accumulation
|
||||
if len(self.frame_buffer) > self.frame_buffer_size:
|
||||
self.frame_buffer = self.frame_buffer[-self.frame_buffer_size:]
|
||||
|
||||
# Periodic GPU memory cleanup
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
def _cleanup_temp_frames(self):
|
||||
"""Clean up temporary frame directory"""
|
||||
if self.temp_dir and os.path.exists(self.temp_dir):
|
||||
import shutil
|
||||
shutil.rmtree(self.temp_dir)
|
||||
self.temp_dir = None
|
||||
|
||||
def cleanup(self):
|
||||
"""Clean up all resources"""
|
||||
self._cleanup_temp_frames()
|
||||
self.frame_buffer.clear()
|
||||
self.object_ids.clear()
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
gc.collect()
|
||||
|
||||
print("🧹 Simple SAM2 streaming processor cleaned up")
|
||||
|
||||
def apply_mask_to_frame(self,
|
||||
frame: np.ndarray,
|
||||
mask: np.ndarray,
|
||||
output_format: str = "alpha",
|
||||
background_color: tuple = (0, 255, 0)) -> np.ndarray:
|
||||
"""
|
||||
Apply mask to frame with specified output format (matches chunked implementation)
|
||||
|
||||
Args:
|
||||
frame: Input frame (BGR)
|
||||
mask: Binary mask (0-255 or boolean)
|
||||
output_format: "alpha" or "greenscreen"
|
||||
background_color: RGB background color for greenscreen mode
|
||||
|
||||
Returns:
|
||||
Processed frame
|
||||
"""
|
||||
if mask is None:
|
||||
return frame
|
||||
|
||||
# Ensure mask is 2D (handle 3D masks properly)
|
||||
if mask.ndim == 3:
|
||||
mask = mask.squeeze()
|
||||
|
||||
# Resize mask to match frame if needed (use INTER_NEAREST for binary masks)
|
||||
if mask.shape[:2] != frame.shape[:2]:
|
||||
import cv2
|
||||
# Convert to uint8 for resizing, then back to bool
|
||||
if mask.dtype == bool:
|
||||
mask_uint8 = mask.astype(np.uint8) * 255
|
||||
else:
|
||||
mask_uint8 = mask.astype(np.uint8)
|
||||
|
||||
mask_resized = cv2.resize(mask_uint8,
|
||||
(frame.shape[1], frame.shape[0]),
|
||||
interpolation=cv2.INTER_NEAREST)
|
||||
mask = mask_resized.astype(bool) if mask.dtype == bool else mask_resized
|
||||
|
||||
if output_format == "alpha":
|
||||
# Create RGBA output (matches chunked implementation)
|
||||
output = np.zeros((frame.shape[0], frame.shape[1], 4), dtype=np.uint8)
|
||||
output[:, :, :3] = frame
|
||||
if mask.dtype == bool:
|
||||
output[:, :, 3] = mask.astype(np.uint8) * 255
|
||||
else:
|
||||
output[:, :, 3] = mask.astype(np.uint8)
|
||||
return output
|
||||
|
||||
elif output_format == "greenscreen":
|
||||
# Create RGB output with background (matches chunked implementation)
|
||||
output = np.full_like(frame, background_color, dtype=np.uint8)
|
||||
if mask.dtype == bool:
|
||||
output[mask] = frame[mask]
|
||||
else:
|
||||
mask_bool = mask.astype(bool)
|
||||
output[mask_bool] = frame[mask_bool]
|
||||
return output
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported output format: {output_format}. Use 'alpha' or 'greenscreen'")
|
||||
|
||||
def get_memory_usage(self) -> Dict[str, float]:
|
||||
"""
|
||||
Get current memory usage statistics
|
||||
|
||||
Returns:
|
||||
Dictionary with memory usage info
|
||||
"""
|
||||
stats = {}
|
||||
|
||||
if torch.cuda.is_available():
|
||||
# GPU memory stats
|
||||
stats['cuda_allocated_gb'] = torch.cuda.memory_allocated() / (1024**3)
|
||||
stats['cuda_reserved_gb'] = torch.cuda.memory_reserved() / (1024**3)
|
||||
stats['cuda_max_allocated_gb'] = torch.cuda.max_memory_allocated() / (1024**3)
|
||||
|
||||
return stats
|
||||
@@ -15,7 +15,7 @@ import warnings
|
||||
from .frame_reader import StreamingFrameReader
|
||||
from .frame_writer import StreamingFrameWriter
|
||||
from .stereo_manager import StereoConsistencyManager
|
||||
from .sam2_streaming import SAM2StreamingProcessor
|
||||
from .sam2_streaming_simple import SAM2StreamingProcessor
|
||||
from .detector import PersonDetector
|
||||
from .config import StreamingConfig
|
||||
|
||||
@@ -102,25 +102,17 @@ class VR180StreamingProcessor:
|
||||
self.initialize()
|
||||
self.start_time = time.time()
|
||||
|
||||
# Initialize SAM2 states for both eyes
|
||||
print("🎯 Initializing SAM2 streaming states...")
|
||||
left_state = self.sam2_processor.init_state(
|
||||
self.config.input.video_path,
|
||||
eye='left'
|
||||
)
|
||||
right_state = self.sam2_processor.init_state(
|
||||
self.config.input.video_path,
|
||||
eye='right'
|
||||
)
|
||||
# Simple SAM2 initialization (no complex state management needed)
|
||||
print("🎯 SAM2 streaming processor ready...")
|
||||
|
||||
# Process first frame to establish detections
|
||||
print("🔍 Processing first frame for initial detection...")
|
||||
if not self._initialize_tracking(left_state, right_state):
|
||||
if not self._initialize_tracking():
|
||||
raise RuntimeError("Failed to initialize tracking - no persons detected")
|
||||
|
||||
# Main streaming loop
|
||||
print("\n🎬 Starting streaming processing loop...")
|
||||
self._streaming_loop(left_state, right_state)
|
||||
self._streaming_loop()
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n⚠️ Processing interrupted by user")
|
||||
@@ -134,7 +126,7 @@ class VR180StreamingProcessor:
|
||||
finally:
|
||||
self._finalize()
|
||||
|
||||
def _initialize_tracking(self, left_state: Dict, right_state: Dict) -> bool:
|
||||
def _initialize_tracking(self) -> bool:
|
||||
"""Initialize tracking with first frame detection"""
|
||||
# Read and process first frame
|
||||
first_frame = self.frame_reader.read_frame()
|
||||
@@ -158,19 +150,15 @@ class VR180StreamingProcessor:
|
||||
|
||||
print(f" Detected {len(detections)} person(s) in first frame")
|
||||
|
||||
# Add detections to both eyes
|
||||
self.sam2_processor.add_detections(left_state, detections, frame_idx=0)
|
||||
# Process with simple SAM2 approach
|
||||
left_masks = self.sam2_processor.add_frame_and_detections(left_eye, detections, 0)
|
||||
|
||||
# Transfer detections to slave eye
|
||||
# Transfer detections to right eye
|
||||
transferred_detections = self.stereo_manager.transfer_detections(
|
||||
detections,
|
||||
'left_to_right' if self.stereo_manager.master_eye == 'left' else 'right_to_left'
|
||||
)
|
||||
self.sam2_processor.add_detections(right_state, transferred_detections, frame_idx=0)
|
||||
|
||||
# Process and write first frame
|
||||
left_masks = self.sam2_processor._propagate_single_frame(left_state, left_eye, 0)
|
||||
right_masks = self.sam2_processor._propagate_single_frame(right_state, right_eye, 0)
|
||||
right_masks = self.sam2_processor.add_frame_and_detections(right_eye, transferred_detections, 0)
|
||||
|
||||
# Apply masks and write
|
||||
processed_frame = self._apply_masks_to_frame(first_frame, left_masks, right_masks)
|
||||
@@ -179,7 +167,7 @@ class VR180StreamingProcessor:
|
||||
self.frames_processed = 1
|
||||
return True
|
||||
|
||||
def _streaming_loop(self, left_state: Dict, right_state: Dict) -> None:
|
||||
def _streaming_loop(self) -> None:
|
||||
"""Main streaming processing loop"""
|
||||
frame_times = []
|
||||
last_log_time = time.time()
|
||||
@@ -195,23 +183,36 @@ class VR180StreamingProcessor:
|
||||
# Split into eyes
|
||||
left_eye, right_eye = self.stereo_manager.split_frame(frame)
|
||||
|
||||
# Propagate masks for both eyes
|
||||
left_masks, right_masks = self.sam2_processor.propagate_frame_pair(
|
||||
left_state, right_state, left_eye, right_eye, frame_idx
|
||||
)
|
||||
# Check if we need to run detection for continuous correction
|
||||
detections = []
|
||||
if (self.config.matting.continuous_correction and
|
||||
frame_idx % self.config.matting.correction_interval == 0):
|
||||
print(f"\n🔄 Running YOLO detection for correction at frame {frame_idx}")
|
||||
master_eye = left_eye if self.stereo_manager.master_eye == 'left' else right_eye
|
||||
detections = self.detector.detect_persons(master_eye)
|
||||
if detections:
|
||||
print(f" Detected {len(detections)} person(s) for correction")
|
||||
else:
|
||||
print(f" No persons detected for correction")
|
||||
|
||||
# Process frames (with detections if this is a correction frame)
|
||||
left_masks = self.sam2_processor.add_frame_and_detections(left_eye, detections, frame_idx)
|
||||
|
||||
# For right eye, transfer detections if we have them
|
||||
if detections:
|
||||
transferred_detections = self.stereo_manager.transfer_detections(
|
||||
detections,
|
||||
'left_to_right' if self.stereo_manager.master_eye == 'left' else 'right_to_left'
|
||||
)
|
||||
right_masks = self.sam2_processor.add_frame_and_detections(right_eye, transferred_detections, frame_idx)
|
||||
else:
|
||||
right_masks = self.sam2_processor.add_frame_and_detections(right_eye, [], frame_idx)
|
||||
|
||||
# Validate stereo consistency
|
||||
right_masks = self.stereo_manager.validate_masks(
|
||||
left_masks, right_masks, frame_idx
|
||||
)
|
||||
|
||||
# Apply continuous correction if enabled
|
||||
if (self.config.matting.continuous_correction and
|
||||
frame_idx % self.config.matting.correction_interval == 0):
|
||||
self._apply_continuous_correction(
|
||||
left_state, right_state, left_eye, right_eye, frame_idx
|
||||
)
|
||||
|
||||
# Apply masks and write frame
|
||||
processed_frame = self._apply_masks_to_frame(frame, left_masks, right_masks)
|
||||
self.frame_writer.write_frame(processed_frame)
|
||||
@@ -282,21 +283,20 @@ class VR180StreamingProcessor:
|
||||
return left_processed
|
||||
|
||||
def _apply_continuous_correction(self,
|
||||
left_state: Dict,
|
||||
right_state: Dict,
|
||||
left_eye: np.ndarray,
|
||||
right_eye: np.ndarray,
|
||||
frame_idx: int) -> None:
|
||||
"""Apply continuous correction to maintain tracking accuracy"""
|
||||
print(f"\n🔄 Applying continuous correction at frame {frame_idx}")
|
||||
|
||||
# Detect on master eye
|
||||
# Detect on master eye and add fresh detections
|
||||
master_eye = left_eye if self.stereo_manager.master_eye == 'left' else right_eye
|
||||
master_state = left_state if self.stereo_manager.master_eye == 'left' else right_state
|
||||
detections = self.detector.detect_persons(master_eye)
|
||||
|
||||
self.sam2_processor.apply_continuous_correction(
|
||||
master_state, master_eye, frame_idx, self.detector
|
||||
)
|
||||
if detections:
|
||||
print(f" Adding {len(detections)} fresh detection(s) for correction")
|
||||
# Add fresh detections to help correct drift
|
||||
self.sam2_processor.add_frame_and_detections(master_eye, detections, frame_idx)
|
||||
|
||||
# Transfer corrections to slave eye
|
||||
# Note: This is simplified - actual implementation would transfer the refined prompts
|
||||
|
||||
45
vr180_streaming/timeout_init.py
Normal file
45
vr180_streaming/timeout_init.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""
|
||||
Timeout wrapper for SAM2 initialization to prevent hanging
|
||||
"""
|
||||
|
||||
import signal
|
||||
import functools
|
||||
from typing import Any, Callable
|
||||
|
||||
|
||||
class TimeoutError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def timeout(seconds: int = 120):
|
||||
"""Decorator to add timeout to function calls"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs) -> Any:
|
||||
# Define signal handler
|
||||
def timeout_handler(signum, frame):
|
||||
raise TimeoutError(f"Function {func.__name__} timed out after {seconds} seconds")
|
||||
|
||||
# Set signal handler
|
||||
old_handler = signal.signal(signal.SIGALRM, timeout_handler)
|
||||
signal.alarm(seconds)
|
||||
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
finally:
|
||||
# Restore old handler
|
||||
signal.alarm(0)
|
||||
signal.signal(signal.SIGALRM, old_handler)
|
||||
|
||||
return result
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
@timeout(120) # 2 minute timeout
|
||||
def safe_init_state(predictor, video_path: str, **kwargs) -> Any:
|
||||
"""Safely initialize SAM2 state with timeout"""
|
||||
return predictor.init_state(
|
||||
video_path=video_path,
|
||||
**kwargs
|
||||
)
|
||||
Reference in New Issue
Block a user