33static INT_PARAM_FLAG(debug_interval, 0,
"How often to display the alignment.");
36static INT_PARAM_FLAG(perfect_sample_delay, 0,
"How many imperfect samples between perfect ones.");
41 "Resets all stored learning rates to the value specified by --learning_rate.");
44static INT_PARAM_FLAG(max_image_MB, 6000,
"Max memory to use for images.");
48 "File listing training files in lstmf training format.");
49static STRING_PARAM_FLAG(eval_listfile,
"",
"File listing eval files in lstmf training format.");
51static BOOL_PARAM_FLAG(debug_float,
false,
"Raise error on certain float errors.");
53static BOOL_PARAM_FLAG(stop_training,
false,
"Just convert the training model to a runtime model.");
54static BOOL_PARAM_FLAG(convert_to_int,
false,
"Convert the recognition model to an integer model.");
56 "Use the training files sequentially instead of round-robin.");
58 "Index in continue_from Network at which to"
59 " attach the new network defined by net_spec");
60static BOOL_PARAM_FLAG(debug_network,
false,
"Get info on distribution of weight values");
61static INT_PARAM_FLAG(max_iterations, 0,
"If set, exit after this many iterations");
62static STRING_PARAM_FLAG(traineddata,
"",
"Combined Dawgs/Unicharset/Recoder for language model");
64 "When changing the character set, this specifies the old"
65 " character set that is to be replaced");
67 "Train OSD and randomly turn training samples upside-down");
76int main(
int argc,
char **argv) {
77 tesseract::CheckSharedLibraryVersion();
80 if (FLAGS_debug_float) {
82 feenableexcept(FE_DIVBYZERO | FE_OVERFLOW | FE_INVALID);
85 if (FLAGS_model_output.empty()) {
86 tprintf(
"Must provide a --model_output!\n");
89 if (FLAGS_traineddata.empty()) {
90 tprintf(
"Must provide a --traineddata see training documentation\n");
95 std::string test_file = FLAGS_model_output.c_str();
96 test_file +=
"_wtest";
97 FILE *f = fopen(test_file.c_str(),
"wb");
100 if (remove(test_file.c_str()) != 0) {
101 tprintf(
"Error, failed to remove %s: %s\n", test_file.c_str(), strerror(errno));
105 tprintf(
"Error, model output cannot be written: %s\n", strerror(errno));
110 std::string checkpoint_file = FLAGS_model_output.c_str();
111 checkpoint_file +=
"_checkpoint";
112 std::string checkpoint_bak = checkpoint_file +
".bak";
114 FLAGS_debug_interval,
115 static_cast<int64_t
>(FLAGS_max_image_MB) * 1048576);
116 if (!trainer.
InitCharSet(FLAGS_traineddata.c_str())) {
117 tprintf(
"Error, failed to read %s\n", FLAGS_traineddata.c_str());
123 if (FLAGS_stop_training || FLAGS_debug_network) {
125 tprintf(
"Failed to read continue from: %s\n", FLAGS_continue_from.c_str());
128 if (FLAGS_debug_network) {
131 if (FLAGS_convert_to_int) {
135 tprintf(
"Failed to write recognition model : %s\n", FLAGS_model_output.c_str());
142 if (FLAGS_train_listfile.empty()) {
143 tprintf(
"Must supply a list of training filenames! --train_listfile\n");
146 std::vector<std::string> filenames;
148 tprintf(
"Failed to load list of training filenames from %s\n", FLAGS_train_listfile.c_str());
155 tprintf(
"Successfully restored trainer from %s\n", checkpoint_file.c_str());
157 if (!FLAGS_continue_from.empty()) {
160 FLAGS_append_index >= 0 ? FLAGS_continue_from.c_str()
161 : FLAGS_old_traineddata.c_str())) {
162 tprintf(
"Failed to continue from: %s\n", FLAGS_continue_from.c_str());
165 tprintf(
"Continuing from %s\n", FLAGS_continue_from.c_str());
166 if (FLAGS_reset_learning_rate) {
168 tprintf(
"Set learning rate to %f\n",
static_cast<float>(FLAGS_learning_rate));
172 if (FLAGS_continue_from.empty() || FLAGS_append_index >= 0) {
173 if (FLAGS_append_index >= 0) {
174 tprintf(
"Appending a new network to an old one!!");
175 if (FLAGS_continue_from.empty()) {
176 tprintf(
"Must set --continue_from for appending!\n");
181 if (!trainer.
InitNetwork(FLAGS_net_spec.c_str(), FLAGS_append_index, FLAGS_net_mode,
182 FLAGS_weight_range, FLAGS_learning_rate, FLAGS_momentum,
184 tprintf(
"Failed to create network from spec: %s\n", FLAGS_net_spec.c_str());
193 FLAGS_randomly_rotate)) {
194 tprintf(
"Load of images failed!!\n");
200 if (!FLAGS_eval_listfile.empty()) {
201 using namespace std::placeholders;
203 tprintf(
"Failed to load eval data from: %s\n", FLAGS_eval_listfile.c_str());
209 int max_iterations = FLAGS_max_iterations;
210 if (max_iterations < 0) {
212 max_iterations = filenames.size() * (-max_iterations);
213 }
else if (max_iterations == 0) {
215 max_iterations = INT_MAX;
222 iteration < target_iteration && iteration < max_iterations;
226 std::stringstream log_str;
227 log_str.imbue(std::locale::classic());
229 tprintf(
"%s\n", log_str.str().c_str());
232 tprintf(
"Finished! Selected model with minimal training error rate (BCER) = %g\n",
#define DOUBLE_PARAM_FLAG(name, val, comment)
#define BOOL_PARAM_FLAG(name, val, comment)
#define INT_PARAM_FLAG(name, val, comment)
#define STRING_PARAM_FLAG(name, val, comment)
int main(int argc, char **argv)
const int kNumPagesPerBatch
void tprintf(const char *format,...)
void ParseArguments(int *argc, char ***argv)
std::function< std::string(int, const double *, const TessdataManager &, int)> TestCallback
bool LoadFileLinesToStrings(const char *filename, std::vector< std::string > *lines)
int training_iteration() const
void SetLearningRate(float learning_rate)
std::string RunEvalAsync(int iteration, const double *training_errors, const TessdataManager &model_mgr, int training_stage)
bool LoadAllEvalData(const char *filenames_file)
bool MaintainCheckpoints(const TestCallback &tester, std::stringstream &log_msg)
bool LoadAllTrainingData(const std::vector< std::string > &filenames, CachingStrategy cache_strategy, bool randomly_rotate)
bool InitCharSet(const std::string &traineddata_path)
double best_error_rate() const
bool SaveTraineddata(const char *filename)
void set_perfect_delay(int delay)
bool InitNetwork(const char *network_spec, int append_index, int net_flags, float weight_range, float learning_rate, float momentum, float adam_beta)
const ImageData * TrainOnLine(LSTMTrainer *samples_trainer, bool batch)
bool TryLoadingCheckpoint(const char *filename, const char *old_traineddata)