Skip to content

Commit

Permalink
Merge pull request #15 from Lachstec/feature/clab-testing
Browse files Browse the repository at this point in the history
Implement skipping of certificate validation for testing purposes
  • Loading branch information
Lachstec authored Feb 27, 2024
2 parents 337e7c4 + 4f3a337 commit 1971774
Show file tree
Hide file tree
Showing 9 changed files with 277 additions and 107 deletions.
16 changes: 15 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ homepage = "https://github.com/Lachstec/ginmi"
[lib]
doctest = false

[features]
dangerous_configuration = ["dep:hyper-rustls", "tower-http/util", "tower-http/add-extension", "dep:rustls-pemfile", "dep:tokio-rustls"]

[dependencies]
tokio = { version = "1.35.1", features = ["rt-multi-thread", "macros"] }
prost = "0.12.3"
Expand All @@ -25,7 +28,18 @@ thiserror = "1.0.56"
tower-service = "0.3.2"
# Needs to match tonics version of http, else implementations of the Service trait break.
http = "0.2.0"
tower = "0.4.13"
tower = "0.4"

# Dependencies for dangerous configuration
hyper = { version = "0.14", features = ["http2"] }
hyper-rustls = { version = "0.24.0", optional = true, features = ["http2"] }
tower-http = { version = "0.4", optional = true}
rustls-pemfile = { version = "1", optional = true }
tokio-rustls = { version = "0.24.0", optional = true, features = ["dangerous_configuration"] }

[package.metadata.docs.rs]
all-features = true
rustdoc-args = ["--cfg", "docsrs"]

[dev-dependencies]
tokio-test = "0.4.3"
Expand Down
3 changes: 2 additions & 1 deletion build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@ fn main() {
"proto/google.proto",
],
&[proto_dir],
).expect("Failed to compile protobuf files");
)
.expect("Failed to compile protobuf files");
}
72 changes: 19 additions & 53 deletions src/auth.rs
Original file line number Diff line number Diff line change
@@ -1,64 +1,30 @@
use http::{HeaderValue, Request};
use std::error::Error;
use std::sync::Arc;
use std::task::{Context, Poll};
use tonic::codegen::Body;
use tower_service::Service;
use tonic::metadata::AsciiMetadataValue;
use tonic::service::Interceptor;
use tonic::{Request, Status};

