This commit is contained in:
2025-07-27 09:26:47 -07:00
parent 9b7f36fec2
commit 43be574729

View File

@@ -81,14 +81,31 @@ class SAM2StreamingProcessor:
vos_optimized=True # Enable full model compilation for speed
)
# Set to eval mode
# Set to eval mode and ensure all model components are on GPU
self.predictor.eval()
# Force all predictor components to GPU
self.predictor = self.predictor.to(self.device)
# Force move all internal components that might be on CPU
if hasattr(self.predictor, 'image_encoder'):
self.predictor.image_encoder = self.predictor.image_encoder.to(self.device)
if hasattr(self.predictor, 'memory_attention'):
self.predictor.memory_attention = self.predictor.memory_attention.to(self.device)
if hasattr(self.predictor, 'memory_encoder'):
self.predictor.memory_encoder = self.predictor.memory_encoder.to(self.device)
if hasattr(self.predictor, 'sam_mask_decoder'):
self.predictor.sam_mask_decoder = self.predictor.sam_mask_decoder.to(self.device)
if hasattr(self.predictor, 'sam_prompt_encoder'):
self.predictor.sam_prompt_encoder = self.predictor.sam_prompt_encoder.to(self.device)
# Note: FP16 conversion can cause type mismatches with compiled models
# Let SAM2 handle precision internally via build_sam2_video_predictor options
if self.fp16 and self.device.type == 'cuda':
print(" FP16 enabled via SAM2 internal settings")
print(f" All SAM2 components moved to {self.device}")
except Exception as e:
raise RuntimeError(f"Failed to initialize SAM2 predictor: {e}")
@@ -286,11 +303,28 @@ class SAM2StreamingProcessor:
print(f" Error in add_new_points_or_box: {e}")
print(f" Box tensor device: {boxes_tensor.device}")
print(f" Frame tensor device: {frame_tensor.device}")
print(f" State device keys: {[k for k in state.keys() if 'device' in k.lower()]}")
# Try to inspect state tensor devices
for key, value in state.items():
if isinstance(value, torch.Tensor):
print(f" State[{key}] device: {value.device}")
# Check predictor components
print(f" Checking predictor device placement:")
if hasattr(self.predictor, 'image_encoder'):
try:
for name, param in self.predictor.image_encoder.named_parameters():
if param.device.type != 'cuda':
print(f" image_encoder.{name}: {param.device}")
break
except: pass
if hasattr(self.predictor, 'sam_prompt_encoder'):
try:
for name, param in self.predictor.sam_prompt_encoder.named_parameters():
if param.device.type != 'cuda':
print(f" sam_prompt_encoder.{name}: {param.device}")
break
except: pass
# Check for any CPU tensors in predictor
print(f" Predictor type: {type(self.predictor)}")
print(f" Available predictor attributes: {[attr for attr in dir(self.predictor) if not attr.startswith('_')]}")
raise
self.object_ids = object_ids