diff --git a/vr180_streaming/sam2_streaming_simple.py b/vr180_streaming/sam2_streaming_simple.py index 60f1e30..a8402e4 100644 --- a/vr180_streaming/sam2_streaming_simple.py +++ b/vr180_streaming/sam2_streaming_simple.py @@ -28,25 +28,17 @@ class SAM2StreamingProcessor: self.device = torch.device(config.get('hardware', {}).get('device', 'cuda')) # SAM2 model configuration - model_cfg_name = config.get('matting', {}).get('sam2_model_cfg', 'sam2.1_hiera_l') + model_cfg = config.get('matting', {}).get('sam2_model_cfg', 'sam2.1_hiera_l') checkpoint = config.get('matting', {}).get('sam2_checkpoint', 'segment-anything-2/checkpoints/sam2.1_hiera_large.pt') - # Map config name to full path - config_mapping = { - 'sam2.1_hiera_t': 'segment-anything-2/sam2/configs/sam2.1/sam2.1_hiera_t.yaml', - 'sam2.1_hiera_s': 'segment-anything-2/sam2/configs/sam2.1/sam2.1_hiera_s.yaml', - 'sam2.1_hiera_b+': 'segment-anything-2/sam2/configs/sam2.1/sam2.1_hiera_b+.yaml', - 'sam2.1_hiera_l': 'segment-anything-2/sam2/configs/sam2.1/sam2.1_hiera_l.yaml', - } - - model_cfg = config_mapping.get(model_cfg_name, model_cfg_name) - # Build predictor (simple, clean approach) + # Note: SAM2 uses Hydra to find configs automatically in sam2/configs/ self.predictor = build_sam2_video_predictor( - model_cfg, + model_cfg, # Just the config name, Hydra finds it automatically checkpoint, - device=self.device + device=self.device, + vos_optimized=True # Enable VOS optimizations for speed ) # Frame buffer for streaming (like det-sam2)