Skip to content

Commit

Permalink
Fix crash when reference expression is passed in _eval. (typesense#…
Browse files Browse the repository at this point in the history
happy-san authored Aug 1, 2024
1 parent 05b0faa commit 65151cd
Showing 7 changed files with 59 additions and 26 deletions.
6 changes: 3 additions & 3 deletions include/collection.h
Original file line number Diff line number Diff line change
@@ -701,9 +701,9 @@ class Collection {

Option<bool> truncate_after_top_k(const std::string& field_name, size_t k);

void reference_populate_sort_mapping(int* sort_order, std::vector<size_t>& geopoint_indices,
std::vector<sort_by>& sort_fields_std,
std::array<spp::sparse_hash_map<uint32_t, int64_t, Hasher32>*, 3>& field_values) const;
Option<bool> reference_populate_sort_mapping(int* sort_order, std::vector<size_t>& geopoint_indices,
std::vector<sort_by>& sort_fields_std,
std::array<spp::sparse_hash_map<uint32_t, int64_t, Hasher32>*, 3>& field_values) const;

int64_t reference_string_sort_score(const std::string& field_name, const uint32_t& seq_id) const;

12 changes: 6 additions & 6 deletions include/index.h
Original file line number Diff line number Diff line change
@@ -796,13 +796,13 @@ class Index {
const uint32_t* exclude_token_ids, size_t exclude_token_ids_size, uint32_t*& filter_ids,
uint32_t& filter_ids_length, const std::vector<uint32_t>& curated_ids_sorted) const;

void populate_sort_mapping(int* sort_order, std::vector<size_t>& geopoint_indices,
std::vector<sort_by>& sort_fields_std,
std::array<spp::sparse_hash_map<uint32_t, int64_t, Hasher32>*, 3>& field_values) const;
Option<bool> populate_sort_mapping(int* sort_order, std::vector<size_t>& geopoint_indices,
std::vector<sort_by>& sort_fields_std,
std::array<spp::sparse_hash_map<uint32_t, int64_t, Hasher32>*, 3>& field_values) const;

void populate_sort_mapping_with_lock(int* sort_order, std::vector<size_t>& geopoint_indices,
std::vector<sort_by>& sort_fields_std,
std::array<spp::sparse_hash_map<uint32_t, int64_t, Hasher32>*, 3>& field_values) const;
Option<bool> populate_sort_mapping_with_lock(int* sort_order, std::vector<size_t>& geopoint_indices,
std::vector<sort_by>& sort_fields_std,
std::array<spp::sparse_hash_map<uint32_t, int64_t, Hasher32>*, 3>& field_values) const;

int64_t reference_string_sort_score(const std::string& field_name, const uint32_t& seq_id) const;

10 changes: 5 additions & 5 deletions src/collection.cpp
Original file line number Diff line number Diff line change
@@ -6828,12 +6828,12 @@ Option<bool> Collection::truncate_after_top_k(const string &field_name, size_t k
return Option<bool>(true);
}

void Collection::reference_populate_sort_mapping(int *sort_order, std::vector<size_t> &geopoint_indices,
std::vector<sort_by> &sort_fields_std,
std::array<spp::sparse_hash_map<uint32_t, int64_t, Hasher32> *, 3> &field_values)
const {
Option<bool> Collection::reference_populate_sort_mapping(int *sort_order, std::vector<size_t> &geopoint_indices,
std::vector<sort_by> &sort_fields_std,
std::array<spp::sparse_hash_map<uint32_t, int64_t, Hasher32> *, 3> &field_values)
const {
std::shared_lock lock(mutex);
index->populate_sort_mapping_with_lock(sort_order, geopoint_indices, sort_fields_std, field_values);
return index->populate_sort_mapping_with_lock(sort_order, geopoint_indices, sort_fields_std, field_values);
}

int64_t Collection::reference_string_sort_score(const string &field_name, const uint32_t& seq_id) const {
5 changes: 5 additions & 0 deletions src/collection_manager.cpp
Original file line number Diff line number Diff line change
@@ -1036,6 +1036,11 @@ bool CollectionManager::parse_sort_by_str(std::string sort_by_str, std::vector<s
i = open_paren_pos;
while(sort_by_str[++i] == ' ');

if (sort_by_str[i] == '$' && sort_by_str.find('(', i) != std::string::npos) {
// Reference expression inside `_eval()`
return false;
}

auto result = sort_by_str[i] == '[' ? parse_multi_eval(sort_by_str, i, sort_fields) :
parse_eval(sort_by_str, --i, sort_fields);
if (!result) {
6 changes: 6 additions & 0 deletions src/filter_result_iterator.cpp
Original file line number Diff line number Diff line change
@@ -793,6 +793,12 @@ void filter_result_iterator_t::init() {
ref_collection_name = ref_collection->name;

auto coll = cm.get_collection(collection_name);
if (coll == nullptr) {
status = Option<bool>(400, "Collection `" + collection_name + "` not found.");
validity = invalid;
return;
}

bool is_referenced = coll->referenced_in.count(ref_collection_name) > 0,
has_reference = ref_collection->is_referenced_in(collection_name);
if (!is_referenced && !has_reference) {
35 changes: 23 additions & 12 deletions src/index.cpp
Original file line number Diff line number Diff line change
@@ -2807,7 +2807,10 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
int sort_order[3]; // 1 or -1 based on DESC or ASC respectively
std::array<spp::sparse_hash_map<uint32_t, int64_t, Hasher32>*, 3> field_values;
std::vector<size_t> geopoint_indices;
populate_sort_mapping(sort_order, geopoint_indices, sort_fields_std, field_values);
auto populate_op = populate_sort_mapping(sort_order, geopoint_indices, sort_fields_std, field_values);
if (!populate_op.ok()) {
return populate_op;
}

// Prepare excluded document IDs that we can later remove from the result set
uint32_t* excluded_result_ids = nullptr;
@@ -5006,6 +5009,10 @@ Option<bool> Index::compute_sort_scores(const std::vector<sort_by>& sort_fields,
bool found = false;
uint32_t index = 0;
auto const& eval = sort_fields[0].eval;
if (eval.eval_ids_vec.size() != count || eval.eval_ids_count_vec.size() != count) {
return Option<bool>(400, "Eval expressions count does not match the ids count.");
}

for (; index < count; index++) {
// ref_seq_id(s) can be unordered.
uint32_t ref_filter_index = 0;
@@ -5140,7 +5147,6 @@ Option<bool> Index::compute_sort_scores(const std::vector<sort_by>& sort_fields,
}

scores[1] = found ? eval.scores[index] : 0;
LOG(INFO) << "seq_id: " << seq_id << " ref_seq_id: " << ref_seq_id << " score: " << scores[1] << " index: " << index;
} else if(field_values[1] == &vector_distance_sentinel_value) {
scores[1] = float_to_int64_t(vector_distance);
} else if(field_values[1] == &vector_query_sentinel_value) {
@@ -6124,9 +6130,9 @@ Option<bool> Index::search_wildcard(filter_node_t const* const& filter_tree_root
return Option<bool>(true);
}

void Index::populate_sort_mapping(int* sort_order, std::vector<size_t>& geopoint_indices,
std::vector<sort_by>& sort_fields_std,
std::array<spp::sparse_hash_map<uint32_t, int64_t, Hasher32>*, 3>& field_values) const {
Option<bool> Index::populate_sort_mapping(int* sort_order, std::vector<size_t>& geopoint_indices,
std::vector<sort_by>& sort_fields_std,
std::array<spp::sparse_hash_map<uint32_t, int64_t, Hasher32>*, 3>& field_values) const {
for (size_t i = 0; i < sort_fields_std.size(); i++) {
if (!sort_fields_std[i].reference_collection_name.empty()) {
auto& cm = CollectionManager::get_instance();
@@ -6138,8 +6144,11 @@ void Index::populate_sort_mapping(int* sort_order, std::vector<size_t>& geopoint
ref_sort_fields_std.emplace_back(sort_fields_std[i]);
ref_sort_fields_std.front().reference_collection_name.clear();
std::array<spp::sparse_hash_map<uint32_t, int64_t, Hasher32>*, 3> ref_field_values;
ref_collection->reference_populate_sort_mapping(ref_sort_order, ref_geopoint_indices,
ref_sort_fields_std, ref_field_values);
auto populate_op = ref_collection->reference_populate_sort_mapping(ref_sort_order, ref_geopoint_indices,
ref_sort_fields_std, ref_field_values);
if (!populate_op.ok()) {
return populate_op;
}

sort_order[i] = ref_sort_order[0];
if (!ref_geopoint_indices.empty()) {
@@ -6171,7 +6180,7 @@ void Index::populate_sort_mapping(int* sort_order, std::vector<size_t>& geopoint
search_begin_us, search_stop_us);
auto filter_init_op = filter_result_iterator.init_status();
if (!filter_init_op.ok()) {
return;
return filter_init_op;
}

filter_result_iterator.compute_iterators();
@@ -6200,13 +6209,15 @@ void Index::populate_sort_mapping(int* sort_order, std::vector<size_t>& geopoint
}
}
}

return Option<bool>(true);
}

void Index::populate_sort_mapping_with_lock(int* sort_order, std::vector<size_t>& geopoint_indices,
std::vector<sort_by>& sort_fields_std,
std::array<spp::sparse_hash_map<uint32_t, int64_t, Hasher32>*, 3>& field_values) const {
Option<bool> Index::populate_sort_mapping_with_lock(int* sort_order, std::vector<size_t>& geopoint_indices,
std::vector<sort_by>& sort_fields_std,
std::array<spp::sparse_hash_map<uint32_t, int64_t, Hasher32>*, 3>& field_values) const {
std::shared_lock lock(mutex);
populate_sort_mapping(sort_order, geopoint_indices, sort_fields_std, field_values);
return populate_sort_mapping(sort_order, geopoint_indices, sort_fields_std, field_values);
}

int Index::get_bounded_typo_cost(const size_t max_cost, const std::string& token, const size_t token_len,
11 changes: 11 additions & 0 deletions test/collection_join_test.cpp
Original file line number Diff line number Diff line change
@@ -5904,6 +5904,17 @@ TEST_F(CollectionJoinTest, SortByReference) {
ASSERT_EQ("2", res_obj["hits"][3]["document"].at("id"));
ASSERT_EQ("6", res_obj["hits"][4]["document"].at("id"));
ASSERT_EQ("1", res_obj["hits"][5]["document"].at("id"));

req_params = {
{"collection", "product"},
{"q", "tablet"},
{"query_by", "name"},
{"filter_by", "$stock(id: *)"},
{"sort_by", "_eval($stock(store_1:true || store_2:true)):desc"}
};
search_op = collectionManager.do_search(req_params, embedded_params, json_res, now_ts);
ASSERT_FALSE(search_op.ok());
ASSERT_EQ("Parameter `sort_by` is malformed.", search_op.error());
}

TEST_F(CollectionJoinTest, FilterByReferenceAlias) {

0 comments on commit 65151cd

Please sign in to comment.