tesseract  4.00.00dev
fullyconnected.cpp
Go to the documentation of this file.
1 // File: fullyconnected.cpp
3 // Description: Simple feed-forward layer with various non-linearities.
4 // Author: Ray Smith
5 // Created: Wed Feb 26 14:49:15 PST 2014
6 //
7 // (C) Copyright 2014, Google Inc.
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 // http://www.apache.org/licenses/LICENSE-2.0
12 // Unless required by applicable law or agreed to in writing, software
13 // distributed under the License is distributed on an "AS IS" BASIS,
14 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 // See the License for the specific language governing permissions and
16 // limitations under the License.
18 
19 #include "fullyconnected.h"
20 
21 #ifdef _OPENMP
22 #include <omp.h>
23 #endif
24 #include <stdio.h>
25 #include <stdlib.h>
26 
27 #include "functions.h"
28 #include "networkscratch.h"
29 
30 // Number of threads to use for parallel calculation of Forward and Backward.
31 #ifdef _OPENMP
32 const int kNumThreads = 4;
33 #else
34 const int kNumThreads = 1;
35 #endif
36 
37 namespace tesseract {
38 
39 FullyConnected::FullyConnected(const STRING& name, int ni, int no,
40  NetworkType type)
41  : Network(type, name, ni, no), external_source_(NULL), int_mode_(false) {
42 }
43 
45 }
46 
47 // Returns the shape output from the network given an input shape (which may
48 // be partially unknown ie zero).
50  LossType loss_type = LT_NONE;
51  if (type_ == NT_SOFTMAX)
52  loss_type = LT_CTC;
53  else if (type_ == NT_SOFTMAX_NO_CTC)
54  loss_type = LT_SOFTMAX;
55  else if (type_ == NT_LOGISTIC)
56  loss_type = LT_LOGISTIC;
57  StaticShape result(input_shape);
58  result.set_depth(no_);
59  result.set_loss_type(loss_type);
60  return result;
61 }
62 
63 // Suspends/Enables training by setting the training_ flag.
65  if (state == TS_RE_ENABLE) {
66  // Enable only from temp disabled.
68  } else if (state == TS_TEMP_DISABLE) {
69  // Temp disable only from enabled.
70  if (training_ == TS_ENABLED) training_ = state;
71  } else {
72  if (state == TS_ENABLED && training_ != TS_ENABLED)
74  training_ = state;
75  }
76 }
77 
78 // Sets up the network for training. Initializes weights using weights of
79 // scale `range` picked according to the random number generator `randomizer`.
80 int FullyConnected::InitWeights(float range, TRand* randomizer) {
81  Network::SetRandomizer(randomizer);
83  range, randomizer);
84  return num_weights_;
85 }
86 
87 // Recursively searches the network for softmaxes with old_no outputs,
88 // and remaps their outputs according to code_map. See network.h for details.
89 
90 int FullyConnected::RemapOutputs(int old_no, const std::vector<int>& code_map) {
91  if (type_ == NT_SOFTMAX && no_ == old_no) {
93  no_ = code_map.size();
94  }
95  return num_weights_;
96 }
97 
98 // Converts a float network to an int network.
101 }
102 
103 // Provides debug output on the weights.
106 }
107 
108 // Writes to the given file. Returns false in case of error.
110  if (!Network::Serialize(fp)) return false;
111  if (!weights_.Serialize(IsTraining(), fp)) return false;
112  return true;
113 }
114 
115 // Reads from the given file. Returns false in case of error.
117  return weights_.DeSerialize(IsTraining(), fp);
118 }
119 
120 // Runs forward propagation of activations on the input line.
121 // See NetworkCpp for a detailed discussion of the arguments.
122 void FullyConnected::Forward(bool debug, const NetworkIO& input,
123  const TransposedArray* input_transpose,
124  NetworkScratch* scratch, NetworkIO* output) {
125  int width = input.Width();
126  if (type_ == NT_SOFTMAX)
127  output->ResizeFloat(input, no_);
128  else
129  output->Resize(input, no_);
130  SetupForward(input, input_transpose);
135  for (int i = 0; i < kNumThreads; ++i) {
136  temp_lines[i].Init(no_, scratch);
137  curr_input[i].Init(ni_, scratch);
138  }
139 #ifdef _OPENMP
140 #pragma omp parallel for num_threads(kNumThreads)
141  for (int t = 0; t < width; ++t) {
142  // Thread-local pointer to temporary storage.
143  int thread_id = omp_get_thread_num();
144 #else
145  for (int t = 0; t < width; ++t) {
146  // Thread-local pointer to temporary storage.
147  int thread_id = 0;
148 #endif
149  double* temp_line = temp_lines[thread_id];
150  const double* d_input = NULL;
151  const inT8* i_input = NULL;
152  if (input.int_mode()) {
153  i_input = input.i(t);
154  } else {
155  input.ReadTimeStep(t, curr_input[thread_id]);
156  d_input = curr_input[thread_id];
157  }
158  ForwardTimeStep(d_input, i_input, t, temp_line);
159  output->WriteTimeStep(t, temp_line);
160  if (IsTraining() && type_ != NT_SOFTMAX) {
161  acts_.CopyTimeStepFrom(t, *output, t);
162  }
163  }
164  // Zero all the elements that are in the padding around images that allows
165  // multiple different-sized images to exist in a single array.
166  // acts_ is only used if this is not a softmax op.
167  if (IsTraining() && type_ != NT_SOFTMAX) {
169  }
170  output->ZeroInvalidElements();
171 #if DEBUG_DETAIL > 0
172  tprintf("F Output:%s\n", name_.string());
173  output->Print(10);
174 #endif
175  if (debug) DisplayForward(*output);
176 }
177 
178 // Components of Forward so FullyConnected can be reused inside LSTM.
180  const TransposedArray* input_transpose) {
181  // Softmax output is always float, so save the input type.
182  int_mode_ = input.int_mode();
183  if (IsTraining()) {
184  acts_.Resize(input, no_);
185  // Source_ is a transposed copy of input. It isn't needed if provided.
186  external_source_ = input_transpose;
187  if (external_source_ == NULL) source_t_.ResizeNoInit(ni_, input.Width());
188  }
189 }
190 
191 void FullyConnected::ForwardTimeStep(const double* d_input, const inT8* i_input,
192  int t, double* output_line) {
193  // input is copied to source_ line-by-line for cache coherency.
194  if (IsTraining() && external_source_ == NULL && d_input != NULL)
195  source_t_.WriteStrided(t, d_input);
196  if (d_input != NULL)
197  weights_.MatrixDotVector(d_input, output_line);
198  else
199  weights_.MatrixDotVector(i_input, output_line);
200  if (type_ == NT_TANH) {
201  FuncInplace<GFunc>(no_, output_line);
202  } else if (type_ == NT_LOGISTIC) {
203  FuncInplace<FFunc>(no_, output_line);
204  } else if (type_ == NT_POSCLIP) {
205  FuncInplace<ClipFFunc>(no_, output_line);
206  } else if (type_ == NT_SYMCLIP) {
207  FuncInplace<ClipGFunc>(no_, output_line);
208  } else if (type_ == NT_RELU) {
209  FuncInplace<Relu>(no_, output_line);
210  } else if (type_ == NT_SOFTMAX || type_ == NT_SOFTMAX_NO_CTC) {
211  SoftmaxInPlace(no_, output_line);
212  } else if (type_ != NT_LINEAR) {
213  ASSERT_HOST("Invalid fully-connected type!" == NULL);
214  }
215 }
216 
217 // Runs backward propagation of errors on the deltas line.
218 // See NetworkCpp for a detailed discussion of the arguments.
219 bool FullyConnected::Backward(bool debug, const NetworkIO& fwd_deltas,
220  NetworkScratch* scratch,
221  NetworkIO* back_deltas) {
222  if (debug) DisplayBackward(fwd_deltas);
223  back_deltas->Resize(fwd_deltas, ni_);
225  errors.init_to_size(kNumThreads, NetworkScratch::FloatVec());
226  for (int i = 0; i < kNumThreads; ++i) errors[i].Init(no_, scratch);
228  if (needs_to_backprop_) {
229  temp_backprops.init_to_size(kNumThreads, NetworkScratch::FloatVec());
230  for (int i = 0; i < kNumThreads; ++i) temp_backprops[i].Init(ni_, scratch);
231  }
232  int width = fwd_deltas.Width();
234  errors_t.Init(no_, width, scratch);
235 #ifdef _OPENMP
236 #pragma omp parallel for num_threads(kNumThreads)
237  for (int t = 0; t < width; ++t) {
238  int thread_id = omp_get_thread_num();
239 #else
240  for (int t = 0; t < width; ++t) {
241  int thread_id = 0;
242 #endif
243  double* backprop = NULL;
244  if (needs_to_backprop_) backprop = temp_backprops[thread_id];
245  double* curr_errors = errors[thread_id];
246  BackwardTimeStep(fwd_deltas, t, curr_errors, errors_t.get(), backprop);
247  if (backprop != NULL) {
248  back_deltas->WriteTimeStep(t, backprop);
249  }
250  }
251  FinishBackward(*errors_t.get());
252  if (needs_to_backprop_) {
253  back_deltas->ZeroInvalidElements();
254 #if DEBUG_DETAIL > 0
255  tprintf("F Backprop:%s\n", name_.string());
256  back_deltas->Print(10);
257 #endif
258  return true;
259  }
260  return false; // No point going further back.
261 }
262 
263 void FullyConnected::BackwardTimeStep(const NetworkIO& fwd_deltas, int t,
264  double* curr_errors,
265  TransposedArray* errors_t,
266  double* backprop) {
267  if (type_ == NT_TANH)
268  acts_.FuncMultiply<GPrime>(fwd_deltas, t, curr_errors);
269  else if (type_ == NT_LOGISTIC)
270  acts_.FuncMultiply<FPrime>(fwd_deltas, t, curr_errors);
271  else if (type_ == NT_POSCLIP)
272  acts_.FuncMultiply<ClipFPrime>(fwd_deltas, t, curr_errors);
273  else if (type_ == NT_SYMCLIP)
274  acts_.FuncMultiply<ClipGPrime>(fwd_deltas, t, curr_errors);
275  else if (type_ == NT_RELU)
276  acts_.FuncMultiply<ReluPrime>(fwd_deltas, t, curr_errors);
277  else if (type_ == NT_SOFTMAX || type_ == NT_SOFTMAX_NO_CTC ||
278  type_ == NT_LINEAR)
279  fwd_deltas.ReadTimeStep(t, curr_errors); // fwd_deltas are the errors.
280  else
281  ASSERT_HOST("Invalid fully-connected type!" == NULL);
282  // Generate backprop only if needed by the lower layer.
283  if (backprop != NULL) weights_.VectorDotMatrix(curr_errors, backprop);
284  errors_t->WriteStrided(t, curr_errors);
285 }
286 
288  if (external_source_ == NULL)
289  weights_.SumOuterTransposed(errors_t, source_t_, true);
290  else
291  weights_.SumOuterTransposed(errors_t, *external_source_, true);
292 }
293 
294 // Updates the weights using the given learning rate, momentum and adam_beta.
295 // num_samples is used in the adam computation iff use_adam_ is true.
296 void FullyConnected::Update(float learning_rate, float momentum,
297  float adam_beta, int num_samples) {
298  weights_.Update(learning_rate, momentum, adam_beta, num_samples);
299 }
300 
301 // Sums the products of weight updates in *this and other, splitting into
302 // positive (same direction) in *same and negative (different direction) in
303 // *changed.
304 void FullyConnected::CountAlternators(const Network& other, double* same,
305  double* changed) const {
306  ASSERT_HOST(other.type() == type_);
307  const FullyConnected* fc = static_cast<const FullyConnected*>(&other);
308  weights_.CountAlternators(fc->weights_, same, changed);
309 }
310 
311 } // namespace tesseract.
virtual void SetEnableTraining(TrainingState state)
virtual int InitWeights(float range, TRand *randomizer)
void SoftmaxInPlace(int n, T *inout)
Definition: functions.h:163
virtual bool Serialize(TFile *fp) const
Definition: network.cpp:153
NetworkType
Definition: network.h:43
void Resize(const NetworkIO &src, int num_features)
Definition: networkio.h:45
virtual bool Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, NetworkIO *back_deltas)
int RemapOutputs(const std::vector< int > &code_map)
virtual StaticShape OutputShape(const StaticShape &input_shape) const
void BackwardTimeStep(const NetworkIO &fwd_deltas, int t, double *curr_errors, TransposedArray *errors_t, double *backprop)
void ResizeNoInit(int size1, int size2, int pad=0)
Definition: matrix.h:88
virtual void SetRandomizer(TRand *randomizer)
Definition: network.cpp:140
void ZeroInvalidElements()
Definition: networkio.cpp:93
void DisplayForward(const NetworkIO &matrix)
Definition: network.cpp:285
void ForwardTimeStep(const double *d_input, const inT8 *i_input, int t, double *output_line)
TrainingState training_
Definition: network.h:300
bool TestFlag(NetworkFlags flag) const
Definition: network.h:144
NetworkType type_
Definition: network.h:299
const int kNumThreads
void MatrixDotVector(const double *u, double *v) const
int InitWeightsFloat(int no, int ni, bool use_adam, float weight_range, TRand *randomizer)
virtual void CountAlternators(const Network &other, double *same, double *changed) const
void FinishBackward(const TransposedArray &errors_t)
void VectorDotMatrix(const double *u, double *v) const
void Init(int size1, int size2, NetworkScratch *scratch)
#define tprintf(...)
Definition: tprintf.h:31
virtual bool Serialize(TFile *fp) const
void ReadTimeStep(int t, double *output) const
Definition: networkio.cpp:603
bool IsTraining() const
Definition: network.h:115
void CountAlternators(const WeightMatrix &other, double *same, double *changed) const
void CopyTimeStepFrom(int dest_t, const NetworkIO &src, int src_t)
Definition: networkio.cpp:388
int8_t inT8
Definition: host.h:34
void SumOuterTransposed(const TransposedArray &u, const TransposedArray &v, bool parallel)
bool int_mode() const
Definition: networkio.h:127
void set_loss_type(LossType value)
Definition: static_shape.h:49
const char * string() const
Definition: strngs.cpp:198
Definition: strngs.h:45
inT32 num_weights_
Definition: network.h:305
void WriteTimeStep(int t, const double *input)
Definition: networkio.cpp:650
void Debug2D(const char *msg)
void FuncMultiply(const NetworkIO &v_io, int t, double *product)
Definition: networkio.h:259
TrainingState
Definition: network.h:92
void Update(double learning_rate, double momentum, double adam_beta, int num_samples)
#define ASSERT_HOST(x)
Definition: errcode.h:84
int RemapOutputs(int old_no, const std::vector< int > &code_map) override
void ResizeFloat(const NetworkIO &src, int num_features)
Definition: networkio.h:52
NetworkType type() const
Definition: network.h:112
virtual bool DeSerialize(TFile *fp)
int Width() const
Definition: networkio.h:107
void DisplayBackward(const NetworkIO &matrix)
Definition: network.cpp:296
void Print(int num) const
Definition: networkio.cpp:371
void set_depth(int value)
Definition: static_shape.h:47
bool needs_to_backprop_
Definition: network.h:301
void SetupForward(const NetworkIO &input, const TransposedArray *input_transpose)
virtual void Forward(bool debug, const NetworkIO &input, const TransposedArray *input_transpose, NetworkScratch *scratch, NetworkIO *output)
void Update(float learning_rate, float momentum, float adam_beta, int num_samples) override
void WriteStrided(int t, const float *data)
Definition: weightmatrix.h:39
bool DeSerialize(bool training, TFile *fp)
FullyConnected(const STRING &name, int ni, int no, NetworkType type)
void init_to_size(int size, T t)
const TransposedArray * external_source_
const inT8 * i(int t) const
Definition: networkio.h:123
bool Serialize(bool training, TFile *fp) const