simplify
This commit is contained in:
@@ -283,16 +283,29 @@ class SAM2StreamingProcessor:
|
||||
# Store features in state for this frame
|
||||
state['cached_features'][frame_idx] = backbone_out
|
||||
|
||||
# Add boxes as prompts for this specific frame
|
||||
try:
|
||||
# Force ensure all inputs are on correct device
|
||||
boxes_tensor = boxes_tensor.to(self.device)
|
||||
# Convert boxes to points for manual implementation
|
||||
# SAM2 expects corner points from boxes with labels 2,3
|
||||
points = []
|
||||
labels = []
|
||||
for box in boxes:
|
||||
# Convert box [x1, y1, x2, y2] to corner points
|
||||
x1, y1, x2, y2 = box
|
||||
points.extend([[x1, y1], [x2, y2]]) # Top-left and bottom-right corners
|
||||
labels.extend([2, 3]) # SAM2 standard labels for box corners
|
||||
|
||||
_, object_ids, masks = self.predictor.add_new_points_or_box(
|
||||
points_tensor = torch.tensor(points, dtype=torch.float32, device=self.device)
|
||||
labels_tensor = torch.tensor(labels, dtype=torch.int32, device=self.device)
|
||||
|
||||
try:
|
||||
# Use add_new_points instead of add_new_points_or_box to avoid device issues
|
||||
_, object_ids, masks = self.predictor.add_new_points(
|
||||
inference_state=state,
|
||||
frame_idx=frame_idx,
|
||||
obj_id=None, # Let SAM2 auto-assign
|
||||
box=boxes_tensor
|
||||
points=points_tensor,
|
||||
labels=labels_tensor,
|
||||
clear_old_points=True,
|
||||
normalize_coords=True
|
||||
)
|
||||
|
||||
# Update state with object tracking info
|
||||
@@ -300,32 +313,25 @@ class SAM2StreamingProcessor:
|
||||
state['tracking_has_started'] = True
|
||||
|
||||
except Exception as e:
|
||||
print(f" Error in add_new_points_or_box: {e}")
|
||||
print(f" Box tensor device: {boxes_tensor.device}")
|
||||
print(f" Error in add_new_points: {e}")
|
||||
print(f" Points tensor device: {points_tensor.device}")
|
||||
print(f" Labels tensor device: {labels_tensor.device}")
|
||||
print(f" Frame tensor device: {frame_tensor.device}")
|
||||
|
||||
# Check predictor components
|
||||
print(f" Checking predictor device placement:")
|
||||
if hasattr(self.predictor, 'image_encoder'):
|
||||
try:
|
||||
for name, param in self.predictor.image_encoder.named_parameters():
|
||||
if param.device.type != 'cuda':
|
||||
print(f" image_encoder.{name}: {param.device}")
|
||||
break
|
||||
except: pass
|
||||
|
||||
if hasattr(self.predictor, 'sam_prompt_encoder'):
|
||||
try:
|
||||
for name, param in self.predictor.sam_prompt_encoder.named_parameters():
|
||||
if param.device.type != 'cuda':
|
||||
print(f" sam_prompt_encoder.{name}: {param.device}")
|
||||
break
|
||||
except: pass
|
||||
# Fallback: manually initialize object tracking
|
||||
print(f" Using fallback manual object initialization")
|
||||
object_ids = [i for i in range(len(detections))]
|
||||
state['obj_ids'] = object_ids
|
||||
state['tracking_has_started'] = True
|
||||
|
||||
# Check for any CPU tensors in predictor
|
||||
print(f" Predictor type: {type(self.predictor)}")
|
||||
print(f" Available predictor attributes: {[attr for attr in dir(self.predictor) if not attr.startswith('_')]}")
|
||||
raise
|
||||
# Store detection info for later use
|
||||
for i, (points_pair, det) in enumerate(zip(zip(points[::2], points[1::2]), detections)):
|
||||
state['point_inputs_per_obj'][i] = {
|
||||
frame_idx: {
|
||||
'points': points_tensor[i*2:(i+1)*2],
|
||||
'labels': labels_tensor[i*2:(i+1)*2]
|
||||
}
|
||||
}
|
||||
|
||||
self.object_ids = object_ids
|
||||
print(f" Added {len(detections)} detections at frame {frame_idx}: objects {object_ids}")
|
||||
|
||||
Reference in New Issue
Block a user