|
1 | 1 | # Copyright (c) TileDB, Inc. and The Chan Zuckerberg Initiative Foundation
|
2 | 2 | #
|
3 | 3 | # Licensed under the MIT License.
|
| 4 | +from __future__ import annotations |
4 | 5 |
|
5 |
| -from typing import Dict, List, Tuple |
| 6 | +from typing import cast |
6 | 7 |
|
7 | 8 | import anndata as ad
|
8 | 9 | import attrs
|
| 10 | +import numpy as np |
| 11 | +import numpy.typing as npt |
9 | 12 | import pandas as pd
|
10 | 13 | from typing_extensions import Self
|
11 | 14 |
|
12 |
| -import tiledbsoma |
13 |
| -import tiledbsoma.logging |
14 | 15 |
|
15 |
| - |
16 |
| -@attrs.define(kw_only=True) |
| 16 | +@attrs.define(kw_only=True, frozen=True) |
17 | 17 | class AxisIDMapping:
|
18 |
| - """ |
19 |
| - For a single to-be-appended AnnData/H5AD input in SOMA multi-file append-mode ingestion, this |
| 18 | + """For a single to-be-appended AnnData/H5AD input in SOMA multi-file append-mode ingestion, this |
20 | 19 | class tracks the mapping of input-data ``obs`` or ``var`` 0-up offsets to SOMA join ID values
|
21 | 20 | for the destination SOMA experiment.
|
22 | 21 |
|
23 |
| - See module-level comments for more information. |
| 22 | + Private class |
24 | 23 | """
|
25 | 24 |
|
26 |
| - # Tuple not List so this can't be modified by accident when passed into some function somewhere |
27 |
| - data: Tuple[int, ...] |
| 25 | + data: npt.NDArray[np.int64] |
28 | 26 |
|
29 |
| - def is_identity(self) -> bool: |
30 |
| - for i, data in enumerate(self.data): |
31 |
| - if data != i: |
32 |
| - return False |
33 |
| - return True |
| 27 | + def __attrs_post_init__(self) -> None: |
| 28 | + self.data.setflags(write=False) |
34 | 29 |
|
35 | 30 | def get_shape(self) -> int:
|
36 | 31 | if len(self.data) == 0:
|
37 | 32 | return 0
|
38 | 33 | else:
|
39 |
| - return 1 + max(self.data) |
| 34 | + return int(self.data.max() + 1) |
| 35 | + |
| 36 | + def is_identity(self) -> bool: |
| 37 | + # fast rejection first |
| 38 | + if self.get_shape() != len(self.data) or self.data[0] != 0: |
| 39 | + return False |
| 40 | + |
| 41 | + return np.array_equal(self.data, np.arange(0, len(self.data))) |
40 | 42 |
|
41 | 43 | @classmethod
|
42 | 44 | def identity(cls, n: int) -> Self:
|
43 | 45 | """This maps 0-up input-file offsets to 0-up soma_joinid values. This is
|
44 | 46 | important for uns arrays which we never grow on ingest --- rather, we
|
45 | 47 | sub-nest the entire recursive ``uns`` data structure.
|
46 | 48 | """
|
47 |
| - return cls(data=tuple(range(n))) |
| 49 | + return cls(data=np.arange(n, dtype=np.int64)) |
48 | 50 |
|
49 | 51 |
|
50 |
| -@attrs.define(kw_only=True) |
| 52 | +@attrs.define(kw_only=True, frozen=True) |
51 | 53 | class ExperimentIDMapping:
|
52 |
| - """ |
53 |
| - For a single to-be-appended AnnData/H5AD input in SOMA multi-file append-mode ingestion, this |
54 |
| - class contains an ``ExperimentIDMapping`` for ``obs``, and one ``ExperimentIDMapping`` for |
| 54 | + """For a single to-be-appended AnnData/H5AD input in SOMA multi-file append-mode ingestion, this |
| 55 | + class contains an ``AxisIDMapping`` for ``obs``, and one ``AxisIDMapping`` for |
55 | 56 | ``var`` in each measurement.
|
56 | 57 |
|
57 |
| - See module-level comments for more information. |
| 58 | + Private class |
58 | 59 | """
|
59 | 60 |
|
60 | 61 | obs_axis: AxisIDMapping
|
61 |
| - var_axes: Dict[str, AxisIDMapping] |
| 62 | + var_axes: dict[str, AxisIDMapping] |
62 | 63 |
|
63 | 64 | @classmethod
|
64 |
| - def from_isolated_anndata( |
65 |
| - cls, |
66 |
| - adata: ad.AnnData, |
67 |
| - measurement_name: str, |
68 |
| - ) -> Self: |
69 |
| - """Factory method to compute offset-to-SOMA-join-ID mappings for a single input file in |
70 |
| - isolation. This is used when a user is ingesting a single AnnData/H5AD to a single SOMA |
71 |
| - experiment, not in append mode, allowing us to still have the bulk of the ingestor code to |
72 |
| - be non-duplicated between non-append mode and append mode. |
73 |
| - """ |
74 |
| - tiledbsoma.logging.logger.info( |
75 |
| - "Registration: registering isolated AnnData object." |
76 |
| - ) |
| 65 | + def from_anndata(cls, adata: ad.AnnData, *, measurement_name: str = "RNA") -> Self: |
| 66 | + """Create a new ID mapping from an AnnData. |
77 | 67 |
|
78 |
| - obs_mapping = AxisIDMapping(data=tuple(range(len(adata.obs)))) |
79 |
| - var_axes = {} |
80 |
| - var_axes[measurement_name] = AxisIDMapping(data=tuple(range(len(adata.var)))) |
| 68 | + This is useful for creating a new Experiment from a single AnnData. |
| 69 | + """ |
| 70 | + obs_axis = AxisIDMapping.identity(len(adata.obs)) |
| 71 | + var_axes = {measurement_name: AxisIDMapping.identity(len(adata.var))} |
81 | 72 | if adata.raw is not None:
|
82 |
| - var_axes["raw"] = AxisIDMapping(data=tuple(range(len(adata.raw.var)))) |
83 |
| - |
84 |
| - return cls(obs_axis=obs_mapping, var_axes=var_axes) |
| 73 | + var_axes["raw"] = AxisIDMapping.identity(len(adata.raw.var)) |
| 74 | + return cls(obs_axis=obs_axis, var_axes=var_axes) |
85 | 75 |
|
86 | 76 |
|
87 |
| -def get_dataframe_values(df: pd.DataFrame, field_name: str) -> List[str]: |
| 77 | +def get_dataframe_values(df: pd.DataFrame, field_name: str) -> pd.Series: # type: ignore[type-arg] |
88 | 78 | """Extracts the label values (e.g. cell barcode, gene symbol) from an AnnData/H5AD
|
89 | 79 | ``obs`` or ``var`` dataframe."""
|
90 | 80 | if field_name in df:
|
91 |
| - values = [str(e) for e in df[field_name]] |
| 81 | + values = cast(pd.Series, df[field_name].astype(str)) # type: ignore[type-arg] |
92 | 82 | elif df.index.name in (field_name, "index", None):
|
93 |
| - values = list(df.index) |
| 83 | + values = cast(pd.Series, df.index.to_series().astype(str)) # type: ignore[type-arg] |
94 | 84 | else:
|
95 |
| - raise ValueError(f"could not find field name {field_name} in dataframe") |
| 85 | + raise ValueError(f"Could not find field name {field_name} in dataframe.") |
96 | 86 |
|
97 | 87 | # Check the values are unique.
|
98 |
| - if len(values) != len(set(values)): |
| 88 | + if not values.is_unique: |
99 | 89 | raise ValueError(
|
100 |
| - f"non-unique registration values have been provided in field {field_name}" |
| 90 | + f"Non-unique registration values have been provided in field {field_name}." |
101 | 91 | )
|
102 | 92 | return values
|
0 commit comments