381 lines
14 KiB
Python
381 lines
14 KiB
Python
"""
|
|
SAM2 streaming processor for frame-by-frame video segmentation
|
|
|
|
NOTE: This is a template implementation. The actual SAM2 integration would need to:
|
|
1. Handle the fact that SAM2VideoPredictor loads the entire video internally
|
|
2. Potentially modify SAM2 to support frame-by-frame input
|
|
3. Or use a custom video loader that provides frames on demand
|
|
|
|
For a true streaming implementation, you may need to:
|
|
- Extend SAM2VideoPredictor to accept a frame generator instead of video path
|
|
- Implement a custom video loader that doesn't load all frames at once
|
|
- Use the memory offloading features more aggressively
|
|
"""
|
|
|
|
import torch
|
|
import numpy as np
|
|
from pathlib import Path
|
|
from typing import Dict, Any, List, Optional, Tuple, Generator
|
|
import warnings
|
|
import gc
|
|
|
|
# Import SAM2 components - these will be available after SAM2 installation
|
|
try:
|
|
from sam2.build_sam import build_sam2_video_predictor
|
|
from sam2.utils.misc import load_video_frames
|
|
except ImportError:
|
|
warnings.warn("SAM2 not installed. Please install segment-anything-2 first.")
|
|
|
|
|
|
class SAM2StreamingProcessor:
|
|
"""Streaming integration with SAM2 video predictor"""
|
|
|
|
def __init__(self, config: Dict[str, Any]):
|
|
self.config = config
|
|
self.device = torch.device(config.get('hardware', {}).get('device', 'cuda'))
|
|
|
|
# SAM2 model configuration
|
|
model_cfg = config.get('matting', {}).get('sam2_model_cfg', 'sam2.1_hiera_l')
|
|
checkpoint = config.get('matting', {}).get('sam2_checkpoint',
|
|
'segment-anything-2/checkpoints/sam2.1_hiera_large.pt')
|
|
|
|
# Build predictor
|
|
self.predictor = None
|
|
self._init_predictor(model_cfg, checkpoint)
|
|
|
|
# Processing parameters
|
|
self.memory_offload = config.get('matting', {}).get('memory_offload', True)
|
|
self.fp16 = config.get('matting', {}).get('fp16', True)
|
|
self.correction_interval = config.get('matting', {}).get('correction_interval', 300)
|
|
|
|
# State management
|
|
self.states = {} # eye -> inference state
|
|
self.object_ids = []
|
|
self.frame_count = 0
|
|
|
|
print(f"🎯 SAM2 streaming processor initialized:")
|
|
print(f" Model: {model_cfg}")
|
|
print(f" Device: {self.device}")
|
|
print(f" Memory offload: {self.memory_offload}")
|
|
print(f" FP16: {self.fp16}")
|
|
|
|
def _init_predictor(self, model_cfg: str, checkpoint: str) -> None:
|
|
"""Initialize SAM2 video predictor"""
|
|
try:
|
|
# Map config string to actual config path
|
|
config_mapping = {
|
|
'sam2.1_hiera_t': 'configs/sam2.1/sam2.1_hiera_t.yaml',
|
|
'sam2.1_hiera_s': 'configs/sam2.1/sam2.1_hiera_s.yaml',
|
|
'sam2.1_hiera_b+': 'configs/sam2.1/sam2.1_hiera_b+.yaml',
|
|
'sam2.1_hiera_l': 'configs/sam2.1/sam2.1_hiera_l.yaml',
|
|
}
|
|
|
|
actual_config = config_mapping.get(model_cfg, model_cfg)
|
|
|
|
# Build predictor with VOS optimizations
|
|
self.predictor = build_sam2_video_predictor(
|
|
actual_config,
|
|
checkpoint,
|
|
device=self.device,
|
|
vos_optimized=True # Enable full model compilation for speed
|
|
)
|
|
|
|
# Set to eval mode
|
|
self.predictor.eval()
|
|
|
|
# Enable FP16 if requested
|
|
if self.fp16 and self.device.type == 'cuda':
|
|
self.predictor = self.predictor.half()
|
|
|
|
except Exception as e:
|
|
raise RuntimeError(f"Failed to initialize SAM2 predictor: {e}")
|
|
|
|
def init_state(self,
|
|
video_path: str,
|
|
eye: str = 'full') -> Dict[str, Any]:
|
|
"""
|
|
Initialize inference state for streaming
|
|
|
|
Args:
|
|
video_path: Path to video file
|
|
eye: Eye identifier ('left', 'right', or 'full')
|
|
|
|
Returns:
|
|
Inference state dictionary
|
|
"""
|
|
# Initialize state with memory offloading enabled
|
|
with torch.inference_mode():
|
|
state = self.predictor.init_state(
|
|
video_path=video_path,
|
|
offload_video_to_cpu=self.memory_offload,
|
|
offload_state_to_cpu=self.memory_offload,
|
|
async_loading_frames=False # We'll provide frames directly
|
|
)
|
|
|
|
self.states[eye] = state
|
|
print(f" Initialized state for {eye} eye")
|
|
|
|
return state
|
|
|
|
def add_detections(self,
|
|
state: Dict[str, Any],
|
|
detections: List[Dict[str, Any]],
|
|
frame_idx: int = 0) -> List[int]:
|
|
"""
|
|
Add detection boxes as prompts to SAM2
|
|
|
|
Args:
|
|
state: Inference state
|
|
detections: List of detections with 'box' key
|
|
frame_idx: Frame index to add prompts
|
|
|
|
Returns:
|
|
List of object IDs
|
|
"""
|
|
if not detections:
|
|
warnings.warn(f"No detections to add at frame {frame_idx}")
|
|
return []
|
|
|
|
# Convert detections to SAM2 format
|
|
boxes = []
|
|
for det in detections:
|
|
box = det['box'] # [x1, y1, x2, y2]
|
|
boxes.append(box)
|
|
|
|
boxes_tensor = torch.tensor(boxes, dtype=torch.float32, device=self.device)
|
|
|
|
# Add boxes as prompts
|
|
with torch.inference_mode():
|
|
_, object_ids, _ = self.predictor.add_new_points_or_box(
|
|
inference_state=state,
|
|
frame_idx=frame_idx,
|
|
obj_id=0, # SAM2 will auto-increment
|
|
box=boxes_tensor
|
|
)
|
|
|
|
self.object_ids = object_ids
|
|
print(f" Added {len(detections)} detections at frame {frame_idx}: objects {object_ids}")
|
|
|
|
return object_ids
|
|
|
|
def propagate_in_video_simple(self,
|
|
state: Dict[str, Any]) -> Generator[Tuple[int, List[int], np.ndarray], None, None]:
|
|
"""
|
|
Simple propagation for single eye processing
|
|
|
|
Yields:
|
|
(frame_idx, object_ids, masks) tuples
|
|
"""
|
|
with torch.inference_mode():
|
|
for frame_idx, object_ids, masks in self.predictor.propagate_in_video(state):
|
|
# Convert masks to numpy
|
|
if isinstance(masks, torch.Tensor):
|
|
masks_np = masks.cpu().numpy()
|
|
else:
|
|
masks_np = masks
|
|
|
|
yield frame_idx, object_ids, masks_np
|
|
|
|
# Periodic memory cleanup
|
|
if frame_idx % 100 == 0:
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
gc.collect()
|
|
|
|
def propagate_frame_pair(self,
|
|
left_state: Dict[str, Any],
|
|
right_state: Dict[str, Any],
|
|
left_frame: np.ndarray,
|
|
right_frame: np.ndarray,
|
|
frame_idx: int) -> Tuple[np.ndarray, np.ndarray]:
|
|
"""
|
|
Propagate masks for a stereo frame pair
|
|
|
|
Args:
|
|
left_state: Left eye inference state
|
|
right_state: Right eye inference state
|
|
left_frame: Left eye frame
|
|
right_frame: Right eye frame
|
|
frame_idx: Current frame index
|
|
|
|
Returns:
|
|
Tuple of (left_masks, right_masks)
|
|
"""
|
|
# For actual implementation, we would need to handle the video frames
|
|
# being already loaded in the state. This is a simplified version.
|
|
# In practice, SAM2's propagate_in_video would handle frame loading.
|
|
|
|
# Get masks from the current propagation state
|
|
# This is pseudo-code as actual integration would depend on
|
|
# how frames are provided to SAM2VideoPredictor
|
|
|
|
left_masks = np.zeros((left_frame.shape[0], left_frame.shape[1]), dtype=np.uint8)
|
|
right_masks = np.zeros((right_frame.shape[0], right_frame.shape[1]), dtype=np.uint8)
|
|
|
|
# In actual implementation, you would:
|
|
# 1. Use predictor.propagate_in_video() generator
|
|
# 2. Extract masks for current frame_idx
|
|
# 3. Combine multiple object masks if needed
|
|
|
|
return left_masks, right_masks
|
|
|
|
def _propagate_single_frame(self,
|
|
state: Dict[str, Any],
|
|
frame: np.ndarray,
|
|
frame_idx: int) -> np.ndarray:
|
|
"""
|
|
Propagate masks for a single frame
|
|
|
|
Args:
|
|
state: Inference state
|
|
frame: Input frame
|
|
frame_idx: Frame index
|
|
|
|
Returns:
|
|
Combined mask for all objects
|
|
"""
|
|
# This is a simplified version - in practice we'd use the actual
|
|
# SAM2 propagation API which handles memory updates internally
|
|
|
|
# Get current masks from propagation
|
|
# Note: This is pseudo-code as the actual API may differ
|
|
masks = []
|
|
|
|
# For each tracked object
|
|
for obj_idx in range(len(self.object_ids)):
|
|
# Get mask for this object
|
|
# In reality, SAM2 handles this internally
|
|
obj_mask = self._get_object_mask(state, obj_idx, frame_idx)
|
|
masks.append(obj_mask)
|
|
|
|
# Combine all object masks
|
|
if masks:
|
|
combined_mask = np.max(masks, axis=0)
|
|
else:
|
|
combined_mask = np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8)
|
|
|
|
return combined_mask
|
|
|
|
def _get_object_mask(self, state: Dict[str, Any], obj_idx: int, frame_idx: int) -> np.ndarray:
|
|
"""
|
|
Get mask for specific object (placeholder - actual implementation uses SAM2 API)
|
|
"""
|
|
# In practice, this would extract the mask from SAM2's internal state
|
|
# For now, return a placeholder
|
|
h, w = state.get('video_height', 1080), state.get('video_width', 1920)
|
|
return np.zeros((h, w), dtype=np.uint8)
|
|
|
|
def apply_continuous_correction(self,
|
|
state: Dict[str, Any],
|
|
frame: np.ndarray,
|
|
frame_idx: int,
|
|
detector: Any) -> None:
|
|
"""
|
|
Apply continuous correction by re-detecting and refining masks
|
|
|
|
Args:
|
|
state: Inference state
|
|
frame: Current frame
|
|
frame_idx: Frame index
|
|
detector: Person detector instance
|
|
"""
|
|
if frame_idx % self.correction_interval != 0:
|
|
return
|
|
|
|
print(f" 🔄 Applying continuous correction at frame {frame_idx}")
|
|
|
|
# Detect persons in current frame
|
|
new_detections = detector.detect_persons(frame)
|
|
|
|
if new_detections:
|
|
# Add new prompts to refine tracking
|
|
with torch.inference_mode():
|
|
boxes = [det['box'] for det in new_detections]
|
|
boxes_tensor = torch.tensor(boxes, dtype=torch.float32, device=self.device)
|
|
|
|
# Add refinement prompts
|
|
self.predictor.add_new_points_or_box(
|
|
inference_state=state,
|
|
frame_idx=frame_idx,
|
|
obj_id=0, # Refine existing objects
|
|
box=boxes_tensor
|
|
)
|
|
|
|
def apply_mask_to_frame(self,
|
|
frame: np.ndarray,
|
|
mask: np.ndarray,
|
|
output_format: str = 'greenscreen',
|
|
background_color: List[int] = [0, 255, 0]) -> np.ndarray:
|
|
"""
|
|
Apply mask to frame with specified output format
|
|
|
|
Args:
|
|
frame: Input frame (BGR)
|
|
mask: Binary mask
|
|
output_format: 'alpha' or 'greenscreen'
|
|
background_color: Background color for greenscreen
|
|
|
|
Returns:
|
|
Processed frame
|
|
"""
|
|
if output_format == 'alpha':
|
|
# Add alpha channel
|
|
if mask.dtype != np.uint8:
|
|
mask = (mask * 255).astype(np.uint8)
|
|
|
|
# Create BGRA image
|
|
bgra = np.zeros((frame.shape[0], frame.shape[1], 4), dtype=np.uint8)
|
|
bgra[:, :, :3] = frame
|
|
bgra[:, :, 3] = mask
|
|
|
|
return bgra
|
|
|
|
else: # greenscreen
|
|
# Create green background
|
|
background = np.full_like(frame, background_color, dtype=np.uint8)
|
|
|
|
# Expand mask to 3 channels
|
|
if mask.ndim == 2:
|
|
mask_3ch = np.expand_dims(mask, axis=2)
|
|
mask_3ch = np.repeat(mask_3ch, 3, axis=2)
|
|
else:
|
|
mask_3ch = mask
|
|
|
|
# Normalize mask to 0-1
|
|
if mask_3ch.dtype == np.uint8:
|
|
mask_float = mask_3ch.astype(np.float32) / 255.0
|
|
else:
|
|
mask_float = mask_3ch.astype(np.float32)
|
|
|
|
# Composite
|
|
result = (frame * mask_float + background * (1 - mask_float)).astype(np.uint8)
|
|
|
|
return result
|
|
|
|
def cleanup(self) -> None:
|
|
"""Clean up resources"""
|
|
# Clear states
|
|
self.states.clear()
|
|
|
|
# Clear CUDA cache
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
torch.cuda.synchronize()
|
|
|
|
# Garbage collection
|
|
gc.collect()
|
|
|
|
print("🧹 SAM2 streaming processor cleaned up")
|
|
|
|
def get_memory_usage(self) -> Dict[str, float]:
|
|
"""Get current memory usage"""
|
|
memory_stats = {
|
|
'states_count': len(self.states),
|
|
'object_count': len(self.object_ids),
|
|
}
|
|
|
|
if torch.cuda.is_available():
|
|
memory_stats['cuda_allocated_gb'] = torch.cuda.memory_allocated() / 1e9
|
|
memory_stats['cuda_reserved_gb'] = torch.cuda.memory_reserved() / 1e9
|
|
|
|
return memory_stats |