20# include "config_auto.h"
25#include <allheaders.h>
40#include <unordered_set>
58 , training_iteration_(0)
59 , sample_iteration_(0)
61 , learning_rate_(0.0f)
66 , debug_win_(nullptr) {}
206 tprintf(
"Space was garbled in recoding!!\n");
228 dict_->user_words_file.ResetFrom(params);
229 dict_->user_words_suffix.ResetFrom(params);
230 dict_->user_patterns_file.ResetFrom(params);
231 dict_->user_patterns_suffix.ResetFrom(params);
238 tprintf(
"Failed to load any lstm-specific dictionaries for lang %s!!\n", lang.c_str());
248 float invert_threshold,
bool debug,
249 double worst_dict_cert,
const TBOX &line_box,
251 int lstm_choice_amount) {
255 if (!
RecognizeLine(image_data, invert_threshold, debug,
false,
false, &scale_factor, &inputs, &outputs)) {
266 if (lstm_choice_mode) {
268 for (
int i = 0;
i < lstm_choice_amount; ++
i) {
274 unsigned char_it = 0;
275 for (
size_t i = 0;
i < words->
size(); ++
i) {
276 for (
int j = 0; j < words->
at(
i)->end; ++j) {
277 if (char_it < search_->ctc_choices.size()) {
280 if (char_it < search_->segmentedTimesteps.size()) {
285 words->
at(
i)->timesteps =
297 const int kOutputScale = INT8_MAX;
298 STATS stats(0, kOutputScale);
299 for (
int t = 0; t < outputs.
Width(); ++t) {
300 int best_label = outputs.
BestLabel(t,
nullptr);
302 float best_output = outputs.
f(t)[best_label];
303 stats.
add(
static_cast<int>(kOutputScale * best_output), 1);
313 *min_output =
static_cast<float>(stats.
min_bucket()) / kOutputScale;
314 *mean_output = stats.
mean() / kOutputScale;
315 *sd = stats.
sd() / kOutputScale;
322 float invert_threshold,
bool debug,
323 bool re_invert,
bool upside_down,
float *scale_factor,
329 if (pix ==
nullptr) {
330 tprintf(
"Line cannot be recognized!!\n");
334 const int kMaxImageWidth = 128 * pixGetHeight(pix);
336 tprintf(
"Image too large to learn!! Size = %dx%d\n", pixGetWidth(pix), pixGetHeight(pix));
341 pixRotate180(pix, pix);
344 *scale_factor = min_width / *scale_factor;
350 if (invert_threshold > 0.0f) {
351 float pos_min, pos_mean, pos_sd;
352 OutputStats(*outputs, &pos_min, &pos_mean, &pos_sd);
353 if (pos_mean < invert_threshold) {
361 float inv_min, inv_mean, inv_sd;
362 OutputStats(inv_outputs, &inv_min, &inv_mean, &inv_sd);
363 if (inv_mean > pos_mean) {
366 tprintf(
"Inverting image: old min=%g, mean=%g, sd=%g, inv %g,%g,%g\n", pos_min, pos_mean,
367 pos_sd, inv_min, inv_mean, inv_sd);
369 *outputs = inv_outputs;
370 *inputs = inv_inputs;
371 }
else if (re_invert) {
382 std::vector<int> labels, coords;
384#ifndef GRAPHICS_DISABLED
397 for (
unsigned start = 0; start < labels.size(); start = end) {
401 result +=
DecodeLabel(labels, start, &end,
nullptr);
407#ifndef GRAPHICS_DISABLED
412 const std::vector<int> &label_coords,
const char *window_name,
415 Network::ClearWindow(
false, window_name, pixGetWidth(input_pix), pixGetHeight(input_pix), window);
423 const std::vector<int> &xcoords,
int height,
428 for (
unsigned start = 0; start < labels.size(); start = end) {
429 int xpos = xcoords[start] * x_scale;
435 const char *str =
DecodeLabel(labels, start, &end,
nullptr);
439 xpos = xcoords[(start + end) / 2] * x_scale;
440 window->
Text(xpos, height, str);
442 window->
Line(xpos, 0, xpos, height * 3 / 2);
452 const std::vector<int> &xcoords) {
453 if (xcoords[0] > 0) {
457 for (
unsigned start = 0; start < labels.size(); start = end) {
464 const char *label =
DecodeLabel(labels, start, &end, &decoded);
466 for (
unsigned i = start + 1;
i < end; ++
i) {
477 int best_choice,
int x_start,
int x_end) {
478 tprintf(
"%s=%d On [%d, %d), scores=", label, best_choice, x_start, x_end);
479 double max_score = 0.0;
480 double mean_score = 0.0;
481 const int width = x_end - x_start;
482 for (
int x = x_start;
x < x_end; ++
x) {
483 const float *line = outputs.
f(
x);
484 const double score = line[best_choice] * 100.0;
485 if (score > max_score) {
488 mean_score += score / width;
490 double best_score = 0.0;
492 if (c != best_choice && line[c] > best_score) {
494 best_score = line[c];
499 tprintf(
", Mean=%g, max=%g\n", mean_score, max_score);
507 int null_char,
int t) {
508 if (
output.f(t)[null_char] >= null_thr)
return true;
520 std::vector<int> *xcoords) {
531 std::vector<int> *xcoords) {
543 std::vector<int> *xcoords) {
546 const int width =
output.Width();
547 for (
int t = 0; t < width; ++t) {
549 const int label =
output.BestLabel(t, &score);
551 labels->push_back(label);
552 xcoords->push_back(t);
555 xcoords->push_back(width);
567 if (decoded !=
nullptr) {
573 unsigned index = start;
575 code.
Set(code.
length(), labels[index++]);
576 while (index < labels.size() && labels[index] ==
null_char_) {
582 if (uni_id != INVALID_UNICHAR_ID &&
586 if (decoded !=
nullptr) {
595 return "<Undecodable>";
597 if (decoded !=
nullptr) {
598 *decoded = labels[start];
621 if (label == INVALID_UNICHAR_ID) {
void tprintf(const char *format,...)
@ TESSDATA_LSTM_UNICHARSET
void add(int32_t value, int32_t count)
int32_t get_total() const
int32_t min_bucket() const
std::string language_data_path_prefix
bool DeSerialize(std::string &data)
bool Serialize(const std::string &data)
bool GetComponent(TessdataType type, TFile *fp)
bool IsComponentAvailable(TessdataType type) const
void Set(int index, int value)
static const int kMaxCodeLen
int EncodeUnichar(unsigned unichar_id, RecodedCharID *code) const
bool DeSerialize(TFile *fp)
bool IsValidFirstCode(int code) const
void SetupPassThrough(const UNICHARSET &unicharset)
int DecodeUnichar(const RecodedCharID &code) const
bool Serialize(TFile *fp) const
bool load_from_file(const char *const filename, bool skip_fragments)
const char * get_normed_unichar(UNICHAR_ID unichar_id) const
static DawgCache * GlobalDawgCache()
void LoadLSTM(const std::string &lang, TessdataManager *data_file)
void SetupForLoad(DawgCache *dawg_cache)
static Image PrepareLSTMInputs(const ImageData &image_data, const Network *network, int min_width, TRand *randomizer, float *image_scale)
static void PreparePixInput(const StaticShape &shape, const Image pix, TRand *randomizer, NetworkIO *input)
void DebugActivationPath(const NetworkIO &outputs, const std::vector< int > &labels, const std::vector< int > &xcoords)
std::string DecodeLabels(const std::vector< int > &labels)
bool LoadRecoder(TFile *fp)
bool SimpleTextOutput() const
NetworkScratch scratch_space_
void LabelsViaReEncode(const NetworkIO &output, std::vector< int > *labels, std::vector< int > *xcoords)
const char * DecodeSingleLabel(int label)
bool LoadCharsets(const TessdataManager *mgr)
void LabelsFromOutputs(const NetworkIO &outputs, std::vector< int > *labels, std::vector< int > *xcoords)
void RecognizeLine(const ImageData &image_data, float invert_threshold, bool debug, double worst_dict_cert, const TBOX &line_box, PointerVector< WERD_RES > *words, int lstm_choice_mode=0, int lstm_choice_amount=5)
void DisplayForward(const NetworkIO &inputs, const std::vector< int > &labels, const std::vector< int > &label_coords, const char *window_name, ScrollView **window)
void OutputStats(const NetworkIO &outputs, float *min_output, float *mean_output, float *sd)
const char * DecodeLabel(const std::vector< int > &labels, unsigned start, unsigned *end, int *decoded)
bool Load(const ParamsVectors *params, const std::string &lang, TessdataManager *mgr)
RecodeBeamSearch * search_
void LabelsViaSimpleText(const NetworkIO &output, std::vector< int > *labels, std::vector< int > *xcoords)
int32_t training_iteration_
void DisplayLSTMOutput(const std::vector< int > &labels, const std::vector< int > &xcoords, int height, ScrollView *window)
bool Serialize(const TessdataManager *mgr, TFile *fp) const
void DebugActivationRange(const NetworkIO &outputs, const char *label, int best_choice, int x_start, int x_end)
const UNICHARSET & GetUnicharset() const
int32_t sample_iteration_
bool DeSerialize(const TessdataManager *mgr, TFile *fp)
bool LoadDictionary(const ParamsVectors *params, const std::string &lang, TessdataManager *mgr)
virtual int XScaleFactor() const
virtual void Forward(bool debug, const NetworkIO &input, const TransposedArray *input_transpose, NetworkScratch *scratch, NetworkIO *output)=0
static void ClearWindow(bool tess_coords, const char *window_name, int width, int height, ScrollView **window)
static Network * CreateFromFile(TFile *fp)
virtual bool Serialize(TFile *fp) const
virtual void CacheXScaleFactor(int factor)
static int DisplayImage(Image pix, ScrollView *window)
virtual void SetRandomizer(TRand *randomizer)
virtual StaticShape InputShape() const
void set_int_mode(bool is_quantized)
int BestLabel(int t, float *score) const
void Decode(const NetworkIO &output, double dict_ratio, double cert_offset, double worst_dict_cert, const UNICHARSET *charset, int lstm_choice_mode=0)
std::vector< std::vector< std::pair< const char *, float > > > ctc_choices
void DecodeSecondaryBeams(const NetworkIO &output, double dict_ratio, double cert_offset, double worst_dict_cert, const UNICHARSET *charset, int lstm_choice_mode=0)
std::vector< std::vector< std::pair< const char *, float > > > combineSegmentedTimesteps(std::vector< std::vector< std::vector< std::pair< const char *, float > > > > *segmentedTimesteps)
std::vector< std::vector< std::vector< std::pair< const char *, float > > > > segmentedTimesteps
void extractSymbolChoices(const UNICHARSET *unicharset)
std::vector< std::unordered_set< int > > excludedUnichars
void ExtractBestPathAsLabels(std::vector< int > *labels, std::vector< int > *xcoords) const
static constexpr float kMinCertainty
void ExtractBestPathAsWords(const TBOX &line_box, float scale_factor, bool debug, const UNICHARSET *unicharset, PointerVector< WERD_RES > *words, int lstm_choice_mode=0)
void segmentTimestepsByCharacters()
void Line(int x1, int y1, int x2, int y2)
void TextAttributes(const char *font, int pixel_size, bool bold, bool italic, bool underlined)
void Text(int x, int y, const char *mystring)