tesseract v5.3.3.20231005
lstm.cpp
Go to the documentation of this file.
1
2// File: lstm.cpp
3// Description: Long-term-short-term-memory Recurrent neural network.
4// Author: Ray Smith
5//
6// (C) Copyright 2013, Google Inc.
7// Licensed under the Apache License, Version 2.0 (the "License");
8// you may not use this file except in compliance with the License.
9// You may obtain a copy of the License at
10// http://www.apache.org/licenses/LICENSE-2.0
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
17
18#ifdef HAVE_CONFIG_H
19# include "config_auto.h"
20#endif
21
22#include "lstm.h"
23
24#ifdef _OPENMP
25# include <omp.h>
26#endif
27#include <cstdio>
28#include <cstdlib>
29#include <sstream> // for std::ostringstream
30
31#if defined(_MSC_VER) && !defined(__clang__)
32# include <intrin.h> // _BitScanReverse
33#endif
34
35#include "fullyconnected.h"
36#include "functions.h"
37#include "networkscratch.h"
38#include "tprintf.h"
39
40// Macros for openmp code if it is available, otherwise empty macros.
41#ifdef _OPENMP
42# define PARALLEL_IF_OPENMP(__num_threads) \
43 PRAGMA(omp parallel if (__num_threads > 1) num_threads(__num_threads)) { \
44 PRAGMA(omp sections nowait) { \
45 PRAGMA(omp section) {
46# define SECTION_IF_OPENMP \
47 } \
48 PRAGMA(omp section) {
49# define END_PARALLEL_IF_OPENMP \
50 } \
51 } /* end of sections */ \
52 } /* end of parallel section */
53
54// Define the portable PRAGMA macro.
55# ifdef _MSC_VER // Different _Pragma
56# define PRAGMA(x) __pragma(x)
57# else
58# define PRAGMA(x) _Pragma(# x)
59# endif // _MSC_VER
60
61#else // _OPENMP
62# define PARALLEL_IF_OPENMP(__num_threads)
63# define SECTION_IF_OPENMP
64# define END_PARALLEL_IF_OPENMP
65#endif // _OPENMP
66
67namespace tesseract {
68
69// Max absolute value of state_. It is reasonably high to enable the state
70// to count things.
71const TFloat kStateClip = 100.0;
72// Max absolute value of gate_errors (the gradients).
73const TFloat kErrClip = 1.0f;
74
75// Calculate ceil(log2(n)).
76static inline uint32_t ceil_log2(uint32_t n) {
77 // l2 = (unsigned)log2(n).
78#if defined(__GNUC__)
79 // Use fast inline assembler code for gcc or clang.
80 uint32_t l2 = 31 - __builtin_clz(n);
81#elif defined(_MSC_VER)
82 // Use fast intrinsic function for MS compiler.
83 unsigned long l2 = 0;
84 _BitScanReverse(&l2, n);
85#else
86 if (n == 0)
87 return UINT_MAX;
88 if (n == 1)
89 return 0;
90 uint32_t val = n;
91 uint32_t l2 = 0;
92 while (val > 1) {
93 val >>= 1;
94 l2++;
95 }
96#endif
97 // Round up if n is not a power of 2.
98 return (n == (1u << l2)) ? l2 : l2 + 1;
99}
100
101LSTM::LSTM(const std::string &name, int ni, int ns, int no, bool two_dimensional, NetworkType type)
102 : Network(type, name, ni, no)
103 , na_(ni + ns)
104 , ns_(ns)
105 , nf_(0)
106 , is_2d_(two_dimensional)
107 , softmax_(nullptr)
108 , input_width_(0) {
109 if (two_dimensional) {
110 na_ += ns_;
111 }
112 if (type_ == NT_LSTM || type_ == NT_LSTM_SUMMARY) {
113 nf_ = 0;
114 // networkbuilder ensures this is always true.
115 ASSERT_HOST(no == ns);
117 nf_ = type_ == NT_LSTM_SOFTMAX ? no_ : ceil_log2(no_);
118 softmax_ = new FullyConnected("LSTM Softmax", ns_, no_, NT_SOFTMAX);
119 } else {
120 tprintf("%d is invalid type of LSTM!\n", type);
121 ASSERT_HOST(false);
122 }
123 na_ += nf_;
124}
125
127 delete softmax_;
128}
129
130// Returns the shape output from the network given an input shape (which may
131// be partially unknown ie zero).
132StaticShape LSTM::OutputShape(const StaticShape &input_shape) const {
133 StaticShape result = input_shape;
134 result.set_depth(no_);
135 if (type_ == NT_LSTM_SUMMARY) {
136 result.set_width(1);
137 }
138 if (softmax_ != nullptr) {
139 return softmax_->OutputShape(result);
140 }
141 return result;
142}
143
144// Suspends/Enables training by setting the training_ flag. Serialize and
145// DeSerialize only operate on the run-time data if state is false.
147 if (state == TS_RE_ENABLE) {
148 // Enable only from temp disabled.
149 if (training_ == TS_TEMP_DISABLE) {
151 }
152 } else if (state == TS_TEMP_DISABLE) {
153 // Temp disable only from enabled.
154 if (training_ == TS_ENABLED) {
155 training_ = state;
156 }
157 } else {
158 if (state == TS_ENABLED && training_ != TS_ENABLED) {
159 for (int w = 0; w < WT_COUNT; ++w) {
160 if (w == GFS && !Is2D()) {
161 continue;
162 }
163 gate_weights_[w].InitBackward();
164 }
165 }
166 training_ = state;
167 }
168 if (softmax_ != nullptr) {
169 softmax_->SetEnableTraining(state);
170 }
171}
172
173// Sets up the network for training. Initializes weights using weights of
174// scale `range` picked according to the random number generator `randomizer`.
175int LSTM::InitWeights(float range, TRand *randomizer) {
176 Network::SetRandomizer(randomizer);
177 num_weights_ = 0;
178 for (int w = 0; w < WT_COUNT; ++w) {
179 if (w == GFS && !Is2D()) {
180 continue;
181 }
182 num_weights_ +=
183 gate_weights_[w].InitWeightsFloat(ns_, na_ + 1, TestFlag(NF_ADAM), range, randomizer);
184 }
185 if (softmax_ != nullptr) {
186 num_weights_ += softmax_->InitWeights(range, randomizer);
187 }
188 return num_weights_;
189}
190
191// Recursively searches the network for softmaxes with old_no outputs,
192// and remaps their outputs according to code_map. See network.h for details.
193int LSTM::RemapOutputs(int old_no, const std::vector<int> &code_map) {
194 if (softmax_ != nullptr) {
195 num_weights_ -= softmax_->num_weights();
196 num_weights_ += softmax_->RemapOutputs(old_no, code_map);
197 }
198 return num_weights_;
199}
200
201// Converts a float network to an int network.
203 for (int w = 0; w < WT_COUNT; ++w) {
204 if (w == GFS && !Is2D()) {
205 continue;
206 }
207 gate_weights_[w].ConvertToInt();
208 }
209 if (softmax_ != nullptr) {
210 softmax_->ConvertToInt();
211 }
212}
213
214// Sets up the network for training using the given weight_range.
216 for (int w = 0; w < WT_COUNT; ++w) {
217 if (w == GFS && !Is2D()) {
218 continue;
219 }
220 std::ostringstream msg;
221 msg << name_ << " Gate weights " << w;
222 gate_weights_[w].Debug2D(msg.str().c_str());
223 }
224 if (softmax_ != nullptr) {
225 softmax_->DebugWeights();
226 }
227}
228
229// Writes to the given file. Returns false in case of error.
230bool LSTM::Serialize(TFile *fp) const {
231 if (!Network::Serialize(fp)) {
232 return false;
233 }
234 if (!fp->Serialize(&na_)) {
235 return false;
236 }
237 for (int w = 0; w < WT_COUNT; ++w) {
238 if (w == GFS && !Is2D()) {
239 continue;
240 }
241 if (!gate_weights_[w].Serialize(IsTraining(), fp)) {
242 return false;
243 }
244 }
245 if (softmax_ != nullptr && !softmax_->Serialize(fp)) {
246 return false;
247 }
248 return true;
249}
250
251// Reads from the given file. Returns false in case of error.
252
254 if (!fp->DeSerialize(&na_)) {
255 return false;
256 }
257 if (type_ == NT_LSTM_SOFTMAX) {
258 nf_ = no_;
259 } else if (type_ == NT_LSTM_SOFTMAX_ENCODED) {
260 nf_ = ceil_log2(no_);
261 } else {
262 nf_ = 0;
263 }
264 is_2d_ = false;
265 for (int w = 0; w < WT_COUNT; ++w) {
266 if (w == GFS && !Is2D()) {
267 continue;
268 }
269 if (!gate_weights_[w].DeSerialize(IsTraining(), fp)) {
270 return false;
271 }
272 if (w == CI) {
273 ns_ = gate_weights_[CI].NumOutputs();
274 is_2d_ = na_ - nf_ == ni_ + 2 * ns_;
275 }
276 }
277 delete softmax_;
279 softmax_ = static_cast<FullyConnected *>(Network::CreateFromFile(fp));
280 if (softmax_ == nullptr) {
281 return false;
282 }
283 } else {
284 softmax_ = nullptr;
285 }
286 return true;
287}
288
289// Runs forward propagation of activations on the input line.
290// See NetworkCpp for a detailed discussion of the arguments.
291void LSTM::Forward(bool debug, const NetworkIO &input, const TransposedArray *input_transpose,
292 NetworkScratch *scratch, NetworkIO *output) {
293 input_map_ = input.stride_map();
294 input_width_ = input.Width();
295 if (softmax_ != nullptr) {
296 output->ResizeFloat(input, no_);
297 } else if (type_ == NT_LSTM_SUMMARY) {
298 output->ResizeXTo1(input, no_);
299 } else {
300 output->Resize(input, no_);
301 }
302 ResizeForward(input);
303 // Temporary storage of forward computation for each gate.
305 int ro = ns_;
306 if (source_.int_mode() && IntSimdMatrix::intSimdMatrix) {
308 }
309 for (auto &temp_line : temp_lines) {
310 temp_line.Init(ns_, ro, scratch);
311 }
312 // Single timestep buffers for the current/recurrent output and state.
313 NetworkScratch::FloatVec curr_state, curr_output;
314 curr_state.Init(ns_, scratch);
315 ZeroVector<TFloat>(ns_, curr_state);
316 curr_output.Init(ns_, scratch);
317 ZeroVector<TFloat>(ns_, curr_output);
318 // Rotating buffers of width buf_width allow storage of the state and output
319 // for the other dimension, used only when working in true 2D mode. The width
320 // is enough to hold an entire strip of the major direction.
321 int buf_width = Is2D() ? input_map_.Size(FD_WIDTH) : 1;
322 std::vector<NetworkScratch::FloatVec> states, outputs;
323 if (Is2D()) {
324 states.resize(buf_width);
325 outputs.resize(buf_width);
326 for (int i = 0; i < buf_width; ++i) {
327 states[i].Init(ns_, scratch);
328 ZeroVector<TFloat>(ns_, states[i]);
329 outputs[i].Init(ns_, scratch);
330 ZeroVector<TFloat>(ns_, outputs[i]);
331 }
332 }
333 // Used only if a softmax LSTM.
334 NetworkScratch::FloatVec softmax_output;
335 NetworkScratch::IO int_output;
336 if (softmax_ != nullptr) {
337 softmax_output.Init(no_, scratch);
338 ZeroVector<TFloat>(no_, softmax_output);
339 int rounded_softmax_inputs = gate_weights_[CI].RoundInputs(ns_);
340 if (input.int_mode()) {
341 int_output.Resize2d(true, 1, rounded_softmax_inputs, scratch);
342 }
343 softmax_->SetupForward(input, nullptr);
344 }
345 NetworkScratch::FloatVec curr_input;
346 curr_input.Init(na_, scratch);
347 StrideMap::Index src_index(input_map_);
348 // Used only by NT_LSTM_SUMMARY.
349 StrideMap::Index dest_index(output->stride_map());
350 do {
351 int t = src_index.t();
352 // True if there is a valid old state for the 2nd dimension.
353 bool valid_2d = Is2D();
354 if (valid_2d) {
355 StrideMap::Index dim_index(src_index);
356 if (!dim_index.AddOffset(-1, FD_HEIGHT)) {
357 valid_2d = false;
358 }
359 }
360 // Index of the 2-D revolving buffers (outputs, states).
361 int mod_t = Modulo(t, buf_width); // Current timestep.
362 // Setup the padded input in source.
363 source_.CopyTimeStepGeneral(t, 0, ni_, input, t, 0);
364 if (softmax_ != nullptr) {
365 source_.WriteTimeStepPart(t, ni_, nf_, softmax_output);
366 }
367 source_.WriteTimeStepPart(t, ni_ + nf_, ns_, curr_output);
368 if (Is2D()) {
369 source_.WriteTimeStepPart(t, ni_ + nf_ + ns_, ns_, outputs[mod_t]);
370 }
371 if (!source_.int_mode()) {
372 source_.ReadTimeStep(t, curr_input);
373 }
374 // Matrix multiply the inputs with the source.
376 // It looks inefficient to create the threads on each t iteration, but the
377 // alternative of putting the parallel outside the t loop, a single around
378 // the t-loop and then tasks in place of the sections is a *lot* slower.
379 // Cell inputs.
380 if (source_.int_mode()) {
381 gate_weights_[CI].MatrixDotVector(source_.i(t), temp_lines[CI]);
382 } else {
383 gate_weights_[CI].MatrixDotVector(curr_input, temp_lines[CI]);
384 }
385 FuncInplace<GFunc>(ns_, temp_lines[CI]);
386
388 // Input Gates.
389 if (source_.int_mode()) {
390 gate_weights_[GI].MatrixDotVector(source_.i(t), temp_lines[GI]);
391 } else {
392 gate_weights_[GI].MatrixDotVector(curr_input, temp_lines[GI]);
393 }
394 FuncInplace<FFunc>(ns_, temp_lines[GI]);
395
397 // 1-D forget gates.
398 if (source_.int_mode()) {
399 gate_weights_[GF1].MatrixDotVector(source_.i(t), temp_lines[GF1]);
400 } else {
401 gate_weights_[GF1].MatrixDotVector(curr_input, temp_lines[GF1]);
402 }
403 FuncInplace<FFunc>(ns_, temp_lines[GF1]);
404
405 // 2-D forget gates.
406 if (Is2D()) {
407 if (source_.int_mode()) {
408 gate_weights_[GFS].MatrixDotVector(source_.i(t), temp_lines[GFS]);
409 } else {
410 gate_weights_[GFS].MatrixDotVector(curr_input, temp_lines[GFS]);
411 }
412 FuncInplace<FFunc>(ns_, temp_lines[GFS]);
413 }
414
416 // Output gates.
417 if (source_.int_mode()) {
418 gate_weights_[GO].MatrixDotVector(source_.i(t), temp_lines[GO]);
419 } else {
420 gate_weights_[GO].MatrixDotVector(curr_input, temp_lines[GO]);
421 }
422 FuncInplace<FFunc>(ns_, temp_lines[GO]);
424
425 // Apply forget gate to state.
426 MultiplyVectorsInPlace(ns_, temp_lines[GF1], curr_state);
427 if (Is2D()) {
428 // Max-pool the forget gates (in 2-d) instead of blindly adding.
429 int8_t *which_fg_col = which_fg_[t];
430 memset(which_fg_col, 1, ns_ * sizeof(which_fg_col[0]));
431 if (valid_2d) {
432 const TFloat *stepped_state = states[mod_t];
433 for (int i = 0; i < ns_; ++i) {
434 if (temp_lines[GF1][i] < temp_lines[GFS][i]) {
435 curr_state[i] = temp_lines[GFS][i] * stepped_state[i];
436 which_fg_col[i] = 2;
437 }
438 }
439 }
440 }
441 MultiplyAccumulate(ns_, temp_lines[CI], temp_lines[GI], curr_state);
442 // Clip curr_state to a sane range.
443 ClipVector<TFloat>(ns_, -kStateClip, kStateClip, curr_state);
444 if (IsTraining()) {
445 // Save the gate node values.
446 node_values_[CI].WriteTimeStep(t, temp_lines[CI]);
447 node_values_[GI].WriteTimeStep(t, temp_lines[GI]);
448 node_values_[GF1].WriteTimeStep(t, temp_lines[GF1]);
449 node_values_[GO].WriteTimeStep(t, temp_lines[GO]);
450 if (Is2D()) {
451 node_values_[GFS].WriteTimeStep(t, temp_lines[GFS]);
452 }
453 }
454 FuncMultiply<HFunc>(curr_state, temp_lines[GO], ns_, curr_output);
455 if (IsTraining()) {
456 state_.WriteTimeStep(t, curr_state);
457 }
458 if (softmax_ != nullptr) {
459 if (input.int_mode()) {
460 int_output->WriteTimeStepPart(0, 0, ns_, curr_output);
461 softmax_->ForwardTimeStep(int_output->i(0), t, softmax_output);
462 } else {
463 softmax_->ForwardTimeStep(curr_output, t, softmax_output);
464 }
465 output->WriteTimeStep(t, softmax_output);
467 CodeInBinary(no_, nf_, softmax_output);
468 }
469 } else if (type_ == NT_LSTM_SUMMARY) {
470 // Output only at the end of a row.
471 if (src_index.IsLast(FD_WIDTH)) {
472 output->WriteTimeStep(dest_index.t(), curr_output);
473 dest_index.Increment();
474 }
475 } else {
476 output->WriteTimeStep(t, curr_output);
477 }
478 // Save states for use by the 2nd dimension only if needed.
479 if (Is2D()) {
480 CopyVector(ns_, curr_state, states[mod_t]);
481 CopyVector(ns_, curr_output, outputs[mod_t]);
482 }
483 // Always zero the states at the end of every row, but only for the major
484 // direction. The 2-D state remains intact.
485 if (src_index.IsLast(FD_WIDTH)) {
486 ZeroVector<TFloat>(ns_, curr_state);
487 ZeroVector<TFloat>(ns_, curr_output);
488 }
489 } while (src_index.Increment());
490#if DEBUG_DETAIL > 0
491 tprintf("Source:%s\n", name_.c_str());
492 source_.Print(10);
493 tprintf("State:%s\n", name_.c_str());
494 state_.Print(10);
495 tprintf("Output:%s\n", name_.c_str());
496 output->Print(10);
497#endif
498#ifndef GRAPHICS_DISABLED
499 if (debug) {
501 }
502#endif
503}
504
505// Runs backward propagation of errors on the deltas line.
506// See NetworkCpp for a detailed discussion of the arguments.
507bool LSTM::Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch,
508 NetworkIO *back_deltas) {
509#ifndef GRAPHICS_DISABLED
510 if (debug) {
511 DisplayBackward(fwd_deltas);
512 }
513#endif
514 back_deltas->ResizeToMap(fwd_deltas.int_mode(), input_map_, ni_);
515 // ======Scratch space.======
516 // Output errors from deltas with recurrence from sourceerr.
517 NetworkScratch::FloatVec outputerr;
518 outputerr.Init(ns_, scratch);
519 // Recurrent error in the state/source.
520 NetworkScratch::FloatVec curr_stateerr, curr_sourceerr;
521 curr_stateerr.Init(ns_, scratch);
522 curr_sourceerr.Init(na_, scratch);
523 ZeroVector<TFloat>(ns_, curr_stateerr);
524 ZeroVector<TFloat>(na_, curr_sourceerr);
525 // Errors in the gates.
527 for (auto &gate_error : gate_errors) {
528 gate_error.Init(ns_, scratch);
529 }
530 // Rotating buffers of width buf_width allow storage of the recurrent time-
531 // steps used only for true 2-D. Stores one full strip of the major direction.
532 int buf_width = Is2D() ? input_map_.Size(FD_WIDTH) : 1;
533 std::vector<NetworkScratch::FloatVec> stateerr, sourceerr;
534 if (Is2D()) {
535 stateerr.resize(buf_width);
536 sourceerr.resize(buf_width);
537 for (int t = 0; t < buf_width; ++t) {
538 stateerr[t].Init(ns_, scratch);
539 sourceerr[t].Init(na_, scratch);
540 ZeroVector<TFloat>(ns_, stateerr[t]);
541 ZeroVector<TFloat>(na_, sourceerr[t]);
542 }
543 }
544 // Parallel-generated sourceerr from each of the gates.
545 NetworkScratch::FloatVec sourceerr_temps[WT_COUNT];
546 for (auto &sourceerr_temp : sourceerr_temps) {
547 sourceerr_temp.Init(na_, scratch);
548 }
549 int width = input_width_;
550 // Transposed gate errors stored over all timesteps for sum outer.
552 for (auto &w : gate_errors_t) {
553 w.Init(ns_, width, scratch);
554 }
555 // Used only if softmax_ != nullptr.
556 NetworkScratch::FloatVec softmax_errors;
557 NetworkScratch::GradientStore softmax_errors_t;
558 if (softmax_ != nullptr) {
559 softmax_errors.Init(no_, scratch);
560 softmax_errors_t.Init(no_, width, scratch);
561 }
562 TFloat state_clip = Is2D() ? 9.0 : 4.0;
563#if DEBUG_DETAIL > 1
564 tprintf("fwd_deltas:%s\n", name_.c_str());
565 fwd_deltas.Print(10);
566#endif
567 StrideMap::Index dest_index(input_map_);
568 dest_index.InitToLast();
569 // Used only by NT_LSTM_SUMMARY.
570 StrideMap::Index src_index(fwd_deltas.stride_map());
571 src_index.InitToLast();
572 do {
573 int t = dest_index.t();
574 bool at_last_x = dest_index.IsLast(FD_WIDTH);
575 // up_pos is the 2-D back step, down_pos is the 2-D fwd step, and are only
576 // valid if >= 0, which is true if 2d and not on the top/bottom.
577 int up_pos = -1;
578 int down_pos = -1;
579 if (Is2D()) {
580 if (dest_index.index(FD_HEIGHT) > 0) {
581 StrideMap::Index up_index(dest_index);
582 if (up_index.AddOffset(-1, FD_HEIGHT)) {
583 up_pos = up_index.t();
584 }
585 }
586 if (!dest_index.IsLast(FD_HEIGHT)) {
587 StrideMap::Index down_index(dest_index);
588 if (down_index.AddOffset(1, FD_HEIGHT)) {
589 down_pos = down_index.t();
590 }
591 }
592 }
593 // Index of the 2-D revolving buffers (sourceerr, stateerr).
594 int mod_t = Modulo(t, buf_width); // Current timestep.
595 // Zero the state in the major direction only at the end of every row.
596 if (at_last_x) {
597 ZeroVector<TFloat>(na_, curr_sourceerr);
598 ZeroVector<TFloat>(ns_, curr_stateerr);
599 }
600 // Setup the outputerr.
601 if (type_ == NT_LSTM_SUMMARY) {
602 if (dest_index.IsLast(FD_WIDTH)) {
603 fwd_deltas.ReadTimeStep(src_index.t(), outputerr);
604 src_index.Decrement();
605 } else {
606 ZeroVector<TFloat>(ns_, outputerr);
607 }
608 } else if (softmax_ == nullptr) {
609 fwd_deltas.ReadTimeStep(t, outputerr);
610 } else {
611 softmax_->BackwardTimeStep(fwd_deltas, t, softmax_errors, softmax_errors_t.get(), outputerr);
612 }
613 if (!at_last_x) {
614 AccumulateVector(ns_, curr_sourceerr + ni_ + nf_, outputerr);
615 }
616 if (down_pos >= 0) {
617 AccumulateVector(ns_, sourceerr[mod_t] + ni_ + nf_ + ns_, outputerr);
618 }
619 // Apply the 1-d forget gates.
620 if (!at_last_x) {
621 const float *next_node_gf1 = node_values_[GF1].f(t + 1);
622 for (int i = 0; i < ns_; ++i) {
623 curr_stateerr[i] *= next_node_gf1[i];
624 }
625 }
626 if (Is2D() && t + 1 < width) {
627 for (int i = 0; i < ns_; ++i) {
628 if (which_fg_[t + 1][i] != 1) {
629 curr_stateerr[i] = 0.0;
630 }
631 }
632 if (down_pos >= 0) {
633 const float *right_node_gfs = node_values_[GFS].f(down_pos);
634 const TFloat *right_stateerr = stateerr[mod_t];
635 for (int i = 0; i < ns_; ++i) {
636 if (which_fg_[down_pos][i] == 2) {
637 curr_stateerr[i] += right_stateerr[i] * right_node_gfs[i];
638 }
639 }
640 }
641 }
642 state_.FuncMultiply3Add<HPrime>(node_values_[GO], t, outputerr, curr_stateerr);
643 // Clip stateerr_ to a sane range.
644 ClipVector<TFloat>(ns_, -state_clip, state_clip, curr_stateerr);
645#if DEBUG_DETAIL > 1
646 if (t + 10 > width) {
647 tprintf("t=%d, stateerr=", t);
648 for (int i = 0; i < ns_; ++i)
649 tprintf(" %g,%g,%g", curr_stateerr[i], outputerr[i], curr_sourceerr[ni_ + nf_ + i]);
650 tprintf("\n");
651 }
652#endif
653 // Matrix multiply to get the source errors.
655
656 // Cell inputs.
657 node_values_[CI].FuncMultiply3<GPrime>(t, node_values_[GI], t, curr_stateerr, gate_errors[CI]);
658 ClipVector(ns_, -kErrClip, kErrClip, gate_errors[CI].get());
659 gate_weights_[CI].VectorDotMatrix(gate_errors[CI], sourceerr_temps[CI]);
660 gate_errors_t[CI].get()->WriteStrided(t, gate_errors[CI]);
661
663 // Input Gates.
664 node_values_[GI].FuncMultiply3<FPrime>(t, node_values_[CI], t, curr_stateerr, gate_errors[GI]);
665 ClipVector(ns_, -kErrClip, kErrClip, gate_errors[GI].get());
666 gate_weights_[GI].VectorDotMatrix(gate_errors[GI], sourceerr_temps[GI]);
667 gate_errors_t[GI].get()->WriteStrided(t, gate_errors[GI]);
668
670 // 1-D forget Gates.
671 if (t > 0) {
672 node_values_[GF1].FuncMultiply3<FPrime>(t, state_, t - 1, curr_stateerr, gate_errors[GF1]);
673 ClipVector(ns_, -kErrClip, kErrClip, gate_errors[GF1].get());
674 gate_weights_[GF1].VectorDotMatrix(gate_errors[GF1], sourceerr_temps[GF1]);
675 } else {
676 memset(gate_errors[GF1], 0, ns_ * sizeof(gate_errors[GF1][0]));
677 memset(sourceerr_temps[GF1], 0, na_ * sizeof(*sourceerr_temps[GF1]));
678 }
679 gate_errors_t[GF1].get()->WriteStrided(t, gate_errors[GF1]);
680
681 // 2-D forget Gates.
682 if (up_pos >= 0) {
683 node_values_[GFS].FuncMultiply3<FPrime>(t, state_, up_pos, curr_stateerr, gate_errors[GFS]);
684 ClipVector(ns_, -kErrClip, kErrClip, gate_errors[GFS].get());
685 gate_weights_[GFS].VectorDotMatrix(gate_errors[GFS], sourceerr_temps[GFS]);
686 } else {
687 memset(gate_errors[GFS], 0, ns_ * sizeof(gate_errors[GFS][0]));
688 memset(sourceerr_temps[GFS], 0, na_ * sizeof(*sourceerr_temps[GFS]));
689 }
690 if (Is2D()) {
691 gate_errors_t[GFS].get()->WriteStrided(t, gate_errors[GFS]);
692 }
693
695 // Output gates.
696 state_.Func2Multiply3<HFunc, FPrime>(node_values_[GO], t, outputerr, gate_errors[GO]);
697 ClipVector(ns_, -kErrClip, kErrClip, gate_errors[GO].get());
698 gate_weights_[GO].VectorDotMatrix(gate_errors[GO], sourceerr_temps[GO]);
699 gate_errors_t[GO].get()->WriteStrided(t, gate_errors[GO]);
701
702 SumVectors(na_, sourceerr_temps[CI], sourceerr_temps[GI], sourceerr_temps[GF1],
703 sourceerr_temps[GO], sourceerr_temps[GFS], curr_sourceerr);
704 back_deltas->WriteTimeStep(t, curr_sourceerr);
705 // Save states for use by the 2nd dimension only if needed.
706 if (Is2D()) {
707 CopyVector(ns_, curr_stateerr, stateerr[mod_t]);
708 CopyVector(na_, curr_sourceerr, sourceerr[mod_t]);
709 }
710 } while (dest_index.Decrement());
711#if DEBUG_DETAIL > 2
712 for (int w = 0; w < WT_COUNT; ++w) {
713 tprintf("%s gate errors[%d]\n", name_.c_str(), w);
714 gate_errors_t[w].get()->PrintUnTransposed(10);
715 }
716#endif
717 // Transposed source_ used to speed-up SumOuter.
718 NetworkScratch::GradientStore source_t, state_t;
719 source_t.Init(na_, width, scratch);
720 source_.Transpose(source_t.get());
721 state_t.Init(ns_, width, scratch);
722 state_.Transpose(state_t.get());
723#ifdef _OPENMP
724# pragma omp parallel for num_threads(GFS) if (!Is2D())
725#endif
726 for (int w = 0; w < WT_COUNT; ++w) {
727 if (w == GFS && !Is2D()) {
728 continue;
729 }
730 gate_weights_[w].SumOuterTransposed(*gate_errors_t[w], *source_t, false);
731 }
732 if (softmax_ != nullptr) {
733 softmax_->FinishBackward(*softmax_errors_t);
734 }
735 return needs_to_backprop_;
736}
737
738// Updates the weights using the given learning rate, momentum and adam_beta.
739// num_samples is used in the adam computation iff use_adam_ is true.
740void LSTM::Update(float learning_rate, float momentum, float adam_beta, int num_samples) {
741#if DEBUG_DETAIL > 3
742 PrintW();
743#endif
744 for (int w = 0; w < WT_COUNT; ++w) {
745 if (w == GFS && !Is2D()) {
746 continue;
747 }
748 gate_weights_[w].Update(learning_rate, momentum, adam_beta, num_samples);
749 }
750 if (softmax_ != nullptr) {
751 softmax_->Update(learning_rate, momentum, adam_beta, num_samples);
752 }
753#if DEBUG_DETAIL > 3
754 PrintDW();
755#endif
756}
757
758// Sums the products of weight updates in *this and other, splitting into
759// positive (same direction) in *same and negative (different direction) in
760// *changed.
761void LSTM::CountAlternators(const Network &other, TFloat *same, TFloat *changed) const {
762 ASSERT_HOST(other.type() == type_);
763 const LSTM *lstm = static_cast<const LSTM *>(&other);
764 for (int w = 0; w < WT_COUNT; ++w) {
765 if (w == GFS && !Is2D()) {
766 continue;
767 }
768 gate_weights_[w].CountAlternators(lstm->gate_weights_[w], same, changed);
769 }
770 if (softmax_ != nullptr) {
771 softmax_->CountAlternators(*lstm->softmax_, same, changed);
772 }
773}
774
775#if DEBUG_DETAIL > 3
776
777// Prints the weights for debug purposes.
778void LSTM::PrintW() {
779 tprintf("Weight state:%s\n", name_.c_str());
780 for (int w = 0; w < WT_COUNT; ++w) {
781 if (w == GFS && !Is2D()) {
782 continue;
783 }
784 tprintf("Gate %d, inputs\n", w);
785 for (int i = 0; i < ni_; ++i) {
786 tprintf("Row %d:", i);
787 for (int s = 0; s < ns_; ++s) {
788 tprintf(" %g", gate_weights_[w].GetWeights(s)[i]);
789 }
790 tprintf("\n");
791 }
792 tprintf("Gate %d, outputs\n", w);
793 for (int i = ni_; i < ni_ + ns_; ++i) {
794 tprintf("Row %d:", i - ni_);
795 for (int s = 0; s < ns_; ++s) {
796 tprintf(" %g", gate_weights_[w].GetWeights(s)[i]);
797 }
798 tprintf("\n");
799 }
800 tprintf("Gate %d, bias\n", w);
801 for (int s = 0; s < ns_; ++s) {
802 tprintf(" %g", gate_weights_[w].GetWeights(s)[na_]);
803 }
804 tprintf("\n");
805 }
806}
807
808// Prints the weight deltas for debug purposes.
809void LSTM::PrintDW() {
810 tprintf("Delta state:%s\n", name_.c_str());
811 for (int w = 0; w < WT_COUNT; ++w) {
812 if (w == GFS && !Is2D()) {
813 continue;
814 }
815 tprintf("Gate %d, inputs\n", w);
816 for (int i = 0; i < ni_; ++i) {
817 tprintf("Row %d:", i);
818 for (int s = 0; s < ns_; ++s) {
819 tprintf(" %g", gate_weights_[w].GetDW(s, i));
820 }
821 tprintf("\n");
822 }
823 tprintf("Gate %d, outputs\n", w);
824 for (int i = ni_; i < ni_ + ns_; ++i) {
825 tprintf("Row %d:", i - ni_);
826 for (int s = 0; s < ns_; ++s) {
827 tprintf(" %g", gate_weights_[w].GetDW(s, i));
828 }
829 tprintf("\n");
830 }
831 tprintf("Gate %d, bias\n", w);
832 for (int s = 0; s < ns_; ++s) {
833 tprintf(" %g", gate_weights_[w].GetDW(s, na_));
834 }
835 tprintf("\n");
836 }
837}
838
839#endif
840
841// Resizes forward data to cope with an input image of the given width.
842void LSTM::ResizeForward(const NetworkIO &input) {
843 int rounded_inputs = gate_weights_[CI].RoundInputs(na_);
844 source_.Resize(input, rounded_inputs);
845 which_fg_.ResizeNoInit(input.Width(), ns_);
846 if (IsTraining()) {
847 state_.ResizeFloat(input, ns_);
848 for (int w = 0; w < WT_COUNT; ++w) {
849 if (w == GFS && !Is2D()) {
850 continue;
851 }
852 node_values_[w].ResizeFloat(input, ns_);
853 }
854 }
855}
856
857} // namespace tesseract.
#define ASSERT_HOST(x)
Definition: errcode.h:54
#define END_PARALLEL_IF_OPENMP
Definition: lstm.cpp:64
#define PARALLEL_IF_OPENMP(__num_threads)
Definition: lstm.cpp:62
#define SECTION_IF_OPENMP
Definition: lstm.cpp:63
const TFloat kErrClip
Definition: lstm.cpp:73
void CopyVector(unsigned n, const TFloat *src, TFloat *dest)
Definition: functions.h:210
const TFloat kStateClip
Definition: lstm.cpp:71
void tprintf(const char *format,...)
Definition: tprintf.cpp:41
void SumVectors(int n, const TFloat *v1, const TFloat *v2, const TFloat *v3, const TFloat *v4, const TFloat *v5, TFloat *sum)
Definition: functions.h:236
TrainingState
Definition: network.h:90
@ TS_TEMP_DISABLE
Definition: network.h:95
@ TS_ENABLED
Definition: network.h:93
@ TS_RE_ENABLE
Definition: network.h:97
void MultiplyAccumulate(int n, const TFloat *u, const TFloat *v, TFloat *out)
Definition: functions.h:229
NetworkType
Definition: network.h:41
@ NT_LSTM
Definition: network.h:58
@ NT_SOFTMAX
Definition: network.h:66
@ NT_LSTM_SOFTMAX_ENCODED
Definition: network.h:74
@ NT_LSTM_SUMMARY
Definition: network.h:59
@ NT_LSTM_SOFTMAX
Definition: network.h:73
double TFloat
Definition: tesstypes.h:39
@ FD_WIDTH
Definition: stridemap.h:35
@ FD_HEIGHT
Definition: stridemap.h:34
@ NF_ADAM
Definition: network.h:86
void CodeInBinary(int n, int nf, TFloat *vec)
Definition: functions.h:259
void AccumulateVector(int n, const TFloat *src, TFloat *dest)
Definition: functions.h:215
void ClipVector(int n, T lower, T upper, T *vec)
Definition: functions.h:251
void MultiplyVectorsInPlace(int n, const TFloat *src, TFloat *inout)
Definition: functions.h:222
int Modulo(int a, int b)
Definition: helpers.h:153
type
Definition: upload.py:458
void ResizeNoInit(int size1, int size2, int pad=0)
Definition: matrix.h:94
int RoundOutputs(int size) const
Definition: intsimdmatrix.h:74
static const IntSimdMatrix * intSimdMatrix
bool DeSerialize(std::string &data)
Definition: serialis.cpp:94
bool Serialize(const std::string &data)
Definition: serialis.cpp:107
void ForwardTimeStep(int t, TFloat *output_line)
void FinishBackward(const TransposedArray &errors_t)
void SetupForward(const NetworkIO &input, const TransposedArray *input_transpose)
void SetEnableTraining(TrainingState state) override
void Update(float learning_rate, float momentum, float adam_beta, int num_samples) override
int InitWeights(float range, TRand *randomizer) override
void CountAlternators(const Network &other, TFloat *same, TFloat *changed) const override
void BackwardTimeStep(const NetworkIO &fwd_deltas, int t, TFloat *curr_errors, TransposedArray *errors_t, TFloat *backprop)
int RemapOutputs(int old_no, const std::vector< int > &code_map) override
StaticShape OutputShape(const StaticShape &input_shape) const override
bool Serialize(TFile *fp) const override
bool Is2D() const
Definition: lstm.h:119
bool Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, NetworkIO *back_deltas) override
Definition: lstm.cpp:507
TESS_API LSTM(const std::string &name, int num_inputs, int num_states, int num_outputs, bool two_dimensional, NetworkType type)
Definition: lstm.cpp:101
~LSTM() override
Definition: lstm.cpp:126
int InitWeights(float range, TRand *randomizer) override
Definition: lstm.cpp:175
void DebugWeights() override
Definition: lstm.cpp:215
int RemapOutputs(int old_no, const std::vector< int > &code_map) override
Definition: lstm.cpp:193
void CountAlternators(const Network &other, TFloat *same, TFloat *changed) const override
Definition: lstm.cpp:761
bool DeSerialize(TFile *fp) override
Definition: lstm.cpp:253
bool Serialize(TFile *fp) const override
Definition: lstm.cpp:230
void ConvertToInt() override
Definition: lstm.cpp:202
void SetEnableTraining(TrainingState state) override
Definition: lstm.cpp:146
void Forward(bool debug, const NetworkIO &input, const TransposedArray *input_transpose, NetworkScratch *scratch, NetworkIO *output) override
Definition: lstm.cpp:291
void Update(float learning_rate, float momentum, float adam_beta, int num_samples) override
Definition: lstm.cpp:740
StaticShape OutputShape(const StaticShape &input_shape) const override
Definition: lstm.cpp:132
NetworkType type_
Definition: network.h:300
bool needs_to_backprop_
Definition: network.h:302
int num_weights() const
Definition: network.h:119
void DisplayForward(const NetworkIO &matrix)
Definition: network.cpp:333
std::string name_
Definition: network.h:307
void DisplayBackward(const NetworkIO &matrix)
Definition: network.cpp:341
static Network * CreateFromFile(TFile *fp)
Definition: network.cpp:217
bool IsTraining() const
Definition: network.h:113
virtual bool Serialize(TFile *fp) const
Definition: network.cpp:158
bool TestFlag(NetworkFlags flag) const
Definition: network.h:146
int32_t num_weights_
Definition: network.h:306
TrainingState training_
Definition: network.h:301
NetworkType type() const
Definition: network.h:110
virtual void SetRandomizer(TRand *randomizer)
Definition: network.cpp:145
void Resize(const NetworkIO &src, int num_features)
Definition: networkio.h:44
void WriteTimeStepPart(int t, int offset, int num_features, const TFloat *input)
Definition: networkio.cpp:662
bool int_mode() const
Definition: networkio.h:122
void ResizeFloat(const NetworkIO &src, int num_features)
Definition: networkio.h:51
void CopyTimeStepGeneral(int dest_t, int dest_offset, int num_features, const NetworkIO &src, int src_t, int src_offset)
Definition: networkio.cpp:405
void WriteTimeStep(int t, const TFloat *input)
Definition: networkio.cpp:656
void FuncMultiply3Add(const NetworkIO &v_io, int t, const TFloat *w, TFloat *product) const
Definition: networkio.h:299
void Print(int num) const
Definition: networkio.cpp:378
void ReadTimeStep(int t, TFloat *output) const
Definition: networkio.cpp:610
float * f(int t)
Definition: networkio.h:110
int Width() const
Definition: networkio.h:102
void Func2Multiply3(const NetworkIO &v_io, int t, const TFloat *w, TFloat *product) const
Definition: networkio.h:314
void Transpose(TransposedArray *dest) const
Definition: networkio.cpp:971
const StrideMap & stride_map() const
Definition: networkio.h:128
void ResizeToMap(bool int_mode, const StrideMap &stride_map, int num_features)
Definition: networkio.cpp:46
const int8_t * i(int t) const
Definition: networkio.h:118
void Resize2d(bool int_mode, int width, int num_features, NetworkScratch *scratch)
void Init(int, int reserve, NetworkScratch *scratch)
void Init(int size1, int size2, NetworkScratch *scratch)
void set_depth(int value)
Definition: static_shape.h:62
void set_width(int value)
Definition: static_shape.h:56
int Size(FlexDimensions dimension) const
Definition: stridemap.h:119
int index(FlexDimensions dimension) const
Definition: stridemap.h:59
bool AddOffset(int offset, FlexDimensions dimension)
Definition: stridemap.cpp:67
bool IsLast(FlexDimensions dimension) const
Definition: stridemap.cpp:40
void WriteStrided(int t, const float *data)
Definition: weightmatrix.h:40
void PrintUnTransposed(int num)
Definition: weightmatrix.h:53
void SumOuterTransposed(const TransposedArray &u, const TransposedArray &v, bool parallel)
int InitWeightsFloat(int no, int ni, bool use_adam, float weight_range, TRand *randomizer)
void Update(float learning_rate, float momentum, float adam_beta, int num_samples)
void Debug2D(const char *msg)
void VectorDotMatrix(const TFloat *u, TFloat *v) const
void MatrixDotVector(const TFloat *u, TFloat *v) const
int RoundInputs(int size) const
Definition: weightmatrix.h:96
void CountAlternators(const WeightMatrix &other, TFloat *same, TFloat *changed) const