From a47a169a73c9bf1317f59a36eff34356e172f983 Mon Sep 17 00:00:00 2001 From: Andreas Albert Date: Fri, 25 Apr 2025 18:22:03 +0200 Subject: [PATCH] fix: Make enums tolerate multiple sequence types --- dataframely/columns/enum.py | 2 +- tests/column_types/test_enum.py | 9 +++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/dataframely/columns/enum.py b/dataframely/columns/enum.py index c5f75f0..7f2e43b 100644 --- a/dataframely/columns/enum.py +++ b/dataframely/columns/enum.py @@ -49,7 +49,7 @@ def __init__( alias=alias, metadata=metadata, ) - self.categories = categories + self.categories = list(categories) @property def dtype(self) -> pl.DataType: diff --git a/tests/column_types/test_enum.py b/tests/column_types/test_enum.py index 5a0d8a3..f2389b3 100644 --- a/tests/column_types/test_enum.py +++ b/tests/column_types/test_enum.py @@ -52,3 +52,12 @@ def test_valid_cast( schema = create_schema("test", {"a": enum}) df = df_type(data) assert schema.is_valid(df, cast=True) == valid + + +@pytest.mark.parametrize("type1", [list, tuple]) +@pytest.mark.parametrize("type2", [list, tuple]) +def test_different_sequences(type1: type, type2: type) -> None: + allowed = ["a", "b"] + S = create_schema("test", {"x": dy.Enum(type1(allowed))}) + df = pl.DataFrame({"x": pl.Series(["a", "b"], dtype=pl.Enum(type2(allowed)))}) + S.validate(df)