Skip to content

Commit

Permalink
refactor: toggle TLS support with feature gate
Browse files Browse the repository at this point in the history
  • Loading branch information
kezhuw committed Apr 27, 2024
1 parent b68c967 commit 8c2056d
Show file tree
Hide file tree
Showing 7 changed files with 269 additions and 109 deletions.
7 changes: 4 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ rust-version = "1.65"

[features]
default = []
tls = ["rustls", "rustls-pemfile", "webpki-roots"]
sasl = ["sasl-gssapi", "sasl-digest-md5"]
sasl-digest-md5 = ["rsasl/unstable_custom_mechanism", "md5", "linkme", "hex"]
sasl-gssapi = ["rsasl/gssapi"]
Expand All @@ -34,9 +35,9 @@ hashbrown = "0.12.0"
hashlink = "0.8.0"
either = "1.9.0"
uuid = { version = "1.4.1", features = ["v4"] }
rustls = "0.23.2"
rustls-pemfile = "2"
webpki-roots = "0.26.1"
rustls = { version = "0.23.2", optional = true }
rustls-pemfile = { version = "2", optional = true }
webpki-roots = { version = "0.26.1", optional = true }
derive-where = "1.2.7"
tokio-rustls = "0.26.0"
fastrand = "2.0.2"
Expand Down
8 changes: 7 additions & 1 deletion src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ use crate::record::{self, Record, StaticRecord};
use crate::sasl::SaslOptions;
use crate::session::StateReceiver;
pub use crate::session::{EventType, SessionId, SessionInfo, SessionState, WatchedEvent};
#[cfg(feature = "tls")]
use crate::tls::TlsOptions;
use crate::util;

