Skip to content

Commit

Permalink
[cli] add first wenet cli (wenet-e2e#2047)
Browse files Browse the repository at this point in the history
  • Loading branch information
robin1001 authored Oct 12, 2023
1 parent e6525ad commit f926af1
Show file tree
Hide file tree
Showing 5 changed files with 208 additions and 0 deletions.
15 changes: 15 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
[metadata]
name = wenet
version = 0.0.0
license = Apache Software License
description = End to end speech recognition toolkit
long_description = file: README.md
classifiers =
License :: OSI Approved :: Apache Software License
Operating System :: OS Independent
Programming Language :: Python :: 3

[options]
packages = find:
include_package_data = True
python_requires = >= 3.8
15 changes: 15 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from setuptools import setup, find_packages

requirements = [
"torch==1.10.0",
"torchaudio==0.10.0"
]

setup(
name="wenet",
install_requires=requirements,
packages=find_packages(),
entry_points={"console_scripts": [
"wenet = wenet.cli.transcribe:main",
]},
)
Empty file added wenet/cli/__init__.py
Empty file.
108 changes: 108 additions & 0 deletions wenet/cli/hub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Copyright (c) 2022 Mddct(hamddct@gmail.com)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import sys
import tarfile
from pathlib import Path
from urllib.request import urlretrieve

import tqdm


def download(url: str, dest: str, only_child=True):
""" download from url to dest
"""
assert os.path.exists(dest)
print('Downloading {} to {}'.format(url, dest))

def progress_hook(t):
last_b = [0]

def update_to(b=1, bsize=1, tsize=None):
if tsize not in (None, -1):
t.total = tsize
displayed = t.update((b - last_b[0]) * bsize)
last_b[0] = b
return displayed

return update_to

# *.tar.gz
name = url.split("/")[-1]
tar_path = os.path.join(dest, name)
with tqdm.tqdm(unit='B',
unit_scale=True,
unit_divisor=1024,
miniters=1,
desc=(name)) as t:
urlretrieve(url,
filename=tar_path,
reporthook=progress_hook(t),
data=None)
t.total = t.n

with tarfile.open(tar_path) as f:
if not only_child:
f.extractall(dest)
else:
for tarinfo in f:
if "/" not in tarinfo.name:
continue
name = os.path.basename(tarinfo.name)
fileobj = f.extractfile(tarinfo)
with open(os.path.join(dest, name), "wb") as writer:
writer.write(fileobj.read())


class Hub(object):
"""Hub for wenet pretrain runtime model
"""
# TODO(Mddct): make assets class to support other language
Assets = {
# wenetspeech
"chinese":
"https://github.com/wenet-e2e/wenet/releases/download/v2.0.1/chs.tar.gz",
# gigaspeech
"english":
"https://github.com/wenet-e2e/wenet/releases/download/v2.0.1/en.tar.gz"
}

def __init__(self) -> None:
pass

@staticmethod
def get_model_by_lang(lang: str) -> str:
if lang not in Hub.Assets.keys():
print('ERROR: Unsupported language {} !!!'.format(lang))
sys.exit(1)

# NOTE(Mddct): model_dir structure
# Path.Home()/.went
# - chs
# - units.txt
# - final.zip
# - en
# - units.txt
# - final.zip
model_url = Hub.Assets[lang]
model_dir = os.path.join(Path.home(), ".wenet", lang)
if not os.path.exists(model_dir):
os.makedirs(model_dir)
# TODO(Mddct): model metadata
if set(["final.zip",
"units.txt"]).issubset(set(os.listdir(model_dir))):
return model_dir
download(model_url, model_dir, only_child=True)
return model_dir
70 changes: 70 additions & 0 deletions wenet/cli/transcribe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright (c) 2023 Binbin Zhang (binbzha@qq.com)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os

import torch
import torchaudio
import torchaudio.compliance.kaldi as kaldi

from wenet.cli.hub import Hub
from wenet.utils.common import remove_duplicates_and_blank
from wenet.utils.file_utils import read_symbol_table


def get_args():
parser = argparse.ArgumentParser(description='')
parser.add_argument('audio_file', help='audio file to transcribe')
parser.add_argument('--language',
choices=[
'chinese',
'english',
],
default='chinese',
help='language type')

args = parser.parse_args()
return args


def main():
args = get_args()
model_dir = Hub.get_model_by_lang(args.language)
model_path = os.path.join(model_dir, 'final.zip')
units_path = os.path.join(model_dir, 'units.txt')
model = torch.jit.load(model_path)
symbol_table = read_symbol_table(units_path)
char_dict = {v: k for k, v in symbol_table.items()}
waveform, sample_rate = torchaudio.load(args.audio_file, normalize=False)
waveform = waveform.to(torch.float)
feats = kaldi.fbank(waveform,
num_mel_bins=80,
frame_length=25,
frame_shift=10,
energy_floor=0.0,
sample_frequency=16000)
feats = feats.unsqueeze(0)
encoder_out, _, _ = model.forward_encoder_chunk(feats, 0, -1)
ctc_probs = model.ctc_activation(encoder_out)
topk_prob, topk_index = ctc_probs.topk(1, dim=2)
topk_index = topk_index.squeeze().tolist()
hyp = remove_duplicates_and_blank(topk_index)
hyp = [char_dict[x] for x in hyp]
result = ''.join(hyp)
print(result)


if __name__ == "__main__":
main()

0 comments on commit f926af1

Please sign in to comment.