tesseract v5.3.3.20231005
lstmrecognizer.h
Go to the documentation of this file.
1
2// File: lstmrecognizer.h
3// Description: Top-level line recognizer class for LSTM-based networks.
4// Author: Ray Smith
5//
6// (C) Copyright 2013, Google Inc.
7// Licensed under the Apache License, Version 2.0 (the "License");
8// you may not use this file except in compliance with the License.
9// You may obtain a copy of the License at
10// http://www.apache.org/licenses/LICENSE-2.0
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
17
18#ifndef TESSERACT_LSTM_LSTMRECOGNIZER_H_
19#define TESSERACT_LSTM_LSTMRECOGNIZER_H_
20
21#include "ccutil.h"
22#include "helpers.h"
23#include "matrix.h"
24#include "network.h"
25#include "networkscratch.h"
26#include "params.h"
27#include "recodebeam.h"
28#include "series.h"
29#include "unicharcompress.h"
30
31class BLOB_CHOICE_IT;
32struct Pix;
33class ROW_RES;
34class ScrollView;
35class TBOX;
36class WERD_RES;
37
38namespace tesseract {
39
40class Dict;
41class ImageData;
42
43// Enum indicating training mode control flags.
47};
48
49// Top-level line recognizer class for LSTM-based networks.
50// Note that a sub-class, LSTMTrainer is used for training.
52public:
54 LSTMRecognizer(const std::string &language_data_path_prefix);
56
57 int NumOutputs() const {
58 return network_->NumOutputs();
59 }
60
61 // Return the training iterations.
62 int training_iteration() const {
63 return training_iteration_;
64 }
65
66 // Return the sample iterations.
67 int sample_iteration() const {
68 return sample_iteration_;
69 }
70
71 // Return the learning rate.
72 float learning_rate() const {
73 return learning_rate_;
74 }
75
77 if (network_ == nullptr) {
78 return LT_NONE;
79 }
80 StaticShape shape;
81 shape = network_->OutputShape(shape);
82 return shape.loss_type();
83 }
84 bool SimpleTextOutput() const {
85 return OutputLossType() == LT_SOFTMAX;
86 }
87 bool IsIntMode() const {
88 return (training_flags_ & TF_INT_MODE) != 0;
89 }
90 // True if recoder_ is active to re-encode text to a smaller space.
91 bool IsRecoding() const {
92 return (training_flags_ & TF_COMPRESS_UNICHARSET) != 0;
93 }
94 // Returns true if the network is a TensorFlow network.
95 bool IsTensorFlow() const {
96 return network_->type() == NT_TENSORFLOW;
97 }
98 // Returns a vector of layer ids that can be passed to other layer functions
99 // to access a specific layer.
100 std::vector<std::string> EnumerateLayers() const {
101 ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES);
102 auto *series = static_cast<Series *>(network_);
103 std::vector<std::string> layers;
104 series->EnumerateLayers(nullptr, layers);
105 return layers;
106 }
107 // Returns a specific layer from its id (from EnumerateLayers).
108 Network *GetLayer(const std::string &id) const {
109 ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES);
110 ASSERT_HOST(id.length() > 1 && id[0] == ':');
111 auto *series = static_cast<Series *>(network_);
112 return series->GetLayer(&id[1]);
113 }
114 // Returns the learning rate of the layer from its id.
115 float GetLayerLearningRate(const std::string &id) const {
116 ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES);
117 if (network_->TestFlag(NF_LAYER_SPECIFIC_LR)) {
118 ASSERT_HOST(id.length() > 1 && id[0] == ':');
119 auto *series = static_cast<Series *>(network_);
120 return series->LayerLearningRate(&id[1]);
121 } else {
122 return learning_rate_;
123 }
124 }
125
126 // Return the network string.
127 const char *GetNetwork() const {
128 return network_str_.c_str();
129 }
130
131 // Return the adam beta.
132 float GetAdamBeta() const {
133 return adam_beta_;
134 }
135
136 // Return the momentum.
137 float GetMomentum() const {
138 return momentum_;
139 }
140
141 // Multiplies the all the learning rate(s) by the given factor.
142 void ScaleLearningRate(double factor) {
143 ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES);
144 learning_rate_ *= factor;
145 if (network_->TestFlag(NF_LAYER_SPECIFIC_LR)) {
146 std::vector<std::string> layers = EnumerateLayers();
147 for (auto &layer : layers) {
148 ScaleLayerLearningRate(layer, factor);
149 }
150 }
151 }
152 // Multiplies the learning rate of the layer with id, by the given factor.
153 void ScaleLayerLearningRate(const std::string &id, double factor) {
154 ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES);
155 ASSERT_HOST(id.length() > 1 && id[0] == ':');
156 auto *series = static_cast<Series *>(network_);
157 series->ScaleLayerLearningRate(&id[1], factor);
158 }
159
160 // Set the all the learning rate(s) to the given value.
161 void SetLearningRate(float learning_rate)
162 {
163 ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES);
164 learning_rate_ = learning_rate;
165 if (network_->TestFlag(NF_LAYER_SPECIFIC_LR)) {
166 for (auto &id : EnumerateLayers()) {
167 SetLayerLearningRate(id, learning_rate);
168 }
169 }
170 }
171 // Set the learning rate of the layer with id, by the given value.
172 void SetLayerLearningRate(const std::string &id, float learning_rate)
173 {
174 ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES);
175 ASSERT_HOST(id.length() > 1 && id[0] == ':');
176 auto *series = static_cast<Series *>(network_);
177 series->SetLayerLearningRate(&id[1], learning_rate);
178 }
179
180 // Converts the network to int if not already.
182 if ((training_flags_ & TF_INT_MODE) == 0) {
183 network_->ConvertToInt();
184 training_flags_ |= TF_INT_MODE;
185 }
186 }
187
188 // Provides access to the UNICHARSET that this classifier works with.
189 const UNICHARSET &GetUnicharset() const {
190 return ccutil_.unicharset;
191 }
193 return ccutil_.unicharset;
194 }
195 // Provides access to the UnicharCompress that this classifier works with.
197 return recoder_;
198 }
199 // Provides access to the Dict that this classifier works with.
200 const Dict *GetDict() const {
201 return dict_;
202 }
204 return dict_;
205 }
206 // Sets the sample iteration to the given value. The sample_iteration_
207 // determines the seed for the random number generator. The training
208 // iteration is incremented only by a successful training iteration.
209 void SetIteration(int iteration) {
210 sample_iteration_ = iteration;
211 }
212 // Accessors for textline image normalization.
213 int NumInputs() const {
214 return network_->NumInputs();
215 }
216
217 // Return the null char index.
218 int null_char() const {
219 return null_char_;
220 }
221
222 // Loads a model from mgr, including the dictionary only if lang is not null.
223 bool Load(const ParamsVectors *params, const std::string &lang, TessdataManager *mgr);
224
225 // Writes to the given file. Returns false in case of error.
226 // If mgr contains a unicharset and recoder, then they are not encoded to fp.
227 bool Serialize(const TessdataManager *mgr, TFile *fp) const;
228 // Reads from the given file. Returns false in case of error.
229 // If mgr contains a unicharset and recoder, then they are taken from there,
230 // otherwise, they are part of the serialization in fp.
231 bool DeSerialize(const TessdataManager *mgr, TFile *fp);
232 // Loads the charsets from mgr.
233 bool LoadCharsets(const TessdataManager *mgr);
234 // Loads the Recoder.
235 bool LoadRecoder(TFile *fp);
236 // Loads the dictionary if possible from the traineddata file.
237 // Prints a warning message, and returns false but otherwise fails silently
238 // and continues to work without it if loading fails.
239 // Note that dictionary load is independent from DeSerialize, but dependent
240 // on the unicharset matching. This enables training to deserialize a model
241 // from checkpoint or restore without having to go back and reload the
242 // dictionary.
243 bool LoadDictionary(const ParamsVectors *params, const std::string &lang, TessdataManager *mgr);
244
245 // Recognizes the line image, contained within image_data, returning the
246 // recognized tesseract WERD_RES for the words.
247 // If invert_threshold > 0, tries inverted as well if the normal
248 // interpretation doesn't produce a result which at least reaches
249 // that threshold. The line_box is used for computing the
250 // box_word in the output words. worst_dict_cert is the worst certainty that
251 // will be used in a dictionary word.
252 void RecognizeLine(const ImageData &image_data, float invert_threshold, bool debug, double worst_dict_cert,
253 const TBOX &line_box, PointerVector<WERD_RES> *words, int lstm_choice_mode = 0,
254 int lstm_choice_amount = 5);
255
256 // Helper computes min and mean best results in the output.
257 void OutputStats(const NetworkIO &outputs, float *min_output, float *mean_output, float *sd);
258 // Recognizes the image_data, returning the labels,
259 // scores, and corresponding pairs of start, end x-coords in coords.
260 // Returned in scale_factor is the reduction factor
261 // between the image and the output coords, for computing bounding boxes.
262 // If re_invert is true, the input is inverted back to its original
263 // photometric interpretation if inversion is attempted but fails to
264 // improve the results. This ensures that outputs contains the correct
265 // forward outputs for the best photometric interpretation.
266 // inputs is filled with the used inputs to the network.
267 bool RecognizeLine(const ImageData &image_data, float invert_threshold, bool debug, bool re_invert,
268 bool upside_down, float *scale_factor, NetworkIO *inputs, NetworkIO *outputs);
269
270 // Converts an array of labels to utf-8, whether or not the labels are
271 // augmented with character boundaries.
272 std::string DecodeLabels(const std::vector<int> &labels);
273
274 // Displays the forward results in a window with the characters and
275 // boundaries as determined by the labels and label_coords.
276 void DisplayForward(const NetworkIO &inputs, const std::vector<int> &labels,
277 const std::vector<int> &label_coords, const char *window_name,
278 ScrollView **window);
279 // Converts the network output to a sequence of labels. Outputs labels, scores
280 // and start xcoords of each char, and each null_char_, with an additional
281 // final xcoord for the end of the output.
282 // The conversion method is determined by internal state.
283 void LabelsFromOutputs(const NetworkIO &outputs, std::vector<int> *labels,
284 std::vector<int> *xcoords);
285
286protected:
287 // Sets the random seed from the sample_iteration_;
289 int64_t seed = static_cast<int64_t>(sample_iteration_) * 0x10000001;
290 randomizer_.set_seed(seed);
291 randomizer_.IntRand();
292 }
293
294 // Displays the labels and cuts at the corresponding xcoords.
295 // Size of labels should match xcoords.
296 void DisplayLSTMOutput(const std::vector<int> &labels, const std::vector<int> &xcoords,
297 int height, ScrollView *window);
298
299 // Prints debug output detailing the activation path that is implied by the
300 // xcoords.
301 void DebugActivationPath(const NetworkIO &outputs, const std::vector<int> &labels,
302 const std::vector<int> &xcoords);
303
304 // Prints debug output detailing activations and 2nd choice over a range
305 // of positions.
306 void DebugActivationRange(const NetworkIO &outputs, const char *label, int best_choice,
307 int x_start, int x_end);
308
309 // As LabelsViaCTC except that this function constructs the best path that
310 // contains only legal sequences of subcodes for recoder_.
311 void LabelsViaReEncode(const NetworkIO &output, std::vector<int> *labels,
312 std::vector<int> *xcoords);
313 // Converts the network output to a sequence of labels, with scores, using
314 // the simple character model (each position is a char, and the null_char_ is
315 // mainly intended for tail padding.)
316 void LabelsViaSimpleText(const NetworkIO &output, std::vector<int> *labels,
317 std::vector<int> *xcoords);
318
319 // Returns a string corresponding to the label starting at start. Sets *end
320 // to the next start and if non-null, *decoded to the unichar id.
321 const char *DecodeLabel(const std::vector<int> &labels, unsigned start, unsigned *end, int *decoded);
322
323 // Returns a string corresponding to a given single label id, falling back to
324 // a default of ".." for part of a multi-label unichar-id.
325 const char *DecodeSingleLabel(int label);
326
327protected:
328 // The network hierarchy.
330 // The unicharset. Only the unicharset element is serialized.
331 // Has to be a CCUtil, so Dict can point to it.
333 // For backward compatibility, recoder_ is serialized iff
334 // training_flags_ & TF_COMPRESS_UNICHARSET.
335 // Further encode/decode ccutil_.unicharset's ids to simplify the unicharset.
337
338 // ==Training parameters that are serialized to provide a record of them.==
339 std::string network_str_;
340 // Flags used to determine the training method of the network.
341 // See enum TrainingFlags above.
343 // Number of actual backward training steps used.
345 // Index into training sample set. sample_iteration >= training_iteration_.
347 // Index in softmax of null character. May take the value UNICHAR_BROKEN or
348 // ccutil_.unicharset.size().
349 int32_t null_char_;
350 // Learning rate and momentum multipliers of deltas in backprop.
353 // Smoothing factor for 2nd moment of gradients.
355
356 // === NOT SERIALIZED.
359 // Language model (optional) to use with the beam search.
361 // Beam search held between uses to optimize memory allocation/use.
363
364 // == Debugging parameters.==
365 // Recognition debug display window.
367};
368
369} // namespace tesseract.
370
371#endif // TESSERACT_LSTM_LSTMRECOGNIZER_H_
#define ASSERT_HOST(x)
Definition: errcode.h:54
@ TBOX
@ TF_COMPRESS_UNICHARSET
bool DeSerialize(bool swap, FILE *fp, std::vector< T > &data)
Definition: helpers.h:205
bool Serialize(FILE *fp, const std::vector< T > &data)
Definition: helpers.h:236
@ NT_TENSORFLOW
Definition: network.h:76
@ NT_SERIES
Definition: network.h:52
@ NF_LAYER_SPECIFIC_LR
Definition: network.h:85
LossType OutputLossType() const
void SetLayerLearningRate(const std::string &id, float learning_rate)
UNICHARSET & GetUnicharset()
const UnicharCompress & GetRecoder() const
NetworkScratch scratch_space_
void SetIteration(int iteration)
RecodeBeamSearch * search_
const Dict * GetDict() const
void ScaleLearningRate(double factor)
void ScaleLayerLearningRate(const std::string &id, double factor)
std::vector< std::string > EnumerateLayers() const
float GetLayerLearningRate(const std::string &id) const
const char * GetNetwork() const
Network * GetLayer(const std::string &id) const
const UNICHARSET & GetUnicharset() const
void SetLearningRate(float learning_rate)
float LayerLearningRate(const char *id)
Definition: plumbing.h:110
void EnumerateLayers(const std::string *prefix, std::vector< std::string > &layers) const
Definition: plumbing.cpp:144
void SetLayerLearningRate(const char *id, float learning_rate)
Definition: plumbing.h:123
void ScaleLayerLearningRate(const char *id, double factor)
Definition: plumbing.h:116
Network * GetLayer(const char *id) const
Definition: plumbing.cpp:161
LossType loss_type() const
Definition: static_shape.h:65
#define TESS_API
Definition: export.h:32