From 4a3f8aadf2bc2aaea3f6dcc3442c43bdaac92976 Mon Sep 17 00:00:00 2001 From: Brock Date: Sun, 27 Dec 2020 13:37:53 -0800 Subject: [PATCH 1/7] BUG: GroupBy.idxmax/idxmin with EA dtypes --- pandas/core/arrays/base.py | 8 ++++++-- pandas/tests/groupby/test_function.py | 6 ++++++ 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index bd5cf43e19e9f..df563015d7ecd 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -591,7 +591,7 @@ def argsort( mask=np.asarray(self.isna()), ) - def argmin(self): + def argmin(self, axis=None, skipna: bool = True): """ Return the index of minimum value. @@ -606,9 +606,11 @@ def argmin(self): -------- ExtensionArray.argmax """ + if not skipna: + raise NotImplementedError return nargminmax(self, "argmin") - def argmax(self): + def argmax(self, axis=None, skipna: bool = True): """ Return the index of maximum value. @@ -623,6 +625,8 @@ def argmax(self): -------- ExtensionArray.argmin """ + if not skipna: + raise NotImplementedError return nargminmax(self, "argmax") def fillna(self, value=None, method=None, limit=None): diff --git a/pandas/tests/groupby/test_function.py b/pandas/tests/groupby/test_function.py index 8d7fcbfcfe694..f532e496ccca9 100644 --- a/pandas/tests/groupby/test_function.py +++ b/pandas/tests/groupby/test_function.py @@ -531,10 +531,16 @@ def test_idxmin_idxmax_returns_int_types(func, values): } ) df["c_date"] = pd.to_datetime(df["c_date"]) + df["c_date_tz"] = df["c_date"].dt.tz_localize("US/Pacific") + df["c_timedelta"] = df["c_date"] - df["c_date"].iloc[0] + df["c_period"] = df["c_date"].dt.to_period("W") result = getattr(df.groupby("name"), func)() expected = DataFrame(values, index=Index(["A", "B"], name="name")) + expected["c_date_tz"] = expected["c_date"] + expected["c_timedelta"] = expected["c_date"] + expected["c_period"] = expected["c_date"] tm.assert_frame_equal(result, expected) From b9dfbd6e3a0cc4dc871408f63105545936fc27d2 Mon Sep 17 00:00:00 2001 From: Brock Date: Sun, 17 Jan 2021 16:44:36 -0800 Subject: [PATCH 2/7] TST: extension test for argmin/argmax with skipna --- pandas/core/arrays/base.py | 16 +++++++++------- pandas/core/base.py | 10 ++-------- pandas/tests/extension/base/methods.py | 15 +++++++++++++++ 3 files changed, 26 insertions(+), 15 deletions(-) diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index 51149de502243..317cee73777ef 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -30,7 +30,7 @@ from pandas.compat.numpy import function as nv from pandas.errors import AbstractMethodError from pandas.util._decorators import Appender, Substitution -from pandas.util._validators import validate_fillna_kwargs +from pandas.util._validators import validate_bool_kwarg, validate_fillna_kwargs from pandas.core.dtypes.cast import maybe_cast_to_extension_array from pandas.core.dtypes.common import ( @@ -596,7 +596,7 @@ def argsort( mask=np.asarray(self.isna()), ) - def argmin(self, axis=None, skipna: bool = True): + def argmin(self, skipna: bool = True) -> int: """ Return the index of minimum value. @@ -611,11 +611,12 @@ def argmin(self, axis=None, skipna: bool = True): -------- ExtensionArray.argmax """ - if not skipna: - raise NotImplementedError + validate_bool_kwarg(skipna, "skipna") + if not skipna and self.isna().any(): + return -1 return nargminmax(self, "argmin") - def argmax(self, axis=None, skipna: bool = True): + def argmax(self, skipna: bool = True) -> int: """ Return the index of maximum value. @@ -630,8 +631,9 @@ def argmax(self, axis=None, skipna: bool = True): -------- ExtensionArray.argmin """ - if not skipna: - raise NotImplementedError + validate_bool_kwarg(skipna, "skipna") + if not skipna and self.isna().any(): + return -1 return nargminmax(self, "argmax") def fillna(self, value=None, method=None, limit=None): diff --git a/pandas/core/base.py b/pandas/core/base.py index b603ba31f51dd..feba5f477510b 100644 --- a/pandas/core/base.py +++ b/pandas/core/base.py @@ -726,10 +726,7 @@ def argmax(self, axis=None, skipna: bool = True, *args, **kwargs) -> int: skipna = nv.validate_argmax_with_skipna(skipna, args, kwargs) if isinstance(delegate, ExtensionArray): - if not skipna and delegate.isna().any(): - return -1 - else: - return delegate.argmax() + return delegate.argmax(skipna=skipna) else: return nanops.nanargmax(delegate, skipna=skipna) @@ -784,10 +781,7 @@ def argmin(self, axis=None, skipna=True, *args, **kwargs) -> int: skipna = nv.validate_argmin_with_skipna(skipna, args, kwargs) if isinstance(delegate, ExtensionArray): - if not skipna and delegate.isna().any(): - return -1 - else: - return delegate.argmin() + return delegate.argmin(skipna=skipna) else: return nanops.nanargmin(delegate, skipna=skipna) diff --git a/pandas/tests/extension/base/methods.py b/pandas/tests/extension/base/methods.py index 7e7f1f1a6e025..238bf502a7c62 100644 --- a/pandas/tests/extension/base/methods.py +++ b/pandas/tests/extension/base/methods.py @@ -107,6 +107,21 @@ def test_argmin_argmax_all_na(self, method, data, na_value): with pytest.raises(ValueError, match=err_msg): getattr(data_na, method)() + @pytest.mark.parametrize( + "op_name, skipna, expected", + [ + ("argmax", True, 0), + ("argmin", True, 2), + ("argmax", False, -1), + ("argmin", False, -1), + ], + ) + def test_argmin_argmax_skipna( + self, op_name, skipna, expected, data_missing_for_sorting + ): + result = getattr(data_missing_for_sorting, op_name)(skipna=skipna) + tm.assert_almost_equal(result, expected) + @pytest.mark.parametrize( "op_name, skipna, expected", [ From 305b934ea8835590e7f20f59579b039fa39cdf75 Mon Sep 17 00:00:00 2001 From: Brock Date: Mon, 18 Jan 2021 09:46:09 -0800 Subject: [PATCH 3/7] revert --- pandas/core/arrays/base.py | 4 ++-- pandas/core/base.py | 10 ++++++++-- pandas/tests/extension/base/methods.py | 15 --------------- 3 files changed, 10 insertions(+), 19 deletions(-) diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index 317cee73777ef..26588e9a21e93 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -613,7 +613,7 @@ def argmin(self, skipna: bool = True) -> int: """ validate_bool_kwarg(skipna, "skipna") if not skipna and self.isna().any(): - return -1 + raise NotImplementedError return nargminmax(self, "argmin") def argmax(self, skipna: bool = True) -> int: @@ -633,7 +633,7 @@ def argmax(self, skipna: bool = True) -> int: """ validate_bool_kwarg(skipna, "skipna") if not skipna and self.isna().any(): - return -1 + raise NotImplementedError return nargminmax(self, "argmax") def fillna(self, value=None, method=None, limit=None): diff --git a/pandas/core/base.py b/pandas/core/base.py index feba5f477510b..b603ba31f51dd 100644 --- a/pandas/core/base.py +++ b/pandas/core/base.py @@ -726,7 +726,10 @@ def argmax(self, axis=None, skipna: bool = True, *args, **kwargs) -> int: skipna = nv.validate_argmax_with_skipna(skipna, args, kwargs) if isinstance(delegate, ExtensionArray): - return delegate.argmax(skipna=skipna) + if not skipna and delegate.isna().any(): + return -1 + else: + return delegate.argmax() else: return nanops.nanargmax(delegate, skipna=skipna) @@ -781,7 +784,10 @@ def argmin(self, axis=None, skipna=True, *args, **kwargs) -> int: skipna = nv.validate_argmin_with_skipna(skipna, args, kwargs) if isinstance(delegate, ExtensionArray): - return delegate.argmin(skipna=skipna) + if not skipna and delegate.isna().any(): + return -1 + else: + return delegate.argmin() else: return nanops.nanargmin(delegate, skipna=skipna) diff --git a/pandas/tests/extension/base/methods.py b/pandas/tests/extension/base/methods.py index 238bf502a7c62..7e7f1f1a6e025 100644 --- a/pandas/tests/extension/base/methods.py +++ b/pandas/tests/extension/base/methods.py @@ -107,21 +107,6 @@ def test_argmin_argmax_all_na(self, method, data, na_value): with pytest.raises(ValueError, match=err_msg): getattr(data_na, method)() - @pytest.mark.parametrize( - "op_name, skipna, expected", - [ - ("argmax", True, 0), - ("argmin", True, 2), - ("argmax", False, -1), - ("argmin", False, -1), - ], - ) - def test_argmin_argmax_skipna( - self, op_name, skipna, expected, data_missing_for_sorting - ): - result = getattr(data_missing_for_sorting, op_name)(skipna=skipna) - tm.assert_almost_equal(result, expected) - @pytest.mark.parametrize( "op_name, skipna, expected", [ From 403951c3f950e07bcfadfbc844155d22c5b211b5 Mon Sep 17 00:00:00 2001 From: Brock Date: Mon, 18 Jan 2021 14:41:34 -0800 Subject: [PATCH 4/7] whatsnew --- doc/source/whatsnew/v1.3.0.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/source/whatsnew/v1.3.0.rst b/doc/source/whatsnew/v1.3.0.rst index 6a85bfd852e19..b19d6d282ca54 100644 --- a/doc/source/whatsnew/v1.3.0.rst +++ b/doc/source/whatsnew/v1.3.0.rst @@ -339,6 +339,7 @@ Groupby/resample/rolling - Fixed bug in :meth:`DataFrameGroupBy.sum` and :meth:`SeriesGroupBy.sum` causing loss of precision through using Kahan summation (:issue:`38778`) - Fixed bug in :meth:`DataFrameGroupBy.cumsum`, :meth:`SeriesGroupBy.cumsum`, :meth:`DataFrameGroupBy.mean` and :meth:`SeriesGroupBy.mean` causing loss of precision through using Kahan summation (:issue:`38934`) - Bug in :meth:`.Resampler.aggregate` and :meth:`DataFrame.transform` raising ``TypeError`` instead of ``SpecificationError`` when missing keys had mixed dtypes (:issue:`39025`) +- Bug in :meth:`GroupBy.idxmin` and :meth:`GroupBy.idxmax` with ``ExtendionDtype`` columns (:issue:`38733`) Reshaping ^^^^^^^^^ From ed7190a49d18e4901818d36551bea68554926a86 Mon Sep 17 00:00:00 2001 From: jbrockmendel Date: Tue, 19 Jan 2021 07:22:43 -0800 Subject: [PATCH 5/7] Update doc/source/whatsnew/v1.3.0.rst Co-authored-by: Joris Van den Bossche --- doc/source/whatsnew/v1.3.0.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/source/whatsnew/v1.3.0.rst b/doc/source/whatsnew/v1.3.0.rst index b19d6d282ca54..a249d0cf39a0b 100644 --- a/doc/source/whatsnew/v1.3.0.rst +++ b/doc/source/whatsnew/v1.3.0.rst @@ -339,7 +339,7 @@ Groupby/resample/rolling - Fixed bug in :meth:`DataFrameGroupBy.sum` and :meth:`SeriesGroupBy.sum` causing loss of precision through using Kahan summation (:issue:`38778`) - Fixed bug in :meth:`DataFrameGroupBy.cumsum`, :meth:`SeriesGroupBy.cumsum`, :meth:`DataFrameGroupBy.mean` and :meth:`SeriesGroupBy.mean` causing loss of precision through using Kahan summation (:issue:`38934`) - Bug in :meth:`.Resampler.aggregate` and :meth:`DataFrame.transform` raising ``TypeError`` instead of ``SpecificationError`` when missing keys had mixed dtypes (:issue:`39025`) -- Bug in :meth:`GroupBy.idxmin` and :meth:`GroupBy.idxmax` with ``ExtendionDtype`` columns (:issue:`38733`) +- Bug in :meth:`.DataFrameGroupBy.idxmin` and :meth:`.DataFrameGroupBy.idxmax` with ``ExtensionDtype`` columns (:issue:`38733`) Reshaping ^^^^^^^^^ From 1d8e28e99fff22aee3f989532f792732af32c5c6 Mon Sep 17 00:00:00 2001 From: Brock Date: Tue, 19 Jan 2021 11:01:17 -0800 Subject: [PATCH 6/7] update docstring --- pandas/core/arrays/base.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index 26588e9a21e93..b0979218e099c 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -603,6 +603,10 @@ def argmin(self, skipna: bool = True) -> int: In case of multiple occurrences of the minimum value, the index corresponding to the first occurrence is returned. + Parameters + ---------- + skipna : bool, default True + Returns ------- int @@ -623,6 +627,10 @@ def argmax(self, skipna: bool = True) -> int: In case of multiple occurrences of the maximum value, the index corresponding to the first occurrence is returned. + Parameters + ---------- + skipna : bool, default True + Returns ------- int From 2fd45ad2c9ba9fee00b724eedaf6fa88dfdf66ba Mon Sep 17 00:00:00 2001 From: Brock Date: Tue, 19 Jan 2021 11:23:05 -0800 Subject: [PATCH 7/7] test for NotImplementedError --- pandas/tests/extension/base/methods.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/pandas/tests/extension/base/methods.py b/pandas/tests/extension/base/methods.py index 7e7f1f1a6e025..3518f3b29e8c2 100644 --- a/pandas/tests/extension/base/methods.py +++ b/pandas/tests/extension/base/methods.py @@ -128,6 +128,16 @@ def test_argreduce_series( result = getattr(ser, op_name)(skipna=skipna) tm.assert_almost_equal(result, expected) + def test_argmax_argmin_no_skipna_notimplemented(self, data_missing_for_sorting): + # GH#38733 + data = data_missing_for_sorting + + with pytest.raises(NotImplementedError, match=""): + data.argmin(skipna=False) + + with pytest.raises(NotImplementedError, match=""): + data.argmax(skipna=False) + @pytest.mark.parametrize( "na_position, expected", [