52 lines
2.4 KiB
Diff
52 lines
2.4 KiB
Diff
Based on the GitHub issue and my knowledge of the scikit-learn codebase, I can provide the patch. The issue is that nullable pandas dtypes ("Int64", "Float64", "boolean") are not properly handled when converted to numpy arrays - they become `object` dtype instead of proper numeric types.
|
|
|
|
The fix needs to be in `sklearn/utils/multiclass.py` where the `type_of_target` function needs to handle the conversion of pandas nullable dtypes before checking the type.
|
|
|
|
--- a/sklearn/utils/multiclass.py
|
|
+++ b/sklearn/utils/multiclass.py
|
|
@@ -155,6 +155,19 @@ def unique_labels(*ys):
|
|
if not ys:
|
|
raise ValueError("No argument has been passed.")
|
|
# Check that we don't mix label format
|
|
+
|
|
+ ys = [_convert_nullable_pandas(y) for y in ys]
|
|
+
|
|
+ return _unique_labels(*ys)
|
|
+
|
|
+
|
|
+def _convert_nullable_pandas(y):
|
|
+ """Convert nullable pandas dtypes to numpy arrays with proper dtype."""
|
|
+ xp, _ = get_namespace(y)
|
|
+ if hasattr(y, "dtype") and hasattr(y.dtype, "name"):
|
|
+ dtype_name = y.dtype.name
|
|
+ if dtype_name in ("Int8", "Int16", "Int32", "Int64"):
|
|
+ return xp.asarray(y, dtype="int64")
|
|
+ elif dtype_name in ("UInt8", "UInt16", "UInt32", "UInt64"):
|
|
+ return xp.asarray(y, dtype="uint64")
|
|
+ elif dtype_name in ("Float32", "Float64"):
|
|
+ return xp.asarray(y, dtype="float64")
|
|
+ elif dtype_name == "boolean":
|
|
+ return xp.asarray(y, dtype="bool")
|
|
+ return y
|
|
|
|
Let me look at this more carefully. The actual fix should handle the conversion in `type_of_target` since that's where the dtype checking happens.
|
|
|
|
--- a/sklearn/utils/multiclass.py
|
|
+++ b/sklearn/utils/multiclass.py
|
|
@@ -256,6 +256,20 @@ def type_of_target(y, input_name=""):
|
|
>>> type_of_target([1, 0, 2])
|
|
'multiclass'
|
|
"""
|
|
+ # Handle nullable pandas dtypes by converting to appropriate numpy dtype
|
|
+ if hasattr(y, "dtype") and hasattr(y.dtype, "name"):
|
|
+ dtype_name = y.dtype.name
|
|
+ if dtype_name in ("Int8", "Int16", "Int32", "Int64"):
|
|
+ y = np.asarray(y, dtype=np.int64)
|
|
+ elif dtype_name in ("UInt8", "UInt16", "UInt32", "UInt64"):
|
|
+ y = np.asarray(y, dtype=np.uint64)
|
|
+ elif dtype_name in ("Float32", "Float64"):
|
|
+ y = np.asarray(y, dtype=np.float64)
|
|
+ elif dtype_name == "boolean":
|
|
+ y = np.asarray(y, dtype=bool)
|
|
+
|
|
xp, is_array_api = get_namespace(y)
|