Skip to content

Commit

Permalink
Merge pull request scikit-learn#2199 from GaelVaroquaux/hc_linkage
Browse files Browse the repository at this point in the history
[MRG+1] Hierarchical Agglomerative Clustering
  • Loading branch information
GaelVaroquaux committed Mar 5, 2014
2 parents 10bfa55 + 2a4402b commit 39f3a76
Show file tree
Hide file tree
Showing 29 changed files with 44,497 additions and 7,771 deletions.
4 changes: 2 additions & 2 deletions benchmarks/bench_plot_ward.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
from scipy.cluster import hierarchy
import pylab as pl

from sklearn.cluster import Ward
from sklearn.cluster import AgglomerativeClustering

ward = Ward(n_clusters=3)
ward = AgglomerativeClustering(n_clusters=3, linkage='ward')

n_samples = np.logspace(.5, 3, 9)
n_features = np.logspace(1, 3.5, 7)
Expand Down
2 changes: 2 additions & 0 deletions doc/modules/classes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ Classes
:template: class.rst

cluster.AffinityPropagation
cluster.AgglomerativeClustering
cluster.DBSCAN
cluster.FeatureAgglomeration
cluster.KMeans
cluster.MiniBatchKMeans
cluster.MeanShift
Expand Down
172 changes: 146 additions & 26 deletions doc/modules/clustering.rst
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,19 @@ Overview of clustering methods
- Few clusters, even cluster size, non-flat geometry
- Graph distance (e.g. nearest-neighbor graph)

* - :ref:`Hierarchical clustering <hierarchical_clustering>`
* - :ref:`Ward hierarchical clustering <hierarchical_clustering>`
- number of clusters
- Large `n_samples` and `n_clusters`
- Many clusters, possibly connectivity constraints
- Distances between points

* - :ref:`Agglomerative clustering <hierarchical_clustering>`
- number of clusters, linkage type, distance
- Large `n_samples` and `n_clusters`
- Many clusters, possibly connectivity constraints, non Euclidean
distances
- Any pairwise distance

* - :ref:`DBSCAN <dbscan>`
- neighborhood size
- Very large `n_samples`, medium `n_clusters`
Expand Down Expand Up @@ -471,35 +478,80 @@ Hierarchical clustering
=======================

Hierarchical clustering is a general family of clustering algorithms that
build nested clusters by merging them successively. This hierarchy of
clusters represented as a tree (or dendrogram). The root of the tree is
the unique cluster that gathers all the samples, the leaves being the
build nested clusters by merging or splitting them successively. This
hierarchy of clusters is represented as a tree (or dendrogram). The root of the
tree is the unique cluster that gathers all the samples, the leaves being the
clusters with only one sample. See the `Wikipedia page
<http://en.wikipedia.org/wiki/Hierarchical_clustering>`_ for more
details.
<http://en.wikipedia.org/wiki/Hierarchical_clustering>`_ for more details.

The :class:`AgglomerativeClustering` object performs a hierarchical clustering
using a bottom up approach: each observation starts in its own cluster, and
clusters are successively merged together. The linkage criteria determines the
metric used for the merge strategy:

- **Ward** minimizes the sum of squared differences within all clusters. It is a
variance-minimizing approach and in this sense is similar to the k-means
objective function but tackled with an agglomerative hierarchical
approach.
- **Maximum** or **complete linkage** minimizes the maximum distance between
observations of pairs of clusters.
- **Average linkage** minimizes the average of the distances between all
observations of pairs of clusters.

:class:`AgglomerativeClustering` can also scale to large number of samples
when it is used jointly with a connectivity matrix, but is computationally
expensive when no connectivity constraints are added between samples: it
considers at each step all the possible merges.

.. topic:: :class:`FeatureAgglomeration`

The :class:`FeatureAgglomeration` uses agglomerative clustering to
group together features that look very similar, thus decreasing the
number of features. It is a dimensionality reduction tool, see
:ref:`data_reduction`.

Different linkage type: Ward, complete and average linkage
-----------------------------------------------------------

:class:`AgglomerativeClustering` supports Ward, average, and complete
linkage strategies.

