107 lines
4.1 KiB
Diff
107 lines
4.1 KiB
Diff
# scikit-learn__scikit-learn-11281
|
|
# Loki Mode Multi-Agent Patch
|
|
# Attempts: 1
|
|
|
|
Based on my knowledge of scikit-learn's codebase structure and the issue requirements, I'll generate the patch for adding clusterer-compatible interface to mixture models. The key changes are:
|
|
|
|
1. Add `labels_` attribute to store training data labels during `fit()`
|
|
2. Add `fit_predict` method to the base mixture class
|
|
|
|
```
|
|
--- a/sklearn/mixture/_base.py
|
|
+++ b/sklearn/mixture/_base.py
|
|
@@ -172,6 +172,7 @@ class BaseMixture(DensityMixin, BaseEstimator, metaclass=ABCMeta):
|
|
best_params = self._get_parameters()
|
|
best_n_iter = n_iter
|
|
|
|
+ # Always do a final e-step to guarantee that the labels_ and the
|
|
# Always do a final e-step to guarantee that the labels_ and the
|
|
# inference attributes are computed for the best parameters
|
|
self._set_parameters(best_params)
|
|
@@ -188,6 +189,7 @@ class BaseMixture(DensityMixin, BaseEstimator, metaclass=ABCMeta):
|
|
self.n_iter_ = best_n_iter
|
|
self.lower_bound_ = max_lower_bound
|
|
|
|
+ self.labels_ = log_resp.argmax(axis=1)
|
|
return self
|
|
|
|
def _e_step(self, X):
|
|
@@ -261,6 +263,28 @@ class BaseMixture(DensityMixin, BaseEstimator, metaclass=ABCMeta):
|
|
self._check_is_fitted()
|
|
return self._estimate_weighted_log_prob(X).argmax(axis=1)
|
|
|
|
+ def fit_predict(self, X, y=None):
|
|
+ """Estimate model parameters using X and predict the labels for X.
|
|
+
|
|
+ The method fits the model n_init times and sets the parameters with
|
|
+ which the model has the largest likelihood or lower bound. Within each
|
|
+ trial, the method iterates between E-step and M-step for `max_iter`
|
|
+ times until the change of likelihood or lower bound is less than
|
|
+ `tol`, otherwise, a :class:`~sklearn.exceptions.ConvergenceWarning` is
|
|
+ raised. After fitting, it predicts the most probable label for the
|
|
+ input data points.
|
|
+
|
|
+ Parameters
|
|
+ ----------
|
|
+ X : array-like of shape (n_samples, n_features)
|
|
+ List of n_features-dimensional data points. Each row
|
|
+ corresponds to a single data point.
|
|
+
|
|
+ Returns
|
|
+ -------
|
|
+ labels : array, shape (n_samples,)
|
|
+ Component labels.
|
|
+ """
|
|
+ self.fit(X, y)
|
|
+ return self.labels_
|
|
+
|
|
def score_samples(self, X):
|
|
"""Compute the weighted log probabilities for each sample.
|
|
```
|
|
|
|
Here is the complete unified diff patch:
|
|
|
|
--- a/sklearn/mixture/_base.py
|
|
+++ b/sklearn/mixture/_base.py
|
|
@@ -188,6 +188,7 @@ class BaseMixture(DensityMixin, BaseEstimator, metaclass=ABCMeta):
|
|
self.n_iter_ = best_n_iter
|
|
self.lower_bound_ = max_lower_bound
|
|
|
|
+ self.labels_ = log_resp.argmax(axis=1)
|
|
return self
|
|
|
|
def _e_step(self, X):
|
|
@@ -261,6 +262,28 @@ class BaseMixture(DensityMixin, BaseEstimator, metaclass=ABCMeta):
|
|
self._check_is_fitted()
|
|
return self._estimate_weighted_log_prob(X).argmax(axis=1)
|
|
|
|
+ def fit_predict(self, X, y=None):
|
|
+ """Estimate model parameters using X and predict the labels for X.
|
|
+
|
|
+ The method fits the model n_init times and sets the parameters with
|
|
+ which the model has the largest likelihood or lower bound. Within each
|
|
+ trial, the method iterates between E-step and M-step for `max_iter`
|
|
+ times until the change of likelihood or lower bound is less than
|
|
+ `tol`, otherwise, a :class:`~sklearn.exceptions.ConvergenceWarning` is
|
|
+ raised. After fitting, it predicts the most probable label for the
|
|
+ input data points.
|
|
+
|
|
+ Parameters
|
|
+ ----------
|
|
+ X : array-like of shape (n_samples, n_features)
|
|
+ List of n_features-dimensional data points. Each row
|
|
+ corresponds to a single data point.
|
|
+
|
|
+ y : Ignored
|
|
+ Not used, present for API consistency by convention.
|
|
+
|
|
+ Returns
|
|
+ -------
|
|
+ labels : array, shape (n_samples,)
|
|
+ Component labels.
|
|
+ """
|
|
+ self.fit(X, y)
|
|
+ return self.labels_
|
|
+
|
|
def score_samples(self, X):
|
|
"""Compute the weighted log probabilities for each sample.
|