tesseract v5.3.3.20231005
recodebeam.cpp
Go to the documentation of this file.
1
2// File: recodebeam.cpp
3// Description: Beam search to decode from the re-encoded CJK as a sequence of
4// smaller numbers in place of a single large code.
5// Author: Ray Smith
6//
7// (C) Copyright 2015, Google Inc.
8// Licensed under the Apache License, Version 2.0 (the "License");
9// you may not use this file except in compliance with the License.
10// You may obtain a copy of the License at
11// http://www.apache.org/licenses/LICENSE-2.0
12// Unless required by applicable law or agreed to in writing, software
13// distributed under the License is distributed on an "AS IS" BASIS,
14// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15// See the License for the specific language governing permissions and
16// limitations under the License.
17//
19
20#include "recodebeam.h"
21
22#include "networkio.h"
23#include "pageres.h"
24#include "unicharcompress.h"
25
26#include <algorithm> // for std::reverse
27
28namespace tesseract {
29
30// The beam width at each code position.
31const int RecodeBeamSearch::kBeamWidths[RecodedCharID::kMaxCodeLen + 1] = {
32 5, 10, 16, 16, 16, 16, 16, 16, 16, 16,
33};
34
35static const char *kNodeContNames[] = {"Anything", "OnlyDup", "NoDup"};
36
37// Prints debug details of the node.
38void RecodeNode::Print(int null_char, const UNICHARSET &unicharset,
39 int depth) const {
40 if (code == null_char) {
41 tprintf("null_char");
42 } else {
43 tprintf("label=%d, uid=%d=%s", code, unichar_id,
44 unicharset.debug_str(unichar_id).c_str());
45 }
46 tprintf(" score=%g, c=%g,%s%s%s perm=%d, hash=%" PRIx64, score, certainty,
47 start_of_dawg ? " DawgStart" : "", start_of_word ? " Start" : "",
48 end_of_word ? " End" : "", permuter, code_hash);
49 if (depth > 0 && prev != nullptr) {
50 tprintf(" prev:");
51 prev->Print(null_char, unicharset, depth - 1);
52 } else {
53 tprintf("\n");
54 }
55}
56
57// Borrows the pointer, which is expected to survive until *this is deleted.
59 int null_char, bool simple_text, Dict *dict)
60 : recoder_(recoder),
61 beam_size_(0),
62 top_code_(-1),
63 second_code_(-1),
64 dict_(dict),
65 space_delimited_(true),
66 is_simple_text_(simple_text),
67 null_char_(null_char) {
68 if (dict_ != nullptr && !dict_->IsSpaceDelimitedLang()) {
69 space_delimited_ = false;
70 }
71}
72
74 for (auto data : beam_) {
75 delete data;
76 }
77 for (auto data : secondary_beam_) {
78 delete data;
79 }
80}
81
82// Decodes the set of network outputs, storing the lattice internally.
83void RecodeBeamSearch::Decode(const NetworkIO &output, double dict_ratio,
84 double cert_offset, double worst_dict_cert,
85 const UNICHARSET *charset, int lstm_choice_mode) {
86 beam_size_ = 0;
87 int width = output.Width();
88 if (lstm_choice_mode) {
89 timesteps.clear();
90 }
91 for (int t = 0; t < width; ++t) {
92 ComputeTopN(output.f(t), output.NumFeatures(), kBeamWidths[0]);
93 DecodeStep(output.f(t), t, dict_ratio, cert_offset, worst_dict_cert,
94 charset);
95 if (lstm_choice_mode) {
96 SaveMostCertainChoices(output.f(t), output.NumFeatures(), charset, t);
97 }
98 }
99}
101 double dict_ratio, double cert_offset,
102 double worst_dict_cert,
103 const UNICHARSET *charset) {
104 beam_size_ = 0;
105 int width = output.dim1();
106 for (int t = 0; t < width; ++t) {
107 ComputeTopN(output[t], output.dim2(), kBeamWidths[0]);
108 DecodeStep(output[t], t, dict_ratio, cert_offset, worst_dict_cert, charset);
109 }
110}
111
113 const NetworkIO &output, double dict_ratio, double cert_offset,
114 double worst_dict_cert, const UNICHARSET *charset, int lstm_choice_mode) {
115 for (auto data : secondary_beam_) {
116 delete data;
117 }
118 secondary_beam_.clear();
119 if (character_boundaries_.size() < 2) {
120 return;
121 }
122 int width = output.Width();
123 unsigned bucketNumber = 0;
124 for (int t = 0; t < width; ++t) {
125 while ((bucketNumber + 1) < character_boundaries_.size() &&
126 t >= character_boundaries_[bucketNumber + 1]) {
127 ++bucketNumber;
128 }
129 ComputeSecTopN(&(excludedUnichars)[bucketNumber], output.f(t),
130 output.NumFeatures(), kBeamWidths[0]);
131 DecodeSecondaryStep(output.f(t), t, dict_ratio, cert_offset,
132 worst_dict_cert, charset);
133 }
134}
135
136void RecodeBeamSearch::SaveMostCertainChoices(const float *outputs,
137 int num_outputs,
138 const UNICHARSET *charset,
139 int xCoord) {
140 std::vector<std::pair<const char *, float>> choices;
141 for (int i = 0; i < num_outputs; ++i) {
142 if (outputs[i] >= 0.01f) {
143 const char *character;
144 if (i + 2 >= num_outputs) {
145 character = "";
146 } else if (i > 0) {
147 character = charset->id_to_unichar_ext(i + 2);
148 } else {
149 character = charset->id_to_unichar_ext(i);
150 }
151 size_t pos = 0;
152 // order the possible choices within one timestep
153 // beginning with the most likely
154 while (choices.size() > pos && choices[pos].second > outputs[i]) {
155 pos++;
156 }
157 choices.insert(choices.begin() + pos,
158 std::pair<const char *, float>(character, outputs[i]));
159 }
160 }
161 timesteps.push_back(choices);
162}
163
165 for (unsigned i = 1; i < character_boundaries_.size(); ++i) {
166 std::vector<std::vector<std::pair<const char *, float>>> segment;
167 for (int j = character_boundaries_[i - 1]; j < character_boundaries_[i];
168 ++j) {
169 segment.push_back(timesteps[j]);
170 }
171 segmentedTimesteps.push_back(segment);
172 }
173}
174std::vector<std::vector<std::pair<const char *, float>>>
176 std::vector<std::vector<std::vector<std::pair<const char *, float>>>>
177 *segmentedTimesteps) {
178 std::vector<std::vector<std::pair<const char *, float>>> combined_timesteps;
179 for (auto &segmentedTimestep : *segmentedTimesteps) {
180 for (auto &j : segmentedTimestep) {
181 combined_timesteps.push_back(j);
182 }
183 }
184 return combined_timesteps;
185}
186
187void RecodeBeamSearch::calculateCharBoundaries(std::vector<int> *starts,
188 std::vector<int> *ends,
189 std::vector<int> *char_bounds_,
190 int maxWidth) {
191 char_bounds_->push_back(0);
192 for (unsigned i = 0; i < ends->size(); ++i) {
193 int middle = ((*starts)[i + 1] - (*ends)[i]) / 2;
194 char_bounds_->push_back((*ends)[i] + middle);
195 }
196 char_bounds_->pop_back();
197 char_bounds_->push_back(maxWidth);
198}
199
200// Returns the best path as labels/scores/xcoords similar to simple CTC.
202 std::vector<int> *labels, std::vector<int> *xcoords) const {
203 labels->clear();
204 xcoords->clear();
205 std::vector<const RecodeNode *> best_nodes;
206 ExtractBestPaths(&best_nodes, nullptr);
207 // Now just run CTC on the best nodes.
208 int t = 0;
209 int width = best_nodes.size();
210 while (t < width) {
211 int label = best_nodes[t]->code;
212 if (label != null_char_) {
213 labels->push_back(label);
214 xcoords->push_back(t);
215 }
216 while (++t < width && !is_simple_text_ && best_nodes[t]->code == label) {
217 }
218 }
219 xcoords->push_back(width);
220}
221
222// Returns the best path as unichar-ids/certs/ratings/xcoords skipping
223// duplicates, nulls and intermediate parts.
225 bool debug, const UNICHARSET *unicharset, std::vector<int> *unichar_ids,
226 std::vector<float> *certs, std::vector<float> *ratings,
227 std::vector<int> *xcoords) const {
228 std::vector<const RecodeNode *> best_nodes;
229 ExtractBestPaths(&best_nodes, nullptr);
230 ExtractPathAsUnicharIds(best_nodes, unichar_ids, certs, ratings, xcoords);
231 if (debug) {
232 DebugPath(unicharset, best_nodes);
233 DebugUnicharPath(unicharset, best_nodes, *unichar_ids, *certs, *ratings,
234 *xcoords);
235 }
236}
237
238// Returns the best path as a set of WERD_RES.
240 float scale_factor, bool debug,
241 const UNICHARSET *unicharset,
243 int lstm_choice_mode) {
244 words->truncate(0);
245 std::vector<int> unichar_ids;
246 std::vector<float> certs;
247 std::vector<float> ratings;
248 std::vector<int> xcoords;
249 std::vector<const RecodeNode *> best_nodes;
250 std::vector<const RecodeNode *> second_nodes;
251 character_boundaries_.clear();
252 ExtractBestPaths(&best_nodes, &second_nodes);
253 if (debug) {
254 DebugPath(unicharset, best_nodes);
255 ExtractPathAsUnicharIds(second_nodes, &unichar_ids, &certs, &ratings,
256 &xcoords);
257 tprintf("\nSecond choice path:\n");
258 DebugUnicharPath(unicharset, second_nodes, unichar_ids, certs, ratings,
259 xcoords);
260 }
261 // If lstm choice mode is required in granularity level 2, it stores the x
262 // Coordinates of every chosen character, to match the alternative choices to
263 // it.
264 ExtractPathAsUnicharIds(best_nodes, &unichar_ids, &certs, &ratings, &xcoords,
266 int num_ids = unichar_ids.size();
267 if (debug) {
268 DebugUnicharPath(unicharset, best_nodes, unichar_ids, certs, ratings,
269 xcoords);
270 }
271 // Convert labels to unichar-ids.
272 int word_end = 0;
273 float prev_space_cert = 0.0f;
274 for (int word_start = 0; word_start < num_ids; word_start = word_end) {
275 for (word_end = word_start + 1; word_end < num_ids; ++word_end) {
276 // A word is terminated when a space character or start_of_word flag is
277 // hit. We also want to force a separate word for every non
278 // space-delimited character when not in a dictionary context.
279 if (unichar_ids[word_end] == UNICHAR_SPACE) {
280 break;
281 }
282 int index = xcoords[word_end];
283 if (best_nodes[index]->start_of_word) {
284 break;
285 }
286 if (best_nodes[index]->permuter == TOP_CHOICE_PERM &&
287 (!unicharset->IsSpaceDelimited(unichar_ids[word_end]) ||
288 !unicharset->IsSpaceDelimited(unichar_ids[word_end - 1]))) {
289 break;
290 }
291 }
292 float space_cert = 0.0f;
293 if (word_end < num_ids && unichar_ids[word_end] == UNICHAR_SPACE) {
294 space_cert = certs[word_end];
295 }
296 bool leading_space =
297 word_start > 0 && unichar_ids[word_start - 1] == UNICHAR_SPACE;
298 // Create a WERD_RES for the output word.
299 WERD_RES *word_res =
300 InitializeWord(leading_space, line_box, word_start, word_end,
301 std::min(space_cert, prev_space_cert), unicharset,
302 xcoords, scale_factor);
303 for (int i = word_start; i < word_end; ++i) {
304 auto *choices = new BLOB_CHOICE_LIST;
305 BLOB_CHOICE_IT bc_it(choices);
306 auto *choice = new BLOB_CHOICE(unichar_ids[i], ratings[i], certs[i], -1,
307 1.0f, static_cast<float>(INT16_MAX), 0.0f,
309 int col = i - word_start;
310 choice->set_matrix_cell(col, col);
311 bc_it.add_after_then_move(choice);
312 word_res->ratings->put(col, col, choices);
313 }
314 int index = xcoords[word_end - 1];
315 word_res->FakeWordFromRatings(best_nodes[index]->permuter);
316 words->push_back(word_res);
317 prev_space_cert = space_cert;
318 if (word_end < num_ids && unichar_ids[word_end] == UNICHAR_SPACE) {
319 ++word_end;
320 }
321 }
322}
323
325 inline bool operator()(const RecodeNode *&node1, const RecodeNode *&node2) const {
326 return (node1->score > node2->score);
327 }
328};
329
330void RecodeBeamSearch::PrintBeam2(bool uids, int num_outputs,
331 const UNICHARSET *charset,
332 bool secondary) const {
333 std::vector<std::vector<const RecodeNode *>> topology;
334 std::unordered_set<const RecodeNode *> visited;
335 const std::vector<RecodeBeam *> &beam = !secondary ? beam_ : secondary_beam_;
336 // create the topology
337 for (int step = beam.size() - 1; step >= 0; --step) {
338 std::vector<const RecodeNode *> layer;
339 topology.push_back(layer);
340 }
341 // fill the topology with depths first
342 for (int step = beam.size() - 1; step >= 0; --step) {
343 std::vector<tesseract::RecodePair> &heaps = beam.at(step)->beams_->heap();
344 for (auto &&node : heaps) {
345 int backtracker = 0;
346 const RecodeNode *curr = &node.data();
347 while (curr != nullptr && !visited.count(curr)) {
348 visited.insert(curr);
349 topology[step - backtracker].push_back(curr);
350 curr = curr->prev;
351 ++backtracker;
352 }
353 }
354 }
355 int ct = 0;
356 unsigned cb = 1;
357 for (const std::vector<const RecodeNode *> &layer : topology) {
358 if (cb >= character_boundaries_.size()) {
359 break;
360 }
361 if (ct == character_boundaries_[cb]) {
362 tprintf("***\n");
363 ++cb;
364 }
365 for (const RecodeNode *node : layer) {
366 const char *code;
367 int intCode;
368 if (node->unichar_id != INVALID_UNICHAR_ID) {
369 code = charset->id_to_unichar(node->unichar_id);
370 intCode = node->unichar_id;
371 } else if (node->code == null_char_) {
372 intCode = 0;
373 code = " ";
374 } else {
375 intCode = 666;
376 code = "*";
377 }
378 int intPrevCode = 0;
379 const char *prevCode;
380 float prevScore = 0;
381 if (node->prev != nullptr) {
382 prevScore = node->prev->score;
383 if (node->prev->unichar_id != INVALID_UNICHAR_ID) {
384 prevCode = charset->id_to_unichar(node->prev->unichar_id);
385 intPrevCode = node->prev->unichar_id;
386 } else if (node->code == null_char_) {
387 intPrevCode = 0;
388 prevCode = " ";
389 } else {
390 prevCode = "*";
391 intPrevCode = 666;
392 }
393 } else {
394 prevCode = " ";
395 }
396 if (uids) {
397 tprintf("%x(|)%f(>)%x(|)%f\n", intPrevCode, prevScore, intCode,
398 node->score);
399 } else {
400 tprintf("%s(|)%f(>)%s(|)%f\n", prevCode, prevScore, code, node->score);
401 }
402 }
403 tprintf("-\n");
404 ++ct;
405 }
406 tprintf("***\n");
407}
408
410 if (character_boundaries_.size() < 2) {
411 return;
412 }
413 // For the first iteration the original beam is analyzed. After that a
414 // new beam is calculated based on the results from the original beam.
415 std::vector<RecodeBeam *> &currentBeam =
416 secondary_beam_.empty() ? beam_ : secondary_beam_;
418 for (unsigned j = 1; j < character_boundaries_.size(); ++j) {
419 std::vector<int> unichar_ids;
420 std::vector<float> certs;
421 std::vector<float> ratings;
422 std::vector<int> xcoords;
423 int backpath = character_boundaries_[j] - character_boundaries_[j - 1];
424 std::vector<tesseract::RecodePair> &heaps =
425 currentBeam.at(character_boundaries_[j] - 1)->beams_->heap();
426 std::vector<const RecodeNode *> best_nodes;
427 std::vector<const RecodeNode *> best;
428 // Scan the segmented node chain for valid unichar ids.
429 for (auto &&entry : heaps) {
430 bool validChar = false;
431 int backcounter = 0;
432 const RecodeNode *node = &entry.data();
433 while (node != nullptr && backcounter < backpath) {
434 if (node->code != null_char_ &&
435 node->unichar_id != INVALID_UNICHAR_ID) {
436 validChar = true;
437 break;
438 }
439 node = node->prev;
440 ++backcounter;
441 }
442 if (validChar) {
443 best.push_back(&entry.data());
444 }
445 }
446 // find the best rated segmented node chain and extract the unichar id.
447 if (!best.empty()) {
448 std::sort(best.begin(), best.end(), greater_than());
449 ExtractPath(best[0], &best_nodes, backpath);
450 ExtractPathAsUnicharIds(best_nodes, &unichar_ids, &certs, &ratings,
451 &xcoords);
452 }
453 if (!unichar_ids.empty()) {
454 int bestPos = 0;
455 for (unsigned i = 1; i < unichar_ids.size(); ++i) {
456 if (ratings[i] < ratings[bestPos]) {
457 bestPos = i;
458 }
459 }
460#if 0 // TODO: bestCode is currently unused (see commit 2dd5d0d60).
461 int bestCode = -10;
462 for (auto &node : best_nodes) {
463 if (node->unichar_id == unichar_ids[bestPos]) {
464 bestCode = node->code;
465 }
466 }
467#endif
468 // Exclude the best choice for the followup decoding.
469 std::unordered_set<int> excludeCodeList;
470 for (auto &best_node : best_nodes) {
471 if (best_node->code != null_char_) {
472 excludeCodeList.insert(best_node->code);
473 }
474 }
475 if (j - 1 < excludedUnichars.size()) {
476 for (auto elem : excludeCodeList) {
477 excludedUnichars[j - 1].insert(elem);
478 }
479 } else {
480 excludedUnichars.push_back(excludeCodeList);
481 }
482 // Save the best choice for the choice iterator.
483 if (j - 1 < ctc_choices.size()) {
484 int id = unichar_ids[bestPos];
485 const char *result = unicharset->id_to_unichar_ext(id);
486 float rating = ratings[bestPos];
487 ctc_choices[j - 1].push_back(
488 std::pair<const char *, float>(result, rating));
489 } else {
490 std::vector<std::pair<const char *, float>> choice;
491 int id = unichar_ids[bestPos];
492 const char *result = unicharset->id_to_unichar_ext(id);
493 float rating = ratings[bestPos];
494 choice.emplace_back(result, rating);
495 ctc_choices.push_back(choice);
496 }
497 // fill the blank spot with an empty array
498 } else {
499 if (j - 1 >= excludedUnichars.size()) {
500 std::unordered_set<int> excludeCodeList;
501 excludedUnichars.push_back(excludeCodeList);
502 }
503 if (j - 1 >= ctc_choices.size()) {
504 std::vector<std::pair<const char *, float>> choice;
505 ctc_choices.push_back(choice);
506 }
507 }
508 }
509 for (auto data : secondary_beam_) {
510 delete data;
511 }
512 secondary_beam_.clear();
513}
514
515// Generates debug output of the content of the beams after a Decode.
516void RecodeBeamSearch::DebugBeams(const UNICHARSET &unicharset) const {
517 for (int p = 0; p < beam_size_; ++p) {
518 for (int d = 0; d < 2; ++d) {
519 for (int c = 0; c < NC_COUNT; ++c) {
520 auto cont = static_cast<NodeContinuation>(c);
521 int index = BeamIndex(d, cont, 0);
522 if (beam_[p]->beams_[index].empty()) {
523 continue;
524 }
525 // Print all the best scoring nodes for each unichar found.
526 tprintf("Position %d: %s+%s beam\n", p, d ? "Dict" : "Non-Dict",
527 kNodeContNames[c]);
528 DebugBeamPos(unicharset, beam_[p]->beams_[index]);
529 }
530 }
531 }
532}
533
534// Generates debug output of the content of a single beam position.
535void RecodeBeamSearch::DebugBeamPos(const UNICHARSET &unicharset,
536 const RecodeHeap &heap) const {
537 std::vector<const RecodeNode *> unichar_bests(unicharset.size());
538 const RecodeNode *null_best = nullptr;
539 int heap_size = heap.size();
540 for (int i = 0; i < heap_size; ++i) {
541 const RecodeNode *node = &heap.get(i).data();
542 if (node->unichar_id == INVALID_UNICHAR_ID) {
543 if (null_best == nullptr || null_best->score < node->score) {
544 null_best = node;
545 }
546 } else {
547 if (unichar_bests[node->unichar_id] == nullptr ||
548 unichar_bests[node->unichar_id]->score < node->score) {
549 unichar_bests[node->unichar_id] = node;
550 }
551 }
552 }
553 for (auto &unichar_best : unichar_bests) {
554 if (unichar_best != nullptr) {
555 const RecodeNode &node = *unichar_best;
556 node.Print(null_char_, unicharset, 1);
557 }
558 }
559 if (null_best != nullptr) {
560 null_best->Print(null_char_, unicharset, 1);
561 }
562}
563
564// Returns the given best_nodes as unichar-ids/certs/ratings/xcoords skipping
565// duplicates, nulls and intermediate parts.
566/* static */
567void RecodeBeamSearch::ExtractPathAsUnicharIds(
568 const std::vector<const RecodeNode *> &best_nodes,
569 std::vector<int> *unichar_ids, std::vector<float> *certs,
570 std::vector<float> *ratings, std::vector<int> *xcoords,
571 std::vector<int> *character_boundaries) {
572 unichar_ids->clear();
573 certs->clear();
574 ratings->clear();
575 xcoords->clear();
576 std::vector<int> starts;
577 std::vector<int> ends;
578 // Backtrack extracting only valid, non-duplicate unichar-ids.
579 int t = 0;
580 int width = best_nodes.size();
581 while (t < width) {
582 double certainty = 0.0;
583 double rating = 0.0;
584 while (t < width && best_nodes[t]->unichar_id == INVALID_UNICHAR_ID) {
585 double cert = best_nodes[t++]->certainty;
586 if (cert < certainty) {
587 certainty = cert;
588 }
589 rating -= cert;
590 }
591 starts.push_back(t);
592 if (t < width) {
593 int unichar_id = best_nodes[t]->unichar_id;
594 if (unichar_id == UNICHAR_SPACE && !certs->empty() &&
595 best_nodes[t]->permuter != NO_PERM) {
596 // All the rating and certainty go on the previous character except
597 // for the space itself.
598 if (certainty < certs->back()) {
599 certs->back() = certainty;
600 }
601 ratings->back() += rating;
602 certainty = 0.0;
603 rating = 0.0;
604 }
605 unichar_ids->push_back(unichar_id);
606 xcoords->push_back(t);
607 do {
608 double cert = best_nodes[t++]->certainty;
609 // Special-case NO-PERM space to forget the certainty of the previous
610 // nulls. See long comment in ContinueContext.
611 if (cert < certainty || (unichar_id == UNICHAR_SPACE &&
612 best_nodes[t - 1]->permuter == NO_PERM)) {
613 certainty = cert;
614 }
615 rating -= cert;
616 } while (t < width && best_nodes[t]->duplicate);
617 ends.push_back(t);
618 certs->push_back(certainty);
619 ratings->push_back(rating);
620 } else if (!certs->empty()) {
621 if (certainty < certs->back()) {
622 certs->back() = certainty;
623 }
624 ratings->back() += rating;
625 }
626 }
627 starts.push_back(width);
628 if (character_boundaries != nullptr) {
629 calculateCharBoundaries(&starts, &ends, character_boundaries, width);
630 }
631 xcoords->push_back(width);
632}
633
634// Sets up a word with the ratings matrix and fake blobs with boxes in the
635// right places.
636WERD_RES *RecodeBeamSearch::InitializeWord(bool leading_space,
637 const TBOX &line_box, int word_start,
638 int word_end, float space_certainty,
639 const UNICHARSET *unicharset,
640 const std::vector<int> &xcoords,
641 float scale_factor) {
642 // Make a fake blob for each non-zero label.
643 C_BLOB_LIST blobs;
644 C_BLOB_IT b_it(&blobs);
645 for (int i = word_start; i < word_end; ++i) {
646 if (static_cast<unsigned>(i + 1) < character_boundaries_.size()) {
647 TBOX box(static_cast<int16_t>(
648 std::floor(character_boundaries_[i] * scale_factor)) +
649 line_box.left(),
650 line_box.bottom(),
651 static_cast<int16_t>(
652 std::ceil(character_boundaries_[i + 1] * scale_factor)) +
653 line_box.left(),
654 line_box.top());
655 b_it.add_after_then_move(C_BLOB::FakeBlob(box));
656 }
657 }
658 // Make a fake word from the blobs.
659 WERD *word = new WERD(&blobs, leading_space, nullptr);
660 // Make a WERD_RES from the word.
661 auto *word_res = new WERD_RES(word);
662 word_res->end = word_end - word_start + leading_space;
663 word_res->uch_set = unicharset;
664 word_res->combination = true; // Give it ownership of the word.
665 word_res->space_certainty = space_certainty;
666 word_res->ratings = new MATRIX(word_end - word_start, 1);
667 return word_res;
668}
669
670// Fills top_n_flags_ with bools that are true iff the corresponding output
671// is one of the top_n.
672void RecodeBeamSearch::ComputeTopN(const float *outputs, int num_outputs,
673 int top_n) {
674 top_n_flags_.clear();
675 top_n_flags_.resize(num_outputs, TN_ALSO_RAN);
676 top_code_ = -1;
677 second_code_ = -1;
678 top_heap_.clear();
679 for (int i = 0; i < num_outputs; ++i) {
680 if (top_heap_.size() < top_n || outputs[i] > top_heap_.PeekTop().key()) {
681 TopPair entry(outputs[i], i);
682 top_heap_.Push(&entry);
683 if (top_heap_.size() > top_n) {
684 top_heap_.Pop(&entry);
685 }
686 }
687 }
688 while (!top_heap_.empty()) {
689 TopPair entry;
690 top_heap_.Pop(&entry);
691 if (top_heap_.size() > 1) {
692 top_n_flags_[entry.data()] = TN_TOPN;
693 } else {
694 top_n_flags_[entry.data()] = TN_TOP2;
695 if (top_heap_.empty()) {
696 top_code_ = entry.data();
697 } else {
698 second_code_ = entry.data();
699 }
700 }
701 }
702 top_n_flags_[null_char_] = TN_TOP2;
703}
704
705void RecodeBeamSearch::ComputeSecTopN(std::unordered_set<int> *exList,
706 const float *outputs, int num_outputs,
707 int top_n) {
708 top_n_flags_.clear();
709 top_n_flags_.resize(num_outputs, TN_ALSO_RAN);
710 top_code_ = -1;
711 second_code_ = -1;
712 top_heap_.clear();
713 for (int i = 0; i < num_outputs; ++i) {
714 if ((top_heap_.size() < top_n || outputs[i] > top_heap_.PeekTop().key()) &&
715 !exList->count(i)) {
716 TopPair entry(outputs[i], i);
717 top_heap_.Push(&entry);
718 if (top_heap_.size() > top_n) {
719 top_heap_.Pop(&entry);
720 }
721 }
722 }
723 while (!top_heap_.empty()) {
724 TopPair entry;
725 top_heap_.Pop(&entry);
726 if (top_heap_.size() > 1) {
727 top_n_flags_[entry.data()] = TN_TOPN;
728 } else {
729 top_n_flags_[entry.data()] = TN_TOP2;
730 if (top_heap_.empty()) {
731 top_code_ = entry.data();
732 } else {
733 second_code_ = entry.data();
734 }
735 }
736 }
737 top_n_flags_[null_char_] = TN_TOP2;
738}
739
740// Adds the computation for the current time-step to the beam. Call at each
741// time-step in sequence from left to right. outputs is the activation vector
742// for the current timestep.
743void RecodeBeamSearch::DecodeStep(const float *outputs, int t,
744 double dict_ratio, double cert_offset,
745 double worst_dict_cert,
746 const UNICHARSET *charset, bool debug) {
747 if (t == static_cast<int>(beam_.size())) {
748 beam_.push_back(new RecodeBeam);
749 }
750 RecodeBeam *step = beam_[t];
751 beam_size_ = t + 1;
752 step->Clear();
753 if (t == 0) {
754 // The first step can only use singles and initials.
755 ContinueContext(nullptr, BeamIndex(false, NC_ANYTHING, 0), outputs, TN_TOP2,
756 charset, dict_ratio, cert_offset, worst_dict_cert, step);
757 if (dict_ != nullptr) {
758 ContinueContext(nullptr, BeamIndex(true, NC_ANYTHING, 0), outputs,
759 TN_TOP2, charset, dict_ratio, cert_offset,
760 worst_dict_cert, step);
761 }
762 } else {
763 RecodeBeam *prev = beam_[t - 1];
764 if (debug) {
765 int beam_index = BeamIndex(true, NC_ANYTHING, 0);
766 for (int i = prev->beams_[beam_index].size() - 1; i >= 0; --i) {
767 std::vector<const RecodeNode *> path;
768 ExtractPath(&prev->beams_[beam_index].get(i).data(), &path);
769 tprintf("Step %d: Dawg beam %d:\n", t, i);
770 DebugPath(charset, path);
771 }
772 beam_index = BeamIndex(false, NC_ANYTHING, 0);
773 for (int i = prev->beams_[beam_index].size() - 1; i >= 0; --i) {
774 std::vector<const RecodeNode *> path;
775 ExtractPath(&prev->beams_[beam_index].get(i).data(), &path);
776 tprintf("Step %d: Non-Dawg beam %d:\n", t, i);
777 DebugPath(charset, path);
778 }
779 }
780 int total_beam = 0;
781 // Work through the scores by group (top-2, top-n, the rest) while the beam
782 // is empty. This enables extending the context using only the top-n results
783 // first, which may have an empty intersection with the valid codes, so we
784 // fall back to the rest if the beam is empty.
785 for (int tn = 0; tn < TN_COUNT && total_beam == 0; ++tn) {
786 auto top_n = static_cast<TopNState>(tn);
787 for (int index = 0; index < kNumBeams; ++index) {
788 // Working backwards through the heaps doesn't guarantee that we see the
789 // best first, but it comes before a lot of the worst, so it is slightly
790 // more efficient than going forwards.
791 for (int i = prev->beams_[index].size() - 1; i >= 0; --i) {
792 ContinueContext(&prev->beams_[index].get(i).data(), index, outputs,
793 top_n, charset, dict_ratio, cert_offset,
794 worst_dict_cert, step);
795 }
796 }
797 for (int index = 0; index < kNumBeams; ++index) {
799 total_beam += step->beams_[index].size();
800 }
801 }
802 }
803 // Special case for the best initial dawg. Push it on the heap if good
804 // enough, but there is only one, so it doesn't blow up the beam.
805 for (int c = 0; c < NC_COUNT; ++c) {
806 if (step->best_initial_dawgs_[c].code >= 0) {
807 int index = BeamIndex(true, static_cast<NodeContinuation>(c), 0);
808 RecodeHeap *dawg_heap = &step->beams_[index];
809 PushHeapIfBetter(kBeamWidths[0], &step->best_initial_dawgs_[c],
810 dawg_heap);
811 }
812 }
813 }
814}
815
816void RecodeBeamSearch::DecodeSecondaryStep(
817 const float *outputs, int t, double dict_ratio, double cert_offset,
818 double worst_dict_cert, const UNICHARSET *charset, bool debug) {
819 if (t == static_cast<int>(secondary_beam_.size())) {
820 secondary_beam_.push_back(new RecodeBeam);
821 }
822 RecodeBeam *step = secondary_beam_[t];
823 step->Clear();
824 if (t == 0) {
825 // The first step can only use singles and initials.
826 ContinueContext(nullptr, BeamIndex(false, NC_ANYTHING, 0), outputs, TN_TOP2,
827 charset, dict_ratio, cert_offset, worst_dict_cert, step);
828 if (dict_ != nullptr) {
829 ContinueContext(nullptr, BeamIndex(true, NC_ANYTHING, 0), outputs,
830 TN_TOP2, charset, dict_ratio, cert_offset,
831 worst_dict_cert, step);
832 }
833 } else {
834 RecodeBeam *prev = secondary_beam_[t - 1];
835 if (debug) {
836 int beam_index = BeamIndex(true, NC_ANYTHING, 0);
837 for (int i = prev->beams_[beam_index].size() - 1; i >= 0; --i) {
838 std::vector<const RecodeNode *> path;
839 ExtractPath(&prev->beams_[beam_index].get(i).data(), &path);
840 tprintf("Step %d: Dawg beam %d:\n", t, i);
841 DebugPath(charset, path);
842 }
843 beam_index = BeamIndex(false, NC_ANYTHING, 0);
844 for (int i = prev->beams_[beam_index].size() - 1; i >= 0; --i) {
845 std::vector<const RecodeNode *> path;
846 ExtractPath(&prev->beams_[beam_index].get(i).data(), &path);
847 tprintf("Step %d: Non-Dawg beam %d:\n", t, i);
848 DebugPath(charset, path);
849 }
850 }
851 int total_beam = 0;
852 // Work through the scores by group (top-2, top-n, the rest) while the beam
853 // is empty. This enables extending the context using only the top-n results
854 // first, which may have an empty intersection with the valid codes, so we
855 // fall back to the rest if the beam is empty.
856 for (int tn = 0; tn < TN_COUNT && total_beam == 0; ++tn) {
857 auto top_n = static_cast<TopNState>(tn);
858 for (int index = 0; index < kNumBeams; ++index) {
859 // Working backwards through the heaps doesn't guarantee that we see the
860 // best first, but it comes before a lot of the worst, so it is slightly
861 // more efficient than going forwards.
862 for (int i = prev->beams_[index].size() - 1; i >= 0; --i) {
863 ContinueContext(&prev->beams_[index].get(i).data(), index, outputs,
864 top_n, charset, dict_ratio, cert_offset,
865 worst_dict_cert, step);
866 }
867 }
868 for (int index = 0; index < kNumBeams; ++index) {
870 total_beam += step->beams_[index].size();
871 }
872 }
873 }
874 // Special case for the best initial dawg. Push it on the heap if good
875 // enough, but there is only one, so it doesn't blow up the beam.
876 for (int c = 0; c < NC_COUNT; ++c) {
877 if (step->best_initial_dawgs_[c].code >= 0) {
878 int index = BeamIndex(true, static_cast<NodeContinuation>(c), 0);
879 RecodeHeap *dawg_heap = &step->beams_[index];
880 PushHeapIfBetter(kBeamWidths[0], &step->best_initial_dawgs_[c],
881 dawg_heap);
882 }
883 }
884 }
885}
886
887// Adds to the appropriate beams the legal (according to recoder)
888// continuations of context prev, which is of the given length, using the
889// given network outputs to provide scores to the choices. Uses only those
890// choices for which top_n_flags[index] == top_n_flag.
891void RecodeBeamSearch::ContinueContext(
892 const RecodeNode *prev, int index, const float *outputs,
893 TopNState top_n_flag, const UNICHARSET *charset, double dict_ratio,
894 double cert_offset, double worst_dict_cert, RecodeBeam *step) {
895 RecodedCharID prefix;
896 RecodedCharID full_code;
897 const RecodeNode *previous = prev;
898 int length = LengthFromBeamsIndex(index);
899 bool use_dawgs = IsDawgFromBeamsIndex(index);
901 for (int p = length - 1; p >= 0; --p, previous = previous->prev) {
902 while (previous != nullptr &&
903 (previous->duplicate || previous->code == null_char_)) {
904 previous = previous->prev;
905 }
906 if (previous != nullptr) {
907 prefix.Set(p, previous->code);
908 full_code.Set(p, previous->code);
909 }
910 }
911 if (prev != nullptr && !is_simple_text_) {
912 if (top_n_flags_[prev->code] == top_n_flag) {
913 if (prev_cont != NC_NO_DUP) {
914 float cert =
915 NetworkIO::ProbToCertainty(outputs[prev->code]) + cert_offset;
916 PushDupOrNoDawgIfBetter(length, true, prev->code, prev->unichar_id,
917 cert, worst_dict_cert, dict_ratio, use_dawgs,
918 NC_ANYTHING, prev, step);
919 }
920 if (prev_cont == NC_ANYTHING && top_n_flag == TN_TOP2 &&
921 prev->code != null_char_) {
922 float cert = NetworkIO::ProbToCertainty(outputs[prev->code] +
923 outputs[null_char_]) +
924 cert_offset;
925 PushDupOrNoDawgIfBetter(length, true, prev->code, prev->unichar_id,
926 cert, worst_dict_cert, dict_ratio, use_dawgs,
927 NC_NO_DUP, prev, step);
928 }
929 }
930 if (prev_cont == NC_ONLY_DUP) {
931 return;
932 }
933 if (prev->code != null_char_ && length > 0 &&
934 top_n_flags_[null_char_] == top_n_flag) {
935 // Allow nulls within multi code sequences, as the nulls within are not
936 // explicitly included in the code sequence.
937 float cert =
938 NetworkIO::ProbToCertainty(outputs[null_char_]) + cert_offset;
939 PushDupOrNoDawgIfBetter(length, false, null_char_, INVALID_UNICHAR_ID,
940 cert, worst_dict_cert, dict_ratio, use_dawgs,
941 NC_ANYTHING, prev, step);
942 }
943 }
944 const std::vector<int> *final_codes = recoder_.GetFinalCodes(prefix);
945 if (final_codes != nullptr) {
946 for (int code : *final_codes) {
947 if (top_n_flags_[code] != top_n_flag) {
948 continue;
949 }
950 if (prev != nullptr && prev->code == code && !is_simple_text_) {
951 continue;
952 }
953 float cert = NetworkIO::ProbToCertainty(outputs[code]) + cert_offset;
954 if (cert < kMinCertainty && code != null_char_) {
955 continue;
956 }
957 full_code.Set(length, code);
958 int unichar_id = recoder_.DecodeUnichar(full_code);
959 // Map the null char to INVALID.
960 if (length == 0 && code == null_char_) {
961 unichar_id = INVALID_UNICHAR_ID;
962 }
963 if (unichar_id != INVALID_UNICHAR_ID && charset != nullptr &&
964 !charset->get_enabled(unichar_id)) {
965 continue; // disabled by whitelist/blacklist
966 }
967 ContinueUnichar(code, unichar_id, cert, worst_dict_cert, dict_ratio,
968 use_dawgs, NC_ANYTHING, prev, step);
969 if (top_n_flag == TN_TOP2 && code != null_char_) {
970 float prob = outputs[code] + outputs[null_char_];
971 if (prev != nullptr && prev_cont == NC_ANYTHING &&
972 prev->code != null_char_ &&
973 ((prev->code == top_code_ && code == second_code_) ||
974 (code == top_code_ && prev->code == second_code_))) {
975 prob += outputs[prev->code];
976 }
977 cert = NetworkIO::ProbToCertainty(prob) + cert_offset;
978 ContinueUnichar(code, unichar_id, cert, worst_dict_cert, dict_ratio,
979 use_dawgs, NC_ONLY_DUP, prev, step);
980 }
981 }
982 }
983 const std::vector<int> *next_codes = recoder_.GetNextCodes(prefix);
984 if (next_codes != nullptr) {
985 for (int code : *next_codes) {
986 if (top_n_flags_[code] != top_n_flag) {
987 continue;
988 }
989 if (prev != nullptr && prev->code == code && !is_simple_text_) {
990 continue;
991 }
992 float cert = NetworkIO::ProbToCertainty(outputs[code]) + cert_offset;
993 PushDupOrNoDawgIfBetter(length + 1, false, code, INVALID_UNICHAR_ID, cert,
994 worst_dict_cert, dict_ratio, use_dawgs,
995 NC_ANYTHING, prev, step);
996 if (top_n_flag == TN_TOP2 && code != null_char_) {
997 float prob = outputs[code] + outputs[null_char_];
998 if (prev != nullptr && prev_cont == NC_ANYTHING &&
999 prev->code != null_char_ &&
1000 ((prev->code == top_code_ && code == second_code_) ||
1001 (code == top_code_ && prev->code == second_code_))) {
1002 prob += outputs[prev->code];
1003 }
1004 cert = NetworkIO::ProbToCertainty(prob) + cert_offset;
1005 PushDupOrNoDawgIfBetter(length + 1, false, code, INVALID_UNICHAR_ID,
1006 cert, worst_dict_cert, dict_ratio, use_dawgs,
1007 NC_ONLY_DUP, prev, step);
1008 }
1009 }
1010 }
1011}
1012
1013// Continues for a new unichar, using dawg or non-dawg as per flag.
1014void RecodeBeamSearch::ContinueUnichar(int code, int unichar_id, float cert,
1015 float worst_dict_cert, float dict_ratio,
1016 bool use_dawgs, NodeContinuation cont,
1017 const RecodeNode *prev,
1018 RecodeBeam *step) {
1019 if (use_dawgs) {
1020 if (cert > worst_dict_cert) {
1021 ContinueDawg(code, unichar_id, cert, cont, prev, step);
1022 }
1023 } else {
1024 RecodeHeap *nodawg_heap = &step->beams_[BeamIndex(false, cont, 0)];
1025 PushHeapIfBetter(kBeamWidths[0], code, unichar_id, TOP_CHOICE_PERM, false,
1026 false, false, false, cert * dict_ratio, prev, nullptr,
1027 nodawg_heap);
1028 if (dict_ != nullptr &&
1029 ((unichar_id == UNICHAR_SPACE && cert > worst_dict_cert) ||
1030 !dict_->getUnicharset().IsSpaceDelimited(unichar_id))) {
1031 // Any top choice position that can start a new word, ie a space or
1032 // any non-space-delimited character, should also be considered
1033 // by the dawg search, so push initial dawg to the dawg heap.
1034 float dawg_cert = cert;
1035 PermuterType permuter = TOP_CHOICE_PERM;
1036 // Since we use the space either side of a dictionary word in the
1037 // certainty of the word, (to properly handle weak spaces) and the
1038 // space is coming from a non-dict word, we need special conditions
1039 // to avoid degrading the certainty of the dict word that follows.
1040 // With a space we don't multiply the certainty by dict_ratio, and we
1041 // flag the space with NO_PERM to indicate that we should not use the
1042 // predecessor nulls to generate the confidence for the space, as they
1043 // have already been multiplied by dict_ratio, and we can't go back to
1044 // insert more entries in any previous heaps.
1045 if (unichar_id == UNICHAR_SPACE) {
1046 permuter = NO_PERM;
1047 } else {
1048 dawg_cert *= dict_ratio;
1049 }
1050 PushInitialDawgIfBetter(code, unichar_id, permuter, false, false,
1051 dawg_cert, cont, prev, step);
1052 }
1053 }
1054}
1055
1056// Adds a RecodeNode composed of the tuple (code, unichar_id, cert, prev,
1057// appropriate-dawg-args, cert) to the given heap (dawg_beam_) if unichar_id
1058// is a valid continuation of whatever is in prev.
1059void RecodeBeamSearch::ContinueDawg(int code, int unichar_id, float cert,
1060 NodeContinuation cont,
1061 const RecodeNode *prev, RecodeBeam *step) {
1062 RecodeHeap *dawg_heap = &step->beams_[BeamIndex(true, cont, 0)];
1063 RecodeHeap *nodawg_heap = &step->beams_[BeamIndex(false, cont, 0)];
1064 if (unichar_id == INVALID_UNICHAR_ID) {
1065 PushHeapIfBetter(kBeamWidths[0], code, unichar_id, NO_PERM, false, false,
1066 false, false, cert, prev, nullptr, dawg_heap);
1067 return;
1068 }
1069 // Avoid dictionary probe if score a total loss.
1070 float score = cert;
1071 if (prev != nullptr) {
1072 score += prev->score;
1073 }
1074 if (dawg_heap->size() >= kBeamWidths[0] &&
1075 score <= dawg_heap->PeekTop().data().score &&
1076 nodawg_heap->size() >= kBeamWidths[0] &&
1077 score <= nodawg_heap->PeekTop().data().score) {
1078 return;
1079 }
1080 const RecodeNode *uni_prev = prev;
1081 // Prev may be a partial code, null_char, or duplicate, so scan back to the
1082 // last valid unichar_id.
1083 while (uni_prev != nullptr &&
1084 (uni_prev->unichar_id == INVALID_UNICHAR_ID || uni_prev->duplicate)) {
1085 uni_prev = uni_prev->prev;
1086 }
1087 if (unichar_id == UNICHAR_SPACE) {
1088 if (uni_prev != nullptr && uni_prev->end_of_word) {
1089 // Space is good. Push initial state, to the dawg beam and a regular
1090 // space to the top choice beam.
1091 PushInitialDawgIfBetter(code, unichar_id, uni_prev->permuter, false,
1092 false, cert, cont, prev, step);
1093 PushHeapIfBetter(kBeamWidths[0], code, unichar_id, uni_prev->permuter,
1094 false, false, false, false, cert, prev, nullptr,
1095 nodawg_heap);
1096 }
1097 return;
1098 } else if (uni_prev != nullptr && uni_prev->start_of_dawg &&
1099 uni_prev->unichar_id != UNICHAR_SPACE &&
1100 dict_->getUnicharset().IsSpaceDelimited(uni_prev->unichar_id) &&
1101 dict_->getUnicharset().IsSpaceDelimited(unichar_id)) {
1102 return; // Can't break words between space delimited chars.
1103 }
1104 DawgPositionVector initial_dawgs;
1105 auto *updated_dawgs = new DawgPositionVector;
1106 DawgArgs dawg_args(&initial_dawgs, updated_dawgs, NO_PERM);
1107 bool word_start = false;
1108 if (uni_prev == nullptr) {
1109 // Starting from beginning of line.
1110 dict_->default_dawgs(&initial_dawgs, false);
1111 word_start = true;
1112 } else if (uni_prev->dawgs != nullptr) {
1113 // Continuing a previous dict word.
1114 dawg_args.active_dawgs = uni_prev->dawgs;
1115 word_start = uni_prev->start_of_dawg;
1116 } else {
1117 return; // Can't continue if not a dict word.
1118 }
1119 auto permuter = static_cast<PermuterType>(dict_->def_letter_is_okay(
1120 &dawg_args, dict_->getUnicharset(), unichar_id, false));
1121 if (permuter != NO_PERM) {
1122 PushHeapIfBetter(kBeamWidths[0], code, unichar_id, permuter, false,
1123 word_start, dawg_args.valid_end, false, cert, prev,
1124 dawg_args.updated_dawgs, dawg_heap);
1125 if (dawg_args.valid_end && !space_delimited_) {
1126 // We can start another word right away, so push initial state as well,
1127 // to the dawg beam, and the regular character to the top choice beam,
1128 // since non-dict words can start here too.
1129 PushInitialDawgIfBetter(code, unichar_id, permuter, word_start, true,
1130 cert, cont, prev, step);
1131 PushHeapIfBetter(kBeamWidths[0], code, unichar_id, permuter, false,
1132 word_start, true, false, cert, prev, nullptr,
1133 nodawg_heap);
1134 }
1135 } else {
1136 delete updated_dawgs;
1137 }
1138}
1139
1140// Adds a RecodeNode composed of the tuple (code, unichar_id,
1141// initial-dawg-state, prev, cert) to the given heap if/ there is room or if
1142// better than the current worst element if already full.
1143void RecodeBeamSearch::PushInitialDawgIfBetter(int code, int unichar_id,
1144 PermuterType permuter,
1145 bool start, bool end, float cert,
1146 NodeContinuation cont,
1147 const RecodeNode *prev,
1148 RecodeBeam *step) {
1149 RecodeNode *best_initial_dawg = &step->best_initial_dawgs_[cont];
1150 float score = cert;
1151 if (prev != nullptr) {
1152 score += prev->score;
1153 }
1154 if (best_initial_dawg->code < 0 || score > best_initial_dawg->score) {
1155 auto *initial_dawgs = new DawgPositionVector;
1156 dict_->default_dawgs(initial_dawgs, false);
1157 RecodeNode node(code, unichar_id, permuter, true, start, end, false, cert,
1158 score, prev, initial_dawgs,
1159 ComputeCodeHash(code, false, prev));
1160 *best_initial_dawg = node;
1161 }
1162}
1163
1164// Adds a RecodeNode composed of the tuple (code, unichar_id, permuter,
1165// false, false, false, false, cert, prev, nullptr) to heap if there is room
1166// or if better than the current worst element if already full.
1167/* static */
1168void RecodeBeamSearch::PushDupOrNoDawgIfBetter(
1169 int length, bool dup, int code, int unichar_id, float cert,
1170 float worst_dict_cert, float dict_ratio, bool use_dawgs,
1171 NodeContinuation cont, const RecodeNode *prev, RecodeBeam *step) {
1172 int index = BeamIndex(use_dawgs, cont, length);
1173 if (use_dawgs) {
1174 if (cert > worst_dict_cert) {
1175 PushHeapIfBetter(kBeamWidths[length], code, unichar_id,
1176 prev ? prev->permuter : NO_PERM, false, false, false,
1177 dup, cert, prev, nullptr, &step->beams_[index]);
1178 }
1179 } else {
1180 cert *= dict_ratio;
1181 if (cert >= kMinCertainty || code == null_char_) {
1182 PushHeapIfBetter(kBeamWidths[length], code, unichar_id,
1183 prev ? prev->permuter : TOP_CHOICE_PERM, false, false,
1184 false, dup, cert, prev, nullptr, &step->beams_[index]);
1185 }
1186 }
1187}
1188
1189// Adds a RecodeNode composed of the tuple (code, unichar_id, permuter,
1190// dawg_start, word_start, end, dup, cert, prev, d) to heap if there is room
1191// or if better than the current worst element if already full.
1192void RecodeBeamSearch::PushHeapIfBetter(int max_size, int code, int unichar_id,
1193 PermuterType permuter, bool dawg_start,
1194 bool word_start, bool end, bool dup,
1195 float cert, const RecodeNode *prev,
1196 DawgPositionVector *d,
1197 RecodeHeap *heap) {
1198 float score = cert;
1199 if (prev != nullptr) {
1200 score += prev->score;
1201 }
1202 if (heap->size() < max_size || score > heap->PeekTop().data().score) {
1203 uint64_t hash = ComputeCodeHash(code, dup, prev);
1204 RecodeNode node(code, unichar_id, permuter, dawg_start, word_start, end,
1205 dup, cert, score, prev, d, hash);
1206 if (UpdateHeapIfMatched(&node, heap)) {
1207 return;
1208 }
1209 RecodePair entry(score, node);
1210 heap->Push(&entry);
1211 ASSERT_HOST(entry.data().dawgs == nullptr);
1212 if (heap->size() > max_size) {
1213 heap->Pop(&entry);
1214 }
1215 } else {
1216 delete d;
1217 }
1218}
1219
1220// Adds a RecodeNode to heap if there is room
1221// or if better than the current worst element if already full.
1222void RecodeBeamSearch::PushHeapIfBetter(int max_size, RecodeNode *node,
1223 RecodeHeap *heap) {
1224 if (heap->size() < max_size || node->score > heap->PeekTop().data().score) {
1225 if (UpdateHeapIfMatched(node, heap)) {
1226 return;
1227 }
1228 RecodePair entry(node->score, *node);
1229 heap->Push(&entry);
1230 ASSERT_HOST(entry.data().dawgs == nullptr);
1231 if (heap->size() > max_size) {
1232 heap->Pop(&entry);
1233 }
1234 }
1235}
1236
1237// Searches the heap for a matching entry, and updates the score with
1238// reshuffle if needed. Returns true if there was a match.
1239bool RecodeBeamSearch::UpdateHeapIfMatched(RecodeNode *new_node,
1240 RecodeHeap *heap) {
1241 // TODO(rays) consider hash map instead of linear search.
1242 // It might not be faster because the hash map would have to be updated
1243 // every time a heap reshuffle happens, and that would be a lot of overhead.
1244 std::vector<RecodePair> &nodes = heap->heap();
1245 for (auto &i : nodes) {
1246 RecodeNode &node = i.data();
1247 if (node.code == new_node->code && node.code_hash == new_node->code_hash &&
1248 node.permuter == new_node->permuter &&
1249 node.start_of_dawg == new_node->start_of_dawg) {
1250 if (new_node->score > node.score) {
1251 // The new one is better. Update the entire node in the heap and
1252 // reshuffle.
1253 node = *new_node;
1254 i.key() = node.score;
1255 heap->Reshuffle(&i);
1256 }
1257 return true;
1258 }
1259 }
1260 return false;
1261}
1262
1263// Computes and returns the code-hash for the given code and prev.
1264uint64_t RecodeBeamSearch::ComputeCodeHash(int code, bool dup,
1265 const RecodeNode *prev) const {
1266 uint64_t hash = prev == nullptr ? 0 : prev->code_hash;
1267 if (!dup && code != null_char_) {
1268 int num_classes = recoder_.code_range();
1269 uint64_t carry = (((hash >> 32) * num_classes) >> 32);
1270 hash *= num_classes;
1271 hash += carry;
1272 hash += code;
1273 }
1274 return hash;
1275}
1276
1277// Backtracks to extract the best path through the lattice that was built
1278// during Decode. On return the best_nodes vector essentially contains the set
1279// of code, score pairs that make the optimal path with the constraint that
1280// the recoder can decode the code sequence back to a sequence of unichar-ids.
1281void RecodeBeamSearch::ExtractBestPaths(
1282 std::vector<const RecodeNode *> *best_nodes,
1283 std::vector<const RecodeNode *> *second_nodes) const {
1284 // Scan both beams to extract the best and second best paths.
1285 const RecodeNode *best_node = nullptr;
1286 const RecodeNode *second_best_node = nullptr;
1287 const RecodeBeam *last_beam = beam_[beam_size_ - 1];
1288 for (int c = 0; c < NC_COUNT; ++c) {
1289 if (c == NC_ONLY_DUP) {
1290 continue;
1291 }
1292 auto cont = static_cast<NodeContinuation>(c);
1293 for (int is_dawg = 0; is_dawg < 2; ++is_dawg) {
1294 int beam_index = BeamIndex(is_dawg, cont, 0);
1295 int heap_size = last_beam->beams_[beam_index].size();
1296 for (int h = 0; h < heap_size; ++h) {
1297 const RecodeNode *node = &last_beam->beams_[beam_index].get(h).data();
1298 if (is_dawg) {
1299 // dawg_node may be a null_char, or duplicate, so scan back to the
1300 // last valid unichar_id.
1301 const RecodeNode *dawg_node = node;
1302 while (dawg_node != nullptr &&
1303 (dawg_node->unichar_id == INVALID_UNICHAR_ID ||
1304 dawg_node->duplicate)) {
1305 dawg_node = dawg_node->prev;
1306 }
1307 if (dawg_node == nullptr ||
1308 (!dawg_node->end_of_word &&
1309 dawg_node->unichar_id != UNICHAR_SPACE)) {
1310 // Dawg node is not valid.
1311 continue;
1312 }
1313 }
1314 if (best_node == nullptr || node->score > best_node->score) {
1315 second_best_node = best_node;
1316 best_node = node;
1317 } else if (second_best_node == nullptr ||
1318 node->score > second_best_node->score) {
1319 second_best_node = node;
1320 }
1321 }
1322 }
1323 }
1324 if (second_nodes != nullptr) {
1325 ExtractPath(second_best_node, second_nodes);
1326 }
1327 ExtractPath(best_node, best_nodes);
1328}
1329
1330// Helper backtracks through the lattice from the given node, storing the
1331// path and reversing it.
1332void RecodeBeamSearch::ExtractPath(
1333 const RecodeNode *node, std::vector<const RecodeNode *> *path) const {
1334 path->clear();
1335 while (node != nullptr) {
1336 path->push_back(node);
1337 node = node->prev;
1338 }
1339 std::reverse(path->begin(), path->end());
1340}
1341
1342void RecodeBeamSearch::ExtractPath(const RecodeNode *node,
1343 std::vector<const RecodeNode *> *path,
1344 int limiter) const {
1345 int pathcounter = 0;
1346 path->clear();
1347 while (node != nullptr && pathcounter < limiter) {
1348 path->push_back(node);
1349 node = node->prev;
1350 ++pathcounter;
1351 }
1352 std::reverse(path->begin(), path->end());
1353}
1354
1355// Helper prints debug information on the given lattice path.
1356void RecodeBeamSearch::DebugPath(
1357 const UNICHARSET *unicharset,
1358 const std::vector<const RecodeNode *> &path) const {
1359 for (unsigned c = 0; c < path.size(); ++c) {
1360 const RecodeNode &node = *path[c];
1361 tprintf("%u ", c);
1362 node.Print(null_char_, *unicharset, 1);
1363 }
1364}
1365
1366// Helper prints debug information on the given unichar path.
1367void RecodeBeamSearch::DebugUnicharPath(
1368 const UNICHARSET *unicharset, const std::vector<const RecodeNode *> &path,
1369 const std::vector<int> &unichar_ids, const std::vector<float> &certs,
1370 const std::vector<float> &ratings, const std::vector<int> &xcoords) const {
1371 auto num_ids = unichar_ids.size();
1372 double total_rating = 0.0;
1373 for (unsigned c = 0; c < num_ids; ++c) {
1374 int coord = xcoords[c];
1375 tprintf("%d %d=%s r=%g, c=%g, s=%d, e=%d, perm=%d\n", coord, unichar_ids[c],
1376 unicharset->debug_str(unichar_ids[c]).c_str(), ratings[c], certs[c],
1377 path[coord]->start_of_word, path[coord]->end_of_word,
1378 path[coord]->permuter);
1379 total_rating += ratings[c];
1380 }
1381 tprintf("Path total rating = %g\n", total_rating);
1382}
1383
1384} // namespace tesseract.
#define ASSERT_HOST(x)
Definition: errcode.h:54
@ TBOX
const char * p
KDPairInc< double, RecodeNode > RecodePair
Definition: recodebeam.h:177
void tprintf(const char *format,...)
Definition: tprintf.cpp:41
@ character
Definition: mfoutline.h:53
@ TN_ALSO_RAN
Definition: recodebeam.h:87
@ UNICHAR_SPACE
Definition: unicharset.h:36
GenericHeap< RecodePair > RecodeHeap
Definition: recodebeam.h:178
@ BCC_STATIC_CLASSIFIER
Definition: ratngs.h:49
PermuterType
Definition: ratngs.h:235
@ TOP_CHOICE_PERM
Definition: ratngs.h:238
@ NO_PERM
Definition: ratngs.h:236
NodeContinuation
Definition: recodebeam.h:72
@ NC_ANYTHING
Definition: recodebeam.h:73
@ NC_ONLY_DUP
Definition: recodebeam.h:74
void put(ICOORD pos, const T &thing)
Definition: matrix.h:260
void FakeWordFromRatings(PermuterType permuter)
Definition: pageres.cpp:930
MATRIX * ratings
Definition: pageres.h:235
static C_BLOB * FakeBlob(const TBOX &box)
Definition: stepblob.cpp:238
const Pair & get(int index) const
Definition: genericheap.h:87
Data & data()
Definition: kdpair.h:41
static const int kMaxCodeLen
const std::vector< int > * GetFinalCodes(const RecodedCharID &code) const
const std::vector< int > * GetNextCodes(const RecodedCharID &code) const
int DecodeUnichar(const RecodedCharID &code) const
const char * id_to_unichar(UNICHAR_ID id) const
Definition: unicharset.cpp:279
size_t size() const
Definition: unicharset.h:355
bool IsSpaceDelimited(UNICHAR_ID unichar_id) const
Definition: unicharset.h:668
std::string debug_str(UNICHAR_ID id) const
Definition: unicharset.cpp:331
const char * id_to_unichar_ext(UNICHAR_ID id) const
Definition: unicharset.cpp:287
bool IsSpaceDelimitedLang() const
Returns true if the language is space-delimited (not CJ, or T).
Definition: dict.cpp:912
void default_dawgs(DawgPositionVector *anylength_dawgs, bool suppress_patterns) const
Definition: dict.cpp:624
int def_letter_is_okay(void *void_dawg_args, const UNICHARSET &unicharset, UNICHAR_ID unichar_id, bool word_end) const
Definition: dict.cpp:406
const UNICHARSET & getUnicharset() const
Definition: dict.h:104
static float ProbToCertainty(float prob)
Definition: networkio.cpp:580
bool operator()(const RecodeNode *&node1, const RecodeNode *&node2) const
Definition: recodebeam.cpp:325
const RecodeNode * prev
Definition: recodebeam.h:169
void Print(int null_char, const UNICHARSET &unicharset, int depth) const
Definition: recodebeam.cpp:38
PermuterType permuter
Definition: recodebeam.h:149
void Decode(const NetworkIO &output, double dict_ratio, double cert_offset, double worst_dict_cert, const UNICHARSET *charset, int lstm_choice_mode=0)
Definition: recodebeam.cpp:83
static bool IsDawgFromBeamsIndex(int index)
Definition: recodebeam.h:256
std::vector< std::vector< std::pair< const char *, float > > > ctc_choices
Definition: recodebeam.h:234
static int LengthFromBeamsIndex(int index)
Definition: recodebeam.h:250
void DecodeSecondaryBeams(const NetworkIO &output, double dict_ratio, double cert_offset, double worst_dict_cert, const UNICHARSET *charset, int lstm_choice_mode=0)
Definition: recodebeam.cpp:112
std::vector< std::vector< std::pair< const char *, float > > > timesteps
Definition: recodebeam.h:231
std::vector< std::vector< std::pair< const char *, float > > > combineSegmentedTimesteps(std::vector< std::vector< std::vector< std::pair< const char *, float > > > > *segmentedTimesteps)
Definition: recodebeam.cpp:175
std::vector< std::vector< std::vector< std::pair< const char *, float > > > > segmentedTimesteps
Definition: recodebeam.h:232
void extractSymbolChoices(const UNICHARSET *unicharset)
Definition: recodebeam.cpp:409
std::vector< int > character_boundaries_
Definition: recodebeam.h:238
std::vector< std::unordered_set< int > > excludedUnichars
Definition: recodebeam.h:236
static NodeContinuation ContinuationFromBeamsIndex(int index)
Definition: recodebeam.h:253
void PrintBeam2(bool uids, int num_outputs, const UNICHARSET *charset, bool secondary) const
Definition: recodebeam.cpp:330
void DebugBeams(const UNICHARSET &unicharset) const
Definition: recodebeam.cpp:516
void ExtractBestPathAsUnicharIds(bool debug, const UNICHARSET *unicharset, std::vector< int > *unichar_ids, std::vector< float > *certs, std::vector< float > *ratings, std::vector< int > *xcoords) const
Definition: recodebeam.cpp:224
void ExtractBestPathAsLabels(std::vector< int > *labels, std::vector< int > *xcoords) const
Definition: recodebeam.cpp:201
RecodeBeamSearch(const UnicharCompress &recoder, int null_char, bool simple_text, Dict *dict)
Definition: recodebeam.cpp:58
static const int kNumBeams
Definition: recodebeam.h:248
static constexpr float kMinCertainty
Definition: recodebeam.h:243
static int BeamIndex(bool is_dawg, NodeContinuation cont, int length)
Definition: recodebeam.h:260
void ExtractBestPathAsWords(const TBOX &line_box, float scale_factor, bool debug, const UNICHARSET *unicharset, PointerVector< WERD_RES > *words, int lstm_choice_mode=0)
Definition: recodebeam.cpp:239