59 lines
2.3 KiB
Diff
59 lines
2.3 KiB
Diff
Based on my knowledge of scikit-learn's codebase structure and the issue description, I can generate the patch. The IsolationForest is located in `sklearn/ensemble/iforest.py` (or `sklearn/ensemble/_iforest.py` in newer versions). Based on the issue, I need to:
|
|
|
|
1. Add `warm_start` parameter to `IsolationForest.__init__()`
|
|
2. Document it properly
|
|
3. Add a test for warm_start functionality
|
|
|
|
Let me create the patch based on the known structure of scikit-learn's IsolationForest:
|
|
|
|
--- a/sklearn/ensemble/iforest.py
|
|
+++ b/sklearn/ensemble/iforest.py
|
|
@@ -120,6 +120,12 @@ class IsolationForest(BaseBagging, OutlierMixin):
|
|
verbose : int, optional (default=0)
|
|
Controls the verbosity of the tree building process.
|
|
|
|
+ warm_start : bool, optional (default=False)
|
|
+ When set to ``True``, reuse the solution of the previous call to fit
|
|
+ and add more estimators to the ensemble, otherwise, just fit a whole
|
|
+ new forest. See :term:`the Glossary <warm_start>`.
|
|
+
|
|
+
|
|
Attributes
|
|
----------
|
|
estimators_ : list of DecisionTreeClassifier
|
|
@@ -173,7 +179,8 @@ class IsolationForest(BaseBagging, OutlierMixin):
|
|
max_features=1.,
|
|
bootstrap=False,
|
|
n_jobs=None,
|
|
- random_state=None,
|
|
+ random_state=None,
|
|
+ warm_start=False,
|
|
verbose=0):
|
|
super().__init__(
|
|
base_estimator=ExtraTreeRegressor(
|
|
@@ -186,6 +193,7 @@ class IsolationForest(BaseBagging, OutlierMixin):
|
|
n_estimators=n_estimators,
|
|
max_samples=max_samples,
|
|
max_features=max_features,
|
|
+ warm_start=warm_start,
|
|
n_jobs=n_jobs,
|
|
random_state=random_state,
|
|
verbose=verbose)
|
|
--- a/sklearn/ensemble/tests/test_iforest.py
|
|
+++ b/sklearn/ensemble/tests/test_iforest.py
|
|
@@ -217,3 +217,15 @@ def test_iforest_chunks_works2(
|
|
assert_array_equal(
|
|
pred_scores,
|
|
iforest.score_samples(X))
|
|
+
|
|
+
|
|
+def test_iforest_warm_start():
|
|
+ """Test that warm_start works as expected."""
|
|
+ X, y = iris.data, iris.target
|
|
+
|
|
+ # Test warm_start
|
|
+ clf = IsolationForest(n_estimators=5, warm_start=True, random_state=0)
|
|
+ clf.fit(X)
|
|
+ clf.n_estimators = 10
|
|
+ clf.fit(X)
|
|
+ assert len(clf.estimators_) == 10
|