20#ifndef THIRD_PARTY_TESSERACT_LSTM_RECODEBEAM_H_
21#define THIRD_PARTY_TESSERACT_LSTM_RECODEBEAM_H_
32#include <unordered_set>
132 memcpy(
this, &src,
sizeof(src));
140 void Print(
int null_char,
const UNICHARSET &unicharset,
int depth)
const;
189 void Decode(
const NetworkIO &
output,
double dict_ratio,
double cert_offset,
190 double worst_dict_cert,
const UNICHARSET *charset,
int lstm_choice_mode = 0);
192 double worst_dict_cert,
const UNICHARSET *charset);
194 void DecodeSecondaryBeams(
const NetworkIO &
output,
double dict_ratio,
double cert_offset,
195 double worst_dict_cert,
const UNICHARSET *charset,
196 int lstm_choice_mode = 0);
199 void ExtractBestPathAsLabels(std::vector<int> *labels, std::vector<int> *xcoords)
const;
202 void ExtractBestPathAsUnicharIds(
bool debug,
const UNICHARSET *unicharset,
203 std::vector<int> *unichar_ids, std::vector<float> *certs,
204 std::vector<float> *ratings, std::vector<int> *xcoords)
const;
207 void ExtractBestPathAsWords(
const TBOX &line_box,
float scale_factor,
bool debug,
209 int lstm_choice_mode = 0);
212 void DebugBeams(
const UNICHARSET &unicharset)
const;
219 void extractSymbolChoices(
const UNICHARSET *unicharset);
222 void PrintBeam2(
bool uids,
int num_outputs,
const UNICHARSET *charset,
bool secondary)
const;
224 void segmentTimestepsByCharacters();
225 std::vector<std::vector<std::pair<const char *, float>>>
227 combineSegmentedTimesteps(
228 std::vector<std::vector<std::vector<std::pair<const char *, float>>>> *segmentedTimesteps);
231 std::vector<std::vector<std::pair<const char *, float>>>
timesteps;
234 std::vector<std::vector<std::pair<const char *, float>>>
ctc_choices;
248 static const int kNumBeams = 2 *
NC_COUNT * kNumLengths;
251 return index % kNumLengths;
257 return index / (kNumLengths *
NC_COUNT) > 0;
261 return (is_dawg *
NC_COUNT + cont) * kNumLengths + length;
271 for (
auto &beam : beams_) {
275 for (
auto &best_initial_dawg : best_initial_dawgs_) {
276 best_initial_dawg = empty;
295 RecodeNode best_initial_dawgs_[
NC_COUNT];
297 using TopPair = KDPairInc<float, int>;
300 void DebugBeamPos(
const UNICHARSET &unicharset,
const RecodeHeap &heap)
const;
304 static void ExtractPathAsUnicharIds(
const std::vector<const RecodeNode *> &best_nodes,
305 std::vector<int> *unichar_ids, std::vector<float> *certs,
306 std::vector<float> *ratings, std::vector<int> *xcoords,
307 std::vector<int> *character_boundaries =
nullptr);
311 WERD_RES *InitializeWord(
bool leading_space,
const TBOX &line_box,
int word_start,
int word_end,
312 float space_certainty,
const UNICHARSET *unicharset,
313 const std::vector<int> &xcoords,
float scale_factor);
317 void ComputeTopN(
const float *outputs,
int num_outputs,
int top_n);
319 void ComputeSecTopN(std::unordered_set<int> *exList,
const float *outputs,
int num_outputs,
325 void DecodeStep(
const float *outputs,
int t,
double dict_ratio,
double cert_offset,
326 double worst_dict_cert,
const UNICHARSET *charset,
bool debug =
false);
328 void DecodeSecondaryStep(
const float *outputs,
int t,
double dict_ratio,
double cert_offset,
329 double worst_dict_cert,
const UNICHARSET *charset,
bool debug =
false);
332 void SaveMostCertainChoices(
const float *outputs,
int num_outputs,
const UNICHARSET *charset,
337 static void calculateCharBoundaries(std::vector<int> *starts, std::vector<int> *ends,
338 std::vector<int> *character_boundaries_,
int maxWidth);
344 void ContinueContext(
const RecodeNode *prev,
int index,
const float *outputs,
345 TopNState top_n_flag,
const UNICHARSET *unicharset,
double dict_ratio,
346 double cert_offset,
double worst_dict_cert, RecodeBeam *step);
348 void ContinueUnichar(
int code,
int unichar_id,
float cert,
float worst_dict_cert,
350 const RecodeNode *prev, RecodeBeam *step);
353 void ContinueDawg(
int code,
int unichar_id,
float cert,
NodeContinuation cont,
354 const RecodeNode *prev, RecodeBeam *step);
357 void PushInitialDawgIfBetter(
int code,
int unichar_id,
PermuterType permuter,
bool start,
363 void PushDupOrNoDawgIfBetter(
int length,
bool dup,
int code,
int unichar_id,
float cert,
364 float worst_dict_cert,
float dict_ratio,
bool use_dawgs,
368 void PushHeapIfBetter(
int max_size,
int code,
int unichar_id,
PermuterType permuter,
369 bool dawg_start,
bool word_start,
bool end,
bool dup,
float cert,
370 const RecodeNode *prev, DawgPositionVector *d,
RecodeHeap *heap);
373 void PushHeapIfBetter(
int max_size, RecodeNode *node,
RecodeHeap *heap);
376 bool UpdateHeapIfMatched(RecodeNode *new_node,
RecodeHeap *heap);
378 uint64_t ComputeCodeHash(
int code,
bool dup,
const RecodeNode *prev)
const;
383 void ExtractBestPaths(std::vector<const RecodeNode *> *best_nodes,
384 std::vector<const RecodeNode *> *second_nodes)
const;
387 void ExtractPath(
const RecodeNode *node, std::vector<const RecodeNode *> *path)
const;
388 void ExtractPath(
const RecodeNode *node, std::vector<const RecodeNode *> *path,
391 void DebugPath(
const UNICHARSET *unicharset,
const std::vector<const RecodeNode *> &path)
const;
393 void DebugUnicharPath(
const UNICHARSET *unicharset,
const std::vector<const RecodeNode *> &path,
394 const std::vector<int> &unichar_ids,
const std::vector<float> &certs,
395 const std::vector<float> &ratings,
const std::vector<int> &xcoords)
const;
400 const UnicharCompress &recoder_;
402 std::vector<RecodeBeam *> beam_;
404 std::vector<RecodeBeam *> secondary_beam_;
409 std::vector<TopNState> top_n_flags_;
414 GenericHeap<TopPair> top_heap_;
419 bool space_delimited_;
422 bool is_simple_text_;
const float kMinCertainty
GenericHeap< RecodePair > RecodeHeap
static const int kMaxCodeLen
RecodeNode(int c, int uni_id, PermuterType perm, bool dawg_start, bool word_start, bool end, bool dup, float cert, float s, const RecodeNode *p, DawgPositionVector *d, uint64_t hash)
void Print(int null_char, const UNICHARSET &unicharset, int depth) const
DawgPositionVector * dawgs
RecodeNode(const RecodeNode &src)
RecodeNode & operator=(const RecodeNode &src)
static bool IsDawgFromBeamsIndex(int index)
std::vector< std::vector< std::pair< const char *, float > > > ctc_choices
static int LengthFromBeamsIndex(int index)
std::vector< std::vector< std::pair< const char *, float > > > timesteps
std::vector< std::vector< std::vector< std::pair< const char *, float > > > > segmentedTimesteps
std::vector< int > character_boundaries_
std::vector< std::unordered_set< int > > excludedUnichars
static NodeContinuation ContinuationFromBeamsIndex(int index)
static int BeamIndex(bool is_dawg, NodeContinuation cont, int length)