Skip to content

Commit

Permalink
Fix: Heap overflow for fractional-size scalars
Browse files Browse the repository at this point in the history
  • Loading branch information
ashvardanian committed Jul 23, 2023
1 parent 134a6f0 commit 459abcd
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 9 deletions.
5 changes: 4 additions & 1 deletion include/usearch/index.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2943,6 +2943,7 @@ class index_gt {
std::size_t memory_usage() const noexcept { return tape.size() + vector.size(); }
bool colocated() const noexcept { return tape.end() == vector.begin(); }
operator node_t() const noexcept { return node_t{tape.begin(), reinterpret_cast<scalar_t*>(vector.begin())}; }
explicit operator bool() const noexcept { return tape.begin() != nullptr; }
};

inline node_bytes_split_t node_bytes_split_(node_t node) const noexcept {
Expand All @@ -2967,7 +2968,9 @@ class index_gt {
}

node_t node_make_(label_t label, vector_view_t vector, level_t level, bool store_vector) noexcept {
node_bytes_split_t node_bytes = node_malloc_(vector.size() * store_vector, level);
node_bytes_split_t node_bytes = node_malloc_(static_cast<dim_t>(vector.size() * store_vector), level);
if (!node_bytes)
return {};
if (store_vector) {
std::memset(node_bytes.tape.data(), 0, node_bytes.tape.size());
std::memcpy(node_bytes.vector.data(), vector.data(), node_bytes.vector.size());
Expand Down
10 changes: 5 additions & 5 deletions include/usearch/index_punned_dense.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -705,10 +705,10 @@ class index_punned_dense_gt {
template <typename scalar_at>
add_result_t add_(label_t label, scalar_at const* vector, add_config_t config, cast_t const& cast) {
byte_t const* vector_data = reinterpret_cast<byte_t const*>(vector);
std::size_t vector_bytes = dimensions_ * sizeof(scalar_at);
std::size_t vector_bytes = scalar_words_ * sizeof(scalar_at);

byte_t* casted_data = cast_buffer_.data() + casted_vector_bytes_ * config.thread;
bool casted = cast(vector_data, dimensions_, casted_data);
bool casted = cast(vector_data, scalar_words_, casted_data);
if (casted)
vector_data = casted_data, vector_bytes = casted_vector_bytes_, config.store_vector = true;

Expand Down Expand Up @@ -737,10 +737,10 @@ class index_punned_dense_gt {
search_config_t config, cast_t const& cast) const {

byte_t const* vector_data = reinterpret_cast<byte_t const*>(vector);
std::size_t vector_bytes = dimensions_ * sizeof(scalar_at);
std::size_t vector_bytes = scalar_words_ * sizeof(scalar_at);

byte_t* casted_data = cast_buffer_.data() + casted_vector_bytes_ * config.thread;
bool casted = cast(vector_data, dimensions_, casted_data);
bool casted = cast(vector_data, scalar_words_, casted_data);
if (casted)
vector_data = casted_data, vector_bytes = casted_vector_bytes_;

Expand Down Expand Up @@ -790,7 +790,7 @@ class index_punned_dense_gt {
// Export the entry
member_cref_t member = typed_->at(id);
byte_t const* punned_vector = reinterpret_cast<byte_t const*>(member.vector.data());
bool casted = cast(punned_vector, dimensions_, (byte_t*)reconstructed);
bool casted = cast(punned_vector, scalar_words_, (byte_t*)reconstructed);
if (!casted)
std::memcpy(reconstructed, punned_vector, casted_vector_bytes_);
return true;
Expand Down
1 change: 0 additions & 1 deletion python/scripts/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,6 @@ def test_bitwise_index(
bit_vectors = np.packbits(byte_vectors, axis=1)

index.add(labels, bit_vectors)
assert np.allclose(index.get_vectors(labels), byte_vectors, atol=0.1)
assert np.all(index.get_vectors(labels, ScalarKind.B1) == bit_vectors)

index.search(bit_vectors, 10)
Expand Down
6 changes: 4 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@

if sys.platform == "linux":
compile_args.append("-std=c++11")
compile_args.append("-O3")
compile_args.append("-O3") # Maximize performance
compile_args.append("-g") # Simplify debugging
compile_args.append("-Wno-unknown-pragmas")

macros_args.append(("USEARCH_USE_OPENMP", "1"))
Expand All @@ -26,7 +27,8 @@
# https://github.com/unum-cloud/usearch/actions/runs/4975434891/jobs/8902603392
compile_args.append("-mmacosx-version-min=10.15")
compile_args.append("-std=c++11")
compile_args.append("-O3")
compile_args.append("-O3") # Maximize performance
compile_args.append("-g") # Simplify debugging
compile_args.append("-Wno-unknown-pragmas")

# Linking OpenMP requires additional preparion in CIBuildWheel
Expand Down

0 comments on commit 459abcd

Please sign in to comment.