Skip to content

Commit

Permalink
add constant folding for WHERE in uops (tinygrad#3584)
Browse files Browse the repository at this point in the history
* add constant folding for WHERE in uops

* prereqs for generic constant folding

* fix test

* disable slow overflow logic

* make that test faster
  • Loading branch information
geohot authored Mar 2, 2024
1 parent 3b7e3fa commit aa9b013
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 29 deletions.
10 changes: 9 additions & 1 deletion test/test_linearizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@ def test_arg_dedup(self):
np_c = (np_a[:2] - np_a[2:]) - (np_b[:2] - np_b[2:])
np.testing.assert_allclose(np_c, c.numpy(), atol=1e-4, rtol=1e-4)

def test_load_removed(self):
a = Tensor.rand(1).realize()
b = Tensor.rand(1).realize()
ta = Tensor.where(Tensor(True), a, b).numpy()
tb = Tensor.where(Tensor(False), a, b).numpy()
np.testing.assert_equal(a.numpy(), ta)
np.testing.assert_equal(b.numpy(), tb)

def test_load_dedup(self):
# for different leaves in the AST, the same loads may occur.

Expand Down Expand Up @@ -209,7 +217,7 @@ def helper_test_simplify(uop, dtype, vin, arg=None):
c0 = UOp(UOps.CONST, dtypes.float, vin=(), arg=0.0)
c1 = UOp(UOps.CONST, dtypes.float, vin=(), arg=1.0)
assert helper_test_simplify(UOps.ALU, dtypes.float, vin=(UOp(UOps.CONST, dtypes.bool, vin=(), arg=True), c0, c1),
arg=TernaryOps.WHERE).uop == UOps.ALU
arg=TernaryOps.WHERE).uop == UOps.CONST

def helper_realized_ast(r:Tensor):
s = create_schedule([r.lazydata])
Expand Down
2 changes: 1 addition & 1 deletion test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1082,7 +1082,7 @@ def test_simple_conv3d(self):

@unittest.skipIf(IMAGE>0, "no conv3d on images")
def test_padded_conv3d(self):
helper_test_op([(1,4,9,9,9), (4,4,3,3,3)],
helper_test_op([(1,4,5,5,5), (4,4,3,3,3)],
lambda x,w: torch.nn.functional.conv3d(x,w,padding=1).relu(),
lambda x,w: Tensor.conv2d(x,w,padding=[1,1,1,1,1,1]).relu(), atol=1e-4, grad_rtol=1e-5)

Expand Down
18 changes: 17 additions & 1 deletion test/test_uops.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
from tinygrad.device import CompiledASTRunner, Compiled
from tinygrad.codegen.linearizer import UOps, UOp
from tinygrad.runtime.ops_python import exec_alu
from tinygrad.codegen.uops import exec_alu
from test.test_dtype import is_dtype_supported

def _uops_to_prg(uops):
Expand Down Expand Up @@ -113,5 +113,21 @@ class TestExecALU(TestUOps):
def test_sqrt(self):
self.assertEqual(exec_alu(UnaryOps.SQRT, dtypes.int, (0,)), 0)

@unittest.skip("not enabled because it's slow")
def test_overflow(self):
self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.uint8, (250, 250)), 244)
self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.uint8, (256, 0)), 0)
self.assertEqual(exec_alu(BinaryOps.SUB, dtypes.uint8, (0, 1)), 255)
self.assertEqual(exec_alu(BinaryOps.SUB, dtypes.uint8, (0, 1000)), 24)

self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.int8, (127, 0)), 127)
self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.int8, (-128, 0)), -128)
self.assertEqual(exec_alu(BinaryOps.SUB, dtypes.int8, (-100, 100)), 56)
self.assertEqual(exec_alu(BinaryOps.SUB, dtypes.int8, (-1000, 0)), 24)
self.assertEqual(exec_alu(BinaryOps.SUB, dtypes.int8, (-130, 0)), 126)

self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.int8, (1.0, 1.0)), 2)
self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.int8, (-math.exp2(7), 0)), -128)

if __name__ == '__main__':
unittest.main(verbosity=2)
27 changes: 26 additions & 1 deletion tinygrad/codegen/uops.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from __future__ import annotations
import functools
import functools, math
from typing import List, Set, Optional, Tuple, Any, Dict, DefaultDict
from collections import defaultdict
from tinygrad.helpers import DEBUG, flatten, all_same
Expand All @@ -25,6 +25,30 @@ class UOp:
def __repr__(self):
return f"{str(self.uop):20s}: {str(self.dtype) if self.dtype is not None else '':25s} {str([x.uop for x in self.vin]):32s} {self.arg}"

