forked from mindspore-lab/mindnlp
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
upload generation files (mindspore-lab#538)
- Loading branch information
1 parent
3f6723c
commit d5a7637
Showing
12 changed files
with
1,038 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
# Copyright 2023 Huawei Technologies Co., Ltd | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================ | ||
|
||
""" | ||
Beam constraints | ||
""" | ||
|
||
class DisjunctiveConstraint: | ||
"""DisjunctiveConstraint""" | ||
def __init__(self) -> None: | ||
pass | ||
|
||
class PhrasalConstraint: | ||
"""PhrasalConstraint""" | ||
def __init__(self) -> None: | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
# Copyright 2023 Huawei Technologies Co., Ltd | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================ | ||
|
||
""" | ||
Beam search | ||
""" | ||
|
||
class BeamScorer: | ||
"""BeamScorer""" | ||
def __init__(self) -> None: | ||
pass | ||
|
||
class BeamSearchScorer: | ||
"""BeamSearchScorer""" | ||
def __init__(self) -> None: | ||
pass | ||
|
||
class ConstrainedBeamSearchScorer: | ||
"""ConstrainedBeamSearchScorer""" | ||
def __init__(self) -> None: | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
# Copyright 2023 Huawei Technologies Co., Ltd | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================ | ||
# pylint: disable=W0613 | ||
|
||
""" | ||
Logits process | ||
""" | ||
|
||
import inspect | ||
import mindspore | ||
|
||
class LogitsProcessorList(list): | ||
""" | ||
This class can be used to create a list of [`LogitsProcessor`] or [`LogitsWarper`] to subsequently process a | ||
`scores` input tensor. This class inherits from list and adds a specific *__call__* method to apply each | ||
[`LogitsProcessor`] or [`LogitsWarper`] to the inputs. | ||
""" | ||
|
||
def __call__(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor, **kwargs) -> mindspore.Tensor: | ||
for processor in self: | ||
function_args = inspect.signature(processor.__call__).parameters | ||
if len(function_args) > 2: | ||
if not all(arg in kwargs for arg in list(function_args.keys())[2:]): | ||
raise ValueError( | ||
f"Make sure that all the required parameters: {list(function_args.keys())} for " | ||
f"{processor.__class__} are passed to the logits processor." | ||
) | ||
scores = processor(input_ids, scores, **kwargs) | ||
else: | ||
scores = processor(input_ids, scores) | ||
return scores |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
# Copyright 2023 Huawei Technologies Co., Ltd | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================ | ||
|
||
""" | ||
Stopping criteria | ||
""" | ||
import time | ||
import warnings | ||
from copy import deepcopy | ||
from typing import Optional | ||
|
||
import mindspore | ||
|
||
|
||
|
||
STOPPING_CRITERIA_INPUTS_DOCSTRING = r""" | ||
Args: | ||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): | ||
Indices of input sequence tokens in the vocabulary. | ||
Indices can be obtained using [`BertTokenizer`]. See [`PreTrainedTokenizer.encode`] and | ||
[`PreTrainedTokenizer.__call__`] for details. | ||
[What are input IDs?](../glossary#input-ids) | ||
scores (`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`): | ||
Prediction scores of a language modeling head. These can be scores for each vocabulary token before SoftMax | ||
or scores for each vocabulary token after SoftMax. | ||
kwargs: | ||
Additional stopping criteria specific kwargs. | ||
Return: | ||
`bool`. `False` indicates we should continue, `True` indicates we should stop. | ||
""" | ||
|
||
|
||
class StoppingCriteria(): | ||
"""Abstract base class for all stopping criteria that can be applied during generation.""" | ||
|
||
def __call__(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor, **kwargs) -> bool: | ||
raise NotImplementedError("StoppingCriteria needs to be subclassed") | ||
|
||
|
||
class MaxLengthCriteria(StoppingCriteria): | ||
""" | ||
This class can be used to stop generation whenever the full generated number of tokens exceeds `max_length`. Keep | ||
in mind for decoder-only type of transformers, this will include the initial prompted tokens. | ||
Args: | ||
max_length (`int`): | ||
The maximum length that the output sequence can have in number of tokens. | ||
""" | ||
|
||
def __init__(self, max_length: int): | ||
self.max_length = max_length | ||
|
||
def __call__(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor, **kwargs) -> bool: | ||
return input_ids.shape[-1] >= self.max_length | ||
|
||
|
||
class MaxNewTokensCriteria(StoppingCriteria): | ||
""" | ||
This class can be used to stop generation whenever the generated number of tokens exceeds `max_new_tokens`. Keep in | ||
mind for decoder-only type of transformers, this will **not** include the initial prompted tokens. This is very | ||
close to `MaxLengthCriteria` but ignores the number of initial tokens. | ||
Args: | ||
start_length (`int`): | ||
The number of initial tokens. | ||
max_new_tokens (`int`): | ||
The maximum number of tokens to generate. | ||
""" | ||
|
||
def __init__(self, start_length: int, max_new_tokens: int): | ||
warnings.warn( | ||
"The class `MaxNewTokensCriteria` is deprecated. " | ||
f"Please use `MaxLengthCriteria(max_length={start_length + max_new_tokens})` " | ||
"with `max_length = start_length + max_new_tokens` instead.", | ||
FutureWarning, | ||
) | ||
self.start_length = start_length | ||
self.max_new_tokens = max_new_tokens | ||
self.max_length = start_length + max_new_tokens | ||
|
||
def __call__(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor, **kwargs) -> bool: | ||
return input_ids.shape[-1] >= self.max_length | ||
|
||
|
||
class MaxTimeCriteria(StoppingCriteria): | ||
""" | ||
This class can be used to stop generation whenever the full generation exceeds some amount of time. By default, the | ||
time will start being counted when you initialize this function. You can override this by passing an | ||
`initial_time`. | ||
Args: | ||
max_time (`float`): | ||
The maximum allowed time in seconds for the generation. | ||
initial_time (`float`, *optional*, defaults to `time.time()`): | ||
The start of the generation allowed time. | ||
""" | ||
|
||
def __init__(self, max_time: float, initial_timestamp: Optional[float] = None): | ||
self.max_time = max_time | ||
self.initial_timestamp = time.time() if initial_timestamp is None else initial_timestamp | ||
|
||
def __call__(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor, **kwargs) -> bool: | ||
return time.time() - self.initial_timestamp > self.max_time | ||
|
||
|
||
class StoppingCriteriaList(list): | ||
"""StoppingCriteriaList""" | ||
def __call__(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor, **kwargs) -> bool: | ||
return any(criteria(input_ids, scores) for criteria in self) | ||
|
||
@property | ||
def max_length(self) -> Optional[int]: | ||
"""return max length""" | ||
for stopping_criterium in self: | ||
if isinstance(stopping_criterium, MaxLengthCriteria): | ||
return stopping_criterium.max_length | ||
if isinstance(stopping_criterium, MaxNewTokensCriteria): | ||
return stopping_criterium.max_length | ||
return None | ||
|
||
|
||
def validate_stopping_criteria(stopping_criteria: StoppingCriteriaList, max_length: int) -> StoppingCriteriaList: | ||
"""validate stopping criteria""" | ||
stopping_max_length = stopping_criteria.max_length | ||
new_stopping_criteria = deepcopy(stopping_criteria) | ||
if stopping_max_length is not None and stopping_max_length != max_length: | ||
warnings.warn("You set different `max_length` for stopping criteria and `max_length` parameter", UserWarning) | ||
elif stopping_max_length is None: | ||
new_stopping_criteria.append(MaxLengthCriteria(max_length=max_length)) | ||
return new_stopping_criteria |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters