Browse Source

Merge pull request #5290 from Srangrang/develop

Add support for FP16 to openBLAS and shgemm on RISCV
pull/5295/head
Martin Kroeker GitHub 3 months ago
parent
commit
d96daa220d
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
37 changed files with 2522 additions and 92 deletions
  1. +3
    -0
      CMakeLists.txt
  2. +2
    -2
      Makefile.prebuild
  3. +4
    -4
      Makefile.riscv64
  4. +2
    -0
      Makefile.rule
  5. +6
    -0
      Makefile.system
  6. +5
    -2
      Makefile.tail
  7. +21
    -5
      benchmark/Makefile
  8. +8
    -1
      benchmark/gemm.c
  9. +4
    -0
      cblas.h
  10. +11
    -8
      cmake/system.cmake
  11. +15
    -0
      common.h
  12. +2
    -0
      common_interface.h
  13. +18
    -1
      common_level3.h
  14. +45
    -0
      common_macro.h
  15. +72
    -22
      common_param.h
  16. +72
    -0
      common_sh.h
  17. +6
    -0
      driver/level3/CMakeLists.txt
  18. +71
    -16
      driver/level3/Makefile
  19. +1
    -1
      driver/others/Makefile
  20. +22
    -0
      driver/others/parameter.c
  21. +10
    -7
      exports/Makefile
  22. +16
    -7
      exports/gensymbol
  23. +13
    -8
      exports/gensymbol.pl
  24. +2
    -0
      getarch_2nd.c
  25. +3
    -0
      interface/CMakeLists.txt
  26. +22
    -2
      interface/Makefile
  27. +8
    -6
      interface/gemm.c
  28. +55
    -0
      kernel/CMakeLists.txt
  29. +164
    -0
      kernel/Makefile.L3
  30. +11
    -0
      kernel/riscv64/KERNEL.RISCV64_ZVL128B
  31. +18
    -0
      kernel/riscv64/KERNEL.RISCV64_ZVL256B
  32. +969
    -0
      kernel/riscv64/shgemm_kernel_16x8_zvl256b.c
  33. +767
    -0
      kernel/riscv64/shgemm_kernel_8x8_zvl128b.c
  34. +37
    -0
      kernel/setparam-ref.c
  35. +1
    -0
      lapack/CMakeLists.txt
  36. +7
    -0
      openblas_config_template.h
  37. +29
    -0
      param.h

+ 3
- 0
CMakeLists.txt View File

@@ -152,6 +152,9 @@ endif ()
if (NOT DEFINED BUILD_BFLOAT16) if (NOT DEFINED BUILD_BFLOAT16)
set (BUILD_BFLOAT16 false) set (BUILD_BFLOAT16 false)
endif () endif ()
if (NOT DEFINED BUILD_HFLOAT16)
set (BUILD_HFLOAT16 false)
endif ()
# set which float types we want to build for # set which float types we want to build for
if (NOT DEFINED BUILD_SINGLE AND NOT DEFINED BUILD_DOUBLE AND NOT DEFINED BUILD_COMPLEX AND NOT DEFINED BUILD_COMPLEX16) if (NOT DEFINED BUILD_SINGLE AND NOT DEFINED BUILD_DOUBLE AND NOT DEFINED BUILD_COMPLEX AND NOT DEFINED BUILD_COMPLEX16)
# if none are defined, build for all # if none are defined, build for all


+ 2
- 2
Makefile.prebuild View File

@@ -64,11 +64,11 @@ TARGET_FLAGS = -march=rv64imafdcv_zba_zbb_zfh -mabi=lp64d
endif endif


ifeq ($(TARGET), RISCV64_ZVL256B) ifeq ($(TARGET), RISCV64_ZVL256B)
TARGET_FLAGS = -march=rv64imafdcv -mabi=lp64d
TARGET_FLAGS = -march=rv64imafdcv_zvfh_zfh -mabi=lp64d
endif endif


ifeq ($(TARGET), RISCV64_ZVL128B) ifeq ($(TARGET), RISCV64_ZVL128B)
TARGET_FLAGS = -march=rv64imafdcv -mabi=lp64d
TARGET_FLAGS = -march=rv64imafdcv_zvfh_zfh -mabi=lp64d
endif endif


ifeq ($(TARGET), RISCV64_GENERIC) ifeq ($(TARGET), RISCV64_GENERIC)


+ 4
- 4
Makefile.riscv64 View File

@@ -7,12 +7,12 @@ CCOMMON_OPT += -march=rv64imafdcv_zba_zbb_zfh_zvl512b -mabi=lp64d
FCOMMON_OPT += -march=rv64imafdcv_zba_zbb_zfh -mabi=lp64d -static FCOMMON_OPT += -march=rv64imafdcv_zba_zbb_zfh -mabi=lp64d -static
endif endif
ifeq ($(CORE), RISCV64_ZVL256B) ifeq ($(CORE), RISCV64_ZVL256B)
CCOMMON_OPT += -march=rv64imafdcv_zvl256b -mabi=lp64d
FCOMMON_OPT += -march=rv64imafdcv -mabi=lp64d
CCOMMON_OPT += -march=rv64imafdcv_zvl256b_zvfh_zfh -mabi=lp64d
FCOMMON_OPT += -march=rv64imafdcv_zvfh_zfh -mabi=lp64d
endif endif
ifeq ($(CORE), RISCV64_ZVL128B) ifeq ($(CORE), RISCV64_ZVL128B)
CCOMMON_OPT += -march=rv64imafdcv -mabi=lp64d
FCOMMON_OPT += -march=rv64imafdcv -mabi=lp64d
CCOMMON_OPT += -march=rv64imafdcv_zvfh_zfh -mabi=lp64d
FCOMMON_OPT += -march=rv64imafdcv_zvfh_zfh -mabi=lp64d
endif endif
ifeq ($(CORE), RISCV64_GENERIC) ifeq ($(CORE), RISCV64_GENERIC)
CCOMMON_OPT += -march=rv64imafdc -mabi=lp64d CCOMMON_OPT += -march=rv64imafdc -mabi=lp64d


+ 2
- 0
Makefile.rule View File

@@ -308,6 +308,8 @@ COMMON_PROF = -pg
# If you want to enable the experimental BFLOAT16 support # If you want to enable the experimental BFLOAT16 support
# BUILD_BFLOAT16 = 1 # BUILD_BFLOAT16 = 1


# If you want to enable the experimental HFLOAT16 support
# BUILD_HFLOAT16 = 1


# Set the thread number threshold beyond which the job array for the threaded level3 BLAS # Set the thread number threshold beyond which the job array for the threaded level3 BLAS
# will be allocated on the heap rather than the stack. (This array alone requires # will be allocated on the heap rather than the stack. (This array alone requires


+ 6
- 0
Makefile.system View File

@@ -1556,6 +1556,9 @@ endif
ifeq ($(BUILD_BFLOAT16), 1) ifeq ($(BUILD_BFLOAT16), 1)
CCOMMON_OPT += -DBUILD_BFLOAT16 CCOMMON_OPT += -DBUILD_BFLOAT16
endif endif
ifeq ($(BUILD_HFLOAT16), 1)
CCOMMON_OPT += -DBUILD_HFLOAT16
endif
ifeq ($(BUILD_SINGLE), 1) ifeq ($(BUILD_SINGLE), 1)
CCOMMON_OPT += -DBUILD_SINGLE=1 CCOMMON_OPT += -DBUILD_SINGLE=1
endif endif
@@ -1898,11 +1901,14 @@ export TARGET_CORE
export NO_AVX512 export NO_AVX512
export NO_AVX2 export NO_AVX2
export BUILD_BFLOAT16 export BUILD_BFLOAT16
export BUILD_HFLOAT16
export NO_LSX export NO_LSX
export NO_LASX export NO_LASX


export SBGEMM_UNROLL_M export SBGEMM_UNROLL_M
export SBGEMM_UNROLL_N export SBGEMM_UNROLL_N
export SHGEMM_UNROLL_M
export SHGEMM_UNROLL_N
export SGEMM_UNROLL_M export SGEMM_UNROLL_M
export SGEMM_UNROLL_N export SGEMM_UNROLL_N
export DGEMM_UNROLL_M export DGEMM_UNROLL_M


+ 5
- 2
Makefile.tail View File

@@ -1,4 +1,5 @@
SBBLASOBJS_P = $(SBBLASOBJS:.$(SUFFIX)=.$(PSUFFIX)) SBBLASOBJS_P = $(SBBLASOBJS:.$(SUFFIX)=.$(PSUFFIX))
SHBLASPBJS_P = $(SHBLASOBJS:.$(SUFFIX)=.$(PSUFFIX))
SBLASOBJS_P = $(SBLASOBJS:.$(SUFFIX)=.$(PSUFFIX)) SBLASOBJS_P = $(SBLASOBJS:.$(SUFFIX)=.$(PSUFFIX))
DBLASOBJS_P = $(DBLASOBJS:.$(SUFFIX)=.$(PSUFFIX)) DBLASOBJS_P = $(DBLASOBJS:.$(SUFFIX)=.$(PSUFFIX))
QBLASOBJS_P = $(QBLASOBJS:.$(SUFFIX)=.$(PSUFFIX)) QBLASOBJS_P = $(QBLASOBJS:.$(SUFFIX)=.$(PSUFFIX))
@@ -11,8 +12,8 @@ COMMONOBJS_P = $(COMMONOBJS:.$(SUFFIX)=.$(PSUFFIX))


HPLOBJS_P = $(HPLOBJS:.$(SUFFIX)=.$(PSUFFIX)) HPLOBJS_P = $(HPLOBJS:.$(SUFFIX)=.$(PSUFFIX))


BLASOBJS = $(SBEXTOBJS) $(SBBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS) $(CBAUXOBJS)
BLASOBJS_P = $(SBEXTOBJS_P) $(SBBLASOBJS_P) $(SBLASOBJS_P) $(DBLASOBJS_P) $(CBLASOBJS_P) $(ZBLASOBJS_P) $(CBAUXOBJS_P)
BLASOBJS = $(SHBLASOBJS) $(SBEXTOBJS) $(SBBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS) $(CBAUXOBJS)
BLASOBJS_P = $(SHBLASPBJS_P) $(SBEXTOBJS_P) $(SBBLASOBJS_P) $(SBLASOBJS_P) $(DBLASOBJS_P) $(CBLASOBJS_P) $(ZBLASOBJS_P) $(CBAUXOBJS_P)


ifdef EXPRECISION ifdef EXPRECISION
BLASOBJS += $(QBLASOBJS) $(XBLASOBJS) BLASOBJS += $(QBLASOBJS) $(XBLASOBJS)
@@ -24,6 +25,7 @@ BLASOBJS += $(QBLASOBJS) $(XBLASOBJS)
BLASOBJS_P += $(QBLASOBJS_P) $(XBLASOBJS_P) BLASOBJS_P += $(QBLASOBJS_P) $(XBLASOBJS_P)
endif endif


$(SHBLASOBJS) $(SHBLASOBJS_P) : override CFLAGS += -DHFLOAT16 -UDOUBLE -UCOMPLEX
$(SBBLASOBJS) $(SBBLASOBJS_P) : override CFLAGS += -DBFLOAT16 -UDOUBLE -UCOMPLEX $(SBBLASOBJS) $(SBBLASOBJS_P) : override CFLAGS += -DBFLOAT16 -UDOUBLE -UCOMPLEX
$(SBLASOBJS) $(SBLASOBJS_P) : override CFLAGS += -UDOUBLE -UCOMPLEX $(SBLASOBJS) $(SBLASOBJS_P) : override CFLAGS += -UDOUBLE -UCOMPLEX
$(DBLASOBJS) $(DBLASOBJS_P) : override CFLAGS += -DDOUBLE -UCOMPLEX $(DBLASOBJS) $(DBLASOBJS_P) : override CFLAGS += -DDOUBLE -UCOMPLEX
@@ -33,6 +35,7 @@ $(ZBLASOBJS) $(ZBLASOBJS_P) : override CFLAGS += -DDOUBLE -DCOMPLEX
$(XBLASOBJS) $(XBLASOBJS_P) : override CFLAGS += -DXDOUBLE -DCOMPLEX $(XBLASOBJS) $(XBLASOBJS_P) : override CFLAGS += -DXDOUBLE -DCOMPLEX
$(SBEXTOBJS) $(SBEXTOBJS_P) : override CFLAGS += -DBFLOAT16 -UDOUBLE -UCOMPLEX $(SBEXTOBJS) $(SBEXTOBJS_P) : override CFLAGS += -DBFLOAT16 -UDOUBLE -UCOMPLEX


$(SHBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
$(SBBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF) $(SBBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
$(SBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF) $(SBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
$(DBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF) $(DBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)


+ 21
- 5
benchmark/Makefile View File

@@ -56,9 +56,15 @@ GOTO_LAPACK_TARGETS=
endif endif


ifeq ($(BUILD_BFLOAT16),1) ifeq ($(BUILD_BFLOAT16),1)
GOTO_HALF_TARGETS=sbgemm.goto
GOTO_BFLOAT_TARGETS=sbgemm.goto
else else
GOTO_HALF_TARGETS=
GOTO_BFLOAT_TARGETS=
endif

ifeq ($(BUILD_HFLOAT16),1)
GOTO_HFLOAT_TARGETS=shgemm.goto
else
GOTO_HFLOAT_TARGETS=
endif endif


ifeq ($(OSNAME), WINNT) ifeq ($(OSNAME), WINNT)
@@ -104,7 +110,7 @@ goto :: slinpack.goto dlinpack.goto clinpack.goto zlinpack.goto \
spotrf.goto dpotrf.goto cpotrf.goto zpotrf.goto \ spotrf.goto dpotrf.goto cpotrf.goto zpotrf.goto \
ssymm.goto dsymm.goto csymm.goto zsymm.goto \ ssymm.goto dsymm.goto csymm.goto zsymm.goto \
somatcopy.goto domatcopy.goto comatcopy.goto zomatcopy.goto \ somatcopy.goto domatcopy.goto comatcopy.goto zomatcopy.goto \
saxpby.goto daxpby.goto caxpby.goto zaxpby.goto $(GOTO_HALF_TARGETS)
saxpby.goto daxpby.goto caxpby.goto zaxpby.goto $(GOTO_BFLOAT_TARGETS) $(GOTO_HFLOAT_TARGETS)


acml :: slinpack.acml dlinpack.acml clinpack.acml zlinpack.acml \ acml :: slinpack.acml dlinpack.acml clinpack.acml zlinpack.acml \
scholesky.acml dcholesky.acml ccholesky.acml zcholesky.acml \ scholesky.acml dcholesky.acml ccholesky.acml zcholesky.acml \
@@ -278,7 +284,7 @@ goto :: sgemm.goto dgemm.goto cgemm.goto zgemm.goto \
smin.goto dmin.goto \ smin.goto dmin.goto \
saxpby.goto daxpby.goto caxpby.goto zaxpby.goto \ saxpby.goto daxpby.goto caxpby.goto zaxpby.goto \
somatcopy.goto domatcopy.goto comatcopy.goto zomatcopy.goto \ somatcopy.goto domatcopy.goto comatcopy.goto zomatcopy.goto \
snrm2.goto dnrm2.goto scnrm2.goto dznrm2.goto $(GOTO_LAPACK_TARGETS) $(GOTO_HALF_TARGETS)
snrm2.goto dnrm2.goto scnrm2.goto dznrm2.goto $(GOTO_LAPACK_TARGETS) $(GOTO_BFLOAT_TARGETS) $(GOTO_HFLOAT_TARGETS)


acml :: slinpack.acml dlinpack.acml clinpack.acml zlinpack.acml \ acml :: slinpack.acml dlinpack.acml clinpack.acml zlinpack.acml \
scholesky.acml dcholesky.acml ccholesky.acml zcholesky.acml \ scholesky.acml dcholesky.acml ccholesky.acml zcholesky.acml \
@@ -633,6 +639,11 @@ sbgemm.goto : sbgemm.$(SUFFIX) ../$(LIBNAME)
$(CC) $(CFLAGS) -o $(@F) $^ $(CEXTRALIB) $(EXTRALIB) $(FEXTRALIB) -lm $(CC) $(CFLAGS) -o $(@F) $^ $(CEXTRALIB) $(EXTRALIB) $(FEXTRALIB) -lm
endif endif


ifeq ($(BUILD_HFLOAT16),1)
shgemm.goto : shgemm.$(SUFFIX) ../$(LIBNAME)
$(CC) $(CFLAGS) -o $(@F) $^ $(CEXTRALIB) $(EXTRALIB) $(FEXTRALIB) -lm
endif

sgemm.goto : sgemm.$(SUFFIX) ../$(LIBNAME) sgemm.goto : sgemm.$(SUFFIX) ../$(LIBNAME)
$(CC) $(CFLAGS) -o $(@F) $^ $(CEXTRALIB) $(EXTRALIB) $(FEXTRALIB) -lm $(CC) $(CFLAGS) -o $(@F) $^ $(CEXTRALIB) $(EXTRALIB) $(FEXTRALIB) -lm


@@ -2960,7 +2971,12 @@ zcholesky.$(SUFFIX) : cholesky.c


ifeq ($(BUILD_BFLOAT16),1) ifeq ($(BUILD_BFLOAT16),1)
sbgemm.$(SUFFIX) : gemm.c sbgemm.$(SUFFIX) : gemm.c
$(CC) $(CFLAGS) -c -DHALF -UCOMPLEX -UDOUBLE -o $(@F) $^
$(CC) $(CFLAGS) -c -DBFLOAT16 -UCOMPLEX -UDOUBLE -o $(@F) $^
endif

ifeq ($(BUILD_HFLOAT16),1)
shgemm.$(SUFFIX) : gemm.c
$(CC) $(CFLAGS) -c -DHFLOAT16 -UCOMPLEX -UDOUBLE -o $(@F) $^
endif endif


sgemm.$(SUFFIX) : gemm.c sgemm.$(SUFFIX) : gemm.c


+ 8
- 1
benchmark/gemm.c View File

@@ -33,10 +33,17 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.


#ifdef DOUBLE #ifdef DOUBLE
#define GEMM BLASFUNC(dgemm) #define GEMM BLASFUNC(dgemm)
#elif defined(HALF)
#elif defined(BFLOAT16)
#define GEMM BLASFUNC(sbgemm) #define GEMM BLASFUNC(sbgemm)
#undef IFLOAT
#define IFLOAT bfloat16
#elif defined(HFLOAT16)
#define GEMM BLASFUNC(shgemm)
#undef IFLOAT
#define IFLOAT hfloat16
#else #else
#define GEMM BLASFUNC(sgemm) #define GEMM BLASFUNC(sgemm)
#define IFLOAT float
#endif #endif


#else #else


+ 4
- 0
cblas.h View File

@@ -446,6 +446,10 @@ void cblas_sbgemm(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum C
void cblas_sbgemm_batch(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransA_array, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransB_array, OPENBLAS_CONST blasint * M_array, OPENBLAS_CONST blasint * N_array, OPENBLAS_CONST blasint * K_array, void cblas_sbgemm_batch(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransA_array, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransB_array, OPENBLAS_CONST blasint * M_array, OPENBLAS_CONST blasint * N_array, OPENBLAS_CONST blasint * K_array,
OPENBLAS_CONST float * alpha_array, OPENBLAS_CONST bfloat16 ** A_array, OPENBLAS_CONST blasint * lda_array, OPENBLAS_CONST bfloat16 ** B_array, OPENBLAS_CONST blasint * ldb_array, OPENBLAS_CONST float * beta_array, float ** C_array, OPENBLAS_CONST blasint * ldc_array, OPENBLAS_CONST blasint group_count, OPENBLAS_CONST blasint * group_size); OPENBLAS_CONST float * alpha_array, OPENBLAS_CONST bfloat16 ** A_array, OPENBLAS_CONST blasint * lda_array, OPENBLAS_CONST bfloat16 ** B_array, OPENBLAS_CONST blasint * ldb_array, OPENBLAS_CONST float * beta_array, float ** C_array, OPENBLAS_CONST blasint * ldc_array, OPENBLAS_CONST blasint group_count, OPENBLAS_CONST blasint * group_size);


/*** FLOAT16 extensions ***/
void cblas_shgemm(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransA, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransB, OPENBLAS_CONST blasint M, OPENBLAS_CONST blasint N, OPENBLAS_CONST blasint K,
OPENBLAS_CONST float alpha, OPENBLAS_CONST hfloat16 *A, OPENBLAS_CONST blasint lda, OPENBLAS_CONST hfloat16 *B, OPENBLAS_CONST blasint ldb, OPENBLAS_CONST float beta, float *C, OPENBLAS_CONST blasint ldc);

#ifdef __cplusplus #ifdef __cplusplus
} }
#endif /* __cplusplus */ #endif /* __cplusplus */


+ 11
- 8
cmake/system.cmake View File

@@ -640,6 +640,9 @@ endif()
if (BUILD_BFLOAT16) if (BUILD_BFLOAT16)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DBUILD_BFLOAT16") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DBUILD_BFLOAT16")
endif() endif()
if (BUILD_HFLOAT16)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DBUILD_HFLOAT16")
endif()
if(NOT MSVC) if(NOT MSVC)
set(CMAKE_ASM_FLAGS "${CMAKE_ASM_FLAGS} ${CCOMMON_OPT}") set(CMAKE_ASM_FLAGS "${CMAKE_ASM_FLAGS} ${CCOMMON_OPT}")
endif() endif()
@@ -647,14 +650,14 @@ endif()
set(PFLAGS "${PFLAGS} ${CCOMMON_OPT} -I${TOPDIR} -DPROFILE ${COMMON_PROF}") set(PFLAGS "${PFLAGS} ${CCOMMON_OPT} -I${TOPDIR} -DPROFILE ${COMMON_PROF}")
if ("${CMAKE_BUILD_TYPE}" STREQUAL "Release") if ("${CMAKE_BUILD_TYPE}" STREQUAL "Release")


if ("${F_COMPILER}" STREQUAL "FLANG")
if (${CMAKE_Fortran_COMPILER_VERSION} VERSION_LESS_EQUAL 3)
set(CMAKE_Fortran_FLAGS_RELEASE "${CMAKE_Fortran_FLAGS_RELEASE} -fno-unroll-loops")
endif ()
endif ()
if (ARM64 AND CMAKE_Fortran_COMPILER_ID MATCHES "LLVMFlang.*" AND CMAKE_SYSTEM_NAME STREQUAL "Windows")
set(CMAKE_Fortran_FLAGS_RELEASE "${CMAKE_Fortran_FLAGS_RELEASE} -O2")
endif ()
if ("${F_COMPILER}" STREQUAL "FLANG")
if (${CMAKE_Fortran_COMPILER_VERSION} VERSION_LESS_EQUAL 3)
set(CMAKE_Fortran_FLAGS_RELEASE "${CMAKE_Fortran_FLAGS_RELEASE} -fno-unroll-loops")
endif ()
endif ()
if (ARM64 AND CMAKE_Fortran_COMPILER_ID MATCHES "LLVMFlang.*" AND CMAKE_SYSTEM_NAME STREQUAL "Windows")
set(CMAKE_Fortran_FLAGS_RELEASE "${CMAKE_Fortran_FLAGS_RELEASE} -O2")
endif ()
endif () endif ()






+ 15
- 0
common.h View File

@@ -266,6 +266,14 @@ typedef uint16_t bfloat16;
#define BFLOAT16CONVERSION 1 #define BFLOAT16CONVERSION 1
#endif #endif


#ifdef BUILD_HFLOAT16
#ifndef hfloat16
typedef _Float16 hfloat16;
#endif
#else
typedef uint16_t hfloat16;
#endif

#ifdef USE64BITINT #ifdef USE64BITINT
typedef BLASLONG blasint; typedef BLASLONG blasint;
#if defined(OS_WINDOWS) && defined(__64BIT__) #if defined(OS_WINDOWS) && defined(__64BIT__)
@@ -313,6 +321,13 @@ typedef int blasint;
#define SIZE 2 #define SIZE 2
#define BASE_SHIFT 1 #define BASE_SHIFT 1
#define ZBASE_SHIFT 2 #define ZBASE_SHIFT 2
#elif defined(HFLOAT16)
#define IFLOAT hfloat16
#define XFLOAT IFLOAT
#define FLOAT float
#define SIZE 2
#define BASE_SHIFT 1
#define ZBASE_SHIFT 2
#else #else
#define FLOAT float #define FLOAT float
#define SIZE 4 #define SIZE 4


+ 2
- 0
common_interface.h View File

@@ -481,6 +481,8 @@ void BLASFUNC(xhbmv)(char *, blasint *, blasint *, xdouble *, xdouble *, blasint


/* Level 3 routines */ /* Level 3 routines */


void BLASFUNC(shgemm)(char *, char *, blasint *, blasint *, blasint *, float *,
hfloat16 *, blasint *, hfloat16 *, blasint *, float *, float *, blasint *);
void BLASFUNC(sbgemm)(char *, char *, blasint *, blasint *, blasint *, float *, void BLASFUNC(sbgemm)(char *, char *, blasint *, blasint *, blasint *, float *,
bfloat16 *, blasint *, bfloat16 *, blasint *, float *, float *, blasint *); bfloat16 *, blasint *, bfloat16 *, blasint *, float *, float *, blasint *);
void BLASFUNC(sgemm)(char *, char *, blasint *, blasint *, blasint *, float *, void BLASFUNC(sgemm)(char *, char *, blasint *, blasint *, blasint *, float *,


+ 18
- 1
common_level3.h View File

@@ -54,7 +54,8 @@ void sgemm_direct(BLASLONG M, BLASLONG N, BLASLONG K,


int sgemm_direct_performant(BLASLONG M, BLASLONG N, BLASLONG K); int sgemm_direct_performant(BLASLONG M, BLASLONG N, BLASLONG K);



int shgemm_beta(BLASLONG, BLASLONG, BLASLONG, float,
hfloat16 *, BLASLONG, hfloat16 *, BLASLONG, float *, BLASLONG);
int sbgemm_beta(BLASLONG, BLASLONG, BLASLONG, float, int sbgemm_beta(BLASLONG, BLASLONG, BLASLONG, float,
bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, float *, BLASLONG); bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, float *, BLASLONG);
int sgemm_beta(BLASLONG, BLASLONG, BLASLONG, float, int sgemm_beta(BLASLONG, BLASLONG, BLASLONG, float,
@@ -78,6 +79,10 @@ int xgemm_beta(BLASLONG, BLASLONG, BLASLONG, xdouble *,
xdouble *, BLASLONG, xdouble *, BLASLONG, xdouble *, BLASLONG); xdouble *, BLASLONG, xdouble *, BLASLONG, xdouble *, BLASLONG);
#endif #endif


int shgemm_incopy(BLASLONG m, BLASLONG n, hfloat16 *a, BLASLONG lda, hfloat16 *b);
int shgemm_itcopy(BLASLONG m, BLASLONG n, hfloat16 *a, BLASLONG lda, hfloat16 *b);
int shgemm_oncopy(BLASLONG m, BLASLONG n, hfloat16 *a, BLASLONG lda, hfloat16 *b);
int shgemm_otcopy(BLASLONG m, BLASLONG n, hfloat16 *a, BLASLONG lda, hfloat16 *b);
int sbgemm_incopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b); int sbgemm_incopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b);
int sbgemm_itcopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b); int sbgemm_itcopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b);
int sbgemm_oncopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b); int sbgemm_oncopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b);
@@ -505,6 +510,7 @@ int xher2k_kernel_UC(BLASLONG m, BLASLONG n, BLASLONG k, xdouble alpha_r, xdoubl
int xher2k_kernel_LN(BLASLONG m, BLASLONG n, BLASLONG k, xdouble alpha_r, xdouble alpha_i, xdouble *a, xdouble *b, xdouble *c, BLASLONG ldc, BLASLONG offset, int flag); int xher2k_kernel_LN(BLASLONG m, BLASLONG n, BLASLONG k, xdouble alpha_r, xdouble alpha_i, xdouble *a, xdouble *b, xdouble *c, BLASLONG ldc, BLASLONG offset, int flag);
int xher2k_kernel_LC(BLASLONG m, BLASLONG n, BLASLONG k, xdouble alpha_r, xdouble alpha_i, xdouble *a, xdouble *b, xdouble *c, BLASLONG ldc, BLASLONG offset, int flag); int xher2k_kernel_LC(BLASLONG m, BLASLONG n, BLASLONG k, xdouble alpha_r, xdouble alpha_i, xdouble *a, xdouble *b, xdouble *c, BLASLONG ldc, BLASLONG offset, int flag);


int shgemm_kernel(BLASLONG, BLASLONG, BLASLONG, float, hfloat16 *, hfloat16 *, float *, BLASLONG);
int sbgemm_kernel(BLASLONG, BLASLONG, BLASLONG, float, bfloat16 *, bfloat16 *, float *, BLASLONG); int sbgemm_kernel(BLASLONG, BLASLONG, BLASLONG, float, bfloat16 *, bfloat16 *, float *, BLASLONG);
int sgemm_kernel(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG); int sgemm_kernel(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG);
int dgemm_kernel(BLASLONG, BLASLONG, BLASLONG, double, double *, double *, double *, BLASLONG); int dgemm_kernel(BLASLONG, BLASLONG, BLASLONG, double, double *, double *, double *, BLASLONG);
@@ -657,6 +663,11 @@ int cgemm3m_kernel(BLASLONG, BLASLONG, BLASLONG, float, float, float *, float
int zgemm3m_kernel(BLASLONG, BLASLONG, BLASLONG, double, double, double *, double *, double *, BLASLONG); int zgemm3m_kernel(BLASLONG, BLASLONG, BLASLONG, double, double, double *, double *, double *, BLASLONG);
int xgemm3m_kernel(BLASLONG, BLASLONG, BLASLONG, xdouble, xdouble, xdouble *, xdouble *, xdouble *, BLASLONG); int xgemm3m_kernel(BLASLONG, BLASLONG, BLASLONG, xdouble, xdouble, xdouble *, xdouble *, xdouble *, BLASLONG);


int shgemm_nn(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG);
int shgemm_nt(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG);
int shgemm_tn(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG);
int shgemm_tt(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG);

int sbgemm_nn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG); int sbgemm_nn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
int sbgemm_nt(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG); int sbgemm_nt(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
int sbgemm_tn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG); int sbgemm_tn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
@@ -754,6 +765,11 @@ int xgemm_cr(blas_arg_t *, BLASLONG *, BLASLONG *, xdouble *, xdouble *, BLASLON
int xgemm_cc(blas_arg_t *, BLASLONG *, BLASLONG *, xdouble *, xdouble *, BLASLONG); int xgemm_cc(blas_arg_t *, BLASLONG *, BLASLONG *, xdouble *, xdouble *, BLASLONG);
#endif #endif


int shgemm_thread_nn(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG);
int shgemm_thread_nt(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG);
int shgemm_thread_tn(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG);
int shgemm_thread_tt(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG);

int sbgemm_thread_nn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG); int sbgemm_thread_nn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
int sbgemm_thread_nt(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG); int sbgemm_thread_nt(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
int sbgemm_thread_tn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG); int sbgemm_thread_tn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
@@ -1944,6 +1960,7 @@ int dgemm_batch_thread(blas_arg_t * queue, BLASLONG nums);
int cgemm_batch_thread(blas_arg_t * queue, BLASLONG nums); int cgemm_batch_thread(blas_arg_t * queue, BLASLONG nums);
int zgemm_batch_thread(blas_arg_t * queue, BLASLONG nums); int zgemm_batch_thread(blas_arg_t * queue, BLASLONG nums);
int sbgemm_batch_thread(blas_arg_t * queue, BLASLONG nums); int sbgemm_batch_thread(blas_arg_t * queue, BLASLONG nums);
// int shgemm_batch_thread(blas_arg_t * queue, BLASLONG nums);


#ifdef __CUDACC__ #ifdef __CUDACC__
} }


