Skip to content

pixano_inference.client

Pixano inference client.

InferenceTooLongError

Bases: Exception

Exeption when inference took too long.

PixanoInferenceClient(**data)

Bases: Settings

Pixano Inference Client.

Source code in pixano_inference/settings.py
def __init__(self, **data: Any):
    """Initialize the settings."""
    if "num_cpus" not in data:
        data["num_cpus"] = os.cpu_count()
    if "num_gpus" not in data:
        if is_torch_installed():
            if torch.cuda.is_available():
                data["num_gpus"] = torch.cuda.device_count()
            else:
                data["num_gpus"] = 0
        else:
            data["num_gpus"] = 0

    super().__init__(**data)

_rest_call(path, method, timeout=60, **kwargs) async

Perform a REST call to the pixano inference server.

Source code in pixano_inference/client.py
async def _rest_call(
    self,
    path: str,
    method: Literal["GET", "POST", "PUT", "DELETE"],
    timeout: int = 60,
    **kwargs,
) -> Response:
    """Perform a REST call to the pixano inference server."""
    async with httpx.AsyncClient(timeout=timeout) as client:
        match method:
            case "GET":
                request_fn = client.get
            case "POST":
                request_fn = client.post
            case "PUT":
                request_fn = client.put
            case "DELETE":
                request_fn = client.delete
            case _:
                raise ValueError(
                    f"Invalid REST call method. Expected one of ['GET', 'POST', 'PUT', 'DELETE'], but got "
                    f"'{method}'."
                )

        if path.startswith("/"):
            path = path[1:]

        url = f"{self.url}/{path}"
        response = await request_fn(url, **kwargs)
        raise_if_error(response)

        return response

connect(url) staticmethod

Connect to pixano inference.

Parameters:

Name Type Description Default
url str

The URL of the pixano inference server.

required
Source code in pixano_inference/client.py
@staticmethod
def connect(url: str) -> "PixanoInferenceClient":
    """Connect to pixano inference.

    Args:
        url: The URL of the pixano inference server.
    """
    settings = Settings.model_validate(requests.get(f"{url}/app/settings/").json())
    client = PixanoInferenceClient(url=url, **settings.model_dump())
    return client

delete(path, **kwargs) async

Perform a DELETE request to the pixano inference server.

Parameters:

Name Type Description Default
path str

The path of the request.

required
kwargs Any

The keyword arguments to pass to the request or httpx client.

{}
Source code in pixano_inference/client.py
async def delete(self, path: str, **kwargs: Any) -> Response:
    """Perform a DELETE request to the pixano inference server.

    Args:
        path: The path of the request.
        kwargs: The keyword arguments to pass to the request or httpx client.
    """
    return await self._rest_call(path=path, method="DELETE", **kwargs)

delete_model(model_name) async

Delete a model.

Parameters:

Name Type Description Default
model_name str

The name of the model.

required
Source code in pixano_inference/client.py
async def delete_model(self, model_name: str) -> None:
    """Delete a model.

    Args:
        model_name: The name of the model.
    """
    await self.delete(f"providers/model/{model_name}")

get(path, **kwargs) async

Perform a GET request to the pixano inference server.

Parameters:

Name Type Description Default
path str

The path of the request.

required
kwargs Any

The keyword arguments to pass to the request or httpx client.

{}
Source code in pixano_inference/client.py
async def get(self, path: str, **kwargs: Any) -> Response:
    """Perform a GET request to the pixano inference server.

    Args:
        path: The path of the request.
        kwargs: The keyword arguments to pass to the request or httpx client.
    """
    return await self._rest_call(path=path, method="GET", **kwargs)

get_settings() async

Get the settings for the pixano inference server.

Source code in pixano_inference/client.py
async def get_settings(self) -> Settings:
    """Get the settings for the pixano inference server."""
    response = await self.get("app/settings/")
    raise_if_error(response)
    return Settings(**response.json())

image_mask_generation(request=None, poll_interval=0.1, timeout=60, task_id=None, asynchronous=False) async

