tesseract v5.3.3.20231005
lstmtraining.cpp
Go to the documentation of this file.
1
2// File: lstmtraining.cpp
3// Description: Training program for LSTM-based networks.
4// Author: Ray Smith
5//
6// (C) Copyright 2013, Google Inc.
7// Licensed under the Apache License, Version 2.0 (the "License");
8// you may not use this file except in compliance with the License.
9// You may obtain a copy of the License at
10// http://www.apache.org/licenses/LICENSE-2.0
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
17
18#include <cerrno>
19#include <locale> // for std::locale::classic
20#if defined(__USE_GNU)
21# include <cfenv> // for feenableexcept
22#endif
23#include "commontraining.h"
24#include "fileio.h" // for LoadFileLinesToStrings
25#include "lstmtester.h"
26#include "lstmtrainer.h"
27#include "params.h"
28#include "tprintf.h"
30
31using namespace tesseract;
32
33static INT_PARAM_FLAG(debug_interval, 0, "How often to display the alignment.");
34static STRING_PARAM_FLAG(net_spec, "", "Network specification");
35static INT_PARAM_FLAG(net_mode, 192, "Controls network behavior.");
36static INT_PARAM_FLAG(perfect_sample_delay, 0, "How many imperfect samples between perfect ones.");
37static DOUBLE_PARAM_FLAG(target_error_rate, 0.01, "Final error rate in percent.");
38static DOUBLE_PARAM_FLAG(weight_range, 0.1, "Range of initial random weights.");
39static DOUBLE_PARAM_FLAG(learning_rate, 10.0e-4, "Weight factor for new deltas.");
40static BOOL_PARAM_FLAG(reset_learning_rate, false,
41 "Resets all stored learning rates to the value specified by --learning_rate.");
42static DOUBLE_PARAM_FLAG(momentum, 0.5, "Decay factor for repeating deltas.");
43static DOUBLE_PARAM_FLAG(adam_beta, 0.999, "Decay factor for repeating deltas.");
44static INT_PARAM_FLAG(max_image_MB, 6000, "Max memory to use for images.");
45static STRING_PARAM_FLAG(continue_from, "", "Existing model to extend");
46static STRING_PARAM_FLAG(model_output, "lstmtrain", "Basename for output models");
47static STRING_PARAM_FLAG(train_listfile, "",
48 "File listing training files in lstmf training format.");
49static STRING_PARAM_FLAG(eval_listfile, "", "File listing eval files in lstmf training format.");
50#if defined(__USE_GNU)
51static BOOL_PARAM_FLAG(debug_float, false, "Raise error on certain float errors.");
52#endif
53static BOOL_PARAM_FLAG(stop_training, false, "Just convert the training model to a runtime model.");
54static BOOL_PARAM_FLAG(convert_to_int, false, "Convert the recognition model to an integer model.");
55static BOOL_PARAM_FLAG(sequential_training, false,
56 "Use the training files sequentially instead of round-robin.");
57static INT_PARAM_FLAG(append_index, -1,
58 "Index in continue_from Network at which to"
59 " attach the new network defined by net_spec");
60static BOOL_PARAM_FLAG(debug_network, false, "Get info on distribution of weight values");
61static INT_PARAM_FLAG(max_iterations, 0, "If set, exit after this many iterations");
62static STRING_PARAM_FLAG(traineddata, "", "Combined Dawgs/Unicharset/Recoder for language model");
63static STRING_PARAM_FLAG(old_traineddata, "",
64 "When changing the character set, this specifies the old"
65 " character set that is to be replaced");
66static BOOL_PARAM_FLAG(randomly_rotate, false,
67 "Train OSD and randomly turn training samples upside-down");
68
69// Number of training images to train between calls to MaintainCheckpoints.
70const int kNumPagesPerBatch = 100;
71
72// Apart from command-line flags, input is a collection of lstmf files, that
73// were previously created using tesseract with the lstm.train config file.
74// The program iterates over the inputs, feeding the data to the network,
75// until the error rate reaches a specified target or max_iterations is reached.
76int main(int argc, char **argv) {
77 tesseract::CheckSharedLibraryVersion();
78 ParseArguments(&argc, &argv);
79#if defined(__USE_GNU)
80 if (FLAGS_debug_float) {
81 // Raise SIGFPE for unwanted floating point calculations.
82 feenableexcept(FE_DIVBYZERO | FE_OVERFLOW | FE_INVALID);
83 }
84#endif
85 if (FLAGS_model_output.empty()) {
86 tprintf("Must provide a --model_output!\n");
87 return EXIT_FAILURE;
88 }
89 if (FLAGS_traineddata.empty()) {
90 tprintf("Must provide a --traineddata see training documentation\n");
91 return EXIT_FAILURE;
92 }
93
94 // Check write permissions.
95 std::string test_file = FLAGS_model_output.c_str();
96 test_file += "_wtest";
97 FILE *f = fopen(test_file.c_str(), "wb");
98 if (f != nullptr) {
99 fclose(f);
100 if (remove(test_file.c_str()) != 0) {
101 tprintf("Error, failed to remove %s: %s\n", test_file.c_str(), strerror(errno));
102 return EXIT_FAILURE;
103 }
104 } else {
105 tprintf("Error, model output cannot be written: %s\n", strerror(errno));
106 return EXIT_FAILURE;
107 }
108
109 // Setup the trainer.
110 std::string checkpoint_file = FLAGS_model_output.c_str();
111 checkpoint_file += "_checkpoint";
112 std::string checkpoint_bak = checkpoint_file + ".bak";
113 tesseract::LSTMTrainer trainer(FLAGS_model_output.c_str(), checkpoint_file.c_str(),
114 FLAGS_debug_interval,
115 static_cast<int64_t>(FLAGS_max_image_MB) * 1048576);
116 if (!trainer.InitCharSet(FLAGS_traineddata.c_str())) {
117 tprintf("Error, failed to read %s\n", FLAGS_traineddata.c_str());
118 return EXIT_FAILURE;
119 }
120
121 // Reading something from an existing model doesn't require many flags,
122 // so do it now and exit.
123 if (FLAGS_stop_training || FLAGS_debug_network) {
124 if (!trainer.TryLoadingCheckpoint(FLAGS_continue_from.c_str(), nullptr)) {
125 tprintf("Failed to read continue from: %s\n", FLAGS_continue_from.c_str());
126 return EXIT_FAILURE;
127 }
128 if (FLAGS_debug_network) {
129 trainer.DebugNetwork();
130 } else {
131 if (FLAGS_convert_to_int) {
132 trainer.ConvertToInt();
133 }
134 if (!trainer.SaveTraineddata(FLAGS_model_output.c_str())) {
135 tprintf("Failed to write recognition model : %s\n", FLAGS_model_output.c_str());
136 }
137 }
138 return EXIT_SUCCESS;
139 }
140
141 // Get the list of files to process.
142 if (FLAGS_train_listfile.empty()) {
143 tprintf("Must supply a list of training filenames! --train_listfile\n");
144 return EXIT_FAILURE;
145 }
146 std::vector<std::string> filenames;
147 if (!tesseract::LoadFileLinesToStrings(FLAGS_train_listfile.c_str(), &filenames)) {
148 tprintf("Failed to load list of training filenames from %s\n", FLAGS_train_listfile.c_str());
149 return EXIT_FAILURE;
150 }
151
152 // Checkpoints always take priority if they are available.
153 if (trainer.TryLoadingCheckpoint(checkpoint_file.c_str(), nullptr) ||
154 trainer.TryLoadingCheckpoint(checkpoint_bak.c_str(), nullptr)) {
155 tprintf("Successfully restored trainer from %s\n", checkpoint_file.c_str());
156 } else {
157 if (!FLAGS_continue_from.empty()) {
158 // Load a past model file to improve upon.
159 if (!trainer.TryLoadingCheckpoint(FLAGS_continue_from.c_str(),
160 FLAGS_append_index >= 0 ? FLAGS_continue_from.c_str()
161 : FLAGS_old_traineddata.c_str())) {
162 tprintf("Failed to continue from: %s\n", FLAGS_continue_from.c_str());
163 return EXIT_FAILURE;
164 }
165 tprintf("Continuing from %s\n", FLAGS_continue_from.c_str());
166 if (FLAGS_reset_learning_rate) {
167 trainer.SetLearningRate(FLAGS_learning_rate);
168 tprintf("Set learning rate to %f\n", static_cast<float>(FLAGS_learning_rate));
169 }
170 trainer.InitIterations();
171 }
172 if (FLAGS_continue_from.empty() || FLAGS_append_index >= 0) {
173 if (FLAGS_append_index >= 0) {
174 tprintf("Appending a new network to an old one!!");
175 if (FLAGS_continue_from.empty()) {
176 tprintf("Must set --continue_from for appending!\n");
177 return EXIT_FAILURE;
178 }
179 }
180 // We are initializing from scratch.
181 if (!trainer.InitNetwork(FLAGS_net_spec.c_str(), FLAGS_append_index, FLAGS_net_mode,
182 FLAGS_weight_range, FLAGS_learning_rate, FLAGS_momentum,
183 FLAGS_adam_beta)) {
184 tprintf("Failed to create network from spec: %s\n", FLAGS_net_spec.c_str());
185 return EXIT_FAILURE;
186 }
187 trainer.set_perfect_delay(FLAGS_perfect_sample_delay);
188 }
189 }
190 if (!trainer.LoadAllTrainingData(
191 filenames,
192 FLAGS_sequential_training ? tesseract::CS_SEQUENTIAL : tesseract::CS_ROUND_ROBIN,
193 FLAGS_randomly_rotate)) {
194 tprintf("Load of images failed!!\n");
195 return EXIT_FAILURE;
196 }
197
198 tesseract::LSTMTester tester(static_cast<int64_t>(FLAGS_max_image_MB) * 1048576);
199 tesseract::TestCallback tester_callback = nullptr;
200 if (!FLAGS_eval_listfile.empty()) {
201 using namespace std::placeholders; // for _1, _2, _3...
202 if (!tester.LoadAllEvalData(FLAGS_eval_listfile.c_str())) {
203 tprintf("Failed to load eval data from: %s\n", FLAGS_eval_listfile.c_str());
204 return EXIT_FAILURE;
205 }
206 tester_callback = std::bind(&tesseract::LSTMTester::RunEvalAsync, &tester, _1, _2, _3, _4);
207 }
208
209 int max_iterations = FLAGS_max_iterations;
210 if (max_iterations < 0) {
211 // A negative value is interpreted as epochs
212 max_iterations = filenames.size() * (-max_iterations);
213 } else if (max_iterations == 0) {
214 // "Infinite" iterations.
215 max_iterations = INT_MAX;
216 }
217
218 do {
219 // Train a few.
220 int iteration = trainer.training_iteration();
221 for (int target_iteration = iteration + kNumPagesPerBatch;
222 iteration < target_iteration && iteration < max_iterations;
223 iteration = trainer.training_iteration()) {
224 trainer.TrainOnLine(&trainer, false);
225 }
226 std::stringstream log_str;
227 log_str.imbue(std::locale::classic());
228 trainer.MaintainCheckpoints(tester_callback, log_str);
229 tprintf("%s\n", log_str.str().c_str());
230 } while (trainer.best_error_rate() > FLAGS_target_error_rate &&
231 (trainer.training_iteration() < max_iterations));
232 tprintf("Finished! Selected model with minimal training error rate (BCER) = %g\n",
233 trainer.best_error_rate());
234 return EXIT_SUCCESS;
235} /* main */
#define DOUBLE_PARAM_FLAG(name, val, comment)
#define BOOL_PARAM_FLAG(name, val, comment)
#define INT_PARAM_FLAG(name, val, comment)
#define STRING_PARAM_FLAG(name, val, comment)
int main(int argc, char **argv)
const int kNumPagesPerBatch
void tprintf(const char *format,...)
Definition: tprintf.cpp:41
void ParseArguments(int *argc, char ***argv)
std::function< std::string(int, const double *, const TessdataManager &, int)> TestCallback
Definition: lstmtrainer.h:78
@ CS_SEQUENTIAL
Definition: imagedata.h:49
@ CS_ROUND_ROBIN
Definition: imagedata.h:54
bool LoadFileLinesToStrings(const char *filename, std::vector< std::string > *lines)
Definition: fileio.h:32
void SetLearningRate(float learning_rate)
std::string RunEvalAsync(int iteration, const double *training_errors, const TessdataManager &model_mgr, int training_stage)
Definition: lstmtester.cpp:51
bool LoadAllEvalData(const char *filenames_file)
Definition: lstmtester.cpp:30
bool MaintainCheckpoints(const TestCallback &tester, std::stringstream &log_msg)
bool LoadAllTrainingData(const std::vector< std::string > &filenames, CachingStrategy cache_strategy, bool randomly_rotate)
bool InitCharSet(const std::string &traineddata_path)
Definition: lstmtrainer.h:100
double best_error_rate() const
Definition: lstmtrainer.h:139
bool SaveTraineddata(const char *filename)
void set_perfect_delay(int delay)
Definition: lstmtrainer.h:151
bool InitNetwork(const char *network_spec, int append_index, int net_flags, float weight_range, float learning_rate, float momentum, float adam_beta)
const ImageData * TrainOnLine(LSTMTrainer *samples_trainer, bool batch)
Definition: lstmtrainer.h:268
bool TryLoadingCheckpoint(const char *filename, const char *old_traineddata)