+ 45
- 0
common_macro.h View File

@@ -39,6 +39,7 @@
#ifndef COMMON_MACRO #ifndef COMMON_MACRO
#define COMMON_MACRO #define COMMON_MACRO


#include "common_sh.h"
#include "common_sb.h" #include "common_sb.h"
#include "common_s.h" #include "common_s.h"
#include "common_d.h" #include "common_d.h"
@@ -656,6 +657,50 @@
#define GEMM_SMALL_KERNEL_B0_NT DGEMM_SMALL_KERNEL_B0_NT #define GEMM_SMALL_KERNEL_B0_NT DGEMM_SMALL_KERNEL_B0_NT
#define GEMM_SMALL_KERNEL_B0_TN DGEMM_SMALL_KERNEL_B0_TN #define GEMM_SMALL_KERNEL_B0_TN DGEMM_SMALL_KERNEL_B0_TN
#define GEMM_SMALL_KERNEL_B0_TT DGEMM_SMALL_KERNEL_B0_TT #define GEMM_SMALL_KERNEL_B0_TT DGEMM_SMALL_KERNEL_B0_TT
#elif defined(HFLOAT16)
#define GEMM_BETA SHGEMM_BETA
#define GEMM_KERNEL_N SHGEMM_KERNEL
#define GEMM_KERNEL_L SHGEMM_KERNEL
#define GEMM_KERNEL_R SHGEMM_KERNEL
#define GEMM_KERNEL_B SHGEMM_KERNEL
#define GEMM_NN SHGEMM_NN
#define GEMM_CN SHGEMM_TN
#define GEMM_TN SHGEMM_TN
#define GEMM_NC SHGEMM_NT
#define GEMM_NT SHGEMM_NT
#define GEMM_CC SHGEMM_TT
#define GEMM_CT SHGEMM_TT
#define GEMM_TC SHGEMM_TT
#define GEMM_TT SHGEMM_TT
#define GEMM_NR SHGEMM_NN
#define GEMM_TR SHGEMM_TN
#define GEMM_CR SHGEMM_TN
#define GEMM_RN SHGEMM_NN
#define GEMM_RT SHGEMM_NT
#define GEMM_RC SHGEMM_NT
#define GEMM_RR SHGEMM_NN
#define GEMM_ONCOPY SHGEMM_ONCOPY
#define GEMM_OTCOPY SHGEMM_OTCOPY
#define GEMM_INCOPY SHGEMM_INCOPY
#define GEMM_ITCOPY SHGEMM_ITCOPY

#define GEMM_THREAD_NN SHGEMM_THREAD_NN
#define GEMM_THREAD_CN SHGEMM_THREAD_TN
#define GEMM_THREAD_TN SHGEMM_THREAD_TN
#define GEMM_THREAD_NC SHGEMM_THREAD_NT
#define GEMM_THREAD_NT SHGEMM_THREAD_NT
#define GEMM_THREAD_CC SHGEMM_THREAD_TT
#define GEMM_THREAD_CT SHGEMM_THREAD_TT
#define GEMM_THREAD_TC SHGEMM_THREAD_TT
#define GEMM_THREAD_TT SHGEMM_THREAD_TT
#define GEMM_THREAD_NR SHGEMM_THREAD_NN
#define GEMM_THREAD_TR SHGEMM_THREAD_TN
#define GEMM_THREAD_CR SHGEMM_THREAD_TN
#define GEMM_THREAD_RN SHGEMM_THREAD_NN
#define GEMM_THREAD_RT SHGEMM_THREAD_NT
#define GEMM_THREAD_RC SHGEMM_THREAD_NT
#define GEMM_THREAD_RR SHGEMM_THREAD_NN



#elif defined(BFLOAT16) #elif defined(BFLOAT16)




+ 72
- 22
common_param.h View File