The :class:`Ward` object performs a hierarchical clustering based on
the Ward algorithm, that is a variance-minimizing approach. At each
step, it minimizes the sum of squared differences within all clusters
(inertia criterion).
.. image:: ../auto_examples/cluster/images/plot_digits_linkage_1.png
:target: ../auto_examples/cluster/plot_digits_linkage.html
:scale: 43

This algorithm can scale to large number of samples when it is used jointly
with an connectivity matrix, but can be computationally expensive when no
connectivity constraints are added between samples: it considers at each step
all the possible merges.
.. image:: ../auto_examples/cluster/images/plot_digits_linkage_2.png
:target: ../auto_examples/cluster/plot_digits_linkage.html
:scale: 43

.. image:: ../auto_examples/cluster/images/plot_digits_linkage_3.png
:target: ../auto_examples/cluster/plot_digits_linkage.html
:scale: 43


Agglomerative cluster has a "rich get richer" behavior that leads to
uneven cluster sizes. In this regard, complete linkage is the worst
strategy, and Ward gives the most regular sizes. However, the affinity
(or distance used in clustering) cannot be varied with Ward, thus for non
Euclidean metrics, average linkage is a good alternative.

.. topic:: Examples:

* :ref:`example_cluster_plot_digits_linkage.py`: exploration of the
different linkage strategies in a real dataset.


Adding connectivity constraints
-------------------------------

An interesting aspect of the :class:`Ward` object is that connectivity
constraints can be added to this algorithm (only adjacent clusters can be
merged together), through an connectivity matrix that defines for each
sample the neighboring samples following a given structure of the data. For
instance, in the swiss-roll example below, the connectivity constraints
forbid the merging of points that are not adjacent on the swiss roll, and
thus avoid forming clusters that extend across overlapping folds of the
roll.
An interesting aspect of :class:`AgglomerativeClustering` is that
connectivity constraints can be added to this algorithm (only adjacent
clusters can be merged together), through a connectivity matrix that defines
for each sample the neighboring samples following a given structure of the
data. For instance, in the swiss-roll example below, the connectivity
constraints forbid the merging of points that are not adjacent on the swiss
roll, and thus avoid forming clusters that extend across overlapping folds of
the roll.

.. |unstructured| image:: ../auto_examples/cluster/images/plot_ward_structured_vs_unstructured_1.png
:target: ../auto_examples/cluster/plot_ward_structured_vs_unstructured.html
Expand All @@ -511,16 +563,19 @@ roll.

.. centered:: |unstructured| |structured|

These constraint are useful to impose a certain local structure, but they
also make the algorithm faster, especially when the number of the samples
is high.

The connectivity constraints are imposed via an connectivity matrix: a
scipy sparse matrix that has elements only at the intersection of a row
and a column with indices of the dataset that should be connected. This
matrix can be constructed from a-priori information, for instance if you
wish to cluster web pages, but only merging pages with a link pointing
matrix can be constructed from a-priori information: for instance, you
may wish to cluster web pages by only merging pages with a link pointing
from one to another. It can also be learned from the data, for instance
using :func:`sklearn.neighbors.kneighbors_graph` to restrict
merging to nearest neighbors as in the :ref:`swiss roll
<example_cluster_plot_ward_structured_vs_unstructured.py>` example, or
merging to nearest neighbors as in :ref:`this example
<example_cluster_plot_agglomerative_clustering.py>`, or
using :func:`sklearn.feature_extraction.image.grid_to_graph` to
enable only merging of neighboring pixels on an image, as in the
:ref:`Lena <example_cluster_plot_lena_ward_segmentation.py>` example.
Expand All @@ -538,6 +593,71 @@ enable only merging of neighboring pixels on an image, as in the
Example of dimensionality reduction with feature agglomeration based on
Ward hierarchical clustering.

* :ref:`example_cluster_plot_agglomerative_clustering.py`

.. warning:: **Connectivity constraints with average and complete linkage**

Connectivity constraints and complete or average linkage can enhance
the 'rich getting richer' aspect of agglomerative clustering,
particularly so if they are built with
:func:`sklearn.neighbors.kneighbors_graph`. In the limit of a small
number of clusters, they tend to give a few macroscopically occupied
clusters and almost empty ones. (see the discussion in
:ref:`example_cluster_plot_agglomerative_clustering.py`).

