forked from mindspore-lab/mindnlp
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add glue dataset (mindspore-lab#496)
- Loading branch information
Showing
3 changed files
with
350 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,3 +17,4 @@ | |
""" | ||
|
||
from .hf_imdb import HF_IMDB, HF_IMDB_Process | ||
from .hf_glue import HF_GLUE |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
# Copyright 2022 Huawei Technologies Co., Ltd | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================ | ||
""" | ||
Hugging Face GLUE load function | ||
""" | ||
# pylint: disable=C0103 | ||
import os | ||
from typing import Union, Tuple | ||
from datasets import load_dataset as hf_load | ||
from mindspore.dataset import GeneratorDataset | ||
from mindnlp.dataset.register import load_dataset | ||
from mindnlp.configs import DEFAULT_ROOT | ||
|
||
|
||
class HFglue: | ||
""" | ||
Hugging Face GLUE dataset source | ||
""" | ||
|
||
def __init__(self, dataset_list, name) -> None: | ||
self.dataset_list = dataset_list | ||
if name in ('cola', 'sst2'): | ||
self._label, self._idx, self._sentence = [], [], [] | ||
elif name in ('mrpc', 'stsb', 'rte', 'wnli'): | ||
self._label, self._idx, self._sentence1, self._sentence2 = [], [], [], [] | ||
elif name == "qqp": | ||
self._label, self._idx, self._question1, self._question2 = [], [], [], [] | ||
elif (len(name) >= 4 and name[0:4] == "mnli") or name == "ax": | ||
self._label, self._idx, self._premise, self._hypothesis = [], [], [], [] | ||
elif name == "qnli": | ||
self._label, self._idx, self._question, self._sentence = [], [], [], [] | ||
self._label, self._text = [], [] | ||
self._load(name) | ||
|
||
def _load(self, name): | ||
for every_dict in self.dataset_list: | ||
self._label.append(every_dict['label']) | ||
self._text.append(every_dict['idx']) | ||
if name in ('cola', 'sst2'): | ||
self._sentence.append(every_dict['sentence']) | ||
elif name in ('mrpc', 'stsb', 'rte', 'wnli'): | ||
self._sentence1.append(every_dict['sentence1']) | ||
self._sentence2.append(every_dict['sentence2']) | ||
elif name == "qqp": | ||
self._question1.append(every_dict['question1']) | ||
self._question2.append(every_dict['question2']) | ||
elif (len(name) >= 4 and name[0:4] == "mnli") or name == "ax": | ||
self._premise.append(every_dict['premise']) | ||
self._hypothesis.append(every_dict['hypothesis']) | ||
elif name == "qnli": | ||
self._sentence.append(every_dict['sentence']) | ||
self._question.append(every_dict['question']) | ||
|
||
def __getitem__(self, index): | ||
return self._text[index], self._label[index] | ||
|
||
def __len__(self): | ||
return len(self._label) | ||
|
||
|
||
@load_dataset.register | ||
def HF_GLUE( | ||
name: str, | ||
root: str = DEFAULT_ROOT, | ||
split: Union[Tuple[str], str] = ("train", "test"), | ||
shuffle=True, | ||
): | ||
r""" | ||
Load the huggingface GLUE dataset. | ||
Args: | ||
name (str):Task name | ||
root (str): Directory where the datasets are saved. | ||
Default:~/.mindnlp | ||
split (str|Tuple[str]): Split or splits to be returned. | ||
Default:('train', 'test'). | ||
shuffle (bool): Whether to shuffle the dataset. | ||
Default:True. | ||
Returns: | ||
- **datasets_list** (list) -A list of loaded datasets. | ||
If only one type of dataset is specified,such as 'trian', | ||
this dataset is returned instead of a list of datasets. | ||
Examples: | ||
>>> root = "~/.mindnlp" | ||
>>> split = ('train', 'test') | ||
>>> dataset_train,dataset_test = HF_GLUE(root, split) | ||
>>> train_iter = dataset_train.create_tuple_iterator() | ||
>>> print(next(train_iter)) | ||
""" | ||
cache_dir = os.path.join(root, "datasets", "hf_datasets", "GLUE") | ||
if name in ('cola', 'sst2'): | ||
column_names = ['sentence', 'label', 'idx'] | ||
elif name in ('mrpc', 'stsb', 'rte', 'wnli'): | ||
column_names = ['sentence1', 'sentence2', 'label', 'idx'] | ||
elif name == "qqp": | ||
column_names = ['question1', 'question2', 'label', 'idx'] | ||
elif (len(name) >= 4 and name[0:4] == "mnli") or name == "ax": | ||
column_names = ['premise', 'hypothesis', 'label', 'idx'] | ||
elif name == "qnli": | ||
column_names = ['question', 'sentence', 'label', 'idx'] | ||
|
||
datasets_list = [] | ||
mode_list = [] | ||
|
||
if isinstance(split, str): | ||
mode_list.append(split) | ||
else: | ||
for s in split: | ||
mode_list.append(s) | ||
|
||
ds_list = hf_load('glue', name, split=mode_list, cache_dir=cache_dir) | ||
for every_ds in ds_list: | ||
datasets_list.append(GeneratorDataset( | ||
source=HFglue(every_ds, name), | ||
column_names=column_names, shuffle=shuffle) | ||
) | ||
if len(mode_list) == 1: | ||
return datasets_list[0] | ||
return datasets_list |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,215 @@ | ||
# Copyright 2022 Huawei Technologies Co., Ltd | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================ | ||
""" | ||
Test HF_GLUE | ||
""" | ||
|
||
import os | ||
import unittest | ||
import shutil | ||
import pytest | ||
from mindnlp.dataset import HF_GLUE | ||
|
||
|
||
class TestHFGLUE(unittest.TestCase): | ||
r""" | ||
Test HF_GLUE | ||
""" | ||
|
||
@classmethod | ||
def setUpClass(cls): | ||
cls.root = os.path.join(os.path.expanduser("~"), ".mindnlp") | ||
|
||
@classmethod | ||
def tearDownClass(cls): | ||
shutil.rmtree(cls.root) | ||
|
||
@pytest.mark.download | ||
def test_hf_glue_cola(self): | ||
"""Test glue_cola""" | ||
num_lines = { | ||
"train": 8551, | ||
"test": 1063, | ||
"validation": 1043, | ||
} | ||
dataset_train, dataset_validation, dataset_test = HF_GLUE( | ||
name="cola", root=self.root, split=("train", "validation", "test") | ||
) | ||
assert dataset_train.get_dataset_size() == num_lines["train"] | ||
assert dataset_test.get_dataset_size() == num_lines["test"] | ||
assert dataset_validation.get_dataset_size() == num_lines["validation"] | ||
|
||
@pytest.mark.download | ||
def test_hf_glue_sst2(self): | ||
"""Test glue_sst2""" | ||
num_lines = { | ||
"train": 67349, | ||
"test": 1821, | ||
"validation": 872, | ||
} | ||
dataset_train, dataset_validation, dataset_test = HF_GLUE( | ||
name="sst2", root=self.root, split=("train", "validation", "test") | ||
) | ||
assert dataset_train.get_dataset_size() == num_lines["train"] | ||
assert dataset_test.get_dataset_size() == num_lines["test"] | ||
assert dataset_validation.get_dataset_size() == num_lines["validation"] | ||
|
||
@pytest.mark.download | ||
def test_hf_glue_mrpc(self): | ||
"""Test glue_mrpc""" | ||
num_lines = { | ||
"train": 3668, | ||
"test": 1725, | ||
"validation": 408, | ||
} | ||
dataset_train, dataset_validation, dataset_test = HF_GLUE( | ||
name="mrpc", root=self.root, split=("train", "validation", "test") | ||
) | ||
assert dataset_train.get_dataset_size() == num_lines["train"] | ||
assert dataset_test.get_dataset_size() == num_lines["test"] | ||
assert dataset_validation.get_dataset_size() == num_lines["validation"] | ||
|
||
@pytest.mark.download | ||
def test_hf_glue_qqp(self): | ||
"""Test glue_qqp""" | ||
num_lines = { | ||
"train": 363846, | ||
"test": 390965, | ||
"validation": 40430, | ||
} | ||
dataset_train, dataset_validation, dataset_test = HF_GLUE( | ||
name="qqp", root=self.root, split=("train", "validation", "test") | ||
) | ||
assert dataset_train.get_dataset_size() == num_lines["train"] | ||
assert dataset_test.get_dataset_size() == num_lines["test"] | ||
assert dataset_validation.get_dataset_size() == num_lines["validation"] | ||
|
||
@pytest.mark.download | ||
def test_hf_glue_stsb(self): | ||
"""Test glue_stsb""" | ||
num_lines = { | ||
"train": 5749, | ||
"test": 1379, | ||
"validation": 1500, | ||
} | ||
dataset_train, dataset_validation, dataset_test = HF_GLUE( | ||
name="stsb", root=self.root, split=("train", "validation", "test") | ||
) | ||
assert dataset_train.get_dataset_size() == num_lines["train"] | ||
assert dataset_test.get_dataset_size() == num_lines["test"] | ||
assert dataset_validation.get_dataset_size() == num_lines["validation"] | ||
|
||
@pytest.mark.download | ||
def test_hf_glue_mnli(self): | ||
"""Test glue_mnli""" | ||
num_lines = { | ||
"test_matched": 9796, | ||
"test_mismatched": 9847, | ||
"train": 392702, | ||
"validation_matched": 9815, | ||
"validation_mismatched": 9832, | ||
} | ||
dataset_train, dataset_validation_matched, dataset_validation_mismatched, dataset_test_matched, dataset_test_mismatched = HF_GLUE( | ||
name="mnli", root=self.root, split=("train", "validation_matched", "validation_mismatched","test_matched", "test_mismatched",) | ||
) | ||
assert dataset_train.get_dataset_size() == num_lines["train"] | ||
assert dataset_test_matched.get_dataset_size() == num_lines["test_matched"] | ||
assert dataset_test_mismatched.get_dataset_size() == num_lines["test_mismatched"] | ||
assert dataset_validation_matched.get_dataset_size() == num_lines["validation_matched"] | ||
assert dataset_validation_mismatched.get_dataset_size() == num_lines["validation_mismatched"] | ||
|
||
@pytest.mark.download | ||
def test_hf_glue_mnli_mismatched(self): | ||
"""Test glue_mnli_mismatched""" | ||
num_lines = { | ||
"test": 9847, | ||
"validation": 9832, | ||
} | ||
dataset_validation, dataset_test = HF_GLUE( | ||
name="mnli_mismatched", root=self.root, | ||
split=("validation", "test") | ||
) | ||
assert dataset_test.get_dataset_size() == num_lines["test"] | ||
assert dataset_validation.get_dataset_size() == num_lines["validation"] | ||
|
||
@pytest.mark.download | ||
def test_hf_glue_mnli_matched(self): | ||
"""Test glue_mnli_matched""" | ||
num_lines = { | ||
"test": 9796, | ||
"validation": 9815, | ||
} | ||
dataset_validation, dataset_test = HF_GLUE( | ||
name="mnli_matched", root=self.root, | ||
split=("validation", "test") | ||
) | ||
assert dataset_test.get_dataset_size() == num_lines["test"] | ||
assert dataset_validation.get_dataset_size() == num_lines["validation"] | ||
|
||
@pytest.mark.download | ||
def test_hf_glue_qnli(self): | ||
"""Test glue_qnli""" | ||
num_lines = { | ||
"train": 104743, | ||
"test": 5463, | ||
"validation": 5463, | ||
} | ||
dataset_train, dataset_validation, dataset_test = HF_GLUE( | ||
name="qnli", root=self.root, split=("train", "validation", "test") | ||
) | ||
assert dataset_train.get_dataset_size() == num_lines["train"] | ||
assert dataset_test.get_dataset_size() == num_lines["test"] | ||
assert dataset_validation.get_dataset_size() == num_lines["validation"] | ||
|
||
@pytest.mark.download | ||
def test_hf_glue_rte(self): | ||
"""Test glue_rte""" | ||
num_lines = { | ||
"train": 2490, | ||
"test": 3000, | ||
"validation": 277, | ||
} | ||
dataset_train, dataset_validation, dataset_test = HF_GLUE( | ||
name="rte", root=self.root, split=("train", "validation", "test") | ||
) | ||
assert dataset_train.get_dataset_size() == num_lines["train"] | ||
assert dataset_test.get_dataset_size() == num_lines["test"] | ||
assert dataset_validation.get_dataset_size() == num_lines["validation"] | ||
|
||
@pytest.mark.download | ||
def test_hf_glue_wnli(self): | ||
"""Test glue_wnli""" | ||
num_lines = { | ||
"train": 635, | ||
"test": 146, | ||
"validation": 71, | ||
} | ||
dataset_train, dataset_validation, dataset_test = HF_GLUE( | ||
name="wnli", root=self.root, split=("train", "validation", "test") | ||
) | ||
assert dataset_train.get_dataset_size() == num_lines["train"] | ||
assert dataset_test.get_dataset_size() == num_lines["test"] | ||
assert dataset_validation.get_dataset_size() == num_lines["validation"] | ||
|
||
@pytest.mark.download | ||
def test_hf_glue_ax(self): | ||
"""Test glue_ax""" | ||
num_lines = { | ||
"test": 1104, | ||
} | ||
dataset_test = HF_GLUE( | ||
name="ax", root=self.root, split="test" | ||
) | ||
assert dataset_test.get_dataset_size() == num_lines["test"] |