add stuff true streaming
This commit is contained in:
@@ -18,6 +18,7 @@ from pathlib import Path
|
||||
from typing import Dict, Any, List, Optional, Tuple, Generator
|
||||
import warnings
|
||||
import gc
|
||||
from .timeout_init import safe_init_state, TimeoutError
|
||||
|
||||
# Import SAM2 components - these will be available after SAM2 installation
|
||||
try:
|
||||
@@ -92,41 +93,85 @@ class SAM2StreamingProcessor:
|
||||
raise RuntimeError(f"Failed to initialize SAM2 predictor: {e}")
|
||||
|
||||
def init_state(self,
|
||||
video_path: str,
|
||||
video_info: Dict[str, Any],
|
||||
eye: str = 'full') -> Dict[str, Any]:
|
||||
"""
|
||||
Initialize inference state for streaming
|
||||
Initialize inference state for streaming (NO VIDEO LOADING)
|
||||
|
||||
Args:
|
||||
video_path: Path to video file
|
||||
video_info: Video metadata dict with width, height, frame_count
|
||||
eye: Eye identifier ('left', 'right', or 'full')
|
||||
|
||||
Returns:
|
||||
Inference state dictionary
|
||||
"""
|
||||
# Initialize state with memory offloading enabled
|
||||
with torch.inference_mode():
|
||||
state = self.predictor.init_state(
|
||||
video_path=video_path,
|
||||
offload_video_to_cpu=self.memory_offload,
|
||||
offload_state_to_cpu=self.memory_offload,
|
||||
async_loading_frames=False # We'll provide frames directly
|
||||
)
|
||||
print(f" Initializing streaming state for {eye} eye...")
|
||||
|
||||
# Monitor memory before initialization
|
||||
if torch.cuda.is_available():
|
||||
before_mem = torch.cuda.memory_allocated() / 1e9
|
||||
print(f" 📊 GPU memory before init: {before_mem:.1f}GB")
|
||||
|
||||
# Create streaming state WITHOUT loading video frames
|
||||
state = self._create_streaming_state(video_info)
|
||||
|
||||
# Monitor memory after initialization
|
||||
if torch.cuda.is_available():
|
||||
after_mem = torch.cuda.memory_allocated() / 1e9
|
||||
print(f" 📊 GPU memory after init: {after_mem:.1f}GB (+{after_mem-before_mem:.1f}GB)")
|
||||
|
||||
self.states[eye] = state
|
||||
print(f" Initialized state for {eye} eye")
|
||||
print(f" ✅ Streaming state initialized for {eye} eye")
|
||||
|
||||
return state
|
||||
|
||||
def _create_streaming_state(self, video_info: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Create streaming state for frame-by-frame processing"""
|
||||
# Create a streaming-compatible inference state
|
||||
# This mirrors SAM2's internal state structure but without video frames
|
||||
|
||||
with torch.inference_mode():
|
||||
# Initialize empty inference state using SAM2's predictor
|
||||
# We'll manually provide frames via propagate calls
|
||||
inference_state = {
|
||||
'point_inputs_per_obj': {},
|
||||
'mask_inputs_per_obj': {},
|
||||
'cached_features': {},
|
||||
'constants': {},
|
||||
'obj_id_to_idx': {},
|
||||
'obj_idx_to_id': {},
|
||||
'obj_ids': [],
|
||||
'click_inputs_per_obj': {},
|
||||
'temp_output_dict_per_obj': {},
|
||||
'consolidated_frame_inds': {},
|
||||
'tracking_has_started': False,
|
||||
'num_frames': video_info['frame_count'],
|
||||
'video_height': video_info['height'],
|
||||
'video_width': video_info['width'],
|
||||
'device': self.device,
|
||||
'storage_device': torch.device('cpu') if self.memory_offload else self.device,
|
||||
'offload_video_to_cpu': self.memory_offload,
|
||||
'offload_state_to_cpu': self.memory_offload,
|
||||
'inference_state': {},
|
||||
}
|
||||
|
||||
# Initialize SAM2 constants that don't depend on video frames
|
||||
self.predictor._get_image_feature_cache = {}
|
||||
self.predictor._feature_bank = {}
|
||||
|
||||
return inference_state
|
||||
|
||||
def add_detections(self,
|
||||
state: Dict[str, Any],
|
||||
frame: np.ndarray,
|
||||
detections: List[Dict[str, Any]],
|
||||
frame_idx: int = 0) -> List[int]:
|
||||
"""
|
||||
Add detection boxes as prompts to SAM2
|
||||
Add detection boxes as prompts to SAM2 with frame data
|
||||
|
||||
Args:
|
||||
state: Inference state
|
||||
frame: Frame image (RGB numpy array)
|
||||
detections: List of detections with 'box' key
|
||||
frame_idx: Frame index to add prompts
|
||||
|
||||
@@ -137,6 +182,12 @@ class SAM2StreamingProcessor:
|
||||
warnings.warn(f"No detections to add at frame {frame_idx}")
|
||||
return []
|
||||
|
||||
# Convert frame to tensor
|
||||
frame_tensor = torch.tensor(frame, dtype=torch.float32, device=self.device)
|
||||
if frame_tensor.ndim == 3:
|
||||
frame_tensor = frame_tensor.permute(2, 0, 1) # HWC -> CHW
|
||||
frame_tensor = frame_tensor.unsqueeze(0) # Add batch dimension
|
||||
|
||||
# Convert detections to SAM2 format
|
||||
boxes = []
|
||||
for det in detections:
|
||||
@@ -145,9 +196,16 @@ class SAM2StreamingProcessor:
|
||||
|
||||
boxes_tensor = torch.tensor(boxes, dtype=torch.float32, device=self.device)
|
||||
|
||||
# Add boxes as prompts
|
||||
# Manually process frame and add prompts (streaming approach)
|
||||
with torch.inference_mode():
|
||||
_, object_ids, _ = self.predictor.add_new_points_or_box(
|
||||
# Process frame through SAM2's image encoder
|
||||
features = self.predictor._get_image_features(frame_tensor)
|
||||
|
||||
# Store features in state for this frame
|
||||
state['cached_features'][frame_idx] = features
|
||||
|
||||
# Add boxes as prompts for this specific frame
|
||||
_, object_ids, masks = self.predictor.add_new_points_or_box(
|
||||
inference_state=state,
|
||||
frame_idx=frame_idx,
|
||||
obj_id=0, # SAM2 will auto-increment
|
||||
@@ -159,29 +217,78 @@ class SAM2StreamingProcessor:
|
||||
|
||||
return object_ids
|
||||
|
||||
def propagate_in_video_simple(self,
|
||||
state: Dict[str, Any]) -> Generator[Tuple[int, List[int], np.ndarray], None, None]:
|
||||
def propagate_single_frame(self,
|
||||
state: Dict[str, Any],
|
||||
frame: np.ndarray,
|
||||
frame_idx: int) -> np.ndarray:
|
||||
"""
|
||||
Simple propagation for single eye processing
|
||||
Propagate masks for a single frame (true streaming)
|
||||
|
||||
Yields:
|
||||
(frame_idx, object_ids, masks) tuples
|
||||
Args:
|
||||
state: Inference state
|
||||
frame: Frame image (RGB numpy array)
|
||||
frame_idx: Frame index
|
||||
|
||||
Returns:
|
||||
Combined mask for all objects
|
||||
"""
|
||||
# Convert frame to tensor
|
||||
frame_tensor = torch.tensor(frame, dtype=torch.float32, device=self.device)
|
||||
if frame_tensor.ndim == 3:
|
||||
frame_tensor = frame_tensor.permute(2, 0, 1) # HWC -> CHW
|
||||
frame_tensor = frame_tensor.unsqueeze(0) # Add batch dimension
|
||||
|
||||
with torch.inference_mode():
|
||||
for frame_idx, object_ids, masks in self.predictor.propagate_in_video(state):
|
||||
# Convert masks to numpy
|
||||
if isinstance(masks, torch.Tensor):
|
||||
masks_np = masks.cpu().numpy()
|
||||
else:
|
||||
masks_np = masks
|
||||
|
||||
yield frame_idx, object_ids, masks_np
|
||||
# Process frame through SAM2's image encoder
|
||||
features = self.predictor._get_image_features(frame_tensor)
|
||||
|
||||
# Store features in state for this frame
|
||||
state['cached_features'][frame_idx] = features
|
||||
|
||||
# Get masks for current frame by propagating from previous frames
|
||||
masks = []
|
||||
for obj_id in state.get('obj_ids', []):
|
||||
# Use SAM2's mask propagation for this object
|
||||
try:
|
||||
obj_mask = self.predictor._propagate_single_object(
|
||||
state, obj_id, frame_idx, features
|
||||
)
|
||||
if obj_mask is not None:
|
||||
masks.append(obj_mask)
|
||||
except Exception as e:
|
||||
# If propagation fails, use empty mask
|
||||
print(f" Warning: Propagation failed for object {obj_id}: {e}")
|
||||
empty_mask = torch.zeros((frame.shape[0], frame.shape[1]), device=self.device)
|
||||
masks.append(empty_mask)
|
||||
|
||||
# Combine all object masks
|
||||
if masks:
|
||||
combined_mask = torch.stack(masks).max(dim=0)[0]
|
||||
# Convert to numpy
|
||||
combined_mask_np = combined_mask.cpu().numpy().astype(np.uint8)
|
||||
else:
|
||||
combined_mask_np = np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8)
|
||||
|
||||
# Cleanup old features to prevent memory accumulation
|
||||
self._cleanup_old_features(state, frame_idx, keep_frames=10)
|
||||
|
||||
return combined_mask_np
|
||||
|
||||
def _cleanup_old_features(self, state: Dict[str, Any], current_frame: int, keep_frames: int = 10):
|
||||
"""Remove old cached features to prevent memory accumulation"""
|
||||
features_to_remove = []
|
||||
for frame_idx in state.get('cached_features', {}):
|
||||
if frame_idx < current_frame - keep_frames:
|
||||
features_to_remove.append(frame_idx)
|
||||
|
||||
# Periodic memory cleanup
|
||||
if frame_idx % 100 == 0:
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
for frame_idx in features_to_remove:
|
||||
del state['cached_features'][frame_idx]
|
||||
|
||||
# Periodic GPU memory cleanup
|
||||
if current_frame % 50 == 0:
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
def propagate_frame_pair(self,
|
||||
left_state: Dict[str, Any],
|
||||
|
||||
@@ -102,14 +102,15 @@ class VR180StreamingProcessor:
|
||||
self.initialize()
|
||||
self.start_time = time.time()
|
||||
|
||||
# Initialize SAM2 states for both eyes
|
||||
# Initialize SAM2 states for both eyes (streaming mode - no video loading)
|
||||
print("🎯 Initializing SAM2 streaming states...")
|
||||
video_info = self.frame_reader.get_video_info()
|
||||
left_state = self.sam2_processor.init_state(
|
||||
self.config.input.video_path,
|
||||
video_info,
|
||||
eye='left'
|
||||
)
|
||||
right_state = self.sam2_processor.init_state(
|
||||
self.config.input.video_path,
|
||||
video_info,
|
||||
eye='right'
|
||||
)
|
||||
|
||||
@@ -158,19 +159,19 @@ class VR180StreamingProcessor:
|
||||
|
||||
print(f" Detected {len(detections)} person(s) in first frame")
|
||||
|
||||
# Add detections to both eyes
|
||||
self.sam2_processor.add_detections(left_state, detections, frame_idx=0)
|
||||
# Add detections to both eyes (streaming - pass frame data)
|
||||
self.sam2_processor.add_detections(left_state, left_eye, detections, frame_idx=0)
|
||||
|
||||
# Transfer detections to slave eye
|
||||
transferred_detections = self.stereo_manager.transfer_detections(
|
||||
detections,
|
||||
'left_to_right' if self.stereo_manager.master_eye == 'left' else 'right_to_left'
|
||||
)
|
||||
self.sam2_processor.add_detections(right_state, transferred_detections, frame_idx=0)
|
||||
self.sam2_processor.add_detections(right_state, right_eye, transferred_detections, frame_idx=0)
|
||||
|
||||
# Process and write first frame
|
||||
left_masks = self.sam2_processor._propagate_single_frame(left_state, left_eye, 0)
|
||||
right_masks = self.sam2_processor._propagate_single_frame(right_state, right_eye, 0)
|
||||
left_masks = self.sam2_processor.propagate_single_frame(left_state, left_eye, 0)
|
||||
right_masks = self.sam2_processor.propagate_single_frame(right_state, right_eye, 0)
|
||||
|
||||
# Apply masks and write
|
||||
processed_frame = self._apply_masks_to_frame(first_frame, left_masks, right_masks)
|
||||
@@ -195,10 +196,9 @@ class VR180StreamingProcessor:
|
||||
# Split into eyes
|
||||
left_eye, right_eye = self.stereo_manager.split_frame(frame)
|
||||
|
||||
# Propagate masks for both eyes
|
||||
left_masks, right_masks = self.sam2_processor.propagate_frame_pair(
|
||||
left_state, right_state, left_eye, right_eye, frame_idx
|
||||
)
|
||||
# Propagate masks for both eyes (streaming approach)
|
||||
left_masks = self.sam2_processor.propagate_single_frame(left_state, left_eye, frame_idx)
|
||||
right_masks = self.sam2_processor.propagate_single_frame(right_state, right_eye, frame_idx)
|
||||
|
||||
# Validate stereo consistency
|
||||
right_masks = self.stereo_manager.validate_masks(
|
||||
|
||||
Reference in New Issue
Block a user