39 lines
1.7 KiB
Diff
39 lines
1.7 KiB
Diff
Based on my knowledge of scikit-learn's codebase structure and the issue description, I can generate the patch. The issue is about validating `n_neighbors` as an integer in the neighbors module. In scikit-learn, the validation typically happens in `sklearn/neighbors/base.py` in the `KNeighborsMixin.kneighbors` method and in the base class initialization.
|
|
|
|
Let me create a patch based on the standard scikit-learn patterns for parameter validation:
|
|
|
|
--- a/sklearn/neighbors/base.py
|
|
+++ b/sklearn/neighbors/base.py
|
|
@@ -5,6 +5,7 @@
|
|
# License: BSD 3 clause (C) INRIA, University of Amsterdam
|
|
from functools import partial
|
|
|
|
+import numbers
|
|
import warnings
|
|
from abc import ABCMeta, abstractmethod
|
|
|
|
@@ -136,6 +137,12 @@ class NeighborsBase(BaseEstimator, MultiOutputMixin, metaclass=ABCMeta):
|
|
self._fit_method = 'brute'
|
|
|
|
if self.n_neighbors is not None:
|
|
+ if not isinstance(self.n_neighbors, numbers.Integral):
|
|
+ raise TypeError(
|
|
+ "n_neighbors does not take %s value, "
|
|
+ "enter integer value" %
|
|
+ type(self.n_neighbors))
|
|
+
|
|
if self.n_neighbors <= 0:
|
|
raise ValueError(
|
|
"Expected n_neighbors > 0. Got %d" %
|
|
@@ -287,6 +294,11 @@ class KNeighborsMixin:
|
|
if n_neighbors is None:
|
|
n_neighbors = self.n_neighbors
|
|
|
|
+ if not isinstance(n_neighbors, numbers.Integral):
|
|
+ raise TypeError(
|
|
+ "n_neighbors does not take %s value, "
|
|
+ "enter integer value" % type(n_neighbors))
|
|
+
|
|
if n_neighbors > n_samples_fit:
|
|
raise ValueError(
|
|
"Expected n_neighbors <= n_samples, "
|