Skip to content

pixano.inference.zero_shot_detection

image_zero_shot_detection(provider, media_dir, image, entity, source, classes, box_threshold=0.5, text_threshold=0.5, **provider_kwargs) async

Image zero shot detection task.

Parameters:

Name Type Description Default
provider InferenceProvider

Inference provider.

required
media_dir Path

Media directory.

required
image Image

Image to generate mask for.

required
entity Entity

Entity associated with the image.

required
source Source

The source refering to the model.

required
classes list[str] | str

List of classes to detect in the image.

required
box_threshold float

Box threshold for detection in the image.

0.5
text_threshold float

Text threshold for detection in the image.

0.5
provider_kwargs Any

Additional kwargs for the provider.

{}

Returns:

Type Description
list[tuple[BBox, Classification]]

List of BBoxes and Classifications detected in the image with respect to classes and threshold values.

Source code in pixano/inference/zero_shot_detection.py
async def image_zero_shot_detection(
    provider: InferenceProvider,
    media_dir: Path,
    image: Image,
    entity: Entity,
    source: Source,
    classes: list[str] | str,
    box_threshold: float = 0.5,
    text_threshold: float = 0.5,
    **provider_kwargs: Any,
) -> list[tuple[BBox, Classification]]:
    """Image zero shot detection task.

    Args:
        provider: Inference provider.
        media_dir: Media directory.
        image: Image to generate mask for.
        entity: Entity associated with the image.
        source: The source refering to the model.
        classes: List of classes to detect in the image.
        box_threshold: Box threshold for detection in the image.
        text_threshold: Text threshold for detection in the image.
        provider_kwargs: Additional kwargs for the provider.

    Returns:
        List of BBoxes and Classifications detected in the image with respect to classes and threshold values.
    """
    image_request = image.url if _is_url(image.url) else image.open(media_dir, "base64")

    input_data = ImageZeroShotDetectionInput(
        image=image_request,
        model=source.name,
        classes=classes,
        box_threshold=box_threshold,
        text_threshold=text_threshold,
    )

    result = await provider.image_zero_shot_detection(input_data, **provider_kwargs)

    inference_metadata = jsonable_encoder(
        {
            "timestamp": result.timestamp.isoformat(),
            "processing_time": result.processing_time,
            **result.metadata,
        }
    )

    boxes = result.data.boxes
    scores = result.data.scores
    detected_classes = result.data.classes

    output: list[tuple[BBox, Classification]] = []

    for b, s, c in zip(boxes, scores, detected_classes, strict=True):
        view_ref = ViewRef(name=image.table_name, id=image.id)
        entity_ref = EntityRef(name=entity.table_name, id=entity.id)
        source_ref = SourceRef(id=source.id)
        output.append(
            (
                BBox(
                    id=shortuuid.uuid(),
                    item_ref=image.item_ref,
                    view_ref=view_ref,
                    entity_ref=entity_ref,
                    source_ref=source_ref,
                    inference_metadata=inference_metadata,
                    coords=b,
                    format="xyxy",
                    is_normalized=False,
                    confidence=s,
                ),
                Classification(
                    id=shortuuid.uuid(),
                    item_ref=image.item_ref,
                    view_ref=view_ref,
                    entity_ref=entity_ref,
                    source_ref=source_ref,
                    inference_metadata=inference_metadata,
                    labels=[c],
                    confidences=[s],
                ),
            )
        )

    return output