All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
beam_search.cpp
Go to the documentation of this file.
1 /**********************************************************************
2  * File: beam_search.cpp
3  * Description: Class to implement Beam Word Search Algorithm
4  * Author: Ahmad Abdulkader
5  * Created: 2007
6  *
7  * (C) Copyright 2008, Google Inc.
8  ** Licensed under the Apache License, Version 2.0 (the "License");
9  ** you may not use this file except in compliance with the License.
10  ** You may obtain a copy of the License at
11  ** http://www.apache.org/licenses/LICENSE-2.0
12  ** Unless required by applicable law or agreed to in writing, software
13  ** distributed under the License is distributed on an "AS IS" BASIS,
14  ** WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  ** See the License for the specific language governing permissions and
16  ** limitations under the License.
17  *
18  **********************************************************************/
19 
20 #include <algorithm>
21 
22 #include "beam_search.h"
23 #include "tesseractclass.h"
24 
25 namespace tesseract {
26 
27 BeamSearch::BeamSearch(CubeRecoContext *cntxt, bool word_mode) {
28  cntxt_ = cntxt;
29  seg_pt_cnt_ = 0;
30  col_cnt_ = 1;
31  col_ = NULL;
32  word_mode_ = word_mode;
33 }
34 
35 // Cleanup the lattice corresponding to the last search
36 void BeamSearch::Cleanup() {
37  if (col_ != NULL) {
38  for (int col = 0; col < col_cnt_; col++) {
39  if (col_[col])
40  delete col_[col];
41  }
42  delete []col_;
43  }
44  col_ = NULL;
45 }
46 
48  Cleanup();
49 }
50 
51 // Creates a set of children nodes emerging from a parent node based on
52 // the character alternate list and the language model.
53 void BeamSearch::CreateChildren(SearchColumn *out_col, LangModel *lang_mod,
54  SearchNode *parent_node,
55  LangModEdge *lm_parent_edge,
56  CharAltList *char_alt_list, int extra_cost) {
57  // get all the edges from this parent
58  int edge_cnt;
59  LangModEdge **lm_edges = lang_mod->GetEdges(char_alt_list,
60  lm_parent_edge, &edge_cnt);
61  if (lm_edges) {
62  // add them to the ending column with the appropriate parent
63  for (int edge = 0; edge < edge_cnt; edge++) {
64  // add a node to the column if the current column is not the
65  // last one, or if the lang model edge indicates it is valid EOW
66  if (!cntxt_->NoisyInput() && out_col->ColIdx() >= seg_pt_cnt_ &&
67  !lm_edges[edge]->IsEOW()) {
68  // free edge since no object is going to own it
69  delete lm_edges[edge];
70  continue;
71  }
72 
73  // compute the recognition cost of this node
74  int recognition_cost = MIN_PROB_COST;
75  if (char_alt_list && char_alt_list->AltCount() > 0) {
76  recognition_cost = MAX(0, char_alt_list->ClassCost(
77  lm_edges[edge]->ClassID()));
78  // Add the no space cost. This should zero in word mode
79  recognition_cost += extra_cost;
80  }
81 
82  // Note that the edge will be freed inside the column if
83  // AddNode is called
84  if (recognition_cost >= 0) {
85  out_col->AddNode(lm_edges[edge], recognition_cost, parent_node,
86  cntxt_);
87  } else {
88  delete lm_edges[edge];
89  }
90  } // edge
91  // free edge array
92  delete []lm_edges;
93  } // lm_edges
94 }
95 
96 // Performs a beam seach in the specified search using the specified
97 // language model; returns an alternate list of possible words as a result.
99  // verifications
100  if (!lang_mod)
101  lang_mod = cntxt_->LangMod();
102  if (!lang_mod) {
103  fprintf(stderr, "Cube ERROR (BeamSearch::Search): could not construct "
104  "LangModel\n");
105  return NULL;
106  }
107 
108  // free existing state
109  Cleanup();
110 
111  // get seg pt count
112  seg_pt_cnt_ = srch_obj->SegPtCnt();
113  if (seg_pt_cnt_ < 0) {
114  return NULL;
115  }
116  col_cnt_ = seg_pt_cnt_ + 1;
117 
118  // disregard suspicious cases
119  if (seg_pt_cnt_ > 128) {
120  fprintf(stderr, "Cube ERROR (BeamSearch::Search): segment point count is "
121  "suspiciously high; bailing out\n");
122  return NULL;
123  }
124 
125  // alloc memory for columns
126  col_ = new SearchColumn *[col_cnt_];
127  if (!col_) {
128  fprintf(stderr, "Cube ERROR (BeamSearch::Search): could not construct "
129  "SearchColumn array\n");
130  return NULL;
131  }
132  memset(col_, 0, col_cnt_ * sizeof(*col_));
133 
134  // for all possible segments
135  for (int end_seg = 1; end_seg <= (seg_pt_cnt_ + 1); end_seg++) {
136  // create a search column
137  col_[end_seg - 1] = new SearchColumn(end_seg - 1,
138  cntxt_->Params()->BeamWidth());
139  if (!col_[end_seg - 1]) {
140  fprintf(stderr, "Cube ERROR (BeamSearch::Search): could not construct "
141  "SearchColumn for column %d\n", end_seg - 1);
142  return NULL;
143  }
144 
145  // for all possible start segments
146  int init_seg = MAX(0, end_seg - cntxt_->Params()->MaxSegPerChar());
147  for (int strt_seg = init_seg; strt_seg < end_seg; strt_seg++) {
148  int parent_nodes_cnt;
149  SearchNode **parent_nodes;
150 
151  // for the root segment, we do not have a parent
152  if (strt_seg == 0) {
153  parent_nodes_cnt = 1;
154  parent_nodes = NULL;
155  } else {
156  // for all the existing nodes in the starting column
157  parent_nodes_cnt = col_[strt_seg - 1]->NodeCount();
158  parent_nodes = col_[strt_seg - 1]->Nodes();
159  }
160 
161  // run the shape recognizer
162  CharAltList *char_alt_list = srch_obj->RecognizeSegment(strt_seg - 1,
163  end_seg - 1);
164  // for all the possible parents
165  for (int parent_idx = 0; parent_idx < parent_nodes_cnt; parent_idx++) {
166  // point to the parent node
167  SearchNode *parent_node = !parent_nodes ? NULL
168  : parent_nodes[parent_idx];
169  LangModEdge *lm_parent_edge = !parent_node ? lang_mod->Root()
170  : parent_node->LangModelEdge();
171 
172  // compute the cost of not having spaces within the segment range
173  int contig_cost = srch_obj->NoSpaceCost(strt_seg - 1, end_seg - 1);
174 
175  // In phrase mode, compute the cost of not having a space before
176  // this character
177  int no_space_cost = 0;
178  if (!word_mode_ && strt_seg > 0) {
179  no_space_cost = srch_obj->NoSpaceCost(strt_seg - 1);
180  }
181 
182  // if the no space cost is low enough
183  if ((contig_cost + no_space_cost) < MIN_PROB_COST) {
184  // Add the children nodes
185  CreateChildren(col_[end_seg - 1], lang_mod, parent_node,
186  lm_parent_edge, char_alt_list,
187  contig_cost + no_space_cost);
188  }
189 
190  // In phrase mode and if not starting at the root
191  if (!word_mode_ && strt_seg > 0) { // parent_node must be non-NULL
192  // consider starting a new word for nodes that are valid EOW
193  if (parent_node->LangModelEdge()->IsEOW()) {
194  // get the space cost
195  int space_cost = srch_obj->SpaceCost(strt_seg - 1);
196  // if the space cost is low enough
197  if ((contig_cost + space_cost) < MIN_PROB_COST) {
198  // Restart the language model and add nodes as children to the
199  // space node.
200  CreateChildren(col_[end_seg - 1], lang_mod, parent_node, NULL,
201  char_alt_list, contig_cost + space_cost);
202  }
203  }
204  }
205  } // parent
206  } // strt_seg
207 
208  // prune the column nodes
209  col_[end_seg - 1]->Prune();
210 
211  // Free the column hash table. No longer needed
212  col_[end_seg - 1]->FreeHashTable();
213  } // end_seg
214 
215  WordAltList *alt_list = CreateWordAltList(srch_obj);
216  return alt_list;
217 }
218 
219 // Creates a Word alternate list from the results in the lattice.
220 WordAltList *BeamSearch::CreateWordAltList(SearchObject *srch_obj) {
221  // create an alternate list of all the nodes in the last column
222  int node_cnt = col_[col_cnt_ - 1]->NodeCount();
223  SearchNode **srch_nodes = col_[col_cnt_ - 1]->Nodes();
224  CharBigrams *bigrams = cntxt_->Bigrams();
225  WordUnigrams *word_unigrams = cntxt_->WordUnigramsObj();
226 
227  // Save the index of the best-cost node before the alt list is
228  // sorted, so that we can retrieve it from the node list when backtracking.
229  best_presorted_node_idx_ = 0;
230  int best_cost = -1;
231 
232  if (node_cnt <= 0)
233  return NULL;
234 
235  // start creating the word alternate list
236  WordAltList *alt_list = new WordAltList(node_cnt + 1);
237  for (int node_idx = 0; node_idx < node_cnt; node_idx++) {
238  // recognition cost
239  int recognition_cost = srch_nodes[node_idx]->BestCost();
240  // compute the size cost of the alternate
241  char_32 *ch_buff = NULL;
242  int size_cost = SizeCost(srch_obj, srch_nodes[node_idx], &ch_buff);
243  // accumulate other costs
244  if (ch_buff) {
245  int cost = 0;
246  // char bigram cost
247  int bigram_cost = !bigrams ? 0 :
248  bigrams->Cost(ch_buff, cntxt_->CharacterSet());
249  // word unigram cost
250  int unigram_cost = !word_unigrams ? 0 :
251  word_unigrams->Cost(ch_buff, cntxt_->LangMod(),
252  cntxt_->CharacterSet());
253  // overall cost
254  cost = static_cast<int>(
255  (size_cost * cntxt_->Params()->SizeWgt()) +
256  (bigram_cost * cntxt_->Params()->CharBigramWgt()) +
257  (unigram_cost * cntxt_->Params()->WordUnigramWgt()) +
258  (recognition_cost * cntxt_->Params()->RecoWgt()));
259 
260  // insert into word alt list
261  alt_list->Insert(ch_buff, cost,
262  static_cast<void *>(srch_nodes[node_idx]));
263  // Note that strict < is necessary because WordAltList::Sort()
264  // uses it in a bubble sort to swap entries.
265  if (best_cost < 0 || cost < best_cost) {
266  best_presorted_node_idx_ = node_idx;
267  best_cost = cost;
268  }
269  delete []ch_buff;
270  }
271  }
272 
273  // sort the alternates based on cost
274  alt_list->Sort();
275  return alt_list;
276 }
277 
278 // Returns the lattice column corresponding to the specified column index.
280  if (col < 0 || col >= col_cnt_ || !col_)
281  return NULL;
282  return col_[col];
283 }
284 
285 // Returns the best node in the last column of last performed search.
287  if (col_cnt_ < 1 || !col_ || !col_[col_cnt_ - 1])
288  return NULL;
289 
290  int node_cnt = col_[col_cnt_ - 1]->NodeCount();
291  SearchNode **srch_nodes = col_[col_cnt_ - 1]->Nodes();
292  if (node_cnt < 1 || !srch_nodes || !srch_nodes[0])
293  return NULL;
294  return srch_nodes[0];
295 }
296 
297 // Returns the string corresponding to the specified alt.
298 char_32 *BeamSearch::Alt(int alt) const {
299  // get the last column of the lattice
300  if (col_cnt_ <= 0)
301  return NULL;
302 
303  SearchColumn *srch_col = col_[col_cnt_ - 1];
304  if (!srch_col)
305  return NULL;
306 
307  // point to the last node in the selected path
308  if (alt >= srch_col->NodeCount() || srch_col->Nodes() == NULL) {
309  return NULL;
310  }
311 
312  SearchNode *srch_node = srch_col->Nodes()[alt];
313  if (!srch_node)
314  return NULL;
315 
316  // get string
317  char_32 *str32 = srch_node->PathString();
318  if (!str32)
319  return NULL;
320 
321  return str32;
322 }
323 
324 // Backtracks from the specified node index and returns the corresponding
325 // character mapped segments and character count. Optional return
326 // arguments are the char_32 result string and character bounding
327 // boxes, if non-NULL values are passed in.
328 CharSamp **BeamSearch::BackTrack(SearchObject *srch_obj, int node_index,
329  int *char_cnt, char_32 **str32,
330  Boxa **char_boxes) const {
331  // get the last column of the lattice
332  if (col_cnt_ <= 0)
333  return NULL;
334  SearchColumn *srch_col = col_[col_cnt_ - 1];
335  if (!srch_col)
336  return NULL;
337 
338  // point to the last node in the selected path
339  if (node_index >= srch_col->NodeCount() || !srch_col->Nodes())
340  return NULL;
341 
342  SearchNode *srch_node = srch_col->Nodes()[node_index];
343  if (!srch_node)
344  return NULL;
345  return BackTrack(srch_obj, srch_node, char_cnt, str32, char_boxes);
346 }
347 
348 // Backtracks from the specified node index and returns the corresponding
349 // character mapped segments and character count. Optional return
350 // arguments are the char_32 result string and character bounding
351 // boxes, if non-NULL values are passed in.
353  int *char_cnt, char_32 **str32,
354  Boxa **char_boxes) const {
355  if (!srch_node)
356  return NULL;
357 
358  if (str32) {
359  if (*str32)
360  delete [](*str32); // clear existing value
361  *str32 = srch_node->PathString();
362  if (!*str32)
363  return NULL;
364  }
365 
366  if (char_boxes && *char_boxes) {
367  boxaDestroy(char_boxes); // clear existing value
368  }
369 
370  CharSamp **chars;
371  chars = SplitByNode(srch_obj, srch_node, char_cnt, char_boxes);
372  if (!chars && str32)
373  delete []*str32;
374  return chars;
375 }
376 
377 // Backtracks from the given lattice node and return the corresponding
378 // char mapped segments and character count. The character bounding
379 // boxes are optional return arguments, if non-NULL values are passed in.
380 CharSamp **BeamSearch::SplitByNode(SearchObject *srch_obj,
381  SearchNode *srch_node,
382  int *char_cnt,
383  Boxa **char_boxes) const {
384  // Count the characters (could be less than the path length when in
385  // phrase mode)
386  *char_cnt = 0;
387  SearchNode *node = srch_node;
388  while (node) {
389  node = node->ParentNode();
390  (*char_cnt)++;
391  }
392 
393  if (*char_cnt == 0)
394  return NULL;
395 
396  // Allocate box array
397  if (char_boxes) {
398  if (*char_boxes)
399  boxaDestroy(char_boxes); // clear existing value
400  *char_boxes = boxaCreate(*char_cnt);
401  if (*char_boxes == NULL)
402  return NULL;
403  }
404 
405  // Allocate memory for CharSamp array.
406  CharSamp **chars = new CharSamp *[*char_cnt];
407  if (!chars) {
408  if (char_boxes)
409  boxaDestroy(char_boxes);
410  return NULL;
411  }
412 
413  int ch_idx = *char_cnt - 1;
414  int seg_pt_cnt = srch_obj->SegPtCnt();
415  bool success=true;
416  while (srch_node && ch_idx >= 0) {
417  // Parent node (could be null)
418  SearchNode *parent_node = srch_node->ParentNode();
419 
420  // Get the seg pts corresponding to the search node
421  int st_col = !parent_node ? 0 : parent_node->ColIdx() + 1;
422  int st_seg_pt = st_col <= 0 ? -1 : st_col - 1;
423  int end_col = srch_node->ColIdx();
424  int end_seg_pt = end_col >= seg_pt_cnt ? seg_pt_cnt : end_col;
425 
426  // Get a char sample corresponding to the segmentation points
427  CharSamp *samp = srch_obj->CharSample(st_seg_pt, end_seg_pt);
428  if (!samp) {
429  success = false;
430  break;
431  }
432  samp->SetLabel(srch_node->NodeString());
433  chars[ch_idx] = samp;
434  if (char_boxes) {
435  // Create the corresponding character bounding box
436  Box *char_box = boxCreate(samp->Left(), samp->Top(),
437  samp->Width(), samp->Height());
438  if (!char_box) {
439  success = false;
440  break;
441  }
442  boxaAddBox(*char_boxes, char_box, L_INSERT);
443  }
444  srch_node = parent_node;
445  ch_idx--;
446  }
447  if (!success) {
448  delete []chars;
449  if (char_boxes)
450  boxaDestroy(char_boxes);
451  return NULL;
452  }
453 
454  // Reverse the order of boxes.
455  if (char_boxes) {
456  int char_boxa_size = boxaGetCount(*char_boxes);
457  int limit = char_boxa_size / 2;
458  for (int i = 0; i < limit; ++i) {
459  int box1_idx = i;
460  int box2_idx = char_boxa_size - 1 - i;
461  Box *box1 = boxaGetBox(*char_boxes, box1_idx, L_CLONE);
462  Box *box2 = boxaGetBox(*char_boxes, box2_idx, L_CLONE);
463  boxaReplaceBox(*char_boxes, box2_idx, box1);
464  boxaReplaceBox(*char_boxes, box1_idx, box2);
465  }
466  }
467  return chars;
468 }
469 
470 // Returns the size cost of a string for a lattice path that
471 // ends at the specified lattice node.
473  char_32 **str32) const {
474  CharSamp **chars = NULL;
475  int char_cnt = 0;
476  if (!node)
477  return 0;
478  // Backtrack to get string and character segmentation
479  chars = BackTrack(srch_obj, node, &char_cnt, str32, NULL);
480  if (!chars)
481  return WORST_COST;
482  int size_cost = (cntxt_->SizeModel() == NULL) ? 0 :
483  cntxt_->SizeModel()->Cost(chars, char_cnt);
484  delete []chars;
485  return size_cost;
486 }
487 } // namespace tesesract
bool Insert(char_32 *char_ptr, int cost, void *tag=NULL)
#define MAX(x, y)
Definition: ndminx.h:24
#define WORST_COST
Definition: cube_const.h:30
double WordUnigramWgt() const
Definition: tuning_params.h:51
LangModel * LangMod() const
int MaxSegPerChar() const
Definition: tuning_params.h:52
virtual LangModEdge ** GetEdges(CharAltList *alt_list, LangModEdge *parent_edge, int *edge_cnt)=0
virtual int SpaceCost(int seg_pt)=0
virtual LangModEdge * Root()=0
int Cost(CharSamp **samp_array, int samp_cnt) const
SearchNode * BestNode() const
WordUnigrams * WordUnigramsObj() const
int Cost(const char_32 *str, CharSet *char_set) const
double CharBigramWgt() const
Definition: tuning_params.h:50
CharBigrams * Bigrams() const
SearchNode * AddNode(LangModEdge *edge, int score, SearchNode *parent, CubeRecoContext *cntxt)
int Cost(const char_32 *str32, LangModel *lang_mod, CharSet *char_set) const
int SizeCost(SearchObject *srch_obj, SearchNode *node, char_32 **str32=NULL) const
const char_32 * NodeString()
Definition: search_node.h:51
#define MIN_PROB_COST
Definition: cube_const.h:26
TuningParams * Params() const
virtual CharSamp * CharSample(int start_pt, int end_pt)=0
double RecoWgt() const
Definition: tuning_params.h:48
virtual int ClassID() const =0
void SetLabel(char_32 label)
Definition: char_samp.h:68
char_32 * Alt(int alt) const
BeamSearch(CubeRecoContext *cntxt, bool word_mode=true)
Definition: beam_search.cpp:27
LangModEdge * LangModelEdge()
Definition: search_node.h:70
double SizeWgt() const
Definition: tuning_params.h:49
WordSizeModel * SizeModel() const
SearchNode * ParentNode()
Definition: search_node.h:69
CharSet * CharacterSet() const
SearchNode ** Nodes() const
Definition: search_column.h:44
int ClassCost(int class_id) const
Definition: char_altlist.h:42
virtual int SegPtCnt()=0
int AltCount() const
Definition: altlist.h:39
virtual CharAltList * RecognizeSegment(int start_pt, int end_pt)=0
virtual bool IsEOW() const =0
signed int char_32
Definition: string_32.h:40
#define NULL
Definition: host.h:144
virtual int NoSpaceCost(int seg_pt)=0
WordAltList * Search(SearchObject *srch_obj, LangModel *lang_mod=NULL)
Definition: beam_search.cpp:98
CharSamp ** BackTrack(SearchObject *srch_obj, int node_index, int *char_cnt, char_32 **str32, Boxa **char_boxes) const
SearchColumn * Column(int col_idx) const