This program reads in a text file consisting of feature samples from a training page in the following format:
The result of this program is a binary inttemp file used by the OCR engine.
76 if (FLAGS_model_output.empty()) {
77 tprintf(
"Must provide a --model_output!\n");
80 if (FLAGS_traineddata.empty()) {
81 tprintf(
"Must provide a --traineddata see training wiki\n");
85 for (
int i = 0; i < model_output.
length(); ++i) {
86 if (model_output[i] ==
'[' || model_output[i] ==
']')
87 model_output[i] =
'-';
88 if (model_output[i] ==
'(' || model_output[i] ==
')')
89 model_output[i] =
'_';
92 STRING checkpoint_file = FLAGS_model_output.
c_str();
93 checkpoint_file +=
"_checkpoint";
94 STRING checkpoint_bak = checkpoint_file +
".bak";
96 nullptr,
nullptr,
nullptr,
nullptr, FLAGS_model_output.c_str(),
97 checkpoint_file.
c_str(), FLAGS_debug_interval,
98 static_cast<inT64>(FLAGS_max_image_MB) * 1048576);
99 trainer.InitCharSet(FLAGS_traineddata.c_str());
103 if (FLAGS_stop_training || FLAGS_debug_network) {
104 if (!trainer.TryLoadingCheckpoint(FLAGS_continue_from.c_str(),
nullptr)) {
105 tprintf(
"Failed to read continue from: %s\n",
106 FLAGS_continue_from.c_str());
109 if (FLAGS_debug_network) {
110 trainer.DebugNetwork();
112 if (FLAGS_convert_to_int) trainer.ConvertToInt();
113 if (!trainer.SaveTraineddata(FLAGS_model_output.c_str())) {
114 tprintf(
"Failed to write recognition model : %s\n",
115 FLAGS_model_output.c_str());
122 if (FLAGS_train_listfile.empty()) {
123 tprintf(
"Must supply a list of training filenames! --train_listfile\n");
129 tprintf(
"Failed to load list of training filenames from %s\n",
130 FLAGS_train_listfile.c_str());
135 if (trainer.TryLoadingCheckpoint(checkpoint_file.
string(),
nullptr) ||
136 trainer.TryLoadingCheckpoint(checkpoint_bak.
string(),
nullptr)) {
137 tprintf(
"Successfully restored trainer from %s\n",
138 checkpoint_file.
string());
140 if (!FLAGS_continue_from.empty()) {
142 if (!trainer.TryLoadingCheckpoint(FLAGS_continue_from.c_str(),
143 FLAGS_append_index >= 0
144 ? FLAGS_continue_from.c_str()
145 : FLAGS_old_traineddata.c_str())) {
146 tprintf(
"Failed to continue from: %s\n", FLAGS_continue_from.c_str());
149 tprintf(
"Continuing from %s\n", FLAGS_continue_from.c_str());
150 trainer.InitIterations();
152 if (FLAGS_continue_from.empty() || FLAGS_append_index >= 0) {
153 if (FLAGS_append_index >= 0) {
154 tprintf(
"Appending a new network to an old one!!");
155 if (FLAGS_continue_from.empty()) {
156 tprintf(
"Must set --continue_from for appending!\n");
161 if (!trainer.InitNetwork(FLAGS_net_spec.c_str(), FLAGS_append_index,
162 FLAGS_net_mode, FLAGS_weight_range,
163 FLAGS_learning_rate, FLAGS_momentum,
165 tprintf(
"Failed to create network from spec: %s\n",
166 FLAGS_net_spec.c_str());
169 trainer.set_perfect_delay(FLAGS_perfect_sample_delay);
172 if (!trainer.LoadAllTrainingData(filenames,
173 FLAGS_sequential_training
176 FLAGS_randomly_rotate)) {
177 tprintf(
"Load of images failed!!\n");
184 if (!FLAGS_eval_listfile.empty()) {
185 if (!tester.LoadAllEvalData(FLAGS_eval_listfile.c_str())) {
186 tprintf(
"Failed to load eval data from: %s\n",
187 FLAGS_eval_listfile.c_str());
195 int iteration = trainer.training_iteration();
197 iteration < target_iteration;
198 iteration = trainer.training_iteration()) {
199 trainer.TrainOnLine(&trainer,
false);
202 trainer.MaintainCheckpoints(tester_callback, &log_str);
204 }
while (trainer.best_error_rate() > FLAGS_target_error_rate &&
205 (trainer.training_iteration() < FLAGS_max_iterations ||
206 FLAGS_max_iterations == 0));
207 delete tester_callback;
208 tprintf(
"Finished! Error rate = %g\n", trainer.best_error_rate());
const int kNumPagesPerBatch
_ConstTessMemberResultCallback_0_0< false, R, T1 >::base * NewPermanentTessCallback(const T1 *obj, R(T2::*member)() const)
const char * string() const
void ParseArguments(int *argc, char ***argv)
const char * c_str() const
bool LoadFileLinesToStrings(const STRING &filename, GenericVector< STRING > *lines)
STRING RunEvalAsync(int iteration, const double *training_errors, const TessdataManager &model_mgr, int training_stage)