Skip to content

Commit

Permalink
Add PQXDH tests
Browse files Browse the repository at this point in the history
  • Loading branch information
moiseev-signal authored May 23, 2023
1 parent b0a1bf2 commit 28e112b
Show file tree
Hide file tree
Showing 11 changed files with 2,217 additions and 1,704 deletions.
29 changes: 29 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions rust/protocol/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ armv8 = ["aes/armv8", "aes-gcm-siv/armv8"]
criterion = "0.4"
proptest = "1.0"
futures-util = "0.3.7"
env_logger = "0.8.1"

[build-dependencies]
prost-build = "0.9"
Expand Down
4 changes: 2 additions & 2 deletions rust/protocol/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ pub use session_cipher::{
message_decrypt, message_decrypt_prekey, message_decrypt_signal, message_encrypt,
};
pub use state::{
GenericSignedPreKey, KyberPreKeyId, KyberPreKeyRecord, PreKeyBundle, PreKeyId, PreKeyRecord,
SessionRecord, SignedPreKeyId, SignedPreKeyRecord,
GenericSignedPreKey, KyberPreKeyId, KyberPreKeyRecord, PreKeyBundle, PreKeyBundleContent,
PreKeyId, PreKeyRecord, SessionRecord, SignedPreKeyId, SignedPreKeyRecord,
};
pub use storage::{
Context, Direction, IdentityKeyStore, InMemIdentityKeyStore, InMemKyberPreKeyStore,
Expand Down
2 changes: 1 addition & 1 deletion rust/protocol/src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ mod prekey;
mod session;
mod signed_prekey;

pub use bundle::PreKeyBundle;
pub use bundle::{PreKeyBundle, PreKeyBundleContent};
pub use kyber_prekey::{KyberPreKeyId, KyberPreKeyRecord};
pub use prekey::{PreKeyId, PreKeyRecord};
pub use session::SessionRecord;
Expand Down
101 changes: 100 additions & 1 deletion rust/protocol/src/state/bundle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
//

use crate::state::{PreKeyId, SignedPreKeyId};
use crate::{kem, DeviceId, IdentityKey, KyberPreKeyId, PublicKey, Result};
use crate::{kem, DeviceId, IdentityKey, KyberPreKeyId, PublicKey, Result, SignalProtocolError};
use std::clone::Clone;
use std::convert::{TryFrom, TryInto};

#[derive(Clone)]
struct SignedPreKey {
Expand Down Expand Up @@ -41,6 +42,95 @@ impl KyberPreKey {
}
}

// Represents the raw contents of the pre-key bundle without any notion of required/optional
// fields.
// Can be used as a "builder" for PreKeyBundle, in which case all the validation will happen in
// PreKeyBundle::new.
pub struct PreKeyBundleContent {
pub registration_id: Option<u32>,
pub device_id: Option<DeviceId>,
pub pre_key_id: Option<PreKeyId>,
pub pre_key_public: Option<PublicKey>,
pub ec_pre_key_id: Option<SignedPreKeyId>,
pub ec_pre_key_public: Option<PublicKey>,
pub ec_pre_key_signature: Option<Vec<u8>>,
pub identity_key: Option<IdentityKey>,
pub kyber_pre_key_id: Option<KyberPreKeyId>,
pub kyber_pre_key_public: Option<kem::PublicKey>,
pub kyber_pre_key_signature: Option<Vec<u8>>,
}

impl From<PreKeyBundle> for PreKeyBundleContent {
fn from(bundle: PreKeyBundle) -> Self {
Self {
registration_id: Some(bundle.registration_id),
device_id: Some(bundle.device_id),
pre_key_id: bundle.pre_key_id,
pre_key_public: bundle.pre_key_public,
ec_pre_key_id: Some(bundle.ec_signed_pre_key.id),
ec_pre_key_public: Some(bundle.ec_signed_pre_key.public_key),
ec_pre_key_signature: Some(bundle.ec_signed_pre_key.signature),
identity_key: Some(bundle.identity_key),
kyber_pre_key_id: bundle.kyber_pre_key.as_ref().map(|kyber| kyber.id),
kyber_pre_key_public: bundle
.kyber_pre_key
.as_ref()
.map(|kyber| kyber.public_key.clone()),
kyber_pre_key_signature: bundle
.kyber_pre_key
.as_ref()
.map(|kyber| kyber.signature.clone()),
}
}
}

impl TryFrom<PreKeyBundleContent> for PreKeyBundle {
type Error = SignalProtocolError;

fn try_from(content: PreKeyBundleContent) -> Result<Self> {
let mut bundle = PreKeyBundle::new(
content.registration_id.ok_or_else(|| {
SignalProtocolError::InvalidArgument("registration_id is required".to_string())
})?,
content.device_id.ok_or_else(|| {
SignalProtocolError::InvalidArgument("device_id is required".to_string())
})?,
content
.pre_key_id
.and_then(|id| content.pre_key_public.map(|public| (id, public))),
content.ec_pre_key_id.ok_or_else(|| {
SignalProtocolError::InvalidArgument("signed_pre_key_id is required".to_string())
})?,
content.ec_pre_key_public.ok_or_else(|| {
SignalProtocolError::InvalidArgument(
"signed_pre_key_public is required".to_string(),
)
})?,
content.ec_pre_key_signature.ok_or_else(|| {
SignalProtocolError::InvalidArgument(
"signed_pre_key_signature is required".to_string(),
)
})?,
content.identity_key.ok_or_else(|| {
SignalProtocolError::InvalidArgument("identity_key is required".to_string())
})?,
)?;

fn zip3<T, U, V>(x: Option<T>, y: Option<U>, z: Option<V>) -> Option<(T, U, V)> {
x.zip(y).zip(z).map(|((x, y), z)| (x, y, z))
}

if let Some((kyber_id, kyber_public, kyber_sig)) = zip3(
content.kyber_pre_key_id,
content.kyber_pre_key_public,
content.kyber_pre_key_signature,
) {
bundle = bundle.with_kyber_pre_key(kyber_id, kyber_public, kyber_sig);
}
Ok(bundle)
}
}

#[derive(Clone)]
pub struct PreKeyBundle {
registration_id: u32,
Expand Down Expand Up @@ -149,4 +239,13 @@ impl PreKeyBundle {
.as_ref()
.map(|pre_key| pre_key.signature.as_ref()))
}

pub fn modify<F>(self, modify: F) -> Result<Self>
where
F: FnOnce(&mut PreKeyBundleContent),
{
let mut content = self.into();
modify(&mut content);
content.try_into()
}
}
2 changes: 1 addition & 1 deletion rust/protocol/src/state/kyber_prekey.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use std::convert::TryInto;
use std::fmt;

