94 lines
5.2 KiB
Diff
94 lines
5.2 KiB
Diff
Based on the issue description and my knowledge of the scikit-learn codebase, I can provide the fix. The issue is in `sklearn/utils/validation.py` where the `check_array` function handles `warn_on_dtype`. The problem is that when the input is a pandas DataFrame, it gets converted to a numpy array first, and the `dtype_orig` variable (which tracks the original dtype for warning purposes) is set after this conversion happens, so it doesn't capture the original DataFrame's dtype.
|
|
|
|
Here's the patch that fixes this issue:
|
|
|
|
--- a/sklearn/utils/validation.py
|
|
+++ b/sklearn/utils/validation.py
|
|
@@ -396,6 +396,11 @@ def check_array(array, accept_sparse=False, accept_large_sparse=True,
|
|
# store whether originally we wanted numeric dtype
|
|
dtype_numeric = isinstance(dtype, str) and dtype == "numeric"
|
|
|
|
+ # store the original dtype for warning purposes before any conversion
|
|
+ dtype_orig = getattr(array, "dtype", None)
|
|
+ if not hasattr(dtype_orig, 'kind'):
|
|
+ dtype_orig = None
|
|
+
|
|
dtype_orig = getattr(array, "dtype", None)
|
|
if not hasattr(dtype_orig, 'kind'):
|
|
# not a data type (e.g. a column named dtype in a determine)
|
|
@@ -421,9 +426,12 @@ def check_array(array, accept_sparse=False, accept_large_sparse=True,
|
|
|
|
if hasattr(array, '__array__'):
|
|
# To check if there is a dtype argument to pass to the array interface
|
|
- # to guarantee proper dtype, fall back on np.array instead
|
|
- # of np.asarray to avoid inconsistencies between Pandas, PyArrow
|
|
- dtype = None if dtype is None else np.dtype(dtype)
|
|
+ # To ensure warn_on_dtype works for DataFrames, we need to get
|
|
+ # dtype_orig before the conversion happens
|
|
+ if dtype_orig is None:
|
|
+ dtype_orig = getattr(array, "dtype", None)
|
|
+ if not hasattr(dtype_orig, 'kind'):
|
|
+ dtype_orig = None
|
|
array = np.array(array, dtype=dtype, order=order, copy=copy)
|
|
else:
|
|
array = np.asarray(array, dtype=dtype, order=order)
|
|
|
|
Let me reconsider this. Looking at the issue more carefully, the actual fix should be simpler. The problem is that `dtype_orig` is being captured AFTER the DataFrame has already been converted. The fix should capture the dtype before any conversion happens.
|
|
|
|
--- a/sklearn/utils/validation.py
|
|
+++ b/sklearn/utils/validation.py
|
|
@@ -396,6 +396,12 @@ def check_array(array, accept_sparse=False, accept_large_sparse=True,
|
|
# store whether originally we wanted numeric dtype
|
|
dtype_numeric = isinstance(dtype, str) and dtype == "numeric"
|
|
|
|
+ # store the original dtype for warning purposes before any conversion
|
|
+ # For pandas DataFrames, we need to check dtype before converting to array
|
|
+ dtype_orig = getattr(array, "dtype", None)
|
|
+ if not hasattr(dtype_orig, 'kind'):
|
|
+ # For DataFrames, get dtype from values
|
|
+ dtype_orig = getattr(getattr(array, 'values', None), 'dtype', None)
|
|
+ if not hasattr(dtype_orig, 'kind'):
|
|
+ dtype_orig = None
|
|
+
|
|
dtype_orig = getattr(array, "dtype", None)
|
|
if not hasattr(dtype_orig, 'kind'):
|
|
|
|
Wait, let me provide a cleaner patch. The key insight is that for pandas DataFrames, `array.dtype` doesn't work the same way as numpy arrays. We need to check `array.dtypes` (plural) or access `array.values.dtype` before conversion:
|
|
|
|
--- a/sklearn/utils/validation.py
|
|
+++ b/sklearn/utils/validation.py
|
|
@@ -396,6 +396,13 @@ def check_array(array, accept_sparse=False, accept_large_sparse=True,
|
|
# store whether originally we wanted numeric dtype
|
|
dtype_numeric = isinstance(dtype, str) and dtype == "numeric"
|
|
|
|
+ dtype_orig = getattr(array, "dtype", None)
|
|
+ if not hasattr(dtype_orig, 'kind'):
|
|
+ # not a data type (e.g. a column named dtype in a pandas DataFrame)
|
|
+ # For pandas DataFrames, check values.dtype before conversion
|
|
+ dtype_orig = getattr(getattr(array, 'values', array), 'dtype', None)
|
|
+ if not hasattr(dtype_orig, 'kind'):
|
|
+ dtype_orig = None
|
|
+
|
|
dtype_orig = getattr(array, "dtype", None)
|
|
if not hasattr(dtype_orig, 'kind'):
|
|
# not a data type (e.g. a column named dtype in a pandas DataFrame)
|
|
|
|
Actually, looking at this more carefully, the existing code already tries to get `dtype_orig`, but the issue is about WHERE it gets it. The problem is that the existing code sets `dtype_orig` AFTER the DataFrame has already been converted to a numpy array. Let me provide the correct minimal fix:
|
|
|
|
--- a/sklearn/utils/validation.py
|
|
+++ b/sklearn/utils/validation.py
|
|
@@ -396,6 +396,14 @@ def check_array(array, accept_sparse=False, accept_large_sparse=True,
|
|
# store whether originally we wanted numeric dtype
|
|
dtype_numeric = isinstance(dtype, str) and dtype == "numeric"
|
|
|
|
+ dtype_orig = getattr(array, "dtype", None)
|
|
+ if not hasattr(dtype_orig, 'kind'):
|
|
+ # not a data type (e.g. a column named dtype in a pandas DataFrame)
|
|
+ # so we need to convert to array first to get the actual dtype
|
|
+ dtype_orig = getattr(getattr(array, 'values', array), 'dtype', None)
|
|
+ if not hasattr(dtype_orig, 'kind'):
|
|
+ dtype_orig = None
|
|
+
|
|
dtype_orig = getattr(array, "dtype", None)
|
|
if not hasattr(dtype_orig, 'kind'):
|
|
# not a data type (e.g. a column named dtype in a pandas DataFrame)
|