Skip to content

Commit

Permalink
add glue dataset (mindspore-lab#496)
Browse files Browse the repository at this point in the history
  • Loading branch information
Iron-Boyy authored May 18, 2023
1 parent 73a0f44 commit 0e086e0
Show file tree
Hide file tree
Showing 3 changed files with 350 additions and 0 deletions.
1 change: 1 addition & 0 deletions mindnlp/dataset/hf_datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@
"""

from .hf_imdb import HF_IMDB, HF_IMDB_Process
from .hf_glue import HF_GLUE
134 changes: 134 additions & 0 deletions mindnlp/dataset/hf_datasets/hf_glue.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Hugging Face GLUE load function
"""
# pylint: disable=C0103
import os
from typing import Union, Tuple
from datasets import load_dataset as hf_load
from mindspore.dataset import GeneratorDataset
from mindnlp.dataset.register import load_dataset
from mindnlp.configs import DEFAULT_ROOT


class HFglue:
"""
Hugging Face GLUE dataset source
"""

def __init__(self, dataset_list, name) -> None:
self.dataset_list = dataset_list
if name in ('cola', 'sst2'):
self._label, self._idx, self._sentence = [], [], []
elif name in ('mrpc', 'stsb', 'rte', 'wnli'):
self._label, self._idx, self._sentence1, self._sentence2 = [], [], [], []
elif name == "qqp":
self._label, self._idx, self._question1, self._question2 = [], [], [], []
elif (len(name) >= 4 and name[0:4] == "mnli") or name == "ax":
self._label, self._idx, self._premise, self._hypothesis = [], [], [], []
elif name == "qnli":
self._label, self._idx, self._question, self._sentence = [], [], [], []
self._label, self._text = [], []
self._load(name)

def _load(self, name):
for every_dict in self.dataset_list:
self._label.append(every_dict['label'])
self._text.append(every_dict['idx'])
if name in ('cola', 'sst2'):
self._sentence.append(every_dict['sentence'])
elif name in ('mrpc', 'stsb', 'rte', 'wnli'):
self._sentence1.append(every_dict['sentence1'])
self._sentence2.append(every_dict['sentence2'])
elif name == "qqp":
self._question1.append(every_dict['question1'])
self._question2.append(every_dict['question2'])
elif (len(name) >= 4 and name[0:4] == "mnli") or name == "ax":
self._premise.append(every_dict['premise'])
self._hypothesis.append(every_dict['hypothesis'])
elif name == "qnli":
self._sentence.append(every_dict['sentence'])
self._question.append(every_dict['question'])

def __getitem__(self, index):
return self._text[index], self._label[index]

def __len__(self):
return len(self._label)


@load_dataset.register
def HF_GLUE(
name: str,
root: str = DEFAULT_ROOT,
split: Union[Tuple[str], str] = ("train", "test"),
shuffle=True,
):
r"""
Load the huggingface GLUE dataset.
Args:
name (str):Task name
root (str): Directory where the datasets are saved.
Default:~/.mindnlp
split (str|Tuple[str]): Split or splits to be returned.
Default:('train', 'test').
shuffle (bool): Whether to shuffle the dataset.
Default:True.
Returns:
- **datasets_list** (list) -A list of loaded datasets.
If only one type of dataset is specified,such as 'trian',
this dataset is returned instead of a list of datasets.
Examples:
>>> root = "~/.mindnlp"
>>> split = ('train', 'test')
>>> dataset_train,dataset_test = HF_GLUE(root, split)
>>> train_iter = dataset_train.create_tuple_iterator()
>>> print(next(train_iter))
"""
cache_dir = os.path.join(root, "datasets", "hf_datasets", "GLUE")
if name in ('cola', 'sst2'):
column_names = ['sentence', 'label', 'idx']
elif name in ('mrpc', 'stsb', 'rte', 'wnli'):
column_names = ['sentence1', 'sentence2', 'label', 'idx']
elif name == "qqp":
column_names = ['question1', 'question2', 'label', 'idx']
elif (len(name) >= 4 and name[0:4] == "mnli") or name == "ax":
column_names = ['premise', 'hypothesis', 'label', 'idx']
elif name == "qnli":
column_names = ['question', 'sentence', 'label', 'idx']

datasets_list = []
mode_list = []

if isinstance(split, str):
mode_list.append(split)
else:
for s in split:
mode_list.append(s)

ds_list = hf_load('glue', name, split=mode_list, cache_dir=cache_dir)
for every_ds in ds_list:
datasets_list.append(GeneratorDataset(
source=HFglue(every_ds, name),
column_names=column_names, shuffle=shuffle)
)
if len(mode_list) == 1:
return datasets_list[0]
return datasets_list
215 changes: 215 additions & 0 deletions tests/ut/dataset/test_hfglue.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Test HF_GLUE
"""

import os
import unittest
import shutil
import pytest
from mindnlp.dataset import HF_GLUE


class TestHFGLUE(unittest.TestCase):
r"""
Test HF_GLUE
"""

@classmethod
def setUpClass(cls):
cls.root = os.path.join(os.path.expanduser("~"), ".mindnlp")

@classmethod
def tearDownClass(cls):
shutil.rmtree(cls.root)

@pytest.mark.download
def test_hf_glue_cola(self):
"""Test glue_cola"""
num_lines = {
"train": 8551,
"test": 1063,
"validation": 1043,
}
dataset_train, dataset_validation, dataset_test = HF_GLUE(
name="cola", root=self.root, split=("train", "validation", "test")
)
assert dataset_train.get_dataset_size() == num_lines["train"]
assert dataset_test.get_dataset_size() == num_lines["test"]
assert dataset_validation.get_dataset_size() == num_lines["validation"]

@pytest.mark.download
def test_hf_glue_sst2(self):
"""Test glue_sst2"""
num_lines = {
"train": 67349,
"test": 1821,
"validation": 872,
}
dataset_train, dataset_validation, dataset_test = HF_GLUE(
name="sst2", root=self.root, split=("train", "validation", "test")
)
assert dataset_train.get_dataset_size() == num_lines["train"]
assert dataset_test.get_dataset_size() == num_lines["test"]
assert dataset_validation.get_dataset_size() == num_lines["validation"]

@pytest.mark.download
def test_hf_glue_mrpc(self):
"""Test glue_mrpc"""
num_lines = {
"train": 3668,
"test": 1725,
"validation": 408,
}
dataset_train, dataset_validation, dataset_test = HF_GLUE(
name="mrpc", root=self.root, split=("train", "validation", "test")
)
assert dataset_train.get_dataset_size() == num_lines["train"]
assert dataset_test.get_dataset_size() == num_lines["test"]
assert dataset_validation.get_dataset_size() == num_lines["validation"]

@pytest.mark.download
def test_hf_glue_qqp(self):
"""Test glue_qqp"""
num_lines = {
"train": 363846,
"test": 390965,
"validation": 40430,
}
dataset_train, dataset_validation, dataset_test = HF_GLUE(
name="qqp", root=self.root, split=("train", "validation", "test")
)
assert dataset_train.get_dataset_size() == num_lines["train"]
assert dataset_test.get_dataset_size() == num_lines["test"]
assert dataset_validation.get_dataset_size() == num_lines["validation"]

@pytest.mark.download
def test_hf_glue_stsb(self):
"""Test glue_stsb"""
num_lines = {
"train": 5749,
"test": 1379,
"validation": 1500,
}
dataset_train, dataset_validation, dataset_test = HF_GLUE(
name="stsb", root=self.root, split=("train", "validation", "test")
)
assert dataset_train.get_dataset_size() == num_lines["train"]
assert dataset_test.get_dataset_size() == num_lines["test"]
assert dataset_validation.get_dataset_size() == num_lines["validation"]

@pytest.mark.download
def test_hf_glue_mnli(self):
"""Test glue_mnli"""
num_lines = {
"test_matched": 9796,
"test_mismatched": 9847,
"train": 392702,
"validation_matched": 9815,
"validation_mismatched": 9832,
}
dataset_train, dataset_validation_matched, dataset_validation_mismatched, dataset_test_matched, dataset_test_mismatched = HF_GLUE(
name="mnli", root=self.root, split=("train", "validation_matched", "validation_mismatched","test_matched", "test_mismatched",)
)
assert dataset_train.get_dataset_size() == num_lines["train"]
assert dataset_test_matched.get_dataset_size() == num_lines["test_matched"]
assert dataset_test_mismatched.get_dataset_size() == num_lines["test_mismatched"]
assert dataset_validation_matched.get_dataset_size() == num_lines["validation_matched"]
assert dataset_validation_mismatched.get_dataset_size() == num_lines["validation_mismatched"]

@pytest.mark.download
def test_hf_glue_mnli_mismatched(self):
"""Test glue_mnli_mismatched"""
num_lines = {
"test": 9847,
"validation": 9832,
}
dataset_validation, dataset_test = HF_GLUE(
name="mnli_mismatched", root=self.root,
split=("validation", "test")
)
assert dataset_test.get_dataset_size() == num_lines["test"]
assert dataset_validation.get_dataset_size() == num_lines["validation"]

@pytest.mark.download
def test_hf_glue_mnli_matched(self):
"""Test glue_mnli_matched"""
num_lines = {
"test": 9796,
"validation": 9815,
}
dataset_validation, dataset_test = HF_GLUE(
name="mnli_matched", root=self.root,
split=("validation", "test")
)
assert dataset_test.get_dataset_size() == num_lines["test"]
assert dataset_validation.get_dataset_size() == num_lines["validation"]

@pytest.mark.download
def test_hf_glue_qnli(self):
"""Test glue_qnli"""
num_lines = {
"train": 104743,
"test": 5463,
"validation": 5463,
}
dataset_train, dataset_validation, dataset_test = HF_GLUE(
name="qnli", root=self.root, split=("train", "validation", "test")
)
assert dataset_train.get_dataset_size() == num_lines["train"]
assert dataset_test.get_dataset_size() == num_lines["test"]
assert dataset_validation.get_dataset_size() == num_lines["validation"]

@pytest.mark.download
def test_hf_glue_rte(self):
"""Test glue_rte"""
num_lines = {
"train": 2490,
"test": 3000,
"validation": 277,
}
dataset_train, dataset_validation, dataset_test = HF_GLUE(
name="rte", root=self.root, split=("train", "validation", "test")
)
assert dataset_train.get_dataset_size() == num_lines["train"]
assert dataset_test.get_dataset_size() == num_lines["test"]
assert dataset_validation.get_dataset_size() == num_lines["validation"]

@pytest.mark.download
def test_hf_glue_wnli(self):
"""Test glue_wnli"""
num_lines = {
"train": 635,
"test": 146,
"validation": 71,
}
dataset_train, dataset_validation, dataset_test = HF_GLUE(
name="wnli", root=self.root, split=("train", "validation", "test")
)
assert dataset_train.get_dataset_size() == num_lines["train"]
assert dataset_test.get_dataset_size() == num_lines["test"]
assert dataset_validation.get_dataset_size() == num_lines["validation"]

@pytest.mark.download
def test_hf_glue_ax(self):
"""Test glue_ax"""
num_lines = {
"test": 1104,
}
dataset_test = HF_GLUE(
name="ax", root=self.root, split="test"
)
assert dataset_test.get_dataset_size() == num_lines["test"]

0 comments on commit 0e086e0

Please sign in to comment.