tesseract v5.3.3.20231005
lstmtrainer_test.cc
Go to the documentation of this file.
1// (C) Copyright 2017, Google Inc.
2// Licensed under the Apache License, Version 2.0 (the "License");
3// you may not use this file except in compliance with the License.
4// You may obtain a copy of the License at
5// http://www.apache.org/licenses/LICENSE-2.0
6// Unless required by applicable law or agreed to in writing, software
7// distributed under the License is distributed on an "AS IS" BASIS,
8// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9// See the License for the specific language governing permissions and
10// limitations under the License.
11
12#include <allheaders.h>
13#include <tesseract/baseapi.h>
14#include "lstm_test.h"
15
16namespace tesseract {
17
18TEST_F(LSTMTrainerTest, EncodesEng) {
19 TestEncodeDecodeBoth("eng", "The quick brown 'fox' jumps over: the lazy dog!");
20}
21
22TEST_F(LSTMTrainerTest, EncodesKan) {
23 TestEncodeDecodeBoth("kan", "ಫ್ರಬ್ರವರಿ ತತ್ವಾಂಶಗಳೆಂದರೆ ಮತ್ತು ಜೊತೆಗೆ ಕ್ರಮವನ್ನು");
24}
25
26TEST_F(LSTMTrainerTest, EncodesKor) {
27 TestEncodeDecodeBoth("kor", "이는 것으로 다시 넣을 수는 있지만 선택의 의미는");
28}
29
31 LSTMTrainer fra_trainer;
32 fra_trainer.InitCharSet(TestDataNameToPath("fra/fra.traineddata"));
33 LSTMTrainer deu_trainer;
34 deu_trainer.InitCharSet(TestDataNameToPath("deu/deu.traineddata"));
35 // A string that uses characters common to French and German.
36 std::string kTestStr = "The quick brown 'fox' jumps over: the lazy dog!";
37 std::vector<int> deu_labels;
38 EXPECT_TRUE(deu_trainer.EncodeString(kTestStr.c_str(), &deu_labels));
39 // The french trainer cannot decode them correctly.
40 std::string badly_decoded = fra_trainer.DecodeLabels(deu_labels);
41 std::string bad_str(&badly_decoded[0], badly_decoded.length());
42 LOG(INFO) << "bad_str fra=" << bad_str << "\n";
43 EXPECT_NE(kTestStr, bad_str);
44 // Encode the string as fra.
45 std::vector<int> fra_labels;
46 EXPECT_TRUE(fra_trainer.EncodeString(kTestStr.c_str(), &fra_labels));
47 // Use the mapper to compute what the labels are as deu.
48 std::vector<int> mapping =
49 fra_trainer.MapRecoder(deu_trainer.GetUnicharset(), deu_trainer.GetRecoder());
50 std::vector<int> mapped_fra_labels(fra_labels.size(), -1);
51 for (unsigned i = 0; i < fra_labels.size(); ++i) {
52 mapped_fra_labels[i] = mapping[fra_labels[i]];
53 EXPECT_NE(-1, mapped_fra_labels[i]) << "i=" << i << ", ch=" << kTestStr[i];
54 EXPECT_EQ(mapped_fra_labels[i], deu_labels[i])
55 << "i=" << i << ", ch=" << kTestStr[i] << " has deu label=" << deu_labels[i]
56 << ", but mapped to " << mapped_fra_labels[i];
57 }
58 // The german trainer can now decode them correctly.
59 std::string decoded = deu_trainer.DecodeLabels(mapped_fra_labels);
60 std::string ok_str(&decoded[0], decoded.length());
61 LOG(INFO) << "ok_str deu=" << ok_str << "\n";
62 EXPECT_EQ(kTestStr, ok_str);
63}
64
65// Tests that the actual fra model can be converted to the deu character set
66// and still read an eng image with 100% accuracy.
67TEST_F(LSTMTrainerTest, ConvertModel) {
68 // Setup a trainer with a deu charset.
69 LSTMTrainer deu_trainer;
70 deu_trainer.InitCharSet(TestDataNameToPath("deu/deu.traineddata"));
71 // Load the fra traineddata, strip out the model, and save to a tmp file.
73 std::string fra_data = file::JoinPath(TESSDATA_DIR "_best", "fra.traineddata");
74 CHECK(mgr.Init(fra_data.c_str()));
75 LOG(INFO) << "Load " << fra_data << "\n";
77 std::string model_path = file::JoinPath(FLAGS_test_tmpdir, "fra.lstm");
78 CHECK(mgr.ExtractToFile(model_path.c_str()));
79 LOG(INFO) << "Extract " << model_path << "\n";
80 // Load the fra model into the deu_trainer, and save the converted model.
81 CHECK(deu_trainer.TryLoadingCheckpoint(model_path.c_str(), fra_data.c_str()));
82 LOG(INFO) << "Checkpoint load for " << model_path << " and " << fra_data << "\n";
83 std::string deu_data = file::JoinPath(FLAGS_test_tmpdir, "deu.traineddata");
84 CHECK(deu_trainer.SaveTraineddata(deu_data.c_str()));
85 LOG(INFO) << "Save " << deu_data << "\n";
86 // Now run the saved model on phototest. (See BasicTesseractTest in
87 // baseapi_test.cc).
88 TessBaseAPI api;
89 api.Init(FLAGS_test_tmpdir, "deu", tesseract::OEM_LSTM_ONLY);
90 Image src_pix = pixRead(TestingNameToPath("phototest.tif").c_str());
91 CHECK(src_pix);
92 api.SetImage(src_pix);
93 std::unique_ptr<char[]> result(api.GetUTF8Text());
94 std::string truth_text;
96 file::GetContents(TestingNameToPath("phototest.gold.txt"), &truth_text, file::Defaults()));
97
98 EXPECT_STREQ(truth_text.c_str(), result.get());
99 src_pix.destroy();
100}
101
102} // namespace tesseract
@ LOG
@ INFO
Definition: log.h:28
#define EXPECT_EQ(val1, val2)
Definition: gtest.h:2043
#define EXPECT_NE(val1, val2)
Definition: gtest.h:2045
#define EXPECT_TRUE(condition)
Definition: gtest.h:1982
#define EXPECT_STREQ(s1, s2)
Definition: gtest.h:2112
#define CHECK(condition)
Definition: include_gunit.h:76
#define CHECK_OK(test)
Definition: include_gunit.h:84
std::string TestDataNameToPath(const std::string &name)
TEST_F(EuroText, FastLatinOCR)
int Init(const char *datapath, const char *language, OcrEngineMode mode, char **configs, int configs_size, const std::vector< std::string > *vars_vec, const std::vector< std::string > *vars_values, bool set_only_non_debug_params)
Definition: baseapi.cpp:368
void SetImage(const unsigned char *imagedata, int width, int height, int bytes_per_pixel, int bytes_per_line)
Definition: baseapi.cpp:576
void destroy()
Definition: image.cpp:32
bool ExtractToFile(const char *filename)
bool Init(const char *data_file_name)
const UnicharCompress & GetRecoder() const
std::string DecodeLabels(const std::vector< int > &labels)
const UNICHARSET & GetUnicharset() const
bool EncodeString(const std::string &str, std::vector< int > *labels) const
Definition: lstmtrainer.h:254
bool InitCharSet(const std::string &traineddata_path)
Definition: lstmtrainer.h:100
bool SaveTraineddata(const char *filename)
bool TryLoadingCheckpoint(const char *filename, const char *old_traineddata)
std::vector< int > MapRecoder(const UNICHARSET &old_chset, const UnicharCompress &old_recoder) const
static int Defaults()
Definition: include_gunit.h:61
static void MakeTmpdir()
Definition: include_gunit.h:38
static std::string JoinPath(const std::string &s1, const std::string &s2)
Definition: include_gunit.h:65
static bool GetContents(const std::string &filename, std::string *out, int)
Definition: include_gunit.h:52