Having looked at google-research's blog post and tensorflow implementation on fast attention (FAVOR+), I was left scratching my head about the causal attention implementation. This repository investigates a simpler version.
- Causal attention can be concisely expressed mathematically using properties of low rank matrices and hadamard products (see Theory section below).
- ops/v1.py provide significantly simpler implementations that use neither loops over tensors nor custom gradients.
- the implementations are much shorter (3 and 2 lines vs. 25 and 22 of the original), making it much easier to reason about
- jit-compiling these operations is significantly faster than the originals (ops/v0.py)
- results are the same between the two implementations
- computation time using
google-benchmark
is indistinguishable
The google-ai blog post provides a visualisation of causal attention.
It's not immediately apparent to me what's going on here, and looking at the code (originally here but with relevant part included here for convenience) doesn't make things any clearer.
My implementation (v1) takes a different approach.
The task we consider is to compute the noncausal numerator
where
where
where
Note it is neither efficient nor necessary to compute any of the new matrices above.
The implementation looks slighty different to the maths above because we compute tf.linalg.matvec
.
def causal_numerator(qs: tf.Tensor, ks: tf.Tensor, vs: tf.Tensor):
"""Computes not-normalized FAVOR causal attention A_{masked}V.
Args:
qs: query_prime tensor of the shape [L,B,H,M].
ks: key_prime tensor of the shape [L,B,H,M].
vs: value tensor of the shape [L,B,H,D].
Returns:
Not-normalized FAVOR causal attention A_{masked}V.
"""
# rhs = tf.einsum('lbhm,lbhd->lbhdm', ks, vs)
rhs = tf.expand_dims(ks, axis=-2) * tf.expand_dims(vs, axis=-1) # [L,B,H,D,M]
rhs = tf.cumsum(rhs, axis=0)
# return tf.einsum('lbhm,lbhdm->lbhd', qs, rhs)
return tf.linalg.matvec(rhs, qs)
That's a 3-line implementation, as opposed to the 25 used in the original.
The noncausal denominator function is conceptually the same as the numerator except using the ones vector for ks
:
def causal_denominator(qs, ks):
"""Computes FAVOR normalizer in causal attention.
Args:
qs: query_prime tensor of the shape [L,B,H,M].
ks: key_prime tensor of the shape [L,B,H,M].
Returns:
FAVOR normalizer in causal attention.
"""
rhs = tf.cumsum(ks, axis=0)
return tf.einsum("lbhm,lbhm->lbh", qs, rhs)
That's 2 lines compared to 22 in the original.
The following benchmarks were run on my fairly old laptop with a 1050-Ti. Note that I trust the timings from google-benchmark
a lot more than I do from tfbm
/ tf.test.Benchmark
, but maybe I'm just misinterpretting them.
gbenchmark.py uses google-benchmark. The output is a lot simpler compared to tf.test.Benchmark
which means I trust these timings a lot more.
Take-aways:
- There isn't much difference between v0 and v1 implementations in terms of computation time
- v1 implementations warm-up significantly faster
- jit compilation significantly reduces forward time on cpu, but is negligible on gpu
python gbenchmark.py
--------------------------------------------------------------
Benchmark Time CPU Iterations
--------------------------------------------------------------
v0_forward-cpu 5403096 ns 364764 ns 1000
v1_forward-cpu 5419832 ns 365650 ns 1000
v0_backward-cpu 268558 ns 238634 ns 2896
v1_backward-cpu 267089 ns 235842 ns 2937
v0_forward-gpu 288531 ns 241580 ns 2874
v1_forward-gpu 285695 ns 238078 ns 2908
v0_backward-gpu 268220 ns 237309 ns 2869
v1_backward-gpu 268324 ns 240429 ns 2751
v0_forward-cpu-jit 299143 ns 271613 ns 2516
v1_forward-cpu-jit 291873 ns 269618 ns 2538
v0_backward-cpu-jit 303150 ns 275359 ns 2483
v1_backward-cpu-jit 303948 ns 276806 ns 2482
v0_forward-gpu-jit 278147 ns 277842 ns 2450
v1_forward-gpu-jit 276128 ns 275956 ns 2523
v0_backward-gpu-jit 256809 ns 256798 ns 2706
v1_backward-gpu-jit 252543 ns 252537 ns 2769
Warmup time for v0_forward-cpu: 6.56445574760437
Warmup time for v1_forward-cpu: 0.1015627384185791
Warmup time for v0_backward-cpu: 22.0670325756073
Warmup time for v1_backward-cpu: 0.08140373229980469
Warmup time for v0_forward-gpu: 6.233572244644165
Warmup time for v1_forward-gpu: 0.028412342071533203
Warmup time for v0_backward-gpu: 22.226712226867676
Warmup time for v1_backward-gpu: 0.051419734954833984
Warmup time for v0_forward-cpu-jit: 6.481787443161011
Warmup time for v1_forward-cpu-jit: 0.05790424346923828
Warmup time for v0_backward-cpu-jit: 24.72081184387207
Warmup time for v1_backward-cpu-jit: 0.09151363372802734
Warmup time for v0_forward-gpu-jit: 8.328083515167236
Warmup time for v1_forward-gpu-jit: 0.08592033386230469
Warmup time for v0_backward-gpu-jit: 24.7033634185791
Warmup time for v1_backward-gpu-jit: 0.12377095222473145
benchmark.py requires tfbm, a wrapper around tf.test.Benchmark implementations.
pip install git+https://github.com/jackd/tfbm
Results for cls=Forward
Uniform results:
+--------+---------+-------+
| run_id | cls | iters |
+--------+---------+-------+
| NOW | Forward | 10 |
+--------+---------+-------+
Varied results:
+--------------------+----------------+--------+------------------------+-----------------+--------------------------+---------+
| test | wall_time (us) | device | max_mem_GPU_0_bfc (Mb) | max_mem_cpu (b) | max_mem_gpu_host_bfc (b) | xla_jit |
+--------------------+----------------+--------+------------------------+-----------------+--------------------------+---------+
| v1_forward_xla_gpu | 218.511 | gpu | --- | 192.000 | 49.000 | True |
+--------------------+----------------+--------+------------------------+-----------------+--------------------------+---------+
| v0_forward_xla_gpu | 260.711 | gpu | --- | 192.000 | 49.000 | True |
+--------------------+----------------+--------+------------------------+-----------------+--------------------------+---------+
| v1_forward_xla_cpu | 272.274 | cpu | --- | 192.000 | 49.000 | True |
+--------------------+----------------+--------+------------------------+-----------------+--------------------------+---------+
| v0_forward_xla_cpu | 284.910 | cpu | --- | 192.000 | 49.000 | True |
+--------------------+----------------+--------+------------------------+-----------------+--------------------------+---------+
| v0_forward_gpu | 20148.039 | gpu | 2563.000 | 192.000 | 1088.000 | False |
+--------------------+----------------+--------+------------------------+-----------------+--------------------------+---------+
| v1_forward_gpu | 58524.966 | gpu | 516.000 | 192.000 | 64.000 | False |
+--------------------+----------------+--------+------------------------+-----------------+--------------------------+---------+
| v0_forward_cpu | 100526.690 | cpu | --- | 1776322300.000 | --- | False |
+--------------------+----------------+--------+------------------------+-----------------+--------------------------+---------+
| v1_forward_cpu | 233323.097 | cpu | --- | 541065220.000 | --- | False |
+--------------------+----------------+--------+------------------------+-----------------+--------------------------+---------+
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Results for cls=Backward
Uniform results:
+--------+----------+-------+
| run_id | cls | iters |
+--------+----------+-------+
| NOW | Backward | 10 |
+--------+----------+-------+
Varied results:
+---------------------+----------------+--------+------------------------+-----------------+--------------------------+---------+
| test | wall_time (us) | device | max_mem_GPU_0_bfc (Mb) | max_mem_cpu (b) | max_mem_gpu_host_bfc (b) | xla_jit |
+---------------------+----------------+--------+------------------------+-----------------+--------------------------+---------+
| v0_backward_xla_gpu | 225.782 | gpu | --- | 192.000 | 49.000 | True |
+---------------------+----------------+--------+------------------------+-----------------+--------------------------+---------+
| v1_backward_xla_gpu | 231.147 | gpu | --- | 192.000 | 49.000 | True |
+---------------------+----------------+--------+------------------------+-----------------+--------------------------+---------+
| v1_backward_xla_cpu | 239.730 | cpu | --- | 192.000 | 49.000 | True |
+---------------------+----------------+--------+------------------------+-----------------+--------------------------+---------+
| v0_backward_gpu | 47311.544 | gpu | 2641.594 | 192.000 | 1084.000 | False |
+---------------------+----------------+--------+------------------------+-----------------+--------------------------+---------+
| v0_backward_xla_cpu | 77256.083 | cpu | --- | 288358548.000 | 49.000 | True |
+---------------------+----------------+--------+------------------------+-----------------+--------------------------+---------+
| v1_backward_gpu | 110480.189 | gpu | 788.031 | 192.000 | 68.000 | False |
+---------------------+----------------+--------+------------------------+-----------------+--------------------------+---------+
| v0_backward_cpu | 260773.897 | cpu | --- | 1936089700.000 | --- | False |
+---------------------+----------------+--------+------------------------+-----------------+--------------------------+---------+
| v1_backward_cpu | 524332.523 | cpu | --- | 1092648988.000 | --- | False |
+---------------------+----------------+--------+------------------------+-----------------+--------------------------+---------+