Skip to content

Commit

Permalink
Add DeBERTa head models (huggingface#9691)
Browse files Browse the repository at this point in the history
* Add DebertaForMaskedLM, DebertaForTokenClassification, DebertaForQuestionAnswering

* Add docs and fix quality

* Fix Deberta not having pooler
  • Loading branch information
NielsRogge authored Jan 20, 2021
1 parent a7b62fe commit d1370d2
Show file tree
Hide file tree
Showing 7 changed files with 686 additions and 253 deletions.
21 changes: 21 additions & 0 deletions docs/source/model_doc/deberta.rst
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,29 @@ DebertaPreTrainedModel
:members:


DebertaForMaskedLM
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.DebertaForMaskedLM
:members:


DebertaForSequenceClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.DebertaForSequenceClassification
:members:


DebertaForTokenClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.DebertaForTokenClassification
:members:


DebertaForQuestionAnswering
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.DebertaForQuestionAnswering
:members:
6 changes: 6 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,10 @@
"DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST",
"DebertaForSequenceClassification",
"DebertaModel",
"DebertaForMaskedLM",
"DebertaPreTrainedModel",
"DebertaForTokenClassification",
"DebertaForQuestionAnswering",
]
)
_import_structure["models.distilbert"].extend(
Expand Down Expand Up @@ -1527,7 +1530,10 @@
)
from .models.deberta import (
DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
DebertaForMaskedLM,
DebertaForQuestionAnswering,
DebertaForSequenceClassification,
DebertaForTokenClassification,
DebertaModel,
DebertaPreTrainedModel,
)
Expand Down
12 changes: 11 additions & 1 deletion src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,13 @@
CamembertModel,
)
from ..ctrl.modeling_ctrl import CTRLForSequenceClassification, CTRLLMHeadModel, CTRLModel
from ..deberta.modeling_deberta import DebertaForSequenceClassification, DebertaModel
from ..deberta.modeling_deberta import (
DebertaForMaskedLM,
DebertaForQuestionAnswering,
DebertaForSequenceClassification,
DebertaForTokenClassification,
DebertaModel,
)
from ..distilbert.modeling_distilbert import (
DistilBertForMaskedLM,
DistilBertForMultipleChoice,
Expand Down Expand Up @@ -378,6 +384,7 @@
(FunnelConfig, FunnelForMaskedLM),
(MPNetConfig, MPNetForMaskedLM),
(TapasConfig, TapasForMaskedLM),
(DebertaConfig, DebertaForMaskedLM),
]
)

Expand Down Expand Up @@ -426,6 +433,7 @@
(FunnelConfig, FunnelForMaskedLM),
(MPNetConfig, MPNetForMaskedLM),
(TapasConfig, TapasForMaskedLM),
(DebertaConfig, DebertaForMaskedLM),
]
)

Expand Down Expand Up @@ -503,6 +511,7 @@
(FunnelConfig, FunnelForQuestionAnswering),
(LxmertConfig, LxmertForQuestionAnswering),
(MPNetConfig, MPNetForQuestionAnswering),
(DebertaConfig, DebertaForQuestionAnswering),
]
)

Expand Down Expand Up @@ -533,6 +542,7 @@
(FlaubertConfig, FlaubertForTokenClassification),
(FunnelConfig, FunnelForTokenClassification),
(MPNetConfig, MPNetForTokenClassification),
(DebertaConfig, DebertaForTokenClassification),
]
)

Expand Down
6 changes: 6 additions & 0 deletions src/transformers/models/deberta/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@
"DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST",
"DebertaForSequenceClassification",
"DebertaModel",
"DebertaForMaskedLM",
"DebertaPreTrainedModel",
"DebertaForTokenClassification",
"DebertaForQuestionAnswering",
]


Expand All @@ -42,7 +45,10 @@
if is_torch_available():
from .modeling_deberta import (
DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
DebertaForMaskedLM,
DebertaForQuestionAnswering,
DebertaForSequenceClassification,
DebertaForTokenClassification,
DebertaModel,
DebertaPreTrainedModel,
)
Expand Down
Loading

0 comments on commit d1370d2

Please sign in to comment.