Skip to content

Commit

Permalink
fix mindspore 2.0rc1 version caused error (mindspore-lab#416)
Browse files Browse the repository at this point in the history
  • Loading branch information
lvyufeng authored Apr 4, 2023
1 parent d714d02 commit 6270fe8
Show file tree
Hide file tree
Showing 16 changed files with 85 additions and 39 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ encoder = RNNEncoder(embedding, lstm_layer)

# build head
head = nn.SequentialCell([
nn.Dropout(1 - dropout),
nn.Dropout(p=dropout),
nn.Sigmoid(),
nn.Dense(hidden_size * 2, output_size,
weight_init=HeUniform(math.sqrt(5)),
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/question_answer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ using MindNLP:
self.char_channel_size = char_channel_size
self.word_vocab = word_vocab
self.hidden_size = hidden_size
self.dropout = nn.Dropout(1 - dropout)
self.dropout = nn.Dropout(p=dropout)
self.init_embed = initializer(Uniform(0.001), [char_vocab_size, char_dim])
self.embed = Parameter(self.init_embed, name='embed')
Expand Down
4 changes: 2 additions & 2 deletions examples/question_answer/bidaf_squad_concise.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@
" self.char_channel_size = char_channel_size\n",
" self.word_vocab = word_vocab\n",
" self.hidden_size = hidden_size\n",
" self.dropout = nn.Dropout(1 - dropout)\n",
" self.dropout = nn.Dropout(p=dropout)\n",
" self.init_embed = initializer(Uniform(0.001), [char_vocab_size, char_dim])\n",
" self.embed = Parameter(self.init_embed, name='embed')\n",
"\n",
Expand Down Expand Up @@ -631,7 +631,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.12"
"version": "3.7.5"
},
"orig_nbformat": 4
},
Expand Down
7 changes: 2 additions & 5 deletions mindnlp/_legacy/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,7 @@
from mindspore.common import dtype as mstype
from mindspore.ops._primitive_cache import _get_cache_prim
from mindspore.ops import constexpr
from packaging import version

MS_COMPATIBLE_VERSION = '1.10.1'

from mindnlp.utils import less_min_api_compatible
tensor_slice = ops.Slice()
cast_ = ops.Cast()
scalar_to_tensor_ = ops.ScalarToTensor()
Expand Down Expand Up @@ -86,7 +83,7 @@ def kl_div(inputs, target, reduction='none', log_target=False):

def split(x, size, axis=0):
"""inner split"""
if version.parse(mindspore.__version__) <= version.parse(MS_COMPATIBLE_VERSION):
if less_min_api_compatible:
num = int(x.shape[axis] / size)
return ops.split(x, axis, num)
return ops.split(x, split_size_or_sections=size, axis=axis)
Expand Down
2 changes: 1 addition & 1 deletion mindnlp/models/albert/albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __init__(self, config):
self.layer_norm = nn.LayerNorm(
(config.embedding_size,),
epsilon=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.dropout = nn.Dropout(p=config.hidden_dropout_prob)

def construct(self, input_ids, token_type_ids=None, position_ids=None):
seq_len = input_ids.shape[1]
Expand Down
2 changes: 1 addition & 1 deletion mindnlp/models/bert/bert_config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 Huawei Technologies Co., Ltd
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion mindnlp/models/gpt2/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -869,7 +869,7 @@ def __init__(self, config):
classifier_dropout = config.hidden_dropout
else:
classifier_dropout = 0.1
self.dropout = nn.Dropout(classifier_dropout)
self.dropout = nn.Dropout(p=classifier_dropout)
self.classifier = nn.Dense(config.hidden_size, config.num_labels)

def construct(self, input_ids=None, past_key_values=None, attention_mask=None, token_type_ids=None,
Expand Down
4 changes: 2 additions & 2 deletions mindnlp/models/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,11 +145,11 @@ def __init__(self, config):

self.first_dropout = Identity()
if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0:
self.first_dropout = nn.Dropout(config.summary_first_dropout)
self.first_dropout = nn.Dropout(p=config.summary_first_dropout)

self.last_dropout = Identity()
if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0:
self.last_dropout = nn.Dropout(config.summary_last_dropout)
self.last_dropout = nn.Dropout(p=config.summary_last_dropout)

def construct(self, hidden_states: Tensor, cls_index: Optional[Tensor] = None) -> Tensor:
if self.summary_type == "last":
Expand Down
2 changes: 1 addition & 1 deletion mindnlp/models/xlm/xlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -974,7 +974,7 @@ def __init__(self, config):
self.num_labels = config.num_labels

self.transformer = XLMModel(config)
self.dropout = nn.Dropout(config.dropout)
self.dropout = nn.Dropout(p=config.dropout)
self.classifier = nn.Dense(config.hidden_size, config.num_labels)

# Initialize weights and apply final processing
Expand Down
15 changes: 11 additions & 4 deletions mindnlp/transforms/tokenizers/bert_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@
BertTokenizer
"""

from typing import Union
import numpy as np
from mindspore.dataset.transforms.transforms import PyTensorOperation
from mindspore.dataset.text.transforms import Implementation
from mindspore.dataset.text import Vocab as msVocab
from tokenizers.implementations import BertWordPieceTokenizer

from mindnlp.vocab import Vocab
class BertTokenizer(PyTensorOperation):
"""
Tokenizer used for Bert text process.
Expand Down Expand Up @@ -51,10 +53,15 @@ class BertTokenizer(PyTensorOperation):
"""

# @check_decode
def __init__(self, vocab, lower_case:bool = True, return_token = False):
def __init__(self, vocab: Union[msVocab, Vocab], lower_case:bool = True, return_token = False):
super().__init__()
self.tokenizer = BertWordPieceTokenizer(vocab=vocab.vocab(), lowercase=lower_case)
if isinstance(vocab, msVocab):
vocab_dict = vocab.vocab()
elif isinstance(vocab, Vocab):
vocab_dict = vocab.vocab
else:
raise ValueError(f'only support Vocab class from mindspore or mindnlp, but got {vocab}')
self.tokenizer = BertWordPieceTokenizer(vocab=vocab_dict, lowercase=lower_case)
self.return_token = return_token
self.implementation = Implementation.PY

Expand Down
17 changes: 14 additions & 3 deletions mindnlp/utils/compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,19 @@

MIN_COMPATIBLE_VERSION = '1.8.1'
MAX_GRAPH_FIRST_VERSION = '1.12.0'
API_COMPATIBLE_VERSION = '1.10.1'

less_min_compatible = version.parse(mindspore.__version__) < version.parse(MIN_COMPATIBLE_VERSION)
less_min_pynative_first = version.parse(mindspore.__version__) <= version.parse(MAX_GRAPH_FIRST_VERSION)
MS_VERSION = mindspore.__version__
MS_VERSION = MS_VERSION.replace('rc', '')

__all__ = ['less_min_compatible', 'less_min_pynative_first']
less_min_minddata_compatible = version.parse(MS_VERSION) <= version.parse(MIN_COMPATIBLE_VERSION)
less_min_compatible = version.parse(MS_VERSION) < version.parse(MIN_COMPATIBLE_VERSION)
less_min_pynative_first = version.parse(MS_VERSION) <= version.parse(MAX_GRAPH_FIRST_VERSION)
less_min_api_compatible = version.parse(MS_VERSION) <= version.parse(API_COMPATIBLE_VERSION)

__all__ = [
'less_min_compatible',
'less_min_pynative_first',
'less_min_api_compatible',
'less_min_minddata_compatible'
]
5 changes: 4 additions & 1 deletion mindnlp/vocab/vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,10 @@ def from_pretrained(cls, name="glove.6B.50d", root=DEFAULT_ROOT,

return vocab


@property
def vocab(self):
"""return vocab dict."""
return self._token_dict

pretrained_aliases = {
"glove.6B.50d": "https://download.mindspore.cn/toolkits/mindnlp/vocab/Glove/glove.6B.50d.txt",
Expand Down
11 changes: 8 additions & 3 deletions tests/ut/modules/loss/test_cmrc2018loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,13 @@ def forward(tensor_a, tensor_b, my_context_len, tensor_c, tensor_d):
loss = cmrc_loss(tensor_a, tensor_b, my_context_len, tensor_c, tensor_d)
return loss

if jit:
forward = ms_jit(forward)
@ms_jit
def forward_jit(tensor_a, tensor_b, my_context_len, tensor_c, tensor_d):
loss = cmrc_loss(tensor_a, tensor_b, my_context_len, tensor_c, tensor_d)
return loss

loss = forward(tensor_a, tensor_b, my_context_len, tensor_c, tensor_d)
if jit:
loss = forward_jit(tensor_a, tensor_b, my_context_len, tensor_c, tensor_d)
else:
loss = forward(tensor_a, tensor_b, my_context_len, tensor_c, tensor_d)
assert loss.shape == ()
16 changes: 12 additions & 4 deletions tests/ut/modules/loss/test_rdrop_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,23 @@ def test_loss(self, jit):
Test RDropLoss loss
"""
r_drop_loss = RDropLoss()
temp_p = Tensor(np.array([1., 0., 1.]), mindspore.float32)
temp_q = Tensor(np.array([0.2, 0.3, 1.1]), mindspore.float32)

@ms_jit
def forward_jit(temp_p, temp_q):
loss = r_drop_loss(temp_p, temp_q)
return loss

def forward(temp_p, temp_q):
loss = r_drop_loss(temp_p, temp_q)
return loss

temp_p = Tensor(np.array([1., 0., 1.]), mindspore.float32)
temp_q = Tensor(np.array([0.2, 0.3, 1.1]), mindspore.float32)


if jit:
forward = ms_jit(forward)
loss = forward(temp_p, temp_q)
loss = forward_jit(temp_p, temp_q)
else:
loss = forward(temp_p, temp_q)

assert np.allclose(loss.asnumpy(), np.array([0.10013707]), 1e-5, 1e-5)
9 changes: 3 additions & 6 deletions tests/ut/transforms/test_add_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,9 @@
# ============================================================================
"""Test the AddToken"""

from packaging import version
import mindspore
from mindspore.dataset import NumpySlicesDataset
from mindnlp._legacy.transforms import AddToken

MIN_COMPATIBLE_VERSION = '1.8.1'
from mindnlp.utils import less_min_minddata_compatible

def test_addtoken_begin():
"""test addtoken by dataset.map"""
Expand All @@ -37,7 +34,7 @@ def test_addtoken_begin():
# | ['TOKEN', 'a', 'b', 'c', 'd', 'e'] |
# +---------------------------+
data_after = next(dataset.create_tuple_iterator(output_numpy=True))[0]
if version.parse(mindspore.__version__) <= version.parse(MIN_COMPATIBLE_VERSION):
if less_min_minddata_compatible:
assert data_after.tolist() == [b'TOKEN', b'a', b'b', b'c', b'd', b'e']
else:
assert data_after.tolist() == ['TOKEN', 'a', 'b', 'c', 'd', 'e']
Expand All @@ -58,7 +55,7 @@ def test_addtoken_end():
# | ['a', 'b', 'c', 'd', 'e', 'TOKEN'] |
# +---------------------------+
data_after = next(dataset.create_tuple_iterator(output_numpy=True))[0]
if version.parse(mindspore.__version__) <= version.parse(MIN_COMPATIBLE_VERSION):
if less_min_minddata_compatible:
assert data_after.tolist() == [b'a', b'b', b'c', b'd', b'e', b'TOKEN']
else:
assert data_after.tolist() == ['a', 'b', 'c', 'd', 'e', 'TOKEN']
24 changes: 21 additions & 3 deletions tests/ut/transforms/test_bert_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
"""Test the BertTokenizer"""

import mindspore as ms
from mindspore.dataset import text
from mindspore.dataset import GeneratorDataset
from mindspore.dataset.text import Vocab as msVocab
from mindnlp import Vocab
from mindnlp.transforms import BertTokenizer

def test_bert_tokenizer():
def test_bert_tokenizer_mindnlp_vocab():
"""test BertTokenizer by dataset.map"""
texts = ['i make a small mistake when i\'m working! 床前明月光']
test_dataset = GeneratorDataset(texts, 'text')
Expand All @@ -28,7 +29,24 @@ def test_bert_tokenizer():
"make", "small", "mistake", "##s", "during", "work", "##ing", "hour", "😀", "😃",
"😄", "😁", "+", "/", "-", "=", "12", "28", "40", "16", " ", "I", "[CLS]", "[SEP]",
"[UNK]", "[PAD]", "[MASK]", "[unused1]", "[unused10]"]
vocab = text.Vocab.from_list(vocab_list)
vocab = Vocab(vocab_list)
bert_tokenizer = BertTokenizer(vocab=vocab, lower_case=True, return_token=True)
test_dataset = test_dataset.map(operations=bert_tokenizer)
dataset_after = next(test_dataset.create_tuple_iterator())[0]

assert len(dataset_after) == 19
assert dataset_after.dtype == ms.string

def test_bert_tokenizer_mindspore_vocab():
"""test BertTokenizer by dataset.map"""
texts = ['i make a small mistake when i\'m working! 床前明月光']
test_dataset = GeneratorDataset(texts, 'text')
vocab_list = ["床", "前", "明", "月", "光", "疑", "是", "地", "上", "霜", "举", "头", "望", "低",
"思", "故", "乡","繁", "體", "字", "嘿", "哈", "大", "笑", "嘻", "i", "am", "mak",
"make", "small", "mistake", "##s", "during", "work", "##ing", "hour", "😀", "😃",
"😄", "😁", "+", "/", "-", "=", "12", "28", "40", "16", " ", "I", "[CLS]", "[SEP]",
"[UNK]", "[PAD]", "[MASK]", "[unused1]", "[unused10]"]
vocab = msVocab.from_list(vocab_list)
bert_tokenizer = BertTokenizer(vocab=vocab, lower_case=True, return_token=True)
test_dataset = test_dataset.map(operations=bert_tokenizer)
dataset_after = next(test_dataset.create_tuple_iterator())[0]
Expand Down

0 comments on commit 6270fe8

Please sign in to comment.