Skip to content

Commit

Permalink
fuzz uop schedules (tinygrad#5345)
Browse files Browse the repository at this point in the history
* basic blocks + cleanups

* fixups

* elif is better for future me

* fuzz_schedule_max_paths

* fix linter
  • Loading branch information
Qazalin authored Jul 9, 2024
1 parent d5a68ae commit bee96a1
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 58 deletions.
7 changes: 4 additions & 3 deletions test/external/fuzz_schedule.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import itertools
import numpy as np
from typing import DefaultDict, Dict, List, Set, Tuple, TypeVar
from typing import DefaultDict, Dict, List, Set, Tuple, TypeVar, Union
from tinygrad.device import Buffer
from tinygrad.engine.realize import CustomOp, capturing, lower_schedule_item
from tinygrad.helpers import DEBUG, MULTIOUTPUT, colored, getenv
Expand All @@ -10,6 +10,7 @@
from tinygrad.tensor import Tensor, _to_np_dtype

ctx_vars = { MULTIOUTPUT: (0, 1) }
FUZZ_SCHEDULE_MAX_PATHS = getenv("FUZZ_SCHEDULE_MAX_PATHS", 10)

def fuzz_schedule(outs:List[LazyBuffer]):
# find toposorts across all tunable params
Expand Down Expand Up @@ -73,7 +74,7 @@ def _exec_si(si:ScheduleItem, seed:int):
ei.run()

T = TypeVar("T")
def find_all_toposorts(graph:DefaultDict[T, List[T]], in_degree:DefaultDict[T, int]) -> List[Tuple[T, ...]]:
def find_all_toposorts(graph:DefaultDict[T, List[T]], in_degree:Union[DefaultDict[T, int], Dict[T, int]]) -> List[Tuple[T, ...]]:
visited: Set[T] = set()
ret: List[Tuple[T, ...]] = []
path: List[T] = []
Expand All @@ -85,7 +86,7 @@ def recurse_paths(path:List[T]):
path.append(v)
visited.add(v)
recurse_paths(path)
if len(ret) >= getenv("FUZZ_SCHEDULE_MAX_PATHS", 10): return
if len(ret) >= FUZZ_SCHEDULE_MAX_PATHS: return
# backtrack
for u in graph[v]: in_degree[u] += 1
path.pop()
Expand Down
78 changes: 33 additions & 45 deletions test/external/fuzz_uops.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,55 @@
import itertools
from collections import defaultdict
import numpy as np
from dataclasses import replace
from typing import Dict, List, Set, Tuple
from tinygrad.codegen.uops import UOp, UOpGraph, UOps
from typing import DefaultDict, Dict, List, Tuple
from tinygrad.codegen.uops import END_FOR_UOP, UOp, UOpGraph
from tinygrad.device import Buffer, Device
from tinygrad.engine.realize import CompiledRunner
from tinygrad.helpers import DEBUG, colored, getenv
from tinygrad.helpers import DEBUG, colored
from tinygrad.shape.symbolic import Variable
from tinygrad.tensor import _to_np_dtype
from test.external.fuzz_schedule import FUZZ_SCHEDULE_MAX_PATHS, find_all_toposorts

def fuzz_uops(graph:Dict[UOp, List[UOp]], in_degree:Dict[UOp, int], loops_children:Dict[UOp, Set[UOp]]):
paths: List[List[UOp]] = []
# TODO: express DEFINE_ACC and loop children conditions in the graph, builtin.
for p in find_all_toposorts(graph, in_degree):
assert p[-1].op is UOps.SINK, f"didn't end with SINK, ended with {p[-1]}"
paths.append(path:=list(p[:-1]))
for u in path:
if u.op is UOps.IF: path.append(UOp(UOps.ENDIF, None, (u,)))
if u.op is UOps.RANGE:
path.insert(max(path.index(x) for x in loops_children[u] if x in path)+1, UOp(UOps.ENDRANGE, None, (u,)))
return paths
def fuzz_uops(uops:UOpGraph) -> List[Tuple[UOp, ...]]:
blocks: List[List[UOp]] = [[]]
for u in uops:
if u.op in END_FOR_UOP: blocks.append([u])
elif u.op in {x[1] for x in END_FOR_UOP.values()}: blocks.extend([[u], []])
else: blocks[-1].append(u)

paths_for_block: Dict[int, List[Tuple[UOp, ...]]] = {}
for bi, bb in enumerate(blocks):
children: DefaultDict[UOp, List[UOp]] = defaultdict(list)
in_degree: Dict[UOp, int] = {}
for u in bb:
in_degree[u] = 0
for x in u.src:
if x in bb:
children[x].append(u)
in_degree[u] += 1
paths_for_block[bi] = find_all_toposorts(children, in_degree)
paths: Dict[Tuple[UOp, ...], None] = {}
for up in itertools.product(*paths_for_block.values()):
paths[tuple(uop for path in up for uop in path)] = None
if len(paths) >= FUZZ_SCHEDULE_MAX_PATHS: break
return list(paths)

class UOpsFuzzerRunner(CompiledRunner):
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False):
assert self.p.uops is not None and len(self.p.uops.fuzz_paths) >= 1
assert self.p.uops is not None and len(self.p.uops._fuzz_paths) >= 1
init_rawbufs, init_name = {x:x.as_buffer() for x in rawbufs}, self.p.function_name
init_globals = {i[0]:buf for i, buf in zip(self.p.globals, rawbufs)}
if DEBUG >= 1: print(colored(f"fuzzing {len(self.p.uops.fuzz_paths)} UOps permutations for {init_name}", "yellow"))
if DEBUG >= 1: print(colored(f"fuzzing {len(self.p.uops._fuzz_paths)} uop permutations for {init_name}", "yellow"))