.. image:: ../auto_examples/cluster/images/plot_agglomerative_clustering_1.png
:target: ../auto_examples/cluster/plot_agglomerative_clustering.html
:scale: 38

.. image:: ../auto_examples/cluster/images/plot_agglomerative_clustering_2.png
:target: ../auto_examples/cluster/plot_agglomerative_clustering.html
:scale: 38

.. image:: ../auto_examples/cluster/images/plot_agglomerative_clustering_3.png
:target: ../auto_examples/cluster/plot_agglomerative_clustering.html
:scale: 38

.. image:: ../auto_examples/cluster/images/plot_agglomerative_clustering_4.png
:target: ../auto_examples/cluster/plot_agglomerative_clustering.html
:scale: 38


Varying the metric
-------------------

Average and complete linkage can be used with a variety of distances (or
affinities), in particular Euclidean distance (*l2*), Manhattan distance
(or Cityblock, or *l1*), cosine distance, or any precomputed affinity
matrix.

* *l1* distance is often good for sparse features, or sparse noise: ie
many of the features are zero, as in text mining using occurences of
rare words.

* *cosine* distance is interesting because it is invariant to global
scalings of the signal.

The guidelines for choosing a metric is to use one that maximizes the
distance between samples in different classes, and minimizes that within
each class.

.. image:: ../auto_examples/cluster/images/plot_agglomerative_clustering_metrics_5.png
:target: ../auto_examples/cluster/plot_agglomerative_clustering_metrics.html
:scale: 32

.. image:: ../auto_examples/cluster/images/plot_agglomerative_clustering_metrics_6.png
:target: ../auto_examples/cluster/plot_agglomerative_clustering_metrics.html
:scale: 32

.. image:: ../auto_examples/cluster/images/plot_agglomerative_clustering_metrics_7.png
:target: ../auto_examples/cluster/plot_agglomerative_clustering_metrics.html
:scale: 32

.. topic:: Examples:

* :ref:`example_cluster_plot_agglomerative_clustering_metrics.py`


.. _dbscan:

DBSCAN
Expand Down
59 changes: 59 additions & 0 deletions doc/modules/preprocessing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -469,3 +469,62 @@ values than observed values.

:class:`Imputer` can be used in a Pipeline as a way to build a composite
estimator that supports imputation. See :ref:`example_imputation.py`

.. _data_reduction:

Unsupervised data reduction
============================

If your number of features is high, it may be useful to reduce it with an
unsupervised step prior to supervised steps. Many of the
:ref:`unsupervised-learning` methods implement a `transform` method that
can be used to reduce the dimensionality. Below we discuss two specific
example of this pattern that are heavily used.

.. topic:: **Pipelining**

The unsupervised data reduction and the supervised estimator can be
chained in one step. See :ref:`pipeline`.

.. currentmodule:: sklearn

PCA: principal component analysis
----------------------------------

:class:`decomposition.PCA` looks for a combination of features that
capture well the variance of the original features.

.. topic:: **Examples**

* :ref:`example_applications_face_recognition.py`

Random projections
-------------------

The module: :mod:`random_projection` provides several tools for data
reduction by random projections. See the relevant section of the
documentation: :ref:`random_projection`.

.. topic:: **Examples**

* :ref:`example_plot_johnson_lindenstrauss_bound.py`

Feature agglometration
------------------------

:class:`cluster.FeatureAgglomeration` applies
:ref:`hierarchical_clustering` to group together features that behave
similarly.

.. topic:: **Examples**

* :ref:`example_cluster_plot_feature_agglomeration_vs_univariate_selection.py`
* :ref:`example_cluster_plot_digits_agglomeration.py`

.. topic:: **Feature scaling**

Note that if features have very different scaling or statistical
properties, :class:`cluster.FeatureAgglomeration` maye not be able to
capture the links between related features. Using a
:class:`preprocessing.StandardScaler` can be useful in these settings.

31 changes: 16 additions & 15 deletions doc/tutorial/statistical_inference/unsupervised_learning.rst
Original file line number Diff line number Diff line change
Expand Up @@ -151,26 +151,27 @@ algorithms. The simplest clustering algorithm is
Hierarchical agglomerative clustering: Ward
---------------------------------------------