/// Service that injects username and password into the request metadata
#[derive(Debug, Clone)]
pub struct AuthService<S> {
inner: S,
username: Option<Arc<HeaderValue>>,
password: Option<Arc<HeaderValue>>,
pub struct AuthInterceptor {
username: AsciiMetadataValue,
password: AsciiMetadataValue,
}

impl<S> AuthService<S> {
#[inline]
pub fn new(
inner: S,
username: Option<Arc<HeaderValue>>,
password: Option<Arc<HeaderValue>>,
) -> Self {
impl AuthInterceptor {
pub fn new(username: Option<AsciiMetadataValue>, password: Option<AsciiMetadataValue>) -> Self {
Self {
inner,
username,
password,
username: username.unwrap_or(AsciiMetadataValue::from_static("")),
password: password.unwrap_or(AsciiMetadataValue::from_static("")),
}
}
}

/// Implementation of Service so that it plays nicely with tonic.
/// Trait bounds have to match those specified on [`tonic::client::GrpcService`]
impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for AuthService<S>
where
S: Service<Request<ReqBody>, Response = ResBody>,
S::Error:,
ResBody: Body,
<ResBody as Body>::Error: Into<Box<dyn Error + Send + Sync>>,
{
type Response = S::Response;
type Error = S::Error;
type Future = S::Future;

#[inline]
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}

#[inline]
fn call(&mut self, mut request: Request<ReqBody>) -> Self::Future {
if let Some(user) = &self.username {
if let Some(pass) = &self.password {
request
.headers_mut()
.insert("username", user.as_ref().clone());
request
.headers_mut()
.insert("password", pass.as_ref().clone());
}
}

self.inner.call(request)
impl Interceptor for AuthInterceptor {
fn call(&mut self, mut request: Request<()>) -> Result<Request<()>, Status> {
request
.metadata_mut()
.insert("username", self.username.clone());
request
.metadata_mut()
.insert("password", self.password.clone());
Ok(request)
}
}
14 changes: 6 additions & 8 deletions src/client/capabilities.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use crate::Client;
use crate::gen::gnmi::CapabilityResponse;
use crate::gen::gnmi::ModelData;

Expand All @@ -7,7 +6,7 @@ pub use crate::gen::gnmi::Encoding;
/// Capabilities of a given gNMI Target device.
///
/// Contains information about the capabilities that supported by a gNMI Target device.
/// Obtained via [`Client::capabilities`].
/// Obtained via [Client::capabilities](super::Client::capabilities).
#[derive(Debug, Clone)]
pub struct Capabilities(pub CapabilityResponse);

Expand All @@ -16,7 +15,7 @@ impl<'a> Capabilities {
///
/// # Examples
/// ```rust
/// # use ginmi::{Client, Capabilities};
/// # use ginmi::client::{Client, Capabilities};
/// # fn main() -> std::io::Result<()> {
/// # tokio_test::block_on(async {
/// # const CERT: &str = "CA Certificate";
Expand Down Expand Up @@ -46,7 +45,7 @@ impl<'a> Capabilities {
///
/// # Examples
/// ```rust
/// # use ginmi::{Client, Capabilities};
/// # use ginmi::client::{Client, Capabilities};
/// # fn main() -> std::io::Result<()> {
/// # tokio_test::block_on(async {
/// # const CERT: &str = "CA Certificate";
Expand All @@ -71,7 +70,7 @@ impl<'a> Capabilities {
self.0.supported_models.contains(&ModelData {
name: name.to_string(),
organization: organization.to_string(),
version: version.to_string()
version: version.to_string(),
})
}

Expand All @@ -82,7 +81,7 @@ impl<'a> Capabilities {
///
/// # Examples
/// ```rust
/// # use ginmi::{Client, Capabilities, Encoding};
/// # use ginmi::client::{Client, Capabilities, Encoding};
/// # fn main() -> std::io::Result<()> {
/// # tokio_test::block_on(async {
/// # const CERT: &str = "CA Certificate";
Expand All @@ -106,10 +105,9 @@ impl<'a> Capabilities {
Encoding::Bytes => 1,
Encoding::Proto => 2,
Encoding::Ascii => 3,
Encoding::JsonIetf => 4
Encoding::JsonIetf => 4,
};

self.0.supported_encodings.contains(&enc)
}
}

73 changes: 45 additions & 28 deletions src/client/client.rs
Original file line number Diff line number Diff line change
@@ -1,26 +1,37 @@
use crate::auth::AuthService;
use super::capabilities::Capabilities;
#[cfg(feature = "dangerous_configuration")]
use super::dangerous::DangerousClientBuilder;
use crate::auth::AuthInterceptor;
use crate::error::GinmiError;
use crate::gen::gnmi::g_nmi_client::GNmiClient;
use crate::gen::gnmi::CapabilityRequest;
use super::capabilities::Capabilities;
use http::HeaderValue;
use hyper::body::Bytes;
use std::str::FromStr;
use std::sync::Arc;
use tonic::codegen::{Body, InterceptedService, StdError};
use tonic::metadata::AsciiMetadataValue;
use tonic::transport::{Certificate, Channel, ClientTlsConfig, Uri};

/// Provides the main functionality of connection to a target device
/// and manipulating configuration or querying telemetry.
#[derive(Debug, Clone)]
pub struct Client {
inner: GNmiClient<AuthService<Channel>>,
pub struct Client<T> {
pub(crate) inner: GNmiClient<T>,
}

impl<'a> Client {
impl<'a> Client<InterceptedService<Channel, AuthInterceptor>> {
/// Create a [`ClientBuilder`] that can create [`Client`]s.
pub fn builder(target: &'a str) -> ClientBuilder<'a> {
ClientBuilder::new(target)
}
}

impl<T> Client<T>
where
T: tonic::client::GrpcService<tonic::body::BoxBody>,
T::Error: Into<StdError>,
T::ResponseBody: Body<Data = Bytes> + Send + 'static,
<T::ResponseBody as Body>::Error: Into<StdError> + Send,
{
/// Returns information from the target device about its capabilities
/// according to the [gNMI Specification Section 3.2.2](https://github.com/openconfig/reference/blob/master/rpc/gnmi/gnmi-specification.md#322-the-capabilityresponse-message)
///
Expand Down Expand Up @@ -48,17 +59,17 @@ impl<'a> Client {

#[derive(Debug, Copy, Clone)]
pub struct Credentials<'a> {
username: &'a str,
password: &'a str,
pub(crate) username: &'a str,
pub(crate) password: &'a str,
}

/// Builder for [`Client`]s
///
/// Used to configure and create instances of [`Client`].
#[derive(Debug, Clone)]
pub struct ClientBuilder<'a> {
target: &'a str,
creds: Option<Credentials<'a>>,
pub(crate) target: &'a str,
pub(crate) creds: Option<Credentials<'a>>,
tls_settings: Option<ClientTlsConfig>,
}

Expand Down Expand Up @@ -87,14 +98,23 @@ impl<'a> ClientBuilder<'a> {
self
}

#[cfg(feature = "dangerous_configuration")]
#[cfg_attr(docsrs, doc(cfg(feature = "dangerous_configuration")))]
/// Access configuration options that are dangerous and require extra care.
pub fn dangerous(self) -> DangerousClientBuilder<'a> {
DangerousClientBuilder::from(self)
}

/// Consume the [`ClientBuilder`] and return a [`Client`].
///
/// # Errors
/// - Returns [`GinmiError::InvalidUriError`] if specified target is not a valid URI.
/// - Returns [`GinmiError::TransportError`] if the TLS-Settings are invalid.
/// - Returns [`GinmiError::TransportError`] if a connection to the target could not be
/// established.
pub async fn build(self) -> Result<Client, GinmiError> {
pub async fn build(
self,
) -> Result<Client<InterceptedService<Channel, AuthInterceptor>>, GinmiError> {
let uri = match Uri::from_str(self.target) {
Ok(u) => u,
Err(e) => return Err(GinmiError::InvalidUriError(e.to_string())),
Expand All @@ -107,22 +127,17 @@ impl<'a> ClientBuilder<'a> {
}

let channel = endpoint.connect().await?;

return if let Some(creds) = self.creds {
let user_header = HeaderValue::from_str(creds.username)?;
let pass_header = HeaderValue::from_str(creds.password)?;
Ok(Client {
inner: GNmiClient::new(AuthService::new(
channel,
Some(Arc::new(user_header)),
Some(Arc::new(pass_header)),
)),
})
} else {
Ok(Client {
inner: GNmiClient::new(AuthService::new(channel, None, None)),
})
let (username, password) = match self.creds {
Some(c) => (
Some(AsciiMetadataValue::from_str(c.username)?),
Some(AsciiMetadataValue::from_str(c.password)?),
),
None => (None, None),
};

Ok(Client {
inner: GNmiClient::with_interceptor(channel, AuthInterceptor::new(username, password)),
})
}
}

Expand All @@ -132,7 +147,9 @@ mod tests {

#[tokio::test]
async fn invalid_uri() {
let client = Client::builder("$$$$").build().await;
let client = Client::<InterceptedService<Channel, AuthInterceptor>>::builder("$$$$")
.build()
.await;
assert!(client.is_err());
}

Expand Down
Loading

0 comments on commit 1971774

Please sign in to comment.