tesseract v5.3.3.20231005
lstmrecognizer.cpp
Go to the documentation of this file.
1
2// File: lstmrecognizer.cpp
3// Description: Top-level line recognizer class for LSTM-based networks.
4// Author: Ray Smith
5//
6// (C) Copyright 2013, Google Inc.
7// Licensed under the Apache License, Version 2.0 (the "License");
8// you may not use this file except in compliance with the License.
9// You may obtain a copy of the License at
10// http://www.apache.org/licenses/LICENSE-2.0
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
17
18// Include automatically generated configuration file if running autoconf.
19#ifdef HAVE_CONFIG_H
20# include "config_auto.h"
21#endif
22
23#include "lstmrecognizer.h"
24
25#include <allheaders.h>
26#include "dict.h"
27#include "genericheap.h"
28#include "helpers.h"
29#include "imagedata.h"
30#include "input.h"
31#include "lstm.h"
32#include "normalis.h"
33#include "pageres.h"
34#include "ratngs.h"
35#include "recodebeam.h"
36#include "scrollview.h"
37#include "statistc.h"
38#include "tprintf.h"
39
40#include <unordered_set>
41#include <vector>
42
43namespace tesseract {
44
45// Default ratio between dict and non-dict words.
46const double kDictRatio = 2.25;
47// Default certainty offset to give the dictionary a chance.
48const double kCertOffset = -0.085;
49
50LSTMRecognizer::LSTMRecognizer(const std::string &language_data_path_prefix)
52 ccutil_.language_data_path_prefix = language_data_path_prefix;
53}
54
56 : network_(nullptr)
57 , training_flags_(0)
58 , training_iteration_(0)
59 , sample_iteration_(0)
60 , null_char_(UNICHAR_BROKEN)
61 , learning_rate_(0.0f)
62 , momentum_(0.0f)
63 , adam_beta_(0.0f)
64 , dict_(nullptr)
65 , search_(nullptr)
66 , debug_win_(nullptr) {}
67
69 delete network_;
70 delete dict_;
71 delete search_;
72}
73
74// Loads a model from mgr, including the dictionary only if lang is not null.
75bool LSTMRecognizer::Load(const ParamsVectors *params, const std::string &lang,
76 TessdataManager *mgr) {
77 TFile fp;
78 if (!mgr->GetComponent(TESSDATA_LSTM, &fp)) {
79 return false;
80 }
81 if (!DeSerialize(mgr, &fp)) {
82 return false;
83 }
84 if (lang.empty()) {
85 return true;
86 }
87 // Allow it to run without a dictionary.
88 LoadDictionary(params, lang, mgr);
89 return true;
90}
91
92// Writes to the given file. Returns false in case of error.
93bool LSTMRecognizer::Serialize(const TessdataManager *mgr, TFile *fp) const {
94 bool include_charsets = mgr == nullptr || !mgr->IsComponentAvailable(TESSDATA_LSTM_RECODER) ||
96 if (!network_->Serialize(fp)) {
97 return false;
98 }
99 if (include_charsets && !GetUnicharset().save_to_file(fp)) {
100 return false;
101 }
102 if (!fp->Serialize(network_str_)) {
103 return false;
104 }
105 if (!fp->Serialize(&training_flags_)) {
106 return false;
107 }
108 if (!fp->Serialize(&training_iteration_)) {
109 return false;
110 }
111 if (!fp->Serialize(&sample_iteration_)) {
112 return false;
113 }
114 if (!fp->Serialize(&null_char_)) {
115 return false;
116 }
117 if (!fp->Serialize(&adam_beta_)) {
118 return false;
119 }
120 if (!fp->Serialize(&learning_rate_)) {
121 return false;
122 }
123 if (!fp->Serialize(&momentum_)) {
124 return false;
125 }
126 if (include_charsets && IsRecoding() && !recoder_.Serialize(fp)) {
127 return false;
128 }
129 return true;
130}
131
132// Reads from the given file. Returns false in case of error.
134 delete network_;
136 if (network_ == nullptr) {
137 return false;
138 }
139 bool include_charsets = mgr == nullptr || !mgr->IsComponentAvailable(TESSDATA_LSTM_RECODER) ||
141 if (include_charsets && !ccutil_.unicharset.load_from_file(fp, false)) {
142 return false;
143 }
144 if (!fp->DeSerialize(network_str_)) {
145 return false;
146 }
147 if (!fp->DeSerialize(&training_flags_)) {
148 return false;
149 }
150 if (!fp->DeSerialize(&training_iteration_)) {
151 return false;
152 }
153 if (!fp->DeSerialize(&sample_iteration_)) {
154 return false;
155 }
156 if (!fp->DeSerialize(&null_char_)) {
157 return false;
158 }
159 if (!fp->DeSerialize(&adam_beta_)) {
160 return false;
161 }
162 if (!fp->DeSerialize(&learning_rate_)) {
163 return false;
164 }
165 if (!fp->DeSerialize(&momentum_)) {
166 return false;
167 }
168 if (include_charsets && !LoadRecoder(fp)) {
169 return false;
170 }
171 if (!include_charsets && !LoadCharsets(mgr)) {
172 return false;
173 }
176 return true;
177}
178
179// Loads the charsets from mgr.
181 TFile fp;
182 if (!mgr->GetComponent(TESSDATA_LSTM_UNICHARSET, &fp)) {
183 return false;
184 }
185 if (!ccutil_.unicharset.load_from_file(&fp, false)) {
186 return false;
187 }
188 if (!mgr->GetComponent(TESSDATA_LSTM_RECODER, &fp)) {
189 return false;
190 }
191 if (!LoadRecoder(&fp)) {
192 return false;
193 }
194 return true;
195}
196
197// Loads the Recoder.
199 if (IsRecoding()) {
200 if (!recoder_.DeSerialize(fp)) {
201 return false;
202 }
203 RecodedCharID code;
205 if (code(0) != UNICHAR_SPACE) {
206 tprintf("Space was garbled in recoding!!\n");
207 return false;
208 }
209 } else {
212 }
213 return true;
214}
215
216// Loads the dictionary if possible from the traineddata file.
217// Prints a warning message, and returns false but otherwise fails silently
218// and continues to work without it if loading fails.
219// Note that dictionary load is independent from DeSerialize, but dependent
220// on the unicharset matching. This enables training to deserialize a model
221// from checkpoint or restore without having to go back and reload the
222// dictionary.
223// Some parameters have to be passed in (from langdata/config/api via Tesseract)
224bool LSTMRecognizer::LoadDictionary(const ParamsVectors *params, const std::string &lang,
225 TessdataManager *mgr) {
226 delete dict_;
227 dict_ = new Dict(&ccutil_);
228 dict_->user_words_file.ResetFrom(params);
229 dict_->user_words_suffix.ResetFrom(params);
230 dict_->user_patterns_file.ResetFrom(params);
231 dict_->user_patterns_suffix.ResetFrom(params);
233 dict_->LoadLSTM(lang, mgr);
234 if (dict_->FinishLoad()) {
235 return true; // Success.
236 }
237 if (log_level <= 0) {
238 tprintf("Failed to load any lstm-specific dictionaries for lang %s!!\n", lang.c_str());
239 }
240 delete dict_;
241 dict_ = nullptr;
242 return false;
243}
244
245// Recognizes the line image, contained within image_data, returning the
246// ratings matrix and matching box_word for each WERD_RES in the output.
248 float invert_threshold, bool debug,
249 double worst_dict_cert, const TBOX &line_box,
250 PointerVector<WERD_RES> *words, int lstm_choice_mode,
251 int lstm_choice_amount) {
252 NetworkIO outputs;
253 float scale_factor;
254 NetworkIO inputs;
255 if (!RecognizeLine(image_data, invert_threshold, debug, false, false, &scale_factor, &inputs, &outputs)) {
256 return;
257 }
258 if (search_ == nullptr) {
260 }
261 search_->excludedUnichars.clear();
262 search_->Decode(outputs, kDictRatio, kCertOffset, worst_dict_cert, &GetUnicharset(),
263 lstm_choice_mode);
264 search_->ExtractBestPathAsWords(line_box, scale_factor, debug, &GetUnicharset(), words,
265 lstm_choice_mode);
266 if (lstm_choice_mode) {
268 for (int i = 0; i < lstm_choice_amount; ++i) {
269 search_->DecodeSecondaryBeams(outputs, kDictRatio, kCertOffset, worst_dict_cert,
270 &GetUnicharset(), lstm_choice_mode);
272 }
274 unsigned char_it = 0;
275 for (size_t i = 0; i < words->size(); ++i) {
276 for (int j = 0; j < words->at(i)->end; ++j) {
277 if (char_it < search_->ctc_choices.size()) {
278 words->at(i)->CTC_symbol_choices.push_back(search_->ctc_choices[char_it]);
279 }
280 if (char_it < search_->segmentedTimesteps.size()) {
281 words->at(i)->segmented_timesteps.push_back(search_->segmentedTimesteps[char_it]);
282 }
283 ++char_it;
284 }
285 words->at(i)->timesteps =
286 search_->combineSegmentedTimesteps(&words->at(i)->segmented_timesteps);
287 }
289 search_->ctc_choices.clear();
290 search_->excludedUnichars.clear();
291 }
292}
293
294// Helper computes min and mean best results in the output.
295void LSTMRecognizer::OutputStats(const NetworkIO &outputs, float *min_output, float *mean_output,
296 float *sd) {
297 const int kOutputScale = INT8_MAX;
298 STATS stats(0, kOutputScale);
299 for (int t = 0; t < outputs.Width(); ++t) {
300 int best_label = outputs.BestLabel(t, nullptr);
301 if (best_label != null_char_) {
302 float best_output = outputs.f(t)[best_label];
303 stats.add(static_cast<int>(kOutputScale * best_output), 1);
304 }
305 }
306 // If the output is all nulls it could be that the photometric interpretation
307 // is wrong, so make it look bad, so the other way can win, even if not great.
308 if (stats.get_total() == 0) {
309 *min_output = 0.0f;
310 *mean_output = 0.0f;
311 *sd = 1.0f;
312 } else {
313 *min_output = static_cast<float>(stats.min_bucket()) / kOutputScale;
314 *mean_output = stats.mean() / kOutputScale;
315 *sd = stats.sd() / kOutputScale;
316 }
317}
318
319// Recognizes the image_data, returning the labels,
320// scores, and corresponding pairs of start, end x-coords in coords.
322 float invert_threshold, bool debug,
323 bool re_invert, bool upside_down, float *scale_factor,
324 NetworkIO *inputs, NetworkIO *outputs) {
325 // This ensures consistent recognition results.
327 int min_width = network_->XScaleFactor();
328 Image pix = Input::PrepareLSTMInputs(image_data, network_, min_width, &randomizer_, scale_factor);
329 if (pix == nullptr) {
330 tprintf("Line cannot be recognized!!\n");
331 return false;
332 }
333 // Maximum width of image to train on.
334 const int kMaxImageWidth = 128 * pixGetHeight(pix);
335 if (network_->IsTraining() && pixGetWidth(pix) > kMaxImageWidth) {
336 tprintf("Image too large to learn!! Size = %dx%d\n", pixGetWidth(pix), pixGetHeight(pix));
337 pix.destroy();
338 return false;
339 }
340 if (upside_down) {
341 pixRotate180(pix, pix);
342 }
343 // Reduction factor from image to coords.
344 *scale_factor = min_width / *scale_factor;
345 inputs->set_int_mode(IsIntMode());
348 network_->Forward(debug, *inputs, nullptr, &scratch_space_, outputs);
349 // Check for auto inversion.
350 if (invert_threshold > 0.0f) {
351 float pos_min, pos_mean, pos_sd;
352 OutputStats(*outputs, &pos_min, &pos_mean, &pos_sd);
353 if (pos_mean < invert_threshold) {
354 // Run again inverted and see if it is any better.
355 NetworkIO inv_inputs, inv_outputs;
356 inv_inputs.set_int_mode(IsIntMode());
358 pixInvert(pix, pix);
360 network_->Forward(debug, inv_inputs, nullptr, &scratch_space_, &inv_outputs);
361 float inv_min, inv_mean, inv_sd;
362 OutputStats(inv_outputs, &inv_min, &inv_mean, &inv_sd);
363 if (inv_mean > pos_mean) {
364 // Inverted did better. Use inverted data.
365 if (debug) {
366 tprintf("Inverting image: old min=%g, mean=%g, sd=%g, inv %g,%g,%g\n", pos_min, pos_mean,
367 pos_sd, inv_min, inv_mean, inv_sd);
368 }
369 *outputs = inv_outputs;
370 *inputs = inv_inputs;
371 } else if (re_invert) {
372 // Inverting was not an improvement, so undo and run again, so the
373 // outputs match the best forward result.
375 network_->Forward(debug, *inputs, nullptr, &scratch_space_, outputs);
376 }
377 }
378 }
379
380 pix.destroy();
381 if (debug) {
382 std::vector<int> labels, coords;
383 LabelsFromOutputs(*outputs, &labels, &coords);
384#ifndef GRAPHICS_DISABLED
385 DisplayForward(*inputs, labels, coords, "LSTMForward", &debug_win_);
386#endif
387 DebugActivationPath(*outputs, labels, coords);
388 }
389 return true;
390}
391
392// Converts an array of labels to utf-8, whether or not the labels are
393// augmented with character boundaries.
394std::string LSTMRecognizer::DecodeLabels(const std::vector<int> &labels) {
395 std::string result;
396 unsigned end = 1;
397 for (unsigned start = 0; start < labels.size(); start = end) {
398 if (labels[start] == null_char_) {
399 end = start + 1;
400 } else {
401 result += DecodeLabel(labels, start, &end, nullptr);
402 }
403 }
404 return result;
405}
406
407#ifndef GRAPHICS_DISABLED
408
409// Displays the forward results in a window with the characters and
410// boundaries as determined by the labels and label_coords.
411void LSTMRecognizer::DisplayForward(const NetworkIO &inputs, const std::vector<int> &labels,
412 const std::vector<int> &label_coords, const char *window_name,
413 ScrollView **window) {
414 Image input_pix = inputs.ToPix();
415 Network::ClearWindow(false, window_name, pixGetWidth(input_pix), pixGetHeight(input_pix), window);
416 int line_height = Network::DisplayImage(input_pix, *window);
417 DisplayLSTMOutput(labels, label_coords, line_height, *window);
418}
419
420// Displays the labels and cuts at the corresponding xcoords.
421// Size of labels should match xcoords.
422void LSTMRecognizer::DisplayLSTMOutput(const std::vector<int> &labels,
423 const std::vector<int> &xcoords, int height,
424 ScrollView *window) {
425 int x_scale = network_->XScaleFactor();
426 window->TextAttributes("Arial", height / 4, false, false, false);
427 unsigned end = 1;
428 for (unsigned start = 0; start < labels.size(); start = end) {
429 int xpos = xcoords[start] * x_scale;
430 if (labels[start] == null_char_) {
431 end = start + 1;
432 window->Pen(ScrollView::RED);
433 } else {
434 window->Pen(ScrollView::GREEN);
435 const char *str = DecodeLabel(labels, start, &end, nullptr);
436 if (*str == '\\') {
437 str = "\\\\";
438 }
439 xpos = xcoords[(start + end) / 2] * x_scale;
440 window->Text(xpos, height, str);
441 }
442 window->Line(xpos, 0, xpos, height * 3 / 2);
443 }
444 window->Update();
445}
446
447#endif // !GRAPHICS_DISABLED
448
449// Prints debug output detailing the activation path that is implied by the
450// label_coords.
451void LSTMRecognizer::DebugActivationPath(const NetworkIO &outputs, const std::vector<int> &labels,
452 const std::vector<int> &xcoords) {
453 if (xcoords[0] > 0) {
454 DebugActivationRange(outputs, "<null>", null_char_, 0, xcoords[0]);
455 }
456 unsigned end = 1;
457 for (unsigned start = 0; start < labels.size(); start = end) {
458 if (labels[start] == null_char_) {
459 end = start + 1;
460 DebugActivationRange(outputs, "<null>", null_char_, xcoords[start], xcoords[end]);
461 continue;
462 } else {
463 int decoded;
464 const char *label = DecodeLabel(labels, start, &end, &decoded);
465 DebugActivationRange(outputs, label, labels[start], xcoords[start], xcoords[start + 1]);
466 for (unsigned i = start + 1; i < end; ++i) {
467 DebugActivationRange(outputs, DecodeSingleLabel(labels[i]), labels[i], xcoords[i],
468 xcoords[i + 1]);
469 }
470 }
471 }
472}
473
474// Prints debug output detailing activations and 2nd choice over a range
475// of positions.
476void LSTMRecognizer::DebugActivationRange(const NetworkIO &outputs, const char *label,
477 int best_choice, int x_start, int x_end) {
478 tprintf("%s=%d On [%d, %d), scores=", label, best_choice, x_start, x_end);
479 double max_score = 0.0;
480 double mean_score = 0.0;
481 const int width = x_end - x_start;
482 for (int x = x_start; x < x_end; ++x) {
483 const float *line = outputs.f(x);
484 const double score = line[best_choice] * 100.0;
485 if (score > max_score) {
486 max_score = score;
487 }
488 mean_score += score / width;
489 int best_c = 0;
490 double best_score = 0.0;
491 for (int c = 0; c < outputs.NumFeatures(); ++c) {
492 if (c != best_choice && line[c] > best_score) {
493 best_c = c;
494 best_score = line[c];
495 }
496 }
497 tprintf(" %.3g(%s=%d=%.3g)", score, DecodeSingleLabel(best_c), best_c, best_score * 100.0);
498 }
499 tprintf(", Mean=%g, max=%g\n", mean_score, max_score);
500}
501
502// Helper returns true if the null_char is the winner at t, and it beats the
503// null_threshold, or the next choice is space, in which case we will use the
504// null anyway.
505#if 0 // TODO: unused, remove if still unused after 2020.
506static bool NullIsBest(const NetworkIO& output, float null_thr,
507 int null_char, int t) {
508 if (output.f(t)[null_char] >= null_thr) return true;
509 if (output.BestLabel(t, null_char, null_char, nullptr) != UNICHAR_SPACE)
510 return false;
511 return output.f(t)[null_char] > output.f(t)[UNICHAR_SPACE];
512}
513#endif
514
515// Converts the network output to a sequence of labels. Outputs labels, scores
516// and start xcoords of each char, and each null_char_, with an additional
517// final xcoord for the end of the output.
518// The conversion method is determined by internal state.
519void LSTMRecognizer::LabelsFromOutputs(const NetworkIO &outputs, std::vector<int> *labels,
520 std::vector<int> *xcoords) {
521 if (SimpleTextOutput()) {
522 LabelsViaSimpleText(outputs, labels, xcoords);
523 } else {
524 LabelsViaReEncode(outputs, labels, xcoords);
525 }
526}
527
528// As LabelsViaCTC except that this function constructs the best path that
529// contains only legal sequences of subcodes for CJK.
530void LSTMRecognizer::LabelsViaReEncode(const NetworkIO &output, std::vector<int> *labels,
531 std::vector<int> *xcoords) {
532 if (search_ == nullptr) {
534 }
536 search_->ExtractBestPathAsLabels(labels, xcoords);
537}
538
539// Converts the network output to a sequence of labels, with scores, using
540// the simple character model (each position is a char, and the null_char_ is
541// mainly intended for tail padding.)
542void LSTMRecognizer::LabelsViaSimpleText(const NetworkIO &output, std::vector<int> *labels,
543 std::vector<int> *xcoords) {
544 labels->clear();
545 xcoords->clear();
546 const int width = output.Width();
547 for (int t = 0; t < width; ++t) {
548 float score = 0.0f;
549 const int label = output.BestLabel(t, &score);
550 if (label != null_char_) {
551 labels->push_back(label);
552 xcoords->push_back(t);
553 }
554 }
555 xcoords->push_back(width);
556}
557
558// Returns a string corresponding to the label starting at start. Sets *end
559// to the next start and if non-null, *decoded to the unichar id.
560const char *LSTMRecognizer::DecodeLabel(const std::vector<int> &labels, unsigned start, unsigned *end,
561 int *decoded) {
562 *end = start + 1;
563 if (IsRecoding()) {
564 // Decode labels via recoder_.
565 RecodedCharID code;
566 if (labels[start] == null_char_) {
567 if (decoded != nullptr) {
568 code.Set(0, null_char_);
569 *decoded = recoder_.DecodeUnichar(code);
570 }
571 return "<null>";
572 }
573 unsigned index = start;
574 while (index < labels.size() && code.length() < RecodedCharID::kMaxCodeLen) {
575 code.Set(code.length(), labels[index++]);
576 while (index < labels.size() && labels[index] == null_char_) {
577 ++index;
578 }
579 int uni_id = recoder_.DecodeUnichar(code);
580 // If the next label isn't a valid first code, then we need to continue
581 // extending even if we have a valid uni_id from this prefix.
582 if (uni_id != INVALID_UNICHAR_ID &&
583 (index == labels.size() || code.length() == RecodedCharID::kMaxCodeLen ||
584 recoder_.IsValidFirstCode(labels[index]))) {
585 *end = index;
586 if (decoded != nullptr) {
587 *decoded = uni_id;
588 }
589 if (uni_id == UNICHAR_SPACE) {
590 return " ";
591 }
592 return GetUnicharset().get_normed_unichar(uni_id);
593 }
594 }
595 return "<Undecodable>";
596 } else {
597 if (decoded != nullptr) {
598 *decoded = labels[start];
599 }
600 if (labels[start] == null_char_) {
601 return "<null>";
602 }
603 if (labels[start] == UNICHAR_SPACE) {
604 return " ";
605 }
606 return GetUnicharset().get_normed_unichar(labels[start]);
607 }
608}
609
610// Returns a string corresponding to a given single label id, falling back to
611// a default of ".." for part of a multi-label unichar-id.
612const char *LSTMRecognizer::DecodeSingleLabel(int label) {
613 if (label == null_char_) {
614 return "<null>";
615 }
616 if (IsRecoding()) {
617 // Decode label via recoder_.
618 RecodedCharID code;
619 code.Set(0, label);
620 label = recoder_.DecodeUnichar(code);
621 if (label == INVALID_UNICHAR_ID) {
622 return ".."; // Part of a bigger code.
623 }
624 }
625 if (label == UNICHAR_SPACE) {
626 return " ";
627 }
628 return GetUnicharset().get_normed_unichar(label);
629}
630
631} // namespace tesseract.
@ TF_COMPRESS_UNICHARSET
void tprintf(const char *format,...)
Definition: tprintf.cpp:41
@ TESSDATA_LSTM_UNICHARSET
@ TESSDATA_LSTM_RECODER
const double kCertOffset
@ UNICHAR_SPACE
Definition: unicharset.h:36
@ UNICHAR_BROKEN
Definition: unicharset.h:38
const double kDictRatio
int log_level
Definition: tprintf.cpp:36
void destroy()
Definition: image.cpp:32
void add(int32_t value, int32_t count)
Definition: statistc.cpp:99
int32_t get_total() const
Definition: statistc.h:85
int32_t min_bucket() const
Definition: statistc.cpp:204
double sd() const
Definition: statistc.cpp:148
double mean() const
Definition: statistc.cpp:132
std::string language_data_path_prefix
Definition: ccutil.h:60
UNICHARSET unicharset
Definition: ccutil.h:61
T & at(int index) const
Definition: genericvector.h:89
unsigned size() const
Definition: genericvector.h:70
bool DeSerialize(std::string &data)
Definition: serialis.cpp:94
bool Serialize(const std::string &data)
Definition: serialis.cpp:107
bool GetComponent(TessdataType type, TFile *fp)
bool IsComponentAvailable(TessdataType type) const
void Set(int index, int value)
static const int kMaxCodeLen
int EncodeUnichar(unsigned unichar_id, RecodedCharID *code) const
bool IsValidFirstCode(int code) const
void SetupPassThrough(const UNICHARSET &unicharset)
int DecodeUnichar(const RecodedCharID &code) const
bool Serialize(TFile *fp) const
bool load_from_file(const char *const filename, bool skip_fragments)
Definition: unicharset.h:391
const char * get_normed_unichar(UNICHAR_ID unichar_id) const
Definition: unicharset.h:859
static DawgCache * GlobalDawgCache()
Definition: dict.cpp:172
void LoadLSTM(const std::string &lang, TessdataManager *data_file)
Definition: dict.cpp:291
void SetupForLoad(DawgCache *dawg_cache)
Definition: dict.cpp:180
bool FinishLoad()
Definition: dict.cpp:357
static Image PrepareLSTMInputs(const ImageData &image_data, const Network *network, int min_width, TRand *randomizer, float *image_scale)
Definition: input.cpp:81
static void PreparePixInput(const StaticShape &shape, const Image pix, TRand *randomizer, NetworkIO *input)
Definition: input.cpp:107
void DebugActivationPath(const NetworkIO &outputs, const std::vector< int > &labels, const std::vector< int > &xcoords)
std::string DecodeLabels(const std::vector< int > &labels)
NetworkScratch scratch_space_
void LabelsViaReEncode(const NetworkIO &output, std::vector< int > *labels, std::vector< int > *xcoords)
const char * DecodeSingleLabel(int label)
bool LoadCharsets(const TessdataManager *mgr)
void LabelsFromOutputs(const NetworkIO &outputs, std::vector< int > *labels, std::vector< int > *xcoords)
void RecognizeLine(const ImageData &image_data, float invert_threshold, bool debug, double worst_dict_cert, const TBOX &line_box, PointerVector< WERD_RES > *words, int lstm_choice_mode=0, int lstm_choice_amount=5)
void DisplayForward(const NetworkIO &inputs, const std::vector< int > &labels, const std::vector< int > &label_coords, const char *window_name, ScrollView **window)
void OutputStats(const NetworkIO &outputs, float *min_output, float *mean_output, float *sd)
const char * DecodeLabel(const std::vector< int > &labels, unsigned start, unsigned *end, int *decoded)
bool Load(const ParamsVectors *params, const std::string &lang, TessdataManager *mgr)
RecodeBeamSearch * search_
void LabelsViaSimpleText(const NetworkIO &output, std::vector< int > *labels, std::vector< int > *xcoords)
void DisplayLSTMOutput(const std::vector< int > &labels, const std::vector< int > &xcoords, int height, ScrollView *window)
bool Serialize(const TessdataManager *mgr, TFile *fp) const
void DebugActivationRange(const NetworkIO &outputs, const char *label, int best_choice, int x_start, int x_end)
const UNICHARSET & GetUnicharset() const
bool DeSerialize(const TessdataManager *mgr, TFile *fp)
bool LoadDictionary(const ParamsVectors *params, const std::string &lang, TessdataManager *mgr)
virtual int XScaleFactor() const
Definition: network.h:214
virtual void Forward(bool debug, const NetworkIO &input, const TransposedArray *input_transpose, NetworkScratch *scratch, NetworkIO *output)=0
static void ClearWindow(bool tess_coords, const char *window_name, int width, int height, ScrollView **window)
Definition: network.cpp:350
static Network * CreateFromFile(TFile *fp)
Definition: network.cpp:217
bool IsTraining() const
Definition: network.h:113
virtual bool Serialize(TFile *fp) const
Definition: network.cpp:158
virtual void CacheXScaleFactor(int factor)
Definition: network.h:220
static int DisplayImage(Image pix, ScrollView *window)
Definition: network.cpp:378
virtual void SetRandomizer(TRand *randomizer)
Definition: network.cpp:145
virtual StaticShape InputShape() const
Definition: network.h:129
Image ToPix() const
Definition: networkio.cpp:300
float * f(int t)
Definition: networkio.h:110
int Width() const
Definition: networkio.h:102
void set_int_mode(bool is_quantized)
Definition: networkio.h:125
int NumFeatures() const
Definition: networkio.h:106
int BestLabel(int t, float *score) const
Definition: networkio.h:165
void Decode(const NetworkIO &output, double dict_ratio, double cert_offset, double worst_dict_cert, const UNICHARSET *charset, int lstm_choice_mode=0)
Definition: recodebeam.cpp:83
std::vector< std::vector< std::pair< const char *, float > > > ctc_choices
Definition: recodebeam.h:234
void DecodeSecondaryBeams(const NetworkIO &output, double dict_ratio, double cert_offset, double worst_dict_cert, const UNICHARSET *charset, int lstm_choice_mode=0)
Definition: recodebeam.cpp:112
std::vector< std::vector< std::pair< const char *, float > > > combineSegmentedTimesteps(std::vector< std::vector< std::vector< std::pair< const char *, float > > > > *segmentedTimesteps)
Definition: recodebeam.cpp:175
std::vector< std::vector< std::vector< std::pair< const char *, float > > > > segmentedTimesteps
Definition: recodebeam.h:232
void extractSymbolChoices(const UNICHARSET *unicharset)
Definition: recodebeam.cpp:409
std::vector< std::unordered_set< int > > excludedUnichars
Definition: recodebeam.h:236
void ExtractBestPathAsLabels(std::vector< int > *labels, std::vector< int > *xcoords) const
Definition: recodebeam.cpp:201
static constexpr float kMinCertainty
Definition: recodebeam.h:243
void ExtractBestPathAsWords(const TBOX &line_box, float scale_factor, bool debug, const UNICHARSET *unicharset, PointerVector< WERD_RES > *words, int lstm_choice_mode=0)
Definition: recodebeam.cpp:239
void Line(int x1, int y1, int x2, int y2)
Definition: scrollview.cpp:498
void TextAttributes(const char *font, int pixel_size, bool bold, bool italic, bool underlined)
Definition: scrollview.cpp:610
void Text(int x, int y, const char *mystring)
Definition: scrollview.cpp:635
void Pen(Color color)
Definition: scrollview.cpp:710
static void Update()
Definition: scrollview.cpp:700