{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "f400486b", "metadata": {}, "outputs": [], "source": [ "# Copyright (c) Meta Platforms, Inc. and affiliates." ] }, { "cell_type": "markdown", "id": "a1ae39ff", "metadata": { "jp-MarkdownHeadingCollapsed": true }, "source": [ "# Interactive Instance Segmentation using SAM 3" ] }, { "cell_type": "markdown", "id": "b4a4b25c", "metadata": {}, "source": [ "Segment Anything Model 3 (SAM 3) predicts instance masks that indicate the desired object given geometric prompts (SAM 1 task).\n", "The `SAM3Image` and `Sam3Processor` classes provide an easy interface to prompt the model. The user first sets an image using the `Sam3Processor.set_image` method, which computes the necessary image embeddings. Then, prompts can be provided via the `predict` method to efficiently predict masks from those prompts. The model can take as input both point and box prompts, as well as masks from the previous iteration of prediction.\n", "\n", "This notebook follows the SAM 2 API for interactive image segmentation.\n", "\n", "# \n", "# \"Open\n", "# \n" ] }, { "cell_type": "markdown", "id": "644532a8", "metadata": {}, "source": [ "## Environment Set-up" ] }, { "cell_type": "markdown", "id": "07fabfee", "metadata": {}, "source": [ "First install `sam3` in your environment using the [installation instructions](https://github.com/facebookresearch/sam3?tab=readme-ov-file#installation) in the repository." ] }, { "cell_type": "markdown", "id": "0be845da", "metadata": {}, "source": [ "## Set-up" ] }, { "cell_type": "markdown", "id": "33681dd1", "metadata": {}, "source": [ "Necessary imports and helper functions for displaying points, boxes, and masks." ] }, { "cell_type": "code", "execution_count": null, "id": "fe773ede", "metadata": {}, "outputs": [], "source": [ "using_colab = False" ] }, { "cell_type": "code", "execution_count": null, "id": "79250a4e", "metadata": {}, "outputs": [], "source": [ "if using_colab:\n", " import torch\n", " import torchvision\n", " print(\"PyTorch version:\", torch.__version__)\n", " print(\"Torchvision version:\", torchvision.__version__)\n", " print(\"CUDA is available:\", torch.cuda.is_available())\n", " import sys\n", " !{sys.executable} -m pip install opencv-python matplotlib scikit-learn\n", " !{sys.executable} -m pip install 'git+https://github.com/facebookresearch/sam3.git'" ] }, { "cell_type": "code", "execution_count": null, "id": "69b28288", "metadata": {}, "outputs": [], "source": [ "import os\n", "# if using Apple MPS, fall back to CPU for unsupported ops\n", "os.environ[\"PYTORCH_ENABLE_MPS_FALLBACK\"] = \"1\"\n", "import numpy as np\n", "import torch\n", "import matplotlib.pyplot as plt\n", "from PIL import Image\n", "import sam3\n", "sam3_root = os.path.join(os.path.dirname(sam3.__file__), \"..\")\n" ] }, { "cell_type": "code", "execution_count": null, "id": "33a15e2f-c7e1-4e5d-862f-fcb751a60b89", "metadata": {}, "outputs": [], "source": [ "# select the device for computation\n", "if torch.cuda.is_available():\n", " device = torch.device(\"cuda\")\n", "elif torch.backends.mps.is_available():\n", " device = torch.device(\"mps\")\n", "else:\n", " device = torch.device(\"cpu\")\n", "print(f\"using device: {device}\")\n", "\n", "if device.type == \"cuda\":\n", " # use bfloat16 for the entire notebook\n", " torch.autocast(\"cuda\", dtype=torch.bfloat16).__enter__()\n", " # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)\n", " if torch.cuda.get_device_properties(0).major >= 8:\n", " torch.backends.cuda.matmul.allow_tf32 = True\n", " torch.backends.cudnn.allow_tf32 = True\n", "elif device.type == \"mps\":\n", " print(\n", " \"\\nSupport for MPS devices is preliminary. SAM 3 is trained with CUDA and might \"\n", " \"give numerically different outputs and sometimes degraded performance on MPS. \"\n", " \"See e.g. https://github.com/pytorch/pytorch/issues/84936 for a discussion.\"\n", " )" ] }, { "cell_type": "code", "execution_count": null, "id": "29bc90d5", "metadata": {}, "outputs": [], "source": [ "np.random.seed(3)\n", "\n", "def show_mask(mask, ax, random_color=False, borders = True):\n", " if random_color:\n", " color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)\n", " else:\n", " color = np.array([30/255, 144/255, 255/255, 0.6])\n", " h, w = mask.shape[-2:]\n", " mask = mask.astype(np.uint8)\n", " mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)\n", " if borders:\n", " import cv2\n", " contours, _ = cv2.findContours(mask,cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) \n", " # Try to smooth contours\n", " contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]\n", " mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2) \n", " ax.imshow(mask_image)\n", "\n", "def show_points(coords, labels, ax, marker_size=375):\n", " pos_points = coords[labels==1]\n", " neg_points = coords[labels==0]\n", " ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)\n", " ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) \n", "\n", "def show_box(box, ax):\n", " x0, y0 = box[0], box[1]\n", " w, h = box[2] - box[0], box[3] - box[1]\n", " ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2)) \n", "\n", "def show_masks(image, masks, scores, point_coords=None, box_coords=None, input_labels=None, borders=True):\n", " for i, (mask, score) in enumerate(zip(masks, scores)):\n", " plt.figure(figsize=(10, 10))\n", " plt.imshow(image)\n", " show_mask(mask, plt.gca(), borders=borders)\n", " if point_coords is not None:\n", " assert input_labels is not None\n", " show_points(point_coords, input_labels, plt.gca())\n", " if box_coords is not None:\n", " # boxes\n", " show_box(box_coords, plt.gca())\n", " if len(scores) > 1:\n", " plt.title(f\"Mask {i+1}, Score: {score:.3f}\", fontsize=18)\n", " plt.axis('off')\n", " plt.show()" ] }, { "cell_type": "markdown", "id": "23842fb2", "metadata": {}, "source": [ "## Example image" ] }, { "cell_type": "code", "execution_count": null, "id": "3c2e4f6b", "metadata": {}, "outputs": [], "source": [ "image = Image.open(f\"{sam3_root}/assets/images/truck.jpg\")" ] }, { "cell_type": "code", "execution_count": null, "id": "e30125fd", "metadata": {}, "outputs": [], "source": [ "plt.figure(figsize=(10, 10))\n", "plt.imshow(image)\n", "plt.axis('on')\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "98b228b8", "metadata": {}, "source": [ "## Selecting objects with SAM 3" ] }, { "cell_type": "markdown", "id": "0bb1927b", "metadata": {}, "source": [ "First, load the SAM 3 model. Running on CUDA and using the default model are recommended for best results." ] }, { "cell_type": "code", "execution_count": null, "id": "7e28150b", "metadata": {}, "outputs": [], "source": [ "from sam3 import build_sam3_image_model\n", "from sam3.model.sam3_image_processor import Sam3Processor\n", "\n", "bpe_path = f\"{sam3_root}/assets/bpe_simple_vocab_16e6.txt.gz\"\n", "model = build_sam3_image_model(bpe_path=bpe_path, enable_inst_interactivity=True)\n" ] }, { "cell_type": "markdown", "id": "c925e829", "metadata": {}, "source": [ "Process the image to produce an image embedding by calling `Sam3Processor.set_image`." ] }, { "cell_type": "code", "execution_count": null, "id": "d95d48dd", "metadata": {}, "outputs": [], "source": [ "processor = Sam3Processor(model)\n", "inference_state = processor.set_image(image)" ] }, { "cell_type": "markdown", "id": "d8fc7a46", "metadata": {}, "source": [ "To select the truck, choose a point on it. Points are input to the model in (x,y) format and come with labels 1 (foreground point) or 0 (background point). Multiple points can be input; here we use only one. The chosen point will be shown as a star on the image." ] }, { "cell_type": "code", "execution_count": null, "id": "5c69570c", "metadata": {}, "outputs": [], "source": [ "input_point = np.array([[520, 375]])\n", "input_label = np.array([1])" ] }, { "cell_type": "code", "execution_count": null, "id": "a91ba973", "metadata": {}, "outputs": [], "source": [ "plt.figure(figsize=(10, 10))\n", "plt.imshow(image)\n", "show_points(input_point, input_label, plt.gca())\n", "plt.axis('on')\n", "plt.show() " ] }, { "cell_type": "markdown", "id": "c765e952", "metadata": {}, "source": [ "Predict with `SAM3Image.predict_inst`. The model returns masks, quality predictions for those masks, and low resolution mask logits that can be passed to the next iteration of prediction." ] }, { "cell_type": "code", "execution_count": null, "id": "5373fd68", "metadata": {}, "outputs": [], "source": [ "masks, scores, logits = model.predict_inst(\n", " inference_state,\n", " point_coords=input_point,\n", " point_labels=input_label,\n", " multimask_output=True,\n", ")\n", "sorted_ind = np.argsort(scores)[::-1]\n", "masks = masks[sorted_ind]\n", "scores = scores[sorted_ind]\n", "logits = logits[sorted_ind]" ] }, { "cell_type": "markdown", "id": "c7f0e938", "metadata": {}, "source": [ "With `multimask_output=True` (the default setting), SAM 3 outputs 3 masks, where `scores` gives the model's own estimation of the quality of these masks. This setting is intended for ambiguous input prompts, and helps the model disambiguate different objects consistent with the prompt. When `False`, it will return a single mask. For ambiguous prompts such as a single point, it is recommended to use `multimask_output=True` even if only a single mask is desired; the best single mask can be chosen by picking the one with the highest score returned in `scores`. This will often result in a better mask." ] }, { "cell_type": "code", "execution_count": null, "id": "47821187", "metadata": {}, "outputs": [], "source": [ "masks.shape # (number_of_masks) x H x W" ] }, { "cell_type": "code", "execution_count": null, "id": "e9c227a6", "metadata": {}, "outputs": [], "source": [ "show_masks(image, masks, scores, point_coords=input_point, input_labels=input_label, borders=True)" ] }, { "cell_type": "markdown", "id": "3fa31f7c", "metadata": {}, "source": [ "## Specifying a specific object with additional points" ] }, { "cell_type": "markdown", "id": "88d6d29a", "metadata": {}, "source": [ "The single input point is ambiguous, and the model has returned multiple objects consistent with it. To obtain a single object, multiple points can be provided. If available, a mask from a previous iteration can also be supplied to the model to aid in prediction. When specifying a single object with multiple prompts, a single mask can be requested by setting `multimask_output=False`." ] }, { "cell_type": "code", "execution_count": null, "id": "f6923b94", "metadata": {}, "outputs": [], "source": [ "input_point = np.array([[500, 375], [1125, 625]])\n", "input_label = np.array([1, 1])\n", "\n", "mask_input = logits[np.argmax(scores), :, :] # Choose the model's best mask" ] }, { "cell_type": "code", "execution_count": null, "id": "d98f96a1", "metadata": {}, "outputs": [], "source": [ "masks, scores, _ = model.predict_inst(\n", " inference_state,\n", " point_coords=input_point,\n", " point_labels=input_label,\n", " mask_input=mask_input[None, :, :],\n", " multimask_output=False,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "0ce8b82f", "metadata": {}, "outputs": [], "source": [ "masks.shape" ] }, { "cell_type": "code", "execution_count": null, "id": "e06d5c8d", "metadata": {}, "outputs": [], "source": [ "show_masks(image, masks, scores, point_coords=input_point, input_labels=input_label)" ] }, { "cell_type": "markdown", "id": "c93e2087", "metadata": {}, "source": [ "To exclude the car and specify just the window, a background point (with label 0, here shown in red) can be supplied." ] }, { "cell_type": "code", "execution_count": null, "id": "9a196f68", "metadata": {}, "outputs": [], "source": [ "input_point = np.array([[500, 375], [1125, 625]])\n", "input_label = np.array([1, 0])\n", "\n", "mask_input = logits[np.argmax(scores), :, :] # Choose the model's best mask" ] }, { "cell_type": "code", "execution_count": null, "id": "81a52282", "metadata": {}, "outputs": [], "source": [ "masks, scores, _ = model.predict_inst(\n", " inference_state,\n", " point_coords=input_point,\n", " point_labels=input_label,\n", " mask_input=mask_input[None, :, :],\n", " multimask_output=False,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "bfca709f", "metadata": {}, "outputs": [], "source": [ "show_masks(image, masks, scores, point_coords=input_point, input_labels=input_label)" ] }, { "cell_type": "markdown", "id": "41e2d5a9", "metadata": {}, "source": [ "## Specifying a specific object with a box" ] }, { "cell_type": "markdown", "id": "d61ca7ac", "metadata": {}, "source": [ "The model can also take a box as input, provided in xyxy format." ] }, { "cell_type": "code", "execution_count": null, "id": "8ea92a7b", "metadata": {}, "outputs": [], "source": [ "input_box = np.array([425, 600, 700, 875])" ] }, { "cell_type": "code", "execution_count": null, "id": "b35a8814", "metadata": {}, "outputs": [], "source": [ "masks, scores, _ = model.predict_inst(\n", " inference_state,\n", " point_coords=None,\n", " point_labels=None,\n", " box=input_box[None, :],\n", " multimask_output=False,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "3ffb4906", "metadata": {}, "outputs": [], "source": [ "show_masks(image, masks, scores, box_coords=input_box)" ] }, { "cell_type": "markdown", "id": "c1ed9f0a", "metadata": {}, "source": [ "## Combining points and boxes" ] }, { "cell_type": "markdown", "id": "8455d1c5", "metadata": {}, "source": [ "Points and boxes may be combined, just by including both types of prompts to the predictor. Here this can be used to select just the trucks's tire, instead of the entire wheel." ] }, { "cell_type": "code", "execution_count": null, "id": "90e2e547", "metadata": {}, "outputs": [], "source": [ "input_box = np.array([425, 600, 700, 875])\n", "input_point = np.array([[575, 750]])\n", "input_label = np.array([0])" ] }, { "cell_type": "code", "execution_count": null, "id": "6956d8c4", "metadata": {}, "outputs": [], "source": [ "masks, scores, logits = model.predict_inst(\n", " inference_state,\n", " point_coords=input_point,\n", " point_labels=input_label,\n", " box=input_box,\n", " multimask_output=False,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "eb519a31", "metadata": {}, "outputs": [], "source": [ "show_masks(image, masks, scores, box_coords=input_box, point_coords=input_point, input_labels=input_label)" ] }, { "cell_type": "markdown", "id": "45ddbca3", "metadata": {}, "source": [ "## Batched prompt inputs" ] }, { "cell_type": "markdown", "id": "df6f18a0", "metadata": {}, "source": [ "`SAM3Image` can take multiple input prompts for the same image, using `predict_inst` method. For example, imagine we have several box outputs from an object detector." ] }, { "cell_type": "code", "execution_count": null, "id": "0a06681b", "metadata": {}, "outputs": [], "source": [ "input_boxes = np.array([\n", " [75, 275, 1725, 850],\n", " [425, 600, 700, 875],\n", " [1375, 550, 1650, 800],\n", " [1240, 675, 1400, 750],\n", "])" ] }, { "cell_type": "code", "execution_count": null, "id": "117521a3", "metadata": {}, "outputs": [], "source": [ "masks, scores, _ = model.predict_inst(\n", " inference_state,\n", " point_coords=None,\n", " point_labels=None,\n", " box=input_boxes,\n", " multimask_output=False,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "6a8f5d49", "metadata": {}, "outputs": [], "source": [ "masks.shape # (batch_size) x (num_predicted_masks_per_input) x H x W" ] }, { "cell_type": "code", "execution_count": null, "id": "c00c3681", "metadata": {}, "outputs": [], "source": [ "plt.figure(figsize=(10, 10))\n", "plt.imshow(image)\n", "for mask in masks:\n", " show_mask(mask.squeeze(0), plt.gca(), random_color=True)\n", "for box in input_boxes:\n", " show_box(box, plt.gca())\n", "plt.axis('off')\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "b9a27b5d", "metadata": {}, "source": [ "## End-to-end batched inference\n", "If all prompts are available in advance, it is possible to run SAM 3 directly in an end-to-end fashion. This also allows batching over images." ] }, { "cell_type": "code", "execution_count": null, "id": "d485f75b", "metadata": {}, "outputs": [], "source": [ "image1 = image # truck.jpg from above\n", "image1_boxes = np.array([\n", " [75, 275, 1725, 850],\n", " [425, 600, 700, 875],\n", " [1375, 550, 1650, 800],\n", " [1240, 675, 1400, 750],\n", "])\n", "\n", "image2 = Image.open(f\"{sam3_root}/assets/images/groceries.jpg\")\n", "image2_boxes = np.array([\n", " [450, 170, 520, 350],\n", " [350, 190, 450, 350],\n", " [500, 170, 580, 350],\n", " [580, 170, 640, 350],\n", "])\n", "\n", "img_batch = [image1, image2]\n", "boxes_batch = [image1_boxes, image2_boxes]" ] }, { "cell_type": "code", "execution_count": null, "id": "47932c99", "metadata": {}, "outputs": [], "source": [ "inference_state = processor.set_image_batch(img_batch)" ] }, { "cell_type": "code", "execution_count": null, "id": "97af3c54", "metadata": {}, "outputs": [], "source": [ "masks_batch, scores_batch, _ = model.predict_inst_batch(\n", " inference_state,\n", " None,\n", " None, \n", " box_batch=boxes_batch, \n", " multimask_output=False\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "226df881", "metadata": {}, "outputs": [], "source": [ "for image, boxes, masks in zip(img_batch, boxes_batch, masks_batch):\n", " plt.figure(figsize=(10, 10))\n", " plt.imshow(image) \n", " for mask in masks:\n", " show_mask(mask.squeeze(0), plt.gca(), random_color=True)\n", " for box in boxes:\n", " show_box(box, plt.gca())" ] }, { "cell_type": "markdown", "id": "46f30085", "metadata": {}, "source": [ "Similarly, we can have a batch of point prompts defined over a batch of images" ] }, { "cell_type": "code", "execution_count": null, "id": "1ab929fc", "metadata": {}, "outputs": [], "source": [ "image1 = image # truck.jpg from above\n", "image1_pts = np.array([\n", " [[500, 375]],\n", " [[650, 750]]\n", " ]) # Bx1x2 where B corresponds to number of objects \n", "image1_labels = np.array([[1], [1]])\n", "\n", "image2_pts = np.array([\n", " [[400, 300]],\n", " [[630, 300]],\n", "])\n", "image2_labels = np.array([[1], [1]])\n", "\n", "pts_batch = [image1_pts, image2_pts]\n", "labels_batch = [image1_labels, image2_labels]" ] }, { "cell_type": "code", "execution_count": null, "id": "848f8287", "metadata": {}, "outputs": [], "source": [ "masks_batch, scores_batch, _ = model.predict_inst_batch(inference_state, pts_batch, labels_batch, box_batch=None, multimask_output=True)\n", "\n", "# Select the best single mask per object\n", "best_masks = []\n", "for masks, scores in zip(masks_batch,scores_batch):\n", " best_masks.append(masks[range(len(masks)), np.argmax(scores, axis=-1)])" ] }, { "cell_type": "code", "execution_count": null, "id": "99b15c6c", "metadata": {}, "outputs": [], "source": [ "for image, points, labels, masks in zip(img_batch, pts_batch, labels_batch, best_masks):\n", " plt.figure(figsize=(10, 10))\n", " plt.imshow(image) \n", " for mask in masks:\n", " show_mask(mask, plt.gca(), random_color=True)\n", " show_points(points, labels, plt.gca())" ] }, { "cell_type": "code", "execution_count": null, "id": "4c1594a5-a0de-4477-91d4-db4504a78a83", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "74e3d07e-b0de-48a5-9d29-d639a0dbcdfc", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "d8b1de3a-a253-48ff-8a1c-d80742acbe86", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.11" } }, "nbformat": 4, "nbformat_minor": 5 }