70 lines
3.0 KiB
Diff
70 lines
3.0 KiB
Diff
# scikit-learn__scikit-learn-25638
|
|
# Loki Mode Multi-Agent Patch
|
|
# Attempts: 1
|
|
|
|
The scikit-learn repository is not available locally. Based on the architect's analysis, I'll generate the patch for the `sklearn/utils/multiclass.py` file in the scikit-learn repository. The fix needs to handle nullable pandas dtypes in the `type_of_target` function.
|
|
|
|
```
|
|
--- a/sklearn/utils/multiclass.py
|
|
+++ b/sklearn/utils/multiclass.py
|
|
@@ -254,6 +254,24 @@ def type_of_target(y, input_name=""):
|
|
if hasattr(y, "__array__") or is_array_api:
|
|
y = np.asarray(y)
|
|
|
|
+ # Handle pandas nullable dtypes (Int8, Int16, Int32, Int64, UInt8, UInt16,
|
|
+ # UInt32, UInt64, Float32, Float64, boolean) by converting to the
|
|
+ # corresponding numpy dtype. These nullable dtypes otherwise become object
|
|
+ # dtype when converted via np.asarray(), causing type detection to fail.
|
|
+ if hasattr(y_orig, "dtype"):
|
|
+ dtype_name = str(y_orig.dtype)
|
|
+ if dtype_name in {
|
|
+ "Int8", "Int16", "Int32", "Int64",
|
|
+ "UInt8", "UInt16", "UInt32", "UInt64",
|
|
+ "Float32", "Float64", "boolean",
|
|
+ }:
|
|
+ # Convert to numpy array with proper dtype
|
|
+ # NA values will become NaN for float types
|
|
+ if dtype_name == "boolean":
|
|
+ y = y_orig.to_numpy(dtype="float64", na_value=np.nan)
|
|
+ else:
|
|
+ y = y_orig.to_numpy(dtype=dtype_name.lower(), na_value=np.nan)
|
|
+
|
|
# The old sequence of sequences format
|
|
try:
|
|
if (
|
|
```
|
|
|
|
Wait, I need to look at the actual scikit-learn source to provide an accurate patch. Let me generate the patch based on the standard sklearn source structure:
|
|
|
|
```
|
|
--- a/sklearn/utils/multiclass.py
|
|
+++ b/sklearn/utils/multiclass.py
|
|
@@ -252,6 +252,27 @@ def type_of_target(y, input_name=""):
|
|
# XXX: is there a way to duck-type this condition?
|
|
is_array_api = get_namespace(y)[0].__name__ != "numpy"
|
|
|
|
+ # Handle pandas nullable dtypes (Int8, Int16, Int32, Int64, UInt8, UInt16,
|
|
+ # UInt32, UInt64, Float32, Float64, boolean) by converting to the
|
|
+ # corresponding numpy dtype. These nullable dtypes otherwise become object
|
|
+ # dtype when converted via np.asarray(), causing type detection to fail.
|
|
+ pandas_nullable_dtypes = {
|
|
+ "Int8", "Int16", "Int32", "Int64",
|
|
+ "UInt8", "UInt16", "UInt32", "UInt64",
|
|
+ "Float32", "Float64", "boolean",
|
|
+ }
|
|
+ if hasattr(y, "dtype"):
|
|
+ dtype_name = str(y.dtype)
|
|
+ if dtype_name in pandas_nullable_dtypes:
|
|
+ # Convert to numpy array with proper dtype
|
|
+ # NA values will become NaN for float types
|
|
+ if dtype_name == "boolean":
|
|
+ y = y.to_numpy(dtype="float64", na_value=np.nan)
|
|
+ elif dtype_name.startswith(("Int", "UInt")):
|
|
+ y = y.to_numpy(dtype="float64", na_value=np.nan)
|
|
+ else:
|
|
+ y = y.to_numpy(dtype=dtype_name.lower(), na_value=np.nan)
|
|
+
|
|
if hasattr(y, "__array__") or is_array_api:
|
|
y = np.asarray(y)
|
|
|
|
```
|