Browse Source

Fix spr sbgemm error

tags/v0.3.24
Honglin Zhu 2 years ago
parent
commit
f249ccb741
2 changed files with 11 additions and 8 deletions
  1. +6
    -2
      cpuid_x86.c
  2. +5
    -6
      kernel/x86_64/sbgemm_kernel_16x16_spr_tmpl.c

+ 6
- 2
cpuid_x86.c View File

@@ -1479,6 +1479,8 @@ int get_cpuname(void){
else else
return CPUTYPE_NEHALEM; return CPUTYPE_NEHALEM;
case 15: // Sapphire Rapids case 15: // Sapphire Rapids
if(support_amx_bf16())
return CPUTYPE_SAPPHIRERAPIDS;
if(support_avx512_bf16()) if(support_avx512_bf16())
return CPUTYPE_COOPERLAKE; return CPUTYPE_COOPERLAKE;
if(support_avx512()) if(support_avx512())
@@ -1845,7 +1847,8 @@ static char *cpuname[] = {
"ZEN", "ZEN",
"SKYLAKEX", "SKYLAKEX",
"DHYANA", "DHYANA",
"COOPERLAKE"
"COOPERLAKE",
"SAPPHIRERAPIDS",
}; };


static char *lowercpuname[] = { static char *lowercpuname[] = {
@@ -1902,7 +1905,8 @@ static char *lowercpuname[] = {
"zen", "zen",
"skylakex", "skylakex",
"dhyana", "dhyana",
"cooperlake"
"cooperlake",
"sapphirerapids",
}; };


static char *corename[] = { static char *corename[] = {


+ 5
- 6
kernel/x86_64/sbgemm_kernel_16x16_spr_tmpl.c View File

@@ -97,33 +97,32 @@ typedef struct {
#define T_C10 6 #define T_C10 6
#define T_C11 7 #define T_C11 7


// FIXME: gcc11 seem have problem in tile load/store address calc,
// need to multiply with element size (2 or 4) here.

#define LOAD_A(M, N) _tile_loadd(T_A##M, ptr_a##M, lda * 2) #define LOAD_A(M, N) _tile_loadd(T_A##M, ptr_a##M, lda * 2)
#define LOAD_A_TAIL(M, N) {\ #define LOAD_A_TAIL(M, N) {\
__m256i ymm = _mm256_loadu_epi16(ptr_a##M); \ __m256i ymm = _mm256_loadu_epi16(ptr_a##M); \
__m512i zmm = _mm512_cvtepu16_epi32(ymm); \ __m512i zmm = _mm512_cvtepu16_epi32(ymm); \
_mm512_storeu_epi16(tail_a + 16 * M, zmm); \ _mm512_storeu_epi16(tail_a + 16 * M, zmm); \
_tile_loadd(T_A##M, tail_a + 16 * 2 * M, 2 * 2); \
_tile_loadd(T_A##M, tail_a + 16 * M, 2 * 2); \
} }
#define MASK_LOAD_A_TAIL(M, N) {\ #define MASK_LOAD_A_TAIL(M, N) {\
__m256i ymm = _mm256_maskz_loadu_epi16(amask, ptr_a##M); \ __m256i ymm = _mm256_maskz_loadu_epi16(amask, ptr_a##M); \
__m512i zmm = _mm512_cvtepu16_epi32(ymm); \ __m512i zmm = _mm512_cvtepu16_epi32(ymm); \
_mm512_storeu_epi16(tail_a + 16 * M, zmm); \ _mm512_storeu_epi16(tail_a + 16 * M, zmm); \
_tile_loadd(T_A##M, tail_a + 16 * 2 * M, 2 * 2); \
_tile_loadd(T_A##M, tail_a + 16 * M, 2 * 2); \
} }
#define LOAD_B(M, N) _tile_loadd(T_B##N, ptr_b##N, ldb * 2) #define LOAD_B(M, N) _tile_loadd(T_B##N, ptr_b##N, ldb * 2)
#define LOAD_B_TAIL(M, N) {\ #define LOAD_B_TAIL(M, N) {\
__m256i ymm = _mm256_loadu_epi16(ptr_b##N); \ __m256i ymm = _mm256_loadu_epi16(ptr_b##N); \
__m512i zmm = _mm512_cvtepu16_epi32(ymm); \ __m512i zmm = _mm512_cvtepu16_epi32(ymm); \
_mm512_storeu_epi16(tail_b + 16 * N, zmm); \ _mm512_storeu_epi16(tail_b + 16 * N, zmm); \
_tile_loadd(T_B##N, tail_b + 16 * 2 * N, 2 * 2); \
_tile_loadd(T_B##N, tail_b + 16 * N, 2 * 2); \
} }
#define MASK_LOAD_B_TAIL(M, N) {\ #define MASK_LOAD_B_TAIL(M, N) {\
__m256i ymm = _mm256_maskz_loadu_epi16(bmask, ptr_b##N); \ __m256i ymm = _mm256_maskz_loadu_epi16(bmask, ptr_b##N); \
__m512i zmm = _mm512_cvtepu16_epi32(ymm); \ __m512i zmm = _mm512_cvtepu16_epi32(ymm); \
_mm512_storeu_epi16(tail_b + 16 * N, zmm); \ _mm512_storeu_epi16(tail_b + 16 * N, zmm); \
_tile_loadd(T_B##N, tail_b + 16 * 2 * N, 2 * 2); \
_tile_loadd(T_B##N, tail_b + 16 * N, 2 * 2); \
} }


#define MATMUL(M, N) _tile_dpbf16ps(T_C##M##N, T_A##M, T_B##N) #define MATMUL(M, N) _tile_dpbf16ps(T_C##M##N, T_A##M, T_B##N)


Loading…
Cancel
Save