Compare commits

..

22 Commits

Author SHA1 Message Date
c1aa11e5a0 idk 2025-07-27 10:37:40 -07:00
f0cf3341af amp 2025-07-27 10:23:25 -07:00
ee330fa322 exccept 2025-07-27 10:20:25 -07:00
1e9c42adbd fix streaming 2025-07-27 10:16:39 -07:00
9cc755b5c7 cupy and mask 2025-07-27 10:10:00 -07:00
300ae5613e fucking llms 2025-07-27 10:01:12 -07:00
a479d6a5f0 wtf 2025-07-27 09:57:42 -07:00
e38f63f539 simplify2 2025-07-27 09:55:52 -07:00
66895a87a0 simplify 2025-07-27 09:52:56 -07:00
43be574729 debug 2025-07-27 09:26:47 -07:00
9b7f36fec2 bullshit 2025-07-27 09:23:15 -07:00
7b3ffb7830 idk 2025-07-27 09:20:42 -07:00
1d15fb5bc8 please fucking work 2025-07-27 09:15:48 -07:00
2e5ded7dbf fix api 2025-07-27 09:04:40 -07:00
3a59e87f3e fix something 2025-07-27 08:58:43 -07:00
abc48604a1 timeout init 2025-07-27 08:55:42 -07:00
ee80ed28b6 add stuff true streaming 2025-07-27 08:54:19 -07:00
b5eae7b41d pytorch shit 2025-07-27 08:40:59 -07:00
4cc14bc0a9 nvenc 2025-07-27 08:34:57 -07:00
9faaf4ed57 more stuff 2025-07-27 08:19:42 -07:00
7431954482 add main 2025-07-27 08:15:56 -07:00
f0208f0983 fixup some running stuff 2025-07-27 08:10:20 -07:00
10 changed files with 882 additions and 144 deletions

View File

@@ -93,16 +93,15 @@ git clone <repository-url>
cd sam2e cd sam2e
./runpod_setup.sh ./runpod_setup.sh
# Then use the convenience scripts: # Then run with Python directly:
./run_streaming.sh # For streaming approach (recommended) python -m vr180_streaming config-streaming-runpod.yaml # Streaming (recommended)
./run_chunked.sh # For chunked approach python -m vr180_matting config-chunked-runpod.yaml # Chunked (original)
``` ```
The setup script will: The setup script will:
- Install all dependencies - Install all dependencies
- Download SAM2 models - Download SAM2 models
- Create example configs - Create example configs for both approaches
- Set up convenience scripts
## Configuration ## Configuration

View File

@@ -27,9 +27,9 @@ matting:
sam2_model_cfg: "sam2.1_hiera_l" # Use large model for best quality sam2_model_cfg: "sam2.1_hiera_l" # Use large model for best quality
sam2_checkpoint: "segment-anything-2/checkpoints/sam2.1_hiera_large.pt" sam2_checkpoint: "segment-anything-2/checkpoints/sam2.1_hiera_large.pt"
memory_offload: true # Critical for streaming - offload to CPU when needed memory_offload: true # Critical for streaming - offload to CPU when needed
fp16: true # Use half precision for memory efficiency fp16: false # Disable FP16 to avoid type mismatch with compiled models for memory efficiency
continuous_correction: true # Periodically refine tracking continuous_correction: true # Periodically refine tracking
correction_interval: 300 # Correct every 5 seconds at 60fps correction_interval: 30 # Correct every 0.5 seconds at 60fps (for testing)
stereo: stereo:
mode: "master_slave" # Left eye detects, right eye follows mode: "master_slave" # Left eye detects, right eye follows
@@ -43,14 +43,14 @@ output:
path: "/workspace/output_video.mp4" # Update with your output path path: "/workspace/output_video.mp4" # Update with your output path
format: "greenscreen" # "greenscreen" or "alpha" format: "greenscreen" # "greenscreen" or "alpha"
background_color: [0, 255, 0] # RGB for green screen background_color: [0, 255, 0] # RGB for green screen
video_codec: "h264_nvenc" # GPU encoding (or "hevc_nvenc" for better compression) video_codec: "h264_nvenc" # GPU encoding for L40 (fallback to CPU if not available)
quality_preset: "p4" # NVENC preset (p1-p7, higher = better quality) quality_preset: "p4" # NVENC preset (p1=fastest, p7=slowest/best quality)
crf: 18 # Quality (0-51, lower = better, 18 = high quality) crf: 18 # Quality (0-51, lower = better, 18 = high quality)
maintain_sbs: true # Keep side-by-side format with audio maintain_sbs: true # Keep side-by-side format with audio
hardware: hardware:
device: "cuda" device: "cuda"
max_vram_gb: 40.0 # Conservative limit for 48GB GPU max_vram_gb: 44.0 # Conservative limit for L40 48GB VRAM
max_ram_gb: 48.0 # RunPod container RAM limit max_ram_gb: 48.0 # RunPod container RAM limit
recovery: recovery:

View File

@@ -12,4 +12,4 @@ ffmpeg-python>=0.2.0
decord>=0.6.0 decord>=0.6.0
# GPU acceleration (optional but recommended for stereo validation speedup) # GPU acceleration (optional but recommended for stereo validation speedup)
# cupy-cuda11x>=12.0.0 # For CUDA 11.x # cupy-cuda11x>=12.0.0 # For CUDA 11.x
# cupy-cuda12x>=12.0.0 # For CUDA 12.x - uncomment appropriate version cupy-cuda12x>=12.0.0 # For CUDA 12.x (most common on modern systems)

View File

