tesseract  4.00.00dev
tesseract::FullyConnected Class Reference

#include <fullyconnected.h>

Inheritance diagram for tesseract::FullyConnected:
tesseract::Network

Public Member Functions

 FullyConnected (const STRING &name, int ni, int no, NetworkType type)
 
virtual ~FullyConnected ()
 
virtual StaticShape OutputShape (const StaticShape &input_shape) const
 
virtual STRING spec () const
 
void ChangeType (NetworkType type)
 
virtual void SetEnableTraining (TrainingState state)
 
virtual int InitWeights (float range, TRand *randomizer)
 
int RemapOutputs (int old_no, const std::vector< int > &code_map) override
 
virtual void ConvertToInt ()
 
virtual void DebugWeights ()
 
virtual bool Serialize (TFile *fp) const
 
virtual bool DeSerialize (TFile *fp)
 
virtual void Forward (bool debug, const NetworkIO &input, const TransposedArray *input_transpose, NetworkScratch *scratch, NetworkIO *output)
 
void SetupForward (const NetworkIO &input, const TransposedArray *input_transpose)
 
void ForwardTimeStep (const double *d_input, const inT8 *i_input, int t, double *output_line)
 
virtual bool Backward (bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, NetworkIO *back_deltas)
 
void BackwardTimeStep (const NetworkIO &fwd_deltas, int t, double *curr_errors, TransposedArray *errors_t, double *backprop)
 
void FinishBackward (const TransposedArray &errors_t)
 
void Update (float learning_rate, float momentum, float adam_beta, int num_samples) override
 
virtual void CountAlternators (const Network &other, double *same, double *changed) const
 
- Public Member Functions inherited from tesseract::Network
 Network ()
 
 Network (NetworkType type, const STRING &name, int ni, int no)
 
virtual ~Network ()
 
NetworkType type () const
 
bool IsTraining () const
 
bool needs_to_backprop () const
 
int num_weights () const
 
int NumInputs () const
 
int NumOutputs () const
 
virtual StaticShape InputShape () const
 
const STRINGname () const
 
bool TestFlag (NetworkFlags flag) const
 
virtual bool IsPlumbingType () const
 
virtual void SetNetworkFlags (uinT32 flags)
 
virtual void SetRandomizer (TRand *randomizer)
 
virtual bool SetupNeedsBackprop (bool needs_backprop)
 
virtual int XScaleFactor () const
 
virtual void CacheXScaleFactor (int factor)
 
void DisplayForward (const NetworkIO &matrix)
 
void DisplayBackward (const NetworkIO &matrix)
 

Protected Attributes

WeightMatrix weights_
 
TransposedArray source_t_
 
const TransposedArrayexternal_source_
 
NetworkIO acts_
 
bool int_mode_
 
- Protected Attributes inherited from tesseract::Network
NetworkType type_
 
TrainingState training_
 
bool needs_to_backprop_
 
inT32 network_flags_
 
inT32 ni_
 
inT32 no_
 
inT32 num_weights_
 
STRING name_
 
ScrollViewforward_win_
 
ScrollViewbackward_win_
 
TRandrandomizer_
 

Additional Inherited Members

- Static Public Member Functions inherited from tesseract::Network
static NetworkCreateFromFile (TFile *fp)
 
static void ClearWindow (bool tess_coords, const char *window_name, int width, int height, ScrollView **window)
 
static int DisplayImage (Pix *pix, ScrollView *window)
 
- Protected Member Functions inherited from tesseract::Network
double Random (double range)
 
- Static Protected Attributes inherited from tesseract::Network
static char const *const kTypeNames [NT_COUNT]
 

Detailed Description

Definition at line 28 of file fullyconnected.h.

Constructor & Destructor Documentation

◆ FullyConnected()

tesseract::FullyConnected::FullyConnected ( const STRING name,
int  ni,
int  no,
NetworkType  type 
)

Definition at line 39 of file fullyconnected.cpp.

41  : Network(type, name, ni, no), external_source_(NULL), int_mode_(false) {
42 }
NetworkType type() const
Definition: network.h:112
const TransposedArray * external_source_

◆ ~FullyConnected()

tesseract::FullyConnected::~FullyConnected ( )
virtual

Definition at line 44 of file fullyconnected.cpp.

44  {
45 }

Member Function Documentation

◆ Backward()

bool tesseract::FullyConnected::Backward ( bool  debug,
const NetworkIO fwd_deltas,
NetworkScratch scratch,
NetworkIO back_deltas 
)
virtual

