19#if defined(__ARM_NEON)
33constexpr int kNumOutputsPerRegister = 8;
35constexpr int kMaxOutputRegisters = 1;
37constexpr int kNumInputsPerRegister = 8;
39constexpr int kNumInputsPerGroup = 8;
55static inline void PartialMatrixDotVector8(
const int8_t *__restrict wi,
56 const TFloat *__restrict scales,
57 const int8_t *__restrict u,
int num_in,
58 TFloat *__restrict v,
int num_out) {
60 int32x4_t result0123 = {0, 0, 0, 0};
61 int32x4_t result4567 = {0, 0, 0, 0};
62 int8x8_t bias_scale = {127, 127, 127, 127, 127, 127, 127, 127};
64 for (
int j = 0; j < num_in; j += 8) {
65 int8x8_t vu = vld1_s8(u);
66 int8x16_t vw01 = vld1q_s8(wi);
68 int8x16_t vw23 = vld1q_s8(wi + 8 * 2);
70 int8x16_t vw45 = vld1q_s8(wi + 8 * 4);
72 int8x16_t vw67 = vld1q_s8(wi + 8 * 6);
75 int16x8_t vrow0q = vmull_s8(vget_low_s8(vw01), vu);
77 int16x8_t vrow1q = vmull_s8(vget_high_s8(vw01),
80 int16x8_t vrow2q = vmull_s8(vget_low_s8(vw23), vu);
82 int16x8_t vrow3q = vmull_s8(vget_high_s8(vw23),
85 int16x8_t vrow4q = vmull_s8(vget_low_s8(vw45), vu);
87 int16x8_t vrow5q = vmull_s8(vget_high_s8(vw45),
90 int16x8_t vrow6q = vmull_s8(vget_low_s8(vw67), vu);
92 int16x8_t vrow7q = vmull_s8(vget_high_s8(vw67),
96 int32x4_t vrow0q2 = vpaddlq_s16(vrow0q);
98 int32x4_t vrow1q2 = vpaddlq_s16(vrow1q);
100 int32x4_t vrow2q2 = vpaddlq_s16(vrow2q);
102 int32x4_t vrow3q2 = vpaddlq_s16(vrow3q);
104 int32x4_t vrow4q2 = vpaddlq_s16(vrow4q);
106 int32x4_t vrow5q2 = vpaddlq_s16(vrow5q);
108 int32x4_t vrow6q2 = vpaddlq_s16(vrow6q);
110 int32x4_t vrow7q2 = vpaddlq_s16(vrow7q);
113 vrow0q2 = vcombine_s32(vpadd_s32(vget_low_s32(vrow0q2), vget_high_s32(vrow0q2)),
114 vpadd_s32(vget_low_s32(vrow1q2), vget_high_s32(vrow1q2)));
117 vrow2q2 = vcombine_s32(vpadd_s32(vget_low_s32(vrow2q2), vget_high_s32(vrow2q2)),
118 vpadd_s32(vget_low_s32(vrow3q2), vget_high_s32(vrow3q2)));
121 vrow4q2 = vcombine_s32(vpadd_s32(vget_low_s32(vrow4q2), vget_high_s32(vrow4q2)),
122 vpadd_s32(vget_low_s32(vrow5q2), vget_high_s32(vrow5q2)));
125 vrow6q2 = vcombine_s32(vpadd_s32(vget_low_s32(vrow6q2), vget_high_s32(vrow6q2)),
126 vpadd_s32(vget_low_s32(vrow7q2), vget_high_s32(vrow7q2)));
130 vrow0q2 = vcombine_s32(vpadd_s32(vget_low_s32(vrow0q2), vget_high_s32(vrow0q2)),
131 vpadd_s32(vget_low_s32(vrow2q2), vget_high_s32(vrow2q2)));
134 vrow4q2 = vcombine_s32(vpadd_s32(vget_low_s32(vrow4q2), vget_high_s32(vrow4q2)),
135 vpadd_s32(vget_low_s32(vrow6q2), vget_high_s32(vrow6q2)));
139 result0123 = vaddq_s32(result0123, vrow0q2);
140 result4567 = vaddq_s32(result4567, vrow4q2);
145 int8x8_t bias = vld1_s8(wi);
146 int16x8_t scaled_bias = vmull_s8(bias, bias_scale);
147 result0123 = vaddw_s16(result0123, vget_low_s16(scaled_bias));
148 result4567 = vaddw_s16(result4567, vget_high_s16(scaled_bias));
149 *v++ = vget_lane_s32(vget_low_s32(result0123), 0) * *scales++;
151 *v++ = vget_lane_s32(vget_low_s32(result0123), 1) * *scales++;
153 *v++ = vget_lane_s32(vget_high_s32(result0123), 0) * *scales++;
155 *v++ = vget_lane_s32(vget_high_s32(result0123), 1) * *scales++;
157 *v++ = vget_lane_s32(vget_low_s32(result4567), 0) * *scales++;
159 *v++ = vget_lane_s32(vget_low_s32(result4567), 1) * *scales++;
161 *v++ = vget_lane_s32(vget_high_s32(result4567), 0) * *scales++;
163 *v = vget_lane_s32(vget_high_s32(result4567), 1) * *scales;
167static void matrixDotVector(
int dim1,
int dim2,
const int8_t *wi,
const TFloat *scales,
168 const int8_t *u,
TFloat *v) {
169 const int num_out = dim1;
170 const int num_in = dim2 - 1;
174 int group_size = kNumOutputsPerRegister * kMaxOutputRegisters;
177 int w_step = (rounded_num_in + 1) * group_size;
179 for (;
output + group_size <= num_out;
output += group_size) {
180 PartialMatrixDotVector8(wi, scales, u, rounded_num_in, v, kNumOutputsPerRegister);
182 scales += group_size;
186 PartialMatrixDotVector8(wi, scales, u, rounded_num_in, v,
187 num_out & (kNumOutputsPerRegister - 1));
194 kNumOutputsPerRegister,
198 kNumInputsPerRegister,
static int Roundup(int input, int factor)
static const IntSimdMatrix intSimdMatrixNEON