1. Added bfloat16 based dot as new API: shdot 2. Implemented generic kernel and cooperlake-specific (AVX512-BF16) kernel for shdot 3. Added 4 conversion APIs for bfloat16 data type <=> single/double: shstobf16 shdtobf16 sbf16tos dbf16tod shstobf16 -- convert single float array to bfloat16 array shdtobf16 -- convert double float array to bfloat16 array sbf16tos -- convert bfloat16 array to single float array dbf16tod -- convert bfloat16 array to double float array 4. Implemented generic kernels for all 4 conversion APIs, and cooperlake-specific kernel for shstobf16 and shdtobf16 5. Update level1 thread facilitate functions and macros to support multi-threading for these new APIs 6. Fix Cooperlake platform detection/specify issue when under dynamic-arch building 7. Change the typedef of bfloat16 from unsigned short to more strict uint16_t Signed-off-by: Chen, Guobing <guobing.chen@intel.com>tags/v0.3.11^2
@@ -5,13 +5,14 @@ QBLASOBJS_P = $(QBLASOBJS:.$(SUFFIX)=.$(PSUFFIX)) | |||||
CBLASOBJS_P = $(CBLASOBJS:.$(SUFFIX)=.$(PSUFFIX)) | CBLASOBJS_P = $(CBLASOBJS:.$(SUFFIX)=.$(PSUFFIX)) | ||||
ZBLASOBJS_P = $(ZBLASOBJS:.$(SUFFIX)=.$(PSUFFIX)) | ZBLASOBJS_P = $(ZBLASOBJS:.$(SUFFIX)=.$(PSUFFIX)) | ||||
XBLASOBJS_P = $(XBLASOBJS:.$(SUFFIX)=.$(PSUFFIX)) | XBLASOBJS_P = $(XBLASOBJS:.$(SUFFIX)=.$(PSUFFIX)) | ||||
SHEXTOBJS_P = $(SHEXTOBJS:.$(SUFFIX)=.$(PSUFFIX)) | |||||
COMMONOBJS_P = $(COMMONOBJS:.$(SUFFIX)=.$(PSUFFIX)) | COMMONOBJS_P = $(COMMONOBJS:.$(SUFFIX)=.$(PSUFFIX)) | ||||
HPLOBJS_P = $(HPLOBJS:.$(SUFFIX)=.$(PSUFFIX)) | HPLOBJS_P = $(HPLOBJS:.$(SUFFIX)=.$(PSUFFIX)) | ||||
BLASOBJS = $(SHBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS) | |||||
BLASOBJS_P = $(SHBLASOBJS_P) $(SBLASOBJS_P) $(DBLASOBJS_P) $(CBLASOBJS_P) $(ZBLASOBJS_P) | |||||
BLASOBJS = $(SHEXTOBJS) $(SHBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS) | |||||
BLASOBJS_P = $(SHEXTOBJS_P) $(SHBLASOBJS_P) $(SBLASOBJS_P) $(DBLASOBJS_P) $(CBLASOBJS_P) $(ZBLASOBJS_P) | |||||
ifdef EXPRECISION | ifdef EXPRECISION | ||||
BLASOBJS += $(QBLASOBJS) $(XBLASOBJS) | BLASOBJS += $(QBLASOBJS) $(XBLASOBJS) | ||||
@@ -30,6 +31,7 @@ $(QBLASOBJS) $(QBLASOBJS_P) : override CFLAGS += -DXDOUBLE -UCOMPLEX | |||||
$(CBLASOBJS) $(CBLASOBJS_P) : override CFLAGS += -UDOUBLE -DCOMPLEX | $(CBLASOBJS) $(CBLASOBJS_P) : override CFLAGS += -UDOUBLE -DCOMPLEX | ||||
$(ZBLASOBJS) $(ZBLASOBJS_P) : override CFLAGS += -DDOUBLE -DCOMPLEX | $(ZBLASOBJS) $(ZBLASOBJS_P) : override CFLAGS += -DDOUBLE -DCOMPLEX | ||||
$(XBLASOBJS) $(XBLASOBJS_P) : override CFLAGS += -DXDOUBLE -DCOMPLEX | $(XBLASOBJS) $(XBLASOBJS_P) : override CFLAGS += -DXDOUBLE -DCOMPLEX | ||||
$(SHEXTOBJS) $(SHEXTOBJS_P) : override CFLAGS += -DHALF -UDOUBLE -UCOMPLEX | |||||
$(SHBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF) | $(SHBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF) | ||||
$(SBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF) | $(SBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF) | ||||
@@ -38,6 +40,7 @@ $(QBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF) | |||||
$(CBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF) | $(CBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF) | ||||
$(ZBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF) | $(ZBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF) | ||||
$(XBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF) | $(XBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF) | ||||
$(SHEXTOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF) | |||||
libs :: $(BLASOBJS) $(COMMONOBJS) | libs :: $(BLASOBJS) $(COMMONOBJS) | ||||
$(AR) $(ARFLAGS) -ru $(TOPDIR)/$(LIBNAME) $^ | $(AR) $(ARFLAGS) -ru $(TOPDIR)/$(LIBNAME) $^ | ||||
@@ -382,6 +382,17 @@ void cblas_cgeadd(OPENBLAS_CONST enum CBLAS_ORDER CORDER,OPENBLAS_CONST blasint | |||||
void cblas_zgeadd(OPENBLAS_CONST enum CBLAS_ORDER CORDER,OPENBLAS_CONST blasint crows, OPENBLAS_CONST blasint ccols, OPENBLAS_CONST double *calpha, double *a, OPENBLAS_CONST blasint clda, OPENBLAS_CONST double *cbeta, | void cblas_zgeadd(OPENBLAS_CONST enum CBLAS_ORDER CORDER,OPENBLAS_CONST blasint crows, OPENBLAS_CONST blasint ccols, OPENBLAS_CONST double *calpha, double *a, OPENBLAS_CONST blasint clda, OPENBLAS_CONST double *cbeta, | ||||
double *c, OPENBLAS_CONST blasint cldc); | double *c, OPENBLAS_CONST blasint cldc); | ||||
/*** BFLOAT16 and INT8 extensions ***/ | |||||
/* convert float array to BFLOAT16 array by rounding */ | |||||
void cblas_shstobf16(OPENBLAS_CONST blasint n, OPENBLAS_CONST float *in, OPENBLAS_CONST blasint incin, bfloat16 *out, OPENBLAS_CONST blasint incout); | |||||
/* convert double array to BFLOAT16 array by rounding */ | |||||
void cblas_shdtobf16(OPENBLAS_CONST blasint n, OPENBLAS_CONST double *in, OPENBLAS_CONST blasint incin, bfloat16 *out, OPENBLAS_CONST blasint incout); | |||||
/* convert BFLOAT16 array to float array */ | |||||
void cblas_sbf16tos(OPENBLAS_CONST blasint n, OPENBLAS_CONST bfloat16 *in, OPENBLAS_CONST blasint incin, float *out, OPENBLAS_CONST blasint incout); | |||||
/* convert BFLOAT16 array to double array */ | |||||
void cblas_dbf16tod(OPENBLAS_CONST blasint n, OPENBLAS_CONST bfloat16 *in, OPENBLAS_CONST blasint incin, double *out, OPENBLAS_CONST blasint incout); | |||||
/* dot production of BFLOAT16 input arrays, and output as float */ | |||||
float cblas_shdot(OPENBLAS_CONST blasint n, OPENBLAS_CONST bfloat16 *x, OPENBLAS_CONST blasint incx, OPENBLAS_CONST bfloat16 *y, OPENBLAS_CONST blasint incy); | |||||
#ifdef __cplusplus | #ifdef __cplusplus | ||||
} | } | ||||
@@ -126,12 +126,14 @@ if (BUILD_HALF) | |||||
set(SHAXPYKERNEL ../arm/axpy.c) | set(SHAXPYKERNEL ../arm/axpy.c) | ||||
set(SHAXPBYKERNEL ../arm/axpby.c) | set(SHAXPBYKERNEL ../arm/axpby.c) | ||||
set(SHCOPYKERNEL ../arm/copy.c) | set(SHCOPYKERNEL ../arm/copy.c) | ||||
set(SHDOTKERNEL ../arm/dot.c) | |||||
set(SHDOTKERNEL ../x86_64/shdot.c) | |||||
set(SHROTKERNEL ../arm/rot.c) | set(SHROTKERNEL ../arm/rot.c) | ||||
set(SHSCALKERNEL ../arm/scal.c) | set(SHSCALKERNEL ../arm/scal.c) | ||||
set(SHNRM2KERNEL ../arm/nrm2.c) | set(SHNRM2KERNEL ../arm/nrm2.c) | ||||
set(SHSUMKERNEL ../arm/sum.c) | set(SHSUMKERNEL ../arm/sum.c) | ||||
set(SHSWAPKERNEL ../arm/swap.c) | set(SHSWAPKERNEL ../arm/swap.c) | ||||
set(TOBF16KERNEL ../x86_64/tobf16.c) | |||||
set(BF16TOKERNEL ../x86_64/bf16to.c) | |||||
endif () | endif () | ||||
endmacro () | endmacro () | ||||
@@ -258,7 +258,8 @@ typedef unsigned long BLASULONG; | |||||
#endif | #endif | ||||
#ifndef BFLOAT16 | #ifndef BFLOAT16 | ||||
typedef unsigned short bfloat16; | |||||
#include <stdint.h> | |||||
typedef uint16_t bfloat16; | |||||
#define HALFCONVERSION 1 | #define HALFCONVERSION 1 | ||||
#endif | #endif | ||||
@@ -54,6 +54,11 @@ double BLASFUNC(dsdot) (blasint *, float *, blasint *, float *, blasint *); | |||||
double BLASFUNC(ddot) (blasint *, double *, blasint *, double *, blasint *); | double BLASFUNC(ddot) (blasint *, double *, blasint *, double *, blasint *); | ||||
xdouble BLASFUNC(qdot) (blasint *, xdouble *, blasint *, xdouble *, blasint *); | xdouble BLASFUNC(qdot) (blasint *, xdouble *, blasint *, xdouble *, blasint *); | ||||
float BLASFUNC(shdot) (blasint *, bfloat16 *, blasint *, bfloat16 *, blasint *); | |||||
void BLASFUNC(shstobf16) (blasint *, float *, blasint *, bfloat16 *, blasint *); | |||||
void BLASFUNC(shdtobf16) (blasint *, double *, blasint *, bfloat16 *, blasint *); | |||||
void BLASFUNC(sbf16tos) (blasint *, bfloat16 *, blasint *, float *, blasint *); | |||||
void BLASFUNC(dbf16tod) (blasint *, bfloat16 *, blasint *, double *, blasint *); | |||||
#ifdef RETURN_BY_STRUCT | #ifdef RETURN_BY_STRUCT | ||||
typedef struct { | typedef struct { | ||||
@@ -46,6 +46,12 @@ float sdot_k(BLASLONG, float *, BLASLONG, float *, BLASLONG); | |||||
double dsdot_k(BLASLONG, float *, BLASLONG, float *, BLASLONG); | double dsdot_k(BLASLONG, float *, BLASLONG, float *, BLASLONG); | ||||
double ddot_k(BLASLONG, double *, BLASLONG, double *, BLASLONG); | double ddot_k(BLASLONG, double *, BLASLONG, double *, BLASLONG); | ||||
xdouble qdot_k(BLASLONG, xdouble *, BLASLONG, xdouble *, BLASLONG); | xdouble qdot_k(BLASLONG, xdouble *, BLASLONG, xdouble *, BLASLONG); | ||||
float shdot_k(BLASLONG, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG); | |||||
void shstobf16_k(BLASLONG, float *, BLASLONG, bfloat16 *, BLASLONG); | |||||
void shdtobf16_k(BLASLONG, double *, BLASLONG, bfloat16 *, BLASLONG); | |||||
void sbf16tos_k (BLASLONG, bfloat16 *, BLASLONG, float *, BLASLONG); | |||||
void dbf16tod_k (BLASLONG, bfloat16 *, BLASLONG, double *, BLASLONG); | |||||
openblas_complex_float cdotc_k (BLASLONG, float *, BLASLONG, float *, BLASLONG); | openblas_complex_float cdotc_k (BLASLONG, float *, BLASLONG, float *, BLASLONG); | ||||
openblas_complex_float cdotu_k (BLASLONG, float *, BLASLONG, float *, BLASLONG); | openblas_complex_float cdotu_k (BLASLONG, float *, BLASLONG, float *, BLASLONG); | ||||
@@ -646,6 +646,11 @@ | |||||
#elif defined(HALF) | #elif defined(HALF) | ||||
#define D_TO_BF16_K SHDTOBF16_K | |||||
#define D_BF16_TO_K DBF16TOD_K | |||||
#define S_TO_BF16_K SHSTOBF16_K | |||||
#define S_BF16_TO_K SBF16TOS_K | |||||
#define AMAX_K SAMAX_K | #define AMAX_K SAMAX_K | ||||
#define AMIN_K SAMIN_K | #define AMIN_K SAMIN_K | ||||
#define MAX_K SMAX_K | #define MAX_K SMAX_K | ||||
@@ -657,6 +662,7 @@ | |||||
#define ASUM_K SASUM_K | #define ASUM_K SASUM_K | ||||
#define DOTU_K SDOTU_K | #define DOTU_K SDOTU_K | ||||
#define DOTC_K SDOTC_K | #define DOTC_K SDOTC_K | ||||
#define BF16_DOT_K SHDOT_K | |||||
#define AXPYU_K SAXPYU_K | #define AXPYU_K SAXPYU_K | ||||
#define AXPYC_K SAXPYC_K | #define AXPYC_K SAXPYC_K | ||||
#define AXPBY_K SAXPBY_K | #define AXPBY_K SAXPBY_K | ||||
@@ -51,6 +51,11 @@ typedef struct { | |||||
int shgemm_p, shgemm_q, shgemm_r; | int shgemm_p, shgemm_q, shgemm_r; | ||||
int shgemm_unroll_m, shgemm_unroll_n, shgemm_unroll_mn; | int shgemm_unroll_m, shgemm_unroll_n, shgemm_unroll_mn; | ||||
void (*shstobf16_k) (BLASLONG, float *, BLASLONG, bfloat16 *, BLASLONG); | |||||
void (*shdtobf16_k) (BLASLONG, double *, BLASLONG, bfloat16 *, BLASLONG); | |||||
void (*sbf16tos_k) (BLASLONG, bfloat16 *, BLASLONG, float *, BLASLONG); | |||||
void (*dbf16tod_k) (BLASLONG, bfloat16 *, BLASLONG, double *, BLASLONG); | |||||
float (*shamax_k) (BLASLONG, float *, BLASLONG); | float (*shamax_k) (BLASLONG, float *, BLASLONG); | ||||
float (*shamin_k) (BLASLONG, float *, BLASLONG); | float (*shamin_k) (BLASLONG, float *, BLASLONG); | ||||
float (*shmax_k) (BLASLONG, float *, BLASLONG); | float (*shmax_k) (BLASLONG, float *, BLASLONG); | ||||
@@ -64,7 +69,7 @@ BLASLONG (*ishmin_k) (BLASLONG, float *, BLASLONG); | |||||
float (*shasum_k) (BLASLONG, float *, BLASLONG); | float (*shasum_k) (BLASLONG, float *, BLASLONG); | ||||
float (*shsum_k) (BLASLONG, float *, BLASLONG); | float (*shsum_k) (BLASLONG, float *, BLASLONG); | ||||
int (*shcopy_k) (BLASLONG, float *, BLASLONG, float *, BLASLONG); | int (*shcopy_k) (BLASLONG, float *, BLASLONG, float *, BLASLONG); | ||||
float (*shdot_k) (BLASLONG, float *, BLASLONG, float *, BLASLONG); | |||||
float (*shdot_k) (BLASLONG, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG); | |||||
double (*dshdot_k) (BLASLONG, float *, BLASLONG, float *, BLASLONG); | double (*dshdot_k) (BLASLONG, float *, BLASLONG, float *, BLASLONG); | ||||
int (*shrot_k) (BLASLONG, float *, BLASLONG, float *, BLASLONG, float, float); | int (*shrot_k) (BLASLONG, float *, BLASLONG, float *, BLASLONG, float, float); | ||||
@@ -3,6 +3,12 @@ | |||||
#ifndef DYNAMIC_ARCH | #ifndef DYNAMIC_ARCH | ||||
#define SHDOT_K shdot_k | |||||
#define SHSTOBF16_K shstobf16_k | |||||
#define SHDTOBF16_K shdtobf16_k | |||||
#define SBF16TOS_K sbf16tos_k | |||||
#define DBF16TOD_K dbf16tod_k | |||||
#define SHGEMM_ONCOPY shgemm_oncopy | #define SHGEMM_ONCOPY shgemm_oncopy | ||||
#define SHGEMM_OTCOPY shgemm_otcopy | #define SHGEMM_OTCOPY shgemm_otcopy | ||||
@@ -18,6 +24,12 @@ | |||||
#else | #else | ||||
#define SHDOT_K gotoblas -> shdot_k | |||||
#define SHSTOBF16_K gotoblas -> shstobf16_k | |||||
#define SHDTOBF16_K gotoblas -> shdtobf16_k | |||||
#define SBF16TOS_K gotoblas -> sbf16tos_k | |||||
#define DBF16TOD_K gotoblas -> dbf16tod_k | |||||
#define SHGEMM_ONCOPY gotoblas -> shgemm_oncopy | #define SHGEMM_ONCOPY gotoblas -> shgemm_oncopy | ||||
#define SHGEMM_OTCOPY gotoblas -> shgemm_otcopy | #define SHGEMM_OTCOPY gotoblas -> shgemm_otcopy | ||||
#define SHGEMM_INCOPY gotoblas -> shgemm_incopy | #define SHGEMM_INCOPY gotoblas -> shgemm_incopy | ||||
@@ -59,12 +59,19 @@ extern int blas_omp_linked; | |||||
#define BLAS_PTHREAD 0x4000U | #define BLAS_PTHREAD 0x4000U | ||||
#define BLAS_NODE 0x2000U | #define BLAS_NODE 0x2000U | ||||
#define BLAS_PREC 0x0003U | |||||
#define BLAS_SINGLE 0x0000U | |||||
#define BLAS_DOUBLE 0x0001U | |||||
#define BLAS_XDOUBLE 0x0002U | |||||
#define BLAS_REAL 0x0000U | |||||
#define BLAS_COMPLEX 0x0004U | |||||
#define BLAS_PREC 0x000FU | |||||
#define BLAS_INT8 0x0000U | |||||
#define BLAS_BFLOAT16 0x0001U | |||||
#define BLAS_SINGLE 0x0002U | |||||
#define BLAS_DOUBLE 0x0003U | |||||
#define BLAS_XDOUBLE 0x0004U | |||||
#define BLAS_STOBF16 0x0008U | |||||
#define BLAS_DTOBF16 0x0009U | |||||
#define BLAS_BF16TOS 0x000AU | |||||
#define BLAS_BF16TOD 0x000BU | |||||
#define BLAS_REAL 0x0000U | |||||
#define BLAS_COMPLEX 0x1000U | |||||
#define BLAS_TRANSA 0x0030U /* 2bit */ | #define BLAS_TRANSA 0x0030U /* 2bit */ | ||||
#define BLAS_TRANSA_N 0x0000U | #define BLAS_TRANSA_N 0x0000U | ||||
@@ -142,6 +142,29 @@ static __inline void cpuid(int op, int *eax, int *ebx, int *ecx, int *edx){ | |||||
#endif | #endif | ||||
} | } | ||||
static __inline void cpuid_count(int op, int count, int *eax, int *ebx, int *ecx, int *edx) | |||||
{ | |||||
#ifdef C_MSVC | |||||
int cpuInfo[4] = {-1}; | |||||
__cpuidex(cpuInfo, op, count); | |||||
*eax = cpuInfo[0]; | |||||
*ebx = cpuInfo[1]; | |||||
*ecx = cpuInfo[2]; | |||||
*edx = cpuInfo[3]; | |||||
#else | |||||
#if defined(__i386__) && defined(__PIC__) | |||||
__asm__ __volatile__ | |||||
("mov %%ebx, %%edi;" | |||||
"cpuid;" | |||||
"xchgl %%ebx, %%edi;" | |||||
: "=a" (*eax), "=D" (*ebx), "=c" (*ecx), "=d" (*edx) : "0" (op), "2" (count) : "cc"); | |||||
#else | |||||
__asm__ __volatile__ | |||||
("cpuid": "=a" (*eax), "=b" (*ebx), "=c" (*ecx), "=d" (*edx) : "0" (op), "2" (count) : "cc"); | |||||
#endif | |||||
#endif | |||||
} | |||||
/* | /* | ||||
#define WHEREAMI | #define WHEREAMI | ||||
*/ | */ | ||||
@@ -49,9 +49,36 @@ int blas_level1_thread(int mode, BLASLONG m, BLASLONG n, BLASLONG k, void *alpha | |||||
blas_arg_t args [MAX_CPU_NUMBER]; | blas_arg_t args [MAX_CPU_NUMBER]; | ||||
BLASLONG i, width, astride, bstride; | BLASLONG i, width, astride, bstride; | ||||
int num_cpu, calc_type; | |||||
calc_type = (mode & BLAS_PREC) + ((mode & BLAS_COMPLEX) != 0) + 2; | |||||
int num_cpu, calc_type_a, calc_type_b; | |||||
switch (mode & BLAS_PREC) { | |||||
case BLAS_INT8 : | |||||
case BLAS_BFLOAT16: | |||||
case BLAS_SINGLE : | |||||
case BLAS_DOUBLE : | |||||
case BLAS_XDOUBLE : | |||||
calc_type_a = calc_type_b = (mode & BLAS_PREC) + ((mode & BLAS_COMPLEX) != 0); | |||||
break; | |||||
case BLAS_STOBF16 : | |||||
calc_type_a = 2 + ((mode & BLAS_COMPLEX) != 0); | |||||
calc_type_b = 1 + ((mode & BLAS_COMPLEX) != 0); | |||||
break; | |||||
case BLAS_DTOBF16 : | |||||
calc_type_a = 3 + ((mode & BLAS_COMPLEX) != 0); | |||||
calc_type_b = 1 + ((mode & BLAS_COMPLEX) != 0); | |||||
break; | |||||
case BLAS_BF16TOS : | |||||
calc_type_a = 1 + ((mode & BLAS_COMPLEX) != 0); | |||||
calc_type_b = 2 + ((mode & BLAS_COMPLEX) != 0); | |||||
break; | |||||
case BLAS_BF16TOD : | |||||
calc_type_a = 1 + ((mode & BLAS_COMPLEX) != 0); | |||||
calc_type_b = 3 + ((mode & BLAS_COMPLEX) != 0); | |||||
break; | |||||
default: | |||||
calc_type_a = calc_type_b = 0; | |||||
break; | |||||
} | |||||
mode |= BLAS_LEGACY; | mode |= BLAS_LEGACY; | ||||
@@ -77,8 +104,8 @@ int blas_level1_thread(int mode, BLASLONG m, BLASLONG n, BLASLONG k, void *alpha | |||||
bstride = width; | bstride = width; | ||||
} | } | ||||
astride <<= calc_type; | |||||
bstride <<= calc_type; | |||||
astride <<= calc_type_a; | |||||
bstride <<= calc_type_b; | |||||
args[num_cpu].m = width; | args[num_cpu].m = width; | ||||
args[num_cpu].n = n; | args[num_cpu].n = n; | ||||
@@ -120,9 +147,36 @@ int blas_level1_thread_with_return_value(int mode, BLASLONG m, BLASLONG n, BLASL | |||||
blas_arg_t args [MAX_CPU_NUMBER]; | blas_arg_t args [MAX_CPU_NUMBER]; | ||||
BLASLONG i, width, astride, bstride; | BLASLONG i, width, astride, bstride; | ||||
int num_cpu, calc_type; | |||||
calc_type = (mode & BLAS_PREC) + ((mode & BLAS_COMPLEX) != 0) + 2; | |||||
int num_cpu, calc_type_a, calc_type_b; | |||||
switch (mode & BLAS_PREC) { | |||||
case BLAS_INT8 : | |||||
case BLAS_BFLOAT16: | |||||
case BLAS_SINGLE : | |||||
case BLAS_DOUBLE : | |||||
case BLAS_XDOUBLE : | |||||
calc_type_a = calc_type_b = (mode & BLAS_PREC) + ((mode & BLAS_COMPLEX) != 0); | |||||
break; | |||||
case BLAS_STOBF16 : | |||||
calc_type_a = 2 + ((mode & BLAS_COMPLEX) != 0); | |||||
calc_type_b = 1 + ((mode & BLAS_COMPLEX) != 0); | |||||
break; | |||||
case BLAS_DTOBF16 : | |||||
calc_type_a = 3 + ((mode & BLAS_COMPLEX) != 0); | |||||
calc_type_b = 1 + ((mode & BLAS_COMPLEX) != 0); | |||||
break; | |||||
case BLAS_BF16TOS : | |||||
calc_type_a = 1 + ((mode & BLAS_COMPLEX) != 0); | |||||
calc_type_b = 2 + ((mode & BLAS_COMPLEX) != 0); | |||||
break; | |||||
case BLAS_BF16TOD : | |||||
calc_type_a = 1 + ((mode & BLAS_COMPLEX) != 0); | |||||
calc_type_b = 3 + ((mode & BLAS_COMPLEX) != 0); | |||||
break; | |||||
default: | |||||
calc_type_a = calc_type_b = 0; | |||||
break; | |||||
} | |||||
mode |= BLAS_LEGACY; | mode |= BLAS_LEGACY; | ||||
@@ -148,8 +202,8 @@ int blas_level1_thread_with_return_value(int mode, BLASLONG m, BLASLONG n, BLASL | |||||
bstride = width; | bstride = width; | ||||
} | } | ||||
astride <<= calc_type; | |||||
bstride <<= calc_type; | |||||
astride <<= calc_type_a; | |||||
bstride <<= calc_type_b; | |||||
args[num_cpu].m = width; | args[num_cpu].m = width; | ||||
args[num_cpu].n = n; | args[num_cpu].n = n; | ||||
@@ -192,7 +192,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){ | |||||
if (!(mode & BLAS_COMPLEX)){ | if (!(mode & BLAS_COMPLEX)){ | ||||
#ifdef EXPRECISION | #ifdef EXPRECISION | ||||
if (mode & BLAS_XDOUBLE){ | |||||
if ((mode & BLAS_PREC) == BLAS_XDOUBLE){ | |||||
/* REAL / Extended Double */ | /* REAL / Extended Double */ | ||||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, xdouble, | void (*afunc)(BLASLONG, BLASLONG, BLASLONG, xdouble, | ||||
xdouble *, BLASLONG, xdouble *, BLASLONG, | xdouble *, BLASLONG, xdouble *, BLASLONG, | ||||
@@ -205,7 +205,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){ | |||||
args -> c, args -> ldc, sb); | args -> c, args -> ldc, sb); | ||||
} else | } else | ||||
#endif | #endif | ||||
if (mode & BLAS_DOUBLE){ | |||||
if ((mode & BLAS_PREC) == BLAS_DOUBLE){ | |||||
/* REAL / Double */ | /* REAL / Double */ | ||||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, double, | void (*afunc)(BLASLONG, BLASLONG, BLASLONG, double, | ||||
double *, BLASLONG, double *, BLASLONG, | double *, BLASLONG, double *, BLASLONG, | ||||
@@ -216,21 +216,58 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){ | |||||
args -> a, args -> lda, | args -> a, args -> lda, | ||||
args -> b, args -> ldb, | args -> b, args -> ldb, | ||||
args -> c, args -> ldc, sb); | args -> c, args -> ldc, sb); | ||||
} else { | |||||
/* REAL / Single */ | |||||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float, | |||||
float *, BLASLONG, float *, BLASLONG, | |||||
float *, BLASLONG, void *) = func; | |||||
afunc(args -> m, args -> n, args -> k, | |||||
((float *)args -> alpha)[0], | |||||
args -> a, args -> lda, | |||||
args -> b, args -> ldb, | |||||
args -> c, args -> ldc, sb); | |||||
} else if ((mode & BLAS_PREC) == BLAS_SINGLE){ | |||||
/* REAL / Single */ | |||||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float, | |||||
float *, BLASLONG, float *, BLASLONG, | |||||
float *, BLASLONG, void *) = func; | |||||
afunc(args -> m, args -> n, args -> k, | |||||
((float *)args -> alpha)[0], | |||||
args -> a, args -> lda, | |||||
args -> b, args -> ldb, | |||||
args -> c, args -> ldc, sb); | |||||
#ifdef BUILD_HALF | |||||
} else if ((mode & BLAS_PREC) == BLAS_BFLOAT16){ | |||||
/* REAL / BFLOAT16 */ | |||||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, bfloat16, | |||||
bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, | |||||
bfloat16 *, BLASLONG, void *) = func; | |||||
afunc(args -> m, args -> n, args -> k, | |||||
((bfloat16 *)args -> alpha)[0], | |||||
args -> a, args -> lda, | |||||
args -> b, args -> ldb, | |||||
args -> c, args -> ldc, sb); | |||||
} else if ((mode & BLAS_PREC) == BLAS_STOBF16){ | |||||
/* REAL / BLAS_STOBF16 */ | |||||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float, | |||||
float *, BLASLONG, bfloat16 *, BLASLONG, | |||||
float *, BLASLONG, void *) = func; | |||||
afunc(args -> m, args -> n, args -> k, | |||||
((float *)args -> alpha)[0], | |||||
args -> a, args -> lda, | |||||
args -> b, args -> ldb, | |||||
args -> c, args -> ldc, sb); | |||||
} else if ((mode & BLAS_PREC) == BLAS_DTOBF16){ | |||||
/* REAL / BLAS_DTOBF16 */ | |||||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, double, | |||||
double *, BLASLONG, bfloat16 *, BLASLONG, | |||||
double *, BLASLONG, void *) = func; | |||||
afunc(args -> m, args -> n, args -> k, | |||||
((double *)args -> alpha)[0], | |||||
args -> a, args -> lda, | |||||
args -> b, args -> ldb, | |||||
args -> c, args -> ldc, sb); | |||||
#endif | |||||
} else { | |||||
/* REAL / Other types in future */ | |||||
} | } | ||||
} else { | } else { | ||||
#ifdef EXPRECISION | #ifdef EXPRECISION | ||||
if (mode & BLAS_XDOUBLE){ | |||||
if ((mode & BLAS_PREC) == BLAS_XDOUBLE){ | |||||
/* COMPLEX / Extended Double */ | /* COMPLEX / Extended Double */ | ||||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, xdouble, xdouble, | void (*afunc)(BLASLONG, BLASLONG, BLASLONG, xdouble, xdouble, | ||||
xdouble *, BLASLONG, xdouble *, BLASLONG, | xdouble *, BLASLONG, xdouble *, BLASLONG, | ||||
@@ -244,7 +281,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){ | |||||
args -> c, args -> ldc, sb); | args -> c, args -> ldc, sb); | ||||
} else | } else | ||||
#endif | #endif | ||||
if (mode & BLAS_DOUBLE){ | |||||
if ((mode & BLAS_PREC) == BLAS_DOUBLE) { | |||||
/* COMPLEX / Double */ | /* COMPLEX / Double */ | ||||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, double, double, | void (*afunc)(BLASLONG, BLASLONG, BLASLONG, double, double, | ||||
double *, BLASLONG, double *, BLASLONG, | double *, BLASLONG, double *, BLASLONG, | ||||
@@ -256,7 +293,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){ | |||||
args -> a, args -> lda, | args -> a, args -> lda, | ||||
args -> b, args -> ldb, | args -> b, args -> ldb, | ||||
args -> c, args -> ldc, sb); | args -> c, args -> ldc, sb); | ||||
} else { | |||||
} else if ((mode & BLAS_PREC) == BLAS_SINGLE) { | |||||
/* COMPLEX / Single */ | /* COMPLEX / Single */ | ||||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float, float, | void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float, float, | ||||
float *, BLASLONG, float *, BLASLONG, | float *, BLASLONG, float *, BLASLONG, | ||||
@@ -268,7 +305,9 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){ | |||||
args -> a, args -> lda, | args -> a, args -> lda, | ||||
args -> b, args -> ldb, | args -> b, args -> ldb, | ||||
args -> c, args -> ldc, sb); | args -> c, args -> ldc, sb); | ||||
} | |||||
} else { | |||||
/* COMPLEX / Other types in future */ | |||||
} | |||||
} | } | ||||
} | } | ||||
@@ -414,33 +453,37 @@ blas_queue_t *tscq; | |||||
if (sb == NULL) { | if (sb == NULL) { | ||||
if (!(queue -> mode & BLAS_COMPLEX)){ | if (!(queue -> mode & BLAS_COMPLEX)){ | ||||
#ifdef EXPRECISION | #ifdef EXPRECISION | ||||
if (queue -> mode & BLAS_XDOUBLE){ | |||||
if ((queue -> mode & BLAS_PREC) == BLAS_XDOUBLE){ | |||||
sb = (void *)(((BLASLONG)sa + ((QGEMM_P * QGEMM_Q * sizeof(xdouble) | sb = (void *)(((BLASLONG)sa + ((QGEMM_P * QGEMM_Q * sizeof(xdouble) | ||||
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); | + GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); | ||||
} else | } else | ||||
#endif | #endif | ||||
if (queue -> mode & BLAS_DOUBLE){ | |||||
if ((queue -> mode & BLAS_PREC) == BLAS_DOUBLE) { | |||||
sb = (void *)(((BLASLONG)sa + ((DGEMM_P * DGEMM_Q * sizeof(double) | sb = (void *)(((BLASLONG)sa + ((DGEMM_P * DGEMM_Q * sizeof(double) | ||||
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); | + GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); | ||||
} else { | |||||
} else if ((queue -> mode & BLAS_PREC) == BLAS_SINGLE) { | |||||
sb = (void *)(((BLASLONG)sa + ((SGEMM_P * SGEMM_Q * sizeof(float) | sb = (void *)(((BLASLONG)sa + ((SGEMM_P * SGEMM_Q * sizeof(float) | ||||
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); | + GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); | ||||
} | |||||
} else { | |||||
/* Other types in future */ | |||||
} | |||||
} else { | } else { | ||||
#ifdef EXPRECISION | #ifdef EXPRECISION | ||||
if (queue -> mode & BLAS_XDOUBLE){ | |||||
if ((queue -> mode & BLAS_PREC) == BLAS_XDOUBLE){ | |||||
sb = (void *)(((BLASLONG)sa + ((XGEMM_P * XGEMM_Q * 2 * sizeof(xdouble) | sb = (void *)(((BLASLONG)sa + ((XGEMM_P * XGEMM_Q * 2 * sizeof(xdouble) | ||||
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); | + GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); | ||||
} else | } else | ||||
#endif | #endif | ||||
if (queue -> mode & BLAS_DOUBLE){ | |||||
if ((queue -> mode & BLAS_PREC) == BLAS_DOUBLE){ | |||||
sb = (void *)(((BLASLONG)sa + ((ZGEMM_P * ZGEMM_Q * 2 * sizeof(double) | sb = (void *)(((BLASLONG)sa + ((ZGEMM_P * ZGEMM_Q * 2 * sizeof(double) | ||||
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); | + GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); | ||||
} else { | |||||
} else if ((queue -> mode & BLAS_PREC) == BLAS_SINGLE) { | |||||
sb = (void *)(((BLASLONG)sa + ((CGEMM_P * CGEMM_Q * 2 * sizeof(float) | sb = (void *)(((BLASLONG)sa + ((CGEMM_P * CGEMM_Q * 2 * sizeof(float) | ||||
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); | + GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); | ||||
} | |||||
} else { | |||||
/* Other types in future */ | |||||
} | |||||
} | } | ||||
queue->sb=sb; | queue->sb=sb; | ||||
} | } | ||||
@@ -142,7 +142,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){ | |||||
if (!(mode & BLAS_COMPLEX)){ | if (!(mode & BLAS_COMPLEX)){ | ||||
#ifdef EXPRECISION | #ifdef EXPRECISION | ||||
if (mode & BLAS_XDOUBLE){ | |||||
if ((mode & BLAS_PREC) == BLAS_XDOUBLE){ | |||||
/* REAL / Extended Double */ | /* REAL / Extended Double */ | ||||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, xdouble, | void (*afunc)(BLASLONG, BLASLONG, BLASLONG, xdouble, | ||||
xdouble *, BLASLONG, xdouble *, BLASLONG, | xdouble *, BLASLONG, xdouble *, BLASLONG, | ||||
@@ -155,7 +155,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){ | |||||
args -> c, args -> ldc, sb); | args -> c, args -> ldc, sb); | ||||
} else | } else | ||||
#endif | #endif | ||||
if (mode & BLAS_DOUBLE){ | |||||
if ((mode & BLAS_PREC) == BLAS_DOUBLE){ | |||||
/* REAL / Double */ | /* REAL / Double */ | ||||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, double, | void (*afunc)(BLASLONG, BLASLONG, BLASLONG, double, | ||||
double *, BLASLONG, double *, BLASLONG, | double *, BLASLONG, double *, BLASLONG, | ||||
@@ -166,7 +166,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){ | |||||
args -> a, args -> lda, | args -> a, args -> lda, | ||||
args -> b, args -> ldb, | args -> b, args -> ldb, | ||||
args -> c, args -> ldc, sb); | args -> c, args -> ldc, sb); | ||||
} else { | |||||
} else if ((mode & BLAS_PREC) == BLAS_SINGLE){ | |||||
/* REAL / Single */ | /* REAL / Single */ | ||||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float, | void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float, | ||||
float *, BLASLONG, float *, BLASLONG, | float *, BLASLONG, float *, BLASLONG, | ||||
@@ -177,10 +177,47 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){ | |||||
args -> a, args -> lda, | args -> a, args -> lda, | ||||
args -> b, args -> ldb, | args -> b, args -> ldb, | ||||
args -> c, args -> ldc, sb); | args -> c, args -> ldc, sb); | ||||
#ifdef BUILD_HALF | |||||
} else if ((mode & BLAS_PREC) == BLAS_BFLOAT16){ | |||||
/* REAL / BFLOAT16 */ | |||||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, bfloat16, | |||||
bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, | |||||
bfloat16 *, BLASLONG, void *) = func; | |||||
afunc(args -> m, args -> n, args -> k, | |||||
((bfloat16 *)args -> alpha)[0], | |||||
args -> a, args -> lda, | |||||
args -> b, args -> ldb, | |||||
args -> c, args -> ldc, sb); | |||||
} else if ((mode & BLAS_PREC) == BLAS_STOBF16){ | |||||
/* REAL / BLAS_STOBF16 */ | |||||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float, | |||||
float *, BLASLONG, bfloat16 *, BLASLONG, | |||||
float *, BLASLONG, void *) = func; | |||||
afunc(args -> m, args -> n, args -> k, | |||||
((float *)args -> alpha)[0], | |||||
args -> a, args -> lda, | |||||
args -> b, args -> ldb, | |||||
args -> c, args -> ldc, sb); | |||||
} else if ((mode & BLAS_PREC) == BLAS_DTOBF16){ | |||||
/* REAL / BLAS_DTOBF16 */ | |||||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, double, | |||||
double *, BLASLONG, bfloat16 *, BLASLONG, | |||||
double *, BLASLONG, void *) = func; | |||||
afunc(args -> m, args -> n, args -> k, | |||||
((double *)args -> alpha)[0], | |||||
args -> a, args -> lda, | |||||
args -> b, args -> ldb, | |||||
args -> c, args -> ldc, sb); | |||||
#endif | |||||
} else { | |||||
/* REAL / Other types in future */ | |||||
} | } | ||||
} else { | } else { | ||||
#ifdef EXPRECISION | #ifdef EXPRECISION | ||||
if (mode & BLAS_XDOUBLE){ | |||||
if ((mode & BLAS_PREC) == BLAS_XDOUBLE){ | |||||
/* COMPLEX / Extended Double */ | /* COMPLEX / Extended Double */ | ||||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, xdouble, xdouble, | void (*afunc)(BLASLONG, BLASLONG, BLASLONG, xdouble, xdouble, | ||||
xdouble *, BLASLONG, xdouble *, BLASLONG, | xdouble *, BLASLONG, xdouble *, BLASLONG, | ||||
@@ -194,7 +231,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){ | |||||
args -> c, args -> ldc, sb); | args -> c, args -> ldc, sb); | ||||
} else | } else | ||||
#endif | #endif | ||||
if (mode & BLAS_DOUBLE){ | |||||
if ((mode & BLAS_PREC) == BLAS_DOUBLE){ | |||||
/* COMPLEX / Double */ | /* COMPLEX / Double */ | ||||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, double, double, | void (*afunc)(BLASLONG, BLASLONG, BLASLONG, double, double, | ||||
double *, BLASLONG, double *, BLASLONG, | double *, BLASLONG, double *, BLASLONG, | ||||
@@ -206,7 +243,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){ | |||||
args -> a, args -> lda, | args -> a, args -> lda, | ||||
args -> b, args -> ldb, | args -> b, args -> ldb, | ||||
args -> c, args -> ldc, sb); | args -> c, args -> ldc, sb); | ||||
} else { | |||||
} else if ((mode & BLAS_PREC) == BLAS_SINGLE){ | |||||
/* COMPLEX / Single */ | /* COMPLEX / Single */ | ||||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float, float, | void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float, float, | ||||
float *, BLASLONG, float *, BLASLONG, | float *, BLASLONG, float *, BLASLONG, | ||||
@@ -218,8 +255,10 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){ | |||||
args -> a, args -> lda, | args -> a, args -> lda, | ||||
args -> b, args -> ldb, | args -> b, args -> ldb, | ||||
args -> c, args -> ldc, sb); | args -> c, args -> ldc, sb); | ||||
} | |||||
} | |||||
} else { | |||||
/* COMPLEX / Other types in future */ | |||||
} | |||||
} | |||||
} | } | ||||
static void exec_threads(blas_queue_t *queue, int buf_index){ | static void exec_threads(blas_queue_t *queue, int buf_index){ | ||||
@@ -255,32 +294,36 @@ static void exec_threads(blas_queue_t *queue, int buf_index){ | |||||
if (sb == NULL) { | if (sb == NULL) { | ||||
if (!(queue -> mode & BLAS_COMPLEX)){ | if (!(queue -> mode & BLAS_COMPLEX)){ | ||||
#ifdef EXPRECISION | #ifdef EXPRECISION | ||||
if (queue -> mode & BLAS_XDOUBLE){ | |||||
if ((queue -> mode & BLAS_PREC) == BLAS_XDOUBLE){ | |||||
sb = (void *)(((BLASLONG)sa + ((QGEMM_P * QGEMM_Q * sizeof(xdouble) | sb = (void *)(((BLASLONG)sa + ((QGEMM_P * QGEMM_Q * sizeof(xdouble) | ||||
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); | + GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); | ||||
} else | } else | ||||
#endif | #endif | ||||
if (queue -> mode & BLAS_DOUBLE){ | |||||
if ((queue -> mode & BLAS_PREC) == BLAS_DOUBLE){ | |||||
sb = (void *)(((BLASLONG)sa + ((DGEMM_P * DGEMM_Q * sizeof(double) | sb = (void *)(((BLASLONG)sa + ((DGEMM_P * DGEMM_Q * sizeof(double) | ||||
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); | + GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); | ||||
} else { | |||||
} else if ((queue -> mode & BLAS_PREC) == BLAS_SINGLE){ | |||||
sb = (void *)(((BLASLONG)sa + ((SGEMM_P * SGEMM_Q * sizeof(float) | sb = (void *)(((BLASLONG)sa + ((SGEMM_P * SGEMM_Q * sizeof(float) | ||||
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); | + GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); | ||||
} else { | |||||
/* Other types in future */ | |||||
} | } | ||||
} else { | } else { | ||||
#ifdef EXPRECISION | #ifdef EXPRECISION | ||||
if (queue -> mode & BLAS_XDOUBLE){ | |||||
if ((queue -> mode & BLAS_PREC) == BLAS_XDOUBLE){ | |||||
sb = (void *)(((BLASLONG)sa + ((XGEMM_P * XGEMM_Q * 2 * sizeof(xdouble) | sb = (void *)(((BLASLONG)sa + ((XGEMM_P * XGEMM_Q * 2 * sizeof(xdouble) | ||||
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); | + GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); | ||||
} else | } else | ||||
#endif | #endif | ||||
if (queue -> mode & BLAS_DOUBLE){ | |||||
if ((queue -> mode & BLAS_PREC) == BLAS_DOUBLE){ | |||||
sb = (void *)(((BLASLONG)sa + ((ZGEMM_P * ZGEMM_Q * 2 * sizeof(double) | sb = (void *)(((BLASLONG)sa + ((ZGEMM_P * ZGEMM_Q * 2 * sizeof(double) | ||||
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); | + GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); | ||||
} else { | |||||
} else if ((queue -> mode & BLAS_PREC) == BLAS_SINGLE) { | |||||
sb = (void *)(((BLASLONG)sa + ((CGEMM_P * CGEMM_Q * 2 * sizeof(float) | sb = (void *)(((BLASLONG)sa + ((CGEMM_P * CGEMM_Q * 2 * sizeof(float) | ||||
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); | + GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); | ||||
} else { | |||||
/* Other types in future */ | |||||
} | } | ||||
} | } | ||||
queue->sb=sb; | queue->sb=sb; | ||||
@@ -77,7 +77,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){ | |||||
if (!(mode & BLAS_COMPLEX)){ | if (!(mode & BLAS_COMPLEX)){ | ||||
#ifdef EXPRECISION | #ifdef EXPRECISION | ||||
if (mode & BLAS_XDOUBLE){ | |||||
if ((mode & BLAS_PREC) == BLAS_XDOUBLE){ | |||||
/* REAL / Extended Double */ | /* REAL / Extended Double */ | ||||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, xdouble, | void (*afunc)(BLASLONG, BLASLONG, BLASLONG, xdouble, | ||||
xdouble *, BLASLONG, xdouble *, BLASLONG, | xdouble *, BLASLONG, xdouble *, BLASLONG, | ||||
@@ -90,7 +90,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){ | |||||
args -> c, args -> ldc, sb); | args -> c, args -> ldc, sb); | ||||
} else | } else | ||||
#endif | #endif | ||||
if (mode & BLAS_DOUBLE){ | |||||
if ((mode & BLAS_PREC) == BLAS_DOUBLE){ | |||||
/* REAL / Double */ | /* REAL / Double */ | ||||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, double, | void (*afunc)(BLASLONG, BLASLONG, BLASLONG, double, | ||||
double *, BLASLONG, double *, BLASLONG, | double *, BLASLONG, double *, BLASLONG, | ||||
@@ -101,7 +101,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){ | |||||
args -> a, args -> lda, | args -> a, args -> lda, | ||||
args -> b, args -> ldb, | args -> b, args -> ldb, | ||||
args -> c, args -> ldc, sb); | args -> c, args -> ldc, sb); | ||||
} else { | |||||
} else if ((mode & BLAS_PREC) == BLAS_SINGLE){ | |||||
/* REAL / Single */ | /* REAL / Single */ | ||||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float, | void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float, | ||||
float *, BLASLONG, float *, BLASLONG, | float *, BLASLONG, float *, BLASLONG, | ||||
@@ -112,10 +112,47 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){ | |||||
args -> a, args -> lda, | args -> a, args -> lda, | ||||
args -> b, args -> ldb, | args -> b, args -> ldb, | ||||
args -> c, args -> ldc, sb); | args -> c, args -> ldc, sb); | ||||
#ifdef BUILD_HALF | |||||
} else if ((mode & BLAS_PREC) == BLAS_BFLOAT16){ | |||||
/* REAL / BFLOAT16 */ | |||||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, bfloat16, | |||||
bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, | |||||
bfloat16 *, BLASLONG, void *) = func; | |||||
afunc(args -> m, args -> n, args -> k, | |||||
((bfloat16 *)args -> alpha)[0], | |||||
args -> a, args -> lda, | |||||
args -> b, args -> ldb, | |||||
args -> c, args -> ldc, sb); | |||||
} else if ((mode & BLAS_PREC) == BLAS_STOBF16){ | |||||
/* REAL / BLAS_STOBF16 */ | |||||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float, | |||||
float *, BLASLONG, bfloat16 *, BLASLONG, | |||||
float *, BLASLONG, void *) = func; | |||||
afunc(args -> m, args -> n, args -> k, | |||||
((float *)args -> alpha)[0], | |||||
args -> a, args -> lda, | |||||
args -> b, args -> ldb, | |||||
args -> c, args -> ldc, sb); | |||||
} else if ((mode & BLAS_PREC) == BLAS_DTOBF16){ | |||||
/* REAL / BLAS_DTOBF16 */ | |||||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, double, | |||||
double *, BLASLONG, bfloat16 *, BLASLONG, | |||||
double *, BLASLONG, void *) = func; | |||||
afunc(args -> m, args -> n, args -> k, | |||||
((double *)args -> alpha)[0], | |||||
args -> a, args -> lda, | |||||
args -> b, args -> ldb, | |||||
args -> c, args -> ldc, sb); | |||||
#endif | |||||
} else { | |||||
/* REAL / Other types in future */ | |||||
} | } | ||||
} else { | } else { | ||||
#ifdef EXPRECISION | #ifdef EXPRECISION | ||||
if (mode & BLAS_XDOUBLE){ | |||||
if ((mode & BLAS_PREC) == BLAS_XDOUBLE){ | |||||
/* COMPLEX / Extended Double */ | /* COMPLEX / Extended Double */ | ||||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, xdouble, xdouble, | void (*afunc)(BLASLONG, BLASLONG, BLASLONG, xdouble, xdouble, | ||||
xdouble *, BLASLONG, xdouble *, BLASLONG, | xdouble *, BLASLONG, xdouble *, BLASLONG, | ||||
@@ -129,7 +166,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){ | |||||
args -> c, args -> ldc, sb); | args -> c, args -> ldc, sb); | ||||
} else | } else | ||||
#endif | #endif | ||||
if (mode & BLAS_DOUBLE){ | |||||
if ((mode & BLAS_PREC) == BLAS_DOUBLE){ | |||||
/* COMPLEX / Double */ | /* COMPLEX / Double */ | ||||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, double, double, | void (*afunc)(BLASLONG, BLASLONG, BLASLONG, double, double, | ||||
double *, BLASLONG, double *, BLASLONG, | double *, BLASLONG, double *, BLASLONG, | ||||
@@ -141,7 +178,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){ | |||||
args -> a, args -> lda, | args -> a, args -> lda, | ||||
args -> b, args -> ldb, | args -> b, args -> ldb, | ||||
args -> c, args -> ldc, sb); | args -> c, args -> ldc, sb); | ||||
} else { | |||||
} else if ((mode & BLAS_PREC) == BLAS_SINGLE) { | |||||
/* COMPLEX / Single */ | /* COMPLEX / Single */ | ||||
void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float, float, | void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float, float, | ||||
float *, BLASLONG, float *, BLASLONG, | float *, BLASLONG, float *, BLASLONG, | ||||
@@ -153,7 +190,9 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){ | |||||
args -> a, args -> lda, | args -> a, args -> lda, | ||||
args -> b, args -> ldb, | args -> b, args -> ldb, | ||||
args -> c, args -> ldc, sb); | args -> c, args -> ldc, sb); | ||||
} | |||||
} else { | |||||
/* COMPLEX / Other types in future */ | |||||
} | |||||
} | } | ||||
} | } | ||||
@@ -233,32 +272,36 @@ static DWORD WINAPI blas_thread_server(void *arg){ | |||||
if (sb == NULL) { | if (sb == NULL) { | ||||
if (!(queue -> mode & BLAS_COMPLEX)){ | if (!(queue -> mode & BLAS_COMPLEX)){ | ||||
#ifdef EXPRECISION | #ifdef EXPRECISION | ||||
if (queue -> mode & BLAS_XDOUBLE){ | |||||
if ((queue -> mode & BLAS_PREC) == BLAS_XDOUBLE){ | |||||
sb = (void *)(((BLASLONG)sa + ((XGEMM_P * XGEMM_Q * sizeof(xdouble) | sb = (void *)(((BLASLONG)sa + ((XGEMM_P * XGEMM_Q * sizeof(xdouble) | ||||
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); | + GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); | ||||
} else | } else | ||||
#endif | #endif | ||||
if (queue -> mode & BLAS_DOUBLE){ | |||||
if ((queue -> mode & BLAS_PREC) == BLAS_DOUBLE){ | |||||
sb = (void *)(((BLASLONG)sa + ((DGEMM_P * DGEMM_Q * sizeof(double) | sb = (void *)(((BLASLONG)sa + ((DGEMM_P * DGEMM_Q * sizeof(double) | ||||
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); | + GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); | ||||
} else { | |||||
} else if ((queue -> mode & BLAS_PREC) == BLAS_SINGLE) { | |||||
sb = (void *)(((BLASLONG)sa + ((SGEMM_P * SGEMM_Q * sizeof(float) | sb = (void *)(((BLASLONG)sa + ((SGEMM_P * SGEMM_Q * sizeof(float) | ||||
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); | + GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); | ||||
} else { | |||||
/* Other types in future */ | |||||
} | } | ||||
} else { | } else { | ||||
#ifdef EXPRECISION | #ifdef EXPRECISION | ||||
if (queue -> mode & BLAS_XDOUBLE){ | |||||
if ((queue -> mode & BLAS_PREC) == BLAS_XDOUBLE){ | |||||
sb = (void *)(((BLASLONG)sa + ((XGEMM_P * XGEMM_Q * 2 * sizeof(xdouble) | sb = (void *)(((BLASLONG)sa + ((XGEMM_P * XGEMM_Q * 2 * sizeof(xdouble) | ||||
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); | + GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); | ||||
} else | } else | ||||
#endif | #endif | ||||
if (queue -> mode & BLAS_DOUBLE){ | |||||
if ((queue -> mode & BLAS_PREC) == BLAS_DOUBLE){ | |||||
sb = (void *)(((BLASLONG)sa + ((ZGEMM_P * ZGEMM_Q * 2 * sizeof(double) | sb = (void *)(((BLASLONG)sa + ((ZGEMM_P * ZGEMM_Q * 2 * sizeof(double) | ||||
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); | + GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); | ||||
} else { | |||||
} else if ((queue -> mode & BLAS_PREC) == BLAS_SINGLE) { | |||||
sb = (void *)(((BLASLONG)sa + ((CGEMM_P * CGEMM_Q * 2 * sizeof(float) | sb = (void *)(((BLASLONG)sa + ((CGEMM_P * CGEMM_Q * 2 * sizeof(float) | ||||
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); | + GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B); | ||||
} else { | |||||
/* Other types in future */ | |||||
} | } | ||||
} | } | ||||
queue->sb=sb; | queue->sb=sb; | ||||
@@ -207,6 +207,19 @@ extern gotoblas_t gotoblas_SKYLAKEX; | |||||
#else | #else | ||||
#define gotoblas_SKYLAKEX gotoblas_PRESCOTT | #define gotoblas_SKYLAKEX gotoblas_PRESCOTT | ||||
#endif | #endif | ||||
#ifdef DYN_COOPERLAKE | |||||
extern gotoblas_t gotoblas_COOPERLAKE; | |||||
#elif defined(DYN_SKYLAKEX) | |||||
#define gotoblas_COOPERLAKE gotoblas_SKYLAKEX | |||||
#elif defined(DYN_HASWELL) | |||||
#define gotoblas_COOPERLAKE gotoblas_HASWELL | |||||
#elif defined(DYN_SANDYBRIDGE) | |||||
#define gotoblas_COOPERLAKE gotoblas_SANDYBRIDGE | |||||
#elif defined(DYN_NEHALEM) | |||||
#define gotoblas_COOPERLAKE gotoblas_NEHALEM | |||||
#else | |||||
#define gotoblas_COOPERLAKE gotoblas_PRESCOTT | |||||
#endif | |||||
#else // not DYNAMIC_LIST | #else // not DYNAMIC_LIST | ||||
@@ -247,14 +260,17 @@ extern gotoblas_t gotoblas_EXCAVATOR; | |||||
#ifdef NO_AVX2 | #ifdef NO_AVX2 | ||||
#define gotoblas_HASWELL gotoblas_SANDYBRIDGE | #define gotoblas_HASWELL gotoblas_SANDYBRIDGE | ||||
#define gotoblas_SKYLAKEX gotoblas_SANDYBRIDGE | #define gotoblas_SKYLAKEX gotoblas_SANDYBRIDGE | ||||
#define gotoblas_COOPERLAKE gotoblas_SANDYBRIDGE | |||||
#define gotoblas_ZEN gotoblas_SANDYBRIDGE | #define gotoblas_ZEN gotoblas_SANDYBRIDGE | ||||
#else | #else | ||||
extern gotoblas_t gotoblas_HASWELL; | extern gotoblas_t gotoblas_HASWELL; | ||||
extern gotoblas_t gotoblas_ZEN; | extern gotoblas_t gotoblas_ZEN; | ||||
#ifndef NO_AVX512 | #ifndef NO_AVX512 | ||||
extern gotoblas_t gotoblas_SKYLAKEX; | extern gotoblas_t gotoblas_SKYLAKEX; | ||||
extern gotoblas_t gotoblas_COOPERLAKE; | |||||
#else | #else | ||||
#define gotoblas_SKYLAKEX gotoblas_HASWELL | #define gotoblas_SKYLAKEX gotoblas_HASWELL | ||||
#define gotoblas_COOPERLAKE gotoblas_HASWELL | |||||
#endif | #endif | ||||
#endif | #endif | ||||
#else | #else | ||||
@@ -262,6 +278,7 @@ extern gotoblas_t gotoblas_SKYLAKEX; | |||||
#define gotoblas_SANDYBRIDGE gotoblas_NEHALEM | #define gotoblas_SANDYBRIDGE gotoblas_NEHALEM | ||||
#define gotoblas_HASWELL gotoblas_NEHALEM | #define gotoblas_HASWELL gotoblas_NEHALEM | ||||
#define gotoblas_SKYLAKEX gotoblas_NEHALEM | #define gotoblas_SKYLAKEX gotoblas_NEHALEM | ||||
#define gotoblas_COOPERLAKE gotoblas_NEHALEM | |||||
#define gotoblas_BULLDOZER gotoblas_BARCELONA | #define gotoblas_BULLDOZER gotoblas_BARCELONA | ||||
#define gotoblas_PILEDRIVER gotoblas_BARCELONA | #define gotoblas_PILEDRIVER gotoblas_BARCELONA | ||||
#define gotoblas_STEAMROLLER gotoblas_BARCELONA | #define gotoblas_STEAMROLLER gotoblas_BARCELONA | ||||
@@ -343,6 +360,23 @@ int support_avx512(){ | |||||
#endif | #endif | ||||
} | } | ||||
int support_avx512_bf16(){ | |||||
#if !defined(NO_AVX) && !defined(NO_AVX512) | |||||
int eax, ebx, ecx, edx; | |||||
int ret=0; | |||||
if (!support_avx512()) | |||||
return 0; | |||||
cpuid_count(7, 1, &eax, &ebx, &ecx, &edx); | |||||
if((eax & 32) == 32){ | |||||
ret=1; // CPUID.7.1:EAX[bit 5] indicates whether avx512_bf16 supported or not | |||||
} | |||||
return ret; | |||||
#else | |||||
return 0; | |||||
#endif | |||||
} | |||||
extern void openblas_warning(int verbose, const char * msg); | extern void openblas_warning(int verbose, const char * msg); | ||||
#define FALLBACK_VERBOSE 1 | #define FALLBACK_VERBOSE 1 | ||||
#define NEHALEM_FALLBACK "OpenBLAS : Your OS does not support AVX instructions. OpenBLAS is using Nehalem kernels as a fallback, which may give poorer performance.\n" | #define NEHALEM_FALLBACK "OpenBLAS : Your OS does not support AVX instructions. OpenBLAS is using Nehalem kernels as a fallback, which may give poorer performance.\n" | ||||
@@ -524,7 +558,10 @@ static gotoblas_t *get_coretype(void){ | |||||
return &gotoblas_NEHALEM; //OS doesn't support AVX. Use old kernels. | return &gotoblas_NEHALEM; //OS doesn't support AVX. Use old kernels. | ||||
} | } | ||||
} | } | ||||
if (model == 5) { | |||||
if (model == 5) { | |||||
// Intel Cooperlake | |||||
if(support_avx512_bf16()) | |||||
return &gotoblas_COOPERLAKE; | |||||
// Intel Skylake X | // Intel Skylake X | ||||
if (support_avx512()) | if (support_avx512()) | ||||
return &gotoblas_SKYLAKEX; | return &gotoblas_SKYLAKEX; | ||||
@@ -774,7 +811,8 @@ static char *corename[] = { | |||||
"Steamroller", | "Steamroller", | ||||
"Excavator", | "Excavator", | ||||
"Zen", | "Zen", | ||||
"SkylakeX" | |||||
"SkylakeX", | |||||
"Cooperlake" | |||||
}; | }; | ||||
char *gotoblas_corename(void) { | char *gotoblas_corename(void) { | ||||
@@ -838,6 +876,7 @@ char *gotoblas_corename(void) { | |||||
if (gotoblas == &gotoblas_EXCAVATOR) return corename[22]; | if (gotoblas == &gotoblas_EXCAVATOR) return corename[22]; | ||||
if (gotoblas == &gotoblas_ZEN) return corename[23]; | if (gotoblas == &gotoblas_ZEN) return corename[23]; | ||||
if (gotoblas == &gotoblas_SKYLAKEX) return corename[24]; | if (gotoblas == &gotoblas_SKYLAKEX) return corename[24]; | ||||
if (gotoblas == &gotoblas_COOPERLAKE) return corename[25]; | |||||
return corename[0]; | return corename[0]; | ||||
} | } | ||||
@@ -868,6 +907,7 @@ static gotoblas_t *force_coretype(char *coretype){ | |||||
switch (found) | switch (found) | ||||
{ | { | ||||
case 25: return (&gotoblas_COOPERLAKE); | |||||
case 24: return (&gotoblas_SKYLAKEX); | case 24: return (&gotoblas_SKYLAKEX); | ||||
case 23: return (&gotoblas_ZEN); | case 23: return (&gotoblas_ZEN); | ||||
case 22: return (&gotoblas_EXCAVATOR); | case 22: return (&gotoblas_EXCAVATOR); | ||||
@@ -46,7 +46,7 @@ | |||||
ssum, dsum, scsum, dzsum | ssum, dsum, scsum, dzsum | ||||
); | ); | ||||
@halfblasobjs = (shgemm); | |||||
@halfblasobjs = (shgemm, shdot, shstobf16, shdtobf16, sbf16tos, dbf16tod); | |||||
@cblasobjs = ( | @cblasobjs = ( | ||||
cblas_caxpy, cblas_ccopy, cblas_cdotc, cblas_cdotu, cblas_cgbmv, cblas_cgemm, cblas_cgemv, | cblas_caxpy, cblas_ccopy, cblas_cdotc, cblas_cdotu, cblas_cgbmv, cblas_cgemm, cblas_cgemv, | ||||
cblas_cgerc, cblas_cgeru, cblas_chbmv, cblas_chemm, cblas_chemv, cblas_cher2, cblas_cher2k, | cblas_cgerc, cblas_cgeru, cblas_chbmv, cblas_chemm, cblas_chemv, cblas_cher2, cblas_cher2k, | ||||
@@ -84,7 +84,7 @@ | |||||
cblas_xerbla | cblas_xerbla | ||||
); | ); | ||||
@halfcblasobjs = (cblas_shgemm); | |||||
@halfcblasobjs = (cblas_shgemm, cblas_shdot, cblas_shstobf16, cblas_shdtobf16, cblas_sbf16tos, cblas_dbf16tod); | |||||
@exblasobjs = ( | @exblasobjs = ( | ||||
qamax,qamin,qasum,qaxpy,qcabs1,qcopy,qdot,qgbmv,qgemm, | qamax,qamin,qasum,qaxpy,qcabs1,qcopy,qdot,qgbmv,qgemm, | ||||
@@ -47,7 +47,9 @@ SBLAS3OBJS = \ | |||||
sgeadd.$(SUFFIX) | sgeadd.$(SUFFIX) | ||||
ifeq ($(BUILD_HALF),1) | ifeq ($(BUILD_HALF),1) | ||||
SHBLAS1OBJS = shdot.$(SUFFIX) | |||||
SHBLAS3OBJS = shgemm.$(SUFFIX) | SHBLAS3OBJS = shgemm.$(SUFFIX) | ||||
SHEXTOBJS = shstobf16.$(SUFFIX) shdtobf16.$(SUFFIX) sbf16tos.$(SUFFIX) dbf16tod.$(SUFFIX) | |||||
endif | endif | ||||
DBLAS1OBJS = \ | DBLAS1OBJS = \ | ||||
@@ -281,7 +283,9 @@ CSBLAS3OBJS = \ | |||||
cblas_sgeadd.$(SUFFIX) | cblas_sgeadd.$(SUFFIX) | ||||
ifeq ($(BUILD_HALF),1) | ifeq ($(BUILD_HALF),1) | ||||
CSHBLAS1OBJS = cblas_shdot.$(SUFFIX) | |||||
CSHBLAS3OBJS = cblas_shgemm.$(SUFFIX) | CSHBLAS3OBJS = cblas_shgemm.$(SUFFIX) | ||||
CSHEXTOBJS = cblas_shstobf16.$(SUFFIX) cblas_shdtobf16.$(SUFFIX) cblas_sbf16tos.$(SUFFIX) cblas_dbf16tod.$(SUFFIX) | |||||
endif | endif | ||||
CDBLAS1OBJS = \ | CDBLAS1OBJS = \ | ||||
@@ -374,6 +378,7 @@ override CFLAGS += -I. | |||||
SBLAS1OBJS += $(CSBLAS1OBJS) | SBLAS1OBJS += $(CSBLAS1OBJS) | ||||
SBLAS2OBJS += $(CSBLAS2OBJS) | SBLAS2OBJS += $(CSBLAS2OBJS) | ||||
SBLAS3OBJS += $(CSBLAS3OBJS) | SBLAS3OBJS += $(CSBLAS3OBJS) | ||||
SHBLAS1OBJS += $(CSHBLAS1OBJS) | |||||
SHBLAS3OBJS += $(CSHBLAS3OBJS) | SHBLAS3OBJS += $(CSHBLAS3OBJS) | ||||
DBLAS1OBJS += $(CDBLAS1OBJS) | DBLAS1OBJS += $(CDBLAS1OBJS) | ||||
DBLAS2OBJS += $(CDBLAS2OBJS) | DBLAS2OBJS += $(CDBLAS2OBJS) | ||||
@@ -385,10 +390,11 @@ ZBLAS1OBJS += $(CZBLAS1OBJS) | |||||
ZBLAS2OBJS += $(CZBLAS2OBJS) | ZBLAS2OBJS += $(CZBLAS2OBJS) | ||||
ZBLAS3OBJS += $(CZBLAS3OBJS) | ZBLAS3OBJS += $(CZBLAS3OBJS) | ||||
SHEXTOBJS += $(CSHEXTOBJS) | |||||
endif | endif | ||||
SBLASOBJS = $(SBLAS1OBJS) $(SBLAS2OBJS) $(SBLAS3OBJS) | SBLASOBJS = $(SBLAS1OBJS) $(SBLAS2OBJS) $(SBLAS3OBJS) | ||||
SHBLASOBJS = $(SHBLAS3OBJS) | |||||
SHBLASOBJS = $(SHBLAS1OBJS) $(SHBLAS3OBJS) | |||||
DBLASOBJS = $(DBLAS1OBJS) $(DBLAS2OBJS) $(DBLAS3OBJS) | DBLASOBJS = $(DBLAS1OBJS) $(DBLAS2OBJS) $(DBLAS3OBJS) | ||||
QBLASOBJS = $(QBLAS1OBJS) $(QBLAS2OBJS) $(QBLAS3OBJS) | QBLASOBJS = $(QBLAS1OBJS) $(QBLAS2OBJS) $(QBLAS3OBJS) | ||||
CBLASOBJS = $(CBLAS1OBJS) $(CBLAS2OBJS) $(CBLAS3OBJS) | CBLASOBJS = $(CBLAS1OBJS) $(CBLAS2OBJS) $(CBLAS3OBJS) | ||||
@@ -463,7 +469,7 @@ ZBLASOBJS += $(ZLAPACKOBJS) | |||||
endif | endif | ||||
FUNCOBJS = $(SHBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS) | |||||
FUNCOBJS = $(SHEXTOBJS) $(SHBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS) | |||||
ifdef EXPRECISION | ifdef EXPRECISION | ||||
FUNCOBJS += $(QBLASOBJS) $(XBLASOBJS) | FUNCOBJS += $(QBLASOBJS) $(XBLASOBJS) | ||||
@@ -491,7 +497,7 @@ endif | |||||
clean :: | clean :: | ||||
@rm -f functable.h | @rm -f functable.h | ||||
level1 : $(SBLAS1OBJS) $(DBLAS1OBJS) $(QBLAS1OBJS) $(CBLAS1OBJS) $(ZBLAS1OBJS) $(XBLAS1OBJS) | |||||
level1 : $(BEXTOBJS) $(SHBLAS1OBJS) $(SBLAS1OBJS) $(DBLAS1OBJS) $(QBLAS1OBJS) $(CBLAS1OBJS) $(ZBLAS1OBJS) $(XBLAS1OBJS) | |||||
$(AR) $(ARFLAGS) -ru $(TOPDIR)/$(LIBNAME) $^ | $(AR) $(ARFLAGS) -ru $(TOPDIR)/$(LIBNAME) $^ | ||||
level2 : $(SBLAS2OBJS) $(DBLAS2OBJS) $(QBLAS2OBJS) $(CBLAS2OBJS) $(ZBLAS2OBJS) $(XBLAS2OBJS) | level2 : $(SBLAS2OBJS) $(DBLAS2OBJS) $(QBLAS2OBJS) $(CBLAS2OBJS) $(ZBLAS2OBJS) $(XBLAS2OBJS) | ||||
@@ -725,6 +731,19 @@ sdsdot.$(SUFFIX) sdsdot.$(PSUFFIX) : sdsdot.c | |||||
dsdot.$(SUFFIX) dsdot.$(PSUFFIX) : dsdot.c | dsdot.$(SUFFIX) dsdot.$(PSUFFIX) : dsdot.c | ||||
$(CC) $(CFLAGS) -c $< -o $(@F) | $(CC) $(CFLAGS) -c $< -o $(@F) | ||||
ifeq ($(BUILD_HALF),1) | |||||
shdot.$(SUFFIX) shdot.$(PSUFFIX) : bf16dot.c | |||||
$(CC) $(CFLAGS) -c $< -o $(@F) | |||||
shstobf16.$(SUFFIX) shstobf16.$(PSUFFIX) : tobf16.c | |||||
$(CC) $(CFLAGS) -DSINGLE_PREC -UDOUBLE_PREC -c $< -o $(@F) | |||||
shdtobf16.$(SUFFIX) shdtobf16.$(PSUFFIX) : tobf16.c | |||||
$(CC) $(CFLAGS) -USINGLE_PREC -DDOUBLE_PREC -c $< -o $(@F) | |||||
sbf16tos.$(SUFFIX) sbf16tos.$(PSUFFIX) : bf16to.c | |||||
$(CC) $(CFLAGS) -DSINGLE_PREC -UDOUBLE_PREC -c $< -o $(@F) | |||||
dbf16tod.$(SUFFIX) dbf16tod.$(PSUFFIX) : bf16to.c | |||||
$(CC) $(CFLAGS) -USINGLE_PREC -DDOUBLE_PREC -c $< -o $(@F) | |||||
endif | |||||
sdot.$(SUFFIX) sdot.$(PSUFFIX) : dot.c | sdot.$(SUFFIX) sdot.$(PSUFFIX) : dot.c | ||||
$(CC) $(CFLAGS) -c $< -o $(@F) | $(CC) $(CFLAGS) -c $< -o $(@F) | ||||
@@ -1463,6 +1482,19 @@ cblas_sdsdot.$(SUFFIX) cblas_sdsdot.$(PSUFFIX) : sdsdot.c | |||||
cblas_dsdot.$(SUFFIX) cblas_dsdot.$(PSUFFIX) : dsdot.c | cblas_dsdot.$(SUFFIX) cblas_dsdot.$(PSUFFIX) : dsdot.c | ||||
$(CC) $(CFLAGS) -DCBLAS -c $< -o $(@F) | $(CC) $(CFLAGS) -DCBLAS -c $< -o $(@F) | ||||
ifeq ($(BUILD_HALF),1) | |||||
cblas_shdot.$(SUFFIX) cblas_shdot.$(PSUFFIX) : bf16dot.c | |||||
$(CC) $(CFLAGS) -DCBLAS -c $< -o $(@F) | |||||
cblas_shstobf16.$(SUFFIX) cblas_shstobf16.$(PSUFFIX) : tobf16.c | |||||
$(CC) $(CFLAGS) -DCBLAS -DSINGLE_PREC -UDOUBLE_PREC -c $< -o $(@F) | |||||
cblas_shdtobf16.$(SUFFIX) cblas_shdtobf16.$(PSUFFIX) : tobf16.c | |||||
$(CC) $(CFLAGS) -DCBLAS -USINGLE_PREC -DDOUBLE_PREC -c $< -o $(@F) | |||||
cblas_sbf16tos.$(SUFFIX) cblas_sbf16tos.$(PSUFFIX) : bf16to.c | |||||
$(CC) $(CFLAGS) -DCBLAS -DSINGLE_PREC -UDOUBLE_PREC -c $< -o $(@F) | |||||
cblas_dbf16tod.$(SUFFIX) cblas_dbf16tod.$(PSUFFIX) : bf16to.c | |||||
$(CC) $(CFLAGS) -DCBLAS -USINGLE_PREC -DDOUBLE_PREC -c $< -o $(@F) | |||||
endif | |||||
cblas_sdot.$(SUFFIX) cblas_sdot.$(PSUFFIX) : dot.c | cblas_sdot.$(SUFFIX) cblas_sdot.$(PSUFFIX) : dot.c | ||||
$(CC) $(CFLAGS) -DCBLAS -c $< -o $(@F) | $(CC) $(CFLAGS) -DCBLAS -c $< -o $(@F) | ||||
@@ -0,0 +1,52 @@ | |||||
#include <stdio.h> | |||||
#include "common.h" | |||||
#ifdef FUNCTION_PROFILE | |||||
#include "functable.h" | |||||
#endif | |||||
#ifndef CBLAS | |||||
float NAME(blasint *N, bfloat16 *x, blasint *INCX, bfloat16 *y, blasint *INCY){ | |||||
BLASLONG n = *N; | |||||
BLASLONG incx = *INCX; | |||||
BLASLONG incy = *INCY; | |||||
float ret; | |||||
PRINT_DEBUG_NAME; | |||||
if (n <= 0) return 0.; | |||||
IDEBUG_START; | |||||
FUNCTION_PROFILE_START(); | |||||
if (incx < 0) x -= (n - 1) * incx; | |||||
if (incy < 0) y -= (n - 1) * incy; | |||||
ret = BF16_DOT_K(n, x, incx, y, incy); | |||||
FUNCTION_PROFILE_END(1, 2 * n, 2 * n); | |||||
IDEBUG_END; | |||||
return ret; | |||||
} | |||||
#else | |||||
float CNAME(blasint n, bfloat16 *x, blasint incx, bfloat16 *y, blasint incy){ | |||||
float ret; | |||||
PRINT_DEBUG_CNAME; | |||||
if (n <= 0) return 0.; | |||||
IDEBUG_START; | |||||
FUNCTION_PROFILE_START(); | |||||
if (incx < 0) x -= (n - 1) * incx; | |||||
if (incy < 0) y -= (n - 1) * incy; | |||||
ret = BF16_DOT_K(n, x, incx, y, incy); | |||||
FUNCTION_PROFILE_END(1, 2 * n, 2 * n); | |||||
IDEBUG_END; | |||||
return ret; | |||||
} | |||||
#endif |
@@ -0,0 +1,62 @@ | |||||
#include <stdio.h> | |||||
#include "common.h" | |||||
#ifdef FUNCTION_PROFILE | |||||
#include "functable.h" | |||||
#endif | |||||
#if defined(DOUBLE_PREC) | |||||
#define FLOAT_TYPE double | |||||
#elif defined(SINGLE_PREC) | |||||
#define FLOAT_TYPE float | |||||
#else | |||||
#endif | |||||
#ifndef CBLAS | |||||
void NAME(blasint *N, bfloat16 *in, blasint *INC_IN, FLOAT_TYPE *out, blasint *INC_OUT){ | |||||
BLASLONG n = *N; | |||||
BLASLONG inc_in = *INC_IN; | |||||
BLASLONG inc_out = *INC_OUT; | |||||
PRINT_DEBUG_NAME; | |||||
if (n <= 0) return; | |||||
IDEBUG_START; | |||||
FUNCTION_PROFILE_START(); | |||||
if (inc_in < 0) in -= (n - 1) * inc_in; | |||||
if (inc_out < 0) out -= (n - 1) * inc_out; | |||||
#if defined(DOUBLE_PREC) | |||||
D_BF16_TO_K(n, in, inc_in, out, inc_out); | |||||
#elif defined(SINGLE_PREC) | |||||
S_BF16_TO_K(n, in, inc_in, out, inc_out); | |||||
#else | |||||
#endif | |||||
FUNCTION_PROFILE_END(1, 2 * n, 2 * n); | |||||
IDEBUG_END; | |||||
} | |||||
#else | |||||
void CNAME(blasint n, bfloat16 * in, blasint inc_in, FLOAT_TYPE * out, blasint inc_out){ | |||||
PRINT_DEBUG_CNAME; | |||||
if (n <= 0) return; | |||||
IDEBUG_START; | |||||
FUNCTION_PROFILE_START(); | |||||
if (inc_in < 0) in -= (n - 1) * inc_in; | |||||
if (inc_out < 0) out -= (n - 1) * inc_out; | |||||
#if defined(DOUBLE_PREC) | |||||
D_BF16_TO_K(n, in, inc_in, out, inc_out); | |||||
#elif defined(SINGLE_PREC) | |||||
S_BF16_TO_K(n, in, inc_in, out, inc_out); | |||||
#else | |||||
#endif | |||||
FUNCTION_PROFILE_END(1, 2 * n, 2 * n); | |||||
IDEBUG_END; | |||||
} | |||||
#endif |
@@ -0,0 +1,61 @@ | |||||
#include <stdio.h> | |||||
#include "common.h" | |||||
#ifdef FUNCTION_PROFILE | |||||
#include "functable.h" | |||||
#endif | |||||
#if defined(DOUBLE_PREC) | |||||
#define FLOAT_TYPE double | |||||
#elif defined(SINGLE_PREC) | |||||
#define FLOAT_TYPE float | |||||
#else | |||||
#endif | |||||
#ifndef CBLAS | |||||
void NAME(blasint *N, FLOAT_TYPE *in, blasint *INC_IN, bfloat16 *out, blasint *INC_OUT){ | |||||
BLASLONG n = *N; | |||||
BLASLONG inc_in = *INC_IN; | |||||
BLASLONG inc_out = *INC_OUT; | |||||
PRINT_DEBUG_NAME; | |||||
if (n <= 0) return; | |||||
IDEBUG_START; | |||||
FUNCTION_PROFILE_START(); | |||||
if (inc_in < 0) in -= (n - 1) * inc_in; | |||||
if (inc_out < 0) out -= (n - 1) * inc_out; | |||||
#if defined(DOUBLE_PREC) | |||||
D_TO_BF16_K(n, in, inc_in, out, inc_out); | |||||
#elif defined(SINGLE_PREC) | |||||
S_TO_BF16_K(n, in, inc_in, out, inc_out); | |||||
#else | |||||
#endif | |||||
FUNCTION_PROFILE_END(1, 2 * n, 2 * n); | |||||
IDEBUG_END; | |||||
} | |||||
#else | |||||
void CNAME(blasint n, FLOAT_TYPE *in, blasint inc_in, bfloat16 *out, blasint inc_out){ | |||||
PRINT_DEBUG_CNAME; | |||||
if (n <= 0) return; | |||||
IDEBUG_START; | |||||
FUNCTION_PROFILE_START(); | |||||
if (inc_in < 0) in -= (n - 1) * inc_in; | |||||
if (inc_out < 0) out -= (n - 1) * inc_out; | |||||
#if defined(DOUBLE_PREC) | |||||
D_TO_BF16_K(n, in, inc_in, out, inc_out); | |||||
#elif defined(SINGLE_PREC) | |||||
S_TO_BF16_K(n, in, inc_in, out, inc_out); | |||||
#endif | |||||
FUNCTION_PROFILE_END(1, 2 * n, 2 * n); | |||||
IDEBUG_END; | |||||
} | |||||
#endif |
@@ -262,6 +262,20 @@ ifndef XDOTKERNEL | |||||
XDOTKERNEL = zdot.S | XDOTKERNEL = zdot.S | ||||
endif | endif | ||||
ifeq ($(BUILD_HALF),1) | |||||
ifndef SHDOTKERNEL | |||||
SHDOTKERNEL = ../x86_64/shdot.c | |||||
endif | |||||
ifndef TOBF16KERNEL | |||||
TOBF16KERNEL = ../x86_64/tobf16.c | |||||
endif | |||||
ifndef BF16TOKERNEL | |||||
BF16TOKERNEL = ../x86_64/bf16to.c | |||||
endif | |||||
endif | |||||
### NRM2 ### | ### NRM2 ### | ||||
ifndef SNRM2KERNEL | ifndef SNRM2KERNEL | ||||
@@ -516,6 +530,15 @@ XBLASOBJS += \ | |||||
xdotc_k$(TSUFFIX).$(SUFFIX) xdotu_k$(TSUFFIX).$(SUFFIX) xnrm2_k$(TSUFFIX).$(SUFFIX) xqrot_k$(TSUFFIX).$(SUFFIX) \ | xdotc_k$(TSUFFIX).$(SUFFIX) xdotu_k$(TSUFFIX).$(SUFFIX) xnrm2_k$(TSUFFIX).$(SUFFIX) xqrot_k$(TSUFFIX).$(SUFFIX) \ | ||||
xscal_k$(TSUFFIX).$(SUFFIX) xswap_k$(TSUFFIX).$(SUFFIX) xsum_k$(TSUFFIX).$(SUFFIX) | xscal_k$(TSUFFIX).$(SUFFIX) xswap_k$(TSUFFIX).$(SUFFIX) xsum_k$(TSUFFIX).$(SUFFIX) | ||||
ifeq ($(BUILD_HALF),1) | |||||
SHBLASOBJS += \ | |||||
shdot_k$(TSUFFIX).$(SUFFIX) | |||||
SHEXTOBJS += \ | |||||
shstobf16_k$(TSUFFIX).$(SUFFIX) shdtobf16_k$(TSUFFIX).$(SUFFIX) | |||||
SHEXTOBJS += \ | |||||
sbf16tos_k$(TSUFFIX).$(SUFFIX) dbf16tod_k$(TSUFFIX).$(SUFFIX) | |||||
endif | |||||
### AMAX ### | ### AMAX ### | ||||
@@ -734,6 +757,19 @@ $(KDIR)ddot_k$(TSUFFIX).$(SUFFIX) $(KDIR)ddot_k$(TPSUFFIX).$(PSUFFIX) : $(KERNEL | |||||
$(KDIR)qdot_k$(TSUFFIX).$(SUFFIX) $(KDIR)qdot_k$(TPSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(QDOTKERNEL) | $(KDIR)qdot_k$(TSUFFIX).$(SUFFIX) $(KDIR)qdot_k$(TPSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(QDOTKERNEL) | ||||
$(CC) -c $(CFLAGS) -UCOMPLEX -DXDOUBLE $< -o $@ | $(CC) -c $(CFLAGS) -UCOMPLEX -DXDOUBLE $< -o $@ | ||||
ifeq ($(BUILD_HALF),1) | |||||
$(KDIR)shdot_k$(TSUFFIX).$(SUFFIX) $(KDIR)shdot_k$(TPSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SHDOTKERNEL) | |||||
$(CC) -c $(CFLAGS) -UCOMPLEX $< -o $@ | |||||
$(KDIR)shstobf16_k$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(TOBF16KERNEL) | |||||
$(CC) -c $(CFLAGS) -UDOUBLE -DSINGLE $< -o $@ | |||||
$(KDIR)shdtobf16_k$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(TOBF16KERNEL) | |||||
$(CC) -c $(CFLAGS) -DDOUBLE -USINGLE $< -o $@ | |||||
$(KDIR)sbf16tos_k$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(BF16TOKERNEL) | |||||
$(CC) -c $(CFLAGS) -UDOUBLE -DSINGLE $< -o $@ | |||||
$(KDIR)dbf16tod_k$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(BF16TOKERNEL) | |||||
$(CC) -c $(CFLAGS) -DDOUBLE -USINGLE $< -o $@ | |||||
endif | |||||
$(KDIR)sdot_k$(TSUFFIX).$(SUFFIX) $(KDIR)sdot_k$(TPSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SDOTKERNEL) | $(KDIR)sdot_k$(TSUFFIX).$(SUFFIX) $(KDIR)sdot_k$(TPSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SDOTKERNEL) | ||||
$(CC) -c $(CFLAGS) -UCOMPLEX -UDOUBLE $< -o $@ | $(CC) -c $(CFLAGS) -UCOMPLEX -UDOUBLE $< -o $@ | ||||
@@ -62,9 +62,11 @@ gotoblas_t TABLE_NAME = { | |||||
MAX(SHGEMM_DEFAULT_UNROLL_M, SHGEMM_DEFAULT_UNROLL_N), | MAX(SHGEMM_DEFAULT_UNROLL_M, SHGEMM_DEFAULT_UNROLL_N), | ||||
#endif | #endif | ||||
shstobf16_kTS, shdtobf16_kTS, sbf16tos_kTS, dbf16tod_kTS, | |||||
samax_kTS, samin_kTS, smax_kTS, smin_kTS, | samax_kTS, samin_kTS, smax_kTS, smin_kTS, | ||||
isamax_kTS, isamin_kTS, ismax_kTS, ismin_kTS, | isamax_kTS, isamin_kTS, ismax_kTS, ismin_kTS, | ||||
snrm2_kTS, sasum_kTS, ssum_kTS, scopy_kTS, sdot_kTS, | |||||
snrm2_kTS, sasum_kTS, ssum_kTS, scopy_kTS, shdot_kTS, | |||||
dsdot_kTS, | dsdot_kTS, | ||||
srot_kTS, saxpy_kTS, sscal_kTS, sswap_kTS, | srot_kTS, saxpy_kTS, sscal_kTS, sswap_kTS, | ||||
sgemv_nTS, sgemv_tTS, sger_kTS, | sgemv_nTS, sgemv_tTS, sger_kTS, | ||||
@@ -146,6 +146,18 @@ ifndef XDOTKERNEL | |||||
XDOTKERNEL = zdot.S | XDOTKERNEL = zdot.S | ||||
endif | endif | ||||
ifndef SHDOTKERNEL | |||||
SHDOTKERNEL = shdot.c | |||||
endif | |||||
ifndef TOBF16KERNEL | |||||
TOBF16KERNEL = tobf16.c | |||||
endif | |||||
ifndef BF16TOKERNEL | |||||
BF16TOKERNEL = bf16to.c | |||||
endif | |||||
ifndef ISAMAXKERNEL | ifndef ISAMAXKERNEL | ||||
ISAMAXKERNEL = iamax_sse.S | ISAMAXKERNEL = iamax_sse.S | ||||
endif | endif | ||||
@@ -0,0 +1,114 @@ | |||||
/*************************************************************************** | |||||
Copyright (c) 2014, The OpenBLAS Project | |||||
All rights reserved. | |||||
Redistribution and use in source and binary forms, with or without | |||||
modification, are permitted provided that the following conditions are | |||||
met: | |||||
1. Redistributions of source code must retain the above copyright | |||||
notice, this list of conditions and the following disclaimer. | |||||
2. Redistributions in binary form must reproduce the above copyright | |||||
notice, this list of conditions and the following disclaimer in | |||||
the documentation and/or other materials provided with the | |||||
distribution. | |||||
3. Neither the name of the OpenBLAS project nor the names of | |||||
its contributors may be used to endorse or promote products | |||||
derived from this software without specific prior written permission. | |||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||||
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE | |||||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE | |||||
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||||
*****************************************************************************/ | |||||
#include <stddef.h> | |||||
#include "common.h" | |||||
#if defined(DOUBLE) | |||||
#define FLOAT_TYPE double | |||||
#elif defined(SINGLE) | |||||
#define FLOAT_TYPE float | |||||
#else | |||||
#endif | |||||
/* Notes for algorithm: | |||||
* - Input denormal treated as zero | |||||
* - Force to be QNAN | |||||
*/ | |||||
static void bf16to_kernel_1(BLASLONG n, const bfloat16 * in, BLASLONG inc_in, FLOAT_TYPE * out, BLASLONG inc_out) | |||||
{ | |||||
BLASLONG register index_in = 0; | |||||
BLASLONG register index_out = 0; | |||||
BLASLONG register index = 0; | |||||
uint16_t * tmp = NULL; | |||||
#if defined(DOUBLE) | |||||
float float_out = 0.0; | |||||
#endif | |||||
while(index<n) { | |||||
#if defined(DOUBLE) | |||||
float_out = 0.0; | |||||
tmp = (uint16_t*)(&float_out); | |||||
#else | |||||
*(out+index_out) = 0; | |||||
tmp = (uint16_t*)(out+index_out); | |||||
#endif | |||||
switch((*(in+index_in)) & 0xff80u) { | |||||
case (0x0000u): /* Type 1: Positive denormal */ | |||||
tmp[1] = 0x0000u; | |||||
tmp[0] = 0x0000u; | |||||
break; | |||||
case (0x8000u): /* Type 2: Negative denormal */ | |||||
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ | |||||
tmp[1] = 0x8000u; | |||||
tmp[0] = 0x0000u; | |||||
#else | |||||
tmp[1] = 0x0000u; | |||||
tmp[0] = 0x8000u; | |||||
#endif | |||||
break; | |||||
case (0x7f80u): /* Type 3: Positive infinity or NAN */ | |||||
case (0xff80u): /* Type 4: Negative infinity or NAN */ | |||||
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ | |||||
tmp[1] = *(in+index_in); | |||||
#else | |||||
tmp[0] = *(in+index_in); | |||||
#endif | |||||
/* Specific for NAN */ | |||||
if (((*(in+index_in)) & 0x007fu) != 0) { | |||||
/* Force to be QNAN */ | |||||
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ | |||||
tmp[1] |= 0x0040u; | |||||
#else | |||||
tmp[0] |= 0x0040u; | |||||
#endif | |||||
} | |||||
break; | |||||
default: /* Type 5: Normal case */ | |||||
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ | |||||
tmp[1] = *(in+index_in); | |||||
#else | |||||
tmp[0] = *(in+index_in); | |||||
#endif | |||||
break; | |||||
} | |||||
#if defined(DOUBLE) | |||||
*(out+index_out) = (double)float_out; | |||||
#endif | |||||
index_in += inc_in; | |||||
index_out += inc_out; | |||||
index++; | |||||
} | |||||
} | |||||
void CNAME(BLASLONG n, bfloat16 * in, BLASLONG inc_in, FLOAT_TYPE * out, BLASLONG inc_out) | |||||
{ | |||||
if (n <= 0) return; | |||||
bf16to_kernel_1(n, in, inc_in, out, inc_out); | |||||
} |
@@ -0,0 +1,104 @@ | |||||
/*************************************************************************** | |||||
Copyright (c) 2014, The OpenBLAS Project | |||||
All rights reserved. | |||||
Redistribution and use in source and binary forms, with or without | |||||
modification, are permitted provided that the following conditions are | |||||
met: | |||||
1. Redistributions of source code must retain the above copyright | |||||
notice, this list of conditions and the following disclaimer. | |||||
2. Redistributions in binary form must reproduce the above copyright | |||||
notice, this list of conditions and the following disclaimer in | |||||
the documentation and/or other materials provided with the | |||||
distribution. | |||||
3. Neither the name of the OpenBLAS project nor the names of | |||||
its contributors may be used to endorse or promote products | |||||
derived from this software without specific prior written permission. | |||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||||
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE | |||||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE | |||||
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||||
*****************************************************************************/ | |||||
/* need a new enough GCC for avx512 support */ | |||||
#if (( defined(__GNUC__) && __GNUC__ >= 10 && defined(__AVX512BF16__)) || (defined(__clang__) && __clang_major__ >= 9)) | |||||
#define HAVE_TOBF16_ACCL_KERNEL 1 | |||||
#include "common.h" | |||||
#include <immintrin.h> | |||||
static void tobf16_accl_kernel(BLASLONG n, const double * in, bfloat16 * out) | |||||
{ | |||||
/* Get the 64-bytes unaligned header number targeting for avx512 | |||||
* processing (Assume input float array is natural aligned) */ | |||||
int align_header = ((64 - ((uintptr_t)in & (uintptr_t)0x3f)) >> 3) & 0x7; | |||||
if (n < align_header) {align_header = n;} | |||||
if (align_header != 0) { | |||||
unsigned char align_mask8 = (((unsigned char)0xff) >> (8-align_header)); | |||||
__m512d a = _mm512_maskz_loadu_pd(*((__mmask8*) &align_mask8), &in[0]); | |||||
_mm_mask_storeu_epi16(&out[0], *((__mmask8*) &align_mask8), (__m128i) _mm256_cvtneps_pbh(_mm512_cvtpd_ps(a))); | |||||
} | |||||
if (n == align_header) { | |||||
return; | |||||
} else { | |||||
n -= align_header; | |||||
in += align_header; | |||||
out += align_header; | |||||
} | |||||
int tail_index_8 = n&(~7); | |||||
int tail_index_32 = n&(~31); | |||||
int tail_index_128 = n&(~127); | |||||
unsigned char tail_mask8 = (((unsigned char) 0xff) >> (8 -(n&7))); | |||||
/* Processing the main chunk with 128-elements per round */ | |||||
for (int i = 0; i < tail_index_128; i += 128) { | |||||
// Fold 1 | |||||
__m512 data1_512_low = _mm512_insertf32x8(_mm512_castps256_ps512(_mm512_cvtpd_ps(_mm512_load_pd(&in[i+ 0]))), _mm512_cvtpd_ps(_mm512_load_pd(&in[i+ 8])), 1); | |||||
__m512 data1_512_high = _mm512_insertf32x8(_mm512_castps256_ps512(_mm512_cvtpd_ps(_mm512_load_pd(&in[i+16]))), _mm512_cvtpd_ps(_mm512_load_pd(&in[i+24])), 1); | |||||
_mm512_storeu_si512(&out[i+ 0], (__m512i) _mm512_cvtne2ps_pbh(data1_512_high, data1_512_low)); | |||||
// Fold 2 | |||||
__m512 data2_512_low = _mm512_insertf32x8(_mm512_castps256_ps512(_mm512_cvtpd_ps(_mm512_load_pd(&in[i+32]))), _mm512_cvtpd_ps(_mm512_load_pd(&in[i+40])), 1); | |||||
__m512 data2_512_high = _mm512_insertf32x8(_mm512_castps256_ps512(_mm512_cvtpd_ps(_mm512_load_pd(&in[i+48]))), _mm512_cvtpd_ps(_mm512_load_pd(&in[i+56])), 1); | |||||
_mm512_storeu_si512(&out[i+32], (__m512i) _mm512_cvtne2ps_pbh(data2_512_high, data2_512_low)); | |||||
// Fold 3 | |||||
__m512 data3_512_low = _mm512_insertf32x8(_mm512_castps256_ps512(_mm512_cvtpd_ps(_mm512_load_pd(&in[i+64]))), _mm512_cvtpd_ps(_mm512_load_pd(&in[i+72])), 1); | |||||
__m512 data3_512_high = _mm512_insertf32x8(_mm512_castps256_ps512(_mm512_cvtpd_ps(_mm512_load_pd(&in[i+80]))), _mm512_cvtpd_ps(_mm512_load_pd(&in[i+88])), 1); | |||||
_mm512_storeu_si512(&out[i+64], (__m512i) _mm512_cvtne2ps_pbh(data3_512_high, data3_512_low)); | |||||
// Fold 4 | |||||
__m512 data4_512_low = _mm512_insertf32x8(_mm512_castps256_ps512(_mm512_cvtpd_ps(_mm512_load_pd(&in[i+96]))), _mm512_cvtpd_ps(_mm512_load_pd(&in[i+104])), 1); | |||||
__m512 data4_512_high = _mm512_insertf32x8(_mm512_castps256_ps512(_mm512_cvtpd_ps(_mm512_load_pd(&in[i+112]))), _mm512_cvtpd_ps(_mm512_load_pd(&in[i+120])), 1); | |||||
_mm512_storeu_si512(&out[i+96], (__m512i) _mm512_cvtne2ps_pbh(data4_512_high, data4_512_low)); | |||||
} | |||||
/* Processing the remaining <128 chunk with 32-elements per round */ | |||||
for (int j = tail_index_128; j < tail_index_32; j += 32) { | |||||
__m512 data1_512_low = _mm512_insertf32x8(_mm512_castps256_ps512(_mm512_cvtpd_ps(_mm512_load_pd(&in[j+ 0]))), _mm512_cvtpd_ps(_mm512_load_pd(&in[j+ 8])), 1); | |||||
__m512 data1_512_high = _mm512_insertf32x8(_mm512_castps256_ps512(_mm512_cvtpd_ps(_mm512_load_pd(&in[j+16]))), _mm512_cvtpd_ps(_mm512_load_pd(&in[j+24])), 1); | |||||
_mm512_storeu_si512(&out[j], (__m512i) _mm512_cvtne2ps_pbh(data1_512_high, data1_512_low)); | |||||
} | |||||
/* Processing the remaining <32 chunk with 8-elements per round */ | |||||
for (int j = tail_index_32; j < tail_index_8; j += 8) { | |||||
_mm_storeu_si128((__m128i *)&out[j], (__m128i) _mm256_cvtneps_pbh(_mm512_cvtpd_ps(_mm512_load_pd(&in[j])))); | |||||
} | |||||
/* Processing the remaining <8 chunk with masked processing */ | |||||
if ((n&7) > 0) { | |||||
__m512d data_512 = _mm512_maskz_load_pd(*((__mmask8*) &tail_mask8), &in[tail_index_8]); | |||||
_mm_mask_storeu_epi16(&out[tail_index_8], *((__mmask8*) &tail_mask8), (__m128i) _mm256_cvtneps_pbh(_mm512_cvtpd_ps(data_512))); | |||||
} | |||||
} | |||||
#endif |
@@ -0,0 +1,115 @@ | |||||
/*************************************************************************** | |||||
Copyright (c) 2014, The OpenBLAS Project | |||||
All rights reserved. | |||||
Redistribution and use in source and binary forms, with or without | |||||
modification, are permitted provided that the following conditions are | |||||
met: | |||||
1. Redistributions of source code must retain the above copyright | |||||
notice, this list of conditions and the following disclaimer. | |||||
2. Redistributions in binary form must reproduce the above copyright | |||||
notice, this list of conditions and the following disclaimer in | |||||
the documentation and/or other materials provided with the | |||||
distribution. | |||||
3. Neither the name of the OpenBLAS project nor the names of | |||||
its contributors may be used to endorse or promote products | |||||
derived from this software without specific prior written permission. | |||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||||
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE | |||||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE | |||||
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||||
*****************************************************************************/ | |||||
#include "common.h" | |||||
#if defined(COOPERLAKE) | |||||
#include "shdot_microk_cooperlake.c" | |||||
#endif | |||||
static float shdot_compute(BLASLONG n, bfloat16 *x, BLASLONG inc_x, bfloat16 *y, BLASLONG inc_y) | |||||
{ | |||||
float d = 0.0; | |||||
#ifdef HAVE_SHDOT_ACCL_KERNEL | |||||
if ((inc_x == 1) && (inc_y == 1)) { | |||||
return shdot_accl_kernel(n, x, y); | |||||
} | |||||
#endif | |||||
float * x_fp32 = malloc(sizeof(float)*n); | |||||
float * y_fp32 = malloc(sizeof(float)*n); | |||||
SBF16TOS_K(n, x, inc_x, x_fp32, 1); | |||||
SBF16TOS_K(n, y, inc_y, y_fp32, 1); | |||||
d = SDOTU_K(n, x_fp32, 1, y_fp32, 1); | |||||
free(x_fp32); | |||||
free(y_fp32); | |||||
return d; | |||||
} | |||||
#if defined(SMP) | |||||
static int shdot_thread_func(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, bfloat16 dummy2, | |||||
bfloat16 *x, BLASLONG inc_x, bfloat16 *y, BLASLONG inc_y, | |||||
float *result, BLASLONG dummy3) | |||||
{ | |||||
*(float *)result = shdot_compute(n, x, inc_x, y, inc_y); | |||||
return 0; | |||||
} | |||||
extern int blas_level1_thread_with_return_value(int mode, BLASLONG m, BLASLONG n, BLASLONG k, void *alpha, | |||||
void *a, BLASLONG lda, void *b, BLASLONG ldb, void *c, BLASLONG ldc, | |||||
int (*function)(), int nthreads); | |||||
#endif | |||||
float CNAME(BLASLONG n, bfloat16 *x, BLASLONG inc_x, bfloat16 *y, BLASLONG inc_y) | |||||
{ | |||||
float dot_result = 0.0; | |||||
if (n <= 0) return 0.0; | |||||
#if defined(SMP) | |||||
int nthreads; | |||||
int thread_thres = 40960; | |||||
bfloat16 dummy_alpha; | |||||
#endif | |||||
#if defined(SMP) | |||||
if (inc_x == 0 || inc_y == 0 || n <= thread_thres) | |||||
nthreads = 1; | |||||
else | |||||
nthreads = num_cpu_avail(1); | |||||
int best_threads = (int) (n/(float)thread_thres + 0.5); | |||||
if (best_threads < nthreads) { | |||||
nthreads = best_threads; | |||||
} | |||||
if (nthreads <= 1) { | |||||
dot_result = shdot_compute(n, x, inc_x, y, inc_y); | |||||
} else { | |||||
char thread_result[MAX_CPU_NUMBER * sizeof(double) * 2]; | |||||
int mode = BLAS_BFLOAT16 | BLAS_REAL; | |||||
blas_level1_thread_with_return_value(mode, n, 0, 0, &dummy_alpha, | |||||
x, inc_x, y, inc_y, thread_result, 0, | |||||
(void *)shdot_thread_func, nthreads); | |||||
float * ptr = (float *)thread_result; | |||||
for (int i = 0; i < nthreads; i++) { | |||||
dot_result += (*ptr); | |||||
ptr = (float *)(((char *)ptr) + sizeof(double) * 2); | |||||
} | |||||
} | |||||
#else | |||||
dot_result = shdot_compute(n, x, inc_x, y, inc_y); | |||||
#endif | |||||
return dot_result; | |||||
} |
@@ -0,0 +1,159 @@ | |||||
/*************************************************************************** | |||||
Copyright (c) 2014, The OpenBLAS Project | |||||
All rights reserved. | |||||
Redistribution and use in source and binary forms, with or without | |||||
modification, are permitted provided that the following conditions are | |||||
met: | |||||
1. Redistributions of source code must retain the above copyright | |||||
notice, this list of conditions and the following disclaimer. | |||||
2. Redistributions in binary form must reproduce the above copyright | |||||
notice, this list of conditions and the following disclaimer in | |||||
the documentation and/or other materials provided with the | |||||
distribution. | |||||
3. Neither the name of the OpenBLAS project nor the names of | |||||
its contributors may be used to endorse or promote products | |||||
derived from this software without specific prior written permission. | |||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||||
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE | |||||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE | |||||
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||||
*****************************************************************************/ | |||||
/* need a new enough GCC for avx512 support */ | |||||
#if (( defined(__GNUC__) && __GNUC__ >= 10 && defined(__AVX512BF16__)) || (defined(__clang__) && __clang_major__ >= 9)) | |||||
#define HAVE_SHDOT_ACCL_KERNEL 1 | |||||
#include "common.h" | |||||
#include <immintrin.h> | |||||
static float shdot_accl_kernel(BLASLONG n, bfloat16 *x, bfloat16 *y) | |||||
{ | |||||
__m128 accum128 = _mm_setzero_ps(); | |||||
if (n> 127) { /* n range from 128 to inf. */ | |||||
long tail_index_32 = n&(~31); | |||||
long tail_index_128 = n&(~127); | |||||
unsigned int tail_mask_uint = (((unsigned int)0xffffffff) >> (32-(n&31))); | |||||
__mmask32 tail_mask = *((__mmask32*) &tail_mask_uint); | |||||
__m512 accum512_0 = _mm512_setzero_ps(); | |||||
__m512 accum512_1 = _mm512_setzero_ps(); | |||||
__m512 accum512_2 = _mm512_setzero_ps(); | |||||
__m512 accum512_3 = _mm512_setzero_ps(); | |||||
/* Processing the main chunk with 128-elements per round */ | |||||
for (long i = 0; i < tail_index_128; i += 128) { | |||||
accum512_0 = _mm512_dpbf16_ps(accum512_0, (__m512bh) _mm512_loadu_si512(&x[i+ 0]), (__m512bh) _mm512_loadu_si512(&y[i+ 0])); | |||||
accum512_1 = _mm512_dpbf16_ps(accum512_1, (__m512bh) _mm512_loadu_si512(&x[i+32]), (__m512bh) _mm512_loadu_si512(&y[i+32])); | |||||
accum512_2 = _mm512_dpbf16_ps(accum512_2, (__m512bh) _mm512_loadu_si512(&x[i+64]), (__m512bh) _mm512_loadu_si512(&y[i+64])); | |||||
accum512_3 = _mm512_dpbf16_ps(accum512_3, (__m512bh) _mm512_loadu_si512(&x[i+96]), (__m512bh) _mm512_loadu_si512(&y[i+96])); | |||||
} | |||||
/* Processing the remaining <128 chunk with 32-elements per round */ | |||||
for (long j = tail_index_128; j < tail_index_32; j += 32) { | |||||
accum512_0 = _mm512_dpbf16_ps(accum512_0, (__m512bh) _mm512_loadu_si512(&x[j]), (__m512bh) _mm512_loadu_si512(&y[j])); | |||||
} | |||||
/* Processing the remaining <32 chunk with masked 32-elements processing */ | |||||
if ((n&31) != 0) { | |||||
accum512_2 = _mm512_dpbf16_ps(accum512_2, | |||||
(__m512bh) _mm512_maskz_loadu_epi16(tail_mask, &x[tail_index_32]), | |||||
(__m512bh) _mm512_maskz_loadu_epi16(tail_mask, &y[tail_index_32])); | |||||
} | |||||
/* Accumulate the 4 registers into 1 register */ | |||||
accum512_0 = _mm512_add_ps(accum512_0, accum512_1); | |||||
accum512_2 = _mm512_add_ps(accum512_2, accum512_3); | |||||
accum512_0 = _mm512_add_ps(accum512_0, accum512_2); | |||||
__m256 accum256 = _mm256_add_ps(_mm512_castps512_ps256(accum512_0), _mm512_extractf32x8_ps(accum512_0, 1)); | |||||
accum128 = _mm_add_ps(_mm256_castps256_ps128(accum256), _mm256_extractf128_ps(accum256, 1)); | |||||
} else if (n > 31) { /* n range from 32 to 127 */ | |||||
/* Processing <128 chunk with 32-elements per round */ | |||||
__m256 accum256 = _mm256_setzero_ps(); | |||||
__m256 accum256_1 = _mm256_setzero_ps(); | |||||
int tail_index_32 = n&(~31); | |||||
for (int j = 0; j < tail_index_32; j += 32) { | |||||
accum256 = _mm256_dpbf16_ps(accum256, (__m256bh) _mm256_loadu_si256(&x[j+ 0]), (__m256bh) _mm256_loadu_si256(&y[j+ 0])); | |||||
accum256_1 = _mm256_dpbf16_ps(accum256_1, (__m256bh) _mm256_loadu_si256(&x[j+16]), (__m256bh) _mm256_loadu_si256(&y[j+16])); | |||||
} | |||||
accum256 = _mm256_add_ps(accum256, accum256_1); | |||||
/* Processing the remaining <32 chunk with 16-elements processing */ | |||||
if ((n&16) != 0) { | |||||
accum256 = _mm256_dpbf16_ps(accum256, (__m256bh) _mm256_loadu_si256(&x[tail_index_32]), (__m256bh) _mm256_loadu_si256(&y[tail_index_32])); | |||||
} | |||||
accum128 = _mm_add_ps(_mm256_castps256_ps128(accum256), _mm256_extractf128_ps(accum256, 1)); | |||||
/* Processing the remaining <16 chunk with 8-elements processing */ | |||||
if ((n&8) != 0) { | |||||
int tail_index_16 = n&(~15); | |||||
accum128 = _mm_dpbf16_ps(accum128, (__m128bh) _mm_loadu_si128(&x[tail_index_16]), (__m128bh) _mm_loadu_si128(&y[tail_index_16])); | |||||
} | |||||
/* Processing the remaining <8 chunk with masked 8-elements processing */ | |||||
if ((n&7) != 0) { | |||||
unsigned char tail_mask_uint = (((unsigned char)0xff) >> (8-(n&7))); | |||||
__mmask8 tail_mask = *((__mmask8*) &tail_mask_uint); | |||||
int tail_index_8 = n&(~7); | |||||
accum128 = _mm_dpbf16_ps(accum128, | |||||
(__m128bh) _mm_maskz_loadu_epi16(tail_mask, &x[tail_index_8]), | |||||
(__m128bh) _mm_maskz_loadu_epi16(tail_mask, &y[tail_index_8])); | |||||
} | |||||
} else if (n > 15) { /* n range from 16 to 31 */ | |||||
/* Processing <32 chunk with 16-elements processing */ | |||||
__m256 accum256 = _mm256_setzero_ps(); | |||||
accum256 = _mm256_dpbf16_ps(accum256, (__m256bh) _mm256_loadu_si256(&x[0]), (__m256bh) _mm256_loadu_si256(&y[0])); | |||||
accum128 += _mm_add_ps(_mm256_castps256_ps128(accum256), _mm256_extractf128_ps(accum256, 1)); | |||||
/* Processing the remaining <16 chunk with 8-elements processing */ | |||||
if ((n&8) != 0) { | |||||
int tail_index_16 = n&(~15); | |||||
accum128 = _mm_dpbf16_ps(accum128, (__m128bh) _mm_loadu_si128(&x[tail_index_16]), (__m128bh) _mm_loadu_si128(&y[tail_index_16])); | |||||
} | |||||
/* Processing the remaining <8 chunk with masked 8-elements processing */ | |||||
if ((n&7) != 0) { | |||||
unsigned char tail_mask_uint = (((unsigned char)0xff) >> (8-(n&7))); | |||||
__mmask8 tail_mask = *((__mmask8*) &tail_mask_uint); | |||||
int tail_index_8 = n&(~7); | |||||
accum128 = _mm_dpbf16_ps(accum128, | |||||
(__m128bh) _mm_maskz_loadu_epi16(tail_mask, &x[tail_index_8]), | |||||
(__m128bh) _mm_maskz_loadu_epi16(tail_mask, &y[tail_index_8])); | |||||
} | |||||
} else if (n > 7) { /* n range from 8 to 15 */ | |||||
/* Processing <16 chunk with 8-elements processing */ | |||||
accum128 = _mm_dpbf16_ps(accum128, (__m128bh) _mm_loadu_si128(&x[0]), (__m128bh) _mm_loadu_si128(&y[0])); | |||||
/* Processing the remaining <8 chunk with masked 8-elements processing */ | |||||
if ((n&7) != 0) { | |||||
unsigned char tail_mask_uint = (((unsigned char)0xff) >> (8-(n&7))); | |||||
__mmask8 tail_mask = *((__mmask8*) &tail_mask_uint); | |||||
int tail_index_8 = n&(~7); | |||||
accum128 = _mm_dpbf16_ps(accum128, | |||||
(__m128bh) _mm_maskz_loadu_epi16(tail_mask, &x[tail_index_8]), | |||||
(__m128bh) _mm_maskz_loadu_epi16(tail_mask, &y[tail_index_8])); | |||||
} | |||||
} else { /* n range from 1 to 7 */ | |||||
unsigned char tail_mask_uint = (((unsigned char)0xff) >> (8-(n&7))); | |||||
__mmask8 tail_mask = *((__mmask8*) &tail_mask_uint); | |||||
accum128 = _mm_dpbf16_ps(accum128, | |||||
(__m128bh) _mm_maskz_loadu_epi16(tail_mask, &x[0]), | |||||
(__m128bh) _mm_maskz_loadu_epi16(tail_mask, &y[0])); | |||||
} | |||||
/* Add up the 4 elements into lowest entry */ | |||||
__m128 accum128_1 = _mm_shuffle_ps(accum128, accum128, 14); | |||||
accum128 = _mm_add_ps(accum128, accum128_1); | |||||
accum128_1 = _mm_shuffle_ps(accum128, accum128, 1); | |||||
accum128 = _mm_add_ps(accum128, accum128_1); | |||||
return accum128[0]; | |||||
} | |||||
#endif |
@@ -0,0 +1,86 @@ | |||||
/*************************************************************************** | |||||
Copyright (c) 2014, The OpenBLAS Project | |||||
All rights reserved. | |||||
Redistribution and use in source and binary forms, with or without | |||||
modification, are permitted provided that the following conditions are | |||||
met: | |||||
1. Redistributions of source code must retain the above copyright | |||||
notice, this list of conditions and the following disclaimer. | |||||
2. Redistributions in binary form must reproduce the above copyright | |||||
notice, this list of conditions and the following disclaimer in | |||||
the documentation and/or other materials provided with the | |||||
distribution. | |||||
3. Neither the name of the OpenBLAS project nor the names of | |||||
its contributors may be used to endorse or promote products | |||||
derived from this software without specific prior written permission. | |||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||||
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE | |||||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE | |||||
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||||
*****************************************************************************/ | |||||
/* need a new enough GCC for avx512 support */ | |||||
#if (( defined(__GNUC__) && __GNUC__ >= 10 && defined(__AVX512BF16__)) || (defined(__clang__) && __clang_major__ >= 9)) | |||||
#define HAVE_TOBF16_ACCL_KERNEL 1 | |||||
#include "common.h" | |||||
#include <immintrin.h> | |||||
static void tobf16_accl_kernel(BLASLONG n, const float * in, bfloat16 * out) | |||||
{ | |||||
/* Get the 64-bytes unaligned header number targeting for avx512 | |||||
* processing (Assume input float array is natural aligned) */ | |||||
int align_header = ((64 - ((uintptr_t)in & (uintptr_t)0x3f)) >> 2) & 0xf; | |||||
if (n < align_header) {align_header = n;} | |||||
if (align_header != 0) { | |||||
uint16_t align_mask16 = (((uint16_t)0xffff) >> (16-align_header)); | |||||
__m512 a = _mm512_maskz_loadu_ps(*((__mmask16*) &align_mask16), &in[0]); | |||||
_mm256_mask_storeu_epi16(&out[0], *((__mmask16*) &align_mask16), (__m256i) _mm512_cvtneps_pbh(a)); | |||||
} | |||||
if (n == align_header) { | |||||
return; | |||||
} else { | |||||
n -= align_header; | |||||
in += align_header; | |||||
out += align_header; | |||||
} | |||||
int tail_index_32 = n&(~31); | |||||
int tail_index_128 = n&(~127); | |||||
uint32_t tail_mask32 = (((uint32_t) 0xffffffff) >> (32-(n&31))); | |||||
uint16_t tail_mask16 = (((uint16_t) 0xffff) >> (16-(n&15))); | |||||
/* Processing the main chunk with 128-elements per round */ | |||||
for (int i = 0; i < tail_index_128; i += 128) { | |||||
_mm512_storeu_si512(&out[i+ 0], (__m512i) _mm512_cvtne2ps_pbh(_mm512_load_ps(&in[i+ 16]), _mm512_load_ps(&in[i+ 0]))); | |||||
_mm512_storeu_si512(&out[i+32], (__m512i) _mm512_cvtne2ps_pbh(_mm512_load_ps(&in[i+ 48]), _mm512_load_ps(&in[i+32]))); | |||||
_mm512_storeu_si512(&out[i+64], (__m512i) _mm512_cvtne2ps_pbh(_mm512_load_ps(&in[i+ 80]), _mm512_load_ps(&in[i+64]))); | |||||
_mm512_storeu_si512(&out[i+96], (__m512i) _mm512_cvtne2ps_pbh(_mm512_load_ps(&in[i+112]), _mm512_load_ps(&in[i+96]))); | |||||
} | |||||
/* Processing the remaining <128 chunk with 32-elements per round */ | |||||
for (int j = tail_index_128; j < tail_index_32; j += 32) { | |||||
_mm512_storeu_si512(&out[j], (__m512i) _mm512_cvtne2ps_pbh(_mm512_load_ps(&in[j+ 16]), _mm512_load_ps(&in[j]))); | |||||
} | |||||
/* Processing the remaining <32 chunk with masked processing */ | |||||
if ((n&31) > 15) { | |||||
__m512 b = _mm512_load_ps(&in[tail_index_32]); | |||||
__m512 a = _mm512_maskz_load_ps(*((__mmask16*) &tail_mask16), &in[tail_index_32+16]); | |||||
_mm512_mask_storeu_epi16(&out[tail_index_32], *((__mmask32*) &tail_mask32), (__m512i) _mm512_cvtne2ps_pbh(a, b)); | |||||
} else if ((n&31) > 0) { | |||||
__m512 a = _mm512_maskz_load_ps(*((__mmask16*) &tail_mask16), &in[tail_index_32]); | |||||
_mm256_mask_storeu_epi16(&out[tail_index_32], *((__mmask16*) &tail_mask16), (__m256i) _mm512_cvtneps_pbh(a)); | |||||
} | |||||
} | |||||
#endif |
@@ -0,0 +1,170 @@ | |||||
/*************************************************************************** | |||||
Copyright (c) 2014, The OpenBLAS Project | |||||
All rights reserved. | |||||
Redistribution and use in source and binary forms, with or without | |||||
modification, are permitted provided that the following conditions are | |||||
met: | |||||
1. Redistributions of source code must retain the above copyright | |||||
notice, this list of conditions and the following disclaimer. | |||||
2. Redistributions in binary form must reproduce the above copyright | |||||
notice, this list of conditions and the following disclaimer in | |||||
the documentation and/or other materials provided with the | |||||
distribution. | |||||
3. Neither the name of the OpenBLAS project nor the names of | |||||
its contributors may be used to endorse or promote products | |||||
derived from this software without specific prior written permission. | |||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||||
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE | |||||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE | |||||
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||||
*****************************************************************************/ | |||||
#include <stddef.h> | |||||
#include "common.h" | |||||
#if defined(DOUBLE) | |||||
#define FLOAT_TYPE double | |||||
#elif defined(SINGLE) | |||||
#define FLOAT_TYPE float | |||||
#else | |||||
#endif | |||||
#if defined(COOPERLAKE) | |||||
#if defined(DOUBLE) | |||||
#include "dtobf16_microk_cooperlake.c" | |||||
#elif defined(SINGLE) | |||||
#include "stobf16_microk_cooperlake.c" | |||||
#endif | |||||
#endif | |||||
/* Notes for algorithm: | |||||
* - Round to Nearest Even used generally | |||||
* - QNAN for NAN case | |||||
* - Input denormals are treated as zero | |||||
*/ | |||||
static void tobf16_generic_kernel(BLASLONG n, const FLOAT_TYPE * in, BLASLONG inc_in, bfloat16 * out, BLASLONG inc_out) | |||||
{ | |||||
BLASLONG register index_in = 0; | |||||
BLASLONG register index_out = 0; | |||||
BLASLONG register index = 0; | |||||
float float_in = 0.0; | |||||
uint32_t * uint32_in = (uint32_t *)(&float_in); | |||||
uint16_t * uint16_in = (uint16_t *)(&float_in); | |||||
while(index<n) { | |||||
#if defined(DOUBLE) | |||||
float_in = (float)(*(in+index_in)); | |||||
#else | |||||
float_in = *(in+index_in); | |||||
#endif | |||||
switch((*uint32_in) & 0xff800000u) { | |||||
case (0x00000000u): /* Type 1: Positive denormal */ | |||||
*(out+index_out) = 0x0000u; | |||||
break; | |||||
case (0x80000000u): /* Type 2: Negative denormal */ | |||||
*(out+index_out) = 0x8000u; | |||||
break; | |||||
case (0x7f800000u): /* Type 3: Positive infinity or NAN */ | |||||
case (0xff800000u): /* Type 4: Negative infinity or NAN */ | |||||
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ | |||||
*(out+index_out) = uint16_in[1]; | |||||
#else | |||||
*(out+index_out) = uint16_in[0]; | |||||
#endif | |||||
/* Specific for NAN */ | |||||
if (((*uint32_in) & 0x007fffffu) != 0) { | |||||
/* Force to be QNAN */ | |||||
*(out+index_out) |= 0x0040u; | |||||
} | |||||
break; | |||||
default: /* Type 5: Normal case */ | |||||
(*uint32_in) += ((((*uint32_in) >> 16) & 0x1u) + 0x7fffu); | |||||
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ | |||||
*(out+index_out) = uint16_in[1]; | |||||
#else | |||||
*(out+index_out) = uint16_in[0]; | |||||
#endif | |||||
break; | |||||
} | |||||
index_in += inc_in; | |||||
index_out += inc_out; | |||||
index++; | |||||
} | |||||
} | |||||
#ifndef HAVE_TOBF16_ACCL_KERNEL | |||||
static void tobf16_accl_kernel(BLASLONG n, const FLOAT_TYPE * in, bfloat16 * out) | |||||
{ | |||||
tobf16_generic_kernel(n, in, 1, out, 1); | |||||
} | |||||
#endif | |||||
static void tobf16_compute(BLASLONG n, FLOAT_TYPE * in, BLASLONG inc_in, bfloat16 * out, BLASLONG inc_out) | |||||
{ | |||||
if ((inc_in == 1) && (inc_out == 1)) { | |||||
tobf16_accl_kernel(n, in, out); | |||||
} else { | |||||
tobf16_generic_kernel(n, in, inc_in, out, inc_out); | |||||
} | |||||
} | |||||
#if defined(SMP) | |||||
static int tobf16_thread_func(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT_TYPE dummy2, | |||||
FLOAT_TYPE *x, BLASLONG inc_x, bfloat16 *y, BLASLONG inc_y, | |||||
FLOAT_TYPE *dummy3, BLASLONG dummy4) | |||||
{ | |||||
tobf16_compute(n, x, inc_x, y, inc_y); | |||||
return 0; | |||||
} | |||||
extern int blas_level1_thread(int mode, BLASLONG m, BLASLONG n, BLASLONG k, void *alpha, | |||||
void *a, BLASLONG lda, void *b, BLASLONG ldb, void *c, BLASLONG ldc, | |||||
int (*function)(), int nthreads); | |||||
#endif | |||||
void CNAME(BLASLONG n, FLOAT_TYPE * in, BLASLONG inc_in, bfloat16 * out, BLASLONG inc_out) | |||||
{ | |||||
if (n <= 0) return; | |||||
#if defined(SMP) | |||||
int nthreads; | |||||
FLOAT_TYPE dummy_alpha; | |||||
FLOAT_TYPE dummy_c; | |||||
#endif | |||||
#if defined(SMP) | |||||
if (inc_in == 0 || inc_out == 0 || n <= 100000) { | |||||
nthreads = 1; | |||||
} else { | |||||
if (n/100000 < 100) { | |||||
nthreads = 4; | |||||
} else { | |||||
nthreads = 16; | |||||
} | |||||
} | |||||
if (nthreads == 1) { | |||||
tobf16_compute(n, in, inc_in, out, inc_out); | |||||
} else { | |||||
#if defined(DOUBLE) | |||||
int mode = BLAS_REAL | BLAS_DTOBF16; | |||||
#elif defined(SINGLE) | |||||
int mode = BLAS_REAL | BLAS_STOBF16; | |||||
#endif | |||||
blas_level1_thread(mode, n, 0, 0, &dummy_alpha, | |||||
in, inc_in, out, inc_out, &dummy_c, 0, | |||||
(void *)tobf16_thread_func, nthreads); | |||||
} | |||||
#else | |||||
tobf16_compute(n, in, inc_in, out, inc_out); | |||||
#endif | |||||
} |
@@ -35,7 +35,8 @@ typedef unsigned long BLASULONG; | |||||
#endif | #endif | ||||
#ifndef BFLOAT16 | #ifndef BFLOAT16 | ||||
typedef unsigned short bfloat16; | |||||
#include <stdint.h> | |||||
typedef uint16_t bfloat16; | |||||
#endif | #endif | ||||
#ifdef OPENBLAS_USE64BITINT | #ifdef OPENBLAS_USE64BITINT | ||||