tesseract  4.0.0-beta.1-59-g2cc4
tesseract::FullyConnected Class Reference

#include <fullyconnected.h>

Inheritance diagram for tesseract::FullyConnected:
tesseract::Network

Public Member Functions

 FullyConnected (const STRING &name, int ni, int no, NetworkType type)
 
virtual ~FullyConnected ()
 
virtual StaticShape OutputShape (const StaticShape &input_shape) const
 
virtual STRING spec () const
 
void ChangeType (NetworkType type)
 
virtual void SetEnableTraining (TrainingState state)
 
virtual int InitWeights (float range, TRand *randomizer)
 
int RemapOutputs (int old_no, const std::vector< int > &code_map) override
 
virtual void ConvertToInt ()
 
virtual void DebugWeights ()
 
virtual bool Serialize (TFile *fp) const
 
virtual bool DeSerialize (TFile *fp)
 
virtual void Forward (bool debug, const NetworkIO &input, const TransposedArray *input_transpose, NetworkScratch *scratch, NetworkIO *output)
 
void SetupForward (const NetworkIO &input, const TransposedArray *input_transpose)
 
void ForwardTimeStep (const double *d_input, const int8_t *i_input, int t, double *output_line)
 
virtual bool Backward (bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, NetworkIO *back_deltas)
 
void BackwardTimeStep (const NetworkIO &fwd_deltas, int t, double *curr_errors, TransposedArray *errors_t, double *backprop)
 
void FinishBackward (const TransposedArray &errors_t)
 
void Update (float learning_rate, float momentum, float adam_beta, int num_samples) override
 
virtual void CountAlternators (const Network &other, double *same, double *changed) const
 
- Public Member Functions inherited from tesseract::Network
 Network ()
 
 Network (NetworkType type, const STRING &name, int ni, int no)
 
virtual ~Network ()
 
NetworkType type () const
 
bool IsTraining () const
 
bool needs_to_backprop () const
 
int num_weights () const
 
int NumInputs () const
 
int NumOutputs () const
 
virtual StaticShape InputShape () const
 
const STRINGname () const
 
bool TestFlag (NetworkFlags flag) const
 
virtual bool IsPlumbingType () const
 
virtual void SetNetworkFlags (uint32_t flags)
 
virtual void SetRandomizer (TRand *randomizer)
 
virtual bool SetupNeedsBackprop (bool needs_backprop)
 
virtual int XScaleFactor () const
 
virtual void CacheXScaleFactor (int factor)
 
void DisplayForward (const NetworkIO &matrix)
 
void DisplayBackward (const NetworkIO &matrix)
 

Protected Attributes

WeightMatrix weights_
 
TransposedArray source_t_
 
const TransposedArrayexternal_source_
 
NetworkIO acts_
 
bool int_mode_
 
- Protected Attributes inherited from tesseract::Network
NetworkType type_
 
TrainingState training_
 
bool needs_to_backprop_
 
int32_t network_flags_
 
int32_t ni_
 
int32_t no_
 
int32_t num_weights_
 
STRING name_
 
ScrollViewforward_win_
 
ScrollViewbackward_win_
 
TRandrandomizer_
 

Additional Inherited Members

- Static Public Member Functions inherited from tesseract::Network
static NetworkCreateFromFile (TFile *fp)
 
static void ClearWindow (bool tess_coords, const char *window_name, int width, int height, ScrollView **window)
 
static int DisplayImage (Pix *pix, ScrollView *window)
 
- Protected Member Functions inherited from tesseract::Network
double Random (double range)
 
- Static Protected Attributes inherited from tesseract::Network
static char const *const kTypeNames [NT_COUNT]
 

Detailed Description

Definition at line 28 of file fullyconnected.h.

Constructor & Destructor Documentation

◆ FullyConnected()

tesseract::FullyConnected::FullyConnected ( const STRING name,
int  ni,
int  no,
NetworkType  type 
)

Definition at line 39 of file fullyconnected.cpp.

41  : Network(type, name, ni, no), external_source_(nullptr), int_mode_(false) {
42 }
const TransposedArray * external_source_
NetworkType type() const
Definition: network.h:112

◆ ~FullyConnected()

tesseract::FullyConnected::~FullyConnected ( )
virtual

Definition at line 44 of file fullyconnected.cpp.

44  {
45 }

Member Function Documentation

◆ Backward()

bool tesseract::FullyConnected::Backward ( bool  debug,
const NetworkIO fwd_deltas,
NetworkScratch scratch,
NetworkIO back_deltas 
)
virtual

