tesseract  4.00.00dev
tesseract::WeightMatrix Class Reference

#include <weightmatrix.h>

Public Member Functions

 WeightMatrix ()
 
int InitWeightsFloat (int no, int ni, bool ada_grad, float weight_range, TRand *randomizer)
 
void ConvertToInt ()
 
bool is_int_mode () const
 
int NumOutputs () const
 
const double * GetWeights (int index) const
 
double GetDW (int i, int j) const
 
void InitBackward ()
 
bool Serialize (bool training, TFile *fp) const
 
bool DeSerialize (bool training, TFile *fp)
 
bool DeSerializeOld (bool training, TFile *fp)
 
void MatrixDotVector (const double *u, double *v) const
 
void MatrixDotVector (const inT8 *u, double *v) const
 
void MultiplyAccumulate (const double *v, double *inout)
 
void VectorDotMatrix (const double *u, double *v) const
 
void SumOuterTransposed (const TransposedArray &u, const TransposedArray &v, bool parallel)
 
void Update (double learning_rate, double momentum, int num_samples)
 
void AddDeltas (const WeightMatrix &other)
 
void CountAlternators (const WeightMatrix &other, double *same, double *changed) const
 
void Debug2D (const char *msg)
 

Static Public Member Functions

static double DotProduct (const double *u, const double *v, int n)
 
static void FloatToDouble (const GENERIC_2D_ARRAY< float > &wf, GENERIC_2D_ARRAY< double > *wd)
 

Detailed Description

Definition at line 63 of file weightmatrix.h.

Constructor & Destructor Documentation

◆ WeightMatrix()

tesseract::WeightMatrix::WeightMatrix ( )
inline

Definition at line 65 of file weightmatrix.h.

65 : int_mode_(false), use_ada_grad_(false) {}

Member Function Documentation

◆ AddDeltas()

void tesseract::WeightMatrix::AddDeltas ( const WeightMatrix other)

Definition at line 267 of file weightmatrix.cpp.

267  {
268  ASSERT_HOST(dw_.dim1() == other.dw_.dim1());
269  ASSERT_HOST(dw_.dim2() == other.dw_.dim2());
270  dw_ += other.dw_;
271 }
#define ASSERT_HOST(x)
Definition: errcode.h:84
int dim1() const
Definition: matrix.h:201
int dim2() const
Definition: matrix.h:202

◆ ConvertToInt()

void tesseract::WeightMatrix::ConvertToInt ( )

Definition at line 62 of file weightmatrix.cpp.

62  {
63  wi_.ResizeNoInit(wf_.dim1(), wf_.dim2());
64  scales_.init_to_size(wi_.dim1(), 0.0);
65  int dim2 = wi_.dim2();
66  for (int t = 0; t < wi_.dim1(); ++t) {
67  double* f_line = wf_[t];
68  inT8* i_line = wi_[t];
69  double max_abs = 0.0;
70  for (int f = 0; f < dim2; ++f) {
71  double abs_val = fabs(f_line[f]);
72  if (abs_val > max_abs) max_abs = abs_val;
73  }
74  double scale = max_abs / MAX_INT8;
75  scales_[t] = scale;
76  if (scale == 0.0) scale = 1.0;
77  for (int f = 0; f < dim2; ++f) {
78  i_line[f] = IntCastRounded(f_line[f] / scale);
79  }
80  }
81  wf_.Resize(1, 1, 0.0);
82  int_mode_ = true;
83 }
void init_to_size(int size, T t)
void Resize(int size1, int size2, const T &empty)
Definition: matrix.h:98
int IntCastRounded(double x)
Definition: helpers.h:179
int dim1() const
Definition: matrix.h:201
int dim2() const
Definition: matrix.h:202
void ResizeNoInit(int size1, int size2)
Definition: matrix.h:86
int8_t inT8
Definition: host.h:34
#define MAX_INT8
Definition: host.h:60

◆ CountAlternators()

void tesseract::WeightMatrix::CountAlternators ( const WeightMatrix other,
double *  same,
double *  changed 
) const

Definition at line 276 of file weightmatrix.cpp.

