44 lines
2.2 KiB
Diff
44 lines
2.2 KiB
Diff
Based on the issue description, I understand the problem:
|
|
|
|
1. The regression was introduced in PR #10830
|
|
2. The issue is that `mutual_info_score` with object dtype arrays containing strings fails with "ValueError: could not convert string to float: 'b'"
|
|
3. Previously (0.21.1), this worked without a warning
|
|
|
|
The problem is in the `check_clusterings` function in `sklearn/metrics/cluster/_supervised.py`. The PR #10830 likely added input validation using `check_array` which by default tries to convert to float, but clustering metrics should accept categorical/label data including strings.
|
|
|
|
Let me create the patch based on my understanding of the sklearn codebase and this issue:
|
|
|
|
```
|
|
--- a/sklearn/metrics/cluster/_supervised.py
|
|
+++ b/sklearn/metrics/cluster/_supervised.py
|
|
@@ -40,8 +40,6 @@ def check_clusterings(labels_true, labels_pred):
|
|
labels_true : array-like of shape (n_samples,)
|
|
The true labels.
|
|
labels_pred : array-like of shape (n_samples,)
|
|
The predicted labels.
|
|
"""
|
|
- labels_true = check_array(labels_true, ensure_2d=False, ensure_min_samples=0)
|
|
- labels_pred = check_array(labels_pred, ensure_2d=False, ensure_min_samples=0)
|
|
+ labels_true = check_array(labels_true, ensure_2d=False, ensure_min_samples=0, dtype=None)
|
|
+ labels_pred = check_array(labels_pred, ensure_2d=False, ensure_min_samples=0, dtype=None)
|
|
|
|
# input checks
|
|
```
|
|
|
|
The fix is to add `dtype=None` to the `check_array` calls, which prevents the automatic conversion to float and allows string/object arrays to pass through.
|
|
|
|
--- a/sklearn/metrics/cluster/_supervised.py
|
|
+++ b/sklearn/metrics/cluster/_supervised.py
|
|
@@ -43,8 +43,10 @@ def check_clusterings(labels_true, labels_pred):
|
|
labels_pred : array-like of shape (n_samples,)
|
|
The predicted labels.
|
|
"""
|
|
- labels_true = check_array(labels_true, ensure_2d=False, ensure_min_samples=0)
|
|
- labels_pred = check_array(labels_pred, ensure_2d=False, ensure_min_samples=0)
|
|
+ labels_true = check_array(
|
|
+ labels_true, ensure_2d=False, ensure_min_samples=0, dtype=None)
|
|
+ labels_pred = check_array(
|
|
+ labels_pred, ensure_2d=False, ensure_min_samples=0, dtype=None)
|
|
|
|
# input checks
|
|
if labels_true.ndim != 1:
|