|
- /***************************************************************************
- * Copyright (c) 2021, The OpenBLAS Project
- * All rights reserved.
- * Redistribution and use in source and binary forms, with or without
- * modification, are permitted provided that the following conditions are
- * met:
- * 1. Redistributions of source code must retain the above copyright
- * notice, this list of conditions and the following disclaimer.
- * 2. Redistributions in binary form must reproduce the above copyright
- * notice, this list of conditions and the following disclaimer in
- * the documentation and/or other materials provided with the
- * distribution.
- * 3. Neither the name of the OpenBLAS project nor the names of
- * its contributors may be used to endorse or promote products
- * derived from this software without specific prior written permission.
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
- * ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
- * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
- * USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
- * *****************************************************************************/
-
- #include <immintrin.h>
- #include <string.h>
- #include "common.h"
-
- #ifndef SBGEMM_KERNEL_SPR
- #define SBGEMM_KERNEL_SPR
- typedef struct {
- char palette_id;
- char start_row;
- char dummy0[14]; // bytes 2-15 reserved, must be zero
- short tile_colsb[8];
- char dummy1[16]; // bytes 32-47 reserved, must be zero
- char tile_rows[8];
- char dummy2[16]; // bytes 56-63 reserved, must be zero
- } tilecfg;
-
- /* tile0/tile1 -- A (m x 2k)
- * tile2/tile3 -- B (2k x n)
- * tile4-7 -- C (m x n)
- */
- #define TCONF(cfg, m, n, k2) \
- memset(&cfg, 0, sizeof(tilecfg)); \
- cfg.palette_id = 1; \
- cfg.tile_rows[0] = m; \
- cfg.tile_rows[1] = m; \
- cfg.tile_rows[2] = k2>>1; \
- cfg.tile_rows[3] = k2>>1; \
- cfg.tile_rows[4] = m; \
- cfg.tile_rows[5] = m; \
- cfg.tile_rows[6] = m; \
- cfg.tile_rows[7] = m; \
- cfg.tile_colsb[0] = k2<<1; \
- cfg.tile_colsb[1] = k2<<1; \
- cfg.tile_colsb[2] = n * 4; \
- cfg.tile_colsb[3] = n * 4; \
- cfg.tile_colsb[4] = n * 4; \
- cfg.tile_colsb[5] = n * 4; \
- cfg.tile_colsb[6] = n * 4; \
- cfg.tile_colsb[7] = n * 4; \
- _tile_loadconfig(&cfg);
-
- /* CONFIG for handling k2 and odd tail at the same time
- * tile0 -- A (m x 2k)
- * tile1 -- A (m x 1)
- * tile2 -- B (2k x n)
- * tile3 -- B (1 x n)
- * tile4 -- C (m x n)
- */
- #define TCONF_TAIL(cfg, m, n, k2) \
- memset(&cfg, 0, sizeof(tilecfg)); \
- cfg.palette_id = 1; \
- cfg.tile_rows[0] = m; \
- cfg.tile_rows[1] = m; \
- cfg.tile_rows[2] = k2>>1; \
- cfg.tile_rows[3] = 1; \
- cfg.tile_rows[4] = m; \
- cfg.tile_colsb[0] = k2<<1; \
- cfg.tile_colsb[1] = 4; \
- cfg.tile_colsb[2] = n * 4; \
- cfg.tile_colsb[3] = n * 4; \
- cfg.tile_colsb[4] = n * 4; \
- _tile_loadconfig(&cfg);
-
- #define T_A0 0
- #define T_A1 1
- #define T_B0 2
- #define T_B1 3
- #define T_C00 4
- #define T_C01 5
- #define T_C10 6
- #define T_C11 7
-
-
- #define LOAD_A(M, N) _tile_loadd(T_A##M, ptr_a##M, lda * 2)
- #define LOAD_A_TAIL(M, N) {\
- __m256i ymm = _mm256_loadu_epi16(ptr_a##M); \
- __m512i zmm = _mm512_cvtepu16_epi32(ymm); \
- _mm512_storeu_epi16(tail_a + 16 * M, zmm); \
- _tile_loadd(T_A##M, tail_a + 16 * M, 2 * 2); \
- }
- #define MASK_LOAD_A_TAIL(M, N) {\
- __m256i ymm = _mm256_maskz_loadu_epi16(amask, ptr_a##M); \
- __m512i zmm = _mm512_cvtepu16_epi32(ymm); \
- _mm512_storeu_epi16(tail_a + 16 * M, zmm); \
- _tile_loadd(T_A##M, tail_a + 16 * M, 2 * 2); \
- }
- #define LOAD_B(M, N) _tile_loadd(T_B##N, ptr_b##N, ldb * 2)
- #define LOAD_B_TAIL(M, N) {\
- __m256i ymm = _mm256_loadu_epi16(ptr_b##N); \
- __m512i zmm = _mm512_cvtepu16_epi32(ymm); \
- _mm512_storeu_epi16(tail_b + 16 * N, zmm); \
- _tile_loadd(T_B##N, tail_b + 16 * N, 2 * 2); \
- }
- #define MASK_LOAD_B_TAIL(M, N) {\
- __m256i ymm = _mm256_maskz_loadu_epi16(bmask, ptr_b##N); \
- __m512i zmm = _mm512_cvtepu16_epi32(ymm); \
- _mm512_storeu_epi16(tail_b + 16 * N, zmm); \
- _tile_loadd(T_B##N, tail_b + 16 * N, 2 * 2); \
- }
-
- #define MATMUL(M, N) _tile_dpbf16ps(T_C##M##N, T_A##M, T_B##N)
- #define MATMUL_TAIL(M, N) _tile_dpbf16ps(T_C00, T_A##M, T_B##N)
- #define STORE_C(M, N) _tile_stored(T_C##M##N, ptr_c##M##N, ldc * 4)
- #define LOAD_C_F(M, N) _tile_loadd(T_C##M##N, ptr_c##M##N, ldc * 4)
-
- #endif // end of SBGEMM_KERNEL_SPR
-
- #ifdef ALPHA_ONE
- #undef LOAD_C
- #define LOAD_C(M, N) _tile_loadd(T_C##M##N, ptr_c##M##N, ldc * 4)
- #else
- #undef LOAD_C
- #define LOAD_C(M, N) _tile_zero(T_C##M##N)
- #define ALPHA_STORE(N) \
- __m512 zmm_d##N = _mm512_loadu_ps(dst##N + noffset); \
- __m512 zmm_s##N = _mm512_loadu_ps(src##N + noffset); \
- zmm_d##N = _mm512_fmadd_ps(alpha_512, zmm_s##N, zmm_d##N); \
- _mm512_storeu_ps(dst##N + noffset, zmm_d##N);
- #define MASK_APLPHA_STORE(N) \
- __m512 zmm_d##N = _mm512_maskz_loadu_ps(mask, dst##N + noffset); \
- __m512 zmm_s##N = _mm512_maskz_loadu_ps(mask, src##N + noffset); \
- zmm_d##N = _mm512_fmadd_ps(alpha_512, zmm_s##N, zmm_d##N); \
- _mm512_mask_storeu_ps(dst##N + noffset, mask, zmm_d##N);
- #endif // end of ALPHA_ONE
-
-
- #ifdef ALPHA_ONE
- int sbgemm_kernel_spr_alpha_one(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT * A, IFLOAT * B, FLOAT * C, BLASLONG ldc)
- #else
- int sbgemm_kernel_spr_alpha(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT * A, IFLOAT * B, FLOAT * C, BLASLONG ldc)
- #endif
- {
- /* Row Major matrix for AMX requirement */
- IFLOAT *ptr_a = A, *ptr_b = B;
- IFLOAT *ptr_b0, *ptr_b1;
- IFLOAT *ptr_a0, *ptr_a1;
- FLOAT *ptr_c = C;
- FLOAT *ptr_c00, *ptr_c01, *ptr_c10, *ptr_c11;
-
- BLASLONG lda, ldb;
- BLASLONG m_count = m;
- BLASLONG n_count, k_count;
-
- #ifndef ALPHA_ONE
- // make sure each row is 64 bytes aligned
- BLASLONG cn = (n & 31) ? (n & ~31) + 32 : n;
- FLOAT *raw_tmp_c;
- if (k < 32) {
- // only need to zero buff in this situation
- raw_tmp_c = (FLOAT *)calloc(1, sizeof(FLOAT) * m * cn + 64);
- } else {
- raw_tmp_c = (FLOAT *)malloc(sizeof(FLOAT) * m * cn + 64);
- }
- // align buf to 64 byte boundary
- FLOAT *tmp_c = (FLOAT *)(((uintptr_t) raw_tmp_c + 63) & ~(uintptr_t)63);
- ptr_c = tmp_c;
- BLASLONG ldc_o = ldc;
- ldc = cn;
- #endif
- IFLOAT tail_a[32 * 2] __attribute__ ((aligned (64)));
- IFLOAT tail_b[32 * 2] __attribute__ ((aligned (64)));
- tilecfg cfg;
-
- if (k > 31) {
- for (; m_count > 31; m_count -= 32) {
- ptr_b = B;
-
- ptr_c00 = ptr_c;
- ptr_c01 = ptr_c00 + 16;
- ptr_c10 = ptr_c + 16 * ldc;
- ptr_c11 = ptr_c10 + 16;
- ptr_c += 32 * ldc;
- n_count = n;
- TCONF(cfg, 16, 16, 32);
- for (; n_count > 31; n_count -= 32) {
- ptr_a0 = ptr_a;
- ptr_a1 = ptr_a + 16 * k;
-
- ptr_b0 = ptr_b;
- ptr_b1 = ptr_b + 16 * k;
- ptr_b += 32 * k;
-
- lda = 32;
- ldb = 32;
- LOAD_C(0, 0); LOAD_C(0, 1);
- LOAD_C(1, 0); LOAD_C(1, 1);
- k_count = k;
- for (; k_count > 31; k_count -= 32) {
- LOAD_A(0, x); LOAD_A(1, x);
- ptr_a0 += 16 * 32;
- ptr_a1 += 16 * 32;
- LOAD_B(x, 0); LOAD_B(x, 1);
- ptr_b0 += 16 * 32;
- ptr_b1 += 16 * 32;
-
- MATMUL(0, 0); MATMUL(0, 1);
- MATMUL(1, 0); MATMUL(1, 1);
- }
- STORE_C(0, 0); STORE_C(0, 1);
- STORE_C(1, 0); STORE_C(1, 1);
- ptr_c00 += 32;
- ptr_c01 += 32;
- ptr_c10 += 32;
- ptr_c11 += 32;
- }
- for (; n_count > 0; n_count -= 16) {
- int tail_n = (n_count > 16) ? 16: n_count;
- ptr_a0 = ptr_a;
- ptr_a1 = ptr_a + 16 * k;
-
- ptr_b0 = ptr_b;
- ptr_b += tail_n * k;
-
- lda = 32;
- ldb = 2 * tail_n;
- TCONF(cfg, 16, tail_n, 32);
- LOAD_C(0, 0);
- LOAD_C(1, 0);
- k_count = k;
- for (; k_count > 31; k_count -= 32) {
- LOAD_A(0, x); LOAD_A(1, x);
- ptr_a0 += 16 * 32;
- ptr_a1 += 16 * 32;
- LOAD_B(x, 0);
- ptr_b0 += tail_n * 32;
-
- MATMUL(0, 0);
- MATMUL(1, 0);
- }
- STORE_C(0, 0);
- STORE_C(1, 0);
- ptr_c00 += tail_n;
- ptr_c10 += tail_n;
- }
- ptr_a += 32 * k;
- }
- for (; m_count > 0; m_count -= 16) {
- // process at most 16 m at a time
- int tail_m = (m_count > 16) ? 16: m_count;
-
- ptr_b = B;
-
- ptr_c00 = ptr_c;
- ptr_c01 = ptr_c00 + 16;
- ptr_c += tail_m * ldc;
- n_count = n;
- TCONF(cfg, tail_m, 16, 32);
- for (; n_count > 31; n_count -= 32) {
- ptr_a0 = ptr_a;
-
- ptr_b0 = ptr_b;
- ptr_b1 = ptr_b + 16 * k;
- ptr_b += 32 * k;
-
- lda = 32;
- ldb = 32;
- LOAD_C(0, 0); LOAD_C(0, 1);
- k_count = k;
- for (; k_count > 31; k_count -= 32) {
- LOAD_A(0, x);
- ptr_a0 += tail_m * 32;
- LOAD_B(x, 0); LOAD_B(x, 1);
- ptr_b0 += 16 * 32;
- ptr_b1 += 16 * 32;
-
- MATMUL(0, 0); MATMUL(0, 1);
- }
- STORE_C(0, 0); STORE_C(0, 1);
- ptr_c00 += 32;
- ptr_c01 += 32;
- }
- for (; n_count > 0; n_count -= 16) {
- int tail_n = (n_count > 16) ? 16: n_count;
- ptr_a0 = ptr_a;
-
- ptr_b0 = ptr_b;
- ptr_b += tail_n * k;
-
- lda = 32;
- ldb = 2 * tail_n;
- TCONF(cfg, tail_m, tail_n, 32);
- LOAD_C(0, 0);
- k_count = k;
- for (; k_count > 31; k_count -= 32) {
- LOAD_A(0, x);
- ptr_a0 += tail_m * 32;
- LOAD_B(x, 0);
- ptr_b0 += tail_n * 32;
-
- MATMUL(0, 0);
- }
- STORE_C(0, 0);
- ptr_c00 += tail_n;
- }
- ptr_a += tail_m * k;
- }
- }
-
- // process for k < 32
- BLASLONG k32 = k & ~31;
- BLASLONG k2 = k & ~1;
- if (k32 != k) {
- int remain_k2 = k2 - k32;
- m_count = m;
- ptr_a = A;
- #ifndef ALPHA_ONE
- ptr_c = tmp_c;
- #else
- ptr_c = C;
- #endif
- if (remain_k2 > 0 && k2 != k) { // k%32 = 2x + 1 (x != 0)
- for (; m_count > 0; m_count -= 16) {
- int tail_m = (m_count > 16) ? 16: m_count;
- __mmask16 amask = (1UL << tail_m) - 1;
-
- ptr_a0 = ptr_a + tail_m * k32;
- ptr_a1 = ptr_a + tail_m * k2;
- ptr_a += tail_m * k;
- ptr_b = B;
- ptr_c00 = ptr_c;
- ptr_c += tail_m * ldc;
- n_count = n;
- lda = remain_k2;
- ldb = 32;
- if (n_count > 15) {
- TCONF_TAIL(cfg, tail_m, 16, remain_k2);
- LOAD_A(0, x); MASK_LOAD_A_TAIL(1, x);
- for (; n_count > 15; n_count -= 16) {
- ptr_b0 = ptr_b + 16 * k32;
- ptr_b1 = ptr_b + 16 * k2;
- LOAD_C_F(0, 0);
- LOAD_B(x, 0); LOAD_B_TAIL(x, 1);
- MATMUL(0, 0); MATMUL_TAIL(1, 1);
- STORE_C(0, 0);
- ptr_b += 16 * k;
- ptr_c00 += 16;
- }
- }
- if (n_count > 0) {
- int tail_n = (n_count > 16) ? 16: n_count;
- __mmask16 bmask = (1UL << tail_n) - 1;
- ptr_b0 = ptr_b + tail_n * k32;
- ptr_b1 = ptr_b + tail_n * k2;
- ldb = 2 * tail_n;
- TCONF_TAIL(cfg, tail_m, tail_n, remain_k2);
- LOAD_C_F(0, 0);
- LOAD_A(0, x); MASK_LOAD_A_TAIL(1, x);
- LOAD_B(x, 0); MASK_LOAD_B_TAIL(x, 1);
- MATMUL(0, 0); MATMUL_TAIL(1, 1);
- STORE_C(0, 0);
- }
- }
-
- } else if (remain_k2 > 0) { // k%32 = 2x
- for (; m_count > 0; m_count -= 16) {
- int tail_m = (m_count > 16) ? 16: m_count;
-
- ptr_a0 = ptr_a + tail_m * k32;
- ptr_a += tail_m * k;
- ptr_b = B;
- ptr_c00 = ptr_c;
- ptr_c += tail_m * ldc;
- n_count = n;
- lda = remain_k2;
- ldb = 32;
- if (n_count > 15) {
- TCONF(cfg, tail_m, 16, remain_k2);
- LOAD_A(0, x);
- for (; n_count > 15; n_count -= 16) {
- ptr_b0 = ptr_b + 16 * k32;
- LOAD_C_F(0, 0);
- LOAD_B(x, 0);
- MATMUL(0, 0);
- STORE_C(0, 0);
- ptr_b += 16 * k;
- ptr_c00 += 16;
- }
- }
- if (n_count > 0) {
- int tail_n = (n_count > 16) ? 16: n_count;
- ptr_b0 = ptr_b + tail_n * k32;
- ldb = 2 * tail_n;
- TCONF(cfg, tail_m, tail_n, remain_k2);
- LOAD_C_F(0, 0);
- LOAD_A(0, x);
- LOAD_B(x, 0);
- MATMUL(0, 0);
- STORE_C(0, 0);
- }
- }
- } else { // k%32 = 1
- for (; m_count > 0; m_count -= 16) {
- int tail_m = (m_count > 16) ? 16: m_count;
- __mmask16 amask = (1UL << tail_m) - 1;
-
- ptr_a0 = ptr_a + tail_m * k2;
- ptr_a += tail_m * k;
- ptr_b = B;
- ptr_c00 = ptr_c;
- ptr_c += tail_m * ldc;
- n_count = n;
- if (n_count > 15) {
- TCONF(cfg, tail_m, 16, 2);
- MASK_LOAD_A_TAIL(0, x);
- for (; n_count > 15; n_count -= 16) {
- ptr_b0 = ptr_b + 16 * k2;
- LOAD_C_F(0, 0);
- LOAD_B_TAIL(x, 0);
- MATMUL(0, 0);
- STORE_C(0, 0);
- ptr_b += 16 * k;
- ptr_c00 += 16;
- }
- }
- if (n_count > 0) {
- int tail_n = (n_count > 16) ? 16: n_count;
- __mmask16 bmask = (1UL << tail_n) - 1;
- ptr_b0 = ptr_b + tail_n * k2;
- TCONF(cfg, tail_m, tail_n, 2);
- LOAD_C_F(0, 0);
- MASK_LOAD_A_TAIL(0, x);
- MASK_LOAD_B_TAIL(x, 0);
- MATMUL(0, 0);
- STORE_C(0, 0);
- }
- }
-
- }
- }
- #ifndef ALPHA_ONE
- __m512 alpha_512 = _mm512_broadcastss_ps(_mm_load_ss(&alpha));
- BLASLONG n16 = n & ~15;
- BLASLONG noffset;
- FLOAT *src0, *src1, *src2, *src3;
- FLOAT *dst0, *dst1, *dst2, *dst3;
- FLOAT *src = tmp_c;
- FLOAT *dst = C;
- m_count = m;
- for (; m_count > 3; m_count -= 4) {
- src0 = src;
- src1 = src0 + ldc;
- src2 = src1 + ldc;
- src3 = src2 + ldc;
- src += 4 * ldc;
-
- dst0 = dst;
- dst1 = dst0 + ldc_o;
- dst2 = dst1 + ldc_o;
- dst3 = dst2 + ldc_o;
- dst += 4 * ldc_o;
-
- noffset = 0;
- for (; noffset < n16; noffset += 16) {
- ALPHA_STORE(0);
- ALPHA_STORE(1);
- ALPHA_STORE(2);
- ALPHA_STORE(3);
- }
- if (noffset < n) {
- __mmask16 mask = (1UL << (n - noffset)) - 1;
- MASK_APLPHA_STORE(0);
- MASK_APLPHA_STORE(1);
- MASK_APLPHA_STORE(2);
- MASK_APLPHA_STORE(3);
- }
- }
- for (; m_count > 1; m_count -= 2) {
- src0 = src;
- src1 = src0 + ldc;
- src += 2 * ldc;
-
- dst0 = dst;
- dst1 = dst0 + ldc_o;
- dst += 2 * ldc_o;
-
- noffset = 0;
- for (; noffset < n16; noffset += 16) {
- ALPHA_STORE(0);
- ALPHA_STORE(1);
- }
- if (noffset < n) {
- __mmask16 mask = (1UL << (n - noffset)) - 1;
- MASK_APLPHA_STORE(0);
- MASK_APLPHA_STORE(1);
- }
- }
- for (; m_count > 0; m_count -= 1) {
- src0 = src;
- dst0 = dst;
- noffset = 0;
- for (; noffset < n16; noffset += 16) {
- ALPHA_STORE(0);
- }
- if (noffset < n) {
- __mmask16 mask = (1UL << (n - noffset)) - 1;
- MASK_APLPHA_STORE(0);
- }
- }
- free(raw_tmp_c);
- #endif
- return 0;
- }
|