Skip to content

Commit

Permalink
Merge pull request tesseract-ocr#2461 from stweil/tensorflow
Browse files Browse the repository at this point in the history
 Support build with Tensorflow
  • Loading branch information
zdenop authored May 25, 2019
2 parents e44c60c + 32dcfd0 commit 8de022a
Show file tree
Hide file tree
Showing 13 changed files with 1,725 additions and 12 deletions.
19 changes: 19 additions & 0 deletions configure.ac
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,25 @@ if test "$enable_opencl" = "yes"; then
])
fi

# Check whether to build with support for TensorFlow.
AC_ARG_WITH([tensorflow],
AS_HELP_STRING([--with-tensorflow],
[support TensorFlow @<:@default=check@:>@]),
[], [with_tensorflow=check])
AM_CONDITIONAL([TENSORFLOW], false)
TENSORFLOW_LIBS=
AS_IF([test "x$with_tensorflow" != xno],
[AC_CHECK_FILE([/usr/include/tensorflow/core/framework/graph.pb.h],
[AC_SUBST([TENSORFLOW_LIBS], ["-lprotobuf -ltensorflow_cc"])
AM_CONDITIONAL([TENSORFLOW], true)
],
[if test "x$with_tensorflow" != xcheck; then
AC_MSG_FAILURE(
[--with-tensorflow was given, but test for libtensorflow-dev failed])
fi
])
])

# https://lists.apple.com/archives/unix-porting/2009/Jan/msg00026.html
m4_define([MY_CHECK_FRAMEWORK],
[AC_CACHE_CHECK([if -framework $1 works],[my_cv_framework_$1],
Expand Down
1 change: 1 addition & 0 deletions src/api/Makefile.am
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ tesseract_LDFLAGS = $(OPENCL_LDFLAGS)

tesseract_LDADD += $(LEPTONICA_LIBS)
tesseract_LDADD += $(OPENMP_CXXFLAGS)
tesseract_LDADD += $(TENSORFLOW_LIBS)
tesseract_LDADD += $(libarchive_LIBS)

if T_WIN
Expand Down
3 changes: 1 addition & 2 deletions src/ccutil/unicharset.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
// File: unicharset.h
// Description: Unicode character/ligature set class.
// Author: Thomas Kielbus
// Created: Wed Jun 28 17:05:01 PDT 2006
//
// (C) Copyright 2006, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -153,7 +152,7 @@ class UNICHARSET {
// List of strings for the SpecialUnicharCodes. Keep in sync with the enum.
static TESS_API const char* kSpecialUnicharCodes[SPECIAL_UNICHAR_CODES_COUNT];

// ICU 2.0 UCharDirection enum (from third_party/icu/include/unicode/uchar.h)
// ICU 2.0 UCharDirection enum (from icu/include/unicode/uchar.h)
enum Direction {
U_LEFT_TO_RIGHT = 0,
U_RIGHT_TO_LEFT = 1,
Expand Down
9 changes: 9 additions & 0 deletions src/lstm/Makefile.am
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ AM_CPPFLAGS += \

AM_CXXFLAGS = $(OPENMP_CXXFLAGS)

if TENSORFLOW
AM_CPPFLAGS += -DINCLUDE_TENSORFLOW
AM_CPPFLAGS += -I/usr/include/tensorflow
endif

if !NO_TESSDATA_PREFIX
AM_CXXFLAGS += -DTESSDATA_PREFIX=@datadir@
endif
Expand Down Expand Up @@ -37,3 +42,7 @@ libtesseract_lstm_la_SOURCES = \
networkbuilder.cpp network.cpp networkio.cpp \
parallel.cpp plumbing.cpp recodebeam.cpp reconfig.cpp reversed.cpp \
series.cpp stridemap.cpp tfnetwork.cpp weightmatrix.cpp

if TENSORFLOW
libtesseract_lstm_la_SOURCES += tfnetwork.pb.cc
endif
5 changes: 2 additions & 3 deletions src/lstm/tfnetwork.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
// Description: Encapsulation of an entire tensorflow graph as a
// Tesseract Network.
// Author: Ray Smith
// Created: Fri Feb 26 09:35:29 PST 2016
//
// (C) Copyright 2016, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -90,14 +89,14 @@ void TFNetwork::Forward(bool debug, const NetworkIO& input,
if (!model_proto_.image_widths().empty()) {
TensorShape size_shape{1};
Tensor width_tensor(tensorflow::DT_INT64, size_shape);
auto eigen_wtensor = width_tensor.flat<int64>();
auto eigen_wtensor = width_tensor.flat<tensorflow::int64>();
*eigen_wtensor.data() = stride_map.Size(FD_WIDTH);
tf_inputs.emplace_back(model_proto_.image_widths(), width_tensor);
}
if (!model_proto_.image_heights().empty()) {
TensorShape size_shape{1};
Tensor height_tensor(tensorflow::DT_INT64, size_shape);
auto eigen_htensor = height_tensor.flat<int64>();
auto eigen_htensor = height_tensor.flat<tensorflow::int64>();
*eigen_htensor.data() = stride_map.Size(FD_HEIGHT);
tf_inputs.emplace_back(model_proto_.image_heights(), height_tensor);
}
Expand Down
18 changes: 15 additions & 3 deletions src/lstm/tfnetwork.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@

#include "network.h"
#include "static_shape.h"
#include "tfnetwork.proto.h"
#include "third_party/tensorflow/core/framework/graph.pb.h"
#include "third_party/tensorflow/core/public/session.h"
#include "tfnetwork.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/public/session.h"

namespace tesseract {

Expand Down Expand Up @@ -69,6 +69,18 @@ class TFNetwork : public Network {
NetworkScratch* scratch, NetworkIO* output) override;

private:
// Runs backward propagation of errors on the deltas line.
// See Network for a detailed discussion of the arguments.
bool Backward(bool debug, const NetworkIO& fwd_deltas,
NetworkScratch* scratch,
NetworkIO* back_deltas) override {
tprintf("Must override Network::DebugWeights for type %d\n", type_);
}

void DebugWeights() override {
tprintf("Must override Network::DebugWeights for type %d\n", type_);
}

int InitFromProto();

// The original network definition for reference.
Expand Down
Loading

0 comments on commit 8de022a

Please sign in to comment.