tesseract  4.00.00dev
lstm.cpp
Go to the documentation of this file.
1 // File: lstm.cpp
3 // Description: Long-term-short-term-memory Recurrent neural network.
4 // Author: Ray Smith
5 // Created: Wed May 01 17:43:06 PST 2013
6 //
7 // (C) Copyright 2013, Google Inc.
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 // http://www.apache.org/licenses/LICENSE-2.0
12 // Unless required by applicable law or agreed to in writing, software
13 // distributed under the License is distributed on an "AS IS" BASIS,
14 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 // See the License for the specific language governing permissions and
16 // limitations under the License.
18 
19 #include "lstm.h"
20 
21 #ifdef _OPENMP
22 #include <omp.h>
23 #endif
24 #include <stdio.h>
25 #include <stdlib.h>
26 
27 #include "fullyconnected.h"
28 #include "functions.h"
29 #include "networkscratch.h"
30 #include "tprintf.h"
31 
32 // Macros for openmp code if it is available, otherwise empty macros.
33 #ifdef _OPENMP
34 #define PARALLEL_IF_OPENMP(__num_threads) \
35  PRAGMA(omp parallel if (__num_threads > 1) num_threads(__num_threads)) { \
36  PRAGMA(omp sections nowait) { \
37  PRAGMA(omp section) {
38 #define SECTION_IF_OPENMP \
39  } \
40  PRAGMA(omp section) \
41  {
42 
43 #define END_PARALLEL_IF_OPENMP \
44  } \
45  } /* end of sections */ \
46  } /* end of parallel section */
47 
48 // Define the portable PRAGMA macro.
49 #ifdef _MSC_VER // Different _Pragma
50 #define PRAGMA(x) __pragma(x)
51 #else
52 #define PRAGMA(x) _Pragma(#x)
53 #endif // _MSC_VER
54 
55 #else // _OPENMP
56 #define PARALLEL_IF_OPENMP(__num_threads)
57 #define SECTION_IF_OPENMP
58 #define END_PARALLEL_IF_OPENMP
59 #endif // _OPENMP
60 
61 
62 namespace tesseract {
63 
64 // Max absolute value of state_. It is reasonably high to enable the state
65 // to count things.
66 const double kStateClip = 100.0;
67 // Max absolute value of gate_errors (the gradients).
68 const double kErrClip = 1.0f;
69 
70 LSTM::LSTM(const STRING& name, int ni, int ns, int no, bool two_dimensional,
71  NetworkType type)
72  : Network(type, name, ni, no),
73  na_(ni + ns),
74  ns_(ns),
75  nf_(0),
76  is_2d_(two_dimensional),
77  softmax_(NULL),
78  input_width_(0) {
79  if (two_dimensional) na_ += ns_;
80  if (type_ == NT_LSTM || type_ == NT_LSTM_SUMMARY) {
81  nf_ = 0;
82  // networkbuilder ensures this is always true.
83  ASSERT_HOST(no == ns);
84  } else if (type_ == NT_LSTM_SOFTMAX || type_ == NT_LSTM_SOFTMAX_ENCODED) {
85  nf_ = type_ == NT_LSTM_SOFTMAX ? no_ : IntCastRounded(ceil(log2(no_)));
86  softmax_ = new FullyConnected("LSTM Softmax", ns_, no_, NT_SOFTMAX);
87  } else {
88  tprintf("%d is invalid type of LSTM!\n", type);
89  ASSERT_HOST(false);
90  }
91  na_ += nf_;
92 }
93 
94 LSTM::~LSTM() { delete softmax_; }
95 
96 // Returns the shape output from the network given an input shape (which may
97 // be partially unknown ie zero).
98 StaticShape LSTM::OutputShape(const StaticShape& input_shape) const {
99  StaticShape result = input_shape;
100  result.set_depth(no_);
101  if (type_ == NT_LSTM_SUMMARY) result.set_width(1);
102  if (softmax_ != NULL) return softmax_->OutputShape(result);
103  return result;
104 }
105 
106 // Suspends/Enables training by setting the training_ flag. Serialize and
107 // DeSerialize only operate on the run-time data if state is false.
109  if (state == TS_RE_ENABLE) {
110  // Enable only from temp disabled.
112  } else if (state == TS_TEMP_DISABLE) {
113  // Temp disable only from enabled.
114  if (training_ == TS_ENABLED) training_ = state;
115  } else {
116  if (state == TS_ENABLED && training_ != TS_ENABLED) {
117  for (int w = 0; w < WT_COUNT; ++w) {
118  if (w == GFS && !Is2D()) continue;
119  gate_weights_[w].InitBackward();
120  }
121  }
122  training_ = state;
123  }
124  if (softmax_ != NULL) softmax_->SetEnableTraining(state);
125 }
126 
127 // Sets up the network for training. Initializes weights using weights of
128 // scale `range` picked according to the random number generator `randomizer`.
129 int LSTM::InitWeights(float range, TRand* randomizer) {
130  Network::SetRandomizer(randomizer);
131  num_weights_ = 0;
132  for (int w = 0; w < WT_COUNT; ++w) {
133  if (w == GFS && !Is2D()) continue;
134  num_weights_ += gate_weights_[w].InitWeightsFloat(
135  ns_, na_ + 1, TestFlag(NF_ADAM), range, randomizer);
136  }
137  if (softmax_ != NULL) {
138  num_weights_ += softmax_->InitWeights(range, randomizer);
139  }
140  return num_weights_;
141 }
142 
143 // Recursively searches the network for softmaxes with old_no outputs,
144 // and remaps their outputs according to code_map. See network.h for details.
145 int LSTM::RemapOutputs(int old_no, const std::vector<int>& code_map) {
146  if (softmax_ != NULL) {
147  num_weights_ -= softmax_->num_weights();
148  num_weights_ += softmax_->RemapOutputs(old_no, code_map);
149  }
150  return num_weights_;
151 }
152 
153 // Converts a float network to an int network.
155  for (int w = 0; w < WT_COUNT; ++w) {
156  if (w == GFS && !Is2D()) continue;
157  gate_weights_[w].ConvertToInt();
158  }
159  if (softmax_ != NULL) {
160  softmax_->ConvertToInt();
161  }
162 }
163 
164 // Sets up the network for training using the given weight_range.
166  for (int w = 0; w < WT_COUNT; ++w) {
167  if (w == GFS && !Is2D()) continue;
168  STRING msg = name_;
169  msg.add_str_int(" Gate weights ", w);
170  gate_weights_[w].Debug2D(msg.string());
171  }
172  if (softmax_ != NULL) {
173  softmax_->DebugWeights();
174  }
175 }
176 
177 // Writes to the given file. Returns false in case of error.
178 bool LSTM::Serialize(TFile* fp) const {
179  if (!Network::Serialize(fp)) return false;
180  if (fp->FWrite(&na_, sizeof(na_), 1) != 1) return false;
181  for (int w = 0; w < WT_COUNT; ++w) {
182  if (w == GFS && !Is2D()) continue;
183  if (!gate_weights_[w].Serialize(IsTraining(), fp)) return false;
184  }
185  if (softmax_ != NULL && !softmax_->Serialize(fp)) return false;
186  return true;
187 }
188 
189 // Reads from the given file. Returns false in case of error.
190 
192  if (fp->FReadEndian(&na_, sizeof(na_), 1) != 1) return false;
193  if (type_ == NT_LSTM_SOFTMAX) {
194  nf_ = no_;
195  } else if (type_ == NT_LSTM_SOFTMAX_ENCODED) {
196  nf_ = IntCastRounded(ceil(log2(no_)));
197  } else {
198  nf_ = 0;
199  }
200  is_2d_ = false;
201  for (int w = 0; w < WT_COUNT; ++w) {
202  if (w == GFS && !Is2D()) continue;
203  if (!gate_weights_[w].DeSerialize(IsTraining(), fp)) return false;
204  if (w == CI) {
205  ns_ = gate_weights_[CI].NumOutputs();
206  is_2d_ = na_ - nf_ == ni_ + 2 * ns_;
207  }
208  }
209  delete softmax_;
211  softmax_ = static_cast<FullyConnected*>(Network::CreateFromFile(fp));
212  if (softmax_ == nullptr) return false;
213  } else {
214  softmax_ = nullptr;
215  }
216  return true;
217 }
218 
219 // Runs forward propagation of activations on the input line.
220 // See NetworkCpp for a detailed discussion of the arguments.
221 void LSTM::Forward(bool debug, const NetworkIO& input,
222  const TransposedArray* input_transpose,
223  NetworkScratch* scratch, NetworkIO* output) {
224  input_map_ = input.stride_map();
225  input_width_ = input.Width();
226  if (softmax_ != NULL)
227  output->ResizeFloat(input, no_);
228  else if (type_ == NT_LSTM_SUMMARY)
229  output->ResizeXTo1(input, no_);
230  else
231  output->Resize(input, no_);
232  ResizeForward(input);
233  // Temporary storage of forward computation for each gate.
235  for (int i = 0; i < WT_COUNT; ++i) temp_lines[i].Init(ns_, scratch);
236  // Single timestep buffers for the current/recurrent output and state.
237  NetworkScratch::FloatVec curr_state, curr_output;
238  curr_state.Init(ns_, scratch);
239  ZeroVector<double>(ns_, curr_state);
240  curr_output.Init(ns_, scratch);
241  ZeroVector<double>(ns_, curr_output);
242  // Rotating buffers of width buf_width allow storage of the state and output
243  // for the other dimension, used only when working in true 2D mode. The width
244  // is enough to hold an entire strip of the major direction.
245  int buf_width = Is2D() ? input_map_.Size(FD_WIDTH) : 1;
247  if (Is2D()) {
248  states.init_to_size(buf_width, NetworkScratch::FloatVec());
249  outputs.init_to_size(buf_width, NetworkScratch::FloatVec());
250  for (int i = 0; i < buf_width; ++i) {
251  states[i].Init(ns_, scratch);
252  ZeroVector<double>(ns_, states[i]);
253  outputs[i].Init(ns_, scratch);
254  ZeroVector<double>(ns_, outputs[i]);
255  }
256  }
257  // Used only if a softmax LSTM.
258  NetworkScratch::FloatVec softmax_output;
259  NetworkScratch::IO int_output;
260  if (softmax_ != NULL) {
261  softmax_output.Init(no_, scratch);
262  ZeroVector<double>(no_, softmax_output);
263  int rounded_softmax_inputs = gate_weights_[CI].RoundInputs(ns_);
264  if (input.int_mode())
265  int_output.Resize2d(true, 1, rounded_softmax_inputs, scratch);
266  softmax_->SetupForward(input, NULL);
267  }
268  NetworkScratch::FloatVec curr_input;
269  curr_input.Init(na_, scratch);
270  StrideMap::Index src_index(input_map_);
271  // Used only by NT_LSTM_SUMMARY.
272  StrideMap::Index dest_index(output->stride_map());
273  do {
274  int t = src_index.t();
275  // True if there is a valid old state for the 2nd dimension.
276  bool valid_2d = Is2D();
277  if (valid_2d) {
278  StrideMap::Index dim_index(src_index);
279  if (!dim_index.AddOffset(-1, FD_HEIGHT)) valid_2d = false;
280  }
281  // Index of the 2-D revolving buffers (outputs, states).
282  int mod_t = Modulo(t, buf_width); // Current timestep.
283  // Setup the padded input in source.
284  source_.CopyTimeStepGeneral(t, 0, ni_, input, t, 0);
285  if (softmax_ != NULL) {
286  source_.WriteTimeStepPart(t, ni_, nf_, softmax_output);
287  }
288  source_.WriteTimeStepPart(t, ni_ + nf_, ns_, curr_output);
289  if (Is2D())
290  source_.WriteTimeStepPart(t, ni_ + nf_ + ns_, ns_, outputs[mod_t]);
291  if (!source_.int_mode()) source_.ReadTimeStep(t, curr_input);
292  // Matrix multiply the inputs with the source.
294  // It looks inefficient to create the threads on each t iteration, but the
295  // alternative of putting the parallel outside the t loop, a single around
296  // the t-loop and then tasks in place of the sections is a *lot* slower.
297  // Cell inputs.
298  if (source_.int_mode())
299  gate_weights_[CI].MatrixDotVector(source_.i(t), temp_lines[CI]);
300  else
301  gate_weights_[CI].MatrixDotVector(curr_input, temp_lines[CI]);
302  FuncInplace<GFunc>(ns_, temp_lines[CI]);
303 
305  // Input Gates.
306  if (source_.int_mode())
307  gate_weights_[GI].MatrixDotVector(source_.i(t), temp_lines[GI]);
308  else
309  gate_weights_[GI].MatrixDotVector(curr_input, temp_lines[GI]);
310  FuncInplace<FFunc>(ns_, temp_lines[GI]);
311 
313  // 1-D forget gates.
314  if (source_.int_mode())
315  gate_weights_[GF1].MatrixDotVector(source_.i(t), temp_lines[GF1]);
316  else
317  gate_weights_[GF1].MatrixDotVector(curr_input, temp_lines[GF1]);
318  FuncInplace<FFunc>(ns_, temp_lines[GF1]);
319 
320  // 2-D forget gates.
321  if (Is2D()) {
322  if (source_.int_mode())
323  gate_weights_[GFS].MatrixDotVector(source_.i(t), temp_lines[GFS]);
324  else
325  gate_weights_[GFS].MatrixDotVector(curr_input, temp_lines[GFS]);
326  FuncInplace<FFunc>(ns_, temp_lines[GFS]);
327  }
328 
330  // Output gates.
331  if (source_.int_mode())
332  gate_weights_[GO].MatrixDotVector(source_.i(t), temp_lines[GO]);
333  else
334  gate_weights_[GO].MatrixDotVector(curr_input, temp_lines[GO]);
335  FuncInplace<FFunc>(ns_, temp_lines[GO]);
337 
338  // Apply forget gate to state.
339  MultiplyVectorsInPlace(ns_, temp_lines[GF1], curr_state);
340  if (Is2D()) {
341  // Max-pool the forget gates (in 2-d) instead of blindly adding.
342  inT8* which_fg_col = which_fg_[t];
343  memset(which_fg_col, 1, ns_ * sizeof(which_fg_col[0]));
344  if (valid_2d) {
345  const double* stepped_state = states[mod_t];
346  for (int i = 0; i < ns_; ++i) {
347  if (temp_lines[GF1][i] < temp_lines[GFS][i]) {
348  curr_state[i] = temp_lines[GFS][i] * stepped_state[i];
349  which_fg_col[i] = 2;
350  }
351  }
352  }
353  }
354  MultiplyAccumulate(ns_, temp_lines[CI], temp_lines[GI], curr_state);
355  // Clip curr_state to a sane range.
356  ClipVector<double>(ns_, -kStateClip, kStateClip, curr_state);
357  if (IsTraining()) {
358  // Save the gate node values.
359  node_values_[CI].WriteTimeStep(t, temp_lines[CI]);
360  node_values_[GI].WriteTimeStep(t, temp_lines[GI]);
361  node_values_[GF1].WriteTimeStep(t, temp_lines[GF1]);
362  node_values_[GO].WriteTimeStep(t, temp_lines[GO]);
363  if (Is2D()) node_values_[GFS].WriteTimeStep(t, temp_lines[GFS]);
364  }
365  FuncMultiply<HFunc>(curr_state, temp_lines[GO], ns_, curr_output);
366  if (IsTraining()) state_.WriteTimeStep(t, curr_state);
367  if (softmax_ != NULL) {
368  if (input.int_mode()) {
369  int_output->WriteTimeStepPart(0, 0, ns_, curr_output);
370  softmax_->ForwardTimeStep(NULL, int_output->i(0), t, softmax_output);
371  } else {
372  softmax_->ForwardTimeStep(curr_output, NULL, t, softmax_output);
373  }
374  output->WriteTimeStep(t, softmax_output);
376  CodeInBinary(no_, nf_, softmax_output);
377  }
378  } else if (type_ == NT_LSTM_SUMMARY) {
379  // Output only at the end of a row.
380  if (src_index.IsLast(FD_WIDTH)) {
381  output->WriteTimeStep(dest_index.t(), curr_output);
382  dest_index.Increment();
383  }
384  } else {
385  output->WriteTimeStep(t, curr_output);
386  }
387  // Save states for use by the 2nd dimension only if needed.
388  if (Is2D()) {
389  CopyVector(ns_, curr_state, states[mod_t]);
390  CopyVector(ns_, curr_output, outputs[mod_t]);
391  }
392  // Always zero the states at the end of every row, but only for the major
393  // direction. The 2-D state remains intact.
394  if (src_index.IsLast(FD_WIDTH)) {
395  ZeroVector<double>(ns_, curr_state);
396  ZeroVector<double>(ns_, curr_output);
397  }
398  } while (src_index.Increment());
399 #if DEBUG_DETAIL > 0
400  tprintf("Source:%s\n", name_.string());
401  source_.Print(10);
402  tprintf("State:%s\n", name_.string());
403  state_.Print(10);
404  tprintf("Output:%s\n", name_.string());
405  output->Print(10);
406 #endif
407  if (debug) DisplayForward(*output);
408 }
409 
410 // Runs backward propagation of errors on the deltas line.
411 // See NetworkCpp for a detailed discussion of the arguments.
412 bool LSTM::Backward(bool debug, const NetworkIO& fwd_deltas,
413  NetworkScratch* scratch,
414  NetworkIO* back_deltas) {
415  if (debug) DisplayBackward(fwd_deltas);
416  back_deltas->ResizeToMap(fwd_deltas.int_mode(), input_map_, ni_);
417  // ======Scratch space.======
418  // Output errors from deltas with recurrence from sourceerr.
419  NetworkScratch::FloatVec outputerr;
420  outputerr.Init(ns_, scratch);
421  // Recurrent error in the state/source.
422  NetworkScratch::FloatVec curr_stateerr, curr_sourceerr;
423  curr_stateerr.Init(ns_, scratch);
424  curr_sourceerr.Init(na_, scratch);
425  ZeroVector<double>(ns_, curr_stateerr);
426  ZeroVector<double>(na_, curr_sourceerr);
427  // Errors in the gates.
428  NetworkScratch::FloatVec gate_errors[WT_COUNT];
429  for (int g = 0; g < WT_COUNT; ++g) gate_errors[g].Init(ns_, scratch);
430  // Rotating buffers of width buf_width allow storage of the recurrent time-
431  // steps used only for true 2-D. Stores one full strip of the major direction.
432  int buf_width = Is2D() ? input_map_.Size(FD_WIDTH) : 1;
433  GenericVector<NetworkScratch::FloatVec> stateerr, sourceerr;
434  if (Is2D()) {
435  stateerr.init_to_size(buf_width, NetworkScratch::FloatVec());
436  sourceerr.init_to_size(buf_width, NetworkScratch::FloatVec());
437  for (int t = 0; t < buf_width; ++t) {
438  stateerr[t].Init(ns_, scratch);
439  sourceerr[t].Init(na_, scratch);
440  ZeroVector<double>(ns_, stateerr[t]);
441  ZeroVector<double>(na_, sourceerr[t]);
442  }
443  }
444  // Parallel-generated sourceerr from each of the gates.
445  NetworkScratch::FloatVec sourceerr_temps[WT_COUNT];
446  for (int w = 0; w < WT_COUNT; ++w)
447  sourceerr_temps[w].Init(na_, scratch);
448  int width = input_width_;
449  // Transposed gate errors stored over all timesteps for sum outer.
451  for (int w = 0; w < WT_COUNT; ++w) {
452  gate_errors_t[w].Init(ns_, width, scratch);
453  }
454  // Used only if softmax_ != NULL.
455  NetworkScratch::FloatVec softmax_errors;
456  NetworkScratch::GradientStore softmax_errors_t;
457  if (softmax_ != NULL) {
458  softmax_errors.Init(no_, scratch);
459  softmax_errors_t.Init(no_, width, scratch);
460  }
461  double state_clip = Is2D() ? 9.0 : 4.0;
462 #if DEBUG_DETAIL > 1
463  tprintf("fwd_deltas:%s\n", name_.string());
464  fwd_deltas.Print(10);
465 #endif
466  StrideMap::Index dest_index(input_map_);
467  dest_index.InitToLast();
468  // Used only by NT_LSTM_SUMMARY.
469  StrideMap::Index src_index(fwd_deltas.stride_map());
470  src_index.InitToLast();
471  do {
472  int t = dest_index.t();
473  bool at_last_x = dest_index.IsLast(FD_WIDTH);
474  // up_pos is the 2-D back step, down_pos is the 2-D fwd step, and are only
475  // valid if >= 0, which is true if 2d and not on the top/bottom.
476  int up_pos = -1;
477  int down_pos = -1;
478  if (Is2D()) {
479  if (dest_index.index(FD_HEIGHT) > 0) {
480  StrideMap::Index up_index(dest_index);
481  if (up_index.AddOffset(-1, FD_HEIGHT)) up_pos = up_index.t();
482  }
483  if (!dest_index.IsLast(FD_HEIGHT)) {
484  StrideMap::Index down_index(dest_index);
485  if (down_index.AddOffset(1, FD_HEIGHT)) down_pos = down_index.t();
486  }
487  }
488  // Index of the 2-D revolving buffers (sourceerr, stateerr).
489  int mod_t = Modulo(t, buf_width); // Current timestep.
490  // Zero the state in the major direction only at the end of every row.
491  if (at_last_x) {
492  ZeroVector<double>(na_, curr_sourceerr);
493  ZeroVector<double>(ns_, curr_stateerr);
494  }
495  // Setup the outputerr.
496  if (type_ == NT_LSTM_SUMMARY) {
497  if (dest_index.IsLast(FD_WIDTH)) {
498  fwd_deltas.ReadTimeStep(src_index.t(), outputerr);
499  src_index.Decrement();
500  } else {
501  ZeroVector<double>(ns_, outputerr);
502  }
503  } else if (softmax_ == NULL) {
504  fwd_deltas.ReadTimeStep(t, outputerr);
505  } else {
506  softmax_->BackwardTimeStep(fwd_deltas, t, softmax_errors,
507  softmax_errors_t.get(), outputerr);
508  }
509  if (!at_last_x)
510  AccumulateVector(ns_, curr_sourceerr + ni_ + nf_, outputerr);
511  if (down_pos >= 0)
512  AccumulateVector(ns_, sourceerr[mod_t] + ni_ + nf_ + ns_, outputerr);
513  // Apply the 1-d forget gates.
514  if (!at_last_x) {
515  const float* next_node_gf1 = node_values_[GF1].f(t + 1);
516  for (int i = 0; i < ns_; ++i) {
517  curr_stateerr[i] *= next_node_gf1[i];
518  }
519  }
520  if (Is2D() && t + 1 < width) {
521  for (int i = 0; i < ns_; ++i) {
522  if (which_fg_[t + 1][i] != 1) curr_stateerr[i] = 0.0;
523  }
524  if (down_pos >= 0) {
525  const float* right_node_gfs = node_values_[GFS].f(down_pos);
526  const double* right_stateerr = stateerr[mod_t];
527  for (int i = 0; i < ns_; ++i) {
528  if (which_fg_[down_pos][i] == 2) {
529  curr_stateerr[i] += right_stateerr[i] * right_node_gfs[i];
530  }
531  }
532  }
533  }
534  state_.FuncMultiply3Add<HPrime>(node_values_[GO], t, outputerr,
535  curr_stateerr);
536  // Clip stateerr_ to a sane range.
537  ClipVector<double>(ns_, -state_clip, state_clip, curr_stateerr);
538 #if DEBUG_DETAIL > 1
539  if (t + 10 > width) {
540  tprintf("t=%d, stateerr=", t);
541  for (int i = 0; i < ns_; ++i)
542  tprintf(" %g,%g,%g", curr_stateerr[i], outputerr[i],
543  curr_sourceerr[ni_ + nf_ + i]);
544  tprintf("\n");
545  }
546 #endif
547  // Matrix multiply to get the source errors.
549 
550  // Cell inputs.
551  node_values_[CI].FuncMultiply3<GPrime>(t, node_values_[GI], t,
552  curr_stateerr, gate_errors[CI]);
553  ClipVector(ns_, -kErrClip, kErrClip, gate_errors[CI].get());
554  gate_weights_[CI].VectorDotMatrix(gate_errors[CI], sourceerr_temps[CI]);
555  gate_errors_t[CI].get()->WriteStrided(t, gate_errors[CI]);
556 
558  // Input Gates.
559  node_values_[GI].FuncMultiply3<FPrime>(t, node_values_[CI], t,
560  curr_stateerr, gate_errors[GI]);
561  ClipVector(ns_, -kErrClip, kErrClip, gate_errors[GI].get());
562  gate_weights_[GI].VectorDotMatrix(gate_errors[GI], sourceerr_temps[GI]);
563  gate_errors_t[GI].get()->WriteStrided(t, gate_errors[GI]);
564 
566  // 1-D forget Gates.
567  if (t > 0) {
568  node_values_[GF1].FuncMultiply3<FPrime>(t, state_, t - 1, curr_stateerr,
569  gate_errors[GF1]);
570  ClipVector(ns_, -kErrClip, kErrClip, gate_errors[GF1].get());
571  gate_weights_[GF1].VectorDotMatrix(gate_errors[GF1],
572  sourceerr_temps[GF1]);
573  } else {
574  memset(gate_errors[GF1], 0, ns_ * sizeof(gate_errors[GF1][0]));
575  memset(sourceerr_temps[GF1], 0, na_ * sizeof(*sourceerr_temps[GF1]));
576  }
577  gate_errors_t[GF1].get()->WriteStrided(t, gate_errors[GF1]);
578 
579  // 2-D forget Gates.
580  if (up_pos >= 0) {
581  node_values_[GFS].FuncMultiply3<FPrime>(t, state_, up_pos, curr_stateerr,
582  gate_errors[GFS]);
583  ClipVector(ns_, -kErrClip, kErrClip, gate_errors[GFS].get());
584  gate_weights_[GFS].VectorDotMatrix(gate_errors[GFS],
585  sourceerr_temps[GFS]);
586  } else {
587  memset(gate_errors[GFS], 0, ns_ * sizeof(gate_errors[GFS][0]));
588  memset(sourceerr_temps[GFS], 0, na_ * sizeof(*sourceerr_temps[GFS]));
589  }
590  if (Is2D()) gate_errors_t[GFS].get()->WriteStrided(t, gate_errors[GFS]);
591 
593  // Output gates.
594  state_.Func2Multiply3<HFunc, FPrime>(node_values_[GO], t, outputerr,
595  gate_errors[GO]);
596  ClipVector(ns_, -kErrClip, kErrClip, gate_errors[GO].get());
597  gate_weights_[GO].VectorDotMatrix(gate_errors[GO], sourceerr_temps[GO]);
598  gate_errors_t[GO].get()->WriteStrided(t, gate_errors[GO]);
600 
601  SumVectors(na_, sourceerr_temps[CI], sourceerr_temps[GI],
602  sourceerr_temps[GF1], sourceerr_temps[GO], sourceerr_temps[GFS],
603  curr_sourceerr);
604  back_deltas->WriteTimeStep(t, curr_sourceerr);
605  // Save states for use by the 2nd dimension only if needed.
606  if (Is2D()) {
607  CopyVector(ns_, curr_stateerr, stateerr[mod_t]);
608  CopyVector(na_, curr_sourceerr, sourceerr[mod_t]);
609  }
610  } while (dest_index.Decrement());
611 #if DEBUG_DETAIL > 2
612  for (int w = 0; w < WT_COUNT; ++w) {
613  tprintf("%s gate errors[%d]\n", name_.string(), w);
614  gate_errors_t[w].get()->PrintUnTransposed(10);
615  }
616 #endif
617  // Transposed source_ used to speed-up SumOuter.
618  NetworkScratch::GradientStore source_t, state_t;
619  source_t.Init(na_, width, scratch);
620  source_.Transpose(source_t.get());
621  state_t.Init(ns_, width, scratch);
622  state_.Transpose(state_t.get());
623 #ifdef _OPENMP
624 #pragma omp parallel for num_threads(GFS) if (!Is2D())
625 #endif
626  for (int w = 0; w < WT_COUNT; ++w) {
627  if (w == GFS && !Is2D()) continue;
628  gate_weights_[w].SumOuterTransposed(*gate_errors_t[w], *source_t, false);
629  }
630  if (softmax_ != NULL) {
631  softmax_->FinishBackward(*softmax_errors_t);
632  }
633  return needs_to_backprop_;
634 }
635 
636 // Updates the weights using the given learning rate, momentum and adam_beta.
637 // num_samples is used in the adam computation iff use_adam_ is true.
638 void LSTM::Update(float learning_rate, float momentum, float adam_beta,
639  int num_samples) {
640 #if DEBUG_DETAIL > 3
641  PrintW();
642 #endif
643  for (int w = 0; w < WT_COUNT; ++w) {
644  if (w == GFS && !Is2D()) continue;
645  gate_weights_[w].Update(learning_rate, momentum, adam_beta, num_samples);
646  }
647  if (softmax_ != NULL) {
648  softmax_->Update(learning_rate, momentum, adam_beta, num_samples);
649  }
650 #if DEBUG_DETAIL > 3
651  PrintDW();
652 #endif
653 }
654 
655 // Sums the products of weight updates in *this and other, splitting into
656 // positive (same direction) in *same and negative (different direction) in
657 // *changed.
658 void LSTM::CountAlternators(const Network& other, double* same,
659  double* changed) const {
660  ASSERT_HOST(other.type() == type_);
661  const LSTM* lstm = static_cast<const LSTM*>(&other);
662  for (int w = 0; w < WT_COUNT; ++w) {
663  if (w == GFS && !Is2D()) continue;
664  gate_weights_[w].CountAlternators(lstm->gate_weights_[w], same, changed);
665  }
666  if (softmax_ != NULL) {
667  softmax_->CountAlternators(*lstm->softmax_, same, changed);
668  }
669 }
670 
671 // Prints the weights for debug purposes.
672 void LSTM::PrintW() {
673  tprintf("Weight state:%s\n", name_.string());
674  for (int w = 0; w < WT_COUNT; ++w) {
675  if (w == GFS && !Is2D()) continue;
676  tprintf("Gate %d, inputs\n", w);
677  for (int i = 0; i < ni_; ++i) {
678  tprintf("Row %d:", i);
679  for (int s = 0; s < ns_; ++s)
680  tprintf(" %g", gate_weights_[w].GetWeights(s)[i]);
681  tprintf("\n");
682  }
683  tprintf("Gate %d, outputs\n", w);
684  for (int i = ni_; i < ni_ + ns_; ++i) {
685  tprintf("Row %d:", i - ni_);
686  for (int s = 0; s < ns_; ++s)
687  tprintf(" %g", gate_weights_[w].GetWeights(s)[i]);
688  tprintf("\n");
689  }
690  tprintf("Gate %d, bias\n", w);
691  for (int s = 0; s < ns_; ++s)
692  tprintf(" %g", gate_weights_[w].GetWeights(s)[na_]);
693  tprintf("\n");
694  }
695 }
696 
697 // Prints the weight deltas for debug purposes.
699  tprintf("Delta state:%s\n", name_.string());
700  for (int w = 0; w < WT_COUNT; ++w) {
701  if (w == GFS && !Is2D()) continue;
702  tprintf("Gate %d, inputs\n", w);
703  for (int i = 0; i < ni_; ++i) {
704  tprintf("Row %d:", i);
705  for (int s = 0; s < ns_; ++s)
706  tprintf(" %g", gate_weights_[w].GetDW(s, i));
707  tprintf("\n");
708  }
709  tprintf("Gate %d, outputs\n", w);
710  for (int i = ni_; i < ni_ + ns_; ++i) {
711  tprintf("Row %d:", i - ni_);
712  for (int s = 0; s < ns_; ++s)
713  tprintf(" %g", gate_weights_[w].GetDW(s, i));
714  tprintf("\n");
715  }
716  tprintf("Gate %d, bias\n", w);
717  for (int s = 0; s < ns_; ++s)
718  tprintf(" %g", gate_weights_[w].GetDW(s, na_));
719  tprintf("\n");
720  }
721 }
722 
723 // Resizes forward data to cope with an input image of the given width.
724 void LSTM::ResizeForward(const NetworkIO& input) {
725  int rounded_inputs = gate_weights_[CI].RoundInputs(na_);
726  source_.Resize(input, rounded_inputs);
727  which_fg_.ResizeNoInit(input.Width(), ns_);
728  if (IsTraining()) {
729  state_.ResizeFloat(input, ns_);
730  for (int w = 0; w < WT_COUNT; ++w) {
731  if (w == GFS && !Is2D()) continue;
732  node_values_[w].ResizeFloat(input, ns_);
733  }
734  }
735 }
736 
737 
738 } // namespace tesseract.
virtual void SetEnableTraining(TrainingState state)
virtual int InitWeights(float range, TRand *randomizer)
void CopyVector(int n, const double *src, double *dest)
Definition: functions.h:186
virtual bool Serialize(TFile *fp) const
Definition: network.cpp:153
int FWrite(const void *buffer, int size, int count)
Definition: serialis.cpp:148
NetworkType
Definition: network.h:43
void Resize(const NetworkIO &src, int num_features)
Definition: networkio.h:45
virtual StaticShape OutputShape(const StaticShape &input_shape) const
virtual int InitWeights(float range, TRand *randomizer)
Definition: lstm.cpp:129
virtual void ConvertToInt()
Definition: lstm.cpp:154
bool IsLast(FlexDimensions dimension) const
Definition: stridemap.cpp:37
virtual void CountAlternators(const Network &other, double *same, double *changed) const
Definition: lstm.cpp:658
void BackwardTimeStep(const NetworkIO &fwd_deltas, int t, double *curr_errors, TransposedArray *errors_t, double *backprop)
void ResizeNoInit(int size1, int size2, int pad=0)
Definition: matrix.h:88
void PrintW()
Definition: lstm.cpp:672
void CopyTimeStepGeneral(int dest_t, int dest_offset, int num_features, const NetworkIO &src, int src_t, int src_offset)
Definition: networkio.cpp:398
void ClipVector(int n, T lower, T upper, T *vec)
Definition: functions.h:225
virtual void SetRandomizer(TRand *randomizer)
Definition: network.cpp:140
void DisplayForward(const NetworkIO &matrix)
Definition: network.cpp:285
virtual StaticShape OutputShape(const StaticShape &input_shape) const
Definition: lstm.cpp:98
const double kStateClip
Definition: lstm.cpp:66
void WriteTimeStepPart(int t, int offset, int num_features, const double *input)
Definition: networkio.cpp:656
void ForwardTimeStep(const double *d_input, const inT8 *i_input, int t, double *output_line)
void Resize2d(bool int_mode, int width, int num_features, NetworkScratch *scratch)
void add_str_int(const char *str, int number)
Definition: strngs.cpp:381
virtual bool Serialize(TFile *fp) const
Definition: lstm.cpp:178
void ResizeXTo1(const NetworkIO &src, int num_features)
Definition: networkio.cpp:75
LSTM(const STRING &name, int num_inputs, int num_states, int num_outputs, bool two_dimensional, NetworkType type)
Definition: lstm.cpp:70
TrainingState training_
Definition: network.h:300
virtual void SetEnableTraining(TrainingState state)
Definition: lstm.cpp:108
bool TestFlag(NetworkFlags flag) const
Definition: network.h:144
NetworkType type_
Definition: network.h:299
void MatrixDotVector(const double *u, double *v) const
void PrintUnTransposed(int num)
Definition: weightmatrix.h:48
#define SECTION_IF_OPENMP
Definition: lstm.cpp:57
int Size(FlexDimensions dimension) const
Definition: stridemap.h:116
int InitWeightsFloat(int no, int ni, bool use_adam, float weight_range, TRand *randomizer)
#define PARALLEL_IF_OPENMP(__num_threads)
Definition: lstm.cpp:56
virtual void CountAlternators(const Network &other, double *same, double *changed) const
void FinishBackward(const TransposedArray &errors_t)
const StrideMap & stride_map() const
Definition: networkio.h:133
void FuncMultiply3Add(const NetworkIO &v_io, int t, const double *w, double *product) const
Definition: networkio.h:299
void Func2Multiply3(const NetworkIO &v_io, int t, const double *w, double *product) const
Definition: networkio.h:315
void VectorDotMatrix(const double *u, double *v) const
void Init(int size1, int size2, NetworkScratch *scratch)
#define tprintf(...)
Definition: tprintf.h:31
const double kErrClip
Definition: lstm.cpp:68
virtual bool Serialize(TFile *fp) const
void ReadTimeStep(int t, double *output) const
Definition: networkio.cpp:603
void MultiplyVectorsInPlace(int n, const double *src, double *inout)
Definition: functions.h:196
bool IsTraining() const
Definition: network.h:115
void CountAlternators(const WeightMatrix &other, double *same, double *changed) const
int8_t inT8
Definition: host.h:34
void SumOuterTransposed(const TransposedArray &u, const TransposedArray &v, bool parallel)
bool int_mode() const
Definition: networkio.h:127
bool Is2D() const
Definition: lstm.h:120
const char * string() const
Definition: strngs.cpp:198
void ResizeToMap(bool int_mode, const StrideMap &stride_map, int num_features)
Definition: networkio.cpp:50
float * f(int t)
Definition: networkio.h:115
void Transpose(TransposedArray *dest) const
Definition: networkio.cpp:969
void set_width(int value)
Definition: static_shape.h:45
void PrintDW()
Definition: lstm.cpp:698
virtual bool DeSerialize(TFile *fp)
Definition: lstm.cpp:191
Definition: strngs.h:45
inT32 num_weights_
Definition: network.h:305
void WriteTimeStep(int t, const double *input)
Definition: networkio.cpp:650
void Debug2D(const char *msg)
TrainingState
Definition: network.h:92
void Update(double learning_rate, double momentum, double adam_beta, int num_samples)
int num_weights() const
Definition: network.h:119
#define ASSERT_HOST(x)
Definition: errcode.h:84
int RemapOutputs(int old_no, const std::vector< int > &code_map) override
void ResizeFloat(const NetworkIO &src, int num_features)
Definition: networkio.h:52
NetworkType type() const
Definition: network.h:112
int RemapOutputs(int old_no, const std::vector< int > &code_map) override
Definition: lstm.cpp:145
int Width() const
Definition: networkio.h:107
static Network * CreateFromFile(TFile *fp)
Definition: network.cpp:203
int FReadEndian(void *buffer, int size, int count)
Definition: serialis.cpp:97
void MultiplyAccumulate(int n, const double *u, const double *v, double *out)
Definition: functions.h:201
void DisplayBackward(const NetworkIO &matrix)
Definition: network.cpp:296
void Print(int num) const
Definition: networkio.cpp:371
virtual void DebugWeights()
Definition: lstm.cpp:165
void Init(int size, NetworkScratch *scratch)
int index(FlexDimensions dimension) const
Definition: stridemap.h:60
void set_depth(int value)
Definition: static_shape.h:47
#define END_PARALLEL_IF_OPENMP
Definition: lstm.cpp:58
void SumVectors(int n, const double *v1, const double *v2, const double *v3, const double *v4, const double *v5, double *sum)
Definition: functions.h:209
bool needs_to_backprop_
Definition: network.h:301
int Modulo(int a, int b)
Definition: helpers.h:164
bool AddOffset(int offset, FlexDimensions dimension)
Definition: stridemap.cpp:62
void SetupForward(const NetworkIO &input, const TransposedArray *input_transpose)
int RoundInputs(int size) const
Definition: weightmatrix.h:92
void Update(float learning_rate, float momentum, float adam_beta, int num_samples) override
Definition: lstm.cpp:638
virtual void Forward(bool debug, const NetworkIO &input, const TransposedArray *input_transpose, NetworkScratch *scratch, NetworkIO *output)
Definition: lstm.cpp:221
void Update(float learning_rate, float momentum, float adam_beta, int num_samples) override
void WriteStrided(int t, const float *data)
Definition: weightmatrix.h:39
int IntCastRounded(double x)
Definition: helpers.h:179
void init_to_size(int size, T t)
void AccumulateVector(int n, const double *src, double *dest)
Definition: functions.h:191
void CodeInBinary(int n, int nf, double *vec)
Definition: functions.h:231
virtual ~LSTM()
Definition: lstm.cpp:94
virtual bool Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, NetworkIO *back_deltas)
Definition: lstm.cpp:412
const inT8 * i(int t) const
Definition: networkio.h:123