Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add inst training and inference #111

Merged
merged 6 commits into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
716 changes: 485 additions & 231 deletions aria/data/datasets.py

Large diffs are not rendered by default.

33 changes: 30 additions & 3 deletions aria/data/midi.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def __init__(
self.tempo_msgs = tempo_msgs
self.pedal_msgs = pedal_msgs
self.instrument_msgs = instrument_msgs
self.note_msgs = note_msgs
self.note_msgs = sorted(note_msgs, key=lambda msg: msg["tick"])
self.ticks_per_beat = ticks_per_beat
self.metadata = metadata

Expand Down Expand Up @@ -726,11 +726,13 @@ def meta_maestro_json(
mid: mido.MidiFile, msg_data: dict, composer_names: list, form_names: list
):
if os.path.isfile("maestro.json") is False:
print("maestro.json not found")
return {}

file_name = pathlib.Path(mid.filename).name
with open("maestro.json", "r") as f:
metadata = json.load(f).get(file_name, None)
_file_name_without_ext = os.path.splitext(file_name)[0]
metadata = json.load(f).get(_file_name_without_ext + ".midi", None)
if metadata == None:
return {}

Expand All @@ -755,13 +757,38 @@ def meta_maestro_json(
return res


def meta_listening_model(mid: mido.MidiFile, msg_data: dict, tag_names: list):
if os.path.isfile("listening_model_tags.json") is False:
return {}

file_name = pathlib.Path(mid.filename).name
with open("listening_model_tags.json", "r") as f:
tags = json.load(f).get(file_name, None)
if tags == None:
return {}

valid_tags = []
for tag in tags:
tag_name = tag[0]
if tag_name in tag_names:
valid_tags.append(tag)

return {"listening_model": valid_tags}


def meta_abs_path(mid: mido.MidiFile, msg_data: dict):
return {"abs_path": str(pathlib.Path(mid.filename).absolute())}


def get_metadata_fn(metadata_proc_name: str):
# Add additional test_names to this inventory
name_to_fn = {
"composer_filename": meta_composer_filename,
"composer_metamsg": meta_composer_metamsg,
"form_filename": meta_form_filename,
"maestro_csv": meta_maestro_json,
"maestro_json": meta_maestro_json,
"listening_model": meta_listening_model,
"abs_path": meta_abs_path,
}

fn = name_to_fn.get(metadata_proc_name, None)
Expand Down
Empty file removed aria/evals/__init__.py
Empty file.
1 change: 1 addition & 0 deletions aria/inference/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .model import TransformerLM
250 changes: 250 additions & 0 deletions aria/inference/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,250 @@
"""Inference implementation with torch-compiler friendly kv-cache."""

import torch
import torch.nn as nn

from torch.nn import functional as F
from aria.model import ModelConfig


class KVCache(nn.Module):
def __init__(
self,
max_batch_size: int,
max_seq_length: int,
n_heads: int,
head_dim: int,
dtype=torch.bfloat16,
):
super().__init__()
self.dtype = dtype
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))

def update(self, input_pos, k_val, v_val):
# input_pos: [S], k_val: [B, H, S, D]
assert input_pos.shape[0] == k_val.shape[2]

k_out = self.k_cache
v_out = self.v_cache
k_out[:, :, input_pos] = k_val
v_out[:, :, input_pos] = v_val

return k_out, v_out


class TransformerLM(nn.Module):
def __init__(self, model_config: ModelConfig):
super().__init__()
self.model_config = model_config
self.max_seq_len = model_config.max_seq_len
self.model = Transformer(model_config)
self.lm_head = nn.Linear(
model_config.d_model, model_config.vocab_size, bias=False
)

def forward(self, idxs: torch.Tensor, input_pos: torch.Tensor):
hidden_states = self.model(idxs=idxs, input_pos=input_pos)
logits = self.lm_head(hidden_states)

return logits

def setup_cache(
self,
batch_size,
max_seq_len=4096,
dtype=torch.bfloat16,
):
# Init cache
for b in self.model.encode_layers:
b.kv_cache = KVCache(
max_batch_size=batch_size,
max_seq_length=max_seq_len,
n_heads=self.model_config.n_heads,
head_dim=self.model_config.d_model // self.model_config.n_heads,
dtype=dtype,
).cuda()

self.model.freqs_cis = precompute_freqs_cis(
seq_len=max_seq_len,
n_elem=self.model_config.d_model // self.model_config.n_heads,
base=10000,
dtype=dtype,
).cuda()
self.model.causal_mask = torch.tril(
torch.ones(max_seq_len, max_seq_len, dtype=torch.bool)
).cuda()


class Transformer(nn.Module):
def __init__(self, model_config: ModelConfig) -> None:
super().__init__()
self.model_config = model_config

self.tok_embeddings = nn.Embedding(
num_embeddings=model_config.vocab_size,
embedding_dim=model_config.d_model,
)
self.encode_layers = nn.ModuleList(
TransformerBlock(model_config) for _ in range(model_config.n_layers)
)
self.out_layer_norm = nn.LayerNorm(model_config.d_model)

self.freqs_cis = None
self.casual_mask = None

def forward(
self,
idxs: torch.Tensor,
input_pos: torch.Tensor,
):
assert self.freqs_cis is not None, "Caches must be initialized first"

mask = self.causal_mask[None, None, input_pos]
freqs_cis = self.freqs_cis[input_pos]

x = self.tok_embeddings(idxs)
for layer in self.encode_layers:
x = layer(x, input_pos, freqs_cis, mask)

x = self.out_layer_norm(x)

return x


class TransformerBlock(nn.Module):
def __init__(self, model_config: ModelConfig) -> None:
super().__init__()

self.d_model = model_config.d_model
self.n_heads = model_config.n_heads
self.d_head = self.d_model // self.n_heads
self.max_seq_len = model_config.max_seq_len

# Att
self.mixed_qkv = nn.Linear(
in_features=model_config.d_model,
out_features=3 * model_config.d_model,
bias=False,
)
self.att_proj_linear = nn.Linear(
in_features=model_config.d_model,
out_features=model_config.d_model,
)

# FF
self.ff_linear_1 = nn.Linear(
in_features=model_config.d_model,
out_features=model_config.d_model * model_config.ff_mult,
)
self.ff_linear_2 = nn.Linear(
in_features=model_config.d_model * model_config.ff_mult,
out_features=model_config.d_model,
)
self.ff_activation = nn.GELU()

# Pre layer norms
self.norm1 = nn.LayerNorm(model_config.d_model)
self.norm2 = nn.LayerNorm(model_config.d_model)

# TODO: Fill in args
self.kv_cache = None

def forward(
self,
x: torch.Tensor,
input_pos: torch.Tensor,
freqs_cis: torch.Tensor,
mask: torch.Tensor,
):
assert self.kv_cache is not None, "Cache not initialized"

x += self._att_block(
x=self.norm1(x),
input_pos=input_pos,
freqs_cis=freqs_cis,
mask=mask,
)
x = x + self._ff_block(self.norm2(x))

return x

def get_kv(self, k: torch.Tensor, v: torch.Tensor, input_pos: torch.Tensor):
k, v = self.kv_cache.update(k_val=k, v_val=v, input_pos=input_pos)

return k, v

def _att_block(
self,
x: torch.Tensor,
input_pos: torch.Tensor,
freqs_cis: torch.Tensor,
mask: torch.Tensor,
):

q, k, v = self.mixed_qkv(x).split(
[self.d_model, self.d_model, self.d_model], dim=-1
)

batch_size, seq_len, _ = q.shape
q = q.view(batch_size, seq_len, self.n_heads, self.d_head)
k = k.view(batch_size, seq_len, self.n_heads, self.d_head)
v = v.view(batch_size, seq_len, self.n_heads, self.d_head)

q = apply_rotary_emb(q, freqs_cis)
k = apply_rotary_emb(k, freqs_cis)
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))

k, v = self.get_kv(k, v, input_pos=input_pos)
wv = F.scaled_dot_product_attention(
query=q,
key=k,
value=v,
attn_mask=mask,
)

# (bz, nh, L, dh) -> (bz, L, nh, dh) -> (bz, L, d)
wv = wv.transpose(1, 2).reshape(
batch_size, seq_len, self.n_heads * self.d_head
)

return self.att_proj_linear(wv)

def _ff_block(self, x: torch.Tensor):
return self.ff_linear_2(self.ff_activation(self.ff_linear_1(x)))


def precompute_freqs_cis(
seq_len: int,
n_elem: int,
base: int = 10000,
dtype: torch.dtype = torch.bfloat16,
):
freqs = 1.0 / (
base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
)
t = torch.arange(seq_len, device=freqs.device)
freqs = torch.outer(t, freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)

return cache.to(dtype=dtype)


@torch.jit.script
def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
"""
In-place RoPE. Credits to Katherine Crowson:
x shape (b_sz, s_len, n_head, d_head).
cos, sin shape (s_len, d_head // 2).
"""

d = x.shape[-1] // 2
cos = freqs_cis[..., 0][None, :, None]
sin = freqs_cis[..., 1][None, :, None]
x1, x2 = x[..., :d], x[..., d : d * 2]
tmp = x1.clone()
x1.mul_(cos).addcmul_(x2, sin, value=-1)
x2.mul_(cos).addcmul_(tmp, sin, value=1)
return x
Loading
Loading