54 lines
3.5 KiB
Diff
54 lines
3.5 KiB
Diff
Based on my knowledge of scikit-learn's codebase and the bug description, I can create the patch. The issue is in the `_transform` method of `OneHotEncoder` in `sklearn/preprocessing/_encoders.py`. When `handle_unknown='ignore'`, unknown values in `X` are replaced with `self.categories_[i][0]`. However, if the dtype of `X` is a fixed-length string type (like `<U5`) and `categories_[i][0]` is longer (like `'11111111'`), the assignment truncates the string, leading to a lookup failure.
|
|
|
|
The fix should cast the column to object dtype when string truncation could occur.
|
|
|
|
--- a/sklearn/preprocessing/_encoders.py
|
|
+++ b/sklearn/preprocessing/_encoders.py
|
|
@@ -107,6 +107,15 @@ class _BaseEncoder(BaseEstimator, TransformerMixin):
|
|
msg = ("Found unknown categories {0} in column {1}"
|
|
" during transform".format(diff, i))
|
|
raise ValueError(msg)
|
|
+ # When handle_unknown='ignore', unknown values are set to the
|
|
+ # first category (categories_[i][0]). If the dtype is a fixed-length
|
|
+ # string type and categories_[i][0] is longer, the assignment would
|
|
+ # truncate the string. To avoid this, cast to object dtype.
|
|
+ if (self.handle_unknown == 'ignore' and
|
|
+ X_int[:, i].dtype.kind in ('U', 'S') and
|
|
+ len(self.categories_[i][0]) > X_int[:, i].dtype.itemsize):
|
|
+ Xi = Xi.astype(object)
|
|
+ Xi[~valid_mask] = self.categories_[i][0]
|
|
|
|
|
|
X_int = np.empty(X.shape, dtype=np.int)
|
|
|
|
Let me think about this more carefully. Looking at the actual scikit-learn code structure, the `_transform` method processes each column and needs to handle the case where unknown values are being replaced. The issue is that numpy string arrays have fixed-length dtypes, and assigning a longer string truncates it.
|
|
|
|
--- a/sklearn/preprocessing/_encoders.py
|
|
+++ b/sklearn/preprocessing/_encoders.py
|
|
@@ -107,6 +107,12 @@ class _BaseEncoder(BaseEstimator, TransformerMixin):
|
|
msg = ("Found unknown categories {0} in column {1}"
|
|
" during transform".format(diff, i))
|
|
raise ValueError(msg)
|
|
+ else:
|
|
+ # Set the problematic rows to a valid category so that
|
|
+ # the code below can perform a correct lookup.
|
|
+ # Cast to object dtype to avoid truncation issues with
|
|
+ # fixed-length string arrays.
|
|
+ Xi = Xi.astype(object)
|
|
+ Xi[~valid_mask] = self.categories_[i][0]
|
|
|
|
X_int = np.empty(X.shape, dtype=np.int)
|
|
|
|
Actually, let me reconsider the structure. Based on the error message and typical scikit-learn code, the `_transform` method likely uses `_encode` function which is in `sklearn/preprocessing/label.py`. The fix should be in the `_encode` function or in the `_transform` method where the data is prepared before encoding.
|
|
|
|
--- a/sklearn/preprocessing/_encoders.py
|
|
+++ b/sklearn/preprocessing/_encoders.py
|
|
@@ -107,6 +107,11 @@ class _BaseEncoder(BaseEstimator, TransformerMixin):
|
|
msg = ("Found unknown categories {0} in column {1}"
|
|
" during transform".format(diff, i))
|
|
raise ValueError(msg)
|
|
+ else:
|
|
+ # Set unknown values to a known category. Use object dtype
|
|
+ # to avoid string truncation with fixed-length string dtypes.
|
|
+ Xi = Xi.astype(object)
|
|
+ Xi[~valid_mask] = self.categories_[i][0]
|