add Colab support to the notebooks; pack config files in sam2_configs package during installation (#176)

This commit is contained in:
Ronghang Hu
2024-08-08 11:03:22 -07:00
committed by GitHub
parent 6186d1529a
commit d421e0b040
6 changed files with 286 additions and 114 deletions

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -29,14 +29,83 @@
"- propagating clicks (or box) to get _masklets_ throughout the video\n",
"- segmenting and tracking multiple objects at the same time\n",
"\n",
"We use the terms _segment_ or _mask_ to refer to the model prediction for an object on a single frame, and _masklet_ to refer to the spatio-temporal masks across the entire video. \n",
"We use the terms _segment_ or _mask_ to refer to the model prediction for an object on a single frame, and _masklet_ to refer to the spatio-temporal masks across the entire video. "
]
},
{
"cell_type": "markdown",
"id": "a887b90f-6576-4ef8-964e-76d3a156ccb6",
"metadata": {},
"source": [
"<a target=\"_blank\" href=\"https://colab.research.google.com/github/facebookresearch/segment-anything-2/blob/main/notebooks/video_predictor_example.ipynb\">\n",
" <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
"</a>"
]
},
{
"cell_type": "markdown",
"id": "26616201-06df-435b-98fd-ad17c373bb4a",
"metadata": {},
"source": [
"## Environment Set-up"
]
},
{
"cell_type": "markdown",
"id": "8491a127-4c01-48f5-9dc5-f148a9417fdf",
"metadata": {},
"source": [
"If running locally using jupyter, first install `segment-anything-2` in your environment using the [installation instructions](https://github.com/facebookresearch/segment-anything-2#installation) in the repository.\n",
"\n",
"If running locally using jupyter, first install `segment-anything-2` in your environment using the [installation instructions](https://github.com/facebookresearch/segment-anything-2#installation) in the repository."
"If running from Google Colab, set `using_colab=True` below and run the cell. In Colab, be sure to select 'GPU' under 'Edit'->'Notebook Settings'->'Hardware accelerator'. Note that it's recommended to use **A100 or L4 GPUs when running in Colab** (T4 GPUs might also work, but could be slow and might run out of memory in some cases)."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "f74c53be-aab1-46b9-8c0b-068b52ef5948",
"metadata": {},
"outputs": [],
"source": [
"using_colab = False"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "d824a4b2-71f3-4da3-bfc7-3249625e6730",
"metadata": {},
"outputs": [],
"source": [
"if using_colab:\n",
" import torch\n",
" import torchvision\n",
" print(\"PyTorch version:\", torch.__version__)\n",
" print(\"Torchvision version:\", torchvision.__version__)\n",
" print(\"CUDA is available:\", torch.cuda.is_available())\n",
" import sys\n",
" !{sys.executable} -m pip install opencv-python matplotlib\n",
" !{sys.executable} -m pip install 'git+https://github.com/facebookresearch/segment-anything-2.git'\n",
"\n",
" !mkdir -p videos\n",
" !wget -P videos https://dl.fbaipublicfiles.com/segment_anything_2/assets/bedroom.zip\n",
" !unzip -d videos videos/bedroom.zip\n",
"\n",
" !mkdir -p ../checkpoints/\n",
" !wget -P ../checkpoints/ https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt"
]
},
{
"cell_type": "markdown",
"id": "22e6aa9d-487f-4207-b657-8cff0902343e",
"metadata": {},
"source": [
"## Set-up"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "e5318a85-5bf7-4880-b2b3-15e4db24d796",
"metadata": {},
"outputs": [],
@@ -50,7 +119,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 5,
"id": "08ba49d8-8c22-4eba-a2ab-46eee839287f",
"metadata": {},
"outputs": [],
@@ -74,7 +143,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 6,
"id": "f5f3245e-b4d6-418b-a42a-a67e0b3b5aec",
"metadata": {},
"outputs": [],
@@ -89,7 +158,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 7,
"id": "1a5320fe-06d7-45b8-b888-ae00799d07fa",
"metadata": {},
"outputs": [],
@@ -143,17 +212,17 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 8,
"id": "b94c87ca-fd1a-4011-9609-e8be1cbe3230",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0x7f884825eef0>"
"<matplotlib.image.AxesImage at 0x7fdeec360250>"
]
},
"execution_count": 6,
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
},
@@ -206,7 +275,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 9,
"id": "8967aed3-eb82-4866-b8df-0f4743255c2c",
"metadata": {},
"outputs": [
@@ -214,7 +283,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"frame loading (JPEG): 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:05<00:00, 33.78it/s]\n"
"frame loading (JPEG): 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:05<00:00, 35.92it/s]\n"
]
}
],
@@ -242,7 +311,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 10,
"id": "d2646a1d-3401-438c-a653-55e0e56b7d9d",
"metadata": {},
"outputs": [],
@@ -272,7 +341,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 11,
"id": "3e749bab-0f36-4173-bf8d-0c20cd5214b3",
"metadata": {},
"outputs": [
@@ -333,7 +402,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 12,
"id": "e1ab3ec7-2537-4158-bf98-3d0977d8908d",
"metadata": {},
"outputs": [
@@ -399,7 +468,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 13,
"id": "ab45e932-b0d5-4983-9718-6ee77d1ac31b",
"metadata": {},
"outputs": [
@@ -407,7 +476,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"propagate in video: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:08<00:00, 23.85it/s]\n"
"propagate in video: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:08<00:00, 23.76it/s]\n"
]
},
{
@@ -591,7 +660,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 14,
"id": "1a572ea9-5b7e-479c-b30c-93c38b121131",
"metadata": {},
"outputs": [
@@ -664,7 +733,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 15,
"id": "baa96690-4a38-4a24-aa17-fd2f4db0e232",
"metadata": {},
"outputs": [
@@ -672,7 +741,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"propagate in video: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:08<00:00, 23.94it/s]\n"
"propagate in video: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:08<00:00, 23.93it/s]\n"
]
},
{
@@ -862,7 +931,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 16,
"id": "6dbe9183-abbb-4283-b0cb-d24f3d7beb34",
"metadata": {},
"outputs": [],
@@ -882,7 +951,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 17,
"id": "1cbfb273-4e14-495b-bd89-87a8baf52ae7",
"metadata": {},
"outputs": [
@@ -932,7 +1001,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 18,
"id": "54906315-ab4c-4088-b866-4c22134d5b66",
"metadata": {},
"outputs": [
@@ -986,7 +1055,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 19,
"id": "9cd90557-a0dc-442e-b091-9c74c831bef8",
"metadata": {},
"outputs": [
@@ -994,7 +1063,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"propagate in video: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:08<00:00, 24.05it/s]\n"
"propagate in video: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:08<00:00, 23.71it/s]\n"
]
},
{
@@ -1158,6 +1227,14 @@
" show_mask(out_mask, plt.gca(), obj_id=out_obj_id)"
]
},
{
"cell_type": "markdown",
"id": "e023f91f-0cc5-4980-ae8e-a13c5749112b",
"metadata": {},
"source": [
"Note that in addition to clicks or boxes, SAM 2 also supports directly using a **mask prompt** as input via the `add_new_mask` method in the `SAM2VideoPredictor` class. This can be helpful in e.g. semi-supervised VOS evaluations (see [tools/vos_inference.py](https://github.com/facebookresearch/segment-anything-2/blob/main/tools/vos_inference.py) for an example)."
]
},
{
"cell_type": "markdown",
"id": "da018be8-a4ae-4943-b1ff-702c2b89cb68",
@@ -1176,7 +1253,7 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 20,
"id": "29b874c8-9f39-42d3-a667-54a0bd696410",
"metadata": {},
"outputs": [],
@@ -1204,7 +1281,7 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 21,
"id": "e22d896d-3cd5-4fa0-9230-f33e217035dc",
"metadata": {},
"outputs": [],
@@ -1224,7 +1301,7 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 22,
"id": "d13432fc-f467-44d8-adfe-3e0c488046b7",
"metadata": {},
"outputs": [
@@ -1276,7 +1353,7 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": 23,
"id": "95ecf61d-662b-4f98-ae62-46557b219842",
"metadata": {},
"outputs": [
@@ -1334,7 +1411,7 @@
},
{
"cell_type": "code",
"execution_count": 22,
"execution_count": 24,
"id": "86ca1bde-62a4-40e6-98e4-15606441e52f",
"metadata": {},
"outputs": [
@@ -1407,7 +1484,7 @@
},
{
"cell_type": "code",
"execution_count": 23,
"execution_count": 25,
"id": "17737191-d62b-4611-b2c6-6d0418a9ab74",
"metadata": {},
"outputs": [
@@ -1415,7 +1492,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"propagate in video: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:10<00:00, 19.93it/s]\n"
"propagate in video: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:10<00:00, 19.77it/s]\n"
]
},
{