Reimplemented from tesseract::Network.

Definition at line 219 of file fullyconnected.cpp.

221  {
222  if (debug) DisplayBackward(fwd_deltas);
223  back_deltas->Resize(fwd_deltas, ni_);
225  errors.init_to_size(kNumThreads, NetworkScratch::FloatVec());
226  for (int i = 0; i < kNumThreads; ++i) errors[i].Init(no_, scratch);
228  if (needs_to_backprop_) {
229  temp_backprops.init_to_size(kNumThreads, NetworkScratch::FloatVec());
230  for (int i = 0; i < kNumThreads; ++i) temp_backprops[i].Init(ni_, scratch);
231  }
232  int width = fwd_deltas.Width();
233  NetworkScratch::GradientStore errors_t;
234  errors_t.Init(no_, width, scratch);
235 #ifdef _OPENMP
236 #pragma omp parallel for num_threads(kNumThreads)
237  for (int t = 0; t < width; ++t) {
238  int thread_id = omp_get_thread_num();
239 #else
240  for (int t = 0; t < width; ++t) {
241  int thread_id = 0;
242 #endif
243  double* backprop = NULL;
244  if (needs_to_backprop_) backprop = temp_backprops[thread_id];
245  double* curr_errors = errors[thread_id];
246  BackwardTimeStep(fwd_deltas, t, curr_errors, errors_t.get(), backprop);
247  if (backprop != NULL) {
248  back_deltas->WriteTimeStep(t, backprop);
249  }
250  }
251  FinishBackward(*errors_t.get());
252  if (needs_to_backprop_) {
253  back_deltas->ZeroInvalidElements();
254 #if DEBUG_DETAIL > 0
255  tprintf("F Backprop:%s\n", name_.string());
256  back_deltas->Print(10);
257 #endif
258  return true;
259  }
260  return false; // No point going further back.
261 }
void BackwardTimeStep(const NetworkIO &fwd_deltas, int t, double *curr_errors, TransposedArray *errors_t, double *backprop)
const int kNumThreads
void FinishBackward(const TransposedArray &errors_t)
#define tprintf(...)
Definition: tprintf.h:31
const char * string() const
Definition: strngs.cpp:198
void DisplayBackward(const NetworkIO &matrix)
Definition: network.cpp:296
bool needs_to_backprop_
Definition: network.h:301
void init_to_size(int size, T t)

◆ BackwardTimeStep()

void tesseract::FullyConnected::BackwardTimeStep ( const NetworkIO fwd_deltas,
int  t,
double *  curr_errors,
TransposedArray errors_t,
double *  backprop 
)

Definition at line 263 of file fullyconnected.cpp.

266  {
267  if (type_ == NT_TANH)
268  acts_.FuncMultiply<GPrime>(fwd_deltas, t, curr_errors);
269  else if (type_ == NT_LOGISTIC)
270  acts_.FuncMultiply<FPrime>(fwd_deltas, t, curr_errors);
271  else if (type_ == NT_POSCLIP)
272  acts_.FuncMultiply<ClipFPrime>(fwd_deltas, t, curr_errors);
273  else if (type_ == NT_SYMCLIP)
274  acts_.FuncMultiply<ClipGPrime>(fwd_deltas, t, curr_errors);
275  else if (type_ == NT_RELU)
276  acts_.FuncMultiply<ReluPrime>(fwd_deltas, t, curr_errors);
277  else if (type_ == NT_SOFTMAX || type_ == NT_SOFTMAX_NO_CTC ||
278  type_ == NT_LINEAR)
279  fwd_deltas.ReadTimeStep(t, curr_errors); // fwd_deltas are the errors.
280  else
281  ASSERT_HOST("Invalid fully-connected type!" == NULL);
282  // Generate backprop only if needed by the lower layer.
283  if (backprop != NULL) weights_.VectorDotMatrix(curr_errors, backprop);
284  errors_t->WriteStrided(t, curr_errors);
285 }
NetworkType type_
Definition: network.h:299
void VectorDotMatrix(const double *u, double *v) const
void FuncMultiply(const NetworkIO &v_io, int t, double *product)
Definition: networkio.h:259
#define ASSERT_HOST(x)
Definition: errcode.h:84

◆ ChangeType()

void tesseract::FullyConnected::ChangeType ( NetworkType  type)
inline

Definition at line 60 of file fullyconnected.h.

60  {
61  type_ = type;
62  }
NetworkType type_
Definition: network.h:299
NetworkType type() const
Definition: network.h:112

◆ ConvertToInt()

void tesseract::FullyConnected::ConvertToInt ( )
virtual

Reimplemented from tesseract::Network.

Definition at line 99 of file fullyconnected.cpp.

99  {
101 }

◆ CountAlternators()

void tesseract::FullyConnected::CountAlternators ( const Network other,
double *  same,
double *  changed 
) const
virtual

Reimplemented from tesseract::Network.

Definition at line 304 of file fullyconnected.cpp.

305  {
306  ASSERT_HOST(other.type() == type_);
307  const FullyConnected* fc = static_cast<const FullyConnected*>(&other);
308  weights_.CountAlternators(fc->weights_, same, changed);
309 }
NetworkType type_
Definition: network.h:299
void CountAlternators(const WeightMatrix &other, double *same, double *changed) const
#define ASSERT_HOST(x)
Definition: errcode.h:84
FullyConnected(const STRING &name, int ni, int no, NetworkType type)

◆ DebugWeights()

void tesseract::FullyConnected::DebugWeights ( )
virtual

Reimplemented from tesseract::Network.

Definition at line 104 of file fullyconnected.cpp.

104  {
106 }
const char * string() const
Definition: strngs.cpp:198
void Debug2D(const char *msg)

◆ DeSerialize()

bool tesseract::FullyConnected::DeSerialize ( TFile fp)
virtual

Reimplemented from tesseract::Network.

Definition at line 116 of file fullyconnected.cpp.

116  {
117  return weights_.DeSerialize(IsTraining(), fp);
118 }
bool IsTraining() const
Definition: network.h:115
bool DeSerialize(bool training, TFile *fp)

◆ FinishBackward()

void tesseract::FullyConnected::FinishBackward ( const TransposedArray errors_t)

Definition at line 287 of file fullyconnected.cpp.

287  {
288  if (external_source_ == NULL)
289  weights_.SumOuterTransposed(errors_t, source_t_, true);
290  else
291  weights_.SumOuterTransposed(errors_t, *external_source_, true);
292 }
void SumOuterTransposed(const TransposedArray &u, const TransposedArray &v, bool parallel)
const TransposedArray * external_source_

◆ Forward()

void tesseract::FullyConnected::Forward ( bool  debug,
const NetworkIO input,
const TransposedArray input_transpose,
NetworkScratch scratch,
NetworkIO output 
)
virtual

Reimplemented from tesseract::Network.

Definition at line 122 of file fullyconnected.cpp.

124  {
125  int width = input.Width();
126  if (type_ == NT_SOFTMAX)
127  output->ResizeFloat(input, no_);
128  else
129  output->Resize(input, no_);
130  SetupForward(input, input_transpose);
132  temp_lines.init_to_size(kNumThreads, NetworkScratch::FloatVec());
134  curr_input.init_to_size(kNumThreads, NetworkScratch::FloatVec());
135  for (int i = 0; i < kNumThreads; ++i) {
136  temp_lines[i].Init(no_, scratch);
137  curr_input[i].Init(ni_, scratch);
138  }
139 #ifdef _OPENMP
140 #pragma omp parallel for num_threads(kNumThreads)
141  for (int t = 0; t < width; ++t) {
142  // Thread-local pointer to temporary storage.
143  int thread_id = omp_get_thread_num();
144 #else
145  for (int t = 0; t < width; ++t) {
146  // Thread-local pointer to temporary storage.
147  int thread_id = 0;
148 #endif
149  double* temp_line = temp_lines[thread_id];
150  const double* d_input = NULL;
151  const inT8* i_input = NULL;
152  if (input.int_mode()) {
153  i_input = input.i(t);
154  } else {
155  input.ReadTimeStep(t, curr_input[thread_id]);
156  d_input = curr_input[thread_id];
157  }
158  ForwardTimeStep(d_input, i_input, t, temp_line);
159  output->WriteTimeStep(t, temp_line);
160  if (IsTraining() && type_ != NT_SOFTMAX) {
161  acts_.CopyTimeStepFrom(t, *output, t);
162  }
163  }
164  // Zero all the elements that are in the padding around images that allows
165  // multiple different-sized images to exist in a single array.
166  // acts_ is only used if this is not a softmax op.
167  if (IsTraining() && type_ != NT_SOFTMAX) {
169  }
170  output->ZeroInvalidElements();
171 #if DEBUG_DETAIL > 0
172  tprintf("F Output:%s\n", name_.string());
173  output->Print(10);
174 #endif
175  if (debug) DisplayForward(*output);
176 }
void ZeroInvalidElements()
Definition: networkio.cpp:93
void DisplayForward(const NetworkIO &matrix)
Definition: network.cpp:285
void ForwardTimeStep(const double *d_input, const inT8 *i_input, int t, double *output_line)
NetworkType type_
Definition: network.h:299
const int kNumThreads
#define tprintf(...)
Definition: tprintf.h:31
bool IsTraining() const
Definition: network.h:115
void CopyTimeStepFrom(int dest_t, const NetworkIO &src, int src_t)
Definition: networkio.cpp:388
int8_t inT8
Definition: host.h:34
const char * string() const
Definition: strngs.cpp:198
void SetupForward(const NetworkIO &input, const TransposedArray *input_transpose)
void init_to_size(int size, T t)

◆ ForwardTimeStep()

void tesseract::FullyConnected::ForwardTimeStep ( const double *  d_input,
const inT8 i_input,
int  t,
double *  output_line 
)

Definition at line 191 of file fullyconnected.cpp.

192  {
193  // input is copied to source_ line-by-line for cache coherency.
194  if (IsTraining() && external_source_ == NULL && d_input != NULL)
195  source_t_.WriteStrided(t, d_input);
196  if (d_input != NULL)
197  weights_.MatrixDotVector(d_input, output_line);
198  else
199  weights_.MatrixDotVector(i_input, output_line);
200  if (type_ == NT_TANH) {
201  FuncInplace<GFunc>(no_, output_line);
202  } else if (type_ == NT_LOGISTIC) {
203  FuncInplace<FFunc>(no_, output_line);
204  } else if (type_ == NT_POSCLIP) {
205  FuncInplace<ClipFFunc>(no_, output_line);
206  } else if (type_ == NT_SYMCLIP) {
207  FuncInplace<ClipGFunc>(no_, output_line);
208  } else if (type_ == NT_RELU) {
209  FuncInplace<Relu>(no_, output_line);
210  } else if (type_ == NT_SOFTMAX || type_ == NT_SOFTMAX_NO_CTC) {
211  SoftmaxInPlace(no_, output_line);
212  } else if (type_ != NT_LINEAR) {
213  ASSERT_HOST("Invalid fully-connected type!" == NULL);
214  }
215 }
void SoftmaxInPlace(int n, T *inout)
Definition: functions.h:163
NetworkType type_
Definition: network.h:299
void MatrixDotVector(const double *u, double *v) const
bool IsTraining() const
Definition: network.h:115
#define ASSERT_HOST(x)
Definition: errcode.h:84
void WriteStrided(int t, const float *data)
Definition: weightmatrix.h:39
const TransposedArray * external_source_

◆ InitWeights()

int tesseract::FullyConnected::InitWeights ( float  range,
TRand randomizer 
)
virtual

Reimplemented from tesseract::Network.

Definition at line 80 of file fullyconnected.cpp.

80  {
81  Network::SetRandomizer(randomizer);
83  range, randomizer);
84  return num_weights_;
85 }
virtual void SetRandomizer(TRand *randomizer)
Definition: network.cpp:140
bool TestFlag(NetworkFlags flag) const
Definition: network.h:144
int InitWeightsFloat(int no, int ni, bool use_adam, float weight_range, TRand *randomizer)
inT32 num_weights_
Definition: network.h:305

◆ OutputShape()

StaticShape tesseract::FullyConnected::OutputShape ( const StaticShape input_shape) const
virtual

Reimplemented from tesseract::Network.

Definition at line 49 of file fullyconnected.cpp.

49  {
50  LossType loss_type = LT_NONE;
51  if (type_ == NT_SOFTMAX)
52  loss_type = LT_CTC;
53  else if (type_ == NT_SOFTMAX_NO_CTC)
54  loss_type = LT_SOFTMAX;
55  else if (type_ == NT_LOGISTIC)
56  loss_type = LT_LOGISTIC;
57  StaticShape result(input_shape);
58  result.set_depth(no_);
59  result.set_loss_type(loss_type);
60  return result;
61 }
NetworkType type_
Definition: network.h:299

◆ RemapOutputs()

int tesseract::FullyConnected::RemapOutputs ( int  old_no,
const std::vector< int > &  code_map 
)
overridevirtual

Reimplemented from tesseract::Network.

Definition at line 90 of file fullyconnected.cpp.

90  {
91  if (type_ == NT_SOFTMAX && no_ == old_no) {
93  no_ = code_map.size();
94  }
95  return num_weights_;
96 }
int RemapOutputs(const std::vector< int > &code_map)
NetworkType type_
Definition: network.h:299
inT32 num_weights_
Definition: network.h:305

◆ Serialize()

bool tesseract::FullyConnected::Serialize ( TFile fp) const
virtual

Reimplemented from tesseract::Network.

Definition at line 109 of file fullyconnected.cpp.

109  {
110  if (!Network::Serialize(fp)) return false;
111  if (!weights_.Serialize(IsTraining(), fp)) return false;
112  return true;
113 }
virtual bool Serialize(TFile *fp) const
Definition: network.cpp:153
bool IsTraining() const
Definition: network.h:115
bool Serialize(bool training, TFile *fp) const

◆ SetEnableTraining()

void tesseract::FullyConnected::SetEnableTraining ( TrainingState  state)
virtual

Reimplemented from tesseract::Network.

Definition at line 64 of file fullyconnected.cpp.

64  {
65  if (state == TS_RE_ENABLE) {
66  // Enable only from temp disabled.
68  } else if (state == TS_TEMP_DISABLE) {
69  // Temp disable only from enabled.
70  if (training_ == TS_ENABLED) training_ = state;
71  } else {
72  if (state == TS_ENABLED && training_ != TS_ENABLED)
74  training_ = state;
75  }
76 }
TrainingState training_
Definition: network.h:300

◆ SetupForward()

void tesseract::FullyConnected::SetupForward ( const NetworkIO input,
const TransposedArray input_transpose 
)

Definition at line 179 of file fullyconnected.cpp.

180  {
181  // Softmax output is always float, so save the input type.
182  int_mode_ = input.int_mode();
183  if (IsTraining()) {
184  acts_.Resize(input, no_);
185  // Source_ is a transposed copy of input. It isn't needed if provided.
186  external_source_ = input_transpose;
187  if (external_source_ == NULL) source_t_.ResizeNoInit(ni_, input.Width());
188  }
189 }
void Resize(const NetworkIO &src, int num_features)
Definition: networkio.h:45
void ResizeNoInit(int size1, int size2, int pad=0)
Definition: matrix.h:88
bool IsTraining() const
Definition: network.h:115
const TransposedArray * external_source_

◆ spec()

virtual STRING tesseract::FullyConnected::spec ( ) const
inlinevirtual

Reimplemented from tesseract::Network.

Definition at line 37 of file fullyconnected.h.

37  {
38  STRING spec;
39  if (type_ == NT_TANH)
40  spec.add_str_int("Ft", no_);
41  else if (type_ == NT_LOGISTIC)
42  spec.add_str_int("Fs", no_);
43  else if (type_ == NT_RELU)
44  spec.add_str_int("Fr", no_);
45  else if (type_ == NT_LINEAR)
46  spec.add_str_int("Fl", no_);
47  else if (type_ == NT_POSCLIP)
48  spec.add_str_int("Fp", no_);
49  else if (type_ == NT_SYMCLIP)
50  spec.add_str_int("Fs", no_);
51  else if (type_ == NT_SOFTMAX)
52  spec.add_str_int("Fc", no_);
53  else
54  spec.add_str_int("Fm", no_);
55  return spec;
56  }
void add_str_int(const char *str, int number)
Definition: strngs.cpp:381
NetworkType type_
Definition: network.h:299
Definition: strngs.h:45
virtual STRING spec() const

◆ Update()

void tesseract::FullyConnected::Update ( float  learning_rate,
float  momentum,
float  adam_beta,
int  num_samples 
)
overridevirtual

Reimplemented from tesseract::Network.

Definition at line 296 of file fullyconnected.cpp.

297  {
298  weights_.Update(learning_rate, momentum, adam_beta, num_samples);
299 }
void Update(double learning_rate, double momentum, double adam_beta, int num_samples)

Member Data Documentation

◆ acts_

NetworkIO tesseract::FullyConnected::acts_
protected

Definition at line 126 of file fullyconnected.h.

◆ external_source_

const TransposedArray* tesseract::FullyConnected::external_source_
protected

Definition at line 124 of file fullyconnected.h.

◆ int_mode_

bool tesseract::FullyConnected::int_mode_
protected

Definition at line 129 of file fullyconnected.h.

◆ source_t_

TransposedArray tesseract::FullyConnected::source_t_
protected

Definition at line 121 of file fullyconnected.h.

◆ weights_

WeightMatrix tesseract::FullyConnected::weights_
protected

Definition at line 119 of file fullyconnected.h.


The documentation for this class was generated from the following files: