Skip to content

Commit

Permalink
sparse float vector index to support mmap when loading from index file (
Browse files Browse the repository at this point in the history
#530)

Signed-off-by: Buqian Zheng <zhengbuqian@gmail.com>
  • Loading branch information
zhengbuqian authored Apr 26, 2024
1 parent 714ba32 commit 2ab12eb
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 37 deletions.
60 changes: 51 additions & 9 deletions src/index/sparse/sparse_index_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.

#include <sys/mman.h>

#include "index/hnsw/hnsw_config.h"
#include "index/sparse/sparse_inverted_index.h"
#include "index/sparse/sparse_inverted_index_config.h"
#include "io/file_io.h"
#include "io/memory_io.h"
#include "knowhere/comp/thread_pool.h"
#include "knowhere/config.h"
Expand All @@ -38,10 +41,7 @@ class SparseInvertedIndexNode : public IndexNode {
}

~SparseInvertedIndexNode() override {
if (index_ != nullptr) {
delete index_;
index_ = nullptr;
}
delete_index();
}

Status
Expand All @@ -58,7 +58,7 @@ class SparseInvertedIndexNode : public IndexNode {
drop_ratio_build);
if (index_ != nullptr) {
LOG_KNOWHERE_WARNING_ << Type() << " deleting old index during train";
delete index_;
delete_index();
}
index_ = index;
return Status::success;
Expand Down Expand Up @@ -183,8 +183,7 @@ class SparseInvertedIndexNode : public IndexNode {
Deserialize(const BinarySet& binset, const Config& config) override {
if (index_) {
LOG_KNOWHERE_WARNING_ << Type() << " has already been created, deleting old";
delete index_;
index_ = nullptr;
delete_index();
}
auto binary = binset.GetByName(Type());
if (binary == nullptr) {
Expand All @@ -194,12 +193,35 @@ class SparseInvertedIndexNode : public IndexNode {
MemoryIOReader reader(binary->data.get(), binary->size);
index_ = new sparse::InvertedIndex<T>();
// no need to set use_wand_ of index_, since it will be set in Load()
return index_->Load(reader);
return index_->Load(reader, false);
}

Status
DeserializeFromFile(const std::string& filename, const Config& config) override {
throw std::runtime_error("DeserializeFromFile not supported for current index type");
if (index_) {
LOG_KNOWHERE_WARNING_ << Type() << " has already been created, deleting old";
delete_index();
}
auto cfg = static_cast<const knowhere::BaseConfig&>(config);
auto reader = knowhere::FileReader(filename);
map_size_ = reader.size();
int map_flags = MAP_SHARED;
if (cfg.enable_mmap_pop.has_value() && cfg.enable_mmap_pop.value()) {
#ifdef MAP_POPULATE
map_flags |= MAP_POPULATE;
#endif
}
map_ = static_cast<char*>(mmap(nullptr, map_size_, PROT_READ, map_flags, reader.descriptor(), 0));
if (map_ == MAP_FAILED) {
LOG_KNOWHERE_ERROR_ << "Failed to mmap file: " << strerror(errno);
return Status::disk_file_error;
}
if (madvise(map_, map_size_, MADV_RANDOM) != 0) {
LOG_KNOWHERE_WARNING_ << "Failed to madvise file: " << strerror(errno);
}
index_ = new sparse::InvertedIndex<T>();
MemoryIOReader map_reader((uint8_t*)map_, map_size_);
return index_->Load(map_reader, true);
}

[[nodiscard]] std::unique_ptr<BaseConfig>
Expand Down Expand Up @@ -229,8 +251,28 @@ class SparseInvertedIndexNode : public IndexNode {
}

private:
void
delete_index() {
if (index_ != nullptr) {
delete index_;
index_ = nullptr;
}
if (map_ != nullptr) {
auto res = munmap(map_, map_size_);
if (res != 0) {
LOG_KNOWHERE_ERROR_ << "Failed to munmap when trying to delete index: " << strerror(errno);
}
map_ = nullptr;
map_size_ = 0;
}
}

sparse::InvertedIndex<T>* index_{};
std::shared_ptr<ThreadPool> search_pool_;

// if map_ is not nullptr, it means the index is mmapped from disk.
char* map_ = nullptr;
size_t map_size_ = 0;
}; // class SparseInvertedIndexNode

KNOWHERE_SIMPLE_REGISTER_GLOBAL(SPARSE_INVERTED_INDEX, SparseInvertedIndexNode, fp32, /*use_wand=*/false);
Expand Down
15 changes: 10 additions & 5 deletions src/index/sparse/sparse_inverted_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class InvertedIndex {
}

Status
Load(MemoryIOReader& reader) {
Load(MemoryIOReader& reader, bool is_mmap) {
std::unique_lock<std::shared_mutex> lock(mu_);
int64_t rows;
readBinaryPOD(reader, rows);
Expand All @@ -91,11 +91,16 @@ class InvertedIndex {
for (int64_t i = 0; i < rows; ++i) {
size_t count;
readBinaryPOD(reader, count);
raw_data_.emplace_back(count);
if (count == 0) {
continue;
if (is_mmap) {
raw_data_.emplace_back(count, reader.data() + reader.tellg(), false);
reader.advance(count * SparseRow<T>::element_size());
} else {
raw_data_.emplace_back(count);
if (count == 0) {
continue;
}
reader.read(raw_data_[i].data(), count * SparseRow<T>::element_size());
}
reader.read(raw_data_[i].data(), count * SparseRow<T>::element_size());
add_row_to_index(raw_data_[i], i);
}

Expand Down
10 changes: 10 additions & 0 deletions src/io/memory_io.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,16 @@ struct MemoryIOReader : public faiss::IOReader {
return rp_;
}

// returns the new tellg() result.
size_t
advance(size_t size) {
rp_ += size;
if (rp_ > total_) {
rp_ = total_;
}
return rp_;
}

void
reset() {
rp_ = 0;
Expand Down
66 changes: 43 additions & 23 deletions tests/ut/test_sparse.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,30 +104,50 @@ TEST_CASE("Test Mem Sparse Index With Float Vector", "[float metrics]") {
auto gt = knowhere::BruteForce::SearchSparse(train_ds, query_ds, conf, nullptr);
check_distance_decreasing(*gt.value());

auto idx = knowhere::IndexFactory::Instance().Create<knowhere::fp32>(name, version).value();
auto cfg_json = gen().dump();
CAPTURE(name, cfg_json);
knowhere::Json json = knowhere::Json::parse(cfg_json);
REQUIRE(idx.Type() == name);
REQUIRE(idx.Build(*train_ds, json) == knowhere::Status::success);
REQUIRE(idx.Size() > 0);
REQUIRE(idx.Count() == nb);

knowhere::BinarySet bs;
REQUIRE(idx.Serialize(bs) == knowhere::Status::success);
REQUIRE(idx.Deserialize(bs, json) == knowhere::Status::success);
auto use_mmap = GENERATE(true, false);
auto tmp_file = "/tmp/knowhere_sparse_inverted_index_test";
{
auto idx = knowhere::IndexFactory::Instance().Create<knowhere::fp32>(name, version).value();
auto cfg_json = gen().dump();
CAPTURE(name, cfg_json);
knowhere::Json json = knowhere::Json::parse(cfg_json);
REQUIRE(idx.Type() == name);
REQUIRE(idx.Build(*train_ds, json) == knowhere::Status::success);
REQUIRE(idx.Size() > 0);
REQUIRE(idx.Count() == nb);

knowhere::BinarySet bs;
REQUIRE(idx.Serialize(bs) == knowhere::Status::success);
if (use_mmap) {
auto binary = bs.GetByName(idx.Type());
auto data = binary->data.get();
auto size = binary->size;
// if tmp_file already exists, remove it
std::remove(tmp_file);
std::ofstream out(tmp_file, std::ios::binary);
out.write((const char*)data, size);
out.close();
REQUIRE(idx.DeserializeFromFile(tmp_file, json) == knowhere::Status::success);
} else {
REQUIRE(idx.Deserialize(bs, json) == knowhere::Status::success);
}

auto results = idx.Search(*query_ds, json, nullptr);
REQUIRE(results.has_value());
float recall = GetKNNRecall(*gt.value(), *results.value());
check_distance_decreasing(*results.value());
auto drop_ratio_build = json[knowhere::indexparam::DROP_RATIO_BUILD].get<float>();
auto drop_ratio_search = json[knowhere::indexparam::DROP_RATIO_SEARCH].get<float>();
if (drop_ratio_build == 0 && drop_ratio_search == 0) {
REQUIRE(recall == 1);
} else {
// most test cases are above 0.95, only a few between 0.9 and 0.95
REQUIRE(recall >= 0.85);
auto results = idx.Search(*query_ds, json, nullptr);
REQUIRE(results.has_value());
float recall = GetKNNRecall(*gt.value(), *results.value());
check_distance_decreasing(*results.value());
auto drop_ratio_build = json[knowhere::indexparam::DROP_RATIO_BUILD].get<float>();
auto drop_ratio_search = json[knowhere::indexparam::DROP_RATIO_SEARCH].get<float>();
if (drop_ratio_build == 0 && drop_ratio_search == 0) {
REQUIRE(recall == 1);
} else {
// most test cases are above 0.95, only a few between 0.9 and 0.95
REQUIRE(recall >= 0.85);
}
// idx to destruct and munmap
}
if (use_mmap) {
REQUIRE(std::remove(tmp_file) == 0);
}
}

Expand Down

0 comments on commit 2ab12eb

Please sign in to comment.