Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rpc module: report error on invalid subscription #561

Merged
merged 4 commits into from
Nov 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions test-utils/src/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
use crate::mocks::{Body, HttpResponse, Id, Uri};
use hyper::service::{make_service_fn, service_fn};
use hyper::{Request, Response, Server};
use serde::Serialize;
use serde_json::Value;
use std::convert::Infallible;
use std::net::SocketAddr;
Expand Down Expand Up @@ -95,6 +96,15 @@ pub fn invalid_params(id: Id) -> String {
)
}

pub fn call<T: Serialize>(method: &str, params: Vec<T>, id: Id) -> String {
format!(
r#"{{"jsonrpc":"2.0","method":{},"params":{},"id":{}}}"#,
serde_json::to_string(method).unwrap(),
serde_json::to_string(&params).unwrap(),
serde_json::to_string(&id).unwrap()
)
}

pub fn call_execution_failed(msg: &str, id: Id) -> String {
format!(
r#"{{"jsonrpc":"2.0","error":{{"code":-32000,"message":"{}"}},"id":{}}}"#,
Expand Down
23 changes: 22 additions & 1 deletion types/src/v2/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,13 @@ pub struct RpcError<'a> {
pub id: Id<'a>,
}

impl<'a> RpcError<'a> {
/// Create a new `RpcError`.
pub fn new(error: ErrorObject<'a>, id: Id<'a>) -> Self {
Self { jsonrpc: TwoPointZero, error, id }
}
}

impl<'a> fmt::Display for RpcError<'a> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", serde_json::to_string(&self).expect("infallible; qed"))
Expand All @@ -64,6 +71,13 @@ pub struct ErrorObject<'a> {
pub data: Option<&'a RawValue>,
}

impl<'a> ErrorObject<'a> {
/// Create a new `ErrorObject` with optional data.
pub fn new(code: ErrorCode, data: Option<&'a RawValue>) -> ErrorObject<'a> {
Self { code, message: code.message(), data }
}
}

impl<'a> From<ErrorCode> for ErrorObject<'a> {
fn from(code: ErrorCode) -> Self {
Self { code, message: code.message(), data: None }
Expand All @@ -73,7 +87,7 @@ impl<'a> From<ErrorCode> for ErrorObject<'a> {
impl<'a> PartialEq for ErrorObject<'a> {
fn eq(&self, other: &Self) -> bool {
let this_raw = self.data.map(|r| r.get());
let other_raw = self.data.map(|r| r.get());
let other_raw = other.data.map(|r| r.get());
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unrelated bug I found

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops. No test for this eh?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed

self.code == other.code && self.message == other.message && this_raw == other_raw
}
}
Expand All @@ -98,6 +112,8 @@ pub const SERVER_IS_BUSY_CODE: i32 = -32604;
pub const CALL_EXECUTION_FAILED_CODE: i32 = -32000;
/// Unknown error.
pub const UNKNOWN_ERROR_CODE: i32 = -32001;
/// Invalid subscription error code.
pub const INVALID_SUBSCRIPTION_CODE: i32 = -32002;

/// Parse error message
pub const PARSE_ERROR_MSG: &str = "Parse error";
Expand Down Expand Up @@ -212,6 +228,11 @@ impl serde::Serialize for ErrorCode {
}
}

/// Create a invalid subscription ID error.
pub fn invalid_subscription_err(data: Option<&RawValue>) -> ErrorObject {
ErrorObject::new(ErrorCode::ServerError(INVALID_SUBSCRIPTION_CODE), data)
}

#[cfg(test)]
mod tests {
use super::{ErrorCode, ErrorObject, Id, RpcError, TwoPointZero};
Expand Down
17 changes: 13 additions & 4 deletions utils/src/server/rpc_module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ use crate::server::resource_limiting::{ResourceGuard, ResourceTable, ResourceVec
use beef::Cow;
use futures_channel::{mpsc, oneshot};
use futures_util::{future::BoxFuture, FutureExt, StreamExt};
use jsonrpsee_types::to_json_raw_value;
use jsonrpsee_types::v2::error::{invalid_subscription_err, CALL_EXECUTION_FAILED_CODE};
use jsonrpsee_types::{
error::{Error, SubscriptionClosedError},
traits::ToRpcParams,
Expand Down Expand Up @@ -587,7 +589,7 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> {
err,
id
);
send_error(id, method_sink, ErrorCode::ServerError(-1).into());
send_error(id, method_sink, ErrorCode::ServerError(CALL_EXECUTION_FAILED_CODE).into());
}
})),
);
Expand All @@ -605,12 +607,18 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> {
unsubscribe_method_name,
id
);
send_error(id, tx, ErrorCode::ServerError(-1).into());
let err = to_json_raw_value(&"Invalid subscription ID type, must be integer").ok();
send_error(id, tx, invalid_subscription_err(err.as_deref()));
return;
}
};
subscribers.lock().remove(&SubscriptionKey { conn_id, sub_id });
send_response(id, tx, "Unsubscribed", max_response_size);

