Skip to content

Commit

Permalink
hnsw and binflat support
Browse files Browse the repository at this point in the history
Signed-off-by: Yusheng.Ma <Yusheng.Ma@zilliz.com>
  • Loading branch information
Presburger committed Aug 29, 2022
1 parent 40eab16 commit b583062
Show file tree
Hide file tree
Showing 35 changed files with 2,826 additions and 2,128 deletions.
11 changes: 8 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ include(cmake/utils/utils.cmake)
include(cmake/utils/compiler_check.cmake)
include(cmake/utils/platform_check.cmake)
include(cmake/libs/libfaiss.cmake)

include(cmake/libs/libannoy.cmake)
include(cmake/libs/libhnsw.cmake)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)

add_definitions(-DAUTO_INITIALIZE_EASYLOGGINGPP=1)
Expand All @@ -21,10 +22,14 @@ include_directories(thirdparty/faiss)
include_directories(thirdparty/bitset)
include_directories(thirdparty)
find_package(OpenMP REQUIRED)
find_package(Boost REQUIRED)
include_directories(${OpenMP_CXX_INCLUDE_DIRS})


if (OPENMP_FOUND)
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
set (CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${OpenMP_EXE_LINKER_FLAGS}")
endif()

find_package(Boost REQUIRED)
knowhere_file_glob(GLOB_RECURSE KNOWHERE_SRCS src/common/*.cc
src/index/*.cc src/io/*.cc)
include_directories(src)
Expand Down
1 change: 1 addition & 0 deletions cmake/libs/libannoy.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
include_directories(thirdparty/annoy/src)
4 changes: 1 addition & 3 deletions cmake/libs/libfaiss.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ if(__X86_64)
-mf16c
-mavx512dq
-mavx512bw>)
add_library(faiss STATIC ${FAISS_SRCS})
add_library(faiss STATIC ${FAISS_SRCS})
add_dependencies(faiss faiss_avx512 knowhere_utils)
target_compile_options(faiss PRIVATE $<$<COMPILE_LANGUAGE:CXX>: -msse4.2
-mavx2 -mfma -mf16c -Wno-sign-compare
Expand All @@ -68,5 +68,3 @@ if(__X86_64)
target_compile_definitions(faiss PRIVATE FINTEGER=int)

endif()


1 change: 1 addition & 0 deletions cmake/libs/libhnsw.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
include_directories(thirdparty/hnswlib)
19 changes: 0 additions & 19 deletions cmake/utils/compiler_check.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,6 @@ if(NOT COMPILER_SUPPORTS_CXX17)
)
endif()

if(CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
check_cxx_compiler_flag("-fopenmp=libomp" COMPILER_SUPPORTS_OMP)
set(CMAKE_CXX_FLAGS "-fopenmp=libomp ${CMAKE_CXX_FLAGS}")
endif()

if(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang")
check_cxx_compiler_flag("-Xclang -fopenmp" COMPILER_SUPPORTS_OMP)
set(CMAKE_CXX_FLAGS "-Xclang -fopenmp ${CMAKE_CXX_FLAGS}")
endif()

if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
check_cxx_compiler_flag("-fopenmp" COMPILER_SUPPORTS_OMP)
set(CMAKE_CXX_FLAGS "-fopenmp ${CMAKE_CXX_FLAGS}")
endif()

if(NOT COMPILER_SUPPORTS_OMP)
message(FATAL_ERROR "compiler must support openmp.")
endif()

if("${CMAKE_BUILD_TYPE}" STREQUAL "Debug")
message(STATUS "Build in Debug mode")
set(CMAKE_CXX_FLAGS "-O0 -g -Wall -fPIC ${CMAKE_CXX_FLAGS}")
Expand Down
26 changes: 22 additions & 4 deletions include/knowhere/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ template <typename T>
struct Entry {};

enum PARAM_TYPE {
QUERY = 0x1,
SEARCH = 0x1,
RANGE = 0x2,
TRAIN = 0x4,
};
Expand Down Expand Up @@ -114,8 +114,8 @@ class EntryAccess {
return *this;
}
EntryAccess&
for_query() {
entry->type |= PARAM_TYPE::QUERY;
for_search() {
entry->type |= PARAM_TYPE::SEARCH;
return *this;
}
EntryAccess&
Expand All @@ -130,7 +130,7 @@ class EntryAccess {
}
EntryAccess&
for_all() {
entry->type |= PARAM_TYPE::QUERY;
entry->type |= PARAM_TYPE::SEARCH;
entry->type |= PARAM_TYPE::RANGE;
entry->type |= PARAM_TYPE::TRAIN;
return *this;
Expand Down Expand Up @@ -259,6 +259,24 @@ class Config {
EntryAccess<decltype(PARAM)> PARAM##_access(std::get_if<Entry<decltype(PARAM)>>(&__DICT__[#PARAM])); \
PARAM##_access

class BaseConfig : public Config {
public:
int dim;
std::string metric_type;
int k;
float radius;
KNOHWERE_DECLARE_CONFIG(BaseConfig) {
KNOWHERE_CONFIG_DECLARE_FIELD(dim).description("vector dims.").for_all();
KNOWHERE_CONFIG_DECLARE_FIELD(metric_type).set_default("L2").description("distance metric type.").for_all();
KNOWHERE_CONFIG_DECLARE_FIELD(k)
.set_default(10)
.description("search for top k similar vector.")
.for_search()
.for_train();
KNOWHERE_CONFIG_DECLARE_FIELD(radius).set_default(0.0f).description("range search radius.").for_range();
}
};

} // namespace knowhere

#endif /* CONFIG_H */
10 changes: 9 additions & 1 deletion include/knowhere/expected.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
#ifndef EXPECTED_H
#define EXPECTED_H

#include <cassert>
#include <iostream>
#include <string>
Expand All @@ -11,9 +14,12 @@ enum class Error {
type_conflict_in_json,
invalid_metric_type,
empty_index,
not_implemented,
index_not_trained,
index_already_trained,
faiss_inner_error,

annoy_inner_error,
hnsw_inner_error,
};

template <typename E>
Expand Down Expand Up @@ -85,3 +91,5 @@ class expected {
};

} // namespace knowhere

#endif /* EXPECTED_H */
33 changes: 33 additions & 0 deletions src/common/metric.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#ifndef METRIC_H
#define METRIC_H

#include <algorithm>
#include <string>
#include <unordered_map>

#include "faiss/MetricType.h"
#include "knowhere/expected.h"
namespace knowhere {

expected<faiss::MetricType, Error>
Str2FaissMetricType(std::string metric) {
static const std::unordered_map<std::string, faiss::MetricType> metric_map = {
{"L2", faiss::MetricType::METRIC_L2},
{"IP", faiss::MetricType::METRIC_INNER_PRODUCT},
{"JACCARD", faiss::MetricType::METRIC_Jaccard},
{"TANIMOTO", faiss::MetricType::METRIC_Tanimoto},
{"HAMMING", faiss::MetricType::METRIC_Hamming},
{"SUBSTRUCTURE", faiss::MetricType::METRIC_Substructure},
{"SUPERSTRUCTURE", faiss::MetricType::METRIC_Superstructure},
};

std::transform(metric.begin(), metric.end(), metric.begin(), toupper);
auto it = metric_map.find(metric);
if (it == metric_map.end())
return unexpected(Error::invalid_metric_type);
return it->second;
}

} // namespace knowhere

#endif /* METRIC_H */
206 changes: 206 additions & 0 deletions src/index/annoy/annoy.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
#include "annoylib.h"
#include "index/annoy/annoy_config.h"
#include "kissrandom.h"
#include "knowhere/knowhere.h"
namespace knowhere {

using ThreadedBuildPolicy = AnnoyIndexSingleThreadedBuildPolicy;
class AnnoyIndexNode : public IndexNode {
public:
virtual Error
Build(const DataSet& dataset, const Config& cfg) override {
if (index_)
return Error::index_already_trained;

const AnnoyConfig& annoy_cfg = static_cast<const AnnoyConfig&>(cfg);
metric_type_ = annoy_cfg.metric_type;
auto dim = dataset.GetDim();
if (annoy_cfg.metric_type == "L2")
index_ =
new (std::nothrow) AnnoyIndex<int64_t, float, ::Euclidean, ::Kiss64Random, ThreadedBuildPolicy>(dim);
if (annoy_cfg.metric_type == "IP")
index_ =
new (std::nothrow) AnnoyIndex<int64_t, float, ::DotProduct, ::Kiss64Random, ThreadedBuildPolicy>(dim);
if (index_) {
auto p_data = dataset.GetTensor();
auto rows = dataset.GetRows();
for (int i = 0; i < rows; ++i) {
index_->add_item(i, static_cast<const float*>(p_data) + dim * i);
}
index_->build(annoy_cfg.n_trees);
return Error::success;
}

return Error::invalid_metric_type;
}
virtual Error
Train(const DataSet& dataset, const Config& cfg) override {
return Error::not_implemented;
}
virtual Error
Add(const DataSet& dataset, const Config& cfg) override {
return Error::not_implemented;
}
virtual expected<DataSetPtr, Error>
Search(const DataSet& dataset, const Config& cfg, const BitsetView& bitset) const override {
if (!index_) {
return unexpected(Error::empty_index);
}

auto dim = dataset.GetDim();
auto rows = dataset.GetRows();
auto ts = dataset.GetTensor();
auto annoy_cfg = static_cast<const AnnoyConfig&>(cfg);
auto p_id = new (std::nothrow) int64_t[annoy_cfg.k * rows];
auto p_dist = new (std::nothrow) float[annoy_cfg.k * rows];

#pragma omp parallel for
for (unsigned int i = 0; i < rows; ++i) {
std::vector<int64_t> result;
result.reserve(annoy_cfg.k);
std::vector<float> distances;
distances.reserve(annoy_cfg.k);
index_->get_nns_by_vector(static_cast<const float*>(ts) + i * dim, annoy_cfg.k, annoy_cfg.search_k, &result,
&distances, bitset);

size_t result_num = result.size();
auto local_p_id = p_id + annoy_cfg.k * i;
auto local_p_dist = p_dist + annoy_cfg.k * i;
memcpy(local_p_id, result.data(), result_num * sizeof(int64_t));
memcpy(local_p_dist, distances.data(), result_num * sizeof(float));

for (; result_num < (size_t)annoy_cfg.k; result_num++) {
local_p_id[result_num] = -1;
local_p_dist[result_num] = 1.0 / 0.0;
}
}

auto results = std::make_shared<DataSet>();
results->SetIds(p_id);
results->SetDistance(p_dist);
return results;
}
virtual expected<DataSetPtr, Error>
SearchByRange(const DataSet& dataset, const Config& cfg, const BitsetView& bitset) const override {
return unexpected(Error::not_implemented);
}

virtual expected<DataSetPtr, Error>
GetVectorByIds(const DataSet& dataset, const Config& cfg) const override {
if (!index_) {
return unexpected(Error::empty_index);
}

auto rows = dataset.GetRows();
auto dim = dataset.GetDim();
auto p_ids = dataset.GetIds();

float* p_x = nullptr;
try {
p_x = new (std::nothrow) float[dim * rows];
for (int64_t i = 0; i < rows; i++) {
int64_t id = p_ids[i];
assert(id >= 0 && id < index_->get_n_items());
index_->get_item(id, p_x + i * dim);
}
} catch (...) {
std::unique_ptr<float> auto_del(p_x);
return unexpected(Error::annoy_inner_error);
}

auto results = std::make_shared<DataSet>();
results->SetTensor(p_x);

return results;
}
virtual Error
Serialization(BinarySet& binset) const override {
if (!index_) {
return Error::empty_index;
}

auto metric_type_length = metric_type_.length();
auto metric_type = std::make_shared<uint8_t[]>(metric_type_length);
memcpy(metric_type.get(), metric_type_.data(), metric_type_.length());

auto dim = Dims();
auto dim_data = std::make_shared<uint8_t[]>(sizeof(uint64_t));
memcpy(dim_data.get(), &dim, sizeof(uint64_t));

size_t index_length = index_->get_index_length();
auto index_data = std::make_shared<uint8_t[]>(index_length);
memcpy(index_data.get(), index_->get_index(), index_length);

binset.Append("annoy_metric_type", metric_type, metric_type_length);
binset.Append("annoy_dim", dim_data, sizeof(uint64_t));
binset.Append("annoy_index_data", index_data, index_length);

return Error::success;
}
virtual Error
Deserialization(const BinarySet& binset) override {
auto metric_type = binset.GetByName("annoy_metric_type");
metric_type_.resize(static_cast<size_t>(metric_type->size));
memcpy(metric_type_.data(), metric_type->data.get(), static_cast<size_t>(metric_type->size));

auto dim_data = binset.GetByName("annoy_dim");
uint64_t dim;
memcpy(&dim, dim_data->data.get(), static_cast<size_t>(dim_data->size));

if (metric_type_ == "L2") {
index_ =
new (std::nothrow) AnnoyIndex<int64_t, float, ::Euclidean, ::Kiss64Random, ThreadedBuildPolicy>(dim);
} else if (metric_type_ == "IP") {
index_ =
new (std::nothrow) AnnoyIndex<int64_t, float, ::DotProduct, ::Kiss64Random, ThreadedBuildPolicy>(dim);
}

auto index_data = binset.GetByName("annoy_index_data");
char* p = nullptr;
if (!index_->load_index(reinterpret_cast<void*>(index_data->data.get()), index_data->size, &p)) {
free(p);
return Error::annoy_inner_error;
}

return Error::success;
}

virtual std::unique_ptr<Config>
CreateConfig() const override {
return std::make_unique<AnnoyConfig>();
}
virtual int64_t
Dims() const override {
if (!index_)
return 0;
return index_->get_dim();
}
virtual int64_t
Size() const override {
if (!index_)
return 0;
return index_->cal_size();
}
virtual int64_t
Count() const override {
if (!index_)
return 0;
return index_->get_n_items();
}
virtual std::string
Type() const override {
return "ANNOY";
}
virtual ~AnnoyIndexNode() {
if (index_)
delete index_;
}

private:
AnnoyIndexInterface<int64_t, float>* index_;
std::string metric_type_;
};

KNOWHERE_REGISTER_GLOBAL(ANNOY, []() { return Index<AnnoyIndexNode>::Create(); });

} // namespace knowhere
Loading

0 comments on commit b583062

Please sign in to comment.