Skip to content

pixano.app.routers.inference.zero_shot_detection

ZeroShotOutput(**data)

Bases: BaseModel

Zero shot output.

Source code in pydantic/main.py
def __init__(self, /, **data: Any) -> None:
    """Create a new model by parsing and validating input data from keyword arguments.

    Raises [`ValidationError`][pydantic_core.ValidationError] if the input data cannot be
    validated to form a valid model.

    `self` is explicitly positional-only to allow `self` as a field name.
    """
    # `__tracebackhide__` tells pytest and some other tools to omit this function from tracebacks
    __tracebackhide__ = True
    validated_self = self.__pydantic_validator__.validate_python(data, self_instance=self)
    if self is not validated_self:
        warnings.warn(
            'A custom validator is returning a value other than `self`.\n'
            "Returning anything other than `self` from a top level model validator isn't supported when validating via `__init__`.\n"
            'See the `model_validator` docs (https://docs.pydantic.dev/latest/concepts/validators/#model-validators) for more details.',
            stacklevel=2,
        )

call_image_zero_shot_detection(dataset_id, image, entity, classes, model, box_table_name, class_table_name, settings, box_threshold=0.3, text_threshold=0.2) async

Perform zero shot detection on an image.

Parameters:

Name Type Description Default
dataset_id Annotated[str, Body(embed=True)]

The ID of the dataset to use.

required
image Annotated[ViewModel, Body(embed=True)]

The image to use for detection.

required
entity Annotated[EntityModel, Body(embed=True)]

The entity to use for detection.

required
classes Annotated[list[str] | str, Body(embed=True)]

Labels to detect.

required
model Annotated[str, Body(embed=True)]

The name of the model to use.

required
box_table_name Annotated[str, Body(embed=True)]

The name of the table to use for boxes in dataset.

required
class_table_name Annotated[str, Body(embed=True)]

The name of the table to use for classifications in dataset.

required
settings Annotated[Settings, Depends(get_settings)]

App settings.

required
box_threshold Annotated[float, Body(embed=True)]

Box threshold for detection in the image.

0.3
text_threshold Annotated[float, Body(embed=True)]

Text threshold for detection in the image.

0.2

Returns:

Type Description
list[ZeroShotOutput]

The predicted bboxes and classifications.

Source code in pixano/app/routers/inference/zero_shot_detection.py
@router.post(
    "/image",
    response_model=list[ZeroShotOutput],
)
async def call_image_zero_shot_detection(
    dataset_id: Annotated[str, Body(embed=True)],
    image: Annotated[ViewModel, Body(embed=True)],
    entity: Annotated[EntityModel, Body(embed=True)],
    classes: Annotated[list[str] | str, Body(embed=True)],
    model: Annotated[str, Body(embed=True)],
    box_table_name: Annotated[str, Body(embed=True)],
    class_table_name: Annotated[str, Body(embed=True)],
    settings: Annotated[Settings, Depends(get_settings)],
    box_threshold: Annotated[float, Body(embed=True)] = 0.3,
    text_threshold: Annotated[float, Body(embed=True)] = 0.2,
) -> list[ZeroShotOutput]:
    """Perform zero shot detection on an image.

    Args:
        dataset_id: The ID of the dataset to use.
        image: The image to use for detection.
        entity: The entity to use for detection.
        classes: Labels to detect.
        model: The name of the model to use.
        box_table_name: The name of the table to use for boxes in dataset.
        class_table_name: The name of the table to use for classifications in dataset.
        settings: App settings.
        box_threshold: Box threshold for detection in the image.
        text_threshold: Text threshold for detection in the image.

    Returns:
        The predicted bboxes and classifications.
    """
    dataset = get_dataset(dataset_id=dataset_id, dir=settings.library_dir, media_dir=settings.media_dir)
    client = get_client_from_settings(settings=settings)

    if not is_image(dataset.schema.schemas[image.table_info.name]):
        raise HTTPException(status_code=400, detail="Image must be an image.")

    image_row: Image = image.to_row(dataset)
    entity_row: Entity = entity.to_row(dataset)
    source = get_model_source(dataset=dataset, model=model)

    try:
        bboxes_and_classifications: list[tuple[BBox, Classification]] = await image_zero_shot_detection(
            client=client,
            source=source,
            media_dir=settings.media_dir,
            image=image_row,
            entity=entity_row,
            classes=classes,
            box_threshold=box_threshold,
            text_threshold=text_threshold,
        )
    except Exception as e:
        raise HTTPException(status_code=400, detail=str(e)) from e

    output: list[ZeroShotOutput] = []
    for bbox, classification in bboxes_and_classifications:
        bbox_model = AnnotationModel.from_row(
            row=bbox, table_info=TableInfo(name=box_table_name, group="annotations", base_schema="BBox")
        )
        classification_model = AnnotationModel.from_row(
            row=classification,
            table_info=TableInfo(name=class_table_name, group="annotations", base_schema="Classification"),
        )
        output.append(ZeroShotOutput(bbox=bbox_model, classification=classification_model))
    return output