tesseract v5.3.3.20231005
intsimdmatrixavx2.cpp
Go to the documentation of this file.
1
2// File: intsimdmatrixavx2.cpp
3// Description: matrix-vector product for 8-bit data on avx2.
4// Author: Ray Smith
5//
6// (C) Copyright 2017, Google Inc.
7// Licensed under the Apache License, Version 2.0 (the "License");
8// you may not use this file except in compliance with the License.
9// You may obtain a copy of the License at
10// http://www.apache.org/licenses/LICENSE-2.0
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
17
18#include "intsimdmatrix.h"
19
20#if !defined(__AVX2__)
21# if defined(__i686__) || defined(__x86_64__)
22# error Implementation only for AVX2 capable architectures
23# endif
24#else
25# include <immintrin.h>
26# include <algorithm>
27# include <cstdint>
28# include <vector>
29
30# if defined(_MSC_VER) && _MSC_VER >= 1925 && _MSC_VER <= 1929 && \
31 defined(_WIN32) && !defined(_WIN64)
32// Optimize for size (/Os) instead of using the default optimization for some
33// versions of the 32 bit Visual Studio compiler which generate buggy code.
34# pragma optimize("", off)
35# pragma optimize("s", on)
36# endif
37
38namespace tesseract {
39
40// Number of outputs held in each register. 8 x 32 bit ints.
41constexpr int kNumOutputsPerRegister = 8;
42// Maximum number of registers that we will use.
43constexpr int kMaxOutputRegisters = 8;
44// Number of inputs in the inputs register.
45constexpr int kNumInputsPerRegister = 32;
46// Number of inputs in each weight group.
47constexpr int kNumInputsPerGroup = 4;
48// Number of groups of inputs to be broadcast.
49constexpr int kNumInputGroups = kNumInputsPerRegister / kNumInputsPerGroup;
50
51// Functions to compute part of a matrix.vector multiplication. The weights
52// are in a very specific order (see above) in w, which is multiplied by
53// u of length num_in, to produce output v after scaling the integer results
54// by the corresponding member of scales.
55// The amount of w and scales consumed is fixed and not available to the
56// caller. The number of outputs written to v will be at most num_out.
57
58// Computes one set of 4x8 products of inputs and weights, adding to result.
59// Horizontally adds 4 adjacent results, making 8x32-bit results.
60// rep_input is assumed to be an 8x replicated set of 4x8-bit signed integers.
61// Note that wi must previously have been re-organized with blocks of 4x8
62// weights in contiguous memory.
63// ones is a register of 16x16-bit values all equal to 1.
64// Note: wi is incremented by the amount of data read.
65// weights and reps are scratch registers.
66// This function must be inlined with references in order for the compiler to
67// correctly use the registers declared in the caller.
68static inline void MultiplyGroup(const __m256i &rep_input, const __m256i &ones, const int8_t *&wi,
69 __m256i &weights, __m256i &reps, __m256i &result) {
70 // Load a 4x8 block of weights.
71 weights = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(wi));
72 wi += kNumInputsPerRegister;
73 // Normalize the signs on rep_input, weights, so weights is always +ve.
74 reps = _mm256_sign_epi8(rep_input, weights);
75 weights = _mm256_sign_epi8(weights, weights);
76 // Multiply 32x8-bit reps by 32x8-bit weights to make 16x16-bit results,
77 // with adjacent pairs added.
78 weights = _mm256_maddubs_epi16(weights, reps);
79 // Multiply 16x16-bit result by 16x16-bit ones to make 8x32-bit results,
80 // with adjacent pairs added. What we really want is a horizontal add of
81 // 16+16=32 bit result, but there is no such instruction, so multiply by
82 // 16-bit ones instead. It is probably faster than all the sign-extending,
83 // permuting and adding that would otherwise be required.
84 weights = _mm256_madd_epi16(weights, ones);
85 result = _mm256_add_epi32(result, weights);
86}
87
88// Load 64 bits into the bottom of a 128bit register.
89// We don't actually care what the top 64bits are, but this ends
90// up with them being zero.
91static inline __m128i load64_to_128(const int8_t *wi_) {
92 const auto *wi = reinterpret_cast<const int64_t *>(wi_);
93 return _mm_set_epi64x(0, wi[0]);
94}
95
96#if defined(FAST_FLOAT)
97
98static inline void ExtractResults8(__m256i result, const int8_t *wi,
99 const float *scales, float *v) {
100 __m128i w128 = load64_to_128(wi); // 8x8bit vals in bottom of 128bit reg
101 __m256i w256 = _mm256_cvtepi8_epi32(w128); // 8x32bit vals in 256bit reg
102 __m256i bias_scale = _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127);
103 __m256 scale01234567 = _mm256_loadu_ps(scales);
104 w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
105 result = _mm256_add_epi32(result, w256); // result += bias * 127
106 __m256 res01234567 = _mm256_cvtepi32_ps(result);
107 result = _mm256_permute4x64_epi64(result, 2 + (3 << 2));
108 res01234567 = _mm256_mul_ps(res01234567, scale01234567);
109 _mm256_storeu_ps(v, res01234567);
110}
111
112static inline void ExtractResults16(__m256i result0, __m256i result1,
113 const int8_t *&wi, const float *&scales,
114 float *&v) {
115 __m128i w8 = _mm_loadu_si128(reinterpret_cast<const __m128i *>(wi));
116 // 8x8bit vals in bottom of 128bit reg
117 const __m256i bias_scale =
118 _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127);
119 __m256i w256 = _mm256_cvtepi8_epi32(w8); // 8x32bit vals in 256bit reg
120 __m256 scale01234567 = _mm256_loadu_ps(scales);
121 w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
122 result0 = _mm256_add_epi32(result0, w256); // result += bias * 127
123 __m256 res01234567 = _mm256_cvtepi32_ps(result0);
124 result0 = _mm256_permute4x64_epi64(result0, 2 + (3 << 2));
125 res01234567 = _mm256_mul_ps(res01234567, scale01234567);
126 _mm256_storeu_ps(v, res01234567);
127 w8 = _mm_shuffle_epi32(w8, 2 + (3 << 2));
128 w256 = _mm256_cvtepi8_epi32(w8); // 8x32bit vals in 256bit reg
129 scale01234567 = _mm256_loadu_ps(scales + 8);
130 w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
131 result1 = _mm256_add_epi32(result1, w256); // result += bias * 127
132 res01234567 = _mm256_cvtepi32_ps(result1);
133 result1 = _mm256_permute4x64_epi64(result1, 2 + (3 << 2));
134 res01234567 = _mm256_mul_ps(res01234567, scale01234567);
135 _mm256_storeu_ps(v + 8, res01234567);
136 wi += 16;
137 scales += 16;
138 v += 16;
139}
140
141// Computes part of matrix.vector v = Wu. Computes N=64 results.
142// The weights *must* be arranged so that consecutive reads from wi
143// provides (num_in/kNumInputsPerGroup groups of (N output dim groups of
144// (kNumInputsPerGroup inputs))). After that there must be N consecutive
145// bias weights, before continuing with any more weights.
146// u must be padded out with zeros to
147// kNumInputsPerGroup*ceil(num_in/kNumInputsPerGroup) elements.
148static void PartialMatrixDotVector64(const int8_t *wi, const float *scales, const int8_t *u,
149 int num_in, float *v) {
150 // Register containing 16-bit ones for horizontal add with 16->32 bit
151 // conversion.
152 __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
153 __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
154 // Initialize all the results to 0.
155 __m256i result0 = _mm256_setzero_si256();
156 __m256i result1 = _mm256_setzero_si256();
157 __m256i result2 = _mm256_setzero_si256();
158 __m256i result3 = _mm256_setzero_si256();
159 __m256i result4 = _mm256_setzero_si256();
160 __m256i result5 = _mm256_setzero_si256();
161 __m256i result6 = _mm256_setzero_si256();
162 __m256i result7 = _mm256_setzero_si256();
163 // Iterate over the input (u), one registerful at a time.
164 for (int j = 0; j < num_in;) {
165 __m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j));
166 // Inputs are processed in groups of kNumInputsPerGroup, replicated
167 // kNumInputGroups times.
168 for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
169 // Replicate the low 32 bits (4 inputs) 8 times.
170 __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
171 // Rotate the inputs in groups of 4, so the next 4 inputs are ready.
172 inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
173 __m256i weights, reps;
174 // Mul-add, with horizontal add of the 4 inputs to each of the results.
175 MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
176 MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
177 MultiplyGroup(rep_input, ones, wi, weights, reps, result2);
178 MultiplyGroup(rep_input, ones, wi, weights, reps, result3);
179 MultiplyGroup(rep_input, ones, wi, weights, reps, result4);
180 MultiplyGroup(rep_input, ones, wi, weights, reps, result5);
181 MultiplyGroup(rep_input, ones, wi, weights, reps, result6);
182 MultiplyGroup(rep_input, ones, wi, weights, reps, result7);
183 }
184 }
185 ExtractResults16(result0, result1, wi, scales, v);
186 ExtractResults16(result2, result3, wi, scales, v);
187 ExtractResults16(result4, result5, wi, scales, v);
188 ExtractResults16(result6, result7, wi, scales, v);
189}
190
191// Computes part of matrix.vector v = Wu. Computes N=32 results.
192// For details see PartialMatrixDotVector64 with N=32.
193static void PartialMatrixDotVector32(const int8_t *wi, const float *scales, const int8_t *u,
194 int num_in, float *v) {
195 // Register containing 16-bit ones for horizontal add with 16->32 bit
196 // conversion.
197 __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
198 __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
199 // Initialize all the results to 0.
200 __m256i result0 = _mm256_setzero_si256();
201 __m256i result1 = _mm256_setzero_si256();
202 __m256i result2 = _mm256_setzero_si256();
203 __m256i result3 = _mm256_setzero_si256();
204 // Iterate over the input (u), one registerful at a time.
205 for (int j = 0; j < num_in;) {
206 __m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j));
207 // Inputs are processed in groups of kNumInputsPerGroup, replicated
208 // kNumInputGroups times.
209 for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
210 // Replicate the low 32 bits (4 inputs) 8 times.
211 __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
212 // Rotate the inputs in groups of 4, so the next 4 inputs are ready.
213 inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
214 __m256i weights, reps;
215 // Mul-add, with horizontal add of the 4 inputs to each of the results.
216 MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
217 MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
218 MultiplyGroup(rep_input, ones, wi, weights, reps, result2);
219 MultiplyGroup(rep_input, ones, wi, weights, reps, result3);
220 }
221 }
222 ExtractResults16(result0, result1, wi, scales, v);
223 ExtractResults16(result2, result3, wi, scales, v);
224}
225
226// Computes part of matrix.vector v = Wu. Computes N=16 results.
227// For details see PartialMatrixDotVector64 with N=16.
228static void PartialMatrixDotVector16(const int8_t *wi, const float *scales, const int8_t *u,
229 int num_in, float *v) {
230 // Register containing 16-bit ones for horizontal add with 16->32 bit
231 // conversion.
232 __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
233 __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
234 // Initialize all the results to 0.
235 __m256i result0 = _mm256_setzero_si256();
236 __m256i result1 = _mm256_setzero_si256();
237 // Iterate over the input (u), one registerful at a time.
238 for (int j = 0; j < num_in;) {
239 __m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j));
240 // Inputs are processed in groups of kNumInputsPerGroup, replicated
241 // kNumInputGroups times.
242 for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
243 // Replicate the low 32 bits (4 inputs) 8 times.
244 __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
245 // Rotate the inputs in groups of 4, so the next 4 inputs are ready.
246 inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
247 __m256i weights, reps;
248 // Mul-add, with horizontal add of the 4 inputs to each of the results.
249 MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
250 MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
251 }
252 }
253 ExtractResults16(result0, result1, wi, scales, v);
254}
255
256// Computes part of matrix.vector v = Wu. Computes N=8 results.
257// For details see PartialMatrixDotVector64 with N=8.
258static inline void PartialMatrixDotVector8(const int8_t *wi, const float *scales, const int8_t *u,
259 int num_in, float *v) {
260 // Register containing 16-bit ones for horizontal add with 16->32 bit
261 // conversion.
262 __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
263 __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
264 // Initialize all the results to 0.
265 __m256i result0 = _mm256_setzero_si256();
266 // Iterate over the input (u), one registerful at a time.
267 for (int j = 0; j < num_in;) {
268 __m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j));
269 // Inputs are processed in groups of kNumInputsPerGroup, replicated
270 // kNumInputGroups times.
271 for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
272 // Replicate the low 32 bits (4 inputs) 8 times.
273 __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
274 // Rotate the inputs in groups of 4, so the next 4 inputs are ready.
275 inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
276 __m256i weights, reps;
277 // Mul-add, with horizontal add of the 4 inputs to each of the results.
278 MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
279 }
280 }
281 ExtractResults8(result0, wi, scales, v);
282}
283
284static void matrixDotVector(int dim1, int dim2, const int8_t *wi, const float *scales,
285 const int8_t *u, float *v) {
286 const int num_out = dim1;
287 const int num_in = dim2 - 1;
288 // Each call to a partial_func_ produces group_size outputs, except the
289 // last one, which can produce less.
290 const int rounded_num_in = IntSimdMatrix::Roundup(num_in, kNumInputsPerGroup);
291 const int rounded_num_out = IntSimdMatrix::Roundup(num_out, kNumOutputsPerRegister);
292 int group_size = kNumOutputsPerRegister * kMaxOutputRegisters;
293 int output = 0;
294
295 int w_step = (rounded_num_in + 1) * group_size;
296
297 // Run with this group size, until it would produce too much output, then
298 // switch to a smaller size.
299 for (; output + group_size <= rounded_num_out; output += group_size) {
300 PartialMatrixDotVector64(wi, scales, u, rounded_num_in, v);
301 wi += w_step;
302 scales += group_size;
303 v += group_size;
304 }
305 group_size /= 2;
306 w_step /= 2;
307
308 if (output + group_size <= rounded_num_out) {
309 PartialMatrixDotVector32(wi, scales, u, rounded_num_in, v);
310 wi += w_step;
311 scales += group_size;
312 v += group_size;
313 output += group_size;
314 }
315 group_size /= 2;
316 w_step /= 2;
317
318 if (output + group_size <= rounded_num_out) {
319 PartialMatrixDotVector16(wi, scales, u, rounded_num_in, v);
320 wi += w_step;
321 scales += group_size;
322 v += group_size;
323 output += group_size;
324 }
325 group_size /= 2;
326 w_step /= 2;
327
328 if (output + group_size <= rounded_num_out) {
329 PartialMatrixDotVector8(wi, scales, u, rounded_num_in, v);
330 }
331}
332#else
333static inline void ExtractResults8(__m256i result, const int8_t *wi, const double *scales,
334 double *v) {
335 __m128i w128 = load64_to_128(wi); // 8x8bit vals in bottom of 128bit reg
336 __m256i w256 = _mm256_cvtepi8_epi32(w128); // 8x32bit vals in 256bit reg
337 __m256i bias_scale = _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127);
338 __m256d scale0123 = _mm256_loadu_pd(scales);
339 __m256d scale4567 = _mm256_loadu_pd(scales + 4);
340 w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
341 result = _mm256_add_epi32(result, w256); // result += bias * 127
342 __m256d res0123 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result));
343 result = _mm256_permute4x64_epi64(result, 2 + (3 << 2));
344 __m256d res4567 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result));
345 res0123 = _mm256_mul_pd(res0123, scale0123);
346 res4567 = _mm256_mul_pd(res4567, scale4567);
347 _mm256_storeu_pd(v, res0123);
348 _mm256_storeu_pd(v + 4, res4567);
349}
350
351static inline void ExtractResults16(__m256i result0, __m256i result1, const int8_t *&wi,
352 const double *&scales, double *&v) {
353 __m128i w8 = _mm_loadu_si128(reinterpret_cast<const __m128i *>(wi));
354 // 8x8bit vals in bottom of 128bit reg
355 const __m256i bias_scale = _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127);
356 __m256i w256 = _mm256_cvtepi8_epi32(w8); // 8x32bit vals in 256bit reg
357 __m256d scale0123 = _mm256_loadu_pd(scales);
358 __m256d scale4567 = _mm256_loadu_pd(scales + 4);
359 w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
360 result0 = _mm256_add_epi32(result0, w256); // result += bias * 127
361 __m256d res0123 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result0));
362 result0 = _mm256_permute4x64_epi64(result0, 2 + (3 << 2));
363 __m256d res4567 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result0));
364 res0123 = _mm256_mul_pd(res0123, scale0123);
365 res4567 = _mm256_mul_pd(res4567, scale4567);
366 _mm256_storeu_pd(v, res0123);
367 _mm256_storeu_pd(v + 4, res4567);
368 w8 = _mm_shuffle_epi32(w8, 2 + (3 << 2));
369 w256 = _mm256_cvtepi8_epi32(w8); // 8x32bit vals in 256bit reg
370 scale0123 = _mm256_loadu_pd(scales + 8);
371 scale4567 = _mm256_loadu_pd(scales + 12);
372 w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
373 result1 = _mm256_add_epi32(result1, w256); // result += bias * 127
374 res0123 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result1));
375 result1 = _mm256_permute4x64_epi64(result1, 2 + (3 << 2));
376 res4567 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result1));
377 res0123 = _mm256_mul_pd(res0123, scale0123);
378 res4567 = _mm256_mul_pd(res4567, scale4567);
379 _mm256_storeu_pd(v + 8, res0123);
380 _mm256_storeu_pd(v + 12, res4567);
381 wi += 16;
382 scales += 16;
383 v += 16;
384}
385
386// Computes part of matrix.vector v = Wu. Computes N=64 results.
387// The weights *must* be arranged so that consecutive reads from wi
388// provides (num_in/kNumInputsPerGroup groups of (N output dim groups of
389// (kNumInputsPerGroup inputs))). After that there must be N consecutive
390// bias weights, before continuing with any more weights.
391// u must be padded out with zeros to
392// kNumInputsPerGroup*ceil(num_in/kNumInputsPerGroup) elements.
393static void PartialMatrixDotVector64(const int8_t *wi, const double *scales, const int8_t *u,
394 int num_in, double *v) {
395 // Register containing 16-bit ones for horizontal add with 16->32 bit
396 // conversion.
397 __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
398 __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
399 // Initialize all the results to 0.
400 __m256i result0 = _mm256_setzero_si256();
401 __m256i result1 = _mm256_setzero_si256();
402 __m256i result2 = _mm256_setzero_si256();
403 __m256i result3 = _mm256_setzero_si256();
404 __m256i result4 = _mm256_setzero_si256();
405 __m256i result5 = _mm256_setzero_si256();
406 __m256i result6 = _mm256_setzero_si256();
407 __m256i result7 = _mm256_setzero_si256();
408 // Iterate over the input (u), one registerful at a time.
409 for (int j = 0; j < num_in;) {
410 __m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j));
411 // Inputs are processed in groups of kNumInputsPerGroup, replicated
412 // kNumInputGroups times.
413 for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
414 // Replicate the low 32 bits (4 inputs) 8 times.
415 __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
416 // Rotate the inputs in groups of 4, so the next 4 inputs are ready.
417 inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
418 __m256i weights, reps;
419 // Mul-add, with horizontal add of the 4 inputs to each of the results.
420 MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
421 MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
422 MultiplyGroup(rep_input, ones, wi, weights, reps, result2);
423 MultiplyGroup(rep_input, ones, wi, weights, reps, result3);
424 MultiplyGroup(rep_input, ones, wi, weights, reps, result4);
425 MultiplyGroup(rep_input, ones, wi, weights, reps, result5);
426 MultiplyGroup(rep_input, ones, wi, weights, reps, result6);
427 MultiplyGroup(rep_input, ones, wi, weights, reps, result7);
428 }
429 }
430 ExtractResults16(result0, result1, wi, scales, v);
431 ExtractResults16(result2, result3, wi, scales, v);
432 ExtractResults16(result4, result5, wi, scales, v);
433 ExtractResults16(result6, result7, wi, scales, v);
434}
435
436// Computes part of matrix.vector v = Wu. Computes N=32 results.
437// For details see PartialMatrixDotVector64 with N=32.
438static void PartialMatrixDotVector32(const int8_t *wi, const double *scales, const int8_t *u,
439 int num_in, double *v) {
440 // Register containing 16-bit ones for horizontal add with 16->32 bit
441 // conversion.
442 __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
443 __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
444 // Initialize all the results to 0.
445 __m256i result0 = _mm256_setzero_si256();
446 __m256i result1 = _mm256_setzero_si256();
447 __m256i result2 = _mm256_setzero_si256();
448 __m256i result3 = _mm256_setzero_si256();
449 // Iterate over the input (u), one registerful at a time.
450 for (int j = 0; j < num_in;) {
451 __m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j));
452 // Inputs are processed in groups of kNumInputsPerGroup, replicated
453 // kNumInputGroups times.
454 for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
455 // Replicate the low 32 bits (4 inputs) 8 times.
456 __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
457 // Rotate the inputs in groups of 4, so the next 4 inputs are ready.
458 inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
459 __m256i weights, reps;
460 // Mul-add, with horizontal add of the 4 inputs to each of the results.
461 MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
462 MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
463 MultiplyGroup(rep_input, ones, wi, weights, reps, result2);
464 MultiplyGroup(rep_input, ones, wi, weights, reps, result3);
465 }
466 }
467 ExtractResults16(result0, result1, wi, scales, v);
468 ExtractResults16(result2, result3, wi, scales, v);
469}
470
471// Computes part of matrix.vector v = Wu. Computes N=16 results.
472// For details see PartialMatrixDotVector64 with N=16.
473static void PartialMatrixDotVector16(const int8_t *wi, const double *scales, const int8_t *u,
474 int num_in, double *v) {
475 // Register containing 16-bit ones for horizontal add with 16->32 bit
476 // conversion.
477 __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
478 __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
479 // Initialize all the results to 0.
480 __m256i result0 = _mm256_setzero_si256();
481 __m256i result1 = _mm256_setzero_si256();
482 // Iterate over the input (u), one registerful at a time.
483 for (int j = 0; j < num_in;) {
484 __m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j));
485 // Inputs are processed in groups of kNumInputsPerGroup, replicated
486 // kNumInputGroups times.
487 for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
488 // Replicate the low 32 bits (4 inputs) 8 times.
489 __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
490 // Rotate the inputs in groups of 4, so the next 4 inputs are ready.
491 inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
492 __m256i weights, reps;
493 // Mul-add, with horizontal add of the 4 inputs to each of the results.
494 MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
495 MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
496 }
497 }
498 ExtractResults16(result0, result1, wi, scales, v);
499}
500
501// Computes part of matrix.vector v = Wu. Computes N=8 results.
502// For details see PartialMatrixDotVector64 with N=8.
503static inline void PartialMatrixDotVector8(const int8_t *wi, const double *scales, const int8_t *u,
504 int num_in, double *v) {
505 // Register containing 16-bit ones for horizontal add with 16->32 bit
506 // conversion.
507 __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
508 __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
509 // Initialize all the results to 0.
510 __m256i result0 = _mm256_setzero_si256();
511 // Iterate over the input (u), one registerful at a time.
512 for (int j = 0; j < num_in;) {
513 __m256i inputs = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(u + j));
514 // Inputs are processed in groups of kNumInputsPerGroup, replicated
515 // kNumInputGroups times.
516 for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) {
517 // Replicate the low 32 bits (4 inputs) 8 times.
518 __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
519 // Rotate the inputs in groups of 4, so the next 4 inputs are ready.
520 inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
521 __m256i weights, reps;
522 // Mul-add, with horizontal add of the 4 inputs to each of the results.
523 MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
524 }
525 }
526 ExtractResults8(result0, wi, scales, v);
527}
528
529static void matrixDotVector(int dim1, int dim2, const int8_t *wi, const double *scales,
530 const int8_t *u, double *v) {
531 const int num_out = dim1;
532 const int num_in = dim2 - 1;
533 // Each call to a partial_func_ produces group_size outputs, except the
534 // last one, which can produce less.
535 const int rounded_num_in = IntSimdMatrix::Roundup(num_in, kNumInputsPerGroup);
536 const int rounded_num_out = IntSimdMatrix::Roundup(num_out, kNumOutputsPerRegister);
537 int group_size = kNumOutputsPerRegister * kMaxOutputRegisters;
538 int output = 0;
539
540 int w_step = (rounded_num_in + 1) * group_size;
541
542 // Run with this group size, until it would produce too much output, then
543 // switch to a smaller size.
544 for (; output + group_size <= rounded_num_out; output += group_size) {
545 PartialMatrixDotVector64(wi, scales, u, rounded_num_in, v);
546 wi += w_step;
547 scales += group_size;
548 v += group_size;
549 }
550 group_size /= 2;
551 w_step /= 2;
552
553 if (output + group_size <= rounded_num_out) {
554 PartialMatrixDotVector32(wi, scales, u, rounded_num_in, v);
555 wi += w_step;
556 scales += group_size;
557 v += group_size;
558 output += group_size;
559 }
560 group_size /= 2;
561 w_step /= 2;
562
563 if (output + group_size <= rounded_num_out) {
564 PartialMatrixDotVector16(wi, scales, u, rounded_num_in, v);
565 wi += w_step;
566 scales += group_size;
567 v += group_size;
568 output += group_size;
569 }
570 group_size /= 2;
571 w_step /= 2;
572
573 if (output + group_size <= rounded_num_out) {
574 PartialMatrixDotVector8(wi, scales, u, rounded_num_in, v);
575 }
576}
577#endif
578
579const IntSimdMatrix IntSimdMatrix::intSimdMatrixAVX2 = {
580 // Function.
581 matrixDotVector,
582 // Number of 32 bit outputs held in each register.
583 kNumOutputsPerRegister,
584 // Maximum number of registers that we will use to hold outputs.
585 kMaxOutputRegisters,
586 // Number of 8 bit inputs in the inputs register.
587 kNumInputsPerRegister,
588 // Number of inputs in each weight group.
589 kNumInputsPerGroup
590};
591
592} // namespace tesseract.
593
594#endif
static const IntSimdMatrix intSimdMatrixAVX2
static int Roundup(int input, int factor)
Definition: intsimdmatrix.h:87