You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

bf16_common_macros.h 47 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847
  1. /***************************************************************************
  2. Copyright (c) 2014, The OpenBLAS Project
  3. All rights reserved.
  4. Redistribution and use in source and binary forms, with or without
  5. modification, are permitted provided that the following conditions are
  6. met:
  7. 1. Redistributions of source code must retain the above copyright
  8. notice, this list of conditions and the following disclaimer.
  9. 2. Redistributions in binary form must reproduce the above copyright
  10. notice, this list of conditions and the following disclaimer in
  11. the documentation and/or other materials provided with the
  12. distribution.
  13. 3. Neither the name of the OpenBLAS project nor the names of
  14. its contributors may be used to endorse or promote products
  15. derived from this software without specific prior written permission.
  16. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
  17. AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  18. IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
  19. ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
  20. LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
  21. DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
  22. SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
  23. CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
  24. OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
  25. USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  26. *****************************************************************************/
  27. #ifndef __BF16_COMMON_MACROS
  28. #define __BF16_COMMON_MACROS
  29. #include <immintrin.h>
  30. #define _MM512_BROADCASTD_EPI32(addr, zmm) \
  31. __asm__ ("vpbroadcastd (%1), %0;" \
  32. : "=v" (zmm) \
  33. : "r" (addr) )
  34. #define PREFETCH_T0(addr) \
  35. __asm__ ("prefetcht0 (%0);" \
  36. : \
  37. : "r" (addr) )
  38. #define EXTRACT_LOW_256_FROM_512_2X(reg256, reg512) \
  39. reg256##_0 = _mm512_castps512_ps256(reg512##_0); \
  40. reg256##_1 = _mm512_castps512_ps256(reg512##_1);
  41. #define BF16_MATRIX_LOAD_8x32(regArray, a, lda, idx_m, idx_n) \
  42. regArray##_0 = _mm512_loadu_si512(&a[(idx_m+0)*lda + idx_n]); \
  43. regArray##_1 = _mm512_loadu_si512(&a[(idx_m+1)*lda + idx_n]); \
  44. regArray##_2 = _mm512_loadu_si512(&a[(idx_m+2)*lda + idx_n]); \
  45. regArray##_3 = _mm512_loadu_si512(&a[(idx_m+3)*lda + idx_n]); \
  46. regArray##_4 = _mm512_loadu_si512(&a[(idx_m+4)*lda + idx_n]); \
  47. regArray##_5 = _mm512_loadu_si512(&a[(idx_m+5)*lda + idx_n]); \
  48. regArray##_6 = _mm512_loadu_si512(&a[(idx_m+6)*lda + idx_n]); \
  49. regArray##_7 = _mm512_loadu_si512(&a[(idx_m+7)*lda + idx_n]);
  50. #define BF16_MATRIX_LOAD_8x16(regArray, a, lda, idx_m, idx_n) \
  51. regArray##_0 = _mm256_loadu_si256((__m256i *)(&a[(idx_m+0)*lda + idx_n])); \
  52. regArray##_1 = _mm256_loadu_si256((__m256i *)(&a[(idx_m+1)*lda + idx_n])); \
  53. regArray##_2 = _mm256_loadu_si256((__m256i *)(&a[(idx_m+2)*lda + idx_n])); \
  54. regArray##_3 = _mm256_loadu_si256((__m256i *)(&a[(idx_m+3)*lda + idx_n])); \
  55. regArray##_4 = _mm256_loadu_si256((__m256i *)(&a[(idx_m+4)*lda + idx_n])); \
  56. regArray##_5 = _mm256_loadu_si256((__m256i *)(&a[(idx_m+5)*lda + idx_n])); \
  57. regArray##_6 = _mm256_loadu_si256((__m256i *)(&a[(idx_m+6)*lda + idx_n])); \
  58. regArray##_7 = _mm256_loadu_si256((__m256i *)(&a[(idx_m+7)*lda + idx_n]));
  59. #define BF16_MATRIX_LOAD_8x8(regArray, a, lda, idx_m, idx_n) \
  60. regArray##_0 = _mm_loadu_si128((__m128i *)(&a[(idx_m+0)*lda + idx_n])); \
  61. regArray##_1 = _mm_loadu_si128((__m128i *)(&a[(idx_m+1)*lda + idx_n])); \
  62. regArray##_2 = _mm_loadu_si128((__m128i *)(&a[(idx_m+2)*lda + idx_n])); \
  63. regArray##_3 = _mm_loadu_si128((__m128i *)(&a[(idx_m+3)*lda + idx_n])); \
  64. regArray##_4 = _mm_loadu_si128((__m128i *)(&a[(idx_m+4)*lda + idx_n])); \
  65. regArray##_5 = _mm_loadu_si128((__m128i *)(&a[(idx_m+5)*lda + idx_n])); \
  66. regArray##_6 = _mm_loadu_si128((__m128i *)(&a[(idx_m+6)*lda + idx_n])); \
  67. regArray##_7 = _mm_loadu_si128((__m128i *)(&a[(idx_m+7)*lda + idx_n]));
  68. #define BF16_MATRIX_LOAD_1x32(regArray, a, lda, idx_m, idx_n) \
  69. regArray = _mm512_loadu_si512(&a[idx_m*lda + idx_n]);
  70. #define BF16_MATRIX_MASKZ_LOAD_8x32(regArray, a, lda, idx_m, idx_n, mask) \
  71. regArray##_0 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+0)*lda + idx_n]); \
  72. regArray##_1 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+1)*lda + idx_n]); \
  73. regArray##_2 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+2)*lda + idx_n]); \
  74. regArray##_3 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+3)*lda + idx_n]); \
  75. regArray##_4 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+4)*lda + idx_n]); \
  76. regArray##_5 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+5)*lda + idx_n]); \
  77. regArray##_6 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+6)*lda + idx_n]); \
  78. regArray##_7 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+7)*lda + idx_n]);
  79. #define BF16_MATRIX_MASKZ_LOAD_8x16(regArray, a, lda, idx_m, idx_n, mask) \
  80. regArray##_0 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+0)*lda + idx_n]); \
  81. regArray##_1 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+1)*lda + idx_n]); \
  82. regArray##_2 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+2)*lda + idx_n]); \
  83. regArray##_3 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+3)*lda + idx_n]); \
  84. regArray##_4 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+4)*lda + idx_n]); \
  85. regArray##_5 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+5)*lda + idx_n]); \
  86. regArray##_6 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+6)*lda + idx_n]); \
  87. regArray##_7 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+7)*lda + idx_n]);
  88. #define BF16_MATRIX_MASKZ_LOAD_8x8(regArray, a, lda, idx_m, idx_n, mask) \
  89. regArray##_0 = _mm_maskz_loadu_epi16(mask, &a[(idx_m+0)*lda + idx_n]); \
  90. regArray##_1 = _mm_maskz_loadu_epi16(mask, &a[(idx_m+1)*lda + idx_n]); \
  91. regArray##_2 = _mm_maskz_loadu_epi16(mask, &a[(idx_m+2)*lda + idx_n]); \
  92. regArray##_3 = _mm_maskz_loadu_epi16(mask, &a[(idx_m+3)*lda + idx_n]); \
  93. regArray##_4 = _mm_maskz_loadu_epi16(mask, &a[(idx_m+4)*lda + idx_n]); \
  94. regArray##_5 = _mm_maskz_loadu_epi16(mask, &a[(idx_m+5)*lda + idx_n]); \
  95. regArray##_6 = _mm_maskz_loadu_epi16(mask, &a[(idx_m+6)*lda + idx_n]); \
  96. regArray##_7 = _mm_maskz_loadu_epi16(mask, &a[(idx_m+7)*lda + idx_n]);
  97. #define BF16_MATRIX_MASKZ_LOAD_4x32(regArray, a, lda, idx_m, idx_n, mask) \
  98. regArray##_0 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+0)*lda + idx_n]); \
  99. regArray##_1 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+1)*lda + idx_n]); \
  100. regArray##_2 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+2)*lda + idx_n]); \
  101. regArray##_3 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+3)*lda + idx_n]);
  102. #define BF16_MATRIX_MASKZ_LOAD_4x16(regArray, a, lda, idx_m, idx_n, mask) \
  103. regArray##_0 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+0)*lda + idx_n]); \
  104. regArray##_1 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+1)*lda + idx_n]); \
  105. regArray##_2 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+2)*lda + idx_n]); \
  106. regArray##_3 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+3)*lda + idx_n]);
  107. #define BF16_MATRIX_MASKZ_LOAD_8x32_2(regArray, a, lda, idx_m, idx_n, mask) \
  108. regArray##_0 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+0)*lda + idx_n]); \
  109. regArray##_1 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+2)*lda + idx_n]); \
  110. regArray##_2 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+4)*lda + idx_n]); \
  111. regArray##_3 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+6)*lda + idx_n]); \
  112. regArray##_4 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+8)*lda + idx_n]); \
  113. regArray##_5 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+10)*lda + idx_n]); \
  114. regArray##_6 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+12)*lda + idx_n]); \
  115. regArray##_7 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+14)*lda + idx_n]);
  116. #define BF16_MATRIX_MASKZ_LOAD_4x32_2(regArray, a, lda, idx_m, idx_n, mask) \
  117. regArray##_0 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+0)*lda + idx_n]); \
  118. regArray##_1 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+2)*lda + idx_n]); \
  119. regArray##_2 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+4)*lda + idx_n]); \
  120. regArray##_3 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+6)*lda + idx_n]);
  121. #define BF16_MATRIX_MASKZ_LOAD_1x32(regArray, a, lda, idx_m, idx_n, mask) \
  122. regArray = _mm512_maskz_loadu_epi16(mask, &a[idx_m*lda + idx_n]);
  123. #define BF16_VECTOR_LOAD_1x32(reg, x, idx_n) \
  124. reg = _mm512_loadu_si512(x + idx_n);
  125. #define BF16_VECTOR_LOAD_1x16(reg, x, idx_n) \
  126. reg = _mm256_loadu_si256((__m256i *)(x + idx_n));
  127. #define BF16_VECTOR_LOAD_1x8(reg, x, idx_n) \
  128. reg = _mm_loadu_si128((__m128i *)(x + idx_n));
  129. #define BF16_VECTOR_MASKZ_LOAD_1x32(reg, x, idx_n, mask) \
  130. reg = _mm512_maskz_loadu_epi16(mask, x + idx_n);
  131. #define BF16_VECTOR_MASKZ_LOAD_1x16(reg, x, idx_n, mask) \
  132. reg = _mm256_maskz_loadu_epi16(mask, x + idx_n);
  133. #define BF16_VECTOR_MASKZ_LOAD_1x8(reg, x, idx_n, mask) \
  134. reg = _mm_maskz_loadu_epi16(mask, x + idx_n);
  135. /* 2-step interleave for matrix against 8 rows with 32 BF16 elements per row
  136. Input - register array of 8 rows of raw-major matrix
  137. Output - the output of Step 2
  138. Step 1: 2-element interleave for matrix
  139. |a0|a1|b0|b1|a2|a3|b2|b3|a8 |a9 |b8 |b9 |a10|a11|b10|b11|a16|a17|b16|b17|a18|a19|b18|b19|a24|a25|b24|b25|a26|a27|b26|b27
  140. |c0|c1|d0|d1|c2|c3|d2|d3|c8 |c9 |d8 |d9 |c10|c11|d10|d11|c16|c17|d16|d17|c18|c19|d18|d19|c24|c25|d24|d25|c26|c27|d26|d27
  141. |e0|e1|f0|f1|e2|e3|f2|f3|e8 |e9 |f8 |f9 |e10|e11|f10|f11|e16|e17|f16|f17|e18|e19|f18|f19|e24|e25|f24|f25|e26|e27|f26|f27
  142. |g0|g1|h0|h1|g2|g3|h2|h3|g8 |g9 |h8 |h9 |g10|g11|h10|h11|g16|g17|h16|h17|g18|g19|h18|h19|g24|g25|h24|h25|g26|g27|h26|h27
  143. |a4|a5|b4|b5|a6|a7|b6|b7|a12|a13|b12|b13|a14|a15|b14|b15|a20|a21|b20|b21|a22|a23|b22|b23|a28|a29|b28|b29|a30|a31|b30|b31
  144. |c4|c5|d4|d5|c6|c7|d6|d7|c12|c13|d12|d13|c14|c15|d14|d15|c20|c21|d20|d21|c22|c23|d22|d23|c28|c29|d28|d29|c30|c31|d30|d31
  145. |e4|e5|f4|f5|e6|e7|f6|f7|e12|e13|f12|f13|e14|e15|f14|f15|e20|e21|f20|f21|e22|e23|f22|f23|e28|e29|f28|f29|e30|e31|f30|f31
  146. |g4|g5|h4|h5|g6|g7|h6|h7|g12|g13|h12|h13|g14|g15|h14|h15|g20|g21|h20|h21|g22|g23|h22|h23|g28|g29|h28|h29|g30|g31|h30|h31
  147. Step 2: 4-element interleave for matrix
  148. |a0|a1|b0|b1|c0|c1|d0|d1|a8 |a9 |b8 |b9 |c8 |c9 |d8 |d9 |a16|a17|b16|b17|c16|c17|d16|d17|a24|a25|b24|b25|c24|c25|d24|d25
  149. |a2|a3|b2|b3|c2|c3|d2|d3|a10|a11|b10|b11|c10|c11|d10|d11|a18|a19|b18|b19|c18|c19|d18|d19|a26|a27|b26|b27|c26|c27|d26|d27
  150. |e0|e1|f0|f1|g0|g1|h0|h1|e8 |e9 |f8 |f9 |g8 |g9 |h8 |h9 |e16|e17|f16|f17|g16|g17|h16|h17|e24|e25|f24|f25|g24|g25|h24|h25
  151. |e2|e3|f2|f3|g2|g3|h2|h3|e10|e11|f10|f11|g10|g11|h10|h11|e18|e19|f18|f19|g18|g19|h18|h19|e26|e27|f26|f27|g26|g27|h26|h27
  152. |a4|a5|b4|b5|c4|c5|d4|d5|a12|a13|b12|b13|c12|c13|d12|d13|a20|a21|b20|b21|c20|c21|d20|d21|a28|a29|b28|b29|c28|c29|d28|d29
  153. |a6|a7|b6|b7|c6|c7|d6|d7|a14|a15|b14|b15|c14|c15|d14|d15|a22|a23|b22|b23|c22|c23|d22|d23|a30|a31|b30|b31|c30|c31|d30|d31
  154. |e4|e5|f4|f5|g4|g5|h4|h5|e12|e13|f12|f13|g12|g13|h12|h13|e20|e21|f20|f21|g20|g21|h20|h21|e28|e29|f28|f29|g28|g29|h28|h29
  155. |e6|e7|f6|f7|g6|g7|h6|h7|e14|e15|f14|f15|g14|g15|h14|h15|e22|e23|f22|f23|g22|g23|h22|h23|e30|e31|f30|f31|g30|g31|h30|h31
  156. */
  157. #define BF16_INTERLEAVE_8x32(regArray) \
  158. regArray##_8 = _mm512_unpacklo_epi32(regArray##_0, regArray##_1); \
  159. regArray##_9 = _mm512_unpacklo_epi32(regArray##_2, regArray##_3); \
  160. regArray##_10 = _mm512_unpacklo_epi32(regArray##_4, regArray##_5); \
  161. regArray##_11 = _mm512_unpacklo_epi32(regArray##_6, regArray##_7); \
  162. regArray##_12 = _mm512_unpackhi_epi32(regArray##_0, regArray##_1); \
  163. regArray##_13 = _mm512_unpackhi_epi32(regArray##_2, regArray##_3); \
  164. regArray##_14 = _mm512_unpackhi_epi32(regArray##_4, regArray##_5); \
  165. regArray##_15 = _mm512_unpackhi_epi32(regArray##_6, regArray##_7); \
  166. \
  167. regArray##_0 = _mm512_unpacklo_epi64(regArray##_8, regArray##_9); \
  168. regArray##_1 = _mm512_unpackhi_epi64(regArray##_8, regArray##_9); \
  169. regArray##_2 = _mm512_unpacklo_epi64(regArray##_10, regArray##_11); \
  170. regArray##_3 = _mm512_unpackhi_epi64(regArray##_10, regArray##_11); \
  171. regArray##_4 = _mm512_unpacklo_epi64(regArray##_12, regArray##_13); \
  172. regArray##_5 = _mm512_unpackhi_epi64(regArray##_12, regArray##_13); \
  173. regArray##_6 = _mm512_unpacklo_epi64(regArray##_14, regArray##_15); \
  174. regArray##_7 = _mm512_unpackhi_epi64(regArray##_14, regArray##_15);
  175. /* 2-step interleave for matrix against 8 rows with 16 BF16 elements per row
  176. Input - register array of 8 rows of raw-major matrix
  177. Output - the output of Step 2
  178. Step 1: 2-element interleave for matrix
  179. |a0|a1|b0|b1|a2|a3|b2|b3|a8 |a9 |b8 |b9 |a10|a11|b10|b11
  180. |c0|c1|d0|d1|c2|c3|d2|d3|c8 |c9 |d8 |d9 |c10|c11|d10|d11
  181. |e0|e1|f0|f1|e2|e3|f2|f3|e8 |e9 |f8 |f9 |e10|e11|f10|f11
  182. |g0|g1|h0|h1|g2|g3|h2|h3|g8 |g9 |h8 |h9 |g10|g11|h10|h11
  183. |a4|a5|b4|b5|a6|a7|b6|b7|a12|a13|b12|b13|a14|a15|b14|b15
  184. |c4|c5|d4|d5|c6|c7|d6|d7|c12|c13|d12|d13|c14|c15|d14|d15
  185. |e4|e5|f4|f5|e6|e7|f6|f7|e12|e13|f12|f13|e14|e15|f14|f15
  186. |g4|g5|h4|h5|g6|g7|h6|h7|g12|g13|h12|h13|g14|g15|h14|h15
  187. Step 2: 4-element interleave for matrix
  188. |a0|a1|b0|b1|c0|c1|d0|d1|a8 |a9 |b8 |b9 |c8 |c9 |d8 |d9
  189. |a2|a3|b2|b3|c2|c3|d2|d3|a10|a11|b10|b11|c10|c11|d10|d11
  190. |e0|e1|f0|f1|g0|g1|h0|h1|e8 |e9 |f8 |f9 |g8 |g9 |h8 |h9
  191. |e2|e3|f2|f3|g2|g3|h2|h3|e10|e11|f10|f11|g10|g11|h10|h11
  192. |a4|a5|b4|b5|c4|c5|d4|d5|a12|a13|b12|b13|c12|c13|d12|d13
  193. |a6|a7|b6|b7|c6|c7|d6|d7|a14|a15|b14|b15|c14|c15|d14|d15
  194. |e4|e5|f4|f5|g4|g5|h4|h5|e12|e13|f12|f13|g12|g13|h12|h13
  195. |e6|e7|f6|f7|g6|g7|h6|h7|e14|e15|f14|f15|g14|g15|h14|h15
  196. */
  197. #define BF16_INTERLEAVE_8x16(regArray) \
  198. regArray##_8 = _mm256_unpacklo_epi32(regArray##_0, regArray##_1); \
  199. regArray##_9 = _mm256_unpacklo_epi32(regArray##_2, regArray##_3); \
  200. regArray##_10 = _mm256_unpacklo_epi32(regArray##_4, regArray##_5); \
  201. regArray##_11 = _mm256_unpacklo_epi32(regArray##_6, regArray##_7); \
  202. regArray##_12 = _mm256_unpackhi_epi32(regArray##_0, regArray##_1); \
  203. regArray##_13 = _mm256_unpackhi_epi32(regArray##_2, regArray##_3); \
  204. regArray##_14 = _mm256_unpackhi_epi32(regArray##_4, regArray##_5); \
  205. regArray##_15 = _mm256_unpackhi_epi32(regArray##_6, regArray##_7); \
  206. \
  207. regArray##_0 = _mm256_unpacklo_epi64(regArray##_8, regArray##_9); \
  208. regArray##_1 = _mm256_unpackhi_epi64(regArray##_8, regArray##_9); \
  209. regArray##_2 = _mm256_unpacklo_epi64(regArray##_10, regArray##_11); \
  210. regArray##_3 = _mm256_unpackhi_epi64(regArray##_10, regArray##_11); \
  211. regArray##_4 = _mm256_unpacklo_epi64(regArray##_12, regArray##_13); \
  212. regArray##_5 = _mm256_unpackhi_epi64(regArray##_12, regArray##_13); \
  213. regArray##_6 = _mm256_unpacklo_epi64(regArray##_14, regArray##_15); \
  214. regArray##_7 = _mm256_unpackhi_epi64(regArray##_14, regArray##_15);
  215. /* 2-step interleave for matrix against 8 rows with 32 BF16 elements per row
  216. Input - register array of 8 rows of raw-major matrix
  217. Output - the output of Step 2
  218. Step 1: 2-element interleave for matrix
  219. |a0|a1|b0|b1|a2|a3|b2|b3|a8 |a9 |b8 |b9 |a10|a11|b10|b11|a16|a17|b16|b17|a18|a19|b18|b19|a24|a25|b24|b25|a26|a27|b26|b27
  220. |c0|c1|d0|d1|c2|c3|d2|d3|c8 |c9 |d8 |d9 |c10|c11|d10|d11|c16|c17|d16|d17|c18|c19|d18|d19|c24|c25|d24|d25|c26|c27|d26|d27
  221. |a4|a5|b4|b5|a6|a7|b6|b7|a12|a13|b12|b13|a14|a15|b14|b15|a20|a21|b20|b21|a22|a23|b22|b23|a28|a29|b28|b29|a30|a31|b30|b31
  222. |c4|c5|d4|d5|c6|c7|d6|d7|c12|c13|d12|d13|c14|c15|d14|d15|c20|c21|d20|d21|c22|c23|d22|d23|c28|c29|d28|d29|c30|c31|d30|d31
  223. Step 2: 4-element interleave for matrix
  224. |a0|a1|b0|b1|c0|c1|d0|d1|a8 |a9 |b8 |b9 |c8 |c9 |d8 |d9 |a16|a17|b16|b17|c16|c17|d16|d17|a24|a25|b24|b25|c24|c25|d24|d25
  225. |a2|a3|b2|b3|c2|c3|d2|d3|a10|a11|b10|b11|c10|c11|d10|d11|a18|a19|b18|b19|c18|c19|d18|d19|a26|a27|b26|b27|c26|c27|d26|d27
  226. |a4|a5|b4|b5|c4|c5|d4|d5|a12|a13|b12|b13|c12|c13|d12|d13|a20|a21|b20|b21|c20|c21|d20|d21|a28|a29|b28|b29|c28|c29|d28|d29
  227. |a6|a7|b6|b7|c6|c7|d6|d7|a14|a15|b14|b15|c14|c15|d14|d15|a22|a23|b22|b23|c22|c23|d22|d23|a30|a31|b30|b31|c30|c31|d30|d31
  228. */
  229. #define BF16_INTERLEAVE_4x32(regArray) \
  230. regArray##_4 = _mm512_unpacklo_epi32(regArray##_0, regArray##_1); \
  231. regArray##_5 = _mm512_unpacklo_epi32(regArray##_2, regArray##_3); \
  232. regArray##_6 = _mm512_unpackhi_epi32(regArray##_0, regArray##_1); \
  233. regArray##_7 = _mm512_unpackhi_epi32(regArray##_2, regArray##_3); \
  234. \
  235. regArray##_0 = _mm512_unpacklo_epi64(regArray##_4, regArray##_5); \
  236. regArray##_1 = _mm512_unpackhi_epi64(regArray##_4, regArray##_5); \
  237. regArray##_2 = _mm512_unpacklo_epi64(regArray##_6, regArray##_7); \
  238. regArray##_3 = _mm512_unpackhi_epi64(regArray##_6, regArray##_7);
  239. /* 2-step interleave for matrix against 8 rows with 16 BF16 elements per row
  240. Input - register array of 8 rows of raw-major matrix
  241. Output - the output of Step 2
  242. Step 1: 2-element interleave for matrix
  243. |a0|a1|b0|b1|a2|a3|b2|b3|a8 |a9 |b8 |b9 |a10|a11|b10|b11
  244. |c0|c1|d0|d1|c2|c3|d2|d3|c8 |c9 |d8 |d9 |c10|c11|d10|d11
  245. |a4|a5|b4|b5|a6|a7|b6|b7|a12|a13|b12|b13|a14|a15|b14|b15
  246. |c4|c5|d4|d5|c6|c7|d6|d7|c12|c13|d12|d13|c14|c15|d14|d15
  247. Step 2: 4-element interleave for matrix
  248. |a0|a1|b0|b1|c0|c1|d0|d1|a8 |a9 |b8 |b9 |c8 |c9 |d8 |d9
  249. |a2|a3|b2|b3|c2|c3|d2|d3|a10|a11|b10|b11|c10|c11|d10|d11
  250. |a4|a5|b4|b5|c4|c5|d4|d5|a12|a13|b12|b13|c12|c13|d12|d13
  251. |a6|a7|b6|b7|c6|c7|d6|d7|a14|a15|b14|b15|c14|c15|d14|d15
  252. */
  253. #define BF16_INTERLEAVE_4x16(regArray) \
  254. regArray##_4 = _mm256_unpacklo_epi32(regArray##_0, regArray##_1); \
  255. regArray##_5 = _mm256_unpacklo_epi32(regArray##_2, regArray##_3); \
  256. regArray##_6 = _mm256_unpackhi_epi32(regArray##_0, regArray##_1); \
  257. regArray##_7 = _mm256_unpackhi_epi32(regArray##_2, regArray##_3); \
  258. \
  259. regArray##_0 = _mm256_unpacklo_epi64(regArray##_4, regArray##_5); \
  260. regArray##_1 = _mm256_unpackhi_epi64(regArray##_4, regArray##_5); \
  261. regArray##_2 = _mm256_unpacklo_epi64(regArray##_6, regArray##_7); \
  262. regArray##_3 = _mm256_unpackhi_epi64(regArray##_6, regArray##_7);
  263. /* 2-step interleave for x with 32 BF16 elements
  264. Input - original vector
  265. Output - the output of Step 2
  266. Step 1: 2-element interleave for x:
  267. |x0|x1|x0|x1|x2|x3|x2|x3|x8 |x9 |x8 |x9 |x10|x11|x10|x11|x16|x17|x16|x17|x18|x19|x18|x19|x24|x25|x24|x25|x26|x27|x26|x27
  268. |x4|x5|x4|x5|x6|x7|x6|x7|x12|x13|x12|x13|x14|x15|x14|x15|x20|x21|x20|x21|x22|x23|x22|x23|x28|x29|x28|x29|x30|x31|x30|x31
  269. Step 2: 4-element interleave for x:
  270. |x0|x1|x0|x1|x0|x1|x0|x1|x8 |x9 |x8 |x9 |x8 |x9 |x8 |x9 |x16|x17|x16|x17|x16|x17|x16|x17|x24|x25|x24|x25|x24|x25|x24|x25
  271. |x2|x3|x2|x3|x2|x3|x2|x3|x10|x11|x10|x11|x10|x11|x10|x11|x18|x19|x18|x19|x18|x19|x18|x19|x26|x27|x26|x27|x26|x27|x26|x27
  272. |x4|x5|x4|x5|x4|x5|x4|x5|x12|x13|x12|x13|x12|x13|x12|x13|x20|x21|x20|x21|x20|x21|x20|x21|x28|x29|x28|x29|x28|x29|x28|x29
  273. |x6|x7|x6|x7|x6|x7|x6|x7|x14|x15|x14|x15|x14|x15|x14|x15|x22|x23|x22|x23|x22|x23|x22|x23|x30|x31|x30|x31|x30|x31|x30|x31
  274. */
  275. #define BF16_INTERLEAVE_1x32(regArray) \
  276. regArray##_1 = _mm512_unpacklo_epi32(regArray##_0, regArray##_0); \
  277. regArray##_3 = _mm512_unpackhi_epi32(regArray##_0, regArray##_0); \
  278. \
  279. regArray##_0 = _mm512_unpacklo_epi64(regArray##_1, regArray##_1); \
  280. regArray##_1 = _mm512_unpackhi_epi64(regArray##_1, regArray##_1); \
  281. regArray##_2 = _mm512_unpacklo_epi64(regArray##_3, regArray##_3); \
  282. regArray##_3 = _mm512_unpackhi_epi64(regArray##_3, regArray##_3);
  283. /* 2-step interleave for x with 16 BF16 elements
  284. Input - original vector
  285. Output - the output of Step 2
  286. Step 1: 2-element interleave for x:
  287. |x0|x1|x0|x1|x2|x3|x2|x3|x8 |x9 |x8 |x9 |x10|x11|x10|x11
  288. |x4|x5|x4|x5|x6|x7|x6|x7|x12|x13|x12|x13|x14|x15|x14|x15
  289. Step 2: 4-element interleave for x:
  290. |x0|x1|x0|x1|x0|x1|x0|x1|x8 |x9 |x8 |x9 |x8 |x9 |x8 |x9
  291. |x2|x3|x2|x3|x2|x3|x2|x3|x10|x11|x10|x11|x10|x11|x10|x11
  292. |x4|x5|x4|x5|x4|x5|x4|x5|x12|x13|x12|x13|x12|x13|x12|x13
  293. |x6|x7|x6|x7|x6|x7|x6|x7|x14|x15|x14|x15|x14|x15|x14|x15
  294. */
  295. #define BF16_INTERLEAVE_1x16(regArray) \
  296. regArray##_1 = _mm256_unpacklo_epi32(regArray##_0, regArray##_0); \
  297. regArray##_3 = _mm256_unpackhi_epi32(regArray##_0, regArray##_0); \
  298. \
  299. regArray##_0 = _mm256_unpacklo_epi64(regArray##_1, regArray##_1); \
  300. regArray##_1 = _mm256_unpackhi_epi64(regArray##_1, regArray##_1); \
  301. regArray##_2 = _mm256_unpacklo_epi64(regArray##_3, regArray##_3); \
  302. regArray##_3 = _mm256_unpackhi_epi64(regArray##_3, regArray##_3);
  303. /* 1-step interleave to exchange the high-256s bit and low-256 bits of 4 pair of registers
  304. |a0|a1|...|a14|a15|i0|i1|...|i14|i15|
  305. |b0|b1|...|b14|b15|j0|j1|...|j14|j15|
  306. |c0|c1|...|c14|c15|k0|k1|...|k14|k15|
  307. |d0|d1|...|d14|d15|l0|l1|...|l14|l15|
  308. |e0|e1|...|e14|e15|m0|m1|...|m14|m15|
  309. |f0|f1|...|f14|f15|n0|n1|...|n14|n15|
  310. |g0|g1|...|g14|g15|o0|o1|...|o14|o15|
  311. |h0|h1|...|h14|h15|p0|p1|...|p14|p15|
  312. */
  313. #define BF16_INTERLEAVE256_8x32(regArray) \
  314. regArray##_0 = _mm512_shuffle_i32x4(regArray##_8, regArray##_12, 0x44); \
  315. regArray##_1 = _mm512_shuffle_i32x4(regArray##_8, regArray##_12, 0xee); \
  316. regArray##_2 = _mm512_shuffle_i32x4(regArray##_9, regArray##_13, 0x44); \
  317. regArray##_3 = _mm512_shuffle_i32x4(regArray##_9, regArray##_13, 0xee); \
  318. regArray##_4 = _mm512_shuffle_i32x4(regArray##_10, regArray##_14, 0x44); \
  319. regArray##_5 = _mm512_shuffle_i32x4(regArray##_10, regArray##_14, 0xee); \
  320. regArray##_6 = _mm512_shuffle_i32x4(regArray##_11, regArray##_15, 0x44); \
  321. regArray##_7 = _mm512_shuffle_i32x4(regArray##_11, regArray##_15, 0xee);
  322. /* 1-step interleave to exchange the high-256s bit and low-256 bits of 2 pair of registers
  323. |a0|a1|...|a14|a15|e0|e1|...|e14|e15|
  324. |b0|b1|...|b14|b15|f0|f1|...|f14|f15|
  325. |c0|c1|...|c14|c15|g0|g1|...|g14|g15|
  326. |d0|d1|...|d14|d15|h0|h1|...|h14|h15|
  327. */
  328. #define BF16_INTERLEAVE256_4x32(regArray) \
  329. regArray##_0 = _mm512_shuffle_i32x4(regArray##_4, regArray##_6, 0x44); \
  330. regArray##_1 = _mm512_shuffle_i32x4(regArray##_4, regArray##_6, 0xee); \
  331. regArray##_2 = _mm512_shuffle_i32x4(regArray##_5, regArray##_7, 0x44); \
  332. regArray##_3 = _mm512_shuffle_i32x4(regArray##_5, regArray##_7, 0xee);
  333. #define BF16_PERMUTE_8x32(idx, regArray) \
  334. regArray##_8 = _mm512_permutexvar_epi16(idx, regArray##_0); \
  335. regArray##_9 = _mm512_permutexvar_epi16(idx, regArray##_1); \
  336. regArray##_10 = _mm512_permutexvar_epi16(idx, regArray##_2); \
  337. regArray##_11 = _mm512_permutexvar_epi16(idx, regArray##_3); \
  338. regArray##_12 = _mm512_permutexvar_epi16(idx, regArray##_4); \
  339. regArray##_13 = _mm512_permutexvar_epi16(idx, regArray##_5); \
  340. regArray##_14 = _mm512_permutexvar_epi16(idx, regArray##_6); \
  341. regArray##_15 = _mm512_permutexvar_epi16(idx, regArray##_7);
  342. #define BF16_PERMUTE_8x32_2(idx, regArray) \
  343. regArray##_8 = _mm512_permutexvar_epi32(idx, regArray##_0); \
  344. regArray##_9 = _mm512_permutexvar_epi32(idx, regArray##_1); \
  345. regArray##_10 = _mm512_permutexvar_epi32(idx, regArray##_2); \
  346. regArray##_11 = _mm512_permutexvar_epi32(idx, regArray##_3); \
  347. regArray##_12 = _mm512_permutexvar_epi32(idx, regArray##_4); \
  348. regArray##_13 = _mm512_permutexvar_epi32(idx, regArray##_5); \
  349. regArray##_14 = _mm512_permutexvar_epi32(idx, regArray##_6); \
  350. regArray##_15 = _mm512_permutexvar_epi32(idx, regArray##_7);
  351. #define BF16_PERMUTE_4x32(idx, regArray) \
  352. regArray##_4 = _mm512_permutexvar_epi16(idx, regArray##_0); \
  353. regArray##_5 = _mm512_permutexvar_epi16(idx, regArray##_1); \
  354. regArray##_6 = _mm512_permutexvar_epi16(idx, regArray##_2); \
  355. regArray##_7 = _mm512_permutexvar_epi16(idx, regArray##_3);
  356. #define BF16_PERMUTE_4x32_2(idx, regArray) \
  357. regArray##_4 = _mm512_permutexvar_epi32(idx, regArray##_0); \
  358. regArray##_5 = _mm512_permutexvar_epi32(idx, regArray##_1); \
  359. regArray##_6 = _mm512_permutexvar_epi32(idx, regArray##_2); \
  360. regArray##_7 = _mm512_permutexvar_epi32(idx, regArray##_3);
  361. /* Calculate the dot result for 2-step interleaved matrix and vector
  362. (Assume throughput for _mm512_dpbf16_ps is 0.5, tunable per platform)
  363. */
  364. #define BF16_2STEP_INTERLEAVED_DOT_8x32(accumArray, matArray, xArray) \
  365. accumArray##_0 = _mm512_dpbf16_ps(accumArray##_0, (__m512bh) matArray##_0, (__m512bh) xArray##_0); \
  366. accumArray##_1 = _mm512_dpbf16_ps(accumArray##_1, (__m512bh) matArray##_2, (__m512bh) xArray##_0); \
  367. accumArray##_0 = _mm512_dpbf16_ps(accumArray##_0, (__m512bh) matArray##_1, (__m512bh) xArray##_1); \
  368. accumArray##_1 = _mm512_dpbf16_ps(accumArray##_1, (__m512bh) matArray##_3, (__m512bh) xArray##_1); \
  369. accumArray##_0 = _mm512_dpbf16_ps(accumArray##_0, (__m512bh) matArray##_4, (__m512bh) xArray##_2); \
  370. accumArray##_1 = _mm512_dpbf16_ps(accumArray##_1, (__m512bh) matArray##_6, (__m512bh) xArray##_2); \
  371. accumArray##_0 = _mm512_dpbf16_ps(accumArray##_0, (__m512bh) matArray##_5, (__m512bh) xArray##_3); \
  372. accumArray##_1 = _mm512_dpbf16_ps(accumArray##_1, (__m512bh) matArray##_7, (__m512bh) xArray##_3);
  373. /* Calculate the dot result for 2-step interleaved matrix and vector
  374. (Assume throughput for _mm256_dpbf16_ps is 0.5, tunable per platform)
  375. */
  376. #define BF16_2STEP_INTERLEAVED_DOT_8x16(accumArray, matArray, xArray) \
  377. accumArray##_0 = _mm256_dpbf16_ps(accumArray##_0, (__m256bh) matArray##_0, (__m256bh) xArray##_0); \
  378. accumArray##_1 = _mm256_dpbf16_ps(accumArray##_1, (__m256bh) matArray##_2, (__m256bh) xArray##_0); \
  379. accumArray##_0 = _mm256_dpbf16_ps(accumArray##_0, (__m256bh) matArray##_1, (__m256bh) xArray##_1); \
  380. accumArray##_1 = _mm256_dpbf16_ps(accumArray##_1, (__m256bh) matArray##_3, (__m256bh) xArray##_1); \
  381. accumArray##_0 = _mm256_dpbf16_ps(accumArray##_0, (__m256bh) matArray##_4, (__m256bh) xArray##_2); \
  382. accumArray##_1 = _mm256_dpbf16_ps(accumArray##_1, (__m256bh) matArray##_6, (__m256bh) xArray##_2); \
  383. accumArray##_0 = _mm256_dpbf16_ps(accumArray##_0, (__m256bh) matArray##_5, (__m256bh) xArray##_3); \
  384. accumArray##_1 = _mm256_dpbf16_ps(accumArray##_1, (__m256bh) matArray##_7, (__m256bh) xArray##_3);
  385. /* Calculate the dot result for 2-step interleaved matrix and vector
  386. (Assume throughput for _mm512_dpbf16_ps is 0.5, tunable per platform)
  387. */
  388. #define BF16_2STEP_INTERLEAVED_DOT_4x32(accumArray, matArray, xArray) \
  389. accumArray##_0 = _mm512_dpbf16_ps(accumArray##_0, (__m512bh) matArray##_0, (__m512bh) xArray##_0); \
  390. accumArray##_1 = _mm512_dpbf16_ps(accumArray##_1, (__m512bh) matArray##_1, (__m512bh) xArray##_1); \
  391. accumArray##_0 = _mm512_dpbf16_ps(accumArray##_0, (__m512bh) matArray##_2, (__m512bh) xArray##_2); \
  392. accumArray##_1 = _mm512_dpbf16_ps(accumArray##_1, (__m512bh) matArray##_3, (__m512bh) xArray##_3);
  393. /* Calculate the dot result for 2-step interleaved matrix and vector
  394. (Assume throughput for _mm256_dpbf16_ps is 0.5, tunable per platform)
  395. */
  396. #define BF16_2STEP_INTERLEAVED_DOT_4x16(accumArray, matArray, xArray) \
  397. accumArray##_0 = _mm256_dpbf16_ps(accumArray##_0, (__m256bh) matArray##_0, (__m256bh) xArray##_0); \
  398. accumArray##_1 = _mm256_dpbf16_ps(accumArray##_1, (__m256bh) matArray##_1, (__m256bh) xArray##_1); \
  399. accumArray##_0 = _mm256_dpbf16_ps(accumArray##_0, (__m256bh) matArray##_2, (__m256bh) xArray##_2); \
  400. accumArray##_1 = _mm256_dpbf16_ps(accumArray##_1, (__m256bh) matArray##_3, (__m256bh) xArray##_3);
  401. /* Calculate the dot result for matrix and vector at 32 elements per row
  402. (Assume throughput for _mm512_dpbf16_ps is 0.5, tunable per platform)
  403. */
  404. #define BF16_DOT_8x32(accumArray, matArray, xArray) \
  405. accumArray##_0 = _mm512_dpbf16_ps(accumArray##_0, (__m512bh) matArray##_0, (__m512bh) xArray); \
  406. accumArray##_1 = _mm512_dpbf16_ps(accumArray##_1, (__m512bh) matArray##_1, (__m512bh) xArray); \
  407. accumArray##_2 = _mm512_dpbf16_ps(accumArray##_2, (__m512bh) matArray##_2, (__m512bh) xArray); \
  408. accumArray##_3 = _mm512_dpbf16_ps(accumArray##_3, (__m512bh) matArray##_3, (__m512bh) xArray); \
  409. accumArray##_4 = _mm512_dpbf16_ps(accumArray##_4, (__m512bh) matArray##_4, (__m512bh) xArray); \
  410. accumArray##_5 = _mm512_dpbf16_ps(accumArray##_5, (__m512bh) matArray##_5, (__m512bh) xArray); \
  411. accumArray##_6 = _mm512_dpbf16_ps(accumArray##_6, (__m512bh) matArray##_6, (__m512bh) xArray); \
  412. accumArray##_7 = _mm512_dpbf16_ps(accumArray##_7, (__m512bh) matArray##_7, (__m512bh) xArray);
  413. /* Calculate the dot result for matrix and vector at 32 elements per row
  414. (Assume throughput for _mm512_dpbf16_ps is 0.5, tunable per platform)
  415. */
  416. #define BF16_DOT_1x32(accumArray, matArray, xArray) \
  417. accumArray = _mm512_dpbf16_ps(accumArray, (__m512bh) matArray, (__m512bh) xArray);
  418. /* Calculate the dot result for matrix and vector at 16 elements per row
  419. (Assume throughput for _mm256_dpbf16_ps is 0.5, tunable per platform)
  420. */
  421. #define BF16_DOT_8x16(accumArray, matArray, xArray) \
  422. accumArray##_0 = _mm256_dpbf16_ps(accumArray##_0, (__m256bh) matArray##_0, (__m256bh) xArray); \
  423. accumArray##_1 = _mm256_dpbf16_ps(accumArray##_1, (__m256bh) matArray##_1, (__m256bh) xArray); \
  424. accumArray##_2 = _mm256_dpbf16_ps(accumArray##_2, (__m256bh) matArray##_2, (__m256bh) xArray); \
  425. accumArray##_3 = _mm256_dpbf16_ps(accumArray##_3, (__m256bh) matArray##_3, (__m256bh) xArray); \
  426. accumArray##_4 = _mm256_dpbf16_ps(accumArray##_4, (__m256bh) matArray##_4, (__m256bh) xArray); \
  427. accumArray##_5 = _mm256_dpbf16_ps(accumArray##_5, (__m256bh) matArray##_5, (__m256bh) xArray); \
  428. accumArray##_6 = _mm256_dpbf16_ps(accumArray##_6, (__m256bh) matArray##_6, (__m256bh) xArray); \
  429. accumArray##_7 = _mm256_dpbf16_ps(accumArray##_7, (__m256bh) matArray##_7, (__m256bh) xArray);
  430. /* 2-step interleave for matrix against 8 rows with 16 fp32 elements per row
  431. Input - register array of 8 rows of raw-major matrix
  432. Output - the output of Step 2
  433. Step 1: 2-element interleave for matrix
  434. |a0|b0|a1|b1|a4|b4|a5|b5|a8 |b8 |a9 |b9 |a12|b12|a13|b13|
  435. |c0|d0|c1|d1|c4|d4|c5|d5|c8 |d8 |c9 |d9 |c12|d12|c13|d13|
  436. |e0|f0|e1|f1|e4|f4|e5|f5|e8 |f8 |e9 |f9 |e12|f12|e13|f13|
  437. |g0|h0|g1|h1|g4|h4|g5|h5|g8 |h8 |g9 |h9 |g12|h12|g13|h13|
  438. |a2|b2|a3|b3|a6|b6|a7|b7|a10|b10|a11|b11|a14|b14|a15|b15|
  439. |c2|d2|c3|d3|c6|d6|c7|d7|c10|d10|c11|d11|c14|d14|c15|d15|
  440. |e2|f2|e3|f3|e6|f6|e7|f7|e10|f10|e11|f11|e14|f14|e15|f15|
  441. |g2|h2|g3|h3|g6|h6|g7|h7|g10|h10|g11|h11|g14|h14|g15|h15|
  442. Step 2: 4-element interleave for matrix
  443. |a0|b0|c0|d0|a4|b4|c4|d4|a8 |b8 |c8 |d8 |a12|b12|c12|d12|
  444. |a1|b1|c1|d1|a5|b5|c5|d5|a9 |b9 |c9 |d9 |a13|b13|c13|d13|
  445. |e0|f0|g0|h0|e4|f4|g4|h4|e8 |f8 |g8 |h8 |e12|f12|g12|h12|
  446. |e1|f1|g1|h1|e5|f5|g5|h5|e9 |f9 |g9 |h9 |e13|f13|g13|h13|
  447. |a2|b2|c2|d2|a6|b6|c6|d6|a10|b10|c10|d10|a14|b14|c14|d14|
  448. |a3|b3|c3|d3|a7|b7|c7|d7|a11|b11|c11|d11|a15|b15|c15|d15|
  449. |e2|f2|g2|h2|e6|f6|g6|h6|e10|f10|g10|h10|e14|f14|g14|h14|
  450. |e3|f3|g3|h3|e7|f7|g7|h7|e11|f11|g11|h11|e15|f15|g15|h15|
  451. */
  452. #define FP32_INTERLEAVE_8x16(regArray) \
  453. regArray##_8 = _mm512_unpacklo_ps(regArray##_0, regArray##_1); \
  454. regArray##_9 = _mm512_unpacklo_ps(regArray##_2, regArray##_3); \
  455. regArray##_10 = _mm512_unpacklo_ps(regArray##_4, regArray##_5); \
  456. regArray##_11 = _mm512_unpacklo_ps(regArray##_6, regArray##_7); \
  457. regArray##_12 = _mm512_unpackhi_ps(regArray##_0, regArray##_1); \
  458. regArray##_13 = _mm512_unpackhi_ps(regArray##_2, regArray##_3); \
  459. regArray##_14 = _mm512_unpackhi_ps(regArray##_4, regArray##_5); \
  460. regArray##_15 = _mm512_unpackhi_ps(regArray##_6, regArray##_7); \
  461. \
  462. regArray##_0 = (__m512) _mm512_unpacklo_pd((__m512d) regArray##_8, (__m512d) regArray##_9); \
  463. regArray##_1 = (__m512) _mm512_unpackhi_pd((__m512d) regArray##_8, (__m512d) regArray##_9); \
  464. regArray##_4 = (__m512) _mm512_unpacklo_pd((__m512d) regArray##_10, (__m512d) regArray##_11); \
  465. regArray##_5 = (__m512) _mm512_unpackhi_pd((__m512d) regArray##_10, (__m512d) regArray##_11); \
  466. regArray##_2 = (__m512) _mm512_unpacklo_pd((__m512d) regArray##_12, (__m512d) regArray##_13); \
  467. regArray##_3 = (__m512) _mm512_unpackhi_pd((__m512d) regArray##_12, (__m512d) regArray##_13); \
  468. regArray##_6 = (__m512) _mm512_unpacklo_pd((__m512d) regArray##_14, (__m512d) regArray##_15); \
  469. regArray##_7 = (__m512) _mm512_unpackhi_pd((__m512d) regArray##_14, (__m512d) regArray##_15);
  470. #define FP32_INTERLEAVE_8x16_ARRAY(regArray) \
  471. regArray[8] = _mm512_unpacklo_ps(regArray[0], regArray[1]); \
  472. regArray[9] = _mm512_unpacklo_ps(regArray[2], regArray[3]); \
  473. regArray[10] = _mm512_unpacklo_ps(regArray[4], regArray[5]); \
  474. regArray[11] = _mm512_unpacklo_ps(regArray[6], regArray[7]); \
  475. regArray[12] = _mm512_unpackhi_ps(regArray[0], regArray[1]); \
  476. regArray[13] = _mm512_unpackhi_ps(regArray[2], regArray[3]); \
  477. regArray[14] = _mm512_unpackhi_ps(regArray[4], regArray[5]); \
  478. regArray[15] = _mm512_unpackhi_ps(regArray[6], regArray[7]); \
  479. \
  480. regArray[0] = (__m512) _mm512_unpacklo_pd((__m512d) regArray[8], (__m512d) regArray[9]); \
  481. regArray[1] = (__m512) _mm512_unpackhi_pd((__m512d) regArray[8], (__m512d) regArray[9]); \
  482. regArray[4] = (__m512) _mm512_unpacklo_pd((__m512d) regArray[10], (__m512d) regArray[11]); \
  483. regArray[5] = (__m512) _mm512_unpackhi_pd((__m512d) regArray[10], (__m512d) regArray[11]); \
  484. regArray[2] = (__m512) _mm512_unpacklo_pd((__m512d) regArray[12], (__m512d) regArray[13]); \
  485. regArray[3] = (__m512) _mm512_unpackhi_pd((__m512d) regArray[12], (__m512d) regArray[13]); \
  486. regArray[6] = (__m512) _mm512_unpacklo_pd((__m512d) regArray[14], (__m512d) regArray[15]); \
  487. regArray[7] = (__m512) _mm512_unpackhi_pd((__m512d) regArray[14], (__m512d) regArray[15]);
  488. /* 2-step interleave for matrix against 8 rows with 8 fp32 elements per row
  489. Input - register array of 8 rows of raw-major matrix
  490. Output - the output of Step 2
  491. Step 1: 2-element interleave for matrix
  492. |a0|b0|a1|b1|a4|b4|a5|b5|
  493. |c0|d0|c1|d1|c4|d4|c5|d5|
  494. |e0|f0|e1|f1|e4|f4|e5|f5|
  495. |g0|h0|g1|h1|g4|h4|g5|h5|
  496. |a2|b2|a3|b3|a6|b6|a7|b7|
  497. |c2|d2|c3|d3|c6|d6|c7|d7|
  498. |e2|f2|e3|f3|e6|f6|e7|f7|
  499. |g2|h2|g3|h3|g6|h6|g7|h7|
  500. Step 2: 4-element interleave for matrix
  501. |a0|b0|c0|d0|a4|b4|c4|d4|
  502. |a1|b1|c1|d1|a5|b5|c5|d5|
  503. |e0|f0|g0|h0|e4|f4|g4|h4|
  504. |e1|f1|g1|h1|e5|f5|g5|h5|
  505. |a2|b2|c2|d2|a6|b6|c6|d6|
  506. |a3|b3|c3|d3|a7|b7|c7|d7|
  507. |e2|f2|g2|h2|e6|f6|g6|h6|
  508. |e3|f3|g3|h3|e7|f7|g7|h7|
  509. */
  510. #define FP32_INTERLEAVE_8x8(regArray) \
  511. regArray##_8 = _mm256_unpacklo_ps(regArray##_0, regArray##_1); \
  512. regArray##_9 = _mm256_unpacklo_ps(regArray##_2, regArray##_3); \
  513. regArray##_10 = _mm256_unpacklo_ps(regArray##_4, regArray##_5); \
  514. regArray##_11 = _mm256_unpacklo_ps(regArray##_6, regArray##_7); \
  515. regArray##_12 = _mm256_unpackhi_ps(regArray##_0, regArray##_1); \
  516. regArray##_13 = _mm256_unpackhi_ps(regArray##_2, regArray##_3); \
  517. regArray##_14 = _mm256_unpackhi_ps(regArray##_4, regArray##_5); \
  518. regArray##_15 = _mm256_unpackhi_ps(regArray##_6, regArray##_7); \
  519. \
  520. regArray##_0 = (__m256) _mm256_unpacklo_pd((__m256d) regArray##_8, (__m256d) regArray##_9); \
  521. regArray##_1 = (__m256) _mm256_unpackhi_pd((__m256d) regArray##_8, (__m256d) regArray##_9); \
  522. regArray##_4 = (__m256) _mm256_unpacklo_pd((__m256d) regArray##_10, (__m256d) regArray##_11); \
  523. regArray##_5 = (__m256) _mm256_unpackhi_pd((__m256d) regArray##_10, (__m256d) regArray##_11); \
  524. regArray##_2 = (__m256) _mm256_unpacklo_pd((__m256d) regArray##_12, (__m256d) regArray##_13); \
  525. regArray##_3 = (__m256) _mm256_unpackhi_pd((__m256d) regArray##_12, (__m256d) regArray##_13); \
  526. regArray##_6 = (__m256) _mm256_unpacklo_pd((__m256d) regArray##_14, (__m256d) regArray##_15); \
  527. regArray##_7 = (__m256) _mm256_unpackhi_pd((__m256d) regArray##_14, (__m256d) regArray##_15);
  528. /* Accumulate the result for 2 batch of 4-registers
  529. */
  530. #define FP32_ACCUM2_8x16(regArray) \
  531. regArray##_0 = _mm512_add_ps(regArray##_0, regArray##_1); \
  532. regArray##_2 = _mm512_add_ps(regArray##_2, regArray##_3); \
  533. regArray##_4 = _mm512_add_ps(regArray##_4, regArray##_5); \
  534. regArray##_6 = _mm512_add_ps(regArray##_6, regArray##_7); \
  535. regArray##_0 = _mm512_add_ps(regArray##_0, regArray##_2); \
  536. regArray##_4 = _mm512_add_ps(regArray##_4, regArray##_6);
  537. #define FP32_ACCUM2_8x16_ARRAY(regArray) \
  538. regArray[0] = _mm512_add_ps(regArray[0], regArray[1]); \
  539. regArray[2] = _mm512_add_ps(regArray[2], regArray[3]); \
  540. regArray[4] = _mm512_add_ps(regArray[4], regArray[5]); \
  541. regArray[6] = _mm512_add_ps(regArray[6], regArray[7]); \
  542. regArray[0] = _mm512_add_ps(regArray[0], regArray[2]); \
  543. regArray[4] = _mm512_add_ps(regArray[4], regArray[6]);
  544. /* Accumulate the result for 2 batch of 4-registers
  545. */
  546. #define FP32_ACCUM2_8x8(regArray) \
  547. regArray##_0 = _mm256_add_ps(regArray##_0, regArray##_1); \
  548. regArray##_2 = _mm256_add_ps(regArray##_2, regArray##_3); \
  549. regArray##_4 = _mm256_add_ps(regArray##_4, regArray##_5); \
  550. regArray##_6 = _mm256_add_ps(regArray##_6, regArray##_7); \
  551. regArray##_0 = _mm256_add_ps(regArray##_0, regArray##_2); \
  552. regArray##_4 = _mm256_add_ps(regArray##_4, regArray##_6);
  553. /* Store 16 (alpha * result + beta * y) to y
  554. */
  555. #define STORE16_COMPLETE_RESULT_ALPHA_BETA(regResult, targetAddr) \
  556. regResult = _mm512_fmadd_ps(ALPHAVECTOR, regResult, _mm512_mul_ps(BETAVECTOR, _mm512_loadu_ps(targetAddr))); \
  557. _mm512_storeu_ps(targetAddr, regResult);
  558. /* Masked store 16 (alpha * result + beta * y) to y
  559. */
  560. #define STORE16_MASK_COMPLETE_RESULT_ALPHA_BETA(regResult, targetAddr, mask) \
  561. regResult = _mm512_fmadd_ps(ALPHAVECTOR, regResult, _mm512_mul_ps(BETAVECTOR, _mm512_maskz_loadu_ps(mask, targetAddr))); \
  562. _mm512_mask_storeu_ps(targetAddr, mask, regResult);
  563. /* Store 8 (alpha * result + beta * y) to y
  564. */
  565. #define STORE8_COMPLETE_RESULT_ALPHA_BETA(regResult, targetAddr) \
  566. regResult = _mm256_fmadd_ps(_mm512_castps512_ps256(ALPHAVECTOR), regResult, _mm256_mul_ps(_mm512_castps512_ps256(BETAVECTOR), _mm256_loadu_ps(targetAddr))); \
  567. _mm256_storeu_ps(targetAddr, regResult);
  568. /* Masked store 8 (alpha * result + beta * y) to y
  569. */
  570. #define STORE8_MASK_COMPLETE_RESULT_ALPHA_BETA(regResult, targetAddr, mask) \
  571. regResult = _mm256_fmadd_ps(_mm512_castps512_ps256(ALPHAVECTOR), regResult, _mm256_mul_ps(_mm512_castps512_ps256(BETAVECTOR), _mm256_maskz_loadu_ps(mask, targetAddr))); \
  572. _mm256_mask_storeu_ps(targetAddr, mask, regResult);
  573. /* Store 4 (alpha * result + beta * y) to y
  574. */
  575. #define STORE4_COMPLETE_RESULT_ALPHA_BETA(regResult, targetAddr) \
  576. regResult = _mm_fmadd_ps(_mm512_castps512_ps128(ALPHAVECTOR), regResult, _mm_mul_ps(_mm512_castps512_ps128(BETAVECTOR), _mm_loadu_ps(targetAddr))); \
  577. _mm_storeu_ps(targetAddr, regResult);
  578. /* Masked store 4 (alpha * result + beta * y) to y
  579. */
  580. #define STORE4_MASK_COMPLETE_RESULT_ALPHA_BETA(regResult, targetAddr, mask) \
  581. regResult = _mm_fmadd_ps(_mm512_castps512_ps128(ALPHAVECTOR), regResult, _mm_mul_ps(_mm512_castps512_ps128(BETAVECTOR), _mm_maskz_loadu_ps(mask, targetAddr))); \
  582. _mm_mask_storeu_ps(targetAddr, mask, regResult);
  583. /* Store 16 (alpha * result + y) to y
  584. */
  585. #define STORE16_COMPLETE_RESULT_ALPHA_ONE(regResult, targetAddr) \
  586. regResult = _mm512_fmadd_ps(ALPHAVECTOR, regResult, _mm512_loadu_ps(targetAddr)); \
  587. _mm512_storeu_ps(targetAddr, regResult);
  588. /* Masked store 16 (alpha * result + y) to y
  589. */
  590. #define STORE16_MASK_COMPLETE_RESULT_ALPHA_ONE(regResult, targetAddr, mask) \
  591. regResult = _mm512_fmadd_ps(ALPHAVECTOR, regResult, _mm512_maskz_loadu_ps(mask, targetAddr)); \
  592. _mm512_mask_storeu_ps(targetAddr, mask, regResult);
  593. /* Store 8 (alpha * result + y) to y
  594. */
  595. #define STORE8_COMPLETE_RESULT_ALPHA_ONE(regResult, targetAddr) \
  596. regResult = _mm256_fmadd_ps(_mm512_castps512_ps256(ALPHAVECTOR), regResult, _mm256_loadu_ps(targetAddr)); \
  597. _mm256_storeu_ps(targetAddr, regResult);
  598. /* Masked store 8 (alpha * result + y) to y
  599. */
  600. #define STORE8_MASK_COMPLETE_RESULT_ALPHA_ONE(regResult, targetAddr, mask) \
  601. regResult = _mm256_fmadd_ps(_mm512_castps512_ps256(ALPHAVECTOR), regResult, _mm256_maskz_loadu_ps(mask, targetAddr)); \
  602. _mm256_mask_storeu_ps(targetAddr, mask, regResult);
  603. /* Store 4 (alpha * result + y) to y
  604. */
  605. #define STORE4_COMPLETE_RESULT_ALPHA_ONE(regResult, targetAddr) \
  606. regResult = _mm_fmadd_ps(_mm512_castps512_ps128(ALPHAVECTOR), regResult, _mm_loadu_ps(targetAddr)); \
  607. _mm_storeu_ps(targetAddr, regResult);
  608. /* Masked store 4 (alpha * result + y) to y
  609. */
  610. #define STORE4_MASK_COMPLETE_RESULT_ALPHA_ONE(regResult, targetAddr, mask) \
  611. regResult = _mm_fmadd_ps(_mm512_castps512_ps128(ALPHAVECTOR), regResult, _mm_maskz_loadu_ps(mask, targetAddr)); \
  612. _mm_mask_storeu_ps(targetAddr, mask, regResult);
  613. /* Store 16 (result + y) to y
  614. */
  615. #define STORE16_COMPLETE_RESULT_ONE_ONE(regResult, targetAddr) \
  616. regResult = _mm512_add_ps(regResult, _mm512_loadu_ps(targetAddr)); \
  617. _mm512_storeu_ps(targetAddr, regResult);
  618. /* Masked store 16 (result + y) to y
  619. */
  620. #define STORE16_MASK_COMPLETE_RESULT_ONE_ONE(regResult, targetAddr, mask) \
  621. regResult = _mm512_add_ps(regResult, _mm512_maskz_loadu_ps(mask, targetAddr)); \
  622. _mm512_mask_storeu_ps(targetAddr, mask, regResult);
  623. /* Store 8 (result + y) to y
  624. */
  625. #define STORE8_COMPLETE_RESULT_ONE_ONE(regResult, targetAddr) \
  626. regResult = _mm256_add_ps(regResult, _mm256_loadu_ps(targetAddr)); \
  627. _mm256_storeu_ps(targetAddr, regResult);
  628. /* Masked store 8 (result + y) to y
  629. */
  630. #define STORE8_MASK_COMPLETE_RESULT_ONE_ONE(regResult, targetAddr, mask) \
  631. regResult = _mm256_add_ps(regResult, _mm256_maskz_loadu_ps(mask, targetAddr)); \
  632. _mm256_mask_storeu_ps(targetAddr, mask, regResult);
  633. /* Store 4 (result + y) to y
  634. */
  635. #define STORE4_COMPLETE_RESULT_ONE_ONE(regResult, targetAddr) \
  636. regResult = _mm_add_ps(regResult, _mm_loadu_ps(targetAddr)); \
  637. _mm_storeu_ps(targetAddr, regResult);
  638. /* Masked store 4 (result + y) to y
  639. */
  640. #define STORE4_MASK_COMPLETE_RESULT_ONE_ONE(regResult, targetAddr, mask) \
  641. regResult = _mm_add_ps(regResult, _mm_maskz_loadu_ps(mask, targetAddr)); \
  642. _mm_mask_storeu_ps(targetAddr, mask, regResult);
  643. /* Store 16 (alpha * result) to y
  644. */
  645. #define STORE16_COMPLETE_RESULT_ALPHA(regResult, targetAddr) \
  646. _mm512_storeu_ps(targetAddr, _mm512_mul_ps(ALPHAVECTOR, regResult));
  647. /* Masked store 16 (alpha * result) to y
  648. */
  649. #define STORE16_MASK_COMPLETE_RESULT_ALPHA(regResult, targetAddr, mask) \
  650. _mm512_mask_storeu_ps(targetAddr, mask, _mm512_mul_ps(ALPHAVECTOR, regResult));
  651. /* Store 8 (alpha * result) to y
  652. */
  653. #define STORE8_COMPLETE_RESULT_ALPHA(regResult, targetAddr) \
  654. _mm256_storeu_ps(targetAddr, _mm256_mul_ps(_mm512_castps512_ps256(ALPHAVECTOR), regResult));
  655. /* Masked store 8 (alpha * result) to y
  656. */
  657. #define STORE8_MASK_COMPLETE_RESULT_ALPHA(regResult, targetAddr, mask) \
  658. _mm256_mask_storeu_ps(targetAddr, mask, _mm256_mul_ps(_mm512_castps512_ps256(ALPHAVECTOR), regResult));
  659. /* Store 4 (alpha * result) to y
  660. */
  661. #define STORE4_COMPLETE_RESULT_ALPHA(regResult, targetAddr) \
  662. _mm_storeu_ps(targetAddr, _mm_mul_ps(_mm512_castps512_ps128(ALPHAVECTOR), regResult));
  663. /* Masked store 4 (alpha * result) to y
  664. */
  665. #define STORE4_MASK_COMPLETE_RESULT_ALPHA(regResult, targetAddr, mask) \
  666. _mm_mask_storeu_ps(targetAddr, mask, _mm_mul_ps(_mm512_castps512_ps128(ALPHAVECTOR), regResult));
  667. /* Store 16 result to y
  668. */
  669. #define STORE16_COMPLETE_RESULT_DIRECT(regResult, targetAddr) \
  670. _mm512_storeu_ps(targetAddr, regResult);
  671. /* Masked store 16 result to y
  672. */
  673. #define STORE16_MASK_COMPLETE_RESULT_DIRECT(regResult, targetAddr, mask) \
  674. _mm512_mask_storeu_ps(targetAddr, mask, regResult);
  675. /* Store 8 result to y
  676. */
  677. #define STORE8_COMPLETE_RESULT_DIRECT(regResult, targetAddr) \
  678. _mm256_storeu_ps(targetAddr, regResult);
  679. /* Masked store 8 result to y
  680. */
  681. #define STORE8_MASK_COMPLETE_RESULT_DIRECT(regResult, targetAddr, mask) \
  682. _mm256_mask_storeu_ps(targetAddr, mask, regResult);
  683. /* Store 4 result to y
  684. */
  685. #define STORE4_COMPLETE_RESULT_DIRECT(regResult, targetAddr) \
  686. _mm_storeu_ps(targetAddr, regResult);
  687. /* Masked store 4 result to y
  688. */
  689. #define STORE4_MASK_COMPLETE_RESULT_DIRECT(regResult, targetAddr, mask) \
  690. _mm_mask_storeu_ps(targetAddr, mask, regResult);
  691. #endif