@@ -1,6 +1,7 @@
#!/bin/bash #!/bin/bash
# VR180 Matting Unified Setup Script for RunPod # VR180 Matting Unified Setup Script for RunPod
# Supports both chunked and streaming implementations # Supports both chunked and streaming implementations
# Optimized for L40, A6000, and other NVENC-capable GPUs
set -e # Exit on error set -e # Exit on error
@@ -58,20 +59,11 @@ pip install -r requirements.txt
print_status "Installing video processing libraries..." print_status "Installing video processing libraries..."
pip install decord ffmpeg-python pip install decord ffmpeg-python
# Install CuPy for GPU acceleration of stereo validation # Install CuPy for GPU acceleration (CUDA 12 is standard on modern RunPod)
print_status "Installing CuPy for GPU acceleration..." print_status "Installing CuPy for GPU acceleration..."
# Auto-detect CUDA version and install appropriate CuPy
if command -v nvidia-smi &> /dev/null; then if command -v nvidia-smi &> /dev/null; then
CUDA_VERSION=$(nvidia-smi | grep "CUDA Version" | awk '{print $9}' | cut -d. -f1-2) print_status "Installing CuPy for CUDA 12.x (standard on RunPod)..."
echo "Detected CUDA version: $CUDA_VERSION"
if [[ "$CUDA_VERSION" == "11."* ]]; then
pip install cupy-cuda11x>=12.0.0 && print_success "Installed CuPy for CUDA 11.x"
elif [[ "$CUDA_VERSION" == "12."* ]]; then
pip install cupy-cuda12x>=12.0.0 && print_success "Installed CuPy for CUDA 12.x" pip install cupy-cuda12x>=12.0.0 && print_success "Installed CuPy for CUDA 12.x"
else
print_error "Unknown CUDA version, skipping CuPy installation"
fi
else else
print_error "NVIDIA GPU not detected, skipping CuPy installation" print_error "NVIDIA GPU not detected, skipping CuPy installation"
fi fi
@@ -91,6 +83,10 @@ else
cd .. cd ..
fi fi
# Fix PyTorch version conflicts after SAM2 installation
print_status "Fixing PyTorch version conflicts..."
pip install torchaudio --upgrade --no-deps || print_error "Failed to upgrade torchaudio"
# Download SAM2 checkpoints # Download SAM2 checkpoints
print_status "Downloading SAM2 checkpoints..." print_status "Downloading SAM2 checkpoints..."
cd segment-anything-2/checkpoints cd segment-anything-2/checkpoints
@@ -159,28 +155,7 @@ if [ ! -f "config-streaming-runpod.yaml" ]; then
print_error "config-streaming-runpod.yaml not found - please check the repository" print_error "config-streaming-runpod.yaml not found - please check the repository"
fi fi
# Create convenience run scripts # Skip creating convenience scripts - use Python directly
print_status "Creating run scripts..."
# Chunked approach
cat > run_chunked.sh << 'EOF'
#!/bin/bash
# Run VR180 matting with chunked approach (original)
echo "🎬 Running VR180 matting - Chunked Approach"
echo "==========================================="
python -m vr180_matting.main config-chunked-runpod.yaml "$@"
EOF
chmod +x run_chunked.sh
# Streaming approach
cat > run_streaming.sh << 'EOF'
#!/bin/bash
# Run VR180 matting with streaming approach (optimized)
echo "🎬 Running VR180 matting - Streaming Approach"
echo "============================================="
python -m vr180_streaming.main config-streaming-runpod.yaml "$@"
EOF
chmod +x run_streaming.sh
# Test installation # Test installation
print_status "Testing installation..." print_status "Testing installation..."
@@ -246,12 +221,10 @@ echo
echo "2. Choose your processing approach:" echo "2. Choose your processing approach:"
echo echo
echo " a) STREAMING (Recommended - 2-3x faster, constant memory):" echo " a) STREAMING (Recommended - 2-3x faster, constant memory):"
echo " ./run_streaming.sh" echo " python -m vr180_streaming config-streaming-runpod.yaml"
echo " # Or: python -m vr180_streaming config-streaming-runpod.yaml"
echo echo
echo " b) CHUNKED (Original - more stable, higher memory):" echo " b) CHUNKED (Original - more stable, higher memory):"
echo " ./run_chunked.sh" echo " python -m vr180_matting config-chunked-runpod.yaml"
echo " # Or: python -m vr180_matting config-chunked-runpod.yaml"
echo echo
echo "3. Optional: Edit configs first:" echo "3. Optional: Edit configs first:"
echo " nano config-streaming-runpod.yaml # For streaming" echo " nano config-streaming-runpod.yaml # For streaming"
@@ -267,18 +240,18 @@ echo "==================="
echo "- Streaming: Best for long videos, uses ~50GB RAM constant" echo "- Streaming: Best for long videos, uses ~50GB RAM constant"
echo "- Chunked: More stable but uses 100GB+ RAM in spikes" echo "- Chunked: More stable but uses 100GB+ RAM in spikes"
echo "- Scale factor: 0.25 (fast) → 0.5 (balanced) → 1.0 (quality)" echo "- Scale factor: 0.25 (fast) → 0.5 (balanced) → 1.0 (quality)"
echo "- A6000/A100: Can handle 0.5-0.75 scale easily" echo "- L40/A6000: Can handle 0.5-0.75 scale easily with NVENC GPU encoding"
echo "- Monitor VRAM with: nvidia-smi -l 1" echo "- Monitor VRAM with: nvidia-smi -l 1"
echo echo
echo "🎯 Example Commands:" echo "🎯 Example Commands:"
echo "===================" echo "==================="
echo "# Process with custom output path:" echo "# Process with custom output path:"
echo "./run_streaming.sh --output /workspace/output/my_video.mp4" echo "python -m vr180_streaming config-streaming-runpod.yaml --output /workspace/output/my_video.mp4"
echo echo
echo "# Process specific frame range:" echo "# Process specific frame range:"
echo "./run_streaming.sh --start-frame 1000 --max-frames 5000" echo "python -m vr180_streaming config-streaming-runpod.yaml --start-frame 1000 --max-frames 5000"
echo echo
echo "# Override scale for quality:" echo "# Override scale for quality:"
echo "./run_streaming.sh --scale 0.75" echo "python -m vr180_streaming config-streaming-runpod.yaml --scale 0.75"
echo echo
echo "Happy matting! 🎬" echo "Happy matting! 🎬"

View File

@@ -0,0 +1,9 @@
#!/usr/bin/env python3
"""
VR180 Streaming entry point for python -m vr180_streaming
"""
from .main import main
if __name__ == "__main__":
main()

View File

