All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
neural_net.h
Go to the documentation of this file.
1 // Copyright 2008 Google Inc.
2 // All Rights Reserved.
3 // Author: ahmadab@google.com (Ahmad Abdulkader)
4 //
5 // neural_net.h: Declarations of a class for an object that
6 // represents an arbitrary network of neurons
7 //
8 
9 #ifndef NEURAL_NET_H
10 #define NEURAL_NET_H
11 
12 #include <string>
13 #include <vector>
14 #include "neuron.h"
15 #include "input_file_buffer.h"
16 
17 namespace tesseract {
18 
19 // Minimum input range below which we set the input weight to zero
20 static const float kMinInputRange = 1e-6f;
21 
22 class NeuralNet {
23  public:
24  NeuralNet();
25  virtual ~NeuralNet();
26  // create a net object from a file. Uses stdio
27  static NeuralNet *FromFile(const string file_name);
28  // create a net object from an input buffer
30  // Different flavors of feed forward function
31  template <typename Type> bool FeedForward(const Type *inputs,
32  Type *outputs);
33  // Compute the output of a specific output node.
34  // This function is useful for application that are interested in a single
35  // output of the net and do not want to waste time on the rest
36  template <typename Type> bool GetNetOutput(const Type *inputs,
37  int output_id,
38  Type *output);
39  // Accessor functions
40  int in_cnt() const { return in_cnt_; }
41  int out_cnt() const { return out_cnt_; }
42 
43  protected:
44  struct Node;
45  // A node-weight pair
46  struct WeightedNode {
48  float input_weight;
49  };
50  // node struct used for fast feedforward in
51  // Read only nets
52  struct Node {
53  float out;
54  float bias;
57  };
58  // Read-Only flag (no training: On by default)
59  // will presumeably be set to false by
60  // the inherting TrainableNeuralNet class
61  bool read_only_;
62  // input count
63  int in_cnt_;
64  // output count
65  int out_cnt_;
66  // Total neuron count (including inputs)
68  // count of unique weights
69  int wts_cnt_;
70  // Neuron vector
72  // size of allocated weight chunk (in weights)
73  // This is basically the size of the biggest network
74  // that I have trained. However, the class will allow
75  // a bigger sized net if desired
76  static const int kWgtChunkSize = 0x10000;
77  // Magic number expected at the beginning of the NN
78  // binary file
79  static const unsigned int kNetSignature = 0xFEFEABD0;
80  // count of allocated wgts in the last chunk
82  // vector of weights buffers
83  vector<vector<float> *>wts_vec_;
84  // Is the net an auto-encoder type
86  // vector of input max values
87  vector<float> inputs_max_;
88  // vector of input min values
89  vector<float> inputs_min_;
90  // vector of input mean values
91  vector<float> inputs_mean_;
92  // vector of input standard deviation values
93  vector<float> inputs_std_dev_;
94  // vector of input offsets used by fast read-only
95  // feedforward function
96  vector<Node> fast_nodes_;
97  // Network Initialization function
98  void Init();
99  // Clears all neurons
100  void Clear() {
101  for (int node = 0; node < neuron_cnt_; node++) {
102  neurons_[node].Clear();
103  }
104  }
105  // Reads the net from an input buffer
106  template<class ReadBuffType> bool ReadBinary(ReadBuffType *input_buff) {
107  // Init vars
108  Init();
109  // is this an autoencoder
110  unsigned int read_val;
111  unsigned int auto_encode;
112  // read and verify signature
113  if (input_buff->Read(&read_val, sizeof(read_val)) != sizeof(read_val)) {
114  return false;
115  }
116  if (read_val != kNetSignature) {
117  return false;
118  }
119  if (input_buff->Read(&auto_encode, sizeof(auto_encode)) !=
120  sizeof(auto_encode)) {
121  return false;
122  }
123  auto_encoder_ = auto_encode;
124  // read and validate total # of nodes
125  if (input_buff->Read(&read_val, sizeof(read_val)) != sizeof(read_val)) {
126  return false;
127  }
128  neuron_cnt_ = read_val;
129  if (neuron_cnt_ <= 0) {
130  return false;
131  }
132  // set the size of the neurons vector
133  neurons_ = new Neuron[neuron_cnt_];
134  if (neurons_ == NULL) {
135  return false;
136  }
137  // read & validate inputs
138  if (input_buff->Read(&read_val, sizeof(read_val)) != sizeof(read_val)) {
139  return false;
140  }
141  in_cnt_ = read_val;
142  if (in_cnt_ <= 0) {
143  return false;
144  }
145  // read outputs
146  if (input_buff->Read(&read_val, sizeof(read_val)) != sizeof(read_val)) {
147  return false;
148  }
149  out_cnt_ = read_val;
150  if (out_cnt_ <= 0) {
151  return false;
152  }
153  // set neuron ids and types
154  for (int idx = 0; idx < neuron_cnt_; idx++) {
155  neurons_[idx].set_id(idx);
156  // input type
157  if (idx < in_cnt_) {
158  neurons_[idx].set_node_type(Neuron::Input);
159  } else if (idx >= (neuron_cnt_ - out_cnt_)) {
160  neurons_[idx].set_node_type(Neuron::Output);
161  } else {
162  neurons_[idx].set_node_type(Neuron::Hidden);
163  }
164  }
165  // read the connections
166  for (int node_idx = 0; node_idx < neuron_cnt_; node_idx++) {
167  // read fanout
168  if (input_buff->Read(&read_val, sizeof(read_val)) != sizeof(read_val)) {
169  return false;
170  }
171  // read the neuron's info
172  int fan_out_cnt = read_val;
173  for (int fan_out_idx = 0; fan_out_idx < fan_out_cnt; fan_out_idx++) {
174  // read the neuron id
175  if (input_buff->Read(&read_val, sizeof(read_val)) != sizeof(read_val)) {
176  return false;
177  }
178  // create the connection
179  if (!SetConnection(node_idx, read_val)) {
180  return false;
181  }
182  }
183  }
184  // read all the neurons' fan-in connections
185  for (int node_idx = 0; node_idx < neuron_cnt_; node_idx++) {
186  // read
187  if (!neurons_[node_idx].ReadBinary(input_buff)) {
188  return false;
189  }
190  }
191  // size input stats vector to expected input size
192  inputs_mean_.resize(in_cnt_);
193  inputs_std_dev_.resize(in_cnt_);
194  inputs_min_.resize(in_cnt_);
195  inputs_max_.resize(in_cnt_);
196  // read stats
197  if (input_buff->Read(&(inputs_mean_.front()),
198  sizeof(inputs_mean_[0]) * in_cnt_) !=
199  sizeof(inputs_mean_[0]) * in_cnt_) {
200  return false;
201  }
202  if (input_buff->Read(&(inputs_std_dev_.front()),
203  sizeof(inputs_std_dev_[0]) * in_cnt_) !=
204  sizeof(inputs_std_dev_[0]) * in_cnt_) {
205  return false;
206  }
207  if (input_buff->Read(&(inputs_min_.front()),
208  sizeof(inputs_min_[0]) * in_cnt_) !=
209  sizeof(inputs_min_[0]) * in_cnt_) {
210  return false;
211  }
212  if (input_buff->Read(&(inputs_max_.front()),
213  sizeof(inputs_max_[0]) * in_cnt_) !=
214  sizeof(inputs_max_[0]) * in_cnt_) {
215  return false;
216  }
217  // create a readonly version for fast feedforward
218  if (read_only_) {
219  return CreateFastNet();
220  }
221  return true;
222  }
223 
224  // creates a connection between two nodes
225  bool SetConnection(int from, int to);
226  // Create a read only version of the net that
227  // has faster feedforward performance
228  bool CreateFastNet();
229  // internal function to allocate a new set of weights
230  // Centralized weight allocation attempts to increase
231  // weights locality of reference making it more cache friendly
232  float *AllocWgt(int wgt_cnt);
233  // different flavors read-only feedforward function
234  template <typename Type> bool FastFeedForward(const Type *inputs,
235  Type *outputs);
236  // Compute the output of a specific output node.
237  // This function is useful for application that are interested in a single
238  // output of the net and do not want to waste time on the rest
239  // This is the fast-read-only version of this function
240  template <typename Type> bool FastGetNetOutput(const Type *inputs,
241  int output_id,
242  Type *output);
243 };
244 }
245 
246 #endif // NEURAL_NET_H__
static NeuralNet * FromInputBuffer(InputFileBuffer *ib)
Definition: neural_net.cpp:213
static const int kWgtChunkSize
Definition: neural_net.h:76
int in_cnt() const
Definition: neural_net.h:40
vector< float > inputs_min_
Definition: neural_net.h:89
float * AllocWgt(int wgt_cnt)
Definition: neural_net.cpp:189
bool ReadBinary(ReadBuffType *input_buff)
Definition: neural_net.h:106
bool FastFeedForward(const Type *inputs, Type *outputs)
Definition: neural_net.cpp:52
bool FastGetNetOutput(const Type *inputs, int output_id, Type *output)
Definition: neural_net.cpp:231
bool GetNetOutput(const Type *inputs, int output_id, Type *output)
Definition: neural_net.cpp:265
vector< float > inputs_mean_
Definition: neural_net.h:91
bool FeedForward(const Type *inputs, Type *outputs)
Definition: neural_net.cpp:79
vector< float > inputs_std_dev_
Definition: neural_net.h:93
vector< Node > fast_nodes_
Definition: neural_net.h:96
int out_cnt() const
Definition: neural_net.h:41
static const unsigned int kNetSignature
Definition: neural_net.h:79
vector< float > inputs_max_
Definition: neural_net.h:87
void Clear()
Definition: neuron.h:37
void set_id(int id)
Definition: neuron.h:111
void set_node_type(NeuronTypes type)
Definition: neuron.cpp:62
#define NULL
Definition: host.h:144
vector< vector< float > * > wts_vec_
Definition: neural_net.h:83
bool SetConnection(int from, int to)
Definition: neural_net.cpp:112
WeightedNode * inputs
Definition: neural_net.h:56
static NeuralNet * FromFile(const string file_name)
Definition: neural_net.cpp:204