Skip to content

Use numpy.typing.DTypeLike #594

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ jobs:
- name: Run pre-commit
uses: pre-commit/action@v2.0.0
- name: Check for Sphinx doc warnings
if: contains(matrix.python-version, '3.8')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ravwojdyla thanks for your (offline) suggestion to do this. The workflow syntax works perfectly. But the build is failing on 3.8 too, so I need to look into that!

run: |
cd docs
make html SPHINXOPTS="-W --keep-going -n"
Expand Down
6 changes: 5 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,11 @@ def filter(self, record: pylogging.LogRecord) -> bool:

autosummary_generate = True

nitpick_ignore = [("py:class", "sgkit.display.GenotypeDisplay")]
nitpick_ignore = [
("py:class", "sgkit.display.GenotypeDisplay"),
("py:class", "numpy.typing._dtype_like._DTypeDict"),
("py:class", "numpy.typing._dtype_like._SupportsDType"),
]


# FIXME: Workaround for linking xarray module
Expand Down
3 changes: 3 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,9 @@ ignore_missing_imports = True
ignore_missing_imports = True
[mypy-sgkit.*]
allow_redefinition = True
[mypy-sgkit.io.bgen.*]
# avoid warning on unused ignore for Python 3.8, but not unused for 3.7
warn_unused_ignores = False
[mypy-sgkit.*.tests.*]
disallow_untyped_defs = False
disallow_untyped_decorators = False
Expand Down
22 changes: 15 additions & 7 deletions sgkit/io/bgen/bgen_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,20 @@
import xarray as xr
import zarr
from cbgen import bgen_file, bgen_metafile
from numpy.typing import DTypeLike
from rechunker import api as rechunker_api
from xarray import Dataset

from sgkit import create_genotype_dosage_dataset
from sgkit.io.utils import dataframe_to_dict, encode_contigs
from sgkit.typing import ArrayLike, DType, PathType
from sgkit.typing import ArrayLike, PathType

