Skip to content

Commit

Permalink
tensor metadata (tinygrad#5271)
Browse files Browse the repository at this point in the history
  • Loading branch information
wozeparrot authored Jul 9, 2024
1 parent 7f642aa commit 9150a6b
Show file tree
Hide file tree
Showing 6 changed files with 137 additions and 24 deletions.
55 changes: 54 additions & 1 deletion test/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import torch
import unittest, copy, mmap, random, math
from tinygrad import Tensor, Device, dtypes
from tinygrad.helpers import getenv, temp, CI
from tinygrad.engine.schedule import create_schedule
from tinygrad.helpers import getenv, temp, CI, _METADATA
from extra.gradcheck import numerical_jacobian, jacobian, gradcheck
from hypothesis import given, settings, strategies as strat
from test.helpers import is_dtype_supported
Expand Down Expand Up @@ -604,5 +605,57 @@ def f(x, m, W):
assert W.grad is None
f(x, m, W)

class TestTensorMetadata(unittest.TestCase):
def test_matmul(self):
_METADATA.set(None)
x = Tensor.rand(3, requires_grad=True)
W = Tensor.rand(3, 3, requires_grad=True)
out = x.matmul(W)
assert out.lazydata.metadata.name == "matmul"
s = create_schedule([out.lazydata])
assert len(s[-1].metadata) == 1
assert s[-1].metadata[0].name == "matmul"

def test_relu(self):
_METADATA.set(None)
x = Tensor.rand(3, requires_grad=True)
out = x.relu()
assert out.lazydata.metadata.name == "relu"
s = create_schedule([out.lazydata])
assert len(s[-1].metadata) == 1
assert s[-1].metadata[0].name == "relu"

def test_complex(self):
_METADATA.set(None)
x = Tensor.rand(3, requires_grad=True)
y = Tensor.rand(3, requires_grad=True)
out = x.relu() * y.sigmoid()
assert out.lazydata.metadata.name == "__mul__"
assert out.lazydata.srcs[0].metadata.name == "relu"
assert out.lazydata.srcs[1].metadata.name == "sigmoid"
s = create_schedule([out.lazydata])
assert len(s[-1].metadata) == 3
assert s[-1].metadata[0].name == "relu"
assert s[-1].metadata[1].name == "sigmoid"
assert s[-1].metadata[2].name == "__mul__"

def test_complex_backward(self):
_METADATA.set(None)
x = Tensor.rand(3, requires_grad=True)
y = Tensor.rand(3, requires_grad=True)
out = (x.relu() * y.sigmoid()).sum()
assert out.lazydata.metadata.name == "sum"
out.backward()
assert x.grad.lazydata.metadata.name == "relu"
assert x.grad.lazydata.metadata.backward
assert y.grad.lazydata.metadata.name == "sigmoid"
assert y.grad.lazydata.metadata.backward
s = create_schedule([out.lazydata, x.grad.lazydata, y.grad.lazydata])
assert len(s[-1].metadata) == 3
assert s[-1].metadata[0].name == "sigmoid"
assert s[-1].metadata[1].name == "sigmoid"
assert s[-1].metadata[1].backward
assert s[-1].metadata[2].name == "relu"

if __name__ == '__main__':
unittest.main()
20 changes: 15 additions & 5 deletions tinygrad/engine/realize.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import List, Dict, Optional, cast, Generator, Tuple
import time
import time, pprint
from dataclasses import dataclass, replace
from tinygrad.helpers import colored, getenv, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING
from tinygrad.helpers import colored, getenv, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata
from tinygrad.ops import BufferOps, LoadOps, LazyOp
from tinygrad.device import Device, Buffer
from tinygrad.shape.symbolic import Variable, sym_infer, sint
Expand Down Expand Up @@ -148,6 +148,7 @@ def get_runner(dname:str, ast:Tuple[LazyOp, ...]) -> CompiledRunner:
class ExecItem:
prg: Runner
bufs: List[Optional[Buffer]]
metadata: Optional[List[Metadata]] = None
def run(self, var_vals:Optional[Dict[Variable, int]]=None, wait=False, jit=False, do_update_stats=True) -> Optional[float]:
bufs = [cast(Buffer, x) for x in self.bufs] if jit else [cast(Buffer, x).ensure_allocated() for x in self.bufs]
et = self.prg(bufs, var_vals if var_vals is not None else {}, wait=wait or DEBUG >= 2)
Expand All @@ -159,15 +160,16 @@ def run(self, var_vals:Optional[Dict[Variable, int]]=None, wait=False, jit=False
if DEBUG >= 2:
ptm = (colored(f"{et*1e3:9.2f}ms", "yellow") if et > 0.01 else f"{et*1e6:9.2f}us") if et is not None else ""
print(f"{colored(f'*** {self.prg.dname[:7]:7s} {GlobalCounters.kernel_count:4d}', 'magenta' if jit else ('green' if self.prg.first_run else None))} {self.prg.display_name+' '*(38-ansilen(self.prg.display_name))} arg {len(self.bufs):3d} mem {GlobalCounters.mem_used/1e9:5.2f} GB " + # noqa: E501
(str() if et is None else f"tm {ptm}/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)")) # noqa: E501
(str() if et is None else f"tm {ptm}/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)" + # noqa: E501
f" {[repr(m) if DEBUG >= 3 else str(m) for m in self.metadata] if self.metadata else ''}"))
self.prg.first_run = False
return et

def lower_schedule_item(si:ScheduleItem) -> ExecItem:
assert len(set(x.device for x in si.bufs)) == 1 or si.ast[0].op is LoadOps.COPY or getenv("USE_COPY_KERNEL")
if si.ast[0].op is BufferOps.STORE:
runner = get_runner(si.outputs[0].device, si.ast)
return ExecItem(runner, [si.bufs[x[0]] for x in runner.p.globals])
return ExecItem(runner, [si.bufs[x[0]] for x in runner.p.globals], si.metadata)
out, ast = si.outputs[0], si.ast[0]
if ast.op is LoadOps.COPY:
kernel_type = BufferCopy
Expand All @@ -180,7 +182,15 @@ def lower_schedule_item(si:ScheduleItem) -> ExecItem:
raise RuntimeError(f"don't know how to lower {ast}")

def lower_schedule(schedule:List[ScheduleItem]) -> Generator[ExecItem, None, None]:
while len(schedule): yield lower_schedule_item(schedule.pop(0))
while len(schedule):
si = schedule.pop(0)
try: yield lower_schedule_item(si)
except Exception as e:
if DEBUG >= 2:
print(f"error lowering {si.ast[0].op}")
print("tensor operations:")
pprint.pprint(si.metadata, indent=2)
raise e

# **************** main run function ****************

Expand Down
16 changes: 9 additions & 7 deletions tinygrad/engine/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Tuple, List, Dict, Optional, Set, DefaultDict, Union, get_args
from tinygrad.ops import LoadOps, BufferOps, LazyOp, ReduceOps, ConstBuffer, MemBuffer, UNSAFE_PAD_OPS, UnaryOps
from tinygrad.engine.graph import log_lazybuffer, realized_lazybuffer
from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, GlobalCounters, colored, prod, dedup, all_int, merge_dicts, getenv
from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, GlobalCounters, colored, prod, dedup, all_int, merge_dicts, getenv, Metadata
from tinygrad.shape.symbolic import Variable
from tinygrad.dtype import ConstType, ImageDType, dtypes
from tinygrad.lazy import LazyBuffer
Expand All @@ -23,6 +23,7 @@
class ScheduleItem:
ast: Tuple[LazyOp, ...]
bufs: Tuple[Buffer, ...]
metadata: Optional[List[Metadata]] = None
@property
def outputs(self) -> Tuple[Buffer, ...]:
"""Read/write or write only buffers in the schedule."""
Expand Down Expand Up @@ -96,20 +97,21 @@ def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None], re
"""describe the computation for a LazyBuffer with LazyOp + inputs + var_vals"""
if (out:=outs[0]).op is LoadOps.COPY and getenv("USE_COPY_KERNEL") and out.device.split(":")[0] == out.srcs[0].device.split(":")[0]:
rd = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.uint8, st:=ShapeTracker.from_shape((out.arg,))))
return (LazyOp(BufferOps.STORE, (rd,), MemBuffer(0, dtypes.uint8, st)), ), [x.base for x in out.srcs], {}
if out.op in {LoadOps.CUSTOM, LoadOps.COPY, LoadOps.EMPTY, LoadOps.VIEW}: return (LazyOp(out.op, (), out.arg), ), [x.base for x in out.srcs], {}
return (LazyOp(BufferOps.STORE, (rd,), MemBuffer(0, dtypes.uint8, st)), ), [x.base for x in out.srcs], {}, []
if out.op in {LoadOps.CUSTOM, LoadOps.COPY, LoadOps.EMPTY, LoadOps.VIEW}: return (LazyOp(out.op, (), out.arg), ), [x.base for x in out.srcs], {}, []
var_vals: Dict[Variable, int] = merge_dicts([out.st.var_vals.copy() for out in outs])
assign_targets = {x.srcs[1]:x for x in outs if x.op is LoadOps.ASSIGN}
cache: Dict[Tuple[LazyBuffer, ShapeTracker], LazyOp] = {}
ast: List[LazyOp] = []
inputs: List[LazyBuffer] = []
for i, out in enumerate(outs):
output_st = ShapeTracker.from_shape(reduce_for_op[out].shape if out in reduce_for_op else out.shape)
output_view = out.arg[0] if out.op is LoadOps.ASSIGN and out.arg else output_st
lop = _recursive_lazyop(out, inputs, tuple(outs), var_vals, output_st, realizes, assign_targets, cache={})
lop = _recursive_lazyop(out, inputs, tuple(outs), var_vals, output_st, realizes, assign_targets, cache=cache)
output_view, vv = output_view.simplify().unbind()
if vv: var_vals.update(vv)
ast.append(LazyOp(BufferOps.STORE, (lop, ), MemBuffer(i, out.dtype, output_view)))
return tuple(ast), inputs, var_vals
return tuple(ast), inputs, var_vals, dedup([x[0].metadata for x in cache if x[0].metadata and x[0] not in inputs])

