Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: ApiKey newtype to ensure key is always valid format #835

Merged
merged 10 commits into from
May 4, 2023
6 changes: 4 additions & 2 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,15 @@ Before we can login to our local instance of shuttle, we need to create a user.
The following command inserts a user into the `auth` state with admin privileges:

```bash
docker compose --file docker-compose.rendered.yml --project-name shuttle-dev exec auth /usr/local/bin/service --state=/var/lib/shuttle-auth init --name admin --key test-key
# the --key needs to be 16 alphanumeric characters
docker compose --file docker-compose.rendered.yml --project-name shuttle-dev exec auth /usr/local/bin/service --state=/var/lib/shuttle-auth init --name admin --key dh9z58jttoes3qvt
```

Login to shuttle service in a new terminal window from the root of the shuttle directory:

```bash
cargo run --bin cargo-shuttle -- login --api-key "test-key"
# the --api-kei should be the same one you inserted in the auth state
cargo run --bin cargo-shuttle -- login --api-key "dh9z58jttoes3qvt"
```

The [shuttle examples](https://github.com/shuttle-hq/examples) are linked to the main repo as a [git submodule](https://git-scm.com/book/en/v2/Git-Tools-Submodules), to initialize it run the following commands:
Expand Down
81 changes: 80 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion auth/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ tracing-subscriber = { workspace = true }

[dependencies.shuttle-common]
workspace = true
features = ["backend", "models"]
features = ["backend", "models", "persist"]

[dev-dependencies]
axum-extra = { version = "0.7.1", features = ["cookie"] }
Expand Down
2 changes: 1 addition & 1 deletion auth/src/api/handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ pub(crate) async fn convert_key(
let User {
name, account_tier, ..
} = user_manager
.get_user_by_key(key.clone())
.get_user_by_key(key.as_ref().clone())
.await
.map_err(|_| StatusCode::UNAUTHORIZED)?;

Expand Down
16 changes: 9 additions & 7 deletions auth/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ mod user;
use std::{io, str::FromStr, time::Duration};

use args::StartArgs;
use shuttle_common::ApiKey;
use sqlx::{
migrate::Migrator,
query,
Expand All @@ -15,10 +16,7 @@ use sqlx::{
};
use tracing::info;

use crate::{
api::serve,
user::{AccountTier, Key},
};
use crate::{api::serve, user::AccountTier};
pub use api::ApiBuilder;
pub use args::{Args, Commands, InitArgs};

Expand All @@ -41,8 +39,8 @@ pub async fn start(pool: SqlitePool, args: StartArgs) -> io::Result<()> {

pub async fn init(pool: SqlitePool, args: InitArgs) -> io::Result<()> {
let key = match args.key {
Some(ref key) => Key::from_str(key).unwrap(),
None => Key::new_random(),
Some(ref key) => ApiKey::parse(key).unwrap(),
None => ApiKey::generate(),
};

query("INSERT INTO users (account_name, key, account_tier) VALUES (?1, ?2, ?3)")
Expand All @@ -53,7 +51,11 @@ pub async fn init(pool: SqlitePool, args: InitArgs) -> io::Result<()> {
.await
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;

println!("`{}` created as super user with key: {key}", args.name);
println!(
"`{}` created as super user with key: {}",
args.name,
key.as_ref()
);
Ok(())
}

Expand Down
64 changes: 27 additions & 37 deletions auth/src/user.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@ use axum::{
http::request::Parts,
TypedHeader,
};
use rand::distributions::{Alphanumeric, DistString};
use serde::{Deserialize, Deserializer, Serialize};
use shuttle_common::claims::{Scope, ScopeBuilder};
use shuttle_common::{
claims::{Scope, ScopeBuilder},
ApiKey,
};
use sqlx::{query, Row, SqlitePool};
use tracing::{trace, Span};

Expand All @@ -19,7 +21,7 @@ use crate::{api::UserManagerState, error::Error};
pub trait UserManagement: Send + Sync {
async fn create_user(&self, name: AccountName, tier: AccountTier) -> Result<User, Error>;
async fn get_user(&self, name: AccountName) -> Result<User, Error>;
async fn get_user_by_key(&self, key: Key) -> Result<User, Error>;
async fn get_user_by_key(&self, key: ApiKey) -> Result<User, Error>;
}

#[derive(Clone)]
Expand All @@ -30,7 +32,7 @@ pub struct UserManager {
#[async_trait]
impl UserManagement for UserManager {
async fn create_user(&self, name: AccountName, tier: AccountTier) -> Result<User, Error> {
let key = Key::new_random();
let key = ApiKey::generate();

query("INSERT INTO users (account_name, key, account_tier) VALUES (?1, ?2, ?3)")
.bind(&name)
Expand All @@ -55,7 +57,7 @@ impl UserManagement for UserManager {
.ok_or(Error::UserNotFound)
}

async fn get_user_by_key(&self, key: Key) -> Result<User, Error> {
async fn get_user_by_key(&self, key: ApiKey) -> Result<User, Error> {
query("SELECT account_name, key, account_tier FROM users WHERE key = ?1")
.bind(&key)
.fetch_optional(&self.pool)
Expand All @@ -72,7 +74,7 @@ impl UserManagement for UserManager {
#[derive(Clone, Deserialize, PartialEq, Eq, Serialize, Debug)]
pub struct User {
pub name: AccountName,
pub key: Key,
pub key: ApiKey,
pub account_tier: AccountTier,
}

Expand All @@ -81,7 +83,7 @@ impl User {
self.account_tier == AccountTier::Admin
}

pub fn new(name: AccountName, key: Key, account_tier: AccountTier) -> Self {
pub fn new(name: AccountName, key: ApiKey, account_tier: AccountTier) -> Self {
Self {
name,
key,
Expand All @@ -104,9 +106,9 @@ where
let user_manager: UserManagerState = UserManagerState::from_ref(state);

let user = user_manager
.get_user_by_key(key)
.get_user_by_key(key.as_ref().clone())
.await
// Absord any error into `Unauthorized`
// Absorb any error into `Unauthorized`
.map_err(|_| Error::Unauthorized)?;

// Record current account name for tracing purposes
Expand All @@ -120,16 +122,21 @@ impl From<User> for shuttle_common::models::user::Response {
fn from(user: User) -> Self {
Self {
name: user.name.to_string(),
key: user.key.to_string(),
key: user.key.as_ref().to_string(),
account_tier: user.account_tier.to_string(),
}
}
}

#[derive(Clone, Debug, sqlx::Type, PartialEq, Hash, Eq, Serialize, Deserialize)]
#[serde(transparent)]
#[sqlx(transparent)]
pub struct Key(String);
/// A wrapper around [ApiKey] so we can implement [FromRequestParts]
/// for it.
pub struct Key(ApiKey);

impl AsRef<ApiKey> for Key {
fn as_ref(&self) -> &ApiKey {
&self.0
}
}

#[async_trait]
impl<S> FromRequestParts<S> for Key
Expand All @@ -142,31 +149,14 @@ where
let key = TypedHeader::<Authorization<Bearer>>::from_request_parts(parts, state)
.await
.map_err(|_| Error::KeyMissing)
.and_then(|TypedHeader(Authorization(bearer))| bearer.token().trim().parse())?;
.and_then(|TypedHeader(Authorization(bearer))| {
let bearer = bearer.token().trim();
ApiKey::parse(bearer).map_err(|_| Self::Rejection::Unauthorized)
})?;

trace!(%key, "got bearer key");

Ok(key)
}
}

impl std::fmt::Display for Key {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}

impl FromStr for Key {
type Err = Error;

fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(Self(s.to_string()))
}
}
trace!("got bearer key");

impl Key {
pub fn new_random() -> Self {
Self(Alphanumeric.sample_string(&mut rand::thread_rng(), 16))
Ok(Key(key))
}
}

Expand Down
2 changes: 1 addition & 1 deletion auth/tests/api/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ async fn convert_api_key_to_jwt() {
// GET /auth/key with invalid bearer token.
let request = Request::builder()
.uri("/auth/key")
.header(AUTHORIZATION, "Bearer notadmin")
.header(AUTHORIZATION, "Bearer ndh9z58jttoefake")
.body(Body::empty())
.unwrap();

Expand Down
2 changes: 1 addition & 1 deletion auth/tests/api/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use shuttle_auth::{sqlite_init, ApiBuilder};
use sqlx::query;
use tower::ServiceExt;

pub(crate) const ADMIN_KEY: &str = "my-api-key";
pub(crate) const ADMIN_KEY: &str = "ndh9z58jttoes3qv";

pub(crate) struct TestApp {
pub router: Router,
Expand Down
Loading