tesseract v5.3.3.20231005
tesseract::LSTMTrainerTest Class Reference

#include <lstm_test.h>

Inheritance diagram for tesseract::LSTMTrainerTest:
testing::Test

Protected Member Functions

void SetUp () override
 
 LSTMTrainerTest ()=default
 
std::string TestDataNameToPath (const std::string &name)
 
std::string TessDataNameToPath (const std::string &name)
 
std::string TestingNameToPath (const std::string &name)
 
void SetupTrainerEng (const std::string &network_spec, const std::string &model_name, bool recode, bool adam)
 
void SetupTrainer (const std::string &network_spec, const std::string &model_name, const std::string &unicharset_file, const std::string &lstmf_file, bool recode, bool adam, float learning_rate, bool layer_specific, const std::string &kLang)
 
double TrainIterations (int max_iterations)
 
double TestIterations (int max_iterations)
 
double TestIntMode (int test_iterations)
 
void TestEncodeDecode (const std::string &lang, const std::string &str, bool recode)
 
void TestEncodeDecodeBoth (const std::string &lang, const std::string &str)
 
- Protected Member Functions inherited from testing::Test
 Test ()
 
virtual void SetUp ()
 
virtual void TearDown ()
 

Protected Attributes

std::unique_ptr< LSTMTrainertrainer_
 

Additional Inherited Members

- Public Member Functions inherited from testing::Test
virtual ~Test ()
 
- Static Public Member Functions inherited from testing::Test
static void SetUpTestSuite ()
 
static void TearDownTestSuite ()
 
static void TearDownTestCase ()
 
static void SetUpTestCase ()
 
static bool HasFatalFailure ()
 
static bool HasNonfatalFailure ()
 
static bool IsSkipped ()
 
static bool HasFailure ()
 
static void RecordProperty (const std::string &key, const std::string &value)
 
static void RecordProperty (const std::string &key, int value)
 

Detailed Description

Definition at line 45 of file lstm_test.h.

Constructor & Destructor Documentation

◆ LSTMTrainerTest()

tesseract::LSTMTrainerTest::LSTMTrainerTest ( )
protecteddefault

Member Function Documentation

◆ SetUp()

void tesseract::LSTMTrainerTest::SetUp ( )
inlineoverrideprotectedvirtual

Reimplemented from testing::Test.

Definition at line 47 of file lstm_test.h.

47 {
48 std::locale::global(std::locale(""));
50 }
static void MakeTmpdir()
Definition: include_gunit.h:38

◆ SetupTrainer()

void tesseract::LSTMTrainerTest::SetupTrainer ( const std::string &  network_spec,
const std::string &  model_name,
const std::string &  unicharset_file,
const std::string &  lstmf_file,
bool  recode,
bool  adam,
float  learning_rate,
bool  layer_specific,
const std::string &  kLang 
)
inlineprotected

Definition at line 68 of file lstm_test.h.

70 {
71 // constexpr char kLang[] = "eng"; // Exact value doesn't matter.
72 std::string unicharset_name = TestDataNameToPath(unicharset_file);
73 UNICHARSET unicharset;
74 ASSERT_TRUE(unicharset.load_from_file(unicharset_name.c_str(), false));
75 std::string script_dir = file::JoinPath(LANGDATA_DIR, "");
76 std::vector<std::string> words;
77 EXPECT_EQ(0, CombineLangModel(unicharset, script_dir, "", FLAGS_test_tmpdir, kLang, !recode,
78 words, words, words, false, nullptr, nullptr));
79 std::string model_path = file::JoinPath(FLAGS_test_tmpdir, model_name);
80 std::string checkpoint_path = model_path + "_checkpoint";
81 trainer_ = std::make_unique<LSTMTrainer>(model_path.c_str(), checkpoint_path.c_str(), 0, 0);
82 trainer_->InitCharSet(
83 file::JoinPath(FLAGS_test_tmpdir, kLang, kLang) + ".traineddata");
84 int net_mode = adam ? NF_ADAM : 0;
85 // Adam needs a higher learning rate, due to not multiplying the effective
86 // rate by 1/(1-momentum).
87 if (adam) {
88 learning_rate *= 20.0f;
89 }
90 if (layer_specific) {
91 net_mode |= NF_LAYER_SPECIFIC_LR;
92 }
94 trainer_->InitNetwork(network_spec.c_str(), -1, net_mode, 0.1, learning_rate, 0.9, 0.999));
95 std::vector<std::string> filenames;
96 filenames.emplace_back(TestDataNameToPath(lstmf_file).c_str());
97 EXPECT_TRUE(trainer_->LoadAllTrainingData(filenames, CS_SEQUENTIAL, false));
98 LOG(INFO) << "Setup network:" << model_name << "\n";
99 }
@ LOG
@ INFO
Definition: log.h:28
#define EXPECT_EQ(val1, val2)
Definition: gtest.h:2043
#define EXPECT_TRUE(condition)
Definition: gtest.h:1982
#define ASSERT_TRUE(condition)
Definition: gtest.h:1990
@ NF_LAYER_SPECIFIC_LR
Definition: network.h:85
@ NF_ADAM
Definition: network.h:86
@ CS_SEQUENTIAL
Definition: imagedata.h:49
int CombineLangModel(const UNICHARSET &unicharset, const std::string &script_dir, const std::string &version_str, const std::string &output_dir, const std::string &lang, bool pass_through_recoder, const std::vector< std::string > &words, const std::vector< std::string > &puncs, const std::vector< std::string > &numbers, bool lang_is_rtl, FileReader reader, FileWriter writer)
static std::string JoinPath(const std::string &s1, const std::string &s2)
Definition: include_gunit.h:65
std::unique_ptr< LSTMTrainer > trainer_
Definition: lstm_test.h:180
std::string TestDataNameToPath(const std::string &name)
Definition: lstm_test.h:53

◆ SetupTrainerEng()

void tesseract::LSTMTrainerTest::SetupTrainerEng ( const std::string &  network_spec,
const std::string &  model_name,
bool  recode,
bool  adam 
)
inlineprotected

Definition at line 63 of file lstm_test.h.

64 {
65 SetupTrainer(network_spec, model_name, "eng/eng.unicharset", "eng.Arial.exp0.lstmf", recode,
66 adam, 5e-4, false, "eng");
67 }
void SetupTrainer(const std::string &network_spec, const std::string &model_name, const std::string &unicharset_file, const std::string &lstmf_file, bool recode, bool adam, float learning_rate, bool layer_specific, const std::string &kLang)
Definition: lstm_test.h:68

◆ TessDataNameToPath()

std::string tesseract::LSTMTrainerTest::TessDataNameToPath ( const std::string &  name)
inlineprotected

Definition at line 56 of file lstm_test.h.

56 {
57 return file::JoinPath(TESSDATA_DIR, "" + name);
58 }

◆ TestDataNameToPath()

std::string tesseract::LSTMTrainerTest::TestDataNameToPath ( const std::string &  name)
inlineprotected

Definition at line 53 of file lstm_test.h.

53 {
54 return file::JoinPath(TESTDATA_DIR, "" + name);
55 }

◆ TestEncodeDecode()

void tesseract::LSTMTrainerTest::TestEncodeDecode ( const std::string &  lang,
const std::string &  str,
bool  recode 
)
inlineprotected

Definition at line 163 of file lstm_test.h.

163 {
164 std::string unicharset_name = lang + "/" + lang + ".unicharset";
165 std::string lstmf_name = lang + ".Arial_Unicode_MS.exp0.lstmf";
166 SetupTrainer("[1,1,0,32 Lbx100 O1c1]", "bidi-lstm", unicharset_name, lstmf_name, recode, true,
167 5e-4f, true, lang);
168 std::vector<int> labels;
169 EXPECT_TRUE(trainer_->EncodeString(str.c_str(), &labels));
170 std::string decoded = trainer_->DecodeLabels(labels);
171 std::string decoded_str(&decoded[0], decoded.length());
172 EXPECT_EQ(str, decoded_str);
173 }

◆ TestEncodeDecodeBoth()

void tesseract::LSTMTrainerTest::TestEncodeDecodeBoth ( const std::string &  lang,
const std::string &  str 
)
inlineprotected

Definition at line 175 of file lstm_test.h.

175 {
176 TestEncodeDecode(lang, str, false);
177 TestEncodeDecode(lang, str, true);
178 }
void TestEncodeDecode(const std::string &lang, const std::string &str, bool recode)
Definition: lstm_test.h:163

◆ TestingNameToPath()

std::string tesseract::LSTMTrainerTest::TestingNameToPath ( const std::string &  name)
inlineprotected

Definition at line 59 of file lstm_test.h.

59 {
60 return file::JoinPath(TESTING_DIR, "" + name);
61 }

◆ TestIntMode()

double tesseract::LSTMTrainerTest::TestIntMode ( int  test_iterations)
inlineprotected

Definition at line 148 of file lstm_test.h.

148 {
149 std::vector<char> trainer_data;
150 EXPECT_TRUE(trainer_->SaveTrainingDump(NO_BEST_TRAINER, *trainer_, &trainer_data));
151 // Get the error on the next few iterations in float mode.
152 double float_err = TestIterations(test_iterations);
153 // Restore the dump, convert to int and test error on that.
154 EXPECT_TRUE(trainer_->ReadTrainingDump(trainer_data, *trainer_));
155 trainer_->ConvertToInt();
156 double int_err = TestIterations(test_iterations);
157 EXPECT_LT(int_err, float_err + 1.0);
158 return int_err - float_err;
159 }
#define EXPECT_LT(val1, val2)
Definition: gtest.h:2049
@ NO_BEST_TRAINER
Definition: lstmtrainer.h:62
double TestIterations(int max_iterations)
Definition: lstm_test.h:126

◆ TestIterations()

double tesseract::LSTMTrainerTest::TestIterations ( int  max_iterations)
inlineprotected

Definition at line 126 of file lstm_test.h.

126 {
127 CHECK_GT(max_iterations, 0);
128 int iteration = trainer_->sample_iteration();
129 double mean_error = 0.0;
130 int error_count = 0;
131 while (error_count < max_iterations) {
132 const ImageData &trainingdata =
133 *trainer_->mutable_training_data()->GetPageBySerial(iteration);
134 NetworkIO fwd_outputs, targets;
135 if (trainer_->PrepareForBackward(&trainingdata, &fwd_outputs, &targets) != UNENCODABLE) {
136 mean_error += trainer_->NewSingleError(ET_CHAR_ERROR);
137 ++error_count;
138 }
139 trainer_->SetIteration(++iteration);
140 }
141 mean_error *= 100.0 / max_iterations;
142 LOG(INFO) << "Tester error rate = " << mean_error << "\n";
143 return mean_error;
144 }
#define CHECK_GT(test, value)
Definition: include_gunit.h:81
@ ET_CHAR_ERROR
Definition: lstmtrainer.h:45

◆ TrainIterations()

double tesseract::LSTMTrainerTest::TrainIterations ( int  max_iterations)
inlineprotected

Definition at line 101 of file lstm_test.h.

101 {
102 int iteration = trainer_->training_iteration();
103 int iteration_limit = iteration + max_iterations;
104 double best_error = 100.0;
105 do {
106 std::stringstream log_str;
107 int target_iteration = iteration + kBatchIterations;
108 // Train a few.
109 double mean_error = 0.0;
110 while (iteration < target_iteration && iteration < iteration_limit) {
111 trainer_->TrainOnLine(trainer_.get(), false);
112 iteration = trainer_->training_iteration();
113 mean_error += trainer_->LastSingleError(ET_CHAR_ERROR);
114 }
115 trainer_->MaintainCheckpoints(nullptr, log_str);
116 iteration = trainer_->training_iteration();
117 mean_error *= 100.0 / kBatchIterations;
118 if (mean_error < best_error) {
119 best_error = mean_error;
120 }
121 } while (iteration < iteration_limit);
122 LOG(INFO) << "Trainer error rate = " << best_error << "\n";
123 return best_error;
124 }
const int kBatchIterations
Definition: lstm_test.h:36

Member Data Documentation

◆ trainer_

std::unique_ptr<LSTMTrainer> tesseract::LSTMTrainerTest::trainer_
protected

Definition at line 180 of file lstm_test.h.


The documentation for this class was generated from the following file: