Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: toggle TLS support with feature gate #47

Merged
merged 1 commit into from
Apr 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ rust-version = "1.65"

[features]
default = []
tls = ["rustls", "rustls-pemfile", "webpki-roots"]
sasl = ["sasl-gssapi", "sasl-digest-md5"]
sasl-digest-md5 = ["rsasl/unstable_custom_mechanism", "md5", "linkme", "hex"]
sasl-gssapi = ["rsasl/gssapi"]
Expand All @@ -34,9 +35,9 @@ hashbrown = "0.12.0"
hashlink = "0.8.0"
either = "1.9.0"
uuid = { version = "1.4.1", features = ["v4"] }
rustls = "0.23.2"
rustls-pemfile = "2"
webpki-roots = "0.26.1"
rustls = { version = "0.23.2", optional = true }
rustls-pemfile = { version = "2", optional = true }
webpki-roots = { version = "0.26.1", optional = true }
derive-where = "1.2.7"
tokio-rustls = "0.26.0"
fastrand = "2.0.2"
Expand Down
8 changes: 7 additions & 1 deletion src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ use crate::record::{self, Record, StaticRecord};
use crate::sasl::SaslOptions;
use crate::session::StateReceiver;
pub use crate::session::{EventType, SessionId, SessionInfo, SessionState, WatchedEvent};
#[cfg(feature = "tls")]
use crate::tls::TlsOptions;
use crate::util;

Expand Down Expand Up @@ -1538,6 +1539,7 @@ pub(crate) struct Version(u32, u32, u32);
/// A builder for [Client].
#[derive(Clone, Debug)]
pub struct Connector {
#[cfg(feature = "tls")]
tls: Option<TlsOptions>,
#[cfg(any(feature = "sasl-digest-md5", feature = "sasl-gssapi"))]
sasl: Option<SaslOptions>,
Expand All @@ -1555,6 +1557,7 @@ pub struct Connector {
impl Connector {
fn new() -> Self {
Self {
#[cfg(feature = "tls")]
tls: None,
#[cfg(any(feature = "sasl-digest-md5", feature = "sasl-gssapi"))]
sasl: None,
Expand Down Expand Up @@ -1624,6 +1627,7 @@ impl Connector {
}

/// Specifies tls options for connections to ZooKeeper.
#[cfg(feature = "tls")]
pub fn tls(&mut self, options: TlsOptions) -> &mut Self {
self.tls = Some(options);
self
Expand All @@ -1649,13 +1653,14 @@ impl Connector {
async fn connect_internally(&mut self, secure: bool, cluster: &str) -> Result<Client> {
let (endpoints, chroot) = endpoint::parse_connect_string(cluster, secure)?;
let builder = Session::builder()
.with_tls(self.tls.take())
.with_session(self.session.take())
.with_authes(&self.authes)
.with_readonly(self.readonly)
.with_detached(self.detached)
.with_session_timeout(self.session_timeout)
.with_connection_timeout(self.connection_timeout);
#[cfg(feature = "tls")]
let builder = builder.with_tls(self.tls.take());
#[cfg(any(feature = "sasl-digest-md5", feature = "sasl-gssapi"))]
let builder = builder.with_sasl(self.sasl.take());
let (mut session, state_receiver) = builder.build()?;
Expand Down Expand Up @@ -1685,6 +1690,7 @@ impl Connector {
///
/// Same to [Self::connect] except that `server1` will use tls encrypted protocol given
/// the connection string `server1,tcp://server2,tcp+tls://server3`.
#[cfg(feature = "tls")]
pub async fn secure_connect(&mut self, cluster: &str) -> Result<Client> {
self.connect_internally(true, cluster).await
}
Expand Down
1 change: 1 addition & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ impl Error {
Self::Other(OtherError { message: message.into(), source })
}

#[allow(dead_code)]
pub(crate) fn other(message: impl Into<String>, source: impl std::error::Error + Send + Sync + 'static) -> Self {
Self::new_other(message.into(), Some(Arc::new(source)))
}
Expand Down
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@ mod record;
#[cfg(any(feature = "sasl-digest-md5", feature = "sasl-gssapi"))]
mod sasl;
mod session;
#[cfg(feature = "tls")]
mod tls;
mod util;

pub use self::acl::{Acl, Acls, AuthId, AuthUser, Permission};
pub use self::error::Error;
#[cfg(feature = "tls")]
pub use self::tls::TlsOptions;
pub use crate::client::*;
#[cfg(feature = "sasl-digest-md5")]
Expand Down
87 changes: 75 additions & 12 deletions src/session/connection.rs
Original file line number Diff line number Diff line change
@@ -1,37 +1,47 @@
use std::io::{Error, ErrorKind, IoSlice, Result};
use std::pin::Pin;
use std::ptr;
use std::sync::Arc;
use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
use std::time::Duration;

use bytes::buf::BufMut;
use ignore_result::Ignore;
use rustls::pki_types::ServerName;
use rustls::ClientConfig;
use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufStream, ReadBuf};
use tokio::net::TcpStream;
use tokio::{select, time};
use tokio_rustls::client::TlsStream;
use tokio_rustls::TlsConnector;
use tracing::{debug, trace};

#[cfg(feature = "tls")]
mod tls {
pub use std::sync::Arc;

pub use rustls::pki_types::ServerName;
pub use rustls::ClientConfig;
pub use tokio_rustls::client::TlsStream;
pub use tokio_rustls::TlsConnector;
}
#[cfg(feature = "tls")]
use tls::*;

use crate::deadline::Deadline;
use crate::endpoint::{EndpointRef, IterableEndpoints};

const NOOP_VTABLE: RawWakerVTable =
RawWakerVTable::new(|_| RawWaker::new(ptr::null(), &NOOP_VTABLE), |_| {}, |_| {}, |_| {});
const NOOP_WAKER: RawWaker = RawWaker::new(ptr::null(), &NOOP_VTABLE);

#[derive(Debug)]
pub enum Connection {
Tls(TlsStream<TcpStream>),
Raw(TcpStream),
#[cfg(feature = "tls")]
Tls(TlsStream<TcpStream>),
}

impl AsyncRead for Connection {
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<Result<()>> {
match self.get_mut() {
Self::Raw(stream) => Pin::new(stream).poll_read(cx, buf),
#[cfg(feature = "tls")]
Self::Tls(stream) => Pin::new(stream).poll_read(cx, buf),
}
}
Expand All @@ -41,20 +51,23 @@ impl AsyncWrite for Connection {
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
match self.get_mut() {
Self::Raw(stream) => Pin::new(stream).poll_write(cx, buf),
#[cfg(feature = "tls")]
Self::Tls(stream) => Pin::new(stream).poll_write(cx, buf),
}
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
match self.get_mut() {
Self::Raw(stream) => Pin::new(stream).poll_flush(cx),
#[cfg(feature = "tls")]
Self::Tls(stream) => Pin::new(stream).poll_flush(cx),
}
}

fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
match self.get_mut() {
Self::Raw(stream) => Pin::new(stream).poll_shutdown(cx),
#[cfg(feature = "tls")]
Self::Tls(stream) => Pin::new(stream).poll_shutdown(cx),
}
}
Expand All @@ -65,6 +78,7 @@ impl Connection {
Self::Raw(stream)
}

#[cfg(feature = "tls")]
pub fn new_tls(stream: TlsStream<TcpStream>) -> Self {
Self::Tls(stream)
}
Expand Down Expand Up @@ -97,6 +111,7 @@ impl Connection {
pub async fn readable(&self) -> Result<()> {
match self {
Self::Raw(stream) => stream.readable().await,
#[cfg(feature = "tls")]
Self::Tls(stream) => {
let (stream, session) = stream.get_ref();
if session.wants_read() {
Expand All @@ -112,6 +127,7 @@ impl Connection {
pub async fn writable(&self) -> Result<()> {
match self {
Self::Raw(stream) => stream.writable().await,
#[cfg(feature = "tls")]
Self::Tls(stream) => {
let (stream, _session) = stream.get_ref();
stream.writable().await
Expand All @@ -122,6 +138,7 @@ impl Connection {
pub fn wants_write(&self) -> bool {
match self {
Self::Raw(_) => false,
#[cfg(feature = "tls")]
Self::Tls(stream) => {
let (_stream, session) = stream.get_ref();
session.wants_write()
Expand Down Expand Up @@ -160,13 +177,33 @@ impl Connection {

#[derive(Clone)]
pub struct Connector {
tls: TlsConnector,
#[cfg(feature = "tls")]
tls: Option<TlsConnector>,
timeout: Duration,
}

impl Connector {
pub fn new(config: impl Into<Arc<ClientConfig>>) -> Self {
Self { tls: TlsConnector::from(config.into()), timeout: Duration::from_secs(10) }
#[cfg(feature = "tls")]
#[allow(dead_code)]
pub fn new() -> Self {
Self { tls: None, timeout: Duration::from_secs(10) }
}

#[cfg(not(feature = "tls"))]
pub fn new() -> Self {
Self { timeout: Duration::from_secs(10) }
}

#[cfg(feature = "tls")]
pub fn with_tls(config: ClientConfig) -> Self {
Self { tls: Some(TlsConnector::from(Arc::new(config))), timeout: Duration::from_secs(10) }
}

#[cfg(feature = "tls")]
async fn connect_tls(&self, stream: TcpStream, host: &str) -> Result<Connection> {
let domain = ServerName::try_from(host).unwrap().to_owned();
let stream = self.tls.as_ref().unwrap().connect(domain, stream).await?;
Ok(Connection::new_tls(stream))
}

pub fn timeout(&self) -> Duration {
Expand All @@ -178,6 +215,14 @@ impl Connector {
}

pub async fn connect(&self, endpoint: EndpointRef<'_>, deadline: &mut Deadline) -> Result<Connection> {
if endpoint.tls {
#[cfg(feature = "tls")]
if self.tls.is_none() {
return Err(Error::new(ErrorKind::Unsupported, "tls not supported"));
}
#[cfg(not(feature = "tls"))]
return Err(Error::new(ErrorKind::Unsupported, "tls not supported"));
}
select! {
_ = unsafe { Pin::new_unchecked(deadline) } => Err(Error::new(ErrorKind::TimedOut, "deadline exceed")),
_ = time::sleep(self.timeout) => Err(Error::new(ErrorKind::TimedOut, format!("connection timeout{:?} exceed", self.timeout))),
Expand All @@ -186,9 +231,10 @@ impl Connector {
Err(err) => Err(err),
Ok(sock) => {
let connection = if endpoint.tls {
let domain = ServerName::try_from(endpoint.host).unwrap().to_owned();
let stream = self.tls.connect(domain, sock).await?;
Connection::new_tls(stream)
#[cfg(not(feature = "tls"))]
unreachable!("tls not supported");
#[cfg(feature = "tls")]
self.connect_tls(sock, endpoint.host).await?
} else {
Connection::new_raw(sock)
};
Expand Down Expand Up @@ -231,3 +277,20 @@ impl Connector {
None
}
}

#[cfg(test)]
mod tests {
use std::io::ErrorKind;

use super::Connector;
use crate::deadline::Deadline;
use crate::endpoint::EndpointRef;

#[tokio::test]
async fn raw() {
let connector = Connector::new();
let endpoint = EndpointRef::new("host1", 2181, true);
let err = connector.connect(endpoint, &mut Deadline::never()).await.unwrap_err();
assert_eq!(err.kind(), ErrorKind::Unsupported);
}
}
10 changes: 8 additions & 2 deletions src/session/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ use crate::proto::{AuthPacket, ConnectRequest, ConnectResponse, ErrorCode, OpCod
use crate::record;
#[cfg(any(feature = "sasl-digest-md5", feature = "sasl-gssapi"))]
use crate::sasl::{SaslInitiator, SaslOptions, SaslSession};
#[cfg(feature = "tls")]
use crate::tls::TlsOptions;

pub const PASSWORD_LEN: usize = 16;
Expand All @@ -61,6 +62,7 @@ impl RequestOperation for (WatcherId, StateResponser) {

#[derive(Default)]
pub struct Builder {
#[cfg(feature = "tls")]
tls: Option<TlsOptions>,
#[cfg(any(feature = "sasl-digest-md5", feature = "sasl-gssapi"))]
sasl: Option<SaslOptions>,
Expand All @@ -73,6 +75,7 @@ pub struct Builder {
}

impl Builder {
#[cfg(feature = "tls")]
pub fn with_tls(self, tls: Option<TlsOptions>) -> Self {
Self { tls, ..self }
}
Expand Down Expand Up @@ -125,7 +128,10 @@ impl Builder {
} else if self.connection_timeout < Duration::ZERO {
return Err(Error::BadArguments(&"connection timeout must not be negative"));
}
let tls_config = self.tls.unwrap_or_default().into_config()?;
#[cfg(feature = "tls")]
let connector = Connector::with_tls(self.tls.unwrap_or_default().into_config()?);
#[cfg(not(feature = "tls"))]
let connector = Connector::new();
let (state_sender, state_receiver) = tokio::sync::watch::channel(SessionState::Disconnected);
let now = Instant::now();
let (watch_manager, unwatch_receiver) = WatchManager::new();
Expand All @@ -142,7 +148,7 @@ impl Builder {
tick_timeout: Duration::ZERO,
ping_timeout: Duration::ZERO,
session_expired_timeout: Duration::ZERO,
connector: Connector::new(tls_config),
connector,
#[cfg(any(feature = "sasl-digest-md5", feature = "sasl-gssapi"))]
sasl_options: self.sasl,
#[cfg(any(feature = "sasl-digest-md5", feature = "sasl-gssapi"))]
Expand Down
Loading
Loading