forked from tinygrad/tinygrad
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathops_metal.py
120 lines (110 loc) · 6.73 KB
/
ops_metal.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from __future__ import annotations
import os, subprocess, pathlib, ctypes, tempfile, functools
import Metal, libdispatch
from typing import List, Any, Tuple, Optional
from tinygrad.helpers import prod, getenv, DEBUG, unwrap2
from tinygrad.device import Compiled, Compiler, CompileError, LRUAllocator
from tinygrad.renderer.cstyle import MetalRenderer
def wait_check(cbuf: Any):
cbuf.waitUntilCompleted()
if (error := cbuf.error()) is not None:
raise RuntimeError(error)
class MetalCompiler(Compiler):
def __init__(self, device:Optional[MetalDevice]):
self.device = device
super().__init__("compile_metal")
def compile(self, src:str) -> bytes:
if self.device is None:
# NOTE: if you run llvm-dis on "air" you can see the llvm bytecode
air = subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metal', '-x', 'metal', '-c', '-', '-o', '-'], input=src.encode('utf-8'))
return subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metallib', '-', '-o', '-'], input=air)
options = Metal.MTLCompileOptions.new()
options.setFastMathEnabled_(getenv("METAL_FAST_MATH"))
try: library = unwrap2(self.device.device.newLibraryWithSource_options_error_(src, options, None))
except AssertionError as e: raise CompileError(e) from e
return library.libraryDataContents().bytes().tobytes()
class MetalProgram:
def __init__(self, device:MetalDevice, name:str, lib:bytes):
self.device, self.name, self.lib = device, name, lib
if DEBUG >= 6:
with tempfile.NamedTemporaryFile(delete=True) as shader:
shader.write(lib)
shader.flush()
ret = os.system(f"cd {pathlib.Path(__file__).parents[2]}/extra/disassemblers/applegpu && python3 compiler_explorer.py {shader.name}")
if ret:
print("Error running disassembler: Make sure you have https://github.com/dougallj/applegpu cloned to tinygrad/extra/disassemblers/applegpu")
assert lib[:4] == b"MTLB", "Invalid Metal library. Could be due to using conda. Try system python or METAL_XCODE=1 DISABLE_COMPILER_CACHE=1."
data = libdispatch.dispatch_data_create(lib, len(lib), None, None)
self.library = unwrap2(self.device.device.newLibraryWithData_error_(data, None))
self.fxn = self.library.newFunctionWithName_(name)
descriptor = Metal.MTLComputePipelineDescriptor.new()
descriptor.setComputeFunction_(self.fxn)
descriptor.setSupportIndirectCommandBuffers_(True)
self.pipeline_state = unwrap2(self.device.device.newComputePipelineStateWithDescriptor_options_reflection_error_(
descriptor, Metal.MTLPipelineOption(0), None, None))
def __call__(self, *bufs, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1), vals:Tuple[int, ...]=(), wait=False):
if prod(local_size) > self.pipeline_state.maxTotalThreadsPerThreadgroup(): raise RuntimeError(f"local size {local_size} bigger than {self.pipeline_state.maxTotalThreadsPerThreadgroup()} with exec width {self.pipeline_state.threadExecutionWidth()} memory length {self.pipeline_state.staticThreadgroupMemoryLength()}") # noqa: E501
command_buffer = self.device.mtl_queue.commandBuffer()
encoder = command_buffer.computeCommandEncoder()
encoder.setComputePipelineState_(self.pipeline_state)
for i,a in enumerate(bufs): encoder.setBuffer_offset_atIndex_(a.buf, a.offset, i)
for i,a in enumerate(vals,start=len(bufs)): encoder.setBytes_length_atIndex_(ctypes.c_int32(a), 4, i)
encoder.dispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size))
encoder.endEncoding()
command_buffer.commit()
if wait:
wait_check(command_buffer)
return command_buffer.GPUEndTime() - command_buffer.GPUStartTime()
self.device.mtl_buffers_in_flight.append(command_buffer)
class MetalBuffer:
def __init__(self, buf:Any, size:int, offset=0): self.buf, self.size, self.offset = buf, size, offset
class MetalAllocator(LRUAllocator):
def __init__(self, device:MetalDevice):
self.device:MetalDevice = device
super().__init__()
def _alloc(self, size:int, options) -> MetalBuffer:
ret = self.device.device.newBufferWithLength_options_(size, Metal.MTLResourceStorageModeShared)
if ret is None: raise MemoryError(f"Metal OOM while allocating {size=}")
return MetalBuffer(ret, size)
def _free(self, opaque:MetalBuffer, options): opaque.buf.release()
def transfer(self, dest:MetalBuffer, src:MetalBuffer, sz:int, src_dev:MetalDevice, dest_dev:MetalDevice):
dest_dev.synchronize()
src_command_buffer = src_dev.mtl_queue.commandBuffer()
encoder = src_command_buffer.blitCommandEncoder()
encoder.copyFromBuffer_sourceOffset_toBuffer_destinationOffset_size_(src.buf, src.offset, dest.buf, dest.offset, sz)
encoder.endEncoding()
if src_dev != dest_dev:
src_command_buffer.encodeSignalEvent_value_(src_dev.timeline_signal, src_dev.timeline_value)
dest_command_buffer = dest_dev.mtl_queue.commandBuffer()
dest_command_buffer.encodeWaitForEvent_value_(src_dev.timeline_signal, src_dev.timeline_value)
dest_command_buffer.commit()
dest_dev.mtl_buffers_in_flight.append(dest_command_buffer)
src_dev.timeline_value += 1
src_command_buffer.commit()
src_dev.mtl_buffers_in_flight.append(src_command_buffer)
def from_buffer(self, src:memoryview) -> Optional[Any]:
ret = self.device.device.newBufferWithBytesNoCopy_length_options_deallocator_(src, src.nbytes, Metal.MTLResourceStorageModeShared, None)
if ret: self.device.mv_in_metal.append(src)
return MetalBuffer(ret, src.nbytes)
def as_buffer(self, src:MetalBuffer) -> memoryview:
self.device.synchronize()
return src.buf.contents().as_buffer(src.offset+src.size)[src.offset:]
def copyin(self, dest:MetalBuffer, src:memoryview): self.as_buffer(dest)[:] = src
def copyout(self, dest:memoryview, src:MetalBuffer): dest[:] = self.as_buffer(src)
def offset(self, buf:MetalBuffer, size:int, offset:int): return MetalBuffer(buf.buf, size, offset)
class MetalDevice(Compiled):
def __init__(self, device:str):
self.device = Metal.MTLCreateSystemDefaultDevice()
self.mtl_queue = self.device.newCommandQueueWithMaxCommandBufferCount_(1024)
if self.mtl_queue is None: raise RuntimeError("Cannot allocate a new command queue")
self.mtl_buffers_in_flight: List[Any] = []
self.mv_in_metal: List[memoryview] = []
self.timeline_signal = self.device.newSharedEvent()
self.timeline_value = 0
from tinygrad.runtime.graph.metal import MetalGraph
super().__init__(device, MetalAllocator(self), MetalRenderer(), MetalCompiler(None if getenv("METAL_XCODE") else self),
functools.partial(MetalProgram, self), MetalGraph)
def synchronize(self):
for cbuf in self.mtl_buffers_in_flight: wait_check(cbuf)
self.mv_in_metal.clear()
self.mtl_buffers_in_flight.clear()