Skip to content

Commit

Permalink
- Fixed RoBERTa confic checks for sentence classification (guillaume-…
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaume-be authored Sep 10, 2022
1 parent e4a2a10 commit 59c0e66
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ All notable changes to this project will be documented in this file. The format
## Changed
- Addition of type aliases for the controlled generation (`PrefixAllowedFunction`) and zero-shot classification (`ZeroShotTemplate`)

## Fixed
- Fixed configuration check for RoBERTa models for sentence classification.

## [0.18.0] - 2022-07-24
## Added
- Support for sentence embeddings models and pipelines, based on [SentenceTransformers](https://www.sbert.net).
Expand Down
8 changes: 4 additions & 4 deletions src/pipelines/sequence_classification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -283,24 +283,24 @@ impl SequenceClassificationOption {
}
}
ModelType::Roberta => {
if let ConfigOption::Bert(config) = config {
if let ConfigOption::Roberta(config) = config {
Ok(SequenceClassificationOption::Roberta(
RobertaForSequenceClassification::new(p, config),
))
} else {
Err(RustBertError::InvalidConfigurationError(
"You can only supply a BertConfig for Roberta!".to_string(),
"You can only supply a RobertaConfig for Roberta!".to_string(),
))
}
}
ModelType::XLMRoberta => {
if let ConfigOption::Bert(config) = config {
if let ConfigOption::Roberta(config) = config {
Ok(SequenceClassificationOption::XLMRoberta(
RobertaForSequenceClassification::new(p, config),
))
} else {
Err(RustBertError::InvalidConfigurationError(
"You can only supply a BertConfig for Roberta!".to_string(),
"You can only supply a RobertaConfig for Roberta!".to_string(),
))
}
}
Expand Down

0 comments on commit 59c0e66

Please sign in to comment.