Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vector query parse bugs #1937

Merged
merged 5 commits into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions include/index.h
Original file line number Diff line number Diff line change
Expand Up @@ -1024,10 +1024,9 @@ class Index {
std::array<spp::sparse_hash_map<uint32_t, int64_t, Hasher32>*, 3> field_values,
const std::vector<size_t>& geopoint_indices, uint32_t seq_id,
const std::map<basic_string<char>, reference_filter_result_t>& references,
std::vector<uint32_t>& filter_indexes,
int64_t max_field_match_score,
int64_t* scores,
int64_t& match_score_index, float vector_distance = 0,
std::vector<uint32_t>& filter_indexes, int64_t max_field_match_score,
int64_t* scores, int64_t& match_score_index, bool& should_skip,
float vector_distance = 0,
const std::string& collection_name = "") const;

void process_curated_ids(const std::vector<std::pair<uint32_t, uint32_t>>& included_ids,
Expand Down
3 changes: 2 additions & 1 deletion include/vector_query_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
#include <string>
#include <vector>
#include "option.h"
#include <limits>

class Collection;

struct vector_query_t {
std::string field_name;
size_t k = 0;
size_t flat_search_cutoff = 0;
float distance_threshold = 2.01;
float distance_threshold = std::numeric_limits<float>::max();
std::vector<float> values;

uint32_t seq_id = 0;
Expand Down
72 changes: 61 additions & 11 deletions src/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3175,15 +3175,20 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons

int64_t scores[3] = {0};
int64_t match_score_index = -1;
bool should_skip = false;

auto compute_sort_scores_op = compute_sort_scores(sort_fields_std, sort_order, field_values,
geopoint_indices, seq_id, references, eval_filter_indexes,
0, scores, match_score_index, vec_dist_score,
0, scores, match_score_index, should_skip, vec_dist_score,
collection_name);
if (!compute_sort_scores_op.ok()) {
return compute_sort_scores_op;
}

if(should_skip) {
continue;
}

KV kv(searched_queries.size(), seq_id, distinct_id, match_score_index, scores, std::move(references));
kv.vector_distance = vec_dist_score;
int ret = topster->add(&kv);
Expand Down Expand Up @@ -3606,15 +3611,20 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
((1.0 / (seq_id_to_rank[seq_id] + 1)) * VECTOR_SEARCH_WEIGHT));
int64_t match_score_index = -1;
int64_t scores[3] = {0};
bool should_skip = false;

auto compute_sort_scores_op = compute_sort_scores(sort_fields_std, sort_order, field_values,
geopoint_indices, seq_id, references, eval_filter_indexes,
match_score, scores, match_score_index,
match_score, scores, match_score_index, should_skip,
vec_result.second, collection_name);
if (!compute_sort_scores_op.ok()) {
return compute_sort_scores_op;
}

if(should_skip) {
continue;
}

for(int i = 0; i < 3; i++) {
found_kv->scores[i] = scores[i];
}
Expand All @@ -3627,15 +3637,20 @@ Option<bool> Index::search(std::vector<query_tokens_t>& field_query_tokens, cons
int64_t scores[3] = {0};
int64_t match_score = float_to_int64_t((1.0 / (seq_id_to_rank[seq_id] + 1)) * VECTOR_SEARCH_WEIGHT);
int64_t match_score_index = -1;
bool should_skip = false;

auto compute_sort_scores_op = compute_sort_scores(sort_fields_std, sort_order, field_values,
geopoint_indices, seq_id, references, eval_filter_indexes,
match_score, scores, match_score_index,
match_score, scores, match_score_index, should_skip,
vec_result.second, collection_name);
if (!compute_sort_scores_op.ok()) {
return compute_sort_scores_op;
}

if(should_skip) {
continue;
}

uint64_t distinct_id = seq_id;
if (group_limit != 0) {
distinct_id = 1;
Expand Down Expand Up @@ -4765,15 +4780,20 @@ Option<bool> Index::search_across_fields(const std::vector<token_t>& query_token

int64_t scores[3] = {0};
int64_t match_score_index = -1;
bool should_skip = false;

auto compute_sort_scores_op = compute_sort_scores(sort_fields, sort_order, field_values, geopoint_indices,
seq_id, references, eval_filter_indexes, best_field_match_score,
scores, match_score_index, 0, collection_name);
scores, match_score_index, should_skip, 0, collection_name);
if (!compute_sort_scores_op.ok()) {
status = Option<bool>(compute_sort_scores_op.code(), compute_sort_scores_op.error());
return;
}

if(should_skip) {
return;
}

query_len = std::min<size_t>(15, query_len);

// NOTE: `query_len` is total tokens matched across fields.
Expand Down Expand Up @@ -5074,7 +5094,7 @@ Option<bool> Index::compute_sort_scores(const std::vector<sort_by>& sort_fields,
const std::vector<size_t>& geopoint_indices,
uint32_t seq_id, const std::map<basic_string<char>, reference_filter_result_t>& references,
std::vector<uint32_t>& filter_indexes, int64_t max_field_match_score, int64_t* scores,
int64_t& match_score_index, float vector_distance,
int64_t& match_score_index, bool& should_skip, float vector_distance,
const std::string& collection_name) const {

int64_t geopoint_distances[3];
Expand Down Expand Up @@ -5223,7 +5243,12 @@ Option<bool> Index::compute_sort_scores(const std::vector<sort_by>& sort_fields,
const auto& values = sort_fields[0].vector_query.vector_index->vecdex->getDataByLabel<float>(seq_id);
const auto& dist_func = sort_fields[0].vector_query.vector_index->space->get_dist_func();
float dist = dist_func(sort_fields[0].vector_query.query.values.data(), values.data(), &sort_fields[0].vector_query.vector_index->num_dim);


if(dist > sort_fields[0].vector_query.query.distance_threshold) {
//if computed distance is more then distance_thershold then we wont add that to results
should_skip = true;
}

scores[0] = float_to_int64_t(dist);
} catch(...) {
// probably not found
Expand Down Expand Up @@ -5338,7 +5363,12 @@ Option<bool> Index::compute_sort_scores(const std::vector<sort_by>& sort_fields,
const auto& values = sort_fields[1].vector_query.vector_index->vecdex->getDataByLabel<float>(seq_id);
const auto& dist_func = sort_fields[1].vector_query.vector_index->space->get_dist_func();
float dist = dist_func(sort_fields[1].vector_query.query.values.data(), values.data(), &sort_fields[1].vector_query.vector_index->num_dim);


if(dist > sort_fields[1].vector_query.query.distance_threshold) {
//if computed distance is more then distance_thershold then we wont add that to results
should_skip = true;
}

scores[1] = float_to_int64_t(dist);
} catch(...) {
// probably not found
Expand Down Expand Up @@ -5449,7 +5479,12 @@ Option<bool> Index::compute_sort_scores(const std::vector<sort_by>& sort_fields,
const auto& values = sort_fields[2].vector_query.vector_index->vecdex->getDataByLabel<float>(seq_id);
const auto& dist_func = sort_fields[2].vector_query.vector_index->space->get_dist_func();
float dist = dist_func(sort_fields[2].vector_query.query.values.data(), values.data(), &sort_fields[2].vector_query.vector_index->num_dim);


if(dist > sort_fields[2].vector_query.query.distance_threshold) {
//if computed distance is more then distance_thershold then we wont add that to results
should_skip = true;
}

scores[2] = float_to_int64_t(dist);
} catch(...) {
// probably not found
Expand Down Expand Up @@ -5621,14 +5656,19 @@ Option<bool> Index::do_phrase_search(const size_t num_search_fields, const std::
int64_t match_score = phrase_match_id_scores[seq_id];
int64_t scores[3] = {0};
int64_t match_score_index = -1;
bool should_skip = false;

auto compute_sort_scores_op = compute_sort_scores(sort_fields, sort_order, field_values, geopoint_indices,
seq_id, references, eval_filter_indexes, match_score, scores,
match_score_index, 0, collection_name);
match_score_index, should_skip, 0, collection_name);
if (!compute_sort_scores_op.ok()) {
return compute_sort_scores_op;
}

if(should_skip) {
continue;
}

uint64_t distinct_id = seq_id;
if(group_limit != 0) {
distinct_id = 1;
Expand Down Expand Up @@ -5795,15 +5835,20 @@ Option<bool> Index::do_infix_search(const size_t num_search_fields, const std::v

int64_t scores[3] = {0};
int64_t match_score_index = -1;
bool should_skip = false;

auto compute_sort_scores_op = compute_sort_scores(sort_fields, sort_order, field_values,
geopoint_indices, seq_id, references,
eval_filter_indexes, 100, scores, match_score_index,
0, collection_name);
should_skip, 0, collection_name);
if (!compute_sort_scores_op.ok()) {
return compute_sort_scores_op;
}

if(should_skip) {
continue;
}

uint64_t distinct_id = seq_id;
if(group_limit != 0) {
distinct_id = 1;
Expand Down Expand Up @@ -6222,15 +6267,20 @@ Option<bool> Index::search_wildcard(filter_node_t const* const& filter_tree_root

int64_t scores[3] = {0};
int64_t match_score_index = -1;
bool should_skip = false;

auto compute_sort_scores_op = compute_sort_scores(sort_fields, sort_order, field_values, geopoint_indices,
seq_id, references, filter_indexes, 100, scores,
match_score_index, 0, collection_name);
match_score_index, should_skip, 0, collection_name);
if (!compute_sort_scores_op.ok()) {
compute_sort_score_status = new Option<bool>(compute_sort_scores_op.code(), compute_sort_scores_op.error());
break;
}

if(should_skip) {
continue;
}

uint64_t distinct_id = seq_id;
if(group_limit != 0) {
distinct_id = 1;
Expand Down
13 changes: 11 additions & 2 deletions src/vector_query_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,12 +165,21 @@ Option<bool> VectorQueryOps::parse_vector_query_str(const std::string& vector_qu
}

if(param_kv[0] == "distance_threshold") {
if(!StringUtils::is_float(param_kv[1]) || std::stof(param_kv[1]) < 0.0 || std::stof(param_kv[1]) > 2.0) {
auto search_schema = const_cast<Collection*>(coll)->get_schema();
auto vector_field_it = search_schema.find(vector_query.field_name);

if(!StringUtils::is_float(param_kv[1])) {
return Option<bool>(400, "Malformed vector query string: "
"`distance_threshold` parameter must be a float.");
}

auto distance_threshold = std::stof(param_kv[1]);
if(vector_field_it->vec_dist == cosine && (distance_threshold < 0.0 || distance_threshold > 2.0)) {
return Option<bool>(400, "Malformed vector query string: "
"`distance_threshold` parameter must be a float between 0.0-2.0.");
}

vector_query.distance_threshold = std::stof(param_kv[1]);
vector_query.distance_threshold = distance_threshold;
}

if(param_kv[0] == "alpha") {
Expand Down
81 changes: 81 additions & 0 deletions test/collection_vector_search_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4945,4 +4945,85 @@ TEST_F(CollectionVectorTest, TestRestoringImages) {
coll = collectionManager.get_collection("test").get();

ASSERT_EQ(1, coll->get_summary_json()["num_documents"]);
}

TEST_F(CollectionVectorTest, TestDistanceThresholdWithIP) {
auto schema_json = R"({
"name": "products",
"fields":[
{"name": "name","type": "string"},
{"name": "rank_score", "type": "float"},
{"name": "embedding","type": "float[]", "num_dim":5, "optinal":true, "vec_dist": "ip"}
],
"default_sorting_field": "rank_score"
})"_json;


auto coll_op = collectionManager.create_collection(schema_json);
ASSERT_TRUE(coll_op.ok());
auto coll = coll_op.get();

std::mt19937 rng;
rng.seed(47);
std::uniform_real_distribution<> distrib(-1,1);
std::uniform_int_distribution<>distrib2(0,100);

nlohmann::json doc;
for (auto i = 0; i < 5; ++i) {
std::vector<float> vector(5);
std::generate(vector.begin(), vector.end(), [&](){ return distrib(rng); });

doc["name"] = "document_" + std::to_string(i);
doc["rank_score"] = distrib2(rng);
doc["embedding"] = vector;
ASSERT_TRUE(coll->add(doc.dump()).ok());
}

//results ids exceeding distance_threshold will be skipped
std::map<std::string, std::string> req_params = {
{"collection", "products"},
{"q", "document"},
{"query_by", "*"},
{"sort_by", "_text_match:desc,_vector_query(embedding:([0.11731103425347378, -0.6694758317235057, -0.6211945774857595, -0.27966758971688255, -0.4683744007950299],distance_threshold:1)):asc,rank_score:desc"},
{"exclude_fields", "embedding"}
};
nlohmann::json embedded_params;
std::string json_res;
auto now_ts = std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::system_clock::now().time_since_epoch()).count();
auto search_op = collectionManager.do_search(req_params, embedded_params, json_res, now_ts);
auto res = nlohmann::json::parse(json_res);


ASSERT_EQ(2, res["found"].get<size_t>());
ASSERT_EQ(93, res["hits"][0]["document"]["rank_score"].get<size_t>());
ASSERT_EQ(0.2189185470342636, res["hits"][0]["vector_distance"].get<float>());
ASSERT_EQ(51, res["hits"][1]["document"]["rank_score"].get<size_t>());
ASSERT_EQ(0.7371898889541626, res["hits"][1]["vector_distance"].get<float>());

//inner product distances should work when distance_threshold is not given
req_params = {
{"collection", "products"},
{"q", "document"},
{"query_by", "*"},
{"sort_by", "_text_match:desc,_vector_query(embedding:([-100,-100,-100,-100,-100])):asc,rank_score:desc"},
{"exclude_fields", "embedding"}
};
now_ts = std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::system_clock::now().time_since_epoch()).count();
search_op = collectionManager.do_search(req_params, embedded_params, json_res, now_ts);
res = nlohmann::json::parse(json_res);

ASSERT_EQ(5, res["found"].get<size_t>());

ASSERT_EQ("document_1", res["hits"][0]["document"]["name"]);
ASSERT_EQ(-45.23314666748047, res["hits"][0]["vector_distance"].get<float>());
ASSERT_EQ("document_2", res["hits"][1]["document"]["name"]);
ASSERT_EQ(-38.66290283203125, res["hits"][1]["vector_distance"].get<float>());
ASSERT_EQ("document_4", res["hits"][2]["document"]["name"]);
ASSERT_EQ(-36.0988655090332, res["hits"][2]["vector_distance"].get<float>());
ASSERT_EQ("document_3", res["hits"][3]["document"]["name"]);
ASSERT_EQ(9.637892723083496, res["hits"][3]["vector_distance"].get<float>());
ASSERT_EQ("document_0", res["hits"][4]["document"]["name"]);
ASSERT_EQ(288.0364685058594, res["hits"][4]["vector_distance"].get<float>());
}
Loading