Skip to content

Commit

Permalink
nn.RMSNorm (tinygrad#5272)
Browse files Browse the repository at this point in the history
the norm itself has no significant value to add to Tensor method, but we would want Tensor.normalize
  • Loading branch information
chenyuxyz authored Jul 3, 2024
1 parent 9a2a82a commit b2c3a28
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 32 deletions.
1 change: 1 addition & 0 deletions docs/nn.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
::: tinygrad.nn.InstanceNorm
::: tinygrad.nn.LayerNorm
::: tinygrad.nn.LayerNorm2d
::: tinygrad.nn.RMSNorm
::: tinygrad.nn.Embedding

## Optimizers
Expand Down
5 changes: 2 additions & 3 deletions examples/mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from tinygrad import Tensor, TinyJit, nn
from tinygrad.helpers import fetch
from tinygrad.nn.state import load_state_dict, torch_load
from extra.models.llama import RMSNorm

from tqdm import tqdm
from transformers import AutoTokenizer
Expand Down Expand Up @@ -234,7 +233,7 @@ class MambaBlock:
def __init__(self, dim: int, norm_eps: float = 1e-5, rms_norm: bool = True, layer_idx: Optional[int] = None):
self.mixer = MambaMixer(dim, layer_idx=layer_idx)
if rms_norm:
self.norm = RMSNorm(dim, norm_eps)
self.norm = nn.RMSNorm(dim, norm_eps)
else:
raise NotImplementedError

Expand All @@ -249,7 +248,7 @@ def __init__(self, dim: int, n_layers: int, vocab_size: int, rms_norm: bool = Tr
self.embedding = nn.Embedding(vocab_size, dim)
self.layers = [MambaBlock(dim, rms_norm=rms_norm, layer_idx=i) for i in range(n_layers)]
if rms_norm:
self.norm_f = RMSNorm(dim, norm_eps)
self.norm_f = nn.RMSNorm(dim, norm_eps)

def __call__(self, input_ids: Tensor, inference_params=None) -> Any:
hidden_states = self.embedding(input_ids)
Expand Down
11 changes: 5 additions & 6 deletions examples/openelm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import json, pprint
from tinygrad import fetch, nn, Tensor
from tinygrad.helpers import DEBUG
from extra.models.llama import RMSNorm # TODO: move to nn

class FeedForward:
def __init__(self, model_dim, intermediate_dim):
Expand All @@ -26,8 +25,8 @@ def __init__(self, model_dim, num_query_heads, num_kv_heads, head_dim):
self.qkv_proj = nn.Linear(model_dim, (num_query_heads + num_kv_heads*2) * head_dim, bias=False)
self.num_query_heads, self.num_kv_heads = num_query_heads, num_kv_heads
self.head_dim = head_dim
self.q_norm = RMSNorm(head_dim)
self.k_norm = RMSNorm(head_dim)
self.q_norm = nn.RMSNorm(head_dim)
self.k_norm = nn.RMSNorm(head_dim)
self.out_proj = nn.Linear(num_query_heads * head_dim, model_dim, bias=False)

def __call__(self, x:Tensor) -> Tensor:
Expand Down Expand Up @@ -65,8 +64,8 @@ class Layer:
def __init__(self, model_dim, intermediate_dim, num_query_heads, num_kv_heads, head_dim):
self.ffn = FeedForward(model_dim, intermediate_dim)
self.attn = Attention(model_dim, num_query_heads, num_kv_heads, head_dim)
self.ffn_norm = RMSNorm(model_dim)
self.attn_norm = RMSNorm(model_dim)
self.ffn_norm = nn.RMSNorm(model_dim)
self.attn_norm = nn.RMSNorm(model_dim)

def __call__(self, x:Tensor) -> Tensor: # (batch, seq_len, embed_dim)
x = x + self.attn(self.attn_norm(x))
Expand All @@ -84,7 +83,7 @@ def __init__(self, cfg):
if DEBUG >= 3: pprint.pp(cfg)
self.layers = [Layer(cfg['model_dim'], make_divisible(int(cfg["model_dim"] * cfg['ffn_multipliers'][i]), cfg['ffn_dim_divisor']),
cfg['num_query_heads'][i], cfg['num_kv_heads'][i], cfg['head_dim']) for i in range(cfg['num_transformer_layers'])]
self.norm = RMSNorm(cfg['model_dim'])
self.norm = nn.RMSNorm(cfg['model_dim'])
self.token_embeddings = nn.Embedding(cfg['vocab_size'], cfg['model_dim'])

def __call__(self, tokens:Tensor):
Expand Down
17 changes: 3 additions & 14 deletions extra/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,6 @@ def repeat_kv(x:Tensor, n_rep:int) -> Tensor:
# NOTE: this is different from x.repeat((1, 1, n_rep, 1))
return x.repeat((1, 1, 1, n_rep)).reshape(bs, seqlen, n_kv_heads * n_rep, head_dim)

class RMSNorm:
def __init__(self, dim, eps=1e-6):
self.eps = eps
self.weight = Tensor.ones(dim)

def _norm(self, x:Tensor):
return x * (x.pow(2).mean(-1, keepdim=True) + self.eps).rsqrt()

def __call__(self, x:Tensor) -> Tensor:
return self._norm(x.float()).cast(x.dtype) * self.weight

class Attention:
def __init__(self, dim, n_heads, n_kv_heads, max_context, linear=nn.Linear):
self.n_heads = n_heads
Expand Down Expand Up @@ -98,8 +87,8 @@ class TransformerBlock:
def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_kv_heads:int, norm_eps:float, max_context:int, linear=nn.Linear, feed_forward=FeedForward):
self.attention = Attention(dim, n_heads, n_kv_heads, max_context, linear)
self.feed_forward = feed_forward(dim, hidden_dim, linear)
self.attention_norm = RMSNorm(dim, norm_eps)
self.ffn_norm = RMSNorm(dim, norm_eps)
self.attention_norm = nn.RMSNorm(dim, norm_eps)
self.ffn_norm = nn.RMSNorm(dim, norm_eps)

def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]):
h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
Expand Down Expand Up @@ -154,7 +143,7 @@ def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
class Transformer:
def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_layers:int, norm_eps:float, vocab_size, linear=nn.Linear, n_kv_heads=None, rope_theta=10000, max_context=1024, jit=True, feed_forward=FeedForward):
self.layers = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear, feed_forward=feed_forward) for _ in range(n_layers)]
self.norm = RMSNorm(dim, norm_eps)
self.norm = nn.RMSNorm(dim, norm_eps)
self.tok_embeddings = nn.Embedding(vocab_size, dim)
self.output = nn.Linear(dim, vocab_size, bias=False)
self.max_context = max_context
Expand Down
13 changes: 6 additions & 7 deletions test/test_multitensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,26 +282,25 @@ def test_embedding(self):
np.testing.assert_allclose(z.numpy(), z_shard.numpy(), atol=1e-6, rtol=1e-6)

