forked from arogozhnikov/einops
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path_backends.py
680 lines (493 loc) · 20.5 KB
/
_backends.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
"""
Backends in `einops` are organized to meet the following requirements
- backends are not imported unless those are actually needed, because
- backends may not be installed
- importing all available backends will drive to significant memory footprint
- backends may by present but installed with errors (but never used),
importing may drive to crashes
- backend should be either symbolic or imperative (tensorflow is for both, but that causes problems)
- this determines which methods (from_numpy/to_numpy or create_symbol/eval_symbol) should be defined
- if backend can't (temporarily) provide symbols for shape dimensions, UnknownSize objects are used
"""
import sys
import warnings
__author__ = 'Alex Rogozhnikov'
_backends: dict = {}
_debug_importing = False
def get_backend(tensor) -> 'AbstractBackend':
"""
Takes a correct backend (e.g. numpy backend if tensor is numpy.ndarray) for a tensor.
If needed, imports package and creates backend
"""
for framework_name, backend in _backends.items():
if backend.is_appropriate_type(tensor):
return backend
# Find backend subclasses recursively
backend_subclasses = []
backends = AbstractBackend.__subclasses__()
while backends:
backend = backends.pop()
backends += backend.__subclasses__()
backend_subclasses.append(backend)
for BackendSubclass in backend_subclasses:
if _debug_importing:
print('Testing for subclass of ', BackendSubclass)
if BackendSubclass.framework_name not in _backends:
# check that module was already imported. Otherwise it can't be imported
if BackendSubclass.framework_name in sys.modules:
if _debug_importing:
print('Imported backend for ', BackendSubclass.framework_name)
backend = BackendSubclass()
_backends[backend.framework_name] = backend
if backend.is_appropriate_type(tensor):
return backend
raise RuntimeError('Tensor type unknown to einops {}'.format(type(tensor)))
class AbstractBackend:
""" Base backend class, major part of methods are only for debugging purposes. """
framework_name: str
def is_appropriate_type(self, tensor):
""" helper method should recognize tensors it can handle """
raise NotImplementedError()
def from_numpy(self, x):
raise NotImplementedError("framework doesn't support imperative execution")
def to_numpy(self, x):
raise NotImplementedError("framework doesn't support imperative execution")
def create_symbol(self, shape):
raise NotImplementedError("framework doesn't support symbolic computations")
def eval_symbol(self, symbol, input_dict):
raise NotImplementedError("framework doesn't support symbolic computations")
def arange(self, start, stop):
# supplementary method used only in testing, so should implement CPU version
raise NotImplementedError("framework doesn't implement arange")
def shape(self, x):
"""shape should return a tuple with integers or "shape symbols" (which will evaluate to actual size)"""
return x.shape
def reshape(self, x, shape):
return x.reshape(shape)
def transpose(self, x, axes):
return x.transpose(axes)
def reduce(self, x, operation, axes):
return getattr(x, operation)(axis=axes)
def stack_on_zeroth_dimension(self, tensors: list):
raise NotImplementedError()
def add_axis(self, x, new_position):
raise NotImplementedError()
def add_axes(self, x, n_axes, pos2len):
repeats = [1] * n_axes
for axis_position, axis_length in pos2len.items():
x = self.add_axis(x, axis_position)
repeats[axis_position] = axis_length
return self.tile(x, tuple(repeats))
def tile(self, x, repeats):
"""repeats is a number of """
raise NotImplementedError()
def concat(self, tensors, axis: int):
""" concatenates tensors along axis.
Assume identical across tensors: devices, dtypes and shapes except selected axis."""
raise NotImplementedError()
def is_float_type(self, x):
# some backends (torch) can't compute average for non-floating types.
# Decided to drop average for all backends if type is not floating
raise NotImplementedError()
def layers(self):
raise NotImplementedError("backend does not provide layers")
def __repr__(self):
return "<einops backend for {}>".format(self.framework_name)
def einsum(self, pattern, *x):
raise NotImplementedError("backend does not support einsum")
class UnknownSize:
""" pseudo-symbol for symbolic frameworks which do not provide symbols for shape elements """
def __floordiv__(self, other):
return self
def __eq__(self, other):
return True # we don't know actual size
def __mul__(self, other):
return self
def __rmul__(self, other):
return self
def __hash__(self):
return hash(None)
class NumpyBackend(AbstractBackend):
framework_name = 'numpy'
def __init__(self):
import numpy
self.np = numpy
def is_appropriate_type(self, tensor):
return isinstance(tensor, self.np.ndarray)
def from_numpy(self, x):
return x
def to_numpy(self, x):
return x
def arange(self, start, stop):
return self.np.arange(start, stop)
def stack_on_zeroth_dimension(self, tensors: list):
return self.np.stack(tensors)
def tile(self, x, repeats):
return self.np.tile(x, repeats)
def concat(self, tensors, axis: int):
return self.np.concatenate(tensors, axis=axis)
def is_float_type(self, x):
return x.dtype in ('float16', 'float32', 'float64', 'float128', 'bfloat16')
def add_axis(self, x, new_position):
return self.np.expand_dims(x, new_position)
def einsum(self, pattern, *x):
return self.np.einsum(pattern, *x)
class JaxBackend(NumpyBackend):
framework_name = 'jax'
def __init__(self):
super(JaxBackend, self).__init__()
self.onp = self.np
import jax.numpy
self.np = jax.numpy
def from_numpy(self, x):
return self.np.asarray(x)
def to_numpy(self, x):
return self.onp.asarray(x)
class GluonBackend(AbstractBackend):
framework_name = 'mxnet.ndarray'
def __init__(self):
import mxnet
self.mx = mxnet
def is_appropriate_type(self, tensor):
return isinstance(tensor, self.mx.nd.NDArray)
def from_numpy(self, x):
if len(x.shape) == 0:
x = x[None] # poor support of scalars in mxnet, otherwise mxnet can't attach gradients
var = self.mx.nd.array(x, dtype=x.dtype)
var.attach_grad()
return var
def to_numpy(self, x):
return self.mx.nd.NDArray.asnumpy(x)
def reshape(self, x, shape):
if len(shape) == 0:
return x # poor support of scalars in mxnet
return x.reshape(shape)
def arange(self, start, stop):
return self.mx.nd.arange(start, stop)
def stack_on_zeroth_dimension(self, tensors: list):
return self.mx.nd.stack(*tensors)
def tile(self, x, repeats):
return self.mx.nd.tile(x, repeats)
def concat(self, tensors, axis: int):
return self.mx.nd.concat(*tensors, dim=axis)
def add_axis(self, x, new_position):
return self.mx.nd.expand_dims(x, new_position)
def is_float_type(self, x):
return 'float' in str(x.dtype)
def layers(self):
from .layers import gluon
return gluon
class MXNetBackend(AbstractBackend):
framework_name = 'mxnet.symbol'
def __init__(self):
import mxnet
self.mx = mxnet
def is_appropriate_type(self, tensor):
return isinstance(tensor, self.mx.symbol.Symbol)
def create_symbol(self, shape, dtype='float32'):
# mxnet accepts zeros as undefined dimensions
shape = tuple(0 if d is None else d for d in shape)
var = self.mx.symbol.Variable('input', shape=shape, dtype=dtype)
return var
def eval_symbol(self, symbol, input_dict):
args = {var.name: self.mx.nd.array(val) for var, val in input_dict}
ex = symbol.bind(ctx=self.mx.cpu(), args=args)
ex.forward()
return ex.outputs[0].asnumpy()
def shape(self, x):
# mxnet has problems with shape inference - it does not provide shape symbols
# shape_array seems to be impossible to use in shape inference
# infer_shape_partial returns empty tuple if was not able to infer shape
# reductions such as sum can't return scalars, but return 1-element vectors
shape = x.infer_shape_partial()[1][0]
if len(shape) == 0:
warnings.warn('mxnet inferred shape to be (), which probably means it could not be inferred')
shape = tuple(UnknownSize() if d == 0 else d for d in shape)
return shape
def reshape(self, x, shape):
if len(shape) == 0:
return x # poor support of scalars in mxnet
if any(isinstance(dimension, UnknownSize) for dimension in shape):
from einops import EinopsError
raise EinopsError("Mxnet couldn't infer all dimensions statically, please provide those with axes_lengths")
return x.reshape(shape)
def arange(self, start, stop):
return self.mx.symbol.arange(start, stop)
def stack_on_zeroth_dimension(self, tensors: list):
return self.mx.symbol.stack(*tensors)
def tile(self, x, repeats):
return self.mx.symbol.tile(x, repeats)
def concat(self, tensors, axis: int):
return self.mx.symbol.concat(tensors, dim=axis)
def add_axis(self, x, new_position):
return self.mx.symbol.expand_dims(x, new_position)
def is_float_type(self, x):
return 'float' in str(x.infer_type()[1][0])
def layers(self):
from .layers import gluon
return gluon
class TorchBackend(AbstractBackend):
framework_name = 'torch'
def __init__(self):
import torch
self.torch = torch
def is_appropriate_type(self, tensor):
return isinstance(tensor, self.torch.Tensor)
def from_numpy(self, x):
variable = self.torch.from_numpy(x)
if self.is_float_type(variable):
# attach grad only to floating types
variable.requires_grad = True
return variable
def to_numpy(self, x):
return x.detach().cpu().numpy()
def arange(self, start, stop):
return self.torch.arange(start, stop, dtype=self.torch.int64)
def reduce(self, x, operation, reduced_axes):
if operation == 'min':
return x.amin(dim=reduced_axes)
elif operation == 'max':
return x.amax(dim=reduced_axes)
elif operation == 'sum':
return x.sum(dim=reduced_axes)
elif operation == 'mean':
return x.mean(dim=reduced_axes)
elif operation == 'prod':
for i in list(sorted(reduced_axes))[::-1]:
x = x.prod(dim=i)
return x
else:
raise NotImplementedError('Unknown reduction ', operation)
def transpose(self, x, axes):
return x.permute(axes)
def stack_on_zeroth_dimension(self, tensors: list):
return self.torch.stack(tensors)
def add_axes(self, x, n_axes, pos2len):
repeats = [-1] * n_axes
for axis_position, axis_length in pos2len.items():
x = self.add_axis(x, axis_position)
repeats[axis_position] = axis_length
return x.expand(repeats)
def tile(self, x, repeats):
return x.repeat(repeats)
def concat(self, tensors, axis: int):
return self.torch.concat(tensors, dim=axis)
def add_axis(self, x, new_position):
return self.torch.unsqueeze(x, new_position)
def is_float_type(self, x):
return x.dtype in [self.torch.float16, self.torch.float32, self.torch.float64, self.torch.bfloat16]
def layers(self):
from .layers import torch
return torch
def einsum(self, pattern, *x):
return self.torch.einsum(pattern, *x)
class CupyBackend(AbstractBackend):
framework_name = 'cupy'
def __init__(self):
import cupy
self.cupy = cupy
def is_appropriate_type(self, tensor):
return isinstance(tensor, self.cupy.ndarray)
def from_numpy(self, x):
return self.cupy.asarray(x)
def to_numpy(self, x):
return self.cupy.asnumpy(x)
def arange(self, start, stop):
return self.cupy.arange(start, stop)
def stack_on_zeroth_dimension(self, tensors: list):
return self.cupy.stack(tensors)
def tile(self, x, repeats):
return self.cupy.tile(x, repeats)
def concat(self, tensors, axis: int):
return self.cupy.concatenate(tensors, axis=axis)
def add_axis(self, x, new_position):
return self.cupy.expand_dims(x, new_position)
def is_float_type(self, x):
return x.dtype in ('float16', 'float32', 'float64', 'float128', 'bfloat16')
def einsum(self, pattern, *x):
return self.cupy.einsum(pattern, *x)
class ChainerBackend(AbstractBackend):
framework_name = 'chainer'
def __init__(self):
import chainer
import numpy
self.numpy = numpy
self.chainer = chainer
def is_appropriate_type(self, tensor):
return isinstance(tensor, self.chainer.Variable)
def from_numpy(self, x):
return self.chainer.Variable(x.astype('float32'))
def to_numpy(self, x):
if isinstance(x, self.chainer.Variable):
x = x.data
return x
def arange(self, start, stop):
return self.numpy.arange(start, stop)
def reduce(self, x, operation, axes):
return getattr(self.chainer.functions, operation)(x, axis=axes)
def stack_on_zeroth_dimension(self, tensors: list):
return self.chainer.functions.stack(tensors)
def tile(self, x, repeats):
return self.chainer.functions.tile(x, repeats)
def concat(self, tensors, axis: int):
return self.chainer.functions.concat(tensors, axis=axis)
def add_axis(self, x, new_position):
return self.chainer.functions.expand_dims(x, new_position)
def is_float_type(self, x):
return x.dtype in ('float16', 'float32', 'float64', 'float128', 'bfloat16')
def layers(self):
from .layers import chainer
return chainer
def einsum(self, pattern, *x):
return self.chainer.functions.einsum(pattern, *x)
class HashableTuple:
"""Overcomes non-hashability of symbolic elements"""
def __init__(self, elements: tuple):
self.elements = elements
def __iter__(self):
for x in self.elements:
yield x
def __len__(self):
return len(self.elements)
def __getitem__(self, item):
return self.elements[item]
class TensorflowBackend(AbstractBackend):
framework_name = 'tensorflow'
def __init__(self):
import tensorflow
self.tf = tensorflow
def is_appropriate_type(self, tensor):
return isinstance(tensor, (self.tf.Tensor, self.tf.Variable))
def from_numpy(self, x):
assert self.tf.executing_eagerly()
return self.tf.convert_to_tensor(x)
def to_numpy(self, x):
assert self.tf.executing_eagerly()
return x.numpy()
def arange(self, start, stop):
return self.tf.range(start, stop)
def shape(self, x):
if self.tf.executing_eagerly():
return tuple(UnknownSize() if d is None else int(d) for d in x.shape)
else:
static_shape = x.shape.as_list()
tf_shape = self.tf.shape(x)
# use the static shape where known, otherwise use the TF shape components
shape = tuple([s or tf_shape[dim] for dim, s in enumerate(static_shape)])
try:
hash(shape)
return shape
except:
# unhashable symbols in shape. Wrap tuple to be hashable.
return HashableTuple(shape)
def reduce(self, x, operation, axes):
return getattr(self.tf, 'reduce_' + operation)(x, axis=axes)
def reshape(self, x, shape):
return self.tf.reshape(x, shape)
def transpose(self, x, axes):
return self.tf.transpose(x, axes)
def stack_on_zeroth_dimension(self, tensors: list):
return self.tf.stack(tensors)
def tile(self, x, repeats):
return self.tf.tile(x, repeats)
def concat(self, tensors, axis: int):
return self.tf.concat(tensors, axis=axis)
def add_axis(self, x, new_position):
return self.tf.expand_dims(x, new_position)
def is_float_type(self, x):
return x.dtype in ('float16', 'float32', 'float64', 'float128', 'bfloat16')
def layers(self):
from .layers import tensorflow
return tensorflow
def einsum(self, pattern, *x):
return self.tf.einsum(pattern, *x)
class KerasBackend(AbstractBackend):
framework_name = 'tensorflow.keras'
def __init__(self):
import tensorflow as tf
self.tf = tf
self.keras = tf.keras
self.K = tf.keras.backend
def is_appropriate_type(self, tensor):
return self.tf.is_tensor(tensor) and self.K.is_keras_tensor(tensor)
def create_symbol(self, shape):
return self.keras.Input(batch_shape=shape)
def eval_symbol(self, symbol, input_dict):
model = self.keras.models.Model([var for (var, _) in input_dict], symbol)
return model.predict_on_batch([val for (_, val) in input_dict])
def arange(self, start, stop):
return self.K.arange(start, stop)
def shape(self, x):
shape = self.K.shape(x) # tf tensor
return HashableTuple(tuple(shape))
def reduce(self, x, operation, axes):
return getattr(self.K, operation)(x, axis=axes)
def reshape(self, x, shape):
return self.K.reshape(x, shape)
def transpose(self, x, axes):
return self.K.permute_dimensions(x, axes)
def stack_on_zeroth_dimension(self, tensors: list):
return self.K.stack(tensors)
def tile(self, x, repeats):
return self.K.tile(x, repeats)
def concat(self, tensors, axis: int):
return self.K.concatenate(tensors, axis=axis)
def add_axis(self, x, new_position):
return self.K.expand_dims(x, new_position)
def is_float_type(self, x):
return 'float' in self.K.dtype(x)
def layers(self):
from .layers import keras
return keras
class OneFlowBackend(AbstractBackend):
framework_name = "oneflow"
def __init__(self):
import oneflow as flow
self.flow = flow
print("using oneflow")
def is_appropriate_type(self, tensor):
return isinstance(tensor, self.flow.Tensor)
def from_numpy(self, x):
variable = self.flow.from_numpy(x)
if self.is_float_type(variable):
# attach grad only to floating types
variable.requires_grad = True
return variable
def to_numpy(self, x):
return x.detach().cpu().numpy()
def arange(self, start, stop):
return self.flow.arange(start, stop, dtype=self.flow.int64)
def reduce(self, x, operation, reduced_axes):
for axis in sorted(reduced_axes, reverse=True):
if operation == 'min':
x, _ = x.min(dim=axis)
elif operation == 'max':
x, _ = x.max(dim=axis)
elif operation in ['sum', 'mean', 'prod']:
x = getattr(x, operation)(dim=axis)
else:
raise NotImplementedError('Unknown reduction ', operation)
return x
def transpose(self, x, axes):
return x.permute(axes)
def stack_on_zeroth_dimension(self, tensors: list):
return self.flow.stack(tensors)
def add_axes(self, x, n_axes, pos2len):
repeats = [-1] * n_axes
for axis_position, axis_length in pos2len.items():
x = self.add_axis(x, axis_position)
repeats[axis_position] = axis_length
return x.expand(*repeats)
def tile(self, x, repeats):
return x.repeat(repeats)
def concat(self, tensors, axis: int):
return self.flow.concat(tensors, dim=axis)
def add_axis(self, x, new_position):
return self.flow.unsqueeze(x, new_position)
def is_float_type(self, x):
return x.dtype in [self.flow.float16, self.flow.float32, self.flow.float64]
def layers(self):
from .layers import oneflow
return oneflow
def einsum(self, pattern, *x):
return self.flow.einsum(pattern, *x)