Skip to content

Commit

Permalink
[doc] update cache doc in runtime (wenet-e2e#1014)
Browse files Browse the repository at this point in the history
  • Loading branch information
robin1001 authored Apr 6, 2022
1 parent 223f5c7 commit 7ebc37e
Showing 1 changed file with 14 additions and 8 deletions.
22 changes: 14 additions & 8 deletions docs/runtime.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Runtime for WeNet

WeNet runtime uses [Unified Two Pass (U2)](https://arxiv.org/pdf/2012.05481.pdf) framework for inference. U2 has the following advantages:
WeNet runtime uses [Unified Two Pass (U2)](https://arxiv.org/pdf/2102.01547.pdf) framework for inference. U2 has the following advantages:
* **Unified**: U2 unified the streaming and non-streaming model in a simple way, and our runtime is also unified. Therefore you can easily balance the latency and accuracy by changing chunk_size (described in the following section).
* **Accurate**: U2 achieves better accuracy by CTC joint training.
* **Fast**: Our runtime uses attention rescoring based decoding method described in U2, which is much faster than a traditional autoregressive beam search.
Expand Down Expand Up @@ -31,7 +31,10 @@ We can group $C$ continuous frames $x_t, x_{t+1}, x_{t+C}$ as one chunk for the

### Interface Design

We use LibTorch to implement U2 runtime in WeNet, and we export several interfaces in PyTorch python code by @torch.jit.export (see [asr_model.py](https://github.com/wenet-e2e/wenet/tree/main/wenet/transformer/asr_model.py)), which are required and used in C++ runtime in [torch_asr_model.cc](https://github.com/wenet-e2e/wenet/tree/main/runtime/server/x86/decoder/torch_asr_model.cc) and [torch_asr_decoder.cc](https://github.com/wenet-e2e/wenet/tree/main/runtime/server/x86/decoder/torch_asr_decoder.cc). Here we just list the interface and give a brief introduction.
We use LibTorch to implement U2 runtime in WeNet, and we export several interfaces in PyTorch python code
by @torch.jit.export (see [asr_model.py](https://github.com/wenet-e2e/wenet/tree/main/wenet/transformer/asr_model.py)),
which are required and used in C++ runtime in [torch_asr_model.cc](https://github.com/wenet-e2e/wenet/tree/main/runtime/server/x86/decoder/torch_asr_model.cc).
Here we just list the interface and give a brief introduction.

| interface | description |
|----------------------------------|-----------------------------------------|
Expand All @@ -45,15 +48,18 @@ We use LibTorch to implement U2 runtime in WeNet, and we export several interfac

### Cache in Details

For streaming scenario, the *Shared Encoder* module works in an incremental way. The current chunk computation requries the inputs and outputs of all the history chunks. We implement the incremental computation by using caches. Overall, three caches are used in our runtime.
For streaming scenario, the *Shared Encoder* module works in an incremental way. The current chunk computation requries the inputs and outputs of all the history chunks. We implement the incremental computation by using caches. Overall, two types of cache are used in our runtime.

* Encoder Conformer/Transformer layers output cache: cache the output of every encoder layer.
* Conformer CNN cache: if conformer is used, we should cache the left context for causal CNN computation in Conformer.
* Subsampling cache: cache the output of subsampling layer, which is the input of the first encoder layer.
* att_cache: the attention cache of the *Shared Encoder*(Conformer/Transformer) module.
* cnn_cache: the cnn cache of the *Shared Encoder*, which caches the left context for causal CNN computation in Conformer.

Please see [encoder.py:forward_chunk()](https://github.com/wenet-e2e/wenet/tree/main/wenet/transformer/encoder.py) and [torch_asr_decoder.cc](https://github.com/wenet-e2e/wenet/tree/main/runtime/server/x86/decoder/torch_asr_decoder.cc) for details of the caches.
Please see [encoder.py:forward_chunk()](https://github.com/wenet-e2e/wenet/tree/main/wenet/transformer/encoder.py) and [torch_asr_model.cc](https://github.com/wenet-e2e/wenet/tree/main/runtime/server/x86/decoder/torch_asr_model.cc) for details of the caches.

In practice, CNN is also used in the subsampling. We should handle the CNN cache in subsampling. However, since there are serveral CNN layers in subsampling with different left contexts, right contexts and strides, which makes it tircky to directly implement the CNN cache in subsampling. In our implementation, we simply overlap the input to avoid subsampling CNN cache. It is simple and straightforward with negligible additional cost since subsampling CNN only costs a very small fraction of the whole computation. The following picture shows how it works, where the blue color is for the overlap part of current inputs and previous inputs.
In practice, CNN is also used for subsampling, we should handle the CNN cache in subsampling.
However, there are different CNN layers in subsampling with different left contexts, right contexts and strides, which makes it tircky to directly implement the CNN cache in subsampling.
In our implementation, we simply overlap the input to avoid subsampling CNN cache.
It is simple and straightforward with negligible additional cost since subsampling CNN only costs a very small fraction of the whole computation.
The following picture shows how it works, where the blue color is for the overlap part of current inputs and previous inputs.

![Overlap input for Subsampling CNN](images/subsampling_overalp.gif)

Expand Down

0 comments on commit 7ebc37e

Please sign in to comment.