Skip to content

Commit

Permalink
upload generation files (mindspore-lab#538)
Browse files Browse the repository at this point in the history
  • Loading branch information
Geaming2002 authored May 29, 2023
1 parent 3f6723c commit d5a7637
Show file tree
Hide file tree
Showing 12 changed files with 1,038 additions and 7 deletions.
3 changes: 2 additions & 1 deletion mindnlp/abc/configs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,6 @@
"""

from .pretrained_config import PreTrainedConfig
from .generation_config import GenerationConfig

__all__ = ['PreTrainedConfig']
__all__ = ['PreTrainedConfig', 'GenerationConfig']
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
"""

from typing import Dict, Any
from mindnlp.abc import PreTrainedConfig

class GenerationConfig:
"""
Expand Down Expand Up @@ -145,7 +144,7 @@ def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "GenerationConfig":
return config

@classmethod
def from_model_config(cls, model_config: PreTrainedConfig) -> "GenerationConfig":
def from_model_config(cls, model_config) -> "GenerationConfig":
"""
Instantiates a [`GenerationConfig`] from a [`PretrainedConfig`]. This function is useful to convert legacy
[`PretrainedConfig`] objects, which may contain generation parameters, into a stand-alone [`GenerationConfig`].
Expand Down
776 changes: 776 additions & 0 deletions mindnlp/abc/mixins/generation_mixin.py

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion mindnlp/abc/models/pretrained_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

from mindnlp.configs import HF_MODEL_URL_BASE
from mindnlp.utils.download import cached_path
from mindnlp.abc.configs import PreTrainedConfig
from mindnlp.abc.configs import PreTrainedConfig, GenerationConfig
from mindnlp.abc.mixins import CellUtilMixin, GenerationMixin

_init_weights = True
Expand All @@ -42,11 +42,13 @@ class PreTrainedModel(nn.Cell, CellUtilMixin, GenerationMixin):
config_class = None
pretrained_model_archive_map = {}
base_model_prefix = ""
main_input_name = "input_ids"

def __init__(self, config):
super().__init__(config)
# Save config in model
self.config = config
self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None

def post_init(self):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,8 @@
"""
Generation
"""
from .generation_config import GenerationConfig
from .beam_constraints import *
from .beam_search import *
from .logits_process import *
from .stopping_criteria import *
from .utils import *
28 changes: 28 additions & 0 deletions mindnlp/generation/beam_constraints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""
Beam constraints
"""

class DisjunctiveConstraint:
"""DisjunctiveConstraint"""
def __init__(self) -> None:
pass

class PhrasalConstraint:
"""PhrasalConstraint"""
def __init__(self) -> None:
pass
33 changes: 33 additions & 0 deletions mindnlp/generation/beam_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""
Beam search
"""

class BeamScorer:
"""BeamScorer"""
def __init__(self) -> None:
pass

class BeamSearchScorer:
"""BeamSearchScorer"""
def __init__(self) -> None:
pass

class ConstrainedBeamSearchScorer:
"""ConstrainedBeamSearchScorer"""
def __init__(self) -> None:
pass
43 changes: 43 additions & 0 deletions mindnlp/generation/logits_process.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# pylint: disable=W0613

"""
Logits process
"""

import inspect
import mindspore

class LogitsProcessorList(list):
"""
This class can be used to create a list of [`LogitsProcessor`] or [`LogitsWarper`] to subsequently process a
`scores` input tensor. This class inherits from list and adds a specific *__call__* method to apply each
[`LogitsProcessor`] or [`LogitsWarper`] to the inputs.
"""

def __call__(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor, **kwargs) -> mindspore.Tensor:
for processor in self:
function_args = inspect.signature(processor.__call__).parameters
if len(function_args) > 2:
if not all(arg in kwargs for arg in list(function_args.keys())[2:]):
raise ValueError(
f"Make sure that all the required parameters: {list(function_args.keys())} for "
f"{processor.__class__} are passed to the logits processor."
)
scores = processor(input_ids, scores, **kwargs)
else:
scores = processor(input_ids, scores)
return scores
146 changes: 146 additions & 0 deletions mindnlp/generation/stopping_criteria.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""
Stopping criteria
"""
import time
import warnings
from copy import deepcopy
from typing import Optional

import mindspore



STOPPING_CRITERIA_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary.
Indices can be obtained using [`BertTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
scores (`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`):
Prediction scores of a language modeling head. These can be scores for each vocabulary token before SoftMax
or scores for each vocabulary token after SoftMax.
kwargs:
Additional stopping criteria specific kwargs.
Return:
`bool`. `False` indicates we should continue, `True` indicates we should stop.
"""


class StoppingCriteria():
"""Abstract base class for all stopping criteria that can be applied during generation."""

def __call__(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor, **kwargs) -> bool:
raise NotImplementedError("StoppingCriteria needs to be subclassed")


class MaxLengthCriteria(StoppingCriteria):
"""
This class can be used to stop generation whenever the full generated number of tokens exceeds `max_length`. Keep
in mind for decoder-only type of transformers, this will include the initial prompted tokens.
Args:
max_length (`int`):
The maximum length that the output sequence can have in number of tokens.
"""

def __init__(self, max_length: int):
self.max_length = max_length

def __call__(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor, **kwargs) -> bool:
return input_ids.shape[-1] >= self.max_length


class MaxNewTokensCriteria(StoppingCriteria):
"""
This class can be used to stop generation whenever the generated number of tokens exceeds `max_new_tokens`. Keep in
mind for decoder-only type of transformers, this will **not** include the initial prompted tokens. This is very
close to `MaxLengthCriteria` but ignores the number of initial tokens.
Args:
start_length (`int`):
The number of initial tokens.
max_new_tokens (`int`):
The maximum number of tokens to generate.
"""

def __init__(self, start_length: int, max_new_tokens: int):
warnings.warn(
"The class `MaxNewTokensCriteria` is deprecated. "
f"Please use `MaxLengthCriteria(max_length={start_length + max_new_tokens})` "
"with `max_length = start_length + max_new_tokens` instead.",
FutureWarning,
)
self.start_length = start_length
self.max_new_tokens = max_new_tokens
self.max_length = start_length + max_new_tokens

def __call__(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor, **kwargs) -> bool:
return input_ids.shape[-1] >= self.max_length


class MaxTimeCriteria(StoppingCriteria):
"""
This class can be used to stop generation whenever the full generation exceeds some amount of time. By default, the
time will start being counted when you initialize this function. You can override this by passing an
`initial_time`.
Args:
max_time (`float`):
The maximum allowed time in seconds for the generation.
initial_time (`float`, *optional*, defaults to `time.time()`):
The start of the generation allowed time.
"""

def __init__(self, max_time: float, initial_timestamp: Optional[float] = None):
self.max_time = max_time
self.initial_timestamp = time.time() if initial_timestamp is None else initial_timestamp

def __call__(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor, **kwargs) -> bool:
return time.time() - self.initial_timestamp > self.max_time


class StoppingCriteriaList(list):
"""StoppingCriteriaList"""
def __call__(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor, **kwargs) -> bool:
return any(criteria(input_ids, scores) for criteria in self)

@property
def max_length(self) -> Optional[int]:
"""return max length"""
for stopping_criterium in self:
if isinstance(stopping_criterium, MaxLengthCriteria):
return stopping_criterium.max_length
if isinstance(stopping_criterium, MaxNewTokensCriteria):
return stopping_criterium.max_length
return None


def validate_stopping_criteria(stopping_criteria: StoppingCriteriaList, max_length: int) -> StoppingCriteriaList:
"""validate stopping criteria"""
stopping_max_length = stopping_criteria.max_length
new_stopping_criteria = deepcopy(stopping_criteria)
if stopping_max_length is not None and stopping_max_length != max_length:
warnings.warn("You set different `max_length` for stopping criteria and `max_length` parameter", UserWarning)
elif stopping_max_length is None:
new_stopping_criteria.append(MaxLengthCriteria(max_length=max_length))
return new_stopping_criteria
File renamed without changes.
1 change: 0 additions & 1 deletion mindnlp/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from mindnlp.modules.crf import CRF
from mindnlp.modules.loss import RDropLoss, CMRC2018Loss
from mindnlp.modules.rnns import *
from mindnlp.modules.generation import *
from mindnlp.modules.accumulator import *

if less_min_pynative_first:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@
from setuptools.command.egg_info import egg_info
from setuptools.command.build_py import build_py


version = '0.1.1'
cur_dir = os.path.dirname(os.path.realpath(__file__))
pkg_dir = os.path.join(cur_dir, 'build')


def clean():
# pylint: disable=unused-argument
def readonly_handler(func, path, execinfo):
Expand Down

0 comments on commit d5a7637

Please sign in to comment.