Skip to content

Commit

Permalink
TRanker class with config
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyaGusev committed Sep 17, 2020
1 parent 971b704 commit bef5254
Show file tree
Hide file tree
Showing 10 changed files with 60 additions and 27 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ set(SOURCE_FILES
src/embedders/token_indexer.cpp
src/embedders/torch_embedder.cpp
src/nasty.cpp
src/rank.cpp
src/ranker.cpp
src/run_server.cpp
src/server_clustering.cpp
src/summarizer.cpp
Expand Down
1 change: 1 addition & 0 deletions configs/ranker.pbtxt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
min_cluster_size: 3
3 changes: 3 additions & 0 deletions configs/server.pbtxt
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,6 @@ clusterer_config_path: "configs/clusterer.pbtxt"

## Path to summarizer config
summarizer_config_path: "configs/summarizer.pbtxt"

## Path to ranker config
ranker_config_path: "configs/ranker.pbtxt"
7 changes: 4 additions & 3 deletions src/controller.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

#include "document.h"
#include "document.pb.h"
#include "rank.h"
#include "util.h"

#include <optional>
Expand Down Expand Up @@ -61,11 +60,13 @@ namespace {
void TController::Init(
const THotState<TClusterIndex>* index,
rocksdb::DB* db,
std::unique_ptr<TAnnotator> annotator
std::unique_ptr<TAnnotator> annotator,
std::unique_ptr<TRanker> ranker
) {
Index = index;
Db = db;
Annotator = std::move(annotator);
Ranker = std::move(ranker);
Initialized.store(true, std::memory_order_release);
}

Expand Down Expand Up @@ -201,7 +202,7 @@ void TController::Threads(
const uint64_t fromTimestamp = index->TrueMaxTimestamp > period.value() ? index->TrueMaxTimestamp - period.value() : 0;

const auto indexIt = std::lower_bound(clusters.cbegin(), clusters.cend(), fromTimestamp, TNewsCluster::Compare);
const auto weightedClusters = Rank(indexIt, clusters.cend(), index->IterTimestamp, period.value());
const auto weightedClusters = Ranker->Rank(indexIt, clusters.cend(), index->IterTimestamp, period.value());
const auto& categoryClusters = weightedClusters.at(category.value());

Json::Value threads(Json::arrayValue);
Expand Down
6 changes: 4 additions & 2 deletions src/controller.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "annotator.h"
#include "clusterer.h"
#include "hot_state.h"
#include "ranker.h"

#include <drogon/HttpController.h>
#include <rocksdb/db.h>
Expand All @@ -20,7 +21,8 @@ class TController : public drogon::HttpController<TController, /* AutoCreation *
void Init(
const THotState<TClusterIndex>* index,
rocksdb::DB* db,
std::unique_ptr<TAnnotator> annotator
std::unique_ptr<TAnnotator> annotator,
std::unique_ptr<TRanker> ranker
);

void Put(
Expand Down Expand Up @@ -64,12 +66,12 @@ class TController : public drogon::HttpController<TController, /* AutoCreation *
drogon::HttpStatusCode existedCode
) const;


private:
std::atomic<bool> Initialized {false};

const THotState<TClusterIndex>* Index;

rocksdb::DB* Db;
std::unique_ptr<TAnnotator> Annotator;
std::unique_ptr<TRanker> Ranker;
};
8 changes: 6 additions & 2 deletions src/main.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#include "annotator.h"
#include "clusterer.h"
#include "rank.h"
#include "ranker.h"
#include "run_server.h"
#include "summarizer.h"
#include "timer.h"
Expand All @@ -22,6 +22,7 @@ int main(int argc, char** argv) {
("annotator_config", po::value<std::string>()->default_value("configs/annotator.pbtxt"), "annotator_config")
("clusterer_config", po::value<std::string>()->default_value("configs/clusterer.pbtxt"), "clusterer_config")
("summarizer_config", po::value<std::string>()->default_value("configs/summarizer.pbtxt"), "summarizer_config")
("ranker_config", po::value<std::string>()->default_value("configs/ranker.pbtxt"), "ranker_config")
("ndocs", po::value<int>()->default_value(-1), "ndocs")
("save_not_news", po::bool_switch()->default_value(false), "save_not_news")
("languages", po::value<std::vector<std::string>>()->multitoken()->default_value(std::vector<std::string>{"ru", "en"}, "ru en"), "languages")
Expand Down Expand Up @@ -233,7 +234,10 @@ int main(int argc, char** argv) {
std::back_inserter(allClusters)
);
}
const auto tops = Rank(allClusters.begin(), allClusters.end(), clusterIndex.IterTimestamp, window);

const std::string rankerConfigPath = vm["ranker_config"].as<std::string>();
const TRanker ranker(rankerConfigPath);
const auto tops = ranker.Rank(allClusters.begin(), allClusters.end(), clusterIndex.IterTimestamp, window);
nlohmann::json outputJson = nlohmann::json::array();
for (auto it = tops.begin(); it != tops.end(); ++it) {
const auto category = static_cast<tg::ECategory>(std::distance(tops.begin(), it));
Expand Down
9 changes: 7 additions & 2 deletions src/proto/config.proto
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ message TServerConfig {
string annotator_config_path = 12;
string clusterer_config_path = 13;
string summarizer_config_path = 14;
string ranker_config_path = 15;
}

message TCategoryModelConfig{
Expand Down Expand Up @@ -77,6 +78,10 @@ message TClustererConfig {
}

message TSummarizerConfig {
string hosts_rating = 3;
string alexa_rating = 4;
string hosts_rating = 1;
string alexa_rating = 2;
}

message TRankerConfig {
uint64 min_cluster_size = 1;
}
20 changes: 13 additions & 7 deletions src/rank.cpp → src/ranker.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include "rank.h"
#include "ranker.h"
#include "util.h"

TWeightInfo ComputeClusterWeightPush(
Expand All @@ -19,12 +19,16 @@ TWeightInfo ComputeClusterWeightPush(
return TWeightInfo{clusterTime, rank, timeMultiplier, rank * timeMultiplier, cluster.GetSize()};
}

std::vector<std::vector<TWeightedNewsCluster>> Rank(
TRanker::TRanker(const std::string& configPath) {
::ParseConfig(configPath, Config);
}

std::vector<std::vector<TWeightedNewsCluster>> TRanker::Rank(
TClusters::const_iterator begin,
TClusters::const_iterator end,
uint64_t iterTimestamp,
uint64_t window
) {
) const {
std::vector<TWeightedNewsCluster> weightedClusters;
for (TClusters::const_iterator it = begin; it != end; it++) {
const TNewsCluster& cluster = *it;
Expand All @@ -33,12 +37,14 @@ std::vector<std::vector<TWeightedNewsCluster>> Rank(
}

std::stable_sort(weightedClusters.begin(), weightedClusters.end(),
[](const TWeightedNewsCluster& a, const TWeightedNewsCluster& b) {
if (a.WeightInfo.ClusterSize == b.WeightInfo.ClusterSize) {
[&](const TWeightedNewsCluster& a, const TWeightedNewsCluster& b) {
size_t firstSize = a.WeightInfo.ClusterSize;
size_t secondSize = b.WeightInfo.ClusterSize;
if (firstSize == secondSize) {
return a.WeightInfo.Weight > b.WeightInfo.Weight;
}
if (a.WeightInfo.ClusterSize < 3 || b.WeightInfo.ClusterSize < 3) {
return a.WeightInfo.ClusterSize > b.WeightInfo.ClusterSize;
if (firstSize < this->Config.min_cluster_size() || secondSize < this->Config.min_cluster_size()) {
return firstSize > secondSize;
}
return a.WeightInfo.Weight > b.WeightInfo.Weight;
}
Expand Down
25 changes: 17 additions & 8 deletions src/rank.h → src/ranker.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "agency_rating.h"
#pragma once

#include "clustering/clustering.h"
#include "db_document.h"
#include "config.pb.h"

#include <cstdint>
#include <unordered_map>
Expand All @@ -26,9 +27,17 @@ struct TWeightedNewsCluster {
{}
};

std::vector<std::vector<TWeightedNewsCluster>> Rank(
TClusters::const_iterator begin,
TClusters::const_iterator end,
uint64_t iterTimestamp,
uint64_t window
);
class TRanker {
public:
TRanker(const std::string& configPath);

std::vector<std::vector<TWeightedNewsCluster>> Rank(
TClusters::const_iterator begin,
TClusters::const_iterator end,
uint64_t iterTimestamp,
uint64_t window
) const;

private:
tg::TRankerConfig Config;
};
6 changes: 4 additions & 2 deletions src/run_server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ int RunServer(const std::string& fname, uint16_t port) {
LOG_DEBUG("Creating summarizer");
std::unique_ptr<TSummarizer> summarizer = std::make_unique<TSummarizer>(config.summarizer_config_path());

LOG_DEBUG("Creating ranker");
std::unique_ptr<TRanker> ranker = std::make_unique<TRanker>(config.ranker_config_path());

TServerClustering serverClustering(std::move(clusterer), std::move(summarizer), db.get());

LOG_DEBUG("Launching server");
Expand All @@ -100,12 +103,11 @@ int RunServer(const std::string& fname, uint16_t port) {
auto controllerPtr = std::make_shared<TController>();
app().registerController(controllerPtr);


LOG_DEBUG("Launching clustering");
THotState<TClusterIndex> index;

auto initContoller = [&, annotator=std::move(annotator)]() mutable {
DrClassMap::getSingleInstance<TController>()->Init(&index, db.get(), std::move(annotator));
DrClassMap::getSingleInstance<TController>()->Init(&index, db.get(), std::move(annotator), std::move(ranker));
};

std::thread clusteringThread([&, sleep_ms=config.clusterer_sleep()]() {
Expand Down

0 comments on commit bef5254

Please sign in to comment.