Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
kisasexypantera94 committed Apr 18, 2024
1 parent 3473d25 commit bed4a19
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 20 deletions.
18 changes: 9 additions & 9 deletions src/graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
namespace urukrama {

template <typename T>
GraphConstructor<T>::GraphConstructor(std::span<const Point<T>> points, const size_t R, const size_t L)
Graph<T>::Graph(std::span<const Point<T>> points, const size_t R, const size_t L)
: m_R(R), m_L(L), m_dimension(points.front().size()), m_points(points)
{
Init();
Expand All @@ -33,7 +33,7 @@ GraphConstructor<T>::GraphConstructor(std::span<const Point<T>> points, const si


template <typename T>
void GraphConstructor<T>::Init()
void Graph<T>::Init()
{
std::mt19937_64 random_engine{std::random_device{}()};
std::uniform_int_distribution<size_t> dis(0, m_points.size() - 1);
Expand All @@ -48,7 +48,7 @@ void GraphConstructor<T>::Init()
}

template <typename T>
size_t GraphConstructor<T>::ProcessPoints(size_t s_idx, float alpha)
size_t Graph<T>::ProcessPoints(size_t s_idx, float alpha)
{
size_t good = 0;

Expand Down Expand Up @@ -80,13 +80,13 @@ size_t GraphConstructor<T>::ProcessPoints(size_t s_idx, float alpha)
}

template <typename T>
T GraphConstructor<T>::Distance(const Point<T>& a, const Point<T>& b)
T Graph<T>::Distance(const Point<T>& a, const Point<T>& b)
{
return (a - b).squaredNorm();
}

template <typename T>
size_t GraphConstructor<T>::FindMedoid()
size_t Graph<T>::FindMedoid()
{
Point<T> centroid = std::reduce(m_points.begin(), m_points.end(), Point<T>(m_dimension)) / m_points.size();

Expand All @@ -96,7 +96,7 @@ size_t GraphConstructor<T>::FindMedoid()
}

template <typename T>
GraphConstructor<T>::GreedySearchResult GraphConstructor<T>::GreedySearch(size_t s_idx, const Point<T>& query, size_t k)
Graph<T>::GreedySearchResult Graph<T>::GreedySearch(size_t s_idx, const Point<T>& query, size_t k)
{
HashSet<size_t> fast_visited;
fast_visited.reserve(m_L * 2);
Expand Down Expand Up @@ -135,7 +135,7 @@ GraphConstructor<T>::GreedySearchResult GraphConstructor<T>::GreedySearch(size_t
}

template <typename T>
void GraphConstructor<T>::RobustPrune(size_t p_idx, const std::vector<std::pair<T, size_t>>& candidates, float alpha)
void Graph<T>::RobustPrune(size_t p_idx, const std::vector<std::pair<T, size_t>>& candidates, float alpha)
{
auto& p_n_out = m_n_out[p_idx];

Expand Down Expand Up @@ -166,7 +166,7 @@ void GraphConstructor<T>::RobustPrune(size_t p_idx, const std::vector<std::pair<
}
}

template class GraphConstructor<float>;
template class GraphConstructor<uint8_t>;
template class Graph<float>;
template class Graph<uint8_t>;

} // namespace urukrama
14 changes: 7 additions & 7 deletions src/graph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,16 @@
namespace urukrama {

template <typename T>
class GraphConstructor {
class Graph {
public:
GraphConstructor(std::span<const Point<T>> points, const size_t R, const size_t L);
Graph(std::span<const Point<T>> points, const size_t R, const size_t L);

GraphConstructor(const GraphConstructor&) = delete;
GraphConstructor(GraphConstructor&&) = delete;
GraphConstructor& operator=(const GraphConstructor&) = delete;
GraphConstructor& operator=(GraphConstructor&&) = delete;
Graph(const Graph&) = delete;
Graph(Graph&&) = delete;
Graph& operator=(const Graph&) = delete;
Graph& operator=(Graph&&) = delete;

~GraphConstructor() = default;
~Graph() = default;

private:
template <typename K>
Expand Down
8 changes: 4 additions & 4 deletions src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
#include <ranges>
#include <vector>

constexpr size_t NUM_BATCHES = 40;
constexpr size_t NUM_BATCHES = 8;
constexpr size_t l = 2;
constexpr size_t M = 32;

auto PrepareBatches()
{
auto [points, dimension, num_points] = urukrama::FVecsRead("/data/deep1M_base.fvecs");
auto [points, dimension, num_points] = urukrama::FVecsRead("../data/deep1b/deep1M_base.fvecs");
// auto pq = ComputeProductQuantization(points, dimension, M);
const auto [clusters, indices, distances] = ComputeClusters(points, dimension, NUM_BATCHES, l);

Expand Down Expand Up @@ -50,7 +50,7 @@ int main()
{
const auto batches = PrepareBatches();

boost::asio::thread_pool pool;
boost::asio::thread_pool pool(8);
std::atomic<size_t> sum_proc_time = 0;

for (const auto& [batch_idx, batch]: batches | std::views::enumerate) {
Expand All @@ -59,7 +59,7 @@ int main()


auto t0 = high_resolution_clock::now();
urukrama::GraphConstructor gc(std::span{batch}, 70, 75);
urukrama::Graph gc(std::span{batch}, 70, 75);
auto duration = duration_cast<seconds>(high_resolution_clock::now() - t0).count();

sum_proc_time += duration;
Expand Down

0 comments on commit bed4a19

Please sign in to comment.