Skip to content

Commit

Permalink
refactor(auth, gateway): use user_id over account_name (#1674)
Browse files Browse the repository at this point in the history
* feat(auth): user_id not null, remove migration insertions

* fix(auth): insert user_id on command inserts

* feat: rename account_name to user_id in most places

* nits

* nit2

* fix: auth tests almost working

* yeet: auth/refresh

* fix: span

* fix: auth tests

* feat: set old account name header for tracing in old deployers

* nit: clarify start_last_deploy

* fix: sql comment

* fix: userid comment

* feat(deployer): use claim instead of user id header for tracing

* nit: remove request.path tracing field

* Revert "nit: remove request.path tracing field"

This reverts commit 0be50c3.

* less clone

* feat(auth): keep get account by name endpoint

* fmt

* ci: unstable

* clippy

* fix: migration drift fixes

* fix: migration drift fixes 2

* revert: migration drift fixes

* fix: endpoint ordering

* test: set empty field on endpoint

* Revert "test: set empty field on endpoint"

This reverts commit bc82a89.
  • Loading branch information
jonaro00 authored Mar 15, 2024
1 parent 27c88d3 commit 4937f4b
Show file tree
Hide file tree
Showing 22 changed files with 325 additions and 414 deletions.
27 changes: 27 additions & 0 deletions auth/migrations/0004_user_ids_part2.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
-- All rows should have user_ids at this point (added in the application logic before this migration was introduced)
ALTER TABLE users
ALTER COLUMN user_id SET NOT NULL;


-- Switch the foreign key(fk) on subscriptions and remove the old fk
ALTER TABLE subscriptions
ADD COLUMN user_id TEXT;

UPDATE subscriptions
SET user_id = users.user_id
FROM users
WHERE subscriptions.account_name = users.account_name;

ALTER TABLE subscriptions
DROP CONSTRAINT subscriptions_account_name_fkey,
ADD FOREIGN KEY (user_id) REFERENCES users (user_id),
ALTER COLUMN user_id SET NOT NULL,
DROP COLUMN account_name,
-- Add back the unique pair constraint
ADD CONSTRAINT user_id_type UNIQUE (user_id, type);


-- Switch the primary key on users
ALTER TABLE users
DROP CONSTRAINT users_pkey,
ADD PRIMARY KEY (user_id);
19 changes: 11 additions & 8 deletions auth/src/api/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ use crate::{
};

use super::handlers::{
convert_key, delete_subscription, get_public_key, get_user, health_check, post_subscription,
post_user, put_user_reset_key, refresh_token,
convert_key, delete_subscription, get_public_key, get_user, get_user_by_name,
post_subscription, post_user, put_user_reset_key,
};

pub type UserManagerState = Arc<Box<dyn UserManagement>>;
Expand Down Expand Up @@ -62,24 +62,27 @@ impl Default for ApiBuilder {
impl ApiBuilder {
pub fn new() -> Self {
let router = Router::new()
.route("/", get(health_check))
// health check: 200 OK
.route("/", get(|| async move {}))
.route("/auth/key", get(convert_key))
.route("/auth/refresh", post(refresh_token))
.route("/public-key", get(get_public_key))
.route("/users/:account_name", get(get_user))
// used by console to get user based on auth0 name
.route("/users/name/:account_name", get(get_user_by_name))
// users are created based on auth0 name by console
.route("/users/:account_name/:account_tier", post(post_user))
.route("/users/:user_id", get(get_user))
.route("/users/reset-api-key", put(put_user_reset_key))
.route("/users/:account_name/subscribe", post(post_subscription))
.route("/users/:user_id/subscribe", post(post_subscription))
.route(
"/users/:account_name/subscribe/:subscription_id",
"/users/:user_id/subscribe/:subscription_id",
delete(delete_subscription),
)
.route_layer(from_extractor::<Metrics>())
.layer(
TraceLayer::new(|request| {
request_span!(
request,
request.params.account_name = field::Empty,
request.params.user_id = field::Empty,
request.params.account_tier = field::Empty,
)
})
Expand Down
55 changes: 28 additions & 27 deletions auth/src/api/handlers.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{
error::Error,
user::{AccountName, Admin, Key},
user::{Admin, Key},
};
use axum::{
extract::{Path, State},
Expand All @@ -9,33 +9,46 @@ use axum::{
use http::StatusCode;
use shuttle_common::{
claims::{AccountTier, Claim},
models::user::{self, SubscriptionRequest},
models::user::{self, SubscriptionRequest, UserId},
};
use tracing::instrument;
use tracing::{field, instrument, Span};

use super::{
builder::{KeyManagerState, UserManagerState},
RouterState,
};

#[instrument(skip_all, fields(account.name = %account_name))]
#[instrument(skip_all, fields(account.user_id = %user_id))]
pub(crate) async fn get_user(
_: Admin,
State(user_manager): State<UserManagerState>,
Path(account_name): Path<AccountName>,
Path(user_id): Path<UserId>,
) -> Result<Json<user::Response>, Error> {
let user = user_manager.get_user(account_name).await?;
let user = user_manager.get_user(user_id).await?;

Ok(Json(user.into()))
}

#[instrument(skip_all, fields(account.name = %account_name, account.tier = %account_tier))]
#[instrument(skip_all, fields(account.name = %account_name, account.user_id = field::Empty))]
pub(crate) async fn get_user_by_name(
_: Admin,
State(user_manager): State<UserManagerState>,
Path(account_name): Path<String>,
) -> Result<Json<user::Response>, Error> {
let user = user_manager.get_user_by_name(&account_name).await?;
Span::current().record("account.user_id", &user.id);

Ok(Json(user.into()))
}

#[instrument(skip_all, fields(account.name = %account_name, account.tier = %account_tier, account.user_id = field::Empty))]
pub(crate) async fn post_user(
_: Admin,
State(user_manager): State<UserManagerState>,
Path((account_name, account_tier)): Path<(AccountName, AccountTier)>,
Path((account_name, account_tier)): Path<(String, AccountTier)>,
) -> Result<Json<user::Response>, Error> {
let user = user_manager.create_user(account_name, account_tier).await?;
Span::current().record("account.user_id", &user.id);

Ok(Json(user.into()))
}
Expand All @@ -44,24 +57,19 @@ pub(crate) async fn put_user_reset_key(
State(user_manager): State<UserManagerState>,
key: Key,
) -> Result<(), Error> {
let account_name = user_manager.get_user_by_key(key.into()).await?.name;
let user_id = user_manager.get_user_by_key(key.into()).await?.id;

user_manager.reset_key(account_name).await
user_manager.reset_key(user_id).await
}

pub(crate) async fn post_subscription(
_: Admin,
State(user_manager): State<UserManagerState>,
Path(account_name): Path<AccountName>,
Path(user_id): Path<UserId>,
payload: Json<SubscriptionRequest>,
) -> Result<(), Error> {
user_manager
.insert_subscription(
&account_name,
&payload.id,
&payload.r#type,
payload.quantity,
)
.insert_subscription(&user_id, &payload.id, &payload.r#type, payload.quantity)
.await?;

Ok(())
Expand All @@ -70,20 +78,15 @@ pub(crate) async fn post_subscription(
pub(crate) async fn delete_subscription(
_: Admin,
State(user_manager): State<UserManagerState>,
Path((account_name, subscription_id)): Path<(AccountName, String)>,
Path((user_id, subscription_id)): Path<(UserId, String)>,
) -> Result<(), Error> {
user_manager
.delete_subscription(&account_name, &subscription_id)
.delete_subscription(&user_id, &subscription_id)
.await?;

Ok(())
}

// Dummy health-check returning 200 if the auth server is up.
pub(crate) async fn health_check() -> Result<(), Error> {
Ok(())
}

/// Convert a valid API-key bearer token to a JWT.
pub(crate) async fn convert_key(
_: Admin,
Expand All @@ -99,7 +102,7 @@ pub(crate) async fn convert_key(
.map_err(|_| StatusCode::UNAUTHORIZED)?;

let claim = Claim::new(
user.name.to_string(),
user.id.clone(),
user.account_tier.into(),
user.account_tier,
user,
Expand All @@ -112,8 +115,6 @@ pub(crate) async fn convert_key(
Ok(Json(response))
}

pub(crate) async fn refresh_token() {}

pub(crate) async fn get_public_key(State(key_manager): State<KeyManagerState>) -> Vec<u8> {
key_manager.public_key().to_vec()
}
5 changes: 3 additions & 2 deletions auth/src/args.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::net::SocketAddr;

use clap::{Parser, Subcommand};
use shuttle_common::models::user::UserId;

#[derive(Parser, Debug)]
pub struct Args {
Expand Down Expand Up @@ -37,9 +38,9 @@ pub struct StartArgs {

#[derive(clap::Args, Debug, Clone)]
pub struct InitArgs {
/// Name of initial account to create
/// User id of account to create
#[arg(long)]
pub name: String,
pub user_id: UserId,
/// Key to assign to initial account
#[arg(long)]
pub key: Option<String>,
Expand Down
23 changes: 4 additions & 19 deletions auth/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,18 @@ pub async fn init(pool: PgPool, args: InitArgs, tier: AccountTier) -> io::Result
None => ApiKey::generate(),
};

query("INSERT INTO users (account_name, key, account_tier) VALUES ($1, $2, $3)")
.bind(&args.name)
query("INSERT INTO users (account_name, key, account_tier, user_id) VALUES ($1, $2, $3, $4)")
.bind("")
.bind(&key)
.bind(tier.to_string())
.bind(&args.user_id)
.execute(&pool)
.await
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;

println!(
"`{}` created as {} with key: {}",
args.name,
args.user_id,
tier,
key.as_ref()
);
Expand All @@ -68,21 +69,5 @@ pub async fn pgpool_init(db_uri: &str) -> io::Result<PgPool> {
.await
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;

// Post-migration logic for 0003.
// This is done here to skip the need for postgres extensions.
let names: Vec<(String,)> =
sqlx::query_as("SELECT account_name FROM users WHERE user_id IS NULL")
.fetch_all(&pool)
.await
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
for (name,) in names {
sqlx::query("UPDATE users SET user_id = $1 WHERE account_name = $2")
.bind(User::new_user_id())
.bind(name)
.execute(&pool)
.await
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
}

Ok(pool)
}
Loading

0 comments on commit 4937f4b

Please sign in to comment.