Skip to content

Commit

Permalink
Split connection to reader and writer to multiplex them (#51)
Browse files Browse the repository at this point in the history
changelog: changed
  • Loading branch information
kezhuw authored May 7, 2024
1 parent 03c7ba1 commit c71c44d
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 125 deletions.
133 changes: 60 additions & 73 deletions src/session/connection.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
use std::io::{Error, ErrorKind, IoSlice, Result};
use std::pin::Pin;
use std::ptr;
use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
use std::task::{Context, Poll};
use std::time::Duration;

use bytes::buf::BufMut;
use ignore_result::Ignore;
use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufStream, ReadBuf};
use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufStream, ReadBuf};
use tokio::net::TcpStream;
use tokio::{select, time};
use tracing::{debug, trace};
Expand All @@ -26,17 +25,31 @@ use tls::*;
use crate::deadline::Deadline;
use crate::endpoint::{EndpointRef, IterableEndpoints};

const NOOP_VTABLE: RawWakerVTable =
RawWakerVTable::new(|_| RawWaker::new(ptr::null(), &NOOP_VTABLE), |_| {}, |_| {}, |_| {});
const NOOP_WAKER: RawWaker = RawWaker::new(ptr::null(), &NOOP_VTABLE);

#[derive(Debug)]
pub enum Connection {
Raw(TcpStream),
#[cfg(feature = "tls")]
Tls(TlsStream<TcpStream>),
}

pub trait AsyncReadToBuf: AsyncReadExt {
async fn read_to_buf(&mut self, buf: &mut impl BufMut) -> Result<usize>
where
Self: Unpin, {
let chunk = buf.chunk_mut();
let read_to = unsafe { std::mem::transmute(chunk.as_uninit_slice_mut()) };
let n = self.read(read_to).await?;
if n != 0 {
unsafe {
buf.advance_mut(n);
}
}
Ok(n)
}
}

impl<T> AsyncReadToBuf for T where T: AsyncReadExt {}

