tesseract  4.00.00dev
tesseract::IntSimdMatrix Class Reference

#include <intsimdmatrix.h>

Inheritance diagram for tesseract::IntSimdMatrix:
tesseract::IntSimdMatrixAVX2 tesseract::IntSimdMatrixSSE

Public Member Functions

 IntSimdMatrix ()
 
void Init (const GENERIC_2D_ARRAY< int8_t > &w)
 
int RoundInputs (int size) const
 
int RoundOutputs (int size) const
 
void MatrixDotVector (const GENERIC_2D_ARRAY< int8_t > &w, const GenericVector< double > &scales, const int8_t *u, double *v) const
 

Static Public Member Functions

static IntSimdMatrixGetFastestMultiplier ()
 

Protected Types

typedef void(* PartialFunc) (const int8_t *w, const double *scales, const int8_t *u, int num_in, int num_out, double *v)
 

Static Protected Member Functions

static int Roundup (int input, int factor)
 

Protected Attributes

int num_outputs_per_register_
 
int max_output_registers_
 
int num_inputs_per_register_
 
int num_inputs_per_group_
 
int num_input_groups_
 
std::vector< int8_t > shaped_w_
 
std::vector< PartialFuncpartial_funcs_
 

Detailed Description

Definition at line 60 of file intsimdmatrix.h.

Member Typedef Documentation

◆ PartialFunc

typedef void(* tesseract::IntSimdMatrix::PartialFunc) (const int8_t *w, const double *scales, const int8_t *u, int num_in, int num_out, double *v)
protected

Definition at line 108 of file intsimdmatrix.h.

Constructor & Destructor Documentation

◆ IntSimdMatrix()

tesseract::IntSimdMatrix::IntSimdMatrix ( )
inline

Member Function Documentation

◆ GetFastestMultiplier()

IntSimdMatrix * tesseract::IntSimdMatrix::GetFastestMultiplier ( )
static

Definition at line 29 of file intsimdmatrix.cpp.

29  {
30  IntSimdMatrix* multiplier = nullptr;
32  multiplier = new IntSimdMatrixAVX2();
33  } else if (SIMDDetect::IsSSEAvailable()) {
34  multiplier = new IntSimdMatrixSSE();
35  } else {
36  // Default c++ implementation.
37  multiplier = new IntSimdMatrix();
38  }
39  return multiplier;
40 }
static bool IsSSEAvailable()
Definition: simddetect.h:38
static bool IsAVX2Available()
Definition: simddetect.h:28

◆ Init()

void tesseract::IntSimdMatrix::Init ( const GENERIC_2D_ARRAY< int8_t > &  w)

Definition at line 44 of file intsimdmatrix.cpp.

44  {
45  if (partial_funcs_.empty()) return;
46  int num_out = w.dim1();
47  int num_in = w.dim2() - 1;
48  // The rounded-up sizes of the reshaped weight matrix, excluding biases.
49  int rounded_num_in = Roundup(num_in, num_inputs_per_group_);
50  int rounded_num_out = RoundOutputs(num_out);
51  // Add the bias and compute the required size.
52  shaped_w_.resize((rounded_num_in + 1) * rounded_num_out, 0);
53  int shaped_index = 0;
54  int output = 0;
55  // Each number of registers needs a different format! Iterates over the
56  // different numbers of registers (each a power of 2).
57  for (int num_registers = max_output_registers_; num_registers >= 1;
58  num_registers /= 2) {
59  // The number of outputs that we will generate with this many registers.
60  int num_outputs_per_register_set =
61  num_registers * num_outputs_per_register_;
62  // Use the max number of registers until we have to go fewer.
63  while (output + num_outputs_per_register_set <= rounded_num_out) {
64  // Accumulating outputs in registers saves iterating over the inputs, so
65  // we only have to do it once per output register set.
66  for (int input = 0; input < num_in; input += num_inputs_per_group_) {
67  // Iterate over the number of outputs in a register set.
68  for (int j = 0; j < num_outputs_per_register_set; ++j) {
69  // Inner-most loop corresponds to the number of inputs in an input
70  // group.
71  for (int i = 0; i < num_inputs_per_group_; ++i) {
72  int8_t weight = 0;
73  if (output + j < num_out && input + i < num_in)
74  weight = w(output + j, input + i);
75  shaped_w_[shaped_index++] = weight;
76  }
77  }
78  }
79  // Append the bias weights for the register set.
80  for (int j = 0; j < num_outputs_per_register_set; ++j) {
81  int8_t weight = 0;
82  if (output + j < num_out) weight = w(output + j, num_in);
83  shaped_w_[shaped_index++] = weight;
84  }
85  output += num_outputs_per_register_set;
86  }
87  }
88 }
std::vector< PartialFunc > partial_funcs_
int dim2() const
Definition: matrix.h:206
int RoundOutputs(int size) const
Definition: intsimdmatrix.h:84
std::vector< int8_t > shaped_w_
static int Roundup(int input, int factor)
int dim1() const
Definition: matrix.h:205

◆ MatrixDotVector()

void tesseract::IntSimdMatrix::MatrixDotVector ( const GENERIC_2D_ARRAY< int8_t > &  w,
const GenericVector< double > &  scales,
const int8_t *  u,
double *  v 
) const

Definition at line 94 of file intsimdmatrix.cpp.

96  {
97  int num_out = w.dim1();
98  int num_in = w.dim2() - 1;
99  if (partial_funcs_.empty()) {
100  // Base implementation.
101  for (int i = 0; i < num_out; ++i) {
102  const int8_t* wi = w[i];
103  int total = 0;
104  for (int j = 0; j < num_in; ++j) total += wi[j] * u[j];
105  // Add in the bias and correct for integer values.
106  v[i] = (static_cast<double>(total) / MAX_INT8 + wi[num_in]) * scales[i];
107  }
108  } else {
109  const int8_t* w_data = shaped_w_.data();
110  const double* scales_data = &scales[0];
111  // Each call to a partial_func_ produces group_size outputs, except the
112  // last one, which can produce less.
114  int rounded_num_in = Roundup(num_in, num_inputs_per_group_);
115  int rounded_num_out = RoundOutputs(num_out);
116  int output = 0;
117  for (auto fn : partial_funcs_) {
118  // The amount of w_data consumed by each call to fn.
119  int w_step = (rounded_num_in + 1) * group_size;
120  // Run with this group size, until it would produce too much output, then
121  // switch to a smaller size.
122  for (; output + group_size <= rounded_num_out; output += group_size) {
123  (*fn)(w_data, scales_data, u, rounded_num_in, num_out - output, v);
124  w_data += w_step;
125  scales_data += group_size;
126  v += group_size;
127  }
128  group_size /= 2;
129  }
130  }
131 }
#define MAX_INT8
Definition: host.h:60
std::vector< PartialFunc > partial_funcs_
int dim2() const
Definition: matrix.h:206
int RoundOutputs(int size) const
Definition: intsimdmatrix.h:84
std::vector< int8_t > shaped_w_
static int Roundup(int input, int factor)
int dim1() const
Definition: matrix.h:205

◆ RoundInputs()

int tesseract::IntSimdMatrix::RoundInputs ( int  size) const
inline

Definition at line 80 of file intsimdmatrix.h.

80  {
81  return Roundup(size, num_inputs_per_register_);
82  }
static int Roundup(int input, int factor)

◆ RoundOutputs()

int tesseract::IntSimdMatrix::RoundOutputs ( int  size) const
inline

Definition at line 84 of file intsimdmatrix.h.

84  {
85  return Roundup(size, num_outputs_per_register_);
86  }
static int Roundup(int input, int factor)

◆ Roundup()

static int tesseract::IntSimdMatrix::Roundup ( int  input,
int  factor 
)
inlinestaticprotected

Definition at line 113 of file intsimdmatrix.h.

113  {
114  return (input + factor - 1) / factor * factor;
115  }

Member Data Documentation

◆ max_output_registers_

int tesseract::IntSimdMatrix::max_output_registers_
protected

Definition at line 120 of file intsimdmatrix.h.

◆ num_input_groups_

int tesseract::IntSimdMatrix::num_input_groups_
protected

Definition at line 126 of file intsimdmatrix.h.

◆ num_inputs_per_group_

int tesseract::IntSimdMatrix::num_inputs_per_group_
protected

Definition at line 124 of file intsimdmatrix.h.

◆ num_inputs_per_register_

int tesseract::IntSimdMatrix::num_inputs_per_register_
protected

Definition at line 122 of file intsimdmatrix.h.

◆ num_outputs_per_register_

int tesseract::IntSimdMatrix::num_outputs_per_register_
protected

Definition at line 118 of file intsimdmatrix.h.

◆ partial_funcs_

std::vector<PartialFunc> tesseract::IntSimdMatrix::partial_funcs_
protected

Definition at line 130 of file intsimdmatrix.h.

◆ shaped_w_

std::vector<int8_t> tesseract::IntSimdMatrix::shaped_w_
protected

Definition at line 128 of file intsimdmatrix.h.


The documentation for this class was generated from the following files: