tesseract v5.3.3.20231005
tesseract::LSTMRecognizer Class Reference

#include <lstmrecognizer.h>

Inheritance diagram for tesseract::LSTMRecognizer:
tesseract::LSTMTrainer

Public Member Functions

 LSTMRecognizer ()
 
 LSTMRecognizer (const std::string &language_data_path_prefix)
 
 ~LSTMRecognizer ()
 
int NumOutputs () const
 
int training_iteration () const
 
int sample_iteration () const
 
float learning_rate () const
 
LossType OutputLossType () const
 
bool SimpleTextOutput () const
 
bool IsIntMode () const
 
bool IsRecoding () const
 
bool IsTensorFlow () const
 
std::vector< std::string > EnumerateLayers () const
 
NetworkGetLayer (const std::string &id) const
 
float GetLayerLearningRate (const std::string &id) const
 
const char * GetNetwork () const
 
float GetAdamBeta () const
 
float GetMomentum () const
 
void ScaleLearningRate (double factor)
 
void ScaleLayerLearningRate (const std::string &id, double factor)
 
void SetLearningRate (float learning_rate)
 
void SetLayerLearningRate (const std::string &id, float learning_rate)
 
void ConvertToInt ()
 
const UNICHARSETGetUnicharset () const
 
UNICHARSETGetUnicharset ()
 
const UnicharCompressGetRecoder () const
 
const DictGetDict () const
 
DictGetDict ()
 
void SetIteration (int iteration)
 
int NumInputs () const
 
int null_char () const
 
bool Load (const ParamsVectors *params, const std::string &lang, TessdataManager *mgr)
 
bool Serialize (const TessdataManager *mgr, TFile *fp) const
 
bool DeSerialize (const TessdataManager *mgr, TFile *fp)
 
bool LoadCharsets (const TessdataManager *mgr)
 
bool LoadRecoder (TFile *fp)
 
bool LoadDictionary (const ParamsVectors *params, const std::string &lang, TessdataManager *mgr)
 
void RecognizeLine (const ImageData &image_data, float invert_threshold, bool debug, double worst_dict_cert, const TBOX &line_box, PointerVector< WERD_RES > *words, int lstm_choice_mode=0, int lstm_choice_amount=5)
 
void OutputStats (const NetworkIO &outputs, float *min_output, float *mean_output, float *sd)
 
bool RecognizeLine (const ImageData &image_data, float invert_threshold, bool debug, bool re_invert, bool upside_down, float *scale_factor, NetworkIO *inputs, NetworkIO *outputs)
 
std::string DecodeLabels (const std::vector< int > &labels)
 
void DisplayForward (const NetworkIO &inputs, const std::vector< int > &labels, const std::vector< int > &label_coords, const char *window_name, ScrollView **window)
 
void LabelsFromOutputs (const NetworkIO &outputs, std::vector< int > *labels, std::vector< int > *xcoords)
 

Protected Member Functions

void SetRandomSeed ()
 
void DisplayLSTMOutput (const std::vector< int > &labels, const std::vector< int > &xcoords, int height, ScrollView *window)
 
void DebugActivationPath (const NetworkIO &outputs, const std::vector< int > &labels, const std::vector< int > &xcoords)
 
void DebugActivationRange (const NetworkIO &outputs, const char *label, int best_choice, int x_start, int x_end)
 
void LabelsViaReEncode (const NetworkIO &output, std::vector< int > *labels, std::vector< int > *xcoords)
 
void LabelsViaSimpleText (const NetworkIO &output, std::vector< int > *labels, std::vector< int > *xcoords)
 
const char * DecodeLabel (const std::vector< int > &labels, unsigned start, unsigned *end, int *decoded)
 
const char * DecodeSingleLabel (int label)
 

Protected Attributes

Networknetwork_
 
CCUtil ccutil_
 
UnicharCompress recoder_
 
std::string network_str_
 
int32_t training_flags_
 
int32_t training_iteration_
 
int32_t sample_iteration_
 
int32_t null_char_
 
float learning_rate_
 
float momentum_
 
float adam_beta_
 
TRand randomizer_
 
NetworkScratch scratch_space_
 
Dictdict_
 
RecodeBeamSearchsearch_
 
ScrollViewdebug_win_
 

Detailed Description

Definition at line 51 of file lstmrecognizer.h.

Constructor & Destructor Documentation

◆ LSTMRecognizer() [1/2]

tesseract::LSTMRecognizer::LSTMRecognizer ( )

Definition at line 55 of file lstmrecognizer.cpp.

◆ LSTMRecognizer() [2/2]

tesseract::LSTMRecognizer::LSTMRecognizer ( const std::string &  language_data_path_prefix)

Definition at line 50 of file lstmrecognizer.cpp.

◆ ~LSTMRecognizer()

tesseract::LSTMRecognizer::~LSTMRecognizer ( )

Definition at line 68 of file lstmrecognizer.cpp.

68 {
69 delete network_;
70 delete dict_;
71 delete search_;
72}

Member Function Documentation

◆ ConvertToInt()

void tesseract::LSTMRecognizer::ConvertToInt ( )
inline

Definition at line 181 of file lstmrecognizer.h.

181 {
182 if ((training_flags_ & TF_INT_MODE) == 0) {
185 }
186 }
virtual void ConvertToInt()
Definition: network.h:196

◆ DebugActivationPath()

