tesseract  4.00.00dev
series.h
Go to the documentation of this file.
1 // File: series.h
3 // Description: Runs networks in series on the same input.
4 // Author: Ray Smith
5 // Created: Thu May 02 08:20:06 PST 2013
6 //
7 // (C) Copyright 2013, Google Inc.
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 // http://www.apache.org/licenses/LICENSE-2.0
12 // Unless required by applicable law or agreed to in writing, software
13 // distributed under the License is distributed on an "AS IS" BASIS,
14 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 // See the License for the specific language governing permissions and
16 // limitations under the License.
18 
19 #ifndef TESSERACT_LSTM_SERIES_H_
20 #define TESSERACT_LSTM_SERIES_H_
21 
22 #include "plumbing.h"
23 
24 namespace tesseract {
25 
26 // Runs two or more networks in series (layers) on the same input.
27 class Series : public Plumbing {
28  public:
29  // ni_ and no_ will be set by AddToStack.
30  explicit Series(const STRING& name);
31  virtual ~Series();
32 
33  // Returns the shape output from the network given an input shape (which may
34  // be partially unknown ie zero).
35  virtual StaticShape OutputShape(const StaticShape& input_shape) const;
36 
37  virtual STRING spec() const {
38  STRING spec("[");
39  for (int i = 0; i < stack_.size(); ++i)
40  spec += stack_[i]->spec();
41  spec += "]";
42  return spec;
43  }
44 
45  // Sets up the network for training. Initializes weights using weights of
46  // scale `range` picked according to the random number generator `randomizer`.
47  // Returns the number of weights initialized.
48  virtual int InitWeights(float range, TRand* randomizer);
49  // Recursively searches the network for softmaxes with old_no outputs,
50  // and remaps their outputs according to code_map. See network.h for details.
51  int RemapOutputs(int old_no, const std::vector<int>& code_map) override;
52 
53  // Sets needs_to_backprop_ to needs_backprop and returns true if
54  // needs_backprop || any weights in this network so the next layer forward
55  // can be told to produce backprop for this layer if needed.
56  virtual bool SetupNeedsBackprop(bool needs_backprop);
57 
58  // Returns an integer reduction factor that the network applies to the
59  // time sequence. Assumes that any 2-d is already eliminated. Used for
60  // scaling bounding boxes of truth data.
61  // WARNING: if GlobalMinimax is used to vary the scale, this will return
62  // the last used scale factor. Call it before any forward, and it will return
63  // the minimum scale factor of the paths through the GlobalMinimax.
64  virtual int XScaleFactor() const;
65 
66  // Provides the (minimum) x scale factor to the network (of interest only to
67  // input units) so they can determine how to scale bounding boxes.
68  virtual void CacheXScaleFactor(int factor);
69 
70  // Runs forward propagation of activations on the input line.
71  // See Network for a detailed discussion of the arguments.
72  virtual void Forward(bool debug, const NetworkIO& input,
73  const TransposedArray* input_transpose,
74  NetworkScratch* scratch, NetworkIO* output);
75 
76  // Runs backward propagation of errors on the deltas line.
77  // See Network for a detailed discussion of the arguments.
78  virtual bool Backward(bool debug, const NetworkIO& fwd_deltas,
79  NetworkScratch* scratch,
80  NetworkIO* back_deltas);
81 
82  // Splits the series after the given index, returning the two parts and
83  // deletes itself. The first part, up to network with index last_start, goes
84  // into start, and the rest goes into end.
85  void SplitAt(int last_start, Series** start, Series** end);
86 
87  // Appends the elements of the src series to this, removing from src and
88  // deleting it.
89  void AppendSeries(Network* src);
90 };
91 
92 } // namespace tesseract.
93 
94 #endif // TESSERACT_LSTM_SERIES_H_
void AppendSeries(Network *src)
Definition: series.cpp:193
virtual StaticShape OutputShape(const StaticShape &input_shape) const
Definition: series.cpp:38
virtual void CacheXScaleFactor(int factor)
Definition: series.cpp:104
virtual STRING spec() const
Definition: series.h:37
virtual void Forward(bool debug, const NetworkIO &input, const TransposedArray *input_transpose, NetworkScratch *scratch, NetworkIO *output)
Definition: series.cpp:110
int RemapOutputs(int old_no, const std::vector< int > &code_map) override
Definition: series.cpp:65
Series(const STRING &name)
Definition: series.cpp:29
void SplitAt(int last_start, Series **start, Series **end)
Definition: series.cpp:163
PointerVector< Network > stack_
Definition: plumbing.h:136
Definition: strngs.h:45
virtual int XScaleFactor() const
Definition: series.cpp:95
const STRING & name() const
Definition: network.h:138
virtual int InitWeights(float range, TRand *randomizer)
Definition: series.cpp:50
virtual bool SetupNeedsBackprop(bool needs_backprop)
Definition: series.cpp:82
virtual bool Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, NetworkIO *back_deltas)
Definition: series.cpp:132
virtual ~Series()
Definition: series.cpp:33