Browse Source

Merge pull request #5290 from Srangrang/develop

Add support for FP16 to openBLAS and shgemm on RISCV
pull/5295/head
Martin Kroeker GitHub 3 months ago
parent
commit
d96daa220d
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
37 changed files with 2522 additions and 92 deletions
  1. +3
    -0
      CMakeLists.txt
  2. +2
    -2
      Makefile.prebuild
  3. +4
    -4
      Makefile.riscv64
  4. +2
    -0
      Makefile.rule
  5. +6
    -0
      Makefile.system
  6. +5
    -2
      Makefile.tail
  7. +21
    -5
      benchmark/Makefile
  8. +8
    -1
      benchmark/gemm.c
  9. +4
    -0
      cblas.h
  10. +11
    -8
      cmake/system.cmake
  11. +15
    -0
      common.h
  12. +2
    -0
      common_interface.h
  13. +18
    -1
      common_level3.h
  14. +45
    -0
      common_macro.h
  15. +72
    -22
      common_param.h
  16. +72
    -0
      common_sh.h
  17. +6
    -0
      driver/level3/CMakeLists.txt
  18. +71
    -16
      driver/level3/Makefile
  19. +1
    -1
      driver/others/Makefile
  20. +22
    -0
      driver/others/parameter.c
  21. +10
    -7
      exports/Makefile
  22. +16
    -7
      exports/gensymbol
  23. +13
    -8
      exports/gensymbol.pl
  24. +2
    -0
      getarch_2nd.c
  25. +3
    -0
      interface/CMakeLists.txt
  26. +22
    -2
      interface/Makefile
  27. +8
    -6
      interface/gemm.c
  28. +55
    -0
      kernel/CMakeLists.txt
  29. +164
    -0
      kernel/Makefile.L3
  30. +11
    -0
      kernel/riscv64/KERNEL.RISCV64_ZVL128B
  31. +18
    -0
      kernel/riscv64/KERNEL.RISCV64_ZVL256B
  32. +969
    -0
      kernel/riscv64/shgemm_kernel_16x8_zvl256b.c
  33. +767
    -0
      kernel/riscv64/shgemm_kernel_8x8_zvl128b.c
  34. +37
    -0
      kernel/setparam-ref.c
  35. +1
    -0
      lapack/CMakeLists.txt
  36. +7
    -0
      openblas_config_template.h
  37. +29
    -0
      param.h

+ 3
- 0
CMakeLists.txt View File

@@ -152,6 +152,9 @@ endif ()
if (NOT DEFINED BUILD_BFLOAT16)
set (BUILD_BFLOAT16 false)
endif ()
if (NOT DEFINED BUILD_HFLOAT16)
set (BUILD_HFLOAT16 false)
endif ()
# set which float types we want to build for
if (NOT DEFINED BUILD_SINGLE AND NOT DEFINED BUILD_DOUBLE AND NOT DEFINED BUILD_COMPLEX AND NOT DEFINED BUILD_COMPLEX16)
# if none are defined, build for all


+ 2
- 2
Makefile.prebuild View File

@@ -64,11 +64,11 @@ TARGET_FLAGS = -march=rv64imafdcv_zba_zbb_zfh -mabi=lp64d
endif

ifeq ($(TARGET), RISCV64_ZVL256B)
TARGET_FLAGS = -march=rv64imafdcv -mabi=lp64d
TARGET_FLAGS = -march=rv64imafdcv_zvfh_zfh -mabi=lp64d
endif

ifeq ($(TARGET), RISCV64_ZVL128B)
TARGET_FLAGS = -march=rv64imafdcv -mabi=lp64d
TARGET_FLAGS = -march=rv64imafdcv_zvfh_zfh -mabi=lp64d
endif

ifeq ($(TARGET), RISCV64_GENERIC)


+ 4
- 4
Makefile.riscv64 View File

@@ -7,12 +7,12 @@ CCOMMON_OPT += -march=rv64imafdcv_zba_zbb_zfh_zvl512b -mabi=lp64d
FCOMMON_OPT += -march=rv64imafdcv_zba_zbb_zfh -mabi=lp64d -static
endif
ifeq ($(CORE), RISCV64_ZVL256B)
CCOMMON_OPT += -march=rv64imafdcv_zvl256b -mabi=lp64d
FCOMMON_OPT += -march=rv64imafdcv -mabi=lp64d
CCOMMON_OPT += -march=rv64imafdcv_zvl256b_zvfh_zfh -mabi=lp64d
FCOMMON_OPT += -march=rv64imafdcv_zvfh_zfh -mabi=lp64d
endif
ifeq ($(CORE), RISCV64_ZVL128B)
CCOMMON_OPT += -march=rv64imafdcv -mabi=lp64d
FCOMMON_OPT += -march=rv64imafdcv -mabi=lp64d
CCOMMON_OPT += -march=rv64imafdcv_zvfh_zfh -mabi=lp64d
FCOMMON_OPT += -march=rv64imafdcv_zvfh_zfh -mabi=lp64d
endif
ifeq ($(CORE), RISCV64_GENERIC)
CCOMMON_OPT += -march=rv64imafdc -mabi=lp64d


+ 2
- 0
Makefile.rule View File

@@ -308,6 +308,8 @@ COMMON_PROF = -pg
# If you want to enable the experimental BFLOAT16 support
# BUILD_BFLOAT16 = 1

# If you want to enable the experimental HFLOAT16 support
# BUILD_HFLOAT16 = 1

# Set the thread number threshold beyond which the job array for the threaded level3 BLAS
# will be allocated on the heap rather than the stack. (This array alone requires


+ 6
- 0
Makefile.system View File

@@ -1556,6 +1556,9 @@ endif
ifeq ($(BUILD_BFLOAT16), 1)
CCOMMON_OPT += -DBUILD_BFLOAT16
endif
ifeq ($(BUILD_HFLOAT16), 1)
CCOMMON_OPT += -DBUILD_HFLOAT16
endif
ifeq ($(BUILD_SINGLE), 1)
CCOMMON_OPT += -DBUILD_SINGLE=1
endif
@@ -1898,11 +1901,14 @@ export TARGET_CORE
export NO_AVX512
export NO_AVX2
export BUILD_BFLOAT16
export BUILD_HFLOAT16
export NO_LSX
export NO_LASX

export SBGEMM_UNROLL_M
export SBGEMM_UNROLL_N
export SHGEMM_UNROLL_M
export SHGEMM_UNROLL_N
export SGEMM_UNROLL_M
export SGEMM_UNROLL_N
export DGEMM_UNROLL_M


+ 5
- 2
Makefile.tail View File

@@ -1,4 +1,5 @@
SBBLASOBJS_P = $(SBBLASOBJS:.$(SUFFIX)=.$(PSUFFIX))
SHBLASPBJS_P = $(SHBLASOBJS:.$(SUFFIX)=.$(PSUFFIX))
SBLASOBJS_P = $(SBLASOBJS:.$(SUFFIX)=.$(PSUFFIX))
DBLASOBJS_P = $(DBLASOBJS:.$(SUFFIX)=.$(PSUFFIX))
QBLASOBJS_P = $(QBLASOBJS:.$(SUFFIX)=.$(PSUFFIX))
@@ -11,8 +12,8 @@ COMMONOBJS_P = $(COMMONOBJS:.$(SUFFIX)=.$(PSUFFIX))

HPLOBJS_P = $(HPLOBJS:.$(SUFFIX)=.$(PSUFFIX))

BLASOBJS = $(SBEXTOBJS) $(SBBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS) $(CBAUXOBJS)
BLASOBJS_P = $(SBEXTOBJS_P) $(SBBLASOBJS_P) $(SBLASOBJS_P) $(DBLASOBJS_P) $(CBLASOBJS_P) $(ZBLASOBJS_P) $(CBAUXOBJS_P)
BLASOBJS = $(SHBLASOBJS) $(SBEXTOBJS) $(SBBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS) $(CBAUXOBJS)
BLASOBJS_P = $(SHBLASPBJS_P) $(SBEXTOBJS_P) $(SBBLASOBJS_P) $(SBLASOBJS_P) $(DBLASOBJS_P) $(CBLASOBJS_P) $(ZBLASOBJS_P) $(CBAUXOBJS_P)

ifdef EXPRECISION
BLASOBJS += $(QBLASOBJS) $(XBLASOBJS)
@@ -24,6 +25,7 @@ BLASOBJS += $(QBLASOBJS) $(XBLASOBJS)
BLASOBJS_P += $(QBLASOBJS_P) $(XBLASOBJS_P)
endif

$(SHBLASOBJS) $(SHBLASOBJS_P) : override CFLAGS += -DHFLOAT16 -UDOUBLE -UCOMPLEX
$(SBBLASOBJS) $(SBBLASOBJS_P) : override CFLAGS += -DBFLOAT16 -UDOUBLE -UCOMPLEX
$(SBLASOBJS) $(SBLASOBJS_P) : override CFLAGS += -UDOUBLE -UCOMPLEX
$(DBLASOBJS) $(DBLASOBJS_P) : override CFLAGS += -DDOUBLE -UCOMPLEX
@@ -33,6 +35,7 @@ $(ZBLASOBJS) $(ZBLASOBJS_P) : override CFLAGS += -DDOUBLE -DCOMPLEX
$(XBLASOBJS) $(XBLASOBJS_P) : override CFLAGS += -DXDOUBLE -DCOMPLEX
$(SBEXTOBJS) $(SBEXTOBJS_P) : override CFLAGS += -DBFLOAT16 -UDOUBLE -UCOMPLEX

$(SHBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
$(SBBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
$(SBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
$(DBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)


+ 21
- 5
benchmark/Makefile View File

@@ -56,9 +56,15 @@ GOTO_LAPACK_TARGETS=
endif

ifeq ($(BUILD_BFLOAT16),1)
GOTO_HALF_TARGETS=sbgemm.goto
GOTO_BFLOAT_TARGETS=sbgemm.goto
else
GOTO_HALF_TARGETS=
GOTO_BFLOAT_TARGETS=
endif

ifeq ($(BUILD_HFLOAT16),1)
GOTO_HFLOAT_TARGETS=shgemm.goto
else
GOTO_HFLOAT_TARGETS=
endif

ifeq ($(OSNAME), WINNT)
@@ -104,7 +110,7 @@ goto :: slinpack.goto dlinpack.goto clinpack.goto zlinpack.goto \
spotrf.goto dpotrf.goto cpotrf.goto zpotrf.goto \
ssymm.goto dsymm.goto csymm.goto zsymm.goto \
somatcopy.goto domatcopy.goto comatcopy.goto zomatcopy.goto \
saxpby.goto daxpby.goto caxpby.goto zaxpby.goto $(GOTO_HALF_TARGETS)
saxpby.goto daxpby.goto caxpby.goto zaxpby.goto $(GOTO_BFLOAT_TARGETS) $(GOTO_HFLOAT_TARGETS)

acml :: slinpack.acml dlinpack.acml clinpack.acml zlinpack.acml \
scholesky.acml dcholesky.acml ccholesky.acml zcholesky.acml \
@@ -278,7 +284,7 @@ goto :: sgemm.goto dgemm.goto cgemm.goto zgemm.goto \
smin.goto dmin.goto \
saxpby.goto daxpby.goto caxpby.goto zaxpby.goto \
somatcopy.goto domatcopy.goto comatcopy.goto zomatcopy.goto \
snrm2.goto dnrm2.goto scnrm2.goto dznrm2.goto $(GOTO_LAPACK_TARGETS) $(GOTO_HALF_TARGETS)
snrm2.goto dnrm2.goto scnrm2.goto dznrm2.goto $(GOTO_LAPACK_TARGETS) $(GOTO_BFLOAT_TARGETS) $(GOTO_HFLOAT_TARGETS)

acml :: slinpack.acml dlinpack.acml clinpack.acml zlinpack.acml \
scholesky.acml dcholesky.acml ccholesky.acml zcholesky.acml \
@@ -633,6 +639,11 @@ sbgemm.goto : sbgemm.$(SUFFIX) ../$(LIBNAME)
$(CC) $(CFLAGS) -o $(@F) $^ $(CEXTRALIB) $(EXTRALIB) $(FEXTRALIB) -lm
endif

ifeq ($(BUILD_HFLOAT16),1)
shgemm.goto : shgemm.$(SUFFIX) ../$(LIBNAME)
$(CC) $(CFLAGS) -o $(@F) $^ $(CEXTRALIB) $(EXTRALIB) $(FEXTRALIB) -lm
endif

sgemm.goto : sgemm.$(SUFFIX) ../$(LIBNAME)
$(CC) $(CFLAGS) -o $(@F) $^ $(CEXTRALIB) $(EXTRALIB) $(FEXTRALIB) -lm

@@ -2960,7 +2971,12 @@ zcholesky.$(SUFFIX) : cholesky.c

ifeq ($(BUILD_BFLOAT16),1)
sbgemm.$(SUFFIX) : gemm.c
$(CC) $(CFLAGS) -c -DHALF -UCOMPLEX -UDOUBLE -o $(@F) $^
$(CC) $(CFLAGS) -c -DBFLOAT16 -UCOMPLEX -UDOUBLE -o $(@F) $^
endif

ifeq ($(BUILD_HFLOAT16),1)
shgemm.$(SUFFIX) : gemm.c
$(CC) $(CFLAGS) -c -DHFLOAT16 -UCOMPLEX -UDOUBLE -o $(@F) $^
endif

sgemm.$(SUFFIX) : gemm.c


+ 8
- 1
benchmark/gemm.c View File

@@ -33,10 +33,17 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

#ifdef DOUBLE
#define GEMM BLASFUNC(dgemm)
#elif defined(HALF)
#elif defined(BFLOAT16)
#define GEMM BLASFUNC(sbgemm)
#undef IFLOAT
#define IFLOAT bfloat16
#elif defined(HFLOAT16)
#define GEMM BLASFUNC(shgemm)
#undef IFLOAT
#define IFLOAT hfloat16
#else
#define GEMM BLASFUNC(sgemm)
#define IFLOAT float
#endif

#else


+ 4
- 0
cblas.h View File

@@ -446,6 +446,10 @@ void cblas_sbgemm(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum C
void cblas_sbgemm_batch(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransA_array, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransB_array, OPENBLAS_CONST blasint * M_array, OPENBLAS_CONST blasint * N_array, OPENBLAS_CONST blasint * K_array,
OPENBLAS_CONST float * alpha_array, OPENBLAS_CONST bfloat16 ** A_array, OPENBLAS_CONST blasint * lda_array, OPENBLAS_CONST bfloat16 ** B_array, OPENBLAS_CONST blasint * ldb_array, OPENBLAS_CONST float * beta_array, float ** C_array, OPENBLAS_CONST blasint * ldc_array, OPENBLAS_CONST blasint group_count, OPENBLAS_CONST blasint * group_size);

/*** FLOAT16 extensions ***/
void cblas_shgemm(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransA, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransB, OPENBLAS_CONST blasint M, OPENBLAS_CONST blasint N, OPENBLAS_CONST blasint K,
OPENBLAS_CONST float alpha, OPENBLAS_CONST hfloat16 *A, OPENBLAS_CONST blasint lda, OPENBLAS_CONST hfloat16 *B, OPENBLAS_CONST blasint ldb, OPENBLAS_CONST float beta, float *C, OPENBLAS_CONST blasint ldc);

#ifdef __cplusplus
}
#endif /* __cplusplus */


+ 11
- 8
cmake/system.cmake View File

@@ -640,6 +640,9 @@ endif()
if (BUILD_BFLOAT16)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DBUILD_BFLOAT16")
endif()
if (BUILD_HFLOAT16)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DBUILD_HFLOAT16")
endif()
if(NOT MSVC)
set(CMAKE_ASM_FLAGS "${CMAKE_ASM_FLAGS} ${CCOMMON_OPT}")
endif()
@@ -647,14 +650,14 @@ endif()
set(PFLAGS "${PFLAGS} ${CCOMMON_OPT} -I${TOPDIR} -DPROFILE ${COMMON_PROF}")
if ("${CMAKE_BUILD_TYPE}" STREQUAL "Release")

if ("${F_COMPILER}" STREQUAL "FLANG")
if (${CMAKE_Fortran_COMPILER_VERSION} VERSION_LESS_EQUAL 3)
set(CMAKE_Fortran_FLAGS_RELEASE "${CMAKE_Fortran_FLAGS_RELEASE} -fno-unroll-loops")
endif ()
endif ()
if (ARM64 AND CMAKE_Fortran_COMPILER_ID MATCHES "LLVMFlang.*" AND CMAKE_SYSTEM_NAME STREQUAL "Windows")
set(CMAKE_Fortran_FLAGS_RELEASE "${CMAKE_Fortran_FLAGS_RELEASE} -O2")
endif ()
if ("${F_COMPILER}" STREQUAL "FLANG")
if (${CMAKE_Fortran_COMPILER_VERSION} VERSION_LESS_EQUAL 3)
set(CMAKE_Fortran_FLAGS_RELEASE "${CMAKE_Fortran_FLAGS_RELEASE} -fno-unroll-loops")
endif ()
endif ()
if (ARM64 AND CMAKE_Fortran_COMPILER_ID MATCHES "LLVMFlang.*" AND CMAKE_SYSTEM_NAME STREQUAL "Windows")
set(CMAKE_Fortran_FLAGS_RELEASE "${CMAKE_Fortran_FLAGS_RELEASE} -O2")
endif ()
endif ()




+ 15
- 0
common.h View File

@@ -266,6 +266,14 @@ typedef uint16_t bfloat16;
#define BFLOAT16CONVERSION 1
#endif

#ifdef BUILD_HFLOAT16
#ifndef hfloat16
typedef _Float16 hfloat16;
#endif
#else
typedef uint16_t hfloat16;
#endif

#ifdef USE64BITINT
typedef BLASLONG blasint;
#if defined(OS_WINDOWS) && defined(__64BIT__)
@@ -313,6 +321,13 @@ typedef int blasint;
#define SIZE 2
#define BASE_SHIFT 1
#define ZBASE_SHIFT 2
#elif defined(HFLOAT16)
#define IFLOAT hfloat16
#define XFLOAT IFLOAT
#define FLOAT float
#define SIZE 2
#define BASE_SHIFT 1
#define ZBASE_SHIFT 2
#else
#define FLOAT float
#define SIZE 4


+ 2
- 0
common_interface.h View File

@@ -481,6 +481,8 @@ void BLASFUNC(xhbmv)(char *, blasint *, blasint *, xdouble *, xdouble *, blasint

/* Level 3 routines */

void BLASFUNC(shgemm)(char *, char *, blasint *, blasint *, blasint *, float *,
hfloat16 *, blasint *, hfloat16 *, blasint *, float *, float *, blasint *);
void BLASFUNC(sbgemm)(char *, char *, blasint *, blasint *, blasint *, float *,
bfloat16 *, blasint *, bfloat16 *, blasint *, float *, float *, blasint *);
void BLASFUNC(sgemm)(char *, char *, blasint *, blasint *, blasint *, float *,


+ 18
- 1
common_level3.h View File

@@ -54,7 +54,8 @@ void sgemm_direct(BLASLONG M, BLASLONG N, BLASLONG K,

int sgemm_direct_performant(BLASLONG M, BLASLONG N, BLASLONG K);


int shgemm_beta(BLASLONG, BLASLONG, BLASLONG, float,
hfloat16 *, BLASLONG, hfloat16 *, BLASLONG, float *, BLASLONG);
int sbgemm_beta(BLASLONG, BLASLONG, BLASLONG, float,
bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, float *, BLASLONG);
int sgemm_beta(BLASLONG, BLASLONG, BLASLONG, float,
@@ -78,6 +79,10 @@ int xgemm_beta(BLASLONG, BLASLONG, BLASLONG, xdouble *,
xdouble *, BLASLONG, xdouble *, BLASLONG, xdouble *, BLASLONG);
#endif

int shgemm_incopy(BLASLONG m, BLASLONG n, hfloat16 *a, BLASLONG lda, hfloat16 *b);
int shgemm_itcopy(BLASLONG m, BLASLONG n, hfloat16 *a, BLASLONG lda, hfloat16 *b);
int shgemm_oncopy(BLASLONG m, BLASLONG n, hfloat16 *a, BLASLONG lda, hfloat16 *b);
int shgemm_otcopy(BLASLONG m, BLASLONG n, hfloat16 *a, BLASLONG lda, hfloat16 *b);
int sbgemm_incopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b);
int sbgemm_itcopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b);
int sbgemm_oncopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b);
@@ -505,6 +510,7 @@ int xher2k_kernel_UC(BLASLONG m, BLASLONG n, BLASLONG k, xdouble alpha_r, xdoubl
int xher2k_kernel_LN(BLASLONG m, BLASLONG n, BLASLONG k, xdouble alpha_r, xdouble alpha_i, xdouble *a, xdouble *b, xdouble *c, BLASLONG ldc, BLASLONG offset, int flag);
int xher2k_kernel_LC(BLASLONG m, BLASLONG n, BLASLONG k, xdouble alpha_r, xdouble alpha_i, xdouble *a, xdouble *b, xdouble *c, BLASLONG ldc, BLASLONG offset, int flag);

int shgemm_kernel(BLASLONG, BLASLONG, BLASLONG, float, hfloat16 *, hfloat16 *, float *, BLASLONG);
int sbgemm_kernel(BLASLONG, BLASLONG, BLASLONG, float, bfloat16 *, bfloat16 *, float *, BLASLONG);
int sgemm_kernel(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG);
int dgemm_kernel(BLASLONG, BLASLONG, BLASLONG, double, double *, double *, double *, BLASLONG);
@@ -657,6 +663,11 @@ int cgemm3m_kernel(BLASLONG, BLASLONG, BLASLONG, float, float, float *, float
int zgemm3m_kernel(BLASLONG, BLASLONG, BLASLONG, double, double, double *, double *, double *, BLASLONG);
int xgemm3m_kernel(BLASLONG, BLASLONG, BLASLONG, xdouble, xdouble, xdouble *, xdouble *, xdouble *, BLASLONG);

int shgemm_nn(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG);
int shgemm_nt(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG);
int shgemm_tn(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG);
int shgemm_tt(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG);

int sbgemm_nn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
int sbgemm_nt(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
int sbgemm_tn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
@@ -754,6 +765,11 @@ int xgemm_cr(blas_arg_t *, BLASLONG *, BLASLONG *, xdouble *, xdouble *, BLASLON
int xgemm_cc(blas_arg_t *, BLASLONG *, BLASLONG *, xdouble *, xdouble *, BLASLONG);
#endif

int shgemm_thread_nn(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG);
int shgemm_thread_nt(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG);
int shgemm_thread_tn(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG);
int shgemm_thread_tt(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG);

int sbgemm_thread_nn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
int sbgemm_thread_nt(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
int sbgemm_thread_tn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
@@ -1944,6 +1960,7 @@ int dgemm_batch_thread(blas_arg_t * queue, BLASLONG nums);
int cgemm_batch_thread(blas_arg_t * queue, BLASLONG nums);
int zgemm_batch_thread(blas_arg_t * queue, BLASLONG nums);
int sbgemm_batch_thread(blas_arg_t * queue, BLASLONG nums);
// int shgemm_batch_thread(blas_arg_t * queue, BLASLONG nums);

#ifdef __CUDACC__
}


+ 45
- 0
common_macro.h View File

@@ -39,6 +39,7 @@
#ifndef COMMON_MACRO
#define COMMON_MACRO

#include "common_sh.h"
#include "common_sb.h"
#include "common_s.h"
#include "common_d.h"
@@ -656,6 +657,50 @@
#define GEMM_SMALL_KERNEL_B0_NT DGEMM_SMALL_KERNEL_B0_NT
#define GEMM_SMALL_KERNEL_B0_TN DGEMM_SMALL_KERNEL_B0_TN
#define GEMM_SMALL_KERNEL_B0_TT DGEMM_SMALL_KERNEL_B0_TT
#elif defined(HFLOAT16)
#define GEMM_BETA SHGEMM_BETA
#define GEMM_KERNEL_N SHGEMM_KERNEL
#define GEMM_KERNEL_L SHGEMM_KERNEL
#define GEMM_KERNEL_R SHGEMM_KERNEL
#define GEMM_KERNEL_B SHGEMM_KERNEL
#define GEMM_NN SHGEMM_NN
#define GEMM_CN SHGEMM_TN
#define GEMM_TN SHGEMM_TN
#define GEMM_NC SHGEMM_NT
#define GEMM_NT SHGEMM_NT
#define GEMM_CC SHGEMM_TT
#define GEMM_CT SHGEMM_TT
#define GEMM_TC SHGEMM_TT
#define GEMM_TT SHGEMM_TT
#define GEMM_NR SHGEMM_NN
#define GEMM_TR SHGEMM_TN
#define GEMM_CR SHGEMM_TN
#define GEMM_RN SHGEMM_NN
#define GEMM_RT SHGEMM_NT
#define GEMM_RC SHGEMM_NT
#define GEMM_RR SHGEMM_NN
#define GEMM_ONCOPY SHGEMM_ONCOPY
#define GEMM_OTCOPY SHGEMM_OTCOPY
#define GEMM_INCOPY SHGEMM_INCOPY
#define GEMM_ITCOPY SHGEMM_ITCOPY

#define GEMM_THREAD_NN SHGEMM_THREAD_NN
#define GEMM_THREAD_CN SHGEMM_THREAD_TN
#define GEMM_THREAD_TN SHGEMM_THREAD_TN
#define GEMM_THREAD_NC SHGEMM_THREAD_NT
#define GEMM_THREAD_NT SHGEMM_THREAD_NT
#define GEMM_THREAD_CC SHGEMM_THREAD_TT
#define GEMM_THREAD_CT SHGEMM_THREAD_TT
#define GEMM_THREAD_TC SHGEMM_THREAD_TT
#define GEMM_THREAD_TT SHGEMM_THREAD_TT
#define GEMM_THREAD_NR SHGEMM_THREAD_NN
#define GEMM_THREAD_TR SHGEMM_THREAD_TN
#define GEMM_THREAD_CR SHGEMM_THREAD_TN
#define GEMM_THREAD_RN SHGEMM_THREAD_NN
#define GEMM_THREAD_RT SHGEMM_THREAD_NT
#define GEMM_THREAD_RC SHGEMM_THREAD_NT
#define GEMM_THREAD_RR SHGEMM_THREAD_NN


#elif defined(BFLOAT16)



+ 72
- 22
common_param.h View File

@@ -48,6 +48,21 @@ typedef struct {
int dtb_entries;
int switch_ratio;
int offsetA, offsetB, align;
#if BUILD_HFLOAT16 == 1
int shgemm_p, shgemm_q, shgemm_r;
int shgemm_unroll_m, shgemm_unroll_n, shgemm_unroll_mn;

int (*shgemm_kernel )(BLASLONG, BLASLONG, BLASLONG, float, hfloat16 *, hfloat16 *, float *, BLASLONG);
int (*shgemm_beta )(BLASLONG, BLASLONG, BLASLONG, float, hfloat16 *, BLASLONG, hfloat16 *, BLASLONG, float *, BLASLONG);

int (*shgemm_incopy )(BLASLONG, BLASLONG, hfloat16 *, BLASLONG, hfloat16 *);
int (*shgemm_itcopy )(BLASLONG, BLASLONG, hfloat16 *, BLASLONG, hfloat16 *);
int (*shgemm_oncopy )(BLASLONG, BLASLONG, hfloat16 *, BLASLONG, hfloat16 *);
int (*shgemm_otcopy )(BLASLONG, BLASLONG, hfloat16 *, BLASLONG, hfloat16 *);


#endif


#if BUILD_BFLOAT16 == 1
int sbgemm_p, sbgemm_q, sbgemm_r;
@@ -64,10 +79,10 @@ typedef struct {
float (*sbamin_k) (BLASLONG, float *, BLASLONG);
float (*sbmax_k) (BLASLONG, float *, BLASLONG);
float (*sbmin_k) (BLASLONG, float *, BLASLONG);
BLASLONG (*isbamax_k)(BLASLONG, float *, BLASLONG);
BLASLONG (*isbamin_k)(BLASLONG, float *, BLASLONG);
BLASLONG (*isbmax_k) (BLASLONG, float *, BLASLONG);
BLASLONG (*isbmin_k) (BLASLONG, float *, BLASLONG);
BLASLONG (*isbamax_k)(BLASLONG, float *, BLASLONG);
BLASLONG (*isbamin_k)(BLASLONG, float *, BLASLONG);
BLASLONG (*isbmax_k) (BLASLONG, float *, BLASLONG);
BLASLONG (*isbmin_k) (BLASLONG, float *, BLASLONG);

float (*sbnrm2_k) (BLASLONG, float *, BLASLONG);
float (*sbasum_k) (BLASLONG, float *, BLASLONG);
@@ -180,12 +195,12 @@ BLASLONG (*isbmin_k) (BLASLONG, float *, BLASLONG);
#endif

#if (BUILD_SINGLE==1) || (BUILD_DOUBLE ==1) || (BUILD_COMPLEX==1)
BLASLONG (*isamax_k)(BLASLONG, float *, BLASLONG);
BLASLONG (*isamax_k)(BLASLONG, float *, BLASLONG);
#endif
#if (BUILD_SINGLE==1) || (BUILD_COMPLEX==1)
BLASLONG (*isamin_k)(BLASLONG, float *, BLASLONG);
BLASLONG (*ismax_k) (BLASLONG, float *, BLASLONG);
BLASLONG (*ismin_k) (BLASLONG, float *, BLASLONG);
BLASLONG (*isamin_k)(BLASLONG, float *, BLASLONG);
BLASLONG (*ismax_k) (BLASLONG, float *, BLASLONG);
BLASLONG (*ismin_k) (BLASLONG, float *, BLASLONG);
float (*snrm2_k) (BLASLONG, float *, BLASLONG);
float (*sasum_k) (BLASLONG, float *, BLASLONG);
#endif
@@ -316,10 +331,10 @@ BLASLONG (*ismin_k) (BLASLONG, float *, BLASLONG);
double (*damin_k) (BLASLONG, double *, BLASLONG);
double (*dmax_k) (BLASLONG, double *, BLASLONG);
double (*dmin_k) (BLASLONG, double *, BLASLONG);
BLASLONG (*idamax_k)(BLASLONG, double *, BLASLONG);
BLASLONG (*idamin_k)(BLASLONG, double *, BLASLONG);
BLASLONG (*idmax_k) (BLASLONG, double *, BLASLONG);
BLASLONG (*idmin_k) (BLASLONG, double *, BLASLONG);
BLASLONG (*idamax_k)(BLASLONG, double *, BLASLONG);
BLASLONG (*idamin_k)(BLASLONG, double *, BLASLONG);
BLASLONG (*idmax_k) (BLASLONG, double *, BLASLONG);
BLASLONG (*idmin_k) (BLASLONG, double *, BLASLONG);

double (*dnrm2_k) (BLASLONG, double *, BLASLONG);
double (*dasum_k) (BLASLONG, double *, BLASLONG);
@@ -435,10 +450,10 @@ BLASLONG (*idmin_k) (BLASLONG, double *, BLASLONG);
xdouble (*qamin_k) (BLASLONG, xdouble *, BLASLONG);
xdouble (*qmax_k) (BLASLONG, xdouble *, BLASLONG);
xdouble (*qmin_k) (BLASLONG, xdouble *, BLASLONG);
BLASLONG (*iqamax_k)(BLASLONG, xdouble *, BLASLONG);
BLASLONG (*iqamin_k)(BLASLONG, xdouble *, BLASLONG);
BLASLONG (*iqmax_k) (BLASLONG, xdouble *, BLASLONG);
BLASLONG (*iqmin_k) (BLASLONG, xdouble *, BLASLONG);
BLASLONG (*iqamax_k)(BLASLONG, xdouble *, BLASLONG);
BLASLONG (*iqamin_k)(BLASLONG, xdouble *, BLASLONG);
BLASLONG (*iqmax_k) (BLASLONG, xdouble *, BLASLONG);
BLASLONG (*iqmin_k) (BLASLONG, xdouble *, BLASLONG);

xdouble (*qnrm2_k) (BLASLONG, xdouble *, BLASLONG);
xdouble (*qasum_k) (BLASLONG, xdouble *, BLASLONG);
@@ -528,8 +543,8 @@ BLASLONG (*iqmin_k) (BLASLONG, xdouble *, BLASLONG);
float (*camax_k) (BLASLONG, float *, BLASLONG);
float (*camin_k) (BLASLONG, float *, BLASLONG);
BLASLONG (*icamax_k)(BLASLONG, float *, BLASLONG);
BLASLONG (*icamin_k)(BLASLONG, float *, BLASLONG);
BLASLONG (*icamax_k)(BLASLONG, float *, BLASLONG);
BLASLONG (*icamin_k)(BLASLONG, float *, BLASLONG);

float (*cnrm2_k) (BLASLONG, float *, BLASLONG);
float (*casum_k) (BLASLONG, float *, BLASLONG);
@@ -739,8 +754,8 @@ BLASLONG (*icamin_k)(BLASLONG, float *, BLASLONG);

double (*zamax_k) (BLASLONG, double *, BLASLONG);
double (*zamin_k) (BLASLONG, double *, BLASLONG);
BLASLONG (*izamax_k)(BLASLONG, double *, BLASLONG);
BLASLONG (*izamin_k)(BLASLONG, double *, BLASLONG);
BLASLONG (*izamax_k)(BLASLONG, double *, BLASLONG);
BLASLONG (*izamin_k)(BLASLONG, double *, BLASLONG);

double (*znrm2_k) (BLASLONG, double *, BLASLONG);
double (*zasum_k) (BLASLONG, double *, BLASLONG);
@@ -950,8 +965,8 @@ BLASLONG (*izamin_k)(BLASLONG, double *, BLASLONG);

xdouble (*xamax_k) (BLASLONG, xdouble *, BLASLONG);
xdouble (*xamin_k) (BLASLONG, xdouble *, BLASLONG);
BLASLONG (*ixamax_k)(BLASLONG, xdouble *, BLASLONG);
BLASLONG (*ixamin_k)(BLASLONG, xdouble *, BLASLONG);
BLASLONG (*ixamax_k)(BLASLONG, xdouble *, BLASLONG);
BLASLONG (*ixamin_k)(BLASLONG, xdouble *, BLASLONG);

xdouble (*xnrm2_k) (BLASLONG, xdouble *, BLASLONG);
xdouble (*xasum_k) (BLASLONG, xdouble *, BLASLONG);
@@ -1229,6 +1244,15 @@ extern gotoblas_t *gotoblas;

#define HAVE_EX_L2 gotoblas -> exclusive_cache

#if (BUILD_HFLOAT16==1)
#define SHGEMM_P gotoblas -> shgemm_p
#define SHGEMM_Q gotoblas -> shgemm_q
#define SHGEMM_R gotoblas -> shgemm_r
#define SHGEMM_UNROLL_M gotoblas -> shgemm_unroll_m
#define SHGEMM_UNROLL_N gotoblas -> shgemm_unroll_n
#define SHGEMM_UNROLL_MN gotoblas -> shgemm_unroll_mn
#endif

#if (BUILD_BFLOAT16==1)
#define SBGEMM_P gotoblas -> sbgemm_p
#define SBGEMM_Q gotoblas -> sbgemm_q
@@ -1357,6 +1381,19 @@ extern gotoblas_t *gotoblas;
#define HAVE_EX_L2 0
#endif

#if (BUILD_HFLOAT16 == 1)
#define SHGEMM_P SHGEMM_DEFAULT_P
#define SHGEMM_Q SHGEMM_DEFAULT_Q
#define SHGEMM_R SHGEMM_DEFAULT_R
#define SHGEMM_UNROLL_M SHGEMM_DEFAULT_UNROLL_M
#define SHGEMM_UNROLL_N SHGEMM_DEFAULT_UNROLL_N
#ifdef SHGEMM_DEFAULT_UNROLL_MN
#define SHGEMM_UNROLL_MN SHGEMM_DEFAULT_UNROLL_MN
#else
#define SHGEMM_UNROLL_MN MAX((SHGEMM_UNROLL_M), (SHGEMM_UNROLL_N))
#endif
#endif

#if (BUILD_BFLOAT16 == 1)
#define SBGEMM_P SBGEMM_DEFAULT_P
#define SBGEMM_Q SBGEMM_DEFAULT_Q
@@ -1478,6 +1515,7 @@ extern gotoblas_t *gotoblas;


#endif

#endif

#ifndef COMPLEX
@@ -1505,6 +1543,18 @@ extern gotoblas_t *gotoblas;
#define GEMM_DEFAULT_R DGEMM_DEFAULT_R
#define GEMM_DEFAULT_UNROLL_M DGEMM_DEFAULT_UNROLL_M
#define GEMM_DEFAULT_UNROLL_N DGEMM_DEFAULT_UNROLL_N
#elif defined(HFLOAT16)
#define GEMM_P SHGEMM_P
#define GEMM_Q SHGEMM_Q
#define GEMM_R SHGEMM_R
#define GEMM_UNROLL_M SHGEMM_UNROLL_M
#define GEMM_UNROLL_N SHGEMM_UNROLL_N
#define GEMM_UNROLL_MN SHGEMM_UNROLL_MN
#define GEMM_DEFAULT_P SHGEMM_DEFAULT_P
#define GEMM_DEFAULT_Q SHGEMM_DEFAULT_Q
#define GEMM_DEFAULT_R SHGEMM_DEFAULT_R
#define GEMM_DEFAULT_UNROLL_M SHGEMM_DEFAULT_UNROLL_M
#define GEMM_DEFAULT_UNROLL_N SHGEMM_DEFAULT_UNROLL_N
#elif defined(BFLOAT16)
#define GEMM_P SBGEMM_P
#define GEMM_Q SBGEMM_Q


+ 72
- 0
common_sh.h View File

@@ -0,0 +1,72 @@
#ifndef COMMON_SH_H
#define COMMON_SH_H

#ifndef DYNAMIC_ARCH

#define SHGEMM_ONCOPY shgemm_oncopy
#define SHGEMM_OTCOPY shgemm_otcopy

#if SGEMM_DEFAULT_UNROLL_M == SGEMM_DEFAULT_UNROLL_N
#define SHGEMM_INCOPY shgemm_oncopy
#define SHGEMM_ITCOPY shgemm_otcopy
#else
#define SHGEMM_INCOPY shgemm_incopy
#define SHGEMM_ITCOPY shgemm_itcopy
#endif

#define SHGEMM_BETA shgemm_beta
#define SHGEMM_KERNEL shgemm_kernel


#else // #DYNAMIC_ARCH

#define SHGEMM_ONCOPY gotoblas -> shgemm_oncopy
#define SHGEMM_OTCOPY gotoblas -> shgemm_otcopy
#if SGEMM_DEFAULT_UNROLL_M == SGEMM_DEFAULT_UNROLL_N
#define SHGEMM_INCOPY gotoblas -> shgemm_oncopy
#define SHGEMM_ITCOPY gotoblas -> shgemm_otcopy
#else
#define SHGEMM_INCOPY gotoblas -> shgemm_incopy
#define SHGEMM_ITCOPY gotoblas -> shgemm_itcopy
#endif

#define SHGEMM_BETA gotoblas -> shgemm_beta
#define SHGEMM_KERNEL gotoblas -> shgemm_kernel
#endif // #DYNAMIC_ARCH

#define SHGEMM_NN shgemm_nn
#define SHGEMM_CN shgemm_tn
#define SHGEMM_TN shgemm_tn
#define SHGEMM_NC shgemm_nt
#define SHGEMM_NT shgemm_nt
#define SHGEMM_CC shgemm_tt
#define SHGEMM_CT shgemm_tt
#define SHGEMM_TC shgemm_tt
#define SHGEMM_TT shgemm_tt
#define SHGEMM_NR shgemm_nn
#define SHGEMM_TR shgemm_tn
#define SHGEMM_CR shgemm_tn
#define SHGEMM_RN shgemm_nn
#define SHGEMM_RT shgemm_nt
#define SHGEMM_RC shgemm_nt
#define SHGEMM_RR shgemm_nn

#define SHGEMM_THREAD_NN shgemm_thread_nn
#define SHGEMM_THREAD_CN shgemm_thread_tn
#define SHGEMM_THREAD_TN shgemm_thread_tn
#define SHGEMM_THREAD_NC shgemm_thread_nt
#define SHGEMM_THREAD_NT shgemm_thread_nt
#define SHGEMM_THREAD_CC shgemm_thread_tt
#define SHGEMM_THREAD_CT shgemm_thread_tt
#define SHGEMM_THREAD_TC shgemm_thread_tt
#define SHGEMM_THREAD_TT shgemm_thread_tt
#define SHGEMM_THREAD_NR shgemm_thread_nn
#define SHGEMM_THREAD_TR shgemm_thread_tn
#define SHGEMM_THREAD_CR shgemm_thread_tn
#define SHGEMM_THREAD_RN shgemm_thread_nn
#define SHGEMM_THREAD_RT shgemm_thread_nt
#define SHGEMM_THREAD_RC shgemm_thread_nt
#define SHGEMM_THREAD_RR shgemm_thread_nn


#endif // #COMMON_SH_H

+ 6
- 0
driver/level3/CMakeLists.txt View File

@@ -18,6 +18,12 @@ foreach (GEMM_DEFINE ${GEMM_DEFINES})
GenerateNamedObjects("gemm.c" "${GEMM_DEFINE};THREADED_LEVEL3" "gemm_thread_${GEMM_DEFINE_LC}" 0 "" "" false "BFLOAT16")
endif ()
endif ()
if (BUILD_HFLOAT16)
GenerateNamedObjects("gemm.c" "${GEMM_DEFINE}" "gemm_${GEMM_DEFINE_LC}" 0 "" "" false "HFLOAT16")
if (USE_THREAD AND NOT USE_SIMPLE_THREADED_LEVEL3)
GenerateNamedObjects("gemm.c" "${GEMM_DEFINE};THREADED_LEVEL3" "gemm_thread_${GEMM_DEFINE_LC}" 0 "" "" false "HFLOAT16")
endif ()
endif ()
endforeach ()

if ( BUILD_COMPLEX16 AND NOT BUILD_DOUBLE)


+ 71
- 16
driver/level3/Makefile View File

@@ -23,6 +23,10 @@ ifeq ($(BUILD_BFLOAT16),1)
SBBLASOBJS += sbgemm_nn.$(SUFFIX) sbgemm_nt.$(SUFFIX) sbgemm_tn.$(SUFFIX) sbgemm_tt.$(SUFFIX)
endif

ifeq ($(BUILD_HFLOAT16),1)
SHBLASOBJS += shgemm_nn.$(SUFFIX) shgemm_nt.$(SUFFIX) shgemm_tn.$(SUFFIX) shgemm_tt.$(SUFFIX)
endif

SBLASOBJS += \
sgemm_nn.$(SUFFIX) sgemm_nt.$(SUFFIX) sgemm_tn.$(SUFFIX) sgemm_tt.$(SUFFIX) \
strmm_LNUU.$(SUFFIX) strmm_LNUN.$(SUFFIX) strmm_LNLU.$(SUFFIX) strmm_LNLN.$(SUFFIX) \
@@ -210,6 +214,9 @@ ifneq ($(USE_SIMPLE_THREADED_LEVEL3), 1)
ifeq ($(BUILD_BFLOAT16),1)
SBBLASOBJS += sbgemm_thread_nn.$(SUFFIX) sbgemm_thread_nt.$(SUFFIX) sbgemm_thread_tn.$(SUFFIX) sbgemm_thread_tt.$(SUFFIX)
endif
ifeq ($(BUILD_HFLOAT16),1)
SHBLASOBJS += shgemm_thread_nn.$(SUFFIX) shgemm_thread_nt.$(SUFFIX) shgemm_thread_tn.$(SUFFIX) shgemm_thread_tt.$(SUFFIX)
endif
SBLASOBJS += sgemm_thread_nn.$(SUFFIX) sgemm_thread_nt.$(SUFFIX) sgemm_thread_tn.$(SUFFIX) sgemm_thread_tt.$(SUFFIX)
DBLASOBJS += dgemm_thread_nn.$(SUFFIX) dgemm_thread_nt.$(SUFFIX) dgemm_thread_tn.$(SUFFIX) dgemm_thread_tt.$(SUFFIX)
QBLASOBJS += qgemm_thread_nn.$(SUFFIX) qgemm_thread_nt.$(SUFFIX) qgemm_thread_tn.$(SUFFIX) qgemm_thread_tt.$(SUFFIX)
@@ -344,16 +351,28 @@ endif
all ::

sbgemm_nn.$(SUFFIX) : gemm.c level3.c ../../param.h
$(CC) $(CFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DNN $< -o $(@F)
$(CC) $(CFLAGS) $(BLOCKS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX -DNN $< -o $(@F)

sbgemm_nt.$(SUFFIX) : gemm.c level3.c ../../param.h
$(CC) $(CFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DNT $< -o $(@F)
$(CC) $(CFLAGS) $(BLOCKS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX -DNT $< -o $(@F)

sbgemm_tn.$(SUFFIX) : gemm.c level3.c ../../param.h
$(CC) $(CFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DTN $< -o $(@F)
$(CC) $(CFLAGS) $(BLOCKS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX -DTN $< -o $(@F)

sbgemm_tt.$(SUFFIX) : gemm.c level3.c ../../param.h
$(CC) $(CFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DTT $< -o $(@F)
$(CC) $(CFLAGS) $(BLOCKS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX -DTT $< -o $(@F)

shgemm_nn.$(SUFFIX) : gemm.c level3.c ../../param.h
$(CC) $(CFLAGS) $(BLOCKS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX -DNN $< -o $(@F)

shgemm_nt.$(SUFFIX) : gemm.c level3.c ../../param.h
$(CC) $(CFLAGS) $(BLOCKS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX -DNT $< -o $(@F)

shgemm_tn.$(SUFFIX) : gemm.c level3.c ../../param.h
$(CC) $(CFLAGS) $(BLOCKS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX -DTN $< -o $(@F)

shgemm_tt.$(SUFFIX) : gemm.c level3.c ../../param.h
$(CC) $(CFLAGS) $(BLOCKS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX -DTT $< -o $(@F)

sgemm_nn.$(SUFFIX) : gemm.c level3.c ../../param.h
$(CC) $(CFLAGS) $(BLOCKS) -c -UDOUBLE -UCOMPLEX -DNN $< -o $(@F)
@@ -551,16 +570,28 @@ beta_thread.$(SUFFIX) : beta_thread.c ../../common.h
$(CC) -c $(CFLAGS) $< -o $(@F)

sbgemm_thread_nn.$(SUFFIX) : gemm.c level3_thread.c ../../param.h
$(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DNN $< -o $(@F)
$(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DBFLOAT16 -UDOUBLE -UCOMPLEX -DNN $< -o $(@F)

sbgemm_thread_nt.$(SUFFIX) : gemm.c level3_thread.c ../../param.h
$(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DNT $< -o $(@F)
$(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DBFLOAT16 -UDOUBLE -UCOMPLEX -DNT $< -o $(@F)

sbgemm_thread_tn.$(SUFFIX) : gemm.c level3_thread.c ../../param.h
$(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DTN $< -o $(@F)
$(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DBFLOAT16 -UDOUBLE -UCOMPLEX -DTN $< -o $(@F)

sbgemm_thread_tt.$(SUFFIX) : gemm.c level3_thread.c ../../param.h
$(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DTT $< -o $(@F)
$(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DBFLOAT16 -UDOUBLE -UCOMPLEX -DTT $< -o $(@F)

shgemm_thread_nn.$(SUFFIX) : gemm.c level3_thread.c ../../param.h
$(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHFLOAT16 -UDOUBLE -UCOMPLEX -DNN $< -o $(@F)

shgemm_thread_nt.$(SUFFIX) : gemm.c level3_thread.c ../../param.h
$(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHFLOAT16 -UDOUBLE -UCOMPLEX -DNT $< -o $(@F)

shgemm_thread_tn.$(SUFFIX) : gemm.c level3_thread.c ../../param.h
$(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHFLOAT16 -UDOUBLE -UCOMPLEX -DTN $< -o $(@F)

shgemm_thread_tt.$(SUFFIX) : gemm.c level3_thread.c ../../param.h
$(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHFLOAT16 -UDOUBLE -UCOMPLEX -DTT $< -o $(@F)

sgemm_thread_nn.$(SUFFIX) : gemm.c level3_thread.c ../../param.h
$(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -UDOUBLE -UCOMPLEX -DNN $< -o $(@F)
@@ -2736,16 +2767,28 @@ xtrsm_RCLN.$(SUFFIX) : trsm_R.c
$(CC) -c $(CFLAGS) -DCOMPLEX -DXDOUBLE -DTRANSA -UUPPER -UUNIT -DCONJ $< -o $(@F)

sbgemm_nn.$(PSUFFIX) : gemm.c level3.c ../../param.h
$(CC) $(PFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DNN $< -o $(@F)
$(CC) $(PFLAGS) $(BLOCKS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX -DNN $< -o $(@F)

sbgemm_nt.$(PSUFFIX) : gemm.c level3.c ../../param.h
$(CC) $(PFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DNT $< -o $(@F)
$(CC) $(PFLAGS) $(BLOCKS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX -DNT $< -o $(@F)

sbgemm_tn.$(PSUFFIX) : gemm.c level3.c ../../param.h
$(CC) $(PFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DTN $< -o $(@F)
$(CC) $(PFLAGS) $(BLOCKS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX -DTN $< -o $(@F)

sbgemm_tt.$(PSUFFIX) : gemm.c level3.c ../../param.h
$(CC) $(PFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DTT $< -o $(@F)
$(CC) $(PFLAGS) $(BLOCKS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX -DTT $< -o $(@F)

shgemm_nn.$(PSUFFIX) : gemm.c level3.c ../../param.h
$(CC) $(PFLAGS) $(BLOCKS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX -DNN $< -o $(@F)

shgemm_nt.$(PSUFFIX) : gemm.c level3.c ../../param.h
$(CC) $(PFLAGS) $(BLOCKS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX -DNT $< -o $(@F)

shgemm_tn.$(PSUFFIX) : gemm.c level3.c ../../param.h
$(CC) $(PFLAGS) $(BLOCKS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX -DTN $< -o $(@F)

shgemm_tt.$(PSUFFIX) : gemm.c level3.c ../../param.h
$(CC) $(PFLAGS) $(BLOCKS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX -DTT $< -o $(@F)

sgemm_nn.$(PSUFFIX) : gemm.c level3.c ../../param.h
$(CC) $(PFLAGS) $(BLOCKS) -c -UDOUBLE -UCOMPLEX -DNN $< -o $(@F)
@@ -2959,16 +3002,28 @@ zgemm_batch_thread.$(SUFFIX) : gemm_batch_thread.c ../../common.h


sbgemm_thread_nn.$(PSUFFIX) : gemm.c level3_thread.c ../../param.h
$(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DNN $< -o $(@F)
$(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DBFLOAT16 -UDOUBLE -UCOMPLEX -DNN $< -o $(@F)

sbgemm_thread_nt.$(PSUFFIX) : gemm.c level3_thread.c ../../param.h
$(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DNT $< -o $(@F)
$(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DBFLOAT16 -UDOUBLE -UCOMPLEX -DNT $< -o $(@F)

sbgemm_thread_tn.$(PSUFFIX) : gemm.c level3_thread.c ../../param.h
$(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DTN $< -o $(@F)
$(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DBFLOAT16 -UDOUBLE -UCOMPLEX -DTN $< -o $(@F)

sbgemm_thread_tt.$(PSUFFIX) : gemm.c level3_thread.c ../../param.h
$(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DTT $< -o $(@F)
$(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DBFLOAT16 -UDOUBLE -UCOMPLEX -DTT $< -o $(@F)

shgemm_thread_nn.$(PSUFFIX) : gemm.c level3_thread.c ../../param.h
$(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHFLOAT16 -UDOUBLE -UCOMPLEX -DNN $< -o $(@F)

shgemm_thread_nt.$(PSUFFIX) : gemm.c level3_thread.c ../../param.h
$(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHFLOAT16 -UDOUBLE -UCOMPLEX -DNT $< -o $(@F)

shgemm_thread_tn.$(PSUFFIX) : gemm.c level3_thread.c ../../param.h
$(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHFLOAT16 -UDOUBLE -UCOMPLEX -DTN $< -o $(@F)

shgemm_thread_tt.$(PSUFFIX) : gemm.c level3_thread.c ../../param.h
$(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHFLOAT16 -UDOUBLE -UCOMPLEX -DTT $< -o $(@F)

sgemm_thread_nn.$(PSUFFIX) : gemm.c level3_thread.c ../../param.h
$(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -UDOUBLE -UCOMPLEX -DNN $< -o $(@F)


+ 1
- 1
driver/others/Makefile View File

@@ -218,7 +218,7 @@ mulx.$(SUFFIX) : $(ARCH)/mulx.c
$(CC) $(CFLAGS) -c -DXDOUBLE -UCOMPLEX $< -o $(@F)

detect_riscv64.$(SUFFIX): detect_riscv64.c
$(CC) $(CFLAGS) -c -march=rv64imafdcv $< -o $(@F)
$(CC) $(CFLAGS) -c -march=rv64imafdcv_zvfh_zfh $< -o $(@F)

xerbla.$(PSUFFIX) : xerbla.c
$(CC) $(PFLAGS) -c $< -o $(@F)


+ 22
- 0
driver/others/parameter.c View File

@@ -67,6 +67,11 @@ BLASLONG sbgemm_p = DEFAULT_GEMM_P;
#else
BLASLONG sbgemm_p = SBGEMM_P;
#endif
#if SHGEMM_P == shgemm_p
BLASLONG shgemm_p = DEFAULT_GEMM_P;
#else
BLASLONG shgemm_p = SHGEMM_P;
#endif
#if SGEMM_P == sgemm_p
BLASLONG sgemm_p = DEFAULT_GEMM_P;
#else
@@ -93,6 +98,11 @@ BLASLONG sbgemm_q = DEFAULT_GEMM_Q;
#else
BLASLONG sbgemm_q = SBGEMM_Q;
#endif
#if SHGEMM_Q == shgemm_q
BLASLONG shgemm_q = DEFAULT_GEMM_Q;
#else
BLASLONG shgemm_q = SHGEMM_Q;
#endif
#if SGEMM_Q == sgemm_q
BLASLONG sgemm_q = DEFAULT_GEMM_Q;
#else
@@ -119,6 +129,11 @@ BLASLONG sbgemm_r = DEFAULT_GEMM_R;
#else
BLASLONG sbgemm_r = SBGEMM_R;
#endif
#if SHGEMM_R == shgemm_r
BLASLONG shgemm_r = DEFAULT_GEMM_R;
#else
BLASLONG shgemm_r = SHGEMM_R;
#endif
#if SGEMM_R == sgemm_r
BLASLONG sgemm_r = DEFAULT_GEMM_R;
#else
@@ -526,6 +541,9 @@ void blas_set_parameter(void){

#ifdef BUILD_BFLOAT16
sbgemm_r = (((BUFFER_SIZE - ((SBGEMM_P * SBGEMM_Q * 4 + GEMM_OFFSET_A + GEMM_ALIGN) & ~GEMM_ALIGN)) / (SBGEMM_Q * 4)) - 15) & ~15;
#endif
#ifdef BUILD_HFLOAT16
shgemm_r = (((BUFFER_SIZE - ((SHGEMM_P * SHGEMM_Q * 4 + GEMM_OFFSET_A + GEMM_ALIGN) & ~GEMM_ALIGN)) / (SHGEMM_Q * 4)) - 15) & ~15;
#endif
sgemm_r = (((BUFFER_SIZE - ((SGEMM_P * SGEMM_Q * 4 + GEMM_OFFSET_A + GEMM_ALIGN) & ~GEMM_ALIGN)) / (SGEMM_Q * 4)) - 15) & ~15;
dgemm_r = (((BUFFER_SIZE - ((DGEMM_P * DGEMM_Q * 8 + GEMM_OFFSET_A + GEMM_ALIGN) & ~GEMM_ALIGN)) / (DGEMM_Q * 8)) - 15) & ~15;
@@ -619,6 +637,7 @@ void blas_set_parameter(void){
size = BITMASK(cpuid3, 16, 0xff);

sbgemm_p = 192 * (size + 1);
shgemm_p = 192 * (size + 1);
sgemm_p = 192 * (size + 1);
dgemm_p = 96 * (size + 1);
cgemm_p = 96 * (size + 1);
@@ -634,6 +653,9 @@ void blas_set_parameter(void){

#ifdef BUILD_BFLOAT16
sbgemm_r = (((BUFFER_SIZE - ((SBGEMM_P * SBGEMM_Q * 4 + GEMM_OFFSET_A + GEMM_ALIGN) & ~GEMM_ALIGN)) / (SBGEMM_Q * 4)) - 15) & ~15;
#endif
#ifdef BUILD_HFLOAT16
shgemm_r = (((BUFFER_SIZE - ((SHGEMM_P * SHGEMM_Q * 4 + GEMM_OFFSET_A + GEMM_ALIGN) & ~GEMM_ALIGN)) / (SHGEMM_Q * 4)) - 15) & ~15;
#endif
sgemm_r = (((BUFFER_SIZE - ((SGEMM_P * SGEMM_Q * 4 + GEMM_OFFSET_A + GEMM_ALIGN) & ~GEMM_ALIGN)) / (SGEMM_Q * 4)) - 15) & ~15;
dgemm_r = (((BUFFER_SIZE - ((DGEMM_P * DGEMM_Q * 8 + GEMM_OFFSET_A + GEMM_ALIGN) & ~GEMM_ALIGN)) / (DGEMM_Q * 8)) - 15) & ~15;


+ 10
- 7
exports/Makefile View File

@@ -39,6 +39,9 @@ endif
ifndef BUILD_BFLOAT16
BUILD_BFLOAT16 = 0
endif
ifndef BUILD_HFLOAT16
BUILD_HFLOAT16 = 0
endif
ifndef BUILD_SINGLE
BUILD_SINGLE = 0
endif
@@ -130,10 +133,10 @@ dll : ../$(LIBDLLNAME)
-Wl,--whole-archive ../$(LIBNAME) -Wl,--no-whole-archive $(FEXTRALIB) $(EXTRALIB)

$(LIBPREFIX).def : $(GENSYM)
./$(GENSYM) win2k $(ARCH) dummy $(EXPRECISION) $(NO_CBLAS) $(NO_LAPACK) $(NO_LAPACKE) $(NEED2UNDERSCORES) $(ONLY_CBLAS) "$(SYMBOLPREFIX)" "$(SYMBOLSUFFIX)" $(BUILD_LAPACK_DEPRECATED) $(BUILD_BFLOAT16) $(BUILD_SINGLE) $(BUILD_DOUBLE) $(BUILD_COMPLEX) $(BUILD_COMPLEX16) > $(@F)
./$(GENSYM) win2k $(ARCH) dummy $(EXPRECISION) $(NO_CBLAS) $(NO_LAPACK) $(NO_LAPACKE) $(NEED2UNDERSCORES) $(ONLY_CBLAS) "$(SYMBOLPREFIX)" "$(SYMBOLSUFFIX)" $(BUILD_LAPACK_DEPRECATED) $(BUILD_BFLOAT16) $(BUILD_HFLOAT16) $(BUILD_SINGLE) $(BUILD_DOUBLE) $(BUILD_COMPLEX) $(BUILD_COMPLEX16) > $(@F)

libgoto_hpl.def : $(GENSYM)
./$(GENSYM) win2khpl $(ARCH) dummy $(EXPRECISION) $(NO_CBLAS) $(NO_LAPACK) $(NO_LAPACKE) $(NEED2UNDERSCORES) $(ONLY_CBLAS) "$(SYMBOLPREFIX)" "$(SYMBOLSUFFIX)" $(BUILD_LAPACK_DEPRECATED) $(BUILD_BFLOAT16) $(BUILD_SINGLE) $(BUILD_DOUBLE) $(BUILD_COMPLEX) $(BUILD_COMPLEX16) > $(@F)
./$(GENSYM) win2khpl $(ARCH) dummy $(EXPRECISION) $(NO_CBLAS) $(NO_LAPACK) $(NO_LAPACKE) $(NEED2UNDERSCORES) $(ONLY_CBLAS) "$(SYMBOLPREFIX)" "$(SYMBOLSUFFIX)" $(BUILD_LAPACK_DEPRECATED) $(BUILD_BFLOAT16) $(BUILD_HFLOAT16) $(BUILD_SINGLE) $(BUILD_DOUBLE) $(BUILD_COMPLEX) $(BUILD_COMPLEX16) > $(@F)

ifeq ($(OSNAME), Darwin)
ifeq ($(FIXED_LIBNAME),1)
@@ -298,23 +301,23 @@ static : ../$(LIBNAME)
rm -f goto.$(SUFFIX)

osx.def : $(GENSYM) ../Makefile.system ../getarch.c
./$(GENSYM) osx $(ARCH) "$(BU)" $(EXPRECISION) $(NO_CBLAS) $(NO_LAPACK) $(NO_LAPACKE) $(NEED2UNDERSCORES) $(ONLY_CBLAS) "$(SYMBOLPREFIX)" "$(SYMBOLSUFFIX)" $(BUILD_LAPACK_DEPRECATED) $(BUILD_BFLOAT16) $(BUILD_SINGLE) $(BUILD_DOUBLE) $(BUILD_COMPLEX) $(BUILD_COMPLEX16) > $(@F)
./$(GENSYM) osx $(ARCH) "$(BU)" $(EXPRECISION) $(NO_CBLAS) $(NO_LAPACK) $(NO_LAPACKE) $(NEED2UNDERSCORES) $(ONLY_CBLAS) "$(SYMBOLPREFIX)" "$(SYMBOLSUFFIX)" $(BUILD_LAPACK_DEPRECATED) $(BUILD_BFLOAT16) $(BUILD_HFLOAT16) $(BUILD_SINGLE) $(BUILD_DOUBLE) $(BUILD_COMPLEX) $(BUILD_COMPLEX16) > $(@F)

aix.def : $(GENSYM) ../Makefile.system ../getarch.c
./$(GENSYM) aix $(ARCH) "$(BU)" $(EXPRECISION) $(NO_CBLAS) $(NO_LAPACK) $(NO_LAPACKE) $(NEED2UNDERSCORES) $(ONLY_CBLAS) "$(SYMBOLPREFIX)" "$(SYMBOLSUFFIX)" $(BUILD_LAPACK_DEPRECATED) $(BUILD_BFLOAT16) $(BUILD_SINGLE) $(BUILD_DOUBLE) $(BUILD_COMPLEX) $(BUILD_COMPLEX16) > $(@F)
./$(GENSYM) aix $(ARCH) "$(BU)" $(EXPRECISION) $(NO_CBLAS) $(NO_LAPACK) $(NO_LAPACKE) $(NEED2UNDERSCORES) $(ONLY_CBLAS) "$(SYMBOLPREFIX)" "$(SYMBOLSUFFIX)" $(BUILD_LAPACK_DEPRECATED) $(BUILD_BFLOAT16) $(BUILD_HFLOAT16) $(BUILD_SINGLE) $(BUILD_DOUBLE) $(BUILD_COMPLEX) $(BUILD_COMPLEX16) > $(@F)

objcopy.def : $(GENSYM) ../Makefile.system ../getarch.c
./$(GENSYM) objcopy $(ARCH) "$(BU)" $(EXPRECISION) $(NO_CBLAS) $(NO_LAPACK) $(NO_LAPACKE) $(NEED2UNDERSCORES) $(ONLY_CBLAS) "$(SYMBOLPREFIX)" "$(SYMBOLSUFFIX)" $(BUILD_LAPACK_DEPRECATED) $(BUILD_BFLOAT16) $(BUILD_SINGLE) $(BUILD_DOUBLE) $(BUILD_COMPLEX) $(BUILD_COMPLEX16) > $(@F)
./$(GENSYM) objcopy $(ARCH) "$(BU)" $(EXPRECISION) $(NO_CBLAS) $(NO_LAPACK) $(NO_LAPACKE) $(NEED2UNDERSCORES) $(ONLY_CBLAS) "$(SYMBOLPREFIX)" "$(SYMBOLSUFFIX)" $(BUILD_LAPACK_DEPRECATED) $(BUILD_BFLOAT16) $(BUILD_HFLOAT16) $(BUILD_SINGLE) $(BUILD_DOUBLE) $(BUILD_COMPLEX) $(BUILD_COMPLEX16) > $(@F)

objconv.def : $(GENSYM) ../Makefile.system ../getarch.c
./$(GENSYM) objconv $(ARCH) "$(BU)" $(EXPRECISION) $(NO_CBLAS) $(NO_LAPACK) $(NO_LAPACKE) $(NEED2UNDERSCORES) $(ONLY_CBLAS) "$(SYMBOLPREFIX)" "$(SYMBOLSUFFIX)" $(BUILD_LAPACK_DEPRECATED) $(BUILD_BFLOAT16) $(BUILD_SINGLE) $(BUILD_DOUBLE) $(BUILD_COMPLEX) $(BUILD_COMPLEX16) > $(@F)
./$(GENSYM) objconv $(ARCH) "$(BU)" $(EXPRECISION) $(NO_CBLAS) $(NO_LAPACK) $(NO_LAPACKE) $(NEED2UNDERSCORES) $(ONLY_CBLAS) "$(SYMBOLPREFIX)" "$(SYMBOLSUFFIX)" $(BUILD_LAPACK_DEPRECATED) $(BUILD_BFLOAT16) $(BUILD_HFLOAT16) $(BUILD_SINGLE) $(BUILD_DOUBLE) $(BUILD_COMPLEX) $(BUILD_COMPLEX16) > $(@F)

test : linktest.c
$(CC) $(CFLAGS) $(LDFLAGS) -w -o linktest linktest.c ../$(LIBSONAME) -lm && echo OK.
rm -f linktest

linktest.c : $(GENSYM) ../Makefile.system ../getarch.c
./$(GENSYM) linktest $(ARCH) "$(BU)" $(EXPRECISION) $(NO_CBLAS) $(NO_LAPACK) $(NO_LAPACKE) $(NEED2UNDERSCORES) $(ONLY_CBLAS) "$(SYMBOLPREFIX)" "$(SYMBOLSUFFIX)" $(BUILD_LAPACK_DEPRECATED) $(BUILD_BFLOAT16) $(BUILD_SINGLE) $(BUILD_DOUBLE) $(BUILD_COMPLEX) $(BUILD_COMPLEX16) > linktest.c
./$(GENSYM) linktest $(ARCH) "$(BU)" $(EXPRECISION) $(NO_CBLAS) $(NO_LAPACK) $(NO_LAPACKE) $(NEED2UNDERSCORES) $(ONLY_CBLAS) "$(SYMBOLPREFIX)" "$(SYMBOLSUFFIX)" $(BUILD_LAPACK_DEPRECATED) $(BUILD_BFLOAT16) $(BUILD_HFLOAT16) $(BUILD_SINGLE) $(BUILD_DOUBLE) $(BUILD_COMPLEX) $(BUILD_COMPLEX16) > linktest.c

clean ::
@rm -f *.def *.dylib __.SYMDEF* *.renamed


+ 16
- 7
exports/gensymbol View File

@@ -52,6 +52,7 @@ blasobjsz="

blasobjs="lsame xerbla"
bfblasobjs="sbgemm sbgemmt sbgemmtr sbgemv sbdot sbstobf16 sbdtobf16 sbf16tos dbf16tod"
hfblasobjs="shgemm"
cblasobjsc="
cblas_caxpy cblas_ccopy cblas_cdotc cblas_cdotu cblas_cgbmv cblas_cgemm cblas_cgemv
cblas_cgerc cblas_cgeru cblas_chbmv cblas_chemm cblas_chemv cblas_cher2 cblas_cher2k
@@ -100,6 +101,7 @@ cblasobjsz="
cblasobjs="cblas_xerbla"

bfcblasobjs="cblas_sbgemm cblas_sbgemv cblas_sbdot cblas_sbstobf16 cblas_sbdtobf16 cblas_sbf16tos cblas_dbf16tod cblas_sbgemm_batch"
hfcblasobjs="cblas_shgemm"

exblasobjs="
qamax qamin qasum qaxpy qcabs1 qcopy qdot qgbmv qgemm
@@ -3814,6 +3816,8 @@ shift
p16=$9
shift
p17=$9
shift
p18=$9

if [ $p13 -eq 1 ]; then
blasobjs="$blasobjs $bfblasobjs"
@@ -3821,6 +3825,11 @@ if [ $p13 -eq 1 ]; then
fi

if [ $p14 -eq 1 ]; then
blasobjs="$blasobjs $hfblasobjs"
cblasobjs="$cblasobjs $hfcblasobjs"
fi

if [ $p15 -eq 1 ]; then
blasobjs="$blasobjs $blasobjss"
cblasobjs="$cblasobjs $cblasobjss"
lapackobjs="$lapackobjs $lapackobjss"
@@ -3833,11 +3842,11 @@ if [ $p14 -eq 1 ]; then
lapackeobjs="$lapackeobjs $lapackeobjss"
fi

if [ $p15 -eq 1 ]; then
if [ $p16 -eq 1 ]; then
blasobjs="$blasobjs $blasobjsd"
cblasobjs="$cblasobjs $cblasobjsd"
lapackobjs="$lapackobjs $lapackobjsd"
if [ $p14 -eq 0 ]; then
if [ $p15 -eq 0 ]; then
lapackobjs2="$lapackobjs2 $lapackobjs2ds"
fi
lapackobjs2="$lapackobjs2 $lapackobjs2d $lapackobjs2dz"
@@ -3847,14 +3856,14 @@ if [ $p15 -eq 1 ]; then
lapackeobjs="$lapackeobjs $lapackeobjsd"
fi

if [ $p16 -eq 1 ]; then
if [ $p17 -eq 1 ]; then
blasobjs="$blasobjs $blasobjsc"
cblasobjs="$cblasobjs $cblasobjsc"
gemm3mobjs="$gemm3mobjs $gemm3mobjsc"
cblasgemm3mobjs="$cblasgemm3mobjs $cblasgemm3mobjsc"
lapackobjs="$lapackobjs $lapackobjsc"
lapackobjs2="$lapackobjs2 $lapackobjs2c $lapackobjs2zc"
if [ $p14 -eq 0 ]; then
if [ $p15 -eq 0 ]; then
lapackobjs2="$lapackobjs2 $lapackobjs2sc"
fi
lapack_deprecated_objs="$lapack_deprecated_objs $lapack_deprecated_objsc"
@@ -3863,17 +3872,17 @@ if [ $p16 -eq 1 ]; then
lapackeobjs="$lapackeobjs $lapackeobjsc"
fi

if [ $p17 -eq 1 ]; then
if [ $p18 -eq 1 ]; then
blasobjs="$blasobjs $blasobjsz"
cblasobjs="$cblasobjs $cblasobjsz"
gemm3mobjs="$gemm3mobjs $gemm3mobjsz"
cblasgemm3mobjs="$cblasgemm3mobjs $cblasgemm3mobjsz"
lapackobjs="$lapackobjs $lapackobjsz"
lapackobjs2="$lapackobjs2 $lapackobjs2z"
if [ $p16 -eq 0 ]; then
if [ $p17 -eq 0 ]; then
lapackobjs2="$lapackobjs2 $lapackobjs2zc"
fi
if [ $p15 -eq 0 ]; then
if [ $p16 -eq 0 ]; then
lapackobjs2="$lapackobjs2 $lapackobjs2dz"
fi
lapack_deprecated_objs="$lapack_deprecated_objs $lapack_deprecated_objsz"


+ 13
- 8
exports/gensymbol.pl View File

@@ -52,6 +52,7 @@

@blasobjs = (lsame, xerbla);
@bfblasobjs = (sbgemm, sbgemmt, sbgemmtr, sbgemv, sbdot, sbstobf16, sbdtobf16, sbf16tos, dbf16tod);
@hfblasobjs = (shgemm);
@cblasobjsc = (
cblas_caxpy, cblas_ccopy, cblas_cdotc, cblas_cdotu, cblas_cgbmv, cblas_cgemm, cblas_cgemv,
cblas_cgerc, cblas_cgeru, cblas_chbmv, cblas_chemm, cblas_chemv, cblas_cher2, cblas_cher2k,
@@ -97,7 +98,7 @@
@cblasobjs = ( cblas_xerbla );

@bfcblasobjs = (cblas_sbgemm, cblas_sbgemmt, cblas_sbgemmtr, cblas_sbgemv, cblas_sbdot, cblas_sbstobf16, cblas_sbdtobf16, cblas_sbf16tos, cblas_dbf16tod, cblas_sbgemm_batch);
@hfcblasobjs = (cblas_shgemm);
@exblasobjs = (
qamax,qamin,qasum,qaxpy,qcabs1,qcopy,qdot,qgbmv,qgemm,
qgemv,qger,qmax,qmin,
@@ -3777,6 +3778,10 @@ if ($ARGV[12] == 1) {
@cblasobjs = (@cblasobjs, @bfcblasobjs);
}
if ($ARGV[13] == 1) {
@blasobjs = (@blasobjs, @hfblasobjs);
@cblasobjs = (@cblasobjs, @hfcblasobjs);
}
if ($ARGV[14] == 1) {
@blasobjs = (@blasobjs, @blasobjss);
@cblasobjs = (@cblasobjs, @cblasobjss);
@lapackobjs = (@lapackobjs, @lapackobjss);
@@ -3788,11 +3793,11 @@ if ($ARGV[13] == 1) {
@lapack_embeded_underscore_objs = (@lapack_embeded_underscore_objs, @lapack_embeded_underscore_objs_s);
@lapackeobjs = (@lapackeobjs, @lapackeobjss);
}
if ($ARGV[14] == 1) {
if ($ARGV[15] == 1) {
@blasobjs = (@blasobjs, @blasobjsd);
@cblasobjs = (@cblasobjs, @cblasobjsd);
@lapackobjs = (@lapackobjs, @lapackobjsd);
if ($ARGV[13] == 0) {
if ($ARGV[14] == 0) {
@lapackobjs2 = (@lapackobjs2, @lapackobjs2ds);
}
@lapackobjs2 = (@lapackobjs2, @lapackobjs2d, @lapackobjs2dz);
@@ -3801,14 +3806,14 @@ if ($ARGV[14] == 1) {
@lapack_embeded_underscore_objs = (@lapack_embeded_underscore_objs, @lapack_embeded_underscore_objs_d);
@lapackeobjs = (@lapackeobjs, @lapackeobjsd);
}
if ($ARGV[15] == 1) {
if ($ARGV[16] == 1) {
@blasobjs = (@blasobjs, @blasobjsc);
@cblasobjs = (@cblasobjs, @cblasobjsc);
@gemm3mobjs = (@gemm3mobjs, @gemm3mobjsc);
@cblasgemm3mobjs = (@cblasgemm3mobjs, @cblasgemm3mobjsc);
@lapackobjs = (@lapackobjs, @lapackobjsc);
@lapackobjs2 = (@lapackobjs2, @lapackobjs2c, @lapackobjs2zc);
if ($ARGV[13] == 0) {
if ($ARGV[14] == 0) {
@lapackobjs2 = (@lapackobjs2, @lapackobjs2sc);
}
@lapack_deprecated_objs = (@lapack_deprecated_objs, @lapack_deprecated_objsc);
@@ -3816,17 +3821,17 @@ if ($ARGV[15] == 1) {
@lapack_embeded_underscore_objs = (@lapack_embeded_underscore_objs, @lapack_embeded_underscore_objs_c);
@lapackeobjs = (@lapackeobjs, @lapackeobjsc);
}
if ($ARGV[16] == 1) {
if ($ARGV[17] == 1) {
@blasobjs = (@blasobjs, @blasobjsz);
@cblasobjs = (@cblasobjs, @cblasobjsz);
@gemm3mobjs = (@gemm3mobjs, @gemm3mobjsz);
@cblasgemm3mobjs = (@cblasgemm3mobjs, @cblasgemm3mobjsz);
@lapackobjs = (@lapackobjs, @lapackobjsz);
@lapackobjs2 = (@lapackobjs2, @lapackobjs2z);
if ($ARGV[15] == 0) {
if ($ARGV[16] == 0) {
@lapackobjs2 = (@lapackobjs2, @lapackobjs2zc);
}
if ($ARGV[14] == 0) {
if ($ARGV[15] == 0) {
@lapackobjs2 = (@lapackobjs2, @lapackobjs2dz);
}
@lapack_deprecated_objs = (@lapack_deprecated_objs, @lapack_deprecated_objsz);


+ 2
- 0
getarch_2nd.c View File

@@ -19,6 +19,8 @@ int main(int argc, char **argv) {
if ( (argc <= 1) || ((argc >= 2) && (*argv[1] == '0'))) {
printf("SBGEMM_UNROLL_M=%d\n", SBGEMM_DEFAULT_UNROLL_M);
printf("SBGEMM_UNROLL_N=%d\n", SBGEMM_DEFAULT_UNROLL_N);
printf("SHGEMM_UNROLL_M=%d\n", SHGEMM_DEFAULT_UNROLL_M);
printf("SHGEMM_UNROLL_N=%d\n", SHGEMM_DEFAULT_UNROLL_N);
printf("SGEMM_UNROLL_M=%d\n", SGEMM_DEFAULT_UNROLL_M);
printf("SGEMM_UNROLL_N=%d\n", SGEMM_DEFAULT_UNROLL_N);
printf("DGEMM_UNROLL_M=%d\n", DGEMM_DEFAULT_UNROLL_M);


+ 3
- 0
interface/CMakeLists.txt View File

@@ -136,6 +136,9 @@ if (BUILD_BFLOAT16)
GenerateNamedObjects("gemm_batch.c" "" "sbgemm_batch" ${CBLAS_FLAG} "" "" true "BFLOAT16")
endif ()
endif ()
if (BUILD_HFLOAT16)
GenerateNamedObjects("gemm.c" "" "shgemm" ${CBLAS_FLAG} "" "" true "HFLOAT16")
endif ()

# complex-specific sources
foreach (float_type ${FLOAT_TYPES})


+ 22
- 2
interface/Makefile View File

@@ -53,6 +53,10 @@ SBBLAS3OBJS = sbgemm.$(SUFFIX) sbgemmt.$(SUFFIX) sbgemmtr.$(SUFFIX)
SBEXTOBJS = sbstobf16.$(SUFFIX) sbdtobf16.$(SUFFIX) sbf16tos.$(SUFFIX) dbf16tod.$(SUFFIX)
endif

ifeq ($(BUILD_HFLOAT16),1)
SHBLAS3OBJS = shgemm.$(SUFFIX)
endif

DBLAS1OBJS = \
daxpy.$(SUFFIX) dswap.$(SUFFIX) \
dcopy.$(SUFFIX) dscal.$(SUFFIX) \
@@ -291,6 +295,10 @@ CSBBLAS3OBJS = cblas_sbgemm.$(SUFFIX) cblas_sbgemmt.$(SUFFIX) cblas_sbgemmtr.$(S
CSBEXTOBJS = cblas_sbstobf16.$(SUFFIX) cblas_sbdtobf16.$(SUFFIX) cblas_sbf16tos.$(SUFFIX) cblas_dbf16tod.$(SUFFIX)
endif

ifeq ($(BUILD_HFLOAT16),1)
CSHBLAS3OBJS = cblas_shgemm.$(SUFFIX)
endif

CDBLAS1OBJS = \
cblas_idamax.$(SUFFIX) cblas_idamin.$(SUFFIX) cblas_dasum.$(SUFFIX) cblas_daxpy.$(SUFFIX) \
cblas_dcopy.$(SUFFIX) cblas_ddot.$(SUFFIX) \
@@ -388,6 +396,7 @@ SBLAS3OBJS += $(CSBLAS3OBJS)
SBBLAS1OBJS += $(CSBBLAS1OBJS)
SBBLAS2OBJS += $(CSBBLAS2OBJS)
SBBLAS3OBJS += $(CSBBLAS3OBJS)
SHBLAS3OBJS += $(CSHBLAS3OBJS)
DBLAS1OBJS += $(CDBLAS1OBJS)
DBLAS2OBJS += $(CDBLAS2OBJS)
DBLAS3OBJS += $(CDBLAS3OBJS)
@@ -405,6 +414,7 @@ endif

SBLASOBJS = $(SBLAS1OBJS) $(SBLAS2OBJS) $(SBLAS3OBJS)
SBBLASOBJS = $(SBBLAS1OBJS) $(SBBLAS2OBJS) $(SBBLAS3OBJS)
SHBLASOBJS = $(SHBLAS3OBJS)
DBLASOBJS = $(DBLAS1OBJS) $(DBLAS2OBJS) $(DBLAS3OBJS)
QBLASOBJS = $(QBLAS1OBJS) $(QBLAS2OBJS) $(QBLAS3OBJS)
CBLASOBJS = $(CBLAS1OBJS) $(CBLAS2OBJS) $(CBLAS3OBJS)
@@ -512,7 +522,7 @@ ifneq ($(BUILD_COMPLEX16),1)
ZBLASOBJS=
endif

FUNCOBJS = $(SBEXTOBJS) $(CXERBLAOBJS) $(SBBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS)
FUNCOBJS = $(SBEXTOBJS) $(CXERBLAOBJS) $(SBBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS) $(SHBLASOBJS)

ifeq ($(EXPRECISION), 1)
FUNCOBJS += $(QBLASOBJS) $(XBLASOBJS)
@@ -550,7 +560,7 @@ level1 : $(SBEXTOBJS) $(SBBLAS1OBJS) $(SBLAS1OBJS) $(DBLAS1OBJS) $(QBLAS1OBJS) $
level2 : $(SBBLAS2OBJS) $(SBLAS2OBJS) $(DBLAS2OBJS) $(QBLAS2OBJS) $(CBLAS2OBJS) $(ZBLAS2OBJS) $(XBLAS2OBJS)
$(AR) $(ARFLAGS) -ru $(TOPDIR)/$(LIBNAME) $^

level3 : $(SBBLAS3OBJS) $(SBLAS3OBJS) $(DBLAS3OBJS) $(QBLAS3OBJS) $(CBLAS3OBJS) $(ZBLAS3OBJS) $(XBLAS3OBJS)
level3 : $(SBBLAS3OBJS) $(SBLAS3OBJS) $(DBLAS3OBJS) $(QBLAS3OBJS) $(CBLAS3OBJS) $(ZBLAS3OBJS) $(XBLAS3OBJS) $(SHBLAS3OBJS)
$(AR) $(ARFLAGS) -ru $(TOPDIR)/$(LIBNAME) $^

aux : $(CBAUXOBJS)
@@ -1309,6 +1319,11 @@ sbgemmtr.$(SUFFIX) sbgemmtr.$(PSUFFIX) : sbgemmt.c ../param.h
$(CC) -c $(CFLAGS) -DRNAME $< -o $(@F)
endif

ifeq ($(BUILD_HFLOAT16),1)
shgemm.$(SUFFIX) shgemm.$(PSUFFIX) : gemm.c ../param.h
$(CC) -c $(CFLAGS) $< -o $(@F)
endif

sgemm.$(SUFFIX) sgemm.$(PSUFFIX) : gemm.c ../param.h
$(CC) -c $(CFLAGS) $< -o $(@F)

@@ -1968,6 +1983,11 @@ cblas_sbgemm.$(SUFFIX) cblas_sbgemm.$(PSUFFIX) : gemm.c ../param.h
$(CC) -DCBLAS -c $(CFLAGS) $< -o $(@F)
endif

ifeq ($(BUILD_HFLOAT16),1)
cblas_shgemm.$(SUFFIX) cblas_shgemm.$(PSUFFIX) : gemm.c ../param.h
$(CC) -DCBLAS -c $(CFLAGS) $< -o $(@F)
endif

cblas_dgemm.$(SUFFIX) cblas_dgemm.$(PSUFFIX) : gemm.c ../param.h
$(CC) -DCBLAS -c $(CFLAGS) $< -o $(@F)



+ 8
- 6
interface/gemm.c View File

@@ -56,6 +56,8 @@
#elif defined(BFLOAT16)
#define ERROR_NAME "SBGEMM "
#define GEMV BLASFUNC(sbgemv)
#elif defined(HFLOAT16)
#define ERROR_NAME "SHGEMM "
#else
#define ERROR_NAME "SGEMM "
#define GEMV BLASFUNC(sgemv)
@@ -111,7 +113,7 @@ static int (*gemm[])(blas_arg_t *, BLASLONG *, BLASLONG *, IFLOAT *, IFLOAT *, B
#endif
};

#if defined(SMALL_MATRIX_OPT) && !defined(GEMM3M) && !defined(XDOUBLE)
#if defined(SMALL_MATRIX_OPT) && !defined(GEMM3M) && !defined(XDOUBLE) &&!defined(HFLOAT16)
#define USE_SMALL_MATRIX_OPT 1
#else
#define USE_SMALL_MATRIX_OPT 0
@@ -219,11 +221,11 @@ static inline int get_gemm_optimal_nthreads_neoversev2(double MNK, int ncpu) {

static inline int get_gemm_optimal_nthreads(double MNK) {
int ncpu = num_cpu_avail(3);
#if defined(NEOVERSEV1) && !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16)
#if defined(NEOVERSEV1) && !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16) && !defined(HFLOAT16)
return get_gemm_optimal_nthreads_neoversev1(MNK, ncpu);
#elif defined(NEOVERSEV2) && !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16)
#elif defined(NEOVERSEV2) && !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16) && !defined(HFLOAT16)
return get_gemm_optimal_nthreads_neoversev2(MNK, ncpu);
#elif defined(DYNAMIC_ARCH) && !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16)
#elif defined(DYNAMIC_ARCH) && !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16) && !defined(HFLOAT16)
if (strcmp(gotoblas_corename(), "neoversev1") == 0) {
return get_gemm_optimal_nthreads_neoversev1(MNK, ncpu);
}
@@ -417,7 +419,7 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS

PRINT_DEBUG_CNAME;

#if !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16)
#if !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16) && !defined(HFLOAT16)
#if defined(ARCH_x86) && (defined(USE_SGEMM_KERNEL_DIRECT)||defined(DYNAMIC_ARCH))
#if defined(DYNAMIC_ARCH)
if (support_avx512() )
@@ -577,7 +579,7 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS
args.m, args.n, args.k, args.lda, args.ldb, args.ldc);
#endif

#if defined(GEMM_GEMV_FORWARD) && !defined(GEMM3M) && !defined(COMPLEX) && (!defined(BFLOAT16) || defined(GEMM_GEMV_FORWARD_BF16))
#if defined(GEMM_GEMV_FORWARD) && !defined(GEMM3M) && !defined(COMPLEX) && !defined(HFLOAT16) && (!defined(BFLOAT16) || defined(GEMM_GEMV_FORWARD_BF16))
#if defined(ARCH_ARM64)
// The gemv kernels in arm64/{gemv_n.S,gemv_n_sve.c,gemv_t.S,gemv_t_sve.c}
// perform poorly in certain circumstances. We use the following boolean


+ 55
- 0
kernel/CMakeLists.txt View File

@@ -351,6 +351,22 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS)
GenerateNamedObjects("${KERNELDIR}/${SBGEMMKERNEL}" "" "gemm_kernel" false "" "" false "BFLOAT16")
GenerateNamedObjects("${KERNELDIR}/${SBGEMM_BETA}" "" "gemm_beta" false "" "" false "BFLOAT16")
endif ()
if (BUILD_HFLOAT16)
if (SHGEMMINCOPY)
GenerateNamedObjects("${KERNELDIR}/${SHGEMMINCOPY}" "" "${SHGEMMINCOPYOBJ}" false "" "" true "HFLOAT16")
endif ()
if (SHGEMMITCOPY)
GenerateNamedObjects("${KERNELDIR}/${SHGEMMITCOPY}" "" "${SHGEMMITCOPYOBJ}" false "" "" true "HFLOAT16")
endif ()
if (SHGEMMONCOPY)
GenerateNamedObjects("${KERNELDIR}/${SHGEMMONCOPY}" "" "${SHGEMMONCOPYOBJ}" false "" "" true "HFLOAT16")
endif ()
if (SHGEMMOTCOPY)
GenerateNamedObjects("${KERNELDIR}/${SHGEMMOTCOPY}" "" "${SHGEMMOTCOPYOBJ}" false "" "" true "HFLOAT16")
endif ()
GenerateNamedObjects("${KERNELDIR}/${SHGEMMKERNEL}" "" "gemm_kernel" false "" "" false "HFLOAT16")
GenerateNamedObjects("${KERNELDIR}/${SHGEMM_BETA}" "" "gemm_beta" false "" "" false "HFLOAT16")
endif ()
foreach (float_type ${FLOAT_TYPES})
string(SUBSTRING ${float_type} 0 1 float_char)
if (${float_char}GEMMINCOPY)
@@ -769,6 +785,45 @@ endif ()
GenerateNamedObjects("${KERNELDIR}/${SBGEMM_SMALL_K_B0_TN}" "B0" "gemm_small_kernel_b0_tn" false "" "" false "BFLOAT16")
GenerateNamedObjects("${KERNELDIR}/${SBGEMM_SMALL_K_B0_TT}" "B0" "gemm_small_kernel_b0_tt" false "" "" false "BFLOAT16")
endif ()

if (BUILD_HFLOAT16)
if (NOT DEFINED SHGEMM_SMALL_M_PERMIT)
set(SHGEMM_SMALL_M_PERMIT ../generic/gemm_small_matrix_permit.c)
endif ()
if (NOT DEFINED SHGEMM_SMALL_K_NN)
set(SHGEMM_SMALL_K_NN ../generic/gemm_small_matrix_kernel_nn.c)
endif ()
if (NOT DEFINED SHGEMM_SMALL_K_NT)
set(SHGEMM_SMALL_K_NT ../generic/gemm_small_matrix_kernel_nt.c)
endif ()
if (NOT DEFINED SHGEMM_SMALL_K_TN)
set(SHGEMM_SMALL_K_TN ../generic/gemm_small_matrix_kernel_tn.c)
endif ()
if (NOT DEFINED SHGEMM_SMALL_K_TT)
set(SHGEMM_SMALL_K_TT ../generic/gemm_small_matrix_kernel_tt.c)
endif ()
if (NOT DEFINED SHGEMM_SMALL_K_B0_NN)
set(SHGEMM_SMALL_K_B0_NN ../generic/gemm_small_matrix_kernel_nn.c)
endif ()
if (NOT DEFINED SHGEMM_SMALL_K_B0_NT)
set(SHGEMM_SMALL_K_B0_NT ../generic/gemm_small_matrix_kernel_nt.c)
endif ()
if (NOT DEFINED SHGEMM_SMALL_K_B0_TN)
set(SHGEMM_SMALL_K_B0_TN ../generic/gemm_small_matrix_kernel_tn.c)
endif ()
if (NOT DEFINED SHGEMM_SMALL_K_B0_TT)
set(SHGEMM_SMALL_K_B0_TT ../generic/gemm_small_matrix_kernel_tt.c)
endif ()
GenerateNamedObjects("${KERNELDIR}/${SHGEMM_SMALL_M_PERMIT}" "" "gemm_small_matrix_permit" false "" "" false "HFLOAT16")
GenerateNamedObjects("${KERNELDIR}/${SHGEMM_SMALL_K_NN}" "" "gemm_small_kernel_nn" false "" "" false "HFLOAT16")
GenerateNamedObjects("${KERNELDIR}/${SHGEMM_SMALL_K_NT}" "" "gemm_small_kernel_nt" false "" "" false "HFLOAT16")
GenerateNamedObjects("${KERNELDIR}/${SHGEMM_SMALL_K_TN}" "" "gemm_small_kernel_tn" false "" "" false "HFLOAT16")
GenerateNamedObjects("${KERNELDIR}/${SHGEMM_SMALL_K_TT}" "" "gemm_small_kernel_tt" false "" "" false "HFLOAT16")
GenerateNamedObjects("${KERNELDIR}/${SHGEMM_SMALL_K_B0_NN}" "B0" "gemm_small_kernel_b0_nn" false "" "" false "HFLOAT16")
GenerateNamedObjects("${KERNELDIR}/${SHGEMM_SMALL_K_B0_NT}" "B0" "gemm_small_kernel_b0_nt" false "" "" false "HFLOAT16")
GenerateNamedObjects("${KERNELDIR}/${SHGEMM_SMALL_K_B0_TN}" "B0" "gemm_small_kernel_b0_tn" false "" "" false "HFLOAT16")
GenerateNamedObjects("${KERNELDIR}/${SHGEMM_SMALL_K_B0_TT}" "B0" "gemm_small_kernel_b0_tt" false "" "" false "HFLOAT16")
endif ()
endif ()

if (NOT DEFINED ${float_char}OMATCOPY_CN)


+ 164
- 0
kernel/Makefile.L3 View File

@@ -129,6 +129,26 @@ SBKERNELOBJS += \
$(SBGEMMONCOPYOBJ) $(SBGEMMOTCOPYOBJ)
endif

ifeq ($(BUILD_HFLOAT16), 1)
ifndef SHGEMMKERNEL
SHGEMM_BETA = ../generic/gemm_beta.c
SHGEMMKERNEL = ../generic/gemmkernel_2x2.c
SHGEMMONCOPY = ../generic/gemm_ncopy_2.c
SHGEMMOTCOPY = ../generic/gemm_tcopy_2.c
SHGEMMONCOPYOBJ = shgemm_oncopy$(TSUFFIX).$(SUFFIX)
SHGEMMOTCOPYOBJ = shgemm_otcopy$(TSUFFIX).$(SUFFIX)
SHGEMMINCOPY = ../generic/gemm_ncopy_2.c
SHGEMMITCOPY = ../generic/gemm_tcopy_2.c
SHGEMMINCOPYOBJ = shgemm_incopy$(TSUFFIX).$(SUFFIX)
SHGEMMITCOPYOBJ = shgemm_itcopy$(TSUFFIX).$(SUFFIX)
endif

SHKERNELOBJS += \
shgemm_kernel$(TSUFFIX).$(SUFFIX) \
$(SHGEMMINCOPYOBJ) $(SHGEMMITCOPYOBJ) \
$(SHGEMMONCOPYOBJ) $(SHGEMMOTCOPYOBJ)
endif

ifneq "$(or $(BUILD_SINGLE),$(BUILD_DOUBLE),$(BUILD_COMPLEX))" ""
SKERNELOBJS += \
sgemm_kernel$(TSUFFIX).$(SUFFIX) \
@@ -192,6 +212,9 @@ XKERNELOBJS += \
ifeq ($(BUILD_BFLOAT16),1)
SBBLASOBJS += $(SBKERNELOBJS)
endif
ifeq ($(BUILD_HFLOAT16),1)
SHBLASOBJS += $(SHKERNELOBJS)
endif
SBLASOBJS += $(SKERNELOBJS)
DBLASOBJS += $(DKERNELOBJS)
QBLASOBJS += $(QKERNELOBJS)
@@ -202,6 +225,9 @@ XBLASOBJS += $(XKERNELOBJS)
ifeq ($(BUILD_BFLOAT16),1)
SBBLASOBJS += sbgemm_beta$(TSUFFIX).$(SUFFIX)
endif
ifeq ($(BUILD_HFLOAT16),1)
SHBLASOBJS += shgemm_beta$(TSUFFIX).$(SUFFIX)
endif

ifneq "$(or $(BUILD_SINGLE),$(BUILD_DOUBLE),$(BUILD_COMPLEX))" ""
SBLASOBJS += \
@@ -493,6 +519,15 @@ SBBLASOBJS += \
sbgemm_small_kernel_b0_tn$(TSUFFIX).$(SUFFIX) sbgemm_small_kernel_b0_tt$(TSUFFIX).$(SUFFIX)
endif

ifeq ($(BUILD_HFLOAT16),1)
SHBLASOBJS += \
shgemm_small_matrix_permit$(TSUFFIX).$(SUFFIX) \
shgemm_small_kernel_nn$(TSUFFIX).$(SUFFIX) shgemm_small_kernel_nt$(TSUFFIX).$(SUFFIX) \
shgemm_small_kernel_tn$(TSUFFIX).$(SUFFIX) shgemm_small_kernel_tt$(TSUFFIX).$(SUFFIX) \
shgemm_small_kernel_b0_nn$(TSUFFIX).$(SUFFIX) shgemm_small_kernel_b0_nt$(TSUFFIX).$(SUFFIX) \
shgemm_small_kernel_b0_tn$(TSUFFIX).$(SUFFIX) shgemm_small_kernel_b0_tt$(TSUFFIX).$(SUFFIX)
endif

SBLASOBJS += \
sgemm_small_matrix_permit$(TSUFFIX).$(SUFFIX) \
sgemm_small_kernel_nn$(TSUFFIX).$(SUFFIX) sgemm_small_kernel_nt$(TSUFFIX).$(SUFFIX) \
@@ -599,6 +634,13 @@ SBGEMMONCOPYOBJ_P = $(SBGEMMONCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX))
SBGEMMOTCOPYOBJ_P = $(SBGEMMOTCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX))
endif

ifeq ($(BUILD_HFLOAT16), 1)
SHGEMMINCOPYOBJ_P = $(SHGEMMINCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX))
SHGEMMITCOPYOBJ_P = $(SHGEMMITCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX))
SHGEMMONCOPYOBJ_P = $(SHGEMMONCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX))
SHGEMMOTCOPYOBJ_P = $(SHGEMMOTCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX))
endif

SGEMMINCOPYOBJ_P = $(SGEMMINCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX))
SGEMMITCOPYOBJ_P = $(SGEMMITCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX))
SGEMMONCOPYOBJ_P = $(SGEMMONCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX))
@@ -629,6 +671,11 @@ $(KDIR)sbgemm_beta$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SBGEMM_BETA)
$(CC) $(CFLAGS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@
endif

ifeq ($(BUILD_HFLOAT16),1)
$(KDIR)shgemm_beta$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SHGEMM_BETA)
$(CC) $(CFLAGS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@
endif

$(KDIR)sgemm_beta$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_BETA)
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@

@@ -671,6 +718,25 @@ $(KDIR)$(SBGEMMITCOPYOBJ) : $(KERNELDIR)/$(SBGEMMITCOPY)
endif
endif

ifeq ($(BUILD_HFLOAT16), 1)

$(KDIR)$(SHGEMMONCOPYOBJ) : $(KERNELDIR)/$(SHGEMMONCOPY)
$(CC) $(CFLAGS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@

$(KDIR)$(SHGEMMOTCOPYOBJ) : $(KERNELDIR)/$(SHGEMMOTCOPY)
$(CC) $(CFLAGS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@

#ifneq ($(SHGEMM_UNROLL_M), $(SHGEMM_UNROLL_N))

$(KDIR)$(SHGEMMINCOPYOBJ) : $(KERNELDIR)/$(SHGEMMINCOPY)
$(CC) $(CFLAGS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@

$(KDIR)$(SHGEMMITCOPYOBJ) : $(KERNELDIR)/$(SHGEMMITCOPY)
$(CC) $(CFLAGS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@

#endif
endif

$(KDIR)$(SGEMMONCOPYOBJ) : $(KERNELDIR)/$(SGEMMONCOPY)
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@

@@ -853,6 +919,12 @@ $(KDIR)sbgemm_kernel$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SBGEMMKERNEL) $(SBGEMM
$(CC) $(CFLAGS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@
endif

ifeq ($(BUILD_HFLOAT16), 1)

$(KDIR)shgemm_kernel$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SHGEMMKERNEL) $(SHGEMMDEPEND)
$(CC) $(CFLAGS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@
endif

$(KDIR)dgemm_kernel$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMMKERNEL) $(DGEMMDEPEND)
ifeq ($(OS), AIX)
$(CC) $(CFLAGS) -S -DDOUBLE -UCOMPLEX $< -o - > dgemm_kernel$(TSUFFIX).s
@@ -2840,6 +2912,11 @@ $(KDIR)sbgemm_beta$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SBGEMM_BETA)
$(CC) $(PFLAGS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@
endif

ifeq ($(BUILD_HFLOAT16),1)
$(KDIR)shgemm_beta$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SHGEMM_BETA)
$(CC) $(PFLAGS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@
endif

$(KDIR)dgemm_beta$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(DGEMM_BETA)
$(CC) $(PFLAGS) -c -DDOUBLE -UCOMPLEX $< -o $@

@@ -2873,6 +2950,23 @@ $(SBGEMMITCOPYOBJ_P) : $(KERNELDIR)/$(SBGEMMITCOPY)
endif
endif

ifeq ($(BUILD_HFLOAT16), 1)
$(SHGEMMONCOPYOBJ_P) : $(KERNELDIR)/$(SHGEMMONCOPY)
$(CC) $(PFLAGS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@

$(SHGEMMOTCOPYOBJ_P) : $(KERNELDIR)/$(SHGEMMOTCOPY)
$(CC) $(PFLAGS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@

#ifneq ($(SHGEMM_UNROLL_M), $(SHGEMM_UNROLL_N))
$(SHGEMMINCOPYOBJ_P) : $(KERNELDIR)/$(SHGEMMINCOPY)
$(CC) $(PFLAGS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@

$(SHGEMMITCOPYOBJ_P) : $(KERNELDIR)/$(SHGEMMITCOPY)
$(CC) $(PFLAGS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@

#endif
endif

$(SGEMMONCOPYOBJ_P) : $(KERNELDIR)/$(SGEMMONCOPY)
$(CC) $(PFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@

@@ -2983,6 +3077,11 @@ $(KDIR)sbgemm_kernel$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SBGEMMKERNEL) $(SBGEM
$(CC) $(PFLAGS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@
endif

ifeq ($(BUILD_HFLOAT16), 1)
$(KDIR)shgemm_kernel$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SHGEMMKERNEL) $(SHGEMMDEPEND)
$(CC) $(PFLAGS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@
endif

$(KDIR)sgemm_kernel$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SGEMMKERNEL) $(SGEMMDEPEND)
$(CC) $(PFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@

@@ -4843,6 +4942,71 @@ $(KDIR)sbgemm_small_kernel_b0_tt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SBGEMM_SMA
$(CC) $(CFLAGS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX -DB0 $< -o $@
endif

ifeq ($(BUILD_HFLOAT16), 1)
ifndef SHGEMM_SMALL_M_PERMIT
SHGEMM_SMALL_M_PERMIT = ../generic/gemm_small_matrix_permit.c
endif

ifndef SHGEMM_SMALL_K_NN
SHGEMM_SMALL_K_NN = ../generic/gemm_small_matrix_kernel_nn.c
endif

ifndef SHGEMM_SMALL_K_NT
SHGEMM_SMALL_K_NT = ../generic/gemm_small_matrix_kernel_nt.c
endif

ifndef SHGEMM_SMALL_K_TN
SHGEMM_SMALL_K_TN = ../generic/gemm_small_matrix_kernel_tn.c
endif

ifndef SHGEMM_SMALL_K_TT
SHGEMM_SMALL_K_TT = ../generic/gemm_small_matrix_kernel_tt.c
endif

$(KDIR)shgemm_small_matrix_permit$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SHGEMM_SMALL_M_PERMIT)
$(CC) $(CFLAGS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@

$(KDIR)shgemm_small_kernel_nn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SHGEMM_SMALL_K_NN)
$(CC) $(CFLAGS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@

$(KDIR)shgemm_small_kernel_nt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SHGEMM_SMALL_K_NT)
$(CC) $(CFLAGS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@

$(KDIR)shgemm_small_kernel_tn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SHGEMM_SMALL_K_TN)
$(CC) $(CFLAGS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@

$(KDIR)shgemm_small_kernel_tt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SHGEMM_SMALL_K_TT)
$(CC) $(CFLAGS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@

ifndef SHGEMM_SMALL_K_B0_NN
SHGEMM_SMALL_K_B0_NN = ../generic/gemm_small_matrix_kernel_nn.c
endif

ifndef SHGEMM_SMALL_K_B0_NT
SHGEMM_SMALL_K_B0_NT = ../generic/gemm_small_matrix_kernel_nt.c
endif

ifndef SHGEMM_SMALL_K_B0_TN
SHGEMM_SMALL_K_B0_TN = ../generic/gemm_small_matrix_kernel_tn.c
endif

ifndef SHGEMM_SMALL_K_B0_TT
SHGEMM_SMALL_K_B0_TT = ../generic/gemm_small_matrix_kernel_tt.c
endif

$(KDIR)shgemm_small_kernel_b0_nn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SHGEMM_SMALL_K_B0_NN)
$(CC) $(CFLAGS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX -DB0 $< -o $@

$(KDIR)shgemm_small_kernel_b0_nt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SHGEMM_SMALL_K_B0_NT)
$(CC) $(CFLAGS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX -DB0 $< -o $@

$(KDIR)shgemm_small_kernel_b0_tn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SHGEMM_SMALL_K_B0_TN)
$(CC) $(CFLAGS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX -DB0 $< -o $@

$(KDIR)shgemm_small_kernel_b0_tt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SHGEMM_SMALL_K_B0_TT)
$(CC) $(CFLAGS) -c -DHFLOAT16 -UDOUBLE -UCOMPLEX -DB0 $< -o $@
endif

ifndef CGEMM_SMALL_M_PERMIT
CGEMM_SMALL_M_PERMIT = ../generic/zgemm_small_matrix_permit.c
endif


+ 11
- 0
kernel/riscv64/KERNEL.RISCV64_ZVL128B View File

@@ -245,3 +245,14 @@ endif
ifndef ZGEMM_BETA
ZGEMM_BETA = zgemm_beta_rvv.c
endif

ifeq ($(BUILD_BFLOAT16), 1)
SHGEMMKERNEL = shgemm_kernel_$(SHGEMM_UNROLL_M)x$(SHGEMM_UNROLL_N)_zvl128b.c
SHGEMMONCOPY = ../generic/gemm_ncopy_$(SHGEMM_UNROLL_N).c
SHGEMMOTCOPY = ../generic/gemm_tcopy_$(SHGEMM_UNROLL_N).c
SHGEMMONCOPYOBJ = shgemm_oncopy$(TSUFFIX).$(SUFFIX)
SHGEMMOTCOPYOBJ = shgemm_otcopy$(TSUFFIX).$(SUFFIX)
ifndef SHGEMM_BETA
SHGEMM_BETA = gemm_beta_rvv.c
endif
endif

+ 18
- 0
kernel/riscv64/KERNEL.RISCV64_ZVL256B View File

@@ -209,5 +209,23 @@ COMATCOPY_CN = zomatcopy_cn_vector.c
DOMATCOPY_CN = omatcopy_cn_vector.c
SOMATCOPY_CN = omatcopy_cn_vector.c


ifeq ($(BUILD_BFLOAT16), 1)
SHGEMMKERNEL = shgemm_kernel_$(SHGEMM_UNROLL_M)x$(SHGEMM_UNROLL_N)_zvl256b.c
ifneq ($(SHGEMM_UNROLL_M), $(SHGEMM_UNROLL_N))
SHGEMMINCOPY = ../generic/gemm_ncopy_$(SHGEMM_UNROLL_M).c
SHGEMMITCOPY = ../generic/gemm_tcopy_$(SHGEMM_UNROLL_M).c
SHGEMMINCOPYOBJ = shgemm_incopy$(TSUFFIX).$(SUFFIX)
SHGEMMITCOPYOBJ = shgemm_itcopy$(TSUFFIX).$(SUFFIX)
endif
SHGEMMONCOPY = ../generic/gemm_ncopy_$(SHGEMM_UNROLL_N).c
SHGEMMOTCOPY = ../generic/gemm_tcopy_$(SHGEMM_UNROLL_N).c
SHGEMMONCOPYOBJ = shgemm_oncopy$(TSUFFIX).$(SUFFIX)
SHGEMMOTCOPYOBJ = shgemm_otcopy$(TSUFFIX).$(SUFFIX)
ifndef SHGEMM_BETA
SHGEMM_BETA = gemm_beta_rvv.c
endif
endif

SAXPBYKERNEL = axpby_vector_v2.c
DAXPBYKERNEL = axpby_vector_v2.c

+ 969
- 0
kernel/riscv64/shgemm_kernel_16x8_zvl256b.c View File

@@ -0,0 +1,969 @@

#include "common.h"
#include <riscv_vector.h>
int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B, FLOAT *C, BLASLONG ldc)
{
BLASLONG gvl = 0;
BLASLONG m_top = 0;
BLASLONG n_top = 0;

// -- MAIN PASS
for (BLASLONG j=0; j<N/8; j+=1) {
m_top = 0;
BLASLONG gvl = __riscv_vsetvl_e16m1(16);

for (BLASLONG i=0; i<M/16; i+=1) {
BLASLONG ai=m_top*K;
BLASLONG bi=n_top*K;

_Float16 B0 = B[bi+0];
_Float16 B1 = B[bi+1];
_Float16 B2 = B[bi+2];
_Float16 B3 = B[bi+3];
_Float16 B4 = B[bi+4];
_Float16 B5 = B[bi+5];
_Float16 B6 = B[bi+6];
_Float16 B7 = B[bi+7];
bi += 8;

vfloat16m1_t A0 = __riscv_vle16_v_f16m1( &A[ai+0*gvl], gvl );
ai += 16;
vfloat32m2_t result0 = __riscv_vfwmul_vf_f32m2( A0, B0, gvl);
vfloat32m2_t result1 = __riscv_vfwmul_vf_f32m2( A0, B1, gvl);
vfloat32m2_t result2 = __riscv_vfwmul_vf_f32m2( A0, B2, gvl);
vfloat32m2_t result3 = __riscv_vfwmul_vf_f32m2( A0, B3, gvl);
vfloat32m2_t result4 = __riscv_vfwmul_vf_f32m2( A0, B4, gvl);
vfloat32m2_t result5 = __riscv_vfwmul_vf_f32m2( A0, B5, gvl);
vfloat32m2_t result6 = __riscv_vfwmul_vf_f32m2( A0, B6, gvl);
vfloat32m2_t result7 = __riscv_vfwmul_vf_f32m2( A0, B7, gvl);
for(BLASLONG k=1; k<K; k++) {
B0 = B[bi+0];
B1 = B[bi+1];
B2 = B[bi+2];
B3 = B[bi+3];
B4 = B[bi+4];
B5 = B[bi+5];
B6 = B[bi+6];
B7 = B[bi+7];
bi += 8;
A0 = __riscv_vle16_v_f16m1( &A[ai+0*gvl], gvl );
ai += 16;
result0 = __riscv_vfwmacc_vf_f32m2(result0, B0, A0, gvl);
result1 = __riscv_vfwmacc_vf_f32m2(result1, B1, A0, gvl);
result2 = __riscv_vfwmacc_vf_f32m2(result2, B2, A0, gvl);
result3 = __riscv_vfwmacc_vf_f32m2(result3, B3, A0, gvl);
result4 = __riscv_vfwmacc_vf_f32m2(result4, B4, A0, gvl);
result5 = __riscv_vfwmacc_vf_f32m2(result5, B5, A0, gvl);
result6 = __riscv_vfwmacc_vf_f32m2(result6, B6, A0, gvl);
result7 = __riscv_vfwmacc_vf_f32m2(result7, B7, A0, gvl);
}
BLASLONG ci=n_top*ldc+m_top;

vfloat32m2_t c0 = __riscv_vle32_v_f32m2( &C[ci], gvl); ci += ldc-gvl*0;
vfloat32m2_t c1 = __riscv_vle32_v_f32m2( &C[ci], gvl); ci += ldc-gvl*0;
vfloat32m2_t c2 = __riscv_vle32_v_f32m2( &C[ci], gvl); ci += ldc-gvl*0;
vfloat32m2_t c3 = __riscv_vle32_v_f32m2( &C[ci], gvl); ci += ldc-gvl*0;
vfloat32m2_t c4 = __riscv_vle32_v_f32m2( &C[ci], gvl); ci += ldc-gvl*0;
vfloat32m2_t c5 = __riscv_vle32_v_f32m2( &C[ci], gvl); ci += ldc-gvl*0;
vfloat32m2_t c6 = __riscv_vle32_v_f32m2( &C[ci], gvl); ci += ldc-gvl*0;
vfloat32m2_t c7 = __riscv_vle32_v_f32m2( &C[ci], gvl);
c0 = __riscv_vfmacc_vf_f32m2(c0, alpha, result0, gvl);
c1 = __riscv_vfmacc_vf_f32m2(c1, alpha, result1, gvl);
c2 = __riscv_vfmacc_vf_f32m2(c2, alpha, result2, gvl);
c3 = __riscv_vfmacc_vf_f32m2(c3, alpha, result3, gvl);
c4 = __riscv_vfmacc_vf_f32m2(c4, alpha, result4, gvl);
c5 = __riscv_vfmacc_vf_f32m2(c5, alpha, result5, gvl);
c6 = __riscv_vfmacc_vf_f32m2(c6, alpha, result6, gvl);
c7 = __riscv_vfmacc_vf_f32m2(c7, alpha, result7, gvl);

ci=n_top*ldc+m_top;

__riscv_vse32_v_f32m2( &C[ci], c0, gvl); ci += ldc-gvl*0;
__riscv_vse32_v_f32m2( &C[ci], c1, gvl); ci += ldc-gvl*0;
__riscv_vse32_v_f32m2( &C[ci], c2, gvl); ci += ldc-gvl*0;
__riscv_vse32_v_f32m2( &C[ci], c3, gvl); ci += ldc-gvl*0;
__riscv_vse32_v_f32m2( &C[ci], c4, gvl); ci += ldc-gvl*0;
__riscv_vse32_v_f32m2( &C[ci], c5, gvl); ci += ldc-gvl*0;
__riscv_vse32_v_f32m2( &C[ci], c6, gvl); ci += ldc-gvl*0;
__riscv_vse32_v_f32m2( &C[ci], c7, gvl);
m_top += 16;
}



// -- tails for main pass
if( M & 8 ) {
gvl = __riscv_vsetvl_e16mf2(8);

BLASLONG ai=m_top*K;
BLASLONG bi=n_top*K;
_Float16 B0 = B[bi+0];
_Float16 B1 = B[bi+1];
_Float16 B2 = B[bi+2];
_Float16 B3 = B[bi+3];
_Float16 B4 = B[bi+4];
_Float16 B5 = B[bi+5];
_Float16 B6 = B[bi+6];
_Float16 B7 = B[bi+7];
bi += 8;

vfloat16mf2_t A0 = __riscv_vle16_v_f16mf2( &A[ai+0*gvl], gvl );
ai += 8;

vfloat32m1_t result0 = __riscv_vfwmul_vf_f32m1( A0, B0, gvl);
vfloat32m1_t result1 = __riscv_vfwmul_vf_f32m1( A0, B1, gvl);
vfloat32m1_t result2 = __riscv_vfwmul_vf_f32m1( A0, B2, gvl);
vfloat32m1_t result3 = __riscv_vfwmul_vf_f32m1( A0, B3, gvl);
vfloat32m1_t result4 = __riscv_vfwmul_vf_f32m1( A0, B4, gvl);
vfloat32m1_t result5 = __riscv_vfwmul_vf_f32m1( A0, B5, gvl);
vfloat32m1_t result6 = __riscv_vfwmul_vf_f32m1( A0, B6, gvl);
vfloat32m1_t result7 = __riscv_vfwmul_vf_f32m1( A0, B7, gvl);

for(BLASLONG k=1; k<K; k++) {
B0 = B[bi+0];
B1 = B[bi+1];
B2 = B[bi+2];
B3 = B[bi+3];
B4 = B[bi+4];
B5 = B[bi+5];
B6 = B[bi+6];
B7 = B[bi+7];
bi += 8;

A0 = __riscv_vle16_v_f16mf2( &A[ai+0*gvl], gvl );
ai += 8;
result0 = __riscv_vfwmacc_vf_f32m1(result0, B0, A0, gvl);
result1 = __riscv_vfwmacc_vf_f32m1(result1, B1, A0, gvl);
result2 = __riscv_vfwmacc_vf_f32m1(result2, B2, A0, gvl);
result3 = __riscv_vfwmacc_vf_f32m1(result3, B3, A0, gvl);
result4 = __riscv_vfwmacc_vf_f32m1(result4, B4, A0, gvl);
result5 = __riscv_vfwmacc_vf_f32m1(result5, B5, A0, gvl);
result6 = __riscv_vfwmacc_vf_f32m1(result6, B6, A0, gvl);
result7 = __riscv_vfwmacc_vf_f32m1(result7, B7, A0, gvl);
}

BLASLONG ci=n_top*ldc+m_top;

vfloat32m1_t c0 = __riscv_vle32_v_f32m1( &C[ci], gvl); ci += ldc-gvl*0;
vfloat32m1_t c1 = __riscv_vle32_v_f32m1( &C[ci], gvl); ci += ldc-gvl*0;
vfloat32m1_t c2 = __riscv_vle32_v_f32m1( &C[ci], gvl); ci += ldc-gvl*0;
vfloat32m1_t c3 = __riscv_vle32_v_f32m1( &C[ci], gvl); ci += ldc-gvl*0;
vfloat32m1_t c4 = __riscv_vle32_v_f32m1( &C[ci], gvl); ci += ldc-gvl*0;
vfloat32m1_t c5 = __riscv_vle32_v_f32m1( &C[ci], gvl); ci += ldc-gvl*0;
vfloat32m1_t c6 = __riscv_vle32_v_f32m1( &C[ci], gvl); ci += ldc-gvl*0;
vfloat32m1_t c7 = __riscv_vle32_v_f32m1( &C[ci], gvl);
c0 = __riscv_vfmacc_vf_f32m1(c0, alpha, result0, gvl);
c1 = __riscv_vfmacc_vf_f32m1(c1, alpha, result1, gvl);
c2 = __riscv_vfmacc_vf_f32m1(c2, alpha, result2, gvl);
c3 = __riscv_vfmacc_vf_f32m1(c3, alpha, result3, gvl);
c4 = __riscv_vfmacc_vf_f32m1(c4, alpha, result4, gvl);
c5 = __riscv_vfmacc_vf_f32m1(c5, alpha, result5, gvl);
c6 = __riscv_vfmacc_vf_f32m1(c6, alpha, result6, gvl);
c7 = __riscv_vfmacc_vf_f32m1(c7, alpha, result7, gvl);

ci=n_top*ldc+m_top;

__riscv_vse32_v_f32m1(&C[ci], c0, gvl); ci += ldc - gvl * 0;
__riscv_vse32_v_f32m1(&C[ci], c1, gvl); ci += ldc - gvl * 0;
__riscv_vse32_v_f32m1(&C[ci], c2, gvl); ci += ldc - gvl * 0;
__riscv_vse32_v_f32m1(&C[ci], c3, gvl); ci += ldc - gvl * 0;
__riscv_vse32_v_f32m1(&C[ci], c4, gvl); ci += ldc - gvl * 0;
__riscv_vse32_v_f32m1(&C[ci], c5, gvl); ci += ldc - gvl * 0;
__riscv_vse32_v_f32m1(&C[ci], c6, gvl); ci += ldc - gvl * 0;
__riscv_vse32_v_f32m1(&C[ci], c7, gvl);
m_top += 8;
}


if( M & 4 ) {
gvl = __riscv_vsetvl_e16mf2(4);

BLASLONG ai=m_top*K;
BLASLONG bi=n_top*K;
_Float16 B0 = B[bi+0];
_Float16 B1 = B[bi+1];
_Float16 B2 = B[bi+2];
_Float16 B3 = B[bi+3];
_Float16 B4 = B[bi+4];
_Float16 B5 = B[bi+5];
_Float16 B6 = B[bi+6];
_Float16 B7 = B[bi+7];
bi += 8;

vfloat16mf2_t A0 = __riscv_vle16_v_f16mf2( &A[ai+0*gvl], gvl );
ai += 4;

vfloat32m1_t result0 = __riscv_vfwmul_vf_f32m1( A0, B0, gvl);
vfloat32m1_t result1 = __riscv_vfwmul_vf_f32m1( A0, B1, gvl);
vfloat32m1_t result2 = __riscv_vfwmul_vf_f32m1( A0, B2, gvl);
vfloat32m1_t result3 = __riscv_vfwmul_vf_f32m1( A0, B3, gvl);
vfloat32m1_t result4 = __riscv_vfwmul_vf_f32m1( A0, B4, gvl);
vfloat32m1_t result5 = __riscv_vfwmul_vf_f32m1( A0, B5, gvl);
vfloat32m1_t result6 = __riscv_vfwmul_vf_f32m1( A0, B6, gvl);
vfloat32m1_t result7 = __riscv_vfwmul_vf_f32m1( A0, B7, gvl);

for(BLASLONG k=1; k < K; ++k) {
B0 = B[bi+0];
B1 = B[bi+1];
B2 = B[bi+2];
B3 = B[bi+3];
B4 = B[bi+4];
B5 = B[bi+5];
B6 = B[bi+6];
B7 = B[bi+7];
bi += 8;

A0 = __riscv_vle16_v_f16mf2( &A[ai+0*gvl], gvl );
ai += 4;

result0 = __riscv_vfwmacc_vf_f32m1(result0, B0, A0, gvl);
result1 = __riscv_vfwmacc_vf_f32m1(result1, B1, A0, gvl);
result2 = __riscv_vfwmacc_vf_f32m1(result2, B2, A0, gvl);
result3 = __riscv_vfwmacc_vf_f32m1(result3, B3, A0, gvl);
result4 = __riscv_vfwmacc_vf_f32m1(result4, B4, A0, gvl);
result5 = __riscv_vfwmacc_vf_f32m1(result5, B5, A0, gvl);
result6 = __riscv_vfwmacc_vf_f32m1(result6, B6, A0, gvl);
result7 = __riscv_vfwmacc_vf_f32m1(result7, B7, A0, gvl);
}

BLASLONG ci = n_top * ldc + m_top;

vfloat32m1_t c0 = __riscv_vle32_v_f32m1(&C[ci], gvl);
ci += ldc - gvl * 0;
vfloat32m1_t c1 = __riscv_vle32_v_f32m1(&C[ci], gvl);
ci += ldc - gvl * 0;
vfloat32m1_t c2 = __riscv_vle32_v_f32m1(&C[ci], gvl);
ci += ldc - gvl * 0;
vfloat32m1_t c3 = __riscv_vle32_v_f32m1(&C[ci], gvl);
ci += ldc - gvl * 0;
vfloat32m1_t c4 = __riscv_vle32_v_f32m1(&C[ci], gvl);
ci += ldc - gvl * 0;
vfloat32m1_t c5 = __riscv_vle32_v_f32m1(&C[ci], gvl);
ci += ldc - gvl * 0;
vfloat32m1_t c6 = __riscv_vle32_v_f32m1(&C[ci], gvl);
ci += ldc - gvl * 0;
vfloat32m1_t c7 = __riscv_vle32_v_f32m1(&C[ci], gvl);
c0 = __riscv_vfmacc_vf_f32m1(c0, alpha, result0, gvl);
c1 = __riscv_vfmacc_vf_f32m1(c1, alpha, result1, gvl);
c2 = __riscv_vfmacc_vf_f32m1(c2, alpha, result2, gvl);
c3 = __riscv_vfmacc_vf_f32m1(c3, alpha, result3, gvl);
c4 = __riscv_vfmacc_vf_f32m1(c4, alpha, result4, gvl);
c5 = __riscv_vfmacc_vf_f32m1(c5, alpha, result5, gvl);
c6 = __riscv_vfmacc_vf_f32m1(c6, alpha, result6, gvl);
c7 = __riscv_vfmacc_vf_f32m1(c7, alpha, result7, gvl);

ci= n_top * ldc + m_top;

__riscv_vse32_v_f32m1(&C[ci], c0, gvl); ci += ldc - gvl * 0;
__riscv_vse32_v_f32m1(&C[ci], c1, gvl); ci += ldc - gvl * 0;
__riscv_vse32_v_f32m1(&C[ci], c2, gvl); ci += ldc - gvl * 0;
__riscv_vse32_v_f32m1(&C[ci], c3, gvl); ci += ldc - gvl * 0;
__riscv_vse32_v_f32m1(&C[ci], c4, gvl); ci += ldc - gvl * 0;
__riscv_vse32_v_f32m1(&C[ci], c5, gvl); ci += ldc - gvl * 0;
__riscv_vse32_v_f32m1(&C[ci], c6, gvl); ci += ldc - gvl * 0;
__riscv_vse32_v_f32m1(&C[ci], c7, gvl);
m_top += 4;
}

if( M & 2 ) {
float result0 = 0;
float result1 = 0;
float result2 = 0;
float result3 = 0;
float result4 = 0;
float result5 = 0;
float result6 = 0;
float result7 = 0;
float result8 = 0;
float result9 = 0;
float result10 = 0;
float result11 = 0;
float result12 = 0;
float result13 = 0;
float result14 = 0;
float result15 = 0;
BLASLONG ai = m_top * K;
BLASLONG bi = n_top * K;
for(BLASLONG k=0; k<K; k++) {
result0+=(float)(A[ai+0]*B[bi+0]);
result1+=(float)(A[ai+1]*B[bi+0]);
result2+=(float)(A[ai+0]*B[bi+1]);
result3+=(float)(A[ai+1]*B[bi+1]);
result4+=(float)(A[ai+0]*B[bi+2]);
result5+=(float)(A[ai+1]*B[bi+2]);
result6+=(float)(A[ai+0]*B[bi+3]);
result7+=(float)(A[ai+1]*B[bi+3]);
result8+=(float)(A[ai+0]*B[bi+4]);
result9+=(float)(A[ai+1]*B[bi+4]);
result10+=(float)(A[ai+0]*B[bi+5]);
result11+=(float)(A[ai+1]*B[bi+5]);
result12+=(float)(A[ai+0]*B[bi+6]);
result13+=(float)(A[ai+1]*B[bi+6]);
result14+=(float)(A[ai+0]*B[bi+7]);
result15+=(float)(A[ai+1]*B[bi+7]);
ai+=2;
bi+=8;
}

BLASLONG ci=n_top*ldc+m_top;

C[ci + 0 * ldc + 0] += alpha * result0;
C[ci + 0 * ldc + 1] += alpha * result1;
C[ci + 1 * ldc + 0] += alpha * result2;
C[ci + 1 * ldc + 1] += alpha * result3;
C[ci + 2 * ldc + 0] += alpha * result4;
C[ci + 2 * ldc + 1] += alpha * result5;
C[ci + 3 * ldc + 0] += alpha * result6;
C[ci + 3 * ldc + 1] += alpha * result7;
C[ci + 4 * ldc + 0] += alpha * result8;
C[ci + 4 * ldc + 1] += alpha * result9;
C[ci + 5 * ldc + 0] += alpha * result10;
C[ci + 5 * ldc + 1] += alpha * result11;
C[ci + 6 * ldc + 0] += alpha * result12;
C[ci + 6 * ldc + 1] += alpha * result13;
C[ci + 7 * ldc + 0] += alpha * result14;
C[ci + 7 * ldc + 1] += alpha * result15;

m_top+=2;
}


if( M & 1 ) {
float result0 = 0;
float result1 = 0;
float result2 = 0;
float result3 = 0;
float result4 = 0;
float result5 = 0;
float result6 = 0;
float result7 = 0;
BLASLONG ai = m_top * K;
BLASLONG bi = n_top * K;

for(BLASLONG k=0; k<K; k++) {
result0+=(float)(A[ai+0]*B[bi+0]);
result1+=(float)(A[ai+0]*B[bi+1]);
result2+=(float)(A[ai+0]*B[bi+2]);
result3+=(float)(A[ai+0]*B[bi+3]);
result4+=(float)(A[ai+0]*B[bi+4]);
result5+=(float)(A[ai+0]*B[bi+5]);
result6+=(float)(A[ai+0]*B[bi+6]);
result7+=(float)(A[ai+0]*B[bi+7]);
ai+=1;
bi+=8;
}

BLASLONG ci = n_top * ldc + m_top;
C[ci + 0 * ldc + 0] += alpha * result0;
C[ci + 1 * ldc + 0] += alpha * result1;
C[ci + 2 * ldc + 0] += alpha * result2;
C[ci + 3 * ldc + 0] += alpha * result3;
C[ci + 4 * ldc + 0] += alpha * result4;
C[ci + 5 * ldc + 0] += alpha * result5;
C[ci + 6 * ldc + 0] += alpha * result6;
C[ci + 7 * ldc + 0] += alpha * result7;
m_top+=1;
}
n_top += 8;
}

if( N & 4 ) {
gvl = __riscv_vsetvl_e16m1(16);
m_top = 0;

for (BLASLONG i=0; i<M/16; i+=1) {
BLASLONG ai=m_top*K;
BLASLONG bi=n_top*K;
_Float16 B0 = B[bi+0];
_Float16 B1 = B[bi+1];
_Float16 B2 = B[bi+2];
_Float16 B3 = B[bi+3];
bi += 4;

vfloat16m1_t A0 = __riscv_vle16_v_f16m1( &A[ai+0*gvl], gvl );
ai += 16;
vfloat32m2_t result0 = __riscv_vfwmul_vf_f32m2( A0, B0, gvl);
vfloat32m2_t result1 = __riscv_vfwmul_vf_f32m2( A0, B1, gvl);
vfloat32m2_t result2 = __riscv_vfwmul_vf_f32m2( A0, B2, gvl);
vfloat32m2_t result3 = __riscv_vfwmul_vf_f32m2( A0, B3, gvl);
for(BLASLONG k=1; k<K; k++) {
B0 = B[bi+0];
B1 = B[bi+1];
B2 = B[bi+2];
B3 = B[bi+3];
bi += 4;

A0 = __riscv_vle16_v_f16m1( &A[ai+0*gvl], gvl );
ai += 16;
result0 = __riscv_vfwmacc_vf_f32m2(result0, B0, A0, gvl);
result1 = __riscv_vfwmacc_vf_f32m2(result1, B1, A0, gvl);
result2 = __riscv_vfwmacc_vf_f32m2(result2, B2, A0, gvl);
result3 = __riscv_vfwmacc_vf_f32m2(result3, B3, A0, gvl);
}
BLASLONG ci=n_top*ldc+m_top;

vfloat32m2_t c0 = __riscv_vle32_v_f32m2( &C[ci], gvl); ci += ldc-gvl*0;
vfloat32m2_t c1 = __riscv_vle32_v_f32m2( &C[ci], gvl); ci += ldc-gvl*0;
vfloat32m2_t c2 = __riscv_vle32_v_f32m2( &C[ci], gvl); ci += ldc-gvl*0;
vfloat32m2_t c3 = __riscv_vle32_v_f32m2( &C[ci], gvl);
c0 = __riscv_vfmacc_vf_f32m2(c0, alpha, result0, gvl);
c1 = __riscv_vfmacc_vf_f32m2(c1, alpha, result1, gvl);
c2 = __riscv_vfmacc_vf_f32m2(c2, alpha, result2, gvl);
c3 = __riscv_vfmacc_vf_f32m2(c3, alpha, result3, gvl);

ci=n_top*ldc+m_top;

__riscv_vse32_v_f32m2( &C[ci], c0, gvl); ci += ldc-gvl*0;
__riscv_vse32_v_f32m2( &C[ci], c1, gvl); ci += ldc-gvl*0;
__riscv_vse32_v_f32m2( &C[ci], c2, gvl); ci += ldc-gvl*0;
__riscv_vse32_v_f32m2( &C[ci], c3, gvl);
m_top += 16;
}

if( M & 8 ) {
gvl = __riscv_vsetvl_e16mf2(8);
BLASLONG ai=m_top*K;
BLASLONG bi=n_top*K;
_Float16 B0 = B[bi+0];
_Float16 B1 = B[bi+1];
_Float16 B2 = B[bi+2];
_Float16 B3 = B[bi+3];
bi += 4;

vfloat16mf2_t A0 = __riscv_vle16_v_f16mf2( &A[ai+0*gvl], gvl );
ai += 8;

vfloat32m1_t result0 = __riscv_vfwmul_vf_f32m1( A0, B0, gvl);
vfloat32m1_t result1 = __riscv_vfwmul_vf_f32m1( A0, B1, gvl);
vfloat32m1_t result2 = __riscv_vfwmul_vf_f32m1( A0, B2, gvl);
vfloat32m1_t result3 = __riscv_vfwmul_vf_f32m1( A0, B3, gvl);
for(BLASLONG k=1; k<K; k++) {
B0 = B[bi+0];
B1 = B[bi+1];
B2 = B[bi+2];
B3 = B[bi+3];
bi += 4;

A0 = __riscv_vle16_v_f16mf2( &A[ai+0*gvl], gvl );
ai += 8;

result0 = __riscv_vfwmacc_vf_f32m1(result0, B0, A0, gvl);
result1 = __riscv_vfwmacc_vf_f32m1(result1, B1, A0, gvl);
result2 = __riscv_vfwmacc_vf_f32m1(result2, B2, A0, gvl);
result3 = __riscv_vfwmacc_vf_f32m1(result3, B3, A0, gvl);
}

BLASLONG ci=n_top*ldc+m_top;

vfloat32m1_t c0 = __riscv_vle32_v_f32m1( &C[ci], gvl); ci += ldc - gvl * 0;
vfloat32m1_t c1 = __riscv_vle32_v_f32m1( &C[ci], gvl); ci += ldc - gvl * 0;
vfloat32m1_t c2 = __riscv_vle32_v_f32m1( &C[ci], gvl); ci += ldc - gvl * 0;
vfloat32m1_t c3 = __riscv_vle32_v_f32m1( &C[ci], gvl);
c0 = __riscv_vfmacc_vf_f32m1(c0, alpha, result0, gvl);
c1 = __riscv_vfmacc_vf_f32m1(c1, alpha, result1, gvl);
c2 = __riscv_vfmacc_vf_f32m1(c2, alpha, result2, gvl);
c3 = __riscv_vfmacc_vf_f32m1(c3, alpha, result3, gvl);

ci = n_top * ldc + m_top;

__riscv_vse32_v_f32m1( &C[ci], c0, gvl); ci += ldc-gvl*0;
__riscv_vse32_v_f32m1( &C[ci], c1, gvl); ci += ldc-gvl*0;
__riscv_vse32_v_f32m1( &C[ci], c2, gvl); ci += ldc-gvl*0;
__riscv_vse32_v_f32m1( &C[ci], c3, gvl);
m_top += 8;
}

if( M & 4 ) {
gvl = __riscv_vsetvl_e16mf2(4);

BLASLONG ai=m_top*K;
BLASLONG bi=n_top*K;
_Float16 B0 = B[bi+0];
_Float16 B1 = B[bi+1];
_Float16 B2 = B[bi+2];
_Float16 B3 = B[bi+3];
bi += 4;

vfloat16mf2_t A0 = __riscv_vle16_v_f16mf2( &A[ai+0*gvl], gvl );
ai += 4;

vfloat32m1_t result0 = __riscv_vfwmul_vf_f32m1( A0, B0, gvl);
vfloat32m1_t result1 = __riscv_vfwmul_vf_f32m1( A0, B1, gvl);
vfloat32m1_t result2 = __riscv_vfwmul_vf_f32m1( A0, B2, gvl);
vfloat32m1_t result3 = __riscv_vfwmul_vf_f32m1( A0, B3, gvl);

for(BLASLONG k=1; k < K; ++k) {
B0 = B[bi+0];
B1 = B[bi+1];
B2 = B[bi+2];
B3 = B[bi+3];
bi += 4;

A0 = __riscv_vle16_v_f16mf2( &A[ai+0*gvl], gvl );
ai += 4;

result0 = __riscv_vfwmacc_vf_f32m1(result0, B0, A0, gvl);
result1 = __riscv_vfwmacc_vf_f32m1(result1, B1, A0, gvl);
result2 = __riscv_vfwmacc_vf_f32m1(result2, B2, A0, gvl);
result3 = __riscv_vfwmacc_vf_f32m1(result3, B3, A0, gvl);
}

BLASLONG ci = n_top * ldc + m_top;

vfloat32m1_t c0 = __riscv_vle32_v_f32m1(&C[ci], gvl);
ci += ldc - gvl * 0;
vfloat32m1_t c1 = __riscv_vle32_v_f32m1(&C[ci], gvl);
ci += ldc - gvl * 0;
vfloat32m1_t c2 = __riscv_vle32_v_f32m1(&C[ci], gvl);
ci += ldc - gvl * 0;
vfloat32m1_t c3 = __riscv_vle32_v_f32m1(&C[ci], gvl);
c0 = __riscv_vfmacc_vf_f32m1(c0, alpha, result0, gvl);
c1 = __riscv_vfmacc_vf_f32m1(c1, alpha, result1, gvl);
c2 = __riscv_vfmacc_vf_f32m1(c2, alpha, result2, gvl);
c3 = __riscv_vfmacc_vf_f32m1(c3, alpha, result3, gvl);

ci= n_top * ldc + m_top;

__riscv_vse32_v_f32m1(&C[ci], c0, gvl); ci += ldc - gvl * 0;
__riscv_vse32_v_f32m1(&C[ci], c1, gvl); ci += ldc - gvl * 0;
__riscv_vse32_v_f32m1(&C[ci], c2, gvl); ci += ldc - gvl * 0;
__riscv_vse32_v_f32m1(&C[ci], c3, gvl);
m_top += 4;
}


if( M & 2 ) {
float result0 = 0;
float result1 = 0;
float result2 = 0;
float result3 = 0;
float result4 = 0;
float result5 = 0;
float result6 = 0;
float result7 = 0;
BLASLONG ai = m_top * K;
BLASLONG bi = n_top * K;

for(BLASLONG k=0; k<K; k++) {
result0+=(float)(A[ai+0]*B[bi+0]);
result1+=(float)(A[ai+1]*B[bi+0]);
result2+=(float)(A[ai+0]*B[bi+1]);
result3+=(float)(A[ai+1]*B[bi+1]);
result4+=(float)(A[ai+0]*B[bi+2]);
result5+=(float)(A[ai+1]*B[bi+2]);
result6+=(float)(A[ai+0]*B[bi+3]);
result7+=(float)(A[ai+1]*B[bi+3]);
ai+=2;
bi+=4;
}
BLASLONG ci=n_top*ldc+m_top;
C[ci + 0 * ldc + 0] += alpha * result0;
C[ci + 0 * ldc + 1] += alpha * result1;
C[ci + 1 * ldc + 0] += alpha * result2;
C[ci + 1 * ldc + 1] += alpha * result3;
C[ci + 2 * ldc + 0] += alpha * result4;
C[ci + 2 * ldc + 1] += alpha * result5;
C[ci + 3 * ldc + 0] += alpha * result6;
C[ci + 3 * ldc + 1] += alpha * result7;

m_top += 2;
}


if( M & 1 ) {
float result0 = 0;
float result1 = 0;
float result2 = 0;
float result3 = 0;
BLASLONG ai = m_top * K;
BLASLONG bi = n_top * K;

for(BLASLONG k=0; k<K; k++) {
result0+=(float)(A[ai+0]*B[bi+0]);
result1+=(float)(A[ai+0]*B[bi+1]);
result2+=(float)(A[ai+0]*B[bi+2]);
result3+=(float)(A[ai+0]*B[bi+3]);
ai+=1;
bi+=4;
}

BLASLONG ci = n_top * ldc + m_top;
C[ci + 0 * ldc + 0] += alpha * result0;
C[ci + 1 * ldc + 0] += alpha * result1;
C[ci + 2 * ldc + 0] += alpha * result2;
C[ci + 3 * ldc + 0] += alpha * result3;
m_top += 1;
}

n_top += 4;
}



// -- tails for N=2
if( N & 2 ) {
gvl = __riscv_vsetvl_e16m1(16);
m_top = 0;

for (BLASLONG i=0; i<M/16; i+=1) {
BLASLONG ai=m_top*K;
BLASLONG bi=n_top*K;
_Float16 B0 = B[bi+0];
_Float16 B1 = B[bi+1];
bi += 2;

vfloat16m1_t A0 = __riscv_vle16_v_f16m1( &A[ai+0*gvl], gvl );
ai += 16;
vfloat32m2_t result0 = __riscv_vfwmul_vf_f32m2( A0, B0, gvl);
vfloat32m2_t result1 = __riscv_vfwmul_vf_f32m2( A0, B1, gvl);
for(BLASLONG k=1; k<K; k++) {
B0 = B[bi+0];
B1 = B[bi+1];
bi += 2;

A0 = __riscv_vle16_v_f16m1( &A[ai+0*gvl], gvl );
ai += 16;

result0 = __riscv_vfwmacc_vf_f32m2(result0, B0, A0, gvl);
result1 = __riscv_vfwmacc_vf_f32m2(result1, B1, A0, gvl);
}

BLASLONG ci=n_top*ldc+m_top;

vfloat32m2_t c0 = __riscv_vle32_v_f32m2( &C[ci], gvl); ci += ldc-gvl*0;
vfloat32m2_t c1 = __riscv_vle32_v_f32m2( &C[ci], gvl);
c0 = __riscv_vfmacc_vf_f32m2(c0, alpha, result0, gvl);
c1 = __riscv_vfmacc_vf_f32m2(c1, alpha, result1, gvl);

ci=n_top*ldc+m_top;

__riscv_vse32_v_f32m2( &C[ci], c0, gvl); ci += ldc-gvl*0;
__riscv_vse32_v_f32m2( &C[ci], c1, gvl);
m_top += 16;
}

if( M & 8 ) {
gvl = __riscv_vsetvl_e16mf2(8);
BLASLONG ai=m_top*K;
BLASLONG bi=n_top*K;
_Float16 B0 = B[bi+0];
_Float16 B1 = B[bi+1];
bi += 2;

vfloat16mf2_t A0 = __riscv_vle16_v_f16mf2( &A[ai+0*gvl], gvl );
ai += 8;

vfloat32m1_t result0 = __riscv_vfwmul_vf_f32m1( A0, B0, gvl);
vfloat32m1_t result1 = __riscv_vfwmul_vf_f32m1( A0, B1, gvl);
for(BLASLONG k=1; k<K; k++) {
B0 = B[bi+0];
B1 = B[bi+1];
bi += 2;

A0 = __riscv_vle16_v_f16mf2( &A[ai+0*gvl], gvl );
ai += 8;

result0 = __riscv_vfwmacc_vf_f32m1(result0, B0, A0, gvl);
result1 = __riscv_vfwmacc_vf_f32m1(result1, B1, A0, gvl);
}


BLASLONG ci=n_top*ldc+m_top;

vfloat32m1_t c0 = __riscv_vle32_v_f32m1( &C[ci], gvl); ci += ldc - gvl * 0;
vfloat32m1_t c1 = __riscv_vle32_v_f32m1( &C[ci], gvl);
c0 = __riscv_vfmacc_vf_f32m1(c0, alpha, result0, gvl);
c1 = __riscv_vfmacc_vf_f32m1(c1, alpha, result1, gvl);

ci = n_top * ldc + m_top;

__riscv_vse32_v_f32m1( &C[ci], c0, gvl); ci += ldc-gvl*0;
__riscv_vse32_v_f32m1( &C[ci], c1, gvl);
m_top += 8;
}

if( M & 4 ) {
gvl = __riscv_vsetvl_e16mf2(4);

BLASLONG ai=m_top*K;
BLASLONG bi=n_top*K;
_Float16 B0 = B[bi+0];
_Float16 B1 = B[bi+1];
bi += 2;

vfloat16mf2_t A0 = __riscv_vle16_v_f16mf2( &A[ai+0*gvl], gvl );
ai += 4;

vfloat32m1_t result0 = __riscv_vfwmul_vf_f32m1( A0, B0, gvl);
vfloat32m1_t result1 = __riscv_vfwmul_vf_f32m1( A0, B1, gvl);

for(BLASLONG k=1; k < K; ++k) {
B0 = B[bi+0];
B1 = B[bi+1];
bi += 2;

A0 = __riscv_vle16_v_f16mf2( &A[ai+0*gvl], gvl );
ai += 4;

result0 = __riscv_vfwmacc_vf_f32m1(result0, B0, A0, gvl);
result1 = __riscv_vfwmacc_vf_f32m1(result1, B1, A0, gvl);
}

BLASLONG ci = n_top * ldc + m_top;

vfloat32m1_t c0 = __riscv_vle32_v_f32m1(&C[ci], gvl);
ci += ldc - gvl * 0;
vfloat32m1_t c1 = __riscv_vle32_v_f32m1(&C[ci], gvl);
c0 = __riscv_vfmacc_vf_f32m1(c0, alpha, result0, gvl);
c1 = __riscv_vfmacc_vf_f32m1(c1, alpha, result1, gvl);

ci= n_top * ldc + m_top;

__riscv_vse32_v_f32m1(&C[ci], c0, gvl); ci += ldc - gvl * 0;
__riscv_vse32_v_f32m1(&C[ci], c1, gvl);
m_top += 4;
}


if( M & 2 ) {
float result0 = 0;
float result1 = 0;
float result2 = 0;
float result3 = 0;
BLASLONG ai = m_top * K;
BLASLONG bi = n_top * K;

for(BLASLONG k=0; k<K; k++) {
result0+=(float)(A[ai+0]*B[bi+0]);
result1+=(float)(A[ai+1]*B[bi+0]);
result2+=(float)(A[ai+0]*B[bi+1]);
result3+=(float)(A[ai+1]*B[bi+1]);
ai+=2;
bi+=2;
}
BLASLONG ci=n_top*ldc+m_top;
C[ci + 0 * ldc + 0] += alpha * result0;
C[ci + 0 * ldc + 1] += alpha * result1;
C[ci + 1 * ldc + 0] += alpha * result2;
C[ci + 1 * ldc + 1] += alpha * result3;

m_top += 2;
}


if( M & 1 ) {
float result0 = 0;
float result1 = 0;
BLASLONG ai = m_top * K;
BLASLONG bi = n_top * K;

for(BLASLONG k=0; k<K; k++) {
result0+=(float)(A[ai+0]*B[bi+0]);
result1+=(float)(A[ai+0]*B[bi+1]);
ai+=1;
bi+=2;
}

BLASLONG ci = n_top * ldc + m_top;
C[ci + 0 * ldc + 0] += alpha * result0;
C[ci + 1 * ldc + 0] += alpha * result1;
m_top += 1;
}

n_top += 2;
}



// -- tails for N=1
if( N & 1 ) {
gvl = __riscv_vsetvl_e16m1(16);
m_top = 0;

for (BLASLONG i=0; i<M/16; i+=1) {
BLASLONG ai=m_top*K;
BLASLONG bi=n_top*K;
_Float16 B0 = B[bi+0];
bi += 1;

vfloat16m1_t A0 = __riscv_vle16_v_f16m1( &A[ai+0*gvl], gvl );
ai += 16;

vfloat32m2_t result0 = __riscv_vfwmul_vf_f32m2( A0, B0, gvl);

for(BLASLONG k=1; k<K; k++) {
B0 = B[bi+0];
bi += 1;

A0 = __riscv_vle16_v_f16m1( &A[ai+0*gvl], gvl );
ai += 16;
result0 = __riscv_vfwmacc_vf_f32m2(result0, B0, A0, gvl);
}
BLASLONG ci=n_top*ldc+m_top;

vfloat32m2_t c0 = __riscv_vle32_v_f32m2( &C[ci], gvl);
c0 = __riscv_vfmacc_vf_f32m2(c0, alpha, result0, gvl);

ci=n_top*ldc+m_top;

__riscv_vse32_v_f32m2( &C[ci], c0, gvl);
m_top += 16;
}

if( M & 8 ) {
gvl = __riscv_vsetvl_e16mf2(8);
BLASLONG ai=m_top*K;
BLASLONG bi=n_top*K;
_Float16 B0 = B[bi+0];
bi += 1;

vfloat16mf2_t A0 = __riscv_vle16_v_f16mf2( &A[ai+0*gvl], gvl );
ai += 8;

vfloat32m1_t result0 = __riscv_vfwmul_vf_f32m1( A0, B0, gvl);
for(BLASLONG k=1; k<K; k++) {
B0 = B[bi+0];
bi += 1;

A0 = __riscv_vle16_v_f16mf2( &A[ai+0*gvl], gvl );
ai += 8;

result0 = __riscv_vfwmacc_vf_f32m1(result0, B0, A0, gvl);
}


BLASLONG ci=n_top*ldc+m_top;

vfloat32m1_t c0 = __riscv_vle32_v_f32m1( &C[ci], gvl);
c0 = __riscv_vfmacc_vf_f32m1(c0, alpha, result0, gvl);

ci = n_top * ldc + m_top;

__riscv_vse32_v_f32m1( &C[ci], c0, gvl);
m_top += 8;
}

if( M & 4 ) {
gvl = __riscv_vsetvl_e16mf2(4);

BLASLONG ai=m_top*K;
BLASLONG bi=n_top*K;
_Float16 B0 = B[bi+0];
bi += 1;

vfloat16mf2_t A0 = __riscv_vle16_v_f16mf2( &A[ai+0*gvl], gvl );
ai += 4;

vfloat32m1_t result0 = __riscv_vfwmul_vf_f32m1( A0, B0, gvl);

for(BLASLONG k=1; k < K; ++k) {
B0 = B[bi+0];
bi += 1;

A0 = __riscv_vle16_v_f16mf2( &A[ai+0*gvl], gvl );
ai += 4;

result0 = __riscv_vfwmacc_vf_f32m1(result0, B0, A0, gvl);
}

BLASLONG ci = n_top * ldc + m_top;

vfloat32m1_t c0 = __riscv_vle32_v_f32m1(&C[ci], gvl);
c0 = __riscv_vfmacc_vf_f32m1(c0, alpha, result0, gvl);

ci= n_top * ldc + m_top;

__riscv_vse32_v_f32m1(&C[ci], c0, gvl);
m_top += 4;
}


if( M & 2 ) {
float result0 = 0;
float result1 = 0;
BLASLONG ai = m_top * K;
BLASLONG bi = n_top * K;

for(BLASLONG k=0; k<K; k++) {
result0+=(float)(A[ai+0]*B[bi+0]);
result1+=(float)(A[ai+1]*B[bi+0]);
ai+=2;
bi+=1;
}
BLASLONG ci=n_top*ldc+m_top;
C[ci + 0 * ldc + 0] += alpha * result0;
C[ci + 0 * ldc + 1] += alpha * result1;

m_top += 2;
}


if( M & 1 ) {
float result0 = 0;
BLASLONG ai = m_top * K;
BLASLONG bi = n_top * K;

for(BLASLONG k=0; k<K; k++) {
result0+=(float)(A[ai+0]*B[bi+0]);
ai+=1;
bi+=1;
}

BLASLONG ci = n_top * ldc + m_top;
C[ci + 0 * ldc + 0] += alpha * result0;
m_top += 1;
}

n_top += 1;
}
return 0;
}

+ 767
- 0
kernel/riscv64/shgemm_kernel_8x8_zvl128b.c View File

@@ -0,0 +1,767 @@

#include "common.h"
#include <riscv_vector.h>

int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B, FLOAT *C, BLASLONG ldc)
{
BLASLONG gvl = 0;
BLASLONG m_top = 0;
BLASLONG n_top = 0;

// -- MAIN PASS
for (BLASLONG j=0; j<N/8; j+=1) {
m_top = 0;
BLASLONG gvl = __riscv_vsetvl_e16m1(8);

for (BLASLONG i=0; i<M/8; i+=1) {
BLASLONG ai=m_top*K;
BLASLONG bi=n_top*K;
_Float16 B0 = B[bi+0];
_Float16 B1 = B[bi+1];
_Float16 B2 = B[bi+2];
_Float16 B3 = B[bi+3];
_Float16 B4 = B[bi+4];
_Float16 B5 = B[bi+5];
_Float16 B6 = B[bi+6];
_Float16 B7 = B[bi+7];
bi += 8;

vfloat16m1_t A0 = __riscv_vle16_v_f16m1( &A[ai+0*gvl], gvl );
ai += 8;

vfloat32m2_t result0 = __riscv_vfwmul_vf_f32m2( A0, B0, gvl);
vfloat32m2_t result1 = __riscv_vfwmul_vf_f32m2( A0, B1, gvl);
vfloat32m2_t result2 = __riscv_vfwmul_vf_f32m2( A0, B2, gvl);
vfloat32m2_t result3 = __riscv_vfwmul_vf_f32m2( A0, B3, gvl);
vfloat32m2_t result4 = __riscv_vfwmul_vf_f32m2( A0, B4, gvl);
vfloat32m2_t result5 = __riscv_vfwmul_vf_f32m2( A0, B5, gvl);
vfloat32m2_t result6 = __riscv_vfwmul_vf_f32m2( A0, B6, gvl);
vfloat32m2_t result7 = __riscv_vfwmul_vf_f32m2( A0, B7, gvl);
for(BLASLONG k=1; k<K; k++) {
B0 = B[bi+0];
B1 = B[bi+1];
B2 = B[bi+2];
B3 = B[bi+3];
B4 = B[bi+4];
B5 = B[bi+5];
B6 = B[bi+6];
B7 = B[bi+7];
bi += 8;

A0 = __riscv_vle16_v_f16m1( &A[ai+0*gvl], gvl );
ai += 8;

result0 = __riscv_vfwmacc_vf_f32m2(result0, B0, A0, gvl);
result1 = __riscv_vfwmacc_vf_f32m2(result1, B1, A0, gvl);
result2 = __riscv_vfwmacc_vf_f32m2(result2, B2, A0, gvl);
result3 = __riscv_vfwmacc_vf_f32m2(result3, B3, A0, gvl);
result4 = __riscv_vfwmacc_vf_f32m2(result4, B4, A0, gvl);
result5 = __riscv_vfwmacc_vf_f32m2(result5, B5, A0, gvl);
result6 = __riscv_vfwmacc_vf_f32m2(result6, B6, A0, gvl);
result7 = __riscv_vfwmacc_vf_f32m2(result7, B7, A0, gvl);
}

BLASLONG ci=n_top*ldc+m_top;

vfloat32m2_t c0 = __riscv_vle32_v_f32m2( &C[ci], gvl); ci += ldc-gvl*0;
vfloat32m2_t c1 = __riscv_vle32_v_f32m2( &C[ci], gvl); ci += ldc-gvl*0;
vfloat32m2_t c2 = __riscv_vle32_v_f32m2( &C[ci], gvl); ci += ldc-gvl*0;
vfloat32m2_t c3 = __riscv_vle32_v_f32m2( &C[ci], gvl); ci += ldc-gvl*0;
vfloat32m2_t c4 = __riscv_vle32_v_f32m2( &C[ci], gvl); ci += ldc-gvl*0;
vfloat32m2_t c5 = __riscv_vle32_v_f32m2( &C[ci], gvl); ci += ldc-gvl*0;
vfloat32m2_t c6 = __riscv_vle32_v_f32m2( &C[ci], gvl); ci += ldc-gvl*0;
vfloat32m2_t c7 = __riscv_vle32_v_f32m2( &C[ci], gvl); ci += ldc-gvl*0;
c0 = __riscv_vfmacc_vf_f32m2(c0, alpha, result0, gvl);
c1 = __riscv_vfmacc_vf_f32m2(c1, alpha, result1, gvl);
c2 = __riscv_vfmacc_vf_f32m2(c2, alpha, result2, gvl);
c3 = __riscv_vfmacc_vf_f32m2(c3, alpha, result3, gvl);
c4 = __riscv_vfmacc_vf_f32m2(c4, alpha, result4, gvl);
c5 = __riscv_vfmacc_vf_f32m2(c5, alpha, result5, gvl);
c6 = __riscv_vfmacc_vf_f32m2(c6, alpha, result6, gvl);
c7 = __riscv_vfmacc_vf_f32m2(c7, alpha, result7, gvl);

ci = n_top * ldc + m_top;

__riscv_vse32_v_f32m2( &C[ci], c0, gvl); ci += ldc-gvl*0;
__riscv_vse32_v_f32m2( &C[ci], c1, gvl); ci += ldc-gvl*0;
__riscv_vse32_v_f32m2( &C[ci], c2, gvl); ci += ldc-gvl*0;
__riscv_vse32_v_f32m2( &C[ci], c3, gvl); ci += ldc-gvl*0;
__riscv_vse32_v_f32m2( &C[ci], c4, gvl); ci += ldc-gvl*0;
__riscv_vse32_v_f32m2( &C[ci], c5, gvl); ci += ldc-gvl*0;
__riscv_vse32_v_f32m2( &C[ci], c6, gvl); ci += ldc-gvl*0;
__riscv_vse32_v_f32m2( &C[ci], c7, gvl); ci += ldc-gvl*0;
m_top += 8;
}

// -- tails for main pass --

if( M & 4 ) {
gvl = __riscv_vsetvl_e16m1(4);

BLASLONG ai=m_top*K;
BLASLONG bi=n_top*K;
_Float16 B0 = B[bi+0];
_Float16 B1 = B[bi+1];
_Float16 B2 = B[bi+2];
_Float16 B3 = B[bi+3];
_Float16 B4 = B[bi+4];
_Float16 B5 = B[bi+5];
_Float16 B6 = B[bi+6];
_Float16 B7 = B[bi+7];
bi += 8;

vfloat16m1_t A0 = __riscv_vle16_v_f16m1(&A[ai + 0 * gvl], gvl);
ai += 4;

vfloat32m2_t result0 = __riscv_vfwmul_vf_f32m2( A0, B0, gvl);
vfloat32m2_t result1 = __riscv_vfwmul_vf_f32m2( A0, B1, gvl);
vfloat32m2_t result2 = __riscv_vfwmul_vf_f32m2( A0, B2, gvl);
vfloat32m2_t result3 = __riscv_vfwmul_vf_f32m2( A0, B3, gvl);
vfloat32m2_t result4 = __riscv_vfwmul_vf_f32m2( A0, B4, gvl);
vfloat32m2_t result5 = __riscv_vfwmul_vf_f32m2( A0, B5, gvl);
vfloat32m2_t result6 = __riscv_vfwmul_vf_f32m2( A0, B6, gvl);
vfloat32m2_t result7 = __riscv_vfwmul_vf_f32m2( A0, B7, gvl);

for(BLASLONG k=1; k < K; ++k) {
B0 = B[bi+0];
B1 = B[bi+1];
B2 = B[bi+2];
B3 = B[bi+3];
B4 = B[bi+4];
B5 = B[bi+5];
B6 = B[bi+6];
B7 = B[bi+7];
bi += 8;

A0 = __riscv_vle16_v_f16m1(&A[ai + 0 * gvl], gvl);
ai += 4;

result0 = __riscv_vfwmacc_vf_f32m2(result0, B0, A0, gvl);
result1 = __riscv_vfwmacc_vf_f32m2(result1, B1, A0, gvl);
result2 = __riscv_vfwmacc_vf_f32m2(result2, B2, A0, gvl);
result3 = __riscv_vfwmacc_vf_f32m2(result3, B3, A0, gvl);
result4 = __riscv_vfwmacc_vf_f32m2(result4, B4, A0, gvl);
result5 = __riscv_vfwmacc_vf_f32m2(result5, B5, A0, gvl);
result6 = __riscv_vfwmacc_vf_f32m2(result6, B6, A0, gvl);
result7 = __riscv_vfwmacc_vf_f32m2(result7, B7, A0, gvl);
}

BLASLONG ci = n_top * ldc + m_top;

vfloat32m2_t c0 = __riscv_vle32_v_f32m2(&C[ci], gvl);
ci += ldc - gvl * 0;
vfloat32m2_t c1 = __riscv_vle32_v_f32m2(&C[ci], gvl);
ci += ldc - gvl * 0;
vfloat32m2_t c2 = __riscv_vle32_v_f32m2(&C[ci], gvl);
ci += ldc - gvl * 0;
vfloat32m2_t c3 = __riscv_vle32_v_f32m2(&C[ci], gvl);
ci += ldc - gvl * 0;
vfloat32m2_t c4 = __riscv_vle32_v_f32m2(&C[ci], gvl);
ci += ldc - gvl * 0;
vfloat32m2_t c5 = __riscv_vle32_v_f32m2(&C[ci], gvl);
ci += ldc - gvl * 0;
vfloat32m2_t c6 = __riscv_vle32_v_f32m2(&C[ci], gvl);
ci += ldc - gvl * 0;
vfloat32m2_t c7 = __riscv_vle32_v_f32m2(&C[ci], gvl);
c0 = __riscv_vfmacc_vf_f32m2(c0, alpha, result0, gvl);
c1 = __riscv_vfmacc_vf_f32m2(c1, alpha, result1, gvl);
c2 = __riscv_vfmacc_vf_f32m2(c2, alpha, result2, gvl);
c3 = __riscv_vfmacc_vf_f32m2(c3, alpha, result3, gvl);
c4 = __riscv_vfmacc_vf_f32m2(c4, alpha, result4, gvl);
c5 = __riscv_vfmacc_vf_f32m2(c5, alpha, result5, gvl);
c6 = __riscv_vfmacc_vf_f32m2(c6, alpha, result6, gvl);
c7 = __riscv_vfmacc_vf_f32m2(c7, alpha, result7, gvl);

ci= n_top * ldc + m_top;

__riscv_vse32_v_f32m2(&C[ci], c0, gvl); ci += ldc - gvl * 0;
__riscv_vse32_v_f32m2(&C[ci], c1, gvl); ci += ldc - gvl * 0;
__riscv_vse32_v_f32m2(&C[ci], c2, gvl); ci += ldc - gvl * 0;
__riscv_vse32_v_f32m2(&C[ci], c3, gvl); ci += ldc - gvl * 0;
__riscv_vse32_v_f32m2(&C[ci], c4, gvl); ci += ldc - gvl * 0;
__riscv_vse32_v_f32m2(&C[ci], c5, gvl); ci += ldc - gvl * 0;
__riscv_vse32_v_f32m2(&C[ci], c6, gvl); ci += ldc - gvl * 0;
__riscv_vse32_v_f32m2(&C[ci], c7, gvl);
m_top += 4;
}


if( M & 2 ) {

BLASLONG ai = m_top * K;
BLASLONG bi = n_top * K;

float result0 = 0;
float result1 = 0;
float result2 = 0;
float result3 = 0;
float result4 = 0;
float result5 = 0;
float result6 = 0;
float result7 = 0;
float result8 = 0;
float result9 = 0;
float result10 = 0;
float result11 = 0;
float result12 = 0;
float result13 = 0;
float result14 = 0;
float result15 = 0;

for(BLASLONG k=0; k<K; k++) {
result0+=(float)(A[ai+0]*B[bi+0]);
result1+=(float)(A[ai+1]*B[bi+0]);
result2+=(float)(A[ai+0]*B[bi+1]);
result3+=(float)(A[ai+1]*B[bi+1]);
result4+=(float)(A[ai+0]*B[bi+2]);
result5+=(float)(A[ai+1]*B[bi+2]);
result6+=(float)(A[ai+0]*B[bi+3]);
result7+=(float)(A[ai+1]*B[bi+3]);
result8+=(float)(A[ai+0]*B[bi+4]);
result9+=(float)(A[ai+1]*B[bi+4]);
result10+=(float)(A[ai+0]*B[bi+5]);
result11+=(float)(A[ai+1]*B[bi+5]);
result12+=(float)(A[ai+0]*B[bi+6]);
result13+=(float)(A[ai+1]*B[bi+6]);
result14+=(float)(A[ai+0]*B[bi+7]);
result15+=(float)(A[ai+1]*B[bi+7]);
ai+=2;
bi+=8;
}
BLASLONG ci=n_top*ldc+m_top;
C[ci + 0 * ldc + 0] += alpha * result0;
C[ci + 0 * ldc + 1] += alpha * result1;
C[ci + 1 * ldc + 0] += alpha * result2;
C[ci + 1 * ldc + 1] += alpha * result3;
C[ci + 2 * ldc + 0] += alpha * result4;
C[ci + 2 * ldc + 1] += alpha * result5;
C[ci + 3 * ldc + 0] += alpha * result6;
C[ci + 3 * ldc + 1] += alpha * result7;
C[ci + 4 * ldc + 0] += alpha * result8;
C[ci + 4 * ldc + 1] += alpha * result9;
C[ci + 5 * ldc + 0] += alpha * result10;
C[ci + 5 * ldc + 1] += alpha * result11;
C[ci + 6 * ldc + 0] += alpha * result12;
C[ci + 6 * ldc + 1] += alpha * result13;
C[ci + 7 * ldc + 0] += alpha * result14;
C[ci + 7 * ldc + 1] += alpha * result15;

m_top+=2;
}


if( M & 1 ) {
float result0 = 0;
float result1 = 0;
float result2 = 0;
float result3 = 0;
float result4 = 0;
float result5 = 0;
float result6 = 0;
float result7 = 0;
BLASLONG ai = m_top * K;
BLASLONG bi = n_top * K;

for(BLASLONG k=0; k<K; k++) {
result0+=(float)(A[ai+0]*B[bi+0]);
result1+=(float)(A[ai+0]*B[bi+1]);
result2+=(float)(A[ai+0]*B[bi+2]);
result3+=(float)(A[ai+0]*B[bi+3]);
result4+=(float)(A[ai+0]*B[bi+4]);
result5+=(float)(A[ai+0]*B[bi+5]);
result6+=(float)(A[ai+0]*B[bi+6]);
result7+=(float)(A[ai+0]*B[bi+7]);
ai+=1;
bi+=8;
}

BLASLONG ci = n_top * ldc + m_top;
C[ci + 0 * ldc + 0] += alpha * result0;
C[ci + 1 * ldc + 0] += alpha * result1;
C[ci + 2 * ldc + 0] += alpha * result2;
C[ci + 3 * ldc + 0] += alpha * result3;
C[ci + 4 * ldc + 0] += alpha * result4;
C[ci + 5 * ldc + 0] += alpha * result5;
C[ci + 6 * ldc + 0] += alpha * result6;
C[ci + 7 * ldc + 0] += alpha * result7;
m_top+=1;
}

n_top += 8;
}

// -- tails for N=4
if( N & 4 ) {
gvl = __riscv_vsetvl_e16m1(8);
m_top = 0;

for (BLASLONG i=0; i<M/8; i+=1) {
BLASLONG ai=m_top*K;
BLASLONG bi=n_top*K;
_Float16 B0 = B[bi+0];
_Float16 B1 = B[bi+1];
_Float16 B2 = B[bi+2];
_Float16 B3 = B[bi+3];
bi += 4;

vfloat16m1_t A0 = __riscv_vle16_v_f16m1( &A[ai+0*gvl], gvl );
ai += 8;

vfloat32m2_t result0 = __riscv_vfwmul_vf_f32m2( A0, B0, gvl);
vfloat32m2_t result1 = __riscv_vfwmul_vf_f32m2( A0, B1, gvl);
vfloat32m2_t result2 = __riscv_vfwmul_vf_f32m2( A0, B2, gvl);
vfloat32m2_t result3 = __riscv_vfwmul_vf_f32m2( A0, B3, gvl);
for(BLASLONG k=1; k<K; k++) {
B0 = B[bi+0];
B1 = B[bi+1];
B2 = B[bi+2];
B3 = B[bi+3];
bi += 4;

A0 = __riscv_vle16_v_f16m1( &A[ai+0*gvl], gvl );
ai += 8;

result0 = __riscv_vfwmacc_vf_f32m2(result0, B0, A0, gvl);
result1 = __riscv_vfwmacc_vf_f32m2(result1, B1, A0, gvl);
result2 = __riscv_vfwmacc_vf_f32m2(result2, B2, A0, gvl);
result3 = __riscv_vfwmacc_vf_f32m2(result3, B3, A0, gvl);
}

BLASLONG ci=n_top*ldc+m_top;

vfloat32m2_t c0 = __riscv_vle32_v_f32m2( &C[ci], gvl); ci += ldc - gvl * 0;
vfloat32m2_t c1 = __riscv_vle32_v_f32m2( &C[ci], gvl); ci += ldc - gvl * 0;
vfloat32m2_t c2 = __riscv_vle32_v_f32m2( &C[ci], gvl); ci += ldc - gvl * 0;
vfloat32m2_t c3 = __riscv_vle32_v_f32m2( &C[ci], gvl);
c0 = __riscv_vfmacc_vf_f32m2(c0, alpha, result0, gvl);
c1 = __riscv_vfmacc_vf_f32m2(c1, alpha, result1, gvl);
c2 = __riscv_vfmacc_vf_f32m2(c2, alpha, result2, gvl);
c3 = __riscv_vfmacc_vf_f32m2(c3, alpha, result3, gvl);

ci = n_top * ldc + m_top;

__riscv_vse32_v_f32m2( &C[ci], c0, gvl); ci += ldc-gvl*0;
__riscv_vse32_v_f32m2( &C[ci], c1, gvl); ci += ldc-gvl*0;
__riscv_vse32_v_f32m2( &C[ci], c2, gvl); ci += ldc-gvl*0;
__riscv_vse32_v_f32m2( &C[ci], c3, gvl);
m_top += 8;
}

if( M & 4 ) {
gvl = __riscv_vsetvl_e16m1(4);

BLASLONG ai=m_top*K;
BLASLONG bi=n_top*K;
_Float16 B0 = B[bi+0];
_Float16 B1 = B[bi+1];
_Float16 B2 = B[bi+2];
_Float16 B3 = B[bi+3];
bi += 4;

vfloat16m1_t A0 = __riscv_vle16_v_f16m1(&A[ai + 0 * gvl], gvl);
ai += 4;

vfloat32m2_t result0 = __riscv_vfwmul_vf_f32m2( A0, B0, gvl);
vfloat32m2_t result1 = __riscv_vfwmul_vf_f32m2( A0, B1, gvl);
vfloat32m2_t result2 = __riscv_vfwmul_vf_f32m2( A0, B2, gvl);
vfloat32m2_t result3 = __riscv_vfwmul_vf_f32m2( A0, B3, gvl);

for(BLASLONG k=1; k < K; ++k) {
B0 = B[bi+0];
B1 = B[bi+1];
B2 = B[bi+2];
B3 = B[bi+3];
bi += 4;

A0 = __riscv_vle16_v_f16m1(&A[ai + 0 * gvl], gvl);
ai += 4;

result0 = __riscv_vfwmacc_vf_f32m2(result0, B0, A0, gvl);
result1 = __riscv_vfwmacc_vf_f32m2(result1, B1, A0, gvl);
result2 = __riscv_vfwmacc_vf_f32m2(result2, B2, A0, gvl);
result3 = __riscv_vfwmacc_vf_f32m2(result3, B3, A0, gvl);
}

BLASLONG ci = n_top * ldc + m_top;

vfloat32m2_t c0 = __riscv_vle32_v_f32m2(&C[ci], gvl);
ci += ldc - gvl * 0;
vfloat32m2_t c1 = __riscv_vle32_v_f32m2(&C[ci], gvl);
ci += ldc - gvl * 0;
vfloat32m2_t c2 = __riscv_vle32_v_f32m2(&C[ci], gvl);
ci += ldc - gvl * 0;
vfloat32m2_t c3 = __riscv_vle32_v_f32m2(&C[ci], gvl);
c0 = __riscv_vfmacc_vf_f32m2(c0, alpha, result0, gvl);
c1 = __riscv_vfmacc_vf_f32m2(c1, alpha, result1, gvl);
c2 = __riscv_vfmacc_vf_f32m2(c2, alpha, result2, gvl);
c3 = __riscv_vfmacc_vf_f32m2(c3, alpha, result3, gvl);

ci= n_top * ldc + m_top;

__riscv_vse32_v_f32m2(&C[ci], c0, gvl); ci += ldc - gvl * 0;
__riscv_vse32_v_f32m2(&C[ci], c1, gvl); ci += ldc - gvl * 0;
__riscv_vse32_v_f32m2(&C[ci], c2, gvl); ci += ldc - gvl * 0;
__riscv_vse32_v_f32m2(&C[ci], c3, gvl);
m_top += 4;
}


if( M & 2 ) {

BLASLONG ai = m_top * K;
BLASLONG bi = n_top * K;

float result0 = 0;
float result1 = 0;
float result2 = 0;
float result3 = 0;
float result4 = 0;
float result5 = 0;
float result6 = 0;
float result7 = 0;

for(BLASLONG k=0; k<K; k++) {
result0+=(float)(A[ai+0]*B[bi+0]);
result1+=(float)(A[ai+1]*B[bi+0]);
result2+=(float)(A[ai+0]*B[bi+1]);
result3+=(float)(A[ai+1]*B[bi+1]);
result4+=(float)(A[ai+0]*B[bi+2]);
result5+=(float)(A[ai+1]*B[bi+2]);
result6+=(float)(A[ai+0]*B[bi+3]);
result7+=(float)(A[ai+1]*B[bi+3]);
ai+=2;
bi+=4;
}
BLASLONG ci=n_top*ldc+m_top;
C[ci + 0 * ldc + 0] += alpha * result0;
C[ci + 0 * ldc + 1] += alpha * result1;
C[ci + 1 * ldc + 0] += alpha * result2;
C[ci + 1 * ldc + 1] += alpha * result3;
C[ci + 2 * ldc + 0] += alpha * result4;
C[ci + 2 * ldc + 1] += alpha * result5;
C[ci + 3 * ldc + 0] += alpha * result6;
C[ci + 3 * ldc + 1] += alpha * result7;

m_top += 2;
}


if( M & 1 ) {
float result0 = 0;
float result1 = 0;
float result2 = 0;
float result3 = 0;
BLASLONG ai = m_top * K;
BLASLONG bi = n_top * K;

for(BLASLONG k=0; k<K; k++) {
result0+=(float)(A[ai+0]*B[bi+0]);
result1+=(float)(A[ai+0]*B[bi+1]);
result2+=(float)(A[ai+0]*B[bi+2]);
result3+=(float)(A[ai+0]*B[bi+3]);
ai+=1;
bi+=4;
}

BLASLONG ci = n_top * ldc + m_top;
C[ci + 0 * ldc + 0] += alpha * result0;
C[ci + 1 * ldc + 0] += alpha * result1;
C[ci + 2 * ldc + 0] += alpha * result2;
C[ci + 3 * ldc + 0] += alpha * result3;
m_top += 1;
}

n_top += 4;
}



// -- tails for N=2
if( N & 2 ) {
gvl = __riscv_vsetvl_e16m1(8);
m_top = 0;

for (BLASLONG i=0; i<M/8; i+=1) {
BLASLONG ai=m_top*K;
BLASLONG bi=n_top*K;
_Float16 B0 = B[bi+0];
_Float16 B1 = B[bi+1];
bi += 2;

vfloat16m1_t A0 = __riscv_vle16_v_f16m1( &A[ai+0*gvl], gvl );
ai += 8;

vfloat32m2_t result0 = __riscv_vfwmul_vf_f32m2( A0, B0, gvl);
vfloat32m2_t result1 = __riscv_vfwmul_vf_f32m2( A0, B1, gvl);
for(BLASLONG k=1; k<K; k++) {
B0 = B[bi+0];
B1 = B[bi+1];
bi += 2;

A0 = __riscv_vle16_v_f16m1( &A[ai+0*gvl], gvl );
ai += 8;

result0 = __riscv_vfwmacc_vf_f32m2(result0, B0, A0, gvl);
result1 = __riscv_vfwmacc_vf_f32m2(result1, B1, A0, gvl);
}


BLASLONG ci=n_top*ldc+m_top;

vfloat32m2_t c0 = __riscv_vle32_v_f32m2( &C[ci], gvl); ci += ldc - gvl * 0;
vfloat32m2_t c1 = __riscv_vle32_v_f32m2( &C[ci], gvl);
c0 = __riscv_vfmacc_vf_f32m2(c0, alpha, result0, gvl);
c1 = __riscv_vfmacc_vf_f32m2(c1, alpha, result1, gvl);

ci = n_top * ldc + m_top;

__riscv_vse32_v_f32m2( &C[ci], c0, gvl); ci += ldc-gvl*0;
__riscv_vse32_v_f32m2( &C[ci], c1, gvl);
m_top += 8;
}

if( M & 4 ) {
gvl = __riscv_vsetvl_e16m1(4);

BLASLONG ai=m_top*K;
BLASLONG bi=n_top*K;
_Float16 B0 = B[bi+0];
_Float16 B1 = B[bi+1];
bi += 2;

vfloat16m1_t A0 = __riscv_vle16_v_f16m1(&A[ai + 0 * gvl], gvl);
ai += 4;

vfloat32m2_t result0 = __riscv_vfwmul_vf_f32m2( A0, B0, gvl);
vfloat32m2_t result1 = __riscv_vfwmul_vf_f32m2( A0, B1, gvl);

for(BLASLONG k=1; k < K; ++k) {
B0 = B[bi+0];
B1 = B[bi+1];
bi += 2;

A0 = __riscv_vle16_v_f16m1(&A[ai + 0 * gvl], gvl);
ai += 4;

result0 = __riscv_vfwmacc_vf_f32m2(result0, B0, A0, gvl);
result1 = __riscv_vfwmacc_vf_f32m2(result1, B1, A0, gvl);
}

BLASLONG ci = n_top * ldc + m_top;

vfloat32m2_t c0 = __riscv_vle32_v_f32m2(&C[ci], gvl);
ci += ldc - gvl * 0;
vfloat32m2_t c1 = __riscv_vle32_v_f32m2(&C[ci], gvl);
c0 = __riscv_vfmacc_vf_f32m2(c0, alpha, result0, gvl);
c1 = __riscv_vfmacc_vf_f32m2(c1, alpha, result1, gvl);

ci= n_top * ldc + m_top;

__riscv_vse32_v_f32m2(&C[ci], c0, gvl); ci += ldc - gvl * 0;
__riscv_vse32_v_f32m2(&C[ci], c1, gvl);
m_top += 4;
}


if( M & 2 ) {

BLASLONG ai = m_top * K;
BLASLONG bi = n_top * K;

float result0 = 0;
float result1 = 0;
float result2 = 0;
float result3 = 0;

for(BLASLONG k=0; k<K; k++) {
result0+=(float)(A[ai+0]*B[bi+0]);
result1+=(float)(A[ai+1]*B[bi+0]);
result2+=(float)(A[ai+0]*B[bi+1]);
result3+=(float)(A[ai+1]*B[bi+1]);
ai+=2;
bi+=2;
}
BLASLONG ci=n_top*ldc+m_top;
C[ci + 0 * ldc + 0] += alpha * result0;
C[ci + 0 * ldc + 1] += alpha * result1;
C[ci + 1 * ldc + 0] += alpha * result2;
C[ci + 1 * ldc + 1] += alpha * result3;

m_top += 2;
}


if( M & 1 ) {
float result0 = 0;
float result1 = 0;
BLASLONG ai = m_top * K;
BLASLONG bi = n_top * K;

for(BLASLONG k=0; k<K; k++) {
result0+=(float)(A[ai+0]*B[bi+0]);
result1+=(float)(A[ai+0]*B[bi+1]);
ai+=1;
bi+=2;
}

BLASLONG ci = n_top * ldc + m_top;
C[ci + 0 * ldc + 0] += alpha * result0;
C[ci + 1 * ldc + 0] += alpha * result1;
m_top += 1;
}

n_top += 2;
}



// -- tails for N=1
if( N & 1 ) {
gvl = __riscv_vsetvl_e16m1(8);
m_top = 0;

for (BLASLONG i=0; i<M/8; i+=1) {
BLASLONG ai=m_top*K;
BLASLONG bi=n_top*K;
_Float16 B0 = B[bi+0];
bi += 1;

vfloat16m1_t A0 = __riscv_vle16_v_f16m1( &A[ai+0*gvl], gvl );
ai += 8;

vfloat32m2_t result0 = __riscv_vfwmul_vf_f32m2( A0, B0, gvl);
for(BLASLONG k=1; k<K; k++) {
B0 = B[bi+0];
bi += 1;

A0 = __riscv_vle16_v_f16m1( &A[ai+0*gvl], gvl );
ai += 8;

result0 = __riscv_vfwmacc_vf_f32m2(result0, B0, A0, gvl);
}


BLASLONG ci=n_top*ldc+m_top;

vfloat32m2_t c0 = __riscv_vle32_v_f32m2( &C[ci], gvl);
c0 = __riscv_vfmacc_vf_f32m2(c0, alpha, result0, gvl);

ci = n_top * ldc + m_top;

__riscv_vse32_v_f32m2( &C[ci], c0, gvl);
m_top += 8;
}

if( M & 4 ) {
gvl = __riscv_vsetvl_e16m1(4);

BLASLONG ai=m_top*K;
BLASLONG bi=n_top*K;
_Float16 B0 = B[bi+0];
bi += 1;

vfloat16m1_t A0 = __riscv_vle16_v_f16m1(&A[ai + 0 * gvl], gvl);
ai += 4;

vfloat32m2_t result0 = __riscv_vfwmul_vf_f32m2( A0, B0, gvl);

for(BLASLONG k=1; k < K; ++k) {
B0 = B[bi+0];
bi += 1;

A0 = __riscv_vle16_v_f16m1(&A[ai + 0 * gvl], gvl);
ai += 4;

result0 = __riscv_vfwmacc_vf_f32m2(result0, B0, A0, gvl);
}

BLASLONG ci = n_top * ldc + m_top;

vfloat32m2_t c0 = __riscv_vle32_v_f32m2(&C[ci], gvl);
c0 = __riscv_vfmacc_vf_f32m2(c0, alpha, result0, gvl);

ci= n_top * ldc + m_top;

__riscv_vse32_v_f32m2(&C[ci], c0, gvl);
m_top += 4;
}


if( M & 2 ) {

BLASLONG ai = m_top * K;
BLASLONG bi = n_top * K;

float result0 = 0;
float result1 = 0;

for(BLASLONG k=0; k<K; k++) {
result0+=(float)(A[ai+0]*B[bi+0]);
result1+=(float)(A[ai+1]*B[bi+0]);
ai+=2;
bi+=1;
}
BLASLONG ci=n_top*ldc+m_top;
C[ci + 0 * ldc + 0] += alpha * result0;
C[ci + 0 * ldc + 1] += alpha * result1;

m_top += 2;
}


if( M & 1 ) {
float result0 = 0;
BLASLONG ai = m_top * K;
BLASLONG bi = n_top * K;

for(BLASLONG k=0; k<K; k++) {
result0+=(float)(A[ai+0]*B[bi+0]);
ai+=1;
bi+=1;
}

BLASLONG ci = n_top * ldc + m_top;
C[ci + 0 * ldc + 0] += alpha * result0;
m_top += 1;
}

n_top += 1;
}

return 0;

}

+ 37
- 0
kernel/setparam-ref.c View File

@@ -125,6 +125,23 @@ gotoblas_t TABLE_NAME = {
#endif
#endif

#ifdef BUILD_HFLOAT16
0, 0, 0,
SHGEMM_DEFAULT_UNROLL_M, SHGEMM_DEFAULT_UNROLL_N,
#ifdef SHGEMM_DEFAULT_UNROLL_MN
SHGEMM_DEFAULT_UNROLL_MN,
#else
MAX(SHGEMM_DEFAULT_UNROLL_M, SHGEMM_DEFAULT_UNROLL_N),
#endif
shgemm_kernelTS, shgemm_betaTS,
#if SHGEMM_DEFAULT_UNROLL_M != SHGEMM_DEFAULT_UNROLL_N
shgemm_incopyTS, shgemm_itcopyTS,
#else
shgemm_oncopyTS, shgemm_otcopyTS,
#endif
shgemm_oncopyTS, shgemm_otcopyTS,
#endif

#if ( BUILD_SINGLE==1) || (BUILD_DOUBLE==1) || (BUILD_COMPLEX==1) || (BUILD_COMPLEX16==1)
0, 0, 0,
SGEMM_DEFAULT_UNROLL_M, SGEMM_DEFAULT_UNROLL_N,
@@ -1252,6 +1269,9 @@ static void init_parameter(void) {

#ifdef BUILD_BFLOAT16
TABLE_NAME.sbgemm_p = SBGEMM_DEFAULT_P;
#endif
#ifdef BUILD_HFLOAT16
TABLE_NAME.shgemm_p = SHGEMM_DEFAULT_P;
#endif
TABLE_NAME.sgemm_p = SGEMM_DEFAULT_P;
TABLE_NAME.dgemm_p = DGEMM_DEFAULT_P;
@@ -1260,6 +1280,9 @@ static void init_parameter(void) {

#ifdef BUILD_BFLOAT16
TABLE_NAME.sbgemm_r = SBGEMM_DEFAULT_R;
#endif
#ifdef BUILD_HFLOAT16
TABLE_NAME.shgemm_r = SHGEMM_DEFAULT_R;
#endif
TABLE_NAME.sgemm_r = SGEMM_DEFAULT_R;
TABLE_NAME.dgemm_r = DGEMM_DEFAULT_R;
@@ -1269,6 +1292,9 @@ static void init_parameter(void) {

#ifdef BUILD_BFLOAT16
TABLE_NAME.sbgemm_q = SBGEMM_DEFAULT_Q;
#endif
#ifdef BUILD_HFLOAT16
TABLE_NAME.shgemm_q = SHGEMM_DEFAULT_Q;
#endif
TABLE_NAME.sgemm_q = SGEMM_DEFAULT_Q;
TABLE_NAME.dgemm_q = DGEMM_DEFAULT_Q;
@@ -1417,6 +1443,10 @@ static void init_parameter(void) {
TABLE_NAME.sbgemm_p = SBGEMM_DEFAULT_P;
TABLE_NAME.sbgemm_q = SBGEMM_DEFAULT_Q;
#endif
#ifdef BUILD_HFLOAT16
TABLE_NAME.shgemm_p = SHGEMM_DEFAULT_P;
TABLE_NAME.shgemm_q = SHGEMM_DEFAULT_Q;
#endif
#if (BUILD_SINGLE==1) || (BUILD_COMPLEX==1)
TABLE_NAME.sgemm_q = SGEMM_DEFAULT_Q;
#endif
@@ -2012,6 +2042,13 @@ static void init_parameter(void) {
) / (TABLE_NAME.sbgemm_q * 4) - 15) & ~15);
#endif

#if BUILD_HFLOAT16==1
TABLE_NAME.shgemm_r = (((BUFFER_SIZE -
((TABLE_NAME.shgemm_p * TABLE_NAME.shgemm_q * 4 + TABLE_NAME.offsetA
+ TABLE_NAME.align) & ~TABLE_NAME.align)
) / (TABLE_NAME.shgemm_q * 4) - 15) & ~15);
#endif

#if BUILD_SINGLE==1
TABLE_NAME.sgemm_r = (((BUFFER_SIZE -
((TABLE_NAME.sgemm_p * TABLE_NAME.sgemm_q * 4 + TABLE_NAME.offsetA


+ 1
- 0
lapack/CMakeLists.txt View File

@@ -3,6 +3,7 @@ include_directories(${PROJECT_SOURCE_DIR})
include_directories(${PROJECT_BINARY_DIR})

list (REMOVE_ITEM FLOAT_TYPES "BFLOAT16")
list (REMOVE_ITEM FLOAT_TYPES "HFLOAT16")

set(LAPACK_SOURCES
potrf/potrf_U_single.c


+ 7
- 0
openblas_config_template.h View File

@@ -39,6 +39,13 @@ typedef unsigned long BLASULONG;
typedef uint16_t bfloat16;
#endif

#if defined(__GNUC__) && (__GNUC__ >= 12)
typedef _Float16 hfloat16;
#else
#include <stdint.h>
typedef uint16_t hfloat16;
#endif

#ifdef OPENBLAS_USE64BITINT
typedef BLASLONG blasint;
#else


+ 29
- 0
param.h View File

@@ -72,6 +72,12 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#ifndef PARAM_H
#define PARAM_H

#define SHGEMM_DEFAULT_UNROLL_N 8
#define SHGEMM_DEFAULT_UNROLL_M 8
#define SHGEMM_DEFAULT_UNROLL_MN 32
#define SHGEMM_DEFAULT_P 128
#define SHGEMM_DEFAULT_R 240
#define SHGEMM_DEFAULT_Q 12288

#define SBGEMM_DEFAULT_UNROLL_N 4
#define SBGEMM_DEFAULT_UNROLL_M 8
@@ -3138,10 +3144,16 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#endif

#ifdef RISCV64_ZVL128B

#define GEMM_DEFAULT_OFFSET_A 0
#define GEMM_DEFAULT_OFFSET_B 0
#define GEMM_DEFAULT_ALIGN (BLASLONG)0x03fffUL

#undef SHGEMM_DEFAULT_UNROLL_M
#undef SHGEMM_DEFAULT_UNROLL_N
#define SHGEMM_DEFAULT_UNROLL_M 8
#define SHGEMM_DEFAULT_UNROLL_N 8

#define SGEMM_DEFAULT_UNROLL_M 8
#define SGEMM_DEFAULT_UNROLL_N 8

@@ -3154,16 +3166,22 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#define ZGEMM_DEFAULT_UNROLL_M 4
#define ZGEMM_DEFAULT_UNROLL_N 4

#undef SHGEMM_DEFAULT_P
#define SHGEMM_DEFAULT_P 128
#define SGEMM_DEFAULT_P 128
#define DGEMM_DEFAULT_P 128
#define CGEMM_DEFAULT_P 96
#define ZGEMM_DEFAULT_P 64

#undef SHGEMM_DEFAULT_Q
#define SHGEMM_DEFAULT_Q 240
#define SGEMM_DEFAULT_Q 240
#define DGEMM_DEFAULT_Q 120
#define CGEMM_DEFAULT_Q 120
#define ZGEMM_DEFAULT_Q 120

#undef SHGEMM_DEFAULT_R
#define SHGEMM_DEFAULT_R 12288
#define SGEMM_DEFAULT_R 12288
#define DGEMM_DEFAULT_R 8192
#define CGEMM_DEFAULT_R 4096
@@ -3181,6 +3199,11 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#define GEMM_DEFAULT_OFFSET_B 0
#define GEMM_DEFAULT_ALIGN 0x03fffUL

#undef SHGEMM_DEFAULT_UNROLL_M
#undef SHGEMM_DEFAULT_UNROLL_N
#define SHGEMM_DEFAULT_UNROLL_M 16
#define SHGEMM_DEFAULT_UNROLL_N 8

#define SGEMM_DEFAULT_UNROLL_M 16
#define SGEMM_DEFAULT_UNROLL_N 8

@@ -3193,16 +3216,22 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#define ZGEMM_DEFAULT_UNROLL_M 8
#define ZGEMM_DEFAULT_UNROLL_N 4

#undef SHGEMM_DEFAULT_P
#define SHGEMM_DEFAULT_P 128
#define SGEMM_DEFAULT_P 128
#define DGEMM_DEFAULT_P 64
#define CGEMM_DEFAULT_P 64
#define ZGEMM_DEFAULT_P 64

#undef SHGEMM_DEFAULT_Q
#define SHGEMM_DEFAULT_Q 128
#define SGEMM_DEFAULT_Q 128
#define DGEMM_DEFAULT_Q 128
#define CGEMM_DEFAULT_Q 128
#define ZGEMM_DEFAULT_Q 64

#undef SHGEMM_DEFAULT_R
#define SHGEMM_DEFAULT_R 16384
#define SGEMM_DEFAULT_R 16384
#define DGEMM_DEFAULT_R 8192
#define CGEMM_DEFAULT_R 8192


Loading…
Cancel
Save