void tesseract::LSTMRecognizer::DebugActivationPath ( const NetworkIO outputs,
const std::vector< int > &  labels,
const std::vector< int > &  xcoords 
)
protected

Definition at line 451 of file lstmrecognizer.cpp.

452 {
453 if (xcoords[0] > 0) {
454 DebugActivationRange(outputs, "<null>", null_char_, 0, xcoords[0]);
455 }
456 unsigned end = 1;
457 for (unsigned start = 0; start < labels.size(); start = end) {
458 if (labels[start] == null_char_) {
459 end = start + 1;
460 DebugActivationRange(outputs, "<null>", null_char_, xcoords[start], xcoords[end]);
461 continue;
462 } else {
463 int decoded;
464 const char *label = DecodeLabel(labels, start, &end, &decoded);
465 DebugActivationRange(outputs, label, labels[start], xcoords[start], xcoords[start + 1]);
466 for (unsigned i = start + 1; i < end; ++i) {
467 DebugActivationRange(outputs, DecodeSingleLabel(labels[i]), labels[i], xcoords[i],
468 xcoords[i + 1]);
469 }
470 }
471 }
472}
const char * DecodeSingleLabel(int label)
const char * DecodeLabel(const std::vector< int > &labels, unsigned start, unsigned *end, int *decoded)
void DebugActivationRange(const NetworkIO &outputs, const char *label, int best_choice, int x_start, int x_end)

◆ DebugActivationRange()

void tesseract::LSTMRecognizer::DebugActivationRange ( const NetworkIO outputs,
const char *  label,
int  best_choice,
int  x_start,
int  x_end 
)
protected

Definition at line 476 of file lstmrecognizer.cpp.

477 {
478 tprintf("%s=%d On [%d, %d), scores=", label, best_choice, x_start, x_end);
479 double max_score = 0.0;
480 double mean_score = 0.0;
481 const int width = x_end - x_start;
482 for (int x = x_start; x < x_end; ++x) {
483 const float *line = outputs.f(x);
484 const double score = line[best_choice] * 100.0;
485 if (score > max_score) {
486 max_score = score;
487 }
488 mean_score += score / width;
489 int best_c = 0;
490 double best_score = 0.0;
491 for (int c = 0; c < outputs.NumFeatures(); ++c) {
492 if (c != best_choice && line[c] > best_score) {
493 best_c = c;
494 best_score = line[c];
495 }
496 }
497 tprintf(" %.3g(%s=%d=%.3g)", score, DecodeSingleLabel(best_c), best_c, best_score * 100.0);
498 }
499 tprintf(", Mean=%g, max=%g\n", mean_score, max_score);
500}
void tprintf(const char *format,...)
Definition: tprintf.cpp:41

◆ DecodeLabel()

const char * tesseract::LSTMRecognizer::DecodeLabel ( const std::vector< int > &  labels,
unsigned  start,
unsigned *  end,
int *  decoded 
)
protected

Definition at line 560 of file lstmrecognizer.cpp.

561 {
562 *end = start + 1;
563 if (IsRecoding()) {
564 // Decode labels via recoder_.
565 RecodedCharID code;
566 if (labels[start] == null_char_) {
567 if (decoded != nullptr) {
568 code.Set(0, null_char_);
569 *decoded = recoder_.DecodeUnichar(code);
570 }
571 return "<null>";
572 }
573 unsigned index = start;
574 while (index < labels.size() && code.length() < RecodedCharID::kMaxCodeLen) {
575 code.Set(code.length(), labels[index++]);
576 while (index < labels.size() && labels[index] == null_char_) {
577 ++index;
578 }
579 int uni_id = recoder_.DecodeUnichar(code);
580 // If the next label isn't a valid first code, then we need to continue
581 // extending even if we have a valid uni_id from this prefix.
582 if (uni_id != INVALID_UNICHAR_ID &&
583 (index == labels.size() || code.length() == RecodedCharID::kMaxCodeLen ||
584 recoder_.IsValidFirstCode(labels[index]))) {
585 *end = index;
586 if (decoded != nullptr) {
587 *decoded = uni_id;
588 }
589 if (uni_id == UNICHAR_SPACE) {
590 return " ";
591 }
592 return GetUnicharset().get_normed_unichar(uni_id);
593 }
594 }
595 return "<Undecodable>";
596 } else {
597 if (decoded != nullptr) {
598 *decoded = labels[start];
599 }
600 if (labels[start] == null_char_) {
601 return "<null>";
602 }
603 if (labels[start] == UNICHAR_SPACE) {
604 return " ";
605 }
606 return GetUnicharset().get_normed_unichar(labels[start]);
607 }
608}
@ UNICHAR_SPACE
Definition: unicharset.h:36
static const int kMaxCodeLen
bool IsValidFirstCode(int code) const
int DecodeUnichar(const RecodedCharID &code) const
const char * get_normed_unichar(UNICHAR_ID unichar_id) const
Definition: unicharset.h:859
const UNICHARSET & GetUnicharset() const

◆ DecodeLabels()

std::string tesseract::LSTMRecognizer::DecodeLabels ( const std::vector< int > &  labels)

Definition at line 394 of file lstmrecognizer.cpp.

394 {
395 std::string result;
396 unsigned end = 1;
397 for (unsigned start = 0; start < labels.size(); start = end) {
398 if (labels[start] == null_char_) {
399 end = start + 1;
400 } else {
401 result += DecodeLabel(labels, start, &end, nullptr);
402 }
403 }
404 return result;
405}

◆ DecodeSingleLabel()

const char * tesseract::LSTMRecognizer::DecodeSingleLabel ( int  label)
protected

Definition at line 612 of file lstmrecognizer.cpp.

612 {
613 if (label == null_char_) {
614 return "<null>";
615 }
616 if (IsRecoding()) {
617 // Decode label via recoder_.
618 RecodedCharID code;
619 code.Set(0, label);
620 label = recoder_.DecodeUnichar(code);
621 if (label == INVALID_UNICHAR_ID) {
622 return ".."; // Part of a bigger code.
623 }
624 }
625 if (label == UNICHAR_SPACE) {
626 return " ";
627 }
628 return GetUnicharset().get_normed_unichar(label);
629}

◆ DeSerialize()

bool tesseract::LSTMRecognizer::DeSerialize ( const TessdataManager mgr,
TFile fp 
)

Definition at line 133 of file lstmrecognizer.cpp.

133 {
134 delete network_;
136 if (network_ == nullptr) {
137 return false;
138 }
139 bool include_charsets = mgr == nullptr || !mgr->IsComponentAvailable(TESSDATA_LSTM_RECODER) ||
140 !mgr->IsComponentAvailable(TESSDATA_LSTM_UNICHARSET);
141 if (include_charsets && !ccutil_.unicharset.load_from_file(fp, false)) {
142 return false;
143 }
144 if (!fp->DeSerialize(network_str_)) {
145 return false;
146 }
147 if (!fp->DeSerialize(&training_flags_)) {
148 return false;
149 }
150 if (!fp->DeSerialize(&training_iteration_)) {
151 return false;
152 }
153 if (!fp->DeSerialize(&sample_iteration_)) {
154 return false;
155 }
156 if (!fp->DeSerialize(&null_char_)) {
157 return false;
158 }
159 if (!fp->DeSerialize(&adam_beta_)) {
160 return false;
161 }
162 if (!fp->DeSerialize(&learning_rate_)) {
163 return false;
164 }
165 if (!fp->DeSerialize(&momentum_)) {
166 return false;
167 }
168 if (include_charsets && !LoadRecoder(fp)) {
169 return false;
170 }
171 if (!include_charsets && !LoadCharsets(mgr)) {
172 return false;
173 }
176 return true;
177}
@ TESSDATA_LSTM_UNICHARSET
@ TESSDATA_LSTM_RECODER
UNICHARSET unicharset
Definition: ccutil.h:61
bool load_from_file(const char *const filename, bool skip_fragments)
Definition: unicharset.h:391
bool LoadCharsets(const TessdataManager *mgr)
virtual int XScaleFactor() const
Definition: network.h:214
static Network * CreateFromFile(TFile *fp)
Definition: network.cpp:217
virtual void CacheXScaleFactor(int factor)
Definition: network.h:220
virtual void SetRandomizer(TRand *randomizer)
Definition: network.cpp:145

◆ DisplayForward()

void tesseract::LSTMRecognizer::DisplayForward ( const NetworkIO inputs,
const std::vector< int > &  labels,
const std::vector< int > &  label_coords,
const char *  window_name,
ScrollView **  window 
)

Definition at line 411 of file lstmrecognizer.cpp.

413 {
414 Image input_pix = inputs.ToPix();
415 Network::ClearWindow(false, window_name, pixGetWidth(input_pix), pixGetHeight(input_pix), window);
416 int line_height = Network::DisplayImage(input_pix, *window);
417 DisplayLSTMOutput(labels, label_coords, line_height, *window);
418}
void DisplayLSTMOutput(const std::vector< int > &labels, const std::vector< int > &xcoords, int height, ScrollView *window)
static void ClearWindow(bool tess_coords, const char *window_name, int width, int height, ScrollView **window)
Definition: network.cpp:350
static int DisplayImage(Image pix, ScrollView *window)
Definition: network.cpp:378

◆ DisplayLSTMOutput()

void tesseract::LSTMRecognizer::DisplayLSTMOutput ( const std::vector< int > &  labels,
const std::vector< int > &  xcoords,
int  height,
ScrollView window 
)
protected

Definition at line 422 of file lstmrecognizer.cpp.

424 {
425 int x_scale = network_->XScaleFactor();
426 window->TextAttributes("Arial", height / 4, false, false, false);
427 unsigned end = 1;
428 for (unsigned start = 0; start < labels.size(); start = end) {
429 int xpos = xcoords[start] * x_scale;
430 if (labels[start] == null_char_) {
431 end = start + 1;
432 window->Pen(ScrollView::RED);
433 } else {
434 window->Pen(ScrollView::GREEN);
435 const char *str = DecodeLabel(labels, start, &end, nullptr);
436 if (*str == '\\') {
437 str = "\\\\";
438 }
439 xpos = xcoords[(start + end) / 2] * x_scale;
440 window->Text(xpos, height, str);
441 }
442 window->Line(xpos, 0, xpos, height * 3 / 2);
443 }
444 window->Update();
445}

◆ EnumerateLayers()

std::vector< std::string > tesseract::LSTMRecognizer::EnumerateLayers ( ) const
inline

Definition at line 100 of file lstmrecognizer.h.

100 {
101 ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES);
102 auto *series = static_cast<Series *>(network_);
103 std::vector<std::string> layers;
104 series->EnumerateLayers(nullptr, layers);
105 return layers;
106 }
#define ASSERT_HOST(x)
Definition: errcode.h:54
@ NT_SERIES
Definition: network.h:52
NetworkType type() const
Definition: network.h:110

◆ GetAdamBeta()

float tesseract::LSTMRecognizer::GetAdamBeta ( ) const
inline

Definition at line 132 of file lstmrecognizer.h.

132 {
133 return adam_beta_;
134 }

◆ GetDict() [1/2]

Dict * tesseract::LSTMRecognizer::GetDict ( )
inline

Definition at line 203 of file lstmrecognizer.h.

203 {
204 return dict_;
205 }

◆ GetDict() [2/2]

const Dict * tesseract::LSTMRecognizer::GetDict ( ) const
inline

Definition at line 200 of file lstmrecognizer.h.

200 {
201 return dict_;
202 }

◆ GetLayer()

Network * tesseract::LSTMRecognizer::GetLayer ( const std::string &  id) const
inline

Definition at line 108 of file lstmrecognizer.h.

108 {
109 ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES);
110 ASSERT_HOST(id.length() > 1 && id[0] == ':');
111 auto *series = static_cast<Series *>(network_);
112 return series->GetLayer(&id[1]);
113 }

◆ GetLayerLearningRate()

float tesseract::LSTMRecognizer::GetLayerLearningRate ( const std::string &  id) const
inline

Definition at line 115 of file lstmrecognizer.h.

115 {
116 ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES);
118 ASSERT_HOST(id.length() > 1 && id[0] == ':');
119 auto *series = static_cast<Series *>(network_);
120 return series->LayerLearningRate(&id[1]);
121 } else {
122 return learning_rate_;
123 }
124 }
@ NF_LAYER_SPECIFIC_LR
Definition: network.h:85
bool TestFlag(NetworkFlags flag) const
Definition: network.h:146

◆ GetMomentum()

float tesseract::LSTMRecognizer::GetMomentum ( ) const
inline

Definition at line 137 of file lstmrecognizer.h.

137 {
138 return momentum_;
139 }

◆ GetNetwork()

const char * tesseract::LSTMRecognizer::GetNetwork ( ) const
inline

Definition at line 127 of file lstmrecognizer.h.

127 {
128 return network_str_.c_str();
129 }

◆ GetRecoder()

const UnicharCompress & tesseract::LSTMRecognizer::GetRecoder ( ) const
inline

Definition at line 196 of file lstmrecognizer.h.

196 {
197 return recoder_;
198 }

◆ GetUnicharset() [1/2]

UNICHARSET & tesseract::LSTMRecognizer::GetUnicharset ( )
inline

Definition at line 192 of file lstmrecognizer.h.

192 {
193 return ccutil_.unicharset;
194 }

◆ GetUnicharset() [2/2]

const UNICHARSET & tesseract::LSTMRecognizer::GetUnicharset ( ) const
inline

Definition at line 189 of file lstmrecognizer.h.

189 {
190 return ccutil_.unicharset;
191 }

◆ IsIntMode()

bool tesseract::LSTMRecognizer::IsIntMode ( ) const
inline

Definition at line 87 of file lstmrecognizer.h.

87 {
88 return (training_flags_ & TF_INT_MODE) != 0;
89 }

◆ IsRecoding()

bool tesseract::LSTMRecognizer::IsRecoding ( ) const
inline

Definition at line 91 of file lstmrecognizer.h.

91 {
93 }
@ TF_COMPRESS_UNICHARSET

◆ IsTensorFlow()

bool tesseract::LSTMRecognizer::IsTensorFlow ( ) const
inline

Definition at line 95 of file lstmrecognizer.h.

95 {
96 return network_->type() == NT_TENSORFLOW;
97 }
@ NT_TENSORFLOW
Definition: network.h:76

◆ LabelsFromOutputs()

void tesseract::LSTMRecognizer::LabelsFromOutputs ( const NetworkIO outputs,
std::vector< int > *  labels,
std::vector< int > *  xcoords 
)

Definition at line 519 of file lstmrecognizer.cpp.

520 {
521 if (SimpleTextOutput()) {
522 LabelsViaSimpleText(outputs, labels, xcoords);
523 } else {
524 LabelsViaReEncode(outputs, labels, xcoords);
525 }
526}
void LabelsViaReEncode(const NetworkIO &output, std::vector< int > *labels, std::vector< int > *xcoords)
void LabelsViaSimpleText(const NetworkIO &output, std::vector< int > *labels, std::vector< int > *xcoords)

◆ LabelsViaReEncode()

void tesseract::LSTMRecognizer::LabelsViaReEncode ( const NetworkIO output,
std::vector< int > *  labels,
std::vector< int > *  xcoords 
)
protected

Definition at line 530 of file lstmrecognizer.cpp.

531 {
532 if (search_ == nullptr) {
533 search_ = new RecodeBeamSearch(recoder_, null_char_, SimpleTextOutput(), dict_);
534 }
536 search_->ExtractBestPathAsLabels(labels, xcoords);
537}
void Decode(const NetworkIO &output, double dict_ratio, double cert_offset, double worst_dict_cert, const UNICHARSET *charset, int lstm_choice_mode=0)
Definition: recodebeam.cpp:83
void ExtractBestPathAsLabels(std::vector< int > *labels, std::vector< int > *xcoords) const
Definition: recodebeam.cpp:201
static constexpr float kMinCertainty
Definition: recodebeam.h:243

◆ LabelsViaSimpleText()

