Skip to content

Commit

Permalink
Fix poll_capacity to wake in combination with max_send_buffer_size
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmonstar committed Dec 9, 2021
1 parent 88037ae commit a5c60b2
Showing 3 changed files with 90 additions and 0 deletions.
5 changes: 5 additions & 0 deletions src/proto/streams/prioritize.rs
Original file line number Diff line number Diff line change
@@ -741,6 +741,11 @@ impl Prioritize {
stream.buffered_send_data -= len as usize;
stream.requested_send_capacity -= len;

// If the capacity was limited because of the
// max_send_buffer_size, then consider waking
// the send task again...
stream.notify_if_can_buffer_more();

// Assign the capacity back to the connection that
// was just consumed from the stream in the previous
// line.
11 changes: 11 additions & 0 deletions src/proto/streams/stream.rs
Original file line number Diff line number Diff line change
@@ -279,6 +279,17 @@ impl Stream {
}
}

/// If the capacity was limited because of the max_send_buffer_size,
/// then consider waking the send task again...
pub fn notify_if_can_buffer_more(&mut self) {
// Only notify if the capacity exceeds the amount of buffered data
if self.send_flow.available() > self.buffered_send_data {
self.send_capacity_inc = true;
tracing::trace!(" notifying task");
self.notify_send();
}
}

/// Returns `Err` when the decrement cannot be completed due to overflow.
pub fn dec_content_length(&mut self, len: usize) -> Result<(), ()> {
match self.content_length {
74 changes: 74 additions & 0 deletions tests/h2-tests/tests/flow_control.rs
Original file line number Diff line number Diff line change
@@ -1668,3 +1668,77 @@ async fn max_send_buffer_size_overflow() {

join(srv, client).await;
}

#[tokio::test]
async fn max_send_buffer_size_poll_capacity_wakes_task() {
h2_support::trace_init!();
let (io, mut srv) = mock::new();

let srv = async move {
let settings = srv.assert_client_handshake().await;
assert_default_settings!(settings);
srv.recv_frame(frames::headers(1).request("POST", "https://www.example.com/"))
.await;
srv.send_frame(frames::headers(1).response(200).eos()).await;
srv.recv_frame(frames::data(1, &[0; 5][..])).await;
srv.recv_frame(frames::data(1, &[0; 5][..])).await;
srv.recv_frame(frames::data(1, &[0; 5][..])).await;
srv.recv_frame(frames::data(1, &[0; 5][..])).await;
srv.recv_frame(frames::data(1, &[][..]).eos()).await;
};

let client = async move {
let (mut client, mut conn) = client::Builder::new()
.max_send_buffer_size(5)
.handshake::<_, Bytes>(io)
.await
.unwrap();
let request = Request::builder()
.method(Method::POST)
.uri("https://www.example.com/")
.body(())
.unwrap();

let (response, mut stream) = client.send_request(request, false).unwrap();

let response = conn.drive(response).await.unwrap();

assert_eq!(response.status(), StatusCode::OK);

assert_eq!(stream.capacity(), 0);
const TO_SEND: usize = 20;
stream.reserve_capacity(TO_SEND);
assert_eq!(
stream.capacity(),
5,
"polled capacity not over max buffer size"
);

let t1 = tokio::spawn(async move {
let mut sent = 0;
let buf = [0; TO_SEND];
loop {
match poll_fn(|cx| stream.poll_capacity(cx)).await {
None => panic!("no cap"),
Some(Err(e)) => panic!("cap error: {:?}", e),
Some(Ok(cap)) => {
stream
.send_data(buf[sent..(sent + cap)].to_vec().into(), false)
.unwrap();
sent += cap;
if sent >= TO_SEND {
break;
}
}
}
}
stream.send_data(Bytes::new(), true).unwrap();
});

// Wait for the connection to close
conn.await.unwrap();
t1.await.unwrap();
};

join(srv, client).await;
}

0 comments on commit a5c60b2

Please sign in to comment.