Browse Source

Add FP16 support for RISCV

pull/5290/head
Srangrang 4 months ago
parent
commit
0a967797a1
12 changed files with 270 additions and 35 deletions
  1. +1
    -0
      Makefile.system
  2. +5
    -2
      Makefile.tail
  3. +4
    -0
      cblas.h
  4. +11
    -8
      cmake/system.cmake
  5. +7
    -2
      common.h
  6. +2
    -0
      common_interface.h
  7. +18
    -1
      common_level3.h
  8. +45
    -0
      common_macro.h
  9. +72
    -22
      common_param.h
  10. +72
    -0
      common_sh.h
  11. +5
    -0
      openblas_config_template.h
  12. +28
    -0
      param.h

+ 1
- 0
Makefile.system View File

@@ -1889,6 +1889,7 @@ export TARGET_CORE
export NO_AVX512
export NO_AVX2
export BUILD_BFLOAT16
export BUILD_HFLOAT16
export NO_LSX
export NO_LASX



+ 5
- 2
Makefile.tail View File

@@ -1,4 +1,5 @@
SBBLASOBJS_P = $(SBBLASOBJS:.$(SUFFIX)=.$(PSUFFIX))
SHBLASPBJS_P = $(SHBLASOBJS:.$(SUFFIX)=.$(PSUFFIX))
SBLASOBJS_P = $(SBLASOBJS:.$(SUFFIX)=.$(PSUFFIX))
DBLASOBJS_P = $(DBLASOBJS:.$(SUFFIX)=.$(PSUFFIX))
QBLASOBJS_P = $(QBLASOBJS:.$(SUFFIX)=.$(PSUFFIX))
@@ -11,8 +12,8 @@ COMMONOBJS_P = $(COMMONOBJS:.$(SUFFIX)=.$(PSUFFIX))

HPLOBJS_P = $(HPLOBJS:.$(SUFFIX)=.$(PSUFFIX))

BLASOBJS = $(SBEXTOBJS) $(SBBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS) $(CBAUXOBJS)
BLASOBJS_P = $(SBEXTOBJS_P) $(SBBLASOBJS_P) $(SBLASOBJS_P) $(DBLASOBJS_P) $(CBLASOBJS_P) $(ZBLASOBJS_P) $(CBAUXOBJS_P)
BLASOBJS = $(SHBLASOBJS) $(SBEXTOBJS) $(SBBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS) $(CBAUXOBJS)
BLASOBJS_P = $(SHBLASPBJS_P) $(SBEXTOBJS_P) $(SBBLASOBJS_P) $(SBLASOBJS_P) $(DBLASOBJS_P) $(CBLASOBJS_P) $(ZBLASOBJS_P) $(CBAUXOBJS_P)

ifdef EXPRECISION
BLASOBJS += $(QBLASOBJS) $(XBLASOBJS)
@@ -24,6 +25,7 @@ BLASOBJS += $(QBLASOBJS) $(XBLASOBJS)
BLASOBJS_P += $(QBLASOBJS_P) $(XBLASOBJS_P)
endif

$(SHBLASOBJS) $(SHBLASOBJS_P) : override CFLAGS += -DHFLOAT16 -UDOUBLE -UCOMPLEX
$(SBBLASOBJS) $(SBBLASOBJS_P) : override CFLAGS += -DBFLOAT16 -UDOUBLE -UCOMPLEX
$(SBLASOBJS) $(SBLASOBJS_P) : override CFLAGS += -UDOUBLE -UCOMPLEX
$(DBLASOBJS) $(DBLASOBJS_P) : override CFLAGS += -DDOUBLE -UCOMPLEX
@@ -33,6 +35,7 @@ $(ZBLASOBJS) $(ZBLASOBJS_P) : override CFLAGS += -DDOUBLE -DCOMPLEX
$(XBLASOBJS) $(XBLASOBJS_P) : override CFLAGS += -DXDOUBLE -DCOMPLEX
$(SBEXTOBJS) $(SBEXTOBJS_P) : override CFLAGS += -DBFLOAT16 -UDOUBLE -UCOMPLEX

$(SHBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
$(SBBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
$(SBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
$(DBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)


+ 4
- 0
cblas.h View File

@@ -446,6 +446,10 @@ void cblas_sbgemm(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum C
void cblas_sbgemm_batch(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransA_array, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransB_array, OPENBLAS_CONST blasint * M_array, OPENBLAS_CONST blasint * N_array, OPENBLAS_CONST blasint * K_array,
OPENBLAS_CONST float * alpha_array, OPENBLAS_CONST bfloat16 ** A_array, OPENBLAS_CONST blasint * lda_array, OPENBLAS_CONST bfloat16 ** B_array, OPENBLAS_CONST blasint * ldb_array, OPENBLAS_CONST float * beta_array, float ** C_array, OPENBLAS_CONST blasint * ldc_array, OPENBLAS_CONST blasint group_count, OPENBLAS_CONST blasint * group_size);

/*** FLOAT16 extensions */
void cblas_shgemm(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransA, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransB, OPENBLAS_CONST blasint M, OPENBLAS_CONST blasint N, OPENBLAS_CONST blasint K,
OPENBLAS_CONST float alpha, OPENBLAS_CONST hfloat16 *A, OPENBLAS_CONST blasint lda, OPENBLAS_CONST hfloat16 *B, OPENBLAS_CONST blasint ldb, OPENBLAS_CONST float beta, float *C, OPENBLAS_CONST blasint ldc);

#ifdef __cplusplus
}
#endif /* __cplusplus */


+ 11
- 8
cmake/system.cmake View File

@@ -640,6 +640,9 @@ endif()
if (BUILD_BFLOAT16)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DBUILD_BFLOAT16")
endif()
if (BUILD_HFLOAT16)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DBUILD_HFLOAT16")
endif()
if(NOT MSVC)
set(CMAKE_ASM_FLAGS "${CMAKE_ASM_FLAGS} ${CCOMMON_OPT}")
endif()
@@ -647,14 +650,14 @@ endif()
set(PFLAGS "${PFLAGS} ${CCOMMON_OPT} -I${TOPDIR} -DPROFILE ${COMMON_PROF}")
if ("${CMAKE_BUILD_TYPE}" STREQUAL "Release")

if ("${F_COMPILER}" STREQUAL "FLANG")
if (${CMAKE_Fortran_COMPILER_VERSION} VERSION_LESS_EQUAL 3)
set(CMAKE_Fortran_FLAGS_RELEASE "${CMAKE_Fortran_FLAGS_RELEASE} -fno-unroll-loops")
endif ()
endif ()
if (ARM64 AND CMAKE_Fortran_COMPILER_ID MATCHES "LLVMFlang.*" AND CMAKE_SYSTEM_NAME STREQUAL "Windows")
set(CMAKE_Fortran_FLAGS_RELEASE "${CMAKE_Fortran_FLAGS_RELEASE} -O2")
endif ()
if ("${F_COMPILER}" STREQUAL "FLANG")
if (${CMAKE_Fortran_COMPILER_VERSION} VERSION_LESS_EQUAL 3)
set(CMAKE_Fortran_FLAGS_RELEASE "${CMAKE_Fortran_FLAGS_RELEASE} -fno-unroll-loops")
endif ()
endif ()
if (ARM64 AND CMAKE_Fortran_COMPILER_ID MATCHES "LLVMFlang.*" AND CMAKE_SYSTEM_NAME STREQUAL "Windows")
set(CMAKE_Fortran_FLAGS_RELEASE "${CMAKE_Fortran_FLAGS_RELEASE} -O2")
endif ()
endif ()




+ 7
- 2
common.h View File

@@ -266,6 +266,11 @@ typedef uint16_t bfloat16;
#define BFLOAT16CONVERSION 1
#endif

#ifndef hfloat16
#include <stdint.h>
typedef uint16_t hfloat16;
#endif

#ifdef USE64BITINT
typedef BLASLONG blasint;
#if defined(OS_WINDOWS) && defined(__64BIT__)
@@ -313,8 +318,8 @@ typedef int blasint;
#define SIZE 2
#define BASE_SHIFT 1
#define ZBASE_SHIFT 2
#elif defined(FLOAT16)
#define IFLOAT float16
#elif defined(HFLOAT16)
#define IFLOAT hfloat16
#define XFLOAT IFLOAT
#define FLOAT float
#define SIZE 2


+ 2
- 0
common_interface.h View File

@@ -481,6 +481,8 @@ void BLASFUNC(xhbmv)(char *, blasint *, blasint *, xdouble *, xdouble *, blasint

/* Level 3 routines */

void BLASFUNC(shgemm)(char *, char *, blasint *, blasint *, blasint *, float *,
hfloat16 *, blasint *, hfloat16 *, blasint *, float *, float *, blasint *);
void BLASFUNC(sbgemm)(char *, char *, blasint *, blasint *, blasint *, float *,
bfloat16 *, blasint *, bfloat16 *, blasint *, float *, float *, blasint *);
void BLASFUNC(sgemm)(char *, char *, blasint *, blasint *, blasint *, float *,


+ 18
- 1
common_level3.h View File

@@ -54,7 +54,8 @@ void sgemm_direct(BLASLONG M, BLASLONG N, BLASLONG K,

int sgemm_direct_performant(BLASLONG M, BLASLONG N, BLASLONG K);


int shgemm_beta(BLASLONG, BLASLONG, BLASLONG, float,
hfloat16 *, BLASLONG, hfloat16 *, BLASLONG, float *, BLASLONG);
int sbgemm_beta(BLASLONG, BLASLONG, BLASLONG, float,
bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, float *, BLASLONG);
int sgemm_beta(BLASLONG, BLASLONG, BLASLONG, float,
@@ -78,6 +79,10 @@ int xgemm_beta(BLASLONG, BLASLONG, BLASLONG, xdouble *,
xdouble *, BLASLONG, xdouble *, BLASLONG, xdouble *, BLASLONG);
#endif

int shgemm_incopy(BLASLONG m, BLASLONG n, hfloat16 *a, BLASLONG lda, hfloat16 *b);
int shgemm_itcopy(BLASLONG m, BLASLONG n, hfloat16 *a, BLASLONG lda, hfloat16 *b);
int shgemm_oncopy(BLASLONG m, BLASLONG n, hfloat16 *a, BLASLONG lda, hfloat16 *b);
int shgemm_otcopy(BLASLONG m, BLASLONG n, hfloat16 *a, BLASLONG lda, hfloat16 *b);
int sbgemm_incopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b);
int sbgemm_itcopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b);
int sbgemm_oncopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b);
@@ -505,6 +510,7 @@ int xher2k_kernel_UC(BLASLONG m, BLASLONG n, BLASLONG k, xdouble alpha_r, xdoubl
int xher2k_kernel_LN(BLASLONG m, BLASLONG n, BLASLONG k, xdouble alpha_r, xdouble alpha_i, xdouble *a, xdouble *b, xdouble *c, BLASLONG ldc, BLASLONG offset, int flag);
int xher2k_kernel_LC(BLASLONG m, BLASLONG n, BLASLONG k, xdouble alpha_r, xdouble alpha_i, xdouble *a, xdouble *b, xdouble *c, BLASLONG ldc, BLASLONG offset, int flag);

int shgemm_kernel(BLASLONG, BLASLONG, BLASLONG, float, hfloat16 *, hfloat16 *, float *, BLASLONG);
int sbgemm_kernel(BLASLONG, BLASLONG, BLASLONG, float, bfloat16 *, bfloat16 *, float *, BLASLONG);
int sgemm_kernel(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG);
int dgemm_kernel(BLASLONG, BLASLONG, BLASLONG, double, double *, double *, double *, BLASLONG);
@@ -657,6 +663,11 @@ int cgemm3m_kernel(BLASLONG, BLASLONG, BLASLONG, float, float, float *, float
int zgemm3m_kernel(BLASLONG, BLASLONG, BLASLONG, double, double, double *, double *, double *, BLASLONG);
int xgemm3m_kernel(BLASLONG, BLASLONG, BLASLONG, xdouble, xdouble, xdouble *, xdouble *, xdouble *, BLASLONG);

int shgemm_nn(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG);
int shgemm_nt(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG);
int shgemm_tn(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG);
int shgemm_tt(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG);

int sbgemm_nn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
int sbgemm_nt(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
int sbgemm_tn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
@@ -754,6 +765,11 @@ int xgemm_cr(blas_arg_t *, BLASLONG *, BLASLONG *, xdouble *, xdouble *, BLASLON
int xgemm_cc(blas_arg_t *, BLASLONG *, BLASLONG *, xdouble *, xdouble *, BLASLONG);
#endif

int shgemm_thread_nn(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG);
int shgemm_thread_nt(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG);
int shgemm_thread_tn(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG);
int shgemm_thread_tt(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG);

int sbgemm_thread_nn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
int sbgemm_thread_nt(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
int sbgemm_thread_tn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
@@ -1944,6 +1960,7 @@ int dgemm_batch_thread(blas_arg_t * queue, BLASLONG nums);
int cgemm_batch_thread(blas_arg_t * queue, BLASLONG nums);
int zgemm_batch_thread(blas_arg_t * queue, BLASLONG nums);
int sbgemm_batch_thread(blas_arg_t * queue, BLASLONG nums);
// int shgemm_batch_thread(blas_arg_t * queue, BLASLONG nums);

#ifdef __CUDACC__
}


+ 45
- 0
common_macro.h View File

@@ -39,6 +39,7 @@
#ifndef COMMON_MACRO
#define COMMON_MACRO

#include "common_sh.h"
#include "common_sb.h"
#include "common_s.h"
#include "common_d.h"
@@ -656,6 +657,50 @@
#define GEMM_SMALL_KERNEL_B0_NT DGEMM_SMALL_KERNEL_B0_NT
#define GEMM_SMALL_KERNEL_B0_TN DGEMM_SMALL_KERNEL_B0_TN
#define GEMM_SMALL_KERNEL_B0_TT DGEMM_SMALL_KERNEL_B0_TT
#elif defined(HFLOAT16)
#define GEMM_BETA SHGEMM_BETA
#define GEMM_KERNEL_N SHGEMM_KERNEL
#define GEMM_KERNEL_L SHGEMM_KERNEL
#define GEMM_KERNEL_R SHGEMM_KERNEL
#define GEMM_KERNEL_B SHGEMM_KERNEL
#define GEMM_NN SHGEMM_NN
#define GEMM_CN SHGEMM_TN
#define GEMM_TN SHGEMM_TN
#define GEMM_NC SHGEMM_NT
#define GEMM_NT SHGEMM_NT
#define GEMM_CC SHGEMM_TT
#define GEMM_CT SHGEMM_TT
#define GEMM_TC SHGEMM_TT
#define GEMM_TT SHGEMM_TT
#define GEMM_NR SHGEMM_NN
#define GEMM_TR SHGEMM_TN
#define GEMM_CR SHGEMM_TN
#define GEMM_RN SHGEMM_NN
#define GEMM_RT SHGEMM_NT
#define GEMM_RC SHGEMM_NT
#define GEMM_RR SHGEMM_NN
#define GEMM_ONCOPY SHGEMM_ONCOPY
#define GEMM_OTCOPY SHGEMM_OTCOPY
#define GEMM_INCOPY SHGEMM_INCOPY
#define GEMM_ITCOPY SHGEMM_ITCOPY

#define GEMM_THREAD_NN SHGEMM_THREAD_NN
#define GEMM_THREAD_CN SHGEMM_THREAD_TN
#define GEMM_THREAD_TN SHGEMM_THREAD_TN
#define GEMM_THREAD_NC SHGEMM_THREAD_NT
#define GEMM_THREAD_NT SHGEMM_THREAD_NT
#define GEMM_THREAD_CC SHGEMM_THREAD_TT
#define GEMM_THREAD_CT SHGEMM_THREAD_TT
#define GEMM_THREAD_TC SHGEMM_THREAD_TT
#define GEMM_THREAD_TT SHGEMM_THREAD_TT
#define GEMM_THREAD_NR SHGEMM_THREAD_NN
#define GEMM_THREAD_TR SHGEMM_THREAD_TN
#define GEMM_THREAD_CR SHGEMM_THREAD_TN
#define GEMM_THREAD_RN SHGEMM_THREAD_NN
#define GEMM_THREAD_RT SHGEMM_THREAD_NT
#define GEMM_THREAD_RC SHGEMM_THREAD_NT
#define GEMM_THREAD_RR SHGEMM_THREAD_NN


#elif defined(BFLOAT16)



+ 72
- 22
common_param.h View File

@@ -48,6 +48,21 @@ typedef struct {
int dtb_entries;
int switch_ratio;
int offsetA, offsetB, align;
#if BUILD_HFLOAT16 == 1
int shgemm_p, shgemm_q, shgemm_r;
int shgemm_unroll_m, shgemm_unroll_n, shgemm_unroll_mn;

int (*shgemm_kernel )(BLASLONG, BLASLONG, BLASLONG, float, hfloat16 *, hfloat16 *, float *, BLASLONG);
int (*shgemm_beta )(BLASLONG, BLASLONG, BLASLONG, float, hfloat16 *, BLASLONG, hfloat16 *, BLASLONG, float *, BLASLONG);

int (*shgemm_incopy )(BLASLONG, BLASLONG, hfloat16 *, BLASLONG, hfloat16 *);
int (*shgemm_itcopy )(BLASLONG, BLASLONG, hfloat16 *, BLASLONG, hfloat16 *);
int (*shgemm_oncopy )(BLASLONG, BLASLONG, hfloat16 *, BLASLONG, hfloat16 *);
int (*shgemm_otcopy )(BLASLONG, BLASLONG, hfloat16 *, BLASLONG, hfloat16 *);


#endif


#if BUILD_BFLOAT16 == 1
int sbgemm_p, sbgemm_q, sbgemm_r;
@@ -64,10 +79,10 @@ typedef struct {
float (*sbamin_k) (BLASLONG, float *, BLASLONG);
float (*sbmax_k) (BLASLONG, float *, BLASLONG);
float (*sbmin_k) (BLASLONG, float *, BLASLONG);
BLASLONG (*isbamax_k)(BLASLONG, float *, BLASLONG);
BLASLONG (*isbamin_k)(BLASLONG, float *, BLASLONG);
BLASLONG (*isbmax_k) (BLASLONG, float *, BLASLONG);
BLASLONG (*isbmin_k) (BLASLONG, float *, BLASLONG);
BLASLONG (*isbamax_k)(BLASLONG, float *, BLASLONG);
BLASLONG (*isbamin_k)(BLASLONG, float *, BLASLONG);
BLASLONG (*isbmax_k) (BLASLONG, float *, BLASLONG);
BLASLONG (*isbmin_k) (BLASLONG, float *, BLASLONG);

float (*sbnrm2_k) (BLASLONG, float *, BLASLONG);
float (*sbasum_k) (BLASLONG, float *, BLASLONG);
@@ -180,12 +195,12 @@ BLASLONG (*isbmin_k) (BLASLONG, float *, BLASLONG);
#endif

#if (BUILD_SINGLE==1) || (BUILD_DOUBLE ==1) || (BUILD_COMPLEX==1)
BLASLONG (*isamax_k)(BLASLONG, float *, BLASLONG);
BLASLONG (*isamax_k)(BLASLONG, float *, BLASLONG);
#endif
#if (BUILD_SINGLE==1) || (BUILD_COMPLEX==1)
BLASLONG (*isamin_k)(BLASLONG, float *, BLASLONG);
BLASLONG (*ismax_k) (BLASLONG, float *, BLASLONG);
BLASLONG (*ismin_k) (BLASLONG, float *, BLASLONG);
BLASLONG (*isamin_k)(BLASLONG, float *, BLASLONG);
BLASLONG (*ismax_k) (BLASLONG, float *, BLASLONG);
BLASLONG (*ismin_k) (BLASLONG, float *, BLASLONG);
float (*snrm2_k) (BLASLONG, float *, BLASLONG);
float (*sasum_k) (BLASLONG, float *, BLASLONG);
#endif
@@ -316,10 +331,10 @@ BLASLONG (*ismin_k) (BLASLONG, float *, BLASLONG);
double (*damin_k) (BLASLONG, double *, BLASLONG);
double (*dmax_k) (BLASLONG, double *, BLASLONG);
double (*dmin_k) (BLASLONG, double *, BLASLONG);
BLASLONG (*idamax_k)(BLASLONG, double *, BLASLONG);
BLASLONG (*idamin_k)(BLASLONG, double *, BLASLONG);
BLASLONG (*idmax_k) (BLASLONG, double *, BLASLONG);
BLASLONG (*idmin_k) (BLASLONG, double *, BLASLONG);
BLASLONG (*idamax_k)(BLASLONG, double *, BLASLONG);
BLASLONG (*idamin_k)(BLASLONG, double *, BLASLONG);
BLASLONG (*idmax_k) (BLASLONG, double *, BLASLONG);
BLASLONG (*idmin_k) (BLASLONG, double *, BLASLONG);

double (*dnrm2_k) (BLASLONG, double *, BLASLONG);
double (*dasum_k) (BLASLONG, double *, BLASLONG);
@@ -435,10 +450,10 @@ BLASLONG (*idmin_k) (BLASLONG, double *, BLASLONG);
xdouble (*qamin_k) (BLASLONG, xdouble *, BLASLONG);
xdouble (*qmax_k) (BLASLONG, xdouble *, BLASLONG);
xdouble (*qmin_k) (BLASLONG, xdouble *, BLASLONG);
BLASLONG (*iqamax_k)(BLASLONG, xdouble *, BLASLONG);
BLASLONG (*iqamin_k)(BLASLONG, xdouble *, BLASLONG);
BLASLONG (*iqmax_k) (BLASLONG, xdouble *, BLASLONG);
BLASLONG (*iqmin_k) (BLASLONG, xdouble *, BLASLONG);
BLASLONG (*iqamax_k)(BLASLONG, xdouble *, BLASLONG);
BLASLONG (*iqamin_k)(BLASLONG, xdouble *, BLASLONG);
BLASLONG (*iqmax_k) (BLASLONG, xdouble *, BLASLONG);
BLASLONG (*iqmin_k) (BLASLONG, xdouble *, BLASLONG);

xdouble (*qnrm2_k) (BLASLONG, xdouble *, BLASLONG);
xdouble (*qasum_k) (BLASLONG, xdouble *, BLASLONG);
@@ -528,8 +543,8 @@ BLASLONG (*iqmin_k) (BLASLONG, xdouble *, BLASLONG);
float (*camax_k) (BLASLONG, float *, BLASLONG);
float (*camin_k) (BLASLONG, float *, BLASLONG);
BLASLONG (*icamax_k)(BLASLONG, float *, BLASLONG);
BLASLONG (*icamin_k)(BLASLONG, float *, BLASLONG);
BLASLONG (*icamax_k)(BLASLONG, float *, BLASLONG);
BLASLONG (*icamin_k)(BLASLONG, float *, BLASLONG);

float (*cnrm2_k) (BLASLONG, float *, BLASLONG);
float (*casum_k) (BLASLONG, float *, BLASLONG);
@@ -739,8 +754,8 @@ BLASLONG (*icamin_k)(BLASLONG, float *, BLASLONG);

double (*zamax_k) (BLASLONG, double *, BLASLONG);
double (*zamin_k) (BLASLONG, double *, BLASLONG);
BLASLONG (*izamax_k)(BLASLONG, double *, BLASLONG);
BLASLONG (*izamin_k)(BLASLONG, double *, BLASLONG);
BLASLONG (*izamax_k)(BLASLONG, double *, BLASLONG);
BLASLONG (*izamin_k)(BLASLONG, double *, BLASLONG);

double (*znrm2_k) (BLASLONG, double *, BLASLONG);
double (*zasum_k) (BLASLONG, double *, BLASLONG);
@@ -950,8 +965,8 @@ BLASLONG (*izamin_k)(BLASLONG, double *, BLASLONG);

xdouble (*xamax_k) (BLASLONG, xdouble *, BLASLONG);
xdouble (*xamin_k) (BLASLONG, xdouble *, BLASLONG);
BLASLONG (*ixamax_k)(BLASLONG, xdouble *, BLASLONG);
BLASLONG (*ixamin_k)(BLASLONG, xdouble *, BLASLONG);
BLASLONG (*ixamax_k)(BLASLONG, xdouble *, BLASLONG);
BLASLONG (*ixamin_k)(BLASLONG, xdouble *, BLASLONG);

xdouble (*xnrm2_k) (BLASLONG, xdouble *, BLASLONG);
xdouble (*xasum_k) (BLASLONG, xdouble *, BLASLONG);
@@ -1229,6 +1244,15 @@ extern gotoblas_t *gotoblas;

#define HAVE_EX_L2 gotoblas -> exclusive_cache

#if (BUILD_HFLOAT16==1)
#define SHGEMM_P gotoblas -> shgemm_p
#define SHGEMM_Q gotoblas -> shgemm_q
#define SHGEMM_R gotoblas -> shgemm_r
#define SHGEMM_UNROLL_M gotoblas -> shgemm_unroll_m
#define SHGEMM_UNROLL_N gotoblas -> shgemm_unroll_n
#define SHGEMM_UNROLL_MN gotoblas -> shgemm_unroll_mn
#endif

#if (BUILD_BFLOAT16==1)
#define SBGEMM_P gotoblas -> sbgemm_p
#define SBGEMM_Q gotoblas -> sbgemm_q
@@ -1357,6 +1381,19 @@ extern gotoblas_t *gotoblas;
#define HAVE_EX_L2 0
#endif

#if (BUILD_HFLOAT16 == 1)
#define SHGEMM_P SHGEMM_DEFAULT_P
#define SHGEMM_Q SHGEMM_DEFAULT_Q
#define SHGEMM_R SHGEMM_DEFAULT_R
#define SHGEMM_UNROLL_M SHGEMM_DEFAULT_UNROLL_M
#define SHGEMM_UNROLL_N SHGEMM_DEFAULT_UNROLL_N
#ifdef SHGEMM_DEFAULT_UNROLL_MN
#define SHGEMM_UNROLL_MN SHGEMM_DEFAULT_UNROLL_MN
#else
#define SHGEMM_UNROLL_MN MAX((SHGEMM_UNROLL_M), (SHGEMM_UNROLL_N))
#endif
#endif

#if (BUILD_BFLOAT16 == 1)
#define SBGEMM_P SBGEMM_DEFAULT_P
#define SBGEMM_Q SBGEMM_DEFAULT_Q
@@ -1478,6 +1515,7 @@ extern gotoblas_t *gotoblas;


#endif

#endif

#ifndef COMPLEX
@@ -1505,6 +1543,18 @@ extern gotoblas_t *gotoblas;
#define GEMM_DEFAULT_R DGEMM_DEFAULT_R
#define GEMM_DEFAULT_UNROLL_M DGEMM_DEFAULT_UNROLL_M
#define GEMM_DEFAULT_UNROLL_N DGEMM_DEFAULT_UNROLL_N
#elif defined(HFLOAT16)
#define GEMM_P SHGEMM_P
#define GEMM_Q SHGEMM_Q
#define GEMM_R SHGEMM_R
#define GEMM_UNROLL_M SHGEMM_UNROLL_M
#define GEMM_UNROLL_N SHGEMM_UNROLL_N
#define GEMM_UNROLL_MN SHGEMM_UNROLL_MN
#define GEMM_DEFAULT_P SHGEMM_DEFAULT_P
#define GEMM_DEFAULT_Q SHGEMM_DEFAULT_Q
#define GEMM_DEFAULT_R SHGEMM_DEFAULT_R
#define GEMM_DEFAULT_UNROLL_M SHGEMM_DEFAULT_UNROLL_M
#define GEMM_DEFAULT_UNROLL_N SHGEMM_DEFAULT_UNROLL_N
#elif defined(BFLOAT16)
#define GEMM_P SBGEMM_P
#define GEMM_Q SBGEMM_Q


+ 72
- 0
common_sh.h View File

@@ -0,0 +1,72 @@
#ifndef COMMON_SH_H
#define COMMON_SH_H

#ifndef DYNAMIC_ARCH

#define SHGEMM_ONCOPY shgemm_oncopy
#define SHGEMM_OTCOPY shgemm_otcopy

#if SGEMM_DEFAULT_UNROLL_M == SGEMM_DEFAULT_UNROLL_N
#define SHGEMM_INCOPY shgemm_oncopy
#define SHGEMM_ITCOPY shgemm_otcopy
#else
#define SHGEMM_INCOPY shgemm_incopy
#define SHGEMM_ITCOPY shgemm_itcopy
#endif

#define SHGEMM_BETA shgemm_beta
#define SHGEMM_KERNEL shgemm_kernel


#else // #DYNAMIC_ARCH

#define SHGEMM_ONCOPY gotoblas -> shgemm_oncopy
#define SHGEMM_OTCOPY gotoblas -> shgemm_otcopy
#if SGEMM_DEFAULT_UNROLL_M == SGEMM_DEFAULT_UNROLL_N
#define SHGEMM_INCOPY gotoblas -> shgemm_oncopy
#define SHGEMM_ITCOPY gotoblas -> shgemm_otcopy
#else
#define SHGEMM_INCOPY gotoblas -> shgemm_incopy
#define SHGEMM_ITCOPY gotoblas -> shgemm_itcopy
#endif

#define SHGEMM_BETA gotoblas -> shgemm_beta
#define SHGEMM_KERNEL gotoblas -> shgemm_kernel
#endif // #DYNAMIC_ARCH

#define SHGEMM_NN shgemm_nn
#define SHGEMM_CN shgemm_tn
#define SHGEMM_TN shgemm_tn
#define SHGEMM_NC shgemm_nt
#define SHGEMM_NT shgemm_nt
#define SHGEMM_CC shgemm_tt
#define SHGEMM_CT shgemm_tt
#define SHGEMM_TC shgemm_tt
#define SHGEMM_TT shgemm_tt
#define SHGEMM_NR shgemm_nn
#define SHGEMM_TR shgemm_tn
#define SHGEMM_CR shgemm_tn
#define SHGEMM_RN shgemm_nn
#define SHGEMM_RT shgemm_nt
#define SHGEMM_RC shgemm_nt
#define SHGEMM_RR shgemm_nn

#define SHGEMM_THREAD_NN shgemm_thread_nn
#define SHGEMM_THREAD_CN shgemm_thread_tn
#define SHGEMM_THREAD_TN shgemm_thread_tn
#define SHGEMM_THREAD_NC shgemm_thread_nt
#define SHGEMM_THREAD_NT shgemm_thread_nt
#define SHGEMM_THREAD_CC shgemm_thread_tt
#define SHGEMM_THREAD_CT shgemm_thread_tt
#define SHGEMM_THREAD_TC shgemm_thread_tt
#define SHGEMM_THREAD_TT shgemm_thread_tt
#define SHGEMM_THREAD_NR shgemm_thread_nn
#define SHGEMM_THREAD_TR shgemm_thread_tn
#define SHGEMM_THREAD_CR shgemm_thread_tn
#define SHGEMM_THREAD_RN shgemm_thread_nn
#define SHGEMM_THREAD_RT shgemm_thread_nt
#define SHGEMM_THREAD_RC shgemm_thread_nt
#define SHGEMM_THREAD_RR shgemm_thread_nn


#endif // #COMMON_SH_H

+ 5
- 0
openblas_config_template.h View File

@@ -39,6 +39,11 @@ typedef unsigned long BLASULONG;
typedef uint16_t bfloat16;
#endif

#ifndef HFLOAT16
#include <stdint.h>
typedef uint16_t hfloat16;
#endif

#ifdef OPENBLAS_USE64BITINT
typedef BLASLONG blasint;
#else


+ 28
- 0
param.h View File

@@ -72,6 +72,11 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#ifndef PARAM_H
#define PARAM_H

#define SHGEMM_DEFAULT_UNROLL_N 8
#define SHGEMM_DEFAULT_UNROLL_M 8
#define SHGEMM_DEFAULT_P 128
#define SHGEMM_DEFAULT_R 240
#define SHGEMM_DEFAULT_Q 12288

#define SBGEMM_DEFAULT_UNROLL_N 4
#define SBGEMM_DEFAULT_UNROLL_M 8
@@ -3138,10 +3143,16 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#endif

#ifdef RISCV64_ZVL128B

#define GEMM_DEFAULT_OFFSET_A 0
#define GEMM_DEFAULT_OFFSET_B 0
#define GEMM_DEFAULT_ALIGN (BLASLONG)0x03fffUL

#undef SHGEMM_DEFAULT_UNROLL_M
#undef SHGEMM_DEFAULT_UNROLL_N
#define SHGEMM_DEFAULT_UNROLL_M 8
#define SHGEMM_DEFAULT_UNROLL_N 8

#define SGEMM_DEFAULT_UNROLL_M 8
#define SGEMM_DEFAULT_UNROLL_N 8

@@ -3154,16 +3165,22 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#define ZGEMM_DEFAULT_UNROLL_M 4
#define ZGEMM_DEFAULT_UNROLL_N 4

#undef SHGEMM_DEFAULT_P
#define SHGEMM_DEFAULT_P 128
#define SGEMM_DEFAULT_P 128
#define DGEMM_DEFAULT_P 128
#define CGEMM_DEFAULT_P 96
#define ZGEMM_DEFAULT_P 64

#undef SHGEMM_DEFAULT_Q
#define SHGEMM_DEFAULT_Q 240
#define SGEMM_DEFAULT_Q 240
#define DGEMM_DEFAULT_Q 120
#define CGEMM_DEFAULT_Q 120
#define ZGEMM_DEFAULT_Q 120

#undef SHGEMM_DEFAULT_R
#define SHGEMM_DEFAULT_R 12288
#define SGEMM_DEFAULT_R 12288
#define DGEMM_DEFAULT_R 8192
#define CGEMM_DEFAULT_R 4096
@@ -3181,6 +3198,11 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#define GEMM_DEFAULT_OFFSET_B 0
#define GEMM_DEFAULT_ALIGN 0x03fffUL

#undef SHGEMM_DEFAULT_UNROLL_M
#undef SHGEMM_DEFAULT_UNROLL_N
#define SHGEMM_DEFAULT_UNROLL_M 16
#define SHGEMM_DEFAULT_UNROLL_N 8

#define SGEMM_DEFAULT_UNROLL_M 16
#define SGEMM_DEFAULT_UNROLL_N 8

@@ -3193,16 +3215,22 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#define ZGEMM_DEFAULT_UNROLL_M 8
#define ZGEMM_DEFAULT_UNROLL_N 4

#undef SHGEMM_DEFAULT_P
#define SHGEMM_DEFAULT_P 128
#define SGEMM_DEFAULT_P 128
#define DGEMM_DEFAULT_P 64
#define CGEMM_DEFAULT_P 64
#define ZGEMM_DEFAULT_P 64

#undef SHGEMM_DEFAULT_Q
#define SHGEMM_DEFAULT_Q 128
#define SGEMM_DEFAULT_Q 128
#define DGEMM_DEFAULT_Q 128
#define CGEMM_DEFAULT_Q 128
#define ZGEMM_DEFAULT_Q 64

#undef SHGEMM_DEFAULT_R
#define SHGEMM_DEFAULT_R 16384
#define SGEMM_DEFAULT_R 16384
#define DGEMM_DEFAULT_R 8192
#define CGEMM_DEFAULT_R 8192


Loading…
Cancel
Save