tesseract  4.00.00dev
tesseract::LSTM Class Reference

#include <lstm.h>

Inheritance diagram for tesseract::LSTM:
tesseract::Network

Public Types

enum  WeightType {
  CI, GI, GF1, GO,
  GFS, WT_COUNT
}
 

Public Member Functions

 LSTM (const STRING &name, int num_inputs, int num_states, int num_outputs, bool two_dimensional, NetworkType type)
 
virtual ~LSTM ()
 
virtual StaticShape OutputShape (const StaticShape &input_shape) const
 
virtual STRING spec () const
 
virtual void SetEnableTraining (TrainingState state)
 
virtual int InitWeights (float range, TRand *randomizer)
 
int RemapOutputs (int old_no, const std::vector< int > &code_map) override
 
virtual void ConvertToInt ()
 
virtual void DebugWeights ()
 
virtual bool Serialize (TFile *fp) const
 
virtual bool DeSerialize (TFile *fp)
 
virtual void Forward (bool debug, const NetworkIO &input, const TransposedArray *input_transpose, NetworkScratch *scratch, NetworkIO *output)
 
virtual bool Backward (bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, NetworkIO *back_deltas)
 
void Update (float learning_rate, float momentum, float adam_beta, int num_samples) override
 
virtual void CountAlternators (const Network &other, double *same, double *changed) const
 
void PrintW ()
 
void PrintDW ()
 
bool Is2D () const
 
- Public Member Functions inherited from tesseract::Network
 Network ()
 
 Network (NetworkType type, const STRING &name, int ni, int no)
 
virtual ~Network ()
 
NetworkType type () const
 
bool IsTraining () const
 
bool needs_to_backprop () const
 
int num_weights () const
 
int NumInputs () const
 
int NumOutputs () const
 
virtual StaticShape InputShape () const
 
const STRINGname () const
 
bool TestFlag (NetworkFlags flag) const
 
virtual bool IsPlumbingType () const
 
virtual void SetNetworkFlags (uinT32 flags)
 
virtual void SetRandomizer (TRand *randomizer)
 
virtual bool SetupNeedsBackprop (bool needs_backprop)
 
virtual int XScaleFactor () const
 
virtual void CacheXScaleFactor (int factor)
 
void DisplayForward (const NetworkIO &matrix)
 
void DisplayBackward (const NetworkIO &matrix)
 

Additional Inherited Members

- Static Public Member Functions inherited from tesseract::Network
static NetworkCreateFromFile (TFile *fp)
 
static void ClearWindow (bool tess_coords, const char *window_name, int width, int height, ScrollView **window)
 
static int DisplayImage (Pix *pix, ScrollView *window)
 
- Protected Member Functions inherited from tesseract::Network
double Random (double range)
 
- Protected Attributes inherited from tesseract::Network
NetworkType type_
 
TrainingState training_
 
bool needs_to_backprop_
 
inT32 network_flags_
 
inT32 ni_
 
inT32 no_
 
inT32 num_weights_
 
STRING name_
 
ScrollViewforward_win_
 
ScrollViewbackward_win_
 
TRandrandomizer_
 
- Static Protected Attributes inherited from tesseract::Network
static char const *const kTypeNames [NT_COUNT]
 

Detailed Description

Definition at line 28 of file lstm.h.

Member Enumeration Documentation

◆ WeightType

Enumerator
CI 
GI 
GF1 
GO 
GFS 
WT_COUNT 

Definition at line 33 of file lstm.h.

33  {
34  CI, // Cell Inputs.
35  GI, // Gate at the input.
36  GF1, // Forget gate at the memory (1-d or looking back 1 timestep).
37  GO, // Gate at the output.
38  GFS, // Forget gate at the memory, looking back in the other dimension.
39 
40  WT_COUNT // Number of WeightTypes.
41  };

Constructor & Destructor Documentation

◆ LSTM()

tesseract::LSTM::LSTM ( const STRING name,
int  num_inputs,
int  num_states,
int  num_outputs,
bool  two_dimensional,
NetworkType  type 
)

Definition at line 70 of file lstm.cpp.

72  : Network(type, name, ni, no),
73  na_(ni + ns),
74  ns_(ns),
75  nf_(0),
76  is_2d_(two_dimensional),
77  softmax_(NULL),
78  input_width_(0) {
79  if (two_dimensional) na_ += ns_;
80  if (type_ == NT_LSTM || type_ == NT_LSTM_SUMMARY) {
81  nf_ = 0;
82  // networkbuilder ensures this is always true.
83  ASSERT_HOST(no == ns);
84  } else if (type_ == NT_LSTM_SOFTMAX || type_ == NT_LSTM_SOFTMAX_ENCODED) {
85  nf_ = type_ == NT_LSTM_SOFTMAX ? no_ : IntCastRounded(ceil(log2(no_)));
86  softmax_ = new FullyConnected("LSTM Softmax", ns_, no_, NT_SOFTMAX);
87  } else {
88  tprintf("%d is invalid type of LSTM!\n", type);
89  ASSERT_HOST(false);
90  }
91  na_ += nf_;
92 }
NetworkType type_
Definition: network.h:299
#define tprintf(...)
Definition: tprintf.h:31
#define ASSERT_HOST(x)
Definition: errcode.h:84
NetworkType type() const
Definition: network.h:112
int IntCastRounded(double x)
Definition: helpers.h:179

◆ ~LSTM()

tesseract::LSTM::~LSTM ( )
virtual

Definition at line 94 of file lstm.cpp.

94 { delete softmax_; }

