All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
hybrid_neural_net_classifier.cpp
Go to the documentation of this file.
1 /**********************************************************************
2  * File: charclassifier.cpp
3  * Description: Implementation of Convolutional-NeuralNet Character Classifier
4  * Author: Ahmad Abdulkader
5  * Created: 2007
6  *
7  * (C) Copyright 2008, Google Inc.
8  ** Licensed under the Apache License, Version 2.0 (the "License");
9  ** you may not use this file except in compliance with the License.
10  ** You may obtain a copy of the License at
11  ** http://www.apache.org/licenses/LICENSE-2.0
12  ** Unless required by applicable law or agreed to in writing, software
13  ** distributed under the License is distributed on an "AS IS" BASIS,
14  ** WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  ** See the License for the specific language governing permissions and
16  ** limitations under the License.
17  *
18  **********************************************************************/
19 
20 #include <algorithm>
21 #include <stdio.h>
22 #include <stdlib.h>
23 #include <string>
24 #include <vector>
25 #include <wctype.h>
26 
27 #include "classifier_base.h"
28 #include "char_set.h"
29 #include "const.h"
30 #include "conv_net_classifier.h"
31 #include "cube_utils.h"
32 #include "feature_base.h"
33 #include "feature_bmp.h"
35 #include "tess_lang_model.h"
36 
37 namespace tesseract {
38 
40  CharSet *char_set,
41  TuningParams *params,
42  FeatureBase *feat_extract)
43  : CharClassifier(char_set, params, feat_extract) {
44  net_input_ = NULL;
45  net_output_ = NULL;
46 }
47 
49  for (int net_idx = 0; net_idx < nets_.size(); net_idx++) {
50  if (nets_[net_idx] != NULL) {
51  delete nets_[net_idx];
52  }
53  }
54  nets_.clear();
55 
56  if (net_input_ != NULL) {
57  delete []net_input_;
58  net_input_ = NULL;
59  }
60 
61  if (net_output_ != NULL) {
62  delete []net_output_;
63  net_output_ = NULL;
64  }
65 }
66 
67 // The main training function. Given a sample and a class ID the classifier
68 // updates its parameters according to its learning algorithm. This function
69 // is currently not implemented. TODO(ahmadab): implement end-2-end training
70 bool HybridNeuralNetCharClassifier::Train(CharSamp *char_samp, int ClassID) {
71  return false;
72 }
73 
74 // A secondary function needed for training. Allows the trainer to set the
75 // value of any train-time paramter. This function is currently not
76 // implemented. TODO(ahmadab): implement end-2-end training
77 bool HybridNeuralNetCharClassifier::SetLearnParam(char *var_name, float val) {
78  // TODO(ahmadab): implementation of parameter initializing.
79  return false;
80 }
81 
82 // Folds the output of the NeuralNet using the loaded folding sets
83 void HybridNeuralNetCharClassifier::Fold() {
84  // in case insensitive mode
85  if (case_sensitive_ == false) {
86  int class_cnt = char_set_->ClassCount();
87  // fold case
88  for (int class_id = 0; class_id < class_cnt; class_id++) {
89  // get class string
90  const char_32 *str32 = char_set_->ClassString(class_id);
91  // get the upper case form of the string
92  string_32 upper_form32 = str32;
93  for (int ch = 0; ch < upper_form32.length(); ch++) {
94  if (iswalpha(static_cast<int>(upper_form32[ch])) != 0) {
95  upper_form32[ch] = towupper(upper_form32[ch]);
96  }
97  }
98 
99  // find out the upperform class-id if any
100  int upper_class_id =
101  char_set_->ClassID(reinterpret_cast<const char_32 *>(
102  upper_form32.c_str()));
103  if (upper_class_id != -1 && class_id != upper_class_id) {
104  float max_out = MAX(net_output_[class_id], net_output_[upper_class_id]);
105  net_output_[class_id] = max_out;
106  net_output_[upper_class_id] = max_out;
107  }
108  }
109  }
110 
111  // The folding sets specify how groups of classes should be folded
112  // Folding involved assigning a min-activation to all the members
113  // of the folding set. The min-activation is a fraction of the max-activation
114  // of the members of the folding set
115  for (int fold_set = 0; fold_set < fold_set_cnt_; fold_set++) {
116  float max_prob = net_output_[fold_sets_[fold_set][0]];
117 
118  for (int ch = 1; ch < fold_set_len_[fold_set]; ch++) {
119  if (net_output_[fold_sets_[fold_set][ch]] > max_prob) {
120  max_prob = net_output_[fold_sets_[fold_set][ch]];
121  }
122  }
123  for (int ch = 0; ch < fold_set_len_[fold_set]; ch++) {
124  net_output_[fold_sets_[fold_set][ch]] = MAX(max_prob * kFoldingRatio,
125  net_output_[fold_sets_[fold_set][ch]]);
126  }
127  }
128 }
129 
130 // compute the features of specified charsamp and
131 // feedforward the specified nets
132 bool HybridNeuralNetCharClassifier::RunNets(CharSamp *char_samp) {
133  int feat_cnt = feat_extract_->FeatureCnt();
134  int class_cnt = char_set_->ClassCount();
135 
136  // allocate i/p and o/p buffers if needed
137  if (net_input_ == NULL) {
138  net_input_ = new float[feat_cnt];
139  if (net_input_ == NULL) {
140  return false;
141  }
142 
143  net_output_ = new float[class_cnt];
144  if (net_output_ == NULL) {
145  return false;
146  }
147  }
148 
149  // compute input features
150  if (feat_extract_->ComputeFeatures(char_samp, net_input_) == false) {
151  return false;
152  }
153 
154  // go thru all the nets
155  memset(net_output_, 0, class_cnt * sizeof(*net_output_));
156  float *inputs = net_input_;
157  for (int net_idx = 0; net_idx < nets_.size(); net_idx++) {
158  // run each net
159  vector<float> net_out(class_cnt, 0.0);
160  if (!nets_[net_idx]->FeedForward(inputs, &net_out[0])) {
161  return false;
162  }
163  // add the output values
164  for (int class_idx = 0; class_idx < class_cnt; class_idx++) {
165  net_output_[class_idx] += (net_out[class_idx] * net_wgts_[net_idx]);
166  }
167  // increment inputs pointer
168  inputs += nets_[net_idx]->in_cnt();
169  }
170 
171  Fold();
172 
173  return true;
174 }
175 
176 // return the cost of being a char
178  // it is by design that a character cost is equal to zero
179  // when no nets are present. This is the case during training.
180  if (RunNets(char_samp) == false) {
181  return 0;
182  }
183 
184  return CubeUtils::Prob2Cost(1.0f - net_output_[0]);
185 }
186 
187 // classifies a charsamp and returns an alternate list
188 // of chars sorted by char costs
190  // run the needed nets
191  if (RunNets(char_samp) == false) {
192  return NULL;
193  }
194 
195  int class_cnt = char_set_->ClassCount();
196 
197  // create an altlist
198  CharAltList *alt_list = new CharAltList(char_set_, class_cnt);
199  if (alt_list == NULL) {
200  return NULL;
201  }
202 
203  for (int out = 1; out < class_cnt; out++) {
204  int cost = CubeUtils::Prob2Cost(net_output_[out]);
205  alt_list->Insert(out, cost);
206  }
207 
208  return alt_list;
209 }
210 
211 // set an external net (for training purposes)
213 }
214 
215 // Load folding sets
216 // This function returns true on success or if the file can't be read,
217 // returns false if an error is encountered.
218 bool HybridNeuralNetCharClassifier::LoadFoldingSets(
219  const string &data_file_path, const string &lang, LangModel *lang_mod) {
220  fold_set_cnt_ = 0;
221  string fold_file_name;
222  fold_file_name = data_file_path + lang;
223  fold_file_name += ".cube.fold";
224 
225  // folding sets are optional
226  FILE *fp = fopen(fold_file_name.c_str(), "rb");
227  if (fp == NULL) {
228  return true;
229  }
230  fclose(fp);
231 
232  string fold_sets_str;
233  if (!CubeUtils::ReadFileToString(fold_file_name,
234  &fold_sets_str)) {
235  return false;
236  }
237 
238  // split into lines
239  vector<string> str_vec;
240  CubeUtils::SplitStringUsing(fold_sets_str, "\r\n", &str_vec);
241  fold_set_cnt_ = str_vec.size();
242  fold_sets_ = new int *[fold_set_cnt_];
243  if (fold_sets_ == NULL) {
244  return false;
245  }
246  fold_set_len_ = new int[fold_set_cnt_];
247  if (fold_set_len_ == NULL) {
248  fold_set_cnt_ = 0;
249  return false;
250  }
251 
252  for (int fold_set = 0; fold_set < fold_set_cnt_; fold_set++) {
253  reinterpret_cast<TessLangModel *>(lang_mod)->RemoveInvalidCharacters(
254  &str_vec[fold_set]);
255 
256  // if all or all but one character are invalid, invalidate this set
257  if (str_vec[fold_set].length() <= 1) {
258  fprintf(stderr, "Cube WARNING (ConvNetCharClassifier::LoadFoldingSets): "
259  "invalidating folding set %d\n", fold_set);
260  fold_set_len_[fold_set] = 0;
261  fold_sets_[fold_set] = NULL;
262  continue;
263  }
264 
265  string_32 str32;
266  CubeUtils::UTF8ToUTF32(str_vec[fold_set].c_str(), &str32);
267  fold_set_len_[fold_set] = str32.length();
268  fold_sets_[fold_set] = new int[fold_set_len_[fold_set]];
269  if (fold_sets_[fold_set] == NULL) {
270  fprintf(stderr, "Cube ERROR (ConvNetCharClassifier::LoadFoldingSets): "
271  "could not allocate folding set\n");
272  fold_set_cnt_ = fold_set;
273  return false;
274  }
275  for (int ch = 0; ch < fold_set_len_[fold_set]; ch++) {
276  fold_sets_[fold_set][ch] = char_set_->ClassID(str32[ch]);
277  }
278  }
279  return true;
280 }
281 
282 // Init the classifier provided a data-path and a language string
283 bool HybridNeuralNetCharClassifier::Init(const string &data_file_path,
284  const string &lang,
285  LangModel *lang_mod) {
286  if (init_ == true) {
287  return true;
288  }
289 
290  // load the nets if any. This function will return true if the net file
291  // does not exist. But will fail if the net did not pass the sanity checks
292  if (!LoadNets(data_file_path, lang)) {
293  return false;
294  }
295 
296  // load the folding sets if any. This function will return true if the
297  // file does not exist. But will fail if the it did not pass the sanity checks
298  if (!LoadFoldingSets(data_file_path, lang, lang_mod)) {
299  return false;
300  }
301 
302  init_ = true;
303  return true;
304 }
305 
306 // Load the classifier's Neural Nets
307 // This function will return true if the net file does not exist.
308 // But will fail if the net did not pass the sanity checks
309 bool HybridNeuralNetCharClassifier::LoadNets(const string &data_file_path,
310  const string &lang) {
311  string hybrid_net_file;
312  string junk_net_file;
313 
314  // add the lang identifier
315  hybrid_net_file = data_file_path + lang;
316  hybrid_net_file += ".cube.hybrid";
317 
318  // neural network is optional
319  FILE *fp = fopen(hybrid_net_file.c_str(), "rb");
320  if (fp == NULL) {
321  return true;
322  }
323  fclose(fp);
324 
325  string str;
326  if (!CubeUtils::ReadFileToString(hybrid_net_file, &str)) {
327  return false;
328  }
329 
330  // split into lines
331  vector<string> str_vec;
332  CubeUtils::SplitStringUsing(str, "\r\n", &str_vec);
333  if (str_vec.size() <= 0) {
334  return false;
335  }
336 
337  // create and add the nets
338  nets_.resize(str_vec.size(), NULL);
339  net_wgts_.resize(str_vec.size(), 0);
340  int total_input_size = 0;
341  for (int net_idx = 0; net_idx < str_vec.size(); net_idx++) {
342  // parse the string
343  vector<string> tokens_vec;
344  CubeUtils::SplitStringUsing(str_vec[net_idx], " \t", &tokens_vec);
345  // has to be 2 tokens, net name and input size
346  if (tokens_vec.size() != 2) {
347  return false;
348  }
349  // load the net
350  string net_file_name = data_file_path + tokens_vec[0];
351  nets_[net_idx] = tesseract::NeuralNet::FromFile(net_file_name);
352  if (nets_[net_idx] == NULL) {
353  return false;
354  }
355  // parse the input size and validate it
356  net_wgts_[net_idx] = atof(tokens_vec[1].c_str());
357  if (net_wgts_[net_idx] < 0.0) {
358  return false;
359  }
360  total_input_size += nets_[net_idx]->in_cnt();
361  }
362  // validate total input count
363  if (total_input_size != feat_extract_->FeatureCnt()) {
364  return false;
365  }
366  // success
367  return true;
368 }
369 } // tesseract
#define MAX(x, y)
Definition: ndminx.h:24
bool Insert(int class_id, int cost, void *tag=NULL)
static int Prob2Cost(double prob_val)
Definition: cube_utils.cpp:37
virtual bool ComputeFeatures(CharSamp *char_samp, float *features)=0
basic_string< char_32 > string_32
Definition: string_32.h:41
static bool ReadFileToString(const string &file_name, string *str)
Definition: cube_utils.cpp:195
virtual bool Train(CharSamp *char_samp, int ClassID)
virtual CharAltList * Classify(CharSamp *char_samp)
static void UTF8ToUTF32(const char *utf8_str, string_32 *str32)
Definition: cube_utils.cpp:266
int ClassID(const char_32 *str) const
Definition: char_set.h:54
HybridNeuralNetCharClassifier(CharSet *char_set, TuningParams *params, FeatureBase *feat_extract)
virtual bool SetLearnParam(char *var_name, float val)
static void SplitStringUsing(const string &str, const string &delims, vector< string > *str_vec)
Definition: cube_utils.cpp:230
signed int char_32
Definition: string_32.h:40
int ClassCount() const
Definition: char_set.h:111
#define NULL
Definition: host.h:144
virtual int FeatureCnt()=0
const char_32 * ClassString(int class_id) const
Definition: char_set.h:104
static NeuralNet * FromFile(const string file_name)
Definition: neural_net.cpp:204