Skip to content

Commit

Permalink
Refactor for class-based decode
Browse files Browse the repository at this point in the history
  • Loading branch information
ryanleary committed Jun 26, 2017
1 parent 4f0101c commit 7094aa0
Show file tree
Hide file tree
Showing 6 changed files with 88 additions and 424 deletions.
12 changes: 6 additions & 6 deletions pytorch_ctc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
def _import_symbols(locals):
for symbol in dir(_lib):
fn = getattr(_lib, symbol)
locals[symbol] = _wrap_function(fn, _ffi)
__all__.append(symbol)
new_symbol = "_" + symbol
locals[new_symbol] = _wrap_function(fn, _ffi)
__all__.append(new_symbol)


_import_symbols(locals())
Expand All @@ -35,11 +36,10 @@ def beam_decode(probs, labels, seq_len=None, top_paths=1, beam_width=10, blank_i
out_seq_len = torch.IntTensor(top_paths, batch_size)

merge_int = 1 if merge_repeated else 0
scorer = _lib.get_kenlm_scorer(labels, len(labels), space_index, blank_index, lm_path.encode(), trie_path.encode())
scorer = ctc._get_kenlm_scorer(labels, len(labels), space_index, blank_index, lm_path.encode(), trie_path.encode())
decoder = ctc._get_ctc_beam_decoder(num_classes, batch_size, top_paths, beam_width, blank_index, merge_int, scorer, 1)
print(scorer)
result = ctc.ctc_beam_decode(probs, seq_len, output, scores, out_seq_len,
top_paths, beam_width, blank_index, merge_int,
scorer)
result = ctc._ctc_beam_decode(probs, seq_len, output, scores, out_seq_len, decoder, 1)

return output, scores, out_seq_len

Expand Down
91 changes: 59 additions & 32 deletions pytorch_ctc/src/cpu_binding.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

#include <iostream>
#include "ctc_beam_entry.h"
#include "ctc_beam_scorer.h"
Expand All @@ -10,12 +9,14 @@
#include "util/status.h"
#include "TH.h"
#include "lm/model.hh"
#include "cpu_binding.h"

namespace pytorch {

using pytorch::ctc::KenLMBeamScorer;
using pytorch::ctc::ctc_beam_search::KenLMBeamState;
using pytorch::ctc::Labels;
using pytorch::ctc::Status;

typedef lm::ngram::ProbingModel Model;

Expand Down Expand Up @@ -72,21 +73,43 @@ namespace pytorch {
void* get_kenlm_scorer(const wchar_t* label_str, int labels_size, int space_index, int blank_index,
const char* lm_path, const char* trie_path) {

Labels labels(label_str, labels_size, blank_index, space_index);
ctc::KenLMBeamScorer *beam_scorer = new ctc::KenLMBeamScorer(&labels, lm_path, trie_path);
void* ptr = static_cast<void*>(beam_scorer);
std::cout << "kenlm addr (out): " << beam_scorer << ptr << std::endl;
return ptr;
Labels* labels = new Labels(label_str, labels_size, blank_index, space_index);
ctc::KenLMBeamScorer *beam_scorer = new ctc::KenLMBeamScorer(labels, lm_path, trie_path);
return static_cast<void*>(beam_scorer);
}

void* get_base_scorer() {
ctc::CTCBeamSearchDecoder<>::DefaultBeamScorer *beam_scorer = new ctc::CTCBeamSearchDecoder<>::DefaultBeamScorer();
return static_cast<void *>(beam_scorer);
}

void* get_ctc_beam_decoder(int num_classes, int batch_size, int top_paths, int beam_width, int blank_index, int merge_repeated, void *scorer, DecodeType type) {
switch (type) {
case CTC:
{
ctc::CTCBeamSearchDecoder<>::DefaultBeamScorer *beam_scorer = static_cast<ctc::CTCBeamSearchDecoder<>::DefaultBeamScorer *>(scorer);
ctc::CTCBeamSearchDecoder<> *decoder = new ctc::CTCBeamSearchDecoder<>
(num_classes, beam_width, beam_scorer, batch_size, blank_index, merge_repeated == 1);
return static_cast<void *>(decoder);
}
case CTC_KENLM:
{
ctc::KenLMBeamScorer *beam_scorer = static_cast<ctc::KenLMBeamScorer*>(scorer);
ctc::CTCBeamSearchDecoder<KenLMBeamState> *decoder = new ctc::CTCBeamSearchDecoder<KenLMBeamState>
(num_classes, beam_width, beam_scorer, batch_size, blank_index, merge_repeated == 1);
return static_cast<void *>(decoder);
}
}
return nullptr;
}

int ctc_beam_decode(THFloatTensor *th_probs, THIntTensor *th_seq_len, THIntTensor *th_output,
THFloatTensor *th_scores, THIntTensor *th_out_len, int top_paths,
int beam_width, int blank_index, int merge_repeated,
void* scorer)
THFloatTensor *th_scores, THIntTensor *th_out_len, void *void_decoder, DecodeType type)
{
const int64_t max_time = THFloatTensor_size(th_probs, 0);
const int64_t batch_size = THFloatTensor_size(th_probs, 1);
const int64_t num_classes = THFloatTensor_size(th_probs, 2);
const int64_t top_paths = THIntTensor_size(th_output, 0);

// convert tensors to something the beam scorer can use
// sequence length
Expand All @@ -113,29 +136,33 @@ namespace pytorch {
}

float score[batch_size][top_paths] = {{0.0}};
Eigen::Map<Eigen::MatrixXf> scores(&score[0][0], batch_size, beam_width);


// set up scorer and decoder
// Labels labels(label_str, labels_size, blank_index, space_index);
// ctc::KenLMBeamScorer beam_scorer(&labels, lm_path, trie_path);
std::cout << "Pre cast" << std::endl;
ctc::KenLMBeamScorer* beam_scorer = static_cast<KenLMBeamScorer*>(scorer);
std::cout << "kenlm addr (in): " << scorer <<beam_scorer << std::endl;
std::cout << "post cast" << std::endl;
ctc::CTCBeamSearchDecoder<KenLMBeamState> beam_search(num_classes, beam_width, beam_scorer,
batch_size, blank_index, merge_repeated == 1);
// ctc::CTCBeamSearchDecoder<>::DefaultBeamScorer beam_scorer;
// ctc::CTCBeamSearchDecoder<> beam_search(num_classes, beam_width, &beam_scorer,
// batch_size, blank_index, merge_repeated == 1);
std::cout << "Post beam search init" << std::endl;
ctc::Status stat = beam_search.Decode(seq_len, inputs, &outputs, &scores);
if (!stat.ok()) {
return 0;
Eigen::Map<Eigen::MatrixXf> *scores;

// TODO: this is ugly -- can we better leverage generics somehow?
switch (type) {
case CTC:
{
ctc::CTCBeamSearchDecoder<> *decoder = static_cast<ctc::CTCBeamSearchDecoder<> *>(void_decoder);
scores = new Eigen::Map<Eigen::MatrixXf>(&score[0][0], batch_size, decoder->GetBeamWidth());
Status stat = decoder->Decode(seq_len, inputs, &outputs, scores);
if (!stat.ok()) {
return 0;
}
}
break;
case CTC_KENLM:
{
ctc::CTCBeamSearchDecoder<KenLMBeamState> *decoder = static_cast<ctc::CTCBeamSearchDecoder<KenLMBeamState> *>(void_decoder);
scores = new Eigen::Map<Eigen::MatrixXf>(&score[0][0], batch_size, decoder->GetBeamWidth());
Status stat = decoder->Decode(seq_len, inputs, &outputs, scores);
if (!stat.ok()) {
return 0;
}
}
break;
}
std::cout << "post decode" << std::endl;
std::vector<float> log_probs;

std::vector<float> log_probs;
for (int p=0; p < top_paths; ++p) {
int64_t max_decoded = 0;
int64_t offset = 0;
Expand All @@ -148,14 +175,14 @@ namespace pytorch {
for (int64_t t=0; t < num_decoded; ++t) {
// TODO: this could be more efficient (significant pointer arithmetic every time currently)
THIntTensor_set3d(th_output, p, b, t, p_batch[t]);
THFloatTensor_set2d(th_scores, p, b, scores(b, p));
THFloatTensor_set2d(th_scores, p, b, (*scores)(b, p));
}
}
}
delete scores;
return 1;
}


int generate_lm_trie(const wchar_t* label_str, int size, int blank_index, int space_index,
const char* lm_path, const char* dictionary_path, const char* output_path) {
Labels labels(label_str, size, blank_index, space_index);
Expand Down
24 changes: 20 additions & 4 deletions pytorch_ctc/src/cpu_binding.h
Original file line number Diff line number Diff line change
@@ -1,10 +1,26 @@
int ctc_beam_decode(THFloatTensor *probs, THIntTensor *seq_len, THIntTensor *output,
THFloatTensor *scores, THIntTensor *th_out_len, int top_paths,
int beam_width, int blank_index, int merge_repeated,
void *scorer);
typedef enum {
CTC,
CTC_KENLM
} DecodeType ;


/* scorers */
void* get_kenlm_scorer(const wchar_t* label_str, int labels_size, int space_index, int blank_index,
const char* lm_path, const char* trie_path);
void* get_base_scorer();


/* decoders */
void* get_ctc_beam_decoder(int num_classes, int batch_size, int top_paths, int beam_width, int blank_index,
int merge_repeated, void *scorer, DecodeType type);


/* run decoding */
int ctc_beam_decode(THFloatTensor *probs, THIntTensor *seq_len, THIntTensor *output,
THFloatTensor *scores, THIntTensor *th_out_len,
void *decoder, DecodeType type);


/* utilities */
int generate_lm_trie(const wchar_t* labels, int size, int blank_index, int space_index,
const char* lm_path, const char* dictionary_path, const char* output_path);
2 changes: 2 additions & 0 deletions pytorch_ctc/src/ctc_beam_search.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ class CTCBeamSearchDecoder : public CTCDecoder {
label_selection_margin_ = label_selection_margin;
}

int GetBeamWidth() const { return beam_width_; }

// Reset the beam search
void Reset();

Expand Down
Loading

0 comments on commit 7094aa0

Please sign in to comment.