Skip to content

Commit

Permalink
Restructure and rebase on new format of equations
Browse files Browse the repository at this point in the history
  • Loading branch information
carlini committed Sep 28, 2024
1 parent 384bd17 commit 14967bf
Show file tree
Hide file tree
Showing 533 changed files with 33,248 additions and 109,833 deletions.
1 change: 1 addition & 0 deletions equational_theories.lean
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
import equational_theories.Subgraph
import equational_theories.AllEquations
import equational_theories.SimpleRewrites
272 changes: 272 additions & 0 deletions equational_theories/SimpleRewrites.lean

Large diffs are not rendered by default.

9 changes: 9 additions & 0 deletions equational_theories/SimpleRewrites/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
Auto-generate all proofs that can be solved by substitution and variable re-naming. Specifically, all theorems are of the form

```
theorem Equation2237_implies_Equation2235 := λ x y z w u => h x y z w u w
```

Running `src/find_simple_rewrites.py` will automatically generate this list.

This scripts finds simple rewrites by first syntactically checking if it looks like there might be a possible rewrite rule that matches, and if so, runs a very simple (but slower) inference check to make sure it works.

Large diffs are not rendered by default.

68 changes: 68 additions & 0 deletions equational_theories/SimpleRewrites/src/find_simple_rewrites.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import utils
import os
import re
import itertools

equations_txt = open("equations.txt", "r").read().split("\n")[:-1]

def get_eq():
"""
Parse the data out of the equations.txt file and turn it into trees
"""
fns = []
for eq in equations_txt:
oeq = eq
eq = eq.split("∀")[1]
variables, eq = eq.split(":")
variables = variables.strip().split()
rule = eq.split(",")[1]
fns.append((variables, utils.make_tree(rule)))

return fns

equations = get_eq()

did = {}
remap_to_rule = {}

for i,(v_a, a) in enumerate(equations):
print(i)
for j,(v_b, b) in enumerate(equations):
if i == j: continue

remap = {}
for chr1, chr2 in zip(str(a), str(b)):
if chr1 != chr2:
remap[chr1] = chr2
if '(' in remap or ' ' in remap or ')' in remap: continue

a_rename = a.rename(remap)
if not utils.is_same_under_rewriting(a_rename, b):
continue

remapk = tuple(sorted(remap.items()))
if remapk not in remap_to_rule:
remap_to_rule[remapk] = []
oo = (f"theorem Equation{i+1}_implies_Equation{j+1} (G : Type*) [Magma G] (h : Equation{i+1} G) : Equation{j+1} G := λ " + " ".join(v_b) + " => h " + " ".join([remap.get(x) or x for x in v_a]))
remap_to_rule[remapk].append(oo)


if not os.path.exists("theorems"):
os.makedirs("theorems")

for rule, outs in remap_to_rule.items():
fname = "theorems/Rewrite_" + "_".join([f"{k}{v}" for k,v in rule]) + ".lean"
proofs = "\n".join(outs)
proofs = """import Mathlib.Tactic
import Mathlib.Data.Nat.Defs
import equational_theories.AllEquations
import equational_theories.Magma
namespace SimpleRewrites
""" + proofs + "\nend SimpleRewrites"
open(fname, "w").write(proofs)


151 changes: 151 additions & 0 deletions equational_theories/SimpleRewrites/src/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
from itertools import product
import re
from typing import List, Callable

import re

# Define the expression node
class ExprNode:
def __init__(self, value, left=None, right=None):
self.value = value # Operator or operand
self.left = left
self.right = right

def __repr__(self):
if self.left and self.right:
return f"({self.left} {self.value} {self.right})"
return "("+self.value+")"

def get_leafs(root):
def traverse(node, leaves):
if not node:
return

if not node.left and not node.right: # Leaf node
leaves.add(node.value)
else:
traverse(node.left, leaves)
traverse(node.right, leaves)

leaf_set = set()
traverse(root, leaf_set)
return leaf_set

def rename(root, rename_map):
def traverse(node):
if not node:
return None

if not node.left and not node.right: # Leaf node
new_value = rename_map.get(node.value, node.value)
return ExprNode(new_value)

new_left = traverse(node.left)
new_right = traverse(node.right)
return ExprNode(node.value, new_left, new_right)

return traverse(root)

def is_same_under_rewriting(left, right):
def traverse(node, mapping):
if not node.left and not node.right: # Leaf node
if node.value not in mapping:
if node.value in mapping.values():
return False # This value is already mapped to another variable
mapping[node.value] = len(mapping)
return mapping[node.value]

if not node.left or not node.right:
return None # Invalid expression tree

left_result = traverse(node.left, mapping)
right_result = traverse(node.right, mapping)

if left_result is None or right_result is None:
return None

return (node.value, left_result, right_result)

mapping1 = {}
left_structure = traverse(left, mapping1)
mapping2 = {}
right_structure = traverse(right, mapping2)

if left_structure == right_structure:
return {v:k for k,v in mapping1.items()}, {v:k for k,v in mapping2.items()}
return None

# Parser implementation
class Parser:
def __init__(self, expression):
self.expression = expression.replace(' ', '')
self.index = 0
self.length = len(self.expression)

def parse(self):
return self.parse_expression()

def parse_expression(self):
nodes = [self.parse_term()]

while self.current_char() == '∘' or self.current_char() == '.':
op = self.current_char()
self.advance()
right = self.parse_term()
nodes.append(op)
nodes.append(right)

# Build the tree (left-associative)
node = nodes[0]
for i in range(1, len(nodes), 2):
node = ExprNode(nodes[i], left=node, right=nodes[i+1])

return node

def parse_term(self):
char = self.current_char()
if char == '(':
self.advance()
node = self.parse_expression()
if self.current_char() != ')':
raise ValueError("Mismatched parentheses")
self.advance()
return node
else:
return self.parse_variable()

def parse_variable(self):
match = re.match(r'[a-zA-Z_]\w*', self.expression[self.index:])
if not match:
raise ValueError(f"Invalid character at index {self.index}")
var = match.group(0)
self.index += len(var)
return ExprNode(var)

def current_char(self):
if self.index < self.length:
return self.expression[self.index]
return None

def advance(self):
self.index += 1

# Function to convert expression tree to prefix notation
def expr_to_prefix(node):
if node.value == '∘':
left = expr_to_prefix(node.left)
right = expr_to_prefix(node.right)
return f"f({left}, {right})"
else:
return node.value

def make_tree(equation):
lhs_expr, rhs_expr = equation.split('=')
parser_lhs = Parser(lhs_expr)
tree_lhs = parser_lhs.parse()

# Parse RHS
parser_rhs = Parser(rhs_expr)
tree_rhs = parser_rhs.parse()

return ExprNode("=", left=tree_lhs, right=tree_rhs)
Loading

0 comments on commit 14967bf

Please sign in to comment.