From bd8c16d4e26c177ea2e126b1187f7187161f1c9c Mon Sep 17 00:00:00 2001 From: ByT5 Team Date: Tue, 7 Sep 2021 12:11:58 -0700 Subject: [PATCH] Adding grapheme-to-phoneme and morphological inflection tasks. PiperOrigin-RevId: 395301991 --- byt5/sigmorphon.py | 272 +++++++++++++++++++++++++++++++++++++++++++++ byt5/tasks_test.py | 52 +++++---- 2 files changed, 301 insertions(+), 23 deletions(-) create mode 100644 byt5/sigmorphon.py diff --git a/byt5/sigmorphon.py b/byt5/sigmorphon.py new file mode 100644 index 0000000..06b4cbf --- /dev/null +++ b/byt5/sigmorphon.py @@ -0,0 +1,272 @@ +# Copyright 2021 The ByT5 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Add Tasks to registry.""" +import functools +import random +from byt5.tasks import DEFAULT_BYTE_OUTPUT_FEATURES +from byt5.tasks import DEFAULT_MT5_OUTPUT_FEATURES +import numpy +import t5.data +from t5.data import preprocessors + +# Place downloaded data from https://sigmorphon.github.io/sharedtasks/2020 in +# the following directory. +SIGMORPHON_DIR = None + +FEATURE_MAP = { + "byt5": DEFAULT_BYTE_OUTPUT_FEATURES, + "mt5": DEFAULT_MT5_OUTPUT_FEATURES +} + +# ====================== SIGMORPHON-2020 TASK-1 ==================== +# Task 1: Multilingual Grapheme-to-Phoneme Conversion +# Please see website https://sigmorphon.github.io/sharedtasks/2020/task1/ +# for details. + + +def get_2020_task1_preprocessor(language): + return [ + functools.partial( + preprocessors.preprocess_tsv, + inputs_format=f' {language} ' + '{0}', + targets_format='{1}', + num_fields=2), + ] + + +def metrics_task1_2020(targets, predictions): + """Computes word error rate and edit distance metrics.""" + + def edit_distance(x, y) -> int: + # Implementation from + # https://github.com/sigmorphon/2020/blob/master/task1/evaluation/evallib.py + idim = len(x) + 1 + jdim = len(y) + 1 + table = numpy.zeros((idim, jdim), dtype=numpy.uint8) + table[1:, 0] = 1 + table[0, 1:] = 1 + for i in range(1, idim): + for j in range(1, jdim): + if x[i - 1] == y[j - 1]: + table[i][j] = table[i - 1][j - 1] + else: + c1 = table[i - 1][j] + c2 = table[i][j - 1] + c3 = table[i - 1][j - 1] + table[i][j] = min(c1, c2, c3) + 1 + return int(table[-1][-1]) + + # Word-level measures. + correct = 0 + incorrect = 0 + # Label-level measures. + total_edits = 0 + total_length = 0 + for gold, hypo in zip(targets, predictions): + edits = edit_distance(gold, hypo) + length = len(gold) + if edits == 0: + correct += 1 + else: + incorrect += 1 + total_edits += edits + total_length += length + wer = incorrect / (correct + incorrect) + ler = 100 * total_edits / total_length + return {'wer': wer, 'ler': ler} + + +langs = [ + 'arm', 'bul', 'fre', 'geo', 'hin', 'hun', 'ice', 'kor', 'lit', 'gre', 'ady', + 'dut', 'jpn', 'rum', 'vie' +] +year = '2020' +task = 'task1' +data_dir = f'{SIGMORPHON_DIR}/{year}/{task}/data/' + +for lang in langs: + for prefix, output_features in FEATURE_MAP.items(): + t5.data.TaskRegistry.add( + f'{prefix}_sigmorphon_{year}_{task}.{lang}', + t5.data.TextLineTask, + text_preprocessor=get_2020_task1_preprocessor(lang), + output_features=output_features, + split_to_filepattern={ + 'train': f'{data_dir}/train/{lang}_train.tsv', + 'validation': f'{data_dir}/dev/{lang}_dev.tsv', + 'test': f'{data_dir}/test/{lang}_test.tsv', + }, + metric_fns=[metrics_task1_2020]) + +for prefix in ['mt5', 'byt5']: + t5.data.MixtureRegistry.add( + f'{prefix}_sigmorphon_{year}_{task}', + [f'{prefix}_sigmorphon_{year}_{task}.{lang}' for lang in langs], + default_rate=1.) + +# ====================== SIGMORPHON-2020 TASK-0 ==================== +# Task 0: Typologically Diverse Morphological Inflection +# Please see website https://sigmorphon.github.io/sharedtasks/2020/task0/ +# for details. + + +def get_2020_task0_preprocessor(language): + return [ + functools.partial( + preprocessors.preprocess_tsv, + inputs_format=f'{language}' + ' {0} ' + 'form={2}', + targets_format='{1}', + num_fields=3), + ] + + +def metrics_task0_2020(targets, predictions): + """Calculates exact match and edit distance based metrics.""" + + def distance(str1, str2): + """Levenshtein distance.""" + # Implementation from + # https://github.com/sigmorphon2020/task0-data/blob/master/evaluate.py + m = numpy.zeros([len(str2) + 1, len(str1) + 1]) + for x in range(1, len(str2) + 1): + m[x][0] = m[x - 1][0] + 1 + for y in range(1, len(str1) + 1): + m[0][y] = m[0][y - 1] + 1 + for x in range(1, len(str2) + 1): + for y in range(1, len(str1) + 1): + if str1[y - 1] == str2[x - 1]: + dg = 0 + else: + dg = 1 + m[x][y] = min(m[x - 1][y] + 1, m[x][y - 1] + 1, m[x - 1][y - 1] + dg) + return int(m[len(str2)][len(str1)]) + + correct, dist, total = 0., 0., 0. + for target, prediction in zip(targets, predictions): + if target == prediction: + correct += 1 + dist += distance(target, prediction) + total += 1 + return { + 'accuracy': round(correct / total * 100, 2), + 'distance': round(dist / total, 2) + } + + +surprise_lang_path_prefix = [ + 'SURPRISE-LANGUAGES/Afro-Asiatic/mlt', 'SURPRISE-LANGUAGES/Germanic/gsw', + 'SURPRISE-LANGUAGES/Nilo-Sahan/dje', 'SURPRISE-LANGUAGES/Romance/frm', + 'SURPRISE-LANGUAGES/Indo-Aryan/urd', 'SURPRISE-LANGUAGES/Uralic/kpv', + 'SURPRISE-LANGUAGES/Sino-Tibetan/bod', 'SURPRISE-LANGUAGES/Germanic/nno', + 'SURPRISE-LANGUAGES/Uralic/olo', 'SURPRISE-LANGUAGES/Romance/fur', + 'SURPRISE-LANGUAGES/Romance/cat', 'SURPRISE-LANGUAGES/Afro-Asiatic/syc', + 'SURPRISE-LANGUAGES/Algic/cre', 'SURPRISE-LANGUAGES/Turkic/kir', + 'SURPRISE-LANGUAGES/Uralic/lud', 'SURPRISE-LANGUAGES/Uralic/udm', + 'SURPRISE-LANGUAGES/Iranian/pus', 'SURPRISE-LANGUAGES/Romance/ast', + 'SURPRISE-LANGUAGES/Germanic/gml', 'SURPRISE-LANGUAGES/Turkic/bak', + 'SURPRISE-LANGUAGES/Indo-Aryan/hin', 'SURPRISE-LANGUAGES/Iranian/fas', + 'SURPRISE-LANGUAGES/Niger-Congo/sna', 'SURPRISE-LANGUAGES/Romance/xno', + 'SURPRISE-LANGUAGES/Romance/vec', 'SURPRISE-LANGUAGES/Dravidian/kan', + 'SURPRISE-LANGUAGES/Afro-Asiatic/orm', 'SURPRISE-LANGUAGES/Turkic/uzb', + 'SURPRISE-LANGUAGES/Uto-Aztecan/ood', 'SURPRISE-LANGUAGES/Turkic/tuk', + 'SURPRISE-LANGUAGES/Iranian/tgk', 'SURPRISE-LANGUAGES/Romance/lld', + 'SURPRISE-LANGUAGES/Turkic/kaz', 'SURPRISE-LANGUAGES/Indo-Aryan/ben', + 'SURPRISE-LANGUAGES/Siouan/dak', 'SURPRISE-LANGUAGES/Romance/glg', + 'SURPRISE-LANGUAGES/Turkic/kjh', 'SURPRISE-LANGUAGES/Turkic/crh', + 'SURPRISE-LANGUAGES/Indo-Aryan/san', 'SURPRISE-LANGUAGES/Dravidian/tel', + 'SURPRISE-LANGUAGES/Tungusic/evn', 'SURPRISE-LANGUAGES/Turkic/aze', + 'SURPRISE-LANGUAGES/Uralic/vro', 'SURPRISE-LANGUAGES/Turkic/uig', + 'SURPRISE-LANGUAGES/Australian/mwf' +] +development_lang_path_prefix = [ + 'DEVELOPMENT-LANGUAGES/germanic/swe', 'DEVELOPMENT-LANGUAGES/germanic/ang', + 'DEVELOPMENT-LANGUAGES/oto-manguean/azg', + 'DEVELOPMENT-LANGUAGES/uralic/vep', 'DEVELOPMENT-LANGUAGES/niger-congo/lin', + 'DEVELOPMENT-LANGUAGES/niger-congo/nya', + 'DEVELOPMENT-LANGUAGES/germanic/frr', 'DEVELOPMENT-LANGUAGES/uralic/vot', + 'DEVELOPMENT-LANGUAGES/austronesian/mlg', + 'DEVELOPMENT-LANGUAGES/oto-manguean/ctp', + 'DEVELOPMENT-LANGUAGES/oto-manguean/otm', + 'DEVELOPMENT-LANGUAGES/oto-manguean/ote', + 'DEVELOPMENT-LANGUAGES/uralic/fin', + 'DEVELOPMENT-LANGUAGES/oto-manguean/cpa', + 'DEVELOPMENT-LANGUAGES/austronesian/mao', + 'DEVELOPMENT-LANGUAGES/uralic/mdf', 'DEVELOPMENT-LANGUAGES/germanic/dan', + 'DEVELOPMENT-LANGUAGES/niger-congo/gaa', + 'DEVELOPMENT-LANGUAGES/oto-manguean/cly', + 'DEVELOPMENT-LANGUAGES/uralic/mhr', 'DEVELOPMENT-LANGUAGES/niger-congo/zul', + 'DEVELOPMENT-LANGUAGES/uralic/krl', 'DEVELOPMENT-LANGUAGES/niger-congo/kon', + 'DEVELOPMENT-LANGUAGES/oto-manguean/czn', + 'DEVELOPMENT-LANGUAGES/germanic/gmh', 'DEVELOPMENT-LANGUAGES/uralic/izh', + 'DEVELOPMENT-LANGUAGES/austronesian/ceb', + 'DEVELOPMENT-LANGUAGES/germanic/nob', + 'DEVELOPMENT-LANGUAGES/austronesian/tgl', + 'DEVELOPMENT-LANGUAGES/austronesian/hil', + 'DEVELOPMENT-LANGUAGES/niger-congo/lug', + 'DEVELOPMENT-LANGUAGES/niger-congo/sot', + 'DEVELOPMENT-LANGUAGES/niger-congo/swa', + 'DEVELOPMENT-LANGUAGES/germanic/isl', + 'DEVELOPMENT-LANGUAGES/oto-manguean/pei', + 'DEVELOPMENT-LANGUAGES/uralic/sme', 'DEVELOPMENT-LANGUAGES/germanic/nld', + 'DEVELOPMENT-LANGUAGES/niger-congo/aka', + 'DEVELOPMENT-LANGUAGES/germanic/eng', + 'DEVELOPMENT-LANGUAGES/oto-manguean/zpv', + 'DEVELOPMENT-LANGUAGES/uralic/est', 'DEVELOPMENT-LANGUAGES/uralic/liv', + 'DEVELOPMENT-LANGUAGES/oto-manguean/xty', + 'DEVELOPMENT-LANGUAGES/germanic/deu', 'DEVELOPMENT-LANGUAGES/uralic/myv' +] +year = '2020' +task = 'task0' +data_dir = f'{SIGMORPHON_DIR}/{year}/task0-data/' +langs = [ + path_prefix.split('/')[-1] + for path_prefix in surprise_lang_path_prefix + development_lang_path_prefix +] +random.shuffle(langs) +path_prefixes = surprise_lang_path_prefix + development_lang_path_prefix + +for prefix, output_features in FEATURE_MAP.items(): + for path_prefix in path_prefixes: + lang = path_prefix.split('/')[-1] + split_to_filepattern = { + 'train': f'{data_dir}/{path_prefix}.trn', + 'validation': f'{data_dir}/{path_prefix}.dev', + 'test': f'{data_dir}/GOLD-TEST/{lang}.tst', + } + t5.data.TaskRegistry.add( + f'{prefix}_sigmorphon_{year}_{task}.{lang}', + t5.data.TextLineTask, + text_preprocessor=get_2020_task0_preprocessor(lang), + output_features=output_features, + split_to_filepattern=split_to_filepattern, + metric_fns=[metrics_task0_2020]) + + + t5.data.TaskRegistry.add( + f'{prefix}_sigmorphon_{year}_{task}.all', + t5.data.TextLineTask, + text_preprocessor=preprocessors.preprocess_tsv, + output_features=output_features, + split_to_filepattern={ + 'test': f'{data_dir}/test.tsv', + 'validation': f'{data_dir}/validation.tsv' + }, + metric_fns=[metrics_task0_2020]) + +for prefix in ['mt5', 'byt5']: + t5.data.MixtureRegistry.add( + f'{prefix}_sigmorphon_{year}_{task}', + [f'{prefix}_sigmorphon_{year}_{task}.{lang}' for lang in langs], + default_rate=1.) diff --git a/byt5/tasks_test.py b/byt5/tasks_test.py index b28bd16..3e400e1 100644 --- a/byt5/tasks_test.py +++ b/byt5/tasks_test.py @@ -29,9 +29,11 @@ _SEQUENCE_LENGTH = {'inputs': 128, 'targets': 128} _TASKS = [ - 'byt5_wiki.en', + 'byt5_dakshina_single_word_translit_indic2latin.bn', + 'byt5_dakshina_word_translit_latin2indic_lang_prefix.bn', + 'byt5_gem_xsum', 'byt5_mc4.en', - 'char_t5_mc4.en', + 'byt5_sigmorphon_2020_task1.ar', 'byt5_super_glue_boolq_v102', 'byt5_super_glue_cb_v102', 'byt5_super_glue_copa_v102', @@ -42,39 +44,43 @@ 'byt5_super_glue_wsc_v102_simple_eval', 'byt5_super_glue_wsc_v102_simple_train', 'byt5_tweetqa', - 'byt5_dakshina_word_translit_latin2indic_lang_prefix.bn', - 'byt5_dakshina_single_word_translit_indic2latin.bn', - 'mt5_dakshina_word_translit_latin2indic_lang_prefix.bn', - 'mt5_dakshina_single_word_translit_indic2latin.bn', + 'byt5_wiki.en', 'byt5_wmt15_enfr_v003', 'byt5_wmt16_enro_v003', 'byt5_wmt_t2t_ende_v003', - 'byt5_gem_xsum' + 'char_t5_mc4.en', + 'mt5_dakshina_single_word_translit_indic2latin.bn', + 'mt5_dakshina_word_translit_latin2indic_lang_prefix.bn', + 'mt5_sigmorphon_2020_task0.dje' ] _MIXTURES = [ - 'byt5_xnli_zeroshot', + 'byt5_dak_wrdtrnslit_ind2lat', + 'byt5_dak_wrdtrnslit_lat2ind', + 'byt5_dak_wrdtrnslit_lat2ind_lp', + 'byt5_glue_v002_proportional', + 'byt5_mlqa_translate_train', + 'byt5_mlqa_zeroshot', + 'byt5_ner_multilingual', + 'byt5_ner_zeroshot', + 'byt5_pawsx_translate_train', 'byt5_pawsx_zeroshot', + 'byt5_sigmorphon_2020_task0', + 'byt5_sigmorphon_2020_task1', + 'byt5_super_glue_v102_proportional', 'byt5_tydiqa', + 'byt5_tydiqa_translate_train', 'byt5_tydiqa_zeroshot', - 'byt5_xquad_zeroshot', - 'byt5_mlqa_zeroshot', + 'byt5_wikilingua', 'byt5_xnli_translate_train', - 'byt5_pawsx_translate_train', - 'byt5_tydiqa_translate_train', + 'byt5_xnli_zeroshot', 'byt5_xquad_translate_train', - 'byt5_mlqa_translate_train', - 'byt5_super_glue_v102_proportional', - 'byt5_glue_v002_proportional', - 'byt5_wikilingua', - 'byt5_dak_wrdtrnslit_lat2ind_lp', - 'mt5_dak_wrdtrnslit_lat2ind_lp', - 'byt5_dak_wrdtrnslit_lat2ind', - 'mt5_dak_wrdtrnslit_lat2ind', - 'byt5_dak_wrdtrnslit_ind2lat', + 'byt5_xquad_zeroshot', 'mt5_dak_wrdtrnslit_ind2lat', - 'byt5_ner_zeroshot', - 'byt5_ner_multilingual' + 'mt5_dak_wrdtrnslit_lat2ind', + 'mt5_dak_wrdtrnslit_lat2ind_lp', + 'mt5_sigmorphon_2020_task0', + 'mt5_sigmorphon_2020_task1' ]