Compare commits

...

2 Commits

Author SHA1 Message Date
ee80ed28b6 add stuff true streaming 2025-07-27 08:54:19 -07:00
b5eae7b41d pytorch shit 2025-07-27 08:40:59 -07:00
3 changed files with 156 additions and 45 deletions

View File

@@ -83,6 +83,10 @@ else
cd .. cd ..
fi fi
# Fix PyTorch version conflicts after SAM2 installation
print_status "Fixing PyTorch version conflicts..."
pip install torchaudio --upgrade --no-deps || print_error "Failed to upgrade torchaudio"
# Download SAM2 checkpoints # Download SAM2 checkpoints
print_status "Downloading SAM2 checkpoints..." print_status "Downloading SAM2 checkpoints..."
cd segment-anything-2/checkpoints cd segment-anything-2/checkpoints

View File

@@ -18,6 +18,7 @@ from pathlib import Path
from typing import Dict, Any, List, Optional, Tuple, Generator from typing import Dict, Any, List, Optional, Tuple, Generator
import warnings import warnings
import gc import gc
from .timeout_init import safe_init_state, TimeoutError
# Import SAM2 components - these will be available after SAM2 installation # Import SAM2 components - these will be available after SAM2 installation
try: try:
@@ -92,41 +93,85 @@ class SAM2StreamingProcessor:
raise RuntimeError(f"Failed to initialize SAM2 predictor: {e}") raise RuntimeError(f"Failed to initialize SAM2 predictor: {e}")
def init_state(self, def init_state(self,
video_path: str, video_info: Dict[str, Any],
eye: str = 'full') -> Dict[str, Any]: eye: str = 'full') -> Dict[str, Any]:
""" """
Initialize inference state for streaming Initialize inference state for streaming (NO VIDEO LOADING)
Args: Args:
video_path: Path to video file video_info: Video metadata dict with width, height, frame_count
eye: Eye identifier ('left', 'right', or 'full') eye: Eye identifier ('left', 'right', or 'full')
Returns: Returns:
Inference state dictionary Inference state dictionary
""" """
# Initialize state with memory offloading enabled print(f" Initializing streaming state for {eye} eye...")
with torch.inference_mode():
state = self.predictor.init_state( # Monitor memory before initialization
video_path=video_path, if torch.cuda.is_available():
offload_video_to_cpu=self.memory_offload, before_mem = torch.cuda.memory_allocated() / 1e9
offload_state_to_cpu=self.memory_offload, print(f" 📊 GPU memory before init: {before_mem:.1f}GB")
async_loading_frames=False # We'll provide frames directly
) # Create streaming state WITHOUT loading video frames
state = self._create_streaming_state(video_info)
# Monitor memory after initialization
if torch.cuda.is_available():
after_mem = torch.cuda.memory_allocated() / 1e9
print(f" 📊 GPU memory after init: {after_mem:.1f}GB (+{after_mem-before_mem:.1f}GB)")
self.states[eye] = state self.states[eye] = state
print(f" Initialized state for {eye} eye") print(f" ✅ Streaming state initialized for {eye} eye")
return state return state
def _create_streaming_state(self, video_info: Dict[str, Any]) -> Dict[str, Any]:
"""Create streaming state for frame-by-frame processing"""
# Create a streaming-compatible inference state
# This mirrors SAM2's internal state structure but without video frames
with torch.inference_mode():
# Initialize empty inference state using SAM2's predictor
# We'll manually provide frames via propagate calls
inference_state = {
'point_inputs_per_obj': {},
'mask_inputs_per_obj': {},
'cached_features': {},
'constants': {},
'obj_id_to_idx': {},
'obj_idx_to_id': {},
'obj_ids': [],
'click_inputs_per_obj': {},
'temp_output_dict_per_obj': {},
'consolidated_frame_inds': {},
'tracking_has_started': False,
'num_frames': video_info['frame_count'],
'video_height': video_info['height'],
'video_width': video_info['width'],
'device': self.device,
'storage_device': torch.device('cpu') if self.memory_offload else self.device,
'offload_video_to_cpu': self.memory_offload,
'offload_state_to_cpu': self.memory_offload,
'inference_state': {},
}
# Initialize SAM2 constants that don't depend on video frames
self.predictor._get_image_feature_cache = {}
self.predictor._feature_bank = {}
return inference_state
def add_detections(self, def add_detections(self,
state: Dict[str, Any], state: Dict[str, Any],
frame: np.ndarray,
detections: List[Dict[str, Any]], detections: List[Dict[str, Any]],
frame_idx: int = 0) -> List[int]: frame_idx: int = 0) -> List[int]:
""" """
Add detection boxes as prompts to SAM2 Add detection boxes as prompts to SAM2 with frame data
Args: Args:
state: Inference state state: Inference state
frame: Frame image (RGB numpy array)
detections: List of detections with 'box' key detections: List of detections with 'box' key
frame_idx: Frame index to add prompts frame_idx: Frame index to add prompts
@@ -137,6 +182,12 @@ class SAM2StreamingProcessor:
warnings.warn(f"No detections to add at frame {frame_idx}") warnings.warn(f"No detections to add at frame {frame_idx}")
return [] return []
# Convert frame to tensor
frame_tensor = torch.tensor(frame, dtype=torch.float32, device=self.device)
if frame_tensor.ndim == 3:
frame_tensor = frame_tensor.permute(2, 0, 1) # HWC -> CHW
frame_tensor = frame_tensor.unsqueeze(0) # Add batch dimension
# Convert detections to SAM2 format # Convert detections to SAM2 format
boxes = [] boxes = []
for det in detections: for det in detections:
@@ -145,9 +196,16 @@ class SAM2StreamingProcessor:
boxes_tensor = torch.tensor(boxes, dtype=torch.float32, device=self.device) boxes_tensor = torch.tensor(boxes, dtype=torch.float32, device=self.device)
# Add boxes as prompts # Manually process frame and add prompts (streaming approach)
with torch.inference_mode(): with torch.inference_mode():
_, object_ids, _ = self.predictor.add_new_points_or_box( # Process frame through SAM2's image encoder
features = self.predictor._get_image_features(frame_tensor)
# Store features in state for this frame
state['cached_features'][frame_idx] = features
# Add boxes as prompts for this specific frame
_, object_ids, masks = self.predictor.add_new_points_or_box(
inference_state=state, inference_state=state,
frame_idx=frame_idx, frame_idx=frame_idx,
obj_id=0, # SAM2 will auto-increment obj_id=0, # SAM2 will auto-increment
@@ -159,29 +217,78 @@ class SAM2StreamingProcessor:
return object_ids return object_ids
def propagate_in_video_simple(self, def propagate_single_frame(self,
state: Dict[str, Any]) -> Generator[Tuple[int, List[int], np.ndarray], None, None]: state: Dict[str, Any],
frame: np.ndarray,
frame_idx: int) -> np.ndarray:
""" """
Simple propagation for single eye processing Propagate masks for a single frame (true streaming)
Yields: Args:
(frame_idx, object_ids, masks) tuples state: Inference state
frame: Frame image (RGB numpy array)
frame_idx: Frame index
Returns:
Combined mask for all objects
""" """
# Convert frame to tensor
frame_tensor = torch.tensor(frame, dtype=torch.float32, device=self.device)
if frame_tensor.ndim == 3:
frame_tensor = frame_tensor.permute(2, 0, 1) # HWC -> CHW
frame_tensor = frame_tensor.unsqueeze(0) # Add batch dimension
with torch.inference_mode(): with torch.inference_mode():
for frame_idx, object_ids, masks in self.predictor.propagate_in_video(state): # Process frame through SAM2's image encoder
# Convert masks to numpy features = self.predictor._get_image_features(frame_tensor)
if isinstance(masks, torch.Tensor):
masks_np = masks.cpu().numpy() # Store features in state for this frame
else: state['cached_features'][frame_idx] = features
masks_np = masks
# Get masks for current frame by propagating from previous frames
yield frame_idx, object_ids, masks_np masks = []
for obj_id in state.get('obj_ids', []):
# Use SAM2's mask propagation for this object
try:
obj_mask = self.predictor._propagate_single_object(
state, obj_id, frame_idx, features
)
if obj_mask is not None:
masks.append(obj_mask)
except Exception as e:
# If propagation fails, use empty mask
print(f" Warning: Propagation failed for object {obj_id}: {e}")
empty_mask = torch.zeros((frame.shape[0], frame.shape[1]), device=self.device)
masks.append(empty_mask)
# Combine all object masks
if masks:
combined_mask = torch.stack(masks).max(dim=0)[0]
# Convert to numpy
combined_mask_np = combined_mask.cpu().numpy().astype(np.uint8)
else:
combined_mask_np = np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8)
# Cleanup old features to prevent memory accumulation
self._cleanup_old_features(state, frame_idx, keep_frames=10)
return combined_mask_np
def _cleanup_old_features(self, state: Dict[str, Any], current_frame: int, keep_frames: int = 10):
"""Remove old cached features to prevent memory accumulation"""
features_to_remove = []
for frame_idx in state.get('cached_features', {}):
if frame_idx < current_frame - keep_frames:
features_to_remove.append(frame_idx)
# Periodic memory cleanup for frame_idx in features_to_remove:
if frame_idx % 100 == 0: del state['cached_features'][frame_idx]
if torch.cuda.is_available():
torch.cuda.empty_cache() # Periodic GPU memory cleanup
gc.collect() if current_frame % 50 == 0:
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
def propagate_frame_pair(self, def propagate_frame_pair(self,
left_state: Dict[str, Any], left_state: Dict[str, Any],

View File

@@ -102,14 +102,15 @@ class VR180StreamingProcessor:
self.initialize() self.initialize()
self.start_time = time.time() self.start_time = time.time()
# Initialize SAM2 states for both eyes # Initialize SAM2 states for both eyes (streaming mode - no video loading)
print("🎯 Initializing SAM2 streaming states...") print("🎯 Initializing SAM2 streaming states...")
video_info = self.frame_reader.get_video_info()
left_state = self.sam2_processor.init_state( left_state = self.sam2_processor.init_state(
self.config.input.video_path, video_info,
eye='left' eye='left'
) )
right_state = self.sam2_processor.init_state( right_state = self.sam2_processor.init_state(
self.config.input.video_path, video_info,
eye='right' eye='right'
) )
@@ -158,19 +159,19 @@ class VR180StreamingProcessor:
print(f" Detected {len(detections)} person(s) in first frame") print(f" Detected {len(detections)} person(s) in first frame")
# Add detections to both eyes # Add detections to both eyes (streaming - pass frame data)
self.sam2_processor.add_detections(left_state, detections, frame_idx=0) self.sam2_processor.add_detections(left_state, left_eye, detections, frame_idx=0)
# Transfer detections to slave eye # Transfer detections to slave eye
transferred_detections = self.stereo_manager.transfer_detections( transferred_detections = self.stereo_manager.transfer_detections(
detections, detections,
'left_to_right' if self.stereo_manager.master_eye == 'left' else 'right_to_left' 'left_to_right' if self.stereo_manager.master_eye == 'left' else 'right_to_left'
) )
self.sam2_processor.add_detections(right_state, transferred_detections, frame_idx=0) self.sam2_processor.add_detections(right_state, right_eye, transferred_detections, frame_idx=0)
# Process and write first frame # Process and write first frame
left_masks = self.sam2_processor._propagate_single_frame(left_state, left_eye, 0) left_masks = self.sam2_processor.propagate_single_frame(left_state, left_eye, 0)
right_masks = self.sam2_processor._propagate_single_frame(right_state, right_eye, 0) right_masks = self.sam2_processor.propagate_single_frame(right_state, right_eye, 0)
# Apply masks and write # Apply masks and write
processed_frame = self._apply_masks_to_frame(first_frame, left_masks, right_masks) processed_frame = self._apply_masks_to_frame(first_frame, left_masks, right_masks)
@@ -195,10 +196,9 @@ class VR180StreamingProcessor:
# Split into eyes # Split into eyes
left_eye, right_eye = self.stereo_manager.split_frame(frame) left_eye, right_eye = self.stereo_manager.split_frame(frame)
# Propagate masks for both eyes # Propagate masks for both eyes (streaming approach)
left_masks, right_masks = self.sam2_processor.propagate_frame_pair( left_masks = self.sam2_processor.propagate_single_frame(left_state, left_eye, frame_idx)
left_state, right_state, left_eye, right_eye, frame_idx right_masks = self.sam2_processor.propagate_single_frame(right_state, right_eye, frame_idx)
)
# Validate stereo consistency # Validate stereo consistency
right_masks = self.stereo_manager.validate_masks( right_masks = self.stereo_manager.validate_masks(