Skip to content

Commit

Permalink
Add POS tagging and Phrase chunking token classification examples (#6457
Browse files Browse the repository at this point in the history
)

* Add more token classification examples

* POS tagging example

* Phrase chunking example

* PR review fixes

* Add conllu to third party list (used in token classification examples)
  • Loading branch information
vblagoje authored Aug 13, 2020
1 parent f51161e commit eda07ef
Show file tree
Hide file tree
Showing 10 changed files with 473 additions and 204 deletions.
1 change: 1 addition & 0 deletions examples/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ pandas
nlp
fire
pytest
conllu
1 change: 1 addition & 0 deletions examples/token-classification/run.sh
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ export SAVE_STEPS=750
export SEED=1

python3 run_ner.py \
--task_type NER \
--data_dir . \
--labels ./labels.txt \
--model_name_or_path $BERT_MODEL \
Expand Down
37 changes: 37 additions & 0 deletions examples/token-classification/run_chunk.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
if ! [ -f ./dev.txt ]; then
echo "Downloading CONLL2003 dev dataset...."
curl -L -o ./dev.txt 'https://github.com/davidsbatista/NER-datasets/raw/master/CONLL2003/valid.txt'
fi

if ! [ -f ./test.txt ]; then
echo "Downloading CONLL2003 test dataset...."
curl -L -o ./test.txt 'https://github.com/davidsbatista/NER-datasets/raw/master/CONLL2003/test.txt'
fi

if ! [ -f ./train.txt ]; then
echo "Downloading CONLL2003 train dataset...."
curl -L -o ./train.txt 'https://github.com/davidsbatista/NER-datasets/raw/master/CONLL2003/train.txt'
fi

export MAX_LENGTH=200
export BERT_MODEL=bert-base-uncased
export OUTPUT_DIR=chunker-model
export BATCH_SIZE=32
export NUM_EPOCHS=3
export SAVE_STEPS=750
export SEED=1

python3 run_ner.py \
--task_type Chunk \
--data_dir . \
--model_name_or_path $BERT_MODEL \
--output_dir $OUTPUT_DIR \
--max_seq_length $MAX_LENGTH \
--num_train_epochs $NUM_EPOCHS \
--per_gpu_train_batch_size $BATCH_SIZE \
--save_steps $SAVE_STEPS \
--seed $SEED \
--do_train \
--do_eval \
--do_predict

46 changes: 25 additions & 21 deletions examples/token-classification/run_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" Fine-tuning the library models for named entity recognition on CoNLL-2003. """


import logging
import os
import sys
from dataclasses import dataclass, field
from importlib import import_module
from typing import Dict, List, Optional, Tuple

import numpy as np
from seqeval.metrics import f1_score, precision_score, recall_score
from seqeval.metrics import accuracy_score, f1_score, precision_score, recall_score
from torch import nn

from transformers import (
Expand All @@ -36,7 +35,7 @@
TrainingArguments,
set_seed,
)
from utils_ner import NerDataset, Split, get_labels
from utils_ner import Split, TokenClassificationDataset, TokenClassificationTask


logger = logging.getLogger(__name__)
Expand All @@ -54,6 +53,9 @@ class ModelArguments:
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
task_type: Optional[str] = field(
default="NER", metadata={"help": "Task type to fine tune in training (e.g. NER, POS, etc)"}
)
tokenizer_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)
Expand Down Expand Up @@ -113,6 +115,16 @@ def main():
f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
)

module = import_module("tasks")
try:
token_classification_task_clazz = getattr(module, model_args.task_type)
token_classification_task: TokenClassificationTask = token_classification_task_clazz()
except AttributeError:
raise ValueError(
f"Task {model_args.task_type} needs to be defined as a TokenClassificationTask subclass in {module}. "
f"Available tasks classes are: {TokenClassificationTask.__subclasses__()}"
)

# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
Expand All @@ -133,7 +145,7 @@ def main():
set_seed(training_args.seed)

# Prepare CONLL-2003 task
labels = get_labels(data_args.labels)
labels = token_classification_task.get_labels(data_args.labels)
label_map: Dict[int, str] = {i: label for i, label in enumerate(labels)}
num_labels = len(labels)

Expand Down Expand Up @@ -164,7 +176,8 @@ def main():

# Get datasets
train_dataset = (
NerDataset(
TokenClassificationDataset(
token_classification_task=token_classification_task,
data_dir=data_args.data_dir,
tokenizer=tokenizer,
labels=labels,
Expand All @@ -177,7 +190,8 @@ def main():
else None
)
eval_dataset = (
NerDataset(
TokenClassificationDataset(
token_classification_task=token_classification_task,
data_dir=data_args.data_dir,
tokenizer=tokenizer,
labels=labels,
Expand Down Expand Up @@ -209,6 +223,7 @@ def align_predictions(predictions: np.ndarray, label_ids: np.ndarray) -> Tuple[L
def compute_metrics(p: EvalPrediction) -> Dict:
preds_list, out_label_list = align_predictions(p.predictions, p.label_ids)
return {
"accuracy_score": accuracy_score(out_label_list, preds_list),
"precision": precision_score(out_label_list, preds_list),
"recall": recall_score(out_label_list, preds_list),
"f1": f1_score(out_label_list, preds_list),
Expand Down Expand Up @@ -253,7 +268,8 @@ def compute_metrics(p: EvalPrediction) -> Dict:

# Predict
if training_args.do_predict:
test_dataset = NerDataset(
test_dataset = TokenClassificationDataset(
token_classification_task=token_classification_task,
data_dir=data_args.data_dir,
tokenizer=tokenizer,
labels=labels,
Expand All @@ -278,19 +294,7 @@ def compute_metrics(p: EvalPrediction) -> Dict:
if trainer.is_world_master():
with open(output_test_predictions_file, "w") as writer:
with open(os.path.join(data_args.data_dir, "test.txt"), "r") as f:
example_id = 0
for line in f:
if line.startswith("-DOCSTART-") or line == "" or line == "\n":
writer.write(line)
if not preds_list[example_id]:
example_id += 1
elif preds_list[example_id]:
output_line = line.split()[0] + " " + preds_list[example_id].pop(0) + "\n"
writer.write(output_line)
else:
logger.warning(
"Maximum sequence length exceeded: No prediction for '%s'.", line.split()[0]
)
token_classification_task.write_predictions_to_file(writer, f, preds_list)

return results

Expand Down
36 changes: 26 additions & 10 deletions examples/token-classification/run_pl_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,17 @@
import glob
import logging
import os
from argparse import Namespace
from importlib import import_module

import numpy as np
import torch
from seqeval.metrics import f1_score, precision_score, recall_score
from seqeval.metrics import accuracy_score, f1_score, precision_score, recall_score
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader, TensorDataset

from lightning_base import BaseTransformer, add_generic_args, generic_train
from utils_ner import convert_examples_to_features, get_labels, read_examples_from_file
from utils_ner import TokenClassificationTask


logger = logging.getLogger(__name__)
Expand All @@ -24,10 +26,20 @@ class NERTransformer(BaseTransformer):
mode = "token-classification"

def __init__(self, hparams):
self.labels = get_labels(hparams.labels)
num_labels = len(self.labels)
if type(hparams) == dict:
hparams = Namespace(**hparams)
module = import_module("tasks")
try:
token_classification_task_clazz = getattr(module, hparams.task_type)
self.token_classification_task: TokenClassificationTask = token_classification_task_clazz()
except AttributeError:
raise ValueError(
f"Task {hparams.task_type} needs to be defined as a TokenClassificationTask subclass in {module}. "
f"Available tasks classes are: {TokenClassificationTask.__subclasses__()}"
)
self.labels = self.token_classification_task.get_labels(hparams.labels)
self.pad_token_label_id = CrossEntropyLoss().ignore_index
super().__init__(hparams, num_labels, self.mode)
super().__init__(hparams, len(self.labels), self.mode)

def forward(self, **inputs):
return self.model(**inputs)
Expand All @@ -42,8 +54,8 @@ def training_step(self, batch, batch_num):

outputs = self(**inputs)
loss = outputs[0]
tensorboard_logs = {"loss": loss, "rate": self.lr_scheduler.get_last_lr()[-1]}
return {"loss": loss, "log": tensorboard_logs}
# tensorboard_logs = {"loss": loss, "rate": self.lr_scheduler.get_last_lr()[-1]}
return {"loss": loss}

def prepare_data(self):
"Called to initialize data. Use the call to construct features"
Expand All @@ -55,8 +67,8 @@ def prepare_data(self):
features = torch.load(cached_features_file)
else:
logger.info("Creating features from dataset file at %s", args.data_dir)
examples = read_examples_from_file(args.data_dir, mode)
features = convert_examples_to_features(
examples = self.token_classification_task.read_examples_from_file(args.data_dir, mode)
features = self.token_classification_task.convert_examples_to_features(
examples,
self.labels,
args.max_seq_length,
Expand All @@ -74,7 +86,7 @@ def prepare_data(self):
logger.info("Saving features into cached file %s", cached_features_file)
torch.save(features, cached_features_file)

def load_dataset(self, mode, batch_size):
def get_dataloader(self, mode: int, batch_size: int) -> DataLoader:
"Load datasets. Called after prepare data."
cached_features_file = self._feature_file(mode)
logger.info("Loading features from cached file %s", cached_features_file)
Expand Down Expand Up @@ -124,6 +136,7 @@ def _eval_end(self, outputs):

results = {
"val_loss": val_loss_mean,
"accuracy_score": accuracy_score(out_label_list, preds_list),
"precision": precision_score(out_label_list, preds_list),
"recall": recall_score(out_label_list, preds_list),
"f1": f1_score(out_label_list, preds_list),
Expand Down Expand Up @@ -154,6 +167,9 @@ def test_epoch_end(self, outputs):
def add_model_specific_args(parser, root_dir):
# Add NER specific options
BaseTransformer.add_model_specific_args(parser, root_dir)
parser.add_argument(
"--task_type", default="NER", type=str, help="Task type to fine tune in training (e.g. NER, POS, etc)"
)
parser.add_argument(
"--max_seq_length",
default=128,
Expand Down
37 changes: 37 additions & 0 deletions examples/token-classification/run_pos.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
if ! [ -f ./dev.txt ]; then
echo "Download dev dataset...."
curl -L -o ./dev.txt 'https://github.com/UniversalDependencies/UD_English-EWT/raw/master/en_ewt-ud-dev.conllu'
fi

if ! [ -f ./test.txt ]; then
echo "Download test dataset...."
curl -L -o ./test.txt 'https://github.com/UniversalDependencies/UD_English-EWT/raw/master/en_ewt-ud-test.conllu'
fi

if ! [ -f ./train.txt ]; then
echo "Download train dataset...."
curl -L -o ./train.txt 'https://github.com/UniversalDependencies/UD_English-EWT/raw/master/en_ewt-ud-train.conllu'
fi

export MAX_LENGTH=200
export BERT_MODEL=bert-base-uncased
export OUTPUT_DIR=postagger-model
export BATCH_SIZE=32
export NUM_EPOCHS=3
export SAVE_STEPS=750
export SEED=1

python3 run_ner.py \
--task_type POS \
--data_dir . \
--model_name_or_path $BERT_MODEL \
--output_dir $OUTPUT_DIR \
--max_seq_length $MAX_LENGTH \
--num_train_epochs $NUM_EPOCHS \
--per_gpu_train_batch_size $BATCH_SIZE \
--save_steps $SAVE_STEPS \
--seed $SEED \
--do_train \
--do_eval \
--do_predict

39 changes: 39 additions & 0 deletions examples/token-classification/run_pos_pl.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#!/usr/bin/env bash
if ! [ -f ./dev.txt ]; then
echo "Download dev dataset...."
curl -L -o ./dev.txt 'https://github.com/UniversalDependencies/UD_English-EWT/raw/master/en_ewt-ud-dev.conllu'
fi

if ! [ -f ./test.txt ]; then
echo "Download test dataset...."
curl -L -o ./test.txt 'https://github.com/UniversalDependencies/UD_English-EWT/raw/master/en_ewt-ud-test.conllu'
fi

if ! [ -f ./train.txt ]; then
echo "Download train dataset...."
curl -L -o ./train.txt 'https://github.com/UniversalDependencies/UD_English-EWT/raw/master/en_ewt-ud-train.conllu'
fi

export MAX_LENGTH=200
export BERT_MODEL=bert-base-uncased
export OUTPUT_DIR=postagger-model
export BATCH_SIZE=32
export NUM_EPOCHS=3
export SAVE_STEPS=750
export SEED=1


# Add parent directory to python path to access lightning_base.py
export PYTHONPATH="../":"${PYTHONPATH}"

python3 run_pl_ner.py --data_dir ./ \
--task_type POS \
--model_name_or_path $BERT_MODEL \
--output_dir $OUTPUT_DIR \
--max_seq_length $MAX_LENGTH \
--num_train_epochs $NUM_EPOCHS \
--train_batch_size $BATCH_SIZE \
--seed $SEED \
--gpus 1 \
--do_train \
--do_predict
Loading

0 comments on commit eda07ef

Please sign in to comment.