Files
test2/vr180_streaming/sam2_streaming.py
2025-07-27 09:04:40 -07:00

527 lines
21 KiB
Python

"""
SAM2 streaming processor for frame-by-frame video segmentation
NOTE: This is a template implementation. The actual SAM2 integration would need to:
1. Handle the fact that SAM2VideoPredictor loads the entire video internally
2. Potentially modify SAM2 to support frame-by-frame input
3. Or use a custom video loader that provides frames on demand
For a true streaming implementation, you may need to:
- Extend SAM2VideoPredictor to accept a frame generator instead of video path
- Implement a custom video loader that doesn't load all frames at once
- Use the memory offloading features more aggressively
"""
import torch
import numpy as np
from pathlib import Path
from typing import Dict, Any, List, Optional, Tuple, Generator
import warnings
import gc
# Import SAM2 components - these will be available after SAM2 installation
try:
from sam2.build_sam import build_sam2_video_predictor
from sam2.utils.misc import load_video_frames
except ImportError:
warnings.warn("SAM2 not installed. Please install segment-anything-2 first.")
class SAM2StreamingProcessor:
"""Streaming integration with SAM2 video predictor"""
def __init__(self, config: Dict[str, Any]):
self.config = config
self.device = torch.device(config.get('hardware', {}).get('device', 'cuda'))
# Processing parameters (set before _init_predictor)
self.memory_offload = config.get('matting', {}).get('memory_offload', True)
self.fp16 = config.get('matting', {}).get('fp16', True)
self.correction_interval = config.get('matting', {}).get('correction_interval', 300)
# SAM2 model configuration
model_cfg = config.get('matting', {}).get('sam2_model_cfg', 'sam2.1_hiera_l')
checkpoint = config.get('matting', {}).get('sam2_checkpoint',
'segment-anything-2/checkpoints/sam2.1_hiera_large.pt')
# Build predictor
self.predictor = None
self._init_predictor(model_cfg, checkpoint)
# State management
self.states = {} # eye -> inference state
self.object_ids = []
self.frame_count = 0
print(f"🎯 SAM2 streaming processor initialized:")
print(f" Model: {model_cfg}")
print(f" Device: {self.device}")
print(f" Memory offload: {self.memory_offload}")
print(f" FP16: {self.fp16}")
def _init_predictor(self, model_cfg: str, checkpoint: str) -> None:
"""Initialize SAM2 video predictor"""
try:
# Map config string to actual config path
config_mapping = {
'sam2.1_hiera_t': 'configs/sam2.1/sam2.1_hiera_t.yaml',
'sam2.1_hiera_s': 'configs/sam2.1/sam2.1_hiera_s.yaml',
'sam2.1_hiera_b+': 'configs/sam2.1/sam2.1_hiera_b+.yaml',
'sam2.1_hiera_l': 'configs/sam2.1/sam2.1_hiera_l.yaml',
}
actual_config = config_mapping.get(model_cfg, model_cfg)
# Build predictor with VOS optimizations
self.predictor = build_sam2_video_predictor(
actual_config,
checkpoint,
device=self.device,
vos_optimized=True # Enable full model compilation for speed
)
# Set to eval mode
self.predictor.eval()
# Note: FP16 conversion can cause type mismatches with compiled models
# Let SAM2 handle precision internally via build_sam2_video_predictor options
if self.fp16 and self.device.type == 'cuda':
print(" FP16 enabled via SAM2 internal settings")
except Exception as e:
raise RuntimeError(f"Failed to initialize SAM2 predictor: {e}")
def init_state(self,
video_info: Dict[str, Any],
eye: str = 'full') -> Dict[str, Any]:
"""
Initialize inference state for streaming (NO VIDEO LOADING)
Args:
video_info: Video metadata dict with width, height, frame_count
eye: Eye identifier ('left', 'right', or 'full')
Returns:
Inference state dictionary
"""
print(f" Initializing streaming state for {eye} eye...")
# Monitor memory before initialization
if torch.cuda.is_available():
before_mem = torch.cuda.memory_allocated() / 1e9
print(f" 📊 GPU memory before init: {before_mem:.1f}GB")
# Create streaming state WITHOUT loading video frames
state = self._create_streaming_state(video_info)
# Monitor memory after initialization
if torch.cuda.is_available():
after_mem = torch.cuda.memory_allocated() / 1e9
print(f" 📊 GPU memory after init: {after_mem:.1f}GB (+{after_mem-before_mem:.1f}GB)")
self.states[eye] = state
print(f" ✅ Streaming state initialized for {eye} eye")
return state
def _create_streaming_state(self, video_info: Dict[str, Any]) -> Dict[str, Any]:
"""Create streaming state for frame-by-frame processing"""
# Create a streaming-compatible inference state
# This mirrors SAM2's internal state structure but without video frames
# Use SAM2's init_state but with a dummy 1-frame video to avoid loading
# We'll override the frame access later
try:
# Create a minimal dummy video file temporarily
import tempfile
import cv2
# Create 1-frame dummy video
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_file:
dummy_path = tmp_file.name
# Write a single frame video
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(dummy_path, fourcc, 1.0, (video_info['width'], video_info['height']))
dummy_frame = np.zeros((video_info['height'], video_info['width'], 3), dtype=np.uint8)
out.write(dummy_frame)
out.release()
# Initialize with dummy video (SAM2 will load metadata only from 1 frame)
with torch.inference_mode():
inference_state = self.predictor.init_state(
video_path=dummy_path,
offload_video_to_cpu=self.memory_offload,
offload_state_to_cpu=self.memory_offload,
async_loading_frames=True
)
# Clean up dummy file
import os
os.unlink(dummy_path)
# Update state with actual video info
inference_state['num_frames'] = video_info.get('total_frames', video_info.get('frame_count', 0))
inference_state['video_height'] = video_info['height']
inference_state['video_width'] = video_info['width']
except Exception as e:
print(f" Warning: Failed to create proper SAM2 state ({e}), using minimal state")
# Fallback to minimal state
inference_state = {
'point_inputs_per_obj': {},
'mask_inputs_per_obj': {},
'cached_features': {},
'constants': {},
'obj_id_to_idx': {},
'obj_idx_to_id': {},
'obj_ids': [],
'click_inputs_per_obj': {},
'temp_output_dict_per_obj': {},
'consolidated_frame_inds': {},
'tracking_has_started': False,
'num_frames': video_info.get('total_frames', video_info.get('frame_count', 0)),
'video_height': video_info['height'],
'video_width': video_info['width'],
'device': self.device,
'storage_device': torch.device('cpu') if self.memory_offload else self.device,
'offload_video_to_cpu': self.memory_offload,
'offload_state_to_cpu': self.memory_offload,
}
return inference_state
def add_detections(self,
state: Dict[str, Any],
frame: np.ndarray,
detections: List[Dict[str, Any]],
frame_idx: int = 0) -> List[int]:
"""
Add detection boxes as prompts to SAM2 with frame data
Args:
state: Inference state
frame: Frame image (RGB numpy array)
detections: List of detections with 'box' key
frame_idx: Frame index to add prompts
Returns:
List of object IDs
"""
if not detections:
warnings.warn(f"No detections to add at frame {frame_idx}")
return []
# Convert frame to tensor
frame_tensor = torch.tensor(frame, dtype=torch.float32, device=self.device)
if frame_tensor.ndim == 3:
frame_tensor = frame_tensor.permute(2, 0, 1) # HWC -> CHW
frame_tensor = frame_tensor.unsqueeze(0) # Add batch dimension
# Convert detections to SAM2 format
boxes = []
for det in detections:
box = det['box'] # [x1, y1, x2, y2]
boxes.append(box)
boxes_tensor = torch.tensor(boxes, dtype=torch.float32, device=self.device)
# Manually process frame and add prompts (streaming approach)
with torch.inference_mode():
# Process frame through SAM2's image encoder
backbone_out = self.predictor.forward_image(frame_tensor)
# Store features in state for this frame
state['cached_features'][frame_idx] = backbone_out
# Add boxes as prompts for this specific frame
_, object_ids, masks = self.predictor.add_new_points_or_box(
inference_state=state,
frame_idx=frame_idx,
obj_id=None, # Let SAM2 auto-assign
box=boxes_tensor
)
self.object_ids = object_ids
print(f" Added {len(detections)} detections at frame {frame_idx}: objects {object_ids}")
return object_ids
def propagate_single_frame(self,
state: Dict[str, Any],
frame: np.ndarray,
frame_idx: int) -> np.ndarray:
"""
Propagate masks for a single frame (true streaming)
Args:
state: Inference state
frame: Frame image (RGB numpy array)
frame_idx: Frame index
Returns:
Combined mask for all objects
"""
# Convert frame to tensor
frame_tensor = torch.tensor(frame, dtype=torch.float32, device=self.device)
if frame_tensor.ndim == 3:
frame_tensor = frame_tensor.permute(2, 0, 1) # HWC -> CHW
frame_tensor = frame_tensor.unsqueeze(0) # Add batch dimension
with torch.inference_mode():
# Process frame through SAM2's image encoder
backbone_out = self.predictor.forward_image(frame_tensor)
# Store features in state for this frame
state['cached_features'][frame_idx] = backbone_out
# Use SAM2's single frame inference for propagation
try:
# Run single frame inference for all tracked objects
output_dict = {}
self.predictor._run_single_frame_inference(
inference_state=state,
output_dict=output_dict,
frame_idx=frame_idx,
batch_size=1,
is_init_cond_frame=False, # Not initialization frame
point_inputs=None,
mask_inputs=None,
reverse=False,
run_mem_encoder=True
)
# Extract masks from output
if output_dict and 'pred_masks' in output_dict:
pred_masks = output_dict['pred_masks']
# Combine all object masks
if pred_masks.shape[0] > 0:
combined_mask = pred_masks.max(dim=0)[0]
combined_mask_np = (combined_mask > 0.0).cpu().numpy().astype(np.uint8) * 255
else:
combined_mask_np = np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8)
else:
combined_mask_np = np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8)
except Exception as e:
print(f" Warning: Single frame inference failed: {e}")
combined_mask_np = np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8)
# Cleanup old features to prevent memory accumulation
self._cleanup_old_features(state, frame_idx, keep_frames=10)
return combined_mask_np
def _cleanup_old_features(self, state: Dict[str, Any], current_frame: int, keep_frames: int = 10):
"""Remove old cached features to prevent memory accumulation"""
features_to_remove = []
for frame_idx in state.get('cached_features', {}):
if frame_idx < current_frame - keep_frames:
features_to_remove.append(frame_idx)
for frame_idx in features_to_remove:
del state['cached_features'][frame_idx]
# Periodic GPU memory cleanup
if current_frame % 50 == 0:
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
def propagate_frame_pair(self,
left_state: Dict[str, Any],
right_state: Dict[str, Any],
left_frame: np.ndarray,
right_frame: np.ndarray,
frame_idx: int) -> Tuple[np.ndarray, np.ndarray]:
"""
Propagate masks for a stereo frame pair
Args:
left_state: Left eye inference state
right_state: Right eye inference state
left_frame: Left eye frame
right_frame: Right eye frame
frame_idx: Current frame index
Returns:
Tuple of (left_masks, right_masks)
"""
# For actual implementation, we would need to handle the video frames
# being already loaded in the state. This is a simplified version.
# In practice, SAM2's propagate_in_video would handle frame loading.
# Get masks from the current propagation state
# This is pseudo-code as actual integration would depend on
# how frames are provided to SAM2VideoPredictor
left_masks = np.zeros((left_frame.shape[0], left_frame.shape[1]), dtype=np.uint8)
right_masks = np.zeros((right_frame.shape[0], right_frame.shape[1]), dtype=np.uint8)
# In actual implementation, you would:
# 1. Use predictor.propagate_in_video() generator
# 2. Extract masks for current frame_idx
# 3. Combine multiple object masks if needed
return left_masks, right_masks
def _propagate_single_frame(self,
state: Dict[str, Any],
frame: np.ndarray,
frame_idx: int) -> np.ndarray:
"""
Propagate masks for a single frame
Args:
state: Inference state
frame: Input frame
frame_idx: Frame index
Returns:
Combined mask for all objects
"""
# This is a simplified version - in practice we'd use the actual
# SAM2 propagation API which handles memory updates internally
# Get current masks from propagation
# Note: This is pseudo-code as the actual API may differ
masks = []
# For each tracked object
for obj_idx in range(len(self.object_ids)):
# Get mask for this object
# In reality, SAM2 handles this internally
obj_mask = self._get_object_mask(state, obj_idx, frame_idx)
masks.append(obj_mask)
# Combine all object masks
if masks:
combined_mask = np.max(masks, axis=0)
else:
combined_mask = np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8)
return combined_mask
def _get_object_mask(self, state: Dict[str, Any], obj_idx: int, frame_idx: int) -> np.ndarray:
"""
Get mask for specific object (placeholder - actual implementation uses SAM2 API)
"""
# In practice, this would extract the mask from SAM2's internal state
# For now, return a placeholder
h, w = state.get('video_height', 1080), state.get('video_width', 1920)
return np.zeros((h, w), dtype=np.uint8)
def apply_continuous_correction(self,
state: Dict[str, Any],
frame: np.ndarray,
frame_idx: int,
detector: Any) -> None:
"""
Apply continuous correction by re-detecting and refining masks
Args:
state: Inference state
frame: Current frame
frame_idx: Frame index
detector: Person detector instance
"""
if frame_idx % self.correction_interval != 0:
return
print(f" 🔄 Applying continuous correction at frame {frame_idx}")
# Detect persons in current frame
new_detections = detector.detect_persons(frame)
if new_detections:
# Add new prompts to refine tracking
with torch.inference_mode():
boxes = [det['box'] for det in new_detections]
boxes_tensor = torch.tensor(boxes, dtype=torch.float32, device=self.device)
# Add refinement prompts
self.predictor.add_new_points_or_box(
inference_state=state,
frame_idx=frame_idx,
obj_id=0, # Refine existing objects
box=boxes_tensor
)
def apply_mask_to_frame(self,
frame: np.ndarray,
mask: np.ndarray,
output_format: str = 'greenscreen',
background_color: List[int] = [0, 255, 0]) -> np.ndarray:
"""
Apply mask to frame with specified output format
Args:
frame: Input frame (BGR)
mask: Binary mask
output_format: 'alpha' or 'greenscreen'
background_color: Background color for greenscreen
Returns:
Processed frame
"""
if output_format == 'alpha':
# Add alpha channel
if mask.dtype != np.uint8:
mask = (mask * 255).astype(np.uint8)
# Create BGRA image
bgra = np.zeros((frame.shape[0], frame.shape[1], 4), dtype=np.uint8)
bgra[:, :, :3] = frame
bgra[:, :, 3] = mask
return bgra
else: # greenscreen
# Create green background
background = np.full_like(frame, background_color, dtype=np.uint8)
# Expand mask to 3 channels
if mask.ndim == 2:
mask_3ch = np.expand_dims(mask, axis=2)
mask_3ch = np.repeat(mask_3ch, 3, axis=2)
else:
mask_3ch = mask
# Normalize mask to 0-1
if mask_3ch.dtype == np.uint8:
mask_float = mask_3ch.astype(np.float32) / 255.0
else:
mask_float = mask_3ch.astype(np.float32)
# Composite
result = (frame * mask_float + background * (1 - mask_float)).astype(np.uint8)
return result
def cleanup(self) -> None:
"""Clean up resources"""
# Clear states
self.states.clear()
# Clear CUDA cache
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
# Garbage collection
gc.collect()
print("🧹 SAM2 streaming processor cleaned up")
def get_memory_usage(self) -> Dict[str, float]:
"""Get current memory usage"""
memory_stats = {
'states_count': len(self.states),
'object_count': len(self.object_ids),
}
if torch.cuda.is_available():
memory_stats['cuda_allocated_gb'] = torch.cuda.memory_allocated() / 1e9
memory_stats['cuda_reserved_gb'] = torch.cuda.memory_reserved() / 1e9
return memory_stats