forked from davisking/dlib
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
275 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
// Copyright (C) 2015 Davis E. King (davis@dlib.net) | ||
// License: Boost Software License See LICENSE.txt for the full license. | ||
#ifndef DLIB_BOTTOM_uP_CLUSTER_Hh_ | ||
#define DLIB_BOTTOM_uP_CLUSTER_Hh_ | ||
|
||
#include <queue> | ||
#include <map> | ||
|
||
#include "bottom_up_cluster_abstract.h" | ||
#include "../algs.h" | ||
#include "../matrix.h" | ||
#include "../disjoint_subsets.h" | ||
#include "../graph_utils.h" | ||
|
||
|
||
namespace dlib | ||
{ | ||
|
||
// ---------------------------------------------------------------------------------------- | ||
|
||
namespace buc_impl | ||
{ | ||
inline void merge_sets ( | ||
matrix<double>& dists, | ||
unsigned long dest, | ||
unsigned long src | ||
) | ||
{ | ||
for (long r = 0; r < dists.nr(); ++r) | ||
dists(dest,r) = dists(r,dest) = std::max(dists(r,dest), dists(r,src)); | ||
} | ||
|
||
struct compare_dist | ||
{ | ||
bool operator() ( | ||
const sample_pair& a, | ||
const sample_pair& b | ||
) const | ||
{ | ||
return a.distance() > b.distance(); | ||
} | ||
}; | ||
} | ||
|
||
// ---------------------------------------------------------------------------------------- | ||
|
||
template < | ||
typename EXP | ||
> | ||
unsigned long bottom_up_cluster ( | ||
const matrix_exp<EXP>& dists_, | ||
std::vector<unsigned long>& labels, | ||
unsigned long min_num_clusters, | ||
double max_dist = std::numeric_limits<double>::infinity() | ||
) | ||
{ | ||
matrix<double> dists = matrix_cast<double>(dists_); | ||
// make sure requires clause is not broken | ||
DLIB_CASSERT(dists.nr() == dists.nc() && min_num_clusters > 0, | ||
"\t unsigned long bottom_up_cluster()" | ||
<< "\n\t Invalid inputs were given to this function." | ||
<< "\n\t dists.nr(): " << dists.nr() | ||
<< "\n\t dists.nc(): " << dists.nc() | ||
<< "\n\t min_num_clusters: " << min_num_clusters | ||
); | ||
|
||
using namespace buc_impl; | ||
|
||
labels.resize(dists.nr()); | ||
disjoint_subsets sets; | ||
sets.set_size(dists.nr()); | ||
if (labels.size() == 0) | ||
return 0; | ||
|
||
// push all the edges in the graph into a priority queue so the best edges to merge | ||
// come first. | ||
std::priority_queue<sample_pair, std::vector<sample_pair>, compare_dist> que; | ||
for (long r = 0; r < dists.nr(); ++r) | ||
for (long c = r+1; c < dists.nc(); ++c) | ||
que.push(sample_pair(r,c,dists(r,c))); | ||
|
||
// Now start merging nodes. | ||
for (unsigned long iter = min_num_clusters; iter < sets.size(); ++iter) | ||
{ | ||
// find the next best thing to merge. | ||
double best_dist = que.top().distance(); | ||
unsigned long a = sets.find_set(que.top().index1()); | ||
unsigned long b = sets.find_set(que.top().index2()); | ||
que.pop(); | ||
// we have been merging and modifying the distances, so make sure this distance | ||
// is still valid and these guys haven't been merged already. | ||
while(a == b || best_dist < dists(a,b)) | ||
{ | ||
// Haven't merged it yet, so put it back in with updated distance for | ||
// reconsideration later. | ||
if (a != b) | ||
que.push(sample_pair(a, b, dists(a, b))); | ||
|
||
best_dist = que.top().distance(); | ||
a = sets.find_set(que.top().index1()); | ||
b = sets.find_set(que.top().index2()); | ||
que.pop(); | ||
} | ||
|
||
|
||
// now merge these sets if the best distance is small enough | ||
if (best_dist > max_dist) | ||
break; | ||
unsigned long news = sets.merge_sets(a,b); | ||
unsigned long olds = (news==a)?b:a; | ||
merge_sets(dists, news, olds); | ||
} | ||
|
||
// figure out which cluster each element is in. Also make sure the labels are | ||
// contiguous. | ||
std::map<unsigned long, unsigned long> relabel; | ||
for (unsigned long r = 0; r < labels.size(); ++r) | ||
{ | ||
unsigned long l = sets.find_set(r); | ||
// relabel to make contiguous | ||
if (relabel.count(l) == 0) | ||
{ | ||
unsigned long next = relabel.size(); | ||
relabel[l] = next; | ||
} | ||
labels[r] = relabel[l]; | ||
} | ||
|
||
|
||
return relabel.size(); | ||
} | ||
|
||
// ---------------------------------------------------------------------------------------- | ||
|
||
} | ||
|
||
#endif // DLIB_BOTTOM_uP_CLUSTER_Hh_ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
// Copyright (C) 2015 Davis E. King (davis@dlib.net) | ||
// License: Boost Software License See LICENSE.txt for the full license. | ||
#undef DLIB_BOTTOM_uP_CLUSTER_ABSTRACT_Hh_ | ||
#ifdef DLIB_BOTTOM_uP_CLUSTER_ABSTRACT_Hh_ | ||
|
||
#include "../matrix.h" | ||
|
||
namespace dlib | ||
{ | ||
|
||
// ---------------------------------------------------------------------------------------- | ||
|
||
template < | ||
typename EXP | ||
> | ||
unsigned long bottom_up_cluster ( | ||
const matrix_exp<EXP>& dists, | ||
std::vector<unsigned long>& labels, | ||
unsigned long min_num_clusters, | ||
double max_dist = std::numeric_limits<double>::infinity() | ||
); | ||
/*! | ||
requires | ||
- dists.nr() == dists.nc() | ||
- min_num_clusters > 0 | ||
- dists == trans(dists) | ||
(l.e. dists should be symmetric) | ||
ensures | ||
- Runs a bottom up agglomerative clustering algorithm. | ||
- Interprets dists as a matrix that gives the distances between dists.nr() | ||
items. In particular, we take dists(i,j) to be the distance between the ith | ||
and jth element of some set. This function clusters the elements of this set | ||
into at least min_num_clusters (or dists.nr() if there aren't enough | ||
elements). Additionally, within each cluster, the maximum pairwise distance | ||
between any two cluster elements is <= max_dist. | ||
- returns the number of clusters found. | ||
- #labels.size() == dists.nr() | ||
- for all valid i: | ||
- #labels[i] == the cluster ID of the node with index i (i.e. the node | ||
corresponding to the distances dists(i,*)). | ||
- 0 <= #labels[i] < the number of clusters found | ||
(i.e. cluster IDs are assigned contiguously and start at 0) | ||
!*/ | ||
|
||
// ---------------------------------------------------------------------------------------- | ||
|
||
} | ||
|
||
#endif // DLIB_BOTTOM_uP_CLUSTER_ABSTRACT_Hh_ | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters