Skip to content

Commit

Permalink
Refactor extension structure
Browse files Browse the repository at this point in the history
  • Loading branch information
ryanleary authored and Ryan Leary committed Jun 19, 2017
1 parent f8b6636 commit e669fc1
Show file tree
Hide file tree
Showing 155 changed files with 30 additions and 17 deletions.
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
build
_ext
pytorch_ctc
*.pyc

# Byte-compiled / optimized / DLL files
Expand Down
12 changes: 0 additions & 12 deletions ctc_decode/__init__.py

This file was deleted.

26 changes: 26 additions & 0 deletions pytorch_ctc/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import torch

from cffi import FFI
ffi = FFI()
from ._ext.ctc_decode import ctc_beam_decode


def beam_decode(probs, seq_len=None, top_paths=1, beam_width=10, merge_repeated=True):
prob_size = probs.size()
max_seq_len = prob_size[0]
batch_size = prob_size[1]
num_classes = prob_size[2]

if seq_len is not None and batch_size != seq_len.size():
raise ValueError("seq_len shape must be a 1xbatch_size tensor or None")
if top_paths < 1 or top_paths > beam_width:
raise ValueError("top_paths must be greater than 1 and less than or equal to the beam_width")

seq_len = torch.IntTensor(batch_size).zero_().add_(max_seq_len) if seq_len is None else seq_len
output = torch.IntTensor(top_paths, batch_size, max_seq_len)
scores = torch.FloatTensor(top_paths, batch_size)

merge_int = 1 if merge_repeated else 0
result = ctc_beam_decode(probs, seq_len, output, scores, top_paths, beam_width, merge_int)

return output, scores
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
8 changes: 4 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,19 @@
from distutils.core import setup
from torch.utils.ffi import create_extension

sources = ['src/cpu_binding.cpp', 'src/util/status.cpp']
headers = ['src/cpu_binding.h']
sources = ['pytorch_ctc/src/cpu_binding.cpp', 'pytorch_ctc/src/util/status.cpp']
headers = ['pytorch_ctc/src/cpu_binding.h']

ffi = create_extension(
name='ctc_decode',
name='pytorch_ctc._ext.ctc_decode',
package=True,
language='c++',
headers=headers,
sources=sources,
with_cuda=False,
extra_compile_args=['-std=c++11', '-fPIC', '-w']
)
ffi = ffi.distutils_extension()
ffi.name = 'pytorch_ctc._ctc_decode'

setup(
name="pytorch_ctc",
Expand Down

0 comments on commit e669fc1

Please sign in to comment.