-
Notifications
You must be signed in to change notification settings - Fork 312
/
TwoSat.py
52 lines (46 loc) · 1.42 KB
/
TwoSat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
def find_SCC(graph):
SCC, S, P = [], [], []
depth = [0] * len(graph)
stack = list(range(len(graph)))
while stack:
node = stack.pop()
if node < 0:
d = depth[~node] - 1
if P[-1] > d:
SCC.append(S[d:])
del S[d:], P[-1]
for node in SCC[-1]:
depth[node] = -1
elif depth[node] > 0:
while P[-1] > depth[node]:
P.pop()
elif depth[node] == 0:
S.append(node)
P.append(len(S))
depth[node] = len(S)
stack.append(~node)
stack += graph[node]
return SCC[::-1]
class TwoSat:
def __init__(self, n):
self.n = n
self.graph = [[] for _ in range(2 * n)]
def _imply(self, x, y):
self.graph[x].append(y if y >= 0 else 2 * self.n + y)
def either(self, x, y):
"""either x or y must be True"""
self._imply(~x, y)
self._imply(~y, x)
def set(self, x):
"""x must be True"""
self._imply(~x, x)
def solve(self):
SCC = find_SCC(self.graph)
order = [0] * (2 * self.n)
for i, comp in enumerate(SCC):
for x in comp:
order[x] = i
for i in range(self.n):
if order[i] == order[~i]:
return False, None
return True, [+(order[i] > order[~i]) for i in range(self.n)]