forked from teorth/equational_theories
-
Notifications
You must be signed in to change notification settings - Fork 0
/
process_implications.py
executable file
·209 lines (175 loc) · 7.87 KB
/
process_implications.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
#!/usr/bin/env python3
"""
Example usage:
```sh
$ lake exe extract_implications | python scripts/process_implications.py
```
"""
from collections import defaultdict
import json
import os
from random import sample
import re
from sys import argv, stdin
import networkx as nx
def transitive_closure(pairs):
pairs_idx = defaultdict(list)
for a, b in pairs:
pairs_idx[a].append(b)
new_pairs = closure = set(pairs)
while new_pairs:
new_pairs = {
(a, c)
for a, b in new_pairs
for c in pairs_idx[b]
} - closure
closure |= new_pairs
return closure
def longest_path(pairs, src, dst):
fwd = defaultdict(list)
for a, b in pairs:
if b != a:
fwd[a].append(b)
cache = {dst: (dst,)}
# Explicit stack since we can't DFS over 4000 vertices in Python
stack = [src]
while stack:
node = stack[-1]
if node in cache:
stack.pop()
continue
if child := next((child for child in fwd[node] if child not in cache), None):
stack.append(child)
continue
cache[node] = (node,) + max((cache[child] for child in fwd[node]), key=len)
stack.pop()
return cache[src]
def get_unknown_implications(universe, known_implies, known_not_implies):
all_implications = transitive_closure(known_implies)
fwd_implications = defaultdict(set)
bwd_implications = defaultdict(set)
for a, b in all_implications:
fwd_implications[a].add(b)
bwd_implications[b].add(a)
all_negative_implications = set(
(c, d)
for a, b in known_not_implies
for c in fwd_implications[a]
for d in bwd_implications[b]
)
return set((a, b) for a in universe for b in universe) - all_implications - all_negative_implications
def parse_proofs_file_internal(universe, known_implies, known_not_implies, equations_files, file_name):
# This code is buggy: it doesn't verify that the proofs are correct.
# It is also extremely sensitive to formatting of the proof types. There's
# probably a way to get this directly from Lean.
for file in equations_files:
for line in open(file):
if m := re.match(r'abbrev\s+(Equation\d+)\s+', line):
universe.add(m.group(1))
known_implies.add((m.group(1), m.group(1)))
try:
for line in open(file_name):
if m := re.match(r'theorem\s+.*\[Magma\s+G\]\s*:\s*(Equation\d+)\s*G\s*:=', line):
universe.add(m.group(1))
for eq in universe:
known_implies.add((eq, m.group(1)))
elif m := re.match(r'theorem\s+.*\[Magma\s+G\]\s*\(.\s*:\s*(Equation\d+)\s+G\)\s*:\s*(Equation\d+)\s+G\s*:=', line):
universe.add(m.group(1))
universe.add(m.group(2))
known_implies.add((m.group(1), m.group(2)))
elif m := re.match(r'theorem\s+.*:\s*∃.*\(_:\s*Magma\s+G\),\s*(Equation\d+)\s+G\s*∧\s*¬\s*(Equation\d+)\s+G\s*:=', line):
universe.add(m.group(1))
universe.add(m.group(2))
known_not_implies.add((m.group(1), m.group(2)))
return universe, known_implies, known_not_implies
except UnicodeDecodeError as err:
print(f"File {file_name} encounter error: {err}")
raise err
def parse_proofs_file(equations_files, file_name):
universe = set()
known_implies, known_not_implies = set(), set()
parse_proofs_file_internal(universe, known_implies, known_not_implies, equations_files, file_name)
return universe, known_implies, known_not_implies
def parse_proofs_files(equations_files, files):
universe = set()
known_implies, known_not_implies = set(), set()
for file_name in files:
parse_proofs_file_internal(universe, known_implies, known_not_implies, equations_files, file_name)
return universe, known_implies, known_not_implies
def parse_extracted_implications():
output = json.load(stdin)
print(f'Parsed {len(output["unconditionals"])} unconditionals, {len(output["implications"])} implications, {len(output["facts"])} facts')
universe = set()
universe.update(output['unconditionals'])
universe.update(implication[side] for implication in output['implications'] for side in ['lhs', 'rhs'])
universe.update(eq for example in output['facts'] for status in ['satisfied', 'refuted'] for eq in example[status])
known_implies = set()
known_implies.update((eq, eq) for eq in universe)
known_implies.update((implication['lhs'], implication['rhs']) for implication in output['implications'])
known_implies.update((eq, ueq) for eq in universe for ueq in output['unconditionals'])
G = nx.DiGraph()
G.add_nodes_from(universe)
G.add_edges_from(known_implies)
comp_names = {}
names = set()
for comp in nx.strongly_connected_components(G):
name = f'Equation{min(int(eq[8:]) for eq in comp)}'
names.add(name)
for eq in comp:
comp_names[eq] = name
print(f'Processing {len(names)} equivalence classes of {len(universe)} laws')
comp_implies = {(comp_names[lhs], comp_names[rhs]) for lhs, rhs in known_implies}
all_implications = transitive_closure(comp_implies)
print('All implications:', len(all_implications))
fwd_implications = {eq: set() for eq in names}
bwd_implications = {eq: set() for eq in names}
for a, b in all_implications:
fwd_implications[a].add(b)
bwd_implications[b].add(a)
all_negative_implications = set()
for example in output['facts']:
pos = {succ for eq in example['satisfied'] for succ in fwd_implications[comp_names[eq]]}
neg = {pred for eq in example['refuted'] for pred in bwd_implications[comp_names[eq]]}
all_negative_implications.update((p, n) for p in pos for n in neg)
print('All negative implications:', len(all_negative_implications))
missing_implications = set((a, b) for a in names for b in names) - all_implications - all_negative_implications
print(f'Missing implications: {len(missing_implications)}')
irreducible = missing_implications
irreducible = {
(lhs, rhs) for lhs, rhs in irreducible
if all((succ, rhs) not in irreducible for succ in fwd_implications[lhs] if succ != lhs)
}
irreducible = {
(lhs, rhs) for lhs, rhs in irreducible
if all((lhs, pred) not in irreducible for pred in bwd_implications[rhs] if pred != rhs)
}
G = nx.DiGraph()
G.add_nodes_from(names)
G.add_edges_from(comp_implies)
G.add_edges_from(irreducible)
print('Equivalence classes if all conjectured implications hold:', len(list(nx.strongly_connected_components(G))))
path = longest_path({(lhs, rhs) for lhs, rhs in all_implications if (rhs, lhs) in all_negative_implications}, 'Equation2', 'Equation1')
print('Longest known chain of non-equivalent implications: ', ' => '.join(eq[8:] for eq in path))
print(f'Irreducible missing implications: {len(irreducible)}')
for lhs, rhs in sorted(irreducible, key=lambda x: (int(x[0][8:]), int(x[1][8:]))):
print(lhs, '=>', rhs)
if __name__ == '__main__':
if len(argv) == 1:
parse_extracted_implications()
exit()
try:
file_name = argv[1]
assert os.path.exists(file_name)
except:
print('Usage: python process_implications.py <file_name.lean>')
exit(1)
equations_file = os.path.join(os.path.dirname(file_name), "Equations/Basic.lean")
universe, known_implies, known_not_implies = parse_proofs_file([], file_name)
all_unknown = get_unknown_implications(universe, known_implies, known_not_implies)
print(f'Found {len(all_unknown)} unknown implications')
if all_unknown:
k = min(10, len(all_unknown))
if k < len(all_unknown):
print('Sample of', k, 'unknown implications:')
for a, b in sample(list(all_unknown), k):
print(f'{a} => {b}')