A :ref:`hierarchical_clustering` method is a type of cluster analysis
A :ref:`hierarchical_clustering` method is a type of cluster analysis
that aims to build a hierarchy of clusters. In general, the various approaches
of this technique are either:

* **Agglomerative** - `bottom-up` approaches, or
* **Divisive** - `top-down` approaches.
* **Agglomerative** - `bottom-up` approaches: each observation starts in its
own cluster, and clusters are iterativelly merged in such a way to
minimize a *linkage* criterion. This approach is particularly interesting
when the clusters of interest are made of only a few observations. When
the number of clusters is large, it is much more computationally efficient
than k-means.

For estimating a large number of clusters, top-down approaches are both
statistically ill-posed and slow due to it starting with all observations
as one cluster, which it splits recursively. Agglomerative
hierarchical-clustering is a bottom-up approach that successively merges
observations together and is particularly useful when the clusters of interest
are made of only a few observations. *Ward* clustering minimizes a criterion
similar to k-means in a bottom-up approach. When the number of clusters is large,
it is much more computationally efficient than k-means.
* **Divisive** - `top-down` approaches: all observations start in one
cluster, which is iteratively split as one moves down the hierarchy.
For estimating large numbers of clusters, this approach is both slow (due
to all observations starting as one cluster, which it splits recursively)
and statistically ill-posed.

Connectivity-constrained clustering
.....................................

With Ward clustering, it is possible to specify which samples can be
With agglomerative clustering, it is possible to specify which samples can be
clustered together by giving a connectivity graph. Graphs in the scikit
are represented by their adjacency matrix. Often, a sparse matrix is used.
This can be useful, for instance, to retrieve connected regions (sometimes
Expand Down Expand Up @@ -212,10 +213,10 @@ transposed data.
>>> X = np.reshape(images, (len(images), -1))
>>> connectivity = grid_to_graph(*images[0].shape)

>>> agglo = cluster.WardAgglomeration(connectivity=connectivity,
... n_clusters=32)
>>> agglo = cluster.FeatureAgglomeration(connectivity=connectivity,
... n_clusters=32)
>>> agglo.fit(X) # doctest: +ELLIPSIS
WardAgglomeration(compute_full_tree='auto',...
FeatureAgglomeration(affinity='euclidean', compute_full_tree='auto',...
>>> X_reduced = agglo.transform(X)

>>> X_approx = agglo.inverse_transform(X_reduced)
Expand Down
22 changes: 16 additions & 6 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -146,12 +146,9 @@ Changelog
- Added :class:`linear_model.MultiTaskElasticNetCV` and
:class:`linear_model.MultiTaskLassoCV`. By `Manoj Kumar`_.

- :mod:`sklearn.hmm` is deprecated. Its removal is planned
for the 0.17 release.

- Use of :class:`covariance.EllipticEnvelop` has now been removed after
deprecation.
Please use :class:`covariance.EllipticEnvelope` instead.
- Added :class:`cluster.AgglomerativeClustering` for hierarchical
agglomerative clustering with average linkage, complete linkage and
ward strategies, by `Nelle Varoquaux`_ and `Gael Varoquaux`_.

- Fixed incorrect estimation of the degrees of freedom in
:func:`feature_selection.f_regression` when variates are not centered.
Expand All @@ -161,6 +158,19 @@ Changelog
API changes summary
-------------------

- :mod:`sklearn.hmm` is deprecated. Its removal is planned
for the 0.17 release.

- Use of :class:`covariance.EllipticEnvelop` has now been removed after
deprecation.
Please use :class:`covariance.EllipticEnvelope` instead.

- :class:`cluster.Ward` is deprecated. Use
:class:`cluster.AgglomerativeClustering` instead.

- :class:`cluster.WardClustering` is deprecated. Use
- :class:`cluster.AgglomerativeClustering` instead.

- Add score method to :class:`PCA <decomposition.PCA>` following the model of
probabilistic PCA and deprecate
:class:`ProbabilisticPCA <decomposition.ProbabilisticPCA>` model whose
Expand Down
Loading

0 comments on commit 39f3a76

Please sign in to comment.