@@ -11,6 +11,28 @@ import atexit
import warnings import warnings
def test_nvenc_support() -> bool:
"""Test if NVENC encoding is available"""
try:
# Quick test with a 1-frame video
cmd = [
'ffmpeg', '-f', 'lavfi', '-i', 'testsrc=duration=0.1:size=320x240:rate=1',
'-c:v', 'h264_nvenc', '-t', '0.1', '-f', 'null', '-'
]
result = subprocess.run(
cmd,
capture_output=True,
timeout=10,
text=True
)
return result.returncode == 0
except (subprocess.TimeoutExpired, FileNotFoundError):
return False
class StreamingFrameWriter: class StreamingFrameWriter:
"""Write frames directly to ffmpeg via pipe for memory-efficient output""" """Write frames directly to ffmpeg via pipe for memory-efficient output"""
@@ -36,6 +58,16 @@ class StreamingFrameWriter:
self.frames_written = 0 self.frames_written = 0
self.ffmpeg_process = None self.ffmpeg_process = None
# Test NVENC support if GPU codec requested
if video_codec in ['h264_nvenc', 'hevc_nvenc']:
print(f"🔍 Testing NVENC support...")
if not test_nvenc_support():
print(f"❌ NVENC not available, switching to CPU encoding")
video_codec = 'libx264'
quality_preset = 'medium'
else:
print(f"✅ NVENC available")
# Build ffmpeg command # Build ffmpeg command
self.ffmpeg_cmd = self._build_ffmpeg_command( self.ffmpeg_cmd = self._build_ffmpeg_command(
video_codec, quality_preset, crf video_codec, quality_preset, crf
@@ -132,16 +164,41 @@ class StreamingFrameWriter:
bufsize=10**8 # Large buffer for performance bufsize=10**8 # Large buffer for performance
) )
# Test if ffmpeg starts successfully (quick check)
import time
time.sleep(0.2) # Give ffmpeg time to fail if it's going to
if self.ffmpeg_process.poll() is not None:
# Process already died - read error
stderr = self.ffmpeg_process.stderr.read().decode()
# Check for specific NVENC errors and provide better feedback
if 'nvenc' in ' '.join(self.ffmpeg_cmd):
if 'unsupported device' in stderr.lower():
print(f"❌ NVENC not available on this GPU - switching to CPU encoding")
elif 'cannot load' in stderr.lower() or 'not found' in stderr.lower():
print(f"❌ NVENC drivers not available - switching to CPU encoding")
else:
print(f"❌ NVENC encoding failed: {stderr}")
# Try CPU fallback
print(f"🔄 Falling back to CPU encoding (libx264)...")
self.ffmpeg_cmd = self._build_ffmpeg_command('libx264', 'medium', 18)
return self._start_ffmpeg()
else:
raise RuntimeError(f"FFmpeg failed: {stderr}")
# Set process to ignore SIGINT (Ctrl+C) - we'll handle it # Set process to ignore SIGINT (Ctrl+C) - we'll handle it
if hasattr(signal, 'pthread_sigmask'): if hasattr(signal, 'pthread_sigmask'):
signal.pthread_sigmask(signal.SIG_BLOCK, [signal.SIGINT]) signal.pthread_sigmask(signal.SIG_BLOCK, [signal.SIGINT])
except Exception as e: except Exception as e:
# Try CPU fallback if GPU encoding fails # Final fallback if everything fails
if 'nvenc' in self.ffmpeg_cmd: if 'nvenc' in ' '.join(self.ffmpeg_cmd):
print(f"⚠️ GPU encoding failed, trying CPU fallback...") print(f"⚠️ GPU encoding failed with error: {e}")
print(f"🔄 Falling back to CPU encoding...")
self.ffmpeg_cmd = self._build_ffmpeg_command('libx264', 'medium', 18) self.ffmpeg_cmd = self._build_ffmpeg_command('libx264', 'medium', 18)
self._start_ffmpeg() return self._start_ffmpeg()
else: else:
raise RuntimeError(f"Failed to start ffmpeg: {e}") raise RuntimeError(f"Failed to start ffmpeg: {e}")

View File

