diff --git a/pandas/core/arrays/_mixins.py b/pandas/core/arrays/_mixins.py index 11928e79ffc62..ff01d4ac835ba 100644 --- a/pandas/core/arrays/_mixins.py +++ b/pandas/core/arrays/_mixins.py @@ -208,6 +208,10 @@ def _hash_pandas_object( values, encoding=encoding, hash_key=hash_key, categorize=categorize ) + def _cast_pointwise_result(self, values: ArrayLike) -> ArrayLike: + values = np.asarray(values, dtype=object) + return lib.maybe_convert_objects(values, convert_non_numeric=True) + # Signature of "argmin" incompatible with supertype "ExtensionArray" def argmin(self, axis: AxisInt = 0, skipna: bool = True): # type: ignore[override] # override base class by adding axis keyword diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index 225cc888d50db..8818336132eb3 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -19,6 +19,7 @@ cast, overload, ) +import warnings import numpy as np @@ -33,6 +34,7 @@ cache_readonly, set_module, ) +from pandas.util._exceptions import find_stack_level from pandas.util._validators import ( validate_bool_kwarg, validate_insert_loc, @@ -86,6 +88,7 @@ AstypeArg, AxisInt, Dtype, + DtypeObj, FillnaOptions, InterpolateOptions, NumpySorter, @@ -383,13 +386,67 @@ def _from_factorized(cls, values, original): """ raise AbstractMethodError(cls) - def _cast_pointwise_result(self, values) -> ArrayLike: + @classmethod + def _from_scalars(cls, scalars, *, dtype: DtypeObj) -> Self: + """ + Strict analogue to _from_sequence, allowing only sequences of scalars + that should be specifically inferred to the given dtype. + + Parameters + ---------- + scalars : sequence + dtype : ExtensionDtype + + Raises + ------ + TypeError or ValueError + + Notes + ----- + This is called in a try/except block when casting the result of a + pointwise operation in ExtensionArray._cast_pointwise_result. + """ + try: + return cls._from_sequence(scalars, dtype=dtype, copy=False) + except (ValueError, TypeError): + raise + except Exception: + warnings.warn( + "_from_scalars should only raise ValueError or TypeError. " + "Consider overriding _from_scalars where appropriate.", + stacklevel=find_stack_level(), + ) + raise + + def _cast_pointwise_result(self, values, **kwargs) -> ArrayLike: """ + Construct an ExtensionArray after a pointwise operation. + Cast the result of a pointwise operation (e.g. Series.map) to an - array, preserve dtype_backend if possible. + array. This is not required to return an ExtensionArray of the same + type as self or of the same dtype. It can also return another + ExtensionArray of the same "family" if you implement multiple + ExtensionArrays/Dtypes that are interoperable (e.g. if you have float + array with units, this method can return an int array with units). + + If converting to your own ExtensionArray is not possible, this method + falls back to returning an array with the default type inference. + If you only need to cast to `self.dtype`, it is recommended to override + `_from_scalars` instead of this method. + + Parameters + ---------- + values : sequence + + Returns + ------- + ExtensionArray or ndarray """ - values = np.asarray(values, dtype=object) - return lib.maybe_convert_objects(values, convert_non_numeric=True) + try: + return type(self)._from_scalars(values, dtype=self.dtype) + except (ValueError, TypeError): + values = np.asarray(values, dtype=object) + return lib.maybe_convert_objects(values, convert_non_numeric=True) # ------------------------------------------------------------------------ # Must be a Sequence diff --git a/pandas/core/arrays/sparse/array.py b/pandas/core/arrays/sparse/array.py index c15a196dc6727..86140229b724e 100644 --- a/pandas/core/arrays/sparse/array.py +++ b/pandas/core/arrays/sparse/array.py @@ -622,7 +622,8 @@ def _from_factorized(cls, values, original) -> Self: return cls(values, dtype=original.dtype) def _cast_pointwise_result(self, values): - result = super()._cast_pointwise_result(values) + values = np.asarray(values, dtype=object) + result = lib.maybe_convert_objects(values, convert_non_numeric=True) if result.dtype.kind == self.dtype.kind: try: # e.g. test_groupby_agg_extension diff --git a/pandas/tests/extension/decimal/array.py b/pandas/tests/extension/decimal/array.py index 7d055e2143112..23f3fa046dab4 100644 --- a/pandas/tests/extension/decimal/array.py +++ b/pandas/tests/extension/decimal/array.py @@ -111,15 +111,15 @@ def _from_sequence_of_strings(cls, strings, *, dtype: ExtensionDtype, copy=False def _from_factorized(cls, values, original): return cls(values) - def _cast_pointwise_result(self, values): - result = super()._cast_pointwise_result(values) - try: - # If this were ever made a non-test EA, special-casing could - # be avoided by handling Decimal in maybe_convert_objects - res = type(self)._from_sequence(result, dtype=self.dtype) - except (ValueError, TypeError): - return result - return res + # test to ensure that the base class _cast_pointwise_result works as expected + # def _cast_pointwise_result(self, values): + # try: + # # If this were ever made a non-test EA, special-casing could + # # be avoided by handling Decimal in maybe_convert_objects + # res = type(self)._from_sequence(values, dtype=self.dtype) + # except (ValueError, TypeError): + # return values + # return res _HANDLED_TYPES = (decimal.Decimal, numbers.Number, np.ndarray) diff --git a/pandas/tests/extension/json/array.py b/pandas/tests/extension/json/array.py index 1878fac1b8111..5eb1a9ff286fa 100644 --- a/pandas/tests/extension/json/array.py +++ b/pandas/tests/extension/json/array.py @@ -94,11 +94,14 @@ def _from_factorized(cls, values, original): return cls([UserDict(x) for x in values if x != ()]) def _cast_pointwise_result(self, values): - result = super()._cast_pointwise_result(values) try: - return type(self)._from_sequence(result, dtype=self.dtype) + return type(self)._from_sequence(values, dtype=self.dtype) except (ValueError, TypeError): - return result + # TODO replace with public function + from pandas._libs import lib + + values = np.asarray(values, dtype=object) + return lib.maybe_convert_objects(values, convert_non_numeric=True) def __getitem__(self, item): if isinstance(item, tuple):