Skip to content

Commit

Permalink
fix(resource-recorder)!: disable service id endpoint (#1644)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonaro00 authored Feb 29, 2024
1 parent b637bef commit 0b97911
Show file tree
Hide file tree
Showing 13 changed files with 217 additions and 148 deletions.
47 changes: 42 additions & 5 deletions common/src/backends/client/gateway.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,16 @@ impl Client {
/// Interact with all the data relating to projects
#[allow(async_fn_in_trait)]
pub trait ProjectsDal {
/// Get a user project
async fn get_user_project(
&self,
user_token: &str,
project_name: &str,
) -> Result<models::project::Response, Error>;

/// Check the HEAD of a user project
async fn head_user_project(&self, user_token: &str, project_name: &str) -> Result<bool, Error>;

/// Get the projects that belong to a user
async fn get_user_projects(
&self,
Expand All @@ -56,22 +66,49 @@ pub trait ProjectsDal {
}

impl ProjectsDal for Client {
#[instrument(skip_all)]
async fn get_user_project(
&self,
user_token: &str,
project_name: &str,
) -> Result<models::project::Response, Error> {
self.public_client
.request(
Method::GET,
format!("projects/{}", project_name).as_str(),
None::<()>,
Some(Authorization::bearer(user_token).expect("to build an authorization bearer")),
)
.await
}

#[instrument(skip_all)]
async fn head_user_project(&self, user_token: &str, project_name: &str) -> Result<bool, Error> {
self.public_client
.request_raw(
Method::HEAD,
format!("projects/{}", project_name).as_str(),
None::<()>,
Some(Authorization::bearer(user_token).expect("to build an authorization bearer")),
)
.await?;

Ok(true)
}

#[instrument(skip_all)]
async fn get_user_projects(
&self,
user_token: &str,
) -> Result<Vec<models::project::Response>, Error> {
let projects = self
.public_client
self.public_client
.request(
Method::GET,
"projects",
None::<()>,
Some(Authorization::bearer(user_token).expect("to build an authorization bearer")),
)
.await?;

Ok(projects)
.await
}
}

Expand Down
30 changes: 19 additions & 11 deletions common/src/backends/client/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use bytes::Bytes;
use headers::{ContentType, Header, HeaderMapExt};
use http::{Method, Request, StatusCode, Uri};
use hyper::{body, client::HttpConnector, Body, Client};
Expand Down Expand Up @@ -36,38 +37,48 @@ pub struct ServicesApiClient {
}

impl ServicesApiClient {
/// Make a new client that connects to the given endpoint
fn new(base: Uri) -> Self {
Self {
client: Client::new(),
base,
}
}

/// Make a get request to a path on the service
pub async fn request<B: Serialize, T: DeserializeOwned, H: Header>(
&self,
method: Method,
path: &str,
body: Option<B>,
extra_header: Option<H>,
) -> Result<T, Error> {
let bytes = self.request_raw(method, path, body, extra_header).await?;
let json = serde_json::from_slice(&bytes)?;

Ok(json)
}

pub async fn request_raw<B: Serialize, H: Header>(
&self,
method: Method,
path: &str,
body: Option<B>,
extra_header: Option<H>,
) -> Result<Bytes, Error> {
let uri = format!("{}{path}", self.base);
trace!(uri, "calling inner service");

let mut req = Request::builder().method(method).uri(uri);
let headers = req
.headers_mut()
.expect("new request to have mutable headers");

headers.typed_insert(ContentType::json());

if let Some(extra_header) = extra_header {
headers.typed_insert(extra_header);
}
if body.is_some() {
headers.typed_insert(ContentType::json());
}

let cx = Span::current().context();

global::get_text_map_propagator(|propagator| {
propagator.inject_context(&cx, &mut HeaderInjector(req.headers_mut().unwrap()))
});
Expand All @@ -79,18 +90,15 @@ impl ServicesApiClient {
};

let resp = self.client.request(req?).await?;

trace!(response = ?resp, "Load response");

if resp.status() != StatusCode::OK {
return Err(Error::RequestError(resp.status()));
}

let body = resp.into_body();
let bytes = body::to_bytes(body).await?;
let json = serde_json::from_slice(&bytes)?;
let bytes = body::to_bytes(resp.into_body()).await?;

Ok(json)
Ok(bytes)
}
}

Expand Down
26 changes: 23 additions & 3 deletions common/src/backends/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use tracing::instrument;

use crate::claims::{Claim, Scope};
use crate::claims::{AccountTier, Claim, Scope};

use self::client::{ProjectsDal, ResourceDal};

Expand All @@ -17,6 +17,7 @@ pub mod trace;
pub trait ClaimExt {
/// Verify that the [Claim] has the [Scope::Admin] scope.
fn is_admin(&self) -> bool;
fn is_deployer(&self) -> bool;
/// Verify that the user's current project count is lower than the account limit in [Claim::limits].
fn can_create_project(&self, current_count: u32) -> bool;
/// Verify that the user has permission to provision RDS instances.
Expand All @@ -31,12 +32,21 @@ pub trait ClaimExt {
projects_dal: &G,
project_name: &str,
) -> Result<bool, client::Error>;
/// Verify if the claim subject has ownership of a project.
async fn owns_project_id<G: ProjectsDal>(
&self,
projects_dal: &G,
project_id: &str,
) -> Result<bool, client::Error>;
}

impl ClaimExt for Claim {
fn is_admin(&self) -> bool {
self.scopes.contains(&Scope::Admin)
}
fn is_deployer(&self) -> bool {
self.tier == AccountTier::Deployer
}

fn can_create_project(&self, current_count: u32) -> bool {
self.is_admin() || self.limits.project_limit() > current_count
Expand Down Expand Up @@ -71,7 +81,17 @@ impl ClaimExt for Claim {
project_name: &str,
) -> Result<bool, client::Error> {
let token = self.token.as_ref().expect("token to be set");
let projects = projects_dal.get_user_projects(token).await?;
Ok(projects.iter().any(|project| project.name == project_name))
projects_dal.head_user_project(token, project_name).await
}

#[instrument(skip_all)]
async fn owns_project_id<G: ProjectsDal>(
&self,
projects_dal: &G,
project_id: &str,
) -> Result<bool, client::Error> {
let token = self.token.as_ref().expect("token to be set");
let projects = projects_dal.get_user_project_ids(token).await?;
Ok(projects.iter().any(|id| id == project_id))
}
}
26 changes: 13 additions & 13 deletions deployer/src/persistence/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ use shuttle_common::{claims::Claim, resource::Type};
use shuttle_proto::{
provisioner::{self, DatabaseRequest},
resource_recorder::{
self, record_request, RecordRequest, ResourceIds, ResourceResponse, ResourcesResponse,
ResultResponse, ServiceResourcesRequest,
self, record_request, ProjectResourcesRequest, RecordRequest, ResourceIds,
ResourceResponse, ResourcesResponse, ResultResponse,
},
};
use sqlx::{
Expand Down Expand Up @@ -364,17 +364,18 @@ impl ResourceManager for Persistence {
service_id: &Ulid,
claim: Claim,
) -> Result<ResourcesResponse> {
let mut req = tonic::Request::new(ServiceResourcesRequest {
service_id: service_id.to_string(),
let mut req = tonic::Request::new(ProjectResourcesRequest {
project_id: self.project_id.to_string(),
});

req.extensions_mut().insert(claim.clone());

info!(%service_id, "Getting resources from resource-recorder");
info!(%self.project_id, "Getting resources from resource-recorder");
let res = self
.resource_recorder_client
.as_mut()
.expect("to have the resource recorder set up")
.get_service_resources(req)
.get_project_resources(req)
.await
.map_err(PersistenceError::ResourceRecorder)
.map(|res| res.into_inner())?;
Expand All @@ -384,8 +385,7 @@ impl ResourceManager for Persistence {
info!("Got no resources from resource-recorder");
// Check if there are cached resources on the local persistence.
let resources: std::result::Result<Vec<Resource>, sqlx::Error> =
sqlx::query_as("SELECT * FROM resources WHERE service_id = ?")
.bind(service_id.to_string())
sqlx::query_as("SELECT * FROM resources")
.fetch_all(&self.pool)
.await;

Expand All @@ -410,17 +410,18 @@ impl ResourceManager for Persistence {
self.insert_resources(local_resources, service_id, claim.clone())
.await?;

let mut req = tonic::Request::new(ServiceResourcesRequest {
service_id: service_id.to_string(),
let mut req = tonic::Request::new(ProjectResourcesRequest {
project_id: self.project_id.to_string(),
});

req.extensions_mut().insert(claim);

info!("Getting resources from resource-recorder again");
let res = self
.resource_recorder_client
.as_mut()
.expect("to have the resource recorder set up")
.get_service_resources(req)
.get_project_resources(req)
.await
.map_err(PersistenceError::ResourceRecorder)
.map(|res| res.into_inner())?;
Expand All @@ -433,8 +434,7 @@ impl ResourceManager for Persistence {
info!("Deleting local resources");
// Now that we know that the resources are in resource-recorder,
// we can safely delete them from here to prevent de-sync issues and to not hinder project deletion
sqlx::query("DELETE FROM resources WHERE service_id = ?")
.bind(service_id.to_string())
sqlx::query("DELETE FROM resources")
.execute(&self.pool)
.await?;

Expand Down
2 changes: 1 addition & 1 deletion proto/resource-recorder.proto
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ service ResourceRecorder {
// Get the resources belonging to a project
rpc GetProjectResources(ProjectResourcesRequest) returns (ResourcesResponse);

// Get the resources belonging to a service
// Discontinued
rpc GetServiceResources(ServiceResourcesRequest) returns (ResourcesResponse);

// Get a resource
Expand Down
4 changes: 2 additions & 2 deletions proto/src/generated/resource_recorder.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

44 changes: 18 additions & 26 deletions provisioner/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use rand::Rng;
use shuttle_common::backends::auth::VerifyClaim;
use shuttle_common::backends::client::gateway;
use shuttle_common::backends::ClaimExt;
use shuttle_common::claims::Scope;
use shuttle_common::claims::{Claim, Scope};
use shuttle_common::models::project::ProjectName;
pub use shuttle_proto::provisioner::provisioner_server::ProvisionerServer;
use shuttle_proto::provisioner::{
Expand Down Expand Up @@ -460,6 +460,21 @@ impl ShuttleProvisioner {

Ok(DatabaseDeletionResponse {})
}

async fn verify_ownership(&self, claim: &Claim, project_name: &str) -> Result<(), Status> {
if !claim.is_admin()
&& !claim.is_deployer()
&& !claim
.owns_project(&self.gateway_client, project_name)
.await
.map_err(|_| Status::internal("could not verify project ownership"))?
{
let status = Status::permission_denied("the request lacks the authorizations");
error!(error = &status as &dyn std::error::Error);
return Err(status);
}
Ok(())
}
}

#[tonic::async_trait]
Expand All @@ -470,24 +485,12 @@ impl Provisioner for ShuttleProvisioner {
request: Request<DatabaseRequest>,
) -> Result<Response<DatabaseResponse>, Status> {
request.verify(Scope::ResourcesWrite)?;

let claim = request.get_claim()?;

let request = request.into_inner();
if !ProjectName::is_valid(&request.project_name) {
return Err(Status::invalid_argument("invalid project name"));
}

// Check project ownership.
if !claim
.owns_project(&self.gateway_client, &request.project_name)
.await
.map_err(|_| Status::internal("can not verify project ownership"))?
{
let status = Status::permission_denied("the request lacks the authorizations");
error!(error = &status as &dyn std::error::Error);
return Err(status);
}
self.verify_ownership(&claim, &request.project_name).await?;

let db_type = request.db_type.unwrap();

Expand Down Expand Up @@ -539,22 +542,11 @@ impl Provisioner for ShuttleProvisioner {
) -> Result<Response<DatabaseDeletionResponse>, Status> {
request.verify(Scope::ResourcesWrite)?;
let claim = request.get_claim()?;

let request = request.into_inner();
if !ProjectName::is_valid(&request.project_name) {
return Err(Status::invalid_argument("invalid project name"));
}

// Check project ownership.
if !claim
.owns_project(&self.gateway_client, &request.project_name)
.await
.map_err(|_| Status::internal("can not verify project ownership"))?
{
let status = Status::permission_denied("the request lacks the authorizations");
error!(error = &status as &dyn std::error::Error);
return Err(status);
}
self.verify_ownership(&claim, &request.project_name).await?;

let db_type = request.db_type.unwrap();

Expand Down
Loading

0 comments on commit 0b97911

Please sign in to comment.