tesseract  4.00.00dev
tesseract::LSTMRecognizer Class Reference

#include <lstmrecognizer.h>

Inheritance diagram for tesseract::LSTMRecognizer:
tesseract::LSTMTrainer

Public Member Functions

 LSTMRecognizer ()
 
 ~LSTMRecognizer ()
 
int NumOutputs () const
 
int training_iteration () const
 
int sample_iteration () const
 
double learning_rate () const
 
LossType OutputLossType () const
 
bool SimpleTextOutput () const
 
bool IsIntMode () const
 
bool IsRecoding () const
 
bool IsTensorFlow () const
 
GenericVector< STRINGEnumerateLayers () const
 
NetworkGetLayer (const STRING &id) const
 
float GetLayerLearningRate (const STRING &id) const
 
void ScaleLearningRate (double factor)
 
void ScaleLayerLearningRate (const STRING &id, double factor)
 
void ConvertToInt ()
 
const UNICHARSETGetUnicharset () const
 
const UnicharCompressGetRecoder () const
 
const DictGetDict () const
 
void SetIteration (int iteration)
 
int NumInputs () const
 
int null_char () const
 
bool Load (const char *lang, TessdataManager *mgr)
 
bool Serialize (const TessdataManager *mgr, TFile *fp) const
 
bool DeSerialize (const TessdataManager *mgr, TFile *fp)
 
bool LoadCharsets (const TessdataManager *mgr)
 
bool LoadRecoder (TFile *fp)
 
bool LoadDictionary (const char *lang, TessdataManager *mgr)
 
void RecognizeLine (const ImageData &image_data, bool invert, bool debug, double worst_dict_cert, const TBOX &line_box, PointerVector< WERD_RES > *words)
 
void OutputStats (const NetworkIO &outputs, float *min_output, float *mean_output, float *sd)
 
bool RecognizeLine (const ImageData &image_data, bool invert, bool debug, bool re_invert, bool upside_down, float *scale_factor, NetworkIO *inputs, NetworkIO *outputs)
 
STRING DecodeLabels (const GenericVector< int > &labels)
 
void DisplayForward (const NetworkIO &inputs, const GenericVector< int > &labels, const GenericVector< int > &label_coords, const char *window_name, ScrollView **window)
 
void LabelsFromOutputs (const NetworkIO &outputs, GenericVector< int > *labels, GenericVector< int > *xcoords)
 

Protected Member Functions

void SetRandomSeed ()
 
void DisplayLSTMOutput (const GenericVector< int > &labels, const GenericVector< int > &xcoords, int height, ScrollView *window)
 
void DebugActivationPath (const NetworkIO &outputs, const GenericVector< int > &labels, const GenericVector< int > &xcoords)
 
void DebugActivationRange (const NetworkIO &outputs, const char *label, int best_choice, int x_start, int x_end)
 
void LabelsViaReEncode (const NetworkIO &output, GenericVector< int > *labels, GenericVector< int > *xcoords)
 
void LabelsViaSimpleText (const NetworkIO &output, GenericVector< int > *labels, GenericVector< int > *xcoords)
 
const char * DecodeLabel (const GenericVector< int > &labels, int start, int *end, int *decoded)
 
const char * DecodeSingleLabel (int label)
 

Protected Attributes

Networknetwork_
 
CCUtil ccutil_
 
UnicharCompress recoder_
 
STRING network_str_
 
inT32 training_flags_
 
inT32 training_iteration_
 
inT32 sample_iteration_
 
inT32 null_char_
 
float learning_rate_
 
float momentum_
 
float adam_beta_
 
TRand randomizer_
 
NetworkScratch scratch_space_
 
Dictdict_
 
RecodeBeamSearchsearch_
 
ScrollViewdebug_win_
 

Detailed Description

Definition at line 53 of file lstmrecognizer.h.

Constructor & Destructor Documentation

◆ LSTMRecognizer()

tesseract::LSTMRecognizer::LSTMRecognizer ( )

◆ ~LSTMRecognizer()

tesseract::LSTMRecognizer::~LSTMRecognizer ( )

Definition at line 65 of file lstmrecognizer.cpp.

65  {
66  delete network_;
67  delete dict_;
68  delete search_;
69 }
RecodeBeamSearch * search_

Member Function Documentation

◆ ConvertToInt()

void tesseract::LSTMRecognizer::ConvertToInt ( )
inline

Definition at line 131 of file lstmrecognizer.h.

131  {
132  if ((training_flags_ & TF_INT_MODE) == 0) {
135  }
136  }
virtual void ConvertToInt()
Definition: network.h:191

◆ DebugActivationPath()

void tesseract::LSTMRecognizer::DebugActivationPath ( const NetworkIO outputs,
const GenericVector< int > &  labels,
const GenericVector< int > &  xcoords 
)
protected

Definition at line 356 of file lstmrecognizer.cpp.

