tesseract v5.3.3.20231005
lstmtrainer.h
Go to the documentation of this file.
1
2// File: lstmtrainer.h
3// Description: Top-level line trainer class for LSTM-based networks.
4// Author: Ray Smith
5//
6// (C) Copyright 2013, Google Inc.
7// Licensed under the Apache License, Version 2.0 (the "License");
8// you may not use this file except in compliance with the License.
9// You may obtain a copy of the License at
10// http://www.apache.org/licenses/LICENSE-2.0
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
17
18#ifndef TESSERACT_LSTM_LSTMTRAINER_H_
19#define TESSERACT_LSTM_LSTMTRAINER_H_
20
21#include "export.h"
22
23#include "imagedata.h" // for DocumentCache
24#include "lstmrecognizer.h"
25#include "rect.h"
26
27#include <functional> // for std::function
28#include <sstream> // for std::stringstream
29
30namespace tesseract {
31
32class LSTM;
33class LSTMTester;
34class LSTMTrainer;
35class Parallel;
36class Reversed;
37class Softmax;
38class Series;
39
40// Enum for the types of errors that are counted.
42 ET_RMS, // RMS activation error.
43 ET_DELTA, // Number of big errors in deltas.
44 ET_WORD_RECERR, // Output text string word recall error.
45 ET_CHAR_ERROR, // Output text string total char error.
46 ET_SKIP_RATIO, // Fraction of samples skipped.
47 ET_COUNT // For array sizing.
48};
49
50// Enum for the trainability_ flags.
52 TRAINABLE, // Non-zero delta error.
53 PERFECT, // Zero delta error.
54 UNENCODABLE, // Not trainable due to coding/alignment trouble.
55 HI_PRECISION_ERR, // Hi confidence disagreement.
56 NOT_BOXED, // Early in training and has no character boxes.
57};
58
59// Enum to define the amount of data to get serialized.
61 LIGHT, // Minimal data for remote training.
62 NO_BEST_TRAINER, // Save an empty vector in place of best_trainer_.
63 FULL, // All data including best_trainer_.
64};
65
66// Enum to indicate how the sub_trainer_ training went.
68 STR_NONE, // Did nothing as not good enough.
69 STR_UPDATED, // Subtrainer was updated, but didn't replace *this.
70 STR_REPLACED // Subtrainer replaced *this.
71};
72
73class LSTMTrainer;
74// Function to compute and record error rates on some external test set(s).
75// Args are: iteration, mean errors, model, training stage.
76// Returns a string containing logging information about the tests.
77using TestCallback = std::function<std::string(int, const double *,
78 const TessdataManager &, int)>;
79
80// Trainer class for LSTM networks. Most of the effort is in creating the
81// ideal target outputs from the transcription. A box file is used if it is
82// available, otherwise estimates of the char widths from the unicharset are
83// used to guide a DP search for the best fit to the transcription.
84class TESS_UNICHARSET_TRAINING_API LSTMTrainer : public LSTMRecognizer {
85public:
87 LSTMTrainer(const char *model_base, const char *checkpoint_name,
88 int debug_interval, int64_t max_memory);
89 virtual ~LSTMTrainer();
90
91 // Tries to deserialize a trainer from the given file and silently returns
92 // false in case of failure. If old_traineddata is not null, then it is
93 // assumed that the character set is to be re-mapped from old_traineddata to
94 // the new, with consequent change in weight matrices etc.
95 bool TryLoadingCheckpoint(const char *filename, const char *old_traineddata);
96
97 // Initializes the character set encode/decode mechanism directly from a
98 // previously setup traineddata containing dawgs, UNICHARSET and
99 // UnicharCompress. Note: Call before InitNetwork!
100 bool InitCharSet(const std::string &traineddata_path) {
101 bool success = mgr_.Init(traineddata_path.c_str());
102 if (success) {
103 InitCharSet();
104 }
105 return success;
106 }
107 void InitCharSet(const TessdataManager &mgr) {
108 mgr_ = mgr;
109 InitCharSet();
110 }
111
112 // Initializes the trainer with a network_spec in the network description
113 // net_flags control network behavior according to the NetworkFlags enum.
114 // There isn't really much difference between them - only where the effects
115 // are implemented.
116 // For other args see NetworkBuilder::InitNetwork.
117 // Note: Be sure to call InitCharSet before InitNetwork!
118 bool InitNetwork(const char *network_spec, int append_index, int net_flags,
119 float weight_range, float learning_rate, float momentum,
120 float adam_beta);
121 // Initializes a trainer from a serialized TFNetworkModel proto.
122 // Returns the global step of TensorFlow graph or 0 if failed.
123 // Building a compatible TF graph: See tfnetwork.proto.
124 int InitTensorFlowNetwork(const std::string &tf_proto);
125 // Resets all the iteration counters for fine tuning or training a head,
126 // where we want the error reporting to reset.
127 void InitIterations();
128
129 // Accessors.
130 double ActivationError() const {
131 return error_rates_[ET_DELTA];
132 }
133 double CharError() const {
134 return error_rates_[ET_CHAR_ERROR];
135 }
136 const double *error_rates() const {
137 return error_rates_;
138 }
139 double best_error_rate() const {
140 return best_error_rate_;
141 }
142 int best_iteration() const {
143 return best_iteration_;
144 }
145 int learning_iteration() const {
146 return learning_iteration_;
147 }
148 int32_t improvement_steps() const {
149 return improvement_steps_;
150 }
151 void set_perfect_delay(int delay) {
152 perfect_delay_ = delay;
153 }
154 const std::vector<char> &best_trainer() const {
155 return best_trainer_;
156 }
157 // Returns the error that was just calculated by PrepareForBackward.
159 return error_buffers_[type][training_iteration() % kRollingBufferSize_];
160 }
161 // Returns the error that was just calculated by TrainOnLine. Since
162 // TrainOnLine rolls the error buffers, this is one further back than
163 // NewSingleError.
165 return error_buffers_[type]
166 [(training_iteration() + kRollingBufferSize_ - 1) %
167 kRollingBufferSize_];
168 }
170 return training_data_;
171 }
173 return &training_data_;
174 }
175
176 // If the training sample is usable, grid searches for the optimal
177 // dict_ratio/cert_offset, and returns the results in a string of space-
178 // separated triplets of ratio,offset=worderr.
179 Trainability GridSearchDictParams(
180 const ImageData *trainingdata, int iteration, double min_dict_ratio,
181 double dict_ratio_step, double max_dict_ratio, double min_cert_offset,
182 double cert_offset_step, double max_cert_offset, std::string &results);
183
184 // Provides output on the distribution of weight values.
185 void DebugNetwork();
186
187 // Loads a set of lstmf files that were created using the lstm.train config to
188 // tesseract into memory ready for training. Returns false if nothing was
189 // loaded.
190 bool LoadAllTrainingData(const std::vector<std::string> &filenames,
191 CachingStrategy cache_strategy,
192 bool randomly_rotate);
193
194 // Keeps track of best and locally worst error rate, using internally computed
195 // values. See MaintainCheckpointsSpecific for more detail.
196 bool MaintainCheckpoints(const TestCallback &tester, std::stringstream &log_msg);
197 // Keeps track of best and locally worst error_rate (whatever it is) and
198 // launches tests using rec_model, when a new min or max is reached.
199 // Writes checkpoints using train_model at appropriate times and builds and
200 // returns a log message to indicate progress. Returns false if nothing
201 // interesting happened.
202 bool MaintainCheckpointsSpecific(int iteration,
203 const std::vector<char> *train_model,
204 const std::vector<char> *rec_model,
205 TestCallback tester, std::stringstream &log_msg);
206 // Builds a progress message with current error rates.
207 void PrepareLogMsg(std::stringstream &log_msg) const;
208 // Appends <intro_str> iteration learning_iteration()/training_iteration()/
209 // sample_iteration() to the log_msg.
210 void LogIterations(const char *intro_str, std::stringstream &log_msg) const;
211
212 // TODO(rays) Add curriculum learning.
213 // Returns true and increments the training_stage_ if the error rate has just
214 // passed through the given threshold for the first time.
215 bool TransitionTrainingStage(float error_threshold);
216 // Returns the current training stage.
218 return training_stage_;
219 }
220
221 // Writes to the given file. Returns false in case of error.
222 bool Serialize(SerializeAmount serialize_amount, const TessdataManager *mgr,
223 TFile *fp) const;
224 // Reads from the given file. Returns false in case of error.
225 bool DeSerialize(const TessdataManager *mgr, TFile *fp);
226
227 // De-serializes the saved best_trainer_ into sub_trainer_, and adjusts the
228 // learning rates (by scaling reduction, or layer specific, according to
229 // NF_LAYER_SPECIFIC_LR).
230 void StartSubtrainer(std::stringstream &log_msg);
231 // While the sub_trainer_ is behind the current training iteration and its
232 // training error is at least kSubTrainerMarginFraction better than the
233 // current training error, trains the sub_trainer_, and returns STR_UPDATED if
234 // it did anything. If it catches up, and has a better error rate than the
235 // current best, as well as a margin over the current error rate, then the
236 // trainer in *this is replaced with sub_trainer_, and STR_REPLACED is
237 // returned. STR_NONE is returned if the subtrainer wasn't good enough to
238 // receive any training iterations.
239 SubTrainerResult UpdateSubtrainer(std::stringstream &log_msg);
240 // Reduces network learning rates, either for everything, or for layers
241 // independently, according to NF_LAYER_SPECIFIC_LR.
242 void ReduceLearningRates(LSTMTrainer *samples_trainer, std::stringstream &log_msg);
243 // Considers reducing the learning rate independently for each layer down by
244 // factor(<1), or leaving it the same, by double-training the given number of
245 // samples and minimizing the amount of changing of sign of weight updates.
246 // Even if it looks like all weights should remain the same, an adjustment
247 // will be made to guarantee a different result when reverting to an old best.
248 // Returns the number of layer learning rates that were reduced.
249 int ReduceLayerLearningRates(TFloat factor, int num_samples,
250 LSTMTrainer *samples_trainer);
251
252 // Converts the string to integer class labels, with appropriate null_char_s
253 // in between if not in SimpleTextOutput mode. Returns false on failure.
254 bool EncodeString(const std::string &str, std::vector<int> *labels) const {
255 return EncodeString(str, GetUnicharset(),
256 IsRecoding() ? &recoder_ : nullptr, SimpleTextOutput(),
257 null_char_, labels);
258 }
259 // Static version operates on supplied unicharset, encoder, simple_text.
260 static bool EncodeString(const std::string &str, const UNICHARSET &unicharset,
261 const UnicharCompress *recoder, bool simple_text,
262 int null_char, std::vector<int> *labels);
263
264 // Performs forward-backward on the given trainingdata.
265 // Returns the sample that was used or nullptr if the next sample was deemed
266 // unusable. samples_trainer could be this or an alternative trainer that
267 // holds the training samples.
268 const ImageData *TrainOnLine(LSTMTrainer *samples_trainer, bool batch) {
269 int sample_index = sample_iteration();
270 const ImageData *image =
271 samples_trainer->training_data_.GetPageBySerial(sample_index);
272 if (image != nullptr) {
273 Trainability trainable = TrainOnLine(image, batch);
274 if (trainable == UNENCODABLE || trainable == NOT_BOXED) {
275 return nullptr; // Sample was unusable.
276 }
277 } else {
278 ++sample_iteration_;
279 }
280 return image;
281 }
282 Trainability TrainOnLine(const ImageData *trainingdata, bool batch);
283
284 // Prepares the ground truth, runs forward, and prepares the targets.
285 // Returns a Trainability enum to indicate the suitability of the sample.
286 Trainability PrepareForBackward(const ImageData *trainingdata,
287 NetworkIO *fwd_outputs, NetworkIO *targets);
288
289 // Writes the trainer to memory, so that the current training state can be
290 // restored. *this must always be the master trainer that retains the only
291 // copy of the training data and language model. trainer is the model that is
292 // actually serialized.
293 bool SaveTrainingDump(SerializeAmount serialize_amount,
294 const LSTMTrainer &trainer,
295 std::vector<char> *data) const;
296
297 // Reads previously saved trainer from memory. *this must always be the
298 // master trainer that retains the only copy of the training data and
299 // language model. trainer is the model that is restored.
300 bool ReadTrainingDump(const std::vector<char> &data,
301 LSTMTrainer &trainer) const {
302 if (data.empty()) {
303 return false;
304 }
305 return ReadSizedTrainingDump(&data[0], data.size(), trainer);
306 }
307 bool ReadSizedTrainingDump(const char *data, int size,
308 LSTMTrainer &trainer) const {
309 return trainer.ReadLocalTrainingDump(&mgr_, data, size);
310 }
311 // Restores the model to *this.
312 bool ReadLocalTrainingDump(const TessdataManager *mgr, const char *data,
313 int size);
314
315 // Sets up the data for MaintainCheckpoints from a light ReadTrainingDump.
317
318 // Writes the full recognition traineddata to the given filename.
319 bool SaveTraineddata(const char *filename);
320
321 // Writes the recognizer to memory, so that it can be used for testing later.
322 void SaveRecognitionDump(std::vector<char> *data) const;
323
324 // Returns a suitable filename for a training dump, based on the model_base_,
325 // the iteration and the error rates.
326 std::string DumpFilename() const;
327
328 // Fills the whole error buffer of the given type with the given value.
329 void FillErrorBuffer(double new_error, ErrorTypes type);
330 // Helper generates a map from each current recoder_ code (ie softmax index)
331 // to the corresponding old_recoder code, or -1 if there isn't one.
332 std::vector<int> MapRecoder(const UNICHARSET &old_chset,
333 const UnicharCompress &old_recoder) const;
334
335protected:
336 // Private version of InitCharSet above finishes the job after initializing
337 // the mgr_ data member.
338 void InitCharSet();
339 // Helper computes and sets the null_char_.
340 void SetNullChar();
341
342 // Factored sub-constructor sets up reasonable default values.
343 void EmptyConstructor();
344
345 // Outputs the string and periodically displays the given network inputs
346 // as an image in the given window, and the corresponding labels at the
347 // corresponding x_starts.
348 // Returns false if the truth string is empty.
349 bool DebugLSTMTraining(const NetworkIO &inputs, const ImageData &trainingdata,
350 const NetworkIO &fwd_outputs,
351 const std::vector<int> &truth_labels,
352 const NetworkIO &outputs);
353 // Displays the network targets as line a line graph.
354 void DisplayTargets(const NetworkIO &targets, const char *window_name,
355 ScrollView **window);
356
357 // Builds a no-compromises target where the first positions should be the
358 // truth labels and the rest is padded with the null_char_.
359 bool ComputeTextTargets(const NetworkIO &outputs,
360 const std::vector<int> &truth_labels,
361 NetworkIO *targets);
362
363 // Builds a target using standard CTC. truth_labels should be pre-padded with
364 // nulls wherever desired. They don't have to be between all labels.
365 // outputs is input-output, as it gets clipped to minimum probability.
366 bool ComputeCTCTargets(const std::vector<int> &truth_labels,
367 NetworkIO *outputs, NetworkIO *targets);
368
369 // Computes network errors, and stores the results in the rolling buffers,
370 // along with the supplied text_error.
371 // Returns the delta error of the current sample (not running average.)
372 double ComputeErrorRates(const NetworkIO &deltas, double char_error,
373 double word_error);
374
375 // Computes the network activation RMS error rate.
376 double ComputeRMSError(const NetworkIO &deltas);
377
378 // Computes network activation winner error rate. (Number of values that are
379 // in error by >= 0.5 divided by number of time-steps.) More closely related
380 // to final character error than RMS, but still directly calculable from
381 // just the deltas. Because of the binary nature of the targets, zero winner
382 // error is a sufficient but not necessary condition for zero char error.
383 double ComputeWinnerError(const NetworkIO &deltas);
384
385 // Computes a very simple bag of chars char error rate.
386 double ComputeCharError(const std::vector<int> &truth_str,
387 const std::vector<int> &ocr_str);
388 // Computes a very simple bag of words word recall error rate.
389 // NOTE that this is destructive on both input strings.
390 double ComputeWordError(std::string *truth_str, std::string *ocr_str);
391
392 // Updates the error buffer and corresponding mean of the given type with
393 // the new_error.
394 void UpdateErrorBuffer(double new_error, ErrorTypes type);
395
396 // Rolls error buffers and reports the current means.
397 void RollErrorBuffers();
398
399 // Given that error_rate is either a new min or max, updates the best/worst
400 // error rates, and record of progress.
401 std::string UpdateErrorGraph(int iteration, double error_rate,
402 const std::vector<char> &model_data,
403 const TestCallback &tester);
404
405protected:
406#ifndef GRAPHICS_DISABLED
407 // Alignment display window.
409 // CTC target display window.
411 // CTC output display window.
413 // Reconstructed image window.
415#endif
416 // How often to display a debug image.
418 // Iteration at which the last checkpoint was dumped.
420 // Basename of files to save best models to.
421 std::string model_base_;
422 // Checkpoint filename.
423 std::string checkpoint_name_;
424 // Training data.
427 // Name to use when saving best_trainer_.
428 std::string best_model_name_;
429 // Number of available training stages.
431
432 // ===Serialized data to ensure that a restart produces the same results.===
433 // These members are only serialized when serialize_amount != LIGHT.
434 // Best error rate so far.
436 // Snapshot of all error rates at best_iteration_.
437 double best_error_rates_[ET_COUNT];
438 // Iteration of best_error_rate_.
440 // Worst error rate since best_error_rate_.
442 // Snapshot of all error rates at worst_iteration_.
443 double worst_error_rates_[ET_COUNT];
444 // Iteration of worst_error_rate_.
446 // Iteration at which the process will be thought stalled.
448 // Saved recognition models for computing test error for graph points.
449 std::vector<char> best_model_data_;
450 std::vector<char> worst_model_data_;
451 // Saved trainer for reverting back to last known best.
452 std::vector<char> best_trainer_;
453 // A subsidiary trainer running with a different learning rate until either
454 // *this or sub_trainer_ hits a new best.
455 std::unique_ptr<LSTMTrainer> sub_trainer_;
456 // Error rate at which last best model was dumped.
458 // Current stage of training.
460 // History of best error rate against iteration. Used for computing the
461 // number of steps to each 2% improvement.
462 std::vector<double> best_error_history_;
463 std::vector<int32_t> best_error_iterations_;
464 // Number of iterations since the best_error_rate_ was 2% more than it is now.
466 // Number of iterations that yielded a non-zero delta error and thus provided
467 // significant learning. learning_iteration_ <= training_iteration_.
468 // learning_iteration_ is used to measure rate of learning progress.
470 // Saved value of sample_iteration_ before looking for the next sample.
472 // How often to include a PERFECT training sample in backprop.
473 // A PERFECT training sample is used if the current
474 // training_iteration_ > last_perfect_training_iteration_ + perfect_delay_,
475 // so with perfect_delay_ == 0, all samples are used, and with
476 // perfect_delay_ == 4, at most 1 in 5 samples will be perfect.
478 // Value of training_iteration_ at which the last PERFECT training sample
479 // was used in back prop.
481 // Rolling buffers storing recent training errors are indexed by
482 // training_iteration % kRollingBufferSize_.
483 static const int kRollingBufferSize_ = 1000;
484 std::vector<double> error_buffers_[ET_COUNT];
485 // Rounded mean percent trailing training errors in the buffers.
486 double error_rates_[ET_COUNT]; // RMS training error.
487 // Traineddata file with optional dawgs + UNICHARSET and recoder.
489};
490
491} // namespace tesseract.
492
493#endif // TESSERACT_LSTM_LSTMTRAINER_H_
@ ET_WORD_RECERR
Definition: lstmtrainer.h:44
@ ET_SKIP_RATIO
Definition: lstmtrainer.h:46
@ ET_CHAR_ERROR
Definition: lstmtrainer.h:45
@ HI_PRECISION_ERR
Definition: lstmtrainer.h:55
@ STR_REPLACED
Definition: lstmtrainer.h:70
bool DeSerialize(bool swap, FILE *fp, std::vector< T > &data)
Definition: helpers.h:205
bool Serialize(FILE *fp, const std::vector< T > &data)
Definition: helpers.h:236
std::function< std::string(int, const double *, const TessdataManager &, int)> TestCallback
Definition: lstmtrainer.h:78
double TFloat
Definition: tesstypes.h:39
CachingStrategy
Definition: imagedata.h:42
@ NO_BEST_TRAINER
Definition: lstmtrainer.h:62
type
Definition: upload.py:458
const ImageData * GetPageBySerial(int serial)
Definition: imagedata.h:317
std::vector< int32_t > best_error_iterations_
Definition: lstmtrainer.h:463
bool MaintainCheckpointsSpecific(int iteration, const std::vector< char > *train_model, const std::vector< char > *rec_model, TestCallback tester, std::stringstream &log_msg)
std::vector< char > worst_model_data_
Definition: lstmtrainer.h:450
bool ReadLocalTrainingDump(const TessdataManager *mgr, const char *data, int size)
ScrollView * target_win_
Definition: lstmtrainer.h:410
bool EncodeString(const std::string &str, std::vector< int > *labels) const
Definition: lstmtrainer.h:254
const double * error_rates() const
Definition: lstmtrainer.h:136
bool InitCharSet(const std::string &traineddata_path)
Definition: lstmtrainer.h:100
int InitTensorFlowNetwork(const std::string &tf_proto)
std::string model_base_
Definition: lstmtrainer.h:421
std::string best_model_name_
Definition: lstmtrainer.h:428
double NewSingleError(ErrorTypes type) const
Definition: lstmtrainer.h:158
double CharError() const
Definition: lstmtrainer.h:133
std::vector< char > best_trainer_
Definition: lstmtrainer.h:452
double best_error_rate() const
Definition: lstmtrainer.h:139
double LastSingleError(ErrorTypes type) const
Definition: lstmtrainer.h:164
DocumentCache * mutable_training_data()
Definition: lstmtrainer.h:172
const std::vector< char > & best_trainer() const
Definition: lstmtrainer.h:154
float error_rate_of_last_saved_best_
Definition: lstmtrainer.h:457
ScrollView * recon_win_
Definition: lstmtrainer.h:414
int learning_iteration() const
Definition: lstmtrainer.h:145
void set_perfect_delay(int delay)
Definition: lstmtrainer.h:151
std::string checkpoint_name_
Definition: lstmtrainer.h:423
ScrollView * ctc_win_
Definition: lstmtrainer.h:412
int CurrentTrainingStage() const
Definition: lstmtrainer.h:217
double ActivationError() const
Definition: lstmtrainer.h:130
std::vector< char > best_model_data_
Definition: lstmtrainer.h:449
bool ReadSizedTrainingDump(const char *data, int size, LSTMTrainer &trainer) const
Definition: lstmtrainer.h:307
void InitCharSet(const TessdataManager &mgr)
Definition: lstmtrainer.h:107
DocumentCache training_data_
Definition: lstmtrainer.h:426
std::unique_ptr< LSTMTrainer > sub_trainer_
Definition: lstmtrainer.h:455
const DocumentCache & training_data() const
Definition: lstmtrainer.h:169
bool ReadTrainingDump(const std::vector< char > &data, LSTMTrainer &trainer) const
Definition: lstmtrainer.h:300
int32_t improvement_steps() const
Definition: lstmtrainer.h:148
const ImageData * TrainOnLine(LSTMTrainer *samples_trainer, bool batch)
Definition: lstmtrainer.h:268
std::vector< double > best_error_history_
Definition: lstmtrainer.h:462
int best_iteration() const
Definition: lstmtrainer.h:142
TessdataManager mgr_
Definition: lstmtrainer.h:488
ScrollView * align_win_
Definition: lstmtrainer.h:408