Member Function Documentation

◆ Backward()

bool tesseract::LSTM::Backward ( bool  debug,
const NetworkIO fwd_deltas,
NetworkScratch scratch,
NetworkIO back_deltas 
)
virtual

Reimplemented from tesseract::Network.

Definition at line 412 of file lstm.cpp.

414  {
415  if (debug) DisplayBackward(fwd_deltas);
416  back_deltas->ResizeToMap(fwd_deltas.int_mode(), input_map_, ni_);
417  // ======Scratch space.======
418  // Output errors from deltas with recurrence from sourceerr.
419  NetworkScratch::FloatVec outputerr;
420  outputerr.Init(ns_, scratch);
421  // Recurrent error in the state/source.
422  NetworkScratch::FloatVec curr_stateerr, curr_sourceerr;
423  curr_stateerr.Init(ns_, scratch);
424  curr_sourceerr.Init(na_, scratch);
425  ZeroVector<double>(ns_, curr_stateerr);
426  ZeroVector<double>(na_, curr_sourceerr);
427  // Errors in the gates.
428  NetworkScratch::FloatVec gate_errors[WT_COUNT];
429  for (int g = 0; g < WT_COUNT; ++g) gate_errors[g].Init(ns_, scratch);
430  // Rotating buffers of width buf_width allow storage of the recurrent time-
431  // steps used only for true 2-D. Stores one full strip of the major direction.
432  int buf_width = Is2D() ? input_map_.Size(FD_WIDTH) : 1;
433  GenericVector<NetworkScratch::FloatVec> stateerr, sourceerr;
434  if (Is2D()) {
435  stateerr.init_to_size(buf_width, NetworkScratch::FloatVec());
436  sourceerr.init_to_size(buf_width, NetworkScratch::FloatVec());
437  for (int t = 0; t < buf_width; ++t) {
438  stateerr[t].Init(ns_, scratch);
439  sourceerr[t].Init(na_, scratch);
440  ZeroVector<double>(ns_, stateerr[t]);
441  ZeroVector<double>(na_, sourceerr[t]);
442  }
443  }
444  // Parallel-generated sourceerr from each of the gates.
445  NetworkScratch::FloatVec sourceerr_temps[WT_COUNT];
446  for (int w = 0; w < WT_COUNT; ++w)
447  sourceerr_temps[w].Init(na_, scratch);
448  int width = input_width_;
449  // Transposed gate errors stored over all timesteps for sum outer.
450  NetworkScratch::GradientStore gate_errors_t[WT_COUNT];
451  for (int w = 0; w < WT_COUNT; ++w) {
452  gate_errors_t[w].Init(ns_, width, scratch);
453  }
454  // Used only if softmax_ != NULL.
455  NetworkScratch::FloatVec softmax_errors;
456  NetworkScratch::GradientStore softmax_errors_t;
457  if (softmax_ != NULL) {
458  softmax_errors.Init(no_, scratch);
459  softmax_errors_t.Init(no_, width, scratch);
460  }
461  double state_clip = Is2D() ? 9.0 : 4.0;
462 #if DEBUG_DETAIL > 1
463  tprintf("fwd_deltas:%s\n", name_.string());
464  fwd_deltas.Print(10);
465 #endif
466  StrideMap::Index dest_index(input_map_);
467  dest_index.InitToLast();
468  // Used only by NT_LSTM_SUMMARY.
469  StrideMap::Index src_index(fwd_deltas.stride_map());
470  src_index.InitToLast();
471  do {
472  int t = dest_index.t();
473  bool at_last_x = dest_index.IsLast(FD_WIDTH);
474  // up_pos is the 2-D back step, down_pos is the 2-D fwd step, and are only
475  // valid if >= 0, which is true if 2d and not on the top/bottom.
476  int up_pos = -1;
477  int down_pos = -1;
478  if (Is2D()) {
479  if (dest_index.index(FD_HEIGHT) > 0) {
480  StrideMap::Index up_index(dest_index);
481  if (up_index.AddOffset(-1, FD_HEIGHT)) up_pos = up_index.t();
482  }
483  if (!dest_index.IsLast(FD_HEIGHT)) {
484  StrideMap::Index down_index(dest_index);
485  if (down_index.AddOffset(1, FD_HEIGHT)) down_pos = down_index.t();
486  }
487  }
488  // Index of the 2-D revolving buffers (sourceerr, stateerr).
489  int mod_t = Modulo(t, buf_width); // Current timestep.
490  // Zero the state in the major direction only at the end of every row.
491  if (at_last_x) {
492  ZeroVector<double>(na_, curr_sourceerr);
493  ZeroVector<double>(ns_, curr_stateerr);
494  }
495  // Setup the outputerr.
496  if (type_ == NT_LSTM_SUMMARY) {
497  if (dest_index.IsLast(FD_WIDTH)) {
498  fwd_deltas.ReadTimeStep(src_index.t(), outputerr);
499  src_index.Decrement();
500  } else {
501  ZeroVector<double>(ns_, outputerr);
502  }
503  } else if (softmax_ == NULL) {
504  fwd_deltas.ReadTimeStep(t, outputerr);
505  } else {
506  softmax_->BackwardTimeStep(fwd_deltas, t, softmax_errors,
507  softmax_errors_t.get(), outputerr);
508  }
509  if (!at_last_x)
510  AccumulateVector(ns_, curr_sourceerr + ni_ + nf_, outputerr);
511  if (down_pos >= 0)
512  AccumulateVector(ns_, sourceerr[mod_t] + ni_ + nf_ + ns_, outputerr);
513  // Apply the 1-d forget gates.
514  if (!at_last_x) {
515  const float* next_node_gf1 = node_values_[GF1].f(t + 1);
516  for (int i = 0; i < ns_; ++i) {
517  curr_stateerr[i] *= next_node_gf1[i];
518  }
519  }
520  if (Is2D() && t + 1 < width) {
521  for (int i = 0; i < ns_; ++i) {
522  if (which_fg_[t + 1][i] != 1) curr_stateerr[i] = 0.0;
523  }
524  if (down_pos >= 0) {
525  const float* right_node_gfs = node_values_[GFS].f(down_pos);
526  const double* right_stateerr = stateerr[mod_t];
527  for (int i = 0; i < ns_; ++i) {
528  if (which_fg_[down_pos][i] == 2) {
529  curr_stateerr[i] += right_stateerr[i] * right_node_gfs[i];
530  }
531  }
532  }
533  }
534  state_.FuncMultiply3Add<HPrime>(node_values_[GO], t, outputerr,
535  curr_stateerr);
536  // Clip stateerr_ to a sane range.
537  ClipVector<double>(ns_, -state_clip, state_clip, curr_stateerr);
538 #if DEBUG_DETAIL > 1
539  if (t + 10 > width) {
540  tprintf("t=%d, stateerr=", t);
541  for (int i = 0; i < ns_; ++i)
542  tprintf(" %g,%g,%g", curr_stateerr[i], outputerr[i],
543  curr_sourceerr[ni_ + nf_ + i]);
544  tprintf("\n");
545  }
546 #endif
547  // Matrix multiply to get the source errors.
549 
550  // Cell inputs.
551  node_values_[CI].FuncMultiply3<GPrime>(t, node_values_[GI], t,
552  curr_stateerr, gate_errors[CI]);
553  ClipVector(ns_, -kErrClip, kErrClip, gate_errors[CI].get());
554  gate_weights_[CI].VectorDotMatrix(gate_errors[CI], sourceerr_temps[CI]);
555  gate_errors_t[CI].get()->WriteStrided(t, gate_errors[CI]);
556 
558  // Input Gates.
559  node_values_[GI].FuncMultiply3<FPrime>(t, node_values_[CI], t,
560  curr_stateerr, gate_errors[GI]);
561  ClipVector(ns_, -kErrClip, kErrClip, gate_errors[GI].get());
562  gate_weights_[GI].VectorDotMatrix(gate_errors[GI], sourceerr_temps[GI]);
563  gate_errors_t[GI].get()->WriteStrided(t, gate_errors[GI]);
564 
566  // 1-D forget Gates.
567  if (t > 0) {
568  node_values_[GF1].FuncMultiply3<FPrime>(t, state_, t - 1, curr_stateerr,
569  gate_errors[GF1]);
570  ClipVector(ns_, -kErrClip, kErrClip, gate_errors[GF1].get());
571  gate_weights_[GF1].VectorDotMatrix(gate_errors[GF1],
572  sourceerr_temps[GF1]);
573  } else {
574  memset(gate_errors[GF1], 0, ns_ * sizeof(gate_errors[GF1][0]));
575  memset(sourceerr_temps[GF1], 0, na_ * sizeof(*sourceerr_temps[GF1]));
576  }
577  gate_errors_t[GF1].get()->WriteStrided(t, gate_errors[GF1]);
578 
579  // 2-D forget Gates.
580  if (up_pos >= 0) {
581  node_values_[GFS].FuncMultiply3<FPrime>(t, state_, up_pos, curr_stateerr,
582  gate_errors[GFS]);
583  ClipVector(ns_, -kErrClip, kErrClip, gate_errors[GFS].get());
584  gate_weights_[GFS].VectorDotMatrix(gate_errors[GFS],
585  sourceerr_temps[GFS]);
586  } else {
587  memset(gate_errors[GFS], 0, ns_ * sizeof(gate_errors[GFS][0]));
588  memset(sourceerr_temps[GFS], 0, na_ * sizeof(*sourceerr_temps[GFS]));
589  }
590  if (Is2D()) gate_errors_t[GFS].get()->WriteStrided(t, gate_errors[GFS]);
591 
593  // Output gates.
594  state_.Func2Multiply3<HFunc, FPrime>(node_values_[GO], t, outputerr,
595  gate_errors[GO]);
596  ClipVector(ns_, -kErrClip, kErrClip, gate_errors[GO].get());
597  gate_weights_[GO].VectorDotMatrix(gate_errors[GO], sourceerr_temps[GO]);
598  gate_errors_t[GO].get()->WriteStrided(t, gate_errors[GO]);
600 
601  SumVectors(na_, sourceerr_temps[CI], sourceerr_temps[GI],
602  sourceerr_temps[GF1], sourceerr_temps[GO], sourceerr_temps[GFS],
603  curr_sourceerr);
604  back_deltas->WriteTimeStep(t, curr_sourceerr);
605  // Save states for use by the 2nd dimension only if needed.
606  if (Is2D()) {
607  CopyVector(ns_, curr_stateerr, stateerr[mod_t]);
608  CopyVector(na_, curr_sourceerr, sourceerr[mod_t]);
609  }
610  } while (dest_index.Decrement());
611 #if DEBUG_DETAIL > 2
612  for (int w = 0; w < WT_COUNT; ++w) {
613  tprintf("%s gate errors[%d]\n", name_.string(), w);
614  gate_errors_t[w].get()->PrintUnTransposed(10);
615  }
616 #endif
617  // Transposed source_ used to speed-up SumOuter.
618  NetworkScratch::GradientStore source_t, state_t;
619  source_t.Init(na_, width, scratch);
620  source_.Transpose(source_t.get());
621  state_t.Init(ns_, width, scratch);
622  state_.Transpose(state_t.get());
623 #ifdef _OPENMP
624 #pragma omp parallel for num_threads(GFS) if (!Is2D())
625 #endif
626  for (int w = 0; w < WT_COUNT; ++w) {
627  if (w == GFS && !Is2D()) continue;
628  gate_weights_[w].SumOuterTransposed(*gate_errors_t[w], *source_t, false);
629  }
630  if (softmax_ != NULL) {
631  softmax_->FinishBackward(*softmax_errors_t);
632  }
633  return needs_to_backprop_;
634 }
void CopyVector(int n, const double *src, double *dest)
Definition: functions.h:186
void BackwardTimeStep(const NetworkIO &fwd_deltas, int t, double *curr_errors, TransposedArray *errors_t, double *backprop)
void ClipVector(int n, T lower, T upper, T *vec)
Definition: functions.h:225
NetworkType type_
Definition: network.h:299
#define SECTION_IF_OPENMP
Definition: lstm.cpp:57
int Size(FlexDimensions dimension) const
Definition: stridemap.h:116
#define PARALLEL_IF_OPENMP(__num_threads)
Definition: lstm.cpp:56
void FinishBackward(const TransposedArray &errors_t)
void FuncMultiply3Add(const NetworkIO &v_io, int t, const double *w, double *product) const
Definition: networkio.h:299
void VectorDotMatrix(const double *u, double *v) const
#define tprintf(...)
Definition: tprintf.h:31
const double kErrClip
Definition: lstm.cpp:68
void SumOuterTransposed(const TransposedArray &u, const TransposedArray &v, bool parallel)
bool Is2D() const
Definition: lstm.h:120
const char * string() const
Definition: strngs.cpp:198
float * f(int t)
Definition: networkio.h:115
void Transpose(TransposedArray *dest) const
Definition: networkio.cpp:969
void DisplayBackward(const NetworkIO &matrix)
Definition: network.cpp:296
#define END_PARALLEL_IF_OPENMP
Definition: lstm.cpp:58
void SumVectors(int n, const double *v1, const double *v2, const double *v3, const double *v4, const double *v5, double *sum)
Definition: functions.h:209
bool needs_to_backprop_
Definition: network.h:301
int Modulo(int a, int b)
Definition: helpers.h:164
void init_to_size(int size, T t)
void AccumulateVector(int n, const double *src, double *dest)
Definition: functions.h:191

◆ ConvertToInt()

void tesseract::LSTM::ConvertToInt ( )
virtual

Reimplemented from tesseract::Network.

Definition at line 154 of file lstm.cpp.

154  {
155  for (int w = 0; w < WT_COUNT; ++w) {
156  if (w == GFS && !Is2D()) continue;
157  gate_weights_[w].ConvertToInt();
158  }
159  if (softmax_ != NULL) {
160  softmax_->ConvertToInt();
161  }
162 }
bool Is2D() const
Definition: lstm.h:120

◆ CountAlternators()

void tesseract::LSTM::CountAlternators ( const Network other,
double *  same,
double *  changed 
) const
virtual

Reimplemented from tesseract::Network.

Definition at line 658 of file lstm.cpp.

659  {
660  ASSERT_HOST(other.type() == type_);
661  const LSTM* lstm = static_cast<const LSTM*>(&other);
662  for (int w = 0; w < WT_COUNT; ++w) {
663  if (w == GFS && !Is2D()) continue;
664  gate_weights_[w].CountAlternators(lstm->gate_weights_[w], same, changed);
665  }
666  if (softmax_ != NULL) {
667  softmax_->CountAlternators(*lstm->softmax_, same, changed);
668  }
669 }
LSTM(const STRING &name, int num_inputs, int num_states, int num_outputs, bool two_dimensional, NetworkType type)
Definition: lstm.cpp:70
NetworkType type_
Definition: network.h:299
virtual void CountAlternators(const Network &other, double *same, double *changed) const
void CountAlternators(const WeightMatrix &other, double *same, double *changed) const
bool Is2D() const
Definition: lstm.h:120
#define ASSERT_HOST(x)
Definition: errcode.h:84

