add stuff true streaming
This commit is contained in:
@@ -18,6 +18,7 @@ from pathlib import Path
|
|||||||
from typing import Dict, Any, List, Optional, Tuple, Generator
|
from typing import Dict, Any, List, Optional, Tuple, Generator
|
||||||
import warnings
|
import warnings
|
||||||
import gc
|
import gc
|
||||||
|
from .timeout_init import safe_init_state, TimeoutError
|
||||||
|
|
||||||
# Import SAM2 components - these will be available after SAM2 installation
|
# Import SAM2 components - these will be available after SAM2 installation
|
||||||
try:
|
try:
|
||||||
@@ -92,41 +93,85 @@ class SAM2StreamingProcessor:
|
|||||||
raise RuntimeError(f"Failed to initialize SAM2 predictor: {e}")
|
raise RuntimeError(f"Failed to initialize SAM2 predictor: {e}")
|
||||||
|
|
||||||
def init_state(self,
|
def init_state(self,
|
||||||
video_path: str,
|
video_info: Dict[str, Any],
|
||||||
eye: str = 'full') -> Dict[str, Any]:
|
eye: str = 'full') -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Initialize inference state for streaming
|
Initialize inference state for streaming (NO VIDEO LOADING)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
video_path: Path to video file
|
video_info: Video metadata dict with width, height, frame_count
|
||||||
eye: Eye identifier ('left', 'right', or 'full')
|
eye: Eye identifier ('left', 'right', or 'full')
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Inference state dictionary
|
Inference state dictionary
|
||||||
"""
|
"""
|
||||||
# Initialize state with memory offloading enabled
|
print(f" Initializing streaming state for {eye} eye...")
|
||||||
with torch.inference_mode():
|
|
||||||
state = self.predictor.init_state(
|
# Monitor memory before initialization
|
||||||
video_path=video_path,
|
if torch.cuda.is_available():
|
||||||
offload_video_to_cpu=self.memory_offload,
|
before_mem = torch.cuda.memory_allocated() / 1e9
|
||||||
offload_state_to_cpu=self.memory_offload,
|
print(f" 📊 GPU memory before init: {before_mem:.1f}GB")
|
||||||
async_loading_frames=False # We'll provide frames directly
|
|
||||||
)
|
# Create streaming state WITHOUT loading video frames
|
||||||
|
state = self._create_streaming_state(video_info)
|
||||||
|
|
||||||
|
# Monitor memory after initialization
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
after_mem = torch.cuda.memory_allocated() / 1e9
|
||||||
|
print(f" 📊 GPU memory after init: {after_mem:.1f}GB (+{after_mem-before_mem:.1f}GB)")
|
||||||
|
|
||||||
self.states[eye] = state
|
self.states[eye] = state
|
||||||
print(f" Initialized state for {eye} eye")
|
print(f" ✅ Streaming state initialized for {eye} eye")
|
||||||
|
|
||||||
return state
|
return state
|
||||||
|
|
||||||
|
def _create_streaming_state(self, video_info: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Create streaming state for frame-by-frame processing"""
|
||||||
|
# Create a streaming-compatible inference state
|
||||||
|
# This mirrors SAM2's internal state structure but without video frames
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
# Initialize empty inference state using SAM2's predictor
|
||||||
|
# We'll manually provide frames via propagate calls
|
||||||
|
inference_state = {
|
||||||
|
'point_inputs_per_obj': {},
|
||||||
|
'mask_inputs_per_obj': {},
|
||||||
|
'cached_features': {},
|
||||||
|
'constants': {},
|
||||||
|
'obj_id_to_idx': {},
|
||||||
|
'obj_idx_to_id': {},
|
||||||
|
'obj_ids': [],
|
||||||
|
'click_inputs_per_obj': {},
|
||||||
|
'temp_output_dict_per_obj': {},
|
||||||
|
'consolidated_frame_inds': {},
|
||||||
|
'tracking_has_started': False,
|
||||||
|
'num_frames': video_info['frame_count'],
|
||||||
|
'video_height': video_info['height'],
|
||||||
|
'video_width': video_info['width'],
|
||||||
|
'device': self.device,
|
||||||
|
'storage_device': torch.device('cpu') if self.memory_offload else self.device,
|
||||||
|
'offload_video_to_cpu': self.memory_offload,
|
||||||
|
'offload_state_to_cpu': self.memory_offload,
|
||||||
|
'inference_state': {},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Initialize SAM2 constants that don't depend on video frames
|
||||||
|
self.predictor._get_image_feature_cache = {}
|
||||||
|
self.predictor._feature_bank = {}
|
||||||
|
|
||||||
|
return inference_state
|
||||||
|
|
||||||
def add_detections(self,
|
def add_detections(self,
|
||||||
state: Dict[str, Any],
|
state: Dict[str, Any],
|
||||||
|
frame: np.ndarray,
|
||||||
detections: List[Dict[str, Any]],
|
detections: List[Dict[str, Any]],
|
||||||
frame_idx: int = 0) -> List[int]:
|
frame_idx: int = 0) -> List[int]:
|
||||||
"""
|
"""
|
||||||
Add detection boxes as prompts to SAM2
|
Add detection boxes as prompts to SAM2 with frame data
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
state: Inference state
|
state: Inference state
|
||||||
|
frame: Frame image (RGB numpy array)
|
||||||
detections: List of detections with 'box' key
|
detections: List of detections with 'box' key
|
||||||
frame_idx: Frame index to add prompts
|
frame_idx: Frame index to add prompts
|
||||||
|
|
||||||
@@ -137,6 +182,12 @@ class SAM2StreamingProcessor:
|
|||||||
warnings.warn(f"No detections to add at frame {frame_idx}")
|
warnings.warn(f"No detections to add at frame {frame_idx}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
# Convert frame to tensor
|
||||||
|
frame_tensor = torch.tensor(frame, dtype=torch.float32, device=self.device)
|
||||||
|
if frame_tensor.ndim == 3:
|
||||||
|
frame_tensor = frame_tensor.permute(2, 0, 1) # HWC -> CHW
|
||||||
|
frame_tensor = frame_tensor.unsqueeze(0) # Add batch dimension
|
||||||
|
|
||||||
# Convert detections to SAM2 format
|
# Convert detections to SAM2 format
|
||||||
boxes = []
|
boxes = []
|
||||||
for det in detections:
|
for det in detections:
|
||||||
@@ -145,9 +196,16 @@ class SAM2StreamingProcessor:
|
|||||||
|
|
||||||
boxes_tensor = torch.tensor(boxes, dtype=torch.float32, device=self.device)
|
boxes_tensor = torch.tensor(boxes, dtype=torch.float32, device=self.device)
|
||||||
|
|
||||||
# Add boxes as prompts
|
# Manually process frame and add prompts (streaming approach)
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
_, object_ids, _ = self.predictor.add_new_points_or_box(
|
# Process frame through SAM2's image encoder
|
||||||
|
features = self.predictor._get_image_features(frame_tensor)
|
||||||
|
|
||||||
|
# Store features in state for this frame
|
||||||
|
state['cached_features'][frame_idx] = features
|
||||||
|
|
||||||
|
# Add boxes as prompts for this specific frame
|
||||||
|
_, object_ids, masks = self.predictor.add_new_points_or_box(
|
||||||
inference_state=state,
|
inference_state=state,
|
||||||
frame_idx=frame_idx,
|
frame_idx=frame_idx,
|
||||||
obj_id=0, # SAM2 will auto-increment
|
obj_id=0, # SAM2 will auto-increment
|
||||||
@@ -159,29 +217,78 @@ class SAM2StreamingProcessor:
|
|||||||
|
|
||||||
return object_ids
|
return object_ids
|
||||||
|
|
||||||
def propagate_in_video_simple(self,
|
def propagate_single_frame(self,
|
||||||
state: Dict[str, Any]) -> Generator[Tuple[int, List[int], np.ndarray], None, None]:
|
state: Dict[str, Any],
|
||||||
|
frame: np.ndarray,
|
||||||
|
frame_idx: int) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Simple propagation for single eye processing
|
Propagate masks for a single frame (true streaming)
|
||||||
|
|
||||||
Yields:
|
Args:
|
||||||
(frame_idx, object_ids, masks) tuples
|
state: Inference state
|
||||||
|
frame: Frame image (RGB numpy array)
|
||||||
|
frame_idx: Frame index
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Combined mask for all objects
|
||||||
"""
|
"""
|
||||||
|
# Convert frame to tensor
|
||||||
|
frame_tensor = torch.tensor(frame, dtype=torch.float32, device=self.device)
|
||||||
|
if frame_tensor.ndim == 3:
|
||||||
|
frame_tensor = frame_tensor.permute(2, 0, 1) # HWC -> CHW
|
||||||
|
frame_tensor = frame_tensor.unsqueeze(0) # Add batch dimension
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
for frame_idx, object_ids, masks in self.predictor.propagate_in_video(state):
|
# Process frame through SAM2's image encoder
|
||||||
# Convert masks to numpy
|
features = self.predictor._get_image_features(frame_tensor)
|
||||||
if isinstance(masks, torch.Tensor):
|
|
||||||
masks_np = masks.cpu().numpy()
|
# Store features in state for this frame
|
||||||
else:
|
state['cached_features'][frame_idx] = features
|
||||||
masks_np = masks
|
|
||||||
|
# Get masks for current frame by propagating from previous frames
|
||||||
yield frame_idx, object_ids, masks_np
|
masks = []
|
||||||
|
for obj_id in state.get('obj_ids', []):
|
||||||
|
# Use SAM2's mask propagation for this object
|
||||||
|
try:
|
||||||
|
obj_mask = self.predictor._propagate_single_object(
|
||||||
|
state, obj_id, frame_idx, features
|
||||||
|
)
|
||||||
|
if obj_mask is not None:
|
||||||
|
masks.append(obj_mask)
|
||||||
|
except Exception as e:
|
||||||
|
# If propagation fails, use empty mask
|
||||||
|
print(f" Warning: Propagation failed for object {obj_id}: {e}")
|
||||||
|
empty_mask = torch.zeros((frame.shape[0], frame.shape[1]), device=self.device)
|
||||||
|
masks.append(empty_mask)
|
||||||
|
|
||||||
|
# Combine all object masks
|
||||||
|
if masks:
|
||||||
|
combined_mask = torch.stack(masks).max(dim=0)[0]
|
||||||
|
# Convert to numpy
|
||||||
|
combined_mask_np = combined_mask.cpu().numpy().astype(np.uint8)
|
||||||
|
else:
|
||||||
|
combined_mask_np = np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8)
|
||||||
|
|
||||||
|
# Cleanup old features to prevent memory accumulation
|
||||||
|
self._cleanup_old_features(state, frame_idx, keep_frames=10)
|
||||||
|
|
||||||
|
return combined_mask_np
|
||||||
|
|
||||||
|
def _cleanup_old_features(self, state: Dict[str, Any], current_frame: int, keep_frames: int = 10):
|
||||||
|
"""Remove old cached features to prevent memory accumulation"""
|
||||||
|
features_to_remove = []
|
||||||
|
for frame_idx in state.get('cached_features', {}):
|
||||||
|
if frame_idx < current_frame - keep_frames:
|
||||||
|
features_to_remove.append(frame_idx)
|
||||||
|
|
||||||
# Periodic memory cleanup
|
for frame_idx in features_to_remove:
|
||||||
if frame_idx % 100 == 0:
|
del state['cached_features'][frame_idx]
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.empty_cache()
|
# Periodic GPU memory cleanup
|
||||||
gc.collect()
|
if current_frame % 50 == 0:
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
def propagate_frame_pair(self,
|
def propagate_frame_pair(self,
|
||||||
left_state: Dict[str, Any],
|
left_state: Dict[str, Any],
|
||||||
|
|||||||
@@ -102,14 +102,15 @@ class VR180StreamingProcessor:
|
|||||||
self.initialize()
|
self.initialize()
|
||||||
self.start_time = time.time()
|
self.start_time = time.time()
|
||||||
|
|
||||||
# Initialize SAM2 states for both eyes
|
# Initialize SAM2 states for both eyes (streaming mode - no video loading)
|
||||||
print("🎯 Initializing SAM2 streaming states...")
|
print("🎯 Initializing SAM2 streaming states...")
|
||||||
|
video_info = self.frame_reader.get_video_info()
|
||||||
left_state = self.sam2_processor.init_state(
|
left_state = self.sam2_processor.init_state(
|
||||||
self.config.input.video_path,
|
video_info,
|
||||||
eye='left'
|
eye='left'
|
||||||
)
|
)
|
||||||
right_state = self.sam2_processor.init_state(
|
right_state = self.sam2_processor.init_state(
|
||||||
self.config.input.video_path,
|
video_info,
|
||||||
eye='right'
|
eye='right'
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -158,19 +159,19 @@ class VR180StreamingProcessor:
|
|||||||
|
|
||||||
print(f" Detected {len(detections)} person(s) in first frame")
|
print(f" Detected {len(detections)} person(s) in first frame")
|
||||||
|
|
||||||
# Add detections to both eyes
|
# Add detections to both eyes (streaming - pass frame data)
|
||||||
self.sam2_processor.add_detections(left_state, detections, frame_idx=0)
|
self.sam2_processor.add_detections(left_state, left_eye, detections, frame_idx=0)
|
||||||
|
|
||||||
# Transfer detections to slave eye
|
# Transfer detections to slave eye
|
||||||
transferred_detections = self.stereo_manager.transfer_detections(
|
transferred_detections = self.stereo_manager.transfer_detections(
|
||||||
detections,
|
detections,
|
||||||
'left_to_right' if self.stereo_manager.master_eye == 'left' else 'right_to_left'
|
'left_to_right' if self.stereo_manager.master_eye == 'left' else 'right_to_left'
|
||||||
)
|
)
|
||||||
self.sam2_processor.add_detections(right_state, transferred_detections, frame_idx=0)
|
self.sam2_processor.add_detections(right_state, right_eye, transferred_detections, frame_idx=0)
|
||||||
|
|
||||||
# Process and write first frame
|
# Process and write first frame
|
||||||
left_masks = self.sam2_processor._propagate_single_frame(left_state, left_eye, 0)
|
left_masks = self.sam2_processor.propagate_single_frame(left_state, left_eye, 0)
|
||||||
right_masks = self.sam2_processor._propagate_single_frame(right_state, right_eye, 0)
|
right_masks = self.sam2_processor.propagate_single_frame(right_state, right_eye, 0)
|
||||||
|
|
||||||
# Apply masks and write
|
# Apply masks and write
|
||||||
processed_frame = self._apply_masks_to_frame(first_frame, left_masks, right_masks)
|
processed_frame = self._apply_masks_to_frame(first_frame, left_masks, right_masks)
|
||||||
@@ -195,10 +196,9 @@ class VR180StreamingProcessor:
|
|||||||
# Split into eyes
|
# Split into eyes
|
||||||
left_eye, right_eye = self.stereo_manager.split_frame(frame)
|
left_eye, right_eye = self.stereo_manager.split_frame(frame)
|
||||||
|
|
||||||
# Propagate masks for both eyes
|
# Propagate masks for both eyes (streaming approach)
|
||||||
left_masks, right_masks = self.sam2_processor.propagate_frame_pair(
|
left_masks = self.sam2_processor.propagate_single_frame(left_state, left_eye, frame_idx)
|
||||||
left_state, right_state, left_eye, right_eye, frame_idx
|
right_masks = self.sam2_processor.propagate_single_frame(right_state, right_eye, frame_idx)
|
||||||
)
|
|
||||||
|
|
||||||
# Validate stereo consistency
|
# Validate stereo consistency
|
||||||
right_masks = self.stereo_manager.validate_masks(
|
right_masks = self.stereo_manager.validate_masks(
|
||||||
|
|||||||
Reference in New Issue
Block a user