image_mask_generation(request: ImageMaskGenerationRequest | None, poll_interval: float, timeout: float, task_id: str, asynchronous: Literal[True]) -> ImageMaskGenerationResponse | CeleryTask
image_mask_generation(request: ImageMaskGenerationRequest | None, poll_interval: float, timeout: float, task_id: None, asynchronous: Literal[True]) -> CeleryTask
image_mask_generation(request: ImageMaskGenerationRequest, poll_interval: float, timeout: float, task_id: str | None, asynchronous: Literal[False]) -> ImageMaskGenerationResponse
image_mask_generation(request: ImageMaskGenerationRequest | None, poll_interval: float, timeout: float, task_id: str | None, asynchronous: bool) -> ImageMaskGenerationResponse | CeleryTask

Perform an inference to perform image mask generation.

Source code in pixano_inference/client.py
async def image_mask_generation(
    self,
    request: ImageMaskGenerationRequest | None = None,
    poll_interval: float = 0.1,
    timeout: float = 60,
    task_id: str | None = None,
    asynchronous: bool = False,
) -> ImageMaskGenerationResponse | CeleryTask:
    """Perform an inference to perform image mask generation."""
    return await self.inference(
        route="tasks/image/mask_generation/",
        request=request,
        response_type=ImageMaskGenerationResponse,
        poll_interval=poll_interval,
        timeout=timeout,
        task_id=task_id,
        asynchronous=asynchronous,
    )

image_zero_shot_detection(request=None, poll_interval=0.1, timeout=60, task_id=None, asynchronous=False) async

image_zero_shot_detection(request: ImageZeroShotDetectionRequest | None, poll_interval: float, timeout: float, task_id: str, asynchronous: Literal[True]) -> ImageZeroShotDetectionResponse | CeleryTask
image_zero_shot_detection(request: ImageZeroShotDetectionRequest | None, poll_interval: float, timeout: float, task_id: None, asynchronous: Literal[True]) -> CeleryTask
image_zero_shot_detection(request: ImageZeroShotDetectionRequest, poll_interval: float, timeout: float, task_id: str | None, asynchronous: Literal[False]) -> ImageZeroShotDetectionResponse
image_zero_shot_detection(request: ImageZeroShotDetectionRequest | None, poll_interval: float, timeout: float, task_id: str | None, asynchronous: bool) -> ImageZeroShotDetectionResponse | CeleryTask

Perform an inference to perform video mask generation.

Source code in pixano_inference/client.py
async def image_zero_shot_detection(
    self,
    request: ImageZeroShotDetectionRequest | None = None,
    poll_interval: float = 0.1,
    timeout: float = 60,
    task_id: str | None = None,
    asynchronous: bool = False,
) -> ImageZeroShotDetectionResponse | CeleryTask:
    """Perform an inference to perform video mask generation."""
    return await self.inference(
        route="tasks/image/zero_shot_detection/",
        request=request,
        response_type=ImageZeroShotDetectionResponse,
        poll_interval=poll_interval,
        timeout=timeout,
        task_id=task_id,
        asynchronous=asynchronous,
    )

inference(route, request=None, response_type=None, poll_interval=0.1, timeout=60.0, task_id=None, asynchronous=False) async

inference(route: str, request: BaseRequest | None, response_type: type[BaseResponse] | None, poll_interval: float, timeout: float, task_id: str, asynchronous: Literal[True]) -> BaseResponse | CeleryTask
inference(route: str, request: BaseRequest | None, response_type: type[BaseResponse] | None, poll_interval: float, timeout: float, task_id: None, asynchronous: Literal[True]) -> CeleryTask
inference(route: str, request: BaseRequest, response_type: type[BaseResponse], poll_interval: float, timeout: float, task_id: str | None, asynchronous: Literal[False]) -> BaseResponse
inference(route: str, request: BaseRequest | None, response_type: type[BaseResponse] | None, poll_interval: float, timeout: float, task_id: str | None, asynchronous: bool) -> BaseResponse | CeleryTask

Perform a POST request to the pixano inference server.

Parameters:

Name Type Description Default
route str

The root to the request.

required
request BaseRequest | None

The request of the model.

None
response_type type[BaseResponse] | None

The type of the response.

None
poll_interval float

waiting time between subsequent requests to server to retrieve task results for synchronous requests.

0.1
timeout float

Time to wait for response for synchronous requests. If reached, the request will be aborted.

60.0
task_id str | None

The id of the task to poll for results.

None
asynchronous bool

If True then the function will be called asynchronously and returns a CeleryTask object or poll results when task id is provided.

False

