|
- /***************************************************************************
- Copyright (c) 2014, The OpenBLAS Project
- All rights reserved.
- Redistribution and use in source and binary forms, with or without
- modification, are permitted provided that the following conditions are
- met:
- 1. Redistributions of source code must retain the above copyright
- notice, this list of conditions and the following disclaimer.
- 2. Redistributions in binary form must reproduce the above copyright
- notice, this list of conditions and the following disclaimer in
- the documentation and/or other materials provided with the
- distribution.
- 3. Neither the name of the OpenBLAS project nor the names of
- its contributors may be used to endorse or promote products
- derived from this software without specific prior written permission.
- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
- AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
- IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
- ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
- LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
- DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
- SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
- CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
- OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
- USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
- *****************************************************************************/
- #ifndef __BF16_COMMON_MACROS
- #define __BF16_COMMON_MACROS
-
- #include <immintrin.h>
-
- #define _MM512_BROADCASTD_EPI32(addr, zmm) \
- __asm__ ("vpbroadcastd (%1), %0;" \
- : "=v" (zmm) \
- : "r" (addr) )
-
- #define PREFETCH_T0(addr) \
- __asm__ ("prefetcht0 (%0);" \
- : \
- : "r" (addr) )
-
- #define EXTRACT_LOW_256_FROM_512_2X(reg256, reg512) \
- reg256##_0 = _mm512_castps512_ps256(reg512##_0); \
- reg256##_1 = _mm512_castps512_ps256(reg512##_1);
-
-
- #define BF16_MATRIX_LOAD_8x32(regArray, a, lda, idx_m, idx_n) \
- regArray##_0 = _mm512_loadu_si512(&a[(idx_m+0)*lda + idx_n]); \
- regArray##_1 = _mm512_loadu_si512(&a[(idx_m+1)*lda + idx_n]); \
- regArray##_2 = _mm512_loadu_si512(&a[(idx_m+2)*lda + idx_n]); \
- regArray##_3 = _mm512_loadu_si512(&a[(idx_m+3)*lda + idx_n]); \
- regArray##_4 = _mm512_loadu_si512(&a[(idx_m+4)*lda + idx_n]); \
- regArray##_5 = _mm512_loadu_si512(&a[(idx_m+5)*lda + idx_n]); \
- regArray##_6 = _mm512_loadu_si512(&a[(idx_m+6)*lda + idx_n]); \
- regArray##_7 = _mm512_loadu_si512(&a[(idx_m+7)*lda + idx_n]);
-
-
- #define BF16_MATRIX_LOAD_8x16(regArray, a, lda, idx_m, idx_n) \
- regArray##_0 = _mm256_loadu_si256((__m256i *)(&a[(idx_m+0)*lda + idx_n])); \
- regArray##_1 = _mm256_loadu_si256((__m256i *)(&a[(idx_m+1)*lda + idx_n])); \
- regArray##_2 = _mm256_loadu_si256((__m256i *)(&a[(idx_m+2)*lda + idx_n])); \
- regArray##_3 = _mm256_loadu_si256((__m256i *)(&a[(idx_m+3)*lda + idx_n])); \
- regArray##_4 = _mm256_loadu_si256((__m256i *)(&a[(idx_m+4)*lda + idx_n])); \
- regArray##_5 = _mm256_loadu_si256((__m256i *)(&a[(idx_m+5)*lda + idx_n])); \
- regArray##_6 = _mm256_loadu_si256((__m256i *)(&a[(idx_m+6)*lda + idx_n])); \
- regArray##_7 = _mm256_loadu_si256((__m256i *)(&a[(idx_m+7)*lda + idx_n]));
-
-
- #define BF16_MATRIX_LOAD_8x8(regArray, a, lda, idx_m, idx_n) \
- regArray##_0 = _mm_loadu_si128((__m128i *)(&a[(idx_m+0)*lda + idx_n])); \
- regArray##_1 = _mm_loadu_si128((__m128i *)(&a[(idx_m+1)*lda + idx_n])); \
- regArray##_2 = _mm_loadu_si128((__m128i *)(&a[(idx_m+2)*lda + idx_n])); \
- regArray##_3 = _mm_loadu_si128((__m128i *)(&a[(idx_m+3)*lda + idx_n])); \
- regArray##_4 = _mm_loadu_si128((__m128i *)(&a[(idx_m+4)*lda + idx_n])); \
- regArray##_5 = _mm_loadu_si128((__m128i *)(&a[(idx_m+5)*lda + idx_n])); \
- regArray##_6 = _mm_loadu_si128((__m128i *)(&a[(idx_m+6)*lda + idx_n])); \
- regArray##_7 = _mm_loadu_si128((__m128i *)(&a[(idx_m+7)*lda + idx_n]));
-
-
- #define BF16_MATRIX_LOAD_1x32(regArray, a, lda, idx_m, idx_n) \
- regArray = _mm512_loadu_si512(&a[idx_m*lda + idx_n]);
-
-
- #define BF16_MATRIX_MASKZ_LOAD_8x32(regArray, a, lda, idx_m, idx_n, mask) \
- regArray##_0 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+0)*lda + idx_n]); \
- regArray##_1 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+1)*lda + idx_n]); \
- regArray##_2 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+2)*lda + idx_n]); \
- regArray##_3 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+3)*lda + idx_n]); \
- regArray##_4 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+4)*lda + idx_n]); \
- regArray##_5 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+5)*lda + idx_n]); \
- regArray##_6 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+6)*lda + idx_n]); \
- regArray##_7 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+7)*lda + idx_n]);
-
-
- #define BF16_MATRIX_MASKZ_LOAD_8x16(regArray, a, lda, idx_m, idx_n, mask) \
- regArray##_0 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+0)*lda + idx_n]); \
- regArray##_1 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+1)*lda + idx_n]); \
- regArray##_2 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+2)*lda + idx_n]); \
- regArray##_3 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+3)*lda + idx_n]); \
- regArray##_4 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+4)*lda + idx_n]); \
- regArray##_5 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+5)*lda + idx_n]); \
- regArray##_6 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+6)*lda + idx_n]); \
- regArray##_7 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+7)*lda + idx_n]);
-
-
- #define BF16_MATRIX_MASKZ_LOAD_8x8(regArray, a, lda, idx_m, idx_n, mask) \
- regArray##_0 = _mm_maskz_loadu_epi16(mask, &a[(idx_m+0)*lda + idx_n]); \
- regArray##_1 = _mm_maskz_loadu_epi16(mask, &a[(idx_m+1)*lda + idx_n]); \
- regArray##_2 = _mm_maskz_loadu_epi16(mask, &a[(idx_m+2)*lda + idx_n]); \
- regArray##_3 = _mm_maskz_loadu_epi16(mask, &a[(idx_m+3)*lda + idx_n]); \
- regArray##_4 = _mm_maskz_loadu_epi16(mask, &a[(idx_m+4)*lda + idx_n]); \
- regArray##_5 = _mm_maskz_loadu_epi16(mask, &a[(idx_m+5)*lda + idx_n]); \
- regArray##_6 = _mm_maskz_loadu_epi16(mask, &a[(idx_m+6)*lda + idx_n]); \
- regArray##_7 = _mm_maskz_loadu_epi16(mask, &a[(idx_m+7)*lda + idx_n]);
-
-
- #define BF16_MATRIX_MASKZ_LOAD_4x32(regArray, a, lda, idx_m, idx_n, mask) \
- regArray##_0 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+0)*lda + idx_n]); \
- regArray##_1 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+1)*lda + idx_n]); \
- regArray##_2 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+2)*lda + idx_n]); \
- regArray##_3 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+3)*lda + idx_n]);
-
-
- #define BF16_MATRIX_MASKZ_LOAD_4x16(regArray, a, lda, idx_m, idx_n, mask) \
- regArray##_0 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+0)*lda + idx_n]); \
- regArray##_1 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+1)*lda + idx_n]); \
- regArray##_2 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+2)*lda + idx_n]); \
- regArray##_3 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+3)*lda + idx_n]);
-
-
- #define BF16_MATRIX_MASKZ_LOAD_8x32_2(regArray, a, lda, idx_m, idx_n, mask) \
- regArray##_0 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+0)*lda + idx_n]); \
- regArray##_1 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+2)*lda + idx_n]); \
- regArray##_2 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+4)*lda + idx_n]); \
- regArray##_3 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+6)*lda + idx_n]); \
- regArray##_4 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+8)*lda + idx_n]); \
- regArray##_5 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+10)*lda + idx_n]); \
- regArray##_6 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+12)*lda + idx_n]); \
- regArray##_7 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+14)*lda + idx_n]);
-
-
- #define BF16_MATRIX_MASKZ_LOAD_4x32_2(regArray, a, lda, idx_m, idx_n, mask) \
- regArray##_0 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+0)*lda + idx_n]); \
- regArray##_1 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+2)*lda + idx_n]); \
- regArray##_2 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+4)*lda + idx_n]); \
- regArray##_3 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+6)*lda + idx_n]);
-
- #define BF16_MATRIX_MASKZ_LOAD_1x32(regArray, a, lda, idx_m, idx_n, mask) \
- regArray = _mm512_maskz_loadu_epi16(mask, &a[idx_m*lda + idx_n]);
-
- #define BF16_VECTOR_LOAD_1x32(reg, x, idx_n) \
- reg = _mm512_loadu_si512(x + idx_n);
-
-
- #define BF16_VECTOR_LOAD_1x16(reg, x, idx_n) \
- reg = _mm256_loadu_si256((__m256i *)(x + idx_n));
-
-
- #define BF16_VECTOR_LOAD_1x8(reg, x, idx_n) \
- reg = _mm_loadu_si128((__m128i *)(x + idx_n));
-
-
- #define BF16_VECTOR_MASKZ_LOAD_1x32(reg, x, idx_n, mask) \
- reg = _mm512_maskz_loadu_epi16(mask, x + idx_n);
-
-
- #define BF16_VECTOR_MASKZ_LOAD_1x16(reg, x, idx_n, mask) \
- reg = _mm256_maskz_loadu_epi16(mask, x + idx_n);
-
-
- #define BF16_VECTOR_MASKZ_LOAD_1x8(reg, x, idx_n, mask) \
- reg = _mm_maskz_loadu_epi16(mask, x + idx_n);
-
-
- /* 2-step interleave for matrix against 8 rows with 32 BF16 elements per row
- Input - register array of 8 rows of raw-major matrix
- Output - the output of Step 2
-
- Step 1: 2-element interleave for matrix
- |a0|a1|b0|b1|a2|a3|b2|b3|a8 |a9 |b8 |b9 |a10|a11|b10|b11|a16|a17|b16|b17|a18|a19|b18|b19|a24|a25|b24|b25|a26|a27|b26|b27
- |c0|c1|d0|d1|c2|c3|d2|d3|c8 |c9 |d8 |d9 |c10|c11|d10|d11|c16|c17|d16|d17|c18|c19|d18|d19|c24|c25|d24|d25|c26|c27|d26|d27
- |e0|e1|f0|f1|e2|e3|f2|f3|e8 |e9 |f8 |f9 |e10|e11|f10|f11|e16|e17|f16|f17|e18|e19|f18|f19|e24|e25|f24|f25|e26|e27|f26|f27
- |g0|g1|h0|h1|g2|g3|h2|h3|g8 |g9 |h8 |h9 |g10|g11|h10|h11|g16|g17|h16|h17|g18|g19|h18|h19|g24|g25|h24|h25|g26|g27|h26|h27
- |a4|a5|b4|b5|a6|a7|b6|b7|a12|a13|b12|b13|a14|a15|b14|b15|a20|a21|b20|b21|a22|a23|b22|b23|a28|a29|b28|b29|a30|a31|b30|b31
- |c4|c5|d4|d5|c6|c7|d6|d7|c12|c13|d12|d13|c14|c15|d14|d15|c20|c21|d20|d21|c22|c23|d22|d23|c28|c29|d28|d29|c30|c31|d30|d31
- |e4|e5|f4|f5|e6|e7|f6|f7|e12|e13|f12|f13|e14|e15|f14|f15|e20|e21|f20|f21|e22|e23|f22|f23|e28|e29|f28|f29|e30|e31|f30|f31
- |g4|g5|h4|h5|g6|g7|h6|h7|g12|g13|h12|h13|g14|g15|h14|h15|g20|g21|h20|h21|g22|g23|h22|h23|g28|g29|h28|h29|g30|g31|h30|h31
-
- Step 2: 4-element interleave for matrix
- |a0|a1|b0|b1|c0|c1|d0|d1|a8 |a9 |b8 |b9 |c8 |c9 |d8 |d9 |a16|a17|b16|b17|c16|c17|d16|d17|a24|a25|b24|b25|c24|c25|d24|d25
- |a2|a3|b2|b3|c2|c3|d2|d3|a10|a11|b10|b11|c10|c11|d10|d11|a18|a19|b18|b19|c18|c19|d18|d19|a26|a27|b26|b27|c26|c27|d26|d27
- |e0|e1|f0|f1|g0|g1|h0|h1|e8 |e9 |f8 |f9 |g8 |g9 |h8 |h9 |e16|e17|f16|f17|g16|g17|h16|h17|e24|e25|f24|f25|g24|g25|h24|h25
- |e2|e3|f2|f3|g2|g3|h2|h3|e10|e11|f10|f11|g10|g11|h10|h11|e18|e19|f18|f19|g18|g19|h18|h19|e26|e27|f26|f27|g26|g27|h26|h27
- |a4|a5|b4|b5|c4|c5|d4|d5|a12|a13|b12|b13|c12|c13|d12|d13|a20|a21|b20|b21|c20|c21|d20|d21|a28|a29|b28|b29|c28|c29|d28|d29
- |a6|a7|b6|b7|c6|c7|d6|d7|a14|a15|b14|b15|c14|c15|d14|d15|a22|a23|b22|b23|c22|c23|d22|d23|a30|a31|b30|b31|c30|c31|d30|d31
- |e4|e5|f4|f5|g4|g5|h4|h5|e12|e13|f12|f13|g12|g13|h12|h13|e20|e21|f20|f21|g20|g21|h20|h21|e28|e29|f28|f29|g28|g29|h28|h29
- |e6|e7|f6|f7|g6|g7|h6|h7|e14|e15|f14|f15|g14|g15|h14|h15|e22|e23|f22|f23|g22|g23|h22|h23|e30|e31|f30|f31|g30|g31|h30|h31
- */
- #define BF16_INTERLEAVE_8x32(regArray) \
- regArray##_8 = _mm512_unpacklo_epi32(regArray##_0, regArray##_1); \
- regArray##_9 = _mm512_unpacklo_epi32(regArray##_2, regArray##_3); \
- regArray##_10 = _mm512_unpacklo_epi32(regArray##_4, regArray##_5); \
- regArray##_11 = _mm512_unpacklo_epi32(regArray##_6, regArray##_7); \
- regArray##_12 = _mm512_unpackhi_epi32(regArray##_0, regArray##_1); \
- regArray##_13 = _mm512_unpackhi_epi32(regArray##_2, regArray##_3); \
- regArray##_14 = _mm512_unpackhi_epi32(regArray##_4, regArray##_5); \
- regArray##_15 = _mm512_unpackhi_epi32(regArray##_6, regArray##_7); \
- \
- regArray##_0 = _mm512_unpacklo_epi64(regArray##_8, regArray##_9); \
- regArray##_1 = _mm512_unpackhi_epi64(regArray##_8, regArray##_9); \
- regArray##_2 = _mm512_unpacklo_epi64(regArray##_10, regArray##_11); \
- regArray##_3 = _mm512_unpackhi_epi64(regArray##_10, regArray##_11); \
- regArray##_4 = _mm512_unpacklo_epi64(regArray##_12, regArray##_13); \
- regArray##_5 = _mm512_unpackhi_epi64(regArray##_12, regArray##_13); \
- regArray##_6 = _mm512_unpacklo_epi64(regArray##_14, regArray##_15); \
- regArray##_7 = _mm512_unpackhi_epi64(regArray##_14, regArray##_15);
-
-
- /* 2-step interleave for matrix against 8 rows with 16 BF16 elements per row
- Input - register array of 8 rows of raw-major matrix
- Output - the output of Step 2
-
- Step 1: 2-element interleave for matrix
- |a0|a1|b0|b1|a2|a3|b2|b3|a8 |a9 |b8 |b9 |a10|a11|b10|b11
- |c0|c1|d0|d1|c2|c3|d2|d3|c8 |c9 |d8 |d9 |c10|c11|d10|d11
- |e0|e1|f0|f1|e2|e3|f2|f3|e8 |e9 |f8 |f9 |e10|e11|f10|f11
- |g0|g1|h0|h1|g2|g3|h2|h3|g8 |g9 |h8 |h9 |g10|g11|h10|h11
- |a4|a5|b4|b5|a6|a7|b6|b7|a12|a13|b12|b13|a14|a15|b14|b15
- |c4|c5|d4|d5|c6|c7|d6|d7|c12|c13|d12|d13|c14|c15|d14|d15
- |e4|e5|f4|f5|e6|e7|f6|f7|e12|e13|f12|f13|e14|e15|f14|f15
- |g4|g5|h4|h5|g6|g7|h6|h7|g12|g13|h12|h13|g14|g15|h14|h15
-
- Step 2: 4-element interleave for matrix
- |a0|a1|b0|b1|c0|c1|d0|d1|a8 |a9 |b8 |b9 |c8 |c9 |d8 |d9
- |a2|a3|b2|b3|c2|c3|d2|d3|a10|a11|b10|b11|c10|c11|d10|d11
- |e0|e1|f0|f1|g0|g1|h0|h1|e8 |e9 |f8 |f9 |g8 |g9 |h8 |h9
- |e2|e3|f2|f3|g2|g3|h2|h3|e10|e11|f10|f11|g10|g11|h10|h11
- |a4|a5|b4|b5|c4|c5|d4|d5|a12|a13|b12|b13|c12|c13|d12|d13
- |a6|a7|b6|b7|c6|c7|d6|d7|a14|a15|b14|b15|c14|c15|d14|d15
- |e4|e5|f4|f5|g4|g5|h4|h5|e12|e13|f12|f13|g12|g13|h12|h13
- |e6|e7|f6|f7|g6|g7|h6|h7|e14|e15|f14|f15|g14|g15|h14|h15
- */
- #define BF16_INTERLEAVE_8x16(regArray) \
- regArray##_8 = _mm256_unpacklo_epi32(regArray##_0, regArray##_1); \
- regArray##_9 = _mm256_unpacklo_epi32(regArray##_2, regArray##_3); \
- regArray##_10 = _mm256_unpacklo_epi32(regArray##_4, regArray##_5); \
- regArray##_11 = _mm256_unpacklo_epi32(regArray##_6, regArray##_7); \
- regArray##_12 = _mm256_unpackhi_epi32(regArray##_0, regArray##_1); \
- regArray##_13 = _mm256_unpackhi_epi32(regArray##_2, regArray##_3); \
- regArray##_14 = _mm256_unpackhi_epi32(regArray##_4, regArray##_5); \
- regArray##_15 = _mm256_unpackhi_epi32(regArray##_6, regArray##_7); \
- \
- regArray##_0 = _mm256_unpacklo_epi64(regArray##_8, regArray##_9); \
- regArray##_1 = _mm256_unpackhi_epi64(regArray##_8, regArray##_9); \
- regArray##_2 = _mm256_unpacklo_epi64(regArray##_10, regArray##_11); \
- regArray##_3 = _mm256_unpackhi_epi64(regArray##_10, regArray##_11); \
- regArray##_4 = _mm256_unpacklo_epi64(regArray##_12, regArray##_13); \
- regArray##_5 = _mm256_unpackhi_epi64(regArray##_12, regArray##_13); \
- regArray##_6 = _mm256_unpacklo_epi64(regArray##_14, regArray##_15); \
- regArray##_7 = _mm256_unpackhi_epi64(regArray##_14, regArray##_15);
-
- /* 2-step interleave for matrix against 8 rows with 32 BF16 elements per row
- Input - register array of 8 rows of raw-major matrix
- Output - the output of Step 2
-
- Step 1: 2-element interleave for matrix
- |a0|a1|b0|b1|a2|a3|b2|b3|a8 |a9 |b8 |b9 |a10|a11|b10|b11|a16|a17|b16|b17|a18|a19|b18|b19|a24|a25|b24|b25|a26|a27|b26|b27
- |c0|c1|d0|d1|c2|c3|d2|d3|c8 |c9 |d8 |d9 |c10|c11|d10|d11|c16|c17|d16|d17|c18|c19|d18|d19|c24|c25|d24|d25|c26|c27|d26|d27
- |a4|a5|b4|b5|a6|a7|b6|b7|a12|a13|b12|b13|a14|a15|b14|b15|a20|a21|b20|b21|a22|a23|b22|b23|a28|a29|b28|b29|a30|a31|b30|b31
- |c4|c5|d4|d5|c6|c7|d6|d7|c12|c13|d12|d13|c14|c15|d14|d15|c20|c21|d20|d21|c22|c23|d22|d23|c28|c29|d28|d29|c30|c31|d30|d31
-
- Step 2: 4-element interleave for matrix
- |a0|a1|b0|b1|c0|c1|d0|d1|a8 |a9 |b8 |b9 |c8 |c9 |d8 |d9 |a16|a17|b16|b17|c16|c17|d16|d17|a24|a25|b24|b25|c24|c25|d24|d25
- |a2|a3|b2|b3|c2|c3|d2|d3|a10|a11|b10|b11|c10|c11|d10|d11|a18|a19|b18|b19|c18|c19|d18|d19|a26|a27|b26|b27|c26|c27|d26|d27
- |a4|a5|b4|b5|c4|c5|d4|d5|a12|a13|b12|b13|c12|c13|d12|d13|a20|a21|b20|b21|c20|c21|d20|d21|a28|a29|b28|b29|c28|c29|d28|d29
- |a6|a7|b6|b7|c6|c7|d6|d7|a14|a15|b14|b15|c14|c15|d14|d15|a22|a23|b22|b23|c22|c23|d22|d23|a30|a31|b30|b31|c30|c31|d30|d31
- */
- #define BF16_INTERLEAVE_4x32(regArray) \
- regArray##_4 = _mm512_unpacklo_epi32(regArray##_0, regArray##_1); \
- regArray##_5 = _mm512_unpacklo_epi32(regArray##_2, regArray##_3); \
- regArray##_6 = _mm512_unpackhi_epi32(regArray##_0, regArray##_1); \
- regArray##_7 = _mm512_unpackhi_epi32(regArray##_2, regArray##_3); \
- \
- regArray##_0 = _mm512_unpacklo_epi64(regArray##_4, regArray##_5); \
- regArray##_1 = _mm512_unpackhi_epi64(regArray##_4, regArray##_5); \
- regArray##_2 = _mm512_unpacklo_epi64(regArray##_6, regArray##_7); \
- regArray##_3 = _mm512_unpackhi_epi64(regArray##_6, regArray##_7);
-
-
- /* 2-step interleave for matrix against 8 rows with 16 BF16 elements per row
- Input - register array of 8 rows of raw-major matrix
- Output - the output of Step 2
-
- Step 1: 2-element interleave for matrix
- |a0|a1|b0|b1|a2|a3|b2|b3|a8 |a9 |b8 |b9 |a10|a11|b10|b11
- |c0|c1|d0|d1|c2|c3|d2|d3|c8 |c9 |d8 |d9 |c10|c11|d10|d11
- |a4|a5|b4|b5|a6|a7|b6|b7|a12|a13|b12|b13|a14|a15|b14|b15
- |c4|c5|d4|d5|c6|c7|d6|d7|c12|c13|d12|d13|c14|c15|d14|d15
-
- Step 2: 4-element interleave for matrix
- |a0|a1|b0|b1|c0|c1|d0|d1|a8 |a9 |b8 |b9 |c8 |c9 |d8 |d9
- |a2|a3|b2|b3|c2|c3|d2|d3|a10|a11|b10|b11|c10|c11|d10|d11
- |a4|a5|b4|b5|c4|c5|d4|d5|a12|a13|b12|b13|c12|c13|d12|d13
- |a6|a7|b6|b7|c6|c7|d6|d7|a14|a15|b14|b15|c14|c15|d14|d15
- */
- #define BF16_INTERLEAVE_4x16(regArray) \
- regArray##_4 = _mm256_unpacklo_epi32(regArray##_0, regArray##_1); \
- regArray##_5 = _mm256_unpacklo_epi32(regArray##_2, regArray##_3); \
- regArray##_6 = _mm256_unpackhi_epi32(regArray##_0, regArray##_1); \
- regArray##_7 = _mm256_unpackhi_epi32(regArray##_2, regArray##_3); \
- \
- regArray##_0 = _mm256_unpacklo_epi64(regArray##_4, regArray##_5); \
- regArray##_1 = _mm256_unpackhi_epi64(regArray##_4, regArray##_5); \
- regArray##_2 = _mm256_unpacklo_epi64(regArray##_6, regArray##_7); \
- regArray##_3 = _mm256_unpackhi_epi64(regArray##_6, regArray##_7);
-
-
- /* 2-step interleave for x with 32 BF16 elements
- Input - original vector
- Output - the output of Step 2
-
- Step 1: 2-element interleave for x:
- |x0|x1|x0|x1|x2|x3|x2|x3|x8 |x9 |x8 |x9 |x10|x11|x10|x11|x16|x17|x16|x17|x18|x19|x18|x19|x24|x25|x24|x25|x26|x27|x26|x27
- |x4|x5|x4|x5|x6|x7|x6|x7|x12|x13|x12|x13|x14|x15|x14|x15|x20|x21|x20|x21|x22|x23|x22|x23|x28|x29|x28|x29|x30|x31|x30|x31
-
- Step 2: 4-element interleave for x:
- |x0|x1|x0|x1|x0|x1|x0|x1|x8 |x9 |x8 |x9 |x8 |x9 |x8 |x9 |x16|x17|x16|x17|x16|x17|x16|x17|x24|x25|x24|x25|x24|x25|x24|x25
- |x2|x3|x2|x3|x2|x3|x2|x3|x10|x11|x10|x11|x10|x11|x10|x11|x18|x19|x18|x19|x18|x19|x18|x19|x26|x27|x26|x27|x26|x27|x26|x27
- |x4|x5|x4|x5|x4|x5|x4|x5|x12|x13|x12|x13|x12|x13|x12|x13|x20|x21|x20|x21|x20|x21|x20|x21|x28|x29|x28|x29|x28|x29|x28|x29
- |x6|x7|x6|x7|x6|x7|x6|x7|x14|x15|x14|x15|x14|x15|x14|x15|x22|x23|x22|x23|x22|x23|x22|x23|x30|x31|x30|x31|x30|x31|x30|x31
- */
- #define BF16_INTERLEAVE_1x32(regArray) \
- regArray##_1 = _mm512_unpacklo_epi32(regArray##_0, regArray##_0); \
- regArray##_3 = _mm512_unpackhi_epi32(regArray##_0, regArray##_0); \
- \
- regArray##_0 = _mm512_unpacklo_epi64(regArray##_1, regArray##_1); \
- regArray##_1 = _mm512_unpackhi_epi64(regArray##_1, regArray##_1); \
- regArray##_2 = _mm512_unpacklo_epi64(regArray##_3, regArray##_3); \
- regArray##_3 = _mm512_unpackhi_epi64(regArray##_3, regArray##_3);
-
-
- /* 2-step interleave for x with 16 BF16 elements
- Input - original vector
- Output - the output of Step 2
-
- Step 1: 2-element interleave for x:
- |x0|x1|x0|x1|x2|x3|x2|x3|x8 |x9 |x8 |x9 |x10|x11|x10|x11
- |x4|x5|x4|x5|x6|x7|x6|x7|x12|x13|x12|x13|x14|x15|x14|x15
-
- Step 2: 4-element interleave for x:
- |x0|x1|x0|x1|x0|x1|x0|x1|x8 |x9 |x8 |x9 |x8 |x9 |x8 |x9
- |x2|x3|x2|x3|x2|x3|x2|x3|x10|x11|x10|x11|x10|x11|x10|x11
- |x4|x5|x4|x5|x4|x5|x4|x5|x12|x13|x12|x13|x12|x13|x12|x13
- |x6|x7|x6|x7|x6|x7|x6|x7|x14|x15|x14|x15|x14|x15|x14|x15
- */
- #define BF16_INTERLEAVE_1x16(regArray) \
- regArray##_1 = _mm256_unpacklo_epi32(regArray##_0, regArray##_0); \
- regArray##_3 = _mm256_unpackhi_epi32(regArray##_0, regArray##_0); \
- \
- regArray##_0 = _mm256_unpacklo_epi64(regArray##_1, regArray##_1); \
- regArray##_1 = _mm256_unpackhi_epi64(regArray##_1, regArray##_1); \
- regArray##_2 = _mm256_unpacklo_epi64(regArray##_3, regArray##_3); \
- regArray##_3 = _mm256_unpackhi_epi64(regArray##_3, regArray##_3);
-
- /* 1-step interleave to exchange the high-256s bit and low-256 bits of 4 pair of registers
- |a0|a1|...|a14|a15|i0|i1|...|i14|i15|
- |b0|b1|...|b14|b15|j0|j1|...|j14|j15|
- |c0|c1|...|c14|c15|k0|k1|...|k14|k15|
- |d0|d1|...|d14|d15|l0|l1|...|l14|l15|
- |e0|e1|...|e14|e15|m0|m1|...|m14|m15|
- |f0|f1|...|f14|f15|n0|n1|...|n14|n15|
- |g0|g1|...|g14|g15|o0|o1|...|o14|o15|
- |h0|h1|...|h14|h15|p0|p1|...|p14|p15|
- */
- #define BF16_INTERLEAVE256_8x32(regArray) \
- regArray##_0 = _mm512_shuffle_i32x4(regArray##_8, regArray##_12, 0x44); \
- regArray##_1 = _mm512_shuffle_i32x4(regArray##_8, regArray##_12, 0xee); \
- regArray##_2 = _mm512_shuffle_i32x4(regArray##_9, regArray##_13, 0x44); \
- regArray##_3 = _mm512_shuffle_i32x4(regArray##_9, regArray##_13, 0xee); \
- regArray##_4 = _mm512_shuffle_i32x4(regArray##_10, regArray##_14, 0x44); \
- regArray##_5 = _mm512_shuffle_i32x4(regArray##_10, regArray##_14, 0xee); \
- regArray##_6 = _mm512_shuffle_i32x4(regArray##_11, regArray##_15, 0x44); \
- regArray##_7 = _mm512_shuffle_i32x4(regArray##_11, regArray##_15, 0xee);
-
-
- /* 1-step interleave to exchange the high-256s bit and low-256 bits of 2 pair of registers
- |a0|a1|...|a14|a15|e0|e1|...|e14|e15|
- |b0|b1|...|b14|b15|f0|f1|...|f14|f15|
- |c0|c1|...|c14|c15|g0|g1|...|g14|g15|
- |d0|d1|...|d14|d15|h0|h1|...|h14|h15|
- */
- #define BF16_INTERLEAVE256_4x32(regArray) \
- regArray##_0 = _mm512_shuffle_i32x4(regArray##_4, regArray##_6, 0x44); \
- regArray##_1 = _mm512_shuffle_i32x4(regArray##_4, regArray##_6, 0xee); \
- regArray##_2 = _mm512_shuffle_i32x4(regArray##_5, regArray##_7, 0x44); \
- regArray##_3 = _mm512_shuffle_i32x4(regArray##_5, regArray##_7, 0xee);
-
-
- #define BF16_PERMUTE_8x32(idx, regArray) \
- regArray##_8 = _mm512_permutexvar_epi16(idx, regArray##_0); \
- regArray##_9 = _mm512_permutexvar_epi16(idx, regArray##_1); \
- regArray##_10 = _mm512_permutexvar_epi16(idx, regArray##_2); \
- regArray##_11 = _mm512_permutexvar_epi16(idx, regArray##_3); \
- regArray##_12 = _mm512_permutexvar_epi16(idx, regArray##_4); \
- regArray##_13 = _mm512_permutexvar_epi16(idx, regArray##_5); \
- regArray##_14 = _mm512_permutexvar_epi16(idx, regArray##_6); \
- regArray##_15 = _mm512_permutexvar_epi16(idx, regArray##_7);
-
-
- #define BF16_PERMUTE_8x32_2(idx, regArray) \
- regArray##_8 = _mm512_permutexvar_epi32(idx, regArray##_0); \
- regArray##_9 = _mm512_permutexvar_epi32(idx, regArray##_1); \
- regArray##_10 = _mm512_permutexvar_epi32(idx, regArray##_2); \
- regArray##_11 = _mm512_permutexvar_epi32(idx, regArray##_3); \
- regArray##_12 = _mm512_permutexvar_epi32(idx, regArray##_4); \
- regArray##_13 = _mm512_permutexvar_epi32(idx, regArray##_5); \
- regArray##_14 = _mm512_permutexvar_epi32(idx, regArray##_6); \
- regArray##_15 = _mm512_permutexvar_epi32(idx, regArray##_7);
-
-
- #define BF16_PERMUTE_4x32(idx, regArray) \
- regArray##_4 = _mm512_permutexvar_epi16(idx, regArray##_0); \
- regArray##_5 = _mm512_permutexvar_epi16(idx, regArray##_1); \
- regArray##_6 = _mm512_permutexvar_epi16(idx, regArray##_2); \
- regArray##_7 = _mm512_permutexvar_epi16(idx, regArray##_3);
-
-
- #define BF16_PERMUTE_4x32_2(idx, regArray) \
- regArray##_4 = _mm512_permutexvar_epi32(idx, regArray##_0); \
- regArray##_5 = _mm512_permutexvar_epi32(idx, regArray##_1); \
- regArray##_6 = _mm512_permutexvar_epi32(idx, regArray##_2); \
- regArray##_7 = _mm512_permutexvar_epi32(idx, regArray##_3);
-
-
- /* Calculate the dot result for 2-step interleaved matrix and vector
- (Assume throughput for _mm512_dpbf16_ps is 0.5, tunable per platform)
- */
- #define BF16_2STEP_INTERLEAVED_DOT_8x32(accumArray, matArray, xArray) \
- accumArray##_0 = _mm512_dpbf16_ps(accumArray##_0, (__m512bh) matArray##_0, (__m512bh) xArray##_0); \
- accumArray##_1 = _mm512_dpbf16_ps(accumArray##_1, (__m512bh) matArray##_2, (__m512bh) xArray##_0); \
- accumArray##_0 = _mm512_dpbf16_ps(accumArray##_0, (__m512bh) matArray##_1, (__m512bh) xArray##_1); \
- accumArray##_1 = _mm512_dpbf16_ps(accumArray##_1, (__m512bh) matArray##_3, (__m512bh) xArray##_1); \
- accumArray##_0 = _mm512_dpbf16_ps(accumArray##_0, (__m512bh) matArray##_4, (__m512bh) xArray##_2); \
- accumArray##_1 = _mm512_dpbf16_ps(accumArray##_1, (__m512bh) matArray##_6, (__m512bh) xArray##_2); \
- accumArray##_0 = _mm512_dpbf16_ps(accumArray##_0, (__m512bh) matArray##_5, (__m512bh) xArray##_3); \
- accumArray##_1 = _mm512_dpbf16_ps(accumArray##_1, (__m512bh) matArray##_7, (__m512bh) xArray##_3);
-
-
- /* Calculate the dot result for 2-step interleaved matrix and vector
- (Assume throughput for _mm256_dpbf16_ps is 0.5, tunable per platform)
- */
- #define BF16_2STEP_INTERLEAVED_DOT_8x16(accumArray, matArray, xArray) \
- accumArray##_0 = _mm256_dpbf16_ps(accumArray##_0, (__m256bh) matArray##_0, (__m256bh) xArray##_0); \
- accumArray##_1 = _mm256_dpbf16_ps(accumArray##_1, (__m256bh) matArray##_2, (__m256bh) xArray##_0); \
- accumArray##_0 = _mm256_dpbf16_ps(accumArray##_0, (__m256bh) matArray##_1, (__m256bh) xArray##_1); \
- accumArray##_1 = _mm256_dpbf16_ps(accumArray##_1, (__m256bh) matArray##_3, (__m256bh) xArray##_1); \
- accumArray##_0 = _mm256_dpbf16_ps(accumArray##_0, (__m256bh) matArray##_4, (__m256bh) xArray##_2); \
- accumArray##_1 = _mm256_dpbf16_ps(accumArray##_1, (__m256bh) matArray##_6, (__m256bh) xArray##_2); \
- accumArray##_0 = _mm256_dpbf16_ps(accumArray##_0, (__m256bh) matArray##_5, (__m256bh) xArray##_3); \
- accumArray##_1 = _mm256_dpbf16_ps(accumArray##_1, (__m256bh) matArray##_7, (__m256bh) xArray##_3);
-
- /* Calculate the dot result for 2-step interleaved matrix and vector
- (Assume throughput for _mm512_dpbf16_ps is 0.5, tunable per platform)
- */
- #define BF16_2STEP_INTERLEAVED_DOT_4x32(accumArray, matArray, xArray) \
- accumArray##_0 = _mm512_dpbf16_ps(accumArray##_0, (__m512bh) matArray##_0, (__m512bh) xArray##_0); \
- accumArray##_1 = _mm512_dpbf16_ps(accumArray##_1, (__m512bh) matArray##_1, (__m512bh) xArray##_1); \
- accumArray##_0 = _mm512_dpbf16_ps(accumArray##_0, (__m512bh) matArray##_2, (__m512bh) xArray##_2); \
- accumArray##_1 = _mm512_dpbf16_ps(accumArray##_1, (__m512bh) matArray##_3, (__m512bh) xArray##_3);
-
-
- /* Calculate the dot result for 2-step interleaved matrix and vector
- (Assume throughput for _mm256_dpbf16_ps is 0.5, tunable per platform)
- */
- #define BF16_2STEP_INTERLEAVED_DOT_4x16(accumArray, matArray, xArray) \
- accumArray##_0 = _mm256_dpbf16_ps(accumArray##_0, (__m256bh) matArray##_0, (__m256bh) xArray##_0); \
- accumArray##_1 = _mm256_dpbf16_ps(accumArray##_1, (__m256bh) matArray##_1, (__m256bh) xArray##_1); \
- accumArray##_0 = _mm256_dpbf16_ps(accumArray##_0, (__m256bh) matArray##_2, (__m256bh) xArray##_2); \
- accumArray##_1 = _mm256_dpbf16_ps(accumArray##_1, (__m256bh) matArray##_3, (__m256bh) xArray##_3);
-
-
- /* Calculate the dot result for matrix and vector at 32 elements per row
- (Assume throughput for _mm512_dpbf16_ps is 0.5, tunable per platform)
- */
- #define BF16_DOT_8x32(accumArray, matArray, xArray) \
- accumArray##_0 = _mm512_dpbf16_ps(accumArray##_0, (__m512bh) matArray##_0, (__m512bh) xArray); \
- accumArray##_1 = _mm512_dpbf16_ps(accumArray##_1, (__m512bh) matArray##_1, (__m512bh) xArray); \
- accumArray##_2 = _mm512_dpbf16_ps(accumArray##_2, (__m512bh) matArray##_2, (__m512bh) xArray); \
- accumArray##_3 = _mm512_dpbf16_ps(accumArray##_3, (__m512bh) matArray##_3, (__m512bh) xArray); \
- accumArray##_4 = _mm512_dpbf16_ps(accumArray##_4, (__m512bh) matArray##_4, (__m512bh) xArray); \
- accumArray##_5 = _mm512_dpbf16_ps(accumArray##_5, (__m512bh) matArray##_5, (__m512bh) xArray); \
- accumArray##_6 = _mm512_dpbf16_ps(accumArray##_6, (__m512bh) matArray##_6, (__m512bh) xArray); \
- accumArray##_7 = _mm512_dpbf16_ps(accumArray##_7, (__m512bh) matArray##_7, (__m512bh) xArray);
-
- /* Calculate the dot result for matrix and vector at 32 elements per row
- (Assume throughput for _mm512_dpbf16_ps is 0.5, tunable per platform)
- */
- #define BF16_DOT_1x32(accumArray, matArray, xArray) \
- accumArray = _mm512_dpbf16_ps(accumArray, (__m512bh) matArray, (__m512bh) xArray);
-
- /* Calculate the dot result for matrix and vector at 16 elements per row
- (Assume throughput for _mm256_dpbf16_ps is 0.5, tunable per platform)
- */
- #define BF16_DOT_8x16(accumArray, matArray, xArray) \
- accumArray##_0 = _mm256_dpbf16_ps(accumArray##_0, (__m256bh) matArray##_0, (__m256bh) xArray); \
- accumArray##_1 = _mm256_dpbf16_ps(accumArray##_1, (__m256bh) matArray##_1, (__m256bh) xArray); \
- accumArray##_2 = _mm256_dpbf16_ps(accumArray##_2, (__m256bh) matArray##_2, (__m256bh) xArray); \
- accumArray##_3 = _mm256_dpbf16_ps(accumArray##_3, (__m256bh) matArray##_3, (__m256bh) xArray); \
- accumArray##_4 = _mm256_dpbf16_ps(accumArray##_4, (__m256bh) matArray##_4, (__m256bh) xArray); \
- accumArray##_5 = _mm256_dpbf16_ps(accumArray##_5, (__m256bh) matArray##_5, (__m256bh) xArray); \
- accumArray##_6 = _mm256_dpbf16_ps(accumArray##_6, (__m256bh) matArray##_6, (__m256bh) xArray); \
- accumArray##_7 = _mm256_dpbf16_ps(accumArray##_7, (__m256bh) matArray##_7, (__m256bh) xArray);
-
-
- /* 2-step interleave for matrix against 8 rows with 16 fp32 elements per row
- Input - register array of 8 rows of raw-major matrix
- Output - the output of Step 2
-
- Step 1: 2-element interleave for matrix
- |a0|b0|a1|b1|a4|b4|a5|b5|a8 |b8 |a9 |b9 |a12|b12|a13|b13|
- |c0|d0|c1|d1|c4|d4|c5|d5|c8 |d8 |c9 |d9 |c12|d12|c13|d13|
- |e0|f0|e1|f1|e4|f4|e5|f5|e8 |f8 |e9 |f9 |e12|f12|e13|f13|
- |g0|h0|g1|h1|g4|h4|g5|h5|g8 |h8 |g9 |h9 |g12|h12|g13|h13|
- |a2|b2|a3|b3|a6|b6|a7|b7|a10|b10|a11|b11|a14|b14|a15|b15|
- |c2|d2|c3|d3|c6|d6|c7|d7|c10|d10|c11|d11|c14|d14|c15|d15|
- |e2|f2|e3|f3|e6|f6|e7|f7|e10|f10|e11|f11|e14|f14|e15|f15|
- |g2|h2|g3|h3|g6|h6|g7|h7|g10|h10|g11|h11|g14|h14|g15|h15|
-
- Step 2: 4-element interleave for matrix
- |a0|b0|c0|d0|a4|b4|c4|d4|a8 |b8 |c8 |d8 |a12|b12|c12|d12|
- |a1|b1|c1|d1|a5|b5|c5|d5|a9 |b9 |c9 |d9 |a13|b13|c13|d13|
- |e0|f0|g0|h0|e4|f4|g4|h4|e8 |f8 |g8 |h8 |e12|f12|g12|h12|
- |e1|f1|g1|h1|e5|f5|g5|h5|e9 |f9 |g9 |h9 |e13|f13|g13|h13|
- |a2|b2|c2|d2|a6|b6|c6|d6|a10|b10|c10|d10|a14|b14|c14|d14|
- |a3|b3|c3|d3|a7|b7|c7|d7|a11|b11|c11|d11|a15|b15|c15|d15|
- |e2|f2|g2|h2|e6|f6|g6|h6|e10|f10|g10|h10|e14|f14|g14|h14|
- |e3|f3|g3|h3|e7|f7|g7|h7|e11|f11|g11|h11|e15|f15|g15|h15|
- */
- #define FP32_INTERLEAVE_8x16(regArray) \
- regArray##_8 = _mm512_unpacklo_ps(regArray##_0, regArray##_1); \
- regArray##_9 = _mm512_unpacklo_ps(regArray##_2, regArray##_3); \
- regArray##_10 = _mm512_unpacklo_ps(regArray##_4, regArray##_5); \
- regArray##_11 = _mm512_unpacklo_ps(regArray##_6, regArray##_7); \
- regArray##_12 = _mm512_unpackhi_ps(regArray##_0, regArray##_1); \
- regArray##_13 = _mm512_unpackhi_ps(regArray##_2, regArray##_3); \
- regArray##_14 = _mm512_unpackhi_ps(regArray##_4, regArray##_5); \
- regArray##_15 = _mm512_unpackhi_ps(regArray##_6, regArray##_7); \
- \
- regArray##_0 = (__m512) _mm512_unpacklo_pd((__m512d) regArray##_8, (__m512d) regArray##_9); \
- regArray##_1 = (__m512) _mm512_unpackhi_pd((__m512d) regArray##_8, (__m512d) regArray##_9); \
- regArray##_4 = (__m512) _mm512_unpacklo_pd((__m512d) regArray##_10, (__m512d) regArray##_11); \
- regArray##_5 = (__m512) _mm512_unpackhi_pd((__m512d) regArray##_10, (__m512d) regArray##_11); \
- regArray##_2 = (__m512) _mm512_unpacklo_pd((__m512d) regArray##_12, (__m512d) regArray##_13); \
- regArray##_3 = (__m512) _mm512_unpackhi_pd((__m512d) regArray##_12, (__m512d) regArray##_13); \
- regArray##_6 = (__m512) _mm512_unpacklo_pd((__m512d) regArray##_14, (__m512d) regArray##_15); \
- regArray##_7 = (__m512) _mm512_unpackhi_pd((__m512d) regArray##_14, (__m512d) regArray##_15);
-
- #define FP32_INTERLEAVE_8x16_ARRAY(regArray) \
- regArray[8] = _mm512_unpacklo_ps(regArray[0], regArray[1]); \
- regArray[9] = _mm512_unpacklo_ps(regArray[2], regArray[3]); \
- regArray[10] = _mm512_unpacklo_ps(regArray[4], regArray[5]); \
- regArray[11] = _mm512_unpacklo_ps(regArray[6], regArray[7]); \
- regArray[12] = _mm512_unpackhi_ps(regArray[0], regArray[1]); \
- regArray[13] = _mm512_unpackhi_ps(regArray[2], regArray[3]); \
- regArray[14] = _mm512_unpackhi_ps(regArray[4], regArray[5]); \
- regArray[15] = _mm512_unpackhi_ps(regArray[6], regArray[7]); \
- \
- regArray[0] = (__m512) _mm512_unpacklo_pd((__m512d) regArray[8], (__m512d) regArray[9]); \
- regArray[1] = (__m512) _mm512_unpackhi_pd((__m512d) regArray[8], (__m512d) regArray[9]); \
- regArray[4] = (__m512) _mm512_unpacklo_pd((__m512d) regArray[10], (__m512d) regArray[11]); \
- regArray[5] = (__m512) _mm512_unpackhi_pd((__m512d) regArray[10], (__m512d) regArray[11]); \
- regArray[2] = (__m512) _mm512_unpacklo_pd((__m512d) regArray[12], (__m512d) regArray[13]); \
- regArray[3] = (__m512) _mm512_unpackhi_pd((__m512d) regArray[12], (__m512d) regArray[13]); \
- regArray[6] = (__m512) _mm512_unpacklo_pd((__m512d) regArray[14], (__m512d) regArray[15]); \
- regArray[7] = (__m512) _mm512_unpackhi_pd((__m512d) regArray[14], (__m512d) regArray[15]);
-
- /* 2-step interleave for matrix against 8 rows with 8 fp32 elements per row
- Input - register array of 8 rows of raw-major matrix
- Output - the output of Step 2
-
- Step 1: 2-element interleave for matrix
- |a0|b0|a1|b1|a4|b4|a5|b5|
- |c0|d0|c1|d1|c4|d4|c5|d5|
- |e0|f0|e1|f1|e4|f4|e5|f5|
- |g0|h0|g1|h1|g4|h4|g5|h5|
- |a2|b2|a3|b3|a6|b6|a7|b7|
- |c2|d2|c3|d3|c6|d6|c7|d7|
- |e2|f2|e3|f3|e6|f6|e7|f7|
- |g2|h2|g3|h3|g6|h6|g7|h7|
-
- Step 2: 4-element interleave for matrix
- |a0|b0|c0|d0|a4|b4|c4|d4|
- |a1|b1|c1|d1|a5|b5|c5|d5|
- |e0|f0|g0|h0|e4|f4|g4|h4|
- |e1|f1|g1|h1|e5|f5|g5|h5|
- |a2|b2|c2|d2|a6|b6|c6|d6|
- |a3|b3|c3|d3|a7|b7|c7|d7|
- |e2|f2|g2|h2|e6|f6|g6|h6|
- |e3|f3|g3|h3|e7|f7|g7|h7|
- */
- #define FP32_INTERLEAVE_8x8(regArray) \
- regArray##_8 = _mm256_unpacklo_ps(regArray##_0, regArray##_1); \
- regArray##_9 = _mm256_unpacklo_ps(regArray##_2, regArray##_3); \
- regArray##_10 = _mm256_unpacklo_ps(regArray##_4, regArray##_5); \
- regArray##_11 = _mm256_unpacklo_ps(regArray##_6, regArray##_7); \
- regArray##_12 = _mm256_unpackhi_ps(regArray##_0, regArray##_1); \
- regArray##_13 = _mm256_unpackhi_ps(regArray##_2, regArray##_3); \
- regArray##_14 = _mm256_unpackhi_ps(regArray##_4, regArray##_5); \
- regArray##_15 = _mm256_unpackhi_ps(regArray##_6, regArray##_7); \
- \
- regArray##_0 = (__m256) _mm256_unpacklo_pd((__m256d) regArray##_8, (__m256d) regArray##_9); \
- regArray##_1 = (__m256) _mm256_unpackhi_pd((__m256d) regArray##_8, (__m256d) regArray##_9); \
- regArray##_4 = (__m256) _mm256_unpacklo_pd((__m256d) regArray##_10, (__m256d) regArray##_11); \
- regArray##_5 = (__m256) _mm256_unpackhi_pd((__m256d) regArray##_10, (__m256d) regArray##_11); \
- regArray##_2 = (__m256) _mm256_unpacklo_pd((__m256d) regArray##_12, (__m256d) regArray##_13); \
- regArray##_3 = (__m256) _mm256_unpackhi_pd((__m256d) regArray##_12, (__m256d) regArray##_13); \
- regArray##_6 = (__m256) _mm256_unpacklo_pd((__m256d) regArray##_14, (__m256d) regArray##_15); \
- regArray##_7 = (__m256) _mm256_unpackhi_pd((__m256d) regArray##_14, (__m256d) regArray##_15);
-
-
- /* Accumulate the result for 2 batch of 4-registers
- */
- #define FP32_ACCUM2_8x16(regArray) \
- regArray##_0 = _mm512_add_ps(regArray##_0, regArray##_1); \
- regArray##_2 = _mm512_add_ps(regArray##_2, regArray##_3); \
- regArray##_4 = _mm512_add_ps(regArray##_4, regArray##_5); \
- regArray##_6 = _mm512_add_ps(regArray##_6, regArray##_7); \
- regArray##_0 = _mm512_add_ps(regArray##_0, regArray##_2); \
- regArray##_4 = _mm512_add_ps(regArray##_4, regArray##_6);
-
- #define FP32_ACCUM2_8x16_ARRAY(regArray) \
- regArray[0] = _mm512_add_ps(regArray[0], regArray[1]); \
- regArray[2] = _mm512_add_ps(regArray[2], regArray[3]); \
- regArray[4] = _mm512_add_ps(regArray[4], regArray[5]); \
- regArray[6] = _mm512_add_ps(regArray[6], regArray[7]); \
- regArray[0] = _mm512_add_ps(regArray[0], regArray[2]); \
- regArray[4] = _mm512_add_ps(regArray[4], regArray[6]);
-
- /* Accumulate the result for 2 batch of 4-registers
- */
- #define FP32_ACCUM2_8x8(regArray) \
- regArray##_0 = _mm256_add_ps(regArray##_0, regArray##_1); \
- regArray##_2 = _mm256_add_ps(regArray##_2, regArray##_3); \
- regArray##_4 = _mm256_add_ps(regArray##_4, regArray##_5); \
- regArray##_6 = _mm256_add_ps(regArray##_6, regArray##_7); \
- regArray##_0 = _mm256_add_ps(regArray##_0, regArray##_2); \
- regArray##_4 = _mm256_add_ps(regArray##_4, regArray##_6);
-
-
- /* Store 16 (alpha * result + beta * y) to y
- */
- #define STORE16_COMPLETE_RESULT_ALPHA_BETA(regResult, targetAddr) \
- regResult = _mm512_fmadd_ps(ALPHAVECTOR, regResult, _mm512_mul_ps(BETAVECTOR, _mm512_loadu_ps(targetAddr))); \
- _mm512_storeu_ps(targetAddr, regResult);
-
-
- /* Masked store 16 (alpha * result + beta * y) to y
- */
- #define STORE16_MASK_COMPLETE_RESULT_ALPHA_BETA(regResult, targetAddr, mask) \
- regResult = _mm512_fmadd_ps(ALPHAVECTOR, regResult, _mm512_mul_ps(BETAVECTOR, _mm512_maskz_loadu_ps(mask, targetAddr))); \
- _mm512_mask_storeu_ps(targetAddr, mask, regResult);
-
-
- /* Store 8 (alpha * result + beta * y) to y
- */
- #define STORE8_COMPLETE_RESULT_ALPHA_BETA(regResult, targetAddr) \
- regResult = _mm256_fmadd_ps(_mm512_castps512_ps256(ALPHAVECTOR), regResult, _mm256_mul_ps(_mm512_castps512_ps256(BETAVECTOR), _mm256_loadu_ps(targetAddr))); \
- _mm256_storeu_ps(targetAddr, regResult);
-
-
- /* Masked store 8 (alpha * result + beta * y) to y
- */
- #define STORE8_MASK_COMPLETE_RESULT_ALPHA_BETA(regResult, targetAddr, mask) \
- regResult = _mm256_fmadd_ps(_mm512_castps512_ps256(ALPHAVECTOR), regResult, _mm256_mul_ps(_mm512_castps512_ps256(BETAVECTOR), _mm256_maskz_loadu_ps(mask, targetAddr))); \
- _mm256_mask_storeu_ps(targetAddr, mask, regResult);
-
-
- /* Store 4 (alpha * result + beta * y) to y
- */
- #define STORE4_COMPLETE_RESULT_ALPHA_BETA(regResult, targetAddr) \
- regResult = _mm_fmadd_ps(_mm512_castps512_ps128(ALPHAVECTOR), regResult, _mm_mul_ps(_mm512_castps512_ps128(BETAVECTOR), _mm_loadu_ps(targetAddr))); \
- _mm_storeu_ps(targetAddr, regResult);
-
-
- /* Masked store 4 (alpha * result + beta * y) to y
- */
- #define STORE4_MASK_COMPLETE_RESULT_ALPHA_BETA(regResult, targetAddr, mask) \
- regResult = _mm_fmadd_ps(_mm512_castps512_ps128(ALPHAVECTOR), regResult, _mm_mul_ps(_mm512_castps512_ps128(BETAVECTOR), _mm_maskz_loadu_ps(mask, targetAddr))); \
- _mm_mask_storeu_ps(targetAddr, mask, regResult);
-
-
- /* Store 16 (alpha * result + y) to y
- */
- #define STORE16_COMPLETE_RESULT_ALPHA_ONE(regResult, targetAddr) \
- regResult = _mm512_fmadd_ps(ALPHAVECTOR, regResult, _mm512_loadu_ps(targetAddr)); \
- _mm512_storeu_ps(targetAddr, regResult);
-
-
- /* Masked store 16 (alpha * result + y) to y
- */
- #define STORE16_MASK_COMPLETE_RESULT_ALPHA_ONE(regResult, targetAddr, mask) \
- regResult = _mm512_fmadd_ps(ALPHAVECTOR, regResult, _mm512_maskz_loadu_ps(mask, targetAddr)); \
- _mm512_mask_storeu_ps(targetAddr, mask, regResult);
-
-
- /* Store 8 (alpha * result + y) to y
- */
- #define STORE8_COMPLETE_RESULT_ALPHA_ONE(regResult, targetAddr) \
- regResult = _mm256_fmadd_ps(_mm512_castps512_ps256(ALPHAVECTOR), regResult, _mm256_loadu_ps(targetAddr)); \
- _mm256_storeu_ps(targetAddr, regResult);
-
-
- /* Masked store 8 (alpha * result + y) to y
- */
- #define STORE8_MASK_COMPLETE_RESULT_ALPHA_ONE(regResult, targetAddr, mask) \
- regResult = _mm256_fmadd_ps(_mm512_castps512_ps256(ALPHAVECTOR), regResult, _mm256_maskz_loadu_ps(mask, targetAddr)); \
- _mm256_mask_storeu_ps(targetAddr, mask, regResult);
-
-
- /* Store 4 (alpha * result + y) to y
- */
- #define STORE4_COMPLETE_RESULT_ALPHA_ONE(regResult, targetAddr) \
- regResult = _mm_fmadd_ps(_mm512_castps512_ps128(ALPHAVECTOR), regResult, _mm_loadu_ps(targetAddr)); \
- _mm_storeu_ps(targetAddr, regResult);
-
-
- /* Masked store 4 (alpha * result + y) to y
- */
- #define STORE4_MASK_COMPLETE_RESULT_ALPHA_ONE(regResult, targetAddr, mask) \
- regResult = _mm_fmadd_ps(_mm512_castps512_ps128(ALPHAVECTOR), regResult, _mm_maskz_loadu_ps(mask, targetAddr)); \
- _mm_mask_storeu_ps(targetAddr, mask, regResult);
-
-
- /* Store 16 (result + y) to y
- */
- #define STORE16_COMPLETE_RESULT_ONE_ONE(regResult, targetAddr) \
- regResult = _mm512_add_ps(regResult, _mm512_loadu_ps(targetAddr)); \
- _mm512_storeu_ps(targetAddr, regResult);
-
-
- /* Masked store 16 (result + y) to y
- */
- #define STORE16_MASK_COMPLETE_RESULT_ONE_ONE(regResult, targetAddr, mask) \
- regResult = _mm512_add_ps(regResult, _mm512_maskz_loadu_ps(mask, targetAddr)); \
- _mm512_mask_storeu_ps(targetAddr, mask, regResult);
-
-
- /* Store 8 (result + y) to y
- */
- #define STORE8_COMPLETE_RESULT_ONE_ONE(regResult, targetAddr) \
- regResult = _mm256_add_ps(regResult, _mm256_loadu_ps(targetAddr)); \
- _mm256_storeu_ps(targetAddr, regResult);
-
-
- /* Masked store 8 (result + y) to y
- */
- #define STORE8_MASK_COMPLETE_RESULT_ONE_ONE(regResult, targetAddr, mask) \
- regResult = _mm256_add_ps(regResult, _mm256_maskz_loadu_ps(mask, targetAddr)); \
- _mm256_mask_storeu_ps(targetAddr, mask, regResult);
-
-
- /* Store 4 (result + y) to y
- */
- #define STORE4_COMPLETE_RESULT_ONE_ONE(regResult, targetAddr) \
- regResult = _mm_add_ps(regResult, _mm_loadu_ps(targetAddr)); \
- _mm_storeu_ps(targetAddr, regResult);
-
-
- /* Masked store 4 (result + y) to y
- */
- #define STORE4_MASK_COMPLETE_RESULT_ONE_ONE(regResult, targetAddr, mask) \
- regResult = _mm_add_ps(regResult, _mm_maskz_loadu_ps(mask, targetAddr)); \
- _mm_mask_storeu_ps(targetAddr, mask, regResult);
-
-
- /* Store 16 (alpha * result) to y
- */
- #define STORE16_COMPLETE_RESULT_ALPHA(regResult, targetAddr) \
- _mm512_storeu_ps(targetAddr, _mm512_mul_ps(ALPHAVECTOR, regResult));
-
-
- /* Masked store 16 (alpha * result) to y
- */
- #define STORE16_MASK_COMPLETE_RESULT_ALPHA(regResult, targetAddr, mask) \
- _mm512_mask_storeu_ps(targetAddr, mask, _mm512_mul_ps(ALPHAVECTOR, regResult));
-
-
- /* Store 8 (alpha * result) to y
- */
- #define STORE8_COMPLETE_RESULT_ALPHA(regResult, targetAddr) \
- _mm256_storeu_ps(targetAddr, _mm256_mul_ps(_mm512_castps512_ps256(ALPHAVECTOR), regResult));
-
-
- /* Masked store 8 (alpha * result) to y
- */
- #define STORE8_MASK_COMPLETE_RESULT_ALPHA(regResult, targetAddr, mask) \
- _mm256_mask_storeu_ps(targetAddr, mask, _mm256_mul_ps(_mm512_castps512_ps256(ALPHAVECTOR), regResult));
-
-
- /* Store 4 (alpha * result) to y
- */
- #define STORE4_COMPLETE_RESULT_ALPHA(regResult, targetAddr) \
- _mm_storeu_ps(targetAddr, _mm_mul_ps(_mm512_castps512_ps128(ALPHAVECTOR), regResult));
-
-
- /* Masked store 4 (alpha * result) to y
- */
- #define STORE4_MASK_COMPLETE_RESULT_ALPHA(regResult, targetAddr, mask) \
- _mm_mask_storeu_ps(targetAddr, mask, _mm_mul_ps(_mm512_castps512_ps128(ALPHAVECTOR), regResult));
-
-
- /* Store 16 result to y
- */
- #define STORE16_COMPLETE_RESULT_DIRECT(regResult, targetAddr) \
- _mm512_storeu_ps(targetAddr, regResult);
-
-
- /* Masked store 16 result to y
- */
- #define STORE16_MASK_COMPLETE_RESULT_DIRECT(regResult, targetAddr, mask) \
- _mm512_mask_storeu_ps(targetAddr, mask, regResult);
-
-
- /* Store 8 result to y
- */
- #define STORE8_COMPLETE_RESULT_DIRECT(regResult, targetAddr) \
- _mm256_storeu_ps(targetAddr, regResult);
-
-
- /* Masked store 8 result to y
- */
- #define STORE8_MASK_COMPLETE_RESULT_DIRECT(regResult, targetAddr, mask) \
- _mm256_mask_storeu_ps(targetAddr, mask, regResult);
-
-
- /* Store 4 result to y
- */
- #define STORE4_COMPLETE_RESULT_DIRECT(regResult, targetAddr) \
- _mm_storeu_ps(targetAddr, regResult);
-
-
- /* Masked store 4 result to y
- */
- #define STORE4_MASK_COMPLETE_RESULT_DIRECT(regResult, targetAddr, mask) \
- _mm_mask_storeu_ps(targetAddr, mask, regResult);
-
- #endif
|