if subscribers.lock().remove(&SubscriptionKey { conn_id, sub_id }).is_some() {
send_response(id, tx, "Unsubscribed", max_response_size);
} else {
let err = to_json_raw_value(&format!("Invalid subscription ID={}", sub_id)).ok();
send_error(id, tx, invalid_subscription_err(err.as_deref()))
}
})),
);
}
Expand Down Expand Up @@ -698,6 +706,7 @@ impl SubscriptionSink {
fn inner_close(&mut self, err: &SubscriptionClosedError) {
self.is_connected.take();
if let Some((sink, _)) = self.subscribers.lock().remove(&self.uniq_sub) {
tracing::debug!("Closing subscription: {:?}", self.uniq_sub.sub_id);
let msg = self.build_message(err).expect("valid json infallible; qed");
let _ = sink.unbounded_send(msg);
}
Expand Down
53 changes: 53 additions & 0 deletions ws-server/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,16 @@
#![cfg(test)]

use crate::types::error::{CallError, Error};
use crate::types::v2::{self, Response, RpcError};
use crate::types::DeserializeOwned;
use crate::{future::ServerHandle, RpcModule, WsServerBuilder};
use anyhow::anyhow;
use futures_util::future::join;
use jsonrpsee_test_utils::helpers::*;
use jsonrpsee_test_utils::mocks::{Id, TestContext, WebSocketTestClient, WebSocketTestError};
use jsonrpsee_test_utils::TimeoutFutureExt;
use jsonrpsee_types::to_json_raw_value;
use jsonrpsee_types::v2::error::invalid_subscription_err;
use serde_json::Value as JsonValue;
use std::{fmt, net::SocketAddr, time::Duration};
use tracing_subscriber::{EnvFilter, FmtSubscriber};
Expand All @@ -41,6 +45,11 @@ fn init_logger() {
let _ = FmtSubscriber::builder().with_env_filter(EnvFilter::from_default_env()).try_init();
}

fn deser_call<T: DeserializeOwned>(raw: String) -> T {
let out: Response<T> = serde_json::from_str(&raw).unwrap();
out.result
}

/// Applications can/should provide their own error.
#[derive(Debug)]
struct MyAppError;
Expand Down Expand Up @@ -107,6 +116,15 @@ async fn server_with_handles() -> (SocketAddr, ServerHandle) {
Ok("Yawn!")
})
.unwrap();
module
.register_subscription("subscribe_hello", "unsubscribe_hello", |_, sink, _| {
std::thread::spawn(move || loop {
let _ = sink;
std::thread::sleep(std::time::Duration::from_secs(30));
});
Ok(())
})
.unwrap();

let addr = server.local_addr().unwrap();

Expand Down Expand Up @@ -569,3 +587,38 @@ async fn run_forever() {
// Send the shutdown request from one handle and await the server on the second one.
join(server_handle.clone().stop().unwrap(), server_handle).with_timeout(TIMEOUT).await.unwrap();
}

#[tokio::test]
async fn unsubscribe_twice_should_indicate_error() {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ideally these tests should be in the RPC module but I have to refactor the call_with for that to work because inner channel is dropped in call_with/call/execute...

init_logger();
let addr = server().await;
let mut client = WebSocketTestClient::new(addr).with_default_timeout().await.unwrap().unwrap();

let sub_call = call("subscribe_hello", Vec::<()>::new(), Id::Num(0));
let sub_id: u64 = deser_call(client.send_request_text(sub_call).await.unwrap());

let unsub_call = call("unsubscribe_hello", vec![sub_id], Id::Num(1));
let unsub_1: String = deser_call(client.send_request_text(unsub_call).await.unwrap());
assert_eq!(&unsub_1, "Unsubscribed");

let unsub_call = call("unsubscribe_hello", vec![sub_id], Id::Num(2));
let unsub_2 = client.send_request_text(unsub_call).await.unwrap();
let unsub_2_err: RpcError = serde_json::from_str(&unsub_2).unwrap();
let sub_id = to_json_raw_value(&sub_id).unwrap();

let err = Some(to_json_raw_value(&format!("Invalid subscription ID={}", sub_id)).unwrap());
assert_eq!(unsub_2_err, RpcError::new(invalid_subscription_err(err.as_deref()), v2::Id::Number(2)));
}

#[tokio::test]
async fn unsubscribe_wrong_sub_id_type() {
init_logger();
let addr = server().await;
let mut client = WebSocketTestClient::new(addr).with_default_timeout().await.unwrap().unwrap();

let unsub =
client.send_request_text(call("unsubscribe_hello", vec!["string_is_not_supported"], Id::Num(0))).await.unwrap();
let unsub_2_err: RpcError = serde_json::from_str(&unsub).unwrap();
let err = Some(to_json_raw_value(&"Invalid subscription ID type, must be integer").unwrap());
assert_eq!(unsub_2_err, RpcError::new(invalid_subscription_err(err.as_deref()), v2::Id::Number(0)));
}