Skip to content

Commit

Permalink
Add Actix and config validation (qdrant#1463)
Browse files Browse the repository at this point in the history
* actix validation

* add check for memmap/indexing_threshold

* fix actix json settings

* Validate settings configuration on start, print pretty warnings on fail

* Add more validation rules for settings and nested types

* Move shared validation logic into collection/operations

* Show validation warning in log when loading some internal configs

* Show prettier actix JSON validation errors

* Stubs for pretty handling of query errors, reformat validation errors

* Use crate flatten function, the hard work was already done for us

We don't have to flatten validation errors into their qualified field
names ourselves because there is a utility function for this.

* Configure actix path validator

* Actix endpoints don't require public

* Extend validation to more actix types

* Validate all remaining actix path and query properties

* Rephrase range validation messages to clearly describe they're inclusive

* Validate all query params to respond with pretty deserialize errors

* Nicely format JSON payload deserialize error responses

* Improve error reporting for upsert point batches

* Add basic validation test that checks a path, query and payload value

* Add some validation constraints

* Add simple validation error render test

* Update Cargo.lock

---------

Co-authored-by: timvisee <tim+github@visee.me>
  • Loading branch information
2 people authored and generall committed Apr 11, 2023
1 parent 2ddce55 commit f0f9229
Show file tree
Hide file tree
Showing 39 changed files with 945 additions and 403 deletions.
247 changes: 173 additions & 74 deletions Cargo.lock

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ clap = { version = "4.2.1", features = ["derive"] }
serde_cbor = { version = "0.11.2"}
uuid = { version = "1.3", features = ["v4", "serde"] }
sys-info = "0.9.1"
actix-web-validator = "5.0.1"
validator = { version = "0.16", features = ["derive"] }

config = "~0.13.3"

Expand Down
2 changes: 2 additions & 0 deletions lib/collection/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ tonic = { version = "0.8.3", features = ["gzip", "tls"] }
tower = "0.4.13"
uuid = { version = "1.3", features = ["v4", "serde"] }
url = { version = "2", features = ["serde"] }
actix-web-validator = "5.0.1"
validator = { version = "0.16", features = ["derive"] }

segment = {path = "../segment"}
api = {path = "../api"}
Expand Down
7 changes: 5 additions & 2 deletions lib/collection/src/collection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use tar::Builder as TarBuilder;
use tokio::fs::{copy, create_dir_all, remove_dir_all, remove_file, rename};
use tokio::runtime::Handle;
use tokio::sync::{Mutex, RwLock, RwLockWriteGuard};
use validator::Validate;

use crate::collection_state::{ShardInfo, State};
use crate::common::is_ready::IsReady;
Expand All @@ -36,7 +37,7 @@ use crate::operations::types::{
CountResult, LocalShardInfo, NodeType, PointRequest, Record, RemoteShardInfo, ScrollRequest,
ScrollResult, SearchRequest, SearchRequestBatch, UpdateResult,
};
use crate::operations::{CollectionUpdateOperations, Validate};
use crate::operations::CollectionUpdateOperations;
use crate::optimizers_builder::OptimizersConfig;
use crate::shards::channel_service::ChannelService;
use crate::shards::collection_shard_distribution::CollectionShardDistribution;
Expand Down Expand Up @@ -248,9 +249,10 @@ impl Collection {
panic!(
"Can't read collection config due to {}\nat {}",
err,
path.to_str().unwrap()
path.to_str().unwrap(),
)
});
collection_config.validate_and_warn();

