32 5, 10, 16, 16, 16, 16, 16, 16, 16, 16,
35static const char *kNodeContNames[] = {
"Anything",
"OnlyDup",
"NoDup"};
40 if (
code == null_char) {
49 if (depth > 0 &&
prev !=
nullptr) {
51 prev->
Print(null_char, unicharset, depth - 1);
59 int null_char,
bool simple_text,
Dict *dict)
65 space_delimited_(true),
66 is_simple_text_(simple_text),
67 null_char_(null_char) {
69 space_delimited_ =
false;
74 for (
auto data : beam_) {
77 for (
auto data : secondary_beam_) {
84 double cert_offset,
double worst_dict_cert,
85 const UNICHARSET *charset,
int lstm_choice_mode) {
87 int width =
output.Width();
88 if (lstm_choice_mode) {
91 for (
int t = 0; t < width; ++t) {
92 ComputeTopN(
output.f(t),
output.NumFeatures(), kBeamWidths[0]);
93 DecodeStep(
output.f(t), t, dict_ratio, cert_offset, worst_dict_cert,
95 if (lstm_choice_mode) {
96 SaveMostCertainChoices(
output.f(t),
output.NumFeatures(), charset, t);
101 double dict_ratio,
double cert_offset,
102 double worst_dict_cert,
105 int width =
output.dim1();
106 for (
int t = 0; t < width; ++t) {
108 DecodeStep(
output[t], t, dict_ratio, cert_offset, worst_dict_cert, charset);
114 double worst_dict_cert,
const UNICHARSET *charset,
int lstm_choice_mode) {
115 for (
auto data : secondary_beam_) {
118 secondary_beam_.clear();
122 int width =
output.Width();
123 unsigned bucketNumber = 0;
124 for (
int t = 0; t < width; ++t) {
130 output.NumFeatures(), kBeamWidths[0]);
131 DecodeSecondaryStep(
output.f(t), t, dict_ratio, cert_offset,
132 worst_dict_cert, charset);
136void RecodeBeamSearch::SaveMostCertainChoices(
const float *outputs,
140 std::vector<std::pair<const char *, float>> choices;
141 for (
int i = 0;
i < num_outputs; ++
i) {
142 if (outputs[
i] >= 0.01f) {
144 if (
i + 2 >= num_outputs) {
154 while (choices.size() > pos && choices[pos].second > outputs[
i]) {
157 choices.insert(choices.begin() + pos,
158 std::pair<const char *, float>(
character, outputs[
i]));
166 std::vector<std::vector<std::pair<const char *, float>>> segment;
174std::vector<std::vector<std::pair<const char *, float>>>
176 std::vector<std::vector<std::vector<std::pair<const char *, float>>>>
177 *segmentedTimesteps) {
178 std::vector<std::vector<std::pair<const char *, float>>> combined_timesteps;
180 for (
auto &j : segmentedTimestep) {
181 combined_timesteps.push_back(j);
184 return combined_timesteps;
187void RecodeBeamSearch::calculateCharBoundaries(std::vector<int> *starts,
188 std::vector<int> *ends,
189 std::vector<int> *char_bounds_,
191 char_bounds_->push_back(0);
192 for (
unsigned i = 0;
i < ends->size(); ++
i) {
193 int middle = ((*starts)[
i + 1] - (*ends)[
i]) / 2;
194 char_bounds_->push_back((*ends)[
i] + middle);
196 char_bounds_->pop_back();
197 char_bounds_->push_back(maxWidth);
202 std::vector<int> *labels, std::vector<int> *xcoords)
const {
205 std::vector<const RecodeNode *> best_nodes;
206 ExtractBestPaths(&best_nodes,
nullptr);
209 int width = best_nodes.size();
211 int label = best_nodes[t]->code;
212 if (label != null_char_) {
213 labels->push_back(label);
214 xcoords->push_back(t);
216 while (++t < width && !is_simple_text_ && best_nodes[t]->code == label) {
219 xcoords->push_back(width);
225 bool debug,
const UNICHARSET *unicharset, std::vector<int> *unichar_ids,
226 std::vector<float> *certs, std::vector<float> *ratings,
227 std::vector<int> *xcoords)
const {
228 std::vector<const RecodeNode *> best_nodes;
229 ExtractBestPaths(&best_nodes,
nullptr);
230 ExtractPathAsUnicharIds(best_nodes, unichar_ids, certs, ratings, xcoords);
232 DebugPath(unicharset, best_nodes);
233 DebugUnicharPath(unicharset, best_nodes, *unichar_ids, *certs, *ratings,
240 float scale_factor,
bool debug,
243 int lstm_choice_mode) {
245 std::vector<int> unichar_ids;
246 std::vector<float> certs;
247 std::vector<float> ratings;
248 std::vector<int> xcoords;
249 std::vector<const RecodeNode *> best_nodes;
250 std::vector<const RecodeNode *> second_nodes;
252 ExtractBestPaths(&best_nodes, &second_nodes);
254 DebugPath(unicharset, best_nodes);
255 ExtractPathAsUnicharIds(second_nodes, &unichar_ids, &certs, &ratings,
257 tprintf(
"\nSecond choice path:\n");
258 DebugUnicharPath(unicharset, second_nodes, unichar_ids, certs, ratings,
264 ExtractPathAsUnicharIds(best_nodes, &unichar_ids, &certs, &ratings, &xcoords,
266 int num_ids = unichar_ids.size();
268 DebugUnicharPath(unicharset, best_nodes, unichar_ids, certs, ratings,
273 float prev_space_cert = 0.0f;
274 for (
int word_start = 0; word_start < num_ids; word_start = word_end) {
275 for (word_end = word_start + 1; word_end < num_ids; ++word_end) {
282 int index = xcoords[word_end];
283 if (best_nodes[index]->start_of_word) {
292 float space_cert = 0.0f;
293 if (word_end < num_ids && unichar_ids[word_end] ==
UNICHAR_SPACE) {
294 space_cert = certs[word_end];
297 word_start > 0 && unichar_ids[word_start - 1] ==
UNICHAR_SPACE;
300 InitializeWord(leading_space, line_box, word_start, word_end,
301 std::min(space_cert, prev_space_cert), unicharset,
302 xcoords, scale_factor);
303 for (
int i = word_start;
i < word_end; ++
i) {
304 auto *choices =
new BLOB_CHOICE_LIST;
305 BLOB_CHOICE_IT bc_it(choices);
306 auto *choice =
new BLOB_CHOICE(unichar_ids[
i], ratings[
i], certs[
i], -1,
307 1.0f,
static_cast<float>(INT16_MAX), 0.0f,
309 int col =
i - word_start;
310 choice->set_matrix_cell(col, col);
311 bc_it.add_after_then_move(choice);
314 int index = xcoords[word_end - 1];
317 prev_space_cert = space_cert;
318 if (word_end < num_ids && unichar_ids[word_end] ==
UNICHAR_SPACE) {
332 bool secondary)
const {
333 std::vector<std::vector<const RecodeNode *>> topology;
334 std::unordered_set<const RecodeNode *> visited;
335 const std::vector<RecodeBeam *> &beam = !secondary ? beam_ : secondary_beam_;
337 for (
int step = beam.size() - 1; step >= 0; --step) {
338 std::vector<const RecodeNode *> layer;
339 topology.push_back(layer);
342 for (
int step = beam.size() - 1; step >= 0; --step) {
343 std::vector<tesseract::RecodePair> &heaps = beam.at(step)->beams_->heap();
344 for (
auto &&node : heaps) {
347 while (curr !=
nullptr && !visited.count(curr)) {
348 visited.insert(curr);
349 topology[step - backtracker].push_back(curr);
357 for (
const std::vector<const RecodeNode *> &layer : topology) {
368 if (node->unichar_id != INVALID_UNICHAR_ID) {
370 intCode = node->unichar_id;
371 }
else if (node->code == null_char_) {
379 const char *prevCode;
381 if (node->prev !=
nullptr) {
382 prevScore = node->prev->score;
383 if (node->prev->unichar_id != INVALID_UNICHAR_ID) {
385 intPrevCode = node->prev->unichar_id;
386 }
else if (node->code == null_char_) {
397 tprintf(
"%x(|)%f(>)%x(|)%f\n", intPrevCode, prevScore, intCode,
400 tprintf(
"%s(|)%f(>)%s(|)%f\n", prevCode, prevScore, code, node->score);
415 std::vector<RecodeBeam *> ¤tBeam =
416 secondary_beam_.empty() ? beam_ : secondary_beam_;
419 std::vector<int> unichar_ids;
420 std::vector<float> certs;
421 std::vector<float> ratings;
422 std::vector<int> xcoords;
424 std::vector<tesseract::RecodePair> &heaps =
426 std::vector<const RecodeNode *> best_nodes;
427 std::vector<const RecodeNode *> best;
429 for (
auto &&entry : heaps) {
430 bool validChar =
false;
433 while (node !=
nullptr && backcounter < backpath) {
434 if (node->
code != null_char_ &&
443 best.push_back(&entry.data());
449 ExtractPath(best[0], &best_nodes, backpath);
450 ExtractPathAsUnicharIds(best_nodes, &unichar_ids, &certs, &ratings,
453 if (!unichar_ids.empty()) {
455 for (
unsigned i = 1;
i < unichar_ids.size(); ++
i) {
456 if (ratings[
i] < ratings[bestPos]) {
462 for (
auto &node : best_nodes) {
463 if (node->unichar_id == unichar_ids[bestPos]) {
464 bestCode = node->code;
469 std::unordered_set<int> excludeCodeList;
470 for (
auto &best_node : best_nodes) {
471 if (best_node->code != null_char_) {
472 excludeCodeList.insert(best_node->code);
476 for (
auto elem : excludeCodeList) {
484 int id = unichar_ids[bestPos];
486 float rating = ratings[bestPos];
488 std::pair<const char *, float>(result, rating));
490 std::vector<std::pair<const char *, float>> choice;
491 int id = unichar_ids[bestPos];
493 float rating = ratings[bestPos];
494 choice.emplace_back(result, rating);
500 std::unordered_set<int> excludeCodeList;
504 std::vector<std::pair<const char *, float>> choice;
509 for (
auto data : secondary_beam_) {
512 secondary_beam_.clear();
517 for (
int p = 0;
p < beam_size_; ++
p) {
518 for (
int d = 0; d < 2; ++d) {
519 for (
int c = 0; c <
NC_COUNT; ++c) {
522 if (beam_[
p]->beams_[index].empty()) {
526 tprintf(
"Position %d: %s+%s beam\n",
p, d ?
"Dict" :
"Non-Dict",
528 DebugBeamPos(unicharset, beam_[
p]->beams_[index]);
535void RecodeBeamSearch::DebugBeamPos(
const UNICHARSET &unicharset,
537 std::vector<const RecodeNode *> unichar_bests(unicharset.
size());
539 int heap_size = heap.
size();
540 for (
int i = 0;
i < heap_size; ++
i) {
543 if (null_best ==
nullptr || null_best->
score < node->
score) {
547 if (unichar_bests[node->
unichar_id] ==
nullptr ||
553 for (
auto &unichar_best : unichar_bests) {
554 if (unichar_best !=
nullptr) {
555 const RecodeNode &node = *unichar_best;
556 node.
Print(null_char_, unicharset, 1);
559 if (null_best !=
nullptr) {
560 null_best->
Print(null_char_, unicharset, 1);
567void RecodeBeamSearch::ExtractPathAsUnicharIds(
568 const std::vector<const RecodeNode *> &best_nodes,
569 std::vector<int> *unichar_ids, std::vector<float> *certs,
570 std::vector<float> *ratings, std::vector<int> *xcoords,
571 std::vector<int> *character_boundaries) {
572 unichar_ids->clear();
576 std::vector<int> starts;
577 std::vector<int> ends;
580 int width = best_nodes.size();
582 double certainty = 0.0;
584 while (t < width && best_nodes[t]->unichar_id == INVALID_UNICHAR_ID) {
585 double cert = best_nodes[t++]->certainty;
586 if (cert < certainty) {
593 int unichar_id = best_nodes[t]->unichar_id;
595 best_nodes[t]->permuter !=
NO_PERM) {
598 if (certainty < certs->back()) {
599 certs->back() = certainty;
601 ratings->back() += rating;
605 unichar_ids->push_back(unichar_id);
606 xcoords->push_back(t);
608 double cert = best_nodes[t++]->certainty;
612 best_nodes[t - 1]->permuter ==
NO_PERM)) {
616 }
while (t < width && best_nodes[t]->duplicate);
618 certs->push_back(certainty);
619 ratings->push_back(rating);
620 }
else if (!certs->empty()) {
621 if (certainty < certs->back()) {
622 certs->back() = certainty;
624 ratings->back() += rating;
627 starts.push_back(width);
628 if (character_boundaries !=
nullptr) {
629 calculateCharBoundaries(&starts, &ends, character_boundaries, width);
631 xcoords->push_back(width);
636WERD_RES *RecodeBeamSearch::InitializeWord(
bool leading_space,
637 const TBOX &line_box,
int word_start,
638 int word_end,
float space_certainty,
639 const UNICHARSET *unicharset,
640 const std::vector<int> &xcoords,
641 float scale_factor) {
644 C_BLOB_IT b_it(&blobs);
645 for (
int i = word_start;
i < word_end; ++
i) {
647 TBOX box(
static_cast<int16_t
>(
651 static_cast<int16_t
>(
659 WERD *word =
new WERD(&blobs, leading_space,
nullptr);
661 auto *word_res =
new WERD_RES(word);
662 word_res->end = word_end - word_start + leading_space;
663 word_res->uch_set = unicharset;
664 word_res->combination =
true;
665 word_res->space_certainty = space_certainty;
666 word_res->ratings =
new MATRIX(word_end - word_start, 1);
672void RecodeBeamSearch::ComputeTopN(
const float *outputs,
int num_outputs,
674 top_n_flags_.clear();
679 for (
int i = 0;
i < num_outputs; ++
i) {
680 if (top_heap_.size() < top_n || outputs[
i] > top_heap_.PeekTop().key()) {
681 TopPair entry(outputs[
i],
i);
682 top_heap_.Push(&entry);
683 if (top_heap_.size() > top_n) {
684 top_heap_.Pop(&entry);
688 while (!top_heap_.empty()) {
690 top_heap_.Pop(&entry);
691 if (top_heap_.size() > 1) {
692 top_n_flags_[entry.data()] =
TN_TOPN;
694 top_n_flags_[entry.data()] =
TN_TOP2;
695 if (top_heap_.empty()) {
696 top_code_ = entry.data();
698 second_code_ = entry.data();
702 top_n_flags_[null_char_] =
TN_TOP2;
705void RecodeBeamSearch::ComputeSecTopN(std::unordered_set<int> *exList,
706 const float *outputs,
int num_outputs,
708 top_n_flags_.clear();
713 for (
int i = 0;
i < num_outputs; ++
i) {
714 if ((top_heap_.size() < top_n || outputs[
i] > top_heap_.PeekTop().key()) &&
716 TopPair entry(outputs[
i],
i);
717 top_heap_.Push(&entry);
718 if (top_heap_.size() > top_n) {
719 top_heap_.Pop(&entry);
723 while (!top_heap_.empty()) {
725 top_heap_.Pop(&entry);
726 if (top_heap_.size() > 1) {
727 top_n_flags_[entry.data()] =
TN_TOPN;
729 top_n_flags_[entry.data()] =
TN_TOP2;
730 if (top_heap_.empty()) {
731 top_code_ = entry.data();
733 second_code_ = entry.data();
737 top_n_flags_[null_char_] =
TN_TOP2;
743void RecodeBeamSearch::DecodeStep(
const float *outputs,
int t,
744 double dict_ratio,
double cert_offset,
745 double worst_dict_cert,
746 const UNICHARSET *charset,
bool debug) {
747 if (t ==
static_cast<int>(beam_.size())) {
748 beam_.push_back(
new RecodeBeam);
750 RecodeBeam *step = beam_[t];
756 charset, dict_ratio, cert_offset, worst_dict_cert, step);
757 if (dict_ !=
nullptr) {
759 TN_TOP2, charset, dict_ratio, cert_offset,
760 worst_dict_cert, step);
763 RecodeBeam *prev = beam_[t - 1];
766 for (
int i = prev->beams_[beam_index].size() - 1;
i >= 0; --
i) {
767 std::vector<const RecodeNode *> path;
768 ExtractPath(&prev->beams_[beam_index].get(
i).data(), &path);
769 tprintf(
"Step %d: Dawg beam %d:\n", t,
i);
770 DebugPath(charset, path);
773 for (
int i = prev->beams_[beam_index].size() - 1;
i >= 0; --
i) {
774 std::vector<const RecodeNode *> path;
775 ExtractPath(&prev->beams_[beam_index].get(
i).data(), &path);
776 tprintf(
"Step %d: Non-Dawg beam %d:\n", t,
i);
777 DebugPath(charset, path);
785 for (
int tn = 0; tn <
TN_COUNT && total_beam == 0; ++tn) {
787 for (
int index = 0; index <
kNumBeams; ++index) {
791 for (
int i = prev->beams_[index].size() - 1;
i >= 0; --
i) {
792 ContinueContext(&prev->beams_[index].get(
i).data(), index, outputs,
793 top_n, charset, dict_ratio, cert_offset,
794 worst_dict_cert, step);
797 for (
int index = 0; index <
kNumBeams; ++index) {
799 total_beam += step->beams_[index].size();
805 for (
int c = 0; c <
NC_COUNT; ++c) {
806 if (step->best_initial_dawgs_[c].code >= 0) {
809 PushHeapIfBetter(kBeamWidths[0], &step->best_initial_dawgs_[c],
816void RecodeBeamSearch::DecodeSecondaryStep(
817 const float *outputs,
int t,
double dict_ratio,
double cert_offset,
818 double worst_dict_cert,
const UNICHARSET *charset,
bool debug) {
819 if (t ==
static_cast<int>(secondary_beam_.size())) {
820 secondary_beam_.push_back(
new RecodeBeam);
822 RecodeBeam *step = secondary_beam_[t];
827 charset, dict_ratio, cert_offset, worst_dict_cert, step);
828 if (dict_ !=
nullptr) {
830 TN_TOP2, charset, dict_ratio, cert_offset,
831 worst_dict_cert, step);
834 RecodeBeam *prev = secondary_beam_[t - 1];
837 for (
int i = prev->beams_[beam_index].size() - 1;
i >= 0; --
i) {
838 std::vector<const RecodeNode *> path;
839 ExtractPath(&prev->beams_[beam_index].get(
i).data(), &path);
840 tprintf(
"Step %d: Dawg beam %d:\n", t,
i);
841 DebugPath(charset, path);
844 for (
int i = prev->beams_[beam_index].size() - 1;
i >= 0; --
i) {
845 std::vector<const RecodeNode *> path;
846 ExtractPath(&prev->beams_[beam_index].get(
i).data(), &path);
847 tprintf(
"Step %d: Non-Dawg beam %d:\n", t,
i);
848 DebugPath(charset, path);
856 for (
int tn = 0; tn <
TN_COUNT && total_beam == 0; ++tn) {
858 for (
int index = 0; index <
kNumBeams; ++index) {
862 for (
int i = prev->beams_[index].size() - 1;
i >= 0; --
i) {
863 ContinueContext(&prev->beams_[index].get(
i).data(), index, outputs,
864 top_n, charset, dict_ratio, cert_offset,
865 worst_dict_cert, step);
868 for (
int index = 0; index <
kNumBeams; ++index) {
870 total_beam += step->beams_[index].size();
876 for (
int c = 0; c <
NC_COUNT; ++c) {
877 if (step->best_initial_dawgs_[c].code >= 0) {
880 PushHeapIfBetter(kBeamWidths[0], &step->best_initial_dawgs_[c],
891void RecodeBeamSearch::ContinueContext(
892 const RecodeNode *prev,
int index,
const float *outputs,
893 TopNState top_n_flag,
const UNICHARSET *charset,
double dict_ratio,
894 double cert_offset,
double worst_dict_cert, RecodeBeam *step) {
895 RecodedCharID prefix;
896 RecodedCharID full_code;
897 const RecodeNode *previous = prev;
901 for (
int p = length - 1;
p >= 0; --
p, previous = previous->prev) {
902 while (previous !=
nullptr &&
903 (previous->duplicate || previous->code == null_char_)) {
904 previous = previous->prev;
906 if (previous !=
nullptr) {
907 prefix.Set(
p, previous->code);
908 full_code.Set(
p, previous->code);
911 if (prev !=
nullptr && !is_simple_text_) {
912 if (top_n_flags_[prev->code] == top_n_flag) {
916 PushDupOrNoDawgIfBetter(length,
true, prev->code, prev->unichar_id,
917 cert, worst_dict_cert, dict_ratio, use_dawgs,
921 prev->code != null_char_) {
923 outputs[null_char_]) +
925 PushDupOrNoDawgIfBetter(length,
true, prev->code, prev->unichar_id,
926 cert, worst_dict_cert, dict_ratio, use_dawgs,
933 if (prev->code != null_char_ && length > 0 &&
934 top_n_flags_[null_char_] == top_n_flag) {
939 PushDupOrNoDawgIfBetter(length,
false, null_char_, INVALID_UNICHAR_ID,
940 cert, worst_dict_cert, dict_ratio, use_dawgs,
944 const std::vector<int> *final_codes = recoder_.
GetFinalCodes(prefix);
945 if (final_codes !=
nullptr) {
946 for (
int code : *final_codes) {
947 if (top_n_flags_[code] != top_n_flag) {
950 if (prev !=
nullptr && prev->code == code && !is_simple_text_) {
957 full_code.Set(length, code);
960 if (length == 0 && code == null_char_) {
961 unichar_id = INVALID_UNICHAR_ID;
963 if (unichar_id != INVALID_UNICHAR_ID && charset !=
nullptr &&
964 !charset->get_enabled(unichar_id)) {
967 ContinueUnichar(code, unichar_id, cert, worst_dict_cert, dict_ratio,
969 if (top_n_flag ==
TN_TOP2 && code != null_char_) {
970 float prob = outputs[code] + outputs[null_char_];
972 prev->code != null_char_ &&
973 ((prev->code == top_code_ && code == second_code_) ||
974 (code == top_code_ && prev->code == second_code_))) {
975 prob += outputs[prev->code];
978 ContinueUnichar(code, unichar_id, cert, worst_dict_cert, dict_ratio,
983 const std::vector<int> *next_codes = recoder_.
GetNextCodes(prefix);
984 if (next_codes !=
nullptr) {
985 for (
int code : *next_codes) {
986 if (top_n_flags_[code] != top_n_flag) {
989 if (prev !=
nullptr && prev->code == code && !is_simple_text_) {
993 PushDupOrNoDawgIfBetter(length + 1,
false, code, INVALID_UNICHAR_ID, cert,
994 worst_dict_cert, dict_ratio, use_dawgs,
996 if (top_n_flag ==
TN_TOP2 && code != null_char_) {
997 float prob = outputs[code] + outputs[null_char_];
999 prev->code != null_char_ &&
1000 ((prev->code == top_code_ && code == second_code_) ||
1001 (code == top_code_ && prev->code == second_code_))) {
1002 prob += outputs[prev->code];
1005 PushDupOrNoDawgIfBetter(length + 1,
false, code, INVALID_UNICHAR_ID,
1006 cert, worst_dict_cert, dict_ratio, use_dawgs,
1014void RecodeBeamSearch::ContinueUnichar(
int code,
int unichar_id,
float cert,
1015 float worst_dict_cert,
float dict_ratio,
1017 const RecodeNode *prev,
1020 if (cert > worst_dict_cert) {
1021 ContinueDawg(code, unichar_id, cert, cont, prev, step);
1025 PushHeapIfBetter(kBeamWidths[0], code, unichar_id,
TOP_CHOICE_PERM,
false,
1026 false,
false,
false, cert * dict_ratio, prev,
nullptr,
1028 if (dict_ !=
nullptr &&
1034 float dawg_cert = cert;
1048 dawg_cert *= dict_ratio;
1050 PushInitialDawgIfBetter(code, unichar_id, permuter,
false,
false,
1051 dawg_cert, cont, prev, step);
1059void RecodeBeamSearch::ContinueDawg(
int code,
int unichar_id,
float cert,
1061 const RecodeNode *prev, RecodeBeam *step) {
1064 if (unichar_id == INVALID_UNICHAR_ID) {
1065 PushHeapIfBetter(kBeamWidths[0], code, unichar_id,
NO_PERM,
false,
false,
1066 false,
false, cert, prev,
nullptr, dawg_heap);
1071 if (prev !=
nullptr) {
1072 score += prev->score;
1074 if (dawg_heap->size() >= kBeamWidths[0] &&
1075 score <= dawg_heap->PeekTop().data().score &&
1076 nodawg_heap->size() >= kBeamWidths[0] &&
1077 score <= nodawg_heap->PeekTop().data().score) {
1080 const RecodeNode *uni_prev = prev;
1083 while (uni_prev !=
nullptr &&
1084 (uni_prev->unichar_id == INVALID_UNICHAR_ID || uni_prev->duplicate)) {
1085 uni_prev = uni_prev->prev;
1088 if (uni_prev !=
nullptr && uni_prev->end_of_word) {
1091 PushInitialDawgIfBetter(code, unichar_id, uni_prev->permuter,
false,
1092 false, cert, cont, prev, step);
1093 PushHeapIfBetter(kBeamWidths[0], code, unichar_id, uni_prev->permuter,
1094 false,
false,
false,
false, cert, prev,
nullptr,
1098 }
else if (uni_prev !=
nullptr && uni_prev->start_of_dawg &&
1104 DawgPositionVector initial_dawgs;
1105 auto *updated_dawgs =
new DawgPositionVector;
1106 DawgArgs dawg_args(&initial_dawgs, updated_dawgs,
NO_PERM);
1107 bool word_start =
false;
1108 if (uni_prev ==
nullptr) {
1112 }
else if (uni_prev->dawgs !=
nullptr) {
1114 dawg_args.active_dawgs = uni_prev->dawgs;
1115 word_start = uni_prev->start_of_dawg;
1122 PushHeapIfBetter(kBeamWidths[0], code, unichar_id, permuter,
false,
1123 word_start, dawg_args.valid_end,
false, cert, prev,
1124 dawg_args.updated_dawgs, dawg_heap);
1125 if (dawg_args.valid_end && !space_delimited_) {
1129 PushInitialDawgIfBetter(code, unichar_id, permuter, word_start,
true,
1130 cert, cont, prev, step);
1131 PushHeapIfBetter(kBeamWidths[0], code, unichar_id, permuter,
false,
1132 word_start,
true,
false, cert, prev,
nullptr,
1136 delete updated_dawgs;
1143void RecodeBeamSearch::PushInitialDawgIfBetter(
int code,
int unichar_id,
1145 bool start,
bool end,
float cert,
1147 const RecodeNode *prev,
1149 RecodeNode *best_initial_dawg = &step->best_initial_dawgs_[cont];
1151 if (prev !=
nullptr) {
1152 score += prev->score;
1154 if (best_initial_dawg->code < 0 || score > best_initial_dawg->score) {
1155 auto *initial_dawgs =
new DawgPositionVector;
1157 RecodeNode node(code, unichar_id, permuter,
true, start, end,
false, cert,
1158 score, prev, initial_dawgs,
1159 ComputeCodeHash(code,
false, prev));
1160 *best_initial_dawg = node;
1168void RecodeBeamSearch::PushDupOrNoDawgIfBetter(
1169 int length,
bool dup,
int code,
int unichar_id,
float cert,
1170 float worst_dict_cert,
float dict_ratio,
bool use_dawgs,
1172 int index =
BeamIndex(use_dawgs, cont, length);
1174 if (cert > worst_dict_cert) {
1175 PushHeapIfBetter(kBeamWidths[length], code, unichar_id,
1176 prev ? prev->permuter :
NO_PERM,
false,
false,
false,
1177 dup, cert, prev,
nullptr, &step->beams_[index]);
1182 PushHeapIfBetter(kBeamWidths[length], code, unichar_id,
1184 false, dup, cert, prev,
nullptr, &step->beams_[index]);
1192void RecodeBeamSearch::PushHeapIfBetter(
int max_size,
int code,
int unichar_id,
1194 bool word_start,
bool end,
bool dup,
1195 float cert,
const RecodeNode *prev,
1196 DawgPositionVector *d,
1199 if (prev !=
nullptr) {
1200 score += prev->score;
1202 if (heap->size() < max_size || score > heap->PeekTop().data().score) {
1203 uint64_t hash = ComputeCodeHash(code, dup, prev);
1204 RecodeNode node(code, unichar_id, permuter, dawg_start, word_start, end,
1205 dup, cert, score, prev, d, hash);
1206 if (UpdateHeapIfMatched(&node, heap)) {
1212 if (heap->size() > max_size) {
1222void RecodeBeamSearch::PushHeapIfBetter(
int max_size, RecodeNode *node,
1224 if (heap->size() < max_size || node->score > heap->PeekTop().data().score) {
1225 if (UpdateHeapIfMatched(node, heap)) {
1231 if (heap->size() > max_size) {
1239bool RecodeBeamSearch::UpdateHeapIfMatched(RecodeNode *new_node,
1244 std::vector<RecodePair> &nodes = heap->heap();
1245 for (
auto &
i : nodes) {
1246 RecodeNode &node =
i.data();
1247 if (node.code == new_node->code && node.code_hash == new_node->code_hash &&
1248 node.permuter == new_node->permuter &&
1249 node.start_of_dawg == new_node->start_of_dawg) {
1250 if (new_node->score > node.score) {
1254 i.key() = node.score;
1255 heap->Reshuffle(&
i);
1264uint64_t RecodeBeamSearch::ComputeCodeHash(
int code,
bool dup,
1265 const RecodeNode *prev)
const {
1266 uint64_t hash = prev ==
nullptr ? 0 : prev->code_hash;
1267 if (!dup && code != null_char_) {
1269 uint64_t carry = (((hash >> 32) * num_classes) >> 32);
1270 hash *= num_classes;
1281void RecodeBeamSearch::ExtractBestPaths(
1282 std::vector<const RecodeNode *> *best_nodes,
1283 std::vector<const RecodeNode *> *second_nodes)
const {
1285 const RecodeNode *best_node =
nullptr;
1286 const RecodeNode *second_best_node =
nullptr;
1287 const RecodeBeam *last_beam = beam_[beam_size_ - 1];
1288 for (
int c = 0; c <
NC_COUNT; ++c) {
1293 for (
int is_dawg = 0; is_dawg < 2; ++is_dawg) {
1294 int beam_index =
BeamIndex(is_dawg, cont, 0);
1295 int heap_size = last_beam->beams_[beam_index].size();
1296 for (
int h = 0; h < heap_size; ++h) {
1297 const RecodeNode *node = &last_beam->beams_[beam_index].get(h).data();
1301 const RecodeNode *dawg_node = node;
1302 while (dawg_node !=
nullptr &&
1303 (dawg_node->unichar_id == INVALID_UNICHAR_ID ||
1304 dawg_node->duplicate)) {
1305 dawg_node = dawg_node->prev;
1307 if (dawg_node ==
nullptr ||
1308 (!dawg_node->end_of_word &&
1314 if (best_node ==
nullptr || node->score > best_node->score) {
1315 second_best_node = best_node;
1317 }
else if (second_best_node ==
nullptr ||
1318 node->score > second_best_node->score) {
1319 second_best_node = node;
1324 if (second_nodes !=
nullptr) {
1325 ExtractPath(second_best_node, second_nodes);
1327 ExtractPath(best_node, best_nodes);
1332void RecodeBeamSearch::ExtractPath(
1333 const RecodeNode *node, std::vector<const RecodeNode *> *path)
const {
1335 while (node !=
nullptr) {
1336 path->push_back(node);
1339 std::reverse(path->begin(), path->end());
1342void RecodeBeamSearch::ExtractPath(
const RecodeNode *node,
1343 std::vector<const RecodeNode *> *path,
1344 int limiter)
const {
1345 int pathcounter = 0;
1347 while (node !=
nullptr && pathcounter < limiter) {
1348 path->push_back(node);
1352 std::reverse(path->begin(), path->end());
1356void RecodeBeamSearch::DebugPath(
1357 const UNICHARSET *unicharset,
1358 const std::vector<const RecodeNode *> &path)
const {
1359 for (
unsigned c = 0; c < path.size(); ++c) {
1360 const RecodeNode &node = *path[c];
1362 node.Print(null_char_, *unicharset, 1);
1367void RecodeBeamSearch::DebugUnicharPath(
1368 const UNICHARSET *unicharset,
const std::vector<const RecodeNode *> &path,
1369 const std::vector<int> &unichar_ids,
const std::vector<float> &certs,
1370 const std::vector<float> &ratings,
const std::vector<int> &xcoords)
const {
1371 auto num_ids = unichar_ids.size();
1372 double total_rating = 0.0;
1373 for (
unsigned c = 0; c < num_ids; ++c) {
1374 int coord = xcoords[c];
1375 tprintf(
"%d %d=%s r=%g, c=%g, s=%d, e=%d, perm=%d\n", coord, unichar_ids[c],
1376 unicharset->debug_str(unichar_ids[c]).c_str(), ratings[c], certs[c],
1377 path[coord]->start_of_word, path[coord]->end_of_word,
1378 path[coord]->permuter);
1379 total_rating += ratings[c];
1381 tprintf(
"Path total rating = %g\n", total_rating);
KDPairInc< double, RecodeNode > RecodePair
void tprintf(const char *format,...)
GenericHeap< RecodePair > RecodeHeap
void put(ICOORD pos, const T &thing)
void FakeWordFromRatings(PermuterType permuter)
static C_BLOB * FakeBlob(const TBOX &box)
const Pair & get(int index) const
static const int kMaxCodeLen
const std::vector< int > * GetFinalCodes(const RecodedCharID &code) const
const std::vector< int > * GetNextCodes(const RecodedCharID &code) const
int DecodeUnichar(const RecodedCharID &code) const
const char * id_to_unichar(UNICHAR_ID id) const
bool IsSpaceDelimited(UNICHAR_ID unichar_id) const
std::string debug_str(UNICHAR_ID id) const
const char * id_to_unichar_ext(UNICHAR_ID id) const
bool IsSpaceDelimitedLang() const
Returns true if the language is space-delimited (not CJ, or T).
void default_dawgs(DawgPositionVector *anylength_dawgs, bool suppress_patterns) const
int def_letter_is_okay(void *void_dawg_args, const UNICHARSET &unicharset, UNICHAR_ID unichar_id, bool word_end) const
const UNICHARSET & getUnicharset() const
static float ProbToCertainty(float prob)
bool operator()(const RecodeNode *&node1, const RecodeNode *&node2) const
void Print(int null_char, const UNICHARSET &unicharset, int depth) const
void Decode(const NetworkIO &output, double dict_ratio, double cert_offset, double worst_dict_cert, const UNICHARSET *charset, int lstm_choice_mode=0)
static bool IsDawgFromBeamsIndex(int index)
std::vector< std::vector< std::pair< const char *, float > > > ctc_choices
static int LengthFromBeamsIndex(int index)
void DecodeSecondaryBeams(const NetworkIO &output, double dict_ratio, double cert_offset, double worst_dict_cert, const UNICHARSET *charset, int lstm_choice_mode=0)
std::vector< std::vector< std::pair< const char *, float > > > timesteps
std::vector< std::vector< std::pair< const char *, float > > > combineSegmentedTimesteps(std::vector< std::vector< std::vector< std::pair< const char *, float > > > > *segmentedTimesteps)
std::vector< std::vector< std::vector< std::pair< const char *, float > > > > segmentedTimesteps
void extractSymbolChoices(const UNICHARSET *unicharset)
std::vector< int > character_boundaries_
std::vector< std::unordered_set< int > > excludedUnichars
static NodeContinuation ContinuationFromBeamsIndex(int index)
void PrintBeam2(bool uids, int num_outputs, const UNICHARSET *charset, bool secondary) const
void DebugBeams(const UNICHARSET &unicharset) const
void ExtractBestPathAsUnicharIds(bool debug, const UNICHARSET *unicharset, std::vector< int > *unichar_ids, std::vector< float > *certs, std::vector< float > *ratings, std::vector< int > *xcoords) const
void ExtractBestPathAsLabels(std::vector< int > *labels, std::vector< int > *xcoords) const
RecodeBeamSearch(const UnicharCompress &recoder, int null_char, bool simple_text, Dict *dict)
static const int kNumBeams
static constexpr float kMinCertainty
static int BeamIndex(bool is_dawg, NodeContinuation cont, int length)
void ExtractBestPathAsWords(const TBOX &line_box, float scale_factor, bool debug, const UNICHARSET *unicharset, PointerVector< WERD_RES > *words, int lstm_choice_mode=0)
void segmentTimestepsByCharacters()