1. Added bfloat16 based dot as new API: shdot 2. Implemented generic kernel and cooperlake-specific (AVX512-BF16) kernel for shdot 3. Added 4 conversion APIs for bfloat16 data type <=> single/double: shstobf16 shdtobf16 sbf16tos dbf16tod shstobf16 -- convert single float array to bfloat16 array shdtobf16 -- convert double float array to bfloat16 array sbf16tos -- convert bfloat16 array to single float array dbf16tod -- convert bfloat16 array to double float array 4. Implemented generic kernels for all 4 conversion APIs, and cooperlake-specific kernel for shstobf16 and shdtobf16 5. Update level1 thread facilitate functions and macros to support multi-threading for these new APIs 6. Fix Cooperlake platform detection/specify issue when under dynamic-arch building 7. Change the typedef of bfloat16 from unsigned short to more strict uint16_t Signed-off-by: Chen, Guobing <guobing.chen@intel.com>tags/v0.3.11^2
@@ -5,13 +5,14 @@ QBLASOBJS_P = $(QBLASOBJS:.$(SUFFIX)=.$(PSUFFIX)) | |||
CBLASOBJS_P = $(CBLASOBJS:.$(SUFFIX)=.$(PSUFFIX)) | |||
ZBLASOBJS_P = $(ZBLASOBJS:.$(SUFFIX)=.$(PSUFFIX)) | |||
XBLASOBJS_P = $(XBLASOBJS:.$(SUFFIX)=.$(PSUFFIX)) | |||
SHEXTOBJS_P = $(SHEXTOBJS:.$(SUFFIX)=.$(PSUFFIX)) | |||
COMMONOBJS_P = $(COMMONOBJS:.$(SUFFIX)=.$(PSUFFIX)) | |||
HPLOBJS_P = $(HPLOBJS:.$(SUFFIX)=.$(PSUFFIX)) | |||
BLASOBJS = $(SHBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS) | |||
BLASOBJS_P = $(SHBLASOBJS_P) $(SBLASOBJS_P) $(DBLASOBJS_P) $(CBLASOBJS_P) $(ZBLASOBJS_P) | |||
BLASOBJS = $(SHEXTOBJS) $(SHBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS) | |||
BLASOBJS_P = $(SHEXTOBJS_P) $(SHBLASOBJS_P) $(SBLASOBJS_P) $(DBLASOBJS_P) $(CBLASOBJS_P) $(ZBLASOBJS_P) | |||
ifdef EXPRECISION | |||
BLASOBJS += $(QBLASOBJS) $(XBLASOBJS) | |||
@@ -30,6 +31,7 @@ $(QBLASOBJS) $(QBLASOBJS_P) : override CFLAGS += -DXDOUBLE -UCOMPLEX | |||
$(CBLASOBJS) $(CBLASOBJS_P) : override CFLAGS += -UDOUBLE -DCOMPLEX | |||
$(ZBLASOBJS) $(ZBLASOBJS_P) : override CFLAGS += -DDOUBLE -DCOMPLEX | |||
$(XBLASOBJS) $(XBLASOBJS_P) : override CFLAGS += -DXDOUBLE -DCOMPLEX | |||
$(SHEXTOBJS) $(SHEXTOBJS_P) : override CFLAGS += -DHALF -UDOUBLE -UCOMPLEX | |||
$(SHBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF) | |||
$(SBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF) | |||
@@ -38,6 +40,7 @@ $(QBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF) | |||
$(CBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF) | |||
$(ZBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF) | |||
$(XBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF) | |||
$(SHEXTOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF) | |||
libs :: $(BLASOBJS) $(COMMONOBJS) | |||
$(AR) $(ARFLAGS) -ru $(TOPDIR)/$(LIBNAME) $^ | |||
@@ -382,6 +382,17 @@ void cblas_cgeadd(OPENBLAS_CONST enum CBLAS_ORDER CORDER,OPENBLAS_CONST blasint | |||
void cblas_zgeadd(OPENBLAS_CONST enum CBLAS_ORDER CORDER,OPENBLAS_CONST blasint crows, OPENBLAS_CONST blasint ccols, OPENBLAS_CONST double *calpha, double *a, OPENBLAS_CONST blasint clda, OPENBLAS_CONST double *cbeta, | |||
double *c, OPENBLAS_CONST blasint cldc); | |||
/*** BFLOAT16 and INT8 extensions ***/ | |||
/* convert float array to BFLOAT16 array by rounding */ | |||
void cblas_shstobf16(OPENBLAS_CONST blasint n, OPENBLAS_CONST float *in, OPENBLAS_CONST blasint incin, bfloat16 *out, OPENBLAS_CONST blasint incout); | |||
/* convert double array to BFLOAT16 array by rounding */ | |||
void cblas_shdtobf16(OPENBLAS_CONST blasint n, OPENBLAS_CONST double *in, OPENBLAS_CONST blasint incin, bfloat16 *out, OPENBLAS_CONST blasint incout); | |||
/* convert BFLOAT16 array to float array */ | |||
void cblas_sbf16tos(OPENBLAS_CONST blasint n, OPENBLAS_CONST bfloat16 *in, OPENBLAS_CONST blasint incin, float *out, OPENBLAS_CONST blasint incout); | |||
/* convert BFLOAT16 array to double array */ | |||
void cblas_dbf16tod(OPENBLAS_CONST blasint n, OPENBLAS_CONST bfloat16 *in, OPENBLAS_CONST blasint incin, double *out, OPENBLAS_CONST blasint incout); | |||
/* dot production of BFLOAT16 input arrays, and output as float */ | |||
float cblas_shdot(OPENBLAS_CONST blasint n, OPENBLAS_CONST bfloat16 *x, OPENBLAS_CONST blasint incx, OPENBLAS_CONST bfloat16 *y, OPENBLAS_CONST blasint incy); | |||
#ifdef __cplusplus | |||
} | |||
@@ -126,12 +126,14 @@ if (BUILD_HALF) | |||
set(SHAXPYKERNEL ../arm/axpy.c) | |||
set(SHAXPBYKERNEL ../arm/axpby.c) | |||
set(SHCOPYKERNEL ../arm/copy.c) | |||
set(SHDOTKERNEL ../arm/dot.c) | |||
set(SHDOTKERNEL ../x86_64/shdot.c) | |||
set(SHROTKERNEL ../arm/rot.c) | |||
set(SHSCALKERNEL ../arm/scal.c) | |||
set(SHNRM2KERNEL ../arm/nrm2.c) | |||
set(SHSUMKERNEL ../arm/sum.c) | |||
set(SHSWAPKERNEL ../arm/swap.c) | |||
set(TOBF16KERNEL ../x86_64/tobf16.c) | |||
set(BF16TOKERNEL ../x86_64/bf16to.c) | |||
endif () | |||
endmacro () | |||
@@ -258,7 +258,8 @@ typedef unsigned long BLASULONG; | |||
#endif | |||
#ifndef BFLOAT16 | |||
typedef unsigned short bfloat16; | |||
#include <stdint.h> | |||
typedef uint16_t bfloat16; | |||
#define HALFCONVERSION 1 | |||
#endif | |||
@@ -54,6 +54,11 @@ double BLASFUNC(dsdot) (blasint *, float *, blasint *, float *, blasint *); | |||
double BLASFUNC(ddot) (blasint *, double *, blasint *, double *, blasint *); | |||
xdouble BLASFUNC(qdot) (blasint *, xdouble *, blasint *, xdouble *, blasint *); | |||
float BLASFUNC(shdot) (blasint *, bfloat16 *, blasint *, bfloat16 *, blasint *); | |||
void BLASFUNC(shstobf16) (blasint *, float *, blasint *, bfloat16 *, blasint *); | |||
void BLASFUNC(shdtobf16) (blasint *, double *, blasint *, bfloat16 *, blasint *); | |||
void BLASFUNC(sbf16tos) (blasint *, bfloat16 *, blasint *, float *, blasint *); | |||
void BLASFUNC(dbf16tod) (blasint *, bfloat16 *, blasint *, double *, blasint *); | |||
#ifdef RETURN_BY_STRUCT | |||
typedef struct { | |||
@@ -46,6 +46,12 @@ float sdot_k(BLASLONG, float *, BLASLONG, float *, BLASLONG); | |||
double dsdot_k(BLASLONG, float *, BLASLONG, float *, BLASLONG); | |||
double ddot_k(BLASLONG, double *, BLASLONG, double *, BLASLONG); | |||
xdouble qdot_k(BLASLONG, xdouble *, BLASLONG, xdouble *, BLASLONG); | |||
float shdot_k(BLASLONG, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG); | |||
void shstobf16_k(BLASLONG, float *, BLASLONG, bfloat16 *, BLASLONG); | |||
void shdtobf16_k(BLASLONG, double *, BLASLONG, bfloat16 *, BLASLONG); | |||
void sbf16tos_k (BLASLONG, bfloat16 *, BLASLONG, float *, BLASLONG); | |||
void dbf16tod_k (BLASLONG, bfloat16 *, BLASLONG, double *, BLASLONG); | |||
openblas_complex_float cdotc_k (BLASLONG, float *, BLASLONG, float *, BLASLONG); | |||
openblas_complex_float cdotu_k (BLASLONG, float *, BLASLONG, float *, BLASLONG); | |||
@@ -646,6 +646,11 @@ | |||
#elif defined(HALF) | |||
#define D_TO_BF16_K SHDTOBF16_K | |||
#define D_BF16_TO_K DBF16TOD_K | |||
#define S_TO_BF16_K SHSTOBF16_K | |||
#define S_BF16_TO_K SBF16TOS_K | |||
#define AMAX_K SAMAX_K | |||
#define AMIN_K SAMIN_K | |||
#define MAX_K SMAX_K | |||
@@ -657,6 +662,7 @@ | |||
#define ASUM_K SASUM_K | |||
#define DOTU_K SDOTU_K | |||
#define DOTC_K SDOTC_K | |||
#define BF16_DOT_K SHDOT_K | |||
#define AXPYU_K SAXPYU_K | |||
#define AXPYC_K SAXPYC_K | |||
#define AXPBY_K SAXPBY_K | |||
@@ -51,6 +51,11 @@ typedef struct { | |||
int shgemm_p, shgemm_q, shgemm_r; | |||
int shgemm_unroll_m, shgemm_unroll_n, shgemm_unroll_mn; | |||
void (*shstobf16_k) (BLASLONG, float *, BLASLONG, bfloat16 *, BLASLONG); | |||
void (*shdtobf16_k) (BLASLONG, double *, BLASLONG, bfloat16 *, BLASLONG); | |||
void (*sbf16tos_k) (BLASLONG, bfloat16 *, BLASLONG, float *, BLASLONG); | |||
void (*dbf16tod_k) (BLASLONG, bfloat16 *, BLASLONG, double *, BLASLONG); | |||
float (*shamax_k) (BLASLONG, float *, BLASLONG); | |||
float (*shamin_k) (BLASLONG, float *, BLASLONG); | |||
float (*shmax_k) (BLASLONG, float *, BLASLONG); | |||
@@ -64,7 +69,7 @@ BLASLONG (*ishmin_k) (BLASLONG, float *, BLASLONG); | |||
float (*shasum_k) (BLASLONG, float *, BLASLONG); | |||
float (*shsum_k) (BLASLONG, float *, BLASLONG); | |||
int (*shcopy_k) (BLASLONG, float *, BLASLONG, float *, BLASLONG); | |||
float (*shdot_k) (BLASLONG, float *, BLASLONG, float *, BLASLONG); | |||
float (*shdot_k) (BLASLONG, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG); | |||
double (*dshdot_k) (BLASLONG, float *, BLASLONG, float *, BLASLONG); | |||
int (*shrot_k) (BLASLONG, float *, BLASLONG, float *, BLASLONG, float, float); | |||
@@ -3,6 +3,12 @@ | |||
#ifndef DYNAMIC_ARCH | |||
#define SHDOT_K shdot_k | |||
#define SHSTOBF16_K shstobf16_k | |||
#define SHDTOBF16_K shdtobf16_k | |||
#define SBF16TOS_K sbf16tos_k | |||
#define DBF16TOD_K dbf16tod_k | |||
#define SHGEMM_ONCOPY shgemm_oncopy | |||
#define SHGEMM_OTCOPY shgemm_otcopy | |||
@@ -18,6 +24,12 @@ | |||
#else | |||
#define SHDOT_K gotoblas -> shdot_k | |||
#define SHSTOBF16_K gotoblas -> shstobf16_k | |||
#define SHDTOBF16_K gotoblas -> shdtobf16_k | |||
#define SBF16TOS_K gotoblas -> sbf16tos_k | |||
#define DBF16TOD_K gotoblas -> dbf16tod_k | |||
#define SHGEMM_ONCOPY gotoblas -> shgemm_oncopy | |||
#define SHGEMM_OTCOPY gotoblas -> shgemm_otcopy | |||
#define SHGEMM_INCOPY gotoblas -> shgemm_incopy | |||
@@ -59,12 +59,19 @@ extern int blas_omp_linked; | |||
#define BLAS_PTHREAD 0x4000U | |||
#define BLAS_NODE 0x2000U | |||
#define BLAS_PREC 0x0003U | |||
#define BLAS_SINGLE 0x0000U | |||
#define BLAS_DOUBLE 0x0001U | |||
#define BLAS_XDOUBLE 0x0002U | |||
#define BLAS_REAL 0x0000U | |||
#define BLAS_COMPLEX 0x0004U | |||
#define BLAS_PREC 0x000FU | |||
#define BLAS_INT8 0x0000U | |||
#define BLAS_BFLOAT16 0x0001U | |||
#define BLAS_SINGLE 0x0002U | |||
#define BLAS_DOUBLE 0x0003U | |||
#define BLAS_XDOUBLE 0x0004U | |||
#define BLAS_STOBF16 0x0008U | |||
#define BLAS_DTOBF16 0x0009U | |||
#define BLAS_BF16TOS 0x000AU | |||
#define BLAS_BF16TOD 0x000BU | |||
#define BLAS_REAL 0x0000U | |||
#define BLAS_COMPLEX 0x1000U | |||
#define BLAS_TRANSA 0x0030U /* 2bit */ | |||
#define BLAS_TRANSA_N 0x0000U | |||
@@ -142,6 +142,29 @@ static __inline void cpuid(int op, int *eax, int *ebx, int *ecx, int *edx){ | |||
#endif | |||
} | |||
static __inline void cpuid_count(int op, int count, int *eax, int *ebx, int *ecx, int *edx) | |||
{ | |||
#ifdef C_MSVC | |||
int cpuInfo[4] = {-1}; | |||
__cpuidex(cpuInfo, op, count); | |||
*eax = cpuInfo[0]; | |||
*ebx = cpuInfo[1]; | |||
*ecx = cpuInfo[2]; | |||
*edx = cpuInfo[3]; | |||
#else | |||
#if defined(__i386__) && defined(__PIC__) | |||
__asm__ __volatile__ | |||
("mov %%ebx, %%edi;" | |||
"cpuid;" | |||
"xchgl %%ebx, %%edi;" | |||
: "=a" (*eax), "=D" (*ebx), "=c" (*ecx), "=d" (*edx) : "0" (op), "2" (count) : "cc"); | |||
#else | |||
__asm__ __volatile__ | |||
("cpuid": "=a" (*eax), "=b" (*ebx), "=c" (*ecx), "=d" (*edx) : "0" (op), "2" (count) : "cc"); | |||
#endif | |||
#endif | |||
} | |||
/* | |||
#define WHEREAMI | |||
*/ | |||
@@ -49,9 +49,36 @@ int blas_level1_thread(int mode, BLASLONG m, BLASLONG n, BLASLONG k, void *alpha | |||
blas_arg_t args [MAX_CPU_NUMBER]; | |||
BLASLONG i, width, astride, bstride; | |||
int num_cpu, calc_type; | |||
calc_type = (mode & BLAS_PREC) + ((mode & BLAS_COMPLEX) != 0) + 2; | |||
int num_cpu, calc_type_a, calc_type_b; | |||
switch (mode & BLAS_PREC) { | |||
case BLAS_INT8 : | |||
case BLAS_BFLOAT16: | |||
case BLAS_SINGLE : | |||
case BLAS_DOUBLE : | |||
case BLAS_XDOUBLE : | |||
calc_type_a = calc_type_b = (mode & BLAS_PREC) + ((mode & BLAS_COMPLEX) != 0); | |||
break; | |||
case BLAS_STOBF16 : | |||
calc_type_a = 2 + ((mode & BLAS_COMPLEX) != 0); | |||
calc_type_b = 1 + ((mode & BLAS_COMPLEX) != 0); | |||
break; | |||
case BLAS_DTOBF16 : | |||
calc_type_a = 3 + ((mode & BLAS_COMPLEX) != 0); | |||
calc_type_b = 1 + ((mode & BLAS_COMPLEX) != 0); | |||
break; | |||
case BLAS_BF16TOS : | |||
calc_type_a = 1 + ((mode & BLAS_COMPLEX) != 0); | |||
calc_type_b = 2 + ((mode & BLAS_COMPLEX) != 0); | |||
break; | |||
case BLAS_BF16TOD : | |||
calc_type_a = 1 + ((mode & BLAS_COMPLEX) != 0); | |||
calc_type_b = 3 + ((mode & BLAS_COMPLEX) != 0); | |||
break; | |||
default: | |||
calc_type_a = calc_type_b = 0; | |||
break; | |||
} | |||
mode |= BLAS_LEGACY; | |||
@@ -77,8 +104,8 @@ int blas_level1_thread(int mode, BLASLONG m, BLASLONG n, BLASLONG k, void *alpha | |||
bstride = width; | |||
} | |||
astride <<= calc_type; | |||
bstride <<= calc_type; | |||
astride <<= calc_type_a; | |||
bstride <<= calc_type_b; | |||
args[num_cpu].m = width; | |||
args[num_cpu].n = n; | |||
@@ -120,9 +147,36 @@ int blas_level1_thread_with_return_value(int mode, BLASLONG m, BLASLONG n, BLASL | |||
blas_arg_t args [MAX_CPU_NUMBER]; | |||
BLASLONG i, width, astride, bstride; | |||
int num_cpu, calc_type; | |||
calc_type = (mode & BLAS_PREC) + ((mode & BLAS_COMPLEX) != 0) + 2; | |||
int num_cpu, calc_type_a, calc_type_b; | |||
switch (mode & BLAS_PREC) { | |||
case BLAS_INT8 : | |||
case BLAS_BFLOAT16: | |||
case BLAS_SINGLE : | |||
case BLAS_DOUBLE : | |||
case BLAS_XDOUBLE : | |||
calc_type_a = calc_type_b = (mode & BLAS_PREC) + ((mode & BLAS_COMPLEX) != 0); | |||
break; | |||
case BLAS_STOBF16 : | |||
calc_type_a = 2 + ((mode & BLAS_COMPLEX) != 0); | |||
calc_type_b = 1 + ((mode & BLAS_COMPLEX) != 0); | |||
break; | |||
case BLAS_DTOBF16 : | |||
calc_type_a = 3 + ((mode & BLAS_COMPLEX) != 0); | |||
calc_type_b = 1 + ((mode & BLAS_COMPLEX) != 0); | |||
break; | |||
case BLAS_BF16TOS : | |||
calc_type_a = 1 + ((mode & BLAS_COMPLEX) != 0); | |||
calc_type_b = 2 + ((mode & BLAS_COMPLEX) != 0); | |||
break; | |||
case BLAS_BF16TOD : | |||
calc_type_a = 1 + ((mode & BLAS_COMPLEX) != 0); | |||
calc_type_b = 3 + ((mode & BLAS_COMPLEX) != 0); | |||
break; | |||
default: | |||
calc_type_a = calc_type_b = 0; | |||
break; | |||
} | |||
mode |= BLAS_LEGACY; | |||
@@ -148,8 +202,8 @@ int blas_level1_thread_with_return_value(int mode, BLASLONG m, BLASLONG n, BLASL | |||
bstride = width; | |||
} | |||
astride <<= calc_type; | |||
bstride <<= calc_type; | |||
astride <<= calc_type_a; | |||
bstride <<= calc_type_b; | |||
args[num_cpu].m = width; | |||
args[num_cpu].n = n; | |||
@@ -192,7 +192,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){ | |||
if (!(mode & BLAS_COMPLEX)){ | |||
#ifdef EXPRECISION | |||
if (mode & BLAS_XDOUBLE){ | |||
if ((mode & BLAS_PREC) == BLAS_XDOUBLE){ | |||
/* REAL / Extended Double */ | |||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, xdouble, | |||
xdouble *, BLASLONG, xdouble *, BLASLONG, | |||
@@ -205,7 +205,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){ | |||
args -> c, args -> ldc, sb); | |||
} else | |||
#endif | |||
if (mode & BLAS_DOUBLE){ | |||
if ((mode & BLAS_PREC) == BLAS_DOUBLE){ | |||
/* REAL / Double */ | |||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, double, | |||
double *, BLASLONG, double *, BLASLONG, | |||
@@ -216,21 +216,58 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){ | |||
args -> a, args -> lda, | |||
args -> b, args -> ldb, | |||
args -> c, args -> ldc, sb); | |||
} else { | |||
/* REAL / Single */ | |||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float, | |||
float *, BLASLONG, float *, BLASLONG, | |||
float *, BLASLONG, void *) = func; | |||
afunc(args -> m, args -> n, args -> k, | |||
((float *)args -> alpha)[0], | |||
args -> a, args -> lda, | |||
args -> b, args -> ldb, | |||
args -> c, args -> ldc, sb); | |||
} else if ((mode & BLAS_PREC) == BLAS_SINGLE){ | |||
/* REAL / Single */ | |||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float, | |||
float *, BLASLONG, float *, BLASLONG, | |||
float *, BLASLONG, void *) = func; | |||
afunc(args -> m, args -> n, args -> k, | |||
((float *)args -> alpha)[0], | |||
args -> a, args -> lda, | |||
args -> b, args -> ldb, | |||
args -> c, args -> ldc, sb); | |||
#ifdef BUILD_HALF | |||
} else if ((mode & BLAS_PREC) == BLAS_BFLOAT16){ | |||
/* REAL / BFLOAT16 */ | |||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, bfloat16, | |||
bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, | |||
bfloat16 *, BLASLONG, void *) = func; | |||
afunc(args -> m, args -> n, args -> k, | |||
((bfloat16 *)args -> alpha)[0], | |||
args -> a, args -> lda, | |||
args -> b, args -> ldb, | |||
args -> c, args -> ldc, sb); | |||
} else if ((mode & BLAS_PREC) == BLAS_STOBF16){ | |||
/* REAL / BLAS_STOBF16 */ | |||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float, | |||
float *, BLASLONG, bfloat16 *, BLASLONG, | |||
float *, BLASLONG, void *) = func; | |||
afunc(args -> m, args -> n, args -> k, | |||
((float *)args -> alpha)[0], | |||
args -> a, args -> lda, | |||
args -> b, args -> ldb, | |||
args -> c, args -> ldc, sb); | |||
} else if ((mode & BLAS_PREC) == BLAS_DTOBF16){ | |||
/* REAL / BLAS_DTOBF16 */ | |||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, double, | |||
double *, BLASLONG, bfloat16 *, BLASLONG, | |||
double *, BLASLONG, void *) = func; | |||
afunc(args -> m, args -> n, args -> k, | |||
((double *)args -> alpha)[0], | |||
args -> a, args -> lda, | |||
args -> b, args -> ldb, | |||
args -> c, args -> ldc, sb); | |||
#endif | |||
} else { | |||
/* REAL / Other types in future */ | |||
} | |||
} else { | |||
#ifdef EXPRECISION | |||
if (mode & BLAS_XDOUBLE){ | |||
if ((mode & BLAS_PREC) == BLAS_XDOUBLE){ | |||
/* COMPLEX / Extended Double */ | |||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, xdouble, xdouble, | |||
xdouble *, BLASLONG, xdouble *, BLASLONG, | |||
@@ -244,7 +281,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){ | |||
args -> c, args -> ldc, sb); | |||
} else | |||
#endif | |||
if (mode & BLAS_DOUBLE){ | |||
if ((mode & BLAS_PREC) == BLAS_DOUBLE) { | |||
/* COMPLEX / Double */ | |||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, double, double, | |||
double *, BLASLONG, double *, BLASLONG, | |||
@@ -256,7 +293,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){ | |||
args -> a, args -> lda, | |||
args -> b, args -> ldb, | |||
args -> c, args -> ldc, sb); | |||
} else { | |||
} else if ((mode & BLAS_PREC) == BLAS_SINGLE) { | |||
/* COMPLEX / Single */ | |||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float, float, | |||
float *, BLASLONG, float *, BLASLONG, | |||
@@ -268,7 +305,9 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){ | |||
args -> a, args -> lda, | |||
args -> b, args -> ldb, | |||
args -> c, args -> ldc, sb); | |||
} | |||
} else { | |||
/* COMPLEX / Other types in future */ | |||
} | |||
} | |||
} | |||
@@ -414,33 +453,37 @@ blas_queue_t *tscq; | |||
if (sb == NULL) { | |||
if (!(queue -> mode & BLAS_COMPLEX)){ | |||
#ifdef EXPRECISION | |||
if (queue -> mode & BLAS_XDOUBLE){ | |||
if ((queue -> mode & BLAS_PREC) == BLAS_XDOUBLE){ | |||
sb = (void *)(((BLASLONG)sa + ((QGEMM_P * QGEMM_Q * sizeof(xdouble) | |||
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); | |||
} else | |||
#endif | |||
if (queue -> mode & BLAS_DOUBLE){ | |||
if ((queue -> mode & BLAS_PREC) == BLAS_DOUBLE) { | |||
sb = (void *)(((BLASLONG)sa + ((DGEMM_P * DGEMM_Q * sizeof(double) | |||
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); | |||
} else { | |||
} else if ((queue -> mode & BLAS_PREC) == BLAS_SINGLE) { | |||
sb = (void *)(((BLASLONG)sa + ((SGEMM_P * SGEMM_Q * sizeof(float) | |||
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); | |||
} | |||
} else { | |||
/* Other types in future */ | |||
} | |||
} else { | |||
#ifdef EXPRECISION | |||
if (queue -> mode & BLAS_XDOUBLE){ | |||
if ((queue -> mode & BLAS_PREC) == BLAS_XDOUBLE){ | |||
sb = (void *)(((BLASLONG)sa + ((XGEMM_P * XGEMM_Q * 2 * sizeof(xdouble) | |||
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); | |||
} else | |||
#endif | |||
if (queue -> mode & BLAS_DOUBLE){ | |||
if ((queue -> mode & BLAS_PREC) == BLAS_DOUBLE){ | |||
sb = (void *)(((BLASLONG)sa + ((ZGEMM_P * ZGEMM_Q * 2 * sizeof(double) | |||
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); | |||
} else { | |||
} else if ((queue -> mode & BLAS_PREC) == BLAS_SINGLE) { | |||
sb = (void *)(((BLASLONG)sa + ((CGEMM_P * CGEMM_Q * 2 * sizeof(float) | |||
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); | |||
} | |||
} else { | |||
/* Other types in future */ | |||
} | |||
} | |||
queue->sb=sb; | |||
} | |||
@@ -142,7 +142,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){ | |||
if (!(mode & BLAS_COMPLEX)){ | |||
#ifdef EXPRECISION | |||
if (mode & BLAS_XDOUBLE){ | |||
if ((mode & BLAS_PREC) == BLAS_XDOUBLE){ | |||
/* REAL / Extended Double */ | |||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, xdouble, | |||
xdouble *, BLASLONG, xdouble *, BLASLONG, | |||
@@ -155,7 +155,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){ | |||
args -> c, args -> ldc, sb); | |||
} else | |||
#endif | |||
if (mode & BLAS_DOUBLE){ | |||
if ((mode & BLAS_PREC) == BLAS_DOUBLE){ | |||
/* REAL / Double */ | |||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, double, | |||
double *, BLASLONG, double *, BLASLONG, | |||
@@ -166,7 +166,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){ | |||
args -> a, args -> lda, | |||
args -> b, args -> ldb, | |||
args -> c, args -> ldc, sb); | |||
} else { | |||
} else if ((mode & BLAS_PREC) == BLAS_SINGLE){ | |||
/* REAL / Single */ | |||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float, | |||
float *, BLASLONG, float *, BLASLONG, | |||
@@ -177,10 +177,47 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){ | |||
args -> a, args -> lda, | |||
args -> b, args -> ldb, | |||
args -> c, args -> ldc, sb); | |||
#ifdef BUILD_HALF | |||
} else if ((mode & BLAS_PREC) == BLAS_BFLOAT16){ | |||
/* REAL / BFLOAT16 */ | |||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, bfloat16, | |||
bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, | |||
bfloat16 *, BLASLONG, void *) = func; | |||
afunc(args -> m, args -> n, args -> k, | |||
((bfloat16 *)args -> alpha)[0], | |||
args -> a, args -> lda, | |||
args -> b, args -> ldb, | |||
args -> c, args -> ldc, sb); | |||
} else if ((mode & BLAS_PREC) == BLAS_STOBF16){ | |||
/* REAL / BLAS_STOBF16 */ | |||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float, | |||
float *, BLASLONG, bfloat16 *, BLASLONG, | |||
float *, BLASLONG, void *) = func; | |||
afunc(args -> m, args -> n, args -> k, | |||
((float *)args -> alpha)[0], | |||
args -> a, args -> lda, | |||
args -> b, args -> ldb, | |||
args -> c, args -> ldc, sb); | |||
} else if ((mode & BLAS_PREC) == BLAS_DTOBF16){ | |||
/* REAL / BLAS_DTOBF16 */ | |||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, double, | |||
double *, BLASLONG, bfloat16 *, BLASLONG, | |||
double *, BLASLONG, void *) = func; | |||
afunc(args -> m, args -> n, args -> k, | |||
((double *)args -> alpha)[0], | |||
args -> a, args -> lda, | |||
args -> b, args -> ldb, | |||
args -> c, args -> ldc, sb); | |||
#endif | |||
} else { | |||
/* REAL / Other types in future */ | |||
} | |||
} else { | |||
#ifdef EXPRECISION | |||
if (mode & BLAS_XDOUBLE){ | |||
if ((mode & BLAS_PREC) == BLAS_XDOUBLE){ | |||
/* COMPLEX / Extended Double */ | |||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, xdouble, xdouble, | |||
xdouble *, BLASLONG, xdouble *, BLASLONG, | |||
@@ -194,7 +231,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){ | |||
args -> c, args -> ldc, sb); | |||
} else | |||
#endif | |||
if (mode & BLAS_DOUBLE){ | |||
if ((mode & BLAS_PREC) == BLAS_DOUBLE){ | |||
/* COMPLEX / Double */ | |||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, double, double, | |||
double *, BLASLONG, double *, BLASLONG, | |||
@@ -206,7 +243,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){ | |||
args -> a, args -> lda, | |||
args -> b, args -> ldb, | |||
args -> c, args -> ldc, sb); | |||
} else { | |||
} else if ((mode & BLAS_PREC) == BLAS_SINGLE){ | |||
/* COMPLEX / Single */ | |||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float, float, | |||
float *, BLASLONG, float *, BLASLONG, | |||
@@ -218,8 +255,10 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){ | |||
args -> a, args -> lda, | |||
args -> b, args -> ldb, | |||
args -> c, args -> ldc, sb); | |||
} | |||
} | |||
} else { | |||
/* COMPLEX / Other types in future */ | |||
} | |||
} | |||
} | |||
static void exec_threads(blas_queue_t *queue, int buf_index){ | |||
@@ -255,32 +294,36 @@ static void exec_threads(blas_queue_t *queue, int buf_index){ | |||
if (sb == NULL) { | |||
if (!(queue -> mode & BLAS_COMPLEX)){ | |||
#ifdef EXPRECISION | |||
if (queue -> mode & BLAS_XDOUBLE){ | |||
if ((queue -> mode & BLAS_PREC) == BLAS_XDOUBLE){ | |||
sb = (void *)(((BLASLONG)sa + ((QGEMM_P * QGEMM_Q * sizeof(xdouble) | |||
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); | |||
} else | |||
#endif | |||
if (queue -> mode & BLAS_DOUBLE){ | |||
if ((queue -> mode & BLAS_PREC) == BLAS_DOUBLE){ | |||
sb = (void *)(((BLASLONG)sa + ((DGEMM_P * DGEMM_Q * sizeof(double) | |||
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); | |||
} else { | |||
} else if ((queue -> mode & BLAS_PREC) == BLAS_SINGLE){ | |||
sb = (void *)(((BLASLONG)sa + ((SGEMM_P * SGEMM_Q * sizeof(float) | |||
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); | |||
} else { | |||
/* Other types in future */ | |||
} | |||
} else { | |||
#ifdef EXPRECISION | |||
if (queue -> mode & BLAS_XDOUBLE){ | |||
if ((queue -> mode & BLAS_PREC) == BLAS_XDOUBLE){ | |||
sb = (void *)(((BLASLONG)sa + ((XGEMM_P * XGEMM_Q * 2 * sizeof(xdouble) | |||
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); | |||
} else | |||
#endif | |||
if (queue -> mode & BLAS_DOUBLE){ | |||
if ((queue -> mode & BLAS_PREC) == BLAS_DOUBLE){ | |||
sb = (void *)(((BLASLONG)sa + ((ZGEMM_P * ZGEMM_Q * 2 * sizeof(double) | |||
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); | |||
} else { | |||
} else if ((queue -> mode & BLAS_PREC) == BLAS_SINGLE) { | |||
sb = (void *)(((BLASLONG)sa + ((CGEMM_P * CGEMM_Q * 2 * sizeof(float) | |||
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); | |||
} else { | |||
/* Other types in future */ | |||
} | |||
} | |||
queue->sb=sb; | |||
@@ -77,7 +77,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){ | |||
if (!(mode & BLAS_COMPLEX)){ | |||
#ifdef EXPRECISION | |||
if (mode & BLAS_XDOUBLE){ | |||
if ((mode & BLAS_PREC) == BLAS_XDOUBLE){ | |||
/* REAL / Extended Double */ | |||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, xdouble, | |||
xdouble *, BLASLONG, xdouble *, BLASLONG, | |||
@@ -90,7 +90,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){ | |||
args -> c, args -> ldc, sb); | |||
} else | |||
#endif | |||
if (mode & BLAS_DOUBLE){ | |||
if ((mode & BLAS_PREC) == BLAS_DOUBLE){ | |||
/* REAL / Double */ | |||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, double, | |||
double *, BLASLONG, double *, BLASLONG, | |||
@@ -101,7 +101,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){ | |||
args -> a, args -> lda, | |||
args -> b, args -> ldb, | |||
args -> c, args -> ldc, sb); | |||
} else { | |||
} else if ((mode & BLAS_PREC) == BLAS_SINGLE){ | |||
/* REAL / Single */ | |||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float, | |||
float *, BLASLONG, float *, BLASLONG, | |||
@@ -112,10 +112,47 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){ | |||
args -> a, args -> lda, | |||
args -> b, args -> ldb, | |||
args -> c, args -> ldc, sb); | |||
#ifdef BUILD_HALF | |||
} else if ((mode & BLAS_PREC) == BLAS_BFLOAT16){ | |||
/* REAL / BFLOAT16 */ | |||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, bfloat16, | |||
bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, | |||
bfloat16 *, BLASLONG, void *) = func; | |||
afunc(args -> m, args -> n, args -> k, | |||
((bfloat16 *)args -> alpha)[0], | |||
args -> a, args -> lda, | |||
args -> b, args -> ldb, | |||
args -> c, args -> ldc, sb); | |||
} else if ((mode & BLAS_PREC) == BLAS_STOBF16){ | |||
/* REAL / BLAS_STOBF16 */ | |||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float, | |||
float *, BLASLONG, bfloat16 *, BLASLONG, | |||
float *, BLASLONG, void *) = func; | |||
afunc(args -> m, args -> n, args -> k, | |||
((float *)args -> alpha)[0], | |||
args -> a, args -> lda, | |||
args -> b, args -> ldb, | |||
args -> c, args -> ldc, sb); | |||
} else if ((mode & BLAS_PREC) == BLAS_DTOBF16){ | |||
/* REAL / BLAS_DTOBF16 */ | |||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, double, | |||
double *, BLASLONG, bfloat16 *, BLASLONG, | |||
double *, BLASLONG, void *) = func; | |||
afunc(args -> m, args -> n, args -> k, | |||
((double *)args -> alpha)[0], | |||
args -> a, args -> lda, | |||
args -> b, args -> ldb, | |||
args -> c, args -> ldc, sb); | |||
#endif | |||
} else { | |||
/* REAL / Other types in future */ | |||
} | |||
} else { | |||
#ifdef EXPRECISION | |||
if (mode & BLAS_XDOUBLE){ | |||
if ((mode & BLAS_PREC) == BLAS_XDOUBLE){ | |||
/* COMPLEX / Extended Double */ | |||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, xdouble, xdouble, | |||
xdouble *, BLASLONG, xdouble *, BLASLONG, | |||
@@ -129,7 +166,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){ | |||
args -> c, args -> ldc, sb); | |||
} else | |||
#endif | |||
if (mode & BLAS_DOUBLE){ | |||
if ((mode & BLAS_PREC) == BLAS_DOUBLE){ | |||
/* COMPLEX / Double */ | |||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, double, double, | |||
double *, BLASLONG, double *, BLASLONG, | |||
@@ -141,7 +178,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){ | |||
args -> a, args -> lda, | |||
args -> b, args -> ldb, | |||
args -> c, args -> ldc, sb); | |||
} else { | |||
} else if ((mode & BLAS_PREC) == BLAS_SINGLE) { | |||
/* COMPLEX / Single */ | |||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float, float, | |||
float *, BLASLONG, float *, BLASLONG, | |||
@@ -153,7 +190,9 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){ | |||
args -> a, args -> lda, | |||
args -> b, args -> ldb, | |||
args -> c, args -> ldc, sb); | |||
} | |||
} else { | |||
/* COMPLEX / Other types in future */ | |||
} | |||
} | |||
} | |||
@@ -233,32 +272,36 @@ static DWORD WINAPI blas_thread_server(void *arg){ | |||
if (sb == NULL) { | |||
if (!(queue -> mode & BLAS_COMPLEX)){ | |||
#ifdef EXPRECISION | |||
if (queue -> mode & BLAS_XDOUBLE){ | |||
if ((queue -> mode & BLAS_PREC) == BLAS_XDOUBLE){ | |||
sb = (void *)(((BLASLONG)sa + ((XGEMM_P * XGEMM_Q * sizeof(xdouble) | |||
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); | |||
} else | |||
#endif | |||
if (queue -> mode & BLAS_DOUBLE){ | |||
if ((queue -> mode & BLAS_PREC) == BLAS_DOUBLE){ | |||
sb = (void *)(((BLASLONG)sa + ((DGEMM_P * DGEMM_Q * sizeof(double) | |||
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); | |||
} else { | |||
} else if ((queue -> mode & BLAS_PREC) == BLAS_SINGLE) { | |||
sb = (void *)(((BLASLONG)sa + ((SGEMM_P * SGEMM_Q * sizeof(float) | |||
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); | |||
} else { | |||
/* Other types in future */ | |||
} | |||
} else { | |||
#ifdef EXPRECISION | |||
if (queue -> mode & BLAS_XDOUBLE){ | |||
if ((queue -> mode & BLAS_PREC) == BLAS_XDOUBLE){ | |||
sb = (void *)(((BLASLONG)sa + ((XGEMM_P * XGEMM_Q * 2 * sizeof(xdouble) | |||
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); | |||
} else | |||
#endif | |||
if (queue -> mode & BLAS_DOUBLE){ | |||
if ((queue -> mode & BLAS_PREC) == BLAS_DOUBLE){ | |||
sb = (void *)(((BLASLONG)sa + ((ZGEMM_P * ZGEMM_Q * 2 * sizeof(double) | |||
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); | |||
} else { | |||
} else if ((queue -> mode & BLAS_PREC) == BLAS_SINGLE) { | |||
sb = (void *)(((BLASLONG)sa + ((CGEMM_P * CGEMM_Q * 2 * sizeof(float) | |||
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); | |||
} else { | |||
/* Other types in future */ | |||
} | |||
} | |||
queue->sb=sb; | |||
@@ -207,6 +207,19 @@ extern gotoblas_t gotoblas_SKYLAKEX; | |||
#else | |||
#define gotoblas_SKYLAKEX gotoblas_PRESCOTT | |||
#endif | |||
#ifdef DYN_COOPERLAKE | |||
extern gotoblas_t gotoblas_COOPERLAKE; | |||
#elif defined(DYN_SKYLAKEX) | |||
#define gotoblas_COOPERLAKE gotoblas_SKYLAKEX | |||
#elif defined(DYN_HASWELL) | |||
#define gotoblas_COOPERLAKE gotoblas_HASWELL | |||
#elif defined(DYN_SANDYBRIDGE) | |||
#define gotoblas_COOPERLAKE gotoblas_SANDYBRIDGE | |||
#elif defined(DYN_NEHALEM) | |||
#define gotoblas_COOPERLAKE gotoblas_NEHALEM | |||
#else | |||
#define gotoblas_COOPERLAKE gotoblas_PRESCOTT | |||
#endif | |||
#else // not DYNAMIC_LIST | |||
@@ -247,14 +260,17 @@ extern gotoblas_t gotoblas_EXCAVATOR; | |||
#ifdef NO_AVX2 | |||
#define gotoblas_HASWELL gotoblas_SANDYBRIDGE | |||
#define gotoblas_SKYLAKEX gotoblas_SANDYBRIDGE | |||
#define gotoblas_COOPERLAKE gotoblas_SANDYBRIDGE | |||
#define gotoblas_ZEN gotoblas_SANDYBRIDGE | |||
#else | |||
extern gotoblas_t gotoblas_HASWELL; | |||
extern gotoblas_t gotoblas_ZEN; | |||
#ifndef NO_AVX512 | |||
extern gotoblas_t gotoblas_SKYLAKEX; | |||
extern gotoblas_t gotoblas_COOPERLAKE; | |||
#else | |||
#define gotoblas_SKYLAKEX gotoblas_HASWELL | |||
#define gotoblas_COOPERLAKE gotoblas_HASWELL | |||
#endif | |||
#endif | |||
#else | |||
@@ -262,6 +278,7 @@ extern gotoblas_t gotoblas_SKYLAKEX; | |||
#define gotoblas_SANDYBRIDGE gotoblas_NEHALEM | |||
#define gotoblas_HASWELL gotoblas_NEHALEM | |||
#define gotoblas_SKYLAKEX gotoblas_NEHALEM | |||
#define gotoblas_COOPERLAKE gotoblas_NEHALEM | |||
#define gotoblas_BULLDOZER gotoblas_BARCELONA | |||
#define gotoblas_PILEDRIVER gotoblas_BARCELONA | |||
#define gotoblas_STEAMROLLER gotoblas_BARCELONA | |||
@@ -343,6 +360,23 @@ int support_avx512(){ | |||
#endif | |||
} | |||
int support_avx512_bf16(){ | |||
#if !defined(NO_AVX) && !defined(NO_AVX512) | |||
int eax, ebx, ecx, edx; | |||
int ret=0; | |||
if (!support_avx512()) | |||
return 0; | |||
cpuid_count(7, 1, &eax, &ebx, &ecx, &edx); | |||
if((eax & 32) == 32){ | |||
ret=1; // CPUID.7.1:EAX[bit 5] indicates whether avx512_bf16 supported or not | |||
} | |||
return ret; | |||
#else | |||
return 0; | |||
#endif | |||
} | |||
extern void openblas_warning(int verbose, const char * msg); | |||
#define FALLBACK_VERBOSE 1 | |||
#define NEHALEM_FALLBACK "OpenBLAS : Your OS does not support AVX instructions. OpenBLAS is using Nehalem kernels as a fallback, which may give poorer performance.\n" | |||
@@ -524,7 +558,10 @@ static gotoblas_t *get_coretype(void){ | |||
return &gotoblas_NEHALEM; //OS doesn't support AVX. Use old kernels. | |||
} | |||
} | |||
if (model == 5) { | |||
if (model == 5) { | |||
// Intel Cooperlake | |||
if(support_avx512_bf16()) | |||
return &gotoblas_COOPERLAKE; | |||
// Intel Skylake X | |||
if (support_avx512()) | |||
return &gotoblas_SKYLAKEX; | |||
@@ -774,7 +811,8 @@ static char *corename[] = { | |||
"Steamroller", | |||
"Excavator", | |||
"Zen", | |||
"SkylakeX" | |||
"SkylakeX", | |||
"Cooperlake" | |||
}; | |||
char *gotoblas_corename(void) { | |||
@@ -838,6 +876,7 @@ char *gotoblas_corename(void) { | |||
if (gotoblas == &gotoblas_EXCAVATOR) return corename[22]; | |||
if (gotoblas == &gotoblas_ZEN) return corename[23]; | |||
if (gotoblas == &gotoblas_SKYLAKEX) return corename[24]; | |||
if (gotoblas == &gotoblas_COOPERLAKE) return corename[25]; | |||
return corename[0]; | |||
} | |||
@@ -868,6 +907,7 @@ static gotoblas_t *force_coretype(char *coretype){ | |||
switch (found) | |||
{ | |||
case 25: return (&gotoblas_COOPERLAKE); | |||
case 24: return (&gotoblas_SKYLAKEX); | |||
case 23: return (&gotoblas_ZEN); | |||
case 22: return (&gotoblas_EXCAVATOR); | |||
@@ -46,7 +46,7 @@ | |||
ssum, dsum, scsum, dzsum | |||
); | |||
@halfblasobjs = (shgemm); | |||
@halfblasobjs = (shgemm, shdot, shstobf16, shdtobf16, sbf16tos, dbf16tod); | |||
@cblasobjs = ( | |||
cblas_caxpy, cblas_ccopy, cblas_cdotc, cblas_cdotu, cblas_cgbmv, cblas_cgemm, cblas_cgemv, | |||
cblas_cgerc, cblas_cgeru, cblas_chbmv, cblas_chemm, cblas_chemv, cblas_cher2, cblas_cher2k, | |||
@@ -84,7 +84,7 @@ | |||
cblas_xerbla | |||
); | |||
@halfcblasobjs = (cblas_shgemm); | |||
@halfcblasobjs = (cblas_shgemm, cblas_shdot, cblas_shstobf16, cblas_shdtobf16, cblas_sbf16tos, cblas_dbf16tod); | |||
@exblasobjs = ( | |||
qamax,qamin,qasum,qaxpy,qcabs1,qcopy,qdot,qgbmv,qgemm, | |||
@@ -47,7 +47,9 @@ SBLAS3OBJS = \ | |||
sgeadd.$(SUFFIX) | |||
ifeq ($(BUILD_HALF),1) | |||
SHBLAS1OBJS = shdot.$(SUFFIX) | |||
SHBLAS3OBJS = shgemm.$(SUFFIX) | |||
SHEXTOBJS = shstobf16.$(SUFFIX) shdtobf16.$(SUFFIX) sbf16tos.$(SUFFIX) dbf16tod.$(SUFFIX) | |||
endif | |||
DBLAS1OBJS = \ | |||
@@ -281,7 +283,9 @@ CSBLAS3OBJS = \ | |||
cblas_sgeadd.$(SUFFIX) | |||
ifeq ($(BUILD_HALF),1) | |||
CSHBLAS1OBJS = cblas_shdot.$(SUFFIX) | |||
CSHBLAS3OBJS = cblas_shgemm.$(SUFFIX) | |||
CSHEXTOBJS = cblas_shstobf16.$(SUFFIX) cblas_shdtobf16.$(SUFFIX) cblas_sbf16tos.$(SUFFIX) cblas_dbf16tod.$(SUFFIX) | |||
endif | |||
CDBLAS1OBJS = \ | |||
@@ -374,6 +378,7 @@ override CFLAGS += -I. | |||
SBLAS1OBJS += $(CSBLAS1OBJS) | |||
SBLAS2OBJS += $(CSBLAS2OBJS) | |||
SBLAS3OBJS += $(CSBLAS3OBJS) | |||
SHBLAS1OBJS += $(CSHBLAS1OBJS) | |||
SHBLAS3OBJS += $(CSHBLAS3OBJS) | |||
DBLAS1OBJS += $(CDBLAS1OBJS) | |||
DBLAS2OBJS += $(CDBLAS2OBJS) | |||
@@ -385,10 +390,11 @@ ZBLAS1OBJS += $(CZBLAS1OBJS) | |||
ZBLAS2OBJS += $(CZBLAS2OBJS) | |||
ZBLAS3OBJS += $(CZBLAS3OBJS) | |||
SHEXTOBJS += $(CSHEXTOBJS) | |||
endif | |||
SBLASOBJS = $(SBLAS1OBJS) $(SBLAS2OBJS) $(SBLAS3OBJS) | |||
SHBLASOBJS = $(SHBLAS3OBJS) | |||
SHBLASOBJS = $(SHBLAS1OBJS) $(SHBLAS3OBJS) | |||
DBLASOBJS = $(DBLAS1OBJS) $(DBLAS2OBJS) $(DBLAS3OBJS) | |||
QBLASOBJS = $(QBLAS1OBJS) $(QBLAS2OBJS) $(QBLAS3OBJS) | |||
CBLASOBJS = $(CBLAS1OBJS) $(CBLAS2OBJS) $(CBLAS3OBJS) | |||
@@ -463,7 +469,7 @@ ZBLASOBJS += $(ZLAPACKOBJS) | |||
endif | |||
FUNCOBJS = $(SHBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS) | |||
FUNCOBJS = $(SHEXTOBJS) $(SHBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS) | |||
ifdef EXPRECISION | |||
FUNCOBJS += $(QBLASOBJS) $(XBLASOBJS) | |||
@@ -491,7 +497,7 @@ endif | |||
clean :: | |||
@rm -f functable.h | |||
level1 : $(SBLAS1OBJS) $(DBLAS1OBJS) $(QBLAS1OBJS) $(CBLAS1OBJS) $(ZBLAS1OBJS) $(XBLAS1OBJS) | |||
level1 : $(BEXTOBJS) $(SHBLAS1OBJS) $(SBLAS1OBJS) $(DBLAS1OBJS) $(QBLAS1OBJS) $(CBLAS1OBJS) $(ZBLAS1OBJS) $(XBLAS1OBJS) | |||
$(AR) $(ARFLAGS) -ru $(TOPDIR)/$(LIBNAME) $^ | |||
level2 : $(SBLAS2OBJS) $(DBLAS2OBJS) $(QBLAS2OBJS) $(CBLAS2OBJS) $(ZBLAS2OBJS) $(XBLAS2OBJS) | |||
@@ -725,6 +731,19 @@ sdsdot.$(SUFFIX) sdsdot.$(PSUFFIX) : sdsdot.c | |||
dsdot.$(SUFFIX) dsdot.$(PSUFFIX) : dsdot.c | |||
$(CC) $(CFLAGS) -c $< -o $(@F) | |||
ifeq ($(BUILD_HALF),1) | |||
shdot.$(SUFFIX) shdot.$(PSUFFIX) : bf16dot.c | |||
$(CC) $(CFLAGS) -c $< -o $(@F) | |||
shstobf16.$(SUFFIX) shstobf16.$(PSUFFIX) : tobf16.c | |||
$(CC) $(CFLAGS) -DSINGLE_PREC -UDOUBLE_PREC -c $< -o $(@F) | |||
shdtobf16.$(SUFFIX) shdtobf16.$(PSUFFIX) : tobf16.c | |||
$(CC) $(CFLAGS) -USINGLE_PREC -DDOUBLE_PREC -c $< -o $(@F) | |||
sbf16tos.$(SUFFIX) sbf16tos.$(PSUFFIX) : bf16to.c | |||
$(CC) $(CFLAGS) -DSINGLE_PREC -UDOUBLE_PREC -c $< -o $(@F) | |||
dbf16tod.$(SUFFIX) dbf16tod.$(PSUFFIX) : bf16to.c | |||
$(CC) $(CFLAGS) -USINGLE_PREC -DDOUBLE_PREC -c $< -o $(@F) | |||
endif | |||
sdot.$(SUFFIX) sdot.$(PSUFFIX) : dot.c | |||
$(CC) $(CFLAGS) -c $< -o $(@F) | |||
@@ -1463,6 +1482,19 @@ cblas_sdsdot.$(SUFFIX) cblas_sdsdot.$(PSUFFIX) : sdsdot.c | |||
cblas_dsdot.$(SUFFIX) cblas_dsdot.$(PSUFFIX) : dsdot.c | |||
$(CC) $(CFLAGS) -DCBLAS -c $< -o $(@F) | |||
ifeq ($(BUILD_HALF),1) | |||
cblas_shdot.$(SUFFIX) cblas_shdot.$(PSUFFIX) : bf16dot.c | |||
$(CC) $(CFLAGS) -DCBLAS -c $< -o $(@F) | |||
cblas_shstobf16.$(SUFFIX) cblas_shstobf16.$(PSUFFIX) : tobf16.c | |||
$(CC) $(CFLAGS) -DCBLAS -DSINGLE_PREC -UDOUBLE_PREC -c $< -o $(@F) | |||
cblas_shdtobf16.$(SUFFIX) cblas_shdtobf16.$(PSUFFIX) : tobf16.c | |||
$(CC) $(CFLAGS) -DCBLAS -USINGLE_PREC -DDOUBLE_PREC -c $< -o $(@F) | |||
cblas_sbf16tos.$(SUFFIX) cblas_sbf16tos.$(PSUFFIX) : bf16to.c | |||
$(CC) $(CFLAGS) -DCBLAS -DSINGLE_PREC -UDOUBLE_PREC -c $< -o $(@F) | |||
cblas_dbf16tod.$(SUFFIX) cblas_dbf16tod.$(PSUFFIX) : bf16to.c | |||
$(CC) $(CFLAGS) -DCBLAS -USINGLE_PREC -DDOUBLE_PREC -c $< -o $(@F) | |||
endif | |||
cblas_sdot.$(SUFFIX) cblas_sdot.$(PSUFFIX) : dot.c | |||
$(CC) $(CFLAGS) -DCBLAS -c $< -o $(@F) | |||
@@ -0,0 +1,52 @@ | |||
#include <stdio.h> | |||
#include "common.h" | |||
#ifdef FUNCTION_PROFILE | |||
#include "functable.h" | |||
#endif | |||
#ifndef CBLAS | |||
float NAME(blasint *N, bfloat16 *x, blasint *INCX, bfloat16 *y, blasint *INCY){ | |||
BLASLONG n = *N; | |||
BLASLONG incx = *INCX; | |||
BLASLONG incy = *INCY; | |||
float ret; | |||
PRINT_DEBUG_NAME; | |||
if (n <= 0) return 0.; | |||
IDEBUG_START; | |||
FUNCTION_PROFILE_START(); | |||
if (incx < 0) x -= (n - 1) * incx; | |||
if (incy < 0) y -= (n - 1) * incy; | |||
ret = BF16_DOT_K(n, x, incx, y, incy); | |||
FUNCTION_PROFILE_END(1, 2 * n, 2 * n); | |||
IDEBUG_END; | |||
return ret; | |||
} | |||
#else | |||
float CNAME(blasint n, bfloat16 *x, blasint incx, bfloat16 *y, blasint incy){ | |||
float ret; | |||
PRINT_DEBUG_CNAME; | |||
if (n <= 0) return 0.; | |||
IDEBUG_START; | |||
FUNCTION_PROFILE_START(); | |||
if (incx < 0) x -= (n - 1) * incx; | |||
if (incy < 0) y -= (n - 1) * incy; | |||
ret = BF16_DOT_K(n, x, incx, y, incy); | |||
FUNCTION_PROFILE_END(1, 2 * n, 2 * n); | |||
IDEBUG_END; | |||
return ret; | |||
} | |||
#endif |
@@ -0,0 +1,62 @@ | |||
#include <stdio.h> | |||
#include "common.h" | |||
#ifdef FUNCTION_PROFILE | |||
#include "functable.h" | |||
#endif | |||
#if defined(DOUBLE_PREC) | |||
#define FLOAT_TYPE double | |||
#elif defined(SINGLE_PREC) | |||
#define FLOAT_TYPE float | |||
#else | |||
#endif | |||
#ifndef CBLAS | |||
void NAME(blasint *N, bfloat16 *in, blasint *INC_IN, FLOAT_TYPE *out, blasint *INC_OUT){ | |||
BLASLONG n = *N; | |||
BLASLONG inc_in = *INC_IN; | |||
BLASLONG inc_out = *INC_OUT; | |||
PRINT_DEBUG_NAME; | |||
if (n <= 0) return; | |||
IDEBUG_START; | |||
FUNCTION_PROFILE_START(); | |||
if (inc_in < 0) in -= (n - 1) * inc_in; | |||
if (inc_out < 0) out -= (n - 1) * inc_out; | |||
#if defined(DOUBLE_PREC) | |||
D_BF16_TO_K(n, in, inc_in, out, inc_out); | |||
#elif defined(SINGLE_PREC) | |||
S_BF16_TO_K(n, in, inc_in, out, inc_out); | |||
#else | |||
#endif | |||
FUNCTION_PROFILE_END(1, 2 * n, 2 * n); | |||
IDEBUG_END; | |||
} | |||
#else | |||
void CNAME(blasint n, bfloat16 * in, blasint inc_in, FLOAT_TYPE * out, blasint inc_out){ | |||
PRINT_DEBUG_CNAME; | |||
if (n <= 0) return; | |||
IDEBUG_START; | |||
FUNCTION_PROFILE_START(); | |||
if (inc_in < 0) in -= (n - 1) * inc_in; | |||
if (inc_out < 0) out -= (n - 1) * inc_out; | |||
#if defined(DOUBLE_PREC) | |||
D_BF16_TO_K(n, in, inc_in, out, inc_out); | |||
#elif defined(SINGLE_PREC) | |||
S_BF16_TO_K(n, in, inc_in, out, inc_out); | |||
#else | |||
#endif | |||
FUNCTION_PROFILE_END(1, 2 * n, 2 * n); | |||
IDEBUG_END; | |||
} | |||
#endif |
@@ -0,0 +1,61 @@ | |||
#include <stdio.h> | |||
#include "common.h" | |||
#ifdef FUNCTION_PROFILE | |||
#include "functable.h" | |||
#endif | |||
#if defined(DOUBLE_PREC) | |||
#define FLOAT_TYPE double | |||
#elif defined(SINGLE_PREC) | |||
#define FLOAT_TYPE float | |||
#else | |||
#endif | |||
#ifndef CBLAS | |||
void NAME(blasint *N, FLOAT_TYPE *in, blasint *INC_IN, bfloat16 *out, blasint *INC_OUT){ | |||
BLASLONG n = *N; | |||
BLASLONG inc_in = *INC_IN; | |||
BLASLONG inc_out = *INC_OUT; | |||
PRINT_DEBUG_NAME; | |||
if (n <= 0) return; | |||
IDEBUG_START; | |||
FUNCTION_PROFILE_START(); | |||
if (inc_in < 0) in -= (n - 1) * inc_in; | |||
if (inc_out < 0) out -= (n - 1) * inc_out; | |||
#if defined(DOUBLE_PREC) | |||
D_TO_BF16_K(n, in, inc_in, out, inc_out); | |||
#elif defined(SINGLE_PREC) | |||
S_TO_BF16_K(n, in, inc_in, out, inc_out); | |||
#else | |||
#endif | |||
FUNCTION_PROFILE_END(1, 2 * n, 2 * n); | |||
IDEBUG_END; | |||
} | |||
#else | |||
void CNAME(blasint n, FLOAT_TYPE *in, blasint inc_in, bfloat16 *out, blasint inc_out){ | |||
PRINT_DEBUG_CNAME; | |||
if (n <= 0) return; | |||
IDEBUG_START; | |||
FUNCTION_PROFILE_START(); | |||
if (inc_in < 0) in -= (n - 1) * inc_in; | |||
if (inc_out < 0) out -= (n - 1) * inc_out; | |||
#if defined(DOUBLE_PREC) | |||
D_TO_BF16_K(n, in, inc_in, out, inc_out); | |||
#elif defined(SINGLE_PREC) | |||
S_TO_BF16_K(n, in, inc_in, out, inc_out); | |||
#endif | |||
FUNCTION_PROFILE_END(1, 2 * n, 2 * n); | |||
IDEBUG_END; | |||
} | |||
#endif |
@@ -262,6 +262,20 @@ ifndef XDOTKERNEL | |||
XDOTKERNEL = zdot.S | |||
endif | |||
ifeq ($(BUILD_HALF),1) | |||
ifndef SHDOTKERNEL | |||
SHDOTKERNEL = ../x86_64/shdot.c | |||
endif | |||
ifndef TOBF16KERNEL | |||
TOBF16KERNEL = ../x86_64/tobf16.c | |||
endif | |||
ifndef BF16TOKERNEL | |||
BF16TOKERNEL = ../x86_64/bf16to.c | |||
endif | |||
endif | |||
### NRM2 ### | |||
ifndef SNRM2KERNEL | |||
@@ -516,6 +530,15 @@ XBLASOBJS += \ | |||
xdotc_k$(TSUFFIX).$(SUFFIX) xdotu_k$(TSUFFIX).$(SUFFIX) xnrm2_k$(TSUFFIX).$(SUFFIX) xqrot_k$(TSUFFIX).$(SUFFIX) \ | |||
xscal_k$(TSUFFIX).$(SUFFIX) xswap_k$(TSUFFIX).$(SUFFIX) xsum_k$(TSUFFIX).$(SUFFIX) | |||
ifeq ($(BUILD_HALF),1) | |||
SHBLASOBJS += \ | |||
shdot_k$(TSUFFIX).$(SUFFIX) | |||
SHEXTOBJS += \ | |||
shstobf16_k$(TSUFFIX).$(SUFFIX) shdtobf16_k$(TSUFFIX).$(SUFFIX) | |||
SHEXTOBJS += \ | |||
sbf16tos_k$(TSUFFIX).$(SUFFIX) dbf16tod_k$(TSUFFIX).$(SUFFIX) | |||
endif | |||
### AMAX ### | |||
@@ -734,6 +757,19 @@ $(KDIR)ddot_k$(TSUFFIX).$(SUFFIX) $(KDIR)ddot_k$(TPSUFFIX).$(PSUFFIX) : $(KERNEL | |||
$(KDIR)qdot_k$(TSUFFIX).$(SUFFIX) $(KDIR)qdot_k$(TPSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(QDOTKERNEL) | |||
$(CC) -c $(CFLAGS) -UCOMPLEX -DXDOUBLE $< -o $@ | |||
ifeq ($(BUILD_HALF),1) | |||
$(KDIR)shdot_k$(TSUFFIX).$(SUFFIX) $(KDIR)shdot_k$(TPSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SHDOTKERNEL) | |||
$(CC) -c $(CFLAGS) -UCOMPLEX $< -o $@ | |||
$(KDIR)shstobf16_k$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(TOBF16KERNEL) | |||
$(CC) -c $(CFLAGS) -UDOUBLE -DSINGLE $< -o $@ | |||
$(KDIR)shdtobf16_k$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(TOBF16KERNEL) | |||
$(CC) -c $(CFLAGS) -DDOUBLE -USINGLE $< -o $@ | |||
$(KDIR)sbf16tos_k$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(BF16TOKERNEL) | |||
$(CC) -c $(CFLAGS) -UDOUBLE -DSINGLE $< -o $@ | |||
$(KDIR)dbf16tod_k$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(BF16TOKERNEL) | |||
$(CC) -c $(CFLAGS) -DDOUBLE -USINGLE $< -o $@ | |||
endif | |||
$(KDIR)sdot_k$(TSUFFIX).$(SUFFIX) $(KDIR)sdot_k$(TPSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SDOTKERNEL) | |||
$(CC) -c $(CFLAGS) -UCOMPLEX -UDOUBLE $< -o $@ | |||
@@ -62,9 +62,11 @@ gotoblas_t TABLE_NAME = { | |||
MAX(SHGEMM_DEFAULT_UNROLL_M, SHGEMM_DEFAULT_UNROLL_N), | |||
#endif | |||
shstobf16_kTS, shdtobf16_kTS, sbf16tos_kTS, dbf16tod_kTS, | |||
samax_kTS, samin_kTS, smax_kTS, smin_kTS, | |||
isamax_kTS, isamin_kTS, ismax_kTS, ismin_kTS, | |||
snrm2_kTS, sasum_kTS, ssum_kTS, scopy_kTS, sdot_kTS, | |||
snrm2_kTS, sasum_kTS, ssum_kTS, scopy_kTS, shdot_kTS, | |||
dsdot_kTS, | |||
srot_kTS, saxpy_kTS, sscal_kTS, sswap_kTS, | |||
sgemv_nTS, sgemv_tTS, sger_kTS, | |||
@@ -146,6 +146,18 @@ ifndef XDOTKERNEL | |||
XDOTKERNEL = zdot.S | |||
endif | |||
ifndef SHDOTKERNEL | |||
SHDOTKERNEL = shdot.c | |||
endif | |||
ifndef TOBF16KERNEL | |||
TOBF16KERNEL = tobf16.c | |||
endif | |||
ifndef BF16TOKERNEL | |||
BF16TOKERNEL = bf16to.c | |||
endif | |||
ifndef ISAMAXKERNEL | |||
ISAMAXKERNEL = iamax_sse.S | |||
endif | |||
@@ -0,0 +1,114 @@ | |||
/*************************************************************************** | |||
Copyright (c) 2014, The OpenBLAS Project | |||
All rights reserved. | |||
Redistribution and use in source and binary forms, with or without | |||
modification, are permitted provided that the following conditions are | |||
met: | |||
1. Redistributions of source code must retain the above copyright | |||
notice, this list of conditions and the following disclaimer. | |||
2. Redistributions in binary form must reproduce the above copyright | |||
notice, this list of conditions and the following disclaimer in | |||
the documentation and/or other materials provided with the | |||
distribution. | |||
3. Neither the name of the OpenBLAS project nor the names of | |||
its contributors may be used to endorse or promote products | |||
derived from this software without specific prior written permission. | |||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE | |||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE | |||
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||
*****************************************************************************/ | |||
#include <stddef.h> | |||
#include "common.h" | |||
#if defined(DOUBLE) | |||
#define FLOAT_TYPE double | |||
#elif defined(SINGLE) | |||
#define FLOAT_TYPE float | |||
#else | |||
#endif | |||
/* Notes for algorithm: | |||
* - Input denormal treated as zero | |||
* - Force to be QNAN | |||
*/ | |||
static void bf16to_kernel_1(BLASLONG n, const bfloat16 * in, BLASLONG inc_in, FLOAT_TYPE * out, BLASLONG inc_out) | |||
{ | |||
BLASLONG register index_in = 0; | |||
BLASLONG register index_out = 0; | |||
BLASLONG register index = 0; | |||
uint16_t * tmp = NULL; | |||
#if defined(DOUBLE) | |||
float float_out = 0.0; | |||
#endif | |||
while(index<n) { | |||
#if defined(DOUBLE) | |||
float_out = 0.0; | |||
tmp = (uint16_t*)(&float_out); | |||
#else | |||
*(out+index_out) = 0; | |||
tmp = (uint16_t*)(out+index_out); | |||
#endif | |||
switch((*(in+index_in)) & 0xff80u) { | |||
case (0x0000u): /* Type 1: Positive denormal */ | |||
tmp[1] = 0x0000u; | |||
tmp[0] = 0x0000u; | |||
break; | |||
case (0x8000u): /* Type 2: Negative denormal */ | |||
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ | |||
tmp[1] = 0x8000u; | |||
tmp[0] = 0x0000u; | |||
#else | |||
tmp[1] = 0x0000u; | |||
tmp[0] = 0x8000u; | |||
#endif | |||
break; | |||
case (0x7f80u): /* Type 3: Positive infinity or NAN */ | |||
case (0xff80u): /* Type 4: Negative infinity or NAN */ | |||
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ | |||
tmp[1] = *(in+index_in); | |||
#else | |||
tmp[0] = *(in+index_in); | |||
#endif | |||
/* Specific for NAN */ | |||
if (((*(in+index_in)) & 0x007fu) != 0) { | |||
/* Force to be QNAN */ | |||
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ | |||
tmp[1] |= 0x0040u; | |||
#else | |||
tmp[0] |= 0x0040u; | |||
#endif | |||
} | |||
break; | |||
default: /* Type 5: Normal case */ | |||
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ | |||
tmp[1] = *(in+index_in); | |||
#else | |||
tmp[0] = *(in+index_in); | |||
#endif | |||
break; | |||
} | |||
#if defined(DOUBLE) | |||
*(out+index_out) = (double)float_out; | |||
#endif | |||
index_in += inc_in; | |||
index_out += inc_out; | |||
index++; | |||
} | |||
} | |||
void CNAME(BLASLONG n, bfloat16 * in, BLASLONG inc_in, FLOAT_TYPE * out, BLASLONG inc_out) | |||
{ | |||
if (n <= 0) return; | |||
bf16to_kernel_1(n, in, inc_in, out, inc_out); | |||
} |
@@ -0,0 +1,104 @@ | |||
/*************************************************************************** | |||
Copyright (c) 2014, The OpenBLAS Project | |||
All rights reserved. | |||
Redistribution and use in source and binary forms, with or without | |||
modification, are permitted provided that the following conditions are | |||
met: | |||
1. Redistributions of source code must retain the above copyright | |||
notice, this list of conditions and the following disclaimer. | |||
2. Redistributions in binary form must reproduce the above copyright | |||
notice, this list of conditions and the following disclaimer in | |||
the documentation and/or other materials provided with the | |||
distribution. | |||
3. Neither the name of the OpenBLAS project nor the names of | |||
its contributors may be used to endorse or promote products | |||
derived from this software without specific prior written permission. | |||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE | |||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE | |||
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||
*****************************************************************************/ | |||
/* need a new enough GCC for avx512 support */ | |||
#if (( defined(__GNUC__) && __GNUC__ >= 10 && defined(__AVX512BF16__)) || (defined(__clang__) && __clang_major__ >= 9)) | |||
#define HAVE_TOBF16_ACCL_KERNEL 1 | |||
#include "common.h" | |||
#include <immintrin.h> | |||
static void tobf16_accl_kernel(BLASLONG n, const double * in, bfloat16 * out) | |||
{ | |||
/* Get the 64-bytes unaligned header number targeting for avx512 | |||
* processing (Assume input float array is natural aligned) */ | |||
int align_header = ((64 - ((uintptr_t)in & (uintptr_t)0x3f)) >> 3) & 0x7; | |||
if (n < align_header) {align_header = n;} | |||
if (align_header != 0) { | |||
unsigned char align_mask8 = (((unsigned char)0xff) >> (8-align_header)); | |||
__m512d a = _mm512_maskz_loadu_pd(*((__mmask8*) &align_mask8), &in[0]); | |||
_mm_mask_storeu_epi16(&out[0], *((__mmask8*) &align_mask8), (__m128i) _mm256_cvtneps_pbh(_mm512_cvtpd_ps(a))); | |||
} | |||
if (n == align_header) { | |||
return; | |||
} else { | |||
n -= align_header; | |||
in += align_header; | |||
out += align_header; | |||
} | |||
int tail_index_8 = n&(~7); | |||
int tail_index_32 = n&(~31); | |||
int tail_index_128 = n&(~127); | |||
unsigned char tail_mask8 = (((unsigned char) 0xff) >> (8 -(n&7))); | |||
/* Processing the main chunk with 128-elements per round */ | |||
for (int i = 0; i < tail_index_128; i += 128) { | |||
// Fold 1 | |||
__m512 data1_512_low = _mm512_insertf32x8(_mm512_castps256_ps512(_mm512_cvtpd_ps(_mm512_load_pd(&in[i+ 0]))), _mm512_cvtpd_ps(_mm512_load_pd(&in[i+ 8])), 1); | |||
__m512 data1_512_high = _mm512_insertf32x8(_mm512_castps256_ps512(_mm512_cvtpd_ps(_mm512_load_pd(&in[i+16]))), _mm512_cvtpd_ps(_mm512_load_pd(&in[i+24])), 1); | |||
_mm512_storeu_si512(&out[i+ 0], (__m512i) _mm512_cvtne2ps_pbh(data1_512_high, data1_512_low)); | |||
// Fold 2 | |||
__m512 data2_512_low = _mm512_insertf32x8(_mm512_castps256_ps512(_mm512_cvtpd_ps(_mm512_load_pd(&in[i+32]))), _mm512_cvtpd_ps(_mm512_load_pd(&in[i+40])), 1); | |||
__m512 data2_512_high = _mm512_insertf32x8(_mm512_castps256_ps512(_mm512_cvtpd_ps(_mm512_load_pd(&in[i+48]))), _mm512_cvtpd_ps(_mm512_load_pd(&in[i+56])), 1); | |||
_mm512_storeu_si512(&out[i+32], (__m512i) _mm512_cvtne2ps_pbh(data2_512_high, data2_512_low)); | |||
// Fold 3 | |||
__m512 data3_512_low = _mm512_insertf32x8(_mm512_castps256_ps512(_mm512_cvtpd_ps(_mm512_load_pd(&in[i+64]))), _mm512_cvtpd_ps(_mm512_load_pd(&in[i+72])), 1); | |||
__m512 data3_512_high = _mm512_insertf32x8(_mm512_castps256_ps512(_mm512_cvtpd_ps(_mm512_load_pd(&in[i+80]))), _mm512_cvtpd_ps(_mm512_load_pd(&in[i+88])), 1); | |||
_mm512_storeu_si512(&out[i+64], (__m512i) _mm512_cvtne2ps_pbh(data3_512_high, data3_512_low)); | |||
// Fold 4 | |||
__m512 data4_512_low = _mm512_insertf32x8(_mm512_castps256_ps512(_mm512_cvtpd_ps(_mm512_load_pd(&in[i+96]))), _mm512_cvtpd_ps(_mm512_load_pd(&in[i+104])), 1); | |||
__m512 data4_512_high = _mm512_insertf32x8(_mm512_castps256_ps512(_mm512_cvtpd_ps(_mm512_load_pd(&in[i+112]))), _mm512_cvtpd_ps(_mm512_load_pd(&in[i+120])), 1); | |||
_mm512_storeu_si512(&out[i+96], (__m512i) _mm512_cvtne2ps_pbh(data4_512_high, data4_512_low)); | |||
} | |||
/* Processing the remaining <128 chunk with 32-elements per round */ | |||
for (int j = tail_index_128; j < tail_index_32; j += 32) { | |||
__m512 data1_512_low = _mm512_insertf32x8(_mm512_castps256_ps512(_mm512_cvtpd_ps(_mm512_load_pd(&in[j+ 0]))), _mm512_cvtpd_ps(_mm512_load_pd(&in[j+ 8])), 1); | |||
__m512 data1_512_high = _mm512_insertf32x8(_mm512_castps256_ps512(_mm512_cvtpd_ps(_mm512_load_pd(&in[j+16]))), _mm512_cvtpd_ps(_mm512_load_pd(&in[j+24])), 1); | |||
_mm512_storeu_si512(&out[j], (__m512i) _mm512_cvtne2ps_pbh(data1_512_high, data1_512_low)); | |||
} | |||
/* Processing the remaining <32 chunk with 8-elements per round */ | |||
for (int j = tail_index_32; j < tail_index_8; j += 8) { | |||
_mm_storeu_si128((__m128i *)&out[j], (__m128i) _mm256_cvtneps_pbh(_mm512_cvtpd_ps(_mm512_load_pd(&in[j])))); | |||
} | |||
/* Processing the remaining <8 chunk with masked processing */ | |||
if ((n&7) > 0) { | |||
__m512d data_512 = _mm512_maskz_load_pd(*((__mmask8*) &tail_mask8), &in[tail_index_8]); | |||
_mm_mask_storeu_epi16(&out[tail_index_8], *((__mmask8*) &tail_mask8), (__m128i) _mm256_cvtneps_pbh(_mm512_cvtpd_ps(data_512))); | |||
} | |||
} | |||
#endif |
@@ -0,0 +1,115 @@ | |||
/*************************************************************************** | |||
Copyright (c) 2014, The OpenBLAS Project | |||
All rights reserved. | |||
Redistribution and use in source and binary forms, with or without | |||
modification, are permitted provided that the following conditions are | |||
met: | |||
1. Redistributions of source code must retain the above copyright | |||
notice, this list of conditions and the following disclaimer. | |||
2. Redistributions in binary form must reproduce the above copyright | |||
notice, this list of conditions and the following disclaimer in | |||
the documentation and/or other materials provided with the | |||
distribution. | |||
3. Neither the name of the OpenBLAS project nor the names of | |||
its contributors may be used to endorse or promote products | |||
derived from this software without specific prior written permission. | |||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE | |||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE | |||
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||
*****************************************************************************/ | |||
#include "common.h" | |||
#if defined(COOPERLAKE) | |||
#include "shdot_microk_cooperlake.c" | |||
#endif | |||
static float shdot_compute(BLASLONG n, bfloat16 *x, BLASLONG inc_x, bfloat16 *y, BLASLONG inc_y) | |||
{ | |||
float d = 0.0; | |||
#ifdef HAVE_SHDOT_ACCL_KERNEL | |||
if ((inc_x == 1) && (inc_y == 1)) { | |||
return shdot_accl_kernel(n, x, y); | |||
} | |||
#endif | |||
float * x_fp32 = malloc(sizeof(float)*n); | |||
float * y_fp32 = malloc(sizeof(float)*n); | |||
SBF16TOS_K(n, x, inc_x, x_fp32, 1); | |||
SBF16TOS_K(n, y, inc_y, y_fp32, 1); | |||
d = SDOTU_K(n, x_fp32, 1, y_fp32, 1); | |||
free(x_fp32); | |||
free(y_fp32); | |||
return d; | |||
} | |||
#if defined(SMP) | |||
static int shdot_thread_func(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, bfloat16 dummy2, | |||
bfloat16 *x, BLASLONG inc_x, bfloat16 *y, BLASLONG inc_y, | |||
float *result, BLASLONG dummy3) | |||
{ | |||
*(float *)result = shdot_compute(n, x, inc_x, y, inc_y); | |||
return 0; | |||
} | |||
extern int blas_level1_thread_with_return_value(int mode, BLASLONG m, BLASLONG n, BLASLONG k, void *alpha, | |||
void *a, BLASLONG lda, void *b, BLASLONG ldb, void *c, BLASLONG ldc, | |||
int (*function)(), int nthreads); | |||
#endif | |||
float CNAME(BLASLONG n, bfloat16 *x, BLASLONG inc_x, bfloat16 *y, BLASLONG inc_y) | |||
{ | |||
float dot_result = 0.0; | |||
if (n <= 0) return 0.0; | |||
#if defined(SMP) | |||
int nthreads; | |||
int thread_thres = 40960; | |||
bfloat16 dummy_alpha; | |||
#endif | |||
#if defined(SMP) | |||
if (inc_x == 0 || inc_y == 0 || n <= thread_thres) | |||
nthreads = 1; | |||
else | |||
nthreads = num_cpu_avail(1); | |||
int best_threads = (int) (n/(float)thread_thres + 0.5); | |||
if (best_threads < nthreads) { | |||
nthreads = best_threads; | |||
} | |||
if (nthreads <= 1) { | |||
dot_result = shdot_compute(n, x, inc_x, y, inc_y); | |||
} else { | |||
char thread_result[MAX_CPU_NUMBER * sizeof(double) * 2]; | |||
int mode = BLAS_BFLOAT16 | BLAS_REAL; | |||
blas_level1_thread_with_return_value(mode, n, 0, 0, &dummy_alpha, | |||
x, inc_x, y, inc_y, thread_result, 0, | |||
(void *)shdot_thread_func, nthreads); | |||
float * ptr = (float *)thread_result; | |||
for (int i = 0; i < nthreads; i++) { | |||
dot_result += (*ptr); | |||
ptr = (float *)(((char *)ptr) + sizeof(double) * 2); | |||
} | |||
} | |||
#else | |||
dot_result = shdot_compute(n, x, inc_x, y, inc_y); | |||
#endif | |||
return dot_result; | |||
} |
@@ -0,0 +1,159 @@ | |||
/*************************************************************************** | |||
Copyright (c) 2014, The OpenBLAS Project | |||
All rights reserved. | |||
Redistribution and use in source and binary forms, with or without | |||
modification, are permitted provided that the following conditions are | |||
met: | |||
1. Redistributions of source code must retain the above copyright | |||
notice, this list of conditions and the following disclaimer. | |||
2. Redistributions in binary form must reproduce the above copyright | |||
notice, this list of conditions and the following disclaimer in | |||
the documentation and/or other materials provided with the | |||
distribution. | |||
3. Neither the name of the OpenBLAS project nor the names of | |||
its contributors may be used to endorse or promote products | |||
derived from this software without specific prior written permission. | |||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE | |||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE | |||
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||
*****************************************************************************/ | |||
/* need a new enough GCC for avx512 support */ | |||
#if (( defined(__GNUC__) && __GNUC__ >= 10 && defined(__AVX512BF16__)) || (defined(__clang__) && __clang_major__ >= 9)) | |||
#define HAVE_SHDOT_ACCL_KERNEL 1 | |||
#include "common.h" | |||
#include <immintrin.h> | |||
static float shdot_accl_kernel(BLASLONG n, bfloat16 *x, bfloat16 *y) | |||
{ | |||
__m128 accum128 = _mm_setzero_ps(); | |||
if (n> 127) { /* n range from 128 to inf. */ | |||
long tail_index_32 = n&(~31); | |||
long tail_index_128 = n&(~127); | |||
unsigned int tail_mask_uint = (((unsigned int)0xffffffff) >> (32-(n&31))); | |||
__mmask32 tail_mask = *((__mmask32*) &tail_mask_uint); | |||
__m512 accum512_0 = _mm512_setzero_ps(); | |||
__m512 accum512_1 = _mm512_setzero_ps(); | |||
__m512 accum512_2 = _mm512_setzero_ps(); | |||
__m512 accum512_3 = _mm512_setzero_ps(); | |||
/* Processing the main chunk with 128-elements per round */ | |||
for (long i = 0; i < tail_index_128; i += 128) { | |||
accum512_0 = _mm512_dpbf16_ps(accum512_0, (__m512bh) _mm512_loadu_si512(&x[i+ 0]), (__m512bh) _mm512_loadu_si512(&y[i+ 0])); | |||
accum512_1 = _mm512_dpbf16_ps(accum512_1, (__m512bh) _mm512_loadu_si512(&x[i+32]), (__m512bh) _mm512_loadu_si512(&y[i+32])); | |||
accum512_2 = _mm512_dpbf16_ps(accum512_2, (__m512bh) _mm512_loadu_si512(&x[i+64]), (__m512bh) _mm512_loadu_si512(&y[i+64])); | |||
accum512_3 = _mm512_dpbf16_ps(accum512_3, (__m512bh) _mm512_loadu_si512(&x[i+96]), (__m512bh) _mm512_loadu_si512(&y[i+96])); | |||
} | |||
/* Processing the remaining <128 chunk with 32-elements per round */ | |||
for (long j = tail_index_128; j < tail_index_32; j += 32) { | |||
accum512_0 = _mm512_dpbf16_ps(accum512_0, (__m512bh) _mm512_loadu_si512(&x[j]), (__m512bh) _mm512_loadu_si512(&y[j])); | |||
} | |||
/* Processing the remaining <32 chunk with masked 32-elements processing */ | |||
if ((n&31) != 0) { | |||
accum512_2 = _mm512_dpbf16_ps(accum512_2, | |||
(__m512bh) _mm512_maskz_loadu_epi16(tail_mask, &x[tail_index_32]), | |||
(__m512bh) _mm512_maskz_loadu_epi16(tail_mask, &y[tail_index_32])); | |||
} | |||
/* Accumulate the 4 registers into 1 register */ | |||
accum512_0 = _mm512_add_ps(accum512_0, accum512_1); | |||
accum512_2 = _mm512_add_ps(accum512_2, accum512_3); | |||
accum512_0 = _mm512_add_ps(accum512_0, accum512_2); | |||
__m256 accum256 = _mm256_add_ps(_mm512_castps512_ps256(accum512_0), _mm512_extractf32x8_ps(accum512_0, 1)); | |||
accum128 = _mm_add_ps(_mm256_castps256_ps128(accum256), _mm256_extractf128_ps(accum256, 1)); | |||
} else if (n > 31) { /* n range from 32 to 127 */ | |||
/* Processing <128 chunk with 32-elements per round */ | |||
__m256 accum256 = _mm256_setzero_ps(); | |||
__m256 accum256_1 = _mm256_setzero_ps(); | |||
int tail_index_32 = n&(~31); | |||
for (int j = 0; j < tail_index_32; j += 32) { | |||
accum256 = _mm256_dpbf16_ps(accum256, (__m256bh) _mm256_loadu_si256(&x[j+ 0]), (__m256bh) _mm256_loadu_si256(&y[j+ 0])); | |||
accum256_1 = _mm256_dpbf16_ps(accum256_1, (__m256bh) _mm256_loadu_si256(&x[j+16]), (__m256bh) _mm256_loadu_si256(&y[j+16])); | |||
} | |||
accum256 = _mm256_add_ps(accum256, accum256_1); | |||
/* Processing the remaining <32 chunk with 16-elements processing */ | |||
if ((n&16) != 0) { | |||
accum256 = _mm256_dpbf16_ps(accum256, (__m256bh) _mm256_loadu_si256(&x[tail_index_32]), (__m256bh) _mm256_loadu_si256(&y[tail_index_32])); | |||
} | |||
accum128 = _mm_add_ps(_mm256_castps256_ps128(accum256), _mm256_extractf128_ps(accum256, 1)); | |||
/* Processing the remaining <16 chunk with 8-elements processing */ | |||
if ((n&8) != 0) { | |||
int tail_index_16 = n&(~15); | |||
accum128 = _mm_dpbf16_ps(accum128, (__m128bh) _mm_loadu_si128(&x[tail_index_16]), (__m128bh) _mm_loadu_si128(&y[tail_index_16])); | |||
} | |||
/* Processing the remaining <8 chunk with masked 8-elements processing */ | |||
if ((n&7) != 0) { | |||
unsigned char tail_mask_uint = (((unsigned char)0xff) >> (8-(n&7))); | |||
__mmask8 tail_mask = *((__mmask8*) &tail_mask_uint); | |||
int tail_index_8 = n&(~7); | |||
accum128 = _mm_dpbf16_ps(accum128, | |||
(__m128bh) _mm_maskz_loadu_epi16(tail_mask, &x[tail_index_8]), | |||
(__m128bh) _mm_maskz_loadu_epi16(tail_mask, &y[tail_index_8])); | |||
} | |||
} else if (n > 15) { /* n range from 16 to 31 */ | |||
/* Processing <32 chunk with 16-elements processing */ | |||
__m256 accum256 = _mm256_setzero_ps(); | |||
accum256 = _mm256_dpbf16_ps(accum256, (__m256bh) _mm256_loadu_si256(&x[0]), (__m256bh) _mm256_loadu_si256(&y[0])); | |||
accum128 += _mm_add_ps(_mm256_castps256_ps128(accum256), _mm256_extractf128_ps(accum256, 1)); | |||
/* Processing the remaining <16 chunk with 8-elements processing */ | |||
if ((n&8) != 0) { | |||
int tail_index_16 = n&(~15); | |||
accum128 = _mm_dpbf16_ps(accum128, (__m128bh) _mm_loadu_si128(&x[tail_index_16]), (__m128bh) _mm_loadu_si128(&y[tail_index_16])); | |||
} | |||
/* Processing the remaining <8 chunk with masked 8-elements processing */ | |||
if ((n&7) != 0) { | |||
unsigned char tail_mask_uint = (((unsigned char)0xff) >> (8-(n&7))); | |||
__mmask8 tail_mask = *((__mmask8*) &tail_mask_uint); | |||
int tail_index_8 = n&(~7); | |||
accum128 = _mm_dpbf16_ps(accum128, | |||
(__m128bh) _mm_maskz_loadu_epi16(tail_mask, &x[tail_index_8]), | |||
(__m128bh) _mm_maskz_loadu_epi16(tail_mask, &y[tail_index_8])); | |||
} | |||
} else if (n > 7) { /* n range from 8 to 15 */ | |||
/* Processing <16 chunk with 8-elements processing */ | |||
accum128 = _mm_dpbf16_ps(accum128, (__m128bh) _mm_loadu_si128(&x[0]), (__m128bh) _mm_loadu_si128(&y[0])); | |||
/* Processing the remaining <8 chunk with masked 8-elements processing */ | |||
if ((n&7) != 0) { | |||
unsigned char tail_mask_uint = (((unsigned char)0xff) >> (8-(n&7))); | |||
__mmask8 tail_mask = *((__mmask8*) &tail_mask_uint); | |||
int tail_index_8 = n&(~7); | |||
accum128 = _mm_dpbf16_ps(accum128, | |||
(__m128bh) _mm_maskz_loadu_epi16(tail_mask, &x[tail_index_8]), | |||
(__m128bh) _mm_maskz_loadu_epi16(tail_mask, &y[tail_index_8])); | |||
} | |||
} else { /* n range from 1 to 7 */ | |||
unsigned char tail_mask_uint = (((unsigned char)0xff) >> (8-(n&7))); | |||
__mmask8 tail_mask = *((__mmask8*) &tail_mask_uint); | |||
accum128 = _mm_dpbf16_ps(accum128, | |||
(__m128bh) _mm_maskz_loadu_epi16(tail_mask, &x[0]), | |||
(__m128bh) _mm_maskz_loadu_epi16(tail_mask, &y[0])); | |||
} | |||
/* Add up the 4 elements into lowest entry */ | |||
__m128 accum128_1 = _mm_shuffle_ps(accum128, accum128, 14); | |||
accum128 = _mm_add_ps(accum128, accum128_1); | |||
accum128_1 = _mm_shuffle_ps(accum128, accum128, 1); | |||
accum128 = _mm_add_ps(accum128, accum128_1); | |||
return accum128[0]; | |||
} | |||
#endif |
@@ -0,0 +1,86 @@ | |||
/*************************************************************************** | |||
Copyright (c) 2014, The OpenBLAS Project | |||
All rights reserved. | |||
Redistribution and use in source and binary forms, with or without | |||
modification, are permitted provided that the following conditions are | |||
met: | |||
1. Redistributions of source code must retain the above copyright | |||
notice, this list of conditions and the following disclaimer. | |||
2. Redistributions in binary form must reproduce the above copyright | |||
notice, this list of conditions and the following disclaimer in | |||
the documentation and/or other materials provided with the | |||
distribution. | |||
3. Neither the name of the OpenBLAS project nor the names of | |||
its contributors may be used to endorse or promote products | |||
derived from this software without specific prior written permission. | |||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE | |||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE | |||
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||
*****************************************************************************/ | |||
/* need a new enough GCC for avx512 support */ | |||
#if (( defined(__GNUC__) && __GNUC__ >= 10 && defined(__AVX512BF16__)) || (defined(__clang__) && __clang_major__ >= 9)) | |||
#define HAVE_TOBF16_ACCL_KERNEL 1 | |||
#include "common.h" | |||
#include <immintrin.h> | |||
static void tobf16_accl_kernel(BLASLONG n, const float * in, bfloat16 * out) | |||
{ | |||
/* Get the 64-bytes unaligned header number targeting for avx512 | |||
* processing (Assume input float array is natural aligned) */ | |||
int align_header = ((64 - ((uintptr_t)in & (uintptr_t)0x3f)) >> 2) & 0xf; | |||
if (n < align_header) {align_header = n;} | |||
if (align_header != 0) { | |||
uint16_t align_mask16 = (((uint16_t)0xffff) >> (16-align_header)); | |||
__m512 a = _mm512_maskz_loadu_ps(*((__mmask16*) &align_mask16), &in[0]); | |||
_mm256_mask_storeu_epi16(&out[0], *((__mmask16*) &align_mask16), (__m256i) _mm512_cvtneps_pbh(a)); | |||
} | |||
if (n == align_header) { | |||
return; | |||
} else { | |||
n -= align_header; | |||
in += align_header; | |||
out += align_header; | |||
} | |||
int tail_index_32 = n&(~31); | |||
int tail_index_128 = n&(~127); | |||
uint32_t tail_mask32 = (((uint32_t) 0xffffffff) >> (32-(n&31))); | |||
uint16_t tail_mask16 = (((uint16_t) 0xffff) >> (16-(n&15))); | |||
/* Processing the main chunk with 128-elements per round */ | |||
for (int i = 0; i < tail_index_128; i += 128) { | |||
_mm512_storeu_si512(&out[i+ 0], (__m512i) _mm512_cvtne2ps_pbh(_mm512_load_ps(&in[i+ 16]), _mm512_load_ps(&in[i+ 0]))); | |||
_mm512_storeu_si512(&out[i+32], (__m512i) _mm512_cvtne2ps_pbh(_mm512_load_ps(&in[i+ 48]), _mm512_load_ps(&in[i+32]))); | |||
_mm512_storeu_si512(&out[i+64], (__m512i) _mm512_cvtne2ps_pbh(_mm512_load_ps(&in[i+ 80]), _mm512_load_ps(&in[i+64]))); | |||
_mm512_storeu_si512(&out[i+96], (__m512i) _mm512_cvtne2ps_pbh(_mm512_load_ps(&in[i+112]), _mm512_load_ps(&in[i+96]))); | |||
} | |||
/* Processing the remaining <128 chunk with 32-elements per round */ | |||
for (int j = tail_index_128; j < tail_index_32; j += 32) { | |||
_mm512_storeu_si512(&out[j], (__m512i) _mm512_cvtne2ps_pbh(_mm512_load_ps(&in[j+ 16]), _mm512_load_ps(&in[j]))); | |||
} | |||
/* Processing the remaining <32 chunk with masked processing */ | |||
if ((n&31) > 15) { | |||
__m512 b = _mm512_load_ps(&in[tail_index_32]); | |||
__m512 a = _mm512_maskz_load_ps(*((__mmask16*) &tail_mask16), &in[tail_index_32+16]); | |||
_mm512_mask_storeu_epi16(&out[tail_index_32], *((__mmask32*) &tail_mask32), (__m512i) _mm512_cvtne2ps_pbh(a, b)); | |||
} else if ((n&31) > 0) { | |||
__m512 a = _mm512_maskz_load_ps(*((__mmask16*) &tail_mask16), &in[tail_index_32]); | |||
_mm256_mask_storeu_epi16(&out[tail_index_32], *((__mmask16*) &tail_mask16), (__m256i) _mm512_cvtneps_pbh(a)); | |||
} | |||
} | |||
#endif |
@@ -0,0 +1,170 @@ | |||
/*************************************************************************** | |||
Copyright (c) 2014, The OpenBLAS Project | |||
All rights reserved. | |||
Redistribution and use in source and binary forms, with or without | |||
modification, are permitted provided that the following conditions are | |||
met: | |||
1. Redistributions of source code must retain the above copyright | |||
notice, this list of conditions and the following disclaimer. | |||
2. Redistributions in binary form must reproduce the above copyright | |||
notice, this list of conditions and the following disclaimer in | |||
the documentation and/or other materials provided with the | |||
distribution. | |||
3. Neither the name of the OpenBLAS project nor the names of | |||
its contributors may be used to endorse or promote products | |||
derived from this software without specific prior written permission. | |||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE | |||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE | |||
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||
*****************************************************************************/ | |||
#include <stddef.h> | |||
#include "common.h" | |||
#if defined(DOUBLE) | |||
#define FLOAT_TYPE double | |||
#elif defined(SINGLE) | |||
#define FLOAT_TYPE float | |||
#else | |||
#endif | |||
#if defined(COOPERLAKE) | |||
#if defined(DOUBLE) | |||
#include "dtobf16_microk_cooperlake.c" | |||
#elif defined(SINGLE) | |||
#include "stobf16_microk_cooperlake.c" | |||
#endif | |||
#endif | |||
/* Notes for algorithm: | |||
* - Round to Nearest Even used generally | |||
* - QNAN for NAN case | |||
* - Input denormals are treated as zero | |||
*/ | |||
static void tobf16_generic_kernel(BLASLONG n, const FLOAT_TYPE * in, BLASLONG inc_in, bfloat16 * out, BLASLONG inc_out) | |||
{ | |||
BLASLONG register index_in = 0; | |||
BLASLONG register index_out = 0; | |||
BLASLONG register index = 0; | |||
float float_in = 0.0; | |||
uint32_t * uint32_in = (uint32_t *)(&float_in); | |||
uint16_t * uint16_in = (uint16_t *)(&float_in); | |||
while(index<n) { | |||
#if defined(DOUBLE) | |||
float_in = (float)(*(in+index_in)); | |||
#else | |||
float_in = *(in+index_in); | |||
#endif | |||
switch((*uint32_in) & 0xff800000u) { | |||
case (0x00000000u): /* Type 1: Positive denormal */ | |||
*(out+index_out) = 0x0000u; | |||
break; | |||
case (0x80000000u): /* Type 2: Negative denormal */ | |||
*(out+index_out) = 0x8000u; | |||
break; | |||
case (0x7f800000u): /* Type 3: Positive infinity or NAN */ | |||
case (0xff800000u): /* Type 4: Negative infinity or NAN */ | |||
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ | |||
*(out+index_out) = uint16_in[1]; | |||
#else | |||
*(out+index_out) = uint16_in[0]; | |||
#endif | |||
/* Specific for NAN */ | |||
if (((*uint32_in) & 0x007fffffu) != 0) { | |||
/* Force to be QNAN */ | |||
*(out+index_out) |= 0x0040u; | |||
} | |||
break; | |||
default: /* Type 5: Normal case */ | |||
(*uint32_in) += ((((*uint32_in) >> 16) & 0x1u) + 0x7fffu); | |||
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ | |||
*(out+index_out) = uint16_in[1]; | |||
#else | |||
*(out+index_out) = uint16_in[0]; | |||
#endif | |||
break; | |||
} | |||
index_in += inc_in; | |||
index_out += inc_out; | |||
index++; | |||
} | |||
} | |||
#ifndef HAVE_TOBF16_ACCL_KERNEL | |||
static void tobf16_accl_kernel(BLASLONG n, const FLOAT_TYPE * in, bfloat16 * out) | |||
{ | |||
tobf16_generic_kernel(n, in, 1, out, 1); | |||
} | |||
#endif | |||
static void tobf16_compute(BLASLONG n, FLOAT_TYPE * in, BLASLONG inc_in, bfloat16 * out, BLASLONG inc_out) | |||
{ | |||
if ((inc_in == 1) && (inc_out == 1)) { | |||
tobf16_accl_kernel(n, in, out); | |||
} else { | |||
tobf16_generic_kernel(n, in, inc_in, out, inc_out); | |||
} | |||
} | |||
#if defined(SMP) | |||
static int tobf16_thread_func(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT_TYPE dummy2, | |||
FLOAT_TYPE *x, BLASLONG inc_x, bfloat16 *y, BLASLONG inc_y, | |||
FLOAT_TYPE *dummy3, BLASLONG dummy4) | |||
{ | |||
tobf16_compute(n, x, inc_x, y, inc_y); | |||
return 0; | |||
} | |||
extern int blas_level1_thread(int mode, BLASLONG m, BLASLONG n, BLASLONG k, void *alpha, | |||
void *a, BLASLONG lda, void *b, BLASLONG ldb, void *c, BLASLONG ldc, | |||
int (*function)(), int nthreads); | |||
#endif | |||
void CNAME(BLASLONG n, FLOAT_TYPE * in, BLASLONG inc_in, bfloat16 * out, BLASLONG inc_out) | |||
{ | |||
if (n <= 0) return; | |||
#if defined(SMP) | |||
int nthreads; | |||
FLOAT_TYPE dummy_alpha; | |||
FLOAT_TYPE dummy_c; | |||
#endif | |||
#if defined(SMP) | |||
if (inc_in == 0 || inc_out == 0 || n <= 100000) { | |||
nthreads = 1; | |||
} else { | |||
if (n/100000 < 100) { | |||
nthreads = 4; | |||
} else { | |||
nthreads = 16; | |||
} | |||
} | |||
if (nthreads == 1) { | |||
tobf16_compute(n, in, inc_in, out, inc_out); | |||
} else { | |||
#if defined(DOUBLE) | |||
int mode = BLAS_REAL | BLAS_DTOBF16; | |||
#elif defined(SINGLE) | |||
int mode = BLAS_REAL | BLAS_STOBF16; | |||
#endif | |||
blas_level1_thread(mode, n, 0, 0, &dummy_alpha, | |||
in, inc_in, out, inc_out, &dummy_c, 0, | |||
(void *)tobf16_thread_func, nthreads); | |||
} | |||
#else | |||
tobf16_compute(n, in, inc_in, out, inc_out); | |||
#endif | |||
} |
@@ -35,7 +35,8 @@ typedef unsigned long BLASULONG; | |||
#endif | |||
#ifndef BFLOAT16 | |||
typedef unsigned short bfloat16; | |||
#include <stdint.h> | |||
typedef uint16_t bfloat16; | |||
#endif | |||
#ifdef OPENBLAS_USE64BITINT | |||