Files
test2/vr180_streaming/sam2_streaming_simple.py
2025-07-27 09:55:52 -07:00

252 lines
10 KiB
Python

"""
Simple SAM2 streaming processor based on det-sam2 pattern
Adapted for current segment-anything-2 API
"""
import torch
import numpy as np
import cv2
import tempfile
import os
from pathlib import Path
from typing import Dict, Any, List, Optional
import warnings
import gc
# Import SAM2 components
try:
from sam2.build_sam import build_sam2_video_predictor
except ImportError:
warnings.warn("SAM2 not installed. Please install segment-anything-2 first.")
class SAM2StreamingProcessor:
"""Simple streaming integration with SAM2 following det-sam2 pattern"""
def __init__(self, config: Dict[str, Any]):
self.config = config
self.device = torch.device(config.get('hardware', {}).get('device', 'cuda'))
# SAM2 model configuration
model_cfg_name = config.get('matting', {}).get('sam2_model_cfg', 'sam2.1_hiera_l')
checkpoint = config.get('matting', {}).get('sam2_checkpoint',
'segment-anything-2/checkpoints/sam2.1_hiera_large.pt')
# Map config name to full path
config_mapping = {
'sam2.1_hiera_t': 'segment-anything-2/sam2/configs/sam2.1/sam2.1_hiera_t.yaml',
'sam2.1_hiera_s': 'segment-anything-2/sam2/configs/sam2.1/sam2.1_hiera_s.yaml',
'sam2.1_hiera_b+': 'segment-anything-2/sam2/configs/sam2.1/sam2.1_hiera_b+.yaml',
'sam2.1_hiera_l': 'segment-anything-2/sam2/configs/sam2.1/sam2.1_hiera_l.yaml',
}
model_cfg = config_mapping.get(model_cfg_name, model_cfg_name)
# Build predictor (simple, clean approach)
self.predictor = build_sam2_video_predictor(
model_cfg,
checkpoint,
device=self.device
)
# Frame buffer for streaming (like det-sam2)
self.frame_buffer = []
self.frame_buffer_size = config.get('streaming', {}).get('buffer_frames', 10)
# State management (simple)
self.inference_state = None
self.temp_dir = None
self.object_ids = []
# Memory management
self.memory_offload = config.get('matting', {}).get('memory_offload', True)
self.max_frames_to_track = config.get('matting', {}).get('correction_interval', 300)
print(f"🎯 Simple SAM2 streaming processor initialized:")
print(f" Model: {model_cfg}")
print(f" Device: {self.device}")
print(f" Buffer size: {self.frame_buffer_size}")
print(f" Memory offload: {self.memory_offload}")
def add_frame_and_detections(self,
frame: np.ndarray,
detections: List[Dict[str, Any]],
frame_idx: int) -> np.ndarray:
"""
Add frame to buffer and process detections (det-sam2 pattern)
Args:
frame: Input frame (BGR)
detections: List of detections with 'box' key
frame_idx: Global frame index
Returns:
Mask for current frame
"""
# Add frame to buffer
self.frame_buffer.append({
'frame': frame,
'frame_idx': frame_idx,
'detections': detections
})
# Process when buffer is full or when we have detections
if len(self.frame_buffer) >= self.frame_buffer_size or detections:
return self._process_buffer()
else:
# Return empty mask if no processing yet
return np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8)
def _process_buffer(self) -> np.ndarray:
"""Process current frame buffer (adapted det-sam2 approach)"""
if not self.frame_buffer:
return np.zeros((480, 640), dtype=np.uint8)
try:
# Create temporary directory for frames (current SAM2 API requirement)
self._create_temp_frames()
# Initialize or update SAM2 state
if self.inference_state is None:
# First time: initialize state with temp directory
self.inference_state = self.predictor.init_state(
video_path=self.temp_dir,
offload_video_to_cpu=self.memory_offload,
offload_state_to_cpu=self.memory_offload
)
print(f" Initialized SAM2 state with {len(self.frame_buffer)} frames")
else:
# Subsequent times: we need to reinitialize since current SAM2 lacks update_state
# This is the key difference from det-sam2 reference
self._cleanup_temp_frames()
self._create_temp_frames()
self.inference_state = self.predictor.init_state(
video_path=self.temp_dir,
offload_video_to_cpu=self.memory_offload,
offload_state_to_cpu=self.memory_offload
)
print(f" Reinitialized SAM2 state with {len(self.frame_buffer)} frames")
# Add detections as prompts (standard SAM2 API)
self._add_detection_prompts()
# Get masks via propagation
masks = self._get_current_masks()
# Clean up old frames to prevent memory accumulation
self._cleanup_old_frames()
return masks
except Exception as e:
print(f" Warning: Buffer processing failed: {e}")
return np.zeros((480, 640), dtype=np.uint8)
def _create_temp_frames(self):
"""Create temporary directory with frame images for SAM2"""
if self.temp_dir:
self._cleanup_temp_frames()
self.temp_dir = tempfile.mkdtemp(prefix='sam2_streaming_')
for i, buffer_item in enumerate(self.frame_buffer):
frame = buffer_item['frame']
# Convert BGR to RGB for SAM2
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# Save as JPEG (SAM2 expects JPEG images in directory)
frame_path = os.path.join(self.temp_dir, f"{i:05d}.jpg")
cv2.imwrite(frame_path, cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR))
def _add_detection_prompts(self):
"""Add detection boxes as prompts to SAM2 (standard API)"""
for buffer_idx, buffer_item in enumerate(self.frame_buffer):
detections = buffer_item.get('detections', [])
for det_idx, detection in enumerate(detections):
box = detection['box'] # [x1, y1, x2, y2]
# Use standard SAM2 API
try:
_, out_obj_ids, out_mask_logits = self.predictor.add_new_points_or_box(
inference_state=self.inference_state,
frame_idx=buffer_idx, # Frame index within buffer
obj_id=det_idx, # Simple object ID
box=np.array(box, dtype=np.float32)
)
# Track object IDs
if det_idx not in self.object_ids:
self.object_ids.append(det_idx)
except Exception as e:
print(f" Warning: Failed to add detection: {e}")
continue
def _get_current_masks(self) -> np.ndarray:
"""Get masks for current frame via propagation"""
if not self.object_ids:
# No objects to track
frame_shape = self.frame_buffer[-1]['frame'].shape
return np.zeros((frame_shape[0], frame_shape[1]), dtype=np.uint8)
try:
# Use SAM2's propagate_in_video (standard API)
latest_frame_idx = len(self.frame_buffer) - 1
masks_for_frame = []
for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(
self.inference_state,
start_frame_idx=latest_frame_idx,
max_frame_num_to_track=1, # Just current frame
reverse=False
):
if out_frame_idx == latest_frame_idx:
# Combine all object masks
if len(out_mask_logits) > 0:
combined_mask = np.zeros_like(out_mask_logits[0], dtype=bool)
for mask_logit in out_mask_logits:
mask = (mask_logit > 0.0).cpu().numpy()
combined_mask = combined_mask | mask
return (combined_mask * 255).astype(np.uint8)
# If no masks found, return empty
frame_shape = self.frame_buffer[-1]['frame'].shape
return np.zeros((frame_shape[0], frame_shape[1]), dtype=np.uint8)
except Exception as e:
print(f" Warning: Mask propagation failed: {e}")
frame_shape = self.frame_buffer[-1]['frame'].shape
return np.zeros((frame_shape[0], frame_shape[1]), dtype=np.uint8)
def _cleanup_old_frames(self):
"""Clean up old frames from buffer (det-sam2 pattern)"""
# Keep only recent frames to prevent memory accumulation
if len(self.frame_buffer) > self.frame_buffer_size:
self.frame_buffer = self.frame_buffer[-self.frame_buffer_size:]
# Periodic GPU memory cleanup
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
def _cleanup_temp_frames(self):
"""Clean up temporary frame directory"""
if self.temp_dir and os.path.exists(self.temp_dir):
import shutil
shutil.rmtree(self.temp_dir)
self.temp_dir = None
def cleanup(self):
"""Clean up all resources"""
self._cleanup_temp_frames()
self.frame_buffer.clear()
self.object_ids.clear()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
gc.collect()
print("🧹 Simple SAM2 streaming processor cleaned up")