358  {
359  if (xcoords[0] > 0)
360  DebugActivationRange(outputs, "<null>", null_char_, 0, xcoords[0]);
361  int end = 1;
362  for (int start = 0; start < labels.size(); start = end) {
363  if (labels[start] == null_char_) {
364  end = start + 1;
365  DebugActivationRange(outputs, "<null>", null_char_, xcoords[start],
366  xcoords[end]);
367  continue;
368  } else {
369  int decoded;
370  const char* label = DecodeLabel(labels, start, &end, &decoded);
371  DebugActivationRange(outputs, label, labels[start], xcoords[start],
372  xcoords[start + 1]);
373  for (int i = start + 1; i < end; ++i) {
374  DebugActivationRange(outputs, DecodeSingleLabel(labels[i]), labels[i],
375  xcoords[i], xcoords[i + 1]);
376  }
377  }
378  }
379 }
void DebugActivationRange(const NetworkIO &outputs, const char *label, int best_choice, int x_start, int x_end)
int size() const
Definition: genericvector.h:72
const char * DecodeLabel(const GenericVector< int > &labels, int start, int *end, int *decoded)
const char * DecodeSingleLabel(int label)

◆ DebugActivationRange()

void tesseract::LSTMRecognizer::DebugActivationRange ( const NetworkIO outputs,
const char *  label,
int  best_choice,
int  x_start,
int  x_end 
)
protected

Definition at line 383 of file lstmrecognizer.cpp.

385  {
386  tprintf("%s=%d On [%d, %d), scores=", label, best_choice, x_start, x_end);
387  double max_score = 0.0;
388  double mean_score = 0.0;
389  int width = x_end - x_start;
390  for (int x = x_start; x < x_end; ++x) {
391  const float* line = outputs.f(x);
392  double score = line[best_choice] * 100.0;
393  if (score > max_score) max_score = score;
394  mean_score += score / width;
395  int best_c = 0;
396  double best_score = 0.0;
397  for (int c = 0; c < outputs.NumFeatures(); ++c) {
398  if (c != best_choice && line[c] > best_score) {
399  best_c = c;
400  best_score = line[c];
401  }
402  }
403  tprintf(" %.3g(%s=%d=%.3g)", score, DecodeSingleLabel(best_c), best_c,
404  best_score * 100.0);
405  }
406  tprintf(", Mean=%g, max=%g\n", mean_score, max_score);
407 }
#define tprintf(...)
Definition: tprintf.h:31
const char * DecodeSingleLabel(int label)

◆ DecodeLabel()

const char * tesseract::LSTMRecognizer::DecodeLabel ( const GenericVector< int > &  labels,
int  start,
int *  end,
int *  decoded 
)
protected

Definition at line 469 of file lstmrecognizer.cpp.

470  {
471  *end = start + 1;
472  if (IsRecoding()) {
473  // Decode labels via recoder_.
474  RecodedCharID code;
475  if (labels[start] == null_char_) {
476  if (decoded != NULL) {
477  code.Set(0, null_char_);
478  *decoded = recoder_.DecodeUnichar(code);
479  }
480  return "<null>";
481  }
482  int index = start;
483  while (index < labels.size() &&
484  code.length() < RecodedCharID::kMaxCodeLen) {
485  code.Set(code.length(), labels[index++]);
486  while (index < labels.size() && labels[index] == null_char_) ++index;
487  int uni_id = recoder_.DecodeUnichar(code);
488  // If the next label isn't a valid first code, then we need to continue
489  // extending even if we have a valid uni_id from this prefix.
490  if (uni_id != INVALID_UNICHAR_ID &&
491  (index == labels.size() ||
492  code.length() == RecodedCharID::kMaxCodeLen ||
493  recoder_.IsValidFirstCode(labels[index]))) {
494  *end = index;
495  if (decoded != NULL) *decoded = uni_id;
496  if (uni_id == UNICHAR_SPACE) return " ";
497  return GetUnicharset().get_normed_unichar(uni_id);
498  }
499  }
500  return "<Undecodable>";
501  } else {
502  if (decoded != NULL) *decoded = labels[start];
503  if (labels[start] == null_char_) return "<null>";
504  if (labels[start] == UNICHAR_SPACE) return " ";
505  return GetUnicharset().get_normed_unichar(labels[start]);
506  }
507 }
static const int kMaxCodeLen
int size() const
Definition: genericvector.h:72
const char * get_normed_unichar(UNICHAR_ID unichar_id) const
Definition: unicharset.h:827
const UNICHARSET & GetUnicharset() const
bool IsValidFirstCode(int code) const
int DecodeUnichar(const RecodedCharID &code) const

◆ DecodeLabels()

STRING tesseract::LSTMRecognizer::DecodeLabels ( const GenericVector< int > &  labels)

Definition at line 298 of file lstmrecognizer.cpp.

298  {
299  STRING result;
300  int end = 1;
301  for (int start = 0; start < labels.size(); start = end) {
302  if (labels[start] == null_char_) {
303  end = start + 1;
304  } else {
305  result += DecodeLabel(labels, start, &end, NULL);
306  }
307  }
308  return result;
309 }
int size() const
Definition: genericvector.h:72
const char * DecodeLabel(const GenericVector< int > &labels, int start, int *end, int *decoded)
Definition: strngs.h:45

◆ DecodeSingleLabel()

const char * tesseract::LSTMRecognizer::DecodeSingleLabel ( int  label)
protected

Definition at line 511 of file lstmrecognizer.cpp.

511  {
512  if (label == null_char_) return "<null>";
513  if (IsRecoding()) {
514  // Decode label via recoder_.
515  RecodedCharID code;
516  code.Set(0, label);
517  label = recoder_.DecodeUnichar(code);
518  if (label == INVALID_UNICHAR_ID) return ".."; // Part of a bigger code.
519  }
520  if (label == UNICHAR_SPACE) return " ";
521  return GetUnicharset().get_normed_unichar(label);
522 }
const char * get_normed_unichar(UNICHAR_ID unichar_id) const
Definition: unicharset.h:827
const UNICHARSET & GetUnicharset() const
int DecodeUnichar(const RecodedCharID &code) const

◆ DeSerialize()

bool tesseract::LSTMRecognizer::DeSerialize ( const TessdataManager mgr,
TFile fp 
)

Definition at line 105 of file lstmrecognizer.cpp.

105  {
106  delete network_;
108  if (network_ == NULL) return false;
109  bool include_charsets = mgr == nullptr ||
110  !mgr->IsComponentAvailable(TESSDATA_LSTM_RECODER) ||
111  !mgr->IsComponentAvailable(TESSDATA_LSTM_UNICHARSET);
112  if (include_charsets && !ccutil_.unicharset.load_from_file(fp, false))
113  return false;
114  if (!network_str_.DeSerialize(fp)) return false;
115  if (fp->FReadEndian(&training_flags_, sizeof(training_flags_), 1) != 1)
116  return false;
117  if (fp->FReadEndian(&training_iteration_, sizeof(training_iteration_), 1) !=
118  1)
119  return false;
120  if (fp->FReadEndian(&sample_iteration_, sizeof(sample_iteration_), 1) != 1)
121  return false;
122  if (fp->FReadEndian(&null_char_, sizeof(null_char_), 1) != 1) return false;
123  if (fp->FReadEndian(&adam_beta_, sizeof(adam_beta_), 1) != 1) return false;
124  if (fp->FReadEndian(&learning_rate_, sizeof(learning_rate_), 1) != 1)
125  return false;
126  if (fp->FReadEndian(&momentum_, sizeof(momentum_), 1) != 1) return false;
127  if (include_charsets && !LoadRecoder(fp)) return false;
128  if (!include_charsets && !LoadCharsets(mgr)) return false;
131  return true;
132 }
virtual void SetRandomizer(TRand *randomizer)
Definition: network.cpp:140
virtual int XScaleFactor() const
Definition: network.h:209
UNICHARSET unicharset
Definition: ccutil.h:68
virtual void CacheXScaleFactor(int factor)
Definition: network.h:215
bool load_from_file(const char *const filename, bool skip_fragments)
Definition: unicharset.h:387
static Network * CreateFromFile(TFile *fp)
Definition: network.cpp:203
bool DeSerialize(bool swap, FILE *fp)
Definition: strngs.cpp:163
bool LoadCharsets(const TessdataManager *mgr)

◆ DisplayForward()

void tesseract::LSTMRecognizer::DisplayForward ( const NetworkIO inputs,
const GenericVector< int > &  labels,
const GenericVector< int > &  label_coords,
const char *  window_name,
ScrollView **  window 
)

Definition at line 313 of file lstmrecognizer.cpp.

317  {
318 #ifndef GRAPHICS_DISABLED // do nothing if there's no graphics
319  Pix* input_pix = inputs.ToPix();
320  Network::ClearWindow(false, window_name, pixGetWidth(input_pix),
321  pixGetHeight(input_pix), window);
322  int line_height = Network::DisplayImage(input_pix, *window);
323  DisplayLSTMOutput(labels, label_coords, line_height, *window);
324 #endif // GRAPHICS_DISABLED
325 }
static void ClearWindow(bool tess_coords, const char *window_name, int width, int height, ScrollView **window)
Definition: network.cpp:309
static int DisplayImage(Pix *pix, ScrollView *window)
Definition: network.cpp:332
void DisplayLSTMOutput(const GenericVector< int > &labels, const GenericVector< int > &xcoords, int height, ScrollView *window)

◆ DisplayLSTMOutput()

void tesseract::LSTMRecognizer::DisplayLSTMOutput ( const GenericVector< int > &  labels,
const GenericVector< int > &  xcoords,
int  height,
ScrollView window 
)
protected

Definition at line 329 of file lstmrecognizer.cpp.

331  {
332 #ifndef GRAPHICS_DISABLED // do nothing if there's no graphics
333  int x_scale = network_->XScaleFactor();
334  window->TextAttributes("Arial", height / 4, false, false, false);
335  int end = 1;
336  for (int start = 0; start < labels.size(); start = end) {
337  int xpos = xcoords[start] * x_scale;
338  if (labels[start] == null_char_) {
339  end = start + 1;
340  window->Pen(ScrollView::RED);
341  } else {
342  window->Pen(ScrollView::GREEN);
343  const char* str = DecodeLabel(labels, start, &end, NULL);
344  if (*str == '\\') str = "\\\\";
345  xpos = xcoords[(start + end) / 2] * x_scale;
346  window->Text(xpos, height, str);
347  }
348  window->Line(xpos, 0, xpos, height * 3 / 2);
349  }
350  window->Update();
351 #endif // GRAPHICS_DISABLED
352 }
static void Update()
Definition: scrollview.cpp:715
int size() const
Definition: genericvector.h:72
void Line(int x1, int y1, int x2, int y2)
Definition: scrollview.cpp:538
virtual int XScaleFactor() const
Definition: network.h:209
const char * DecodeLabel(const GenericVector< int > &labels, int start, int *end, int *decoded)
void Pen(Color color)
Definition: scrollview.cpp:726
void Text(int x, int y, const char *mystring)
Definition: scrollview.cpp:658
void TextAttributes(const char *font, int pixel_size, bool bold, bool italic, bool underlined)
Definition: scrollview.cpp:641

◆ EnumerateLayers()

GenericVector<STRING> tesseract::LSTMRecognizer::EnumerateLayers ( ) const
inline

Definition at line 86 of file lstmrecognizer.h.

86  {
87  ASSERT_HOST(network_ != NULL && network_->type() == NT_SERIES);
88  Series* series = static_cast<Series*>(network_);
89  GenericVector<STRING> layers;
90  series->EnumerateLayers(NULL, &layers);
91  return layers;
92  }
#define ASSERT_HOST(x)
Definition: errcode.h:84
NetworkType type() const
Definition: network.h:112

◆ GetDict()

const Dict* tesseract::LSTMRecognizer::GetDict ( ) const
inline

Definition at line 143 of file lstmrecognizer.h.

143 { return dict_; }

◆ GetLayer()

Network* tesseract::LSTMRecognizer::GetLayer ( const STRING id) const
inline

Definition at line 94 of file lstmrecognizer.h.

94  {
95  ASSERT_HOST(network_ != NULL && network_->type() == NT_SERIES);
96  ASSERT_HOST(id.length() > 1 && id[0] == ':');
97  Series* series = static_cast<Series*>(network_);
98  return series->GetLayer(&id[1]);
99  }
#define ASSERT_HOST(x)
Definition: errcode.h:84
NetworkType type() const
Definition: network.h:112

◆ GetLayerLearningRate()

float tesseract::LSTMRecognizer::GetLayerLearningRate ( const STRING id) const
inline

Definition at line 101 of file lstmrecognizer.h.

101  {
102  ASSERT_HOST(network_ != NULL && network_->type() == NT_SERIES);
104  ASSERT_HOST(id.length() > 1 && id[0] == ':');
105  Series* series = static_cast<Series*>(network_);
106  return series->LayerLearningRate(&id[1]);
107  } else {
108  return learning_rate_;
109  }
110  }
bool TestFlag(NetworkFlags flag) const
Definition: network.h:144
#define ASSERT_HOST(x)
Definition: errcode.h:84
NetworkType type() const
Definition: network.h:112

◆ GetRecoder()

const UnicharCompress& tesseract::LSTMRecognizer::GetRecoder ( ) const
inline

Definition at line 141 of file lstmrecognizer.h.

141 { return recoder_; }

◆ GetUnicharset()

const UNICHARSET& tesseract::LSTMRecognizer::GetUnicharset ( ) const
inline

Definition at line 139 of file lstmrecognizer.h.

139 { return ccutil_.unicharset; }
UNICHARSET unicharset
Definition: ccutil.h:68

◆ IsIntMode()

bool tesseract::LSTMRecognizer::IsIntMode ( ) const
inline

Definition at line 77 of file lstmrecognizer.h.

◆ IsRecoding()

bool tesseract::LSTMRecognizer::IsRecoding ( ) const
inline

Definition at line 79 of file lstmrecognizer.h.

◆ IsTensorFlow()

bool tesseract::LSTMRecognizer::IsTensorFlow ( ) const
inline

Definition at line 83 of file lstmrecognizer.h.

83 { return network_->type() == NT_TENSORFLOW; }
NetworkType type() const
Definition: network.h:112

◆ LabelsFromOutputs()

void tesseract::LSTMRecognizer::LabelsFromOutputs ( const NetworkIO outputs,
GenericVector< int > *  labels,
GenericVector< int > *  xcoords 
)

Definition at line 424 of file lstmrecognizer.cpp.

426  {
427  if (SimpleTextOutput()) {
428  LabelsViaSimpleText(outputs, labels, xcoords);
429  } else {
430  LabelsViaReEncode(outputs, labels, xcoords);
431  }
432 }
void LabelsViaReEncode(const NetworkIO &output, GenericVector< int > *labels, GenericVector< int > *xcoords)
void LabelsViaSimpleText(const NetworkIO &output, GenericVector< int > *labels, GenericVector< int > *xcoords)

◆ LabelsViaReEncode()

void tesseract::LSTMRecognizer::LabelsViaReEncode ( const NetworkIO output,
GenericVector< int > *  labels,
GenericVector< int > *  xcoords 
)
protected

Definition at line 436 of file lstmrecognizer.cpp.

438  {
439  if (search_ == NULL) {
440  search_ =
441  new RecodeBeamSearch(recoder_, null_char_, SimpleTextOutput(), dict_);
442  }
443  search_->Decode(output, 1.0, 0.0, RecodeBeamSearch::kMinCertainty, NULL);
444  search_->ExtractBestPathAsLabels(labels, xcoords);
445 }
void Decode(const NetworkIO &output, double dict_ratio, double cert_offset, double worst_dict_cert, const UNICHARSET *charset)
Definition: recodebeam.cpp:76
RecodeBeamSearch * search_
void ExtractBestPathAsLabels(GenericVector< int > *labels, GenericVector< int > *xcoords) const
Definition: recodebeam.cpp:100
static const float kMinCertainty
Definition: recodebeam.h:213

◆ LabelsViaSimpleText()

void tesseract::LSTMRecognizer::LabelsViaSimpleText ( const NetworkIO output,
GenericVector< int > *  labels,
GenericVector< int > *  xcoords 
)
protected

Definition at line 450 of file lstmrecognizer.cpp.

452  {
453  labels->truncate(0);
454  xcoords->truncate(0);
455  int width = output.Width();
456  for (int t = 0; t < width; ++t) {
457  float score = 0.0f;
458  int label = output.BestLabel(t, &score);
459  if (label != null_char_) {
460  labels->push_back(label);
461  xcoords->push_back(t);
462  }
463  }
464  xcoords->push_back(width);
465 }
void truncate(int size)
int push_back(T object)

◆ learning_rate()

double tesseract::LSTMRecognizer::learning_rate ( ) const
inline

Definition at line 67 of file lstmrecognizer.h.

67  {
68  return learning_rate_;
69  }

◆ Load()

bool tesseract::LSTMRecognizer::Load ( const char *  lang,
TessdataManager mgr 
)

Definition at line 72 of file lstmrecognizer.cpp.

72  {
73  TFile fp;
74  if (!mgr->GetComponent(TESSDATA_LSTM, &fp)) return false;
75  if (!DeSerialize(mgr, &fp)) return false;
76  if (lang == nullptr) return true;
77  // Allow it to run without a dictionary.
78  LoadDictionary(lang, mgr);
79  return true;
80 }
bool DeSerialize(const TessdataManager *mgr, TFile *fp)
bool LoadDictionary(const char *lang, TessdataManager *mgr)

◆ LoadCharsets()

bool tesseract::LSTMRecognizer::LoadCharsets ( const TessdataManager mgr)

Definition at line 135 of file lstmrecognizer.cpp.

135  {
136  TFile fp;
137  if (!mgr->GetComponent(TESSDATA_LSTM_UNICHARSET, &fp)) return false;
138  if (!ccutil_.unicharset.load_from_file(&fp, false)) return false;
139  if (!mgr->GetComponent(TESSDATA_LSTM_RECODER, &fp)) return false;
140  if (!LoadRecoder(&fp)) return false;
141  return true;
142 }
UNICHARSET unicharset
Definition: ccutil.h:68
bool load_from_file(const char *const filename, bool skip_fragments)
Definition: unicharset.h:387

◆ LoadDictionary()

bool tesseract::LSTMRecognizer::LoadDictionary ( const char *  lang,
TessdataManager mgr 
)

Definition at line 168 of file lstmrecognizer.cpp.

168  {
169  delete dict_;
170  dict_ = new Dict(&ccutil_);
172  dict_->LoadLSTM(lang, mgr);
173  if (dict_->FinishLoad()) return true; // Success.
174  tprintf("Failed to load any lstm-specific dictionaries for lang %s!!\n",
175  lang);
176  delete dict_;
177  dict_ = NULL;
178  return false;
179 }
void SetupForLoad(DawgCache *dawg_cache)
Definition: dict.cpp:206
bool FinishLoad()
Definition: dict.cpp:328
#define tprintf(...)
Definition: tprintf.h:31
static DawgCache * GlobalDawgCache()
Definition: dict.cpp:198
void LoadLSTM(const STRING &lang, TessdataManager *data_file)
Definition: dict.cpp:307

◆ LoadRecoder()

bool tesseract::LSTMRecognizer::LoadRecoder ( TFile fp)

Definition at line 145 of file lstmrecognizer.cpp.

145  {
146  if (IsRecoding()) {
147  if (!recoder_.DeSerialize(fp)) return false;
148  RecodedCharID code;
150  if (code(0) != UNICHAR_SPACE) {
151  tprintf("Space was garbled in recoding!!\n");
152  return false;
153  }
154  } else {
157  }
158  return true;
159 }
#define tprintf(...)
Definition: tprintf.h:31
int EncodeUnichar(int unichar_id, RecodedCharID *code) const
void SetupPassThrough(const UNICHARSET &unicharset)
const UNICHARSET & GetUnicharset() const

◆ null_char()

int tesseract::LSTMRecognizer::null_char ( ) const
inline

Definition at line 154 of file lstmrecognizer.h.

154 { return null_char_; }

◆ NumInputs()

int tesseract::LSTMRecognizer::NumInputs ( ) const
inline

Definition at line 151 of file lstmrecognizer.h.

151  {
152  return network_->NumInputs();
153  }
int NumInputs() const
Definition: network.h:120

◆ NumOutputs()

int tesseract::LSTMRecognizer::NumOutputs ( ) const
inline

Definition at line 58 of file lstmrecognizer.h.

58  {
59  return network_->NumOutputs();
60  }
int NumOutputs() const
Definition: network.h:123

◆ OutputLossType()

LossType tesseract::LSTMRecognizer::OutputLossType ( ) const
inline

Definition at line 70 of file lstmrecognizer.h.

70  {
71  if (network_ == nullptr) return LT_NONE;
72  StaticShape shape;
73  shape = network_->OutputShape(shape);
74  return shape.loss_type();
75  }
LossType loss_type() const
Definition: static_shape.h:48
virtual StaticShape OutputShape(const StaticShape &input_shape) const
Definition: network.h:133

◆ OutputStats()

void tesseract::LSTMRecognizer::OutputStats ( const NetworkIO outputs,
float *  min_output,
float *  mean_output,
float *  sd 
)

Definition at line 203 of file lstmrecognizer.cpp.

204  {
205  const int kOutputScale = MAX_INT8;
206  STATS stats(0, kOutputScale + 1);
207  for (int t = 0; t < outputs.Width(); ++t) {
208  int best_label = outputs.BestLabel(t, NULL);
209  if (best_label != null_char_) {
210  float best_output = outputs.f(t)[best_label];
211  stats.add(static_cast<int>(kOutputScale * best_output), 1);
212  }
213  }
214  // If the output is all nulls it could be that the photometric interpretation
215  // is wrong, so make it look bad, so the other way can win, even if not great.
216  if (stats.get_total() == 0) {
217  *min_output = 0.0f;
218  *mean_output = 0.0f;
219  *sd = 1.0f;
220  } else {
221  *min_output = static_cast<float>(stats.min_bucket()) / kOutputScale;
222  *mean_output = stats.mean() / kOutputScale;
223  *sd = stats.sd() / kOutputScale;
224  }
225 }
#define MAX_INT8
Definition: host.h:60
Definition: statistc.h:33

◆ RecognizeLine() [1/2]

void tesseract::LSTMRecognizer::RecognizeLine ( const ImageData image_data,
bool  invert,
bool  debug,
double  worst_dict_cert,
const TBOX line_box,
PointerVector< WERD_RES > *  words 
)

Definition at line 183 of file lstmrecognizer.cpp.

186  {
187  NetworkIO outputs;
188  float scale_factor;
189  NetworkIO inputs;
190  if (!RecognizeLine(image_data, invert, debug, false, false, &scale_factor,
191  &inputs, &outputs))
192  return;
193  if (search_ == NULL) {
194  search_ =
195  new RecodeBeamSearch(recoder_, null_char_, SimpleTextOutput(), dict_);
196  }
197  search_->Decode(outputs, kDictRatio, kCertOffset, worst_dict_cert, NULL);
198  search_->ExtractBestPathAsWords(line_box, scale_factor, debug,
199  &GetUnicharset(), words);
200 }
const double kCertOffset
void Decode(const NetworkIO &output, double dict_ratio, double cert_offset, double worst_dict_cert, const UNICHARSET *charset)
Definition: recodebeam.cpp:76
RecodeBeamSearch * search_
void RecognizeLine(const ImageData &image_data, bool invert, bool debug, double worst_dict_cert, const TBOX &line_box, PointerVector< WERD_RES > *words)
const double kDictRatio
void ExtractBestPathAsWords(const TBOX &line_box, float scale_factor, bool debug, const UNICHARSET *unicharset, PointerVector< WERD_RES > *words)
Definition: recodebeam.cpp:138
const UNICHARSET & GetUnicharset() const

◆ RecognizeLine() [2/2]

bool tesseract::LSTMRecognizer::RecognizeLine ( const ImageData image_data,
bool  invert,
bool  debug,
bool  re_invert,
bool  upside_down,
float *  scale_factor,
NetworkIO inputs,
NetworkIO outputs 
)

Definition at line 229 of file lstmrecognizer.cpp.

232  {
233  // Maximum width of image to train on.
234  const int kMaxImageWidth = 2560;
235  // This ensures consistent recognition results.
236  SetRandomSeed();
237  int min_width = network_->XScaleFactor();
238  Pix* pix = Input::PrepareLSTMInputs(image_data, network_, min_width,
239  &randomizer_, scale_factor);
240  if (pix == NULL) {
241  tprintf("Line cannot be recognized!!\n");
242  return false;
243  }
244  if (network_->IsTraining() && pixGetWidth(pix) > kMaxImageWidth) {
245  tprintf("Image too large to learn!! Size = %dx%d\n", pixGetWidth(pix),
246  pixGetHeight(pix));
247  pixDestroy(&pix);
248  return false;
249  }
250  if (upside_down) pixRotate180(pix, pix);
251  // Reduction factor from image to coords.
252  *scale_factor = min_width / *scale_factor;
253  inputs->set_int_mode(IsIntMode());
254  SetRandomSeed();
256  network_->Forward(debug, *inputs, NULL, &scratch_space_, outputs);
257  // Check for auto inversion.
258  float pos_min, pos_mean, pos_sd;
259  OutputStats(*outputs, &pos_min, &pos_mean, &pos_sd);
260  if (invert && pos_min < 0.5) {
261  // Run again inverted and see if it is any better.
262  NetworkIO inv_inputs, inv_outputs;
263  inv_inputs.set_int_mode(IsIntMode());
264  SetRandomSeed();
265  pixInvert(pix, pix);
267  &inv_inputs);
268  network_->Forward(debug, inv_inputs, NULL, &scratch_space_, &inv_outputs);
269  float inv_min, inv_mean, inv_sd;
270  OutputStats(inv_outputs, &inv_min, &inv_mean, &inv_sd);
271  if (inv_min > pos_min && inv_mean > pos_mean && inv_sd < pos_sd) {
272  // Inverted did better. Use inverted data.
273  if (debug) {
274  tprintf("Inverting image: old min=%g, mean=%g, sd=%g, inv %g,%g,%g\n",
275  pos_min, pos_mean, pos_sd, inv_min, inv_mean, inv_sd);
276  }
277  *outputs = inv_outputs;
278  *inputs = inv_inputs;
279  } else if (re_invert) {
280  // Inverting was not an improvement, so undo and run again, so the
281  // outputs match the best forward result.
282  SetRandomSeed();
283  network_->Forward(debug, *inputs, NULL, &scratch_space_, outputs);
284  }
285  }
286  pixDestroy(&pix);
287  if (debug) {
288  GenericVector<int> labels, coords;
289  LabelsFromOutputs(*outputs, &labels, &coords);
290  DisplayForward(*inputs, labels, coords, "LSTMForward", &debug_win_);
291  DebugActivationPath(*outputs, labels, coords);
292  }
293  return true;
294 }
void DisplayForward(const NetworkIO &inputs, const GenericVector< int > &labels, const GenericVector< int > &label_coords, const char *window_name, ScrollView **window)
void DebugActivationPath(const NetworkIO &outputs, const GenericVector< int > &labels, const GenericVector< int > &xcoords)
void LabelsFromOutputs(const NetworkIO &outputs, GenericVector< int > *labels, GenericVector< int > *xcoords)
void OutputStats(const NetworkIO &outputs, float *min_output, float *mean_output, float *sd)
static void PreparePixInput(const StaticShape &shape, const Pix *pix, TRand *randomizer, NetworkIO *input)
Definition: input.cpp:117
virtual int XScaleFactor() const
Definition: network.h:209
#define tprintf(...)
Definition: tprintf.h:31
virtual void Forward(bool debug, const NetworkIO &input, const TransposedArray *input_transpose, NetworkScratch *scratch, NetworkIO *output)
Definition: network.h:262
virtual StaticShape InputShape() const
Definition: network.h:127
bool IsTraining() const
Definition: network.h:115
static Pix * PrepareLSTMInputs(const ImageData &image_data, const Network *network, int min_width, TRand *randomizer, float *image_scale)
Definition: input.cpp:89
NetworkScratch scratch_space_

◆ sample_iteration()

int tesseract::LSTMRecognizer::sample_iteration ( ) const
inline

Definition at line 64 of file lstmrecognizer.h.

64  {
65  return sample_iteration_;
66  }

◆ ScaleLayerLearningRate()

void tesseract::LSTMRecognizer::ScaleLayerLearningRate ( const STRING id,
double  factor 
)
inline

Definition at line 123 of file lstmrecognizer.h.

123  {
124  ASSERT_HOST(network_ != NULL && network_->type() == NT_SERIES);
125  ASSERT_HOST(id.length() > 1 && id[0] == ':');
126  Series* series = static_cast<Series*>(network_);
127  series->ScaleLayerLearningRate(&id[1], factor);
128  }
#define ASSERT_HOST(x)
Definition: errcode.h:84
NetworkType type() const
Definition: network.h:112

◆ ScaleLearningRate()

void tesseract::LSTMRecognizer::ScaleLearningRate ( double  factor)
inline

Definition at line 112 of file lstmrecognizer.h.

112  {
113  ASSERT_HOST(network_ != NULL && network_->type() == NT_SERIES);
114  learning_rate_ *= factor;
117  for (int i = 0; i < layers.size(); ++i) {
118  ScaleLayerLearningRate(layers[i], factor);
119  }
120  }
121  }
bool TestFlag(NetworkFlags flag) const
Definition: network.h:144
int size() const
Definition: genericvector.h:72
void ScaleLayerLearningRate(const STRING &id, double factor)
#define ASSERT_HOST(x)
Definition: errcode.h:84
NetworkType type() const
Definition: network.h:112
GenericVector< STRING > EnumerateLayers() const

◆ Serialize()

bool tesseract::LSTMRecognizer::Serialize ( const TessdataManager mgr,
TFile fp 
) const

Definition at line 83 of file lstmrecognizer.cpp.

83  {
84  bool include_charsets = mgr == nullptr ||
85  !mgr->IsComponentAvailable(TESSDATA_LSTM_RECODER) ||
86  !mgr->IsComponentAvailable(TESSDATA_LSTM_UNICHARSET);
87  if (!network_->Serialize(fp)) return false;
88  if (include_charsets && !GetUnicharset().save_to_file(fp)) return false;
89  if (!network_str_.Serialize(fp)) return false;
90  if (fp->FWrite(&training_flags_, sizeof(training_flags_), 1) != 1)
91  return false;
92  if (fp->FWrite(&training_iteration_, sizeof(training_iteration_), 1) != 1)
93  return false;
94  if (fp->FWrite(&sample_iteration_, sizeof(sample_iteration_), 1) != 1)
95  return false;
96  if (fp->FWrite(&null_char_, sizeof(null_char_), 1) != 1) return false;
97  if (fp->FWrite(&adam_beta_, sizeof(adam_beta_), 1) != 1) return false;
98  if (fp->FWrite(&learning_rate_, sizeof(learning_rate_), 1) != 1) return false;
99  if (fp->FWrite(&momentum_, sizeof(momentum_), 1) != 1) return false;
100  if (include_charsets && IsRecoding() && !recoder_.Serialize(fp)) return false;
101  return true;
102 }
virtual bool Serialize(TFile *fp) const
Definition: network.cpp:153
bool save_to_file(const char *const filename) const
Definition: unicharset.h:347
bool Serialize(TFile *fp) const
const UNICHARSET & GetUnicharset() const
bool Serialize(FILE *fp) const
Definition: strngs.cpp:148

◆ SetIteration()

void tesseract::LSTMRecognizer::SetIteration ( int  iteration)
inline

Definition at line 147 of file lstmrecognizer.h.

147  {
148  sample_iteration_ = iteration;
149  }

◆ SetRandomSeed()

void tesseract::LSTMRecognizer::SetRandomSeed ( )
inlineprotected

Definition at line 225 of file lstmrecognizer.h.

225  {
226  inT64 seed = static_cast<inT64>(sample_iteration_) * 0x10000001;
227  randomizer_.set_seed(seed);
229  }
void set_seed(uinT64 seed)
Definition: helpers.h:45
inT32 IntRand()
Definition: helpers.h:55
int64_t inT64
Definition: host.h:40

◆ SimpleTextOutput()

bool tesseract::LSTMRecognizer::SimpleTextOutput ( ) const
inline

Definition at line 76 of file lstmrecognizer.h.

◆ training_iteration()

int tesseract::LSTMRecognizer::training_iteration ( ) const
inline

Definition at line 61 of file lstmrecognizer.h.

61  {
62  return training_iteration_;
63  }

Member Data Documentation

◆ adam_beta_

float tesseract::LSTMRecognizer::adam_beta_
protected

Definition at line 295 of file lstmrecognizer.h.

◆ ccutil_

CCUtil tesseract::LSTMRecognizer::ccutil_
protected

Definition at line 273 of file lstmrecognizer.h.

◆ debug_win_

ScrollView* tesseract::LSTMRecognizer::debug_win_
protected

Definition at line 307 of file lstmrecognizer.h.

◆ dict_

Dict* tesseract::LSTMRecognizer::dict_
protected

Definition at line 301 of file lstmrecognizer.h.

◆ learning_rate_

float tesseract::LSTMRecognizer::learning_rate_
protected

Definition at line 292 of file lstmrecognizer.h.

◆ momentum_

float tesseract::LSTMRecognizer::momentum_
protected

Definition at line 293 of file lstmrecognizer.h.

◆ network_

Network* tesseract::LSTMRecognizer::network_
protected

Definition at line 270 of file lstmrecognizer.h.

◆ network_str_

STRING tesseract::LSTMRecognizer::network_str_
protected

Definition at line 280 of file lstmrecognizer.h.

◆ null_char_

inT32 tesseract::LSTMRecognizer::null_char_
protected

Definition at line 290 of file lstmrecognizer.h.

◆ randomizer_

TRand tesseract::LSTMRecognizer::randomizer_
protected

Definition at line 298 of file lstmrecognizer.h.

◆ recoder_

UnicharCompress tesseract::LSTMRecognizer::recoder_
protected

Definition at line 277 of file lstmrecognizer.h.

◆ sample_iteration_

inT32 tesseract::LSTMRecognizer::sample_iteration_
protected

Definition at line 287 of file lstmrecognizer.h.

◆ scratch_space_

NetworkScratch tesseract::LSTMRecognizer::scratch_space_
protected

Definition at line 299 of file lstmrecognizer.h.

◆ search_

RecodeBeamSearch* tesseract::LSTMRecognizer::search_
protected

Definition at line 303 of file lstmrecognizer.h.

◆ training_flags_

inT32 tesseract::LSTMRecognizer::training_flags_
protected

Definition at line 283 of file lstmrecognizer.h.

◆ training_iteration_

inT32 tesseract::LSTMRecognizer::training_iteration_
protected

Definition at line 285 of file lstmrecognizer.h.


The documentation for this class was generated from the following files: