From 9150a6be7a30bbd17f0b84f3352fac7af0c68b73 Mon Sep 17 00:00:00 2001 From: wozeparrot Date: Tue, 9 Jul 2024 00:45:40 +0000 Subject: [PATCH] tensor metadata (#5271) --- test/test_tensor.py | 55 ++++++++++++++++++++++++++++++++++++- tinygrad/engine/realize.py | 20 ++++++++++---- tinygrad/engine/schedule.py | 16 ++++++----- tinygrad/helpers.py | 17 ++++++++++-- tinygrad/lazy.py | 8 +++--- tinygrad/tensor.py | 45 +++++++++++++++++++++++++++--- 6 files changed, 137 insertions(+), 24 deletions(-) diff --git a/test/test_tensor.py b/test/test_tensor.py index f1ce7ea4a154f..e1ece7dfe1b8d 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -2,7 +2,8 @@ import torch import unittest, copy, mmap, random, math from tinygrad import Tensor, Device, dtypes -from tinygrad.helpers import getenv, temp, CI +from tinygrad.engine.schedule import create_schedule +from tinygrad.helpers import getenv, temp, CI, _METADATA from extra.gradcheck import numerical_jacobian, jacobian, gradcheck from hypothesis import given, settings, strategies as strat from test.helpers import is_dtype_supported @@ -604,5 +605,57 @@ def f(x, m, W): assert W.grad is None f(x, m, W) +class TestTensorMetadata(unittest.TestCase): + def test_matmul(self): + _METADATA.set(None) + x = Tensor.rand(3, requires_grad=True) + W = Tensor.rand(3, 3, requires_grad=True) + out = x.matmul(W) + assert out.lazydata.metadata.name == "matmul" + s = create_schedule([out.lazydata]) + assert len(s[-1].metadata) == 1 + assert s[-1].metadata[0].name == "matmul" + + def test_relu(self): + _METADATA.set(None) + x = Tensor.rand(3, requires_grad=True) + out = x.relu() + assert out.lazydata.metadata.name == "relu" + s = create_schedule([out.lazydata]) + assert len(s[-1].metadata) == 1 + assert s[-1].metadata[0].name == "relu" + + def test_complex(self): + _METADATA.set(None) + x = Tensor.rand(3, requires_grad=True) + y = Tensor.rand(3, requires_grad=True) + out = x.relu() * y.sigmoid() + assert out.lazydata.metadata.name == "__mul__" + assert out.lazydata.srcs[0].metadata.name == "relu" + assert out.lazydata.srcs[1].metadata.name == "sigmoid" + s = create_schedule([out.lazydata]) + assert len(s[-1].metadata) == 3 + assert s[-1].metadata[0].name == "relu" + assert s[-1].metadata[1].name == "sigmoid" + assert s[-1].metadata[2].name == "__mul__" + + def test_complex_backward(self): + _METADATA.set(None) + x = Tensor.rand(3, requires_grad=True) + y = Tensor.rand(3, requires_grad=True) + out = (x.relu() * y.sigmoid()).sum() + assert out.lazydata.metadata.name == "sum" + out.backward() + assert x.grad.lazydata.metadata.name == "relu" + assert x.grad.lazydata.metadata.backward + assert y.grad.lazydata.metadata.name == "sigmoid" + assert y.grad.lazydata.metadata.backward + s = create_schedule([out.lazydata, x.grad.lazydata, y.grad.lazydata]) + assert len(s[-1].metadata) == 3 + assert s[-1].metadata[0].name == "sigmoid" + assert s[-1].metadata[1].name == "sigmoid" + assert s[-1].metadata[1].backward + assert s[-1].metadata[2].name == "relu" + if __name__ == '__main__': unittest.main() diff --git a/tinygrad/engine/realize.py b/tinygrad/engine/realize.py index aa7fa8adacd7e..68670bcb292f1 100644 --- a/tinygrad/engine/realize.py +++ b/tinygrad/engine/realize.py @@ -1,7 +1,7 @@ from typing import List, Dict, Optional, cast, Generator, Tuple -import time +import time, pprint from dataclasses import dataclass, replace -from tinygrad.helpers import colored, getenv, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING +from tinygrad.helpers import colored, getenv, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata from tinygrad.ops import BufferOps, LoadOps, LazyOp from tinygrad.device import Device, Buffer from tinygrad.shape.symbolic import Variable, sym_infer, sint @@ -148,6 +148,7 @@ def get_runner(dname:str, ast:Tuple[LazyOp, ...]) -> CompiledRunner: class ExecItem: prg: Runner bufs: List[Optional[Buffer]] + metadata: Optional[List[Metadata]] = None def run(self, var_vals:Optional[Dict[Variable, int]]=None, wait=False, jit=False, do_update_stats=True) -> Optional[float]: bufs = [cast(Buffer, x) for x in self.bufs] if jit else [cast(Buffer, x).ensure_allocated() for x in self.bufs] et = self.prg(bufs, var_vals if var_vals is not None else {}, wait=wait or DEBUG >= 2) @@ -159,7 +160,8 @@ def run(self, var_vals:Optional[Dict[Variable, int]]=None, wait=False, jit=False if DEBUG >= 2: ptm = (colored(f"{et*1e3:9.2f}ms", "yellow") if et > 0.01 else f"{et*1e6:9.2f}us") if et is not None else "" print(f"{colored(f'*** {self.prg.dname[:7]:7s} {GlobalCounters.kernel_count:4d}', 'magenta' if jit else ('green' if self.prg.first_run else None))} {self.prg.display_name+' '*(38-ansilen(self.prg.display_name))} arg {len(self.bufs):3d} mem {GlobalCounters.mem_used/1e9:5.2f} GB " + # noqa: E501 - (str() if et is None else f"tm {ptm}/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)")) # noqa: E501 + (str() if et is None else f"tm {ptm}/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)" + # noqa: E501 + f" {[repr(m) if DEBUG >= 3 else str(m) for m in self.metadata] if self.metadata else ''}")) self.prg.first_run = False return et @@ -167,7 +169,7 @@ def lower_schedule_item(si:ScheduleItem) -> ExecItem: assert len(set(x.device for x in si.bufs)) == 1 or si.ast[0].op is LoadOps.COPY or getenv("USE_COPY_KERNEL") if si.ast[0].op is BufferOps.STORE: runner = get_runner(si.outputs[0].device, si.ast) - return ExecItem(runner, [si.bufs[x[0]] for x in runner.p.globals]) + return ExecItem(runner, [si.bufs[x[0]] for x in runner.p.globals], si.metadata) out, ast = si.outputs[0], si.ast[0] if ast.op is LoadOps.COPY: kernel_type = BufferCopy @@ -180,7 +182,15 @@ def lower_schedule_item(si:ScheduleItem) -> ExecItem: raise RuntimeError(f"don't know how to lower {ast}") def lower_schedule(schedule:List[ScheduleItem]) -> Generator[ExecItem, None, None]: - while len(schedule): yield lower_schedule_item(schedule.pop(0)) + while len(schedule): + si = schedule.pop(0) + try: yield lower_schedule_item(si) + except Exception as e: + if DEBUG >= 2: + print(f"error lowering {si.ast[0].op}") + print("tensor operations:") + pprint.pprint(si.metadata, indent=2) + raise e # **************** main run function **************** diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index c7d31912d1f6d..04226d62fa145 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -4,7 +4,7 @@ from typing import Tuple, List, Dict, Optional, Set, DefaultDict, Union, get_args from tinygrad.ops import LoadOps, BufferOps, LazyOp, ReduceOps, ConstBuffer, MemBuffer, UNSAFE_PAD_OPS, UnaryOps from tinygrad.engine.graph import log_lazybuffer, realized_lazybuffer -from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, GlobalCounters, colored, prod, dedup, all_int, merge_dicts, getenv +from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, GlobalCounters, colored, prod, dedup, all_int, merge_dicts, getenv, Metadata from tinygrad.shape.symbolic import Variable from tinygrad.dtype import ConstType, ImageDType, dtypes from tinygrad.lazy import LazyBuffer @@ -23,6 +23,7 @@ class ScheduleItem: ast: Tuple[LazyOp, ...] bufs: Tuple[Buffer, ...] + metadata: Optional[List[Metadata]] = None @property def outputs(self) -> Tuple[Buffer, ...]: """Read/write or write only buffers in the schedule.""" @@ -96,20 +97,21 @@ def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None], re """describe the computation for a LazyBuffer with LazyOp + inputs + var_vals""" if (out:=outs[0]).op is LoadOps.COPY and getenv("USE_COPY_KERNEL") and out.device.split(":")[0] == out.srcs[0].device.split(":")[0]: rd = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.uint8, st:=ShapeTracker.from_shape((out.arg,)))) - return (LazyOp(BufferOps.STORE, (rd,), MemBuffer(0, dtypes.uint8, st)), ), [x.base for x in out.srcs], {} - if out.op in {LoadOps.CUSTOM, LoadOps.COPY, LoadOps.EMPTY, LoadOps.VIEW}: return (LazyOp(out.op, (), out.arg), ), [x.base for x in out.srcs], {} + return (LazyOp(BufferOps.STORE, (rd,), MemBuffer(0, dtypes.uint8, st)), ), [x.base for x in out.srcs], {}, [] + if out.op in {LoadOps.CUSTOM, LoadOps.COPY, LoadOps.EMPTY, LoadOps.VIEW}: return (LazyOp(out.op, (), out.arg), ), [x.base for x in out.srcs], {}, [] var_vals: Dict[Variable, int] = merge_dicts([out.st.var_vals.copy() for out in outs]) assign_targets = {x.srcs[1]:x for x in outs if x.op is LoadOps.ASSIGN} + cache: Dict[Tuple[LazyBuffer, ShapeTracker], LazyOp] = {} ast: List[LazyOp] = [] inputs: List[LazyBuffer] = [] for i, out in enumerate(outs): output_st = ShapeTracker.from_shape(reduce_for_op[out].shape if out in reduce_for_op else out.shape) output_view = out.arg[0] if out.op is LoadOps.ASSIGN and out.arg else output_st - lop = _recursive_lazyop(out, inputs, tuple(outs), var_vals, output_st, realizes, assign_targets, cache={}) + lop = _recursive_lazyop(out, inputs, tuple(outs), var_vals, output_st, realizes, assign_targets, cache=cache) output_view, vv = output_view.simplify().unbind() if vv: var_vals.update(vv) ast.append(LazyOp(BufferOps.STORE, (lop, ), MemBuffer(i, out.dtype, output_view))) - return tuple(ast), inputs, var_vals + return tuple(ast), inputs, var_vals, dedup([x[0].metadata for x in cache if x[0].metadata and x[0] not in inputs]) # *** DAG creation: decide which LazyBuffers should realize *** @@ -301,7 +303,7 @@ def create_schedule_with_vars(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffe for out in ps[0]: realized_lazybuffer(out, kernel_number) var_vals = merge_dicts([var_vals, ps[3]]) for out in ps[0]: del out.srcs # can only schedule once - schedule.append(si:=ScheduleItem(ps[1], tuple(x.buffer for x in ps[0]+ps[2] if x.size != 0))) + schedule.append(si:=ScheduleItem(ps[1], tuple(x.buffer for x in ps[0]+ps[2] if x.size != 0), ps[4])) if logops and si.ast[0].op not in LoadOps and not any(i.device.startswith("DISK:") for i in si.inputs): logops.write(str(si.ast)+"\n") for x in graph[ps[0][0]]: in_degree[x] -= 1 @@ -366,4 +368,4 @@ def find_replace_buffer(buf, st, en): def memory_planner(schedule:List[ScheduleItem]) -> List[ScheduleItem]: # Exclude buffers involved in load ops (e.g transfers) to preserve parallelism in graphs. assigned = _internal_memory_planner([si.bufs for si in schedule], noopt_buffers={b for si in schedule if si.ast[0].op in LoadOps for b in si.bufs}) - return [ScheduleItem(si.ast, tuple(assigned.get(x, x) for x in si.bufs)) for si in schedule] + return [ScheduleItem(si.ast, tuple(assigned.get(x, x) for x in si.bufs), si.metadata) for si in schedule] diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 339fd58073561..29b2354243eba 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -1,6 +1,7 @@ from __future__ import annotations import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, cProfile, pstats, tempfile, pathlib, string, ctypes, sys -import itertools, urllib.request, subprocess, shutil, math, json +import itertools, urllib.request, subprocess, shutil, math, json, contextvars +from dataclasses import dataclass from typing import Dict, Tuple, Union, List, ClassVar, Optional, Iterable, Any, TypeVar, TYPE_CHECKING, Callable, Sequence if TYPE_CHECKING: # TODO: remove this and import TypeGuard from typing once minimum python supported version is 3.10 from typing_extensions import TypeGuard @@ -101,10 +102,20 @@ def __gt__(self, x): return self.value > x def __lt__(self, x): return self.value < x DEBUG, IMAGE, BEAM, NOOPT, JIT = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0), ContextVar("BEAM", 0), ContextVar("NOOPT", 0), ContextVar("JIT", 1) -WINO, THREEFRY, CAPTURING = ContextVar("WINO", 0), ContextVar("THREEFRY", 0), ContextVar("CAPTURING", 1) +WINO, THREEFRY, CAPTURING, TRACEMETA = ContextVar("WINO", 0), ContextVar("THREEFRY", 0), ContextVar("CAPTURING", 1), ContextVar("TRACEMETA", 1) GRAPH, GRAPHPATH, SAVE_SCHEDULE, RING = ContextVar("GRAPH", 0), getenv("GRAPHPATH", "/tmp/net"), ContextVar("SAVE_SCHEDULE", 0), ContextVar("RING", 1) MULTIOUTPUT, PROFILE = ContextVar("MULTIOUTPUT", 1), ContextVar("PROFILE", 0) +@dataclass(frozen=True) +class Metadata: + name: str + caller: str + backward: bool = False + def __hash__(self): return hash(self.name) + def __repr__(self): return str(self) + (f" - {self.caller}" if self.caller else "") + def __str__(self): return self.name + (" bw" if self.backward else "") +_METADATA: contextvars.ContextVar[Optional[Metadata]] = contextvars.ContextVar("_METADATA", default=None) + # **************** global state Counters **************** class GlobalCounters: @@ -306,4 +317,4 @@ def fn(x): return (f"{x/1000**int(g:=math.log(x,1000)):.{int(3-3*math.fmod(g,1)) print(bar[:ncols+1],flush=True,end='\n'*close,file=sys.stderr) class trange(tqdm): - def __init__(self, n:int, **kwargs): super().__init__(iterable=range(n), total=n, **kwargs) \ No newline at end of file + def __init__(self, n:int, **kwargs): super().__init__(iterable=range(n), total=n, **kwargs) diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index 8e13242985582..e9770e868034d 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -1,7 +1,7 @@ from __future__ import annotations from typing import Union, Optional, Any, Tuple, List from tinygrad.dtype import dtypes, DType, ConstType -from tinygrad.helpers import prod, getenv, all_int, all_same, DEBUG +from tinygrad.helpers import prod, getenv, all_int, all_same, DEBUG, _METADATA, Metadata from tinygrad.ops import LoadOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, Op, exec_alu, python_alu from tinygrad.shape.symbolic import sint, Variable from tinygrad.shape.shapetracker import ShapeTracker @@ -17,7 +17,7 @@ def create_lazybuffer(device:str, st:ShapeTracker, dtype:DType, op:Optional[Op]= cache_key = (device, st, dtype, op, arg, tuple(ref(x) for x in srcs)) if base is None else (st, ref(base)) if enable_cache and (rret := lazycache.get(cache_key, None)): return rret - ret = LazyBuffer(device, st, dtype, op, arg, srcs, base=base) + ret = LazyBuffer(device, st, dtype, op, arg, srcs, base=base, metadata=_METADATA.get()) if enable_cache: lazycache[cache_key] = ret return ret @@ -25,8 +25,8 @@ def create_lazybuffer(device:str, st:ShapeTracker, dtype:DType, op:Optional[Op]= class LazyBuffer: def __init__(self, device:str, st:ShapeTracker, dtype:DType, op:Optional[Op]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(), - base:Optional[LazyBuffer]=None): - self.device, self.st, self.dtype, self.shape, self.size = device, st, dtype, st.shape, st.size + base:Optional[LazyBuffer]=None, metadata:Optional[Metadata]=None): + self.device, self.st, self.dtype, self.shape, self.size, self.metadata = device, st, dtype, st.shape, st.size, metadata self._base: Optional[LazyBuffer] = None if base is None: # properties on base diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 15082b18ce054..cf9d6f38a7cd1 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1,6 +1,7 @@ # inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py from __future__ import annotations -import time, math, itertools, functools, struct +import dataclasses +import time, math, itertools, functools, struct, sys, inspect from contextlib import ContextDecorator from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Dict, DefaultDict, cast, get_args, Set from collections import defaultdict @@ -8,7 +9,7 @@ from tinygrad.dtype import DType, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype from tinygrad.helpers import argfix, make_pair, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, get_shape, fully_flatten, dedup -from tinygrad.helpers import IMAGE, DEBUG, WINO, THREEFRY +from tinygrad.helpers import IMAGE, DEBUG, WINO, THREEFRY, _METADATA, Metadata, TRACEMETA from tinygrad.lazy import LazyBuffer from tinygrad.multi import MultiLazyBuffer from tinygrad.ops import LoadOps, truncate @@ -20,18 +21,19 @@ # **** start with two base classes, Tensor and Function **** class Function: - def __init__(self, device:Union[str, Tuple[str, ...]], *tensors:Tensor): + def __init__(self, device:Union[str, Tuple[str, ...]], *tensors:Tensor, metadata:Optional[Metadata]=None): self.device = device self.needs_input_grad = [t.requires_grad for t in tensors] self.requires_grad = True if any(self.needs_input_grad) else None if None in self.needs_input_grad else False if self.requires_grad: self.parents = tensors + self.metadata = metadata def forward(self, *args, **kwargs): raise NotImplementedError(f"forward not implemented for {type(self)}") def backward(self, *args, **kwargs): raise RuntimeError(f"backward not implemented for {type(self)}") @classmethod def apply(fxn:Type[Function], *x:Tensor, **kwargs) -> Tensor: - ctx = fxn(x[0].device, *x) + ctx = fxn(x[0].device, *x, metadata=_METADATA.get()) ret = Tensor.__new__(Tensor) ret.lazydata, ret.requires_grad, ret.grad = ctx.forward(*[t.lazydata for t in x], **kwargs), ctx.requires_grad, None ret._ctx = ctx if ctx.requires_grad and not Tensor.no_grad else None # used by autograd engine @@ -740,7 +742,9 @@ def backward(self) -> Tensor: for t0 in reversed(self._deepwalk()): if t0.grad is None: raise RuntimeError(f"tensor {t0} has no grad") + token = _METADATA.set(dataclasses.replace(md, backward=True) if (md := t0._ctx.metadata) is not None else None) grads = t0._ctx.backward(t0.grad.lazydata) + _METADATA.reset(token) grads = [Tensor(g, device=self.device, requires_grad=False) if g is not None else None for g in ([grads] if len(t0._ctx.parents) == 1 else grads)] for t, g in zip(t0._ctx.parents, grads): @@ -3091,3 +3095,36 @@ def custom_random(out:Buffer): if out.dtype == dtypes.half: rng_np_buffer = (rng.integers(low=0, high=2047, size=out.size) / 2048).astype(np.half, copy=False) else: rng_np_buffer = rng.random(size=out.size, dtype=np.float32).astype(dtype=_to_np_dtype(out.dtype), copy=False) out.copyin(rng_np_buffer.data) + +def _metadata_wrapper(fn): + def _wrapper(*args, **kwargs): + if _METADATA.get() is not None: return fn(*args, **kwargs) + + caller_frame = sys._getframe(frame := 1) + caller_module = caller_frame.f_globals.get("__name__", None) + caller_func = caller_frame.f_code.co_name + if caller_module is None: return fn(*args, **kwargs) + + # if its called from nn we want to step up frames until we are out of nn + while caller_module.startswith("tinygrad.nn") and "optim" not in caller_module: + caller_frame = sys._getframe(frame := frame + 1) + caller_module = caller_frame.f_globals.get("__name__", None) + if caller_module is None: return fn(*args, **kwargs) + + # if its called from a lambda in tinygrad we want to look two more frames up + if caller_module.startswith("tinygrad") and caller_func == "": caller_frame = sys._getframe(frame := frame + 2) + caller_module = caller_frame.f_globals.get("__name__", None) + if caller_module is None: return fn(*args, **kwargs) + caller_func = caller_frame.f_code.co_name + caller_lineno = caller_frame.f_lineno + + token = _METADATA.set(Metadata(name=fn.__name__, caller=f"{caller_module}:{caller_lineno}::{caller_func}")) + ret = fn(*args, **kwargs) + _METADATA.reset(token) + return ret + return _wrapper + +if TRACEMETA >= 1: + for name, fn in inspect.getmembers(Tensor, inspect.isfunction): + if name in ["__class__", "__init__", "__repr__", "backward", "sequential"]: continue + setattr(Tensor, name, functools.wraps(fn)(_metadata_wrapper(fn)))