def exec_alu(arg, dtype, p):
if arg == TernaryOps.WHERE: ret = p[1] if p[0] else p[2]
elif arg == UnaryOps.LOG2: ret = math.log2(p[0]) if p[0] > 0 else -math.inf if p[0] == 0 else math.nan
elif arg == UnaryOps.EXP2:
try: ret = math.exp(p[0]*math.log(2))
except OverflowError: ret = math.inf
elif arg == UnaryOps.SQRT: ret = math.sqrt(p[0]) if p[0] >= 0 else math.nan
elif arg == UnaryOps.SIN: ret = math.sin(p[0])
elif arg == UnaryOps.NEG: ret = -p[0]
elif arg == BinaryOps.MUL: ret = p[0]*p[1]
elif arg == BinaryOps.ADD: ret = p[0]+p[1]
elif arg == BinaryOps.SUB: ret = p[0]-p[1]
elif arg == BinaryOps.XOR: ret = p[0]^p[1]
elif arg == BinaryOps.MAX: ret = max(p[0], p[1])
elif arg == BinaryOps.CMPEQ: ret = p[0] == p[1]
elif arg == BinaryOps.CMPLT: ret = p[0] < p[1]
elif arg == BinaryOps.DIV: ret = p[0]//p[1] if dtypes.is_int(dtype) else (p[0]/p[1] if p[1] != 0 else math.nan)
elif arg == BinaryOps.MOD: ret = p[0]%p[1]
return ret
#else: raise NotImplementedError(f"no support for {arg}")
#if not dtypes.is_int(dtype): return ret
#adjusted = 0 if dtypes.is_unsigned(dtype) else 2 ** (dtype.itemsize * 8 - 1)
#return (ret + adjusted) % 2 ** (dtype.itemsize * 8) - adjusted

def uop_alu_resolve(u:UOp) -> sint:
if u.uop == UOps.CONST: return u.arg
elif u.uop == UOps.DEFINE_VAR: return u.arg
Expand Down Expand Up @@ -68,6 +92,7 @@ def add(self, uop:UOps, dtype:Optional[DType]=None, vin:Tuple[UOp, ...]=tuple(),
# constant folding
if arg is UnaryOps.NEG and vin[0].uop is UOps.CONST: return self.add(UOps.CONST, dtype, arg=-vin[0].arg, insert_before=insert_before)
if arg is TernaryOps.WHERE and vin[1] == vin[2]: return vin[1] # a conditional with the same results either way is a noop
if arg is TernaryOps.WHERE and vin[0].uop is UOps.CONST: return vin[1] if vin[0].arg else vin[2]
if arg is BinaryOps.MUL and vin[0].uop is UOps.CONST and vin[1].uop is UOps.CONST and dtype is not None and dtypes.is_float(dtype):
return self.add(UOps.CONST, dtype, arg=vin[0].arg * vin[1].arg, insert_before=insert_before)
# zero folding
Expand Down
28 changes: 3 additions & 25 deletions tinygrad/runtime/ops_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,36 +2,14 @@
# works to test the tensor cores, and all the uops in general
# this is the (living) definition of uops
from typing import Tuple, List, Optional, Any, Dict
import pickle, base64, itertools, time, math, struct
import pickle, base64, itertools, time, struct
from tinygrad.dtype import DType, dtypes, ImageDType
from tinygrad.helpers import all_same, getenv, flatten
from tinygrad.device import Compiled, Allocator, Compiler
from tinygrad.codegen.uops import UOp, UOps
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
from tinygrad.codegen.uops import UOp, UOps, exec_alu
from tinygrad.ops import BinaryOps, TernaryOps
from tinygrad.codegen.kernel import LinearizerOptions

def exec_alu(arg, dtype, p):
# TODO: make this complete and correctly honor the dtypes
# TODO: use this for constant folding
if arg == TernaryOps.WHERE: return p[1] if p[0] else p[2]
if arg == UnaryOps.LOG2: return math.log2(p[0]) if p[0] > 0 else -math.inf if p[0] == 0 else math.nan
if arg == UnaryOps.EXP2:
try: return math.exp(p[0]*math.log(2))
except OverflowError: return math.inf
if arg == UnaryOps.SQRT: return math.sqrt(p[0]) if p[0] >= 0 else math.nan
if arg == UnaryOps.SIN: return math.sin(p[0])
if arg == UnaryOps.NEG: return -p[0]
if arg == BinaryOps.MUL: return p[0]*p[1]
if arg == BinaryOps.ADD: return p[0]+p[1]
if arg == BinaryOps.SUB: return p[0]-p[1]
if arg == BinaryOps.XOR: return p[0]^p[1]
if arg == BinaryOps.MAX: return max(p[0], p[1])
if arg == BinaryOps.CMPEQ: return p[0] == p[1]
if arg == BinaryOps.CMPLT: return p[0] < p[1]
if arg == BinaryOps.DIV: return p[0]//p[1] if dtypes.is_int(dtype) else (p[0]/p[1] if p[1] != 0 else math.nan)
if arg == BinaryOps.MOD: return p[0]%p[1]
raise NotImplementedError(f"no support for {arg}")

def _load(m, i):
if i<0 or i>=len(m): raise IndexError(f"load out of bounds, size is {len(m)} and access is {i}")
return m[i]
Expand Down

0 comments on commit aa9b013

Please sign in to comment.