|
|
@@ -97,33 +97,32 @@ typedef struct { |
|
|
|
#define T_C10 6 |
|
|
|
#define T_C11 7 |
|
|
|
|
|
|
|
// FIXME: gcc11 seem have problem in tile load/store address calc, |
|
|
|
// need to multiply with element size (2 or 4) here. |
|
|
|
|
|
|
|
#define LOAD_A(M, N) _tile_loadd(T_A##M, ptr_a##M, lda * 2) |
|
|
|
#define LOAD_A_TAIL(M, N) {\ |
|
|
|
__m256i ymm = _mm256_loadu_epi16(ptr_a##M); \ |
|
|
|
__m512i zmm = _mm512_cvtepu16_epi32(ymm); \ |
|
|
|
_mm512_storeu_epi16(tail_a + 16 * M, zmm); \ |
|
|
|
_tile_loadd(T_A##M, tail_a + 16 * 2 * M, 2 * 2); \ |
|
|
|
_tile_loadd(T_A##M, tail_a + 16 * M, 2 * 2); \ |
|
|
|
} |
|
|
|
#define MASK_LOAD_A_TAIL(M, N) {\ |
|
|
|
__m256i ymm = _mm256_maskz_loadu_epi16(amask, ptr_a##M); \ |
|
|
|
__m512i zmm = _mm512_cvtepu16_epi32(ymm); \ |
|
|
|
_mm512_storeu_epi16(tail_a + 16 * M, zmm); \ |
|
|
|
_tile_loadd(T_A##M, tail_a + 16 * 2 * M, 2 * 2); \ |
|
|
|
_tile_loadd(T_A##M, tail_a + 16 * M, 2 * 2); \ |
|
|
|
} |
|
|
|
#define LOAD_B(M, N) _tile_loadd(T_B##N, ptr_b##N, ldb * 2) |
|
|
|
#define LOAD_B_TAIL(M, N) {\ |
|
|
|
__m256i ymm = _mm256_loadu_epi16(ptr_b##N); \ |
|
|
|
__m512i zmm = _mm512_cvtepu16_epi32(ymm); \ |
|
|
|
_mm512_storeu_epi16(tail_b + 16 * N, zmm); \ |
|
|
|
_tile_loadd(T_B##N, tail_b + 16 * 2 * N, 2 * 2); \ |
|
|
|
_tile_loadd(T_B##N, tail_b + 16 * N, 2 * 2); \ |
|
|
|
} |
|
|
|
#define MASK_LOAD_B_TAIL(M, N) {\ |
|
|
|
__m256i ymm = _mm256_maskz_loadu_epi16(bmask, ptr_b##N); \ |
|
|
|
__m512i zmm = _mm512_cvtepu16_epi32(ymm); \ |
|
|
|
_mm512_storeu_epi16(tail_b + 16 * N, zmm); \ |
|
|
|
_tile_loadd(T_B##N, tail_b + 16 * 2 * N, 2 * 2); \ |
|
|
|
_tile_loadd(T_B##N, tail_b + 16 * N, 2 * 2); \ |
|
|
|
} |
|
|
|
|
|
|
|
#define MATMUL(M, N) _tile_dpbf16ps(T_C##M##N, T_A##M, T_B##N) |
|
|
|