void tesseract::LSTMRecognizer::LabelsViaSimpleText ( const NetworkIO output,
std::vector< int > *  labels,
std::vector< int > *  xcoords 
)
protected

Definition at line 542 of file lstmrecognizer.cpp.

543 {
544 labels->clear();
545 xcoords->clear();
546 const int width = output.Width();
547 for (int t = 0; t < width; ++t) {
548 float score = 0.0f;
549 const int label = output.BestLabel(t, &score);
550 if (label != null_char_) {
551 labels->push_back(label);
552 xcoords->push_back(t);
553 }
554 }
555 xcoords->push_back(width);
556}

◆ learning_rate()

float tesseract::LSTMRecognizer::learning_rate ( ) const
inline

Definition at line 72 of file lstmrecognizer.h.

72 {
73 return learning_rate_;
74 }

◆ Load()

bool tesseract::LSTMRecognizer::Load ( const ParamsVectors params,
const std::string &  lang,
TessdataManager mgr 
)

Definition at line 75 of file lstmrecognizer.cpp.

76 {
77 TFile fp;
78 if (!mgr->GetComponent(TESSDATA_LSTM, &fp)) {
79 return false;
80 }
81 if (!DeSerialize(mgr, &fp)) {
82 return false;
83 }
84 if (lang.empty()) {
85 return true;
86 }
87 // Allow it to run without a dictionary.
88 LoadDictionary(params, lang, mgr);
89 return true;
90}
bool DeSerialize(const TessdataManager *mgr, TFile *fp)
bool LoadDictionary(const ParamsVectors *params, const std::string &lang, TessdataManager *mgr)

◆ LoadCharsets()

bool tesseract::LSTMRecognizer::LoadCharsets ( const TessdataManager mgr)

Definition at line 180 of file lstmrecognizer.cpp.

180 {
181 TFile fp;
182 if (!mgr->GetComponent(TESSDATA_LSTM_UNICHARSET, &fp)) {
183 return false;
184 }
185 if (!ccutil_.unicharset.load_from_file(&fp, false)) {
186 return false;
187 }
188 if (!mgr->GetComponent(TESSDATA_LSTM_RECODER, &fp)) {
189 return false;
190 }
191 if (!LoadRecoder(&fp)) {
192 return false;
193 }
194 return true;
195}

◆ LoadDictionary()

bool tesseract::LSTMRecognizer::LoadDictionary ( const ParamsVectors params,
const std::string &  lang,
TessdataManager mgr 
)

Definition at line 224 of file lstmrecognizer.cpp.

225 {
226 delete dict_;
227 dict_ = new Dict(&ccutil_);
228 dict_->user_words_file.ResetFrom(params);
229 dict_->user_words_suffix.ResetFrom(params);
230 dict_->user_patterns_file.ResetFrom(params);
231 dict_->user_patterns_suffix.ResetFrom(params);
233 dict_->LoadLSTM(lang, mgr);
234 if (dict_->FinishLoad()) {
235 return true; // Success.
236 }
237 if (log_level <= 0) {
238 tprintf("Failed to load any lstm-specific dictionaries for lang %s!!\n", lang.c_str());
239 }
240 delete dict_;
241 dict_ = nullptr;
242 return false;
243}
int log_level
Definition: tprintf.cpp:36
static DawgCache * GlobalDawgCache()
Definition: dict.cpp:172
void LoadLSTM(const std::string &lang, TessdataManager *data_file)
Definition: dict.cpp:291
void SetupForLoad(DawgCache *dawg_cache)
Definition: dict.cpp:180
bool FinishLoad()
Definition: dict.cpp:357

◆ LoadRecoder()

bool tesseract::LSTMRecognizer::LoadRecoder ( TFile fp)

Definition at line 198 of file lstmrecognizer.cpp.

198 {
199 if (IsRecoding()) {
200 if (!recoder_.DeSerialize(fp)) {
201 return false;
202 }
203 RecodedCharID code;
205 if (code(0) != UNICHAR_SPACE) {
206 tprintf("Space was garbled in recoding!!\n");
207 return false;
208 }
209 } else {
212 }
213 return true;
214}
int EncodeUnichar(unsigned unichar_id, RecodedCharID *code) const
void SetupPassThrough(const UNICHARSET &unicharset)

◆ null_char()

int tesseract::LSTMRecognizer::null_char ( ) const
inline

Definition at line 218 of file lstmrecognizer.h.

218 {
219 return null_char_;
220 }

◆ NumInputs()

int tesseract::LSTMRecognizer::NumInputs ( ) const
inline

Definition at line 213 of file lstmrecognizer.h.

213 {
214 return network_->NumInputs();
215 }
int NumInputs() const
Definition: network.h:122

◆ NumOutputs()

int tesseract::LSTMRecognizer::NumOutputs ( ) const
inline

Definition at line 57 of file lstmrecognizer.h.

57 {
58 return network_->NumOutputs();
59 }
int NumOutputs() const
Definition: network.h:125

◆ OutputLossType()

LossType tesseract::LSTMRecognizer::OutputLossType ( ) const
inline

Definition at line 76 of file lstmrecognizer.h.

76 {
77 if (network_ == nullptr) {
78 return LT_NONE;
79 }
80 StaticShape shape;
81 shape = network_->OutputShape(shape);
82 return shape.loss_type();
83 }
virtual StaticShape OutputShape(const StaticShape &input_shape) const
Definition: network.h:135
LossType loss_type() const
Definition: static_shape.h:65

