@@ -1889,6 +1889,7 @@ export TARGET_CORE | |||
export NO_AVX512 | |||
export NO_AVX2 | |||
export BUILD_BFLOAT16 | |||
export BUILD_HFLOAT16 | |||
export NO_LSX | |||
export NO_LASX | |||
@@ -1,4 +1,5 @@ | |||
SBBLASOBJS_P = $(SBBLASOBJS:.$(SUFFIX)=.$(PSUFFIX)) | |||
SHBLASPBJS_P = $(SHBLASOBJS:.$(SUFFIX)=.$(PSUFFIX)) | |||
SBLASOBJS_P = $(SBLASOBJS:.$(SUFFIX)=.$(PSUFFIX)) | |||
DBLASOBJS_P = $(DBLASOBJS:.$(SUFFIX)=.$(PSUFFIX)) | |||
QBLASOBJS_P = $(QBLASOBJS:.$(SUFFIX)=.$(PSUFFIX)) | |||
@@ -11,8 +12,8 @@ COMMONOBJS_P = $(COMMONOBJS:.$(SUFFIX)=.$(PSUFFIX)) | |||
HPLOBJS_P = $(HPLOBJS:.$(SUFFIX)=.$(PSUFFIX)) | |||
BLASOBJS = $(SBEXTOBJS) $(SBBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS) $(CBAUXOBJS) | |||
BLASOBJS_P = $(SBEXTOBJS_P) $(SBBLASOBJS_P) $(SBLASOBJS_P) $(DBLASOBJS_P) $(CBLASOBJS_P) $(ZBLASOBJS_P) $(CBAUXOBJS_P) | |||
BLASOBJS = $(SHBLASOBJS) $(SBEXTOBJS) $(SBBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS) $(CBAUXOBJS) | |||
BLASOBJS_P = $(SHBLASPBJS_P) $(SBEXTOBJS_P) $(SBBLASOBJS_P) $(SBLASOBJS_P) $(DBLASOBJS_P) $(CBLASOBJS_P) $(ZBLASOBJS_P) $(CBAUXOBJS_P) | |||
ifdef EXPRECISION | |||
BLASOBJS += $(QBLASOBJS) $(XBLASOBJS) | |||
@@ -24,6 +25,7 @@ BLASOBJS += $(QBLASOBJS) $(XBLASOBJS) | |||
BLASOBJS_P += $(QBLASOBJS_P) $(XBLASOBJS_P) | |||
endif | |||
$(SHBLASOBJS) $(SHBLASOBJS_P) : override CFLAGS += -DHFLOAT16 -UDOUBLE -UCOMPLEX | |||
$(SBBLASOBJS) $(SBBLASOBJS_P) : override CFLAGS += -DBFLOAT16 -UDOUBLE -UCOMPLEX | |||
$(SBLASOBJS) $(SBLASOBJS_P) : override CFLAGS += -UDOUBLE -UCOMPLEX | |||
$(DBLASOBJS) $(DBLASOBJS_P) : override CFLAGS += -DDOUBLE -UCOMPLEX | |||
@@ -33,6 +35,7 @@ $(ZBLASOBJS) $(ZBLASOBJS_P) : override CFLAGS += -DDOUBLE -DCOMPLEX | |||
$(XBLASOBJS) $(XBLASOBJS_P) : override CFLAGS += -DXDOUBLE -DCOMPLEX | |||
$(SBEXTOBJS) $(SBEXTOBJS_P) : override CFLAGS += -DBFLOAT16 -UDOUBLE -UCOMPLEX | |||
$(SHBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF) | |||
$(SBBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF) | |||
$(SBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF) | |||
$(DBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF) | |||
@@ -446,6 +446,10 @@ void cblas_sbgemm(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum C | |||
void cblas_sbgemm_batch(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransA_array, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransB_array, OPENBLAS_CONST blasint * M_array, OPENBLAS_CONST blasint * N_array, OPENBLAS_CONST blasint * K_array, | |||
OPENBLAS_CONST float * alpha_array, OPENBLAS_CONST bfloat16 ** A_array, OPENBLAS_CONST blasint * lda_array, OPENBLAS_CONST bfloat16 ** B_array, OPENBLAS_CONST blasint * ldb_array, OPENBLAS_CONST float * beta_array, float ** C_array, OPENBLAS_CONST blasint * ldc_array, OPENBLAS_CONST blasint group_count, OPENBLAS_CONST blasint * group_size); | |||
/*** FLOAT16 extensions */ | |||
void cblas_shgemm(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransA, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransB, OPENBLAS_CONST blasint M, OPENBLAS_CONST blasint N, OPENBLAS_CONST blasint K, | |||
OPENBLAS_CONST float alpha, OPENBLAS_CONST hfloat16 *A, OPENBLAS_CONST blasint lda, OPENBLAS_CONST hfloat16 *B, OPENBLAS_CONST blasint ldb, OPENBLAS_CONST float beta, float *C, OPENBLAS_CONST blasint ldc); | |||
#ifdef __cplusplus | |||
} | |||
#endif /* __cplusplus */ | |||
@@ -640,6 +640,9 @@ endif() | |||
if (BUILD_BFLOAT16) | |||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DBUILD_BFLOAT16") | |||
endif() | |||
if (BUILD_HFLOAT16) | |||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DBUILD_HFLOAT16") | |||
endif() | |||
if(NOT MSVC) | |||
set(CMAKE_ASM_FLAGS "${CMAKE_ASM_FLAGS} ${CCOMMON_OPT}") | |||
endif() | |||
@@ -647,14 +650,14 @@ endif() | |||
set(PFLAGS "${PFLAGS} ${CCOMMON_OPT} -I${TOPDIR} -DPROFILE ${COMMON_PROF}") | |||
if ("${CMAKE_BUILD_TYPE}" STREQUAL "Release") | |||
if ("${F_COMPILER}" STREQUAL "FLANG") | |||
if (${CMAKE_Fortran_COMPILER_VERSION} VERSION_LESS_EQUAL 3) | |||
set(CMAKE_Fortran_FLAGS_RELEASE "${CMAKE_Fortran_FLAGS_RELEASE} -fno-unroll-loops") | |||
endif () | |||
endif () | |||
if (ARM64 AND CMAKE_Fortran_COMPILER_ID MATCHES "LLVMFlang.*" AND CMAKE_SYSTEM_NAME STREQUAL "Windows") | |||
set(CMAKE_Fortran_FLAGS_RELEASE "${CMAKE_Fortran_FLAGS_RELEASE} -O2") | |||
endif () | |||
if ("${F_COMPILER}" STREQUAL "FLANG") | |||
if (${CMAKE_Fortran_COMPILER_VERSION} VERSION_LESS_EQUAL 3) | |||
set(CMAKE_Fortran_FLAGS_RELEASE "${CMAKE_Fortran_FLAGS_RELEASE} -fno-unroll-loops") | |||
endif () | |||
endif () | |||
if (ARM64 AND CMAKE_Fortran_COMPILER_ID MATCHES "LLVMFlang.*" AND CMAKE_SYSTEM_NAME STREQUAL "Windows") | |||
set(CMAKE_Fortran_FLAGS_RELEASE "${CMAKE_Fortran_FLAGS_RELEASE} -O2") | |||
endif () | |||
endif () | |||
@@ -266,6 +266,11 @@ typedef uint16_t bfloat16; | |||
#define BFLOAT16CONVERSION 1 | |||
#endif | |||
#ifndef hfloat16 | |||
#include <stdint.h> | |||
typedef uint16_t hfloat16; | |||
#endif | |||
#ifdef USE64BITINT | |||
typedef BLASLONG blasint; | |||
#if defined(OS_WINDOWS) && defined(__64BIT__) | |||
@@ -313,8 +318,8 @@ typedef int blasint; | |||
#define SIZE 2 | |||
#define BASE_SHIFT 1 | |||
#define ZBASE_SHIFT 2 | |||
#elif defined(FLOAT16) | |||
#define IFLOAT float16 | |||
#elif defined(HFLOAT16) | |||
#define IFLOAT hfloat16 | |||
#define XFLOAT IFLOAT | |||
#define FLOAT float | |||
#define SIZE 2 | |||
@@ -481,6 +481,8 @@ void BLASFUNC(xhbmv)(char *, blasint *, blasint *, xdouble *, xdouble *, blasint | |||
/* Level 3 routines */ | |||
void BLASFUNC(shgemm)(char *, char *, blasint *, blasint *, blasint *, float *, | |||
hfloat16 *, blasint *, hfloat16 *, blasint *, float *, float *, blasint *); | |||
void BLASFUNC(sbgemm)(char *, char *, blasint *, blasint *, blasint *, float *, | |||
bfloat16 *, blasint *, bfloat16 *, blasint *, float *, float *, blasint *); | |||
void BLASFUNC(sgemm)(char *, char *, blasint *, blasint *, blasint *, float *, | |||
@@ -54,7 +54,8 @@ void sgemm_direct(BLASLONG M, BLASLONG N, BLASLONG K, | |||
int sgemm_direct_performant(BLASLONG M, BLASLONG N, BLASLONG K); | |||
int shgemm_beta(BLASLONG, BLASLONG, BLASLONG, float, | |||
hfloat16 *, BLASLONG, hfloat16 *, BLASLONG, float *, BLASLONG); | |||
int sbgemm_beta(BLASLONG, BLASLONG, BLASLONG, float, | |||
bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, float *, BLASLONG); | |||
int sgemm_beta(BLASLONG, BLASLONG, BLASLONG, float, | |||
@@ -78,6 +79,10 @@ int xgemm_beta(BLASLONG, BLASLONG, BLASLONG, xdouble *, | |||
xdouble *, BLASLONG, xdouble *, BLASLONG, xdouble *, BLASLONG); | |||
#endif | |||
int shgemm_incopy(BLASLONG m, BLASLONG n, hfloat16 *a, BLASLONG lda, hfloat16 *b); | |||
int shgemm_itcopy(BLASLONG m, BLASLONG n, hfloat16 *a, BLASLONG lda, hfloat16 *b); | |||
int shgemm_oncopy(BLASLONG m, BLASLONG n, hfloat16 *a, BLASLONG lda, hfloat16 *b); | |||
int shgemm_otcopy(BLASLONG m, BLASLONG n, hfloat16 *a, BLASLONG lda, hfloat16 *b); | |||
int sbgemm_incopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b); | |||
int sbgemm_itcopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b); | |||
int sbgemm_oncopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b); | |||
@@ -505,6 +510,7 @@ int xher2k_kernel_UC(BLASLONG m, BLASLONG n, BLASLONG k, xdouble alpha_r, xdoubl | |||
int xher2k_kernel_LN(BLASLONG m, BLASLONG n, BLASLONG k, xdouble alpha_r, xdouble alpha_i, xdouble *a, xdouble *b, xdouble *c, BLASLONG ldc, BLASLONG offset, int flag); | |||
int xher2k_kernel_LC(BLASLONG m, BLASLONG n, BLASLONG k, xdouble alpha_r, xdouble alpha_i, xdouble *a, xdouble *b, xdouble *c, BLASLONG ldc, BLASLONG offset, int flag); | |||
int shgemm_kernel(BLASLONG, BLASLONG, BLASLONG, float, hfloat16 *, hfloat16 *, float *, BLASLONG); | |||
int sbgemm_kernel(BLASLONG, BLASLONG, BLASLONG, float, bfloat16 *, bfloat16 *, float *, BLASLONG); | |||
int sgemm_kernel(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG); | |||
int dgemm_kernel(BLASLONG, BLASLONG, BLASLONG, double, double *, double *, double *, BLASLONG); | |||
@@ -657,6 +663,11 @@ int cgemm3m_kernel(BLASLONG, BLASLONG, BLASLONG, float, float, float *, float | |||
int zgemm3m_kernel(BLASLONG, BLASLONG, BLASLONG, double, double, double *, double *, double *, BLASLONG); | |||
int xgemm3m_kernel(BLASLONG, BLASLONG, BLASLONG, xdouble, xdouble, xdouble *, xdouble *, xdouble *, BLASLONG); | |||
int shgemm_nn(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG); | |||
int shgemm_nt(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG); | |||
int shgemm_tn(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG); | |||
int shgemm_tt(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG); | |||
int sbgemm_nn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG); | |||
int sbgemm_nt(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG); | |||
int sbgemm_tn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG); | |||
@@ -754,6 +765,11 @@ int xgemm_cr(blas_arg_t *, BLASLONG *, BLASLONG *, xdouble *, xdouble *, BLASLON | |||
int xgemm_cc(blas_arg_t *, BLASLONG *, BLASLONG *, xdouble *, xdouble *, BLASLONG); | |||
#endif | |||
int shgemm_thread_nn(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG); | |||
int shgemm_thread_nt(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG); | |||
int shgemm_thread_tn(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG); | |||
int shgemm_thread_tt(blas_arg_t *, BLASLONG *, BLASLONG *, hfloat16 *, hfloat16 *, BLASLONG); | |||
int sbgemm_thread_nn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG); | |||
int sbgemm_thread_nt(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG); | |||
int sbgemm_thread_tn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG); | |||
@@ -1944,6 +1960,7 @@ int dgemm_batch_thread(blas_arg_t * queue, BLASLONG nums); | |||
int cgemm_batch_thread(blas_arg_t * queue, BLASLONG nums); | |||
int zgemm_batch_thread(blas_arg_t * queue, BLASLONG nums); | |||
int sbgemm_batch_thread(blas_arg_t * queue, BLASLONG nums); | |||
// int shgemm_batch_thread(blas_arg_t * queue, BLASLONG nums); | |||
#ifdef __CUDACC__ | |||
} | |||
@@ -39,6 +39,7 @@ | |||
#ifndef COMMON_MACRO | |||
#define COMMON_MACRO | |||
#include "common_sh.h" | |||
#include "common_sb.h" | |||
#include "common_s.h" | |||
#include "common_d.h" | |||
@@ -656,6 +657,50 @@ | |||
#define GEMM_SMALL_KERNEL_B0_NT DGEMM_SMALL_KERNEL_B0_NT | |||
#define GEMM_SMALL_KERNEL_B0_TN DGEMM_SMALL_KERNEL_B0_TN | |||
#define GEMM_SMALL_KERNEL_B0_TT DGEMM_SMALL_KERNEL_B0_TT | |||
#elif defined(HFLOAT16) | |||
#define GEMM_BETA SHGEMM_BETA | |||
#define GEMM_KERNEL_N SHGEMM_KERNEL | |||
#define GEMM_KERNEL_L SHGEMM_KERNEL | |||
#define GEMM_KERNEL_R SHGEMM_KERNEL | |||
#define GEMM_KERNEL_B SHGEMM_KERNEL | |||
#define GEMM_NN SHGEMM_NN | |||
#define GEMM_CN SHGEMM_TN | |||
#define GEMM_TN SHGEMM_TN | |||
#define GEMM_NC SHGEMM_NT | |||
#define GEMM_NT SHGEMM_NT | |||
#define GEMM_CC SHGEMM_TT | |||
#define GEMM_CT SHGEMM_TT | |||
#define GEMM_TC SHGEMM_TT | |||
#define GEMM_TT SHGEMM_TT | |||
#define GEMM_NR SHGEMM_NN | |||
#define GEMM_TR SHGEMM_TN | |||
#define GEMM_CR SHGEMM_TN | |||
#define GEMM_RN SHGEMM_NN | |||
#define GEMM_RT SHGEMM_NT | |||
#define GEMM_RC SHGEMM_NT | |||
#define GEMM_RR SHGEMM_NN | |||
#define GEMM_ONCOPY SHGEMM_ONCOPY | |||
#define GEMM_OTCOPY SHGEMM_OTCOPY | |||
#define GEMM_INCOPY SHGEMM_INCOPY | |||
#define GEMM_ITCOPY SHGEMM_ITCOPY | |||
#define GEMM_THREAD_NN SHGEMM_THREAD_NN | |||
#define GEMM_THREAD_CN SHGEMM_THREAD_TN | |||
#define GEMM_THREAD_TN SHGEMM_THREAD_TN | |||
#define GEMM_THREAD_NC SHGEMM_THREAD_NT | |||
#define GEMM_THREAD_NT SHGEMM_THREAD_NT | |||
#define GEMM_THREAD_CC SHGEMM_THREAD_TT | |||
#define GEMM_THREAD_CT SHGEMM_THREAD_TT | |||
#define GEMM_THREAD_TC SHGEMM_THREAD_TT | |||
#define GEMM_THREAD_TT SHGEMM_THREAD_TT | |||
#define GEMM_THREAD_NR SHGEMM_THREAD_NN | |||
#define GEMM_THREAD_TR SHGEMM_THREAD_TN | |||
#define GEMM_THREAD_CR SHGEMM_THREAD_TN | |||
#define GEMM_THREAD_RN SHGEMM_THREAD_NN | |||
#define GEMM_THREAD_RT SHGEMM_THREAD_NT | |||
#define GEMM_THREAD_RC SHGEMM_THREAD_NT | |||
#define GEMM_THREAD_RR SHGEMM_THREAD_NN | |||
#elif defined(BFLOAT16) | |||
@@ -48,6 +48,21 @@ typedef struct { | |||
int dtb_entries; | |||
int switch_ratio; | |||
int offsetA, offsetB, align; | |||
#if BUILD_HFLOAT16 == 1 | |||
int shgemm_p, shgemm_q, shgemm_r; | |||
int shgemm_unroll_m, shgemm_unroll_n, shgemm_unroll_mn; | |||
int (*shgemm_kernel )(BLASLONG, BLASLONG, BLASLONG, float, hfloat16 *, hfloat16 *, float *, BLASLONG); | |||
int (*shgemm_beta )(BLASLONG, BLASLONG, BLASLONG, float, hfloat16 *, BLASLONG, hfloat16 *, BLASLONG, float *, BLASLONG); | |||
int (*shgemm_incopy )(BLASLONG, BLASLONG, hfloat16 *, BLASLONG, hfloat16 *); | |||
int (*shgemm_itcopy )(BLASLONG, BLASLONG, hfloat16 *, BLASLONG, hfloat16 *); | |||
int (*shgemm_oncopy )(BLASLONG, BLASLONG, hfloat16 *, BLASLONG, hfloat16 *); | |||
int (*shgemm_otcopy )(BLASLONG, BLASLONG, hfloat16 *, BLASLONG, hfloat16 *); | |||
#endif | |||
#if BUILD_BFLOAT16 == 1 | |||
int sbgemm_p, sbgemm_q, sbgemm_r; | |||
@@ -64,10 +79,10 @@ typedef struct { | |||
float (*sbamin_k) (BLASLONG, float *, BLASLONG); | |||
float (*sbmax_k) (BLASLONG, float *, BLASLONG); | |||
float (*sbmin_k) (BLASLONG, float *, BLASLONG); | |||
BLASLONG (*isbamax_k)(BLASLONG, float *, BLASLONG); | |||
BLASLONG (*isbamin_k)(BLASLONG, float *, BLASLONG); | |||
BLASLONG (*isbmax_k) (BLASLONG, float *, BLASLONG); | |||
BLASLONG (*isbmin_k) (BLASLONG, float *, BLASLONG); | |||
BLASLONG (*isbamax_k)(BLASLONG, float *, BLASLONG); | |||
BLASLONG (*isbamin_k)(BLASLONG, float *, BLASLONG); | |||
BLASLONG (*isbmax_k) (BLASLONG, float *, BLASLONG); | |||
BLASLONG (*isbmin_k) (BLASLONG, float *, BLASLONG); | |||
float (*sbnrm2_k) (BLASLONG, float *, BLASLONG); | |||
float (*sbasum_k) (BLASLONG, float *, BLASLONG); | |||
@@ -180,12 +195,12 @@ BLASLONG (*isbmin_k) (BLASLONG, float *, BLASLONG); | |||
#endif | |||
#if (BUILD_SINGLE==1) || (BUILD_DOUBLE ==1) || (BUILD_COMPLEX==1) | |||
BLASLONG (*isamax_k)(BLASLONG, float *, BLASLONG); | |||
BLASLONG (*isamax_k)(BLASLONG, float *, BLASLONG); | |||
#endif | |||
#if (BUILD_SINGLE==1) || (BUILD_COMPLEX==1) | |||
BLASLONG (*isamin_k)(BLASLONG, float *, BLASLONG); | |||
BLASLONG (*ismax_k) (BLASLONG, float *, BLASLONG); | |||
BLASLONG (*ismin_k) (BLASLONG, float *, BLASLONG); | |||
BLASLONG (*isamin_k)(BLASLONG, float *, BLASLONG); | |||
BLASLONG (*ismax_k) (BLASLONG, float *, BLASLONG); | |||
BLASLONG (*ismin_k) (BLASLONG, float *, BLASLONG); | |||
float (*snrm2_k) (BLASLONG, float *, BLASLONG); | |||
float (*sasum_k) (BLASLONG, float *, BLASLONG); | |||
#endif | |||
@@ -316,10 +331,10 @@ BLASLONG (*ismin_k) (BLASLONG, float *, BLASLONG); | |||
double (*damin_k) (BLASLONG, double *, BLASLONG); | |||
double (*dmax_k) (BLASLONG, double *, BLASLONG); | |||
double (*dmin_k) (BLASLONG, double *, BLASLONG); | |||
BLASLONG (*idamax_k)(BLASLONG, double *, BLASLONG); | |||
BLASLONG (*idamin_k)(BLASLONG, double *, BLASLONG); | |||
BLASLONG (*idmax_k) (BLASLONG, double *, BLASLONG); | |||
BLASLONG (*idmin_k) (BLASLONG, double *, BLASLONG); | |||
BLASLONG (*idamax_k)(BLASLONG, double *, BLASLONG); | |||
BLASLONG (*idamin_k)(BLASLONG, double *, BLASLONG); | |||
BLASLONG (*idmax_k) (BLASLONG, double *, BLASLONG); | |||
BLASLONG (*idmin_k) (BLASLONG, double *, BLASLONG); | |||
double (*dnrm2_k) (BLASLONG, double *, BLASLONG); | |||
double (*dasum_k) (BLASLONG, double *, BLASLONG); | |||
@@ -435,10 +450,10 @@ BLASLONG (*idmin_k) (BLASLONG, double *, BLASLONG); | |||
xdouble (*qamin_k) (BLASLONG, xdouble *, BLASLONG); | |||
xdouble (*qmax_k) (BLASLONG, xdouble *, BLASLONG); | |||
xdouble (*qmin_k) (BLASLONG, xdouble *, BLASLONG); | |||
BLASLONG (*iqamax_k)(BLASLONG, xdouble *, BLASLONG); | |||
BLASLONG (*iqamin_k)(BLASLONG, xdouble *, BLASLONG); | |||
BLASLONG (*iqmax_k) (BLASLONG, xdouble *, BLASLONG); | |||
BLASLONG (*iqmin_k) (BLASLONG, xdouble *, BLASLONG); | |||
BLASLONG (*iqamax_k)(BLASLONG, xdouble *, BLASLONG); | |||
BLASLONG (*iqamin_k)(BLASLONG, xdouble *, BLASLONG); | |||
BLASLONG (*iqmax_k) (BLASLONG, xdouble *, BLASLONG); | |||
BLASLONG (*iqmin_k) (BLASLONG, xdouble *, BLASLONG); | |||
xdouble (*qnrm2_k) (BLASLONG, xdouble *, BLASLONG); | |||
xdouble (*qasum_k) (BLASLONG, xdouble *, BLASLONG); | |||
@@ -528,8 +543,8 @@ BLASLONG (*iqmin_k) (BLASLONG, xdouble *, BLASLONG); | |||
float (*camax_k) (BLASLONG, float *, BLASLONG); | |||
float (*camin_k) (BLASLONG, float *, BLASLONG); | |||
BLASLONG (*icamax_k)(BLASLONG, float *, BLASLONG); | |||
BLASLONG (*icamin_k)(BLASLONG, float *, BLASLONG); | |||
BLASLONG (*icamax_k)(BLASLONG, float *, BLASLONG); | |||
BLASLONG (*icamin_k)(BLASLONG, float *, BLASLONG); | |||
float (*cnrm2_k) (BLASLONG, float *, BLASLONG); | |||
float (*casum_k) (BLASLONG, float *, BLASLONG); | |||
@@ -739,8 +754,8 @@ BLASLONG (*icamin_k)(BLASLONG, float *, BLASLONG); | |||
double (*zamax_k) (BLASLONG, double *, BLASLONG); | |||
double (*zamin_k) (BLASLONG, double *, BLASLONG); | |||
BLASLONG (*izamax_k)(BLASLONG, double *, BLASLONG); | |||
BLASLONG (*izamin_k)(BLASLONG, double *, BLASLONG); | |||
BLASLONG (*izamax_k)(BLASLONG, double *, BLASLONG); | |||
BLASLONG (*izamin_k)(BLASLONG, double *, BLASLONG); | |||
double (*znrm2_k) (BLASLONG, double *, BLASLONG); | |||
double (*zasum_k) (BLASLONG, double *, BLASLONG); | |||
@@ -950,8 +965,8 @@ BLASLONG (*izamin_k)(BLASLONG, double *, BLASLONG); | |||
xdouble (*xamax_k) (BLASLONG, xdouble *, BLASLONG); | |||
xdouble (*xamin_k) (BLASLONG, xdouble *, BLASLONG); | |||
BLASLONG (*ixamax_k)(BLASLONG, xdouble *, BLASLONG); | |||
BLASLONG (*ixamin_k)(BLASLONG, xdouble *, BLASLONG); | |||
BLASLONG (*ixamax_k)(BLASLONG, xdouble *, BLASLONG); | |||
BLASLONG (*ixamin_k)(BLASLONG, xdouble *, BLASLONG); | |||
xdouble (*xnrm2_k) (BLASLONG, xdouble *, BLASLONG); | |||
xdouble (*xasum_k) (BLASLONG, xdouble *, BLASLONG); | |||
@@ -1229,6 +1244,15 @@ extern gotoblas_t *gotoblas; | |||
#define HAVE_EX_L2 gotoblas -> exclusive_cache | |||
#if (BUILD_HFLOAT16==1) | |||
#define SHGEMM_P gotoblas -> shgemm_p | |||
#define SHGEMM_Q gotoblas -> shgemm_q | |||
#define SHGEMM_R gotoblas -> shgemm_r | |||
#define SHGEMM_UNROLL_M gotoblas -> shgemm_unroll_m | |||
#define SHGEMM_UNROLL_N gotoblas -> shgemm_unroll_n | |||
#define SHGEMM_UNROLL_MN gotoblas -> shgemm_unroll_mn | |||
#endif | |||
#if (BUILD_BFLOAT16==1) | |||
#define SBGEMM_P gotoblas -> sbgemm_p | |||
#define SBGEMM_Q gotoblas -> sbgemm_q | |||
@@ -1357,6 +1381,19 @@ extern gotoblas_t *gotoblas; | |||
#define HAVE_EX_L2 0 | |||
#endif | |||
#if (BUILD_HFLOAT16 == 1) | |||
#define SHGEMM_P SHGEMM_DEFAULT_P | |||
#define SHGEMM_Q SHGEMM_DEFAULT_Q | |||
#define SHGEMM_R SHGEMM_DEFAULT_R | |||
#define SHGEMM_UNROLL_M SHGEMM_DEFAULT_UNROLL_M | |||
#define SHGEMM_UNROLL_N SHGEMM_DEFAULT_UNROLL_N | |||
#ifdef SHGEMM_DEFAULT_UNROLL_MN | |||
#define SHGEMM_UNROLL_MN SHGEMM_DEFAULT_UNROLL_MN | |||
#else | |||
#define SHGEMM_UNROLL_MN MAX((SHGEMM_UNROLL_M), (SHGEMM_UNROLL_N)) | |||
#endif | |||
#endif | |||
#if (BUILD_BFLOAT16 == 1) | |||
#define SBGEMM_P SBGEMM_DEFAULT_P | |||
#define SBGEMM_Q SBGEMM_DEFAULT_Q | |||
@@ -1478,6 +1515,7 @@ extern gotoblas_t *gotoblas; | |||
#endif | |||
#endif | |||
#ifndef COMPLEX | |||
@@ -1505,6 +1543,18 @@ extern gotoblas_t *gotoblas; | |||
#define GEMM_DEFAULT_R DGEMM_DEFAULT_R | |||
#define GEMM_DEFAULT_UNROLL_M DGEMM_DEFAULT_UNROLL_M | |||
#define GEMM_DEFAULT_UNROLL_N DGEMM_DEFAULT_UNROLL_N | |||
#elif defined(HFLOAT16) | |||
#define GEMM_P SHGEMM_P | |||
#define GEMM_Q SHGEMM_Q | |||
#define GEMM_R SHGEMM_R | |||
#define GEMM_UNROLL_M SHGEMM_UNROLL_M | |||
#define GEMM_UNROLL_N SHGEMM_UNROLL_N | |||
#define GEMM_UNROLL_MN SHGEMM_UNROLL_MN | |||
#define GEMM_DEFAULT_P SHGEMM_DEFAULT_P | |||
#define GEMM_DEFAULT_Q SHGEMM_DEFAULT_Q | |||
#define GEMM_DEFAULT_R SHGEMM_DEFAULT_R | |||
#define GEMM_DEFAULT_UNROLL_M SHGEMM_DEFAULT_UNROLL_M | |||
#define GEMM_DEFAULT_UNROLL_N SHGEMM_DEFAULT_UNROLL_N | |||
#elif defined(BFLOAT16) | |||
#define GEMM_P SBGEMM_P | |||
#define GEMM_Q SBGEMM_Q | |||
@@ -0,0 +1,72 @@ | |||
#ifndef COMMON_SH_H | |||
#define COMMON_SH_H | |||
#ifndef DYNAMIC_ARCH | |||
#define SHGEMM_ONCOPY shgemm_oncopy | |||
#define SHGEMM_OTCOPY shgemm_otcopy | |||
#if SGEMM_DEFAULT_UNROLL_M == SGEMM_DEFAULT_UNROLL_N | |||
#define SHGEMM_INCOPY shgemm_oncopy | |||
#define SHGEMM_ITCOPY shgemm_otcopy | |||
#else | |||
#define SHGEMM_INCOPY shgemm_incopy | |||
#define SHGEMM_ITCOPY shgemm_itcopy | |||
#endif | |||
#define SHGEMM_BETA shgemm_beta | |||
#define SHGEMM_KERNEL shgemm_kernel | |||
#else // #DYNAMIC_ARCH | |||
#define SHGEMM_ONCOPY gotoblas -> shgemm_oncopy | |||
#define SHGEMM_OTCOPY gotoblas -> shgemm_otcopy | |||
#if SGEMM_DEFAULT_UNROLL_M == SGEMM_DEFAULT_UNROLL_N | |||
#define SHGEMM_INCOPY gotoblas -> shgemm_oncopy | |||
#define SHGEMM_ITCOPY gotoblas -> shgemm_otcopy | |||
#else | |||
#define SHGEMM_INCOPY gotoblas -> shgemm_incopy | |||
#define SHGEMM_ITCOPY gotoblas -> shgemm_itcopy | |||
#endif | |||
#define SHGEMM_BETA gotoblas -> shgemm_beta | |||
#define SHGEMM_KERNEL gotoblas -> shgemm_kernel | |||
#endif // #DYNAMIC_ARCH | |||
#define SHGEMM_NN shgemm_nn | |||
#define SHGEMM_CN shgemm_tn | |||
#define SHGEMM_TN shgemm_tn | |||
#define SHGEMM_NC shgemm_nt | |||
#define SHGEMM_NT shgemm_nt | |||
#define SHGEMM_CC shgemm_tt | |||
#define SHGEMM_CT shgemm_tt | |||
#define SHGEMM_TC shgemm_tt | |||
#define SHGEMM_TT shgemm_tt | |||
#define SHGEMM_NR shgemm_nn | |||
#define SHGEMM_TR shgemm_tn | |||
#define SHGEMM_CR shgemm_tn | |||
#define SHGEMM_RN shgemm_nn | |||
#define SHGEMM_RT shgemm_nt | |||
#define SHGEMM_RC shgemm_nt | |||
#define SHGEMM_RR shgemm_nn | |||
#define SHGEMM_THREAD_NN shgemm_thread_nn | |||
#define SHGEMM_THREAD_CN shgemm_thread_tn | |||
#define SHGEMM_THREAD_TN shgemm_thread_tn | |||
#define SHGEMM_THREAD_NC shgemm_thread_nt | |||
#define SHGEMM_THREAD_NT shgemm_thread_nt | |||
#define SHGEMM_THREAD_CC shgemm_thread_tt | |||
#define SHGEMM_THREAD_CT shgemm_thread_tt | |||
#define SHGEMM_THREAD_TC shgemm_thread_tt | |||
#define SHGEMM_THREAD_TT shgemm_thread_tt | |||
#define SHGEMM_THREAD_NR shgemm_thread_nn | |||
#define SHGEMM_THREAD_TR shgemm_thread_tn | |||
#define SHGEMM_THREAD_CR shgemm_thread_tn | |||
#define SHGEMM_THREAD_RN shgemm_thread_nn | |||
#define SHGEMM_THREAD_RT shgemm_thread_nt | |||
#define SHGEMM_THREAD_RC shgemm_thread_nt | |||
#define SHGEMM_THREAD_RR shgemm_thread_nn | |||
#endif // #COMMON_SH_H |
@@ -39,6 +39,11 @@ typedef unsigned long BLASULONG; | |||
typedef uint16_t bfloat16; | |||
#endif | |||
#ifndef HFLOAT16 | |||
#include <stdint.h> | |||
typedef uint16_t hfloat16; | |||
#endif | |||
#ifdef OPENBLAS_USE64BITINT | |||
typedef BLASLONG blasint; | |||
#else | |||
@@ -72,6 +72,11 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||
#ifndef PARAM_H | |||
#define PARAM_H | |||
#define SHGEMM_DEFAULT_UNROLL_N 8 | |||
#define SHGEMM_DEFAULT_UNROLL_M 8 | |||
#define SHGEMM_DEFAULT_P 128 | |||
#define SHGEMM_DEFAULT_R 240 | |||
#define SHGEMM_DEFAULT_Q 12288 | |||
#define SBGEMM_DEFAULT_UNROLL_N 4 | |||
#define SBGEMM_DEFAULT_UNROLL_M 8 | |||
@@ -3138,10 +3143,16 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||
#endif | |||
#ifdef RISCV64_ZVL128B | |||
#define GEMM_DEFAULT_OFFSET_A 0 | |||
#define GEMM_DEFAULT_OFFSET_B 0 | |||
#define GEMM_DEFAULT_ALIGN (BLASLONG)0x03fffUL | |||
#undef SHGEMM_DEFAULT_UNROLL_M | |||
#undef SHGEMM_DEFAULT_UNROLL_N | |||
#define SHGEMM_DEFAULT_UNROLL_M 8 | |||
#define SHGEMM_DEFAULT_UNROLL_N 8 | |||
#define SGEMM_DEFAULT_UNROLL_M 8 | |||
#define SGEMM_DEFAULT_UNROLL_N 8 | |||
@@ -3154,16 +3165,22 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||
#define ZGEMM_DEFAULT_UNROLL_M 4 | |||
#define ZGEMM_DEFAULT_UNROLL_N 4 | |||
#undef SHGEMM_DEFAULT_P | |||
#define SHGEMM_DEFAULT_P 128 | |||
#define SGEMM_DEFAULT_P 128 | |||
#define DGEMM_DEFAULT_P 128 | |||
#define CGEMM_DEFAULT_P 96 | |||
#define ZGEMM_DEFAULT_P 64 | |||
#undef SHGEMM_DEFAULT_Q | |||
#define SHGEMM_DEFAULT_Q 240 | |||
#define SGEMM_DEFAULT_Q 240 | |||
#define DGEMM_DEFAULT_Q 120 | |||
#define CGEMM_DEFAULT_Q 120 | |||
#define ZGEMM_DEFAULT_Q 120 | |||
#undef SHGEMM_DEFAULT_R | |||
#define SHGEMM_DEFAULT_R 12288 | |||
#define SGEMM_DEFAULT_R 12288 | |||
#define DGEMM_DEFAULT_R 8192 | |||
#define CGEMM_DEFAULT_R 4096 | |||
@@ -3181,6 +3198,11 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||
#define GEMM_DEFAULT_OFFSET_B 0 | |||
#define GEMM_DEFAULT_ALIGN 0x03fffUL | |||
#undef SHGEMM_DEFAULT_UNROLL_M | |||
#undef SHGEMM_DEFAULT_UNROLL_N | |||
#define SHGEMM_DEFAULT_UNROLL_M 16 | |||
#define SHGEMM_DEFAULT_UNROLL_N 8 | |||
#define SGEMM_DEFAULT_UNROLL_M 16 | |||
#define SGEMM_DEFAULT_UNROLL_N 8 | |||
@@ -3193,16 +3215,22 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||
#define ZGEMM_DEFAULT_UNROLL_M 8 | |||
#define ZGEMM_DEFAULT_UNROLL_N 4 | |||
#undef SHGEMM_DEFAULT_P | |||
#define SHGEMM_DEFAULT_P 128 | |||
#define SGEMM_DEFAULT_P 128 | |||
#define DGEMM_DEFAULT_P 64 | |||
#define CGEMM_DEFAULT_P 64 | |||
#define ZGEMM_DEFAULT_P 64 | |||
#undef SHGEMM_DEFAULT_Q | |||
#define SHGEMM_DEFAULT_Q 128 | |||
#define SGEMM_DEFAULT_Q 128 | |||
#define DGEMM_DEFAULT_Q 128 | |||
#define CGEMM_DEFAULT_Q 128 | |||
#define ZGEMM_DEFAULT_Q 64 | |||
#undef SHGEMM_DEFAULT_R | |||
#define SHGEMM_DEFAULT_R 16384 | |||
#define SGEMM_DEFAULT_R 16384 | |||
#define DGEMM_DEFAULT_R 8192 | |||
#define CGEMM_DEFAULT_R 8192 | |||