Returns:

Type Description
BaseResponse | CeleryTask

A response from the pixano inference server.

Source code in pixano_inference/client.py
async def inference(
    self,
    route: str,
    request: BaseRequest | None = None,
    response_type: type[BaseResponse] | None = None,
    poll_interval: float = 0.1,
    timeout: float = 60.0,
    task_id: str | None = None,
    asynchronous: bool = False,
) -> BaseResponse | CeleryTask:
    """Perform a POST request to the pixano inference server.

    Args:
        route: The root to the request.
        request: The request of the model.
        response_type: The type of the response.
        poll_interval: waiting time between subsequent requests to server to retrieve task results for synchronous
            requests.
        timeout: Time to wait for response for synchronous requests. If reached, the request will be aborted.
        task_id: The id of the task to poll for results.
        asynchronous: If True then the function will be called asynchronously and returns a CeleryTask object or
            poll results when task id is provided.

    Returns:
        A response from the pixano inference server.
    """
    _validate_task_id_asynchronous_request_response_type(
        task_id=task_id, asynchronous=asynchronous, request=request, response_type=response_type
    )
    _validate_poll_interval_timeout(poll_interval=poll_interval, timeout=timeout)

    if not asynchronous or task_id is None:
        request = cast(BaseRequest, request)
        celery_response: Response = await self.post(route, json=request.model_dump())
        celery_task: CeleryTask = CeleryTask.model_construct(**celery_response.json())

    # Asynchronous calls
    if asynchronous and task_id is None:
        return celery_task
    elif asynchronous and task_id is not None:
        response_type = cast(type[BaseResponse], response_type)
        has_slash = route.endswith("/")
        task_route = route + f"{'' if has_slash else '/'}{task_id}"
        response: dict[str, Any] = (await self.get(task_route)).json()
        if response["status"] == states.SUCCESS:
            return response_type.model_validate(response)
        return CeleryTask.model_construct(**response)

    # Synchronous calls with polling for result retrieval and deletion of celery tasks after timeout.
    response_type = cast(type[BaseResponse], response_type)

    time = 0.0
    has_slash = route.endswith("/")
    task_route = route + f"{'' if has_slash else '/'}{celery_task.id}"
    while time < timeout:
        response = (await self.get(task_route)).json()
        if response["status"] == states.SUCCESS:
            return response_type.model_validate(response)
        elif response["status"] == states.FAILURE:
            raise ValueError("The inference failed. Please check your inputs.")
        time += poll_interval
        await asyncio.sleep(poll_interval)
    await self.delete(task_route)
    raise InferenceTooLongError("The model is either busy or the task takes too long to perform.")

instantiate_model(provider, config, timeout=60) async

Instantiate a model.

Parameters:

Name Type Description Default
provider str

The model provider.

required
config ModelConfig

The configuration of the model.

required
timeout int

The timeout to wait for a response. Please note that even if the timeout is reached, the request will not be aborted.

60
Source code in pixano_inference/client.py
async def instantiate_model(self, provider: str, config: ModelConfig, timeout: int = 60) -> None:
    """Instantiate a model.

    Args:
        provider: The model provider.
        config: The configuration of the model.
        timeout: The timeout to wait for a response. Please note that even if the timeout is reached, the request
            will not be aborted.
    """
    json_content = jsonable_encoder({"provider": provider, "config": config})
    await self.post("providers/instantiate", json=json_content, timeout=timeout)
    return

list_models() async

List all models.

Source code in pixano_inference/client.py
async def list_models(self) -> list[ModelInfo]:
    """List all models."""
    response = await self.get("app/models/")
    return [ModelInfo.model_construct(**model) for model in response.json()]

post(path, **kwargs) async

Perform a POST request to the pixano inference server.

Parameters:

Name Type Description Default
path str

The path of the request.

required
kwargs Any

The keyword arguments to pass to the request or httpx client.

{}
Source code in pixano_inference/client.py
async def post(self, path: str, **kwargs: Any) -> Response:
    """Perform a POST request to the pixano inference server.

    Args:
        path: The path of the request.
        kwargs: The keyword arguments to pass to the request or httpx client.
    """
    return await self._rest_call(path=path, method="POST", **kwargs)

put(path, **kwargs) async

Perform a PUT request to the pixano inference server.

Parameters:

Name Type Description Default
path str

The path of the request.

required
kwargs Any

The keyword arguments to pass to the request or httpx client.

{}
Source code in pixano_inference/client.py
async def put(self, path: str, **kwargs: Any) -> Response:
    """Perform a PUT request to the pixano inference server.

    Args:
        path: The path of the request.
        kwargs: The keyword arguments to pass to the request or httpx client.
    """
    return await self._rest_call(path=path, method="PUT", **kwargs)

text_image_conditional_generation(request=None, poll_interval=0.1, timeout=60, task_id=None, asynchronous=False) async

text_image_conditional_generation(request: TextImageConditionalGenerationRequest | None, poll_interval: float, timeout: float, task_id: str, asynchronous: Literal[True]) -> TextImageConditionalGenerationResponse | CeleryTask
text_image_conditional_generation(request: TextImageConditionalGenerationRequest | None, poll_interval: float, timeout: float, task_id: None, asynchronous: Literal[True]) -> CeleryTask
text_image_conditional_generation(request: TextImageConditionalGenerationRequest, poll_interval: float, timeout: float, task_id: str | None, asynchronous: Literal[False]) -> TextImageConditionalGenerationResponse | CeleryTask
text_image_conditional_generation(request: TextImageConditionalGenerationRequest | None, poll_interval: float, timeout: float, task_id: str | None, asynchronous: bool) -> TextImageConditionalGenerationResponse | CeleryTask

Perform an inference to perform text-image conditional generation.

Source code in pixano_inference/client.py
async def text_image_conditional_generation(
    self,
    request: TextImageConditionalGenerationRequest | None = None,
    poll_interval: float = 0.1,
    timeout: float = 60,
    task_id: str | None = None,
    asynchronous: bool = False,
) -> TextImageConditionalGenerationResponse | CeleryTask:
    """Perform an inference to perform text-image conditional generation."""
    return await self.inference(
        route="tasks/multimodal/text-image/conditional_generation/",
        request=request,
        response_type=TextImageConditionalGenerationResponse,
        poll_interval=poll_interval,
        timeout=timeout,
        task_id=task_id,
        asynchronous=asynchronous,
    )

video_mask_generation(request=None, poll_interval=0.1, timeout=60, task_id=None, asynchronous=False) async

video_mask_generation(request: VideoMaskGenerationRequest | None, poll_interval: float, timeout: float, task_id: str, asynchronous: Literal[True]) -> VideoMaskGenerationResponse | CeleryTask
video_mask_generation(request: VideoMaskGenerationRequest | None, poll_interval: float, timeout: float, task_id: None, asynchronous: Literal[True]) -> CeleryTask
video_mask_generation(request: VideoMaskGenerationRequest, poll_interval: float, timeout: float, task_id: str | None, asynchronous: Literal[False]) -> VideoMaskGenerationResponse
video_mask_generation(request: VideoMaskGenerationRequest | None, poll_interval: float, timeout: float, task_id: str | None, asynchronous: bool) -> VideoMaskGenerationResponse | CeleryTask

Perform an inference to perform video mask generation.

Source code in pixano_inference/client.py
async def video_mask_generation(
    self,
    request: VideoMaskGenerationRequest | None = None,
    poll_interval: float = 0.1,
    timeout: float = 60,
    task_id: str | None = None,
    asynchronous: bool = False,
) -> VideoMaskGenerationResponse | CeleryTask:
    """Perform an inference to perform video mask generation."""
    return await self.inference(
        route="tasks/video/mask_generation/",
        request=request,
        response_type=VideoMaskGenerationResponse,
        poll_interval=poll_interval,
        timeout=timeout,
        task_id=task_id,
        asynchronous=asynchronous,
    )

raise_if_error(response)

Raise an error from a response.

Source code in pixano_inference/client.py
def raise_if_error(response: Response) -> None:
    """Raise an error from a response."""
    if response.is_success:
        return
    error_out = f"HTTP {response.status_code}: {response.reason_phrase}"
    try:
        json_detail = response.json()
    except Exception:
        json_detail = {}

    detail = json_detail.get("detail", None)
    if detail is not None:
        error_out += f" - {detail}"
    error = json_detail.get("error", None)
    if error is not None:
        error_out += f" - {error}"
    raise HTTPException(response.status_code, detail=error_out)