527 lines
21 KiB
Python
527 lines
21 KiB
Python
"""
|
|
SAM2 streaming processor for frame-by-frame video segmentation
|
|
|
|
NOTE: This is a template implementation. The actual SAM2 integration would need to:
|
|
1. Handle the fact that SAM2VideoPredictor loads the entire video internally
|
|
2. Potentially modify SAM2 to support frame-by-frame input
|
|
3. Or use a custom video loader that provides frames on demand
|
|
|
|
For a true streaming implementation, you may need to:
|
|
- Extend SAM2VideoPredictor to accept a frame generator instead of video path
|
|
- Implement a custom video loader that doesn't load all frames at once
|
|
- Use the memory offloading features more aggressively
|
|
"""
|
|
|
|
import torch
|
|
import numpy as np
|
|
from pathlib import Path
|
|
from typing import Dict, Any, List, Optional, Tuple, Generator
|
|
import warnings
|
|
import gc
|
|
|
|
# Import SAM2 components - these will be available after SAM2 installation
|
|
try:
|
|
from sam2.build_sam import build_sam2_video_predictor
|
|
from sam2.utils.misc import load_video_frames
|
|
except ImportError:
|
|
warnings.warn("SAM2 not installed. Please install segment-anything-2 first.")
|
|
|
|
|
|
class SAM2StreamingProcessor:
|
|
"""Streaming integration with SAM2 video predictor"""
|
|
|
|
def __init__(self, config: Dict[str, Any]):
|
|
self.config = config
|
|
self.device = torch.device(config.get('hardware', {}).get('device', 'cuda'))
|
|
|
|
# Processing parameters (set before _init_predictor)
|
|
self.memory_offload = config.get('matting', {}).get('memory_offload', True)
|
|
self.fp16 = config.get('matting', {}).get('fp16', True)
|
|
self.correction_interval = config.get('matting', {}).get('correction_interval', 300)
|
|
|
|
# SAM2 model configuration
|
|
model_cfg = config.get('matting', {}).get('sam2_model_cfg', 'sam2.1_hiera_l')
|
|
checkpoint = config.get('matting', {}).get('sam2_checkpoint',
|
|
'segment-anything-2/checkpoints/sam2.1_hiera_large.pt')
|
|
|
|
# Build predictor
|
|
self.predictor = None
|
|
self._init_predictor(model_cfg, checkpoint)
|
|
|
|
# State management
|
|
self.states = {} # eye -> inference state
|
|
self.object_ids = []
|
|
self.frame_count = 0
|
|
|
|
print(f"🎯 SAM2 streaming processor initialized:")
|
|
print(f" Model: {model_cfg}")
|
|
print(f" Device: {self.device}")
|
|
print(f" Memory offload: {self.memory_offload}")
|
|
print(f" FP16: {self.fp16}")
|
|
|
|
def _init_predictor(self, model_cfg: str, checkpoint: str) -> None:
|
|
"""Initialize SAM2 video predictor"""
|
|
try:
|
|
# Map config string to actual config path
|
|
config_mapping = {
|
|
'sam2.1_hiera_t': 'configs/sam2.1/sam2.1_hiera_t.yaml',
|
|
'sam2.1_hiera_s': 'configs/sam2.1/sam2.1_hiera_s.yaml',
|
|
'sam2.1_hiera_b+': 'configs/sam2.1/sam2.1_hiera_b+.yaml',
|
|
'sam2.1_hiera_l': 'configs/sam2.1/sam2.1_hiera_l.yaml',
|
|
}
|
|
|
|
actual_config = config_mapping.get(model_cfg, model_cfg)
|
|
|
|
# Build predictor with VOS optimizations
|
|
self.predictor = build_sam2_video_predictor(
|
|
actual_config,
|
|
checkpoint,
|
|
device=self.device,
|
|
vos_optimized=True # Enable full model compilation for speed
|
|
)
|
|
|
|
# Set to eval mode
|
|
self.predictor.eval()
|
|
|
|
# Note: FP16 conversion can cause type mismatches with compiled models
|
|
# Let SAM2 handle precision internally via build_sam2_video_predictor options
|
|
if self.fp16 and self.device.type == 'cuda':
|
|
print(" FP16 enabled via SAM2 internal settings")
|
|
|
|
except Exception as e:
|
|
raise RuntimeError(f"Failed to initialize SAM2 predictor: {e}")
|
|
|
|
def init_state(self,
|
|
video_info: Dict[str, Any],
|
|
eye: str = 'full') -> Dict[str, Any]:
|
|
"""
|
|
Initialize inference state for streaming (NO VIDEO LOADING)
|
|
|
|
Args:
|
|
video_info: Video metadata dict with width, height, frame_count
|
|
eye: Eye identifier ('left', 'right', or 'full')
|
|
|
|
Returns:
|
|
Inference state dictionary
|
|
"""
|
|
print(f" Initializing streaming state for {eye} eye...")
|
|
|
|
# Monitor memory before initialization
|
|
if torch.cuda.is_available():
|
|
before_mem = torch.cuda.memory_allocated() / 1e9
|
|
print(f" 📊 GPU memory before init: {before_mem:.1f}GB")
|
|
|
|
# Create streaming state WITHOUT loading video frames
|
|
state = self._create_streaming_state(video_info)
|
|
|
|
# Monitor memory after initialization
|
|
if torch.cuda.is_available():
|
|
after_mem = torch.cuda.memory_allocated() / 1e9
|
|
print(f" 📊 GPU memory after init: {after_mem:.1f}GB (+{after_mem-before_mem:.1f}GB)")
|
|
|
|
self.states[eye] = state
|
|
print(f" ✅ Streaming state initialized for {eye} eye")
|
|
|
|
return state
|
|
|
|
def _create_streaming_state(self, video_info: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Create streaming state for frame-by-frame processing"""
|
|
# Create a streaming-compatible inference state
|
|
# This mirrors SAM2's internal state structure but without video frames
|
|
|
|
# Use SAM2's init_state but with a dummy 1-frame video to avoid loading
|
|
# We'll override the frame access later
|
|
try:
|
|
# Create a minimal dummy video file temporarily
|
|
import tempfile
|
|
import cv2
|
|
|
|
# Create 1-frame dummy video
|
|
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_file:
|
|
dummy_path = tmp_file.name
|
|
|
|
# Write a single frame video
|
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
|
out = cv2.VideoWriter(dummy_path, fourcc, 1.0, (video_info['width'], video_info['height']))
|
|
dummy_frame = np.zeros((video_info['height'], video_info['width'], 3), dtype=np.uint8)
|
|
out.write(dummy_frame)
|
|
out.release()
|
|
|
|
# Initialize with dummy video (SAM2 will load metadata only from 1 frame)
|
|
with torch.inference_mode():
|
|
inference_state = self.predictor.init_state(
|
|
video_path=dummy_path,
|
|
offload_video_to_cpu=self.memory_offload,
|
|
offload_state_to_cpu=self.memory_offload,
|
|
async_loading_frames=True
|
|
)
|
|
|
|
# Clean up dummy file
|
|
import os
|
|
os.unlink(dummy_path)
|
|
|
|
# Update state with actual video info
|
|
inference_state['num_frames'] = video_info.get('total_frames', video_info.get('frame_count', 0))
|
|
inference_state['video_height'] = video_info['height']
|
|
inference_state['video_width'] = video_info['width']
|
|
|
|
except Exception as e:
|
|
print(f" Warning: Failed to create proper SAM2 state ({e}), using minimal state")
|
|
# Fallback to minimal state
|
|
inference_state = {
|
|
'point_inputs_per_obj': {},
|
|
'mask_inputs_per_obj': {},
|
|
'cached_features': {},
|
|
'constants': {},
|
|
'obj_id_to_idx': {},
|
|
'obj_idx_to_id': {},
|
|
'obj_ids': [],
|
|
'click_inputs_per_obj': {},
|
|
'temp_output_dict_per_obj': {},
|
|
'consolidated_frame_inds': {},
|
|
'tracking_has_started': False,
|
|
'num_frames': video_info.get('total_frames', video_info.get('frame_count', 0)),
|
|
'video_height': video_info['height'],
|
|
'video_width': video_info['width'],
|
|
'device': self.device,
|
|
'storage_device': torch.device('cpu') if self.memory_offload else self.device,
|
|
'offload_video_to_cpu': self.memory_offload,
|
|
'offload_state_to_cpu': self.memory_offload,
|
|
}
|
|
|
|
return inference_state
|
|
|
|
def add_detections(self,
|
|
state: Dict[str, Any],
|
|
frame: np.ndarray,
|
|
detections: List[Dict[str, Any]],
|
|
frame_idx: int = 0) -> List[int]:
|
|
"""
|
|
Add detection boxes as prompts to SAM2 with frame data
|
|
|
|
Args:
|
|
state: Inference state
|
|
frame: Frame image (RGB numpy array)
|
|
detections: List of detections with 'box' key
|
|
frame_idx: Frame index to add prompts
|
|
|
|
Returns:
|
|
List of object IDs
|
|
"""
|
|
if not detections:
|
|
warnings.warn(f"No detections to add at frame {frame_idx}")
|
|
return []
|
|
|
|
# Convert frame to tensor
|
|
frame_tensor = torch.tensor(frame, dtype=torch.float32, device=self.device)
|
|
if frame_tensor.ndim == 3:
|
|
frame_tensor = frame_tensor.permute(2, 0, 1) # HWC -> CHW
|
|
frame_tensor = frame_tensor.unsqueeze(0) # Add batch dimension
|
|
|
|
# Convert detections to SAM2 format
|
|
boxes = []
|
|
for det in detections:
|
|
box = det['box'] # [x1, y1, x2, y2]
|
|
boxes.append(box)
|
|
|
|
boxes_tensor = torch.tensor(boxes, dtype=torch.float32, device=self.device)
|
|
|
|
# Manually process frame and add prompts (streaming approach)
|
|
with torch.inference_mode():
|
|
# Process frame through SAM2's image encoder
|
|
backbone_out = self.predictor.forward_image(frame_tensor)
|
|
|
|
# Store features in state for this frame
|
|
state['cached_features'][frame_idx] = backbone_out
|
|
|
|
# Add boxes as prompts for this specific frame
|
|
_, object_ids, masks = self.predictor.add_new_points_or_box(
|
|
inference_state=state,
|
|
frame_idx=frame_idx,
|
|
obj_id=None, # Let SAM2 auto-assign
|
|
box=boxes_tensor
|
|
)
|
|
|
|
self.object_ids = object_ids
|
|
print(f" Added {len(detections)} detections at frame {frame_idx}: objects {object_ids}")
|
|
|
|
return object_ids
|
|
|
|
def propagate_single_frame(self,
|
|
state: Dict[str, Any],
|
|
frame: np.ndarray,
|
|
frame_idx: int) -> np.ndarray:
|
|
"""
|
|
Propagate masks for a single frame (true streaming)
|
|
|
|
Args:
|
|
state: Inference state
|
|
frame: Frame image (RGB numpy array)
|
|
frame_idx: Frame index
|
|
|
|
Returns:
|
|
Combined mask for all objects
|
|
"""
|
|
# Convert frame to tensor
|
|
frame_tensor = torch.tensor(frame, dtype=torch.float32, device=self.device)
|
|
if frame_tensor.ndim == 3:
|
|
frame_tensor = frame_tensor.permute(2, 0, 1) # HWC -> CHW
|
|
frame_tensor = frame_tensor.unsqueeze(0) # Add batch dimension
|
|
|
|
with torch.inference_mode():
|
|
# Process frame through SAM2's image encoder
|
|
backbone_out = self.predictor.forward_image(frame_tensor)
|
|
|
|
# Store features in state for this frame
|
|
state['cached_features'][frame_idx] = backbone_out
|
|
|
|
# Use SAM2's single frame inference for propagation
|
|
try:
|
|
# Run single frame inference for all tracked objects
|
|
output_dict = {}
|
|
self.predictor._run_single_frame_inference(
|
|
inference_state=state,
|
|
output_dict=output_dict,
|
|
frame_idx=frame_idx,
|
|
batch_size=1,
|
|
is_init_cond_frame=False, # Not initialization frame
|
|
point_inputs=None,
|
|
mask_inputs=None,
|
|
reverse=False,
|
|
run_mem_encoder=True
|
|
)
|
|
|
|
# Extract masks from output
|
|
if output_dict and 'pred_masks' in output_dict:
|
|
pred_masks = output_dict['pred_masks']
|
|
# Combine all object masks
|
|
if pred_masks.shape[0] > 0:
|
|
combined_mask = pred_masks.max(dim=0)[0]
|
|
combined_mask_np = (combined_mask > 0.0).cpu().numpy().astype(np.uint8) * 255
|
|
else:
|
|
combined_mask_np = np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8)
|
|
else:
|
|
combined_mask_np = np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8)
|
|
|
|
except Exception as e:
|
|
print(f" Warning: Single frame inference failed: {e}")
|
|
combined_mask_np = np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8)
|
|
|
|
# Cleanup old features to prevent memory accumulation
|
|
self._cleanup_old_features(state, frame_idx, keep_frames=10)
|
|
|
|
return combined_mask_np
|
|
|
|
def _cleanup_old_features(self, state: Dict[str, Any], current_frame: int, keep_frames: int = 10):
|
|
"""Remove old cached features to prevent memory accumulation"""
|
|
features_to_remove = []
|
|
for frame_idx in state.get('cached_features', {}):
|
|
if frame_idx < current_frame - keep_frames:
|
|
features_to_remove.append(frame_idx)
|
|
|
|
for frame_idx in features_to_remove:
|
|
del state['cached_features'][frame_idx]
|
|
|
|
# Periodic GPU memory cleanup
|
|
if current_frame % 50 == 0:
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
gc.collect()
|
|
|
|
def propagate_frame_pair(self,
|
|
left_state: Dict[str, Any],
|
|
right_state: Dict[str, Any],
|
|
left_frame: np.ndarray,
|
|
right_frame: np.ndarray,
|
|
frame_idx: int) -> Tuple[np.ndarray, np.ndarray]:
|
|
"""
|
|
Propagate masks for a stereo frame pair
|
|
|
|
Args:
|
|
left_state: Left eye inference state
|
|
right_state: Right eye inference state
|
|
left_frame: Left eye frame
|
|
right_frame: Right eye frame
|
|
frame_idx: Current frame index
|
|
|
|
Returns:
|
|
Tuple of (left_masks, right_masks)
|
|
"""
|
|
# For actual implementation, we would need to handle the video frames
|
|
# being already loaded in the state. This is a simplified version.
|
|
# In practice, SAM2's propagate_in_video would handle frame loading.
|
|
|
|
# Get masks from the current propagation state
|
|
# This is pseudo-code as actual integration would depend on
|
|
# how frames are provided to SAM2VideoPredictor
|
|
|
|
left_masks = np.zeros((left_frame.shape[0], left_frame.shape[1]), dtype=np.uint8)
|
|
right_masks = np.zeros((right_frame.shape[0], right_frame.shape[1]), dtype=np.uint8)
|
|
|
|
# In actual implementation, you would:
|
|
# 1. Use predictor.propagate_in_video() generator
|
|
# 2. Extract masks for current frame_idx
|
|
# 3. Combine multiple object masks if needed
|
|
|
|
return left_masks, right_masks
|
|
|
|
def _propagate_single_frame(self,
|
|
state: Dict[str, Any],
|
|
frame: np.ndarray,
|
|
frame_idx: int) -> np.ndarray:
|
|
"""
|
|
Propagate masks for a single frame
|
|
|
|
Args:
|
|
state: Inference state
|
|
frame: Input frame
|
|
frame_idx: Frame index
|
|
|
|
Returns:
|
|
Combined mask for all objects
|
|
"""
|
|
# This is a simplified version - in practice we'd use the actual
|
|
# SAM2 propagation API which handles memory updates internally
|
|
|
|
# Get current masks from propagation
|
|
# Note: This is pseudo-code as the actual API may differ
|
|
masks = []
|
|
|
|
# For each tracked object
|
|
for obj_idx in range(len(self.object_ids)):
|
|
# Get mask for this object
|
|
# In reality, SAM2 handles this internally
|
|
obj_mask = self._get_object_mask(state, obj_idx, frame_idx)
|
|
masks.append(obj_mask)
|
|
|
|
# Combine all object masks
|
|
if masks:
|
|
combined_mask = np.max(masks, axis=0)
|
|
else:
|
|
combined_mask = np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8)
|
|
|
|
return combined_mask
|
|
|
|
def _get_object_mask(self, state: Dict[str, Any], obj_idx: int, frame_idx: int) -> np.ndarray:
|
|
"""
|
|
Get mask for specific object (placeholder - actual implementation uses SAM2 API)
|
|
"""
|
|
# In practice, this would extract the mask from SAM2's internal state
|
|
# For now, return a placeholder
|
|
h, w = state.get('video_height', 1080), state.get('video_width', 1920)
|
|
return np.zeros((h, w), dtype=np.uint8)
|
|
|
|
def apply_continuous_correction(self,
|
|
state: Dict[str, Any],
|
|
frame: np.ndarray,
|
|
frame_idx: int,
|
|
detector: Any) -> None:
|
|
"""
|
|
Apply continuous correction by re-detecting and refining masks
|
|
|
|
Args:
|
|
state: Inference state
|
|
frame: Current frame
|
|
frame_idx: Frame index
|
|
detector: Person detector instance
|
|
"""
|
|
if frame_idx % self.correction_interval != 0:
|
|
return
|
|
|
|
print(f" 🔄 Applying continuous correction at frame {frame_idx}")
|
|
|
|
# Detect persons in current frame
|
|
new_detections = detector.detect_persons(frame)
|
|
|
|
if new_detections:
|
|
# Add new prompts to refine tracking
|
|
with torch.inference_mode():
|
|
boxes = [det['box'] for det in new_detections]
|
|
boxes_tensor = torch.tensor(boxes, dtype=torch.float32, device=self.device)
|
|
|
|
# Add refinement prompts
|
|
self.predictor.add_new_points_or_box(
|
|
inference_state=state,
|
|
frame_idx=frame_idx,
|
|
obj_id=0, # Refine existing objects
|
|
box=boxes_tensor
|
|
)
|
|
|
|
def apply_mask_to_frame(self,
|
|
frame: np.ndarray,
|
|
mask: np.ndarray,
|
|
output_format: str = 'greenscreen',
|
|
background_color: List[int] = [0, 255, 0]) -> np.ndarray:
|
|
"""
|
|
Apply mask to frame with specified output format
|
|
|
|
Args:
|
|
frame: Input frame (BGR)
|
|
mask: Binary mask
|
|
output_format: 'alpha' or 'greenscreen'
|
|
background_color: Background color for greenscreen
|
|
|
|
Returns:
|
|
Processed frame
|
|
"""
|
|
if output_format == 'alpha':
|
|
# Add alpha channel
|
|
if mask.dtype != np.uint8:
|
|
mask = (mask * 255).astype(np.uint8)
|
|
|
|
# Create BGRA image
|
|
bgra = np.zeros((frame.shape[0], frame.shape[1], 4), dtype=np.uint8)
|
|
bgra[:, :, :3] = frame
|
|
bgra[:, :, 3] = mask
|
|
|
|
return bgra
|
|
|
|
else: # greenscreen
|
|
# Create green background
|
|
background = np.full_like(frame, background_color, dtype=np.uint8)
|
|
|
|
# Expand mask to 3 channels
|
|
if mask.ndim == 2:
|
|
mask_3ch = np.expand_dims(mask, axis=2)
|
|
mask_3ch = np.repeat(mask_3ch, 3, axis=2)
|
|
else:
|
|
mask_3ch = mask
|
|
|
|
# Normalize mask to 0-1
|
|
if mask_3ch.dtype == np.uint8:
|
|
mask_float = mask_3ch.astype(np.float32) / 255.0
|
|
else:
|
|
mask_float = mask_3ch.astype(np.float32)
|
|
|
|
# Composite
|
|
result = (frame * mask_float + background * (1 - mask_float)).astype(np.uint8)
|
|
|
|
return result
|
|
|
|
def cleanup(self) -> None:
|
|
"""Clean up resources"""
|
|
# Clear states
|
|
self.states.clear()
|
|
|
|
# Clear CUDA cache
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
torch.cuda.synchronize()
|
|
|
|
# Garbage collection
|
|
gc.collect()
|
|
|
|
print("🧹 SAM2 streaming processor cleaned up")
|
|
|
|
def get_memory_usage(self) -> Dict[str, float]:
|
|
"""Get current memory usage"""
|
|
memory_stats = {
|
|
'states_count': len(self.states),
|
|
'object_count': len(self.object_ids),
|
|
}
|
|
|
|
if torch.cuda.is_available():
|
|
memory_stats['cuda_allocated_gb'] = torch.cuda.memory_allocated() / 1e9
|
|
memory_stats['cuda_reserved_gb'] = torch.cuda.memory_reserved() / 1e9
|
|
|
|
return memory_stats |