|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847 |
- /***************************************************************************
- Copyright (c) 2014, The OpenBLAS Project
- All rights reserved.
- Redistribution and use in source and binary forms, with or without
- modification, are permitted provided that the following conditions are
- met:
- 1. Redistributions of source code must retain the above copyright
- notice, this list of conditions and the following disclaimer.
- 2. Redistributions in binary form must reproduce the above copyright
- notice, this list of conditions and the following disclaimer in
- the documentation and/or other materials provided with the
- distribution.
- 3. Neither the name of the OpenBLAS project nor the names of
- its contributors may be used to endorse or promote products
- derived from this software without specific prior written permission.
- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
- AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
- IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
- ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
- LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
- DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
- SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
- CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
- OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
- USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
- *****************************************************************************/
- #ifndef __BF16_COMMON_MACROS
- #define __BF16_COMMON_MACROS
-
- #include <immintrin.h>
-
- #define _MM512_BROADCASTD_EPI32(addr, zmm) \
- __asm__ ("vpbroadcastd (%1), %0;" \
- : "=v" (zmm) \
- : "r" (addr) )
-
- #define PREFETCH_T0(addr) \
- __asm__ ("prefetcht0 (%0);" \
- : \
- : "r" (addr) )
-
- #define EXTRACT_LOW_256_FROM_512_2X(reg256, reg512) \
- reg256##_0 = _mm512_castps512_ps256(reg512##_0); \
- reg256##_1 = _mm512_castps512_ps256(reg512##_1);
-
-
- #define BF16_MATRIX_LOAD_8x32(regArray, a, lda, idx_m, idx_n) \
- regArray##_0 = _mm512_loadu_si512(&a[(idx_m+0)*lda + idx_n]); \
- regArray##_1 = _mm512_loadu_si512(&a[(idx_m+1)*lda + idx_n]); \
- regArray##_2 = _mm512_loadu_si512(&a[(idx_m+2)*lda + idx_n]); \
- regArray##_3 = _mm512_loadu_si512(&a[(idx_m+3)*lda + idx_n]); \
- regArray##_4 = _mm512_loadu_si512(&a[(idx_m+4)*lda + idx_n]); \
- regArray##_5 = _mm512_loadu_si512(&a[(idx_m+5)*lda + idx_n]); \
- regArray##_6 = _mm512_loadu_si512(&a[(idx_m+6)*lda + idx_n]); \
- regArray##_7 = _mm512_loadu_si512(&a[(idx_m+7)*lda + idx_n]);
-
-
- #define BF16_MATRIX_LOAD_8x16(regArray, a, lda, idx_m, idx_n) \
- regArray##_0 = _mm256_loadu_si256((__m256i *)(&a[(idx_m+0)*lda + idx_n])); \
- regArray##_1 = _mm256_loadu_si256((__m256i *)(&a[(idx_m+1)*lda + idx_n])); \
- regArray##_2 = _mm256_loadu_si256((__m256i *)(&a[(idx_m+2)*lda + idx_n])); \
- regArray##_3 = _mm256_loadu_si256((__m256i *)(&a[(idx_m+3)*lda + idx_n])); \
- regArray##_4 = _mm256_loadu_si256((__m256i *)(&a[(idx_m+4)*lda + idx_n])); \
- regArray##_5 = _mm256_loadu_si256((__m256i *)(&a[(idx_m+5)*lda + idx_n])); \
- regArray##_6 = _mm256_loadu_si256((__m256i *)(&a[(idx_m+6)*lda + idx_n])); \
- regArray##_7 = _mm256_loadu_si256((__m256i *)(&a[(idx_m+7)*lda + idx_n]));
-
-
- #define BF16_MATRIX_LOAD_8x8(regArray, a, lda, idx_m, idx_n) \
- regArray##_0 = _mm_loadu_si128((__m128i *)(&a[(idx_m+0)*lda + idx_n])); \
- regArray##_1 = _mm_loadu_si128((__m128i *)(&a[(idx_m+1)*lda + idx_n])); \
- regArray##_2 = _mm_loadu_si128((__m128i *)(&a[(idx_m+2)*lda + idx_n])); \
- regArray##_3 = _mm_loadu_si128((__m128i *)(&a[(idx_m+3)*lda + idx_n])); \
- regArray##_4 = _mm_loadu_si128((__m128i *)(&a[(idx_m+4)*lda + idx_n])); \
- regArray##_5 = _mm_loadu_si128((__m128i *)(&a[(idx_m+5)*lda + idx_n])); \
- regArray##_6 = _mm_loadu_si128((__m128i *)(&a[(idx_m+6)*lda + idx_n])); \
- regArray##_7 = _mm_loadu_si128((__m128i *)(&a[(idx_m+7)*lda + idx_n]));
-
-
- #define BF16_MATRIX_LOAD_1x32(regArray, a, lda, idx_m, idx_n) \
- regArray = _mm512_loadu_si512(&a[idx_m*lda + idx_n]);
-
-
- #define BF16_MATRIX_MASKZ_LOAD_8x32(regArray, a, lda, idx_m, idx_n, mask) \
- regArray##_0 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+0)*lda + idx_n]); \
- regArray##_1 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+1)*lda + idx_n]); \
- regArray##_2 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+2)*lda + idx_n]); \
- regArray##_3 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+3)*lda + idx_n]); \
- regArray##_4 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+4)*lda + idx_n]); \
- regArray##_5 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+5)*lda + idx_n]); \
- regArray##_6 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+6)*lda + idx_n]); \
- regArray##_7 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+7)*lda + idx_n]);
-
-
- #define BF16_MATRIX_MASKZ_LOAD_8x16(regArray, a, lda, idx_m, idx_n, mask) \
- regArray##_0 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+0)*lda + idx_n]); \
- regArray##_1 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+1)*lda + idx_n]); \
- regArray##_2 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+2)*lda + idx_n]); \
- regArray##_3 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+3)*lda + idx_n]); \
- regArray##_4 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+4)*lda + idx_n]); \
- regArray##_5 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+5)*lda + idx_n]); \
- regArray##_6 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+6)*lda + idx_n]); \
- regArray##_7 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+7)*lda + idx_n]);
-
-
- #define BF16_MATRIX_MASKZ_LOAD_8x8(regArray, a, lda, idx_m, idx_n, mask) \
- regArray##_0 = _mm_maskz_loadu_epi16(mask, &a[(idx_m+0)*lda + idx_n]); \
- regArray##_1 = _mm_maskz_loadu_epi16(mask, &a[(idx_m+1)*lda + idx_n]); \
- regArray##_2 = _mm_maskz_loadu_epi16(mask, &a[(idx_m+2)*lda + idx_n]); \
- regArray##_3 = _mm_maskz_loadu_epi16(mask, &a[(idx_m+3)*lda + idx_n]); \
- regArray##_4 = _mm_maskz_loadu_epi16(mask, &a[(idx_m+4)*lda + idx_n]); \
- regArray##_5 = _mm_maskz_loadu_epi16(mask, &a[(idx_m+5)*lda + idx_n]); \
- regArray##_6 = _mm_maskz_loadu_epi16(mask, &a[(idx_m+6)*lda + idx_n]); \
- regArray##_7 = _mm_maskz_loadu_epi16(mask, &a[(idx_m+7)*lda + idx_n]);
-
-
- #define BF16_MATRIX_MASKZ_LOAD_4x32(regArray, a, lda, idx_m, idx_n, mask) \
- regArray##_0 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+0)*lda + idx_n]); \
- regArray##_1 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+1)*lda + idx_n]); \
- regArray##_2 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+2)*lda + idx_n]); \
- regArray##_3 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+3)*lda + idx_n]);
-
-
- #define BF16_MATRIX_MASKZ_LOAD_4x16(regArray, a, lda, idx_m, idx_n, mask) \
- regArray##_0 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+0)*lda + idx_n]); \
- regArray##_1 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+1)*lda + idx_n]); \
- regArray##_2 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+2)*lda + idx_n]); \
- regArray##_3 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+3)*lda + idx_n]);
-
-
- #define BF16_MATRIX_MASKZ_LOAD_8x32_2(regArray, a, lda, idx_m, idx_n, mask) \
- regArray##_0 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+0)*lda + idx_n]); \
- regArray##_1 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+2)*lda + idx_n]); \
- regArray##_2 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+4)*lda + idx_n]); \
- regArray##_3 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+6)*lda + idx_n]); \
- regArray##_4 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+8)*lda + idx_n]); \
- regArray##_5 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+10)*lda + idx_n]); \
- regArray##_6 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+12)*lda + idx_n]); \
- regArray##_7 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+14)*lda + idx_n]);
-
-
- #define BF16_MATRIX_MASKZ_LOAD_4x32_2(regArray, a, lda, idx_m, idx_n, mask) \
- regArray##_0 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+0)*lda + idx_n]); \
- regArray##_1 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+2)*lda + idx_n]); \
- regArray##_2 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+4)*lda + idx_n]); \
- regArray##_3 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+6)*lda + idx_n]);
-
- #define BF16_MATRIX_MASKZ_LOAD_1x32(regArray, a, lda, idx_m, idx_n, mask) \
- regArray = _mm512_maskz_loadu_epi16(mask, &a[idx_m*lda + idx_n]);
-
- #define BF16_VECTOR_LOAD_1x32(reg, x, idx_n) \
- reg = _mm512_loadu_si512(x + idx_n);
-
-
- #define BF16_VECTOR_LOAD_1x16(reg, x, idx_n) \
- reg = _mm256_loadu_si256((__m256i *)(x + idx_n));
-
-
- #define BF16_VECTOR_LOAD_1x8(reg, x, idx_n) \
- reg = _mm_loadu_si128((__m128i *)(x + idx_n));
-
-
- #define BF16_VECTOR_MASKZ_LOAD_1x32(reg, x, idx_n, mask) \
- reg = _mm512_maskz_loadu_epi16(mask, x + idx_n);
-
-
- #define BF16_VECTOR_MASKZ_LOAD_1x16(reg, x, idx_n, mask) \
- reg = _mm256_maskz_loadu_epi16(mask, x + idx_n);
-
-
- #define BF16_VECTOR_MASKZ_LOAD_1x8(reg, x, idx_n, mask) \
- reg = _mm_maskz_loadu_epi16(mask, x + idx_n);
-
-
- /* 2-step interleave for matrix against 8 rows with 32 BF16 elements per row
- Input - register array of 8 rows of raw-major matrix
- Output - the output of Step 2
-
- Step 1: 2-element interleave for matrix
- |a0|a1|b0|b1|a2|a3|b2|b3|a8 |a9 |b8 |b9 |a10|a11|b10|b11|a16|a17|b16|b17|a18|a19|b18|b19|a24|a25|b24|b25|a26|a27|b26|b27
- |c0|c1|d0|d1|c2|c3|d2|d3|c8 |c9 |d8 |d9 |c10|c11|d10|d11|c16|c17|d16|d17|c18|c19|d18|d19|c24|c25|d24|d25|c26|c27|d26|d27
- |e0|e1|f0|f1|e2|e3|f2|f3|e8 |e9 |f8 |f9 |e10|e11|f10|f11|e16|e17|f16|f17|e18|e19|f18|f19|e24|e25|f24|f25|e26|e27|f26|f27
- |g0|g1|h0|h1|g2|g3|h2|h3|g8 |g9 |h8 |h9 |g10|g11|h10|h11|g16|g17|h16|h17|g18|g19|h18|h19|g24|g25|h24|h25|g26|g27|h26|h27
- |a4|a5|b4|b5|a6|a7|b6|b7|a12|a13|b12|b13|a14|a15|b14|b15|a20|a21|b20|b21|a22|a23|b22|b23|a28|a29|b28|b29|a30|a31|b30|b31
- |c4|c5|d4|d5|c6|c7|d6|d7|c12|c13|d12|d13|c14|c15|d14|d15|c20|c21|d20|d21|c22|c23|d22|d23|c28|c29|d28|d29|c30|c31|d30|d31
- |e4|e5|f4|f5|e6|e7|f6|f7|e12|e13|f12|f13|e14|e15|f14|f15|e20|e21|f20|f21|e22|e23|f22|f23|e28|e29|f28|f29|e30|e31|f30|f31
- |g4|g5|h4|h5|g6|g7|h6|h7|g12|g13|h12|h13|g14|g15|h14|h15|g20|g21|h20|h21|g22|g23|h22|h23|g28|g29|h28|h29|g30|g31|h30|h31
-
- Step 2: 4-element interleave for matrix
- |a0|a1|b0|b1|c0|c1|d0|d1|a8 |a9 |b8 |b9 |c8 |c9 |d8 |d9 |a16|a17|b16|b17|c16|c17|d16|d17|a24|a25|b24|b25|c24|c25|d24|d25
- |a2|a3|b2|b3|c2|c3|d2|d3|a10|a11|b10|b11|c10|c11|d10|d11|a18|a19|b18|b19|c18|c19|d18|d19|a26|a27|b26|b27|c26|c27|d26|d27
- |e0|e1|f0|f1|g0|g1|h0|h1|e8 |e9 |f8 |f9 |g8 |g9 |h8 |h9 |e16|e17|f16|f17|g16|g17|h16|h17|e24|e25|f24|f25|g24|g25|h24|h25
- |e2|e3|f2|f3|g2|g3|h2|h3|e10|e11|f10|f11|g10|g11|h10|h11|e18|e19|f18|f19|g18|g19|h18|h19|e26|e27|f26|f27|g26|g27|h26|h27
- |a4|a5|b4|b5|c4|c5|d4|d5|a12|a13|b12|b13|c12|c13|d12|d13|a20|a21|b20|b21|c20|c21|d20|d21|a28|a29|b28|b29|c28|c29|d28|d29
- |a6|a7|b6|b7|c6|c7|d6|d7|a14|a15|b14|b15|c14|c15|d14|d15|a22|a23|b22|b23|c22|c23|d22|d23|a30|a31|b30|b31|c30|c31|d30|d31
- |e4|e5|f4|f5|g4|g5|h4|h5|e12|e13|f12|f13|g12|g13|h12|h13|e20|e21|f20|f21|g20|g21|h20|h21|e28|e29|f28|f29|g28|g29|h28|h29
- |e6|e7|f6|f7|g6|g7|h6|h7|e14|e15|f14|f15|g14|g15|h14|h15|e22|e23|f22|f23|g22|g23|h22|h23|e30|e31|f30|f31|g30|g31|h30|h31
- */
- #define BF16_INTERLEAVE_8x32(regArray) \
- regArray##_8 = _mm512_unpacklo_epi32(regArray##_0, regArray##_1); \
- regArray##_9 = _mm512_unpacklo_epi32(regArray##_2, regArray##_3); \
- regArray##_10 = _mm512_unpacklo_epi32(regArray##_4, regArray##_5); \
- regArray##_11 = _mm512_unpacklo_epi32(regArray##_6, regArray##_7); \
- regArray##_12 = _mm512_unpackhi_epi32(regArray##_0, regArray##_1); \
- regArray##_13 = _mm512_unpackhi_epi32(regArray##_2, regArray##_3); \
- regArray##_14 = _mm512_unpackhi_epi32(regArray##_4, regArray##_5); \
- regArray##_15 = _mm512_unpackhi_epi32(regArray##_6, regArray##_7); \
- \
- regArray##_0 = _mm512_unpacklo_epi64(regArray##_8, regArray##_9); \
- regArray##_1 = _mm512_unpackhi_epi64(regArray##_8, regArray##_9); \
- regArray##_2 = _mm512_unpacklo_epi64(regArray##_10, regArray##_11); \
- regArray##_3 = _mm512_unpackhi_epi64(regArray##_10, regArray##_11); \
- regArray##_4 = _mm512_unpacklo_epi64(regArray##_12, regArray##_13); \
- regArray##_5 = _mm512_unpackhi_epi64(regArray##_12, regArray##_13); \
- regArray##_6 = _mm512_unpacklo_epi64(regArray##_14, regArray##_15); \
- regArray##_7 = _mm512_unpackhi_epi64(regArray##_14, regArray##_15);
-
-
- /* 2-step interleave for matrix against 8 rows with 16 BF16 elements per row
- Input - register array of 8 rows of raw-major matrix
- Output - the output of Step 2
-
- Step 1: 2-element interleave for matrix
- |a0|a1|b0|b1|a2|a3|b2|b3|a8 |a9 |b8 |b9 |a10|a11|b10|b11
- |c0|c1|d0|d1|c2|c3|d2|d3|c8 |c9 |d8 |d9 |c10|c11|d10|d11
- |e0|e1|f0|f1|e2|e3|f2|f3|e8 |e9 |f8 |f9 |e10|e11|f10|f11
- |g0|g1|h0|h1|g2|g3|h2|h3|g8 |g9 |h8 |h9 |g10|g11|h10|h11
- |a4|a5|b4|b5|a6|a7|b6|b7|a12|a13|b12|b13|a14|a15|b14|b15
- |c4|c5|d4|d5|c6|c7|d6|d7|c12|c13|d12|d13|c14|c15|d14|d15
- |e4|e5|f4|f5|e6|e7|f6|f7|e12|e13|f12|f13|e14|e15|f14|f15
- |g4|g5|h4|h5|g6|g7|h6|h7|g12|g13|h12|h13|g14|g15|h14|h15
-
- Step 2: 4-element interleave for matrix
- |a0|a1|b0|b1|c0|c1|d0|d1|a8 |a9 |b8 |b9 |c8 |c9 |d8 |d9
- |a2|a3|b2|b3|c2|c3|d2|d3|a10|a11|b10|b11|c10|c11|d10|d11
- |e0|e1|f0|f1|g0|g1|h0|h1|e8 |e9 |f8 |f9 |g8 |g9 |h8 |h9
- |e2|e3|f2|f3|g2|g3|h2|h3|e10|e11|f10|f11|g10|g11|h10|h11
- |a4|a5|b4|b5|c4|c5|d4|d5|a12|a13|b12|b13|c12|c13|d12|d13
- |a6|a7|b6|b7|c6|c7|d6|d7|a14|a15|b14|b15|c14|c15|d14|d15
- |e4|e5|f4|f5|g4|g5|h4|h5|e12|e13|f12|f13|g12|g13|h12|h13
- |e6|e7|f6|f7|g6|g7|h6|h7|e14|e15|f14|f15|g14|g15|h14|h15
- */
- #define BF16_INTERLEAVE_8x16(regArray) \
- regArray##_8 = _mm256_unpacklo_epi32(regArray##_0, regArray##_1); \
- regArray##_9 = _mm256_unpacklo_epi32(regArray##_2, regArray##_3); \
- regArray##_10 = _mm256_unpacklo_epi32(regArray##_4, regArray##_5); \
- regArray##_11 = _mm256_unpacklo_epi32(regArray##_6, regArray##_7); \
- regArray##_12 = _mm256_unpackhi_epi32(regArray##_0, regArray##_1); \
- regArray##_13 = _mm256_unpackhi_epi32(regArray##_2, regArray##_3); \
- regArray##_14 = _mm256_unpackhi_epi32(regArray##_4, regArray##_5); \
- regArray##_15 = _mm256_unpackhi_epi32(regArray##_6, regArray##_7); \
- \
- regArray##_0 = _mm256_unpacklo_epi64(regArray##_8, regArray##_9); \
- regArray##_1 = _mm256_unpackhi_epi64(regArray##_8, regArray##_9); \
- regArray##_2 = _mm256_unpacklo_epi64(regArray##_10, regArray##_11); \
- regArray##_3 = _mm256_unpackhi_epi64(regArray##_10, regArray##_11); \
- regArray##_4 = _mm256_unpacklo_epi64(regArray##_12, regArray##_13); \
- regArray##_5 = _mm256_unpackhi_epi64(regArray##_12, regArray##_13); \
- regArray##_6 = _mm256_unpacklo_epi64(regArray##_14, regArray##_15); \
- regArray##_7 = _mm256_unpackhi_epi64(regArray##_14, regArray##_15);
-
- /* 2-step interleave for matrix against 8 rows with 32 BF16 elements per row
- Input - register array of 8 rows of raw-major matrix
- Output - the output of Step 2
-
- Step 1: 2-element interleave for matrix
- |a0|a1|b0|b1|a2|a3|b2|b3|a8 |a9 |b8 |b9 |a10|a11|b10|b11|a16|a17|b16|b17|a18|a19|b18|b19|a24|a25|b24|b25|a26|a27|b26|b27
- |c0|c1|d0|d1|c2|c3|d2|d3|c8 |c9 |d8 |d9 |c10|c11|d10|d11|c16|c17|d16|d17|c18|c19|d18|d19|c24|c25|d24|d25|c26|c27|d26|d27
- |a4|a5|b4|b5|a6|a7|b6|b7|a12|a13|b12|b13|a14|a15|b14|b15|a20|a21|b20|b21|a22|a23|b22|b23|a28|a29|b28|b29|a30|a31|b30|b31
- |c4|c5|d4|d5|c6|c7|d6|d7|c12|c13|d12|d13|c14|c15|d14|d15|c20|c21|d20|d21|c22|c23|d22|d23|c28|c29|d28|d29|c30|c31|d30|d31
-
- Step 2: 4-element interleave for matrix
- |a0|a1|b0|b1|c0|c1|d0|d1|a8 |a9 |b8 |b9 |c8 |c9 |d8 |d9 |a16|a17|b16|b17|c16|c17|d16|d17|a24|a25|b24|b25|c24|c25|d24|d25
- |a2|a3|b2|b3|c2|c3|d2|d3|a10|a11|b10|b11|c10|c11|d10|d11|a18|a19|b18|b19|c18|c19|d18|d19|a26|a27|b26|b27|c26|c27|d26|d27
- |a4|a5|b4|b5|c4|c5|d4|d5|a12|a13|b12|b13|c12|c13|d12|d13|a20|a21|b20|b21|c20|c21|d20|d21|a28|a29|b28|b29|c28|c29|d28|d29
- |a6|a7|b6|b7|c6|c7|d6|d7|a14|a15|b14|b15|c14|c15|d14|d15|a22|a23|b22|b23|c22|c23|d22|d23|a30|a31|b30|b31|c30|c31|d30|d31
- */
- #define BF16_INTERLEAVE_4x32(regArray) \
- regArray##_4 = _mm512_unpacklo_epi32(regArray##_0, regArray##_1); \
- regArray##_5 = _mm512_unpacklo_epi32(regArray##_2, regArray##_3); \
- regArray##_6 = _mm512_unpackhi_epi32(regArray##_0, regArray##_1); \
- regArray##_7 = _mm512_unpackhi_epi32(regArray##_2, regArray##_3); \
- \
- regArray##_0 = _mm512_unpacklo_epi64(regArray##_4, regArray##_5); \
- regArray##_1 = _mm512_unpackhi_epi64(regArray##_4, regArray##_5); \
- regArray##_2 = _mm512_unpacklo_epi64(regArray##_6, regArray##_7); \
- regArray##_3 = _mm512_unpackhi_epi64(regArray##_6, regArray##_7);
-
-
- /* 2-step interleave for matrix against 8 rows with 16 BF16 elements per row
- Input - register array of 8 rows of raw-major matrix
- Output - the output of Step 2
-
- Step 1: 2-element interleave for matrix
- |a0|a1|b0|b1|a2|a3|b2|b3|a8 |a9 |b8 |b9 |a10|a11|b10|b11
- |c0|c1|d0|d1|c2|c3|d2|d3|c8 |c9 |d8 |d9 |c10|c11|d10|d11
- |a4|a5|b4|b5|a6|a7|b6|b7|a12|a13|b12|b13|a14|a15|b14|b15
- |c4|c5|d4|d5|c6|c7|d6|d7|c12|c13|d12|d13|c14|c15|d14|d15
-
- Step 2: 4-element interleave for matrix
- |a0|a1|b0|b1|c0|c1|d0|d1|a8 |a9 |b8 |b9 |c8 |c9 |d8 |d9
- |a2|a3|b2|b3|c2|c3|d2|d3|a10|a11|b10|b11|c10|c11|d10|d11
- |a4|a5|b4|b5|c4|c5|d4|d5|a12|a13|b12|b13|c12|c13|d12|d13
- |a6|a7|b6|b7|c6|c7|d6|d7|a14|a15|b14|b15|c14|c15|d14|d15
- */
- #define BF16_INTERLEAVE_4x16(regArray) \
- regArray##_4 = _mm256_unpacklo_epi32(regArray##_0, regArray##_1); \
- regArray##_5 = _mm256_unpacklo_epi32(regArray##_2, regArray##_3); \
- regArray##_6 = _mm256_unpackhi_epi32(regArray##_0, regArray##_1); \
- regArray##_7 = _mm256_unpackhi_epi32(regArray##_2, regArray##_3); \
- \
- regArray##_0 = _mm256_unpacklo_epi64(regArray##_4, regArray##_5); \
- regArray##_1 = _mm256_unpackhi_epi64(regArray##_4, regArray##_5); \
- regArray##_2 = _mm256_unpacklo_epi64(regArray##_6, regArray##_7); \
- regArray##_3 = _mm256_unpackhi_epi64(regArray##_6, regArray##_7);
-
-
- /* 2-step interleave for x with 32 BF16 elements
- Input - original vector
- Output - the output of Step 2
-
- Step 1: 2-element interleave for x:
- |x0|x1|x0|x1|x2|x3|x2|x3|x8 |x9 |x8 |x9 |x10|x11|x10|x11|x16|x17|x16|x17|x18|x19|x18|x19|x24|x25|x24|x25|x26|x27|x26|x27
- |x4|x5|x4|x5|x6|x7|x6|x7|x12|x13|x12|x13|x14|x15|x14|x15|x20|x21|x20|x21|x22|x23|x22|x23|x28|x29|x28|x29|x30|x31|x30|x31
-
- Step 2: 4-element interleave for x:
- |x0|x1|x0|x1|x0|x1|x0|x1|x8 |x9 |x8 |x9 |x8 |x9 |x8 |x9 |x16|x17|x16|x17|x16|x17|x16|x17|x24|x25|x24|x25|x24|x25|x24|x25
- |x2|x3|x2|x3|x2|x3|x2|x3|x10|x11|x10|x11|x10|x11|x10|x11|x18|x19|x18|x19|x18|x19|x18|x19|x26|x27|x26|x27|x26|x27|x26|x27
- |x4|x5|x4|x5|x4|x5|x4|x5|x12|x13|x12|x13|x12|x13|x12|x13|x20|x21|x20|x21|x20|x21|x20|x21|x28|x29|x28|x29|x28|x29|x28|x29
- |x6|x7|x6|x7|x6|x7|x6|x7|x14|x15|x14|x15|x14|x15|x14|x15|x22|x23|x22|x23|x22|x23|x22|x23|x30|x31|x30|x31|x30|x31|x30|x31
- */
- #define BF16_INTERLEAVE_1x32(regArray) \
- regArray##_1 = _mm512_unpacklo_epi32(regArray##_0, regArray##_0); \
- regArray##_3 = _mm512_unpackhi_epi32(regArray##_0, regArray##_0); \
- \
- regArray##_0 = _mm512_unpacklo_epi64(regArray##_1, regArray##_1); \
- regArray##_1 = _mm512_unpackhi_epi64(regArray##_1, regArray##_1); \
- regArray##_2 = _mm512_unpacklo_epi64(regArray##_3, regArray##_3); \
- regArray##_3 = _mm512_unpackhi_epi64(regArray##_3, regArray##_3);
-
-
- /* 2-step interleave for x with 16 BF16 elements
- Input - original vector
- Output - the output of Step 2
-
- Step 1: 2-element interleave for x:
- |x0|x1|x0|x1|x2|x3|x2|x3|x8 |x9 |x8 |x9 |x10|x11|x10|x11
- |x4|x5|x4|x5|x6|x7|x6|x7|x12|x13|x12|x13|x14|x15|x14|x15
-
- Step 2: 4-element interleave for x:
- |x0|x1|x0|x1|x0|x1|x0|x1|x8 |x9 |x8 |x9 |x8 |x9 |x8 |x9
- |x2|x3|x2|x3|x2|x3|x2|x3|x10|x11|x10|x11|x10|x11|x10|x11
- |x4|x5|x4|x5|x4|x5|x4|x5|x12|x13|x12|x13|x12|x13|x12|x13
- |x6|x7|x6|x7|x6|x7|x6|x7|x14|x15|x14|x15|x14|x15|x14|x15
- */
- #define BF16_INTERLEAVE_1x16(regArray) \
- regArray##_1 = _mm256_unpacklo_epi32(regArray##_0, regArray##_0); \
- regArray##_3 = _mm256_unpackhi_epi32(regArray##_0, regArray##_0); \
- \
- regArray##_0 = _mm256_unpacklo_epi64(regArray##_1, regArray##_1); \
- regArray##_1 = _mm256_unpackhi_epi64(regArray##_1, regArray##_1); \
- regArray##_2 = _mm256_unpacklo_epi64(regArray##_3, regArray##_3); \
- regArray##_3 = _mm256_unpackhi_epi64(regArray##_3, regArray##_3);
-
- /* 1-step interleave to exchange the high-256s bit and low-256 bits of 4 pair of registers
- |a0|a1|...|a14|a15|i0|i1|...|i14|i15|
- |b0|b1|...|b14|b15|j0|j1|...|j14|j15|
- |c0|c1|...|c14|c15|k0|k1|...|k14|k15|
- |d0|d1|...|d14|d15|l0|l1|...|l14|l15|
- |e0|e1|...|e14|e15|m0|m1|...|m14|m15|
- |f0|f1|...|f14|f15|n0|n1|...|n14|n15|
- |g0|g1|...|g14|g15|o0|o1|...|o14|o15|
- |h0|h1|...|h14|h15|p0|p1|...|p14|p15|
- */
- #define BF16_INTERLEAVE256_8x32(regArray) \
- regArray##_0 = _mm512_shuffle_i32x4(regArray##_8, regArray##_12, 0x44); \
- regArray##_1 = _mm512_shuffle_i32x4(regArray##_8, regArray##_12, 0xee); \
- regArray##_2 = _mm512_shuffle_i32x4(regArray##_9, regArray##_13, 0x44); \
- regArray##_3 = _mm512_shuffle_i32x4(regArray##_9, regArray##_13, 0xee); \
- regArray##_4 = _mm512_shuffle_i32x4(regArray##_10, regArray##_14, 0x44); \
- regArray##_5 = _mm512_shuffle_i32x4(regArray##_10, regArray##_14, 0xee); \
- regArray##_6 = _mm512_shuffle_i32x4(regArray##_11, regArray##_15, 0x44); \
- regArray##_7 = _mm512_shuffle_i32x4(regArray##_11, regArray##_15, 0xee);
-
-
- /* 1-step interleave to exchange the high-256s bit and low-256 bits of 2 pair of registers
- |a0|a1|...|a14|a15|e0|e1|...|e14|e15|
- |b0|b1|...|b14|b15|f0|f1|...|f14|f15|
- |c0|c1|...|c14|c15|g0|g1|...|g14|g15|
- |d0|d1|...|d14|d15|h0|h1|...|h14|h15|
- */
- #define BF16_INTERLEAVE256_4x32(regArray) \
- regArray##_0 = _mm512_shuffle_i32x4(regArray##_4, regArray##_6, 0x44); \
- regArray##_1 = _mm512_shuffle_i32x4(regArray##_4, regArray##_6, 0xee); \
- regArray##_2 = _mm512_shuffle_i32x4(regArray##_5, regArray##_7, 0x44); \
- regArray##_3 = _mm512_shuffle_i32x4(regArray##_5, regArray##_7, 0xee);
-
-
- #define BF16_PERMUTE_8x32(idx, regArray) \
- regArray##_8 = _mm512_permutexvar_epi16(idx, regArray##_0); \
- regArray##_9 = _mm512_permutexvar_epi16(idx, regArray##_1); \
- regArray##_10 = _mm512_permutexvar_epi16(idx, regArray##_2); \
- regArray##_11 = _mm512_permutexvar_epi16(idx, regArray##_3); \
- regArray##_12 = _mm512_permutexvar_epi16(idx, regArray##_4); \
- regArray##_13 = _mm512_permutexvar_epi16(idx, regArray##_5); \
- regArray##_14 = _mm512_permutexvar_epi16(idx, regArray##_6); \
- regArray##_15 = _mm512_permutexvar_epi16(idx, regArray##_7);
-
-
- #define BF16_PERMUTE_8x32_2(idx, regArray) \
- regArray##_8 = _mm512_permutexvar_epi32(idx, regArray##_0); \
- regArray##_9 = _mm512_permutexvar_epi32(idx, regArray##_1); \
- regArray##_10 = _mm512_permutexvar_epi32(idx, regArray##_2); \
- regArray##_11 = _mm512_permutexvar_epi32(idx, regArray##_3); \
- regArray##_12 = _mm512_permutexvar_epi32(idx, regArray##_4); \
- regArray##_13 = _mm512_permutexvar_epi32(idx, regArray##_5); \
- regArray##_14 = _mm512_permutexvar_epi32(idx, regArray##_6); \
- regArray##_15 = _mm512_permutexvar_epi32(idx, regArray##_7);
-
-
- #define BF16_PERMUTE_4x32(idx, regArray) \
- regArray##_4 = _mm512_permutexvar_epi16(idx, regArray##_0); \
- regArray##_5 = _mm512_permutexvar_epi16(idx, regArray##_1); \
- regArray##_6 = _mm512_permutexvar_epi16(idx, regArray##_2); \
- regArray##_7 = _mm512_permutexvar_epi16(idx, regArray##_3);
-
-
- #define BF16_PERMUTE_4x32_2(idx, regArray) \
- regArray##_4 = _mm512_permutexvar_epi32(idx, regArray##_0); \
- regArray##_5 = _mm512_permutexvar_epi32(idx, regArray##_1); \
- regArray##_6 = _mm512_permutexvar_epi32(idx, regArray##_2); \
- regArray##_7 = _mm512_permutexvar_epi32(idx, regArray##_3);
-
-
- /* Calculate the dot result for 2-step interleaved matrix and vector
- (Assume throughput for _mm512_dpbf16_ps is 0.5, tunable per platform)
- */
- #define BF16_2STEP_INTERLEAVED_DOT_8x32(accumArray, matArray, xArray) \
- accumArray##_0 = _mm512_dpbf16_ps(accumArray##_0, (__m512bh) matArray##_0, (__m512bh) xArray##_0); \
- accumArray##_1 = _mm512_dpbf16_ps(accumArray##_1, (__m512bh) matArray##_2, (__m512bh) xArray##_0); \
- accumArray##_0 = _mm512_dpbf16_ps(accumArray##_0, (__m512bh) matArray##_1, (__m512bh) xArray##_1); \
- accumArray##_1 = _mm512_dpbf16_ps(accumArray##_1, (__m512bh) matArray##_3, (__m512bh) xArray##_1); \
- accumArray##_0 = _mm512_dpbf16_ps(accumArray##_0, (__m512bh) matArray##_4, (__m512bh) xArray##_2); \
- accumArray##_1 = _mm512_dpbf16_ps(accumArray##_1, (__m512bh) matArray##_6, (__m512bh) xArray##_2); \
- accumArray##_0 = _mm512_dpbf16_ps(accumArray##_0, (__m512bh) matArray##_5, (__m512bh) xArray##_3); \
- accumArray##_1 = _mm512_dpbf16_ps(accumArray##_1, (__m512bh) matArray##_7, (__m512bh) xArray##_3);
-
-
- /* Calculate the dot result for 2-step interleaved matrix and vector
- (Assume throughput for _mm256_dpbf16_ps is 0.5, tunable per platform)
- */
- #define BF16_2STEP_INTERLEAVED_DOT_8x16(accumArray, matArray, xArray) \
- accumArray##_0 = _mm256_dpbf16_ps(accumArray##_0, (__m256bh) matArray##_0, (__m256bh) xArray##_0); \
- accumArray##_1 = _mm256_dpbf16_ps(accumArray##_1, (__m256bh) matArray##_2, (__m256bh) xArray##_0); \
- accumArray##_0 = _mm256_dpbf16_ps(accumArray##_0, (__m256bh) matArray##_1, (__m256bh) xArray##_1); \
- accumArray##_1 = _mm256_dpbf16_ps(accumArray##_1, (__m256bh) matArray##_3, (__m256bh) xArray##_1); \
- accumArray##_0 = _mm256_dpbf16_ps(accumArray##_0, (__m256bh) matArray##_4, (__m256bh) xArray##_2); \
- accumArray##_1 = _mm256_dpbf16_ps(accumArray##_1, (__m256bh) matArray##_6, (__m256bh) xArray##_2); \
- accumArray##_0 = _mm256_dpbf16_ps(accumArray##_0, (__m256bh) matArray##_5, (__m256bh) xArray##_3); \
- accumArray##_1 = _mm256_dpbf16_ps(accumArray##_1, (__m256bh) matArray##_7, (__m256bh) xArray##_3);
-
- /* Calculate the dot result for 2-step interleaved matrix and vector
- (Assume throughput for _mm512_dpbf16_ps is 0.5, tunable per platform)
- */
- #define BF16_2STEP_INTERLEAVED_DOT_4x32(accumArray, matArray, xArray) \
- accumArray##_0 = _mm512_dpbf16_ps(accumArray##_0, (__m512bh) matArray##_0, (__m512bh) xArray##_0); \
- accumArray##_1 = _mm512_dpbf16_ps(accumArray##_1, (__m512bh) matArray##_1, (__m512bh) xArray##_1); \
- accumArray##_0 = _mm512_dpbf16_ps(accumArray##_0, (__m512bh) matArray##_2, (__m512bh) xArray##_2); \
- accumArray##_1 = _mm512_dpbf16_ps(accumArray##_1, (__m512bh) matArray##_3, (__m512bh) xArray##_3);
-
-
- /* Calculate the dot result for 2-step interleaved matrix and vector
- (Assume throughput for _mm256_dpbf16_ps is 0.5, tunable per platform)
- */
- #define BF16_2STEP_INTERLEAVED_DOT_4x16(accumArray, matArray, xArray) \
- accumArray##_0 = _mm256_dpbf16_ps(accumArray##_0, (__m256bh) matArray##_0, (__m256bh) xArray##_0); \
- accumArray##_1 = _mm256_dpbf16_ps(accumArray##_1, (__m256bh) matArray##_1, (__m256bh) xArray##_1); \
- accumArray##_0 = _mm256_dpbf16_ps(accumArray##_0, (__m256bh) matArray##_2, (__m256bh) xArray##_2); \
- accumArray##_1 = _mm256_dpbf16_ps(accumArray##_1, (__m256bh) matArray##_3, (__m256bh) xArray##_3);
-
-
- /* Calculate the dot result for matrix and vector at 32 elements per row
- (Assume throughput for _mm512_dpbf16_ps is 0.5, tunable per platform)
- */
- #define BF16_DOT_8x32(accumArray, matArray, xArray) \
- accumArray##_0 = _mm512_dpbf16_ps(accumArray##_0, (__m512bh) matArray##_0, (__m512bh) xArray); \
- accumArray##_1 = _mm512_dpbf16_ps(accumArray##_1, (__m512bh) matArray##_1, (__m512bh) xArray); \
- accumArray##_2 = _mm512_dpbf16_ps(accumArray##_2, (__m512bh) matArray##_2, (__m512bh) xArray); \
- accumArray##_3 = _mm512_dpbf16_ps(accumArray##_3, (__m512bh) matArray##_3, (__m512bh) xArray); \
- accumArray##_4 = _mm512_dpbf16_ps(accumArray##_4, (__m512bh) matArray##_4, (__m512bh) xArray); \
- accumArray##_5 = _mm512_dpbf16_ps(accumArray##_5, (__m512bh) matArray##_5, (__m512bh) xArray); \
- accumArray##_6 = _mm512_dpbf16_ps(accumArray##_6, (__m512bh) matArray##_6, (__m512bh) xArray); \
- accumArray##_7 = _mm512_dpbf16_ps(accumArray##_7, (__m512bh) matArray##_7, (__m512bh) xArray);
-
- /* Calculate the dot result for matrix and vector at 32 elements per row
- (Assume throughput for _mm512_dpbf16_ps is 0.5, tunable per platform)
- */
- #define BF16_DOT_1x32(accumArray, matArray, xArray) \
- accumArray = _mm512_dpbf16_ps(accumArray, (__m512bh) matArray, (__m512bh) xArray);
-
- /* Calculate the dot result for matrix and vector at 16 elements per row
- (Assume throughput for _mm256_dpbf16_ps is 0.5, tunable per platform)
- */
- #define BF16_DOT_8x16(accumArray, matArray, xArray) \
- accumArray##_0 = _mm256_dpbf16_ps(accumArray##_0, (__m256bh) matArray##_0, (__m256bh) xArray); \
- accumArray##_1 = _mm256_dpbf16_ps(accumArray##_1, (__m256bh) matArray##_1, (__m256bh) xArray); \
- accumArray##_2 = _mm256_dpbf16_ps(accumArray##_2, (__m256bh) matArray##_2, (__m256bh) xArray); \
- accumArray##_3 = _mm256_dpbf16_ps(accumArray##_3, (__m256bh) matArray##_3, (__m256bh) xArray); \
- accumArray##_4 = _mm256_dpbf16_ps(accumArray##_4, (__m256bh) matArray##_4, (__m256bh) xArray); \
- accumArray##_5 = _mm256_dpbf16_ps(accumArray##_5, (__m256bh) matArray##_5, (__m256bh) xArray); \
- accumArray##_6 = _mm256_dpbf16_ps(accumArray##_6, (__m256bh) matArray##_6, (__m256bh) xArray); \
- accumArray##_7 = _mm256_dpbf16_ps(accumArray##_7, (__m256bh) matArray##_7, (__m256bh) xArray);
-
-
- /* 2-step interleave for matrix against 8 rows with 16 fp32 elements per row
- Input - register array of 8 rows of raw-major matrix
- Output - the output of Step 2
-
- Step 1: 2-element interleave for matrix
- |a0|b0|a1|b1|a4|b4|a5|b5|a8 |b8 |a9 |b9 |a12|b12|a13|b13|
- |c0|d0|c1|d1|c4|d4|c5|d5|c8 |d8 |c9 |d9 |c12|d12|c13|d13|
- |e0|f0|e1|f1|e4|f4|e5|f5|e8 |f8 |e9 |f9 |e12|f12|e13|f13|
- |g0|h0|g1|h1|g4|h4|g5|h5|g8 |h8 |g9 |h9 |g12|h12|g13|h13|
- |a2|b2|a3|b3|a6|b6|a7|b7|a10|b10|a11|b11|a14|b14|a15|b15|
- |c2|d2|c3|d3|c6|d6|c7|d7|c10|d10|c11|d11|c14|d14|c15|d15|
- |e2|f2|e3|f3|e6|f6|e7|f7|e10|f10|e11|f11|e14|f14|e15|f15|
- |g2|h2|g3|h3|g6|h6|g7|h7|g10|h10|g11|h11|g14|h14|g15|h15|
-
- Step 2: 4-element interleave for matrix
- |a0|b0|c0|d0|a4|b4|c4|d4|a8 |b8 |c8 |d8 |a12|b12|c12|d12|
- |a1|b1|c1|d1|a5|b5|c5|d5|a9 |b9 |c9 |d9 |a13|b13|c13|d13|
- |e0|f0|g0|h0|e4|f4|g4|h4|e8 |f8 |g8 |h8 |e12|f12|g12|h12|
- |e1|f1|g1|h1|e5|f5|g5|h5|e9 |f9 |g9 |h9 |e13|f13|g13|h13|
- |a2|b2|c2|d2|a6|b6|c6|d6|a10|b10|c10|d10|a14|b14|c14|d14|
- |a3|b3|c3|d3|a7|b7|c7|d7|a11|b11|c11|d11|a15|b15|c15|d15|
- |e2|f2|g2|h2|e6|f6|g6|h6|e10|f10|g10|h10|e14|f14|g14|h14|
- |e3|f3|g3|h3|e7|f7|g7|h7|e11|f11|g11|h11|e15|f15|g15|h15|
- */
- #define FP32_INTERLEAVE_8x16(regArray) \
- regArray##_8 = _mm512_unpacklo_ps(regArray##_0, regArray##_1); \
- regArray##_9 = _mm512_unpacklo_ps(regArray##_2, regArray##_3); \
- regArray##_10 = _mm512_unpacklo_ps(regArray##_4, regArray##_5); \
- regArray##_11 = _mm512_unpacklo_ps(regArray##_6, regArray##_7); \
- regArray##_12 = _mm512_unpackhi_ps(regArray##_0, regArray##_1); \
- regArray##_13 = _mm512_unpackhi_ps(regArray##_2, regArray##_3); \
- regArray##_14 = _mm512_unpackhi_ps(regArray##_4, regArray##_5); \
- regArray##_15 = _mm512_unpackhi_ps(regArray##_6, regArray##_7); \
- \
- regArray##_0 = (__m512) _mm512_unpacklo_pd((__m512d) regArray##_8, (__m512d) regArray##_9); \
- regArray##_1 = (__m512) _mm512_unpackhi_pd((__m512d) regArray##_8, (__m512d) regArray##_9); \
- regArray##_4 = (__m512) _mm512_unpacklo_pd((__m512d) regArray##_10, (__m512d) regArray##_11); \
- regArray##_5 = (__m512) _mm512_unpackhi_pd((__m512d) regArray##_10, (__m512d) regArray##_11); \
- regArray##_2 = (__m512) _mm512_unpacklo_pd((__m512d) regArray##_12, (__m512d) regArray##_13); \
- regArray##_3 = (__m512) _mm512_unpackhi_pd((__m512d) regArray##_12, (__m512d) regArray##_13); \
- regArray##_6 = (__m512) _mm512_unpacklo_pd((__m512d) regArray##_14, (__m512d) regArray##_15); \
- regArray##_7 = (__m512) _mm512_unpackhi_pd((__m512d) regArray##_14, (__m512d) regArray##_15);
-
- #define FP32_INTERLEAVE_8x16_ARRAY(regArray) \
- regArray[8] = _mm512_unpacklo_ps(regArray[0], regArray[1]); \
- regArray[9] = _mm512_unpacklo_ps(regArray[2], regArray[3]); \
- regArray[10] = _mm512_unpacklo_ps(regArray[4], regArray[5]); \
- regArray[11] = _mm512_unpacklo_ps(regArray[6], regArray[7]); \
- regArray[12] = _mm512_unpackhi_ps(regArray[0], regArray[1]); \
- regArray[13] = _mm512_unpackhi_ps(regArray[2], regArray[3]); \
- regArray[14] = _mm512_unpackhi_ps(regArray[4], regArray[5]); \
- regArray[15] = _mm512_unpackhi_ps(regArray[6], regArray[7]); \
- \
- regArray[0] = (__m512) _mm512_unpacklo_pd((__m512d) regArray[8], (__m512d) regArray[9]); \
- regArray[1] = (__m512) _mm512_unpackhi_pd((__m512d) regArray[8], (__m512d) regArray[9]); \
- regArray[4] = (__m512) _mm512_unpacklo_pd((__m512d) regArray[10], (__m512d) regArray[11]); \
- regArray[5] = (__m512) _mm512_unpackhi_pd((__m512d) regArray[10], (__m512d) regArray[11]); \
- regArray[2] = (__m512) _mm512_unpacklo_pd((__m512d) regArray[12], (__m512d) regArray[13]); \
- regArray[3] = (__m512) _mm512_unpackhi_pd((__m512d) regArray[12], (__m512d) regArray[13]); \
- regArray[6] = (__m512) _mm512_unpacklo_pd((__m512d) regArray[14], (__m512d) regArray[15]); \
- regArray[7] = (__m512) _mm512_unpackhi_pd((__m512d) regArray[14], (__m512d) regArray[15]);
-
- /* 2-step interleave for matrix against 8 rows with 8 fp32 elements per row
- Input - register array of 8 rows of raw-major matrix
- Output - the output of Step 2
-
- Step 1: 2-element interleave for matrix
- |a0|b0|a1|b1|a4|b4|a5|b5|
- |c0|d0|c1|d1|c4|d4|c5|d5|
- |e0|f0|e1|f1|e4|f4|e5|f5|
- |g0|h0|g1|h1|g4|h4|g5|h5|
- |a2|b2|a3|b3|a6|b6|a7|b7|
- |c2|d2|c3|d3|c6|d6|c7|d7|
- |e2|f2|e3|f3|e6|f6|e7|f7|
- |g2|h2|g3|h3|g6|h6|g7|h7|
-
- Step 2: 4-element interleave for matrix
- |a0|b0|c0|d0|a4|b4|c4|d4|
- |a1|b1|c1|d1|a5|b5|c5|d5|
- |e0|f0|g0|h0|e4|f4|g4|h4|
- |e1|f1|g1|h1|e5|f5|g5|h5|
- |a2|b2|c2|d2|a6|b6|c6|d6|
- |a3|b3|c3|d3|a7|b7|c7|d7|
- |e2|f2|g2|h2|e6|f6|g6|h6|
- |e3|f3|g3|h3|e7|f7|g7|h7|
- */
- #define FP32_INTERLEAVE_8x8(regArray) \
- regArray##_8 = _mm256_unpacklo_ps(regArray##_0, regArray##_1); \
- regArray##_9 = _mm256_unpacklo_ps(regArray##_2, regArray##_3); \
- regArray##_10 = _mm256_unpacklo_ps(regArray##_4, regArray##_5); \
- regArray##_11 = _mm256_unpacklo_ps(regArray##_6, regArray##_7); \
- regArray##_12 = _mm256_unpackhi_ps(regArray##_0, regArray##_1); \
- regArray##_13 = _mm256_unpackhi_ps(regArray##_2, regArray##_3); \
- regArray##_14 = _mm256_unpackhi_ps(regArray##_4, regArray##_5); \
- regArray##_15 = _mm256_unpackhi_ps(regArray##_6, regArray##_7); \
- \
- regArray##_0 = (__m256) _mm256_unpacklo_pd((__m256d) regArray##_8, (__m256d) regArray##_9); \
- regArray##_1 = (__m256) _mm256_unpackhi_pd((__m256d) regArray##_8, (__m256d) regArray##_9); \
- regArray##_4 = (__m256) _mm256_unpacklo_pd((__m256d) regArray##_10, (__m256d) regArray##_11); \
- regArray##_5 = (__m256) _mm256_unpackhi_pd((__m256d) regArray##_10, (__m256d) regArray##_11); \
- regArray##_2 = (__m256) _mm256_unpacklo_pd((__m256d) regArray##_12, (__m256d) regArray##_13); \
- regArray##_3 = (__m256) _mm256_unpackhi_pd((__m256d) regArray##_12, (__m256d) regArray##_13); \
- regArray##_6 = (__m256) _mm256_unpacklo_pd((__m256d) regArray##_14, (__m256d) regArray##_15); \
- regArray##_7 = (__m256) _mm256_unpackhi_pd((__m256d) regArray##_14, (__m256d) regArray##_15);
-
-
- /* Accumulate the result for 2 batch of 4-registers
- */
- #define FP32_ACCUM2_8x16(regArray) \
- regArray##_0 = _mm512_add_ps(regArray##_0, regArray##_1); \
- regArray##_2 = _mm512_add_ps(regArray##_2, regArray##_3); \
- regArray##_4 = _mm512_add_ps(regArray##_4, regArray##_5); \
- regArray##_6 = _mm512_add_ps(regArray##_6, regArray##_7); \
- regArray##_0 = _mm512_add_ps(regArray##_0, regArray##_2); \
- regArray##_4 = _mm512_add_ps(regArray##_4, regArray##_6);
-
- #define FP32_ACCUM2_8x16_ARRAY(regArray) \
- regArray[0] = _mm512_add_ps(regArray[0], regArray[1]); \
- regArray[2] = _mm512_add_ps(regArray[2], regArray[3]); \
- regArray[4] = _mm512_add_ps(regArray[4], regArray[5]); \
- regArray[6] = _mm512_add_ps(regArray[6], regArray[7]); \
- regArray[0] = _mm512_add_ps(regArray[0], regArray[2]); \
- regArray[4] = _mm512_add_ps(regArray[4], regArray[6]);
-
- /* Accumulate the result for 2 batch of 4-registers
- */
- #define FP32_ACCUM2_8x8(regArray) \
- regArray##_0 = _mm256_add_ps(regArray##_0, regArray##_1); \
- regArray##_2 = _mm256_add_ps(regArray##_2, regArray##_3); \
- regArray##_4 = _mm256_add_ps(regArray##_4, regArray##_5); \
- regArray##_6 = _mm256_add_ps(regArray##_6, regArray##_7); \
- regArray##_0 = _mm256_add_ps(regArray##_0, regArray##_2); \
- regArray##_4 = _mm256_add_ps(regArray##_4, regArray##_6);
-
-
- /* Store 16 (alpha * result + beta * y) to y
- */
- #define STORE16_COMPLETE_RESULT_ALPHA_BETA(regResult, targetAddr) \
- regResult = _mm512_fmadd_ps(ALPHAVECTOR, regResult, _mm512_mul_ps(BETAVECTOR, _mm512_loadu_ps(targetAddr))); \
- _mm512_storeu_ps(targetAddr, regResult);
-
-
- /* Masked store 16 (alpha * result + beta * y) to y
- */
- #define STORE16_MASK_COMPLETE_RESULT_ALPHA_BETA(regResult, targetAddr, mask) \
- regResult = _mm512_fmadd_ps(ALPHAVECTOR, regResult, _mm512_mul_ps(BETAVECTOR, _mm512_maskz_loadu_ps(mask, targetAddr))); \
- _mm512_mask_storeu_ps(targetAddr, mask, regResult);
-
-
- /* Store 8 (alpha * result + beta * y) to y
- */
- #define STORE8_COMPLETE_RESULT_ALPHA_BETA(regResult, targetAddr) \
- regResult = _mm256_fmadd_ps(_mm512_castps512_ps256(ALPHAVECTOR), regResult, _mm256_mul_ps(_mm512_castps512_ps256(BETAVECTOR), _mm256_loadu_ps(targetAddr))); \
- _mm256_storeu_ps(targetAddr, regResult);
-
-
- /* Masked store 8 (alpha * result + beta * y) to y
- */
- #define STORE8_MASK_COMPLETE_RESULT_ALPHA_BETA(regResult, targetAddr, mask) \
- regResult = _mm256_fmadd_ps(_mm512_castps512_ps256(ALPHAVECTOR), regResult, _mm256_mul_ps(_mm512_castps512_ps256(BETAVECTOR), _mm256_maskz_loadu_ps(mask, targetAddr))); \
- _mm256_mask_storeu_ps(targetAddr, mask, regResult);
-
-
- /* Store 4 (alpha * result + beta * y) to y
- */
- #define STORE4_COMPLETE_RESULT_ALPHA_BETA(regResult, targetAddr) \
- regResult = _mm_fmadd_ps(_mm512_castps512_ps128(ALPHAVECTOR), regResult, _mm_mul_ps(_mm512_castps512_ps128(BETAVECTOR), _mm_loadu_ps(targetAddr))); \
- _mm_storeu_ps(targetAddr, regResult);
-
-
- /* Masked store 4 (alpha * result + beta * y) to y
- */
- #define STORE4_MASK_COMPLETE_RESULT_ALPHA_BETA(regResult, targetAddr, mask) \
- regResult = _mm_fmadd_ps(_mm512_castps512_ps128(ALPHAVECTOR), regResult, _mm_mul_ps(_mm512_castps512_ps128(BETAVECTOR), _mm_maskz_loadu_ps(mask, targetAddr))); \
- _mm_mask_storeu_ps(targetAddr, mask, regResult);
-
-
- /* Store 16 (alpha * result + y) to y
- */
- #define STORE16_COMPLETE_RESULT_ALPHA_ONE(regResult, targetAddr) \
- regResult = _mm512_fmadd_ps(ALPHAVECTOR, regResult, _mm512_loadu_ps(targetAddr)); \
- _mm512_storeu_ps(targetAddr, regResult);
-
-
- /* Masked store 16 (alpha * result + y) to y
- */
- #define STORE16_MASK_COMPLETE_RESULT_ALPHA_ONE(regResult, targetAddr, mask) \
- regResult = _mm512_fmadd_ps(ALPHAVECTOR, regResult, _mm512_maskz_loadu_ps(mask, targetAddr)); \
- _mm512_mask_storeu_ps(targetAddr, mask, regResult);
-
-
- /* Store 8 (alpha * result + y) to y
- */
- #define STORE8_COMPLETE_RESULT_ALPHA_ONE(regResult, targetAddr) \
- regResult = _mm256_fmadd_ps(_mm512_castps512_ps256(ALPHAVECTOR), regResult, _mm256_loadu_ps(targetAddr)); \
- _mm256_storeu_ps(targetAddr, regResult);
-
-
- /* Masked store 8 (alpha * result + y) to y
- */
- #define STORE8_MASK_COMPLETE_RESULT_ALPHA_ONE(regResult, targetAddr, mask) \
- regResult = _mm256_fmadd_ps(_mm512_castps512_ps256(ALPHAVECTOR), regResult, _mm256_maskz_loadu_ps(mask, targetAddr)); \
- _mm256_mask_storeu_ps(targetAddr, mask, regResult);
-
-
- /* Store 4 (alpha * result + y) to y
- */
- #define STORE4_COMPLETE_RESULT_ALPHA_ONE(regResult, targetAddr) \
- regResult = _mm_fmadd_ps(_mm512_castps512_ps128(ALPHAVECTOR), regResult, _mm_loadu_ps(targetAddr)); \
- _mm_storeu_ps(targetAddr, regResult);
-
-
- /* Masked store 4 (alpha * result + y) to y
- */
- #define STORE4_MASK_COMPLETE_RESULT_ALPHA_ONE(regResult, targetAddr, mask) \
- regResult = _mm_fmadd_ps(_mm512_castps512_ps128(ALPHAVECTOR), regResult, _mm_maskz_loadu_ps(mask, targetAddr)); \
- _mm_mask_storeu_ps(targetAddr, mask, regResult);
-
-
- /* Store 16 (result + y) to y
- */
- #define STORE16_COMPLETE_RESULT_ONE_ONE(regResult, targetAddr) \
- regResult = _mm512_add_ps(regResult, _mm512_loadu_ps(targetAddr)); \
- _mm512_storeu_ps(targetAddr, regResult);
-
-
- /* Masked store 16 (result + y) to y
- */
- #define STORE16_MASK_COMPLETE_RESULT_ONE_ONE(regResult, targetAddr, mask) \
- regResult = _mm512_add_ps(regResult, _mm512_maskz_loadu_ps(mask, targetAddr)); \
- _mm512_mask_storeu_ps(targetAddr, mask, regResult);
-
-
- /* Store 8 (result + y) to y
- */
- #define STORE8_COMPLETE_RESULT_ONE_ONE(regResult, targetAddr) \
- regResult = _mm256_add_ps(regResult, _mm256_loadu_ps(targetAddr)); \
- _mm256_storeu_ps(targetAddr, regResult);
-
-
- /* Masked store 8 (result + y) to y
- */
- #define STORE8_MASK_COMPLETE_RESULT_ONE_ONE(regResult, targetAddr, mask) \
- regResult = _mm256_add_ps(regResult, _mm256_maskz_loadu_ps(mask, targetAddr)); \
- _mm256_mask_storeu_ps(targetAddr, mask, regResult);
-
-
- /* Store 4 (result + y) to y
- */
- #define STORE4_COMPLETE_RESULT_ONE_ONE(regResult, targetAddr) \
- regResult = _mm_add_ps(regResult, _mm_loadu_ps(targetAddr)); \
- _mm_storeu_ps(targetAddr, regResult);
-
-
- /* Masked store 4 (result + y) to y
- */
- #define STORE4_MASK_COMPLETE_RESULT_ONE_ONE(regResult, targetAddr, mask) \
- regResult = _mm_add_ps(regResult, _mm_maskz_loadu_ps(mask, targetAddr)); \
- _mm_mask_storeu_ps(targetAddr, mask, regResult);
-
-
- /* Store 16 (alpha * result) to y
- */
- #define STORE16_COMPLETE_RESULT_ALPHA(regResult, targetAddr) \
- _mm512_storeu_ps(targetAddr, _mm512_mul_ps(ALPHAVECTOR, regResult));
-
-
- /* Masked store 16 (alpha * result) to y
- */
- #define STORE16_MASK_COMPLETE_RESULT_ALPHA(regResult, targetAddr, mask) \
- _mm512_mask_storeu_ps(targetAddr, mask, _mm512_mul_ps(ALPHAVECTOR, regResult));
-
-
- /* Store 8 (alpha * result) to y
- */
- #define STORE8_COMPLETE_RESULT_ALPHA(regResult, targetAddr) \
- _mm256_storeu_ps(targetAddr, _mm256_mul_ps(_mm512_castps512_ps256(ALPHAVECTOR), regResult));
-
-
- /* Masked store 8 (alpha * result) to y
- */
- #define STORE8_MASK_COMPLETE_RESULT_ALPHA(regResult, targetAddr, mask) \
- _mm256_mask_storeu_ps(targetAddr, mask, _mm256_mul_ps(_mm512_castps512_ps256(ALPHAVECTOR), regResult));
-
-
- /* Store 4 (alpha * result) to y
- */
- #define STORE4_COMPLETE_RESULT_ALPHA(regResult, targetAddr) \
- _mm_storeu_ps(targetAddr, _mm_mul_ps(_mm512_castps512_ps128(ALPHAVECTOR), regResult));
-
-
- /* Masked store 4 (alpha * result) to y
- */
- #define STORE4_MASK_COMPLETE_RESULT_ALPHA(regResult, targetAddr, mask) \
- _mm_mask_storeu_ps(targetAddr, mask, _mm_mul_ps(_mm512_castps512_ps128(ALPHAVECTOR), regResult));
-
-
- /* Store 16 result to y
- */
- #define STORE16_COMPLETE_RESULT_DIRECT(regResult, targetAddr) \
- _mm512_storeu_ps(targetAddr, regResult);
-
-
- /* Masked store 16 result to y
- */
- #define STORE16_MASK_COMPLETE_RESULT_DIRECT(regResult, targetAddr, mask) \
- _mm512_mask_storeu_ps(targetAddr, mask, regResult);
-
-
- /* Store 8 result to y
- */
- #define STORE8_COMPLETE_RESULT_DIRECT(regResult, targetAddr) \
- _mm256_storeu_ps(targetAddr, regResult);
-
-
- /* Masked store 8 result to y
- */
- #define STORE8_MASK_COMPLETE_RESULT_DIRECT(regResult, targetAddr, mask) \
- _mm256_mask_storeu_ps(targetAddr, mask, regResult);
-
-
- /* Store 4 result to y
- */
- #define STORE4_COMPLETE_RESULT_DIRECT(regResult, targetAddr) \
- _mm_storeu_ps(targetAddr, regResult);
-
-
- /* Masked store 4 result to y
- */
- #define STORE4_MASK_COMPLETE_RESULT_DIRECT(regResult, targetAddr, mask) \
- _mm_mask_storeu_ps(targetAddr, mask, regResult);
-
- #endif
|