Skip to content

Commit

Permalink
Convert to new pytorch apis
Browse files Browse the repository at this point in the history
  • Loading branch information
KarimTarabishy committed Dec 20, 2018
1 parent 140b458 commit 7dde7b2
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 106 deletions.
37 changes: 17 additions & 20 deletions build.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import warnings

import wget
from torch.utils.ffi import create_extension
import setuptools
from torch.utils.cpp_extension import CppExtension, include_paths


def download_extract(url, dl_path):
Expand Down Expand Up @@ -40,9 +41,8 @@ def compile_test(header, library):
return os.system(command) == 0


compile_args = ['-O3', '-DNDEBUG', '-DKENLM_MAX_ORDER=6', '-std=c++11', '-fPIC', '-std=c99', '-w']
ext_libs = ['stdc++']

compile_args = ['-O3', '-DKENLM_MAX_ORDER=6', '-std=c++11', '-fPIC']
ext_libs = []
if compile_test('zlib.h', 'z'):
compile_args.append('-DHAVE_ZLIB')
ext_libs.append('z')
Expand All @@ -63,19 +63,16 @@ def compile_test(header, library):

third_party_includes = [os.path.realpath(os.path.join("third_party", lib)) for lib in third_party_libs]
ctc_sources = glob.glob('ctcdecode/src/*.cpp')
ctc_headers = ['ctcdecode/src/binding.h', ]

ffi = create_extension(
name='ctcdecode._ext.ctc_decode',
package=True,
language='c++',
headers=ctc_headers,
sources=ctc_sources + lib_sources,
include_dirs=third_party_includes,
with_cuda=False,
libraries=ext_libs,
extra_compile_args=compile_args
)

if __name__ == '__main__':
ffi.build()


extension = CppExtension(
name='ctcdecode._ext.ctc_decode',
package=True,
with_cuda=False,
sources=ctc_sources + lib_sources,
include_dirs=third_party_includes + include_paths(),
libraries=ext_libs,
extra_compile_args=compile_args,
language='c++')


176 changes: 96 additions & 80 deletions ctcdecode/src/binding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#include <iostream>
#include <string>
#include <vector>
#include "TH.h"
#include <torch/torch.h>
#include "scorer.h"
#include "ctc_beam_search_decoder.h"
#include "utf8.h"
Expand All @@ -22,8 +22,8 @@ int utf8_to_utf8_char_vec(const char* labels, std::vector<std::string>& new_voca
while (str_i < end);
}

