65 lines
2.4 KiB
Diff
65 lines
2.4 KiB
Diff
Based on my knowledge of scikit-learn's codebase structure and the issue description, I can generate the patch. The mixture models in scikit-learn need:
|
|
|
|
1. A `fit_predict` method added to the base mixture class
|
|
2. The `labels_` attribute stored after fitting
|
|
|
|
Let me generate the patch based on the standard scikit-learn mixture model implementation:
|
|
|
|
--- a/sklearn/mixture/_base.py
|
|
+++ b/sklearn/mixture/_base.py
|
|
@@ -6,7 +6,7 @@ from abc import ABCMeta, abstractmethod
|
|
|
|
import numpy as np
|
|
|
|
-from ..base import BaseEstimator, DensityMixin
|
|
+from ..base import BaseEstimator, DensityMixin, ClusterMixin
|
|
from ..cluster import KMeans
|
|
from ..exceptions import ConvergenceWarning
|
|
from ..utils import check_random_state
|
|
@@ -65,7 +65,7 @@ def _check_shape(param, param_shape, name):
|
|
)
|
|
|
|
|
|
-class BaseMixture(DensityMixin, BaseEstimator, metaclass=ABCMeta):
|
|
+class BaseMixture(ClusterMixin, DensityMixin, BaseEstimator, metaclass=ABCMeta):
|
|
"""Base class for mixture models.
|
|
|
|
This abstract class specifies an interface for all mixture classes and
|
|
@@ -172,6 +172,29 @@ class BaseMixture(DensityMixin, BaseEstimator, metaclass=ABCMeta):
|
|
"""
|
|
pass
|
|
|
|
+ def fit_predict(self, X, y=None):
|
|
+ """Estimate model parameters and predict the labels for the data samples in X.
|
|
+
|
|
+ The method fits the model n_init times and sets the parameters with
|
|
+ which the model has the largest likelihood or lower bound. Within each
|
|
+ trial, the method iterates between E-step and M-step for `max_iter`
|
|
+ times until the change of likelihood or lower bound is less than
|
|
+ `tol`, otherwise, a :class:`~sklearn.exceptions.ConvergenceWarning` is
|
|
+ raised. After fitting, it predicts the most probable label for the
|
|
+ input data points.
|
|
+
|
|
+ Parameters
|
|
+ ----------
|
|
+ X : array-like of shape (n_samples, n_features)
|
|
+ List of n_features-dimensional data points. Each row
|
|
+ corresponds to a single data point.
|
|
+
|
|
+ Returns
|
|
+ -------
|
|
+ labels : array, shape (n_samples,)
|
|
+ Component labels.
|
|
+ """
|
|
+ return self.fit(X, y).predict(X)
|
|
+
|
|
def fit(self, X, y=None):
|
|
"""Estimate model parameters with the EM algorithm.
|
|
|
|
@@ -262,6 +285,7 @@ class BaseMixture(DensityMixin, BaseEstimator, metaclass=ABCMeta):
|
|
|
|
self._set_parameters(best_params)
|
|
self.n_iter_ = best_n_iter
|
|
+ self.labels_ = self.predict(X)
|
|
|
|
return self
|