Skip to content

Commit

Permalink
Add sampling for value based faceting.
Browse files Browse the repository at this point in the history
  • Loading branch information
kishorenc committed Mar 2, 2024
1 parent cc81529 commit 720aad4
Show file tree
Hide file tree
Showing 9 changed files with 111 additions and 32 deletions.
5 changes: 4 additions & 1 deletion include/facet_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,10 @@ class facet_index_t {
size_t get_facet_count(const std::string& field_name);

size_t intersect(facet& a_facet, const field& facet_field,
bool has_facet_query, const std::vector<std::vector<std::string>>& fvalue_searched_tokens,
bool has_facet_query,
bool estimate_facets,
size_t facet_sample_interval,
const std::vector<std::vector<std::string>>& fvalue_searched_tokens,
const std::vector<char>& symbols_to_index, const std::vector<char>& token_separators,
const uint32_t* result_ids, size_t result_id_len,
size_t max_facet_count, std::map<std::string, docid_count_t>& found,
Expand Down
4 changes: 3 additions & 1 deletion include/id_list.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class id_list_t {
void next();
void previous();
[[nodiscard]] uint32_t last_block_id() const;
void skip_n(uint32_t n);
void skip_to(uint32_t id);
void reset_cache();
[[nodiscard]] uint32_t id() const;
Expand Down Expand Up @@ -158,7 +159,8 @@ class id_list_t {

void uncompress(std::vector<uint32_t>& data);

size_t intersect_count(const uint32_t* res_ids, size_t res_ids_len);
size_t intersect_count(const uint32_t* res_ids, size_t res_ids_len,
bool estimate_facets, size_t facet_sample_interval);
};

template<class T>
Expand Down
3 changes: 2 additions & 1 deletion include/ids_t.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ class ids_t {

static void uncompress(void*& obj, std::vector<uint32_t>& ids);

static size_t intersect_count(void*& obj, const uint32_t* result_ids, size_t result_ids_len);
static size_t intersect_count(void*& obj, const uint32_t* result_ids, size_t result_ids_len,
bool estimate_facets = false, size_t facet_sample_mod_value = 1);

static void to_expanded_id_lists(const std::vector<void*>& raw_id_lists, std::vector<id_list_t*>& id_lists,
std::vector<id_list_t*>& expanded_id_lists);
Expand Down
9 changes: 7 additions & 2 deletions src/facet_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,10 @@ size_t facet_index_t::get_facet_count(const std::string& field_name) {

//returns the count of matching seq_ids from result array
size_t facet_index_t::intersect(facet& a_facet, const field& facet_field,
bool has_facet_query, const std::vector<std::vector<std::string>>& fvalue_searched_tokens,
bool has_facet_query,
bool estimate_facets,
size_t facet_sample_interval,
const std::vector<std::vector<std::string>>& fvalue_searched_tokens,
const std::vector<char>& symbols_to_index, const std::vector<char>& token_separators,
const uint32_t* result_ids, size_t result_ids_len,
size_t max_facet_count, std::map<std::string, docid_count_t>& found,
Expand Down Expand Up @@ -389,7 +392,9 @@ size_t facet_index_t::intersect(facet& a_facet, const field& facet_field,
if (is_wildcard_no_filter_query) {
count = facet_count_it->count;
} else {
count = ids_t::intersect_count(ids, result_ids, result_ids_len);
auto val_count = ids_t::num_ids(ids);
bool estimate_facet_count = (estimate_facets && val_count > 300);
count = ids_t::intersect_count(ids, result_ids, result_ids_len, estimate_facet_count, facet_sample_interval);
}

if (count) {
Expand Down
69 changes: 57 additions & 12 deletions src/id_list.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,29 @@ void id_list_t::iterator_t::reset_cache() {
curr_block = end_block = nullptr;
}

void id_list_t::iterator_t::skip_n(uint32_t n) {
while(curr_block != end_block) {
curr_index += n;
if(curr_index < curr_block->size()) {
return;
}

n = (curr_index - curr_block->size() + 1);
curr_block = curr_block->next;

delete [] ids;
ids = nullptr;

if(curr_block != end_block) {
curr_index = 0;
n--;
ids = curr_block->ids.uncompress();
} else {
reset_cache();
}
}
}

void id_list_t::iterator_t::skip_to(uint32_t id) {
// first look to skip within current block
if(id <= this->last_block_id()) {
Expand Down Expand Up @@ -699,23 +722,45 @@ uint32_t* id_list_t::uncompress() {
return arr;
}

size_t id_list_t::intersect_count(const uint32_t *res_ids, size_t res_ids_len) {
size_t id_list_t::intersect_count(const uint32_t *res_ids, size_t res_ids_len,
bool estimate_facets, size_t facet_sample_interval) {
size_t count = 0;
size_t res_index = 0;
auto it = new_iterator();

while(it.valid() && res_index < res_ids_len) {
if(it.id() < res_ids[res_index]) {
it.skip_to(res_ids[res_index]);
} else if(it.id() > res_ids[res_index]) {
// returns index that is >= to value or last if no such element is found.
res_index = std::lower_bound(res_ids + res_index, res_ids + res_ids_len, it.id()) - res_ids;
} else {
it.next();
res_index++;
count++;
if(estimate_facets) {
while(res_index < res_ids_len && it.valid()) {
auto skip_interval = rand() % (2 * facet_sample_interval);
if(it.id() == res_ids[res_index]) {
count++;
it.next();
res_index += skip_interval;
} else if(it.id() < res_ids[res_index]) {
it.skip_to(res_ids[res_index]);
} else {
res_index += skip_interval;
}
}
} else {
while(it.valid() && res_index < res_ids_len) {
if(it.id() == res_ids[res_index]) {
count++;
it.next();
res_index += 1;
} else if(it.id() < res_ids[res_index]) {
it.next();
} else {
res_index += 1;
}
}
}

return count;
//LOG(INFO) << "estimate_facets: " << estimate_facets << ", res_ids_len: " << res_ids_len
// << ", skip_interval: " << facet_sample_interval << ", count: " << count;

if(estimate_facets) {
count = count * facet_sample_interval;
}

return std::min<size_t>(ids_length, count);
}
5 changes: 3 additions & 2 deletions src/ids_t.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -365,13 +365,14 @@ void ids_t::uncompress(void*& obj, std::vector<uint32_t>& ids) {
}
}

size_t ids_t::intersect_count(void*& obj, const uint32_t* result_ids, size_t result_ids_len) {
size_t ids_t::intersect_count(void*& obj, const uint32_t* result_ids, size_t result_ids_len,
bool estimate_facets, size_t facet_sample_mod_value) {
if(IS_COMPACT_IDS(obj)) {
compact_id_list_t* list = COMPACT_IDS_PTR(obj);
return list->intersect_count(result_ids, result_ids_len);
} else {
id_list_t* list = (id_list_t*)(obj);
return list->intersect_count(result_ids, result_ids_len);
return list->intersect_count(result_ids, result_ids_len, estimate_facets, facet_sample_mod_value);
}
}

Expand Down
11 changes: 4 additions & 7 deletions src/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1371,7 +1371,7 @@ void Index::do_facets(std::vector<facet> & facets, facet_query_t & facet_query,
facet_sample_percent = 1;
}

size_t mod_value = 100 / facet_sample_percent;
size_t facet_sample_mod_value = 100 / facet_sample_percent;

auto num_facet_values = facet_index_v4->get_facet_count(facet_field.name);
if(num_facet_values == 0) {
Expand All @@ -1392,6 +1392,7 @@ void Index::do_facets(std::vector<facet> & facets, facet_query_t & facet_query,
std::string sort_order = a_facet.is_sort_by_alpha ? a_facet.sort_order : "";

facet_index_v4->intersect(a_facet, facet_field,use_facet_query,
estimate_facets, facet_sample_mod_value,
facet_infos[findex].fvalue_searched_tokens,
symbols_to_index, token_separators,
result_ids, results_size, max_facet_count, facet_results,
Expand Down Expand Up @@ -1462,7 +1463,7 @@ void Index::do_facets(std::vector<facet> & facets, facet_query_t & facet_query,
for(size_t i = 0; i < results_size; i++) {
// if sampling is enabled, we will skip a portion of the results to speed up things
if(estimate_facets) {
if(i % mod_value != 0) {
if(i % facet_sample_mod_value != 0) {
continue;
}
}
Expand Down Expand Up @@ -3638,11 +3639,7 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
}
}

for(auto& facet_kv: acc_facet.value_result_map) {
if(estimate_facets) {
facet_kv.second.count = size_t(double(facet_kv.second.count) * (100.0f / facet_sample_percent));
}
}
// value_result_map already contains the scaled counts

if(estimate_facets) {
acc_facet.sampled = true;
Expand Down
16 changes: 10 additions & 6 deletions test/collection_optimized_faceting_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1364,7 +1364,7 @@ TEST_F(CollectionOptimizedFacetingTest, SampleFacetCounts) {

for(size_t i = 0; i < 1000; i++) {
nlohmann::json doc;
if(distr(gen) % 2 == 0) {
if(distr(gen) % 4 == 0) {
doc["color"] = "blue";
count_blue++;
} else {
Expand All @@ -1375,23 +1375,27 @@ TEST_F(CollectionOptimizedFacetingTest, SampleFacetCounts) {
ASSERT_TRUE(coll1->add(doc.dump()).ok());
}

auto res = coll1->search("*", {}, "", {"color"}, {}, {0}, 3, 1, FREQUENCY, {true}, 5,
auto res = coll1->search("*", {}, "color:blue || color:red", {"color"}, {}, {0}, 3, 1, FREQUENCY, {true}, 5,
spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, "", 30, 4, "", 20, {}, {}, {}, 0,
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, fallback,
4, {off}, 3, 3, 2, 2, false, "", true, 0, max_score, 10, 0, 4294967295UL, VALUE).get();
4, {off}, 3, 3, 2, 2, false, "", true, 0, max_score, 5, 0, 4294967295UL, VALUE).get();

ASSERT_EQ(1000, res["found"].get<size_t>());
ASSERT_EQ(1, res["facet_counts"].size());
ASSERT_EQ(2, res["facet_counts"][0]["counts"].size());

// verify approximate counts
ASSERT_GE(res["facet_counts"][0]["counts"][0]["count"].get<size_t>(), 250);
ASSERT_GE(res["facet_counts"][0]["counts"][1]["count"].get<size_t>(), 250);
ASSERT_GE(res["facet_counts"][0]["counts"][0]["count"].get<size_t>(), 700);
ASSERT_EQ("red", res["facet_counts"][0]["counts"][0]["value"].get<std::string>());

ASSERT_GE(res["facet_counts"][0]["counts"][1]["count"].get<size_t>(), 200);
ASSERT_EQ("blue", res["facet_counts"][0]["counts"][1]["value"].get<std::string>());

ASSERT_TRUE(res["facet_counts"][0]["sampled"].get<bool>());

// when sample threshold is high, don't estimate
res = coll1->search("*", {}, "", {"color"}, {}, {0}, 3, 1, FREQUENCY, {true}, 5,
res = coll1->search("*", {}, "color:blue || color:red", {"color"}, {}, {0}, 3, 1, FREQUENCY, {true}, 5,
spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, "", 30, 4, "", 20, {}, {}, {}, 0,
"<mark>", "</mark>", {}, 1000, true, false, true, "", false, 6000 * 1000, 4, 7, fallback,
Expand Down
21 changes: 21 additions & 0 deletions test/id_list_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,24 @@ TEST(IdListTest, IdListIteratorTest) {
ASSERT_EQ(10, count);
ASSERT_FALSE(iter2.valid());
}

TEST(IdListTest, IdListIntersectionTest) {
id_list_t id_list(2);
size_t res_len = 10*1000;

for (size_t i = 0; i < 1000; i++) {
id_list.upsert(i * 2);
}

// large res_ids + small id_list
uint32_t* res_ids = new uint32_t[res_len];
for(size_t i = 1; i < res_len; i++) {
res_ids[i] = (rand() % (res_len*10));
}
std::sort(res_ids, res_ids + res_len);

auto count = id_list.intersect_count(res_ids, res_len, false, 0);
ASSERT_NE(0, count);

delete [] res_ids;
}

0 comments on commit 720aad4

Please sign in to comment.