def register_inference_routes(app: FastAPI, deployment_manager: DeploymentManager) -> None:
"""Register capability-based inference endpoints.
Args:
app: FastAPI application.
deployment_manager: The deployment manager instance.
"""
@app.post("/inference/segmentation/")
async def segmentation(request: SegmentationRequest) -> dict[str, Any]:
"""Run segmentation inference."""
input_obj = request.to_input()
return await _run_inference(deployment_manager, request.model, input_obj, "segmentation")
@app.post("/inference/segmentation/binary")
async def segmentation_binary(
request: Request,
) -> dict[str, Any]:
"""Run segmentation inference from a binary image upload."""
parsed_request = await _build_binary_request_from_request(
request,
SegmentationRequest,
max_part_size=_LEGACY_BINARY_METADATA_MAX_PART_SIZE,
file_field="image",
payload_key="image",
)
input_obj = parsed_request.to_input()
return await _run_inference(deployment_manager, parsed_request.model, input_obj, "segmentation")
@app.post("/inference/tracking/")
async def tracking(request: TrackingRequest) -> dict[str, Any]:
"""Run tracking inference."""
input_obj = request.to_input()
return await _run_inference(deployment_manager, request.model, input_obj, "tracking")
@app.post("/inference/tracking/binary")
async def tracking_binary(
request: Request,
) -> dict[str, Any]:
"""Run tracking inference from uploaded frame binaries."""
parsed_request = await _build_binary_request_from_request(
request,
TrackingRequest,
max_part_size=_LEGACY_BINARY_METADATA_MAX_PART_SIZE,
file_field="frames",
payload_key="video",
)
input_obj = parsed_request.to_input()
return await _run_inference(deployment_manager, parsed_request.model, input_obj, "tracking")
@app.post("/inference/tracking/jobs/")
async def tracking_job_submit(request: TrackingRequest) -> dict[str, Any]:
"""Submit an asynchronous tracking job."""
_get_validated_handle(deployment_manager, request.model, "tracking")
input_obj = request.to_input()
job_id = deployment_manager.submit_tracking_job(request.model, input_obj)
job = deployment_manager.get_tracking_job(job_id)
if job is None:
raise HTTPException(status_code=500, detail=f"Tracking job '{job_id}' was not created.")
return _serialize_tracking_job(job_id, job)
@app.post("/inference/tracking/jobs/binary")
async def tracking_job_submit_binary(request: Request) -> dict[str, Any]:
"""Submit an asynchronous tracking job from uploaded frame binaries."""
parsed_request = await _build_binary_request_from_request(
request,
TrackingRequest,
max_part_size=_LEGACY_BINARY_METADATA_MAX_PART_SIZE,
file_field="frames",
payload_key="video",
)
_get_validated_handle(deployment_manager, parsed_request.model, "tracking")
input_obj = parsed_request.to_input()
job_id = deployment_manager.submit_tracking_job(parsed_request.model, input_obj)
job = deployment_manager.get_tracking_job(job_id)
if job is None:
raise HTTPException(status_code=500, detail=f"Tracking job '{job_id}' was not created.")
return _serialize_tracking_job(job_id, job)
@app.get("/inference/tracking/jobs/{job_id}")
async def tracking_job_status(job_id: str) -> dict[str, Any]:
"""Poll the current status of an asynchronous tracking job."""
job = deployment_manager.get_tracking_job(job_id)
if job is None:
raise HTTPException(status_code=404, detail=f"Tracking job '{job_id}' not found")
return _serialize_tracking_job(job_id, job)
@app.delete("/inference/tracking/jobs/{job_id}")
async def tracking_job_cancel(job_id: str) -> dict[str, Any]:
"""Cancel an asynchronous tracking job."""
job = deployment_manager.cancel_tracking_job(job_id)
if job is None:
raise HTTPException(status_code=404, detail=f"Tracking job '{job_id}' not found")
return _serialize_tracking_job(job_id, job)
@app.post("/inference/vlm/")
async def vlm(request: VLMRequest) -> dict[str, Any]:
"""Run VLM inference."""
input_obj = request.to_input()
return await _run_inference(deployment_manager, request.model, input_obj, "vlm")
@app.post("/inference/detection/")
async def detection(request: DetectionRequest) -> dict[str, Any]:
"""Run detection inference."""
input_obj = request.to_input()
return await _run_inference(deployment_manager, request.model, input_obj, "detection")