@@ -14,6 +14,7 @@ For a true streaming implementation, you may need to:
import torch import torch
import numpy as np import numpy as np
import cv2
from pathlib import Path from pathlib import Path
from typing import Dict, Any, List, Optional, Tuple, Generator from typing import Dict, Any, List, Optional, Tuple, Generator
import warnings import warnings
@@ -34,6 +35,11 @@ class SAM2StreamingProcessor:
self.config = config self.config = config
self.device = torch.device(config.get('hardware', {}).get('device', 'cuda')) self.device = torch.device(config.get('hardware', {}).get('device', 'cuda'))
# Processing parameters (set before _init_predictor)
self.memory_offload = config.get('matting', {}).get('memory_offload', True)
self.fp16 = config.get('matting', {}).get('fp16', True)
self.correction_interval = config.get('matting', {}).get('correction_interval', 300)
# SAM2 model configuration # SAM2 model configuration
model_cfg = config.get('matting', {}).get('sam2_model_cfg', 'sam2.1_hiera_l') model_cfg = config.get('matting', {}).get('sam2_model_cfg', 'sam2.1_hiera_l')
checkpoint = config.get('matting', {}).get('sam2_checkpoint', checkpoint = config.get('matting', {}).get('sam2_checkpoint',
@@ -43,11 +49,6 @@ class SAM2StreamingProcessor:
self.predictor = None self.predictor = None
self._init_predictor(model_cfg, checkpoint) self._init_predictor(model_cfg, checkpoint)
# Processing parameters
self.memory_offload = config.get('matting', {}).get('memory_offload', True)
self.fp16 = config.get('matting', {}).get('fp16', True)
self.correction_interval = config.get('matting', {}).get('correction_interval', 300)
# State management # State management
self.states = {} # eye -> inference state self.states = {} # eye -> inference state
self.object_ids = [] self.object_ids = []
@@ -80,52 +81,165 @@ class SAM2StreamingProcessor:
vos_optimized=True # Enable full model compilation for speed vos_optimized=True # Enable full model compilation for speed
) )
# Set to eval mode # Set to eval mode and ensure all model components are on GPU
self.predictor.eval() self.predictor.eval()
# Enable FP16 if requested # Force all predictor components to GPU
self.predictor = self.predictor.to(self.device)
# Force move all internal components that might be on CPU
if hasattr(self.predictor, 'image_encoder'):
self.predictor.image_encoder = self.predictor.image_encoder.to(self.device)
if hasattr(self.predictor, 'memory_attention'):
self.predictor.memory_attention = self.predictor.memory_attention.to(self.device)
if hasattr(self.predictor, 'memory_encoder'):
self.predictor.memory_encoder = self.predictor.memory_encoder.to(self.device)
if hasattr(self.predictor, 'sam_mask_decoder'):
self.predictor.sam_mask_decoder = self.predictor.sam_mask_decoder.to(self.device)
if hasattr(self.predictor, 'sam_prompt_encoder'):
self.predictor.sam_prompt_encoder = self.predictor.sam_prompt_encoder.to(self.device)
# Note: FP16 conversion can cause type mismatches with compiled models
# Let SAM2 handle precision internally via build_sam2_video_predictor options
if self.fp16 and self.device.type == 'cuda': if self.fp16 and self.device.type == 'cuda':
self.predictor = self.predictor.half() print(" FP16 enabled via SAM2 internal settings")
print(f" All SAM2 components moved to {self.device}")
except Exception as e: except Exception as e:
raise RuntimeError(f"Failed to initialize SAM2 predictor: {e}") raise RuntimeError(f"Failed to initialize SAM2 predictor: {e}")
def init_state(self, def init_state(self,
video_path: str, video_info: Dict[str, Any],
eye: str = 'full') -> Dict[str, Any]: eye: str = 'full') -> Dict[str, Any]:
""" """
Initialize inference state for streaming Initialize inference state for streaming (NO VIDEO LOADING)
Args: Args:
video_path: Path to video file video_info: Video metadata dict with width, height, frame_count
eye: Eye identifier ('left', 'right', or 'full') eye: Eye identifier ('left', 'right', or 'full')
Returns: Returns:
Inference state dictionary Inference state dictionary
""" """
# Initialize state with memory offloading enabled print(f" Initializing streaming state for {eye} eye...")
with torch.inference_mode():
state = self.predictor.init_state( # Monitor memory before initialization
video_path=video_path, if torch.cuda.is_available():
offload_video_to_cpu=self.memory_offload, before_mem = torch.cuda.memory_allocated() / 1e9
offload_state_to_cpu=self.memory_offload, print(f" 📊 GPU memory before init: {before_mem:.1f}GB")
async_loading_frames=False # We'll provide frames directly
) # Create streaming state WITHOUT loading video frames
state = self._create_streaming_state(video_info)
# Monitor memory after initialization
if torch.cuda.is_available():
after_mem = torch.cuda.memory_allocated() / 1e9
print(f" 📊 GPU memory after init: {after_mem:.1f}GB (+{after_mem-before_mem:.1f}GB)")
self.states[eye] = state self.states[eye] = state
print(f" Initialized state for {eye} eye") print(f" ✅ Streaming state initialized for {eye} eye")
return state return state
def _create_streaming_state(self, video_info: Dict[str, Any]) -> Dict[str, Any]:
"""Create streaming state for frame-by-frame processing"""
# Create a streaming-compatible inference state
# This mirrors SAM2's internal state structure but without video frames
# Create streaming-compatible state without loading video
# This approach avoids the dummy video complexity
with torch.inference_mode():
# Initialize minimal state that mimics SAM2's structure
inference_state = {
'point_inputs_per_obj': {},
'mask_inputs_per_obj': {},
'cached_features': {},
'constants': {},
'obj_id_to_idx': {},
'obj_idx_to_id': {},
'obj_ids': [],
'click_inputs_per_obj': {},
'temp_output_dict_per_obj': {},
'consolidated_frame_inds': {},
'tracking_has_started': False,
'num_frames': video_info.get('total_frames', video_info.get('frame_count', 0)),
'video_height': video_info['height'],
'video_width': video_info['width'],
'device': self.device,
'storage_device': self.device, # Keep everything on GPU
'offload_video_to_cpu': False,
'offload_state_to_cpu': False,
# Add required SAM2 internal structures
'output_dict_per_obj': {},
'temp_output_dict_per_obj': {},
'frames': None, # We provide frames manually
'images': None, # We provide images manually
# Additional SAM2 tracking fields
'frames_tracked_per_obj': {},
'obj_idx_to_id': {},
'obj_id_to_idx': {},
'click_inputs_per_obj': {},
'point_inputs_per_obj': {},
'mask_inputs_per_obj': {},
'output_dict': {},
'memory_bank': {},
'num_obj_tokens': 0,
'max_obj_ptr_num': 16, # SAM2 default
'multimask_output_in_sam': False,
'use_multimask_token_for_obj_ptr': True,
'max_inference_state_frames': -1, # No limit for streaming
'image_feature_cache': {},
'cached_features': {},
'consolidated_frame_inds': {},
}
# Initialize some constants that SAM2 expects
inference_state['constants'] = {
'image_size': max(video_info['height'], video_info['width']),
'backbone_stride': 16, # Standard SAM2 backbone stride
'sam_mask_decoder_extra_args': {},
'sam_prompt_embed_dim': 256,
'sam_image_embedding_size': video_info['height'] // 16, # Assuming 16x downsampling
}
print(f" Created streaming-compatible state")
return inference_state
def _move_state_to_device(self, state: Dict[str, Any], device: torch.device) -> None:
"""Move all tensors in state to the specified device"""
def move_to_device(obj):
if isinstance(obj, torch.Tensor):
return obj.to(device)
elif isinstance(obj, dict):
return {k: move_to_device(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [move_to_device(item) for item in obj]
elif isinstance(obj, tuple):
return tuple(move_to_device(item) for item in obj)
else:
return obj
# Move all state components to device
for key, value in state.items():
if key not in ['video_path', 'num_frames', 'video_height', 'video_width']: # Skip metadata
state[key] = move_to_device(value)
print(f" Moved state tensors to {device}")
def add_detections(self, def add_detections(self,
state: Dict[str, Any], state: Dict[str, Any],
frame: np.ndarray,
detections: List[Dict[str, Any]], detections: List[Dict[str, Any]],
frame_idx: int = 0) -> List[int]: frame_idx: int = 0) -> List[int]:
""" """
Add detection boxes as prompts to SAM2 Add detection boxes as prompts to SAM2 with frame data
Args: Args:
state: Inference state state: Inference state
frame: Frame image (RGB numpy array)
detections: List of detections with 'box' key detections: List of detections with 'box' key
frame_idx: Frame index to add prompts frame_idx: Frame index to add prompts
@@ -136,6 +250,23 @@ class SAM2StreamingProcessor:
warnings.warn(f"No detections to add at frame {frame_idx}") warnings.warn(f"No detections to add at frame {frame_idx}")
return [] return []
# Convert frame to tensor (ensure proper format and device)
if isinstance(frame, np.ndarray):
# Convert BGR to RGB if needed (OpenCV uses BGR)
if frame.shape[-1] == 3:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame_tensor = torch.from_numpy(frame).float().to(self.device)
else:
frame_tensor = frame.float().to(self.device)
if frame_tensor.ndim == 3:
frame_tensor = frame_tensor.permute(2, 0, 1) # HWC -> CHW
frame_tensor = frame_tensor.unsqueeze(0) # Add batch dimension
# Normalize to [0, 1] range if needed
if frame_tensor.max() > 1.0:
frame_tensor = frame_tensor / 255.0
# Convert detections to SAM2 format # Convert detections to SAM2 format
boxes = [] boxes = []
for det in detections: for det in detections:
@@ -144,40 +275,157 @@ class SAM2StreamingProcessor:
boxes_tensor = torch.tensor(boxes, dtype=torch.float32, device=self.device) boxes_tensor = torch.tensor(boxes, dtype=torch.float32, device=self.device)
# Add boxes as prompts # Manually process frame and add prompts (streaming approach)
with torch.inference_mode(): with torch.inference_mode():
_, object_ids, _ = self.predictor.add_new_points_or_box( # Process frame through SAM2's image encoder
backbone_out = self.predictor.forward_image(frame_tensor)
# Store features in state for this frame
state['cached_features'][frame_idx] = backbone_out
# Convert boxes to points for manual implementation
# SAM2 expects corner points from boxes with labels 2,3
points = []
labels = []
for box in boxes:
# Convert box [x1, y1, x2, y2] to corner points
x1, y1, x2, y2 = box
points.extend([[x1, y1], [x2, y2]]) # Top-left and bottom-right corners
labels.extend([2, 3]) # SAM2 standard labels for box corners
points_tensor = torch.tensor(points, dtype=torch.float32, device=self.device)
labels_tensor = torch.tensor(labels, dtype=torch.int32, device=self.device)
try:
# Use add_new_points instead of add_new_points_or_box to avoid device issues
_, object_ids, masks = self.predictor.add_new_points(
inference_state=state, inference_state=state,
frame_idx=frame_idx, frame_idx=frame_idx,
obj_id=0, # SAM2 will auto-increment obj_id=None, # Let SAM2 auto-assign
box=boxes_tensor points=points_tensor,
labels=labels_tensor,
clear_old_points=True,
normalize_coords=True
) )
# Update state with object tracking info
state['obj_ids'] = object_ids
state['tracking_has_started'] = True
except Exception as e:
print(f" Error in add_new_points: {e}")
print(f" Points tensor device: {points_tensor.device}")
print(f" Labels tensor device: {labels_tensor.device}")
print(f" Frame tensor device: {frame_tensor.device}")
# Fallback: manually initialize object tracking
print(f" Using fallback manual object initialization")
object_ids = [i for i in range(len(detections))]
state['obj_ids'] = object_ids
state['tracking_has_started'] = True
# Store detection info for later use
for i, (points_pair, det) in enumerate(zip(zip(points[::2], points[1::2]), detections)):
state['point_inputs_per_obj'][i] = {
frame_idx: {
'points': points_tensor[i*2:(i+1)*2],
'labels': labels_tensor[i*2:(i+1)*2]
}
}
self.object_ids = object_ids self.object_ids = object_ids
print(f" Added {len(detections)} detections at frame {frame_idx}: objects {object_ids}") print(f" Added {len(detections)} detections at frame {frame_idx}: objects {object_ids}")
return object_ids return object_ids
def propagate_in_video_simple(self, def propagate_single_frame(self,
state: Dict[str, Any]) -> Generator[Tuple[int, List[int], np.ndarray], None, None]: state: Dict[str, Any],
frame: np.ndarray,
frame_idx: int) -> np.ndarray:
""" """
Simple propagation for single eye processing Propagate masks for a single frame (true streaming)
Yields: Args:
(frame_idx, object_ids, masks) tuples state: Inference state
frame: Frame image (RGB numpy array)
frame_idx: Frame index
Returns:
Combined mask for all objects
""" """
with torch.inference_mode(): # Convert frame to tensor (ensure proper format and device)
for frame_idx, object_ids, masks in self.predictor.propagate_in_video(state): if isinstance(frame, np.ndarray):
# Convert masks to numpy # Convert BGR to RGB if needed (OpenCV uses BGR)
if isinstance(masks, torch.Tensor): if frame.shape[-1] == 3:
masks_np = masks.cpu().numpy() frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame_tensor = torch.from_numpy(frame).float().to(self.device)
else: else:
masks_np = masks frame_tensor = frame.float().to(self.device)
yield frame_idx, object_ids, masks_np if frame_tensor.ndim == 3:
frame_tensor = frame_tensor.permute(2, 0, 1) # HWC -> CHW
frame_tensor = frame_tensor.unsqueeze(0) # Add batch dimension
# Periodic memory cleanup # Normalize to [0, 1] range if needed
if frame_idx % 100 == 0: if frame_tensor.max() > 1.0:
frame_tensor = frame_tensor / 255.0
with torch.inference_mode():
# Process frame through SAM2's image encoder
backbone_out = self.predictor.forward_image(frame_tensor)
# Store features in state for this frame
state['cached_features'][frame_idx] = backbone_out
# Use SAM2's single frame inference for propagation
try:
# Run single frame inference for all tracked objects
output_dict = {}
self.predictor._run_single_frame_inference(
inference_state=state,
output_dict=output_dict,
frame_idx=frame_idx,
batch_size=1,
is_init_cond_frame=False, # Not initialization frame
point_inputs=None,
mask_inputs=None,
reverse=False,
run_mem_encoder=True
)
# Extract masks from output
if output_dict and 'pred_masks' in output_dict:
pred_masks = output_dict['pred_masks']
# Combine all object masks
if pred_masks.shape[0] > 0:
combined_mask = pred_masks.max(dim=0)[0]
combined_mask_np = (combined_mask > 0.0).cpu().numpy().astype(np.uint8) * 255
else:
combined_mask_np = np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8)
else:
combined_mask_np = np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8)
except Exception as e:
print(f" Warning: Single frame inference failed: {e}")
combined_mask_np = np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8)
# Cleanup old features to prevent memory accumulation
self._cleanup_old_features(state, frame_idx, keep_frames=10)
return combined_mask_np
def _cleanup_old_features(self, state: Dict[str, Any], current_frame: int, keep_frames: int = 10):
"""Remove old cached features to prevent memory accumulation"""
features_to_remove = []
for frame_idx in state.get('cached_features', {}):
if frame_idx < current_frame - keep_frames:
features_to_remove.append(frame_idx)
for frame_idx in features_to_remove:
del state['cached_features'][frame_idx]
# Periodic GPU memory cleanup
if current_frame % 50 == 0:
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()

