Skip to content

pixano_inference.providers.sam2

Provider for the SAM2 model.

Sam2Provider(**kwargs)

Bases: ModelProvider

Provider for the SAM2 model.

Source code in pixano_inference/providers/sam2.py
def __init__(self, **kwargs):
    """Initialize the SAM2 provider."""
    assert_sam2_installed()
    super().__init__(**kwargs)

image_mask_generation(request, model, *args, **kwargs)

Generate a mask from the image.

Parameters:

Name Type Description Default
request ImageMaskGenerationRequest

Request for the generation.

required
model Sam2Model

Model to use for the generation.

required
args Any

Additional arguments.

()
kwargs Any

Additional keyword arguments.

{}

Returns:

Type Description
ImageMaskGenerationOutput

Output of the generation

Source code in pixano_inference/providers/sam2.py
def image_mask_generation(
    self,
    request: ImageMaskGenerationRequest,
    model: Sam2Model,  # type: ignore[override]
    *args: Any,
    **kwargs: Any,
) -> ImageMaskGenerationOutput:
    """Generate a mask from the image.

    Args:
        request: Request for the generation.
        model: Model to use for the generation.
        args: Additional arguments.
        kwargs: Additional keyword arguments.

    Returns:
        Output of the generation
    """
    request_input = request.to_input()
    image = convert_string_to_image(request_input.image)

    if request_input.image_embedding is not None and request_input.high_resolution_features is not None:
        image_embedding = vector_to_tensor(request_input.image_embedding)
        high_resolution_features = [vector_to_tensor(v) for v in request_input.high_resolution_features]
        model.set_image_embeddings(image, image_embedding, high_resolution_features)
    elif request_input.image_embedding is not None or request_input.high_resolution_features is not None:
        raise ValueError("Both image_embedding and high_resolution_features must be provided.")

    model_input = request_input.model_dump(exclude=["image", "image_embedding", "high_resolution_features"])
    model_input["image"] = image
    output = model.image_mask_generation(**model_input)
    model.predictor.reset_predictor()
    return output

load_model(name, task, device, path=None, processor_config={}, config={})

Load the model.

Parameters:

Name Type Description Default
name str

Name of the model.

required
task Task | str

Task of the model.

required
device device

Device to use for the model.

required
path Path | str | None

Path to the model.

None
processor_config dict

Processor configuration.

{}
config dict

Configuration for the model.

{}

Returns:

Type Description
Sam2Model

The loaded model.

Source code in pixano_inference/providers/sam2.py
def load_model(
    self,
    name: str,
    task: Task | str,
    device: "torch.device",
    path: Path | str | None = None,
    processor_config: dict = {},
    config: dict = {},
) -> Sam2Model:
    """Load the model.

    Args:
        name: Name of the model.
        task: Task of the model.
        device: Device to use for the model.
        path: Path to the model.
        processor_config: Processor configuration.
        config: Configuration for the model.

    Returns:
        The loaded model.
    """
    task = str_to_task(task) if isinstance(task, str) else task
    if task == ImageTask.MASK_GENERATION:
        if path is not None and Path(path).exists():
            model = build_sam2(ckpt_path=path, mode="eval", device=device, **config)
        else:
            model = build_sam2_hf(model_id=path, mode="eval", device=device, **config)
        model = torch.compile(model)
        predictor = SAM2ImagePredictor(model)
    elif task == VideoTask.MASK_GENERATION:
        if path is not None and Path(path).exists():
            predictor = build_sam2_video_predictor(
                ckpt_path=path, mode="eval", device=device, vos_optimized=True, **config
            )
        else:
            predictor = build_sam2_video_predictor_hf(
                model_id=path, mode="eval", device=device, vos_optimized=True, **config
            )
        predictor = torch.compile(predictor)
    else:
        raise ValueError(f"Invalid task '{task}' for the SAM2 provider.")

    our_model = Sam2Model(
        name=name,
        provider="sam2",
        predictor=predictor,
        torch_dtype=config.get("torch_dtype", "bfloat16"),
        config=config,
    )

    return our_model

video_mask_generation(request, model, *args, **kwargs)

Generate masks from the video.

Parameters:

Name Type Description Default
request VideoMaskGenerationRequest

Request for the generation.

required
model Sam2Model

Model to use for the generation.

required
args Any

Additional arguments.

()
kwargs Any

Additional keyword arguments.

{}

Returns:

Type Description
VideoMaskGenerationResponse

Response of the generation.

Source code in pixano_inference/providers/sam2.py
def video_mask_generation(
    self,
    request: VideoMaskGenerationRequest,
    model: Sam2Model,  # type: ignore[override]
    *args: Any,
    **kwargs: Any,
) -> VideoMaskGenerationResponse:
    """Generate masks from the video.

    Args:
        request: Request for the generation.
        model: Model to use for the generation.
        args: Additional arguments.
        kwargs: Additional keyword arguments.

    Returns:
        Response of the generation.
    """
    request_input = request.to_input().model_dump()
    request_input["video"] = convert_string_video_to_bytes_or_path(request_input["video"])
    output = model.video_mask_generation(**request_input)
    return output