let ring = HashRing::fair(HASH_RING_SHARD_SCALE);
let mut shard_holder = ShardHolder::new(path, ring).expect("Can not create shard holder");
Expand Down Expand Up @@ -1464,6 +1466,7 @@ impl Collection {
ar.unpack(target_dir)?;

let config = CollectionConfig::load(target_dir)?;
config.validate_and_warn();
let configured_shards = config.params.shard_number.get();

for shard_id in 0..configured_shards {
Expand Down
4 changes: 3 additions & 1 deletion lib/collection/src/collection_state.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::collections::{HashMap, HashSet};

use serde::{Deserialize, Serialize};
use validator::Validate;

use crate::collection::Collection;
use crate::config::CollectionConfig;
Expand All @@ -14,8 +15,9 @@ pub struct ShardInfo {
pub replicas: HashMap<PeerId, ReplicaState>,
}

#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
#[derive(Debug, Serialize, Deserialize, Validate, Clone, PartialEq)]
pub struct State {
#[validate]
pub config: CollectionConfig,
pub shards: HashMap<ShardId, ShardInfo>,
#[serde(default)]
Expand Down
19 changes: 16 additions & 3 deletions lib/collection/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,19 @@ use segment::common::anonymize::Anonymize;
use segment::data_types::vectors::DEFAULT_VECTOR_NAME;
use segment::types::{HnswConfig, QuantizationConfig, VectorDataConfig};
use serde::{Deserialize, Serialize};
use validator::Validate;
use wal::WalOptions;

use crate::operations::types::{CollectionError, CollectionResult, VectorParams, VectorsConfig};
use crate::operations::validation;
use crate::optimizers_builder::OptimizersConfig;

pub const COLLECTION_CONFIG_FILE: &str = "config.json";

#[derive(Debug, Deserialize, Serialize, JsonSchema, Clone, PartialEq, Eq)]
#[derive(Debug, Deserialize, Serialize, JsonSchema, Validate, Clone, PartialEq, Eq)]
pub struct WalConfig {
/// Size of a single WAL segment in MB
#[validate(range(min = 1))]
pub wal_capacity_mb: usize,
/// Number of WAL segments to create ahead of actually used ones
pub wal_segments_ahead: usize,
Expand All @@ -44,7 +47,7 @@ impl Default for WalConfig {
}
}

#[derive(Debug, Deserialize, Serialize, JsonSchema, Clone, PartialEq, Eq)]
#[derive(Debug, Deserialize, Serialize, JsonSchema, Validate, Clone, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub struct CollectionParams {
/// Configuration of the vector storage
Expand Down Expand Up @@ -97,11 +100,15 @@ fn default_on_disk_payload() -> bool {
false
}

#[derive(Debug, Deserialize, Serialize, JsonSchema, Clone, PartialEq)]
#[derive(Debug, Deserialize, Serialize, JsonSchema, Validate, Clone, PartialEq)]
pub struct CollectionConfig {
#[validate]
pub params: CollectionParams,
#[validate]
pub hnsw_config: HnswConfig,
#[validate]
pub optimizer_config: OptimizersConfig,
#[validate]
pub wal_config: WalConfig,
#[serde(default)]
pub quantization_config: Option<QuantizationConfig>,
Expand Down Expand Up @@ -131,6 +138,12 @@ impl CollectionConfig {
let config_path = path.join(COLLECTION_CONFIG_FILE);
config_path.exists()
}

pub fn validate_and_warn(&self) {
if let Err(ref errs) = self.validate() {
validation::warn_validation_errors("Collection configuration file", errs);
}
}
}

impl CollectionParams {
Expand Down
24 changes: 18 additions & 6 deletions lib/collection/src/operations/cluster_ops.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use validator::Validate;

use crate::shards::shard::{PeerId, ShardId};

Expand All @@ -17,41 +18,52 @@ pub enum ClusterOperations {
DropReplica(DropReplicaOperation),
}

#[derive(Debug, Deserialize, Serialize, JsonSchema, Clone)]
#[derive(Debug, Deserialize, Serialize, JsonSchema, Validate, Clone)]
#[serde(rename_all = "snake_case")]
pub struct MoveShardOperation {
pub move_shard: MoveShard,
}

#[derive(Debug, Deserialize, Serialize, JsonSchema, Clone)]
#[derive(Debug, Deserialize, Serialize, JsonSchema, Validate, Clone)]
#[serde(rename_all = "snake_case")]
pub struct ReplicateShardOperation {
pub replicate_shard: MoveShard,
}

#[derive(Debug, Deserialize, Serialize, JsonSchema, Clone)]
#[derive(Debug, Deserialize, Serialize, JsonSchema, Validate, Clone)]
#[serde(rename_all = "snake_case")]
pub struct DropReplicaOperation {
pub drop_replica: Replica,
}

#[derive(Debug, Deserialize, Serialize, JsonSchema, Clone)]
#[derive(Debug, Deserialize, Serialize, JsonSchema, Validate, Clone)]
#[serde(rename_all = "snake_case")]
pub struct AbortTransferOperation {
pub abort_transfer: MoveShard,
}

#[derive(Debug, Deserialize, Serialize, JsonSchema, Clone)]
#[derive(Debug, Deserialize, Serialize, JsonSchema, Validate, Clone)]
#[serde(rename_all = "snake_case")]
pub struct MoveShard {
pub shard_id: ShardId,
pub to_peer_id: PeerId,
pub from_peer_id: PeerId,
}

#[derive(Debug, Deserialize, Serialize, JsonSchema, Clone)]
#[derive(Debug, Deserialize, Serialize, JsonSchema, Validate, Clone)]
#[serde(rename_all = "snake_case")]
pub struct Replica {
pub shard_id: ShardId,
pub peer_id: PeerId,
}

impl Validate for ClusterOperations {
fn validate(&self) -> Result<(), validator::ValidationErrors> {
match self {
ClusterOperations::MoveShard(op) => op.validate(),
ClusterOperations::ReplicateShard(op) => op.validate(),
ClusterOperations::AbortTransfer(op) => op.validate(),
ClusterOperations::DropReplica(op) => op.validate(),
}
}
}
18 changes: 15 additions & 3 deletions lib/collection/src/operations/config_diff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use segment::types::HnswConfig;
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use validator::Validate;

use crate::config::{CollectionParams, WalConfig};
use crate::operations::types::CollectionResult;
Expand Down Expand Up @@ -35,21 +36,27 @@ pub trait DiffConfig<T: DeserializeOwned + Serialize> {
}
}

#[derive(Debug, Deserialize, Serialize, JsonSchema, Copy, Clone, PartialEq, Eq, Merge, Hash)]
#[derive(
Debug, Deserialize, Serialize, JsonSchema, Validate, Copy, Clone, PartialEq, Eq, Merge, Hash,
)]
#[serde(rename_all = "snake_case")]
pub struct HnswConfigDiff {
/// Number of edges per node in the index graph. Larger the value - more accurate the search, more space required.
#[validate(range(min = 4, max = 10_000))]
pub m: Option<usize>,
/// Number of neighbours to consider during the index building. Larger the value - more accurate the search, more time required to build the index.
#[validate(range(min = 4))]
pub ef_construct: Option<usize>,
/// Minimal size (in KiloBytes) of vectors for additional payload-based indexing.
/// If payload chunk is smaller than `full_scan_threshold_kb` additional indexing won't be used -
/// in this case full-scan search should be preferred by query planner and additional indexing is not required.
/// Note: 1Kb = 1 vector of size 256
#[serde(alias = "full_scan_threshold_kb")]
#[validate(range(min = 1000))]
pub full_scan_threshold: Option<usize>,
/// Number of parallel threads used for background index building. If 0 - auto selection.
#[serde(default)]
#[validate(range(min = 1000))]
pub max_indexing_threads: Option<usize>,
/// Store HNSW index on disk. If set to false, the index will be stored in RAM. Default: false
#[serde(default)]
Expand All @@ -59,7 +66,9 @@ pub struct HnswConfigDiff {
pub payload_m: Option<usize>,
}

#[derive(Debug, Deserialize, Serialize, JsonSchema, Clone, Merge, PartialEq, Eq, Hash)]
#[derive(
Debug, Deserialize, Serialize, JsonSchema, Validate, Clone, Merge, PartialEq, Eq, Hash,
)]
pub struct WalConfigDiff {
/// Size of a single WAL segment in MB
pub wal_capacity_mb: Option<usize>,
Expand All @@ -75,7 +84,7 @@ pub struct CollectionParamsDiff {
pub write_consistency_factor: Option<NonZeroU32>,
}

#[derive(Debug, Deserialize, Serialize, JsonSchema, Clone, Merge)]
#[derive(Debug, Deserialize, Serialize, JsonSchema, Validate, Clone, Merge)]
pub struct OptimizersConfigDiff {
/// The minimal fraction of deleted vectors in a segment, required to perform segment optimization
pub deleted_threshold: Option<f64>,
Expand Down Expand Up @@ -104,15 +113,18 @@ pub struct OptimizersConfigDiff {
/// To enable memmap storage, lower the threshold
/// Note: 1Kb = 1 vector of size 256
#[serde(alias = "memmap_threshold_kb")]
#[validate(range(min = 1000))]
pub memmap_threshold: Option<usize>,
/// Maximum size (in KiloBytes) of vectors allowed for plain index.
/// Default value based on <https://github.com/google-research/google-research/blob/master/scann/docs/algorithms.md>
/// Note: 1Kb = 1 vector of size 256
#[serde(alias = "indexing_threshold_kb")]
#[validate(range(min = 1000))]
pub indexing_threshold: Option<usize>,
/// Minimum interval between forced flushes.
pub flush_interval_sec: Option<u64>,
/// Maximum available threads for optimization workers
#[validate(range(min = 1))]
pub max_optimization_threads: Option<usize>,
}

Expand Down
21 changes: 21 additions & 0 deletions lib/collection/src/operations/consistency_params.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
use std::borrow::Cow;

use api::grpc::qdrant::{
read_consistency, ReadConsistency as ReadConsistencyGrpc,
ReadConsistencyType as ReadConsistencyTypeGrpc,
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use validator::{Validate, ValidationError as ValidatorError, ValidationErrors};

/// Read consistency parameter
///
Expand All @@ -26,6 +29,24 @@ pub enum ReadConsistency {
Type(ReadConsistencyType),
}

impl Validate for ReadConsistency {
fn validate(&self) -> Result<(), ValidationErrors> {
match self {
ReadConsistency::Factor(factor) if *factor == 0 => {
let mut errors = ValidationErrors::new();
errors.add("factor", {
let mut error = ValidatorError::new("range");
error.add_param(Cow::from("value"), factor);
error.add_param(Cow::from("min"), &1);
error
});
Err(errors)
}
ReadConsistency::Factor(_) | ReadConsistency::Type(_) => Ok(()),
}
}
}

impl Default for ReadConsistency {
fn default() -> Self {
ReadConsistency::Factor(1)
Expand Down
22 changes: 13 additions & 9 deletions lib/collection/src/operations/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,18 @@ pub mod point_ops;
pub mod shared_storage_config;
pub mod snapshot_ops;
pub mod types;
pub mod validation;

use std::collections::HashMap;

use segment::types::{ExtendedPointId, PayloadFieldSchema};
use serde::{Deserialize, Serialize};
use validator::Validate;

use self::types::CollectionResult;
use crate::hash_ring::HashRing;
use crate::shards::shard::ShardId;

#[derive(Debug, Deserialize, Serialize, Default, Clone)]
#[derive(Debug, Deserialize, Serialize, Validate, Default, Clone)]
#[serde(rename_all = "snake_case")]
pub struct CreateIndex {
pub field_name: String,
Expand Down Expand Up @@ -85,18 +86,21 @@ impl FieldIndexOperations {
}
}

/// Stateless validation of operation content.
/// Checks for `CollectionError::BadInput`
pub trait Validate {
fn validate(&self) -> CollectionResult<()>;
impl Validate for FieldIndexOperations {
fn validate(&self) -> Result<(), validator::ValidationErrors> {
match self {
FieldIndexOperations::CreateIndex(create_index) => create_index.validate(),
FieldIndexOperations::DeleteIndex(_) => Ok(()),
}
}
}

impl Validate for CollectionUpdateOperations {
fn validate(&self) -> CollectionResult<()> {
fn validate(&self) -> Result<(), validator::ValidationErrors> {
match self {
CollectionUpdateOperations::PointOperation(operation) => operation.validate(),
CollectionUpdateOperations::PayloadOperation(_) => Ok(()),
CollectionUpdateOperations::FieldIndexOperation(_) => Ok(()),
CollectionUpdateOperations::PayloadOperation(operation) => operation.validate(),
CollectionUpdateOperations::FieldIndexOperation(operation) => operation.validate(),
}
}
}
Expand Down
Loading

0 comments on commit f0f9229

Please sign in to comment.