Skip to content

Commit

Permalink
Merge pull request #10305 from Turbo87/mono-test-utils
Browse files Browse the repository at this point in the history
tests/util: Reduce monomorphization overhead for the `RequestHelper` code
  • Loading branch information
Turbo87 authored Jan 2, 2025
2 parents 52f2f0c + 113dcf2 commit 9607d28
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 54 deletions.
66 changes: 28 additions & 38 deletions src/tests/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use crate::tests::{
CategoryListResponse, CategoryResponse, CrateList, CrateResponse, GoodCrate, OwnerResp,
OwnersResponse, VersionResponse,
};
use std::future::Future;

use http::{Method, Request};

Expand All @@ -33,6 +34,7 @@ use axum::body::{Body, Bytes};
use axum::extract::connect_info::MockConnectInfo;
use chrono::NaiveDateTime;
use cookie::Cookie;
use futures_util::FutureExt;
use http::header;
use secrecy::ExposeSecret;
use serde_json::json;
Expand Down Expand Up @@ -91,23 +93,28 @@ pub trait RequestHelper {
fn app(&self) -> &TestApp;

/// Run a request that is expected to succeed
async fn run<T>(&self, request: Request<impl Into<Body>>) -> Response<T> {
fn run<T>(&self, request: Request<impl Into<Body>>) -> impl Future<Output = Response<T>> {
let app = self.app();
let router = app.router().clone();
let request = request.map(Into::into);

// Add a mock `SocketAddr` to the requests so that the `ConnectInfo`
// extractor has something to extract.
let mocket_addr = SocketAddr::from(([127, 0, 0, 1], 52381));
let router = router.layer(MockConnectInfo(mocket_addr));
// This inner function is used to avoid long compile times
// due to monomorphization of the `run()` fn itself
async fn inner(app: &TestApp, request: Request<Body>) -> axum::response::Response<Bytes> {
let router = app.router().clone();

let request = request.map(Into::into);
let axum_response = router.oneshot(request).await.unwrap();
// Add a mock `SocketAddr` to the requests so that the `ConnectInfo`
// extractor has something to extract.
let mocket_addr = SocketAddr::from(([127, 0, 0, 1], 52381));
let router = router.layer(MockConnectInfo(mocket_addr));

let axum_response = router.oneshot(request).await.unwrap();

let (parts, body) = axum_response.into_parts();
let bytes = axum::body::to_bytes(body, usize::MAX).await.unwrap();
let bytes_response = axum::response::Response::from_parts(parts, bytes);
let (parts, body) = axum_response.into_parts();
let bytes = axum::body::to_bytes(body, usize::MAX).await.unwrap();
axum::response::Response::from_parts(parts, bytes)
}

Response::new(bytes_response)
inner(app, request).map(Response::new)
}

/// Create a get request
Expand All @@ -134,26 +141,18 @@ pub trait RequestHelper {

/// Issue a PUT request
async fn put<T>(&self, path: &str, body: impl Into<Bytes>) -> Response<T> {
let body = body.into();

let mut request = self.request_builder(Method::PUT, path);
*request.body_mut() = body;
if is_json_body(request.body()) {
request.header(header::CONTENT_TYPE, "application/json");
}
let request = self
.request_builder(Method::PUT, path)
.with_body(body.into());

self.run(request).await
}

/// Issue a PATCH request
async fn patch<T>(&self, path: &str, body: impl Into<Bytes>) -> Response<T> {
let body = body.into();

let mut request = self.request_builder(Method::PATCH, path);
*request.body_mut() = body;
if is_json_body(request.body()) {
request.header(header::CONTENT_TYPE, "application/json");
}
let request = self
.request_builder(Method::PATCH, path)
.with_body(body.into());

self.run(request).await
}
Expand All @@ -166,13 +165,9 @@ pub trait RequestHelper {

/// Issue a DELETE request with a body... yes we do it, for crate owner removal
async fn delete_with_body<T>(&self, path: &str, body: impl Into<Bytes>) -> Response<T> {
let body = body.into();

let mut request = self.request_builder(Method::DELETE, path);
*request.body_mut() = body;
if is_json_body(request.body()) {
request.header(header::CONTENT_TYPE, "application/json");
}
let request = self
.request_builder(Method::DELETE, path)
.with_body(body.into());

self.run(request).await
}
Expand Down Expand Up @@ -256,11 +251,6 @@ fn req(method: Method, path: &str) -> MockRequest {
.unwrap()
}

fn is_json_body(body: &Bytes) -> bool {
(body.starts_with(b"{") && body.ends_with(b"}"))
|| (body.starts_with(b"[") && body.ends_with(b"]"))
}

/// A type that can generate unauthenticated requests
pub struct MockAnonymousUser {
app: TestApp,
Expand Down
17 changes: 16 additions & 1 deletion src/tests/util/mock_request.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use axum::body::Bytes;
use http::{header::IntoHeaderName, HeaderValue, Request};
use http::{header, header::IntoHeaderName, HeaderValue, Request};

pub type MockRequest = Request<Bytes>;

pub trait MockRequestExt {
fn header<K: IntoHeaderName>(&mut self, name: K, value: &str);
fn with_body(self, bytes: Bytes) -> Self;
}

impl MockRequestExt for MockRequest {
Expand All @@ -15,6 +16,20 @@ impl MockRequestExt for MockRequest {
self.headers_mut()
.append(name, HeaderValue::from_str(value).unwrap());
}

fn with_body(mut self, bytes: Bytes) -> Self {
if is_json_body(&bytes) {
self.header(header::CONTENT_TYPE, "application/json");
}

*self.body_mut() = bytes;
self
}
}

fn is_json_body(body: &Bytes) -> bool {
(body.starts_with(b"{") && body.ends_with(b"}"))
|| (body.starts_with(b"[") && body.ends_with(b"]"))
}

#[cfg(test)]
Expand Down
30 changes: 15 additions & 15 deletions src/tests/util/response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,6 @@ impl<T> Response<T> {
assert_ok!(from_utf8(bytes)).to_string()
}

pub fn status(&self) -> StatusCode {
self.response.status()
}

#[track_caller]
pub fn assert_redirect_ends_with(&self, target: &str) -> &Self {
let headers = self.response.headers();
Expand Down Expand Up @@ -105,21 +101,25 @@ fn json<T>(r: &hyper::Response<Bytes>) -> T
where
for<'de> T: serde::Deserialize<'de>,
{
let headers = r.headers();
fn inner(r: &hyper::Response<Bytes>) -> &Bytes {
let headers = r.headers();

assert_some_eq!(headers.get(header::CONTENT_TYPE), "application/json");
assert_some_eq!(headers.get(header::CONTENT_TYPE), "application/json");

let content_length = assert_some!(
r.headers().get(header::CONTENT_LENGTH),
"Missing content-length header"
);
let content_length = assert_ok!(content_length.to_str());
let content_length: usize = assert_ok!(content_length.parse());
let content_length = assert_some!(
r.headers().get(header::CONTENT_LENGTH),
"Missing content-length header"
);
let content_length = assert_ok!(content_length.to_str());
let content_length: usize = assert_ok!(content_length.parse());

let bytes = r.body();
assert_that!(*bytes, len(eq(content_length)));
let bytes = r.body();
assert_that!(*bytes, len(eq(content_length)));

bytes
}

match serde_json::from_slice(bytes) {
match serde_json::from_slice(inner(r)) {
Ok(t) => t,
Err(e) => panic!("failed to decode: {e:?}"),
}
Expand Down

0 comments on commit 9607d28

Please sign in to comment.