277  {
278  int num_outputs = updates_.dim1();
279  int num_inputs = updates_.dim2();
280  ASSERT_HOST(num_outputs == other.updates_.dim1());
281  ASSERT_HOST(num_inputs == other.updates_.dim2());
282  for (int i = 0; i < num_outputs; ++i) {
283  const double* this_i = updates_[i];
284  const double* other_i = other.updates_[i];
285  for (int j = 0; j < num_inputs; ++j) {
286  double product = this_i[j] * other_i[j];
287  if (product < 0.0)
288  *changed -= product;
289  else
290  *same += product;
291  }
292  }
293 }
#define ASSERT_HOST(x)
Definition: errcode.h:84
int dim1() const
Definition: matrix.h:201
int dim2() const
Definition: matrix.h:202

◆ Debug2D()

void tesseract::WeightMatrix::Debug2D ( const char *  msg)

Definition at line 307 of file weightmatrix.cpp.

307  {
308  STATS histogram(0, kHistogramBuckets);
309  if (int_mode_) {
310  for (int i = 0; i < wi_.dim1(); ++i) {
311  for (int j = 0; j < wi_.dim2(); ++j) {
312  HistogramWeight(wi_[i][j] * scales_[i], &histogram);
313  }
314  }
315  } else {
316  for (int i = 0; i < wf_.dim1(); ++i) {
317  for (int j = 0; j < wf_.dim2(); ++j) {
318  HistogramWeight(wf_[i][j], &histogram);
319  }
320  }
321  }
322  tprintf("%s\n", msg);
323  histogram.print();
324 }
const int kHistogramBuckets
#define tprintf(...)
Definition: tprintf.h:31
int dim1() const
Definition: matrix.h:201
int dim2() const
Definition: matrix.h:202
Definition: statistc.h:33

◆ DeSerialize()

bool tesseract::WeightMatrix::DeSerialize ( bool  training,
TFile fp 
)

Definition at line 125 of file weightmatrix.cpp.

125  {
126  uinT8 mode = 0;
127  if (fp->FRead(&mode, sizeof(mode), 1) != 1) return false;
128  int_mode_ = (mode & kInt8Flag) != 0;
129  use_ada_grad_ = (mode & kAdaGradFlag) != 0;
130  if ((mode & kDoubleFlag) == 0) return DeSerializeOld(training, fp);
131  if (int_mode_) {
132  if (!wi_.DeSerialize(fp)) return false;
133  if (!scales_.DeSerialize(fp)) return false;
134  } else {
135  if (!wf_.DeSerialize(fp)) return false;
136  if (training) {
137  InitBackward();
138  if (!updates_.DeSerialize(fp)) return false;
139  if (use_ada_grad_ && !dw_sq_sum_.DeSerialize(fp)) return false;
140  }
141  }
142  return true;
143 }
bool DeSerialize(bool swap, FILE *fp)
bool DeSerialize(bool swap, FILE *fp)
Definition: matrix.h:155
const char int mode
Definition: ioapi.h:38
const int kDoubleFlag
const int kInt8Flag
uint8_t uinT8
Definition: host.h:35
bool DeSerializeOld(bool training, TFile *fp)
const int kAdaGradFlag

◆ DeSerializeOld()

bool tesseract::WeightMatrix::DeSerializeOld ( bool  training,
TFile fp 
)

Definition at line 147 of file weightmatrix.cpp.

147  {
148  GENERIC_2D_ARRAY<float> float_array;
149  if (int_mode_) {
150  if (!wi_.DeSerialize(fp)) return false;
151  GenericVector<float> old_scales;
152  if (!old_scales.DeSerialize(fp)) return false;
153  scales_.resize_no_init(old_scales.size());
154  for (int i = 0; i < old_scales.size(); ++i) scales_[i] = old_scales[i];
155  } else {
156  if (!float_array.DeSerialize(fp)) return false;
157  FloatToDouble(float_array, &wf_);
158  }
159  if (training) {
160  InitBackward();
161  if (!float_array.DeSerialize(fp)) return false;
162  FloatToDouble(float_array, &updates_);
163  // Errs was only used in int training, which is now dead.
164  if (!float_array.DeSerialize(fp)) return false;
165  }
166  return true;
167 }
bool DeSerialize(bool swap, FILE *fp)
void resize_no_init(int size)
Definition: genericvector.h:66
int size() const
Definition: genericvector.h:72
bool DeSerialize(bool swap, FILE *fp)
Definition: matrix.h:155
static void FloatToDouble(const GENERIC_2D_ARRAY< float > &wf, GENERIC_2D_ARRAY< double > *wd)

◆ DotProduct()

double tesseract::WeightMatrix::DotProduct ( const double *  u,
const double *  v,
int  n 
)
static

Definition at line 328 of file weightmatrix.cpp.

328  {
329  // Note: because the order of addition is different among the 3 DotProduct
330  // functions, the results can (and do) vary slightly (although they agree
331  // to within about 4e-15). This produces different results when running
332  // training, despite all random inputs being precisely equal.
333  // To get consistent results, use just one of these DotProduct functions.
334  // On a test multi-layer network, serial is 57% slower than sse, and avx
335  // is about 8% faster than sse. This suggests that the time is memory
336  // bandwidth constrained and could benefit from holding the reused vector
337  // in AVX registers.
338 /*
339 omp simd code
340 real 4m17,294s
341 user 12m39,344s
342 sys 0m2,252s
343 
344 real 4m22,403s
345 user 12m53,408s
346 sys 0m2,116s
347 
348 old code
349 real 2m52,396s
350 user 7m42,624s
351 sys 0m2,008s
352 */
353 
354 #ifndef _OPENMP
355  if (SIMDDetect::IsAVXAvailable()) return DotProductAVX(u, v, n);
356  if (SIMDDetect::IsSSEAvailable()) return DotProductSSE(u, v, n);
357 #endif
358  double total = 0.0;
359 #ifdef _OPENMP
360 #pragma omp simd
361 #endif
362  for (int k = 0; k < n; ++k) total += u[k] * v[k];
363  return total;
364 }
double u[max]
double DotProductSSE(const double *u, const double *v, int n)
static bool IsAVXAvailable()
Definition: simddetect.h:26
double DotProductAVX(const double *u, const double *v, int n)
static bool IsSSEAvailable()
Definition: simddetect.h:28
double v[max]

◆ FloatToDouble()

void tesseract::WeightMatrix::FloatToDouble ( const GENERIC_2D_ARRAY< float > &  wf,
GENERIC_2D_ARRAY< double > *  wd 
)
static

Definition at line 369 of file weightmatrix.cpp.

370  {
371  int dim1 = wf.dim1();
372  int dim2 = wf.dim2();
373  wd->ResizeNoInit(dim1, dim2);
374  for (int i = 0; i < dim1; ++i) {
375  const float* wfi = wf[i];
376  double* wdi = (*wd)[i];
377  for (int j = 0; j < dim2; ++j) wdi[j] = static_cast<double>(wfi[j]);
378  }
379 }
int dim1() const
Definition: matrix.h:201
int dim2() const
Definition: matrix.h:202
void ResizeNoInit(int size1, int size2)
Definition: matrix.h:86

◆ GetDW()

double tesseract::WeightMatrix::GetDW ( int  i,
int  j 
) const
inline

Definition at line 91 of file weightmatrix.h.

91 { return dw_(i, j); }

◆ GetWeights()

const double* tesseract::WeightMatrix::GetWeights ( int  index) const
inline

Definition at line 89 of file weightmatrix.h.

89 { return wf_[index]; }

◆ InitBackward()

void tesseract::WeightMatrix::InitBackward ( )

Definition at line 87 of file weightmatrix.cpp.

87  {
88  int no = int_mode_ ? wi_.dim1() : wf_.dim1();
89  int ni = int_mode_ ? wi_.dim2() : wf_.dim2();
90  dw_.Resize(no, ni, 0.0);
91  updates_.Resize(no, ni, 0.0);
92  wf_t_.Transpose(wf_);
93  if (use_ada_grad_) dw_sq_sum_.Resize(no, ni, 0.0);
94 }
void Transpose(const GENERIC_2D_ARRAY< double > &input)
void Resize(int size1, int size2, const T &empty)
Definition: matrix.h:98
int dim1() const
Definition: matrix.h:201
int dim2() const
Definition: matrix.h:202

◆ InitWeightsFloat()

int tesseract::WeightMatrix::InitWeightsFloat ( int  no,
int  ni,
bool  ada_grad,
float  weight_range,
TRand randomizer 
)

Definition at line 39 of file weightmatrix.cpp.

40  {
41  int_mode_ = false;
42  wf_.Resize(no, ni, 0.0);
43  if (randomizer != NULL) {
44  for (int i = 0; i < no; ++i) {
45  for (int j = 0; j < ni; ++j) {
46  wf_[i][j] = randomizer->SignedRand(weight_range);
47  }
48  }
49  }
50  use_ada_grad_ = ada_grad;
51  InitBackward();
52  return ni * no;
53 }
void Resize(int size1, int size2, const T &empty)
Definition: matrix.h:98

◆ is_int_mode()

bool tesseract::WeightMatrix::is_int_mode ( ) const
inline

Definition at line 84 of file weightmatrix.h.

84  {
85  return int_mode_;
86  }

◆ MatrixDotVector() [1/2]

void tesseract::WeightMatrix::MatrixDotVector ( const double *  u,
double *  v 
) const

Definition at line 174 of file weightmatrix.cpp.

174  {
175  ASSERT_HOST(!int_mode_);
176  MatrixDotVectorInternal(wf_, true, false, u, v);
177 }
double u[max]
#define ASSERT_HOST(x)
Definition: errcode.h:84
double v[max]

◆ MatrixDotVector() [2/2]

void tesseract::WeightMatrix::MatrixDotVector ( const inT8 u,
double *  v 
) const

Definition at line 179 of file weightmatrix.cpp.

179  {
180  ASSERT_HOST(int_mode_);
181  int num_out = wi_.dim1();
182  int num_in = wi_.dim2() - 1;
183  for (int i = 0; i < num_out; ++i) {
184  const inT8* Wi = wi_[i];
185  int total = 0;
187  total = IntDotProductSSE(u, Wi, num_in);
188  } else {
189  for (int j = 0; j < num_in; ++j) total += Wi[j] * u[j];
190  }
191  // Add in the bias and correct for integer values.
192  v[i] = (static_cast<double>(total) / MAX_INT8 + Wi[num_in]) * scales_[i];
193  }
194 }
double u[max]
#define ASSERT_HOST(x)
Definition: errcode.h:84
int dim1() const
Definition: matrix.h:201
int32_t IntDotProductSSE(const int8_t *u, const int8_t *v, int n)
int dim2() const
Definition: matrix.h:202
int8_t inT8
Definition: host.h:34
#define MAX_INT8
Definition: host.h:60
static bool IsSSEAvailable()
Definition: simddetect.h:28
double v[max]

◆ MultiplyAccumulate()

void tesseract::WeightMatrix::MultiplyAccumulate ( const double *  v,
double *  inout 
)

Definition at line 198 of file weightmatrix.cpp.

198  {
199  ASSERT_HOST(!int_mode_);
200  ASSERT_HOST(wf_.dim1() == 1);
201  int n = wf_.dim2();
202  const double* u = wf_[0];
203  for (int i = 0; i < n; ++i) {
204  inout[i] += u[i] * v[i];
205  }
206 }
double u[max]
#define ASSERT_HOST(x)
Definition: errcode.h:84
int dim1() const
Definition: matrix.h:201
int dim2() const
Definition: matrix.h:202
double v[max]

◆ NumOutputs()

int tesseract::WeightMatrix::NumOutputs ( ) const
inline

Definition at line 87 of file weightmatrix.h.

87 { return int_mode_ ? wi_.dim1() : wf_.dim1(); }
int dim1() const
Definition: matrix.h:201

◆ Serialize()

bool tesseract::WeightMatrix::Serialize ( bool  training,
TFile fp 
) const

Definition at line 106 of file weightmatrix.cpp.

106  {
107  // For backward compatibility, add kDoubleFlag to mode to indicate the doubles
108  // format, without errs, so we can detect and read old format weight matrices.
109  uinT8 mode = (int_mode_ ? kInt8Flag : 0) |
110  (use_ada_grad_ ? kAdaGradFlag : 0) | kDoubleFlag;
111  if (fp->FWrite(&mode, sizeof(mode), 1) != 1) return false;
112  if (int_mode_) {
113  if (!wi_.Serialize(fp)) return false;
114  if (!scales_.Serialize(fp)) return false;
115  } else {
116  if (!wf_.Serialize(fp)) return false;
117  if (training && !updates_.Serialize(fp)) return false;
118  if (training && use_ada_grad_ && !dw_sq_sum_.Serialize(fp)) return false;
119  }
120  return true;
121 }
bool Serialize(FILE *fp) const
Definition: matrix.h:137
const char int mode
Definition: ioapi.h:38
const int kDoubleFlag
const int kInt8Flag
uint8_t uinT8
Definition: host.h:35
bool Serialize(FILE *fp) const
const int kAdaGradFlag

◆ SumOuterTransposed()

void tesseract::WeightMatrix::SumOuterTransposed ( const TransposedArray u,
const TransposedArray v,
bool  parallel 
)

Definition at line 222 of file weightmatrix.cpp.

224  {
225  ASSERT_HOST(!int_mode_);
226  int num_outputs = dw_.dim1();
227  ASSERT_HOST(u.dim1() == num_outputs);
228  ASSERT_HOST(u.dim2() == v.dim2());
229  int num_inputs = dw_.dim2() - 1;
230  int num_samples = u.dim2();
231  // v is missing the last element in dim1.
232  ASSERT_HOST(v.dim1() == num_inputs);
233 #ifdef _OPENMP
234 #pragma omp parallel for num_threads(4) if (in_parallel)
235 #endif
236  for (int i = 0; i < num_outputs; ++i) {
237  double* dwi = dw_[i];
238  const double* ui = u[i];
239  for (int j = 0; j < num_inputs; ++j) {
240  dwi[j] = DotProduct(ui, v[j], num_samples);
241  }
242  // The last element of v is missing, presumed 1.0f.
243  double total = 0.0;
244  for (int k = 0; k < num_samples; ++k) total += ui[k];
245  dwi[num_inputs] = total;
246  }
247 }
double u[max]
#define ASSERT_HOST(x)
Definition: errcode.h:84
int dim1() const
Definition: matrix.h:201
int dim2() const
Definition: matrix.h:202
static double DotProduct(const double *u, const double *v, int n)
double v[max]

◆ Update()

void tesseract::WeightMatrix::Update ( double  learning_rate,
double  momentum,
int  num_samples 
)

Definition at line 252 of file weightmatrix.cpp.

253  {
254  ASSERT_HOST(!int_mode_);
255  if (use_ada_grad_ && num_samples > 0) {
256  dw_sq_sum_.SumSquares(dw_);
257  dw_.AdaGradScaling(dw_sq_sum_, num_samples);
258  }
259  dw_ *= learning_rate;
260  updates_ += dw_;
261  if (momentum > 0.0) wf_ += updates_;
262  if (momentum >= 0.0) updates_ *= momentum;
263  wf_t_.Transpose(wf_);
264 }
void AdaGradScaling(const GENERIC_2D_ARRAY< T > &sqsum, int num_samples)
Definition: matrix.h:372
void SumSquares(const GENERIC_2D_ARRAY< T > &src)
Definition: matrix.h:363
void Transpose(const GENERIC_2D_ARRAY< double > &input)
#define ASSERT_HOST(x)
Definition: errcode.h:84

◆ VectorDotMatrix()

void tesseract::WeightMatrix::VectorDotMatrix ( const double *  u,
double *  v 
) const

Definition at line 212 of file weightmatrix.cpp.

212  {
213  ASSERT_HOST(!int_mode_);
214  MatrixDotVectorInternal(wf_t_, false, true, u, v);
215 }
double u[max]
#define ASSERT_HOST(x)
Definition: errcode.h:84
double v[max]

The documentation for this class was generated from the following files: