Skip to content

pixano_inference.ray.routes.inference

Capability-based inference routes with schema bridging.

register_inference_routes(app, deployment_manager)

Register capability-based inference endpoints.

Parameters:

Name Type Description Default
app FastAPI

FastAPI application.

required
deployment_manager DeploymentManager

The deployment manager instance.

required
Source code in pixano_inference/ray/routes/inference.py
def register_inference_routes(app: FastAPI, deployment_manager: DeploymentManager) -> None:
    """Register capability-based inference endpoints.

    Args:
        app: FastAPI application.
        deployment_manager: The deployment manager instance.
    """

    @app.post("/inference/segmentation/")
    async def segmentation(request: SegmentationRequest) -> dict[str, Any]:
        """Run segmentation inference."""
        input_obj = request.to_input()
        return await _run_inference(deployment_manager, request.model, input_obj, "segmentation")

    @app.post("/inference/segmentation/binary")
    async def segmentation_binary(
        request: Request,
    ) -> dict[str, Any]:
        """Run segmentation inference from a binary image upload."""
        parsed_request = await _build_binary_request_from_request(
            request,
            SegmentationRequest,
            max_part_size=_LEGACY_BINARY_METADATA_MAX_PART_SIZE,
            file_field="image",
            payload_key="image",
        )
        input_obj = parsed_request.to_input()
        return await _run_inference(deployment_manager, parsed_request.model, input_obj, "segmentation")

    @app.post("/inference/tracking/")
    async def tracking(request: TrackingRequest) -> dict[str, Any]:
        """Run tracking inference."""
        input_obj = request.to_input()
        return await _run_inference(deployment_manager, request.model, input_obj, "tracking")

    @app.post("/inference/tracking/binary")
    async def tracking_binary(
        request: Request,
    ) -> dict[str, Any]:
        """Run tracking inference from uploaded frame binaries."""
        parsed_request = await _build_binary_request_from_request(
            request,
            TrackingRequest,
            max_part_size=_LEGACY_BINARY_METADATA_MAX_PART_SIZE,
            file_field="frames",
            payload_key="video",
        )
        input_obj = parsed_request.to_input()
        return await _run_inference(deployment_manager, parsed_request.model, input_obj, "tracking")

    @app.post("/inference/tracking/jobs/")
    async def tracking_job_submit(request: TrackingRequest) -> dict[str, Any]:
        """Submit an asynchronous tracking job."""
        _get_validated_handle(deployment_manager, request.model, "tracking")
        input_obj = request.to_input()
        job_id = deployment_manager.submit_tracking_job(request.model, input_obj)
        job = deployment_manager.get_tracking_job(job_id)
        if job is None:
            raise HTTPException(status_code=500, detail=f"Tracking job '{job_id}' was not created.")
        return _serialize_tracking_job(job_id, job)

    @app.post("/inference/tracking/jobs/binary")
    async def tracking_job_submit_binary(request: Request) -> dict[str, Any]:
        """Submit an asynchronous tracking job from uploaded frame binaries."""
        parsed_request = await _build_binary_request_from_request(
            request,
            TrackingRequest,
            max_part_size=_LEGACY_BINARY_METADATA_MAX_PART_SIZE,
            file_field="frames",
            payload_key="video",
        )
        _get_validated_handle(deployment_manager, parsed_request.model, "tracking")
        input_obj = parsed_request.to_input()
        job_id = deployment_manager.submit_tracking_job(parsed_request.model, input_obj)
        job = deployment_manager.get_tracking_job(job_id)
        if job is None:
            raise HTTPException(status_code=500, detail=f"Tracking job '{job_id}' was not created.")
        return _serialize_tracking_job(job_id, job)

    @app.get("/inference/tracking/jobs/{job_id}")
    async def tracking_job_status(job_id: str) -> dict[str, Any]:
        """Poll the current status of an asynchronous tracking job."""
        job = deployment_manager.get_tracking_job(job_id)
        if job is None:
            raise HTTPException(status_code=404, detail=f"Tracking job '{job_id}' not found")
        return _serialize_tracking_job(job_id, job)

    @app.delete("/inference/tracking/jobs/{job_id}")
    async def tracking_job_cancel(job_id: str) -> dict[str, Any]:
        """Cancel an asynchronous tracking job."""
        job = deployment_manager.cancel_tracking_job(job_id)
        if job is None:
            raise HTTPException(status_code=404, detail=f"Tracking job '{job_id}' not found")
        return _serialize_tracking_job(job_id, job)

    @app.post("/inference/vlm/")
    async def vlm(request: VLMRequest) -> dict[str, Any]:
        """Run VLM inference."""
        input_obj = request.to_input()
        return await _run_inference(deployment_manager, request.model, input_obj, "vlm")

    @app.post("/inference/detection/")
    async def detection(request: DetectionRequest) -> dict[str, Any]:
        """Run detection inference."""
        input_obj = request.to_input()
        return await _run_inference(deployment_manager, request.model, input_obj, "detection")