tesseract v5.3.3.20231005
mastertrainer_test.cc
Go to the documentation of this file.
1// (C) Copyright 2017, Google Inc.
2// Licensed under the Apache License, Version 2.0 (the "License");
3// you may not use this file except in compliance with the License.
4// You may obtain a copy of the License at
5// http://www.apache.org/licenses/LICENSE-2.0
6// Unless required by applicable law or agreed to in writing, software
7// distributed under the License is distributed on an "AS IS" BASIS,
8// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9// See the License for the specific language governing permissions and
10// limitations under the License.
11
12// Although this is a trivial-looking test, it exercises a lot of code:
13// SampleIterator has to correctly iterate over the correct characters, or
14// it will fail.
15// The canonical and cloud features computed by TrainingSampleSet need to
16// be correct, along with the distance caches, organizing samples by font
17// and class, indexing of features, distance calculations.
18// IntFeatureDist has to work, or the canonical samples won't work.
19// Mastertrainer has ability to read tr files and set itself up tested.
20// Finally the serialize/deserialize test ensures that MasterTrainer,
21// TrainingSampleSet, TrainingSample can all serialize/deserialize correctly
22// enough to reproduce the same results.
23
24#include "include_gunit.h"
25
26#include "commontraining.h"
27#include "errorcounter.h"
28#include "log.h" // for LOG
29#include "mastertrainer.h"
30#include "shapeclassifier.h"
31#include "shapetable.h"
32#include "trainingsample.h"
33#include "unicharset.h"
34
35#include <string>
36#include <utility>
37#include <vector>
38
39using namespace tesseract;
40
41// Specs of the MockClassifier.
42static const int kNumTopNErrs = 10;
43static const int kNumTop2Errs = kNumTopNErrs + 20;
44static const int kNumTop1Errs = kNumTop2Errs + 30;
45static const int kNumTopTopErrs = kNumTop1Errs + 25;
46static const int kNumNonReject = 1000;
47static const int kNumCorrect = kNumNonReject - kNumTop1Errs;
48// The total number of answers is given by the number of non-rejects plus
49// all the multiple answers.
50static const int kNumAnswers = kNumNonReject + 2 * (kNumTop2Errs - kNumTopNErrs) +
51 (kNumTop1Errs - kNumTop2Errs) + (kNumTopTopErrs - kNumTop1Errs);
52
53#ifndef DISABLED_LEGACY_ENGINE
54static bool safe_strto32(const std::string &str, int *pResult) {
55 long n = strtol(str.c_str(), nullptr, 0);
56 *pResult = n;
57 return true;
58}
59#endif
60
61// Mock ShapeClassifier that cheats by looking at the correct answer, and
62// creates a specific pattern of errors that can be tested.
64public:
65 explicit MockClassifier(ShapeTable *shape_table)
66 : shape_table_(shape_table), num_done_(0), done_bad_font_(false) {
67 // Add a false font answer to the shape table. We pick a random unichar_id,
68 // add a new shape for it with a false font. Font must actually exist in
69 // the font table, but not match anything in the first 1000 samples.
70 false_unichar_id_ = 67;
71 false_shape_ = shape_table_->AddShape(false_unichar_id_, 25);
72 }
73 ~MockClassifier() override = default;
74
75 // Classifies the given [training] sample, writing to results.
76 // If debug is non-zero, then various degrees of classifier dependent debug
77 // information is provided.
78 // If keep_this (a shape index) is >= 0, then the results should always
79 // contain keep_this, and (if possible) anything of intermediate confidence.
80 // The return value is the number of classes saved in results.
81 int ClassifySample(const TrainingSample &sample, Image page_pix, int debug, UNICHAR_ID keep_this,
82 std::vector<ShapeRating> *results) override {
83 results->clear();
84 // Everything except the first kNumNonReject is a reject.
85 if (++num_done_ > kNumNonReject) {
86 return 0;
87 }
88
89 int class_id = sample.class_id();
90 int font_id = sample.font_id();
91 int shape_id = shape_table_->FindShape(class_id, font_id);
92 // Get ids of some wrong answers.
93 int wrong_id1 = shape_id > 10 ? shape_id - 1 : shape_id + 1;
94 int wrong_id2 = shape_id > 10 ? shape_id - 2 : shape_id + 2;
95 if (num_done_ <= kNumTopNErrs) {
96 // The first kNumTopNErrs are top-n errors.
97 results->push_back(ShapeRating(wrong_id1, 1.0f));
98 } else if (num_done_ <= kNumTop2Errs) {
99 // The next kNumTop2Errs - kNumTopNErrs are top-2 errors.
100 results->push_back(ShapeRating(wrong_id1, 1.0f));
101 results->push_back(ShapeRating(wrong_id2, 0.875f));
102 results->push_back(ShapeRating(shape_id, 0.75f));
103 } else if (num_done_ <= kNumTop1Errs) {
104 // The next kNumTop1Errs - kNumTop2Errs are top-1 errors.
105 results->push_back(ShapeRating(wrong_id1, 1.0f));
106 results->push_back(ShapeRating(shape_id, 0.8f));
107 } else if (num_done_ <= kNumTopTopErrs) {
108 // The next kNumTopTopErrs - kNumTop1Errs are cases where the actual top
109 // is not correct, but do not count as a top-1 error because the rating
110 // is close enough to the top answer.
111 results->push_back(ShapeRating(wrong_id1, 1.0f));
112 results->push_back(ShapeRating(shape_id, 0.99f));
113 } else if (!done_bad_font_ && class_id == false_unichar_id_) {
114 // There is a single character with a bad font.
115 results->push_back(ShapeRating(false_shape_, 1.0f));
116 done_bad_font_ = true;
117 } else {
118 // Everything else is correct.
119 results->push_back(ShapeRating(shape_id, 1.0f));
120 }
121 return results->size();
122 }
123 // Provides access to the ShapeTable that this classifier works with.
124 const ShapeTable *GetShapeTable() const override {
125 return shape_table_;
126 }
127
128private:
129 // Borrowed pointer to the ShapeTable.
130 ShapeTable *shape_table_;
131 // Unichar_id of a random character that occurs after the first 60 samples.
132 int false_unichar_id_;
133 // Shape index of prepared false answer for false_unichar_id.
134 int false_shape_;
135 // The number of classifications we have processed.
136 int num_done_;
137 // True after the false font has been emitted.
138 bool done_bad_font_;
139};
140
141const double kMin1lDistance = 0.25;
142
143// The fixture for testing Tesseract.
145#ifndef DISABLED_LEGACY_ENGINE
146protected:
147 void SetUp() override {
148 std::locale::global(std::locale(""));
150 }
151
152 std::string TestDataNameToPath(const std::string &name) {
153 return file::JoinPath(TESTING_DIR, name);
154 }
155 std::string TmpNameToPath(const std::string &name) {
156 return file::JoinPath(FLAGS_test_tmpdir, name);
157 }
158
160 shape_table_ = nullptr;
161 master_trainer_ = nullptr;
162 }
164 delete shape_table_;
165 }
166
167 // Initializes the master_trainer_ and shape_table_.
168 // if load_from_tmp, then reloads a master trainer that was saved by a
169 // previous call in which it was false.
171 FLAGS_output_trainer = TmpNameToPath("tmp_trainer").c_str();
172 FLAGS_F = file::JoinPath(LANGDATA_DIR, "font_properties").c_str();
173 FLAGS_X = TestDataNameToPath("eng.xheights").c_str();
174 FLAGS_U = TestDataNameToPath("eng.unicharset").c_str();
175 std::string tr_file_name(TestDataNameToPath("eng.Arial.exp0.tr"));
176 const char *filelist[] = {tr_file_name.c_str(), nullptr};
177 std::string file_prefix;
178 delete shape_table_;
179 shape_table_ = nullptr;
180 master_trainer_ = LoadTrainingData(filelist, false, &shape_table_, file_prefix);
181 EXPECT_TRUE(master_trainer_ != nullptr);
182 EXPECT_TRUE(shape_table_ != nullptr);
183 }
184
185 // EXPECTs that the distance between I and l in Arial is 0 and that the
186 // distance to 1 is significantly not 0.
187 void VerifyIl1() {
188 // Find the font id for Arial.
189 int font_id = master_trainer_->GetFontInfoId("Arial");
190 EXPECT_GE(font_id, 0);
191 // Track down the characters we are interested in.
192 int unichar_I = master_trainer_->unicharset().unichar_to_id("I");
193 EXPECT_GT(unichar_I, 0);
194 int unichar_l = master_trainer_->unicharset().unichar_to_id("l");
195 EXPECT_GT(unichar_l, 0);
196 int unichar_1 = master_trainer_->unicharset().unichar_to_id("1");
197 EXPECT_GT(unichar_1, 0);
198 // Now get the shape ids.
199 int shape_I = shape_table_->FindShape(unichar_I, font_id);
200 EXPECT_GE(shape_I, 0);
201 int shape_l = shape_table_->FindShape(unichar_l, font_id);
202 EXPECT_GE(shape_l, 0);
203 int shape_1 = shape_table_->FindShape(unichar_1, font_id);
204 EXPECT_GE(shape_1, 0);
205
206 float dist_I_l = master_trainer_->ShapeDistance(*shape_table_, shape_I, shape_l);
207 // No tolerance here. We expect that I and l should match exactly.
208 EXPECT_EQ(0.0f, dist_I_l);
209 float dist_l_I = master_trainer_->ShapeDistance(*shape_table_, shape_l, shape_I);
210 // BOTH ways.
211 EXPECT_EQ(0.0f, dist_l_I);
212
213 // l/1 on the other hand should be distinct.
214 float dist_l_1 = master_trainer_->ShapeDistance(*shape_table_, shape_l, shape_1);
215 EXPECT_GT(dist_l_1, kMin1lDistance);
216 float dist_1_l = master_trainer_->ShapeDistance(*shape_table_, shape_1, shape_l);
217 EXPECT_GT(dist_1_l, kMin1lDistance);
218
219 // So should I/1.
220 float dist_I_1 = master_trainer_->ShapeDistance(*shape_table_, shape_I, shape_1);
221 EXPECT_GT(dist_I_1, kMin1lDistance);
222 float dist_1_I = master_trainer_->ShapeDistance(*shape_table_, shape_1, shape_I);
223 EXPECT_GT(dist_1_I, kMin1lDistance);
224 }
225
226 // Objects declared here can be used by all tests in the test case for Foo.
228 std::unique_ptr<MasterTrainer> master_trainer_;
229#endif
230};
231
232// Tests that the MasterTrainer correctly loads its data and reaches the correct
233// conclusion over the distance between Arial I l and 1.
235#ifdef DISABLED_LEGACY_ENGINE
236 // Skip test because LoadTrainingData is missing.
237 GTEST_SKIP();
238#else
239 // Initialize the master_trainer_ and load the Arial tr file.
240 LoadMasterTrainer();
241 VerifyIl1();
242#endif
243}
244
245// Tests the ErrorCounter using a MockClassifier to check that it counts
246// error categories correctly.
247TEST_F(MasterTrainerTest, ErrorCounterTest) {
248#ifdef DISABLED_LEGACY_ENGINE
249 // Skip test because LoadTrainingData is missing.
250 GTEST_SKIP();
251#else
252 // Initialize the master_trainer_ from the saved tmp file.
253 LoadMasterTrainer();
254 // Add the space character to the shape_table_ if not already present to
255 // count junk.
256 if (shape_table_->FindShape(0, -1) < 0) {
257 shape_table_->AddShape(0, 0);
258 }
259 // Make a mock classifier.
260 auto shape_classifier = std::make_unique<MockClassifier>(shape_table_);
261 // Get the accuracy report.
262 std::string accuracy_report;
263 master_trainer_->TestClassifierOnSamples(tesseract::CT_UNICHAR_TOP1_ERR, 0, false,
264 shape_classifier.get(), &accuracy_report);
265 LOG(INFO) << accuracy_report.c_str();
266 std::string result_string = accuracy_report.c_str();
267 std::vector<std::string> results = split(result_string, '\t');
268 EXPECT_EQ(tesseract::CT_SIZE + 1, results.size());
269 int result_values[tesseract::CT_SIZE];
270 for (int i = 0; i < tesseract::CT_SIZE; ++i) {
271 EXPECT_TRUE(safe_strto32(results[i + 1], &result_values[i]));
272 }
273 // These tests are more-or-less immune to additions to the number of
274 // categories or changes in the training data.
275 int num_samples = master_trainer_->GetSamples()->num_raw_samples();
276 EXPECT_EQ(kNumCorrect, result_values[tesseract::CT_UNICHAR_TOP_OK]);
277 EXPECT_EQ(1, result_values[tesseract::CT_FONT_ATTR_ERR]);
278 EXPECT_EQ(kNumTopTopErrs, result_values[tesseract::CT_UNICHAR_TOPTOP_ERR]);
279 EXPECT_EQ(kNumTop1Errs, result_values[tesseract::CT_UNICHAR_TOP1_ERR]);
280 EXPECT_EQ(kNumTop2Errs, result_values[tesseract::CT_UNICHAR_TOP2_ERR]);
281 EXPECT_EQ(kNumTopNErrs, result_values[tesseract::CT_UNICHAR_TOPN_ERR]);
282 // Each of the TOPTOP errs also counts as a multi-unichar.
283 EXPECT_EQ(kNumTopTopErrs - kNumTop1Errs, result_values[tesseract::CT_OK_MULTI_UNICHAR]);
284 EXPECT_EQ(num_samples - kNumNonReject, result_values[tesseract::CT_REJECT]);
285 EXPECT_EQ(kNumAnswers, result_values[tesseract::CT_NUM_RESULTS]);
286#endif
287}
TEST_F(MasterTrainerTest, Il1Test)
const double kMin1lDistance
@ LOG
@ INFO
Definition: log.h:28
#define GTEST_SKIP()
Definition: gtest.h:1889
#define EXPECT_EQ(val1, val2)
Definition: gtest.h:2043
#define EXPECT_GT(val1, val2)
Definition: gtest.h:2053
#define EXPECT_GE(val1, val2)
Definition: gtest.h:2051
#define EXPECT_TRUE(condition)
Definition: gtest.h:1982
int UNICHAR_ID
Definition: unichar.h:34
std::unique_ptr< MasterTrainer > LoadTrainingData(const char *const *filelist, bool replication, ShapeTable **shape_table, std::string &file_prefix)
@ CT_UNICHAR_TOPN_ERR
Definition: errorcounter.h:76
@ CT_UNICHAR_TOP_OK
Definition: errorcounter.h:70
@ CT_UNICHAR_TOP1_ERR
Definition: errorcounter.h:74
@ CT_UNICHAR_TOP2_ERR
Definition: errorcounter.h:75
@ CT_UNICHAR_TOPTOP_ERR
Definition: errorcounter.h:77
@ CT_FONT_ATTR_ERR
Definition: errorcounter.h:82
@ CT_OK_MULTI_UNICHAR
Definition: errorcounter.h:78
@ CT_NUM_RESULTS
Definition: errorcounter.h:84
const std::vector< std::string > split(const std::string &s, char c)
Definition: helpers.h:43
unsigned AddShape(int unichar_id, int font_id)
Definition: shapetable.cpp:351
int FindShape(int unichar_id, int font_id) const
Definition: shapetable.cpp:400
UNICHAR_ID class_id() const
static void MakeTmpdir()
Definition: include_gunit.h:38
static std::string JoinPath(const std::string &s1, const std::string &s2)
Definition: include_gunit.h:65
~MockClassifier() override=default
MockClassifier(ShapeTable *shape_table)
const ShapeTable * GetShapeTable() const override
int ClassifySample(const TrainingSample &sample, Image page_pix, int debug, UNICHAR_ID keep_this, std::vector< ShapeRating > *results) override
std::string TestDataNameToPath(const std::string &name)
std::string TmpNameToPath(const std::string &name)
void SetUp() override
std::unique_ptr< MasterTrainer > master_trainer_