Skip to content

Commit

Permalink
MetaOps.KERNEL (tinygrad#5543)
Browse files Browse the repository at this point in the history
  • Loading branch information
geohot authored Jul 18, 2024
1 parent d3b0982 commit fa7e734
Show file tree
Hide file tree
Showing 19 changed files with 39 additions and 39 deletions.
2 changes: 1 addition & 1 deletion docs/abstractions2.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
ld_2 = LazyOp(BufferOps.LOAD, (), MemBuffer(2, dtypes.int32, ShapeTracker.from_shape((1,))))
alu = LazyOp(BinaryOps.ADD, (ld_1, ld_2))
st_0 = LazyOp(BufferOps.STORE, (alu,), MemBuffer(0, dtypes.int32, ShapeTracker.from_shape((1,))))
sink = LazyOp(MetaOps.SINK, (st_0,))
sink = LazyOp(MetaOps.KERNEL, (st_0,))

# convert the computation to a "linearized" format (print the format)
from tinygrad.engine.realize import get_kernel, CompiledRunner
Expand Down
2 changes: 1 addition & 1 deletion examples/handcode_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def get_sched_bert():
print(f"optimizing for {Device.DEFAULT}")

sched = globals()[f"get_sched_{getenv('MODEL', 'resnet')}"]()
sched = [x for x in sched if x.ast.op is MetaOps.SINK]
sched = [x for x in sched if x.ast.op is MetaOps.KERNEL]

# focus on one kernel
if getenv("KERNEL", -1) >= 0: sched = sched[getenv("KERNEL", -1):getenv("KERNEL", -1)+1]
Expand Down
4 changes: 2 additions & 2 deletions examples/openpilot/compile2.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def get_schedule(onnx_data) -> Tuple[List[ScheduleItem], List[ScheduleItem]]:
print(f"{len(schedule)} schedule items depend on the input, {len(schedule_independent)} don't")

# confirm no non-sink metaop in the (non independent) schedule except for the ones that load the input buffers
assert all(si.ast.op is MetaOps.SINK or out in input_lb for si in schedule for out in si.outputs), "has non SINK ops, can't compile to Thneed"
assert all(si.ast.op is MetaOps.KERNEL or out in input_lb for si in schedule for out in si.outputs), "has non SINK ops, can't compile to Thneed"
return schedule, schedule_independent, inputs

def test_vs_onnx(onnx_data, eis:Optional[List[ExecItem]], inputs:Dict[str, Tensor]):
Expand Down Expand Up @@ -105,7 +105,7 @@ def test_vs_onnx(onnx_data, eis:Optional[List[ExecItem]], inputs:Dict[str, Tenso
#exit(0)

schedule, schedule_independent, inputs = get_schedule(onnx_data)
schedule, schedule_input = partition(schedule, lambda x: x.ast.op is MetaOps.SINK)
schedule, schedule_input = partition(schedule, lambda x: x.ast.op is MetaOps.KERNEL)
print(f"{len(schedule_input)} inputs")

run_schedule(schedule_independent)
Expand Down
2 changes: 1 addition & 1 deletion extra/optimization/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

# kernel unpacker
from tinygrad.codegen.kernel import Kernel
def ast_str_to_ast(ast_str:str) -> LazyOp: return LazyOp(MetaOps.SINK, val) if isinstance(val:=eval(ast_str), tuple) else val
def ast_str_to_ast(ast_str:str) -> LazyOp: return LazyOp(MetaOps.KERNEL, val) if isinstance(val:=eval(ast_str), tuple) else val
def ast_str_to_lin(ast_str:str, opts=None): return Kernel(ast_str_to_ast(ast_str), opts=opts)
def kern_str_to_lin(kern_str:str, opts=None):
(ast, applied_opts,) = eval(kern_str)
Expand Down
2 changes: 1 addition & 1 deletion test/external/external_benchmark_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
with Timing("***** model schedule in "):
sched = out.schedule()

asts = dedup([x.ast for x in sched if x.ast.op is MetaOps.SINK])
asts = dedup([x.ast for x in sched if x.ast.op is MetaOps.KERNEL])
uops = []
with Profiling(PROFILE):
with Timing("***** model uops in "):
Expand Down
2 changes: 1 addition & 1 deletion test/test_const_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
def _check_ast_count(desired_count:int, t:Tensor):
# NOTE: this has side effect because everything can be scheduled only once
schedule = create_schedule(t.lazydata.lbs)
asts = [s for s in schedule if s.ast.op is MetaOps.SINK]
asts = [s for s in schedule if s.ast.op is MetaOps.KERNEL]
assert len(asts) == desired_count

class TestUnaryOpsConstFolding(unittest.TestCase):
Expand Down
2 changes: 1 addition & 1 deletion test/test_conv_shapetracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def test_conv_3x3_one_view(self):
# first run to init the weights, they are saved in seen
create_schedule([conv(Tensor.empty(1, 16, 10, 10)).lazydata], seen)
# run it again to get the kernels
sched = [si for si in create_schedule([conv(Tensor.empty(1, 16, 10, 10)).lazydata], seen) if si.ast.op is MetaOps.SINK]
sched = [si for si in create_schedule([conv(Tensor.empty(1, 16, 10, 10)).lazydata], seen) if si.ast.op is MetaOps.KERNEL]
assert len(sched) == 1, f"conv should only have one kernel, getting {len(sched)}"
for st in [x.arg.st for x in sched[0].ast.lazyops if x.op is BufferOps.LOAD]:
assert len(st.views) == 1
Expand Down
6 changes: 3 additions & 3 deletions test/test_linearizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,7 +785,7 @@ def _assert_grouped_dims(prefix, dims, max_sizes, reverse_dims, expected_sizes):

def test_div_collapse(self):
def helper(t, msg, max_ops=0):
sched = [si for si in create_schedule([t.lazydata]) if si.ast.op is MetaOps.SINK]
sched = [si for si in create_schedule([t.lazydata]) if si.ast.op is MetaOps.KERNEL]
assert len(sched) == 1

lin = Kernel(sched[0].ast)
Expand All @@ -806,7 +806,7 @@ def helper(t, msg, max_ops=0):

def test_sum_collapse(self):
t = Tensor([2]).reshape(1, 1).expand(256, 256).sum()
sched = [si for si in create_schedule([t.lazydata]) if si.ast.op is MetaOps.SINK]
sched = [si for si in create_schedule([t.lazydata]) if si.ast.op is MetaOps.KERNEL]
assert len(sched) == 1
lin = Kernel(sched[0].ast)
assert not any(u.op is UOps.RANGE for u in lin.linearize().uops), "found loop in sum collapse"
Expand Down Expand Up @@ -1159,7 +1159,7 @@ def test_matvec(self):
assert k.upcasted == 1

def helper_linearizer_ast(ast:Union[Tuple[LazyOp, ...], LazyOp], inputs:List[Tensor], *args, **kwargs):
if not isinstance(ast, LazyOp): ast = LazyOp(MetaOps.SINK, ast)
if not isinstance(ast, LazyOp): ast = LazyOp(MetaOps.KERNEL, ast)
inbufs = [x.lazydata.buffer for x in inputs]
outbufs = [Buffer(inbufs[-1].device if inbufs else Device.DEFAULT, out.arg.st.size, out.arg.dtype).allocate() for out in ast.src]
return _helper_linearizer_opt_ast(ast, outbufs+inbufs, *args, **kwargs)
Expand Down
14 changes: 7 additions & 7 deletions test/test_linearizer_failures.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def test_failure_32(self):

def test_failure_33(self):
# UOps.UNMUL left after linearize
ast = LazyOp(op=MetaOps.SINK, src=(
ast = LazyOp(op=MetaOps.KERNEL, src=(
LazyOp(op=BufferOps.STORE, src=(
LazyOp(op=ReduceOps.SUM, src=(
LazyOp(op=BinaryOps.MUL, src=(
Expand Down Expand Up @@ -282,7 +282,7 @@ def test_failure_33(self):

# from fuzzing on metal
def test_failure_34(self, unroll=False):
ast = LazyOp(op=MetaOps.SINK, src=(
ast = LazyOp(op=MetaOps.KERNEL, src=(
LazyOp(op=BufferOps.STORE, src=(
LazyOp(op=BinaryOps.MAX, src=(
LazyOp(op=ReduceOps.SUM, src=(
Expand All @@ -298,7 +298,7 @@ def test_failure_35(self): self.test_failure_34(True)
# from world fuzz_linearizer: PYTHONPATH=. METAL=1 FUZZ_ALL_ACTIONS=1 DEPTH=1 FUZZ_N=100 FUZZ_NTH=84 python3 ./test/external/fuzz_linearizer.py
def test_failure_36(self):
# UOps.UNMUL left after linearize
ast = LazyOp(op=MetaOps.SINK, src=(LazyOp(op=BufferOps.STORE, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1, dtype=dtypes.uchar, st=ShapeTracker(views=(View(shape=(6, 9), strides=(0, 0), offset=0, mask=((0, 6), (4, 9)), contiguous=False), View(shape=(5, 5), strides=(1, 10), offset=0, mask=None, contiguous=False))))),), arg=dtypes.uint),), arg=(1,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=-1, dtype=dtypes.uint, st=ShapeTracker(views=(View(shape=(5, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=dtypes.uchar),), arg=MemBuffer(idx=0, dtype=dtypes.uchar, st=ShapeTracker(views=(View(shape=(5, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)))),), arg=None)
ast = LazyOp(op=MetaOps.KERNEL, src=(LazyOp(op=BufferOps.STORE, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1, dtype=dtypes.uchar, st=ShapeTracker(views=(View(shape=(6, 9), strides=(0, 0), offset=0, mask=((0, 6), (4, 9)), contiguous=False), View(shape=(5, 5), strides=(1, 10), offset=0, mask=None, contiguous=False))))),), arg=dtypes.uint),), arg=(1,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=-1, dtype=dtypes.uint, st=ShapeTracker(views=(View(shape=(5, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=dtypes.uchar),), arg=MemBuffer(idx=0, dtype=dtypes.uchar, st=ShapeTracker(views=(View(shape=(5, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)))),), arg=None)
opts = [Opt(op=OptOps.UPCAST, axis=0, amt=0)]
helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[])

Expand All @@ -307,14 +307,14 @@ def test_failure_36(self):
def test_failure_37(self):
# beautiful mnist kernel number 28: 6 possible TC axis_choices (3 for axis_buf1 and 2 reduce) and all fail
# fuzz: PYTHONPATH=. METAL=1 FUZZ_ALL_ACTIONS=1 DEPTH=1 FUZZ_NTH=28 DEBUG=2 python3 ./test/external/fuzz_linearizer.py --logfile /tmp/beautiful_mnist.kernels.txt
ast = LazyOp(op=MetaOps.SINK, src=(LazyOp(op=BufferOps.STORE, src=(LazyOp(op=BinaryOps.MAX, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.uchar, st=ShapeTracker(views=(View(shape=(512, 1, 32, 24, 24, 1, 5, 5), strides=(784, 0, 0, 28, 1, 0, 28, 1), offset=0, mask=None, contiguous=False),)))),), arg=dtypes.float), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(512, 1, 32, 24, 24, 1, 5, 5), strides=(0, 0, 25, 0, 0, 0, 5, 1), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=(6, 7)), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=3, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(512, 1, 32, 24, 24, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(512, 1, 32, 24, 24, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(512, 1, 32, 24, 24, 1, 1, 1), strides=(18432, 0, 576, 24, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)))),), arg=None)
ast = LazyOp(op=MetaOps.KERNEL, src=(LazyOp(op=BufferOps.STORE, src=(LazyOp(op=BinaryOps.MAX, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.uchar, st=ShapeTracker(views=(View(shape=(512, 1, 32, 24, 24, 1, 5, 5), strides=(784, 0, 0, 28, 1, 0, 28, 1), offset=0, mask=None, contiguous=False),)))),), arg=dtypes.float), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(512, 1, 32, 24, 24, 1, 5, 5), strides=(0, 0, 25, 0, 0, 0, 5, 1), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=(6, 7)), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=3, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(512, 1, 32, 24, 24, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(512, 1, 32, 24, 24, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(512, 1, 32, 24, 24, 1, 1, 1), strides=(18432, 0, 576, 24, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)))),), arg=None)
for axis in [0,1,2,3,4,5]:
opts = [Opt(op=OptOps.TC, axis=axis, amt=2)]
helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[])
def test_failure_38(self):
# beautiful mnist kernel number 87: 6 possible TC axis_choices (2 for axis_buf1 and 3 reduce) and first/second reduce axis fail for both axis_buf1 choices
# fuzz: PYTHONPATH=. METAL=1 FUZZ_ALL_ACTIONS=1 DEPTH=1 FUZZ_NTH=87 DEBUG=2 python3 ./test/external/fuzz_linearizer.py --logfile /tmp/beautiful_mnist.kernels.txt
ast = LazyOp(op=MetaOps.SINK, src=(
ast = LazyOp(op=MetaOps.KERNEL, src=(
LazyOp(op=BufferOps.STORE, src=(
LazyOp(op=ReduceOps.SUM, src=(
LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=UnaryOps.CAST, src=(
Expand All @@ -327,14 +327,14 @@ def test_failure_38(self):
def test_failure_39(self):
# beautiful mnist kernel number 127: 6 possible TC axis_choices (3 for axis_buf1 and 2 reduce) and all fail
# fuzz: PYTHONPATH=. METAL=1 FUZZ_ALL_ACTIONS=1 DEPTH=1 FUZZ_NTH=127 DEBUG=2 python3 ./test/external/fuzz_linearizer.py --logfile /tmp/beautiful_mnist.kernels.txt
ast = LazyOp(op=MetaOps.SINK, src=(LazyOp(op=BufferOps.STORE, src=(LazyOp(op=BinaryOps.MAX, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.uchar, st=ShapeTracker(views=(View(shape=(10000, 1, 32, 24, 24, 1, 5, 5), strides=(784, 0, 0, 28, 1, 0, 28, 1), offset=0, mask=None, contiguous=False),)))),), arg=dtypes.float), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(10000, 1, 32, 24, 24, 1, 5, 5), strides=(0, 0, 25, 0, 0, 0, 5, 1), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=(6, 7)), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=3, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(10000, 1, 32, 24, 24, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(10000, 1, 32, 24, 24, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(10000, 1, 32, 24, 24, 1, 1, 1), strides=(18432, 0, 576, 24, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)))),), arg=None)
ast = LazyOp(op=MetaOps.KERNEL, src=(LazyOp(op=BufferOps.STORE, src=(LazyOp(op=BinaryOps.MAX, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.uchar, st=ShapeTracker(views=(View(shape=(10000, 1, 32, 24, 24, 1, 5, 5), strides=(784, 0, 0, 28, 1, 0, 28, 1), offset=0, mask=None, contiguous=False),)))),), arg=dtypes.float), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(10000, 1, 32, 24, 24, 1, 5, 5), strides=(0, 0, 25, 0, 0, 0, 5, 1), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=(6, 7)), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=3, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(10000, 1, 32, 24, 24, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(10000, 1, 32, 24, 24, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(10000, 1, 32, 24, 24, 1, 1, 1), strides=(18432, 0, 576, 24, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)))),), arg=None)
for axis in [0,1,2,3,4,5]:
opts = [Opt(op=OptOps.TC, axis=axis, amt=2)]
helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[])
def test_failure_40(self):
# beautiful mnist kernel number 3:
# fuzz: PYTHONPATH=. METAL=1 FUZZ_ALL_ACTIONS=1 DEPTH=2 DEBUG=2 FUZZ_NTH=3 python3 ./test/external/fuzz_linearizer.py --logfile /tmp/beautiful_mnist.kernels.txt
ast = LazyOp(op=MetaOps.SINK, src=(LazyOp(op=BufferOps.STORE, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(60001, 119999), strides=(0, 0), offset=0, mask=((0, 60001), (59999, 119999)), contiguous=False), View(shape=(60000, 60000), strides=(1, 120000), offset=0, mask=None, contiguous=False))))),), arg=(1,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=-1, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(60000, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=MemBuffer(idx=0, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(60000, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)))),), arg=None)
ast = LazyOp(op=MetaOps.KERNEL, src=(LazyOp(op=BufferOps.STORE, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(60001, 119999), strides=(0, 0), offset=0, mask=((0, 60001), (59999, 119999)), contiguous=False), View(shape=(60000, 60000), strides=(1, 120000), offset=0, mask=None, contiguous=False))))),), arg=(1,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=-1, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(60000, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=MemBuffer(idx=0, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(60000, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)))),), arg=None)
for amt in [16,32]:
opts = [Opt(op=OptOps.GROUPTOP, axis=0, amt=amt), Opt(op=OptOps.UNROLL, axis=0, amt=0)]
helper_test_lin(Kernel(ast), opts=opts, failed_platforms=["METAL", "GPU"])
Expand Down
4 changes: 2 additions & 2 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,15 +440,15 @@ def test_embedding_one_kernel(self):
[12, 19, 8, 1]])
result = layer(a)
schedule = create_schedule([result.lazydata])
self.assertEqual(3, len([item for item in schedule if item.ast.op is MetaOps.SINK]), "first run realizes arange, weight, and embedding")
self.assertEqual(3, len([item for item in schedule if item.ast.op is MetaOps.KERNEL]), "first run realizes arange, weight, and embedding")
run_schedule(schedule)

b = Tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
result = layer(b)
schedule = create_schedule([result.lazydata])
self.assertEqual(1, len([item for item in schedule if item.ast.op is MetaOps.SINK]), "second run realizes embedding only")
self.assertEqual(1, len([item for item in schedule if item.ast.op is MetaOps.KERNEL]), "second run realizes embedding only")
run_schedule(schedule)

def test_load_state_dict(self):
Expand Down
4 changes: 2 additions & 2 deletions test/test_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def check_schedule(t:Union[Tensor, List[Tensor]], allowed:int, to_prerealize:Opt
for i,out in enumerate(s.outputs):
seen.add(out)
sched = create_schedule(flatten([r.lazydata.lbs for r in t]), seen)
if filter_sink: sched = [s for s in sched if s.ast.op is MetaOps.SINK]
if filter_sink: sched = [s for s in sched if s.ast.op is MetaOps.KERNEL]
if len(sched) != allowed: print(f"SCHEDULE ISSUE, expecting {allowed} got {len(sched)}")
if len(sched) != allowed or DEBUG >= 3:
for i, s in enumerate(sched):
Expand All @@ -37,7 +37,7 @@ def check_schedule(t:Union[Tensor, List[Tensor]], allowed:int, to_prerealize:Opt
if len(sched) != allowed: raise KernelCountException(f"{len(sched)=} != {allowed}")
# test the (sink) ops linearize
for s in sched:
if s.ast.op is not MetaOps.SINK: continue
if s.ast.op is not MetaOps.KERNEL: continue
l = Kernel(s.ast)
l.hand_coded_optimizations()
l.linearize()
Expand Down
Loading

0 comments on commit fa7e734

Please sign in to comment.