super().__call__(rawbufs, var_vals, wait)
ground_truth = {x:np.frombuffer(x.as_buffer(), _to_np_dtype(x.dtype)) for x in rawbufs}

for i, path in enumerate(self.p.uops.fuzz_paths):
for i, path in enumerate(self.p.uops._fuzz_paths):
# setup prg
uops = UOpGraph([])
uops._uops = list(path)
if DEBUG >= 6: uops.print()
if DEBUG >= 5: uops.print()
self.p = replace(self.p, name=(name:=f"{init_name}fuzz{i}"), src=Device[self.p.dname].renderer.render(name, uops), uops=uops)
if DEBUG >= 4: print(self.p.src)
self.lib = Device[self.p.dname].compiler.compile_cached(self.p.src)
Expand All @@ -49,30 +64,3 @@ def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=Fals
except AssertionError as e:
print(colored(name, "red"))
raise e

def find_all_toposorts(graph:Dict[UOp, List[UOp]], in_degree:Dict[UOp, int]) -> List[Tuple[UOp, ...]]:
visited: Set[UOp] = set()
ret: List[Tuple[UOp, ...]] = []
path: List[UOp] = []

def recurse_paths(path:List[UOp]):
for v, d in in_degree.items():
if d != 0 or v in visited: continue
if v.op is UOps.DEFINE_ACC and any(l not in path for l in v.src): continue
for u in graph[v]: in_degree[u] -= 1
if v.op is UOps.DEFINE_ACC: path.insert(min(path.index(l) for l in v.src), v)
else: path.append(v)
visited.add(v)
recurse_paths(path)
if len(ret) >= getenv("FUZZ_UOPS_MAX_PATHS", 10): return
# backtrack
for u in graph[v]: in_degree[u] += 1
path.pop()
visited.remove(v)
if len(path) == len(in_degree): ret.append(tuple(path))
recurse_paths(path)

if len(ret) == 0: raise RuntimeError("detected cycle in the graph")
# verify all paths are unique
assert len(ret) == len(set(ret))
return ret
18 changes: 9 additions & 9 deletions tinygrad/codegen/uops.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ class UOps(Enum):
# these two are not graph nodes
ENDRANGE = auto(); ENDIF = auto() # noqa: E702

END_FOR_UOP = {UOps.IF:(UOps.STORE, UOps.ENDIF), UOps.RANGE:(UOps.PHI, UOps.ENDRANGE)}

def ufix(dtype: Optional[DType], x): return UOp.const(dtype, x) if not isinstance(x, UOp) else x
@dataclass(frozen=True, eq=False)
class UOp:
Expand Down Expand Up @@ -368,9 +370,9 @@ def vars(self) -> List[Variable]: return sorted([x.arg for x in self.uops if x.o
def globals(self) -> List[Tuple[int, bool]]: return [x.arg for x in self.uops if x.op is UOps.DEFINE_GLOBAL]

@property
def uops(self):
def uops(self) -> List[UOp]:
if self._uops is None: self.linearize()
return self._uops
return cast(List[UOp], self._uops)

def graph(self):
from tinygrad.engine.graph import graph_uops
Expand Down Expand Up @@ -412,8 +414,7 @@ def get_recursive_children(x:UOp, end:UOps, include_self=False) -> Set[UOp]:
return set.union(set((x,)) if include_self else set(), *([get_recursive_children(u, end, True) for u in children[x] if x.op is not end]))

# scope children impact the toposort and END* insertion
end_for_uop = {UOps.IF:(UOps.STORE, UOps.ENDIF), UOps.RANGE:(UOps.PHI, UOps.ENDRANGE)}
scope_children = {p:get_recursive_children(p, end_for_uop[p.op][0]) for p in reversed(in_degree) if p.op in end_for_uop}
scope_children = {p:get_recursive_children(p, END_FOR_UOP[p.op][0]) for p in reversed(in_degree) if p.op in END_FOR_UOP}

queue:List[Tuple[int, UOp]] = []
def push(u:UOp):
Expand All @@ -426,10 +427,6 @@ def push(u:UOp):
for u in children:
if in_degree[u] == 0: push(u)

if getenv("FUZZ_UOPS", 0):
from test.external.fuzz_uops import fuzz_uops
self.fuzz_paths = fuzz_uops(children, in_degree.copy(), scope_children)

self._uops = []
while queue:
p,x = heapq.heappop(queue)
Expand All @@ -443,11 +440,14 @@ def push(u:UOp):
if in_degree[u] == 0: push(u)

for u in (self._uops):
if u.op in end_for_uop: self._uops.insert(max([self._uops.index(l) for l in scope_children[u]])+1, UOp(end_for_uop[u.op][1], None, (u,)))
if u.op in END_FOR_UOP: self._uops.insert(max([self._uops.index(l) for l in scope_children[u]])+1, UOp(END_FOR_UOP[u.op][1], None, (u,)))

assert self._uops[-1].op is UOps.SINK, f"didn't end with SINK, ended with {self._uops[-1]}"
self._uops = self._uops[:-1]

if getenv("FUZZ_UOPS"):
from test.external.fuzz_uops import fuzz_uops
self._fuzz_paths = fuzz_uops(self)
if do_type_verify: type_verify(self.uops)

# *** checker functions ***
Expand Down
2 changes: 1 addition & 1 deletion tinygrad/engine/realize.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def get_runner(dname:str, ast:Tuple[LazyOp, ...]) -> CompiledRunner:
method_cache[ckey] = ret = CompiledRunner(replace(bret.p, dname=dname), bret.lib)
else:
prg: Program = get_linearizer(Device[dname].renderer, ast).to_program()
if hasattr(prg.uops, "fuzz_paths"):
if hasattr(prg.uops, "_fuzz_paths"):
from test.external.fuzz_uops import UOpsFuzzerRunner
return UOpsFuzzerRunner(replace(prg, dname=dname))
method_cache[ckey] = method_cache[bkey] = ret = CompiledRunner(replace(prg, dname=dname))
Expand Down

0 comments on commit bee96a1

Please sign in to comment.