28 int32_t &rounded_num_out)
const {
29 const int num_out = w.
dim1();
30 const int num_in = w.
dim2() - 1;
35 shaped_w.resize((rounded_num_in + 1) * rounded_num_out, 0);
44 while (
output + num_outputs_per_register_set <= rounded_num_out) {
49 for (
int j = 0; j < num_outputs_per_register_set; ++j) {
54 if (
output + j < num_out && input +
i < num_in) {
55 weight = w(
output + j, input +
i);
57 shaped_w[shaped_index++] = weight;
62 for (
int j = 0; j < num_outputs_per_register_set; ++j) {
64 if (
output + j < num_out) {
65 weight = w(
output + j, num_in);
67 shaped_w[shaped_index++] = weight;
69 output += num_outputs_per_register_set;
79 const std::vector<TFloat> &scales,
const int8_t *u,
TFloat *v) {
80 int num_out = w.
dim1();
81 int num_in = w.
dim2() - 1;
83 for (
int i = 0;
i < num_out; ++
i) {
84 const int8_t *wi = w[
i];
86 for (
int j = 0; j < num_in; ++j) {
87 total += wi[j] * u[j];
90 v[
i] = (total + wi[num_in] * INT8_MAX) * scales[
i];
int num_inputs_per_group_
int max_output_registers_
int RoundOutputs(int size) const
static void MatrixDotVector(const GENERIC_2D_ARRAY< int8_t > &w, const std::vector< TFloat > &scales, const int8_t *u, TFloat *v)
int num_outputs_per_register_
static int Roundup(int input, int factor)
static const IntSimdMatrix * intSimdMatrix
void Init(const GENERIC_2D_ARRAY< int8_t > &w, std::vector< int8_t > &shaped_w, int32_t &rounded_num_out) const