Skip to content

Commit

Permalink
Merge pull request #145 from kpcyrd/ratelimits
Browse files Browse the repository at this point in the history
Add ratelimit_throttle
  • Loading branch information
kpcyrd authored Jan 12, 2020
2 parents f4a305d + bb1feaf commit aab8a52
Show file tree
Hide file tree
Showing 10 changed files with 164 additions and 1 deletion.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ For everything else please have a look at the [detailed list][1].
- [pgp_pubkey_armored](https://sn0int.readthedocs.io/en/latest/reference.html#pgp-pubkey-armored)
- [print](https://sn0int.readthedocs.io/en/latest/reference.html#print)
- [psl_domain_from_dns_name](https://sn0int.readthedocs.io/en/latest/reference.html#psl-domain-from-dns-name)
- [ratelimit_throttle](https://sn0int.readthedocs.io/en/latest/reference.html#ratelimit-throttle)
- [regex_find](https://sn0int.readthedocs.io/en/latest/reference.html#regex-find)
- [regex_find_all](https://sn0int.readthedocs.io/en/latest/reference.html#regex-find-all)
- [semver_match](https://sn0int.readthedocs.io/en/latest/reference.html#semver-match)
Expand Down
16 changes: 16 additions & 0 deletions docs/reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,22 @@ Returns the parent domain according to the public suffix list. For
domain = psl_domain_from_dns_name('www.a.b.c.d.example.co.uk')
print(domain == 'example.co.uk')
ratelimit_throttle
------------------

Create a ratelimit that can only be passed x times every y milliseconds. This
limit is global for a single ``run`` and also works with threads.

.. code-block:: lua
-- allow this to pass every 250ms
ratelimit_throttle('foo', 1, 250)
-- allow this to pass not more than 4 times per second
ratelimit_throttle('foo', 4, 1000)
This is useful if you need to coordinate your executions to stay below a
certain request threshold.

regex_find
----------

Expand Down
11 changes: 11 additions & 0 deletions modules/harness/ratelimit.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
-- Description: Run script with a global ratelimit
-- Version: 0.1.0
-- License: GPL-3.0

function run()
-- this shouldn't complete in less than 5 seconds
for i=1, 20 do
ratelimit_throttle('foo', 4, 1000)
info(sn0int_time())
end
end
20 changes: 19 additions & 1 deletion src/engine/ctx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@ use crate::lazy::Lazy;
use crate::runtime;
use crate::sockets::{Socket, SocketOptions, TlsData};
use crate::web::{HttpSession, HttpRequest, RequestOptions};
use crate::worker::{Event, LogEvent, DatabaseEvent, StdioEvent};
use crate::worker::{Event, LogEvent, DatabaseEvent, StdioEvent, RatelimitEvent};
use crate::ratelimits::RatelimitResponse;
use chrootable_https::{self, Resolver};
use serde_json;
use std::collections::HashMap;
use std::result;
use std::net::SocketAddr;
use std::sync::{Arc, Mutex};
use std::thread;
use rand::prelude::*;
use rand::distributions::Alphanumeric;

Expand Down Expand Up @@ -114,6 +116,21 @@ pub trait State {
reply.map_err(|err| format_err!("Failed to read stdin: {:?}", err))
}

fn ratelimit(&self, key: String, passes: u32, time: u32) -> Result<()> {
let ratelimit = Event::Ratelimit(RatelimitEvent::new(key, passes, time));
loop {
self.send(&ratelimit);
let reply = self.recv()?;
let reply: result::Result<RatelimitResponse, String> = serde_json::from_value(reply)?;
match reply {
Ok(RatelimitResponse::Retry(delay)) => thread::sleep(delay),
Ok(RatelimitResponse::Pass) => break,
Err(err) => bail!("Unexpected error case for ratelimit: {}", err),
}
}
Ok(())
}

#[inline]
fn random_id(&self) -> String {
thread_rng().sample_iter(&Alphanumeric).take(16).collect()
Expand Down Expand Up @@ -451,6 +468,7 @@ pub fn ctx<'a>(env: Environment, logger: Arc<Mutex<Box<dyn Reporter>>>) -> (hlua
runtime::pgp_pubkey_armored(&mut lua, state.clone());
runtime::print(&mut lua, state.clone());
runtime::psl_domain_from_dns_name(&mut lua, state.clone());
runtime::ratelimit_throttle(&mut lua, state.clone());
runtime::regex_find(&mut lua, state.clone());
runtime::regex_find_all(&mut lua, state.clone());
runtime::semver_match(&mut lua, state.clone());
Expand Down
1 change: 1 addition & 0 deletions src/engine/isolation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ pub fn spawn_module(module: Module,
Event::Log(event) => tx.send(Event2::Log(event)),
Event::Database(object) => supervisor.send_event_callback(object, &tx),
Event::Stdio(object) => object.apply(&mut supervisor, tx, &mut reader),
Event::Ratelimit(req) => supervisor.send_event_callback(req, &tx),
Event::Blob(blob) => supervisor.send_event_callback(blob, &tx),
Event::Exit(event) => {
if let ExitEvent::Err(err) = &event {
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ pub mod models;
pub mod paths;
pub mod psl;
pub mod options;
pub mod ratelimits;
pub mod registry;
pub mod repl;
pub mod runtime;
Expand Down
67 changes: 67 additions & 0 deletions src/ratelimits.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
use chrono::prelude::*;
use crate::worker::RatelimitSender;
use std::collections::HashMap;
use std::time::Duration;

pub struct Ratelimiter {
buckets: HashMap<String, Bucket>,
}

impl Ratelimiter {
pub fn new() -> Ratelimiter {
Ratelimiter {
buckets: HashMap::new(),
}
}

pub fn pass(&mut self, tx: RatelimitSender, key: &str, passes: u32, time: u32) {
let bucket = if let Some(bucket) = self.buckets.get_mut(key) {
bucket
} else {
let bucket = Bucket::new();
self.buckets.insert(key.to_string(), bucket);
self.buckets.get_mut(key).unwrap()
};
let reply = bucket.pass(passes as usize, time);
tx.send(Ok(reply)).unwrap();
}
}

struct Bucket {
passes: Vec<DateTime<Utc>>,
}

impl Bucket {
pub fn new() -> Bucket {
Bucket {
passes: Vec::new(),
}
}

pub fn pass(&mut self, passes: usize, time: u32) -> RatelimitResponse {
let now = Utc::now();
let time = chrono::Duration::milliseconds(time as i64);
let retain = now - time;
self.passes.retain(|x| *x >= retain);

if self.passes.len() >= passes {
if let Some(min) = self.passes.iter().min() {
let delay = time - (now - *min);
RatelimitResponse::Retry(delay.to_std().unwrap())
} else {
// This should never happen unless passes is zero
RatelimitResponse::Retry(Duration::from_millis(100))
}
} else {
let now = Utc::now();
self.passes.push(now);
RatelimitResponse::Pass
}
}
}

#[derive(Debug, Serialize, Deserialize)]
pub enum RatelimitResponse {
Retry(Duration),
Pass,
}
1 change: 1 addition & 0 deletions src/runtime/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import_fns!(logger);
import_fns!(options);
import_fns!(pgp);
import_fns!(psl);
import_fns!(ratelimits);
import_fns!(regex);
import_fns!(semver);
import_fns!(sleep);
Expand Down
12 changes: 12 additions & 0 deletions src/runtime/ratelimits.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
use crate::errors::*;
use crate::engine::ctx::State;
use crate::hlua;
use std::sync::Arc;


pub fn ratelimit_throttle(lua: &mut hlua::Lua, state: Arc<dyn State>) {
lua.set("ratelimit_throttle", hlua::function3(move |key: String, passes: u32, time: u32| -> Result<()> {
state.ratelimit(key, passes, time)
.map_err(|e| state.set_error(e))
}))
}
35 changes: 35 additions & 0 deletions src/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use crate::engine::{self, Module};
use crate::engine::isolation::Supervisor;
use crate::models::*;
use serde_json;
use crate::ratelimits::{Ratelimiter, RatelimitResponse};
use crate::shell::Shell;
use std::collections::HashMap;
use std::result;
Expand All @@ -23,12 +24,14 @@ use threadpool::ThreadPool;

type DbSender = mpsc::Sender<result::Result<Option<i32>, String>>;
pub type VoidSender = mpsc::Sender<result::Result<(), String>>;
pub type RatelimitSender = mpsc::Sender<result::Result<RatelimitResponse, String>>;

#[derive(Debug, Serialize, Deserialize)]
pub enum Event {
Log(LogEvent),
Database(DatabaseEvent),
Stdio(StdioEvent),
Ratelimit(RatelimitEvent),
Blob(Blob),
Exit(ExitEvent),
}
Expand All @@ -38,6 +41,7 @@ pub enum Event2 {
Start,
Log(LogEvent),
Database((DatabaseEvent, DbSender)),
Ratelimit((RatelimitEvent, RatelimitSender)),
Blob((Blob, VoidSender)),
Exit(ExitEvent),
}
Expand Down Expand Up @@ -283,6 +287,32 @@ impl StdioEvent {
}
}

#[derive(Debug, Serialize, Deserialize)]
pub struct RatelimitEvent {
key: String,
passes: u32,
time: u32,
}

impl EventWithCallback for RatelimitEvent {
type Payload = RatelimitResponse;

fn with_callback(self, tx: mpsc::Sender<result::Result<Self::Payload, String>>) -> Event2 {
Event2::Ratelimit((self, tx))
}
}

impl RatelimitEvent {
#[inline]
pub fn new(key: String, passes: u32, time: u32) -> RatelimitEvent {
RatelimitEvent {
key,
passes,
time,
}
}
}

pub fn spawn(rl: &mut Shell, module: &Module, args: Vec<(serde_json::Value, Option<String>, Vec<Blob>)>, params: &Params, proxy: Option<SocketAddr>, options: HashMap<String, String>) -> usize {
// This function hangs if args is empty, so return early if that's the case
if args.is_empty() {
Expand Down Expand Up @@ -328,6 +358,8 @@ pub fn spawn(rl: &mut Shell, module: &Module, args: Vec<(serde_json::Value, Opti
expected += 1;
}

let mut ratelimit = Ratelimiter::new();

let mut errors = 0;
let mut failed = Vec::new();
let timeout = Duration::from_millis(100);
Expand All @@ -344,6 +376,7 @@ pub fn spawn(rl: &mut Shell, module: &Module, args: Vec<(serde_json::Value, Opti
},
Event2::Log(log) => log.apply(&mut stack.prefixed(name)),
Event2::Database((db, tx)) => db.apply(tx, &mut stack.prefixed(name), rl.db(), verbose),
Event2::Ratelimit((req, tx)) => ratelimit.pass(tx, &req.key, req.passes, req.time),
Event2::Blob((blob, tx)) => rl.store_blob(tx, &blob),
Event2::Exit(event) => {
debug!("Received exit: {:?} -> {:?}", name, event);
Expand Down Expand Up @@ -399,6 +432,7 @@ pub fn spawn_fn<F, T>(label: &str, f: F, clear: bool) -> Result<T>
Some(Event::Log(log)) => log.apply(&mut *spinner),
Some(Event::Database(_)) => (),
Some(Event::Stdio(_)) => (),
Some(Event::Ratelimit(_)) => (),
Some(Event::Blob(_)) => (),
// TODO: refactor
Some(Event::Exit(ExitEvent::Ok)) => break,
Expand Down Expand Up @@ -491,6 +525,7 @@ pub fn spawn_multi<T: Task, F>(tasks: Vec<T>, mut done_fn: F, threads: usize) ->
},
Event2::Log(log) => log.apply(&mut stack.prefixed(&name)),
Event2::Database(_) => (),
Event2::Ratelimit(_) => (),
Event2::Blob(_) => (),
Event2::Exit(event) => {
debug!("Received exit: {:?} -> {:?}", name, event);
Expand Down

0 comments on commit aab8a52

Please sign in to comment.