try:
# needed to avoid Sphinx forward reference error for DTypeLike
# try block is needed since SupportsIndex is in Python 3.7
from typing import SupportsIndex # type: ignore # noqa: F401
except ImportError: # pragma: no cover
pass

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -60,7 +68,7 @@ def __init__(
self,
path: PathType,
metafile_path: Optional[PathType] = None,
dtype: DType = "float32",
dtype: DTypeLike = "float32",
) -> None:
self.path = Path(path)
self.metafile_path = (
Expand Down Expand Up @@ -202,8 +210,8 @@ def read_bgen(
chunks: Union[str, int, Tuple[int, int, int]] = "auto",
lock: bool = False,
persist: bool = True,
contig_dtype: DType = "str",
gp_dtype: DType = "float32",
contig_dtype: DTypeLike = "str",
gp_dtype: DTypeLike = "float32",
) -> Dataset:
"""Read BGEN dataset.

Expand Down Expand Up @@ -394,7 +402,7 @@ def pack_variables(ds: Dataset) -> Dataset:
return ds


def unpack_variables(ds: Dataset, dtype: DType = "float32") -> Dataset:
def unpack_variables(ds: Dataset, dtype: DTypeLike = "float32") -> Dataset:
# Restore homozygous reference GP
gp = ds["call_genotype_probability"].astype(dtype)
if gp.sizes["genotypes"] != 2:
Expand Down Expand Up @@ -423,7 +431,7 @@ def rechunk_bgen(
chunk_length: int = 10_000,
chunk_width: int = 1_000,
compressor: Optional[Any] = zarr.Blosc(cname="zstd", clevel=7, shuffle=2),
probability_dtype: Optional[DType] = "uint8",
probability_dtype: Optional[DTypeLike] = "uint8",
max_mem: str = "4GB",
pack: bool = True,
tempdir: Optional[PathType] = None,
Expand Down Expand Up @@ -533,7 +541,7 @@ def bgen_to_zarr(
chunk_width: int = 1_000,
temp_chunk_length: int = 100,
compressor: Optional[Any] = zarr.Blosc(cname="zstd", clevel=7, shuffle=2),
probability_dtype: Optional[DType] = "uint8",
probability_dtype: Optional[DTypeLike] = "uint8",
max_mem: str = "4GB",
pack: bool = True,
tempdir: Optional[PathType] = None,
Expand Down
7 changes: 4 additions & 3 deletions sgkit/io/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@
import numpy as np
import xarray as xr
import zarr
from numpy.typing import DTypeLike

from ..typing import ArrayLike, DType
from ..typing import ArrayLike
from ..utils import encode_array, max_str_len


def dataframe_to_dict(
df: dd.DataFrame, dtype: Optional[Mapping[str, DType]] = None
df: dd.DataFrame, dtype: Optional[Mapping[str, DTypeLike]] = None
) -> Mapping[str, ArrayLike]:
""" Convert dask dataframe to dictionary of arrays """
arrs = {}
Expand Down Expand Up @@ -110,7 +111,7 @@ def zarrs_to_dataset(
def concatenate_and_rechunk(
zarrs: Sequence[zarr.Array],
chunks: Optional[Tuple[int, ...]] = None,
dtype: DType = None,
dtype: DTypeLike = None,
) -> da.Array:
"""Perform a concatenate and rechunk operation on a collection of Zarr arrays
to produce an array with a uniform chunking, suitable for saving as
Expand Down
5 changes: 3 additions & 2 deletions sgkit/io/vcf/vcf_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@
import numpy as np
import xarray as xr
from cyvcf2 import VCF, Variant
from numpy.typing import DTypeLike

from sgkit.io.utils import zarrs_to_dataset
from sgkit.io.vcf import partition_into_regions
from sgkit.io.vcf.utils import build_url, chunks, temporary_directory, url_filename
from sgkit.io.vcfzarr_reader import vcf_number_to_dimension_and_size
from sgkit.model import DIM_SAMPLE, DIM_VARIANT, create_genotype_call_dataset
from sgkit.typing import ArrayLike, DType, PathType
from sgkit.typing import ArrayLike, PathType
from sgkit.utils import max_str_len

DEFAULT_MAX_ALT_ALLELES = (
Expand Down Expand Up @@ -104,7 +105,7 @@ def _normalize_fields(vcf: VCF, fields: Sequence[str]) -> Sequence[str]:

def _vcf_type_to_numpy_type_and_fill_value(
vcf_type: str, category: str, key: str
) -> Tuple[DType, Any]:
) -> Tuple[DTypeLike, Any]:
"""Convert the VCF Type to a NumPy dtype and fill value."""
if vcf_type == "Flag":
return "bool", False
Expand Down
13 changes: 7 additions & 6 deletions sgkit/stats/ld.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
import pandas as pd
from dask.dataframe import DataFrame
from numba import njit
from numpy.typing import DTypeLike
from xarray import Dataset

from sgkit import variables
from sgkit.typing import ArrayLike, DType
from sgkit.typing import ArrayLike
from sgkit.window import _get_chunked_windows, _sizes_to_start_offsets, has_windows


Expand Down Expand Up @@ -205,8 +206,8 @@ def _ld_matrix_jit(
chunk_window_stops: ArrayLike,
abs_chunk_start: int,
chunk_max_window_start: int,
index_dtype: DType,
value_dtype: DType,
index_dtype: DTypeLike,
value_dtype: DTypeLike,
threshold: float,
scores: ArrayLike,
) -> List[Any]: # pragma: no cover
Expand Down Expand Up @@ -246,7 +247,7 @@ def _ld_matrix_jit(

if no_threshold or (res >= threshold and np.isfinite(res)):
rows.append(
(index_dtype(index), index_dtype(other), value_dtype(res), cmp)
(index_dtype(index), index_dtype(other), value_dtype(res), cmp) # type: ignore
)

return rows
Expand All @@ -258,8 +259,8 @@ def _ld_matrix(
chunk_window_stops: ArrayLike,
abs_chunk_start: int,
chunk_max_window_start: int,
index_dtype: DType,
value_dtype: DType,
index_dtype: DTypeLike,
value_dtype: DTypeLike,
threshold: float = np.nan,
scores: Optional[ArrayLike] = None,
) -> ArrayLike:
Expand Down
5 changes: 3 additions & 2 deletions sgkit/stats/pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
import numpy as np
import xarray as xr
from dask_ml.decomposition import TruncatedSVD
from numpy.typing import DTypeLike
from sklearn.base import BaseEstimator
from sklearn.pipeline import Pipeline
from typing_extensions import Literal
from xarray import DataArray, Dataset

from sgkit import variables

from ..typing import ArrayLike, DType, RandomStateType
from ..typing import ArrayLike, RandomStateType
from ..utils import conditional_merge_datasets
from .aggregation import count_call_alleles
from .preprocessing import PattersonScaler
Expand Down Expand Up @@ -331,7 +332,7 @@ def _allele_counts(
ds: Dataset,
variable: str,
check_missing: bool = True,
dtype: DType = "float32",
dtype: DTypeLike = "float32",
) -> DataArray:
if variable not in ds:
ds = count_call_alternate_alleles(ds)
Expand Down
5 changes: 3 additions & 2 deletions sgkit/tests/test_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,19 @@
import numpy as np
import pytest
import xarray as xr
from numpy.typing import DTypeLike

import sgkit.stats.preprocessing
from sgkit import simulate_genotype_call_dataset
from sgkit.typing import ArrayLike, DType
from sgkit.typing import ArrayLike


def simulate_alternate_allele_counts(
n_variant: int,
n_sample: int,
ploidy: int,
chunks: Any = (10, 10),
dtype: DType = "i",
dtype: DTypeLike = "i",
seed: int = 0,
) -> ArrayLike:
rs = da.random.RandomState(seed)
Expand Down
3 changes: 1 addition & 2 deletions sgkit/typing.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from pathlib import Path
from typing import Any, Union
from typing import Union

import dask.array as da
import numpy as np

ArrayLike = Union[np.ndarray, da.Array]
DType = Any
PathType = Union[str, Path]
RandomStateType = Union[np.random.RandomState, da.random.RandomState, int]
5 changes: 3 additions & 2 deletions sgkit/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@

import numpy as np
from numba import guvectorize
from numpy.typing import DTypeLike
from xarray import Dataset

from . import variables
from .typing import ArrayLike, DType
from .typing import ArrayLike


def check_array_like(
a: Any,
dtype: Union[None, DType, Set[DType]] = None,
dtype: Union[None, DTypeLike, Set[DTypeLike]] = None,
kind: Union[None, str, Set[str]] = None,
ndim: Union[None, int, Set[int]] = None,
) -> None:
Expand Down
7 changes: 4 additions & 3 deletions sgkit/window.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@

import dask.array as da
import numpy as np
from numpy.typing import DTypeLike
from xarray import Dataset

from sgkit.utils import conditional_merge_datasets, create_dataset
from sgkit.variables import window_contig, window_start, window_stop

from .typing import ArrayLike, DType
from .typing import ArrayLike

# Window definition (user code)

Expand Down Expand Up @@ -110,7 +111,7 @@ def moving_statistic(
statistic: Callable[..., ArrayLike],
size: int,
step: int,
dtype: DType,
dtype: DTypeLike,
**kwargs: Any,
) -> da.Array:
"""A Dask implementation of scikit-allel's moving_statistic function."""
Expand All @@ -135,7 +136,7 @@ def window_statistic(
statistic: Callable[..., ArrayLike],
window_starts: ArrayLike,
window_stops: ArrayLike,
dtype: DType,
dtype: DTypeLike,
chunks: Any = None,
new_axis: Union[None, int, Iterable[int]] = None,
**kwargs: Any,
Expand Down