◆ DebugWeights()

void tesseract::LSTM::DebugWeights ( )
virtual

Reimplemented from tesseract::Network.

Definition at line 165 of file lstm.cpp.

165  {
166  for (int w = 0; w < WT_COUNT; ++w) {
167  if (w == GFS && !Is2D()) continue;
168  STRING msg = name_;
169  msg.add_str_int(" Gate weights ", w);
170  gate_weights_[w].Debug2D(msg.string());
171  }
172  if (softmax_ != NULL) {
173  softmax_->DebugWeights();
174  }
175 }
void add_str_int(const char *str, int number)
Definition: strngs.cpp:381
bool Is2D() const
Definition: lstm.h:120
const char * string() const
Definition: strngs.cpp:198
Definition: strngs.h:45
void Debug2D(const char *msg)

◆ DeSerialize()

bool tesseract::LSTM::DeSerialize ( TFile fp)
virtual

Reimplemented from tesseract::Network.

Definition at line 191 of file lstm.cpp.

191  {
192  if (fp->FReadEndian(&na_, sizeof(na_), 1) != 1) return false;
193  if (type_ == NT_LSTM_SOFTMAX) {
194  nf_ = no_;
195  } else if (type_ == NT_LSTM_SOFTMAX_ENCODED) {
196  nf_ = IntCastRounded(ceil(log2(no_)));
197  } else {
198  nf_ = 0;
199  }
200  is_2d_ = false;
201  for (int w = 0; w < WT_COUNT; ++w) {
202  if (w == GFS && !Is2D()) continue;
203  if (!gate_weights_[w].DeSerialize(IsTraining(), fp)) return false;
204  if (w == CI) {
205  ns_ = gate_weights_[CI].NumOutputs();
206  is_2d_ = na_ - nf_ == ni_ + 2 * ns_;
207  }
208  }
209  delete softmax_;
211  softmax_ = static_cast<FullyConnected*>(Network::CreateFromFile(fp));
212  if (softmax_ == nullptr) return false;
213  } else {
214  softmax_ = nullptr;
215  }
216  return true;
217 }
NetworkType type_
Definition: network.h:299
bool IsTraining() const
Definition: network.h:115
bool Is2D() const
Definition: lstm.h:120
virtual bool DeSerialize(TFile *fp)
Definition: lstm.cpp:191
static Network * CreateFromFile(TFile *fp)
Definition: network.cpp:203
int IntCastRounded(double x)
Definition: helpers.h:179

