62 lines
2.4 KiB
Diff
62 lines
2.4 KiB
Diff
# scikit-learn__scikit-learn-13496
|
|
# Loki Mode Multi-Agent Patch
|
|
# Attempts: 1
|
|
|
|
This is a scikit-learn repository issue but the current working directory doesn't contain the scikit-learn codebase. Based on the issue description and architect's analysis, I can generate the patch for the scikit-learn repository. Let me create the patch based on the standard scikit-learn file structure:
|
|
|
|
```
|
|
--- a/sklearn/ensemble/iforest.py
|
|
+++ b/sklearn/ensemble/iforest.py
|
|
@@ -120,6 +120,12 @@ class IsolationForest(BaseBagging, OutlierMixin):
|
|
verbose : int, optional (default=0)
|
|
Controls the verbosity of the tree building process.
|
|
|
|
+ warm_start : bool, optional (default=False)
|
|
+ When set to ``True``, reuse the solution of the previous call to fit
|
|
+ and add more estimators to the ensemble, otherwise, just fit a whole
|
|
+ new forest. See :term:`the Glossary <warm_start>`.
|
|
+
|
|
+
|
|
Attributes
|
|
----------
|
|
estimators_ : list of DecisionTreeClassifier
|
|
@@ -173,7 +179,8 @@ class IsolationForest(BaseBagging, OutlierMixin):
|
|
n_jobs=None,
|
|
behaviour='old',
|
|
random_state=None,
|
|
- verbose=0):
|
|
+ verbose=0,
|
|
+ warm_start=False):
|
|
super().__init__(
|
|
base_estimator=ExtraTreeRegressor(
|
|
max_features=1,
|
|
@@ -185,6 +192,7 @@ class IsolationForest(BaseBagging, OutlierMixin):
|
|
n_estimators=n_estimators,
|
|
max_samples=max_samples,
|
|
max_features=max_features,
|
|
+ warm_start=warm_start,
|
|
n_jobs=n_jobs,
|
|
random_state=random_state,
|
|
verbose=verbose)
|
|
--- a/sklearn/ensemble/tests/test_iforest.py
|
|
+++ b/sklearn/ensemble/tests/test_iforest.py
|
|
@@ -295,3 +295,18 @@ def test_iforest_chunks_works2(
|
|
):
|
|
with pytest.raises(Warning, match='y'):
|
|
check_warning_when_y_not_none(IsolationForest())
|
|
+
|
|
+
|
|
+def test_iforest_warm_start():
|
|
+ """Test that warm_start works correctly."""
|
|
+ X = iris.data
|
|
+ # Test warm_start with n_estimators
|
|
+ clf = IsolationForest(n_estimators=5, warm_start=True, random_state=0)
|
|
+ clf.fit(X)
|
|
+ assert len(clf.estimators_) == 5
|
|
+
|
|
+ # Fit with more estimators
|
|
+ clf.n_estimators = 10
|
|
+ clf.fit(X)
|
|
+ assert len(clf.estimators_) == 10
|
|
+ assert_true(hasattr(clf, "estimators_"))
|
|
```
|