Skip to content

Commit

Permalink
Vector query parse bugs (#1937)
Browse files Browse the repository at this point in the history
* fix _vector_query parsing in sort_by param

* update model and make brace counter integer

* _vector_query bug fixes

* set default distance_threshold to max float value
  • Loading branch information
krunal1313 authored Sep 6, 2024
1 parent 1128f7f commit 2b4d7ba
Show file tree
Hide file tree
Showing 5 changed files with 158 additions and 18 deletions.
7 changes: 3 additions & 4 deletions include/index.h
Original file line number Diff line number Diff line change
Expand Up @@ -1024,10 +1024,9 @@ class Index {
std::array<spp::sparse_hash_map<uint32_t, int64_t, Hasher32>*, 3> field_values,
const std::vector<size_t>& geopoint_indices, uint32_t seq_id,
const std::map<basic_string<char>, reference_filter_result_t>& references,
std::vector<uint32_t>& filter_indexes,
int64_t max_field_match_score,
int64_t* scores,
int64_t& match_score_index, float vector_distance = 0,
std::vector<uint32_t>& filter_indexes, int64_t max_field_match_score,
int64_t* scores, int64_t& match_score_index, bool& should_skip,
float vector_distance = 0,
const std::string& collection_name = "") const;

void process_curated_ids(const std::vector<std::pair<uint32_t, uint32_t>>& included_ids,
Expand Down
3 changes: 2 additions & 1 deletion include/vector_query_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
#include <string>
#include <vector>
#include "option.h"
#include <limits>

class Collection;

struct vector_query_t {
std::string field_name;
size_t k = 0;
size_t flat_search_cutoff = 0;
float distance_threshold = 2.01;
float distance_threshold = std::numeric_limits<float>::max();
std::vector<float> values;

uint32_t seq_id = 0;
Expand Down
72 changes: 61 additions & 11 deletions src/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3175,15 +3175,20 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons

int64_t scores[3] = {0};
int64_t match_score_index = -1;
bool should_skip = false;

auto compute_sort_scores_op = compute_sort_scores(sort_fields_std, sort_order, field_values,
geopoint_indices, seq_id, references, eval_filter_indexes,
0, scores, match_score_index, vec_dist_score,
0, scores, match_score_index, should_skip, vec_dist_score,
collection_name);
if (!compute_sort_scores_op.ok()) {
return compute_sort_scores_op;
}

if(should_skip) {
continue;
}

KV kv(searched_queries.size(), seq_id, distinct_id, match_score_index, scores, std::move(references));
kv.vector_distance = vec_dist_score;
int ret = topster->add(&kv);
Expand Down Expand Up @@ -3606,15 +3611,20 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
((1.0 / (seq_id_to_rank[seq_id] + 1)) * VECTOR_SEARCH_WEIGHT));
int64_t match_score_index = -1;
int64_t scores[3] = {0};
bool should_skip = false;

auto compute_sort_scores_op = compute_sort_scores(sort_fields_std, sort_order, field_values,
geopoint_indices, seq_id, references, eval_filter_indexes,
match_score, scores, match_score_index,
match_score, scores, match_score_index, should_skip,
vec_result.second, collection_name);
if (!compute_sort_scores_op.ok()) {
return compute_sort_scores_op;
}

if(should_skip) {
continue;
}

for(int i = 0; i < 3; i++) {
found_kv->scores[i] = scores[i];
}
Expand All @@ -3627,15 +3637,20 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
int64_t scores[3] = {0};
int64_t match_score = float_to_int64_t((1.0 / (seq_id_to_rank[seq_id] + 1)) * VECTOR_SEARCH_WEIGHT);
int64_t match_score_index = -1;
bool should_skip = false;

auto compute_sort_scores_op = compute_sort_scores(sort_fields_std, sort_order, field_values,
geopoint_indices, seq_id, references, eval_filter_indexes,
match_score, scores, match_score_index,
match_score, scores, match_score_index, should_skip,
vec_result.second, collection_name);
if (!compute_sort_scores_op.ok()) {
return compute_sort_scores_op;
}

if(should_skip) {
continue;
}

uint64_t distinct_id = seq_id;
if (group_limit != 0) {
distinct_id = 1;
Expand Down Expand Up @@ -4765,15 +4780,20 @@ Option<bool> Index::search_across_fields(const std::vector<token_t>& query_token

int64_t scores[3] = {0};
int64_t match_score_index = -1;
bool should_skip = false;

auto compute_sort_scores_op = compute_sort_scores(sort_fields, sort_order, field_values, geopoint_indices,
seq_id, references, eval_filter_indexes, best_field_match_score,
scores, match_score_index, 0, collection_name);
scores, match_score_index, should_skip, 0, collection_name);
if (!compute_sort_scores_op.ok()) {
status = Option<bool>(compute_sort_scores_op.code(), compute_sort_scores_op.error());
return;
}

if(should_skip) {
return;
}

query_len = std::min<size_t>(15, query_len);

// NOTE: `query_len` is total tokens matched across fields.
Expand Down Expand Up @@ -5074,7 +5094,7 @@ Option<bool> Index::compute_sort_scores(const std::vector<sort_by>& sort_fields,
const std::vector<size_t>& geopoint_indices,
uint32_t seq_id, const std::map<basic_string<char>, reference_filter_result_t>& references,
std::vector<uint32_t>& filter_indexes, int64_t max_field_match_score, int64_t* scores,
int64_t& match_score_index, float vector_distance,
int64_t& match_score_index, bool& should_skip, float vector_distance,
const std::string& collection_name) const {

int64_t geopoint_distances[3];
Expand Down Expand Up @@ -5223,7 +5243,12 @@ Option<bool> Index::compute_sort_scores(const std::vector<sort_by>& sort_fields,
const auto& values = sort_fields[0].vector_query.vector_index->vecdex->getDataByLabel<float>(seq_id);
const auto& dist_func = sort_fields[0].vector_query.vector_index->space->get_dist_func();
float dist = dist_func(sort_fields[0].vector_query.query.values.data(), values.data(), &sort_fields[0].vector_query.vector_index->num_dim);


if(dist > sort_fields[0].vector_query.query.distance_threshold) {
//if computed distance is more then distance_thershold then we wont add that to results
should_skip = true;
}

scores[0] = float_to_int64_t(dist);
} catch(...) {
// probably not found
Expand Down Expand Up @@ -5338,7 +5363,12 @@ Option<bool> Index::compute_sort_scores(const std::vector<sort_by>& sort_fields,
const auto& values = sort_fields[1].vector_query.vector_index->vecdex->getDataByLabel<float>(seq_id);
const auto& dist_func = sort_fields[1].vector_query.vector_index->space->get_dist_func();
float dist = dist_func(sort_fields[1].vector_query.query.values.data(), values.data(), &sort_fields[1].vector_query.vector_index->num_dim);


if(dist > sort_fields[1].vector_query.query.distance_threshold) {
//if computed distance is more then distance_thershold then we wont add that to results
should_skip = true;
}

scores[1] = float_to_int64_t(dist);
} catch(...) {
// probably not found
Expand Down Expand Up @@ -5452,7 +5482,12 @@ Option<bool> Index::compute_sort_scores(const std::vector<sort_by>& sort_fields,
const auto& values = sort_fields[2].vector_query.vector_index->vecdex->getDataByLabel<float>(seq_id);
const auto& dist_func = sort_fields[2].vector_query.vector_index->space->get_dist_func();
float dist = dist_func(sort_fields[2].vector_query.query.values.data(), values.data(), &sort_fields[2].vector_query.vector_index->num_dim);


if(dist > sort_fields[2].vector_query.query.distance_threshold) {
//if computed distance is more then distance_thershold then we wont add that to results
should_skip = true;
}

scores[2] = float_to_int64_t(dist);
} catch(...) {
// probably not found
Expand Down Expand Up @@ -5626,14 +5661,19 @@ Option<bool> Index::do_phrase_search(const size_t num_search_fields, const std::
int64_t match_score = phrase_match_id_scores[seq_id];
int64_t scores[3] = {0};
int64_t match_score_index = -1;
bool should_skip = false;

auto compute_sort_scores_op = compute_sort_scores(sort_fields, sort_order, field_values, geopoint_indices,
seq_id, references, eval_filter_indexes, match_score, scores,
match_score_index, 0, collection_name);
match_score_index, should_skip, 0, collection_name);
if (!compute_sort_scores_op.ok()) {
return compute_sort_scores_op;
}

if(should_skip) {
continue;
}

uint64_t distinct_id = seq_id;
if(group_limit != 0) {
distinct_id = 1;
Expand Down Expand Up @@ -5800,15 +5840,20 @@ Option<bool> Index::do_infix_search(const size_t num_search_fields, const std::v

int64_t scores[3] = {0};
int64_t match_score_index = -1;
bool should_skip = false;

auto compute_sort_scores_op = compute_sort_scores(sort_fields, sort_order, field_values,
geopoint_indices, seq_id, references,
eval_filter_indexes, 100, scores, match_score_index,
0, collection_name);
should_skip, 0, collection_name);
if (!compute_sort_scores_op.ok()) {
return compute_sort_scores_op;
}

if(should_skip) {
continue;
}

uint64_t distinct_id = seq_id;
if(group_limit != 0) {
distinct_id = 1;
Expand Down Expand Up @@ -6227,15 +6272,20 @@ Option<bool> Index::search_wildcard(filter_node_t const* const& filter_tree_root

int64_t scores[3] = {0};
int64_t match_score_index = -1;
bool should_skip = false;

auto compute_sort_scores_op = compute_sort_scores(sort_fields, sort_order, field_values, geopoint_indices,
seq_id, references, filter_indexes, 100, scores,
match_score_index, 0, collection_name);
match_score_index, should_skip, 0, collection_name);
if (!compute_sort_scores_op.ok()) {
compute_sort_score_status = new Option<bool>(compute_sort_scores_op.code(), compute_sort_scores_op.error());
break;
}

if(should_skip) {
continue;
}

uint64_t distinct_id = seq_id;
if(group_limit != 0) {
distinct_id = 1;
Expand Down
13 changes: 11 additions & 2 deletions src/vector_query_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,12 +165,21 @@ Option<bool> VectorQueryOps::parse_vector_query_str(const std::string& vector_qu
}

if(param_kv[0] == "distance_threshold") {
if(!StringUtils::is_float(param_kv[1]) || std::stof(param_kv[1]) < 0.0 || std::stof(param_kv[1]) > 2.0) {
auto search_schema = const_cast<Collection*>(coll)->get_schema();
auto vector_field_it = search_schema.find(vector_query.field_name);

if(!StringUtils::is_float(param_kv[1])) {
return Option<bool>(400, "Malformed vector query string: "
"`distance_threshold` parameter must be a float.");
}

auto distance_threshold = std::stof(param_kv[1]);
if(vector_field_it->vec_dist == cosine && (distance_threshold < 0.0 || distance_threshold > 2.0)) {
return Option<bool>(400, "Malformed vector query string: "
"`distance_threshold` parameter must be a float between 0.0-2.0.");
}

vector_query.distance_threshold = std::stof(param_kv[1]);
vector_query.distance_threshold = distance_threshold;
}

if(param_kv[0] == "alpha") {
Expand Down
81 changes: 81 additions & 0 deletions test/collection_vector_search_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4945,4 +4945,85 @@ TEST_F(CollectionVectorTest, TestRestoringImages) {
coll = collectionManager.get_collection("test").get();

ASSERT_EQ(1, coll->get_summary_json()["num_documents"]);
}

TEST_F(CollectionVectorTest, TestDistanceThresholdWithIP) {
auto schema_json = R"({
"name": "products",
"fields":[
{"name": "name","type": "string"},
{"name": "rank_score", "type": "float"},
{"name": "embedding","type": "float[]", "num_dim":5, "optinal":true, "vec_dist": "ip"}
],
"default_sorting_field": "rank_score"
})"_json;


auto coll_op = collectionManager.create_collection(schema_json);
ASSERT_TRUE(coll_op.ok());
auto coll = coll_op.get();

std::mt19937 rng;
rng.seed(47);
std::uniform_real_distribution<> distrib(-1,1);
std::uniform_int_distribution<>distrib2(0,100);

nlohmann::json doc;
for (auto i = 0; i < 5; ++i) {
std::vector<float> vector(5);
std::generate(vector.begin(), vector.end(), [&](){ return distrib(rng); });

doc["name"] = "document_" + std::to_string(i);
doc["rank_score"] = distrib2(rng);
doc["embedding"] = vector;
ASSERT_TRUE(coll->add(doc.dump()).ok());
}

//results ids exceeding distance_threshold will be skipped
std::map<std::string, std::string> req_params = {
{"collection", "products"},
{"q", "document"},
{"query_by", "*"},
{"sort_by", "_text_match:desc,_vector_query(embedding:([0.11731103425347378, -0.6694758317235057, -0.6211945774857595, -0.27966758971688255, -0.4683744007950299],distance_threshold:1)):asc,rank_score:desc"},
{"exclude_fields", "embedding"}
};
nlohmann::json embedded_params;
std::string json_res;
auto now_ts = std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::system_clock::now().time_since_epoch()).count();
auto search_op = collectionManager.do_search(req_params, embedded_params, json_res, now_ts);
auto res = nlohmann::json::parse(json_res);


ASSERT_EQ(2, res["found"].get<size_t>());
ASSERT_EQ(93, res["hits"][0]["document"]["rank_score"].get<size_t>());
ASSERT_EQ(0.2189185470342636, res["hits"][0]["vector_distance"].get<float>());
ASSERT_EQ(51, res["hits"][1]["document"]["rank_score"].get<size_t>());
ASSERT_EQ(0.7371898889541626, res["hits"][1]["vector_distance"].get<float>());

//inner product distances should work when distance_threshold is not given
req_params = {
{"collection", "products"},
{"q", "document"},
{"query_by", "*"},
{"sort_by", "_text_match:desc,_vector_query(embedding:([-100,-100,-100,-100,-100])):asc,rank_score:desc"},
{"exclude_fields", "embedding"}
};
now_ts = std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::system_clock::now().time_since_epoch()).count();
search_op = collectionManager.do_search(req_params, embedded_params, json_res, now_ts);
res = nlohmann::json::parse(json_res);

ASSERT_EQ(5, res["found"].get<size_t>());

ASSERT_EQ("document_1", res["hits"][0]["document"]["name"]);
ASSERT_EQ(-45.23314666748047, res["hits"][0]["vector_distance"].get<float>());
ASSERT_EQ("document_2", res["hits"][1]["document"]["name"]);
ASSERT_EQ(-38.66290283203125, res["hits"][1]["vector_distance"].get<float>());
ASSERT_EQ("document_4", res["hits"][2]["document"]["name"]);
ASSERT_EQ(-36.0988655090332, res["hits"][2]["vector_distance"].get<float>());
ASSERT_EQ("document_3", res["hits"][3]["document"]["name"]);
ASSERT_EQ(9.637892723083496, res["hits"][3]["vector_distance"].get<float>());
ASSERT_EQ("document_0", res["hits"][4]["document"]["name"]);
ASSERT_EQ(288.0364685058594, res["hits"][4]["vector_distance"].get<float>());
}

0 comments on commit 2b4d7ba

Please sign in to comment.