Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sparse idf dot #4126

Merged
merged 14 commits into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
[WIP] prepare idf stats for search query context
  • Loading branch information
generall committed Apr 26, 2024
commit c9bc7d9d9507bd8d8f0a69ea9c11d9cded2910c6
1 change: 1 addition & 0 deletions lib/api/src/grpc/proto/collections.proto
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ message VectorsConfigDiff {

message SparseVectorParams {
optional SparseIndexConfig index = 1; // Configuration of sparse index
optional bool idf = 2; // If true - use Inverse Document Frequency for sparse vectors scoring
}

message SparseVectorConfig {
Expand Down
3 changes: 3 additions & 0 deletions lib/api/src/grpc/qdrant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ pub struct SparseVectorParams {
/// Configuration of sparse index
#[prost(message, optional, tag = "1")]
pub index: ::core::option::Option<SparseIndexConfig>,
/// If true - use Inverse Document Frequency for sparse vectors scoring
#[prost(bool, optional, tag = "2")]
pub idf: ::core::option::Option<bool>,
}
#[derive(serde::Serialize)]
#[allow(clippy::derive_partial_eq_without_eq)]
Expand Down
64 changes: 53 additions & 11 deletions lib/collection/src/collection_manager/segments_searcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,18 @@ use segment::types::{
Filter, Indexes, PointIdType, ScoredPoint, SearchParams, SegmentConfig, SeqNumberType,
WithPayload, WithPayloadInterface, WithVector,
};
use tinyvec::TinyVec;
use tokio::runtime::Handle;
use tokio::task::JoinHandle;

use super::holders::segment_holder::LockedSegmentHolder;
use crate::collection_manager::holders::segment_holder::{LockedSegment, SegmentHolder};
use crate::collection_manager::probabilistic_segment_search_sampling::find_search_sampling_over_point_distribution;
use crate::collection_manager::search_result_aggregator::BatchResultAggregator;
use crate::config::CollectionConfig;
use crate::operations::query_enum::QueryEnum;
use crate::operations::types::{CollectionResult, CoreSearchRequestBatch, Record};
use crate::optimizers_builder::DEFAULT_INDEXING_THRESHOLD_KB;

type BatchOffset = usize;
type SegmentOffset = usize;
Expand Down Expand Up @@ -149,15 +152,46 @@ impl SegmentsSearcher {
(result_aggregator, searches_to_rerun)
}

pub async fn search(
pub async fn prepare_query_context(
segments: LockedSegmentHolder,
batch_request: Arc<CoreSearchRequestBatch>,
runtime_handle: &Handle,
sampling_enabled: bool,
is_stopped: Arc<AtomicBool>,
mut query_context: QueryContext,
) -> CollectionResult<Vec<Vec<ScoredPoint>>> {
// ToDo: accumulate IDF here
batch_request: &CoreSearchRequestBatch,
collection_config: &CollectionConfig,
) -> CollectionResult<Option<QueryContext>> {
let indexing_threshold_kb = collection_config
.optimizer_config
.indexing_threshold
.unwrap_or(DEFAULT_INDEXING_THRESHOLD_KB);
let full_scan_threshold_kb = collection_config.hnsw_config.full_scan_threshold;

const DEFAULT_CAPACITY: usize = 3;
let mut idf_vectors: TinyVec<[&str; DEFAULT_CAPACITY]> = Default::default();

// check vector names existing
for req in &batch_request.searches {
let vector_name = req.query.get_vector_name();
collection_config.params.get_distance(vector_name)?;
if let Some(sparse_vector_params) = collection_config
.params
.get_sparse_vector_params_opt(vector_name)
{
if sparse_vector_params.idf.unwrap_or_default()
|| !idf_vectors.contains(&vector_name)
{
idf_vectors.push(vector_name);
}
}
}

let mut query_context =
QueryContext::new(indexing_threshold_kb.max(full_scan_threshold_kb));

for search_request in &batch_request.searches {
search_request
.query
.iterate_sparse(|vector_name, sparse_vector| {
query_context.init_idf(vector_name, &sparse_vector.indices);
})
}

// Do blocking calls in a blocking task: `segment.get().read()` calls might block async runtime
let task = {
Expand All @@ -175,15 +209,23 @@ impl SegmentsSearcher {
let segment = locked_segment.get();
let segment_guard = segment.read();
query_context.add_available_point_count(segment_guard.available_point_count());
// ToDo: update idf stats
}
Some(query_context)
})
};

let Some(query_context) = task.await? else {
return Ok(Vec::new());
};
Ok(task.await?)
}

pub async fn search(
segments: LockedSegmentHolder,
batch_request: Arc<CoreSearchRequestBatch>,
runtime_handle: &Handle,
sampling_enabled: bool,
is_stopped: Arc<AtomicBool>,
query_context: QueryContext,
) -> CollectionResult<Vec<Vec<ScoredPoint>>> {
let query_context_acr = Arc::new(query_context);

// Using block to ensure `segments` variable is dropped in the end of it
Expand Down
10 changes: 9 additions & 1 deletion lib/collection/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,12 @@ impl CollectionParams {
})
}