Expand Down Expand Up @@ -1538,6 +1539,7 @@ pub(crate) struct Version(u32, u32, u32);
/// A builder for [Client].
#[derive(Clone, Debug)]
pub struct Connector {
#[cfg(feature = "tls")]
tls: Option<TlsOptions>,
#[cfg(any(feature = "sasl-digest-md5", feature = "sasl-gssapi"))]
sasl: Option<SaslOptions>,
Expand All @@ -1555,6 +1557,7 @@ pub struct Connector {
impl Connector {
fn new() -> Self {
Self {
#[cfg(feature = "tls")]
tls: None,
#[cfg(any(feature = "sasl-digest-md5", feature = "sasl-gssapi"))]
sasl: None,
Expand Down Expand Up @@ -1624,6 +1627,7 @@ impl Connector {
}

/// Specifies tls options for connections to ZooKeeper.
#[cfg(feature = "tls")]
pub fn tls(&mut self, options: TlsOptions) -> &mut Self {
self.tls = Some(options);
self
Expand All @@ -1649,13 +1653,14 @@ impl Connector {
async fn connect_internally(&mut self, secure: bool, cluster: &str) -> Result<Client> {
let (endpoints, chroot) = endpoint::parse_connect_string(cluster, secure)?;
let builder = Session::builder()
.with_tls(self.tls.take())
.with_session(self.session.take())
.with_authes(&self.authes)
.with_readonly(self.readonly)
.with_detached(self.detached)
.with_session_timeout(self.session_timeout)
.with_connection_timeout(self.connection_timeout);
#[cfg(feature = "tls")]
let builder = builder.with_tls(self.tls.take());
#[cfg(any(feature = "sasl-digest-md5", feature = "sasl-gssapi"))]
let builder = builder.with_sasl(self.sasl.take());
let (mut session, state_receiver) = builder.build()?;
Expand Down Expand Up @@ -1685,6 +1690,7 @@ impl Connector {
///
/// Same to [Self::connect] except that `server1` will use tls encrypted protocol given
/// the connection string `server1,tcp://server2,tcp+tls://server3`.
#[cfg(feature = "tls")]
pub async fn secure_connect(&mut self, cluster: &str) -> Result<Client> {
self.connect_internally(true, cluster).await
}
Expand Down
1 change: 1 addition & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ impl Error {
Self::Other(OtherError { message: message.into(), source })
}

#[allow(dead_code)]
pub(crate) fn other(message: impl Into<String>, source: impl std::error::Error + Send + Sync + 'static) -> Self {
Self::new_other(message.into(), Some(Arc::new(source)))
}
Expand Down
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@ mod record;
#[cfg(any(feature = "sasl-digest-md5", feature = "sasl-gssapi"))]
mod sasl;
mod session;
#[cfg(feature = "tls")]
mod tls;
mod util;

pub use self::acl::{Acl, Acls, AuthId, AuthUser, Permission};
pub use self::error::Error;
#[cfg(feature = "tls")]
pub use self::tls::TlsOptions;
pub use crate::client::*;
#[cfg(feature = "sasl-digest-md5")]
Expand Down
87 changes: 75 additions & 12 deletions src/session/connection.rs
Original file line number Diff line number Diff line change
@@ -1,37 +1,47 @@
use std::io::{Error, ErrorKind, IoSlice, Result};
use std::pin::Pin;
use std::ptr;
use std::sync::Arc;
use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
use std::time::Duration;

use bytes::buf::BufMut;
use ignore_result::Ignore;
use rustls::pki_types::ServerName;
use rustls::ClientConfig;
use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufStream, ReadBuf};
use tokio::net::TcpStream;
use tokio::{select, time};
use tokio_rustls::client::TlsStream;
use tokio_rustls::TlsConnector;
use tracing::{debug, trace};

#[cfg(feature = "tls")]
mod tls {
pub use std::sync::Arc;

pub use rustls::pki_types::ServerName;
pub use rustls::ClientConfig;
pub use tokio_rustls::client::TlsStream;
pub use tokio_rustls::TlsConnector;
}
#[cfg(feature = "tls")]
use tls::*;

use crate::deadline::Deadline;
use crate::endpoint::{EndpointRef, IterableEndpoints};

const NOOP_VTABLE: RawWakerVTable =
RawWakerVTable::new(|_| RawWaker::new(ptr::null(), &NOOP_VTABLE), |_| {}, |_| {}, |_| {});
const NOOP_WAKER: RawWaker = RawWaker::new(ptr::null(), &NOOP_VTABLE);

#[derive(Debug)]
pub enum Connection {
Tls(TlsStream<TcpStream>),
Raw(TcpStream),
#[cfg(feature = "tls")]
Tls(TlsStream<TcpStream>),
}

impl AsyncRead for Connection {
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<Result<()>> {
match self.get_mut() {
Self::Raw(stream) => Pin::new(stream).poll_read(cx, buf),
#[cfg(feature = "tls")]
Self::Tls(stream) => Pin::new(stream).poll_read(cx, buf),
}
}
Expand All @@ -41,20 +51,23 @@ impl AsyncWrite for Connection {
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
match self.get_mut() {
Self::Raw(stream) => Pin::new(stream).poll_write(cx, buf),
#[cfg(feature = "tls")]
Self::Tls(stream) => Pin::new(stream).poll_write(cx, buf),
}
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
match self.get_mut() {
Self::Raw(stream) => Pin::new(stream).poll_flush(cx),
#[cfg(feature = "tls")]
Self::Tls(stream) => Pin::new(stream).poll_flush(cx),
}
}

fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
match self.get_mut() {
Self::Raw(stream) => Pin::new(stream).poll_shutdown(cx),
#[cfg(feature = "tls")]
Self::Tls(stream) => Pin::new(stream).poll_shutdown(cx),
}
}
Expand All @@ -65,6 +78,7 @@ impl Connection {
Self::Raw(stream)
}

#[cfg(feature = "tls")]
pub fn new_tls(stream: TlsStream<TcpStream>) -> Self {
Self::Tls(stream)
}
Expand Down Expand Up @@ -97,6 +111,7 @@ impl Connection {
pub async fn readable(&self) -> Result<()> {
match self {
Self::Raw(stream) => stream.readable().await,
#[cfg(feature = "tls")]
Self::Tls(stream) => {
let (stream, session) = stream.get_ref();
if session.wants_read() {
Expand All @@ -112,6 +127,7 @@ impl Connection {
pub async fn writable(&self) -> Result<()> {
match self {
Self::Raw(stream) => stream.writable().await,
#[cfg(feature = "tls")]
Self::Tls(stream) => {
let (stream, _session) = stream.get_ref();
stream.writable().await
Expand All @@ -122,6 +138,7 @@ impl Connection {
pub fn wants_write(&self) -> bool {
match self {
Self::Raw(_) => false,
#[cfg(feature = "tls")]
Self::Tls(stream) => {
let (_stream, session) = stream.get_ref();
session.wants_write()
Expand Down Expand Up @@ -160,13 +177,33 @@ impl Connection {

#[derive(Clone)]
pub struct Connector {
tls: TlsConnector,
#[cfg(feature = "tls")]
tls: Option<TlsConnector>,
timeout: Duration,
}

impl Connector {
pub fn new(config: impl Into<Arc<ClientConfig>>) -> Self {
Self { tls: TlsConnector::from(config.into()), timeout: Duration::from_secs(10) }
#[cfg(feature = "tls")]
#[allow(dead_code)]
pub fn new() -> Self {
Self { tls: None, timeout: Duration::from_secs(10) }
}

#[cfg(not(feature = "tls"))]
pub fn new() -> Self {
Self { timeout: Duration::from_secs(10) }
}

#[cfg(feature = "tls")]
pub fn with_tls(config: ClientConfig) -> Self {
Self { tls: Some(TlsConnector::from(Arc::new(config))), timeout: Duration::from_secs(10) }
}

#[cfg(feature = "tls")]
async fn connect_tls(&self, stream: TcpStream, host: &str) -> Result<Connection> {
let domain = ServerName::try_from(host).unwrap().to_owned();
let stream = self.tls.as_ref().unwrap().connect(domain, stream).await?;
Ok(Connection::new_tls(stream))
}

pub fn timeout(&self) -> Duration {
Expand All @@ -178,6 +215,14 @@ impl Connector {
}

pub async fn connect(&self, endpoint: EndpointRef<'_>, deadline: &mut Deadline) -> Result<Connection> {
if endpoint.tls {
#[cfg(feature = "tls")]
if self.tls.is_none() {
return Err(Error::new(ErrorKind::Unsupported, "tls not supported"));
}
#[cfg(not(feature = "tls"))]
return Err(Error::new(ErrorKind::Unsupported, "tls not supported"));
}
select! {
_ = unsafe { Pin::new_unchecked(deadline) } => Err(Error::new(ErrorKind::TimedOut, "deadline exceed")),
_ = time::sleep(self.timeout) => Err(Error::new(ErrorKind::TimedOut, format!("connection timeout{:?} exceed", self.timeout))),
Expand All @@ -186,9 +231,10 @@ impl Connector {
Err(err) => Err(err),
Ok(sock) => {
let connection = if endpoint.tls {
let domain = ServerName::try_from(endpoint.host).unwrap().to_owned();
let stream = self.tls.connect(domain, sock).await?;
Connection::new_tls(stream)
#[cfg(not(feature = "tls"))]
unreachable!("tls not supported");
#[cfg(feature = "tls")]
self.connect_tls(sock, endpoint.host).await?
} else {
Connection::new_raw(sock)
};
Expand Down Expand Up @@ -231,3 +277,20 @@ impl Connector {
None
}
}

#[cfg(test)]
mod tests {
use std::io::ErrorKind;

use super::Connector;
use crate::deadline::Deadline;
use crate::endpoint::EndpointRef;

#[tokio::test]
async fn raw() {
let connector = Connector::new();
let endpoint = EndpointRef::new("host1", 2181, true);
let err = connector.connect(endpoint, &mut Deadline::never()).await.unwrap_err();
assert_eq!(err.kind(), ErrorKind::Unsupported);
}
}
10 changes: 8 additions & 2 deletions src/session/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ use crate::proto::{AuthPacket, ConnectRequest, ConnectResponse, ErrorCode, OpCod
use crate::record;
#[cfg(any(feature = "sasl-digest-md5", feature = "sasl-gssapi"))]
use crate::sasl::{SaslInitiator, SaslOptions, SaslSession};
#[cfg(feature = "tls")]
use crate::tls::TlsOptions;

pub const PASSWORD_LEN: usize = 16;
Expand All @@ -61,6 +62,7 @@ impl RequestOperation for (WatcherId, StateResponser) {

#[derive(Default)]
pub struct Builder {
#[cfg(feature = "tls")]
tls: Option<TlsOptions>,
#[cfg(any(feature = "sasl-digest-md5", feature = "sasl-gssapi"))]
sasl: Option<SaslOptions>,
Expand All @@ -73,6 +75,7 @@ pub struct Builder {
}

impl Builder {
#[cfg(feature = "tls")]
pub fn with_tls(self, tls: Option<TlsOptions>) -> Self {
Self { tls, ..self }
}
Expand Down Expand Up @@ -125,7 +128,10 @@ impl Builder {
} else if self.connection_timeout < Duration::ZERO {
return Err(Error::BadArguments(&"connection timeout must not be negative"));
}
let tls_config = self.tls.unwrap_or_default().into_config()?;
#[cfg(feature = "tls")]
let connector = Connector::with_tls(self.tls.unwrap_or_default().into_config()?);
#[cfg(not(feature = "tls"))]
let connector = Connector::new();
let (state_sender, state_receiver) = tokio::sync::watch::channel(SessionState::Disconnected);
let now = Instant::now();
let (watch_manager, unwatch_receiver) = WatchManager::new();
Expand All @@ -142,7 +148,7 @@ impl Builder {
tick_timeout: Duration::ZERO,
ping_timeout: Duration::ZERO,
session_expired_timeout: Duration::ZERO,
connector: Connector::new(tls_config),
connector,
#[cfg(any(feature = "sasl-digest-md5", feature = "sasl-gssapi"))]
sasl_options: self.sasl,
#[cfg(any(feature = "sasl-digest-md5", feature = "sasl-gssapi"))]
Expand Down
Loading

0 comments on commit 8c2056d

Please sign in to comment.