Skip to content

Commit

Permalink
Simplify function LoadTrainingData and fix mastertrainer_test
Browse files Browse the repository at this point in the history
Signed-off-by: Stefan Weil <sw@weilnetz.de>
  • Loading branch information
stweil committed Sep 17, 2021
1 parent 2b4c359 commit 6380451
Show file tree
Hide file tree
Showing 7 changed files with 10 additions and 37 deletions.
2 changes: 1 addition & 1 deletion src/training/classifier_tester.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ int main(int argc, char **argv) {
tesseract::CheckSharedLibraryVersion();
ParseArguments(&argc, &argv);
std::string file_prefix;
auto trainer = tesseract::LoadTrainingData(argc, argv, false, nullptr, file_prefix);
auto trainer = tesseract::LoadTrainingData(argv + 1, false, nullptr, file_prefix);
tesseract::TessBaseAPI *api;
// Decode the classifier string.
tesseract::ShapeClassifier *shape_classifier =
Expand Down
4 changes: 1 addition & 3 deletions src/training/cntraining.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@ int main(int argc, char *argv[]) {
// Set the global Config parameters before parsing the command line.
Config = CNConfig;

const char *PageName;
LIST CharList = NIL_LIST;
CLUSTERER *Clusterer = nullptr;
LIST ProtoList = NIL_LIST;
Expand All @@ -118,8 +117,7 @@ int main(int argc, char *argv[]) {

ParseArguments(&argc, &argv);
int num_fonts = 0;
int tessoptind = 1;
while ((PageName = GetNextFilename(argc, argv, tessoptind)) != nullptr) {
for (const char *PageName = *++argv; PageName != nullptr; PageName = *++argv) {
printf("Reading %s ...\n", PageName);
FILE *TrainingPage = fopen(PageName, "rb");
ASSERT_HOST(TrainingPage);
Expand Down
27 changes: 3 additions & 24 deletions src/training/common/commontraining.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ void WriteShapeTable(const std::string &file_prefix, const ShapeTable &shape_tab
* If shape_table is not nullptr, but failed to load, make a fake flat one,
* as shape clustering was not run.
*/
std::unique_ptr<MasterTrainer> LoadTrainingData(int argc, const char *const *argv, bool replication,
std::unique_ptr<MasterTrainer> LoadTrainingData(const char *const *filelist, bool replication,
ShapeTable **shape_table, std::string &file_prefix) {
InitFeatureDefs(&feature_defs);
InitIntegerFX();
Expand Down Expand Up @@ -236,10 +236,8 @@ std::unique_ptr<MasterTrainer> LoadTrainingData(int argc, const char *const *arg
}
}
trainer->SetFeatureSpace(fs);
const char *page_name;
// Load training data from .tr files on the command line.
int tessoptind = 1;
while ((page_name = GetNextFilename(argc, argv, tessoptind)) != nullptr) {
// Load training data from .tr files in filelist (terminated by nullptr).
for (const char *page_name = *filelist++; page_name != nullptr; page_name = *filelist++) {
tprintf("Reading %s ...\n", page_name);
trainer->ReadTrainingSamples(page_name, feature_defs, false);

Expand Down Expand Up @@ -291,25 +289,6 @@ std::unique_ptr<MasterTrainer> LoadTrainingData(int argc, const char *const *arg
return trainer;
}

/*---------------------------------------------------------------------------*/
/**
* This routine returns the next command line argument. If
* there are no remaining command line arguments, it returns
* nullptr. This routine should only be called after all option
* arguments have been parsed and removed with ParseArguments.
*
* Globals:
* - tessoptind defined by tessopt sys call
* @return Next command line argument or nullptr.
*/
const char *GetNextFilename(int argc, const char *const *argv, int &tessoptind) {
if (tessoptind < argc) {
return argv[tessoptind++];
} else {
return nullptr;
}
} /* GetNextFilename */

/*---------------------------------------------------------------------------*/
/**
* This routine searches through a list of labeled lists to find
Expand Down
5 changes: 1 addition & 4 deletions src/training/common/commontraining.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,9 @@ void WriteShapeTable(const std::string &file_prefix, const ShapeTable &shape_tab
// If shape_table is not nullptr, but failed to load, make a fake flat one,
// as shape clustering was not run.
TESS_COMMON_TRAINING_API
std::unique_ptr<MasterTrainer> LoadTrainingData(int argc, const char *const *argv, bool replication,
std::unique_ptr<MasterTrainer> LoadTrainingData(const char *const *filelist, bool replication,
ShapeTable **shape_table, std::string &file_prefix);

TESS_COMMON_TRAINING_API
const char *GetNextFilename(int argc, const char *const *argv, int &tessoptind);

LABELEDLIST FindList(tesseract::LIST List, const std::string &Label);

TESS_COMMON_TRAINING_API
Expand Down
2 changes: 1 addition & 1 deletion src/training/mftraining.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ int main(int argc, char **argv) {
ShapeTable *shape_table = nullptr;
std::string file_prefix;
// Load the training data.
auto trainer = tesseract::LoadTrainingData(argc, argv, false, &shape_table, file_prefix);
auto trainer = tesseract::LoadTrainingData(argv + 1, false, &shape_table, file_prefix);
if (trainer == nullptr) {
return 1; // Failed.
}
Expand Down
2 changes: 1 addition & 1 deletion src/training/shapeclustering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ int main(int argc, char **argv) {
ParseArguments(&argc, &argv);

std::string file_prefix;
auto trainer = tesseract::LoadTrainingData(argc, argv, false, nullptr, file_prefix);
auto trainer = tesseract::LoadTrainingData(argv + 1, false, nullptr, file_prefix);

if (!trainer) {
return 1;
Expand Down
5 changes: 2 additions & 3 deletions unittest/mastertrainer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,12 +173,11 @@ class MasterTrainerTest : public testing::Test {
FLAGS_X = TestDataNameToPath("eng.xheights").c_str();
FLAGS_U = TestDataNameToPath("eng.unicharset").c_str();
std::string tr_file_name(TestDataNameToPath("eng.Arial.exp0.tr"));
const char *argv[] = {tr_file_name.c_str()};
int argc = 1;
const char *filelist[] = {tr_file_name.c_str(), nullptr};
std::string file_prefix;
delete shape_table_;
shape_table_ = nullptr;
master_trainer_ = LoadTrainingData(argc, argv, false, &shape_table_, file_prefix);
master_trainer_ = LoadTrainingData(filelist, false, &shape_table_, file_prefix);
EXPECT_TRUE(master_trainer_ != nullptr);
EXPECT_TRUE(shape_table_ != nullptr);
}
Expand Down

0 comments on commit 6380451

Please sign in to comment.