Files
test2/vr180_streaming/sam2_streaming.py
2025-07-27 08:19:42 -07:00

381 lines
14 KiB
Python

"""
SAM2 streaming processor for frame-by-frame video segmentation
NOTE: This is a template implementation. The actual SAM2 integration would need to:
1. Handle the fact that SAM2VideoPredictor loads the entire video internally
2. Potentially modify SAM2 to support frame-by-frame input
3. Or use a custom video loader that provides frames on demand
For a true streaming implementation, you may need to:
- Extend SAM2VideoPredictor to accept a frame generator instead of video path
- Implement a custom video loader that doesn't load all frames at once
- Use the memory offloading features more aggressively
"""
import torch
import numpy as np
from pathlib import Path
from typing import Dict, Any, List, Optional, Tuple, Generator
import warnings
import gc
# Import SAM2 components - these will be available after SAM2 installation
try:
from sam2.build_sam import build_sam2_video_predictor
from sam2.utils.misc import load_video_frames
except ImportError:
warnings.warn("SAM2 not installed. Please install segment-anything-2 first.")
class SAM2StreamingProcessor:
"""Streaming integration with SAM2 video predictor"""
def __init__(self, config: Dict[str, Any]):
self.config = config
self.device = torch.device(config.get('hardware', {}).get('device', 'cuda'))
# Processing parameters (set before _init_predictor)
self.memory_offload = config.get('matting', {}).get('memory_offload', True)
self.fp16 = config.get('matting', {}).get('fp16', True)
self.correction_interval = config.get('matting', {}).get('correction_interval', 300)
# SAM2 model configuration
model_cfg = config.get('matting', {}).get('sam2_model_cfg', 'sam2.1_hiera_l')
checkpoint = config.get('matting', {}).get('sam2_checkpoint',
'segment-anything-2/checkpoints/sam2.1_hiera_large.pt')
# Build predictor
self.predictor = None
self._init_predictor(model_cfg, checkpoint)
# State management
self.states = {} # eye -> inference state
self.object_ids = []
self.frame_count = 0
print(f"🎯 SAM2 streaming processor initialized:")
print(f" Model: {model_cfg}")
print(f" Device: {self.device}")
print(f" Memory offload: {self.memory_offload}")
print(f" FP16: {self.fp16}")
def _init_predictor(self, model_cfg: str, checkpoint: str) -> None:
"""Initialize SAM2 video predictor"""
try:
# Map config string to actual config path
config_mapping = {
'sam2.1_hiera_t': 'configs/sam2.1/sam2.1_hiera_t.yaml',
'sam2.1_hiera_s': 'configs/sam2.1/sam2.1_hiera_s.yaml',
'sam2.1_hiera_b+': 'configs/sam2.1/sam2.1_hiera_b+.yaml',
'sam2.1_hiera_l': 'configs/sam2.1/sam2.1_hiera_l.yaml',
}
actual_config = config_mapping.get(model_cfg, model_cfg)
# Build predictor with VOS optimizations
self.predictor = build_sam2_video_predictor(
actual_config,
checkpoint,
device=self.device,
vos_optimized=True # Enable full model compilation for speed
)
# Set to eval mode
self.predictor.eval()
# Enable FP16 if requested
if self.fp16 and self.device.type == 'cuda':
self.predictor = self.predictor.half()
except Exception as e:
raise RuntimeError(f"Failed to initialize SAM2 predictor: {e}")
def init_state(self,
video_path: str,
eye: str = 'full') -> Dict[str, Any]:
"""
Initialize inference state for streaming
Args:
video_path: Path to video file
eye: Eye identifier ('left', 'right', or 'full')
Returns:
Inference state dictionary
"""
# Initialize state with memory offloading enabled
with torch.inference_mode():
state = self.predictor.init_state(
video_path=video_path,
offload_video_to_cpu=self.memory_offload,
offload_state_to_cpu=self.memory_offload,
async_loading_frames=False # We'll provide frames directly
)
self.states[eye] = state
print(f" Initialized state for {eye} eye")
return state
def add_detections(self,
state: Dict[str, Any],
detections: List[Dict[str, Any]],
frame_idx: int = 0) -> List[int]:
"""
Add detection boxes as prompts to SAM2
Args:
state: Inference state
detections: List of detections with 'box' key
frame_idx: Frame index to add prompts
Returns:
List of object IDs
"""
if not detections:
warnings.warn(f"No detections to add at frame {frame_idx}")
return []
# Convert detections to SAM2 format
boxes = []
for det in detections:
box = det['box'] # [x1, y1, x2, y2]
boxes.append(box)
boxes_tensor = torch.tensor(boxes, dtype=torch.float32, device=self.device)
# Add boxes as prompts
with torch.inference_mode():
_, object_ids, _ = self.predictor.add_new_points_or_box(
inference_state=state,
frame_idx=frame_idx,
obj_id=0, # SAM2 will auto-increment
box=boxes_tensor
)
self.object_ids = object_ids
print(f" Added {len(detections)} detections at frame {frame_idx}: objects {object_ids}")
return object_ids
def propagate_in_video_simple(self,
state: Dict[str, Any]) -> Generator[Tuple[int, List[int], np.ndarray], None, None]:
"""
Simple propagation for single eye processing
Yields:
(frame_idx, object_ids, masks) tuples
"""
with torch.inference_mode():
for frame_idx, object_ids, masks in self.predictor.propagate_in_video(state):
# Convert masks to numpy
if isinstance(masks, torch.Tensor):
masks_np = masks.cpu().numpy()
else:
masks_np = masks
yield frame_idx, object_ids, masks_np
# Periodic memory cleanup
if frame_idx % 100 == 0:
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
def propagate_frame_pair(self,
left_state: Dict[str, Any],
right_state: Dict[str, Any],
left_frame: np.ndarray,
right_frame: np.ndarray,
frame_idx: int) -> Tuple[np.ndarray, np.ndarray]:
"""
Propagate masks for a stereo frame pair
Args:
left_state: Left eye inference state
right_state: Right eye inference state
left_frame: Left eye frame
right_frame: Right eye frame
frame_idx: Current frame index
Returns:
Tuple of (left_masks, right_masks)
"""
# For actual implementation, we would need to handle the video frames
# being already loaded in the state. This is a simplified version.
# In practice, SAM2's propagate_in_video would handle frame loading.
# Get masks from the current propagation state
# This is pseudo-code as actual integration would depend on
# how frames are provided to SAM2VideoPredictor
left_masks = np.zeros((left_frame.shape[0], left_frame.shape[1]), dtype=np.uint8)
right_masks = np.zeros((right_frame.shape[0], right_frame.shape[1]), dtype=np.uint8)
# In actual implementation, you would:
# 1. Use predictor.propagate_in_video() generator
# 2. Extract masks for current frame_idx
# 3. Combine multiple object masks if needed
return left_masks, right_masks
def _propagate_single_frame(self,
state: Dict[str, Any],
frame: np.ndarray,
frame_idx: int) -> np.ndarray:
"""
Propagate masks for a single frame
Args:
state: Inference state
frame: Input frame
frame_idx: Frame index
Returns:
Combined mask for all objects
"""
# This is a simplified version - in practice we'd use the actual
# SAM2 propagation API which handles memory updates internally
# Get current masks from propagation
# Note: This is pseudo-code as the actual API may differ
masks = []
# For each tracked object
for obj_idx in range(len(self.object_ids)):
# Get mask for this object
# In reality, SAM2 handles this internally
obj_mask = self._get_object_mask(state, obj_idx, frame_idx)
masks.append(obj_mask)
# Combine all object masks
if masks:
combined_mask = np.max(masks, axis=0)
else:
combined_mask = np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8)
return combined_mask
def _get_object_mask(self, state: Dict[str, Any], obj_idx: int, frame_idx: int) -> np.ndarray:
"""
Get mask for specific object (placeholder - actual implementation uses SAM2 API)
"""
# In practice, this would extract the mask from SAM2's internal state
# For now, return a placeholder
h, w = state.get('video_height', 1080), state.get('video_width', 1920)
return np.zeros((h, w), dtype=np.uint8)
def apply_continuous_correction(self,
state: Dict[str, Any],
frame: np.ndarray,
frame_idx: int,
detector: Any) -> None:
"""
Apply continuous correction by re-detecting and refining masks
Args:
state: Inference state
frame: Current frame
frame_idx: Frame index
detector: Person detector instance
"""
if frame_idx % self.correction_interval != 0:
return
print(f" 🔄 Applying continuous correction at frame {frame_idx}")
# Detect persons in current frame
new_detections = detector.detect_persons(frame)
if new_detections:
# Add new prompts to refine tracking
with torch.inference_mode():
boxes = [det['box'] for det in new_detections]
boxes_tensor = torch.tensor(boxes, dtype=torch.float32, device=self.device)
# Add refinement prompts
self.predictor.add_new_points_or_box(
inference_state=state,
frame_idx=frame_idx,
obj_id=0, # Refine existing objects
box=boxes_tensor
)
def apply_mask_to_frame(self,
frame: np.ndarray,
mask: np.ndarray,
output_format: str = 'greenscreen',
background_color: List[int] = [0, 255, 0]) -> np.ndarray:
"""
Apply mask to frame with specified output format
Args:
frame: Input frame (BGR)
mask: Binary mask
output_format: 'alpha' or 'greenscreen'
background_color: Background color for greenscreen
Returns:
Processed frame
"""
if output_format == 'alpha':
# Add alpha channel
if mask.dtype != np.uint8:
mask = (mask * 255).astype(np.uint8)
# Create BGRA image
bgra = np.zeros((frame.shape[0], frame.shape[1], 4), dtype=np.uint8)
bgra[:, :, :3] = frame
bgra[:, :, 3] = mask
return bgra
else: # greenscreen
# Create green background
background = np.full_like(frame, background_color, dtype=np.uint8)
# Expand mask to 3 channels
if mask.ndim == 2:
mask_3ch = np.expand_dims(mask, axis=2)
mask_3ch = np.repeat(mask_3ch, 3, axis=2)
else:
mask_3ch = mask
# Normalize mask to 0-1
if mask_3ch.dtype == np.uint8:
mask_float = mask_3ch.astype(np.float32) / 255.0
else:
mask_float = mask_3ch.astype(np.float32)
# Composite
result = (frame * mask_float + background * (1 - mask_float)).astype(np.uint8)
return result
def cleanup(self) -> None:
"""Clean up resources"""
# Clear states
self.states.clear()
# Clear CUDA cache
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
# Garbage collection
gc.collect()
print("🧹 SAM2 streaming processor cleaned up")
def get_memory_usage(self) -> Dict[str, float]:
"""Get current memory usage"""
memory_stats = {
'states_count': len(self.states),
'object_count': len(self.object_ids),
}
if torch.cuda.is_available():
memory_stats['cuda_allocated_gb'] = torch.cuda.memory_allocated() / 1e9
memory_stats['cuda_reserved_gb'] = torch.cuda.memory_reserved() / 1e9
return memory_stats