46 int num_out = w.
dim1();
47 int num_in = w.
dim2() - 1;
52 shaped_w_.resize((rounded_num_in + 1) * rounded_num_out, 0);
60 int num_outputs_per_register_set =
63 while (output + num_outputs_per_register_set <= rounded_num_out) {
68 for (
int j = 0; j < num_outputs_per_register_set; ++j) {
73 if (output + j < num_out && input + i < num_in)
74 weight = w(output + j, input + i);
80 for (
int j = 0; j < num_outputs_per_register_set; ++j) {
82 if (output + j < num_out) weight = w(output + j, num_in);
85 output += num_outputs_per_register_set;
96 const int8_t* u,
double* v)
const {
97 int num_out = w.
dim1();
98 int num_in = w.
dim2() - 1;
101 for (
int i = 0; i < num_out; ++i) {
102 const int8_t* wi = w[i];
104 for (
int j = 0; j < num_in; ++j) total += wi[j] * u[j];
106 v[i] = (
static_cast<double>(total) /
MAX_INT8 + wi[num_in]) * scales[i];
110 const double* scales_data = &scales[0];
119 int w_step = (rounded_num_in + 1) * group_size;
122 for (; output + group_size <= rounded_num_out; output += group_size) {
123 (*fn)(w_data, scales_data, u, rounded_num_in, num_out - output, v);
125 scales_data += group_size;
static bool IsSSEAvailable()
int num_outputs_per_register_
void MatrixDotVector(const GENERIC_2D_ARRAY< int8_t > &w, const GenericVector< double > &scales, const int8_t *u, double *v) const
std::vector< PartialFunc > partial_funcs_
void Init(const GENERIC_2D_ARRAY< int8_t > &w)
int RoundOutputs(int size) const
std::vector< int8_t > shaped_w_
int max_output_registers_
static bool IsAVX2Available()
static int Roundup(int input, int factor)
static IntSimdMatrix * GetFastestMultiplier()
int num_inputs_per_group_