Skip to content

Commit

Permalink
replace simpletransformers training and use datasets for data loading
Browse files Browse the repository at this point in the history
  • Loading branch information
leandro committed Jan 1, 2022
1 parent 750f5fd commit 919a1eb
Show file tree
Hide file tree
Showing 12 changed files with 333 additions and 803 deletions.
497 changes: 0 additions & 497 deletions nbs/03-bert-imdb-training.ipynb

This file was deleted.

291 changes: 291 additions & 0 deletions nbs/03-distilbert-imdb-training.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,291 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Train IMDb Classifier"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!huggingface-cli login"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load IMDb dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from datasets import load_dataset, load_metric"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"ds = load_dataset(\"imdb\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"DatasetDict({\n",
" train: Dataset({\n",
" features: ['text', 'label'],\n",
" num_rows: 25000\n",
" })\n",
" test: Dataset({\n",
" features: ['text', 'label'],\n",
" num_rows: 25000\n",
" })\n",
" unsupervised: Dataset({\n",
" features: ['text', 'label'],\n",
" num_rows: 50000\n",
" })\n",
"})"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ds"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'label': ClassLabel(num_classes=2, names=['neg', 'pos'], names_file=None, id=None),\n",
" 'text': Value(dtype='string', id=None)}"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ds['train'].features"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load Pretrained DistilBERT"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from transformers import AutoModelForSequenceClassification, AutoTokenizer\n",
"\n",
"model_name = \"distilbert-base-uncased\"\n",
"model = AutoModelForSequenceClassification.from_pretrained(model_name)\n",
"tokenizer = AutoTokenizer.from_pretrained(model_name)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Prepocess Data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "6ddef2e0d4a04e12ad7513950158236c",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/25 [00:00<?, ?ba/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "4b1392a042614a1682b6f62642262446",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/25 [00:00<?, ?ba/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a7f130baafab4493bfe185fa7f3a9fe9",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/50 [00:00<?, ?ba/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"def tokenize(examples):\n",
" outputs = tokenizer(examples['text'], truncation=True)\n",
" return outputs\n",
"\n",
"tokenized_ds = ds.map(tokenize, batched=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Prepare Trainer"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from transformers import TrainingArguments, Trainer, DataCollatorWithPadding"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"def compute_metrics(eval_preds):\n",
" metric = load_metric(\"accuracy\")\n",
" logits, labels = eval_preds\n",
" predictions = np.argmax(logits, axis=-1)\n",
" return metric.compute(predictions=predictions, references=labels)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"training_args = TrainingArguments(num_train_epochs=1,\n",
" output_dir=\"distilbert-imdb\",\n",
" push_to_hub=True,\n",
" per_device_train_batch_size=16,\n",
" per_device_eval_batch_size=16,\n",
" evaluation_strategy=\"epoch\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"data_collator = DataCollatorWithPadding(tokenizer)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"trainer = Trainer(model=model, tokenizer=tokenizer,\n",
" data_collator=data_collator,\n",
" args=training_args,\n",
" train_dataset=tokenized_ds[\"train\"],\n",
" eval_dataset=tokenized_ds[\"test\"], \n",
" compute_metrics=compute_metrics)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Train Model and Push to Hub"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"trainer.train()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"trainer.push_to_hub()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Loading

0 comments on commit 919a1eb

Please sign in to comment.