12#ifndef TESSERACT_UNITTEST_LSTM_TEST_H_
13#define TESSERACT_UNITTEST_LSTM_TEST_H_
48 std::locale::global(std::locale(
""));
63 void SetupTrainerEng(
const std::string &network_spec,
const std::string &model_name,
bool recode,
65 SetupTrainer(network_spec, model_name,
"eng/eng.unicharset",
"eng.Arial.exp0.lstmf", recode,
66 adam, 5e-4,
false,
"eng");
68 void SetupTrainer(
const std::string &network_spec,
const std::string &model_name,
69 const std::string &unicharset_file,
const std::string &lstmf_file,
bool recode,
70 bool adam,
float learning_rate,
bool layer_specific,
const std::string &kLang) {
76 std::vector<std::string> words;
78 words, words, words,
false,
nullptr,
nullptr));
79 std::string model_path =
file::JoinPath(FLAGS_test_tmpdir, model_name);
80 std::string checkpoint_path = model_path +
"_checkpoint";
81 trainer_ = std::make_unique<LSTMTrainer>(model_path.c_str(), checkpoint_path.c_str(), 0, 0);
84 int net_mode = adam ?
NF_ADAM : 0;
88 learning_rate *= 20.0f;
94 trainer_->InitNetwork(network_spec.c_str(), -1, net_mode, 0.1, learning_rate, 0.9, 0.999));
95 std::vector<std::string> filenames;
98 LOG(
INFO) <<
"Setup network:" << model_name <<
"\n";
102 int iteration =
trainer_->training_iteration();
103 int iteration_limit = iteration + max_iterations;
104 double best_error = 100.0;
106 std::stringstream log_str;
109 double mean_error = 0.0;
110 while (iteration < target_iteration && iteration < iteration_limit) {
112 iteration =
trainer_->training_iteration();
115 trainer_->MaintainCheckpoints(
nullptr, log_str);
116 iteration =
trainer_->training_iteration();
118 if (mean_error < best_error) {
119 best_error = mean_error;
121 }
while (iteration < iteration_limit);
122 LOG(
INFO) <<
"Trainer error rate = " << best_error <<
"\n";
128 int iteration =
trainer_->sample_iteration();
129 double mean_error = 0.0;
131 while (error_count < max_iterations) {
133 *
trainer_->mutable_training_data()->GetPageBySerial(iteration);
139 trainer_->SetIteration(++iteration);
141 mean_error *= 100.0 / max_iterations;
142 LOG(
INFO) <<
"Tester error rate = " << mean_error <<
"\n";
149 std::vector<char> trainer_data;
158 return int_err - float_err;
164 std::string unicharset_name = lang +
"/" + lang +
".unicharset";
165 std::string lstmf_name = lang +
".Arial_Unicode_MS.exp0.lstmf";
166 SetupTrainer(
"[1,1,0,32 Lbx100 O1c1]",
"bidi-lstm", unicharset_name, lstmf_name, recode,
true,
168 std::vector<int> labels;
170 std::string decoded =
trainer_->DecodeLabels(labels);
171 std::string decoded_str(&decoded[0], decoded.length());
#define EXPECT_EQ(val1, val2)
#define EXPECT_TRUE(condition)
#define ASSERT_TRUE(condition)
#define EXPECT_LT(val1, val2)
#define CHECK_GT(test, value)
const int kBatchIterations
const int kTrainerIterations
int CombineLangModel(const UNICHARSET &unicharset, const std::string &script_dir, const std::string &version_str, const std::string &output_dir, const std::string &lang, bool pass_through_recoder, const std::vector< std::string > &words, const std::vector< std::string > &puncs, const std::vector< std::string > &numbers, bool lang_is_rtl, FileReader reader, FileWriter writer)
bool load_from_file(const char *const filename, bool skip_fragments)
static std::string JoinPath(const std::string &s1, const std::string &s2)
double TestIntMode(int test_iterations)
std::unique_ptr< LSTMTrainer > trainer_
std::string TessDataNameToPath(const std::string &name)
double TrainIterations(int max_iterations)
void TestEncodeDecode(const std::string &lang, const std::string &str, bool recode)
double TestIterations(int max_iterations)
std::string TestingNameToPath(const std::string &name)
LSTMTrainerTest()=default
void TestEncodeDecodeBoth(const std::string &lang, const std::string &str)
void SetupTrainerEng(const std::string &network_spec, const std::string &model_name, bool recode, bool adam)
std::string TestDataNameToPath(const std::string &name)
void SetupTrainer(const std::string &network_spec, const std::string &model_name, const std::string &unicharset_file, const std::string &lstmf_file, bool recode, bool adam, float learning_rate, bool layer_specific, const std::string &kLang)