Skip to content

pixano.data.importers.importer

Importer(name, description, tables, splits, features_values=None)

Bases: ABC

Dataset Importer class

Attributes:

Name Type Description
info DatasetInfo

Dataset information

input_dirs dict[str, Path]

Dataset input directories

Parameters:

Name Type Description Default
name str

Dataset name

required
description str

Dataset description

required
tables dict[str, list[DatasetTable]]

Dataset tables

required
splits list[str]

Dataset splits

required
features_values FeaturesValues

Values for features

None
Source code in pixano/data/importers/importer.py
def __init__(
    self,
    name: str,
    description: str,
    tables: dict[str, list[DatasetTable]],
    splits: list[str],
    features_values: FeaturesValues = None,
):
    """Initialize Importer

    Args:
        name (str): Dataset name
        description (str): Dataset description
        tables (dict[str, list[DatasetTable]]): Dataset tables
        splits (list[str]): Dataset splits
        features_values (FeaturesValues, optional): Values for features
    """

    # Check input directories
    for source_path in self.input_dirs.values():
        if not source_path.exists():
            raise FileNotFoundError(f"{source_path} does not exist.")
        if not any(source_path.iterdir()):
            raise FileNotFoundError(f"{source_path} is empty.")

    # Create DatasetInfo
    self.info = DatasetInfo(
        id=shortuuid.uuid(),
        name=name,
        description=description,
        estimated_size="N/A",
        num_elements=0,
        splits=splits,
        tables=tables,
        features_values=features_values,
    )

copy_or_move_files(import_dir, ds_tables, copy)

Copy or move dataset files

Parameters:

Name Type Description Default
import_dir Path

Import directory

required
ds_tables dict[str, dict[str, LanceTable]]

Dataset tables

required
copy bool

True to copy files, False to move them

required
Source code in pixano/data/importers/importer.py
def copy_or_move_files(
    self,
    import_dir: Path,
    ds_tables: dict[str, dict[str, lancedb.db.LanceTable]],
    copy: bool,
):
    """Copy or move dataset files

    Args:
        import_dir (Path): Import directory
        ds_tables (dict[str, dict[str, lancedb.db.LanceTable]]): Dataset tables
        copy (bool): True to copy files, False to move them
    """

    if copy:
        for table in tqdm(
            ds_tables["media"].values(), desc="Copying media directories"
        ):
            for field in table.schema:
                if field.name in self.input_dirs:
                    field_dir = import_dir / "media" / field.name
                    field_dir.mkdir(parents=True, exist_ok=True)
                    if self.input_dirs[field.name] != field_dir:
                        shutil.copytree(
                            self.input_dirs[field.name],
                            field_dir,
                            dirs_exist_ok=True,
                        )
    else:
        for table in tqdm(
            ds_tables["media"].values(), desc="Moving media directories"
        ):
            for field in table.schema:
                if field.name in self.input_dirs:
                    field_dir = import_dir / "media" / field.name
                    if self.input_dirs[field.name] != field_dir:
                        self.input_dirs[field.name].rename(field_dir)

create_preview(import_dir, ds_tables)

Create dataset preview image

Parameters:

Name Type Description Default
import_dir Path

Import directory

required
ds_tables dict[str, dict[str, LanceTable]]

Dataset tables

required
Source code in pixano/data/importers/importer.py
def create_preview(
    self,
    import_dir: Path,
    ds_tables: dict[str, dict[str, lancedb.db.LanceTable]],
):
    """Create dataset preview image

    Args:
        import_dir (Path): Import directory
        ds_tables (dict[str, dict[str, lancedb.db.LanceTable]]): Dataset tables
    """

    # Get list of image fields
    if "media" in ds_tables:
        if "image" in ds_tables["media"]:
            image_table = ds_tables["media"]["image"]
            if len(image_table) > 0:
                image_fields = [
                    field.name for field in image_table.schema if field.name != "id"
                ]
                with tqdm(desc="Creating dataset thumbnail", total=1) as progress:
                    tile_w = 64
                    tile_h = 64
                    preview = Image.new("RGB", (4 * tile_w, 2 * tile_h))
                    for i in range(8):
                        field = image_fields[i % len(image_fields)]
                        item_id = random.randrange(len(image_table))
                        item = image_table.to_lance().take([item_id]).to_pylist()[0]
                        with Image.open(BytesIO(item[field].preview_bytes)) as im:
                            preview.paste(
                                im,
                                ((i % 4) * tile_w, (int(i / 4) % 2) * tile_h),
                            )
                    preview.save(import_dir / "preview.png")
                    progress.update(1)

create_tables(media_fields=None, object_fields=None)

Create dataset tables

Parameters:

Name Type Description Default
media_fields dict[str, str]

Media fields. Defaults to None.

None
object_fields dict[str, str]

Object fields. Defaults to None.

None

Returns:

Type Description
dict[str, list[DatasetTable]]

Tables

Source code in pixano/data/importers/importer.py
def create_tables(
    self, media_fields: dict[str, str] = None, object_fields: dict[str, str] = None
):
    """Create dataset tables

    Args:
        media_fields (dict[str, str], optional): Media fields. Defaults to None.
        object_fields (dict[str, str], optional): Object fields. Defaults to None.

    Returns:
        dict[str, list[DatasetTable]]: Tables
    """

    if media_fields is None:
        media_fields = {"image": "image"}

    tables: dict[str, list[DatasetTable]] = {
        "main": [
            DatasetTable(
                name="db",
                fields={
                    "id": "str",
                    "original_id": "str",
                    "views": "[str]",
                    "split": "str",
                },
            )
        ],
        "media": [],
    }

    # Add media fields
    for field_name, field_type in media_fields.items():
        table_exists = False
        # If table for given field type exists
        for media_table in tables["media"]:
            if field_type == media_table.name and not table_exists:
                media_table.fields[field_name] = field_type
                table_exists = True
        # Else, create that table
        if not table_exists:
            tables["media"].append(
                DatasetTable(
                    name=field_type,
                    fields={
                        "id": "str",
                        field_name: field_type,
                    },
                )
            )

    # Add object fields
    if object_fields is not None:
        tables["objects"] = [
            DatasetTable(
                name="objects",
                fields={"id": "str", "item_id": "str", "view_id": "str"}
                | object_fields,
                source="Ground Truth",
            )
        ]

    return tables

import_dataset(import_dir, copy=True)

Import dataset to Pixano format

Parameters:

Name Type Description Default
import_dir Path

Import directory

required
copy bool

True to copy files to the import directory, False to move them. Defaults to True.

True

Returns:

Type Description
Dataset

Imported dataset

Source code in pixano/data/importers/importer.py
def import_dataset(
    self,
    import_dir: Path,
    copy: bool = True,
) -> Dataset:
    """Import dataset to Pixano format

    Args:
        import_dir (Path): Import directory
        copy (bool, optional): True to copy files to the import directory, False to move them. Defaults to True.

    Returns:
        Dataset: Imported dataset
    """

    # Create dataset
    dataset = Dataset.create(import_dir, self.info)

    # Load dataset tables
    ds_tables = dataset.open_tables()

    # Initalize batches
    ds_batches: dict[str, dict[str, list]] = defaultdict(dict)
    for group_name, table_group in self.info.tables.items():
        for table in table_group:
            ds_batches[group_name][table.name] = []

    # Add rows to tables
    save_batch_size = 1024
    for rows in tqdm(self.import_rows(), desc="Importing dataset"):
        for group_name, table_group in self.info.tables.items():
            for table in table_group:
                # Store rows in a batch
                ds_batches[group_name][table.name].extend(
                    rows[group_name][table.name]
                )
                # If batch reaches 1024 rows, store in table
                if len(ds_batches[group_name][table.name]) >= save_batch_size:
                    pa_batch = pa.Table.from_pylist(
                        ds_batches[group_name][table.name],
                        schema=Fields(table.fields).to_schema(),
                    )
                    lance.write_dataset(
                        pa_batch,
                        uri=ds_tables[group_name][table.name].to_lance().uri,
                        mode="append",
                    )
                    ds_batches[group_name][table.name] = []

    # Store final batches
    for group_name, table_group in self.info.tables.items():
        for table in table_group:
            if len(ds_batches[group_name][table.name]) > 0:
                pa_batch = pa.Table.from_pylist(
                    ds_batches[group_name][table.name],
                    schema=Fields(table.fields).to_schema(),
                )
                lance.write_dataset(
                    pa_batch,
                    uri=ds_tables[group_name][table.name].to_lance().uri,
                    mode="append",
                )
                ds_batches[group_name][table.name] = []

    # Optimize and clear creation history
    for tables in ds_tables.values():
        for table in tables.values():
            table.to_lance().optimize.compact_files()
            table.to_lance().cleanup_old_versions(older_than=timedelta(0))

    # Refresh tables
    ds_tables = dataset.open_tables()

    # Raise error if generated dataset is empty
    if len(ds_tables["main"]["db"]) == 0:
        raise FileNotFoundError(
            "Generated dataset is empty. Please make sure that the paths to your media files are correct, and that they each contain subfolders for your splits."
        )

    # Create DatasetInfo
    dataset.info.num_elements = len(ds_tables["main"]["db"])
    dataset.info.estimated_size = estimate_size(import_dir)
    dataset.save_info()

    # Create thumbnail
    self.create_preview(import_dir, ds_tables)

    # Copy or move media directories
    self.copy_or_move_files(import_dir, ds_tables, copy)

    return Dataset(import_dir)

import_rows() abstractmethod

Process dataset rows for import

Yields:

Type Description
Iterator

Processed rows

Source code in pixano/data/importers/importer.py
@abstractmethod
def import_rows(self) -> Iterator:
    """Process dataset rows for import

    Yields:
        Iterator: Processed rows
    """