Skip to content

Commit

Permalink
feat: gateway init (#363)
Browse files Browse the repository at this point in the history
* refactor: split off start command

* refactor: move migration to main

* feat: init command

* refactor: clap 4 convention

* refactor: StartArgs
  • Loading branch information
chesedo authored Sep 30, 2022
1 parent 45d4976 commit d434f19
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 46 deletions.
53 changes: 45 additions & 8 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion gateway/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ axum = { version = "0.5.8", features = [ "headers" ] }
base64 = "0.13"
bollard = "0.13"
chrono = "0.4"
clap = { version = "3.1", features = [ "derive" ] }
clap = { version = "4.0.0", features = [ "derive" ] }
convert_case = "0.5.0"
futures = "0.3.21"
http = "0.2.8"
Expand Down
47 changes: 36 additions & 11 deletions gateway/src/args.rs
Original file line number Diff line number Diff line change
@@ -1,30 +1,55 @@
use std::net::SocketAddr;

use clap::Parser;
use clap::{Parser, Subcommand};

#[derive(Parser, Debug, Clone)]
use crate::auth::Key;

#[derive(Parser, Debug)]
pub struct Args {
/// Uri to the `.sqlite` file used to store state
#[arg(long, default_value = "./gateway.sqlite")]
pub state: String,

#[command(subcommand)]
pub command: Commands,
}

#[derive(Subcommand, Debug)]
pub enum Commands {
Start(StartArgs),
Init(InitArgs),
}

#[derive(clap::Args, Debug, Clone)]
pub struct StartArgs {
/// Address to bind the control plane to
#[clap(long, default_value = "127.0.0.1:8001")]
#[arg(long, default_value = "127.0.0.1:8001")]
pub control: SocketAddr,
/// Address to bind the user plane to
#[clap(long, default_value = "127.0.0.1:8000")]
#[arg(long, default_value = "127.0.0.1:8000")]
pub user: SocketAddr,
/// Default image to deploy user runtimes into
#[clap(long, default_value = "public.ecr.aws/shuttle/deployer:latest")]
#[arg(long, default_value = "public.ecr.aws/shuttle/deployer:latest")]
pub image: String,
/// Prefix to add to the name of all docker resources managed by
/// this service
#[clap(long, default_value = "shuttle_prod_")]
#[arg(long, default_value = "shuttle_prod_")]
pub prefix: String,
/// The address at which an active runtime container will find
/// the provisioner service
#[clap(long, default_value = "provisioner")]
#[arg(long, default_value = "provisioner")]
pub provisioner_host: String,
/// The Docker Network name in which to deploy user runtimes
#[clap(long, default_value = "shuttle_default")]
#[arg(long, default_value = "shuttle_default")]
pub network_name: String,
/// Uri to the `.sqlite` file used to store state
#[clap(long, default_value = "./gateway.sqlite")]
pub state: String,
}

#[derive(clap::Args, Debug, Clone)]
pub struct InitArgs {
/// Name of initial account to create
#[arg(long)]
pub name: String,
/// Key to assign to initial account
#[arg(long)]
pub key: Option<Key>,
}
50 changes: 47 additions & 3 deletions gateway/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
use clap::Parser;
use futures::prelude::*;
use shuttle_gateway::api::make_api;
use shuttle_gateway::args::Args;
use shuttle_gateway::args::{Args, Commands, InitArgs};
use shuttle_gateway::auth::Key;
use shuttle_gateway::proxy::make_proxy;
use shuttle_gateway::service::GatewayService;
use shuttle_gateway::worker::Worker;
use shuttle_gateway::{api::make_api, args::StartArgs};
use sqlx::migrate::{MigrateDatabase, Migrator};
use sqlx::{query, Sqlite, SqlitePool};
use std::io;
use std::path::Path;
use std::sync::Arc;
use tracing::{error, info, trace};
use tracing_subscriber::{fmt, prelude::*, EnvFilter};

static MIGRATIONS: Migrator = sqlx::migrate!("./migrations");

#[tokio::main]
async fn main() -> io::Result<()> {
let args = Args::parse();
Expand All @@ -32,7 +38,28 @@ async fn main() -> io::Result<()> {
.with(opentelemetry)
.init();

let gateway = Arc::new(GatewayService::init(args.clone()).await);
if !Path::new(&args.state).exists() {
Sqlite::create_database(&args.state).await.unwrap();
}

info!(
"state db: {}",
std::fs::canonicalize(&args.state)
.unwrap()
.to_string_lossy()
);
let db = SqlitePool::connect(&args.state).await.unwrap();

MIGRATIONS.run(&db).await.unwrap();

match args.command {
Commands::Start(start_args) => start(db, start_args).await,
Commands::Init(init_args) => init(db, init_args).await,
}
}

async fn start(db: SqlitePool, args: StartArgs) -> io::Result<()> {
let gateway = Arc::new(GatewayService::init(args.clone(), db).await);

let worker = Worker::new(Arc::clone(&gateway));
gateway.set_sender(Some(worker.sender())).await.unwrap();
Expand Down Expand Up @@ -61,3 +88,20 @@ async fn main() -> io::Result<()> {

Ok(())
}

async fn init(db: SqlitePool, args: InitArgs) -> io::Result<()> {
let key = match args.key {
Some(key) => key,
None => Key::new_random(),
};

query("INSERT INTO accounts (account_name, key, super_user) VALUES (?1, ?2, 1)")
.bind(&args.name)
.bind(&key)
.execute(&db)
.await
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;

println!("`{}` created as super user with key: {key}", args.name);
Ok(())
}
29 changes: 6 additions & 23 deletions gateway/src/service.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use std::collections::HashMap;
use std::path::Path as StdPath;
use std::sync::Arc;

use axum::body::Body;
Expand All @@ -15,23 +14,21 @@ use hyper_reverse_proxy::ReverseProxy;
use once_cell::sync::Lazy;
use rand::distributions::{Alphanumeric, DistString};
use sqlx::error::DatabaseError;
use sqlx::migrate::{MigrateDatabase, Migrator};
use sqlx::sqlite::{Sqlite, SqlitePool};
use sqlx::sqlite::SqlitePool;
use sqlx::types::Json as SqlxJson;
use sqlx::{query, Error as SqlxError, Row};
use tokio::sync::mpsc::Sender;
use tokio::sync::Mutex;
use tracing::{debug, error, info};
use tracing::{debug, error};

use crate::args::Args;
use crate::args::StartArgs;
use crate::auth::{Key, User};
use crate::project::{self, Project};
use crate::worker::Work;
use crate::{AccountName, Context, Error, ErrorKind, ProjectName, Refresh, Service};

static PROXY_CLIENT: Lazy<ReverseProxy<HttpConnector<GaiResolver>>> =
Lazy::new(|| ReverseProxy::new(Client::new()));
static MIGRATIONS: Migrator = sqlx::migrate!("./migrations");

impl From<SqlxError> for Error {
fn from(err: SqlxError) -> Self {
Expand Down Expand Up @@ -59,8 +56,8 @@ impl<'d> ContainerSettingsBuilder<'d> {
}
}

pub async fn from_args(self, args: &Args) -> ContainerSettings {
let Args {
pub async fn from_args(self, args: &StartArgs) -> ContainerSettings {
let StartArgs {
prefix,
network_name,
provisioner_host,
Expand Down Expand Up @@ -182,27 +179,13 @@ impl GatewayService {
///
/// * `args` - The [`Args`] with which the service was
/// started. Will be passed as [`Context`] to workers and state.
pub async fn init(args: Args) -> Self {
pub async fn init(args: StartArgs, db: SqlitePool) -> Self {
let docker = Docker::connect_with_local_defaults().unwrap();

let container_settings = ContainerSettings::builder(&docker).from_args(&args).await;

let provider = GatewayContextProvider::new(docker, container_settings);

let state = args.state;

if !StdPath::new(&state).exists() {
Sqlite::create_database(&state).await.unwrap();
}

info!(
"state db: {}",
std::fs::canonicalize(&state).unwrap().to_string_lossy()
);
let db = SqlitePool::connect(&state).await.unwrap();

MIGRATIONS.run(&db).await.unwrap();

let sender = Mutex::new(None);

Self {
Expand Down

0 comments on commit d434f19

Please sign in to comment.