Skip to content

Commit

Permalink
Clustering merge refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyaGusev committed Oct 17, 2020
1 parent 1a2ce61 commit 861ea23
Showing 1 changed file with 10 additions and 14 deletions.
24 changes: 10 additions & 14 deletions src/clustering/slink.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ TClusters TSlinkClustering::Cluster(
embeddingKeysWeights[embeddingKeyWeight.embedding_key()] = embeddingKeyWeight.weight();
}
const size_t docSize = docs.size();
const size_t intersectionSize = Config.intersection_size();
std::vector<size_t> labels;
labels.reserve(docSize);

Expand All @@ -92,27 +93,22 @@ TClusters TSlinkClustering::Cluster(
std::vector<TDbDocument>::const_iterator end = begin + batchSize;

std::vector<size_t> newLabels = ClusterBatch(begin, end, embeddingKeysWeights);
size_t newMaxLabel = maxLabel;
for (auto& label : newLabels) {
label += maxLabel;
newMaxLabel = std::max(newMaxLabel, label);
}
maxLabel = newMaxLabel;
std::for_each(newLabels.begin(), newLabels.end(), [&](size_t& i){ i += maxLabel; });
maxLabel = *std::max_element(newLabels.begin(), newLabels.end());

assert(begin->Url == docs[batchStart].Url);
for (size_t i = batchStart; i < batchStart + Config.intersection_size() && i < labels.size(); i++) {
for (size_t i = batchStart; i < batchStart + intersectionSize && i < labels.size(); i++) {
size_t oldLabel = labels[i];
int j = i - batchStart;
assert(j >= 0 && static_cast<size_t>(j) < newLabels.size());
size_t newLabel = newLabels[j];
size_t batchIndex = static_cast<size_t>(i - batchStart);
size_t newLabel = newLabels.at(batchIndex);
oldLabelsToNew[oldLabel] = newLabel;
}
if (batchStart == 0) {
for (size_t i = 0; i < std::min(static_cast<size_t>(Config.intersection_size()), newLabels.size()); i++) {
for (size_t i = 0; i < std::min(intersectionSize, newLabels.size()); i++) {
labels.push_back(newLabels[i]);
}
}
for (size_t i = Config.intersection_size(); i < newLabels.size(); i++) {
for (size_t i = intersectionSize; i < newLabels.size(); i++) {
labels.push_back(newLabels[i]);
}
assert(batchStart == static_cast<size_t>(std::distance(docs.begin(), begin)));
Expand All @@ -122,8 +118,8 @@ TClusters TSlinkClustering::Cluster(
}

prevBatchEnd = batchStart + batchSize;
batchStart = batchStart + batchSize - Config.intersection_size();
begin = end - Config.intersection_size();
batchStart = prevBatchEnd - intersectionSize;
begin = end - intersectionSize;
}
assert(labels.size() == docs.size());
for (auto& label : labels) {
Expand Down

0 comments on commit 861ea23

Please sign in to comment.