Skip to content

Commit

Permalink
add parallel ctc decoder using thread pool from wenet
Browse files Browse the repository at this point in the history
  • Loading branch information
Mddct committed Jun 24, 2022
1 parent 84d9751 commit 2d96bc1
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 2 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ add_library(ctc_prefix_beam_search STATIC
src/wenet/runtime/core/decoder/ctc_prefix_beam_search.cc
src/wenet/runtime/core/decoder/context_graph.cc
src/wenet/runtime/core/utils/utils.cc
src/wenet/runtime/core/utils/thread_pool.h
src/wenet/runtime/core/utils/string.cc
)
target_link_libraries(ctc_prefix_beam_search PUBLIC fst)
Expand Down
38 changes: 38 additions & 0 deletions src/ctc_decoder.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "ctc_decoder.h"
#include <vector>

// data shape: [bs, max_time, num_classes]
// decoded shape: [bs, top_paths, max_time] ignore id: -1
Expand Down Expand Up @@ -55,3 +56,40 @@ ctc_beam_search_decoder(uintptr_t pdata, int max_time, int bs, int num_classes,
}
return std::move(results);
}

std::vector<DecodeResult>
parallel_ctc_beam_search_decoder(uintptr_t pdata, int max_time, int bs,
int num_classes, uintptr_t psequence_length,
int n_sequence_length, int beam_width,
int top_path) {
if (max_time <= 0 || bs <= 0 || num_classes <= 0 || n_sequence_length <= 0 ||
beam_width <= 0 || top_path <= 0) {
return results;
}

auto *data = reinterpret_cast<float *>(pdata);
auto *sequence_length = reinterpret_cast<std::int64_t *>(psequence_length);

{
// TODO: test for now change in future
auto thread_num = bs;
wenet::ThreadPool pool(thread_num);
std::vector<std::future<std::vector<std::vector<DecodeResult>>>> res;
for (int i = 0; i < bs; i++) {
auto next =
reinterpret_cast<uintptr_t>(data + i * max_time * num_classes);
auto next_sequence_length =
reinterpret_cast<uintptr_t>(sequence_length + i);

res.emplace_back(pool.enqueue(ctc_beam_search_decoder, next, max_time, 1,
num_classes, next_sequence_length, 1,
beam_width, top_path));
}

// future result
std::vector<DecodeResult> batch_results;
for (int i =0 ; i < bs; i++){
batch_results.emplace_back(res[i].get(0)[0]);
}
}
}
8 changes: 6 additions & 2 deletions src/ctc_decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
#define CTC_DECODER_H_

#include <cinttypes>
#include <cstdio>
#include <iostream>
#include <vector>

struct DecodeResult {
Expand All @@ -19,6 +17,12 @@ std::vector<DecodeResult> ctc_beam_search_decoder(uintptr_t pdata, int max_time,
int n_sequence_length,
int beam_width, int top_path);

std::vector<DecodeResult>
parallel_ctc_beam_search_decoder(uintptr_t pdata, int max_time, int bs,
int num_classes, uintptr_t psequence_length,
int n_sequence_length, int beam_width,
int top_path);

std::vector<int64_t> EditDistance(uintptr_t sdata, int s_max_seq_len,
uintptr_t s_sequence_length, uintptr_t tdata,
int t_max_seq_len,
Expand Down

0 comments on commit 2d96bc1

Please sign in to comment.