# *** DAG creation: decide which LazyBuffers should realize ***

Expand Down Expand Up @@ -301,7 +303,7 @@ def create_schedule_with_vars(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffe
for out in ps[0]: realized_lazybuffer(out, kernel_number)
var_vals = merge_dicts([var_vals, ps[3]])
for out in ps[0]: del out.srcs # can only schedule once
schedule.append(si:=ScheduleItem(ps[1], tuple(x.buffer for x in ps[0]+ps[2] if x.size != 0)))
schedule.append(si:=ScheduleItem(ps[1], tuple(x.buffer for x in ps[0]+ps[2] if x.size != 0), ps[4]))
if logops and si.ast[0].op not in LoadOps and not any(i.device.startswith("DISK:") for i in si.inputs): logops.write(str(si.ast)+"\n")
for x in graph[ps[0][0]]:
in_degree[x] -= 1
Expand Down Expand Up @@ -366,4 +368,4 @@ def find_replace_buffer(buf, st, en):
def memory_planner(schedule:List[ScheduleItem]) -> List[ScheduleItem]:
# Exclude buffers involved in load ops (e.g transfers) to preserve parallelism in graphs.
assigned = _internal_memory_planner([si.bufs for si in schedule], noopt_buffers={b for si in schedule if si.ast[0].op in LoadOps for b in si.bufs})
return [ScheduleItem(si.ast, tuple(assigned.get(x, x) for x in si.bufs)) for si in schedule]
return [ScheduleItem(si.ast, tuple(assigned.get(x, x) for x in si.bufs), si.metadata) for si in schedule]
17 changes: 14 additions & 3 deletions tinygrad/helpers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations
import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, cProfile, pstats, tempfile, pathlib, string, ctypes, sys
import itertools, urllib.request, subprocess, shutil, math, json
import itertools, urllib.request, subprocess, shutil, math, json, contextvars
from dataclasses import dataclass
from typing import Dict, Tuple, Union, List, ClassVar, Optional, Iterable, Any, TypeVar, TYPE_CHECKING, Callable, Sequence
if TYPE_CHECKING: # TODO: remove this and import TypeGuard from typing once minimum python supported version is 3.10
from typing_extensions import TypeGuard
Expand Down Expand Up @@ -101,10 +102,20 @@ def __gt__(self, x): return self.value > x
def __lt__(self, x): return self.value < x

DEBUG, IMAGE, BEAM, NOOPT, JIT = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0), ContextVar("BEAM", 0), ContextVar("NOOPT", 0), ContextVar("JIT", 1)
WINO, THREEFRY, CAPTURING = ContextVar("WINO", 0), ContextVar("THREEFRY", 0), ContextVar("CAPTURING", 1)
WINO, THREEFRY, CAPTURING, TRACEMETA = ContextVar("WINO", 0), ContextVar("THREEFRY", 0), ContextVar("CAPTURING", 1), ContextVar("TRACEMETA", 1)
GRAPH, GRAPHPATH, SAVE_SCHEDULE, RING = ContextVar("GRAPH", 0), getenv("GRAPHPATH", "/tmp/net"), ContextVar("SAVE_SCHEDULE", 0), ContextVar("RING", 1)
MULTIOUTPUT, PROFILE = ContextVar("MULTIOUTPUT", 1), ContextVar("PROFILE", 0)

