tesseract v5.3.3.20231005
mastertrainer.h
Go to the documentation of this file.
1// Copyright 2010 Google Inc. All Rights Reserved.
2// Author: rays@google.com (Ray Smith)
4// File: mastertrainer.h
5// Description: Trainer to build the MasterClassifier.
6// Author: Ray Smith
7//
8// (C) Copyright 2010, Google Inc.
9// Licensed under the Apache License, Version 2.0 (the "License");
10// you may not use this file except in compliance with the License.
11// You may obtain a copy of the License at
12// http://www.apache.org/licenses/LICENSE-2.0
13// Unless required by applicable law or agreed to in writing, software
14// distributed under the License is distributed on an "AS IS" BASIS,
15// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16// See the License for the specific language governing permissions and
17// limitations under the License.
18//
20
21#ifndef TESSERACT_TRAINING_MASTERTRAINER_H_
22#define TESSERACT_TRAINING_MASTERTRAINER_H_
23
24#include "export.h"
25
26#include "classify.h"
27#include "cluster.h"
28#include "elst.h"
29#include "errorcounter.h"
30#include "featdefs.h"
31#include "fontinfo.h"
32#include "indexmapbidi.h"
33#include "intfeaturemap.h"
34#include "intfeaturespace.h"
35#include "intfx.h"
36#include "intmatcher.h"
37#include "params.h"
38#include "shapetable.h"
39#include "trainingsample.h"
40#include "trainingsampleset.h"
41#include "unicharset.h"
42
43namespace tesseract {
44
45class ShapeClassifier;
46
47// Simple struct to hold the distance between two shapes during clustering.
48struct ShapeDist {
49 ShapeDist() : shape1(0), shape2(0), distance(0.0f) {}
50 ShapeDist(int s1, int s2, float dist) : shape1(s1), shape2(s2), distance(dist) {}
51
52 // Sort operator to sort in ascending order of distance.
53 bool operator<(const ShapeDist &other) const {
54 return distance < other.distance;
55 }
56
57 int shape1;
58 int shape2;
59 float distance;
60};
61
62// Class to encapsulate training processes that use the TrainingSampleSet.
63// Initially supports shape clustering and mftrainining.
64// Other important features of the MasterTrainer are conditioning the data
65// by outlier elimination, replication with perturbation, and serialization.
66class TESS_COMMON_TRAINING_API MasterTrainer {
67public:
68 MasterTrainer(NormalizationMode norm_mode, bool shape_analysis, bool replicate_samples,
69 int debug_level);
71
72 // Writes to the given file. Returns false in case of error.
73 bool Serialize(FILE *fp) const;
74
75 // Loads an initial unicharset, or sets one up if the file cannot be read.
76 void LoadUnicharset(const char *filename);
77
78 // Sets the feature space definition.
80 feature_space_ = fs;
81 feature_map_.Init(fs);
82 }
83
84 // Reads the samples and their features from the given file,
85 // adding them to the trainer with the font_id from the content of the file.
86 // If verification, then these are verification samples, not training.
87 void ReadTrainingSamples(const char *page_name, const FEATURE_DEFS_STRUCT &feature_defs,
88 bool verification);
89
90 // Adds the given single sample to the trainer, setting the classid
91 // appropriately from the given unichar_str.
92 void AddSample(bool verification, const char *unichar_str, TrainingSample *sample);
93
94 // Loads all pages from the given tif filename and append to page_images_.
95 // Must be called after ReadTrainingSamples, as the current number of images
96 // is used as an offset for page numbers in the samples.
97 void LoadPageImages(const char *filename);
98
99 // Cleans up the samples after initial load from the tr files, and prior to
100 // saving the MasterTrainer:
101 // Remaps fragmented chars if running shape analysis.
102 // Sets up the samples appropriately for class/fontwise access.
103 // Deletes outlier samples.
104 void PostLoadCleanup();
105
106 // Gets the samples ready for training. Use after both
107 // ReadTrainingSamples+PostLoadCleanup or DeSerialize.
108 // Re-indexes the features and computes canonical and cloud features.
109 void PreTrainingSetup();
110
111 // Sets up the master_shapes_ table, which tells which fonts should stay
112 // together until they get to a leaf node classifier.
113 void SetupMasterShapes();
114
115 // Adds the junk_samples_ to the main samples_ set. Junk samples are initially
116 // fragments and n-grams (all incorrectly segmented characters).
117 // Various training functions may result in incorrectly segmented characters
118 // being added to the unicharset of the main samples, perhaps because they
119 // form a "radical" decomposition of some (Indic) grapheme, or because they
120 // just look the same as a real character (like rn/m)
121 // This function moves all the junk samples, to the main samples_ set, but
122 // desirable junk, being any sample for which the unichar already exists in
123 // the samples_ unicharset gets the unichar-ids re-indexed to match, but
124 // anything else gets re-marked as unichar_id 0 (space character) to identify
125 // it as junk to the error counter.
126 void IncludeJunk();
127
128 // Replicates the samples and perturbs them if the enable_replication_ flag
129 // is set. MUST be used after the last call to OrganizeByFontAndClass on
130 // the training samples, ie after IncludeJunk if it is going to be used, as
131 // OrganizeByFontAndClass will eat the replicated samples into the regular
132 // samples.
133 void ReplicateAndRandomizeSamplesIfRequired();
134
135 // Loads the basic font properties file into fontinfo_table_.
136 // Returns false on failure.
137 bool LoadFontInfo(const char *filename);
138
139 // Loads the xheight font properties file into xheights_.
140 // Returns false on failure.
141 bool LoadXHeights(const char *filename);
142
143 // Reads spacing stats from filename and adds them to fontinfo_table.
144 // Returns false on failure.
145 bool AddSpacingInfo(const char *filename);
146
147 // Returns the font id corresponding to the given font name.
148 // Returns -1 if the font cannot be found.
149 int GetFontInfoId(const char *font_name);
150 // Returns the font_id of the closest matching font name to the given
151 // filename. It is assumed that a substring of the filename will match
152 // one of the fonts. If more than one is matched, the longest is returned.
153 int GetBestMatchingFontInfoId(const char *filename);
154
155 // Returns the filename of the tr file corresponding to the command-line
156 // argument with the given index.
157 const std::string &GetTRFileName(int index) const {
158 return tr_filenames_[index];
159 }
160
161 // Sets up a flat shapetable with one shape per class/font combination.
162 void SetupFlatShapeTable(ShapeTable *shape_table);
163
164 // Sets up a Clusterer for mftraining on a single shape_id.
165 // Call FreeClusterer on the return value after use.
166 CLUSTERER *SetupForClustering(const ShapeTable &shape_table,
167 const FEATURE_DEFS_STRUCT &feature_defs, int shape_id,
168 int *num_samples);
169
170 // Writes the given float_classes (produced by SetupForFloat2Int) as inttemp
171 // to the given inttemp_file, and the corresponding pffmtable.
172 // The unicharset is the original encoding of graphemes, and shape_set should
173 // match the size of the shape_table, and may possibly be totally fake.
174 void WriteInttempAndPFFMTable(const UNICHARSET &unicharset, const UNICHARSET &shape_set,
175 const ShapeTable &shape_table, CLASS_STRUCT *float_classes,
176 const char *inttemp_file, const char *pffmtable_file);
177
178 const UNICHARSET &unicharset() const {
179 return samples_.unicharset();
180 }
182 return &samples_;
183 }
184 const ShapeTable &master_shapes() const {
185 return master_shapes_;
186 }
187
188 // Generates debug output relating to the canonical distance between the
189 // two given UTF8 grapheme strings.
190 void DebugCanonical(const char *unichar_str1, const char *unichar_str2);
191#ifndef GRAPHICS_DISABLED
192 // Debugging for cloud/canonical features.
193 // Displays a Features window containing:
194 // If unichar_str2 is in the unicharset, and canonical_font is non-negative,
195 // displays the canonical features of the char/font combination in red.
196 // If unichar_str1 is in the unicharset, and cloud_font is non-negative,
197 // displays the cloud feature of the char/font combination in green.
198 // The canonical features are drawn first to show which ones have no
199 // matches in the cloud features.
200 // Until the features window is destroyed, each click in the features window
201 // will display the samples that have that feature in a separate window.
202 void DisplaySamples(const char *unichar_str1, int cloud_font, const char *unichar_str2,
203 int canonical_font);
204#endif // !GRAPHICS_DISABLED
205
206 void TestClassifierVOld(bool replicate_samples, ShapeClassifier *test_classifier,
207 ShapeClassifier *old_classifier);
208
209 // Tests the given test_classifier on the internal samples.
210 // See TestClassifier for details.
211 void TestClassifierOnSamples(CountTypes error_mode, int report_level, bool replicate_samples,
212 ShapeClassifier *test_classifier, std::string *report_string);
213 // Tests the given test_classifier on the given samples
214 // error_mode indicates what counts as an error.
215 // report_levels:
216 // 0 = no output.
217 // 1 = bottom-line error rate.
218 // 2 = bottom-line error rate + time.
219 // 3 = font-level error rate + time.
220 // 4 = list of all errors + short classifier debug output on 16 errors.
221 // 5 = list of all errors + short classifier debug output on 25 errors.
222 // If replicate_samples is true, then the test is run on an extended test
223 // sample including replicated and systematically perturbed samples.
224 // If report_string is non-nullptr, a summary of the results for each font
225 // is appended to the report_string.
226 double TestClassifier(CountTypes error_mode, int report_level, bool replicate_samples,
227 TrainingSampleSet *samples, ShapeClassifier *test_classifier,
228 std::string *report_string);
229
230 // Returns the average (in some sense) distance between the two given
231 // shapes, which may contain multiple fonts and/or unichars.
232 // This function is public to facilitate testing.
233 float ShapeDistance(const ShapeTable &shapes, int s1, int s2);
234
235private:
236 // Replaces samples that are always fragmented with the corresponding
237 // fragment samples.
238 void ReplaceFragmentedSamples();
239
240 // Runs a hierarchical agglomerative clustering to merge shapes in the given
241 // shape_table, while satisfying the given constraints:
242 // * End with at least min_shapes left in shape_table,
243 // * No shape shall have more than max_shape_unichars in it,
244 // * Don't merge shapes where the distance between them exceeds max_dist.
245 void ClusterShapes(int min_shapes, int max_shape_unichars, float max_dist,
246 ShapeTable *shape_table);
247
248private:
249 NormalizationMode norm_mode_;
250 // Character set we are training for.
251 UNICHARSET unicharset_;
252 // Original feature space. Subspace mapping is contained in feature_map_.
253 IntFeatureSpace feature_space_;
254 TrainingSampleSet samples_;
255 TrainingSampleSet junk_samples_;
256 TrainingSampleSet verify_samples_;
257 // Master shape table defines what fonts stay together until the leaves.
258 ShapeTable master_shapes_;
259 // Flat shape table has each unichar/font id pair in a separate shape.
260 ShapeTable flat_shapes_;
261 // Font metrics gathered from multiple files.
262 FontInfoTable fontinfo_table_;
263 // Array of xheights indexed by font ids in fontinfo_table_;
264 std::vector<int32_t> xheights_;
265
266 // Non-serialized data initialized by other means or used temporarily
267 // during loading of training samples.
268 // Number of different class labels in unicharset_.
269 int charsetsize_;
270 // Flag to indicate that we are running shape analysis and need fragments
271 // fixing.
272 bool enable_shape_analysis_;
273 // Flag to indicate that sample replication is required.
274 bool enable_replication_;
275 // Array of classids of fragments that replace the correctly segmented chars.
276 int *fragments_;
277 // Classid of previous correctly segmented sample that was added.
278 int prev_unichar_id_;
279 // Debug output control.
280 int debug_level_;
281 // Feature map used to construct reduced feature spaces for compact
282 // classifiers.
283 IntFeatureMap feature_map_;
284 // Vector of Pix pointers used for classifiers that need the image.
285 // Indexed by page_num_ in the samples.
286 // These images are owned by the trainer and need to be pixDestroyed.
287 std::vector<Image > page_images_;
288 // Vector of filenames of loaded tr files.
289 std::vector<std::string> tr_filenames_;
290};
291
292} // namespace tesseract.
293
294#endif // TESSERACT_TRAINING_MASTERTRAINER_H_
void ReadTrainingSamples(const FEATURE_DEFS_STRUCT &feature_definitions, const char *feature_name, int max_samples, UNICHARSET *unicharset, FILE *file, LIST *training_samples)
bool Serialize(FILE *fp, const std::vector< T > &data)
Definition: helpers.h:236
FEATURE_DEFS_STRUCT feature_defs
NormalizationMode
Definition: normalis.h:46
void Init(uint8_t xbuckets, uint8_t ybuckets, uint8_t thetabuckets)
bool operator<(const ShapeDist &other) const
Definition: mastertrainer.h:53
ShapeDist(int s1, int s2, float dist)
Definition: mastertrainer.h:50
const std::string & GetTRFileName(int index) const
TrainingSampleSet * GetSamples()
void SetFeatureSpace(const IntFeatureSpace &fs)
Definition: mastertrainer.h:79
const UNICHARSET & unicharset() const
const ShapeTable & master_shapes() const