-
Notifications
You must be signed in to change notification settings - Fork 107
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* [week03] seminar added Co-authored-by: Timofey Smirnov <timothyxp.mail.ru>
- Loading branch information
Showing
7 changed files
with
989 additions
and
0 deletions.
There are no files selected for viewing
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
pytorch | ||
torchaudio | ||
matplotlib | ||
librosa | ||
editdistance |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,227 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "9bc3c8b2-35fb-44c3-800e-ef9bc13a9231", | ||
"metadata": {}, | ||
"source": [ | ||
"# Decoding CTC output" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "4b923207-7a92-47f3-9f8f-1448b2ba4285", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import pickle\n", | ||
"import torch\n", | ||
"\n", | ||
"\n", | ||
"# Load precomputed CTC output\n", | ||
"with open('mystery_records.pickle', 'rb') as f:\n", | ||
" batch = pickle.load(f)\n", | ||
"\n", | ||
"# log probabilities of softmax layers [batch_size, T, vocab_size]\n", | ||
"log_probs = batch[\"log_probs\"]\n", | ||
"\n", | ||
"# Dictionary with index to character mapping\n", | ||
"ind2char = batch[\"ind2char\"]\n", | ||
"\n", | ||
"# Index of special EMPTY token\n", | ||
"EMPTY_TOK = '^'\n", | ||
"EMPTY_IND = 0" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "922cbf65-fbaf-48d5-8605-ea41c3f80590", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def ctc_decode(inds, ind2char):\n", | ||
" # your code here\n", | ||
" raise NotImplementedError\n", | ||
"\n", | ||
"for i, rec in enumerate(log_probs):\n", | ||
" text = ctc_decode(rec.argmax(-1).numpy(), ind2char)\n", | ||
" print(f\"{i}) {text}\")\n", | ||
"\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "3a8c4d50-e633-4e85-8842-a6b50602b70f", | ||
"metadata": {}, | ||
"source": [ | ||
"# Computing WER and CER\n", | ||
"Task: Implemet WER and CER metrics" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "0ca11f70-ee02-4765-b542-96186781a0b8", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# library for fast quick calculation of edit distance\n", | ||
"import editdistance\n", | ||
"\n", | ||
"def calc_wer(target_text: str, pred_text: str):\n", | ||
" # your code here\n", | ||
" raise NotImplementedError\n", | ||
" \n", | ||
"\n", | ||
"def calc_cer(target_text: str, pred_text: str):\n", | ||
" # your code here\n", | ||
" raise NotImplementedError\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "6c391511-7469-4ed8-bd26-057c4fde4717", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import numpy as np\n", | ||
"\n", | ||
"for target, pred, expected_wer, expected_cer in [\n", | ||
" (\"if you can not measure it you can not improve it\", \n", | ||
" \"if you can nt measure t yo can not i\", \n", | ||
" 0.454, 0.25),\n", | ||
" (\"if you cant describe what you are doing as a process you dont know what youre doing\", \n", | ||
" \"if you cant describe what you are doing as a process you dont know what youre doing\", \n", | ||
" 0.0, 0.0),\n", | ||
" (\"one measurement is worth a thousand expert opinions\", \n", | ||
" \"one is worth thousand opinions\", \n", | ||
" 0.375, 0.392)\n", | ||
"]:\n", | ||
" wer = calc_wer(target, pred)\n", | ||
" cer = calc_cer(target, pred)\n", | ||
" assert np.isclose(wer, expected_wer, atol=1e-3), f\"true: {target}, pred: {pred}, expected wer {expected_wer} != your wer {wer}\"\n", | ||
" assert np.isclose(cer, expected_cer, atol=1e-3), f\"true: {target}, pred: {pred}, expected cer {expected_cer} != your cer {cer}\"\n", | ||
" " | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "7cefd76b-66d4-4b1e-ae1d-be6b7336a160", | ||
"metadata": {}, | ||
"source": [ | ||
"Task: come up with such a pair of target-prediction texts, so the\n", | ||
"1) WER > 1.0\n", | ||
"2) CER > WER" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "11bceaaf-7b17-466b-ac17-855e4d54cf56", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# 1) WER > 1.0\n", | ||
"# your code here\n", | ||
"target, prediction = \"<...>\" , \"<...>\"\n", | ||
"assert calc_wer(target, prediction) > 1.0\n", | ||
"\n", | ||
"# 2) CER > WER\n", | ||
"# your code here\n", | ||
"target, prediction = \"<...>\", \"<...>\"\n", | ||
"assert calc_wer(target, prediction) < calc_cer(target, prediction) \n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "31a1fb97-4853-4190-835d-31ead094679c", | ||
"metadata": {}, | ||
"source": [ | ||
"# Beam search\n", | ||
"Task: implement beam-search on CTC outputs" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "f8e1c37a-93be-47a1-8211-9b47d0721d23", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Load precomputed CTC output\n", | ||
"with open('lj_batch.pickle', 'rb') as f:\n", | ||
" batch = pickle.load(f)\n", | ||
"\n", | ||
"# log probabilities of softmax layers [batch_size, T, vocab_size]\n", | ||
"log_probs = batch[\"log_probs\"]\n", | ||
"\n", | ||
"# Dictionary with index to character mapping\n", | ||
"ind2char = batch[\"ind2char\"]\n", | ||
"\n", | ||
"true_texts = batch[\"text\"]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "9ae1f264-33cb-4c4d-b959-823d07843936", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from collections import defaultdict\n", | ||
"from tqdm import tqdm\n", | ||
"\n", | ||
"def ctc_beam_search(probs, beam_size, ind2char):\n", | ||
" # your code here\n", | ||
" raise NotImplementedError\n", | ||
"\n", | ||
"bs_results = []\n", | ||
"for log_probs_line in log_probs:\n", | ||
" bs_results.append(ctc_beam_search(log_probs_line.exp().numpy(), 100, ind2char))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "9e6d7249-aed1-4ff3-8ce2-20978320ac7d", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"\n", | ||
"for i in range(len(true_texts)):\n", | ||
" beam_search_hypos = bs_results[i][:3]\n", | ||
" true_text = true_texts[i]\n", | ||
" argmax_text = ctc_decode(log_probs[i].numpy().argmax(-1), ind2char)\n", | ||
" print(\"True: \", true_text)\n", | ||
" print(f\"Argmax: {argmax_text} --- (CER: {calc_cer(true_text, argmax_text):.3f})\")\n", | ||
" for ind, (hypo, score) in enumerate(beam_search_hypos):\n", | ||
" print(f\"{ind+1}) '{hypo}' --- (CER: {calc_cer(true_text, hypo):.3f})\")\n", | ||
" print('-' * 100)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "HW-ASR", | ||
"language": "python", | ||
"name": "hw-asr" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.8.11" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |