SAM2 Video
Define the visualization utilities (from SAM2)
import matplotlib.pyplot as plt
import numpy as np
import torch
np.random.seed(3)
def show_mask(mask, ax, obj_id=None, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
cmap = plt.get_cmap("tab10")
cmap_idx = 0 if obj_id is None else obj_id
color = np.array([*cmap(cmap_idx)[:3], 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
def show_points(coords, labels, ax, marker_size=200):
pos_points = coords[labels == 1]
neg_points = coords[labels == 0]
ax.scatter(
pos_points[:, 0], pos_points[:, 1], color="green", marker="*", s=marker_size, edgecolor="white", linewidth=1.25
)
ax.scatter(
neg_points[:, 0], neg_points[:, 1], color="red", marker="*", s=marker_size, edgecolor="white", linewidth=1.25
)
def show_box(box, ax):
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor="green", facecolor=(0, 0, 0, 0), lw=2))
# select the device for computation
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
print(f"using device: {device}")
Instantiate the model
import torch
from pixano_inference.providers.sam2 import Sam2Provider
provider = Sam2Provider()
model = provider.load_model(
name="sam",
task="image_mask_generation",
device=torch.device("cuda") if torch.cuda.is_available() else "cpu",
path="facebook/sam2-hiera-tiny",
)
Call the model
from pathlib import Path
obj_ids = [0, 2]
frame_indexes = [0, 0]
points = [[[210, 350]], [[400, 500]]]
labels = [[1], [1]]
output = model.video_mask_generation(
sorted([f for f in Path("./docs/assets/examples/sam2/bedroom").glob("**/*") if f.is_file()]),
objects_ids=obj_ids,
frame_indexes=frame_indexes,
points=points,
labels=labels,
propagate=True,
)
Display the result
import os
from PIL import Image
video_dir = Path("./docs/assets/examples/sam2/bedroom/")
# scan all the JPEG frame names in this directory
frame_names = [p for p in os.listdir(video_dir) if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]]
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
# take a look the first video frame
frame_idx = 0
plt.figure(figsize=(9, 6))
plt.title(f"frame {frame_idx}")
plt.imshow(Image.open(os.path.join(video_dir, frame_names[frame_idx])))