From 5bc5e2a0b427e1c92057708b44bb6b4094f12439 Mon Sep 17 00:00:00 2001 From: "theraysmith@gmail.com" Date: Thu, 2 Feb 2012 02:57:42 +0000 Subject: [PATCH] Added simultaneous multi-language capability, Added support for ShapeTable in classifier and training, Refactored class pruner, Added new uniform classifier API, Added new training error counter git-svn-id: https://tesseract-ocr.googlecode.com/svn/trunk@650 d0cd1f9f-072b-0410-8dd7-cf729c803f20 --- classify/Makefile.am | 16 +- classify/adaptmatch.cpp | 1129 +++++++++++++++++--------------- classify/blobclass.cpp | 14 +- classify/classify.cpp | 74 +-- classify/classify.h | 194 ++++-- classify/cluster.cpp | 107 ++- classify/cluster.h | 12 +- classify/clusttool.cpp | 41 +- classify/clusttool.h | 2 - classify/cutoffs.cpp | 7 +- classify/errorcounter.cpp | 385 +++++++++++ classify/errorcounter.h | 198 ++++++ classify/extract.cpp | 9 - classify/featdefs.cpp | 128 ++-- classify/featdefs.h | 9 + classify/float2int.cpp | 54 +- classify/float2int.h | 8 - classify/fxdefs.cpp | 11 +- classify/intfeaturedist.cpp | 159 +++++ classify/intfeaturedist.h | 80 +++ classify/intfeaturemap.cpp | 245 +++++++ classify/intfeaturemap.h | 163 +++++ classify/intfeaturespace.cpp | 143 ++++ classify/intfeaturespace.h | 110 ++++ classify/intfx.cpp | 56 +- classify/intfx.h | 30 +- classify/intmatcher.cpp | 603 +++++++++-------- classify/intmatcher.h | 22 +- classify/intproto.cpp | 279 +++----- classify/intproto.h | 96 +-- classify/mastertrainer.cpp | 967 +++++++++++++++++++++++++++ classify/mastertrainer.h | 296 +++++++++ classify/mf.h | 5 +- classify/mfoutline.cpp | 54 +- classify/mfoutline.h | 2 - classify/mfx.cpp | 1 - classify/normmatch.cpp | 65 +- classify/ocrfeatures.cpp | 7 +- classify/ocrfeatures.h | 6 +- classify/picofeat.cpp | 57 ++ classify/picofeat.h | 19 + classify/protos.h | 9 +- classify/sampleiterator.cpp | 262 ++++++++ classify/sampleiterator.h | 195 ++++++ classify/shapeclassifier.h | 95 +++ classify/shapetable.cpp | 452 +++++++++++++ classify/shapetable.h | 227 +++++++ classify/speckle.cpp | 22 +- classify/tessclassifier.cpp | 52 ++ classify/tessclassifier.h | 65 ++ classify/trainingsample.cpp | 311 +++++++++ classify/trainingsample.h | 240 +++++++ classify/trainingsampleset.cpp | 870 ++++++++++++++++++++++++ classify/trainingsampleset.h | 290 ++++++++ 54 files changed, 7463 insertions(+), 1490 deletions(-) create mode 100644 classify/errorcounter.cpp create mode 100644 classify/errorcounter.h create mode 100644 classify/intfeaturedist.cpp create mode 100644 classify/intfeaturedist.h create mode 100644 classify/intfeaturemap.cpp create mode 100644 classify/intfeaturemap.h create mode 100644 classify/intfeaturespace.cpp create mode 100644 classify/intfeaturespace.h create mode 100644 classify/mastertrainer.cpp create mode 100644 classify/mastertrainer.h create mode 100644 classify/sampleiterator.cpp create mode 100644 classify/sampleiterator.h create mode 100644 classify/shapeclassifier.h create mode 100644 classify/shapetable.cpp create mode 100644 classify/shapetable.h create mode 100644 classify/tessclassifier.cpp create mode 100644 classify/tessclassifier.h create mode 100644 classify/trainingsample.cpp create mode 100644 classify/trainingsample.h create mode 100644 classify/trainingsampleset.cpp create mode 100644 classify/trainingsampleset.h diff --git a/classify/Makefile.am b/classify/Makefile.am index 245713ab97..2083eb2e8d 100644 --- a/classify/Makefile.am +++ b/classify/Makefile.am @@ -7,13 +7,15 @@ AM_CPPFLAGS = \ include_HEADERS = \ adaptive.h baseline.h blobclass.h chartoname.h \ classify.h cluster.h clusttool.h cutoffs.h \ - extern.h extract.h \ + errorcounter.h extern.h extract.h \ featdefs.h flexfx.h float2int.h fpoint.h fxdefs.h \ + intfeaturedist.h intfeaturemap.h intfeaturespace.h \ intfx.h intmatcher.h intproto.h kdtree.h \ - mf.h mfdefs.h mfoutline.h mfx.h \ + mastertrainer.h mf.h mfdefs.h mfoutline.h mfx.h \ normfeat.h normmatch.h \ ocrfeatures.h outfeat.h picofeat.h protos.h \ - speckle.h xform2d.h + sampleiterator.h shapeclassifier.h shapetable.h \ + speckle.h tessclassifier.h trainingsample.h trainingsampleset.h xform2d.h if !USING_MULTIPLELIBS noinst_LTLIBRARIES = libtesseract_classify.la @@ -32,12 +34,14 @@ endif libtesseract_classify_la_SOURCES = \ adaptive.cpp adaptmatch.cpp blobclass.cpp \ chartoname.cpp classify.cpp cluster.cpp clusttool.cpp cutoffs.cpp \ - extract.cpp \ + errorcounter.cpp extract.cpp \ featdefs.cpp flexfx.cpp float2int.cpp fpoint.cpp fxdefs.cpp \ + intfeaturedist.cpp intfeaturemap.cpp intfeaturespace.cpp \ intfx.cpp intmatcher.cpp intproto.cpp kdtree.cpp \ - mf.cpp mfdefs.cpp mfoutline.cpp mfx.cpp \ + mastertrainer.cpp mf.cpp mfdefs.cpp mfoutline.cpp mfx.cpp \ normfeat.cpp normmatch.cpp \ ocrfeatures.cpp outfeat.cpp picofeat.cpp protos.cpp \ - speckle.cpp xform2d.cpp + sampleiterator.cpp shapetable.cpp speckle.cpp \ + tessclassifier.cpp trainingsample.cpp trainingsampleset.cpp xform2d.cpp diff --git a/classify/adaptmatch.cpp b/classify/adaptmatch.cpp index e83efc3dc5..6ef4dd10a9 100644 --- a/classify/adaptmatch.cpp +++ b/classify/adaptmatch.cpp @@ -44,13 +44,16 @@ #include "pageres.h" #include "params.h" #include "classify.h" +#include "shapetable.h" +#include "tessclassifier.h" +#include "trainingsample.h" #include "unicharset.h" #include "dict.h" #include "featdefs.h" +#include "genericvector.h" #include #include -#include #include #include #ifdef __UNIX__ @@ -76,10 +79,11 @@ #define WORST_POSSIBLE_RATING (1.0) struct ScoredClass { - CLASS_ID id; + CLASS_ID unichar_id; + int shape_id; FLOAT32 rating; + bool adapted; inT16 config; - inT16 config2; inT16 fontinfo_id; inT16 fontinfo_id2; }; @@ -98,10 +102,11 @@ struct ADAPT_RESULTS { BlobLength = MAX_INT32; NumMatches = 0; HasNonfragment = false; - best_match.id = NO_CLASS; + best_match.unichar_id = NO_CLASS; + best_match.shape_id = -1; best_match.rating = WORST_POSSIBLE_RATING; + best_match.adapted = false; best_match.config = 0; - best_match.config2 = 0; best_match.fontinfo_id = kBlankFontinfoId; best_match.fontinfo_id2 = kBlankFontinfoId; } @@ -170,6 +175,7 @@ namespace tesseract { * */ void Classify::AdaptiveClassifier(TBLOB *Blob, + const DENORM& denorm, BLOB_CHOICE_LIST *Choices, CLASS_PRUNER_RESULTS CPResults) { assert(Choices != NULL); @@ -180,7 +186,7 @@ void Classify::AdaptiveClassifier(TBLOB *Blob, Results->Initialize(); - DoAdaptiveMatch(Blob, Results); + DoAdaptiveMatch(Blob, denorm, Results); if (CPResults != NULL) memcpy(CPResults, Results->CPResults, sizeof(CPResults[0]) * Results->NumMatches); @@ -189,7 +195,7 @@ void Classify::AdaptiveClassifier(TBLOB *Blob, qsort((void *)Results->match, Results->NumMatches, sizeof(ScoredClass), CompareByRating); RemoveExtraPuncs(Results); - ConvertMatchesToChoices(Results, Choices); + ConvertMatchesToChoices(denorm, Blob->bounding_box(), Results, Choices); if (matcher_debug_level >= 1) { cprintf ("AD Matches = "); @@ -201,7 +207,7 @@ void Classify::AdaptiveClassifier(TBLOB *Blob, #ifndef GRAPHICS_DISABLED if (classify_enable_adaptive_debugger) - DebugAdaptiveClassifier(Blob, Results); + DebugAdaptiveClassifier(Blob, denorm, Results); #endif NumClassesOutput += Choices->length(); @@ -211,7 +217,8 @@ void Classify::AdaptiveClassifier(TBLOB *Blob, Choices = new BLOB_CHOICE_LIST(); BLOB_CHOICE_IT temp_it; temp_it.set_to_list(Choices); - temp_it.add_to_end(new BLOB_CHOICE(0, 50.0f, -20.0f, -1, -1, NULL)); + temp_it.add_to_end( + new BLOB_CHOICE(0, 50.0f, -20.0f, -1, -1, NULL, 0, 0, false)); } delete Results; @@ -251,15 +258,19 @@ void Classify::LearnWord(const char* filename, const char *rejmap, float* thresholds = NULL; if (filename == NULL) { // Adaption mode. - if (!EnableLearning || word->best_choice == NULL) + if (!EnableLearning || word->best_choice == NULL || + // If word->best_choice is not recorded at the top of accumulator's + // best choices (which could happen for choices that are + // altered with ReplaceAmbig()) we skip the adaption. + !getDict().CurrentBestChoiceIs(*(word->best_choice))) return; // Can't or won't adapt. NumWordsAdaptedTo++; if (classify_learning_debug_level >= 1) tprintf("\n\nAdapting to word = %s\n", - word->best_choice->debug_string(unicharset).string()); + word->best_choice->debug_string().string()); thresholds = new float[word_len]; - GetAdaptThresholds(word->rebuild_word, *word->best_choice, + GetAdaptThresholds(word->rebuild_word, word->denorm, *word->best_choice, *word->raw_choice, thresholds); } int start_blob = 0; @@ -286,6 +297,9 @@ void Classify::LearnWord(const char* filename, const char *rejmap, if (word->correct_text[ch].length() > 0 && rej_map_char == '1') { float threshold = thresholds != NULL ? thresholds[ch] : 0.0f; + LearnPieces(filename, start_blob, word->best_state[ch], + threshold, CST_WHOLE, word->correct_text[ch].string(), word); + if (word->best_state[ch] > 1 && !disable_character_fragments) { // Check that the character breaks into meaningful fragments // that each match a whole character with at least @@ -296,24 +310,36 @@ void Classify::LearnWord(const char* filename, const char *rejmap, int frag; for (frag = 0; frag < word->best_state[ch]; ++frag) { if (classify_character_fragments_garbage_certainty_threshold < 0) { - garbage |= LooksLikeGarbage(frag_blob); + garbage |= LooksLikeGarbage(word->denorm, frag_blob); } frag_blob = frag_blob->next; } // Learn the fragments. if (!garbage) { - for (frag = 0; frag < word->best_state[ch]; ++frag) { - STRING frag_str = CHAR_FRAGMENT::to_string( - word->correct_text[ch].string(), frag, word->best_state[ch]); - LearnPieces(filename, start_blob + frag, 1, - threshold, CST_FRAGMENT, frag_str.string(), word); + bool pieces_all_natural = word->PiecesAllNatural(start_blob, + word->best_state[ch]); + if (pieces_all_natural || !prioritize_division) { + for (frag = 0; frag < word->best_state[ch]; ++frag) { + GenericVector tokens; + word->correct_text[ch].split(' ', &tokens); + + tokens[0] = CHAR_FRAGMENT::to_string( + tokens[0].string(), frag, word->best_state[ch], + pieces_all_natural); + + STRING full_string; + for (int i = 0; i < tokens.size(); i++) { + full_string += tokens[i]; + if (i != tokens.size() - 1) + full_string += ' '; + } + LearnPieces(filename, start_blob + frag, 1, + threshold, CST_FRAGMENT, full_string.string(), word); + } } } } - LearnPieces(filename, start_blob, word->best_state[ch], - threshold, CST_WHOLE, word->correct_text[ch].string(), word); - // TODO(rays): re-enable this part of the code when we switch to the // new classifier that needs to see examples of garbage. /* @@ -378,12 +404,17 @@ void Classify::LearnPieces(const char* filename, int start, int length, TBLOB* blob = word->chopped_word->blobs; for (int i = 0; i < start; ++i) blob = blob->next; + // Rotate the blob if needed for classification. + const DENORM* denorm = &word->denorm; + TBLOB* rotated_blob = blob->ClassifyNormalizeIfNeeded(&denorm); + if (rotated_blob == NULL) + rotated_blob = blob; // Draw debug windows showing the blob that is being learned if needed. if (strcmp(classify_learn_debug_str.string(), correct_text) == 0) { RefreshDebugWindow(&learn_debug_win_, "LearnPieces", 600, word->chopped_word->bounding_box()); - blob->plot(learn_debug_win_, ScrollView::GREEN, ScrollView::BROWN); + rotated_blob->plot(learn_debug_win_, ScrollView::GREEN, ScrollView::BROWN); learn_debug_win_->Update(); window_wait(learn_debug_win_); } @@ -398,20 +429,25 @@ void Classify::LearnPieces(const char* filename, int start, int length, classify_norm_method.set_value(character); // force char norm spc 30/11/93 tess_bn_matching.set_value(false); // turn it off tess_cn_matching.set_value(false); - LearnBlob(feature_defs_, filename, blob, word->denorm, correct_text); - } else { - if (!unicharset.contains_unichar(correct_text)) { - unicharset.unichar_insert(correct_text); - // TODO(rays) We may need to worry about exceeding MAX_NUM_CLASSES. - // if (unicharset_boxes->size() > MAX_NUM_CLASSES) ... - } + LearnBlob(feature_defs_, filename, rotated_blob, *denorm, + correct_text); + } else if (unicharset.contains_unichar(correct_text)) { UNICHAR_ID class_id = unicharset.unichar_to_id(correct_text); + int font_id = word->fontinfo != NULL + ? fontinfo_table_.get_id(*word->fontinfo) + : 0; if (classify_learning_debug_level >= 1) tprintf("Adapting to char = %s, thr= %g font_id= %d\n", - unicharset.id_to_unichar(class_id), threshold, word->fontinfo_id); + unicharset.id_to_unichar(class_id), threshold, font_id); // If filename is not NULL we are doing recognition // (as opposed to training), so we must have already set word fonts. - AdaptToChar(blob, class_id, word->fontinfo_id, threshold); + AdaptToChar(rotated_blob, *denorm, class_id, font_id, threshold); + } else if (classify_debug_level >= 1) { + tprintf("Can't adapt to %s not in unicharset\n", correct_text); + } + if (rotated_blob != blob) { + delete rotated_blob; + delete denorm; } break_pieces(blob, word->seam_array, start, start + length - 1); @@ -478,6 +514,8 @@ void Classify::EndAdaptiveClassifier() { AllConfigsOff = NULL; TempProtoMask = NULL; } + delete shape_table_; + shape_table_ = NULL; } /* EndAdaptiveClassifier */ @@ -514,8 +552,21 @@ void Classify::InitAdaptiveClassifier(bool load_pre_trained_templates) { ReadIntTemplates(tessdata_manager.GetDataFilePtr()); if (tessdata_manager.DebugLevel() > 0) tprintf("Loaded inttemp\n"); + if (tessdata_manager.SeekToStart(TESSDATA_SHAPE_TABLE)) { + shape_table_ = new ShapeTable(unicharset); + if (!shape_table_->DeSerialize(tessdata_manager.swap(), + tessdata_manager.GetDataFilePtr())) { + tprintf("Error loading shape table!\n"); + delete shape_table_; + shape_table_ = NULL; + } else if (tessdata_manager.DebugLevel() > 0) { + tprintf("Successfully loaded shape table!\n"); + } + } + ASSERT_HOST(tessdata_manager.SeekToStart(TESSDATA_PFFMTABLE)); ReadNewCutoffs(tessdata_manager.GetDataFilePtr(), + tessdata_manager.swap(), tessdata_manager.GetEndOffset(TESSDATA_PFFMTABLE), CharNormCutoffs); if (tessdata_manager.DebugLevel() > 0) tprintf("Loaded pffmtable\n"); @@ -577,7 +628,7 @@ void Classify::InitAdaptiveClassifier(bool load_pre_trained_templates) { } } /* InitAdaptiveClassifier */ -void Classify::ResetAdaptiveClassifier() { +void Classify::ResetAdaptiveClassifierInternal() { if (classify_learning_debug_level > 0) { tprintf("Resetting adaptive classifier (NumAdaptationsFailed=%d)\n", NumAdaptationsFailed); @@ -642,7 +693,7 @@ void Classify::PrintAdaptiveStatistics(FILE *File) { * enabled then it will remain enabled on the first pass. If it is * disabled, then it will remain disabled. This is only put here to * make it very clear that learning is controlled directly by the global - * setting of EnableLearning. + * setting of EnableLearning. * * Globals: * - #EnableLearning @@ -699,6 +750,7 @@ void Classify::SettupPass2() { * @note History: Thu Mar 14 12:49:39 1991, DSJ, Created. */ void Classify::InitAdaptedClass(TBLOB *Blob, + const DENORM& denorm, CLASS_ID ClassId, int FontinfoId, ADAPT_CLASS Class, @@ -764,7 +816,7 @@ void Classify::InitAdaptedClass(TBLOB *Blob, cprintf ("Added new class '%s' with class id %d and %d protos.\n", unicharset.id_to_unichar(ClassId), ClassId, NumFeatures); if (classify_learning_debug_level > 1) - DisplayAdaptedChar(Blob, IClass); + DisplayAdaptedChar(Blob, denorm, IClass); } if (IsEmptyAdaptedClass(Class)) @@ -840,18 +892,13 @@ int Classify::AdaptableWord(TWERD *Word, int BestChoiceLength = BestChoiceWord.length(); float adaptable_score = getDict().segment_penalty_dict_case_ok + ADAPTABLE_WERD_ADJUSTMENT; - return ( // rules that apply in general - simplest to compute first - BestChoiceLength > 0 && - BestChoiceLength == Word->NumBlobs() && - BestChoiceLength <= MAX_ADAPTABLE_WERD_SIZE && ( - (classify_enable_new_adapt_rules && - getDict().CurrentBestChoiceAdjustFactor() <= adaptable_score && - getDict().AlternativeChoicesWorseThan(adaptable_score) && - getDict().CurrentBestChoiceIs(BestChoiceWord)) || - (!classify_enable_new_adapt_rules && // old rules - BestChoiceLength == RawChoiceWord.length() && - ((getDict().valid_word_or_number(BestChoiceWord) && - getDict().case_ok(BestChoiceWord, getDict().getUnicharset())))))); + return // rules that apply in general - simplest to compute first + BestChoiceLength > 0 && + BestChoiceLength == Word->NumBlobs() && + BestChoiceLength <= MAX_ADAPTABLE_WERD_SIZE && + getDict().CurrentBestChoiceAdjustFactor() <= adaptable_score && + getDict().AlternativeChoicesWorseThan(adaptable_score) && + getDict().CurrentBestChoiceIs(BestChoiceWord); } /*---------------------------------------------------------------------------*/ @@ -872,6 +919,7 @@ int Classify::AdaptableWord(TWERD *Word, * @note History: Thu Mar 14 09:36:03 1991, DSJ, Created. */ void Classify::AdaptToChar(TBLOB *Blob, + const DENORM& denorm, CLASS_ID ClassId, int FontinfoId, FLOAT32 Threshold) { @@ -892,7 +940,8 @@ void Classify::AdaptToChar(TBLOB *Blob, Class = AdaptedTemplates->Class[ClassId]; assert(Class != NULL); if (IsEmptyAdaptedClass(Class)) { - InitAdaptedClass(Blob, ClassId, FontinfoId, Class, AdaptedTemplates); + InitAdaptedClass(Blob, denorm, ClassId, FontinfoId, Class, + AdaptedTemplates); } else { IClass = ClassForClassId (AdaptedTemplates->Templates, ClassId); @@ -912,7 +961,7 @@ void Classify::AdaptToChar(TBLOB *Blob, } } im_.Match(IClass, AllProtosOn, MatchingFontConfigs, - NumFeatures, NumFeatures, IntFeatures, 0, + NumFeatures, IntFeatures, &IntResult, classify_adapt_feature_threshold, NO_DEBUG, matcher_debug_separate_windows); FreeBitVector(MatchingFontConfigs); @@ -938,8 +987,9 @@ void Classify::AdaptToChar(TBLOB *Blob, IntResult.Config, TempConfig->NumTimesSeen); if (TempConfigReliable(ClassId, TempConfig)) { - MakePermanent(AdaptedTemplates, ClassId, IntResult.Config, Blob); - UpdateAmbigsGroup(ClassId, Blob); + MakePermanent(AdaptedTemplates, ClassId, IntResult.Config, denorm, + Blob); + UpdateAmbigsGroup(ClassId, denorm, Blob); } } else { @@ -947,7 +997,7 @@ void Classify::AdaptToChar(TBLOB *Blob, cprintf ("Found poor match to temp config %d = %4.1f%%.\n", IntResult.Config, (1.0 - IntResult.Rating) * 100.0); if (classify_learning_debug_level > 2) - DisplayAdaptedChar(Blob, IClass); + DisplayAdaptedChar(Blob, denorm, IClass); } NewTempConfigId = MakeNewTemporaryConfig(AdaptedTemplates, ClassId, @@ -957,13 +1007,13 @@ void Classify::AdaptToChar(TBLOB *Blob, FloatFeatures); if (NewTempConfigId >= 0 && TempConfigReliable(ClassId, TempConfigFor(Class, NewTempConfigId))) { - MakePermanent(AdaptedTemplates, ClassId, NewTempConfigId, Blob); - UpdateAmbigsGroup(ClassId, Blob); + MakePermanent(AdaptedTemplates, ClassId, NewTempConfigId, denorm, Blob); + UpdateAmbigsGroup(ClassId, denorm, Blob); } #ifndef GRAPHICS_DISABLED if (classify_learning_debug_level > 1) { - DisplayAdaptedChar(Blob, IClass); + DisplayAdaptedChar(Blob, denorm, IClass); } #endif } @@ -971,17 +1021,20 @@ void Classify::AdaptToChar(TBLOB *Blob, } } /* AdaptToChar */ -void Classify::DisplayAdaptedChar(TBLOB* blob, INT_CLASS_STRUCT* int_class) { +void Classify::DisplayAdaptedChar(TBLOB* blob, const DENORM& denorm, + INT_CLASS_STRUCT* int_class) { #ifndef GRAPHICS_DISABLED int bloblength = 0; INT_FEATURE_ARRAY features; - CLASS_NORMALIZATION_ARRAY norm_array; - int num_features = GetBaselineFeatures(blob, PreTrainedTemplates, features, + uinT8* norm_array = new uinT8[unicharset.size()]; + int num_features = GetBaselineFeatures(blob, denorm, PreTrainedTemplates, + features, norm_array, &bloblength); + delete [] norm_array; INT_RESULT_STRUCT IntResult; im_.Match(int_class, AllProtosOn, AllConfigsOn, - num_features, num_features, features, 0, + num_features, features, &IntResult, classify_adapt_feature_threshold, NO_DEBUG, matcher_debug_separate_windows); cprintf ("Best match to temp config %d = %4.1f%%.\n", @@ -991,13 +1044,10 @@ void Classify::DisplayAdaptedChar(TBLOB* blob, INT_CLASS_STRUCT* int_class) { ConfigMask = 1 << IntResult.Config; ShowMatchDisplay(); im_.Match(int_class, AllProtosOn, (BIT_VECTOR)&ConfigMask, - num_features, num_features, features, 0, + num_features, features, &IntResult, classify_adapt_feature_threshold, 6 | 0x19, matcher_debug_separate_windows); UpdateMatchDisplay(); - bool adaptive_on = true; - bool pretrained_on = false; - GetClassToDebug("Adapting", &adaptive_on, &pretrained_on); } #endif } @@ -1018,6 +1068,7 @@ void Classify::DisplayAdaptedChar(TBLOB* blob, INT_CLASS_STRUCT* int_class) { * @note History: Thu Mar 14 09:36:03 1991, DSJ, Created. */ void Classify::AdaptToPunc(TBLOB *Blob, + const DENORM& denorm, CLASS_ID ClassId, int FontinfoId, FLOAT32 Threshold) { @@ -1025,7 +1076,7 @@ void Classify::AdaptToPunc(TBLOB *Blob, int i; Results->Initialize(); - CharNormClassifier(Blob, PreTrainedTemplates, Results); + CharNormClassifier(Blob, denorm, PreTrainedTemplates, Results); RemoveBadMatches(Results); if (Results->NumMatches != 1) { @@ -1034,8 +1085,8 @@ void Classify::AdaptToPunc(TBLOB *Blob, unicharset.id_to_unichar(ClassId)); for (i = 0; i < Results->NumMatches; i++) - cprintf("%s", unicharset.id_to_unichar(Results->match[i].id)); - cprintf(")\n"); + tprintf("%s", unicharset.id_to_unichar(Results->match[i].unichar_id)); + tprintf(")\n"); } } else { #ifndef SECURE_NAMES @@ -1043,7 +1094,7 @@ void Classify::AdaptToPunc(TBLOB *Blob, cprintf ("Adapting to punc = %s, thr= %g\n", unicharset.id_to_unichar(ClassId), Threshold); #endif - AdaptToChar(Blob, ClassId, FontinfoId, Threshold); + AdaptToChar(Blob, denorm, ClassId, FontinfoId, Threshold); } delete Results; } /* AdaptToPunc */ @@ -1077,14 +1128,21 @@ void Classify::AdaptToPunc(TBLOB *Blob, */ void Classify::AddNewResult(ADAPT_RESULTS *results, CLASS_ID class_id, + int shape_id, FLOAT32 rating, + bool adapted, int config, - int config2, int fontinfo_id, int fontinfo_id2) { ScoredClass *old_match = FindScoredUnichar(results, class_id); ScoredClass match = - { class_id, rating, config, config2, fontinfo_id, fontinfo_id2 }; + { class_id, + shape_id, + rating, + adapted, + static_cast(config), + static_cast(fontinfo_id), + static_cast(fontinfo_id2) }; if (rating > results->best_match.rating + matcher_bad_match_pad || (old_match && rating >= old_match->rating)) @@ -1093,10 +1151,6 @@ void Classify::AddNewResult(ADAPT_RESULTS *results, if (!unicharset.get_fragment(class_id)) results->HasNonfragment = true; - if (class_id == NO_CLASS || - !ClassForClassId(PreTrainedTemplates, class_id)) - match.config = ~0; - if (old_match) old_match->rating = rating; else @@ -1134,54 +1188,52 @@ void Classify::AddNewResult(ADAPT_RESULTS *results, * @note History: Tue Mar 12 19:40:36 1991, DSJ, Created. */ void Classify::AmbigClassifier(TBLOB *Blob, + const DENORM& denorm, INT_TEMPLATES Templates, ADAPT_CLASS *Classes, UNICHAR_ID *Ambiguities, ADAPT_RESULTS *Results) { int NumFeatures; INT_FEATURE_ARRAY IntFeatures; - CLASS_NORMALIZATION_ARRAY CharNormArray; + uinT8* CharNormArray = new uinT8[unicharset.size()]; INT_RESULT_STRUCT IntResult; CLASS_ID ClassId; AmbigClassifierCalls++; - NumFeatures = GetCharNormFeatures(Blob, Templates, IntFeatures, CharNormArray, + NumFeatures = GetCharNormFeatures(Blob, denorm, Templates, IntFeatures, + NULL, CharNormArray, &(Results->BlobLength), NULL); - if (NumFeatures <= 0) + if (NumFeatures <= 0) { + delete [] CharNormArray; return; + } - if (matcher_debug_level >= 2) - cprintf ("AM Matches = "); + bool debug = matcher_debug_level >= 2 || classify_debug_level > 1; + if (debug) + tprintf("AM Matches = "); + int top = Blob->bounding_box().top(); + int bottom = Blob->bounding_box().bottom(); while (*Ambiguities >= 0) { ClassId = *Ambiguities; im_.SetCharNormMatch(classify_integer_matcher_multiplier); im_.Match(ClassForClassId(Templates, ClassId), AllProtosOn, AllConfigsOn, - Results->BlobLength, NumFeatures, IntFeatures, - CharNormArray[ClassId], &IntResult, + NumFeatures, IntFeatures, + &IntResult, classify_adapt_feature_threshold, NO_DEBUG, matcher_debug_separate_windows); - if (matcher_debug_level >= 2) - cprintf ("%s-%-2d %2.0f ", unicharset.id_to_unichar(ClassId), - IntResult.Config, - IntResult.Rating * 100.0); - - assert(Classes != NULL); - AddNewResult(Results, ClassId, IntResult.Rating, - IntResult.Config, IntResult.Config2, - GetFontinfoId(Classes[ClassId], IntResult.Config), - GetFontinfoId(Classes[ClassId], IntResult.Config2)); + ExpandShapesAndApplyCorrections(NULL, debug, ClassId, bottom, top, 0, + Results->BlobLength, CharNormArray, + IntResult, Results); Ambiguities++; NumAmbigClassesTried++; } - if (matcher_debug_level >= 2) - cprintf ("\n"); - + delete [] CharNormArray; } /* AmbigClassifier */ /*---------------------------------------------------------------------------*/ @@ -1189,8 +1241,8 @@ void Classify::AmbigClassifier(TBLOB *Blob, /// Returns integer matcher results inside CLASS_PRUNER_RESULTS structure. void Classify::MasterMatcher(INT_TEMPLATES templates, inT16 num_features, - INT_FEATURE_ARRAY features, - CLASS_NORMALIZATION_ARRAY norm_factors, + const INT_FEATURE_STRUCT* features, + const uinT8* norm_factors, ADAPT_CLASS* classes, int debug, int num_classes, @@ -1208,94 +1260,130 @@ void Classify::MasterMatcher(INT_TEMPLATES templates, : AllConfigsOn; im_.Match(ClassForClassId(templates, class_id), - protos, configs, final_results->BlobLength, - num_features, features, norm_factors[class_id], + protos, configs, + num_features, features, &int_result, classify_adapt_feature_threshold, debug, matcher_debug_separate_windows); - // Compute class feature corrections. - double miss_penalty = tessedit_class_miss_scale * - int_result.FeatureMisses; - if (matcher_debug_level >= 2 || classify_debug_level > 1) { - cprintf("%s-%-2d %2.1f(CP%2.1f, IM%2.1f + MP%2.1f) ", - unicharset.id_to_unichar(class_id), int_result.Config, - (int_result.Rating + miss_penalty) * 100.0, - results[c].Rating * 100.0, - int_result.Rating * 100.0, miss_penalty * 100.0); - if (c % 4 == 3) - cprintf ("\n"); + bool debug = matcher_debug_level >= 2 || classify_debug_level > 1; + ExpandShapesAndApplyCorrections(classes, debug, class_id, bottom, top, + results[c].Rating, + final_results->BlobLength, norm_factors, + int_result, final_results); + } +} + +// Converts configs to fonts, and if the result is not adapted, and a +// shape_table_ is present, the shape is expanded to include all +// unichar_ids represented, before applying a set of corrections to the +// distance rating in int_result, (see ComputeCorrectedRating.) +// The results are added to the final_results output. +void Classify::ExpandShapesAndApplyCorrections( + ADAPT_CLASS* classes, bool debug, int class_id, int bottom, int top, + float cp_rating, int blob_length, const uinT8* cn_factors, + INT_RESULT_STRUCT& int_result, ADAPT_RESULTS* final_results) { + // Compute the fontinfo_ids. + int fontinfo_id = kBlankFontinfoId; + int fontinfo_id2 = kBlankFontinfoId; + if (classes != NULL) { + // Adapted result. + fontinfo_id = GetFontinfoId(classes[class_id], int_result.Config); + if (int_result.Config2 >= 0) + fontinfo_id2 = GetFontinfoId(classes[class_id], int_result.Config2); + } else { + // Pre-trained result. + fontinfo_id = ClassAndConfigIDToFontOrShapeID(class_id, int_result.Config); + if (int_result.Config2 >= 0) { + fontinfo_id2 = ClassAndConfigIDToFontOrShapeID(class_id, + int_result.Config2); } - // Penalize non-alnums for being vertical misfits. - if (!unicharset.get_isalpha(class_id) && - !unicharset.get_isdigit(class_id) && - norm_factors[class_id] != 0 && classify_misfit_junk_penalty > 0.0) { - int min_bottom, max_bottom, min_top, max_top; - unicharset.get_top_bottom(class_id, &min_bottom, &max_bottom, - &min_top, &max_top); - if (classify_debug_level > 1) { - tprintf("top=%d, vs [%d, %d], bottom=%d, vs [%d, %d]\n", - top, min_top, max_top, bottom, min_bottom, max_bottom); - } - if (top < min_top || top > max_top || - bottom < min_bottom || bottom > max_bottom) { - miss_penalty += classify_misfit_junk_penalty; + if (shape_table_ != NULL) { + // Actually fontinfo_id is an index into the shape_table_ and it + // contains a list of unchar_id/font_id pairs. + int shape_id = fontinfo_id; + const Shape& shape = shape_table_->GetShape(fontinfo_id); + double min_rating = 0.0; + for (int c = 0; c < shape.size(); ++c) { + int unichar_id = shape[c].unichar_id; + fontinfo_id = shape[c].font_ids[0]; + if (shape[c].font_ids.size() > 1) + fontinfo_id2 = shape[c].font_ids[1]; + else if (fontinfo_id2 != kBlankFontinfoId) + fontinfo_id2 = shape_table_->GetShape(fontinfo_id2)[0].font_ids[0]; + double rating = ComputeCorrectedRating(debug, unichar_id, cp_rating, + int_result.Rating, + int_result.FeatureMisses, + bottom, top, blob_length, + cn_factors); + if (c == 0 || rating < min_rating) + min_rating = rating; + if (unicharset.get_enabled(unichar_id)) { + AddNewResult(final_results, unichar_id, shape_id, rating, + classes != NULL, int_result.Config, + fontinfo_id, fontinfo_id2); + } } + int_result.Rating = min_rating; + return; } - int_result.Rating += miss_penalty; - if (int_result.Rating > WORST_POSSIBLE_RATING) - int_result.Rating = WORST_POSSIBLE_RATING; - if (classes != NULL) { - AddNewResult(final_results, class_id, int_result.Rating, - int_result.Config, int_result.Config2, - GetFontinfoId(classes[class_id], int_result.Config), - GetFontinfoId(classes[class_id], int_result.Config2)); - } else { - AddNewResult(final_results, class_id, int_result.Rating, - int_result.Config, int_result.Config2, - kBlankFontinfoId, kBlankFontinfoId); + } + double rating = ComputeCorrectedRating(debug, class_id, cp_rating, + int_result.Rating, + int_result.FeatureMisses, + bottom, top, blob_length, + cn_factors); + if (unicharset.get_enabled(class_id)) { + AddNewResult(final_results, class_id, -1, rating, + classes != NULL, int_result.Config, + fontinfo_id, fontinfo_id2); + } + int_result.Rating = rating; +} +// Applies a set of corrections to the distance im_rating, +// including the cn_correction, miss penalty and additional penalty +// for non-alnums being vertical misfits. Returns the corrected distance. +double Classify::ComputeCorrectedRating(bool debug, int unichar_id, + double cp_rating, double im_rating, + int feature_misses, + int bottom, int top, + int blob_length, + const uinT8* cn_factors) { + // Compute class feature corrections. + double cn_corrected = im_.ApplyCNCorrection(im_rating, blob_length, + cn_factors[unichar_id]); + double miss_penalty = tessedit_class_miss_scale * feature_misses; + double vertical_penalty = 0.0; + // Penalize non-alnums for being vertical misfits. + if (!unicharset.get_isalpha(unichar_id) && + !unicharset.get_isdigit(unichar_id) && + cn_factors[unichar_id] != 0 && classify_misfit_junk_penalty > 0.0) { + int min_bottom, max_bottom, min_top, max_top; + unicharset.get_top_bottom(unichar_id, &min_bottom, &max_bottom, + &min_top, &max_top); + if (debug) { + tprintf("top=%d, vs [%d, %d], bottom=%d, vs [%d, %d]\n", + top, min_top, max_top, bottom, min_bottom, max_bottom); } - - // Add unichars ambiguous with class_id with the same rating as class_id. - if (use_definite_ambigs_for_classifier) { - const UnicharIdVector *definite_ambigs = - getDict().getUnicharAmbigs().OneToOneDefiniteAmbigs(class_id); - int ambigs_size = (definite_ambigs == NULL) ? 0 : definite_ambigs->size(); - for (int ambig = 0; ambig < ambigs_size; ++ambig) { - UNICHAR_ID ambig_class_id = (*definite_ambigs)[ambig]; - // Do not include ambig_class_id if it has permanent adapted templates. - if (classes[class_id]->NumPermConfigs > 0) continue; - ScoredClass* ambig_match = - FindScoredUnichar(final_results, ambig_class_id); - if (matcher_debug_level >= 3) { - tprintf("class: %d definite ambig: %d rating: old %.4f new %.4f\n", - class_id, ambig_class_id, - ambig_match ? ambig_match->rating : WORST_POSSIBLE_RATING, - int_result.Rating); - } - if (ambig_match) { - // ambig_class_id was already added to final_results, - // so just need to modify the rating. - if (int_result.Rating < ambig_match->rating) { - ambig_match->rating = int_result.Rating; - } - } else { - if (classes != NULL) { - AddNewResult( - final_results, ambig_class_id, int_result.Rating, - int_result.Config, int_result.Config2, - GetFontinfoId(classes[class_id], int_result.Config), - GetFontinfoId(classes[class_id], int_result.Config2)); - } else { - AddNewResult(final_results, ambig_class_id, int_result.Rating, - int_result.Config, int_result.Config2, - kBlankFontinfoId, kBlankFontinfoId); - } - } - } + if (top < min_top || top > max_top || + bottom < min_bottom || bottom > max_bottom) { + vertical_penalty = classify_misfit_junk_penalty; } } - if (matcher_debug_level >= 2 || classify_debug_level > 1) - cprintf("\n"); + double result =cn_corrected + miss_penalty + vertical_penalty; + if (result > WORST_POSSIBLE_RATING) + result = WORST_POSSIBLE_RATING; + if (debug) { + tprintf("%s: %2.1f(CP%2.1f, IM%2.1f + CN%.2f(%d) + MP%2.1f + VP%2.1f)\n", + unicharset.id_to_unichar(unichar_id), + result * 100.0, + cp_rating * 100.0, + im_rating * 100.0, + (cn_corrected - im_rating) * 100.0, + cn_factors[unichar_id], + miss_penalty * 100.0, + vertical_penalty * 100.0); + } + return result; } /*---------------------------------------------------------------------------*/ @@ -1317,24 +1405,27 @@ void Classify::MasterMatcher(INT_TEMPLATES templates, * @note History: Tue Mar 12 19:38:03 1991, DSJ, Created. */ UNICHAR_ID *Classify::BaselineClassifier(TBLOB *Blob, + const DENORM& denorm, ADAPT_TEMPLATES Templates, ADAPT_RESULTS *Results) { int NumFeatures; int NumClasses; INT_FEATURE_ARRAY IntFeatures; - CLASS_NORMALIZATION_ARRAY CharNormArray; + uinT8* CharNormArray = new uinT8[unicharset.size()]; CLASS_ID ClassId; BaselineClassifierCalls++; NumFeatures = GetBaselineFeatures( - Blob, Templates->Templates, IntFeatures, CharNormArray, + Blob, denorm, Templates->Templates, IntFeatures, CharNormArray, &(Results->BlobLength)); - if (NumFeatures <= 0) + if (NumFeatures <= 0) { + delete [] CharNormArray; return NULL; + } - NumClasses = ClassPruner(Templates->Templates, NumFeatures, IntFeatures, - CharNormArray, BaselineCutoffs, Results->CPResults); + NumClasses = PruneClasses(Templates->Templates, NumFeatures, IntFeatures, + CharNormArray, BaselineCutoffs, Results->CPResults); NumBaselineClassesTried += NumClasses; @@ -1346,7 +1437,8 @@ UNICHAR_ID *Classify::BaselineClassifier(TBLOB *Blob, Templates->Class, matcher_debug_flags, NumClasses, Blob->bounding_box(), Results->CPResults, Results); - ClassId = Results->best_match.id; + delete [] CharNormArray; + ClassId = Results->best_match.unichar_id; if (ClassId == NO_CLASS) return (NULL); /* this is a bug - maybe should return "" */ @@ -1376,22 +1468,33 @@ UNICHAR_ID *Classify::BaselineClassifier(TBLOB *Blob, * @note History: Tue Mar 12 16:02:52 1991, DSJ, Created. */ int Classify::CharNormClassifier(TBLOB *Blob, + const DENORM& denorm, INT_TEMPLATES Templates, ADAPT_RESULTS *Results) { int NumFeatures; int NumClasses; INT_FEATURE_ARRAY IntFeatures; - CLASS_NORMALIZATION_ARRAY CharNormArray; CharNormClassifierCalls++; - NumFeatures = GetCharNormFeatures(Blob, Templates, IntFeatures, CharNormArray, + uinT8* CharNormArray = new uinT8[unicharset.size()]; + int num_pruner_classes = MAX(unicharset.size(), + PreTrainedTemplates->NumClasses); + uinT8* PrunerNormArray = new uinT8[num_pruner_classes]; + NumFeatures = GetCharNormFeatures(Blob, denorm, Templates, IntFeatures, + PrunerNormArray, CharNormArray, &(Results->BlobLength), NULL); - if (NumFeatures <= 0) + if (NumFeatures <= 0) { + delete [] CharNormArray; + delete [] PrunerNormArray; return 0; + } - NumClasses = ClassPruner(Templates, NumFeatures, IntFeatures, CharNormArray, - CharNormCutoffs, Results->CPResults); + NumClasses = PruneClasses(Templates, NumFeatures, IntFeatures, + PrunerNormArray, + shape_table_ != NULL ? &shapetable_cutoffs_[0] + : CharNormCutoffs, + Results->CPResults); if (tessedit_single_match && NumClasses > 1) NumClasses = 1; @@ -1401,9 +1504,81 @@ int Classify::CharNormClassifier(TBLOB *Blob, MasterMatcher(Templates, NumFeatures, IntFeatures, CharNormArray, NULL, matcher_debug_flags, NumClasses, Blob->bounding_box(), Results->CPResults, Results); + delete [] CharNormArray; + delete [] PrunerNormArray; return NumFeatures; } /* CharNormClassifier */ +// As CharNormClassifier, but operates on a TrainingSample and outputs to +// a GenericVector of ShapeRating without conversion to classes. +int Classify::CharNormTrainingSample(bool pruner_only, + const TrainingSample& sample, + GenericVector* results) { + results->clear(); + ADAPT_RESULTS* adapt_results = new ADAPT_RESULTS(); + adapt_results->Initialize(); + // Compute the bounding box of the features. + int num_features = sample.num_features(); + TBOX blob_box; + for (int f = 0; f < num_features; ++f) { + const INT_FEATURE_STRUCT feature = sample.features()[f]; + TBOX fbox(feature.X, feature.Y, feature.X, feature.Y); + blob_box += fbox; + } + // Compute the char_norm_array from the saved cn_feature. + FEATURE norm_feature = NewFeature(&CharNormDesc); + norm_feature->Params[CharNormY] = sample.cn_feature(CharNormY); + norm_feature->Params[CharNormLength] = sample.cn_feature(CharNormLength); + norm_feature->Params[CharNormRx] = sample.cn_feature(CharNormRx); + norm_feature->Params[CharNormRy] = sample.cn_feature(CharNormRy); + uinT8* char_norm_array = new uinT8[unicharset.size()]; + int num_pruner_classes = MAX(unicharset.size(), + PreTrainedTemplates->NumClasses); + uinT8* pruner_norm_array = new uinT8[num_pruner_classes]; + adapt_results->BlobLength = + static_cast(ActualOutlineLength(norm_feature) * 20 + 0.5); + ComputeCharNormArrays(norm_feature, PreTrainedTemplates, char_norm_array, + pruner_norm_array); + + int num_classes = PruneClasses(PreTrainedTemplates, num_features, + sample.features(), + pruner_norm_array, + shape_table_ != NULL ? &shapetable_cutoffs_[0] + : CharNormCutoffs, + adapt_results->CPResults); + delete [] pruner_norm_array; + if (pruner_only) { + // Convert pruner results to output format. + for (int i = 0; i < num_classes; ++i) { + int class_id = adapt_results->CPResults[i].Class; + int shape_id = class_id; + if (shape_table_ != NULL) { + // All shapes in a class have the same combination of unichars, so + // it doesn't really matter which config we give it, as we aren't + // trying to get the font here. + shape_id = ClassAndConfigIDToFontOrShapeID(class_id, 0); + } + results->push_back( + ShapeRating(shape_id, 1.0f - adapt_results->CPResults[i].Rating)); + } + } else { + im_.SetCharNormMatch(classify_integer_matcher_multiplier); + MasterMatcher(PreTrainedTemplates, num_features, sample.features(), + char_norm_array, + NULL, matcher_debug_flags, num_classes, + blob_box, adapt_results->CPResults, adapt_results); + // Convert master matcher results to output format. + for (int i = 0; i < adapt_results->NumMatches; i++) { + ScoredClass next = adapt_results->match[i]; + results->push_back(ShapeRating(next.shape_id, 1.0f - next.rating)); + } + results->sort(&ShapeRating::SortDescendingRating); + } + delete [] char_norm_array; + delete adapt_results; + return num_features; +} /* CharNormTrainingSample */ + /*---------------------------------------------------------------------------*/ /** @@ -1427,7 +1602,7 @@ void Classify::ClassifyAsNoise(ADAPT_RESULTS *Results) { Rating *= Rating; Rating /= 1.0 + Rating; - AddNewResult(Results, NO_CLASS, Rating, -1, -1, + AddNewResult(Results, NO_CLASS, -1, Rating, false, -1, kBlankFontinfoId, kBlankFontinfoId); } /* ClassifyAsNoise */ } // namespace tesseract @@ -1437,7 +1612,7 @@ void Classify::ClassifyAsNoise(ADAPT_RESULTS *Results) { // Return a pointer to the scored unichar in results, or NULL if not present. ScoredClass *FindScoredUnichar(ADAPT_RESULTS *results, UNICHAR_ID id) { for (int i = 0; i < results->NumMatches; i++) { - if (results->match[i].id == id) + if (results->match[i].unichar_id == id) return &results->match[i]; } return NULL; @@ -1447,7 +1622,8 @@ ScoredClass *FindScoredUnichar(ADAPT_RESULTS *results, UNICHAR_ID id) { // to WORST_POSSIBLE_RATING. ScoredClass ScoredUnichar(ADAPT_RESULTS *results, UNICHAR_ID id) { ScoredClass poor_result = - {id, WORST_POSSIBLE_RATING, -1, -1, kBlankFontinfoId, kBlankFontinfoId}; + {id, -1, WORST_POSSIBLE_RATING, false, -1, + kBlankFontinfoId, kBlankFontinfoId}; ScoredClass *entry = FindScoredUnichar(results, id); return (entry == NULL) ? poor_result : *entry; } @@ -1464,9 +1640,9 @@ int CompareByRating(const void *arg1, // ScoredClass *class1 else if (class1->rating > class2->rating) return 1; - if (class1->id < class2->id) + if (class1->unichar_id < class2->unichar_id) return -1; - else if (class1->id > class2->id) + else if (class1->unichar_id > class2->unichar_id) return 1; return 0; } @@ -1479,7 +1655,8 @@ namespace tesseract { /// sure that there is at least one non-fragmented classification included. /// For each classification result check the unicharset for "definite" /// ambiguities and modify the resulting Choices accordingly. -void Classify::ConvertMatchesToChoices(ADAPT_RESULTS *Results, +void Classify::ConvertMatchesToChoices(const DENORM& denorm, const TBOX& box, + ADAPT_RESULTS *Results, BLOB_CHOICE_LIST *Choices) { assert(Choices != NULL); FLOAT32 Rating; @@ -1488,27 +1665,25 @@ void Classify::ConvertMatchesToChoices(ADAPT_RESULTS *Results, bool contains_nonfrag = false; temp_it.set_to_list(Choices); int choices_length = 0; + // With no shape_table_ maintain the previous MAX_MATCHES as the maximum + // number of returned results, but with a shape_table_ we want to have room + // for at least the biggest shape (which might contain hundreds of Indic + // grapheme fragments) and more, so use double the size of the biggest shape + // if that is more than the default. + int max_matches = MAX_MATCHES; + if (shape_table_ != NULL) { + max_matches = shape_table_->MaxNumUnichars() * 2; + if (max_matches < MAX_MATCHES) + max_matches = MAX_MATCHES; + } for (int i = 0; i < Results->NumMatches; i++) { ScoredClass next = Results->match[i]; int fontinfo_id = next.fontinfo_id; int fontinfo_id2 = next.fontinfo_id2; - if (fontinfo_id == kBlankFontinfoId) { - // ScoredClass next must have come from pre-trained templates, - // so we infer its font information from fontset_table. - int font_set_id = PreTrainedTemplates->Class[next.id]->font_set_id; - if (font_set_id >= 0) { - const FontSet &fs = fontset_table_.get(font_set_id); - if (next.config >= 0 && next.config < fs.size) { - fontinfo_id = fs.configs[next.config]; - } - if (next.config2 >= 0 && next.config2 < fs.size) { - fontinfo_id2 = fs.configs[next.config2]; - } - } - } - bool current_is_frag = (unicharset.get_fragment(next.id) != NULL); - if (temp_it.length()+1 == MAX_MATCHES && + bool adapted = next.adapted; + bool current_is_frag = (unicharset.get_fragment(next.unichar_id) != NULL); + if (temp_it.length()+1 == max_matches && !contains_nonfrag && current_is_frag) { continue; // look for a non-fragmented character to fill the // last spot in Choices if only fragments are present @@ -1525,12 +1700,16 @@ void Classify::ConvertMatchesToChoices(ADAPT_RESULTS *Results, Rating *= rating_scale * Results->BlobLength; Certainty *= -(getDict().certainty_scale); } - temp_it.add_to_end(new BLOB_CHOICE(next.id, Rating, Certainty, + inT16 min_xheight, max_xheight; + denorm.XHeightRange(next.unichar_id, unicharset, box, + &min_xheight, &max_xheight); + temp_it.add_to_end(new BLOB_CHOICE(next.unichar_id, Rating, Certainty, fontinfo_id, fontinfo_id2, - unicharset.get_script(next.id))); + unicharset.get_script(next.unichar_id), + min_xheight, max_xheight, adapted)); contains_nonfrag |= !current_is_frag; // update contains_nonfrag choices_length++; - if (choices_length >= MAX_MATCHES) break; + if (choices_length >= max_matches) break; } Results->NumMatches = choices_length; } // ConvertMatchesToChoices @@ -1549,37 +1728,38 @@ void Classify::ConvertMatchesToChoices(ADAPT_RESULTS *Results, * @note History: Wed Mar 13 16:44:41 1991, DSJ, Created. */ void Classify::DebugAdaptiveClassifier(TBLOB *Blob, + const DENORM& denorm, ADAPT_RESULTS *Results) { + for (int i = 0; i < Results->NumMatches; i++) { + if (Results->match[i].rating < Results->best_match.rating) + Results->best_match = Results->match[i]; + } const char *Prompt = "Left-click in IntegerMatch Window to continue or right click to debug..."; - const char *DebugMode = "All Templates"; - CLASS_ID LastClass = Results->best_match.id; - CLASS_ID ClassId; + CLASS_ID unichar_id = Results->best_match.unichar_id; + int shape_id = Results->best_match.shape_id; bool adaptive_on = true; bool pretrained_on = true; - ShowMatchDisplay(); - cprintf ("\nDebugging class = %s (%s) ...\n", - unicharset.id_to_unichar(LastClass), DebugMode); - ShowBestMatchFor(Blob, LastClass, adaptive_on, pretrained_on); - UpdateMatchDisplay(); - - while ((ClassId = GetClassToDebug(Prompt, &adaptive_on, - &pretrained_on)) != 0) { + const char* debug_mode; + do { if (!pretrained_on) - DebugMode = "Adaptive Templates Only"; + debug_mode = "Adaptive Templates Only"; else if (!adaptive_on) - DebugMode = "PreTrained Templates Only"; + debug_mode = "PreTrained Templates Only"; else - DebugMode = "All Templates"; - LastClass = ClassId; - + debug_mode = "All Templates"; ShowMatchDisplay(); - cprintf ("\nDebugging class = %d = %s (%s) ...\n", - LastClass, unicharset.id_to_unichar(LastClass), DebugMode); - ShowBestMatchFor(Blob, LastClass, adaptive_on, pretrained_on); + tprintf("Debugging class %d = %s in mode %s ...", + unichar_id, unicharset.id_to_unichar(unichar_id), debug_mode); + if (shape_id >= 0 && shape_table_ != NULL) { + tprintf(" from shape %s\n", shape_table_->DebugStr(shape_id).string()); + } + ShowBestMatchFor(Blob, denorm, unichar_id, shape_id, adaptive_on, + pretrained_on, Results); UpdateMatchDisplay(); - } + } while ((unichar_id = GetClassToDebug(Prompt, &adaptive_on, + &pretrained_on, &shape_id)) != 0); } /* DebugAdaptiveClassifier */ #endif @@ -1607,6 +1787,7 @@ void Classify::DebugAdaptiveClassifier(TBLOB *Blob, * @note History: Tue Mar 12 08:50:11 1991, DSJ, Created. */ void Classify::DoAdaptiveMatch(TBLOB *Blob, + const DENORM& denorm, ADAPT_RESULTS *Results) { UNICHAR_ID *Ambiguities; @@ -1615,17 +1796,16 @@ void Classify::DoAdaptiveMatch(TBLOB *Blob, if (AdaptedTemplates->NumPermClasses < matcher_permanent_classes_min || tess_cn_matching) { - CharNormClassifier(Blob, PreTrainedTemplates, Results); - } - else { - Ambiguities = BaselineClassifier(Blob, AdaptedTemplates, Results); + CharNormClassifier(Blob, denorm, PreTrainedTemplates, Results); + } else { + Ambiguities = BaselineClassifier(Blob, denorm, AdaptedTemplates, Results); if ((Results->NumMatches > 0 && MarginalMatch (Results->best_match.rating) && !tess_bn_matching) || Results->NumMatches == 0) { - CharNormClassifier(Blob, PreTrainedTemplates, Results); + CharNormClassifier(Blob, denorm, PreTrainedTemplates, Results); } else if (Ambiguities && *Ambiguities >= 0 && !tess_bn_matching) { - AmbigClassifier(Blob, + AmbigClassifier(Blob, denorm, PreTrainedTemplates, AdaptedTemplates->Class, Ambiguities, @@ -1637,10 +1817,7 @@ void Classify::DoAdaptiveMatch(TBLOB *Blob, // if the results contain only fragments. // TODO(daria): verify that this is better than // just adding a NULL classification. - if (!Results->HasNonfragment) { - Results->NumMatches = 0; - } - if (Results->NumMatches == 0) + if (!Results->HasNonfragment || Results->NumMatches == 0) ClassifyAsNoise(Results); } /* DoAdaptiveMatch */ @@ -1661,7 +1838,6 @@ void Classify::DoAdaptiveMatch(TBLOB *Blob, * @param[out] Thresholds array of thresholds to be filled in * * Globals: - * - classify_enable_new_adapt_rules * - matcher_good_threshold * - matcher_perfect_threshold * - matcher_rating_margin @@ -1671,43 +1847,14 @@ void Classify::DoAdaptiveMatch(TBLOB *Blob, * @note History: Fri May 31 09:22:08 1991, DSJ, Created. */ void Classify::GetAdaptThresholds(TWERD * Word, + const DENORM& denorm, const WERD_CHOICE& BestChoice, const WERD_CHOICE& BestRawChoice, FLOAT32 Thresholds[]) { - TBLOB *Blob; - const char* BestChoice_string = BestChoice.unichar_string().string(); - const char* BestChoice_lengths = BestChoice.unichar_lengths().string(); - const char* BestRawChoice_string = BestRawChoice.unichar_string().string(); - const char* BestRawChoice_lengths = BestRawChoice.unichar_lengths().string(); - - if (classify_enable_new_adapt_rules && /* new rules */ - getDict().CurrentBestChoiceIs(BestChoice)) { - getDict().FindClassifierErrors(matcher_perfect_threshold, - matcher_good_threshold, - matcher_rating_margin, - Thresholds); - } else { /* old rules */ - for (Blob = Word->blobs; - Blob != NULL; - Blob = Blob->next, BestChoice_string += *(BestChoice_lengths++), - BestRawChoice_string += *(BestRawChoice_lengths++), Thresholds++) - if (*(BestChoice_lengths) == *(BestRawChoice_lengths) && - strncmp(BestChoice_string, BestRawChoice_string, - *(BestChoice_lengths)) == 0) - *Thresholds = matcher_good_threshold; - else { - /* the blob was incorrectly classified - find the rating threshold - needed to create a template which will correct the error with - some margin. However, don't waste time trying to make - templates which are too tight. */ - *Thresholds = GetBestRatingFor( - Blob, unicharset.unichar_to_id(BestChoice_string, - *BestChoice_lengths)); - *Thresholds *= (1.0 - matcher_rating_margin); - *Thresholds = ClipToRange( - *Thresholds, matcher_perfect_threshold, matcher_good_threshold); - } - } + getDict().FindClassifierErrors(matcher_perfect_threshold, + matcher_good_threshold, + matcher_rating_margin, + Thresholds); } /* GetAdaptThresholds */ /*---------------------------------------------------------------------------*/ @@ -1728,6 +1875,7 @@ void Classify::GetAdaptThresholds(TWERD * Word, * @note History: Fri Mar 15 08:08:22 1991, DSJ, Created. */ UNICHAR_ID *Classify::GetAmbiguities(TBLOB *Blob, + const DENORM& denorm, CLASS_ID CorrectClass) { ADAPT_RESULTS *Results = new ADAPT_RESULTS(); UNICHAR_ID *Ambiguities; @@ -1735,7 +1883,7 @@ UNICHAR_ID *Classify::GetAmbiguities(TBLOB *Blob, Results->Initialize(); - CharNormClassifier(Blob, PreTrainedTemplates, Results); + CharNormClassifier(Blob, denorm, PreTrainedTemplates, Results); RemoveBadMatches(Results); qsort((void *)Results->match, Results->NumMatches, sizeof(ScoredClass), CompareByRating); @@ -1745,9 +1893,10 @@ UNICHAR_ID *Classify::GetAmbiguities(TBLOB *Blob, Ambiguities = (UNICHAR_ID *) Emalloc (sizeof (UNICHAR_ID) * (Results->NumMatches + 1)); if (Results->NumMatches > 1 || - (Results->NumMatches == 1 && Results->match[0].id != CorrectClass)) { + (Results->NumMatches == 1 && + Results->match[0].unichar_id != CorrectClass)) { for (i = 0; i < Results->NumMatches; i++) - Ambiguities[i] = Results->match[i].id; + Ambiguities[i] = Results->match[i].unichar_id; Ambiguities[i] = -1; } else { Ambiguities[0] = -1; @@ -1757,165 +1906,6 @@ UNICHAR_ID *Classify::GetAmbiguities(TBLOB *Blob, return Ambiguities; } /* GetAmbiguities */ -/*---------------------------------------------------------------------------*/ -/** - * This routine sets up the feature extractor to extract - * baseline normalized pico-features. - * - * The extracted pico-features are converted - * to integer form and placed in IntFeatures. CharNormArray - * is filled with 0's to indicate to the matcher that no - * character normalization adjustment needs to be done. - * - * The total length of all blob outlines - * in baseline normalized units is also returned. - * - * @param Blob blob to extract features from - * @param Templates used to compute char norm adjustments - * @param IntFeatures array to fill with integer features - * @param CharNormArray array to fill with dummy char norm adjustments - * @param BlobLength length of blob in baseline-normalized units - * - * Globals: none - * - * @return Number of pico-features returned (0 if an error occurred) - * @note Exceptions: none - * @note History: Tue Mar 12 17:55:18 1991, DSJ, Created. - */ -int Classify::GetBaselineFeatures(TBLOB *Blob, - INT_TEMPLATES Templates, - INT_FEATURE_ARRAY IntFeatures, - CLASS_NORMALIZATION_ARRAY CharNormArray, - inT32 *BlobLength) { - FEATURE_SET Features; - int NumFeatures; - - if (classify_enable_int_fx) { - return GetIntBaselineFeatures(Blob, Templates, - IntFeatures, CharNormArray, BlobLength); - } - - classify_norm_method.set_value(baseline); - Features = ExtractPicoFeatures(Blob); - - NumFeatures = Features->NumFeatures; - *BlobLength = NumFeatures; - if (NumFeatures > UNLIKELY_NUM_FEAT) { - FreeFeatureSet(Features); - return (0); - } - - ComputeIntFeatures(Features, IntFeatures); - ClearCharNormArray(Templates, CharNormArray); - - FreeFeatureSet(Features); - return NumFeatures; -} /* GetBaselineFeatures */ - -/** - * This routine classifies Blob against both sets of - * templates for the specified class and returns the best - * rating found. - * - * @param Blob blob to get best rating for - * @param ClassId class blob is to be compared to - * - * Globals: - * - PreTrainedTemplates built-in templates - * - AdaptedTemplates current set of adapted templates - * - AllProtosOn dummy mask to enable all protos - * - AllConfigsOn dummy mask to enable all configs - * - * @return Best rating for match of Blob to ClassId. - * @note Exceptions: none - * @note History: Tue Apr 9 09:01:24 1991, DSJ, Created. - */ -FLOAT32 Classify::GetBestRatingFor(TBLOB *Blob, - CLASS_ID ClassId) { - int NumCNFeatures, NumBLFeatures; - INT_FEATURE_ARRAY CNFeatures, BLFeatures; - INT_RESULT_STRUCT CNResult, BLResult; - inT32 BlobLength; - - CNResult.Rating = BLResult.Rating = 1.0; - - if (!LegalClassId(ClassId)) - return 1.0; - - uinT8 *CNAdjust = new uinT8[MAX_NUM_CLASSES]; - uinT8 *BLAdjust = new uinT8[MAX_NUM_CLASSES]; - - if (!UnusedClassIdIn(PreTrainedTemplates, ClassId)) { - NumCNFeatures = GetCharNormFeatures(Blob, PreTrainedTemplates, - CNFeatures, CNAdjust, &BlobLength, - NULL); - if (NumCNFeatures > 0) { - im_.SetCharNormMatch(classify_integer_matcher_multiplier); - im_.Match(ClassForClassId(PreTrainedTemplates, ClassId), - AllProtosOn, AllConfigsOn, - BlobLength, NumCNFeatures, CNFeatures, - CNAdjust[ClassId], &CNResult, - classify_adapt_feature_threshold, NO_DEBUG, - matcher_debug_separate_windows); - } - } - - if (!UnusedClassIdIn(AdaptedTemplates->Templates, ClassId)) { - NumBLFeatures = GetBaselineFeatures(Blob, - AdaptedTemplates->Templates, - BLFeatures, BLAdjust, &BlobLength); - if (NumBLFeatures > 0) { - im_.SetBaseLineMatch(); - im_.Match(ClassForClassId(AdaptedTemplates->Templates, ClassId), - AdaptedTemplates->Class[ClassId]->PermProtos, - AdaptedTemplates->Class[ClassId]->PermConfigs, - BlobLength, NumBLFeatures, BLFeatures, - BLAdjust[ClassId], &BLResult, - classify_adapt_feature_threshold, NO_DEBUG, - matcher_debug_separate_windows); - } - } - - // Clean up. - delete[] CNAdjust; - delete[] BLAdjust; - - return MIN(BLResult.Rating, CNResult.Rating); -} /* GetBestRatingFor */ - -/*---------------------------------------------------------------------------*/ -/** - * This routine sets up the feature extractor to extract - * character normalization features and character normalized - * pico-features. The extracted pico-features are converted - * to integer form and placed in IntFeatures. The character - * normalization features are matched to each class in - * templates and the resulting adjustment factors are returned - * in CharNormArray. The total length of all blob outlines - * in baseline normalized units is also returned. - * - * @param Blob blob to extract features from - * @param Templates used to compute char norm adjustments - * @param IntFeatures array to fill with integer features - * @param CharNormArray array to fill with char norm adjustments - * @param BlobLength length of blob in baseline-normalized units - * - * Globals: none - * - * @return Number of pico-features returned (0 if an error occurred) - * @note Exceptions: none - * @note History: Tue Mar 12 17:55:18 1991, DSJ, Created. - */ -int Classify::GetCharNormFeatures(TBLOB *Blob, - INT_TEMPLATES Templates, - INT_FEATURE_ARRAY IntFeatures, - CLASS_NORMALIZATION_ARRAY CharNormArray, - inT32 *BlobLength, - inT32 *FeatureOutlineIndex) { - return GetIntCharNormFeatures(Blob, Templates, IntFeatures, CharNormArray, - BlobLength, FeatureOutlineIndex); -} /* GetCharNormFeatures */ - /*---------------------------------------------------------------------------*/ /** * This routine calls the integer (Hardware) feature @@ -1942,16 +1932,17 @@ int Classify::GetCharNormFeatures(TBLOB *Blob, * @note Exceptions: none * @note History: Tue May 28 10:40:52 1991, DSJ, Created. */ -int Classify::GetIntBaselineFeatures(TBLOB *Blob, - INT_TEMPLATES Templates, - INT_FEATURE_ARRAY IntFeatures, - CLASS_NORMALIZATION_ARRAY CharNormArray, - inT32 *BlobLength) { +int Classify::GetBaselineFeatures(TBLOB *Blob, + const DENORM& denorm, + INT_TEMPLATES Templates, + INT_FEATURE_ARRAY IntFeatures, + uinT8* CharNormArray, + inT32 *BlobLength) { register INT_FEATURE Src, Dest, End; if (!FeaturesHaveBeenExtracted) { - FeaturesOK = ExtractIntFeat(Blob, denorm_, BaselineFeatures, - CharNormFeatures, &FXInfo); + FeaturesOK = ExtractIntFeat(Blob, denorm, BaselineFeatures, + CharNormFeatures, &FXInfo, NULL); FeaturesHaveBeenExtracted = TRUE; } @@ -1964,10 +1955,10 @@ int Classify::GetIntBaselineFeatures(TBLOB *Blob, Src < End; *Dest++ = *Src++); - ClearCharNormArray(Templates, CharNormArray); + ClearCharNormArray(CharNormArray); *BlobLength = FXInfo.NumBL; return FXInfo.NumBL; -} /* GetIntBaselineFeatures */ +} /* GetBaselineFeatures */ void Classify::ResetFeaturesHaveBeenExtracted() { FeaturesHaveBeenExtracted = FALSE; @@ -1975,9 +1966,9 @@ void Classify::ResetFeaturesHaveBeenExtracted() { // Returns true if the given blob looks too dissimilar to any character // present in the classifier templates. -bool Classify::LooksLikeGarbage(TBLOB *blob) { +bool Classify::LooksLikeGarbage(const DENORM& denorm, TBLOB *blob) { BLOB_CHOICE_LIST *ratings = new BLOB_CHOICE_LIST(); - AdaptiveClassifier(blob, ratings, NULL); + AdaptiveClassifier(blob, denorm, ratings, NULL); BLOB_CHOICE_IT ratings_it(ratings); const UNICHARSET &unicharset = getDict().getUnicharset(); if (classify_debug_character_fragments) { @@ -2025,19 +2016,21 @@ bool Classify::LooksLikeGarbage(TBLOB *blob) { * @note Exceptions: none * @note History: Tue May 28 10:40:52 1991, DSJ, Created. */ -int Classify::GetIntCharNormFeatures(TBLOB *Blob, - INT_TEMPLATES Templates, - INT_FEATURE_ARRAY IntFeatures, - CLASS_NORMALIZATION_ARRAY CharNormArray, - inT32 *BlobLength, - inT32 *FeatureOutlineArray) { +int Classify::GetCharNormFeatures(TBLOB *Blob, + const DENORM& denorm, + INT_TEMPLATES Templates, + INT_FEATURE_ARRAY IntFeatures, + uinT8* PrunerNormArray, + uinT8* CharNormArray, + inT32 *BlobLength, + inT32 *FeatureOutlineArray) { register INT_FEATURE Src, Dest, End; FEATURE NormFeature; FLOAT32 Baseline, Scale; inT32 FeatureOutlineIndex[MAX_NUM_INT_FEATURES]; if (!FeaturesHaveBeenExtracted) { - FeaturesOK = ExtractIntFeat(Blob, denorm_, BaselineFeatures, + FeaturesOK = ExtractIntFeat(Blob, denorm, BaselineFeatures, CharNormFeatures, &FXInfo, FeatureOutlineIndex); FeaturesHaveBeenExtracted = TRUE; @@ -2063,12 +2056,41 @@ int Classify::GetIntCharNormFeatures(TBLOB *Blob, FXInfo.Length * Scale / LENGTH_COMPRESSION; NormFeature->Params[CharNormRx] = FXInfo.Rx * Scale; NormFeature->Params[CharNormRy] = FXInfo.Ry * Scale; - ComputeIntCharNormArray(NormFeature, Templates, CharNormArray); - FreeFeature(NormFeature); - + ComputeCharNormArrays(NormFeature, Templates, CharNormArray, PrunerNormArray); *BlobLength = FXInfo.NumBL; return (FXInfo.NumCN); -} /* GetIntCharNormFeatures */ +} /* GetCharNormFeatures */ + +// Computes the char_norm_array for the unicharset and, if not NULL, the +// pruner_array as appropriate according to the existence of the shape_table. +void Classify::ComputeCharNormArrays(FEATURE_STRUCT* norm_feature, + INT_TEMPLATES_STRUCT* templates, + uinT8* char_norm_array, + uinT8* pruner_array) { + ComputeIntCharNormArray(*norm_feature, char_norm_array); + if (pruner_array != NULL) { + if (shape_table_ == NULL) { + ComputeIntCharNormArray(*norm_feature, pruner_array); + } else { + memset(pruner_array, MAX_UINT8, + templates->NumClasses * sizeof(pruner_array[0])); + // Each entry in the pruner norm array is the MIN of all the entries of + // the corresponding unichars in the CharNormArray. + for (int id = 0; id < templates->NumClasses; ++id) { + int font_set_id = templates->Class[id]->font_set_id; + const FontSet &fs = fontset_table_.get(font_set_id); + for (int config = 0; config < fs.size; ++config) { + const Shape& shape = shape_table_->GetShape(fs.configs[config]); + for (int c = 0; c < shape.size(); ++c) { + if (char_norm_array[shape[c].unichar_id] < pruner_array[id]) + pruner_array[id] = char_norm_array[shape[c].unichar_id]; + } + } + } + } + } + FreeFeature(norm_feature); +} /*---------------------------------------------------------------------------*/ /** @@ -2265,7 +2287,7 @@ PROTO_ID Classify::MakeNewTempProtos(FEATURE_SET Features, * @param ClassId class containing config to be made permanent * @param ConfigId config to be made permanent * @param Blob current blob being adapted to - * + * * Globals: none * * @note Exceptions: none @@ -2274,6 +2296,7 @@ PROTO_ID Classify::MakeNewTempProtos(FEATURE_SET Features, void Classify::MakePermanent(ADAPT_TEMPLATES Templates, CLASS_ID ClassId, int ConfigId, + const DENORM& denorm, TBLOB *Blob) { UNICHAR_ID *Ambigs; TEMP_CONFIG Config; @@ -2289,7 +2312,7 @@ void Classify::MakePermanent(ADAPT_TEMPLATES Templates, Class->NumPermConfigs++; // Initialize permanent config. - Ambigs = GetAmbiguities(Blob, ClassId); + Ambigs = GetAmbiguities(Blob, denorm, ClassId); PERM_CONFIG Perm = (PERM_CONFIG) alloc_struct(sizeof(PERM_CONFIG_STRUCT), "PERM_CONFIG_STRUCT"); Perm->Ambigs = Ambigs; @@ -2324,12 +2347,12 @@ void Classify::MakePermanent(ADAPT_TEMPLATES Templates, * This routine converts TempProto to be permanent if * its proto id is used by the configuration specified in * ProtoKey. - * + * * @param TempProto temporary proto to compare to key * @param ProtoKey defines which protos to make permanent - * + * * Globals: none - * + * * @return TRUE if TempProto is converted, FALSE otherwise * @note Exceptions: none * @note History: Thu Mar 14 18:49:54 1991, DSJ, Created. @@ -2373,12 +2396,12 @@ namespace tesseract { */ void Classify::PrintAdaptiveMatchResults(FILE *File, ADAPT_RESULTS *Results) { for (int i = 0; i < Results->NumMatches; ++i) { - cprintf("%s(%d) %.2f ", - unicharset.debug_str(Results->match[i].id).string(), - Results->match[i].id, + tprintf("%s(%d), shape %d, %.2f ", + unicharset.debug_str(Results->match[i].unichar_id).string(), + Results->match[i].unichar_id, Results->match[i].shape_id, Results->match[i].rating * 100.0); } - printf("\n"); + tprintf("\n"); } /* PrintAdaptiveMatchResults */ /*---------------------------------------------------------------------------*/ @@ -2393,7 +2416,7 @@ void Classify::PrintAdaptiveMatchResults(FILE *File, ADAPT_RESULTS *Results) { * * Globals: * - matcher_bad_match_pad defines a "bad match" - * + * * @note Exceptions: none * @note History: Tue Mar 12 13:51:03 1991, DSJ, Created. */ @@ -2414,15 +2437,16 @@ void Classify::RemoveBadMatches(ADAPT_RESULTS *Results) { for (Next = NextGood = 0; Next < Results->NumMatches; Next++) { if (Results->match[Next].rating <= BadMatchThreshold) { ScoredClass match = Results->match[Next]; - if (!unicharset.get_isalpha(match.id) || - strstr(romans, unicharset.id_to_unichar(match.id)) != NULL) { + if (!unicharset.get_isalpha(match.unichar_id) || + strstr(romans, + unicharset.id_to_unichar(match.unichar_id)) != NULL) { Results->match[NextGood++] = Results->match[Next]; - } else if (unicharset.eq(match.id, "l") && + } else if (unicharset.eq(match.unichar_id, "l") && scored_one.rating >= BadMatchThreshold) { Results->match[NextGood] = scored_one; Results->match[NextGood].rating = match.rating; NextGood++; - } else if (unicharset.eq(match.id, "O") && + } else if (unicharset.eq(match.unichar_id, "O") && scored_zero.rating >= BadMatchThreshold) { Results->match[NextGood] = scored_zero; Results->match[NextGood].rating = match.rating; @@ -2451,7 +2475,7 @@ void Classify::RemoveBadMatches(ADAPT_RESULTS *Results) { * * Globals: * - matcher_bad_match_pad defines a "bad match" - * + * * @note Exceptions: none * @note History: Tue Mar 12 13:51:03 1991, DSJ, Created. */ @@ -2467,12 +2491,14 @@ void Classify::RemoveExtraPuncs(ADAPT_RESULTS *Results) { digit_count = 0; for (Next = NextGood = 0; Next < Results->NumMatches; Next++) { ScoredClass match = Results->match[Next]; - if (strstr(punc_chars, unicharset.id_to_unichar(match.id)) != NULL) { + if (strstr(punc_chars, + unicharset.id_to_unichar(match.unichar_id)) != NULL) { if (punc_count < 2) Results->match[NextGood++] = match; punc_count++; } else { - if (strstr(digit_chars, unicharset.id_to_unichar(match.id)) != NULL) { + if (strstr(digit_chars, + unicharset.id_to_unichar(match.unichar_id)) != NULL) { if (digit_count < 1) Results->match[NextGood++] = match; digit_count++; @@ -2527,9 +2553,12 @@ void Classify::SetAdaptiveThreshold(FLOAT32 Threshold) { * @note History: Fri Mar 22 08:43:52 1991, DSJ, Created. */ void Classify::ShowBestMatchFor(TBLOB *Blob, + const DENORM& denorm, CLASS_ID ClassId, + int shape_id, BOOL8 AdaptiveOn, - BOOL8 PreTrainedOn) { + BOOL8 PreTrainedOn, + ADAPT_RESULTS *Results) { int NumCNFeatures = 0, NumBLFeatures = 0; INT_FEATURE_ARRAY CNFeatures, BLFeatures; INT_RESULT_STRUCT CNResult, BLResult; @@ -2549,62 +2578,70 @@ void Classify::ShowBestMatchFor(TBLOB *Blob, uinT8 *CNAdjust = new uinT8[MAX_NUM_CLASSES]; uinT8 *BLAdjust = new uinT8[MAX_NUM_CLASSES]; - if (PreTrainedOn) { - if (UnusedClassIdIn (PreTrainedTemplates, ClassId)) - cprintf ("No built-in templates for class %d = %s\n", - ClassId, unicharset.id_to_unichar(ClassId)); - else { - NumCNFeatures = GetCharNormFeatures( - Blob, PreTrainedTemplates, CNFeatures, CNAdjust, &BlobLength, NULL); - if (NumCNFeatures <= 0) - cprintf ("Illegal blob (char norm features)!\n"); - else { + if (shape_table_ == NULL) + shape_id = ClassId; + else + shape_id = ShapeIDToClassID(shape_id); + if (PreTrainedOn && shape_id >= 0) { + if (UnusedClassIdIn(PreTrainedTemplates, shape_id)) { + tprintf("No built-in templates for class/shape %d\n", shape_id); + } else { + NumCNFeatures = GetCharNormFeatures(Blob, denorm, PreTrainedTemplates, + CNFeatures, NULL, CNAdjust, + &BlobLength, NULL); + if (NumCNFeatures <= 0) { + tprintf("Illegal blob (char norm features)!\n"); + } else { im_.SetCharNormMatch(classify_integer_matcher_multiplier); - im_.Match(ClassForClassId (PreTrainedTemplates, ClassId), + im_.Match(ClassForClassId(PreTrainedTemplates, shape_id), AllProtosOn, AllConfigsOn, - BlobLength, NumCNFeatures, CNFeatures, - CNAdjust[ClassId], &CNResult, + NumCNFeatures, CNFeatures, + &CNResult, classify_adapt_feature_threshold, NO_DEBUG, matcher_debug_separate_windows); - - cprintf ("Best built-in template match is config %2d (%4.1f) (cn=%d)\n", - CNResult.Config, CNResult.Rating * 100.0, CNAdjust[ClassId]); + ExpandShapesAndApplyCorrections(NULL, false, shape_id, + Blob->bounding_box().bottom(), + Blob->bounding_box().top(), + 0, BlobLength, CNAdjust, + CNResult, Results); } } } if (AdaptiveOn) { - if (UnusedClassIdIn (AdaptedTemplates->Templates, ClassId)) - cprintf ("No AD templates for class %d = %s\n", - ClassId, unicharset.id_to_unichar(ClassId)); - else { + if (ClassId < 0 || ClassId >= AdaptedTemplates->Templates->NumClasses) { + tprintf("Invalid adapted class id: %d\n", ClassId); + } else if (UnusedClassIdIn(AdaptedTemplates->Templates, ClassId) || + AdaptedTemplates->Class[ClassId] == NULL || + IsEmptyAdaptedClass(AdaptedTemplates->Class[ClassId])) { + tprintf("No AD templates for class %d = %s\n", + ClassId, unicharset.id_to_unichar(ClassId)); + } else { NumBLFeatures = GetBaselineFeatures(Blob, + denorm, AdaptedTemplates->Templates, BLFeatures, BLAdjust, &BlobLength); if (NumBLFeatures <= 0) - cprintf ("Illegal blob (baseline features)!\n"); + tprintf("Illegal blob (baseline features)!\n"); else { im_.SetBaseLineMatch(); - im_.Match(ClassForClassId - (AdaptedTemplates->Templates, ClassId), + im_.Match(ClassForClassId(AdaptedTemplates->Templates, ClassId), AllProtosOn, AllConfigsOn, - BlobLength, NumBLFeatures, BLFeatures, - BLAdjust[ClassId], &BLResult, + NumBLFeatures, BLFeatures, + &BLResult, classify_adapt_feature_threshold, NO_DEBUG, matcher_debug_separate_windows); - -#ifndef SECURE_NAMES - ADAPT_CLASS Class = AdaptedTemplates->Class[ClassId]; - cprintf ("Best adaptive template match is config %2d (%4.1f) %s\n", - BLResult.Config, BLResult.Rating * 100.0, - ConfigIsPermanent(Class, BLResult.Config) ? "Perm" : "Temp"); -#endif + ExpandShapesAndApplyCorrections( + AdaptedTemplates->Class, false, + ClassId, Blob->bounding_box().bottom(), + Blob->bounding_box().top(), 0, BlobLength, CNAdjust, + BLResult, Results); } } } - cprintf ("\n"); + tprintf("\n"); if (BLResult.Rating < CNResult.Rating) { if (next_config < 0) { ConfigMask = 1 << BLResult.Config; @@ -2616,28 +2653,37 @@ void Classify::ShowBestMatchFor(TBLOB *Blob, classify_norm_method.set_value(baseline); im_.SetBaseLineMatch(); + tprintf("Adaptive Class ID: %d\n", ClassId); im_.Match(ClassForClassId(AdaptedTemplates->Templates, ClassId), AllProtosOn, (BIT_VECTOR) &ConfigMask, - BlobLength, NumBLFeatures, BLFeatures, - BLAdjust[ClassId], &BLResult, + NumBLFeatures, BLFeatures, + &BLResult, classify_adapt_feature_threshold, matcher_debug_flags, matcher_debug_separate_windows); - cprintf ("Adaptive template match for config %2d is %4.1f\n", - BLResult.Config, BLResult.Rating * 100.0); - } - else { + ExpandShapesAndApplyCorrections( + AdaptedTemplates->Class, true, + ClassId, Blob->bounding_box().bottom(), + Blob->bounding_box().top(), 0, BlobLength, CNAdjust, + BLResult, Results); + } else if (shape_id >= 0) { ConfigMask = 1 << CNResult.Config; classify_norm_method.set_value(character); + tprintf("Static Shape ID: %d\n", shape_id); im_.SetCharNormMatch(classify_integer_matcher_multiplier); - im_.Match(ClassForClassId (PreTrainedTemplates, ClassId), + im_.Match(ClassForClassId (PreTrainedTemplates, shape_id), AllProtosOn, (BIT_VECTOR) & ConfigMask, - BlobLength, NumCNFeatures, CNFeatures, - CNAdjust[ClassId], &CNResult, + NumCNFeatures, CNFeatures, + &CNResult, classify_adapt_feature_threshold, matcher_debug_flags, matcher_debug_separate_windows); + ExpandShapesAndApplyCorrections(NULL, true, shape_id, + Blob->bounding_box().bottom(), + Blob->bounding_box().top(), + 0, BlobLength, CNAdjust, + CNResult, Results); } // Clean up. @@ -2645,6 +2691,48 @@ void Classify::ShowBestMatchFor(TBLOB *Blob, delete[] BLAdjust; } /* ShowBestMatchFor */ +// Returns a string for the classifier class_id: either the corresponding +// unicharset debug_str or the shape_table_ debug str. +STRING Classify::ClassIDToDebugStr(const INT_TEMPLATES_STRUCT* templates, + int class_id, int config_id) const { + STRING class_string; + if (templates == PreTrainedTemplates && shape_table_ != NULL) { + int shape_id = ClassAndConfigIDToFontOrShapeID(class_id, config_id); + class_string = shape_table_->DebugStr(shape_id); + } else { + class_string = unicharset.debug_str(class_id); + } + return class_string; +} + +// Converts a classifier class_id index to a shape_table_ index +int Classify::ClassAndConfigIDToFontOrShapeID(int class_id, + int int_result_config) const { + int font_set_id = PreTrainedTemplates->Class[class_id]->font_set_id; + // Older inttemps have no font_ids. + if (font_set_id < 0) + return kBlankFontinfoId; + const FontSet &fs = fontset_table_.get(font_set_id); + ASSERT_HOST(int_result_config >= 0 && int_result_config < fs.size); + return fs.configs[int_result_config]; +} + +// Converts a shape_table_ index to a classifier class_id index (not a +// unichar-id!). Uses a search, so not fast. +int Classify::ShapeIDToClassID(int shape_id) const { + for (int id = 0; id < PreTrainedTemplates->NumClasses; ++id) { + int font_set_id = PreTrainedTemplates->Class[id]->font_set_id; + ASSERT_HOST(font_set_id >= 0); + const FontSet &fs = fontset_table_.get(font_set_id); + for (int config = 0; config < fs.size; ++config) { + if (fs.configs[config] == shape_id) + return id; + } + } + tprintf("Shape %d not found\n", shape_id); + return -1; +} + // Returns true if the given TEMP_CONFIG is good enough to make it // a permanent config. bool Classify::TempConfigReliable(CLASS_ID class_id, @@ -2684,7 +2772,8 @@ bool Classify::TempConfigReliable(CLASS_ID class_id, return true; } -void Classify::UpdateAmbigsGroup(CLASS_ID class_id, TBLOB *Blob) { +void Classify::UpdateAmbigsGroup(CLASS_ID class_id, const DENORM& denorm, + TBLOB *Blob) { const UnicharIdVector *ambigs = getDict().getUnicharAmbigs().ReverseAmbigsForAdaption(class_id); int ambigs_size = (ambigs == NULL) ? 0 : ambigs->size(); @@ -2705,7 +2794,7 @@ void Classify::UpdateAmbigsGroup(CLASS_ID class_id, TBLOB *Blob) { getDict().getUnicharset().debug_str( ambig_class_id).string()); } - MakePermanent(AdaptedTemplates, ambig_class_id, cfg, Blob); + MakePermanent(AdaptedTemplates, ambig_class_id, cfg, denorm, Blob); } } } diff --git a/classify/blobclass.cpp b/classify/blobclass.cpp index a6cc826a2a..93ea3fc039 100644 --- a/classify/blobclass.cpp +++ b/classify/blobclass.cpp @@ -112,11 +112,15 @@ void LearnBlob(const FEATURE_DEFS_STRUCT &FeatureDefs, FILE* FeatureFile, return; } - // label the features with a class name and font name - fprintf (FeatureFile, "\n%s %s ", FontName, BlobText); - - // write micro-features to file and clean up - WriteCharDescription(FeatureDefs, FeatureFile, CharDesc); + if (ValidCharDescription(FeatureDefs, CharDesc)) { + // label the features with a class name and font name + fprintf(FeatureFile, "\n%s %s\n", FontName, BlobText); + + // write micro-features to file and clean up + WriteCharDescription(FeatureDefs, FeatureFile, CharDesc); + } else { + tprintf("Blob learned was invalid!\n"); + } FreeCharDescription(CharDesc); } // LearnBlob diff --git a/classify/classify.cpp b/classify/classify.cpp index eb1e12ced9..15210c4393 100644 --- a/classify/classify.cpp +++ b/classify/classify.cpp @@ -17,49 +17,19 @@ /////////////////////////////////////////////////////////////////////// #include "classify.h" +#include "fontinfo.h" #include "intproto.h" #include "mfoutline.h" #include "scrollview.h" +#include "shapetable.h" #include "unicity_table.h" #include -namespace { - -// Compare FontInfo structures. -bool compare_fontinfo(const FontInfo& fi1, const FontInfo& fi2) { - // The font properties are required to be the same for two font with the same - // name, so there is no need to test them. - // Consequently, querying the table with only its font name as information is - // enough to retrieve its properties. - return strcmp(fi1.name, fi2.name) == 0; -} -// Compare FontSet structures. -bool compare_font_set(const FontSet& fs1, const FontSet& fs2) { - if (fs1.size != fs2.size) - return false; - for (int i = 0; i < fs1.size; ++i) { - if (fs1.configs[i] != fs2.configs[i]) - return false; - } - return true; -} - -void delete_callback(FontInfo f) { - if (f.spacing_vec != NULL) { - f.spacing_vec->delete_data_pointers(); - delete f.spacing_vec; - } - delete[] f.name; -} -void delete_callback_fs(FontSet fs) { - delete[] fs.configs; -} - -} - namespace tesseract { Classify::Classify() - : INT_MEMBER(tessedit_single_match, FALSE, + : BOOL_MEMBER(prioritize_division, FALSE, + "Prioritize blob division over chopping", this->params()), + INT_MEMBER(tessedit_single_match, FALSE, "Top choice only from CP", this->params()), BOOL_MEMBER(classify_enable_learning, true, "Enable adaptive classifier", this->params()), @@ -120,10 +90,6 @@ Classify::Classify() "Penalty to apply when a non-alnum is vertically out of " "its expected textline position", this->params()), - BOOL_MEMBER(classify_enable_int_fx, 1, "Enable integer fx", - this->params()), - BOOL_MEMBER(classify_enable_new_adapt_rules, 1, - "Enable new adaptation rules", this->params()), double_MEMBER(rating_scale, 1.5, "Rating scaling factor", this->params()), double_MEMBER(certainty_scale, 20.0, "Certainty scaling factor", this->params()), @@ -149,28 +115,29 @@ Classify::Classify() "One for the protos and one for the features.", this->params()), STRING_MEMBER(classify_learn_debug_str, "", "Class str to debug learning", this->params()), - INT_INIT_MEMBER(classify_class_pruner_threshold, 229, - "Class Pruner Threshold 0-255: ", this->params()), - INT_INIT_MEMBER(classify_class_pruner_multiplier, 30, - "Class Pruner Multiplier 0-255: ", this->params()), - INT_INIT_MEMBER(classify_cp_cutoff_strength, 7, - "Class Pruner CutoffStrength: ", this->params()), - INT_INIT_MEMBER(classify_integer_matcher_multiplier, 14, - "Integer Matcher Multiplier 0-255: ", this->params()), + INT_MEMBER(classify_class_pruner_threshold, 229, + "Class Pruner Threshold 0-255", this->params()), + INT_MEMBER(classify_class_pruner_multiplier, 30, + "Class Pruner Multiplier 0-255: ", this->params()), + INT_MEMBER(classify_cp_cutoff_strength, 7, + "Class Pruner CutoffStrength: ", this->params()), + INT_MEMBER(classify_integer_matcher_multiplier, 14, + "Integer Matcher Multiplier 0-255: ", this->params()), EnableLearning(true), INT_MEMBER(il1_adaption_test, 0, "Dont adapt to i/I at beginning of word", this->params()), BOOL_MEMBER(classify_bln_numeric_mode, 0, "Assume the input is numbers [0-9].", this->params()), + shape_table_(NULL), dict_(&image_) { fontinfo_table_.set_compare_callback( - NewPermanentTessCallback(compare_fontinfo)); + NewPermanentTessCallback(CompareFontInfo)); fontinfo_table_.set_clear_callback( - NewPermanentTessCallback(delete_callback)); + NewPermanentTessCallback(FontInfoDeleteCallback)); fontset_table_.set_compare_callback( - NewPermanentTessCallback(compare_font_set)); + NewPermanentTessCallback(CompareFontSet)); fontset_table_.set_clear_callback( - NewPermanentTessCallback(delete_callback_fs)); + NewPermanentTessCallback(FontSetDeleteCallback)); AdaptedTemplates = NULL; PreTrainedTemplates = NULL; AllProtosOn = NULL; @@ -198,6 +165,9 @@ Classify::Classify() learn_debug_win_ = NULL; learn_fragmented_word_debug_win_ = NULL; learn_fragments_debug_win_ = NULL; + + CharNormCutoffs = new uinT16[MAX_NUM_CLASSES]; + BaselineCutoffs = new uinT16[MAX_NUM_CLASSES]; } Classify::~Classify() { @@ -205,6 +175,8 @@ Classify::~Classify() { delete learn_debug_win_; delete learn_fragmented_word_debug_win_; delete learn_fragments_debug_win_; + delete[] CharNormCutoffs; + delete[] BaselineCutoffs; } } // namespace tesseract diff --git a/classify/classify.h b/classify/classify.h index b76d16a378..abdceef2d7 100644 --- a/classify/classify.h +++ b/classify/classify.h @@ -24,6 +24,7 @@ #include "classify.h" #include "dict.h" #include "featdefs.h" +#include "fontinfo.h" #include "intfx.h" #include "intmatcher.h" #include "normalis.h" @@ -42,6 +43,9 @@ static const int kBlankFontinfoId = -2; namespace tesseract { +struct ShapeRating; +class ShapeTable; + // How segmented is a blob. In this enum, character refers to a classifiable // unit, but that is too long and character is usually easier to understand. enum CharSegmentationType { @@ -59,27 +63,41 @@ class Classify : public CCStruct { return dict_; } - // Set the denorm for classification. Takes a copy. - void set_denorm(const DENORM* denorm) { - denorm_ = *denorm; + const ShapeTable* shape_table() const { + return shape_table_; } /* adaptive.cpp ************************************************************/ ADAPT_TEMPLATES NewAdaptedTemplates(bool InitFromUnicharset); int GetFontinfoId(ADAPT_CLASS Class, uinT8 ConfigId); - int ClassPruner(INT_TEMPLATES IntTemplates, - inT16 NumFeatures, - INT_FEATURE_ARRAY Features, - CLASS_NORMALIZATION_ARRAY NormalizationFactors, - CLASS_CUTOFF_ARRAY ExpectedNumFeatures, - CLASS_PRUNER_RESULTS Results); - void ReadNewCutoffs(FILE *CutoffFile, inT64 end_offset, + // Runs the class pruner from int_templates on the given features, returning + // the number of classes output in results. + // int_templates Class pruner tables + // num_features Number of features in blob + // features Array of features + // normalization_factors (input) Array of int_templates->NumClasses fudge + // factors from blob normalization process. + // (Indexed by CLASS_INDEX) + // expected_num_features (input) Array of int_templates->NumClasses + // expected number of features for each class. + // (Indexed by CLASS_INDEX) + // results (output) Sorted Array of pruned classes. + // Array must be sized to take the maximum possible + // number of outputs : int_templates->NumClasses. + int PruneClasses(const INT_TEMPLATES_STRUCT* int_templates, + int num_features, + const INT_FEATURE_STRUCT* features, + const uinT8* normalization_factors, + const uinT16* expected_num_features, + CP_RESULT_STRUCT* results); + void ReadNewCutoffs(FILE *CutoffFile, bool swap, inT64 end_offset, CLASS_CUTOFF_ARRAY Cutoffs); void PrintAdaptedTemplates(FILE *File, ADAPT_TEMPLATES Templates); void WriteAdaptedTemplates(FILE *File, ADAPT_TEMPLATES Templates); ADAPT_TEMPLATES ReadAdaptedTemplates(FILE *File); /* normmatch.cpp ************************************************************/ - FLOAT32 ComputeNormMatch(CLASS_ID ClassId, FEATURE Feature, BOOL8 DebugMatch); + FLOAT32 ComputeNormMatch(CLASS_ID ClassId, + const FEATURE_STRUCT& feature, BOOL8 DebugMatch); void FreeNormProtos(); NORM_PROTOS *ReadNormProtos(FILE *File, inT64 end_offset); /* protos.cpp ***************************************************************/ @@ -88,6 +106,7 @@ class Classify : public CCStruct { INT_TEMPLATES CreateIntTemplates(CLASSES FloatProtos, const UNICHARSET& target_unicharset); /* adaptmatch.cpp ***********************************************************/ + // Learn the given word using its chopped_word, seam_array, denorm, // box_word, best_state, and correct_text to learn both correctly and // incorrectly segmented blobs. If filename is not NULL, then LearnBlob @@ -111,36 +130,62 @@ class Classify : public CCStruct { const char* correct_text, WERD_RES *word); void InitAdaptiveClassifier(bool load_pre_trained_templates); void InitAdaptedClass(TBLOB *Blob, + const DENORM& denorm, CLASS_ID ClassId, int FontinfoId, ADAPT_CLASS Class, ADAPT_TEMPLATES Templates); void AdaptToPunc(TBLOB *Blob, + const DENORM& denorm, CLASS_ID ClassId, int FontinfoId, FLOAT32 Threshold); void AmbigClassifier(TBLOB *Blob, + const DENORM& denorm, INT_TEMPLATES Templates, ADAPT_CLASS *Classes, UNICHAR_ID *Ambiguities, ADAPT_RESULTS *Results); void MasterMatcher(INT_TEMPLATES templates, inT16 num_features, - INT_FEATURE_ARRAY features, - CLASS_NORMALIZATION_ARRAY norm_factors, + const INT_FEATURE_STRUCT* features, + const uinT8* norm_factors, ADAPT_CLASS* classes, int debug, int num_classes, const TBOX& blob_box, CLASS_PRUNER_RESULTS results, ADAPT_RESULTS* final_results); - void ConvertMatchesToChoices(ADAPT_RESULTS *Results, + // Converts configs to fonts, and if the result is not adapted, and a + // shape_table_ is present, the shape is expanded to include all + // unichar_ids represented, before applying a set of corrections to the + // distance rating in int_result, (see ComputeCorrectedRating.) + // The results are added to the final_results output. + void ExpandShapesAndApplyCorrections(ADAPT_CLASS* classes, + bool debug, + int class_id, + int bottom, int top, + float cp_rating, + int blob_length, + const uinT8* cn_factors, + INT_RESULT_STRUCT& int_result, + ADAPT_RESULTS* final_results); + // Applies a set of corrections to the distance im_rating, + // including the cn_correction, miss penalty and additional penalty + // for non-alnums being vertical misfits. Returns the corrected distance. + double ComputeCorrectedRating(bool debug, int unichar_id, double cp_rating, + double im_rating, int feature_misses, + int bottom, int top, + int blob_length, const uinT8* cn_factors); + void ConvertMatchesToChoices(const DENORM& denorm, const TBOX& box, + ADAPT_RESULTS *Results, BLOB_CHOICE_LIST *Choices); void AddNewResult(ADAPT_RESULTS *results, - CLASS_ID class_dd, + CLASS_ID class_id, + int shape_id, FLOAT32 rating, + bool adapted, int config, - int config2, int fontinfo_id, int fontinfo_id2); int GetAdaptiveFeatures(TBLOB *Blob, @@ -149,9 +194,11 @@ class Classify : public CCStruct { #ifndef GRAPHICS_DISABLED void DebugAdaptiveClassifier(TBLOB *Blob, + const DENORM& denorm, ADAPT_RESULTS *Results); #endif void GetAdaptThresholds (TWERD * Word, + const DENORM& denorm, const WERD_CHOICE& BestChoice, const WERD_CHOICE& BestRawChoice, FLOAT32 Thresholds[]); @@ -171,30 +218,64 @@ class Classify : public CCStruct { void MakePermanent(ADAPT_TEMPLATES Templates, CLASS_ID ClassId, int ConfigId, + const DENORM& denorm, TBLOB *Blob); void PrintAdaptiveMatchResults(FILE *File, ADAPT_RESULTS *Results); void RemoveExtraPuncs(ADAPT_RESULTS *Results); void RemoveBadMatches(ADAPT_RESULTS *Results); void SetAdaptiveThreshold(FLOAT32 Threshold); void ShowBestMatchFor(TBLOB *Blob, + const DENORM& denorm, CLASS_ID ClassId, + int shape_id, BOOL8 AdaptiveOn, - BOOL8 PreTrainedOn); + BOOL8 PreTrainedOn, + ADAPT_RESULTS *Results); + // Returns a string for the classifier class_id: either the corresponding + // unicharset debug_str or the shape_table_ debug str. + STRING ClassIDToDebugStr(const INT_TEMPLATES_STRUCT* templates, + int class_id, int config_id) const; + // Converts a classifier class_id index with a config ID to: + // shape_table_ present: a shape_table_ index OR + // No shape_table_: a font ID. + // Without shape training, each class_id, config pair represents a single + // unichar id/font combination, so this function looks up the corresponding + // font id. + // With shape training, each class_id, config pair represents a single + // shape table index, so the fontset_table stores the shape table index, + // and the shape_table_ must be consulted to obtain the actual unichar_id/ + // font combinations that the shape represents. + int ClassAndConfigIDToFontOrShapeID(int class_id, + int int_result_config) const; + // Converts a shape_table_ index to a classifier class_id index (not a + // unichar-id!). Uses a search, so not fast. + int ShapeIDToClassID(int shape_id) const; UNICHAR_ID *BaselineClassifier(TBLOB *Blob, + const DENORM& denorm, ADAPT_TEMPLATES Templates, ADAPT_RESULTS *Results); int CharNormClassifier(TBLOB *Blob, + const DENORM& denorm, INT_TEMPLATES Templates, ADAPT_RESULTS *Results); + + // As CharNormClassifier, but operates on a TrainingSample and outputs to + // a GenericVector of ShapeRating without conversion to classes. + int CharNormTrainingSample(bool pruner_only, const TrainingSample& sample, + GenericVector* results); UNICHAR_ID *GetAmbiguities(TBLOB *Blob, + const DENORM& denorm, CLASS_ID CorrectClass); void DoAdaptiveMatch(TBLOB *Blob, + const DENORM& denorm, ADAPT_RESULTS *Results); void AdaptToChar(TBLOB *Blob, + const DENORM& denorm, CLASS_ID ClassId, int FontinfoId, FLOAT32 Threshold); - void DisplayAdaptedChar(TBLOB* blob, INT_CLASS_STRUCT* int_class); + void DisplayAdaptedChar(TBLOB* blob, const DENORM& denorm, + INT_CLASS_STRUCT* int_class); int AdaptableWord(TWERD *Word, const WERD_CHOICE &BestChoiceWord, const WERD_CHOICE &RawChoiceWord); @@ -203,55 +284,53 @@ class Classify : public CCStruct { void SettupPass1(); void SettupPass2(); void AdaptiveClassifier(TBLOB *Blob, + const DENORM& denorm, BLOB_CHOICE_LIST *Choices, CLASS_PRUNER_RESULTS cp_results); void ClassifyAsNoise(ADAPT_RESULTS *Results); - void ResetAdaptiveClassifier(); + void ResetAdaptiveClassifierInternal(); int GetBaselineFeatures(TBLOB *Blob, + const DENORM& denorm, INT_TEMPLATES Templates, INT_FEATURE_ARRAY IntFeatures, - CLASS_NORMALIZATION_ARRAY CharNormArray, + uinT8* CharNormArray, inT32 *BlobLength); - FLOAT32 GetBestRatingFor(TBLOB *Blob, - CLASS_ID ClassId); int GetCharNormFeatures(TBLOB *Blob, + const DENORM& denorm, INT_TEMPLATES Templates, INT_FEATURE_ARRAY IntFeatures, - CLASS_NORMALIZATION_ARRAY CharNormArray, + uinT8* PrunerNormArray, + uinT8* CharNormArray, inT32 *BlobLength, inT32 *FeatureOutlineIndex); - int GetIntBaselineFeatures(TBLOB *Blob, - INT_TEMPLATES Templates, - INT_FEATURE_ARRAY IntFeatures, - CLASS_NORMALIZATION_ARRAY CharNormArray, - inT32 *BlobLength); - int GetIntCharNormFeatures(TBLOB *Blob, - INT_TEMPLATES Templates, - INT_FEATURE_ARRAY IntFeatures, - CLASS_NORMALIZATION_ARRAY CharNormArray, - inT32 *BlobLength, - inT32 *FeatureOutlineArray); + // Computes the char_norm_array for the unicharset and, if not NULL, the + // pruner_array as appropriate according to the existence of the shape_table. + // The norm_feature is deleted as it is almost certainly no longer needed. + void ComputeCharNormArrays(FEATURE_STRUCT* norm_feature, + INT_TEMPLATES_STRUCT* templates, + uinT8* char_norm_array, + uinT8* pruner_array); bool TempConfigReliable(CLASS_ID class_id, const TEMP_CONFIG &config); - void UpdateAmbigsGroup(CLASS_ID class_id, TBLOB *Blob); + void UpdateAmbigsGroup(CLASS_ID class_id, const DENORM& denorm, TBLOB *Blob); void ResetFeaturesHaveBeenExtracted(); bool AdaptiveClassifierIsFull() { return NumAdaptationsFailed > 0; } - bool LooksLikeGarbage(TBLOB *blob); + bool LooksLikeGarbage(const DENORM& denorm, TBLOB *blob); void RefreshDebugWindow(ScrollView **win, const char *msg, int y_offset, const TBOX &wbox); /* float2int.cpp ************************************************************/ - void ComputeIntCharNormArray(FEATURE NormFeature, - INT_TEMPLATES Templates, - CLASS_NORMALIZATION_ARRAY CharNormArray); + void ClearCharNormArray(uinT8* char_norm_array); + void ComputeIntCharNormArray(const FEATURE_STRUCT& norm_feature, + uinT8* char_norm_array); void ComputeIntFeatures(FEATURE_SET Features, INT_FEATURE_ARRAY IntFeatures); /* intproto.cpp *************************************************************/ INT_TEMPLATES ReadIntTemplates(FILE *File); void WriteIntTemplates(FILE *File, INT_TEMPLATES Templates, const UNICHARSET& target_unicharset); CLASS_ID GetClassToDebug(const char *Prompt, bool* adaptive_on, - bool* pretrained_on); + bool* pretrained_on, int* shape_id); void ShowMatchDisplay(); /* font detection ***********************************************************/ UnicityTable& get_fontinfo_table() { @@ -271,6 +350,8 @@ class Classify : public CCStruct { // Member variables. // Parameters. + BOOL_VAR_H(prioritize_division, FALSE, + "Prioritize blob division over chopping"); INT_VAR_H(tessedit_single_match, FALSE, "Top choice only from CP"); BOOL_VAR_H(classify_enable_learning, true, "Enable adaptive classifier"); INT_VAR_H(classify_debug_level, 0, "Classify debug level"); @@ -313,8 +394,6 @@ class Classify : public CCStruct { double_VAR_H(classify_misfit_junk_penalty, 0.0, "Penalty to apply when a non-alnum is vertically out of " "its expected textline position"); - BOOL_VAR_H(classify_enable_int_fx, 1, "Enable integer fx"); - BOOL_VAR_H(classify_enable_new_adapt_rules, 1, "Enable new adaptation rules"); double_VAR_H(rating_scale, 1.5, "Rating scaling factor"); double_VAR_H(certainty_scale, 20.0, "Certainty scaling factor"); double_VAR_H(tessedit_class_miss_scale, 0.00390625, @@ -338,7 +417,7 @@ class Classify : public CCStruct { /* intmatcher.cpp **********************************************************/ INT_VAR_H(classify_class_pruner_threshold, 229, - "Class Pruner Threshold 0-255: "); + "Class Pruner Threshold 0-255"); INT_VAR_H(classify_class_pruner_multiplier, 30, "Class Pruner Multiplier 0-255: "); INT_VAR_H(classify_cp_cutoff_strength, 7, @@ -362,17 +441,27 @@ class Classify : public CCStruct { NORM_PROTOS *NormProtos; /* font detection ***********************************************************/ UnicityTable fontinfo_table_; + // Without shape training, each class_id, config pair represents a single + // unichar id/font combination, so each fontset_table_ entry holds font ids + // for each config in the class. + // With shape training, each class_id, config pair represents a single + // shape_table_ index, so the fontset_table_ stores the shape_table_ index, + // and the shape_table_ must be consulted to obtain the actual unichar_id/ + // font combinations that the shape represents. UnicityTable fontset_table_; INT_VAR_H(il1_adaption_test, 0, "Dont adapt to i/I at beginning of word"); BOOL_VAR_H(classify_bln_numeric_mode, 0, "Assume the input is numbers [0-9]."); + protected: IntegerMatcher im_; FEATURE_DEFS_STRUCT feature_defs_; - // Must be set for the classifier to operate. Ususally set in - // Tesseract::recog_word_recursive, being the main word-level entry point. - DENORM denorm_; + // If a shape_table_ is present, it is used to remap classifier output in + // ExpandShapesAndApplyCorrections. font_ids referenced by configs actually + // mean an index to the shape_table_ and the choices returned are *all* the + // shape_table_ entries at that index. + ShapeTable* shape_table_; private: @@ -401,8 +490,17 @@ class Classify : public CCStruct { INT_FEATURE_ARRAY CharNormFeatures; INT_FX_RESULT_STRUCT FXInfo; - CLASS_CUTOFF_ARRAY CharNormCutoffs; - CLASS_CUTOFF_ARRAY BaselineCutoffs; + // Expected number of features in the class pruner, used to penalize + // unknowns that have too few features (like a c being classified as e) so + // it doesn't recognize everything as '@' or '#'. + // CharNormCutoffs is for the static classifier (with no shapetable). + // BaselineCutoffs gets a copy of CharNormCutoffs as an estimate of the real + // value in the adaptive classifier. Both are indexed by unichar_id. + // shapetable_cutoffs_ provides a similar value for each shape in the + // shape_table_ + uinT16* CharNormCutoffs; + uinT16* BaselineCutoffs; + GenericVector shapetable_cutoffs_; ScrollView* learn_debug_win_; ScrollView* learn_fragmented_word_debug_win_; ScrollView* learn_fragments_debug_win_; diff --git a/classify/cluster.cpp b/classify/cluster.cpp index 7c8d08a773..964ab2b7ba 100644 --- a/classify/cluster.cpp +++ b/classify/cluster.cpp @@ -20,6 +20,7 @@ #include "cluster.h" #include "emalloc.h" #include "helpers.h" +#include "matrix.h" #include "tprintf.h" #include "danerror.h" #include "freelist.h" @@ -137,7 +138,7 @@ const double FTable[FTABLE_Y][FTABLE_X] = { dimension of any feature. Since most features are calculated from numbers with a precision no better than 1 in 128, the variance should never be less than the square of this number for parameters whose range is 1. */ -#define MINVARIANCE 0.0001 +#define MINVARIANCE 0.0004 /* define the absolute minimum number of samples which must be present in order to accurately test hypotheses about underlying probability @@ -145,7 +146,6 @@ const double FTable[FTABLE_Y][FTABLE_X] = { before a statistical analysis is attempted; this number should be equal to MINSAMPLES but can be set to a lower number for early testing when very few samples are available. */ -#define MINBUCKETS 5 #define MINSAMPLESPERBUCKET 5 #define MINSAMPLES (MINBUCKETS * MINSAMPLESPERBUCKET) #define MINSAMPLESNEEDED 1 @@ -222,7 +222,6 @@ static const FLOAT64 kNormalMean = BUCKETTABLESIZE / 2; /* define lookup tables used to compute the number of histogram buckets that should be used for a given number of samples. */ #define LOOKUPTABLESIZE 8 -#define MAXBUCKETS 39 #define MAXDEGREESOFFREEDOM MAXBUCKETS static const uinT32 kCountTable[LOOKUPTABLESIZE] = { @@ -349,8 +348,7 @@ BOOL8 DistributionOK(BUCKETS *Buckets); void FreeStatistics(STATISTICS *Statistics); -void FreeBuckets(CLUSTERER* clusterer, - BUCKETS *Buckets); +void FreeBuckets(BUCKETS *Buckets); void FreeCluster(CLUSTER *Cluster); @@ -425,10 +423,11 @@ MakeClusterer (inT16 SampleSize, const PARAM_DESC ParamDesc[]) { // allocate a kd tree to hold the samples Clusterer->KDTree = MakeKDTree (SampleSize, ParamDesc); - // keep a list of histogram buckets to minimize recomputing them - Clusterer->bucket_cache[0] = NIL_LIST; - Clusterer->bucket_cache[1] = NIL_LIST; - Clusterer->bucket_cache[2] = NIL_LIST; + // Initialize cache of histogram buckets to minimize recomputing them. + for (int d = 0; d < DISTRIBUTION_COUNT; ++d) { + for (int c = 0; c < MAXBUCKETS + 1 - MINBUCKETS; ++c) + Clusterer->bucket_cache[d][c] = NULL; + } return Clusterer; } // MakeClusterer @@ -448,8 +447,8 @@ Exceptions: ALREADYCLUSTERED MakeSample can't be called after ClusterSamples has been called History: 5/29/89, DSJ, Created. *****************************************************************************/ -SAMPLE * -MakeSample (CLUSTERER * Clusterer, FLOAT32 Feature[], inT32 CharID) { +SAMPLE* MakeSample(CLUSTERER * Clusterer, const FLOAT32* Feature, + inT32 CharID) { SAMPLE *Sample; int i; @@ -537,9 +536,13 @@ void FreeClusterer(CLUSTERER *Clusterer) { FreeKDTree (Clusterer->KDTree); if (Clusterer->Root != NULL) FreeCluster (Clusterer->Root); - iterate (Clusterer->ProtoList) { - ((PROTOTYPE *) (first_node (Clusterer->ProtoList)))->Cluster = NULL; + // Free up all used buckets structures. + for (int d = 0; d < DISTRIBUTION_COUNT; ++d) { + for (int c = 0; c < MAXBUCKETS + 1 - MINBUCKETS; ++c) + if (Clusterer->bucket_cache[d][c] != NULL) + FreeBuckets(Clusterer->bucket_cache[d][c]); } + memfree(Clusterer); } } // FreeClusterer @@ -662,6 +665,8 @@ FLOAT32 StandardDeviation(PROTOTYPE *Proto, uinT16 Dimension) { case uniform: case D_random: return (Proto->Variance.Elliptical[Dimension]); + case DISTRIBUTION_COUNT: + ASSERT_HOST(!"Distribution count not allowed!"); } } return 0.0f; @@ -1033,7 +1038,6 @@ PROTOTYPE *MakePrototype(CLUSTERER *Clusterer, Config->Confidence); break; } - FreeBuckets(Clusterer, Buckets); FreeStatistics(Statistics); return Proto; } // MakePrototype @@ -1339,10 +1343,6 @@ PROTOTYPE *MakeMixedProto(CLUSTERER *Clusterer, FreePrototype(Proto); Proto = NULL; } - if (UniformBuckets != NULL) - FreeBuckets(Clusterer, UniformBuckets); - if (RandomBuckets != NULL) - FreeBuckets(Clusterer, RandomBuckets); return (Proto); } // MakeMixedProto @@ -1623,6 +1623,7 @@ PROTOTYPE *NewSimpleProto(inT16 N, CLUSTER *Cluster) { Proto->Distrib = NULL; Proto->Significant = TRUE; + Proto->Merged = FALSE; Proto->Style = spherical; Proto->NumSamples = Cluster->SampleCount; Proto->Cluster = Cluster; @@ -1705,17 +1706,18 @@ BUCKETS *GetBuckets(CLUSTERER* clusterer, DISTRIBUTION Distribution, uinT32 SampleCount, FLOAT64 Confidence) { - // search for an old bucket structure with the same number of buckets - LIST *bucket_cache = clusterer->bucket_cache; + // Get an old bucket structure with the same number of buckets. uinT16 NumberOfBuckets = OptimumNumberOfBuckets(SampleCount); - BUCKETS *Buckets = (BUCKETS *) first_node(search( - bucket_cache[(int)Distribution], &NumberOfBuckets, - NumBucketsMatch)); - - // if a matching bucket structure is found, delete it from the list - if (Buckets != NULL) { - bucket_cache[(int) Distribution] = - delete_d(bucket_cache[(int) Distribution], Buckets, ListEntryMatch); + BUCKETS *Buckets = + clusterer->bucket_cache[Distribution][NumberOfBuckets - MINBUCKETS]; + + // If a matching bucket structure is not found, make one and save it. + if (Buckets == NULL) { + Buckets = MakeBuckets(Distribution, SampleCount, Confidence); + clusterer->bucket_cache[Distribution][NumberOfBuckets - MINBUCKETS] = + Buckets; + } else { + // Just adjust the existing buckets. if (SampleCount != Buckets->SampleCount) AdjustBuckets(Buckets, SampleCount); if (Confidence != Buckets->Confidence) { @@ -1725,9 +1727,6 @@ BUCKETS *GetBuckets(CLUSTERER* clusterer, Confidence); } InitBuckets(Buckets); - } else { - // otherwise create a new structure - Buckets = MakeBuckets(Distribution, SampleCount, Confidence); } return Buckets; } // GetBuckets @@ -1770,14 +1769,14 @@ BUCKETS *MakeBuckets(DISTRIBUTION Distribution, BOOL8 Symmetrical; // allocate memory needed for data structure - Buckets = (BUCKETS *) Emalloc(sizeof(BUCKETS)); - Buckets->NumberOfBuckets = OptimumNumberOfBuckets (SampleCount); + Buckets = reinterpret_cast(Emalloc(sizeof(BUCKETS))); + Buckets->NumberOfBuckets = OptimumNumberOfBuckets(SampleCount); Buckets->SampleCount = SampleCount; Buckets->Confidence = Confidence; - Buckets->Count = - (uinT32 *) Emalloc(Buckets->NumberOfBuckets * sizeof (uinT32)); - Buckets->ExpectedCount = - (FLOAT32 *) Emalloc(Buckets->NumberOfBuckets * sizeof (FLOAT32)); + Buckets->Count = reinterpret_cast( + Emalloc(Buckets->NumberOfBuckets * sizeof(uinT32))); + Buckets->ExpectedCount = reinterpret_cast( + Emalloc(Buckets->NumberOfBuckets * sizeof(FLOAT32))); // initialize simple fields Buckets->Distribution = Distribution; @@ -2246,23 +2245,16 @@ void FreeStatistics(STATISTICS *Statistics) { //--------------------------------------------------------------------------- -void FreeBuckets(CLUSTERER* clusterer, BUCKETS *buckets) { +void FreeBuckets(BUCKETS *buckets) { /* ** Parameters: - ** clusterer->bucket_cache - ** distribution-indexed cache of old bucket structures. ** buckets pointer to data structure to be freed ** Operation: - ** This routine places the specified histogram data structure - ** at the front of a list of histograms so that it can be reused - ** later if necessary. A separate list is maintained for each - ** different type of distribution. + ** This routine properly frees the memory used by a BUCKETS. */ - LIST *bucket_cache = clusterer->bucket_cache; - if (buckets != NULL) { - int dist = (int)buckets->Distribution; - bucket_cache[dist] = (LIST) push(bucket_cache[dist], buckets); - } + Efree(buckets->Count); + Efree(buckets->ExpectedCount); + Efree(buckets); } // FreeBuckets @@ -2640,8 +2632,10 @@ CLUSTER * Cluster, FLOAT32 MaxIllegal) } NumCharInCluster--; PercentIllegal = (FLOAT32) NumIllegalInCluster / NumCharInCluster; - if (PercentIllegal > MaxIllegal) + if (PercentIllegal > MaxIllegal) { + destroy(SearchState); return (TRUE); + } } } return (FALSE); @@ -2652,17 +2646,10 @@ CLUSTER * Cluster, FLOAT32 MaxIllegal) // The return value is the sum of norms of the off-diagonal terms of the // product of a and inv. (A measure of the error.) double InvertMatrix(const float* input, int size, float* inv) { - double** U; // The upper triangular array. - double* Umem; - double** U_inv; // The inverse of U. - double* U_invmem; - double** L; // The lower triangular array. - double* Lmem; - // Allocate memory for the 2D arrays. - ALLOC_2D_ARRAY(size, size, Umem, U, double); - ALLOC_2D_ARRAY(size, size, U_invmem, U_inv, double); - ALLOC_2D_ARRAY(size, size, Lmem, L, double); + GENERIC_2D_ARRAY U(size, size, 0.0); + GENERIC_2D_ARRAY U_inv(size, size, 0.0); + GENERIC_2D_ARRAY L(size, size, 0.0); // Initialize the working matrices. U starts as input, L as I and U_inv as O. int row; diff --git a/classify/cluster.h b/classify/cluster.h index 94f9d716cb..53ddf87dad 100644 --- a/classify/cluster.h +++ b/classify/cluster.h @@ -21,6 +21,11 @@ #include "kdtree.h" #include "oldlist.h" +struct BUCKETS; + +#define MINBUCKETS 5 +#define MAXBUCKETS 39 + /*---------------------------------------------------------------------- Types ----------------------------------------------------------------------*/ @@ -51,7 +56,7 @@ typedef struct { // parameters to control clustering } CLUSTERCONFIG; typedef enum { - normal, uniform, D_random + normal, uniform, D_random, DISTRIBUTION_COUNT } DISTRIBUTION; typedef union { @@ -86,7 +91,8 @@ typedef struct { CLUSTER *Root; // ptr to root cluster of cluster tree LIST ProtoList; // list of prototypes inT32 NumChar; // # of characters represented by samples - LIST bucket_cache[3]; // cache of reusable histograms by distribution type + // cache of reusable histograms by distribution type and number of buckets. + BUCKETS* bucket_cache[DISTRIBUTION_COUNT][MAXBUCKETS + 1 - MINBUCKETS]; } CLUSTERER; typedef struct { @@ -103,7 +109,7 @@ typedef struct { --------------------------------------------------------------------------*/ CLUSTERER *MakeClusterer (inT16 SampleSize, const PARAM_DESC ParamDesc[]); -SAMPLE *MakeSample (CLUSTERER * Clusterer, FLOAT32 Feature[], inT32 CharID); +SAMPLE *MakeSample(CLUSTERER * Clusterer, const FLOAT32* Feature, inT32 CharID); LIST ClusterSamples(CLUSTERER *Clusterer, CLUSTERCONFIG *Config); diff --git a/classify/clusttool.cpp b/classify/clusttool.cpp index 8bb4313b4d..f0ebf056d5 100644 --- a/classify/clusttool.cpp +++ b/classify/clusttool.cpp @@ -213,6 +213,8 @@ PROTOTYPE *ReadPrototype(FILE *File, uinT16 N) { Proto->Magnitude.Elliptical[i] = 1.0 / (2.0 * Proto->Variance.Elliptical[i]); break; + case DISTRIBUTION_COUNT: + ASSERT_HOST(!"Distribution count not allowed!"); } Proto->TotalMagnitude *= Proto->Magnitude.Elliptical[i]; } @@ -374,6 +376,8 @@ void WritePrototype(FILE *File, uinT16 N, PROTOTYPE *Proto) { case D_random: fprintf (File, " %9s", "random"); break; + case DISTRIBUTION_COUNT: + ASSERT_HOST(!"Distribution count not allowed!"); } fprintf (File, "\n\t"); WriteNFloats (File, N, Proto->Variance.Elliptical); @@ -392,13 +396,10 @@ Return: None Exceptions: None History: 6/6/89, DSJ, Created. ****************************************************************************/ -void -WriteNFloats (FILE * File, uinT16 N, FLOAT32 Array[]) { - int i; - - for (i = 0; i < N; i++) - fprintf (File, " %9.6f", Array[i]); - fprintf (File, "\n"); +void WriteNFloats(FILE * File, uinT16 N, FLOAT32 Array[]) { + for (int i = 0; i < N; i++) + fprintf(File, " %9.6f", Array[i]); + fprintf(File, "\n"); } // WriteNFloats @@ -479,29 +480,3 @@ void WriteProtoList( } } /* WriteProtoList */ -/** UniformRandomNumber ******************************************************** -Parameters: MMin lower range of uniform distribution - MMax upper range of uniform distribution -Globals: None -Operation: This routine computes a random number which comes from a - uniform distribution over the range from MMin to MMax. -Return: Uniform random number -Exceptions: None -History: 6/6/89, DSJ, Created. -*******************************************************************************/ -FLOAT32 UniformRandomNumber(FLOAT32 MMin, FLOAT32 MMax) { - double fake_drand48(); - FLOAT32 RandomNumber; - - RandomNumber = fake_drand48 (); - return (MMin + (RandomNumber * (MMax - MMin))); -} // UniformRandomNumber - - -/** drand48 ************************************************************* -Cheap replacement for drand48 which is not available on the PC. -**********************************************************************/ - -double fake_drand48() { - return rand () / (RAND_MAX + 1.0); -} diff --git a/classify/clusttool.h b/classify/clusttool.h index c3bd44ea96..a4f3b8351d 100644 --- a/classify/clusttool.h +++ b/classify/clusttool.h @@ -52,8 +52,6 @@ void WriteProtoList( BOOL8 WriteSigProtos, BOOL8 WriteInsigProtos); -FLOAT32 UniformRandomNumber(FLOAT32 MMin, FLOAT32 MMax); - //--------------Global Data Definitions and Declarations--------------------- // define errors that can be trapped #define ILLEGALSAMPLESIZE 5000 diff --git a/classify/cutoffs.cpp b/classify/cutoffs.cpp index baa506820d..536a710caa 100644 --- a/classify/cutoffs.cpp +++ b/classify/cutoffs.cpp @@ -39,7 +39,7 @@ ----------------------------------------------------------------------------**/ /*---------------------------------------------------------------------------*/ namespace tesseract { -void Classify::ReadNewCutoffs(FILE *CutoffFile, inT64 end_offset, +void Classify::ReadNewCutoffs(FILE *CutoffFile, bool swap, inT64 end_offset, CLASS_CUTOFF_ARRAY Cutoffs) { /* ** Parameters: @@ -59,6 +59,11 @@ void Classify::ReadNewCutoffs(FILE *CutoffFile, inT64 end_offset, int Cutoff; int i; + if (shape_table_ != NULL) { + if (!shapetable_cutoffs_.DeSerialize(swap, CutoffFile)) { + tprintf("Error during read of shapetable pffmtable!\n"); + } + } for (i = 0; i < MAX_NUM_CLASSES; i++) Cutoffs[i] = MAX_CUTOFF; diff --git a/classify/errorcounter.cpp b/classify/errorcounter.cpp new file mode 100644 index 0000000000..52d5ac560c --- /dev/null +++ b/classify/errorcounter.cpp @@ -0,0 +1,385 @@ +// Copyright 2011 Google Inc. All Rights Reserved. +// Author: rays@google.com (Ray Smith) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +/////////////////////////////////////////////////////////////////////// + +#include "errorcounter.h" + +#include "fontinfo.h" +#include "ndminx.h" +#include "sampleiterator.h" +#include "shapeclassifier.h" +#include "shapetable.h" +#include "trainingsample.h" +#include "trainingsampleset.h" +#include "unicity_table.h" + +namespace tesseract { + +// Tests a classifier, computing its error rate. +// See errorcounter.h for description of arguments. +// Iterates over the samples, calling the classifier in normal/silent mode. +// If the classifier makes a CT_UNICHAR_TOPN_ERR error, and the appropriate +// report_level is set (4 or greater), it will then call the classifier again +// with a debug flag and a keep_this argument to find out what is going on. +double ErrorCounter::ComputeErrorRate(ShapeClassifier* classifier, + int report_level, CountTypes boosting_mode, + const UnicityTable& fontinfo_table, + const GenericVector& page_images, SampleIterator* it, + double* unichar_error, double* scaled_error, STRING* fonts_report) { + int charsetsize = it->shape_table()->unicharset().size(); + int shapesize = it->CompactCharsetSize(); + int fontsize = it->sample_set()->NumFonts(); + ErrorCounter counter(charsetsize, shapesize, fontsize); + GenericVector results; + + clock_t start = clock(); + int total_samples = 0; + double unscaled_error = 0.0; + // Set a number of samples on which to run the classify debug mode. + int error_samples = report_level > 3 ? report_level * report_level : 0; + // Iterate over all the samples, accumulating errors. + for (it->Begin(); !it->AtEnd(); it->Next()) { + TrainingSample* mutable_sample = it->MutableSample(); + int page_index = mutable_sample->page_num(); + Pix* page_pix = 0 <= page_index && page_index < page_images.size() + ? page_images[page_index] : NULL; + // No debug, no keep this. + classifier->ClassifySample(*mutable_sample, page_pix, 0, INVALID_UNICHAR_ID, + &results); + if (mutable_sample->class_id() == 0) { + // This is junk so use the special counter. + counter.AccumulateJunk(*it->shape_table(), results, mutable_sample); + } else if (counter.AccumulateErrors(report_level > 3, boosting_mode, + fontinfo_table, *it->shape_table(), + results, mutable_sample) && + error_samples > 0) { + // Running debug, keep the correct answer, and debug the classifier. + tprintf("Error on sample %d: Classifier debug output:\n", + it->GlobalSampleIndex()); + int keep_this = it->GetSparseClassID(); + classifier->ClassifySample(*mutable_sample, page_pix, 1, keep_this, + &results); + --error_samples; + } + ++total_samples; + } + double total_time = 1.0 * (clock() - start) / CLOCKS_PER_SEC; + // Create the appropriate error report. + unscaled_error = counter.ReportErrors(report_level, boosting_mode, + fontinfo_table, + *it, unichar_error, fonts_report); + if (scaled_error != NULL) *scaled_error = counter.scaled_error_; + if (report_level > 1) { + // It is useful to know the time in microseconds/char. + tprintf("Errors computed in %.2fs at %.1f μs/char\n", + total_time, 1000000.0 * total_time / total_samples); + } + return unscaled_error; +} + +// Constructor is private. Only anticipated use of ErrorCounter is via +// the static ComputeErrorRate. +ErrorCounter::ErrorCounter(int charsetsize, int shapesize, int fontsize) + : scaled_error_(0.0), unichar_counts_(charsetsize, shapesize, 0) { + Counts empty_counts; + font_counts_.init_to_size(fontsize, empty_counts); +} +ErrorCounter::~ErrorCounter() { +} + +// Accumulates the errors from the classifier results on a single sample. +// Returns true if debug is true and a CT_UNICHAR_TOPN_ERR error occurred. +// boosting_mode selects the type of error to be used for boosting and the +// is_error_ member of sample is set according to whether the required type +// of error occurred. The font_table provides access to font properties +// for error counting and shape_table is used to understand the relationship +// between unichar_ids and shape_ids in the results +bool ErrorCounter::AccumulateErrors(bool debug, CountTypes boosting_mode, + const UnicityTable& font_table, + const ShapeTable& shape_table, + const GenericVector& results, + TrainingSample* sample) { + int num_results = results.size(); + int res_index = 0; + bool debug_it = false; + int font_id = sample->font_id(); + int unichar_id = sample->class_id(); + sample->set_is_error(false); + if (num_results == 0) { + // Reject. We count rejects as a separate category, but still mark the + // sample as an error in case any training module wants to use that to + // improve the classifier. + sample->set_is_error(true); + ++font_counts_[font_id].n[CT_REJECT]; + } else if (shape_table.GetShape(results[0].shape_id). + ContainsUnicharAndFont(unichar_id, font_id)) { + ++font_counts_[font_id].n[CT_SHAPE_TOP_CORRECT]; + // Unichar and font OK, but count if multiple unichars. + if (shape_table.GetShape(results[0].shape_id).size() > 1) + ++font_counts_[font_id].n[CT_OK_MULTI_UNICHAR]; + } else { + // This is a top shape error. + ++font_counts_[font_id].n[CT_SHAPE_TOP_ERR]; + // Check to see if any font in the top choice has attributes that match. + bool attributes_match = false; + uinT32 font_props = font_table.get(font_id).properties; + const Shape& shape = shape_table.GetShape(results[0].shape_id); + for (int c = 0; c < shape.size() && !attributes_match; ++c) { + for (int f = 0; f < shape[c].font_ids.size(); ++f) { + if (font_table.get(shape[c].font_ids[f]).properties == font_props) { + attributes_match = true; + break; + } + } + } + // TODO(rays) It is easy to add counters for individual font attributes + // here if we want them. + if (!attributes_match) + ++font_counts_[font_id].n[CT_FONT_ATTR_ERR]; + if (boosting_mode == CT_SHAPE_TOP_ERR) sample->set_is_error(true); + // Find rank of correct unichar answer. (Ignoring the font.) + while (res_index < num_results && + !shape_table.GetShape(results[res_index].shape_id). + ContainsUnichar(unichar_id)) { + ++res_index; + } + if (res_index == 0) { + // Unichar OK, but count if multiple unichars. + if (shape_table.GetShape(results[res_index].shape_id).size() > 1) { + ++font_counts_[font_id].n[CT_OK_MULTI_UNICHAR]; + } + } else { + // Count maps from unichar id to shape id. + if (num_results > 0) + ++unichar_counts_(unichar_id, results[0].shape_id); + // This is a unichar error. + ++font_counts_[font_id].n[CT_UNICHAR_TOP1_ERR]; + if (boosting_mode == CT_UNICHAR_TOP1_ERR) sample->set_is_error(true); + if (res_index >= MIN(2, num_results)) { + // It is also a 2nd choice unichar error. + ++font_counts_[font_id].n[CT_UNICHAR_TOP2_ERR]; + if (boosting_mode == CT_UNICHAR_TOP2_ERR) sample->set_is_error(true); + } + if (res_index >= num_results) { + // It is also a top-n choice unichar error. + ++font_counts_[font_id].n[CT_UNICHAR_TOPN_ERR]; + if (boosting_mode == CT_UNICHAR_TOPN_ERR) sample->set_is_error(true); + debug_it = debug; + } + } + } + // Compute mean number of return values and mean rank of correct answer. + font_counts_[font_id].n[CT_NUM_RESULTS] += num_results; + font_counts_[font_id].n[CT_RANK] += res_index; + // If it was an error for boosting then sum the weight. + if (sample->is_error()) { + scaled_error_ += sample->weight(); + } + if (debug_it) { + tprintf("%d results for char %s font %d :", + num_results, shape_table.unicharset().id_to_unichar(unichar_id), + font_id); + for (int i = 0; i < num_results; ++i) { + tprintf(" %.3f/%.3f:%s", + results[i].rating, results[i].font, + shape_table.DebugStr(results[i].shape_id).string()); + } + tprintf("\n"); + return true; + } + return false; +} + +// Accumulates counts for junk. Counts only whether the junk was correctly +// rejected or not. +void ErrorCounter::AccumulateJunk(const ShapeTable& shape_table, + const GenericVector& results, + TrainingSample* sample) { + // For junk we accept no answer, or an explicit shape answer matching the + // class id of the sample. + int num_results = results.size(); + int font_id = sample->font_id(); + int unichar_id = sample->class_id(); + if (num_results > 0 && + !shape_table.GetShape(results[0].shape_id).ContainsUnichar(unichar_id)) { + // This is a junk error. + ++font_counts_[font_id].n[CT_ACCEPTED_JUNK]; + sample->set_is_error(true); + // It counts as an error for boosting too so sum the weight. + scaled_error_ += sample->weight(); + } else { + // Correctly rejected. + ++font_counts_[font_id].n[CT_REJECTED_JUNK]; + sample->set_is_error(false); + } +} + +// Creates a report of the error rate. The report_level controls the detail +// that is reported to stderr via tprintf: +// 0 -> no output. +// >=1 -> bottom-line error rate. +// >=3 -> font-level error rate. +// boosting_mode determines the return value. It selects which (un-weighted) +// error rate to return. +// The fontinfo_table from MasterTrainer provides the names of fonts. +// The it determines the current subset of the training samples. +// If not NULL, the top-choice unichar error rate is saved in unichar_error. +// If not NULL, the report string is saved in fonts_report. +// (Ignoring report_level). +double ErrorCounter::ReportErrors(int report_level, CountTypes boosting_mode, + const UnicityTable& fontinfo_table, + const SampleIterator& it, + double* unichar_error, + STRING* fonts_report) { + // Compute totals over all the fonts and report individual font results + // when required. + Counts totals; + int fontsize = font_counts_.size(); + for (int f = 0; f < fontsize; ++f) { + // Accumulate counts over fonts. + totals += font_counts_[f]; + STRING font_report; + if (ReportString(font_counts_[f], &font_report)) { + if (fonts_report != NULL) { + *fonts_report += fontinfo_table.get(f).name; + *fonts_report += ": "; + *fonts_report += font_report; + *fonts_report += "\n"; + } + if (report_level > 2) { + // Report individual font error rates. + tprintf("%s: %s\n", fontinfo_table.get(f).name, font_report.string()); + } + } + } + if (report_level > 0) { + // Report the totals. + STRING total_report; + if (ReportString(totals, &total_report)) { + tprintf("TOTAL Scaled Err=%.4g%%, %s\n", + scaled_error_ * 100.0, total_report.string()); + } + // Report the worst substitution error only for now. + if (totals.n[CT_UNICHAR_TOP1_ERR] > 0) { + const UNICHARSET& unicharset = it.shape_table()->unicharset(); + int charsetsize = unicharset.size(); + int shapesize = it.CompactCharsetSize(); + int worst_uni_id = 0; + int worst_shape_id = 0; + int worst_err = 0; + for (int u = 0; u < charsetsize; ++u) { + for (int s = 0; s < shapesize; ++s) { + if (unichar_counts_(u, s) > worst_err) { + worst_err = unichar_counts_(u, s); + worst_uni_id = u; + worst_shape_id = s; + } + } + } + if (worst_err > 0) { + tprintf("Worst error = %d:%s -> %s with %d/%d=%.2f%% errors\n", + worst_uni_id, unicharset.id_to_unichar(worst_uni_id), + it.shape_table()->DebugStr(worst_shape_id).string(), + worst_err, totals.n[CT_UNICHAR_TOP1_ERR], + 100.0 * worst_err / totals.n[CT_UNICHAR_TOP1_ERR]); + } + } + } + double rates[CT_SIZE]; + if (!ComputeRates(totals, rates)) + return 0.0; + // Set output values if asked for. + if (unichar_error != NULL) + *unichar_error = rates[CT_UNICHAR_TOP1_ERR]; + return rates[boosting_mode]; +} + +// Sets the report string to a combined human and machine-readable report +// string of the error rates. +// Returns false if there is no data, leaving report unchanged. +bool ErrorCounter::ReportString(const Counts& counts, STRING* report) { + // Compute the error rates. + double rates[CT_SIZE]; + if (!ComputeRates(counts, rates)) + return false; + // Using %.4g%%, the length of the output string should exactly match the + // length of the format string, but in case of overflow, allow for +eddd + // on each number. + const int kMaxExtraLength = 5; // Length of +eddd. + // Keep this format string and the snprintf in sync with the CountTypes enum. + const char* format_str = "ShapeErr=%.4g%%, FontAttr=%.4g%%, " + "Unichar=%.4g%%[1], %.4g%%[2], %.4g%%[n], " + "Multi=%.4g%%, Rej=%.4g%%, " + "Answers=%.3g, Rank=%.3g, " + "OKjunk=%.4g%%, Badjunk=%.4g%%"; + int max_str_len = strlen(format_str) + kMaxExtraLength * (CT_SIZE - 1) + 1; + char* formatted_str = new char[max_str_len]; + snprintf(formatted_str, max_str_len, format_str, + rates[CT_SHAPE_TOP_ERR] * 100.0, + rates[CT_FONT_ATTR_ERR] * 100.0, + rates[CT_UNICHAR_TOP1_ERR] * 100.0, + rates[CT_UNICHAR_TOP2_ERR] * 100.0, + rates[CT_UNICHAR_TOPN_ERR] * 100.0, + rates[CT_OK_MULTI_UNICHAR] * 100.0, + rates[CT_REJECT] * 100.0, + rates[CT_NUM_RESULTS], + rates[CT_RANK], + 100.0 * rates[CT_REJECTED_JUNK], + 100.0 * rates[CT_ACCEPTED_JUNK]); + *report = formatted_str; + delete [] formatted_str; + // Now append each field of counts with a tab in front so the result can + // be loaded into a spreadsheet. + for (int ct = 0; ct < CT_SIZE; ++ct) + report->add_str_int("\t", counts.n[ct]); + return true; +} + +// Computes the error rates and returns in rates which is an array of size +// CT_SIZE. Returns false if there is no data, leaving rates unchanged. +bool ErrorCounter::ComputeRates(const Counts& counts, double rates[CT_SIZE]) { + int ok_samples = counts.n[CT_SHAPE_TOP_CORRECT] + counts.n[CT_SHAPE_TOP_ERR] + + counts.n[CT_REJECT]; + int junk_samples = counts.n[CT_REJECTED_JUNK] + counts.n[CT_ACCEPTED_JUNK]; + if (ok_samples == 0 && junk_samples == 0) { + // There is no data. + return false; + } + // Compute rates for normal chars. + double denominator = static_cast(MAX(ok_samples, 1)); + for (int ct = 0; ct <= CT_RANK; ++ct) + rates[ct] = counts.n[ct] / denominator; + // Compute rates for junk. + denominator = static_cast(MAX(junk_samples, 1)); + for (int ct = CT_REJECTED_JUNK; ct <= CT_ACCEPTED_JUNK; ++ct) + rates[ct] = counts.n[ct] / denominator; + return true; +} + +ErrorCounter::Counts::Counts() { + memset(n, 0, sizeof(n[0]) * CT_SIZE); +} +// Adds other into this for computing totals. +void ErrorCounter::Counts::operator+=(const Counts& other) { + for (int ct = 0; ct < CT_SIZE; ++ct) + n[ct] += other.n[ct]; +} + + +} // namespace tesseract. + + + + + diff --git a/classify/errorcounter.h b/classify/errorcounter.h new file mode 100644 index 0000000000..618d568783 --- /dev/null +++ b/classify/errorcounter.h @@ -0,0 +1,198 @@ +// Copyright 2011 Google Inc. All Rights Reserved. +// Author: rays@google.com (Ray Smith) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +/////////////////////////////////////////////////////////////////////// + +#ifndef THIRD_PARTY_TESSERACT_CLASSIFY_ERRORCOUNTER_H_ +#define THIRD_PARTY_TESSERACT_CLASSIFY_ERRORCOUNTER_H_ + +#include "genericvector.h" +#include "matrix.h" + +struct Pix; +template class UnicityTable; + +namespace tesseract { + +struct FontInfo; +class SampleIterator; +class ShapeClassifier; +class ShapeRating; +class ShapeTable; +class TrainingSample; + +// Enumeration of the different types of error count. +// Error counts work as follows: +// +// Ground truth is a valid unichar-id / font-id pair: +// Number of classifier answers? +// 0 >0 +// CT_REJECT BOTH unichar-id and font-id match top shape? +// __________ yes! no +// CT_SHAPE_TOP_CORRECT CT_SHAPE_TOP_ERR +// | Font attributes match? +// | yes! no +// | | CT_FONT_ATTR_ERROR +// | Top unichar-id matches? +// | yes! no +// Top shape-id has multiple unichars? CT_UNICHAR_TOP1_ERR +// yes! no 2nd shape unichar id matches? +// CT_OK_MULTI_UNICHAR ________ yes! no +// ___________________ _____ CT_UNICHAR_TOP2_ERR +// Any unichar-id matches? +// yes! no +// ______ CT_UNICHAR_TOPN_ERR +// _________________ +// Note that multiple counts may be activated for a single sample! +// +// Ground truth is for a fragment/n-gram that is NOT in the unicharset. +// This is called junk and is expected to be rejected: +// Number of classifier answers? +// 0 >0 +// CT_REJECTED_JUNK CT_ACCEPTED_JUNK +// +// Also, CT_NUM_RESULTS stores the mean number of results, and CT_RANK stores +// the mean rank of the correct result, counting from 0, and with an error +// receiving the number of answers as the correct rank. +// +// Keep in sync with the ReportString function. +enum CountTypes { + CT_SHAPE_TOP_CORRECT, // Top shape id is actually correct. + CT_SHAPE_TOP_ERR, // Top shape id is not correct. + CT_FONT_ATTR_ERR, // Font attributes incorrect, ignoring unichar. + CT_UNICHAR_TOP1_ERR, // Top shape does not contain correct unichar id. + CT_UNICHAR_TOP2_ERR, // Top 2 shapes don't contain correct unichar id. + CT_UNICHAR_TOPN_ERR, // No output shape contains correct unichar id. + CT_OK_MULTI_UNICHAR, // Top shape id has correct unichar id, and others. + CT_REJECT, // Classifier hates this. + CT_NUM_RESULTS, // Number of answers produced. + CT_RANK, // Rank of correct answer. + CT_REJECTED_JUNK, // Junk that was correctly rejected. + CT_ACCEPTED_JUNK, // Junk that was incorrectly classified otherwise. + + CT_SIZE // Number of types for array sizing. +}; + +// Class to encapsulate all the functionality and sub-structures required +// to count errors for an isolated character classifier (ShapeClassifier). +class ErrorCounter { + public: + // Computes and returns the unweighted boosting_mode error rate of the given + // classifier. Can be used for testing, or inside an iterative training + // system, including one that uses boosting. + // report_levels: + // 0 = no output. + // 1 = bottom-line error rate. + // 2 = bottom-line error rate + time. + // 3 = font-level error rate + time. + // 4 = list of all errors + short classifier debug output on 16 errors. + // 5 = list of all errors + short classifier debug output on 25 errors. + // * The boosting_mode determines which error type is used for computing the + // scaled_error output, and setting the is_error flag in the samples. + // * The fontinfo_table is used to get string font names for the debug + // output, and also to count font attributes errors. + // * The page_images vector may contain a Pix* (which may be NULL) for each + // page index assigned to the samples. + // * The it provides encapsulated iteration over some sample set. + // * The outputs unichar_error, scaled_error and totals_report are all + // optional. + // * If not NULL, unichar error gets the top1 unichar error rate. + // * Scaled_error gets the error chosen by boosting_mode weighted by the + // weights on the samples. + // * Fonts_report gets a string summarizing the error rates for each font in + // both human-readable form and as a tab-separated list of error counts. + // The human-readable form is all before the first tab. + // * The return value is the un-weighted version of the scaled_error. + static double ComputeErrorRate(ShapeClassifier* classifier, + int report_level, CountTypes boosting_mode, + const UnicityTable& fontinfo_table, + const GenericVector& page_images, + SampleIterator* it, + double* unichar_error, + double* scaled_error, + STRING* fonts_report); + + private: + // Simple struct to hold an array of counts. + struct Counts { + Counts(); + // Adds other into this for computing totals. + void operator+=(const Counts& other); + + int n[CT_SIZE]; + }; + + // Constructor is private. Only anticipated use of ErrorCounter is via + // the static ComputeErrorRate. + ErrorCounter(int charsetsize, int shapesize, int fontsize); + ~ErrorCounter(); + + // Accumulates the errors from the classifier results on a single sample. + // Returns true if debug is true and a CT_UNICHAR_TOPN_ERR error occurred. + // boosting_mode selects the type of error to be used for boosting and the + // is_error_ member of sample is set according to whether the required type + // of error occurred. The font_table provides access to font properties + // for error counting and shape_table is used to understand the relationship + // between unichar_ids and shape_ids in the results + bool AccumulateErrors(bool debug, CountTypes boosting_mode, + const UnicityTable& font_table, + const ShapeTable& shape_table, + const GenericVector& results, + TrainingSample* sample); + + // Accumulates counts for junk. Counts only whether the junk was correctly + // rejected or not. + void AccumulateJunk(const ShapeTable& shape_table, + const GenericVector& results, + TrainingSample* sample); + + // Creates a report of the error rate. The report_level controls the detail + // that is reported to stderr via tprintf: + // 0 -> no output. + // >=1 -> bottom-line error rate. + // >=3 -> font-level error rate. + // boosting_mode determines the return value. It selects which (un-weighted) + // error rate to return. + // The fontinfo_table from MasterTrainer provides the names of fonts. + // The it determines the current subset of the training samples. + // If not NULL, the top-choice unichar error rate is saved in unichar_error. + // If not NULL, the report string is saved in fonts_report. + // (Ignoring report_level). + double ReportErrors(int report_level, CountTypes boosting_mode, + const UnicityTable& fontinfo_table, + const SampleIterator& it, + double* unichar_error, + STRING* fonts_report); + + // Sets the report string to a combined human and machine-readable report + // string of the error rates. + // Returns false if there is no data, leaving report unchanged. + static bool ReportString(const Counts& counts, STRING* report); + + // Computes the error rates and returns in rates which is an array of size + // CT_SIZE. Returns false if there is no data, leaving rates unchanged. + static bool ComputeRates(const Counts& counts, double rates[CT_SIZE]); + + + // Total scaled error used by boosting algorithms. + double scaled_error_; + // Vector indexed by font_id from the samples of error accumulators. + GenericVector font_counts_; + // Counts of the results that map each unichar_id (from samples) to an + // incorrect shape_id. + GENERIC_2D_ARRAY unichar_counts_; +}; + +} // namespace tesseract. + +#endif /* THIRD_PARTY_TESSERACT_CLASSIFY_ERRORCOUNTER_H_ */ diff --git a/classify/extract.cpp b/classify/extract.cpp index 8075d0342b..b8e595a7cf 100644 --- a/classify/extract.cpp +++ b/classify/extract.cpp @@ -29,15 +29,6 @@ typedef CHAR_FEATURES (*CF_FUNC) (); -----------------------------------------------------------------------------*/ void ExtractorStub(); -/*----------------------------------------------------------------------------- - Global Data Definitions and Declarations ------------------------------------------------------------------------------*/ -/** tables to keep track of the different low level feature extractors */ -#define NUM_FX 3 -#define DEFAULT_FX 2 - -int CurrentFx = DEFAULT_FX; - /*----------------------------------------------------------------------------- Public Code -----------------------------------------------------------------------------*/ diff --git a/classify/featdefs.cpp b/classify/featdefs.cpp index fc8a12e1d0..889e614a03 100644 --- a/classify/featdefs.cpp +++ b/classify/featdefs.cpp @@ -30,64 +30,85 @@ #define ILLEGAL_NUM_SETS 3001 #define PICO_FEATURE_LENGTH 0.05 -#define MAX_OUTLINE_FEATURES 100 /*----------------------------------------------------------------------------- Global Data Definitions and Declarations -----------------------------------------------------------------------------*/ -/* define all of the parameters for the MicroFeature type*/ -StartParamDesc (MicroFeatureParams) -DefineParam (0, 0, -0.5, 0.5) -DefineParam (0, 0, -0.25, 0.75) -DefineParam (0, 1, 0.0, 1.0) -DefineParam (1, 0, 0.0, 1.0) +const char* kMicroFeatureType = "mf"; +const char* kCNFeatureType = "cn"; +const char* kIntFeatureType = "if"; +const char* kGeoFeatureType = "tb"; + +// Define all of the parameters for the MicroFeature type. +StartParamDesc(MicroFeatureParams) +DefineParam(0, 0, -0.5, 0.5) +DefineParam(0, 0, -0.25, 0.75) +DefineParam(0, 1, 0.0, 1.0) +DefineParam(1, 0, 0.0, 1.0) DefineParam (0, 1, -0.5, 0.5) DefineParam (0, 1, -0.5, 0.5) EndParamDesc -/* now define the feature type itself (see features.h for info about each - parameter).*/ -DefineFeature (MicroFeatureDesc, 5, 1, 4, 50, "Micro", "mf", MicroFeatureParams) +// Now define the feature type itself (see features.h for parameters). +DefineFeature(MicroFeatureDesc, 5, 1, kMicroFeatureType, MicroFeatureParams) -// define all of the parameters for the PicoFeature type -/* define knob that can be used to adjust pico-feature length */ -FLOAT32 PicoFeatureLength = PICO_FEATURE_LENGTH; -StartParamDesc (PicoFeatParams) -DefineParam (0, 0, -0.25, 0.75) -DefineParam (1, 0, 0.0, 1.0) -DefineParam (0, 0, -0.5, 0.5) +// Define all of the parameters for the NormFeat type. +StartParamDesc (CharNormParams) +DefineParam(0, 0, -0.25, 0.75) +DefineParam(0, 1, 0.0, 1.0) +DefineParam(0, 0, 0.0, 1.0) +DefineParam(0, 0, 0.0, 1.0) EndParamDesc -/* now define the feature type itself (see features.h for info about each - parameter).*/ -DefineFeature (PicoFeatDesc, 2, 1, 1, MAX_UINT8, "Pico", "pf", PicoFeatParams) +// Now define the feature type itself (see features.h for parameters). +DefineFeature(CharNormDesc, 4, 0, kCNFeatureType, CharNormParams) -/* define all of the parameters for the NormFeat type*/ -StartParamDesc (CharNormParams) -DefineParam (0, 0, -0.25, 0.75) -DefineParam (0, 1, 0.0, 1.0) -DefineParam (0, 0, 0.0, 1.0) -DefineParam (0, 0, 0.0, 1.0) +// Define all of the parameters for the IntFeature type +StartParamDesc(IntFeatParams) +DefineParam(0, 0, 0.0, 255.0) +DefineParam(0, 0, 0.0, 255.0) +DefineParam(1, 0, 0.0, 255.0) +EndParamDesc +// Now define the feature type itself (see features.h for parameters). +DefineFeature(IntFeatDesc, 2, 1, kIntFeatureType, IntFeatParams) + +// Define all of the parameters for the GeoFeature type +StartParamDesc(GeoFeatParams) +DefineParam(0, 0, 0.0, 255.0) +DefineParam(0, 0, 0.0, 255.0) +DefineParam(0, 0, 0.0, 255.0) EndParamDesc -/* now define the feature type itself (see features.h for info about each - parameter).*/ -DefineFeature (CharNormDesc, 4, 0, 1, 1, "CharNorm", "cn", CharNormParams) - -// define all of the parameters for the OutlineFeature type -StartParamDesc (OutlineFeatParams) -DefineParam (0, 0, -0.5, 0.5) -DefineParam (0, 0, -0.25, 0.75) -DefineParam (0, 0, 0.0, 1.0) -DefineParam (1, 0, 0.0, 1.0) +// Now define the feature type itself (see features.h for parameters). +DefineFeature(GeoFeatDesc, 3, 0, kGeoFeatureType, GeoFeatParams) + +// Other features used for training the adaptive classifier, but not used +// during normal training, therefore not in the DescDefs array. + +// Define all of the parameters for the PicoFeature type +// define knob that can be used to adjust pico-feature length. +FLOAT32 PicoFeatureLength = PICO_FEATURE_LENGTH; +StartParamDesc(PicoFeatParams) +DefineParam(0, 0, -0.25, 0.75) +DefineParam(1, 0, 0.0, 1.0) +DefineParam(0, 0, -0.5, 0.5) EndParamDesc -/* now define the feature type itself (see features.h for info about each - parameter).*/ -DefineFeature (OutlineFeatDesc, 3, 1, 1, MAX_OUTLINE_FEATURES, "Outline", - "of", OutlineFeatParams) +// Now define the feature type itself (see features.h for parameters). +DefineFeature(PicoFeatDesc, 2, 1, "pf", PicoFeatParams) +// Define all of the parameters for the OutlineFeature type. +StartParamDesc(OutlineFeatParams) +DefineParam(0, 0, -0.5, 0.5) +DefineParam(0, 0, -0.25, 0.75) +DefineParam(0, 0, 0.0, 1.0) +DefineParam(1, 0, 0.0, 1.0) +EndParamDesc +// Now define the feature type itself (see features.h for parameters). +DefineFeature(OutlineFeatDesc, 3, 1, "of", OutlineFeatParams) + +// MUST be kept in-sync with ExtractorDefs in fxdefs.cpp. static const FEATURE_DESC_STRUCT *DescDefs[NUM_FEATURE_TYPES] = { &MicroFeatureDesc, - &PicoFeatDesc, - &OutlineFeatDesc, - &CharNormDesc + &CharNormDesc, + &IntFeatDesc, + &GeoFeatDesc }; /*----------------------------------------------------------------------------- @@ -188,6 +209,27 @@ void WriteCharDescription(const FEATURE_DEFS_STRUCT &FeatureDefs, } } /* WriteCharDescription */ +// Return whether all of the fields of the given feature set +// are well defined (not inf or nan). +bool ValidCharDescription(const FEATURE_DEFS_STRUCT &FeatureDefs, + CHAR_DESC CharDesc) { + bool anything_written = false; + bool well_formed = true; + for (int Type = 0; Type < CharDesc->NumFeatureSets; Type++) { + if (CharDesc->FeatureSets[Type]) { + for (int i = 0; i < CharDesc->FeatureSets[Type]->NumFeatures; i++) { + FEATURE feat = CharDesc->FeatureSets[Type]->Features[i]; + for (int p = 0; p < feat->Type->NumParams; p++) { + if (isnan(feat->Params[p]) || isinf(feat->Params[p])) + well_formed = false; + else + anything_written = true; + } + } + } + } + return anything_written && well_formed; +} /* ValidCharDescription */ /*---------------------------------------------------------------------------*/ /** diff --git a/classify/featdefs.h b/classify/featdefs.h index 1149bf3736..bfb1465471 100644 --- a/classify/featdefs.h +++ b/classify/featdefs.h @@ -25,6 +25,10 @@ /* Enumerate the different types of features currently defined. */ #define NUM_FEATURE_TYPES 4 +extern const char* kMicroFeatureType; +extern const char* kCNFeatureType; +extern const char* kIntFeatureType; +extern const char* kGeoFeatureType; /* define error traps which can be triggered by this module.*/ #define ILLEGAL_SHORT_NAME 2000 @@ -58,6 +62,9 @@ void FreeCharDescription(CHAR_DESC CharDesc); CHAR_DESC NewCharDescription(const FEATURE_DEFS_STRUCT &FeatureDefs); +bool ValidCharDescription(const FEATURE_DEFS_STRUCT &FeatureDefs, + CHAR_DESC CharDesc); + void WriteCharDescription(const FEATURE_DEFS_STRUCT &FeatureDefs, FILE *File, CHAR_DESC CharDesc); @@ -74,4 +81,6 @@ extern const FEATURE_DESC_STRUCT MicroFeatureDesc; extern const FEATURE_DESC_STRUCT PicoFeatDesc; extern const FEATURE_DESC_STRUCT CharNormDesc; extern const FEATURE_DESC_STRUCT OutlineFeatDesc; +extern const FEATURE_DESC_STRUCT IntFeatDesc; +extern const FEATURE_DESC_STRUCT GeoFeatDesc; #endif diff --git a/classify/float2int.cpp b/classify/float2int.cpp index 8c9c58e9fd..87901ca968 100644 --- a/classify/float2int.cpp +++ b/classify/float2int.cpp @@ -22,6 +22,7 @@ #include "normmatch.h" #include "mfoutline.h" #include "classify.h" +#include "helpers.h" #include "picofeat.h" #define MAX_INT_CHAR_NORM (INT_CHAR_NORM_RANGE - 1) @@ -33,63 +34,44 @@ namespace tesseract { /** - * For each class in Templates, clear the corresponding - * entry in CharNormArray. CharNormArray is indexed by class - * indicies (as obtained from Templates) rather than class id's. + * For each class in the unicharset, clears the corresponding + * entry in char_norm_array. char_norm_array is indexed by unichar_id. * * Globals: * - none * - * @param Templates specifies classes currently defined - * @param CharNormArray array to be cleared + * @param char_norm_array array to be cleared * * @note Exceptions: none * @note History: Wed Feb 20 11:20:54 1991, DSJ, Created. */ -void ClearCharNormArray(INT_TEMPLATES Templates, - CLASS_NORMALIZATION_ARRAY CharNormArray) { - int i; - - for (i = 0; i < Templates->NumClasses; i++) { - CharNormArray[i] = 0; - } - +void Classify::ClearCharNormArray(uinT8* char_norm_array) { + memset(char_norm_array, 0, sizeof(*char_norm_array) * unicharset.size()); } /* ClearCharNormArray */ /*---------------------------------------------------------------------------*/ /** - * For each class in Templates, compute the match between - * NormFeature and the normalization protos for that class. - * Convert this number to the range from 0 - 255 and store it - * into CharNormArray. CharNormArray is indexed by class - * indicies (as obtained from Templates) rather than class id's. + * For each class in unicharset, computes the match between + * norm_feature and the normalization protos for that class. + * Converts this number to the range from 0 - 255 and stores it + * into char_norm_array. CharNormArray is indexed by unichar_id. * * Globals: * - none * - * @param NormFeature character normalization feature - * @param Templates specifies classes currently defined - * @param[out] CharNormArray place to put results + * @param norm_feature character normalization feature + * @param[out] char_norm_array place to put results of size unicharset.size() * * @note Exceptions: none * @note History: Wed Feb 20 11:20:54 1991, DSJ, Created. */ -void Classify::ComputeIntCharNormArray( - FEATURE NormFeature, INT_TEMPLATES Templates, - CLASS_NORMALIZATION_ARRAY CharNormArray) { - int i; - int NormAdjust; - - for (i = 0; i < Templates->NumClasses; i++) { - NormAdjust = (int) (INT_CHAR_NORM_RANGE * - ComputeNormMatch (i, NormFeature, FALSE)); - if (NormAdjust < 0) - NormAdjust = 0; - else if (NormAdjust > MAX_INT_CHAR_NORM) - NormAdjust = MAX_INT_CHAR_NORM; - - CharNormArray[i] = NormAdjust; +void Classify::ComputeIntCharNormArray(const FEATURE_STRUCT& norm_feature, + uinT8* char_norm_array) { + for (int i = 0; i < unicharset.size(); i++) { + int norm_adjust = static_cast(INT_CHAR_NORM_RANGE * + ComputeNormMatch(i, norm_feature, FALSE)); + char_norm_array[i] = ClipToRange(norm_adjust, 0, MAX_INT_CHAR_NORM); } } /* ComputeIntCharNormArray */ diff --git a/classify/float2int.h b/classify/float2int.h index bc422e67ba..4f0a21bec9 100644 --- a/classify/float2int.h +++ b/classify/float2int.h @@ -27,12 +27,4 @@ #define INT_FEAT_RANGE 256 #define BASELINE_Y_SHIFT (0.25) -/*----------------------------------------------------------------------------- - Public Function Prototypes ------------------------------------------------------------------------------*/ -namespace tesseract { -void ClearCharNormArray(INT_TEMPLATES Templates, - CLASS_NORMALIZATION_ARRAY CharNormArray); -} // namespace tesseract. - #endif diff --git a/classify/fxdefs.cpp b/classify/fxdefs.cpp index d809c3656f..c4f9cd5599 100644 --- a/classify/fxdefs.cpp +++ b/classify/fxdefs.cpp @@ -27,15 +27,16 @@ -----------------------------------------------------------------------------*/ // Definitions of extractors separated from feature definitions. const FEATURE_EXT_STRUCT MicroFeatureExt = { ExtractMicros }; -const FEATURE_EXT_STRUCT PicoFeatExt = { NULL }; -const FEATURE_EXT_STRUCT OutlineFeatExt = { NULL }; const FEATURE_EXT_STRUCT CharNormExt = { ExtractCharNormFeatures }; +const FEATURE_EXT_STRUCT IntFeatExt = { ExtractIntCNFeatures }; +const FEATURE_EXT_STRUCT GeoFeatExt = { ExtractIntGeoFeatures }; +// MUST be kept in-sync with DescDefs in featdefs.cpp. const FEATURE_EXT_STRUCT* ExtractorDefs[NUM_FEATURE_TYPES] = { &MicroFeatureExt, - &PicoFeatExt, - &OutlineFeatExt, - &CharNormExt + &CharNormExt, + &IntFeatExt, + &GeoFeatExt }; void SetupExtractors(FEATURE_DEFS_STRUCT *FeatureDefs) { diff --git a/classify/intfeaturedist.cpp b/classify/intfeaturedist.cpp new file mode 100644 index 0000000000..02dd228637 --- /dev/null +++ b/classify/intfeaturedist.cpp @@ -0,0 +1,159 @@ +// Copyright 2011 Google Inc. All Rights Reserved. +// Author: rays@google.com (Ray Smith) +/////////////////////////////////////////////////////////////////////// +// File: intfeaturedist.cpp +// Description: Fast set-difference-based feature distance calculator. +// Created: Thu Sep 01 13:07:30 PDT 2011 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +/////////////////////////////////////////////////////////////////////// + +#include "intfeaturedist.h" +#include "intfeaturemap.h" + +namespace tesseract { + +IntFeatureDist::IntFeatureDist() + : size_(0), total_feature_weight_(0.0), + feature_map_(NULL), features_(NULL), + features_delta_one_(NULL), features_delta_two_(NULL) { +} + +IntFeatureDist::~IntFeatureDist() { + Clear(); +} + +// Initialize the table to the given size of feature space. +void IntFeatureDist::Init(const IntFeatureMap* feature_map) { + size_ = feature_map->sparse_size(); + Clear(); + feature_map_ = feature_map; + features_ = new bool[size_]; + features_delta_one_ = new bool[size_]; + features_delta_two_ = new bool[size_]; + memset(features_, false, size_ * sizeof(features_[0])); + memset(features_delta_one_, false, size_ * sizeof(features_delta_one_[0])); + memset(features_delta_two_, false, size_ * sizeof(features_delta_two_[0])); + total_feature_weight_ = 0.0; +} + +// Setup the map for the given indexed_features that have been indexed by +// feature_map. +void IntFeatureDist::Set(const GenericVector& indexed_features, + int canonical_count, bool value) { + total_feature_weight_ = canonical_count; + for (int i = 0; i < indexed_features.size(); ++i) { + int f = indexed_features[i]; + features_[f] = value; + for (int dir = -kNumOffsetMaps; dir <= kNumOffsetMaps; ++dir) { + if (dir == 0) continue; + int mapped_f = feature_map_->OffsetFeature(f, dir); + if (mapped_f >= 0) { + features_delta_one_[mapped_f] = value; + for (int dir2 = -kNumOffsetMaps; dir2 <= kNumOffsetMaps; ++dir2) { + if (dir2 == 0) continue; + int mapped_f2 = feature_map_->OffsetFeature(mapped_f, dir2); + if (mapped_f2 >= 0) + features_delta_two_[mapped_f2] = value; + } + } + } + } +} + +// Compute the distance between the given feature vector and the last +// Set feature vector. +double IntFeatureDist::FeatureDistance( + const GenericVector& features) const { + int num_test_features = features.size(); + double denominator = total_feature_weight_ + num_test_features; + double misses = denominator; + for (int i = 0; i < num_test_features; ++i) { + int index = features[i]; + double weight = 1.0; + if (features_[index]) { + // A perfect match. + misses -= 2.0 * weight; + } else if (features_delta_one_[index]) { + misses -= 1.5 * weight; + } else if (features_delta_two_[index]) { + // A near miss. + misses -= 1.0 * weight; + } + } + return misses / denominator; +} + +// Compute the distance between the given feature vector and the last +// Set feature vector. +double IntFeatureDist::DebugFeatureDistance( + const GenericVector& features) const { + int num_test_features = features.size(); + double denominator = total_feature_weight_ + num_test_features; + double misses = denominator; + for (int i = 0; i < num_test_features; ++i) { + int index = features[i]; + double weight = 1.0; + INT_FEATURE_STRUCT f = feature_map_->InverseMapFeature(features[i]); + tprintf("Testing feature weight %g:", weight); + f.print(); + if (features_[index]) { + // A perfect match. + misses -= 2.0 * weight; + tprintf("Perfect hit\n"); + } else if (features_delta_one_[index]) { + misses -= 1.5 * weight; + tprintf("-1 hit\n"); + } else if (features_delta_two_[index]) { + // A near miss. + misses -= 1.0 * weight; + tprintf("-2 hit\n"); + } else { + tprintf("Total miss\n"); + } + } + tprintf("Features present:"); + for (int i = 0; i < size_; ++i) { + if (features_[i]) { + INT_FEATURE_STRUCT f = feature_map_->InverseMapFeature(i); + f.print(); + } + } + tprintf("\nMinus one features:"); + for (int i = 0; i < size_; ++i) { + if (features_delta_one_[i]) { + INT_FEATURE_STRUCT f = feature_map_->InverseMapFeature(i); + f.print(); + } + } + tprintf("\nMinus two features:"); + for (int i = 0; i < size_; ++i) { + if (features_delta_two_[i]) { + INT_FEATURE_STRUCT f = feature_map_->InverseMapFeature(i); + f.print(); + } + } + tprintf("\n"); + return misses / denominator; +} + +// Clear all data. +void IntFeatureDist::Clear() { + delete [] features_; + features_ = NULL; + delete [] features_delta_one_; + features_delta_one_ = NULL; + delete [] features_delta_two_; + features_delta_two_ = NULL; +} + +} // namespace tesseract diff --git a/classify/intfeaturedist.h b/classify/intfeaturedist.h new file mode 100644 index 0000000000..8cfcb3bed5 --- /dev/null +++ b/classify/intfeaturedist.h @@ -0,0 +1,80 @@ +// Copyright 2011 Google Inc. All Rights Reserved. +// Author: rays@google.com (Ray Smith) +/////////////////////////////////////////////////////////////////////// +// File: intfeaturedist.h +// Description: Fast set-difference-based feature distance calculator. +// Created: Thu Sep 01 12:14:30 PDT 2011 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +/////////////////////////////////////////////////////////////////////// + +#ifndef TESSERACT_CLASSIFY_INTFEATUREDIST_H_ +#define TESSERACT_CLASSIFY_INTFEATUREDIST_H_ + +#include "genericvector.h" + +namespace tesseract { + +class IntFeatureMap; + +// Feature distance calculator designed to provide a fast distance calculation +// based on set difference between a given feature set and many other feature +// sets in turn. +// Representation of a feature set as an array of bools that are sparsely +// true, and companion arrays that allow fast feature set distance +// calculations with allowance of offsets in position. +// Init is expensive, so for greatest efficiency, to re-initialize for a new +// feature set, use Set(..., false) on the SAME feature set as was used to +// setup with Set(..., true), to return to its initialized state before +// reuse with Set(..., true) on a new feature set. +class IntFeatureDist { + public: + IntFeatureDist(); + ~IntFeatureDist(); + + // Initialize the bool array to the given size of feature space. + // The feature_map is just borrowed, and must exist for the entire + // lifetime of the IntFeatureDist. + void Init(const IntFeatureMap* feature_map); + + // Setup the map for the given indexed_features that have been indexed by + // feature_map. After use, use Set(..., false) to reset to the initial state + // as this is faster than calling Init for sparse spaces. + void Set(const GenericVector& indexed_features, + int canonical_count, bool value); + + // Compute the distance between the given feature vector and the last + // Set feature vector. + double FeatureDistance(const GenericVector& features) const; + double DebugFeatureDistance(const GenericVector& features) const; + + private: + // Clear all data. + void Clear(); + + // Size of the indexed feature space. + int size_; + // Total weight of features currently stored in the maps. + double total_feature_weight_; + // Pointer to IntFeatureMap given at Init to find offset features. + const IntFeatureMap* feature_map_; + // Array of bools indicating presence of a feature. + bool* features_; + // Array indicating the presence of a feature offset by one unit. + bool* features_delta_one_; + // Array indicating the presence of a feature offset by two units. + bool* features_delta_two_; +}; + +} // namespace tesseract + +#endif // TESSERACT_CLASSIFY_INTFEATUREDIST_H_ diff --git a/classify/intfeaturemap.cpp b/classify/intfeaturemap.cpp new file mode 100644 index 0000000000..79c5040e96 --- /dev/null +++ b/classify/intfeaturemap.cpp @@ -0,0 +1,245 @@ +// Copyright 2010 Google Inc. All Rights Reserved. +// Author: rays@google.com (Ray Smith) +/////////////////////////////////////////////////////////////////////// +// File: intfeaturemap.cpp +// Description: Encapsulation of IntFeatureSpace with IndexMapBiDi +// to provide a subspace mapping and fast feature lookup. +// Created: Tue Oct 26 08:58:30 PDT 2010 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +/////////////////////////////////////////////////////////////////////// + +#include "intfeaturemap.h" + +#include "intfeaturespace.h" +#include "intfx.h" +// These includes do not exist yet, but will be coming soon. +//#include "sampleiterator.h" +//#include "trainingsample.h" +//#include "trainingsampleset.h" + +namespace tesseract { + +const int kMaxOffsetDist = 32; +const double kMinPCLengthIncrease = 1.0 / 1024; + +IntFeatureMap::IntFeatureMap() + : mapping_changed_(true), compact_size_(0) { + for (int dir = 0; dir < kNumOffsetMaps; ++dir) { + offset_plus_[dir] = NULL; + offset_minus_[dir] = NULL; + } +} + +IntFeatureMap::~IntFeatureMap() { + Clear(); +} + +// Pseudo-accessors. +int IntFeatureMap::IndexFeature(const INT_FEATURE_STRUCT& f) const { + return feature_space_.Index(f); +} +int IntFeatureMap::MapFeature(const INT_FEATURE_STRUCT& f) const { + return feature_map_.SparseToCompact(feature_space_.Index(f)); +} +int IntFeatureMap::MapIndexFeature(int index_feature) const { + return feature_map_.SparseToCompact(index_feature); +} +INT_FEATURE_STRUCT IntFeatureMap::InverseIndexFeature(int index_feature) const { + return feature_space_.PositionFromIndex(index_feature); +} +INT_FEATURE_STRUCT IntFeatureMap::InverseMapFeature(int map_feature) const { + int index = feature_map_.CompactToSparse(map_feature); + return feature_space_.PositionFromIndex(index); +} +void IntFeatureMap::DeleteMapFeature(int map_feature) { + feature_map_.Merge(-1, map_feature); + mapping_changed_ = true; +} +bool IntFeatureMap::IsMapFeatureDeleted(int map_feature) const { + return feature_map_.IsCompactDeleted(map_feature); +} + +// Copies the given feature_space and uses it as the index feature map +// from INT_FEATURE_STRUCT. +void IntFeatureMap::Init(const IntFeatureSpace& feature_space) { + feature_space_ = feature_space; + mapping_changed_ = false; + int sparse_size = feature_space_.Size(); + feature_map_.Init(sparse_size, true); + feature_map_.Setup(); + compact_size_ = feature_map_.CompactSize(); + // Initialize look-up tables if needed. + FCOORD dir = FeatureDirection(0); + if (dir.x() == 0.0f && dir.y() == 0.0f) + InitIntegerFX(); + // Compute look-up tables to generate offset features. + for (int dir = 0; dir < kNumOffsetMaps; ++dir) { + delete [] offset_plus_[dir]; + delete [] offset_minus_[dir]; + offset_plus_[dir] = new int[sparse_size]; + offset_minus_[dir] = new int[sparse_size]; + } + for (int dir = 1; dir <= kNumOffsetMaps; ++dir) { + for (int i = 0; i < sparse_size; ++i) { + int offset_index = ComputeOffsetFeature(i, dir); + offset_plus_[dir - 1][i] = offset_index; + offset_index = ComputeOffsetFeature(i, -dir); + offset_minus_[dir - 1][i] = offset_index; + } + } +} + +// Helper to return an offset index feature. In this context an offset +// feature with a dir of +/-1 is a feature of a similar direction, +// but shifted perpendicular to the direction of the feature. An offset +// feature with a dir of +/-2 is feature at the same position, but rotated +// by +/- one [compact] quantum. Returns the index of the generated offset +// feature, or -1 if it doesn't exist. Dir should be in +// [-kNumOffsetMaps, kNumOffsetMaps] to indicate the relative direction. +// A dir of 0 is an identity transformation. +// Both input and output are from the index(sparse) feature space, not +// the mapped/compact feature space, but the offset feature is the minimum +// distance moved from the input to guarantee that it maps to the next +// available quantum in the mapped/compact space. +int IntFeatureMap::OffsetFeature(int index_feature, int dir) const { + if (dir > 0 && dir <= kNumOffsetMaps) + return offset_plus_[dir - 1][index_feature]; + else if (dir < 0 && -dir <= kNumOffsetMaps) + return offset_minus_[-dir - 1][index_feature]; + else if (dir == 0) + return index_feature; + else + return -1; +} + + +//#define EXPERIMENT_ON +#ifdef EXPERIMENT_ON // This code is commented out as SampleIterator and +// TrainingSample are not reviewed/checked in yet, but these functions are a +// useful indicator of how an IntFeatureMap is setup. + +// Computes the features used by the subset of samples defined by +// the iterator and sets up the feature mapping. +// Returns the size of the compacted feature space. +int IntFeatureMap::FindNZFeatureMapping(SampleIterator* it) { + feature_map_.Init(feature_space_.Size(), false); + int total_samples = 0; + for (it->Begin(); !it->AtEnd(); it->Next()) { + const TrainingSample& sample = it->GetSample(); + GenericVector features; + feature_space_.IndexAndSortFeatures(sample.features(), + sample.num_features(), + &features); + int num_features = features.size(); + for (int f = 0; f < num_features; ++f) + feature_map_.SetMap(features[f], true); + ++total_samples; + } + feature_map_.Setup(); + compact_size_ = feature_map_.CompactSize(); + mapping_changed_ = true; + FinalizeMapping(it); + tprintf("%d non-zero features found in %d samples\n", + compact_size_, total_samples); + return compact_size_; +} +#endif + +// After deleting some features, finish setting up the mapping, and map +// all the samples. Returns the size of the compacted feature space. +int IntFeatureMap::FinalizeMapping(SampleIterator* it) { + if (mapping_changed_) { + feature_map_.CompleteMerges(); + compact_size_ = feature_map_.CompactSize(); +#ifdef EXPERIMENT_ON + it->MapSampleFeatures(*this); +#endif + mapping_changed_ = false; + } + return compact_size_; +} + +// Prints the map features from the set in human-readable form. +void IntFeatureMap::DebugMapFeatures( + const GenericVector& map_features) const { + for (int i = 0; i < map_features.size(); ++i) { + INT_FEATURE_STRUCT f = InverseMapFeature(map_features[i]); + f.print(); + } +} + +void IntFeatureMap::Clear() { + for (int dir = 0; dir < kNumOffsetMaps; ++dir) { + delete [] offset_plus_[dir]; + delete [] offset_minus_[dir]; + offset_plus_[dir] = NULL; + offset_minus_[dir] = NULL; + } +} + +// Helper to compute an offset index feature. In this context an offset +// feature with a dir of +/-1 is a feature of a similar direction, +// but shifted perpendicular to the direction of the feature. An offset +// feature with a dir of +/-2 is feature at the same position, but rotated +// by +/- one [compact] quantum. Returns the index of the generated offset +// feature, or -1 if it doesn't exist. Dir should be in +// [-kNumOffsetMaps, kNumOffsetMaps] to indicate the relative direction. +// A dir of 0 is an identity transformation. +// Both input and output are from the index(sparse) feature space, not +// the mapped/compact feature space, but the offset feature is the minimum +// distance moved from the input to guarantee that it maps to the next +// available quantum in the mapped/compact space. +int IntFeatureMap::ComputeOffsetFeature(int index_feature, int dir) const { + INT_FEATURE_STRUCT f = InverseIndexFeature(index_feature); + ASSERT_HOST(IndexFeature(f) == index_feature); + if (dir == 0) { + return index_feature; + } else if (dir == 1 || dir == -1) { + FCOORD feature_dir = FeatureDirection(f.Theta); + FCOORD rotation90(0.0f, 1.0f); + feature_dir.rotate(rotation90); + // Find the nearest existing feature. + for (int m = 1; m < kMaxOffsetDist; ++m) { + double x_pos = f.X + feature_dir.x() * (m * dir); + double y_pos = f.Y + feature_dir.y() * (m * dir); + int x = IntCastRounded(x_pos); + int y = IntCastRounded(y_pos); + if (x >= 0 && x <= MAX_UINT8 && y >= 0 && y <= MAX_UINT8) { + INT_FEATURE_STRUCT offset_f; + offset_f.X = x; + offset_f.Y = y; + offset_f.Theta = f.Theta; + int offset_index = IndexFeature(offset_f); + if (offset_index != index_feature && offset_index >= 0) + return offset_index; // Found one. + } else { + return -1; // Hit the edge of feature space. + } + } + } else if (dir == 2 || dir == -2) { + // Find the nearest existing index_feature. + for (int m = 1; m < kMaxOffsetDist; ++m) { + int theta = f.Theta + m * dir / 2; + INT_FEATURE_STRUCT offset_f; + offset_f.X = f.X; + offset_f.Y = f.Y; + offset_f.Theta = Modulo(theta, 256); + int offset_index = IndexFeature(offset_f); + if (offset_index != index_feature && offset_index >= 0) + return offset_index; // Found one. + } + } + return -1; // Nothing within the max distance. +} + +} // namespace tesseract. diff --git a/classify/intfeaturemap.h b/classify/intfeaturemap.h new file mode 100644 index 0000000000..55c5b5cf5e --- /dev/null +++ b/classify/intfeaturemap.h @@ -0,0 +1,163 @@ +// Copyright 2010 Google Inc. All Rights Reserved. +// Author: rays@google.com (Ray Smith) +/////////////////////////////////////////////////////////////////////// +// File: intfeaturemap.h +// Description: Encapsulation of IntFeatureSpace with IndexMapBiDi +// to provide a subspace mapping and fast feature lookup. +// Created: Tue Oct 26 08:58:30 PDT 2010 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +/////////////////////////////////////////////////////////////////////// + +#ifndef TESSERACT_CLASSIFY_INTFEATUREMAP_H__ +#define TESSERACT_CLASSIFY_INTFEATUREMAP_H__ + +#include "intfeaturespace.h" +#include "indexmapbidi.h" +#include "intproto.h" + +namespace tesseract { + +class SampleIterator; + +// Number of positive and negative offset maps. +static const int kNumOffsetMaps = 2; + +// Class to map a feature space defined by INT_FEATURE_STRUCT to a compact +// down-sampled subspace of actually used features. +// The IntFeatureMap copes with 2 stages of transformation: +// The first step is down-sampling (re-quantization) and converting to a +// single index value from the 3-D input: +// INT_FEATURE_STRUCT <-> index feature (via IntFeatureSpace) and +// the second is a feature-space compaction to map only the feature indices +// that are actually used. This saves space in classifiers that are built +// using the mapped feature space. +// index (sparse) feature <-> map (compact) feature via IndexMapBiDi. +// Although the transformations are reversible, the inverses are lossy and do +// not return the exact input INT_FEATURE_STRUCT, due to the many->one nature +// of both transformations. +class IntFeatureMap { + public: + IntFeatureMap(); + ~IntFeatureMap(); + + // Accessors. + int sparse_size() const { + return feature_space_.Size(); + } + int compact_size() const { + return compact_size_; + } + const IntFeatureSpace& feature_space() const { + return feature_space_; + } + const IndexMapBiDi& feature_map() const { + return feature_map_; + } + + // Pseudo-accessors. + int IndexFeature(const INT_FEATURE_STRUCT& f) const; + int MapFeature(const INT_FEATURE_STRUCT& f) const; + int MapIndexFeature(int index_feature) const; + INT_FEATURE_STRUCT InverseIndexFeature(int index_feature) const; + INT_FEATURE_STRUCT InverseMapFeature(int map_feature) const; + void DeleteMapFeature(int map_feature); + bool IsMapFeatureDeleted(int map_feature) const; + + // Copies the given feature_space and uses it as the index feature map + // from INT_FEATURE_STRUCT. + void Init(const IntFeatureSpace& feature_space); + + // Helper to return an offset index feature. In this context an offset + // feature with a dir of +/-1 is a feature of a similar direction, + // but shifted perpendicular to the direction of the feature. An offset + // feature with a dir of +/-2 is feature at the same position, but rotated + // by +/- one [compact] quantum. Returns the index of the generated offset + // feature, or -1 if it doesn't exist. Dir should be in + // [-kNumOffsetMaps, kNumOffsetMaps] to indicate the relative direction. + // A dir of 0 is an identity transformation. + // Both input and output are from the index(sparse) feature space, not + // the mapped/compact feature space, but the offset feature is the minimum + // distance moved from the input to guarantee that it maps to the next + // available quantum in the mapped/compact space. + int OffsetFeature(int index_feature, int dir) const; + + // Computes the features used by the subset of samples defined by + // the iterator and sets up the feature mapping. + // Returns the size of the compacted feature space. + int FindNZFeatureMapping(SampleIterator* it); + + // After deleting some features, finish setting up the mapping, and map + // all the samples. Returns the size of the compacted feature space. + int FinalizeMapping(SampleIterator* it); + + // Indexes the given array of features to a vector of sorted indices. + void IndexAndSortFeatures(const INT_FEATURE_STRUCT* features, + int num_features, + GenericVector* sorted_features) const { + feature_space_.IndexAndSortFeatures(features, num_features, + sorted_features); + } + // Maps the given array of index/sparse features to an array of map/compact + // features. + // Assumes the input is sorted. The output indices are sorted and uniqued. + // Returns the number of "missed" features, being features that + // don't map to the compact feature space. + int MapIndexedFeatures(const GenericVector& index_features, + GenericVector* map_features) const { + return feature_map_.MapFeatures(index_features, map_features); + } + + // Prints the map features from the set in human-readable form. + void DebugMapFeatures(const GenericVector& map_features) const; + + private: + void Clear(); + + // Helper to compute an offset index feature. In this context an offset + // feature with a dir of +/-1 is a feature of a similar direction, + // but shifted perpendicular to the direction of the feature. An offset + // feature with a dir of +/-2 is feature at the same position, but rotated + // by +/- one [compact] quantum. Returns the index of the generated offset + // feature, or -1 if it doesn't exist. Dir should be in + // [-kNumOffsetMaps, kNumOffsetMaps] to indicate the relative direction. + // A dir of 0 is an identity transformation. + // Both input and output are from the index(sparse) feature space, not + // the mapped/compact feature space, but the offset feature is the minimum + // distance moved from the input to guarantee that it maps to the next + // available quantum in the mapped/compact space. + int ComputeOffsetFeature(int index_feature, int dir) const; + + // True if the mapping has changed since it was last finalized. + bool mapping_changed_; + // Size of the compacted feature space, after unused features are removed. + int compact_size_; + // Feature space quantization definition and indexing from INT_FEATURE_STRUCT. + IntFeatureSpace feature_space_; + // Mapping from indexed feature space to the compacted space with unused + // features mapping to -1. + IndexMapBiDi feature_map_; + // Index tables to map a feature index to the corresponding feature after a + // shift perpendicular to the feature direction, or a rotation in place. + // An entry of -1 indicates that there is no corresponding feature. + // Array of arrays of size feature_space_.Size() owned by this class. + int* offset_plus_[kNumOffsetMaps]; + int* offset_minus_[kNumOffsetMaps]; + + // Don't use default copy and assign! + IntFeatureMap(const IntFeatureMap&); + void operator=(const IntFeatureMap&); +}; + +} // namespace tesseract. + +#endif // TESSERACT_CLASSIFY_INTFEATUREMAP_H__ diff --git a/classify/intfeaturespace.cpp b/classify/intfeaturespace.cpp new file mode 100644 index 0000000000..5f911dcee6 --- /dev/null +++ b/classify/intfeaturespace.cpp @@ -0,0 +1,143 @@ +// Copyright 2010 Google Inc. All Rights Reserved. +// Author: rays@google.com (Ray Smith) +/////////////////////////////////////////////////////////////////////// +// File: intfeaturespace.cpp +// Description: Indexed feature space based on INT_FEATURE_STRUCT. +// Created: Wed Mar 24 11:21:27 PDT 2010 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +/////////////////////////////////////////////////////////////////////// + +#include "intfeaturespace.h" +#include "intfx.h" + +namespace tesseract { + +IntFeatureSpace::IntFeatureSpace() + : x_buckets_(0), y_buckets_(0), theta_buckets_(0) { +} + +void IntFeatureSpace::Init(uinT8 xbuckets, uinT8 ybuckets, uinT8 thetabuckets) { + x_buckets_ = xbuckets; + y_buckets_ = ybuckets; + theta_buckets_ = thetabuckets; +} + +// Serializes the feature space definition to the given file. +// Returns false on error. +bool IntFeatureSpace::Serialize(FILE* fp) const { + if (fwrite(&x_buckets_, sizeof(x_buckets_), 1, fp) != 1) + return false; + if (fwrite(&y_buckets_, sizeof(y_buckets_), 1, fp) != 1) + return false; + if (fwrite(&theta_buckets_, sizeof(theta_buckets_), 1, fp) != 1) + return false; + return true; +} + +// DeSerializes the feature space definition from the given file. +// If swap is true, the data is big/little-endian swapped. +// Returns false on error. +bool IntFeatureSpace::DeSerialize(bool swap, FILE* fp) { + if (fread(&x_buckets_, sizeof(x_buckets_), 1, fp) != 1) + return false; + if (fread(&y_buckets_, sizeof(y_buckets_), 1, fp) != 1) + return false; + if (fread(&theta_buckets_, sizeof(theta_buckets_), 1, fp) != 1) + return false; + return true; +} + +// Returns an INT_FEATURE_STRUCT corresponding to the given index. +// This is the inverse of the Index member. +INT_FEATURE_STRUCT IntFeatureSpace::PositionFromIndex(int index) const { + return PositionFromBuckets(index / (y_buckets_ * theta_buckets_), + index / theta_buckets_ % y_buckets_, + index % theta_buckets_); +} + +// Bulk calls to Index. Maps the given array of features to a vector of +// inT32 indices in the same order as the input. +void IntFeatureSpace::IndexFeatures(const INT_FEATURE_STRUCT* features, + int num_features, + GenericVector* mapped_features) const { + mapped_features->truncate(0); + for (int f = 0; f < num_features; ++f) + mapped_features->push_back(Index(features[f])); +} + +// Bulk calls to Index. Maps the given array of features to a vector of +// sorted inT32 indices. +void IntFeatureSpace::IndexAndSortFeatures( + const INT_FEATURE_STRUCT* features, int num_features, + GenericVector* sorted_features) const { + sorted_features->truncate(0); + for (int f = 0; f < num_features; ++f) + sorted_features->push_back(Index(features[f])); + sorted_features->sort(); +} + +// Returns a feature space index for the given x,y position in a display +// window, or -1 if the feature is a miss. +int IntFeatureSpace::XYToFeatureIndex(int x, int y) const { + // Round the x,y position to a feature. Search for a valid theta. + INT_FEATURE_STRUCT feature = {static_cast(x), static_cast(y), + 0, 0}; + int index = -1; + for (int theta = 0; theta <= MAX_UINT8 && index < 0; ++theta) { + feature.Theta = theta; + index = Index(feature); + } + if (index < 0) { + tprintf("(%d,%d) does not exist in feature space!\n", x, y); + return -1; + } + feature = PositionFromIndex(index); + tprintf("Click at (%d, %d) ->(%d, %d), ->(%d, %d)\n", + x, y, feature.X, feature.Y, x - feature.X, y - feature.Y); + // Get the relative position of x,y from the rounded feature. + x -= feature.X; + y -= feature.Y; + if (x != 0 || y != 0) { + double angle = atan2(static_cast(y), static_cast(x)) + PI; + angle *= kIntFeatureExtent / (2.0 * PI); + feature.Theta = static_cast(angle + 0.5); + index = Index(feature); + if (index < 0) { + tprintf("Feature failed to map to a valid index:"); + feature.print(); + return -1; + } + feature = PositionFromIndex(index); + } + feature.print(); + return index; +} + +// Returns an INT_FEATURE_STRUCT corresponding to the given bucket coords. +INT_FEATURE_STRUCT IntFeatureSpace::PositionFromBuckets(int x, + int y, + int theta) const { + INT_FEATURE_STRUCT pos = { + static_cast(ClipToRange( + (x * kIntFeatureExtent + kIntFeatureExtent / 2) / x_buckets_, + 0, MAX_UINT8)), + static_cast(ClipToRange( + (y * kIntFeatureExtent + kIntFeatureExtent / 2) / y_buckets_, + 0, MAX_UINT8)), + static_cast(ClipToRange( + DivRounded(theta * kIntFeatureExtent, theta_buckets_), + 0, MAX_UINT8))}; + return pos; +} + +} // namespace tesseract. diff --git a/classify/intfeaturespace.h b/classify/intfeaturespace.h new file mode 100644 index 0000000000..e1e8e6ec9b --- /dev/null +++ b/classify/intfeaturespace.h @@ -0,0 +1,110 @@ +// Copyright 2010 Google Inc. All Rights Reserved. +// Author: rays@google.com (Ray Smith) +/////////////////////////////////////////////////////////////////////// +// File: intfeaturespace.h +// Description: Indexed feature space based on INT_FEATURE_STRUCT. +// Created: Wed Mar 24 10:55:30 PDT 2010 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +/////////////////////////////////////////////////////////////////////// + +#ifndef TESSERACT_CLASSIFY_INTFEATURESPACE_H__ +#define TESSERACT_CLASSIFY_INTFEATURESPACE_H__ + +#include "genericvector.h" +#include "intproto.h" + +// Extent of x,y,theta in the input feature space. [0,255]. +const int kIntFeatureExtent = 256; +// Extent of x,y,theta dimensions in the quantized feature space. +const int kBoostXYBuckets = 16; +const int kBoostDirBuckets = 16; + +namespace tesseract { + +class IndexMap; + +// Down-sampling quantization of the INT_FEATURE_STRUCT feature space and +// conversion to a single scalar index value, used as a binary feature space. +class IntFeatureSpace { + public: + IntFeatureSpace(); + // Default copy constructors and assignment OK! + + // Setup the feature space with the given dimensions. + void Init(uinT8 xbuckets, uinT8 ybuckets, uinT8 thetabuckets); + + // Serializes the feature space definition to the given file. + // Returns false on error. + bool Serialize(FILE* fp) const; + + // DeSerializes the feature space definition from the given file. + // If swap is true, the data is big/little-endian swapped. + // Returns false on error. + bool DeSerialize(bool swap, FILE* fp); + + // Returns the total size of the feature space. + int Size() const { + return static_cast(x_buckets_) * y_buckets_ * theta_buckets_; + } + // Returns an INT_FEATURE_STRUCT corresponding to the given index. + // This is the inverse of the Index member. + INT_FEATURE_STRUCT PositionFromIndex(int index) const; + + // Returns a 1-dimensional index corresponding to the given feature value. + // Range is [0, Size()-1]. Inverse of PositionFromIndex member. + int Index(const INT_FEATURE_STRUCT& f) const { + return (XBucket(f.X) * y_buckets_ + YBucket(f.Y)) * theta_buckets_ + + ThetaBucket(f.Theta); + } + // Bulk calls to Index. Maps the given array of features to a vector of + // inT32 indices in the same order as the input. + void IndexFeatures(const INT_FEATURE_STRUCT* features, int num_features, + GenericVector* mapped_features) const; + // Bulk calls to Index. Maps the given array of features to a vector of + // sorted inT32 indices. + void IndexAndSortFeatures(const INT_FEATURE_STRUCT* features, + int num_features, + GenericVector* sorted_features) const; + // Returns a feature space index for the given x,y position in a display + // window, or -1 if the feature is a miss. + int XYToFeatureIndex(int x, int y) const; + + protected: + // Converters to generate indices for individual feature dimensions. + int XBucket(int x) const { + int bucket = x * x_buckets_ / kIntFeatureExtent; + return ClipToRange(bucket, 0, static_cast(x_buckets_) - 1); + } + int YBucket(int y) const { + int bucket = y * y_buckets_ / kIntFeatureExtent; + return ClipToRange(bucket, 0, static_cast(y_buckets_) - 1); + } + // Use DivRounded for theta so that exactly vertical and horizontal are in + // the middle of a bucket. The Modulo takes care of the wrap-around. + int ThetaBucket(int theta) const { + int bucket = DivRounded(theta * theta_buckets_, kIntFeatureExtent); + return Modulo(bucket, theta_buckets_); + } + // Returns an INT_FEATURE_STRUCT corresponding to the given buckets. + INT_FEATURE_STRUCT PositionFromBuckets(int x, int y, int theta) const; + + // Feature space definition - serialized. + uinT8 x_buckets_; + uinT8 y_buckets_; + uinT8 theta_buckets_; +}; + +} // namespace tesseract. + + +#endif // TESSERACT_CLASSIFY_INTFEATURESPACE_H__ diff --git a/classify/intfx.cpp b/classify/intfx.cpp index 240fdfc27b..e17b3e56c6 100644 --- a/classify/intfx.cpp +++ b/classify/intfx.cpp @@ -23,9 +23,13 @@ #include "const.h" #include "helpers.h" #include "ccutil.h" +#include "statistc.h" +#include "trainingsample.h" #ifdef __UNIX__ #endif +using tesseract::TrainingSample; + /**---------------------------------------------------------------------------- Private Function Prototypes ----------------------------------------------------------------------------**/ @@ -55,6 +59,10 @@ INT_VAR(classify_radius_gyr_max_exp, 8, // atan(0.0) ... atan(ATAN_TABLE_SIZE - 1 / ATAN_TABLE_SIZE) // The entries are in binary degrees where a full circle is 256 binary degrees. static uinT8 AtanTable[ATAN_TABLE_SIZE]; +// Look up table for cos and sin to turn the intfx feature angle to a vector. +// Also protected by atan_table_mutex. +static float cos_table[INT_CHAR_NORM_RANGE]; +static float sin_table[INT_CHAR_NORM_RANGE]; // Guards write access to AtanTable so we dont create it more than once. tesseract::CCUtilMutex atan_table_mutex; @@ -71,11 +79,46 @@ void InitIntegerFX() { AtanTable[i] = (uinT8) (atan ((i / (float) ATAN_TABLE_SIZE)) * 128.0 / PI + 0.5); } + for (int i = 0; i < INT_CHAR_NORM_RANGE; ++i) { + cos_table[i] = cos(i * 2 * PI / INT_CHAR_NORM_RANGE + PI); + sin_table[i] = sin(i * 2 * PI / INT_CHAR_NORM_RANGE + PI); + } atan_table_init = true; } atan_table_mutex.Unlock(); } +// Returns a vector representing the direction of a feature with the given +// theta direction in an INT_FEATURE_STRUCT. +FCOORD FeatureDirection(uinT8 theta) { + return FCOORD(cos_table[theta], sin_table[theta]); +} + +TrainingSample* GetIntFeatures(tesseract::NormalizationMode mode, + TBLOB *blob, const DENORM& denorm) { + INT_FEATURE_ARRAY blfeatures; + INT_FEATURE_ARRAY cnfeatures; + INT_FX_RESULT_STRUCT fx_info; + ExtractIntFeat(blob, denorm, blfeatures, cnfeatures, &fx_info, NULL); + TrainingSample* sample = NULL; + if (mode == tesseract::NM_CHAR_ANISOTROPIC) { + int num_features = fx_info.NumCN; + if (num_features > 0) { + sample = TrainingSample::CopyFromFeatures(fx_info, cnfeatures, + num_features); + } + } else if (mode == tesseract::NM_BASELINE) { + int num_features = fx_info.NumBL; + if (num_features > 0) { + sample = TrainingSample::CopyFromFeatures(fx_info, blfeatures, + num_features); + } + } else { + ASSERT_HOST(!"Unsupported normalization mode!"); + } + return sample; +} + /*--------------------------------------------------------------------------*/ // Extract a set of standard-sized features from Blobs and write them out in @@ -101,7 +144,7 @@ int ExtractIntFeat(TBLOB *Blob, const DENORM& denorm, INT_FEATURE_ARRAY BLFeat, INT_FEATURE_ARRAY CNFeat, - INT_FX_RESULT Results, + INT_FX_RESULT_STRUCT* Results, inT32 *FeatureOutlineArray) { TESSLINE *OutLine; @@ -131,6 +174,8 @@ int ExtractIntFeat(TBLOB *Blob, Results->Ry = 0; Results->NumBL = 0; Results->NumCN = 0; + Results->YBottom = MAX_UINT8; + Results->YTop = 0; // Calculate the centroid (Xmean, Ymean) for the blob. // We use centroid (instead of center of bounding box or center of smallest @@ -200,6 +245,8 @@ int ExtractIntFeat(TBLOB *Blob, Iy = 0; NumBLFeatures = 0; OutLine = Blob->outlines; + int min_x = 0; + int max_x = 0; while (OutLine != NULL) { LoopStart = OutLine->loop; Loop = LoopStart; @@ -213,6 +260,11 @@ int ExtractIntFeat(TBLOB *Blob, Loop = Loop->next; NormX = Loop->pos.x - Xmean; NormY = Loop->pos.y; + if (NormY < Results->YBottom) + Results->YBottom = ClipToRange(NormY, 0, MAX_UINT8); + if (NormY > Results->YTop) + Results->YTop = ClipToRange(NormY, 0, MAX_UINT8); + UpdateRange(NormX, &min_x, &max_x); n = 1; if (!Segment->IsHidden()) { @@ -261,6 +313,7 @@ int ExtractIntFeat(TBLOB *Blob, while (Loop != LoopStart); OutLine = OutLine->next; } + Results->Width = max_x - min_x; if (Ix == 0) Ix = 1; if (Iy == 0) @@ -440,6 +493,7 @@ int SaveFeature(INT_FEATURE_ARRAY FeatureArray, Feature->X = ClipToRange(X, 0, 255); Feature->Y = ClipToRange(Y, 0, 255); Feature->Theta = Theta; + Feature->CP_misses = 0; return TRUE; } diff --git a/classify/intfx.h b/classify/intfx.h index 06f008b4a6..30150947de 100644 --- a/classify/intfx.h +++ b/classify/intfx.h @@ -23,31 +23,43 @@ ----------------------------------------------------------------------------**/ #include "blobs.h" #include "intproto.h" +#include "normalis.h" #include class DENORM; -typedef struct -{ - inT32 Length; /* total length of all outlines */ - inT16 Xmean, Ymean; /* center of mass of all outlines */ - inT16 Rx, Ry; /* radius of gyration */ - inT16 NumBL, NumCN; /* number of features extracted */ +namespace tesseract { +class TrainingSample; } - -INT_FX_RESULT_STRUCT, *INT_FX_RESULT; +struct INT_FX_RESULT_STRUCT { + inT32 Length; // total length of all outlines + inT16 Xmean, Ymean; // center of mass of all outlines + inT16 Rx, Ry; // radius of gyration + inT16 NumBL, NumCN; // number of features extracted + inT16 Width; // Width of blob in BLN coords. + uinT8 YBottom; // Bottom of blob in BLN coords. + uinT8 YTop; // Top of blob in BLN coords. +}; /**---------------------------------------------------------------------------- Public Function Prototypes ----------------------------------------------------------------------------**/ void InitIntegerFX(); +// Returns a vector representing the direction of a feature with the given +// theta direction in an INT_FEATURE_STRUCT. +FCOORD FeatureDirection(uinT8 theta); + +tesseract::TrainingSample* GetIntFeatures( + tesseract::NormalizationMode mode, TBLOB *blob, + const DENORM& denorm); + int ExtractIntFeat(TBLOB *Blob, const DENORM& denorm, INT_FEATURE_ARRAY BLFeat, INT_FEATURE_ARRAY CNFeat, - INT_FX_RESULT Results, + INT_FX_RESULT_STRUCT* Results, inT32 *FeatureOutlineArray = 0); uinT8 BinaryAnglePlusPi(inT32 Y, inT32 X); diff --git a/classify/intmatcher.cpp b/classify/intmatcher.cpp index b8eb5de2af..ce1653fee2 100644 --- a/classify/intmatcher.cpp +++ b/classify/intmatcher.cpp @@ -23,8 +23,11 @@ #include "intproto.h" #include "callcpp.h" #include "scrollview.h" +#include "float2int.h" #include "globals.h" +#include "helpers.h" #include "classify.h" +#include "shapetable.h" #include // Include automatically generated configuration file if running autoconf. @@ -35,6 +38,11 @@ /*---------------------------------------------------------------------------- Global Data Definitions and Declarations ----------------------------------------------------------------------------*/ +// Parameters of the sigmoid used to convert similarity to evidence in the +// similarity_evidence_table_ that is used to convert distance metric to an +// 8 bit evidence value in the secondary matcher. (See IntMatcher::Init). +const float IntegerMatcher::kSEExponentialMultiplier = 0.0; +const float IntegerMatcher::kSimilarityCenter = 0.0075; static const uinT8 offset_table[256] = { 255, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0, @@ -89,275 +97,360 @@ static const uinT8 next_table[256] = { 0xf8, 0xfc, 0xfc, 0xfe }; -struct ClassPrunerData { - int *class_count_; - int *norm_count_; - int *sort_key_; - int *sort_index_; - int max_classes_; +namespace tesseract { - ClassPrunerData(int max_classes) { - // class_count_ and friends are referenced by indexing off of data in - // class pruner word sized chunks. Each pruner word is of sized - // BITS_PER_WERD and each entry is NUM_BITS_PER_CLASS, so there are - // BITS_PER_WERD / NUM_BITS_PER_CLASS entries. - // See Classify::ClassPruner in intmatcher.cpp. - max_classes_ = RoundUp( +// Encapsulation of the intermediate data and computations made by the class +// pruner. The class pruner implements a simple linear classifier on binary +// features by heavily quantizing the feature space, and applying +// NUM_BITS_PER_CLASS (2)-bit weights to the features. Lack of resolution in +// weights is compensated by a non-constant bias that is dependent on the +// number of features present. +class ClassPruner { + public: + ClassPruner(int max_classes) { + // The unrolled loop in ComputeScores means that the array sizes need to + // be rounded up so that the array is big enough to accommodate the extra + // entries accessed by the unrolling. Each pruner word is of sized + // BITS_PER_WERD and each entry is NUM_BITS_PER_CLASS, so there are + // BITS_PER_WERD / NUM_BITS_PER_CLASS entries. + // See ComputeScores. + max_classes_ = max_classes; + rounded_classes_ = RoundUp( max_classes, WERDS_PER_CP_VECTOR * BITS_PER_WERD / NUM_BITS_PER_CLASS); - class_count_ = new int[max_classes_]; - norm_count_ = new int[max_classes_]; - sort_key_ = new int[max_classes_ + 1]; - sort_index_ = new int[max_classes_ + 1]; - for (int i = 0; i < max_classes_; i++) { + class_count_ = new int[rounded_classes_]; + norm_count_ = new int[rounded_classes_]; + sort_key_ = new int[rounded_classes_ + 1]; + sort_index_ = new int[rounded_classes_ + 1]; + for (int i = 0; i < rounded_classes_; i++) { class_count_[i] = 0; } + pruning_threshold_ = 0; + num_features_ = 0; + num_classes_ = 0; } - ~ClassPrunerData() { + ~ClassPruner() { delete []class_count_; delete []norm_count_; delete []sort_key_; delete []sort_index_; } -}; - -const float IntegerMatcher::kSEExponentialMultiplier = 0.0; -const float IntegerMatcher::kSimilarityCenter = 0.0075; - -/*---------------------------------------------------------------------------- - Public Code -----------------------------------------------------------------------------*/ -/*---------------------------------------------------------------------------*/ -namespace tesseract { -int Classify::ClassPruner(INT_TEMPLATES IntTemplates, - inT16 NumFeatures, - INT_FEATURE_ARRAY Features, - CLASS_NORMALIZATION_ARRAY NormalizationFactors, - CLASS_CUTOFF_ARRAY ExpectedNumFeatures, - CLASS_PRUNER_RESULTS Results) { -/* - ** Parameters: - ** IntTemplates Class pruner tables - ** NumFeatures Number of features in blob - ** Features Array of features - ** NormalizationFactors Array of fudge factors from blob - ** normalization process - ** (by CLASS_INDEX) - ** ExpectedNumFeatures Array of expected number of features - ** for each class - ** (by CLASS_INDEX) - ** Results Sorted Array of pruned classes - ** (by CLASS_ID) - ** Operation: - ** Prune the classes using a modified fast match table. - ** Return a sorted list of classes along with the number - ** of pruned classes in that list. - ** Return: Number of pruned classes. - ** Exceptions: none - ** History: Tue Feb 19 10:24:24 MST 1991, RWM, Created. - */ - uinT32 PrunerWord; - inT32 class_index; //index to class - int Word; - uinT32 *BasePrunerAddress; - uinT32 feature_address; //current feature index - INT_FEATURE feature; //current feature - CLASS_PRUNER *ClassPruner; - int PrunerSet; - int NumPruners; - inT32 feature_index; //current feature - - int MaxNumClasses = IntTemplates->NumClasses; - ClassPrunerData data(IntTemplates->NumClasses); - int *ClassCount = data.class_count_; - int *NormCount = data.norm_count_; - int *SortKey = data.sort_key_; - int *SortIndex = data.sort_index_; - - int out_class; - int MaxCount; - int NumClasses; - FLOAT32 max_rating; //max allowed rating - CLASS_ID class_id; - - /* Update Class Counts */ - NumPruners = IntTemplates->NumClassPruners; - for (feature_index = 0; feature_index < NumFeatures; feature_index++) { - feature = &Features[feature_index]; - feature_address = (((feature->X * NUM_CP_BUCKETS >> 8) * NUM_CP_BUCKETS + - (feature->Y * NUM_CP_BUCKETS >> 8)) * NUM_CP_BUCKETS + - (feature->Theta * NUM_CP_BUCKETS >> 8)) << 1; - ClassPruner = IntTemplates->ClassPruner; - class_index = 0; - - for (PrunerSet = 0; PrunerSet < NumPruners; PrunerSet++, ClassPruner++) { - BasePrunerAddress = (uinT32 *) (*ClassPruner) + feature_address; - - for (Word = 0; Word < WERDS_PER_CP_VECTOR; Word++) { - PrunerWord = *BasePrunerAddress++; - // This inner loop is unrolled to speed up the ClassPruner. - // Currently gcc would not unroll it unless it is set to O3 - // level of optimization or -funroll-loops is specified. - /* - uinT32 class_mask = (1 << NUM_BITS_PER_CLASS) - 1; - for (int bit = 0; bit < BITS_PER_WERD/NUM_BITS_PER_CLASS; bit++) { - ClassCount[class_index++] += PrunerWord & class_mask; - PrunerWord >>= NUM_BITS_PER_CLASS; + // Computes the scores for every class in the character set, by summing the + // weights for each feature and stores the sums internally in class_count_. + void ComputeScores(const INT_TEMPLATES_STRUCT* int_templates, + int num_features, const INT_FEATURE_STRUCT* features) { + num_features_ = num_features; + int num_pruners = int_templates->NumClassPruners; + for (int f = 0; f < num_features; ++f) { + const INT_FEATURE_STRUCT* feature = &features[f]; + // Quantize the feature to NUM_CP_BUCKETS*NUM_CP_BUCKETS*NUM_CP_BUCKETS. + int x = feature->X * NUM_CP_BUCKETS >> 8; + int y = feature->Y * NUM_CP_BUCKETS >> 8; + int theta = feature->Theta * NUM_CP_BUCKETS >> 8; + int class_id = 0; + // Each CLASS_PRUNER_STRUCT only covers CLASSES_PER_CP(32) classes, so + // we need a collection of them, indexed by pruner_set. + for (int pruner_set = 0; pruner_set < num_pruners; ++pruner_set) { + // Look up quantized feature in a 3-D array, an array of weights for + // each class. + const uinT32* pruner_word_ptr = + int_templates->ClassPruners[pruner_set]->p[x][y][theta]; + for (int word = 0; word < WERDS_PER_CP_VECTOR; ++word) { + uinT32 pruner_word = *pruner_word_ptr++; + // This inner loop is unrolled to speed up the ClassPruner. + // Currently gcc would not unroll it unless it is set to O3 + // level of optimization or -funroll-loops is specified. + /* + uinT32 class_mask = (1 << NUM_BITS_PER_CLASS) - 1; + for (int bit = 0; bit < BITS_PER_WERD/NUM_BITS_PER_CLASS; bit++) { + class_count_[class_id++] += pruner_word & class_mask; + pruner_word >>= NUM_BITS_PER_CLASS; + } + */ + class_count_[class_id++] += pruner_word & CLASS_PRUNER_CLASS_MASK; + pruner_word >>= NUM_BITS_PER_CLASS; + class_count_[class_id++] += pruner_word & CLASS_PRUNER_CLASS_MASK; + pruner_word >>= NUM_BITS_PER_CLASS; + class_count_[class_id++] += pruner_word & CLASS_PRUNER_CLASS_MASK; + pruner_word >>= NUM_BITS_PER_CLASS; + class_count_[class_id++] += pruner_word & CLASS_PRUNER_CLASS_MASK; + pruner_word >>= NUM_BITS_PER_CLASS; + class_count_[class_id++] += pruner_word & CLASS_PRUNER_CLASS_MASK; + pruner_word >>= NUM_BITS_PER_CLASS; + class_count_[class_id++] += pruner_word & CLASS_PRUNER_CLASS_MASK; + pruner_word >>= NUM_BITS_PER_CLASS; + class_count_[class_id++] += pruner_word & CLASS_PRUNER_CLASS_MASK; + pruner_word >>= NUM_BITS_PER_CLASS; + class_count_[class_id++] += pruner_word & CLASS_PRUNER_CLASS_MASK; + pruner_word >>= NUM_BITS_PER_CLASS; + class_count_[class_id++] += pruner_word & CLASS_PRUNER_CLASS_MASK; + pruner_word >>= NUM_BITS_PER_CLASS; + class_count_[class_id++] += pruner_word & CLASS_PRUNER_CLASS_MASK; + pruner_word >>= NUM_BITS_PER_CLASS; + class_count_[class_id++] += pruner_word & CLASS_PRUNER_CLASS_MASK; + pruner_word >>= NUM_BITS_PER_CLASS; + class_count_[class_id++] += pruner_word & CLASS_PRUNER_CLASS_MASK; + pruner_word >>= NUM_BITS_PER_CLASS; + class_count_[class_id++] += pruner_word & CLASS_PRUNER_CLASS_MASK; + pruner_word >>= NUM_BITS_PER_CLASS; + class_count_[class_id++] += pruner_word & CLASS_PRUNER_CLASS_MASK; + pruner_word >>= NUM_BITS_PER_CLASS; + class_count_[class_id++] += pruner_word & CLASS_PRUNER_CLASS_MASK; + pruner_word >>= NUM_BITS_PER_CLASS; + class_count_[class_id++] += pruner_word & CLASS_PRUNER_CLASS_MASK; } - */ - ClassCount[class_index++] += PrunerWord & CLASS_PRUNER_CLASS_MASK; - PrunerWord >>= NUM_BITS_PER_CLASS; - ClassCount[class_index++] += PrunerWord & CLASS_PRUNER_CLASS_MASK; - PrunerWord >>= NUM_BITS_PER_CLASS; - ClassCount[class_index++] += PrunerWord & CLASS_PRUNER_CLASS_MASK; - PrunerWord >>= NUM_BITS_PER_CLASS; - ClassCount[class_index++] += PrunerWord & CLASS_PRUNER_CLASS_MASK; - PrunerWord >>= NUM_BITS_PER_CLASS; - ClassCount[class_index++] += PrunerWord & CLASS_PRUNER_CLASS_MASK; - PrunerWord >>= NUM_BITS_PER_CLASS; - ClassCount[class_index++] += PrunerWord & CLASS_PRUNER_CLASS_MASK; - PrunerWord >>= NUM_BITS_PER_CLASS; - ClassCount[class_index++] += PrunerWord & CLASS_PRUNER_CLASS_MASK; - PrunerWord >>= NUM_BITS_PER_CLASS; - ClassCount[class_index++] += PrunerWord & CLASS_PRUNER_CLASS_MASK; - PrunerWord >>= NUM_BITS_PER_CLASS; - ClassCount[class_index++] += PrunerWord & CLASS_PRUNER_CLASS_MASK; - PrunerWord >>= NUM_BITS_PER_CLASS; - ClassCount[class_index++] += PrunerWord & CLASS_PRUNER_CLASS_MASK; - PrunerWord >>= NUM_BITS_PER_CLASS; - ClassCount[class_index++] += PrunerWord & CLASS_PRUNER_CLASS_MASK; - PrunerWord >>= NUM_BITS_PER_CLASS; - ClassCount[class_index++] += PrunerWord & CLASS_PRUNER_CLASS_MASK; - PrunerWord >>= NUM_BITS_PER_CLASS; - ClassCount[class_index++] += PrunerWord & CLASS_PRUNER_CLASS_MASK; - PrunerWord >>= NUM_BITS_PER_CLASS; - ClassCount[class_index++] += PrunerWord & CLASS_PRUNER_CLASS_MASK; - PrunerWord >>= NUM_BITS_PER_CLASS; - ClassCount[class_index++] += PrunerWord & CLASS_PRUNER_CLASS_MASK; - PrunerWord >>= NUM_BITS_PER_CLASS; - ClassCount[class_index++] += PrunerWord & CLASS_PRUNER_CLASS_MASK; } } } - /* Adjust Class Counts for Number of Expected Features */ - for (class_id = 0; class_id < MaxNumClasses; class_id++) { - if (NumFeatures < ExpectedNumFeatures[class_id]) { - int deficit = ExpectedNumFeatures[class_id] - NumFeatures; - ClassCount[class_id] -= ClassCount[class_id] * deficit / - (NumFeatures * classify_cp_cutoff_strength + deficit); + // Adjusts the scores according to the number of expected features. Used + // in lieu of a constant bias, this penalizes classes that expect more + // features than there are present. Thus an actual c will score higher for c + // than e, even though almost all the features match e as well as c, because + // e expects more features to be present. + void AdjustForExpectedNumFeatures(const uinT16* expected_num_features, + int cutoff_strength) { + for (int class_id = 0; class_id < max_classes_; ++class_id) { + if (num_features_ < expected_num_features[class_id]) { + int deficit = expected_num_features[class_id] - num_features_; + class_count_[class_id] -= class_count_[class_id] * deficit / + (num_features_ * cutoff_strength + deficit); + } } - if (!unicharset.get_enabled(class_id)) - ClassCount[class_id] = 0; // This char is disabled! + } - // Do not include character fragments in the class pruner - // results if disable_character_fragments is true. - if (disable_character_fragments && unicharset.get_fragment(class_id)) { - ClassCount[class_id] = 0; + // Zeros the scores for classes disabled in the unicharset. + // Implements the black-list to recognize a subset of the character set. + void DisableDisabledClasses(const UNICHARSET& unicharset) { + for (int class_id = 0; class_id < max_classes_; ++class_id) { + if (!unicharset.get_enabled(class_id)) + class_count_[class_id] = 0; // This char is disabled! } } - /* Adjust Class Counts for Normalization Factors */ - MaxCount = 0; - for (class_id = 0; class_id < MaxNumClasses; class_id++) { - NormCount[class_id] = ClassCount[class_id] - - ((classify_class_pruner_multiplier * NormalizationFactors[class_id]) - >> 8); - if (NormCount[class_id] > MaxCount && - // This additional check is added in order to ensure that - // the classifier will return at least one non-fragmented - // character match. - // TODO(daria): verify that this helps accuracy and does not - // hurt performance. - !unicharset.get_fragment(class_id)) { - MaxCount = NormCount[class_id]; + // Zeros the scores of fragments. + void DisableFragments(const UNICHARSET& unicharset) { + for (int class_id = 0; class_id < max_classes_; ++class_id) { + // Do not include character fragments in the class pruner + // results if disable_character_fragments is true. + if (unicharset.get_fragment(class_id)) { + class_count_[class_id] = 0; + } } } - /* Prune Classes */ - MaxCount *= classify_class_pruner_threshold; - MaxCount >>= 8; - /* Select Classes */ - if (MaxCount < 1) - MaxCount = 1; - NumClasses = 0; - for (class_id = 0; class_id < MaxNumClasses; class_id++) { - if (NormCount[class_id] >= MaxCount) { - NumClasses++; - SortIndex[NumClasses] = class_id; - SortKey[NumClasses] = NormCount[class_id]; + // Normalizes the counts for xheight, putting the normalized result in + // norm_count_. Applies a simple subtractive penalty for incorrect vertical + // position provided by the normalization_factors array, indexed by + // character class, and scaled by the norm_multiplier. + void NormalizeForXheight(int norm_multiplier, + const uinT8* normalization_factors) { + for (int class_id = 0; class_id < max_classes_; class_id++) { + norm_count_[class_id] = class_count_[class_id] - + ((norm_multiplier * normalization_factors[class_id]) >> 8); } } - /* Sort Classes using Heapsort Algorithm */ - if (NumClasses > 1) - HeapSort(NumClasses, SortKey, SortIndex); + // The nop normalization copies the class_count_ array to norm_count_. + void NoNormalization() { + for (int class_id = 0; class_id < max_classes_; class_id++) { + norm_count_[class_id] = class_count_[class_id]; + } + } - if (classify_debug_level > 1) { - cprintf ("CP:%d classes, %d features:\n", NumClasses, NumFeatures); - for (class_id = 0; class_id < NumClasses; class_id++) { - cprintf ("%s:C=%d, E=%d, N=%d, Rat=%d\n", - unicharset.debug_str(SortIndex[NumClasses - class_id]).string(), - ClassCount[SortIndex[NumClasses - class_id]], - ExpectedNumFeatures[SortIndex[NumClasses - class_id]], - SortKey[NumClasses - class_id], - 1010 - 1000 * SortKey[NumClasses - class_id] / - (CLASS_PRUNER_CLASS_MASK * NumFeatures)); + // Prunes the classes using * pruning_factor/256 as a + // threshold for keeping classes. If max_of_non_fragments, then ignore + // fragments in computing the maximum count. + void PruneAndSort(int pruning_factor, bool max_of_non_fragments, + const UNICHARSET& unicharset) { + int max_count = 0; + for (int c = 0; c < max_classes_; ++c) { + if (norm_count_[c] > max_count && + // This additional check is added in order to ensure that + // the classifier will return at least one non-fragmented + // character match. + // TODO(daria): verify that this helps accuracy and does not + // hurt performance. + (!max_of_non_fragments || !unicharset.get_fragment(c))) { + max_count = norm_count_[c]; + } + } + // Prune Classes. + pruning_threshold_ = (max_count * pruning_factor) >> 8; + // Select Classes. + if (pruning_threshold_ < 1) + pruning_threshold_ = 1; + num_classes_ = 0; + for (int class_id = 0; class_id < max_classes_; class_id++) { + if (norm_count_[class_id] >= pruning_threshold_) { + ++num_classes_; + sort_index_[num_classes_] = class_id; + sort_key_[num_classes_] = norm_count_[class_id]; + } } - if (classify_debug_level > 2) { - NumPruners = IntTemplates->NumClassPruners; - for (feature_index = 0; feature_index < NumFeatures; - feature_index++) { - cprintf ("F=%3d,", feature_index); - feature = &Features[feature_index]; - feature_address = - (((feature->X * NUM_CP_BUCKETS >> 8) * NUM_CP_BUCKETS + - (feature->Y * NUM_CP_BUCKETS >> 8)) * NUM_CP_BUCKETS + - (feature->Theta * NUM_CP_BUCKETS >> 8)) << 1; - ClassPruner = IntTemplates->ClassPruner; - class_index = 0; - for (PrunerSet = 0; PrunerSet < NumPruners; - PrunerSet++, ClassPruner++) { - BasePrunerAddress = (uinT32 *) (*ClassPruner) - + feature_address; - - for (Word = 0; Word < WERDS_PER_CP_VECTOR; Word++) { - PrunerWord = *BasePrunerAddress++; - for (class_id = 0; class_id < 16; class_id++, class_index++) { - if (NormCount[class_index] >= MaxCount) - cprintf (" %s=%d,", - unicharset.id_to_unichar(class_index), - PrunerWord & CLASS_PRUNER_CLASS_MASK); - PrunerWord >>= NUM_BITS_PER_CLASS; + + // Sort Classes using Heapsort Algorithm. + if (num_classes_ > 1) + HeapSort(num_classes_, sort_key_, sort_index_); + } + + // Prints debug info on the class pruner matches for the pruned classes only. + void DebugMatch(const Classify& classify, + const INT_TEMPLATES_STRUCT* int_templates, + const INT_FEATURE_STRUCT* features) const { + int num_pruners = int_templates->NumClassPruners; + int max_num_classes = int_templates->NumClasses; + for (int f = 0; f < num_features_; ++f) { + const INT_FEATURE_STRUCT* feature = &features[f]; + tprintf("F=%3d(%d,%d,%d),", f, feature->X, feature->Y, feature->Theta); + // Quantize the feature to NUM_CP_BUCKETS*NUM_CP_BUCKETS*NUM_CP_BUCKETS. + int x = feature->X * NUM_CP_BUCKETS >> 8; + int y = feature->Y * NUM_CP_BUCKETS >> 8; + int theta = feature->Theta * NUM_CP_BUCKETS >> 8; + int class_id = 0; + for (int pruner_set = 0; pruner_set < num_pruners; ++pruner_set) { + // Look up quantized feature in a 3-D array, an array of weights for + // each class. + const uinT32* pruner_word_ptr = + int_templates->ClassPruners[pruner_set]->p[x][y][theta]; + for (int word = 0; word < WERDS_PER_CP_VECTOR; ++word) { + uinT32 pruner_word = *pruner_word_ptr++; + for (int word_class = 0; word_class < 16 && + class_id < max_num_classes; ++word_class, ++class_id) { + if (norm_count_[class_id] >= pruning_threshold_) { + tprintf(" %s=%d,", + classify.ClassIDToDebugStr(int_templates, + class_id, 0).string(), + pruner_word & CLASS_PRUNER_CLASS_MASK); } + pruner_word >>= NUM_BITS_PER_CLASS; } } - cprintf ("\n"); - } - cprintf ("Adjustments:"); - for (class_id = 0; class_id < MaxNumClasses; class_id++) { - if (NormCount[class_id] > MaxCount) - cprintf(" %s=%d,", - unicharset.id_to_unichar(class_id), - -((classify_class_pruner_multiplier * - NormalizationFactors[class_id]) >> 8)); + tprintf("\n"); } - cprintf ("\n"); } } - /* Set Up Results */ - max_rating = 0.0f; - for (class_id = 0, out_class = 0; class_id < NumClasses; class_id++) { - Results[out_class].Class = SortIndex[NumClasses - class_id]; - Results[out_class].Rating = - 1.0 - SortKey[NumClasses - class_id] / - (static_cast(CLASS_PRUNER_CLASS_MASK) * NumFeatures); - out_class++; + // Prints a summary of the pruner result. + void SummarizeResult(const Classify& classify, + const INT_TEMPLATES_STRUCT* int_templates, + const uinT16* expected_num_features, + int norm_multiplier, + const uinT8* normalization_factors) const { + tprintf("CP:%d classes, %d features:\n", num_classes_, num_features_); + for (int i = 0; i < num_classes_; ++i) { + int class_id = sort_index_[num_classes_ - i]; + STRING class_string = classify.ClassIDToDebugStr(int_templates, + class_id, 0); + tprintf("%s:Initial=%d, E=%d, Xht-adj=%d, N=%d, Rat=%.2f\n", + class_string.string(), + class_count_[class_id], + expected_num_features[class_id], + (norm_multiplier * normalization_factors[class_id]) >> 8, + sort_key_[num_classes_ - i], + 100.0 - 100.0 * sort_key_[num_classes_ - i] / + (CLASS_PRUNER_CLASS_MASK * num_features_)); + } + } + + // Copies the pruned, sorted classes into the output results and returns + // the number of classes. + int SetupResults(CP_RESULT_STRUCT* results) const { + for (int c = 0; c < num_classes_; ++c) { + results[c].Class = sort_index_[num_classes_ - c]; + results[c].Rating = 1.0 - sort_key_[num_classes_ - c] / + (static_cast(CLASS_PRUNER_CLASS_MASK) * num_features_); + } + return num_classes_; } - NumClasses = out_class; - return NumClasses; + + private: + // Array[rounded_classes_] of initial counts for each class. + int *class_count_; + // Array[rounded_classes_] of modified counts for each class after normalizing + // for expected number of features, disabled classes, fragments, and xheights. + int *norm_count_; + // Array[rounded_classes_ +1] of pruned counts that gets sorted + int *sort_key_; + // Array[rounded_classes_ +1] of classes corresponding to sort_key_. + int *sort_index_; + // Number of classes in this class pruner. + int max_classes_; + // Rounded up number of classes used for array sizes. + int rounded_classes_; + // Threshold count applied to prune classes. + int pruning_threshold_; + // The number of features used to compute the scores. + int num_features_; + // Final number of pruned classes. + int num_classes_; +}; + +/*---------------------------------------------------------------------------- + Public Code +----------------------------------------------------------------------------*/ +/*---------------------------------------------------------------------------*/ +// Runs the class pruner from int_templates on the given features, returning +// the number of classes output in results. +// int_templates Class pruner tables +// num_features Number of features in blob +// features Array of features +// normalization_factors Array of fudge factors from blob +// normalization process (by CLASS_INDEX) +// expected_num_features Array of expected number of features +// for each class (by CLASS_INDEX) +// results Sorted Array of pruned classes. Must be an array +// of size at least int_templates->NumClasses. +int Classify::PruneClasses(const INT_TEMPLATES_STRUCT* int_templates, + int num_features, + const INT_FEATURE_STRUCT* features, + const uinT8* normalization_factors, + const uinT16* expected_num_features, + CP_RESULT_STRUCT* results) { +/* + ** Operation: + ** Prunes the classes using a modified fast match table. + ** Returns a sorted list of classes along with the number + ** of pruned classes in that list. + ** Return: Number of pruned classes. + ** Exceptions: none + ** History: Tue Feb 19 10:24:24 MST 1991, RWM, Created. + */ + ClassPruner pruner(int_templates->NumClasses); + // Compute initial match scores for all classes. + pruner.ComputeScores(int_templates, num_features, features); + // Adjust match scores for number of expected features. + pruner.AdjustForExpectedNumFeatures(expected_num_features, + classify_cp_cutoff_strength); + // Apply disabled classes in unicharset - only works without a shape_table. + if (shape_table_ == NULL) + pruner.DisableDisabledClasses(unicharset); + // If fragments are disabled, remove them, also only without a shape table. + if (disable_character_fragments && shape_table_ == NULL) + pruner.DisableFragments(unicharset); + + // If we have good x-heights, apply the given normalization factors. + if (normalization_factors != NULL) { + pruner.NormalizeForXheight(classify_class_pruner_multiplier, + normalization_factors); + } else { + pruner.NoNormalization(); + } + // Do the actual pruning and sort the short-list. + pruner.PruneAndSort(classify_class_pruner_threshold, + shape_table_ == NULL, unicharset); + + if (classify_debug_level > 2) { + pruner.DebugMatch(*this, int_templates, features); + } + if (classify_debug_level > 1) { + pruner.SummarizeResult(*this, int_templates, expected_num_features, + classify_class_pruner_multiplier, + normalization_factors); + } + // Convert to the expected output format. + return pruner.SetupResults(results); } } // namespace tesseract @@ -366,10 +459,8 @@ int Classify::ClassPruner(INT_TEMPLATES IntTemplates, void IntegerMatcher::Match(INT_CLASS ClassTemplate, BIT_VECTOR ProtoMask, BIT_VECTOR ConfigMask, - uinT16 BlobLength, inT16 NumFeatures, - INT_FEATURE_ARRAY Features, - uinT8 NormalizationFactor, + const INT_FEATURE_STRUCT* Features, INT_RESULT Result, int AdaptFeatureThreshold, int Debug, @@ -436,12 +527,11 @@ void IntegerMatcher::Match(INT_CLASS ClassTemplate, tables->UpdateSumOfProtoEvidences(ClassTemplate, ConfigMask, NumFeatures); tables->NormalizeSums(ClassTemplate, NumFeatures, NumFeatures); - BestMatch = FindBestMatch(ClassTemplate, *tables, BlobLength, - NormalizationFactor, Result); + BestMatch = FindBestMatch(ClassTemplate, *tables, Result); #ifndef GRAPHICS_DISABLED if (PrintMatchSummaryOn(Debug)) - DebugBestMatch(BestMatch, Result, BlobLength, NormalizationFactor); + DebugBestMatch(BestMatch, Result); if (MatchDebuggingOn(Debug)) cprintf("Match Complete --------------------------------------------\n"); @@ -718,7 +808,7 @@ int IntegerMatcher::UpdateTablesForFeature( BIT_VECTOR ProtoMask, BIT_VECTOR ConfigMask, int FeatureNum, - INT_FEATURE Feature, + const INT_FEATURE_STRUCT* Feature, ScratchEvidence *tables, int Debug) { /* @@ -1048,7 +1138,7 @@ void IntegerMatcher::DisplayFeatureDebugInfo( BIT_VECTOR ProtoMask, BIT_VECTOR ConfigMask, inT16 NumFeatures, - INT_FEATURE_ARRAY Features, + const INT_FEATURE_STRUCT* Features, int AdaptFeatureThreshold, int Debug, bool SeparateDebugWindows) { @@ -1146,8 +1236,6 @@ void ScratchEvidence::NormalizeSums( int IntegerMatcher::FindBestMatch( INT_CLASS ClassTemplate, const ScratchEvidence &tables, - uinT16 BlobLength, - uinT8 NormalizationFactor, INT_RESULT Result) { /* ** Parameters: @@ -1168,7 +1256,7 @@ int IntegerMatcher::FindBestMatch( /* Find best match */ for (int ConfigNum = 0; ConfigNum < ClassTemplate->NumConfigs; ConfigNum++) { int rating = tables.sum_feature_evidence_[ConfigNum]; - if (*classify_debug_level_ > 1) + if (*classify_debug_level_ > 2) cprintf("Config %d, rating=%d\n", ConfigNum, rating); if (rating > BestMatch) { if (BestMatch > 0) { @@ -1186,31 +1274,28 @@ int IntegerMatcher::FindBestMatch( } /* Compute Certainty Rating */ - Result->Rating = ((65536.0 - BestMatch) / 65536.0 * BlobLength + - local_matcher_multiplier_ * NormalizationFactor / 256.0) / - (BlobLength + local_matcher_multiplier_); + Result->Rating = (65536.0 - BestMatch) / 65536.0; return BestMatch; } +// Applies the CN normalization factor to the given rating and returns +// the modified rating. +float IntegerMatcher::ApplyCNCorrection(float rating, int blob_length, + int normalization_factor) { + return (rating * blob_length + + local_matcher_multiplier_ * normalization_factor / 256.0) / + (blob_length + local_matcher_multiplier_); +} + /*---------------------------------------------------------------------------*/ #ifndef GRAPHICS_DISABLED // Print debug information about the best match for the current class. void IntegerMatcher::DebugBestMatch( - int BestMatch, INT_RESULT Result, uinT16 BlobLength, - uinT8 NormalizationFactor) { - cprintf("Rating = %5.1f%% Best Config = %3d\n", - 100.0 * ((*Result).Rating), (int) ((*Result).Config)); - cprintf - ("Matcher Error = %5.1f%% Blob Length = %3d Weight = %4.1f%%\n", - 100.0 * (65536.0 - BestMatch) / 65536.0, (int) BlobLength, - 100.0 * BlobLength / (BlobLength + local_matcher_multiplier_)); - cprintf - ("Char Norm Error = %5.1f%% Norm Strength = %3d Weight = %4.1f%%\n", - 100.0 * NormalizationFactor / 256.0, - local_matcher_multiplier_, - 100.0 * local_matcher_multiplier_ / - (BlobLength + local_matcher_multiplier_)); + int BestMatch, INT_RESULT Result) { + tprintf("Rating = %5.1f%% Best Config = %3d, Distance = %5.1f\n", + 100.0 * Result->Rating, Result->Config, + 100.0 * (65536.0 - BestMatch) / 65536.0); } #endif diff --git a/classify/intmatcher.h b/classify/intmatcher.h index 0d28988fe8..50dbd9796d 100644 --- a/classify/intmatcher.h +++ b/classify/intmatcher.h @@ -56,8 +56,6 @@ struct CP_RESULT_STRUCT { typedef CP_RESULT_STRUCT CLASS_PRUNER_RESULTS[MAX_NUM_CLASSES]; -typedef uinT8 CLASS_NORMALIZATION_ARRAY[MAX_NUM_CLASSES]; - /*---------------------------------------------------------------------------- Variables -----------------------------------------------------------------------------*/ @@ -113,15 +111,18 @@ class IntegerMatcher { void Match(INT_CLASS ClassTemplate, BIT_VECTOR ProtoMask, BIT_VECTOR ConfigMask, - uinT16 BlobLength, inT16 NumFeatures, - INT_FEATURE_ARRAY Features, - uinT8 NormalizationFactor, + const INT_FEATURE_STRUCT* Features, INT_RESULT Result, int AdaptFeatureThreshold, int Debug, bool SeparateDebugWindows); + // Applies the CN normalization factor to the given rating and returns + // the modified rating. + float ApplyCNCorrection(float rating, int blob_length, + int normalization_factor); + int FindGoodProtos(INT_CLASS ClassTemplate, BIT_VECTOR ProtoMask, BIT_VECTOR ConfigMask, @@ -148,14 +149,12 @@ class IntegerMatcher { BIT_VECTOR ProtoMask, BIT_VECTOR ConfigMask, int FeatureNum, - INT_FEATURE Feature, + const INT_FEATURE_STRUCT* Feature, ScratchEvidence *evidence, int Debug); int FindBestMatch(INT_CLASS ClassTemplate, const ScratchEvidence &tables, - uinT16 BlobLength, - uinT8 NormalizationFactor, INT_RESULT Result); #ifndef GRAPHICS_DISABLED @@ -179,15 +178,12 @@ class IntegerMatcher { BIT_VECTOR ProtoMask, BIT_VECTOR ConfigMask, inT16 NumFeatures, - INT_FEATURE_ARRAY Features, + const INT_FEATURE_STRUCT* Features, int AdaptFeatureThreshold, int Debug, bool SeparateDebugWindows); - void DebugBestMatch(int BestMatch, - INT_RESULT Result, - uinT16 BlobLength, - uinT8 NormalizationFactor); + void DebugBestMatch(int BestMatch, INT_RESULT Result); #endif diff --git a/classify/intproto.cpp b/classify/intproto.cpp index 28f52f2271..02674ee0d9 100644 --- a/classify/intproto.cpp +++ b/classify/intproto.cpp @@ -18,19 +18,6 @@ /*----------------------------------------------------------------------------- Include Files and Type Defines -----------------------------------------------------------------------------*/ -#include "helpers.h" -#include "intproto.h" -#include "picofeat.h" -#include "mfoutline.h" -#include "emalloc.h" -#include "const.h" -#include "ndminx.h" -#include "svmnode.h" -#include "globals.h" -#include "classify.h" -#include "genericvector.h" - -//extern GetPicoFeatureLength(); #include #include @@ -39,11 +26,29 @@ #include #endif +#include "classify.h" +#include "const.h" +#include "emalloc.h" +#include "fontinfo.h" +#include "genericvector.h" +#include "globals.h" +#include "helpers.h" +#include "intproto.h" +#include "mfoutline.h" +#include "ndminx.h" +#include "picofeat.h" +#include "shapetable.h" +#include "svmnode.h" + // Include automatically generated configuration file if running autoconf. #ifdef HAVE_CONFIG_H #include "config_auto.h" #endif +using tesseract::FontInfo; +using tesseract::FontSet; +using tesseract::FontSpacingInfo; + /* match debug display constants*/ #define PROTO_PRUNER_SCALE (4.0) @@ -126,7 +131,7 @@ FLOAT32 BucketStart(int Bucket, FLOAT32 Offset, int NumBuckets); FLOAT32 BucketEnd(int Bucket, FLOAT32 Offset, int NumBuckets); void DoFill(FILL_SPEC *FillSpec, - CLASS_PRUNER Pruner, + CLASS_PRUNER_STRUCT* Pruner, register uinT32 ClassMask, register uinT32 ClassCount, register uinT32 WordIndex); @@ -218,7 +223,6 @@ double_VAR(classify_pp_side_pad, 2.5, "Proto Pruner Side Pad"); */ void AddIntClass(INT_TEMPLATES Templates, CLASS_ID ClassId, INT_CLASS Class) { int Pruner; - uinT32 *Word; assert (LegalClassId (ClassId)); if (ClassId != Templates->NumClasses) { @@ -231,13 +235,8 @@ void AddIntClass(INT_TEMPLATES Templates, CLASS_ID ClassId, INT_CLASS Class) { if (Templates->NumClasses > MaxNumClassesIn (Templates)) { Pruner = Templates->NumClassPruners++; - Templates->ClassPruner[Pruner] = - (CLASS_PRUNER) Emalloc (sizeof (CLASS_PRUNER_STRUCT)); - - for (Word = reinterpret_cast(Templates->ClassPruner[Pruner]); - Word < reinterpret_cast(Templates->ClassPruner[Pruner]) + - WERDS_PER_CP; - *Word++ = 0); + Templates->ClassPruners[Pruner] = new CLASS_PRUNER_STRUCT; + memset(Templates->ClassPruners[Pruner], 0, sizeof(CLASS_PRUNER_STRUCT)); } } /* AddIntClass */ @@ -296,14 +295,14 @@ int AddIntProto(INT_CLASS Class) { ProtoSet = (PROTO_SET) Emalloc(sizeof(PROTO_SET_STRUCT)); Class->ProtoSets[ProtoSetId] = ProtoSet; - for (Word = reinterpret_cast(ProtoSet->ProtoPruner); - Word < reinterpret_cast(ProtoSet->ProtoPruner) + WERDS_PER_PP; - *Word++ = 0); + memset(ProtoSet, 0, sizeof(*ProtoSet)); /* reallocate space for the proto lengths and install in class */ Class->ProtoLengths = (uinT8 *)Erealloc(Class->ProtoLengths, MaxNumIntProtosIn(Class) * sizeof(uinT8)); + memset(&Class->ProtoLengths[Index], 0, + sizeof(*Class->ProtoLengths) * (MaxNumIntProtosIn(Class) - Index)); } /* initialize proto so its length is zero and it isn't in any configs */ @@ -335,7 +334,7 @@ void AddProtoToClassPruner (PROTO Proto, CLASS_ID ClassId, */ #define MAX_LEVEL 2 { - CLASS_PRUNER Pruner; + CLASS_PRUNER_STRUCT* Pruner; uinT32 ClassMask; uinT32 ClassCount; uinT32 WordIndex; @@ -636,7 +635,7 @@ INT_TEMPLATES Classify::CreateIntTemplates(CLASSES FloatProtos, /*---------------------------------------------------------------------------*/ #ifndef GRAPHICS_DISABLED -void DisplayIntFeature(INT_FEATURE Feature, FLOAT32 Evidence) { +void DisplayIntFeature(const INT_FEATURE_STRUCT* Feature, FLOAT32 Evidence) { /* ** Parameters: ** Feature pico-feature to be displayed @@ -697,7 +696,6 @@ INT_CLASS NewIntClass(int MaxNumProtos, int MaxNumConfigs) { INT_CLASS Class; PROTO_SET ProtoSet; int i; - register uinT32 *Word; assert(MaxNumConfigs <= MAX_NUM_CONFIGS); @@ -713,17 +711,20 @@ INT_CLASS NewIntClass(int MaxNumProtos, int MaxNumConfigs) { for (i = 0; i < Class->NumProtoSets; i++) { /* allocate space for a proto set, install in class, and initialize */ ProtoSet = (PROTO_SET) Emalloc(sizeof(PROTO_SET_STRUCT)); + memset(ProtoSet, 0, sizeof(*ProtoSet)); Class->ProtoSets[i] = ProtoSet; - for (Word = reinterpret_cast(ProtoSet->ProtoPruner); - Word < reinterpret_cast(ProtoSet->ProtoPruner) + WERDS_PER_PP; - *Word++ = 0); /* allocate space for the proto lengths and install in class */ } if (MaxNumIntProtosIn (Class) > 0) { Class->ProtoLengths = (uinT8 *)Emalloc(MaxNumIntProtosIn (Class) * sizeof (uinT8)); + memset(Class->ProtoLengths, 0, + MaxNumIntProtosIn(Class) * sizeof(*Class->ProtoLengths)); + } else { + Class->ProtoLengths = NULL; } + memset(Class->ConfigLengths, 0, sizeof(Class->ConfigLengths)); return (Class); @@ -776,120 +777,11 @@ void free_int_templates(INT_TEMPLATES templates) { for (i = 0; i < templates->NumClasses; i++) free_int_class(templates->Class[i]); for (i = 0; i < templates->NumClassPruners; i++) - Efree(templates->ClassPruner[i]); + delete templates->ClassPruners[i]; Efree(templates); } -/*---------------------------------------------------------------------------*/ -// Code to read/write Classify::font*table structures. -namespace { -bool read_info(FILE* f, FontInfo* fi, bool swap) { - inT32 size; - if (fread(&size, sizeof(size), 1, f) != 1) return false; - if (swap) - Reverse32(&size); - char* font_name = new char[size + 1]; - fi->name = font_name; - if (fread(font_name, sizeof(*font_name), size, f) != size) return false; - font_name[size] = '\0'; - if (fread(&fi->properties, sizeof(fi->properties), 1, f) != 1) return false; - if (swap) - Reverse32(&fi->properties); - return true; -} - -bool write_info(FILE* f, const FontInfo& fi) { - inT32 size = strlen(fi.name); - if (fwrite(&size, sizeof(size), 1, f) != 1) return false; - if (fwrite(fi.name, sizeof(*fi.name), size, f) != size) return false; - if (fwrite(&fi.properties, sizeof(fi.properties), 1, f) != 1) return false; - return true; -} - -bool read_spacing_info(FILE *f, FontInfo* fi, bool swap) { - inT32 vec_size, kern_size; - if (fread(&vec_size, sizeof(vec_size), 1, f) != 1) return false; - if (swap) Reverse32(&vec_size); - ASSERT_HOST(vec_size >= 0); - if (vec_size == 0) return true; - fi->init_spacing(vec_size); - for (int i = 0; i < vec_size; ++i) { - FontSpacingInfo *fs = new FontSpacingInfo(); - if (fread(&fs->x_gap_before, sizeof(fs->x_gap_before), 1, f) != 1 || - fread(&fs->x_gap_after, sizeof(fs->x_gap_after), 1, f) != 1 || - fread(&kern_size, sizeof(kern_size), 1, f) != 1) { - return false; - } - if (swap) { - ReverseN(&(fs->x_gap_before), sizeof(fs->x_gap_before)); - ReverseN(&(fs->x_gap_after), sizeof(fs->x_gap_after)); - Reverse32(&kern_size); - } - if (kern_size < 0) { // indication of a NULL entry in fi->spacing_vec - delete fs; - continue; - } - if (kern_size > 0 && (!fs->kerned_unichar_ids.DeSerialize(swap, f) || - !fs->kerned_x_gaps.DeSerialize(swap, f))) { - return false; - } - fi->add_spacing(i, fs); - } - return true; -} - -bool write_spacing_info(FILE* f, const FontInfo& fi) { - inT32 vec_size = (fi.spacing_vec == NULL) ? 0 : fi.spacing_vec->size(); - if (fwrite(&vec_size, sizeof(vec_size), 1, f) != 1) return false; - inT16 x_gap_invalid = -1; - for (int i = 0; i < vec_size; ++i) { - FontSpacingInfo *fs = fi.spacing_vec->get(i); - inT32 kern_size = (fs == NULL) ? -1 : fs->kerned_x_gaps.size(); - if (fs == NULL) { - if (fwrite(&(x_gap_invalid), sizeof(x_gap_invalid), 1, f) != 1 || - fwrite(&(x_gap_invalid), sizeof(x_gap_invalid), 1, f) != 1 || - fwrite(&kern_size, sizeof(kern_size), 1, f) != 1) { - return false; - } - } else { - if (fwrite(&(fs->x_gap_before), sizeof(fs->x_gap_before), 1, f) != 1 || - fwrite(&(fs->x_gap_after), sizeof(fs->x_gap_after), 1, f) != 1 || - fwrite(&kern_size, sizeof(kern_size), 1, f) != 1) { - return false; - } - } - if (kern_size > 0 && (!fs->kerned_unichar_ids.Serialize(f) || - !fs->kerned_x_gaps.Serialize(f))) { - return false; - } - } - return true; -} - -bool read_set(FILE* f, FontSet* fs, bool swap) { - if (fread(&fs->size, sizeof(fs->size), 1, f) != 1) return false; - if (swap) - Reverse32(&fs->size); - fs->configs = new int[fs->size]; - for (int i = 0; i < fs->size; ++i) { - if (fread(&fs->configs[i], sizeof(fs->configs[i]), 1, f) != 1) return false; - if (swap) - Reverse32(&fs->configs[i]); - } - return true; -} - -bool write_set(FILE* f, const FontSet& fs) { - if (fwrite(&fs.size, sizeof(fs.size), 1, f) != 1) return false; - for (int i = 0; i < fs.size; ++i) { - if (fwrite(&fs.configs[i], sizeof(fs.configs[i]), 1, f) != 1) return false; - } - return true; -} - -} // namespace. - namespace tesseract { INT_TEMPLATES Classify::ReadIntTemplates(FILE *File) { /* @@ -909,7 +801,7 @@ INT_TEMPLATES Classify::ReadIntTemplates(FILE *File) { int unicharset_size; int version_id = 0; INT_TEMPLATES Templates; - CLASS_PRUNER Pruner; + CLASS_PRUNER_STRUCT* Pruner; INT_CLASS Class; uinT8 *Lengths; PROTO_SET ProtoSet; @@ -919,11 +811,11 @@ INT_TEMPLATES Classify::ReadIntTemplates(FILE *File) { CLASS_ID class_id, max_class_id; inT16 *IndexFor = new inT16[MAX_NUM_CLASSES]; CLASS_ID *ClassIdFor = new CLASS_ID[MAX_NUM_CLASSES]; - CLASS_PRUNER *TempClassPruner = new CLASS_PRUNER[MAX_NUM_CLASS_PRUNERS]; + CLASS_PRUNER_STRUCT **TempClassPruner = + new CLASS_PRUNER_STRUCT*[MAX_NUM_CLASS_PRUNERS]; uinT32 SetBitsForMask = // word with NUM_BITS_PER_CLASS (1 << NUM_BITS_PER_CLASS) - 1; // set starting at bit 0 uinT32 Mask, NewMask, ClassBits; - uinT32 *Word; int MaxNumConfigs = MAX_NUM_CONFIGS; int WerdsPerConfigVec = WERDS_PER_CONFIG_VEC; @@ -979,9 +871,9 @@ INT_TEMPLATES Classify::ReadIntTemplates(FILE *File) { /* then read in the class pruners */ for (i = 0; i < Templates->NumClassPruners; i++) { - Pruner = (CLASS_PRUNER) Emalloc(sizeof(CLASS_PRUNER_STRUCT)); + Pruner = new CLASS_PRUNER_STRUCT; if ((nread = - fread((char *) Pruner, 1, sizeof(CLASS_PRUNER_STRUCT), + fread(Pruner, 1, sizeof(CLASS_PRUNER_STRUCT), File)) != sizeof(CLASS_PRUNER_STRUCT)) cprintf("Bad read of inttemp!\n"); if (swap) { @@ -989,7 +881,7 @@ INT_TEMPLATES Classify::ReadIntTemplates(FILE *File) { for (y = 0; y < NUM_CP_BUCKETS; y++) { for (z = 0; z < NUM_CP_BUCKETS; z++) { for (w = 0; w < WERDS_PER_CP_VECTOR; w++) { - Reverse32(&Pruner[x][y][z][w]); + Reverse32(&Pruner->p[x][y][z][w]); } } } @@ -998,7 +890,7 @@ INT_TEMPLATES Classify::ReadIntTemplates(FILE *File) { if (version_id < 2) { TempClassPruner[i] = Pruner; } else { - Templates->ClassPruner[i] = Pruner; + Templates->ClassPruners[i] = Pruner; } } @@ -1010,11 +902,8 @@ INT_TEMPLATES Classify::ReadIntTemplates(FILE *File) { if (ClassIdFor[i] > max_class_id) max_class_id = ClassIdFor[i]; for (i = 0; i <= CPrunerIdFor(max_class_id); i++) { - Templates->ClassPruner[i] = - (CLASS_PRUNER) Emalloc(sizeof(CLASS_PRUNER_STRUCT)); - for (Word = (uinT32 *) (Templates->ClassPruner[i]); - Word < (uinT32 *) (Templates->ClassPruner[i]) + WERDS_PER_CP; - *Word++ = 0); + Templates->ClassPruners[i] = new CLASS_PRUNER_STRUCT; + memset(Templates->ClassPruners[i], 0, sizeof(CLASS_PRUNER_STRUCT)); } // Convert class pruners from the old format (indexed by class index) // to the new format (indexed by class id). @@ -1024,7 +913,7 @@ INT_TEMPLATES Classify::ReadIntTemplates(FILE *File) { for (y = 0; y < NUM_CP_BUCKETS; y++) for (z = 0; z < NUM_CP_BUCKETS; z++) for (w = 0; w < WERDS_PER_CP_VECTOR; w++) { - if (TempClassPruner[i][x][y][z][w] == 0) + if (TempClassPruner[i]->p[x][y][z][w] == 0) continue; for (b = 0; b < BITS_PER_WERD; b += NUM_BITS_PER_CLASS) { bit_number = i * BITS_PER_CP_VECTOR + w * BITS_PER_WERD + b; @@ -1033,7 +922,7 @@ INT_TEMPLATES Classify::ReadIntTemplates(FILE *File) { class_id = ClassIdFor[bit_number / NUM_BITS_PER_CLASS]; // Single out NUM_BITS_PER_CLASS bits relating to class_id. Mask = SetBitsForMask << b; - ClassBits = TempClassPruner[i][x][y][z][w] & Mask; + ClassBits = TempClassPruner[i]->p[x][y][z][w] & Mask; // Move these bits to the new position in which they should // appear (indexed corresponding to the class_id). new_i = CPrunerIdFor(class_id); @@ -1047,13 +936,13 @@ INT_TEMPLATES Classify::ReadIntTemplates(FILE *File) { // Copy bits relating to class_id to the correct position // in Templates->ClassPruner. NewMask = SetBitsForMask << new_b; - Templates->ClassPruner[new_i][x][y][z][new_w] &= ~NewMask; - Templates->ClassPruner[new_i][x][y][z][new_w] |= ClassBits; + Templates->ClassPruners[new_i]->p[x][y][z][new_w] &= ~NewMask; + Templates->ClassPruners[new_i]->p[x][y][z][new_w] |= ClassBits; } } } for (i = 0; i < Templates->NumClassPruners; i++) { - Efree (TempClassPruner[i]); + delete TempClassPruner[i]; } } @@ -1217,7 +1106,6 @@ void Classify::ShowMatchDisplay() { ** History: Thu Mar 21 15:47:33 1991, DSJ, Created. */ InitIntMatchWindowIfReqd(); - c_clear_window(IntMatchWindow); if (ProtoDisplayWindow) { ProtoDisplayWindow->Clear(); } @@ -1227,7 +1115,6 @@ void Classify::ShowMatchDisplay() { ClearFeatureSpaceWindow( static_cast(static_cast(classify_norm_method)), IntMatchWindow); - IntMatchWindow->ZoomToRectangle(INT_MIN_X, INT_MIN_Y, INT_MAX_X, INT_MAX_Y); if (ProtoDisplayWindow) { @@ -1299,7 +1186,7 @@ void Classify::WriteIntTemplates(FILE *File, INT_TEMPLATES Templates, /* then write out the class pruners */ for (i = 0; i < Templates->NumClassPruners; i++) - fwrite(Templates->ClassPruner[i], + fwrite(Templates->ClassPruners[i], sizeof(CLASS_PRUNER_STRUCT), 1, File); /* then write out each class */ @@ -1385,7 +1272,7 @@ FLOAT32 BucketEnd(int Bucket, FLOAT32 Offset, int NumBuckets) { /*---------------------------------------------------------------------------*/ void DoFill(FILL_SPEC *FillSpec, - CLASS_PRUNER Pruner, + CLASS_PRUNER_STRUCT* Pruner, register uinT32 ClassMask, register uinT32 ClassCount, register uinT32 WordIndex) { @@ -1421,11 +1308,11 @@ void DoFill(FILL_SPEC *FillSpec, for (Y = FillSpec->YStart; Y <= FillSpec->YEnd; Y++) for (Angle = FillSpec->AngleStart; TRUE; CircularIncrement (Angle, NUM_CP_BUCKETS)) { - OldWord = Pruner[X][Y][Angle][WordIndex]; + OldWord = Pruner->p[X][Y][Angle][WordIndex]; if (ClassCount > (OldWord & ClassMask)) { OldWord &= ~ClassMask; OldWord |= ClassCount; - Pruner[X][Y][Angle][WordIndex] = OldWord; + Pruner->p[X][Y][Angle][WordIndex] = OldWord; } if (Angle == FillSpec->AngleEnd) break; @@ -1543,7 +1430,7 @@ void FillPPLinearBits(uinT32 ParamTable[NUM_PP_BUCKETS][WERDS_PER_PP_VECTOR], #ifndef GRAPHICS_DISABLED namespace tesseract { CLASS_ID Classify::GetClassToDebug(const char *Prompt, bool* adaptive_on, - bool* pretrained_on) { + bool* pretrained_on, int* shape_id) { /* ** Parameters: ** Prompt prompt to print while waiting for input from window @@ -1557,26 +1444,57 @@ CLASS_ID Classify::GetClassToDebug(const char *Prompt, bool* adaptive_on, tprintf("%s\n", Prompt); SVEvent* ev; SVEventType ev_type; + int unichar_id = INVALID_UNICHAR_ID; // Wait until a click or popup event. do { ev = IntMatchWindow->AwaitEvent(SVET_ANY); ev_type = ev->type; if (ev_type == SVET_POPUP) { - if (unicharset.contains_unichar(ev->parameter)) { - if (ev->command_id == IDA_ADAPTIVE) { - *adaptive_on = true; - *pretrained_on = false; - } else if (ev->command_id == IDA_STATIC) { + if (ev->command_id == IDA_SHAPE_INDEX) { + if (shape_table_ != NULL) { + *shape_id = atoi(ev->parameter); *adaptive_on = false; *pretrained_on = true; + if (*shape_id >= 0 && *shape_id < shape_table_->NumShapes()) { + int font_id; + shape_table_->GetFirstUnicharAndFont(*shape_id, &unichar_id, + &font_id); + tprintf("Shape %d, first unichar=%d, font=%d\n", + *shape_id, unichar_id, font_id); + return unichar_id; + } + tprintf("Shape index '%s' not found in shape table\n", ev->parameter); } else { - *adaptive_on = true; - *pretrained_on = true; + tprintf("No shape table loaded!\n"); + } + } else { + if (unicharset.contains_unichar(ev->parameter)) { + unichar_id = unicharset.unichar_to_id(ev->parameter); + if (ev->command_id == IDA_ADAPTIVE) { + *adaptive_on = true; + *pretrained_on = false; + *shape_id = -1; + } else if (ev->command_id == IDA_STATIC) { + *adaptive_on = false; + *pretrained_on = true; + } else { + *adaptive_on = true; + *pretrained_on = true; + } + if (ev->command_id == IDA_ADAPTIVE || shape_table_ == NULL) { + *shape_id = -1; + return unichar_id; + } + for (int s = 0; s < shape_table_->NumShapes(); ++s) { + if (shape_table_->GetShape(s).ContainsUnichar(unichar_id)) { + tprintf("%s\n", shape_table_->DebugStr(s).string()); + } + } + } else { + tprintf("Char class '%s' not found in unicharset", + ev->parameter); } - return unicharset.unichar_to_id(ev->parameter); } - tprintf("Char class '%s' not found in unicharset", - ev->parameter); } delete ev; } while (ev_type != SVET_CLICK); @@ -1916,15 +1834,8 @@ void RenderIntFeature(ScrollView *window, const INT_FEATURE_STRUCT* Feature, // using BinaryAnglePlusPi in intfx.cpp. Dx = (Length / 2.0) * cos((Feature->Theta / 256.0) * 2.0 * PI - PI); Dy = (Length / 2.0) * sin((Feature->Theta / 256.0) * 2.0 * PI - PI); - float x_offset = Dy / 4.0; - float y_offset = -Dx / 4.0; - window->SetCursor(X - Dx, Y - Dy); - window->DrawTo(X + Dx, Y + Dy); - // Draw another copy of the feature offset perpendicualar to its direction. - X += x_offset; - Y += y_offset; - window->SetCursor(X - Dx, Y - Dy); + window->SetCursor(X, Y); window->DrawTo(X + Dx, Y + Dy); } /* RenderIntFeature */ @@ -2047,6 +1958,8 @@ void InitIntMatchWindowIfReqd() { "x", "Class to debug"); popup_menu->AddChild("Debug Both", IDA_BOTH, "x", "Class to debug"); + popup_menu->AddChild("Debug Shape Index", IDA_SHAPE_INDEX, + "0", "Index to debug"); popup_menu->BuildMenu(IntMatchWindow, false); } } diff --git a/classify/intproto.h b/classify/intproto.h index 854045799d..80bf8ac5ef 100644 --- a/classify/intproto.h +++ b/classify/intproto.h @@ -25,7 +25,6 @@ #include "matchdefs.h" #include "mfoutline.h" #include "protos.h" -#include "callcpp.h" #include "scrollview.h" #include "unicharset.h" @@ -72,11 +71,9 @@ * The position of the the bits recorded for each class in the * 4th dimension is determined by using CPrunerWordIndexFor(c), * where c is the corresponding class id. */ -typedef uinT32 CLASS_PRUNER_STRUCT -[NUM_CP_BUCKETS][NUM_CP_BUCKETS][NUM_CP_BUCKETS][WERDS_PER_CP_VECTOR]; - -typedef -uinT32 (*CLASS_PRUNER)[NUM_CP_BUCKETS][NUM_CP_BUCKETS][WERDS_PER_CP_VECTOR]; +struct CLASS_PRUNER_STRUCT { + uinT32 p[NUM_CP_BUCKETS][NUM_CP_BUCKETS][NUM_CP_BUCKETS][WERDS_PER_CP_VECTOR]; +}; typedef struct { @@ -103,86 +100,6 @@ PROTO_SET_STRUCT, *PROTO_SET; typedef uinT32 CONFIG_PRUNER[NUM_PP_PARAMS][NUM_PP_BUCKETS][4]; -// Struct for information about spacing between characters in a particular font. -struct FontSpacingInfo { - inT16 x_gap_before; - inT16 x_gap_after; - GenericVector kerned_unichar_ids; - GenericVector kerned_x_gaps; -}; - -/* - * font_properties contains properties about boldness, italicness, fixed pitch, - * serif, fraktur - */ -struct FontInfo { - FontInfo() : name(NULL), spacing_vec(NULL) {} - ~FontInfo() {} - // Reserves unicharset_size spots in spacing_vec. - void init_spacing(int unicharset_size) { - spacing_vec = new GenericVector(); - spacing_vec->init_to_size(unicharset_size, NULL); - } - // Adds the given pointer to FontSpacingInfo to spacing_vec member - // (FontInfo class takes ownership of the pointer). - // Note: init_spacing should be called before calling this function. - void add_spacing(UNICHAR_ID uch_id, FontSpacingInfo *spacing_info) { - ASSERT_HOST(spacing_vec != NULL && spacing_vec->size() > uch_id); - (*spacing_vec)[uch_id] = spacing_info; - } - - // Returns the pointer to FontSpacingInfo for the given UNICHAR_ID. - const FontSpacingInfo *get_spacing(UNICHAR_ID uch_id) const { - return (spacing_vec == NULL || spacing_vec->size() <= uch_id) ? - NULL : (*spacing_vec)[uch_id]; - } - - // Fills spacing with the value of the x gap expected between the two given - // UNICHAR_IDs. Returns true on success. - bool get_spacing(UNICHAR_ID prev_uch_id, - UNICHAR_ID uch_id, - int *spacing) const { - const FontSpacingInfo *prev_fsi = this->get_spacing(prev_uch_id); - const FontSpacingInfo *fsi = this->get_spacing(uch_id); - if (prev_fsi == NULL || fsi == NULL) return false; - int i = 0; - for (; i < prev_fsi->kerned_unichar_ids.size(); ++i) { - if (prev_fsi->kerned_unichar_ids[i] == uch_id) break; - } - if (i < prev_fsi->kerned_unichar_ids.size()) { - *spacing = prev_fsi->kerned_x_gaps[i]; - } else { - *spacing = prev_fsi->x_gap_after + fsi->x_gap_before; - } - return true; - } - - bool is_italic() const { return properties & 1; } - bool is_bold() const { return (properties & 2) != 0; } - bool is_fixed_pitch() const { return (properties & 4) != 0; } - bool is_serif() const { return (properties & 8) != 0; } - bool is_fraktur() const { return (properties & 16) != 0; } - - char* name; - uinT32 properties; - // Horizontal spacing between characters (indexed by UNICHAR_ID). - GenericVector *spacing_vec; -}; - -// Every class (character) owns a FontSet that represents all the fonts that can -// render this character. -// Since almost all the characters from the same script share the same set of -// fonts, the sets are shared over multiple classes (see -// Classify::fontset_table_). Thus, a class only store an id to a set. -// Because some fonts cannot render just one character of a set, there are a -// lot of FontSet that differ only by one font. Rather than storing directly -// the FontInfo in the FontSet structure, it's better to share FontInfos among -// FontSets (Classify::fontinfo_table_). -struct FontSet { - int size; - int* configs; // FontInfo ids -}; - typedef struct { @@ -203,7 +120,7 @@ typedef struct int NumClasses; int NumClassPruners; INT_CLASS Class[MAX_NUM_CLASSES]; - CLASS_PRUNER ClassPruner[MAX_NUM_CLASS_PRUNERS]; + CLASS_PRUNER_STRUCT* ClassPruners[MAX_NUM_CLASS_PRUNERS]; } @@ -232,6 +149,7 @@ typedef INT_FEATURE_STRUCT INT_FEATURE_ARRAY[MAX_NUM_INT_FEATURES]; enum IntmatcherDebugAction { IDA_ADAPTIVE, IDA_STATIC, + IDA_SHAPE_INDEX, IDA_BOTH }; @@ -255,7 +173,7 @@ enum IntmatcherDebugAction { #define ClassForClassId(T,c) ((T)->Class[c]) #define ClassPrunersFor(T) ((T)->ClassPruner) #define CPrunerIdFor(c) ((c) / CLASSES_PER_CP) -#define CPrunerFor(T,c) ((T)->ClassPruner [CPrunerIdFor (c)]) +#define CPrunerFor(T,c) ((T)->ClassPruners[CPrunerIdFor(c)]) #define CPrunerWordIndexFor(c) (((c) % CLASSES_PER_CP) / CLASSES_PER_CP_WERD) #define CPrunerBitIndexFor(c) (((c) % CLASSES_PER_CP) % CLASSES_PER_CP_WERD) #define CPrunerMaskFor(L,c) (((L)+1) << CPrunerBitIndexFor (c) * NUM_BITS_PER_CLASS) @@ -300,7 +218,7 @@ void UpdateMatchDisplay(); void ConvertConfig(BIT_VECTOR Config, int ConfigId, INT_CLASS Class); -void DisplayIntFeature(INT_FEATURE Feature, FLOAT32 Evidence); +void DisplayIntFeature(const INT_FEATURE_STRUCT* Feature, FLOAT32 Evidence); void DisplayIntProto(INT_CLASS Class, PROTO_ID ProtoId, FLOAT32 Evidence); diff --git a/classify/mastertrainer.cpp b/classify/mastertrainer.cpp new file mode 100644 index 0000000000..bb272fde26 --- /dev/null +++ b/classify/mastertrainer.cpp @@ -0,0 +1,967 @@ +// Copyright 2010 Google Inc. All Rights Reserved. +// Author: rays@google.com (Ray Smith) +/////////////////////////////////////////////////////////////////////// +// File: mastertrainer.cpp +// Description: Trainer to build the MasterClassifier. +// Author: Ray Smith +// Created: Wed Nov 03 18:10:01 PDT 2010 +// +// (C) Copyright 2010, Google Inc. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +/////////////////////////////////////////////////////////////////////// + +#include "mastertrainer.h" +#include +#include +#include "allheaders.h" +#include "boxread.h" +#include "classify.h" +#include "errorcounter.h" +#include "featdefs.h" +#include "sampleiterator.h" +#include "shapeclassifier.h" +#include "shapetable.h" +#include "svmnode.h" + +namespace tesseract { + +// Constants controlling clustering. With a low kMinClusteredShapes and a high +// kMaxUnicharsPerCluster, then kFontMergeDistance is the only limiting factor. +// Min number of shapes in the output. +const int kMinClusteredShapes = 1; +// Max number of unichars in any individual cluster. +const int kMaxUnicharsPerCluster = 2000; +// Mean font distance below which to merge fonts and unichars. +const float kFontMergeDistance = 0.025; + +MasterTrainer::MasterTrainer(NormalizationMode norm_mode, + bool shape_analysis, + bool replicate_samples, + int debug_level) + : norm_mode_(norm_mode), samples_(fontinfo_table_), + junk_samples_(fontinfo_table_), verify_samples_(fontinfo_table_), + charsetsize_(0), + enable_shape_anaylsis_(shape_analysis), + enable_replication_(replicate_samples), + fragments_(NULL), prev_unichar_id_(-1), debug_level_(debug_level) { + fontinfo_table_.set_compare_callback( + NewPermanentTessCallback(CompareFontInfo)); + fontinfo_table_.set_clear_callback( + NewPermanentTessCallback(FontInfoDeleteCallback)); +} + +MasterTrainer::~MasterTrainer() { + delete [] fragments_; + for (int p = 0; p < page_images_.size(); ++p) + pixDestroy(&page_images_[p]); +} + +// WARNING! Serialize/DeSerialize are only partial, providing +// enough data to get the samples back and display them. +// Writes to the given file. Returns false in case of error. +bool MasterTrainer::Serialize(FILE* fp) const { + if (fwrite(&norm_mode_, sizeof(norm_mode_), 1, fp) != 1) return false; + if (!unicharset_.save_to_file(fp)) return false; + if (!feature_space_.Serialize(fp)) return false; + if (!samples_.Serialize(fp)) return false; + if (!junk_samples_.Serialize(fp)) return false; + if (!verify_samples_.Serialize(fp)) return false; + if (!master_shapes_.Serialize(fp)) return false; + if (!flat_shapes_.Serialize(fp)) return false; + if (!fontinfo_table_.write(fp, NewPermanentTessCallback(write_info))) + return false; + if (!fontinfo_table_.write(fp, NewPermanentTessCallback(write_spacing_info))) + return false; + if (!xheights_.Serialize(fp)) return false; + return true; +} + +// Reads from the given file. Returns false in case of error. +// If swap is true, assumes a big/little-endian swap is needed. +bool MasterTrainer::DeSerialize(bool swap, FILE* fp) { + if (fread(&norm_mode_, sizeof(norm_mode_), 1, fp) != 1) return false; + if (swap) { + ReverseN(&norm_mode_, sizeof(norm_mode_)); + } + if (!unicharset_.load_from_file(fp)) return false; + charsetsize_ = unicharset_.size(); + if (!feature_space_.DeSerialize(swap, fp)) return false; + feature_map_.Init(feature_space_); + if (!samples_.DeSerialize(swap, fp)) return false; + if (!junk_samples_.DeSerialize(swap, fp)) return false; + if (!verify_samples_.DeSerialize(swap, fp)) return false; + if (!master_shapes_.DeSerialize(swap, fp)) return false; + if (!flat_shapes_.DeSerialize(swap, fp)) return false; + if (!fontinfo_table_.read(fp, NewPermanentTessCallback(read_info), swap)) + return false; + if (!fontinfo_table_.read(fp, NewPermanentTessCallback(read_spacing_info), + swap)) + return false; + if (!xheights_.DeSerialize(swap, fp)) return false; + return true; +} + +// Load an initial unicharset, or set one up if the file cannot be read. +void MasterTrainer::LoadUnicharset(const char* filename) { + if (!unicharset_.load_from_file(filename)) { + tprintf("Failed to load unicharset from file %s\n" + "Building unicharset for training from scratch...\n", + filename); + unicharset_.clear(); + // Space character needed to represent NIL_LIST classification. + unicharset_.unichar_insert(" "); + } + charsetsize_ = unicharset_.size(); + delete [] fragments_; + fragments_ = new int[charsetsize_]; + memset(fragments_, 0, sizeof(*fragments_) * charsetsize_); + samples_.LoadUnicharset(filename); + junk_samples_.LoadUnicharset(filename); + verify_samples_.LoadUnicharset(filename); +} + +// Reads the samples and their features from the given .tr format file, +// adding them to the trainer with the font_id from the content of the file. +// See mftraining.cpp for a description of the file format. +// If verification, then these are verification samples, not training. +void MasterTrainer::ReadTrainingSamples(FILE *fp, + const FEATURE_DEFS_STRUCT& feature_defs, + bool verification) { + char buffer[2048]; + int int_feature_type = ShortNameToFeatureType(feature_defs, kIntFeatureType); + int micro_feature_type = ShortNameToFeatureType(feature_defs, + kMicroFeatureType); + int cn_feature_type = ShortNameToFeatureType(feature_defs, kCNFeatureType); + int geo_feature_type = ShortNameToFeatureType(feature_defs, kGeoFeatureType); + + while (fgets(buffer, sizeof(buffer), fp) != NULL) { + if (buffer[0] == '\n') + continue; + + char* space = strchr(buffer, ' '); + if (space == NULL) { + tprintf("Bad format in tr file, reading fontname, unichar\n"); + continue; + } + *space++ = '\0'; + int font_id = GetFontInfoId(buffer); + int page_number; + STRING unichar; + TBOX bounding_box; + if (!ParseBoxFileStr(space, &page_number, &unichar, &bounding_box)) { + tprintf("Bad format in tr file, reading box coords\n"); + continue; + } + CHAR_DESC char_desc = ReadCharDescription(feature_defs, fp); + TrainingSample* sample = new TrainingSample; + sample->set_font_id(font_id); + sample->set_page_num(page_number + page_images_.size()); + sample->set_bounding_box(bounding_box); + sample->ExtractCharDesc(int_feature_type, micro_feature_type, + cn_feature_type, geo_feature_type, char_desc); + AddSample(verification, unichar.string(), sample); + FreeCharDescription(char_desc); + } + charsetsize_ = unicharset_.size(); +} + +// Adds the given single sample to the trainer, setting the classid +// appropriately from the given unichar_str. +void MasterTrainer::AddSample(bool verification, const char* unichar, + TrainingSample* sample) { + if (verification) { + verify_samples_.AddSample(unichar, sample); + prev_unichar_id_ = -1; + } else if (unicharset_.contains_unichar(unichar)) { + if (prev_unichar_id_ >= 0) + fragments_[prev_unichar_id_] = -1; + prev_unichar_id_ = samples_.AddSample(unichar, sample); + if (flat_shapes_.FindShape(prev_unichar_id_, sample->font_id()) < 0) + flat_shapes_.AddShape(prev_unichar_id_, sample->font_id()); + } else { + int junk_id = junk_samples_.AddSample(unichar, sample); + if (prev_unichar_id_ >= 0) { + CHAR_FRAGMENT* frag = CHAR_FRAGMENT::parse_from_string(unichar); + if (frag != NULL && frag->is_natural()) { + if (fragments_[prev_unichar_id_] == 0) + fragments_[prev_unichar_id_] = junk_id; + else if (fragments_[prev_unichar_id_] != junk_id) + fragments_[prev_unichar_id_] = -1; + } + delete frag; + } + prev_unichar_id_ = -1; + } +} + +// Loads all pages from the given tif filename and append to page_images_. +// Must be called after ReadTrainingSamples, as the current number of images +// is used as an offset for page numbers in the samples. +void MasterTrainer::LoadPageImages(const char* filename) { + int page; + Pix* pix; + for (page = 0; (pix = pixReadTiff(filename, page)) != NULL; ++page) { + page_images_.push_back(pix); + } + tprintf("Loaded %d page images from %s\n", page, filename); +} + +// Cleans up the samples after initial load from the tr files, and prior to +// saving the MasterTrainer: +// Remaps fragmented chars if running shape anaylsis. +// Sets up the samples appropriately for class/fontwise access. +// Deletes outlier samples. +void MasterTrainer::PostLoadCleanup() { + if (debug_level_ > 0) + tprintf("PostLoadCleanup...\n"); + if (enable_shape_anaylsis_) + ReplaceFragmentedSamples(); + SampleIterator sample_it; + sample_it.Init(NULL, NULL, true, &verify_samples_); + sample_it.NormalizeSamples(); + verify_samples_.OrganizeByFontAndClass(); + + samples_.IndexFeatures(feature_space_); + // TODO(rays) DeleteOutliers is currently turned off to prove NOP-ness + // against current training. + // samples_.DeleteOutliers(feature_space_, debug_level_ > 0); + samples_.OrganizeByFontAndClass(); + if (debug_level_ > 0) + tprintf("ComputeCanonicalSamples...\n"); + samples_.ComputeCanonicalSamples(feature_map_, debug_level_ > 0); +} + +// Gets the samples ready for training. Use after both +// ReadTrainingSamples+PostLoadCleanup or DeSerialize. +// Re-indexes the features and computes canonical and cloud features. +void MasterTrainer::PreTrainingSetup() { + if (debug_level_ > 0) + tprintf("PreTrainingSetup...\n"); + samples_.IndexFeatures(feature_space_); + samples_.ComputeCanonicalFeatures(); + if (debug_level_ > 0) + tprintf("ComputeCloudFeatures...\n"); + samples_.ComputeCloudFeatures(feature_space_.Size()); +} + +// Sets up the master_shapes_ table, which tells which fonts should stay +// together until they get to a leaf node classifier. +void MasterTrainer::SetupMasterShapes() { + tprintf("Building master shape table\n"); + int num_fonts = samples_.NumFonts(); + + ShapeTable char_shapes_begin_fragment(samples_.unicharset()); + ShapeTable char_shapes_end_fragment(samples_.unicharset()); + ShapeTable char_shapes(samples_.unicharset()); + for (int c = 0; c < samples_.charsetsize(); ++c) { + ShapeTable shapes(samples_.unicharset()); + for (int f = 0; f < num_fonts; ++f) { + if (samples_.NumClassSamples(f, c, true) > 0) + shapes.AddShape(c, f); + } + ClusterShapes(kMinClusteredShapes, 1, kFontMergeDistance, &shapes); + + const CHAR_FRAGMENT *fragment = samples_.unicharset().get_fragment(c); + + if (fragment == NULL) + char_shapes.AppendMasterShapes(shapes); + else if (fragment->is_beginning()) + char_shapes_begin_fragment.AppendMasterShapes(shapes); + else if (fragment->is_ending()) + char_shapes_end_fragment.AppendMasterShapes(shapes); + else + char_shapes.AppendMasterShapes(shapes); + } + ClusterShapes(kMinClusteredShapes, kMaxUnicharsPerCluster, + kFontMergeDistance, &char_shapes_begin_fragment); + char_shapes.AppendMasterShapes(char_shapes_begin_fragment); + ClusterShapes(kMinClusteredShapes, kMaxUnicharsPerCluster, + kFontMergeDistance, &char_shapes_end_fragment); + char_shapes.AppendMasterShapes(char_shapes_end_fragment); + ClusterShapes(kMinClusteredShapes, kMaxUnicharsPerCluster, + kFontMergeDistance, &char_shapes); + master_shapes_.AppendMasterShapes(char_shapes); + tprintf("Master shape_table:%s\n", master_shapes_.SummaryStr().string()); +} + +// Adds the junk_samples_ to the main samples_ set. Junk samples are initially +// fragments and n-grams (all incorrectly segmented characters). +// Various training functions may result in incorrectly segmented characters +// being added to the unicharset of the main samples, perhaps because they +// form a "radical" decomposition of some (Indic) grapheme, or because they +// just look the same as a real character (like rn/m) +// This function moves all the junk samples, to the main samples_ set, but +// desirable junk, being any sample for which the unichar already exists in +// the samples_ unicharset gets the unichar-ids re-indexed to match, but +// anything else gets re-marked as unichar_id 0 (space character) to identify +// it as junk to the error counter. +void MasterTrainer::IncludeJunk() { + // Get ids of fragments in junk_samples_ that replace the dead chars. + const UNICHARSET& junk_set = junk_samples_.unicharset(); + const UNICHARSET& sample_set = samples_.unicharset(); + int num_junks = junk_samples_.num_samples(); + tprintf("Moving %d junk samples to master sample set.\n", num_junks); + for (int s = 0; s < num_junks; ++s) { + TrainingSample* sample = junk_samples_.mutable_sample(s); + int junk_id = sample->class_id(); + const char* junk_utf8 = junk_set.id_to_unichar(junk_id); + int sample_id = sample_set.unichar_to_id(junk_utf8); + if (sample_id == INVALID_UNICHAR_ID) + sample_id = 0; + sample->set_class_id(sample_id); + junk_samples_.extract_sample(s); + samples_.AddSample(sample_id, sample); + } + junk_samples_.DeleteDeadSamples(); + samples_.OrganizeByFontAndClass(); +} + +// Replicates the samples and perturbs them if the enable_replication_ flag +// is set. MUST be used after the last call to OrganizeByFontAndClass on +// the training samples, ie after IncludeJunk if it is going to be used, as +// OrganizeByFontAndClass will eat the replicated samples into the regular +// samples. +void MasterTrainer::ReplicateAndRandomizeSamplesIfRequired() { + if (enable_replication_) { + if (debug_level_ > 0) + tprintf("ReplicateAndRandomize...\n"); + verify_samples_.ReplicateAndRandomizeSamples(); + samples_.ReplicateAndRandomizeSamples(); + samples_.IndexFeatures(feature_space_); + } +} + +// Loads the basic font properties file into fontinfo_table_. +// Returns false on failure. +bool MasterTrainer::LoadFontInfo(const char* filename) { + FILE* fp = fopen(filename, "rb"); + if (fp == NULL) { + fprintf(stderr, "Failed to load font_properties from %s\n", filename); + return false; + } + int italic, bold, fixed, serif, fraktur; + while (!feof(fp)) { + FontInfo fontinfo; + char* font_name = new char[1024]; + fontinfo.name = font_name; + fontinfo.properties = 0; + fontinfo.universal_id = 0; + if (fscanf(fp, "%1024s %i %i %i %i %i\n", font_name, + &italic, &bold, &fixed, &serif, &fraktur) != 6) + continue; + fontinfo.properties = + (italic << 0) + + (bold << 1) + + (fixed << 2) + + (serif << 3) + + (fraktur << 4); + if (!fontinfo_table_.contains(fontinfo)) { + fontinfo_table_.push_back(fontinfo); + } + } + fclose(fp); + return true; +} + +// Loads the xheight font properties file into xheights_. +// Returns false on failure. +bool MasterTrainer::LoadXHeights(const char* filename) { + tprintf("fontinfo table is of size %d\n", fontinfo_table_.size()); + xheights_.init_to_size(fontinfo_table_.size(), -1); + if (filename == NULL) return true; + FILE *f = fopen(filename, "rb"); + if (f == NULL) { + fprintf(stderr, "Failed to load font xheights from %s\n", filename); + return false; + } + tprintf("Reading x-heights from %s ...\n", filename); + FontInfo fontinfo; + fontinfo.properties = 0; // Not used to lookup in the table. + fontinfo.universal_id = 0; + char buffer[1024]; + int xht; + int total_xheight = 0; + int xheight_count = 0; + while (!feof(f)) { + if (fscanf(f, "%1024s %d\n", buffer, &xht) != 2) + continue; + fontinfo.name = buffer; + if (!fontinfo_table_.contains(fontinfo)) continue; + int fontinfo_id = fontinfo_table_.get_id(fontinfo); + xheights_[fontinfo_id] = xht; + total_xheight += xht; + ++xheight_count; + } + if (xheight_count == 0) { + fprintf(stderr, "No valid xheights in %s!\n", filename); + return false; + } + int mean_xheight = DivRounded(total_xheight, xheight_count); + for (int i = 0; i < fontinfo_table_.size(); ++i) { + if (xheights_[i] < 0) + xheights_[i] = mean_xheight; + } + return true; +} // LoadXHeights + +// Reads spacing stats from filename and adds them to fontinfo_table. +bool MasterTrainer::AddSpacingInfo(const char *filename) { + FILE* fontinfo_file = fopen(filename, "rb"); + if (fontinfo_file == NULL) + return true; // We silently ignore missing files! + // Find the fontinfo_id. + int fontinfo_id = GetBestMatchingFontInfoId(filename); + if (fontinfo_id < 0) { + tprintf("No font found matching fontinfo filename %s\n", filename); + fclose(fontinfo_file); + return false; + } + tprintf("Reading spacing from %s for font %d...\n", filename, fontinfo_id); + // TODO(rays) scale should probably be a double, but keep as an int for now + // to duplicate current behavior. + int scale = kBlnXHeight / xheights_[fontinfo_id]; + int num_unichars; + char uch[UNICHAR_LEN]; + char kerned_uch[UNICHAR_LEN]; + int x_gap, x_gap_before, x_gap_after, num_kerned; + ASSERT_HOST(fscanf(fontinfo_file, "%d\n", &num_unichars) == 1); + FontInfo *fi = fontinfo_table_.get_mutable(fontinfo_id); + fi->init_spacing(unicharset_.size()); + FontSpacingInfo *spacing = NULL; + for (int l = 0; l < num_unichars; ++l) { + if (fscanf(fontinfo_file, "%s %d %d %d", + uch, &x_gap_before, &x_gap_after, &num_kerned) != 4) { + tprintf("Bad format of font spacing file %s\n", filename); + fclose(fontinfo_file); + return false; + } + bool valid = unicharset_.contains_unichar(uch); + if (valid) { + spacing = new FontSpacingInfo(); + spacing->x_gap_before = static_cast(x_gap_before * scale); + spacing->x_gap_after = static_cast(x_gap_after * scale); + } + for (int k = 0; k < num_kerned; ++k) { + if (fscanf(fontinfo_file, "%s %d", kerned_uch, &x_gap) != 2) { + tprintf("Bad format of font spacing file %s\n", filename); + fclose(fontinfo_file); + return false; + } + if (!valid || !unicharset_.contains_unichar(kerned_uch)) continue; + spacing->kerned_unichar_ids.push_back( + unicharset_.unichar_to_id(kerned_uch)); + spacing->kerned_x_gaps.push_back(static_cast(x_gap * scale)); + } + if (valid) fi->add_spacing(unicharset_.unichar_to_id(uch), spacing); + } + fclose(fontinfo_file); + return true; +} + +// Returns the font id corresponding to the given font name. +// Returns -1 if the font cannot be found. +int MasterTrainer::GetFontInfoId(const char* font_name) { + FontInfo fontinfo; + // We are only borrowing the string, so it is OK to const cast it. + fontinfo.name = const_cast(font_name); + fontinfo.properties = 0; // Not used to lookup in the table + fontinfo.universal_id = 0; + if (!fontinfo_table_.contains(fontinfo)) { + return -1; + } else { + return fontinfo_table_.get_id(fontinfo); + } +} +// Returns the font_id of the closest matching font name to the given +// filename. It is assumed that a substring of the filename will match +// one of the fonts. If more than one is matched, the longest is returned. +int MasterTrainer::GetBestMatchingFontInfoId(const char* filename) { + int fontinfo_id = -1; + int best_len = 0; + for (int f = 0; f < fontinfo_table_.size(); ++f) { + if (strstr(filename, fontinfo_table_.get(f).name) != NULL) { + int len = strlen(fontinfo_table_.get(f).name); + // Use the longest matching length in case a substring of a font matched. + if (len > best_len) { + best_len = len; + fontinfo_id = f; + } + } + } + return fontinfo_id; +} + +// Sets up a flat shapetable with one shape per class/font combination. +void MasterTrainer::SetupFlatShapeTable(ShapeTable* shape_table) { + // To exactly mimic the results of the previous implementation, the shapes + // must be clustered in order the fonts arrived, and reverse order of the + // characters within each font. + // Get a list of the fonts in the order they appeared. + GenericVector active_fonts; + int num_shapes = flat_shapes_.NumShapes(); + for (int s = 0; s < num_shapes; ++s) { + int font = flat_shapes_.GetShape(s)[0].font_ids[0]; + int f = 0; + for (f = 0; f < active_fonts.size(); ++f) { + if (active_fonts[f] == font) + break; + } + if (f == active_fonts.size()) + active_fonts.push_back(font); + } + // For each font in order, add all the shapes with that font in reverse order. + int num_fonts = active_fonts.size(); + for (int f = 0; f < num_fonts; ++f) { + for (int s = num_shapes - 1; s >= 0; --s) { + int font = flat_shapes_.GetShape(s)[0].font_ids[0]; + if (font == active_fonts[f]) { + shape_table->AddShape(flat_shapes_.GetShape(s)); + } + } + } +} + +// Sets up a Clusterer for mftraining on a single shape_id. +// Call FreeClusterer on the return value after use. +CLUSTERER* MasterTrainer::SetupForClustering( + const ShapeTable& shape_table, + const FEATURE_DEFS_STRUCT& feature_defs, + int shape_id, + int* num_samples) { + + int desc_index = ShortNameToFeatureType(feature_defs, kMicroFeatureType); + int num_params = feature_defs.FeatureDesc[desc_index]->NumParams; + ASSERT_HOST(num_params == MFCount); + CLUSTERER* clusterer = MakeClusterer( + num_params, feature_defs.FeatureDesc[desc_index]->ParamDesc); + + // We want to iterate over the samples of just the one shape. + IndexMapBiDi shape_map; + shape_map.Init(shape_table.NumShapes(), false); + shape_map.SetMap(shape_id, true); + shape_map.Setup(); + // Reverse the order of the samples to match the previous behavior. + GenericVector sample_ptrs; + SampleIterator it; + it.Init(&shape_map, &shape_table, false, &samples_); + for (it.Begin(); !it.AtEnd(); it.Next()) { + sample_ptrs.push_back(&it.GetSample()); + } + int sample_id = 0; + for (int i = sample_ptrs.size() - 1; i >= 0; --i) { + const TrainingSample* sample = sample_ptrs[i]; + int num_features = sample->num_micro_features(); + for (int f = 0; f < num_features; ++f) + MakeSample(clusterer, sample->micro_features()[f], sample_id); + ++sample_id; + } + *num_samples = sample_id; + return clusterer; +} + +// Writes the given float_classes (produced by SetupForFloat2Int) as inttemp +// to the given inttemp_file, and the corresponding pffmtable. +// The unicharset is the original encoding of graphemes, and shape_set should +// match the size of the shape_table, and may possibly be totally fake. +void MasterTrainer::WriteInttempAndPFFMTable(const UNICHARSET& unicharset, + const UNICHARSET& shape_set, + const ShapeTable& shape_table, + CLASS_STRUCT* float_classes, + const char* inttemp_file, + const char* pffmtable_file) { + tesseract::Classify *classify = new tesseract::Classify(); + // Move the fontinfo table to classify. + classify->get_fontinfo_table().move(&fontinfo_table_); + INT_TEMPLATES int_templates = classify->CreateIntTemplates(float_classes, + shape_set); + FILE* fp = fopen(inttemp_file, "wb"); + classify->WriteIntTemplates(fp, int_templates, shape_set); + fclose(fp); + // Now write pffmtable. This is complicated by the fact that the adaptive + // classifier still wants one indexed by unichar-id, but the static + // classifier needs one indexed by its shape class id. + // We put the shapetable_cutoffs in a GenericVector, and compute the + // unicharset cutoffs along the way. + GenericVector shapetable_cutoffs; + GenericVector unichar_cutoffs; + for (int c = 0; c < unicharset.size(); ++c) + unichar_cutoffs.push_back(0); + /* then write out each class */ + for (int i = 0; i < int_templates->NumClasses; ++i) { + INT_CLASS Class = ClassForClassId(int_templates, i); + // Todo: Test with min instead of max + // int MaxLength = LengthForConfigId(Class, 0); + uinT16 max_length = 0; + for (int config_id = 0; config_id < Class->NumConfigs; config_id++) { + // Todo: Test with min instead of max + // if (LengthForConfigId (Class, config_id) < MaxLength) + uinT16 length = Class->ConfigLengths[config_id]; + if (length > max_length) + max_length = Class->ConfigLengths[config_id]; + int shape_id = float_classes[i].font_set.get(config_id); + const Shape& shape = shape_table.GetShape(shape_id); + for (int c = 0; c < shape.size(); ++c) { + int unichar_id = shape[c].unichar_id; + if (length > unichar_cutoffs[unichar_id]) + unichar_cutoffs[unichar_id] = length; + } + } + shapetable_cutoffs.push_back(max_length); + } + fp = fopen(pffmtable_file, "wb"); + shapetable_cutoffs.Serialize(fp); + for (int c = 0; c < unicharset.size(); ++c) { + const char *unichar = unicharset.id_to_unichar(c); + if (strcmp(unichar, " ") == 0) { + unichar = "NULL"; + } + fprintf(fp, "%s %d\n", unichar, unichar_cutoffs[c]); + } + fclose(fp); + free_int_templates(int_templates); +} + +// Generate debug output relating to the canonical distance between the +// two given UTF8 grapheme strings. +void MasterTrainer::DebugCanonical(const char* unichar_str1, + const char* unichar_str2) { + int class_id1 = unicharset_.unichar_to_id(unichar_str1); + int class_id2 = unicharset_.unichar_to_id(unichar_str2); + if (class_id2 == INVALID_UNICHAR_ID) + class_id2 = class_id1; + if (class_id1 == INVALID_UNICHAR_ID) { + tprintf("No unicharset entry found for %s\n", unichar_str1); + return; + } else { + tprintf("Font ambiguities for unichar %d = %s and %d = %s\n", + class_id1, unichar_str1, class_id2, unichar_str2); + } + int num_fonts = samples_.NumFonts(); + const IntFeatureMap& feature_map = feature_map_; + // Iterate the fonts to get the similarity with other fonst of the same + // class. + tprintf(" "); + for (int f = 0; f < num_fonts; ++f) { + if (samples_.NumClassSamples(f, class_id2, false) == 0) + continue; + tprintf("%6d", f); + } + tprintf("\n"); + for (int f1 = 0; f1 < num_fonts; ++f1) { + // Map the features of the canonical_sample. + if (samples_.NumClassSamples(f1, class_id1, false) == 0) + continue; + tprintf("%4d ", f1); + for (int f2 = 0; f2 < num_fonts; ++f2) { + if (samples_.NumClassSamples(f2, class_id2, false) == 0) + continue; + float dist = samples_.ClusterDistance(f1, class_id1, f2, class_id2, + feature_map); + tprintf(" %5.3f", dist); + } + tprintf("\n"); + } + // Build a fake ShapeTable containing all the sample types. + ShapeTable shapes(unicharset_); + for (int f = 0; f < num_fonts; ++f) { + if (samples_.NumClassSamples(f, class_id1, true) > 0) + shapes.AddShape(class_id1, f); + if (class_id1 != class_id2 && + samples_.NumClassSamples(f, class_id2, true) > 0) + shapes.AddShape(class_id2, f); + } +} + +// Debugging for cloud/canonical features. +// Displays a Features window containing: +// If unichar_str2 is in the unicharset, and canonical_font is non-negative, +// displays the canonical features of the char/font combination in red. +// If unichar_str1 is in the unicharset, and cloud_font is non-negative, +// displays the cloud feature of the char/font combination in green. +// The canonical features are drawn first to show which ones have no +// matches in the cloud features. +// Until the features window is destroyed, each click in the features window +// will display the samples that have that feature in a separate window. +void MasterTrainer::DisplaySamples(const char* unichar_str1, int cloud_font, + const char* unichar_str2, + int canonical_font) { + const IntFeatureMap& feature_map = feature_map_; + const IntFeatureSpace& feature_space = feature_map.feature_space(); + ScrollView* f_window = CreateFeatureSpaceWindow("Features", 100, 500); + ClearFeatureSpaceWindow(norm_mode_ == NM_BASELINE ? baseline : character, + f_window); + int class_id2 = samples_.unicharset().unichar_to_id(unichar_str2); + if (class_id2 != INVALID_UNICHAR_ID && canonical_font >= 0) { + const TrainingSample* sample = samples_.GetCanonicalSample(canonical_font, + class_id2); + for (int f = 0; f < sample->num_features(); ++f) { + RenderIntFeature(f_window, &sample->features()[f], ScrollView::RED); + } + } + int class_id1 = samples_.unicharset().unichar_to_id(unichar_str1); + if (class_id1 != INVALID_UNICHAR_ID && cloud_font >= 0) { + const BitVector& cloud = samples_.GetCloudFeatures(cloud_font, class_id1); + for (int f = 0; f < cloud.size(); ++f) { + if (cloud[f]) { + INT_FEATURE_STRUCT feature = + feature_map.InverseIndexFeature(f); + RenderIntFeature(f_window, &feature, ScrollView::GREEN); + } + } + } + f_window->Update(); + ScrollView* s_window = CreateFeatureSpaceWindow("Samples", 100, 500); + SVEventType ev_type; + do { + SVEvent* ev; + // Wait until a click or popup event. + ev = f_window->AwaitEvent(SVET_ANY); + ev_type = ev->type; + if (ev_type == SVET_CLICK) { + int feature_index = feature_space.XYToFeatureIndex(ev->x, ev->y); + if (feature_index >= 0) { + // Iterate samples and display those with the feature. + Shape shape; + shape.AddToShape(class_id1, cloud_font); + s_window->Clear(); + samples_.DisplaySamplesWithFeature(feature_index, shape, + feature_space, ScrollView::GREEN, + s_window); + s_window->Update(); + } + } + delete ev; + } while (ev_type != SVET_DESTROY); +} + +// Tests the given test_classifier on the internal samples. +// See TestClassifier for details. +void MasterTrainer::TestClassifierOnSamples(int report_level, + bool replicate_samples, + ShapeClassifier* test_classifier, + STRING* report_string) { + TestClassifier(report_level, replicate_samples, &samples_, + test_classifier, report_string); +} + +// Tests the given test_classifier on the given samples +// report_levels: +// 0 = no output. +// 1 = bottom-line error rate. +// 2 = bottom-line error rate + time. +// 3 = font-level error rate + time. +// 4 = list of all errors + short classifier debug output on 16 errors. +// 5 = list of all errors + short classifier debug output on 25 errors. +// If replicate_samples is true, then the test is run on an extended test +// sample including replicated and systematically perturbed samples. +// If report_string is non-NULL, a summary of the results for each font +// is appended to the report_string. +double MasterTrainer::TestClassifier(int report_level, + bool replicate_samples, + TrainingSampleSet* samples, + ShapeClassifier* test_classifier, + STRING* report_string) { + SampleIterator sample_it; + sample_it.Init(NULL, test_classifier->GetShapeTable(), replicate_samples, + samples); + if (report_level > 0) { + int num_samples = 0; + for (sample_it.Begin(); !sample_it.AtEnd(); sample_it.Next()) + ++num_samples; + tprintf("Iterator has charset size of %d/%d, %d shapes, %d samples\n", + sample_it.SparseCharsetSize(), sample_it.CompactCharsetSize(), + test_classifier->GetShapeTable()->NumShapes(), num_samples); + tprintf("Testing %sREPLICATED:\n", replicate_samples ? "" : "NON-"); + } + double unichar_error = 0.0; + ErrorCounter::ComputeErrorRate(test_classifier, report_level, + CT_SHAPE_TOP_ERR, fontinfo_table_, + page_images_, &sample_it, &unichar_error, + NULL, report_string); + return unichar_error; +} + +// Returns the average (in some sense) distance between the two given +// shapes, which may contain multiple fonts and/or unichars. +float MasterTrainer::ShapeDistance(const ShapeTable& shapes, int s1, int s2) { + const IntFeatureMap& feature_map = feature_map_; + const Shape& shape1 = shapes.GetShape(s1); + const Shape& shape2 = shapes.GetShape(s2); + int num_chars1 = shape1.size(); + int num_chars2 = shape2.size(); + float dist_sum = 0.0f; + int dist_count = 0; + if (num_chars1 > 1 || num_chars2 > 1) { + // In the multi-char case try to optimize the calculation by computing + // distances between characters of matching font where possible. + for (int c1 = 0; c1 < num_chars1; ++c1) { + for (int c2 = 0; c2 < num_chars2; ++c2) { + dist_sum += samples_.UnicharDistance(shape1[c1], shape2[c2], + true, feature_map); + ++dist_count; + } + } + } else { + // In the single unichar case, there is little alternative, but to compute + // the squared-order distance between pairs of fonts. + dist_sum = samples_.UnicharDistance(shape1[0], shape2[0], + false, feature_map); + ++dist_count; + } + return dist_sum / dist_count; +} + +// Replaces samples that are always fragmented with the corresponding +// fragment samples. +void MasterTrainer::ReplaceFragmentedSamples() { + if (fragments_ == NULL) return; + // Remove samples that are replaced by fragments. Each class that was + // always naturally fragmented should be replaced by its fragments. + int num_samples = samples_.num_samples(); + for (int s = 0; s < num_samples; ++s) { + TrainingSample* sample = samples_.mutable_sample(s); + if (fragments_[sample->class_id()] > 0) + samples_.KillSample(sample); + } + samples_.DeleteDeadSamples(); + + // Get ids of fragments in junk_samples_ that replace the dead chars. + const UNICHARSET& frag_set = junk_samples_.unicharset(); +#if 0 + // TODO(rays) The original idea was to replace only graphemes that were + // always naturally fragmented, but that left a lot of the Indic graphemes + // out. Determine whether we can go back to that idea now that spacing + // is fixed in the training images, or whether this code is obsolete. + bool* good_junk = new bool[frag_set.size()]; + memset(good_junk, 0, sizeof(*good_junk) * frag_set.size()); + for (int dead_ch = 1; dead_ch < unicharset_.size(); ++dead_ch) { + int frag_ch = fragments_[dead_ch]; + if (frag_ch <= 0) continue; + const char* frag_utf8 = frag_set.id_to_unichar(frag_ch); + CHAR_FRAGMENT* frag = CHAR_FRAGMENT::parse_from_string(frag_utf8); + // Mark the chars for all parts of the fragment as good in good_junk. + for (int part = 0; part < frag->get_total(); ++part) { + frag->set_pos(part); + int good_ch = frag_set.unichar_to_id(frag->to_string().string()); + if (good_ch != INVALID_UNICHAR_ID) + good_junk[good_ch] = true; // We want this one. + } + } +#endif + // For now just use all the junk that was from natural fragments. + // Get samples of fragments in junk_samples_ that replace the dead chars. + int num_junks = junk_samples_.num_samples(); + for (int s = 0; s < num_junks; ++s) { + TrainingSample* sample = junk_samples_.mutable_sample(s); + int junk_id = sample->class_id(); + const char* frag_utf8 = frag_set.id_to_unichar(junk_id); + CHAR_FRAGMENT* frag = CHAR_FRAGMENT::parse_from_string(frag_utf8); + if (frag != NULL && frag->is_natural()) { + junk_samples_.extract_sample(s); + samples_.AddSample(frag_set.id_to_unichar(junk_id), sample); + } + } + junk_samples_.DeleteDeadSamples(); + junk_samples_.OrganizeByFontAndClass(); + samples_.OrganizeByFontAndClass(); + unicharset_.clear(); + unicharset_.AppendOtherUnicharset(samples_.unicharset()); + // delete [] good_junk; + // Fragments_ no longer needed? + delete [] fragments_; + fragments_ = NULL; +} + +// Runs a hierarchical agglomerative clustering to merge shapes in the given +// shape_table, while satisfying the given constraints: +// * End with at least min_shapes left in shape_table, +// * No shape shall have more than max_shape_unichars in it, +// * Don't merge shapes where the distance between them exceeds max_dist. +const float kInfiniteDist = 999.0f; +void MasterTrainer::ClusterShapes(int min_shapes, int max_shape_unichars, + float max_dist, ShapeTable* shapes) { + int num_shapes = shapes->NumShapes(); + int max_merges = num_shapes - min_shapes; + GenericVector* shape_dists = + new GenericVector[num_shapes]; + float min_dist = kInfiniteDist; + int min_s1 = 0; + int min_s2 = 0; + tprintf("Computing shape distances..."); + for (int s1 = 0; s1 < num_shapes; ++s1) { + for (int s2 = s1 + 1; s2 < num_shapes; ++s2) { + ShapeDist dist(s1, s2, ShapeDistance(*shapes, s1, s2)); + shape_dists[s1].push_back(dist); + if (dist.distance < min_dist) { + min_dist = dist.distance; + min_s1 = s1; + min_s2 = s2; + } + } + tprintf(" %d", s1); + } + tprintf("\n"); + int num_merged = 0; + while (num_merged < max_merges && min_dist < max_dist) { + tprintf("Distance = %f: ", min_dist); + int num_unichars = shapes->MergedUnicharCount(min_s1, min_s2); + shape_dists[min_s1][min_s2 - min_s1 - 1].distance = kInfiniteDist; + if (num_unichars > max_shape_unichars) { + tprintf("Merge of %d and %d with %d would exceed max of %d unichars\n", + min_s1, min_s2, num_unichars, max_shape_unichars); + } else { + shapes->MergeShapes(min_s1, min_s2); + shape_dists[min_s2].clear(); + ++num_merged; + + for (int s = 0; s < min_s1; ++s) { + if (!shape_dists[s].empty()) { + shape_dists[s][min_s1 - s - 1].distance = + ShapeDistance(*shapes, s, min_s1); + shape_dists[s][min_s2 - s -1].distance = kInfiniteDist; + } + } + for (int s2 = min_s1 + 1; s2 < num_shapes; ++s2) { + if (shape_dists[min_s1][s2 - min_s1 - 1].distance < kInfiniteDist) + shape_dists[min_s1][s2 - min_s1 - 1].distance = + ShapeDistance(*shapes, min_s1, s2); + } + for (int s = min_s1 + 1; s < min_s2; ++s) { + if (!shape_dists[s].empty()) { + shape_dists[s][min_s2 - s - 1].distance = kInfiniteDist; + } + } + } + min_dist = kInfiniteDist; + for (int s1 = 0; s1 < num_shapes; ++s1) { + for (int i = 0; i < shape_dists[s1].size(); ++i) { + if (shape_dists[s1][i].distance < min_dist) { + min_dist = shape_dists[s1][i].distance; + min_s1 = s1; + min_s2 = s1 + 1 + i; + } + } + } + } + tprintf("Stopped with %d merged, min dist %f\n", num_merged, min_dist); + delete [] shape_dists; + if (debug_level_ > 1) { + for (int s1 = 0; s1 < num_shapes; ++s1) { + if (shapes->MasterDestinationIndex(s1) == s1) { + tprintf("Master shape:%s\n", shapes->DebugStr(s1).string()); + } + } + } +} + + +} // namespace tesseract. diff --git a/classify/mastertrainer.h b/classify/mastertrainer.h new file mode 100644 index 0000000000..90d396a9be --- /dev/null +++ b/classify/mastertrainer.h @@ -0,0 +1,296 @@ +// Copyright 2010 Google Inc. All Rights Reserved. +// Author: rays@google.com (Ray Smith) +/////////////////////////////////////////////////////////////////////// +// File: mastertrainer.h +// Description: Trainer to build the MasterClassifier. +// Author: Ray Smith +// Created: Wed Nov 03 18:07:01 PDT 2010 +// +// (C) Copyright 2010, Google Inc. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +/////////////////////////////////////////////////////////////////////// + +#ifndef TESSERACT_TRAINING_MASTERTRAINER_H__ +#define TESSERACT_TRAINING_MASTERTRAINER_H__ + +/**---------------------------------------------------------------------------- + Include Files and Type Defines +----------------------------------------------------------------------------**/ +#include "classify.h" +#include "cluster.h" +#include "intfx.h" +#include "elst.h" +#include "featdefs.h" +#include "fontinfo.h" +#include "indexmapbidi.h" +#include "intfeaturespace.h" +#include "intfeaturemap.h" +#include "intmatcher.h" +#include "params.h" +#include "shapetable.h" +#include "trainingsample.h" +#include "trainingsampleset.h" +#include "unicharset.h" + +namespace tesseract { + +class ShapeClassifier; + +// Simple struct to hold the distance between two shapes during clustering. +struct ShapeDist { + ShapeDist() : shape1(0), shape2(0), distance(0.0f) {} + ShapeDist(int s1, int s2, float dist) + : shape1(s1), shape2(s2), distance(dist) {} + + // Sort operator to sort in ascending order of distance. + bool operator<(const ShapeDist& other) const { + return distance < other.distance; + } + + int shape1; + int shape2; + float distance; +}; + +// Class to encapsulate training processes that use the TrainingSampleSet. +// Initially supports shape clustering and mftrainining. +// Other important features of the MasterTrainer are conditioning the data +// by outlier elimination, replication with perturbation, and serialization. +class MasterTrainer { + public: + MasterTrainer(NormalizationMode norm_mode, bool shape_analysis, + bool replicate_samples, int debug_level); + ~MasterTrainer(); + + // Writes to the given file. Returns false in case of error. + bool Serialize(FILE* fp) const; + // Reads from the given file. Returns false in case of error. + // If swap is true, assumes a big/little-endian swap is needed. + bool DeSerialize(bool swap, FILE* fp); + + // Loads an initial unicharset, or sets one up if the file cannot be read. + void LoadUnicharset(const char* filename); + + // Sets the feature space definition. + void SetFeatureSpace(const IntFeatureSpace& fs) { + feature_space_ = fs; + feature_map_.Init(fs); + } + + // Reads the samples and their features from the given file, + // adding them to the trainer with the font_id from the content of the file. + // If verification, then these are verification samples, not training. + void ReadTrainingSamples(FILE *fp, + const FEATURE_DEFS_STRUCT& feature_defs, + bool verification); + + // Adds the given single sample to the trainer, setting the classid + // appropriately from the given unichar_str. + void AddSample(bool verification, const char* unichar_str, + TrainingSample* sample); + + // Loads all pages from the given tif filename and append to page_images_. + // Must be called after ReadTrainingSamples, as the current number of images + // is used as an offset for page numbers in the samples. + void LoadPageImages(const char* filename); + + // Cleans up the samples after initial load from the tr files, and prior to + // saving the MasterTrainer: + // Remaps fragmented chars if running shape anaylsis. + // Sets up the samples appropriately for class/fontwise access. + // Deletes outlier samples. + void PostLoadCleanup(); + + // Gets the samples ready for training. Use after both + // ReadTrainingSamples+PostLoadCleanup or DeSerialize. + // Re-indexes the features and computes canonical and cloud features. + void PreTrainingSetup(); + + // Sets up the master_shapes_ table, which tells which fonts should stay + // together until they get to a leaf node classifier. + void SetupMasterShapes(); + + // Adds the junk_samples_ to the main samples_ set. Junk samples are initially + // fragments and n-grams (all incorrectly segmented characters). + // Various training functions may result in incorrectly segmented characters + // being added to the unicharset of the main samples, perhaps because they + // form a "radical" decomposition of some (Indic) grapheme, or because they + // just look the same as a real character (like rn/m) + // This function moves all the junk samples, to the main samples_ set, but + // desirable junk, being any sample for which the unichar already exists in + // the samples_ unicharset gets the unichar-ids re-indexed to match, but + // anything else gets re-marked as unichar_id 0 (space character) to identify + // it as junk to the error counter. + void IncludeJunk(); + + // Replicates the samples and perturbs them if the enable_replication_ flag + // is set. MUST be used after the last call to OrganizeByFontAndClass on + // the training samples, ie after IncludeJunk if it is going to be used, as + // OrganizeByFontAndClass will eat the replicated samples into the regular + // samples. + void ReplicateAndRandomizeSamplesIfRequired(); + + // Loads the basic font properties file into fontinfo_table_. + // Returns false on failure. + bool LoadFontInfo(const char* filename); + + // Loads the xheight font properties file into xheights_. + // Returns false on failure. + bool LoadXHeights(const char* filename); + + // Reads spacing stats from filename and adds them to fontinfo_table. + // Returns false on failure. + bool AddSpacingInfo(const char *filename); + + // Returns the font id corresponding to the given font name. + // Returns -1 if the font cannot be found. + int GetFontInfoId(const char* font_name); + // Returns the font_id of the closest matching font name to the given + // filename. It is assumed that a substring of the filename will match + // one of the fonts. If more than one is matched, the longest is returned. + int GetBestMatchingFontInfoId(const char* filename); + + // Sets up a flat shapetable with one shape per class/font combination. + void SetupFlatShapeTable(ShapeTable* shape_table); + + // Sets up a Clusterer for mftraining on a single shape_id. + // Call FreeClusterer on the return value after use. + CLUSTERER* SetupForClustering(const ShapeTable& shape_table, + const FEATURE_DEFS_STRUCT& feature_defs, + int shape_id, int* num_samples); + + // Writes the given float_classes (produced by SetupForFloat2Int) as inttemp + // to the given inttemp_file, and the corresponding pffmtable. + // The unicharset is the original encoding of graphemes, and shape_set should + // match the size of the shape_table, and may possibly be totally fake. + void WriteInttempAndPFFMTable(const UNICHARSET& unicharset, + const UNICHARSET& shape_set, + const ShapeTable& shape_table, + CLASS_STRUCT* float_classes, + const char* inttemp_file, + const char* pffmtable_file); + + const UNICHARSET& unicharset() const { + return samples_.unicharset(); + } + TrainingSampleSet* GetSamples() { + return &samples_; + } + const ShapeTable& master_shapes() const { + return master_shapes_; + } + + // Generates debug output relating to the canonical distance between the + // two given UTF8 grapheme strings. + void DebugCanonical(const char* unichar_str1, const char* unichar_str2); + // Debugging for cloud/canonical features. + // Displays a Features window containing: + // If unichar_str2 is in the unicharset, and canonical_font is non-negative, + // displays the canonical features of the char/font combination in red. + // If unichar_str1 is in the unicharset, and cloud_font is non-negative, + // displays the cloud feature of the char/font combination in green. + // The canonical features are drawn first to show which ones have no + // matches in the cloud features. + // Until the features window is destroyed, each click in the features window + // will display the samples that have that feature in a separate window. + void DisplaySamples(const char* unichar_str1, int cloud_font, + const char* unichar_str2, int canonical_font); + + // Tests the given test_classifier on the internal samples. + // See TestClassifier for details. + void TestClassifierOnSamples(int report_level, + bool replicate_samples, + ShapeClassifier* test_classifier, + STRING* report_string); + // Tests the given test_classifier on the given samples + // report_levels: + // 0 = no output. + // 1 = bottom-line error rate. + // 2 = bottom-line error rate + time. + // 3 = font-level error rate + time. + // 4 = list of all errors + short classifier debug output on 16 errors. + // 5 = list of all errors + short classifier debug output on 25 errors. + // If replicate_samples is true, then the test is run on an extended test + // sample including replicated and systematically perturbed samples. + // If report_string is non-NULL, a summary of the results for each font + // is appended to the report_string. + double TestClassifier(int report_level, + bool replicate_samples, + TrainingSampleSet* samples, + ShapeClassifier* test_classifier, + STRING* report_string); + + // Returns the average (in some sense) distance between the two given + // shapes, which may contain multiple fonts and/or unichars. + // This function is public to facilitate testing. + float ShapeDistance(const ShapeTable& shapes, int s1, int s2); + + private: + // Replaces samples that are always fragmented with the corresponding + // fragment samples. + void ReplaceFragmentedSamples(); + + // Runs a hierarchical agglomerative clustering to merge shapes in the given + // shape_table, while satisfying the given constraints: + // * End with at least min_shapes left in shape_table, + // * No shape shall have more than max_shape_unichars in it, + // * Don't merge shapes where the distance between them exceeds max_dist. + void ClusterShapes(int min_shapes, int max_shape_unichars, + float max_dist, ShapeTable* shape_table); + + private: + NormalizationMode norm_mode_; + // Character set we are training for. + UNICHARSET unicharset_; + // Original feature space. Subspace mapping is contained in feature_map_. + IntFeatureSpace feature_space_; + TrainingSampleSet samples_; + TrainingSampleSet junk_samples_; + TrainingSampleSet verify_samples_; + // Master shape table defines what fonts stay together until the leaves. + ShapeTable master_shapes_; + // Flat shape table has each unichar/font id pair in a separate shape. + ShapeTable flat_shapes_; + // Font metrics gathered from multiple files. + UnicityTable fontinfo_table_; + // Array of xheights indexed by font ids in fontinfo_table_; + GenericVector xheights_; + + // Non-serialized data initialized by other means or used temporarily + // during loading of training samples. + // Number of different class labels in unicharset_. + int charsetsize_; + // Flag to indicate that we are running shape analysis and need fragments + // fixing. + bool enable_shape_anaylsis_; + // Flag to indicate that sample replication is required. + bool enable_replication_; + // Flag to indicate that junk should be included in samples_. + bool include_junk_; + // Array of classids of fragments that replace the correctly segmented chars. + int* fragments_; + // Classid of previous correctly segmented sample that was added. + int prev_unichar_id_; + // Debug output control. + int debug_level_; + // Feature map used to construct reduced feature spaces for compact + // classifiers. + IntFeatureMap feature_map_; + // Vector of Pix pointers used for classifiers that need the image. + // Indexed by page_num_ in the samples. + // These images are owned by the trainer and need to be pixDestroyed. + GenericVector page_images_; +}; + +} // namespace tesseract. + +#endif diff --git a/classify/mf.h b/classify/mf.h index d65bc9c189..0f5e3f64b7 100644 --- a/classify/mf.h +++ b/classify/mf.h @@ -26,8 +26,11 @@ typedef enum { MFXPosition, MFYPosition, - MFLength, MFDirection, MFBulge1, MFBulge2 + MFLength, MFDirection, MFBulge1, MFBulge2, + MFCount // For array sizes. } MF_PARAM_NAME; + +typedef float MicroFeature[MFCount]; /*---------------------------------------------------------------------------- Private Function Prototypes -----------------------------------------------------------------------------*/ diff --git a/classify/mfoutline.cpp b/classify/mfoutline.cpp index e70ddf1599..5903f5d3a1 100644 --- a/classify/mfoutline.cpp +++ b/classify/mfoutline.cpp @@ -76,7 +76,8 @@ MFOUTLINE ConvertOutline(TESSLINE *outline) { EdgePoint = NextPoint; } while (EdgePoint != StartPoint); - MakeOutlineCircular(MFOutline); + if (MFOutline != NULL) + MakeOutlineCircular(MFOutline); return MFOutline; } @@ -95,7 +96,8 @@ LIST ConvertOutlines(TESSLINE *outline, while (outline != NULL) { mf_outline = ConvertOutline(outline); - mf_outlines = push(mf_outlines, mf_outline); + if (mf_outline != NULL) + mf_outlines = push(mf_outlines, mf_outline); outline = outline->next; } return mf_outlines; @@ -404,54 +406,6 @@ void Classify::NormalizeOutlines(LIST Outlines, } /* NormalizeOutlines */ } // namespace tesseract -/*---------------------------------------------------------------------------*/ -void SmearExtremities(MFOUTLINE Outline, FLOAT32 XScale, FLOAT32 YScale) { -/* - ** Parameters: - ** Outline outline whose extremities are to be smeared - ** XScale factor used to normalize outline in x dir - ** YScale factor used to normalize outline in y dir - ** Globals: none - ** Operation: - ** This routine smears the extremities of the specified outline. - ** It does this by adding a random number between - ** -0.5 and 0.5 pixels (that is why X/YScale are needed) to - ** the x and y position of the point. This is done so that - ** the discrete nature of the original scanned image does not - ** affect the statistical clustering used during training. - ** Return: none - ** Exceptions: none - ** History: 1/11/90, DSJ, Created. - */ - MFEDGEPT *Current; - MFOUTLINE EdgePoint; - FLOAT32 MinXSmear; - FLOAT32 MaxXSmear; - FLOAT32 MinYSmear; - FLOAT32 MaxYSmear; - - if (Outline != NIL_LIST) { - MinXSmear = -0.5 * XScale; - MaxXSmear = 0.5 * XScale; - MinYSmear = -0.5 * YScale; - MaxYSmear = 0.5 * YScale; - EdgePoint = Outline; - do { - Current = PointAt (EdgePoint); - if (Current->ExtremityMark) { - Current->Point.x += - UniformRandomNumber(MinXSmear, MaxXSmear); - Current->Point.y += - UniformRandomNumber(MinYSmear, MaxYSmear); - } - - EdgePoint = NextPointAfter (EdgePoint); - } - while (EdgePoint != Outline); - } -} /* SmearExtremities */ - - /**---------------------------------------------------------------------------- Private Code ----------------------------------------------------------------------------**/ diff --git a/classify/mfoutline.h b/classify/mfoutline.h index 3efacc8eb8..71dd310b81 100644 --- a/classify/mfoutline.h +++ b/classify/mfoutline.h @@ -114,8 +114,6 @@ MFOUTLINE NextExtremity(MFOUTLINE EdgePoint); void NormalizeOutline(MFOUTLINE Outline, FLOAT32 XOrigin); -void SmearExtremities(MFOUTLINE Outline, FLOAT32 XScale, FLOAT32 YScale); - /*---------------------------------------------------------------------------- Private Function Prototypes -----------------------------------------------------------------------------*/ diff --git a/classify/mfx.cpp b/classify/mfx.cpp index 85029ffef8..a053a051f9 100644 --- a/classify/mfx.cpp +++ b/classify/mfx.cpp @@ -102,7 +102,6 @@ CHAR_FEATURES BlobMicroFeatures(TBLOB *Blob, const DENORM& denorm) { Outline = (MFOUTLINE) first_node (RemainingOutlines); FindDirectionChanges(Outline, classify_min_slope, classify_max_slope); MarkDirectionChanges(Outline); - SmearExtremities(Outline, XScale, YScale); MicroFeatures = ConvertToMicroFeatures (Outline, MicroFeatures); } FreeOutlines(Outlines); diff --git a/classify/normmatch.cpp b/classify/normmatch.cpp index a693a7348e..7b71f0dfaa 100644 --- a/classify/normmatch.cpp +++ b/classify/normmatch.cpp @@ -70,7 +70,8 @@ const double kWidthErrorWeighting = 0.125; ----------------------------------------------------------------------------**/ /*---------------------------------------------------------------------------*/ namespace tesseract { -FLOAT32 Classify::ComputeNormMatch(CLASS_ID ClassId, FEATURE Feature, +FLOAT32 Classify::ComputeNormMatch(CLASS_ID ClassId, + const FEATURE_STRUCT& feature, BOOL8 DebugMatch) { /* ** Parameters: @@ -96,12 +97,12 @@ FLOAT32 Classify::ComputeNormMatch(CLASS_ID ClassId, FEATURE Feature, /* handle requests for classification as noise */ if (ClassId == NO_CLASS) { /* kludge - clean up constants and make into control knobs later */ - Match = (Feature->Params[CharNormLength] * - Feature->Params[CharNormLength] * 500.0 + - Feature->Params[CharNormRx] * - Feature->Params[CharNormRx] * 8000.0 + - Feature->Params[CharNormRy] * - Feature->Params[CharNormRy] * 8000.0); + Match = (feature.Params[CharNormLength] * + feature.Params[CharNormLength] * 500.0 + + feature.Params[CharNormRx] * + feature.Params[CharNormRx] * 8000.0 + + feature.Params[CharNormRy] * + feature.Params[CharNormRy] * 8000.0); return (1.0 - NormEvidenceOf (Match)); } @@ -109,38 +110,48 @@ FLOAT32 Classify::ComputeNormMatch(CLASS_ID ClassId, FEATURE Feature, Protos = NormProtos->Protos[ClassId]; if (DebugMatch) { - cprintf ("\nFeature = "); - WriteFeature(stdout, Feature); + tprintf("\nChar norm for class %s\n", unicharset.id_to_unichar(ClassId)); } ProtoId = 0; iterate(Protos) { Proto = (PROTOTYPE *) first_node (Protos); - Delta = Feature->Params[CharNormY] - Proto->Mean[CharNormY]; + Delta = feature.Params[CharNormY] - Proto->Mean[CharNormY]; Match = Delta * Delta * Proto->Weight.Elliptical[CharNormY]; - Delta = Feature->Params[CharNormRx] - Proto->Mean[CharNormRx]; + if (DebugMatch) { + tprintf("YMiddle: Proto=%g, Delta=%g, Var=%g, Dist=%g\n", + Proto->Mean[CharNormY], Delta, + Proto->Weight.Elliptical[CharNormY], Match); + } + Delta = feature.Params[CharNormRx] - Proto->Mean[CharNormRx]; Match += Delta * Delta * Proto->Weight.Elliptical[CharNormRx]; + if (DebugMatch) { + tprintf("Height: Proto=%g, Delta=%g, Var=%g, Dist=%g\n", + Proto->Mean[CharNormRx], Delta, + Proto->Weight.Elliptical[CharNormRx], Match); + } // Ry is width! See intfx.cpp. - Delta = Feature->Params[CharNormRy] - Proto->Mean[CharNormRy]; + Delta = feature.Params[CharNormRy] - Proto->Mean[CharNormRy]; + if (DebugMatch) { + tprintf("Width: Proto=%g, Delta=%g, Var=%g\n", + Proto->Mean[CharNormRy], Delta, + Proto->Weight.Elliptical[CharNormRy]); + } Delta = Delta * Delta * Proto->Weight.Elliptical[CharNormRy]; Delta *= kWidthErrorWeighting; Match += Delta; + if (DebugMatch) { + tprintf("Total Dist=%g, scaled=%g, sigmoid=%g, penalty=%g\n", + Match, Match / classify_norm_adj_midpoint, + NormEvidenceOf(Match), 256 * (1 - NormEvidenceOf(Match))); + } if (Match < BestMatch) BestMatch = Match; - if (DebugMatch) { - cprintf ("Proto %1d = ", ProtoId); - WriteNFloats (stdout, NormProtos->NumParams, Proto->Mean); - cprintf (" var = "); - WriteNFloats (stdout, NormProtos->NumParams, - Proto->Variance.Elliptical); - cprintf (" match = "); - PrintNormMatch (stdout, NormProtos->NumParams, Proto, Feature); - } ProtoId++; } - return (1.0 - NormEvidenceOf (BestMatch)); + return 1.0 - NormEvidenceOf(BestMatch); } /* ComputeNormMatch */ void Classify::FreeNormProtos() { @@ -230,7 +241,7 @@ NORM_PROTOS *Classify::ReadNormProtos(FILE *File, inT64 end_offset) { */ NORM_PROTOS *NormProtos; int i; - char unichar[UNICHAR_LEN + 1]; + char unichar[2 * UNICHAR_LEN + 1]; UNICHAR_ID unichar_id; LIST Protos; int NumProtos; @@ -256,8 +267,12 @@ NORM_PROTOS *Classify::ReadNormProtos(FILE *File, inT64 end_offset) { Protos = push_last (Protos, ReadPrototype (File, NormProtos->NumParams)); NormProtos->Protos[unichar_id] = Protos; - } else - cprintf("Error: unichar %s in normproto file is not in unichar set.\n"); + } else { + cprintf("Error: unichar %s in normproto file is not in unichar set.\n", + unichar); + for (i = 0; i < NumProtos; i++) + FreePrototype(ReadPrototype (File, NormProtos->NumParams)); + } SkipNewline(File); } return (NormProtos); diff --git a/classify/ocrfeatures.cpp b/classify/ocrfeatures.cpp index d2f5a70c0a..06dc5b0c85 100644 --- a/classify/ocrfeatures.cpp +++ b/classify/ocrfeatures.cpp @@ -231,12 +231,11 @@ void WriteFeature(FILE *File, FEATURE Feature) { for (i = 0; i < Feature->Type->NumParams; i++) { #ifndef WIN32 - assert (!isnan(Feature->Params[i])); + assert(!isnan(Feature->Params[i])); #endif - fprintf (File, " %12g", Feature->Params[i]); + fprintf(File, " %g", Feature->Params[i]); } - fprintf (File, "\n"); - + fprintf(File, "\n"); } /* WriteFeature */ diff --git a/classify/ocrfeatures.h b/classify/ocrfeatures.h index 8b5c3f43be..8ca9e59752 100644 --- a/classify/ocrfeatures.h +++ b/classify/ocrfeatures.h @@ -100,11 +100,9 @@ const PARAM_DESC Name[] = { Macro for describing a new feature. The parameters of the macro are as follows: -DefineFeature (Name, NumLinear, NumCircular, - MinFeatPerChar, MaxFeatPerChar, - LongName, ShortName, ParamName) +DefineFeature (Name, NumLinear, NumCircular, ShortName, ParamName) ----------------------------------------------------------------------*/ -#define DefineFeature(Name, NL, NC, Min, Max, LN, SN, PN) \ +#define DefineFeature(Name, NL, NC, SN, PN) \ const FEATURE_DESC_STRUCT Name = { \ ((NL) + (NC)), SN, PN}; diff --git a/classify/picofeat.cpp b/classify/picofeat.cpp index 95aac54fb7..9f2a4ead36 100644 --- a/classify/picofeat.cpp +++ b/classify/picofeat.cpp @@ -27,6 +27,7 @@ #include "mfoutline.h" #include "ocrfeatures.h" #include "params.h" +#include "trainingsample.h" #include #include @@ -221,3 +222,59 @@ void NormalizePicoX(FEATURE_SET FeatureSet) { Feature->Params[PicoFeatX] -= Origin; } } /* NormalizePicoX */ + +/*---------------------------------------------------------------------------*/ +FEATURE_SET ExtractIntCNFeatures(TBLOB *blob, const DENORM& denorm) { +/* + ** Parameters: + ** blob blob to extract features from + ** denorm normalization/denormalization parameters. + ** Return: Integer character-normalized features for blob. + ** Exceptions: none + ** History: 8/8/2011, rays, Created. + */ + tesseract::TrainingSample* sample = GetIntFeatures( + tesseract::NM_CHAR_ANISOTROPIC, blob, denorm); + if (sample == NULL) return NULL; + + int num_features = sample->num_features(); + const INT_FEATURE_STRUCT* features = sample->features(); + FEATURE_SET feature_set = NewFeatureSet(num_features); + for (int f = 0; f < num_features; ++f) { + FEATURE feature = NewFeature(&IntFeatDesc); + + feature->Params[IntX] = features[f].X; + feature->Params[IntY] = features[f].Y; + feature->Params[IntDir] = features[f].Theta; + AddFeature(feature_set, feature); + } + delete sample; + + return feature_set; +} /* ExtractIntCNFeatures */ + +/*---------------------------------------------------------------------------*/ +FEATURE_SET ExtractIntGeoFeatures(TBLOB *blob, const DENORM& denorm) { +/* + ** Parameters: + ** blob blob to extract features from + ** denorm normalization/denormalization parameters. + ** Return: Geometric (top/bottom/width) features for blob. + ** Exceptions: none + ** History: 8/8/2011, rays, Created. + */ + tesseract::TrainingSample* sample = GetIntFeatures( + tesseract::NM_CHAR_ANISOTROPIC, blob, denorm); + if (sample == NULL) return NULL; + + FEATURE_SET feature_set = NewFeatureSet(1); + FEATURE feature = NewFeature(&IntFeatDesc); + + feature->Params[GeoBottom] = sample->geo_feature(GeoBottom); + feature->Params[GeoTop] = sample->geo_feature(GeoTop); + feature->Params[GeoWidth] = sample->geo_feature(GeoWidth); + AddFeature(feature_set, feature); + delete sample; + + return feature_set; +} /* ExtractIntGeoFeatures */ diff --git a/classify/picofeat.h b/classify/picofeat.h index 773b472265..db63d21197 100644 --- a/classify/picofeat.h +++ b/classify/picofeat.h @@ -24,6 +24,22 @@ #include "ocrfeatures.h" #include "params.h" +// Enum for the order/type of params in IntFeatDesc. +enum IntParams { + IntX, // x-position (0-255). + IntY, // y-position (0-255). + IntDir // Direction (0-255, circular). +}; + +// Enum for the order/type of params in GeoFeatDesc. +enum GeoParams { + GeoBottom, // Bounding box bottom in baseline space (0-255). + GeoTop, // Bounding box top in baseline space (0-255). + GeoWidth, // Bounding box width in baseline space (0-255). + + GeoCount // Number of geo features. +}; + typedef enum { PicoFeatY, PicoFeatDir, PicoFeatX } PICO_FEAT_PARAM_NAME; @@ -42,6 +58,9 @@ extern double_VAR_H(classify_pico_feature_length, 0.05, "Pico Feature Length"); ----------------------------------------------------------------------------**/ #define GetPicoFeatureLength() (PicoFeatureLength) +FEATURE_SET ExtractIntCNFeatures(TBLOB *Blob, const DENORM& denorm); +FEATURE_SET ExtractIntGeoFeatures(TBLOB *Blob, const DENORM& denorm); + /**---------------------------------------------------------------------------- Global Data Definitions and Declarations ----------------------------------------------------------------------------**/ diff --git a/classify/protos.h b/classify/protos.h index 1c61997d7c..6bc43e8e5f 100644 --- a/classify/protos.h +++ b/classify/protos.h @@ -51,8 +51,11 @@ typedef struct } PROTO_STRUCT; typedef PROTO_STRUCT *PROTO; -typedef struct -{ +struct CLASS_STRUCT { + CLASS_STRUCT() + : NumProtos(0), MaxNumProtos(0), Prototypes(NULL), + NumConfigs(0), MaxNumConfigs(0), Configurations(NULL) { + } inT16 NumProtos; inT16 MaxNumProtos; PROTO Prototypes; @@ -60,7 +63,7 @@ typedef struct inT16 MaxNumConfigs; CONFIGS Configurations; UnicityTableEqEq font_set; -} CLASS_STRUCT; +}; typedef CLASS_STRUCT *CLASS_TYPE; typedef CLASS_STRUCT *CLASSES; diff --git a/classify/sampleiterator.cpp b/classify/sampleiterator.cpp new file mode 100644 index 0000000000..5f06a288b9 --- /dev/null +++ b/classify/sampleiterator.cpp @@ -0,0 +1,262 @@ +// Copyright 2011 Google Inc. All Rights Reserved. +// Author: rays@google.com (Ray Smith) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +/////////////////////////////////////////////////////////////////////// + +#include "sampleiterator.h" + +#include "indexmapbidi.h" +#include "shapetable.h" +#include "trainingsample.h" +#include "trainingsampleset.h" + +namespace tesseract { + +// ================== SampleIterator Implementation ================= + +SampleIterator::SampleIterator() + : charset_map_(NULL), + shape_table_(NULL), + sample_set_(NULL), + randomize_(false), + owned_shape_table_(NULL) { + num_shapes_ = 0; + Begin(); +} + +SampleIterator::~SampleIterator() { + Clear(); +} + +void SampleIterator::Clear() { + delete owned_shape_table_; + owned_shape_table_ = NULL; +} + +// See class comment for arguments. +void SampleIterator::Init(const IndexMapBiDi* charset_map, + const ShapeTable* shape_table, + bool randomize, + TrainingSampleSet* sample_set) { + Clear(); + charset_map_ = charset_map; + shape_table_ = shape_table; + sample_set_ = sample_set; + randomize_ = randomize; + if (shape_table_ == NULL && charset_map_ != NULL) { + // The caller wishes to iterate by class. The easiest way to do this + // is to create a dummy shape_table_ that we will own. + int num_fonts = sample_set_->NumFonts(); + owned_shape_table_ = new ShapeTable(sample_set_->unicharset()); + int charsetsize = sample_set_->unicharset().size(); + for (int c = 0; c < charsetsize; ++c) { + // We always add a shape for each character to keep the index in sync + // with the unichar_id. + int shape_id = owned_shape_table_->AddShape(c, 0); + for (int f = 1; f < num_fonts; ++f) { + if (sample_set_->NumClassSamples(f, c, true) > 0) { + owned_shape_table_->AddToShape(shape_id, c, f); + } + } + } + shape_table_ = owned_shape_table_; + } + if (shape_table_ != NULL) { + num_shapes_ = shape_table_->NumShapes(); + } else { + num_shapes_ = randomize ? sample_set_->num_samples() + : sample_set_->num_raw_samples(); + } + Begin(); +} + +// Iterator functions designed for use with a simple for loop: +// for (it.Begin(); !it.AtEnd(); it.Next()) { +// const TrainingSample& sample = it.GetSample(); +// } +void SampleIterator::Begin() { + shape_index_ = -1; + shape_char_index_ = 0; + num_shape_chars_ = 0; + shape_font_index_ = 0; + num_shape_fonts_ = 0; + sample_index_ = 0; + num_samples_ = 0; + // Find the first indexable sample. + Next(); +} + +bool SampleIterator::AtEnd() const { + return shape_index_ >= num_shapes_; +} + +const TrainingSample& SampleIterator::GetSample() const { + if (shape_table_ != NULL) { + const UnicharAndFonts* shape_entry = GetShapeEntry(); + int char_id = shape_entry->unichar_id; + int font_id = shape_entry->font_ids[shape_font_index_]; + return *sample_set_->GetSample(font_id, char_id, sample_index_); + } else { + return *sample_set_->GetSample(shape_index_); + } +} + +TrainingSample* SampleIterator::MutableSample() const { + if (shape_table_ != NULL) { + const UnicharAndFonts* shape_entry = GetShapeEntry(); + int char_id = shape_entry->unichar_id; + int font_id = shape_entry->font_ids[shape_font_index_]; + return sample_set_->MutableSample(font_id, char_id, sample_index_); + } else { + return sample_set_->mutable_sample(shape_index_); + } +} + +// Returns the total index (from the original set of samples) of the current +// sample. +int SampleIterator::GlobalSampleIndex() const { + if (shape_table_ != NULL) { + const UnicharAndFonts* shape_entry = GetShapeEntry(); + int char_id = shape_entry->unichar_id; + int font_id = shape_entry->font_ids[shape_font_index_]; + return sample_set_->GlobalSampleIndex(font_id, char_id, sample_index_); + } else { + return shape_index_; + } +} + +// Returns the index of the current sample in compact charset space, so +// in a 2-class problem between x and y, the returned indices will all be +// 0 or 1, and have nothing to do with the unichar_ids. +// If the charset_map_ is NULL, then this is equal to GetSparseClassID(). +int SampleIterator::GetCompactClassID() const { + return charset_map_ != NULL ? charset_map_->SparseToCompact(shape_index_) + : GetSparseClassID(); +} +// Returns the index of the current sample in sparse charset space, so +// in a 2-class problem between x and y, the returned indices will all be +// x or y, where x and y may be unichar_ids (no shape_table_) or shape_ids +// with a shape_table_. +int SampleIterator::GetSparseClassID() const { + return shape_table_ != NULL ? shape_index_ : GetSample().class_id(); +} + +// Moves on to the next indexable sample. If the end is reached, leaves +// the state such that AtEnd() is true. +void SampleIterator::Next() { + if (shape_table_ != NULL) { + // Next sample in this class/font combination. + ++sample_index_; + if (sample_index_ < num_samples_) + return; + // Next font in this class in this shape. + sample_index_ = 0; + do { + ++shape_font_index_; + if (shape_font_index_ >= num_shape_fonts_) { + // Next unichar in this shape. + shape_font_index_ = 0; + ++shape_char_index_; + if (shape_char_index_ >= num_shape_chars_) { + // Find the next shape that is mapped in the charset_map_. + shape_char_index_ = 0; + do { + ++shape_index_; + } while (shape_index_ < num_shapes_ && + charset_map_ != NULL && + charset_map_->SparseToCompact(shape_index_) < 0); + if (shape_index_ >= num_shapes_) + return; // The end. + num_shape_chars_ = shape_table_->GetShape(shape_index_).size(); + } + } + const UnicharAndFonts* shape_entry = GetShapeEntry(); + num_shape_fonts_ = shape_entry->font_ids.size(); + int char_id = shape_entry->unichar_id; + int font_id = shape_entry->font_ids[shape_font_index_]; + num_samples_ = sample_set_->NumClassSamples(font_id, char_id, randomize_); + } while (num_samples_ == 0); + } else { + // We are just iterating over the samples. + ++shape_index_; + } +} + +// Returns the size of the compact charset space. +int SampleIterator::CompactCharsetSize() const { + return charset_map_ != NULL ? charset_map_->CompactSize() + : SparseCharsetSize(); +} + +// Returns the size of the sparse charset space. +int SampleIterator::SparseCharsetSize() const { + return charset_map_ != NULL + ? charset_map_->SparseSize() + : (shape_table_ != NULL ? shape_table_->NumShapes() + : sample_set_->charsetsize()); +} + +// Apply the supplied feature_space/feature_map transform to all samples +// accessed by this iterator. +void SampleIterator::MapSampleFeatures(const IntFeatureMap& feature_map) { + for (Begin(); !AtEnd(); Next()) { + TrainingSample* sample = MutableSample(); + sample->MapFeatures(feature_map); + } +} + +// Adjust the weights of all the samples to be uniform in the given charset. +// Returns the number of samples in the iterator. +int SampleIterator::UniformSamples() { + int num_good_samples = 0; + for (Begin(); !AtEnd(); Next()) { + TrainingSample* sample = MutableSample(); + sample->set_weight(1.0); + ++num_good_samples; + } + NormalizeSamples(); + return num_good_samples; +} + +// Normalize the weights of all the samples in the charset_map so they sum +// to 1. Returns the minimum assigned sample weight. +double SampleIterator::NormalizeSamples() { + double total_weight = 0.0; + int sample_count = 0; + for (Begin(); !AtEnd(); Next()) { + const TrainingSample& sample = GetSample(); + total_weight += sample.weight(); + ++sample_count; + } + // Normalize samples. + double min_assigned_sample_weight = 1.0; + if (total_weight > 0.0) { + for (Begin(); !AtEnd(); Next()) { + TrainingSample* sample = MutableSample(); + double weight = sample->weight() / total_weight; + if (weight < min_assigned_sample_weight) + min_assigned_sample_weight = weight; + sample->set_weight(weight); + } + } + return min_assigned_sample_weight; +} + +// Helper returns the current UnicharAndFont shape_entry. +const UnicharAndFonts* SampleIterator::GetShapeEntry() const { + const Shape& shape = shape_table_->GetShape(shape_index_); + return &shape[shape_char_index_]; +} + +} // namespace tesseract. + diff --git a/classify/sampleiterator.h b/classify/sampleiterator.h new file mode 100644 index 0000000000..fd22d31631 --- /dev/null +++ b/classify/sampleiterator.h @@ -0,0 +1,195 @@ +// Copyright 2011 Google Inc. All Rights Reserved. +// Author: rays@google.com (Ray Smith) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +/////////////////////////////////////////////////////////////////////// + + +#ifndef TESSERACT_CLASSIFY_SAMPLEITERATOR_H_ +#define TESSERACT_CLASSIFY_SAMPLEITERATOR_H_ + +namespace tesseract { + +class IndexMapBiDi; +class IntFeatureMap; +class ShapeTable; +class TrainingSample; +class TrainingSampleSet; +class UnicharAndFonts; + +// Iterator class to encapsulate the complex iteration involved in getting +// all samples of all shapes needed for a classification problem. +// +// =====INPUTS TO Init FUNCTION===== +// The charset_map defines a subset of the sample_set classes (with a NULL +// shape_table, or the shape_table classes if not NULL.) +// +// The shape_table (if not NULL) defines the mapping from shapes to +// font_id/class_id pairs. Each shape is a list of unichar_id and font lists. +// +// The sample_set holds the samples and provides indexed access to samples +// of font_id/class_id pairs. +// +// If randomize is true, the samples are perturbed slightly, but the +// perturbation is guaranteed to be the same for multiple identical +// iterations. +// +// =====DIFFERENT COMBINATIONS OF INPUTS===== +// NULL shape_table: +// Without a shape_table, everything works in UNICHAR_IDs. +// +// NULL shape_table, NULL charset_map: +// Iterations simply run over the samples in the order the samples occur in the +// input files. +// GetCompactClassID and GetSparseClassID both return the sample UNICHAR_ID. +// +// NULL shape_table, non-NULL charset_map: +// When shape_table is NULL, the charset_map indexes unichar_ids directly, +// and an iteration returns all samples of all chars in the charset_map, which +// is a subset of the full unicharset. +// The iteration will be in groups of the same unichar_id, in the order +// defined by the charset_map. +// GetCompactClassID returns the charset_map index of a sample, and +// GetSparseClassID returns the sample UNICHAR_ID. +// +// Non-NULL shape_table: +// With a shape_table, samples are grouped according to the shape_table, so +// multiple UNICHAR_IDs and fonts may be grouped together, and everything +// works in shape_ids. +// +// Non-NULL shape_table, NULL charset_map. +// Iterations simply run over the samples in the order of shape_id. +// GetCompactClassID and GetSparseClassID both return the shape_id. +// (If you want the unichar_id or font_id, the sample still has them.) +// +// Non-NULL shape_table, non-NULL charset_map. +// When shape_table is not NULL, the charset_map indexes and subsets shapes in +// the shape_table, and iterations will be in shape_table order, not +// charset_map order. +// GetCompactClassID returns the charset_map index of a shape, and +// GetSparseClassID returns the shape_id. +// +// =====What is SampleIterator good for?===== +// Inside a classifier training module, the SampleIterator has abstracted away +// all the different modes above. +// Use the following iteration to train your classifier: +// for (it.Begin(); !it.AtEnd(); it.Next()) { +// const TrainingSample& sample = it.GetSample(); +// int class_id = it.GetCompactClassID(); +// Your classifier may or may not be dealing with a shape_table, and may be +// dealing with some subset of the character/shape set. It doesn't need to +// know and shouldn't care. It is just learning shapes with compact class ids +// in the range [0, it.CompactCharsetSize()). +class SampleIterator { + public: + SampleIterator(); + ~SampleIterator(); + + void Clear(); + + // See class comment for arguments. + void Init(const IndexMapBiDi* charset_map, + const ShapeTable* shape_table, + bool randomize, + TrainingSampleSet* sample_set); + + // Iterator functions designed for use with a simple for loop: + // for (it.Begin(); !it.AtEnd(); it.Next()) { + // const TrainingSample& sample = it.GetSample(); + // int class_id = it.GetCompactClassID(); + // ... + // } + void Begin(); + bool AtEnd() const; + const TrainingSample& GetSample() const; + TrainingSample* MutableSample() const; + // Returns the total index (from the original set of samples) of the current + // sample. + int GlobalSampleIndex() const; + // Returns the index of the current sample in compact charset space, so + // in a 2-class problem between x and y, the returned indices will all be + // 0 or 1, and have nothing to do with the unichar_ids. + // If the charset_map_ is NULL, then this is equal to GetSparseClassID(). + int GetCompactClassID() const; + // Returns the index of the current sample in sparse charset space, so + // in a 2-class problem between x and y, the returned indices will all be + // x or y, where x and y may be unichar_ids (no shape_table_) or shape_ids + // with a shape_table_. + int GetSparseClassID() const; + // Moves on to the next indexable sample. If the end is reached, leaves + // the state such that AtEnd() is true. + void Next(); + + // Returns the size of the compact charset space. + int CompactCharsetSize() const; + // Returns the size of the sparse charset space. + int SparseCharsetSize() const; + + const IndexMapBiDi& charset_map() const { + return *charset_map_; + } + const ShapeTable* shape_table() const { + return shape_table_; + } + // Sample set operations. + const TrainingSampleSet* sample_set() const { + return sample_set_; + } + + // A set of functions that do something to all the samples accessed by the + // iterator, as it is currently setup. + + // Apply the supplied feature_space/feature_map transform to all samples + // accessed by this iterator. + void MapSampleFeatures(const IntFeatureMap& feature_map); + + // Adjust the weights of all the samples to be uniform in the given charset. + // Returns the number of samples in the iterator. + int UniformSamples(); + + // Normalize the weights of all the samples defined by the iterator so they + // sum to 1. Returns the minimum assigned sample weight. + double NormalizeSamples(); + + private: + // Helper returns the current UnicharAndFont shape_entry. + const UnicharAndFonts* GetShapeEntry() const; + + // Map to subset the actual charset space. + const IndexMapBiDi* charset_map_; + // Shape table to recombine character classes into shapes + const ShapeTable* shape_table_; + // The samples to iterate over. + TrainingSampleSet* sample_set_; + // Flag to control randomizing the sample features. + bool randomize_; + // Shape table owned by this used to iterate character classes. + ShapeTable* owned_shape_table_; + + // Top-level iteration. Shape index in sparse charset_map space. + int shape_index_; + int num_shapes_; + // Index to the character class within a shape. + int shape_char_index_; + int num_shape_chars_; + // Index to the font within a shape/class pair. + int shape_font_index_; + int num_shape_fonts_; + // The lowest level iteration. sample_index_/num_samples_ counts samples + // in the current shape/class/font combination. + int sample_index_; + int num_samples_; +}; + +} // namespace tesseract. + +#endif // TESSERACT_CLASSIFY_SAMPLEITERATOR_H_ diff --git a/classify/shapeclassifier.h b/classify/shapeclassifier.h new file mode 100644 index 0000000000..08808127f1 --- /dev/null +++ b/classify/shapeclassifier.h @@ -0,0 +1,95 @@ +// Copyright 2011 Google Inc. All Rights Reserved. +// Author: rays@google.com (Ray Smith) +/////////////////////////////////////////////////////////////////////// +// File: shapeclassifier.h +// Description: Base interface class for classifiers that return a +// shape index. +// Author: Ray Smith +// Created: Tue Sep 13 11:26:32 PDT 2011 +// +// (C) Copyright 2011, Google Inc. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +/////////////////////////////////////////////////////////////////////// + +#ifndef TESSERACT_CLASSIFY_SHAPECLASSIFIER_H_ +#define TESSERACT_CLASSIFY_SHAPECLASSIFIER_H_ + +template class GenericVector; +struct Pix; + +namespace tesseract { + +class ShapeTable; +class TrainingSample; + +// Classifier result from a low-level classification is an index into some +// ShapeTable and a rating. +struct ShapeRating { + ShapeRating() : shape_id(0), rating(0.0f), raw(0.0f), font(0.0f) {} + ShapeRating(int s, float r) + : shape_id(s), rating(r), raw(1.0f), font(0.0f) {} + + // Sort function to sort ratings appropriately by descending rating. + static int SortDescendingRating(const void* t1, const void* t2) { + const ShapeRating* a = reinterpret_cast(t1); + const ShapeRating* b = reinterpret_cast(t2); + if (a->rating > b->rating) { + return -1; + } else if (a->rating < b->rating) { + return 1; + } else { + return a->shape_id - b->shape_id; + } + } + + // Index into some shape table indicates the class of the answer. + int shape_id; + // Rating from classifier with 1.0 perfect and 0.0 impossible. + // Call it a probability if you must. + float rating; + // Subsidiary rating that a classifier may use internally. + float raw; + // Subsidiary rating that a classifier may use internally. + float font; +}; + +// Interface base class for classifiers that produce ShapeRating results. +class ShapeClassifier { + public: + virtual ~ShapeClassifier() {} + + // Classifies the given [training] sample, writing to results. + // If page_pix is not NULL, the overriding function may call + // sample.GetSamplePix(padding, page_pix) to get an image of the sample + // padded (with real image data) by the given padding to extract features + // from the image of the character. Other members of TrainingSample: + // features(), micro_features(), cn_feature(), geo_feature() may be used + // to get the appropriate tesseract features. + // If debug is non-zero, then various degrees of classifier dependent debug + // information is provided. + // If keep_this (a shape index) is >= 0, then the results should always + // contain keep_this, and (if possible) anything of intermediate confidence. + // (Used for answering "Why didn't it get that right?" questions.) + // The return value is the number of classes saved in results. + // NOTE that overriding functions MUST clear results unless the classifier + // is working with a team of such classifiers. + virtual int ClassifySample(const TrainingSample& sample, Pix* page_pix, + int debug, int keep_this, + GenericVector* results) = 0; + + // Provides access to the ShapeTable that this classifier works with. + virtual const ShapeTable* GetShapeTable() const = 0; +}; + +} // namespace tesseract. + +#endif // TESSERACT_CLASSIFY_SHAPECLASSIFIER_H_ diff --git a/classify/shapetable.cpp b/classify/shapetable.cpp new file mode 100644 index 0000000000..9d45652622 --- /dev/null +++ b/classify/shapetable.cpp @@ -0,0 +1,452 @@ +// Copyright 2010 Google Inc. All Rights Reserved. +// Author: rays@google.com (Ray Smith) +/////////////////////////////////////////////////////////////////////// +// File: shapetable.cpp +// Description: Class to map a classifier shape index to unicharset +// indices and font indices. +// Author: Ray Smith +// Created: Tue Nov 02 15:31:32 PDT 2010 +// +// (C) Copyright 2010, Google Inc. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +/////////////////////////////////////////////////////////////////////// + +#include "shapetable.h" + +#include "intfeaturespace.h" +#include "strngs.h" +#include "unicharset.h" + +namespace tesseract { + +// Writes to the given file. Returns false in case of error. +bool UnicharAndFonts::Serialize(FILE* fp) { + inT32 uni_id = unichar_id; + if (fwrite(&uni_id, sizeof(uni_id), 1, fp) != 1) return false; + if (!font_ids.Serialize(fp)) return false; + return true; +} +// Reads from the given file. Returns false in case of error. +// If swap is true, assumes a big/little-endian swap is needed. +bool UnicharAndFonts::DeSerialize(bool swap, FILE* fp) { + inT32 uni_id; + if (fread(&uni_id, sizeof(uni_id), 1, fp) != 1) return false; + if (swap) + ReverseN(&uni_id, sizeof(uni_id)); + unichar_id = uni_id; + if (!font_ids.DeSerialize(swap, fp)) return false; + return true; +} + +// Sort function to sort a pair of UnicharAndFonts by unichar_id. +int UnicharAndFonts::SortByUnicharId(const void* v1, const void* v2) { + const UnicharAndFonts* p1 = reinterpret_cast(v1); + const UnicharAndFonts* p2 = reinterpret_cast(v2); + return p1->unichar_id - p2->unichar_id; +} + +// Writes to the given file. Returns false in case of error. +bool Shape::Serialize(FILE* fp) { + if (fwrite(&unichars_sorted_, sizeof(unichars_sorted_), 1, fp) != 1) + return false; + if (!unichars_.SerializeClasses(fp)) return false; + return true; +} +// Reads from the given file. Returns false in case of error. +// If swap is true, assumes a big/little-endian swap is needed. +bool Shape::DeSerialize(bool swap, FILE* fp) { + if (fread(&unichars_sorted_, sizeof(unichars_sorted_), 1, fp) != 1) + return false; + if (!unichars_.DeSerializeClasses(swap, fp)) return false; + return true; +} + +// Adds a font_id for the given unichar_id. If the unichar_id is not +// in the shape, it is added. +void Shape::AddToShape(int unichar_id, int font_id) { + for (int c = 0; c < unichars_.size(); ++c) { + if (unichars_[c].unichar_id == unichar_id) { + // Found the unichar in the shape table. + GenericVector& font_list = unichars_[c].font_ids; + for (int f = 0; f < font_list.size(); ++f) { + if (font_list[f] == font_id) + return; // Font is already there. + } + font_list.push_back(font_id); + return; + } + } + // Unichar_id is not in shape, so add it to shape. + unichars_.push_back(UnicharAndFonts(unichar_id, font_id)); + unichars_sorted_ = unichars_.size() <= 1; +} + +// Adds everything in other to this. +void Shape::AddShape(const Shape& other) { + for (int c = 0; c < other.unichars_.size(); ++c) { + for (int f = 0; f < other.unichars_[c].font_ids.size(); ++f) { + AddToShape(other.unichars_[c].unichar_id, + other.unichars_[c].font_ids[f]); + } + } + unichars_sorted_ = unichars_.size() <= 1; +} + +// Returns true if the shape contains the given unichar_id, font_id pair. +bool Shape::ContainsUnicharAndFont(int unichar_id, int font_id) const { + for (int c = 0; c < unichars_.size(); ++c) { + if (unichars_[c].unichar_id == unichar_id) { + // Found the unichar, so look for the font. + GenericVector& font_list = unichars_[c].font_ids; + for (int f = 0; f < font_list.size(); ++f) { + if (font_list[f] == font_id) + return true; + } + return false; + } + } + return false; +} + +// Returns true if the shape contains the given unichar_id, ignoring font. +bool Shape::ContainsUnichar(int unichar_id) const { + for (int c = 0; c < unichars_.size(); ++c) { + if (unichars_[c].unichar_id == unichar_id) { + return true; + } + } + return false; +} + +// Returns true if the shape contains the given font, ignoring unichar_id. +bool Shape::ContainsFont(int font_id) const { + for (int c = 0; c < unichars_.size(); ++c) { + GenericVector& font_list = unichars_[c].font_ids; + for (int f = 0; f < font_list.size(); ++f) { + if (font_list[f] == font_id) + return true; + } + } + return false; +} + +// Returns true if this is a subset (including equal) of other. +bool Shape::IsSubsetOf(const Shape& other) const { + for (int c = 0; c < unichars_.size(); ++c) { + int unichar_id = unichars_[c].unichar_id; + const GenericVector& font_list = unichars_[c].font_ids; + for (int f = 0; f < font_list.size(); ++f) { + if (!other.ContainsUnicharAndFont(unichar_id, font_list[f])) + return false; + } + } + return true; +} + +// Returns true if the lists of unichar ids are the same in this and other, +// ignoring fonts. +// NOT const, as it will sort the unichars on demand. +bool Shape::IsEqualUnichars(Shape* other) { + if (unichars_.size() != other->unichars_.size()) return false; + if (!unichars_sorted_) SortUnichars(); + if (!other->unichars_sorted_) other->SortUnichars(); + for (int c = 0; c < unichars_.size(); ++c) { + if (unichars_[c].unichar_id != other->unichars_[c].unichar_id) + return false; + } + return true; +} + +// Sorts the unichars_ vector by unichar. +void Shape::SortUnichars() { + unichars_.sort(UnicharAndFonts::SortByUnicharId); + unichars_sorted_ = true; +} + +ShapeTable::ShapeTable() : unicharset_(NULL) { +} +ShapeTable::ShapeTable(const UNICHARSET& unicharset) + : unicharset_(&unicharset) { +} + +// Writes to the given file. Returns false in case of error. +bool ShapeTable::Serialize(FILE* fp) const { + if (!shape_table_.Serialize(fp)) return false; + return true; +} +// Reads from the given file. Returns false in case of error. +// If swap is true, assumes a big/little-endian swap is needed. +bool ShapeTable::DeSerialize(bool swap, FILE* fp) { + if (!shape_table_.DeSerialize(swap, fp)) return false; + return true; +} + +// Returns a string listing the classes/fonts in a shape. +STRING ShapeTable::DebugStr(int shape_id) const { + if (shape_id < 0 || shape_id >= shape_table_.size()) + return STRING("INVALID_UNICHAR_ID"); + const Shape& shape = GetShape(shape_id); + STRING result; + result.add_str_int("Shape", shape_id); + for (int c = 0; c < shape.size(); ++c) { + result.add_str_int(" c_id=", shape[c].unichar_id); + result += "="; + result += unicharset_->id_to_unichar(shape[c].unichar_id); + result.add_str_int(", ", shape[c].font_ids.size()); + result += " fonts ="; + for (int f = 0; f < shape[c].font_ids.size(); ++f) { + result.add_str_int(" ", shape[c].font_ids[f]); + } + } + return result; +} + +// Returns a debug string summarizing the table. +STRING ShapeTable::SummaryStr() const { + int max_unichars = 0; + int num_multi_shapes = 0; + int num_master_shapes = 0; + for (int s = 0; s < shape_table_.size(); ++s) { + if (MasterDestinationIndex(s) != s) continue; + ++num_master_shapes; + int shape_size = GetShape(s).size(); + if (shape_size > 1) + ++num_multi_shapes; + if (shape_size > max_unichars) + max_unichars = shape_size; + } + STRING result; + result.add_str_int("Number of shapes = ", num_master_shapes); + result.add_str_int(" max unichars = ", max_unichars); + result.add_str_int(" number with multiple unichars = ", num_multi_shapes); + return result; +} + + +// Adds a new shape starting with the given unichar_id and font_id. +// Returns the assigned index. +int ShapeTable::AddShape(int unichar_id, int font_id) { + int index = shape_table_.size(); + Shape* shape = new Shape; + shape->AddToShape(unichar_id, font_id); + shape_table_.push_back(shape); + return index; +} + +// Adds a copy of the given shape. +// Returns the assigned index. +int ShapeTable::AddShape(const Shape& other) { + int index = shape_table_.size(); + Shape* shape = new Shape(other); + shape_table_.push_back(shape); + return index; +} + +// Removes the shape given by the shape index. +void ShapeTable::DeleteShape(int shape_id) { + delete shape_table_[shape_id]; + shape_table_[shape_id] = NULL; + shape_table_.remove(shape_id); +} + +// Adds a font_id to the given existing shape index for the given +// unichar_id. If the unichar_id is not in the shape, it is added. +void ShapeTable::AddToShape(int shape_id, int unichar_id, int font_id) { + Shape& shape = *shape_table_[shape_id]; + shape.AddToShape(unichar_id, font_id); +} + +// Adds the given shape to the existing shape with the given index. +void ShapeTable::AddShapeToShape(int shape_id, const Shape& other) { + Shape& shape = *shape_table_[shape_id]; + shape.AddShape(other); +} + +// Returns the id of the shape that contains the given unichar and font. +// If not found, returns -1. +// If font_id < 0, the font_id is ignored and the first shape that matches +// the unichar_id is returned. +int ShapeTable::FindShape(int unichar_id, int font_id) const { + for (int s = 0; s < shape_table_.size(); ++s) { + const Shape& shape = GetShape(s); + for (int c = 0; c < shape.size(); ++c) { + if (shape[c].unichar_id == unichar_id) { + if (font_id < 0) + return s; // We don't care about the font. + for (int f = 0; f < shape[c].font_ids.size(); ++f) { + if (shape[c].font_ids[f] == font_id) + return s; + } + } + } + } + return -1; +} + +// Returns the first unichar_id and font_id in the given shape. +void ShapeTable::GetFirstUnicharAndFont(int shape_id, + int* unichar_id, int* font_id) const { + const UnicharAndFonts& unichar_and_fonts = (*shape_table_[shape_id])[0]; + *unichar_id = unichar_and_fonts.unichar_id; + *font_id = unichar_and_fonts.font_ids[0]; +} + +// Expands all the classes/fonts in the shape individually to build +// a ShapeTable. +int ShapeTable::BuildFromShape(const Shape& shape, + const ShapeTable& master_shapes) { + int num_masters = 0; + for (int u_ind = 0; u_ind < shape.size(); ++u_ind) { + for (int f_ind = 0; f_ind < shape[u_ind].font_ids.size(); ++f_ind) { + int c = shape[u_ind].unichar_id; + int f = shape[u_ind].font_ids[f_ind]; + if (FindShape(c, f) < 0) { + int shape_id = AddShape(c, f); + int master_id = master_shapes.FindShape(c, f); + if (master_id >= 0 && shape.size() > 1) { + const Shape& master = master_shapes.GetShape(master_id); + if (master.IsSubsetOf(shape) && !shape.IsSubsetOf(master)) { + // Add everything else from the master shape. + shape_table_[shape_id]->AddShape(master); + ++num_masters; + } + } + } + } + } + return num_masters; +} + +// Returns true if the shapes are already merged. +bool ShapeTable::AlreadyMerged(int shape_id1, int shape_id2) { + return MasterDestinationIndex(shape_id1) == MasterDestinationIndex(shape_id2); +} + +// Returns true if any shape contains multiple unichars. +bool ShapeTable::AnyMultipleUnichars() { + int num_shapes = NumShapes(); + for (int s1 = 0; s1 < num_shapes; ++s1) { + if (MasterDestinationIndex(s1) != s1) continue; + if (GetShape(s1).size() > 1) + return true; + } + return false; +} + +// Returns the maximum number of unichars over all shapes. +int ShapeTable::MaxNumUnichars() const { + int max_num_unichars = 0; + int num_shapes = NumShapes(); + for (int s = 0; s < num_shapes; ++s) { + if (GetShape(s).size() > max_num_unichars) + max_num_unichars = GetShape(s).size(); + } + return max_num_unichars; +} + + +// Merges shapes with a common unichar over the [start, end) interval. +// Assumes single unichar per shape. +void ShapeTable::ForceFontMerges(int start, int end) { + for (int s1 = start; s1 < end; ++s1) { + if (MasterDestinationIndex(s1) == s1 && GetShape(s1).size() == 1) { + int unichar_id = GetShape(s1)[0].unichar_id; + for (int s2 = s1 + 1; s2 < end; ++s2) { + if (MasterDestinationIndex(s2) == s2 && GetShape(s2).size() == 1 && + unichar_id == GetShape(s2)[0].unichar_id) { + MergeShapes(s1, s2); + } + } + } + } + ShapeTable compacted(*unicharset_); + compacted.AppendMasterShapes(*this); + *this = compacted; +} + +// Returns the number of unichars in the master shape. +int ShapeTable::MasterUnicharCount(int shape_id) const { + int master_id = MasterDestinationIndex(shape_id); + return GetShape(master_id).size(); +} + +// Returns the sum of the font counts in the master shape. +int ShapeTable::MasterFontCount(int shape_id) const { + int master_id = MasterDestinationIndex(shape_id); + const Shape& shape = GetShape(master_id); + int font_count = 0; + for (int c = 0; c < shape.size(); ++c) { + font_count += shape[c].font_ids.size(); + } + return font_count; +} + +// Returns the number of unichars that would result from merging the shapes. +int ShapeTable::MergedUnicharCount(int shape_id1, int shape_id2) const { + // Do it the easy way for now. + int master_id1 = MasterDestinationIndex(shape_id1); + int master_id2 = MasterDestinationIndex(shape_id2); + Shape combined_shape(*shape_table_[master_id1]); + combined_shape.AddShape(*shape_table_[master_id2]); + return combined_shape.size(); +} + +// Merges two shape_ids, leaving shape_id2 marked as merged. +void ShapeTable::MergeShapes(int shape_id1, int shape_id2) { + int master_id1 = MasterDestinationIndex(shape_id1); + int master_id2 = MasterDestinationIndex(shape_id2); + // Point master_id2 (and all merged shapes) to master_id1. + shape_table_[master_id2]->set_destination_index(master_id1); + // Add all the shapes of master_id2 to master_id1. + shape_table_[master_id1]->AddShape(*shape_table_[master_id2]); + tprintf("Merged shape %d->%d, %d->%d, now with %d unichars: %s\n", + shape_id1, master_id1, shape_id2, master_id2, + shape_table_[master_id1]->size(), + DebugStr(master_id1).string()); +} + +// Returns the destination of this shape, (if merged), taking into account +// the fact that the destination may itself have been merged. +int ShapeTable::MasterDestinationIndex(int shape_id) const { + int dest_id = shape_table_[shape_id]->destination_index(); + if (dest_id == shape_id || dest_id < 0) + return shape_id; // Is master already. + int master_id = shape_table_[dest_id]->destination_index(); + if (master_id == dest_id || master_id < 0) + return dest_id; // Dest is the master and shape_id points to it. + master_id = MasterDestinationIndex(master_id); + return master_id; +} + +// Appends the master shapes from other to this. +void ShapeTable::AppendMasterShapes(const ShapeTable& other) { + for (int s = 0; s < other.shape_table_.size(); ++s) { + if (other.shape_table_[s]->destination_index() < 0) { + AddShape(*other.shape_table_[s]); + } + } +} + +// Returns the number of master shapes remaining after merging. +int ShapeTable::NumMasterShapes() const { + int num_shapes = 0; + for (int s = 0; s < shape_table_.size(); ++s) { + if (shape_table_[s]->destination_index() < 0) + ++num_shapes; + } + return num_shapes; +} + + +} // namespace tesseract + + diff --git a/classify/shapetable.h b/classify/shapetable.h new file mode 100644 index 0000000000..375f334eb9 --- /dev/null +++ b/classify/shapetable.h @@ -0,0 +1,227 @@ +// Copyright 2010 Google Inc. All Rights Reserved. +// Author: rays@google.com (Ray Smith) +/////////////////////////////////////////////////////////////////////// +// File: shapetable.h +// Description: Class to map a classifier shape index to unicharset +// indices and font indices. +// Author: Ray Smith +// Created: Thu Oct 28 17:46:32 PDT 2010 +// +// (C) Copyright 2010, Google Inc. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +/////////////////////////////////////////////////////////////////////// + +#ifndef TESSERACT_CLASSIFY_SHAPETABLE_H_ +#define TESSERACT_CLASSIFY_SHAPETABLE_H_ + +#include "genericvector.h" +#include "intmatcher.h" + +class STRING; +class UNICHARSET; + +namespace tesseract { + +// Simple struct to hold a set of fonts associated with a single unichar-id. +// A vector of UnicharAndFonts makes a shape. +struct UnicharAndFonts { + UnicharAndFonts() : unichar_id(0) { + } + UnicharAndFonts(int uni_id, int font_id) : unichar_id(uni_id) { + font_ids.push_back(font_id); + } + + // Writes to the given file. Returns false in case of error. + bool Serialize(FILE* fp); + // Reads from the given file. Returns false in case of error. + // If swap is true, assumes a big/little-endian swap is needed. + bool DeSerialize(bool swap, FILE* fp); + + // Sort function to sort a pair of UnicharAndFonts by unichar_id. + static int SortByUnicharId(const void* v1, const void* v2); + + GenericVector font_ids; + int unichar_id; +}; + +// A Shape is a collection of unichar-ids and a list of fonts associated with +// each, organized as a vector of UnicharAndFonts. Conceptually a Shape is +// a classifiable unit, and represents a group of characters or parts of +// characters that have a similar or identical shape. Shapes/ShapeTables may +// be organized hierarchically from identical shapes at the leaves to vaguely +// similar shapes near the root. +class Shape { + public: + Shape() : destination_index_(-1) {} + + // Writes to the given file. Returns false in case of error. + bool Serialize(FILE* fp); + // Reads from the given file. Returns false in case of error. + // If swap is true, assumes a big/little-endian swap is needed. + bool DeSerialize(bool swap, FILE* fp); + + int destination_index() const { + return destination_index_; + } + void set_destination_index(int index) { + destination_index_ = index; + } + int size() const { + return unichars_.size(); + } + // Returns a UnicharAndFonts entry for the given index, which must be + // in the range [0, size()). + const UnicharAndFonts& operator[](int index) const { + return unichars_[index]; + } + // Adds a font_id for the given unichar_id. If the unichar_id is not + // in the shape, it is added. + void AddToShape(int unichar_id, int font_id); + // Adds everything in other to this. + void AddShape(const Shape& other); + // Returns true if the shape contains the given unichar_id, font_id pair. + bool ContainsUnicharAndFont(int unichar_id, int font_id) const; + // Returns true if the shape contains the given unichar_id, ignoring font. + bool ContainsUnichar(int unichar_id) const; + // Returns true if the shape contains the given font, ignoring unichar_id. + bool ContainsFont(int font_id) const; + // Returns true if this is a subset (including equal) of other. + bool IsSubsetOf(const Shape& other) const; + // Returns true if the lists of unichar ids are the same in this and other, + // ignoring fonts. + // NOT const, as it will sort the unichars on demand. + bool IsEqualUnichars(Shape* other); + + private: + // Sorts the unichars_ vector by unichar. + void SortUnichars(); + + // Flag indicates that the unichars are sorted, allowing faster set + // operations with another shape. + bool unichars_sorted_; + // If this Shape is part of a ShapeTable the destiation_index_ is the index + // of some other shape in the ShapeTable with which this shape is merged. + int destination_index_; + // Array of unichars, each with a set of fonts. Each unichar has at most + // one entry in the vector. + GenericVector unichars_; +}; + +// ShapeTable is a class to encapsulate the triple indirection that is +// used here. +// ShapeTable is a vector of shapes. +// Each shape is a vector of UnicharAndFonts representing the set of unichars +// that the shape represents. +// Each UnicharAndFonts also lists the fonts of the unichar_id that were +// mapped to the shape during training. +class ShapeTable { + public: + ShapeTable(); + // The UNICHARSET reference supplied here, or in set_unicharset below must + // exist for the entire life of the ShapeTable. It is used only by DebugStr. + explicit ShapeTable(const UNICHARSET& unicharset); + + // Writes to the given file. Returns false in case of error. + bool Serialize(FILE* fp) const; + // Reads from the given file. Returns false in case of error. + // If swap is true, assumes a big/little-endian swap is needed. + bool DeSerialize(bool swap, FILE* fp); + + // Accessors. + int NumShapes() const { + return shape_table_.size(); + } + const UNICHARSET& unicharset() const { + return *unicharset_; + } + // Shapetable takes a pointer to the UNICHARSET, so it must persist for the + // entire life of the ShapeTable. + void set_unicharset(const UNICHARSET& unicharset) { + unicharset_ = &unicharset; + } + // Returns a string listing the classes/fonts in a shape. + STRING DebugStr(int shape_id) const; + // Returns a debug string summarizing the table. + STRING SummaryStr() const; + + // Adds a new shape starting with the given unichar_id and font_id. + // Returns the assigned index. + int AddShape(int unichar_id, int font_id); + // Adds a copy of the given shape. + // Returns the assigned index. + int AddShape(const Shape& other); + // Removes the shape given by the shape index. All indices above are changed! + void DeleteShape(int shape_id); + // Adds a font_id to the given existing shape index for the given + // unichar_id. If the unichar_id is not in the shape, it is added. + void AddToShape(int shape_id, int unichar_id, int font_id); + // Adds the given shape to the existing shape with the given index. + void AddShapeToShape(int shape_id, const Shape& other); + // Returns the id of the shape that contains the given unichar and font. + // If not found, returns -1. + // If font_id < 0, the font_id is ignored and the first shape that matches + // the unichar_id is returned. + int FindShape(int unichar_id, int font_id) const; + // Returns the first unichar_id and font_id in the given shape. + void GetFirstUnicharAndFont(int shape_id, + int* unichar_id, int* font_id) const; + + // Accessors for the Shape with the given shape_id. + const Shape& GetShape(int shape_id) const { + return *shape_table_[shape_id]; + } + Shape* MutableShape(int shape_id) { + return shape_table_[shape_id]; + } + + // Expands all the classes/fonts in the shape individually to build + // a ShapeTable. + int BuildFromShape(const Shape& shape, const ShapeTable& master_shapes); + + // Returns true if the shapes are already merged. + bool AlreadyMerged(int shape_id1, int shape_id2); + // Returns true if any shape contains multiple unichars. + bool AnyMultipleUnichars(); + // Returns the maximum number of unichars over all shapes. + int MaxNumUnichars() const; + // Merges shapes with a common unichar over the [start, end) interval. + // Assumes single unichar per shape. + void ForceFontMerges(int start, int end); + // Returns the number of unichars in the master shape. + int MasterUnicharCount(int shape_id) const; + // Returns the sum of the font counts in the master shape. + int MasterFontCount(int shape_id) const; + // Returns the number of unichars that would result from merging the shapes. + int MergedUnicharCount(int shape_id1, int shape_id2) const; + // Merges two shape_ids, leaving shape_id2 marked as merged. + void MergeShapes(int shape_id1, int shape_id2); + // Appends the master shapes from other to this. + // Used to create a clean ShapeTable from a merged one, or to create a + // copy of a ShapeTable. + void AppendMasterShapes(const ShapeTable& other); + // Returns the number of master shapes remaining after merging. + int NumMasterShapes() const; + // Returns the destination of this shape, (if merged), taking into account + // the fact that the destination may itself have been merged. + // For a non-merged shape, returns the input shape_id. + int MasterDestinationIndex(int shape_id) const; + + private: + // Pointer to a provided unicharset used only by the Debugstr member. + const UNICHARSET* unicharset_; + // Vector of pointers to the Shapes in this ShapeTable. + PointerVector shape_table_; +}; + +} // namespace tesseract. + +#endif // TESSERACT_CLASSIFY_SHAPETABLE_H_ diff --git a/classify/speckle.cpp b/classify/speckle.cpp index 0d35d689bc..4284440253 100644 --- a/classify/speckle.cpp +++ b/classify/speckle.cpp @@ -70,7 +70,7 @@ void AddLargeSpeckleTo(BLOB_CHOICE_LIST *Choices) { if (Choices->length() == 0) { blob_choice = new BLOB_CHOICE(0, speckle_small_certainty + speckle_large_penalty, - speckle_small_certainty, -1, -1, NULL); + speckle_small_certainty, -1, -1, NULL, 0, 0, false); temp_it.add_to_end(blob_choice); return; } @@ -81,7 +81,7 @@ void AddLargeSpeckleTo(BLOB_CHOICE_LIST *Choices) { blob_choice = temp_it.data(); // pick the worst choice temp_it.add_to_end( new BLOB_CHOICE(0, blob_choice->rating() + speckle_large_penalty, - blob_choice->certainty(), -1, -1, NULL)); + blob_choice->certainty(), -1, -1, NULL, 0, 0, false)); } /* AddLargeSpeckleTo */ @@ -100,18 +100,8 @@ void AddLargeSpeckleTo(BLOB_CHOICE_LIST *Choices) { * * @return TRUE if Blob is speckle, FALSE otherwise. */ -BOOL8 LargeSpeckle(TBLOB *Blob) { - double speckle_size; - TPOINT TopLeft; - TPOINT BottomRight; - - speckle_size = BASELINE_SCALE * speckle_large_max_size; - blob_bounding_box(Blob, &TopLeft, &BottomRight); - - if (TopLeft.y - BottomRight.y < speckle_size && - BottomRight.x - TopLeft.x < speckle_size) - return (TRUE); - else - return (FALSE); - +BOOL8 LargeSpeckle(TBLOB *blob) { + double speckle_size = BASELINE_SCALE * speckle_large_max_size; + TBOX bbox = blob->bounding_box(); + return (bbox.width() < speckle_size && bbox.height() < speckle_size); } /* LargeSpeckle */ diff --git a/classify/tessclassifier.cpp b/classify/tessclassifier.cpp new file mode 100644 index 0000000000..f7735d8a27 --- /dev/null +++ b/classify/tessclassifier.cpp @@ -0,0 +1,52 @@ +// Copyright 2011 Google Inc. All Rights Reserved. +// Author: rays@google.com (Ray Smith) +/////////////////////////////////////////////////////////////////////// +// File: tessclassifier.cpp +// Description: Tesseract implementation of a ShapeClassifier. +// Author: Ray Smith +// Created: Tue Nov 22 14:16:25 PST 2011 +// +// (C) Copyright 2011, Google Inc. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +/////////////////////////////////////////////////////////////////////// + +#include "tessclassifier.h" + +#include "classify.h" +#include "trainingsample.h" + +namespace tesseract { + +// Classifies the given [training] sample, writing to results. +// See ShapeClassifier for a full description. +int TessClassifier::ClassifySample(const TrainingSample& sample, + Pix* page_pix, int debug, int keep_this, + GenericVector* results) { + if (debug) { + classify_->matcher_debug_level.set_value(debug ? 2 : 0); + classify_->matcher_debug_flags.set_value(debug ? 25 : 0); + classify_->classify_debug_level.set_value(debug ? 3 : 0); + } else { + classify_->classify_debug_level.set_value(debug ? 2 : 0); + } + classify_->CharNormTrainingSample(pruner_only_, sample, results); + return results->size(); +} + +// Provides access to the ShapeTable that this classifier works with. +const ShapeTable* TessClassifier::GetShapeTable() const { + return classify_->shape_table(); +} + +} // namespace tesseract + + diff --git a/classify/tessclassifier.h b/classify/tessclassifier.h new file mode 100644 index 0000000000..f2483b7a06 --- /dev/null +++ b/classify/tessclassifier.h @@ -0,0 +1,65 @@ +// Copyright 2011 Google Inc. All Rights Reserved. +// Author: rays@google.com (Ray Smith) +/////////////////////////////////////////////////////////////////////// +// File: tessclassifier.h +// Description: Tesseract implementation of a ShapeClassifier. +// Author: Ray Smith +// Created: Tue Nov 22 14:10:45 PST 2011 +// +// (C) Copyright 2011, Google Inc. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +/////////////////////////////////////////////////////////////////////// + +#ifndef THIRD_PARTY_TESSERACT_CLASSIFY_TESSCLASSIFIER_H_ +#define THIRD_PARTY_TESSERACT_CLASSIFY_TESSCLASSIFIER_H_ + +#include "shapeclassifier.h" + +namespace tesseract { + +class Classify; +class TrainingSample; + +// Tesseract implementation of a ShapeClassifier. +// Due to limitations in the content of TrainingSample, this currently +// only works for the static classifier and only works if the ShapeTable +// in classify is not NULL. +class TessClassifier : public ShapeClassifier { + public: + TessClassifier(bool pruner_only, tesseract::Classify* classify) + : pruner_only_(pruner_only), classify_(classify) {} + virtual ~TessClassifier() {} + + // Classifies the given [training] sample, writing to results. + // See ShapeClassifier for a full description. + virtual int ClassifySample(const TrainingSample& sample, Pix* page_pix, + int debug, int keep_this, + GenericVector* results); + // Provides access to the ShapeTable that this classifier works with. + virtual const ShapeTable* GetShapeTable() const; + + private: + // Indicates that this classifier is to use just the ClassPruner, or the + // full classifier if false. + bool pruner_only_; + // Borrowed pointer to the actual Tesseract classifier. + tesseract::Classify* classify_; +}; + + +} // namespace tesseract + + + + + +#endif /* THIRD_PARTY_TESSERACT_CLASSIFY_TESSCLASSIFIER_H_ */ diff --git a/classify/trainingsample.cpp b/classify/trainingsample.cpp new file mode 100644 index 0000000000..d712d8c261 --- /dev/null +++ b/classify/trainingsample.cpp @@ -0,0 +1,311 @@ +// Copyright 2010 Google Inc. All Rights Reserved. +// Author: rays@google.com (Ray Smith) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +/////////////////////////////////////////////////////////////////////// + +#include "trainingsample.h" + +#include +#include "allheaders.h" +#include "helpers.h" +#include "intfeaturemap.h" +#include "normfeat.h" +#include "shapetable.h" + +namespace tesseract { + +ELISTIZE(TrainingSample) + +// Center of randomizing operations. +const int kRandomizingCenter = 128; + +// Randomizing factors. +const int TrainingSample::kYShiftValues[kSampleYShiftSize] = { + 6, 3, -3, -6, 0 +}; +const double TrainingSample::kScaleValues[kSampleScaleSize] = { + 1.0625, 0.9375, 1.0 +}; + +TrainingSample::~TrainingSample() { + delete [] features_; + delete [] micro_features_; +} + +// WARNING! Serialize/DeSerialize do not save/restore the "cache" data +// members, which is mostly the mapped features, and the weight. +// It is assumed these can all be reconstructed from what is saved. +// Writes to the given file. Returns false in case of error. +bool TrainingSample::Serialize(FILE* fp) const { + if (fwrite(&class_id_, sizeof(class_id_), 1, fp) != 1) return false; + if (fwrite(&font_id_, sizeof(font_id_), 1, fp) != 1) return false; + if (fwrite(&page_num_, sizeof(page_num_), 1, fp) != 1) return false; + if (!bounding_box_.Serialize(fp)) return false; + if (fwrite(&num_features_, sizeof(num_features_), 1, fp) != 1) return false; + if (fwrite(&num_micro_features_, sizeof(num_micro_features_), 1, fp) != 1) + return false; + if (fwrite(features_, sizeof(*features_), num_features_, fp) != num_features_) + return false; + if (fwrite(micro_features_, sizeof(*micro_features_), num_micro_features_, + fp) != num_micro_features_) + return false; + if (fwrite(cn_feature_, sizeof(*cn_feature_), kNumCNParams, fp) != + kNumCNParams) return false; + if (fwrite(geo_feature_, sizeof(*geo_feature_), GeoCount, fp) != GeoCount) + return false; + return true; +} + +// Creates from the given file. Returns NULL in case of error. +// If swap is true, assumes a big/little-endian swap is needed. +TrainingSample* TrainingSample::DeSerializeCreate(bool swap, FILE* fp) { + TrainingSample* sample = new TrainingSample; + if (sample->DeSerialize(swap, fp)) return sample; + delete sample; + return NULL; +} + +// Reads from the given file. Returns false in case of error. +// If swap is true, assumes a big/little-endian swap is needed. +bool TrainingSample::DeSerialize(bool swap, FILE* fp) { + if (fread(&class_id_, sizeof(class_id_), 1, fp) != 1) return false; + if (fread(&font_id_, sizeof(font_id_), 1, fp) != 1) return false; + if (fread(&page_num_, sizeof(page_num_), 1, fp) != 1) return false; + if (!bounding_box_.DeSerialize(swap, fp)) return false; + if (fread(&num_features_, sizeof(num_features_), 1, fp) != 1) return false; + if (fread(&num_micro_features_, sizeof(num_micro_features_), 1, fp) != 1) + return false; + if (swap) { + ReverseN(&class_id_, sizeof(class_id_)); + ReverseN(&num_features_, sizeof(num_features_)); + ReverseN(&num_micro_features_, sizeof(num_micro_features_)); + } + delete [] features_; + features_ = new INT_FEATURE_STRUCT[num_features_]; + if (fread(features_, sizeof(*features_), num_features_, fp) != num_features_) + return false; + delete [] micro_features_; + micro_features_ = new MicroFeature[num_micro_features_]; + if (fread(micro_features_, sizeof(*micro_features_), num_micro_features_, + fp) != num_micro_features_) + return false; + if (fread(cn_feature_, sizeof(*cn_feature_), kNumCNParams, fp) != + kNumCNParams) return false; + if (fread(geo_feature_, sizeof(*geo_feature_), GeoCount, fp) != GeoCount) + return false; + return true; +} + +// Saves the given features into a TrainingSample. +TrainingSample* TrainingSample::CopyFromFeatures( + const INT_FX_RESULT_STRUCT& fx_info, const INT_FEATURE_STRUCT* features, + int num_features) { + TrainingSample* sample = new TrainingSample; + sample->num_features_ = num_features; + sample->features_ = new INT_FEATURE_STRUCT[num_features]; + memcpy(sample->features_, features, num_features * sizeof(features[0])); + sample->geo_feature_[GeoBottom] = fx_info.YBottom; + sample->geo_feature_[GeoTop] = fx_info.YTop; + sample->geo_feature_[GeoWidth] = fx_info.Width; + sample->features_are_indexed_ = false; + sample->features_are_mapped_ = false; + return sample; +} + +// Constructs and returns a copy randomized by the method given by +// the randomizer index. If index is out of [0, kSampleRandomSize) then +// an exact copy is returned. +TrainingSample* TrainingSample::RandomizedCopy(int index) const { + TrainingSample* sample = Copy(); + if (index >= 0 && index < kSampleRandomSize) { + ++index; // Remove the first combination. + int yshift = kYShiftValues[index / kSampleScaleSize]; + double scaling = kScaleValues[index % kSampleScaleSize]; + for (int i = 0; i < num_features_; ++i) { + double result = (features_[i].X - kRandomizingCenter) * scaling; + result += kRandomizingCenter; + sample->features_[i].X = ClipToRange(static_cast(result + 0.5), 0, + MAX_UINT8); + result = (features_[i].Y - kRandomizingCenter) * scaling; + result += kRandomizingCenter + yshift; + sample->features_[i].Y = ClipToRange(static_cast(result + 0.5), 0, + MAX_UINT8); + } + } + return sample; +} + +// Constructs and returns an exact copy. +TrainingSample* TrainingSample::Copy() const { + TrainingSample* sample = new TrainingSample; + sample->class_id_ = class_id_; + sample->font_id_ = font_id_; + sample->weight_ = weight_; + sample->sample_index_ = sample_index_; + sample->num_features_ = num_features_; + if (num_features_ > 0) { + sample->features_ = new INT_FEATURE_STRUCT[num_features_]; + memcpy(sample->features_, features_, num_features_ * sizeof(features_[0])); + } + sample->num_micro_features_ = num_micro_features_; + if (num_micro_features_ > 0) { + sample->micro_features_ = new MicroFeature[num_micro_features_]; + memcpy(sample->micro_features_, micro_features_, + num_micro_features_ * sizeof(micro_features_[0])); + } + memcpy(sample->cn_feature_, cn_feature_, sizeof(*cn_feature_) * kNumCNParams); + memcpy(sample->geo_feature_, geo_feature_, sizeof(*geo_feature_) * GeoCount); + return sample; +} + +// Extracts the needed information from the CHAR_DESC_STRUCT. +void TrainingSample::ExtractCharDesc(int int_feature_type, + int micro_type, + int cn_type, + int geo_type, + CHAR_DESC_STRUCT* char_desc) { + // Extract the INT features. + if (features_ != NULL) delete [] features_; + FEATURE_SET_STRUCT* char_features = char_desc->FeatureSets[int_feature_type]; + if (char_features == NULL) { + tprintf("Error: no features to train on of type %s\n", + kIntFeatureType); + num_features_ = 0; + features_ = NULL; + } else { + num_features_ = char_features->NumFeatures; + features_ = new INT_FEATURE_STRUCT[num_features_]; + for (int f = 0; f < num_features_; ++f) { + features_[f].X = + static_cast(char_features->Features[f]->Params[IntX]); + features_[f].Y = + static_cast(char_features->Features[f]->Params[IntY]); + features_[f].Theta = + static_cast(char_features->Features[f]->Params[IntDir]); + features_[f].CP_misses = 0; + } + } + // Extract the Micro features. + if (micro_features_ != NULL) delete [] micro_features_; + char_features = char_desc->FeatureSets[micro_type]; + if (char_features == NULL) { + tprintf("Error: no features to train on of type %s\n", + kMicroFeatureType); + num_micro_features_ = 0; + micro_features_ = NULL; + } else { + num_micro_features_ = char_features->NumFeatures; + micro_features_ = new MicroFeature[num_micro_features_]; + for (int f = 0; f < num_micro_features_; ++f) { + for (int d = 0; d < MFCount; ++d) { + micro_features_[f][d] = char_features->Features[f]->Params[d]; + } + } + } + // Extract the CN feature. + char_features = char_desc->FeatureSets[cn_type]; + if (char_features == NULL) { + tprintf("Error: no CN feature to train on.\n"); + } else { + ASSERT_HOST(char_features->NumFeatures == 1); + cn_feature_[CharNormY] = char_features->Features[0]->Params[CharNormY]; + cn_feature_[CharNormLength] = + char_features->Features[0]->Params[CharNormLength]; + cn_feature_[CharNormRx] = char_features->Features[0]->Params[CharNormRx]; + cn_feature_[CharNormRy] = char_features->Features[0]->Params[CharNormRy]; + } + // Extract the Geo feature. + char_features = char_desc->FeatureSets[geo_type]; + if (char_features == NULL) { + tprintf("Error: no Geo feature to train on.\n"); + } else { + ASSERT_HOST(char_features->NumFeatures == 1); + geo_feature_[GeoBottom] = char_features->Features[0]->Params[GeoBottom]; + geo_feature_[GeoTop] = char_features->Features[0]->Params[GeoTop]; + geo_feature_[GeoWidth] = char_features->Features[0]->Params[GeoWidth]; + } + features_are_indexed_ = false; + features_are_mapped_ = false; +} + +// Sets the mapped_features_ from the features_ using the provided +// feature_space to the indexed versions of the features. +void TrainingSample::IndexFeatures(const IntFeatureSpace& feature_space) { + GenericVector indexed_features; + feature_space.IndexAndSortFeatures(features_, num_features_, + &mapped_features_); + features_are_indexed_ = true; + features_are_mapped_ = false; +} + +// Sets the mapped_features_ from the features using the provided +// feature_map. +void TrainingSample::MapFeatures(const IntFeatureMap& feature_map) { + GenericVector indexed_features; + feature_map.feature_space().IndexAndSortFeatures(features_, num_features_, + &indexed_features); + feature_map.MapIndexedFeatures(indexed_features, &mapped_features_); + features_are_indexed_ = false; + features_are_mapped_ = true; +} + +// Returns a pix representing the sample. (Int features only.) +Pix* TrainingSample::RenderToPix(const UNICHARSET* unicharset) const { + Pix* pix = pixCreate(kIntFeatureExtent, kIntFeatureExtent, 1); + for (int f = 0; f < num_features_; ++f) { + int start_x = features_[f].X; + int start_y = kIntFeatureExtent - features_[f].Y; + double dx = cos((features_[f].Theta / 256.0) * 2.0 * PI - PI); + double dy = -sin((features_[f].Theta / 256.0) * 2.0 * PI - PI); + for (int i = 0; i <= 5; ++i) { + int x = static_cast(start_x + dx * i); + int y = static_cast(start_y + dy * i); + if (x >= 0 && x < 256 && y >= 0 && y < 256) + pixSetPixel(pix, x, y, 1); + } + } + if (unicharset != NULL) + pixSetText(pix, unicharset->id_to_unichar(class_id_)); + return pix; +} + +// Displays the features in the given window with the given color. +void TrainingSample::DisplayFeatures(ScrollView::Color color, + ScrollView* window) const { + for (int f = 0; f < num_features_; ++f) { + RenderIntFeature(window, &features_[f], color); + } +} + +// Returns a pix of the original sample image. The pix is padded all round +// by padding wherever possible. +// The returned Pix must be pixDestroyed after use. +// If the input page_pix is NULL, NULL is returned. +Pix* TrainingSample::GetSamplePix(int padding, Pix* page_pix) const { + if (page_pix == NULL) + return NULL; + int page_width = pixGetWidth(page_pix); + int page_height = pixGetHeight(page_pix); + TBOX padded_box = bounding_box(); + padded_box.pad(padding, padding); + // Clip the padded_box to the limits of the page + TBOX page_box(0, 0, page_width, page_height); + padded_box &= page_box; + Box* box = boxCreate(page_box.left(), page_height - page_box.top(), + page_box.width(), page_box.height()); + Pix* sample_pix = pixClipRectangle(page_pix, box, NULL); + boxDestroy(&box); + return sample_pix; +} + +} // namespace tesseract diff --git a/classify/trainingsample.h b/classify/trainingsample.h new file mode 100644 index 0000000000..821bbed998 --- /dev/null +++ b/classify/trainingsample.h @@ -0,0 +1,240 @@ +// Copyright 2010 Google Inc. All Rights Reserved. +// Author: rays@google.com (Ray Smith) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +/////////////////////////////////////////////////////////////////////// + +#ifndef TESSERACT_TRAINING_TRAININGSAMPLE_H__ +#define TESSERACT_TRAINING_TRAININGSAMPLE_H__ + +#include "elst.h" +#include "featdefs.h" +#include "intfx.h" +#include "intmatcher.h" +#include "matrix.h" +#include "mf.h" +#include "picofeat.h" +#include "shapetable.h" +#include "unicharset.h" + +struct Pix; + +namespace tesseract { + +class IntFeatureMap; +class IntFeatureSpace; +class ShapeTable; + +// Number of elements of cn_feature_. +static const int kNumCNParams = 4; +// Number of ways to shift the features when randomizing. +static const int kSampleYShiftSize = 5; +// Number of ways to scale the features when randomizing. +static const int kSampleScaleSize = 3; +// Total number of different ways to manipulate the features when randomizing. +// The first and last combinations are removed to avoid an excessive +// top movement (first) and an identity transformation (last). +// WARNING: To avoid patterned duplication of samples, be sure to keep +// kSampleRandomSize prime! +// Eg with current values (kSampleYShiftSize = 5 and TkSampleScaleSize = 3) +// kSampleRandomSize is 13, which is prime. +static const int kSampleRandomSize = kSampleYShiftSize * kSampleScaleSize - 2; +// ASSERT_IS_PRIME(kSampleRandomSize) !! + +class TrainingSample : public ELIST_LINK { + public: + TrainingSample() + : class_id_(INVALID_UNICHAR_ID), font_id_(0), page_num_(0), + num_features_(0), num_micro_features_(0), + features_(NULL), micro_features_(NULL), weight_(1.0), + max_dist_(0.0), sample_index_(0), + features_are_indexed_(false), features_are_mapped_(false), + is_error_(false) { + } + ~TrainingSample(); + + // Saves the given features into a TrainingSample. The features are copied, + // so may be deleted afterwards. Delete the return value after use. + static TrainingSample* CopyFromFeatures(const INT_FX_RESULT_STRUCT& fx_info, + const INT_FEATURE_STRUCT* features, + int num_features); + // Constructs and returns a copy "randomized" by the method given by + // the randomizer index. If index is out of [0, kSampleRandomSize) then + // an exact copy is returned. + TrainingSample* RandomizedCopy(int index) const; + // Constructs and returns an exact copy. + TrainingSample* Copy() const; + + // WARNING! Serialize/DeSerialize do not save/restore the "cache" data + // members, which is mostly the mapped features, and the weight. + // It is assumed these can all be reconstructed from what is saved. + // Writes to the given file. Returns false in case of error. + bool Serialize(FILE* fp) const; + // Creates from the given file. Returns NULL in case of error. + // If swap is true, assumes a big/little-endian swap is needed. + static TrainingSample* DeSerializeCreate(bool swap, FILE* fp); + // Reads from the given file. Returns false in case of error. + // If swap is true, assumes a big/little-endian swap is needed. + bool DeSerialize(bool swap, FILE* fp); + + // Extracts the needed information from the CHAR_DESC_STRUCT. + void ExtractCharDesc(int feature_type, int micro_type, + int cn_type, int geo_type, + CHAR_DESC_STRUCT* char_desc); + + // Sets the mapped_features_ from the features_ using the provided + // feature_space to the indexed versions of the features. + void IndexFeatures(const IntFeatureSpace& feature_space); + // Sets the mapped_features_ from the features_ using the provided + // feature_map. + void MapFeatures(const IntFeatureMap& feature_map); + + // Returns a pix representing the sample. (Int features only.) + Pix* RenderToPix(const UNICHARSET* unicharset) const; + // Displays the features in the given window with the given color. + void DisplayFeatures(ScrollView::Color color, ScrollView* window) const; + + // Returns a pix of the original sample image. The pix is padded all round + // by padding wherever possible. + // The returned Pix must be pixDestroyed after use. + // If the input page_pix is NULL, NULL is returned. + Pix* GetSamplePix(int padding, Pix* page_pix) const; + + // Accessors. + UNICHAR_ID class_id() const { + return class_id_; + } + void set_class_id(int id) { + class_id_ = id; + } + int font_id() const { + return font_id_; + } + void set_font_id(int id) { + font_id_ = id; + } + int page_num() const { + return page_num_; + } + void set_page_num(int page) { + page_num_ = page; + } + const TBOX& bounding_box() const { + return bounding_box_; + } + void set_bounding_box(const TBOX& box) { + bounding_box_ = box; + } + int num_features() const { + return num_features_; + } + const INT_FEATURE_STRUCT* features() const { + return features_; + } + int num_micro_features() const { + return num_micro_features_; + } + const MicroFeature* micro_features() const { + return micro_features_; + } + float cn_feature(int index) const { + return cn_feature_[index]; + } + int geo_feature(int index) const { + return geo_feature_[index]; + } + double weight() const { + return weight_; + } + void set_weight(double value) { + weight_ = value; + } + double max_dist() const { + return max_dist_; + } + void set_max_dist(double value) { + max_dist_ = value; + } + int sample_index() const { + return sample_index_; + } + void set_sample_index(int value) { + sample_index_ = value; + } + bool features_are_mapped() const { + return features_are_mapped_; + } + const GenericVector& mapped_features() const { + ASSERT_HOST(features_are_mapped_); + return mapped_features_; + } + const GenericVector& indexed_features() const { + ASSERT_HOST(features_are_indexed_); + return mapped_features_; + } + bool is_error() const { + return is_error_; + } + void set_is_error(bool value) { + is_error_ = value; + } + + private: + // Unichar id that this sample represents. There obviously must be a + // reference UNICHARSET somewhere. Usually in TrainingSampleSet. + UNICHAR_ID class_id_; + // Font id in which this sample was printed. Refers to a fontinfo_table_ in + // MasterTrainer. + int font_id_; + // Number of page that the sample came from. + int page_num_; + // Bounding box of sample in original image. + TBOX bounding_box_; + // Number of INT_FEATURE_STRUCT in features_ array. + int num_features_; + // Number of MicroFeature in micro_features_ array. + int num_micro_features_; + // Array of features. + INT_FEATURE_STRUCT* features_; + // Array of features. + MicroFeature* micro_features_; + // The one and only CN feature. Indexed by NORM_PARAM_NAME enum. + float cn_feature_[kNumCNParams]; + // The one and only geometric feature. (Aims at replacing cn_feature_). + // Indexed by GeoParams enum in picofeat.h + int geo_feature_[GeoCount]; + + // Non-serialized cache data. + // Weight used for boosting training. + double weight_; + // Maximum distance to other samples of same class/font used in computing + // the canonical sample. + double max_dist_; + // Global index of this sample. + int sample_index_; + // Indexed/mapped features, as indicated by the bools below. + GenericVector mapped_features_; + bool features_are_indexed_; + bool features_are_mapped_; + // True if the last classification was an error by the current definition. + bool is_error_; + + // Randomizing factors. + static const int kYShiftValues[kSampleYShiftSize]; + static const double kScaleValues[kSampleScaleSize]; +}; + +ELISTIZEH(TrainingSample) + +} // namespace tesseract + +#endif // TESSERACT_TRAINING_TRAININGSAMPLE_H__ diff --git a/classify/trainingsampleset.cpp b/classify/trainingsampleset.cpp new file mode 100644 index 0000000000..2e7f77da4d --- /dev/null +++ b/classify/trainingsampleset.cpp @@ -0,0 +1,870 @@ +// Copyright 2010 Google Inc. All Rights Reserved. +// Author: rays@google.com (Ray Smith) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +/////////////////////////////////////////////////////////////////////// + +#include "trainingsampleset.h" +#include "allheaders.h" +#include "boxread.h" +#include "fontinfo.h" +#include "indexmapbidi.h" +#include "intfeaturedist.h" +#include "intfeaturemap.h" +#include "intfeaturespace.h" +#include "shapetable.h" +#include "trainingsample.h" +#include "unicity_table.h" + +namespace tesseract { + +const int kTestChar = -1; // 37; +// Max number of distances to compute the squared way +const int kSquareLimit = 25; +// Prime numbers for subsampling distances. +const int kPrime1 = 17; +const int kPrime2 = 13; +// Min samples from which to start discarding outliers. +const int kMinOutlierSamples = 5; + +TrainingSampleSet::FontClassInfo::FontClassInfo() + : num_raw_samples(0), canonical_sample(-1), canonical_dist(0.0f) { +} + +// Writes to the given file. Returns false in case of error. +bool TrainingSampleSet::FontClassInfo::Serialize(FILE* fp) const { + if (fwrite(&num_raw_samples, sizeof(num_raw_samples), 1, fp) != 1) + return false; + if (fwrite(&canonical_sample, sizeof(canonical_sample), 1, fp) != 1) + return false; + if (fwrite(&canonical_dist, sizeof(canonical_dist), 1, fp) != 1) return false; + if (!samples.Serialize(fp)) return false; + return true; +} +// Reads from the given file. Returns false in case of error. +// If swap is true, assumes a big/little-endian swap is needed. +bool TrainingSampleSet::FontClassInfo::DeSerialize(bool swap, FILE* fp) { + if (fread(&num_raw_samples, sizeof(num_raw_samples), 1, fp) != 1) + return false; + if (fread(&canonical_sample, sizeof(canonical_sample), 1, fp) != 1) + return false; + if (fread(&canonical_dist, sizeof(canonical_dist), 1, fp) != 1) return false; + if (!samples.DeSerialize(swap, fp)) return false; + if (swap) { + ReverseN(&num_raw_samples, sizeof(num_raw_samples)); + ReverseN(&canonical_sample, sizeof(canonical_sample)); + ReverseN(&canonical_dist, sizeof(canonical_dist)); + } + return true; +} + +TrainingSampleSet::TrainingSampleSet(const UnicityTable& font_table) + : num_raw_samples_(0), unicharset_size_(0), + font_class_array_(NULL), fontinfo_table_(font_table) { +} + +TrainingSampleSet::~TrainingSampleSet() { + delete font_class_array_; +} + +// Writes to the given file. Returns false in case of error. +bool TrainingSampleSet::Serialize(FILE* fp) const { + if (!samples_.Serialize(fp)) return false; + if (!unicharset_.save_to_file(fp)) return false; + if (!font_id_map_.Serialize(fp)) return false; + inT8 not_null = font_class_array_ != NULL; + if (fwrite(¬_null, sizeof(not_null), 1, fp) != 1) return false; + if (not_null) { + if (!font_class_array_->SerializeClasses(fp)) return false; + } + return true; +} + +// Reads from the given file. Returns false in case of error. +// If swap is true, assumes a big/little-endian swap is needed. +bool TrainingSampleSet::DeSerialize(bool swap, FILE* fp) { + if (!samples_.DeSerialize(swap, fp)) return false; + num_raw_samples_ = samples_.size(); + if (!unicharset_.load_from_file(fp)) return false; + if (!font_id_map_.DeSerialize(swap, fp)) return false; + if (font_class_array_ != NULL) { + delete font_class_array_; + font_class_array_ = NULL; + } + inT8 not_null; + if (fread(¬_null, sizeof(not_null), 1, fp) != 1) return false; + if (not_null) { + FontClassInfo empty; + font_class_array_ = new GENERIC_2D_ARRAY(1, 1 , empty); + if (!font_class_array_->DeSerializeClasses(swap, fp)) return false; + } + unicharset_size_ = unicharset_.size(); + return true; +} + +// Load an initial unicharset, or set one up if the file cannot be read. +void TrainingSampleSet::LoadUnicharset(const char* filename) { + if (!unicharset_.load_from_file(filename)) { + tprintf("Failed to load unicharset from file %s\n" + "Building unicharset for boosting from scratch...\n", + filename); + unicharset_.clear(); + // Space character needed to represent NIL_LIST classification. + unicharset_.unichar_insert(" "); + } + unicharset_size_ = unicharset_.size(); +} + +// Adds a character sample to this sample set. +// If the unichar is not already in the local unicharset, it is added. +// Returns the unichar_id of the added sample, from the local unicharset. +int TrainingSampleSet::AddSample(const char* unichar, TrainingSample* sample) { + if (!unicharset_.contains_unichar(unichar)) { + unicharset_.unichar_insert(unichar); + if (unicharset_.size() > MAX_NUM_CLASSES) { + tprintf("Error: Size of unicharset in TrainingSampleSet::AddSample is " + "greater than MAX_NUM_CLASSES\n"); + return -1; + } + } + UNICHAR_ID char_id = unicharset_.unichar_to_id(unichar); + AddSample(char_id, sample); + return char_id; +} + +// Adds a character sample to this sample set with the given unichar_id, +// which must correspond to the local unicharset (in this). +void TrainingSampleSet::AddSample(int unichar_id, TrainingSample* sample) { + sample->set_class_id(unichar_id); + samples_.push_back(sample); + num_raw_samples_ = samples_.size(); + unicharset_size_ = unicharset_.size(); +} + +// Returns the number of samples for the given font,class pair. +// If randomize is true, returns the number of samples accessible +// with randomizing on. (Increases the number of samples if small.) +// OrganizeByFontAndClass must have been already called. +int TrainingSampleSet::NumClassSamples(int font_id, int class_id, + bool randomize) const { + ASSERT_HOST(font_class_array_ != NULL); + if (font_id < 0 || class_id < 0 || + font_id >= font_id_map_.SparseSize() || class_id >= unicharset_size_) { + // There are no samples because the font or class doesn't exist. + return 0; + } + int font_index = font_id_map_.SparseToCompact(font_id); + if (font_index < 0) + return 0; // The font has no samples. + if (randomize) + return (*font_class_array_)(font_index, class_id).samples.size(); + else + return (*font_class_array_)(font_index, class_id).num_raw_samples; +} + +// Gets a sample by its index. +const TrainingSample* TrainingSampleSet::GetSample(int index) const { + return samples_[index]; +} + +// Gets a sample by its font, class, index. +// OrganizeByFontAndClass must have been already called. +const TrainingSample* TrainingSampleSet::GetSample(int font_id, int class_id, + int index) const { + ASSERT_HOST(font_class_array_ != NULL); + int font_index = font_id_map_.SparseToCompact(font_id); + if (font_index < 0) return NULL; + int sample_index = (*font_class_array_)(font_index, class_id).samples[index]; + return samples_[sample_index]; +} + +// Get a sample by its font, class, index. Does not randomize. +// OrganizeByFontAndClass must have been already called. +TrainingSample* TrainingSampleSet::MutableSample(int font_id, int class_id, + int index) { + ASSERT_HOST(font_class_array_ != NULL); + int font_index = font_id_map_.SparseToCompact(font_id); + if (font_index < 0) return NULL; + int sample_index = (*font_class_array_)(font_index, class_id).samples[index]; + return samples_[sample_index]; +} + +// Returns a string debug representation of the given sample: +// font, unichar_str, bounding box, page. +STRING TrainingSampleSet::SampleToString(const TrainingSample& sample) const { + STRING boxfile_str; + MakeBoxFileStr(unicharset_.id_to_unichar(sample.class_id()), + sample.bounding_box(), sample.page_num(), &boxfile_str); + return STRING(fontinfo_table_.get(sample.font_id()).name) + " " + boxfile_str; +} + +// Gets the combined set of features used by all the samples of the given +// font/class combination. +const BitVector& TrainingSampleSet::GetCloudFeatures( + int font_id, int class_id) const { + int font_index = font_id_map_.SparseToCompact(font_id); + ASSERT_HOST(font_index >= 0); + return (*font_class_array_)(font_index, class_id).cloud_features; +} +// Gets the indexed features of the canonical sample of the given +// font/class combination. +const GenericVector& TrainingSampleSet::GetCanonicalFeatures( + int font_id, int class_id) const { + int font_index = font_id_map_.SparseToCompact(font_id); + ASSERT_HOST(font_index >= 0); + return (*font_class_array_)(font_index, class_id).canonical_features; +} + +// Returns the distance between the given UniCharAndFonts pair. +// If matched_fonts, only matching fonts, are considered, unless that yields +// the empty set. +// OrganizeByFontAndClass must have been already called. +float TrainingSampleSet::UnicharDistance(const UnicharAndFonts& uf1, + const UnicharAndFonts& uf2, + bool matched_fonts, + const IntFeatureMap& feature_map) { + int num_fonts1 = uf1.font_ids.size(); + int c1 = uf1.unichar_id; + int num_fonts2 = uf2.font_ids.size(); + int c2 = uf2.unichar_id; + double dist_sum = 0.0; + int dist_count = 0; + bool debug = false; + if (matched_fonts) { + // Compute distances only where fonts match. + for (int i = 0; i < num_fonts1; ++i) { + int f1 = uf1.font_ids[i]; + for (int j = 0; j < num_fonts2; ++j) { + int f2 = uf2.font_ids[j]; + if (f1 == f2) { + dist_sum += ClusterDistance(f1, c1, f2, c2, feature_map); + ++dist_count; + } + } + } + } else if (num_fonts1 * num_fonts2 <= kSquareLimit) { + // Small enough sets to compute all the distances. + for (int i = 0; i < num_fonts1; ++i) { + int f1 = uf1.font_ids[i]; + for (int j = 0; j < num_fonts2; ++j) { + int f2 = uf2.font_ids[j]; + dist_sum += ClusterDistance(f1, c1, f2, c2, feature_map); + if (debug) { + tprintf("Cluster dist %d %d %d %d = %g\n", + f1, c1, f2, c2, + ClusterDistance(f1, c1, f2, c2, feature_map)); + } + ++dist_count; + } + } + } else { + // Subsample distances, using the largest set once, and stepping through + // the smaller set so as to ensure that all the pairs are different. + int increment = kPrime1 != num_fonts2 ? kPrime1 : kPrime2; + int index = 0; + int num_samples = MAX(num_fonts1, num_fonts2); + for (int i = 0; i < num_samples; ++i, index += increment) { + int f1 = uf1.font_ids[i % num_fonts1]; + int f2 = uf2.font_ids[index % num_fonts2]; + if (debug) { + tprintf("Cluster dist %d %d %d %d = %g\n", + f1, c1, f2, c2, ClusterDistance(f1, c1, f2, c2, feature_map)); + } + dist_sum += ClusterDistance(f1, c1, f2, c2, feature_map); + ++dist_count; + } + } + if (dist_count == 0) { + if (matched_fonts) + return UnicharDistance(uf1, uf2, false, feature_map); + return 0.0f; + } + return dist_sum / dist_count; +} + +// Returns the distance between the given pair of font/class pairs. +// Finds in cache or computes and caches. +// OrganizeByFontAndClass must have been already called. +float TrainingSampleSet::ClusterDistance(int font_id1, int class_id1, + int font_id2, int class_id2, + const IntFeatureMap& feature_map) { + ASSERT_HOST(font_class_array_ != NULL); + int font_index1 = font_id_map_.SparseToCompact(font_id1); + int font_index2 = font_id_map_.SparseToCompact(font_id2); + if (font_index1 < 0 || font_index2 < 0) + return 0.0f; + FontClassInfo& fc_info = (*font_class_array_)(font_index1, class_id1); + if (font_id1 == font_id2) { + // Special case cache for speed. + if (fc_info.unichar_distance_cache.size() == 0) + fc_info.unichar_distance_cache.init_to_size(unicharset_size_, -1.0f); + if (fc_info.unichar_distance_cache[class_id2] < 0) { + // Distance has to be calculated. + float result = ComputeClusterDistance(font_id1, class_id1, + font_id2, class_id2, + feature_map); + fc_info.unichar_distance_cache[class_id2] = result; + // Copy to the symmetric cache entry. + FontClassInfo& fc_info2 = (*font_class_array_)(font_index2, class_id2); + if (fc_info2.unichar_distance_cache.size() == 0) + fc_info2.unichar_distance_cache.init_to_size(unicharset_size_, -1.0f); + fc_info2.unichar_distance_cache[class_id1] = result; + } + return fc_info.unichar_distance_cache[class_id2]; + } else if (class_id1 == class_id2) { + // Another special-case cache for equal class-id. + if (fc_info.font_distance_cache.size() == 0) + fc_info.font_distance_cache.init_to_size(font_id_map_.CompactSize(), + -1.0f); + if (fc_info.font_distance_cache[font_index2] < 0) { + // Distance has to be calculated. + float result = ComputeClusterDistance(font_id1, class_id1, + font_id2, class_id2, + feature_map); + fc_info.font_distance_cache[font_index2] = result; + // Copy to the symmetric cache entry. + FontClassInfo& fc_info2 = (*font_class_array_)(font_index2, class_id2); + if (fc_info2.font_distance_cache.size() == 0) + fc_info2.font_distance_cache.init_to_size(font_id_map_.CompactSize(), + -1.0f); + fc_info2.font_distance_cache[font_index1] = result; + } + return fc_info.font_distance_cache[font_index2]; + } + // Both font and class are different. Linear search for class_id2/font_id2 + // in what is a hopefully short list of distances. + int cache_index = 0; + while (cache_index < fc_info.distance_cache.size() && + (fc_info.distance_cache[cache_index].unichar_id != class_id2 || + fc_info.distance_cache[cache_index].font_id != font_id2)) + ++cache_index; + if (cache_index == fc_info.distance_cache.size()) { + // Distance has to be calculated. + float result = ComputeClusterDistance(font_id1, class_id1, + font_id2, class_id2, + feature_map); + FontClassDistance fc_dist = { class_id2, font_id2, result }; + fc_info.distance_cache.push_back(fc_dist); + // Copy to the symmetric cache entry. We know it isn't there already, as + // we always copy to the symmetric entry. + FontClassInfo& fc_info2 = (*font_class_array_)(font_index2, class_id2); + fc_dist.unichar_id = class_id1; + fc_dist.font_id = font_id1; + fc_info2.distance_cache.push_back(fc_dist); + } + return fc_info.distance_cache[cache_index].distance; +} + +// Computes the distance between the given pair of font/class pairs. +float TrainingSampleSet::ComputeClusterDistance( + int font_id1, int class_id1, int font_id2, int class_id2, + const IntFeatureMap& feature_map) const { + int dist = ReliablySeparable(font_id1, class_id1, font_id2, class_id2, + feature_map, false); + dist += ReliablySeparable(font_id2, class_id2, font_id1, class_id1, + feature_map, false); + int denominator = GetCanonicalFeatures(font_id1, class_id1).size(); + denominator += GetCanonicalFeatures(font_id2, class_id2).size(); + return static_cast(dist) / denominator; +} + +// Helper to add a feature and its near neighbors to the good_features. +// levels indicates how many times to compute the offset features of what is +// already there. This is done by iteration rather than recursion. +static void AddNearFeatures(const IntFeatureMap& feature_map, int f, int levels, + GenericVector* good_features) { + int prev_num_features = 0; + good_features->push_back(f); + int num_features = 1; + for (int level = 0; level < levels; ++level) { + for (int i = prev_num_features; i < num_features; ++i) { + int feature = (*good_features)[i]; + for (int dir = -kNumOffsetMaps; dir <= kNumOffsetMaps; ++dir) { + if (dir == 0) continue; + int f1 = feature_map.OffsetFeature(feature, dir); + if (f1 >= 0) { + good_features->push_back(f1); + } + } + } + prev_num_features = num_features; + num_features = good_features->size(); + } +} + +// Returns the number of canonical features of font/class 2 for which +// neither the feature nor any of its near neighbors occurs in the cloud +// of font/class 1. Each such feature is a reliable separation between +// the classes, ASSUMING that the canonical sample is sufficiently +// representative that every sample has a feature near that particular +// feature. To check that this is so on the fly would be prohibitively +// expensive, but it might be possible to pre-qualify the canonical features +// to include only those for which this assumption is true. +// ComputeCanonicalFeatures and ComputeCloudFeatures must have been called +// first, or the results will be nonsense. +int TrainingSampleSet::ReliablySeparable(int font_id1, int class_id1, + int font_id2, int class_id2, + const IntFeatureMap& feature_map, + bool thorough) const { + int result = 0; + const TrainingSample* sample2 = GetCanonicalSample(font_id2, class_id2); + if (sample2 == NULL) + return 0; // There are no canonical features. + const GenericVector& canonical2 = GetCanonicalFeatures(font_id2, + class_id2); + const BitVector& cloud1 = GetCloudFeatures(font_id1, class_id1); + if (cloud1.size() == 0) + return canonical2.size(); // There are no cloud features. + + // Find a canonical2 feature that is not in cloud1. + for (int f = 0; f < canonical2.size(); ++f) { + int feature = canonical2[f]; + if (cloud1[feature]) + continue; + // Gather the near neighbours of f. + GenericVector good_features; + AddNearFeatures(feature_map, feature, 1, &good_features); + // Check that none of the good_features are in the cloud. + int i; + for (i = 0; i < good_features.size(); ++i) { + int good_f = good_features[i]; + if (cloud1[good_f]) { + break; + } + } + if (i < good_features.size()) + continue; // Found one in the cloud. + ++result; + } + return result; +} + +// Returns the total index of the requested sample. +// OrganizeByFontAndClass must have been already called. +int TrainingSampleSet::GlobalSampleIndex(int font_id, int class_id, + int index) const { + ASSERT_HOST(font_class_array_ != NULL); + int font_index = font_id_map_.SparseToCompact(font_id); + if (font_index < 0) return -1; + return (*font_class_array_)(font_index, class_id).samples[index]; +} + +// Gets the canonical sample for the given font, class pair. +// ComputeCanonicalSamples must have been called first. +const TrainingSample* TrainingSampleSet::GetCanonicalSample( + int font_id, int class_id) const { + ASSERT_HOST(font_class_array_ != NULL); + int font_index = font_id_map_.SparseToCompact(font_id); + if (font_index < 0) return NULL; + int sample_index = (*font_class_array_)(font_index, + class_id).canonical_sample; + return sample_index >= 0 ? samples_[sample_index] : NULL; +} + +// Gets the max distance for the given canonical sample. +// ComputeCanonicalSamples must have been called first. +float TrainingSampleSet::GetCanonicalDist(int font_id, int class_id) const { + ASSERT_HOST(font_class_array_ != NULL); + int font_index = font_id_map_.SparseToCompact(font_id); + if (font_index < 0) return 0.0f; + if ((*font_class_array_)(font_index, class_id).canonical_sample >= 0) + return (*font_class_array_)(font_index, class_id).canonical_dist; + else + return 0.0f; +} + +// Generates indexed features for all samples with the supplied feature_space. +void TrainingSampleSet::IndexFeatures(const IntFeatureSpace& feature_space) { + for (int s = 0; s < samples_.size(); ++s) + samples_[s]->IndexFeatures(feature_space); +} + +// Delete outlier samples with few features that are shared with others. +// IndexFeatures must have been called already. +void TrainingSampleSet::DeleteOutliers(const IntFeatureSpace& feature_space, + bool debug) { + if (font_class_array_ == NULL) + OrganizeByFontAndClass(); + Pixa* pixa = NULL; + if (debug) + pixa = pixaCreate(0); + GenericVector feature_counts; + int fs_size = feature_space.Size(); + int font_size = font_id_map_.CompactSize(); + for (int font_index = 0; font_index < font_size; ++font_index) { + for (int c = 0; c < unicharset_size_; ++c) { + // Create a histogram of the features used by all samples of this + // font/class combination. + feature_counts.init_to_size(fs_size, 0); + FontClassInfo& fcinfo = (*font_class_array_)(font_index, c); + int sample_count = fcinfo.samples.size(); + if (sample_count < kMinOutlierSamples) + continue; + for (int i = 0; i < sample_count; ++i) { + int s = fcinfo.samples[i]; + const GenericVector& features = samples_[s]->indexed_features(); + for (int f = 0; f < features.size(); ++f) { + ++feature_counts[features[f]]; + } + } + for (int i = 0; i < sample_count; ++i) { + int s = fcinfo.samples[i]; + const TrainingSample& sample = *samples_[s]; + const GenericVector& features = sample.indexed_features(); + // A feature that has a histogram count of 1 is only used by this + // sample, making it 'bad'. All others are 'good'. + int good_features = 0; + int bad_features = 0; + for (int f = 0; f < features.size(); ++f) { + if (feature_counts[features[f]] > 1) + ++good_features; + else + ++bad_features; + } + // If more than 1/3 features are bad, then this is an outlier. + if (bad_features * 2 > good_features) { + tprintf("Deleting outlier sample of %s, %d good, %d bad\n", + SampleToString(sample).string(), + good_features, bad_features); + if (debug) { + pixaAddPix(pixa, sample.RenderToPix(&unicharset_), L_INSERT); + // Add the previous sample as well, so it is easier to see in + // the output what is wrong with this sample. + int t; + if (i == 0) + t = fcinfo.samples[1]; + else + t = fcinfo.samples[i - 1]; + const TrainingSample &csample = *samples_[t]; + pixaAddPix(pixa, csample.RenderToPix(&unicharset_), L_INSERT); + } + // Mark the sample for deletion. + KillSample(samples_[s]); + } + } + } + } + // Truly delete all bad samples and renumber everything. + DeleteDeadSamples(); + if (pixa != NULL) { + Pix* pix = pixaDisplayTiledInRows(pixa, 1, 2600, 1.0, 0, 10, 10); + pixaDestroy(&pixa); + pixWrite("outliers.png", pix, IFF_PNG); + pixDestroy(&pix); + } +} + +// Marks the given sample index for deletion. +// Deletion is actually completed by DeleteDeadSamples. +void TrainingSampleSet::KillSample(TrainingSample* sample) { + sample->set_sample_index(-1); +} + +// Deletes all samples with zero features marked by KillSample. +void TrainingSampleSet::DeleteDeadSamples() { + samples_.compact( + NewPermanentTessCallback(this, &TrainingSampleSet::DeleteableSample)); + num_raw_samples_ = samples_.size(); + // Samples must be re-organized now we have deleted a few. +} + +// Callback function returns true if the given sample is to be deleted, due +// to having a negative classid. +bool TrainingSampleSet::DeleteableSample(const TrainingSample* sample) { + return sample == NULL || sample->class_id() < 0; +} + +static Pix* DebugSample(const UNICHARSET& unicharset, + TrainingSample* sample) { + tprintf("\nOriginal features:\n"); + for (int i = 0; i < sample->num_features(); ++i) { + sample->features()[i].print(); + } + if (sample->features_are_mapped()) { + tprintf("\nMapped features:\n"); + for (int i = 0; i < sample->mapped_features().size(); ++i) { + tprintf("%d ", sample->mapped_features()[i]); + } + tprintf("\n"); + } + return sample->RenderToPix(&unicharset); +} + +// Construct an array to access the samples by font,class pair. +void TrainingSampleSet::OrganizeByFontAndClass() { + // Font indexes are sparse, so we used a map to compact them, so we can + // have an efficient 2-d array of fonts and character classes. + SetupFontIdMap(); + int compact_font_size = font_id_map_.CompactSize(); + // Get a 2-d array of generic vectors. + if (font_class_array_ != NULL) + delete font_class_array_; + FontClassInfo empty; + font_class_array_ = new GENERIC_2D_ARRAY( + compact_font_size, unicharset_size_, empty); + for (int s = 0; s < samples_.size(); ++s) { + int font_id = samples_[s]->font_id(); + int class_id = samples_[s]->class_id(); + if (font_id < 0 || font_id >= font_id_map_.SparseSize()) { + tprintf("Font id = %d/%d, class id = %d/%d on sample %d\n", + font_id, font_id_map_.SparseSize(), class_id, unicharset_size_, + s); + } + ASSERT_HOST(font_id >= 0 && font_id < font_id_map_.SparseSize()); + ASSERT_HOST(class_id >= 0 && class_id < unicharset_size_); + int font_index = font_id_map_.SparseToCompact(font_id); + (*font_class_array_)(font_index, class_id).samples.push_back(s); + } + // Set the num_raw_samples member of the FontClassInfo, to set the boundary + // between the raw samples and the replicated ones. + for (int f = 0; f < compact_font_size; ++f) { + for (int c = 0; c < unicharset_size_; ++c) + (*font_class_array_)(f, c).num_raw_samples = + (*font_class_array_)(f, c).samples.size(); + } + // This is the global number of samples and also marks the boundary between + // real and replicated samples. + num_raw_samples_ = samples_.size(); +} + +// Constructs the font_id_map_ which maps real font_ids (sparse) to a compact +// index for the font_class_array_. +void TrainingSampleSet::SetupFontIdMap() { + // Number of samples for each font_id. + GenericVector font_counts; + for (int s = 0; s < samples_.size(); ++s) { + int font_id = samples_[s]->font_id(); + while (font_id >= font_counts.size()) + font_counts.push_back(0); + ++font_counts[font_id]; + } + font_id_map_.Init(font_counts.size(), false); + for (int f = 0; f < font_counts.size(); ++f) { + font_id_map_.SetMap(f, font_counts[f] > 0); + } + font_id_map_.Setup(); +} + + +// Finds the sample for each font, class pair that has least maximum +// distance to all the other samples of the same font, class. +// OrganizeByFontAndClass must have been already called. +void TrainingSampleSet::ComputeCanonicalSamples(const IntFeatureMap& map, + bool debug) { + ASSERT_HOST(font_class_array_ != NULL); + IntFeatureDist f_table; + if (debug) tprintf("feature table size %d\n", map.sparse_size()); + f_table.Init(&map); + int worst_s1 = 0; + int worst_s2 = 0; + double global_worst_dist = 0.0; + // Compute distances independently for each font and char index. + int font_size = font_id_map_.CompactSize(); + for (int font_index = 0; font_index < font_size; ++font_index) { + int font_id = font_id_map_.CompactToSparse(font_index); + for (int c = 0; c < unicharset_size_; ++c) { + int samples_found = 0; + FontClassInfo& fcinfo = (*font_class_array_)(font_index, c); + if (fcinfo.samples.size() == 0 || + (kTestChar >= 0 && c != kTestChar)) { + fcinfo.canonical_sample = -1; + fcinfo.canonical_dist = 0.0f; + if (debug) tprintf("Skipping class %d\n", c); + continue; + } + // The canonical sample will be the one with the min_max_dist, which + // is the sample with the lowest maximum distance to all other samples. + double min_max_dist = 2.0; + // We keep track of the farthest apart pair (max_s1, max_s2) which + // are max_max_dist apart, so we can see how bad the variability is. + double max_max_dist = 0.0; + int max_s1 = 0; + int max_s2 = 0; + fcinfo.canonical_sample = fcinfo.samples[0]; + fcinfo.canonical_dist = 0.0f; + for (int i = 0; i < fcinfo.samples.size(); ++i) { + int s1 = fcinfo.samples[i]; + const GenericVector& features1 = samples_[s1]->indexed_features(); + f_table.Set(features1, features1.size(), true); + double max_dist = 0.0; + // Run the full squared-order search for similar samples. It is still + // reasonably fast because f_table.FeatureDistance is fast, but we + // may have to reconsider if we start playing with too many samples + // of a single char/font. + for (int j = 0; j < fcinfo.samples.size(); ++j) { + int s2 = fcinfo.samples[j]; + if (samples_[s2]->class_id() != c || + samples_[s2]->font_id() != font_id || + s2 == s1) + continue; + GenericVector features2 = samples_[s2]->indexed_features(); + double dist = f_table.FeatureDistance(features2); + int height = samples_[s2]->geo_feature(GeoTop) - + samples_[s2]->geo_feature(GeoBottom); + if (dist == 1.0 && height > 64) { + // TODO(rays) rethink this when the polygonal approximation goes. + // Currently it is possible for dots and other small characters + // to be completely different, even within the same class. + f_table.DebugFeatureDistance(features2); + } + if (dist > max_dist) { + max_dist = dist; + if (dist > max_max_dist) { + max_s1 = s1; + max_s2 = s2; + } + } + } + // Using Set(..., false) is far faster than re initializing, due to + // the sparseness of the feature space. + f_table.Set(features1, features1.size(), false); + samples_[s1]->set_max_dist(max_dist); + ++samples_found; + if (max_dist < min_max_dist) { + fcinfo.canonical_sample = s1; + fcinfo.canonical_dist = max_dist; + } + UpdateRange(max_dist, &min_max_dist, &max_max_dist); + } + if (max_max_dist > global_worst_dist) { + // Keep a record of the worst pair over all characters/fonts too. + global_worst_dist = max_max_dist; + worst_s1 = max_s1; + worst_s2 = max_s2; + } + if (debug) { + tprintf("Found %d samples of class %d=%s, font %d, " + "dist range [%g, %g], worst pair= %s, %s\n", + samples_found, c, unicharset_.debug_str(c).string(), + font_index, min_max_dist, max_max_dist, + SampleToString(*samples_[max_s1]).string(), + SampleToString(*samples_[max_s2]).string()); + } + } + } + if (debug) { + tprintf("Global worst dist = %g, between sample %d and %d\n", + global_worst_dist, worst_s1, worst_s2); + Pix* pix1 = DebugSample(unicharset_, samples_[worst_s1]); + Pix* pix2 = DebugSample(unicharset_, samples_[worst_s2]); + pixOr(pix1, pix1, pix2); + pixWrite("worstpair.png", pix1, IFF_PNG); + pixDestroy(&pix1); + pixDestroy(&pix2); + } +} + +// Replicates the samples to a minimum frequency defined by +// 2 * kSampleRandomSize, or for larger counts duplicates all samples. +// After replication, the replicated samples are perturbed slightly, but +// in a predictable and repeatable way. +// Use after OrganizeByFontAndClass(). +void TrainingSampleSet::ReplicateAndRandomizeSamples() { + ASSERT_HOST(font_class_array_ != NULL); + int font_size = font_id_map_.CompactSize(); + for (int font_index = 0; font_index < font_size; ++font_index) { + for (int c = 0; c < unicharset_size_; ++c) { + FontClassInfo& fcinfo = (*font_class_array_)(font_index, c); + int sample_count = fcinfo.samples.size(); + int min_samples = 2 * MAX(kSampleRandomSize, sample_count); + if (sample_count > 0 && sample_count < min_samples) { + int base_count = sample_count; + for (int base_index = 0; sample_count < min_samples; ++sample_count) { + int src_index = fcinfo.samples[base_index++]; + if (base_index >= base_count) base_index = 0; + TrainingSample* sample = samples_[src_index]->RandomizedCopy( + sample_count % kSampleRandomSize); + int sample_index = samples_.size(); + sample->set_sample_index(sample_index); + samples_.push_back(sample); + fcinfo.samples.push_back(sample_index); + } + } + } + } +} + +// Caches the indexed features of the canonical samples. +// ComputeCanonicalSamples must have been already called. +// TODO(rays) see note on ReliablySeparable and try restricting the +// canonical features to those that truly represent all samples. +void TrainingSampleSet::ComputeCanonicalFeatures() { + ASSERT_HOST(font_class_array_ != NULL); + int font_size = font_id_map_.CompactSize(); + for (int font_index = 0; font_index < font_size; ++font_index) { + int font_id = font_id_map_.CompactToSparse(font_index); + for (int c = 0; c < unicharset_size_; ++c) { + int num_samples = NumClassSamples(font_id, c, false); + if (num_samples == 0) + continue; + const TrainingSample* sample = GetCanonicalSample(font_id, c); + FontClassInfo& fcinfo = (*font_class_array_)(font_index, c); + fcinfo.canonical_features = sample->indexed_features(); + } + } +} + +// Computes the combined set of features used by all the samples of each +// font/class combination. Use after ReplicateAndRandomizeSamples. +void TrainingSampleSet::ComputeCloudFeatures(int feature_space_size) { + ASSERT_HOST(font_class_array_ != NULL); + int font_size = font_id_map_.CompactSize(); + for (int font_index = 0; font_index < font_size; ++font_index) { + int font_id = font_id_map_.CompactToSparse(font_index); + for (int c = 0; c < unicharset_size_; ++c) { + int num_samples = NumClassSamples(font_id, c, false); + if (num_samples == 0) + continue; + FontClassInfo& fcinfo = (*font_class_array_)(font_index, c); + fcinfo.cloud_features.Init(feature_space_size); + for (int s = 0; s < num_samples; ++s) { + const TrainingSample* sample = GetSample(font_id, c, s); + const GenericVector& sample_features = sample->indexed_features(); + for (int i = 0; i < sample_features.size(); ++i) + fcinfo.cloud_features.SetBit(sample_features[i]); + } + } + } +} + +// Adds all fonts of the given class to the shape. +void TrainingSampleSet::AddAllFontsForClass(int class_id, Shape* shape) const { + for (int f = 0; f < font_id_map_.CompactSize(); ++f) { + int font_id = font_id_map_.CompactToSparse(f); + shape->AddToShape(class_id, font_id); + } +} + +// Display the samples with the given indexed feature that also match +// the given shape. +void TrainingSampleSet::DisplaySamplesWithFeature(int f_index, + const Shape& shape, + const IntFeatureSpace& space, + ScrollView::Color color, + ScrollView* window) const { + for (int s = 0; s < num_raw_samples(); ++s) { + const TrainingSample* sample = GetSample(s); + if (shape.ContainsUnichar(sample->class_id())) { + GenericVector indexed_features; + space.IndexAndSortFeatures(sample->features(), sample->num_features(), + &indexed_features); + for (int f = 0; f < indexed_features.size(); ++f) { + if (indexed_features[f] == f_index) { + sample->DisplayFeatures(color, window); + } + } + } + } +} + + +} // namespace tesseract. diff --git a/classify/trainingsampleset.h b/classify/trainingsampleset.h new file mode 100644 index 0000000000..4ff4e86e01 --- /dev/null +++ b/classify/trainingsampleset.h @@ -0,0 +1,290 @@ +// Copyright 2010 Google Inc. All Rights Reserved. +// Author: rays@google.com (Ray Smith) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +/////////////////////////////////////////////////////////////////////// + +#ifndef TESSERACT_TRAINING_TRAININGSAMPLESET_H__ +#define TESSERACT_TRAINING_TRAININGSAMPLESET_H__ + +#include "bitvector.h" +#include "genericvector.h" +#include "indexmapbidi.h" +#include "matrix.h" +#include "shapetable.h" +#include "trainingsample.h" + +class UNICHARSET; +template class UnicityTable; + +namespace tesseract { + +struct FontInfo; +class IntFeatureMap; +class IntFeatureSpace; +class TrainingSample; +class UnicharAndFonts; + +// Collection of TrainingSample used for training or testing a classifier. +// Provides several useful methods to operate on the collection as a whole, +// including outlier detection and deletion, providing access by font and +// class, finding the canonical sample, finding the "cloud" features (OR of +// all features in all samples), replication of samples, caching of distance +// metrics. +class TrainingSampleSet { + public: + explicit TrainingSampleSet(const UnicityTable& fontinfo_table); + ~TrainingSampleSet(); + + // Writes to the given file. Returns false in case of error. + bool Serialize(FILE* fp) const; + // Reads from the given file. Returns false in case of error. + // If swap is true, assumes a big/little-endian swap is needed. + bool DeSerialize(bool swap, FILE* fp); + + // Accessors + int num_samples() const { + return samples_.size(); + } + int num_raw_samples() const { + return num_raw_samples_; + } + int NumFonts() const { + return font_id_map_.SparseSize(); + } + const UNICHARSET& unicharset() const { + return unicharset_; + } + int charsetsize() const { + return unicharset_size_; + } + + // Loads an initial unicharset, or sets one up if the file cannot be read. + void LoadUnicharset(const char* filename); + + // Adds a character sample to this sample set. + // If the unichar is not already in the local unicharset, it is added. + // Returns the unichar_id of the added sample, from the local unicharset. + int AddSample(const char* unichar, TrainingSample* sample); + // Adds a character sample to this sample set with the given unichar_id, + // which must correspond to the local unicharset (in this). + void AddSample(int unichar_id, TrainingSample* sample); + + // Returns the number of samples for the given font,class pair. + // If randomize is true, returns the number of samples accessible + // with randomizing on. (Increases the number of samples if small.) + // OrganizeByFontAndClass must have been already called. + int NumClassSamples(int font_id, int class_id, bool randomize) const; + + // Gets a sample by its index. + const TrainingSample* GetSample(int index) const; + + // Gets a sample by its font, class, index. + // OrganizeByFontAndClass must have been already called. + const TrainingSample* GetSample(int font_id, int class_id, int index) const; + + // Get a sample by its font, class, index. Does not randomize. + // OrganizeByFontAndClass must have been already called. + TrainingSample* MutableSample(int font_id, int class_id, int index); + + // Returns a string debug representation of the given sample: + // font, unichar_str, bounding box, page. + STRING SampleToString(const TrainingSample& sample) const; + + // Gets the combined set of features used by all the samples of the given + // font/class combination. + const BitVector& GetCloudFeatures(int font_id, int class_id) const; + // Gets the indexed features of the canonical sample of the given + // font/class combination. + const GenericVector& GetCanonicalFeatures(int font_id, + int class_id) const; + + // Returns the distance between the given UniCharAndFonts pair. + // If matched_fonts, only matching fonts, are considered, unless that yields + // the empty set. + // OrganizeByFontAndClass must have been already called. + float UnicharDistance(const UnicharAndFonts& uf1, const UnicharAndFonts& uf2, + bool matched_fonts, const IntFeatureMap& feature_map); + + // Returns the distance between the given pair of font/class pairs. + // Finds in cache or computes and caches. + // OrganizeByFontAndClass must have been already called. + float ClusterDistance(int font_id1, int class_id1, + int font_id2, int class_id2, + const IntFeatureMap& feature_map); + + // Computes the distance between the given pair of font/class pairs. + float ComputeClusterDistance(int font_id1, int class_id1, + int font_id2, int class_id2, + const IntFeatureMap& feature_map) const; + + // Returns the number of canonical features of font/class 2 for which + // neither the feature nor any of its near neighbors occurs in the cloud + // of font/class 1. Each such feature is a reliable separation between + // the classes, ASSUMING that the canonical sample is sufficiently + // representative that every sample has a feature near that particular + // feature. To check that this is so on the fly would be prohibitively + // expensive, but it might be possible to pre-qualify the canonical features + // to include only those for which this assumption is true. + // ComputeCanonicalFeatures and ComputeCloudFeatures must have been called + // first, or the results will be nonsense. + int ReliablySeparable(int font_id1, int class_id1, + int font_id2, int class_id2, + const IntFeatureMap& feature_map, + bool thorough) const; + + + // Returns the total index of the requested sample. + // OrganizeByFontAndClass must have been already called. + int GlobalSampleIndex(int font_id, int class_id, int index) const; + + // Gets the canonical sample for the given font, class pair. + // ComputeCanonicalSamples must have been called first. + const TrainingSample* GetCanonicalSample(int font_id, int class_id) const; + // Gets the max distance for the given canonical sample. + // ComputeCanonicalSamples must have been called first. + float GetCanonicalDist(int font_id, int class_id) const; + + // Returns a mutable pointer to the sample with the given index. + TrainingSample* mutable_sample(int index) { + return samples_[index]; + } + // Gets ownership of the sample with the given index, removing it from this. + TrainingSample* extract_sample(int index) { + TrainingSample* sample = samples_[index]; + samples_[index] = NULL; + return sample; + } + + // Generates indexed features for all samples with the supplied feature_space. + void IndexFeatures(const IntFeatureSpace& feature_space); + + // Delete outlier samples with few features that are shared with others. + // IndexFeatures must have been called already. + void DeleteOutliers(const IntFeatureSpace& feature_space, bool debug); + + // Marks the given sample for deletion. + // Deletion is actually completed by DeleteDeadSamples. + void KillSample(TrainingSample* sample); + + // Deletes all samples with a negative sample index marked by KillSample. + // Must be called before OrganizeByFontAndClass, and OrganizeByFontAndClass + // must be called after as the samples have been renumbered. + void DeleteDeadSamples(); + + // Callback function returns true if the given sample is to be deleted, due + // to having a negative classid. + bool DeleteableSample(const TrainingSample* sample); + + // Construct an array to access the samples by font,class pair. + void OrganizeByFontAndClass(); + + // Constructs the font_id_map_ which maps real font_ids (sparse) to a compact + // index for the font_class_array_. + void SetupFontIdMap(); + + // Finds the sample for each font, class pair that has least maximum + // distance to all the other samples of the same font, class. + // OrganizeByFontAndClass must have been already called. + void ComputeCanonicalSamples(const IntFeatureMap& map, bool debug); + + // Replicates the samples to a minimum frequency defined by + // 2 * kSampleRandomSize, or for larger counts duplicates all samples. + // After replication, the replicated samples are perturbed slightly, but + // in a predictable and repeatable way. + // Use after OrganizeByFontAndClass(). + void ReplicateAndRandomizeSamples(); + + // Caches the indexed features of the canonical samples. + // ComputeCanonicalSamples must have been already called. + void ComputeCanonicalFeatures(); + // Computes the combined set of features used by all the samples of each + // font/class combination. Use after ReplicateAndRandomizeSamples. + void ComputeCloudFeatures(int feature_space_size); + + // Adds all fonts of the given class to the shape. + void AddAllFontsForClass(int class_id, Shape* shape) const; + + // Display the samples with the given indexed feature that also match + // the given shape. + void DisplaySamplesWithFeature(int f_index, const Shape& shape, + const IntFeatureSpace& feature_space, + ScrollView::Color color, + ScrollView* window) const; + + private: + // Struct to store a triplet of unichar, font, distance in the distance cache. + struct FontClassDistance { + int unichar_id; + int font_id; // Real font id. + float distance; + }; + // Simple struct to store information related to each font/class combination. + struct FontClassInfo { + FontClassInfo(); + + // Writes to the given file. Returns false in case of error. + bool Serialize(FILE* fp) const; + // Reads from the given file. Returns false in case of error. + // If swap is true, assumes a big/little-endian swap is needed. + bool DeSerialize(bool swap, FILE* fp); + + // Number of raw samples. + inT32 num_raw_samples; + // Index of the canonical sample. + inT32 canonical_sample; + // Max distance of the canonical sample from any other. + float canonical_dist; + // Sample indices for the samples, including replicated. + GenericVector samples; + + // Non-serialized cache data. + // Indexed features of the canonical sample. + GenericVector canonical_features; + // The mapped features of all the samples. + BitVector cloud_features; + + // Caches for ClusterDistance. + // Caches for other fonts but matching this unichar. -1 indicates not set. + // Indexed by compact font index from font_id_map_. + GenericVector font_distance_cache; + // Caches for other unichars but matching this font. -1 indicates not set. + GenericVector unichar_distance_cache; + // Cache for the rest (non matching font and unichar.) + // A cache of distances computed by ReliablySeparable. + GenericVector distance_cache; + }; + + PointerVector samples_; + // Number of samples before replication/randomization. + int num_raw_samples_; + // Character set we are training for. + UNICHARSET unicharset_; + // Character set size to which the 2-d arrays below refer. + int unicharset_size_; + // Map to allow the font_class_array_ below to be compact. + // The sparse space is the real font_id, used in samples_ . + // The compact space is an index to font_class_array_ + IndexMapBiDi font_id_map_; + // A 2-d array of FontClassInfo holding information related to each + // (font_id, class_id) pair. + GENERIC_2D_ARRAY* font_class_array_; + + // Reference to the fontinfo_table_ in MasterTrainer. Provides names + // for font_ids in the samples. Not serialized! + const UnicityTable& fontinfo_table_; +}; + +} // namespace tesseract. + + +#endif // TRAININGSAMPLESETSET_H_