◆ Forward()

void tesseract::LSTM::Forward ( bool  debug,
const NetworkIO input,
const TransposedArray input_transpose,
NetworkScratch scratch,
NetworkIO output 
)
virtual

Reimplemented from tesseract::Network.

Definition at line 221 of file lstm.cpp.

223  {
224  input_map_ = input.stride_map();
225  input_width_ = input.Width();
226  if (softmax_ != NULL)
227  output->ResizeFloat(input, no_);
228  else if (type_ == NT_LSTM_SUMMARY)
229  output->ResizeXTo1(input, no_);
230  else
231  output->Resize(input, no_);
232  ResizeForward(input);
233  // Temporary storage of forward computation for each gate.
234  NetworkScratch::FloatVec temp_lines[WT_COUNT];
235  for (int i = 0; i < WT_COUNT; ++i) temp_lines[i].Init(ns_, scratch);
236  // Single timestep buffers for the current/recurrent output and state.
237  NetworkScratch::FloatVec curr_state, curr_output;
238  curr_state.Init(ns_, scratch);
239  ZeroVector<double>(ns_, curr_state);
240  curr_output.Init(ns_, scratch);
241  ZeroVector<double>(ns_, curr_output);
242  // Rotating buffers of width buf_width allow storage of the state and output
243  // for the other dimension, used only when working in true 2D mode. The width
244  // is enough to hold an entire strip of the major direction.
245  int buf_width = Is2D() ? input_map_.Size(FD_WIDTH) : 1;
247  if (Is2D()) {
248  states.init_to_size(buf_width, NetworkScratch::FloatVec());
249  outputs.init_to_size(buf_width, NetworkScratch::FloatVec());
250  for (int i = 0; i < buf_width; ++i) {
251  states[i].Init(ns_, scratch);
252  ZeroVector<double>(ns_, states[i]);
253  outputs[i].Init(ns_, scratch);
254  ZeroVector<double>(ns_, outputs[i]);
255  }
256  }
257  // Used only if a softmax LSTM.
258  NetworkScratch::FloatVec softmax_output;
259  NetworkScratch::IO int_output;
260  if (softmax_ != NULL) {
261  softmax_output.Init(no_, scratch);
262  ZeroVector<double>(no_, softmax_output);
263  int rounded_softmax_inputs = gate_weights_[CI].RoundInputs(ns_);
264  if (input.int_mode())
265  int_output.Resize2d(true, 1, rounded_softmax_inputs, scratch);
266  softmax_->SetupForward(input, NULL);
267  }
268  NetworkScratch::FloatVec curr_input;
269  curr_input.Init(na_, scratch);
270  StrideMap::Index src_index(input_map_);
271  // Used only by NT_LSTM_SUMMARY.
272  StrideMap::Index dest_index(output->stride_map());
273  do {
274  int t = src_index.t();
275  // True if there is a valid old state for the 2nd dimension.
276  bool valid_2d = Is2D();
277  if (valid_2d) {
278  StrideMap::Index dim_index(src_index);
279  if (!dim_index.AddOffset(-1, FD_HEIGHT)) valid_2d = false;
280  }
281  // Index of the 2-D revolving buffers (outputs, states).
282  int mod_t = Modulo(t, buf_width); // Current timestep.
283  // Setup the padded input in source.
284  source_.CopyTimeStepGeneral(t, 0, ni_, input, t, 0);
285  if (softmax_ != NULL) {
286  source_.WriteTimeStepPart(t, ni_, nf_, softmax_output);
287  }
288  source_.WriteTimeStepPart(t, ni_ + nf_, ns_, curr_output);
289  if (Is2D())
290  source_.WriteTimeStepPart(t, ni_ + nf_ + ns_, ns_, outputs[mod_t]);
291  if (!source_.int_mode()) source_.ReadTimeStep(t, curr_input);
292  // Matrix multiply the inputs with the source.
294  // It looks inefficient to create the threads on each t iteration, but the
295  // alternative of putting the parallel outside the t loop, a single around
296  // the t-loop and then tasks in place of the sections is a *lot* slower.
297  // Cell inputs.
298  if (source_.int_mode())
299  gate_weights_[CI].MatrixDotVector(source_.i(t), temp_lines[CI]);
300  else
301  gate_weights_[CI].MatrixDotVector(curr_input, temp_lines[CI]);
302  FuncInplace<GFunc>(ns_, temp_lines[CI]);
303 
305  // Input Gates.
306  if (source_.int_mode())
307  gate_weights_[GI].MatrixDotVector(source_.i(t), temp_lines[GI]);
308  else
309  gate_weights_[GI].MatrixDotVector(curr_input, temp_lines[GI]);
310  FuncInplace<FFunc>(ns_, temp_lines[GI]);
311 
313  // 1-D forget gates.
314  if (source_.int_mode())
315  gate_weights_[GF1].MatrixDotVector(source_.i(t), temp_lines[GF1]);
316  else
317  gate_weights_[GF1].MatrixDotVector(curr_input, temp_lines[GF1]);
318  FuncInplace<FFunc>(ns_, temp_lines[GF1]);
319 
320  // 2-D forget gates.
321  if (Is2D()) {
322  if (source_.int_mode())
323  gate_weights_[GFS].MatrixDotVector(source_.i(t), temp_lines[GFS]);
324  else
325  gate_weights_[GFS].MatrixDotVector(curr_input, temp_lines[GFS]);
326  FuncInplace<FFunc>(ns_, temp_lines[GFS]);
327  }
328 
330  // Output gates.
331  if (source_.int_mode())
332  gate_weights_[GO].MatrixDotVector(source_.i(t), temp_lines[GO]);
333  else
334  gate_weights_[GO].MatrixDotVector(curr_input, temp_lines[GO]);
335  FuncInplace<FFunc>(ns_, temp_lines[GO]);
337 
338  // Apply forget gate to state.
339  MultiplyVectorsInPlace(ns_, temp_lines[GF1], curr_state);
340  if (Is2D()) {
341  // Max-pool the forget gates (in 2-d) instead of blindly adding.
342  inT8* which_fg_col = which_fg_[t];
343  memset(which_fg_col, 1, ns_ * sizeof(which_fg_col[0]));
344  if (valid_2d) {
345  const double* stepped_state = states[mod_t];
346  for (int i = 0; i < ns_; ++i) {
347  if (temp_lines[GF1][i] < temp_lines[GFS][i]) {
348  curr_state[i] = temp_lines[GFS][i] * stepped_state[i];
349  which_fg_col[i] = 2;
350  }
351  }
352  }
353  }
354  MultiplyAccumulate(ns_, temp_lines[CI], temp_lines[GI], curr_state);
355  // Clip curr_state to a sane range.
356  ClipVector<double>(ns_, -kStateClip, kStateClip, curr_state);
357  if (IsTraining()) {
358  // Save the gate node values.
359  node_values_[CI].WriteTimeStep(t, temp_lines[CI]);
360  node_values_[GI].WriteTimeStep(t, temp_lines[GI]);
361  node_values_[GF1].WriteTimeStep(t, temp_lines[GF1]);
362  node_values_[GO].WriteTimeStep(t, temp_lines[GO]);
363  if (Is2D()) node_values_[GFS].WriteTimeStep(t, temp_lines[GFS]);
364  }
365  FuncMultiply<HFunc>(curr_state, temp_lines[GO], ns_, curr_output);
366  if (IsTraining()) state_.WriteTimeStep(t, curr_state);
367  if (softmax_ != NULL) {
368  if (input.int_mode()) {
369  int_output->WriteTimeStepPart(0, 0, ns_, curr_output);
370  softmax_->ForwardTimeStep(NULL, int_output->i(0), t, softmax_output);
371  } else {
372  softmax_->ForwardTimeStep(curr_output, NULL, t, softmax_output);
373  }
374  output->WriteTimeStep(t, softmax_output);
376  CodeInBinary(no_, nf_, softmax_output);
377  }
378  } else if (type_ == NT_LSTM_SUMMARY) {
379  // Output only at the end of a row.
380  if (src_index.IsLast(FD_WIDTH)) {
381  output->WriteTimeStep(dest_index.t(), curr_output);
382  dest_index.Increment();
383  }
384  } else {
385  output->WriteTimeStep(t, curr_output);
386  }
387  // Save states for use by the 2nd dimension only if needed.
388  if (Is2D()) {
389  CopyVector(ns_, curr_state, states[mod_t]);
390  CopyVector(ns_, curr_output, outputs[mod_t]);
391  }
392  // Always zero the states at the end of every row, but only for the major
393  // direction. The 2-D state remains intact.
394  if (src_index.IsLast(FD_WIDTH)) {
395  ZeroVector<double>(ns_, curr_state);
396  ZeroVector<double>(ns_, curr_output);
397  }
398  } while (src_index.Increment());
399 #if DEBUG_DETAIL > 0
400  tprintf("Source:%s\n", name_.string());
401  source_.Print(10);
402  tprintf("State:%s\n", name_.string());
403  state_.Print(10);
404  tprintf("Output:%s\n", name_.string());
405  output->Print(10);
406 #endif
407  if (debug) DisplayForward(*output);
408 }
void CopyVector(int n, const double *src, double *dest)
Definition: functions.h:186
void CopyTimeStepGeneral(int dest_t, int dest_offset, int num_features, const NetworkIO &src, int src_t, int src_offset)
Definition: networkio.cpp:398
void DisplayForward(const NetworkIO &matrix)
Definition: network.cpp:285
const double kStateClip
Definition: lstm.cpp:66
void WriteTimeStepPart(int t, int offset, int num_features, const double *input)
Definition: networkio.cpp:656
void ForwardTimeStep(const double *d_input, const inT8 *i_input, int t, double *output_line)
NetworkType type_
Definition: network.h:299
void MatrixDotVector(const double *u, double *v) const
#define SECTION_IF_OPENMP
Definition: lstm.cpp:57
int Size(FlexDimensions dimension) const
Definition: stridemap.h:116
#define PARALLEL_IF_OPENMP(__num_threads)
Definition: lstm.cpp:56
#define tprintf(...)
Definition: tprintf.h:31
void ReadTimeStep(int t, double *output) const
Definition: networkio.cpp:603
void MultiplyVectorsInPlace(int n, const double *src, double *inout)
Definition: functions.h:196
bool IsTraining() const
Definition: network.h:115
int8_t inT8
Definition: host.h:34
bool int_mode() const
Definition: networkio.h:127
bool Is2D() const
Definition: lstm.h:120
const char * string() const
Definition: strngs.cpp:198
void WriteTimeStep(int t, const double *input)
Definition: networkio.cpp:650
void MultiplyAccumulate(int n, const double *u, const double *v, double *out)
Definition: functions.h:201
void Print(int num) const
Definition: networkio.cpp:371
#define END_PARALLEL_IF_OPENMP
Definition: lstm.cpp:58
int Modulo(int a, int b)
Definition: helpers.h:164
void SetupForward(const NetworkIO &input, const TransposedArray *input_transpose)
int RoundInputs(int size) const
Definition: weightmatrix.h:92
void init_to_size(int size, T t)
void CodeInBinary(int n, int nf, double *vec)
Definition: functions.h:231
const inT8 * i(int t) const
Definition: networkio.h:123

