From 9fb494764c26ad62b025ff9388a1e459d68f0695 Mon Sep 17 00:00:00 2001 From: urisinger <60300761+urisinger@users.noreply.github.com> Date: Thu, 21 Nov 2024 20:24:01 +0200 Subject: [PATCH] Add math-parser library (#2033) * start of parser * ops forgot * reorder files and work on executer * start of parser * ops forgot * reorder files and work on executer * Cleanup and fix tests * Integrate into the editor * added unit checking at parse time * fix tests * fix issues * fix editor intergration * update pest grammer to support units * units should be working, need to set up tests to know * make unit type store exponants as i32 * remove scale, insted just multiply the literal by the scale * unit now contains empty unit,remove options * add more tests and implement almost all unary operators * add evaluation context and variables * function calling, api might be refined later * add constants, change function call to not be as built into the parser and add tests * add function definitions * remove meval * remove raw-rs from workspace * add support for numberless units * fix unit handleing logic, add some "unit" tests(haha) * make it so units cant do implcit mul with idents * add bench and better tests * fix editor api * remove old test * change hashmap context to use deref * change constants to use hashmap instad of function --------- Co-authored-by: hypercube <0hypercube@gmail.com> Co-authored-by: Keavon Chambers --- Cargo.lock | 75 ++++- Cargo.toml | 6 +- frontend/wasm/Cargo.toml | 2 +- frontend/wasm/src/editor_api.rs | 90 +----- libraries/math-parser/Cargo.toml | 23 ++ libraries/math-parser/benches/bench.rs | 50 ++++ libraries/math-parser/src/ast.rs | 75 +++++ libraries/math-parser/src/constants.rs | 121 ++++++++ libraries/math-parser/src/context.rs | 83 ++++++ libraries/math-parser/src/executer.rs | 105 +++++++ libraries/math-parser/src/grammer.pest | 60 ++++ libraries/math-parser/src/lib.rs | 151 ++++++++++ libraries/math-parser/src/parser.rs | 385 +++++++++++++++++++++++++ libraries/math-parser/src/value.rs | 128 ++++++++ 14 files changed, 1260 insertions(+), 94 deletions(-) create mode 100644 libraries/math-parser/Cargo.toml create mode 100644 libraries/math-parser/benches/bench.rs create mode 100644 libraries/math-parser/src/ast.rs create mode 100644 libraries/math-parser/src/constants.rs create mode 100644 libraries/math-parser/src/context.rs create mode 100644 libraries/math-parser/src/executer.rs create mode 100644 libraries/math-parser/src/grammer.pest create mode 100644 libraries/math-parser/src/lib.rs create mode 100644 libraries/math-parser/src/parser.rs create mode 100644 libraries/math-parser/src/value.rs diff --git a/Cargo.lock b/Cargo.lock index f6343a8543..00ca1a0086 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2534,7 +2534,7 @@ dependencies = [ "graphite-editor", "js-sys", "log", - "meval", + "math-parser", "ron", "serde", "serde-wasm-bindgen", @@ -3534,6 +3534,19 @@ version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94" +[[package]] +name = "math-parser" +version = "0.0.0" +dependencies = [ + "criterion", + "lazy_static", + "log", + "num-complex", + "pest", + "pest_derive", + "thiserror", +] + [[package]] name = "matrixmultiply" version = "0.3.9" @@ -3592,15 +3605,6 @@ dependencies = [ "paste", ] -[[package]] -name = "meval" -version = "0.2.0" -source = "git+https://github.com/Titaniumtown/meval-rs#6bf579fd402928745cf4f24e5c975bece3285179" -dependencies = [ - "fnv", - "nom", -] - [[package]] name = "mime" version = "0.3.17" @@ -4341,6 +4345,51 @@ version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" +[[package]] +name = "pest" +version = "2.7.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fdbef9d1d47087a895abd220ed25eb4ad973a5e26f6a4367b038c25e28dfc2d9" +dependencies = [ + "memchr", + "thiserror", + "ucd-trie", +] + +[[package]] +name = "pest_derive" +version = "2.7.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4d3a6e3394ec80feb3b6393c725571754c6188490265c61aaf260810d6b95aa0" +dependencies = [ + "pest", + "pest_generator", +] + +[[package]] +name = "pest_generator" +version = "2.7.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94429506bde1ca69d1b5601962c73f4172ab4726571a59ea95931218cb0e930e" +dependencies = [ + "pest", + "pest_meta", + "proc-macro2", + "quote", + "syn 2.0.79", +] + +[[package]] +name = "pest_meta" +version = "2.7.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac8a071862e93690b6e34e9a5fb8e33ff3734473ac0245b27232222c4906a33f" +dependencies = [ + "once_cell", + "pest", + "sha2", +] + [[package]] name = "petgraph" version = "0.6.5" @@ -6758,6 +6807,12 @@ version = "1.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" +[[package]] +name = "ucd-trie" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2896d95c02a80c6d6a5d6e953d479f5ddf2dfdb6a244441010e373ac0fb88971" + [[package]] name = "uds_windows" version = "1.1.0" diff --git a/Cargo.toml b/Cargo.toml index 6fb4b5c3df..9e38721998 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,7 @@ members = [ "libraries/dyn-any", "libraries/path-bool", "libraries/bezier-rs", + "libraries/math-parser", "website/other/bezier-rs-demos/wasm", ] exclude = ["node-graph/gpu-compiler"] @@ -31,6 +32,7 @@ graph-craft = { path = "node-graph/graph-craft", features = ["serde"] } wgpu-executor = { path = "node-graph/wgpu-executor" } bezier-rs = { path = "libraries/bezier-rs", features = ["dyn-any"] } path-bool = { path = "libraries/path-bool", default-features = false } +math-parser = { path = "libraries/math-parser" } node-macro = { path = "node-graph/node-macro" } # Workspace dependencies @@ -77,7 +79,6 @@ glam = { version = "0.28", default-features = false, features = ["serde"] } base64 = "0.22" image = { version = "0.25", default-features = false, features = ["png"] } rustybuzz = "0.17" -meval = "0.2" spirv = "0.3" fern = { version = "0.6", features = ["colored"] } num_enum = "0.7" @@ -94,9 +95,6 @@ syn = { version = "2.0", default-features = false, features = [ ] } kurbo = { version = "0.11.0", features = ["serde"] } -[patch.crates-io] -meval = { git = "https://github.com/Titaniumtown/meval-rs" } - [profile.dev] opt-level = 1 diff --git a/frontend/wasm/Cargo.toml b/frontend/wasm/Cargo.toml index f30d5ed357..f8e93ffa4d 100644 --- a/frontend/wasm/Cargo.toml +++ b/frontend/wasm/Cargo.toml @@ -37,7 +37,7 @@ js-sys = { workspace = true } wasm-bindgen-futures = { workspace = true } bezier-rs = { workspace = true } glam = { workspace = true } -meval = { workspace = true } +math-parser = { workspace = true } wgpu = { workspace = true, features = [ "fragile-send-sync-non-atomic-wasm", ] } # We don't have wgpu on multiple threads (yet) https://github.com/gfx-rs/wgpu/blob/trunk/CHANGELOG.md#wgpu-types-now-send-sync-on-wasm diff --git a/frontend/wasm/src/editor_api.rs b/frontend/wasm/src/editor_api.rs index bb7f976aab..aef4660405 100644 --- a/frontend/wasm/src/editor_api.rs +++ b/frontend/wasm/src/editor_api.rs @@ -910,72 +910,17 @@ impl EditorHandle { #[wasm_bindgen(js_name = evaluateMathExpression)] pub fn evaluate_math_expression(expression: &str) -> Option { - // TODO: Rewrite our own purpose-built math expression parser that supports unit conversions. - - let mut context = meval::Context::new(); - context.var("tau", std::f64::consts::TAU); - context.func("log", f64::log10); - context.func("log10", f64::log10); - context.func("log2", f64::log2); - - // Insert asterisks where implicit multiplication is used in the expression string - let expression = implicit_multiplication_preprocess(expression); - - meval::eval_str_with_context(expression, &context).ok() -} - -// Modified from this public domain snippet: -// Discussion: -pub fn implicit_multiplication_preprocess(expression: &str) -> String { - let function = expression.to_lowercase().replace("log10(", "log(").replace("log2(", "logtwo(").replace("pi", "π").replace("tau", "τ"); - let valid_variables: Vec = "eπτ".chars().collect(); - let letters: Vec = ('a'..='z').chain('A'..='Z').collect(); - let numbers: Vec = ('0'..='9').collect(); - let function_chars: Vec = function.chars().collect(); - let mut output_string: String = String::new(); - let mut prev_chars: Vec = Vec::new(); - - for c in function_chars { - let mut add_asterisk: bool = false; - let prev_chars_len = prev_chars.len(); - - let prev_prev_char = if prev_chars_len >= 2 { *prev_chars.get(prev_chars_len - 2).unwrap() } else { ' ' }; - - let prev_char = if prev_chars_len >= 1 { *prev_chars.get(prev_chars_len - 1).unwrap() } else { ' ' }; - - let c_letters_var = letters.contains(&c) | valid_variables.contains(&c); - let prev_letters_var = valid_variables.contains(&prev_char) | letters.contains(&prev_char); - - if prev_char == ')' { - if (c == '(') | numbers.contains(&c) | c_letters_var { - add_asterisk = true; - } - } else if c == '(' { - if (valid_variables.contains(&prev_char) | (')' == prev_char) | numbers.contains(&prev_char)) && !letters.contains(&prev_prev_char) { - add_asterisk = true; - } - } else if numbers.contains(&prev_char) { - if (c == '(') | c_letters_var { - add_asterisk = true; - } - } else if letters.contains(&c) { - if numbers.contains(&prev_char) | (valid_variables.contains(&prev_char) && valid_variables.contains(&c)) { - add_asterisk = true; - } - } else if (numbers.contains(&c) | c_letters_var) && prev_letters_var { - add_asterisk = true; - } - - if add_asterisk { - output_string += "*"; - } - - prev_chars.push(c); - output_string += &c.to_string(); - } - - // We have to convert the Greek symbols back to ASCII because meval doesn't support unicode symbols as context constants - output_string.replace("logtwo(", "log2(").replace('π', "pi").replace('τ', "tau") + let value = math_parser::evaluate(expression) + .inspect_err(|err| error!("Math parser error on \"{expression}\": {err}")) + .ok()? + .0 + .inspect_err(|err| error!("Math evaluate error on \"{expression}\": {err} ")) + .ok()?; + let Some(real) = value.as_real() else { + error!("{value} was not a real; skipping."); + return None; + }; + Some(real) } /// Helper function for calling JS's `requestAnimationFrame` with the given closure @@ -1066,16 +1011,3 @@ fn auto_save_all_documents() { } }); } - -#[test] -fn implicit_multiplication_preprocess_tests() { - assert_eq!(implicit_multiplication_preprocess("2pi"), "2*pi"); - assert_eq!(implicit_multiplication_preprocess("sin(2pi)"), "sin(2*pi)"); - assert_eq!(implicit_multiplication_preprocess("2sin(pi)"), "2*sin(pi)"); - assert_eq!(implicit_multiplication_preprocess("2sin(3(4 + 5))"), "2*sin(3*(4 + 5))"); - assert_eq!(implicit_multiplication_preprocess("3abs(-4)"), "3*abs(-4)"); - assert_eq!(implicit_multiplication_preprocess("-1(4)"), "-1*(4)"); - assert_eq!(implicit_multiplication_preprocess("(-1)4"), "(-1)*4"); - assert_eq!(implicit_multiplication_preprocess("(((-1)))(4)"), "(((-1)))*(4)"); - assert_eq!(implicit_multiplication_preprocess("2sin(pi) + 2cos(tau)"), "2*sin(pi) + 2*cos(tau)"); -} diff --git a/libraries/math-parser/Cargo.toml b/libraries/math-parser/Cargo.toml new file mode 100644 index 0000000000..cc2097a144 --- /dev/null +++ b/libraries/math-parser/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "math-parser" +version = "0.0.0" +rust-version = "1.79" +edition = "2021" +authors = ["Graphite Authors "] +description = "Parser for Graphite style mathematics expressions" +license = "MIT OR Apache-2.0" + +[dependencies] +pest = "2.7" +pest_derive = "2.7.11" +thiserror = "1" +lazy_static = "1.5" +num-complex = "0.4" +log = { workspace = true } + +[dev-dependencies] +criterion = "0.5" + +[[bench]] +name = "bench" +harness = false diff --git a/libraries/math-parser/benches/bench.rs b/libraries/math-parser/benches/bench.rs new file mode 100644 index 0000000000..392cc7ae8f --- /dev/null +++ b/libraries/math-parser/benches/bench.rs @@ -0,0 +1,50 @@ +use criterion::{black_box, criterion_group, criterion_main, Criterion}; + +use math_parser::ast; +use math_parser::context::EvalContext; + +macro_rules! generate_benchmarks { + ($( $input:expr ),* $(,)?) => { + fn parsing_bench(c: &mut Criterion) { + $( + c.bench_function(concat!("parse ", $input), |b| { + b.iter(|| { + let _ = black_box(ast::Node::from_str($input)).unwrap(); + }); + }); + )* + } + + fn evaluation_bench(c: &mut Criterion) { + $( + let expr = ast::Node::from_str($input).unwrap().0; + let context = EvalContext::default(); + + c.bench_function(concat!("eval ", $input), |b| { + b.iter(|| { + let _ = black_box(expr.eval(&context)); + }); + }); + )* + } + + criterion_group!(benches, parsing_bench, evaluation_bench); + criterion_main!(benches); + }; +} + +generate_benchmarks! { + "(3 * (4 + sqrt(25)) - cos(pi/3) * (2^3)) + 5 * e", // Mixed nested functions, constants, and operations + "((5 + 2 * (3 - sqrt(49)))^2) / (1 + sqrt(16)) + tau / 2", // Complex nested expression with constants + "log(100, 10) + (5 * sin(pi/4) + sqrt(81)) / (2 * phi)", // Logarithmic and trigonometric functions + "(sqrt(144) * 2 + 5) / (3 * (4 - sin(pi / 6))) + e^2", // Combined square root, trigonometric, and exponential operations + "cos(2 * pi) + tan(pi / 3) * log(32, 2) - sqrt(256)", // Multiple trigonometric and logarithmic functions + "(10 * (3 + 2) - 8 / 2)^2 + 7 * (2^4) - sqrt(225) + phi", // Mixed arithmetic with constants + "(5^2 + 3^3) * (sqrt(81) + sqrt(64)) - tau * log(1000, 10)", // Power and square root with constants + "((8 * sqrt(49) - 2 * e) + log(256, 2) / (2 + cos(pi))) * 1.5", // Nested functions and constants + "(tan(pi / 4) + 5) * (3 + sqrt(36)) / (log(1024, 2) - 4)", // Nested functions with trigonometry and logarithm + "((3 * e + 2 * sqrt(100)) - cos(tau / 4)) * log(27, 3) + phi", // Mixed constant usage and functions + "(sqrt(100) + 5 * sin(pi / 6) - 8 / log(64, 2)) + e^(1.5)", // Complex mix of square root, division, and exponentiation + "((sin(pi/2) + cos(0)) * (e^2 - 2 * sqrt(16))) / (log(100, 10) + pi)", // Nested trigonometric, exponential, and logarithmic functions + "(5 * (7 + sqrt(121)) - (log(243, 3) * phi)) + 3^5 / tau", // +} diff --git a/libraries/math-parser/src/ast.rs b/libraries/math-parser/src/ast.rs new file mode 100644 index 0000000000..4c42fc4b1d --- /dev/null +++ b/libraries/math-parser/src/ast.rs @@ -0,0 +1,75 @@ +use crate::value::Complex; + +#[derive(Debug, PartialEq, Eq)] +pub struct Unit { + // Exponent of length unit (meters) + pub length: i32, + // Exponent of mass unit (kilograms) + pub mass: i32, + // Exponent of time unit (seconds) + pub time: i32, +} + +impl Default for Unit { + fn default() -> Self { + Self::BASE_UNIT + } +} + +impl Unit { + pub const BASE_UNIT: Unit = Unit { length: 0, mass: 0, time: 0 }; + + pub const LENGTH: Unit = Unit { length: 1, mass: 0, time: 0 }; + pub const MASS: Unit = Unit { length: 0, mass: 1, time: 0 }; + pub const TIME: Unit = Unit { length: 0, mass: 0, time: 1 }; + + pub const VELOCITY: Unit = Unit { length: 1, mass: 0, time: -1 }; + pub const ACCELERATION: Unit = Unit { length: 1, mass: 0, time: -2 }; + + pub const FORCE: Unit = Unit { length: 1, mass: 1, time: -2 }; + + pub fn base_unit() -> Self { + Self::BASE_UNIT + } + + pub fn is_base(&self) -> bool { + *self == Self::BASE_UNIT + } +} + +#[derive(Debug, PartialEq)] +pub enum Literal { + Float(f64), + Complex(Complex), +} + +impl From for Literal { + fn from(value: f64) -> Self { + Self::Float(value) + } +} + +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum BinaryOp { + Add, + Sub, + Mul, + Div, + Pow, +} + +#[derive(Debug, PartialEq, Clone, Copy)] +pub enum UnaryOp { + Neg, + Sqrt, + Fac, +} + +#[derive(Debug, PartialEq)] +pub enum Node { + Lit(Literal), + Var(String), + FnCall { name: String, expr: Vec }, + BinOp { lhs: Box, op: BinaryOp, rhs: Box }, + UnaryOp { expr: Box, op: UnaryOp }, +} diff --git a/libraries/math-parser/src/constants.rs b/libraries/math-parser/src/constants.rs new file mode 100644 index 0000000000..62c69c3a2c --- /dev/null +++ b/libraries/math-parser/src/constants.rs @@ -0,0 +1,121 @@ +use std::{collections::HashMap, f64::consts::PI}; + +use lazy_static::lazy_static; +use num_complex::{Complex, ComplexFloat}; + +use crate::value::{Number, Value}; +lazy_static! { + pub static ref DEFAULT_FUNCTIONS: HashMap<&'static str, Box Option + Send + Sync>> = { + let mut map: HashMap<&'static str, Box Option + Send + Sync>> = HashMap::new(); + + map.insert( + "sin", + Box::new(|values| match values { + [Value::Number(Number::Real(real))] => Some(Value::Number(Number::Real(real.sin()))), + [Value::Number(Number::Complex(complex))] => Some(Value::Number(Number::Complex(complex.sin()))), + _ => None, + }), + ); + + map.insert( + "cos", + Box::new(|values| match values { + [Value::Number(Number::Real(real))] => Some(Value::Number(Number::Real(real.cos()))), + [Value::Number(Number::Complex(complex))] => Some(Value::Number(Number::Complex(complex.cos()))), + _ => None, + }), + ); + + map.insert( + "tan", + Box::new(|values| match values { + [Value::Number(Number::Real(real))] => Some(Value::Number(Number::Real(real.tan()))), + [Value::Number(Number::Complex(complex))] => Some(Value::Number(Number::Complex(complex.tan()))), + _ => None, + }), + ); + + map.insert( + "csc", + Box::new(|values| match values { + [Value::Number(Number::Real(real))] => Some(Value::Number(Number::Real(real.sin().recip()))), + [Value::Number(Number::Complex(complex))] => Some(Value::Number(Number::Complex(complex.sin().recip()))), + _ => None, + }), + ); + + map.insert( + "sec", + Box::new(|values| match values { + [Value::Number(Number::Real(real))] => Some(Value::Number(Number::Real(real.cos().recip()))), + [Value::Number(Number::Complex(complex))] => Some(Value::Number(Number::Complex(complex.cos().recip()))), + _ => None, + }), + ); + + map.insert( + "cot", + Box::new(|values| match values { + [Value::Number(Number::Real(real))] => Some(Value::Number(Number::Real(real.tan().recip()))), + [Value::Number(Number::Complex(complex))] => Some(Value::Number(Number::Complex(complex.tan().recip()))), + _ => None, + }), + ); + + map.insert( + "invsin", + Box::new(|values| match values { + [Value::Number(Number::Real(real))] => Some(Value::Number(Number::Real(real.asin()))), + [Value::Number(Number::Complex(complex))] => Some(Value::Number(Number::Complex(complex.asin()))), + _ => None, + }), + ); + + map.insert( + "invcos", + Box::new(|values| match values { + [Value::Number(Number::Real(real))] => Some(Value::Number(Number::Real(real.acos()))), + [Value::Number(Number::Complex(complex))] => Some(Value::Number(Number::Complex(complex.acos()))), + _ => None, + }), + ); + + map.insert( + "invtan", + Box::new(|values| match values { + [Value::Number(Number::Real(real))] => Some(Value::Number(Number::Real(real.atan()))), + [Value::Number(Number::Complex(complex))] => Some(Value::Number(Number::Complex(complex.atan()))), + _ => None, + }), + ); + + map.insert( + "invcsc", + Box::new(|values| match values { + [Value::Number(Number::Real(real))] => Some(Value::Number(Number::Real(real.recip().asin()))), + [Value::Number(Number::Complex(complex))] => Some(Value::Number(Number::Complex(complex.recip().asin()))), + _ => None, + }), + ); + + map.insert( + "invsec", + Box::new(|values| match values { + [Value::Number(Number::Real(real))] => Some(Value::Number(Number::Real(real.recip().acos()))), + [Value::Number(Number::Complex(complex))] => Some(Value::Number(Number::Complex(complex.recip().acos()))), + _ => None, + }), + ); + + map.insert( + "invcot", + Box::new(|values| match values { + [Value::Number(Number::Real(real))] => Some(Value::Number(Number::Real((PI / 2.0 - real).atan()))), + [Value::Number(Number::Complex(complex))] => Some(Value::Number(Number::Complex((Complex::new(PI / 2.0, 0.0) - complex).atan()))), + _ => None, + }), + ); + + map + }; +} diff --git a/libraries/math-parser/src/context.rs b/libraries/math-parser/src/context.rs new file mode 100644 index 0000000000..b34dc6fbd8 --- /dev/null +++ b/libraries/math-parser/src/context.rs @@ -0,0 +1,83 @@ +use std::{ + collections::HashMap, + ops::{Deref, DerefMut}, +}; + +use crate::value::Value; + +//TODO: editor integration, implement these traits for whatever is needed, maybe merge them if needed +pub trait ValueProvider { + fn get_value(&self, name: &str) -> Option; +} + +pub trait FunctionProvider { + fn run_function(&self, name: &str, args: &[Value]) -> Option; +} + +pub struct ValueMap(HashMap); + +pub struct NothingMap; + +impl ValueProvider for &ValueMap { + fn get_value(&self, name: &str) -> Option { + self.0.get(name).cloned() + } +} + +impl ValueProvider for NothingMap { + fn get_value(&self, _: &str) -> Option { + None + } +} + +impl ValueProvider for ValueMap { + fn get_value(&self, name: &str) -> Option { + self.0.get(name).cloned() + } +} + +impl Deref for ValueMap { + type Target = HashMap; + fn deref(&self) -> &Self::Target { + &self.0 + } +} +impl DerefMut for ValueMap { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl FunctionProvider for NothingMap { + fn run_function(&self, _: &str, _: &[Value]) -> Option { + None + } +} + +pub struct EvalContext { + values: V, + functions: F, +} + +impl Default for EvalContext { + fn default() -> Self { + Self { + values: NothingMap, + functions: NothingMap, + } + } +} + +impl EvalContext { + pub fn new(values: V, functions: F) -> Self { + Self { values, functions } + } + + pub fn get_value(&self, name: &str) -> Option { + self.values.get_value(name) + } + + pub fn run_function(&self, name: &str, args: &[Value]) -> Option { + self.functions.run_function(name, args) + } +} diff --git a/libraries/math-parser/src/executer.rs b/libraries/math-parser/src/executer.rs new file mode 100644 index 0000000000..ebac537b9f --- /dev/null +++ b/libraries/math-parser/src/executer.rs @@ -0,0 +1,105 @@ +use thiserror::Error; + +use crate::{ + ast::{Literal, Node}, + constants::DEFAULT_FUNCTIONS, + context::{EvalContext, FunctionProvider, ValueProvider}, + value::{Number, Value}, +}; + +#[derive(Debug, Error)] +pub enum EvalError { + #[error("Missing value: {0}")] + MissingValue(String), + + #[error("Missing function: {0}")] + MissingFunction(String), + #[error("Wrong type for function call")] + TypeError, +} + +impl Node { + pub fn eval(&self, context: &EvalContext) -> Result { + match self { + Node::Lit(lit) => match lit { + Literal::Float(num) => Ok(Value::from_f64(*num)), + Literal::Complex(num) => Ok(Value::Number(Number::Complex(*num))), + }, + + Node::BinOp { lhs, op, rhs } => match (lhs.eval(context)?, rhs.eval(context)?) { + (Value::Number(lhs), Value::Number(rhs)) => Ok(Value::Number(lhs.binary_op(*op, rhs))), + }, + Node::UnaryOp { expr, op } => match expr.eval(context)? { + Value::Number(num) => Ok(Value::Number(num.unary_op(*op))), + }, + Node::Var(name) => context.get_value(name).ok_or_else(|| EvalError::MissingValue(name.clone())), + Node::FnCall { name, expr } => { + let values = expr.iter().map(|expr| expr.eval(context)).collect::, EvalError>>()?; + if let Some(function) = DEFAULT_FUNCTIONS.get(&name.as_str()) { + function(&values).ok_or(EvalError::TypeError) + } else if let Some(val) = context.run_function(name, &values) { + Ok(val) + } else { + context.get_value(name).ok_or_else(|| EvalError::MissingFunction(name.to_string())) + } + } + } + } +} + +#[cfg(test)] +mod tests { + use crate::{ + ast::{BinaryOp, Literal, Node, UnaryOp}, + context::{EvalContext, ValueMap}, + value::Value, + }; + + macro_rules! eval_tests { + ($($name:ident: $expected:expr => $expr:expr),* $(,)?) => { + $( + #[test] + fn $name() { + let result = $expr.eval(&EvalContext::default()).unwrap(); + assert_eq!(result, $expected); + } + )* + }; + } + + eval_tests! { + test_addition: Value::from_f64(7.0) => Node::BinOp { + lhs: Box::new(Node::Lit(Literal::Float(3.0))), + op: BinaryOp::Add, + rhs: Box::new(Node::Lit(Literal::Float(4.0))), + }, + test_subtraction: Value::from_f64(1.0) => Node::BinOp { + lhs: Box::new(Node::Lit(Literal::Float(5.0))), + op: BinaryOp::Sub, + rhs: Box::new(Node::Lit(Literal::Float(4.0))), + }, + test_multiplication: Value::from_f64(12.0) => Node::BinOp { + lhs: Box::new(Node::Lit(Literal::Float(3.0))), + op: BinaryOp::Mul, + rhs: Box::new(Node::Lit(Literal::Float(4.0))), + }, + test_division: Value::from_f64(2.5) => Node::BinOp { + lhs: Box::new(Node::Lit(Literal::Float(5.0))), + op: BinaryOp::Div, + rhs: Box::new(Node::Lit(Literal::Float(2.0))), + }, + test_negation: Value::from_f64(-3.0) => Node::UnaryOp { + expr: Box::new(Node::Lit(Literal::Float(3.0))), + op: UnaryOp::Neg, + }, + test_sqrt: Value::from_f64(2.0) => Node::UnaryOp { + expr: Box::new(Node::Lit(Literal::Float(4.0))), + op: UnaryOp::Sqrt, + }, + test_power: Value::from_f64(8.0) => Node::BinOp { + lhs: Box::new(Node::Lit(Literal::Float(2.0))), + op: BinaryOp::Pow, + rhs: Box::new(Node::Lit(Literal::Float(3.0))), + }, + } +} diff --git a/libraries/math-parser/src/grammer.pest b/libraries/math-parser/src/grammer.pest new file mode 100644 index 0000000000..d7a61939df --- /dev/null +++ b/libraries/math-parser/src/grammer.pest @@ -0,0 +1,60 @@ +WHITESPACE = _{ " " | "\t" } + +// TODO: Proper indentation and formatting +program = _{ SOI ~ expr ~ EOI } + +expr = { atom ~ (infix ~ atom)* } +atom = _{ prefix? ~ primary ~ postfix? } +infix = _{ add | sub | mul | div | pow | paren } +add = { "+" } // Addition +sub = { "-" } // Subtraction +mul = { "*" } // Multiplication +div = { "/" } // Division +mod = { "%" } // Modulo +pow = { "^" } // Exponentiation +paren = { "" } // Implicit multiplication operator + +prefix = _{ neg | sqrt } +neg = { "-" } // Negation +sqrt = { "sqrt" } + +postfix = _{ fac } +fac = { "!" } // Factorial + +primary = _{ ("(" ~ expr ~ ")") | lit | constant | fn_call | ident } +fn_call = { ident ~ "(" ~ expr ~ ("," ~ expr)* ~ ")" } +ident = @{ (ASCII_ALPHA | "_") ~ (ASCII_ALPHANUMERIC | "_")* } +lit = { unit | ((float | int) ~ unit?) } + +float = @{ int ~ "." ~ int? ~ exp? | int ~ exp } +exp = _{ ^"e" ~ ("+" | "-")? ~ int } +int = @{ ASCII_DIGIT+ } + +unit = ${ (scale ~ base_unit) | base_unit ~ !ident} +base_unit = _{ meter | second | gram } +meter = { "m" } +second = { "s" } +gram = { "g" } + +scale = _{ nano | micro | milli | centi | deci | deca | hecto | kilo | mega | giga | tera } +nano = { "n" } +micro = { "µ" | "u" } +milli = { "m" } +centi = { "c" } +deci = { "d" } +deca = { "da" } +hecto = { "h" } +kilo = { "k" } +mega = { "M" } +giga = { "G" } +tera = { "T" } + +// Constants +constant = { infinity | imaginary_unit | pi | tau | euler_number | golden_ratio | gravity_acceleration } +infinity = { "inf" | "INF" | "infinity" | "INFINITY" | "∞" } +imaginary_unit = { "i" | "I" } +pi = { "pi" | "PI" | "π" } +tau = { "tau" | "TAU" | "τ" } +euler_number = { "e" } +golden_ratio = { "phi" | "PHI" | "φ" } +gravity_acceleration = { "G" } diff --git a/libraries/math-parser/src/lib.rs b/libraries/math-parser/src/lib.rs new file mode 100644 index 0000000000..e1d0341564 --- /dev/null +++ b/libraries/math-parser/src/lib.rs @@ -0,0 +1,151 @@ +#![allow(unused)] + +pub mod ast; +mod constants; +pub mod context; +pub mod executer; +pub mod parser; +pub mod value; + +use ast::Unit; +use context::{EvalContext, ValueMap}; +use executer::EvalError; +use parser::ParseError; +use value::Value; + +pub fn evaluate(expression: &str) -> Result<(Result, Unit), ParseError> { + let expr = ast::Node::from_str(expression); + let context = EvalContext::default(); + expr.map(|(node, unit)| (node.eval(&context), unit)) +} + +#[cfg(test)] +mod tests { + use value::Number; + + use ast::Unit; + + use super::*; + + const EPSILON: f64 = 1e10_f64; + + macro_rules! test_end_to_end{ + ($($name:ident: $input:expr => ($expected_value:expr, $expected_unit:expr)),* $(,)?) => { + $( + #[test] + fn $name() { + let expected_value = $expected_value; + let expected_unit = $expected_unit; + + let expr = ast::Node::from_str($input); + let context = EvalContext::default(); + + let (actual_value, actual_unit) = expr.map(|(node, unit)| (node.eval(&context), unit)).unwrap(); + let actual_value = actual_value.unwrap(); + + + assert!(actual_unit == expected_unit, "Expected unit {:?} but found unit {:?}", expected_unit, actual_unit); + + let expected_value = expected_value.into(); + + match (actual_value, expected_value) { + (Value::Number(Number::Complex(actual_c)), Value::Number(Number::Complex(expected_c))) => { + assert!( + (actual_c.re.is_infinite() && expected_c.re.is_infinite()) || (actual_c.re - expected_c.re).abs() < EPSILON, + "Expected real part {}, but got {}", + expected_c.re, + actual_c.re + ); + assert!( + (actual_c.im.is_infinite() && expected_c.im.is_infinite()) || (actual_c.im - expected_c.im).abs() < EPSILON, + "Expected imaginary part {}, but got {}", + expected_c.im, + actual_c.im + ); + } + (Value::Number(Number::Real(actual_f)), Value::Number(Number::Real(expected_f))) => { + if actual_f.is_infinite() || expected_f.is_infinite() { + assert!( + actual_f.is_infinite() && expected_f.is_infinite() && actual_f == expected_f, + "Expected infinite value {}, but got {}", + expected_f, + actual_f + ); + } else if actual_f.is_nan() || expected_f.is_nan() { + assert!(actual_f.is_nan() && expected_f.is_nan(), "Expected NaN, but got {}", actual_f); + } else { + assert!((actual_f - expected_f).abs() < EPSILON, "Expected {}, but got {}", expected_f, actual_f); + } + } + // Handle mismatched types + _ => panic!("Mismatched types: expected {:?}, got {:?}", expected_value, actual_value), + } + + } + )* + }; + } + + test_end_to_end! { + // Basic arithmetic and units + infix_addition: "5 + 5" => (10., Unit::BASE_UNIT), + infix_subtraction_units: "5m - 3m" => (2., Unit::LENGTH), + infix_multiplication_units: "4s * 4s" => (16., Unit { length: 0, mass: 0, time: 2 }), + infix_division_units: "8m/2s" => (4., Unit::VELOCITY), + + // Order of operations + order_of_operations_negative_prefix: "-10 + 5" => (-5., Unit::BASE_UNIT), + order_of_operations_add_multiply: "5+1*1+5" => (11., Unit::BASE_UNIT), + order_of_operations_add_negative_multiply: "5+(-1)*1+5" => (9., Unit::BASE_UNIT), + order_of_operations_sqrt: "sqrt25 + 11" => (16., Unit::BASE_UNIT), + order_of_operations_sqrt_expression: "sqrt(25+11)" => (6., Unit::BASE_UNIT), + + // Parentheses and nested expressions + parentheses_nested_multiply: "(5 + 3) * (2 + 6)" => (64., Unit::BASE_UNIT), + parentheses_mixed_operations: "2 * (3 + 5 * (2 + 1))" => (36., Unit::BASE_UNIT), + parentheses_divide_add_multiply: "10 / (2 + 3) + (7 * 2)" => (16., Unit::BASE_UNIT), + + // Square root and nested square root + sqrt_chain_operations: "sqrt(16) + sqrt(9) * sqrt(4)" => (10., Unit::BASE_UNIT), + sqrt_nested: "sqrt(sqrt(81))" => (3., Unit::BASE_UNIT), + sqrt_divide_expression: "sqrt((25 + 11) / 9)" => (2., Unit::BASE_UNIT), + + // Mixed square root and units + sqrt_multiply_units: "sqrt(16) * 2g + 5g" => (13., Unit::MASS), + sqrt_add_multiply: "sqrt(49) - 1 + 2 * 3" => (12., Unit::BASE_UNIT), + sqrt_addition_multiply: "(sqrt(36) + 2) * 2" => (16., Unit::BASE_UNIT), + + // Exponentiation + exponent_single: "2^3" => (8., Unit::BASE_UNIT), + exponent_mixed_operations: "2^3 + 4^2" => (24., Unit::BASE_UNIT), + exponent_nested: "2^(3+1)" => (16., Unit::BASE_UNIT), + + // Operations with negative values + negative_units_add_multiply: "-5s + (-3 * 2)s" => (-11., Unit::TIME), + negative_nested_parentheses: "-(5 + 3 * (2 - 1))" => (-8., Unit::BASE_UNIT), + negative_sqrt_addition: "-(sqrt(16) + sqrt(9))" => (-7., Unit::BASE_UNIT), + multiply_sqrt_subtract: "5 * 2 + sqrt(16) / 2 - 3" => (9., Unit::BASE_UNIT), + add_multiply_subtract_sqrt: "4 + 3 * (2 + 1) - sqrt(25)" => (8., Unit::BASE_UNIT), + add_sqrt_subtract_nested_multiply: "10 + sqrt(64) - (5 * (2 + 1))" => (3., Unit::BASE_UNIT), + + // Mathematical constants + constant_pi: "pi" => (std::f64::consts::PI, Unit::BASE_UNIT), + constant_e: "e" => (std::f64::consts::E, Unit::BASE_UNIT), + constant_phi: "phi" => (1.61803398875, Unit::BASE_UNIT), + constant_tau: "tau" => (2.0 * std::f64::consts::PI, Unit::BASE_UNIT), + constant_infinity: "inf" => (f64::INFINITY, Unit::BASE_UNIT), + constant_infinity_symbol: "∞" => (f64::INFINITY, Unit::BASE_UNIT), + multiply_pi: "2 * pi" => (2.0 * std::f64::consts::PI, Unit::BASE_UNIT), + add_e_constant: "e + 1" => (std::f64::consts::E + 1.0, Unit::BASE_UNIT), + multiply_phi_constant: "phi * 2" => (1.61803398875 * 2.0, Unit::BASE_UNIT), + exponent_tau: "2^tau" => (2f64.powf(2.0 * std::f64::consts::PI), Unit::BASE_UNIT), + infinity_subtract_large_number: "inf - 1000" => (f64::INFINITY, Unit::BASE_UNIT), + + // Trigonometric functions + trig_sin_pi: "sin(pi)" => (0.0, Unit::BASE_UNIT), + trig_cos_zero: "cos(0)" => (1.0, Unit::BASE_UNIT), + trig_tan_pi_div_four: "tan(pi/4)" => (1.0, Unit::BASE_UNIT), + trig_sin_tau: "sin(tau)" => (0.0, Unit::BASE_UNIT), + trig_cos_tau_div_two: "cos(tau/2)" => (-1.0, Unit::BASE_UNIT), + } +} diff --git a/libraries/math-parser/src/parser.rs b/libraries/math-parser/src/parser.rs new file mode 100644 index 0000000000..632355316e --- /dev/null +++ b/libraries/math-parser/src/parser.rs @@ -0,0 +1,385 @@ +use std::num::{ParseFloatError, ParseIntError}; + +use lazy_static::lazy_static; +use num_complex::ComplexFloat; +use pest::{ + iterators::{Pair, Pairs}, + pratt_parser::{Assoc, Op, PrattParser}, + Parser, +}; +use pest_derive::Parser; +use thiserror::Error; + +use crate::{ + ast::{BinaryOp, Literal, Node, UnaryOp, Unit}, + context::EvalContext, + value::{Complex, Number, Value}, +}; + +#[derive(Parser)] +#[grammar = "./grammer.pest"] // Point to the grammar file +struct ExprParser; + +lazy_static! { + static ref PRATT_PARSER: PrattParser = { + PrattParser::new() + .op(Op::infix(Rule::add, Assoc::Left) | Op::infix(Rule::sub, Assoc::Left)) + .op(Op::infix(Rule::mul, Assoc::Left) | Op::infix(Rule::div, Assoc::Left) | Op::infix(Rule::paren, Assoc::Left)) + .op(Op::infix(Rule::pow, Assoc::Right)) + .op(Op::postfix(Rule::fac) | Op::postfix(Rule::EOI)) + .op(Op::prefix(Rule::sqrt)) + .op(Op::prefix(Rule::neg)) + }; +} + +#[derive(Error, Debug)] +pub enum TypeError { + #[error("Invalid BinOp: {0:?} {1:?} {2:?}")] + InvalidBinaryOp(Unit, BinaryOp, Unit), + + #[error("Invalid UnaryOp: {0:?}")] + InvalidUnaryOp(Unit, UnaryOp), +} + +#[derive(Error, Debug)] +pub enum ParseError { + #[error("ParseIntError: {0}")] + ParseInt(#[from] ParseIntError), + #[error("ParseFloatError: {0}")] + ParseFloat(#[from] ParseFloatError), + + #[error("TypeError: {0}")] + Type(#[from] TypeError), + + #[error("PestError: {0}")] + Pest(#[from] Box>), +} + +impl Node { + pub fn from_str(s: &str) -> Result<(Node, Unit), ParseError> { + let pairs = ExprParser::parse(Rule::program, s).map_err(Box::new)?; + let (node, metadata) = parse_expr(pairs)?; + Ok((node, metadata.unit)) + } +} + +struct NodeMetadata { + pub unit: Unit, +} + +impl NodeMetadata { + pub fn new(unit: Unit) -> Self { + Self { unit } + } +} + +fn parse_unit(pairs: Pairs) -> Result<(Unit, f64), ParseError> { + let mut scale = 1.0; + let mut length = 0; + let mut mass = 0; + let mut time = 0; + + for pair in pairs { + println!("found rule: {:?}", pair.as_rule()); + match pair.as_rule() { + Rule::nano => scale *= 1e-9, + Rule::micro => scale *= 1e-6, + Rule::milli => scale *= 1e-3, + Rule::centi => scale *= 1e-2, + Rule::deci => scale *= 1e-1, + Rule::deca => scale *= 1e1, + Rule::hecto => scale *= 1e2, + Rule::kilo => scale *= 1e3, + Rule::mega => scale *= 1e6, + Rule::giga => scale *= 1e9, + Rule::tera => scale *= 1e12, + + Rule::meter => length = 1, + Rule::gram => mass = 1, + Rule::second => time = 1, + + _ => unreachable!(), // All possible rules should be covered + } + } + + Ok((Unit { length, mass, time }, scale)) +} + +fn parse_const(pair: Pair) -> Literal { + match pair.as_rule() { + Rule::infinity => Literal::Float(f64::INFINITY), + Rule::imaginary_unit => Literal::Complex(Complex::new(0.0, 1.0)), + Rule::pi => Literal::Float(std::f64::consts::PI), + Rule::tau => Literal::Float(2.0 * std::f64::consts::PI), + Rule::euler_number => Literal::Float(std::f64::consts::E), + Rule::golden_ratio => Literal::Float(1.61803398875), + _ => unreachable!("Unexpected constant: {:?}", pair), + } +} + +fn parse_lit(mut pairs: Pairs) -> Result<(Literal, Unit), ParseError> { + let literal = match pairs.next() { + Some(lit) => match lit.as_rule() { + Rule::int => { + let value = lit.as_str().parse::()? as f64; + Literal::Float(value) + } + Rule::float => { + let value = lit.as_str().parse::()?; + Literal::Float(value) + } + Rule::unit => { + let (unit, scale) = parse_unit(lit.into_inner())?; + return Ok((Literal::Float(scale), unit)); + } + rule => unreachable!("unexpected rule: {:?}", rule), + }, + None => unreachable!("expected rule"), // No literal found + }; + + if let Some(unit_pair) = pairs.next() { + let unit_pairs = unit_pair.into_inner(); // Get the inner pairs for the unit + let (unit, scale) = parse_unit(unit_pairs)?; + + println!("found unit: {:?}", unit); + + Ok(( + match literal { + Literal::Float(num) => Literal::Float(num * scale), + Literal::Complex(num) => Literal::Complex(num * scale), + }, + unit, + )) + } else { + Ok((literal, Unit::BASE_UNIT)) + } +} + +fn parse_expr(pairs: Pairs) -> Result<(Node, NodeMetadata), ParseError> { + PRATT_PARSER + .map_primary(|primary| { + Ok(match primary.as_rule() { + Rule::lit => { + let (lit, unit) = parse_lit(primary.into_inner())?; + + (Node::Lit(lit), NodeMetadata { unit }) + } + Rule::fn_call => { + let mut pairs = primary.into_inner(); + let name = pairs.next().expect("fn_call always has 2 children").as_str().to_string(); + + ( + Node::FnCall { + name, + expr: pairs.map(|p| parse_expr(p.into_inner()).map(|expr| expr.0)).collect::, ParseError>>()?, + }, + NodeMetadata::new(Unit::BASE_UNIT), + ) + } + Rule::constant => { + let lit = parse_const(primary.into_inner().next().expect("constant should have atleast 1 child")); + + (Node::Lit(lit), NodeMetadata::new(Unit::BASE_UNIT)) + } + Rule::ident => { + let name = primary.as_str().to_string(); + + (Node::Var(name), NodeMetadata::new(Unit::BASE_UNIT)) + } + Rule::expr => parse_expr(primary.into_inner())?, + Rule::float => { + let value = primary.as_str().parse::()?; + (Node::Lit(Literal::Float(value)), NodeMetadata::new(Unit::BASE_UNIT)) + } + rule => unreachable!("unexpected rule: {:?}", rule), + }) + }) + .map_prefix(|op, rhs| { + let (rhs, rhs_metadata) = rhs?; + let op = match op.as_rule() { + Rule::neg => UnaryOp::Neg, + Rule::sqrt => UnaryOp::Sqrt, + + rule => unreachable!("unexpected rule: {:?}", rule), + }; + + let node = Node::UnaryOp { expr: Box::new(rhs), op }; + let unit = rhs_metadata.unit; + + let unit = if !unit.is_base() { + match op { + UnaryOp::Sqrt if unit.length % 2 == 0 && unit.mass % 2 == 0 && unit.time % 2 == 0 => Unit { + length: unit.length / 2, + mass: unit.mass / 2, + time: unit.time / 2, + }, + UnaryOp::Neg => unit, + op => return Err(ParseError::Type(TypeError::InvalidUnaryOp(unit, op))), + } + } else { + Unit::BASE_UNIT + }; + + Ok((node, NodeMetadata::new(unit))) + }) + .map_postfix(|lhs, op| { + let (lhs_node, lhs_metadata) = lhs?; + + let op = match op.as_rule() { + Rule::EOI => return Ok((lhs_node, lhs_metadata)), + Rule::fac => UnaryOp::Fac, + rule => unreachable!("unexpected rule: {:?}", rule), + }; + + if !lhs_metadata.unit.is_base() { + return Err(ParseError::Type(TypeError::InvalidUnaryOp(lhs_metadata.unit, op))); + } + + Ok((Node::UnaryOp { expr: Box::new(lhs_node), op }, lhs_metadata)) + }) + .map_infix(|lhs, op, rhs| { + let (lhs, lhs_metadata) = lhs?; + let (rhs, rhs_metadata) = rhs?; + + let op = match op.as_rule() { + Rule::add => BinaryOp::Add, + Rule::sub => BinaryOp::Sub, + Rule::mul => BinaryOp::Mul, + Rule::div => BinaryOp::Div, + Rule::pow => BinaryOp::Pow, + Rule::paren => BinaryOp::Mul, + rule => unreachable!("unexpected rule: {:?}", rule), + }; + + let (lhs_unit, rhs_unit) = (lhs_metadata.unit, rhs_metadata.unit); + + let unit = match (!lhs_unit.is_base(), !rhs_unit.is_base()) { + (true, true) => match op { + BinaryOp::Mul => Unit { + length: lhs_unit.length + rhs_unit.length, + mass: lhs_unit.mass + rhs_unit.mass, + time: lhs_unit.time + rhs_unit.time, + }, + BinaryOp::Div => Unit { + length: lhs_unit.length - rhs_unit.length, + mass: lhs_unit.mass - rhs_unit.mass, + time: lhs_unit.time - rhs_unit.time, + }, + BinaryOp::Add | BinaryOp::Sub => { + if lhs_unit == rhs_unit { + lhs_unit + } else { + return Err(ParseError::Type(TypeError::InvalidBinaryOp(lhs_unit, op, rhs_unit))); + } + } + BinaryOp::Pow => { + return Err(ParseError::Type(TypeError::InvalidBinaryOp(lhs_unit, op, rhs_unit))); + } + }, + + (true, false) => match op { + BinaryOp::Add | BinaryOp::Sub => return Err(ParseError::Type(TypeError::InvalidBinaryOp(lhs_unit, op, Unit::BASE_UNIT))), + BinaryOp::Pow => { + //TODO: improve error type + //TODO: support 1 / int + if let Ok(Value::Number(Number::Real(val))) = rhs.eval(&EvalContext::default()) { + if (val - val as i32 as f64).abs() <= f64::EPSILON { + Unit { + length: lhs_unit.length * val as i32, + mass: lhs_unit.mass * val as i32, + time: lhs_unit.time * val as i32, + } + } else { + return Err(ParseError::Type(TypeError::InvalidBinaryOp(lhs_unit, op, Unit::BASE_UNIT))); + } + } else { + return Err(ParseError::Type(TypeError::InvalidBinaryOp(lhs_unit, op, Unit::BASE_UNIT))); + } + } + _ => lhs_unit, + }, + (false, true) => match op { + BinaryOp::Add | BinaryOp::Sub | BinaryOp::Pow => return Err(ParseError::Type(TypeError::InvalidBinaryOp(Unit::BASE_UNIT, op, rhs_unit))), + _ => rhs_unit, + }, + (false, false) => Unit::BASE_UNIT, + }; + + let node = Node::BinOp { + lhs: Box::new(lhs), + op, + rhs: Box::new(rhs), + }; + + Ok((node, NodeMetadata::new(unit))) + }) + .parse(pairs) +} + +//TODO: set up Unit test for Units +#[cfg(test)] +mod tests { + use super::*; + macro_rules! test_parser { + ($($name:ident: $input:expr => $expected:expr),* $(,)?) => { + $( + #[test] + fn $name() { + let result = Node::from_str($input).unwrap(); + assert_eq!(result.0, $expected); + } + )* + }; + } + + test_parser! { + test_parse_int_literal: "42" => Node::Lit(Literal::Float(42.0)), + test_parse_float_literal: "3.14" => Node::Lit(Literal::Float(3.14)), + test_parse_ident: "x" => Node::Var("x".to_string()), + test_parse_unary_neg: "-42" => Node::UnaryOp { + expr: Box::new(Node::Lit(Literal::Float(42.0))), + op: UnaryOp::Neg, + }, + test_parse_binary_add: "1 + 2" => Node::BinOp { + lhs: Box::new(Node::Lit(Literal::Float(1.0))), + op: BinaryOp::Add, + rhs: Box::new(Node::Lit(Literal::Float(2.0))), + }, + test_parse_binary_mul: "3 * 4" => Node::BinOp { + lhs: Box::new(Node::Lit(Literal::Float(3.0))), + op: BinaryOp::Mul, + rhs: Box::new(Node::Lit(Literal::Float(4.0))), + }, + test_parse_binary_pow: "2 ^ 3" => Node::BinOp { + lhs: Box::new(Node::Lit(Literal::Float(2.0))), + op: BinaryOp::Pow, + rhs: Box::new(Node::Lit(Literal::Float(3.0))), + }, + test_parse_unary_sqrt: "sqrt(16)" => Node::UnaryOp { + expr: Box::new(Node::Lit(Literal::Float(16.0))), + op: UnaryOp::Sqrt, + }, + test_parse_sqr_ident: "sqr(16)" => Node::FnCall { + name:"sqr".to_string(), + expr: vec![Node::Lit(Literal::Float(16.0))] + }, + + test_parse_complex_expr: "(1 + 2) 3 - 4 ^ 2" => Node::BinOp { + lhs: Box::new(Node::BinOp { + lhs: Box::new(Node::BinOp { + lhs: Box::new(Node::Lit(Literal::Float(1.0))), + op: BinaryOp::Add, + rhs: Box::new(Node::Lit(Literal::Float(2.0))), + }), + op: BinaryOp::Mul, + rhs: Box::new(Node::Lit(Literal::Float(3.0))), + }), + op: BinaryOp::Sub, + rhs: Box::new(Node::BinOp { + lhs: Box::new(Node::Lit(Literal::Float(4.0))), + op: BinaryOp::Pow, + rhs: Box::new(Node::Lit(Literal::Float(2.0))), + }), + } + } +} diff --git a/libraries/math-parser/src/value.rs b/libraries/math-parser/src/value.rs new file mode 100644 index 0000000000..959c36c154 --- /dev/null +++ b/libraries/math-parser/src/value.rs @@ -0,0 +1,128 @@ +use std::f64::consts::PI; + +use num_complex::ComplexFloat; + +use crate::ast::{BinaryOp, UnaryOp}; + +pub type Complex = num_complex::Complex; + +#[derive(Debug, PartialEq, Clone, Copy)] +pub enum Value { + Number(Number), +} + +impl Value { + pub fn from_f64(x: f64) -> Self { + Self::Number(Number::Real(x)) + } + + pub fn as_real(&self) -> Option { + match self { + Self::Number(Number::Real(val)) => Some(*val), + _ => None, + } + } +} + +impl From for Value { + fn from(x: f64) -> Self { + Self::from_f64(x) + } +} + +impl core::fmt::Display for Value { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Value::Number(num) => num.fmt(f), + } + } +} + +#[derive(Debug, PartialEq, Clone, Copy)] +pub enum Number { + Real(f64), + Complex(Complex), +} + +impl std::fmt::Display for Number { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Number::Real(real) => real.fmt(f), + Number::Complex(complex) => complex.fmt(f), + } + } +} + +impl Number { + pub fn binary_op(self, op: BinaryOp, other: Number) -> Number { + match (self, other) { + (Number::Real(lhs), Number::Real(rhs)) => { + let result = match op { + BinaryOp::Add => lhs + rhs, + BinaryOp::Sub => lhs - rhs, + BinaryOp::Mul => lhs * rhs, + BinaryOp::Div => lhs / rhs, + BinaryOp::Pow => lhs.powf(rhs), + }; + Number::Real(result) + } + + (Number::Complex(lhs), Number::Complex(rhs)) => { + let result = match op { + BinaryOp::Add => lhs + rhs, + BinaryOp::Sub => lhs - rhs, + BinaryOp::Mul => lhs * rhs, + BinaryOp::Div => lhs / rhs, + BinaryOp::Pow => lhs.powc(rhs), + }; + Number::Complex(result) + } + + (Number::Real(lhs), Number::Complex(rhs)) => { + let lhs_complex = Complex::new(lhs, 0.0); + let result = match op { + BinaryOp::Add => lhs_complex + rhs, + BinaryOp::Sub => lhs_complex - rhs, + BinaryOp::Mul => lhs_complex * rhs, + BinaryOp::Div => lhs_complex / rhs, + BinaryOp::Pow => lhs_complex.powc(rhs), + }; + Number::Complex(result) + } + + (Number::Complex(lhs), Number::Real(rhs)) => { + let rhs_complex = Complex::new(rhs, 0.0); + let result = match op { + BinaryOp::Add => lhs + rhs_complex, + BinaryOp::Sub => lhs - rhs_complex, + BinaryOp::Mul => lhs * rhs_complex, + BinaryOp::Div => lhs / rhs_complex, + BinaryOp::Pow => lhs.powf(rhs), + }; + Number::Complex(result) + } + } + } + + pub fn unary_op(self, op: UnaryOp) -> Number { + match self { + Number::Real(real) => match op { + UnaryOp::Neg => Number::Real(-real), + UnaryOp::Sqrt => Number::Real(real.sqrt()), + + UnaryOp::Fac => todo!("Implement factorial"), + }, + + Number::Complex(complex) => match op { + UnaryOp::Neg => Number::Complex(-complex), + UnaryOp::Sqrt => Number::Complex(complex.sqrt()), + + UnaryOp::Fac => todo!("Implement factorial"), + }, + } + } + + pub fn from_f64(x: f64) -> Self { + Self::Real(x) + } +}