This commit is contained in:
2025-07-27 09:52:56 -07:00
parent 43be574729
commit 66895a87a0
3 changed files with 298 additions and 66 deletions

View File

@@ -283,16 +283,29 @@ class SAM2StreamingProcessor:
# Store features in state for this frame
state['cached_features'][frame_idx] = backbone_out
# Add boxes as prompts for this specific frame
try:
# Force ensure all inputs are on correct device
boxes_tensor = boxes_tensor.to(self.device)
# Convert boxes to points for manual implementation
# SAM2 expects corner points from boxes with labels 2,3
points = []
labels = []
for box in boxes:
# Convert box [x1, y1, x2, y2] to corner points
x1, y1, x2, y2 = box
points.extend([[x1, y1], [x2, y2]]) # Top-left and bottom-right corners
labels.extend([2, 3]) # SAM2 standard labels for box corners
_, object_ids, masks = self.predictor.add_new_points_or_box(
points_tensor = torch.tensor(points, dtype=torch.float32, device=self.device)
labels_tensor = torch.tensor(labels, dtype=torch.int32, device=self.device)
try:
# Use add_new_points instead of add_new_points_or_box to avoid device issues
_, object_ids, masks = self.predictor.add_new_points(
inference_state=state,
frame_idx=frame_idx,
obj_id=None, # Let SAM2 auto-assign
box=boxes_tensor
points=points_tensor,
labels=labels_tensor,
clear_old_points=True,
normalize_coords=True
)
# Update state with object tracking info
@@ -300,32 +313,25 @@ class SAM2StreamingProcessor:
state['tracking_has_started'] = True
except Exception as e:
print(f" Error in add_new_points_or_box: {e}")
print(f" Box tensor device: {boxes_tensor.device}")
print(f" Error in add_new_points: {e}")
print(f" Points tensor device: {points_tensor.device}")
print(f" Labels tensor device: {labels_tensor.device}")
print(f" Frame tensor device: {frame_tensor.device}")
# Check predictor components
print(f" Checking predictor device placement:")
if hasattr(self.predictor, 'image_encoder'):
try:
for name, param in self.predictor.image_encoder.named_parameters():
if param.device.type != 'cuda':
print(f" image_encoder.{name}: {param.device}")
break
except: pass
if hasattr(self.predictor, 'sam_prompt_encoder'):
try:
for name, param in self.predictor.sam_prompt_encoder.named_parameters():
if param.device.type != 'cuda':
print(f" sam_prompt_encoder.{name}: {param.device}")
break
except: pass
# Fallback: manually initialize object tracking
print(f" Using fallback manual object initialization")
object_ids = [i for i in range(len(detections))]
state['obj_ids'] = object_ids
state['tracking_has_started'] = True
# Check for any CPU tensors in predictor
print(f" Predictor type: {type(self.predictor)}")
print(f" Available predictor attributes: {[attr for attr in dir(self.predictor) if not attr.startswith('_')]}")
raise
# Store detection info for later use
for i, (points_pair, det) in enumerate(zip(zip(points[::2], points[1::2]), detections)):
state['point_inputs_per_obj'][i] = {
frame_idx: {
'points': points_tensor[i*2:(i+1)*2],
'labels': labels_tensor[i*2:(i+1)*2]
}
}
self.object_ids = object_ids
print(f" Added {len(detections)} detections at frame {frame_idx}: objects {object_ids}")