Skip to content

Commit

Permalink
milvus-io#1653 IndexFlat performance improvement for NQ < thread_numb…
Browse files Browse the repository at this point in the history
…er (milvus-io#1674)

* Optimize index flat L2/IP for SSE

Signed-off-by: sahuang <xiaohai.xu@zilliz.com>

* parallel optimization

Signed-off-by: sahuang <xiaohai.xu@zilliz.com>

* fix threshold

Signed-off-by: sahuang <xiaohai.xu@zilliz.com>

* add changelog

Signed-off-by: sahuang <xiaohai.xu@zilliz.com>

* add changelog

Signed-off-by: sahuang <xiaohai.xu@zilliz.com>

Co-authored-by: sahuang <xiaohai.xu@zilliz.com>
  • Loading branch information
Xiaohai Xu and sahuang authored Mar 16, 2020
1 parent 3de34d3 commit 59dab6c
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 45 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ Please mark all change in change log and use the issue from GitHub
- \#1546 Move Config.cpp to config directory
- \#1547 Rename storage/file to storage/disk and rename classes
- \#1548 Move store/Directory to storage/Operation and add FSHandler
- \#1649 Fix Milvus crash on old CPU
- \#1619 Improve compact performance
- \#1649 Fix Milvus crash on old CPU
- \#1653 IndexFlat performance improvement for NQ < thread_number

## Task

Expand Down
2 changes: 1 addition & 1 deletion core/src/index/thirdparty/faiss/utils/BinaryDistance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ namespace faiss {
if (init_heap) ha->heapify ();

int thread_max_num = omp_get_max_threads();
if (ha->nh < thread_max_num) {
if (ha->nh < 4) {
// omp for n2
int all_hash_size = thread_max_num * k;
float *value = new float[all_hash_size];
Expand Down
177 changes: 136 additions & 41 deletions core/src/index/thirdparty/faiss/utils/distances.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,39 +152,84 @@ static void knn_inner_product_sse (const float * x,
ConcurrentBitsetPtr bitset = nullptr)
{
size_t k = res->k;
size_t check_period = InterruptCallback::get_period_hint (ny * d);

check_period *= omp_get_max_threads();

for (size_t i0 = 0; i0 < nx; i0 += check_period) {
size_t i1 = std::min(i0 + check_period, nx);
size_t thread_max_num = omp_get_max_threads();
if (nx < 4) {
// omp for ny
size_t all_hash_size = thread_max_num * k;
float *value = new float[all_hash_size];
int64_t *labels = new int64_t[all_hash_size];

for (size_t i = 0; i < nx; i++) {
// init hash
for (size_t i = 0; i < all_hash_size; i++) {
value[i] = -1.0 / 0.0;
}
const float *x_i = x + i * d;
#pragma omp parallel for
for (size_t i = i0; i < i1; i++) {
const float * x_i = x + i * d;
const float * y_j = y;
for (size_t j = 0; j < ny; j++) {
if(!bitset || !bitset->test(j)) {
const float *y_j = y + j * d;
float ip = fvec_inner_product (x_i, y_j, d);

size_t thread_no = omp_get_thread_num();
float * __restrict val_ = value + thread_no * k;
int64_t * __restrict ids_ = labels + thread_no * k;
if (ip > val_[0]) {
minheap_pop (k, val_, ids_);
minheap_push (k, val_, ids_, ip, j);
}
}
}

// merge hash
float * __restrict simi = res->get_val(i);
int64_t * __restrict idxi = res->get_ids (i);

minheap_heapify (k, simi, idxi);
for (size_t i = 0; i < all_hash_size; i++) {
if (value[i] > simi[0]) {
minheap_pop (k, simi, idxi);
minheap_push (k, simi, idxi, value[i], labels[i]);
}
}
minheap_reorder (k, simi, idxi);
}
delete[] value;
delete[] labels;

for (size_t j = 0; j < ny; j++) {
if(!bitset || !bitset->test(j)){
float ip = fvec_inner_product (x_i, y_j, d);
} else {
size_t check_period = InterruptCallback::get_period_hint (ny * d);
check_period *= thread_max_num;

for (size_t i0 = 0; i0 < nx; i0 += check_period) {
size_t i1 = std::min(i0 + check_period, nx);

#pragma omp parallel for
for (size_t i = i0; i < i1; i++) {
const float * x_i = x + i * d;
const float * y_j = y;

float * __restrict simi = res->get_val(i);
int64_t * __restrict idxi = res->get_ids (i);

if (ip > simi[0]) {
minheap_pop (k, simi, idxi);
minheap_push (k, simi, idxi, ip, j);
minheap_heapify (k, simi, idxi);

for (size_t j = 0; j < ny; j++) {
if(!bitset || !bitset->test(j)){
float ip = fvec_inner_product (x_i, y_j, d);

if (ip > simi[0]) {
minheap_pop (k, simi, idxi);
minheap_push (k, simi, idxi, ip, j);
}
}
y_j += d;
}
y_j += d;
minheap_reorder (k, simi, idxi);
}
minheap_reorder (k, simi, idxi);
InterruptCallback::check ();
}
InterruptCallback::check ();
}

}

static void knn_L2sqr_sse (
Expand All @@ -196,37 +241,87 @@ static void knn_L2sqr_sse (
{
size_t k = res->k;

size_t check_period = InterruptCallback::get_period_hint (ny * d);
check_period *= omp_get_max_threads();

for (size_t i0 = 0; i0 < nx; i0 += check_period) {
size_t i1 = std::min(i0 + check_period, nx);
size_t thread_max_num = omp_get_max_threads();
if (nx < 4) {
// omp for ny
size_t all_hash_size = thread_max_num * k;
float *value = new float[all_hash_size];
int64_t *labels = new int64_t[all_hash_size];

for (size_t i = 0; i < nx; i++) {
// init hash
for (size_t i = 0; i < all_hash_size; i++) {
value[i] = 1.0 / 0.0;
}
for (size_t i = 0; i < k; i++) {
labels[i] = -1;
}
const float *x_i = x + i * d;
#pragma omp parallel for
for (size_t i = i0; i < i1; i++) {
const float * x_i = x + i * d;
const float * y_j = y;
size_t j;
float * simi = res->get_val(i);
int64_t * idxi = res->get_ids (i);

maxheap_heapify (k, simi, idxi);
for (j = 0; j < ny; j++) {
if(!bitset || !bitset->test(j)){
for (size_t j = 0; j < ny; j++) {
if(!bitset || !bitset->test(j)) {
const float *y_j = y + j * d;
float disij = fvec_L2sqr (x_i, y_j, d);

if (disij < simi[0]) {
maxheap_pop (k, simi, idxi);
maxheap_push (k, simi, idxi, disij, j);
size_t thread_no = omp_get_thread_num();
float * __restrict val_ = value + thread_no * k;
int64_t * __restrict ids_ = labels + thread_no * k;
if (disij < val_[0]) {
maxheap_pop (k, val_, ids_);
maxheap_push (k, val_, ids_, disij, j);
}
}
y_j += d;
}

// merge hash
float * __restrict simi = res->get_val(i);
int64_t * __restrict idxi = res->get_ids (i);
memcpy(simi, value, k * sizeof(float));
memcpy(idxi, labels, k * sizeof(int64_t));
maxheap_heapify (k, simi, idxi, value, labels, k);
for (size_t i = k; i < all_hash_size; i++) {
if (value[i] < simi[0]) {
maxheap_pop (k, simi, idxi);
maxheap_push (k, simi, idxi, value[i], labels[i]);
}
}
maxheap_reorder (k, simi, idxi);
}
InterruptCallback::check ();
}
delete[] value;
delete[] labels;

} else {
size_t check_period = InterruptCallback::get_period_hint (ny * d);
check_period *= thread_max_num;

for (size_t i0 = 0; i0 < nx; i0 += check_period) {
size_t i1 = std::min(i0 + check_period, nx);

#pragma omp parallel for
for (size_t i = i0; i < i1; i++) {
const float * x_i = x + i * d;
const float * y_j = y;
float * simi = res->get_val(i);
int64_t * idxi = res->get_ids (i);

maxheap_heapify (k, simi, idxi);

for (size_t j = 0; j < ny; j++) {
if(!bitset || !bitset->test(j)){
float disij = fvec_L2sqr (x_i, y_j, d);

if (disij < simi[0]) {
maxheap_pop (k, simi, idxi);
maxheap_push (k, simi, idxi, disij, j);
}
}
y_j += d;
}
maxheap_reorder (k, simi, idxi);
}
InterruptCallback::check ();
}
}
}


Expand Down Expand Up @@ -899,4 +994,4 @@ void pairwise_L2sqr (int64_t d,
}


} // namespace faiss
} // namespace faiss
4 changes: 2 additions & 2 deletions core/src/index/thirdparty/faiss/utils/hamming.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ void hammings_knn_hc (
if (init_heap) ha->heapify ();

int thread_max_num = omp_get_max_threads();
if (ha->nh < thread_max_num) {
if (ha->nh < 4) {
// omp for n2
int all_hash_size = thread_max_num * k;
hamdis_t *value = new hamdis_t[all_hash_size];
Expand Down Expand Up @@ -432,7 +432,7 @@ void hammings_knn_hc_1 (
}

int thread_max_num = omp_get_max_threads();
if (ha->nh < thread_max_num) {
if (ha->nh < 4) {
// omp for n2
int all_hash_size = thread_max_num * k;
hamdis_t *value = new hamdis_t[all_hash_size];
Expand Down

0 comments on commit 59dab6c

Please sign in to comment.