tesseract v5.3.3.20231005
lstmtrainer.cpp
Go to the documentation of this file.
1
2// File: lstmtrainer.cpp
3// Description: Top-level line trainer class for LSTM-based networks.
4// Author: Ray Smith
5//
6// (C) Copyright 2013, Google Inc.
7// Licensed under the Apache License, Version 2.0 (the "License");
8// you may not use this file except in compliance with the License.
9// You may obtain a copy of the License at
10// http://www.apache.org/licenses/LICENSE-2.0
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
17
18#define _USE_MATH_DEFINES // needed to get definition of M_SQRT1_2
19
20// Include automatically generated configuration file if running autoconf.
21#ifdef HAVE_CONFIG_H
22# include "config_auto.h"
23#endif
24
25#include <cmath>
26#include <iomanip> // for std::setprecision
27#include <locale> // for std::locale::classic
28#include <string>
29#include "lstmtrainer.h"
30
31#include <allheaders.h>
32#include "boxread.h"
33#include "ctc.h"
34#include "imagedata.h"
35#include "input.h"
36#include "networkbuilder.h"
37#include "ratngs.h"
38#include "recodebeam.h"
39#ifdef INCLUDE_TENSORFLOW
40# include "tfnetwork.h"
41#endif
42#include "tprintf.h"
43
44namespace tesseract {
45
46// Min actual error rate increase to constitute divergence.
47const double kMinDivergenceRate = 50.0;
48// Min iterations since last best before acting on a stall.
49const int kMinStallIterations = 10000;
50// Fraction of current char error rate that sub_trainer_ has to be ahead
51// before we declare the sub_trainer_ a success and switch to it.
52const double kSubTrainerMarginFraction = 3.0 / 128;
53// Factor to reduce learning rate on divergence.
54const double kLearningRateDecay = M_SQRT1_2;
55// LR adjustment iterations.
57// How often to add data to the error_graph_.
58const int kErrorGraphInterval = 1000;
59// Number of training images to train between calls to MaintainCheckpoints.
60const int kNumPagesPerBatch = 100;
61// Min percent error rate to consider start-up phase over.
62const int kMinStartedErrorRate = 75;
63// Error rate at which to transition to stage 1.
64const double kStageTransitionThreshold = 10.0;
65// Confidence beyond which the truth is more likely wrong than the recognizer.
66const double kHighConfidence = 0.9375; // 15/16.
67// Fraction of weight sign-changing total to constitute a definite improvement.
68const double kImprovementFraction = 15.0 / 16.0;
69// Fraction of last written best to make it worth writing another.
70const double kBestCheckpointFraction = 31.0 / 32.0;
71#ifndef GRAPHICS_DISABLED
72// Scale factor for display of target activations of CTC.
73const int kTargetXScale = 5;
74const int kTargetYScale = 100;
75#endif // !GRAPHICS_DISABLED
76
78 : randomly_rotate_(false), training_data_(0), sub_trainer_(nullptr) {
81}
82
83LSTMTrainer::LSTMTrainer(const char *model_base, const char *checkpoint_name,
84 int debug_interval, int64_t max_memory)
85 : randomly_rotate_(false),
86 training_data_(max_memory),
87 sub_trainer_(nullptr) {
89 debug_interval_ = debug_interval;
90 model_base_ = model_base;
91 checkpoint_name_ = checkpoint_name;
92}
93
95#ifndef GRAPHICS_DISABLED
96 delete align_win_;
97 delete target_win_;
98 delete ctc_win_;
99 delete recon_win_;
100#endif
101}
102
103// Tries to deserialize a trainer from the given file and silently returns
104// false in case of failure.
105bool LSTMTrainer::TryLoadingCheckpoint(const char *filename,
106 const char *old_traineddata) {
107 std::vector<char> data;
108 if (!LoadDataFromFile(filename, &data)) {
109 return false;
110 }
111 tprintf("Loaded file %s, unpacking...\n", filename);
112 if (!ReadTrainingDump(data, *this)) {
113 return false;
114 }
115 if (IsIntMode()) {
116 tprintf("Error, %s is an integer (fast) model, cannot continue training\n",
117 filename);
118 return false;
119 }
120 if (((old_traineddata == nullptr || *old_traineddata == '\0') &&
122 filename == old_traineddata) {
123 return true; // Normal checkpoint load complete.
124 }
125 tprintf("Code range changed from %d to %d!\n", network_->NumOutputs(),
127 if (old_traineddata == nullptr || *old_traineddata == '\0') {
128 tprintf("Must supply the old traineddata for code conversion!\n");
129 return false;
130 }
131 TessdataManager old_mgr;
132 ASSERT_HOST(old_mgr.Init(old_traineddata));
133 TFile fp;
134 if (!old_mgr.GetComponent(TESSDATA_LSTM_UNICHARSET, &fp)) {
135 return false;
136 }
137 UNICHARSET old_chset;
138 if (!old_chset.load_from_file(&fp, false)) {
139 return false;
140 }
141 if (!old_mgr.GetComponent(TESSDATA_LSTM_RECODER, &fp)) {
142 return false;
143 }
144 UnicharCompress old_recoder;
145 if (!old_recoder.DeSerialize(&fp)) {
146 return false;
147 }
148 std::vector<int> code_map = MapRecoder(old_chset, old_recoder);
149 // Set the null_char_ to the new value.
150 int old_null_char = null_char_;
151 SetNullChar();
152 // Map the softmax(s) in the network.
153 network_->RemapOutputs(old_recoder.code_range(), code_map);
154 tprintf("Previous null char=%d mapped to %d\n", old_null_char, null_char_);
155 return true;
156}
157
158// Initializes the trainer with a network_spec in the network description
159// net_flags control network behavior according to the NetworkFlags enum.
160// There isn't really much difference between them - only where the effects
161// are implemented.
162// For other args see NetworkBuilder::InitNetwork.
163// Note: Be sure to call InitCharSet before InitNetwork!
164bool LSTMTrainer::InitNetwork(const char *network_spec, int append_index,
165 int net_flags, float weight_range,
166 float learning_rate, float momentum,
167 float adam_beta) {
168 mgr_.SetVersionString(mgr_.VersionString() + ":" + network_spec);
169 adam_beta_ = adam_beta;
171 momentum_ = momentum;
172 SetNullChar();
174 append_index, net_flags, weight_range,
175 &randomizer_, &network_)) {
176 return false;
177 }
178 network_str_ += network_spec;
179 tprintf("Built network:%s from request %s\n", network_->spec().c_str(),
180 network_spec);
181 tprintf(
182 "Training parameters:\n Debug interval = %d,"
183 " weights = %g, learning rate = %g, momentum=%g\n",
185 tprintf("null char=%d\n", null_char_);
186 return true;
187}
188
189// Initializes a trainer from a serialized TFNetworkModel proto.
190// Returns the global step of TensorFlow graph or 0 if failed.
191#ifdef INCLUDE_TENSORFLOW
192int LSTMTrainer::InitTensorFlowNetwork(const std::string &tf_proto) {
193 delete network_;
194 TFNetwork *tf_net = new TFNetwork("TensorFlow");
195 training_iteration_ = tf_net->InitFromProtoStr(tf_proto);
196 if (training_iteration_ == 0) {
197 tprintf("InitFromProtoStr failed!!\n");
198 return 0;
199 }
200 network_ = tf_net;
201 ASSERT_HOST(recoder_.code_range() == tf_net->num_classes());
202 return training_iteration_;
203}
204#endif
205
206// Resets all the iteration counters for fine tuning or traininng a head,
207// where we want the error reporting to reset.
213 best_error_rate_ = 100.0;
214 best_iteration_ = 0;
215 worst_error_rate_ = 0.0;
218 best_error_history_.clear();
221 perfect_delay_ = 0;
223 for (int i = 0; i < ET_COUNT; ++i) {
224 best_error_rates_[i] = 100.0;
225 worst_error_rates_[i] = 0.0;
226 error_buffers_[i].clear();
228 error_rates_[i] = 100.0;
229 }
231}
232
233// If the training sample is usable, grid searches for the optimal
234// dict_ratio/cert_offset, and returns the results in a string of space-
235// separated triplets of ratio,offset=worderr.
237 const ImageData *trainingdata, int iteration, double min_dict_ratio,
238 double dict_ratio_step, double max_dict_ratio, double min_cert_offset,
239 double cert_offset_step, double max_cert_offset, std::string &results) {
240 sample_iteration_ = iteration;
241 NetworkIO fwd_outputs, targets;
242 Trainability result =
243 PrepareForBackward(trainingdata, &fwd_outputs, &targets);
244 if (result == UNENCODABLE || result == HI_PRECISION_ERR || dict_ == nullptr) {
245 return result;
246 }
247
248 // Encode/decode the truth to get the normalization.
249 std::vector<int> truth_labels, ocr_labels, xcoords;
250 ASSERT_HOST(EncodeString(trainingdata->transcription(), &truth_labels));
251 // NO-dict error.
253 nullptr);
254 base_search.Decode(fwd_outputs, 1.0, 0.0, RecodeBeamSearch::kMinCertainty,
255 nullptr);
256 base_search.ExtractBestPathAsLabels(&ocr_labels, &xcoords);
257 std::string truth_text = DecodeLabels(truth_labels);
258 std::string ocr_text = DecodeLabels(ocr_labels);
259 double baseline_error = ComputeWordError(&truth_text, &ocr_text);
260 results += "0,0=" + std::to_string(baseline_error);
261
263 for (double r = min_dict_ratio; r < max_dict_ratio; r += dict_ratio_step) {
264 for (double c = min_cert_offset; c < max_cert_offset;
265 c += cert_offset_step) {
266 search.Decode(fwd_outputs, r, c, RecodeBeamSearch::kMinCertainty,
267 nullptr);
268 search.ExtractBestPathAsLabels(&ocr_labels, &xcoords);
269 truth_text = DecodeLabels(truth_labels);
270 ocr_text = DecodeLabels(ocr_labels);
271 // This is destructive on both strings.
272 double word_error = ComputeWordError(&truth_text, &ocr_text);
273 if ((r == min_dict_ratio && c == min_cert_offset) ||
274 !std::isfinite(word_error)) {
275 std::string t = DecodeLabels(truth_labels);
276 std::string o = DecodeLabels(ocr_labels);
277 tprintf("r=%g, c=%g, truth=%s, ocr=%s, wderr=%g, truth[0]=%d\n", r, c,
278 t.c_str(), o.c_str(), word_error, truth_labels[0]);
279 }
280 results += " " + std::to_string(r);
281 results += "," + std::to_string(c);
282 results += "=" + std::to_string(word_error);
283 }
284 }
285 return result;
286}
287
288// Provides output on the distribution of weight values.
291}
292
293// Loads a set of lstmf files that were created using the lstm.train config to
294// tesseract into memory ready for training. Returns false if nothing was
295// loaded.
296bool LSTMTrainer::LoadAllTrainingData(const std::vector<std::string> &filenames,
297 CachingStrategy cache_strategy,
298 bool randomly_rotate) {
299 randomly_rotate_ = randomly_rotate;
301 return training_data_.LoadDocuments(filenames, cache_strategy,
303}
304
305// Keeps track of best and locally worst char error_rate and launches tests
306// using tester, when a new min or max is reached.
307// Writes checkpoints at appropriate times and builds and returns a log message
308// to indicate progress. Returns false if nothing interesting happened.
310 std::stringstream &log_msg) {
311 PrepareLogMsg(log_msg);
312 double error_rate = CharError();
313 int iteration = learning_iteration();
314 if (iteration >= stall_iteration_ &&
315 error_rate > best_error_rate_ * (1.0 + kSubTrainerMarginFraction) &&
317 // It hasn't got any better in a long while, and is a margin worse than the
318 // best, so go back to the best model and try a different learning rate.
319 StartSubtrainer(log_msg);
320 }
321 SubTrainerResult sub_trainer_result = STR_NONE;
322 if (sub_trainer_ != nullptr) {
323 sub_trainer_result = UpdateSubtrainer(log_msg);
324 if (sub_trainer_result == STR_REPLACED) {
325 // Reset the inputs, as we have overwritten *this.
326 error_rate = CharError();
327 iteration = learning_iteration();
328 PrepareLogMsg(log_msg);
329 }
330 }
331 bool result = true; // Something interesting happened.
332 std::vector<char> rec_model_data;
333 if (error_rate < best_error_rate_) {
334 SaveRecognitionDump(&rec_model_data);
335 log_msg << " New best BCER = " << error_rate;
336 log_msg << UpdateErrorGraph(iteration, error_rate, rec_model_data, tester);
337 // If sub_trainer_ is not nullptr, either *this beat it to a new best, or it
338 // just overwrote *this. In either case, we have finished with it.
339 sub_trainer_.reset();
342 log_msg << " Transitioned to stage " << CurrentTrainingStage();
343 }
346 std::string best_model_name = DumpFilename();
347 if (!SaveDataToFile(best_trainer_, best_model_name.c_str())) {
348 log_msg << " failed to write best model:";
349 } else {
350 log_msg << " wrote best model:";
352 }
353 log_msg << best_model_name;
354 }
355 } else if (error_rate > worst_error_rate_) {
356 SaveRecognitionDump(&rec_model_data);
357 log_msg << " New worst BCER = " << error_rate;
358 log_msg << UpdateErrorGraph(iteration, error_rate, rec_model_data, tester);
361 // Error rate has ballooned. Go back to the best model.
362 log_msg << "\nDivergence! ";
363 // Copy best_trainer_ before reading it, as it will get overwritten.
364 std::vector<char> revert_data(best_trainer_);
365 if (ReadTrainingDump(revert_data, *this)) {
366 LogIterations("Reverted to", log_msg);
367 ReduceLearningRates(this, log_msg);
368 } else {
369 LogIterations("Failed to Revert at", log_msg);
370 }
371 // If it fails again, we will wait twice as long before reverting again.
372 stall_iteration_ = iteration + 2 * (iteration - learning_iteration());
373 // Re-save the best trainer with the new learning rates and stall
374 // iteration.
376 }
377 } else {
378 // Something interesting happened only if the sub_trainer_ was trained.
379 result = sub_trainer_result != STR_NONE;
380 }
381 if (checkpoint_name_.length() > 0) {
382 // Write a current checkpoint.
383 std::vector<char> checkpoint;
384 if (!SaveTrainingDump(FULL, *this, &checkpoint) ||
385 !SaveDataToFile(checkpoint, checkpoint_name_.c_str())) {
386 log_msg << " failed to write checkpoint.";
387 } else {
388 log_msg << " wrote checkpoint.";
389 }
390 }
391 return result;
392}
393
394// Builds a string containing a progress message with current error rates.
395void LSTMTrainer::PrepareLogMsg(std::stringstream &log_msg) const {
396 LogIterations("At", log_msg);
397 log_msg << std::fixed << std::setprecision(3)
398 << ", mean rms=" << error_rates_[ET_RMS]
399 << "%, delta=" << error_rates_[ET_DELTA]
400 << "%, BCER train=" << error_rates_[ET_CHAR_ERROR]
401 << "%, BWER train=" << error_rates_[ET_WORD_RECERR]
402 << "%, skip ratio=" << error_rates_[ET_SKIP_RATIO] << "%,";
403}
404
405// Appends <intro_str> iteration learning_iteration()/training_iteration()/
406// sample_iteration() to the log_msg.
407void LSTMTrainer::LogIterations(const char *intro_str,
408 std::stringstream &log_msg) const {
409 log_msg << intro_str
410 << " iteration " << learning_iteration()
411 << "/" << training_iteration()
412 << "/" << sample_iteration();
413}
414
415// Returns true and increments the training_stage_ if the error rate has just
416// passed through the given threshold for the first time.
417bool LSTMTrainer::TransitionTrainingStage(float error_threshold) {
418 if (best_error_rate_ < error_threshold &&
421 return true;
422 }
423 return false;
424}
425
426// Writes to the given file. Returns false in case of error.
428 const TessdataManager *mgr, TFile *fp) const {
429 if (!LSTMRecognizer::Serialize(mgr, fp)) {
430 return false;
431 }
432 if (!fp->Serialize(&learning_iteration_)) {
433 return false;
434 }
436 return false;
437 }
438 if (!fp->Serialize(&perfect_delay_)) {
439 return false;
440 }
442 return false;
443 }
444 for (const auto &error_buffer : error_buffers_) {
445 if (!fp->Serialize(error_buffer)) {
446 return false;
447 }
448 }
449 if (!fp->Serialize(&error_rates_[0], countof(error_rates_))) {
450 return false;
451 }
452 if (!fp->Serialize(&training_stage_)) {
453 return false;
454 }
455 uint8_t amount = serialize_amount;
456 if (!fp->Serialize(&amount)) {
457 return false;
458 }
459 if (serialize_amount == LIGHT) {
460 return true; // We are done.
461 }
462 if (!fp->Serialize(&best_error_rate_)) {
463 return false;
464 }
466 return false;
467 }
468 if (!fp->Serialize(&best_iteration_)) {
469 return false;
470 }
471 if (!fp->Serialize(&worst_error_rate_)) {
472 return false;
473 }
475 return false;
476 }
477 if (!fp->Serialize(&worst_iteration_)) {
478 return false;
479 }
480 if (!fp->Serialize(&stall_iteration_)) {
481 return false;
482 }
483 if (!fp->Serialize(best_model_data_)) {
484 return false;
485 }
486 if (!fp->Serialize(worst_model_data_)) {
487 return false;
488 }
489 if (serialize_amount != NO_BEST_TRAINER && !fp->Serialize(best_trainer_)) {
490 return false;
491 }
492 std::vector<char> sub_data;
493 if (sub_trainer_ != nullptr &&
494 !SaveTrainingDump(LIGHT, *sub_trainer_, &sub_data)) {
495 return false;
496 }
497 if (!fp->Serialize(sub_data)) {
498 return false;
499 }
500 if (!fp->Serialize(best_error_history_)) {
501 return false;
502 }
504 return false;
505 }
506 return fp->Serialize(&improvement_steps_);
507}
508
509// Reads from the given file. Returns false in case of error.
510// NOTE: It is assumed that the trainer is never read cross-endian.
512 if (!LSTMRecognizer::DeSerialize(mgr, fp)) {
513 return false;
514 }
515 if (!fp->DeSerialize(&learning_iteration_)) {
516 // Special case. If we successfully decoded the recognizer, but fail here
517 // then it means we were just given a recognizer, so issue a warning and
518 // allow it.
519 tprintf("Warning: LSTMTrainer deserialized an LSTMRecognizer!\n");
522 return true;
523 }
525 return false;
526 }
527 if (!fp->DeSerialize(&perfect_delay_)) {
528 return false;
529 }
531 return false;
532 }
533 for (auto &error_buffer : error_buffers_) {
534 if (!fp->DeSerialize(error_buffer)) {
535 return false;
536 }
537 }
539 return false;
540 }
541 if (!fp->DeSerialize(&training_stage_)) {
542 return false;
543 }
544 uint8_t amount;
545 if (!fp->DeSerialize(&amount)) {
546 return false;
547 }
548 if (amount == LIGHT) {
549 return true; // Don't read the rest.
550 }
551 if (!fp->DeSerialize(&best_error_rate_)) {
552 return false;
553 }
555 return false;
556 }
557 if (!fp->DeSerialize(&best_iteration_)) {
558 return false;
559 }
560 if (!fp->DeSerialize(&worst_error_rate_)) {
561 return false;
562 }
564 return false;
565 }
566 if (!fp->DeSerialize(&worst_iteration_)) {
567 return false;
568 }
569 if (!fp->DeSerialize(&stall_iteration_)) {
570 return false;
571 }
572 if (!fp->DeSerialize(best_model_data_)) {
573 return false;
574 }
575 if (!fp->DeSerialize(worst_model_data_)) {
576 return false;
577 }
578 if (amount != NO_BEST_TRAINER && !fp->DeSerialize(best_trainer_)) {
579 return false;
580 }
581 std::vector<char> sub_data;
582 if (!fp->DeSerialize(sub_data)) {
583 return false;
584 }
585 if (sub_data.empty()) {
586 sub_trainer_ = nullptr;
587 } else {
588 sub_trainer_ = std::make_unique<LSTMTrainer>();
589 if (!ReadTrainingDump(sub_data, *sub_trainer_)) {
590 return false;
591 }
592 }
594 return false;
595 }
597 return false;
598 }
599 return fp->DeSerialize(&improvement_steps_);
600}
601
602// De-serializes the saved best_trainer_ into sub_trainer_, and adjusts the
603// learning rates (by scaling reduction, or layer specific, according to
604// NF_LAYER_SPECIFIC_LR).
605void LSTMTrainer::StartSubtrainer(std::stringstream &log_msg) {
606 sub_trainer_ = std::make_unique<LSTMTrainer>();
608 log_msg << " Failed to revert to previous best for trial!";
609 sub_trainer_.reset();
610 } else {
611 log_msg << " Trial sub_trainer_ from iteration "
612 << sub_trainer_->training_iteration();
613 // Reduce learning rate so it doesn't diverge this time.
614 sub_trainer_->ReduceLearningRates(this, log_msg);
615 // If it fails again, we will wait twice as long before reverting again.
616 int stall_offset =
617 learning_iteration() - sub_trainer_->learning_iteration();
618 stall_iteration_ = learning_iteration() + 2 * stall_offset;
619 sub_trainer_->stall_iteration_ = stall_iteration_;
620 // Re-save the best trainer with the new learning rates and stall iteration.
622 }
623}
624
625// While the sub_trainer_ is behind the current training iteration and its
626// training error is at least kSubTrainerMarginFraction better than the
627// current training error, trains the sub_trainer_, and returns STR_UPDATED if
628// it did anything. If it catches up, and has a better error rate than the
629// current best, as well as a margin over the current error rate, then the
630// trainer in *this is replaced with sub_trainer_, and STR_REPLACED is
631// returned. STR_NONE is returned if the subtrainer wasn't good enough to
632// receive any training iterations.
634 double training_error = CharError();
635 double sub_error = sub_trainer_->CharError();
636 double sub_margin = (training_error - sub_error) / sub_error;
637 if (sub_margin >= kSubTrainerMarginFraction) {
638 log_msg << " sub_trainer=" << sub_error
639 << " margin=" << 100.0 * sub_margin << "\n";
640 // Catch up to current iteration.
641 int end_iteration = training_iteration();
642 while (sub_trainer_->training_iteration() < end_iteration &&
643 sub_margin >= kSubTrainerMarginFraction) {
644 int target_iteration =
645 sub_trainer_->training_iteration() + kNumPagesPerBatch;
646 while (sub_trainer_->training_iteration() < target_iteration) {
647 sub_trainer_->TrainOnLine(this, false);
648 }
649 std::stringstream batch_log("Sub:");
650 batch_log.imbue(std::locale::classic());
651 sub_trainer_->PrepareLogMsg(batch_log);
652 batch_log << "\n";
653 tprintf("UpdateSubtrainer:%s", batch_log.str().c_str());
654 log_msg << batch_log.str();
655 sub_error = sub_trainer_->CharError();
656 sub_margin = (training_error - sub_error) / sub_error;
657 }
658 if (sub_error < best_error_rate_ &&
659 sub_margin >= kSubTrainerMarginFraction) {
660 // The sub_trainer_ has won the race to a new best. Switch to it.
661 std::vector<char> updated_trainer;
662 SaveTrainingDump(LIGHT, *sub_trainer_, &updated_trainer);
663 ReadTrainingDump(updated_trainer, *this);
664 log_msg << " Sub trainer wins at iteration "
665 << training_iteration() << "\n";
666 return STR_REPLACED;
667 }
668 return STR_UPDATED;
669 }
670 return STR_NONE;
671}
672
673// Reduces network learning rates, either for everything, or for layers
674// independently, according to NF_LAYER_SPECIFIC_LR.
676 std::stringstream &log_msg) {
678 int num_reduced = ReduceLayerLearningRates(
680 log_msg << "\nReduced learning rate on layers: " << num_reduced;
681 } else {
683 log_msg << "\nReduced learning rate to :" << learning_rate_;
684 }
685 log_msg << "\n";
686}
687
688// Considers reducing the learning rate independently for each layer down by
689// factor(<1), or leaving it the same, by double-training the given number of
690// samples and minimizing the amount of changing of sign of weight updates.
691// Even if it looks like all weights should remain the same, an adjustment
692// will be made to guarantee a different result when reverting to an old best.
693// Returns the number of layer learning rates that were reduced.
695 LSTMTrainer *samples_trainer) {
696 enum WhichWay {
697 LR_DOWN, // Learning rate will go down by factor.
698 LR_SAME, // Learning rate will stay the same.
699 LR_COUNT // Size of arrays.
700 };
701 std::vector<std::string> layers = EnumerateLayers();
702 int num_layers = layers.size();
703 std::vector<int> num_weights(num_layers);
704 std::vector<TFloat> bad_sums[LR_COUNT];
705 std::vector<TFloat> ok_sums[LR_COUNT];
706 for (int i = 0; i < LR_COUNT; ++i) {
707 bad_sums[i].resize(num_layers, 0.0);
708 ok_sums[i].resize(num_layers, 0.0);
709 }
710 auto momentum_factor = 1 / (1 - momentum_);
711 std::vector<char> orig_trainer;
712 samples_trainer->SaveTrainingDump(LIGHT, *this, &orig_trainer);
713 for (int i = 0; i < num_layers; ++i) {
714 Network *layer = GetLayer(layers[i]);
715 num_weights[i] = layer->IsTraining() ? layer->num_weights() : 0;
716 }
717 int iteration = sample_iteration();
718 for (int s = 0; s < num_samples; ++s) {
719 // Which way will we modify the learning rate?
720 for (int ww = 0; ww < LR_COUNT; ++ww) {
721 // Transfer momentum to learning rate and adjust by the ww factor.
722 auto ww_factor = momentum_factor;
723 if (ww == LR_DOWN) {
724 ww_factor *= factor;
725 }
726 // Make a copy of *this, so we can mess about without damaging anything.
727 LSTMTrainer copy_trainer;
728 samples_trainer->ReadTrainingDump(orig_trainer, copy_trainer);
729 // Clear the updates, doing nothing else.
730 copy_trainer.network_->Update(0.0, 0.0, 0.0, 0);
731 // Adjust the learning rate in each layer.
732 for (int i = 0; i < num_layers; ++i) {
733 if (num_weights[i] == 0) {
734 continue;
735 }
736 copy_trainer.ScaleLayerLearningRate(layers[i], ww_factor);
737 }
738 copy_trainer.SetIteration(iteration);
739 // Train on the sample, but keep the update in updates_ instead of
740 // applying to the weights.
741 const ImageData *trainingdata =
742 copy_trainer.TrainOnLine(samples_trainer, true);
743 if (trainingdata == nullptr) {
744 continue;
745 }
746 // We'll now use this trainer again for each layer.
747 std::vector<char> updated_trainer;
748 samples_trainer->SaveTrainingDump(LIGHT, copy_trainer, &updated_trainer);
749 for (int i = 0; i < num_layers; ++i) {
750 if (num_weights[i] == 0) {
751 continue;
752 }
753 LSTMTrainer layer_trainer;
754 samples_trainer->ReadTrainingDump(updated_trainer, layer_trainer);
755 Network *layer = layer_trainer.GetLayer(layers[i]);
756 // Update the weights in just the layer, using Adam if enabled.
757 layer->Update(0.0, momentum_, adam_beta_,
758 layer_trainer.training_iteration_ + 1);
759 // Zero the updates matrix again.
760 layer->Update(0.0, 0.0, 0.0, 0);
761 // Train again on the same sample, again holding back the updates.
762 layer_trainer.TrainOnLine(trainingdata, true);
763 // Count the sign changes in the updates in layer vs in copy_trainer.
764 float before_bad = bad_sums[ww][i];
765 float before_ok = ok_sums[ww][i];
766 layer->CountAlternators(*copy_trainer.GetLayer(layers[i]),
767 &ok_sums[ww][i], &bad_sums[ww][i]);
768 float bad_frac =
769 bad_sums[ww][i] + ok_sums[ww][i] - before_bad - before_ok;
770 if (bad_frac > 0.0f) {
771 bad_frac = (bad_sums[ww][i] - before_bad) / bad_frac;
772 }
773 }
774 }
775 ++iteration;
776 }
777 int num_lowered = 0;
778 for (int i = 0; i < num_layers; ++i) {
779 if (num_weights[i] == 0) {
780 continue;
781 }
782 Network *layer = GetLayer(layers[i]);
783 float lr = GetLayerLearningRate(layers[i]);
784 TFloat total_down = bad_sums[LR_DOWN][i] + ok_sums[LR_DOWN][i];
785 TFloat total_same = bad_sums[LR_SAME][i] + ok_sums[LR_SAME][i];
786 TFloat frac_down = bad_sums[LR_DOWN][i] / total_down;
787 TFloat frac_same = bad_sums[LR_SAME][i] / total_same;
788 tprintf("Layer %d=%s: lr %g->%g%%, lr %g->%g%%", i, layer->name().c_str(),
789 lr * factor, 100.0 * frac_down, lr, 100.0 * frac_same);
790 if (frac_down < frac_same * kImprovementFraction) {
791 tprintf(" REDUCED\n");
792 ScaleLayerLearningRate(layers[i], factor);
793 ++num_lowered;
794 } else {
795 tprintf(" SAME\n");
796 }
797 }
798 if (num_lowered == 0) {
799 // Just lower everything to make sure.
800 for (int i = 0; i < num_layers; ++i) {
801 if (num_weights[i] > 0) {
802 ScaleLayerLearningRate(layers[i], factor);
803 ++num_lowered;
804 }
805 }
806 }
807 return num_lowered;
808}
809
810// Converts the string to integer class labels, with appropriate null_char_s
811// in between if not in SimpleTextOutput mode. Returns false on failure.
812/* static */
813bool LSTMTrainer::EncodeString(const std::string &str,
814 const UNICHARSET &unicharset,
815 const UnicharCompress *recoder, bool simple_text,
816 int null_char, std::vector<int> *labels) {
817 if (str.c_str() == nullptr || str.length() <= 0) {
818 tprintf("Empty truth string!\n");
819 return false;
820 }
821 unsigned err_index;
822 std::vector<int> internal_labels;
823 labels->clear();
824 if (!simple_text) {
825 labels->push_back(null_char);
826 }
827 std::string cleaned = unicharset.CleanupString(str.c_str());
828 if (unicharset.encode_string(cleaned.c_str(), true, &internal_labels, nullptr,
829 &err_index)) {
830 bool success = true;
831 for (auto internal_label : internal_labels) {
832 if (recoder != nullptr) {
833 // Re-encode labels via recoder.
834 RecodedCharID code;
835 int len = recoder->EncodeUnichar(internal_label, &code);
836 if (len > 0) {
837 for (int j = 0; j < len; ++j) {
838 labels->push_back(code(j));
839 if (!simple_text) {
840 labels->push_back(null_char);
841 }
842 }
843 } else {
844 success = false;
845 err_index = 0;
846 break;
847 }
848 } else {
849 labels->push_back(internal_label);
850 if (!simple_text) {
851 labels->push_back(null_char);
852 }
853 }
854 }
855 if (success) {
856 return true;
857 }
858 }
859 tprintf("Encoding of string failed! Failure bytes:");
860 while (err_index < cleaned.size()) {
861 tprintf(" %x", cleaned[err_index++] & 0xff);
862 }
863 tprintf("\n");
864 return false;
865}
866
867// Performs forward-backward on the given trainingdata.
868// Returns a Trainability enum to indicate the suitability of the sample.
870 bool batch) {
871 NetworkIO fwd_outputs, targets;
872 Trainability trainable =
873 PrepareForBackward(trainingdata, &fwd_outputs, &targets);
875 if (trainable == UNENCODABLE || trainable == NOT_BOXED) {
876 return trainable; // Sample was unusable.
877 }
878 bool debug =
880 // Run backprop on the output.
881 NetworkIO bp_deltas;
882 if (network_->IsTraining() &&
883 (trainable != PERFECT ||
886 network_->Backward(debug, targets, &scratch_space_, &bp_deltas);
889 }
890#ifndef GRAPHICS_DISABLED
891 if (debug_interval_ == 1 && debug_win_ != nullptr) {
893 }
894#endif // !GRAPHICS_DISABLED
895 // Roll the memory of past means.
897 return trainable;
898}
899
900// Prepares the ground truth, runs forward, and prepares the targets.
901// Returns a Trainability enum to indicate the suitability of the sample.
903 NetworkIO *fwd_outputs,
904 NetworkIO *targets) {
905 if (trainingdata == nullptr) {
906 tprintf("Null trainingdata.\n");
907 return UNENCODABLE;
908 }
909 // Ensure repeatability of random elements even across checkpoints.
910 bool debug =
912 std::vector<int> truth_labels;
913 if (!EncodeString(trainingdata->transcription(), &truth_labels)) {
914 tprintf("Can't encode transcription: '%s' in language '%s'\n",
915 trainingdata->transcription().c_str(),
916 trainingdata->language().c_str());
917 return UNENCODABLE;
918 }
919 bool upside_down = false;
920 if (randomly_rotate_) {
921 // This ensures consistent training results.
923 upside_down = randomizer_.SignedRand(1.0) > 0.0;
924 if (upside_down) {
925 // Modify the truth labels to match the rotation:
926 // Apart from space and null, increment the label. This changes the
927 // script-id to the same script-id but upside-down.
928 // The labels need to be reversed in order, as the first is now the last.
929 for (auto truth_label : truth_labels) {
930 if (truth_label != UNICHAR_SPACE && truth_label != null_char_) {
931 ++truth_label;
932 }
933 }
934 std::reverse(truth_labels.begin(), truth_labels.end());
935 }
936 }
937 unsigned w = 0;
938 while (w < truth_labels.size() &&
939 (truth_labels[w] == UNICHAR_SPACE || truth_labels[w] == null_char_)) {
940 ++w;
941 }
942 if (w == truth_labels.size()) {
943 tprintf("Blank transcription: %s\n", trainingdata->transcription().c_str());
944 return UNENCODABLE;
945 }
946 float image_scale;
947 NetworkIO inputs;
948 bool invert = trainingdata->boxes().empty();
949 if (!RecognizeLine(*trainingdata, invert ? 0.5f : 0.0f, debug, invert, upside_down,
950 &image_scale, &inputs, fwd_outputs)) {
951 tprintf("Image %s not trainable\n", trainingdata->imagefilename().c_str());
952 return UNENCODABLE;
953 }
954 targets->Resize(*fwd_outputs, network_->NumOutputs());
955 LossType loss_type = OutputLossType();
956 if (loss_type == LT_SOFTMAX) {
957 if (!ComputeTextTargets(*fwd_outputs, truth_labels, targets)) {
958 tprintf("Compute simple targets failed for %s!\n",
959 trainingdata->imagefilename().c_str());
960 return UNENCODABLE;
961 }
962 } else if (loss_type == LT_CTC) {
963 if (!ComputeCTCTargets(truth_labels, fwd_outputs, targets)) {
964 tprintf("Compute CTC targets failed for %s!\n",
965 trainingdata->imagefilename().c_str());
966 return UNENCODABLE;
967 }
968 } else {
969 tprintf("Logistic outputs not implemented yet!\n");
970 return UNENCODABLE;
971 }
972 std::vector<int> ocr_labels;
973 std::vector<int> xcoords;
974 LabelsFromOutputs(*fwd_outputs, &ocr_labels, &xcoords);
975 // CTC does not produce correct target labels to begin with.
976 if (loss_type != LT_CTC) {
977 LabelsFromOutputs(*targets, &truth_labels, &xcoords);
978 }
979 if (!DebugLSTMTraining(inputs, *trainingdata, *fwd_outputs, truth_labels,
980 *targets)) {
981 tprintf("Input width was %d\n", inputs.Width());
982 return UNENCODABLE;
983 }
984 std::string ocr_text = DecodeLabels(ocr_labels);
985 std::string truth_text = DecodeLabels(truth_labels);
986 targets->SubtractAllFromFloat(*fwd_outputs);
987 if (debug_interval_ != 0) {
988 if (truth_text != ocr_text) {
989 tprintf("Iteration %d: BEST OCR TEXT : %s\n", training_iteration(),
990 ocr_text.c_str());
991 }
992 }
993 double char_error = ComputeCharError(truth_labels, ocr_labels);
994 double word_error = ComputeWordError(&truth_text, &ocr_text);
995 double delta_error = ComputeErrorRates(*targets, char_error, word_error);
996 if (debug_interval_ != 0) {
997 tprintf("File %s line %d %s:\n", trainingdata->imagefilename().c_str(),
998 trainingdata->page_number(), delta_error == 0.0 ? "(Perfect)" : "");
999 }
1000 if (delta_error == 0.0) {
1001 return PERFECT;
1002 }
1003 if (targets->AnySuspiciousTruth(kHighConfidence)) {
1004 return HI_PRECISION_ERR;
1005 }
1006 return TRAINABLE;
1007}
1008
1009// Writes the trainer to memory, so that the current training state can be
1010// restored. *this must always be the master trainer that retains the only
1011// copy of the training data and language model. trainer is the model that is
1012// actually serialized.
1014 const LSTMTrainer &trainer,
1015 std::vector<char> *data) const {
1016 TFile fp;
1017 fp.OpenWrite(data);
1018 return trainer.Serialize(serialize_amount, &mgr_, &fp);
1019}
1020
1021// Restores the model to *this.
1023 const char *data, int size) {
1024 if (size == 0) {
1025 tprintf("Warning: data size is 0 in LSTMTrainer::ReadLocalTrainingDump\n");
1026 return false;
1027 }
1028 TFile fp;
1029 fp.Open(data, size);
1030 return DeSerialize(mgr, &fp);
1031}
1032
1033// Writes the full recognition traineddata to the given filename.
1034bool LSTMTrainer::SaveTraineddata(const char *filename) {
1035 std::vector<char> recognizer_data;
1036 SaveRecognitionDump(&recognizer_data);
1037 mgr_.OverwriteEntry(TESSDATA_LSTM, &recognizer_data[0],
1038 recognizer_data.size());
1039 return mgr_.SaveFile(filename, SaveDataToFile);
1040}
1041
1042// Writes the recognizer to memory, so that it can be used for testing later.
1043void LSTMTrainer::SaveRecognitionDump(std::vector<char> *data) const {
1044 TFile fp;
1045 fp.OpenWrite(data);
1049}
1050
1051// Returns a suitable filename for a training dump, based on the model_base_,
1052// best_error_rate_, best_iteration_ and training_iteration_.
1053std::string LSTMTrainer::DumpFilename() const {
1054 std::stringstream filename;
1055 filename.imbue(std::locale::classic());
1056 filename << model_base_ << std::fixed << std::setprecision(3)
1057 << "_" << best_error_rate_
1058 << "_" << best_iteration_
1059 << "_" << training_iteration_
1060 << ".checkpoint";
1061 return filename.str();
1062}
1063
1064// Fills the whole error buffer of the given type with the given value.
1066 for (int i = 0; i < kRollingBufferSize_; ++i) {
1067 error_buffers_[type][i] = new_error;
1068 }
1069 error_rates_[type] = 100.0 * new_error;
1070}
1071
1072// Helper generates a map from each current recoder_ code (ie softmax index)
1073// to the corresponding old_recoder code, or -1 if there isn't one.
1075 const UNICHARSET &old_chset, const UnicharCompress &old_recoder) const {
1076 int num_new_codes = recoder_.code_range();
1077 int num_new_unichars = GetUnicharset().size();
1078 std::vector<int> code_map(num_new_codes, -1);
1079 for (int c = 0; c < num_new_codes; ++c) {
1080 int old_code = -1;
1081 // Find all new unichar_ids that recode to something that includes c.
1082 // The <= is to include the null char, which may be beyond the unicharset.
1083 for (int uid = 0; uid <= num_new_unichars; ++uid) {
1084 RecodedCharID codes;
1085 int length = recoder_.EncodeUnichar(uid, &codes);
1086 int code_index = 0;
1087 while (code_index < length && codes(code_index) != c) {
1088 ++code_index;
1089 }
1090 if (code_index == length) {
1091 continue;
1092 }
1093 // The old unicharset must have the same unichar.
1094 int old_uid =
1095 uid < num_new_unichars
1096 ? old_chset.unichar_to_id(GetUnicharset().id_to_unichar(uid))
1097 : old_chset.size() - 1;
1098 if (old_uid == INVALID_UNICHAR_ID) {
1099 continue;
1100 }
1101 // The encoding of old_uid at the same code_index is the old code.
1102 RecodedCharID old_codes;
1103 if (code_index < old_recoder.EncodeUnichar(old_uid, &old_codes)) {
1104 old_code = old_codes(code_index);
1105 break;
1106 }
1107 }
1108 code_map[c] = old_code;
1109 }
1110 return code_map;
1111}
1112
1113// Private version of InitCharSet above finishes the job after initializing
1114// the mgr_ data member.
1118 // Initialize the unicharset and recoder.
1119 if (!LoadCharsets(&mgr_)) {
1121 "Must provide a traineddata containing lstm_unicharset and"
1122 " lstm_recoder!\n" != nullptr);
1123 }
1124 SetNullChar();
1125}
1126
1127// Helper computes and sets the null_char_.
1130 : GetUnicharset().size();
1131 RecodedCharID code;
1133 null_char_ = code(0);
1134}
1135
1136// Factored sub-constructor sets up reasonable default values.
1138#ifndef GRAPHICS_DISABLED
1139 align_win_ = nullptr;
1140 target_win_ = nullptr;
1141 ctc_win_ = nullptr;
1142 recon_win_ = nullptr;
1143#endif
1145 training_stage_ = 0;
1148}
1149
1150// Outputs the string and periodically displays the given network inputs
1151// as an image in the given window, and the corresponding labels at the
1152// corresponding x_starts.
1153// Returns false if the truth string is empty.
1155 const ImageData &trainingdata,
1156 const NetworkIO &fwd_outputs,
1157 const std::vector<int> &truth_labels,
1158 const NetworkIO &outputs) {
1159 const std::string &truth_text = DecodeLabels(truth_labels);
1160 if (truth_text.c_str() == nullptr || truth_text.length() <= 0) {
1161 tprintf("Empty truth string at decode time!\n");
1162 return false;
1163 }
1164 if (debug_interval_ != 0) {
1165 // Get class labels, xcoords and string.
1166 std::vector<int> labels;
1167 std::vector<int> xcoords;
1168 LabelsFromOutputs(outputs, &labels, &xcoords);
1169 std::string text = DecodeLabels(labels);
1170 tprintf("Iteration %d: GROUND TRUTH : %s\n", training_iteration(),
1171 truth_text.c_str());
1172 if (truth_text != text) {
1173 tprintf("Iteration %d: ALIGNED TRUTH : %s\n", training_iteration(),
1174 text.c_str());
1175 }
1177 tprintf("TRAINING activation path for truth string %s\n",
1178 truth_text.c_str());
1179 DebugActivationPath(outputs, labels, xcoords);
1180#ifndef GRAPHICS_DISABLED
1181 DisplayForward(inputs, labels, xcoords, "LSTMTraining", &align_win_);
1182 if (OutputLossType() == LT_CTC) {
1183 DisplayTargets(fwd_outputs, "CTC Outputs", &ctc_win_);
1184 DisplayTargets(outputs, "CTC Targets", &target_win_);
1185 }
1186#endif
1187 }
1188 }
1189 return true;
1190}
1191
1192#ifndef GRAPHICS_DISABLED
1193
1194// Displays the network targets as line a line graph.
1196 const char *window_name, ScrollView **window) {
1197 int width = targets.Width();
1198 int num_features = targets.NumFeatures();
1199 Network::ClearWindow(true, window_name, width * kTargetXScale, kTargetYScale,
1200 window);
1201 for (int c = 0; c < num_features; ++c) {
1202 int color = c % (ScrollView::GREEN_YELLOW - 1) + 2;
1203 (*window)->Pen(static_cast<ScrollView::Color>(color));
1204 int start_t = -1;
1205 for (int t = 0; t < width; ++t) {
1206 double target = targets.f(t)[c];
1207 target *= kTargetYScale;
1208 if (target >= 1) {
1209 if (start_t < 0) {
1210 (*window)->SetCursor(t - 1, 0);
1211 start_t = t;
1212 }
1213 (*window)->DrawTo(t, target);
1214 } else if (start_t >= 0) {
1215 (*window)->DrawTo(t, 0);
1216 (*window)->DrawTo(start_t - 1, 0);
1217 start_t = -1;
1218 }
1219 }
1220 if (start_t >= 0) {
1221 (*window)->DrawTo(width, 0);
1222 (*window)->DrawTo(start_t - 1, 0);
1223 }
1224 }
1225 (*window)->Update();
1226}
1227
1228#endif // !GRAPHICS_DISABLED
1229
1230// Builds a no-compromises target where the first positions should be the
1231// truth labels and the rest is padded with the null_char_.
1233 const std::vector<int> &truth_labels,
1234 NetworkIO *targets) {
1235 if (truth_labels.size() > targets->Width()) {
1236 tprintf("Error: transcription %s too long to fit into target of width %d\n",
1237 DecodeLabels(truth_labels).c_str(), targets->Width());
1238 return false;
1239 }
1240 int i = 0;
1241 for (auto truth_label : truth_labels) {
1242 targets->SetActivations(i, truth_label, 1.0);
1243 ++i;
1244 }
1245 for (i = truth_labels.size(); i < targets->Width(); ++i) {
1246 targets->SetActivations(i, null_char_, 1.0);
1247 }
1248 return true;
1249}
1250
1251// Builds a target using standard CTC. truth_labels should be pre-padded with
1252// nulls wherever desired. They don't have to be between all labels.
1253// outputs is input-output, as it gets clipped to minimum probability.
1254bool LSTMTrainer::ComputeCTCTargets(const std::vector<int> &truth_labels,
1255 NetworkIO *outputs, NetworkIO *targets) {
1256 // Bottom-clip outputs to a minimum probability.
1257 CTC::NormalizeProbs(outputs);
1258 return CTC::ComputeCTCTargets(truth_labels, null_char_,
1259 outputs->float_array(), targets);
1260}
1261
1262// Computes network errors, and stores the results in the rolling buffers,
1263// along with the supplied text_error.
1264// Returns the delta error of the current sample (not running average.)
1266 double char_error, double word_error) {
1268 // Delta error is the fraction of timesteps with >0.5 error in the top choice
1269 // score. If zero, then the top choice characters are guaranteed correct,
1270 // even when there is residue in the RMS error.
1271 double delta_error = ComputeWinnerError(deltas);
1272 UpdateErrorBuffer(delta_error, ET_DELTA);
1273 UpdateErrorBuffer(word_error, ET_WORD_RECERR);
1274 UpdateErrorBuffer(char_error, ET_CHAR_ERROR);
1275 // Skip ratio measures the difference between sample_iteration_ and
1276 // training_iteration_, which reflects the number of unusable samples,
1277 // usually due to unencodable truth text, or the text not fitting in the
1278 // space for the output.
1279 double skip_count = sample_iteration_ - prev_sample_iteration_;
1280 UpdateErrorBuffer(skip_count, ET_SKIP_RATIO);
1281 return delta_error;
1282}
1283
1284// Computes the network activation RMS error rate.
1286 double total_error = 0.0;
1287 int width = deltas.Width();
1288 int num_classes = deltas.NumFeatures();
1289 for (int t = 0; t < width; ++t) {
1290 const float *class_errs = deltas.f(t);
1291 for (int c = 0; c < num_classes; ++c) {
1292 double error = class_errs[c];
1293 total_error += error * error;
1294 }
1295 }
1296 return sqrt(total_error / (width * num_classes));
1297}
1298
1299// Computes network activation winner error rate. (Number of values that are
1300// in error by >= 0.5 divided by number of time-steps.) More closely related
1301// to final character error than RMS, but still directly calculable from
1302// just the deltas. Because of the binary nature of the targets, zero winner
1303// error is a sufficient but not necessary condition for zero char error.
1305 int num_errors = 0;
1306 int width = deltas.Width();
1307 int num_classes = deltas.NumFeatures();
1308 for (int t = 0; t < width; ++t) {
1309 const float *class_errs = deltas.f(t);
1310 for (int c = 0; c < num_classes; ++c) {
1311 float abs_delta = std::fabs(class_errs[c]);
1312 // TODO(rays) Filtering cases where the delta is very large to cut out
1313 // GT errors doesn't work. Find a better way or get better truth.
1314 if (0.5 <= abs_delta) {
1315 ++num_errors;
1316 }
1317 }
1318 }
1319 return static_cast<double>(num_errors) / width;
1320}
1321
1322// Computes a very simple bag of chars char error rate.
1323double LSTMTrainer::ComputeCharError(const std::vector<int> &truth_str,
1324 const std::vector<int> &ocr_str) {
1325 std::vector<int> label_counts(NumOutputs());
1326 unsigned truth_size = 0;
1327 for (auto ch : truth_str) {
1328 if (ch != null_char_) {
1329 ++label_counts[ch];
1330 ++truth_size;
1331 }
1332 }
1333 for (auto ch : ocr_str) {
1334 if (ch != null_char_) {
1335 --label_counts[ch];
1336 }
1337 }
1338 unsigned char_errors = 0;
1339 for (auto label_count : label_counts) {
1340 char_errors += abs(label_count);
1341 }
1342 // Limit BCER to interval [0,1] and avoid division by zero.
1343 if (truth_size <= char_errors) {
1344 return (char_errors == 0) ? 0.0 : 1.0;
1345 }
1346 return static_cast<double>(char_errors) / truth_size;
1347}
1348
1349// Computes word recall error rate using a very simple bag of words algorithm.
1350// NOTE that this is destructive on both input strings.
1351double LSTMTrainer::ComputeWordError(std::string *truth_str,
1352 std::string *ocr_str) {
1353 using StrMap = std::unordered_map<std::string, int, std::hash<std::string>>;
1354 std::vector<std::string> truth_words = split(*truth_str, ' ');
1355 if (truth_words.empty()) {
1356 return 0.0;
1357 }
1358 std::vector<std::string> ocr_words = split(*ocr_str, ' ');
1359 StrMap word_counts;
1360 for (const auto &truth_word : truth_words) {
1361 std::string truth_word_string(truth_word.c_str());
1362 auto it = word_counts.find(truth_word_string);
1363 if (it == word_counts.end()) {
1364 word_counts.insert(std::make_pair(truth_word_string, 1));
1365 } else {
1366 ++it->second;
1367 }
1368 }
1369 for (const auto &ocr_word : ocr_words) {
1370 std::string ocr_word_string(ocr_word.c_str());
1371 auto it = word_counts.find(ocr_word_string);
1372 if (it == word_counts.end()) {
1373 word_counts.insert(std::make_pair(ocr_word_string, -1));
1374 } else {
1375 --it->second;
1376 }
1377 }
1378 int word_recall_errs = 0;
1379 for (const auto &word_count : word_counts) {
1380 if (word_count.second > 0) {
1381 word_recall_errs += word_count.second;
1382 }
1383 }
1384 return static_cast<double>(word_recall_errs) / truth_words.size();
1385}
1386
1387// Updates the error buffer and corresponding mean of the given type with
1388// the new_error.
1391 error_buffers_[type][index] = new_error;
1392 // Compute the mean error.
1393 int mean_count =
1394 std::min<int>(training_iteration_ + 1, error_buffers_[type].size());
1395 double buffer_sum = 0.0;
1396 for (int i = 0; i < mean_count; ++i) {
1397 buffer_sum += error_buffers_[type][i];
1398 }
1399 double mean = buffer_sum / mean_count;
1400 // Trim precision to 1/1000 of 1%.
1401 error_rates_[type] = IntCastRounded(100000.0 * mean) / 1000.0;
1402}
1403
1404// Rolls error buffers and reports the current means.
1407 if (NewSingleError(ET_DELTA) > 0.0) {
1409 } else {
1411 }
1413 if (debug_interval_ != 0) {
1414 tprintf("Mean rms=%g%%, delta=%g%%, train=%g%%(%g%%), skip ratio=%g%%\n",
1418 }
1419}
1420
1421// Given that error_rate is either a new min or max, updates the best/worst
1422// error rates, and record of progress.
1423// Tester is an externally supplied callback function that tests on some
1424// data set with a given model and records the error rates in a graph.
1425std::string LSTMTrainer::UpdateErrorGraph(int iteration, double error_rate,
1426 const std::vector<char> &model_data,
1427 const TestCallback &tester) {
1428 if (error_rate > best_error_rate_ &&
1429 iteration < best_iteration_ + kErrorGraphInterval) {
1430 // Too soon to record a new point.
1431 if (tester != nullptr && !worst_model_data_.empty()) {
1433 worst_model_data_.size());
1434 return tester(worst_iteration_, nullptr, mgr_, CurrentTrainingStage());
1435 } else {
1436 return "";
1437 }
1438 }
1439 std::string result;
1440 // NOTE: there are 2 asymmetries here:
1441 // 1. We are computing the global minimum, but the local maximum in between.
1442 // 2. If the tester returns an empty string, indicating that it is busy,
1443 // call it repeatedly on new local maxima to test the previous min, but
1444 // not the other way around, as there is little point testing the maxima
1445 // between very frequent minima.
1446 if (error_rate < best_error_rate_) {
1447 // This is a new (global) minimum.
1448 if (tester != nullptr && !worst_model_data_.empty()) {
1450 worst_model_data_.size());
1451 result = tester(worst_iteration_, worst_error_rates_, mgr_,
1453 worst_model_data_.clear();
1454 best_model_data_ = model_data;
1455 }
1456 best_error_rate_ = error_rate;
1458 best_iteration_ = iteration;
1459 best_error_history_.push_back(error_rate);
1460 best_error_iterations_.push_back(iteration);
1461 // Compute 2% decay time.
1462 double two_percent_more = error_rate + 2.0;
1463 int i;
1464 for (i = best_error_history_.size() - 1;
1465 i >= 0 && best_error_history_[i] < two_percent_more; --i) {
1466 }
1467 int old_iteration = i >= 0 ? best_error_iterations_[i] : 0;
1468 improvement_steps_ = iteration - old_iteration;
1469 tprintf("2 Percent improvement time=%d, best error was %g @ %d\n",
1470 improvement_steps_, i >= 0 ? best_error_history_[i] : 100.0,
1471 old_iteration);
1472 } else if (error_rate > best_error_rate_) {
1473 // This is a new (local) maximum.
1474 if (tester != nullptr) {
1475 if (!best_model_data_.empty()) {
1477 best_model_data_.size());
1478 result = tester(best_iteration_, best_error_rates_, mgr_,
1480 } else if (!worst_model_data_.empty()) {
1481 // Allow for multiple data points with "worst" error rate.
1483 worst_model_data_.size());
1484 result = tester(worst_iteration_, worst_error_rates_, mgr_,
1486 }
1487 if (result.length() > 0) {
1488 best_model_data_.clear();
1489 }
1490 worst_model_data_ = model_data;
1491 }
1492 }
1493 worst_error_rate_ = error_rate;
1495 worst_iteration_ = iteration;
1496 return result;
1497}
1498
1499} // namespace tesseract.
#define ASSERT_HOST(x)
Definition: errcode.h:54
@ TF_COMPRESS_UNICHARSET
@ ET_WORD_RECERR
Definition: lstmtrainer.h:44
@ ET_SKIP_RATIO
Definition: lstmtrainer.h:46
@ ET_CHAR_ERROR
Definition: lstmtrainer.h:45
@ HI_PRECISION_ERR
Definition: lstmtrainer.h:55
const double kLearningRateDecay
Definition: lstmtrainer.cpp:54
const double kImprovementFraction
Definition: lstmtrainer.cpp:68
const int kTargetYScale
Definition: lstmtrainer.cpp:74
@ STR_REPLACED
Definition: lstmtrainer.h:70
const int kMinStartedErrorRate
Definition: lstmtrainer.cpp:62
void tprintf(const char *format,...)
Definition: tprintf.cpp:41
int IntCastRounded(double x)
Definition: helpers.h:170
@ SVET_CLICK
Definition: scrollview.h:56
@ TESSDATA_LSTM_UNICHARSET
@ TESSDATA_LSTM_RECODER
const double kSubTrainerMarginFraction
Definition: lstmtrainer.cpp:52
std::function< std::string(int, const double *, const TessdataManager &, int)> TestCallback
Definition: lstmtrainer.h:78
const int kErrorGraphInterval
Definition: lstmtrainer.cpp:58
constexpr size_t countof(T const (&)[N]) noexcept
Definition: serialis.h:34
bool SaveDataToFile(const GenericVector< char > &data, const char *filename)
@ UNICHAR_SPACE
Definition: unicharset.h:36
@ UNICHAR_BROKEN
Definition: unicharset.h:38
@ TS_TEMP_DISABLE
Definition: network.h:95
@ TS_ENABLED
Definition: network.h:93
@ TS_RE_ENABLE
Definition: network.h:97
double TFloat
Definition: tesstypes.h:39
@ NF_LAYER_SPECIFIC_LR
Definition: network.h:85
LIST search(LIST list, void *key, int_compare is_equal)
Definition: oldlist.cpp:211
const double kMinDivergenceRate
Definition: lstmtrainer.cpp:47
const int kNumAdjustmentIterations
Definition: lstmtrainer.cpp:56
const double kHighConfidence
Definition: lstmtrainer.cpp:66
CachingStrategy
Definition: imagedata.h:42
const double kBestCheckpointFraction
Definition: lstmtrainer.cpp:70
const int kNumPagesPerBatch
Definition: lstmtrainer.cpp:60
const int kTargetXScale
Definition: lstmtrainer.cpp:73
const std::vector< std::string > split(const std::string &s, char c)
Definition: helpers.h:43
const int kMinStallIterations
Definition: lstmtrainer.cpp:49
const double kStageTransitionThreshold
Definition: lstmtrainer.cpp:64
@ NO_BEST_TRAINER
Definition: lstmtrainer.h:62
bool LoadDataFromFile(const char *filename, GenericVector< char > *data)
type
Definition: upload.py:458
int page_number() const
Definition: imagedata.h:89
const std::string & imagefilename() const
Definition: imagedata.h:83
const std::string & transcription() const
Definition: imagedata.h:104
const std::string & language() const
Definition: imagedata.h:98
const std::vector< TBOX > & boxes() const
Definition: imagedata.h:107
TESS_API bool LoadDocuments(const std::vector< std::string > &filenames, CachingStrategy cache_strategy, FileReader reader)
Definition: imagedata.cpp:614
double SignedRand(double range)
Definition: helpers.h:78
void OpenWrite(std::vector< char > *data)
Definition: serialis.cpp:246
bool DeSerialize(std::string &data)
Definition: serialis.cpp:94
bool Serialize(const std::string &data)
Definition: serialis.cpp:107
bool Open(const char *filename, FileReader reader)
Definition: serialis.cpp:140
void OverwriteEntry(TessdataType type, const char *data, int size)
std::string VersionString() const
void SetVersionString(const std::string &v_str)
bool GetComponent(TessdataType type, TFile *fp)
bool SaveFile(const char *filename, FileWriter writer) const
bool Init(const char *data_file_name)
int EncodeUnichar(unsigned unichar_id, RecodedCharID *code) const
bool encode_string(const char *str, bool give_up_on_failure, std::vector< UNICHAR_ID > *encoding, std::vector< char > *lengths, unsigned *encoded_length) const
Definition: unicharset.cpp:239
bool has_special_codes() const
Definition: unicharset.h:756
bool load_from_file(const char *const filename, bool skip_fragments)
Definition: unicharset.h:391
UNICHAR_ID unichar_to_id(const char *const unichar_repr) const
Definition: unicharset.cpp:186
size_t size() const
Definition: unicharset.h:355
static std::string CleanupString(const char *utf8_str)
Definition: unicharset.h:265
void DebugActivationPath(const NetworkIO &outputs, const std::vector< int > &labels, const std::vector< int > &xcoords)
LossType OutputLossType() const
std::string DecodeLabels(const std::vector< int > &labels)
NetworkScratch scratch_space_
bool LoadCharsets(const TessdataManager *mgr)
void LabelsFromOutputs(const NetworkIO &outputs, std::vector< int > *labels, std::vector< int > *xcoords)
void RecognizeLine(const ImageData &image_data, float invert_threshold, bool debug, double worst_dict_cert, const TBOX &line_box, PointerVector< WERD_RES > *words, int lstm_choice_mode=0, int lstm_choice_amount=5)
void DisplayForward(const NetworkIO &inputs, const std::vector< int > &labels, const std::vector< int > &label_coords, const char *window_name, ScrollView **window)
void SetIteration(int iteration)
void ScaleLearningRate(double factor)
void ScaleLayerLearningRate(const std::string &id, double factor)
std::vector< std::string > EnumerateLayers() const
float GetLayerLearningRate(const std::string &id) const
Network * GetLayer(const std::string &id) const
bool Serialize(const TessdataManager *mgr, TFile *fp) const
const UNICHARSET & GetUnicharset() const
bool DeSerialize(const TessdataManager *mgr, TFile *fp)
virtual int RemapOutputs(int old_no, const std::vector< int > &code_map)
Definition: network.h:190
const std::string & name() const
Definition: network.h:140
int NumOutputs() const
Definition: network.h:125
int num_weights() const
Definition: network.h:119
static void ClearWindow(bool tess_coords, const char *window_name, int width, int height, ScrollView **window)
Definition: network.cpp:350
virtual void SetEnableTraining(TrainingState state)
Definition: network.cpp:113
virtual bool Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, NetworkIO *back_deltas)=0
bool IsTraining() const
Definition: network.h:113
virtual void Update(float learning_rate, float momentum, float adam_beta, int num_samples)
Definition: network.h:235
virtual void DebugWeights()=0
bool TestFlag(NetworkFlags flag) const
Definition: network.h:146
virtual std::string spec() const
Definition: network.h:143
virtual void CountAlternators(const Network &other, TFloat *same, TFloat *changed) const
Definition: network.h:242
void Resize(const NetworkIO &src, int num_features)
Definition: networkio.h:44
float * f(int t)
Definition: networkio.h:110
int Width() const
Definition: networkio.h:102
void SetActivations(int t, int label, float ok_score)
Definition: networkio.cpp:548
bool AnySuspiciousTruth(float confidence_thr) const
Definition: networkio.cpp:591
void SubtractAllFromFloat(const NetworkIO &src)
Definition: networkio.cpp:838
const GENERIC_2D_ARRAY< float > & float_array() const
Definition: networkio.h:134
int NumFeatures() const
Definition: networkio.h:106
void Decode(const NetworkIO &output, double dict_ratio, double cert_offset, double worst_dict_cert, const UNICHARSET *charset, int lstm_choice_mode=0)
Definition: recodebeam.cpp:83
void ExtractBestPathAsLabels(std::vector< int > *labels, std::vector< int > *xcoords) const
Definition: recodebeam.cpp:201
static constexpr float kMinCertainty
Definition: recodebeam.h:243
static bool ComputeCTCTargets(const std::vector< int > &truth_labels, int null_char, const GENERIC_2D_ARRAY< float > &outputs, NetworkIO *targets)
Definition: ctc.cpp:53
static void NormalizeProbs(NetworkIO *probs)
Definition: ctc.h:36
static bool InitNetwork(int num_outputs, const char *network_spec, int append_index, int net_flags, float weight_range, TRand *randomizer, Network **network)
bool TransitionTrainingStage(float error_threshold)
std::vector< int32_t > best_error_iterations_
Definition: lstmtrainer.h:463
std::vector< char > worst_model_data_
Definition: lstmtrainer.h:450
Trainability PrepareForBackward(const ImageData *trainingdata, NetworkIO *fwd_outputs, NetworkIO *targets)
bool ReadLocalTrainingDump(const TessdataManager *mgr, const char *data, int size)
bool MaintainCheckpoints(const TestCallback &tester, std::stringstream &log_msg)
std::string UpdateErrorGraph(int iteration, double error_rate, const std::vector< char > &model_data, const TestCallback &tester)
ScrollView * target_win_
Definition: lstmtrainer.h:410
bool EncodeString(const std::string &str, std::vector< int > *labels) const
Definition: lstmtrainer.h:254
double error_rates_[ET_COUNT]
Definition: lstmtrainer.h:486
bool LoadAllTrainingData(const std::vector< std::string > &filenames, CachingStrategy cache_strategy, bool randomly_rotate)
double ComputeErrorRates(const NetworkIO &deltas, double char_error, double word_error)
int InitTensorFlowNetwork(const std::string &tf_proto)
void LogIterations(const char *intro_str, std::stringstream &log_msg) const
double ComputeWordError(std::string *truth_str, std::string *ocr_str)
std::string model_base_
Definition: lstmtrainer.h:421
double NewSingleError(ErrorTypes type) const
Definition: lstmtrainer.h:158
double CharError() const
Definition: lstmtrainer.h:133
void StartSubtrainer(std::stringstream &log_msg)
bool ComputeCTCTargets(const std::vector< int > &truth_labels, NetworkIO *outputs, NetworkIO *targets)
std::vector< char > best_trainer_
Definition: lstmtrainer.h:452
double worst_error_rates_[ET_COUNT]
Definition: lstmtrainer.h:443
void SaveRecognitionDump(std::vector< char > *data) const
bool Serialize(SerializeAmount serialize_amount, const TessdataManager *mgr, TFile *fp) const
bool ComputeTextTargets(const NetworkIO &outputs, const std::vector< int > &truth_labels, NetworkIO *targets)
float error_rate_of_last_saved_best_
Definition: lstmtrainer.h:457
ScrollView * recon_win_
Definition: lstmtrainer.h:414
void FillErrorBuffer(double new_error, ErrorTypes type)
int learning_iteration() const
Definition: lstmtrainer.h:145
bool SaveTraineddata(const char *filename)
SubTrainerResult UpdateSubtrainer(std::stringstream &log_msg)
bool DeSerialize(const TessdataManager *mgr, TFile *fp)
double ComputeRMSError(const NetworkIO &deltas)
Trainability GridSearchDictParams(const ImageData *trainingdata, int iteration, double min_dict_ratio, double dict_ratio_step, double max_dict_ratio, double min_cert_offset, double cert_offset_step, double max_cert_offset, std::string &results)
double ComputeWinnerError(const NetworkIO &deltas)
std::string checkpoint_name_
Definition: lstmtrainer.h:423
bool InitNetwork(const char *network_spec, int append_index, int net_flags, float weight_range, float learning_rate, float momentum, float adam_beta)
void UpdateErrorBuffer(double new_error, ErrorTypes type)
ScrollView * ctc_win_
Definition: lstmtrainer.h:412
int CurrentTrainingStage() const
Definition: lstmtrainer.h:217
std::string DumpFilename() const
std::vector< char > best_model_data_
Definition: lstmtrainer.h:449
bool SaveTrainingDump(SerializeAmount serialize_amount, const LSTMTrainer &trainer, std::vector< char > *data) const
bool DebugLSTMTraining(const NetworkIO &inputs, const ImageData &trainingdata, const NetworkIO &fwd_outputs, const std::vector< int > &truth_labels, const NetworkIO &outputs)
DocumentCache training_data_
Definition: lstmtrainer.h:426
static const int kRollingBufferSize_
Definition: lstmtrainer.h:483
std::vector< double > error_buffers_[ET_COUNT]
Definition: lstmtrainer.h:484
std::unique_ptr< LSTMTrainer > sub_trainer_
Definition: lstmtrainer.h:455
double ComputeCharError(const std::vector< int > &truth_str, const std::vector< int > &ocr_str)
void DisplayTargets(const NetworkIO &targets, const char *window_name, ScrollView **window)
bool ReadTrainingDump(const std::vector< char > &data, LSTMTrainer &trainer) const
Definition: lstmtrainer.h:300
int ReduceLayerLearningRates(TFloat factor, int num_samples, LSTMTrainer *samples_trainer)
const ImageData * TrainOnLine(LSTMTrainer *samples_trainer, bool batch)
Definition: lstmtrainer.h:268
std::vector< double > best_error_history_
Definition: lstmtrainer.h:462
void PrepareLogMsg(std::stringstream &log_msg) const
TessdataManager mgr_
Definition: lstmtrainer.h:488
ScrollView * align_win_
Definition: lstmtrainer.h:408
bool TryLoadingCheckpoint(const char *filename, const char *old_traineddata)
void ReduceLearningRates(LSTMTrainer *samples_trainer, std::stringstream &log_msg)
double best_error_rates_[ET_COUNT]
Definition: lstmtrainer.h:437
std::vector< int > MapRecoder(const UNICHARSET &old_chset, const UnicharCompress &old_recoder) const
std::unique_ptr< SVEvent > AwaitEvent(SVEventType type)
Definition: scrollview.cpp:432