Skip to content

pixano_inference.impls.sam2.image

SAM2 image segmentation model.

Sam2ImageModel(config)

Bases: SegmentationModel

Native Ray Serve model for SAM2 image mask generation.

model_params contract:

  • path (str, required): HuggingFace model ID or local checkpoint path.
  • torch_dtype (str, default "bfloat16"): Torch dtype for autocast.
  • compile (bool, default True): Whether to torch.compile the model.

Any remaining keys are forwarded to build_sam2 / build_sam2_hf.

Parameters:

Name Type Description Default
config ModelDeploymentConfig

Model deployment configuration.

required
Source code in pixano_inference/impls/sam2/image.py
def __init__(self, config: ModelDeploymentConfig) -> None:
    """Initialize the model.

    Args:
        config: Model deployment configuration.
    """
    super().__init__(config)
    self._predictor: Any = None
    self._torch_dtype: Any = None

metadata property

Model metadata including path and dtype.

load_model()

Load the SAM2 image predictor.

Source code in pixano_inference/impls/sam2/image.py
def load_model(self) -> None:
    """Load the SAM2 image predictor."""
    from pixano_inference.utils.package import assert_sam2_installed

    assert_sam2_installed()

    import torch
    from sam2.build_sam import build_sam2, build_sam2_hf
    from sam2.sam2_image_predictor import SAM2ImagePredictor

    params = dict(self._config.model_params)
    path = params.pop("path")
    torch_dtype_str = params.pop("torch_dtype", "bfloat16")
    compile_model = params.pop("compile", True)

    device = resolve_device(self._config)
    self._torch_dtype = resolve_torch_dtype(torch_dtype_str)

    if path is not None and Path(path).exists():
        model = build_sam2(ckpt_path=path, mode="eval", device=device, **params)
    else:
        model = build_sam2_hf(model_id=path, mode="eval", device=device, **params)

    if compile_model:
        model = torch.compile(model)

    self._predictor = SAM2ImagePredictor(model)
    logger.info("Sam2ImageModel '%s' loaded on %s (dtype=%s)", self.model_name, device, torch_dtype_str)

predict(input)

Run SAM2 image mask generation.

Parameters:

Name Type Description Default
input SegmentationInput

Segmentation input with image, prompts, and options.

required

Returns:

Type Description
SegmentationOutput

Segmentation output with masks, scores, and optionally embeddings.

Source code in pixano_inference/impls/sam2/image.py
def predict(self, input: SegmentationInput) -> SegmentationOutput:
    """Run SAM2 image mask generation.

    Args:
        input: Segmentation input with image, prompts, and options.

    Returns:
        Segmentation output with masks, scores, and optionally embeddings.
    """
    import torch

    from pixano_inference.schemas.nd_array import NDArrayFloat
    from pixano_inference.schemas.rle import CompressedRLE
    from pixano_inference.utils.media import convert_string_to_image

    pil_image = convert_string_to_image(input.image)
    validate_prompts(input.points, input.labels, input.boxes)

    if input.multimask_output and input.num_multimask_outputs != 3:
        raise ValueError("The number of multimask outputs is not configurable for SAM2 and must be 3.")

    # Handle predictor reset and optional embedding restoration
    if input.reset_predictor:
        self._predictor.reset_predictor()
        if input.image_embedding is not None and input.high_resolution_features is not None:
            self._set_image_embeddings(pil_image, input.image_embedding, input.high_resolution_features)
        elif input.image_embedding is not None or input.high_resolution_features is not None:
            raise ValueError("Both image_embedding and high_resolution_features must be provided.")

    # Prepare numpy inputs
    input_points: np.ndarray | None = None
    input_labels: np.ndarray | None = None
    input_boxes: np.ndarray | None = None
    input_mask: np.ndarray | None = None

    if input.points is not None:
        input_points, input_labels = pad_points_and_labels(input.points, input.labels)
    if input.boxes is not None:
        input_boxes = np.array(input.boxes, dtype=np.int32)
    if input.mask_input is not None:
        input_mask = input.mask_input.to_numpy()

    with torch.inference_mode():
        with torch.autocast(self._predictor.device.type, dtype=self._torch_dtype):
            if not self._predictor._is_image_set:
                self._predictor.set_image(pil_image)

            masks, scores, low_res_masks = self._predictor.predict(
                point_coords=input_points,
                point_labels=input_labels,
                box=input_boxes,
                mask_input=input_mask,
                multimask_output=input.multimask_output,
                # Keep full-resolution masks binary for rendering/saving.
                # The third SAM2 output already contains the low-res logits
                # needed for iterative refinement.
                return_logits=False,
            )

    # Ensure 4D: [num_prompts, num_masks, H, W]
    if len(masks.shape) == 3:
        masks = np.expand_dims(masks, 0)
        scores = np.expand_dims(scores, 0)
    elif len(masks.shape) == 2:
        masks = np.expand_dims(masks, (0, 1))
        scores = np.expand_dims(scores, (0, 1))

    # Build output
    out_masks = [
        [CompressedRLE.from_mask(mask.astype(np.uint8)) for mask in prediction_masks] for prediction_masks in masks
    ]
    out_scores = NDArrayFloat.from_numpy(scores)

    out_image_embedding: NDArrayFloat | None = None
    out_high_resolution_features: list[NDArrayFloat] | None = None
    out_mask_logits: NDArrayFloat | None = None

    if input.return_image_embedding:
        embed = self._predictor._features["image_embed"]
        embed_list = embed.to(torch.float32).flatten().tolist()
        out_image_embedding = NDArrayFloat(values=embed_list, shape=list(embed.shape[1:]))

        hr_feats = self._predictor._features["high_res_feats"]
        out_high_resolution_features = [
            NDArrayFloat(
                values=feat.to(torch.float32).flatten().tolist(),
                shape=list(feat.shape[1:]),
            )
            for feat in hr_feats
        ]

    if input.return_logits:
        if len(low_res_masks.shape) == 2:
            low_res_masks = np.expand_dims(low_res_masks, 0)
        out_mask_logits = NDArrayFloat.from_numpy(low_res_masks)

    return SegmentationOutput(
        masks=out_masks,
        scores=out_scores,
        image_embedding=out_image_embedding,
        high_resolution_features=out_high_resolution_features,
        mask_logits=out_mask_logits,
    )

unload()

Free resources.

Source code in pixano_inference/impls/sam2/image.py
def unload(self) -> None:
    """Free resources."""
    if self._predictor is not None:
        del self._predictor
        self._predictor = None
    gc.collect()
    try:
        import torch

        torch.cuda.empty_cache()
    except Exception:
        pass