def test_rmsnorm(self):
from extra.models.llama import RMSNorm
B, T, embed_size = 4, 10, 20

layer_norm = RMSNorm(embed_size)
norm = nn.RMSNorm(embed_size)
x = Tensor.rand((B, T, embed_size)).contiguous().realize()
y = layer_norm(x)
y = norm(x)

# for norm layers, the correct way to shard weights is duplication
layer_norm_sharded = RMSNorm(embed_size)
layer_norm_sharded.weight.shard_(devices_2, axis=None).realize()
norm_sharded = nn.RMSNorm(embed_size)
norm_sharded.weight.shard_(devices_2, axis=None).realize()

# if x is being sharded, then all-reduce is involved
x_sharded = x.shard(devices_2, axis=2).realize()
y_shard = layer_norm_sharded(x_sharded).realize()
y_shard = norm_sharded(x_sharded).realize()
np.testing.assert_allclose(y.numpy(), y_shard.numpy(), atol=1e-6, rtol=1e-6)

# if x is being duplicated, then the operations remain inside each GPU
# which is the common case
x_sharded = x.shard(devices_2, axis=None).realize()
y_shard = layer_norm_sharded(x_sharded).realize()
y_shard = norm_sharded(x_sharded).realize()
np.testing.assert_allclose(y.numpy(), y_shard.numpy(), atol=1e-6, rtol=1e-6)

# NOTE: this is failing on LLVM CI, no idea why. Works locally.
Expand Down
36 changes: 35 additions & 1 deletion test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from tinygrad.helpers import CI, Context
from tinygrad.ops import BufferOps
from tinygrad.nn import Conv1d, ConvTranspose1d, Conv2d, ConvTranspose2d, Linear, Embedding
from tinygrad.nn import BatchNorm2d, LayerNorm, LayerNorm2d, GroupNorm, InstanceNorm
from tinygrad.nn import BatchNorm2d, LayerNorm, LayerNorm2d, GroupNorm, InstanceNorm, RMSNorm
from tinygrad.nn.state import load_state_dict
from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import run_schedule
Expand Down Expand Up @@ -355,6 +355,40 @@ def test_instancenorm_3d(self):
np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=2e-3, rtol=1e-3)
np.testing.assert_allclose(layer.bias.grad.numpy(), torch_layer.bias.grad.detach().numpy(), atol=1e-3, rtol=1e-3)

def test_rmsnorm(self):
class TorchRMSNorm(torch.nn.Module):
# https://github.com/meta-llama/llama/blob/be327c427cc5e89cc1d3ab3d3fec4484df771245/llama/model.py#L34C1-L77C36
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = torch.nn.Parameter(torch.ones(dim))

def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight

B, T, embed_size = 4, 10, 20
torch_layer = TorchRMSNorm(embed_size)
layer = RMSNorm(embed_size)
layer.weight.requires_grad = True

for _ in range(10):
# forward
x = Tensor.randn(B, T, embed_size, requires_grad=True)
z = layer(x)
torch_x = torch.tensor(x.numpy(), requires_grad=True)
torch_z = torch_layer(torch_x)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)

# backward
z.sum().backward()
torch_z.sum().backward(retain_graph=True)
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=1e-3, rtol=1e-3)
np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=2e-3, rtol=1e-3)

def test_embedding(self):
B, T, embed_size, vocab_size = 4, 10, 20, 28

Expand Down
23 changes: 22 additions & 1 deletion tinygrad/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ def __call__(self, x:Tensor):

return x.batchnorm(self.weight, self.bias, batch_mean, batch_invstd)

# TODO: these Conv lines are terrible
def Conv1d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
"""
Applies a 1D convolution over an input signal composed of several input planes.
Expand Down Expand Up @@ -282,6 +281,28 @@ class LayerNorm2d(LayerNorm):
"""
def __call__(self, x): return super().__call__(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)

class RMSNorm:
"""
Applies Root Mean Square Normalization to input.
- Described: https://paperswithcode.com/method/rmsnorm
- Paper: https://arxiv.org/abs/1910.07467
```python exec="true" source="above" session="tensor" result="python"
norm = nn.RMSNorm(4)
t = Tensor.arange(12, dtype=dtypes.float).reshape(3, 4)
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(norm(t).numpy())
```
"""
def __init__(self, dim, eps=1e-6): self.eps, self.weight = eps, Tensor.ones(dim)

def _norm(self, x:Tensor): return x * (x.pow(2).mean(-1, keepdim=True) + self.eps).rsqrt()

def __call__(self, x:Tensor) -> Tensor: return self._norm(x.float()).cast(x.dtype) * self.weight

class Embedding:
"""
A simple lookup table that stores embeddings of a fixed dictionary and size.
Expand Down

0 comments on commit b2c3a28

Please sign in to comment.