Browse Source

Merge pull request #5371 from martin-frbg/fixup-5357

Complete the infrastructure changes for adding BGEMM
pull/5373/head
Martin Kroeker GitHub 2 months ago
parent
commit
e927373f62
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
4 changed files with 44 additions and 7 deletions
  1. +17
    -0
      driver/others/parameter.c
  2. +21
    -0
      kernel/setparam-ref.c
  3. +6
    -6
      test/Makefile
  4. +0
    -1
      test/compare_sgemm_bgemm.c

+ 17
- 0
driver/others/parameter.c View File

@@ -72,6 +72,11 @@ BLASLONG shgemm_p = DEFAULT_GEMM_P;
#else #else
BLASLONG shgemm_p = SHGEMM_P; BLASLONG shgemm_p = SHGEMM_P;
#endif #endif
#if BGEMM_P == bgemm_p
BLASLONG bgemm_p = DEFAULT_GEMM_P;
#else
BLASLONG bgemm_p = BGEMM_P;
#endif
#if SGEMM_P == sgemm_p #if SGEMM_P == sgemm_p
BLASLONG sgemm_p = DEFAULT_GEMM_P; BLASLONG sgemm_p = DEFAULT_GEMM_P;
#else #else
@@ -103,6 +108,11 @@ BLASLONG shgemm_q = DEFAULT_GEMM_Q;
#else #else
BLASLONG shgemm_q = SHGEMM_Q; BLASLONG shgemm_q = SHGEMM_Q;
#endif #endif
#if BGEMM_Q == bgemm_q
BLASLONG bgemm_q = DEFAULT_GEMM_Q;
#else
BLASLONG bgemm_q = BGEMM_Q;
#endif
#if SGEMM_Q == sgemm_q #if SGEMM_Q == sgemm_q
BLASLONG sgemm_q = DEFAULT_GEMM_Q; BLASLONG sgemm_q = DEFAULT_GEMM_Q;
#else #else
@@ -134,6 +144,11 @@ BLASLONG shgemm_r = DEFAULT_GEMM_R;
#else #else
BLASLONG shgemm_r = SHGEMM_R; BLASLONG shgemm_r = SHGEMM_R;
#endif #endif
#if BGEMM_R == bgemm_r
BLASLONG bgemm_r = DEFAULT_GEMM_R;
#else
BLASLONG bgemm_r = BGEMM_R;
#endif
#if SGEMM_R == sgemm_r #if SGEMM_R == sgemm_r
BLASLONG sgemm_r = DEFAULT_GEMM_R; BLASLONG sgemm_r = DEFAULT_GEMM_R;
#else #else
@@ -541,6 +556,7 @@ void blas_set_parameter(void){


#ifdef BUILD_BFLOAT16 #ifdef BUILD_BFLOAT16
sbgemm_r = (((BUFFER_SIZE - ((SBGEMM_P * SBGEMM_Q * 4 + GEMM_OFFSET_A + GEMM_ALIGN) & ~GEMM_ALIGN)) / (SBGEMM_Q * 4)) - 15) & ~15; sbgemm_r = (((BUFFER_SIZE - ((SBGEMM_P * SBGEMM_Q * 4 + GEMM_OFFSET_A + GEMM_ALIGN) & ~GEMM_ALIGN)) / (SBGEMM_Q * 4)) - 15) & ~15;
bgemm_r = (((BUFFER_SIZE - ((BGEMM_P * BGEMM_Q * 4 + GEMM_OFFSET_A + GEMM_ALIGN) & ~GEMM_ALIGN)) / (BGEMM_Q * 4)) - 15) & ~15;
#endif #endif
#ifdef BUILD_HFLOAT16 #ifdef BUILD_HFLOAT16
shgemm_r = (((BUFFER_SIZE - ((SHGEMM_P * SHGEMM_Q * 4 + GEMM_OFFSET_A + GEMM_ALIGN) & ~GEMM_ALIGN)) / (SHGEMM_Q * 4)) - 15) & ~15; shgemm_r = (((BUFFER_SIZE - ((SHGEMM_P * SHGEMM_Q * 4 + GEMM_OFFSET_A + GEMM_ALIGN) & ~GEMM_ALIGN)) / (SHGEMM_Q * 4)) - 15) & ~15;
@@ -653,6 +669,7 @@ void blas_set_parameter(void){


#ifdef BUILD_BFLOAT16 #ifdef BUILD_BFLOAT16
sbgemm_r = (((BUFFER_SIZE - ((SBGEMM_P * SBGEMM_Q * 4 + GEMM_OFFSET_A + GEMM_ALIGN) & ~GEMM_ALIGN)) / (SBGEMM_Q * 4)) - 15) & ~15; sbgemm_r = (((BUFFER_SIZE - ((SBGEMM_P * SBGEMM_Q * 4 + GEMM_OFFSET_A + GEMM_ALIGN) & ~GEMM_ALIGN)) / (SBGEMM_Q * 4)) - 15) & ~15;
bgemm_r = (((BUFFER_SIZE - ((BGEMM_P * BGEMM_Q * 4 + GEMM_OFFSET_A + GEMM_ALIGN) & ~GEMM_ALIGN)) / (BGEMM_Q * 4)) - 15) & ~15;
#endif #endif
#ifdef BUILD_HFLOAT16 #ifdef BUILD_HFLOAT16
shgemm_r = (((BUFFER_SIZE - ((SHGEMM_P * SHGEMM_Q * 4 + GEMM_OFFSET_A + GEMM_ALIGN) & ~GEMM_ALIGN)) / (SHGEMM_Q * 4)) - 15) & ~15; shgemm_r = (((BUFFER_SIZE - ((SHGEMM_P * SHGEMM_Q * 4 + GEMM_OFFSET_A + GEMM_ALIGN) & ~GEMM_ALIGN)) / (SHGEMM_Q * 4)) - 15) & ~15;


+ 21
- 0
kernel/setparam-ref.c View File

@@ -926,6 +926,7 @@ gotoblas_t TABLE_NAME = {
static void init_parameter(void) { static void init_parameter(void) {
#if (BUILD_BFLOAT16) #if (BUILD_BFLOAT16)
TABLE_NAME.sbgemm_p = SBGEMM_DEFAULT_P; TABLE_NAME.sbgemm_p = SBGEMM_DEFAULT_P;
TABLE_NAME.bgemm_p = BGEMM_DEFAULT_P;
#endif #endif
#if (BUILD_SINGLE==1) || (BUILD_COMPLEX==1) #if (BUILD_SINGLE==1) || (BUILD_COMPLEX==1)
TABLE_NAME.sgemm_p = SGEMM_DEFAULT_P; TABLE_NAME.sgemm_p = SGEMM_DEFAULT_P;
@@ -942,6 +943,7 @@ static void init_parameter(void) {


#if (BUILD_BFLOAT16) #if (BUILD_BFLOAT16)
TABLE_NAME.sbgemm_q = SBGEMM_DEFAULT_Q; TABLE_NAME.sbgemm_q = SBGEMM_DEFAULT_Q;
TABLE_NAME.bgemm_q = BGEMM_DEFAULT_Q;
#endif #endif
#if BUILD_SINGLE == 1 || (BUILD_COMPLEX==1) #if BUILD_SINGLE == 1 || (BUILD_COMPLEX==1)
TABLE_NAME.sgemm_q = SGEMM_DEFAULT_Q; TABLE_NAME.sgemm_q = SGEMM_DEFAULT_Q;
@@ -958,6 +960,7 @@ static void init_parameter(void) {


#if (BUILD_BFLOAT16) #if (BUILD_BFLOAT16)
TABLE_NAME.sbgemm_r = SBGEMM_DEFAULT_R; TABLE_NAME.sbgemm_r = SBGEMM_DEFAULT_R;
TABLE_NAME.bgemm_r = BGEMM_DEFAULT_R;
#endif #endif
#if BUILD_SINGLE == 1 || (BUILD_COMPLEX==1) #if BUILD_SINGLE == 1 || (BUILD_COMPLEX==1)
TABLE_NAME.sgemm_r = SGEMM_DEFAULT_R; TABLE_NAME.sgemm_r = SGEMM_DEFAULT_R;
@@ -1113,10 +1116,12 @@ static void init_parameter(void) {


#ifdef BUILD_BFLOAT16 #ifdef BUILD_BFLOAT16
TABLE_NAME.sbgemm_p = SBGEMM_DEFAULT_P; TABLE_NAME.sbgemm_p = SBGEMM_DEFAULT_P;
TABLE_NAME.bgemm_p = BGEMM_DEFAULT_P;
#endif #endif


#ifdef BUILD_BFLOAT16 #ifdef BUILD_BFLOAT16
TABLE_NAME.sbgemm_r = SBGEMM_DEFAULT_R; TABLE_NAME.sbgemm_r = SBGEMM_DEFAULT_R;
TABLE_NAME.bgemm_r = BGEMM_DEFAULT_R;
#endif #endif


#if defined(LA464) #if defined(LA464)
@@ -1215,6 +1220,7 @@ static void init_parameter(void) {


#ifdef BUILD_BFLOAT16 #ifdef BUILD_BFLOAT16
TABLE_NAME.sbgemm_q = SBGEMM_DEFAULT_Q; TABLE_NAME.sbgemm_q = SBGEMM_DEFAULT_Q;
TABLE_NAME.bgemm_q = BGEMM_DEFAULT_Q;
#endif #endif
} }
#else // (ARCH_LOONGARCH64) #else // (ARCH_LOONGARCH64)
@@ -1223,6 +1229,7 @@ static void init_parameter(void) {


#ifdef BUILD_BFLOAT16 #ifdef BUILD_BFLOAT16
TABLE_NAME.sbgemm_p = SBGEMM_DEFAULT_P; TABLE_NAME.sbgemm_p = SBGEMM_DEFAULT_P;
TABLE_NAME.bgemm_p = BGEMM_DEFAULT_P;
#endif #endif
TABLE_NAME.sgemm_p = SGEMM_DEFAULT_P; TABLE_NAME.sgemm_p = SGEMM_DEFAULT_P;
TABLE_NAME.dgemm_p = DGEMM_DEFAULT_P; TABLE_NAME.dgemm_p = DGEMM_DEFAULT_P;
@@ -1231,6 +1238,7 @@ static void init_parameter(void) {


#ifdef BUILD_BFLOAT16 #ifdef BUILD_BFLOAT16
TABLE_NAME.sbgemm_r = SBGEMM_DEFAULT_R; TABLE_NAME.sbgemm_r = SBGEMM_DEFAULT_R;
TABLE_NAME.bgemm_r = BGEMM_DEFAULT_R;
#endif #endif
TABLE_NAME.sgemm_r = SGEMM_DEFAULT_R; TABLE_NAME.sgemm_r = SGEMM_DEFAULT_R;
TABLE_NAME.dgemm_r = DGEMM_DEFAULT_R; TABLE_NAME.dgemm_r = DGEMM_DEFAULT_R;
@@ -1240,6 +1248,7 @@ static void init_parameter(void) {


#ifdef BUILD_BFLOAT16 #ifdef BUILD_BFLOAT16
TABLE_NAME.sbgemm_q = SBGEMM_DEFAULT_Q; TABLE_NAME.sbgemm_q = SBGEMM_DEFAULT_Q;
TABLE_NAME.bgemm_q = BGEMM_DEFAULT_Q;
#endif #endif
TABLE_NAME.sgemm_q = SGEMM_DEFAULT_Q; TABLE_NAME.sgemm_q = SGEMM_DEFAULT_Q;
TABLE_NAME.dgemm_q = DGEMM_DEFAULT_Q; TABLE_NAME.dgemm_q = DGEMM_DEFAULT_Q;
@@ -1252,6 +1261,7 @@ static void init_parameter(void) {
static void init_parameter(void) { static void init_parameter(void) {
#ifdef BUILD_BFLOAT16 #ifdef BUILD_BFLOAT16
TABLE_NAME.sbgemm_p = SBGEMM_DEFAULT_P; TABLE_NAME.sbgemm_p = SBGEMM_DEFAULT_P;
TABLE_NAME.bgemm_p = BGEMM_DEFAULT_P;
#endif #endif
TABLE_NAME.sgemm_p = SGEMM_DEFAULT_P; TABLE_NAME.sgemm_p = SGEMM_DEFAULT_P;
TABLE_NAME.dgemm_p = DGEMM_DEFAULT_P; TABLE_NAME.dgemm_p = DGEMM_DEFAULT_P;
@@ -1260,6 +1270,7 @@ static void init_parameter(void) {


#ifdef BUILD_BFLOAT16 #ifdef BUILD_BFLOAT16
TABLE_NAME.sbgemm_r = SBGEMM_DEFAULT_R; TABLE_NAME.sbgemm_r = SBGEMM_DEFAULT_R;
TABLE_NAME.bgemm_r = BGEMM_DEFAULT_R;
#endif #endif
TABLE_NAME.sgemm_r = SGEMM_DEFAULT_R; TABLE_NAME.sgemm_r = SGEMM_DEFAULT_R;
TABLE_NAME.dgemm_r = DGEMM_DEFAULT_R; TABLE_NAME.dgemm_r = DGEMM_DEFAULT_R;
@@ -1269,6 +1280,7 @@ static void init_parameter(void) {


#ifdef BUILD_BFLOAT16 #ifdef BUILD_BFLOAT16
TABLE_NAME.sbgemm_q = SBGEMM_DEFAULT_Q; TABLE_NAME.sbgemm_q = SBGEMM_DEFAULT_Q;
TABLE_NAME.bgemm_q = BGEMM_DEFAULT_Q;
#endif #endif
TABLE_NAME.sgemm_q = SGEMM_DEFAULT_Q; TABLE_NAME.sgemm_q = SGEMM_DEFAULT_Q;
TABLE_NAME.dgemm_q = DGEMM_DEFAULT_Q; TABLE_NAME.dgemm_q = DGEMM_DEFAULT_Q;
@@ -1282,6 +1294,7 @@ static void init_parameter(void) {


#ifdef BUILD_BFLOAT16 #ifdef BUILD_BFLOAT16
TABLE_NAME.sbgemm_p = SBGEMM_DEFAULT_P; TABLE_NAME.sbgemm_p = SBGEMM_DEFAULT_P;
TABLE_NAME.bgemm_p = BGEMM_DEFAULT_P;
#endif #endif
#ifdef BUILD_HFLOAT16 #ifdef BUILD_HFLOAT16
TABLE_NAME.shgemm_p = SHGEMM_DEFAULT_P; TABLE_NAME.shgemm_p = SHGEMM_DEFAULT_P;
@@ -1293,6 +1306,7 @@ static void init_parameter(void) {


#ifdef BUILD_BFLOAT16 #ifdef BUILD_BFLOAT16
TABLE_NAME.sbgemm_r = SBGEMM_DEFAULT_R; TABLE_NAME.sbgemm_r = SBGEMM_DEFAULT_R;
TABLE_NAME.bgemm_r = BGEMM_DEFAULT_R;
#endif #endif
#ifdef BUILD_HFLOAT16 #ifdef BUILD_HFLOAT16
TABLE_NAME.shgemm_r = SHGEMM_DEFAULT_R; TABLE_NAME.shgemm_r = SHGEMM_DEFAULT_R;
@@ -1305,6 +1319,7 @@ static void init_parameter(void) {


#ifdef BUILD_BFLOAT16 #ifdef BUILD_BFLOAT16
TABLE_NAME.sbgemm_q = SBGEMM_DEFAULT_Q; TABLE_NAME.sbgemm_q = SBGEMM_DEFAULT_Q;
TABLE_NAME.bgemm_q = BGEMM_DEFAULT_Q;
#endif #endif
#ifdef BUILD_HFLOAT16 #ifdef BUILD_HFLOAT16
TABLE_NAME.shgemm_q = SHGEMM_DEFAULT_Q; TABLE_NAME.shgemm_q = SHGEMM_DEFAULT_Q;
@@ -1455,6 +1470,8 @@ static void init_parameter(void) {
#ifdef BUILD_BFLOAT16 #ifdef BUILD_BFLOAT16
TABLE_NAME.sbgemm_p = SBGEMM_DEFAULT_P; TABLE_NAME.sbgemm_p = SBGEMM_DEFAULT_P;
TABLE_NAME.sbgemm_q = SBGEMM_DEFAULT_Q; TABLE_NAME.sbgemm_q = SBGEMM_DEFAULT_Q;
TABLE_NAME.bgemm_p = BGEMM_DEFAULT_P;
TABLE_NAME.bgemm_q = BGEMM_DEFAULT_Q;
#endif #endif
#ifdef BUILD_HFLOAT16 #ifdef BUILD_HFLOAT16
TABLE_NAME.shgemm_p = SHGEMM_DEFAULT_P; TABLE_NAME.shgemm_p = SHGEMM_DEFAULT_P;
@@ -2053,6 +2070,10 @@ static void init_parameter(void) {
((TABLE_NAME.sbgemm_p * TABLE_NAME.sbgemm_q * 4 + TABLE_NAME.offsetA ((TABLE_NAME.sbgemm_p * TABLE_NAME.sbgemm_q * 4 + TABLE_NAME.offsetA
+ TABLE_NAME.align) & ~TABLE_NAME.align) + TABLE_NAME.align) & ~TABLE_NAME.align)
) / (TABLE_NAME.sbgemm_q * 4) - 15) & ~15); ) / (TABLE_NAME.sbgemm_q * 4) - 15) & ~15);
TABLE_NAME.bgemm_r = (((BUFFER_SIZE -
((TABLE_NAME.bgemm_p * TABLE_NAME.bgemm_q * 4 + TABLE_NAME.offsetA
+ TABLE_NAME.align) & ~TABLE_NAME.align)
) / (TABLE_NAME.bgemm_q * 4) - 15) & ~15);
#endif #endif


#if BUILD_HFLOAT16==1 #if BUILD_HFLOAT16==1


+ 6
- 6
test/Makefile View File

@@ -229,8 +229,8 @@ ifneq ($(CROSS), 1)
ifeq ($(BUILD_BFLOAT16),1) ifeq ($(BUILD_BFLOAT16),1)
OPENBLAS_NUM_THREADS=1 OMP_NUM_THREADS=1 ./test_sbgemm > SBBLAT3.SUMM OPENBLAS_NUM_THREADS=1 OMP_NUM_THREADS=1 ./test_sbgemm > SBBLAT3.SUMM
@$(GREP) -q FATAL SBBLAT3.SUMM && cat SBBLAT3.SUMM || exit 0 @$(GREP) -q FATAL SBBLAT3.SUMM && cat SBBLAT3.SUMM || exit 0
# OPENBLAS_NUM_THREADS=1 OMP_NUM_THREADS=1 ./test_bgemm > BBLAT3.SUMM
# @$(GREP) -q FATAL BBLAT3.SUMM && cat BBLAT3.SUMM || exit 0
OPENBLAS_NUM_THREADS=1 OMP_NUM_THREADS=1 ./test_bgemm > BBLAT3.SUMM
@$(GREP) -q FATAL BBLAT3.SUMM && cat BBLAT3.SUMM || exit 0
endif endif
ifeq ($(BUILD_SINGLE),1) ifeq ($(BUILD_SINGLE),1)
OPENBLAS_NUM_THREADS=1 OMP_NUM_THREADS=1 ./sblat3 < ./sblat3.dat OPENBLAS_NUM_THREADS=1 OMP_NUM_THREADS=1 ./sblat3 < ./sblat3.dat
@@ -254,8 +254,8 @@ ifeq ($(USE_OPENMP), 1)
ifeq ($(BUILD_BFLOAT16),1) ifeq ($(BUILD_BFLOAT16),1)
OMP_NUM_THREADS=2 ./test_sbgemm > SBBLAT3.SUMM OMP_NUM_THREADS=2 ./test_sbgemm > SBBLAT3.SUMM
@$(GREP) -q FATAL SBBLAT3.SUMM && cat SBBLAT3.SUMM || exit 0 @$(GREP) -q FATAL SBBLAT3.SUMM && cat SBBLAT3.SUMM || exit 0
# OMP_NUM_THREADS=2 ./test_bgemm > BBLAT3.SUMM
# @$(GREP) -q FATAL BBLAT3.SUMM && cat BBLAT3.SUMM || exit 0
OMP_NUM_THREADS=2 ./test_bgemm > BBLAT3.SUMM
@$(GREP) -q FATAL BBLAT3.SUMM && cat BBLAT3.SUMM || exit 0
endif endif
ifeq ($(BUILD_SINGLE),1) ifeq ($(BUILD_SINGLE),1)
OMP_NUM_THREADS=2 ./sblat3 < ./sblat3.dat OMP_NUM_THREADS=2 ./sblat3 < ./sblat3.dat
@@ -277,8 +277,8 @@ else
ifeq ($(BUILD_BFLOAT16),1) ifeq ($(BUILD_BFLOAT16),1)
OPENBLAS_NUM_THREADS=2 ./test_sbgemm > SBBLAT3.SUMM OPENBLAS_NUM_THREADS=2 ./test_sbgemm > SBBLAT3.SUMM
@$(GREP) -q FATAL SBBLAT3.SUMM && cat SBBLAT3.SUMM || exit 0 @$(GREP) -q FATAL SBBLAT3.SUMM && cat SBBLAT3.SUMM || exit 0
# OPENBLAS_NUM_THREADS=2 ./test_bgemm > BBLAT3.SUMM
# @$(GREP) -q FATAL BBLAT3.SUMM && cat BBLAT3.SUMM || exit 0
OPENBLAS_NUM_THREADS=2 ./test_bgemm > BBLAT3.SUMM
@$(GREP) -q FATAL BBLAT3.SUMM && cat BBLAT3.SUMM || exit 0
endif endif
ifeq ($(BUILD_SINGLE),1) ifeq ($(BUILD_SINGLE),1)
OPENBLAS_NUM_THREADS=2 ./sblat3 < ./sblat3.dat OPENBLAS_NUM_THREADS=2 ./sblat3 < ./sblat3.dat


+ 0
- 1
test/compare_sgemm_bgemm.c View File

@@ -28,7 +28,6 @@ THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include <stdint.h> #include <stdint.h>
#include <stdio.h> #include <stdio.h>


#include <arm_neon.h>


#define SGEMM BLASFUNC(sgemm) #define SGEMM BLASFUNC(sgemm)
#define BGEMM BLASFUNC(bgemm) #define BGEMM BLASFUNC(bgemm)


Loading…
Cancel
Save