◆ InitWeights()

int tesseract::LSTM::InitWeights ( float  range,
TRand randomizer 
)
virtual

Reimplemented from tesseract::Network.

Definition at line 129 of file lstm.cpp.

129  {
130  Network::SetRandomizer(randomizer);
131  num_weights_ = 0;
132  for (int w = 0; w < WT_COUNT; ++w) {
133  if (w == GFS && !Is2D()) continue;
134  num_weights_ += gate_weights_[w].InitWeightsFloat(
135  ns_, na_ + 1, TestFlag(NF_ADAM), range, randomizer);
136  }
137  if (softmax_ != NULL) {
138  num_weights_ += softmax_->InitWeights(range, randomizer);
139  }
140  return num_weights_;
141 }
virtual int InitWeights(float range, TRand *randomizer)
virtual void SetRandomizer(TRand *randomizer)
Definition: network.cpp:140
bool TestFlag(NetworkFlags flag) const
Definition: network.h:144
int InitWeightsFloat(int no, int ni, bool use_adam, float weight_range, TRand *randomizer)
bool Is2D() const
Definition: lstm.h:120
inT32 num_weights_
Definition: network.h:305

◆ Is2D()

bool tesseract::LSTM::Is2D ( ) const
inline

Definition at line 120 of file lstm.h.

120  {
121  return is_2d_;
122  }

◆ OutputShape()

StaticShape tesseract::LSTM::OutputShape ( const StaticShape input_shape) const
virtual

Reimplemented from tesseract::Network.

Definition at line 98 of file lstm.cpp.

98  {
99  StaticShape result = input_shape;
100  result.set_depth(no_);
101  if (type_ == NT_LSTM_SUMMARY) result.set_width(1);
102  if (softmax_ != NULL) return softmax_->OutputShape(result);
103  return result;
104 }
virtual StaticShape OutputShape(const StaticShape &input_shape) const
NetworkType type_
Definition: network.h:299

◆ PrintDW()

void tesseract::LSTM::PrintDW ( )

Definition at line 698 of file lstm.cpp.

698  {
699  tprintf("Delta state:%s\n", name_.string());
700  for (int w = 0; w < WT_COUNT; ++w) {
701  if (w == GFS && !Is2D()) continue;
702  tprintf("Gate %d, inputs\n", w);
703  for (int i = 0; i < ni_; ++i) {
704  tprintf("Row %d:", i);
705  for (int s = 0; s < ns_; ++s)
706  tprintf(" %g", gate_weights_[w].GetDW(s, i));
707  tprintf("\n");
708  }
709  tprintf("Gate %d, outputs\n", w);
710  for (int i = ni_; i < ni_ + ns_; ++i) {
711  tprintf("Row %d:", i - ni_);
712  for (int s = 0; s < ns_; ++s)
713  tprintf(" %g", gate_weights_[w].GetDW(s, i));
714  tprintf("\n");
715  }
716  tprintf("Gate %d, bias\n", w);
717  for (int s = 0; s < ns_; ++s)
718  tprintf(" %g", gate_weights_[w].GetDW(s, na_));
719  tprintf("\n");
720  }
721 }
#define tprintf(...)
Definition: tprintf.h:31
bool Is2D() const
Definition: lstm.h:120
const char * string() const
Definition: strngs.cpp:198

◆ PrintW()

void tesseract::LSTM::PrintW ( )

Definition at line 672 of file lstm.cpp.

672  {
673  tprintf("Weight state:%s\n", name_.string());
674  for (int w = 0; w < WT_COUNT; ++w) {
675  if (w == GFS && !Is2D()) continue;
676  tprintf("Gate %d, inputs\n", w);
677  for (int i = 0; i < ni_; ++i) {
678  tprintf("Row %d:", i);
679  for (int s = 0; s < ns_; ++s)
680  tprintf(" %g", gate_weights_[w].GetWeights(s)[i]);
681  tprintf("\n");
682  }
683  tprintf("Gate %d, outputs\n", w);
684  for (int i = ni_; i < ni_ + ns_; ++i) {
685  tprintf("Row %d:", i - ni_);
686  for (int s = 0; s < ns_; ++s)
687  tprintf(" %g", gate_weights_[w].GetWeights(s)[i]);
688  tprintf("\n");
689  }
690  tprintf("Gate %d, bias\n", w);
691  for (int s = 0; s < ns_; ++s)
692  tprintf(" %g", gate_weights_[w].GetWeights(s)[na_]);
693  tprintf("\n");
694  }
695 }
#define tprintf(...)
Definition: tprintf.h:31
bool Is2D() const
Definition: lstm.h:120
const char * string() const
Definition: strngs.cpp:198

◆ RemapOutputs()

int tesseract::LSTM::RemapOutputs ( int  old_no,
const std::vector< int > &  code_map 
)
overridevirtual

Reimplemented from tesseract::Network.

Definition at line 145 of file lstm.cpp.

145  {
146  if (softmax_ != NULL) {
147  num_weights_ -= softmax_->num_weights();
148  num_weights_ += softmax_->RemapOutputs(old_no, code_map);
149  }
150  return num_weights_;
151 }
inT32 num_weights_
Definition: network.h:305
int num_weights() const
Definition: network.h:119
int RemapOutputs(int old_no, const std::vector< int > &code_map) override

◆ Serialize()

bool tesseract::LSTM::Serialize ( TFile fp) const
virtual

Reimplemented from tesseract::Network.

Definition at line 178 of file lstm.cpp.

178  {
179  if (!Network::Serialize(fp)) return false;
180  if (fp->FWrite(&na_, sizeof(na_), 1) != 1) return false;
181  for (int w = 0; w < WT_COUNT; ++w) {
182  if (w == GFS && !Is2D()) continue;
183  if (!gate_weights_[w].Serialize(IsTraining(), fp)) return false;
184  }
185  if (softmax_ != NULL && !softmax_->Serialize(fp)) return false;
186  return true;
187 }
virtual bool Serialize(TFile *fp) const
Definition: network.cpp:153
virtual bool Serialize(TFile *fp) const
Definition: lstm.cpp:178
virtual bool Serialize(TFile *fp) const
bool IsTraining() const
Definition: network.h:115
bool Is2D() const
Definition: lstm.h:120

◆ SetEnableTraining()

void tesseract::LSTM::SetEnableTraining ( TrainingState  state)
virtual

Reimplemented from tesseract::Network.

Definition at line 108 of file lstm.cpp.

108  {
109  if (state == TS_RE_ENABLE) {
110  // Enable only from temp disabled.
112  } else if (state == TS_TEMP_DISABLE) {
113  // Temp disable only from enabled.
114  if (training_ == TS_ENABLED) training_ = state;
115  } else {
116  if (state == TS_ENABLED && training_ != TS_ENABLED) {
117  for (int w = 0; w < WT_COUNT; ++w) {
118  if (w == GFS && !Is2D()) continue;
119  gate_weights_[w].InitBackward();
120  }
121  }
122  training_ = state;
123  }
124  if (softmax_ != NULL) softmax_->SetEnableTraining(state);
125 }
virtual void SetEnableTraining(TrainingState state)
TrainingState training_
Definition: network.h:300
bool Is2D() const
Definition: lstm.h:120

◆ spec()

virtual STRING tesseract::LSTM::spec ( ) const
inlinevirtual

Reimplemented from tesseract::Network.

Definition at line 58 of file lstm.h.

58  {
59  STRING spec;
60  if (type_ == NT_LSTM)
61  spec.add_str_int("Lfx", ns_);
62  else if (type_ == NT_LSTM_SUMMARY)
63  spec.add_str_int("Lfxs", ns_);
64  else if (type_ == NT_LSTM_SOFTMAX)
65  spec.add_str_int("LS", ns_);
66  else if (type_ == NT_LSTM_SOFTMAX_ENCODED)
67  spec.add_str_int("LE", ns_);
68  if (softmax_ != NULL) spec += softmax_->spec();
69  return spec;
70  }
void add_str_int(const char *str, int number)
Definition: strngs.cpp:381
NetworkType type_
Definition: network.h:299
Definition: strngs.h:45
virtual STRING spec() const
virtual STRING spec() const
Definition: lstm.h:58

◆ Update()

void tesseract::LSTM::Update ( float  learning_rate,
float  momentum,
float  adam_beta,
int  num_samples 
)
overridevirtual

Reimplemented from tesseract::Network.

Definition at line 638 of file lstm.cpp.

639  {
640 #if DEBUG_DETAIL > 3
641  PrintW();
642 #endif
643  for (int w = 0; w < WT_COUNT; ++w) {
644  if (w == GFS && !Is2D()) continue;
645  gate_weights_[w].Update(learning_rate, momentum, adam_beta, num_samples);
646  }
647  if (softmax_ != NULL) {
648  softmax_->Update(learning_rate, momentum, adam_beta, num_samples);
649  }
650 #if DEBUG_DETAIL > 3
651  PrintDW();
652 #endif
653 }
void PrintW()
Definition: lstm.cpp:672
bool Is2D() const
Definition: lstm.h:120
void PrintDW()
Definition: lstm.cpp:698
void Update(double learning_rate, double momentum, double adam_beta, int num_samples)
void Update(float learning_rate, float momentum, float adam_beta, int num_samples) override

The documentation for this class was generated from the following files: