forked from parlance/ctcdecode
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
SeanNaren
committed
Nov 17, 2017
1 parent
2d7ae35
commit 06af477
Showing
184 changed files
with
1,523 additions
and
59,249 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,6 @@ | ||
[submodule "pytorch_ctc/src/third_party/kenlm"] | ||
[submodule "third_party/kenlm"] | ||
path = third_party/kenlm | ||
url = https://github.com/kpu/kenlm.git | ||
url = https://github.com/luotao1/kenlm.git | ||
[submodule "third_party/ThreadPool"] | ||
path = third_party/ThreadPool | ||
url = https://github.com/progschj/ThreadPool.git |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
#!/usr/bin/env python | ||
|
||
import glob | ||
import os | ||
import tarfile | ||
|
||
import wget | ||
from torch.utils.ffi import create_extension | ||
|
||
# Download/Extract openfst | ||
dl_path = 'third_party/openfst-1.6.3.tar.gz' | ||
if not os.path.isfile(dl_path): | ||
wget.download('http://www.openfst.org/twiki/pub/FST/FstDownload/openfst-1.6.3.tar.gz', | ||
out=dl_path) | ||
tar = tarfile.open(dl_path) | ||
tar.extractall('third_party/') | ||
tar.close() | ||
|
||
|
||
# Does gcc compile with this header and library? | ||
def compile_test(header, library): | ||
dummy_path = os.path.join(os.path.dirname(__file__), "dummy") | ||
command = "bash -c \"g++ -include " + header + " -l" + library + " -x c++ - <<<'int main() {}' -o " + dummy_path \ | ||
+ " >/dev/null 2>/dev/null && rm " + dummy_path + " 2>/dev/null\"" | ||
return os.system(command) == 0 | ||
|
||
|
||
compile_args = ['-O3', '-DNDEBUG', '-DKENLM_MAX_ORDER=6', '-std=c++11', '-fPIC', '-w'] | ||
ext_libs = ['stdc++'] | ||
|
||
if compile_test('zlib.h', 'z'): | ||
compile_args.append('-DHAVE_ZLIB') | ||
ext_libs.append('z') | ||
|
||
if compile_test('bzlib.h', 'bz2'): | ||
compile_args.append('-DHAVE_BZLIB') | ||
ext_libs.append('bz2') | ||
|
||
if compile_test('lzma.h', 'lzma'): | ||
compile_args.append('-DHAVE_XZLIB') | ||
ext_libs.append('lzma') | ||
|
||
third_party_libs = ["kenlm", "openfst-1.6.3/src/include", "ThreadPool"] | ||
compile_args.extend(['-DINCLUDE_KENLM', '-DKENLM_MAX_ORDER=6']) | ||
lib_sources = glob.glob('third_party/kenlm/util/*.cc') + glob.glob('third_party/kenlm/lm/*.cc') + glob.glob( | ||
'third_party/kenlm/util/double-conversion/*.cc') + glob.glob('third_party/openfst-1.6.3/src/lib/*.cc') | ||
lib_sources = [fn for fn in lib_sources if not (fn.endswith('main.cc') or fn.endswith('test.cc'))] | ||
|
||
third_party_includes = ["third_party/" + lib for lib in third_party_libs] | ||
ctc_sources = glob.glob('ctcdecode/src/*.cpp') | ||
ctc_headers = ['ctcdecode/src/binding.h', ] | ||
|
||
ffi = create_extension( | ||
name='ctcdecode._ext.ctc_decode', | ||
package=True, | ||
language='c++', | ||
headers=ctc_headers, | ||
sources=ctc_sources + lib_sources, | ||
include_dirs=third_party_includes, | ||
with_cuda=False, | ||
libraries=ext_libs, | ||
extra_compile_args=compile_args | ||
) | ||
|
||
if __name__ == '__main__': | ||
ffi.build() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,139 +1,48 @@ | ||
import torch | ||
import ctcdecode as ctc | ||
from torch.utils.ffi import _wrap_function | ||
from ._ext import ctc_decode | ||
# from ._ext._ctc_decode import lib as _lib, ffi as _ffi | ||
# | ||
# __all__ = [] | ||
# | ||
# | ||
# def _import_symbols(locals): | ||
# for symbol in dir(_lib): | ||
# fn = getattr(_lib, symbol) | ||
# new_symbol = "_" + symbol | ||
# locals[new_symbol] = _wrap_function(fn, _ffi) | ||
# __all__.append(new_symbol) | ||
# | ||
# | ||
# _import_symbols(locals()) | ||
import torch | ||
|
||
|
||
class BaseCTCBeamDecoder(object): | ||
def __init__(self, labels, top_paths=1, beam_width=10, blank_index=0, space_index=28): | ||
self._labels = labels | ||
self._top_paths = top_paths | ||
class CTCBeamDecoder(object): | ||
def __init__(self, labels, model_path=None, alpha=0, beta=0, cutoff_top_n=40, cutoff_prob=1.0, beam_width=100, | ||
num_processes=4, blank_id=0): | ||
self.cutoff_top_n = cutoff_top_n | ||
self._beam_width = beam_width | ||
self._blank_index = blank_index | ||
self._space_index = space_index | ||
self._num_classes = len(labels) | ||
self._decoder = None | ||
|
||
if blank_index < 0 or blank_index >= self._num_classes: | ||
raise ValueError("blank_index must be within num_classes") | ||
|
||
if top_paths < 1 or top_paths > beam_width: | ||
raise ValueError("top_paths must be greater than 1 and less than or equal to the beam_width") | ||
|
||
def decode(self, probs, seq_len=None): | ||
prob_size = probs.size() | ||
max_seq_len = prob_size[0] | ||
batch_size = prob_size[1] | ||
num_classes = prob_size[2] | ||
|
||
if seq_len is not None and batch_size != seq_len.size(0): | ||
raise ValueError("seq_len shape must be a (batch_size) tensor or None") | ||
|
||
seq_len = torch.IntTensor(batch_size).zero_().add_(max_seq_len) if seq_len is None else seq_len | ||
output = torch.IntTensor(self._top_paths, batch_size, max_seq_len) | ||
scores = torch.FloatTensor(self._top_paths, batch_size) | ||
out_seq_len = torch.IntTensor(self._top_paths, batch_size) | ||
alignments = torch.IntTensor(self._top_paths, batch_size, max_seq_len) | ||
char_probs = torch.FloatTensor(self._top_paths, batch_size, max_seq_len) | ||
|
||
result = ctc_decode.ctc_beam_decode(self._decoder, self._decoder_type, probs, seq_len, output, scores, out_seq_len, | ||
alignments, char_probs) | ||
|
||
return output, scores, out_seq_len, alignments, char_probs | ||
|
||
|
||
class BaseScorer(object): | ||
def __init__(self): | ||
self._scorer_type = 0 | ||
self._scorer = None | ||
|
||
def get_scorer_type(self): | ||
return self._scorer_type | ||
|
||
def get_scorer(self): | ||
return self._scorer | ||
|
||
|
||
class Scorer(BaseScorer): | ||
def __init__(self): | ||
super(Scorer, self).__init__() | ||
self._scorer = ctc_decode.get_base_scorer() | ||
|
||
|
||
class DictScorer(BaseScorer): | ||
def __init__(self, labels, trie_path, blank_index=0, space_index=28): | ||
super(DictScorer, self).__init__() | ||
self._scorer_type = 1 | ||
self._scorer = ctc_decode.get_dict_scorer(labels, len(labels), space_index, blank_index, trie_path.encode()) | ||
|
||
def set_min_unigram_weight(self, weight): | ||
if weight is not None: | ||
ctc_decode.set_dict_min_unigram_weight(self._scorer, weight) | ||
|
||
|
||
class KenLMScorer(BaseScorer): | ||
def __init__(self, labels, lm_path, trie_path, blank_index=0, space_index=28): | ||
super(KenLMScorer, self).__init__() | ||
if ctc_decode.kenlm_enabled() != 1: | ||
raise ImportError("ctcdecode not compiled with KenLM support.") | ||
self._scorer_type = 2 | ||
self._scorer = ctc_decode.get_kenlm_scorer(labels, len(labels), space_index, blank_index, lm_path.encode(), | ||
trie_path.encode()) | ||
|
||
# This is a way to make sure the destructor is called for the C++ object | ||
# Frees all the member data items that have allocated memory | ||
def __del__(self): | ||
ctc_decode.free_kenlm_scorer(self._scorer) | ||
|
||
def set_lm_weight(self, weight): | ||
if weight is not None: | ||
ctc_decode.set_kenlm_scorer_lm_weight(self._scorer, weight) | ||
|
||
def set_word_weight(self, weight): | ||
if weight is not None: | ||
ctc_decode.set_kenlm_scorer_wc_weight(self._scorer, weight) | ||
|
||
def set_min_unigram_weight(self, weight): | ||
if weight is not None: | ||
ctc_decode.set_kenlm_min_unigram_weight(self._scorer, weight) | ||
|
||
|
||
class CTCBeamDecoder(BaseCTCBeamDecoder): | ||
def __init__(self, scorer, labels, top_paths=1, beam_width=10, blank_index=0, space_index=28): | ||
super(CTCBeamDecoder, self).__init__(labels, top_paths=top_paths, beam_width=beam_width, | ||
blank_index=blank_index, space_index=space_index) | ||
self._scorer = scorer | ||
self._decoder_type = self._scorer.get_scorer_type() | ||
self._decoder = ctc_decode.get_ctc_beam_decoder(self._num_classes, top_paths, beam_width, blank_index, | ||
self._scorer.get_scorer(), self._decoder_type) | ||
|
||
def set_label_selection_parameters(self, label_size=0, label_margin=-1): | ||
ctc_decode.set_label_selection_parameters(self._decoder, label_size, label_margin) | ||
|
||
|
||
def generate_lm_dict(dictionary_path, output_path, labels, kenlm_path=None, blank_index=0, space_index=28): | ||
if kenlm_path is not None and ctc_decode.kenlm_enabled() != 1: | ||
raise ImportError("ctcdecode not compiled with KenLM support.") | ||
result = None | ||
if kenlm_path is not None: | ||
result = ctc_decode.generate_lm_dict(labels, len(labels), blank_index, space_index, kenlm_path.encode(), | ||
dictionary_path.encode(), output_path.encode()) | ||
else: | ||
result = ctc_decode.generate_dict(labels, len(labels), blank_index, space_index, | ||
dictionary_path.encode(), output_path.encode()) | ||
if result != 0: | ||
raise ValueError("Error encountered generating dictionary") | ||
self._num_processes = num_processes | ||
self._labels = ''.join(labels).encode() | ||
self._blank_id = blank_id | ||
if model_path: | ||
self._scorer = ctc_decode.paddle_get_scorer(alpha, beta, model_path.encode(), self._labels, | ||
len(self._labels)) | ||
self._cutoff_prob = cutoff_prob | ||
|
||
def decode(self, probs): | ||
# We expect batch x seq x label_size | ||
probs = probs.cpu().float() | ||
batch_size, max_seq_len = probs.size(0), probs.size(1) | ||
output = torch.IntTensor(batch_size, self._beam_width, max_seq_len).cpu().int() | ||
scores = torch.IntTensor(batch_size, self._beam_width).cpu().int() | ||
out_seq_len = torch.IntTensor(batch_size, self._beam_width).cpu().int() | ||
if self._scorer: | ||
ctc_decode.paddle_beam_decode_lm(probs, self._labels, len(self._labels), self._beam_width, | ||
self._num_processes, self._cutoff_prob, self.cutoff_top_n, self._blank_id, | ||
self._scorer, output, scores, out_seq_len) | ||
else: | ||
ctc_decode.paddle_beam_decode(probs, self._labels, len(self._labels), self._beam_width, self._num_processes, | ||
self._cutoff_prob, self.cutoff_top_n, self._blank_id, output, scores, | ||
out_seq_len) | ||
|
||
return output, scores, out_seq_len | ||
|
||
def character_based(self): | ||
return ctc_decode.is_character_based(self._scorer) if self._scorer else None | ||
|
||
def max_order(self): | ||
return ctc_decode.get_max_order(self._scorer) if self._scorer else None | ||
|
||
def dict_size(self): | ||
return ctc_decode.get_dict_size(self._scorer) if self._scorer else None | ||
|
||
def reset_params(self, alpha, beta): | ||
if self._scorer is not None: | ||
ctc_decode.reset_params(self._scorer, alpha, beta) |
Oops, something went wrong.