Skip to content

Commit

Permalink
week3 (#19)
Browse files Browse the repository at this point in the history
* [week03] seminar added

Co-authored-by: Timofey Smirnov <timothyxp.mail.ru>
  • Loading branch information
timothyxp authored Sep 29, 2022
1 parent 6db8f4a commit 05f5c78
Show file tree
Hide file tree
Showing 7 changed files with 989 additions and 0 deletions.
Binary file added week03/LJ001-0001.wav
Binary file not shown.
Binary file added week03/h001_Bedroom_65.wav
Binary file not shown.
Binary file added week03/lj_batch.pickle
Binary file not shown.
Binary file added week03/mystery_records.pickle
Binary file not shown.
5 changes: 5 additions & 0 deletions week03/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pytorch
torchaudio
matplotlib
librosa
editdistance
757 changes: 757 additions & 0 deletions week03/seminar03_1.ipynb

Large diffs are not rendered by default.

227 changes: 227 additions & 0 deletions week03/seminar03_2.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "9bc3c8b2-35fb-44c3-800e-ef9bc13a9231",
"metadata": {},
"source": [
"# Decoding CTC output"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4b923207-7a92-47f3-9f8f-1448b2ba4285",
"metadata": {},
"outputs": [],
"source": [
"import pickle\n",
"import torch\n",
"\n",
"\n",
"# Load precomputed CTC output\n",
"with open('mystery_records.pickle', 'rb') as f:\n",
" batch = pickle.load(f)\n",
"\n",
"# log probabilities of softmax layers [batch_size, T, vocab_size]\n",
"log_probs = batch[\"log_probs\"]\n",
"\n",
"# Dictionary with index to character mapping\n",
"ind2char = batch[\"ind2char\"]\n",
"\n",
"# Index of special EMPTY token\n",
"EMPTY_TOK = '^'\n",
"EMPTY_IND = 0"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "922cbf65-fbaf-48d5-8605-ea41c3f80590",
"metadata": {},
"outputs": [],
"source": [
"def ctc_decode(inds, ind2char):\n",
" # your code here\n",
" raise NotImplementedError\n",
"\n",
"for i, rec in enumerate(log_probs):\n",
" text = ctc_decode(rec.argmax(-1).numpy(), ind2char)\n",
" print(f\"{i}) {text}\")\n",
"\n"
]
},
{
"cell_type": "markdown",
"id": "3a8c4d50-e633-4e85-8842-a6b50602b70f",
"metadata": {},
"source": [
"# Computing WER and CER\n",
"Task: Implemet WER and CER metrics"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0ca11f70-ee02-4765-b542-96186781a0b8",
"metadata": {},
"outputs": [],
"source": [
"# library for fast quick calculation of edit distance\n",
"import editdistance\n",
"\n",
"def calc_wer(target_text: str, pred_text: str):\n",
" # your code here\n",
" raise NotImplementedError\n",
" \n",
"\n",
"def calc_cer(target_text: str, pred_text: str):\n",
" # your code here\n",
" raise NotImplementedError\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6c391511-7469-4ed8-bd26-057c4fde4717",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"for target, pred, expected_wer, expected_cer in [\n",
" (\"if you can not measure it you can not improve it\", \n",
" \"if you can nt measure t yo can not i\", \n",
" 0.454, 0.25),\n",
" (\"if you cant describe what you are doing as a process you dont know what youre doing\", \n",
" \"if you cant describe what you are doing as a process you dont know what youre doing\", \n",
" 0.0, 0.0),\n",
" (\"one measurement is worth a thousand expert opinions\", \n",
" \"one is worth thousand opinions\", \n",
" 0.375, 0.392)\n",
"]:\n",
" wer = calc_wer(target, pred)\n",
" cer = calc_cer(target, pred)\n",
" assert np.isclose(wer, expected_wer, atol=1e-3), f\"true: {target}, pred: {pred}, expected wer {expected_wer} != your wer {wer}\"\n",
" assert np.isclose(cer, expected_cer, atol=1e-3), f\"true: {target}, pred: {pred}, expected cer {expected_cer} != your cer {cer}\"\n",
" "
]
},
{
"cell_type": "markdown",
"id": "7cefd76b-66d4-4b1e-ae1d-be6b7336a160",
"metadata": {},
"source": [
"Task: come up with such a pair of target-prediction texts, so the\n",
"1) WER > 1.0\n",
"2) CER > WER"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "11bceaaf-7b17-466b-ac17-855e4d54cf56",
"metadata": {},
"outputs": [],
"source": [
"# 1) WER > 1.0\n",
"# your code here\n",
"target, prediction = \"<...>\" , \"<...>\"\n",
"assert calc_wer(target, prediction) > 1.0\n",
"\n",
"# 2) CER > WER\n",
"# your code here\n",
"target, prediction = \"<...>\", \"<...>\"\n",
"assert calc_wer(target, prediction) < calc_cer(target, prediction) \n"
]
},
{
"cell_type": "markdown",
"id": "31a1fb97-4853-4190-835d-31ead094679c",
"metadata": {},
"source": [
"# Beam search\n",
"Task: implement beam-search on CTC outputs"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f8e1c37a-93be-47a1-8211-9b47d0721d23",
"metadata": {},
"outputs": [],
"source": [
"# Load precomputed CTC output\n",
"with open('lj_batch.pickle', 'rb') as f:\n",
" batch = pickle.load(f)\n",
"\n",
"# log probabilities of softmax layers [batch_size, T, vocab_size]\n",
"log_probs = batch[\"log_probs\"]\n",
"\n",
"# Dictionary with index to character mapping\n",
"ind2char = batch[\"ind2char\"]\n",
"\n",
"true_texts = batch[\"text\"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9ae1f264-33cb-4c4d-b959-823d07843936",
"metadata": {},
"outputs": [],
"source": [
"from collections import defaultdict\n",
"from tqdm import tqdm\n",
"\n",
"def ctc_beam_search(probs, beam_size, ind2char):\n",
" # your code here\n",
" raise NotImplementedError\n",
"\n",
"bs_results = []\n",
"for log_probs_line in log_probs:\n",
" bs_results.append(ctc_beam_search(log_probs_line.exp().numpy(), 100, ind2char))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9e6d7249-aed1-4ff3-8ce2-20978320ac7d",
"metadata": {},
"outputs": [],
"source": [
"\n",
"for i in range(len(true_texts)):\n",
" beam_search_hypos = bs_results[i][:3]\n",
" true_text = true_texts[i]\n",
" argmax_text = ctc_decode(log_probs[i].numpy().argmax(-1), ind2char)\n",
" print(\"True: \", true_text)\n",
" print(f\"Argmax: {argmax_text} --- (CER: {calc_cer(true_text, argmax_text):.3f})\")\n",
" for ind, (hypo, score) in enumerate(beam_search_hypos):\n",
" print(f\"{ind+1}) '{hypo}' --- (CER: {calc_cer(true_text, hypo):.3f})\")\n",
" print('-' * 100)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "HW-ASR",
"language": "python",
"name": "hw-asr"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.11"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

0 comments on commit 05f5c78

Please sign in to comment.