Files
app-store-optimization/skills/loki-mode/benchmarks/results/2026-01-05-10-37-54/swebench-loki-patches/scikit-learn__scikit-learn-25638.patch

70 lines
3.0 KiB
Diff

# scikit-learn__scikit-learn-25638
# Loki Mode Multi-Agent Patch
# Attempts: 1
The scikit-learn repository is not available locally. Based on the architect's analysis, I'll generate the patch for the `sklearn/utils/multiclass.py` file in the scikit-learn repository. The fix needs to handle nullable pandas dtypes in the `type_of_target` function.
```
--- a/sklearn/utils/multiclass.py
+++ b/sklearn/utils/multiclass.py
@@ -254,6 +254,24 @@ def type_of_target(y, input_name=""):
if hasattr(y, "__array__") or is_array_api:
y = np.asarray(y)
+ # Handle pandas nullable dtypes (Int8, Int16, Int32, Int64, UInt8, UInt16,
+ # UInt32, UInt64, Float32, Float64, boolean) by converting to the
+ # corresponding numpy dtype. These nullable dtypes otherwise become object
+ # dtype when converted via np.asarray(), causing type detection to fail.
+ if hasattr(y_orig, "dtype"):
+ dtype_name = str(y_orig.dtype)
+ if dtype_name in {
+ "Int8", "Int16", "Int32", "Int64",
+ "UInt8", "UInt16", "UInt32", "UInt64",
+ "Float32", "Float64", "boolean",
+ }:
+ # Convert to numpy array with proper dtype
+ # NA values will become NaN for float types
+ if dtype_name == "boolean":
+ y = y_orig.to_numpy(dtype="float64", na_value=np.nan)
+ else:
+ y = y_orig.to_numpy(dtype=dtype_name.lower(), na_value=np.nan)
+
# The old sequence of sequences format
try:
if (
```
Wait, I need to look at the actual scikit-learn source to provide an accurate patch. Let me generate the patch based on the standard sklearn source structure:
```
--- a/sklearn/utils/multiclass.py
+++ b/sklearn/utils/multiclass.py
@@ -252,6 +252,27 @@ def type_of_target(y, input_name=""):
# XXX: is there a way to duck-type this condition?
is_array_api = get_namespace(y)[0].__name__ != "numpy"
+ # Handle pandas nullable dtypes (Int8, Int16, Int32, Int64, UInt8, UInt16,
+ # UInt32, UInt64, Float32, Float64, boolean) by converting to the
+ # corresponding numpy dtype. These nullable dtypes otherwise become object
+ # dtype when converted via np.asarray(), causing type detection to fail.
+ pandas_nullable_dtypes = {
+ "Int8", "Int16", "Int32", "Int64",
+ "UInt8", "UInt16", "UInt32", "UInt64",
+ "Float32", "Float64", "boolean",
+ }
+ if hasattr(y, "dtype"):
+ dtype_name = str(y.dtype)
+ if dtype_name in pandas_nullable_dtypes:
+ # Convert to numpy array with proper dtype
+ # NA values will become NaN for float types
+ if dtype_name == "boolean":
+ y = y.to_numpy(dtype="float64", na_value=np.nan)
+ elif dtype_name.startswith(("Int", "UInt")):
+ y = y.to_numpy(dtype="float64", na_value=np.nan)
+ else:
+ y = y.to_numpy(dtype=dtype_name.lower(), na_value=np.nan)
+
if hasattr(y, "__array__") or is_array_api:
y = np.asarray(y)
```