diff --git a/src/scanpy/_compat.py b/src/scanpy/_compat.py index 4af24fd6a9..1b100e8fc5 100644 --- a/src/scanpy/_compat.py +++ b/src/scanpy/_compat.py @@ -4,10 +4,14 @@ from dataclasses import dataclass, field from functools import cache, partial from pathlib import Path +from typing import TYPE_CHECKING from legacy_api_wrap import legacy_api from packaging.version import Version +if TYPE_CHECKING: + from importlib.metadata import PackageMetadata + try: from dask.array import Array as DaskArray except ImportError: @@ -60,14 +64,14 @@ def __exit__(self, *_excinfo) -> None: os.chdir(self._old_cwd.pop()) -def pkg_metadata(package): +def pkg_metadata(package) -> PackageMetadata: from importlib.metadata import metadata return metadata(package) @cache -def pkg_version(package): +def pkg_version(package) -> Version: from importlib.metadata import version return Version(version(package)) diff --git a/src/scanpy/tools/_dendrogram.py b/src/scanpy/tools/_dendrogram.py index f60f0ae2e9..eb7acaa030 100644 --- a/src/scanpy/tools/_dendrogram.py +++ b/src/scanpy/tools/_dendrogram.py @@ -6,12 +6,14 @@ from typing import TYPE_CHECKING -import pandas as pd +from anndata import AnnData from pandas.api.types import CategoricalDtype +from scanpy.get._aggregated import aggregate + from .. import logging as logg from .._compat import old_positionals -from .._utils import _doc_params, raise_not_implemented_error_if_backed_type +from .._utils import _doc_params from ..neighbors._doc import doc_n_pcs, doc_use_rep from ._utils import _choose_representation @@ -19,7 +21,8 @@ from collections.abc import Sequence from typing import Any - from anndata import AnnData + import numpy as np + import pandas as pd @old_positionals( @@ -118,7 +121,6 @@ def dendrogram( >>> sc.pl.dotplot(adata, markers, groupby='bulk_labels', dendrogram=True) """ - raise_not_implemented_error_if_backed_type(adata.X, "dendrogram") if isinstance(groupby, str): # if not a list, turn into a list groupby = [groupby] @@ -135,10 +137,11 @@ def dendrogram( ) if var_names is None: - rep_df = pd.DataFrame( - _choose_representation(adata, use_rep=use_rep, n_pcs=n_pcs) + rep = _choose_representation(adata, use_rep=use_rep, n_pcs=n_pcs) + rep_adata = AnnData(rep) + categorical: pd.Series[pd.CategoricalDtype] = adata.obs[groupby[0]].astype( + "category" ) - categorical = adata.obs[groupby[0]] if len(groupby) > 1: for group in groupby[1:]: # create new category by merging the given groupby categories @@ -147,8 +150,8 @@ def dendrogram( ).astype("category") categorical.name = "_".join(groupby) - rep_df.set_index(categorical, inplace=True) - categories: pd.Index = rep_df.index.categories + rep_adata.obs["cat"] = categorical + categories: pd.Index = categorical.cat.categories else: gene_names = adata.raw.var_names if use_raw else adata.var_names from ..plotting._anndata import _prepare_dataframe @@ -156,18 +159,16 @@ def dendrogram( categories, rep_df = _prepare_dataframe( adata, gene_names, groupby, use_raw=use_raw ) + rep_adata = AnnData(rep_df) + rep_adata.obs["cat"] = rep_df.index # aggregate values within categories using 'mean' - mean_df = ( - rep_df.groupby(level=0, observed=True) - .mean() - .loc[categories] # Fixed ordering for pandas < 2 - ) + means: np.ndarray = aggregate(rep_adata, by="cat", func="mean").layers["mean"] import scipy.cluster.hierarchy as sch from scipy.spatial import distance - corr_matrix = mean_df.T.corr(method=cor_method).clip(-1, 1) + corr_matrix = means.T.corr(method=cor_method).clip(-1, 1) corr_condensed = distance.squareform(1 - corr_matrix) z_var = sch.linkage( corr_condensed, method=linkage_method, optimal_ordering=optimal_ordering diff --git a/tests/test_backed.py b/tests/test_backed.py index 787edf9c21..876b76e344 100644 --- a/tests/test_backed.py +++ b/tests/test_backed.py @@ -2,10 +2,13 @@ from functools import partial +import h5py import pytest from anndata import read_h5ad +from packaging.version import Version import scanpy as sc +from scanpy._compat import pkg_version @pytest.mark.parametrize( @@ -94,5 +97,11 @@ def test_scatter_backed(backed_adata): sc.pl.scatter(backed_adata, color="0", basis="pca") -def test_dotplot_backed(backed_adata): - sc.pl.dotplot(backed_adata, ["0", "1", "2", "3"], groupby="cat") +def test_dotplot_backed(request: pytest.FixtureRequest, backed_adata): + if isinstance(backed_adata.X, h5py.Dataset) and pkg_version("anndata") < Version( + "0.11.0.dev100" + ): + reason = "anndata bug when setting X to h5py.Dataset" + request.applymarker(pytest.mark.xfail(reason=reason)) + sc.tl.dendrogram(backed_adata, "cat") + sc.pl.dotplot(backed_adata, ["0", "1", "2", "3"], groupby="cat", dendrogram=True)