In the previous post of the Gemma explained series, we discussed the latest Gemma 2 architecture. In this post, you will explore the RecurrentGemma architecture. Let’s get started!
RecurrentGemma is based on Griffin, a hybrid model that mixes gated linear recurrences with local sliding window attention. This change improves computation and memory and it's better suited for long context prompts.
However it comes with the downside of reduced needle in haystack performance due to the fixed-sized state of the Griffin architecture. While it is possible to provide the entire text from a book as input, this approach may not be optimal. Recurrent Neural Networks (RNNs) can encounter difficulties in learning long-range dependencies in exceedingly long sequences, and the model has a limited context window. This means that it can only effectively consider a certain number of preceding tokens when making predictions.
Moreover, recurrent models have not yet received as much attention in terms of inference time optimizations compared to their transformer counterparts. And there’s less research and community support available compared to the well-established transformer architecture.
So, this model will be highly valuable in scenarios when you are concerned about exhausting your LLM’s context window. By prioritizing the most recent information and strategically discarding older data, RecurrentGemma ensures that the LLM's performance remains strong as the context expands.
Below is the architecture diagram for the Recurrent Gemma 2B model.
Griffin follows the same residual pattern and MLP block as other Transformer baseline. However, unlike both the MQA Transformer baseline and the Hawk model, Griffin uses a blend of recurrent and MQA blocks.
Griffin uses a layered structure by alternating two residual blocks with a recurrent block, followed by a residual block that incorporates the local MQA attention block.
The core parameters of the architecture are summarized in the table below.
Non-embedding parameters are distributed throughout the hidden layers of the model, in components like attention mechanisms and feedforward networks.
Note: The naming of the model “2B” comes from this parameter
Embedding Parameters are usually found in the dedicated layer called an embedding layer. This layer is responsible for mapping discrete tokens (like words or characters) into continuous vector representations (embeddings).
Note: 0.7B can be calculated as 256k (vocabulary size) x 2560 (model width)
Model width refers to the size of the hidden layers in the model, determining the model’s capacity to represent complex patterns, just like the base Gemma Models.
Recurrent neural network (RNN) width is the size of the hidden state maintained by the Real-Gated Linear Recurrent Unit (RG-LRU). Unlike traditional Transformers, the recurrent block maintains a fixed-size internal state, regardless of the input length. This allows RecurrentGemma to process longer sequences with less memory, making it more efficient for tasks like generating long articles or code.
It’s the same as feedforward hidden dimensions in the base Gemma model. For simplicity, we applied an expansion factor of 3 in the Recurrent Gemma model, resulting in an MLP dimension of 7680 (calculated as 2560 x 3).
The state maintained by RecurrentGemma has a finite size and does not grow with sequences longer than the local attention window of 2k tokens. This means that while the maximum length of samples generated autoregressively by Gemma is limited by the host system's memory capacity, RecurrentGemma can generate sequences of arbitrary length, overcoming this constraint.
RecurrentGemmaForCausalLM(
(model): RecurrentGemmaModel(
(embed_tokens): Embedding(256000, 2560, padding_idx=0)
(layers): ModuleList(
(0-1): 2 x RecurrentGemmaDecoderLayer(
(temporal_pre_norm): RecurrentGemmaRMSNorm()
(temporal_block): RecurrentGemmaRecurrentBlock(
(linear_y): Linear(in_features=2560, out_features=2560, bias=True)
(linear_x): Linear(in_features=2560, out_features=2560, bias=True)
(linear_out): Linear(in_features=2560, out_features=2560, bias=True)
(conv_1d): Conv1d(2560, 2560, kernel_size=(4,), stride=(1,), padding=(3,), groups=2560)
(rg_lru): RecurrentGemmaRglru()
(act_fn): PytorchGELUTanh()
)
(channel_pre_norm): RecurrentGemmaRMSNorm()
(mlp_block): RecurrentGemmaMlp(
(gate_proj): Linear(in_features=2560, out_features=7680, bias=True)
(up_proj): Linear(in_features=2560, out_features=7680, bias=True)
(down_proj): Linear(in_features=7680, out_features=2560, bias=True)
(act_fn): PytorchGELUTanh()
)
)
(2): RecurrentGemmaDecoderLayer(
(temporal_pre_norm): RecurrentGemmaRMSNorm()
(temporal_block): RecurrentGemmaSdpaAttention(
(q_proj): Linear(in_features=2560, out_features=2560, bias=False)
(k_proj): Linear(in_features=2560, out_features=256, bias=False)
(v_proj): Linear(in_features=2560, out_features=256, bias=False)
(o_proj): Linear(in_features=2560, out_features=2560, bias=True)
(rotary_emb): RecurrentGemmaRotaryEmbedding()
)
(channel_pre_norm): RecurrentGemmaRMSNorm()
(mlp_block): RecurrentGemmaMlp(
(gate_proj): Linear(in_features=2560, out_features=7680, bias=True)
(up_proj): Linear(in_features=2560, out_features=7680, bias=True)
(down_proj): Linear(in_features=7680, out_features=2560, bias=True)
(act_fn): PytorchGELUTanh()
)
)
:
(23): RecurrentGemmaDecoderLayer(
(temporal_pre_norm): RecurrentGemmaRMSNorm()
(temporal_block): RecurrentGemmaSdpaAttention(
(q_proj): Linear(in_features=2560, out_features=2560, bias=False)
(k_proj): Linear(in_features=2560, out_features=256, bias=False)
(v_proj): Linear(in_features=2560, out_features=256, bias=False)
(o_proj): Linear(in_features=2560, out_features=2560, bias=True)
(rotary_emb): RecurrentGemmaRotaryEmbedding()
)
(channel_pre_norm): RecurrentGemmaRMSNorm()
(mlp_block): RecurrentGemmaMlp(
(gate_proj): Linear(in_features=2560, out_features=7680, bias=True)
(up_proj): Linear(in_features=2560, out_features=7680, bias=True)
(down_proj): Linear(in_features=7680, out_features=2560, bias=True)
(act_fn): PytorchGELUTanh()
)
)
(24-25): 2 x RecurrentGemmaDecoderLayer(
(temporal_pre_norm): RecurrentGemmaRMSNorm()
(temporal_block): RecurrentGemmaRecurrentBlock(
(linear_y): Linear(in_features=2560, out_features=2560, bias=True)
(linear_x): Linear(in_features=2560, out_features=2560, bias=True)
(linear_out): Linear(in_features=2560, out_features=2560, bias=True)
(conv_1d): Conv1d(2560, 2560, kernel_size=(4,), stride=(1,), padding=(3,), groups=2560)
(rg_lru): RecurrentGemmaRglru()
(act_fn): PytorchGELUTanh()
)
(channel_pre_norm): RecurrentGemmaRMSNorm()
(mlp_block): RecurrentGemmaMlp(
(gate_proj): Linear(in_features=2560, out_features=7680, bias=True)
(up_proj): Linear(in_features=2560, out_features=7680, bias=True)
(down_proj): Linear(in_features=7680, out_features=2560, bias=True)
(act_fn): PytorchGELUTanh()
)
)
)
(final_norm): RecurrentGemmaRMSNorm()
)
(lm_head): Linear(in_features=2560, out_features=256000, bias=False)
)
Takes the input text as a sequence of tokens and maps each token to a continuous vector representation of size 2560. It has a vocabulary size of 256000 which is the same with base Gemma models.
There are 26 decoder layers in total, grouped into repeating patterns.
The model begins with two residual blocks with a recurrent block (0-1). This sequence is then followed by a residual block (2) and a series of continuous blocks that alternate until the end of the layer (25).
In the recurrent block (Temporal mixing block), the model takes the input of dimension (Model width) 2560 and applies two linear layers with output dimension (RNN width) 2560 in parallel, creating two branches.
On the first branch (right side), it applies a small separable Conv1D layer with a temporal filter dimension of 4. And the RG-LRU(Real-Gated Linear Recurrent Unit) layer follows.
On the second branch (left side), it applies a GeLU nonlinearity.
And then merge the branches by element-wise multiplication, apply a final linear layer with output dimension (Model width) 2560.
After applying RMSNorm, the MLP block follows.
After having two residual blocks with a recurrent block (0-1), a residual block with a local MQA (2) follows. One of the key disadvantages of using global attention is that its computational complexity grows quadratically in the sequence length. To address this, RecurrentGemma uses a local sliding window attention. It allows each position to attend only to a fixed number of tokens in the past.
In the local MQA block (Temporal mixing block), the model takes the input of dimension (Model width) 2560. It uses linear projections (q_proj, k_proj, v_proj, o_proj) to create query, key, value, and output representations. Note that out_features for k_proj and v_proj is 256 as they share the same head with a size of 256, while q_proj and o_proj have 10 heads (256 x 10 = 2560) in parallel.
It incorporates rotary_emb (RecurrentGemmaRotaryEmbedding) for rotary positional embeddings (RoPE) just like the base Gemma models.
Applying RMSNorm and the MLP block is the same with the previous residual block.
In this article, you learned about RecurrentGemma.
In the next post, you will explore PaliGemma which is a lightweight open vision-language model (VLM).
Stay tuned and thank you for reading!