impl AsyncRead for Connection {
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<Result<()>> {
match self.get_mut() {
Expand All @@ -56,6 +69,14 @@ impl AsyncWrite for Connection {
}
}

fn poll_write_vectored(self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[IoSlice<'_>]) -> Poll<Result<usize>> {
match self.get_mut() {
Self::Raw(stream) => Pin::new(stream).poll_write_vectored(cx, bufs),
#[cfg(feature = "tls")]
Self::Tls(stream) => Pin::new(stream).poll_write_vectored(cx, bufs),
}
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
match self.get_mut() {
Self::Raw(stream) => Pin::new(stream).poll_flush(cx),
Expand All @@ -73,86 +94,52 @@ impl AsyncWrite for Connection {
}
}

impl Connection {
pub fn new_raw(stream: TcpStream) -> Self {
Self::Raw(stream)
pub struct ConnReader<'a> {
conn: &'a mut Connection,
}

impl AsyncRead for ConnReader<'_> {
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<Result<()>> {
Pin::new(&mut self.get_mut().conn).poll_read(cx, buf)
}
}

#[cfg(feature = "tls")]
pub fn new_tls(stream: TlsStream<TcpStream>) -> Self {
Self::Tls(stream)
pub struct ConnWriter<'a> {
conn: &'a mut Connection,
}

impl AsyncWrite for ConnWriter<'_> {
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
Pin::new(&mut self.get_mut().conn).poll_write(cx, buf)
}

pub fn try_write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> Result<usize> {
let waker = unsafe { Waker::from_raw(NOOP_WAKER) };
let mut context = Context::from_waker(&waker);
match Pin::new(self).poll_write_vectored(&mut context, bufs) {
Poll::Pending => Err(ErrorKind::WouldBlock.into()),
Poll::Ready(result) => result,
}
fn poll_write_vectored(self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[IoSlice<'_>]) -> Poll<Result<usize>> {
Pin::new(&mut self.get_mut().conn).poll_write_vectored(cx, bufs)
}

pub fn try_read_buf(&mut self, buf: &mut impl BufMut) -> Result<usize> {
let waker = unsafe { Waker::from_raw(NOOP_WAKER) };
let mut context = Context::from_waker(&waker);
let chunk = buf.chunk_mut();
let mut read_buf = unsafe { ReadBuf::uninit(chunk.as_uninit_slice_mut()) };
match Pin::new(self).poll_read(&mut context, &mut read_buf) {
Poll::Pending => Err(ErrorKind::WouldBlock.into()),
Poll::Ready(Err(err)) => Err(err),
Poll::Ready(Ok(())) => {
let n = read_buf.filled().len();
unsafe { buf.advance_mut(n) };
Ok(n)
},
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
Pin::new(&mut self.get_mut().conn).poll_flush(cx)
}

pub async fn readable(&self) -> Result<()> {
match self {
Self::Raw(stream) => stream.readable().await,
#[cfg(feature = "tls")]
Self::Tls(stream) => {
let (stream, session) = stream.get_ref();
if session.wants_read() {
stream.readable().await
} else {
// plaintext data are available for read
std::future::ready(Ok(())).await
}
},
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
Pin::new(&mut self.get_mut().conn).poll_shutdown(cx)
}
}

pub async fn writable(&self) -> Result<()> {
match self {
Self::Raw(stream) => stream.writable().await,
#[cfg(feature = "tls")]
Self::Tls(stream) => {
let (stream, _session) = stream.get_ref();
stream.writable().await
},
}
impl Connection {
pub fn new_raw(stream: TcpStream) -> Self {
Self::Raw(stream)
}

pub fn wants_write(&self) -> bool {
match self {
Self::Raw(_) => false,
#[cfg(feature = "tls")]
Self::Tls(stream) => {
let (_stream, session) = stream.get_ref();
session.wants_write()
},
}
pub fn split(&mut self) -> (ConnReader<'_>, ConnWriter<'_>) {
let reader = ConnReader { conn: self };
let writer = ConnWriter { conn: unsafe { std::ptr::read(&reader.conn) } };
(reader, writer)
}

pub fn try_flush(&mut self) -> Result<()> {
let waker = unsafe { Waker::from_raw(NOOP_WAKER) };
let mut context = Context::from_waker(&waker);
match Pin::new(self).poll_flush(&mut context) {
Poll::Pending => Err(ErrorKind::WouldBlock.into()),
Poll::Ready(result) => result,
}
#[cfg(feature = "tls")]
pub fn new_tls(stream: TlsStream<TcpStream>) -> Self {
Self::Tls(stream)
}

pub async fn command(self, cmd: &str) -> Result<String> {
Expand Down
26 changes: 6 additions & 20 deletions src/session/depot.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use std::collections::VecDeque;
use std::io::{self, IoSlice};
use std::io::IoSlice;

use hashbrown::HashMap;
use strum::IntoEnumIterator;
use tokio::io::AsyncWriteExt;
use tracing::debug;

use super::connection::Connection;
use super::request::{MarshalledRequest, OpStat, Operation, SessionOperation, StateResponser};
use super::types::WatchMode;
use super::xid::Xid;
Expand Down Expand Up @@ -229,26 +229,12 @@ impl Depot {
.any(|mode| self.watching_paths.contains_key(&(path, mode)))
}

pub fn write_operations(&mut self, conn: &mut Connection) -> Result<(), Error> {
pub async fn write_to(&mut self, write: &mut (impl AsyncWriteExt + Unpin)) -> Result<(), Error> {
if !self.has_pending_writes() {
if let Err(err) = conn.try_flush() {
if err.kind() == io::ErrorKind::WouldBlock {
return Ok(());
}
return Err(Error::other(err));
}
return Ok(());
write.flush().await.map_err(Error::other)?;
return std::future::pending().await;
}
let result = conn.try_write_vectored(self.writing_slices.as_slice());
let mut written_bytes = match result {
Err(err) => {
if err.kind() == io::ErrorKind::WouldBlock {
return Ok(());
}
return Err(Error::other(err));
},
Ok(written_bytes) => written_bytes,
};
let mut written_bytes = write.write_vectored(self.writing_slices.as_slice()).await.map_err(Error::other)?;
let written_slices = self
.writing_slices
.iter()
Expand Down
48 changes: 16 additions & 32 deletions src/session/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,17 @@ mod types;
mod watch;
mod xid;

use std::io;
use std::time::Duration;

use ignore_result::Ignore;
use tokio::io::AsyncWriteExt;
use tokio::select;
use tokio::sync::mpsc;
use tokio::time::{self, Instant};
use tracing::field::display;
use tracing::{debug, info, instrument, warn, Span};

use self::connection::{Connection, Connector};
use self::connection::{AsyncReadToBuf, Connection, Connector};
pub use self::depot::Depot;
use self::event::WatcherEvent;
pub use self::request::{
Expand Down Expand Up @@ -478,21 +478,6 @@ impl Session {
Ok(())
}

fn read_connection(&mut self, conn: &mut Connection, buf: &mut Vec<u8>) -> Result<(), Error> {
match conn.try_read_buf(buf) {
Ok(0) => {
return Err(Error::ConnectionLoss);
},
Err(err) => {
if err.kind() != io::ErrorKind::WouldBlock {
return Err(Error::other(err));
}
},
_ => {},
}
Ok(())
}

fn handle_recv_buf(&mut self, recved: &mut Vec<u8>, depot: &mut Depot) -> Result<(), Error> {
let mut reading = recved.as_slice();
if self.session_state == SessionState::Disconnected {
Expand Down Expand Up @@ -522,14 +507,15 @@ impl Session {
let mut pinged = false;
let mut tick = time::interval(self.tick_timeout);
tick.set_missed_tick_behavior(time::MissedTickBehavior::Skip);
let (mut reader, mut writer) = conn.split();
while !(self.session_state.is_connected() && depot.is_empty()) {
select! {
_ = conn.readable() => {
self.read_connection(conn, buf)?;
self.handle_recv_buf(buf, depot)?;
r = reader.read_to_buf(buf) => match r.map_err(Error::other)? {
0 => return Err(Error::ConnectionLoss),
_ => self.handle_recv_buf(buf, depot)?,
},
_ = conn.writable(), if depot.has_pending_writes() || conn.wants_write() => {
depot.write_operations(conn)?;
r = depot.write_to(&mut writer) => {
r?;
self.last_send = Instant::now();
},
now = tick.tick() => {
Expand All @@ -543,7 +529,6 @@ impl Session {
// "zookeeper.enforce.auth.enabled".
pinged = true;
self.send_ping(depot, Instant::now());
depot.write_operations(conn)?;
}
}
Ok(())
Expand Down Expand Up @@ -574,19 +559,20 @@ impl Session {
let mut err = None;
let mut channel_halted = false;
depot.start();
while !(channel_halted && depot.is_empty() && !conn.wants_write()) {
let (mut reader, mut writer) = conn.split();
while !(channel_halted && depot.is_empty()) {
select! {
Some(endpoint) = Self::poll(&mut seek_for_writable), if seek_for_writable.is_some() => {
seek_for_writable = None;
err = Some(Error::with_message(format!("encounter writable server {}", endpoint)));
channel_halted = true;
},
_ = conn.readable() => {
self.read_connection(conn, buf)?;
self.handle_recv_buf(buf, depot)?;
r = reader.read_to_buf(buf) => match r.map_err(Error::other)? {
0 => return Err(Error::ConnectionLoss),
_ => self.handle_recv_buf(buf, depot)?,
},
_ = conn.writable(), if depot.has_pending_writes() || conn.wants_write() => {
depot.write_operations(conn)?;
r = depot.write_to(&mut writer) => {
r?;
self.last_send = Instant::now();
},
r = requester.recv(), if !channel_halted => {
Expand All @@ -600,8 +586,6 @@ impl Session {
continue;
};
depot.push_session(operation);
depot.write_operations(conn)?;
self.last_send = Instant::now();
},
r = unwatch_requester.recv() => if let Some((watcher_id, responser)) = r {
self.watch_manager.remove_watcher(watcher_id, responser, depot);
Expand All @@ -612,11 +596,11 @@ impl Session {
}
if self.last_ping.is_none() && now >= self.last_send + self.ping_timeout {
self.send_ping(depot, now);
depot.write_operations(conn)?;
}
},
}
}
writer.flush().await.map_err(Error::other)?;
Err(err.unwrap_or(Error::ClientClosed))
}

Expand Down

0 comments on commit c71c44d

Please sign in to comment.