Skip to content

Commit

Permalink
fix: Expired not fire timely due to connection retry backoff
Browse files Browse the repository at this point in the history
  • Loading branch information
kezhuw committed Mar 30, 2024
1 parent b0ce0b8 commit 0e4e201
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 34 deletions.
9 changes: 9 additions & 0 deletions src/deadline.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;

use tokio::time::{self, Instant, Sleep};

Expand All @@ -20,6 +21,14 @@ impl Deadline {
pub fn elapsed(&self) -> bool {
self.sleep.as_ref().map(|f| f.is_elapsed()).unwrap_or(false)
}

/// Remaining timeout.
pub fn timeout(&self) -> Duration {
match self.sleep.as_ref() {
None => Duration::MAX,
Some(sleep) => sleep.deadline().saturating_duration_since(Instant::now()),
}
}
}

impl Future for Deadline {
Expand Down
26 changes: 14 additions & 12 deletions src/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,15 +209,15 @@ impl IterableEndpoints {
self.start = self.next;
}

pub async fn next(&mut self) -> Option<EndpointRef<'_>> {
pub async fn next(&mut self, max_delay: Duration) -> Option<EndpointRef<'_>> {
let index = self.index()?;
self.delay(index).await;
self.delay(index, max_delay).await;
self.step();
Some(self.endpoints[index.offset].to_ref())
}

async fn delay(&self, index: Index) {
let timeout = Self::timeout(index, self.endpoints.len());
async fn delay(&self, index: Index, max_delay: Duration) {
let timeout = max_delay.min(Self::timeout(index, self.endpoints.len()));
if timeout != Duration::ZERO {
tokio::time::sleep(timeout).await;
}
Expand Down Expand Up @@ -338,22 +338,24 @@ mod tests {

#[tokio::test]
async fn test_iterable_endpoints_next() {
use std::time::Duration;

use assertor::*;

use super::{parse_connect_string, EndpointRef, Index, IterableEndpoints};
let (endpoints, _) = parse_connect_string("host1:2181,tcp://host2,tcp+tls://host3:2182", true).unwrap();
let mut endpoints = IterableEndpoints::from(endpoints.as_slice());
assert_eq!(endpoints.next().await, Some(EndpointRef::new("host1", 2181, true)));
assert_eq!(endpoints.next().await, Some(EndpointRef::new("host2", 2181, false)));
assert_eq!(endpoints.next().await, Some(EndpointRef::new("host3", 2182, true)));
assert_eq!(endpoints.next().await, None);
assert_eq!(endpoints.next(Duration::MAX).await, Some(EndpointRef::new("host1", 2181, true)));
assert_eq!(endpoints.next(Duration::MAX).await, Some(EndpointRef::new("host2", 2181, false)));
assert_eq!(endpoints.next(Duration::MAX).await, Some(EndpointRef::new("host3", 2182, true)));
assert_eq!(endpoints.next(Duration::MAX).await, None);

endpoints.cycle();
let start = std::time::Instant::now();
assert_eq!(endpoints.next().await, Some(EndpointRef::new("host1", 2181, true)));
assert_eq!(endpoints.next().await, Some(EndpointRef::new("host2", 2181, false)));
assert_eq!(endpoints.next().await, Some(EndpointRef::new("host3", 2182, true)));
assert_eq!(endpoints.next().await, Some(EndpointRef::new("host1", 2181, true)));
assert_eq!(endpoints.next(Duration::MAX).await, Some(EndpointRef::new("host1", 2181, true)));
assert_eq!(endpoints.next(Duration::MAX).await, Some(EndpointRef::new("host2", 2181, false)));
assert_eq!(endpoints.next(Duration::MAX).await, Some(EndpointRef::new("host3", 2182, true)));
assert_eq!(endpoints.next(Duration::MAX).await, Some(EndpointRef::new("host1", 2181, true)));
let delay = IterableEndpoints::timeout(Index { offset: 0, cycles: 1 }, 3)
+ IterableEndpoints::timeout(Index { offset: 1, cycles: 1 }, 3);
let now = std::time::Instant::now();
Expand Down
3 changes: 2 additions & 1 deletion src/session/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ impl Session {
let mut depot = Depot::for_serving();
let mut unwatch_requester = self.unwatch_receiver.take().unwrap();
endpoints.cycle();
endpoints.reset();
self.serve_once(conn, &mut endpoints, &mut buf, &mut depot, &mut requester, &mut unwatch_requester).await;
while !self.session_state.is_terminated() {
let conn = match self.start(&mut endpoints, &mut buf, &mut connecting_trans).await {
Expand Down Expand Up @@ -538,7 +539,7 @@ impl Session {
buf: &mut Vec<u8>,
depot: &mut Depot,
) -> Result<Connection, Error> {
let Some(endpoint) = endpoints.next().await else {
let Some(endpoint) = endpoints.next(deadline.timeout()).await else {
return Err(Error::NoHosts);
};
let mut conn = match self.connector.connect(endpoint, deadline).await {
Expand Down
46 changes: 25 additions & 21 deletions tests/zookeeper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1776,6 +1776,29 @@ async fn test_tls() {
assert_eq!(client2.get_data("/a").await.unwrap_err(), zk::Error::NoAuth);
}

trait StateWaiter {
async fn wait(&mut self, expected: zk::SessionState, timeout: Option<Duration>);
}

impl StateWaiter for zk::StateWatcher {
async fn wait(&mut self, expected: zk::SessionState, timeout: Option<Duration>) {
let timeout = timeout.unwrap_or_else(|| Duration::from_secs(60));
let mut sleep = tokio::time::sleep(timeout);
let mut got = self.state();
loop {
if got == expected {
break;
} else if got.is_terminated() {
panic!("expect {expected}, but got terminal state {got}")
}
select! {
state = self.changed() => got = state,
_ = unsafe { Pin::new_unchecked(&mut sleep) } => panic!("expect {expected}, but still {got} after {}s", timeout.as_secs()),
}
}
}
}

#[cfg(target_os = "linux")]
#[test_case(true; "tls")]
#[test_case(false; "plaintext")]
Expand Down Expand Up @@ -1811,15 +1834,7 @@ async fn test_readonly(tls: bool) {
cluster.by_id(2).stop();

// Quorum session will expire finally.
let mut timeout = tokio::time::sleep(2 * client.session_timeout());
loop {
select! {
state = state_watcher.changed() => if state == zk::SessionState::Expired {
break
},
_ = unsafe { Pin::new_unchecked(&mut timeout) } => panic!("expect Expired, but got {}", state_watcher.state()),
}
}
state_watcher.wait(zk::SessionState::Expired, Some(2 * client.session_timeout())).await;

logs.wait_for_message("Read-only server started").unwrap();

Expand All @@ -1842,18 +1857,7 @@ async fn test_readonly(tls: bool) {
cluster.by_id(1).start();
cluster.by_id(2).start();

let mut timeout = tokio::time::sleep(Duration::from_secs(60));
loop {
select! {
state = state_watcher.changed() => match state {
zk::SessionState::SyncConnected => break,
zk::SessionState::Disconnected | zk::SessionState::ConnectedReadOnly => continue,
state => panic!("expect SyncConnected, but got {}", state),
},
_ = unsafe { Pin::new_unchecked(&mut timeout) } => panic!("expect SyncConnected, but got {}", state_watcher.state()),
}
}

state_watcher.wait(zk::SessionState::SyncConnected, None).await;
client.create("/z", b"", PERSISTENT_OPEN).await.unwrap();
}

Expand Down

0 comments on commit 0e4e201

Please sign in to comment.