◆ OutputStats()

void tesseract::LSTMRecognizer::OutputStats ( const NetworkIO outputs,
float *  min_output,
float *  mean_output,
float *  sd 
)

Definition at line 295 of file lstmrecognizer.cpp.

296 {
297 const int kOutputScale = INT8_MAX;
298 STATS stats(0, kOutputScale);
299 for (int t = 0; t < outputs.Width(); ++t) {
300 int best_label = outputs.BestLabel(t, nullptr);
301 if (best_label != null_char_) {
302 float best_output = outputs.f(t)[best_label];
303 stats.add(static_cast<int>(kOutputScale * best_output), 1);
304 }
305 }
306 // If the output is all nulls it could be that the photometric interpretation
307 // is wrong, so make it look bad, so the other way can win, even if not great.
308 if (stats.get_total() == 0) {
309 *min_output = 0.0f;
310 *mean_output = 0.0f;
311 *sd = 1.0f;
312 } else {
313 *min_output = static_cast<float>(stats.min_bucket()) / kOutputScale;
314 *mean_output = stats.mean() / kOutputScale;
315 *sd = stats.sd() / kOutputScale;
316 }
317}

◆ RecognizeLine() [1/2]

bool tesseract::LSTMRecognizer::RecognizeLine ( const ImageData image_data,
float  invert_threshold,
bool  debug,
bool  re_invert,
bool  upside_down,
float *  scale_factor,
NetworkIO inputs,
NetworkIO outputs 
)

Definition at line 321 of file lstmrecognizer.cpp.

324 {
325 // This ensures consistent recognition results.
327 int min_width = network_->XScaleFactor();
328 Image pix = Input::PrepareLSTMInputs(image_data, network_, min_width, &randomizer_, scale_factor);
329 if (pix == nullptr) {
330 tprintf("Line cannot be recognized!!\n");
331 return false;
332 }
333 // Maximum width of image to train on.
334 const int kMaxImageWidth = 128 * pixGetHeight(pix);
335 if (network_->IsTraining() && pixGetWidth(pix) > kMaxImageWidth) {
336 tprintf("Image too large to learn!! Size = %dx%d\n", pixGetWidth(pix), pixGetHeight(pix));
337 pix.destroy();
338 return false;
339 }
340 if (upside_down) {
341 pixRotate180(pix, pix);
342 }
343 // Reduction factor from image to coords.
344 *scale_factor = min_width / *scale_factor;
345 inputs->set_int_mode(IsIntMode());
348 network_->Forward(debug, *inputs, nullptr, &scratch_space_, outputs);
349 // Check for auto inversion.
350 if (invert_threshold > 0.0f) {
351 float pos_min, pos_mean, pos_sd;
352 OutputStats(*outputs, &pos_min, &pos_mean, &pos_sd);
353 if (pos_mean < invert_threshold) {
354 // Run again inverted and see if it is any better.
355 NetworkIO inv_inputs, inv_outputs;
356 inv_inputs.set_int_mode(IsIntMode());
358 pixInvert(pix, pix);
360 network_->Forward(debug, inv_inputs, nullptr, &scratch_space_, &inv_outputs);
361 float inv_min, inv_mean, inv_sd;
362 OutputStats(inv_outputs, &inv_min, &inv_mean, &inv_sd);
363 if (inv_mean > pos_mean) {
364 // Inverted did better. Use inverted data.
365 if (debug) {
366 tprintf("Inverting image: old min=%g, mean=%g, sd=%g, inv %g,%g,%g\n", pos_min, pos_mean,
367 pos_sd, inv_min, inv_mean, inv_sd);
368 }
369 *outputs = inv_outputs;
370 *inputs = inv_inputs;
371 } else if (re_invert) {
372 // Inverting was not an improvement, so undo and run again, so the
373 // outputs match the best forward result.
375 network_->Forward(debug, *inputs, nullptr, &scratch_space_, outputs);
376 }
377 }
378 }
379
380 pix.destroy();
381 if (debug) {
382 std::vector<int> labels, coords;
383 LabelsFromOutputs(*outputs, &labels, &coords);
384#ifndef GRAPHICS_DISABLED
385 DisplayForward(*inputs, labels, coords, "LSTMForward", &debug_win_);
386#endif
387 DebugActivationPath(*outputs, labels, coords);
388 }
389 return true;
390}
static Image PrepareLSTMInputs(const ImageData &image_data, const Network *network, int min_width, TRand *randomizer, float *image_scale)
Definition: input.cpp:81
static void PreparePixInput(const StaticShape &shape, const Image pix, TRand *randomizer, NetworkIO *input)
Definition: input.cpp:107
void DebugActivationPath(const NetworkIO &outputs, const std::vector< int > &labels, const std::vector< int > &xcoords)
NetworkScratch scratch_space_
void LabelsFromOutputs(const NetworkIO &outputs, std::vector< int > *labels, std::vector< int > *xcoords)
void DisplayForward(const NetworkIO &inputs, const std::vector< int > &labels, const std::vector< int > &label_coords, const char *window_name, ScrollView **window)
void OutputStats(const NetworkIO &outputs, float *min_output, float *mean_output, float *sd)
virtual void Forward(bool debug, const NetworkIO &input, const TransposedArray *input_transpose, NetworkScratch *scratch, NetworkIO *output)=0
bool IsTraining() const
Definition: network.h:113
virtual StaticShape InputShape() const
Definition: network.h:129

◆ RecognizeLine() [2/2]

void tesseract::LSTMRecognizer::RecognizeLine ( const ImageData image_data,
float  invert_threshold,
bool  debug,
double  worst_dict_cert,
const TBOX line_box,
PointerVector< WERD_RES > *  words,
int  lstm_choice_mode = 0,
int  lstm_choice_amount = 5 
)

Definition at line 247 of file lstmrecognizer.cpp.

251 {
252 NetworkIO outputs;
253 float scale_factor;
254 NetworkIO inputs;
255 if (!RecognizeLine(image_data, invert_threshold, debug, false, false, &scale_factor, &inputs, &outputs)) {
256 return;
257 }
258 if (search_ == nullptr) {
259 search_ = new RecodeBeamSearch(recoder_, null_char_, SimpleTextOutput(), dict_);
260 }
261 search_->excludedUnichars.clear();
262 search_->Decode(outputs, kDictRatio, kCertOffset, worst_dict_cert, &GetUnicharset(),
263 lstm_choice_mode);
264 search_->ExtractBestPathAsWords(line_box, scale_factor, debug, &GetUnicharset(), words,
265 lstm_choice_mode);
266 if (lstm_choice_mode) {
268 for (int i = 0; i < lstm_choice_amount; ++i) {
269 search_->DecodeSecondaryBeams(outputs, kDictRatio, kCertOffset, worst_dict_cert,
270 &GetUnicharset(), lstm_choice_mode);
272 }
274 unsigned char_it = 0;
275 for (size_t i = 0; i < words->size(); ++i) {
276 for (int j = 0; j < words->at(i)->end; ++j) {
277 if (char_it < search_->ctc_choices.size()) {
278 words->at(i)->CTC_symbol_choices.push_back(search_->ctc_choices[char_it]);
279 }
280 if (char_it < search_->segmentedTimesteps.size()) {
281 words->at(i)->segmented_timesteps.push_back(search_->segmentedTimesteps[char_it]);
282 }
283 ++char_it;
284 }
285 words->at(i)->timesteps =
286 search_->combineSegmentedTimesteps(&words->at(i)->segmented_timesteps);
287 }
289 search_->ctc_choices.clear();
290 search_->excludedUnichars.clear();
291 }
292}
const double kCertOffset
const double kDictRatio
void RecognizeLine(const ImageData &image_data, float invert_threshold, bool debug, double worst_dict_cert, const TBOX &line_box, PointerVector< WERD_RES > *words, int lstm_choice_mode=0, int lstm_choice_amount=5)
std::vector< std::vector< std::pair< const char *, float > > > ctc_choices
Definition: recodebeam.h:234
void DecodeSecondaryBeams(const NetworkIO &output, double dict_ratio, double cert_offset, double worst_dict_cert, const UNICHARSET *charset, int lstm_choice_mode=0)
Definition: recodebeam.cpp:112
std::vector< std::vector< std::pair< const char *, float > > > combineSegmentedTimesteps(std::vector< std::vector< std::vector< std::pair< const char *, float > > > > *segmentedTimesteps)
Definition: recodebeam.cpp:175
std::vector< std::vector< std::vector< std::pair< const char *, float > > > > segmentedTimesteps
Definition: recodebeam.h:232
void extractSymbolChoices(const UNICHARSET *unicharset)
Definition: recodebeam.cpp:409
std::vector< std::unordered_set< int > > excludedUnichars
Definition: recodebeam.h:236
void ExtractBestPathAsWords(const TBOX &line_box, float scale_factor, bool debug, const UNICHARSET *unicharset, PointerVector< WERD_RES > *words, int lstm_choice_mode=0)
Definition: recodebeam.cpp:239

◆ sample_iteration()

int tesseract::LSTMRecognizer::sample_iteration ( ) const
inline

Definition at line 67 of file lstmrecognizer.h.

67 {
68 return sample_iteration_;
69 }

◆ ScaleLayerLearningRate()

void tesseract::LSTMRecognizer::ScaleLayerLearningRate ( const std::string &  id,
double  factor 
)
inline

Definition at line 153 of file lstmrecognizer.h.

153 {
154 ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES);
155 ASSERT_HOST(id.length() > 1 && id[0] == ':');
156 auto *series = static_cast<Series *>(network_);
157 series->ScaleLayerLearningRate(&id[1], factor);
158 }

◆ ScaleLearningRate()

void tesseract::LSTMRecognizer::ScaleLearningRate ( double  factor)
inline

Definition at line 142 of file lstmrecognizer.h.

142 {
143 ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES);
144 learning_rate_ *= factor;
146 std::vector<std::string> layers = EnumerateLayers();
147 for (auto &layer : layers) {
148 ScaleLayerLearningRate(layer, factor);
149 }
150 }
151 }
void ScaleLayerLearningRate(const std::string &id, double factor)
std::vector< std::string > EnumerateLayers() const

◆ Serialize()

bool tesseract::LSTMRecognizer::Serialize ( const TessdataManager mgr,
TFile fp 
) const

Definition at line 93 of file lstmrecognizer.cpp.

93 {
94 bool include_charsets = mgr == nullptr || !mgr->IsComponentAvailable(TESSDATA_LSTM_RECODER) ||
95 !mgr->IsComponentAvailable(TESSDATA_LSTM_UNICHARSET);
96 if (!network_->Serialize(fp)) {
97 return false;
98 }
99 if (include_charsets && !GetUnicharset().save_to_file(fp)) {
100 return false;
101 }
102 if (!fp->Serialize(network_str_)) {
103 return false;
104 }
105 if (!fp->Serialize(&training_flags_)) {
106 return false;
107 }
108 if (!fp->Serialize(&training_iteration_)) {
109 return false;
110 }
111 if (!fp->Serialize(&sample_iteration_)) {
112 return false;
113 }
114 if (!fp->Serialize(&null_char_)) {
115 return false;
116 }
117 if (!fp->Serialize(&adam_beta_)) {
118 return false;
119 }
120 if (!fp->Serialize(&learning_rate_)) {
121 return false;
122 }
123 if (!fp->Serialize(&momentum_)) {
124 return false;
125 }
126 if (include_charsets && IsRecoding() && !recoder_.Serialize(fp)) {
127 return false;
128 }
129 return true;
130}
bool Serialize(TFile *fp) const
virtual bool Serialize(TFile *fp) const
Definition: network.cpp:158

◆ SetIteration()

void tesseract::LSTMRecognizer::SetIteration ( int  iteration)
inline

Definition at line 209 of file lstmrecognizer.h.

209 {
210 sample_iteration_ = iteration;
211 }

◆ SetLayerLearningRate()

void tesseract::LSTMRecognizer::SetLayerLearningRate ( const std::string &  id,
float  learning_rate 
)
inline

Definition at line 172 of file lstmrecognizer.h.

173 {
174 ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES);
175 ASSERT_HOST(id.length() > 1 && id[0] == ':');
176 auto *series = static_cast<Series *>(network_);
177 series->SetLayerLearningRate(&id[1], learning_rate);
178 }

◆ SetLearningRate()

void tesseract::LSTMRecognizer::SetLearningRate ( float  learning_rate)
inline

Definition at line 161 of file lstmrecognizer.h.

162 {
163 ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES);
166 for (auto &id : EnumerateLayers()) {
168 }
169 }
170 }
void SetLayerLearningRate(const std::string &id, float learning_rate)

◆ SetRandomSeed()

void tesseract::LSTMRecognizer::SetRandomSeed ( )
inlineprotected

Definition at line 288 of file lstmrecognizer.h.

288 {
289 int64_t seed = static_cast<int64_t>(sample_iteration_) * 0x10000001;
290 randomizer_.set_seed(seed);
292 }
int32_t IntRand()
Definition: helpers.h:74
void set_seed(uint64_t seed)
Definition: helpers.h:64

◆ SimpleTextOutput()

bool tesseract::LSTMRecognizer::SimpleTextOutput ( ) const
inline

Definition at line 84 of file lstmrecognizer.h.

84 {
85 return OutputLossType() == LT_SOFTMAX;
86 }
LossType OutputLossType() const

◆ training_iteration()

int tesseract::LSTMRecognizer::training_iteration ( ) const
inline

Definition at line 62 of file lstmrecognizer.h.

62 {
64 }

Member Data Documentation

◆ adam_beta_

float tesseract::LSTMRecognizer::adam_beta_
protected

Definition at line 354 of file lstmrecognizer.h.

◆ ccutil_

CCUtil tesseract::LSTMRecognizer::ccutil_
protected

Definition at line 332 of file lstmrecognizer.h.

◆ debug_win_

ScrollView* tesseract::LSTMRecognizer::debug_win_
protected

Definition at line 366 of file lstmrecognizer.h.

◆ dict_

Dict* tesseract::LSTMRecognizer::dict_
protected

Definition at line 360 of file lstmrecognizer.h.

◆ learning_rate_

float tesseract::LSTMRecognizer::learning_rate_
protected

Definition at line 351 of file lstmrecognizer.h.

◆ momentum_

float tesseract::LSTMRecognizer::momentum_
protected

Definition at line 352 of file lstmrecognizer.h.

◆ network_

Network* tesseract::LSTMRecognizer::network_
protected

Definition at line 329 of file lstmrecognizer.h.

◆ network_str_

std::string tesseract::LSTMRecognizer::network_str_
protected

Definition at line 339 of file lstmrecognizer.h.

◆ null_char_

int32_t tesseract::LSTMRecognizer::null_char_
protected

Definition at line 349 of file lstmrecognizer.h.

◆ randomizer_

TRand tesseract::LSTMRecognizer::randomizer_
protected

Definition at line 357 of file lstmrecognizer.h.

◆ recoder_

UnicharCompress tesseract::LSTMRecognizer::recoder_
protected

Definition at line 336 of file lstmrecognizer.h.

◆ sample_iteration_

int32_t tesseract::LSTMRecognizer::sample_iteration_
protected

Definition at line 346 of file lstmrecognizer.h.

◆ scratch_space_

NetworkScratch tesseract::LSTMRecognizer::scratch_space_
protected

Definition at line 358 of file lstmrecognizer.h.

◆ search_

RecodeBeamSearch* tesseract::LSTMRecognizer::search_
protected

Definition at line 362 of file lstmrecognizer.h.

◆ training_flags_

int32_t tesseract::LSTMRecognizer::training_flags_
protected

Definition at line 342 of file lstmrecognizer.h.

◆ training_iteration_

int32_t tesseract::LSTMRecognizer::training_iteration_
protected

Definition at line 344 of file lstmrecognizer.h.


The documentation for this class was generated from the following files: