Skip to content

Commit

Permalink
Add pinyin level CER
Browse files Browse the repository at this point in the history
  • Loading branch information
iamanigeeit committed Aug 23, 2023
1 parent 254afc0 commit 1d3f98a
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 45 deletions.
72 changes: 37 additions & 35 deletions prosody/en_to_zh.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,41 +340,8 @@ def update(
def convert_hanzi_pinyin(self, hans):
if self.segment_chinese:
hans = ' '.join(self.wordseg.cut([hans])[0])
# pypinyin doesn't handle the sandhi properly!
pinyin_list = lazy_pinyin(hans, Style.TONE3, neutral_tone_with_five=True, tone_sandhi=False)
num_words = len(pinyin_list)
# For simplicity convert alternate tone-3 sandhi so that 3-3-3-3-3 becomes 2-3-2-3-2
# The rules are more complicated but no one implements it correctly
with_sandhi = []
i = 0
while i < num_words:
pinyin = pinyin_list[i]
if pinyin.endswith('3'):
add = []
while i < num_words and pinyin_list[i].endswith('3'):
add.append(pinyin_list[i])
i += 1
for j in range(len(add) - 2, -1, -2):
add[j] = add[j][:-1] + '2'
with_sandhi.extend(add)
else:
with_sandhi.append(pinyin)
i += 1
# Adjust for 不 and 一
for i, han in enumerate(hans[:-1]):
if han == '不' and with_sandhi[i + 1].endswith('4'):
bu_pinyin = with_sandhi[i]
with_sandhi[i] = bu_pinyin[:-1] + '2'
elif han == '一':
if i == 0 or hans[i - 1] not in '〇零一二三四五六七八九十':
yi_pinyin = with_sandhi[i]
if with_sandhi[i + 1].endswith('4'):
with_sandhi[i] = yi_pinyin[:-1] + '2'
else:
with_sandhi[i] = yi_pinyin[:-1] + '4'
# Regularize the pinyin (bo -> buo, ju -> jv, lian -> lien, rui -> ruei, sun -> suen, zun -> zuen)
regularized = [regularize_pinyin(pinyin) for pinyin in with_sandhi]
return regularized
with_sandhi = self.hans_to_pinyin(hans)
return [regularize_pinyin(pinyin) for pinyin in with_sandhi]


def find_py_units(self, pinyin_unit):
Expand Down Expand Up @@ -670,3 +637,38 @@ def all_tones(pinyin):
with_tones.extend(p + str(x) for x in range(1,5))
return with_tones


def hans_to_pinyin(hans):
# pypinyin doesn't handle the sandhi properly!
pinyin_list = lazy_pinyin(hans, Style.TONE3, neutral_tone_with_five=True, tone_sandhi=False)
num_words = len(pinyin_list)
# For simplicity convert alternate tone-3 sandhi so that 3-3-3-3-3 becomes 2-3-2-3-2
# The rules are more complicated but no one implements it correctly
with_sandhi = []
i = 0
while i < num_words:
pinyin = pinyin_list[i]
if pinyin.endswith('3'):
add = []
while i < num_words and pinyin_list[i].endswith('3'):
add.append(pinyin_list[i])
i += 1
for j in range(len(add) - 2, -1, -2):
add[j] = add[j][:-1] + '2'
with_sandhi.extend(add)
else:
with_sandhi.append(pinyin)
i += 1
# Adjust for 不 and 一
for i, han in enumerate(hans[:-1]):
if han == '不' and with_sandhi[i + 1].endswith('4'):
bu_pinyin = with_sandhi[i]
with_sandhi[i] = bu_pinyin[:-1] + '2'
elif han == '一':
if i == 0 or hans[i - 1] not in '〇零一二三四五六七八九十':
yi_pinyin = with_sandhi[i]
if with_sandhi[i + 1].endswith('4'):
with_sandhi[i] = yi_pinyin[:-1] + '2'
else:
with_sandhi[i] = yi_pinyin[:-1] + '4'
return with_sandhi
69 changes: 59 additions & 10 deletions prosody/eval_cer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 1,
"outputs": [],
"source": [
"from pathlib import Path\n",
"import os\n",
"from funasr_onnx import Paraformer\n",
"import jiwer\n",
"import re\n",
Expand All @@ -30,6 +31,7 @@
"PWD = %pwd\n",
"PWD = Path(PWD)\n",
"outputs_dir = PWD / 'outputs'\n",
"os.makedirs(outputs_dir, exist_ok=True)\n",
"jets_dir = outputs_dir / 'tts_train_jets_raw_phn_tacotron_g2p_en_no_space/aishell3'\n",
"nopitch_dir = outputs_dir / 'tts_train_jets_raw_phn_tacotron_g2p_en_no_space/aishell3_nopitch'\n",
"model_dir = (PWD / \"../../paraformer-large/\").resolve()\n",
Expand Down Expand Up @@ -120,11 +122,20 @@
},
{
"cell_type": "code",
"execution_count": 46,
"execution_count": 2,
"outputs": [],
"source": [
"transcript_file = data_dir / 'test/content.txt'"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 3,
"outputs": [],
"source": [
"transcript_file = data_dir / 'test/content.txt'\n",
"\n",
"def get_transcripts():\n",
" transcripts = {}\n",
" with open(transcript_file) as f:\n",
Expand All @@ -133,8 +144,17 @@
" transcripts[wav_file] = re.sub(r'[ a-z0-9]', '', transcript)\n",
" return transcripts\n",
"\n",
"transcripts = get_transcripts()\n",
"\n",
"transcripts = get_transcripts()"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"def eval_wer(transcripts, asr_result_path, wer_path):\n",
" with open(wer_path, 'w') as wer_file:\n",
" wer_file.write('wav_file,gt_len,wer,eng_words\\n')\n",
Expand All @@ -145,8 +165,8 @@
" eng_words = sum([word.isascii() for word in asr_output])\n",
" transcript = transcripts[wav_file]\n",
" gt_len = len(transcript)\n",
" wer = jiwer.wer(reference=' '.join(transcript), hypothesis=' '.join(asr_output))\n",
" wer_file.write(f'{wav_file},{gt_len},{wer},{eng_words}\\n')\n"
" wer = jiwer.wer(truth=' '.join(transcript), hypothesis=' '.join(asr_output))\n",
" wer_file.write(f'{wav_file},{gt_len},{wer},{eng_words}\\n')"
],
"metadata": {
"collapsed": false
Expand Down Expand Up @@ -178,9 +198,38 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"outputs": [],
"source": [
"def eval_cer(transcripts, asr_result_path, cer_path):\n",
" from prosody.en_to_zh import hans_to_pinyin\n",
" with open(cer_path, 'w') as cer_file:\n",
" cer_file.write('wav_file,gt_len,cer,eng_words\\n')\n",
" with open(asr_result_path) as f:\n",
" for line in f:\n",
" wav_file, asr_output = line.strip().split(maxsplit=1)\n",
" asr_output = literal_eval(asr_output)\n",
" eng_words = sum([word.isascii() for word in asr_output])\n",
" transcript = transcripts[wav_file]\n",
" trans_pinyin = ''.join(hans_to_pinyin(transcript))\n",
" gt_len = len(trans_pinyin)\n",
" asr_pinyin = ''.join(hans_to_pinyin(asr_output)).lower()\n",
" cer = jiwer.cer(truth=trans_pinyin, hypothesis=asr_pinyin)\n",
" cer_file.write(f'{wav_file},{gt_len},{cer},{eng_words}\\n')"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 6,
"outputs": [],
"source": [],
"source": [
"jets_asr_path = model_dir / 'jets_result.txt'\n",
"jets_cer_path = outputs_dir / 'jets_cer.csv'\n",
"eval_cer(transcripts=transcripts, asr_result_path=jets_asr_path, cer_path=jets_cer_path)"
],
"metadata": {
"collapsed": false
}
Expand Down

0 comments on commit 1d3f98a

Please sign in to comment.