tesseract v5.3.3.20231005
lang_model_test.cc
Go to the documentation of this file.
1// (C) Copyright 2017, Google Inc.
2// Licensed under the Apache License, Version 2.0 (the "License");
3// you may not use this file except in compliance with the License.
4// You may obtain a copy of the License at
5// http://www.apache.org/licenses/LICENSE-2.0
6// Unless required by applicable law or agreed to in writing, software
7// distributed under the License is distributed on an "AS IS" BASIS,
8// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9// See the License for the specific language governing permissions and
10// limitations under the License.
11
12#include <string> // for std::string
13
14#include "gmock/gmock.h" // for testing::ElementsAreArray
15
16#include "include_gunit.h"
17#include "lang_model_helpers.h"
18#include "log.h" // for LOG
19#include "lstmtrainer.h"
21
22namespace tesseract {
23
24std::string TestDataNameToPath(const std::string &name) {
25 return file::JoinPath(TESTING_DIR, name);
26}
27
28// This is an integration test that verifies that CombineLangModel works to
29// the extent that an LSTMTrainer can be initialized with the result, and it
30// can encode strings. More importantly, the test verifies that adding an extra
31// character to the unicharset does not change the encoding of strings.
32TEST(LangModelTest, AddACharacter) {
33 constexpr char kTestString[] = "Simple ASCII string to encode !@#$%&";
34 constexpr char kTestStringRupees[] = "ASCII string with Rupee symbol ₹";
35 // Setup the arguments.
36 std::string script_dir = LANGDATA_DIR;
37 std::string eng_dir = file::JoinPath(script_dir, "eng");
38 std::string unicharset_path = TestDataNameToPath("eng_beam.unicharset");
39 UNICHARSET unicharset;
40 EXPECT_TRUE(unicharset.load_from_file(unicharset_path.c_str()));
41 std::string version_str = "TestVersion";
43 std::string output_dir = FLAGS_test_tmpdir;
44 LOG(INFO) << "Output dir=" << output_dir << "\n";
45 std::string lang1 = "eng";
46 bool pass_through_recoder = false;
47 // If these reads fail, we get a warning message and an empty list of words.
48 std::vector<std::string> words = split(ReadFile(file::JoinPath(eng_dir, "eng.wordlist")), '\n');
49 EXPECT_GT(words.size(), 0);
50 std::vector<std::string> puncs = split(ReadFile(file::JoinPath(eng_dir, "eng.punc")), '\n');
51 EXPECT_GT(puncs.size(), 0);
52 std::vector<std::string> numbers = split(ReadFile(file::JoinPath(eng_dir, "eng.numbers")), '\n');
53 EXPECT_GT(numbers.size(), 0);
54 bool lang_is_rtl = false;
55 // Generate the traineddata file.
56 EXPECT_EQ(0, CombineLangModel(unicharset, script_dir, version_str, output_dir, lang1,
57 pass_through_recoder, words, puncs, numbers, lang_is_rtl, nullptr,
58 nullptr));
59 // Init a trainer with it, and encode kTestString.
60 std::string traineddata1 = file::JoinPath(output_dir, lang1, lang1) + ".traineddata";
61 LSTMTrainer trainer1;
62 trainer1.InitCharSet(traineddata1);
63 std::vector<int> labels1;
64 EXPECT_TRUE(trainer1.EncodeString(kTestString, &labels1));
65 std::string test1_decoded = trainer1.DecodeLabels(labels1);
66 std::string test1_str(&test1_decoded[0], test1_decoded.length());
67 LOG(INFO) << "Labels1=" << test1_str << "\n";
68
69 // Add a new character to the unicharset and try again.
70 int size_before = unicharset.size();
71 unicharset.unichar_insert("₹");
72 SetupBasicProperties(/*report_errors*/ true, /*decompose (NFD)*/ false, &unicharset);
73 EXPECT_EQ(size_before + 1, unicharset.size());
74 // Generate the traineddata file.
75 std::string lang2 = "extended";
76 EXPECT_EQ(EXIT_SUCCESS, CombineLangModel(unicharset, script_dir, version_str, output_dir, lang2,
77 pass_through_recoder, words, puncs, numbers, lang_is_rtl,
78 nullptr, nullptr));
79 // Init a trainer with it, and encode kTestString.
80 std::string traineddata2 = file::JoinPath(output_dir, lang2, lang2) + ".traineddata";
81 LSTMTrainer trainer2;
82 trainer2.InitCharSet(traineddata2);
83 std::vector<int> labels2;
84 EXPECT_TRUE(trainer2.EncodeString(kTestString, &labels2));
85 std::string test2_decoded = trainer2.DecodeLabels(labels2);
86 std::string test2_str(&test2_decoded[0], test2_decoded.length());
87 LOG(INFO) << "Labels2=" << test2_str << "\n";
88 // encode kTestStringRupees.
89 std::vector<int> labels3;
90 EXPECT_TRUE(trainer2.EncodeString(kTestStringRupees, &labels3));
91 std::string test3_decoded = trainer2.DecodeLabels(labels3);
92 std::string test3_str(&test3_decoded[0], test3_decoded.length());
93 LOG(INFO) << "labels3=" << test3_str << "\n";
94 // Copy labels1 to a std::vector, renumbering the null char to match trainer2.
95 // Since Tensor Flow's CTC implementation insists on having the null be the
96 // last label, and we want to be compatible, null has to be renumbered when
97 // we add a class.
98 int null1 = trainer1.null_char();
99 int null2 = trainer2.null_char();
100 EXPECT_EQ(null1 + 1, null2);
101 std::vector<int> labels1_v(labels1.size());
102 for (unsigned i = 0; i < labels1.size(); ++i) {
103 if (labels1[i] == null1) {
104 labels1_v[i] = null2;
105 } else {
106 labels1_v[i] = labels1[i];
107 }
108 }
109 EXPECT_THAT(labels1_v, testing::ElementsAreArray(&labels2[0], labels2.size()));
110 // To make sure we we are not cheating somehow, we can now encode the Rupee
111 // symbol, which we could not do before.
112 EXPECT_FALSE(trainer1.EncodeString(kTestStringRupees, &labels1));
113 EXPECT_TRUE(trainer2.EncodeString(kTestStringRupees, &labels2));
114}
115
116// Same as above test, for hin instead of eng
117TEST(LangModelTest, AddACharacterHindi) {
118 constexpr char kTestString[] = "हिन्दी में एक लाइन लिखें";
119 constexpr char kTestStringRupees[] = "हिंदी में रूपये का चिन्ह प्रयोग करें ₹१००.००";
120 // Setup the arguments.
121 std::string script_dir = LANGDATA_DIR;
122 std::string hin_dir = file::JoinPath(script_dir, "hin");
123 std::string unicharset_path = TestDataNameToPath("hin_beam.unicharset");
124 UNICHARSET unicharset;
125 EXPECT_TRUE(unicharset.load_from_file(unicharset_path.c_str()));
126 std::string version_str = "TestVersion";
128 std::string output_dir = FLAGS_test_tmpdir;
129 LOG(INFO) << "Output dir=" << output_dir << "\n";
130 std::string lang1 = "hin";
131 bool pass_through_recoder = false;
132 // If these reads fail, we get a warning message and an empty list of words.
133 std::vector<std::string> words = split(ReadFile(file::JoinPath(hin_dir, "hin.wordlist")), '\n');
134 EXPECT_GT(words.size(), 0);
135 std::vector<std::string> puncs = split(ReadFile(file::JoinPath(hin_dir, "hin.punc")), '\n');
136 EXPECT_GT(puncs.size(), 0);
137 std::vector<std::string> numbers = split(ReadFile(file::JoinPath(hin_dir, "hin.numbers")), '\n');
138 EXPECT_GT(numbers.size(), 0);
139 bool lang_is_rtl = false;
140 // Generate the traineddata file.
141 EXPECT_EQ(0, CombineLangModel(unicharset, script_dir, version_str, output_dir, lang1,
142 pass_through_recoder, words, puncs, numbers, lang_is_rtl, nullptr,
143 nullptr));
144 // Init a trainer with it, and encode kTestString.
145 std::string traineddata1 = file::JoinPath(output_dir, lang1, lang1) + ".traineddata";
146 LSTMTrainer trainer1;
147 trainer1.InitCharSet(traineddata1);
148 std::vector<int> labels1;
149 EXPECT_TRUE(trainer1.EncodeString(kTestString, &labels1));
150 std::string test1_decoded = trainer1.DecodeLabels(labels1);
151 std::string test1_str(&test1_decoded[0], test1_decoded.length());
152 LOG(INFO) << "Labels1=" << test1_str << "\n";
153
154 // Add a new character to the unicharset and try again.
155 int size_before = unicharset.size();
156 unicharset.unichar_insert("₹");
157 SetupBasicProperties(/*report_errors*/ true, /*decompose (NFD)*/ false, &unicharset);
158 EXPECT_EQ(size_before + 1, unicharset.size());
159 // Generate the traineddata file.
160 std::string lang2 = "extendedhin";
161 EXPECT_EQ(EXIT_SUCCESS, CombineLangModel(unicharset, script_dir, version_str, output_dir, lang2,
162 pass_through_recoder, words, puncs, numbers, lang_is_rtl,
163 nullptr, nullptr));
164 // Init a trainer with it, and encode kTestString.
165 std::string traineddata2 = file::JoinPath(output_dir, lang2, lang2) + ".traineddata";
166 LSTMTrainer trainer2;
167 trainer2.InitCharSet(traineddata2);
168 std::vector<int> labels2;
169 EXPECT_TRUE(trainer2.EncodeString(kTestString, &labels2));
170 std::string test2_decoded = trainer2.DecodeLabels(labels2);
171 std::string test2_str(&test2_decoded[0], test2_decoded.length());
172 LOG(INFO) << "Labels2=" << test2_str << "\n";
173 // encode kTestStringRupees.
174 std::vector<int> labels3;
175 EXPECT_TRUE(trainer2.EncodeString(kTestStringRupees, &labels3));
176 std::string test3_decoded = trainer2.DecodeLabels(labels3);
177 std::string test3_str(&test3_decoded[0], test3_decoded.length());
178 LOG(INFO) << "labels3=" << test3_str << "\n";
179 // Copy labels1 to a std::vector, renumbering the null char to match trainer2.
180 // Since Tensor Flow's CTC implementation insists on having the null be the
181 // last label, and we want to be compatible, null has to be renumbered when
182 // we add a class.
183 int null1 = trainer1.null_char();
184 int null2 = trainer2.null_char();
185 EXPECT_EQ(null1 + 1, null2);
186 std::vector<int> labels1_v(labels1.size());
187 for (unsigned i = 0; i < labels1.size(); ++i) {
188 if (labels1[i] == null1) {
189 labels1_v[i] = null2;
190 } else {
191 labels1_v[i] = labels1[i];
192 }
193 }
194 EXPECT_THAT(labels1_v, testing::ElementsAreArray(&labels2[0], labels2.size()));
195 // To make sure we we are not cheating somehow, we can now encode the Rupee
196 // symbol, which we could not do before.
197 EXPECT_FALSE(trainer1.EncodeString(kTestStringRupees, &labels1));
198 EXPECT_TRUE(trainer2.EncodeString(kTestStringRupees, &labels2));
199}
200
201} // namespace tesseract
@ LOG
@ INFO
Definition: log.h:28
#define EXPECT_THAT(value, matcher)
#define EXPECT_EQ(val1, val2)
Definition: gtest.h:2043
#define EXPECT_GT(val1, val2)
Definition: gtest.h:2053
#define EXPECT_TRUE(condition)
Definition: gtest.h:1982
#define EXPECT_FALSE(condition)
Definition: gtest.h:1986
void SetupBasicProperties(bool report_errors, bool decompose, UNICHARSET *unicharset)
std::string TestDataNameToPath(const std::string &name)
std::string ReadFile(const std::string &filename, FileReader reader)
int CombineLangModel(const UNICHARSET &unicharset, const std::string &script_dir, const std::string &version_str, const std::string &output_dir, const std::string &lang, bool pass_through_recoder, const std::vector< std::string > &words, const std::vector< std::string > &puncs, const std::vector< std::string > &numbers, bool lang_is_rtl, FileReader reader, FileWriter writer)
const std::vector< std::string > split(const std::string &s, char c)
Definition: helpers.h:43
TEST(TesseractInstanceTest, TestMultipleTessInstances)
void unichar_insert(const char *const unichar_repr, OldUncleanUnichars old_style)
Definition: unicharset.cpp:654
bool load_from_file(const char *const filename, bool skip_fragments)
Definition: unicharset.h:391
size_t size() const
Definition: unicharset.h:355
std::string DecodeLabels(const std::vector< int > &labels)
bool EncodeString(const std::string &str, std::vector< int > *labels) const
Definition: lstmtrainer.h:254
bool InitCharSet(const std::string &traineddata_path)
Definition: lstmtrainer.h:100
static void MakeTmpdir()
Definition: include_gunit.h:38
static std::string JoinPath(const std::string &s1, const std::string &s2)
Definition: include_gunit.h:65