Skip to content

Commit

Permalink
GEMM kernel search
Browse files Browse the repository at this point in the history
  • Loading branch information
geohot committed Jan 27, 2023
1 parent 1239931 commit 6d5e1a8
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 0 deletions.
6 changes: 6 additions & 0 deletions extra/kernel_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,12 @@ def test():
op0 = LazyOp(BinaryOps.MUL, (buf0,buf1,), None)
op1 = LazyOp(ReduceOps.SUM, (op0,), (8, 1, 32, 112, 112, 1, 1, 1))
ast = LazyOp(MovementOps.RESHAPE, (op1,), (8, 32, 112, 112))
elif int(os.getenv("GEMM", "0")):
buf0 = GPUBuffer(shape=ShapeTracker(shape=(1, 1, 512, 512, 1, 1, 1, 512), views=[View((1, 512, 512, 1), (0, 1, 512, 0), 0), View((1, 1, 512, 512, 1, 1, 1, 512), (0, 0, 0, 1, 0, 0, 0, 512), 0)]), hostbuf=GPUBuffer(shape=(512, 512), force_create=True))
buf1 = GPUBuffer(shape=ShapeTracker(shape=(1, 1, 512, 512, 1, 1, 1, 512), views=[View((1, 1, 512, 512, 1, 1, 1, 512), (0, 0, 1, 0, 0, 0, 0, 512), 0)]), hostbuf=GPUBuffer(shape=(512, 512), force_create=True))
op0 = LazyOp(BinaryOps.MUL, (buf0,buf1,), None)
op1 = LazyOp(ReduceOps.SUM, (op0,), (1, 1, 512, 512, 1, 1, 1, 1))
ast = LazyOp(MovementOps.RESHAPE, (op1,), (512, 512))
else:
# reduce
buf0 = GPUBuffer(shape=ShapeTracker(shape=(3, 1, 32, 3, 3, 32, 112, 112), views=[View((3, 32, 225, 225), (50176, 150528, 224, 1), 0), ZeroView((3, 32, 224, 224), ((0, 3), (0, 32), (0, 225), (0, 225))), View((3, 1, 32, 3, 3, 32, 112, 112), (1620000, 1620000, 0, 225, 1, 50625, 450, 2), 0)]), hostbuf=GPUBuffer(shape=(32, 3, 224, 224), force_create=True))
Expand Down
1 change: 1 addition & 0 deletions tinygrad/llops/ops_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def __call__(self, *args):
(str() if DEBUG <= 1 or CL.CACHE is not None else f"tm {et/1e3:9.2f}us/{CL.time_sum/1e6:9.2f}ms ({self.op_estimate/et:8.2f} GFLOPS)"))
GlobalCounters.global_ops += self.op_estimate
GlobalCounters.global_mem += sum([x.size//4 for x in args[2:] if isinstance(x, cl.Buffer)])
return e if CL.CACHE is None else None

# **** end CL wrappers ****

Expand Down

0 comments on commit 6d5e1a8

Please sign in to comment.