tesseract v5.3.3.20231005
lstm_test.h
Go to the documentation of this file.
1// (C) Copyright 2017, Google Inc.
2// Licensed under the Apache License, Version 2.0 (the "License");
3// you may not use this file except in compliance with the License.
4// You may obtain a copy of the License at
5// http://www.apache.org/licenses/LICENSE-2.0
6// Unless required by applicable law or agreed to in writing, software
7// distributed under the License is distributed on an "AS IS" BASIS,
8// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9// See the License for the specific language governing permissions and
10// limitations under the License.
11
12#ifndef TESSERACT_UNITTEST_LSTM_TEST_H_
13#define TESSERACT_UNITTEST_LSTM_TEST_H_
14
15#include <memory>
16#include <string>
17#include <utility>
18
19#include "include_gunit.h"
20
21#include "helpers.h"
22#include "tprintf.h"
23
24#include "functions.h"
25#include "lang_model_helpers.h"
26#include "log.h" // for LOG
27#include "lstmtrainer.h"
28#include "unicharset.h"
29
30namespace tesseract {
31
32#if DEBUG_DETAIL == 0
33// Number of iterations to run all the trainers.
34const int kTrainerIterations = 600;
35// Number of iterations between accuracy checks.
36const int kBatchIterations = 100;
37#else
38// Number of iterations to run all the trainers.
39const int kTrainerIterations = 2;
40// Number of iterations between accuracy checks.
41const int kBatchIterations = 1;
42#endif
43
44// The fixture for testing LSTMTrainer.
46protected:
47 void SetUp() override {
48 std::locale::global(std::locale(""));
50 }
51
52 LSTMTrainerTest() = default;
53 std::string TestDataNameToPath(const std::string &name) {
54 return file::JoinPath(TESTDATA_DIR, "" + name);
55 }
56 std::string TessDataNameToPath(const std::string &name) {
57 return file::JoinPath(TESSDATA_DIR, "" + name);
58 }
59 std::string TestingNameToPath(const std::string &name) {
60 return file::JoinPath(TESTING_DIR, "" + name);
61 }
62
63 void SetupTrainerEng(const std::string &network_spec, const std::string &model_name, bool recode,
64 bool adam) {
65 SetupTrainer(network_spec, model_name, "eng/eng.unicharset", "eng.Arial.exp0.lstmf", recode,
66 adam, 5e-4, false, "eng");
67 }
68 void SetupTrainer(const std::string &network_spec, const std::string &model_name,
69 const std::string &unicharset_file, const std::string &lstmf_file, bool recode,
70 bool adam, float learning_rate, bool layer_specific, const std::string &kLang) {
71 // constexpr char kLang[] = "eng"; // Exact value doesn't matter.
72 std::string unicharset_name = TestDataNameToPath(unicharset_file);
73 UNICHARSET unicharset;
74 ASSERT_TRUE(unicharset.load_from_file(unicharset_name.c_str(), false));
75 std::string script_dir = file::JoinPath(LANGDATA_DIR, "");
76 std::vector<std::string> words;
77 EXPECT_EQ(0, CombineLangModel(unicharset, script_dir, "", FLAGS_test_tmpdir, kLang, !recode,
78 words, words, words, false, nullptr, nullptr));
79 std::string model_path = file::JoinPath(FLAGS_test_tmpdir, model_name);
80 std::string checkpoint_path = model_path + "_checkpoint";
81 trainer_ = std::make_unique<LSTMTrainer>(model_path.c_str(), checkpoint_path.c_str(), 0, 0);
82 trainer_->InitCharSet(
83 file::JoinPath(FLAGS_test_tmpdir, kLang, kLang) + ".traineddata");
84 int net_mode = adam ? NF_ADAM : 0;
85 // Adam needs a higher learning rate, due to not multiplying the effective
86 // rate by 1/(1-momentum).
87 if (adam) {
88 learning_rate *= 20.0f;
89 }
90 if (layer_specific) {
91 net_mode |= NF_LAYER_SPECIFIC_LR;
92 }
94 trainer_->InitNetwork(network_spec.c_str(), -1, net_mode, 0.1, learning_rate, 0.9, 0.999));
95 std::vector<std::string> filenames;
96 filenames.emplace_back(TestDataNameToPath(lstmf_file).c_str());
97 EXPECT_TRUE(trainer_->LoadAllTrainingData(filenames, CS_SEQUENTIAL, false));
98 LOG(INFO) << "Setup network:" << model_name << "\n";
99 }
100 // Trains for a given number of iterations and returns the char error rate.
101 double TrainIterations(int max_iterations) {
102 int iteration = trainer_->training_iteration();
103 int iteration_limit = iteration + max_iterations;
104 double best_error = 100.0;
105 do {
106 std::stringstream log_str;
107 int target_iteration = iteration + kBatchIterations;
108 // Train a few.
109 double mean_error = 0.0;
110 while (iteration < target_iteration && iteration < iteration_limit) {
111 trainer_->TrainOnLine(trainer_.get(), false);
112 iteration = trainer_->training_iteration();
113 mean_error += trainer_->LastSingleError(ET_CHAR_ERROR);
114 }
115 trainer_->MaintainCheckpoints(nullptr, log_str);
116 iteration = trainer_->training_iteration();
117 mean_error *= 100.0 / kBatchIterations;
118 if (mean_error < best_error) {
119 best_error = mean_error;
120 }
121 } while (iteration < iteration_limit);
122 LOG(INFO) << "Trainer error rate = " << best_error << "\n";
123 return best_error;
124 }
125 // Tests for a given number of iterations and returns the char error rate.
126 double TestIterations(int max_iterations) {
127 CHECK_GT(max_iterations, 0);
128 int iteration = trainer_->sample_iteration();
129 double mean_error = 0.0;
130 int error_count = 0;
131 while (error_count < max_iterations) {
132 const ImageData &trainingdata =
133 *trainer_->mutable_training_data()->GetPageBySerial(iteration);
134 NetworkIO fwd_outputs, targets;
135 if (trainer_->PrepareForBackward(&trainingdata, &fwd_outputs, &targets) != UNENCODABLE) {
136 mean_error += trainer_->NewSingleError(ET_CHAR_ERROR);
137 ++error_count;
138 }
139 trainer_->SetIteration(++iteration);
140 }
141 mean_error *= 100.0 / max_iterations;
142 LOG(INFO) << "Tester error rate = " << mean_error << "\n";
143 return mean_error;
144 }
145 // Tests that the current trainer_ can be converted to int mode and still gets
146 // within 1% of the error rate. Returns the increase in error from float to
147 // int.
148 double TestIntMode(int test_iterations) {
149 std::vector<char> trainer_data;
150 EXPECT_TRUE(trainer_->SaveTrainingDump(NO_BEST_TRAINER, *trainer_, &trainer_data));
151 // Get the error on the next few iterations in float mode.
152 double float_err = TestIterations(test_iterations);
153 // Restore the dump, convert to int and test error on that.
154 EXPECT_TRUE(trainer_->ReadTrainingDump(trainer_data, *trainer_));
155 trainer_->ConvertToInt();
156 double int_err = TestIterations(test_iterations);
157 EXPECT_LT(int_err, float_err + 1.0);
158 return int_err - float_err;
159 }
160 // Sets up a trainer with the given language and given recode+ctc condition.
161 // It then verifies that the given str encodes and decodes back to the same
162 // string.
163 void TestEncodeDecode(const std::string &lang, const std::string &str, bool recode) {
164 std::string unicharset_name = lang + "/" + lang + ".unicharset";
165 std::string lstmf_name = lang + ".Arial_Unicode_MS.exp0.lstmf";
166 SetupTrainer("[1,1,0,32 Lbx100 O1c1]", "bidi-lstm", unicharset_name, lstmf_name, recode, true,
167 5e-4f, true, lang);
168 std::vector<int> labels;
169 EXPECT_TRUE(trainer_->EncodeString(str.c_str(), &labels));
170 std::string decoded = trainer_->DecodeLabels(labels);
171 std::string decoded_str(&decoded[0], decoded.length());
172 EXPECT_EQ(str, decoded_str);
173 }
174 // Calls TestEncodeDeode with both recode on and off.
175 void TestEncodeDecodeBoth(const std::string &lang, const std::string &str) {
176 TestEncodeDecode(lang, str, false);
177 TestEncodeDecode(lang, str, true);
178 }
179
180 std::unique_ptr<LSTMTrainer> trainer_;
181};
182
183} // namespace tesseract.
184
185#endif // THIRD_PARTY_TESSERACT_UNITTEST_LSTM_TEST_H_
@ LOG
@ INFO
Definition: log.h:28
#define EXPECT_EQ(val1, val2)
Definition: gtest.h:2043
#define EXPECT_TRUE(condition)
Definition: gtest.h:1982
#define ASSERT_TRUE(condition)
Definition: gtest.h:1990
#define EXPECT_LT(val1, val2)
Definition: gtest.h:2049
#define CHECK_GT(test, value)
Definition: include_gunit.h:81
@ ET_CHAR_ERROR
Definition: lstmtrainer.h:45
const int kBatchIterations
Definition: lstm_test.h:36
const int kTrainerIterations
Definition: lstm_test.h:34
@ NF_LAYER_SPECIFIC_LR
Definition: network.h:85
@ NF_ADAM
Definition: network.h:86
@ CS_SEQUENTIAL
Definition: imagedata.h:49
int CombineLangModel(const UNICHARSET &unicharset, const std::string &script_dir, const std::string &version_str, const std::string &output_dir, const std::string &lang, bool pass_through_recoder, const std::vector< std::string > &words, const std::vector< std::string > &puncs, const std::vector< std::string > &numbers, bool lang_is_rtl, FileReader reader, FileWriter writer)
@ NO_BEST_TRAINER
Definition: lstmtrainer.h:62
bool load_from_file(const char *const filename, bool skip_fragments)
Definition: unicharset.h:391
static void MakeTmpdir()
Definition: include_gunit.h:38
static std::string JoinPath(const std::string &s1, const std::string &s2)
Definition: include_gunit.h:65
double TestIntMode(int test_iterations)
Definition: lstm_test.h:148
void SetUp() override
Definition: lstm_test.h:47
std::unique_ptr< LSTMTrainer > trainer_
Definition: lstm_test.h:180
std::string TessDataNameToPath(const std::string &name)
Definition: lstm_test.h:56
double TrainIterations(int max_iterations)
Definition: lstm_test.h:101
void TestEncodeDecode(const std::string &lang, const std::string &str, bool recode)
Definition: lstm_test.h:163
double TestIterations(int max_iterations)
Definition: lstm_test.h:126
std::string TestingNameToPath(const std::string &name)
Definition: lstm_test.h:59
void TestEncodeDecodeBoth(const std::string &lang, const std::string &str)
Definition: lstm_test.h:175
void SetupTrainerEng(const std::string &network_spec, const std::string &model_name, bool recode, bool adam)
Definition: lstm_test.h:63
std::string TestDataNameToPath(const std::string &name)
Definition: lstm_test.h:53
void SetupTrainer(const std::string &network_spec, const std::string &model_name, const std::string &unicharset_file, const std::string &lstmf_file, bool recode, bool adam, float learning_rate, bool layer_specific, const std::string &kLang)
Definition: lstm_test.h:68