All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
sampleiterator.cpp
Go to the documentation of this file.
1 // Copyright 2011 Google Inc. All Rights Reserved.
2 // Author: rays@google.com (Ray Smith)
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 // http://www.apache.org/licenses/LICENSE-2.0
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 //
15 
16 #include "sampleiterator.h"
17 
18 #include "indexmapbidi.h"
19 #include "shapetable.h"
20 #include "trainingsample.h"
21 #include "trainingsampleset.h"
22 
23 namespace tesseract {
24 
25 // ================== SampleIterator Implementation =================
26 
28  : charset_map_(NULL),
29  shape_table_(NULL),
30  sample_set_(NULL),
31  randomize_(false),
32  owned_shape_table_(NULL) {
33  num_shapes_ = 0;
34  Begin();
35 }
36 
38  Clear();
39 }
40 
42  delete owned_shape_table_;
43  owned_shape_table_ = NULL;
44 }
45 
46 // See class comment for arguments.
47 void SampleIterator::Init(const IndexMapBiDi* charset_map,
48  const ShapeTable* shape_table,
49  bool randomize,
50  TrainingSampleSet* sample_set) {
51  Clear();
52  charset_map_ = charset_map;
53  shape_table_ = shape_table;
54  sample_set_ = sample_set;
55  randomize_ = randomize;
56  if (shape_table_ == NULL && charset_map_ != NULL) {
57  // The caller wishes to iterate by class. The easiest way to do this
58  // is to create a dummy shape_table_ that we will own.
59  int num_fonts = sample_set_->NumFonts();
60  owned_shape_table_ = new ShapeTable(sample_set_->unicharset());
61  int charsetsize = sample_set_->unicharset().size();
62  for (int c = 0; c < charsetsize; ++c) {
63  // We always add a shape for each character to keep the index in sync
64  // with the unichar_id.
65  int shape_id = owned_shape_table_->AddShape(c, 0);
66  for (int f = 1; f < num_fonts; ++f) {
67  if (sample_set_->NumClassSamples(f, c, true) > 0) {
68  owned_shape_table_->AddToShape(shape_id, c, f);
69  }
70  }
71  }
72  shape_table_ = owned_shape_table_;
73  }
74  if (shape_table_ != NULL) {
75  num_shapes_ = shape_table_->NumShapes();
76  } else {
77  num_shapes_ = randomize ? sample_set_->num_samples()
78  : sample_set_->num_raw_samples();
79  }
80  Begin();
81 }
82 
83 // Iterator functions designed for use with a simple for loop:
84 // for (it.Begin(); !it.AtEnd(); it.Next()) {
85 // const TrainingSample& sample = it.GetSample();
86 // }
88  shape_index_ = -1;
89  shape_char_index_ = 0;
90  num_shape_chars_ = 0;
91  shape_font_index_ = 0;
92  num_shape_fonts_ = 0;
93  sample_index_ = 0;
94  num_samples_ = 0;
95  // Find the first indexable sample.
96  Next();
97 }
98 
99 bool SampleIterator::AtEnd() const {
100  return shape_index_ >= num_shapes_;
101 }
102 
104  if (shape_table_ != NULL) {
105  const UnicharAndFonts* shape_entry = GetShapeEntry();
106  int char_id = shape_entry->unichar_id;
107  int font_id = shape_entry->font_ids[shape_font_index_];
108  return *sample_set_->GetSample(font_id, char_id, sample_index_);
109  } else {
110  return *sample_set_->GetSample(shape_index_);
111  }
112 }
113 
115  if (shape_table_ != NULL) {
116  const UnicharAndFonts* shape_entry = GetShapeEntry();
117  int char_id = shape_entry->unichar_id;
118  int font_id = shape_entry->font_ids[shape_font_index_];
119  return sample_set_->MutableSample(font_id, char_id, sample_index_);
120  } else {
121  return sample_set_->mutable_sample(shape_index_);
122  }
123 }
124 
125 // Returns the total index (from the original set of samples) of the current
126 // sample.
128  if (shape_table_ != NULL) {
129  const UnicharAndFonts* shape_entry = GetShapeEntry();
130  int char_id = shape_entry->unichar_id;
131  int font_id = shape_entry->font_ids[shape_font_index_];
132  return sample_set_->GlobalSampleIndex(font_id, char_id, sample_index_);
133  } else {
134  return shape_index_;
135  }
136 }
137 
138 // Returns the index of the current sample in compact charset space, so
139 // in a 2-class problem between x and y, the returned indices will all be
140 // 0 or 1, and have nothing to do with the unichar_ids.
141 // If the charset_map_ is NULL, then this is equal to GetSparseClassID().
143  return charset_map_ != NULL ? charset_map_->SparseToCompact(shape_index_)
144  : GetSparseClassID();
145 }
146 // Returns the index of the current sample in sparse charset space, so
147 // in a 2-class problem between x and y, the returned indices will all be
148 // x or y, where x and y may be unichar_ids (no shape_table_) or shape_ids
149 // with a shape_table_.
151  return shape_table_ != NULL ? shape_index_ : GetSample().class_id();
152 }
153 
154 // Moves on to the next indexable sample. If the end is reached, leaves
155 // the state such that AtEnd() is true.
157  if (shape_table_ != NULL) {
158  // Next sample in this class/font combination.
159  ++sample_index_;
160  if (sample_index_ < num_samples_)
161  return;
162  // Next font in this class in this shape.
163  sample_index_ = 0;
164  do {
165  ++shape_font_index_;
166  if (shape_font_index_ >= num_shape_fonts_) {
167  // Next unichar in this shape.
168  shape_font_index_ = 0;
169  ++shape_char_index_;
170  if (shape_char_index_ >= num_shape_chars_) {
171  // Find the next shape that is mapped in the charset_map_.
172  shape_char_index_ = 0;
173  do {
174  ++shape_index_;
175  } while (shape_index_ < num_shapes_ &&
176  charset_map_ != NULL &&
177  charset_map_->SparseToCompact(shape_index_) < 0);
178  if (shape_index_ >= num_shapes_)
179  return; // The end.
180  num_shape_chars_ = shape_table_->GetShape(shape_index_).size();
181  }
182  }
183  const UnicharAndFonts* shape_entry = GetShapeEntry();
184  num_shape_fonts_ = shape_entry->font_ids.size();
185  int char_id = shape_entry->unichar_id;
186  int font_id = shape_entry->font_ids[shape_font_index_];
187  num_samples_ = sample_set_->NumClassSamples(font_id, char_id, randomize_);
188  } while (num_samples_ == 0);
189  } else {
190  // We are just iterating over the samples.
191  ++shape_index_;
192  }
193 }
194 
195 // Returns the size of the compact charset space.
197  return charset_map_ != NULL ? charset_map_->CompactSize()
198  : SparseCharsetSize();
199 }
200 
201 // Returns the size of the sparse charset space.
203  return charset_map_ != NULL
204  ? charset_map_->SparseSize()
205  : (shape_table_ != NULL ? shape_table_->NumShapes()
206  : sample_set_->charsetsize());
207 }
208 
209 // Apply the supplied feature_space/feature_map transform to all samples
210 // accessed by this iterator.
212  for (Begin(); !AtEnd(); Next()) {
214  sample->MapFeatures(feature_map);
215  }
216 }
217 
218 // Adjust the weights of all the samples to be uniform in the given charset.
219 // Returns the number of samples in the iterator.
221  int num_good_samples = 0;
222  for (Begin(); !AtEnd(); Next()) {
224  sample->set_weight(1.0);
225  ++num_good_samples;
226  }
228  return num_good_samples;
229 }
230 
231 // Normalize the weights of all the samples in the charset_map so they sum
232 // to 1. Returns the minimum assigned sample weight.
234  double total_weight = 0.0;
235  int sample_count = 0;
236  for (Begin(); !AtEnd(); Next()) {
237  const TrainingSample& sample = GetSample();
238  total_weight += sample.weight();
239  ++sample_count;
240  }
241  // Normalize samples.
242  double min_assigned_sample_weight = 1.0;
243  if (total_weight > 0.0) {
244  for (Begin(); !AtEnd(); Next()) {
246  double weight = sample->weight() / total_weight;
247  if (weight < min_assigned_sample_weight)
248  min_assigned_sample_weight = weight;
249  sample->set_weight(weight);
250  }
251  }
252  return min_assigned_sample_weight;
253 }
254 
255 // Helper returns the current UnicharAndFont shape_entry.
256 const UnicharAndFonts* SampleIterator::GetShapeEntry() const {
257  const Shape& shape = shape_table_->GetShape(shape_index_);
258  return &shape[shape_char_index_];
259 }
260 
261 } // namespace tesseract.
262 
int size() const
Definition: shapetable.h:202
int size() const
Definition: genericvector.h:72
virtual int SparseSize() const
Definition: indexmapbidi.h:142
void Init(const IndexMapBiDi *charset_map, const ShapeTable *shape_table, bool randomize, TrainingSampleSet *sample_set)
const TrainingSample * GetSample(int index) const
TrainingSample * MutableSample() const
void MapFeatures(const IntFeatureMap &feature_map)
TrainingSample * MutableSample(int font_id, int class_id, int index)
const ShapeTable * shape_table() const
const IndexMapBiDi & charset_map() const
int GlobalSampleIndex(int font_id, int class_id, int index) const
virtual int SparseToCompact(int sparse_index) const
Definition: indexmapbidi.h:138
int NumClassSamples(int font_id, int class_id, bool randomize) const
GenericVector< inT32 > font_ids
Definition: shapetable.h:176
const TrainingSample & GetSample() const
UNICHAR_ID class_id() const
const UNICHARSET & unicharset() const
int CompactSize() const
Definition: indexmapbidi.h:61
Definition: cluster.h:32
const TrainingSampleSet * sample_set() const
#define NULL
Definition: host.h:144
void MapSampleFeatures(const IntFeatureMap &feature_map)
int AddShape(int unichar_id, int font_id)
Definition: shapetable.cpp:346
TrainingSample * mutable_sample(int index)
void set_weight(double value)
int size() const
Definition: unicharset.h:297
const Shape & GetShape(int shape_id) const
Definition: shapetable.h:323
int NumShapes() const
Definition: shapetable.h:278
void AddToShape(int shape_id, int unichar_id, int font_id)
Definition: shapetable.cpp:379