Skip to content

Commit 10b2549

Browse files
authored
[python] Ingest performance improvements (#3865)
* first cut at ingest tuning * mp optional param * fix tests, add load from S3 * reshape fix * lint * add multiprocessing config warning * performance tuning * pass context * lint * add test * interim checkpoint on enum cleanup * add test for subset * add new extend enum code (commented out) * migrate to new enum API * bump typeguard version to work around typeguard#504 * improve doctrings * more cleanup * loosen existance checks when from_anndata/from_h5ad are run in schema_only mode * cleanup and additional error checking * more docstring cleanup * more docstring cleanup * more docstrings * pr fb * pr fb * more pr fb * increase test code coverage * additional code coverage * lint * PR fb * PR fb * more PR fb * lint * remove extraneous line break in docstring
1 parent 92fc2b8 commit 10b2549

11 files changed

+1273
-729
lines changed

apis/python/setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,7 @@ def run(self):
344344
install_requires=[
345345
"anndata>=0.10.1",
346346
"attrs>=22.2",
347+
"more-itertools",
347348
"numpy",
348349
"pandas",
349350
"pyarrow",

apis/python/src/tiledbsoma/io/_registration/ambient_label_mappings.py

Lines changed: 588 additions & 443 deletions
Large diffs are not rendered by default.
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
from __future__ import annotations
2+
3+
from typing import Sequence
4+
5+
import pandas as pd
6+
import pyarrow as pa
7+
8+
from tiledbsoma import DataFrame
9+
10+
11+
def get_enumerations(
12+
df: DataFrame, column_names: Sequence[str]
13+
) -> dict[str, pd.CategoricalDtype]:
14+
"""Look up enum info in schema, and return as a Pandas CategoricalDType. This
15+
is a convenience wrapper around ``DataFrame.get_enumeration_values``, for use
16+
in the registration module."""
17+
18+
# skip columns which are not of type dictionary
19+
column_names = [
20+
c for c in column_names if pa.types.is_dictionary(df.schema.field(c).type)
21+
]
22+
return {
23+
k: pd.CategoricalDtype(categories=v, ordered=df.schema.field(k).type.ordered)
24+
for k, v in df.get_enumeration_values(column_names).items()
25+
}
26+
27+
28+
def extend_enumerations(df: DataFrame, columns: dict[str, pd.CategoricalDtype]) -> None:
29+
"""
30+
Extend enumerations as needed, starting with a CategoricalDType for each
31+
cat/enum/dict column. A convenience wrapper around ``DataFrame.extend_enumeration_values``,
32+
for use in the registration module.
33+
34+
DataFrame must be open for write.
35+
"""
36+
37+
current_enums = get_enumerations(df, list(columns.keys()))
38+
columns_to_extend = {}
39+
for column_name, cat_dtype in columns.items():
40+
41+
# first confirm this is a dictionary. If it has been decategorical-ized, i.e.,
42+
# are an array of the value type, don't extend.
43+
if column_name not in current_enums:
44+
assert not pa.types.is_dictionary(df.schema.field(column_name).type)
45+
continue
46+
47+
# determine if we have any new enum values in this column
48+
existing_dtype = current_enums[column_name]
49+
new_enum_values = pd.Index(cat_dtype.categories).difference(
50+
existing_dtype.categories, sort=False
51+
)
52+
if len(new_enum_values) == 0:
53+
continue
54+
55+
# if there are new values, extend the array schema enum
56+
new_enum_values = pa.array(new_enum_values.to_numpy())
57+
columns_to_extend[column_name] = new_enum_values
58+
59+
# and evolve the schema
60+
df.extend_enumeration_values(columns_to_extend, deduplicate=False)
Lines changed: 38 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,102 +1,92 @@
11
# Copyright (c) TileDB, Inc. and The Chan Zuckerberg Initiative Foundation
22
#
33
# Licensed under the MIT License.
4+
from __future__ import annotations
45

5-
from typing import Dict, List, Tuple
6+
from typing import cast
67

78
import anndata as ad
89
import attrs
10+
import numpy as np
11+
import numpy.typing as npt
912
import pandas as pd
1013
from typing_extensions import Self
1114

12-
import tiledbsoma
13-
import tiledbsoma.logging
1415

15-
16-
@attrs.define(kw_only=True)
16+
@attrs.define(kw_only=True, frozen=True)
1717
class AxisIDMapping:
18-
"""
19-
For a single to-be-appended AnnData/H5AD input in SOMA multi-file append-mode ingestion, this
18+
"""For a single to-be-appended AnnData/H5AD input in SOMA multi-file append-mode ingestion, this
2019
class tracks the mapping of input-data ``obs`` or ``var`` 0-up offsets to SOMA join ID values
2120
for the destination SOMA experiment.
2221
23-
See module-level comments for more information.
22+
Private class
2423
"""
2524

26-
# Tuple not List so this can't be modified by accident when passed into some function somewhere
27-
data: Tuple[int, ...]
25+
data: npt.NDArray[np.int64]
2826

29-
def is_identity(self) -> bool:
30-
for i, data in enumerate(self.data):
31-
if data != i:
32-
return False
33-
return True
27+
def __attrs_post_init__(self) -> None:
28+
self.data.setflags(write=False)
3429

3530
def get_shape(self) -> int:
3631
if len(self.data) == 0:
3732
return 0
3833
else:
39-
return 1 + max(self.data)
34+
return int(self.data.max() + 1)
35+
36+
def is_identity(self) -> bool:
37+
# fast rejection first
38+
if self.get_shape() != len(self.data) or self.data[0] != 0:
39+
return False
40+
41+
return np.array_equal(self.data, np.arange(0, len(self.data)))
4042

4143
@classmethod
4244
def identity(cls, n: int) -> Self:
4345
"""This maps 0-up input-file offsets to 0-up soma_joinid values. This is
4446
important for uns arrays which we never grow on ingest --- rather, we
4547
sub-nest the entire recursive ``uns`` data structure.
4648
"""
47-
return cls(data=tuple(range(n)))
49+
return cls(data=np.arange(n, dtype=np.int64))
4850

4951

50-
@attrs.define(kw_only=True)
52+
@attrs.define(kw_only=True, frozen=True)
5153
class ExperimentIDMapping:
52-
"""
53-
For a single to-be-appended AnnData/H5AD input in SOMA multi-file append-mode ingestion, this
54-
class contains an ``ExperimentIDMapping`` for ``obs``, and one ``ExperimentIDMapping`` for
54+
"""For a single to-be-appended AnnData/H5AD input in SOMA multi-file append-mode ingestion, this
55+
class contains an ``AxisIDMapping`` for ``obs``, and one ``AxisIDMapping`` for
5556
``var`` in each measurement.
5657
57-
See module-level comments for more information.
58+
Private class
5859
"""
5960

6061
obs_axis: AxisIDMapping
61-
var_axes: Dict[str, AxisIDMapping]
62+
var_axes: dict[str, AxisIDMapping]
6263

6364
@classmethod
64-
def from_isolated_anndata(
65-
cls,
66-
adata: ad.AnnData,
67-
measurement_name: str,
68-
) -> Self:
69-
"""Factory method to compute offset-to-SOMA-join-ID mappings for a single input file in
70-
isolation. This is used when a user is ingesting a single AnnData/H5AD to a single SOMA
71-
experiment, not in append mode, allowing us to still have the bulk of the ingestor code to
72-
be non-duplicated between non-append mode and append mode.
73-
"""
74-
tiledbsoma.logging.logger.info(
75-
"Registration: registering isolated AnnData object."
76-
)
65+
def from_anndata(cls, adata: ad.AnnData, *, measurement_name: str = "RNA") -> Self:
66+
"""Create a new ID mapping from an AnnData.
7767
78-
obs_mapping = AxisIDMapping(data=tuple(range(len(adata.obs))))
79-
var_axes = {}
80-
var_axes[measurement_name] = AxisIDMapping(data=tuple(range(len(adata.var))))
68+
This is useful for creating a new Experiment from a single AnnData.
69+
"""
70+
obs_axis = AxisIDMapping.identity(len(adata.obs))
71+
var_axes = {measurement_name: AxisIDMapping.identity(len(adata.var))}
8172
if adata.raw is not None:
82-
var_axes["raw"] = AxisIDMapping(data=tuple(range(len(adata.raw.var))))
83-
84-
return cls(obs_axis=obs_mapping, var_axes=var_axes)
73+
var_axes["raw"] = AxisIDMapping.identity(len(adata.raw.var))
74+
return cls(obs_axis=obs_axis, var_axes=var_axes)
8575

8676

87-
def get_dataframe_values(df: pd.DataFrame, field_name: str) -> List[str]:
77+
def get_dataframe_values(df: pd.DataFrame, field_name: str) -> pd.Series: # type: ignore[type-arg]
8878
"""Extracts the label values (e.g. cell barcode, gene symbol) from an AnnData/H5AD
8979
``obs`` or ``var`` dataframe."""
9080
if field_name in df:
91-
values = [str(e) for e in df[field_name]]
81+
values = cast(pd.Series, df[field_name].astype(str)) # type: ignore[type-arg]
9282
elif df.index.name in (field_name, "index", None):
93-
values = list(df.index)
83+
values = cast(pd.Series, df.index.to_series().astype(str)) # type: ignore[type-arg]
9484
else:
95-
raise ValueError(f"could not find field name {field_name} in dataframe")
85+
raise ValueError(f"Could not find field name {field_name} in dataframe.")
9686

9787
# Check the values are unique.
98-
if len(values) != len(set(values)):
88+
if not values.is_unique:
9989
raise ValueError(
100-
f"non-unique registration values have been provided in field {field_name}"
90+
f"Non-unique registration values have been provided in field {field_name}."
10191
)
10292
return values

apis/python/src/tiledbsoma/io/_util.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,8 @@ def _hack_patch_anndata() -> ContextManager[object]:
113113

114114
@file_backing.AnnDataFileManager.filename.setter # type: ignore[misc]
115115
def filename(
116-
self: file_backing.AnnDataFileManager, filename: Union[Path, _FSPathWrapper]
116+
self: file_backing.AnnDataFileManager,
117+
filename: Union[Path, _FSPathWrapper, None],
117118
) -> None:
118119
self._filename = filename
119120

0 commit comments

Comments
 (0)