Reimplemented from tesseract::Network.

Definition at line 219 of file fullyconnected.cpp.

221  {
222  if (debug) DisplayBackward(fwd_deltas);
223  back_deltas->Resize(fwd_deltas, ni_);
225  errors.init_to_size(kNumThreads, NetworkScratch::FloatVec());
226  for (int i = 0; i < kNumThreads; ++i) errors[i].Init(no_, scratch);
228  if (needs_to_backprop_) {
229  temp_backprops.init_to_size(kNumThreads, NetworkScratch::FloatVec());
230  for (int i = 0; i < kNumThreads; ++i) temp_backprops[i].Init(ni_, scratch);
231  }
232  int width = fwd_deltas.Width();
233  NetworkScratch::GradientStore errors_t;
234  errors_t.Init(no_, width, scratch);
235 #ifdef _OPENMP
236 #pragma omp parallel for num_threads(kNumThreads)
237  for (int t = 0; t < width; ++t) {
238  int thread_id = omp_get_thread_num();
239 #else
240  for (int t = 0; t < width; ++t) {
241  int thread_id = 0;
242 #endif
243  double* backprop = nullptr;
244  if (needs_to_backprop_) backprop = temp_backprops[thread_id];
245  double* curr_errors = errors[thread_id];
246  BackwardTimeStep(fwd_deltas, t, curr_errors, errors_t.get(), backprop);
247  if (backprop != nullptr) {
248  back_deltas->WriteTimeStep(t, backprop);
249  }
250  }
251  FinishBackward(*errors_t.get());
252  if (needs_to_backprop_) {
253  back_deltas->ZeroInvalidElements();
254 #if DEBUG_DETAIL > 0
255  tprintf("F Backprop:%s\n", name_.string());
256  back_deltas->Print(10);
257 #endif
258  return true;
259  }
260  return false; // No point going further back.
261 }
void DisplayBackward(const NetworkIO &matrix)
Definition: network.cpp:296
void FinishBackward(const TransposedArray &errors_t)
bool needs_to_backprop_
Definition: network.h:301
void BackwardTimeStep(const NetworkIO &fwd_deltas, int t, double *curr_errors, TransposedArray *errors_t, double *backprop)
const int kNumThreads
#define tprintf(...)
Definition: tprintf.h:31
const char * string() const
Definition: strngs.cpp:198
void init_to_size(int size, T t)

◆ BackwardTimeStep()

void tesseract::FullyConnected::BackwardTimeStep ( const NetworkIO fwd_deltas,
int  t,
double *  curr_errors,
TransposedArray errors_t,
double *  backprop 
)

Definition at line 263 of file fullyconnected.cpp.

266  {
267  if (type_ == NT_TANH)
268  acts_.FuncMultiply<GPrime>(fwd_deltas, t, curr_errors);
269  else if (type_ == NT_LOGISTIC)
270  acts_.FuncMultiply<FPrime>(fwd_deltas, t, curr_errors);
271  else if (type_ == NT_POSCLIP)
272  acts_.FuncMultiply<ClipFPrime>(fwd_deltas, t, curr_errors);
273  else if (type_ == NT_SYMCLIP)
274  acts_.FuncMultiply<ClipGPrime>(fwd_deltas, t, curr_errors);
275  else if (type_ == NT_RELU)
276  acts_.FuncMultiply<ReluPrime>(fwd_deltas, t, curr_errors);
277  else if (type_ == NT_SOFTMAX || type_ == NT_SOFTMAX_NO_CTC ||
278  type_ == NT_LINEAR)
279  fwd_deltas.ReadTimeStep(t, curr_errors); // fwd_deltas are the errors.
280  else
281  ASSERT_HOST("Invalid fully-connected type!" == nullptr);
282  // Generate backprop only if needed by the lower layer.
283  if (backprop != nullptr) weights_.VectorDotMatrix(curr_errors, backprop);
284  errors_t->WriteStrided(t, curr_errors);
285 }
NetworkType type_
Definition: network.h:299
void VectorDotMatrix(const double *u, double *v) const
#define ASSERT_HOST(x)
Definition: errcode.h:84
void FuncMultiply(const NetworkIO &v_io, int t, double *product)
Definition: networkio.h:259

◆ ChangeType()

void tesseract::FullyConnected::ChangeType ( NetworkType  type)
inline

Definition at line 60 of file fullyconnected.h.

60  {
61  type_ = type;
62  }
NetworkType type_
Definition: network.h:299
NetworkType type() const
Definition: network.h:112

◆ ConvertToInt()

void tesseract::FullyConnected::ConvertToInt ( )
virtual

Reimplemented from tesseract::Network.

Definition at line 99 of file fullyconnected.cpp.

99  {
101 }

◆ CountAlternators()

void tesseract::FullyConnected::CountAlternators ( const Network other,
double *  same,
double *  changed 
) const
virtual

Reimplemented from tesseract::Network.

Definition at line 304 of file fullyconnected.cpp.

305  {
306  ASSERT_HOST(other.type() == type_);
307  const FullyConnected* fc = static_cast<const FullyConnected*>(&other);
308  weights_.CountAlternators(fc->weights_, same, changed);
309 }
NetworkType type_
Definition: network.h:299
FullyConnected(const STRING &name, int ni, int no, NetworkType type)
#define ASSERT_HOST(x)
Definition: errcode.h:84
void CountAlternators(const WeightMatrix &other, double *same, double *changed) const

◆ DebugWeights()

void tesseract::FullyConnected::DebugWeights ( )
virtual

Reimplemented from tesseract::Network.

Definition at line 104 of file fullyconnected.cpp.

104  {
106 }
void Debug2D(const char *msg)
const char * string() const
Definition: strngs.cpp:198

◆ DeSerialize()

bool tesseract::FullyConnected::DeSerialize ( TFile fp)
virtual

Reimplemented from tesseract::Network.

Definition at line 116 of file fullyconnected.cpp.

116  {
117  return weights_.DeSerialize(IsTraining(), fp);
118 }
bool DeSerialize(bool training, TFile *fp)
bool IsTraining() const
Definition: network.h:115

◆ FinishBackward()

void tesseract::FullyConnected::FinishBackward ( const TransposedArray errors_t)

Definition at line 287 of file fullyconnected.cpp.

287  {
288  if (external_source_ == nullptr)
289  weights_.SumOuterTransposed(errors_t, source_t_, true);
290  else
291  weights_.SumOuterTransposed(errors_t, *external_source_, true);
292 }
const TransposedArray * external_source_
void SumOuterTransposed(const TransposedArray &u, const TransposedArray &v, bool parallel)

◆ Forward()

void tesseract::FullyConnected::Forward ( bool  debug,
const NetworkIO input,
const TransposedArray input_transpose,
NetworkScratch scratch,
NetworkIO output 
)
virtual

Reimplemented from tesseract::Network.

Definition at line 122 of file fullyconnected.cpp.

124  {
125  int width = input.Width();
126  if (type_ == NT_SOFTMAX)
127  output->ResizeFloat(input, no_);
128  else
129  output->Resize(input, no_);
130  SetupForward(input, input_transpose);
132  temp_lines.init_to_size(kNumThreads, NetworkScratch::FloatVec());
134  curr_input.init_to_size(kNumThreads, NetworkScratch::FloatVec());
135  for (int i = 0; i < kNumThreads; ++i) {
136  temp_lines[i].Init(no_, scratch);
137  curr_input[i].Init(ni_, scratch);
138  }
139 #ifdef _OPENMP
140 #pragma omp parallel for num_threads(kNumThreads)
141  for (int t = 0; t < width; ++t) {
142  // Thread-local pointer to temporary storage.
143  int thread_id = omp_get_thread_num();
144 #else
145  for (int t = 0; t < width; ++t) {
146  // Thread-local pointer to temporary storage.
147  int thread_id = 0;
148 #endif
149  double* temp_line = temp_lines[thread_id];
150  const double* d_input = nullptr;
151  const int8_t* i_input = nullptr;
152  if (input.int_mode()) {
153  i_input = input.i(t);
154  } else {
155  input.ReadTimeStep(t, curr_input[thread_id]);
156  d_input = curr_input[thread_id];
157  }
158  ForwardTimeStep(d_input, i_input, t, temp_line);
159  output->WriteTimeStep(t, temp_line);
160  if (IsTraining() && type_ != NT_SOFTMAX) {
161  acts_.CopyTimeStepFrom(t, *output, t);
162  }
163  }
164  // Zero all the elements that are in the padding around images that allows
165  // multiple different-sized images to exist in a single array.
166  // acts_ is only used if this is not a softmax op.
167  if (IsTraining() && type_ != NT_SOFTMAX) {
169  }
170  output->ZeroInvalidElements();
171 #if DEBUG_DETAIL > 0
172  tprintf("F Output:%s\n", name_.string());
173  output->Print(10);
174 #endif
175  if (debug) DisplayForward(*output);
176 }
void ForwardTimeStep(const double *d_input, const int8_t *i_input, int t, double *output_line)
NetworkType type_
Definition: network.h:299
void ZeroInvalidElements()
Definition: networkio.cpp:93
void DisplayForward(const NetworkIO &matrix)
Definition: network.cpp:285
const int kNumThreads
#define tprintf(...)
Definition: tprintf.h:31
bool IsTraining() const
Definition: network.h:115
void CopyTimeStepFrom(int dest_t, const NetworkIO &src, int src_t)
Definition: networkio.cpp:388
const char * string() const
Definition: strngs.cpp:198
void SetupForward(const NetworkIO &input, const TransposedArray *input_transpose)
void init_to_size(int size, T t)

◆ ForwardTimeStep()

void tesseract::FullyConnected::ForwardTimeStep ( const double *  d_input,
const int8_t *  i_input,
int  t,
double *  output_line 
)

Definition at line 191 of file fullyconnected.cpp.

192  {
193  // input is copied to source_ line-by-line for cache coherency.
194  if (IsTraining() && external_source_ == nullptr && d_input != nullptr)
195  source_t_.WriteStrided(t, d_input);
196  if (d_input != nullptr)
197  weights_.MatrixDotVector(d_input, output_line);
198  else
199  weights_.MatrixDotVector(i_input, output_line);
200  if (type_ == NT_TANH) {
201  FuncInplace<GFunc>(no_, output_line);
202  } else if (type_ == NT_LOGISTIC) {
203  FuncInplace<FFunc>(no_, output_line);
204  } else if (type_ == NT_POSCLIP) {
205  FuncInplace<ClipFFunc>(no_, output_line);
206  } else if (type_ == NT_SYMCLIP) {
207  FuncInplace<ClipGFunc>(no_, output_line);
208  } else if (type_ == NT_RELU) {
209  FuncInplace<Relu>(no_, output_line);
210  } else if (type_ == NT_SOFTMAX || type_ == NT_SOFTMAX_NO_CTC) {
211  SoftmaxInPlace(no_, output_line);
212  } else if (type_ != NT_LINEAR) {
213  ASSERT_HOST("Invalid fully-connected type!" == nullptr);
214  }
215 }
const TransposedArray * external_source_
NetworkType type_
Definition: network.h:299
#define ASSERT_HOST(x)
Definition: errcode.h:84
bool IsTraining() const
Definition: network.h:115
void SoftmaxInPlace(int n, T *inout)
Definition: functions.h:163
void WriteStrided(int t, const float *data)
Definition: weightmatrix.h:39
void MatrixDotVector(const double *u, double *v) const

◆ InitWeights()

int tesseract::FullyConnected::InitWeights ( float  range,
TRand randomizer 
)
virtual

Reimplemented from tesseract::Network.

Definition at line 80 of file fullyconnected.cpp.

80  {
81  Network::SetRandomizer(randomizer);
83  range, randomizer);
84  return num_weights_;
85 }
int32_t num_weights_
Definition: network.h:305
bool TestFlag(NetworkFlags flag) const
Definition: network.h:144
virtual void SetRandomizer(TRand *randomizer)
Definition: network.cpp:140
int InitWeightsFloat(int no, int ni, bool use_adam, float weight_range, TRand *randomizer)

◆ OutputShape()

StaticShape tesseract::FullyConnected::OutputShape ( const StaticShape input_shape) const
virtual

Reimplemented from tesseract::Network.

Definition at line 49 of file fullyconnected.cpp.

49  {
50  LossType loss_type = LT_NONE;
51  if (type_ == NT_SOFTMAX)
52  loss_type = LT_CTC;
53  else if (type_ == NT_SOFTMAX_NO_CTC)
54  loss_type = LT_SOFTMAX;
55  else if (type_ == NT_LOGISTIC)
56  loss_type = LT_LOGISTIC;
57  StaticShape result(input_shape);
58  result.set_depth(no_);
59  result.set_loss_type(loss_type);
60  return result;
61 }
NetworkType type_
Definition: network.h:299

◆ RemapOutputs()

int tesseract::FullyConnected::RemapOutputs ( int  old_no,
const std::vector< int > &  code_map 
)
overridevirtual

Reimplemented from tesseract::Network.

Definition at line 90 of file fullyconnected.cpp.

90  {
91  if (type_ == NT_SOFTMAX && no_ == old_no) {
93  no_ = code_map.size();
94  }
95  return num_weights_;
96 }
int32_t num_weights_
Definition: network.h:305
NetworkType type_
Definition: network.h:299
int RemapOutputs(const std::vector< int > &code_map)

◆ Serialize()

bool tesseract::FullyConnected::Serialize ( TFile fp) const
virtual

Reimplemented from tesseract::Network.

Definition at line 109 of file fullyconnected.cpp.

109  {
110  if (!Network::Serialize(fp)) return false;
111  if (!weights_.Serialize(IsTraining(), fp)) return false;
112  return true;
113 }
bool Serialize(bool training, TFile *fp) const
virtual bool Serialize(TFile *fp) const
Definition: network.cpp:153
bool IsTraining() const
Definition: network.h:115

◆ SetEnableTraining()

void tesseract::FullyConnected::SetEnableTraining ( TrainingState  state)
virtual

Reimplemented from tesseract::Network.

Definition at line 64 of file fullyconnected.cpp.

64  {
65  if (state == TS_RE_ENABLE) {
66  // Enable only from temp disabled.
68  } else if (state == TS_TEMP_DISABLE) {
69  // Temp disable only from enabled.
70  if (training_ == TS_ENABLED) training_ = state;
71  } else {
72  if (state == TS_ENABLED && training_ != TS_ENABLED)
74  training_ = state;
75  }
76 }
TrainingState training_
Definition: network.h:300

◆ SetupForward()

void tesseract::FullyConnected::SetupForward ( const NetworkIO input,
const TransposedArray input_transpose 
)

Definition at line 179 of file fullyconnected.cpp.

180  {
181  // Softmax output is always float, so save the input type.
182  int_mode_ = input.int_mode();
183  if (IsTraining()) {
184  acts_.Resize(input, no_);
185  // Source_ is a transposed copy of input. It isn't needed if provided.
186  external_source_ = input_transpose;
187  if (external_source_ == nullptr) source_t_.ResizeNoInit(ni_, input.Width());
188  }
189 }
void Resize(const NetworkIO &src, int num_features)
Definition: networkio.h:45
void ResizeNoInit(int size1, int size2, int pad=0)
Definition: matrix.h:88
const TransposedArray * external_source_
bool IsTraining() const
Definition: network.h:115

◆ spec()

virtual STRING tesseract::FullyConnected::spec ( ) const
inlinevirtual

Reimplemented from tesseract::Network.

Definition at line 37 of file fullyconnected.h.

37  {
38  STRING spec;
39  if (type_ == NT_TANH)
40  spec.add_str_int("Ft", no_);
41  else if (type_ == NT_LOGISTIC)
42  spec.add_str_int("Fs", no_);
43  else if (type_ == NT_RELU)
44  spec.add_str_int("Fr", no_);
45  else if (type_ == NT_LINEAR)
46  spec.add_str_int("Fl", no_);
47  else if (type_ == NT_POSCLIP)
48  spec.add_str_int("Fp", no_);
49  else if (type_ == NT_SYMCLIP)
50  spec.add_str_int("Fs", no_);
51  else if (type_ == NT_SOFTMAX)
52  spec.add_str_int("Fc", no_);
53  else
54  spec.add_str_int("Fm", no_);
55  return spec;
56  }
Definition: strngs.h:45
void add_str_int(const char *str, int number)
Definition: strngs.cpp:381
NetworkType type_
Definition: network.h:299
virtual STRING spec() const

◆ Update()

void tesseract::FullyConnected::Update ( float  learning_rate,
float  momentum,
float  adam_beta,
int  num_samples 
)
overridevirtual

Reimplemented from tesseract::Network.

Definition at line 296 of file fullyconnected.cpp.

297  {
298  weights_.Update(learning_rate, momentum, adam_beta, num_samples);
299 }
void Update(double learning_rate, double momentum, double adam_beta, int num_samples)

Member Data Documentation

◆ acts_

NetworkIO tesseract::FullyConnected::acts_
protected

Definition at line 126 of file fullyconnected.h.

◆ external_source_

const TransposedArray* tesseract::FullyConnected::external_source_
protected

Definition at line 124 of file fullyconnected.h.

◆ int_mode_

bool tesseract::FullyConnected::int_mode_
protected

Definition at line 129 of file fullyconnected.h.

◆ source_t_

TransposedArray tesseract::FullyConnected::source_t_
protected

Definition at line 121 of file fullyconnected.h.

◆ weights_

WeightMatrix tesseract::FullyConnected::weights_
protected

Definition at line 119 of file fullyconnected.h.


The documentation for this class was generated from the following files: