Files
test2/vr180_streaming/sam2_streaming.py
2025-07-27 09:52:56 -07:00

629 lines
26 KiB
Python

"""
SAM2 streaming processor for frame-by-frame video segmentation
NOTE: This is a template implementation. The actual SAM2 integration would need to:
1. Handle the fact that SAM2VideoPredictor loads the entire video internally
2. Potentially modify SAM2 to support frame-by-frame input
3. Or use a custom video loader that provides frames on demand
For a true streaming implementation, you may need to:
- Extend SAM2VideoPredictor to accept a frame generator instead of video path
- Implement a custom video loader that doesn't load all frames at once
- Use the memory offloading features more aggressively
"""
import torch
import numpy as np
import cv2
from pathlib import Path
from typing import Dict, Any, List, Optional, Tuple, Generator
import warnings
import gc
# Import SAM2 components - these will be available after SAM2 installation
try:
from sam2.build_sam import build_sam2_video_predictor
from sam2.utils.misc import load_video_frames
except ImportError:
warnings.warn("SAM2 not installed. Please install segment-anything-2 first.")
class SAM2StreamingProcessor:
"""Streaming integration with SAM2 video predictor"""
def __init__(self, config: Dict[str, Any]):
self.config = config
self.device = torch.device(config.get('hardware', {}).get('device', 'cuda'))
# Processing parameters (set before _init_predictor)
self.memory_offload = config.get('matting', {}).get('memory_offload', True)
self.fp16 = config.get('matting', {}).get('fp16', True)
self.correction_interval = config.get('matting', {}).get('correction_interval', 300)
# SAM2 model configuration
model_cfg = config.get('matting', {}).get('sam2_model_cfg', 'sam2.1_hiera_l')
checkpoint = config.get('matting', {}).get('sam2_checkpoint',
'segment-anything-2/checkpoints/sam2.1_hiera_large.pt')
# Build predictor
self.predictor = None
self._init_predictor(model_cfg, checkpoint)
# State management
self.states = {} # eye -> inference state
self.object_ids = []
self.frame_count = 0
print(f"🎯 SAM2 streaming processor initialized:")
print(f" Model: {model_cfg}")
print(f" Device: {self.device}")
print(f" Memory offload: {self.memory_offload}")
print(f" FP16: {self.fp16}")
def _init_predictor(self, model_cfg: str, checkpoint: str) -> None:
"""Initialize SAM2 video predictor"""
try:
# Map config string to actual config path
config_mapping = {
'sam2.1_hiera_t': 'configs/sam2.1/sam2.1_hiera_t.yaml',
'sam2.1_hiera_s': 'configs/sam2.1/sam2.1_hiera_s.yaml',
'sam2.1_hiera_b+': 'configs/sam2.1/sam2.1_hiera_b+.yaml',
'sam2.1_hiera_l': 'configs/sam2.1/sam2.1_hiera_l.yaml',
}
actual_config = config_mapping.get(model_cfg, model_cfg)
# Build predictor with VOS optimizations
self.predictor = build_sam2_video_predictor(
actual_config,
checkpoint,
device=self.device,
vos_optimized=True # Enable full model compilation for speed
)
# Set to eval mode and ensure all model components are on GPU
self.predictor.eval()
# Force all predictor components to GPU
self.predictor = self.predictor.to(self.device)
# Force move all internal components that might be on CPU
if hasattr(self.predictor, 'image_encoder'):
self.predictor.image_encoder = self.predictor.image_encoder.to(self.device)
if hasattr(self.predictor, 'memory_attention'):
self.predictor.memory_attention = self.predictor.memory_attention.to(self.device)
if hasattr(self.predictor, 'memory_encoder'):
self.predictor.memory_encoder = self.predictor.memory_encoder.to(self.device)
if hasattr(self.predictor, 'sam_mask_decoder'):
self.predictor.sam_mask_decoder = self.predictor.sam_mask_decoder.to(self.device)
if hasattr(self.predictor, 'sam_prompt_encoder'):
self.predictor.sam_prompt_encoder = self.predictor.sam_prompt_encoder.to(self.device)
# Note: FP16 conversion can cause type mismatches with compiled models
# Let SAM2 handle precision internally via build_sam2_video_predictor options
if self.fp16 and self.device.type == 'cuda':
print(" FP16 enabled via SAM2 internal settings")
print(f" All SAM2 components moved to {self.device}")
except Exception as e:
raise RuntimeError(f"Failed to initialize SAM2 predictor: {e}")
def init_state(self,
video_info: Dict[str, Any],
eye: str = 'full') -> Dict[str, Any]:
"""
Initialize inference state for streaming (NO VIDEO LOADING)
Args:
video_info: Video metadata dict with width, height, frame_count
eye: Eye identifier ('left', 'right', or 'full')
Returns:
Inference state dictionary
"""
print(f" Initializing streaming state for {eye} eye...")
# Monitor memory before initialization
if torch.cuda.is_available():
before_mem = torch.cuda.memory_allocated() / 1e9
print(f" 📊 GPU memory before init: {before_mem:.1f}GB")
# Create streaming state WITHOUT loading video frames
state = self._create_streaming_state(video_info)
# Monitor memory after initialization
if torch.cuda.is_available():
after_mem = torch.cuda.memory_allocated() / 1e9
print(f" 📊 GPU memory after init: {after_mem:.1f}GB (+{after_mem-before_mem:.1f}GB)")
self.states[eye] = state
print(f" ✅ Streaming state initialized for {eye} eye")
return state
def _create_streaming_state(self, video_info: Dict[str, Any]) -> Dict[str, Any]:
"""Create streaming state for frame-by-frame processing"""
# Create a streaming-compatible inference state
# This mirrors SAM2's internal state structure but without video frames
# Create streaming-compatible state without loading video
# This approach avoids the dummy video complexity
with torch.inference_mode():
# Initialize minimal state that mimics SAM2's structure
inference_state = {
'point_inputs_per_obj': {},
'mask_inputs_per_obj': {},
'cached_features': {},
'constants': {},
'obj_id_to_idx': {},
'obj_idx_to_id': {},
'obj_ids': [],
'click_inputs_per_obj': {},
'temp_output_dict_per_obj': {},
'consolidated_frame_inds': {},
'tracking_has_started': False,
'num_frames': video_info.get('total_frames', video_info.get('frame_count', 0)),
'video_height': video_info['height'],
'video_width': video_info['width'],
'device': self.device,
'storage_device': self.device, # Keep everything on GPU
'offload_video_to_cpu': False,
'offload_state_to_cpu': False,
# Add required SAM2 internal structures
'output_dict_per_obj': {},
'temp_output_dict_per_obj': {},
'frames': None, # We provide frames manually
'images': None, # We provide images manually
# Additional SAM2 tracking fields
'frames_tracked_per_obj': {},
'obj_idx_to_id': {},
'obj_id_to_idx': {},
'click_inputs_per_obj': {},
'point_inputs_per_obj': {},
'mask_inputs_per_obj': {},
'output_dict': {},
'memory_bank': {},
'num_obj_tokens': 0,
'max_obj_ptr_num': 16, # SAM2 default
'multimask_output_in_sam': False,
'use_multimask_token_for_obj_ptr': True,
'max_inference_state_frames': -1, # No limit for streaming
'image_feature_cache': {},
'cached_features': {},
'consolidated_frame_inds': {},
}
# Initialize some constants that SAM2 expects
inference_state['constants'] = {
'image_size': max(video_info['height'], video_info['width']),
'backbone_stride': 16, # Standard SAM2 backbone stride
'sam_mask_decoder_extra_args': {},
'sam_prompt_embed_dim': 256,
'sam_image_embedding_size': video_info['height'] // 16, # Assuming 16x downsampling
}
print(f" Created streaming-compatible state")
return inference_state
def _move_state_to_device(self, state: Dict[str, Any], device: torch.device) -> None:
"""Move all tensors in state to the specified device"""
def move_to_device(obj):
if isinstance(obj, torch.Tensor):
return obj.to(device)
elif isinstance(obj, dict):
return {k: move_to_device(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [move_to_device(item) for item in obj]
elif isinstance(obj, tuple):
return tuple(move_to_device(item) for item in obj)
else:
return obj
# Move all state components to device
for key, value in state.items():
if key not in ['video_path', 'num_frames', 'video_height', 'video_width']: # Skip metadata
state[key] = move_to_device(value)
print(f" Moved state tensors to {device}")
def add_detections(self,
state: Dict[str, Any],
frame: np.ndarray,
detections: List[Dict[str, Any]],
frame_idx: int = 0) -> List[int]:
"""
Add detection boxes as prompts to SAM2 with frame data
Args:
state: Inference state
frame: Frame image (RGB numpy array)
detections: List of detections with 'box' key
frame_idx: Frame index to add prompts
Returns:
List of object IDs
"""
if not detections:
warnings.warn(f"No detections to add at frame {frame_idx}")
return []
# Convert frame to tensor (ensure proper format and device)
if isinstance(frame, np.ndarray):
# Convert BGR to RGB if needed (OpenCV uses BGR)
if frame.shape[-1] == 3:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame_tensor = torch.from_numpy(frame).float().to(self.device)
else:
frame_tensor = frame.float().to(self.device)
if frame_tensor.ndim == 3:
frame_tensor = frame_tensor.permute(2, 0, 1) # HWC -> CHW
frame_tensor = frame_tensor.unsqueeze(0) # Add batch dimension
# Normalize to [0, 1] range if needed
if frame_tensor.max() > 1.0:
frame_tensor = frame_tensor / 255.0
# Convert detections to SAM2 format
boxes = []
for det in detections:
box = det['box'] # [x1, y1, x2, y2]
boxes.append(box)
boxes_tensor = torch.tensor(boxes, dtype=torch.float32, device=self.device)
# Manually process frame and add prompts (streaming approach)
with torch.inference_mode():
# Process frame through SAM2's image encoder
backbone_out = self.predictor.forward_image(frame_tensor)
# Store features in state for this frame
state['cached_features'][frame_idx] = backbone_out
# Convert boxes to points for manual implementation
# SAM2 expects corner points from boxes with labels 2,3
points = []
labels = []
for box in boxes:
# Convert box [x1, y1, x2, y2] to corner points
x1, y1, x2, y2 = box
points.extend([[x1, y1], [x2, y2]]) # Top-left and bottom-right corners
labels.extend([2, 3]) # SAM2 standard labels for box corners
points_tensor = torch.tensor(points, dtype=torch.float32, device=self.device)
labels_tensor = torch.tensor(labels, dtype=torch.int32, device=self.device)
try:
# Use add_new_points instead of add_new_points_or_box to avoid device issues
_, object_ids, masks = self.predictor.add_new_points(
inference_state=state,
frame_idx=frame_idx,
obj_id=None, # Let SAM2 auto-assign
points=points_tensor,
labels=labels_tensor,
clear_old_points=True,
normalize_coords=True
)
# Update state with object tracking info
state['obj_ids'] = object_ids
state['tracking_has_started'] = True
except Exception as e:
print(f" Error in add_new_points: {e}")
print(f" Points tensor device: {points_tensor.device}")
print(f" Labels tensor device: {labels_tensor.device}")
print(f" Frame tensor device: {frame_tensor.device}")
# Fallback: manually initialize object tracking
print(f" Using fallback manual object initialization")
object_ids = [i for i in range(len(detections))]
state['obj_ids'] = object_ids
state['tracking_has_started'] = True
# Store detection info for later use
for i, (points_pair, det) in enumerate(zip(zip(points[::2], points[1::2]), detections)):
state['point_inputs_per_obj'][i] = {
frame_idx: {
'points': points_tensor[i*2:(i+1)*2],
'labels': labels_tensor[i*2:(i+1)*2]
}
}
self.object_ids = object_ids
print(f" Added {len(detections)} detections at frame {frame_idx}: objects {object_ids}")
return object_ids
def propagate_single_frame(self,
state: Dict[str, Any],
frame: np.ndarray,
frame_idx: int) -> np.ndarray:
"""
Propagate masks for a single frame (true streaming)
Args:
state: Inference state
frame: Frame image (RGB numpy array)
frame_idx: Frame index
Returns:
Combined mask for all objects
"""
# Convert frame to tensor (ensure proper format and device)
if isinstance(frame, np.ndarray):
# Convert BGR to RGB if needed (OpenCV uses BGR)
if frame.shape[-1] == 3:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame_tensor = torch.from_numpy(frame).float().to(self.device)
else:
frame_tensor = frame.float().to(self.device)
if frame_tensor.ndim == 3:
frame_tensor = frame_tensor.permute(2, 0, 1) # HWC -> CHW
frame_tensor = frame_tensor.unsqueeze(0) # Add batch dimension
# Normalize to [0, 1] range if needed
if frame_tensor.max() > 1.0:
frame_tensor = frame_tensor / 255.0
with torch.inference_mode():
# Process frame through SAM2's image encoder
backbone_out = self.predictor.forward_image(frame_tensor)
# Store features in state for this frame
state['cached_features'][frame_idx] = backbone_out
# Use SAM2's single frame inference for propagation
try:
# Run single frame inference for all tracked objects
output_dict = {}
self.predictor._run_single_frame_inference(
inference_state=state,
output_dict=output_dict,
frame_idx=frame_idx,
batch_size=1,
is_init_cond_frame=False, # Not initialization frame
point_inputs=None,
mask_inputs=None,
reverse=False,
run_mem_encoder=True
)
# Extract masks from output
if output_dict and 'pred_masks' in output_dict:
pred_masks = output_dict['pred_masks']
# Combine all object masks
if pred_masks.shape[0] > 0:
combined_mask = pred_masks.max(dim=0)[0]
combined_mask_np = (combined_mask > 0.0).cpu().numpy().astype(np.uint8) * 255
else:
combined_mask_np = np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8)
else:
combined_mask_np = np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8)
except Exception as e:
print(f" Warning: Single frame inference failed: {e}")
combined_mask_np = np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8)
# Cleanup old features to prevent memory accumulation
self._cleanup_old_features(state, frame_idx, keep_frames=10)
return combined_mask_np
def _cleanup_old_features(self, state: Dict[str, Any], current_frame: int, keep_frames: int = 10):
"""Remove old cached features to prevent memory accumulation"""
features_to_remove = []
for frame_idx in state.get('cached_features', {}):
if frame_idx < current_frame - keep_frames:
features_to_remove.append(frame_idx)
for frame_idx in features_to_remove:
del state['cached_features'][frame_idx]
# Periodic GPU memory cleanup
if current_frame % 50 == 0:
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
def propagate_frame_pair(self,
left_state: Dict[str, Any],
right_state: Dict[str, Any],
left_frame: np.ndarray,
right_frame: np.ndarray,
frame_idx: int) -> Tuple[np.ndarray, np.ndarray]:
"""
Propagate masks for a stereo frame pair
Args:
left_state: Left eye inference state
right_state: Right eye inference state
left_frame: Left eye frame
right_frame: Right eye frame
frame_idx: Current frame index
Returns:
Tuple of (left_masks, right_masks)
"""
# For actual implementation, we would need to handle the video frames
# being already loaded in the state. This is a simplified version.
# In practice, SAM2's propagate_in_video would handle frame loading.
# Get masks from the current propagation state
# This is pseudo-code as actual integration would depend on
# how frames are provided to SAM2VideoPredictor
left_masks = np.zeros((left_frame.shape[0], left_frame.shape[1]), dtype=np.uint8)
right_masks = np.zeros((right_frame.shape[0], right_frame.shape[1]), dtype=np.uint8)
# In actual implementation, you would:
# 1. Use predictor.propagate_in_video() generator
# 2. Extract masks for current frame_idx
# 3. Combine multiple object masks if needed
return left_masks, right_masks
def _propagate_single_frame(self,
state: Dict[str, Any],
frame: np.ndarray,
frame_idx: int) -> np.ndarray:
"""
Propagate masks for a single frame
Args:
state: Inference state
frame: Input frame
frame_idx: Frame index
Returns:
Combined mask for all objects
"""
# This is a simplified version - in practice we'd use the actual
# SAM2 propagation API which handles memory updates internally
# Get current masks from propagation
# Note: This is pseudo-code as the actual API may differ
masks = []
# For each tracked object
for obj_idx in range(len(self.object_ids)):
# Get mask for this object
# In reality, SAM2 handles this internally
obj_mask = self._get_object_mask(state, obj_idx, frame_idx)
masks.append(obj_mask)
# Combine all object masks
if masks:
combined_mask = np.max(masks, axis=0)
else:
combined_mask = np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8)
return combined_mask
def _get_object_mask(self, state: Dict[str, Any], obj_idx: int, frame_idx: int) -> np.ndarray:
"""
Get mask for specific object (placeholder - actual implementation uses SAM2 API)
"""
# In practice, this would extract the mask from SAM2's internal state
# For now, return a placeholder
h, w = state.get('video_height', 1080), state.get('video_width', 1920)
return np.zeros((h, w), dtype=np.uint8)
def apply_continuous_correction(self,
state: Dict[str, Any],
frame: np.ndarray,
frame_idx: int,
detector: Any) -> None:
"""
Apply continuous correction by re-detecting and refining masks
Args:
state: Inference state
frame: Current frame
frame_idx: Frame index
detector: Person detector instance
"""
if frame_idx % self.correction_interval != 0:
return
print(f" 🔄 Applying continuous correction at frame {frame_idx}")
# Detect persons in current frame
new_detections = detector.detect_persons(frame)
if new_detections:
# Add new prompts to refine tracking
with torch.inference_mode():
boxes = [det['box'] for det in new_detections]
boxes_tensor = torch.tensor(boxes, dtype=torch.float32, device=self.device)
# Add refinement prompts
self.predictor.add_new_points_or_box(
inference_state=state,
frame_idx=frame_idx,
obj_id=0, # Refine existing objects
box=boxes_tensor
)
def apply_mask_to_frame(self,
frame: np.ndarray,
mask: np.ndarray,
output_format: str = 'greenscreen',
background_color: List[int] = [0, 255, 0]) -> np.ndarray:
"""
Apply mask to frame with specified output format
Args:
frame: Input frame (BGR)
mask: Binary mask
output_format: 'alpha' or 'greenscreen'
background_color: Background color for greenscreen
Returns:
Processed frame
"""
if output_format == 'alpha':
# Add alpha channel
if mask.dtype != np.uint8:
mask = (mask * 255).astype(np.uint8)
# Create BGRA image
bgra = np.zeros((frame.shape[0], frame.shape[1], 4), dtype=np.uint8)
bgra[:, :, :3] = frame
bgra[:, :, 3] = mask
return bgra
else: # greenscreen
# Create green background
background = np.full_like(frame, background_color, dtype=np.uint8)
# Expand mask to 3 channels
if mask.ndim == 2:
mask_3ch = np.expand_dims(mask, axis=2)
mask_3ch = np.repeat(mask_3ch, 3, axis=2)
else:
mask_3ch = mask
# Normalize mask to 0-1
if mask_3ch.dtype == np.uint8:
mask_float = mask_3ch.astype(np.float32) / 255.0
else:
mask_float = mask_3ch.astype(np.float32)
# Composite
result = (frame * mask_float + background * (1 - mask_float)).astype(np.uint8)
return result
def cleanup(self) -> None:
"""Clean up resources"""
# Clear states
self.states.clear()
# Clear CUDA cache
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
# Garbage collection
gc.collect()
print("🧹 SAM2 streaming processor cleaned up")
def get_memory_usage(self) -> Dict[str, float]:
"""Get current memory usage"""
memory_stats = {
'states_count': len(self.states),
'object_count': len(self.object_ids),
}
if torch.cuda.is_available():
memory_stats['cuda_allocated_gb'] = torch.cuda.memory_allocated() / 1e9
memory_stats['cuda_reserved_gb'] = torch.cuda.memory_reserved() / 1e9
return memory_stats