Skip to content

Interactive segmentation

Context

SAM (Segment Anything Model) is an open-source model proposed by Meta to perform mask segmentation from boxes, keypoints and/or original masks.

Pixano's web app integrates SAM to quickly annotate your images. It first requires to pre-compute the embeddings of the images.

This tutorial will help you unlock this feature.

Create the image embeddings

Install the requirements

  1. Pip dependencies

Install the official SAM repo, onnx to export the model and transformers to get the image embeddings.

pip install git+https://github.com/facebookresearch/segment-anything.git
pip install onnx transformers
  1. Download the model and export it to ONNX format.
git clone https://github.com/facebookresearch/segment-anything.git

cd segment-anything

wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

python segment-anything/scripts/export_onnx_model.py \
    --checkpoint sam_vit_h_4b8939.pth \
    --model-type vit_h \
    --output sam_h.onnx

cp sam_h.onnx /path/to/pixano/models/
# Defaults is models/ under the library

Create the embeddings

We will use the Health Images dataset defined in the Build and query a dataset tutorial.

  1. Load the model and the dataset.
import torch
from transformers import SamModel, SamProcessor
from pixano.datasets import Dataset
from pixano.features import Image
from pathlib import Path

device = "cuda" if torch.cuda.is_available() else "cpu"
model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device=device)
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")

dataset = Dataset(
    Path("./pixano_library/health_dataset"),
    media_dir=Path("./assets/")
)

images: list[Image] = dataset.get_data("image")
num_images  = len(images)

print(num_images)

>>> 11
  1. Create the SAM embeddings table.
from pixano.features import ViewEmbedding
from pixano.datasets.dataset_schema import SchemaRelation
from lancedb.pydantic import Vector

class SAMViewEmbedding(ViewEmbedding):
    vector: Vector(1048576)

sam_table = dataset.create_table(
    name="sam_embedding",
    schema=SAMViewEmbedding,
    relation_item=SchemaRelation.ONE_TO_ONE,
    mode="overwrite"
)
  1. Compute the embeddings
import shortuuid
from pixano.features import ViewRef

embeddings = []
for i, image in enumerate(images):
    pil_image = image.open( # Load the actual image
            media_dir=dataset.media_dir,
            output_type="image"
        ).convert("RGB")
    with torch.inference_mode():
        # Compute the embeddings
        inputs = processor(pil_image, return_tensors="pt").to(device=device)
        output = model.get_image_embeddings(inputs["pixel_values"])
    # Validate the output
    embedding = dataset.schema.schemas["sam_embedding"](
        id=shortuuid.uuid(),
        item_ref=image.item_ref,
        view_ref=ViewRef(id=image.id, name=image.table_name),
        vector=output.flatten().tolist(),
        shape=output.squeeze().shape,
    )
    embeddings.append(embedding)

# Flush to the table
dataset.add_data("sam_embedding", embeddings)

Use the interactive segmentation

With the app

Now you are all set to use SAM, check the "Smart segmentation tool" section of the using the app guide!