Skip to content

Commit

Permalink
support quantization for search
Browse files Browse the repository at this point in the history
  • Loading branch information
kisasexypantera94 committed Apr 26, 2024
1 parent 376be15 commit c8adb70
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 25 deletions.
17 changes: 6 additions & 11 deletions src/faiss.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <faiss/Clustering.h>
#include <faiss/IndexFlat.h>
#include <faiss/IndexPQ.h>
#include <faiss/impl/ProductQuantizer.h>

using faiss::idx_t;
Expand Down Expand Up @@ -56,18 +57,12 @@ std::tuple<std::vector<float>, std::vector<int64_t>, std::vector<float>> Compute
return ComputeClustersCPU(vecs, dim, number_of_clusters, k, number_of_iterations, verbose);
}

std::vector<uint8_t> ComputeProductQuantization(const std::vector<float>& vecs, size_t dim, size_t M, size_t nbits)
faiss::IndexPQ BuildIndexPQ(const std::vector<float>& vecs, size_t dim, size_t M, size_t nbits)
{
faiss::ProductQuantizer pq(dim, M, nbits);
faiss::IndexPQ index_pq(int(dim), M, nbits);

pq.cp.niter = 50;
pq.cp.verbose = false; // print out per-iteration stats
index_pq.train(idx_t(vecs.size() / dim), vecs.data());
index_pq.add(idx_t(vecs.size() / dim), vecs.data());

pq.verbose = false;
pq.train(vecs.size() / dim, vecs.data());

std::vector<uint8_t> codes(vecs.size() / dim * M);
pq.compute_codes(vecs.data(), codes.data(), vecs.size() / dim);

return codes;
return index_pq;
}
7 changes: 3 additions & 4 deletions src/faiss.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#pragma once

#include <faiss/IndexPQ.h>

#include <cstdint>
#include <vector>

Expand All @@ -11,7 +13,4 @@ std::tuple<std::vector<float>, std::vector<int64_t>, std::vector<float>> Compute
size_t number_of_iterations = 20,
bool verbose = true);