View File

@@ -0,0 +1,407 @@
"""
Simple SAM2 streaming processor based on det-sam2 pattern
Adapted for current segment-anything-2 API
"""
import torch
import numpy as np
import cv2
import tempfile
import os
from pathlib import Path
from typing import Dict, Any, List, Optional
import warnings
import gc
# Import SAM2 components
try:
from sam2.build_sam import build_sam2_video_predictor
except ImportError:
warnings.warn("SAM2 not installed. Please install segment-anything-2 first.")
class SAM2StreamingProcessor:
"""Simple streaming integration with SAM2 following det-sam2 pattern"""
def __init__(self, config: Dict[str, Any]):
self.config = config
self.device = torch.device(config.get('hardware', {}).get('device', 'cuda'))
# SAM2 model configuration
model_cfg_name = config.get('matting', {}).get('sam2_model_cfg', 'sam2.1_hiera_l')
checkpoint = config.get('matting', {}).get('sam2_checkpoint',
'segment-anything-2/checkpoints/sam2.1_hiera_large.pt')
# Map config name to Hydra path (like the examples show)
config_mapping = {
'sam2.1_hiera_t': 'configs/sam2.1/sam2.1_hiera_t.yaml',
'sam2.1_hiera_s': 'configs/sam2.1/sam2.1_hiera_s.yaml',
'sam2.1_hiera_b+': 'configs/sam2.1/sam2.1_hiera_b+.yaml',
'sam2.1_hiera_l': 'configs/sam2.1/sam2.1_hiera_l.yaml',
}
model_cfg = config_mapping.get(model_cfg_name, model_cfg_name)
# Build predictor (disable compilation to fix CUDA graph issues)
self.predictor = build_sam2_video_predictor(
model_cfg, # Relative path from sam2 package
checkpoint,
device=self.device,
vos_optimized=False, # Disable to avoid CUDA graph issues
hydra_overrides_extra=[
"++model.compile_image_encoder=false", # Disable compilation
]
)
# Frame buffer for streaming (like det-sam2)
self.frame_buffer = []
self.frame_buffer_size = config.get('streaming', {}).get('buffer_frames', 10)
# State management (simple)
self.inference_state = None
self.temp_dir = None
self.object_ids = []
# Memory management
self.memory_offload = config.get('matting', {}).get('memory_offload', True)
self.max_frames_to_track = config.get('matting', {}).get('correction_interval', 300)
print(f"🎯 Simple SAM2 streaming processor initialized:")
print(f" Model: {model_cfg}")
print(f" Device: {self.device}")
print(f" Buffer size: {self.frame_buffer_size}")
print(f" Memory offload: {self.memory_offload}")
def add_frame_and_detections(self,
frame: np.ndarray,
detections: List[Dict[str, Any]],
frame_idx: int) -> np.ndarray:
"""
Add frame to buffer and process detections (det-sam2 pattern)
Args:
frame: Input frame (BGR)
detections: List of detections with 'box' key
frame_idx: Global frame index
Returns:
Mask for current frame
"""
# Add frame to buffer
self.frame_buffer.append({
'frame': frame,
'frame_idx': frame_idx,
'detections': detections
})
# Process when buffer is full or when we have detections
if len(self.frame_buffer) >= self.frame_buffer_size or detections:
return self._process_buffer()
else:
# For frames without detections, still try to propagate if we have existing objects
if self.inference_state is not None and self.object_ids:
return self._propagate_existing_objects()
else:
# Return empty mask if no processing yet
return np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8)
def _process_buffer(self) -> np.ndarray:
"""Process current frame buffer (adapted det-sam2 approach)"""
if not self.frame_buffer:
return np.zeros((480, 640), dtype=np.uint8)
try:
# Create temporary directory for frames (current SAM2 API requirement)
self._create_temp_frames()
# Initialize or update SAM2 state
if self.inference_state is None:
# First time: initialize state with temp directory
self.inference_state = self.predictor.init_state(
video_path=self.temp_dir,
offload_video_to_cpu=self.memory_offload,
offload_state_to_cpu=self.memory_offload
)
print(f" Initialized SAM2 state with {len(self.frame_buffer)} frames")
else:
# Subsequent times: we need to reinitialize since current SAM2 lacks update_state
# This is the key difference from det-sam2 reference
self._cleanup_temp_frames()
self._create_temp_frames()
self.inference_state = self.predictor.init_state(
video_path=self.temp_dir,
offload_video_to_cpu=self.memory_offload,
offload_state_to_cpu=self.memory_offload
)
print(f" Reinitialized SAM2 state with {len(self.frame_buffer)} frames")
# Add detections as prompts (standard SAM2 API)
self._add_detection_prompts()
# Get masks via propagation
masks = self._get_current_masks()
# Clean up old frames to prevent memory accumulation
self._cleanup_old_frames()
return masks
except Exception as e:
print(f" Warning: Buffer processing failed: {e}")
return np.zeros((480, 640), dtype=np.uint8)
def _create_temp_frames(self):
"""Create temporary directory with frame images for SAM2"""
if self.temp_dir:
self._cleanup_temp_frames()
self.temp_dir = tempfile.mkdtemp(prefix='sam2_streaming_')
for i, buffer_item in enumerate(self.frame_buffer):
frame = buffer_item['frame']
# Convert BGR to RGB for SAM2
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# Save as JPEG (SAM2 expects JPEG images in directory)
frame_path = os.path.join(self.temp_dir, f"{i:05d}.jpg")
cv2.imwrite(frame_path, cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR))
def _add_detection_prompts(self):
"""Add detection boxes as prompts to SAM2 (standard API)"""
for buffer_idx, buffer_item in enumerate(self.frame_buffer):
detections = buffer_item.get('detections', [])
for det_idx, detection in enumerate(detections):
box = detection['box'] # [x1, y1, x2, y2]
# Use standard SAM2 API
try:
_, out_obj_ids, out_mask_logits = self.predictor.add_new_points_or_box(
inference_state=self.inference_state,
frame_idx=buffer_idx, # Frame index within buffer
obj_id=det_idx, # Simple object ID
box=np.array(box, dtype=np.float32)
)
# Track object IDs
if det_idx not in self.object_ids:
self.object_ids.append(det_idx)
except Exception as e:
print(f" Warning: Failed to add detection: {e}")
continue
def _get_current_masks(self) -> np.ndarray:
"""Get masks for current frame via propagation"""
if not self.object_ids:
# No objects to track
frame_shape = self.frame_buffer[-1]['frame'].shape
return np.zeros((frame_shape[0], frame_shape[1]), dtype=np.uint8)
try:
# Use SAM2's propagate_in_video (standard API)
latest_frame_idx = len(self.frame_buffer) - 1
masks_for_frame = []
for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(
self.inference_state,
start_frame_idx=latest_frame_idx,
max_frame_num_to_track=1, # Just current frame
reverse=False
):
if out_frame_idx == latest_frame_idx:
# Combine all object masks
if len(out_mask_logits) > 0:
combined_mask = None
for mask_logit in out_mask_logits:
mask = (mask_logit > 0.0).cpu().numpy()
if combined_mask is None:
combined_mask = mask.astype(bool)
else:
combined_mask = combined_mask | mask.astype(bool)
return (combined_mask * 255).astype(np.uint8)
# If no masks found, return empty
frame_shape = self.frame_buffer[-1]['frame'].shape
return np.zeros((frame_shape[0], frame_shape[1]), dtype=np.uint8)
except Exception as e:
print(f" Warning: Mask propagation failed: {e}")
frame_shape = self.frame_buffer[-1]['frame'].shape
return np.zeros((frame_shape[0], frame_shape[1]), dtype=np.uint8)
def _propagate_existing_objects(self) -> np.ndarray:
"""Propagate existing objects without adding new detections"""
if not self.object_ids or not self.frame_buffer:
frame_shape = self.frame_buffer[-1]['frame'].shape if self.frame_buffer else (480, 640)
return np.zeros((frame_shape[0], frame_shape[1]), dtype=np.uint8)
try:
# Update temp frames with current buffer
self._create_temp_frames()
# Reinitialize state (since we can't incrementally update)
self.inference_state = self.predictor.init_state(
video_path=self.temp_dir,
offload_video_to_cpu=self.memory_offload,
offload_state_to_cpu=self.memory_offload
)
# Re-add all previous detections from buffer
for buffer_idx, buffer_item in enumerate(self.frame_buffer):
detections = buffer_item.get('detections', [])
if detections: # Only add frames that had detections
for det_idx, detection in enumerate(detections):
box = detection['box']
try:
self.predictor.add_new_points_or_box(
inference_state=self.inference_state,
frame_idx=buffer_idx,
obj_id=det_idx,
box=np.array(box, dtype=np.float32)
)
except Exception as e:
print(f" Warning: Failed to re-add detection: {e}")
# Get masks for latest frame
latest_frame_idx = len(self.frame_buffer) - 1
for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(
self.inference_state,
start_frame_idx=latest_frame_idx,
max_frame_num_to_track=1,
reverse=False
):
if out_frame_idx == latest_frame_idx and len(out_mask_logits) > 0:
combined_mask = None
for mask_logit in out_mask_logits:
mask = (mask_logit > 0.0).cpu().numpy()
if combined_mask is None:
combined_mask = mask.astype(bool)
else:
combined_mask = combined_mask | mask.astype(bool)
return (combined_mask * 255).astype(np.uint8)
# If no masks, return empty
frame_shape = self.frame_buffer[-1]['frame'].shape
return np.zeros((frame_shape[0], frame_shape[1]), dtype=np.uint8)
except Exception as e:
print(f" Warning: Object propagation failed: {e}")
frame_shape = self.frame_buffer[-1]['frame'].shape
return np.zeros((frame_shape[0], frame_shape[1]), dtype=np.uint8)
except Exception as e:
print(f" Warning: Mask propagation failed: {e}")
frame_shape = self.frame_buffer[-1]['frame'].shape
return np.zeros((frame_shape[0], frame_shape[1]), dtype=np.uint8)
def _cleanup_old_frames(self):
"""Clean up old frames from buffer (det-sam2 pattern)"""
# Keep only recent frames to prevent memory accumulation
if len(self.frame_buffer) > self.frame_buffer_size:
self.frame_buffer = self.frame_buffer[-self.frame_buffer_size:]
# Periodic GPU memory cleanup
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
def _cleanup_temp_frames(self):
"""Clean up temporary frame directory"""
if self.temp_dir and os.path.exists(self.temp_dir):
import shutil
shutil.rmtree(self.temp_dir)
self.temp_dir = None
def cleanup(self):
"""Clean up all resources"""
self._cleanup_temp_frames()
self.frame_buffer.clear()
self.object_ids.clear()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
gc.collect()
print("🧹 Simple SAM2 streaming processor cleaned up")
def apply_mask_to_frame(self,
frame: np.ndarray,
mask: np.ndarray,
output_format: str = "alpha",
background_color: tuple = (0, 255, 0)) -> np.ndarray:
"""
Apply mask to frame with specified output format (matches chunked implementation)
Args:
frame: Input frame (BGR)
mask: Binary mask (0-255 or boolean)
output_format: "alpha" or "greenscreen"
background_color: RGB background color for greenscreen mode
Returns:
Processed frame
"""
if mask is None:
return frame
# Ensure mask is 2D (handle 3D masks properly)
if mask.ndim == 3:
mask = mask.squeeze()
# Resize mask to match frame if needed (use INTER_NEAREST for binary masks)
if mask.shape[:2] != frame.shape[:2]:
import cv2
# Convert to uint8 for resizing, then back to bool
if mask.dtype == bool:
mask_uint8 = mask.astype(np.uint8) * 255
else:
mask_uint8 = mask.astype(np.uint8)
mask_resized = cv2.resize(mask_uint8,
(frame.shape[1], frame.shape[0]),
interpolation=cv2.INTER_NEAREST)
mask = mask_resized.astype(bool) if mask.dtype == bool else mask_resized
if output_format == "alpha":
# Create RGBA output (matches chunked implementation)
output = np.zeros((frame.shape[0], frame.shape[1], 4), dtype=np.uint8)
output[:, :, :3] = frame
if mask.dtype == bool:
output[:, :, 3] = mask.astype(np.uint8) * 255
else:
output[:, :, 3] = mask.astype(np.uint8)
return output
elif output_format == "greenscreen":
# Create RGB output with background (matches chunked implementation)
output = np.full_like(frame, background_color, dtype=np.uint8)
if mask.dtype == bool:
output[mask] = frame[mask]
else:
mask_bool = mask.astype(bool)
output[mask_bool] = frame[mask_bool]
return output
else:
raise ValueError(f"Unsupported output format: {output_format}. Use 'alpha' or 'greenscreen'")
def get_memory_usage(self) -> Dict[str, float]:
"""
Get current memory usage statistics
Returns:
Dictionary with memory usage info
"""
stats = {}
if torch.cuda.is_available():
# GPU memory stats
stats['cuda_allocated_gb'] = torch.cuda.memory_allocated() / (1024**3)
stats['cuda_reserved_gb'] = torch.cuda.memory_reserved() / (1024**3)
stats['cuda_max_allocated_gb'] = torch.cuda.max_memory_allocated() / (1024**3)
return stats