@@ -48,6 +48,21 @@ typedef struct {
int dtb_entries; int dtb_entries;
int switch_ratio; int switch_ratio;
int offsetA, offsetB, align; int offsetA, offsetB, align;
#if BUILD_HFLOAT16 == 1
int shgemm_p, shgemm_q, shgemm_r;
int shgemm_unroll_m, shgemm_unroll_n, shgemm_unroll_mn;

int (*shgemm_kernel )(BLASLONG, BLASLONG, BLASLONG, float, hfloat16 *, hfloat16 *, float *, BLASLONG);
int (*shgemm_beta )(BLASLONG, BLASLONG, BLASLONG, float, hfloat16 *, BLASLONG, hfloat16 *, BLASLONG, float *, BLASLONG);

int (*shgemm_incopy )(BLASLONG, BLASLONG, hfloat16 *, BLASLONG, hfloat16 *);
int (*shgemm_itcopy )(BLASLONG, BLASLONG, hfloat16 *, BLASLONG, hfloat16 *);
int (*shgemm_oncopy )(BLASLONG, BLASLONG, hfloat16 *, BLASLONG, hfloat16 *);
int (*shgemm_otcopy )(BLASLONG, BLASLONG, hfloat16 *, BLASLONG, hfloat16 *);


#endif



#if BUILD_BFLOAT16 == 1 #if BUILD_BFLOAT16 == 1
int sbgemm_p, sbgemm_q, sbgemm_r; int sbgemm_p, sbgemm_q, sbgemm_r;
@@ -64,10 +79,10 @@ typedef struct {
float (*sbamin_k) (BLASLONG, float *, BLASLONG); float (*sbamin_k) (BLASLONG, float *, BLASLONG);
float (*sbmax_k) (BLASLONG, float *, BLASLONG); float (*sbmax_k) (BLASLONG, float *, BLASLONG);
float (*sbmin_k) (BLASLONG, float *, BLASLONG); float (*sbmin_k) (BLASLONG, float *, BLASLONG);
BLASLONG (*isbamax_k)(BLASLONG, float *, BLASLONG);
BLASLONG (*isbamin_k)(BLASLONG, float *, BLASLONG);
BLASLONG (*isbmax_k) (BLASLONG, float *, BLASLONG);
BLASLONG (*isbmin_k) (BLASLONG, float *, BLASLONG);
BLASLONG (*isbamax_k)(BLASLONG, float *, BLASLONG);
BLASLONG (*isbamin_k)(BLASLONG, float *, BLASLONG);
BLASLONG (*isbmax_k) (BLASLONG, float *, BLASLONG);
BLASLONG (*isbmin_k) (BLASLONG, float *, BLASLONG);


float (*sbnrm2_k) (BLASLONG, float *, BLASLONG); float (*sbnrm2_k) (BLASLONG, float *, BLASLONG);
float (*sbasum_k) (BLASLONG, float *, BLASLONG); float (*sbasum_k) (BLASLONG, float *, BLASLONG);
@@ -180,12 +195,12 @@ BLASLONG (*isbmin_k) (BLASLONG, float *, BLASLONG);
#endif #endif


#if (BUILD_SINGLE==1) || (BUILD_DOUBLE ==1) || (BUILD_COMPLEX==1) #if (BUILD_SINGLE==1) || (BUILD_DOUBLE ==1) || (BUILD_COMPLEX==1)
BLASLONG (*isamax_k)(BLASLONG, float *, BLASLONG);
BLASLONG (*isamax_k)(BLASLONG, float *, BLASLONG);
#endif #endif
#if (BUILD_SINGLE==1) || (BUILD_COMPLEX==1) #if (BUILD_SINGLE==1) || (BUILD_COMPLEX==1)
BLASLONG (*isamin_k)(BLASLONG, float *, BLASLONG);
BLASLONG (*ismax_k) (BLASLONG, float *, BLASLONG);
BLASLONG (*ismin_k) (BLASLONG, float *, BLASLONG);
BLASLONG (*isamin_k)(BLASLONG, float *, BLASLONG);
BLASLONG (*ismax_k) (BLASLONG, float *, BLASLONG);
BLASLONG (*ismin_k) (BLASLONG, float *, BLASLONG);
float (*snrm2_k) (BLASLONG, float *, BLASLONG); float (*snrm2_k) (BLASLONG, float *, BLASLONG);
float (*sasum_k) (BLASLONG, float *, BLASLONG); float (*sasum_k) (BLASLONG, float *, BLASLONG);
#endif #endif
@@ -316,10 +331,10 @@ BLASLONG (*ismin_k) (BLASLONG, float *, BLASLONG);
double (*damin_k) (BLASLONG, double *, BLASLONG); double (*damin_k) (BLASLONG, double *, BLASLONG);
double (*dmax_k) (BLASLONG, double *, BLASLONG); double (*dmax_k) (BLASLONG, double *, BLASLONG);
double (*dmin_k) (BLASLONG, double *, BLASLONG); double (*dmin_k) (BLASLONG, double *, BLASLONG);
BLASLONG (*idamax_k)(BLASLONG, double *, BLASLONG);
BLASLONG (*idamin_k)(BLASLONG, double *, BLASLONG);
BLASLONG (*idmax_k) (BLASLONG, double *, BLASLONG);
BLASLONG (*idmin_k) (BLASLONG, double *, BLASLONG);
BLASLONG (*idamax_k)(BLASLONG, double *, BLASLONG);
BLASLONG (*idamin_k)(BLASLONG, double *, BLASLONG);
BLASLONG (*idmax_k) (BLASLONG, double *, BLASLONG);
BLASLONG (*idmin_k) (BLASLONG, double *, BLASLONG);


double (*dnrm2_k) (BLASLONG, double *, BLASLONG); double (*dnrm2_k) (BLASLONG, double *, BLASLONG);
double (*dasum_k) (BLASLONG, double *, BLASLONG); double (*dasum_k) (BLASLONG, double *, BLASLONG);
@@ -435,10 +450,10 @@ BLASLONG (*idmin_k) (BLASLONG, double *, BLASLONG);
xdouble (*qamin_k) (BLASLONG, xdouble *, BLASLONG); xdouble (*qamin_k) (BLASLONG, xdouble *, BLASLONG);
xdouble (*qmax_k) (BLASLONG, xdouble *, BLASLONG); xdouble (*qmax_k) (BLASLONG, xdouble *, BLASLONG);
xdouble (*qmin_k) (BLASLONG, xdouble *, BLASLONG); xdouble (*qmin_k) (BLASLONG, xdouble *, BLASLONG);
BLASLONG (*iqamax_k)(BLASLONG, xdouble *, BLASLONG);
BLASLONG (*iqamin_k)(BLASLONG, xdouble *, BLASLONG);
BLASLONG (*iqmax_k) (BLASLONG, xdouble *, BLASLONG);
BLASLONG (*iqmin_k) (BLASLONG, xdouble *, BLASLONG);
BLASLONG (*iqamax_k)(BLASLONG, xdouble *, BLASLONG);
BLASLONG (*iqamin_k)(BLASLONG, xdouble *, BLASLONG);
BLASLONG (*iqmax_k) (BLASLONG, xdouble *, BLASLONG);
BLASLONG (*iqmin_k) (BLASLONG, xdouble *, BLASLONG);


xdouble (*qnrm2_k) (BLASLONG, xdouble *, BLASLONG); xdouble (*qnrm2_k) (BLASLONG, xdouble *, BLASLONG);
xdouble (*qasum_k) (BLASLONG, xdouble *, BLASLONG); xdouble (*qasum_k) (BLASLONG, xdouble *, BLASLONG);
@@ -528,8 +543,8 @@ BLASLONG (*iqmin_k) (BLASLONG, xdouble *, BLASLONG);
float (*camax_k) (BLASLONG, float *, BLASLONG); float (*camax_k) (BLASLONG, float *, BLASLONG);
float (*camin_k) (BLASLONG, float *, BLASLONG); float (*camin_k) (BLASLONG, float *, BLASLONG);
BLASLONG (*icamax_k)(BLASLONG, float *, BLASLONG);
BLASLONG (*icamin_k)(BLASLONG, float *, BLASLONG);
BLASLONG (*icamax_k)(BLASLONG, float *, BLASLONG);
BLASLONG (*icamin_k)(BLASLONG, float *, BLASLONG);


float (*cnrm2_k) (BLASLONG, float *, BLASLONG); float (*cnrm2_k) (BLASLONG, float *, BLASLONG);
float (*casum_k) (BLASLONG, float *, BLASLONG); float (*casum_k) (BLASLONG, float *, BLASLONG);
@@ -739,8 +754,8 @@ BLASLONG (*icamin_k)(BLASLONG, float *, BLASLONG);


double (*zamax_k) (BLASLONG, double *, BLASLONG); double (*zamax_k) (BLASLONG, double *, BLASLONG);
double (*zamin_k) (BLASLONG, double *, BLASLONG); double (*zamin_k) (BLASLONG, double *, BLASLONG);
BLASLONG (*izamax_k)(BLASLONG, double *, BLASLONG);
BLASLONG (*izamin_k)(BLASLONG, double *, BLASLONG);
BLASLONG (*izamax_k)(BLASLONG, double *, BLASLONG);
BLASLONG (*izamin_k)(BLASLONG, double *, BLASLONG);


double (*znrm2_k) (BLASLONG, double *, BLASLONG); double (*znrm2_k) (BLASLONG, double *, BLASLONG);
double (*zasum_k) (BLASLONG, double *, BLASLONG); double (*zasum_k) (BLASLONG, double *, BLASLONG);
@@ -950,8 +965,8 @@ BLASLONG (*izamin_k)(BLASLONG, double *, BLASLONG);


xdouble (*xamax_k) (BLASLONG, xdouble *, BLASLONG); xdouble (*xamax_k) (BLASLONG, xdouble *, BLASLONG);
xdouble (*xamin_k) (BLASLONG, xdouble *, BLASLONG); xdouble (*xamin_k) (BLASLONG, xdouble *, BLASLONG);
BLASLONG (*ixamax_k)(BLASLONG, xdouble *, BLASLONG);
BLASLONG (*ixamin_k)(BLASLONG, xdouble *, BLASLONG);
BLASLONG (*ixamax_k)(BLASLONG, xdouble *, BLASLONG);
BLASLONG (*ixamin_k)(BLASLONG, xdouble *, BLASLONG);


xdouble (*xnrm2_k) (BLASLONG, xdouble *, BLASLONG); xdouble (*xnrm2_k) (BLASLONG, xdouble *, BLASLONG);
xdouble (*xasum_k) (BLASLONG, xdouble *, BLASLONG); xdouble (*xasum_k) (BLASLONG, xdouble *, BLASLONG);
@@ -1229,6 +1244,15 @@ extern gotoblas_t *gotoblas;


#define HAVE_EX_L2 gotoblas -> exclusive_cache #define HAVE_EX_L2 gotoblas -> exclusive_cache


#if (BUILD_HFLOAT16==1)
#define SHGEMM_P gotoblas -> shgemm_p
#define SHGEMM_Q gotoblas -> shgemm_q
#define SHGEMM_R gotoblas -> shgemm_r
#define SHGEMM_UNROLL_M gotoblas -> shgemm_unroll_m
#define SHGEMM_UNROLL_N gotoblas -> shgemm_unroll_n
#define SHGEMM_UNROLL_MN gotoblas -> shgemm_unroll_mn
#endif

#if (BUILD_BFLOAT16==1) #if (BUILD_BFLOAT16==1)
#define SBGEMM_P gotoblas -> sbgemm_p #define SBGEMM_P gotoblas -> sbgemm_p
#define SBGEMM_Q gotoblas -> sbgemm_q #define SBGEMM_Q gotoblas -> sbgemm_q
@@ -1357,6 +1381,19 @@ extern gotoblas_t *gotoblas;
#define HAVE_EX_L2 0 #define HAVE_EX_L2 0
#endif #endif


#if (BUILD_HFLOAT16 == 1)
#define SHGEMM_P SHGEMM_DEFAULT_P
#define SHGEMM_Q SHGEMM_DEFAULT_Q
#define SHGEMM_R SHGEMM_DEFAULT_R
#define SHGEMM_UNROLL_M SHGEMM_DEFAULT_UNROLL_M
#define SHGEMM_UNROLL_N SHGEMM_DEFAULT_UNROLL_N
#ifdef SHGEMM_DEFAULT_UNROLL_MN
#define SHGEMM_UNROLL_MN SHGEMM_DEFAULT_UNROLL_MN
#else
#define SHGEMM_UNROLL_MN MAX((SHGEMM_UNROLL_M), (SHGEMM_UNROLL_N))
#endif
#endif

#if (BUILD_BFLOAT16 == 1) #if (BUILD_BFLOAT16 == 1)
#define SBGEMM_P SBGEMM_DEFAULT_P #define SBGEMM_P SBGEMM_DEFAULT_P
#define SBGEMM_Q SBGEMM_DEFAULT_Q #define SBGEMM_Q SBGEMM_DEFAULT_Q
@@ -1478,6 +1515,7 @@ extern gotoblas_t *gotoblas;




#endif #endif

#endif #endif


#ifndef COMPLEX #ifndef COMPLEX
@@ -1505,6 +1543,18 @@ extern gotoblas_t *gotoblas;
#define GEMM_DEFAULT_R DGEMM_DEFAULT_R #define GEMM_DEFAULT_R DGEMM_DEFAULT_R
#define GEMM_DEFAULT_UNROLL_M DGEMM_DEFAULT_UNROLL_M #define GEMM_DEFAULT_UNROLL_M DGEMM_DEFAULT_UNROLL_M
#define GEMM_DEFAULT_UNROLL_N DGEMM_DEFAULT_UNROLL_N #define GEMM_DEFAULT_UNROLL_N DGEMM_DEFAULT_UNROLL_N
#elif defined(HFLOAT16)
#define GEMM_P SHGEMM_P
#define GEMM_Q SHGEMM_Q
#define GEMM_R SHGEMM_R
#define GEMM_UNROLL_M SHGEMM_UNROLL_M
#define GEMM_UNROLL_N SHGEMM_UNROLL_N
#define GEMM_UNROLL_MN SHGEMM_UNROLL_MN
#define GEMM_DEFAULT_P SHGEMM_DEFAULT_P
#define GEMM_DEFAULT_Q SHGEMM_DEFAULT_Q
#define GEMM_DEFAULT_R SHGEMM_DEFAULT_R
#define GEMM_DEFAULT_UNROLL_M SHGEMM_DEFAULT_UNROLL_M
#define GEMM_DEFAULT_UNROLL_N SHGEMM_DEFAULT_UNROLL_N
#elif defined(BFLOAT16) #elif defined(BFLOAT16)
#define GEMM_P SBGEMM_P #define GEMM_P SBGEMM_P
#define GEMM_Q SBGEMM_Q #define GEMM_Q SBGEMM_Q


+ 72
- 0
common_sh.h View File

@@ -0,0 +1,72 @@
#ifndef COMMON_SH_H
#define COMMON_SH_H

#ifndef DYNAMIC_ARCH

#define SHGEMM_ONCOPY shgemm_oncopy
#define SHGEMM_OTCOPY shgemm_otcopy

#if SGEMM_DEFAULT_UNROLL_M == SGEMM_DEFAULT_UNROLL_N
#define SHGEMM_INCOPY shgemm_oncopy
#define SHGEMM_ITCOPY shgemm_otcopy
#else
#define SHGEMM_INCOPY shgemm_incopy
#define SHGEMM_ITCOPY shgemm_itcopy
#endif

#define SHGEMM_BETA shgemm_beta
#define SHGEMM_KERNEL shgemm_kernel


#else // #DYNAMIC_ARCH

#define SHGEMM_ONCOPY gotoblas -> shgemm_oncopy
#define SHGEMM_OTCOPY gotoblas -> shgemm_otcopy
#if SGEMM_DEFAULT_UNROLL_M == SGEMM_DEFAULT_UNROLL_N
#define SHGEMM_INCOPY gotoblas -> shgemm_oncopy
#define SHGEMM_ITCOPY gotoblas -> shgemm_otcopy
#else
#define SHGEMM_INCOPY gotoblas -> shgemm_incopy
#define SHGEMM_ITCOPY gotoblas -> shgemm_itcopy
#endif

#define SHGEMM_BETA gotoblas -> shgemm_beta
#define SHGEMM_KERNEL gotoblas -> shgemm_kernel
#endif // #DYNAMIC_ARCH

#define SHGEMM_NN shgemm_nn
#define SHGEMM_CN shgemm_tn
#define SHGEMM_TN shgemm_tn
#define SHGEMM_NC shgemm_nt
#define SHGEMM_NT shgemm_nt
#define SHGEMM_CC shgemm_tt
#define SHGEMM_CT shgemm_tt
#define SHGEMM_TC shgemm_tt
#define SHGEMM_TT shgemm_tt
#define SHGEMM_NR shgemm_nn
#define SHGEMM_TR shgemm_tn
#define SHGEMM_CR shgemm_tn
#define SHGEMM_RN shgemm_nn
#define SHGEMM_RT shgemm_nt
#define SHGEMM_RC shgemm_nt
#define SHGEMM_RR shgemm_nn

#define SHGEMM_THREAD_NN shgemm_thread_nn
#define SHGEMM_THREAD_CN shgemm_thread_tn
#define SHGEMM_THREAD_TN shgemm_thread_tn
#define SHGEMM_THREAD_NC shgemm_thread_nt
#define SHGEMM_THREAD_NT shgemm_thread_nt
#define SHGEMM_THREAD_CC shgemm_thread_tt
#define SHGEMM_THREAD_CT shgemm_thread_tt
#define SHGEMM_THREAD_TC shgemm_thread_tt
#define SHGEMM_THREAD_TT shgemm_thread_tt
#define SHGEMM_THREAD_NR shgemm_thread_nn
#define SHGEMM_THREAD_TR shgemm_thread_tn
#define SHGEMM_THREAD_CR shgemm_thread_tn
#define SHGEMM_THREAD_RN shgemm_thread_nn
#define SHGEMM_THREAD_RT shgemm_thread_nt
#define SHGEMM_THREAD_RC shgemm_thread_nt
#define SHGEMM_THREAD_RR shgemm_thread_nn


#endif // #COMMON_SH_H

+ 6
- 0
driver/level3/CMakeLists.txt View File

@@ -18,6 +18,12 @@ foreach (GEMM_DEFINE ${GEMM_DEFINES})
GenerateNamedObjects("gemm.c" "${GEMM_DEFINE};THREADED_LEVEL3" "gemm_thread_${GEMM_DEFINE_LC}" 0 "" "" false "BFLOAT16") GenerateNamedObjects("gemm.c" "${GEMM_DEFINE};THREADED_LEVEL3" "gemm_thread_${GEMM_DEFINE_LC}" 0 "" "" false "BFLOAT16")
endif () endif ()
endif () endif ()
if (BUILD_HFLOAT16)
GenerateNamedObjects("gemm.c" "${GEMM_DEFINE}" "gemm_${GEMM_DEFINE_LC}" 0 "" "" false "HFLOAT16")
if (USE_THREAD AND NOT USE_SIMPLE_THREADED_LEVEL3)
GenerateNamedObjects("gemm.c" "${GEMM_DEFINE};THREADED_LEVEL3" "gemm_thread_${GEMM_DEFINE_LC}" 0 "" "" false "HFLOAT16")
endif ()
endif ()
endforeach () endforeach ()


if ( BUILD_COMPLEX16 AND NOT BUILD_DOUBLE) if ( BUILD_COMPLEX16 AND NOT BUILD_DOUBLE)


+ 71
- 16
driver/level3/Makefile View File

@@ -23,6 +23,10 @@ ifeq ($(BUILD_BFLOAT16),1)
SBBLASOBJS += sbgemm_nn.$(SUFFIX) sbgemm_nt.$(SUFFIX) sbgemm_tn.$(SUFFIX) sbgemm_tt.$(SUFFIX) SBBLASOBJS += sbgemm_nn.$(SUFFIX) sbgemm_nt.$(SUFFIX) sbgemm_tn.$(SUFFIX) sbgemm_tt.$(SUFFIX)
endif endif


ifeq ($(BUILD_HFLOAT16),1)
SHBLASOBJS += shgemm_nn.$(SUFFIX) shgemm_nt.$(SUFFIX) shgemm_tn.$(SUFFIX) shgemm_tt.$(SUFFIX)
endif

SBLASOBJS += \ SBLASOBJS += \
sgemm_nn.$(SUFFIX) sgemm_nt.$(SUFFIX) sgemm_tn.$(SUFFIX) sgemm_tt.$(SUFFIX) \ sgemm_nn.$(SUFFIX) sgemm_nt.$(SUFFIX) sgemm_tn.$(SUFFIX) sgemm_tt.$(SUFFIX) \
strmm_LNUU.$(SUFFIX) strmm_LNUN.$(SUFFIX) strmm_LNLU.$(SUFFIX) strmm_LNLN.$(SUFFIX) \ strmm_LNUU.$(SUFFIX) strmm_LNUN.$(SUFFIX) strmm_LNLU.$(SUFFIX) strmm_LNLN.$(SUFFIX) \
@@ -210,6 +214,9 @@ ifneq ($(USE_SIMPLE_THREADED_LEVEL3), 1)
ifeq ($(BUILD_BFLOAT16),1) ifeq ($(BUILD_BFLOAT16),1)
SBBLASOBJS += sbgemm_thread_nn.$(SUFFIX) sbgemm_thread_nt.$(SUFFIX) sbgemm_thread_tn.$(SUFFIX) sbgemm_thread_tt.$(SUFFIX) SBBLASOBJS += sbgemm_thread_nn.$(SUFFIX) sbgemm_thread_nt.$(SUFFIX) sbgemm_thread_tn.$(SUFFIX) sbgemm_thread_tt.$(SUFFIX)
endif endif
ifeq ($(BUILD_HFLOAT16),1)
SHBLASOBJS += shgemm_thread_nn.$(SUFFIX) shgemm_thread_nt.$(SUFFIX) shgemm_thread_tn.$(SUFFIX) shgemm_thread_tt.$(SUFFIX)
endif
SBLASOBJS += sgemm_thread_nn.$(SUFFIX) sgemm_thread_nt.$(SUFFIX) sgemm_thread_tn.$(SUFFIX) sgemm_thread_tt.$(SUFFIX) SBLASOBJS += sgemm_thread_nn.$(SUFFIX) sgemm_thread_nt.$(SUFFIX) sgemm_thread_tn.$(SUFFIX) sgemm_thread_tt.$(SUFFIX)
DBLASOBJS += dgemm_thread_nn.$(SUFFIX) dgemm_thread_nt.$(SUFFIX) dgemm_thread_tn.$(SUFFIX) dgemm_thread_tt.$(SUFFIX) DBLASOBJS += dgemm_thread_nn.$(SUFFIX) dgemm_thread_nt.$(SUFFIX) dgemm_thread_tn.$(SUFFIX) dgemm_thread_tt.$(SUFFIX)
QBLASOBJS += qgemm_thread_nn.$(SUFFIX) qgemm_thread_nt.$(SUFFIX) qgemm_thread_tn.$(SUFFIX) qgemm_thread_tt.$(SUFFIX) QBLASOBJS += qgemm_thread_nn.$(SUFFIX) qgemm_thread_nt.$(SUFFIX) qgemm_thread_tn.$(SUFFIX) qgemm_thread_tt.$(SUFFIX)
@@ -344,16 +351,28 @@ endif
all :: all ::


sbgemm_nn.$(SUFFIX) : gemm.c level3.c ../../param.h sbgemm_nn.$(SUFFIX) : gemm.c level3.c ../../param.h
$(CC) $(CFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DNN $< -o $(@F)
$(CC) $(CFLAGS) $(BLOCKS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX -DNN $< -o $(@F)


sbgemm_nt.$(SUFFIX) : gemm.c level3.c ../../param.h sbgemm_nt.$(SUFFIX) : gemm.c level3.c ../../param.h
$(CC) $(CFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DNT $< -o $(@F)
$(CC) $(CFLAGS) $(BLOCKS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX -DNT $< -o $(@F)


sbgemm_tn.$(SUFFIX) : gemm.c level3.c ../../param.h sbgemm_tn.$(SUFFIX) : gemm.c level3.c ../../param.h
$(CC) $(CFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DTN $< -o $(@F)
$(CC) $(CFLAGS) $(BLOCKS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX -DTN $< -o $(@F)


sbgemm_tt.$(SUFFIX) : gemm.c level3.c ../../param.h sbgemm_tt.$(SUFFIX) : gemm.c level3.c ../../param.h
$(CC) $(CFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DTT $< -o $(@F)
$(CC) $(CFLAGS) $(BLOCKS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX -DTT $< -o $(@F)

shgemm_nn.$(SUFFIX) : gemm.c level3.c ../../param.h
$(CC) $(CFLAGS) $(BLOCKS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX -DNN $< -o $(@F)

shgemm_nt.$(SUFFIX) : gemm.c level3.c ../../param.h
$(CC) $(CFLAGS) $(BLOCKS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX -DNT $< -o $(@F)

shgemm_tn.$(SUFFIX) : gemm.c level3.c ../../param.h
$(CC) $(CFLAGS) $(BLOCKS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX -DTN $< -o $(@F)

shgemm_tt.$(SUFFIX) : gemm.c level3.c ../../param.h
$(CC) $(CFLAGS) $(BLOCKS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX -DTT $< -o $(@F)


sgemm_nn.$(SUFFIX) : gemm.c level3.c ../../param.h sgemm_nn.$(SUFFIX) : gemm.c level3.c ../../param.h
$(CC) $(CFLAGS) $(BLOCKS) -c -UDOUBLE -UCOMPLEX -DNN $< -o $(@F) $(CC) $(CFLAGS) $(BLOCKS) -c -UDOUBLE -UCOMPLEX -DNN $< -o $(@F)
@@ -551,16 +570,28 @@ beta_thread.$(SUFFIX) : beta_thread.c ../../common.h
$(CC) -c $(CFLAGS) $< -o $(@F) $(CC) -c $(CFLAGS) $< -o $(@F)


sbgemm_thread_nn.$(SUFFIX) : gemm.c level3_thread.c ../../param.h sbgemm_thread_nn.$(SUFFIX) : gemm.c level3_thread.c ../../param.h
$(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DNN $< -o $(@F)
$(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DBFLOAT16 -UDOUBLE -UCOMPLEX -DNN $< -o $(@F)


sbgemm_thread_nt.$(SUFFIX) : gemm.c level3_thread.c ../../param.h sbgemm_thread_nt.$(SUFFIX) : gemm.c level3_thread.c ../../param.h
$(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DNT $< -o $(@F)
$(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DBFLOAT16 -UDOUBLE -UCOMPLEX -DNT $< -o $(@F)


sbgemm_thread_tn.$(SUFFIX) : gemm.c level3_thread.c ../../param.h sbgemm_thread_tn.$(SUFFIX) : gemm.c level3_thread.c ../../param.h
$(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DTN $< -o $(@F)
$(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DBFLOAT16 -UDOUBLE -UCOMPLEX -DTN $< -o $(@F)


sbgemm_thread_tt.$(SUFFIX) : gemm.c level3_thread.c ../../param.h sbgemm_thread_tt.$(SUFFIX) : gemm.c level3_thread.c ../../param.h
$(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DTT $< -o $(@F)
$(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DBFLOAT16 -UDOUBLE -UCOMPLEX -DTT $< -o $(@F)

shgemm_thread_nn.$(SUFFIX) : gemm.c level3_thread.c ../../param.h
$(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHFLOAT16 -UDOUBLE -UCOMPLEX -DNN $< -o $(@F)

shgemm_thread_nt.$(SUFFIX) : gemm.c level3_thread.c ../../param.h
$(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHFLOAT16 -UDOUBLE -UCOMPLEX -DNT $< -o $(@F)

shgemm_thread_tn.$(SUFFIX) : gemm.c level3_thread.c ../../param.h
$(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHFLOAT16 -UDOUBLE -UCOMPLEX -DTN $< -o $(@F)

shgemm_thread_tt.$(SUFFIX) : gemm.c level3_thread.c ../../param.h
$(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHFLOAT16 -UDOUBLE -UCOMPLEX -DTT $< -o $(@F)


sgemm_thread_nn.$(SUFFIX) : gemm.c level3_thread.c ../../param.h sgemm_thread_nn.$(SUFFIX) : gemm.c level3_thread.c ../../param.h
$(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -UDOUBLE -UCOMPLEX -DNN $< -o $(@F) $(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -UDOUBLE -UCOMPLEX -DNN $< -o $(@F)
@@ -2736,16 +2767,28 @@ xtrsm_RCLN.$(SUFFIX) : trsm_R.c
$(CC) -c $(CFLAGS) -DCOMPLEX -DXDOUBLE -DTRANSA -UUPPER -UUNIT -DCONJ $< -o $(@F) $(CC) -c $(CFLAGS) -DCOMPLEX -DXDOUBLE -DTRANSA -UUPPER -UUNIT -DCONJ $< -o $(@F)


sbgemm_nn.$(PSUFFIX) : gemm.c level3.c ../../param.h sbgemm_nn.$(PSUFFIX) : gemm.c level3.c ../../param.h
$(CC) $(PFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DNN $< -o $(@F)
$(CC) $(PFLAGS) $(BLOCKS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX -DNN $< -o $(@F)


sbgemm_nt.$(PSUFFIX) : gemm.c level3.c ../../param.h sbgemm_nt.$(PSUFFIX) : gemm.c level3.c ../../param.h
$(CC) $(PFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DNT $< -o $(@F)
$(CC) $(PFLAGS) $(BLOCKS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX -DNT $< -o $(@F)


sbgemm_tn.$(PSUFFIX) : gemm.c level3.c ../../param.h sbgemm_tn.$(PSUFFIX) : gemm.c level3.c ../../param.h
$(CC) $(PFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DTN $< -o $(@F)
$(CC) $(PFLAGS) $(BLOCKS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX -DTN $< -o $(@F)


sbgemm_tt.$(PSUFFIX) : gemm.c level3.c ../../param.h sbgemm_tt.$(PSUFFIX) : gemm.c level3.c ../../param.h
$(CC) $(PFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DTT $< -o $(@F)
$(CC) $(PFLAGS) $(BLOCKS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX -DTT $< -o $(@F)

shgemm_nn.$(PSUFFIX) : gemm.c level3.c ../../param.h
$(CC) $(PFLAGS) $(BLOCKS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX -DNN $< -o $(@F)

shgemm_nt.$(PSUFFIX) : gemm.c level3.c ../../param.h
$(CC) $(PFLAGS) $(BLOCKS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX -DNT $< -o $(@F)

shgemm_tn.$(PSUFFIX) : gemm.c level3.c ../../param.h
$(CC) $(PFLAGS) $(BLOCKS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX -DTN $< -o $(@F)

shgemm_tt.$(PSUFFIX) : gemm.c level3.c ../../param.h
$(CC) $(PFLAGS) $(BLOCKS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX -DTT $< -o $(@F)


sgemm_nn.$(PSUFFIX) : gemm.c level3.c ../../param.h sgemm_nn.$(PSUFFIX) : gemm.c level3.c ../../param.h
$(CC) $(PFLAGS) $(BLOCKS) -c -UDOUBLE -UCOMPLEX -DNN $< -o $(@F) $(CC) $(PFLAGS) $(BLOCKS) -c -UDOUBLE -UCOMPLEX -DNN $< -o $(@F)
@@ -2959,16 +3002,28 @@ zgemm_batch_thread.$(SUFFIX) : gemm_batch_thread.c ../../common.h




sbgemm_thread_nn.$(PSUFFIX) : gemm.c level3_thread.c ../../param.h sbgemm_thread_nn.$(PSUFFIX) : gemm.c level3_thread.c ../../param.h
$(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DNN $< -o $(@F)
$(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DBFLOAT16 -UDOUBLE -UCOMPLEX -DNN $< -o $(@F)


sbgemm_thread_nt.$(PSUFFIX) : gemm.c level3_thread.c ../../param.h sbgemm_thread_nt.$(PSUFFIX) : gemm.c level3_thread.c ../../param.h
$(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DNT $< -o $(@F)
$(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DBFLOAT16 -UDOUBLE -UCOMPLEX -DNT $< -o $(@F)


sbgemm_thread_tn.$(PSUFFIX) : gemm.c level3_thread.c ../../param.h sbgemm_thread_tn.$(PSUFFIX) : gemm.c level3_thread.c ../../param.h
$(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DTN $< -o $(@F)
$(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DBFLOAT16 -UDOUBLE -UCOMPLEX -DTN $< -o $(@F)


sbgemm_thread_tt.$(PSUFFIX) : gemm.c level3_thread.c ../../param.h sbgemm_thread_tt.$(PSUFFIX) : gemm.c level3_thread.c ../../param.h
$(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DTT $< -o $(@F)
$(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DBFLOAT16 -UDOUBLE -UCOMPLEX -DTT $< -o $(@F)

shgemm_thread_nn.$(PSUFFIX) : gemm.c level3_thread.c ../../param.h
$(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHFLOAT16 -UDOUBLE -UCOMPLEX -DNN $< -o $(@F)

shgemm_thread_nt.$(PSUFFIX) : gemm.c level3_thread.c ../../param.h
$(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHFLOAT16 -UDOUBLE -UCOMPLEX -DNT $< -o $(@F)

shgemm_thread_tn.$(PSUFFIX) : gemm.c level3_thread.c ../../param.h
$(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHFLOAT16 -UDOUBLE -UCOMPLEX -DTN $< -o $(@F)

shgemm_thread_tt.$(PSUFFIX) : gemm.c level3_thread.c ../../param.h
$(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHFLOAT16 -UDOUBLE -UCOMPLEX -DTT $< -o $(@F)


sgemm_thread_nn.$(PSUFFIX) : gemm.c level3_thread.c ../../param.h sgemm_thread_nn.$(PSUFFIX) : gemm.c level3_thread.c ../../param.h
$(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -UDOUBLE -UCOMPLEX -DNN $< -o $(@F) $(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -UDOUBLE -UCOMPLEX -DNN $< -o $(@F)


+ 1
- 1
driver/others/Makefile View File

@@ -218,7 +218,7 @@ mulx.$(SUFFIX) : $(ARCH)/mulx.c
$(CC) $(CFLAGS) -c -DXDOUBLE -UCOMPLEX $< -o $(@F) $(CC) $(CFLAGS) -c -DXDOUBLE -UCOMPLEX $< -o $(@F)


detect_riscv64.$(SUFFIX): detect_riscv64.c detect_riscv64.$(SUFFIX): detect_riscv64.c
$(CC) $(CFLAGS) -c -march=rv64imafdcv $< -o $(@F)
$(CC) $(CFLAGS) -c -march=rv64imafdcv_zvfh_zfh $< -o $(@F)


xerbla.$(PSUFFIX) : xerbla.c xerbla.$(PSUFFIX) : xerbla.c
$(CC) $(PFLAGS) -c $< -o $(@F) $(CC) $(PFLAGS) -c $< -o $(@F)


+ 22
- 0
driver/others/parameter.c View File

@@ -67,6 +67,11 @@ BLASLONG sbgemm_p = DEFAULT_GEMM_P;
#else #else
BLASLONG sbgemm_p = SBGEMM_P; BLASLONG sbgemm_p = SBGEMM_P;
#endif #endif
#if SHGEMM_P == shgemm_p
BLASLONG shgemm_p = DEFAULT_GEMM_P;
#else
BLASLONG shgemm_p = SHGEMM_P;
#endif
#if SGEMM_P == sgemm_p #if SGEMM_P == sgemm_p
BLASLONG sgemm_p = DEFAULT_GEMM_P; BLASLONG sgemm_p = DEFAULT_GEMM_P;
#else #else
@@ -93,6 +98,11 @@ BLASLONG sbgemm_q = DEFAULT_GEMM_Q;
#else #else
BLASLONG sbgemm_q = SBGEMM_Q; BLASLONG sbgemm_q = SBGEMM_Q;
#endif #endif
#if SHGEMM_Q == shgemm_q
BLASLONG shgemm_q = DEFAULT_GEMM_Q;
#else
BLASLONG shgemm_q = SHGEMM_Q;
#endif
#if SGEMM_Q == sgemm_q #if SGEMM_Q == sgemm_q
BLASLONG sgemm_q = DEFAULT_GEMM_Q; BLASLONG sgemm_q = DEFAULT_GEMM_Q;
#else #else
@@ -119,6 +129,11 @@ BLASLONG sbgemm_r = DEFAULT_GEMM_R;
#else #else
BLASLONG sbgemm_r = SBGEMM_R; BLASLONG sbgemm_r = SBGEMM_R;
#endif #endif
#if SHGEMM_R == shgemm_r
BLASLONG shgemm_r = DEFAULT_GEMM_R;
#else
BLASLONG shgemm_r = SHGEMM_R;
#endif
#if SGEMM_R == sgemm_r #if SGEMM_R == sgemm_r
BLASLONG sgemm_r = DEFAULT_GEMM_R; BLASLONG sgemm_r = DEFAULT_GEMM_R;
#else #else
@@ -526,6 +541,9 @@ void blas_set_parameter(void){


#ifdef BUILD_BFLOAT16 #ifdef BUILD_BFLOAT16
sbgemm_r = (((BUFFER_SIZE - ((SBGEMM_P * SBGEMM_Q * 4 + GEMM_OFFSET_A + GEMM_ALIGN) & ~GEMM_ALIGN)) / (SBGEMM_Q * 4)) - 15) & ~15; sbgemm_r = (((BUFFER_SIZE - ((SBGEMM_P * SBGEMM_Q * 4 + GEMM_OFFSET_A + GEMM_ALIGN) & ~GEMM_ALIGN)) / (SBGEMM_Q * 4)) - 15) & ~15;
#endif
#ifdef BUILD_HFLOAT16
shgemm_r = (((BUFFER_SIZE - ((SHGEMM_P * SHGEMM_Q * 4 + GEMM_OFFSET_A + GEMM_ALIGN) & ~GEMM_ALIGN)) / (SHGEMM_Q * 4)) - 15) & ~15;
#endif #endif
sgemm_r = (((BUFFER_SIZE - ((SGEMM_P * SGEMM_Q * 4 + GEMM_OFFSET_A + GEMM_ALIGN) & ~GEMM_ALIGN)) / (SGEMM_Q * 4)) - 15) & ~15; sgemm_r = (((BUFFER_SIZE - ((SGEMM_P * SGEMM_Q * 4 + GEMM_OFFSET_A + GEMM_ALIGN) & ~GEMM_ALIGN)) / (SGEMM_Q * 4)) - 15) & ~15;
dgemm_r = (((BUFFER_SIZE - ((DGEMM_P * DGEMM_Q * 8 + GEMM_OFFSET_A + GEMM_ALIGN) & ~GEMM_ALIGN)) / (DGEMM_Q * 8)) - 15) & ~15; dgemm_r = (((BUFFER_SIZE - ((DGEMM_P * DGEMM_Q * 8 + GEMM_OFFSET_A + GEMM_ALIGN) & ~GEMM_ALIGN)) / (DGEMM_Q * 8)) - 15) & ~15;
@@ -619,6 +637,7 @@ void blas_set_parameter(void){
size = BITMASK(cpuid3, 16, 0xff); size = BITMASK(cpuid3, 16, 0xff);


sbgemm_p = 192 * (size + 1); sbgemm_p = 192 * (size + 1);
shgemm_p = 192 * (size + 1);
sgemm_p = 192 * (size + 1); sgemm_p = 192 * (size + 1);
dgemm_p = 96 * (size + 1); dgemm_p = 96 * (size + 1);
cgemm_p = 96 * (size + 1); cgemm_p = 96 * (size + 1);
@@ -634,6 +653,9 @@ void blas_set_parameter(void){


#ifdef BUILD_BFLOAT16 #ifdef BUILD_BFLOAT16
sbgemm_r = (((BUFFER_SIZE - ((SBGEMM_P * SBGEMM_Q * 4 + GEMM_OFFSET_A + GEMM_ALIGN) & ~GEMM_ALIGN)) / (SBGEMM_Q * 4)) - 15) & ~15; sbgemm_r = (((BUFFER_SIZE - ((SBGEMM_P * SBGEMM_Q * 4 + GEMM_OFFSET_A + GEMM_ALIGN) & ~GEMM_ALIGN)) / (SBGEMM_Q * 4)) - 15) & ~15;
#endif
#ifdef BUILD_HFLOAT16
shgemm_r = (((BUFFER_SIZE - ((SHGEMM_P * SHGEMM_Q * 4 + GEMM_OFFSET_A + GEMM_ALIGN) & ~GEMM_ALIGN)) / (SHGEMM_Q * 4)) - 15) & ~15;
#endif #endif
sgemm_r = (((BUFFER_SIZE - ((SGEMM_P * SGEMM_Q * 4 + GEMM_OFFSET_A + GEMM_ALIGN) & ~GEMM_ALIGN)) / (SGEMM_Q * 4)) - 15) & ~15; sgemm_r = (((BUFFER_SIZE - ((SGEMM_P * SGEMM_Q * 4 + GEMM_OFFSET_A + GEMM_ALIGN) & ~GEMM_ALIGN)) / (SGEMM_Q * 4)) - 15) & ~15;
dgemm_r = (((BUFFER_SIZE - ((DGEMM_P * DGEMM_Q * 8 + GEMM_OFFSET_A + GEMM_ALIGN) & ~GEMM_ALIGN)) / (DGEMM_Q * 8)) - 15) & ~15; dgemm_r = (((BUFFER_SIZE - ((DGEMM_P * DGEMM_Q * 8 + GEMM_OFFSET_A + GEMM_ALIGN) & ~GEMM_ALIGN)) / (DGEMM_Q * 8)) - 15) & ~15;


+ 10
- 7
exports/Makefile View File

@@ -39,6 +39,9 @@ endif
ifndef BUILD_BFLOAT16 ifndef BUILD_BFLOAT16
BUILD_BFLOAT16 = 0 BUILD_BFLOAT16 = 0
endif endif
ifndef BUILD_HFLOAT16
BUILD_HFLOAT16 = 0
endif
ifndef BUILD_SINGLE ifndef BUILD_SINGLE
BUILD_SINGLE = 0 BUILD_SINGLE = 0
endif endif
@@ -130,10 +133,10 @@ dll : ../$(LIBDLLNAME)
-Wl,--whole-archive ../$(LIBNAME) -Wl,--no-whole-archive $(FEXTRALIB) $(EXTRALIB) -Wl,--whole-archive ../$(LIBNAME) -Wl,--no-whole-archive $(FEXTRALIB) $(EXTRALIB)


$(LIBPREFIX).def : $(GENSYM) $(LIBPREFIX).def : $(GENSYM)
./$(GENSYM) win2k $(ARCH) dummy $(EXPRECISION) $(NO_CBLAS) $(NO_LAPACK) $(NO_LAPACKE) $(NEED2UNDERSCORES) $(ONLY_CBLAS) "$(SYMBOLPREFIX)" "$(SYMBOLSUFFIX)" $(BUILD_LAPACK_DEPRECATED) $(BUILD_BFLOAT16) $(BUILD_SINGLE) $(BUILD_DOUBLE) $(BUILD_COMPLEX) $(BUILD_COMPLEX16) > $(@F)
./$(GENSYM) win2k $(ARCH) dummy $(EXPRECISION) $(NO_CBLAS) $(NO_LAPACK) $(NO_LAPACKE) $(NEED2UNDERSCORES) $(ONLY_CBLAS) "$(SYMBOLPREFIX)" "$(SYMBOLSUFFIX)" $(BUILD_LAPACK_DEPRECATED) $(BUILD_BFLOAT16) $(BUILD_HFLOAT16) $(BUILD_SINGLE) $(BUILD_DOUBLE) $(BUILD_COMPLEX) $(BUILD_COMPLEX16) > $(@F)


libgoto_hpl.def : $(GENSYM) libgoto_hpl.def : $(GENSYM)
./$(GENSYM) win2khpl $(ARCH) dummy $(EXPRECISION) $(NO_CBLAS) $(NO_LAPACK) $(NO_LAPACKE) $(NEED2UNDERSCORES) $(ONLY_CBLAS) "$(SYMBOLPREFIX)" "$(SYMBOLSUFFIX)" $(BUILD_LAPACK_DEPRECATED) $(BUILD_BFLOAT16) $(BUILD_SINGLE) $(BUILD_DOUBLE) $(BUILD_COMPLEX) $(BUILD_COMPLEX16) > $(@F)
./$(GENSYM) win2khpl $(ARCH) dummy $(EXPRECISION) $(NO_CBLAS) $(NO_LAPACK) $(NO_LAPACKE) $(NEED2UNDERSCORES) $(ONLY_CBLAS) "$(SYMBOLPREFIX)" "$(SYMBOLSUFFIX)" $(BUILD_LAPACK_DEPRECATED) $(BUILD_BFLOAT16) $(BUILD_HFLOAT16) $(BUILD_SINGLE) $(BUILD_DOUBLE) $(BUILD_COMPLEX) $(BUILD_COMPLEX16) > $(@F)


ifeq ($(OSNAME), Darwin) ifeq ($(OSNAME), Darwin)
ifeq ($(FIXED_LIBNAME),1) ifeq ($(FIXED_LIBNAME),1)
@@ -298,23 +301,23 @@ static : ../$(LIBNAME)
rm -f goto.$(SUFFIX) rm -f goto.$(SUFFIX)


osx.def : $(GENSYM) ../Makefile.system ../getarch.c osx.def : $(GENSYM) ../Makefile.system ../getarch.c
./$(GENSYM) osx $(ARCH) "$(BU)" $(EXPRECISION) $(NO_CBLAS) $(NO_LAPACK) $(NO_LAPACKE) $(NEED2UNDERSCORES) $(ONLY_CBLAS) "$(SYMBOLPREFIX)" "$(SYMBOLSUFFIX)" $(BUILD_LAPACK_DEPRECATED) $(BUILD_BFLOAT16) $(BUILD_SINGLE) $(BUILD_DOUBLE) $(BUILD_COMPLEX) $(BUILD_COMPLEX16) > $(@F)
./$(GENSYM) osx $(ARCH) "$(BU)" $(EXPRECISION) $(NO_CBLAS) $(NO_LAPACK) $(NO_LAPACKE) $(NEED2UNDERSCORES) $(ONLY_CBLAS) "$(SYMBOLPREFIX)" "$(SYMBOLSUFFIX)" $(BUILD_LAPACK_DEPRECATED) $(BUILD_BFLOAT16) $(BUILD_HFLOAT16) $(BUILD_SINGLE) $(BUILD_DOUBLE) $(BUILD_COMPLEX) $(BUILD_COMPLEX16) > $(@F)


aix.def : $(GENSYM) ../Makefile.system ../getarch.c aix.def : $(GENSYM) ../Makefile.system ../getarch.c
./$(GENSYM) aix $(ARCH) "$(BU)" $(EXPRECISION) $(NO_CBLAS) $(NO_LAPACK) $(NO_LAPACKE) $(NEED2UNDERSCORES) $(ONLY_CBLAS) "$(SYMBOLPREFIX)" "$(SYMBOLSUFFIX)" $(BUILD_LAPACK_DEPRECATED) $(BUILD_BFLOAT16) $(BUILD_SINGLE) $(BUILD_DOUBLE) $(BUILD_COMPLEX) $(BUILD_COMPLEX16) > $(@F)
./$(GENSYM) aix $(ARCH) "$(BU)" $(EXPRECISION) $(NO_CBLAS) $(NO_LAPACK) $(NO_LAPACKE) $(NEED2UNDERSCORES) $(ONLY_CBLAS) "$(SYMBOLPREFIX)" "$(SYMBOLSUFFIX)" $(BUILD_LAPACK_DEPRECATED) $(BUILD_BFLOAT16) $(BUILD_HFLOAT16) $(BUILD_SINGLE) $(BUILD_DOUBLE) $(BUILD_COMPLEX) $(BUILD_COMPLEX16) > $(@F)


objcopy.def : $(GENSYM) ../Makefile.system ../getarch.c objcopy.def : $(GENSYM) ../Makefile.system ../getarch.c
./$(GENSYM) objcopy $(ARCH) "$(BU)" $(EXPRECISION) $(NO_CBLAS) $(NO_LAPACK) $(NO_LAPACKE) $(NEED2UNDERSCORES) $(ONLY_CBLAS) "$(SYMBOLPREFIX)" "$(SYMBOLSUFFIX)" $(BUILD_LAPACK_DEPRECATED) $(BUILD_BFLOAT16) $(BUILD_SINGLE) $(BUILD_DOUBLE) $(BUILD_COMPLEX) $(BUILD_COMPLEX16) > $(@F)
./$(GENSYM) objcopy $(ARCH) "$(BU)" $(EXPRECISION) $(NO_CBLAS) $(NO_LAPACK) $(NO_LAPACKE) $(NEED2UNDERSCORES) $(ONLY_CBLAS) "$(SYMBOLPREFIX)" "$(SYMBOLSUFFIX)" $(BUILD_LAPACK_DEPRECATED) $(BUILD_BFLOAT16) $(BUILD_HFLOAT16) $(BUILD_SINGLE) $(BUILD_DOUBLE) $(BUILD_COMPLEX) $(BUILD_COMPLEX16) > $(@F)


objconv.def : $(GENSYM) ../Makefile.system ../getarch.c objconv.def : $(GENSYM) ../Makefile.system ../getarch.c
./$(GENSYM) objconv $(ARCH) "$(BU)" $(EXPRECISION) $(NO_CBLAS) $(NO_LAPACK) $(NO_LAPACKE) $(NEED2UNDERSCORES) $(ONLY_CBLAS) "$(SYMBOLPREFIX)" "$(SYMBOLSUFFIX)" $(BUILD_LAPACK_DEPRECATED) $(BUILD_BFLOAT16) $(BUILD_SINGLE) $(BUILD_DOUBLE) $(BUILD_COMPLEX) $(BUILD_COMPLEX16) > $(@F)
./$(GENSYM) objconv $(ARCH) "$(BU)" $(EXPRECISION) $(NO_CBLAS) $(NO_LAPACK) $(NO_LAPACKE) $(NEED2UNDERSCORES) $(ONLY_CBLAS) "$(SYMBOLPREFIX)" "$(SYMBOLSUFFIX)" $(BUILD_LAPACK_DEPRECATED) $(BUILD_BFLOAT16) $(BUILD_HFLOAT16) $(BUILD_SINGLE) $(BUILD_DOUBLE) $(BUILD_COMPLEX) $(BUILD_COMPLEX16) > $(@F)


test : linktest.c test : linktest.c
$(CC) $(CFLAGS) $(LDFLAGS) -w -o linktest linktest.c ../$(LIBSONAME) -lm && echo OK. $(CC) $(CFLAGS) $(LDFLAGS) -w -o linktest linktest.c ../$(LIBSONAME) -lm && echo OK.
rm -f linktest rm -f linktest


linktest.c : $(GENSYM) ../Makefile.system ../getarch.c linktest.c : $(GENSYM) ../Makefile.system ../getarch.c
./$(GENSYM) linktest $(ARCH) "$(BU)" $(EXPRECISION) $(NO_CBLAS) $(NO_LAPACK) $(NO_LAPACKE) $(NEED2UNDERSCORES) $(ONLY_CBLAS) "$(SYMBOLPREFIX)" "$(SYMBOLSUFFIX)" $(BUILD_LAPACK_DEPRECATED) $(BUILD_BFLOAT16) $(BUILD_SINGLE) $(BUILD_DOUBLE) $(BUILD_COMPLEX) $(BUILD_COMPLEX16) > linktest.c
./$(GENSYM) linktest $(ARCH) "$(BU)" $(EXPRECISION) $(NO_CBLAS) $(NO_LAPACK) $(NO_LAPACKE) $(NEED2UNDERSCORES) $(ONLY_CBLAS) "$(SYMBOLPREFIX)" "$(SYMBOLSUFFIX)" $(BUILD_LAPACK_DEPRECATED) $(BUILD_BFLOAT16) $(BUILD_HFLOAT16) $(BUILD_SINGLE) $(BUILD_DOUBLE) $(BUILD_COMPLEX) $(BUILD_COMPLEX16) > linktest.c


clean :: clean ::
@rm -f *.def *.dylib __.SYMDEF* *.renamed @rm -f *.def *.dylib __.SYMDEF* *.renamed


+ 16
- 7
exports/gensymbol View File

@@ -52,6 +52,7 @@ blasobjsz="


blasobjs="lsame xerbla" blasobjs="lsame xerbla"
bfblasobjs="sbgemm sbgemmt sbgemmtr sbgemv sbdot sbstobf16 sbdtobf16 sbf16tos dbf16tod" bfblasobjs="sbgemm sbgemmt sbgemmtr sbgemv sbdot sbstobf16 sbdtobf16 sbf16tos dbf16tod"
hfblasobjs="shgemm"
cblasobjsc=" cblasobjsc="
cblas_caxpy cblas_ccopy cblas_cdotc cblas_cdotu cblas_cgbmv cblas_cgemm cblas_cgemv cblas_caxpy cblas_ccopy cblas_cdotc cblas_cdotu cblas_cgbmv cblas_cgemm cblas_cgemv
cblas_cgerc cblas_cgeru cblas_chbmv cblas_chemm cblas_chemv cblas_cher2 cblas_cher2k cblas_cgerc cblas_cgeru cblas_chbmv cblas_chemm cblas_chemv cblas_cher2 cblas_cher2k
@@ -100,6 +101,7 @@ cblasobjsz="
cblasobjs="cblas_xerbla" cblasobjs="cblas_xerbla"


bfcblasobjs="cblas_sbgemm cblas_sbgemv cblas_sbdot cblas_sbstobf16 cblas_sbdtobf16 cblas_sbf16tos cblas_dbf16tod cblas_sbgemm_batch" bfcblasobjs="cblas_sbgemm cblas_sbgemv cblas_sbdot cblas_sbstobf16 cblas_sbdtobf16 cblas_sbf16tos cblas_dbf16tod cblas_sbgemm_batch"
hfcblasobjs="cblas_shgemm"


exblasobjs=" exblasobjs="
qamax qamin qasum qaxpy qcabs1 qcopy qdot qgbmv qgemm qamax qamin qasum qaxpy qcabs1 qcopy qdot qgbmv qgemm
@@ -3814,6 +3816,8 @@ shift
p16=$9 p16=$9
shift shift
p17=$9 p17=$9
shift
p18=$9


if [ $p13 -eq 1 ]; then if [ $p13 -eq 1 ]; then
blasobjs="$blasobjs $bfblasobjs" blasobjs="$blasobjs $bfblasobjs"
@@ -3821,6 +3825,11 @@ if [ $p13 -eq 1 ]; then
fi fi


if [ $p14 -eq 1 ]; then if [ $p14 -eq 1 ]; then
blasobjs="$blasobjs $hfblasobjs"
cblasobjs="$cblasobjs $hfcblasobjs"
fi

if [ $p15 -eq 1 ]; then
blasobjs="$blasobjs $blasobjss" blasobjs="$blasobjs $blasobjss"
cblasobjs="$cblasobjs $cblasobjss" cblasobjs="$cblasobjs $cblasobjss"
lapackobjs="$lapackobjs $lapackobjss" lapackobjs="$lapackobjs $lapackobjss"
@@ -3833,11 +3842,11 @@ if [ $p14 -eq 1 ]; then
lapackeobjs="$lapackeobjs $lapackeobjss" lapackeobjs="$lapackeobjs $lapackeobjss"
fi fi


if [ $p15 -eq 1 ]; then
if [ $p16 -eq 1 ]; then
blasobjs="$blasobjs $blasobjsd" blasobjs="$blasobjs $blasobjsd"
cblasobjs="$cblasobjs $cblasobjsd" cblasobjs="$cblasobjs $cblasobjsd"
lapackobjs="$lapackobjs $lapackobjsd" lapackobjs="$lapackobjs $lapackobjsd"
if [ $p14 -eq 0 ]; then
if [ $p15 -eq 0 ]; then
lapackobjs2="$lapackobjs2 $lapackobjs2ds" lapackobjs2="$lapackobjs2 $lapackobjs2ds"
fi fi
lapackobjs2="$lapackobjs2 $lapackobjs2d $lapackobjs2dz" lapackobjs2="$lapackobjs2 $lapackobjs2d $lapackobjs2dz"
@@ -3847,14 +3856,14 @@ if [ $p15 -eq 1 ]; then
lapackeobjs="$lapackeobjs $lapackeobjsd" lapackeobjs="$lapackeobjs $lapackeobjsd"
fi fi


if [ $p16 -eq 1 ]; then
if [ $p17 -eq 1 ]; then
blasobjs="$blasobjs $blasobjsc" blasobjs="$blasobjs $blasobjsc"
cblasobjs="$cblasobjs $cblasobjsc" cblasobjs="$cblasobjs $cblasobjsc"
gemm3mobjs="$gemm3mobjs $gemm3mobjsc" gemm3mobjs="$gemm3mobjs $gemm3mobjsc"
cblasgemm3mobjs="$cblasgemm3mobjs $cblasgemm3mobjsc" cblasgemm3mobjs="$cblasgemm3mobjs $cblasgemm3mobjsc"
lapackobjs="$lapackobjs $lapackobjsc" lapackobjs="$lapackobjs $lapackobjsc"
lapackobjs2="$lapackobjs2 $lapackobjs2c $lapackobjs2zc" lapackobjs2="$lapackobjs2 $lapackobjs2c $lapackobjs2zc"
if [ $p14 -eq 0 ]; then
if [ $p15 -eq 0 ]; then
lapackobjs2="$lapackobjs2 $lapackobjs2sc" lapackobjs2="$lapackobjs2 $lapackobjs2sc"
fi fi
lapack_deprecated_objs="$lapack_deprecated_objs $lapack_deprecated_objsc" lapack_deprecated_objs="$lapack_deprecated_objs $lapack_deprecated_objsc"
@@ -3863,17 +3872,17 @@ if [ $p16 -eq 1 ]; then
lapackeobjs="$lapackeobjs $lapackeobjsc" lapackeobjs="$lapackeobjs $lapackeobjsc"
fi fi


if [ $p17 -eq 1 ]; then
if [ $p18 -eq 1 ]; then
blasobjs="$blasobjs $blasobjsz" blasobjs="$blasobjs $blasobjsz"
cblasobjs="$cblasobjs $cblasobjsz" cblasobjs="$cblasobjs $cblasobjsz"
gemm3mobjs="$gemm3mobjs $gemm3mobjsz" gemm3mobjs="$gemm3mobjs $gemm3mobjsz"
cblasgemm3mobjs="$cblasgemm3mobjs $cblasgemm3mobjsz" cblasgemm3mobjs="$cblasgemm3mobjs $cblasgemm3mobjsz"
lapackobjs="$lapackobjs $lapackobjsz" lapackobjs="$lapackobjs $lapackobjsz"
lapackobjs2="$lapackobjs2 $lapackobjs2z" lapackobjs2="$lapackobjs2 $lapackobjs2z"
if [ $p16 -eq 0 ]; then
if [ $p17 -eq 0 ]; then
lapackobjs2="$lapackobjs2 $lapackobjs2zc" lapackobjs2="$lapackobjs2 $lapackobjs2zc"
fi fi
if [ $p15 -eq 0 ]; then
if [ $p16 -eq 0 ]; then
lapackobjs2="$lapackobjs2 $lapackobjs2dz" lapackobjs2="$lapackobjs2 $lapackobjs2dz"
fi fi
lapack_deprecated_objs="$lapack_deprecated_objs $lapack_deprecated_objsz" lapack_deprecated_objs="$lapack_deprecated_objs $lapack_deprecated_objsz"


+ 13
- 8
exports/gensymbol.pl View File

@@ -52,6 +52,7 @@


@blasobjs = (lsame, xerbla); @blasobjs = (lsame, xerbla);
@bfblasobjs = (sbgemm, sbgemmt, sbgemmtr, sbgemv, sbdot, sbstobf16, sbdtobf16, sbf16tos, dbf16tod); @bfblasobjs = (sbgemm, sbgemmt, sbgemmtr, sbgemv, sbdot, sbstobf16, sbdtobf16, sbf16tos, dbf16tod);
@hfblasobjs = (shgemm);
@cblasobjsc = ( @cblasobjsc = (
cblas_caxpy, cblas_ccopy, cblas_cdotc, cblas_cdotu, cblas_cgbmv, cblas_cgemm, cblas_cgemv, cblas_caxpy, cblas_ccopy, cblas_cdotc, cblas_cdotu, cblas_cgbmv, cblas_cgemm, cblas_cgemv,
cblas_cgerc, cblas_cgeru, cblas_chbmv, cblas_chemm, cblas_chemv, cblas_cher2, cblas_cher2k, cblas_cgerc, cblas_cgeru, cblas_chbmv, cblas_chemm, cblas_chemv, cblas_cher2, cblas_cher2k,
@@ -97,7 +98,7 @@
@cblasobjs = ( cblas_xerbla ); @cblasobjs = ( cblas_xerbla );


@bfcblasobjs = (cblas_sbgemm, cblas_sbgemmt, cblas_sbgemmtr, cblas_sbgemv, cblas_sbdot, cblas_sbstobf16, cblas_sbdtobf16, cblas_sbf16tos, cblas_dbf16tod, cblas_sbgemm_batch); @bfcblasobjs = (cblas_sbgemm, cblas_sbgemmt, cblas_sbgemmtr, cblas_sbgemv, cblas_sbdot, cblas_sbstobf16, cblas_sbdtobf16, cblas_sbf16tos, cblas_dbf16tod, cblas_sbgemm_batch);
@hfcblasobjs = (cblas_shgemm);
@exblasobjs = ( @exblasobjs = (
qamax,qamin,qasum,qaxpy,qcabs1,qcopy,qdot,qgbmv,qgemm, qamax,qamin,qasum,qaxpy,qcabs1,qcopy,qdot,qgbmv,qgemm,
qgemv,qger,qmax,qmin, qgemv,qger,qmax,qmin,
@@ -3777,6 +3778,10 @@ if ($ARGV[12] == 1) {
@cblasobjs = (@cblasobjs, @bfcblasobjs); @cblasobjs = (@cblasobjs, @bfcblasobjs);
} }
if ($ARGV[13] == 1) { if ($ARGV[13] == 1) {
@blasobjs = (@blasobjs, @hfblasobjs);
@cblasobjs = (@cblasobjs, @hfcblasobjs);
}
if ($ARGV[14] == 1) {
@blasobjs = (@blasobjs, @blasobjss); @blasobjs = (@blasobjs, @blasobjss);
@cblasobjs = (@cblasobjs, @cblasobjss); @cblasobjs = (@cblasobjs, @cblasobjss);
@lapackobjs = (@lapackobjs, @lapackobjss); @lapackobjs = (@lapackobjs, @lapackobjss);
@@ -3788,11 +3793,11 @@ if ($ARGV[13] == 1) {
@lapack_embeded_underscore_objs = (@lapack_embeded_underscore_objs, @lapack_embeded_underscore_objs_s); @lapack_embeded_underscore_objs = (@lapack_embeded_underscore_objs, @lapack_embeded_underscore_objs_s);
@lapackeobjs = (@lapackeobjs, @lapackeobjss); @lapackeobjs = (@lapackeobjs, @lapackeobjss);
} }
if ($ARGV[14] == 1) {
if ($ARGV[15] == 1) {
@blasobjs = (@blasobjs, @blasobjsd); @blasobjs = (@blasobjs, @blasobjsd);
@cblasobjs = (@cblasobjs, @cblasobjsd); @cblasobjs = (@cblasobjs, @cblasobjsd);
@lapackobjs = (@lapackobjs, @lapackobjsd); @lapackobjs = (@lapackobjs, @lapackobjsd);
if ($ARGV[13] == 0) {
if ($ARGV[14] == 0) {
@lapackobjs2 = (@lapackobjs2, @lapackobjs2ds); @lapackobjs2 = (@lapackobjs2, @lapackobjs2ds);
} }
@lapackobjs2 = (@lapackobjs2, @lapackobjs2d, @lapackobjs2dz); @lapackobjs2 = (@lapackobjs2, @lapackobjs2d, @lapackobjs2dz);
@@ -3801,14 +3806,14 @@ if ($ARGV[14] == 1) {
@lapack_embeded_underscore_objs = (@lapack_embeded_underscore_objs, @lapack_embeded_underscore_objs_d); @lapack_embeded_underscore_objs = (@lapack_embeded_underscore_objs, @lapack_embeded_underscore_objs_d);
@lapackeobjs = (@lapackeobjs, @lapackeobjsd); @lapackeobjs = (@lapackeobjs, @lapackeobjsd);
} }
if ($ARGV[15] == 1) {
if ($ARGV[16] == 1) {
@blasobjs = (@blasobjs, @blasobjsc); @blasobjs = (@blasobjs, @blasobjsc);
@cblasobjs = (@cblasobjs, @cblasobjsc); @cblasobjs = (@cblasobjs, @cblasobjsc);
@gemm3mobjs = (@gemm3mobjs, @gemm3mobjsc); @gemm3mobjs = (@gemm3mobjs, @gemm3mobjsc);
@cblasgemm3mobjs = (@cblasgemm3mobjs, @cblasgemm3mobjsc); @cblasgemm3mobjs = (@cblasgemm3mobjs, @cblasgemm3mobjsc);
@lapackobjs = (@lapackobjs, @lapackobjsc); @lapackobjs = (@lapackobjs, @lapackobjsc);
@lapackobjs2 = (@lapackobjs2, @lapackobjs2c, @lapackobjs2zc); @lapackobjs2 = (@lapackobjs2, @lapackobjs2c, @lapackobjs2zc);
if ($ARGV[13] == 0) {
if ($ARGV[14] == 0) {
@lapackobjs2 = (@lapackobjs2, @lapackobjs2sc); @lapackobjs2 = (@lapackobjs2, @lapackobjs2sc);
} }
@lapack_deprecated_objs = (@lapack_deprecated_objs, @lapack_deprecated_objsc); @lapack_deprecated_objs = (@lapack_deprecated_objs, @lapack_deprecated_objsc);
@@ -3816,17 +3821,17 @@ if ($ARGV[15] == 1) {
@lapack_embeded_underscore_objs = (@lapack_embeded_underscore_objs, @lapack_embeded_underscore_objs_c); @lapack_embeded_underscore_objs = (@lapack_embeded_underscore_objs, @lapack_embeded_underscore_objs_c);
@lapackeobjs = (@lapackeobjs, @lapackeobjsc); @lapackeobjs = (@lapackeobjs, @lapackeobjsc);
} }
if ($ARGV[16] == 1) {
if ($ARGV[17] == 1) {
@blasobjs = (@blasobjs, @blasobjsz); @blasobjs = (@blasobjs, @blasobjsz);
@cblasobjs = (@cblasobjs, @cblasobjsz); @cblasobjs = (@cblasobjs, @cblasobjsz);
@gemm3mobjs = (@gemm3mobjs, @gemm3mobjsz); @gemm3mobjs = (@gemm3mobjs, @gemm3mobjsz);
@cblasgemm3mobjs = (@cblasgemm3mobjs, @cblasgemm3mobjsz); @cblasgemm3mobjs = (@cblasgemm3mobjs, @cblasgemm3mobjsz);
@lapackobjs = (@lapackobjs, @lapackobjsz); @lapackobjs = (@lapackobjs, @lapackobjsz);
@lapackobjs2 = (@lapackobjs2, @lapackobjs2z); @lapackobjs2 = (@lapackobjs2, @lapackobjs2z);
if ($ARGV[15] == 0) {
if ($ARGV[16] == 0) {
@lapackobjs2 = (@lapackobjs2, @lapackobjs2zc); @lapackobjs2 = (@lapackobjs2, @lapackobjs2zc);
} }
if ($ARGV[14] == 0) {
if ($ARGV[15] == 0) {
@lapackobjs2 = (@lapackobjs2, @lapackobjs2dz); @lapackobjs2 = (@lapackobjs2, @lapackobjs2dz);
} }
@lapack_deprecated_objs = (@lapack_deprecated_objs, @lapack_deprecated_objsz); @lapack_deprecated_objs = (@lapack_deprecated_objs, @lapack_deprecated_objsz);


+ 2
- 0
getarch_2nd.c View File

@@ -19,6 +19,8 @@ int main(int argc, char **argv) {
if ( (argc <= 1) || ((argc >= 2) && (*argv[1] == '0'))) { if ( (argc <= 1) || ((argc >= 2) && (*argv[1] == '0'))) {
printf("SBGEMM_UNROLL_M=%d\n", SBGEMM_DEFAULT_UNROLL_M); printf("SBGEMM_UNROLL_M=%d\n", SBGEMM_DEFAULT_UNROLL_M);
printf("SBGEMM_UNROLL_N=%d\n", SBGEMM_DEFAULT_UNROLL_N); printf("SBGEMM_UNROLL_N=%d\n", SBGEMM_DEFAULT_UNROLL_N);
printf("SHGEMM_UNROLL_M=%d\n", SHGEMM_DEFAULT_UNROLL_M);
printf("SHGEMM_UNROLL_N=%d\n", SHGEMM_DEFAULT_UNROLL_N);
printf("SGEMM_UNROLL_M=%d\n", SGEMM_DEFAULT_UNROLL_M); printf("SGEMM_UNROLL_M=%d\n", SGEMM_DEFAULT_UNROLL_M);
printf("SGEMM_UNROLL_N=%d\n", SGEMM_DEFAULT_UNROLL_N); printf("SGEMM_UNROLL_N=%d\n", SGEMM_DEFAULT_UNROLL_N);
printf("DGEMM_UNROLL_M=%d\n", DGEMM_DEFAULT_UNROLL_M); printf("DGEMM_UNROLL_M=%d\n", DGEMM_DEFAULT_UNROLL_M);


+ 3
- 0
interface/CMakeLists.txt View File

@@ -136,6 +136,9 @@ if (BUILD_BFLOAT16)
GenerateNamedObjects("gemm_batch.c" "" "sbgemm_batch" ${CBLAS_FLAG} "" "" true "BFLOAT16") GenerateNamedObjects("gemm_batch.c" "" "sbgemm_batch" ${CBLAS_FLAG} "" "" true "BFLOAT16")
endif () endif ()
endif () endif ()
if (BUILD_HFLOAT16)
GenerateNamedObjects("gemm.c" "" "shgemm" ${CBLAS_FLAG} "" "" true "HFLOAT16")
endif ()


# complex-specific sources # complex-specific sources
foreach (float_type ${FLOAT_TYPES}) foreach (float_type ${FLOAT_TYPES})


+ 22
- 2
interface/Makefile View File

@@ -53,6 +53,10 @@ SBBLAS3OBJS = sbgemm.$(SUFFIX) sbgemmt.$(SUFFIX) sbgemmtr.$(SUFFIX)
SBEXTOBJS = sbstobf16.$(SUFFIX) sbdtobf16.$(SUFFIX) sbf16tos.$(SUFFIX) dbf16tod.$(SUFFIX) SBEXTOBJS = sbstobf16.$(SUFFIX) sbdtobf16.$(SUFFIX) sbf16tos.$(SUFFIX) dbf16tod.$(SUFFIX)
endif endif


ifeq ($(BUILD_HFLOAT16),1)
SHBLAS3OBJS = shgemm.$(SUFFIX)
endif

DBLAS1OBJS = \ DBLAS1OBJS = \
daxpy.$(SUFFIX) dswap.$(SUFFIX) \ daxpy.$(SUFFIX) dswap.$(SUFFIX) \
dcopy.$(SUFFIX) dscal.$(SUFFIX) \ dcopy.$(SUFFIX) dscal.$(SUFFIX) \
@@ -291,6 +295,10 @@ CSBBLAS3OBJS = cblas_sbgemm.$(SUFFIX) cblas_sbgemmt.$(SUFFIX) cblas_sbgemmtr.$(S
CSBEXTOBJS = cblas_sbstobf16.$(SUFFIX) cblas_sbdtobf16.$(SUFFIX) cblas_sbf16tos.$(SUFFIX) cblas_dbf16tod.$(SUFFIX) CSBEXTOBJS = cblas_sbstobf16.$(SUFFIX) cblas_sbdtobf16.$(SUFFIX) cblas_sbf16tos.$(SUFFIX) cblas_dbf16tod.$(SUFFIX)
endif endif


ifeq ($(BUILD_HFLOAT16),1)
CSHBLAS3OBJS = cblas_shgemm.$(SUFFIX)
endif

CDBLAS1OBJS = \ CDBLAS1OBJS = \
cblas_idamax.$(SUFFIX) cblas_idamin.$(SUFFIX) cblas_dasum.$(SUFFIX) cblas_daxpy.$(SUFFIX) \ cblas_idamax.$(SUFFIX) cblas_idamin.$(SUFFIX) cblas_dasum.$(SUFFIX) cblas_daxpy.$(SUFFIX) \
cblas_dcopy.$(SUFFIX) cblas_ddot.$(SUFFIX) \ cblas_dcopy.$(SUFFIX) cblas_ddot.$(SUFFIX) \
@@ -388,6 +396,7 @@ SBLAS3OBJS += $(CSBLAS3OBJS)
SBBLAS1OBJS += $(CSBBLAS1OBJS) SBBLAS1OBJS += $(CSBBLAS1OBJS)
SBBLAS2OBJS += $(CSBBLAS2OBJS) SBBLAS2OBJS += $(CSBBLAS2OBJS)
SBBLAS3OBJS += $(CSBBLAS3OBJS) SBBLAS3OBJS += $(CSBBLAS3OBJS)
SHBLAS3OBJS += $(CSHBLAS3OBJS)
DBLAS1OBJS += $(CDBLAS1OBJS) DBLAS1OBJS += $(CDBLAS1OBJS)
DBLAS2OBJS += $(CDBLAS2OBJS) DBLAS2OBJS += $(CDBLAS2OBJS)
DBLAS3OBJS += $(CDBLAS3OBJS) DBLAS3OBJS += $(CDBLAS3OBJS)
@@ -405,6 +414,7 @@ endif


SBLASOBJS = $(SBLAS1OBJS) $(SBLAS2OBJS) $(SBLAS3OBJS) SBLASOBJS = $(SBLAS1OBJS) $(SBLAS2OBJS) $(SBLAS3OBJS)
SBBLASOBJS = $(SBBLAS1OBJS) $(SBBLAS2OBJS) $(SBBLAS3OBJS) SBBLASOBJS = $(SBBLAS1OBJS) $(SBBLAS2OBJS) $(SBBLAS3OBJS)
SHBLASOBJS = $(SHBLAS3OBJS)
DBLASOBJS = $(DBLAS1OBJS) $(DBLAS2OBJS) $(DBLAS3OBJS) DBLASOBJS = $(DBLAS1OBJS) $(DBLAS2OBJS) $(DBLAS3OBJS)
QBLASOBJS = $(QBLAS1OBJS) $(QBLAS2OBJS) $(QBLAS3OBJS) QBLASOBJS = $(QBLAS1OBJS) $(QBLAS2OBJS) $(QBLAS3OBJS)
CBLASOBJS = $(CBLAS1OBJS) $(CBLAS2OBJS) $(CBLAS3OBJS) CBLASOBJS = $(CBLAS1OBJS) $(CBLAS2OBJS) $(CBLAS3OBJS)
@@ -512,7 +522,7 @@ ifneq ($(BUILD_COMPLEX16),1)
ZBLASOBJS= ZBLASOBJS=
endif endif


FUNCOBJS = $(SBEXTOBJS) $(CXERBLAOBJS) $(SBBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS)
FUNCOBJS = $(SBEXTOBJS) $(CXERBLAOBJS) $(SBBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS) $(SHBLASOBJS)


ifeq ($(EXPRECISION), 1) ifeq ($(EXPRECISION), 1)
FUNCOBJS += $(QBLASOBJS) $(XBLASOBJS) FUNCOBJS += $(QBLASOBJS) $(XBLASOBJS)
@@ -550,7 +560,7 @@ level1 : $(SBEXTOBJS) $(SBBLAS1OBJS) $(SBLAS1OBJS) $(DBLAS1OBJS) $(QBLAS1OBJS) $
level2 : $(SBBLAS2OBJS) $(SBLAS2OBJS) $(DBLAS2OBJS) $(QBLAS2OBJS) $(CBLAS2OBJS) $(ZBLAS2OBJS) $(XBLAS2OBJS) level2 : $(SBBLAS2OBJS) $(SBLAS2OBJS) $(DBLAS2OBJS) $(QBLAS2OBJS) $(CBLAS2OBJS) $(ZBLAS2OBJS) $(XBLAS2OBJS)
$(AR) $(ARFLAGS) -ru $(TOPDIR)/$(LIBNAME) $^ $(AR) $(ARFLAGS) -ru $(TOPDIR)/$(LIBNAME) $^


level3 : $(SBBLAS3OBJS) $(SBLAS3OBJS) $(DBLAS3OBJS) $(QBLAS3OBJS) $(CBLAS3OBJS) $(ZBLAS3OBJS) $(XBLAS3OBJS)
level3 : $(SBBLAS3OBJS) $(SBLAS3OBJS) $(DBLAS3OBJS) $(QBLAS3OBJS) $(CBLAS3OBJS) $(ZBLAS3OBJS) $(XBLAS3OBJS) $(SHBLAS3OBJS)
$(AR) $(ARFLAGS) -ru $(TOPDIR)/$(LIBNAME) $^ $(AR) $(ARFLAGS) -ru $(TOPDIR)/$(LIBNAME) $^


aux : $(CBAUXOBJS) aux : $(CBAUXOBJS)
@@ -1309,6 +1319,11 @@ sbgemmtr.$(SUFFIX) sbgemmtr.$(PSUFFIX) : sbgemmt.c ../param.h
$(CC) -c $(CFLAGS) -DRNAME $< -o $(@F) $(CC) -c $(CFLAGS) -DRNAME $< -o $(@F)
endif endif


ifeq ($(BUILD_HFLOAT16),1)
shgemm.$(SUFFIX) shgemm.$(PSUFFIX) : gemm.c ../param.h
$(CC) -c $(CFLAGS) $< -o $(@F)
endif

sgemm.$(SUFFIX) sgemm.$(PSUFFIX) : gemm.c ../param.h sgemm.$(SUFFIX) sgemm.$(PSUFFIX) : gemm.c ../param.h
$(CC) -c $(CFLAGS) $< -o $(@F) $(CC) -c $(CFLAGS) $< -o $(@F)


@@ -1968,6 +1983,11 @@ cblas_sbgemm.$(SUFFIX) cblas_sbgemm.$(PSUFFIX) : gemm.c ../param.h
$(CC) -DCBLAS -c $(CFLAGS) $< -o $(@F) $(CC) -DCBLAS -c $(CFLAGS) $< -o $(@F)
endif endif


ifeq ($(BUILD_HFLOAT16),1)
cblas_shgemm.$(SUFFIX) cblas_shgemm.$(PSUFFIX) : gemm.c ../param.h
$(CC) -DCBLAS -c $(CFLAGS) $< -o $(@F)
endif

cblas_dgemm.$(SUFFIX) cblas_dgemm.$(PSUFFIX) : gemm.c ../param.h cblas_dgemm.$(SUFFIX) cblas_dgemm.$(PSUFFIX) : gemm.c ../param.h
$(CC) -DCBLAS -c $(CFLAGS) $< -o $(@F) $(CC) -DCBLAS -c $(CFLAGS) $< -o $(@F)




+ 8
- 6
interface/gemm.c View File

@@ -56,6 +56,8 @@
#elif defined(BFLOAT16) #elif defined(BFLOAT16)
#define ERROR_NAME "SBGEMM " #define ERROR_NAME "SBGEMM "
#define GEMV BLASFUNC(sbgemv) #define GEMV BLASFUNC(sbgemv)
#elif defined(HFLOAT16)
#define ERROR_NAME "SHGEMM "
#else #else
#define ERROR_NAME "SGEMM " #define ERROR_NAME "SGEMM "
#define GEMV BLASFUNC(sgemv) #define GEMV BLASFUNC(sgemv)
@@ -111,7 +113,7 @@ static int (*gemm[])(blas_arg_t *, BLASLONG *, BLASLONG *, IFLOAT *, IFLOAT *, B
#endif #endif
}; };


#if defined(SMALL_MATRIX_OPT) && !defined(GEMM3M) && !defined(XDOUBLE)
#if defined(SMALL_MATRIX_OPT) && !defined(GEMM3M) && !defined(XDOUBLE) &&!defined(HFLOAT16)
#define USE_SMALL_MATRIX_OPT 1 #define USE_SMALL_MATRIX_OPT 1
#else #else
#define USE_SMALL_MATRIX_OPT 0 #define USE_SMALL_MATRIX_OPT 0
@@ -219,11 +221,11 @@ static inline int get_gemm_optimal_nthreads_neoversev2(double MNK, int ncpu) {


static inline int get_gemm_optimal_nthreads(double MNK) { static inline int get_gemm_optimal_nthreads(double MNK) {
int ncpu = num_cpu_avail(3); int ncpu = num_cpu_avail(3);
#if defined(NEOVERSEV1) && !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16)
#if defined(NEOVERSEV1) && !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16) && !defined(HFLOAT16)
return get_gemm_optimal_nthreads_neoversev1(MNK, ncpu); return get_gemm_optimal_nthreads_neoversev1(MNK, ncpu);
#elif defined(NEOVERSEV2) && !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16)
#elif defined(NEOVERSEV2) && !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16) && !defined(HFLOAT16)
return get_gemm_optimal_nthreads_neoversev2(MNK, ncpu); return get_gemm_optimal_nthreads_neoversev2(MNK, ncpu);
#elif defined(DYNAMIC_ARCH) && !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16)
#elif defined(DYNAMIC_ARCH) && !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16) && !defined(HFLOAT16)
if (strcmp(gotoblas_corename(), "neoversev1") == 0) { if (strcmp(gotoblas_corename(), "neoversev1") == 0) {
return get_gemm_optimal_nthreads_neoversev1(MNK, ncpu); return get_gemm_optimal_nthreads_neoversev1(MNK, ncpu);
} }
@@ -417,7 +419,7 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS


PRINT_DEBUG_CNAME; PRINT_DEBUG_CNAME;


#if !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16)
#if !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16) && !defined(HFLOAT16)
#if defined(ARCH_x86) && (defined(USE_SGEMM_KERNEL_DIRECT)||defined(DYNAMIC_ARCH)) #if defined(ARCH_x86) && (defined(USE_SGEMM_KERNEL_DIRECT)||defined(DYNAMIC_ARCH))
#if defined(DYNAMIC_ARCH) #if defined(DYNAMIC_ARCH)
if (support_avx512() ) if (support_avx512() )
@@ -577,7 +579,7 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS
args.m, args.n, args.k, args.lda, args.ldb, args.ldc); args.m, args.n, args.k, args.lda, args.ldb, args.ldc);
#endif #endif


#if defined(GEMM_GEMV_FORWARD) && !defined(GEMM3M) && !defined(COMPLEX) && (!defined(BFLOAT16) || defined(GEMM_GEMV_FORWARD_BF16))
#if defined(GEMM_GEMV_FORWARD) && !defined(GEMM3M) && !defined(COMPLEX) && !defined(HFLOAT16) && (!defined(BFLOAT16) || defined(GEMM_GEMV_FORWARD_BF16))
#if defined(ARCH_ARM64) #if defined(ARCH_ARM64)
// The gemv kernels in arm64/{gemv_n.S,gemv_n_sve.c,gemv_t.S,gemv_t_sve.c} // The gemv kernels in arm64/{gemv_n.S,gemv_n_sve.c,gemv_t.S,gemv_t_sve.c}
// perform poorly in certain circumstances. We use the following boolean // perform poorly in certain circumstances. We use the following boolean


+ 55
- 0
kernel/CMakeLists.txt View File

@@ -351,6 +351,22 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS)
GenerateNamedObjects("${KERNELDIR}/${SBGEMMKERNEL}" "" "gemm_kernel" false "" "" false "BFLOAT16") GenerateNamedObjects("${KERNELDIR}/${SBGEMMKERNEL}" "" "gemm_kernel" false "" "" false "BFLOAT16")
GenerateNamedObjects("${KERNELDIR}/${SBGEMM_BETA}" "" "gemm_beta" false "" "" false "BFLOAT16") GenerateNamedObjects("${KERNELDIR}/${SBGEMM_BETA}" "" "gemm_beta" false "" "" false "BFLOAT16")
endif () endif ()
if (BUILD_HFLOAT16)
if (SHGEMMINCOPY)
GenerateNamedObjects("${KERNELDIR}/${SHGEMMINCOPY}" "" "${SHGEMMINCOPYOBJ}" false "" "" true "HFLOAT16")
endif ()
if (SHGEMMITCOPY)
GenerateNamedObjects("${KERNELDIR}/${SHGEMMITCOPY}" "" "${SHGEMMITCOPYOBJ}" false "" "" true "HFLOAT16")
endif ()
if (SHGEMMONCOPY)
GenerateNamedObjects("${KERNELDIR}/${SHGEMMONCOPY}" "" "${SHGEMMONCOPYOBJ}" false "" "" true "HFLOAT16")
endif ()
if (SHGEMMOTCOPY)
GenerateNamedObjects("${KERNELDIR}/${SHGEMMOTCOPY}" "" "${SHGEMMOTCOPYOBJ}" false "" "" true "HFLOAT16")
endif ()
GenerateNamedObjects("${KERNELDIR}/${SHGEMMKERNEL}" "" "gemm_kernel" false "" "" false "HFLOAT16")
GenerateNamedObjects("${KERNELDIR}/${SHGEMM_BETA}" "" "gemm_beta" false "" "" false "HFLOAT16")
endif ()
foreach (float_type ${FLOAT_TYPES}) foreach (float_type ${FLOAT_TYPES})
string(SUBSTRING ${float_type} 0 1 float_char) string(SUBSTRING ${float_type} 0 1 float_char)
if (${float_char}GEMMINCOPY) if (${float_char}GEMMINCOPY)
@@ -769,6 +785,45 @@ endif ()
GenerateNamedObjects("${KERNELDIR}/${SBGEMM_SMALL_K_B0_TN}" "B0" "gemm_small_kernel_b0_tn" false "" "" false "BFLOAT16") GenerateNamedObjects("${KERNELDIR}/${SBGEMM_SMALL_K_B0_TN}" "B0" "gemm_small_kernel_b0_tn" false "" "" false "BFLOAT16")
GenerateNamedObjects("${KERNELDIR}/${SBGEMM_SMALL_K_B0_TT}" "B0" "gemm_small_kernel_b0_tt" false "" "" false "BFLOAT16") GenerateNamedObjects("${KERNELDIR}/${SBGEMM_SMALL_K_B0_TT}" "B0" "gemm_small_kernel_b0_tt" false "" "" false "BFLOAT16")
endif () endif ()

if (BUILD_HFLOAT16)
if (NOT DEFINED SHGEMM_SMALL_M_PERMIT)
set(SHGEMM_SMALL_M_PERMIT ../generic/gemm_small_matrix_permit.c)
endif ()
if (NOT DEFINED SHGEMM_SMALL_K_NN)
set(SHGEMM_SMALL_K_NN ../generic/gemm_small_matrix_kernel_nn.c)
endif ()
if (NOT DEFINED SHGEMM_SMALL_K_NT)
set(SHGEMM_SMALL_K_NT ../generic/gemm_small_matrix_kernel_nt.c)
endif ()
if (NOT DEFINED SHGEMM_SMALL_K_TN)
set(SHGEMM_SMALL_K_TN ../generic/gemm_small_matrix_kernel_tn.c)
endif ()
if (NOT DEFINED SHGEMM_SMALL_K_TT)
set(SHGEMM_SMALL_K_TT ../generic/gemm_small_matrix_kernel_tt.c)
endif ()
if (NOT DEFINED SHGEMM_SMALL_K_B0_NN)
set(SHGEMM_SMALL_K_B0_NN ../generic/gemm_small_matrix_kernel_nn.c)
endif ()
if (NOT DEFINED SHGEMM_SMALL_K_B0_NT)
set(SHGEMM_SMALL_K_B0_NT ../generic/gemm_small_matrix_kernel_nt.c)
endif ()
if (NOT DEFINED SHGEMM_SMALL_K_B0_TN)
set(SHGEMM_SMALL_K_B0_TN ../generic/gemm_small_matrix_kernel_tn.c)
endif ()
if (NOT DEFINED SHGEMM_SMALL_K_B0_TT)
set(SHGEMM_SMALL_K_B0_TT ../generic/gemm_small_matrix_kernel_tt.c)
endif ()
GenerateNamedObjects("${KERNELDIR}/${SHGEMM_SMALL_M_PERMIT}" "" "gemm_small_matrix_permit" false "" "" false "HFLOAT16")
GenerateNamedObjects("${KERNELDIR}/${SHGEMM_SMALL_K_NN}" "" "gemm_small_kernel_nn" false "" "" false "HFLOAT16")
GenerateNamedObjects("${KERNELDIR}/${SHGEMM_SMALL_K_NT}" "" "gemm_small_kernel_nt" false "" "" false "HFLOAT16")
GenerateNamedObjects("${KERNELDIR}/${SHGEMM_SMALL_K_TN}" "" "gemm_small_kernel_tn" false "" "" false "HFLOAT16")
GenerateNamedObjects("${KERNELDIR}/${SHGEMM_SMALL_K_TT}" "" "gemm_small_kernel_tt" false "" "" false "HFLOAT16")
GenerateNamedObjects("${KERNELDIR}/${SHGEMM_SMALL_K_B0_NN}" "B0" "gemm_small_kernel_b0_nn" false "" "" false "HFLOAT16")
GenerateNamedObjects("${KERNELDIR}/${SHGEMM_SMALL_K_B0_NT}" "B0" "gemm_small_kernel_b0_nt" false "" "" false "HFLOAT16")
GenerateNamedObjects("${KERNELDIR}/${SHGEMM_SMALL_K_B0_TN}" "B0" "gemm_small_kernel_b0_tn" false "" "" false "HFLOAT16")
GenerateNamedObjects("${KERNELDIR}/${SHGEMM_SMALL_K_B0_TT}" "B0" "gemm_small_kernel_b0_tt" false "" "" false "HFLOAT16")
endif ()
endif () endif ()


if (NOT DEFINED ${float_char}OMATCOPY_CN) if (NOT DEFINED ${float_char}OMATCOPY_CN)


+ 164
- 0
kernel/Makefile.L3 View File

@@ -129,6 +129,26 @@ SBKERNELOBJS += \
$(SBGEMMONCOPYOBJ) $(SBGEMMOTCOPYOBJ) $(SBGEMMONCOPYOBJ) $(SBGEMMOTCOPYOBJ)
endif endif


ifeq ($(BUILD_HFLOAT16), 1)
ifndef SHGEMMKERNEL
SHGEMM_BETA = ../generic/gemm_beta.c
SHGEMMKERNEL = ../generic/gemmkernel_2x2.c
SHGEMMONCOPY = ../generic/gemm_ncopy_2.c
SHGEMMOTCOPY = ../generic/gemm_tcopy_2.c
SHGEMMONCOPYOBJ = shgemm_oncopy$(TSUFFIX).$(SUFFIX)
SHGEMMOTCOPYOBJ = shgemm_otcopy$(TSUFFIX).$(SUFFIX)
SHGEMMINCOPY = ../generic/gemm_ncopy_2.c
SHGEMMITCOPY = ../generic/gemm_tcopy_2.c
SHGEMMINCOPYOBJ = shgemm_incopy$(TSUFFIX).$(SUFFIX)
SHGEMMITCOPYOBJ = shgemm_itcopy$(TSUFFIX).$(SUFFIX)
endif

SHKERNELOBJS += \
shgemm_kernel$(TSUFFIX).$(SUFFIX) \
$(SHGEMMINCOPYOBJ) $(SHGEMMITCOPYOBJ) \
$(SHGEMMONCOPYOBJ) $(SHGEMMOTCOPYOBJ)
endif

ifneq "$(or $(BUILD_SINGLE),$(BUILD_DOUBLE),$(BUILD_COMPLEX))" "" ifneq "$(or $(BUILD_SINGLE),$(BUILD_DOUBLE),$(BUILD_COMPLEX))" ""
SKERNELOBJS += \ SKERNELOBJS += \
sgemm_kernel$(TSUFFIX).$(SUFFIX) \ sgemm_kernel$(TSUFFIX).$(SUFFIX) \
@@ -192,6 +212,9 @@ XKERNELOBJS += \
ifeq ($(BUILD_BFLOAT16),1) ifeq ($(BUILD_BFLOAT16),1)
SBBLASOBJS += $(SBKERNELOBJS) SBBLASOBJS += $(SBKERNELOBJS)
endif endif
ifeq ($(BUILD_HFLOAT16),1)
SHBLASOBJS += $(SHKERNELOBJS)
endif
SBLASOBJS += $(SKERNELOBJS) SBLASOBJS += $(SKERNELOBJS)
DBLASOBJS += $(DKERNELOBJS) DBLASOBJS += $(DKERNELOBJS)
QBLASOBJS += $(QKERNELOBJS) QBLASOBJS += $(QKERNELOBJS)
@@ -202,6 +225,9 @@ XBLASOBJS += $(XKERNELOBJS)
ifeq ($(BUILD_BFLOAT16),1) ifeq ($(BUILD_BFLOAT16),1)
SBBLASOBJS += sbgemm_beta$(TSUFFIX).$(SUFFIX) SBBLASOBJS += sbgemm_beta$(TSUFFIX).$(SUFFIX)
endif endif
ifeq ($(BUILD_HFLOAT16),1)
SHBLASOBJS += shgemm_beta$(TSUFFIX).$(SUFFIX)
endif


ifneq "$(or $(BUILD_SINGLE),$(BUILD_DOUBLE),$(BUILD_COMPLEX))" "" ifneq "$(or $(BUILD_SINGLE),$(BUILD_DOUBLE),$(BUILD_COMPLEX))" ""
SBLASOBJS += \ SBLASOBJS += \
@@ -493,6 +519,15 @@ SBBLASOBJS += \
sbgemm_small_kernel_b0_tn$(TSUFFIX).$(SUFFIX) sbgemm_small_kernel_b0_tt$(TSUFFIX).$(SUFFIX) sbgemm_small_kernel_b0_tn$(TSUFFIX).$(SUFFIX) sbgemm_small_kernel_b0_tt$(TSUFFIX).$(SUFFIX)
endif endif


ifeq ($(BUILD_HFLOAT16),1)
SHBLASOBJS += \
shgemm_small_matrix_permit$(TSUFFIX).$(SUFFIX) \
shgemm_small_kernel_nn$(TSUFFIX).$(SUFFIX) shgemm_small_kernel_nt$(TSUFFIX).$(SUFFIX) \
shgemm_small_kernel_tn$(TSUFFIX).$(SUFFIX) shgemm_small_kernel_tt$(TSUFFIX).$(SUFFIX) \
shgemm_small_kernel_b0_nn$(TSUFFIX).$(SUFFIX) shgemm_small_kernel_b0_nt$(TSUFFIX).$(SUFFIX) \
shgemm_small_kernel_b0_tn$(TSUFFIX).$(SUFFIX) shgemm_small_kernel_b0_tt$(TSUFFIX).$(SUFFIX)
endif

SBLASOBJS += \ SBLASOBJS += \
sgemm_small_matrix_permit$(TSUFFIX).$(SUFFIX) \ sgemm_small_matrix_permit$(TSUFFIX).$(SUFFIX) \
sgemm_small_kernel_nn$(TSUFFIX).$(SUFFIX) sgemm_small_kernel_nt$(TSUFFIX).$(SUFFIX) \ sgemm_small_kernel_nn$(TSUFFIX).$(SUFFIX) sgemm_small_kernel_nt$(TSUFFIX).$(SUFFIX) \
@@ -599,6 +634,13 @@ SBGEMMONCOPYOBJ_P = $(SBGEMMONCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX))
SBGEMMOTCOPYOBJ_P = $(SBGEMMOTCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX)) SBGEMMOTCOPYOBJ_P = $(SBGEMMOTCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX))
endif endif


ifeq ($(BUILD_HFLOAT16), 1)
SHGEMMINCOPYOBJ_P = $(SHGEMMINCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX))
SHGEMMITCOPYOBJ_P = $(SHGEMMITCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX))
SHGEMMONCOPYOBJ_P = $(SHGEMMONCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX))
SHGEMMOTCOPYOBJ_P = $(SHGEMMOTCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX))
endif

SGEMMINCOPYOBJ_P = $(SGEMMINCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX)) SGEMMINCOPYOBJ_P = $(SGEMMINCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX))
SGEMMITCOPYOBJ_P = $(SGEMMITCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX)) SGEMMITCOPYOBJ_P = $(SGEMMITCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX))
SGEMMONCOPYOBJ_P = $(SGEMMONCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX)) SGEMMONCOPYOBJ_P = $(SGEMMONCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX))
@@ -629,6 +671,11 @@ $(KDIR)sbgemm_beta$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SBGEMM_BETA)
$(CC) $(CFLAGS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@ $(CC) $(CFLAGS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@
endif endif


ifeq ($(BUILD_HFLOAT16),1)
$(KDIR)shgemm_beta$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SHGEMM_BETA)
$(CC) $(CFLAGS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@
endif

$(KDIR)sgemm_beta$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_BETA) $(KDIR)sgemm_beta$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_BETA)
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@


@@ -671,6 +718,25 @@ $(KDIR)$(SBGEMMITCOPYOBJ) : $(KERNELDIR)/$(SBGEMMITCOPY)
endif endif
endif endif


ifeq ($(BUILD_HFLOAT16), 1)

$(KDIR)$(SHGEMMONCOPYOBJ) : $(KERNELDIR)/$(SHGEMMONCOPY)
$(CC) $(CFLAGS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@

$(KDIR)$(SHGEMMOTCOPYOBJ) : $(KERNELDIR)/$(SHGEMMOTCOPY)
$(CC) $(CFLAGS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@

#ifneq ($(SHGEMM_UNROLL_M), $(SHGEMM_UNROLL_N))

$(KDIR)$(SHGEMMINCOPYOBJ) : $(KERNELDIR)/$(SHGEMMINCOPY)
$(CC) $(CFLAGS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@

$(KDIR)$(SHGEMMITCOPYOBJ) : $(KERNELDIR)/$(SHGEMMITCOPY)
$(CC) $(CFLAGS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@

#endif
endif

$(KDIR)$(SGEMMONCOPYOBJ) : $(KERNELDIR)/$(SGEMMONCOPY) $(KDIR)$(SGEMMONCOPYOBJ) : $(KERNELDIR)/$(SGEMMONCOPY)
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@


@@ -853,6 +919,12 @@ $(KDIR)sbgemm_kernel$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SBGEMMKERNEL) $(SBGEMM
$(CC) $(CFLAGS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@ $(CC) $(CFLAGS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@
endif endif


ifeq ($(BUILD_HFLOAT16), 1)

$(KDIR)shgemm_kernel$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SHGEMMKERNEL) $(SHGEMMDEPEND)
$(CC) $(CFLAGS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@
endif

$(KDIR)dgemm_kernel$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMMKERNEL) $(DGEMMDEPEND) $(KDIR)dgemm_kernel$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMMKERNEL) $(DGEMMDEPEND)
ifeq ($(OS), AIX) ifeq ($(OS), AIX)
$(CC) $(CFLAGS) -S -DDOUBLE -UCOMPLEX $< -o - > dgemm_kernel$(TSUFFIX).s $(CC) $(CFLAGS) -S -DDOUBLE -UCOMPLEX $< -o - > dgemm_kernel$(TSUFFIX).s
@@ -2840,6 +2912,11 @@ $(KDIR)sbgemm_beta$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SBGEMM_BETA)
$(CC) $(PFLAGS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@ $(CC) $(PFLAGS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@
endif endif


ifeq ($(BUILD_HFLOAT16),1)
$(KDIR)shgemm_beta$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SHGEMM_BETA)
$(CC) $(PFLAGS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@
endif

$(KDIR)dgemm_beta$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(DGEMM_BETA) $(KDIR)dgemm_beta$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(DGEMM_BETA)
$(CC) $(PFLAGS) -c -DDOUBLE -UCOMPLEX $< -o $@ $(CC) $(PFLAGS) -c -DDOUBLE -UCOMPLEX $< -o $@


@@ -2873,6 +2950,23 @@ $(SBGEMMITCOPYOBJ_P) : $(KERNELDIR)/$(SBGEMMITCOPY)
endif endif
endif endif


ifeq ($(BUILD_HFLOAT16), 1)
$(SHGEMMONCOPYOBJ_P) : $(KERNELDIR)/$(SHGEMMONCOPY)
$(CC) $(PFLAGS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@

$(SHGEMMOTCOPYOBJ_P) : $(KERNELDIR)/$(SHGEMMOTCOPY)
$(CC) $(PFLAGS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@

#ifneq ($(SHGEMM_UNROLL_M), $(SHGEMM_UNROLL_N))
$(SHGEMMINCOPYOBJ_P) : $(KERNELDIR)/$(SHGEMMINCOPY)
$(CC) $(PFLAGS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@

$(SHGEMMITCOPYOBJ_P) : $(KERNELDIR)/$(SHGEMMITCOPY)
$(CC) $(PFLAGS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@

#endif
endif

$(SGEMMONCOPYOBJ_P) : $(KERNELDIR)/$(SGEMMONCOPY) $(SGEMMONCOPYOBJ_P) : $(KERNELDIR)/$(SGEMMONCOPY)
$(CC) $(PFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ $(CC) $(PFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@


@@ -2983,6 +3077,11 @@ $(KDIR)sbgemm_kernel$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SBGEMMKERNEL) $(SBGEM
$(CC) $(PFLAGS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@ $(CC) $(PFLAGS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@
endif endif


ifeq ($(BUILD_HFLOAT16), 1)
$(KDIR)shgemm_kernel$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SHGEMMKERNEL) $(SHGEMMDEPEND)
$(CC) $(PFLAGS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@
endif

$(KDIR)sgemm_kernel$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SGEMMKERNEL) $(SGEMMDEPEND) $(KDIR)sgemm_kernel$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SGEMMKERNEL) $(SGEMMDEPEND)
$(CC) $(PFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ $(CC) $(PFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@


@@ -4843,6 +4942,71 @@ $(KDIR)sbgemm_small_kernel_b0_tt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SBGEMM_SMA
$(CC) $(CFLAGS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX -DB0 $< -o $@ $(CC) $(CFLAGS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX -DB0 $< -o $@
endif endif


ifeq ($(BUILD_HFLOAT16), 1)
ifndef SHGEMM_SMALL_M_PERMIT
SHGEMM_SMALL_M_PERMIT = ../generic/gemm_small_matrix_permit.c
endif

ifndef SHGEMM_SMALL_K_NN
SHGEMM_SMALL_K_NN = ../generic/gemm_small_matrix_kernel_nn.c
endif

ifndef SHGEMM_SMALL_K_NT
SHGEMM_SMALL_K_NT = ../generic/gemm_small_matrix_kernel_nt.c
endif

ifndef SHGEMM_SMALL_K_TN
SHGEMM_SMALL_K_TN = ../generic/gemm_small_matrix_kernel_tn.c
endif

ifndef SHGEMM_SMALL_K_TT
SHGEMM_SMALL_K_TT = ../generic/gemm_small_matrix_kernel_tt.c
endif

$(KDIR)shgemm_small_matrix_permit$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SHGEMM_SMALL_M_PERMIT)
$(CC) $(CFLAGS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@

$(KDIR)shgemm_small_kernel_nn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SHGEMM_SMALL_K_NN)
$(CC) $(CFLAGS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@

$(KDIR)shgemm_small_kernel_nt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SHGEMM_SMALL_K_NT)
$(CC) $(CFLAGS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@

$(KDIR)shgemm_small_kernel_tn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SHGEMM_SMALL_K_TN)
$(CC) $(CFLAGS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@

$(KDIR)shgemm_small_kernel_tt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SHGEMM_SMALL_K_TT)
$(CC) $(CFLAGS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@

ifndef SHGEMM_SMALL_K_B0_NN
SHGEMM_SMALL_K_B0_NN = ../generic/gemm_small_matrix_kernel_nn.c
endif

ifndef SHGEMM_SMALL_K_B0_NT
SHGEMM_SMALL_K_B0_NT = ../generic/gemm_small_matrix_kernel_nt.c
endif

ifndef SHGEMM_SMALL_K_B0_TN
SHGEMM_SMALL_K_B0_TN = ../generic/gemm_small_matrix_kernel_tn.c
endif

ifndef SHGEMM_SMALL_K_B0_TT
SHGEMM_SMALL_K_B0_TT = ../generic/gemm_small_matrix_kernel_tt.c
endif

$(KDIR)shgemm_small_kernel_b0_nn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SHGEMM_SMALL_K_B0_NN)
$(CC) $(CFLAGS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX -DB0 $< -o $@

$(KDIR)shgemm_small_kernel_b0_nt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SHGEMM_SMALL_K_B0_NT)
$(CC) $(CFLAGS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX -DB0 $< -o $@

$(KDIR)shgemm_small_kernel_b0_tn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SHGEMM_SMALL_K_B0_TN)
$(CC) $(CFLAGS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX -DB0 $< -o $@

$(KDIR)shgemm_small_kernel_b0_tt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SHGEMM_SMALL_K_B0_TT)
$(CC) $(CFLAGS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX -DB0 $< -o $@
endif

ifndef CGEMM_SMALL_M_PERMIT ifndef CGEMM_SMALL_M_PERMIT
CGEMM_SMALL_M_PERMIT = ../generic/zgemm_small_matrix_permit.c CGEMM_SMALL_M_PERMIT = ../generic/zgemm_small_matrix_permit.c
endif endif


+ 11
- 0
kernel/riscv64/KERNEL.RISCV64_ZVL128B View File

@@ -245,3 +245,14 @@ endif
ifndef ZGEMM_BETA ifndef ZGEMM_BETA
ZGEMM_BETA = zgemm_beta_rvv.c ZGEMM_BETA = zgemm_beta_rvv.c
endif endif

ifeq ($(BUILD_BFLOAT16), 1)
SHGEMMKERNEL = shgemm_kernel_$(SHGEMM_UNROLL_M)x$(SHGEMM_UNROLL_N)_zvl128b.c
SHGEMMONCOPY = ../generic/gemm_ncopy_$(SHGEMM_UNROLL_N).c
SHGEMMOTCOPY = ../generic/gemm_tcopy_$(SHGEMM_UNROLL_N).c
SHGEMMONCOPYOBJ = shgemm_oncopy$(TSUFFIX).$(SUFFIX)
SHGEMMOTCOPYOBJ = shgemm_otcopy$(TSUFFIX).$(SUFFIX)
ifndef SHGEMM_BETA
SHGEMM_BETA = gemm_beta_rvv.c
endif
endif

+ 18
- 0
kernel/riscv64/KERNEL.RISCV64_ZVL256B View File

@@ -209,5 +209,23 @@ COMATCOPY_CN = zomatcopy_cn_vector.c
DOMATCOPY_CN = omatcopy_cn_vector.c DOMATCOPY_CN = omatcopy_cn_vector.c
SOMATCOPY_CN = omatcopy_cn_vector.c SOMATCOPY_CN = omatcopy_cn_vector.c



ifeq ($(BUILD_BFLOAT16), 1)
SHGEMMKERNEL = shgemm_kernel_$(SHGEMM_UNROLL_M)x$(SHGEMM_UNROLL_N)_zvl256b.c
ifneq ($(SHGEMM_UNROLL_M), $(SHGEMM_UNROLL_N))
SHGEMMINCOPY = ../generic/gemm_ncopy_$(SHGEMM_UNROLL_M).c
SHGEMMITCOPY = ../generic/gemm_tcopy_$(SHGEMM_UNROLL_M).c
SHGEMMINCOPYOBJ = shgemm_incopy$(TSUFFIX).$(SUFFIX)
SHGEMMITCOPYOBJ = shgemm_itcopy$(TSUFFIX).$(SUFFIX)
endif
SHGEMMONCOPY = ../generic/gemm_ncopy_$(SHGEMM_UNROLL_N).c
SHGEMMOTCOPY = ../generic/gemm_tcopy_$(SHGEMM_UNROLL_N).c
SHGEMMONCOPYOBJ = shgemm_oncopy$(TSUFFIX).$(SUFFIX)
SHGEMMOTCOPYOBJ = shgemm_otcopy$(TSUFFIX).$(SUFFIX)
ifndef SHGEMM_BETA
SHGEMM_BETA = gemm_beta_rvv.c
endif
endif

SAXPBYKERNEL = axpby_vector_v2.c SAXPBYKERNEL = axpby_vector_v2.c
DAXPBYKERNEL = axpby_vector_v2.c DAXPBYKERNEL = axpby_vector_v2.c

+ 969
- 0
kernel/riscv64/shgemm_kernel_16x8_zvl256b.c View File

@@ -0,0 +1,969 @@

#include "common.h"
#include <riscv_vector.h>
int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B, FLOAT *C, BLASLONG ldc)
{
BLASLONG gvl = 0;
BLASLONG m_top = 0;
BLASLONG n_top = 0;

// -- MAIN PASS
for (BLASLONG j=0; j<N/8; j+=1) {
m_top = 0;
BLASLONG gvl = __riscv_vsetvl_e16m1(16);

for (BLASLONG i=0; i<M/16; i+=1) {
BLASLONG ai=m_top*K;
BLASLONG bi=n_top*K;

_Float16 B0 = B[bi+0];
_Float16 B1 = B[bi+1];
_Float16 B2 = B[bi+2];
_Float16 B3 = B[bi+3];
_Float16 B4 = B[bi+4];
_Float16 B5 = B[bi+5];
_Float16 B6 = B[bi+6];
_Float16 B7 = B[bi+7];
bi += 8;

vfloat16m1_t A0 = __riscv_vle16_v_f16m1( &A[ai+0*gvl], gvl );
ai += 16;
vfloat32m2_t result0 = __riscv_vfwmul_vf_f32m2( A0, B0, gvl);
vfloat32m2_t result1 = __riscv_vfwmul_vf_f32m2( A0, B1, gvl);
vfloat32m2_t result2 = __riscv_vfwmul_vf_f32m2( A0, B2, gvl);
vfloat32m2_t result3 = __riscv_vfwmul_vf_f32m2( A0, B3, gvl);
vfloat32m2_t result4 = __riscv_vfwmul_vf_f32m2( A0, B4, gvl);
vfloat32m2_t result5 = __riscv_vfwmul_vf_f32m2( A0, B5, gvl);
vfloat32m2_t result6 = __riscv_vfwmul_vf_f32m2( A0, B6, gvl);
vfloat32m2_t result7 = __riscv_vfwmul_vf_f32m2( A0, B7, gvl);
for(BLASLONG k=1; k<K; k++) {
B0 = B[bi+0];
B1 = B[bi+1];
B2 = B[bi+2];
B3 = B[bi+3];
B4 = B[bi+4];
B5 = B[bi+5];
B6 = B[bi+6];
B7 = B[bi+7];
bi += 8;
A0 = __riscv_vle16_v_f16m1( &A[ai+0*gvl], gvl );
ai += 16;
result0 = __riscv_vfwmacc_vf_f32m2(result0, B0, A0, gvl);
result1 = __riscv_vfwmacc_vf_f32m2(result1, B1, A0, gvl);
result2 = __riscv_vfwmacc_vf_f32m2(result2, B2, A0, gvl);
result3 = __riscv_vfwmacc_vf_f32m2(result3, B3, A0, gvl);
result4 = __riscv_vfwmacc_vf_f32m2(result4, B4, A0, gvl);
result5 = __riscv_vfwmacc_vf_f32m2(result5, B5, A0, gvl);
result6 = __riscv_vfwmacc_vf_f32m2(result6, B6, A0, gvl);
result7 = __riscv_vfwmacc_vf_f32m2(result7, B7, A0, gvl);
}
BLASLONG ci=n_top*ldc+m_top;

vfloat32m2_t c0 = __riscv_vle32_v_f32m2( &C[ci], gvl); ci += ldc-gvl*0;
vfloat32m2_t c1 = __riscv_vle32_v_f32m2( &C[ci], gvl); ci += ldc-gvl*0;
vfloat32m2_t c2 = __riscv_vle32_v_f32m2( &C[ci], gvl); ci += ldc-gvl*0;
vfloat32m2_t c3 = __riscv_vle32_v_f32m2( &C[ci], gvl); ci += ldc-gvl*0;
vfloat32m2_t c4 = __riscv_vle32_v_f32m2( &C[ci], gvl); ci += ldc-gvl*0;
vfloat32m2_t c5 = __riscv_vle32_v_f32m2( &C[ci], gvl); ci += ldc-gvl*0;
vfloat32m2_t c6 = __riscv_vle32_v_f32m2( &C[ci], gvl); ci += ldc-gvl*0;
vfloat32m2_t c7 = __riscv_vle32_v_f32m2( &C[ci], gvl);
c0 = __riscv_vfmacc_vf_f32m2(c0, alpha, result0, gvl);
c1 = __riscv_vfmacc_vf_f32m2(c1, alpha, result1, gvl);
c2 = __riscv_vfmacc_vf_f32m2(c2, alpha, result2, gvl);
c3 = __riscv_vfmacc_vf_f32m2(c3, alpha, result3, gvl);
c4 = __riscv_vfmacc_vf_f32m2(c4, alpha, result4, gvl);
c5 = __riscv_vfmacc_vf_f32m2(c5, alpha, result5, gvl);
c6 = __riscv_vfmacc_vf_f32m2(c6, alpha, result6, gvl);
c7 = __riscv_vfmacc_vf_f32m2(c7, alpha, result7, gvl);

ci=n_top*ldc+m_top;

__riscv_vse32_v_f32m2( &C[ci], c0, gvl); ci += ldc-gvl*0;
__riscv_vse32_v_f32m2( &C[ci], c1, gvl); ci += ldc-gvl*0;
__riscv_vse32_v_f32m2( &C[ci], c2, gvl); ci += ldc-gvl*0;
__riscv_vse32_v_f32m2( &C[ci], c3, gvl); ci += ldc-gvl*0;
__riscv_vse32_v_f32m2( &C[ci], c4, gvl); ci += ldc-gvl*0;
__riscv_vse32_v_f32m2( &C[ci], c5, gvl); ci += ldc-gvl*0;
__riscv_vse32_v_f32m2( &C[ci], c6, gvl); ci += ldc-gvl*0;
__riscv_vse32_v_f32m2( &C[ci], c7, gvl);
m_top += 16;
}



// -- tails for main pass
if( M & 8 ) {
gvl = __riscv_vsetvl_e16mf2(8);

BLASLONG ai=m_top*K;
BLASLONG bi=n_top*K;
_Float16 B0 = B[bi+0];
_Float16 B1 = B[bi+1];
_Float16 B2 = B[bi+2];
_Float16 B3 = B[bi+3];
_Float16 B4 = B[bi+4];
_Float16 B5 = B[bi+5];
_Float16 B6 = B[bi+6];
_Float16 B7 = B[bi+7];
bi += 8;

vfloat16mf2_t A0 = __riscv_vle16_v_f16mf2( &A[ai+0*gvl], gvl );
ai += 8;

vfloat32m1_t result0 = __riscv_vfwmul_vf_f32m1( A0, B0, gvl);
vfloat32m1_t result1 = __riscv_vfwmul_vf_f32m1( A0, B1, gvl);
vfloat32m1_t result2 = __riscv_vfwmul_vf_f32m1( A0, B2, gvl);
vfloat32m1_t result3 = __riscv_vfwmul_vf_f32m1( A0, B3, gvl);
vfloat32m1_t result4 = __riscv_vfwmul_vf_f32m1( A0, B4, gvl);
vfloat32m1_t result5 = __riscv_vfwmul_vf_f32m1( A0, B5, gvl);
vfloat32m1_t result6 = __riscv_vfwmul_vf_f32m1( A0, B6, gvl);
vfloat32m1_t result7 = __riscv_vfwmul_vf_f32m1( A0, B7, gvl);

for(BLASLONG k=1; k<K; k++) {
B0 = B[bi+0];
B1 = B[bi+1];
B2 = B[bi+2];
B3 = B[bi+3];
B4 = B[bi+4];
B5 = B[bi+5];
B6 = B[bi+6];
B7 = B[bi+7];
bi += 8;

A0 = __riscv_vle16_v_f16mf2( &A[ai+0*gvl], gvl );
ai += 8;
result0 = __riscv_vfwmacc_vf_f32m1(result0, B0, A0, gvl);
result1 = __riscv_vfwmacc_vf_f32m1(result1, B1, A0, gvl);
result2 = __riscv_vfwmacc_vf_f32m1(result2, B2, A0, gvl);
result3 = __riscv_vfwmacc_vf_f32m1(result3, B3, A0, gvl);
result4 = __riscv_vfwmacc_vf_f32m1(result4, B4, A0, gvl);
result5 = __riscv_vfwmacc_vf_f32m1(result5, B5, A0, gvl);
result6 = __riscv_vfwmacc_vf_f32m1(result6, B6, A0, gvl);
result7 = __riscv_vfwmacc_vf_f32m1(result7, B7, A0, gvl);
}

BLASLONG ci=n_top*ldc+m_top;

vfloat32m1_t c0 = __riscv_vle32_v_f32m1( &C[ci], gvl); ci += ldc-gvl*0;
vfloat32m1_t c1 = __riscv_vle32_v_f32m1( &C[ci], gvl); ci += ldc-gvl*0;
vfloat32m1_t c2 = __riscv_vle32_v_f32m1( &C[ci], gvl); ci += ldc-gvl*0;
vfloat32m1_t c3 = __riscv_vle32_v_f32m1( &C[ci], gvl); ci += ldc-gvl*0;
vfloat32m1_t c4 = __riscv_vle32_v_f32m1( &C[ci], gvl); ci += ldc-gvl*0;
vfloat32m1_t c5 = __riscv_vle32_v_f32m1( &C[ci], gvl); ci += ldc-gvl*0;
vfloat32m1_t c6 = __riscv_vle32_v_f32m1( &C[ci], gvl); ci += ldc-gvl*0;
vfloat32m1_t c7 = __riscv_vle32_v_f32m1( &C[ci], gvl);
c0 = __riscv_vfmacc_vf_f32m1(c0, alpha, result0, gvl);
c1 = __riscv_vfmacc_vf_f32m1(c1, alpha, result1, gvl);
c2 = __riscv_vfmacc_vf_f32m1(c2, alpha, result2, gvl);
c3 = __riscv_vfmacc_vf_f32m1(c3, alpha, result3, gvl);
c4 = __riscv_vfmacc_vf_f32m1(c4, alpha, result4, gvl);
c5 = __riscv_vfmacc_vf_f32m1(c5, alpha, result5, gvl);
c6 = __riscv_vfmacc_vf_f32m1(c6, alpha, result6, gvl);
c7 = __riscv_vfmacc_vf_f32m1(c7, alpha, result7, gvl);

ci=n_top*ldc+m_top;

__riscv_vse32_v_f32m1(&C[ci], c0, gvl); ci += ldc - gvl * 0;
__riscv_vse32_v_f32m1(&C[ci], c1, gvl); ci += ldc - gvl * 0;
__riscv_vse32_v_f32m1(&C[ci], c2, gvl); ci += ldc - gvl * 0;
__riscv_vse32_v_f32m1(&C[ci], c3, gvl); ci += ldc - gvl * 0;
__riscv_vse32_v_f32m1(&C[ci], c4, gvl); ci += ldc - gvl * 0;
__riscv_vse32_v_f32m1(&C[ci], c5, gvl); ci += ldc - gvl * 0;
__riscv_vse32_v_f32m1(&C[ci], c6, gvl); ci += ldc - gvl * 0;
__riscv_vse32_v_f32m1(&C[ci], c7, gvl);
m_top += 8;
}


if( M & 4 ) {
gvl = __riscv_vsetvl_e16mf2(4);

BLASLONG ai=m_top*K;
BLASLONG bi=n_top*K;
_Float16 B0 = B[bi+0];
_Float16 B1 = B[bi+1];
_Float16 B2 = B[bi+2];
_Float16 B3 = B[bi+3];
_Float16 B4 = B[bi+4];
_Float16 B5 = B[bi+5];
_Float16 B6 = B[bi+6];
_Float16 B7 = B[bi+7];
bi += 8;

vfloat16mf2_t A0 = __riscv_vle16_v_f16mf2( &A[ai+0*gvl], gvl );
ai += 4;

vfloat32m1_t result0 = __riscv_vfwmul_vf_f32m1( A0, B0, gvl);
vfloat32m1_t result1 = __riscv_vfwmul_vf_f32m1( A0, B1, gvl);
vfloat32m1_t result2 = __riscv_vfwmul_vf_f32m1( A0, B2, gvl);
vfloat32m1_t result3 = __riscv_vfwmul_vf_f32m1( A0, B3, gvl);
vfloat32m1_t result4 = __riscv_vfwmul_vf_f32m1( A0, B4, gvl);
vfloat32m1_t result5 = __riscv_vfwmul_vf_f32m1( A0, B5, gvl);
vfloat32m1_t result6 = __riscv_vfwmul_vf_f32m1( A0, B6, gvl);
vfloat32m1_t result7 = __riscv_vfwmul_vf_f32m1( A0, B7, gvl);

for(BLASLONG k=1; k < K; ++k) {
B0 = B[bi+0];
B1 = B[bi+1];
B2 = B[bi+2];
B3 = B[bi+3];
B4 = B[bi+4];
B5 = B[bi+5];
B6 = B[bi+6];
B7 = B[bi+7];
bi += 8;

A0 = __riscv_vle16_v_f16mf2( &A[ai+0*gvl], gvl );
ai += 4;

result0 = __riscv_vfwmacc_vf_f32m1(result0, B0, A0, gvl);
result1 = __riscv_vfwmacc_vf_f32m1(result1, B1, A0, gvl);
result2 = __riscv_vfwmacc_vf_f32m1(result2, B2, A0, gvl);
result3 = __riscv_vfwmacc_vf_f32m1(result3, B3, A0, gvl);
result4 = __riscv_vfwmacc_vf_f32m1(result4, B4, A0, gvl);
result5 = __riscv_vfwmacc_vf_f32m1(result5, B5, A0, gvl);
result6 = __riscv_vfwmacc_vf_f32m1(result6, B6, A0, gvl);
result7 = __riscv_vfwmacc_vf_f32m1(result7, B7, A0, gvl);
}

BLASLONG ci = n_top * ldc + m_top;

vfloat32m1_t c0 = __riscv_vle32_v_f32m1(&C[ci], gvl);
ci += ldc - gvl * 0;
vfloat32m1_t c1 = __riscv_vle32_v_f32m1(&C[ci], gvl);
ci += ldc - gvl * 0;
vfloat32m1_t c2 = __riscv_vle32_v_f32m1(&C[ci], gvl);
ci += ldc - gvl * 0;
vfloat32m1_t c3 = __riscv_vle32_v_f32m1(&C[ci], gvl);
ci += ldc - gvl * 0;
vfloat32m1_t c4 = __riscv_vle32_v_f32m1(&C[ci], gvl);
ci += ldc - gvl * 0;
vfloat32m1_t c5 = __riscv_vle32_v_f32m1(&C[ci], gvl);
ci += ldc - gvl * 0;
vfloat32m1_t c6 = __riscv_vle32_v_f32m1(&C[ci], gvl);
ci += ldc - gvl * 0;
vfloat32m1_t c7 = __riscv_vle32_v_f32m1(&C[ci], gvl);
c0 = __riscv_vfmacc_vf_f32m1(c0, alpha, result0, gvl);
c1 = __riscv_vfmacc_vf_f32m1(c1, alpha, result1, gvl);
c2 = __riscv_vfmacc_vf_f32m1(c2, alpha, result2, gvl);
c3 = __riscv_vfmacc_vf_f32m1(c3, alpha, result3, gvl);
c4 = __riscv_vfmacc_vf_f32m1(c4, alpha, result4, gvl);
c5 = __riscv_vfmacc_vf_f32m1(c5, alpha, result5, gvl);
c6 = __riscv_vfmacc_vf_f32m1(c6, alpha, result6, gvl);
c7 = __riscv_vfmacc_vf_f32m1(c7, alpha, result7, gvl);

ci= n_top * ldc + m_top;

__riscv_vse32_v_f32m1(&C[ci], c0, gvl); ci += ldc - gvl * 0;
__riscv_vse32_v_f32m1(&C[ci], c1, gvl); ci += ldc - gvl * 0;
__riscv_vse32_v_f32m1(&C[ci], c2, gvl); ci += ldc - gvl * 0;
__riscv_vse32_v_f32m1(&C[ci], c3, gvl); ci += ldc - gvl * 0;
__riscv_vse32_v_f32m1(&C[ci], c4, gvl); ci += ldc - gvl * 0;
__riscv_vse32_v_f32m1(&C[ci], c5, gvl); ci += ldc - gvl * 0;
__riscv_vse32_v_f32m1(&C[ci], c6, gvl); ci += ldc - gvl * 0;
__riscv_vse32_v_f32m1(&C[ci], c7, gvl);
m_top += 4;
}

if( M & 2 ) {
float result0 = 0;
float result1 = 0;
float result2 = 0;
float result3 = 0;
float result4 = 0;
float result5 = 0;
float result6 = 0;
float result7 = 0;
float result8 = 0;
float result9 = 0;
float result10 = 0;
float result11 = 0;
float result12 = 0;
float result13 = 0;
float result14 = 0;
float result15 = 0;
BLASLONG ai = m_top * K;
BLASLONG bi = n_top * K;
for(BLASLONG k=0; k<K; k++) {
result0+=(float)(A[ai+0]*B[bi+0]);
result1+=(float)(A[ai+1]*B[bi+0]);
result2+=(float)(A[ai+0]*B[bi+1]);
result3+=(float)(A[ai+1]*B[bi+1]);
result4+=(float)(A[ai+0]*B[bi+2]);
result5+=(float)(A[ai+1]*B[bi+2]);
result6+=(float)(A[ai+0]*B[bi+3]);
result7+=(float)(A[ai+1]*B[bi+3]);
result8+=(float)(A[ai+0]*B[bi+4]);
result9+=(float)(A[ai+1]*B[bi+4]);
result10+=(float)(A[ai+0]*B[bi+5]);
result11+=(float)(A[ai+1]*B[bi+5]);
result12+=(float)(A[ai+0]*B[bi+6]);
result13+=(float)(A[ai+1]*B[bi+6]);
result14+=(float)(A[ai+0]*B[bi+7]);
result15+=(float)(A[ai+1]*B[bi+7]);
ai+=2;
bi+=8;
}

BLASLONG ci=n_top*ldc+m_top;

C[ci + 0 * ldc + 0] += alpha * result0;
C[ci + 0 * ldc + 1] += alpha * result1;
C[ci + 1 * ldc + 0] += alpha * result2;
C[ci + 1 * ldc + 1] += alpha * result3;
C[ci + 2 * ldc + 0] += alpha * result4;
C[ci + 2 * ldc + 1] += alpha * result5;
C[ci + 3 * ldc + 0] += alpha * result6;
C[ci + 3 * ldc + 1] += alpha * result7;
C[ci + 4 * ldc + 0] += alpha * result8;
C[ci + 4 * ldc + 1] += alpha * result9;
C[ci + 5 * ldc + 0] += alpha * result10;
C[ci + 5 * ldc + 1] += alpha * result11;
C[ci + 6 * ldc + 0] += alpha * result12;
C[ci + 6 * ldc + 1] += alpha * result13;
C[ci + 7 * ldc + 0] += alpha * result14;
C[ci + 7 * ldc + 1] += alpha * result15;

m_top+=2;
}


if( M & 1 ) {
float result0 = 0;
float result1 = 0;
float result2 = 0;
float result3 = 0;
float result4 = 0;
float result5 = 0;
float result6 = 0;
float result7 = 0;
BLASLONG ai = m_top * K;
BLASLONG bi = n_top * K;

for(BLASLONG k=0; k<K; k++) {
result0+=(float)(A[ai+0]*B[bi+0]);
result1+=(float)(A[ai+0]*B[bi+1]);
result2+=(float)(A[ai+0]*B[bi+2]);
result3+=(float)(A[ai+0]*B[bi+3]);
result4+=(float)(A[ai+0]*B[bi+4]);
result5+=(float)(A[ai+0]*B[bi+5]);
result6+=(float)(A[ai+0]*B[bi+6]);
result7+=(float)(A[ai+0]*B[bi+7]);
ai+=1;
bi+=8;
}

BLASLONG ci = n_top * ldc + m_top;
C[ci + 0 * ldc + 0] += alpha * result0;
C[ci + 1 * ldc + 0] += alpha * result1;
C[ci + 2 * ldc + 0] += alpha * result2;
C[ci + 3 * ldc + 0] += alpha * result3;
C[ci + 4 * ldc + 0] += alpha * result4;
C[ci + 5 * ldc + 0] += alpha * result5;
C[ci + 6 * ldc + 0] += alpha * result6;
C[ci + 7 * ldc + 0] += alpha * result7;
m_top+=1;
}
n_top += 8;
}

if( N & 4 ) {
gvl = __riscv_vsetvl_e16m1(16);
m_top = 0;

for (BLASLONG i=0; i<M/16; i+=1) {
BLASLONG ai=m_top*K;
BLASLONG bi=n_top*K;
_Float16 B0 = B[bi+0];
_Float16 B1 = B[bi+1];
_Float16 B2 = B[bi+2];
_Float16 B3 = B[bi+3];
bi += 4;

vfloat16m1_t A0 = __riscv_vle16_v_f16m1( &A[ai+0*gvl], gvl );
ai += 16;
vfloat32m2_t result0 = __riscv_vfwmul_vf_f32m2( A0, B0, gvl);
vfloat32m2_t result1 = __riscv_vfwmul_vf_f32m2( A0, B1, gvl);
vfloat32m2_t result2 = __riscv_vfwmul_vf_f32m2( A0, B2, gvl);
vfloat32m2_t result3 = __riscv_vfwmul_vf_f32m2( A0, B3, gvl);
for(BLASLONG k=1; k<K; k++) {
B0 = B[bi+0];
B1 = B[bi+1];
B2 = B[bi+2];
B3 = B[bi+3];
bi += 4;

A0 = __riscv_vle16_v_f16m1( &A[ai+0*gvl], gvl );
ai += 16;
result0 = __riscv_vfwmacc_vf_f32m2(result0, B0, A0, gvl);
result1 = __riscv_vfwmacc_vf_f32m2(result1, B1, A0, gvl);
result2 = __riscv_vfwmacc_vf_f32m2(result2, B2, A0, gvl);
result3 = __riscv_vfwmacc_vf_f32m2(result3, B3, A0, gvl);
}
BLASLONG ci=n_top*ldc+m_top;

vfloat32m2_t c0 = __riscv_vle32_v_f32m2( &C[ci], gvl); ci += ldc-gvl*0;
vfloat32m2_t c1 = __riscv_vle32_v_f32m2( &C[ci], gvl); ci += ldc-gvl*0;
vfloat32m2_t c2 = __riscv_vle32_v_f32m2( &C[ci], gvl); ci += ldc-gvl*0;
vfloat32m2_t c3 = __riscv_vle32_v_f32m2( &C[ci], gvl);
c0 = __riscv_vfmacc_vf_f32m2(c0, alpha, result0, gvl);
c1 = __riscv_vfmacc_vf_f32m2(c1, alpha, result1, gvl);
c2 = __riscv_vfmacc_vf_f32m2(c2, alpha, result2, gvl);
c3 = __riscv_vfmacc_vf_f32m2(c3, alpha, result3, gvl);

ci=n_top*ldc+m_top;

__riscv_vse32_v_f32m2( &C[ci], c0, gvl); ci += ldc-gvl*0;
__riscv_vse32_v_f32m2( &C[ci], c1, gvl); ci += ldc-gvl*0;
__riscv_vse32_v_f32m2( &C[ci], c2, gvl); ci += ldc-gvl*0;
__riscv_vse32_v_f32m2( &C[ci], c3, gvl);
m_top += 16;
}

if( M & 8 ) {
gvl = __riscv_vsetvl_e16mf2(8);
BLASLONG ai=m_top*K;
BLASLONG bi=n_top*K;
_Float16 B0 = B[bi+0];
_Float16 B1 = B[bi+1];
_Float16 B2 = B[bi+2];
_Float16 B3 = B[bi+3];
bi += 4;

vfloat16mf2_t A0 = __riscv_vle16_v_f16mf2( &A[ai+0*gvl], gvl );
ai += 8;

vfloat32m1_t result0 = __riscv_vfwmul_vf_f32m1( A0, B0, gvl);
vfloat32m1_t result1 = __riscv_vfwmul_vf_f32m1( A0, B1, gvl);
vfloat32m1_t result2 = __riscv_vfwmul_vf_f32m1( A0, B2, gvl);
vfloat32m1_t result3 = __riscv_vfwmul_vf_f32m1( A0, B3, gvl);
for(BLASLONG k=1; k<K; k++) {
B0 = B[bi+0];
B1 = B[bi+1];
B2 = B[bi+2];
B3 = B[bi+3];
bi += 4;

A0 = __riscv_vle16_v_f16mf2( &A[ai+0*gvl], gvl );
ai += 8;

result0 = __riscv_vfwmacc_vf_f32m1(result0, B0, A0, gvl);
result1 = __riscv_vfwmacc_vf_f32m1(result1, B1, A0, gvl);
result2 = __riscv_vfwmacc_vf_f32m1(result2, B2, A0, gvl);
result3 = __riscv_vfwmacc_vf_f32m1(result3, B3, A0, gvl);
}

BLASLONG ci=n_top*ldc+m_top;

vfloat32m1_t c0 = __riscv_vle32_v_f32m1( &C[ci], gvl); ci += ldc - gvl * 0;
vfloat32m1_t c1 = __riscv_vle32_v_f32m1( &C[ci], gvl); ci += ldc - gvl * 0;
vfloat32m1_t c2 = __riscv_vle32_v_f32m1( &C[ci], gvl); ci += ldc - gvl * 0;
vfloat32m1_t c3 = __riscv_vle32_v_f32m1( &C[ci], gvl);
c0 = __riscv_vfmacc_vf_f32m1(c0, alpha, result0, gvl);
c1 = __riscv_vfmacc_vf_f32m1(c1, alpha, result1, gvl);
c2 = __riscv_vfmacc_vf_f32m1(c2, alpha, result2, gvl);
c3 = __riscv_vfmacc_vf_f32m1(c3, alpha, result3, gvl);

ci = n_top * ldc + m_top;

__riscv_vse32_v_f32m1( &C[ci], c0, gvl); ci += ldc-gvl*0;
__riscv_vse32_v_f32m1( &C[ci], c1, gvl); ci += ldc-gvl*0;
__riscv_vse32_v_f32m1( &C[ci], c2, gvl); ci += ldc-gvl*0;
__riscv_vse32_v_f32m1( &C[ci], c3, gvl);
m_top += 8;
}

if( M & 4 ) {
gvl = __riscv_vsetvl_e16mf2(4);

BLASLONG ai=m_top*K;
BLASLONG bi=n_top*K;
_Float16 B0 = B[bi+0];
_Float16 B1 = B[bi+1];
_Float16 B2 = B[bi+2];
_Float16 B3 = B[bi+3];
bi += 4;

vfloat16mf2_t A0 = __riscv_vle16_v_f16mf2( &A[ai+0*gvl], gvl );
ai += 4;

vfloat32m1_t result0 = __riscv_vfwmul_vf_f32m1( A0, B0, gvl);
vfloat32m1_t result1 = __riscv_vfwmul_vf_f32m1( A0, B1, gvl);
vfloat32m1_t result2 = __riscv_vfwmul_vf_f32m1( A0, B2, gvl);
vfloat32m1_t result3 = __riscv_vfwmul_vf_f32m1( A0, B3, gvl);

for(BLASLONG k=1; k < K; ++k) {
B0 = B[bi+0];
B1 = B[bi+1];
B2 = B[bi+2];
B3 = B[bi+3];
bi += 4;

A0 = __riscv_vle16_v_f16mf2( &A[ai+0*gvl], gvl );
ai += 4;

result0 = __riscv_vfwmacc_vf_f32m1(result0, B0, A0, gvl);
result1 = __riscv_vfwmacc_vf_f32m1(result1, B1, A0, gvl);
result2 = __riscv_vfwmacc_vf_f32m1(result2, B2, A0, gvl);
result3 = __riscv_vfwmacc_vf_f32m1(result3, B3, A0, gvl);
}

BLASLONG ci = n_top * ldc + m_top;

vfloat32m1_t c0 = __riscv_vle32_v_f32m1(&C[ci], gvl);
ci += ldc - gvl * 0;
vfloat32m1_t c1 = __riscv_vle32_v_f32m1(&C[ci], gvl);
ci += ldc - gvl * 0;
vfloat32m1_t c2 = __riscv_vle32_v_f32m1(&C[ci], gvl);
ci += ldc - gvl * 0;
vfloat32m1_t c3 = __riscv_vle32_v_f32m1(&C[ci], gvl);
c0 = __riscv_vfmacc_vf_f32m1(c0, alpha, result0, gvl);
c1 = __riscv_vfmacc_vf_f32m1(c1, alpha, result1, gvl);
c2 = __riscv_vfmacc_vf_f32m1(c2, alpha, result2, gvl);
c3 = __riscv_vfmacc_vf_f32m1(c3, alpha, result3, gvl);

ci= n_top * ldc + m_top;

__riscv_vse32_v_f32m1(&C[ci], c0, gvl); ci += ldc - gvl * 0;
__riscv_vse32_v_f32m1(&C[ci], c1, gvl); ci += ldc - gvl * 0;
__riscv_vse32_v_f32m1(&C[ci], c2, gvl); ci += ldc - gvl * 0;
__riscv_vse32_v_f32m1(&C[ci], c3, gvl);
m_top += 4;
}


if( M & 2 ) {
float result0 = 0;
float result1 = 0;
float result2 = 0;
float result3 = 0;
float result4 = 0;
float result5 = 0;
float result6 = 0;
float result7 = 0;
BLASLONG ai = m_top * K;
BLASLONG bi = n_top * K;

for(BLASLONG k=0; k<K; k++) {
result0+=(float)(A[ai+0]*B[bi+0]);
result1+=(float)(A[ai+1]*B[bi+0]);
result2+=(float)(A[ai+0]*B[bi+1]);
result3+=(float)(A[ai+1]*B[bi+1]);
result4+=(float)(A[ai+0]*B[bi+2]);
result5+=(float)(A[ai+1]*B[bi+2]);
result6+=(float)(A[ai+0]*B[bi+3]);
result7+=(float)(A[ai+1]*B[bi+3]);
ai+=2;
bi+=4;
}
BLASLONG ci=n_top*ldc+m_top;
C[ci + 0 * ldc + 0] += alpha * result0;
C[ci + 0 * ldc + 1] += alpha * result1;
C[ci + 1 * ldc + 0] += alpha * result2;
C[ci + 1 * ldc + 1] += alpha * result3;
C[ci + 2 * ldc + 0] += alpha * result4;
C[ci + 2 * ldc + 1] += alpha * result5;
C[ci + 3 * ldc + 0] += alpha * result6;
C[ci + 3 * ldc + 1] += alpha * result7;

m_top += 2;
}


if( M & 1 ) {
float result0 = 0;
float result1 = 0;
float result2 = 0;
float result3 = 0;
BLASLONG ai = m_top * K;
BLASLONG bi = n_top * K;

for(BLASLONG k=0; k<K; k++) {
result0+=(float)(A[ai+0]*B[bi+0]);
result1+=(float)(A[ai+0]*B[bi+1]);
result2+=(float)(A[ai+0]*B[bi+2]);
result3+=(float)(A[ai+0]*B[bi+3]);
ai+=1;
bi+=4;
}

BLASLONG ci = n_top * ldc + m_top;
C[ci + 0 * ldc + 0] += alpha * result0;
C[ci + 1 * ldc + 0] += alpha * result1;
C[ci + 2 * ldc + 0] += alpha * result2;
C[ci + 3 * ldc + 0] += alpha * result3;
m_top += 1;
}

n_top += 4;
}



// -- tails for N=2
if( N & 2 ) {
gvl = __riscv_vsetvl_e16m1(16);
m_top = 0;

for (BLASLONG i=0; i<M/16; i+=1) {
BLASLONG ai=m_top*K;
BLASLONG bi=n_top*K;
_Float16 B0 = B[bi+0];
_Float16 B1 = B[bi+1];
bi += 2;

vfloat16m1_t A0 = __riscv_vle16_v_f16m1( &A[ai+0*gvl], gvl );
ai += 16;
vfloat32m2_t result0 = __riscv_vfwmul_vf_f32m2( A0, B0, gvl);
vfloat32m2_t result1 = __riscv_vfwmul_vf_f32m2( A0, B1, gvl);
for(BLASLONG k=1; k<K; k++) {
B0 = B[bi+0];
B1 = B[bi+1];
bi += 2;

A0 = __riscv_vle16_v_f16m1( &A[ai+0*gvl], gvl );
ai += 16;

result0 = __riscv_vfwmacc_vf_f32m2(result0, B0, A0, gvl);
result1 = __riscv_vfwmacc_vf_f32m2(result1, B1, A0, gvl);
}

BLASLONG ci=n_top*ldc+m_top;

vfloat32m2_t c0 = __riscv_vle32_v_f32m2( &C[ci], gvl); ci += ldc-gvl*0;
vfloat32m2_t c1 = __riscv_vle32_v_f32m2( &C[ci], gvl);
c0 = __riscv_vfmacc_vf_f32m2(c0, alpha, result0, gvl);
c1 = __riscv_vfmacc_vf_f32m2(c1, alpha, result1, gvl);

ci=n_top*ldc+m_top;

__riscv_vse32_v_f32m2( &C[ci], c0, gvl); ci += ldc-gvl*0;
__riscv_vse32_v_f32m2( &C[ci], c1, gvl);
m_top += 16;
}

if( M & 8 ) {
gvl = __riscv_vsetvl_e16mf2(8);
BLASLONG ai=m_top*K;
BLASLONG bi=n_top*K;
_Float16 B0 = B[bi+0];
_Float16 B1 = B[bi+1];
bi += 2;

vfloat16mf2_t A0 = __riscv_vle16_v_f16mf2( &A[ai+0*gvl], gvl );
ai += 8;

vfloat32m1_t result0 = __riscv_vfwmul_vf_f32m1( A0, B0, gvl);
vfloat32m1_t result1 = __riscv_vfwmul_vf_f32m1( A0, B1, gvl);
for(BLASLONG k=1; k<K; k++) {
B0 = B[bi+0];
B1 = B[bi+1];
bi += 2;

A0 = __riscv_vle16_v_f16mf2( &A[ai+0*gvl], gvl );
ai += 8;

result0 = __riscv_vfwmacc_vf_f32m1(result0, B0, A0, gvl);
result1 = __riscv_vfwmacc_vf_f32m1(result1, B1, A0, gvl);
}


BLASLONG ci=n_top*ldc+m_top;

vfloat32m1_t c0 = __riscv_vle32_v_f32m1( &C[ci], gvl); ci += ldc - gvl * 0;
vfloat32m1_t c1 = __riscv_vle32_v_f32m1( &C[ci], gvl);
c0 = __riscv_vfmacc_vf_f32m1(c0, alpha, result0, gvl);
c1 = __riscv_vfmacc_vf_f32m1(c1, alpha, result1, gvl);

ci = n_top * ldc + m_top;

__riscv_vse32_v_f32m1( &C[ci], c0, gvl); ci += ldc-gvl*0;
__riscv_vse32_v_f32m1( &C[ci], c1, gvl);
m_top += 8;
}

if( M & 4 ) {
gvl = __riscv_vsetvl_e16mf2(4);

BLASLONG ai=m_top*K;
BLASLONG bi=n_top*K;
_Float16 B0 = B[bi+0];
_Float16 B1 = B[bi+1];
bi += 2;

vfloat16mf2_t A0 = __riscv_vle16_v_f16mf2( &A[ai+0*gvl], gvl );
ai += 4;

vfloat32m1_t result0 = __riscv_vfwmul_vf_f32m1( A0, B0, gvl);
vfloat32m1_t result1 = __riscv_vfwmul_vf_f32m1( A0, B1, gvl);

for(BLASLONG k=1; k < K; ++k) {
B0 = B[bi+0];
B1 = B[bi+1];
bi += 2;

A0 = __riscv_vle16_v_f16mf2( &A[ai+0*gvl], gvl );
ai += 4;

result0 = __riscv_vfwmacc_vf_f32m1(result0, B0, A0, gvl);
result1 = __riscv_vfwmacc_vf_f32m1(result1, B1, A0, gvl);
}

BLASLONG ci = n_top * ldc + m_top;

vfloat32m1_t c0 = __riscv_vle32_v_f32m1(&C[ci], gvl);
ci += ldc - gvl * 0;
vfloat32m1_t c1 = __riscv_vle32_v_f32m1(&C[ci], gvl);
c0 = __riscv_vfmacc_vf_f32m1(c0, alpha, result0, gvl);
c1 = __riscv_vfmacc_vf_f32m1(c1, alpha, result1, gvl);

ci= n_top * ldc + m_top;

__riscv_vse32_v_f32m1(&C[ci], c0, gvl); ci += ldc - gvl * 0;
__riscv_vse32_v_f32m1(&C[ci], c1, gvl);
m_top += 4;
}


if( M & 2 ) {
float result0 = 0;
float result1 = 0;
float result2 = 0;
float result3 = 0;
BLASLONG ai = m_top * K;
BLASLONG bi = n_top * K;

for(BLASLONG k=0; k<K; k++) {
result0+=(float)(A[ai+0]*B[bi+0]);
result1+=(float)(A[ai+1]*B[bi+0]);
result2+=(float)(A[ai+0]*B[bi+1]);
result3+=(float)(A[ai+1]*B[bi+1]);
ai+=2;
bi+=2;
}
BLASLONG ci=n_top*ldc+m_top;
C[ci + 0 * ldc + 0] += alpha * result0;
C[ci + 0 * ldc + 1] += alpha * result1;
C[ci + 1 * ldc + 0] += alpha * result2;
C[ci + 1 * ldc + 1] += alpha * result3;

m_top += 2;
}


if( M & 1 ) {
float result0 = 0;
float result1 = 0;
BLASLONG ai = m_top * K;
BLASLONG bi = n_top * K;

for(BLASLONG k=0; k<K; k++) {
result0+=(float)(A[ai+0]*B[bi+0]);
result1+=(float)(A[ai+0]*B[bi+1]);
ai+=1;
bi+=2;
}

BLASLONG ci = n_top * ldc + m_top;
C[ci + 0 * ldc + 0] += alpha * result0;
C[ci + 1 * ldc + 0] += alpha * result1;
m_top += 1;
}

n_top += 2;
}



// -- tails for N=1
if( N & 1 ) {
gvl = __riscv_vsetvl_e16m1(16);
m_top = 0;

for (BLASLONG i=0; i<M/16; i+=1) {
BLASLONG ai=m_top*K;
BLASLONG bi=n_top*K;
_Float16 B0 = B[bi+0];
bi += 1;

vfloat16m1_t A0 = __riscv_vle16_v_f16m1( &A[ai+0*gvl], gvl );
ai += 16;

vfloat32m2_t result0 = __riscv_vfwmul_vf_f32m2( A0, B0, gvl);

for(BLASLONG k=1; k<K; k++) {
B0 = B[bi+0];
bi += 1;

A0 = __riscv_vle16_v_f16m1( &A[ai+0*gvl], gvl );
ai += 16;
result0 = __riscv_vfwmacc_vf_f32m2(result0, B0, A0, gvl);
}
BLASLONG ci=n_top*ldc+m_top;

vfloat32m2_t c0 = __riscv_vle32_v_f32m2( &C[ci], gvl);
c0 = __riscv_vfmacc_vf_f32m2(c0, alpha, result0, gvl);

ci=n_top*ldc+m_top;

__riscv_vse32_v_f32m2( &C[ci], c0, gvl);
m_top += 16;
}

if( M & 8 ) {
gvl = __riscv_vsetvl_e16mf2(8);
BLASLONG ai=m_top*K;
BLASLONG bi=n_top*K;
_Float16 B0 = B[bi+0];
bi += 1;

vfloat16mf2_t A0 = __riscv_vle16_v_f16mf2( &A[ai+0*gvl], gvl );
ai += 8;

vfloat32m1_t result0 = __riscv_vfwmul_vf_f32m1( A0, B0, gvl);
for(BLASLONG k=1; k<K; k++) {
B0 = B[bi+0];
bi += 1;

A0 = __riscv_vle16_v_f16mf2( &A[ai+0*gvl], gvl );
ai += 8;

result0 = __riscv_vfwmacc_vf_f32m1(result0, B0, A0, gvl);
}


BLASLONG ci=n_top*ldc+m_top;

vfloat32m1_t c0 = __riscv_vle32_v_f32m1( &C[ci], gvl);
c0 = __riscv_vfmacc_vf_f32m1(c0, alpha, result0, gvl);

ci = n_top * ldc + m_top;

__riscv_vse32_v_f32m1( &C[ci], c0, gvl);
m_top += 8;
}

if( M & 4 ) {
gvl = __riscv_vsetvl_e16mf2(4);

BLASLONG ai=m_top*K;
BLASLONG bi=n_top*K;
_Float16 B0 = B[bi+0];
bi += 1;

vfloat16mf2_t A0 = __riscv_vle16_v_f16mf2( &A[ai+0*gvl], gvl );
ai += 4;

vfloat32m1_t result0 = __riscv_vfwmul_vf_f32m1( A0, B0, gvl);

for(BLASLONG k=1; k < K; ++k) {
B0 = B[bi+0];
bi += 1;

A0 = __riscv_vle16_v_f16mf2( &A[ai+0*gvl], gvl );
ai += 4;

result0 = __riscv_vfwmacc_vf_f32m1(result0, B0, A0, gvl);
}

BLASLONG ci = n_top * ldc + m_top;

vfloat32m1_t c0 = __riscv_vle32_v_f32m1(&C[ci], gvl);
c0 = __riscv_vfmacc_vf_f32m1(c0, alpha, result0, gvl);

ci= n_top * ldc + m_top;

__riscv_vse32_v_f32m1(&C[ci], c0, gvl);
m_top += 4;
}


if( M & 2 ) {
float result0 = 0;
float result1 = 0;
BLASLONG ai = m_top * K;
BLASLONG bi = n_top * K;

for(BLASLONG k=0; k<K; k++) {
result0+=(float)(A[ai+0]*B[bi+0]);
result1+=(float)(A[ai+1]*B[bi+0]);
ai+=2;
bi+=1;
}
BLASLONG ci=n_top*ldc+m_top;
C[ci + 0 * ldc + 0] += alpha * result0;
C[ci + 0 * ldc + 1] += alpha * result1;

m_top += 2;
}


if( M & 1 ) {
float result0 = 0;
BLASLONG ai = m_top * K;
BLASLONG bi = n_top * K;

for(BLASLONG k=0; k<K; k++) {
result0+=(float)(A[ai+0]*B[bi+0]);
ai+=1;
bi+=1;
}

BLASLONG ci = n_top * ldc + m_top;
C[ci + 0 * ldc + 0] += alpha * result0;
m_top += 1;
}

n_top += 1;
}
return 0;
}

+ 767
- 0
kernel/riscv64/shgemm_kernel_8x8_zvl128b.c View File

@@ -0,0 +1,767 @@

#include "common.h"
#include <riscv_vector.h>

int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B, FLOAT *C, BLASLONG ldc)
{
BLASLONG gvl = 0;
BLASLONG m_top = 0;
BLASLONG n_top = 0;

// -- MAIN PASS
for (BLASLONG j=0; j<N/8; j+=1) {
m_top = 0;
BLASLONG gvl = __riscv_vsetvl_e16m1(8);

for (BLASLONG i=0; i<M/8; i+=1) {
BLASLONG ai=m_top*K;
BLASLONG bi=n_top*K;
_Float16 B0 = B[bi+0];
_Float16 B1 = B[bi+1];
_Float16 B2 = B[bi+2];
_Float16 B3 = B[bi+3];
_Float16 B4 = B[bi+4];
_Float16 B5 = B[bi+5];
_Float16 B6 = B[bi+6];
_Float16 B7 = B[bi+7];
bi += 8;

vfloat16m1_t A0 = __riscv_vle16_v_f16m1( &A[ai+0*gvl], gvl );
ai += 8;

vfloat32m2_t result0 = __riscv_vfwmul_vf_f32m2( A0, B0, gvl);
vfloat32m2_t result1 = __riscv_vfwmul_vf_f32m2( A0, B1, gvl);
vfloat32m2_t result2 = __riscv_vfwmul_vf_f32m2( A0, B2, gvl);
vfloat32m2_t result3 = __riscv_vfwmul_vf_f32m2( A0, B3, gvl);
vfloat32m2_t result4 = __riscv_vfwmul_vf_f32m2( A0, B4, gvl);
vfloat32m2_t result5 = __riscv_vfwmul_vf_f32m2( A0, B5, gvl);
vfloat32m2_t result6 = __riscv_vfwmul_vf_f32m2( A0, B6, gvl);
vfloat32m2_t result7 = __riscv_vfwmul_vf_f32m2( A0, B7, gvl);
for(BLASLONG k=1; k<K; k++) {
B0 = B[bi+0];
B1 = B[bi+1];
B2 = B[bi+2];
B3 = B[bi+3];
B4 = B[bi+4];
B5 = B[bi+5];
B6 = B[bi+6];
B7 = B[bi+7];
bi += 8;

A0 = __riscv_vle16_v_f16m1( &A[ai+0*gvl], gvl );
ai += 8;

result0 = __riscv_vfwmacc_vf_f32m2(result0, B0, A0, gvl);
result1 = __riscv_vfwmacc_vf_f32m2(result1, B1, A0, gvl);
result2 = __riscv_vfwmacc_vf_f32m2(result2, B2, A0, gvl);
result3 = __riscv_vfwmacc_vf_f32m2(result3, B3, A0, gvl);
result4 = __riscv_vfwmacc_vf_f32m2(result4, B4, A0, gvl);
result5 = __riscv_vfwmacc_vf_f32m2(result5, B5, A0, gvl);
result6 = __riscv_vfwmacc_vf_f32m2(result6, B6, A0, gvl);
result7 = __riscv_vfwmacc_vf_f32m2(result7, B7, A0, gvl);
}

BLASLONG ci=n_top*ldc+m_top;

vfloat32m2_t c0 = __riscv_vle32_v_f32m2( &C[ci], gvl); ci += ldc-gvl*0;
vfloat32m2_t c1 = __riscv_vle32_v_f32m2( &C[ci], gvl); ci += ldc-gvl*0;
vfloat32m2_t c2 = __riscv_vle32_v_f32m2( &C[ci], gvl); ci += ldc-gvl*0;
vfloat32m2_t c3 = __riscv_vle32_v_f32m2( &C[ci], gvl); ci += ldc-gvl*0;
vfloat32m2_t c4 = __riscv_vle32_v_f32m2( &C[ci], gvl); ci += ldc-gvl*0;
vfloat32m2_t c5 = __riscv_vle32_v_f32m2( &C[ci], gvl); ci += ldc-gvl*0;
vfloat32m2_t c6 = __riscv_vle32_v_f32m2( &C[ci], gvl); ci += ldc-gvl*0;
vfloat32m2_t c7 = __riscv_vle32_v_f32m2( &C[ci], gvl); ci += ldc-gvl*0;
c0 = __riscv_vfmacc_vf_f32m2(c0, alpha, result0, gvl);
c1 = __riscv_vfmacc_vf_f32m2(c1, alpha, result1, gvl);
c2 = __riscv_vfmacc_vf_f32m2(c2, alpha, result2, gvl);
c3 = __riscv_vfmacc_vf_f32m2(c3, alpha, result3, gvl);
c4 = __riscv_vfmacc_vf_f32m2(c4, alpha, result4, gvl);
c5 = __riscv_vfmacc_vf_f32m2(c5, alpha, result5, gvl);
c6 = __riscv_vfmacc_vf_f32m2(c6, alpha, result6, gvl);
c7 = __riscv_vfmacc_vf_f32m2(c7, alpha, result7, gvl);

ci = n_top * ldc + m_top;

__riscv_vse32_v_f32m2( &C[ci], c0, gvl); ci += ldc-gvl*0;
__riscv_vse32_v_f32m2( &C[ci], c1, gvl); ci += ldc-gvl*0;
__riscv_vse32_v_f32m2( &C[ci], c2, gvl); ci += ldc-gvl*0;
__riscv_vse32_v_f32m2( &C[ci], c3, gvl); ci += ldc-gvl*0;
__riscv_vse32_v_f32m2( &C[ci], c4, gvl); ci += ldc-gvl*0;
__riscv_vse32_v_f32m2( &C[ci], c5, gvl); ci += ldc-gvl*0;
__riscv_vse32_v_f32m2( &C[ci], c6, gvl); ci += ldc-gvl*0;
__riscv_vse32_v_f32m2( &C[ci], c7, gvl); ci += ldc-gvl*0;
m_top += 8;
}

// -- tails for main pass --

if( M & 4 ) {
gvl = __riscv_vsetvl_e16m1(4);

BLASLONG ai=m_top*K;
BLASLONG bi=n_top*K;
_Float16 B0 = B[bi+0];
_Float16 B1 = B[bi+1];
_Float16 B2 = B[bi+2];
_Float16 B3 = B[bi+3];
_Float16 B4 = B[bi+4];
_Float16 B5 = B[bi+5];
_Float16 B6 = B[bi+6];
_Float16 B7 = B[bi+7];
bi += 8;

vfloat16m1_t A0 = __riscv_vle16_v_f16m1(&A[ai + 0 * gvl], gvl);
ai += 4;

vfloat32m2_t result0 = __riscv_vfwmul_vf_f32m2( A0, B0, gvl);
vfloat32m2_t result1 = __riscv_vfwmul_vf_f32m2( A0, B1, gvl);
vfloat32m2_t result2 = __riscv_vfwmul_vf_f32m2( A0, B2, gvl);
vfloat32m2_t result3 = __riscv_vfwmul_vf_f32m2( A0, B3, gvl);
vfloat32m2_t result4 = __riscv_vfwmul_vf_f32m2( A0, B4, gvl);
vfloat32m2_t result5 = __riscv_vfwmul_vf_f32m2( A0, B5, gvl);
vfloat32m2_t result6 = __riscv_vfwmul_vf_f32m2( A0, B6, gvl);
vfloat32m2_t result7 = __riscv_vfwmul_vf_f32m2( A0, B7, gvl);

for(BLASLONG k=1; k < K; ++k) {
B0 = B[bi+0];
B1 = B[bi+1];
B2 = B[bi+2];
B3 = B[bi+3];
B4 = B[bi+4];
B5 = B[bi+5];
B6 = B[bi+6];
B7 = B[bi+7];
bi += 8;

A0 = __riscv_vle16_v_f16m1(&A[ai + 0 * gvl], gvl);
ai += 4;

result0 = __riscv_vfwmacc_vf_f32m2(result0, B0, A0, gvl);
result1 = __riscv_vfwmacc_vf_f32m2(result1, B1, A0, gvl);
result2 = __riscv_vfwmacc_vf_f32m2(result2, B2, A0, gvl);
result3 = __riscv_vfwmacc_vf_f32m2(result3, B3, A0, gvl);
result4 = __riscv_vfwmacc_vf_f32m2(result4, B4, A0, gvl);
result5 = __riscv_vfwmacc_vf_f32m2(result5, B5, A0, gvl);
result6 = __riscv_vfwmacc_vf_f32m2(result6, B6, A0, gvl);
result7 = __riscv_vfwmacc_vf_f32m2(result7, B7, A0, gvl);
}

BLASLONG ci = n_top * ldc + m_top;

vfloat32m2_t c0 = __riscv_vle32_v_f32m2(&C[ci], gvl);
ci += ldc - gvl * 0;
vfloat32m2_t c1 = __riscv_vle32_v_f32m2(&C[ci], gvl);
ci += ldc - gvl * 0;
vfloat32m2_t c2 = __riscv_vle32_v_f32m2(&C[ci], gvl);
ci += ldc - gvl * 0;
vfloat32m2_t c3 = __riscv_vle32_v_f32m2(&C[ci], gvl);
ci += ldc - gvl * 0;
vfloat32m2_t c4 = __riscv_vle32_v_f32m2(&C[ci], gvl);
ci += ldc - gvl * 0;
vfloat32m2_t c5 = __riscv_vle32_v_f32m2(&C[ci], gvl);
ci += ldc - gvl * 0;
vfloat32m2_t c6 = __riscv_vle32_v_f32m2(&C[ci], gvl);
ci += ldc - gvl * 0;
vfloat32m2_t c7 = __riscv_vle32_v_f32m2(&C[ci], gvl);
c0 = __riscv_vfmacc_vf_f32m2(c0, alpha, result0, gvl);
c1 = __riscv_vfmacc_vf_f32m2(c1, alpha, result1, gvl);
c2 = __riscv_vfmacc_vf_f32m2(c2, alpha, result2, gvl);
c3 = __riscv_vfmacc_vf_f32m2(c3, alpha, result3, gvl);
c4 = __riscv_vfmacc_vf_f32m2(c4, alpha, result4, gvl);
c5 = __riscv_vfmacc_vf_f32m2(c5, alpha, result5, gvl);
c6 = __riscv_vfmacc_vf_f32m2(c6, alpha, result6, gvl);
c7 = __riscv_vfmacc_vf_f32m2(c7, alpha, result7, gvl);

ci= n_top * ldc + m_top;

__riscv_vse32_v_f32m2(&C[ci], c0, gvl); ci += ldc - gvl * 0;
__riscv_vse32_v_f32m2(&C[ci], c1, gvl); ci += ldc - gvl * 0;
__riscv_vse32_v_f32m2(&C[ci], c2, gvl); ci += ldc - gvl * 0;
__riscv_vse32_v_f32m2(&C[ci], c3, gvl); ci += ldc - gvl * 0;
__riscv_vse32_v_f32m2(&C[ci], c4, gvl); ci += ldc - gvl * 0;
__riscv_vse32_v_f32m2(&C[ci], c5, gvl); ci += ldc - gvl * 0;
__riscv_vse32_v_f32m2(&C[ci], c6, gvl); ci += ldc - gvl * 0;
__riscv_vse32_v_f32m2(&C[ci], c7, gvl);
m_top += 4;
}


if( M & 2 ) {

BLASLONG ai = m_top * K;
BLASLONG bi = n_top * K;

float result0 = 0;
float result1 = 0;
float result2 = 0;
float result3 = 0;
float result4 = 0;
float result5 = 0;
float result6 = 0;
float result7 = 0;
float result8 = 0;
float result9 = 0;
float result10 = 0;
float result11 = 0;
float result12 = 0;
float result13 = 0;
float result14 = 0;
float result15 = 0;

for(BLASLONG k=0; k<K; k++) {
result0+=(float)(A[ai+0]*B[bi+0]);
result1+=(float)(A[ai+1]*B[bi+0]);
result2+=(float)(A[ai+0]*B[bi+1]);
result3+=(float)(A[ai+1]*B[bi+1]);
result4+=(float)(A[ai+0]*B[bi+2]);
result5+=(float)(A[ai+1]*B[bi+2]);
result6+=(float)(A[ai+0]*B[bi+3]);
result7+=(float)(A[ai+1]*B[bi+3]);
result8+=(float)(A[ai+0]*B[bi+4]);
result9+=(float)(A[ai+1]*B[bi+4]);
result10+=(float)(A[ai+0]*B[bi+5]);
result11+=(float)(A[ai+1]*B[bi+5]);
result12+=(float)(A[ai+0]*B[bi+6]);
result13+=(float)(A[ai+1]*B[bi+6]);
result14+=(float)(A[ai+0]*B[bi+7]);
result15+=(float)(A[ai+1]*B[bi+7]);
ai+=2;
bi+=8;
}
BLASLONG ci=n_top*ldc+m_top;
C[ci + 0 * ldc + 0] += alpha * result0;
C[ci + 0 * ldc + 1] += alpha * result1;
C[ci + 1 * ldc + 0] += alpha * result2;
C[ci + 1 * ldc + 1] += alpha * result3;
C[ci + 2 * ldc + 0] += alpha * result4;
C[ci + 2 * ldc + 1] += alpha * result5;
C[ci + 3 * ldc + 0] += alpha * result6;
C[ci + 3 * ldc + 1] += alpha * result7;
C[ci + 4 * ldc + 0] += alpha * result8;
C[ci + 4 * ldc + 1] += alpha * result9;
C[ci + 5 * ldc + 0] += alpha * result10;
C[ci + 5 * ldc + 1] += alpha * result11;
C[ci + 6 * ldc + 0] += alpha * result12;
C[ci + 6 * ldc + 1] += alpha * result13;
C[ci + 7 * ldc + 0] += alpha * result14;
C[ci + 7 * ldc + 1] += alpha * result15;

m_top+=2;
}


if( M & 1 ) {
float result0 = 0;
float result1 = 0;
float result2 = 0;
float result3 = 0;
float result4 = 0;
float result5 = 0;
float result6 = 0;
float result7 = 0;
BLASLONG ai = m_top * K;
BLASLONG bi = n_top * K;

for(BLASLONG k=0; k<K; k++) {
result0+=(float)(A[ai+0]*B[bi+0]);
result1+=(float)(A[ai+0]*B[bi+1]);
result2+=(float)(A[ai+0]*B[bi+2]);
result3+=(float)(A[ai+0]*B[bi+3]);
result4+=(float)(A[ai+0]*B[bi+4]);
result5+=(float)(A[ai+0]*B[bi+5]);
result6+=(float)(A[ai+0]*B[bi+6]);
result7+=(float)(A[ai+0]*B[bi+7]);
ai+=1;
bi+=8;
}

BLASLONG ci = n_top * ldc + m_top;
C[ci + 0 * ldc + 0] += alpha * result0;
C[ci + 1 * ldc + 0] += alpha * result1;
C[ci + 2 * ldc + 0] += alpha * result2;
C[ci + 3 * ldc + 0] += alpha * result3;
C[ci + 4 * ldc + 0] += alpha * result4;
C[ci + 5 * ldc + 0] += alpha * result5;
C[ci + 6 * ldc + 0] += alpha * result6;
C[ci + 7 * ldc + 0] += alpha * result7;
m_top+=1;
}

n_top += 8;
}

// -- tails for N=4
if( N & 4 ) {
gvl = __riscv_vsetvl_e16m1(8);
m_top = 0;

for (BLASLONG i=0; i<M/8; i+=1) {
BLASLONG ai=m_top*K;
BLASLONG bi=n_top*K;
_Float16 B0 = B[bi+0];
_Float16 B1 = B[bi+1];
_Float16 B2 = B[bi+2];
_Float16 B3 = B[bi+3];
bi += 4;

vfloat16m1_t A0 = __riscv_vle16_v_f16m1( &A[ai+0*gvl], gvl );
ai += 8;

vfloat32m2_t result0 = __riscv_vfwmul_vf_f32m2( A0, B0, gvl);
vfloat32m2_t result1 = __riscv_vfwmul_vf_f32m2( A0, B1, gvl);
vfloat32m2_t result2 = __riscv_vfwmul_vf_f32m2( A0, B2, gvl);
vfloat32m2_t result3 = __riscv_vfwmul_vf_f32m2( A0, B3, gvl);
for(BLASLONG k=1; k<K; k++) {
B0 = B[bi+0];
B1 = B[bi+1];
B2 = B[bi+2];
B3 = B[bi+3];
bi += 4;

A0 = __riscv_vle16_v_f16m1( &A[ai+0*gvl], gvl );
ai += 8;

result0 = __riscv_vfwmacc_vf_f32m2(result0, B0, A0, gvl);
result1 = __riscv_vfwmacc_vf_f32m2(result1, B1, A0, gvl);
result2 = __riscv_vfwmacc_vf_f32m2(result2, B2, A0, gvl);
result3 = __riscv_vfwmacc_vf_f32m2(result3, B3, A0, gvl);
}

BLASLONG ci=n_top*ldc+m_top;

vfloat32m2_t c0 = __riscv_vle32_v_f32m2( &C[ci], gvl); ci += ldc - gvl * 0;
vfloat32m2_t c1 = __riscv_vle32_v_f32m2( &C[ci], gvl); ci += ldc - gvl * 0;
vfloat32m2_t c2 = __riscv_vle32_v_f32m2( &C[ci], gvl); ci += ldc - gvl * 0;
vfloat32m2_t c3 = __riscv_vle32_v_f32m2( &C[ci], gvl);
c0 = __riscv_vfmacc_vf_f32m2(c0, alpha, result0, gvl);
c1 = __riscv_vfmacc_vf_f32m2(c1, alpha, result1, gvl);
c2 = __riscv_vfmacc_vf_f32m2(c2, alpha, result2, gvl);
c3 = __riscv_vfmacc_vf_f32m2(c3, alpha, result3, gvl);

ci = n_top * ldc + m_top;

__riscv_vse32_v_f32m2( &C[ci], c0, gvl); ci += ldc-gvl*0;
__riscv_vse32_v_f32m2( &C[ci], c1, gvl); ci += ldc-gvl*0;
__riscv_vse32_v_f32m2( &C[ci], c2, gvl); ci += ldc-gvl*0;
__riscv_vse32_v_f32m2( &C[ci], c3, gvl);
m_top += 8;
}

if( M & 4 ) {
gvl = __riscv_vsetvl_e16m1(4);

BLASLONG ai=m_top*K;
BLASLONG bi=n_top*K;
_Float16 B0 = B[bi+0];
_Float16 B1 = B[bi+1];
_Float16 B2 = B[bi+2];
_Float16 B3 = B[bi+3];
bi += 4;

vfloat16m1_t A0 = __riscv_vle16_v_f16m1(&A[ai + 0 * gvl], gvl);
ai += 4;

vfloat32m2_t result0 = __riscv_vfwmul_vf_f32m2( A0, B0, gvl);
vfloat32m2_t result1 = __riscv_vfwmul_vf_f32m2( A0, B1, gvl);
vfloat32m2_t result2 = __riscv_vfwmul_vf_f32m2( A0, B2, gvl);
vfloat32m2_t result3 = __riscv_vfwmul_vf_f32m2( A0, B3, gvl);

for(BLASLONG k=1; k < K; ++k) {
B0 = B[bi+0];
B1 = B[bi+1];
B2 = B[bi+2];
B3 = B[bi+3];
bi += 4;

A0 = __riscv_vle16_v_f16m1(&A[ai + 0 * gvl], gvl);
ai += 4;

result0 = __riscv_vfwmacc_vf_f32m2(result0, B0, A0, gvl);
result1 = __riscv_vfwmacc_vf_f32m2(result1, B1, A0, gvl);
result2 = __riscv_vfwmacc_vf_f32m2(result2, B2, A0, gvl);
result3 = __riscv_vfwmacc_vf_f32m2(result3, B3, A0, gvl);
}

BLASLONG ci = n_top * ldc + m_top;

vfloat32m2_t c0 = __riscv_vle32_v_f32m2(&C[ci], gvl);
ci += ldc - gvl * 0;
vfloat32m2_t c1 = __riscv_vle32_v_f32m2(&C[ci], gvl);
ci += ldc - gvl * 0;
vfloat32m2_t c2 = __riscv_vle32_v_f32m2(&C[ci], gvl);
ci += ldc - gvl * 0;
vfloat32m2_t c3 = __riscv_vle32_v_f32m2(&C[ci], gvl);
c0 = __riscv_vfmacc_vf_f32m2(c0, alpha, result0, gvl);
c1 = __riscv_vfmacc_vf_f32m2(c1, alpha, result1, gvl);
c2 = __riscv_vfmacc_vf_f32m2(c2, alpha, result2, gvl);
c3 = __riscv_vfmacc_vf_f32m2(c3, alpha, result3, gvl);

ci= n_top * ldc + m_top;

__riscv_vse32_v_f32m2(&C[ci], c0, gvl); ci += ldc - gvl * 0;
__riscv_vse32_v_f32m2(&C[ci], c1, gvl); ci += ldc - gvl * 0;
__riscv_vse32_v_f32m2(&C[ci], c2, gvl); ci += ldc - gvl * 0;
__riscv_vse32_v_f32m2(&C[ci], c3, gvl);
m_top += 4;
}


if( M & 2 ) {

BLASLONG ai = m_top * K;
BLASLONG bi = n_top * K;

float result0 = 0;
float result1 = 0;
float result2 = 0;
float result3 = 0;
float result4 = 0;
float result5 = 0;
float result6 = 0;
float result7 = 0;

for(BLASLONG k=0; k<K; k++) {
result0+=(float)(A[ai+0]*B[bi+0]);
result1+=(float)(A[ai+1]*B[bi+0]);
result2+=(float)(A[ai+0]*B[bi+1]);
result3+=(float)(A[ai+1]*B[bi+1]);
result4+=(float)(A[ai+0]*B[bi+2]);
result5+=(float)(A[ai+1]*B[bi+2]);
result6+=(float)(A[ai+0]*B[bi+3]);
result7+=(float)(A[ai+1]*B[bi+3]);
ai+=2;
bi+=4;
}
BLASLONG ci=n_top*ldc+m_top;
C[ci + 0 * ldc + 0] += alpha * result0;
C[ci + 0 * ldc + 1] += alpha * result1;
C[ci + 1 * ldc + 0] += alpha * result2;
C[ci + 1 * ldc + 1] += alpha * result3;
C[ci + 2 * ldc + 0] += alpha * result4;
C[ci + 2 * ldc + 1] += alpha * result5;
C[ci + 3 * ldc + 0] += alpha * result6;
C[ci + 3 * ldc + 1] += alpha * result7;

m_top += 2;
}


if( M & 1 ) {
float result0 = 0;
float result1 = 0;
float result2 = 0;
float result3 = 0;
BLASLONG ai = m_top * K;
BLASLONG bi = n_top * K;

for(BLASLONG k=0; k<K; k++) {
result0+=(float)(A[ai+0]*B[bi+0]);
result1+=(float)(A[ai+0]*B[bi+1]);
result2+=(float)(A[ai+0]*B[bi+2]);
result3+=(float)(A[ai+0]*B[bi+3]);
ai+=1;
bi+=4;
}

BLASLONG ci = n_top * ldc + m_top;
C[ci + 0 * ldc + 0] += alpha * result0;
C[ci + 1 * ldc + 0] += alpha * result1;
C[ci + 2 * ldc + 0] += alpha * result2;
C[ci + 3 * ldc + 0] += alpha * result3;
m_top += 1;
}

n_top += 4;
}



// -- tails for N=2
if( N & 2 ) {
gvl = __riscv_vsetvl_e16m1(8);
m_top = 0;

for (BLASLONG i=0; i<M/8; i+=1) {
BLASLONG ai=m_top*K;
BLASLONG bi=n_top*K;
_Float16 B0 = B[bi+0];
_Float16 B1 = B[bi+1];
bi += 2;

vfloat16m1_t A0 = __riscv_vle16_v_f16m1( &A[ai+0*gvl], gvl );
ai += 8;

vfloat32m2_t result0 = __riscv_vfwmul_vf_f32m2( A0, B0, gvl);
vfloat32m2_t result1 = __riscv_vfwmul_vf_f32m2( A0, B1, gvl);
for(BLASLONG k=1; k<K; k++) {
B0 = B[bi+0];
B1 = B[bi+1];
bi += 2;

A0 = __riscv_vle16_v_f16m1( &A[ai+0*gvl], gvl );
ai += 8;

result0 = __riscv_vfwmacc_vf_f32m2(result0, B0, A0, gvl);
result1 = __riscv_vfwmacc_vf_f32m2(result1, B1, A0, gvl);
}


BLASLONG ci=n_top*ldc+m_top;

vfloat32m2_t c0 = __riscv_vle32_v_f32m2( &C[ci], gvl); ci += ldc - gvl * 0;
vfloat32m2_t c1 = __riscv_vle32_v_f32m2( &C[ci], gvl);
c0 = __riscv_vfmacc_vf_f32m2(c0, alpha, result0, gvl);
c1 = __riscv_vfmacc_vf_f32m2(c1, alpha, result1, gvl);

ci = n_top * ldc + m_top;

__riscv_vse32_v_f32m2( &C[ci], c0, gvl); ci += ldc-gvl*0;
__riscv_vse32_v_f32m2( &C[ci], c1, gvl);
m_top += 8;
}

if( M & 4 ) {
gvl = __riscv_vsetvl_e16m1(4);

BLASLONG ai=m_top*K;
BLASLONG bi=n_top*K;
_Float16 B0 = B[bi+0];
_Float16 B1 = B[bi+1];
bi += 2;

vfloat16m1_t A0 = __riscv_vle16_v_f16m1(&A[ai + 0 * gvl], gvl);
ai += 4;

vfloat32m2_t result0 = __riscv_vfwmul_vf_f32m2( A0, B0, gvl);
vfloat32m2_t result1 = __riscv_vfwmul_vf_f32m2( A0, B1, gvl);

for(BLASLONG k=1; k < K; ++k) {
B0 = B[bi+0];
B1 = B[bi+1];
bi += 2;

A0 = __riscv_vle16_v_f16m1(&A[ai + 0 * gvl], gvl);
ai += 4;

result0 = __riscv_vfwmacc_vf_f32m2(result0, B0, A0, gvl);
result1 = __riscv_vfwmacc_vf_f32m2(result1, B1, A0, gvl);
}

BLASLONG ci = n_top * ldc + m_top;

vfloat32m2_t c0 = __riscv_vle32_v_f32m2(&C[ci], gvl);
ci += ldc - gvl * 0;
vfloat32m2_t c1 = __riscv_vle32_v_f32m2(&C[ci], gvl);
c0 = __riscv_vfmacc_vf_f32m2(c0, alpha, result0, gvl);
c1 = __riscv_vfmacc_vf_f32m2(c1, alpha, result1, gvl);

ci= n_top * ldc + m_top;

__riscv_vse32_v_f32m2(&C[ci], c0, gvl); ci += ldc - gvl * 0;
__riscv_vse32_v_f32m2(&C[ci], c1, gvl);
m_top += 4;
}


if( M & 2 ) {

BLASLONG ai = m_top * K;
BLASLONG bi = n_top * K;

float result0 = 0;
float result1 = 0;
float result2 = 0;
float result3 = 0;

for(BLASLONG k=0; k<K; k++) {
result0+=(float)(A[ai+0]*B[bi+0]);
result1+=(float)(A[ai+1]*B[bi+0]);
result2+=(float)(A[ai+0]*B[bi+1]);
result3+=(float)(A[ai+1]*B[bi+1]);
ai+=2;
bi+=2;
}
BLASLONG ci=n_top*ldc+m_top;
C[ci + 0 * ldc + 0] += alpha * result0;
C[ci + 0 * ldc + 1] += alpha * result1;
C[ci + 1 * ldc + 0] += alpha * result2;
C[ci + 1 * ldc + 1] += alpha * result3;

m_top += 2;
}


if( M & 1 ) {
float result0 = 0;
float result1 = 0;
BLASLONG ai = m_top * K;
BLASLONG bi = n_top * K;

for(BLASLONG k=0; k<K; k++) {
result0+=(float)(A[ai+0]*B[bi+0]);
result1+=(float)(A[ai+0]*B[bi+1]);
ai+=1;
bi+=2;
}

BLASLONG ci = n_top * ldc + m_top;
C[ci + 0 * ldc + 0] += alpha * result0;
C[ci + 1 * ldc + 0] += alpha * result1;
m_top += 1;
}

n_top += 2;
}



// -- tails for N=1
if( N & 1 ) {
gvl = __riscv_vsetvl_e16m1(8);
m_top = 0;

for (BLASLONG i=0; i<M/8; i+=1) {
BLASLONG ai=m_top*K;
BLASLONG bi=n_top*K;
_Float16 B0 = B[bi+0];
bi += 1;

vfloat16m1_t A0 = __riscv_vle16_v_f16m1( &A[ai+0*gvl], gvl );
ai += 8;

vfloat32m2_t result0 = __riscv_vfwmul_vf_f32m2( A0, B0, gvl);
for(BLASLONG k=1; k<K; k++) {
B0 = B[bi+0];
bi += 1;

A0 = __riscv_vle16_v_f16m1( &A[ai+0*gvl], gvl );
ai += 8;

result0 = __riscv_vfwmacc_vf_f32m2(result0, B0, A0, gvl);
}


BLASLONG ci=n_top*ldc+m_top;

vfloat32m2_t c0 = __riscv_vle32_v_f32m2( &C[ci], gvl);
c0 = __riscv_vfmacc_vf_f32m2(c0, alpha, result0, gvl);

ci = n_top * ldc + m_top;

__riscv_vse32_v_f32m2( &C[ci], c0, gvl);
m_top += 8;
}

if( M & 4 ) {
gvl = __riscv_vsetvl_e16m1(4);

BLASLONG ai=m_top*K;
BLASLONG bi=n_top*K;
_Float16 B0 = B[bi+0];
bi += 1;

vfloat16m1_t A0 = __riscv_vle16_v_f16m1(&A[ai + 0 * gvl], gvl);
ai += 4;

vfloat32m2_t result0 = __riscv_vfwmul_vf_f32m2( A0, B0, gvl);

for(BLASLONG k=1; k < K; ++k) {
B0 = B[bi+0];
bi += 1;

A0 = __riscv_vle16_v_f16m1(&A[ai + 0 * gvl], gvl);
ai += 4;

result0 = __riscv_vfwmacc_vf_f32m2(result0, B0, A0, gvl);
}

BLASLONG ci = n_top * ldc + m_top;

vfloat32m2_t c0 = __riscv_vle32_v_f32m2(&C[ci], gvl);
c0 = __riscv_vfmacc_vf_f32m2(c0, alpha, result0, gvl);

ci= n_top * ldc + m_top;

__riscv_vse32_v_f32m2(&C[ci], c0, gvl);
m_top += 4;
}


if( M & 2 ) {

BLASLONG ai = m_top * K;
BLASLONG bi = n_top * K;

float result0 = 0;
float result1 = 0;

for(BLASLONG k=0; k<K; k++) {
result0+=(float)(A[ai+0]*B[bi+0]);
result1+=(float)(A[ai+1]*B[bi+0]);
ai+=2;
bi+=1;
}
BLASLONG ci=n_top*ldc+m_top;
C[ci + 0 * ldc + 0] += alpha * result0;
C[ci + 0 * ldc + 1] += alpha * result1;

m_top += 2;
}


if( M & 1 ) {
float result0 = 0;
BLASLONG ai = m_top * K;
BLASLONG bi = n_top * K;

for(BLASLONG k=0; k<K; k++) {
result0+=(float)(A[ai+0]*B[bi+0]);
ai+=1;
bi+=1;
}

BLASLONG ci = n_top * ldc + m_top;
C[ci + 0 * ldc + 0] += alpha * result0;
m_top += 1;
}

n_top += 1;
}

return 0;

}

+ 37
- 0
kernel/setparam-ref.c View File

@@ -125,6 +125,23 @@ gotoblas_t TABLE_NAME = {
#endif #endif
#endif #endif


#ifdef BUILD_HFLOAT16
0, 0, 0,
SHGEMM_DEFAULT_UNROLL_M, SHGEMM_DEFAULT_UNROLL_N,
#ifdef SHGEMM_DEFAULT_UNROLL_MN
SHGEMM_DEFAULT_UNROLL_MN,
#else
MAX(SHGEMM_DEFAULT_UNROLL_M, SHGEMM_DEFAULT_UNROLL_N),
#endif
shgemm_kernelTS, shgemm_betaTS,
#if SHGEMM_DEFAULT_UNROLL_M != SHGEMM_DEFAULT_UNROLL_N
shgemm_incopyTS, shgemm_itcopyTS,
#else
shgemm_oncopyTS, shgemm_otcopyTS,
#endif
shgemm_oncopyTS, shgemm_otcopyTS,
#endif

#if ( BUILD_SINGLE==1) || (BUILD_DOUBLE==1) || (BUILD_COMPLEX==1) || (BUILD_COMPLEX16==1) #if ( BUILD_SINGLE==1) || (BUILD_DOUBLE==1) || (BUILD_COMPLEX==1) || (BUILD_COMPLEX16==1)
0, 0, 0, 0, 0, 0,
SGEMM_DEFAULT_UNROLL_M, SGEMM_DEFAULT_UNROLL_N, SGEMM_DEFAULT_UNROLL_M, SGEMM_DEFAULT_UNROLL_N,
@@ -1252,6 +1269,9 @@ static void init_parameter(void) {


#ifdef BUILD_BFLOAT16 #ifdef BUILD_BFLOAT16
TABLE_NAME.sbgemm_p = SBGEMM_DEFAULT_P; TABLE_NAME.sbgemm_p = SBGEMM_DEFAULT_P;
#endif
#ifdef BUILD_HFLOAT16
TABLE_NAME.shgemm_p = SHGEMM_DEFAULT_P;
#endif #endif
TABLE_NAME.sgemm_p = SGEMM_DEFAULT_P; TABLE_NAME.sgemm_p = SGEMM_DEFAULT_P;
TABLE_NAME.dgemm_p = DGEMM_DEFAULT_P; TABLE_NAME.dgemm_p = DGEMM_DEFAULT_P;
@@ -1260,6 +1280,9 @@ static void init_parameter(void) {


#ifdef BUILD_BFLOAT16 #ifdef BUILD_BFLOAT16
TABLE_NAME.sbgemm_r = SBGEMM_DEFAULT_R; TABLE_NAME.sbgemm_r = SBGEMM_DEFAULT_R;
#endif
#ifdef BUILD_HFLOAT16
TABLE_NAME.shgemm_r = SHGEMM_DEFAULT_R;
#endif #endif
TABLE_NAME.sgemm_r = SGEMM_DEFAULT_R; TABLE_NAME.sgemm_r = SGEMM_DEFAULT_R;
TABLE_NAME.dgemm_r = DGEMM_DEFAULT_R; TABLE_NAME.dgemm_r = DGEMM_DEFAULT_R;
@@ -1269,6 +1292,9 @@ static void init_parameter(void) {


#ifdef BUILD_BFLOAT16 #ifdef BUILD_BFLOAT16
TABLE_NAME.sbgemm_q = SBGEMM_DEFAULT_Q; TABLE_NAME.sbgemm_q = SBGEMM_DEFAULT_Q;
#endif
#ifdef BUILD_HFLOAT16
TABLE_NAME.shgemm_q = SHGEMM_DEFAULT_Q;
#endif #endif
TABLE_NAME.sgemm_q = SGEMM_DEFAULT_Q; TABLE_NAME.sgemm_q = SGEMM_DEFAULT_Q;
TABLE_NAME.dgemm_q = DGEMM_DEFAULT_Q; TABLE_NAME.dgemm_q = DGEMM_DEFAULT_Q;
@@ -1417,6 +1443,10 @@ static void init_parameter(void) {
TABLE_NAME.sbgemm_p = SBGEMM_DEFAULT_P; TABLE_NAME.sbgemm_p = SBGEMM_DEFAULT_P;
TABLE_NAME.sbgemm_q = SBGEMM_DEFAULT_Q; TABLE_NAME.sbgemm_q = SBGEMM_DEFAULT_Q;
#endif #endif
#ifdef BUILD_HFLOAT16
TABLE_NAME.shgemm_p = SHGEMM_DEFAULT_P;
TABLE_NAME.shgemm_q = SHGEMM_DEFAULT_Q;
#endif
#if (BUILD_SINGLE==1) || (BUILD_COMPLEX==1) #if (BUILD_SINGLE==1) || (BUILD_COMPLEX==1)
TABLE_NAME.sgemm_q = SGEMM_DEFAULT_Q; TABLE_NAME.sgemm_q = SGEMM_DEFAULT_Q;
#endif #endif
@@ -2012,6 +2042,13 @@ static void init_parameter(void) {
) / (TABLE_NAME.sbgemm_q * 4) - 15) & ~15); ) / (TABLE_NAME.sbgemm_q * 4) - 15) & ~15);
#endif #endif


#if BUILD_HFLOAT16==1
TABLE_NAME.shgemm_r = (((BUFFER_SIZE -
((TABLE_NAME.shgemm_p * TABLE_NAME.shgemm_q * 4 + TABLE_NAME.offsetA
+ TABLE_NAME.align) & ~TABLE_NAME.align)
) / (TABLE_NAME.shgemm_q * 4) - 15) & ~15);
#endif

#if BUILD_SINGLE==1 #if BUILD_SINGLE==1
TABLE_NAME.sgemm_r = (((BUFFER_SIZE - TABLE_NAME.sgemm_r = (((BUFFER_SIZE -
((TABLE_NAME.sgemm_p * TABLE_NAME.sgemm_q * 4 + TABLE_NAME.offsetA ((TABLE_NAME.sgemm_p * TABLE_NAME.sgemm_q * 4 + TABLE_NAME.offsetA


+ 1
- 0
lapack/CMakeLists.txt View File

@@ -3,6 +3,7 @@ include_directories(${PROJECT_SOURCE_DIR})
include_directories(${PROJECT_BINARY_DIR}) include_directories(${PROJECT_BINARY_DIR})


list (REMOVE_ITEM FLOAT_TYPES "BFLOAT16") list (REMOVE_ITEM FLOAT_TYPES "BFLOAT16")
list (REMOVE_ITEM FLOAT_TYPES "HFLOAT16")


set(LAPACK_SOURCES set(LAPACK_SOURCES
potrf/potrf_U_single.c potrf/potrf_U_single.c


+ 7
- 0
openblas_config_template.h View File

@@ -39,6 +39,13 @@ typedef unsigned long BLASULONG;
typedef uint16_t bfloat16; typedef uint16_t bfloat16;
#endif #endif


#if defined(__GNUC__) && (__GNUC__ >= 12)
typedef _Float16 hfloat16;
#else
#include <stdint.h>
typedef uint16_t hfloat16;
#endif

#ifdef OPENBLAS_USE64BITINT #ifdef OPENBLAS_USE64BITINT
typedef BLASLONG blasint; typedef BLASLONG blasint;
#else #else


+ 29
- 0
param.h View File

@@ -72,6 +72,12 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#ifndef PARAM_H #ifndef PARAM_H
#define PARAM_H #define PARAM_H


#define SHGEMM_DEFAULT_UNROLL_N 8
#define SHGEMM_DEFAULT_UNROLL_M 8
#define SHGEMM_DEFAULT_UNROLL_MN 32
#define SHGEMM_DEFAULT_P 128
#define SHGEMM_DEFAULT_R 240
#define SHGEMM_DEFAULT_Q 12288


#define SBGEMM_DEFAULT_UNROLL_N 4 #define SBGEMM_DEFAULT_UNROLL_N 4
#define SBGEMM_DEFAULT_UNROLL_M 8 #define SBGEMM_DEFAULT_UNROLL_M 8
@@ -3138,10 +3144,16 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#endif #endif


#ifdef RISCV64_ZVL128B #ifdef RISCV64_ZVL128B

#define GEMM_DEFAULT_OFFSET_A 0 #define GEMM_DEFAULT_OFFSET_A 0
#define GEMM_DEFAULT_OFFSET_B 0 #define GEMM_DEFAULT_OFFSET_B 0
#define GEMM_DEFAULT_ALIGN (BLASLONG)0x03fffUL #define GEMM_DEFAULT_ALIGN (BLASLONG)0x03fffUL


#undef SHGEMM_DEFAULT_UNROLL_M
#undef SHGEMM_DEFAULT_UNROLL_N
#define SHGEMM_DEFAULT_UNROLL_M 8
#define SHGEMM_DEFAULT_UNROLL_N 8

#define SGEMM_DEFAULT_UNROLL_M 8 #define SGEMM_DEFAULT_UNROLL_M 8
#define SGEMM_DEFAULT_UNROLL_N 8 #define SGEMM_DEFAULT_UNROLL_N 8


@@ -3154,16 +3166,22 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#define ZGEMM_DEFAULT_UNROLL_M 4 #define ZGEMM_DEFAULT_UNROLL_M 4
#define ZGEMM_DEFAULT_UNROLL_N 4 #define ZGEMM_DEFAULT_UNROLL_N 4


#undef SHGEMM_DEFAULT_P
#define SHGEMM_DEFAULT_P 128
#define SGEMM_DEFAULT_P 128 #define SGEMM_DEFAULT_P 128
#define DGEMM_DEFAULT_P 128 #define DGEMM_DEFAULT_P 128
#define CGEMM_DEFAULT_P 96 #define CGEMM_DEFAULT_P 96
#define ZGEMM_DEFAULT_P 64 #define ZGEMM_DEFAULT_P 64


#undef SHGEMM_DEFAULT_Q
#define SHGEMM_DEFAULT_Q 240
#define SGEMM_DEFAULT_Q 240 #define SGEMM_DEFAULT_Q 240
#define DGEMM_DEFAULT_Q 120 #define DGEMM_DEFAULT_Q 120
#define CGEMM_DEFAULT_Q 120 #define CGEMM_DEFAULT_Q 120
#define ZGEMM_DEFAULT_Q 120 #define ZGEMM_DEFAULT_Q 120


#undef SHGEMM_DEFAULT_R
#define SHGEMM_DEFAULT_R 12288
#define SGEMM_DEFAULT_R 12288 #define SGEMM_DEFAULT_R 12288
#define DGEMM_DEFAULT_R 8192 #define DGEMM_DEFAULT_R 8192
#define CGEMM_DEFAULT_R 4096 #define CGEMM_DEFAULT_R 4096
@@ -3181,6 +3199,11 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#define GEMM_DEFAULT_OFFSET_B 0 #define GEMM_DEFAULT_OFFSET_B 0
#define GEMM_DEFAULT_ALIGN 0x03fffUL #define GEMM_DEFAULT_ALIGN 0x03fffUL


#undef SHGEMM_DEFAULT_UNROLL_M
#undef SHGEMM_DEFAULT_UNROLL_N
#define SHGEMM_DEFAULT_UNROLL_M 16
#define SHGEMM_DEFAULT_UNROLL_N 8

#define SGEMM_DEFAULT_UNROLL_M 16 #define SGEMM_DEFAULT_UNROLL_M 16
#define SGEMM_DEFAULT_UNROLL_N 8 #define SGEMM_DEFAULT_UNROLL_N 8


@@ -3193,16 +3216,22 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#define ZGEMM_DEFAULT_UNROLL_M 8 #define ZGEMM_DEFAULT_UNROLL_M 8
#define ZGEMM_DEFAULT_UNROLL_N 4 #define ZGEMM_DEFAULT_UNROLL_N 4


#undef SHGEMM_DEFAULT_P
#define SHGEMM_DEFAULT_P 128
#define SGEMM_DEFAULT_P 128 #define SGEMM_DEFAULT_P 128
#define DGEMM_DEFAULT_P 64 #define DGEMM_DEFAULT_P 64
#define CGEMM_DEFAULT_P 64 #define CGEMM_DEFAULT_P 64
#define ZGEMM_DEFAULT_P 64 #define ZGEMM_DEFAULT_P 64


#undef SHGEMM_DEFAULT_Q
#define SHGEMM_DEFAULT_Q 128
#define SGEMM_DEFAULT_Q 128 #define SGEMM_DEFAULT_Q 128
#define DGEMM_DEFAULT_Q 128 #define DGEMM_DEFAULT_Q 128
#define CGEMM_DEFAULT_Q 128 #define CGEMM_DEFAULT_Q 128
#define ZGEMM_DEFAULT_Q 64 #define ZGEMM_DEFAULT_Q 64


#undef SHGEMM_DEFAULT_R
#define SHGEMM_DEFAULT_R 16384
#define SGEMM_DEFAULT_R 16384 #define SGEMM_DEFAULT_R 16384
#define DGEMM_DEFAULT_R 8192 #define DGEMM_DEFAULT_R 8192
#define CGEMM_DEFAULT_R 8192 #define CGEMM_DEFAULT_R 8192


Loading…
Cancel
Save