std::vector<uint8_t> ComputeProductQuantization(const std::vector<float>& vecs,
size_t dim,
size_t M = 32,
size_t nbits = 8);
faiss::IndexPQ BuildIndexPQ(const std::vector<float>& vecs, size_t dim, size_t M = 32, size_t nbits = 8);
7 changes: 4 additions & 3 deletions src/in_memory_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,10 @@ std::vector<HashSet<size_t>> InMemoryGraph<T>::BuildIndexInBatches(size_t s_idx)
for (const float alpha: {1.0f, 1.2f}) {
for (size_t start = 0, end = start; start < m_points.size(); start = end + 1,
end = std::min({start * 2, start + size_t(float(m_points.size()) * 0.02), m_points.size() - 1})) {
std::atomic<size_t> cnt_found = 0;
const size_t batch_size = end - start + 1;

std::vector<HashSet<size_t>> n_out_deltas(end - start + 1);
std::atomic<size_t> cnt_found = 0;
std::vector<HashSet<size_t>> n_out_deltas(batch_size);

tbb::parallel_for(tbb::blocked_range<size_t>(start, end + 1), [&](tbb::blocked_range<size_t> r) {
for (size_t p_idx = r.begin(); p_idx < r.end(); ++p_idx) {
Expand Down Expand Up @@ -116,7 +117,7 @@ std::vector<HashSet<size_t>> InMemoryGraph<T>::BuildIndexInBatches(size_t s_idx)
ksp::log::Info("Processed batch: range=[{}..{}], precision=[{}]",
start,
end,
float(cnt_found) / float(end - start + 1));
float(cnt_found) / float(batch_size));
}
}

Expand Down
17 changes: 14 additions & 3 deletions src/main.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include "faiss.hpp"
#include "in_memory_graph.hpp"
#include "on_disk_graph.hpp"
#include "utils.hpp"
Expand All @@ -10,7 +11,7 @@

int main()
{
auto points = [] {
auto [points, index_pq] = [] {
auto [flat_points, dimension, num_points] = urukrama::FVecsRead("/data/deep1M_base.fvecs");
std::vector<urukrama::Point<float>> points;

Expand All @@ -21,7 +22,17 @@ int main()
std::mt19937_64 random_engine{std::random_device{}()};
std::ranges::shuffle(points, random_engine);

return points;
// TODO: inplace
std::vector<float> flat_points_shuffled;
flat_points_shuffled.reserve(flat_points.size());

for (const auto& x: points | std::views::join) {
flat_points_shuffled.emplace_back(x);
}

auto index_pq = BuildIndexPQ(flat_points_shuffled, dimension);

return std::make_pair(std::move(points), std::move(index_pq));
}();

ksp::log::Info("Loaded points: size=[{}]", points.size());
Expand All @@ -38,7 +49,7 @@ int main()
using namespace std::chrono;

auto t0 = high_resolution_clock::now();
auto top = on_disk_graph.GreedySearch(p, 1);
auto top = on_disk_graph.GreedySearchWithPQ(index_pq, p, 1);
auto t1 = high_resolution_clock::now();

total_search_time += duration_cast<microseconds>(t1 - t0).count();
Expand Down
23 changes: 20 additions & 3 deletions src/on_disk_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,25 @@ OnDiskGraph<T>::OnDiskGraph(std::string_view index_filename)
m_medoid_idx);
}

template <typename T>
std::vector<std::pair<T, size_t>> OnDiskGraph<T>::GreedySearchWithPQ(const faiss::IndexPQ& index_pq,
const Point<T>& query,
size_t k) const
{
auto computer = index_pq.get_FlatCodesDistanceComputer();
computer->set_query(query.data());

return GreedySearchInternal([&](size_t p_idx) { return (*computer)(faiss::idx_t(p_idx)); }, k);
}

template <typename T>
std::vector<std::pair<T, size_t>> OnDiskGraph<T>::GreedySearch(const Point<T>& query, size_t k) const
{
return GreedySearchInternal([&](size_t p_idx) { return FullPrecisionDistance(p_idx, query); }, k);
}

template <typename T>
std::vector<std::pair<T, size_t>> OnDiskGraph<T>::GreedySearchInternal(auto distance_func, size_t k) const
{
HashSet<size_t> fast_visited;
fast_visited.reserve(m_L * 2);
Expand All @@ -45,7 +62,7 @@ std::vector<std::pair<T, size_t>> OnDiskGraph<T>::GreedySearch(const Point<T>& q

BoundedSortedVector<T, size_t> candidates(m_L);
candidates.reserve(m_L + 1);
candidates.emplace(Distance(m_medoid_idx, query), m_medoid_idx);
candidates.emplace(distance_func(m_medoid_idx), m_medoid_idx);

while (true) {
auto it = std::find_if(candidates.begin(), candidates.end(), [&](const auto& c) {
Expand All @@ -65,7 +82,7 @@ std::vector<std::pair<T, size_t>> OnDiskGraph<T>::GreedySearch(const Point<T>& q

for (const size_t n_idx: n_out | std::views::filter([](const size_t n_idx) { return n_idx != DUMMY_P_IDX; })) {
if (not fast_visited.contains(n_idx)) {
candidates.emplace(Distance(n_idx, query), n_idx);
candidates.emplace(distance_func(n_idx), n_idx);
}
}
}
Expand Down Expand Up @@ -163,7 +180,7 @@ OnDiskGraph<float>::DataType OnDiskGraph<float>::GetDataType()
}

template <typename T>
T OnDiskGraph<T>::Distance(const size_t a_idx, const Point<T>& b) const
T OnDiskGraph<T>::FullPrecisionDistance(const size_t a_idx, const Point<T>& b) const
{
std::span point = GetPoint(a_idx);

Expand Down
6 changes: 5 additions & 1 deletion src/on_disk_graph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "types.hpp"

#include <mio/mmap.hpp>
#include <faiss/IndexPQ.h>

#include <cstddef>
#include <cstdint>
Expand All @@ -22,6 +23,7 @@ class OnDiskGraph {

public:
std::vector<std::pair<T, size_t>> GreedySearch(const Point<T>& query, size_t k) const;
std::vector<std::pair<T, size_t>> GreedySearchWithPQ(const faiss::IndexPQ& index_pq, const Point<T>& query, size_t k) const;

static void Write(const InMemoryGraph<T>& in_mem_graph, std::string_view filename);

Expand All @@ -41,6 +43,8 @@ class OnDiskGraph {
};

private:
std::vector<std::pair<T, size_t>> GreedySearchInternal(auto distance_func, size_t k) const;

std::span<const float> GetPoint(size_t p_idx) const;
std::span<const size_t> GetPointNeighbors(size_t p_idx) const;

Expand All @@ -49,7 +53,7 @@ class OnDiskGraph {
template <typename U>
U ReadAs(size_t offset) const;

T Distance(const size_t a_idx, const Point<T>& b) const;
T FullPrecisionDistance(const size_t a_idx, const Point<T>& b) const;

static DataType GetDataType();

Expand Down

0 comments on commit c8adb70

Please sign in to comment.