Skip to content

Commit

Permalink
Make max_length optional (guillaume-be#296)
Browse files Browse the repository at this point in the history
* Made `max_length` an optional argument for generation methods and pipelines

* Updated changelog
  • Loading branch information
guillaume-be authored Nov 15, 2022
1 parent 5d2b107 commit 05367b4
Show file tree
Hide file tree
Showing 31 changed files with 216 additions and 147 deletions.
10 changes: 6 additions & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@ All notable changes to this project will be documented in this file. The format
- Addition of Keyword/Keyphrases extraction pipeline based on KeyBERT (https://github.com/MaartenGr/KeyBERT)

## Changed
- Addition of type aliases for the controlled generation (`PrefixAllowedFunction`) and zero-shot classification (`ZeroShotTemplate`)
- (BREAKING) `merges_resource` now optional for all pipelines
- Allow mixing local and remote resources in pipelines
- Upgraded to `torch` 1.13 (via `tch` 0.9.0)
- Addition of type aliases for the controlled generation (`PrefixAllowedFunction`) and zero-shot classification (`ZeroShotTemplate`).
- (BREAKING) `merges_resource` now optional for all pipelines.
- Allow mixing local and remote resources in pipelines.
- Upgraded to `torch` 1.13 (via `tch` 0.9.0).
- (BREAKING) Made the `max_length` argument for generation methods and pipelines optional.

## Fixed
- Fixed configuration check for RoBERTa models for sentence classification.
- Fixed a bug causing the input prompt to be truncated for text generation if the prompt length was longer than `max_length`

## [0.18.0] - 2022-07-24
## Added
Expand Down
2 changes: 1 addition & 1 deletion benches/generation_benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ fn create_text_generation_model() -> TextGenerationModel {
Gpt2MergesResources::GPT2,
))),
min_length: 0,
max_length: 30,
max_length: Some(30),
do_sample: true,
early_stopping: false,
num_beams: 5,
Expand Down
2 changes: 1 addition & 1 deletion examples/generation_gpt2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ fn main() -> anyhow::Result<()> {
// Set-up model
let generate_config = TextGenerationConfig {
model_type: ModelType::GPT2,
max_length: 30,
max_length: Some(30),
do_sample: false,
num_beams: 1,
temperature: 1.0,
Expand Down
2 changes: 1 addition & 1 deletion examples/generation_gpt_neo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ fn main() -> anyhow::Result<()> {
vocab_resource,
merges_resource: Some(merges_resource),
min_length: 10,
max_length: 32,
max_length: Some(32),
do_sample: false,
early_stopping: true,
num_beams: 4,
Expand Down
2 changes: 1 addition & 1 deletion examples/generation_reformer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ fn main() -> anyhow::Result<()> {
vocab_resource,
merges_resource: None,
min_length: 100,
max_length: 100,
max_length: Some(100),
do_sample: true,
early_stopping: false,
num_beams: 3,
Expand Down
2 changes: 1 addition & 1 deletion examples/generation_xlnet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ fn main() -> anyhow::Result<()> {
config_resource,
vocab_resource,
merges_resource: None,
max_length: 32,
max_length: Some(32),
do_sample: false,
num_beams: 3,
temperature: 1.0,
Expand Down
2 changes: 1 addition & 1 deletion examples/summarization_bart.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ fn main() -> anyhow::Result<()> {
num_beams: 1,
length_penalty: 1.0,
min_length: 56,
max_length: 142,
max_length: Some(142),
device: Device::Cpu,
..Default::default()
};
Expand Down
16 changes: 10 additions & 6 deletions src/bart/bart_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1055,7 +1055,7 @@ impl BartGenerator {
/// # let weights_path = &home.as_path().join("model.ot");
/// let device = Device::cuda_if_available();
/// let generate_config = GenerateConfig {
/// max_length: 30,
/// max_length: Some(30),
/// do_sample: true,
/// num_beams: 5,
/// temperature: 1.1,
Expand Down Expand Up @@ -1183,16 +1183,18 @@ impl PrivateLanguageGenerator<BartForConditionalGeneration, RobertaVocab, Robert
&self,
scores: &mut Tensor,
current_length: i64,
max_length: i64,
max_length: Option<i64>,
forced_bos_token_id: Option<i64>,
) {
if current_length == 1 {
self.force_token_id_generation(
scores,
&[forced_bos_token_id.unwrap_or_else(|| self.get_bos_id().unwrap())],
);
} else if current_length == max_length - 1 {
self.force_token_id_generation(scores, self.get_eos_ids().as_ref().unwrap());
} else if let Some(max_length) = max_length {
if current_length == max_length - 1 {
self.force_token_id_generation(scores, self.get_eos_ids().as_ref().unwrap());
}
}
}

Expand Down Expand Up @@ -1231,15 +1233,17 @@ impl PrivateLanguageGenerator<BartForConditionalGeneration, RobertaVocab, Robert
fn encode_prompt_text<S>(
&self,
prompt_text: &[S],
max_len: i64,
max_len: Option<i64>,
pad_token_id: Option<i64>,
) -> Tensor
where
S: AsRef<str> + Sync,
{
let tokens = self._get_tokenizer().encode_list(
prompt_text,
max_len as usize,
max_len
.map(|max_len| max_len as usize)
.unwrap_or(usize::MAX),
&TruncationStrategy::LongestFirst,
0,
);
Expand Down
2 changes: 1 addition & 1 deletion src/gpt2/gpt2_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -695,7 +695,7 @@ impl GPT2Generator {
/// use rust_bert::pipelines::generation_utils::GenerateConfig;
///
/// let generate_config = GenerateConfig {
/// max_length: 30,
/// max_length: Some(30),
/// do_sample: true,
/// num_beams: 5,
/// temperature: 1.1,
Expand Down
2 changes: 1 addition & 1 deletion src/gpt_neo/gpt_neo_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -670,7 +670,7 @@ impl GptNeoGenerator {
/// use rust_bert::pipelines::generation_utils::GenerateConfig;
///
/// let generate_config = GenerateConfig {
/// max_length: 30,
/// max_length: Some(30),
/// do_sample: true,
/// num_beams: 5,
/// temperature: 1.1,
Expand Down
16 changes: 10 additions & 6 deletions src/m2m_100/m2m_100_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -604,7 +604,7 @@ impl M2M100Generator {
/// # let weights_path = &home.as_path().join("model.ot");
/// let device = Device::cuda_if_available();
/// let generate_config = GenerateConfig {
/// max_length: 30,
/// max_length: Some(30),
/// do_sample: true,
/// num_beams: 5,
/// temperature: 1.1,
Expand Down Expand Up @@ -734,13 +734,15 @@ impl PrivateLanguageGenerator<M2M100ForConditionalGeneration, M2M100Vocab, M2M10
&self,
scores: &mut Tensor,
current_length: i64,
max_length: i64,
max_length: Option<i64>,
forced_bos_token_id: Option<i64>,
) {
if current_length == 1 {
self.force_token_id_generation(scores, &[forced_bos_token_id.unwrap_or(250004)]);
} else if current_length == max_length - 1 {
self.force_token_id_generation(scores, self.get_eos_ids().as_ref().unwrap());
} else if let Some(max_length) = max_length {
if current_length == max_length - 1 {
self.force_token_id_generation(scores, self.get_eos_ids().as_ref().unwrap());
}
}
}

Expand Down Expand Up @@ -779,15 +781,17 @@ impl PrivateLanguageGenerator<M2M100ForConditionalGeneration, M2M100Vocab, M2M10
fn encode_prompt_text<S>(
&self,
prompt_text: &[S],
max_len: i64,
max_len: Option<i64>,
pad_token_id: Option<i64>,
) -> Tensor
where
S: AsRef<str> + Sync,
{
let tokens = self._get_tokenizer().encode_list(
prompt_text,
max_len as usize,
max_len
.map(|max_len| max_len as usize)
.unwrap_or(usize::MAX),
&TruncationStrategy::LongestFirst,
0,
);
Expand Down
16 changes: 10 additions & 6 deletions src/marian/marian_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -824,7 +824,7 @@ impl MarianGenerator {
/// # let weights_path = &home.as_path().join("model.ot");
/// let device = Device::cuda_if_available();
/// let generate_config = GenerateConfig {
/// max_length: 512,
/// max_length: Some(512),
/// do_sample: true,
/// num_beams: 6,
/// temperature: 1.0,
Expand Down Expand Up @@ -956,7 +956,7 @@ impl PrivateLanguageGenerator<MarianForConditionalGeneration, MarianVocab, Maria
&self,
scores: &mut Tensor,
current_length: i64,
max_length: i64,
max_length: Option<i64>,
_forced_bos_token_id: Option<i64>,
) {
let _ = scores.index_fill_(
Expand All @@ -966,8 +966,10 @@ impl PrivateLanguageGenerator<MarianForConditionalGeneration, MarianVocab, Maria
.to_device(scores.device()),
f64::NEG_INFINITY,
);
if current_length == max_length - 1 {
self.force_token_id_generation(scores, self.get_eos_ids().as_ref().unwrap());
if let Some(max_length) = max_length {
if current_length == max_length - 1 {
self.force_token_id_generation(scores, self.get_eos_ids().as_ref().unwrap());
}
}
}

Expand Down Expand Up @@ -1006,15 +1008,17 @@ impl PrivateLanguageGenerator<MarianForConditionalGeneration, MarianVocab, Maria
fn encode_prompt_text<S>(
&self,
prompt_text: &[S],
max_len: i64,
max_len: Option<i64>,
pad_token_id: Option<i64>,
) -> Tensor
where
S: AsRef<str> + Sync,
{
let tokens = self._get_tokenizer().encode_list(
prompt_text,
max_len as usize,
max_len
.map(|max_len| max_len as usize)
.unwrap_or(usize::MAX),
&TruncationStrategy::LongestFirst,
0,
);
Expand Down
16 changes: 10 additions & 6 deletions src/mbart/mbart_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -862,7 +862,7 @@ impl MBartGenerator {
/// # let weights_path = &home.as_path().join("model.ot");
/// let device = Device::cuda_if_available();
/// let generate_config = GenerateConfig {
/// max_length: 30,
/// max_length: Some(30),
/// do_sample: true,
/// num_beams: 5,
/// temperature: 1.1,
Expand Down Expand Up @@ -983,13 +983,15 @@ impl PrivateLanguageGenerator<MBartForConditionalGeneration, MBart50Vocab, MBart
&self,
scores: &mut Tensor,
current_length: i64,
max_length: i64,
max_length: Option<i64>,
forced_bos_token_id: Option<i64>,
) {
if current_length == 1 {
self.force_token_id_generation(scores, &[forced_bos_token_id.unwrap_or(250004)]);
} else if current_length == max_length - 1 {
self.force_token_id_generation(scores, self.get_eos_ids().as_ref().unwrap());
} else if let Some(max_length) = max_length {
if current_length == max_length - 1 {
self.force_token_id_generation(scores, self.get_eos_ids().as_ref().unwrap());
}
}
}

Expand Down Expand Up @@ -1028,15 +1030,17 @@ impl PrivateLanguageGenerator<MBartForConditionalGeneration, MBart50Vocab, MBart
fn encode_prompt_text<S>(
&self,
prompt_text: &[S],
max_len: i64,
max_len: Option<i64>,
pad_token_id: Option<i64>,
) -> Tensor
where
S: AsRef<str> + Sync,
{
let tokens = self._get_tokenizer().encode_list(
prompt_text,
max_len as usize,
max_len
.map(|max_len| max_len as usize)
.unwrap_or(usize::MAX),
&TruncationStrategy::LongestFirst,
0,
);
Expand Down
2 changes: 1 addition & 1 deletion src/openai_gpt/openai_gpt_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ impl OpenAIGenerator {
/// use rust_bert::openai_gpt::OpenAIGenerator;
/// use rust_bert::pipelines::generation_utils::GenerateConfig;
/// let generate_config = GenerateConfig {
/// max_length: 30,
/// max_length: Some(30),
/// do_sample: true,
/// num_beams: 5,
/// temperature: 1.1,
Expand Down
16 changes: 10 additions & 6 deletions src/pegasus/pegasus_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,7 @@ impl PegasusConditionalGenerator {
/// # let weights_path = &home.as_path().join("model.ot");
/// let device = Device::cuda_if_available();
/// let generate_config = GenerateConfig {
/// max_length: 30,
/// max_length: Some(30),
/// do_sample: true,
/// num_beams: 5,
/// temperature: 1.1,
Expand Down Expand Up @@ -710,11 +710,13 @@ impl PrivateLanguageGenerator<PegasusForConditionalGeneration, PegasusVocab, Peg
&self,
scores: &mut Tensor,
current_length: i64,
max_length: i64,
max_length: Option<i64>,
_forced_bos_token_id: Option<i64>,
) {
if current_length == max_length - 1 {
self.force_token_id_generation(scores, self.get_eos_ids().as_ref().unwrap());
if let Some(max_length) = max_length {
if current_length == max_length - 1 {
self.force_token_id_generation(scores, self.get_eos_ids().as_ref().unwrap());
}
}
}

Expand Down Expand Up @@ -753,15 +755,17 @@ impl PrivateLanguageGenerator<PegasusForConditionalGeneration, PegasusVocab, Peg
fn encode_prompt_text<S>(
&self,
prompt_text: &[S],
max_len: i64,
max_len: Option<i64>,
pad_token_id: Option<i64>,
) -> Tensor
where
S: AsRef<str> + Sync,
{
let tokens = self._get_tokenizer().encode_list(
prompt_text,
max_len as usize,
max_len
.map(|max_len| max_len as usize)
.unwrap_or(usize::MAX),
&TruncationStrategy::LongestFirst,
0,
);
Expand Down
Loading

0 comments on commit 05367b4

Please sign in to comment.