Skip to content

Commit

Permalink
Chunkable token classification pipeline (huggingface#21771)
Browse files Browse the repository at this point in the history
* Chunkable classification pipeline 

The TokenClassificationPipeline is now able to process sequences longer than 512. No matter the framework, the model, the tokenizer. We just have to pass process_all=True and a stride number (optional). The behavior remains the same if you don't pass these optional parameters. For overlapping parts when using stride above 0, we consider only the max scores for each overlapped token in all chunks where the token is.

* Update token_classification.py

* Update token_classification.py

* Update token_classification.py

* Update token_classification.py

* Update token_classification.py

* Update token_classification.py

* Update token_classification.py

* Update token_classification.py

* Update token_classification.py

* Update token_classification.py

* Update token_classification.py

* Update token_classification.py

* update with latest black format

* update black format

* Update token_classification.py

* Update token_classification.py

* format correction

* Update token_classification.py

* Update token_classification.py

* Update token_classification.py

* Update token_classification.py

* Update comments

* Update src/transformers/pipelines/token_classification.py

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>

* Update token_classification.py

Correct spaces, remove process_all and keep only stride. If stride is provided, the pipeline is applied to the whole text.

* Update token_classification.py

* Update token_classification.py

* Update token_classification.py

* Update token_classification.py

* Update token_classification.py

* Update token_classification.py

* Update token_classification.py

* Update token_classification.py

* Update chunk aggregation

Update the chunk aggregation strategy based on entities aggregation.

* Update token_classification.py

* Update token_classification.py

* Update token_classification.py

* Update token_classification.py

* Update token_classification.py

* Update token_classification.py

* Update token_classification.py

* Update token_classification.py

* Update token_classification.py

* Update token_classification.py

* Update token_classification.py

* Update token_classification.py

Remove unnecessary pop from outputs dict

* Update token_classification.py

* Update token_classification.py

* Update token_classification.py

* Update token_classification.py

* Update token_classification.py

* Update token_classification.py

* Update token_classification.py

* Update src/transformers/pipelines/token_classification.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* add chunking tests

* correct formating

* correct formatting

* correct model id for test chunking

* update scores with nested simplify

* Update test_pipelines_token_classification.py

* Update test_pipelines_token_classification.py

* update model to a tiny one

* Update test_pipelines_token_classification.py

* Adding smaller test for chunking.

* Fixup

* Update token_classification.py

* Update src/transformers/pipelines/token_classification.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/pipelines/token_classification.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

---------

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
  • Loading branch information
3 people authored Mar 22, 2023
1 parent f48d331 commit d62e7d8
Show file tree
Hide file tree
Showing 2 changed files with 234 additions and 43 deletions.
156 changes: 113 additions & 43 deletions src/transformers/pipelines/token_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,19 @@
import numpy as np

from ..models.bert.tokenization_bert import BasicTokenizer
from ..utils import ExplicitEnum, add_end_docstrings, is_tf_available, is_torch_available
from .base import PIPELINE_INIT_ARGS, ArgumentHandler, Dataset, Pipeline
from ..utils import (
ExplicitEnum,
add_end_docstrings,
is_tf_available,
is_torch_available,
)
from .base import PIPELINE_INIT_ARGS, ArgumentHandler, ChunkPipeline, Dataset


if is_tf_available():
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
import tensorflow as tf

from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
if is_torch_available():
from ..models.auto.modeling_auto import MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING

Expand Down Expand Up @@ -60,6 +66,9 @@ class AggregationStrategy(ExplicitEnum):
grouped_entities (`bool`, *optional*, defaults to `False`):
DEPRECATED, use `aggregation_strategy` instead. Whether or not to group the tokens corresponding to the
same entity together in the predictions or not.
stride (`int`, *optional*):
If stride is provided, the pipeline is applied on all the text. The text is split into chunks of size
model_max_length. Works only with fast tokenizers and `aggregation_strategy` different from `NONE`.
aggregation_strategy (`str`, *optional*, defaults to `"none"`):
The strategy to fuse (or not) tokens based on the model prediction.
Expand All @@ -82,7 +91,7 @@ class AggregationStrategy(ExplicitEnum):
end up with different tags. Word entity will simply be the token with the maximum score.
""",
)
class TokenClassificationPipeline(Pipeline):
class TokenClassificationPipeline(ChunkPipeline):
"""
Named Entity Recognition pipeline using any `ModelForTokenClassification`. See the [named entity recognition
examples](../task_summary#named-entity-recognition) for more information.
Expand Down Expand Up @@ -139,6 +148,7 @@ def _sanitize_parameters(
ignore_subwords: Optional[bool] = None,
aggregation_strategy: Optional[AggregationStrategy] = None,
offset_mapping: Optional[List[Tuple[int, int]]] = None,
stride: Optional[int] = None,
):
preprocess_params = {}
if offset_mapping is not None:
Expand Down Expand Up @@ -174,11 +184,30 @@ def _sanitize_parameters(
):
raise ValueError(
"Slow tokenizers cannot handle subwords. Please set the `aggregation_strategy` option"
'to `"simple"` or use a fast tokenizer.'
' to `"simple"` or use a fast tokenizer.'
)
postprocess_params["aggregation_strategy"] = aggregation_strategy
if ignore_labels is not None:
postprocess_params["ignore_labels"] = ignore_labels
if stride is not None:
if aggregation_strategy == AggregationStrategy.NONE:
raise ValueError(
"`stride` was provided to process all the text but `aggregation_strategy="
f'"{aggregation_strategy}"`, please select another one instead.'
)
else:
if self.tokenizer.is_fast:
tokenizer_params = {
"return_overflowing_tokens": True,
"padding": True,
"stride": stride,
}
preprocess_params["tokenizer_params"] = tokenizer_params
else:
raise ValueError(
"`stride` was provided to process all the text but you're using a slow tokenizer."
" Please use a fast tokenizer."
)
return preprocess_params, {}, postprocess_params

def __call__(self, inputs: Union[str, List[str]], **kwargs):
Expand Down Expand Up @@ -213,29 +242,40 @@ def __call__(self, inputs: Union[str, List[str]], **kwargs):

return super().__call__(inputs, **kwargs)

def preprocess(self, sentence, offset_mapping=None):
def preprocess(self, sentence, offset_mapping=None, **preprocess_params):
tokenizer_params = preprocess_params.pop("tokenizer_params", {})
truncation = True if self.tokenizer.model_max_length and self.tokenizer.model_max_length > 0 else False
model_inputs = self.tokenizer(
inputs = self.tokenizer(
sentence,
return_tensors=self.framework,
truncation=truncation,
return_special_tokens_mask=True,
return_offsets_mapping=self.tokenizer.is_fast,
**tokenizer_params,
)
if offset_mapping:
model_inputs["offset_mapping"] = offset_mapping
inputs.pop("overflow_to_sample_mapping", None)
num_chunks = len(inputs["input_ids"])

model_inputs["sentence"] = sentence
for i in range(num_chunks):
if self.framework == "tf":
model_inputs = {k: tf.expand_dims(v[i], 0) for k, v in inputs.items()}
else:
model_inputs = {k: v[i].unsqueeze(0) for k, v in inputs.items()}
if offset_mapping is not None:
model_inputs["offset_mapping"] = offset_mapping
model_inputs["sentence"] = sentence if i == 0 else None
model_inputs["is_last"] = i == num_chunks - 1

return model_inputs
yield model_inputs

def _forward(self, model_inputs):
# Forward
special_tokens_mask = model_inputs.pop("special_tokens_mask")
offset_mapping = model_inputs.pop("offset_mapping", None)
sentence = model_inputs.pop("sentence")
is_last = model_inputs.pop("is_last")
if self.framework == "tf":
logits = self.model(model_inputs.data)[0]
logits = self.model(**model_inputs)[0]
else:
output = self.model(**model_inputs)
logits = output["logits"] if isinstance(output, dict) else output[0]
Expand All @@ -245,38 +285,67 @@ def _forward(self, model_inputs):
"special_tokens_mask": special_tokens_mask,
"offset_mapping": offset_mapping,
"sentence": sentence,
"is_last": is_last,
**model_inputs,
}

def postprocess(self, model_outputs, aggregation_strategy=AggregationStrategy.NONE, ignore_labels=None):
def postprocess(self, all_outputs, aggregation_strategy=AggregationStrategy.NONE, ignore_labels=None):
if ignore_labels is None:
ignore_labels = ["O"]
logits = model_outputs["logits"][0].numpy()
sentence = model_outputs["sentence"]
input_ids = model_outputs["input_ids"][0]
offset_mapping = model_outputs["offset_mapping"][0] if model_outputs["offset_mapping"] is not None else None
special_tokens_mask = model_outputs["special_tokens_mask"][0].numpy()

maxes = np.max(logits, axis=-1, keepdims=True)
shifted_exp = np.exp(logits - maxes)
scores = shifted_exp / shifted_exp.sum(axis=-1, keepdims=True)

if self.framework == "tf":
input_ids = input_ids.numpy()
offset_mapping = offset_mapping.numpy() if offset_mapping is not None else None

pre_entities = self.gather_pre_entities(
sentence, input_ids, scores, offset_mapping, special_tokens_mask, aggregation_strategy
)
grouped_entities = self.aggregate(pre_entities, aggregation_strategy)
# Filter anything that is in self.ignore_labels
entities = [
entity
for entity in grouped_entities
if entity.get("entity", None) not in ignore_labels
and entity.get("entity_group", None) not in ignore_labels
]
return entities
all_entities = []
for model_outputs in all_outputs:
logits = model_outputs["logits"][0].numpy()
sentence = all_outputs[0]["sentence"]
input_ids = model_outputs["input_ids"][0]
offset_mapping = (
model_outputs["offset_mapping"][0] if model_outputs["offset_mapping"] is not None else None
)
special_tokens_mask = model_outputs["special_tokens_mask"][0].numpy()

maxes = np.max(logits, axis=-1, keepdims=True)
shifted_exp = np.exp(logits - maxes)
scores = shifted_exp / shifted_exp.sum(axis=-1, keepdims=True)

if self.framework == "tf":
input_ids = input_ids.numpy()
offset_mapping = offset_mapping.numpy() if offset_mapping is not None else None

pre_entities = self.gather_pre_entities(
sentence, input_ids, scores, offset_mapping, special_tokens_mask, aggregation_strategy
)
grouped_entities = self.aggregate(pre_entities, aggregation_strategy)
# Filter anything that is in self.ignore_labels
entities = [
entity
for entity in grouped_entities
if entity.get("entity", None) not in ignore_labels
and entity.get("entity_group", None) not in ignore_labels
]
all_entities.extend(entities)
num_chunks = len(all_outputs)
if num_chunks > 1:
all_entities = self.aggregate_overlapping_entities(all_entities)
return all_entities

def aggregate_overlapping_entities(self, entities):
if len(entities) == 0:
return entities
entities = sorted(entities, key=lambda x: x["start"])
aggregated_entities = []
previous_entity = entities[0]
for entity in entities:
if previous_entity["start"] <= entity["start"] < previous_entity["end"]:
current_length = entity["end"] - entity["start"]
previous_length = previous_entity["end"] - previous_entity["start"]
if current_length > previous_length:
previous_entity = entity
elif current_length == previous_length and entity["score"] > previous_entity["score"]:
previous_entity = entity
else:
aggregated_entities.append(previous_entity)
previous_entity = entity
aggregated_entities.append(previous_entity)
return aggregated_entities

def gather_pre_entities(
self,
Expand All @@ -290,9 +359,7 @@ def gather_pre_entities(
"""Fuse various numpy arrays into dicts with all the information needed for aggregation"""
pre_entities = []
for idx, token_scores in enumerate(scores):
# Filter special_tokens, they should only occur
# at the sentence boundaries since we're not encoding pairs of
# sentences so we don't have to keep track of those.
# Filter special_tokens
if special_tokens_mask[idx]:
continue

Expand All @@ -317,7 +384,10 @@ def gather_pre_entities(
AggregationStrategy.AVERAGE,
AggregationStrategy.MAX,
}:
warnings.warn("Tokenizer does not support real words, using fallback heuristic", UserWarning)
warnings.warn(
"Tokenizer does not support real words, using fallback heuristic",
UserWarning,
)
is_subword = start_ind > 0 and " " not in sentence[start_ind - 1 : start_ind + 1]

if int(input_ids[idx]) == self.tokenizer.unk_token_id:
Expand Down
121 changes: 121 additions & 0 deletions tests/pipelines/test_pipelines_token_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,127 @@ def run_aggregation_strategy(self, model, tokenizer):
)
self.assertEqual(token_classifier._postprocess_params["aggregation_strategy"], AggregationStrategy.FIRST)

@slow
@require_torch
def test_chunking(self):
NER_MODEL = "elastic/distilbert-base-uncased-finetuned-conll03-english"
model = AutoModelForTokenClassification.from_pretrained(NER_MODEL)
tokenizer = AutoTokenizer.from_pretrained(NER_MODEL, use_fast=True)
tokenizer.model_max_length = 10
stride = 5
sentence = (
"Hugging Face, Inc. is a French company that develops tools for building applications using machine learning. "
"The company, based in New York City was founded in 2016 by French entrepreneurs Clément Delangue, Julien Chaumond, and Thomas Wolf."
)

token_classifier = TokenClassificationPipeline(
model=model, tokenizer=tokenizer, aggregation_strategy="simple", stride=stride
)
output = token_classifier(sentence)
self.assertEqual(
nested_simplify(output),
[
{"entity_group": "ORG", "score": 0.978, "word": "hugging face, inc.", "start": 0, "end": 18},
{"entity_group": "MISC", "score": 0.999, "word": "french", "start": 24, "end": 30},
{"entity_group": "LOC", "score": 0.997, "word": "new york city", "start": 131, "end": 144},
{"entity_group": "MISC", "score": 0.999, "word": "french", "start": 168, "end": 174},
{"entity_group": "PER", "score": 0.999, "word": "clement delangue", "start": 189, "end": 205},
{"entity_group": "PER", "score": 0.999, "word": "julien chaumond", "start": 207, "end": 222},
{"entity_group": "PER", "score": 0.999, "word": "thomas wolf", "start": 228, "end": 239},
],
)

token_classifier = TokenClassificationPipeline(
model=model, tokenizer=tokenizer, aggregation_strategy="first", stride=stride
)
output = token_classifier(sentence)
self.assertEqual(
nested_simplify(output),
[
{"entity_group": "ORG", "score": 0.978, "word": "hugging face, inc.", "start": 0, "end": 18},
{"entity_group": "MISC", "score": 0.999, "word": "french", "start": 24, "end": 30},
{"entity_group": "LOC", "score": 0.997, "word": "new york city", "start": 131, "end": 144},
{"entity_group": "MISC", "score": 0.999, "word": "french", "start": 168, "end": 174},
{"entity_group": "PER", "score": 0.999, "word": "clement delangue", "start": 189, "end": 205},
{"entity_group": "PER", "score": 0.999, "word": "julien chaumond", "start": 207, "end": 222},
{"entity_group": "PER", "score": 0.999, "word": "thomas wolf", "start": 228, "end": 239},
],
)

token_classifier = TokenClassificationPipeline(
model=model, tokenizer=tokenizer, aggregation_strategy="max", stride=stride
)
output = token_classifier(sentence)
self.assertEqual(
nested_simplify(output),
[
{"entity_group": "ORG", "score": 0.978, "word": "hugging face, inc.", "start": 0, "end": 18},
{"entity_group": "MISC", "score": 0.999, "word": "french", "start": 24, "end": 30},
{"entity_group": "LOC", "score": 0.997, "word": "new york city", "start": 131, "end": 144},
{"entity_group": "MISC", "score": 0.999, "word": "french", "start": 168, "end": 174},
{"entity_group": "PER", "score": 0.999, "word": "clement delangue", "start": 189, "end": 205},
{"entity_group": "PER", "score": 0.999, "word": "julien chaumond", "start": 207, "end": 222},
{"entity_group": "PER", "score": 0.999, "word": "thomas wolf", "start": 228, "end": 239},
],
)

token_classifier = TokenClassificationPipeline(
model=model, tokenizer=tokenizer, aggregation_strategy="average", stride=stride
)
output = token_classifier(sentence)
self.assertEqual(
nested_simplify(output),
[
{"entity_group": "ORG", "score": 0.978, "word": "hugging face, inc.", "start": 0, "end": 18},
{"entity_group": "MISC", "score": 0.999, "word": "french", "start": 24, "end": 30},
{"entity_group": "LOC", "score": 0.997, "word": "new york city", "start": 131, "end": 144},
{"entity_group": "MISC", "score": 0.999, "word": "french", "start": 168, "end": 174},
{"entity_group": "PER", "score": 0.999, "word": "clement delangue", "start": 189, "end": 205},
{"entity_group": "PER", "score": 0.999, "word": "julien chaumond", "start": 207, "end": 222},
{"entity_group": "PER", "score": 0.999, "word": "thomas wolf", "start": 228, "end": 239},
],
)

@require_torch
def test_chunking_fast(self):
# Note: We cannot run the test on "conflicts" on the chunking.
# The problem is that the model is random, and thus the results do heavily
# depend on the chunking, so we cannot expect "abcd" and "bcd" to find
# the same entities. We defer to slow tests for this.
pipe = pipeline(model="hf-internal-testing/tiny-bert-for-token-classification")
sentence = "The company, based in New York City was founded in 2016 by French entrepreneurs"

results = pipe(sentence, aggregation_strategy="first")
# This is what this random model gives on the full sentence
self.assertEqual(
nested_simplify(results),
[
# This is 2 actual tokens
{"end": 39, "entity_group": "MISC", "score": 0.115, "start": 31, "word": "city was"},
{"end": 79, "entity_group": "MISC", "score": 0.115, "start": 66, "word": "entrepreneurs"},
],
)

# This will force the tokenizer to split after "city was".
pipe.tokenizer.model_max_length = 12
self.assertEqual(
pipe.tokenizer.decode(pipe.tokenizer.encode(sentence, truncation=True)),
"[CLS] the company, based in new york city was [SEP]",
)

stride = 4
results = pipe(sentence, aggregation_strategy="first", stride=stride)
self.assertEqual(
nested_simplify(results),
[
{"end": 39, "entity_group": "MISC", "score": 0.115, "start": 31, "word": "city was"},
# This is an extra entity found by this random model, but at least both original
# entities are there
{"end": 58, "entity_group": "MISC", "score": 0.115, "start": 56, "word": "by"},
{"end": 79, "entity_group": "MISC", "score": 0.115, "start": 66, "word": "entrepreneurs"},
],
)

@require_torch
@slow
def test_spanish_bert(self):
Expand Down

0 comments on commit d62e7d8

Please sign in to comment.