Skip to content

pixano.models.inference_model

InferenceModel(name, model_id='', device='', description='')

Bases: ABC

Abstract parent class for OfflineModel and OnlineModel

Attributes:

Name Type Description
name str

Model name

model_id str

Model ID

device str

Model GPU or CPU device

description str

Model description

Parameters:

Name Type Description Default
name str

Model name

required
model_id str

Model ID. Defaults to "".

''
device str

Model GPU or CPU device. Defaults to "".

''
description str

Model description. Defaults to "".

''
Source code in pixano/models/inference_model.py
def __init__(
    self,
    name: str,
    model_id: str = "",
    device: str = "",
    description: str = "",
) -> None:
    """Initialize model name and ID

    Args:
        name (str): Model name
        model_id (str, optional): Model ID. Defaults to "".
        device (str, optional): Model GPU or CPU device. Defaults to "".
        description (str, optional): Model description. Defaults to "".
    """

    self.name = name
    if model_id == "":
        self.model_id = f"{datetime.now().strftime('%y%m%d_%H%M%S')}_{name}"
    else:
        self.model_id = model_id
    self.device = device
    self.description = description

create_table(process_type, views, dataset)

Create inference table in dataset

Parameters:

Name Type Description Default
process_type str

Process type - 'pre_ann' for pre-annotations to accept or reject as Ground Truth - 'model_run' for annotations to compare to Ground Truth - 'segment_emb' for segmentation embeddings - 'search_emb' for semantic search embeddings

required
views list[str]

Dataset views

required
dataset Dataset

Dataset

required

Returns:

Type Description
DatasetTable

Inference table

Source code in pixano/models/inference_model.py
def create_table(
    self,
    process_type: str,
    views: list[str],
    dataset: Dataset,
) -> DatasetTable:
    """Create inference table in dataset

    Args:
        process_type (str): Process type
                            - 'pre_ann' for pre-annotations to accept or reject as Ground Truth
                            - 'model_run' for annotations to compare to Ground Truth
                            - 'segment_emb' for segmentation embeddings
                            - 'search_emb' for semantic search embeddings
        views (list[str]): Dataset views
        dataset (Dataset): Dataset

    Returns:
        DatasetTable: Inference table
    """
    table = None
    table_group = None

    # Inference table filename
    table_filename = (
        f"emb_{self.model_id}" if "emb" in process_type else f"obj_{self.model_id}"
    )

    # Annotations schema
    if process_type in ["pre_ann", "model_run"]:
        table_group = "objects"
        # Create table
        table = DatasetTable(
            name=table_filename,
            fields={
                "id": "str",
                "item_id": "str",
                "view_id": "str",
                "bbox": "bbox",
                "mask": "compressedrle",
                "category": "str",
            },
            source=self.name if process_type == "model_run" else "Pre-annotation",
            type=None,
        )
    # Segmentation embeddings schema
    elif process_type == "segment_emb":
        table_group = "embeddings"
        # Add embedding column for each selected view
        fields = {"id": "str"}
        for view in views:
            fields[view] = "bytes"
        # Create table
        table = DatasetTable(
            name=table_filename,
            fields=fields,
            source=self.name,
            type="segment",
        )

    # Semantic search embeddings schema
    elif process_type == "search_emb":
        table_group = "embeddings"
        # Add vector column for each selected view
        fields = {"id": "str"}
        for view in views:
            fields[view] = "vector(512)"
        # Create table
        table = DatasetTable(
            name=table_filename,
            fields=fields,
            source=self.name,
            type="search",
        )

    # Create table
    if table and table_group:
        dataset.create_table(table, table_group)

    return table

export_to_onnx(library_dir)

Export Torch model to ONNX

Parameters:

Name Type Description Default
library_dir Path

Dataset library directory

required
Source code in pixano/models/inference_model.py
def export_to_onnx(self, library_dir: Path):
    """Export Torch model to ONNX

    Args:
        library_dir (Path): Dataset library directory
    """

preannotate(batch, views, uri_prefix, threshold=0.0, prompt='')

Generate annotations for dataset rows

Parameters:

Name Type Description Default
batch RecordBatch

Input batch

required
views list[str]

Dataset views

required
uri_prefix str

URI prefix for media files

required
threshold float

Confidence threshold. Defaults to 0.0.

0.0
prompt str

Annotation text prompt. Defaults to "".

''

Returns:

Type Description
list[dict]

Annotation rows

Source code in pixano/models/inference_model.py
def preannotate(
    self,
    batch: pa.RecordBatch,
    views: list[str],
    uri_prefix: str,
    threshold: float = 0.0,
    prompt: str = "",
) -> list[dict]:
    """Generate annotations for dataset rows

    Args:
        batch (pa.RecordBatch): Input batch
        views (list[str]): Dataset views
        uri_prefix (str): URI prefix for media files
        threshold (float, optional): Confidence threshold. Defaults to 0.0.
        prompt (str, optional): Annotation text prompt. Defaults to "".

    Returns:
        list[dict]: Annotation rows
    """

precompute_embeddings(batch, views, uri_prefix)

Precompute embeddings for dataset rows

Parameters:

Name Type Description Default
batch RecordBatch

Input batch

required
views list[str]

Dataset views

required
uri_prefix str

URI prefix for media files

required

Returns:

Type Description
list[dict]

Embedding rows

Source code in pixano/models/inference_model.py
def precompute_embeddings(
    self,
    batch: pa.RecordBatch,
    views: list[str],
    uri_prefix: str,
) -> list[dict]:
    """Precompute embeddings for dataset rows

    Args:
        batch (pa.RecordBatch): Input batch
        views (list[str]): Dataset views
        uri_prefix (str): URI prefix for media files

    Returns:
        list[dict]: Embedding rows
    """

process_dataset(dataset_dir, views, process_type, splits=None, batch_size=1, threshold=0.0, prompt='')

Process dataset for annotations or embeddings

Parameters:

Name Type Description Default
dataset_dir Path

Dataset directory

required
views list[str]

Dataset views

required
process_type str

Process type - 'pre_ann' for pre-annotations to accept or reject as Ground Truth - 'model_run' for annotations to compare to Ground Truth - 'segment_emb' for segmentation embeddings - 'search_emb' for semantic search embeddings

required
splits list[str]

Dataset splits, all if None. Defaults to None.

None
batch_size int

Rows per process batch. Defaults to 1.

1
threshold float

Confidence threshold for predictions. Defaults to 0.0.

0.0
prompt str

Annotation text prompt. Defaults to "".

''

Returns:

Type Description
Dataset

Dataset

Source code in pixano/models/inference_model.py
def process_dataset(
    self,
    dataset_dir: Path,
    views: list[str],
    process_type: str,
    splits: list[str] = None,
    batch_size: int = 1,
    threshold: float = 0.0,
    prompt: str = "",
) -> Dataset:
    """Process dataset for annotations or embeddings

    Args:
        dataset_dir (Path): Dataset directory
        views (list[str]): Dataset views
        process_type (str): Process type
                            - 'pre_ann' for pre-annotations to accept or reject as Ground Truth
                            - 'model_run' for annotations to compare to Ground Truth
                            - 'segment_emb' for segmentation embeddings
                            - 'search_emb' for semantic search embeddings
        splits (list[str], optional): Dataset splits, all if None. Defaults to None.
        batch_size (int, optional): Rows per process batch. Defaults to 1.
        threshold (float, optional): Confidence threshold for predictions. Defaults to 0.0.
        prompt (str, optional): Annotation text prompt. Defaults to "".

    Returns:
        Dataset: Dataset
    """

    if process_type not in [
        "pre_ann",
        "model_run",
        "segment_emb",
        "search_emb",
    ]:
        raise ValueError(
            "Please choose a valid process type"
            "('pre_ann' or 'model_run' for for annotations,"
            "'segment_emb' or 'search_emb' for segmentation or semantic search embeddings)"
        )

    if not views:
        raise ValueError("Please select which views you want to process on.")

    # Load dataset
    dataset = Dataset(dataset_dir)

    # Create inference table
    table = self.create_table(process_type, views, dataset)

    # Load dataset tables
    ds_tables = dataset.open_tables()

    # Load inference table
    table_group = "embeddings" if "emb" in process_type else "objects"
    table_lance = ds_tables[table_group][table.name].to_lance()

    # Create URI prefix
    uri_prefix = dataset.media_dir.absolute().as_uri()

    # Add rows to tables
    save_batch_size = 1024
    with tqdm(desc="Processing dataset", total=dataset.num_rows) as progress:
        for save_batch_index in range(ceil(dataset.num_rows / save_batch_size)):
            # Load rows
            process_batches = self._load_rows(
                dataset,
                ds_tables,
                splits,
                batch_size,
                save_batch_size,
                save_batch_index,
            )

            # Process rows
            save_batch = []
            for process_batch in process_batches:
                save_batch.extend(
                    self.precompute_embeddings(process_batch, views, uri_prefix)
                    if "emb" in process_type
                    else self.preannotate(
                        process_batch, views, uri_prefix, threshold, prompt
                    )
                )
                progress.update(batch_size)

            # Save rows
            pyarrow_save_batch = pa.Table.from_pylist(
                save_batch,
                schema=Fields(table.fields).to_schema(),
            )
            lance.write_dataset(
                pyarrow_save_batch,
                uri=table_lance.uri,
                mode="append",
            )

    # Optimize and clear creation history
    table_lance.optimize.compact_files()
    table_lance.cleanup_old_versions(older_than=timedelta(0))

    return dataset