{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This notebook is adapted from [this one](https://github.com/fastai/fastai_docs/blob/master/dev_course/dl2/translation_transformer.ipynb) created by Sylvain Gugger.\n",
"\n",
"See also [The Annotated Transformer](http://nlp.seas.harvard.edu/2018/04/03/attention.html) from Harvard NLP."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Attention and the Transformer"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Nvidia AI researcher [Chip Huyen](https://huyenchip.com/) wrote a great post [Top 8 trends from ICLR 2019](https://huyenchip.com/2019/05/12/top-8-trends-from-iclr-2019.html) in which one of the trends is that *RNN is losing its luster with researchers*.\n",
"\n",
"There's good reason for this, RNNs can be a pain: parallelization can be tricky and they can be difficult to debug. Since language is recursive, it seemed like RNNs were a good conceptual fit with NLP, but recently methods using *attention* have been achieving state of the art results on NLP.\n",
"\n",
"This is still an area of very active research, for instance, a recent paper [Pay Less Attention with Lightweight and Dynamic Convolutions](https://arxiv.org/abs/1901.10430) showed that convolutions can beat attention on some tasks, including English to German translation. More research is needed on the various strenghts of RNNs, CNNs, and transformers/attention, and perhaps on approaches to combine the best of each."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from fastai.text import *"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[PosixPath('/home/jhoward/.fastai/data/giga-fren/cc.en.300.bin'),\n",
" PosixPath('/home/jhoward/.fastai/data/giga-fren/data_save.pkl'),\n",
" PosixPath('/home/jhoward/.fastai/data/giga-fren/models'),\n",
" PosixPath('/home/jhoward/.fastai/data/giga-fren/giga-fren.release2.fixed.en'),\n",
" PosixPath('/home/jhoward/.fastai/data/giga-fren/giga-fren.release2.fixed.fr'),\n",
" PosixPath('/home/jhoward/.fastai/data/giga-fren/questions_easy.csv'),\n",
" PosixPath('/home/jhoward/.fastai/data/giga-fren/cc.fr.300.bin')]"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"path = Config().data_path()/'giga-fren'\n",
"path.ls()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load data"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We reuse the same functions as in the translation notebook to load our data."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def seq2seq_collate(samples, pad_idx=1, pad_first=True, backwards=False):\n",
" \"Function that collect samples and adds padding. Flips token order if needed\"\n",
" samples = to_data(samples)\n",
" max_len_x,max_len_y = max([len(s[0]) for s in samples]),max([len(s[1]) for s in samples])\n",
" res_x = torch.zeros(len(samples), max_len_x).long() + pad_idx\n",
" res_y = torch.zeros(len(samples), max_len_y).long() + pad_idx\n",
" if backwards: pad_first = not pad_first\n",
" for i,s in enumerate(samples):\n",
" if pad_first: \n",
" res_x[i,-len(s[0]):],res_y[i,-len(s[1]):] = LongTensor(s[0]),LongTensor(s[1])\n",
" else: \n",
" res_x[i, :len(s[0])],res_y[i, :len(s[1])] = LongTensor(s[0]),LongTensor(s[1])\n",
" if backwards: res_x,res_y = res_x.flip(1),res_y.flip(1)\n",
" return res_x, res_y"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"class Seq2SeqDataBunch(TextDataBunch):\n",
" \"Create a `TextDataBunch` suitable for training an RNN classifier.\"\n",
" @classmethod\n",
" def create(cls, train_ds, valid_ds, test_ds=None, path='.', bs=32, val_bs=None, pad_idx=1,\n",
" dl_tfms=None, pad_first=False, device=None, no_check=False, backwards=False, **dl_kwargs):\n",
" \"Function that transform the `datasets` in a `DataBunch` for classification. Passes `**dl_kwargs` on to `DataLoader()`\"\n",
" datasets = cls._init_ds(train_ds, valid_ds, test_ds)\n",
" val_bs = ifnone(val_bs, bs)\n",
" collate_fn = partial(seq2seq_collate, pad_idx=pad_idx, pad_first=pad_first, backwards=backwards)\n",
" train_sampler = SortishSampler(datasets[0].x, key=lambda t: len(datasets[0][t][0].data), bs=bs//2)\n",
" train_dl = DataLoader(datasets[0], batch_size=bs, sampler=train_sampler, drop_last=True, **dl_kwargs)\n",
" dataloaders = [train_dl]\n",
" for ds in datasets[1:]:\n",
" lengths = [len(t) for t in ds.x.items]\n",
" sampler = SortSampler(ds.x, key=lengths.__getitem__)\n",
" dataloaders.append(DataLoader(ds, batch_size=val_bs, sampler=sampler, **dl_kwargs))\n",
" return cls(*dataloaders, path=path, device=device, collate_fn=collate_fn, no_check=no_check)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"class Seq2SeqTextList(TextList):\n",
" _bunch = Seq2SeqDataBunch\n",
" _label_cls = TextList"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Refer to notebook 7-seq2seq-translation for the code we used to create, process, and save this data."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"data = load_data(path)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
" \n",
"
\n",
"
text
\n",
"
target
\n",
"
\n",
" \n",
" \n",
"
\n",
"
xxbos quelles questions devraient être traitées respectivement au niveau international et au niveau national , ou quelle xxunk devrait être établie entre la réglementation internationale et la réglementation nationale ?
\n",
"
xxbos which issues should be dealt with internationally and which nationally , or what division should be made between international regulation and national regulation ?
\n",
"
\n",
"
\n",
"
xxbos comment la culture et les arts y vivent - ils , et comment la société civile les prend - elle en considération dans le développement de la ville ?
\n",
"
xxbos where do art and culture fit in , and how is civil society taking them into consideration in developing the city ?
\n",
"
\n",
"
\n",
"
xxbos qu’arrivera - t - il si les entreprises canadiennes et les gouvernements ne se xxunk pas sur la question du conflit entre le travail et la vie personnelle ?
\n",
"
xxbos what will likely happen if canadian organizations and governments do not deal with the issue of work – life conflict ?
\n",
"
\n",
"
\n",
"
xxbos qu'adviendra - t - il de l'examen de rendement et de la rémunération au rendement de xxunk et comment se xxunk - t - il à cet égard ?
\n",
"
xxbos what happens to xxunk 's evaluation review and performance pay and how will this make him feel ?
\n",
"
\n",
"
\n",
"
xxbos quels avantages prévoit - on en général pour la région du delta de beaufort par suite de la signature de xxunk / d'une entente future sur l'autonomie gouvernementale ?
\n",
"
xxbos what benefits to the beaufort delta region generally are expected as a result of this aip / future self - government ?
\n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"data.show_batch()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Transformer model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"![Transformer model](images/Transformer.png)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Shifting"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We add a transform to the dataloader that shifts the targets right and adds a padding at the beginning."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"v = data.vocab"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"v.stoi['xxpad']"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"def shift_tfm(b):\n",
" x,y = b\n",
" y = F.pad(y, (1, 0), value=1)\n",
" return [x,y[:,:-1]], y[:,1:]"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"data.add_tfm(shift_tfm)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Embeddings"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The input and output embeddings are traditional PyTorch embeddings (and we can use pretrained vectors if we want to). The transformer model isn't a recurrent one, so it has no idea of the relative positions of the words. To help it with that, they had to the input embeddings a positional encoding which is cosine of a certain frequency:"
]
},
{
"cell_type": "code",
"execution_count": 156,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([0.0000, 0.0667, 0.1333, 0.2000, 0.2667, 0.3333, 0.4000, 0.4667, 0.5333,\n",
" 0.6000, 0.6667, 0.7333, 0.8000, 0.8667, 0.9333])"
]
},
"execution_count": 156,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"d = 30\n",
"torch.arange(0., d, 2.)/d"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"class PositionalEncoding(nn.Module):\n",
" \"Encode the position with a sinusoid.\"\n",
" def __init__(self, d):\n",
" super().__init__()\n",
" self.register_buffer('freq', 1 / (10000 ** (torch.arange(0., d, 2.)/d)))\n",
" \n",
" def forward(self, pos):\n",
" inp = torch.ger(pos, self.freq)\n",
" enc = torch.cat([inp.sin(), inp.cos()], dim=-1)\n",
" return enc"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"
\n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"inputs, targets, outputs = get_predictions(learn)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(Text xxbos xxmaj pendant que xxunk les activités requises pour maintenir mon xxunk physique , est - ce que je xxunk de la protection d’un régime d’assurance ou de pension ?,\n",
" Text xxbos xxmaj while i go about maintaining this high degree of fitness , am i protected under an insurance or pension plan ?,\n",
" Text xxbos xxmaj while i do to the my physical physical of physical , do i aware by the pension plan service plan ?)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"inputs[10],targets[10],outputs[10]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(Text xxbos xxmaj quelles sont les conséquences sur la recherche , la mise en pratique et les politiques en ce qui a trait à l'ac ?,\n",
" Text xxbos xxmaj what are the xxunk for xxup kt research , practice / policy ?,\n",
" Text xxbos xxmaj what are the implications implications research kt , , policy and policies in)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"inputs[700],targets[700],outputs[700]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(Text xxbos xxmaj quelle est la position des xxmaj états - xxmaj unis , du xxmaj canada et de la xxup xxunk à ce propos ?,\n",
" Text xxbos xxmaj where do the xxup us , xxmaj canada and xxup xxunk stand ?,\n",
" Text xxbos xxmaj what is xxmaj xxup us xxmaj xxmaj united and the xxunk fit in)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"inputs[701],targets[701],outputs[701]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(Text xxbos xxmaj quels sont les atouts particuliers du xxmaj canada en recherche sur l'obésité sur la scène internationale ?,\n",
" Text xxbos xxmaj what are the unique xxmaj canadian strengths in obesity research that set xxmaj canada apart on an international front ?,\n",
" Text xxbos xxmaj what are xxmaj specific strengths canada strengths in obesity - ? are up canada ? from international international stage ?)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"inputs[2500],targets[2500],outputs[2500]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(Text xxbos xxmaj quelles sont les répercussions politiques à long terme de cette révolution scientifique mondiale ?,\n",
" Text xxbos xxmaj what are some of the long - term policy implications of this global knowledge revolution ?,\n",
" Text xxbos xxmaj what are the long the long - term policies implications of this global scientific ? ?)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"inputs[4002],targets[4002],outputs[4002]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Label smoothing"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"They point out in the paper that using label smoothing helped getting a better BLEU/accuracy, even if it made the loss worse."
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {},
"outputs": [],
"source": [
"model = Transformer(len(data.train_ds.x.vocab.itos), len(data.train_ds.y.vocab.itos), d_model=256)"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {},
"outputs": [],
"source": [
"learn = Learner(data, model, metrics=[accuracy, CorpusBLEU(len(data.train_ds.y.vocab.itos))], \n",
" loss_func=FlattenedLoss(LabelSmoothingCrossEntropy, axis=-1))"
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
" \n",
"
\n",
"
epoch
\n",
"
train_loss
\n",
"
valid_loss
\n",
"
accuracy
\n",
"
bleu
\n",
"
time
\n",
"
\n",
" \n",
" \n",
"
\n",
"
0
\n",
"
3.281034
\n",
"
3.357356
\n",
"
0.621848
\n",
"
0.458009
\n",
"
01:45
\n",
"
\n",
"
\n",
"
1
\n",
"
2.872045
\n",
"
2.923340
\n",
"
0.690921
\n",
"
0.510376
\n",
"
01:47
\n",
"
\n",
"
\n",
"
2
\n",
"
2.598603
\n",
"
2.653438
\n",
"
0.729291
\n",
"
0.545735
\n",
"
01:49
\n",
"
\n",
"
\n",
"
3
\n",
"
2.407944
\n",
"
2.514847
\n",
"
0.748187
\n",
"
0.567057
\n",
"
01:46
\n",
"
\n",
"
\n",
"
4
\n",
"
2.195246
\n",
"
2.403729
\n",
"
0.766409
\n",
"
0.592165
\n",
"
01:50
\n",
"
\n",
"
\n",
"
5
\n",
"
2.095695
\n",
"
2.362098
\n",
"
0.776127
\n",
"
0.604666
\n",
"
01:48
\n",
"
\n",
"
\n",
"
6
\n",
"
1.999303
\n",
"
2.358647
\n",
"
0.779535
\n",
"
0.609675
\n",
"
01:47
\n",
"
\n",
"
\n",
"
7
\n",
"
1.923621
\n",
"
2.359421
\n",
"
0.780211
\n",
"
0.610871
\n",
"
01:47
\n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn.fit_one_cycle(8, 5e-4, div_factor=5)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"learn.fit_one_cycle(8, 5e-4, div_factor=5)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Quels sont les atouts particuliers du Canada en recherche sur l'obésité sur la scène internationale ?\n",
"What are Specific strengths canada strengths in obesity - ? are up canada ? from international international stage ?\n",
"Quelles sont les répercussions politiques à long terme de cette révolution scientifique mondiale ?\n",
"What are the long the long - term policies implications of this global scientific ? ?\n"
]
}
],
"source": [
"print(\"Quels sont les atouts particuliers du Canada en recherche sur l'obésité sur la scène internationale ?\")\n",
"print(\"What are Specific strengths canada strengths in obesity - ? are up canada ? from international international stage ?\")\n",
"print(\"Quelles sont les répercussions politiques à long terme de cette révolution scientifique mondiale ?\")\n",
"print(\"What are the long the long - term policies implications of this global scientific ? ?\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(Text xxbos xxmaj quelle distance y a - t - il entre le point le plus rapproché de la surface à xxunk et la position d’utilisation habituelle du tube radiogène ?,\n",
" Text xxbos xxmaj what is the distance between the nearest point of the area to be shielded and the usual operational position of the x - ray tube ?,\n",
" Text xxbos xxmaj what is the xxmaj between the xxmaj and of the xxmaj ? the ? and the most ? ? of the xxmaj - ray tube ?)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"inputs[10],targets[10],outputs[10]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(Text xxbos xxmaj quels types de présentations xxmaj santé xxmaj canada xxunk - t - il dans le format ectd à compter du 1er septembre ?,\n",
" Text xxbos xxmaj what kind of submission types will xxmaj health xxmaj canada accept on xxmaj september 1 , 2004 in ectd format ?,\n",
" Text xxbos xxmaj what is of information is of be canadian xxmaj canada take ? the canadian ? , and ? the format ?)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"inputs[700],targets[700],outputs[700]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(Text xxbos xxmaj quelles sont les trois caractéristiques qui vous incitent le plus à investir dans votre région ( xxup nommez - xxup les ) ?,\n",
" Text xxbos xxmaj what are the three most attractive features about investing in your region ( xxup name xxup it ) ?,\n",
" Text xxbos xxmaj what is the main main important concerns of the in the country ? xxup xxunk , xxunk ) ?)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"inputs[701],targets[701],outputs[701]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(Text xxbos xxmaj quelles actions avez - vous prises et quel en a été le résultat ?,\n",
" Text xxbos xxmaj what were your actions and the outcomes ?,\n",
" Text xxbos xxmaj what is the targets ? how main of)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"inputs[4001],targets[4001],outputs[4001]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Test leakage"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If we change a token in the targets at position n, it shouldn't impact the predictions before that."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"learn.model.eval();"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"xb,yb = data.one_batch(cpu=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"inp1,out1 = xb[0][:1],xb[1][:1]\n",
"inp2,out2 = inp1.clone(),out1.clone()\n",
"out2[0,15] = 10"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"y1 = learn.model(inp1, out1)\n",
"y2 = learn.model(inp2, out2)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(0., device='cuda:0', grad_fn=)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"(y1[0,:15] - y2[0,:15]).abs().mean()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.1"
}
},
"nbformat": 4,
"nbformat_minor": 2
}