tesseract  4.00.00dev
tesseract::LSTMTrainer Class Reference

#include <lstmtrainer.h>

Inheritance diagram for tesseract::LSTMTrainer:
tesseract::LSTMRecognizer

Public Member Functions

 LSTMTrainer ()
 
 LSTMTrainer (FileReader file_reader, FileWriter file_writer, CheckPointReader checkpoint_reader, CheckPointWriter checkpoint_writer, const char *model_base, const char *checkpoint_name, int debug_interval, inT64 max_memory)
 
virtual ~LSTMTrainer ()
 
bool TryLoadingCheckpoint (const char *filename, const char *old_traineddata)
 
void InitCharSet (const string &traineddata_path)
 
void InitCharSet (const TessdataManager &mgr)
 
bool InitNetwork (const STRING &network_spec, int append_index, int net_flags, float weight_range, float learning_rate, float momentum, float adam_beta)
 
int InitTensorFlowNetwork (const std::string &tf_proto)
 
void InitIterations ()
 
double ActivationError () const
 
double CharError () const
 
const double * error_rates () const
 
double best_error_rate () const
 
int best_iteration () const
 
int learning_iteration () const
 
int improvement_steps () const
 
void set_perfect_delay (int delay)
 
const GenericVector< char > & best_trainer () const
 
double NewSingleError (ErrorTypes type) const
 
double LastSingleError (ErrorTypes type) const
 
const DocumentCachetraining_data () const
 
DocumentCachemutable_training_data ()
 
Trainability GridSearchDictParams (const ImageData *trainingdata, int iteration, double min_dict_ratio, double dict_ratio_step, double max_dict_ratio, double min_cert_offset, double cert_offset_step, double max_cert_offset, STRING *results)
 
void DebugNetwork ()
 
bool LoadAllTrainingData (const GenericVector< STRING > &filenames, CachingStrategy cache_strategy, bool randomly_rotate)
 
bool MaintainCheckpoints (TestCallback tester, STRING *log_msg)
 
bool MaintainCheckpointsSpecific (int iteration, const GenericVector< char > *train_model, const GenericVector< char > *rec_model, TestCallback tester, STRING *log_msg)
 
void PrepareLogMsg (STRING *log_msg) const
 
void LogIterations (const char *intro_str, STRING *log_msg) const
 
bool TransitionTrainingStage (float error_threshold)
 
int CurrentTrainingStage () const
 
virtual bool Serialize (SerializeAmount serialize_amount, const TessdataManager *mgr, TFile *fp) const
 
virtual bool DeSerialize (const TessdataManager *mgr, TFile *fp)
 
void StartSubtrainer (STRING *log_msg)
 
SubTrainerResult UpdateSubtrainer (STRING *log_msg)
 
void ReduceLearningRates (LSTMTrainer *samples_trainer, STRING *log_msg)
 
int ReduceLayerLearningRates (double factor, int num_samples, LSTMTrainer *samples_trainer)
 
bool EncodeString (const STRING &str, GenericVector< int > *labels) const
 
const ImageDataTrainOnLine (LSTMTrainer *samples_trainer, bool batch)
 
Trainability TrainOnLine (const ImageData *trainingdata, bool batch)
 
Trainability PrepareForBackward (const ImageData *trainingdata, NetworkIO *fwd_outputs, NetworkIO *targets)
 
bool SaveTrainingDump (SerializeAmount serialize_amount, const LSTMTrainer *trainer, GenericVector< char > *data) const
 
bool ReadTrainingDump (const GenericVector< char > &data, LSTMTrainer *trainer) const
 
bool ReadSizedTrainingDump (const char *data, int size, LSTMTrainer *trainer) const
 
bool ReadLocalTrainingDump (const TessdataManager *mgr, const char *data, int size)
 
void SetupCheckpointInfo ()
 
bool SaveTraineddata (const STRING &filename)
 
void SaveRecognitionDump (GenericVector< char > *data) const
 
STRING DumpFilename () const
 
void FillErrorBuffer (double new_error, ErrorTypes type)
 
std::vector< int > MapRecoder (const UNICHARSET &old_chset, const UnicharCompress &old_recoder) const
 
- Public Member Functions inherited from tesseract::LSTMRecognizer
 LSTMRecognizer ()
 
 ~LSTMRecognizer ()
 
int NumOutputs () const
 
int training_iteration () const
 
int sample_iteration () const
 
double learning_rate () const
 
LossType OutputLossType () const
 
bool SimpleTextOutput () const
 
bool IsIntMode () const
 
bool IsRecoding () const
 
bool IsTensorFlow () const
 
GenericVector< STRINGEnumerateLayers () const
 
NetworkGetLayer (const STRING &id) const
 
float GetLayerLearningRate (const STRING &id) const
 
void ScaleLearningRate (double factor)
 
void ScaleLayerLearningRate (const STRING &id, double factor)
 
void ConvertToInt ()
 
const UNICHARSETGetUnicharset () const
 
const UnicharCompressGetRecoder () const
 
const DictGetDict () const
 
void SetIteration (int iteration)
 
int NumInputs () const
 
int null_char () const
 
bool Load (const char *lang, TessdataManager *mgr)
 
bool Serialize (const TessdataManager *mgr, TFile *fp) const
 
bool DeSerialize (const TessdataManager *mgr, TFile *fp)
 
bool LoadCharsets (const TessdataManager *mgr)
 
bool LoadRecoder (TFile *fp)
 
bool LoadDictionary (const char *lang, TessdataManager *mgr)
 
void RecognizeLine (const ImageData &image_data, bool invert, bool debug, double worst_dict_cert, const TBOX &line_box, PointerVector< WERD_RES > *words)
 
void OutputStats (const NetworkIO &outputs, float *min_output, float *mean_output, float *sd)
 
bool RecognizeLine (const ImageData &image_data, bool invert, bool debug, bool re_invert, bool upside_down, float *scale_factor, NetworkIO *inputs, NetworkIO *outputs)
 
STRING DecodeLabels (const GenericVector< int > &labels)
 
void DisplayForward (const NetworkIO &inputs, const GenericVector< int > &labels, const GenericVector< int > &label_coords, const char *window_name, ScrollView **window)
 
void LabelsFromOutputs (const NetworkIO &outputs, GenericVector< int > *labels, GenericVector< int > *xcoords)
 

Static Public Member Functions

static bool EncodeString (const STRING &str, const UNICHARSET &unicharset, const UnicharCompress *recoder, bool simple_text, int null_char, GenericVector< int > *labels)
 

Protected Member Functions

void InitCharSet ()
 
void SetNullChar ()
 
void EmptyConstructor ()
 
bool DebugLSTMTraining (const NetworkIO &inputs, const ImageData &trainingdata, const NetworkIO &fwd_outputs, const GenericVector< int > &truth_labels, const NetworkIO &outputs)
 
void DisplayTargets (const NetworkIO &targets, const char *window_name, ScrollView **window)
 
bool ComputeTextTargets (const NetworkIO &outputs, const GenericVector< int > &truth_labels, NetworkIO *targets)
 
bool ComputeCTCTargets (const GenericVector< int > &truth_labels, NetworkIO *outputs, NetworkIO *targets)
 
double ComputeErrorRates (const NetworkIO &deltas, double char_error, double word_error)
 
double ComputeRMSError (const NetworkIO &deltas)
 
double ComputeWinnerError (const NetworkIO &deltas)
 
double ComputeCharError (const GenericVector< int > &truth_str, const GenericVector< int > &ocr_str)
 
double ComputeWordError (STRING *truth_str, STRING *ocr_str)
 
void UpdateErrorBuffer (double new_error, ErrorTypes type)
 
void RollErrorBuffers ()
 
STRING UpdateErrorGraph (int iteration, double error_rate, const GenericVector< char > &model_data, TestCallback tester)
 
- Protected Member Functions inherited from tesseract::LSTMRecognizer
void SetRandomSeed ()
 
void DisplayLSTMOutput (const GenericVector< int > &labels, const GenericVector< int > &xcoords, int height, ScrollView *window)
 
void DebugActivationPath (const NetworkIO &outputs, const GenericVector< int > &labels, const GenericVector< int > &xcoords)
 
void DebugActivationRange (const NetworkIO &outputs, const char *label, int best_choice, int x_start, int x_end)
 
void LabelsViaReEncode (const NetworkIO &output, GenericVector< int > *labels, GenericVector< int > *xcoords)
 
void LabelsViaSimpleText (const NetworkIO &output, GenericVector< int > *labels, GenericVector< int > *xcoords)
 
const char * DecodeLabel (const GenericVector< int > &labels, int start, int *end, int *decoded)
 
const char * DecodeSingleLabel (int label)
 

Protected Attributes

ScrollViewalign_win_
 
ScrollViewtarget_win_
 
ScrollViewctc_win_
 
ScrollViewrecon_win_
 
int debug_interval_
 
int checkpoint_iteration_
 
STRING model_base_
 
STRING checkpoint_name_
 
bool randomly_rotate_
 
DocumentCache training_data_
 
STRING best_model_name_
 
int num_training_stages_
 
FileReader file_reader_
 
FileWriter file_writer_
 
CheckPointReader checkpoint_reader_
 
CheckPointWriter checkpoint_writer_
 
double best_error_rate_
 
double best_error_rates_ [ET_COUNT]
 
int best_iteration_
 
double worst_error_rate_
 
double worst_error_rates_ [ET_COUNT]
 
int worst_iteration_
 
int stall_iteration_
 
GenericVector< char > best_model_data_
 
GenericVector< char > worst_model_data_
 
GenericVector< char > best_trainer_
 
LSTMTrainersub_trainer_
 
float error_rate_of_last_saved_best_
 
int training_stage_
 
GenericVector< double > best_error_history_
 
GenericVector< int > best_error_iterations_
 
int improvement_steps_
 
int learning_iteration_
 
int prev_sample_iteration_
 
int perfect_delay_
 
int last_perfect_training_iteration_
 
GenericVector< double > error_buffers_ [ET_COUNT]
 
double error_rates_ [ET_COUNT]
 
TessdataManager mgr_
 
- Protected Attributes inherited from tesseract::LSTMRecognizer
Networknetwork_
 
CCUtil ccutil_
 
UnicharCompress recoder_
 
STRING network_str_
 
inT32 training_flags_
 
inT32 training_iteration_
 
inT32 sample_iteration_
 
inT32 null_char_
 
float learning_rate_
 
float momentum_
 
float adam_beta_
 
TRand randomizer_
 
NetworkScratch scratch_space_
 
Dictdict_
 
RecodeBeamSearchsearch_
 
ScrollViewdebug_win_
 

Static Protected Attributes

static const int kRollingBufferSize_ = 1000
 

Detailed Description

Definition at line 89 of file lstmtrainer.h.

Constructor & Destructor Documentation

◆ LSTMTrainer() [1/2]

tesseract::LSTMTrainer::LSTMTrainer ( )

Definition at line 73 of file lstmtrainer.cpp.

74  : randomly_rotate_(false),
75  training_data_(0),
82  sub_trainer_(NULL) {
84  debug_interval_ = 0;
85 }
_ConstTessMemberResultCallback_0_0< false, R, T1 >::base * NewPermanentTessCallback(const T1 *obj, R(T2::*member)() const)
Definition: tesscallback.h:116
LSTMTrainer * sub_trainer_
Definition: lstmtrainer.h:450
CheckPointReader checkpoint_reader_
Definition: lstmtrainer.h:424
bool SaveTrainingDump(SerializeAmount serialize_amount, const LSTMTrainer *trainer, GenericVector< char > *data) const
CheckPointWriter checkpoint_writer_
Definition: lstmtrainer.h:425
bool SaveDataToFile(const GenericVector< char > &data, const STRING &filename)
bool LoadDataFromFile(const char *filename, GenericVector< char > *data)
DocumentCache training_data_
Definition: lstmtrainer.h:414
bool ReadTrainingDump(const GenericVector< char > &data, LSTMTrainer *trainer) const
Definition: lstmtrainer.h:291

◆ LSTMTrainer() [2/2]

tesseract::LSTMTrainer::LSTMTrainer ( FileReader  file_reader,
FileWriter  file_writer,
CheckPointReader  checkpoint_reader,
CheckPointWriter  checkpoint_writer,
const char *  model_base,
const char *  checkpoint_name,
int  debug_interval,
inT64  max_memory 
)

Definition at line 87 of file lstmtrainer.cpp.

92  : randomly_rotate_(false),
93  training_data_(max_memory),
94  file_reader_(file_reader),
95  file_writer_(file_writer),
96  checkpoint_reader_(checkpoint_reader),
97  checkpoint_writer_(checkpoint_writer),
98  sub_trainer_(NULL),
99  mgr_(file_reader) {
103  if (checkpoint_reader_ == NULL) {
106  }
107  if (checkpoint_writer_ == NULL) {
110  }
111  debug_interval_ = debug_interval;
112  model_base_ = model_base;
113  checkpoint_name_ = checkpoint_name;
114 }
_ConstTessMemberResultCallback_0_0< false, R, T1 >::base * NewPermanentTessCallback(const T1 *obj, R(T2::*member)() const)
Definition: tesscallback.h:116
LSTMTrainer * sub_trainer_
Definition: lstmtrainer.h:450
CheckPointReader checkpoint_reader_
Definition: lstmtrainer.h:424
TessdataManager mgr_
Definition: lstmtrainer.h:483
bool SaveTrainingDump(SerializeAmount serialize_amount, const LSTMTrainer *trainer, GenericVector< char > *data) const
CheckPointWriter checkpoint_writer_
Definition: lstmtrainer.h:425
bool SaveDataToFile(const GenericVector< char > &data, const STRING &filename)
bool LoadDataFromFile(const char *filename, GenericVector< char > *data)
DocumentCache training_data_
Definition: lstmtrainer.h:414
bool ReadTrainingDump(const GenericVector< char > &data, LSTMTrainer *trainer) const
Definition: lstmtrainer.h:291

◆ ~LSTMTrainer()

tesseract::LSTMTrainer::~LSTMTrainer ( )
virtual

Definition at line 116 of file lstmtrainer.cpp.

116  {
117  delete align_win_;
118  delete target_win_;
119  delete ctc_win_;
120  delete recon_win_;
121  delete checkpoint_reader_;
122  delete checkpoint_writer_;
123  delete sub_trainer_;
124 }
LSTMTrainer * sub_trainer_
Definition: lstmtrainer.h:450
CheckPointReader checkpoint_reader_
Definition: lstmtrainer.h:424
ScrollView * target_win_
Definition: lstmtrainer.h:399
ScrollView * ctc_win_
Definition: lstmtrainer.h:401
ScrollView * recon_win_
Definition: lstmtrainer.h:403
CheckPointWriter checkpoint_writer_
Definition: lstmtrainer.h:425
ScrollView * align_win_
Definition: lstmtrainer.h:397

Member Function Documentation

◆ ActivationError()

double tesseract::LSTMTrainer::ActivationError ( ) const
inline

Definition at line 136 of file lstmtrainer.h.

136  {
137  return error_rates_[ET_DELTA];
138  }
double error_rates_[ET_COUNT]
Definition: lstmtrainer.h:481

◆ best_error_rate()

double tesseract::LSTMTrainer::best_error_rate ( ) const
inline

Definition at line 143 of file lstmtrainer.h.

143  {
144  return best_error_rate_;
145  }

◆ best_iteration()

int tesseract::LSTMTrainer::best_iteration ( ) const
inline

Definition at line 146 of file lstmtrainer.h.

146  {
147  return best_iteration_;
148  }

◆ best_trainer()

const GenericVector<char>& tesseract::LSTMTrainer::best_trainer ( ) const
inline

Definition at line 152 of file lstmtrainer.h.

152 { return best_trainer_; }
GenericVector< char > best_trainer_
Definition: lstmtrainer.h:447

◆ CharError()

double tesseract::LSTMTrainer::CharError ( ) const
inline

Definition at line 139 of file lstmtrainer.h.

◆ ComputeCharError()

double tesseract::LSTMTrainer::ComputeCharError ( const GenericVector< int > &  truth_str,
const GenericVector< int > &  ocr_str 
)
protected

Definition at line 1217 of file lstmtrainer.cpp.

1218  {
1219  GenericVector<int> label_counts;
1220  label_counts.init_to_size(NumOutputs(), 0);
1221  int truth_size = 0;
1222  for (int i = 0; i < truth_str.size(); ++i) {
1223  if (truth_str[i] != null_char_) {
1224  ++label_counts[truth_str[i]];
1225  ++truth_size;
1226  }
1227  }
1228  for (int i = 0; i < ocr_str.size(); ++i) {
1229  if (ocr_str[i] != null_char_) {
1230  --label_counts[ocr_str[i]];
1231  }
1232  }
1233  int char_errors = 0;
1234  for (int i = 0; i < label_counts.size(); ++i) {
1235  char_errors += abs(label_counts[i]);
1236  }
1237  if (truth_size == 0) {
1238  return (char_errors == 0) ? 0.0 : 1.0;
1239  }
1240  return static_cast<double>(char_errors) / truth_size;
1241 }
int size() const
Definition: genericvector.h:72
void init_to_size(int size, T t)

◆ ComputeCTCTargets()

bool tesseract::LSTMTrainer::ComputeCTCTargets ( const GenericVector< int > &  truth_labels,
NetworkIO outputs,
NetworkIO targets 
)
protected

Definition at line 1149 of file lstmtrainer.cpp.

1150  {
1151  // Bottom-clip outputs to a minimum probability.
1152  CTC::NormalizeProbs(outputs);
1153  return CTC::ComputeCTCTargets(truth_labels, null_char_,
1154  outputs->float_array(), targets);
1155 }
static void NormalizeProbs(NetworkIO *probs)
Definition: ctc.h:36
static bool ComputeCTCTargets(const GenericVector< int > &truth_labels, int null_char, const GENERIC_2D_ARRAY< float > &outputs, NetworkIO *targets)
Definition: ctc.cpp:53

◆ ComputeErrorRates()

double tesseract::LSTMTrainer::ComputeErrorRates ( const NetworkIO deltas,
double  char_error,
double  word_error 
)
protected

Definition at line 1160 of file lstmtrainer.cpp.

1161  {
1163  // Delta error is the fraction of timesteps with >0.5 error in the top choice
1164  // score. If zero, then the top choice characters are guaranteed correct,
1165  // even when there is residue in the RMS error.
1166  double delta_error = ComputeWinnerError(deltas);
1167  UpdateErrorBuffer(delta_error, ET_DELTA);
1168  UpdateErrorBuffer(word_error, ET_WORD_RECERR);
1169  UpdateErrorBuffer(char_error, ET_CHAR_ERROR);
1170  // Skip ratio measures the difference between sample_iteration_ and
1171  // training_iteration_, which reflects the number of unusable samples,
1172  // usually due to unencodable truth text, or the text not fitting in the
1173  // space for the output.
1174  double skip_count = sample_iteration_ - prev_sample_iteration_;
1175  UpdateErrorBuffer(skip_count, ET_SKIP_RATIO);
1176  return delta_error;
1177 }
double ComputeRMSError(const NetworkIO &deltas)
double ComputeWinnerError(const NetworkIO &deltas)
void UpdateErrorBuffer(double new_error, ErrorTypes type)

◆ ComputeRMSError()

double tesseract::LSTMTrainer::ComputeRMSError ( const NetworkIO deltas)
protected

Definition at line 1180 of file lstmtrainer.cpp.

1180  {
1181  double total_error = 0.0;
1182  int width = deltas.Width();
1183  int num_classes = deltas.NumFeatures();
1184  for (int t = 0; t < width; ++t) {
1185  const float* class_errs = deltas.f(t);
1186  for (int c = 0; c < num_classes; ++c) {
1187  double error = class_errs[c];
1188  total_error += error * error;
1189  }
1190  }
1191  return sqrt(total_error / (width * num_classes));
1192 }

◆ ComputeTextTargets()

bool tesseract::LSTMTrainer::ComputeTextTargets ( const NetworkIO outputs,
const GenericVector< int > &  truth_labels,
NetworkIO targets 
)
protected

Definition at line 1129 of file lstmtrainer.cpp.

1131  {
1132  if (truth_labels.size() > targets->Width()) {
1133  tprintf("Error: transcription %s too long to fit into target of width %d\n",
1134  DecodeLabels(truth_labels).string(), targets->Width());
1135  return false;
1136  }
1137  for (int i = 0; i < truth_labels.size() && i < targets->Width(); ++i) {
1138  targets->SetActivations(i, truth_labels[i], 1.0);
1139  }
1140  for (int i = truth_labels.size(); i < targets->Width(); ++i) {
1141  targets->SetActivations(i, null_char_, 1.0);
1142  }
1143  return true;
1144 }
int size() const
Definition: genericvector.h:72
#define tprintf(...)
Definition: tprintf.h:31
STRING DecodeLabels(const GenericVector< int > &labels)

◆ ComputeWinnerError()

double tesseract::LSTMTrainer::ComputeWinnerError ( const NetworkIO deltas)
protected

Definition at line 1199 of file lstmtrainer.cpp.

1199  {
1200  int num_errors = 0;
1201  int width = deltas.Width();
1202  int num_classes = deltas.NumFeatures();
1203  for (int t = 0; t < width; ++t) {
1204  const float* class_errs = deltas.f(t);
1205  for (int c = 0; c < num_classes; ++c) {
1206  float abs_delta = fabs(class_errs[c]);
1207  // TODO(rays) Filtering cases where the delta is very large to cut out
1208  // GT errors doesn't work. Find a better way or get better truth.
1209  if (0.5 <= abs_delta)
1210  ++num_errors;
1211  }
1212  }
1213  return static_cast<double>(num_errors) / width;
1214 }

◆ ComputeWordError()

double tesseract::LSTMTrainer::ComputeWordError ( STRING truth_str,
STRING ocr_str 
)
protected

Definition at line 1245 of file lstmtrainer.cpp.

1245  {
1246  typedef std::unordered_map<std::string, int, std::hash<std::string> > StrMap;
1247  GenericVector<STRING> truth_words, ocr_words;
1248  truth_str->split(' ', &truth_words);
1249  if (truth_words.empty()) return 0.0;
1250  ocr_str->split(' ', &ocr_words);
1251  StrMap word_counts;
1252  for (int i = 0; i < truth_words.size(); ++i) {
1253  std::string truth_word(truth_words[i].string());
1254  StrMap::iterator it = word_counts.find(truth_word);
1255  if (it == word_counts.end())
1256  word_counts.insert(std::make_pair(truth_word, 1));
1257  else
1258  ++it->second;
1259  }
1260  for (int i = 0; i < ocr_words.size(); ++i) {
1261  std::string ocr_word(ocr_words[i].string());
1262  StrMap::iterator it = word_counts.find(ocr_word);
1263  if (it == word_counts.end())
1264  word_counts.insert(std::make_pair(ocr_word, -1));
1265  else
1266  --it->second;
1267  }
1268  int word_recall_errs = 0;
1269  for (StrMap::const_iterator it = word_counts.begin(); it != word_counts.end();
1270  ++it) {
1271  if (it->second > 0) word_recall_errs += it->second;
1272  }
1273  return static_cast<double>(word_recall_errs) / truth_words.size();
1274 }
bool empty() const
Definition: genericvector.h:91
int size() const
Definition: genericvector.h:72
void split(const char c, GenericVector< STRING > *splited)
Definition: strngs.cpp:286

◆ CurrentTrainingStage()

int tesseract::LSTMTrainer::CurrentTrainingStage ( ) const
inline

Definition at line 211 of file lstmtrainer.h.

211 { return training_stage_; }

◆ DebugLSTMTraining()

bool tesseract::LSTMTrainer::DebugLSTMTraining ( const NetworkIO inputs,
const ImageData trainingdata,
const NetworkIO fwd_outputs,
const GenericVector< int > &  truth_labels,
const NetworkIO outputs 
)
protected

Definition at line 1059 of file lstmtrainer.cpp.

1063  {
1064  const STRING& truth_text = DecodeLabels(truth_labels);
1065  if (truth_text.string() == NULL || truth_text.length() <= 0) {
1066  tprintf("Empty truth string at decode time!\n");
1067  return false;
1068  }
1069  if (debug_interval_ != 0) {
1070  // Get class labels, xcoords and string.
1071  GenericVector<int> labels;
1072  GenericVector<int> xcoords;
1073  LabelsFromOutputs(outputs, &labels, &xcoords);
1074  STRING text = DecodeLabels(labels);
1075  tprintf("Iteration %d: ALIGNED TRUTH : %s\n",
1076  training_iteration(), text.string());
1077  if (debug_interval_ > 0 && training_iteration() % debug_interval_ == 0) {
1078  tprintf("TRAINING activation path for truth string %s\n",
1079  truth_text.string());
1080  DebugActivationPath(outputs, labels, xcoords);
1081  DisplayForward(inputs, labels, xcoords, "LSTMTraining", &align_win_);
1082  if (OutputLossType() == LT_CTC) {
1083  DisplayTargets(fwd_outputs, "CTC Outputs", &ctc_win_);
1084  DisplayTargets(outputs, "CTC Targets", &target_win_);
1085  }
1086  }
1087  }
1088  return true;
1089 }
void DisplayForward(const NetworkIO &inputs, const GenericVector< int > &labels, const GenericVector< int > &label_coords, const char *window_name, ScrollView **window)
void DebugActivationPath(const NetworkIO &outputs, const GenericVector< int > &labels, const GenericVector< int > &xcoords)
void DisplayTargets(const NetworkIO &targets, const char *window_name, ScrollView **window)
void LabelsFromOutputs(const NetworkIO &outputs, GenericVector< int > *labels, GenericVector< int > *xcoords)
ScrollView * target_win_
Definition: lstmtrainer.h:399
ScrollView * ctc_win_
Definition: lstmtrainer.h:401
#define tprintf(...)
Definition: tprintf.h:31
LossType OutputLossType() const
const char * string() const
Definition: strngs.cpp:198
STRING DecodeLabels(const GenericVector< int > &labels)
Definition: strngs.h:45
ScrollView * align_win_
Definition: lstmtrainer.h:397
inT32 length() const
Definition: strngs.cpp:193

◆ DebugNetwork()

void tesseract::LSTMTrainer::DebugNetwork ( )

Definition at line 293 of file lstmtrainer.cpp.

293  {
295 }
virtual void DebugWeights()
Definition: network.h:218

◆ DeSerialize()

bool tesseract::LSTMTrainer::DeSerialize ( const TessdataManager mgr,
TFile fp 
)
virtual

Definition at line 483 of file lstmtrainer.cpp.

483  {
484  if (!LSTMRecognizer::DeSerialize(mgr, fp)) return false;
485  if (fp->FRead(&learning_iteration_, sizeof(learning_iteration_), 1) != 1) {
486  // Special case. If we successfully decoded the recognizer, but fail here
487  // then it means we were just given a recognizer, so issue a warning and
488  // allow it.
489  tprintf("Warning: LSTMTrainer deserialized an LSTMRecognizer!\n");
492  return true;
493  }
494  if (fp->FReadEndian(&prev_sample_iteration_, sizeof(prev_sample_iteration_),
495  1) != 1)
496  return false;
497  if (fp->FReadEndian(&perfect_delay_, sizeof(perfect_delay_), 1) != 1)
498  return false;
499  if (fp->FReadEndian(&last_perfect_training_iteration_,
500  sizeof(last_perfect_training_iteration_), 1) != 1)
501  return false;
502  for (int i = 0; i < ET_COUNT; ++i) {
503  if (!error_buffers_[i].DeSerialize(fp)) return false;
504  }
505  if (fp->FRead(&error_rates_, sizeof(error_rates_), 1) != 1) return false;
506  if (fp->FReadEndian(&training_stage_, sizeof(training_stage_), 1) != 1)
507  return false;
508  uinT8 amount;
509  if (fp->FRead(&amount, sizeof(amount), 1) != 1) return false;
510  if (amount == LIGHT) return true; // Don't read the rest.
511  if (fp->FReadEndian(&best_error_rate_, sizeof(best_error_rate_), 1) != 1)
512  return false;
513  if (fp->FReadEndian(&best_error_rates_, sizeof(best_error_rates_), 1) != 1)
514  return false;
515  if (fp->FReadEndian(&best_iteration_, sizeof(best_iteration_), 1) != 1)
516  return false;
517  if (fp->FReadEndian(&worst_error_rate_, sizeof(worst_error_rate_), 1) != 1)
518  return false;
519  if (fp->FReadEndian(&worst_error_rates_, sizeof(worst_error_rates_), 1) != 1)
520  return false;
521  if (fp->FReadEndian(&worst_iteration_, sizeof(worst_iteration_), 1) != 1)
522  return false;
523  if (fp->FReadEndian(&stall_iteration_, sizeof(stall_iteration_), 1) != 1)
524  return false;
525  if (!best_model_data_.DeSerialize(fp)) return false;
526  if (!worst_model_data_.DeSerialize(fp)) return false;
527  if (amount != NO_BEST_TRAINER && !best_trainer_.DeSerialize(fp)) return false;
528  GenericVector<char> sub_data;
529  if (!sub_data.DeSerialize(fp)) return false;
530  delete sub_trainer_;
531  if (sub_data.empty()) {
532  sub_trainer_ = NULL;
533  } else {
534  sub_trainer_ = new LSTMTrainer();
535  if (!ReadTrainingDump(sub_data, sub_trainer_)) return false;
536  }
537  if (!best_error_history_.DeSerialize(fp)) return false;
538  if (!best_error_iterations_.DeSerialize(fp)) return false;
539  if (fp->FReadEndian(&improvement_steps_, sizeof(improvement_steps_), 1) != 1)
540  return false;
541  return true;
542 }
bool empty() const
Definition: genericvector.h:91
GenericVector< char > best_trainer_
Definition: lstmtrainer.h:447
GenericVector< double > best_error_history_
Definition: lstmtrainer.h:457
double worst_error_rates_[ET_COUNT]
Definition: lstmtrainer.h:438
double best_error_rates_[ET_COUNT]
Definition: lstmtrainer.h:432
GenericVector< char > worst_model_data_
Definition: lstmtrainer.h:445
LSTMTrainer * sub_trainer_
Definition: lstmtrainer.h:450
virtual void SetEnableTraining(TrainingState state)
Definition: network.cpp:112
#define tprintf(...)
Definition: tprintf.h:31
uint8_t uinT8
Definition: host.h:35
double error_rates_[ET_COUNT]
Definition: lstmtrainer.h:481
bool DeSerialize(const TessdataManager *mgr, TFile *fp)
GenericVector< double > error_buffers_[ET_COUNT]
Definition: lstmtrainer.h:479
bool DeSerialize(bool swap, FILE *fp)
GenericVector< int > best_error_iterations_
Definition: lstmtrainer.h:458
virtual bool DeSerialize(const TessdataManager *mgr, TFile *fp)
bool ReadTrainingDump(const GenericVector< char > &data, LSTMTrainer *trainer) const
Definition: lstmtrainer.h:291
GenericVector< char > best_model_data_
Definition: lstmtrainer.h:444

◆ DisplayTargets()

void tesseract::LSTMTrainer::DisplayTargets ( const NetworkIO targets,
const char *  window_name,
ScrollView **  window 
)
protected

Definition at line 1092 of file lstmtrainer.cpp.

1093  {
1094 #ifndef GRAPHICS_DISABLED // do nothing if there's no graphics.
1095  int width = targets.Width();
1096  int num_features = targets.NumFeatures();
1097  Network::ClearWindow(true, window_name, width * kTargetXScale, kTargetYScale,
1098  window);
1099  for (int c = 0; c < num_features; ++c) {
1100  int color = c % (ScrollView::GREEN_YELLOW - 1) + 2;
1101  (*window)->Pen(static_cast<ScrollView::Color>(color));
1102  int start_t = -1;
1103  for (int t = 0; t < width; ++t) {
1104  double target = targets.f(t)[c];
1105  target *= kTargetYScale;
1106  if (target >= 1) {
1107  if (start_t < 0) {
1108  (*window)->SetCursor(t - 1, 0);
1109  start_t = t;
1110  }
1111  (*window)->DrawTo(t, target);
1112  } else if (start_t >= 0) {
1113  (*window)->DrawTo(t, 0);
1114  (*window)->DrawTo(start_t - 1, 0);
1115  start_t = -1;
1116  }
1117  }
1118  if (start_t >= 0) {
1119  (*window)->DrawTo(width, 0);
1120  (*window)->DrawTo(start_t - 1, 0);
1121  }
1122  }
1123  (*window)->Update();
1124 #endif // GRAPHICS_DISABLED
1125 }
const int kTargetXScale
Definition: lstmtrainer.cpp:70
const int kTargetYScale
Definition: lstmtrainer.cpp:71
static void ClearWindow(bool tess_coords, const char *window_name, int width, int height, ScrollView **window)
Definition: network.cpp:309

◆ DumpFilename()

STRING tesseract::LSTMTrainer::DumpFilename ( ) const

Definition at line 970 of file lstmtrainer.cpp.

970  {
973  filename.add_str_int("_", best_iteration_);
974  filename += ".checkpoint";
975  return filename;
976 }
void add_str_int(const char *str, int number)
Definition: strngs.cpp:381
void add_str_double(const char *str, double number)
Definition: strngs.cpp:391
const char * string() const
Definition: strngs.cpp:198
Definition: strngs.h:45

◆ EmptyConstructor()

void tesseract::LSTMTrainer::EmptyConstructor ( )
protected

Definition at line 1044 of file lstmtrainer.cpp.

1044  {
1045  align_win_ = NULL;
1046  target_win_ = NULL;
1047  ctc_win_ = NULL;
1048  recon_win_ = NULL;
1050  training_stage_ = 0;
1052  InitIterations();
1053 }
ScrollView * target_win_
Definition: lstmtrainer.h:399
ScrollView * ctc_win_
Definition: lstmtrainer.h:401
ScrollView * recon_win_
Definition: lstmtrainer.h:403
ScrollView * align_win_
Definition: lstmtrainer.h:397

◆ EncodeString() [1/2]

bool tesseract::LSTMTrainer::EncodeString ( const STRING str,
GenericVector< int > *  labels 
) const
inline

Definition at line 246 of file lstmtrainer.h.

246  {
247  return EncodeString(str, GetUnicharset(), IsRecoding() ? &recoder_ : NULL,
248  SimpleTextOutput(), null_char_, labels);
249  }
bool EncodeString(const STRING &str, GenericVector< int > *labels) const
Definition: lstmtrainer.h:246
const UNICHARSET & GetUnicharset() const

◆ EncodeString() [2/2]

bool tesseract::LSTMTrainer::EncodeString ( const STRING str,
const UNICHARSET unicharset,
const UnicharCompress recoder,
bool  simple_text,
int  null_char,
GenericVector< int > *  labels 
)
static

Definition at line 748 of file lstmtrainer.cpp.

750  {
751  if (str.string() == NULL || str.length() <= 0) {
752  tprintf("Empty truth string!\n");
753  return false;
754  }
755  int err_index;
756  GenericVector<int> internal_labels;
757  labels->truncate(0);
758  if (!simple_text) labels->push_back(null_char);
759  string cleaned = unicharset.CleanupString(str.string());
760  if (unicharset.encode_string(cleaned.c_str(), true, &internal_labels, NULL,
761  &err_index)) {
762  bool success = true;
763  for (int i = 0; i < internal_labels.size(); ++i) {
764  if (recoder != NULL) {
765  // Re-encode labels via recoder.
766  RecodedCharID code;
767  int len = recoder->EncodeUnichar(internal_labels[i], &code);
768  if (len > 0) {
769  for (int j = 0; j < len; ++j) {
770  labels->push_back(code(j));
771  if (!simple_text) labels->push_back(null_char);
772  }
773  } else {
774  success = false;
775  err_index = 0;
776  break;
777  }
778  } else {
779  labels->push_back(internal_labels[i]);
780  if (!simple_text) labels->push_back(null_char);
781  }
782  }
783  if (success) return true;
784  }
785  tprintf("Encoding of string failed! Failure bytes:");
786  while (err_index < cleaned.size()) {
787  tprintf(" %x", cleaned[err_index++]);
788  }
789  tprintf("\n");
790  return false;
791 }
bool encode_string(const char *str, bool give_up_on_failure, GenericVector< UNICHAR_ID > *encoding, GenericVector< char > *lengths, int *encoded_length) const
Definition: unicharset.cpp:256
int size() const
Definition: genericvector.h:72
static string CleanupString(const char *utf8_str)
Definition: unicharset.h:241
#define tprintf(...)
Definition: tprintf.h:31
void truncate(int size)
const char * string() const
Definition: strngs.cpp:198
int push_back(T object)
inT32 length() const
Definition: strngs.cpp:193

◆ error_rates()

const double* tesseract::LSTMTrainer::error_rates ( ) const
inline

Definition at line 140 of file lstmtrainer.h.

140  {
141  return error_rates_;
142  }
double error_rates_[ET_COUNT]
Definition: lstmtrainer.h:481

◆ FillErrorBuffer()

void tesseract::LSTMTrainer::FillErrorBuffer ( double  new_error,
ErrorTypes  type 
)

Definition at line 979 of file lstmtrainer.cpp.

979  {
980  for (int i = 0; i < kRollingBufferSize_; ++i)
981  error_buffers_[type][i] = new_error;
982  error_rates_[type] = 100.0 * new_error;
983 }
static const int kRollingBufferSize_
Definition: lstmtrainer.h:478
double error_rates_[ET_COUNT]
Definition: lstmtrainer.h:481
GenericVector< double > error_buffers_[ET_COUNT]
Definition: lstmtrainer.h:479

◆ GridSearchDictParams()

Trainability tesseract::LSTMTrainer::GridSearchDictParams ( const ImageData trainingdata,
int  iteration,
double  min_dict_ratio,
double  dict_ratio_step,
double  max_dict_ratio,
double  min_cert_offset,
double  cert_offset_step,
double  max_cert_offset,
STRING results 
)

Definition at line 243 of file lstmtrainer.cpp.

246  {
247  sample_iteration_ = iteration;
248  NetworkIO fwd_outputs, targets;
249  Trainability result =
250  PrepareForBackward(trainingdata, &fwd_outputs, &targets);
251  if (result == UNENCODABLE || result == HI_PRECISION_ERR || dict_ == NULL)
252  return result;
253 
254  // Encode/decode the truth to get the normalization.
255  GenericVector<int> truth_labels, ocr_labels, xcoords;
256  ASSERT_HOST(EncodeString(trainingdata->transcription(), &truth_labels));
257  // NO-dict error.
258  RecodeBeamSearch base_search(recoder_, null_char_, SimpleTextOutput(), NULL);
259  base_search.Decode(fwd_outputs, 1.0, 0.0, RecodeBeamSearch::kMinCertainty,
260  NULL);
261  base_search.ExtractBestPathAsLabels(&ocr_labels, &xcoords);
262  STRING truth_text = DecodeLabels(truth_labels);
263  STRING ocr_text = DecodeLabels(ocr_labels);
264  double baseline_error = ComputeWordError(&truth_text, &ocr_text);
265  results->add_str_double("0,0=", baseline_error);
266 
267  RecodeBeamSearch search(recoder_, null_char_, SimpleTextOutput(), dict_);
268  for (double r = min_dict_ratio; r < max_dict_ratio; r += dict_ratio_step) {
269  for (double c = min_cert_offset; c < max_cert_offset;
270  c += cert_offset_step) {
271  search.Decode(fwd_outputs, r, c, RecodeBeamSearch::kMinCertainty, NULL);
272  search.ExtractBestPathAsLabels(&ocr_labels, &xcoords);
273  truth_text = DecodeLabels(truth_labels);
274  ocr_text = DecodeLabels(ocr_labels);
275  // This is destructive on both strings.
276  double word_error = ComputeWordError(&truth_text, &ocr_text);
277  if ((r == min_dict_ratio && c == min_cert_offset) ||
278  !std::isfinite(word_error)) {
279  STRING t = DecodeLabels(truth_labels);
280  STRING o = DecodeLabels(ocr_labels);
281  tprintf("r=%g, c=%g, truth=%s, ocr=%s, wderr=%g, truth[0]=%d\n", r, c,
282  t.string(), o.string(), word_error, truth_labels[0]);
283  }
284  results->add_str_double(" ", r);
285  results->add_str_double(",", c);
286  results->add_str_double("=", word_error);
287  }
288  }
289  return result;
290 }
double ComputeWordError(STRING *truth_str, STRING *ocr_str)
bool EncodeString(const STRING &str, GenericVector< int > *labels) const
Definition: lstmtrainer.h:246
#define tprintf(...)
Definition: tprintf.h:31
void add_str_double(const char *str, double number)
Definition: strngs.cpp:391
const char * string() const
Definition: strngs.cpp:198
LIST search(LIST list, void *key, int_compare is_equal)
Definition: oldlist.cpp:371
STRING DecodeLabels(const GenericVector< int > &labels)
Definition: strngs.h:45
#define ASSERT_HOST(x)
Definition: errcode.h:84
static const float kMinCertainty
Definition: recodebeam.h:213
Trainability PrepareForBackward(const ImageData *trainingdata, NetworkIO *fwd_outputs, NetworkIO *targets)

◆ improvement_steps()

int tesseract::LSTMTrainer::improvement_steps ( ) const
inline

Definition at line 150 of file lstmtrainer.h.

150 { return improvement_steps_; }

◆ InitCharSet() [1/3]

void tesseract::LSTMTrainer::InitCharSet ( const string &  traineddata_path)
inline

Definition at line 109 of file lstmtrainer.h.

109  {
110  ASSERT_HOST(mgr_.Init(traineddata_path.c_str()));
111  InitCharSet();
112  }
TessdataManager mgr_
Definition: lstmtrainer.h:483
#define ASSERT_HOST(x)
Definition: errcode.h:84
bool Init(const char *data_file_name)

◆ InitCharSet() [2/3]

void tesseract::LSTMTrainer::InitCharSet ( const TessdataManager mgr)
inline

Definition at line 113 of file lstmtrainer.h.

113  {
114  mgr_ = mgr;
115  InitCharSet();
116  }
TessdataManager mgr_
Definition: lstmtrainer.h:483

◆ InitCharSet() [3/3]

void tesseract::LSTMTrainer::InitCharSet ( )
protected

Definition at line 1022 of file lstmtrainer.cpp.

1022  {
1023  EmptyConstructor();
1025  // Initialize the unicharset and recoder.
1026  if (!LoadCharsets(&mgr_)) {
1027  ASSERT_HOST(
1028  "Must provide a traineddata containing lstm_unicharset and"
1029  " lstm_recoder!\n" != nullptr);
1030  }
1031  SetNullChar();
1032 }
TessdataManager mgr_
Definition: lstmtrainer.h:483
#define ASSERT_HOST(x)
Definition: errcode.h:84
bool LoadCharsets(const TessdataManager *mgr)

◆ InitIterations()

void tesseract::LSTMTrainer::InitIterations ( )

Definition at line 218 of file lstmtrainer.cpp.

218  {
219  sample_iteration_ = 0;
223  best_error_rate_ = 100.0;
224  best_iteration_ = 0;
225  worst_error_rate_ = 0.0;
226  worst_iteration_ = 0;
229  perfect_delay_ = 0;
231  for (int i = 0; i < ET_COUNT; ++i) {
232  best_error_rates_[i] = 100.0;
233  worst_error_rates_[i] = 0.0;
235  error_rates_[i] = 100.0;
236  }
238 }
static const int kRollingBufferSize_
Definition: lstmtrainer.h:478
double worst_error_rates_[ET_COUNT]
Definition: lstmtrainer.h:438
double best_error_rates_[ET_COUNT]
Definition: lstmtrainer.h:432
const int kMinStartedErrorRate
Definition: lstmtrainer.cpp:60
float error_rate_of_last_saved_best_
Definition: lstmtrainer.h:452
const int kMinStallIterations
Definition: lstmtrainer.cpp:47
double error_rates_[ET_COUNT]
Definition: lstmtrainer.h:481
GenericVector< double > error_buffers_[ET_COUNT]
Definition: lstmtrainer.h:479
void init_to_size(int size, T t)

◆ InitNetwork()

bool tesseract::LSTMTrainer::InitNetwork ( const STRING network_spec,
int  append_index,
int  net_flags,
float  weight_range,
float  learning_rate,
float  momentum,
float  adam_beta 
)

Definition at line 171 of file lstmtrainer.cpp.

174  {
175  mgr_.SetVersionString(mgr_.VersionString() + ":" + network_spec.string());
176  adam_beta_ = adam_beta;
178  momentum_ = momentum;
179  SetNullChar();
180  if (!NetworkBuilder::InitNetwork(recoder_.code_range(), network_spec,
181  append_index, net_flags, weight_range,
182  &randomizer_, &network_)) {
183  return false;
184  }
185  network_str_ += network_spec;
186  tprintf("Built network:%s from request %s\n",
187  network_->spec().string(), network_spec.string());
188  tprintf(
189  "Training parameters:\n Debug interval = %d,"
190  " weights = %g, learning rate = %g, momentum=%g\n",
191  debug_interval_, weight_range, learning_rate_, momentum_);
192  tprintf("null char=%d\n", null_char_);
193  return true;
194 }
virtual STRING spec() const
Definition: network.h:141
double learning_rate() const
TessdataManager mgr_
Definition: lstmtrainer.h:483
static bool InitNetwork(int num_outputs, STRING network_spec, int append_index, int net_flags, float weight_range, TRand *randomizer, Network **network)
#define tprintf(...)
Definition: tprintf.h:31
const char * string() const
Definition: strngs.cpp:198
void SetVersionString(const string &v_str)

◆ InitTensorFlowNetwork()

int tesseract::LSTMTrainer::InitTensorFlowNetwork ( const std::string &  tf_proto)

Definition at line 198 of file lstmtrainer.cpp.

198  {
199 #ifdef INCLUDE_TENSORFLOW
200  delete network_;
201  TFNetwork* tf_net = new TFNetwork("TensorFlow");
202  training_iteration_ = tf_net->InitFromProtoStr(tf_proto);
203  if (training_iteration_ == 0) {
204  tprintf("InitFromProtoStr failed!!\n");
205  return 0;
206  }
207  network_ = tf_net;
208  ASSERT_HOST(recoder_.code_range() == tf_net->num_classes());
209  return training_iteration_;
210 #else
211  tprintf("TensorFlow not compiled in! -DINCLUDE_TENSORFLOW\n");
212  return 0;
213 #endif
214 }
#define tprintf(...)
Definition: tprintf.h:31
#define ASSERT_HOST(x)
Definition: errcode.h:84

◆ LastSingleError()

double tesseract::LSTMTrainer::LastSingleError ( ErrorTypes  type) const
inline

Definition at line 160 of file lstmtrainer.h.

160  {
161  return error_buffers_[type]
164  }
static const int kRollingBufferSize_
Definition: lstmtrainer.h:478
GenericVector< double > error_buffers_[ET_COUNT]
Definition: lstmtrainer.h:479

◆ learning_iteration()

int tesseract::LSTMTrainer::learning_iteration ( ) const
inline

Definition at line 149 of file lstmtrainer.h.

149 { return learning_iteration_; }

◆ LoadAllTrainingData()

bool tesseract::LSTMTrainer::LoadAllTrainingData ( const GenericVector< STRING > &  filenames,
CachingStrategy  cache_strategy,
bool  randomly_rotate 
)

Definition at line 300 of file lstmtrainer.cpp.

302  {
303  randomly_rotate_ = randomly_rotate;
305  return training_data_.LoadDocuments(filenames, cache_strategy, file_reader_);
306 }
bool LoadDocuments(const GenericVector< STRING > &filenames, CachingStrategy cache_strategy, FileReader reader)
Definition: imagedata.cpp:573
DocumentCache training_data_
Definition: lstmtrainer.h:414

◆ LogIterations()

void tesseract::LSTMTrainer::LogIterations ( const char *  intro_str,
STRING log_msg 
) const

Definition at line 412 of file lstmtrainer.cpp.

412  {
413  *log_msg += intro_str;
414  log_msg->add_str_int(" iteration ", learning_iteration());
415  log_msg->add_str_int("/", training_iteration());
416  log_msg->add_str_int("/", sample_iteration());
417 }
void add_str_int(const char *str, int number)
Definition: strngs.cpp:381
int learning_iteration() const
Definition: lstmtrainer.h:149

◆ MaintainCheckpoints()

bool tesseract::LSTMTrainer::MaintainCheckpoints ( TestCallback  tester,
STRING log_msg 
)

Definition at line 312 of file lstmtrainer.cpp.

312  {
313  PrepareLogMsg(log_msg);
314  double error_rate = CharError();
315  int iteration = learning_iteration();
316  if (iteration >= stall_iteration_ &&
317  error_rate > best_error_rate_ * (1.0 + kSubTrainerMarginFraction) &&
319  // It hasn't got any better in a long while, and is a margin worse than the
320  // best, so go back to the best model and try a different learning rate.
321  StartSubtrainer(log_msg);
322  }
323  SubTrainerResult sub_trainer_result = STR_NONE;
324  if (sub_trainer_ != NULL) {
325  sub_trainer_result = UpdateSubtrainer(log_msg);
326  if (sub_trainer_result == STR_REPLACED) {
327  // Reset the inputs, as we have overwritten *this.
328  error_rate = CharError();
329  iteration = learning_iteration();
330  PrepareLogMsg(log_msg);
331  }
332  }
333  bool result = true; // Something interesting happened.
334  GenericVector<char> rec_model_data;
335  if (error_rate < best_error_rate_) {
336  SaveRecognitionDump(&rec_model_data);
337  log_msg->add_str_double(" New best char error = ", error_rate);
338  *log_msg += UpdateErrorGraph(iteration, error_rate, rec_model_data, tester);
339  // If sub_trainer_ is not NULL, either *this beat it to a new best, or it
340  // just overwrote *this. In either case, we have finished with it.
341  delete sub_trainer_;
342  sub_trainer_ = NULL;
345  log_msg->add_str_int(" Transitioned to stage ", CurrentTrainingStage());
346  }
349  STRING best_model_name = DumpFilename();
350  if (!(*file_writer_)(best_trainer_, best_model_name)) {
351  *log_msg += " failed to write best model:";
352  } else {
353  *log_msg += " wrote best model:";
355  }
356  *log_msg += best_model_name;
357  }
358  } else if (error_rate > worst_error_rate_) {
359  SaveRecognitionDump(&rec_model_data);
360  log_msg->add_str_double(" New worst char error = ", error_rate);
361  *log_msg += UpdateErrorGraph(iteration, error_rate, rec_model_data, tester);
364  // Error rate has ballooned. Go back to the best model.
365  *log_msg += "\nDivergence! ";
366  // Copy best_trainer_ before reading it, as it will get overwritten.
367  GenericVector<char> revert_data(best_trainer_);
368  if (checkpoint_reader_->Run(revert_data, this)) {
369  LogIterations("Reverted to", log_msg);
370  ReduceLearningRates(this, log_msg);
371  } else {
372  LogIterations("Failed to Revert at", log_msg);
373  }
374  // If it fails again, we will wait twice as long before reverting again.
375  stall_iteration_ = iteration + 2 * (iteration - learning_iteration());
376  // Re-save the best trainer with the new learning rates and stall
377  // iteration.
379  }
380  } else {
381  // Something interesting happened only if the sub_trainer_ was trained.
382  result = sub_trainer_result != STR_NONE;
383  }
384  if (checkpoint_writer_ != NULL && file_writer_ != NULL &&
385  checkpoint_name_.length() > 0) {
386  // Write a current checkpoint.
387  GenericVector<char> checkpoint;
388  if (!checkpoint_writer_->Run(FULL, this, &checkpoint) ||
389  !(*file_writer_)(checkpoint, checkpoint_name_)) {
390  *log_msg += " failed to write checkpoint.";
391  } else {
392  *log_msg += " wrote checkpoint.";
393  }
394  }
395  *log_msg += "\n";
396  return result;
397 }
int CurrentTrainingStage() const
Definition: lstmtrainer.h:211
bool empty() const
Definition: genericvector.h:91
GenericVector< char > best_trainer_
Definition: lstmtrainer.h:447
virtual R Run(A1, A2)=0
void add_str_int(const char *str, int number)
Definition: strngs.cpp:381
void SaveRecognitionDump(GenericVector< char > *data) const
LSTMTrainer * sub_trainer_
Definition: lstmtrainer.h:450
const int kMinStartedErrorRate
Definition: lstmtrainer.cpp:60
void LogIterations(const char *intro_str, STRING *log_msg) const
float error_rate_of_last_saved_best_
Definition: lstmtrainer.h:452
CheckPointReader checkpoint_reader_
Definition: lstmtrainer.h:424
void StartSubtrainer(STRING *log_msg)
double CharError() const
Definition: lstmtrainer.h:139
const int kMinStallIterations
Definition: lstmtrainer.cpp:47
int learning_iteration() const
Definition: lstmtrainer.h:149
STRING UpdateErrorGraph(int iteration, double error_rate, const GenericVector< char > &model_data, TestCallback tester)
void add_str_double(const char *str, double number)
Definition: strngs.cpp:391
const double kBestCheckpointFraction
Definition: lstmtrainer.cpp:68
bool TransitionTrainingStage(float error_threshold)
CheckPointWriter checkpoint_writer_
Definition: lstmtrainer.h:425
Definition: strngs.h:45
STRING DumpFilename() const
const double kSubTrainerMarginFraction
Definition: lstmtrainer.cpp:50
const double kMinDivergenceRate
Definition: lstmtrainer.cpp:45
virtual R Run(A1, A2, A3)=0
void PrepareLogMsg(STRING *log_msg) const
void ReduceLearningRates(LSTMTrainer *samples_trainer, STRING *log_msg)
const double kStageTransitionThreshold
Definition: lstmtrainer.cpp:62
SubTrainerResult UpdateSubtrainer(STRING *log_msg)
inT32 length() const
Definition: strngs.cpp:193

◆ MaintainCheckpointsSpecific()

bool tesseract::LSTMTrainer::MaintainCheckpointsSpecific ( int  iteration,
const GenericVector< char > *  train_model,
const GenericVector< char > *  rec_model,
TestCallback  tester,
STRING log_msg 
)

◆ MapRecoder()

std::vector< int > tesseract::LSTMTrainer::MapRecoder ( const UNICHARSET old_chset,
const UnicharCompress old_recoder 
) const

Definition at line 987 of file lstmtrainer.cpp.

988  {
989  int num_new_codes = recoder_.code_range();
990  int num_new_unichars = GetUnicharset().size();
991  std::vector<int> code_map(num_new_codes, -1);
992  for (int c = 0; c < num_new_codes; ++c) {
993  int old_code = -1;
994  // Find all new unichar_ids that recode to something that includes c.
995  // The <= is to include the null char, which may be beyond the unicharset.
996  for (int uid = 0; uid <= num_new_unichars; ++uid) {
997  RecodedCharID codes;
998  int length = recoder_.EncodeUnichar(uid, &codes);
999  int code_index = 0;
1000  while (code_index < length && codes(code_index) != c) ++code_index;
1001  if (code_index == length) continue;
1002  // The old unicharset must have the same unichar.
1003  int old_uid =
1004  uid < num_new_unichars
1005  ? old_chset.unichar_to_id(GetUnicharset().id_to_unichar(uid))
1006  : old_chset.size() - 1;
1007  if (old_uid == INVALID_UNICHAR_ID) continue;
1008  // The encoding of old_uid at the same code_index is the old code.
1009  RecodedCharID old_codes;
1010  if (code_index < old_recoder.EncodeUnichar(old_uid, &old_codes)) {
1011  old_code = old_codes(code_index);
1012  break;
1013  }
1014  }
1015  code_map[c] = old_code;
1016  }
1017  return code_map;
1018 }
int EncodeUnichar(int unichar_id, RecodedCharID *code) const
int size() const
Definition: unicharset.h:338
UNICHAR_ID unichar_to_id(const char *const unichar_repr) const
Definition: unicharset.cpp:207
const UNICHARSET & GetUnicharset() const

◆ mutable_training_data()

DocumentCache* tesseract::LSTMTrainer::mutable_training_data ( )
inline

Definition at line 168 of file lstmtrainer.h.

168 { return &training_data_; }
DocumentCache training_data_
Definition: lstmtrainer.h:414

◆ NewSingleError()

double tesseract::LSTMTrainer::NewSingleError ( ErrorTypes  type) const
inline

Definition at line 154 of file lstmtrainer.h.

154  {
156  }
static const int kRollingBufferSize_
Definition: lstmtrainer.h:478
GenericVector< double > error_buffers_[ET_COUNT]
Definition: lstmtrainer.h:479

◆ PrepareForBackward()

Trainability tesseract::LSTMTrainer::PrepareForBackward ( const ImageData trainingdata,
NetworkIO fwd_outputs,
NetworkIO targets 
)

Definition at line 828 of file lstmtrainer.cpp.

830  {
831  if (trainingdata == NULL) {
832  tprintf("Null trainingdata.\n");
833  return UNENCODABLE;
834  }
835  // Ensure repeatability of random elements even across checkpoints.
836  bool debug = debug_interval_ > 0 &&
838  GenericVector<int> truth_labels;
839  if (!EncodeString(trainingdata->transcription(), &truth_labels)) {
840  tprintf("Can't encode transcription: '%s' in language '%s'\n",
841  trainingdata->transcription().string(),
842  trainingdata->language().string());
843  return UNENCODABLE;
844  }
845  bool upside_down = false;
846  if (randomly_rotate_) {
847  // This ensures consistent training results.
848  SetRandomSeed();
849  upside_down = randomizer_.SignedRand(1.0) > 0.0;
850  if (upside_down) {
851  // Modify the truth labels to match the rotation:
852  // Apart from space and null, increment the label. This is changes the
853  // script-id to the same script-id but upside-down.
854  // The labels need to be reversed in order, as the first is now the last.
855  for (int c = 0; c < truth_labels.size(); ++c) {
856  if (truth_labels[c] != UNICHAR_SPACE && truth_labels[c] != null_char_)
857  ++truth_labels[c];
858  }
859  truth_labels.reverse();
860  }
861  }
862  int w = 0;
863  while (w < truth_labels.size() &&
864  (truth_labels[w] == UNICHAR_SPACE || truth_labels[w] == null_char_))
865  ++w;
866  if (w == truth_labels.size()) {
867  tprintf("Blank transcription: %s\n",
868  trainingdata->transcription().string());
869  return UNENCODABLE;
870  }
871  float image_scale;
872  NetworkIO inputs;
873  bool invert = trainingdata->boxes().empty();
874  if (!RecognizeLine(*trainingdata, invert, debug, invert, upside_down,
875  &image_scale, &inputs, fwd_outputs)) {
876  tprintf("Image not trainable\n");
877  return UNENCODABLE;
878  }
879  targets->Resize(*fwd_outputs, network_->NumOutputs());
880  LossType loss_type = OutputLossType();
881  if (loss_type == LT_SOFTMAX) {
882  if (!ComputeTextTargets(*fwd_outputs, truth_labels, targets)) {
883  tprintf("Compute simple targets failed!\n");
884  return UNENCODABLE;
885  }
886  } else if (loss_type == LT_CTC) {
887  if (!ComputeCTCTargets(truth_labels, fwd_outputs, targets)) {
888  tprintf("Compute CTC targets failed!\n");
889  return UNENCODABLE;
890  }
891  } else {
892  tprintf("Logistic outputs not implemented yet!\n");
893  return UNENCODABLE;
894  }
895  GenericVector<int> ocr_labels;
896  GenericVector<int> xcoords;
897  LabelsFromOutputs(*fwd_outputs, &ocr_labels, &xcoords);
898  // CTC does not produce correct target labels to begin with.
899  if (loss_type != LT_CTC) {
900  LabelsFromOutputs(*targets, &truth_labels, &xcoords);
901  }
902  if (!DebugLSTMTraining(inputs, *trainingdata, *fwd_outputs, truth_labels,
903  *targets)) {
904  tprintf("Input width was %d\n", inputs.Width());
905  return UNENCODABLE;
906  }
907  STRING ocr_text = DecodeLabels(ocr_labels);
908  STRING truth_text = DecodeLabels(truth_labels);
909  targets->SubtractAllFromFloat(*fwd_outputs);
910  if (debug_interval_ != 0) {
911  tprintf("Iteration %d: BEST OCR TEXT : %s\n", training_iteration(),
912  ocr_text.string());
913  }
914  double char_error = ComputeCharError(truth_labels, ocr_labels);
915  double word_error = ComputeWordError(&truth_text, &ocr_text);
916  double delta_error = ComputeErrorRates(*targets, char_error, word_error);
917  if (debug_interval_ != 0) {
918  tprintf("File %s page %d %s:\n", trainingdata->imagefilename().string(),
919  trainingdata->page_number(), delta_error == 0.0 ? "(Perfect)" : "");
920  }
921  if (delta_error == 0.0) return PERFECT;
922  if (targets->AnySuspiciousTruth(kHighConfidence)) return HI_PRECISION_ERR;
923  return TRAINABLE;
924 }
double ComputeWordError(STRING *truth_str, STRING *ocr_str)
int NumOutputs() const
Definition: network.h:123
void LabelsFromOutputs(const NetworkIO &outputs, GenericVector< int > *labels, GenericVector< int > *xcoords)
double ComputeErrorRates(const NetworkIO &deltas, double char_error, double word_error)
double ComputeCharError(const GenericVector< int > &truth_str, const GenericVector< int > &ocr_str)
bool ComputeCTCTargets(const GenericVector< int > &truth_labels, NetworkIO *outputs, NetworkIO *targets)
int size() const
Definition: genericvector.h:72
bool ComputeTextTargets(const NetworkIO &outputs, const GenericVector< int > &truth_labels, NetworkIO *targets)
bool EncodeString(const STRING &str, GenericVector< int > *labels) const
Definition: lstmtrainer.h:246
#define tprintf(...)
Definition: tprintf.h:31
LossType OutputLossType() const
const char * string() const
Definition: strngs.cpp:198
double SignedRand(double range)
Definition: helpers.h:60
void RecognizeLine(const ImageData &image_data, bool invert, bool debug, double worst_dict_cert, const TBOX &line_box, PointerVector< WERD_RES > *words)
STRING DecodeLabels(const GenericVector< int > &labels)
Definition: strngs.h:45
const double kHighConfidence
Definition: lstmtrainer.cpp:64
bool DebugLSTMTraining(const NetworkIO &inputs, const ImageData &trainingdata, const NetworkIO &fwd_outputs, const GenericVector< int > &truth_labels, const NetworkIO &outputs)

◆ PrepareLogMsg()

void tesseract::LSTMTrainer::PrepareLogMsg ( STRING log_msg) const

Definition at line 400 of file lstmtrainer.cpp.

400  {
401  LogIterations("At", log_msg);
402  log_msg->add_str_double(", Mean rms=", error_rates_[ET_RMS]);
403  log_msg->add_str_double("%, delta=", error_rates_[ET_DELTA]);
404  log_msg->add_str_double("%, char train=", error_rates_[ET_CHAR_ERROR]);
405  log_msg->add_str_double("%, word train=", error_rates_[ET_WORD_RECERR]);
406  log_msg->add_str_double("%, skip ratio=", error_rates_[ET_SKIP_RATIO]);
407  *log_msg += "%, ";
408 }
void LogIterations(const char *intro_str, STRING *log_msg) const
double error_rates_[ET_COUNT]
Definition: lstmtrainer.h:481
void add_str_double(const char *str, double number)
Definition: strngs.cpp:391

◆ ReadLocalTrainingDump()

bool tesseract::LSTMTrainer::ReadLocalTrainingDump ( const TessdataManager mgr,
const char *  data,
int  size 
)

Definition at line 939 of file lstmtrainer.cpp.

940  {
941  if (size == 0) {
942  tprintf("Warning: data size is 0 in LSTMTrainer::ReadLocalTrainingDump\n");
943  return false;
944  }
945  TFile fp;
946  fp.Open(data, size);
947  return DeSerialize(mgr, &fp);
948 }
#define tprintf(...)
Definition: tprintf.h:31
virtual bool DeSerialize(const TessdataManager *mgr, TFile *fp)

◆ ReadSizedTrainingDump()

bool tesseract::LSTMTrainer::ReadSizedTrainingDump ( const char *  data,
int  size,
LSTMTrainer trainer 
) const
inline

Definition at line 296 of file lstmtrainer.h.

297  {
298  return trainer->ReadLocalTrainingDump(&mgr_, data, size);
299  }
TessdataManager mgr_
Definition: lstmtrainer.h:483

◆ ReadTrainingDump()

bool tesseract::LSTMTrainer::ReadTrainingDump ( const GenericVector< char > &  data,
LSTMTrainer trainer 
) const
inline

Definition at line 291 of file lstmtrainer.h.

292  {
293  if (data.empty()) return false;
294  return ReadSizedTrainingDump(&data[0], data.size(), trainer);
295  }
bool empty() const
Definition: genericvector.h:91
bool ReadSizedTrainingDump(const char *data, int size, LSTMTrainer *trainer) const
Definition: lstmtrainer.h:296
int size() const
Definition: genericvector.h:72

◆ ReduceLayerLearningRates()

int tesseract::LSTMTrainer::ReduceLayerLearningRates ( double  factor,
int  num_samples,
LSTMTrainer samples_trainer 
)

Definition at line 639 of file lstmtrainer.cpp.

640  {
641  enum WhichWay {
642  LR_DOWN, // Learning rate will go down by factor.
643  LR_SAME, // Learning rate will stay the same.
644  LR_COUNT // Size of arrays.
645  };
647  int num_layers = layers.size();
648  GenericVector<int> num_weights;
649  num_weights.init_to_size(num_layers, 0);
650  GenericVector<double> bad_sums[LR_COUNT];
651  GenericVector<double> ok_sums[LR_COUNT];
652  for (int i = 0; i < LR_COUNT; ++i) {
653  bad_sums[i].init_to_size(num_layers, 0.0);
654  ok_sums[i].init_to_size(num_layers, 0.0);
655  }
656  double momentum_factor = 1.0 / (1.0 - momentum_);
657  GenericVector<char> orig_trainer;
658  samples_trainer->SaveTrainingDump(LIGHT, this, &orig_trainer);
659  for (int i = 0; i < num_layers; ++i) {
660  Network* layer = GetLayer(layers[i]);
661  num_weights[i] = layer->IsTraining() ? layer->num_weights() : 0;
662  }
663  int iteration = sample_iteration();
664  for (int s = 0; s < num_samples; ++s) {
665  // Which way will we modify the learning rate?
666  for (int ww = 0; ww < LR_COUNT; ++ww) {
667  // Transfer momentum to learning rate and adjust by the ww factor.
668  float ww_factor = momentum_factor;
669  if (ww == LR_DOWN) ww_factor *= factor;
670  // Make a copy of *this, so we can mess about without damaging anything.
671  LSTMTrainer copy_trainer;
672  samples_trainer->ReadTrainingDump(orig_trainer, &copy_trainer);
673  // Clear the updates, doing nothing else.
674  copy_trainer.network_->Update(0.0, 0.0, 0.0, 0);
675  // Adjust the learning rate in each layer.
676  for (int i = 0; i < num_layers; ++i) {
677  if (num_weights[i] == 0) continue;
678  copy_trainer.ScaleLayerLearningRate(layers[i], ww_factor);
679  }
680  copy_trainer.SetIteration(iteration);
681  // Train on the sample, but keep the update in updates_ instead of
682  // applying to the weights.
683  const ImageData* trainingdata =
684  copy_trainer.TrainOnLine(samples_trainer, true);
685  if (trainingdata == NULL) continue;
686  // We'll now use this trainer again for each layer.
687  GenericVector<char> updated_trainer;
688  samples_trainer->SaveTrainingDump(LIGHT, &copy_trainer, &updated_trainer);
689  for (int i = 0; i < num_layers; ++i) {
690  if (num_weights[i] == 0) continue;
691  LSTMTrainer layer_trainer;
692  samples_trainer->ReadTrainingDump(updated_trainer, &layer_trainer);
693  Network* layer = layer_trainer.GetLayer(layers[i]);
694  // Update the weights in just the layer, using Adam if enabled.
695  layer->Update(0.0, momentum_, adam_beta_,
696  layer_trainer.training_iteration_ + 1);
697  // Zero the updates matrix again.
698  layer->Update(0.0, 0.0, 0.0, 0);
699  // Train again on the same sample, again holding back the updates.
700  layer_trainer.TrainOnLine(trainingdata, true);
701  // Count the sign changes in the updates in layer vs in copy_trainer.
702  float before_bad = bad_sums[ww][i];
703  float before_ok = ok_sums[ww][i];
704  layer->CountAlternators(*copy_trainer.GetLayer(layers[i]),
705  &ok_sums[ww][i], &bad_sums[ww][i]);
706  float bad_frac =
707  bad_sums[ww][i] + ok_sums[ww][i] - before_bad - before_ok;
708  if (bad_frac > 0.0f)
709  bad_frac = (bad_sums[ww][i] - before_bad) / bad_frac;
710  }
711  }
712  ++iteration;
713  }
714  int num_lowered = 0;
715  for (int i = 0; i < num_layers; ++i) {
716  if (num_weights[i] == 0) continue;
717  Network* layer = GetLayer(layers[i]);
718  float lr = GetLayerLearningRate(layers[i]);
719  double total_down = bad_sums[LR_DOWN][i] + ok_sums[LR_DOWN][i];
720  double total_same = bad_sums[LR_SAME][i] + ok_sums[LR_SAME][i];
721  double frac_down = bad_sums[LR_DOWN][i] / total_down;
722  double frac_same = bad_sums[LR_SAME][i] / total_same;
723  tprintf("Layer %d=%s: lr %g->%g%%, lr %g->%g%%", i, layer->name().string(),
724  lr * factor, 100.0 * frac_down, lr, 100.0 * frac_same);
725  if (frac_down < frac_same * kImprovementFraction) {
726  tprintf(" REDUCED\n");
727  ScaleLayerLearningRate(layers[i], factor);
728  ++num_lowered;
729  } else {
730  tprintf(" SAME\n");
731  }
732  }
733  if (num_lowered == 0) {
734  // Just lower everything to make sure.
735  for (int i = 0; i < num_layers; ++i) {
736  if (num_weights[i] > 0) {
737  ScaleLayerLearningRate(layers[i], factor);
738  ++num_lowered;
739  }
740  }
741  }
742  return num_lowered;
743 }
Network * GetLayer(const STRING &id) const
int size() const
Definition: genericvector.h:72
void ScaleLayerLearningRate(const STRING &id, double factor)
#define tprintf(...)
Definition: tprintf.h:31
float GetLayerLearningRate(const STRING &id) const
const double kImprovementFraction
Definition: lstmtrainer.cpp:66
GenericVector< STRING > EnumerateLayers() const
void init_to_size(int size, T t)

◆ ReduceLearningRates()

void tesseract::LSTMTrainer::ReduceLearningRates ( LSTMTrainer samples_trainer,
STRING log_msg 
)

Definition at line 620 of file lstmtrainer.cpp.

621  {
623  int num_reduced = ReduceLayerLearningRates(
624  kLearningRateDecay, kNumAdjustmentIterations, samples_trainer);
625  log_msg->add_str_int("\nReduced learning rate on layers: ", num_reduced);
626  } else {
628  log_msg->add_str_double("\nReduced learning rate to :", learning_rate_);
629  }
630  *log_msg += "\n";
631 }
const double kLearningRateDecay
Definition: lstmtrainer.cpp:52
void add_str_int(const char *str, int number)
Definition: strngs.cpp:381
bool TestFlag(NetworkFlags flag) const
Definition: network.h:144
void add_str_double(const char *str, double number)
Definition: strngs.cpp:391
const int kNumAdjustmentIterations
Definition: lstmtrainer.cpp:54
int ReduceLayerLearningRates(double factor, int num_samples, LSTMTrainer *samples_trainer)
void ScaleLearningRate(double factor)

◆ RollErrorBuffers()

void tesseract::LSTMTrainer::RollErrorBuffers ( )
protected

Definition at line 1291 of file lstmtrainer.cpp.

1291  {
1293  if (NewSingleError(ET_DELTA) > 0.0)
1295  else
1298  if (debug_interval_ != 0) {
1299  tprintf("Mean rms=%g%%, delta=%g%%, train=%g%%(%g%%), skip ratio=%g%%\n",
1303  }
1304 }
#define tprintf(...)
Definition: tprintf.h:31
double error_rates_[ET_COUNT]
Definition: lstmtrainer.h:481
double NewSingleError(ErrorTypes type) const
Definition: lstmtrainer.h:154

◆ SaveRecognitionDump()

void tesseract::LSTMTrainer::SaveRecognitionDump ( GenericVector< char > *  data) const

Definition at line 960 of file lstmtrainer.cpp.

960  {
961  TFile fp;
962  fp.OpenWrite(data);
966 }
bool Serialize(const TessdataManager *mgr, TFile *fp) const
TessdataManager mgr_
Definition: lstmtrainer.h:483
virtual void SetEnableTraining(TrainingState state)
Definition: network.cpp:112
#define ASSERT_HOST(x)
Definition: errcode.h:84

◆ SaveTraineddata()

bool tesseract::LSTMTrainer::SaveTraineddata ( const STRING filename)

Definition at line 951 of file lstmtrainer.cpp.

951  {
952  GenericVector<char> recognizer_data;
953  SaveRecognitionDump(&recognizer_data);
954  mgr_.OverwriteEntry(TESSDATA_LSTM, &recognizer_data[0],
955  recognizer_data.size());
956  return mgr_.SaveFile(filename, file_writer_);
957 }
bool SaveFile(const STRING &filename, FileWriter writer) const
void SaveRecognitionDump(GenericVector< char > *data) const
int size() const
Definition: genericvector.h:72
TessdataManager mgr_
Definition: lstmtrainer.h:483
void OverwriteEntry(TessdataType type, const char *data, int size)

◆ SaveTrainingDump()

bool tesseract::LSTMTrainer::SaveTrainingDump ( SerializeAmount  serialize_amount,
const LSTMTrainer trainer,
GenericVector< char > *  data 
) const

Definition at line 930 of file lstmtrainer.cpp.

932  {
933  TFile fp;
934  fp.OpenWrite(data);
935  return trainer->Serialize(serialize_amount, &mgr_, &fp);
936 }
TessdataManager mgr_
Definition: lstmtrainer.h:483

◆ Serialize()

bool tesseract::LSTMTrainer::Serialize ( SerializeAmount  serialize_amount,
const TessdataManager mgr,
TFile fp 
) const
virtual

Definition at line 431 of file lstmtrainer.cpp.

432  {
433  if (!LSTMRecognizer::Serialize(mgr, fp)) return false;
434  if (fp->FWrite(&learning_iteration_, sizeof(learning_iteration_), 1) != 1)
435  return false;
436  if (fp->FWrite(&prev_sample_iteration_, sizeof(prev_sample_iteration_), 1) !=
437  1)
438  return false;
439  if (fp->FWrite(&perfect_delay_, sizeof(perfect_delay_), 1) != 1) return false;
440  if (fp->FWrite(&last_perfect_training_iteration_,
441  sizeof(last_perfect_training_iteration_), 1) != 1)
442  return false;
443  for (int i = 0; i < ET_COUNT; ++i) {
444  if (!error_buffers_[i].Serialize(fp)) return false;
445  }
446  if (fp->FWrite(&error_rates_, sizeof(error_rates_), 1) != 1) return false;
447  if (fp->FWrite(&training_stage_, sizeof(training_stage_), 1) != 1)
448  return false;
449  uinT8 amount = serialize_amount;
450  if (fp->FWrite(&amount, sizeof(amount), 1) != 1) return false;
451  if (serialize_amount == LIGHT) return true; // We are done.
452  if (fp->FWrite(&best_error_rate_, sizeof(best_error_rate_), 1) != 1)
453  return false;
454  if (fp->FWrite(&best_error_rates_, sizeof(best_error_rates_), 1) != 1)
455  return false;
456  if (fp->FWrite(&best_iteration_, sizeof(best_iteration_), 1) != 1)
457  return false;
458  if (fp->FWrite(&worst_error_rate_, sizeof(worst_error_rate_), 1) != 1)
459  return false;
460  if (fp->FWrite(&worst_error_rates_, sizeof(worst_error_rates_), 1) != 1)
461  return false;
462  if (fp->FWrite(&worst_iteration_, sizeof(worst_iteration_), 1) != 1)
463  return false;
464  if (fp->FWrite(&stall_iteration_, sizeof(stall_iteration_), 1) != 1)
465  return false;
466  if (!best_model_data_.Serialize(fp)) return false;
467  if (!worst_model_data_.Serialize(fp)) return false;
468  if (serialize_amount != NO_BEST_TRAINER && !best_trainer_.Serialize(fp))
469  return false;
470  GenericVector<char> sub_data;
471  if (sub_trainer_ != NULL && !SaveTrainingDump(LIGHT, sub_trainer_, &sub_data))
472  return false;
473  if (!sub_data.Serialize(fp)) return false;
474  if (!best_error_history_.Serialize(fp)) return false;
475  if (!best_error_iterations_.Serialize(fp)) return false;
476  if (fp->FWrite(&improvement_steps_, sizeof(improvement_steps_), 1) != 1)
477  return false;
478  return true;
479 }
GenericVector< char > best_trainer_
Definition: lstmtrainer.h:447
GenericVector< double > best_error_history_
Definition: lstmtrainer.h:457
double worst_error_rates_[ET_COUNT]
Definition: lstmtrainer.h:438
double best_error_rates_[ET_COUNT]
Definition: lstmtrainer.h:432
GenericVector< char > worst_model_data_
Definition: lstmtrainer.h:445
LSTMTrainer * sub_trainer_
Definition: lstmtrainer.h:450
bool Serialize(const TessdataManager *mgr, TFile *fp) const
bool SaveTrainingDump(SerializeAmount serialize_amount, const LSTMTrainer *trainer, GenericVector< char > *data) const
virtual bool Serialize(SerializeAmount serialize_amount, const TessdataManager *mgr, TFile *fp) const
uint8_t uinT8
Definition: host.h:35
double error_rates_[ET_COUNT]
Definition: lstmtrainer.h:481
bool Serialize(FILE *fp) const
GenericVector< double > error_buffers_[ET_COUNT]
Definition: lstmtrainer.h:479
GenericVector< int > best_error_iterations_
Definition: lstmtrainer.h:458
GenericVector< char > best_model_data_
Definition: lstmtrainer.h:444

◆ set_perfect_delay()

void tesseract::LSTMTrainer::set_perfect_delay ( int  delay)
inline

Definition at line 151 of file lstmtrainer.h.

151 { perfect_delay_ = delay; }

◆ SetNullChar()

void tesseract::LSTMTrainer::SetNullChar ( )
protected

Definition at line 1035 of file lstmtrainer.cpp.

1035  {
1037  : GetUnicharset().size();
1038  RecodedCharID code;
1040  null_char_ = code(0);
1041 }
int EncodeUnichar(int unichar_id, RecodedCharID *code) const
int size() const
Definition: unicharset.h:338
bool has_special_codes() const
Definition: unicharset.h:721
const UNICHARSET & GetUnicharset() const

◆ SetupCheckpointInfo()

void tesseract::LSTMTrainer::SetupCheckpointInfo ( )

◆ StartSubtrainer()

void tesseract::LSTMTrainer::StartSubtrainer ( STRING log_msg)

Definition at line 547 of file lstmtrainer.cpp.

547  {
548  delete sub_trainer_;
549  sub_trainer_ = new LSTMTrainer();
551  *log_msg += " Failed to revert to previous best for trial!";
552  delete sub_trainer_;
553  sub_trainer_ = NULL;
554  } else {
555  log_msg->add_str_int(" Trial sub_trainer_ from iteration ",
557  // Reduce learning rate so it doesn't diverge this time.
558  sub_trainer_->ReduceLearningRates(this, log_msg);
559  // If it fails again, we will wait twice as long before reverting again.
560  int stall_offset =
562  stall_iteration_ = learning_iteration() + 2 * stall_offset;
564  // Re-save the best trainer with the new learning rates and stall iteration.
566  }
567 }
GenericVector< char > best_trainer_
Definition: lstmtrainer.h:447
virtual R Run(A1, A2)=0
void add_str_int(const char *str, int number)
Definition: strngs.cpp:381
LSTMTrainer * sub_trainer_
Definition: lstmtrainer.h:450
CheckPointReader checkpoint_reader_
Definition: lstmtrainer.h:424
int learning_iteration() const
Definition: lstmtrainer.h:149
CheckPointWriter checkpoint_writer_
Definition: lstmtrainer.h:425
virtual R Run(A1, A2, A3)=0
void ReduceLearningRates(LSTMTrainer *samples_trainer, STRING *log_msg)

◆ training_data()

const DocumentCache& tesseract::LSTMTrainer::training_data ( ) const
inline

Definition at line 165 of file lstmtrainer.h.

165  {
166  return training_data_;
167  }
DocumentCache training_data_
Definition: lstmtrainer.h:414

◆ TrainOnLine() [1/2]

const ImageData* tesseract::LSTMTrainer::TrainOnLine ( LSTMTrainer samples_trainer,
bool  batch 
)
inline

Definition at line 259 of file lstmtrainer.h.

259  {
260  int sample_index = sample_iteration();
261  const ImageData* image =
262  samples_trainer->training_data_.GetPageBySerial(sample_index);
263  if (image != NULL) {
264  Trainability trainable = TrainOnLine(image, batch);
265  if (trainable == UNENCODABLE || trainable == NOT_BOXED) {
266  return NULL; // Sample was unusable.
267  }
268  } else {
270  }
271  return image;
272  }
const ImageData * TrainOnLine(LSTMTrainer *samples_trainer, bool batch)
Definition: lstmtrainer.h:259

◆ TrainOnLine() [2/2]

Trainability tesseract::LSTMTrainer::TrainOnLine ( const ImageData trainingdata,
bool  batch 
)

Definition at line 795 of file lstmtrainer.cpp.

796  {
797  NetworkIO fwd_outputs, targets;
798  Trainability trainable =
799  PrepareForBackward(trainingdata, &fwd_outputs, &targets);
801  if (trainable == UNENCODABLE || trainable == NOT_BOXED) {
802  return trainable; // Sample was unusable.
803  }
804  bool debug = debug_interval_ > 0 &&
806  // Run backprop on the output.
807  NetworkIO bp_deltas;
808  if (network_->IsTraining() &&
809  (trainable != PERFECT ||
812  network_->Backward(debug, targets, &scratch_space_, &bp_deltas);
814  training_iteration_ + 1);
815  }
816 #ifndef GRAPHICS_DISABLED
817  if (debug_interval_ == 1 && debug_win_ != NULL) {
819  }
820 #endif // GRAPHICS_DISABLED
821  // Roll the memory of past means.
823  return trainable;
824 }
virtual void Update(float learning_rate, float momentum, float adam_beta, int num_samples)
Definition: network.h:231
SVEvent * AwaitEvent(SVEventType type)
Definition: scrollview.cpp:449
bool IsTraining() const
Definition: network.h:115
virtual bool Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, NetworkIO *back_deltas)
Definition: network.h:273
Trainability PrepareForBackward(const ImageData *trainingdata, NetworkIO *fwd_outputs, NetworkIO *targets)
NetworkScratch scratch_space_

◆ TransitionTrainingStage()

bool tesseract::LSTMTrainer::TransitionTrainingStage ( float  error_threshold)

Definition at line 421 of file lstmtrainer.cpp.

421  {
422  if (best_error_rate_ < error_threshold &&
424  ++training_stage_;
425  return true;
426  }
427  return false;
428 }

◆ TryLoadingCheckpoint()

bool tesseract::LSTMTrainer::TryLoadingCheckpoint ( const char *  filename,
const char *  old_traineddata 
)

Definition at line 128 of file lstmtrainer.cpp.

129  {
130  GenericVector<char> data;
131  if (!(*file_reader_)(filename, &data)) return false;
132  tprintf("Loaded file %s, unpacking...\n", filename);
133  if (!checkpoint_reader_->Run(data, this)) return false;
134  StaticShape shape = network_->OutputShape(network_->InputShape());
135  if (((old_traineddata == nullptr || *old_traineddata == '\0') &&
137  filename == old_traineddata) {
138  return true; // Normal checkpoint load complete.
139  }
140  tprintf("Code range changed from %d to %d!\n", network_->NumOutputs(),
141  recoder_.code_range());
142  if (old_traineddata == nullptr || *old_traineddata == '\0') {
143  tprintf("Must supply the old traineddata for code conversion!\n");
144  return false;
145  }
146  TessdataManager old_mgr;
147  ASSERT_HOST(old_mgr.Init(old_traineddata));
148  TFile fp;
149  if (!old_mgr.GetComponent(TESSDATA_LSTM_UNICHARSET, &fp)) return false;
150  UNICHARSET old_chset;
151  if (!old_chset.load_from_file(&fp, false)) return false;
152  if (!old_mgr.GetComponent(TESSDATA_LSTM_RECODER, &fp)) return false;
153  UnicharCompress old_recoder;
154  if (!old_recoder.DeSerialize(&fp)) return false;
155  std::vector<int> code_map = MapRecoder(old_chset, old_recoder);
156  // Set the null_char_ to the new value.
157  int old_null_char = null_char_;
158  SetNullChar();
159  // Map the softmax(s) in the network.
160  network_->RemapOutputs(old_recoder.code_range(), code_map);
161  tprintf("Previous null char=%d mapped to %d\n", old_null_char, null_char_);
162  return true;
163 }
int NumOutputs() const
Definition: network.h:123
virtual R Run(A1, A2)=0
CheckPointReader checkpoint_reader_
Definition: lstmtrainer.h:424
#define tprintf(...)
Definition: tprintf.h:31
virtual StaticShape InputShape() const
Definition: network.h:127
#define ASSERT_HOST(x)
Definition: errcode.h:84
virtual StaticShape OutputShape(const StaticShape &input_shape) const
Definition: network.h:133
std::vector< int > MapRecoder(const UNICHARSET &old_chset, const UnicharCompress &old_recoder) const
bool load_from_file(const char *const filename, bool skip_fragments)
Definition: unicharset.h:387
virtual int RemapOutputs(int old_no, const std::vector< int > &code_map)
Definition: network.h:186

◆ UpdateErrorBuffer()

void tesseract::LSTMTrainer::UpdateErrorBuffer ( double  new_error,
ErrorTypes  type 
)
protected

Definition at line 1278 of file lstmtrainer.cpp.

1278  {
1280  error_buffers_[type][index] = new_error;
1281  // Compute the mean error.
1282  int mean_count = MIN(training_iteration_ + 1, error_buffers_[type].size());
1283  double buffer_sum = 0.0;
1284  for (int i = 0; i < mean_count; ++i) buffer_sum += error_buffers_[type][i];
1285  double mean = buffer_sum / mean_count;
1286  // Trim precision to 1/1000 of 1%.
1287  error_rates_[type] = IntCastRounded(100000.0 * mean) / 1000.0;
1288 }
#define MIN(x, y)
Definition: ndminx.h:28
static const int kRollingBufferSize_
Definition: lstmtrainer.h:478
double error_rates_[ET_COUNT]
Definition: lstmtrainer.h:481
GenericVector< double > error_buffers_[ET_COUNT]
Definition: lstmtrainer.h:479
int IntCastRounded(double x)
Definition: helpers.h:179

◆ UpdateErrorGraph()

STRING tesseract::LSTMTrainer::UpdateErrorGraph ( int  iteration,
double  error_rate,
const GenericVector< char > &  model_data,
TestCallback  tester 
)
protected

Definition at line 1310 of file lstmtrainer.cpp.

1312  {
1313  if (error_rate > best_error_rate_
1314  && iteration < best_iteration_ + kErrorGraphInterval) {
1315  // Too soon to record a new point.
1316  if (tester != NULL && !worst_model_data_.empty()) {
1319  return tester->Run(worst_iteration_, NULL, mgr_, CurrentTrainingStage());
1320  } else {
1321  return "";
1322  }
1323  }
1324  STRING result;
1325  // NOTE: there are 2 asymmetries here:
1326  // 1. We are computing the global minimum, but the local maximum in between.
1327  // 2. If the tester returns an empty string, indicating that it is busy,
1328  // call it repeatedly on new local maxima to test the previous min, but
1329  // not the other way around, as there is little point testing the maxima
1330  // between very frequent minima.
1331  if (error_rate < best_error_rate_) {
1332  // This is a new (global) minimum.
1333  if (tester != nullptr && !worst_model_data_.empty()) {
1336  result = tester->Run(worst_iteration_, worst_error_rates_, mgr_,
1339  best_model_data_ = model_data;
1340  }
1341  best_error_rate_ = error_rate;
1342  memcpy(best_error_rates_, error_rates_, sizeof(error_rates_));
1343  best_iteration_ = iteration;
1344  best_error_history_.push_back(error_rate);
1345  best_error_iterations_.push_back(iteration);
1346  // Compute 2% decay time.
1347  double two_percent_more = error_rate + 2.0;
1348  int i;
1349  for (i = best_error_history_.size() - 1;
1350  i >= 0 && best_error_history_[i] < two_percent_more; --i) {
1351  }
1352  int old_iteration = i >= 0 ? best_error_iterations_[i] : 0;
1353  improvement_steps_ = iteration - old_iteration;
1354  tprintf("2 Percent improvement time=%d, best error was %g @ %d\n",
1355  improvement_steps_, i >= 0 ? best_error_history_[i] : 100.0,
1356  old_iteration);
1357  } else if (error_rate > best_error_rate_) {
1358  // This is a new (local) maximum.
1359  if (tester != NULL) {
1360  if (!best_model_data_.empty()) {
1363  result = tester->Run(best_iteration_, best_error_rates_, mgr_,
1365  } else if (!worst_model_data_.empty()) {
1366  // Allow for multiple data points with "worst" error rate.
1369  result = tester->Run(worst_iteration_, worst_error_rates_, mgr_,
1371  }
1372  if (result.length() > 0)
1374  worst_model_data_ = model_data;
1375  }
1376  }
1377  worst_error_rate_ = error_rate;
1378  memcpy(worst_error_rates_, error_rates_, sizeof(error_rates_));
1379  worst_iteration_ = iteration;
1380  return result;
1381 }
int CurrentTrainingStage() const
Definition: lstmtrainer.h:211
bool empty() const
Definition: genericvector.h:91
GenericVector< double > best_error_history_
Definition: lstmtrainer.h:457
double worst_error_rates_[ET_COUNT]
Definition: lstmtrainer.h:438
double best_error_rates_[ET_COUNT]
Definition: lstmtrainer.h:432
GenericVector< char > worst_model_data_
Definition: lstmtrainer.h:445
int size() const
Definition: genericvector.h:72
TessdataManager mgr_
Definition: lstmtrainer.h:483
#define tprintf(...)
Definition: tprintf.h:31
void truncate(int size)
double error_rates_[ET_COUNT]
Definition: lstmtrainer.h:481
int push_back(T object)
const int kErrorGraphInterval
Definition: lstmtrainer.cpp:56
Definition: strngs.h:45
GenericVector< int > best_error_iterations_
Definition: lstmtrainer.h:458
void OverwriteEntry(TessdataType type, const char *data, int size)
inT32 length() const
Definition: strngs.cpp:193
GenericVector< char > best_model_data_
Definition: lstmtrainer.h:444

◆ UpdateSubtrainer()

SubTrainerResult tesseract::LSTMTrainer::UpdateSubtrainer ( STRING log_msg)

Definition at line 577 of file lstmtrainer.cpp.

577  {
578  double training_error = CharError();
579  double sub_error = sub_trainer_->CharError();
580  double sub_margin = (training_error - sub_error) / sub_error;
581  if (sub_margin >= kSubTrainerMarginFraction) {
582  log_msg->add_str_double(" sub_trainer=", sub_error);
583  log_msg->add_str_double(" margin=", 100.0 * sub_margin);
584  *log_msg += "\n";
585  // Catch up to current iteration.
586  int end_iteration = training_iteration();
587  while (sub_trainer_->training_iteration() < end_iteration &&
588  sub_margin >= kSubTrainerMarginFraction) {
589  int target_iteration =
591  while (sub_trainer_->training_iteration() < target_iteration) {
592  sub_trainer_->TrainOnLine(this, false);
593  }
594  STRING batch_log = "Sub:";
595  sub_trainer_->PrepareLogMsg(&batch_log);
596  batch_log += "\n";
597  tprintf("UpdateSubtrainer:%s", batch_log.string());
598  *log_msg += batch_log;
599  sub_error = sub_trainer_->CharError();
600  sub_margin = (training_error - sub_error) / sub_error;
601  }
602  if (sub_error < best_error_rate_ &&
603  sub_margin >= kSubTrainerMarginFraction) {
604  // The sub_trainer_ has won the race to a new best. Switch to it.
605  GenericVector<char> updated_trainer;
606  SaveTrainingDump(LIGHT, sub_trainer_, &updated_trainer);
607  ReadTrainingDump(updated_trainer, this);
608  log_msg->add_str_int(" Sub trainer wins at iteration ",
610  *log_msg += "\n";
611  return STR_REPLACED;
612  }
613  return STR_UPDATED;
614  }
615  return STR_NONE;
616 }
LSTMTrainer * sub_trainer_
Definition: lstmtrainer.h:450
double CharError() const
Definition: lstmtrainer.h:139
#define tprintf(...)
Definition: tprintf.h:31
bool SaveTrainingDump(SerializeAmount serialize_amount, const LSTMTrainer *trainer, GenericVector< char > *data) const
const int kNumPagesPerBatch
Definition: lstmtrainer.cpp:58
void add_str_double(const char *str, double number)
Definition: strngs.cpp:391
const char * string() const
Definition: strngs.cpp:198
const ImageData * TrainOnLine(LSTMTrainer *samples_trainer, bool batch)
Definition: lstmtrainer.h:259
Definition: strngs.h:45
const double kSubTrainerMarginFraction
Definition: lstmtrainer.cpp:50
void PrepareLogMsg(STRING *log_msg) const
bool ReadTrainingDump(const GenericVector< char > &data, LSTMTrainer *trainer) const
Definition: lstmtrainer.h:291

Member Data Documentation

◆ align_win_

ScrollView* tesseract::LSTMTrainer::align_win_
protected

Definition at line 397 of file lstmtrainer.h.

◆ best_error_history_

GenericVector<double> tesseract::LSTMTrainer::best_error_history_
protected

Definition at line 457 of file lstmtrainer.h.

◆ best_error_iterations_

GenericVector<int> tesseract::LSTMTrainer::best_error_iterations_
protected

Definition at line 458 of file lstmtrainer.h.

◆ best_error_rate_

double tesseract::LSTMTrainer::best_error_rate_
protected

Definition at line 430 of file lstmtrainer.h.

◆ best_error_rates_

double tesseract::LSTMTrainer::best_error_rates_[ET_COUNT]
protected

Definition at line 432 of file lstmtrainer.h.

◆ best_iteration_

int tesseract::LSTMTrainer::best_iteration_
protected

Definition at line 434 of file lstmtrainer.h.

◆ best_model_data_

GenericVector<char> tesseract::LSTMTrainer::best_model_data_
protected

Definition at line 444 of file lstmtrainer.h.

◆ best_model_name_

STRING tesseract::LSTMTrainer::best_model_name_
protected

Definition at line 416 of file lstmtrainer.h.

◆ best_trainer_

GenericVector<char> tesseract::LSTMTrainer::best_trainer_
protected

Definition at line 447 of file lstmtrainer.h.

◆ checkpoint_iteration_

int tesseract::LSTMTrainer::checkpoint_iteration_
protected

Definition at line 407 of file lstmtrainer.h.

◆ checkpoint_name_

STRING tesseract::LSTMTrainer::checkpoint_name_
protected

Definition at line 411 of file lstmtrainer.h.

◆ checkpoint_reader_

CheckPointReader tesseract::LSTMTrainer::checkpoint_reader_
protected

Definition at line 424 of file lstmtrainer.h.

◆ checkpoint_writer_

CheckPointWriter tesseract::LSTMTrainer::checkpoint_writer_
protected

Definition at line 425 of file lstmtrainer.h.

◆ ctc_win_

ScrollView* tesseract::LSTMTrainer::ctc_win_
protected

Definition at line 401 of file lstmtrainer.h.

◆ debug_interval_

int tesseract::LSTMTrainer::debug_interval_
protected

Definition at line 405 of file lstmtrainer.h.

◆ error_buffers_

GenericVector<double> tesseract::LSTMTrainer::error_buffers_[ET_COUNT]
protected

Definition at line 479 of file lstmtrainer.h.

◆ error_rate_of_last_saved_best_

float tesseract::LSTMTrainer::error_rate_of_last_saved_best_
protected

Definition at line 452 of file lstmtrainer.h.

◆ error_rates_

double tesseract::LSTMTrainer::error_rates_[ET_COUNT]
protected

Definition at line 481 of file lstmtrainer.h.

◆ file_reader_

FileReader tesseract::LSTMTrainer::file_reader_
protected

Definition at line 420 of file lstmtrainer.h.

◆ file_writer_

FileWriter tesseract::LSTMTrainer::file_writer_
protected

Definition at line 421 of file lstmtrainer.h.

◆ improvement_steps_

int tesseract::LSTMTrainer::improvement_steps_
protected

Definition at line 460 of file lstmtrainer.h.

◆ kRollingBufferSize_

const int tesseract::LSTMTrainer::kRollingBufferSize_ = 1000
staticprotected

Definition at line 478 of file lstmtrainer.h.

◆ last_perfect_training_iteration_

int tesseract::LSTMTrainer::last_perfect_training_iteration_
protected

Definition at line 475 of file lstmtrainer.h.

◆ learning_iteration_

int tesseract::LSTMTrainer::learning_iteration_
protected

Definition at line 464 of file lstmtrainer.h.

◆ mgr_

TessdataManager tesseract::LSTMTrainer::mgr_
protected

Definition at line 483 of file lstmtrainer.h.

◆ model_base_

STRING tesseract::LSTMTrainer::model_base_
protected

Definition at line 409 of file lstmtrainer.h.

◆ num_training_stages_

int tesseract::LSTMTrainer::num_training_stages_
protected

Definition at line 418 of file lstmtrainer.h.

◆ perfect_delay_

int tesseract::LSTMTrainer::perfect_delay_
protected

Definition at line 472 of file lstmtrainer.h.

◆ prev_sample_iteration_

int tesseract::LSTMTrainer::prev_sample_iteration_
protected

Definition at line 466 of file lstmtrainer.h.

◆ randomly_rotate_

bool tesseract::LSTMTrainer::randomly_rotate_
protected

Definition at line 413 of file lstmtrainer.h.

◆ recon_win_

ScrollView* tesseract::LSTMTrainer::recon_win_
protected

Definition at line 403 of file lstmtrainer.h.

◆ stall_iteration_

int tesseract::LSTMTrainer::stall_iteration_
protected

Definition at line 442 of file lstmtrainer.h.

◆ sub_trainer_

LSTMTrainer* tesseract::LSTMTrainer::sub_trainer_
protected

Definition at line 450 of file lstmtrainer.h.

◆ target_win_

ScrollView* tesseract::LSTMTrainer::target_win_
protected

Definition at line 399 of file lstmtrainer.h.

◆ training_data_

DocumentCache tesseract::LSTMTrainer::training_data_
protected

Definition at line 414 of file lstmtrainer.h.

◆ training_stage_

int tesseract::LSTMTrainer::training_stage_
protected

Definition at line 454 of file lstmtrainer.h.

◆ worst_error_rate_

double tesseract::LSTMTrainer::worst_error_rate_
protected

Definition at line 436 of file lstmtrainer.h.

◆ worst_error_rates_

double tesseract::LSTMTrainer::worst_error_rates_[ET_COUNT]
protected

Definition at line 438 of file lstmtrainer.h.

◆ worst_iteration_

int tesseract::LSTMTrainer::worst_iteration_
protected

Definition at line 440 of file lstmtrainer.h.

◆ worst_model_data_

GenericVector<char> tesseract::LSTMTrainer::worst_model_data_
protected

Definition at line 445 of file lstmtrainer.h.


The documentation for this class was generated from the following files: