35 lines
1.9 KiB
Diff
35 lines
1.9 KiB
Diff
Based on the issue description, I can understand the problem and create the appropriate patch. The issue is that NCA's `_check_params` method checks parameters too strictly:
|
|
1. It checks `isinstance(tol, float)` which fails when `tol=1` (an int)
|
|
2. It checks `isinstance(n_components, int)` which fails when `n_components` is `np.int64`
|
|
|
|
The fix is to use `numbers.Integral` and `numbers.Real` from the `numbers` module instead of `int` and `float` directly. This is the standard Python approach to accept any numeric type that behaves like an integer or real number.
|
|
|
|
Based on the issue description and standard scikit-learn patterns, here's the patch:
|
|
|
|
--- a/sklearn/neighbors/_nca.py
|
|
+++ b/sklearn/neighbors/_nca.py
|
|
@@ -21,6 +21,7 @@ from ..decomposition import PCA
|
|
from ..utils.multiclass import check_classification_targets
|
|
from ..utils.validation import check_is_fitted, check_array
|
|
from ..utils._readonly_array_wrapper import ReadonlyArrayWrapper
|
|
+import numbers
|
|
|
|
|
|
class NeighborhoodComponentsAnalysis(TransformerMixin, BaseEstimator):
|
|
@@ -295,7 +296,7 @@ class NeighborhoodComponentsAnalysis(TransformerMixin, BaseEstimator):
|
|
# Check the preferred dimensionality of the projected space
|
|
if self.n_components is not None:
|
|
- if not isinstance(self.n_components, int):
|
|
+ if not isinstance(self.n_components, numbers.Integral):
|
|
raise TypeError("n_components must be an integer.")
|
|
if self.n_components < 1:
|
|
raise ValueError("n_components must be at least 1.")
|
|
@@ -305,7 +306,7 @@ class NeighborhoodComponentsAnalysis(TransformerMixin, BaseEstimator):
|
|
|
|
# Check the convergence tolerance
|
|
if self.tol is not None:
|
|
- if not isinstance(self.tol, float):
|
|
+ if not isinstance(self.tol, numbers.Real):
|
|
raise TypeError("tol must be a float.")
|
|
if self.tol < 0:
|
|
raise ValueError("tol must be a non-negative float.")
|