pub fn get_sparse_vector_params_opt(&self, vector_name: &str) -> Option<&SparseVectorParams> {
self.sparse_vectors
.as_ref()
.and_then(|sparse_vectors| sparse_vectors.get(vector_name))
}

pub fn get_sparse_vector_params_mut(
&mut self,
vector_name: &str,
Expand Down Expand Up @@ -310,7 +316,9 @@ impl CollectionParams {
) -> CollectionResult<()> {
for (vector_name, update_params) in update_vectors.0.iter() {
let sparse_vector_params = self.get_sparse_vector_params_mut(vector_name)?;
let SparseVectorParams { index } = update_params.clone();
let SparseVectorParams { index, idf } = update_params.clone();

sparse_vector_params.idf = idf;

if let Some(index) = index {
if let Some(existing_index) = &mut sparse_vector_params.index {
Expand Down
2 changes: 2 additions & 0 deletions lib/collection/src/operations/conversions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,7 @@ impl From<api::grpc::qdrant::SparseVectorParams> for SparseVectorParams {
full_scan_threshold: index_config.full_scan_threshold.map(|v| v as usize),
on_disk: index_config.on_disk,
}),
idf: sparse_vector_params.idf,
}
}
}
Expand All @@ -614,6 +615,7 @@ impl From<SparseVectorParams> for api::grpc::qdrant::SparseVectorParams {
on_disk: index_config.on_disk,
}
}),
idf: sparse_vector_params.idf,
}
}
}
Expand Down
6 changes: 6 additions & 0 deletions lib/collection/src/operations/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1334,12 +1334,18 @@ pub struct SparseVectorParams {
/// Custom params for index. If none - values from collection configuration are used.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub index: Option<SparseIndexParams>,

/// If true, include inverse document frequency in the scoring of sparse vectors.
/// Default: false
#[serde(default, skip_serializing_if = "Option::is_none")]
pub idf: Option<bool>,
}

impl Anonymize for SparseVectorParams {
fn anonymize(&self) -> Self {
Self {
index: self.index.anonymize(),
idf: self.idf,
}
}
}
Expand Down
31 changes: 14 additions & 17 deletions lib/collection/src/shards/local_shard/shard_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ use async_trait::async_trait;
use futures::future::try_join_all;
use itertools::Itertools;
use segment::data_types::order_by::{Direction, OrderBy};
use segment::data_types::query_context::QueryContext;
use segment::types::{
ExtendedPointId, Filter, ScoredPoint, WithPayload, WithPayloadInterface, WithVector,
};
Expand All @@ -21,7 +20,6 @@ use crate::operations::types::{
CountRequestInternal, CountResult, PointRequestInternal, Record, UpdateResult, UpdateStatus,
};
use crate::operations::OperationWithClockTag;
use crate::optimizers_builder::DEFAULT_INDEXING_THRESHOLD_KB;
use crate::shards::local_shard::LocalShard;
use crate::shards::shard_trait::ShardOperation;
use crate::update_handler::{OperationData, UpdateSignal};
Expand All @@ -33,26 +31,25 @@ impl LocalShard {
search_runtime_handle: &Handle,
timeout: Option<Duration>,
) -> CollectionResult<Vec<Vec<ScoredPoint>>> {
let (collection_params, indexing_threshold_kb, full_scan_threshold_kb) = {
let (query_context, collection_params) = {
let collection_config = self.collection_config.read().await;
(
collection_config.params.clone(),
collection_config
.optimizer_config
.indexing_threshold
.unwrap_or(DEFAULT_INDEXING_THRESHOLD_KB),
collection_config.hnsw_config.full_scan_threshold,

let query_context_opt = SegmentsSearcher::prepare_query_context(
self.segments.clone(),
&core_request,
&collection_config,
)
};
.await?;

// check vector names existing
for req in &core_request.searches {
collection_params.get_distance(req.query.get_vector_name())?;
}
let Some(query_context) = query_context_opt else {
// No segments to search
return Ok(vec![]);
};

let is_stopped = StoppingGuard::new();
(query_context, collection_config.params.clone())
};

let query_context = QueryContext::new(indexing_threshold_kb.max(full_scan_threshold_kb));
let is_stopped = StoppingGuard::new();

let search_request = SegmentsSearcher::search(
Arc::clone(&self.segments),
Expand Down
14 changes: 14 additions & 0 deletions lib/segment/src/data_types/query_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,18 @@ impl QueryContext {
pub fn add_available_point_count(&mut self, count: usize) {
self.available_point_count += count;
}

pub fn init_idf(&mut self, vector_name: &str, indices: &[DimId]) {
// ToDo: Would be nice to have an implementation of `entry` for `TinyMap`.
let idf = if let Some(idf) = self.idf.get_mut(vector_name) {
idf
} else {
self.idf.insert(vector_name.to_string(), HashMap::default());
self.idf.get_mut(vector_name).unwrap()
};

for index in indices {
idf.insert(*index, 0);
}
}
}
Loading