Skip to content

Commit

Permalink
Add seq length and arbitrary blank index support
Browse files Browse the repository at this point in the history
  • Loading branch information
ryanleary authored and Ryan Leary committed Jun 19, 2017
1 parent 3d4452a commit 05886dd
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 23 deletions.
30 changes: 22 additions & 8 deletions pytorch_ctc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,40 @@
import torch
import pytorch_ctc as ctc
from torch.utils.ffi import _wrap_function
from ._ctc_decode import lib as _lib, ffi as _ffi

from cffi import FFI
ffi = FFI()
from ._ext.ctc_decode import ctc_beam_decode
__all__ = []


def beam_decode(probs, seq_len=None, top_paths=1, beam_width=10, merge_repeated=True):
def _import_symbols(locals):
for symbol in dir(_lib):
fn = getattr(_lib, symbol)
locals[symbol] = _wrap_function(fn, _ffi)
__all__.append(symbol)


_import_symbols(locals())


def beam_decode(probs, seq_len=None, top_paths=1, beam_width=10, blank_index=0, merge_repeated=True):
prob_size = probs.size()
max_seq_len = prob_size[0]
batch_size = prob_size[1]
num_classes = prob_size[2]

if seq_len is not None and batch_size != seq_len.size():
raise ValueError("seq_len shape must be a 1xbatch_size tensor or None")
if blank_index < 0 or blank_index >= num_classes:
raise ValueError("blank_index must be within num_classes")
if seq_len is not None and batch_size != seq_len.size(0):
raise ValueError("seq_len shape must be a (batch_size) tensor or None")
if top_paths < 1 or top_paths > beam_width:
raise ValueError("top_paths must be greater than 1 and less than or equal to the beam_width")

seq_len = torch.IntTensor(batch_size).zero_().add_(max_seq_len) if seq_len is None else seq_len
output = torch.IntTensor(top_paths, batch_size, max_seq_len)
scores = torch.FloatTensor(top_paths, batch_size)
out_seq_len = torch.IntTensor(top_paths, batch_size)

merge_int = 1 if merge_repeated else 0
result = ctc_beam_decode(probs, seq_len, output, scores, top_paths, beam_width, merge_int)
result = ctc.ctc_beam_decode(probs, seq_len, output, scores, out_seq_len, top_paths, beam_width, blank_index, merge_int)

return output, scores
return output, scores, out_seq_len
9 changes: 4 additions & 5 deletions pytorch_ctc/src/cpu_binding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ namespace pytorch {
extern "C"
{
int ctc_beam_decode(THFloatTensor *th_probs, THIntTensor *th_seq_len, THIntTensor *th_output,
THFloatTensor *th_scores, int top_paths, int beam_width, int merge_repeated)
THFloatTensor *th_scores, THIntTensor *th_out_len, int top_paths,
int beam_width, int blank_index, int merge_repeated)
{
const int64_t max_time = THFloatTensor_size(th_probs, 0);
const int64_t batch_size = THFloatTensor_size(th_probs, 1);
Expand Down Expand Up @@ -47,7 +48,7 @@ namespace pytorch {

// initialize beam search class
ctc::CTCBeamSearchDecoder<> beam_search(num_classes, beam_width, &beam_scorer,
batch_size, merge_repeated == 1);
batch_size, blank_index, merge_repeated == 1);

ctc::Status stat = beam_search.Decode(seq_len, inputs, &outputs, &scores);
if (!stat.ok()) {
Expand All @@ -63,14 +64,12 @@ namespace pytorch {
int64_t num_decoded = p_batch.size();

max_decoded = std::max(max_decoded, num_decoded);
THIntTensor_set2d(th_out_len, p, b, num_decoded);
for (int64_t t=0; t < num_decoded; ++t) {
// TODO: this could be more efficient (significant pointer arithmetic every time currently)
THIntTensor_set3d(th_output, p, b, t, p_batch[t]);
THFloatTensor_set2d(th_scores, p, b, scores(b, p));
}
for (int64_t t = num_decoded; t < max_time; ++t) {
THIntTensor_set3d(th_output, p, b, t, -1);
}
}
}
return 1;
Expand Down
3 changes: 2 additions & 1 deletion pytorch_ctc/src/cpu_binding.h
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
int ctc_beam_decode(THFloatTensor *probs, THIntTensor *seq_len, THIntTensor *output,
THFloatTensor *scores, int top_paths, int beam_width, int merge_repeated);
THFloatTensor *scores, THIntTensor *th_out_len, int top_paths,
int beam_width, int blank_index, int merge_repeated);
8 changes: 4 additions & 4 deletions pytorch_ctc/src/ctc_beam_search.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ class CTCBeamSearchDecoder : public CTCDecoder {
// standard beam search.
CTCBeamSearchDecoder(int num_classes, int beam_width,
BaseBeamScorer<CTCBeamState>* scorer, int batch_size = 1,
bool merge_repeated = false)
: CTCDecoder(num_classes, batch_size, merge_repeated),
int blank_index = 0, bool merge_repeated = false)
: CTCDecoder(num_classes, batch_size, blank_index, merge_repeated),
beam_width_(beam_width),
leaves_(beam_width),
// TODO: ADD CHECK_NOTNULL BACK
Expand Down Expand Up @@ -296,7 +296,7 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step(
}

if (!b->HasChildren()) {
b->PopulateChildren(num_classes_ - 1);
b->PopulateChildren(num_classes_);
}

for (BeamEntry& c : *b->Children()) {
Expand Down Expand Up @@ -345,7 +345,7 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Reset() {

// This beam root, and all of its children, will be in memory until
// the next reset.
beam_root_.reset(new BeamEntry(nullptr, -1, num_classes_ - 1, -1));
beam_root_.reset(new BeamEntry(nullptr, -1, num_classes_, -1));
beam_root_->newp.total = 0.0; // ln(1)
beam_root_->newp.blank = 0.0; // ln(1)

Expand Down
8 changes: 4 additions & 4 deletions pytorch_ctc/src/ctc_decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ class CTCDecoder {
typedef std::vector<std::vector<int>> Output;
typedef Eigen::Map<Eigen::MatrixXf> ScoreOutput;

CTCDecoder(int num_classes, int batch_size, bool merge_repeated)
CTCDecoder(int num_classes, int batch_size, int blank_index, bool merge_repeated)
: num_classes_(num_classes),
blank_index_(num_classes - 1),
blank_index_(blank_index),
batch_size_(batch_size),
merge_repeated_(merge_repeated) {}

Expand Down Expand Up @@ -65,8 +65,8 @@ class CTCDecoder {
// algorithm, selecting at each timestep the most likely class at each timestep.
class CTCGreedyDecoder : public CTCDecoder {
public:
CTCGreedyDecoder(int num_classes, int batch_size, bool merge_repeated)
: CTCDecoder(num_classes, batch_size, merge_repeated) {}
CTCGreedyDecoder(int num_classes, int batch_size, int blank_index, bool merge_repeated)
: CTCDecoder(num_classes, batch_size, blank_index, merge_repeated) {}

Status Decode(const CTCDecoder::SequenceLength& seq_len,
const std::vector<CTCDecoder::Input>& input,
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
headers = ['pytorch_ctc/src/cpu_binding.h']

ffi = create_extension(
name='pytorch_ctc._ext.ctc_decode',
name='ctc_decode',
package=True,
language='c++',
headers=headers,
Expand All @@ -19,6 +19,7 @@
extra_compile_args=['-std=c++11', '-fPIC', '-w']
)
ffi = ffi.distutils_extension()
ffi.name = 'pytorch_ctc._ctc_decode'

setup(
name="pytorch_ctc",
Expand Down

0 comments on commit 05886dd

Please sign in to comment.