Skip to content

Commit

Permalink
feat: ApiKey newtype to ensure key is always valid format (#835)
Browse files Browse the repository at this point in the history
* feat: ensure API key is valid

* feat: use ApiKey in auth

* refactor: clean up tests

* refactor: don't allocate in parse unless it succeeds

* fix: clippy

* fix: missing anyhow

* feat: impl debug/display for apikey
  • Loading branch information
oddgrd authored May 4, 2023
1 parent 0ec6509 commit fae2733
Show file tree
Hide file tree
Showing 14 changed files with 259 additions and 114 deletions.
6 changes: 4 additions & 2 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,15 @@ Before we can login to our local instance of shuttle, we need to create a user.
The following command inserts a user into the `auth` state with admin privileges:

```bash
docker compose --file docker-compose.rendered.yml --project-name shuttle-dev exec auth /usr/local/bin/service --state=/var/lib/shuttle-auth init --name admin --key test-key
# the --key needs to be 16 alphanumeric characters
docker compose --file docker-compose.rendered.yml --project-name shuttle-dev exec auth /usr/local/bin/service --state=/var/lib/shuttle-auth init --name admin --key dh9z58jttoes3qvt
```

Login to shuttle service in a new terminal window from the root of the shuttle directory:

```bash
cargo run --bin cargo-shuttle -- login --api-key "test-key"
# the --api-kei should be the same one you inserted in the auth state
cargo run --bin cargo-shuttle -- login --api-key "dh9z58jttoes3qvt"
```

The [shuttle examples](https://github.com/shuttle-hq/examples) are linked to the main repo as a [git submodule](https://git-scm.com/book/en/v2/Git-Tools-Submodules), to initialize it run the following commands:
Expand Down
81 changes: 80 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion auth/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ tracing-subscriber = { workspace = true }

[dependencies.shuttle-common]
workspace = true
features = ["backend", "models"]
features = ["backend", "models", "persist"]

[dev-dependencies]
axum-extra = { version = "0.7.1", features = ["cookie"] }
Expand Down
2 changes: 1 addition & 1 deletion auth/src/api/handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ pub(crate) async fn convert_key(
let User {
name, account_tier, ..
} = user_manager
.get_user_by_key(key.clone())
.get_user_by_key(key.as_ref().clone())
.await
.map_err(|_| StatusCode::UNAUTHORIZED)?;

Expand Down
16 changes: 9 additions & 7 deletions auth/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ mod user;
use std::{io, str::FromStr, time::Duration};

use args::StartArgs;
use shuttle_common::ApiKey;
use sqlx::{
migrate::Migrator,
query,
Expand All @@ -15,10 +16,7 @@ use sqlx::{
};
use tracing::info;

use crate::{
api::serve,
user::{AccountTier, Key},
};
use crate::{api::serve, user::AccountTier};
pub use api::ApiBuilder;
pub use args::{Args, Commands, InitArgs};

Expand All @@ -41,8 +39,8 @@ pub async fn start(pool: SqlitePool, args: StartArgs) -> io::Result<()> {

pub async fn init(pool: SqlitePool, args: InitArgs) -> io::Result<()> {
let key = match args.key {
Some(ref key) => Key::from_str(key).unwrap(),
None => Key::new_random(),
Some(ref key) => ApiKey::parse(key).unwrap(),
None => ApiKey::generate(),
};

query("INSERT INTO users (account_name, key, account_tier) VALUES (?1, ?2, ?3)")
Expand All @@ -53,7 +51,11 @@ pub async fn init(pool: SqlitePool, args: InitArgs) -> io::Result<()> {
.await
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;

println!("`{}` created as super user with key: {key}", args.name);
println!(
"`{}` created as super user with key: {}",
args.name,
key.as_ref()
);
Ok(())
}

Expand Down
64 changes: 27 additions & 37 deletions auth/src/user.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@ use axum::{
http::request::Parts,
TypedHeader,
};
use rand::distributions::{Alphanumeric, DistString};
use serde::{Deserialize, Deserializer, Serialize};
use shuttle_common::claims::{Scope, ScopeBuilder};
use shuttle_common::{
claims::{Scope, ScopeBuilder},
ApiKey,
};
use sqlx::{query, Row, SqlitePool};
use tracing::{trace, Span};

Expand All @@ -19,7 +21,7 @@ use crate::{api::UserManagerState, error::Error};
pub trait UserManagement: Send + Sync {
async fn create_user(&self, name: AccountName, tier: AccountTier) -> Result<User, Error>;
async fn get_user(&self, name: AccountName) -> Result<User, Error>;
async fn get_user_by_key(&self, key: Key) -> Result<User, Error>;
async fn get_user_by_key(&self, key: ApiKey) -> Result<User, Error>;
}

#[derive(Clone)]
Expand All @@ -30,7 +32,7 @@ pub struct UserManager {
#[async_trait]
impl UserManagement for UserManager {
async fn create_user(&self, name: AccountName, tier: AccountTier) -> Result<User, Error> {
let key = Key::new_random();
let key = ApiKey::generate();

query("INSERT INTO users (account_name, key, account_tier) VALUES (?1, ?2, ?3)")
.bind(&name)
Expand All @@ -55,7 +57,7 @@ impl UserManagement for UserManager {
.ok_or(Error::UserNotFound)
}

async fn get_user_by_key(&self, key: Key) -> Result<User, Error> {
async fn get_user_by_key(&self, key: ApiKey) -> Result<User, Error> {
query("SELECT account_name, key, account_tier FROM users WHERE key = ?1")
.bind(&key)
.fetch_optional(&self.pool)
Expand All @@ -72,7 +74,7 @@ impl UserManagement for UserManager {
#[derive(Clone, Deserialize, PartialEq, Eq, Serialize, Debug)]
pub struct User {
pub name: AccountName,
pub key: Key,
pub key: ApiKey,
pub account_tier: AccountTier,
}

Expand All @@ -81,7 +83,7 @@ impl User {
self.account_tier == AccountTier::Admin
}

pub fn new(name: AccountName, key: Key, account_tier: AccountTier) -> Self {
pub fn new(name: AccountName, key: ApiKey, account_tier: AccountTier) -> Self {
Self {
name,
key,
Expand All @@ -104,9 +106,9 @@ where
let user_manager: UserManagerState = UserManagerState::from_ref(state);

let user = user_manager
.get_user_by_key(key)
.get_user_by_key(key.as_ref().clone())
.await
// Absord any error into `Unauthorized`
// Absorb any error into `Unauthorized`
.map_err(|_| Error::Unauthorized)?;

// Record current account name for tracing purposes
Expand All @@ -120,16 +122,21 @@ impl From<User> for shuttle_common::models::user::Response {
fn from(user: User) -> Self {
Self {
name: user.name.to_string(),
key: user.key.to_string(),
key: user.key.as_ref().to_string(),
account_tier: user.account_tier.to_string(),
}
}
}

#[derive(Clone, Debug, sqlx::Type, PartialEq, Hash, Eq, Serialize, Deserialize)]
#[serde(transparent)]
#[sqlx(transparent)]
pub struct Key(String);
/// A wrapper around [ApiKey] so we can implement [FromRequestParts]
/// for it.
pub struct Key(ApiKey);

impl AsRef<ApiKey> for Key {
fn as_ref(&self) -> &ApiKey {
&self.0
}
}

#[async_trait]
impl<S> FromRequestParts<S> for Key
Expand All @@ -142,31 +149,14 @@ where
let key = TypedHeader::<Authorization<Bearer>>::from_request_parts(parts, state)
.await
.map_err(|_| Error::KeyMissing)
.and_then(|TypedHeader(Authorization(bearer))| bearer.token().trim().parse())?;
.and_then(|TypedHeader(Authorization(bearer))| {
let bearer = bearer.token().trim();
ApiKey::parse(bearer).map_err(|_| Self::Rejection::Unauthorized)
})?;

trace!(%key, "got bearer key");

Ok(key)
}
}

impl std::fmt::Display for Key {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}

impl FromStr for Key {
type Err = Error;

fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(Self(s.to_string()))
}
}
trace!("got bearer key");

impl Key {
pub fn new_random() -> Self {
Self(Alphanumeric.sample_string(&mut rand::thread_rng(), 16))
Ok(Key(key))
}
}

Expand Down
2 changes: 1 addition & 1 deletion auth/tests/api/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ async fn convert_api_key_to_jwt() {
// GET /auth/key with invalid bearer token.
let request = Request::builder()
.uri("/auth/key")
.header(AUTHORIZATION, "Bearer notadmin")
.header(AUTHORIZATION, "Bearer ndh9z58jttoefake")
.body(Body::empty())
.unwrap();

Expand Down
2 changes: 1 addition & 1 deletion auth/tests/api/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use shuttle_auth::{sqlite_init, ApiBuilder};
use sqlx::query;
use tower::ServiceExt;

pub(crate) const ADMIN_KEY: &str = "my-api-key";
pub(crate) const ADMIN_KEY: &str = "ndh9z58jttoes3qv";

pub(crate) struct TestApp {
pub router: Router,
Expand Down
Loading

0 comments on commit fae2733

Please sign in to comment.