Skip to content

Commit

Permalink
VisitedList: Vec<usize> -> Vec<u8> (qdrant#5336)
Browse files Browse the repository at this point in the history
* Do not use VisitedListHandle::count_visits_since()

* VisitedList: use u8 instead of usize

* Test VisitedList
  • Loading branch information
xzfc authored and timvisee committed Nov 8, 2024
1 parent 1128ac8 commit 300c452
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 26 deletions.
16 changes: 14 additions & 2 deletions lib/segment/src/index/hnsw_index/hnsw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use std::thread;

use atomic_refcell::AtomicRefCell;
use bitvec::prelude::BitSlice;
use bitvec::vec::BitVec;
use common::counter::hardware_counter::HardwareCounterCell;
#[cfg(target_os = "linux")]
use common::cpu::linux_low_thread_priority;
Expand Down Expand Up @@ -335,7 +336,6 @@ impl<TGraphLinks: GraphLinks> HNSWIndex<TGraphLinks> {

let visited_pool = VisitedPool::new();
let mut block_filter_list = visited_pool.get(total_vector_count);
let visits_iteration = block_filter_list.get_current_iteration_id();

let payload_m = config.payload_m.unwrap_or(config.m);

Expand All @@ -346,6 +346,13 @@ impl<TGraphLinks: GraphLinks> HNSWIndex<TGraphLinks> {
graph_layers_builder.get_average_connectivity_on_level(0);
let average_links_per_0_level_int = (average_links_per_0_level as usize).max(1);

let mut indexed_vectors_set = if config.m != 0 {
// Every vector is already indexed in the main graph, so skip counting.
BitVec::new()
} else {
BitVec::repeat(false, total_vector_count)
};

for (field, _) in payload_index.indexed_fields() {
debug!("building additional index for field {}", &field);

Expand Down Expand Up @@ -385,12 +392,13 @@ impl<TGraphLinks: GraphLinks> HNSWIndex<TGraphLinks> {
&mut additional_graph,
payload_block.condition,
&mut block_filter_list,
&mut indexed_vectors_set,
)?;
graph_layers_builder.merge_from_other(additional_graph);
}
}

let indexed_payload_vectors = block_filter_list.count_visits_since(visits_iteration);
let indexed_payload_vectors = indexed_vectors_set.count_ones();

debug_assert!(indexed_vectors >= indexed_payload_vectors || config.m == 0);
indexed_vectors = indexed_vectors.max(indexed_payload_vectors);
Expand Down Expand Up @@ -429,6 +437,7 @@ impl<TGraphLinks: GraphLinks> HNSWIndex<TGraphLinks> {
graph_layers_builder: &mut GraphLayersBuilder,
condition: FieldCondition,
block_filter_list: &mut VisitedListHandle,
indexed_vectors_set: &mut BitVec,
) -> OperationResult<()> {
block_filter_list.next_iteration();

Expand All @@ -450,6 +459,9 @@ impl<TGraphLinks: GraphLinks> HNSWIndex<TGraphLinks> {

for block_point_id in points_to_index.iter().copied() {
block_filter_list.check_and_update_visited(block_point_id);
if !indexed_vectors_set.is_empty() {
indexed_vectors_set.set(block_point_id as usize, true);
}
}

let insert_points = |block_point_id| {
Expand Down
63 changes: 39 additions & 24 deletions lib/segment/src/index/visited_pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ pub struct VisitedListHandle<'a> {
/// and reuse same counter for multiple queries.
#[derive(Debug)]
struct VisitedList {
current_iter: usize,
visit_counters: Vec<usize>,
current_iter: u8,
visit_counters: Vec<u8>,
}

impl Default for VisitedList {
Expand Down Expand Up @@ -54,41 +54,31 @@ impl<'a> VisitedListHandle<'a> {
}
}

pub fn get_current_iteration_id(&self) -> usize {
self.visited_list.current_iter
}

// Count how many points were visited since the given iteration
pub fn count_visits_since(&self, iteration_id: usize) -> usize {
self.visited_list
.visit_counters
.iter()
.filter(|x| **x >= iteration_id)
.count()
}

/// Return `true` if visited
pub fn check(&self, point_id: PointOffsetType) -> bool {
self.visited_list
.visit_counters
.get(point_id as usize)
.map_or(false, |x| *x >= self.visited_list.current_iter)
.map_or(false, |x| *x == self.visited_list.current_iter)
}

/// Updates visited list
/// return `true` if point was visited before
pub fn check_and_update_visited(&mut self, point_id: PointOffsetType) -> bool {
let idx = point_id as usize;
if idx >= self.visited_list.visit_counters.len() {
self.visited_list.visit_counters.resize(idx + 1, 0);
}
let prev_value = self.visited_list.visit_counters[idx];
self.visited_list.visit_counters[idx] = self.visited_list.current_iter;
prev_value >= self.visited_list.current_iter
std::mem::replace(
&mut self.visited_list.visit_counters[point_id as usize],
self.visited_list.current_iter,
) == self.visited_list.current_iter
}

pub fn next_iteration(&mut self) {
self.visited_list.current_iter += 1;
self.visited_list.current_iter = self.visited_list.current_iter.wrapping_add(1);
if self.visited_list.current_iter == 0 {
self.visited_list
.visit_counters
.iter_mut()
.for_each(|x| *x = u8::MAX);
}
}
}

Expand Down Expand Up @@ -134,3 +124,28 @@ impl Default for VisitedPool {
VisitedPool::new()
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_visited_list() {
let pool = VisitedPool::new();
let mut visited_list = pool.get(10);

for _ in 0..2 {
assert!(!visited_list.check(0));
assert!(!visited_list.check_and_update_visited(0));
assert!(visited_list.check(0));

assert!(visited_list.check_and_update_visited(0));
assert!(visited_list.check(0));

for _ in 0..260 {
visited_list.next_iteration();
assert!(!visited_list.check(0));
}
}
}
}

0 comments on commit 300c452

Please sign in to comment.