Skip to content

Commit

Permalink
Update and extend tests for spacy v3.5.4 (#392)
Browse files Browse the repository at this point in the history
* Update test for compatibility with spacy v3.5.4+

* Extend tests for sourced/replaced listeners
  • Loading branch information
adrianeboyd authored Jun 29, 2023
1 parent d3b532a commit 2c4c845
Showing 1 changed file with 123 additions and 1 deletion.
124 changes: 123 additions & 1 deletion spacy_transformers/tests/test_pipeline_component.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import pytest
from packaging.version import Version
import torch
import spacy
from spacy.language import Language
from spacy.training.example import Example
from spacy.util import make_tempdir
Expand Down Expand Up @@ -195,7 +197,6 @@ def test_transformer_pipeline_tagger_senter_listener():
tagger.add_label(tag)

# Check that the Transformer component finds it listeners
assert transformer.listeners == []
optimizer = nlp.initialize(lambda: train_examples)
assert tagger_trf in transformer.listeners

Expand Down Expand Up @@ -492,3 +493,124 @@ def test_no_update_listener_in_predict():

transformer.predict(docs)
assert listener._backprop is not None


@pytest.mark.skipif(
Version(spacy.__version__) < Version("3.5.4"), reason="Bug fixed in spaCy v3.5.4"
)
def test_source_replace_listeners():
"""Test that a pipeline with a transformer+tagger+senter and some replaced
listeners runs and trains properly"""
orig_config = """
[nlp]
lang = "en"
pipeline = ["transformer","tagger","senter"]
[components]
[components.senter]
factory = "senter"
[components.senter.model]
@architectures = "spacy.Tagger.v1"
nO = null
[components.senter.model.tok2vec]
@architectures = "spacy-transformers.TransformerListener.v1"
grad_factor = 1.0
upstream = "transformer"
[components.senter.model.tok2vec.pooling]
@layers = "reduce_mean.v1"
[components.tagger]
factory = "tagger"
[components.tagger.model]
@architectures = "spacy.Tagger.v1"
nO = null
[components.tagger.model.tok2vec]
@architectures = "spacy-transformers.TransformerListener.v1"
grad_factor = 1.0
upstream = "transformer"
[components.tagger.model.tok2vec.pooling]
@layers = "reduce_mean.v1"
[components.transformer]
factory = "transformer"
[components.transformer.model]
@architectures = "spacy-transformers.TransformerModel.v3"
name = "distilbert-base-uncased"
"""
orig_config = Config().from_str(cfg_string)
nlp = util.load_model_from_config(orig_config, auto_fill=True, validate=True)
assert nlp.pipe_names == ["transformer", "tagger", "senter"]
tagger = nlp.get_pipe("tagger")
train_examples = []
for t in TRAIN_DATA:
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
for tag in t[1]["tags"]:
tagger.add_label(tag)
optimizer = nlp.initialize(lambda: train_examples)
assert nlp.get_pipe("transformer").listening_components == ["tagger", "senter"]
for i in range(2):
losses = {}
nlp.update(train_examples, sgd=optimizer, losses=losses)

with make_tempdir() as dir_path:
nlp.to_disk(dir_path)
base_model = str(dir_path)
new_config = {
"nlp": {
"lang": "en",
"pipeline": ["transformer", "tagger", "senter", "ner"],
},
"components": {
"transformer": {"source": base_model},
"tagger": {
"source": base_model,
"replace_listeners": ["model.tok2vec"],
},
"senter": {
"source": base_model,
"replace_listeners": ["model.tok2vec"],
},
"ner": {
"factory": "ner",
"model": {
"@architectures": "spacy.TransitionBasedParser.v2",
"state_type": "ner",
"tok2vec": {
"@architectures": "spacy-transformers.TransformerListener.v1",
"grad_factor": 1.0,
"upstream": "transformer",
"pooling": {"@layers": "reduce_mean.v1"},
},
},
},
},
}
new_nlp = util.load_model_from_config(new_config, auto_fill=True)
for component in ("tagger", "senter"):
assert (
new_nlp.config["components"][component]["model"]["tok2vec"][
"@architectures"
]
== "spacy-transformers.Tok2VecTransformer.v3"
)
assert new_nlp.get_pipe("transformer").listening_components == ["ner"]

with make_tempdir() as new_dir_path:
new_nlp.to_disk(new_dir_path)
new_nlp_re = spacy.load(new_dir_path)
for component in ("tagger", "senter"):
assert (
new_nlp.config["components"][component]["model"]["tok2vec"][
"@architectures"
]
== "spacy-transformers.Tok2VecTransformer.v3"
)
assert new_nlp_re.get_pipe("transformer").listening_components == ["ner"]

0 comments on commit 2c4c845

Please sign in to comment.