All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
tess_lang_model.cpp
Go to the documentation of this file.
1 /**********************************************************************
2  * File: tess_lang_model.cpp
3  * Description: Implementation of the Tesseract Language Model Class
4  * Author: Ahmad Abdulkader
5  * Created: 2008
6  *
7  * (C) Copyright 2008, Google Inc.
8  ** Licensed under the Apache License, Version 2.0 (the "License");
9  ** you may not use this file except in compliance with the License.
10  ** You may obtain a copy of the License at
11  ** http://www.apache.org/licenses/LICENSE-2.0
12  ** Unless required by applicable law or agreed to in writing, software
13  ** distributed under the License is distributed on an "AS IS" BASIS,
14  ** WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  ** See the License for the specific language governing permissions and
16  ** limitations under the License.
17  *
18  **********************************************************************/
19 
20 // The TessLangModel class abstracts the Tesseract language model. It inherits
21 // from the LangModel class. The Tesseract language model encompasses several
22 // Dawgs (words from training data, punctuation, numbers, document words).
23 // On top of this Cube adds an OOD state machine
24 // The class provides methods to traverse the language model in a generative
25 // fashion. Given any node in the DAWG, the language model can generate a list
26 // of children (or fan-out) edges
27 
28 #include <string>
29 #include <vector>
30 
31 #include "char_samp.h"
32 #include "cube_utils.h"
33 #include "dict.h"
34 #include "tesseractclass.h"
35 #include "tess_lang_model.h"
36 #include "tessdatamanager.h"
37 #include "unicharset.h"
38 
39 namespace tesseract {
40 // max fan-out (used for preallocation). Initialized here, but modified by
41 // constructor
42 int TessLangModel::max_edge_ = 4096;
43 
44 // Language model extra State machines
45 const Dawg *TessLangModel::ood_dawg_ = reinterpret_cast<Dawg *>(DAWG_OOD);
46 const Dawg *TessLangModel::number_dawg_ = reinterpret_cast<Dawg *>(DAWG_NUMBER);
47 
48 // number state machine
49 const int TessLangModel::num_state_machine_[kStateCnt][kNumLiteralCnt] = {
50  {0, 1, 1, NUM_TRM, NUM_TRM},
51  {NUM_TRM, 1, 1, 3, 2},
52  {NUM_TRM, NUM_TRM, 1, NUM_TRM, 2},
53  {NUM_TRM, NUM_TRM, 3, NUM_TRM, 2},
54 };
55 const int TessLangModel::num_max_repeat_[kStateCnt] = {3, 32, 8, 3};
56 
57 // thresholds and penalties
58 int TessLangModel::max_ood_shape_cost_ = CubeUtils::Prob2Cost(1e-4);
59 
60 TessLangModel::TessLangModel(const string &lm_params,
61  const string &data_file_path,
62  bool load_system_dawg,
63  TessdataManager *tessdata_manager,
64  CubeRecoContext *cntxt) {
65  cntxt_ = cntxt;
66  has_case_ = cntxt_->HasCase();
67  // Load the rest of the language model elements from file
68  LoadLangModelElements(lm_params);
69  // Load word_dawgs_ if needed.
70  if (tessdata_manager->SeekToStart(TESSDATA_CUBE_UNICHARSET)) {
71  word_dawgs_ = new DawgVector();
72  if (load_system_dawg &&
73  tessdata_manager->SeekToStart(TESSDATA_CUBE_SYSTEM_DAWG)) {
74  // The last parameter to the Dawg constructor (the debug level) is set to
75  // false, until Cube has a way to express its preferred debug level.
76  *word_dawgs_ += new SquishedDawg(tessdata_manager->GetDataFilePtr(),
78  cntxt_->Lang().c_str(),
79  SYSTEM_DAWG_PERM, false);
80  }
81  } else {
82  word_dawgs_ = NULL;
83  }
84 }
85 
86 // Cleanup an edge array
87 void TessLangModel::FreeEdges(int edge_cnt, LangModEdge **edge_array) {
88  if (edge_array != NULL) {
89  for (int edge_idx = 0; edge_idx < edge_cnt; edge_idx++) {
90  if (edge_array[edge_idx] != NULL) {
91  delete edge_array[edge_idx];
92  }
93  }
94  delete []edge_array;
95  }
96 }
97 
98 // Determines if a sequence of 32-bit chars is valid in this language model
99 // starting from the specified edge. If the eow_flag is ON, also checks for
100 // a valid EndOfWord. If final_edge is not NULL, returns a pointer to the last
101 // edge
102 bool TessLangModel::IsValidSequence(LangModEdge *edge,
103  const char_32 *sequence,
104  bool eow_flag,
105  LangModEdge **final_edge) {
106  // get the edges emerging from this edge
107  int edge_cnt = 0;
108  LangModEdge **edge_array = GetEdges(NULL, edge, &edge_cnt);
109 
110  // find the 1st char in the sequence in the children
111  for (int edge_idx = 0; edge_idx < edge_cnt; edge_idx++) {
112  // found a match
113  if (sequence[0] == edge_array[edge_idx]->EdgeString()[0]) {
114  // if this is the last char
115  if (sequence[1] == 0) {
116  // succeed if we are in prefix mode or this is a terminal edge
117  if (eow_flag == false || edge_array[edge_idx]->IsEOW()) {
118  if (final_edge != NULL) {
119  (*final_edge) = edge_array[edge_idx];
120  edge_array[edge_idx] = NULL;
121  }
122 
123  FreeEdges(edge_cnt, edge_array);
124  return true;
125  }
126  } else {
127  // not the last char continue checking
128  if (IsValidSequence(edge_array[edge_idx], sequence + 1, eow_flag,
129  final_edge) == true) {
130  FreeEdges(edge_cnt, edge_array);
131  return true;
132  }
133  }
134  }
135  }
136 
137  FreeEdges(edge_cnt, edge_array);
138  return false;
139 }
140 
141 // Determines if a sequence of 32-bit chars is valid in this language model
142 // starting from the root. If the eow_flag is ON, also checks for
143 // a valid EndOfWord. If final_edge is not NULL, returns a pointer to the last
144 // edge
145 bool TessLangModel::IsValidSequence(const char_32 *sequence, bool eow_flag,
146  LangModEdge **final_edge) {
147  if (final_edge != NULL) {
148  (*final_edge) = NULL;
149  }
150 
151  return IsValidSequence(NULL, sequence, eow_flag, final_edge);
152 }
153 
155  return lead_punc_.find(ch) != string::npos;
156 }
157 
159  return trail_punc_.find(ch) != string::npos;
160 }
161 
163  return digits_.find(ch) != string::npos;
164 }
165 
166 // The general fan-out generation function. Returns the list of edges
167 // fanning-out of the specified edge and their count. If an AltList is
168 // specified, only the class-ids with a minimum cost are considered
170  LangModEdge *lang_mod_edge,
171  int *edge_cnt) {
172  TessLangModEdge *tess_lm_edge =
173  reinterpret_cast<TessLangModEdge *>(lang_mod_edge);
174  LangModEdge **edge_array = NULL;
175  (*edge_cnt) = 0;
176 
177  // if we are starting from the root, we'll instantiate every DAWG
178  // and get the all the edges that emerge from the root
179  if (tess_lm_edge == NULL) {
180  // get DAWG count from Tesseract
181  int dawg_cnt = NumDawgs();
182  // preallocate the edge buffer
183  (*edge_cnt) = dawg_cnt * max_edge_;
184  edge_array = new LangModEdge *[(*edge_cnt)];
185  if (edge_array == NULL) {
186  return NULL;
187  }
188 
189  for (int dawg_idx = (*edge_cnt) = 0; dawg_idx < dawg_cnt; dawg_idx++) {
190  const Dawg *curr_dawg = GetDawg(dawg_idx);
191  // Only look through word Dawgs (since there is a special way of
192  // handling numbers and punctuation).
193  if (curr_dawg->type() == DAWG_TYPE_WORD) {
194  (*edge_cnt) += FanOut(alt_list, curr_dawg, 0, 0, NULL, true,
195  edge_array + (*edge_cnt));
196  }
197  } // dawg
198 
199  (*edge_cnt) += FanOut(alt_list, number_dawg_, 0, 0, NULL, true,
200  edge_array + (*edge_cnt));
201 
202  // OOD: it is intentionally not added to the list to make sure it comes
203  // at the end
204  (*edge_cnt) += FanOut(alt_list, ood_dawg_, 0, 0, NULL, true,
205  edge_array + (*edge_cnt));
206 
207  // set the root flag for all root edges
208  for (int edge_idx = 0; edge_idx < (*edge_cnt); edge_idx++) {
209  edge_array[edge_idx]->SetRoot(true);
210  }
211  } else { // not starting at the root
212  // preallocate the edge buffer
213  (*edge_cnt) = max_edge_;
214  // allocate memory for edges
215  edge_array = new LangModEdge *[(*edge_cnt)];
216  if (edge_array == NULL) {
217  return NULL;
218  }
219 
220  // get the FanOut edges from the root of each dawg
221  (*edge_cnt) = FanOut(alt_list,
222  tess_lm_edge->GetDawg(),
223  tess_lm_edge->EndEdge(), tess_lm_edge->EdgeMask(),
224  tess_lm_edge->EdgeString(), false, edge_array);
225  }
226  return edge_array;
227 }
228 
229 // generate edges from an NULL terminated string
230 // (used for punctuation, operators and digits)
231 int TessLangModel::Edges(const char *strng, const Dawg *dawg,
232  EDGE_REF edge_ref, EDGE_REF edge_mask,
233  LangModEdge **edge_array) {
234  int edge_idx,
235  edge_cnt = 0;
236 
237  for (edge_idx = 0; strng[edge_idx] != 0; edge_idx++) {
238  int class_id = cntxt_->CharacterSet()->ClassID((char_32)strng[edge_idx]);
239  if (class_id != INVALID_UNICHAR_ID) {
240  // create an edge object
241  edge_array[edge_cnt] = new TessLangModEdge(cntxt_, dawg, edge_ref,
242  class_id);
243  if (edge_array[edge_cnt] == NULL) {
244  return 0;
245  }
246 
247  reinterpret_cast<TessLangModEdge *>(edge_array[edge_cnt])->
248  SetEdgeMask(edge_mask);
249  edge_cnt++;
250  }
251  }
252 
253  return edge_cnt;
254 }
255 
256 // generate OOD edges
257 int TessLangModel::OODEdges(CharAltList *alt_list, EDGE_REF edge_ref,
258  EDGE_REF edge_ref_mask, LangModEdge **edge_array) {
259  int class_cnt = cntxt_->CharacterSet()->ClassCount();
260  int edge_cnt = 0;
261  for (int class_id = 0; class_id < class_cnt; class_id++) {
262  // produce an OOD edge only if the cost of the char is low enough
263  if ((alt_list == NULL ||
264  alt_list->ClassCost(class_id) <= max_ood_shape_cost_)) {
265  // create an edge object
266  edge_array[edge_cnt] = new TessLangModEdge(cntxt_, class_id);
267  if (edge_array[edge_cnt] == NULL) {
268  return 0;
269  }
270 
271  edge_cnt++;
272  }
273  }
274 
275  return edge_cnt;
276 }
277 
278 // computes and returns the edges that fan out of an edge ref
279 int TessLangModel::FanOut(CharAltList *alt_list, const Dawg *dawg,
280  EDGE_REF edge_ref, EDGE_REF edge_mask,
281  const char_32 *str, bool root_flag,
282  LangModEdge **edge_array) {
283  int edge_cnt = 0;
284  NODE_REF next_node = NO_EDGE;
285 
286  // OOD
287  if (dawg == reinterpret_cast<Dawg *>(DAWG_OOD)) {
288  if (ood_enabled_ == true) {
289  return OODEdges(alt_list, edge_ref, edge_mask, edge_array);
290  } else {
291  return 0;
292  }
293  } else if (dawg == reinterpret_cast<Dawg *>(DAWG_NUMBER)) {
294  // Number
295  if (numeric_enabled_ == true) {
296  return NumberEdges(edge_ref, edge_array);
297  } else {
298  return 0;
299  }
300  } else if (IsTrailingPuncEdge(edge_mask)) {
301  // a TRAILING PUNC MASK, generate more trailing punctuation and return
302  if (punc_enabled_ == true) {
303  EDGE_REF trail_cnt = TrailingPuncCount(edge_mask);
304  return Edges(trail_punc_.c_str(), dawg, edge_ref,
305  TrailingPuncEdgeMask(trail_cnt + 1), edge_array);
306  } else {
307  return 0;
308  }
309  } else if (root_flag == true || edge_ref == 0) {
310  // Root, generate leading punctuation and continue
311  if (root_flag) {
312  if (punc_enabled_ == true) {
313  edge_cnt += Edges(lead_punc_.c_str(), dawg, 0, LEAD_PUNC_EDGE_REF_MASK,
314  edge_array);
315  }
316  }
317  next_node = 0;
318  } else {
319  // a node in the main trie
320  bool eow_flag = (dawg->end_of_word(edge_ref) != 0);
321 
322  // for EOW
323  if (eow_flag == true) {
324  // generate trailing punctuation
325  if (punc_enabled_ == true) {
326  edge_cnt += Edges(trail_punc_.c_str(), dawg, edge_ref,
327  TrailingPuncEdgeMask((EDGE_REF)1), edge_array);
328  // generate a hyphen and go back to the root
329  edge_cnt += Edges("-/", dawg, 0, 0, edge_array + edge_cnt);
330  }
331  }
332 
333  // advance node
334  next_node = dawg->next_node(edge_ref);
335  if (next_node == 0 || next_node == NO_EDGE) {
336  return edge_cnt;
337  }
338  }
339 
340  // now get all the emerging edges if word list is enabled
341  if (word_list_enabled_ == true && next_node != NO_EDGE) {
342  // create child edges
343  int child_edge_cnt =
344  TessLangModEdge::CreateChildren(cntxt_, dawg, next_node,
345  edge_array + edge_cnt);
346  int strt_cnt = edge_cnt;
347 
348  // set the edge mask
349  for (int child = 0; child < child_edge_cnt; child++) {
350  reinterpret_cast<TessLangModEdge *>(edge_array[edge_cnt++])->
351  SetEdgeMask(edge_mask);
352  }
353 
354  // if we are at the root, create upper case forms of these edges if possible
355  if (root_flag == true) {
356  for (int child = 0; child < child_edge_cnt; child++) {
357  TessLangModEdge *child_edge =
358  reinterpret_cast<TessLangModEdge *>(edge_array[strt_cnt + child]);
359 
360  if (has_case_ == true) {
361  const char_32 *edge_str = child_edge->EdgeString();
362  if (edge_str != NULL && islower(edge_str[0]) != 0 &&
363  edge_str[1] == 0) {
364  int class_id =
365  cntxt_->CharacterSet()->ClassID(toupper(edge_str[0]));
366  if (class_id != INVALID_UNICHAR_ID) {
367  // generate an upper case edge for lower case chars
368  edge_array[edge_cnt] = new TessLangModEdge(cntxt_, dawg,
369  child_edge->StartEdge(), child_edge->EndEdge(), class_id);
370 
371  if (edge_array[edge_cnt] != NULL) {
372  reinterpret_cast<TessLangModEdge *>(edge_array[edge_cnt])->
373  SetEdgeMask(edge_mask);
374  edge_cnt++;
375  }
376  }
377  }
378  }
379  }
380  }
381  }
382  return edge_cnt;
383 }
384 
385 // Generate the edges fanning-out from an edge in the number state machine
386 int TessLangModel::NumberEdges(EDGE_REF edge_ref, LangModEdge **edge_array) {
387  EDGE_REF new_state,
388  state;
389 
390  inT64 repeat_cnt,
391  new_repeat_cnt;
392 
393  state = ((edge_ref & NUMBER_STATE_MASK) >> NUMBER_STATE_SHIFT);
394  repeat_cnt = ((edge_ref & NUMBER_REPEAT_MASK) >> NUMBER_REPEAT_SHIFT);
395 
396  if (state < 0 || state >= kStateCnt) {
397  return 0;
398  }
399 
400  // go thru all valid transitions from the state
401  int edge_cnt = 0;
402 
403  EDGE_REF new_edge_ref;
404 
405  for (int lit = 0; lit < kNumLiteralCnt; lit++) {
406  // move to the new state
407  new_state = num_state_machine_[state][lit];
408  if (new_state == NUM_TRM) {
409  continue;
410  }
411 
412  if (new_state == state) {
413  new_repeat_cnt = repeat_cnt + 1;
414  } else {
415  new_repeat_cnt = 1;
416  }
417 
418  // not allowed to repeat beyond this
419  if (new_repeat_cnt > num_max_repeat_[state]) {
420  continue;
421  }
422 
423  new_edge_ref = (new_state << NUMBER_STATE_SHIFT) |
424  (lit << NUMBER_LITERAL_SHIFT) |
425  (new_repeat_cnt << NUMBER_REPEAT_SHIFT);
426 
427  edge_cnt += Edges(literal_str_[lit]->c_str(), number_dawg_,
428  new_edge_ref, 0, edge_array + edge_cnt);
429  }
430 
431  return edge_cnt;
432 }
433 
434 // Loads Language model elements from contents of the <lang>.cube.lm file
435 bool TessLangModel::LoadLangModelElements(const string &lm_params) {
436  bool success = true;
437  // split into lines, each corresponding to a token type below
438  vector<string> str_vec;
439  CubeUtils::SplitStringUsing(lm_params, "\r\n", &str_vec);
440  for (int entry = 0; entry < str_vec.size(); entry++) {
441  vector<string> tokens;
442  // should be only two tokens: type and value
443  CubeUtils::SplitStringUsing(str_vec[entry], "=", &tokens);
444  if (tokens.size() != 2)
445  success = false;
446  if (tokens[0] == "LeadPunc") {
447  lead_punc_ = tokens[1];
448  } else if (tokens[0] == "TrailPunc") {
449  trail_punc_ = tokens[1];
450  } else if (tokens[0] == "NumLeadPunc") {
451  num_lead_punc_ = tokens[1];
452  } else if (tokens[0] == "NumTrailPunc") {
453  num_trail_punc_ = tokens[1];
454  } else if (tokens[0] == "Operators") {
455  operators_ = tokens[1];
456  } else if (tokens[0] == "Digits") {
457  digits_ = tokens[1];
458  } else if (tokens[0] == "Alphas") {
459  alphas_ = tokens[1];
460  } else {
461  success = false;
462  }
463  }
464 
465  RemoveInvalidCharacters(&num_lead_punc_);
466  RemoveInvalidCharacters(&num_trail_punc_);
467  RemoveInvalidCharacters(&digits_);
468  RemoveInvalidCharacters(&operators_);
469  RemoveInvalidCharacters(&alphas_);
470 
471  // form the array of literal strings needed for number state machine
472  // It is essential that the literal strings go in the order below
473  literal_str_[0] = &num_lead_punc_;
474  literal_str_[1] = &num_trail_punc_;
475  literal_str_[2] = &digits_;
476  literal_str_[3] = &operators_;
477  literal_str_[4] = &alphas_;
478 
479  return success;
480 }
481 
483  CharSet *char_set = cntxt_->CharacterSet();
484  tesseract::string_32 lm_str32;
485  CubeUtils::UTF8ToUTF32(lm_str->c_str(), &lm_str32);
486 
487  int len = CubeUtils::StrLen(lm_str32.c_str());
488  char_32 *clean_str32 = new char_32[len + 1];
489  if (!clean_str32)
490  return;
491  int clean_len = 0;
492  for (int i = 0; i < len; ++i) {
493  int class_id = char_set->ClassID((char_32)lm_str32[i]);
494  if (class_id != INVALID_UNICHAR_ID) {
495  clean_str32[clean_len] = lm_str32[i];
496  ++clean_len;
497  }
498  }
499  clean_str32[clean_len] = 0;
500  if (clean_len < len) {
501  lm_str->clear();
502  CubeUtils::UTF32ToUTF8(clean_str32, lm_str);
503  }
504  delete [] clean_str32;
505 }
506 
507 int TessLangModel::NumDawgs() const {
508  return (word_dawgs_ != NULL) ?
509  word_dawgs_->size() : cntxt_->TesseractObject()->getDict().NumDawgs();
510 }
511 
512 // Returns the dawgs with the given index from either the dawgs
513 // stored by the Tesseract object, or the word_dawgs_.
514 const Dawg *TessLangModel::GetDawg(int index) const {
515  if (word_dawgs_ != NULL) {
516  ASSERT_HOST(index < word_dawgs_->size());
517  return (*word_dawgs_)[index];
518  } else {
519  ASSERT_HOST(index < cntxt_->TesseractObject()->getDict().NumDawgs());
520  return cntxt_->TesseractObject()->getDict().GetDawg(index);
521  }
522 }
523 }
FILE * GetDataFilePtr() const
TessLangModel(const string &lm_params, const string &data_file_path, bool load_system_dawg, TessdataManager *tessdata_manager, CubeRecoContext *cntxt)
void RemoveInvalidCharacters(string *lm_str)
int size() const
Definition: genericvector.h:72
#define NUM_TRM
#define NUMBER_LITERAL_SHIFT
static int Prob2Cost(double prob_val)
Definition: cube_utils.cpp:37
#define IsTrailingPuncEdge(edge_mask)
tesseract::Tesseract * TesseractObject() const
basic_string< char_32 > string_32
Definition: string_32.h:41
#define NUMBER_REPEAT_SHIFT
GenericVector< Dawg * > DawgVector
Definition: dict.h:50
const char_32 * EdgeString() const
virtual void SetRoot(bool flag)=0
#define TrailingPuncCount(edge_mask)
#define ASSERT_HOST(x)
Definition: errcode.h:84
#define DAWG_NUMBER
#define NUMBER_REPEAT_MASK
static int CreateChildren(CubeRecoContext *cntxt, const Dawg *edges, NODE_REF edge_reg, LangModEdge **lm_edges)
static void UTF8ToUTF32(const char *utf8_str, string_32 *str32)
Definition: cube_utils.cpp:266
int ClassID(const char_32 *str) const
Definition: char_set.h:54
#define NUMBER_STATE_SHIFT
Dict & getDict()
Definition: classify.h:65
#define NUMBER_STATE_MASK
const int NumDawgs() const
Return the number of dawgs in the dawgs_ vector.
Definition: dict.h:404
static int StrLen(const char_32 *str)
Definition: cube_utils.cpp:54
bool IsLeadingPunc(char_32 ch)
const Dawg * GetDawg() const
LangModEdge ** GetEdges(CharAltList *alt_list, LangModEdge *edge, int *edge_cnt)
static void UTF32ToUTF8(const char_32 *utf32_str, string *str)
Definition: cube_utils.cpp:282
bool SeekToStart(TessdataType tessdata_type)
DawgType type() const
Definition: dawg.h:127
static void SplitStringUsing(const string &str, const string &delims, vector< string > *str_vec)
Definition: cube_utils.cpp:230
CharSet * CharacterSet() const
const int kNumLiteralCnt
bool IsTrailingPunc(char_32 ch)
#define LEAD_PUNC_EDGE_REF_MASK
bool IsValidSequence(const char_32 *sequence, bool eow_flag, LangModEdge **final_edge=NULL)
const Dawg * GetDawg(int index) const
Return i-th dawg pointer recorded in the dawgs_ vector.
Definition: dict.h:406
inT64 EDGE_REF
Definition: dawg.h:54
signed int char_32
Definition: string_32.h:40
int ClassCount() const
Definition: char_set.h:111
#define TrailingPuncEdgeMask(Cnt)
#define NULL
Definition: host.h:144
const int kStateCnt
inT64 NODE_REF
Definition: dawg.h:55
#define DAWG_OOD
const string & Lang() const
long long int inT64
Definition: host.h:108