Skip to content

Commit

Permalink
Including DeBertaV2 for zero-shot (guillaume-be#418)
Browse files Browse the repository at this point in the history
* Including DeBertaV2 for zero-shot

* Ignore Clippy warning

---------

Co-authored-by: guillaume-be <guillaume.becquin@gmail.com>
  • Loading branch information
jondot and guillaume-be authored Oct 21, 2023
1 parent ce48aa0 commit dc99a30
Showing 1 changed file with 29 additions and 0 deletions.
29 changes: 29 additions & 0 deletions src/pipelines/zero_shot_classification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ use crate::albert::AlbertForSequenceClassification;
use crate::bart::BartForSequenceClassification;
use crate::bert::BertForSequenceClassification;
use crate::deberta::DebertaForSequenceClassification;
use crate::deberta_v2::DebertaV2ForSequenceClassification;
use crate::distilbert::DistilBertModelClassifier;
use crate::longformer::LongformerForSequenceClassification;
use crate::mobilebert::MobileBertForSequenceClassification;
Expand Down Expand Up @@ -217,11 +218,14 @@ impl Default for ZeroShotClassificationConfig {
/// The models are using a classification architecture that should be trained on Natural Language Inference.
/// The models should output a Tensor of size > 2 in the label dimension, with the first logit corresponding
/// to contradiction and the last logit corresponding to entailment.
#[allow(clippy::large_enum_variant)]
pub enum ZeroShotClassificationOption {
/// Bart for Sequence Classification
Bart(BartForSequenceClassification),
/// DeBERTa for Sequence Classification
Deberta(DebertaForSequenceClassification),
/// DeBERTaV2 for Sequence Classification
DebertaV2(DebertaV2ForSequenceClassification),
/// Bert for Sequence Classification
Bert(BertForSequenceClassification),
/// DistilBert for Sequence Classification
Expand Down Expand Up @@ -288,6 +292,17 @@ impl ZeroShotClassificationOption {
))
}
}
ModelType::DebertaV2 => {
if let ConfigOption::DebertaV2(config) = model_config {
Ok(Self::DebertaV2(
DebertaV2ForSequenceClassification::new(var_store.root(), config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
"You can only supply a DebertaConfig for DeBERTaV2!".to_string(),
))
}
}
ModelType::Bert => {
if let ConfigOption::Bert(config) = model_config {
Ok(Self::Bert(
Expand Down Expand Up @@ -413,6 +428,7 @@ impl ZeroShotClassificationOption {
match *self {
Self::Bart(_) => ModelType::Bart,
Self::Deberta(_) => ModelType::Deberta,
Self::DebertaV2(_) => ModelType::DebertaV2,
Self::Bert(_) => ModelType::Bert,
Self::Roberta(_) => ModelType::Roberta,
Self::XLMRoberta(_) => ModelType::Roberta,
Expand Down Expand Up @@ -474,6 +490,19 @@ impl ZeroShotClassificationOption {
.expect("Error in DeBERTa forward_t")
.logits
}
Self::DebertaV2(ref model) => {
model
.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
train,
)
.expect("Error in DeBERTaV2 forward_t")
.logits
}
Self::DistilBert(ref model) => {
model
.forward_t(input_ids, mask, input_embeds, train)
Expand Down

0 comments on commit dc99a30

Please sign in to comment.