-
Notifications
You must be signed in to change notification settings - Fork 83
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Yusheng.Ma <Yusheng.Ma@zilliz.com>
- Loading branch information
1 parent
40eab16
commit b583062
Showing
35 changed files
with
2,826 additions
and
2,128 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
include_directories(thirdparty/annoy/src) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
include_directories(thirdparty/hnswlib) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
#ifndef METRIC_H | ||
#define METRIC_H | ||
|
||
#include <algorithm> | ||
#include <string> | ||
#include <unordered_map> | ||
|
||
#include "faiss/MetricType.h" | ||
#include "knowhere/expected.h" | ||
namespace knowhere { | ||
|
||
expected<faiss::MetricType, Error> | ||
Str2FaissMetricType(std::string metric) { | ||
static const std::unordered_map<std::string, faiss::MetricType> metric_map = { | ||
{"L2", faiss::MetricType::METRIC_L2}, | ||
{"IP", faiss::MetricType::METRIC_INNER_PRODUCT}, | ||
{"JACCARD", faiss::MetricType::METRIC_Jaccard}, | ||
{"TANIMOTO", faiss::MetricType::METRIC_Tanimoto}, | ||
{"HAMMING", faiss::MetricType::METRIC_Hamming}, | ||
{"SUBSTRUCTURE", faiss::MetricType::METRIC_Substructure}, | ||
{"SUPERSTRUCTURE", faiss::MetricType::METRIC_Superstructure}, | ||
}; | ||
|
||
std::transform(metric.begin(), metric.end(), metric.begin(), toupper); | ||
auto it = metric_map.find(metric); | ||
if (it == metric_map.end()) | ||
return unexpected(Error::invalid_metric_type); | ||
return it->second; | ||
} | ||
|
||
} // namespace knowhere | ||
|
||
#endif /* METRIC_H */ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,206 @@ | ||
#include "annoylib.h" | ||
#include "index/annoy/annoy_config.h" | ||
#include "kissrandom.h" | ||
#include "knowhere/knowhere.h" | ||
namespace knowhere { | ||
|
||
using ThreadedBuildPolicy = AnnoyIndexSingleThreadedBuildPolicy; | ||
class AnnoyIndexNode : public IndexNode { | ||
public: | ||
virtual Error | ||
Build(const DataSet& dataset, const Config& cfg) override { | ||
if (index_) | ||
return Error::index_already_trained; | ||
|
||
const AnnoyConfig& annoy_cfg = static_cast<const AnnoyConfig&>(cfg); | ||
metric_type_ = annoy_cfg.metric_type; | ||
auto dim = dataset.GetDim(); | ||
if (annoy_cfg.metric_type == "L2") | ||
index_ = | ||
new (std::nothrow) AnnoyIndex<int64_t, float, ::Euclidean, ::Kiss64Random, ThreadedBuildPolicy>(dim); | ||
if (annoy_cfg.metric_type == "IP") | ||
index_ = | ||
new (std::nothrow) AnnoyIndex<int64_t, float, ::DotProduct, ::Kiss64Random, ThreadedBuildPolicy>(dim); | ||
if (index_) { | ||
auto p_data = dataset.GetTensor(); | ||
auto rows = dataset.GetRows(); | ||
for (int i = 0; i < rows; ++i) { | ||
index_->add_item(i, static_cast<const float*>(p_data) + dim * i); | ||
} | ||
index_->build(annoy_cfg.n_trees); | ||
return Error::success; | ||
} | ||
|
||
return Error::invalid_metric_type; | ||
} | ||
virtual Error | ||
Train(const DataSet& dataset, const Config& cfg) override { | ||
return Error::not_implemented; | ||
} | ||
virtual Error | ||
Add(const DataSet& dataset, const Config& cfg) override { | ||
return Error::not_implemented; | ||
} | ||
virtual expected<DataSetPtr, Error> | ||
Search(const DataSet& dataset, const Config& cfg, const BitsetView& bitset) const override { | ||
if (!index_) { | ||
return unexpected(Error::empty_index); | ||
} | ||
|
||
auto dim = dataset.GetDim(); | ||
auto rows = dataset.GetRows(); | ||
auto ts = dataset.GetTensor(); | ||
auto annoy_cfg = static_cast<const AnnoyConfig&>(cfg); | ||
auto p_id = new (std::nothrow) int64_t[annoy_cfg.k * rows]; | ||
auto p_dist = new (std::nothrow) float[annoy_cfg.k * rows]; | ||
|
||
#pragma omp parallel for | ||
for (unsigned int i = 0; i < rows; ++i) { | ||
std::vector<int64_t> result; | ||
result.reserve(annoy_cfg.k); | ||
std::vector<float> distances; | ||
distances.reserve(annoy_cfg.k); | ||
index_->get_nns_by_vector(static_cast<const float*>(ts) + i * dim, annoy_cfg.k, annoy_cfg.search_k, &result, | ||
&distances, bitset); | ||
|
||
size_t result_num = result.size(); | ||
auto local_p_id = p_id + annoy_cfg.k * i; | ||
auto local_p_dist = p_dist + annoy_cfg.k * i; | ||
memcpy(local_p_id, result.data(), result_num * sizeof(int64_t)); | ||
memcpy(local_p_dist, distances.data(), result_num * sizeof(float)); | ||
|
||
for (; result_num < (size_t)annoy_cfg.k; result_num++) { | ||
local_p_id[result_num] = -1; | ||
local_p_dist[result_num] = 1.0 / 0.0; | ||
} | ||
} | ||
|
||
auto results = std::make_shared<DataSet>(); | ||
results->SetIds(p_id); | ||
results->SetDistance(p_dist); | ||
return results; | ||
} | ||
virtual expected<DataSetPtr, Error> | ||
SearchByRange(const DataSet& dataset, const Config& cfg, const BitsetView& bitset) const override { | ||
return unexpected(Error::not_implemented); | ||
} | ||
|
||
virtual expected<DataSetPtr, Error> | ||
GetVectorByIds(const DataSet& dataset, const Config& cfg) const override { | ||
if (!index_) { | ||
return unexpected(Error::empty_index); | ||
} | ||
|
||
auto rows = dataset.GetRows(); | ||
auto dim = dataset.GetDim(); | ||
auto p_ids = dataset.GetIds(); | ||
|
||
float* p_x = nullptr; | ||
try { | ||
p_x = new (std::nothrow) float[dim * rows]; | ||
for (int64_t i = 0; i < rows; i++) { | ||
int64_t id = p_ids[i]; | ||
assert(id >= 0 && id < index_->get_n_items()); | ||
index_->get_item(id, p_x + i * dim); | ||
} | ||
} catch (...) { | ||
std::unique_ptr<float> auto_del(p_x); | ||
return unexpected(Error::annoy_inner_error); | ||
} | ||
|
||
auto results = std::make_shared<DataSet>(); | ||
results->SetTensor(p_x); | ||
|
||
return results; | ||
} | ||
virtual Error | ||
Serialization(BinarySet& binset) const override { | ||
if (!index_) { | ||
return Error::empty_index; | ||
} | ||
|
||
auto metric_type_length = metric_type_.length(); | ||
auto metric_type = std::make_shared<uint8_t[]>(metric_type_length); | ||
memcpy(metric_type.get(), metric_type_.data(), metric_type_.length()); | ||
|
||
auto dim = Dims(); | ||
auto dim_data = std::make_shared<uint8_t[]>(sizeof(uint64_t)); | ||
memcpy(dim_data.get(), &dim, sizeof(uint64_t)); | ||
|
||
size_t index_length = index_->get_index_length(); | ||
auto index_data = std::make_shared<uint8_t[]>(index_length); | ||
memcpy(index_data.get(), index_->get_index(), index_length); | ||
|
||
binset.Append("annoy_metric_type", metric_type, metric_type_length); | ||
binset.Append("annoy_dim", dim_data, sizeof(uint64_t)); | ||
binset.Append("annoy_index_data", index_data, index_length); | ||
|
||
return Error::success; | ||
} | ||
virtual Error | ||
Deserialization(const BinarySet& binset) override { | ||
auto metric_type = binset.GetByName("annoy_metric_type"); | ||
metric_type_.resize(static_cast<size_t>(metric_type->size)); | ||
memcpy(metric_type_.data(), metric_type->data.get(), static_cast<size_t>(metric_type->size)); | ||
|
||
auto dim_data = binset.GetByName("annoy_dim"); | ||
uint64_t dim; | ||
memcpy(&dim, dim_data->data.get(), static_cast<size_t>(dim_data->size)); | ||
|
||
if (metric_type_ == "L2") { | ||
index_ = | ||
new (std::nothrow) AnnoyIndex<int64_t, float, ::Euclidean, ::Kiss64Random, ThreadedBuildPolicy>(dim); | ||
} else if (metric_type_ == "IP") { | ||
index_ = | ||
new (std::nothrow) AnnoyIndex<int64_t, float, ::DotProduct, ::Kiss64Random, ThreadedBuildPolicy>(dim); | ||
} | ||
|
||
auto index_data = binset.GetByName("annoy_index_data"); | ||
char* p = nullptr; | ||
if (!index_->load_index(reinterpret_cast<void*>(index_data->data.get()), index_data->size, &p)) { | ||
free(p); | ||
return Error::annoy_inner_error; | ||
} | ||
|
||
return Error::success; | ||
} | ||
|
||
virtual std::unique_ptr<Config> | ||
CreateConfig() const override { | ||
return std::make_unique<AnnoyConfig>(); | ||
} | ||
virtual int64_t | ||
Dims() const override { | ||
if (!index_) | ||
return 0; | ||
return index_->get_dim(); | ||
} | ||
virtual int64_t | ||
Size() const override { | ||
if (!index_) | ||
return 0; | ||
return index_->cal_size(); | ||
} | ||
virtual int64_t | ||
Count() const override { | ||
if (!index_) | ||
return 0; | ||
return index_->get_n_items(); | ||
} | ||
virtual std::string | ||
Type() const override { | ||
return "ANNOY"; | ||
} | ||
virtual ~AnnoyIndexNode() { | ||
if (index_) | ||
delete index_; | ||
} | ||
|
||
private: | ||
AnnoyIndexInterface<int64_t, float>* index_; | ||
std::string metric_type_; | ||
}; | ||
|
||
KNOWHERE_REGISTER_GLOBAL(ANNOY, []() { return Index<AnnoyIndexNode>::Create(); }); | ||
|
||
} // namespace knowhere |
Oops, something went wrong.