Compare commits

...

2 Commits

Author SHA1 Message Date
ee80ed28b6 add stuff true streaming 2025-07-27 08:54:19 -07:00
b5eae7b41d pytorch shit 2025-07-27 08:40:59 -07:00
3 changed files with 156 additions and 45 deletions

View File

@@ -83,6 +83,10 @@ else
cd ..
fi
# Fix PyTorch version conflicts after SAM2 installation
print_status "Fixing PyTorch version conflicts..."
pip install torchaudio --upgrade --no-deps || print_error "Failed to upgrade torchaudio"
# Download SAM2 checkpoints
print_status "Downloading SAM2 checkpoints..."
cd segment-anything-2/checkpoints

View File

@@ -18,6 +18,7 @@ from pathlib import Path
from typing import Dict, Any, List, Optional, Tuple, Generator
import warnings
import gc
from .timeout_init import safe_init_state, TimeoutError
# Import SAM2 components - these will be available after SAM2 installation
try:
@@ -92,41 +93,85 @@ class SAM2StreamingProcessor:
raise RuntimeError(f"Failed to initialize SAM2 predictor: {e}")
def init_state(self,
video_path: str,
video_info: Dict[str, Any],
eye: str = 'full') -> Dict[str, Any]:
"""
Initialize inference state for streaming
Initialize inference state for streaming (NO VIDEO LOADING)
Args:
video_path: Path to video file
video_info: Video metadata dict with width, height, frame_count
eye: Eye identifier ('left', 'right', or 'full')
Returns:
Inference state dictionary
"""
# Initialize state with memory offloading enabled
with torch.inference_mode():
state = self.predictor.init_state(
video_path=video_path,
offload_video_to_cpu=self.memory_offload,
offload_state_to_cpu=self.memory_offload,
async_loading_frames=False # We'll provide frames directly
)
print(f" Initializing streaming state for {eye} eye...")
# Monitor memory before initialization
if torch.cuda.is_available():
before_mem = torch.cuda.memory_allocated() / 1e9
print(f" 📊 GPU memory before init: {before_mem:.1f}GB")
# Create streaming state WITHOUT loading video frames
state = self._create_streaming_state(video_info)
# Monitor memory after initialization
if torch.cuda.is_available():
after_mem = torch.cuda.memory_allocated() / 1e9
print(f" 📊 GPU memory after init: {after_mem:.1f}GB (+{after_mem-before_mem:.1f}GB)")
self.states[eye] = state
print(f" Initialized state for {eye} eye")
print(f" ✅ Streaming state initialized for {eye} eye")
return state
def _create_streaming_state(self, video_info: Dict[str, Any]) -> Dict[str, Any]:
"""Create streaming state for frame-by-frame processing"""
# Create a streaming-compatible inference state
# This mirrors SAM2's internal state structure but without video frames
with torch.inference_mode():
# Initialize empty inference state using SAM2's predictor
# We'll manually provide frames via propagate calls
inference_state = {
'point_inputs_per_obj': {},
'mask_inputs_per_obj': {},
'cached_features': {},
'constants': {},
'obj_id_to_idx': {},
'obj_idx_to_id': {},
'obj_ids': [],
'click_inputs_per_obj': {},
'temp_output_dict_per_obj': {},
'consolidated_frame_inds': {},
'tracking_has_started': False,
'num_frames': video_info['frame_count'],
'video_height': video_info['height'],
'video_width': video_info['width'],
'device': self.device,
'storage_device': torch.device('cpu') if self.memory_offload else self.device,
'offload_video_to_cpu': self.memory_offload,
'offload_state_to_cpu': self.memory_offload,
'inference_state': {},
}
# Initialize SAM2 constants that don't depend on video frames
self.predictor._get_image_feature_cache = {}
self.predictor._feature_bank = {}
return inference_state
def add_detections(self,
state: Dict[str, Any],
frame: np.ndarray,
detections: List[Dict[str, Any]],
frame_idx: int = 0) -> List[int]:
"""
Add detection boxes as prompts to SAM2
Add detection boxes as prompts to SAM2 with frame data
Args:
state: Inference state
frame: Frame image (RGB numpy array)
detections: List of detections with 'box' key
frame_idx: Frame index to add prompts
@@ -137,6 +182,12 @@ class SAM2StreamingProcessor:
warnings.warn(f"No detections to add at frame {frame_idx}")
return []
# Convert frame to tensor
frame_tensor = torch.tensor(frame, dtype=torch.float32, device=self.device)
if frame_tensor.ndim == 3:
frame_tensor = frame_tensor.permute(2, 0, 1) # HWC -> CHW
frame_tensor = frame_tensor.unsqueeze(0) # Add batch dimension
# Convert detections to SAM2 format
boxes = []
for det in detections:
@@ -145,9 +196,16 @@ class SAM2StreamingProcessor:
boxes_tensor = torch.tensor(boxes, dtype=torch.float32, device=self.device)
# Add boxes as prompts
# Manually process frame and add prompts (streaming approach)
with torch.inference_mode():
_, object_ids, _ = self.predictor.add_new_points_or_box(
# Process frame through SAM2's image encoder
features = self.predictor._get_image_features(frame_tensor)
# Store features in state for this frame
state['cached_features'][frame_idx] = features
# Add boxes as prompts for this specific frame
_, object_ids, masks = self.predictor.add_new_points_or_box(
inference_state=state,
frame_idx=frame_idx,
obj_id=0, # SAM2 will auto-increment
@@ -159,29 +217,78 @@ class SAM2StreamingProcessor:
return object_ids
def propagate_in_video_simple(self,
state: Dict[str, Any]) -> Generator[Tuple[int, List[int], np.ndarray], None, None]:
def propagate_single_frame(self,
state: Dict[str, Any],
frame: np.ndarray,
frame_idx: int) -> np.ndarray:
"""
Simple propagation for single eye processing
Propagate masks for a single frame (true streaming)
Yields:
(frame_idx, object_ids, masks) tuples
Args:
state: Inference state
frame: Frame image (RGB numpy array)
frame_idx: Frame index
Returns:
Combined mask for all objects
"""
# Convert frame to tensor
frame_tensor = torch.tensor(frame, dtype=torch.float32, device=self.device)
if frame_tensor.ndim == 3:
frame_tensor = frame_tensor.permute(2, 0, 1) # HWC -> CHW
frame_tensor = frame_tensor.unsqueeze(0) # Add batch dimension
with torch.inference_mode():
for frame_idx, object_ids, masks in self.predictor.propagate_in_video(state):
# Convert masks to numpy
if isinstance(masks, torch.Tensor):
masks_np = masks.cpu().numpy()
else:
masks_np = masks
yield frame_idx, object_ids, masks_np
# Process frame through SAM2's image encoder
features = self.predictor._get_image_features(frame_tensor)
# Store features in state for this frame
state['cached_features'][frame_idx] = features
# Get masks for current frame by propagating from previous frames
masks = []
for obj_id in state.get('obj_ids', []):
# Use SAM2's mask propagation for this object
try:
obj_mask = self.predictor._propagate_single_object(
state, obj_id, frame_idx, features
)
if obj_mask is not None:
masks.append(obj_mask)
except Exception as e:
# If propagation fails, use empty mask
print(f" Warning: Propagation failed for object {obj_id}: {e}")
empty_mask = torch.zeros((frame.shape[0], frame.shape[1]), device=self.device)
masks.append(empty_mask)
# Combine all object masks
if masks:
combined_mask = torch.stack(masks).max(dim=0)[0]
# Convert to numpy
combined_mask_np = combined_mask.cpu().numpy().astype(np.uint8)
else:
combined_mask_np = np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8)
# Cleanup old features to prevent memory accumulation
self._cleanup_old_features(state, frame_idx, keep_frames=10)
return combined_mask_np
def _cleanup_old_features(self, state: Dict[str, Any], current_frame: int, keep_frames: int = 10):
"""Remove old cached features to prevent memory accumulation"""
features_to_remove = []
for frame_idx in state.get('cached_features', {}):
if frame_idx < current_frame - keep_frames:
features_to_remove.append(frame_idx)
# Periodic memory cleanup
if frame_idx % 100 == 0:
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
for frame_idx in features_to_remove:
del state['cached_features'][frame_idx]
# Periodic GPU memory cleanup
if current_frame % 50 == 0:
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
def propagate_frame_pair(self,
left_state: Dict[str, Any],

View File

@@ -102,14 +102,15 @@ class VR180StreamingProcessor:
self.initialize()
self.start_time = time.time()
# Initialize SAM2 states for both eyes
# Initialize SAM2 states for both eyes (streaming mode - no video loading)
print("🎯 Initializing SAM2 streaming states...")
video_info = self.frame_reader.get_video_info()
left_state = self.sam2_processor.init_state(
self.config.input.video_path,
video_info,
eye='left'
)
right_state = self.sam2_processor.init_state(
self.config.input.video_path,
video_info,
eye='right'
)
@@ -158,19 +159,19 @@ class VR180StreamingProcessor:
print(f" Detected {len(detections)} person(s) in first frame")
# Add detections to both eyes
self.sam2_processor.add_detections(left_state, detections, frame_idx=0)
# Add detections to both eyes (streaming - pass frame data)
self.sam2_processor.add_detections(left_state, left_eye, detections, frame_idx=0)
# Transfer detections to slave eye
transferred_detections = self.stereo_manager.transfer_detections(
detections,
'left_to_right' if self.stereo_manager.master_eye == 'left' else 'right_to_left'
)
self.sam2_processor.add_detections(right_state, transferred_detections, frame_idx=0)
self.sam2_processor.add_detections(right_state, right_eye, transferred_detections, frame_idx=0)
# Process and write first frame
left_masks = self.sam2_processor._propagate_single_frame(left_state, left_eye, 0)
right_masks = self.sam2_processor._propagate_single_frame(right_state, right_eye, 0)
left_masks = self.sam2_processor.propagate_single_frame(left_state, left_eye, 0)
right_masks = self.sam2_processor.propagate_single_frame(right_state, right_eye, 0)
# Apply masks and write
processed_frame = self._apply_masks_to_frame(first_frame, left_masks, right_masks)
@@ -195,10 +196,9 @@ class VR180StreamingProcessor:
# Split into eyes
left_eye, right_eye = self.stereo_manager.split_frame(frame)
# Propagate masks for both eyes
left_masks, right_masks = self.sam2_processor.propagate_frame_pair(
left_state, right_state, left_eye, right_eye, frame_idx
)
# Propagate masks for both eyes (streaming approach)
left_masks = self.sam2_processor.propagate_single_frame(left_state, left_eye, frame_idx)
right_masks = self.sam2_processor.propagate_single_frame(right_state, right_eye, frame_idx)
# Validate stereo consistency
right_masks = self.stereo_manager.validate_masks(