Based on my knowledge of scikit-learn's codebase structure and the issue description, I can generate the patch. The mixture models in scikit-learn need: 1. A `fit_predict` method added to the base mixture class 2. The `labels_` attribute stored after fitting Let me generate the patch based on the standard scikit-learn mixture model implementation: --- a/sklearn/mixture/_base.py +++ b/sklearn/mixture/_base.py @@ -6,7 +6,7 @@ from abc import ABCMeta, abstractmethod import numpy as np -from ..base import BaseEstimator, DensityMixin +from ..base import BaseEstimator, DensityMixin, ClusterMixin from ..cluster import KMeans from ..exceptions import ConvergenceWarning from ..utils import check_random_state @@ -65,7 +65,7 @@ def _check_shape(param, param_shape, name): ) -class BaseMixture(DensityMixin, BaseEstimator, metaclass=ABCMeta): +class BaseMixture(ClusterMixin, DensityMixin, BaseEstimator, metaclass=ABCMeta): """Base class for mixture models. This abstract class specifies an interface for all mixture classes and @@ -172,6 +172,29 @@ class BaseMixture(DensityMixin, BaseEstimator, metaclass=ABCMeta): """ pass + def fit_predict(self, X, y=None): + """Estimate model parameters and predict the labels for the data samples in X. + + The method fits the model n_init times and sets the parameters with + which the model has the largest likelihood or lower bound. Within each + trial, the method iterates between E-step and M-step for `max_iter` + times until the change of likelihood or lower bound is less than + `tol`, otherwise, a :class:`~sklearn.exceptions.ConvergenceWarning` is + raised. After fitting, it predicts the most probable label for the + input data points. + + Parameters + ---------- + X : array-like of shape (n_samples, n_features) + List of n_features-dimensional data points. Each row + corresponds to a single data point. + + Returns + ------- + labels : array, shape (n_samples,) + Component labels. + """ + return self.fit(X, y).predict(X) + def fit(self, X, y=None): """Estimate model parameters with the EM algorithm. @@ -262,6 +285,7 @@ class BaseMixture(DensityMixin, BaseEstimator, metaclass=ABCMeta): self._set_parameters(best_params) self.n_iter_ = best_n_iter + self.labels_ = self.predict(X) return self