int beam_decode(THFloatTensor *th_probs,
THIntTensor *th_seq_lens,
int beam_decode( at::Tensor th_probs,
at::Tensor th_seq_lens,
const char* labels,
int vocab_size,
size_t beam_size,
Expand All @@ -32,29 +32,32 @@ int beam_decode(THFloatTensor *th_probs,
size_t cutoff_top_n,
size_t blank_id,
void *scorer,
THIntTensor *th_output,
THIntTensor *th_timesteps,
THFloatTensor *th_scores,
THIntTensor *th_out_length)
at::Tensor th_output,
at::Tensor th_timesteps,
at::Tensor th_scores,
at::Tensor th_out_length)
{
std::vector<std::string> new_vocab;
utf8_to_utf8_char_vec(labels, new_vocab);
Scorer *ext_scorer = NULL;
if (scorer != NULL) {
ext_scorer = static_cast<Scorer *>(scorer);
}
const int64_t max_time = THFloatTensor_size(th_probs, 1);
const int64_t batch_size = THFloatTensor_size(th_probs, 0);
const int64_t num_classes = THFloatTensor_size(th_probs, 2);
const int64_t max_time = th_probs.size(1);
const int64_t batch_size = th_probs.size(0);
const int64_t num_classes = th_probs.size(2);

std::vector<std::vector<std::vector<double>>> inputs;
auto prob_accessor = th_probs.accessor<float, 3>();
auto seq_len_accessor = th_seq_lens.accessor<int, 1>();

for (int b=0; b < batch_size; ++b) {
// avoid a crash by ensuring that an erroneous seq_len doesn't have us try to access memory we shouldn't
int seq_len = std::min(THIntTensor_get1d(th_seq_lens, b), (int)max_time);
int seq_len = std::min((int)seq_len_accessor[b], (int)max_time);
std::vector<std::vector<double>> temp (seq_len, std::vector<double>(num_classes));
for (int t=0; t < seq_len; ++t) {
for (int n=0; n < num_classes; ++n) {
float val = THFloatTensor_get3d(th_probs, b, t, n);
float val = prob_accessor[b][t][n];
temp[t][n] = val;
}
}
Expand All @@ -63,6 +66,11 @@ int beam_decode(THFloatTensor *th_probs,

std::vector<std::vector<std::pair<double, Output>>> batch_results =
ctc_beam_search_decoder_batch(inputs, new_vocab, beam_size, num_processes, cutoff_prob, cutoff_top_n, blank_id, ext_scorer);
auto outputs_accessor = th_output.accessor<int, 3>();
auto timesteps_accessor = th_timesteps.accessor<int, 3>();
auto scores_accessor = th_scores.accessor<float, 2>();
auto out_length_accessor = th_out_length.accessor<int, 2>();


for (int b = 0; b < batch_results.size(); ++b){
std::vector<std::pair<double, Output>> results = batch_results[b];
Expand All @@ -72,84 +80,92 @@ int beam_decode(THFloatTensor *th_probs,
std::vector<int> output_tokens = output.tokens;
std::vector<int> output_timesteps = output.timesteps;
for (int t = 0; t < output_tokens.size(); ++t){
THIntTensor_set3d(th_output, b, p, t, output_tokens[t]); // fill output tokens
THIntTensor_set3d(th_timesteps, b, p, t, output_timesteps[t]); // fill timesteps tokens
outputs_accessor[b][p][t] = output_tokens[t]; // fill output tokens
timesteps_accessor[b][p][t] = output_timesteps[t];
}
THFloatTensor_set2d(th_scores, b, p, n_path_result.first); // fill path scores
THIntTensor_set2d(th_out_length, b, p, output_tokens.size());
scores_accessor[b][p] = n_path_result.first;
out_length_accessor[b][p] = output_tokens.size();
}
}
return 1;
}


extern "C"
{
#include "binding.h"
int paddle_beam_decode(THFloatTensor *th_probs,
THIntTensor *th_seq_lens,
const char* labels,
int vocab_size,
size_t beam_size,
size_t num_processes,
double cutoff_prob,
size_t cutoff_top_n,
size_t blank_id,
THIntTensor *th_output,
THIntTensor *th_timesteps,
THFloatTensor *th_scores,
THIntTensor *th_out_length){

return beam_decode(th_probs, th_seq_lens, labels, vocab_size, beam_size, num_processes,
cutoff_prob, cutoff_top_n, blank_id,NULL, th_output, th_timesteps, th_scores, th_out_length);
}

int paddle_beam_decode_lm(THFloatTensor *th_probs,
THIntTensor *th_seq_lens,
const char* labels,
int vocab_size,
size_t beam_size,
size_t num_processes,
double cutoff_prob,
size_t cutoff_top_n,
size_t blank_id,
void *scorer,
THIntTensor *th_output,
THIntTensor *th_timesteps,
THFloatTensor *th_scores,
THIntTensor *th_out_length){

return beam_decode(th_probs, th_seq_lens, labels, vocab_size, beam_size, num_processes,
cutoff_prob, cutoff_top_n, blank_id,scorer, th_output, th_timesteps, th_scores, th_out_length);
}
int paddle_beam_decode(at::Tensor th_probs,
at::Tensor th_seq_lens,
const char* labels,
int vocab_size,
size_t beam_size,
size_t num_processes,
double cutoff_prob,
size_t cutoff_top_n,
size_t blank_id,
at::Tensor th_output,
at::Tensor th_timesteps,
at::Tensor th_scores,
at::Tensor th_out_length){

return beam_decode(th_probs, th_seq_lens, labels, vocab_size, beam_size, num_processes,
cutoff_prob, cutoff_top_n, blank_id,NULL, th_output, th_timesteps, th_scores, th_out_length);
}

void* paddle_get_scorer(double alpha,
double beta,
const char* lm_path,
const char* labels,
int vocab_size) {
std::vector<std::string> new_vocab;
utf8_to_utf8_char_vec(labels, new_vocab);
Scorer* scorer = new Scorer(alpha, beta, lm_path, new_vocab);
return static_cast<void*>(scorer);
}
int paddle_beam_decode_lm(at::Tensor th_probs,
at::Tensor th_seq_lens,
const char* labels,
int vocab_size,
size_t beam_size,
size_t num_processes,
double cutoff_prob,
size_t cutoff_top_n,
size_t blank_id,
void *scorer,
at::Tensor th_output,
at::Tensor th_timesteps,
at::Tensor th_scores,
at::Tensor th_out_length){

int is_character_based(void *scorer){
Scorer *ext_scorer = static_cast<Scorer *>(scorer);
return ext_scorer->is_character_based();
}
size_t get_max_order(void *scorer){
Scorer *ext_scorer = static_cast<Scorer *>(scorer);
return ext_scorer->get_max_order();
}
size_t get_dict_size(void *scorer){
Scorer *ext_scorer = static_cast<Scorer *>(scorer);
return ext_scorer->get_dict_size();
}
return beam_decode(th_probs, th_seq_lens, labels, vocab_size, beam_size, num_processes,
cutoff_prob, cutoff_top_n, blank_id,scorer, th_output, th_timesteps, th_scores, th_out_length);
}

void reset_params(void *scorer, double alpha, double beta){
Scorer *ext_scorer = static_cast<Scorer *>(scorer);
ext_scorer->reset_params(alpha, beta);
}

void* paddle_get_scorer(double alpha,
double beta,
const char* lm_path,
const char* labels,
int vocab_size) {
std::vector<std::string> new_vocab;
utf8_to_utf8_char_vec(labels, new_vocab);
Scorer* scorer = new Scorer(alpha, beta, lm_path, new_vocab);
return static_cast<void*>(scorer);
}

int is_character_based(void *scorer){
Scorer *ext_scorer = static_cast<Scorer *>(scorer);
return ext_scorer->is_character_based();
}
size_t get_max_order(void *scorer){
Scorer *ext_scorer = static_cast<Scorer *>(scorer);
return ext_scorer->get_max_order();
}
size_t get_dict_size(void *scorer){
Scorer *ext_scorer = static_cast<Scorer *>(scorer);
return ext_scorer->get_dict_size();
}

void reset_params(void *scorer, double alpha, double beta){
Scorer *ext_scorer = static_cast<Scorer *>(scorer);
ext_scorer->reset_params(alpha, beta);
}


PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("paddle_beam_decode", &paddle_beam_decode, "paddle_beam_decode");
m.def("paddle_beam_decode_lm", &paddle_beam_decode_lm, "paddle_beam_decode_lm");
m.def("paddle_get_scorer", &paddle_get_scorer, "paddle_get_scorer");
m.def("is_character_based", &is_character_based, "is_character_based");
m.def("get_max_order", &get_max_order, "get_max_order");
m.def("get_dict_size", &get_dict_size, "get_max_order");
m.def("reset_params", &reset_params, "reset_params");
}
11 changes: 5 additions & 6 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os

from setuptools import setup, find_packages, distutils
from torch.utils.cpp_extension import BuildExtension

this_file = os.path.dirname(__file__)

Expand Down Expand Up @@ -39,19 +40,17 @@ def _single_compile(obj):

# hack compile to support parallel compiling
distutils.ccompiler.CCompiler.compile = parallelCCompile
import build

setup(
name="ctcdecode",
version="0.3",
description="CTC Decoder for PyTorch based on Paddle Paddle's implementation",
url="https://github.com/parlance/ctcdecode",
author="Ryan Leary",
author_email="ryanleary@gmail.com",
install_requires=["cffi>=1.0.0"],
setup_requires=["cffi>=1.0.0", "wget"],
# Exclude the build files.
packages=find_packages(exclude=["build"]),
ext_package="",
cffi_modules=[
os.path.join(this_file, "build.py:ffi")
]
ext_modules = [build.extension],
cmdclass={'build_ext': BuildExtension}
)

0 comments on commit 7dde7b2

Please sign in to comment.