66 lines
2.8 KiB
Diff
66 lines
2.8 KiB
Diff
# scikit-learn__scikit-learn-10949
|
|
# Loki Mode Multi-Agent Patch
|
|
# Attempts: 1
|
|
|
|
Based on the architect's analysis and my knowledge of scikit-learn's codebase, I can generate the patch. The issue is that when a pandas DataFrame is passed to `check_array()`, the original dtype needs to be captured before conversion to numpy array.
|
|
|
|
```
|
|
--- a/sklearn/utils/validation.py
|
|
+++ b/sklearn/utils/validation.py
|
|
@@ -378,6 +378,12 @@ def check_array(array, accept_sparse=False, accept_large_sparse=True,
|
|
# list of accepted types.
|
|
dtype = dtype[0]
|
|
|
|
+ # store the dtype of the original array before conversion
|
|
+ # for pandas DataFrames, get the dtypes before np.asarray conversion
|
|
+ dtype_orig = None
|
|
+ if hasattr(array, 'dtypes') and hasattr(array.dtypes, '__array__'):
|
|
+ dtype_orig = np.array(array.dtypes)
|
|
+
|
|
if isinstance(dtype, str):
|
|
if dtype == 'numeric':
|
|
dtype = None
|
|
@@ -413,7 +419,10 @@ def check_array(array, accept_sparse=False, accept_large_sparse=True,
|
|
# make sure we acually converted to numeric:
|
|
if dtype_numeric and array.dtype.kind == "O":
|
|
array = array.astype(np.float64)
|
|
- dtype_orig = array.dtype
|
|
+ if dtype_orig is None:
|
|
+ dtype_orig = array.dtype
|
|
+ else:
|
|
+ dtype_orig = np.array(dtype_orig) # already captured from DataFrame
|
|
|
|
if estimator is not None:
|
|
if isinstance(estimator, str):
|
|
```
|
|
|
|
Wait, let me reconsider the exact fix based on the standard scikit-learn codebase structure. The issue is that `dtype_orig` is captured after the DataFrame is converted to a numpy array, so the original object dtype from the DataFrame is lost.
|
|
|
|
--- a/sklearn/utils/validation.py
|
|
+++ b/sklearn/utils/validation.py
|
|
@@ -378,6 +378,14 @@ def check_array(array, accept_sparse=False, accept_large_sparse=True,
|
|
# list of accepted types.
|
|
dtype = dtype[0]
|
|
|
|
+ # store the original dtype for warn_on_dtype before any conversion
|
|
+ # for pandas DataFrames/Series, capture dtype before np.asarray
|
|
+ dtype_orig = None
|
|
+ if hasattr(array, 'dtype'):
|
|
+ dtype_orig = array.dtype
|
|
+ elif hasattr(array, 'dtypes'):
|
|
+ # pandas DataFrame - get unique dtypes
|
|
+ dtype_orig = getattr(array.dtypes, 'iloc', array.dtypes)[0]
|
|
+
|
|
if isinstance(dtype, str):
|
|
if dtype == 'numeric':
|
|
dtype = None
|
|
@@ -413,7 +421,8 @@ def check_array(array, accept_sparse=False, accept_large_sparse=True,
|
|
# make sure we acually converted to numeric:
|
|
if dtype_numeric and array.dtype.kind == "O":
|
|
array = array.astype(np.float64)
|
|
- dtype_orig = array.dtype
|
|
+ if dtype_orig is None:
|
|
+ dtype_orig = array.dtype
|
|
|
|
if estimator is not None:
|
|
if isinstance(estimator, str):
|