18#ifndef TESSERACT_LSTM_LSTMTRAINER_H_
19#define TESSERACT_LSTM_LSTMTRAINER_H_
87 LSTMTrainer(
const char *model_base,
const char *checkpoint_name,
88 int debug_interval, int64_t max_memory);
95 bool TryLoadingCheckpoint(
const char *filename,
const char *old_traineddata);
101 bool success = mgr_.Init(traineddata_path.c_str());
118 bool InitNetwork(
const char *network_spec,
int append_index,
int net_flags,
119 float weight_range,
float learning_rate,
float momentum,
127 void InitIterations();
140 return best_error_rate_;
143 return best_iteration_;
146 return learning_iteration_;
149 return improvement_steps_;
152 perfect_delay_ = delay;
155 return best_trainer_;
159 return error_buffers_[
type][training_iteration() % kRollingBufferSize_];
165 return error_buffers_[
type]
166 [(training_iteration() + kRollingBufferSize_ - 1) %
167 kRollingBufferSize_];
170 return training_data_;
173 return &training_data_;
180 const ImageData *trainingdata,
int iteration,
double min_dict_ratio,
181 double dict_ratio_step,
double max_dict_ratio,
double min_cert_offset,
182 double cert_offset_step,
double max_cert_offset, std::string &results);
190 bool LoadAllTrainingData(
const std::vector<std::string> &filenames,
192 bool randomly_rotate);
196 bool MaintainCheckpoints(
const TestCallback &tester, std::stringstream &log_msg);
203 const std::vector<char> *train_model,
204 const std::vector<char> *rec_model,
207 void PrepareLogMsg(std::stringstream &log_msg)
const;
210 void LogIterations(
const char *intro_str, std::stringstream &log_msg)
const;
215 bool TransitionTrainingStage(
float error_threshold);
218 return training_stage_;
230 void StartSubtrainer(std::stringstream &log_msg);
242 void ReduceLearningRates(
LSTMTrainer *samples_trainer, std::stringstream &log_msg);
249 int ReduceLayerLearningRates(
TFloat factor,
int num_samples,
254 bool EncodeString(
const std::string &str, std::vector<int> *labels)
const {
255 return EncodeString(str, GetUnicharset(),
256 IsRecoding() ? &recoder_ :
nullptr, SimpleTextOutput(),
260 static bool EncodeString(
const std::string &str,
const UNICHARSET &unicharset,
262 int null_char, std::vector<int> *labels);
269 int sample_index = sample_iteration();
272 if (image !=
nullptr) {
295 std::vector<char> *data)
const;
305 return ReadSizedTrainingDump(&data[0], data.size(), trainer);
312 bool ReadLocalTrainingDump(
const TessdataManager *mgr,
const char *data,
319 bool SaveTraineddata(
const char *filename);
322 void SaveRecognitionDump(std::vector<char> *data)
const;
326 std::string DumpFilename()
const;
332 std::vector<int> MapRecoder(
const UNICHARSET &old_chset,
343 void EmptyConstructor();
351 const std::vector<int> &truth_labels,
354 void DisplayTargets(
const NetworkIO &targets,
const char *window_name,
359 bool ComputeTextTargets(
const NetworkIO &outputs,
360 const std::vector<int> &truth_labels,
366 bool ComputeCTCTargets(
const std::vector<int> &truth_labels,
372 double ComputeErrorRates(
const NetworkIO &deltas,
double char_error,
376 double ComputeRMSError(
const NetworkIO &deltas);
383 double ComputeWinnerError(
const NetworkIO &deltas);
386 double ComputeCharError(
const std::vector<int> &truth_str,
387 const std::vector<int> &ocr_str);
390 double ComputeWordError(std::string *truth_str, std::string *ocr_str);
397 void RollErrorBuffers();
401 std::string UpdateErrorGraph(
int iteration,
double error_rate,
402 const std::vector<char> &model_data,
406#ifndef GRAPHICS_DISABLED
483 static const int kRollingBufferSize_ = 1000;
bool DeSerialize(bool swap, FILE *fp, std::vector< T > &data)
bool Serialize(FILE *fp, const std::vector< T > &data)
std::function< std::string(int, const double *, const TessdataManager &, int)> TestCallback
const ImageData * GetPageBySerial(int serial)
std::vector< int32_t > best_error_iterations_
bool MaintainCheckpointsSpecific(int iteration, const std::vector< char > *train_model, const std::vector< char > *rec_model, TestCallback tester, std::stringstream &log_msg)
std::vector< char > worst_model_data_
bool ReadLocalTrainingDump(const TessdataManager *mgr, const char *data, int size)
bool EncodeString(const std::string &str, std::vector< int > *labels) const
const double * error_rates() const
bool InitCharSet(const std::string &traineddata_path)
int InitTensorFlowNetwork(const std::string &tf_proto)
std::string best_model_name_
double NewSingleError(ErrorTypes type) const
std::vector< char > best_trainer_
double best_error_rate() const
double LastSingleError(ErrorTypes type) const
DocumentCache * mutable_training_data()
const std::vector< char > & best_trainer() const
float error_rate_of_last_saved_best_
int last_perfect_training_iteration_
int learning_iteration() const
void set_perfect_delay(int delay)
std::string checkpoint_name_
int32_t improvement_steps_
int CurrentTrainingStage() const
double ActivationError() const
std::vector< char > best_model_data_
bool ReadSizedTrainingDump(const char *data, int size, LSTMTrainer &trainer) const
void InitCharSet(const TessdataManager &mgr)
DocumentCache training_data_
int checkpoint_iteration_
int prev_sample_iteration_
std::unique_ptr< LSTMTrainer > sub_trainer_
const DocumentCache & training_data() const
bool ReadTrainingDump(const std::vector< char > &data, LSTMTrainer &trainer) const
int32_t improvement_steps() const
const ImageData * TrainOnLine(LSTMTrainer *samples_trainer, bool batch)
std::vector< double > best_error_history_
int best_iteration() const
void SetupCheckpointInfo()