19 #ifndef TESSERACT_LSTM_LSTMTRAINER_H_ 20 #define TESSERACT_LSTM_LSTMTRAINER_H_ 94 CheckPointReader checkpoint_reader,
95 CheckPointWriter checkpoint_writer,
96 const char* model_base,
const char* checkpoint_name,
97 int debug_interval,
inT64 max_memory);
124 bool InitNetwork(
const STRING& network_spec,
int append_index,
int net_flags,
174 const ImageData* trainingdata,
int iteration,
double min_dict_ratio,
175 double dict_ratio_step,
double max_dict_ratio,
double min_cert_offset,
176 double cert_offset_step,
double max_cert_offset, STRING* results);
186 bool randomly_rotate);
199 TestCallback tester, STRING* log_msg);
204 void LogIterations(
const char* intro_str, STRING* log_msg)
const;
214 virtual bool Serialize(SerializeAmount serialize_amount,
215 const TessdataManager* mgr,
TFile* fp)
const;
242 LSTMTrainer* samples_trainer);
285 const LSTMTrainer* trainer,
292 LSTMTrainer* trainer)
const {
293 if (data.
empty())
return false;
297 LSTMTrainer* trainer)
const {
393 TestCallback tester);
488 #endif // TESSERACT_LSTM_LSTMTRAINER_H_ int CurrentTrainingStage() const
const DocumentCache & training_data() const
static const int kRollingBufferSize_
double ComputeWordError(STRING *truth_str, STRING *ocr_str)
bool TryLoadingCheckpoint(const char *filename, const char *old_traineddata)
int last_perfect_training_iteration_
GenericVector< char > best_trainer_
void DisplayTargets(const NetworkIO &targets, const char *window_name, ScrollView **window)
GenericVector< double > best_error_history_
int InitTensorFlowNetwork(const std::string &tf_proto)
void FillErrorBuffer(double new_error, ErrorTypes type)
double learning_rate() const
double worst_error_rates_[ET_COUNT]
double ComputeRMSError(const NetworkIO &deltas)
void SetupCheckpointInfo()
const GenericVector< char > & best_trainer() const
double ComputeErrorRates(const NetworkIO &deltas, double char_error, double word_error)
void set_perfect_delay(int delay)
double ComputeCharError(const GenericVector< int > &truth_str, const GenericVector< int > &ocr_str)
double ComputeWinnerError(const NetworkIO &deltas)
double best_error_rates_[ET_COUNT]
void SaveRecognitionDump(GenericVector< char > *data) const
GenericVector< char > worst_model_data_
const ImageData * GetPageBySerial(int serial)
TessResultCallback4< STRING, int, const double *, const TessdataManager &, int > * TestCallback
LSTMTrainer * sub_trainer_
bool ComputeCTCTargets(const GenericVector< int > &truth_labels, NetworkIO *outputs, NetworkIO *targets)
bool ReadSizedTrainingDump(const char *data, int size, LSTMTrainer *trainer) const
bool InitNetwork(const STRING &network_spec, int append_index, int net_flags, float weight_range, float learning_rate, float momentum, float adam_beta)
void LogIterations(const char *intro_str, STRING *log_msg) const
TessResultCallback2< bool, const GenericVector< char > &, LSTMTrainer * > * CheckPointReader
float error_rate_of_last_saved_best_
CheckPointReader checkpoint_reader_
void StartSubtrainer(STRING *log_msg)
bool ComputeTextTargets(const NetworkIO &outputs, const GenericVector< int > &truth_labels, NetworkIO *targets)
int sample_iteration() const
bool EncodeString(const STRING &str, GenericVector< int > *labels) const
double LastSingleError(ErrorTypes type) const
bool SaveTrainingDump(SerializeAmount serialize_amount, const LSTMTrainer *trainer, GenericVector< char > *data) const
bool SaveTraineddata(const STRING &filename)
int training_iteration() const
const double * error_rates() const
int learning_iteration() const
virtual bool Serialize(SerializeAmount serialize_amount, const TessdataManager *mgr, TFile *fp) const
STRING UpdateErrorGraph(int iteration, double error_rate, const GenericVector< char > &model_data, TestCallback tester)
bool SimpleTextOutput() const
double error_rates_[ET_COUNT]
bool(* FileWriter)(const GenericVector< char > &data, const STRING &filename)
bool MaintainCheckpoints(TestCallback tester, STRING *log_msg)
int prev_sample_iteration_
Trainability GridSearchDictParams(const ImageData *trainingdata, int iteration, double min_dict_ratio, double dict_ratio_step, double max_dict_ratio, double min_cert_offset, double cert_offset_step, double max_cert_offset, STRING *results)
bool TransitionTrainingStage(float error_threshold)
const ImageData * TrainOnLine(LSTMTrainer *samples_trainer, bool batch)
double NewSingleError(ErrorTypes type) const
CheckPointWriter checkpoint_writer_
double ActivationError() const
bool ReadLocalTrainingDump(const TessdataManager *mgr, const char *data, int size)
int improvement_steps() const
GenericVector< double > error_buffers_[ET_COUNT]
std::vector< int > MapRecoder(const UNICHARSET &old_chset, const UnicharCompress &old_recoder) const
bool LoadAllTrainingData(const GenericVector< STRING > &filenames, CachingStrategy cache_strategy, bool randomly_rotate)
STRING DumpFilename() const
GenericVector< int > best_error_iterations_
void InitCharSet(const string &traineddata_path)
double best_error_rate() const
void InitCharSet(const TessdataManager &mgr)
Trainability PrepareForBackward(const ImageData *trainingdata, NetworkIO *fwd_outputs, NetworkIO *targets)
int best_iteration() const
void PrepareLogMsg(STRING *log_msg) const
const UNICHARSET & GetUnicharset() const
int ReduceLayerLearningRates(double factor, int num_samples, LSTMTrainer *samples_trainer)
bool Init(const char *data_file_name)
DocumentCache training_data_
bool MaintainCheckpointsSpecific(int iteration, const GenericVector< char > *train_model, const GenericVector< char > *rec_model, TestCallback tester, STRING *log_msg)
bool(* FileReader)(const STRING &filename, GenericVector< char > *data)
virtual bool DeSerialize(const TessdataManager *mgr, TFile *fp)
void UpdateErrorBuffer(double new_error, ErrorTypes type)
TessResultCallback3< bool, SerializeAmount, const LSTMTrainer *, GenericVector< char > * > * CheckPointWriter
void ReduceLearningRates(LSTMTrainer *samples_trainer, STRING *log_msg)
DocumentCache * mutable_training_data()
bool ReadTrainingDump(const GenericVector< char > &data, LSTMTrainer *trainer) const
int checkpoint_iteration_
SubTrainerResult UpdateSubtrainer(STRING *log_msg)
bool DebugLSTMTraining(const NetworkIO &inputs, const ImageData &trainingdata, const NetworkIO &fwd_outputs, const GenericVector< int > &truth_labels, const NetworkIO &outputs)
GenericVector< char > best_model_data_