21# if defined(__i686__) || defined(__x86_64__)
22# error Implementation only for AVX2 capable architectures
25# include <immintrin.h>
30# if defined(_MSC_VER) && _MSC_VER >= 1925 && _MSC_VER <= 1929 && \
31 defined(_WIN32) && !defined(_WIN64)
34# pragma optimize("", off)
35# pragma optimize("s", on)
41constexpr int kNumOutputsPerRegister = 8;
43constexpr int kMaxOutputRegisters = 8;
45constexpr int kNumInputsPerRegister = 32;
47constexpr int kNumInputsPerGroup = 4;
49constexpr int kNumInputGroups = kNumInputsPerRegister / kNumInputsPerGroup;
68static inline void MultiplyGroup(
const __m256i &rep_input,
const __m256i &ones,
const int8_t *&wi,
69 __m256i &weights, __m256i &reps, __m256i &result) {
71 weights = _mm256_loadu_si256(
reinterpret_cast<const __m256i *
>(wi));
72 wi += kNumInputsPerRegister;
74 reps = _mm256_sign_epi8(rep_input, weights);
75 weights = _mm256_sign_epi8(weights, weights);
78 weights = _mm256_maddubs_epi16(weights, reps);
84 weights = _mm256_madd_epi16(weights, ones);
85 result = _mm256_add_epi32(result, weights);
91static inline __m128i load64_to_128(
const int8_t *wi_) {
92 const auto *wi =
reinterpret_cast<const int64_t *
>(wi_);
93 return _mm_set_epi64x(0, wi[0]);
96#if defined(FAST_FLOAT)
98static inline void ExtractResults8(__m256i result,
const int8_t *wi,
99 const float *scales,
float *v) {
100 __m128i w128 = load64_to_128(wi);
101 __m256i w256 = _mm256_cvtepi8_epi32(w128);
102 __m256i bias_scale = _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127);
103 __m256 scale01234567 = _mm256_loadu_ps(scales);
104 w256 = _mm256_mullo_epi32(w256, bias_scale);
105 result = _mm256_add_epi32(result, w256);
106 __m256 res01234567 = _mm256_cvtepi32_ps(result);
107 result = _mm256_permute4x64_epi64(result, 2 + (3 << 2));
108 res01234567 = _mm256_mul_ps(res01234567, scale01234567);
109 _mm256_storeu_ps(v, res01234567);
112static inline void ExtractResults16(__m256i result0, __m256i result1,
113 const int8_t *&wi,
const float *&scales,
115 __m128i w8 = _mm_loadu_si128(
reinterpret_cast<const __m128i *
>(wi));
117 const __m256i bias_scale =
118 _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127);
119 __m256i w256 = _mm256_cvtepi8_epi32(w8);
120 __m256 scale01234567 = _mm256_loadu_ps(scales);
121 w256 = _mm256_mullo_epi32(w256, bias_scale);
122 result0 = _mm256_add_epi32(result0, w256);
123 __m256 res01234567 = _mm256_cvtepi32_ps(result0);
124 result0 = _mm256_permute4x64_epi64(result0, 2 + (3 << 2));
125 res01234567 = _mm256_mul_ps(res01234567, scale01234567);
126 _mm256_storeu_ps(v, res01234567);
127 w8 = _mm_shuffle_epi32(w8, 2 + (3 << 2));
128 w256 = _mm256_cvtepi8_epi32(w8);
129 scale01234567 = _mm256_loadu_ps(scales + 8);
130 w256 = _mm256_mullo_epi32(w256, bias_scale);
131 result1 = _mm256_add_epi32(result1, w256);
132 res01234567 = _mm256_cvtepi32_ps(result1);
133 result1 = _mm256_permute4x64_epi64(result1, 2 + (3 << 2));
134 res01234567 = _mm256_mul_ps(res01234567, scale01234567);
135 _mm256_storeu_ps(v + 8, res01234567);
148static void PartialMatrixDotVector64(
const int8_t *wi,
const float *scales,
const int8_t *u,
149 int num_in,
float *v) {
152 __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
153 __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
155 __m256i result0 = _mm256_setzero_si256();
156 __m256i result1 = _mm256_setzero_si256();
157 __m256i result2 = _mm256_setzero_si256();
158 __m256i result3 = _mm256_setzero_si256();
159 __m256i result4 = _mm256_setzero_si256();
160 __m256i result5 = _mm256_setzero_si256();
161 __m256i result6 = _mm256_setzero_si256();
162 __m256i result7 = _mm256_setzero_si256();
164 for (
int j = 0; j < num_in;) {
165 __m256i inputs = _mm256_loadu_si256(
reinterpret_cast<const __m256i *
>(u + j));
168 for (
int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
170 __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
172 inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
173 __m256i weights, reps;
175 MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
176 MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
177 MultiplyGroup(rep_input, ones, wi, weights, reps, result2);
178 MultiplyGroup(rep_input, ones, wi, weights, reps, result3);
179 MultiplyGroup(rep_input, ones, wi, weights, reps, result4);
180 MultiplyGroup(rep_input, ones, wi, weights, reps, result5);
181 MultiplyGroup(rep_input, ones, wi, weights, reps, result6);
182 MultiplyGroup(rep_input, ones, wi, weights, reps, result7);
185 ExtractResults16(result0, result1, wi, scales, v);
186 ExtractResults16(result2, result3, wi, scales, v);
187 ExtractResults16(result4, result5, wi, scales, v);
188 ExtractResults16(result6, result7, wi, scales, v);
193static void PartialMatrixDotVector32(
const int8_t *wi,
const float *scales,
const int8_t *u,
194 int num_in,
float *v) {
197 __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
198 __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
200 __m256i result0 = _mm256_setzero_si256();
201 __m256i result1 = _mm256_setzero_si256();
202 __m256i result2 = _mm256_setzero_si256();
203 __m256i result3 = _mm256_setzero_si256();
205 for (
int j = 0; j < num_in;) {
206 __m256i inputs = _mm256_loadu_si256(
reinterpret_cast<const __m256i *
>(u + j));
209 for (
int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
211 __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
213 inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
214 __m256i weights, reps;
216 MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
217 MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
218 MultiplyGroup(rep_input, ones, wi, weights, reps, result2);
219 MultiplyGroup(rep_input, ones, wi, weights, reps, result3);
222 ExtractResults16(result0, result1, wi, scales, v);
223 ExtractResults16(result2, result3, wi, scales, v);
228static void PartialMatrixDotVector16(
const int8_t *wi,
const float *scales,
const int8_t *u,
229 int num_in,
float *v) {
232 __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
233 __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
235 __m256i result0 = _mm256_setzero_si256();
236 __m256i result1 = _mm256_setzero_si256();
238 for (
int j = 0; j < num_in;) {
239 __m256i inputs = _mm256_loadu_si256(
reinterpret_cast<const __m256i *
>(u + j));
242 for (
int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
244 __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
246 inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
247 __m256i weights, reps;
249 MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
250 MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
253 ExtractResults16(result0, result1, wi, scales, v);
258static inline void PartialMatrixDotVector8(
const int8_t *wi,
const float *scales,
const int8_t *u,
259 int num_in,
float *v) {
262 __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
263 __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
265 __m256i result0 = _mm256_setzero_si256();
267 for (
int j = 0; j < num_in;) {
268 __m256i inputs = _mm256_loadu_si256(
reinterpret_cast<const __m256i *
>(u + j));
271 for (
int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
273 __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
275 inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
276 __m256i weights, reps;
278 MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
281 ExtractResults8(result0, wi, scales, v);
284static void matrixDotVector(
int dim1,
int dim2,
const int8_t *wi,
const float *scales,
285 const int8_t *u,
float *v) {
286 const int num_out = dim1;
287 const int num_in = dim2 - 1;
292 int group_size = kNumOutputsPerRegister * kMaxOutputRegisters;
295 int w_step = (rounded_num_in + 1) * group_size;
299 for (;
output + group_size <= rounded_num_out;
output += group_size) {
300 PartialMatrixDotVector64(wi, scales, u, rounded_num_in, v);
302 scales += group_size;
308 if (
output + group_size <= rounded_num_out) {
309 PartialMatrixDotVector32(wi, scales, u, rounded_num_in, v);
311 scales += group_size;
318 if (
output + group_size <= rounded_num_out) {
319 PartialMatrixDotVector16(wi, scales, u, rounded_num_in, v);
321 scales += group_size;
328 if (
output + group_size <= rounded_num_out) {
329 PartialMatrixDotVector8(wi, scales, u, rounded_num_in, v);
333static inline void ExtractResults8(__m256i result,
const int8_t *wi,
const double *scales,
335 __m128i w128 = load64_to_128(wi);
336 __m256i w256 = _mm256_cvtepi8_epi32(w128);
337 __m256i bias_scale = _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127);
338 __m256d scale0123 = _mm256_loadu_pd(scales);
339 __m256d scale4567 = _mm256_loadu_pd(scales + 4);
340 w256 = _mm256_mullo_epi32(w256, bias_scale);
341 result = _mm256_add_epi32(result, w256);
342 __m256d res0123 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result));
343 result = _mm256_permute4x64_epi64(result, 2 + (3 << 2));
344 __m256d res4567 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result));
345 res0123 = _mm256_mul_pd(res0123, scale0123);
346 res4567 = _mm256_mul_pd(res4567, scale4567);
347 _mm256_storeu_pd(v, res0123);
348 _mm256_storeu_pd(v + 4, res4567);
351static inline void ExtractResults16(__m256i result0, __m256i result1,
const int8_t *&wi,
352 const double *&scales,
double *&v) {
353 __m128i w8 = _mm_loadu_si128(
reinterpret_cast<const __m128i *
>(wi));
355 const __m256i bias_scale = _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127);
356 __m256i w256 = _mm256_cvtepi8_epi32(w8);
357 __m256d scale0123 = _mm256_loadu_pd(scales);
358 __m256d scale4567 = _mm256_loadu_pd(scales + 4);
359 w256 = _mm256_mullo_epi32(w256, bias_scale);
360 result0 = _mm256_add_epi32(result0, w256);
361 __m256d res0123 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result0));
362 result0 = _mm256_permute4x64_epi64(result0, 2 + (3 << 2));
363 __m256d res4567 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result0));
364 res0123 = _mm256_mul_pd(res0123, scale0123);
365 res4567 = _mm256_mul_pd(res4567, scale4567);
366 _mm256_storeu_pd(v, res0123);
367 _mm256_storeu_pd(v + 4, res4567);
368 w8 = _mm_shuffle_epi32(w8, 2 + (3 << 2));
369 w256 = _mm256_cvtepi8_epi32(w8);
370 scale0123 = _mm256_loadu_pd(scales + 8);
371 scale4567 = _mm256_loadu_pd(scales + 12);
372 w256 = _mm256_mullo_epi32(w256, bias_scale);
373 result1 = _mm256_add_epi32(result1, w256);
374 res0123 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result1));
375 result1 = _mm256_permute4x64_epi64(result1, 2 + (3 << 2));
376 res4567 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result1));
377 res0123 = _mm256_mul_pd(res0123, scale0123);
378 res4567 = _mm256_mul_pd(res4567, scale4567);
379 _mm256_storeu_pd(v + 8, res0123);
380 _mm256_storeu_pd(v + 12, res4567);
393static void PartialMatrixDotVector64(
const int8_t *wi,
const double *scales,
const int8_t *u,
394 int num_in,
double *v) {
397 __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
398 __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
400 __m256i result0 = _mm256_setzero_si256();
401 __m256i result1 = _mm256_setzero_si256();
402 __m256i result2 = _mm256_setzero_si256();
403 __m256i result3 = _mm256_setzero_si256();
404 __m256i result4 = _mm256_setzero_si256();
405 __m256i result5 = _mm256_setzero_si256();
406 __m256i result6 = _mm256_setzero_si256();
407 __m256i result7 = _mm256_setzero_si256();
409 for (
int j = 0; j < num_in;) {
410 __m256i inputs = _mm256_loadu_si256(
reinterpret_cast<const __m256i *
>(u + j));
413 for (
int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
415 __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
417 inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
418 __m256i weights, reps;
420 MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
421 MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
422 MultiplyGroup(rep_input, ones, wi, weights, reps, result2);
423 MultiplyGroup(rep_input, ones, wi, weights, reps, result3);
424 MultiplyGroup(rep_input, ones, wi, weights, reps, result4);
425 MultiplyGroup(rep_input, ones, wi, weights, reps, result5);
426 MultiplyGroup(rep_input, ones, wi, weights, reps, result6);
427 MultiplyGroup(rep_input, ones, wi, weights, reps, result7);
430 ExtractResults16(result0, result1, wi, scales, v);
431 ExtractResults16(result2, result3, wi, scales, v);
432 ExtractResults16(result4, result5, wi, scales, v);
433 ExtractResults16(result6, result7, wi, scales, v);
438static void PartialMatrixDotVector32(
const int8_t *wi,
const double *scales,
const int8_t *u,
439 int num_in,
double *v) {
442 __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
443 __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
445 __m256i result0 = _mm256_setzero_si256();
446 __m256i result1 = _mm256_setzero_si256();
447 __m256i result2 = _mm256_setzero_si256();
448 __m256i result3 = _mm256_setzero_si256();
450 for (
int j = 0; j < num_in;) {
451 __m256i inputs = _mm256_loadu_si256(
reinterpret_cast<const __m256i *
>(u + j));
454 for (
int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
456 __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
458 inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
459 __m256i weights, reps;
461 MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
462 MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
463 MultiplyGroup(rep_input, ones, wi, weights, reps, result2);
464 MultiplyGroup(rep_input, ones, wi, weights, reps, result3);
467 ExtractResults16(result0, result1, wi, scales, v);
468 ExtractResults16(result2, result3, wi, scales, v);
473static void PartialMatrixDotVector16(
const int8_t *wi,
const double *scales,
const int8_t *u,
474 int num_in,
double *v) {
477 __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
478 __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
480 __m256i result0 = _mm256_setzero_si256();
481 __m256i result1 = _mm256_setzero_si256();
483 for (
int j = 0; j < num_in;) {
484 __m256i inputs = _mm256_loadu_si256(
reinterpret_cast<const __m256i *
>(u + j));
487 for (
int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
489 __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
491 inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
492 __m256i weights, reps;
494 MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
495 MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
498 ExtractResults16(result0, result1, wi, scales, v);
503static inline void PartialMatrixDotVector8(
const int8_t *wi,
const double *scales,
const int8_t *u,
504 int num_in,
double *v) {
507 __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
508 __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
510 __m256i result0 = _mm256_setzero_si256();
512 for (
int j = 0; j < num_in;) {
513 __m256i inputs = _mm256_loadu_si256(
reinterpret_cast<const __m256i *
>(u + j));
516 for (
int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
518 __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
520 inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
521 __m256i weights, reps;
523 MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
526 ExtractResults8(result0, wi, scales, v);
529static void matrixDotVector(
int dim1,
int dim2,
const int8_t *wi,
const double *scales,
530 const int8_t *u,
double *v) {
531 const int num_out = dim1;
532 const int num_in = dim2 - 1;
537 int group_size = kNumOutputsPerRegister * kMaxOutputRegisters;
540 int w_step = (rounded_num_in + 1) * group_size;
544 for (;
output + group_size <= rounded_num_out;
output += group_size) {
545 PartialMatrixDotVector64(wi, scales, u, rounded_num_in, v);
547 scales += group_size;
553 if (
output + group_size <= rounded_num_out) {
554 PartialMatrixDotVector32(wi, scales, u, rounded_num_in, v);
556 scales += group_size;
563 if (
output + group_size <= rounded_num_out) {
564 PartialMatrixDotVector16(wi, scales, u, rounded_num_in, v);
566 scales += group_size;
573 if (
output + group_size <= rounded_num_out) {
574 PartialMatrixDotVector8(wi, scales, u, rounded_num_in, v);
583 kNumOutputsPerRegister,
587 kNumInputsPerRegister,
static const IntSimdMatrix intSimdMatrixAVX2
static int Roundup(int input, int factor)