first commit
This commit is contained in:
2214
examples/saco_gold_silver_eval_example.ipynb
Normal file
2214
examples/saco_gold_silver_eval_example.ipynb
Normal file
File diff suppressed because it is too large
Load Diff
256
examples/saco_gold_silver_vis_example.ipynb
Normal file
256
examples/saco_gold_silver_vis_example.ipynb
Normal file
@@ -0,0 +1,256 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "37048f21",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Copyright (c) Meta Platforms, Inc. and affiliates."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "154d8663",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"using_colab = False"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b85d99d9",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"if using_colab:\n",
|
||||
" import torch\n",
|
||||
" import torchvision\n",
|
||||
" print(\"PyTorch version:\", torch.__version__)\n",
|
||||
" print(\"Torchvision version:\", torchvision.__version__)\n",
|
||||
" print(\"CUDA is available:\", torch.cuda.is_available())\n",
|
||||
" import sys\n",
|
||||
" !{sys.executable} -m pip install opencv-python matplotlib scikit-learn\n",
|
||||
" !{sys.executable} -m pip install 'git+https://github.com/facebookresearch/sam3.git'"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "da21a3bc",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"from glob import glob\n",
|
||||
"\n",
|
||||
"import numpy as np\n",
|
||||
"import sam3.visualization_utils as utils\n",
|
||||
"\n",
|
||||
"from matplotlib import pyplot as plt\n",
|
||||
"\n",
|
||||
"COLORS = utils.pascal_color_map()[1:]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "57e85e7e",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"1. Load the data"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a796734e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Preapre the data path\n",
|
||||
"ANNOT_DIR = None # PUT YOUR ANNOTATION PATH HERE\n",
|
||||
"IMG_DIR = None # PUT YOUR IMAGE PATH HERE\n",
|
||||
"\n",
|
||||
"# Load the SA-CO/Gold annotation files\n",
|
||||
"annot_file_list = glob(os.path.join(ANNOT_DIR, \"*gold*.json\"))\n",
|
||||
"annot_dfs = utils.get_annot_dfs(file_list=annot_file_list)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "74bf92b1",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Show the annotation files being loaded"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a95620ec",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"annot_dfs.keys()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "5ce211d3",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"2. Examples of the data format"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6ba749db",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"annot_dfs[\"gold_fg_sports_equipment_merged_a_release_test\"].keys()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "4b6dc186",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"annot_dfs[\"gold_fg_sports_equipment_merged_a_release_test\"][\"info\"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c41091b3",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"annot_dfs[\"gold_fg_sports_equipment_merged_a_release_test\"][\"images\"].head(3)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a7df5771",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"annot_dfs[\"gold_fg_sports_equipment_merged_a_release_test\"][\"annotations\"].head(3)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "5673a63f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"3. Visualize the data"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b1fc2a24",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Select a target dataset\n",
|
||||
"target_dataset_name = \"gold_fg_food_merged_a_release_test\"\n",
|
||||
"\n",
|
||||
"import cv2\n",
|
||||
"from pycocotools import mask as mask_util\n",
|
||||
"from collections import defaultdict\n",
|
||||
"\n",
|
||||
"# Group GT annotations by image_id\n",
|
||||
"gt_image_np_pairs = annot_dfs[target_dataset_name][\"images\"]\n",
|
||||
"gt_annotations = annot_dfs[target_dataset_name][\"annotations\"]\n",
|
||||
"\n",
|
||||
"gt_image_np_map = {img[\"id\"]: img for _, img in gt_image_np_pairs.iterrows()}\n",
|
||||
"gt_image_np_ann_map = defaultdict(list)\n",
|
||||
"for _, ann in gt_annotations.iterrows():\n",
|
||||
" image_id = ann[\"image_id\"]\n",
|
||||
" if image_id not in gt_image_np_ann_map:\n",
|
||||
" gt_image_np_ann_map[image_id] = []\n",
|
||||
" gt_image_np_ann_map[image_id].append(ann)\n",
|
||||
"\n",
|
||||
"positiveNPs = common_image_ids = [img_id for img_id in gt_image_np_map.keys() if img_id in gt_image_np_ann_map and gt_image_np_ann_map[img_id]]\n",
|
||||
"negativeNPs = [img_id for img_id in gt_image_np_map.keys() if img_id not in gt_image_np_ann_map or not gt_image_np_ann_map[img_id]]\n",
|
||||
"\n",
|
||||
"num_image_nps_to_show = 10\n",
|
||||
"fig, axes = plt.subplots(num_image_nps_to_show, 3, figsize=(15, 5 * num_image_nps_to_show))\n",
|
||||
"for idx in range(num_image_nps_to_show):\n",
|
||||
" rand_idx = np.random.randint(len(positiveNPs))\n",
|
||||
" image_id = positiveNPs[rand_idx]\n",
|
||||
" noun_phrase = gt_image_np_map[image_id][\"text_input\"]\n",
|
||||
" img_rel_path = gt_image_np_map[image_id][\"file_name\"]\n",
|
||||
" full_path = os.path.join(IMG_DIR, f\"{img_rel_path}\")\n",
|
||||
" img = cv2.imread(full_path)\n",
|
||||
" img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)\n",
|
||||
" gt_annotation = gt_image_np_ann_map[image_id]\n",
|
||||
"\n",
|
||||
" def display_image_in_subplot(img, axes, row, col, title=\"\"):\n",
|
||||
" axes[row, col].imshow(img)\n",
|
||||
" axes[row, col].set_title(title)\n",
|
||||
" axes[row, col].axis('off')\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" noun_phrases = [noun_phrase]\n",
|
||||
" annot_masks = [mask_util.decode(ann[\"segmentation\"]) for ann in gt_annotation]\n",
|
||||
"\n",
|
||||
" # Show the image\n",
|
||||
" display_image_in_subplot(img, axes, idx, 0, f\"{noun_phrase}\")\n",
|
||||
"\n",
|
||||
" # Show all masks over a white background\n",
|
||||
" all_masks = utils.draw_masks_to_frame(\n",
|
||||
" frame=np.ones_like(img)*255, masks=annot_masks, colors=COLORS[: len(annot_masks)]\n",
|
||||
" )\n",
|
||||
" display_image_in_subplot(all_masks, axes, idx, 1, f\"{noun_phrase} - Masks only\")\n",
|
||||
"\n",
|
||||
" # Show masks overlaid on the image\n",
|
||||
" masked_frame = utils.draw_masks_to_frame(\n",
|
||||
" frame=img, masks=annot_masks, colors=COLORS[: len(annot_masks)]\n",
|
||||
" )\n",
|
||||
" display_image_in_subplot(masked_frame, axes, idx, 2, f\"{noun_phrase} - Masks overlaid\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "84a20e0e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"fileHeader": "",
|
||||
"fileUid": "a2cedcd3-26e1-430d-b718-764d51077f86",
|
||||
"isAdHoc": false,
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.13"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
137
examples/saco_veval_eval_example.ipynb
Normal file
137
examples/saco_veval_eval_example.ipynb
Normal file
@@ -0,0 +1,137 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "0e0d2e74",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import json\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"from sam3.eval.saco_veval_eval import VEvalEvaluator"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b31ab5d3",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"DATASETS_TO_EVAL = [\n",
|
||||
" \"saco_veval_sav_test\",\n",
|
||||
" \"saco_veval_yt1b_test\",\n",
|
||||
" \"saco_veval_smartglasses_test\",\n",
|
||||
"]\n",
|
||||
"# Update to the directory where the GT annotation and PRED files exist\n",
|
||||
"GT_DIR = None # PUT YOUR ANNOTATION PATH HERE\n",
|
||||
"PRED_DIR = None # PUT YOUR IMAGE PATH HERE"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3a602fef",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"all_eval_res = {}\n",
|
||||
"for dataset_name in DATASETS_TO_EVAL:\n",
|
||||
" gt_annot_file = os.path.join(GT_DIR, dataset_name + \".json\")\n",
|
||||
" pred_file = os.path.join(PRED_DIR, dataset_name + \"_preds.json\")\n",
|
||||
" eval_res_file = os.path.join(PRED_DIR, dataset_name + \"_eval_res.json\")\n",
|
||||
"\n",
|
||||
" if os.path.exists(eval_res_file):\n",
|
||||
" with open(eval_res_file, \"r\") as f:\n",
|
||||
" eval_res = json.load(f)\n",
|
||||
" else:\n",
|
||||
" # Alternatively, we can run the evaluator offline first\n",
|
||||
" # by leveraging sam3/eval/saco_veval_eval.py\n",
|
||||
" print(f\"=== Running evaluation for Pred {pred_file} vs GT {gt_annot_file} ===\")\n",
|
||||
" veval_evaluator = VEvalEvaluator(\n",
|
||||
" gt_annot_file=gt_annot_file, eval_res_file=eval_res_file\n",
|
||||
" )\n",
|
||||
" eval_res = veval_evaluator.run_eval(pred_file=pred_file)\n",
|
||||
" print(f\"=== Results saved to {eval_res_file} ===\")\n",
|
||||
"\n",
|
||||
" all_eval_res[dataset_name] = eval_res"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a6dbec47",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"REPORT_METRICS = {\n",
|
||||
" \"video_mask_demo_cgf1_micro_50_95\": \"cgf1\",\n",
|
||||
" \"video_mask_all_phrase_HOTA\": \"pHOTA\",\n",
|
||||
"}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "cc28d29f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"res_to_print = []\n",
|
||||
"for dataset_name in DATASETS_TO_EVAL:\n",
|
||||
" eval_res = all_eval_res[dataset_name]\n",
|
||||
" row = [dataset_name]\n",
|
||||
" for metric_k, metric_v in REPORT_METRICS.items():\n",
|
||||
" row.append(eval_res[\"dataset_results\"][metric_k])\n",
|
||||
" res_to_print.append(row)\n",
|
||||
"\n",
|
||||
"# Print dataset header (each dataset spans 2 metrics: 13 + 3 + 13 = 29 chars)\n",
|
||||
"print(\"| \" + \" | \".join(f\"{ds:^29}\" for ds in DATASETS_TO_EVAL) + \" |\")\n",
|
||||
"\n",
|
||||
"# Print metric header\n",
|
||||
"metrics = list(REPORT_METRICS.values())\n",
|
||||
"print(\"| \" + \" | \".join(f\"{m:^13}\" for _ in DATASETS_TO_EVAL for m in metrics) + \" |\")\n",
|
||||
"\n",
|
||||
"# Print eval results\n",
|
||||
"values = []\n",
|
||||
"for row in res_to_print:\n",
|
||||
" values.extend([f\"{v * 100:^13.1f}\" for v in row[1:]])\n",
|
||||
"print(\"| \" + \" | \".join(values) + \" |\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9976908b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"fileHeader": "",
|
||||
"fileUid": "bdaa3851-85de-435f-9582-efb46951a1d0",
|
||||
"isAdHoc": false,
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.13"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
269
examples/saco_veval_vis_example.ipynb
Normal file
269
examples/saco_veval_vis_example.ipynb
Normal file
@@ -0,0 +1,269 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "37048f21",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Copyright (c) Meta Platforms, Inc. and affiliates."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "154d8663",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"using_colab = False"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b85d99d9",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"if using_colab:\n",
|
||||
" import torch\n",
|
||||
" import torchvision\n",
|
||||
" print(\"PyTorch version:\", torch.__version__)\n",
|
||||
" print(\"Torchvision version:\", torchvision.__version__)\n",
|
||||
" print(\"CUDA is available:\", torch.cuda.is_available())\n",
|
||||
" import sys\n",
|
||||
" !{sys.executable} -m pip install opencv-python matplotlib scikit-learn\n",
|
||||
" !{sys.executable} -m pip install 'git+https://github.com/facebookresearch/sam3.git'"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "da21a3bc",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"from glob import glob\n",
|
||||
"\n",
|
||||
"import numpy as np\n",
|
||||
"import utils\n",
|
||||
"\n",
|
||||
"from matplotlib import pyplot as plt\n",
|
||||
"\n",
|
||||
"COLORS = utils.pascal_color_map()[1:]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "57e85e7e",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"1. Load the data"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a796734e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Preapre the data path\n",
|
||||
"DATA_DIR = \"./sam3_saco_veval_data\" # PUT YOUR DATA PATH HERE\n",
|
||||
"ANNOT_DIR = os.path.join(DATA_DIR, \"annotation\")\n",
|
||||
"\n",
|
||||
"# Load the SACO/Veval annotation files\n",
|
||||
"annot_file_list = glob(os.path.join(ANNOT_DIR, \"*veval*.json\"))\n",
|
||||
"annot_dfs = utils.get_annot_dfs(file_list=annot_file_list)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "74bf92b1",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Show the annotation files being loaded"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a95620ec",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"annot_dfs.keys()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "5ce211d3",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"2. Examples of the data format"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6ba749db",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"annot_dfs[\"saco_veval_yt1b_val\"].keys()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "4b6dc186",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"annot_dfs[\"saco_veval_yt1b_val\"][\"info\"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c41091b3",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"annot_dfs[\"saco_veval_yt1b_val\"][\"videos\"].head(3)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a7df5771",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"annot_dfs[\"saco_veval_yt1b_val\"][\"annotations\"].head(3)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "24d2861c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"annot_dfs[\"saco_veval_yt1b_val\"][\"categories\"].head(3)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f9f98f27",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"annot_dfs[\"saco_veval_yt1b_val\"][\"video_np_pairs\"].head(3)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "5673a63f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"3. Visualize the data"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "da827d09",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Select a target dataset\n",
|
||||
"target_dataset_name = \"saco_veval_yt1b_val\"\n",
|
||||
"\n",
|
||||
"# visualize a random positive video-np pair\n",
|
||||
"df_pairs = annot_dfs[target_dataset_name][\"video_np_pairs\"]\n",
|
||||
"df_positive_pairs = df_pairs[df_pairs.num_masklets > 0]\n",
|
||||
"rand_idx = np.random.randint(len(df_positive_pairs))\n",
|
||||
"pair_row = df_positive_pairs.iloc[rand_idx]\n",
|
||||
"video_id = pair_row.video_id\n",
|
||||
"noun_phrase = pair_row.noun_phrase\n",
|
||||
"print(f\"Randomly selected video-np pair: video_id={video_id}, noun_phrase={noun_phrase}\")\n",
|
||||
"\n",
|
||||
"def display_image_in_subplot(img, axes, row, col, title=\"\"):\n",
|
||||
" axes[row, col].imshow(img)\n",
|
||||
" axes[row, col].set_title(title)\n",
|
||||
" axes[row, col].axis('off')\n",
|
||||
"\n",
|
||||
"num_frames_to_show = 5 # Number of frames to show per dataset\n",
|
||||
"every_n_frames = 4 # Interval between frames to show\n",
|
||||
"\n",
|
||||
"fig, axes = plt.subplots(num_frames_to_show, 3, figsize=(15, 5 * num_frames_to_show))\n",
|
||||
"\n",
|
||||
"for idx in range(0, num_frames_to_show):\n",
|
||||
" sampled_frame_idx = idx * every_n_frames\n",
|
||||
" print(f\"Reading annotations for frame {sampled_frame_idx}\")\n",
|
||||
" # Get the frame and the corresponding masks and noun phrases\n",
|
||||
" frame, annot_masks, annot_noun_phrases = utils.get_all_annotations_for_frame(\n",
|
||||
" annot_dfs[target_dataset_name], video_id=video_id, frame_idx=sampled_frame_idx, data_dir=DATA_DIR, dataset=target_dataset_name\n",
|
||||
" )\n",
|
||||
" # Filter masks and noun phrases by the selected noun phrase\n",
|
||||
" annot_masks = [m for m, np in zip(annot_masks, annot_noun_phrases) if np == noun_phrase]\n",
|
||||
"\n",
|
||||
" # Show the frame\n",
|
||||
" display_image_in_subplot(frame, axes, idx, 0, f\"{target_dataset_name} - {noun_phrase} - Frame {sampled_frame_idx}\")\n",
|
||||
"\n",
|
||||
" # Show the annotated masks\n",
|
||||
" if annot_masks is None:\n",
|
||||
" print(f\"No masks found for video_id {video_id} at frame {sampled_frame_idx}\")\n",
|
||||
" else:\n",
|
||||
" # Show all masks over a white background\n",
|
||||
" all_masks = utils.draw_masks_to_frame(\n",
|
||||
" frame=np.ones_like(frame)*255, masks=annot_masks, colors=COLORS[: len(annot_masks)]\n",
|
||||
" )\n",
|
||||
" display_image_in_subplot(all_masks, axes, idx, 1, f\"{target_dataset_name} - {noun_phrase} - Frame {sampled_frame_idx} - Masks\")\n",
|
||||
" \n",
|
||||
" # Show masks overlaid on the frame\n",
|
||||
" masked_frame = utils.draw_masks_to_frame(\n",
|
||||
" frame=frame, masks=annot_masks, colors=COLORS[: len(annot_masks)]\n",
|
||||
" )\n",
|
||||
" display_image_in_subplot(masked_frame, axes, idx, 2, f\"Dataset: {target_dataset_name} - {noun_phrase} - Frame {sampled_frame_idx} - Masks overlaid\")\n",
|
||||
"\n",
|
||||
"plt.tight_layout()\n",
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a2a23152",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.13"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
242
examples/sam3_agent.ipynb
Normal file
242
examples/sam3_agent.ipynb
Normal file
@@ -0,0 +1,242 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Copyright (c) Meta Platforms, Inc. and affiliates."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# SAM 3 Agent"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"This notebook shows an example of how an MLLM can use SAM 3 as a tool, i.e., \"SAM 3 Agent\", to segment more complex text queries such as \"the leftmost child wearing blue vest\"."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Env Setup"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"First install `sam3` in your environment using the [installation instructions](https://github.com/facebookresearch/sam3?tab=readme-ov-file#installation) in the repository."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"# turn on tfloat32 for Ampere GPUs\n",
|
||||
"# https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n",
|
||||
"torch.backends.cuda.matmul.allow_tf32 = True\n",
|
||||
"torch.backends.cudnn.allow_tf32 = True\n",
|
||||
"\n",
|
||||
"# use bfloat16 for the entire notebook. If your card doesn't support it, try float16 instead\n",
|
||||
"torch.autocast(\"cuda\", dtype=torch.bfloat16).__enter__()\n",
|
||||
"\n",
|
||||
"# inference mode for the whole notebook. Disable if you need gradients\n",
|
||||
"torch.inference_mode().__enter__()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"\n",
|
||||
"SAM3_ROOT = os.path.dirname(os.getcwd())\n",
|
||||
"os.chdir(SAM3_ROOT)\n",
|
||||
"\n",
|
||||
"# setup GPU to use - A single GPU is good with the purpose of this demo\n",
|
||||
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
|
||||
"_ = os.system(\"nvidia-smi\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Build SAM3 Model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import sam3\n",
|
||||
"from sam3 import build_sam3_image_model\n",
|
||||
"from sam3.model.sam3_image_processor import Sam3Processor\n",
|
||||
"\n",
|
||||
"sam3_root = os.path.dirname(sam3.__file__)\n",
|
||||
"bpe_path = f\"{sam3_root}/assets/bpe_simple_vocab_16e6.txt.gz\"\n",
|
||||
"model = build_sam3_image_model(bpe_path=bpe_path)\n",
|
||||
"processor = Sam3Processor(model, confidence_threshold=0.5)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## LLM Setup\n",
|
||||
"\n",
|
||||
"Config which MLLM to use, it can either be a model served by vLLM that you launch from your own machine or a model is served via external API. If you want to using a vLLM model, we also provided insturctions below."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"LLM_CONFIGS = {\n",
|
||||
" # vLLM-served models\n",
|
||||
" \"qwen3_vl_8b_thinking\": {\n",
|
||||
" \"provider\": \"vllm\",\n",
|
||||
" \"model\": \"Qwen/Qwen3-VL-8B-Thinking\",\n",
|
||||
" },\n",
|
||||
" # models served via external APIs\n",
|
||||
" # add your own\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"model = \"qwen3_vl_8b_thinking\"\n",
|
||||
"LLM_API_KEY = \"DUMMY_API_KEY\"\n",
|
||||
"\n",
|
||||
"llm_config = LLM_CONFIGS[model]\n",
|
||||
"llm_config[\"api_key\"] = LLM_API_KEY\n",
|
||||
"llm_config[\"name\"] = model\n",
|
||||
"\n",
|
||||
"# setup API endpoint\n",
|
||||
"if llm_config[\"provider\"] == \"vllm\":\n",
|
||||
" LLM_SERVER_URL = \"http://0.0.0.0:8001/v1\" # replace this with your vLLM server address as needed\n",
|
||||
"else:\n",
|
||||
" LLM_SERVER_URL = llm_config[\"base_url\"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Setup vLLM server \n",
|
||||
"This step is only required if you are using a model served by vLLM, skip this step if you are calling LLM using an API like Gemini and GPT.\n",
|
||||
"\n",
|
||||
"* Install vLLM (in a separate conda env from SAM 3 to avoid dependency conflicts).\n",
|
||||
" ```bash\n",
|
||||
" conda create -n vllm python=3.12\n",
|
||||
" pip install vllm --extra-index-url https://download.pytorch.org/whl/cu128\n",
|
||||
" ```\n",
|
||||
"* Start vLLM server on the same machine of this notebook\n",
|
||||
" ```bash\n",
|
||||
" # qwen 3 VL 8B thinking\n",
|
||||
" vllm serve Qwen/Qwen3-VL-8B-Thinking --tensor-parallel-size 4 --allowed-local-media-path / --enforce-eager --port 8001\n",
|
||||
" ```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Run SAM3 Agent Inference"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from functools import partial\n",
|
||||
"from IPython.display import display, Image\n",
|
||||
"from sam3.agent.client_llm import send_generate_request as send_generate_request_orig\n",
|
||||
"from sam3.agent.client_sam3 import call_sam_service as call_sam_service_orig\n",
|
||||
"from sam3.agent.inference import run_single_image_inference"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"output": {
|
||||
"id": 689664053567678,
|
||||
"loadingStatus": "loaded"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# prepare input args and run single image inference\n",
|
||||
"image = \"assets/images/test_image.jpg\"\n",
|
||||
"prompt = \"the leftmost child wearing blue vest\"\n",
|
||||
"image = os.path.abspath(image)\n",
|
||||
"send_generate_request = partial(send_generate_request_orig, server_url=LLM_SERVER_URL, model=llm_config[\"model\"], api_key=llm_config[\"api_key\"])\n",
|
||||
"call_sam_service = partial(call_sam_service_orig, sam3_processor=processor)\n",
|
||||
"output_image_path = run_single_image_inference(\n",
|
||||
" image, prompt, llm_config, send_generate_request, call_sam_service,\n",
|
||||
" debug=True, output_dir=\"agent_output\"\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# display output\n",
|
||||
"if output_image_path is not None:\n",
|
||||
" display(Image(filename=output_image_path))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"fileHeader": "",
|
||||
"fileUid": "be59e249-6c09-4634-a9e7-1f06fd233c42",
|
||||
"isAdHoc": false,
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.11"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
846
examples/sam3_for_sam1_task_example.ipynb
Normal file
846
examples/sam3_for_sam1_task_example.ipynb
Normal file
@@ -0,0 +1,846 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f400486b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Copyright (c) Meta Platforms, Inc. and affiliates."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "a1ae39ff",
|
||||
"metadata": {
|
||||
"jp-MarkdownHeadingCollapsed": true
|
||||
},
|
||||
"source": [
|
||||
"# Interactive Instance Segmentation using SAM 3"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "b4a4b25c",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Segment Anything Model 3 (SAM 3) predicts instance masks that indicate the desired object given geometric prompts (SAM 1 task).\n",
|
||||
"The `SAM3Image` and `Sam3Processor` classes provide an easy interface to prompt the model. The user first sets an image using the `Sam3Processor.set_image` method, which computes the necessary image embeddings. Then, prompts can be provided via the `predict` method to efficiently predict masks from those prompts. The model can take as input both point and box prompts, as well as masks from the previous iteration of prediction.\n",
|
||||
"\n",
|
||||
"This notebook follows the SAM 2 API for interactive image segmentation.\n",
|
||||
"\n",
|
||||
"# <a target=\"_blank\" href=\"https://colab.research.google.com/github/facebookresearch/sam3/blob/main/notebooks/sam3_for_sam1_task_example.ipynb\">\n",
|
||||
"# <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
|
||||
"# </a>\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "644532a8",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Environment Set-up"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "07fabfee",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"First install `sam3` in your environment using the [installation instructions](https://github.com/facebookresearch/sam3?tab=readme-ov-file#installation) in the repository."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "0be845da",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Set-up"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "33681dd1",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Necessary imports and helper functions for displaying points, boxes, and masks."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "fe773ede",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"using_colab = False"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "79250a4e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"if using_colab:\n",
|
||||
" import torch\n",
|
||||
" import torchvision\n",
|
||||
" print(\"PyTorch version:\", torch.__version__)\n",
|
||||
" print(\"Torchvision version:\", torchvision.__version__)\n",
|
||||
" print(\"CUDA is available:\", torch.cuda.is_available())\n",
|
||||
" import sys\n",
|
||||
" !{sys.executable} -m pip install opencv-python matplotlib scikit-learn\n",
|
||||
" !{sys.executable} -m pip install 'git+https://github.com/facebookresearch/sam3.git'"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "69b28288",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"# if using Apple MPS, fall back to CPU for unsupported ops\n",
|
||||
"os.environ[\"PYTORCH_ENABLE_MPS_FALLBACK\"] = \"1\"\n",
|
||||
"import numpy as np\n",
|
||||
"import torch\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"from PIL import Image\n",
|
||||
"import sam3\n",
|
||||
"sam3_root = os.path.join(os.path.dirname(sam3.__file__), \"..\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "33a15e2f-c7e1-4e5d-862f-fcb751a60b89",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# select the device for computation\n",
|
||||
"if torch.cuda.is_available():\n",
|
||||
" device = torch.device(\"cuda\")\n",
|
||||
"elif torch.backends.mps.is_available():\n",
|
||||
" device = torch.device(\"mps\")\n",
|
||||
"else:\n",
|
||||
" device = torch.device(\"cpu\")\n",
|
||||
"print(f\"using device: {device}\")\n",
|
||||
"\n",
|
||||
"if device.type == \"cuda\":\n",
|
||||
" # use bfloat16 for the entire notebook\n",
|
||||
" torch.autocast(\"cuda\", dtype=torch.bfloat16).__enter__()\n",
|
||||
" # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)\n",
|
||||
" if torch.cuda.get_device_properties(0).major >= 8:\n",
|
||||
" torch.backends.cuda.matmul.allow_tf32 = True\n",
|
||||
" torch.backends.cudnn.allow_tf32 = True\n",
|
||||
"elif device.type == \"mps\":\n",
|
||||
" print(\n",
|
||||
" \"\\nSupport for MPS devices is preliminary. SAM 3 is trained with CUDA and might \"\n",
|
||||
" \"give numerically different outputs and sometimes degraded performance on MPS. \"\n",
|
||||
" \"See e.g. https://github.com/pytorch/pytorch/issues/84936 for a discussion.\"\n",
|
||||
" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "29bc90d5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"np.random.seed(3)\n",
|
||||
"\n",
|
||||
"def show_mask(mask, ax, random_color=False, borders = True):\n",
|
||||
" if random_color:\n",
|
||||
" color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)\n",
|
||||
" else:\n",
|
||||
" color = np.array([30/255, 144/255, 255/255, 0.6])\n",
|
||||
" h, w = mask.shape[-2:]\n",
|
||||
" mask = mask.astype(np.uint8)\n",
|
||||
" mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)\n",
|
||||
" if borders:\n",
|
||||
" import cv2\n",
|
||||
" contours, _ = cv2.findContours(mask,cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) \n",
|
||||
" # Try to smooth contours\n",
|
||||
" contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]\n",
|
||||
" mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2) \n",
|
||||
" ax.imshow(mask_image)\n",
|
||||
"\n",
|
||||
"def show_points(coords, labels, ax, marker_size=375):\n",
|
||||
" pos_points = coords[labels==1]\n",
|
||||
" neg_points = coords[labels==0]\n",
|
||||
" ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)\n",
|
||||
" ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) \n",
|
||||
"\n",
|
||||
"def show_box(box, ax):\n",
|
||||
" x0, y0 = box[0], box[1]\n",
|
||||
" w, h = box[2] - box[0], box[3] - box[1]\n",
|
||||
" ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2)) \n",
|
||||
"\n",
|
||||
"def show_masks(image, masks, scores, point_coords=None, box_coords=None, input_labels=None, borders=True):\n",
|
||||
" for i, (mask, score) in enumerate(zip(masks, scores)):\n",
|
||||
" plt.figure(figsize=(10, 10))\n",
|
||||
" plt.imshow(image)\n",
|
||||
" show_mask(mask, plt.gca(), borders=borders)\n",
|
||||
" if point_coords is not None:\n",
|
||||
" assert input_labels is not None\n",
|
||||
" show_points(point_coords, input_labels, plt.gca())\n",
|
||||
" if box_coords is not None:\n",
|
||||
" # boxes\n",
|
||||
" show_box(box_coords, plt.gca())\n",
|
||||
" if len(scores) > 1:\n",
|
||||
" plt.title(f\"Mask {i+1}, Score: {score:.3f}\", fontsize=18)\n",
|
||||
" plt.axis('off')\n",
|
||||
" plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "23842fb2",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Example image"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3c2e4f6b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"image = Image.open(f\"{sam3_root}/assets/images/truck.jpg\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e30125fd",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"plt.figure(figsize=(10, 10))\n",
|
||||
"plt.imshow(image)\n",
|
||||
"plt.axis('on')\n",
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "98b228b8",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Selecting objects with SAM 3"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "0bb1927b",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"First, load the SAM 3 model. Running on CUDA and using the default model are recommended for best results."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7e28150b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from sam3 import build_sam3_image_model\n",
|
||||
"from sam3.model.sam3_image_processor import Sam3Processor\n",
|
||||
"\n",
|
||||
"bpe_path = f\"{sam3_root}/assets/bpe_simple_vocab_16e6.txt.gz\"\n",
|
||||
"model = build_sam3_image_model(bpe_path=bpe_path, enable_inst_interactivity=True)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c925e829",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Process the image to produce an image embedding by calling `Sam3Processor.set_image`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d95d48dd",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"processor = Sam3Processor(model)\n",
|
||||
"inference_state = processor.set_image(image)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "d8fc7a46",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"To select the truck, choose a point on it. Points are input to the model in (x,y) format and come with labels 1 (foreground point) or 0 (background point). Multiple points can be input; here we use only one. The chosen point will be shown as a star on the image."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "5c69570c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"input_point = np.array([[520, 375]])\n",
|
||||
"input_label = np.array([1])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a91ba973",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"plt.figure(figsize=(10, 10))\n",
|
||||
"plt.imshow(image)\n",
|
||||
"show_points(input_point, input_label, plt.gca())\n",
|
||||
"plt.axis('on')\n",
|
||||
"plt.show() "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c765e952",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Predict with `SAM3Image.predict_inst`. The model returns masks, quality predictions for those masks, and low resolution mask logits that can be passed to the next iteration of prediction."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "5373fd68",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"masks, scores, logits = model.predict_inst(\n",
|
||||
" inference_state,\n",
|
||||
" point_coords=input_point,\n",
|
||||
" point_labels=input_label,\n",
|
||||
" multimask_output=True,\n",
|
||||
")\n",
|
||||
"sorted_ind = np.argsort(scores)[::-1]\n",
|
||||
"masks = masks[sorted_ind]\n",
|
||||
"scores = scores[sorted_ind]\n",
|
||||
"logits = logits[sorted_ind]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c7f0e938",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"With `multimask_output=True` (the default setting), SAM 3 outputs 3 masks, where `scores` gives the model's own estimation of the quality of these masks. This setting is intended for ambiguous input prompts, and helps the model disambiguate different objects consistent with the prompt. When `False`, it will return a single mask. For ambiguous prompts such as a single point, it is recommended to use `multimask_output=True` even if only a single mask is desired; the best single mask can be chosen by picking the one with the highest score returned in `scores`. This will often result in a better mask."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "47821187",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"masks.shape # (number_of_masks) x H x W"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e9c227a6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"show_masks(image, masks, scores, point_coords=input_point, input_labels=input_label, borders=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "3fa31f7c",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Specifying a specific object with additional points"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "88d6d29a",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The single input point is ambiguous, and the model has returned multiple objects consistent with it. To obtain a single object, multiple points can be provided. If available, a mask from a previous iteration can also be supplied to the model to aid in prediction. When specifying a single object with multiple prompts, a single mask can be requested by setting `multimask_output=False`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f6923b94",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"input_point = np.array([[500, 375], [1125, 625]])\n",
|
||||
"input_label = np.array([1, 1])\n",
|
||||
"\n",
|
||||
"mask_input = logits[np.argmax(scores), :, :] # Choose the model's best mask"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d98f96a1",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"masks, scores, _ = model.predict_inst(\n",
|
||||
" inference_state,\n",
|
||||
" point_coords=input_point,\n",
|
||||
" point_labels=input_label,\n",
|
||||
" mask_input=mask_input[None, :, :],\n",
|
||||
" multimask_output=False,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "0ce8b82f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"masks.shape"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e06d5c8d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"show_masks(image, masks, scores, point_coords=input_point, input_labels=input_label)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c93e2087",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"To exclude the car and specify just the window, a background point (with label 0, here shown in red) can be supplied."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9a196f68",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"input_point = np.array([[500, 375], [1125, 625]])\n",
|
||||
"input_label = np.array([1, 0])\n",
|
||||
"\n",
|
||||
"mask_input = logits[np.argmax(scores), :, :] # Choose the model's best mask"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "81a52282",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"masks, scores, _ = model.predict_inst(\n",
|
||||
" inference_state,\n",
|
||||
" point_coords=input_point,\n",
|
||||
" point_labels=input_label,\n",
|
||||
" mask_input=mask_input[None, :, :],\n",
|
||||
" multimask_output=False,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "bfca709f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"show_masks(image, masks, scores, point_coords=input_point, input_labels=input_label)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "41e2d5a9",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Specifying a specific object with a box"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "d61ca7ac",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The model can also take a box as input, provided in xyxy format."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8ea92a7b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"input_box = np.array([425, 600, 700, 875])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b35a8814",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"masks, scores, _ = model.predict_inst(\n",
|
||||
" inference_state,\n",
|
||||
" point_coords=None,\n",
|
||||
" point_labels=None,\n",
|
||||
" box=input_box[None, :],\n",
|
||||
" multimask_output=False,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3ffb4906",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"show_masks(image, masks, scores, box_coords=input_box)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c1ed9f0a",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Combining points and boxes"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "8455d1c5",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Points and boxes may be combined, just by including both types of prompts to the predictor. Here this can be used to select just the trucks's tire, instead of the entire wheel."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "90e2e547",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"input_box = np.array([425, 600, 700, 875])\n",
|
||||
"input_point = np.array([[575, 750]])\n",
|
||||
"input_label = np.array([0])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6956d8c4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"masks, scores, logits = model.predict_inst(\n",
|
||||
" inference_state,\n",
|
||||
" point_coords=input_point,\n",
|
||||
" point_labels=input_label,\n",
|
||||
" box=input_box,\n",
|
||||
" multimask_output=False,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "eb519a31",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"show_masks(image, masks, scores, box_coords=input_box, point_coords=input_point, input_labels=input_label)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "45ddbca3",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Batched prompt inputs"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "df6f18a0",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"`SAM3Image` can take multiple input prompts for the same image, using `predict_inst` method. For example, imagine we have several box outputs from an object detector."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "0a06681b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"input_boxes = np.array([\n",
|
||||
" [75, 275, 1725, 850],\n",
|
||||
" [425, 600, 700, 875],\n",
|
||||
" [1375, 550, 1650, 800],\n",
|
||||
" [1240, 675, 1400, 750],\n",
|
||||
"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "117521a3",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"masks, scores, _ = model.predict_inst(\n",
|
||||
" inference_state,\n",
|
||||
" point_coords=None,\n",
|
||||
" point_labels=None,\n",
|
||||
" box=input_boxes,\n",
|
||||
" multimask_output=False,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6a8f5d49",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"masks.shape # (batch_size) x (num_predicted_masks_per_input) x H x W"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c00c3681",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"plt.figure(figsize=(10, 10))\n",
|
||||
"plt.imshow(image)\n",
|
||||
"for mask in masks:\n",
|
||||
" show_mask(mask.squeeze(0), plt.gca(), random_color=True)\n",
|
||||
"for box in input_boxes:\n",
|
||||
" show_box(box, plt.gca())\n",
|
||||
"plt.axis('off')\n",
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "b9a27b5d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## End-to-end batched inference\n",
|
||||
"If all prompts are available in advance, it is possible to run SAM 3 directly in an end-to-end fashion. This also allows batching over images."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d485f75b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"image1 = image # truck.jpg from above\n",
|
||||
"image1_boxes = np.array([\n",
|
||||
" [75, 275, 1725, 850],\n",
|
||||
" [425, 600, 700, 875],\n",
|
||||
" [1375, 550, 1650, 800],\n",
|
||||
" [1240, 675, 1400, 750],\n",
|
||||
"])\n",
|
||||
"\n",
|
||||
"image2 = Image.open(f\"{sam3_root}/assets/images/groceries.jpg\")\n",
|
||||
"image2_boxes = np.array([\n",
|
||||
" [450, 170, 520, 350],\n",
|
||||
" [350, 190, 450, 350],\n",
|
||||
" [500, 170, 580, 350],\n",
|
||||
" [580, 170, 640, 350],\n",
|
||||
"])\n",
|
||||
"\n",
|
||||
"img_batch = [image1, image2]\n",
|
||||
"boxes_batch = [image1_boxes, image2_boxes]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "47932c99",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"inference_state = processor.set_image_batch(img_batch)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "97af3c54",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"masks_batch, scores_batch, _ = model.predict_inst_batch(\n",
|
||||
" inference_state,\n",
|
||||
" None,\n",
|
||||
" None, \n",
|
||||
" box_batch=boxes_batch, \n",
|
||||
" multimask_output=False\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "226df881",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"for image, boxes, masks in zip(img_batch, boxes_batch, masks_batch):\n",
|
||||
" plt.figure(figsize=(10, 10))\n",
|
||||
" plt.imshow(image) \n",
|
||||
" for mask in masks:\n",
|
||||
" show_mask(mask.squeeze(0), plt.gca(), random_color=True)\n",
|
||||
" for box in boxes:\n",
|
||||
" show_box(box, plt.gca())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "46f30085",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Similarly, we can have a batch of point prompts defined over a batch of images"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "1ab929fc",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"image1 = image # truck.jpg from above\n",
|
||||
"image1_pts = np.array([\n",
|
||||
" [[500, 375]],\n",
|
||||
" [[650, 750]]\n",
|
||||
" ]) # Bx1x2 where B corresponds to number of objects \n",
|
||||
"image1_labels = np.array([[1], [1]])\n",
|
||||
"\n",
|
||||
"image2_pts = np.array([\n",
|
||||
" [[400, 300]],\n",
|
||||
" [[630, 300]],\n",
|
||||
"])\n",
|
||||
"image2_labels = np.array([[1], [1]])\n",
|
||||
"\n",
|
||||
"pts_batch = [image1_pts, image2_pts]\n",
|
||||
"labels_batch = [image1_labels, image2_labels]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "848f8287",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"masks_batch, scores_batch, _ = model.predict_inst_batch(inference_state, pts_batch, labels_batch, box_batch=None, multimask_output=True)\n",
|
||||
"\n",
|
||||
"# Select the best single mask per object\n",
|
||||
"best_masks = []\n",
|
||||
"for masks, scores in zip(masks_batch,scores_batch):\n",
|
||||
" best_masks.append(masks[range(len(masks)), np.argmax(scores, axis=-1)])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "99b15c6c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"for image, points, labels, masks in zip(img_batch, pts_batch, labels_batch, best_masks):\n",
|
||||
" plt.figure(figsize=(10, 10))\n",
|
||||
" plt.imshow(image) \n",
|
||||
" for mask in masks:\n",
|
||||
" show_mask(mask, plt.gca(), random_color=True)\n",
|
||||
" show_points(points, labels, plt.gca())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "4c1594a5-a0de-4477-91d4-db4504a78a83",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "74e3d07e-b0de-48a5-9d29-d639a0dbcdfc",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d8b1de3a-a253-48ff-8a1c-d80742acbe86",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.11"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
979
examples/sam3_for_sam2_video_task_example.ipynb
Normal file
979
examples/sam3_for_sam2_video_task_example.ipynb
Normal file
@@ -0,0 +1,979 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3c3b1c46-9f5c-41c1-9101-85db8709ec0d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Copyright (c) Meta Platforms, Inc. and affiliates."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "6e7a0db5-7f04-4845-8b11-684fe6e9f7f2",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Video object segmentation with SAM 3"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "162d0b3c-4207-442d-969c-aa1cbb8fd4ad",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"This notebook shows how to use SAM 3 for video object segmentation in videos, illustrating the use of the `Sam3TrackerPredictor` class.\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"This notebook follows the SAM 2 API for interactive video segmentation.\n",
|
||||
"\n",
|
||||
"<a target=\"_blank\" href=\"https://colab.research.google.com/github/facebookresearch/sam3/blob/main/notebooks/sam3_for_sam2_video_task_example.ipynb\">\n",
|
||||
" <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
|
||||
"</a>"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "26616201-06df-435b-98fd-ad17c373bb4a",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Environment Set-up"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "8491a127-4c01-48f5-9dc5-f148a9417fdf",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"First install `sam3` in your environment using the [installation instructions](https://github.com/facebookresearch/sam3?tab=readme-ov-file#installation) in the repository."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f74c53be-aab1-46b9-8c0b-068b52ef5948",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"using_colab = False"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d824a4b2-71f3-4da3-bfc7-3249625e6730",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"if using_colab:\n",
|
||||
" import torch\n",
|
||||
" import torchvision\n",
|
||||
" print(\"PyTorch version:\", torch.__version__)\n",
|
||||
" print(\"Torchvision version:\", torchvision.__version__)\n",
|
||||
" print(\"CUDA is available:\", torch.cuda.is_available())\n",
|
||||
" import sys\n",
|
||||
" !{sys.executable} -m pip install opencv-python matplotlib scikit-learn\n",
|
||||
" !{sys.executable} -m pip install 'git+https://github.com/facebookresearch/sam3.git'"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "22e6aa9d-487f-4207-b657-8cff0902343e",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Set-up"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d3cae821",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"# select the device for computation\n",
|
||||
"if torch.cuda.is_available():\n",
|
||||
" device = torch.device(\"cuda\")\n",
|
||||
"elif torch.backends.mps.is_available():\n",
|
||||
" device = torch.device(\"mps\")\n",
|
||||
"else:\n",
|
||||
" device = torch.device(\"cpu\")\n",
|
||||
"print(f\"using device: {device}\")\n",
|
||||
"\n",
|
||||
"if device.type == \"cuda\":\n",
|
||||
" # use bfloat16 for the entire notebook\n",
|
||||
" torch.autocast(\"cuda\", dtype=torch.bfloat16).__enter__()\n",
|
||||
" # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)\n",
|
||||
" if torch.cuda.get_device_properties(0).major >= 8:\n",
|
||||
" torch.backends.cuda.matmul.allow_tf32 = True\n",
|
||||
" torch.backends.cudnn.allow_tf32 = True\n",
|
||||
"\n",
|
||||
"elif device.type == \"mps\":\n",
|
||||
" print(\n",
|
||||
" \"\\nSupport for MPS devices is preliminary. SAM 3 is trained with CUDA and might \"\n",
|
||||
" \"give numerically different outputs and sometimes degraded performance on MPS. \"\n",
|
||||
" \"See e.g. https://github.com/pytorch/pytorch/issues/84936 for a discussion.\"\n",
|
||||
" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e5318a85-5bf7-4880-b2b3-15e4db24d796",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import glob\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"import cv2\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"import sam3\n",
|
||||
"import torch\n",
|
||||
"from PIL import Image\n",
|
||||
"from sam3.visualization_utils import show_box, show_mask, show_points\n",
|
||||
"\n",
|
||||
"# font size for axes titles\n",
|
||||
"plt.rcParams[\"axes.titlesize\"] = 12\n",
|
||||
"plt.rcParams[\"figure.titlesize\"] = 12\n",
|
||||
"\n",
|
||||
"sam3_root = os.path.join(os.path.dirname(sam3.__file__), \"..\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ae8e0779-751f-4224-9b04-ed0f0b406500",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Loading the SAM 3 tracking predictor"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f5f3245e-b4d6-418b-a42a-a67e0b3b5aec",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from sam3.model_builder import build_sam3_video_model\n",
|
||||
"\n",
|
||||
"sam3_model = build_sam3_video_model()\n",
|
||||
"predictor = sam3_model.tracker\n",
|
||||
"predictor.backbone = sam3_model.detector.backbone"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "dff46b10-c17a-4a26-8004-8c6d80806b0a",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Initialize the inference state"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f594ac71-a6b9-461d-af27-500fa1d1a420",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Just like SAM 2, SAM 3 requires stateful inference for interactive video segmentation, so we need to initialize an **inference state** on this video.\n",
|
||||
"\n",
|
||||
"During initialization, it loads all the JPEG frames in `video_path` and stores their pixels in `inference_state` (as shown in the progress bar below)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9baa05c9",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"video_path = f\"{sam3_root}/assets/videos/bedroom.mp4\"\n",
|
||||
"inference_state = predictor.init_state(video_path=video_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "edb1f3f6-d74d-4016-934c-8d2a14d1a543",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Example 1: Segment & track one object"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "aa2d3127-67b2-45d2-9f32-8fe3e10dc5eb",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Note: if you have run any previous tracking using this `inference_state`, please reset it first via `clear_all_points_in_video`.\n",
|
||||
"\n",
|
||||
"(The cell below is just for illustration; it's not needed to call `clear_all_points_in_video` here as this `inference_state` is just freshly initialized above.)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d2646a1d-3401-438c-a653-55e0e56b7d9d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"predictor.clear_all_points_in_video(inference_state)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "26aeb04d-8cba-4f57-95da-6e5a1796003e",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Step 1: Add a first click on a frame"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "695c7749-b523-4691-aad0-7558c5d1d68c",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"To get started, let's try to segment the child on the left.\n",
|
||||
"\n",
|
||||
"Here we make a **positive click** at (x, y) = (210, 350) with label `1`, by sending their coordinates and labels into the `add_new_points` API.\n",
|
||||
"\n",
|
||||
"Note: label `1` indicates a *positive click (to add a region)* while label `0` indicates a *negative click (to remove a region)*."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "dd6778a1",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# load the frames for visualization\n",
|
||||
"cap = cv2.VideoCapture(video_path)\n",
|
||||
"video_frames_for_vis = []\n",
|
||||
"while True:\n",
|
||||
" ret, frame = cap.read()\n",
|
||||
" if not ret:\n",
|
||||
" break\n",
|
||||
" video_frames_for_vis.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))\n",
|
||||
"cap.release()\n",
|
||||
"frame0 = video_frames_for_vis[0]\n",
|
||||
"\n",
|
||||
"width, height = frame0.shape[1], frame0.shape[0]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3e749bab-0f36-4173-bf8d-0c20cd5214b3",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ann_frame_idx = 0 # the frame index we interact with\n",
|
||||
"ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers)\n",
|
||||
"\n",
|
||||
"# Let's add a positive click at (x, y) = (210, 350) to get started\n",
|
||||
"points = np.array([[210, 350]], dtype=np.float32)\n",
|
||||
"# for labels, `1` means positive click and `0` means negative click\n",
|
||||
"labels = np.array([1], np.int32)\n",
|
||||
"\n",
|
||||
"rel_points = [[x / width, y / height] for x, y in points]\n",
|
||||
"\n",
|
||||
"points_tensor = torch.tensor(rel_points, dtype=torch.float32)\n",
|
||||
"points_labels_tensor = torch.tensor(labels, dtype=torch.int32)\n",
|
||||
"\n",
|
||||
"_, out_obj_ids, low_res_masks, video_res_masks = predictor.add_new_points(\n",
|
||||
" inference_state=inference_state,\n",
|
||||
" frame_idx=ann_frame_idx,\n",
|
||||
" obj_id=ann_obj_id,\n",
|
||||
" points=points_tensor,\n",
|
||||
" labels=points_labels_tensor,\n",
|
||||
" clear_old_points=False,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# show the results on the current (interacted) frame\n",
|
||||
"plt.figure(figsize=(9, 6))\n",
|
||||
"plt.title(f\"frame {ann_frame_idx}\")\n",
|
||||
"plt.imshow(frame0)\n",
|
||||
"show_points(points, labels, plt.gca())\n",
|
||||
"show_mask((video_res_masks[0] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_ids[0])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "89457875-93fa-40ed-b6dc-4e1c971a27f9",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Step 2: Add a second click to refine the prediction"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "a75eb21b-1413-452c-827b-a04093c30c78",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Hmm, it seems that although we wanted to segment the child on the left, the model predicts the mask for only the shorts -- this can happen since there is ambiguity from a single click about what the target object should be. We can refine the mask on this frame via another positive click on the child's shirt.\n",
|
||||
"\n",
|
||||
"Here we make a **second positive click** at (x, y) = (250, 220) with label `1` to expand the mask.\n",
|
||||
"\n",
|
||||
"Note: we need to send **all the clicks and their labels** (i.e. not just the last click) when calling `add_new_points`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e1ab3ec7-2537-4158-bf98-3d0977d8908d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ann_frame_idx = 0 # the frame index we interact with\n",
|
||||
"ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers)\n",
|
||||
"\n",
|
||||
"# Let's add a 2nd positive click at (x, y) = (250, 220) to refine the mask\n",
|
||||
"# sending all clicks (and their labels) to `add_new_points_or_box`\n",
|
||||
"points = np.array([[210, 350], [250, 220]], dtype=np.float32)\n",
|
||||
"# for labels, `1` means positive click and `0` means negative click\n",
|
||||
"labels = np.array([1, 1], np.int32)\n",
|
||||
"\n",
|
||||
"rel_points = [[x / width, y / height] for x, y in points]\n",
|
||||
"\n",
|
||||
"points_tensor = torch.tensor(rel_points, dtype=torch.float32)\n",
|
||||
"points_labels_tensor = torch.tensor(labels, dtype=torch.int32)\n",
|
||||
"\n",
|
||||
"_, out_obj_ids, low_res_masks, video_res_masks = predictor.add_new_points(\n",
|
||||
" inference_state=inference_state,\n",
|
||||
" frame_idx=ann_frame_idx,\n",
|
||||
" obj_id=ann_obj_id,\n",
|
||||
" points=points_tensor,\n",
|
||||
" labels=points_labels_tensor,\n",
|
||||
" clear_old_points=False,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# show the results on the current (interacted) frame\n",
|
||||
"plt.figure(figsize=(9, 6))\n",
|
||||
"plt.title(f\"frame {ann_frame_idx}\")\n",
|
||||
"plt.imshow(frame0)\n",
|
||||
"show_points(points, labels, plt.gca())\n",
|
||||
"show_mask((video_res_masks[0] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_ids[0])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "df4ab457-d91d-4ac8-b350-fbcd549fd3fd",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"With this 2nd refinement click, now we get a segmentation mask of the entire child on frame 0."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f52015ac-1b7b-4c59-bca3-c2b28484cf46",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Step 3: Propagate the prompts to get the masklet across the video"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "30b025bd-cd58-4bfb-9572-c8d2fd0a02ef",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"To get the masklet throughout the entire video, we propagate the prompts using the `propagate_in_video` API."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ab45e932-b0d5-4983-9718-6ee77d1ac31b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# run propagation throughout the video and collect the results in a dict\n",
|
||||
"video_segments = {} # video_segments contains the per-frame segmentation results\n",
|
||||
"for frame_idx, obj_ids, low_res_masks, video_res_masks, obj_scores in predictor.propagate_in_video(inference_state, start_frame_idx=0, max_frame_num_to_track=240, reverse=False, propagate_preflight=True):\n",
|
||||
" video_segments[frame_idx] = {\n",
|
||||
" out_obj_id: (video_res_masks[i] > 0.0).cpu().numpy()\n",
|
||||
" for i, out_obj_id in enumerate(out_obj_ids)\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
"# render the segmentation results every few frames\n",
|
||||
"vis_frame_stride = 30\n",
|
||||
"plt.close(\"all\")\n",
|
||||
"for out_frame_idx in range(0, len(video_frames_for_vis), vis_frame_stride):\n",
|
||||
" plt.figure(figsize=(6, 4))\n",
|
||||
" plt.title(f\"frame {out_frame_idx}\")\n",
|
||||
" plt.imshow(video_frames_for_vis[out_frame_idx])\n",
|
||||
" for out_obj_id, out_mask in video_segments[out_frame_idx].items():\n",
|
||||
" show_mask(out_mask, plt.gca(), obj_id=out_obj_id)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "3e801b70-72df-4a72-b3fe-84f145e5e3f6",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Step 4: Add new prompts to further refine the masklet"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "478958ab-29b4-4a75-bba4-adb1b03d0a2b",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"It appears that in the output masklet above, there are some small imperfections in boundary details on frame 150.\n",
|
||||
"\n",
|
||||
"With SAM 3 we can fix the model predictions interactively. We can add a **negative click** at (x, y) = (82, 415) on this frame with label `0` to refine the masklet. Here we call the `add_new_points_or_box` API with a different `frame_idx` argument to indicate the frame index we want to refine."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "1a572ea9-5b7e-479c-b30c-93c38b121131",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ann_frame_idx = 150 # further refine some details on this frame\n",
|
||||
"ann_obj_id = 1 # give a unique id to the object we interact with (it can be any integers)\n",
|
||||
"\n",
|
||||
"# show the segment before further refinement\n",
|
||||
"plt.figure(figsize=(9, 6))\n",
|
||||
"plt.title(f\"frame {ann_frame_idx} -- before refinement\")\n",
|
||||
"plt.imshow(video_frames_for_vis[ann_frame_idx])\n",
|
||||
"show_mask(video_segments[ann_frame_idx][ann_obj_id], plt.gca(), obj_id=ann_obj_id)\n",
|
||||
"\n",
|
||||
"# Let's add a negative click on this frame at (x, y) = (82, 415) to refine the segment\n",
|
||||
"points = np.array([[82, 410]], dtype=np.float32)\n",
|
||||
"# for labels, `1` means positive click and `0` means negative click\n",
|
||||
"labels = np.array([0], np.int32)\n",
|
||||
"\n",
|
||||
"rel_points = [[x / width, y / height] for x, y in points]\n",
|
||||
"\n",
|
||||
"points_tensor = torch.tensor(rel_points, dtype=torch.float32)\n",
|
||||
"points_labels_tensor = torch.tensor(labels, dtype=torch.int32)\n",
|
||||
"\n",
|
||||
"_, out_obj_ids, low_res_masks, video_res_masks = predictor.add_new_points(\n",
|
||||
" inference_state=inference_state,\n",
|
||||
" frame_idx=ann_frame_idx,\n",
|
||||
" obj_id=ann_obj_id,\n",
|
||||
" points=points_tensor,\n",
|
||||
" labels=points_labels_tensor,\n",
|
||||
" clear_old_points=False,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# show the segment after the further refinement\n",
|
||||
"plt.figure(figsize=(9, 6))\n",
|
||||
"plt.title(f\"frame {ann_frame_idx} -- after refinement\")\n",
|
||||
"plt.imshow(video_frames_for_vis[ann_frame_idx])\n",
|
||||
"show_points(points, labels, plt.gca())\n",
|
||||
"show_mask((video_res_masks > 0.0).cpu().numpy(), plt.gca(), obj_id=ann_obj_id)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "50a3950a-acf1-435c-bd64-94297267b5e9",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Step 5: Propagate the prompts (again) to get the masklet across the video"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "b1954ecf-c2ec-4f9c-8d10-c4f527a10cd2",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Let's get an updated masklet for the entire video. Here we call `propagate_in_video` again to propagate all the prompts after adding the new refinement click above."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "baa96690-4a38-4a24-aa17-fd2f4db0e232",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# run propagation throughout the video and collect the results in a dict\n",
|
||||
"video_segments = {} # video_segments contains the per-frame segmentation results\n",
|
||||
"for frame_idx, obj_ids, low_res_masks, video_res_masks, obj_scores in predictor.propagate_in_video(inference_state, start_frame_idx=0, max_frame_num_to_track=300, reverse=False, propagate_preflight=True):\n",
|
||||
" video_segments[frame_idx] = {\n",
|
||||
" out_obj_id: (video_res_masks[i] > 0.0).cpu().numpy()\n",
|
||||
" for i, out_obj_id in enumerate(out_obj_ids)\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
"# render the segmentation results every few frames\n",
|
||||
"vis_frame_stride = 30\n",
|
||||
"plt.close(\"all\")\n",
|
||||
"for out_frame_idx in range(0, len(video_frames_for_vis), vis_frame_stride):\n",
|
||||
" plt.figure(figsize=(6, 4))\n",
|
||||
" plt.title(f\"frame {out_frame_idx}\")\n",
|
||||
" plt.imshow(video_frames_for_vis[out_frame_idx])\n",
|
||||
" for out_obj_id, out_mask in video_segments[out_frame_idx].items():\n",
|
||||
" show_mask(out_mask, plt.gca(), obj_id=out_obj_id)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "607507e3-6a2b-4fd7-944c-2371bdab9d01",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The segments now look good on all frames."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2502bb5a-3e1f-43d0-9f58-33f8676fff0d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Example 2: Segment an object using box prompt"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "8e2d26c8-0432-48c6-997e-4a3b77bb5f6d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Note: if you have run any previous tracking using this `inference_state`, please reset it first via `clear_all_points_in_video`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6dbe9183-abbb-4283-b0cb-d24f3d7beb34",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"predictor.clear_all_points_in_video(inference_state)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ceb6eae9-0f4c-434f-8089-a46c9ca59da5",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"In addition to using clicks as inputs, SAM 3 also supports segmenting and tracking objects in a video via **bounding boxes**.\n",
|
||||
"\n",
|
||||
"In the example below, we segment the child on the right using a **box prompt** of (x_min, y_min, x_max, y_max) = (300, 0, 500, 400) on frame 0 as input into the `add_new_points_or_box` API."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "1cbfb273-4e14-495b-bd89-87a8baf52ae7",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ann_frame_idx = 0 # the frame index we interact with\n",
|
||||
"ann_obj_id = 4 # give a unique id to each object we interact with (it can be any integers)\n",
|
||||
"\n",
|
||||
"# Let's add a box at (x_min, y_min, x_max, y_max) = (300, 0, 500, 400) to get started\n",
|
||||
"box = np.array([[300, 0, 500, 400]], dtype=np.float32)\n",
|
||||
"\n",
|
||||
"rel_box = [[xmin / width, ymin / height, xmax / width, ymax / height] for xmin, ymin, xmax, ymax in box]\n",
|
||||
"rel_box = np.array(rel_box, dtype=np.float32)\n",
|
||||
"\n",
|
||||
"_, out_obj_ids, low_res_masks, video_res_masks = predictor.add_new_points_or_box(\n",
|
||||
" inference_state=inference_state,\n",
|
||||
" frame_idx=ann_frame_idx,\n",
|
||||
" obj_id=ann_obj_id,\n",
|
||||
" box=rel_box,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# show the results on the current (interacted) frame\n",
|
||||
"plt.figure(figsize=(9, 6))\n",
|
||||
"plt.title(f\"frame {ann_frame_idx}\")\n",
|
||||
"plt.imshow(video_frames_for_vis[ann_frame_idx])\n",
|
||||
"show_box(box[0], plt.gca())\n",
|
||||
"show_mask((video_res_masks[0] > 0.0).cpu().numpy(), plt.gca(), obj_id=ann_obj_id)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "bd3f9ba7-bf4d-47e5-9b02-8a424cab42cc",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Here, SAM 3 gets a pretty good segmentation mask of the entire child, even though the input bounding box is not perfectly tight around the object.\n",
|
||||
"\n",
|
||||
"Similar to the previous example, if the returned mask from is not perfect when using a box prompt, we can also further **refine** the output using positive or negative clicks. To illustrate this, here we make a **positive click** at (x, y) = (460, 60) with label `1` to expand the segment around the child's hair.\n",
|
||||
"\n",
|
||||
"Note: to refine the segmentation mask from a box prompt, we need to send **both the original box input and all subsequent refinement clicks and their labels** when calling `add_new_points_or_box`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "54906315-ab4c-4088-b866-4c22134d5b66",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ann_frame_idx = 0 # the frame index we interact with\n",
|
||||
"ann_obj_id = 4 # give a unique id to each object we interact with (it can be any integers)\n",
|
||||
"\n",
|
||||
"# Let's add a positive click at (x, y) = (460, 60) to refine the mask\n",
|
||||
"points = np.array([[460, 60]], dtype=np.float32)\n",
|
||||
"# for labels, `1` means positive click and `0` means negative click\n",
|
||||
"labels = np.array([1], np.int32)\n",
|
||||
"# note that we also need to send the original box input along with\n",
|
||||
"# the new refinement click together into `add_new_points_or_box`\n",
|
||||
"box = np.array([[300, 0, 500, 400]], dtype=np.float32)\n",
|
||||
"\n",
|
||||
"rel_box = [[xmin / width, ymin / height, xmax / width, ymax / height] for xmin, ymin, xmax, ymax in box]\n",
|
||||
"rel_box = np.array(rel_box, dtype=np.float32)\n",
|
||||
"\n",
|
||||
"rel_points = [[x / width, y / height] for x, y in points]\n",
|
||||
"\n",
|
||||
"points_tensor = torch.tensor(rel_points, dtype=torch.float32)\n",
|
||||
"points_labels_tensor = torch.tensor(labels, dtype=torch.int32)\n",
|
||||
"\n",
|
||||
"_, out_obj_ids, low_res_masks, video_res_masks = predictor.add_new_points_or_box(\n",
|
||||
" inference_state=inference_state,\n",
|
||||
" frame_idx=ann_frame_idx,\n",
|
||||
" obj_id=ann_obj_id,\n",
|
||||
" points=points_tensor,\n",
|
||||
" labels=points_labels_tensor,\n",
|
||||
" box=rel_box,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# show the results on the current (interacted) frame\n",
|
||||
"plt.figure(figsize=(9, 6))\n",
|
||||
"plt.title(f\"frame {ann_frame_idx}\")\n",
|
||||
"plt.imshow(video_frames_for_vis[ann_frame_idx])\n",
|
||||
"show_box(box[0], plt.gca())\n",
|
||||
"show_points(points, labels, plt.gca())\n",
|
||||
"show_mask((video_res_masks[0][0] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_ids[0])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "73128cd6-dbfa-49f7-8d79-1a8e19835f7f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Then, to get the masklet throughout the entire video, we propagate the prompts using the `propagate_in_video` API."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9cd90557-a0dc-442e-b091-9c74c831bef8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# run propagation throughout the video and collect the results in a dict\n",
|
||||
"video_segments = {} # video_segments contains the per-frame segmentation results\n",
|
||||
"for frame_idx, obj_ids, low_res_masks, video_res_masks, obj_scores in predictor.propagate_in_video(inference_state, start_frame_idx=0, max_frame_num_to_track=300, reverse=False, propagate_preflight=True):\n",
|
||||
" video_segments[frame_idx] = {\n",
|
||||
" out_obj_id: (video_res_masks[i] > 0.0).cpu().numpy()\n",
|
||||
" for i, out_obj_id in enumerate(out_obj_ids)\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
"# render the segmentation results every few frames\n",
|
||||
"vis_frame_stride = 30\n",
|
||||
"plt.close(\"all\")\n",
|
||||
"for out_frame_idx in range(0, len(video_frames_for_vis), vis_frame_stride):\n",
|
||||
" plt.figure(figsize=(6, 4))\n",
|
||||
" plt.title(f\"frame {out_frame_idx}\")\n",
|
||||
" plt.imshow(video_frames_for_vis[out_frame_idx])\n",
|
||||
" for out_obj_id, out_mask in video_segments[out_frame_idx].items():\n",
|
||||
" show_mask(out_mask, plt.gca(), obj_id=out_obj_id)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "e023f91f-0cc5-4980-ae8e-a13c5749112b",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Note that in addition to clicks or boxes, SAM 3 also supports directly using a **mask prompt** as input via the `add_new_mask` method in the `Sam3TrackerPredictor` class. This can be helpful in e.g. semi-supervised VOS evaluations (see [tools/vos_inference.py](https://github.com/facebookresearch/sam2/blob/main/tools/vos_inference.py) for an example)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "da018be8-a4ae-4943-b1ff-702c2b89cb68",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Example 3: Segment multiple objects simultaneously"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "dea6c04c-3072-4876-b394-879321a48c4a",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Note: if you have run any previous tracking using this `inference_state`, please reset it first via `clear_all_points_in_video`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "29b874c8-9f39-42d3-a667-54a0bd696410",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"predictor.clear_all_points_in_video(inference_state)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "48f3f7e6-4821-468c-84e4-f3a0435c9149",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Step 1: Add two objects on a frame"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "95158714-86d7-48a9-8365-b213f97cc9ca",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"SAM 3 can also segment and track two or more objects at the same time. One way, of course, is to do them one by one. However, it would be more efficient to batch them together (e.g. so that we can share the image features between objects to reduce computation costs).\n",
|
||||
"\n",
|
||||
"This time, let's focus on object parts and segment **the shirts of both childen** in this video. Here we add prompts for these two objects and assign each of them a unique object id."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e22d896d-3cd5-4fa0-9230-f33e217035dc",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"prompts = {} # hold all the clicks we add for visualization"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "59d9ac57-b14a-4237-828d-927e422c518b",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Add the first object (the left child's shirt) with a **positive click** at (x, y) = (200, 300) on frame 0.\n",
|
||||
"\n",
|
||||
"We assign it to object id `2` (it can be arbitrary integers, and only needs to be unique for each object to track), which is passed to the `add_new_points_or_box` API to distinguish the object we are clicking upon."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d13432fc-f467-44d8-adfe-3e0c488046b7",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ann_frame_idx = 0 # the frame index we interact with\n",
|
||||
"ann_obj_id = 2 # give a unique id to each object we interact with (it can be any integers)\n",
|
||||
"\n",
|
||||
"# Let's add a positive click at (x, y) = (200, 300) to get started on the first object\n",
|
||||
"points = np.array([[200, 300]], dtype=np.float32)\n",
|
||||
"# for labels, `1` means positive click and `0` means negative click\n",
|
||||
"labels = np.array([1], np.int32)\n",
|
||||
"prompts[ann_obj_id] = points, labels\n",
|
||||
"\n",
|
||||
"rel_points = [[x / width, y / height] for x, y in points]\n",
|
||||
"points_tensor = torch.tensor(rel_points, dtype=torch.float32)\n",
|
||||
"points_labels_tensor = torch.tensor(labels, dtype=torch.int32)\n",
|
||||
"\n",
|
||||
"_, out_obj_ids, low_res_masks, video_res_masks = predictor.add_new_points_or_box(\n",
|
||||
" inference_state=inference_state,\n",
|
||||
" frame_idx=ann_frame_idx,\n",
|
||||
" obj_id=ann_obj_id,\n",
|
||||
" points=points_tensor,\n",
|
||||
" labels=points_labels_tensor,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# show the results on the current (interacted) frame\n",
|
||||
"plt.figure(figsize=(9, 6))\n",
|
||||
"plt.title(f\"frame {ann_frame_idx}\")\n",
|
||||
"plt.imshow(video_frames_for_vis[ann_frame_idx])\n",
|
||||
"for i, out_obj_id in enumerate(out_obj_ids):\n",
|
||||
" show_points(points, labels, plt.gca())\n",
|
||||
" show_points(*prompts[out_obj_id], plt.gca())\n",
|
||||
" show_mask((video_res_masks[i][0] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_id)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1bbbd51b-e1e2-4c36-99ec-1d9a1b49b0cd",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Hmm, this time we just want to select the child's shirt, but the model predicts the mask for the entire child. Let's refine the prediction with a **negative click** at (x, y) = (275, 175)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "95ecf61d-662b-4f98-ae62-46557b219842",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# add the first object\n",
|
||||
"ann_frame_idx = 0 # the frame index we interact with\n",
|
||||
"ann_obj_id = 2 # give a unique id to each object we interact with (it can be any integers)\n",
|
||||
"\n",
|
||||
"# Let's add a 2nd negative click at (x, y) = (275, 175) to refine the first object\n",
|
||||
"# sending all clicks (and their labels) to `add_new_points_or_box`\n",
|
||||
"points = np.array([[200, 300], [275, 175]], dtype=np.float32)\n",
|
||||
"# for labels, `1` means positive click and `0` means negative click\n",
|
||||
"labels = np.array([1, 0], np.int32)\n",
|
||||
"prompts[ann_obj_id] = points, labels\n",
|
||||
"\n",
|
||||
"rel_points = [[x / width, y / height] for x, y in points]\n",
|
||||
"points_tensor = torch.tensor(rel_points, dtype=torch.float32)\n",
|
||||
"points_labels_tensor = torch.tensor(labels, dtype=torch.int32)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"_, out_obj_ids, low_res_masks, video_res_masks = predictor.add_new_points_or_box(\n",
|
||||
" inference_state=inference_state,\n",
|
||||
" frame_idx=ann_frame_idx,\n",
|
||||
" obj_id=ann_obj_id,\n",
|
||||
" points=rel_points,\n",
|
||||
" labels=points_labels_tensor,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# show the results on the current (interacted) frame\n",
|
||||
"plt.figure(figsize=(9, 6))\n",
|
||||
"plt.title(f\"frame {ann_frame_idx}\")\n",
|
||||
"plt.imshow(video_frames_for_vis[ann_frame_idx])\n",
|
||||
"for i, out_obj_id in enumerate(out_obj_ids):\n",
|
||||
" show_points(points, labels, plt.gca())\n",
|
||||
" show_points(*prompts[out_obj_id], plt.gca())\n",
|
||||
" show_mask((video_res_masks[i][0] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_id)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "194718c1-734d-446c-a3ef-361057de2f31",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"After the 2nd negative click, now we get the left child's shirt as our first object.\n",
|
||||
"\n",
|
||||
"Let's move on to the second object (the right child's shirt) with a positive click at (x, y) = (400, 150) on frame 0. Here we assign object id `3` to this second object (it can be arbitrary integers, and only needs to be unique for each object to track).\n",
|
||||
"\n",
|
||||
"Note: when there are multiple objects, the `add_new_points_or_box` API will return a list of masks for each object."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "86ca1bde-62a4-40e6-98e4-15606441e52f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ann_frame_idx = 0 # the frame index we interact with\n",
|
||||
"ann_obj_id = 3 # give a unique id to each object we interact with (it can be any integers)\n",
|
||||
"\n",
|
||||
"# Let's now move on to the second object we want to track (giving it object id `3`)\n",
|
||||
"# with a positive click at (x, y) = (400, 150)\n",
|
||||
"points = np.array([[400, 150]], dtype=np.float32)\n",
|
||||
"# for labels, `1` means positive click and `0` means negative click\n",
|
||||
"labels = np.array([1], np.int32)\n",
|
||||
"prompts[ann_obj_id] = points, labels\n",
|
||||
"\n",
|
||||
"rel_points = [[x / width, y / height] for x, y in points]\n",
|
||||
"points_tensor = torch.tensor(rel_points, dtype=torch.float32)\n",
|
||||
"points_labels_tensor = torch.tensor(labels, dtype=torch.int32)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# `add_new_points_or_box` returns masks for all objects added so far on this interacted frame\n",
|
||||
"_, out_obj_ids, low_res_masks, video_res_masks = predictor.add_new_points_or_box(\n",
|
||||
" inference_state=inference_state,\n",
|
||||
" frame_idx=ann_frame_idx,\n",
|
||||
" obj_id=ann_obj_id,\n",
|
||||
" points=points_tensor,\n",
|
||||
" labels=points_labels_tensor,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# show the results on the current (interacted) frame on all objects\n",
|
||||
"plt.figure(figsize=(9, 6))\n",
|
||||
"plt.title(f\"frame {ann_frame_idx}\")\n",
|
||||
"plt.imshow(video_frames_for_vis[ann_frame_idx])\n",
|
||||
"for i, out_obj_id in enumerate(out_obj_ids):\n",
|
||||
" show_points(points, labels, plt.gca())\n",
|
||||
" show_points(*prompts[out_obj_id], plt.gca())\n",
|
||||
" show_mask((video_res_masks[i][0] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_id)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "a1f7add8-d577-4597-ae2f-654b8c7b05e0",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"This time the model predicts the mask of the shirt we want to track in just one click. Nice!"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "448733b8-ea8b-4078-995f-b676c3b558ba",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Step 2: Propagate the prompts to get masklets across the video"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "60bd73de-d669-41c8-b6ba-943883f0caa2",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now, we propagate the prompts for both objects to get their masklets throughout the video.\n",
|
||||
"\n",
|
||||
"Note: when there are multiple objects, the `propagate_in_video` API will return a list of masks for each object."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "17737191-d62b-4611-b2c6-6d0418a9ab74",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# run propagation throughout the video and collect the results in a dict\n",
|
||||
"video_segments = {} # video_segments contains the per-frame segmentation results\n",
|
||||
"for frame_idx, obj_ids, low_res_masks, video_res_masks, obj_scores in predictor.propagate_in_video(inference_state, start_frame_idx=0, max_frame_num_to_track=300, reverse=False, propagate_preflight=True):\n",
|
||||
" video_segments[frame_idx] = {\n",
|
||||
" out_obj_id: (video_res_masks[i] > 0.0).cpu().numpy()\n",
|
||||
" for i, out_obj_id in enumerate(out_obj_ids)\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
"# render the segmentation results every few frames\n",
|
||||
"vis_frame_stride = 30\n",
|
||||
"plt.close(\"all\")\n",
|
||||
"for out_frame_idx in range(0, len(video_frames_for_vis), vis_frame_stride):\n",
|
||||
" plt.figure(figsize=(6, 4))\n",
|
||||
" plt.title(f\"frame {out_frame_idx}\")\n",
|
||||
" plt.imshow(video_frames_for_vis[out_frame_idx])\n",
|
||||
" for out_obj_id, out_mask in video_segments[out_frame_idx].items():\n",
|
||||
" show_mask(out_mask, plt.gca(), obj_id=out_obj_id)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "18a0b9d7-c78f-432b-afb0-11f2ea5b652a",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Looks like both children's shirts are well segmented in this video.\n",
|
||||
"\n",
|
||||
"Now you can try SAM 3 on your own videos and use cases! "
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.11"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
672
examples/sam3_image_batched_inference.ipynb
Normal file
672
examples/sam3_image_batched_inference.ipynb
Normal file
File diff suppressed because one or more lines are too long
757
examples/sam3_image_interactive.ipynb
Normal file
757
examples/sam3_image_interactive.ipynb
Normal file
File diff suppressed because one or more lines are too long
374
examples/sam3_image_predictor_example.ipynb
Normal file
374
examples/sam3_image_predictor_example.ipynb
Normal file
File diff suppressed because one or more lines are too long
1603
examples/sam3_video_predictor_example.ipynb
Normal file
1603
examples/sam3_video_predictor_example.ipynb
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user