@dataclass(frozen=True)
class Metadata:
name: str
caller: str
backward: bool = False
def __hash__(self): return hash(self.name)
def __repr__(self): return str(self) + (f" - {self.caller}" if self.caller else "")
def __str__(self): return self.name + (" bw" if self.backward else "")
_METADATA: contextvars.ContextVar[Optional[Metadata]] = contextvars.ContextVar("_METADATA", default=None)

# **************** global state Counters ****************

class GlobalCounters:
Expand Down Expand Up @@ -306,4 +317,4 @@ def fn(x): return (f"{x/1000**int(g:=math.log(x,1000)):.{int(3-3*math.fmod(g,1))
print(bar[:ncols+1],flush=True,end='\n'*close,file=sys.stderr)

class trange(tqdm):
def __init__(self, n:int, **kwargs): super().__init__(iterable=range(n), total=n, **kwargs)
def __init__(self, n:int, **kwargs): super().__init__(iterable=range(n), total=n, **kwargs)
8 changes: 4 additions & 4 deletions tinygrad/lazy.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations
from typing import Union, Optional, Any, Tuple, List
from tinygrad.dtype import dtypes, DType, ConstType
from tinygrad.helpers import prod, getenv, all_int, all_same, DEBUG
from tinygrad.helpers import prod, getenv, all_int, all_same, DEBUG, _METADATA, Metadata
from tinygrad.ops import LoadOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, Op, exec_alu, python_alu
from tinygrad.shape.symbolic import sint, Variable
from tinygrad.shape.shapetracker import ShapeTracker
Expand All @@ -17,16 +17,16 @@ def create_lazybuffer(device:str, st:ShapeTracker, dtype:DType, op:Optional[Op]=
cache_key = (device, st, dtype, op, arg, tuple(ref(x) for x in srcs)) if base is None else (st, ref(base))
if enable_cache and (rret := lazycache.get(cache_key, None)): return rret

ret = LazyBuffer(device, st, dtype, op, arg, srcs, base=base)
ret = LazyBuffer(device, st, dtype, op, arg, srcs, base=base, metadata=_METADATA.get())
if enable_cache: lazycache[cache_key] = ret
return ret

view_supported_devices = {"LLVM", "CLANG", "CUDA", "NV", "AMD", "METAL", "DISK"}
class LazyBuffer:
def __init__(self, device:str, st:ShapeTracker, dtype:DType,
op:Optional[Op]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(),
base:Optional[LazyBuffer]=None):
self.device, self.st, self.dtype, self.shape, self.size = device, st, dtype, st.shape, st.size
base:Optional[LazyBuffer]=None, metadata:Optional[Metadata]=None):
self.device, self.st, self.dtype, self.shape, self.size, self.metadata = device, st, dtype, st.shape, st.size, metadata
self._base: Optional[LazyBuffer] = None
if base is None:
# properties on base
Expand Down
45 changes: 41 additions & 4 deletions tinygrad/tensor.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
# inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
from __future__ import annotations
import time, math, itertools, functools, struct
import dataclasses
import time, math, itertools, functools, struct, sys, inspect
from contextlib import ContextDecorator
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Dict, DefaultDict, cast, get_args, Set
from collections import defaultdict
import numpy as np

from tinygrad.dtype import DType, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype
from tinygrad.helpers import argfix, make_pair, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, get_shape, fully_flatten, dedup
from tinygrad.helpers import IMAGE, DEBUG, WINO, THREEFRY
from tinygrad.helpers import IMAGE, DEBUG, WINO, THREEFRY, _METADATA, Metadata, TRACEMETA
from tinygrad.lazy import LazyBuffer
from tinygrad.multi import MultiLazyBuffer
from tinygrad.ops import LoadOps, truncate
Expand All @@ -20,18 +21,19 @@
# **** start with two base classes, Tensor and Function ****

class Function:
def __init__(self, device:Union[str, Tuple[str, ...]], *tensors:Tensor):
def __init__(self, device:Union[str, Tuple[str, ...]], *tensors:Tensor, metadata:Optional[Metadata]=None):
self.device = device
self.needs_input_grad = [t.requires_grad for t in tensors]
self.requires_grad = True if any(self.needs_input_grad) else None if None in self.needs_input_grad else False
if self.requires_grad: self.parents = tensors
self.metadata = metadata

def forward(self, *args, **kwargs): raise NotImplementedError(f"forward not implemented for {type(self)}")
def backward(self, *args, **kwargs): raise RuntimeError(f"backward not implemented for {type(self)}")

@classmethod
def apply(fxn:Type[Function], *x:Tensor, **kwargs) -> Tensor:
ctx = fxn(x[0].device, *x)
ctx = fxn(x[0].device, *x, metadata=_METADATA.get())
ret = Tensor.__new__(Tensor)
ret.lazydata, ret.requires_grad, ret.grad = ctx.forward(*[t.lazydata for t in x], **kwargs), ctx.requires_grad, None
ret._ctx = ctx if ctx.requires_grad and not Tensor.no_grad else None # used by autograd engine
Expand Down Expand Up @@ -740,7 +742,9 @@ def backward(self) -> Tensor:

for t0 in reversed(self._deepwalk()):
if t0.grad is None: raise RuntimeError(f"tensor {t0} has no grad")
token = _METADATA.set(dataclasses.replace(md, backward=True) if (md := t0._ctx.metadata) is not None else None)
grads = t0._ctx.backward(t0.grad.lazydata)
_METADATA.reset(token)
grads = [Tensor(g, device=self.device, requires_grad=False) if g is not None else None
for g in ([grads] if len(t0._ctx.parents) == 1 else grads)]
for t, g in zip(t0._ctx.parents, grads):
Expand Down Expand Up @@ -3091,3 +3095,36 @@ def custom_random(out:Buffer):
if out.dtype == dtypes.half: rng_np_buffer = (rng.integers(low=0, high=2047, size=out.size) / 2048).astype(np.half, copy=False)
else: rng_np_buffer = rng.random(size=out.size, dtype=np.float32).astype(dtype=_to_np_dtype(out.dtype), copy=False)
out.copyin(rng_np_buffer.data)

def _metadata_wrapper(fn):
def _wrapper(*args, **kwargs):
if _METADATA.get() is not None: return fn(*args, **kwargs)

caller_frame = sys._getframe(frame := 1)
caller_module = caller_frame.f_globals.get("__name__", None)
caller_func = caller_frame.f_code.co_name
if caller_module is None: return fn(*args, **kwargs)

# if its called from nn we want to step up frames until we are out of nn
while caller_module.startswith("tinygrad.nn") and "optim" not in caller_module:
caller_frame = sys._getframe(frame := frame + 1)
caller_module = caller_frame.f_globals.get("__name__", None)
if caller_module is None: return fn(*args, **kwargs)

# if its called from a lambda in tinygrad we want to look two more frames up
if caller_module.startswith("tinygrad") and caller_func == "<lambda>": caller_frame = sys._getframe(frame := frame + 2)
caller_module = caller_frame.f_globals.get("__name__", None)
if caller_module is None: return fn(*args, **kwargs)
caller_func = caller_frame.f_code.co_name
caller_lineno = caller_frame.f_lineno

token = _METADATA.set(Metadata(name=fn.__name__, caller=f"{caller_module}:{caller_lineno}::{caller_func}"))
ret = fn(*args, **kwargs)
_METADATA.reset(token)
return ret
return _wrapper

if TRACEMETA >= 1:
for name, fn in inspect.getmembers(Tensor, inspect.isfunction):
if name in ["__class__", "__init__", "__repr__", "backward", "sequential"]: continue
setattr(Tensor, name, functools.wraps(fn)(_metadata_wrapper(fn)))

0 comments on commit 9150a6b

Please sign in to comment.