View File

@@ -15,7 +15,7 @@ import warnings
from .frame_reader import StreamingFrameReader from .frame_reader import StreamingFrameReader
from .frame_writer import StreamingFrameWriter from .frame_writer import StreamingFrameWriter
from .stereo_manager import StereoConsistencyManager from .stereo_manager import StereoConsistencyManager
from .sam2_streaming import SAM2StreamingProcessor from .sam2_streaming_simple import SAM2StreamingProcessor
from .detector import PersonDetector from .detector import PersonDetector
from .config import StreamingConfig from .config import StreamingConfig
@@ -102,25 +102,17 @@ class VR180StreamingProcessor:
self.initialize() self.initialize()
self.start_time = time.time() self.start_time = time.time()
# Initialize SAM2 states for both eyes # Simple SAM2 initialization (no complex state management needed)
print("🎯 Initializing SAM2 streaming states...") print("🎯 SAM2 streaming processor ready...")
left_state = self.sam2_processor.init_state(
self.config.input.video_path,
eye='left'
)
right_state = self.sam2_processor.init_state(
self.config.input.video_path,
eye='right'
)
# Process first frame to establish detections # Process first frame to establish detections
print("🔍 Processing first frame for initial detection...") print("🔍 Processing first frame for initial detection...")
if not self._initialize_tracking(left_state, right_state): if not self._initialize_tracking():
raise RuntimeError("Failed to initialize tracking - no persons detected") raise RuntimeError("Failed to initialize tracking - no persons detected")
# Main streaming loop # Main streaming loop
print("\n🎬 Starting streaming processing loop...") print("\n🎬 Starting streaming processing loop...")
self._streaming_loop(left_state, right_state) self._streaming_loop()
except KeyboardInterrupt: except KeyboardInterrupt:
print("\n⚠️ Processing interrupted by user") print("\n⚠️ Processing interrupted by user")
@@ -134,7 +126,7 @@ class VR180StreamingProcessor:
finally: finally:
self._finalize() self._finalize()
def _initialize_tracking(self, left_state: Dict, right_state: Dict) -> bool: def _initialize_tracking(self) -> bool:
"""Initialize tracking with first frame detection""" """Initialize tracking with first frame detection"""
# Read and process first frame # Read and process first frame
first_frame = self.frame_reader.read_frame() first_frame = self.frame_reader.read_frame()
@@ -158,19 +150,15 @@ class VR180StreamingProcessor:
print(f" Detected {len(detections)} person(s) in first frame") print(f" Detected {len(detections)} person(s) in first frame")
# Add detections to both eyes # Process with simple SAM2 approach
self.sam2_processor.add_detections(left_state, detections, frame_idx=0) left_masks = self.sam2_processor.add_frame_and_detections(left_eye, detections, 0)
# Transfer detections to slave eye # Transfer detections to right eye
transferred_detections = self.stereo_manager.transfer_detections( transferred_detections = self.stereo_manager.transfer_detections(
detections, detections,
'left_to_right' if self.stereo_manager.master_eye == 'left' else 'right_to_left' 'left_to_right' if self.stereo_manager.master_eye == 'left' else 'right_to_left'
) )
self.sam2_processor.add_detections(right_state, transferred_detections, frame_idx=0) right_masks = self.sam2_processor.add_frame_and_detections(right_eye, transferred_detections, 0)
# Process and write first frame
left_masks = self.sam2_processor._propagate_single_frame(left_state, left_eye, 0)
right_masks = self.sam2_processor._propagate_single_frame(right_state, right_eye, 0)
# Apply masks and write # Apply masks and write
processed_frame = self._apply_masks_to_frame(first_frame, left_masks, right_masks) processed_frame = self._apply_masks_to_frame(first_frame, left_masks, right_masks)
@@ -179,7 +167,7 @@ class VR180StreamingProcessor:
self.frames_processed = 1 self.frames_processed = 1
return True return True
def _streaming_loop(self, left_state: Dict, right_state: Dict) -> None: def _streaming_loop(self) -> None:
"""Main streaming processing loop""" """Main streaming processing loop"""
frame_times = [] frame_times = []
last_log_time = time.time() last_log_time = time.time()
@@ -195,23 +183,36 @@ class VR180StreamingProcessor:
# Split into eyes # Split into eyes
left_eye, right_eye = self.stereo_manager.split_frame(frame) left_eye, right_eye = self.stereo_manager.split_frame(frame)
# Propagate masks for both eyes # Check if we need to run detection for continuous correction
left_masks, right_masks = self.sam2_processor.propagate_frame_pair( detections = []
left_state, right_state, left_eye, right_eye, frame_idx if (self.config.matting.continuous_correction and
frame_idx % self.config.matting.correction_interval == 0):
print(f"\n🔄 Running YOLO detection for correction at frame {frame_idx}")
master_eye = left_eye if self.stereo_manager.master_eye == 'left' else right_eye
detections = self.detector.detect_persons(master_eye)
if detections:
print(f" Detected {len(detections)} person(s) for correction")
else:
print(f" No persons detected for correction")
# Process frames (with detections if this is a correction frame)
left_masks = self.sam2_processor.add_frame_and_detections(left_eye, detections, frame_idx)
# For right eye, transfer detections if we have them
if detections:
transferred_detections = self.stereo_manager.transfer_detections(
detections,
'left_to_right' if self.stereo_manager.master_eye == 'left' else 'right_to_left'
) )
right_masks = self.sam2_processor.add_frame_and_detections(right_eye, transferred_detections, frame_idx)
else:
right_masks = self.sam2_processor.add_frame_and_detections(right_eye, [], frame_idx)
# Validate stereo consistency # Validate stereo consistency
right_masks = self.stereo_manager.validate_masks( right_masks = self.stereo_manager.validate_masks(
left_masks, right_masks, frame_idx left_masks, right_masks, frame_idx
) )
# Apply continuous correction if enabled
if (self.config.matting.continuous_correction and
frame_idx % self.config.matting.correction_interval == 0):
self._apply_continuous_correction(
left_state, right_state, left_eye, right_eye, frame_idx
)
# Apply masks and write frame # Apply masks and write frame
processed_frame = self._apply_masks_to_frame(frame, left_masks, right_masks) processed_frame = self._apply_masks_to_frame(frame, left_masks, right_masks)
self.frame_writer.write_frame(processed_frame) self.frame_writer.write_frame(processed_frame)
@@ -282,21 +283,20 @@ class VR180StreamingProcessor:
return left_processed return left_processed
def _apply_continuous_correction(self, def _apply_continuous_correction(self,
left_state: Dict,
right_state: Dict,
left_eye: np.ndarray, left_eye: np.ndarray,
right_eye: np.ndarray, right_eye: np.ndarray,
frame_idx: int) -> None: frame_idx: int) -> None:
"""Apply continuous correction to maintain tracking accuracy""" """Apply continuous correction to maintain tracking accuracy"""
print(f"\n🔄 Applying continuous correction at frame {frame_idx}") print(f"\n🔄 Applying continuous correction at frame {frame_idx}")
# Detect on master eye # Detect on master eye and add fresh detections
master_eye = left_eye if self.stereo_manager.master_eye == 'left' else right_eye master_eye = left_eye if self.stereo_manager.master_eye == 'left' else right_eye
master_state = left_state if self.stereo_manager.master_eye == 'left' else right_state detections = self.detector.detect_persons(master_eye)
self.sam2_processor.apply_continuous_correction( if detections:
master_state, master_eye, frame_idx, self.detector print(f" Adding {len(detections)} fresh detection(s) for correction")
) # Add fresh detections to help correct drift
self.sam2_processor.add_frame_and_detections(master_eye, detections, frame_idx)
# Transfer corrections to slave eye # Transfer corrections to slave eye
# Note: This is simplified - actual implementation would transfer the refined prompts # Note: This is simplified - actual implementation would transfer the refined prompts

View File

@@ -0,0 +1,45 @@
"""
Timeout wrapper for SAM2 initialization to prevent hanging
"""
import signal
import functools
from typing import Any, Callable
class TimeoutError(Exception):
pass
def timeout(seconds: int = 120):
"""Decorator to add timeout to function calls"""
def decorator(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(*args, **kwargs) -> Any:
# Define signal handler
def timeout_handler(signum, frame):
raise TimeoutError(f"Function {func.__name__} timed out after {seconds} seconds")
# Set signal handler
old_handler = signal.signal(signal.SIGALRM, timeout_handler)
signal.alarm(seconds)
try:
result = func(*args, **kwargs)
finally:
# Restore old handler
signal.alarm(0)
signal.signal(signal.SIGALRM, old_handler)
return result
return wrapper
return decorator
@timeout(120) # 2 minute timeout
def safe_init_state(predictor, video_path: str, **kwargs) -> Any:
"""Safely initialize SAM2 state with timeout"""
return predictor.init_state(
video_path=video_path,
**kwargs
)