Skip to content

pixano.app.routers.inference.conditional_generation

call_text_image_conditional_generation(dataset_id, conversation, messages, model, settings, max_new_tokens=DEFAULT_MAX_NEW_TOKENS, temperature=DEFAULT_TEMPERATURE, image_regex=DEFAULT_IMAGE_REGEX, role_system=DEFAULT_ROLE_SYSTEM, role_user=DEFAULT_ROLE_USER, role_assistant=DEFAULT_ROLE_ASSISTANT) async

Call a text image conditional generation model for a conversation.

Parameters:

Name Type Description Default
dataset_id Annotated[str, Body(embed=True)]

The ID of the dataset to use.

required
conversation Annotated[EntityModel, Body(embed=True)]

The conversation to use.

required
messages Annotated[list[AnnotationModel], Body(embed=True)]

The messages to use.

required
model Annotated[str, Body(embed=True)]

The name of the model to use.

required
settings Annotated[Settings, Depends(get_settings)]

App settings.

required
max_new_tokens Annotated[int, Body(embed=True)]

The maximum number of tokens to generate.

DEFAULT_MAX_NEW_TOKENS
temperature Annotated[float, Body(embed=True)]

The temperature to use.

DEFAULT_TEMPERATURE
image_regex Annotated[str, Body(embed=True)]

The regular expression to use to extract images from the text.

DEFAULT_IMAGE_REGEX
role_system Annotated[str, Body(embed=True)]

The role of the system.

DEFAULT_ROLE_SYSTEM
role_user Annotated[str, Body(embed=True)]

The role of the user.

DEFAULT_ROLE_USER
role_assistant Annotated[str, Body(embed=True)]

The role of the assistant.

DEFAULT_ROLE_ASSISTANT

Returns: The generated message model.

Source code in pixano/app/routers/inference/conditional_generation.py
@router.post(
    "/text-image",
    response_model=AnnotationModel,
)
async def call_text_image_conditional_generation(
    dataset_id: Annotated[str, Body(embed=True)],
    conversation: Annotated[EntityModel, Body(embed=True)],
    messages: Annotated[list[AnnotationModel], Body(embed=True)],
    model: Annotated[str, Body(embed=True)],
    settings: Annotated[Settings, Depends(get_settings)],
    max_new_tokens: Annotated[int, Body(embed=True)] = DEFAULT_MAX_NEW_TOKENS,
    temperature: Annotated[float, Body(embed=True)] = DEFAULT_TEMPERATURE,
    image_regex: Annotated[str, Body(embed=True)] = DEFAULT_IMAGE_REGEX,
    role_system: Annotated[str, Body(embed=True)] = DEFAULT_ROLE_SYSTEM,
    role_user: Annotated[str, Body(embed=True)] = DEFAULT_ROLE_USER,
    role_assistant: Annotated[str, Body(embed=True)] = DEFAULT_ROLE_ASSISTANT,
) -> AnnotationModel:
    """Call a text image conditional generation model for a conversation.

    Args:
        dataset_id: The ID of the dataset to use.
        conversation: The conversation to use.
        messages: The messages to use.
        model: The name of the model to use.
        settings: App settings.
        max_new_tokens: The maximum number of tokens to generate.
        temperature: The temperature to use.
        image_regex: The regular expression to use to extract images from the text.
        role_system: The role of the system.
        role_user: The role of the user.
        role_assistant: The role of the assistant.

    Returns: The generated message model.
    """
    dataset = get_dataset(dataset_id=dataset_id, dir=settings.library_dir, media_dir=settings.media_dir)
    client = get_client_from_settings(settings=settings)

    if not is_conversation(dataset.schema.schemas[conversation.table_info.name]):
        raise HTTPException(status_code=400, detail="Conversation must be a conversation.")

    conversation_row: Conversation = conversation.to_row(dataset)

    messages_in_one_table = len({m.table_info.name for m in messages}) == 1
    if not messages_in_one_table:
        raise HTTPException(status_code=400, detail="Only one table for messages is allowed.")
    elif not is_message(dataset.schema.schemas[messages[0].table_info.name]):
        raise HTTPException(status_code=400, detail="Messages must be a message.")

    messages_rows: list[Message] = []
    for m in messages:
        m_row: Message = m.to_row(dataset)
        messages_rows.append(m_row)

    source = get_model_source(dataset=dataset, model=model)

    try:
        infered_message = await text_image_conditional_generation(
            client=client,
            source=source,
            media_dir=settings.media_dir,
            messages=messages_rows,
            conversation=conversation_row,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            image_regex=image_regex,
            role_system=role_system,
            role_user=role_user,
            role_assistant=role_assistant,
        )
    except Exception as e:
        raise HTTPException(status_code=400, detail=str(e)) from e

    message_model = AnnotationModel.from_row(row=infered_message, table_info=messages[0].table_info)

    return message_model