Skip to content

Commit

Permalink
add GenerationConfig and ut (mindspore-lab#484)
Browse files Browse the repository at this point in the history
  • Loading branch information
Geaming2002 authored May 7, 2023
1 parent c336e96 commit 1b079d4
Show file tree
Hide file tree
Showing 5 changed files with 237 additions and 0 deletions.
1 change: 1 addition & 0 deletions mindnlp/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from mindnlp.modules.crf import CRF
from mindnlp.modules.loss import RDropLoss, CMRC2018Loss
from mindnlp.modules.rnns import *
from mindnlp.modules.generation import *

if less_min_pynative_first:
from mindnlp._legacy.nn.transformer import Transformer, TransformerDecoder, TransformerEncoder, \
Expand Down
18 changes: 18 additions & 0 deletions mindnlp/modules/generation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Generation
"""
from .generation_config import GenerationConfig
174 changes: 174 additions & 0 deletions mindnlp/modules/generation/generation_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# pylint:disable=R0902
"""
GenerationConfig
"""

from typing import Dict, Any
from mindnlp.abc import PreTrainedConfig

class GenerationConfig:
"""
Class that holds a configuration for a generation task.
"""
def __init__(self, **kwargs):
# Parameters that control the length of the output
self.max_length = kwargs.pop("max_length", 20)
self.max_new_tokens = kwargs.pop("max_new_tokens", None)
self.min_length = kwargs.pop("min_length", 0)
self.min_new_tokens = kwargs.pop("min_new_tokens", None)
self.early_stopping = kwargs.pop("early_stopping", False)
self.max_time = kwargs.pop("max_time", None)

# Parameters that control the generation strategy used
self.do_sample = kwargs.pop("do_sample", False)
self.num_beams = kwargs.pop("num_beams", 1)
self.num_beam_groups = kwargs.pop("num_beam_groups", 1)
self.penalty_alpha = kwargs.pop("penalty_alpha", None)
self.use_cache = kwargs.pop("use_cache", True)

# Parameters for manipulation of the model output logits
self.temperature = kwargs.pop("temperature", 1.0)
self.top_k = kwargs.pop("top_k", 50)
self.top_p = kwargs.pop("top_p", 1.0)
self.typical_p = kwargs.pop("typical_p", 1.0)
self.epsilon_cutoff = kwargs.pop("epsilon_cutoff", 0.0)
self.eta_cutoff = kwargs.pop("eta_cutoff", 0.0)
self.diversity_penalty = kwargs.pop("diversity_penalty", 0.0)
self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0)
self.encoder_repetition_penalty = kwargs.pop("encoder_repetition_penalty", 1.0)
self.length_penalty = kwargs.pop("length_penalty", 1.0)
self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0)
self.bad_words_ids = kwargs.pop("bad_words_ids", None)
self.force_words_ids = kwargs.pop("force_words_ids", None)
self.renormalize_logits = kwargs.pop("renormalize_logits", False)
self.constraints = kwargs.pop("constraints", None)
self.forced_bos_token_id = kwargs.pop("forced_bos_token_id", None)
self.forced_eos_token_id = kwargs.pop("forced_eos_token_id", None)
self.remove_invalid_values = kwargs.pop("remove_invalid_values", False)
self.exponential_decay_length_penalty = kwargs.pop("exponential_decay_length_penalty", None)
self.suppress_tokens = kwargs.pop("suppress_tokens", None)
self.begin_suppress_tokens = kwargs.pop("begin_suppress_tokens", None)
self.forced_decoder_ids = kwargs.pop("forced_decoder_ids", None)

# Parameters that define the output variables of `generate`
self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
self.output_attentions = kwargs.pop("output_attentions", False)
self.output_hidden_states = kwargs.pop("output_hidden_states", False)
self.output_scores = kwargs.pop("output_scores", False)
self.return_dict_in_generate = kwargs.pop("return_dict_in_generate", False)

# Special tokens that can be used at generation time
self.pad_token_id = kwargs.pop("pad_token_id", None)
self.bos_token_id = kwargs.pop("bos_token_id", None)
self.eos_token_id = kwargs.pop("eos_token_id", None)

# Generation parameters exclusive to encoder-decoder models
self.encoder_no_repeat_ngram_size = kwargs.pop("encoder_no_repeat_ngram_size", 0)
self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None)

# Wild card
self.generation_kwargs = kwargs.pop("generation_kwargs", {})

# From model config
self._from_model_config = kwargs.pop("from_model_config", False)

def set_from_model_config(self, value:bool):
"""set _from_model_config"""
assert isinstance(value, bool), "value must be of type bool"
self._from_model_config = value

def update(self, **kwargs):
"""
Updates attributes of this class instance with attributes from `kwargs` if they match existing atributtes,
returning all the unused kwargs.
Args:
kwargs (`Dict[str, Any]`):
Dictionary of attributes to tentatively update this class.
Returns:
`Dict[str, Any]`: Dictionary containing all the key-value pairs that were not used to update the instance.
"""
to_remove = []
for key, value in kwargs.items():
if hasattr(self, key):
setattr(self, key, value)
to_remove.append(key)

# remove all the attributes that were updated, without modifying the input dict
unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove}
return unused_kwargs

@classmethod
def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "GenerationConfig":
"""
Instantiates a [`GenerationConfig`] from a Python dictionary of parameters.
Args:
config_dict (`Dict[str, Any]`):
Dictionary that will be used to instantiate the configuration object.
kwargs (`Dict[str, Any]`):
Additional parameters from which to initialize the configuration object.
Returns:
[`GenerationConfig`]: The configuration object instantiated from those parameters.
"""
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
# Those arguments may be passed along for our internal telemetry.
# We remove them so they don't appear in `return_unused_kwargs`.
kwargs.pop("_from_auto", None)
kwargs.pop("_from_pipeline", None)
# The commit hash might have been updated in the `config_dict`, we don't want the kwargs to erase that update.
if "_commit_hash" in kwargs and "_commit_hash" in config_dict:
kwargs["_commit_hash"] = config_dict["_commit_hash"]

config = cls(**config_dict)
unused_kwargs = config.update(**kwargs)

# logger.info(f"Generate config {config}")
if return_unused_kwargs:
return config, unused_kwargs
return config

@classmethod
def from_model_config(cls, model_config: PreTrainedConfig) -> "GenerationConfig":
"""
Instantiates a [`GenerationConfig`] from a [`PretrainedConfig`]. This function is useful to convert legacy
[`PretrainedConfig`] objects, which may contain generation parameters, into a stand-alone [`GenerationConfig`].
Args:
model_config (`PretrainedConfig`):
The model config that will be used to instantiate the generation config.
Returns:
[`GenerationConfig`]: The configuration object instantiated from those parameters.
"""
config_dict = model_config.to_dict()
config = cls.from_dict(config_dict, return_unused_kwargs=False)

# Special case: some models have generation attributes set in the decoder. Use them if still unset in the
# generation config.
for decoder_name in ("decoder", "generator"):
if decoder_name in config_dict:
default_generation_config = GenerationConfig()
decoder_config = config_dict[decoder_name]
for attr in config.to_dict().keys():
if attr in decoder_config and getattr(config, attr) == getattr(default_generation_config, attr):
setattr(config, attr, decoder_config[attr])

config.set_from_model_config(True)
return config
Empty file.
44 changes: 44 additions & 0 deletions tests/ut/modules/generation/test_generation_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Test GenerationConfig
"""

import unittest
from mindnlp.modules import GenerationConfig
from mindnlp.models.t5 import T5Config

class TestGenerationConfig(unittest.TestCase):
r"""
Test module.generation GenerationConfig
"""
def test_generation_config_from_model_config(self):
"""test GenerationConfig.from_model_config()"""
config = T5Config()
generation_config = GenerationConfig.from_model_config(config)
assert config.eos_token_id == generation_config.eos_token_id

def test_generation_config_from_dict(self):
"""test GenerationConfig.from_dict()"""
config_dict = T5Config().__dict__
generation_config = GenerationConfig.from_dict(config_dict)
assert config_dict['eos_token_id'] == generation_config.eos_token_id

def test_generation_config_update(self):
"""test GenerationConfig.update()"""
config = T5Config()
generation_config = GenerationConfig.from_model_config(config)
generation_config.update(eos_token_id=666)
assert generation_config.eos_token_id == 666

0 comments on commit 1b079d4

Please sign in to comment.