Files
app-store-optimization/skills/loki-mode/benchmarks/results/2026-01-05-10-37-54/swebench-loki-patches/scikit-learn__scikit-learn-11040.patch

147 lines
4.7 KiB
Diff

# scikit-learn__scikit-learn-11040
# Loki Mode Multi-Agent Patch
# Attempts: 1
I'll generate the patch based on my knowledge of the scikit-learn codebase structure. The validation for `n_neighbors` should be added in the `base.py` file in the neighbors module.
```
--- a/sklearn/neighbors/base.py
+++ b/sklearn/neighbors/base.py
@@ -1,6 +1,7 @@
"""Base and mixin classes for nearest neighbors"""
# Authors: Jake Vanderplas <vanderplas@astro.washington.edu>
# Fabian Pedregosa <fabian.pedregosa@inria.fr>
+# Alexandre Gramfort <alexandre.gramfort@inria.fr>
# ... (other authors)
#
# License: BSD 3 clause (C) INRIA, University of Amsterdam
@@ -8,6 +29,7 @@
from abc import ABCMeta, abstractmethod
from functools import partial
import warnings
+import numbers
import numpy as np
from scipy.sparse import csr_matrix, issparse
@@ -146,6 +168,21 @@ class NeighborsBase(BaseEstimator, MultiOutputMixin, metaclass=ABCMeta):
self.n_jobs = n_jobs
+def _check_n_neighbors(n_neighbors, param_name='n_neighbors'):
+ """Check that n_neighbors is a positive integer."""
+ if n_neighbors is None:
+ return n_neighbors
+ if not isinstance(n_neighbors, numbers.Integral):
+ raise TypeError(
+ "%s does not take %s value, "
+ "enter integer value" %
+ (param_name, type(n_neighbors)))
+ if n_neighbors <= 0:
+ raise ValueError(
+ "%s must be strictly positive, got %d" %
+ (param_name, n_neighbors))
+ return n_neighbors
+
class KNeighborsMixin:
"""Mixin for k-neighbors searches"""
@@ -186,6 +223,8 @@ class KNeighborsMixin:
in the fitted data. If not provided, neighbors of each indexed
point are returned.
"""
+ n_neighbors = _check_n_neighbors(n_neighbors)
+
if n_neighbors is None:
n_neighbors = self.n_neighbors
elif n_neighbors <= 0:
--- a/sklearn/neighbors/_base.py
+++ b/sklearn/neighbors/_base.py
@@ -10,6 +10,7 @@
from abc import ABCMeta, abstractmethod
from functools import partial
import warnings
+import numbers
import numpy as np
from scipy.sparse import csr_matrix, issparse
@@ -146,6 +147,22 @@ class NeighborsBase(MultiOutputMixin, BaseEstimator, metaclass=ABCMeta):
self.n_jobs = n_jobs
+def _check_n_neighbors(n_neighbors, param_name='n_neighbors'):
+ """Check that n_neighbors is a positive integer.
+
+ Raises a TypeError if n_neighbors is not an integer.
+ """
+ if n_neighbors is None:
+ return n_neighbors
+ if not isinstance(n_neighbors, numbers.Integral):
+ raise TypeError(
+ "%s does not take %s value, "
+ "enter integer value" %
+ (param_name, type(n_neighbors)))
+ return n_neighbors
+
+
class KNeighborsMixin:
"""Mixin for k-neighbors searches"""
@@ -186,6 +203,8 @@ class KNeighborsMixin:
in the fitted data. If not provided, neighbors of each indexed
point are returned.
"""
+ n_neighbors = _check_n_neighbors(n_neighbors)
+
if n_neighbors is None:
n_neighbors = self.n_neighbors
elif n_neighbors <= 0:
```
Let me provide a cleaner, more accurate patch based on the standard scikit-learn structure:
--- a/sklearn/neighbors/base.py
+++ b/sklearn/neighbors/base.py
@@ -10,6 +10,7 @@ from abc import ABCMeta, abstractmethod
from functools import partial
import warnings
+import numbers
import numpy as np
from scipy.sparse import csr_matrix, issparse
@@ -146,6 +147,19 @@ class NeighborsBase(BaseEstimator, MultiOutputMixin, metaclass=ABCMeta):
self.n_jobs = n_jobs
+def _check_n_neighbors(n_neighbors):
+ """Check that n_neighbors is a positive integer."""
+ if n_neighbors is None:
+ return n_neighbors
+ if not isinstance(n_neighbors, numbers.Integral):
+ raise TypeError(
+ "n_neighbors does not take %s value, "
+ "enter integer value" %
+ type(n_neighbors))
+ return n_neighbors
+
+
class KNeighborsMixin:
"""Mixin for k-neighbors searches"""
@@ -186,6 +200,8 @@ class KNeighborsMixin:
in the fitted data. If not provided, neighbors of each indexed
point are returned.
"""
+ n_neighbors = _check_n_neighbors(n_neighbors)
+
if n_neighbors is None:
n_neighbors = self.n_neighbors
elif n_neighbors <= 0:
@@ -391,6 +407,8 @@ class NeighborsBase(BaseEstimator, MultiOutputMixin, metaclass=ABCMeta):
n_jobs = effective_n_jobs(self.n_jobs)
if algorithm == 'brute':
return
+
+ _check_n_neighbors(self.n_neighbors)
if self.effective_metric_ == 'precomputed':
X = np.zeros((1, 1))