forked from huggingface/trl
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
replace simpletransformers training and use datasets for data loading
- Loading branch information
leandro
committed
Jan 1, 2022
1 parent
750f5fd
commit 919a1eb
Showing
12 changed files
with
333 additions
and
803 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,291 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Train IMDb Classifier" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"!huggingface-cli login" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Load IMDb dataset" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from datasets import load_dataset, load_metric" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"ds = load_dataset(\"imdb\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"DatasetDict({\n", | ||
" train: Dataset({\n", | ||
" features: ['text', 'label'],\n", | ||
" num_rows: 25000\n", | ||
" })\n", | ||
" test: Dataset({\n", | ||
" features: ['text', 'label'],\n", | ||
" num_rows: 25000\n", | ||
" })\n", | ||
" unsupervised: Dataset({\n", | ||
" features: ['text', 'label'],\n", | ||
" num_rows: 50000\n", | ||
" })\n", | ||
"})" | ||
] | ||
}, | ||
"execution_count": null, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"ds" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"{'label': ClassLabel(num_classes=2, names=['neg', 'pos'], names_file=None, id=None),\n", | ||
" 'text': Value(dtype='string', id=None)}" | ||
] | ||
}, | ||
"execution_count": null, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"ds['train'].features" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Load Pretrained DistilBERT" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from transformers import AutoModelForSequenceClassification, AutoTokenizer\n", | ||
"\n", | ||
"model_name = \"distilbert-base-uncased\"\n", | ||
"model = AutoModelForSequenceClassification.from_pretrained(model_name)\n", | ||
"tokenizer = AutoTokenizer.from_pretrained(model_name)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Prepocess Data" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"application/vnd.jupyter.widget-view+json": { | ||
"model_id": "6ddef2e0d4a04e12ad7513950158236c", | ||
"version_major": 2, | ||
"version_minor": 0 | ||
}, | ||
"text/plain": [ | ||
" 0%| | 0/25 [00:00<?, ?ba/s]" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
}, | ||
{ | ||
"data": { | ||
"application/vnd.jupyter.widget-view+json": { | ||
"model_id": "4b1392a042614a1682b6f62642262446", | ||
"version_major": 2, | ||
"version_minor": 0 | ||
}, | ||
"text/plain": [ | ||
" 0%| | 0/25 [00:00<?, ?ba/s]" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
}, | ||
{ | ||
"data": { | ||
"application/vnd.jupyter.widget-view+json": { | ||
"model_id": "a7f130baafab4493bfe185fa7f3a9fe9", | ||
"version_major": 2, | ||
"version_minor": 0 | ||
}, | ||
"text/plain": [ | ||
" 0%| | 0/50 [00:00<?, ?ba/s]" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
} | ||
], | ||
"source": [ | ||
"def tokenize(examples):\n", | ||
" outputs = tokenizer(examples['text'], truncation=True)\n", | ||
" return outputs\n", | ||
"\n", | ||
"tokenized_ds = ds.map(tokenize, batched=True)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Prepare Trainer" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from transformers import TrainingArguments, Trainer, DataCollatorWithPadding" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import numpy as np\n", | ||
"\n", | ||
"def compute_metrics(eval_preds):\n", | ||
" metric = load_metric(\"accuracy\")\n", | ||
" logits, labels = eval_preds\n", | ||
" predictions = np.argmax(logits, axis=-1)\n", | ||
" return metric.compute(predictions=predictions, references=labels)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"training_args = TrainingArguments(num_train_epochs=1,\n", | ||
" output_dir=\"distilbert-imdb\",\n", | ||
" push_to_hub=True,\n", | ||
" per_device_train_batch_size=16,\n", | ||
" per_device_eval_batch_size=16,\n", | ||
" evaluation_strategy=\"epoch\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"data_collator = DataCollatorWithPadding(tokenizer)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"trainer = Trainer(model=model, tokenizer=tokenizer,\n", | ||
" data_collator=data_collator,\n", | ||
" args=training_args,\n", | ||
" train_dataset=tokenized_ds[\"train\"],\n", | ||
" eval_dataset=tokenized_ds[\"test\"], \n", | ||
" compute_metrics=compute_metrics)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Train Model and Push to Hub" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"trainer.train()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"trainer.push_to_hub()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 4 | ||
} |
Oops, something went wrong.