/// A unique identifier selecting among this client's known signed pre-keys.
#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
pub struct KyberPreKeyId(u32);

impl From<u32> for KyberPreKeyId {
Expand Down
2 changes: 1 addition & 1 deletion rust/protocol/src/state/prekey.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use prost::Message;
use std::fmt;

/// A unique identifier selecting among this client's known pre-keys.
#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
pub struct PreKeyId(u32);

impl From<u32> for PreKeyId {
Expand Down
2 changes: 1 addition & 1 deletion rust/protocol/src/state/signed_prekey.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use std::convert::AsRef;
use std::fmt;

/// A unique identifier selecting among this client's known signed pre-keys.
#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
pub struct SignedPreKeyId(u32);

impl From<u32> for SignedPreKeyId {
Expand Down
30 changes: 30 additions & 0 deletions rust/protocol/src/storage/inmem.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,11 @@ impl InMemPreKeyStore {
pre_keys: HashMap::new(),
}
}

/// Returns all registered pre-key ids
pub fn all_pre_key_ids(&self) -> impl Iterator<Item = &PreKeyId> {
self.pre_keys.keys()
}
}

impl Default for InMemPreKeyStore {
Expand Down Expand Up @@ -166,6 +171,11 @@ impl InMemSignedPreKeyStore {
signed_pre_keys: HashMap::new(),
}
}

/// Returns all registered signed pre-key ids
pub fn all_signed_pre_key_ids(&self) -> impl Iterator<Item = &SignedPreKeyId> {
self.signed_pre_keys.keys()
}
}

impl Default for InMemSignedPreKeyStore {
Expand Down Expand Up @@ -213,6 +223,11 @@ impl InMemKyberPreKeyStore {
kyber_pre_keys: HashMap::new(),
}
}

/// Returns all registered Kyber pre-key ids
pub fn all_kyber_pre_key_ids(&self) -> impl Iterator<Item = &KyberPreKeyId> {
self.kyber_pre_keys.keys()
}
}

impl Default for InMemKyberPreKeyStore {
Expand Down Expand Up @@ -396,6 +411,21 @@ impl InMemSignalProtocolStore {
sender_key_store: InMemSenderKeyStore::new(),
})
}

/// Returns all registered pre-key ids
pub fn all_pre_key_ids(&self) -> impl Iterator<Item = &PreKeyId> {
self.pre_key_store.all_pre_key_ids()
}

/// Returns all registered signed pre-key ids
pub fn all_signed_pre_key_ids(&self) -> impl Iterator<Item = &SignedPreKeyId> {
self.signed_pre_key_store.all_signed_pre_key_ids()
}

/// Returns all registered Kyber pre-key ids
pub fn all_kyber_pre_key_ids(&self) -> impl Iterator<Item = &KyberPreKeyId> {
self.kyber_pre_key_store.all_kyber_pre_key_ids()
}
}

#[async_trait(?Send)]
Expand Down
Loading

0 comments on commit 28e112b

Please sign in to comment.