1. Add a new API -- sbgemv to support bfloat16 based gemv 2. Implement a generic kernel for sbgemv 3. Implement an avx512-bf16 based kernel for sbgemv Signed-off-by: Chen, Guobing <guobing.chen@intel.com>tags/v0.3.13^2
@@ -393,6 +393,7 @@ void cblas_sbf16tos(OPENBLAS_CONST blasint n, OPENBLAS_CONST bfloat16 *in, OPE | |||||
void cblas_dbf16tod(OPENBLAS_CONST blasint n, OPENBLAS_CONST bfloat16 *in, OPENBLAS_CONST blasint incin, double *out, OPENBLAS_CONST blasint incout); | void cblas_dbf16tod(OPENBLAS_CONST blasint n, OPENBLAS_CONST bfloat16 *in, OPENBLAS_CONST blasint incin, double *out, OPENBLAS_CONST blasint incout); | ||||
/* dot production of BFLOAT16 input arrays, and output as float */ | /* dot production of BFLOAT16 input arrays, and output as float */ | ||||
float cblas_sbdot(OPENBLAS_CONST blasint n, OPENBLAS_CONST bfloat16 *x, OPENBLAS_CONST blasint incx, OPENBLAS_CONST bfloat16 *y, OPENBLAS_CONST blasint incy); | float cblas_sbdot(OPENBLAS_CONST blasint n, OPENBLAS_CONST bfloat16 *x, OPENBLAS_CONST blasint incx, OPENBLAS_CONST bfloat16 *y, OPENBLAS_CONST blasint incy); | ||||
void cblas_sbgemv(OPENBLAS_CONST enum CBLAS_ORDER order, OPENBLAS_CONST enum CBLAS_TRANSPOSE trans, OPENBLAS_CONST blasint m, OPENBLAS_CONST blasint n, OPENBLAS_CONST float alpha, OPENBLAS_CONST bfloat16 *a, OPENBLAS_CONST blasint lda, OPENBLAS_CONST bfloat16 *x, OPENBLAS_CONST blasint incx, OPENBLAS_CONST float beta, float *y, OPENBLAS_CONST blasint incy); | |||||
#ifdef __cplusplus | #ifdef __cplusplus | ||||
} | } | ||||
@@ -184,8 +184,8 @@ macro(SetDefaultL2) | |||||
set(XHEMV_V_KERNEL ../generic/zhemv_k.c) | set(XHEMV_V_KERNEL ../generic/zhemv_k.c) | ||||
set(XHEMV_M_KERNEL ../generic/zhemv_k.c) | set(XHEMV_M_KERNEL ../generic/zhemv_k.c) | ||||
if (BUILD_BFLOAT16) | if (BUILD_BFLOAT16) | ||||
set(SBGEMVNKERNEL ../arm/gemv_n.c) | |||||
set(SBGEMVTKERNEL ../arm/gemv_t.c) | |||||
set(SBGEMVNKERNEL ../x86_64/sbgemv_n.c) | |||||
set(SBGEMVTKERNEL ../x86_64/sbgemv_t.c) | |||||
set(SHGERKERNEL ../generic/ger.c) | set(SHGERKERNEL ../generic/ger.c) | ||||
endif () | endif () | ||||
endmacro () | endmacro () | ||||
@@ -250,6 +250,8 @@ void BLASFUNC(xgeru)(blasint *, blasint *, xdouble *, xdouble *, blasint *, | |||||
void BLASFUNC(xgerc)(blasint *, blasint *, xdouble *, xdouble *, blasint *, | void BLASFUNC(xgerc)(blasint *, blasint *, xdouble *, xdouble *, blasint *, | ||||
xdouble *, blasint *, xdouble *, blasint *); | xdouble *, blasint *, xdouble *, blasint *); | ||||
void BLASFUNC(sbgemv)(char *, blasint *, blasint *, float *, bfloat16 *, blasint *, | |||||
bfloat16 *, blasint *, float *, float *, blasint *); | |||||
void BLASFUNC(sgemv)(char *, blasint *, blasint *, float *, float *, blasint *, | void BLASFUNC(sgemv)(char *, blasint *, blasint *, float *, float *, blasint *, | ||||
float *, blasint *, float *, float *, blasint *); | float *, blasint *, float *, float *, blasint *); | ||||
void BLASFUNC(dgemv)(char *, blasint *, blasint *, double *, double *, blasint *, | void BLASFUNC(dgemv)(char *, blasint *, blasint *, double *, double *, blasint *, | ||||
@@ -44,6 +44,10 @@ | |||||
extern "C" { | extern "C" { | ||||
#endif | #endif | ||||
int sbgemv_n(BLASLONG, BLASLONG, float, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, float, float *, BLASLONG); | |||||
int sbgemv_t(BLASLONG, BLASLONG, float, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, float, float *, BLASLONG); | |||||
int sbgemv_thread_n(BLASLONG, BLASLONG, float, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, float, float *, BLASLONG, int); | |||||
int sbgemv_thread_t(BLASLONG, BLASLONG, float, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, float, float *, BLASLONG, int); | |||||
int sger_k (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG, float *); | int sger_k (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG, float *); | ||||
int dger_k (BLASLONG, BLASLONG, BLASLONG, double, double *, BLASLONG, double *, BLASLONG, double *, BLASLONG, double *); | int dger_k (BLASLONG, BLASLONG, BLASLONG, double, double *, BLASLONG, double *, BLASLONG, double *, BLASLONG, double *); | ||||
int qger_k (BLASLONG, BLASLONG, BLASLONG, xdouble, xdouble *, BLASLONG, xdouble *, BLASLONG, xdouble *, BLASLONG, xdouble *); | int qger_k (BLASLONG, BLASLONG, BLASLONG, xdouble, xdouble *, BLASLONG, xdouble *, BLASLONG, xdouble *, BLASLONG, xdouble *); | ||||
@@ -646,10 +646,12 @@ | |||||
#elif defined(BFLOAT16) | #elif defined(BFLOAT16) | ||||
#define D_TO_BF16_K SBDTOBF16_K | |||||
#define D_BF16_TO_K DBF16TOD_K | |||||
#define S_TO_BF16_K SBSTOBF16_K | |||||
#define S_BF16_TO_K SBF16TOS_K | |||||
#define D_TO_BF16_K SBDTOBF16_K | |||||
#define D_BF16_TO_K DBF16TOD_K | |||||
#define S_TO_BF16_K SBSTOBF16_K | |||||
#define S_BF16_TO_K SBF16TOS_K | |||||
#define SBGEMV_N SBGEMV_N_K | |||||
#define SBGEMV_T SBGEMV_T_K | |||||
#define AMAX_K SAMAX_K | #define AMAX_K SAMAX_K | ||||
#define AMIN_K SAMIN_K | #define AMIN_K SAMIN_K | ||||
@@ -78,8 +78,8 @@ BLASLONG (*isbmin_k) (BLASLONG, float *, BLASLONG); | |||||
int (*sbscal_k) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG); | int (*sbscal_k) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG); | ||||
int (*sbswap_k) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG); | int (*sbswap_k) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG); | ||||
int (*sbgemv_n) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG, float *); | |||||
int (*sbgemv_t) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG, float *); | |||||
int (*sbgemv_n) (BLASLONG, BLASLONG, float, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, float, float *, BLASLONG); | |||||
int (*sbgemv_t) (BLASLONG, BLASLONG, float, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, float, float *, BLASLONG); | |||||
int (*sbger_k) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG, float *); | int (*sbger_k) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG, float *); | ||||
int (*sbsymv_L) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG, float *); | int (*sbsymv_L) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG, float *); | ||||
@@ -8,6 +8,8 @@ | |||||
#define SBDTOBF16_K sbdtobf16_k | #define SBDTOBF16_K sbdtobf16_k | ||||
#define SBF16TOS_K sbf16tos_k | #define SBF16TOS_K sbf16tos_k | ||||
#define DBF16TOD_K dbf16tod_k | #define DBF16TOD_K dbf16tod_k | ||||
#define SBGEMV_N_K sbgemv_n | |||||
#define SBGEMV_T_K sbgemv_t | |||||
#define SBGEMM_ONCOPY sbgemm_oncopy | #define SBGEMM_ONCOPY sbgemm_oncopy | ||||
#define SBGEMM_OTCOPY sbgemm_otcopy | #define SBGEMM_OTCOPY sbgemm_otcopy | ||||
@@ -29,6 +31,8 @@ | |||||
#define SBDTOBF16_K gotoblas -> sbdtobf16_k | #define SBDTOBF16_K gotoblas -> sbdtobf16_k | ||||
#define SBF16TOS_K gotoblas -> sbf16tos_k | #define SBF16TOS_K gotoblas -> sbf16tos_k | ||||
#define DBF16TOD_K gotoblas -> dbf16tod_k | #define DBF16TOD_K gotoblas -> dbf16tod_k | ||||
#define SBGEMV_N_K gotoblas -> sbgemv_n | |||||
#define SBGEMV_T_K gotoblas -> sbgemv_t | |||||
#define SBGEMM_ONCOPY gotoblas -> sbgemm_oncopy | #define SBGEMM_ONCOPY gotoblas -> sbgemm_oncopy | ||||
#define SBGEMM_OTCOPY gotoblas -> sbgemm_otcopy | #define SBGEMM_OTCOPY gotoblas -> sbgemm_otcopy | ||||
@@ -413,7 +413,13 @@ XBLASOBJS += \ | |||||
xtbmv_thread_RUU.$(SUFFIX) xtbmv_thread_RUN.$(SUFFIX) \ | xtbmv_thread_RUU.$(SUFFIX) xtbmv_thread_RUN.$(SUFFIX) \ | ||||
xtbmv_thread_RLU.$(SUFFIX) xtbmv_thread_RLN.$(SUFFIX) \ | xtbmv_thread_RLU.$(SUFFIX) xtbmv_thread_RLN.$(SUFFIX) \ | ||||
xtbmv_thread_CUU.$(SUFFIX) xtbmv_thread_CUN.$(SUFFIX) \ | xtbmv_thread_CUU.$(SUFFIX) xtbmv_thread_CUN.$(SUFFIX) \ | ||||
xtbmv_thread_CLU.$(SUFFIX) xtbmv_thread_CLN.$(SUFFIX) \ | |||||
xtbmv_thread_CLU.$(SUFFIX) xtbmv_thread_CLN.$(SUFFIX) | |||||
ifeq ($(BUILD_BFLOAT16),1) | |||||
SBBLASOBJS += \ | |||||
sbgemv_thread_n$(TSUFFIX).$(SUFFIX) \ | |||||
sbgemv_thread_t$(TSUFFIX).$(SUFFIX) | |||||
endif | |||||
endif | endif | ||||
@@ -3693,4 +3699,12 @@ xtrsv_CUU.$(SUFFIX) xtrsv_CUU.$(PSUFFIX) : ztrsv_L.c ../../param.h | |||||
xtrsv_CUN.$(SUFFIX) xtrsv_CUN.$(PSUFFIX) : ztrsv_L.c ../../param.h | xtrsv_CUN.$(SUFFIX) xtrsv_CUN.$(PSUFFIX) : ztrsv_L.c ../../param.h | ||||
$(CC) -c $(CFLAGS) -DXDOUBLE -DCOMPLEX -DTRANSA=4 -UUNIT $< -o $(@F) | $(CC) -c $(CFLAGS) -DXDOUBLE -DCOMPLEX -DTRANSA=4 -UUNIT $< -o $(@F) | ||||
ifeq ($(BUILD_BFLOAT16),1) | |||||
sbgemv_thread_n.$(SUFFIX) sbgemv_thread_n.$(PSUFFIX) : sbgemv_thread.c ../../common.h | |||||
$(CC) -c $(CFLAGS) -UCOMPLEX -UDOUBLE -UTRANSA -UCONJ -UXCONJ $< -o $(@F) | |||||
sbgemv_thread_t.$(SUFFIX) sbgemv_thread_t.$(PSUFFIX) : sbgemv_thread.c ../../common.h | |||||
$(CC) -c $(CFLAGS) -UCOMPLEX -UDOUBLE -DTRANSA -UCONJ -UXCONJ $< -o $(@F) | |||||
endif | |||||
include ../../Makefile.tail | include ../../Makefile.tail |
@@ -0,0 +1,149 @@ | |||||
/*********************************************************************/ | |||||
/* Copyright 2009, 2010 The University of Texas at Austin. */ | |||||
/* All rights reserved. */ | |||||
/* */ | |||||
/* Redistribution and use in source and binary forms, with or */ | |||||
/* without modification, are permitted provided that the following */ | |||||
/* conditions are met: */ | |||||
/* */ | |||||
/* 1. Redistributions of source code must retain the above */ | |||||
/* copyright notice, this list of conditions and the following */ | |||||
/* disclaimer. */ | |||||
/* */ | |||||
/* 2. Redistributions in binary form must reproduce the above */ | |||||
/* copyright notice, this list of conditions and the following */ | |||||
/* disclaimer in the documentation and/or other materials */ | |||||
/* provided with the distribution. */ | |||||
/* */ | |||||
/* THIS SOFTWARE IS PROVIDED BY THE UNIVERSITY OF TEXAS AT */ | |||||
/* AUSTIN ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, */ | |||||
/* INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF */ | |||||
/* MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE */ | |||||
/* DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY OF TEXAS AT */ | |||||
/* AUSTIN OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, */ | |||||
/* INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES */ | |||||
/* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE */ | |||||
/* GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR */ | |||||
/* BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF */ | |||||
/* LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT */ | |||||
/* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT */ | |||||
/* OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE */ | |||||
/* POSSIBILITY OF SUCH DAMAGE. */ | |||||
/* */ | |||||
/* The views and conclusions contained in the software and */ | |||||
/* documentation are those of the authors and should not be */ | |||||
/* interpreted as representing official policies, either expressed */ | |||||
/* or implied, of The University of Texas at Austin. */ | |||||
/*********************************************************************/ | |||||
#include <stdio.h> | |||||
#include <stdlib.h> | |||||
#include "common.h" | |||||
#ifndef TRANSA | |||||
#define SBGEMV SBGEMV_N | |||||
#else | |||||
#define SBGEMV SBGEMV_T | |||||
#endif | |||||
static int sbgemv_kernel(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *dummy1, FLOAT *dummy2, BLASLONG dummy3){ | |||||
bfloat16 *a, *x; | |||||
float *y; | |||||
BLASLONG lda, incx, incy; | |||||
BLASLONG m_from, m_to, n_from, n_to; | |||||
a = (bfloat16 *)args->a; | |||||
x = (bfloat16 *)args->b; | |||||
y = (float *)args->c; | |||||
lda = args->lda; | |||||
incx = args->ldb; | |||||
incy = args->ldc; | |||||
#ifndef TRANSA // N | |||||
m_from = *(range_m + 0); | |||||
m_to = *(range_m + 1); | |||||
n_from = 0; | |||||
n_to = args -> n; | |||||
a += m_from; | |||||
y += m_from * incy; | |||||
#else // T | |||||
m_from = 0; | |||||
m_to = args->m; | |||||
n_from = *(range_n + 0); | |||||
n_to = *(range_n + 1); | |||||
a += n_from * lda; | |||||
y += n_from * incy; | |||||
#endif | |||||
SBGEMV(m_to - m_from, n_to - n_from, *((FLOAT *)(args->alpha)), a, lda, x, incx, *((FLOAT *)(args->beta)), y, incy); | |||||
return 0; | |||||
} | |||||
int CNAME(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, BLASLONG incx, float beta, float *y, BLASLONG incy, int threads) | |||||
{ | |||||
blas_arg_t args; | |||||
blas_queue_t queue[MAX_CPU_NUMBER]; | |||||
BLASLONG range[MAX_CPU_NUMBER + 1]; | |||||
#ifndef TRANSA | |||||
BLASLONG width_for_split = m; | |||||
#else | |||||
BLASLONG width_for_split = n; | |||||
#endif | |||||
BLASLONG BLOCK_WIDTH = width_for_split/threads; | |||||
int mode = BLAS_BFLOAT16 | BLAS_REAL; | |||||
args.m = m; | |||||
args.n = n; | |||||
args.a = (void *)a; | |||||
args.b = (void *)x; | |||||
args.c = (void *)y; | |||||
args.lda = lda; | |||||
args.ldb = incx; | |||||
args.ldc = incy; | |||||
args.alpha = (void *)α | |||||
args.beta = (void *)β | |||||
range[0] = 0; | |||||
int thread_idx; | |||||
for (thread_idx=0; thread_idx<threads; thread_idx++) { | |||||
if (thread_idx != threads-1) { | |||||
range[thread_idx + 1] = range[thread_idx] + BLOCK_WIDTH; | |||||
} else { | |||||
range[thread_idx + 1] = range[thread_idx] + width_for_split; | |||||
} | |||||
queue[thread_idx].mode = mode; | |||||
queue[thread_idx].routine = sbgemv_kernel; | |||||
queue[thread_idx].args = &args; | |||||
#ifndef TRANSA | |||||
queue[thread_idx].range_m = &range[thread_idx]; | |||||
queue[thread_idx].range_n = NULL; | |||||
#else | |||||
queue[thread_idx].range_m = NULL; | |||||
queue[thread_idx].range_n = &range[thread_idx]; | |||||
#endif | |||||
queue[thread_idx].sa = NULL; | |||||
queue[thread_idx].sb = NULL; | |||||
queue[thread_idx].next = &queue[thread_idx + 1]; | |||||
width_for_split -= BLOCK_WIDTH; | |||||
} | |||||
if (thread_idx) { | |||||
queue[0].sa = NULL; | |||||
queue[0].sb = NULL; | |||||
queue[thread_idx - 1].next = NULL; | |||||
exec_blas(thread_idx, queue); | |||||
} | |||||
return 0; | |||||
} |
@@ -352,7 +352,6 @@ fprintf(stderr,"UNHANDLED COMPLEX\n"); | |||||
/* Other types in future */ | /* Other types in future */ | ||||
} | } | ||||
} | } | ||||
if (!sb) fprintf(stderr,"SB not declared!!!\n"); | |||||
queue->sb=sb; | queue->sb=sb; | ||||
} | } | ||||
} | } | ||||
@@ -51,7 +51,7 @@ | |||||
zgeadd, dzsum); | zgeadd, dzsum); | ||||
@blasobjs = (lsame, xerbla); | @blasobjs = (lsame, xerbla); | ||||
@bfblasobjs = (sbgemm, sbdot, sbstobf16, sbdtobf16, sbf16tos, dbf16tod); | |||||
@bfblasobjs = (sbgemm, sbgemv, sbdot, sbstobf16, sbdtobf16, sbf16tos, dbf16tod); | |||||
@cblasobjsc = ( | @cblasobjsc = ( | ||||
cblas_caxpy, cblas_ccopy, cblas_cdotc, cblas_cdotu, cblas_cgbmv, cblas_cgemm, cblas_cgemv, | cblas_caxpy, cblas_ccopy, cblas_cdotc, cblas_cdotu, cblas_cgbmv, cblas_cgemm, cblas_cgemv, | ||||
cblas_cgerc, cblas_cgeru, cblas_chbmv, cblas_chemm, cblas_chemv, cblas_cher2, cblas_cher2k, | cblas_cgerc, cblas_cgeru, cblas_chbmv, cblas_chemm, cblas_chemv, cblas_cher2, cblas_cher2k, | ||||
@@ -94,7 +94,7 @@ | |||||
@cblasobjs = ( cblas_xerbla ); | @cblasobjs = ( cblas_xerbla ); | ||||
@bfcblasobjs = (cblas_sbgemm, cblas_sbdot, cblas_sbstobf16, cblas_sbdtobf16, cblas_sbf16tos, cblas_dbf16tod); | |||||
@bfcblasobjs = (cblas_sbgemm, cblas_sbgemv, cblas_sbdot, cblas_sbstobf16, cblas_sbdtobf16, cblas_sbf16tos, cblas_dbf16tod); | |||||
@exblasobjs = ( | @exblasobjs = ( | ||||
qamax,qamin,qasum,qaxpy,qcabs1,qcopy,qdot,qgbmv,qgemm, | qamax,qamin,qasum,qaxpy,qcabs1,qcopy,qdot,qgbmv,qgemm, | ||||
@@ -48,6 +48,7 @@ SBLAS3OBJS = \ | |||||
ifeq ($(BUILD_BFLOAT16),1) | ifeq ($(BUILD_BFLOAT16),1) | ||||
SBBLAS1OBJS = sbdot.$(SUFFIX) | SBBLAS1OBJS = sbdot.$(SUFFIX) | ||||
SBBLAS2OBJS = sbgemv.$(SUFFIX) | |||||
SBBLAS3OBJS = sbgemm.$(SUFFIX) | SBBLAS3OBJS = sbgemm.$(SUFFIX) | ||||
SBEXTOBJS = sbstobf16.$(SUFFIX) sbdtobf16.$(SUFFIX) sbf16tos.$(SUFFIX) dbf16tod.$(SUFFIX) | SBEXTOBJS = sbstobf16.$(SUFFIX) sbdtobf16.$(SUFFIX) sbf16tos.$(SUFFIX) dbf16tod.$(SUFFIX) | ||||
endif | endif | ||||
@@ -284,6 +285,7 @@ CSBLAS3OBJS = \ | |||||
ifeq ($(BUILD_BFLOAT16),1) | ifeq ($(BUILD_BFLOAT16),1) | ||||
CSBBLAS1OBJS = cblas_sbdot.$(SUFFIX) | CSBBLAS1OBJS = cblas_sbdot.$(SUFFIX) | ||||
CSBBLAS2OBJS = cblas_sbgemv.$(SUFFIX) | |||||
CSBBLAS3OBJS = cblas_sbgemm.$(SUFFIX) | CSBBLAS3OBJS = cblas_sbgemm.$(SUFFIX) | ||||
CSBEXTOBJS = cblas_sbstobf16.$(SUFFIX) cblas_sbdtobf16.$(SUFFIX) cblas_sbf16tos.$(SUFFIX) cblas_dbf16tod.$(SUFFIX) | CSBEXTOBJS = cblas_sbstobf16.$(SUFFIX) cblas_sbdtobf16.$(SUFFIX) cblas_sbf16tos.$(SUFFIX) cblas_dbf16tod.$(SUFFIX) | ||||
endif | endif | ||||
@@ -382,6 +384,7 @@ SBLAS1OBJS += $(CSBLAS1OBJS) | |||||
SBLAS2OBJS += $(CSBLAS2OBJS) | SBLAS2OBJS += $(CSBLAS2OBJS) | ||||
SBLAS3OBJS += $(CSBLAS3OBJS) | SBLAS3OBJS += $(CSBLAS3OBJS) | ||||
SBBLAS1OBJS += $(CSBBLAS1OBJS) | SBBLAS1OBJS += $(CSBBLAS1OBJS) | ||||
SBBLAS2OBJS += $(CSBBLAS2OBJS) | |||||
SBBLAS3OBJS += $(CSBBLAS3OBJS) | SBBLAS3OBJS += $(CSBBLAS3OBJS) | ||||
DBLAS1OBJS += $(CDBLAS1OBJS) | DBLAS1OBJS += $(CDBLAS1OBJS) | ||||
DBLAS2OBJS += $(CDBLAS2OBJS) | DBLAS2OBJS += $(CDBLAS2OBJS) | ||||
@@ -399,7 +402,7 @@ CBAUXOBJS += $(CXERBLAOBJ) | |||||
endif | endif | ||||
SBLASOBJS = $(SBLAS1OBJS) $(SBLAS2OBJS) $(SBLAS3OBJS) | SBLASOBJS = $(SBLAS1OBJS) $(SBLAS2OBJS) $(SBLAS3OBJS) | ||||
SBBLASOBJS = $(SBBLAS1OBJS) $(SBBLAS3OBJS) | |||||
SBBLASOBJS = $(SBBLAS1OBJS) $(SBBLAS2OBJS) $(SBBLAS3OBJS) | |||||
DBLASOBJS = $(DBLAS1OBJS) $(DBLAS2OBJS) $(DBLAS3OBJS) | DBLASOBJS = $(DBLAS1OBJS) $(DBLAS2OBJS) $(DBLAS3OBJS) | ||||
QBLASOBJS = $(QBLAS1OBJS) $(QBLAS2OBJS) $(QBLAS3OBJS) | QBLASOBJS = $(QBLAS1OBJS) $(QBLAS2OBJS) $(QBLAS3OBJS) | ||||
CBLASOBJS = $(CBLAS1OBJS) $(CBLAS2OBJS) $(CBLAS3OBJS) | CBLASOBJS = $(CBLAS1OBJS) $(CBLAS2OBJS) $(CBLAS3OBJS) | ||||
@@ -538,7 +541,7 @@ clean :: | |||||
level1 : $(SBEXTOBJS) $(SBBLAS1OBJS) $(SBLAS1OBJS) $(DBLAS1OBJS) $(QBLAS1OBJS) $(CBLAS1OBJS) $(ZBLAS1OBJS) $(XBLAS1OBJS) | level1 : $(SBEXTOBJS) $(SBBLAS1OBJS) $(SBLAS1OBJS) $(DBLAS1OBJS) $(QBLAS1OBJS) $(CBLAS1OBJS) $(ZBLAS1OBJS) $(XBLAS1OBJS) | ||||
$(AR) $(ARFLAGS) -ru $(TOPDIR)/$(LIBNAME) $^ | $(AR) $(ARFLAGS) -ru $(TOPDIR)/$(LIBNAME) $^ | ||||
level2 : $(SBLAS2OBJS) $(DBLAS2OBJS) $(QBLAS2OBJS) $(CBLAS2OBJS) $(ZBLAS2OBJS) $(XBLAS2OBJS) | |||||
level2 : $(SBBLAS2OBJS) $(SBLAS2OBJS) $(DBLAS2OBJS) $(QBLAS2OBJS) $(CBLAS2OBJS) $(ZBLAS2OBJS) $(XBLAS2OBJS) | |||||
$(AR) $(ARFLAGS) -ru $(TOPDIR)/$(LIBNAME) $^ | $(AR) $(ARFLAGS) -ru $(TOPDIR)/$(LIBNAME) $^ | ||||
level3 : $(SBBLAS3OBJS) $(SBLAS3OBJS) $(DBLAS3OBJS) $(QBLAS3OBJS) $(CBLAS3OBJS) $(ZBLAS3OBJS) $(XBLAS3OBJS) | level3 : $(SBBLAS3OBJS) $(SBLAS3OBJS) $(DBLAS3OBJS) $(QBLAS3OBJS) $(CBLAS3OBJS) $(ZBLAS3OBJS) $(XBLAS3OBJS) | ||||
@@ -929,6 +932,11 @@ xgeru.$(SUFFIX) xgeru.$(PSUFFIX) : zger.c | |||||
xgerc.$(SUFFIX) xgerc.$(PSUFFIX) : zger.c | xgerc.$(SUFFIX) xgerc.$(PSUFFIX) : zger.c | ||||
$(CC) -c $(CFLAGS) -DCONJ $< -o $(@F) | $(CC) -c $(CFLAGS) -DCONJ $< -o $(@F) | ||||
ifeq ($(BUILD_BFLOAT16),1) | |||||
sbgemv.$(SUFFIX) sbgemv.$(PSUFFIX) : sbgemv.c | |||||
$(CC) $(CFLAGS) -c $< -o $(@F) | |||||
endif | |||||
ifndef USE_NETLIB_GEMV | ifndef USE_NETLIB_GEMV | ||||
sgemv.$(SUFFIX) sgemv.$(PSUFFIX): gemv.c | sgemv.$(SUFFIX) sgemv.$(PSUFFIX): gemv.c | ||||
$(CC) -c $(CFLAGS) -o $(@F) $< | $(CC) -c $(CFLAGS) -o $(@F) $< | ||||
@@ -1656,6 +1664,11 @@ cblas_csscal.$(SUFFIX) cblas_csscal.$(PSUFFIX) : zscal.c | |||||
cblas_zdscal.$(SUFFIX) cblas_zdscal.$(PSUFFIX) : zscal.c | cblas_zdscal.$(SUFFIX) cblas_zdscal.$(PSUFFIX) : zscal.c | ||||
$(CC) $(CFLAGS) -DCBLAS -c -DSSCAL $< -o $(@F) | $(CC) $(CFLAGS) -DCBLAS -c -DSSCAL $< -o $(@F) | ||||
ifeq ($(BUILD_BFLOAT16),1) | |||||
cblas_sbgemv.$(SUFFIX) cblas_sbgemv.$(PSUFFIX) : sbgemv.c | |||||
$(CC) -DCBLAS -c $(CFLAGS) $< -o $(@F) | |||||
endif | |||||
cblas_sgemv.$(SUFFIX) cblas_sgemv.$(PSUFFIX): gemv.c | cblas_sgemv.$(SUFFIX) cblas_sgemv.$(PSUFFIX): gemv.c | ||||
$(CC) -DCBLAS -c $(CFLAGS) -o $(@F) $< | $(CC) -DCBLAS -c $(CFLAGS) -o $(@F) $< | ||||
@@ -191,7 +191,6 @@ void CNAME(enum CBLAS_ORDER order, | |||||
} | } | ||||
#endif | #endif | ||||
//printf("m=%d, n=%d, trans=%d, incx=%d, incy=%d, alpha=%f, beta=%f\n", m, n, trans, incx, incy, alpha, beta); | |||||
if ((m==0) || (n==0)) return; | if ((m==0) || (n==0)) return; | ||||
lenx = n; | lenx = n; | ||||
@@ -0,0 +1,210 @@ | |||||
/*********************************************************************/ | |||||
/* Copyright 2009, 2010 The University of Texas at Austin. */ | |||||
/* All rights reserved. */ | |||||
/* */ | |||||
/* Redistribution and use in source and binary forms, with or */ | |||||
/* without modification, are permitted provided that the following */ | |||||
/* conditions are met: */ | |||||
/* */ | |||||
/* 1. Redistributions of source code must retain the above */ | |||||
/* copyright notice, this list of conditions and the following */ | |||||
/* disclaimer. */ | |||||
/* */ | |||||
/* 2. Redistributions in binary form must reproduce the above */ | |||||
/* copyright notice, this list of conditions and the following */ | |||||
/* disclaimer in the documentation and/or other materials */ | |||||
/* provided with the distribution. */ | |||||
/* */ | |||||
/* THIS SOFTWARE IS PROVIDED BY THE UNIVERSITY OF TEXAS AT */ | |||||
/* AUSTIN ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, */ | |||||
/* INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF */ | |||||
/* MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE */ | |||||
/* DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY OF TEXAS AT */ | |||||
/* AUSTIN OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, */ | |||||
/* INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES */ | |||||
/* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE */ | |||||
/* GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR */ | |||||
/* BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF */ | |||||
/* LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT */ | |||||
/* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT */ | |||||
/* OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE */ | |||||
/* POSSIBILITY OF SUCH DAMAGE. */ | |||||
/* */ | |||||
/* The views and conclusions contained in the software and */ | |||||
/* documentation are those of the authors and should not be */ | |||||
/* interpreted as representing official policies, either expressed */ | |||||
/* or implied, of The University of Texas at Austin. */ | |||||
/*********************************************************************/ | |||||
#include <stdio.h> | |||||
#include "common.h" | |||||
#include "l1param.h" | |||||
#ifdef FUNCTION_PROFILE | |||||
#include "functable.h" | |||||
#endif | |||||
#define ERROR_NAME "SBGEMV " | |||||
#ifdef SMP | |||||
static int (*sbgemv_thread[])(BLASLONG, BLASLONG, float, bfloat16 *, BLASLONG, bfloat16 * , BLASLONG, float, float *, BLASLONG, int) = { | |||||
sbgemv_thread_n, sbgemv_thread_t, | |||||
}; | |||||
#endif | |||||
#ifndef CBLAS | |||||
void NAME(char *TRANS, blasint *M, blasint *N, float *ALPHA, bfloat16 *a, blasint *LDA, bfloat16 *x, blasint *INCX, float *BETA, float *y, blasint *INCY) | |||||
{ | |||||
char trans = *TRANS; | |||||
blasint m = *M; | |||||
blasint n = *N; | |||||
blasint lda = *LDA; | |||||
blasint incx = *INCX; | |||||
blasint incy = *INCY; | |||||
float alpha = *ALPHA; | |||||
float beta = *BETA; | |||||
#ifdef SMP | |||||
int nthreads; | |||||
#endif | |||||
int (*sbgemv[])(BLASLONG, BLASLONG, float, bfloat16 *, BLASLONG, bfloat16 * , BLASLONG, float, float *, BLASLONG) = { | |||||
SBGEMV_N, SBGEMV_T, | |||||
}; | |||||
blasint info; | |||||
blasint lenx, leny; | |||||
blasint i; | |||||
PRINT_DEBUG_NAME; | |||||
TOUPPER(trans); | |||||
info = 0; | |||||
i = -1; | |||||
if (trans == 'N') {i = 0;} | |||||
if (trans == 'T') {i = 1;} | |||||
if (trans == 'R') {i = 0;} | |||||
if (trans == 'C') {i = 1;} | |||||
if (incy == 0) {info = 11;} | |||||
if (incx == 0) {info = 8;} | |||||
if (lda < MAX(1, m)) {info = 6;} | |||||
if (n < 0) {info = 3;} | |||||
if (m < 0) {info = 2;} | |||||
if (i < 0) {info = 1;} | |||||
trans = i; | |||||
if (info != 0) { | |||||
BLASFUNC(xerbla)(ERROR_NAME, &info, sizeof(ERROR_NAME)); | |||||
return; | |||||
} | |||||
#else | |||||
void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, blasint m, blasint n, float alpha, bfloat16 *a, blasint lda, bfloat16 *x, blasint incx, float beta, float *y, blasint incy) | |||||
{ | |||||
blasint lenx, leny; | |||||
int trans; | |||||
blasint info, t; | |||||
#ifdef SMP | |||||
int nthreads; | |||||
#endif | |||||
int (*sbgemv[])(BLASLONG, BLASLONG, float, bfloat16 *, BLASLONG, bfloat16 * , BLASLONG, float, float *, BLASLONG) = { | |||||
SBGEMV_N, SBGEMV_T, | |||||
}; | |||||
PRINT_DEBUG_CNAME; | |||||
trans = -1; | |||||
info = 0; | |||||
if (order == CblasColMajor) { // Column Major | |||||
if (TransA == CblasNoTrans || TransA == CblasConjNoTrans) { | |||||
trans = 0; | |||||
} else if (TransA == CblasTrans || TransA == CblasConjTrans) { | |||||
trans = 1; | |||||
} | |||||
} else { // Row Major | |||||
if (TransA == CblasNoTrans || TransA == CblasConjNoTrans) { | |||||
trans = 1; | |||||
} else if (TransA == CblasTrans || TransA == CblasConjTrans) { | |||||
trans = 0; | |||||
} | |||||
t = n; | |||||
n = m; | |||||
m = t; | |||||
} | |||||
info = -1; | |||||
if (incy == 0) {info = 11;} | |||||
if (incx == 0) {info = 8;} | |||||
if (lda < MAX(1, m)) {info = 6;} | |||||
if (n < 0) {info = 3;} | |||||
if (m < 0) {info = 2;} | |||||
if (trans < 0) {info = 1;} | |||||
if (info >= 0) { | |||||
BLASFUNC(xerbla)(ERROR_NAME, &info, sizeof(ERROR_NAME)); | |||||
return; | |||||
} | |||||
#endif | |||||
if ((m==0) || (n==0)) return; | |||||
if (trans) { | |||||
lenx = m; | |||||
leny = n; | |||||
} else { | |||||
lenx = n; | |||||
leny = m; | |||||
} | |||||
if (alpha == ZERO) { | |||||
if (beta != ONE) SCAL_K(leny, 0, 0, beta, y, blasabs(incy), NULL, 0, NULL, 0); | |||||
return; | |||||
} | |||||
IDEBUG_START; | |||||
FUNCTION_PROFILE_START(); | |||||
if (incx < 0) {x -= (lenx - 1) * incx;} | |||||
if (incy < 0) {y -= (leny - 1) * incy;} | |||||
#ifdef SMP | |||||
int thread_thres_row = 20480; | |||||
if (trans) { | |||||
if (n <= thread_thres_row) { | |||||
nthreads = 1; | |||||
} else { | |||||
nthreads = num_cpu_avail(1); | |||||
} | |||||
} else { | |||||
if (m <= thread_thres_row) { | |||||
nthreads = 1; | |||||
} else { | |||||
nthreads = num_cpu_avail(1); | |||||
} | |||||
} | |||||
if (nthreads == 1) { | |||||
#endif | |||||
(sbgemv[(int)trans])(m, n, alpha, a, lda, x, incx, beta, y, incy); | |||||
#ifdef SMP | |||||
} else { | |||||
(sbgemv_thread[(int)trans])(m, n, alpha, a, lda, x, incx, beta, y, incy, nthreads); | |||||
} | |||||
#endif | |||||
FUNCTION_PROFILE_END(1, m * n + m + n, 2 * m * n); | |||||
IDEBUG_END; | |||||
return; | |||||
} |
@@ -48,6 +48,16 @@ ifndef XGEMVTKERNEL | |||||
XGEMVTKERNEL = zgemv_t.S | XGEMVTKERNEL = zgemv_t.S | ||||
endif | endif | ||||
ifeq ($(BUILD_BFLOAT16),1) | |||||
ifndef SBGEMVNKERNEL | |||||
SBGEMVNKERNEL = ../x86_64/sbgemv_n.c | |||||
endif | |||||
ifndef SBGEMVTKERNEL | |||||
SBGEMVTKERNEL = ../x86_64/sbgemv_t.c | |||||
endif | |||||
endif | |||||
### GER ### | ### GER ### | ||||
ifndef SGERKERNEL | ifndef SGERKERNEL | ||||
@@ -234,6 +244,12 @@ XBLASOBJS += \ | |||||
xhemv_U$(TSUFFIX).$(SUFFIX) xhemv_L$(TSUFFIX).$(SUFFIX) xhemv_V$(TSUFFIX).$(SUFFIX) xhemv_M$(TSUFFIX).$(SUFFIX) \ | xhemv_U$(TSUFFIX).$(SUFFIX) xhemv_L$(TSUFFIX).$(SUFFIX) xhemv_V$(TSUFFIX).$(SUFFIX) xhemv_M$(TSUFFIX).$(SUFFIX) \ | ||||
xgeru_k$(TSUFFIX).$(SUFFIX) xgerc_k$(TSUFFIX).$(SUFFIX) xgerv_k$(TSUFFIX).$(SUFFIX) xgerd_k$(TSUFFIX).$(SUFFIX) | xgeru_k$(TSUFFIX).$(SUFFIX) xgerc_k$(TSUFFIX).$(SUFFIX) xgerv_k$(TSUFFIX).$(SUFFIX) xgerd_k$(TSUFFIX).$(SUFFIX) | ||||
ifeq ($(BUILD_BFLOAT16),1) | |||||
SBBLASOBJS += \ | |||||
sbgemv_n$(TSUFFIX).$(SUFFIX) \ | |||||
sbgemv_t$(TSUFFIX).$(SUFFIX) | |||||
endif | |||||
ifneq "$(or $(BUILD_SINGLE), $(BUILD_DOUBLE), $(BUILD_COMPLEX))" "" | ifneq "$(or $(BUILD_SINGLE), $(BUILD_DOUBLE), $(BUILD_COMPLEX))" "" | ||||
$(KDIR)sgemv_n$(TSUFFIX).$(SUFFIX) $(KDIR)sgemv_n$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SGEMVNKERNEL) $(TOPDIR)/common.h $(GEMVDEP) | $(KDIR)sgemv_n$(TSUFFIX).$(SUFFIX) $(KDIR)sgemv_n$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SGEMVNKERNEL) $(TOPDIR)/common.h $(GEMVDEP) | ||||
$(CC) -c $(CFLAGS) -UDOUBLE -UCOMPLEX -UTRANS $< -o $@ | $(CC) -c $(CFLAGS) -UDOUBLE -UCOMPLEX -UTRANS $< -o $@ | ||||
@@ -483,4 +499,10 @@ $(KDIR)xhemv_V$(TSUFFIX).$(SUFFIX) $(KDIR)xhemv_V$(TSUFFIX).$(PSUFFIX) : $(KER | |||||
$(KDIR)xhemv_M$(TSUFFIX).$(SUFFIX) $(KDIR)xhemv_M$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(XHEMV_M_KERNEL) ../symcopy.h | $(KDIR)xhemv_M$(TSUFFIX).$(SUFFIX) $(KDIR)xhemv_M$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(XHEMV_M_KERNEL) ../symcopy.h | ||||
$(CC) -c $(CFLAGS) -DCOMPLEX -DXDOUBLE -DLOWER -DHEMV -DHEMVREV $< -o $@ | $(CC) -c $(CFLAGS) -DCOMPLEX -DXDOUBLE -DLOWER -DHEMV -DHEMVREV $< -o $@ | ||||
ifeq ($(BUILD_BFLOAT16),1) | |||||
$(KDIR)sbgemv_n$(TSUFFIX).$(SUFFIX) $(KDIR)sbgemv_n$(TPSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SBGEMVNKERNEL) | |||||
$(CC) -c $(CFLAGS) -UCOMPLEX $< -o $@ | |||||
$(KDIR)sbgemv_t$(TSUFFIX).$(SUFFIX) $(KDIR)sbgemv_t$(TPSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SBGEMVTKERNEL) | |||||
$(CC) -c $(CFLAGS) -UCOMPLEX $< -o $@ | |||||
endif | |||||
@@ -69,7 +69,7 @@ gotoblas_t TABLE_NAME = { | |||||
snrm2_kTS, sasum_kTS, ssum_kTS, scopy_kTS, sbdot_kTS, | snrm2_kTS, sasum_kTS, ssum_kTS, scopy_kTS, sbdot_kTS, | ||||
dsdot_kTS, | dsdot_kTS, | ||||
srot_kTS, saxpy_kTS, sscal_kTS, sswap_kTS, | srot_kTS, saxpy_kTS, sscal_kTS, sswap_kTS, | ||||
sgemv_nTS, sgemv_tTS, sger_kTS, | |||||
sbgemv_nTS, sbgemv_tTS, sger_kTS, | |||||
ssymv_LTS, ssymv_UTS, | ssymv_LTS, ssymv_UTS, | ||||
sbgemm_kernelTS, sbgemm_betaTS, | sbgemm_kernelTS, sbgemm_betaTS, | ||||
@@ -384,6 +384,14 @@ endif | |||||
GEMVDEP = ../l2param.h | GEMVDEP = ../l2param.h | ||||
ifndef SBGEMVNKERNEL | |||||
SBGEMVNKERNEL = sbgemv_n.c | |||||
endif | |||||
ifndef SBGEMVTKERNEL | |||||
SBGEMVTKERNEL = sbgemv_t.c | |||||
endif | |||||
ifndef SGEMVNKERNEL | ifndef SGEMVNKERNEL | ||||
SGEMVNKERNEL = sgemv_n.c | SGEMVNKERNEL = sgemv_n.c | ||||
endif | endif | ||||
@@ -0,0 +1,795 @@ | |||||
/*************************************************************************** | |||||
Copyright (c) 2014, The OpenBLAS Project | |||||
All rights reserved. | |||||
Redistribution and use in source and binary forms, with or without | |||||
modification, are permitted provided that the following conditions are | |||||
met: | |||||
1. Redistributions of source code must retain the above copyright | |||||
notice, this list of conditions and the following disclaimer. | |||||
2. Redistributions in binary form must reproduce the above copyright | |||||
notice, this list of conditions and the following disclaimer in | |||||
the documentation and/or other materials provided with the | |||||
distribution. | |||||
3. Neither the name of the OpenBLAS project nor the names of | |||||
its contributors may be used to endorse or promote products | |||||
derived from this software without specific prior written permission. | |||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||||
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE | |||||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE | |||||
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||||
*****************************************************************************/ | |||||
#ifndef __BF16_COMMON_MACROS | |||||
#define __BF16_COMMON_MACROS | |||||
#include <immintrin.h> | |||||
#define EXTRACT_LOW_256_FROM_512_2X(reg256, reg512) \ | |||||
reg256##_0 = _mm512_castps512_ps256(reg512##_0); \ | |||||
reg256##_1 = _mm512_castps512_ps256(reg512##_1); | |||||
#define BF16_MATRIX_LOAD_8x32(regArray, a, lda, idx_m, idx_n) \ | |||||
regArray##_0 = _mm512_loadu_si512(&a[(idx_m+0)*lda + idx_n]); \ | |||||
regArray##_1 = _mm512_loadu_si512(&a[(idx_m+1)*lda + idx_n]); \ | |||||
regArray##_2 = _mm512_loadu_si512(&a[(idx_m+2)*lda + idx_n]); \ | |||||
regArray##_3 = _mm512_loadu_si512(&a[(idx_m+3)*lda + idx_n]); \ | |||||
regArray##_4 = _mm512_loadu_si512(&a[(idx_m+4)*lda + idx_n]); \ | |||||
regArray##_5 = _mm512_loadu_si512(&a[(idx_m+5)*lda + idx_n]); \ | |||||
regArray##_6 = _mm512_loadu_si512(&a[(idx_m+6)*lda + idx_n]); \ | |||||
regArray##_7 = _mm512_loadu_si512(&a[(idx_m+7)*lda + idx_n]); | |||||
#define BF16_MATRIX_LOAD_8x16(regArray, a, lda, idx_m, idx_n) \ | |||||
regArray##_0 = _mm256_loadu_si256(&a[(idx_m+0)*lda + idx_n]); \ | |||||
regArray##_1 = _mm256_loadu_si256(&a[(idx_m+1)*lda + idx_n]); \ | |||||
regArray##_2 = _mm256_loadu_si256(&a[(idx_m+2)*lda + idx_n]); \ | |||||
regArray##_3 = _mm256_loadu_si256(&a[(idx_m+3)*lda + idx_n]); \ | |||||
regArray##_4 = _mm256_loadu_si256(&a[(idx_m+4)*lda + idx_n]); \ | |||||
regArray##_5 = _mm256_loadu_si256(&a[(idx_m+5)*lda + idx_n]); \ | |||||
regArray##_6 = _mm256_loadu_si256(&a[(idx_m+6)*lda + idx_n]); \ | |||||
regArray##_7 = _mm256_loadu_si256(&a[(idx_m+7)*lda + idx_n]); | |||||
#define BF16_MATRIX_LOAD_8x8(regArray, a, lda, idx_m, idx_n) \ | |||||
regArray##_0 = _mm_loadu_si128(&a[(idx_m+0)*lda + idx_n]); \ | |||||
regArray##_1 = _mm_loadu_si128(&a[(idx_m+1)*lda + idx_n]); \ | |||||
regArray##_2 = _mm_loadu_si128(&a[(idx_m+2)*lda + idx_n]); \ | |||||
regArray##_3 = _mm_loadu_si128(&a[(idx_m+3)*lda + idx_n]); \ | |||||
regArray##_4 = _mm_loadu_si128(&a[(idx_m+4)*lda + idx_n]); \ | |||||
regArray##_5 = _mm_loadu_si128(&a[(idx_m+5)*lda + idx_n]); \ | |||||
regArray##_6 = _mm_loadu_si128(&a[(idx_m+6)*lda + idx_n]); \ | |||||
regArray##_7 = _mm_loadu_si128(&a[(idx_m+7)*lda + idx_n]); | |||||
#define BF16_MATRIX_LOAD_1x32(regArray, a, lda, idx_m, idx_n) \ | |||||
regArray = _mm512_loadu_si512(&a[idx_m*lda + idx_n]); | |||||
#define BF16_MATRIX_MASKZ_LOAD_8x32(regArray, a, lda, idx_m, idx_n, mask) \ | |||||
regArray##_0 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+0)*lda + idx_n]); \ | |||||
regArray##_1 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+1)*lda + idx_n]); \ | |||||
regArray##_2 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+2)*lda + idx_n]); \ | |||||
regArray##_3 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+3)*lda + idx_n]); \ | |||||
regArray##_4 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+4)*lda + idx_n]); \ | |||||
regArray##_5 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+5)*lda + idx_n]); \ | |||||
regArray##_6 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+6)*lda + idx_n]); \ | |||||
regArray##_7 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+7)*lda + idx_n]); | |||||
#define BF16_MATRIX_MASKZ_LOAD_8x16(regArray, a, lda, idx_m, idx_n, mask) \ | |||||
regArray##_0 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+0)*lda + idx_n]); \ | |||||
regArray##_1 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+1)*lda + idx_n]); \ | |||||
regArray##_2 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+2)*lda + idx_n]); \ | |||||
regArray##_3 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+3)*lda + idx_n]); \ | |||||
regArray##_4 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+4)*lda + idx_n]); \ | |||||
regArray##_5 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+5)*lda + idx_n]); \ | |||||
regArray##_6 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+6)*lda + idx_n]); \ | |||||
regArray##_7 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+7)*lda + idx_n]); | |||||
#define BF16_MATRIX_MASKZ_LOAD_8x8(regArray, a, lda, idx_m, idx_n, mask) \ | |||||
regArray##_0 = _mm_maskz_loadu_epi16(mask, &a[(idx_m+0)*lda + idx_n]); \ | |||||
regArray##_1 = _mm_maskz_loadu_epi16(mask, &a[(idx_m+1)*lda + idx_n]); \ | |||||
regArray##_2 = _mm_maskz_loadu_epi16(mask, &a[(idx_m+2)*lda + idx_n]); \ | |||||
regArray##_3 = _mm_maskz_loadu_epi16(mask, &a[(idx_m+3)*lda + idx_n]); \ | |||||
regArray##_4 = _mm_maskz_loadu_epi16(mask, &a[(idx_m+4)*lda + idx_n]); \ | |||||
regArray##_5 = _mm_maskz_loadu_epi16(mask, &a[(idx_m+5)*lda + idx_n]); \ | |||||
regArray##_6 = _mm_maskz_loadu_epi16(mask, &a[(idx_m+6)*lda + idx_n]); \ | |||||
regArray##_7 = _mm_maskz_loadu_epi16(mask, &a[(idx_m+7)*lda + idx_n]); | |||||
#define BF16_MATRIX_MASKZ_LOAD_4x32(regArray, a, lda, idx_m, idx_n, mask) \ | |||||
regArray##_0 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+0)*lda + idx_n]); \ | |||||
regArray##_1 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+1)*lda + idx_n]); \ | |||||
regArray##_2 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+2)*lda + idx_n]); \ | |||||
regArray##_3 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+3)*lda + idx_n]); | |||||
#define BF16_MATRIX_MASKZ_LOAD_4x16(regArray, a, lda, idx_m, idx_n, mask) \ | |||||
regArray##_0 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+0)*lda + idx_n]); \ | |||||
regArray##_1 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+1)*lda + idx_n]); \ | |||||
regArray##_2 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+2)*lda + idx_n]); \ | |||||
regArray##_3 = _mm256_maskz_loadu_epi16(mask, &a[(idx_m+3)*lda + idx_n]); | |||||
#define BF16_MATRIX_MASKZ_LOAD_8x32_2(regArray, a, lda, idx_m, idx_n, mask) \ | |||||
regArray##_0 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+0)*lda + idx_n]); \ | |||||
regArray##_1 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+2)*lda + idx_n]); \ | |||||
regArray##_2 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+4)*lda + idx_n]); \ | |||||
regArray##_3 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+6)*lda + idx_n]); \ | |||||
regArray##_4 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+8)*lda + idx_n]); \ | |||||
regArray##_5 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+10)*lda + idx_n]); \ | |||||
regArray##_6 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+12)*lda + idx_n]); \ | |||||
regArray##_7 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+14)*lda + idx_n]); | |||||
#define BF16_MATRIX_MASKZ_LOAD_4x32_2(regArray, a, lda, idx_m, idx_n, mask) \ | |||||
regArray##_0 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+0)*lda + idx_n]); \ | |||||
regArray##_1 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+2)*lda + idx_n]); \ | |||||
regArray##_2 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+4)*lda + idx_n]); \ | |||||
regArray##_3 = _mm512_maskz_loadu_epi16(mask, &a[(idx_m+6)*lda + idx_n]); | |||||
#define BF16_MATRIX_MASKZ_LOAD_1x32(regArray, a, lda, idx_m, idx_n, mask) \ | |||||
regArray = _mm512_maskz_loadu_epi16(mask, &a[idx_m*lda + idx_n]); | |||||
#define BF16_VECTOR_LOAD_1x32(reg, x, idx_n) \ | |||||
reg = _mm512_loadu_si512(x + idx_n); | |||||
#define BF16_VECTOR_LOAD_1x16(reg, x, idx_n) \ | |||||
reg = _mm256_loadu_si256(x + idx_n); | |||||
#define BF16_VECTOR_LOAD_1x8(reg, x, idx_n) \ | |||||
reg = _mm_loadu_si128(x + idx_n); | |||||
#define BF16_VECTOR_MASKZ_LOAD_1x32(reg, x, idx_n, mask) \ | |||||
reg = _mm512_maskz_loadu_epi16(mask, x + idx_n); | |||||
#define BF16_VECTOR_MASKZ_LOAD_1x16(reg, x, idx_n, mask) \ | |||||
reg = _mm256_maskz_loadu_epi16(mask, x + idx_n); | |||||
#define BF16_VECTOR_MASKZ_LOAD_1x8(reg, x, idx_n, mask) \ | |||||
reg = _mm_maskz_loadu_epi16(mask, x + idx_n); | |||||
/* 2-step interleave for matrix against 8 rows with 32 BF16 elements per row | |||||
Input - register array of 8 rows of raw-major matrix | |||||
Output - the output of Step 2 | |||||
Step 1: 2-element interleave for matrix | |||||
|a0|a1|b0|b1|a2|a3|b2|b3|a8 |a9 |b8 |b9 |a10|a11|b10|b11|a16|a17|b16|b17|a18|a19|b18|b19|a24|a25|b24|b25|a26|a27|b26|b27 | |||||
|c0|c1|d0|d1|c2|c3|d2|d3|c8 |c9 |d8 |d9 |c10|c11|d10|d11|c16|c17|d16|d17|c18|c19|d18|d19|c24|c25|d24|d25|c26|c27|d26|d27 | |||||
|e0|e1|f0|f1|e2|e3|f2|f3|e8 |e9 |f8 |f9 |e10|e11|f10|f11|e16|e17|f16|f17|e18|e19|f18|f19|e24|e25|f24|f25|e26|e27|f26|f27 | |||||
|g0|g1|h0|h1|g2|g3|h2|h3|g8 |g9 |h8 |h9 |g10|g11|h10|h11|g16|g17|h16|h17|g18|g19|h18|h19|g24|g25|h24|h25|g26|g27|h26|h27 | |||||
|a4|a5|b4|b5|a6|a7|b6|b7|a12|a13|b12|b13|a14|a15|b14|b15|a20|a21|b20|b21|a22|a23|b22|b23|a28|a29|b28|b29|a30|a31|b30|b31 | |||||
|c4|c5|d4|d5|c6|c7|d6|d7|c12|c13|d12|d13|c14|c15|d14|d15|c20|c21|d20|d21|c22|c23|d22|d23|c28|c29|d28|d29|c30|c31|d30|d31 | |||||
|e4|e5|f4|f5|e6|e7|f6|f7|e12|e13|f12|f13|e14|e15|f14|f15|e20|e21|f20|f21|e22|e23|f22|f23|e28|e29|f28|f29|e30|e31|f30|f31 | |||||
|g4|g5|h4|h5|g6|g7|h6|h7|g12|g13|h12|h13|g14|g15|h14|h15|g20|g21|h20|h21|g22|g23|h22|h23|g28|g29|h28|h29|g30|g31|h30|h31 | |||||
Step 2: 4-element interleave for matrix | |||||
|a0|a1|b0|b1|c0|c1|d0|d1|a8 |a9 |b8 |b9 |c8 |c9 |d8 |d9 |a16|a17|b16|b17|c16|c17|d16|d17|a24|a25|b24|b25|c24|c25|d24|d25 | |||||
|a2|a3|b2|b3|c2|c3|d2|d3|a10|a11|b10|b11|c10|c11|d10|d11|a18|a19|b18|b19|c18|c19|d18|d19|a26|a27|b26|b27|c26|c27|d26|d27 | |||||
|e0|e1|f0|f1|g0|g1|h0|h1|e8 |e9 |f8 |f9 |g8 |g9 |h8 |h9 |e16|e17|f16|f17|g16|g17|h16|h17|e24|e25|f24|f25|g24|g25|h24|h25 | |||||
|e2|e3|f2|f3|g2|g3|h2|h3|e10|e11|f10|f11|g10|g11|h10|h11|e18|e19|f18|f19|g18|g19|h18|h19|e26|e27|f26|f27|g26|g27|h26|h27 | |||||
|a4|a5|b4|b5|c4|c5|d4|d5|a12|a13|b12|b13|c12|c13|d12|d13|a20|a21|b20|b21|c20|c21|d20|d21|a28|a29|b28|b29|c28|c29|d28|d29 | |||||
|a6|a7|b6|b7|c6|c7|d6|d7|a14|a15|b14|b15|c14|c15|d14|d15|a22|a23|b22|b23|c22|c23|d22|d23|a30|a31|b30|b31|c30|c31|d30|d31 | |||||
|e4|e5|f4|f5|g4|g5|h4|h5|e12|e13|f12|f13|g12|g13|h12|h13|e20|e21|f20|f21|g20|g21|h20|h21|e28|e29|f28|f29|g28|g29|h28|h29 | |||||
|e6|e7|f6|f7|g6|g7|h6|h7|e14|e15|f14|f15|g14|g15|h14|h15|e22|e23|f22|f23|g22|g23|h22|h23|e30|e31|f30|f31|g30|g31|h30|h31 | |||||
*/ | |||||
#define BF16_INTERLEAVE_8x32(regArray) \ | |||||
regArray##_8 = _mm512_unpacklo_epi32(regArray##_0, regArray##_1); \ | |||||
regArray##_9 = _mm512_unpacklo_epi32(regArray##_2, regArray##_3); \ | |||||
regArray##_10 = _mm512_unpacklo_epi32(regArray##_4, regArray##_5); \ | |||||
regArray##_11 = _mm512_unpacklo_epi32(regArray##_6, regArray##_7); \ | |||||
regArray##_12 = _mm512_unpackhi_epi32(regArray##_0, regArray##_1); \ | |||||
regArray##_13 = _mm512_unpackhi_epi32(regArray##_2, regArray##_3); \ | |||||
regArray##_14 = _mm512_unpackhi_epi32(regArray##_4, regArray##_5); \ | |||||
regArray##_15 = _mm512_unpackhi_epi32(regArray##_6, regArray##_7); \ | |||||
\ | |||||
regArray##_0 = _mm512_unpacklo_epi64(regArray##_8, regArray##_9); \ | |||||
regArray##_1 = _mm512_unpackhi_epi64(regArray##_8, regArray##_9); \ | |||||
regArray##_2 = _mm512_unpacklo_epi64(regArray##_10, regArray##_11); \ | |||||
regArray##_3 = _mm512_unpackhi_epi64(regArray##_10, regArray##_11); \ | |||||
regArray##_4 = _mm512_unpacklo_epi64(regArray##_12, regArray##_13); \ | |||||
regArray##_5 = _mm512_unpackhi_epi64(regArray##_12, regArray##_13); \ | |||||
regArray##_6 = _mm512_unpacklo_epi64(regArray##_14, regArray##_15); \ | |||||
regArray##_7 = _mm512_unpackhi_epi64(regArray##_14, regArray##_15); | |||||
/* 2-step interleave for matrix against 8 rows with 16 BF16 elements per row | |||||
Input - register array of 8 rows of raw-major matrix | |||||
Output - the output of Step 2 | |||||
Step 1: 2-element interleave for matrix | |||||
|a0|a1|b0|b1|a2|a3|b2|b3|a8 |a9 |b8 |b9 |a10|a11|b10|b11 | |||||
|c0|c1|d0|d1|c2|c3|d2|d3|c8 |c9 |d8 |d9 |c10|c11|d10|d11 | |||||
|e0|e1|f0|f1|e2|e3|f2|f3|e8 |e9 |f8 |f9 |e10|e11|f10|f11 | |||||
|g0|g1|h0|h1|g2|g3|h2|h3|g8 |g9 |h8 |h9 |g10|g11|h10|h11 | |||||
|a4|a5|b4|b5|a6|a7|b6|b7|a12|a13|b12|b13|a14|a15|b14|b15 | |||||
|c4|c5|d4|d5|c6|c7|d6|d7|c12|c13|d12|d13|c14|c15|d14|d15 | |||||
|e4|e5|f4|f5|e6|e7|f6|f7|e12|e13|f12|f13|e14|e15|f14|f15 | |||||
|g4|g5|h4|h5|g6|g7|h6|h7|g12|g13|h12|h13|g14|g15|h14|h15 | |||||
Step 2: 4-element interleave for matrix | |||||
|a0|a1|b0|b1|c0|c1|d0|d1|a8 |a9 |b8 |b9 |c8 |c9 |d8 |d9 | |||||
|a2|a3|b2|b3|c2|c3|d2|d3|a10|a11|b10|b11|c10|c11|d10|d11 | |||||
|e0|e1|f0|f1|g0|g1|h0|h1|e8 |e9 |f8 |f9 |g8 |g9 |h8 |h9 | |||||
|e2|e3|f2|f3|g2|g3|h2|h3|e10|e11|f10|f11|g10|g11|h10|h11 | |||||
|a4|a5|b4|b5|c4|c5|d4|d5|a12|a13|b12|b13|c12|c13|d12|d13 | |||||
|a6|a7|b6|b7|c6|c7|d6|d7|a14|a15|b14|b15|c14|c15|d14|d15 | |||||
|e4|e5|f4|f5|g4|g5|h4|h5|e12|e13|f12|f13|g12|g13|h12|h13 | |||||
|e6|e7|f6|f7|g6|g7|h6|h7|e14|e15|f14|f15|g14|g15|h14|h15 | |||||
*/ | |||||
#define BF16_INTERLEAVE_8x16(regArray) \ | |||||
regArray##_8 = _mm256_unpacklo_epi32(regArray##_0, regArray##_1); \ | |||||
regArray##_9 = _mm256_unpacklo_epi32(regArray##_2, regArray##_3); \ | |||||
regArray##_10 = _mm256_unpacklo_epi32(regArray##_4, regArray##_5); \ | |||||
regArray##_11 = _mm256_unpacklo_epi32(regArray##_6, regArray##_7); \ | |||||
regArray##_12 = _mm256_unpackhi_epi32(regArray##_0, regArray##_1); \ | |||||
regArray##_13 = _mm256_unpackhi_epi32(regArray##_2, regArray##_3); \ | |||||
regArray##_14 = _mm256_unpackhi_epi32(regArray##_4, regArray##_5); \ | |||||
regArray##_15 = _mm256_unpackhi_epi32(regArray##_6, regArray##_7); \ | |||||
\ | |||||
regArray##_0 = _mm256_unpacklo_epi64(regArray##_8, regArray##_9); \ | |||||
regArray##_1 = _mm256_unpackhi_epi64(regArray##_8, regArray##_9); \ | |||||
regArray##_2 = _mm256_unpacklo_epi64(regArray##_10, regArray##_11); \ | |||||
regArray##_3 = _mm256_unpackhi_epi64(regArray##_10, regArray##_11); \ | |||||
regArray##_4 = _mm256_unpacklo_epi64(regArray##_12, regArray##_13); \ | |||||
regArray##_5 = _mm256_unpackhi_epi64(regArray##_12, regArray##_13); \ | |||||
regArray##_6 = _mm256_unpacklo_epi64(regArray##_14, regArray##_15); \ | |||||
regArray##_7 = _mm256_unpackhi_epi64(regArray##_14, regArray##_15); | |||||
/* 2-step interleave for matrix against 8 rows with 32 BF16 elements per row | |||||
Input - register array of 8 rows of raw-major matrix | |||||
Output - the output of Step 2 | |||||
Step 1: 2-element interleave for matrix | |||||
|a0|a1|b0|b1|a2|a3|b2|b3|a8 |a9 |b8 |b9 |a10|a11|b10|b11|a16|a17|b16|b17|a18|a19|b18|b19|a24|a25|b24|b25|a26|a27|b26|b27 | |||||
|c0|c1|d0|d1|c2|c3|d2|d3|c8 |c9 |d8 |d9 |c10|c11|d10|d11|c16|c17|d16|d17|c18|c19|d18|d19|c24|c25|d24|d25|c26|c27|d26|d27 | |||||
|a4|a5|b4|b5|a6|a7|b6|b7|a12|a13|b12|b13|a14|a15|b14|b15|a20|a21|b20|b21|a22|a23|b22|b23|a28|a29|b28|b29|a30|a31|b30|b31 | |||||
|c4|c5|d4|d5|c6|c7|d6|d7|c12|c13|d12|d13|c14|c15|d14|d15|c20|c21|d20|d21|c22|c23|d22|d23|c28|c29|d28|d29|c30|c31|d30|d31 | |||||
Step 2: 4-element interleave for matrix | |||||
|a0|a1|b0|b1|c0|c1|d0|d1|a8 |a9 |b8 |b9 |c8 |c9 |d8 |d9 |a16|a17|b16|b17|c16|c17|d16|d17|a24|a25|b24|b25|c24|c25|d24|d25 | |||||
|a2|a3|b2|b3|c2|c3|d2|d3|a10|a11|b10|b11|c10|c11|d10|d11|a18|a19|b18|b19|c18|c19|d18|d19|a26|a27|b26|b27|c26|c27|d26|d27 | |||||
|a4|a5|b4|b5|c4|c5|d4|d5|a12|a13|b12|b13|c12|c13|d12|d13|a20|a21|b20|b21|c20|c21|d20|d21|a28|a29|b28|b29|c28|c29|d28|d29 | |||||
|a6|a7|b6|b7|c6|c7|d6|d7|a14|a15|b14|b15|c14|c15|d14|d15|a22|a23|b22|b23|c22|c23|d22|d23|a30|a31|b30|b31|c30|c31|d30|d31 | |||||
*/ | |||||
#define BF16_INTERLEAVE_4x32(regArray) \ | |||||
regArray##_4 = _mm512_unpacklo_epi32(regArray##_0, regArray##_1); \ | |||||
regArray##_5 = _mm512_unpacklo_epi32(regArray##_2, regArray##_3); \ | |||||
regArray##_6 = _mm512_unpackhi_epi32(regArray##_0, regArray##_1); \ | |||||
regArray##_7 = _mm512_unpackhi_epi32(regArray##_2, regArray##_3); \ | |||||
\ | |||||
regArray##_0 = _mm512_unpacklo_epi64(regArray##_4, regArray##_5); \ | |||||
regArray##_1 = _mm512_unpackhi_epi64(regArray##_4, regArray##_5); \ | |||||
regArray##_2 = _mm512_unpacklo_epi64(regArray##_6, regArray##_7); \ | |||||
regArray##_3 = _mm512_unpackhi_epi64(regArray##_6, regArray##_7); | |||||
/* 2-step interleave for matrix against 8 rows with 16 BF16 elements per row | |||||
Input - register array of 8 rows of raw-major matrix | |||||
Output - the output of Step 2 | |||||
Step 1: 2-element interleave for matrix | |||||
|a0|a1|b0|b1|a2|a3|b2|b3|a8 |a9 |b8 |b9 |a10|a11|b10|b11 | |||||
|c0|c1|d0|d1|c2|c3|d2|d3|c8 |c9 |d8 |d9 |c10|c11|d10|d11 | |||||
|a4|a5|b4|b5|a6|a7|b6|b7|a12|a13|b12|b13|a14|a15|b14|b15 | |||||
|c4|c5|d4|d5|c6|c7|d6|d7|c12|c13|d12|d13|c14|c15|d14|d15 | |||||
Step 2: 4-element interleave for matrix | |||||
|a0|a1|b0|b1|c0|c1|d0|d1|a8 |a9 |b8 |b9 |c8 |c9 |d8 |d9 | |||||
|a2|a3|b2|b3|c2|c3|d2|d3|a10|a11|b10|b11|c10|c11|d10|d11 | |||||
|a4|a5|b4|b5|c4|c5|d4|d5|a12|a13|b12|b13|c12|c13|d12|d13 | |||||
|a6|a7|b6|b7|c6|c7|d6|d7|a14|a15|b14|b15|c14|c15|d14|d15 | |||||
*/ | |||||
#define BF16_INTERLEAVE_4x16(regArray) \ | |||||
regArray##_4 = _mm256_unpacklo_epi32(regArray##_0, regArray##_1); \ | |||||
regArray##_5 = _mm256_unpacklo_epi32(regArray##_2, regArray##_3); \ | |||||
regArray##_6 = _mm256_unpackhi_epi32(regArray##_0, regArray##_1); \ | |||||
regArray##_7 = _mm256_unpackhi_epi32(regArray##_2, regArray##_3); \ | |||||
\ | |||||
regArray##_0 = _mm256_unpacklo_epi64(regArray##_4, regArray##_5); \ | |||||
regArray##_1 = _mm256_unpackhi_epi64(regArray##_4, regArray##_5); \ | |||||
regArray##_2 = _mm256_unpacklo_epi64(regArray##_6, regArray##_7); \ | |||||
regArray##_3 = _mm256_unpackhi_epi64(regArray##_6, regArray##_7); | |||||
/* 2-step interleave for x with 32 BF16 elements | |||||
Input - original vector | |||||
Output - the output of Step 2 | |||||
Step 1: 2-element interleave for x: | |||||
|x0|x1|x0|x1|x2|x3|x2|x3|x8 |x9 |x8 |x9 |x10|x11|x10|x11|x16|x17|x16|x17|x18|x19|x18|x19|x24|x25|x24|x25|x26|x27|x26|x27 | |||||
|x4|x5|x4|x5|x6|x7|x6|x7|x12|x13|x12|x13|x14|x15|x14|x15|x20|x21|x20|x21|x22|x23|x22|x23|x28|x29|x28|x29|x30|x31|x30|x31 | |||||
Step 2: 4-element interleave for x: | |||||
|x0|x1|x0|x1|x0|x1|x0|x1|x8 |x9 |x8 |x9 |x8 |x9 |x8 |x9 |x16|x17|x16|x17|x16|x17|x16|x17|x24|x25|x24|x25|x24|x25|x24|x25 | |||||
|x2|x3|x2|x3|x2|x3|x2|x3|x10|x11|x10|x11|x10|x11|x10|x11|x18|x19|x18|x19|x18|x19|x18|x19|x26|x27|x26|x27|x26|x27|x26|x27 | |||||
|x4|x5|x4|x5|x4|x5|x4|x5|x12|x13|x12|x13|x12|x13|x12|x13|x20|x21|x20|x21|x20|x21|x20|x21|x28|x29|x28|x29|x28|x29|x28|x29 | |||||
|x6|x7|x6|x7|x6|x7|x6|x7|x14|x15|x14|x15|x14|x15|x14|x15|x22|x23|x22|x23|x22|x23|x22|x23|x30|x31|x30|x31|x30|x31|x30|x31 | |||||
*/ | |||||
#define BF16_INTERLEAVE_1x32(regArray) \ | |||||
regArray##_1 = _mm512_unpacklo_epi32(regArray##_0, regArray##_0); \ | |||||
regArray##_3 = _mm512_unpackhi_epi32(regArray##_0, regArray##_0); \ | |||||
\ | |||||
regArray##_0 = _mm512_unpacklo_epi64(regArray##_1, regArray##_1); \ | |||||
regArray##_1 = _mm512_unpackhi_epi64(regArray##_1, regArray##_1); \ | |||||
regArray##_2 = _mm512_unpacklo_epi64(regArray##_3, regArray##_3); \ | |||||
regArray##_3 = _mm512_unpackhi_epi64(regArray##_3, regArray##_3); | |||||
/* 2-step interleave for x with 16 BF16 elements | |||||
Input - original vector | |||||
Output - the output of Step 2 | |||||
Step 1: 2-element interleave for x: | |||||
|x0|x1|x0|x1|x2|x3|x2|x3|x8 |x9 |x8 |x9 |x10|x11|x10|x11 | |||||
|x4|x5|x4|x5|x6|x7|x6|x7|x12|x13|x12|x13|x14|x15|x14|x15 | |||||
Step 2: 4-element interleave for x: | |||||
|x0|x1|x0|x1|x0|x1|x0|x1|x8 |x9 |x8 |x9 |x8 |x9 |x8 |x9 | |||||
|x2|x3|x2|x3|x2|x3|x2|x3|x10|x11|x10|x11|x10|x11|x10|x11 | |||||
|x4|x5|x4|x5|x4|x5|x4|x5|x12|x13|x12|x13|x12|x13|x12|x13 | |||||
|x6|x7|x6|x7|x6|x7|x6|x7|x14|x15|x14|x15|x14|x15|x14|x15 | |||||
*/ | |||||
#define BF16_INTERLEAVE_1x16(regArray) \ | |||||
regArray##_1 = _mm256_unpacklo_epi32(regArray##_0, regArray##_0); \ | |||||
regArray##_3 = _mm256_unpackhi_epi32(regArray##_0, regArray##_0); \ | |||||
\ | |||||
regArray##_0 = _mm256_unpacklo_epi64(regArray##_1, regArray##_1); \ | |||||
regArray##_1 = _mm256_unpackhi_epi64(regArray##_1, regArray##_1); \ | |||||
regArray##_2 = _mm256_unpacklo_epi64(regArray##_3, regArray##_3); \ | |||||
regArray##_3 = _mm256_unpackhi_epi64(regArray##_3, regArray##_3); | |||||
/* 1-step interleave to exchange the high-256s bit and low-256 bits of 4 pair of registers | |||||
|a0|a1|...|a14|a15|i0|i1|...|i14|i15| | |||||
|b0|b1|...|b14|b15|j0|j1|...|j14|j15| | |||||
|c0|c1|...|c14|c15|k0|k1|...|k14|k15| | |||||
|d0|d1|...|d14|d15|l0|l1|...|l14|l15| | |||||
|e0|e1|...|e14|e15|m0|m1|...|m14|m15| | |||||
|f0|f1|...|f14|f15|n0|n1|...|n14|n15| | |||||
|g0|g1|...|g14|g15|o0|o1|...|o14|o15| | |||||
|h0|h1|...|h14|h15|p0|p1|...|p14|p15| | |||||
*/ | |||||
#define BF16_INTERLEAVE256_8x32(regArray) \ | |||||
regArray##_0 = _mm512_shuffle_i32x4(regArray##_8, regArray##_12, 0x44); \ | |||||
regArray##_1 = _mm512_shuffle_i32x4(regArray##_8, regArray##_12, 0xee); \ | |||||
regArray##_2 = _mm512_shuffle_i32x4(regArray##_9, regArray##_13, 0x44); \ | |||||
regArray##_3 = _mm512_shuffle_i32x4(regArray##_9, regArray##_13, 0xee); \ | |||||
regArray##_4 = _mm512_shuffle_i32x4(regArray##_10, regArray##_14, 0x44); \ | |||||
regArray##_5 = _mm512_shuffle_i32x4(regArray##_10, regArray##_14, 0xee); \ | |||||
regArray##_6 = _mm512_shuffle_i32x4(regArray##_11, regArray##_15, 0x44); \ | |||||
regArray##_7 = _mm512_shuffle_i32x4(regArray##_11, regArray##_15, 0xee); | |||||
/* 1-step interleave to exchange the high-256s bit and low-256 bits of 2 pair of registers | |||||
|a0|a1|...|a14|a15|e0|e1|...|e14|e15| | |||||
|b0|b1|...|b14|b15|f0|f1|...|f14|f15| | |||||
|c0|c1|...|c14|c15|g0|g1|...|g14|g15| | |||||
|d0|d1|...|d14|d15|h0|h1|...|h14|h15| | |||||
*/ | |||||
#define BF16_INTERLEAVE256_4x32(regArray) \ | |||||
regArray##_0 = _mm512_shuffle_i32x4(regArray##_4, regArray##_6, 0x44); \ | |||||
regArray##_1 = _mm512_shuffle_i32x4(regArray##_4, regArray##_6, 0xee); \ | |||||
regArray##_2 = _mm512_shuffle_i32x4(regArray##_5, regArray##_7, 0x44); \ | |||||
regArray##_3 = _mm512_shuffle_i32x4(regArray##_5, regArray##_7, 0xee); | |||||
#define BF16_PERMUTE_8x32(idx, regArray) \ | |||||
regArray##_8 = _mm512_permutexvar_epi16(idx, regArray##_0); \ | |||||
regArray##_9 = _mm512_permutexvar_epi16(idx, regArray##_1); \ | |||||
regArray##_10 = _mm512_permutexvar_epi16(idx, regArray##_2); \ | |||||
regArray##_11 = _mm512_permutexvar_epi16(idx, regArray##_3); \ | |||||
regArray##_12 = _mm512_permutexvar_epi16(idx, regArray##_4); \ | |||||
regArray##_13 = _mm512_permutexvar_epi16(idx, regArray##_5); \ | |||||
regArray##_14 = _mm512_permutexvar_epi16(idx, regArray##_6); \ | |||||
regArray##_15 = _mm512_permutexvar_epi16(idx, regArray##_7); | |||||
#define BF16_PERMUTE_8x32_2(idx, regArray) \ | |||||
regArray##_8 = _mm512_permutexvar_epi32(idx, regArray##_0); \ | |||||
regArray##_9 = _mm512_permutexvar_epi32(idx, regArray##_1); \ | |||||
regArray##_10 = _mm512_permutexvar_epi32(idx, regArray##_2); \ | |||||
regArray##_11 = _mm512_permutexvar_epi32(idx, regArray##_3); \ | |||||
regArray##_12 = _mm512_permutexvar_epi32(idx, regArray##_4); \ | |||||
regArray##_13 = _mm512_permutexvar_epi32(idx, regArray##_5); \ | |||||
regArray##_14 = _mm512_permutexvar_epi32(idx, regArray##_6); \ | |||||
regArray##_15 = _mm512_permutexvar_epi32(idx, regArray##_7); | |||||
#define BF16_PERMUTE_4x32(idx, regArray) \ | |||||
regArray##_4 = _mm512_permutexvar_epi16(idx, regArray##_0); \ | |||||
regArray##_5 = _mm512_permutexvar_epi16(idx, regArray##_1); \ | |||||
regArray##_6 = _mm512_permutexvar_epi16(idx, regArray##_2); \ | |||||
regArray##_7 = _mm512_permutexvar_epi16(idx, regArray##_3); | |||||
#define BF16_PERMUTE_4x32_2(idx, regArray) \ | |||||
regArray##_4 = _mm512_permutexvar_epi32(idx, regArray##_0); \ | |||||
regArray##_5 = _mm512_permutexvar_epi32(idx, regArray##_1); \ | |||||
regArray##_6 = _mm512_permutexvar_epi32(idx, regArray##_2); \ | |||||
regArray##_7 = _mm512_permutexvar_epi32(idx, regArray##_3); | |||||
/* Calculate the dot result for 2-step interleaved matrix and vector | |||||
(Assume throughput for _mm512_dpbf16_ps is 0.5, tunable per platform) | |||||
*/ | |||||
#define BF16_2STEP_INTERLEAVED_DOT_8x32(accumArray, matArray, xArray) \ | |||||
accumArray##_0 = _mm512_dpbf16_ps(accumArray##_0, (__m512bh) matArray##_0, (__m512bh) xArray##_0); \ | |||||
accumArray##_1 = _mm512_dpbf16_ps(accumArray##_1, (__m512bh) matArray##_2, (__m512bh) xArray##_0); \ | |||||
accumArray##_0 = _mm512_dpbf16_ps(accumArray##_0, (__m512bh) matArray##_1, (__m512bh) xArray##_1); \ | |||||
accumArray##_1 = _mm512_dpbf16_ps(accumArray##_1, (__m512bh) matArray##_3, (__m512bh) xArray##_1); \ | |||||
accumArray##_0 = _mm512_dpbf16_ps(accumArray##_0, (__m512bh) matArray##_4, (__m512bh) xArray##_2); \ | |||||
accumArray##_1 = _mm512_dpbf16_ps(accumArray##_1, (__m512bh) matArray##_6, (__m512bh) xArray##_2); \ | |||||
accumArray##_0 = _mm512_dpbf16_ps(accumArray##_0, (__m512bh) matArray##_5, (__m512bh) xArray##_3); \ | |||||
accumArray##_1 = _mm512_dpbf16_ps(accumArray##_1, (__m512bh) matArray##_7, (__m512bh) xArray##_3); | |||||
/* Calculate the dot result for 2-step interleaved matrix and vector | |||||
(Assume throughput for _mm256_dpbf16_ps is 0.5, tunable per platform) | |||||
*/ | |||||
#define BF16_2STEP_INTERLEAVED_DOT_8x16(accumArray, matArray, xArray) \ | |||||
accumArray##_0 = _mm256_dpbf16_ps(accumArray##_0, (__m256bh) matArray##_0, (__m256bh) xArray##_0); \ | |||||
accumArray##_1 = _mm256_dpbf16_ps(accumArray##_1, (__m256bh) matArray##_2, (__m256bh) xArray##_0); \ | |||||
accumArray##_0 = _mm256_dpbf16_ps(accumArray##_0, (__m256bh) matArray##_1, (__m256bh) xArray##_1); \ | |||||
accumArray##_1 = _mm256_dpbf16_ps(accumArray##_1, (__m256bh) matArray##_3, (__m256bh) xArray##_1); \ | |||||
accumArray##_0 = _mm256_dpbf16_ps(accumArray##_0, (__m256bh) matArray##_4, (__m256bh) xArray##_2); \ | |||||
accumArray##_1 = _mm256_dpbf16_ps(accumArray##_1, (__m256bh) matArray##_6, (__m256bh) xArray##_2); \ | |||||
accumArray##_0 = _mm256_dpbf16_ps(accumArray##_0, (__m256bh) matArray##_5, (__m256bh) xArray##_3); \ | |||||
accumArray##_1 = _mm256_dpbf16_ps(accumArray##_1, (__m256bh) matArray##_7, (__m256bh) xArray##_3); | |||||
/* Calculate the dot result for 2-step interleaved matrix and vector | |||||
(Assume throughput for _mm512_dpbf16_ps is 0.5, tunable per platform) | |||||
*/ | |||||
#define BF16_2STEP_INTERLEAVED_DOT_4x32(accumArray, matArray, xArray) \ | |||||
accumArray##_0 = _mm512_dpbf16_ps(accumArray##_0, (__m512bh) matArray##_0, (__m512bh) xArray##_0); \ | |||||
accumArray##_1 = _mm512_dpbf16_ps(accumArray##_1, (__m512bh) matArray##_1, (__m512bh) xArray##_1); \ | |||||
accumArray##_0 = _mm512_dpbf16_ps(accumArray##_0, (__m512bh) matArray##_2, (__m512bh) xArray##_2); \ | |||||
accumArray##_1 = _mm512_dpbf16_ps(accumArray##_1, (__m512bh) matArray##_3, (__m512bh) xArray##_3); | |||||
/* Calculate the dot result for 2-step interleaved matrix and vector | |||||
(Assume throughput for _mm256_dpbf16_ps is 0.5, tunable per platform) | |||||
*/ | |||||
#define BF16_2STEP_INTERLEAVED_DOT_4x16(accumArray, matArray, xArray) \ | |||||
accumArray##_0 = _mm256_dpbf16_ps(accumArray##_0, (__m256bh) matArray##_0, (__m256bh) xArray##_0); \ | |||||
accumArray##_1 = _mm256_dpbf16_ps(accumArray##_1, (__m256bh) matArray##_1, (__m256bh) xArray##_1); \ | |||||
accumArray##_0 = _mm256_dpbf16_ps(accumArray##_0, (__m256bh) matArray##_2, (__m256bh) xArray##_2); \ | |||||
accumArray##_1 = _mm256_dpbf16_ps(accumArray##_1, (__m256bh) matArray##_3, (__m256bh) xArray##_3); | |||||
/* Calculate the dot result for matrix and vector at 32 elements per row | |||||
(Assume throughput for _mm512_dpbf16_ps is 0.5, tunable per platform) | |||||
*/ | |||||
#define BF16_DOT_8x32(accumArray, matArray, xArray) \ | |||||
accumArray##_0 = _mm512_dpbf16_ps(accumArray##_0, (__m512bh) matArray##_0, (__m512bh) xArray); \ | |||||
accumArray##_1 = _mm512_dpbf16_ps(accumArray##_1, (__m512bh) matArray##_1, (__m512bh) xArray); \ | |||||
accumArray##_2 = _mm512_dpbf16_ps(accumArray##_2, (__m512bh) matArray##_2, (__m512bh) xArray); \ | |||||
accumArray##_3 = _mm512_dpbf16_ps(accumArray##_3, (__m512bh) matArray##_3, (__m512bh) xArray); \ | |||||
accumArray##_4 = _mm512_dpbf16_ps(accumArray##_4, (__m512bh) matArray##_4, (__m512bh) xArray); \ | |||||
accumArray##_5 = _mm512_dpbf16_ps(accumArray##_5, (__m512bh) matArray##_5, (__m512bh) xArray); \ | |||||
accumArray##_6 = _mm512_dpbf16_ps(accumArray##_6, (__m512bh) matArray##_6, (__m512bh) xArray); \ | |||||
accumArray##_7 = _mm512_dpbf16_ps(accumArray##_7, (__m512bh) matArray##_7, (__m512bh) xArray); | |||||
/* Calculate the dot result for matrix and vector at 32 elements per row | |||||
(Assume throughput for _mm512_dpbf16_ps is 0.5, tunable per platform) | |||||
*/ | |||||
#define BF16_DOT_1x32(accumArray, matArray, xArray) \ | |||||
accumArray = _mm512_dpbf16_ps(accumArray, (__m512bh) matArray, (__m512bh) xArray); | |||||
/* Calculate the dot result for matrix and vector at 16 elements per row | |||||
(Assume throughput for _mm256_dpbf16_ps is 0.5, tunable per platform) | |||||
*/ | |||||
#define BF16_DOT_8x16(accumArray, matArray, xArray) \ | |||||
accumArray##_0 = _mm256_dpbf16_ps(accumArray##_0, (__m256bh) matArray##_0, (__m256bh) xArray); \ | |||||
accumArray##_1 = _mm256_dpbf16_ps(accumArray##_1, (__m256bh) matArray##_1, (__m256bh) xArray); \ | |||||
accumArray##_2 = _mm256_dpbf16_ps(accumArray##_2, (__m256bh) matArray##_2, (__m256bh) xArray); \ | |||||
accumArray##_3 = _mm256_dpbf16_ps(accumArray##_3, (__m256bh) matArray##_3, (__m256bh) xArray); \ | |||||
accumArray##_4 = _mm256_dpbf16_ps(accumArray##_4, (__m256bh) matArray##_4, (__m256bh) xArray); \ | |||||
accumArray##_5 = _mm256_dpbf16_ps(accumArray##_5, (__m256bh) matArray##_5, (__m256bh) xArray); \ | |||||
accumArray##_6 = _mm256_dpbf16_ps(accumArray##_6, (__m256bh) matArray##_6, (__m256bh) xArray); \ | |||||
accumArray##_7 = _mm256_dpbf16_ps(accumArray##_7, (__m256bh) matArray##_7, (__m256bh) xArray); | |||||
/* 2-step interleave for matrix against 8 rows with 16 fp32 elements per row | |||||
Input - register array of 8 rows of raw-major matrix | |||||
Output - the output of Step 2 | |||||
Step 1: 2-element interleave for matrix | |||||
|a0|b0|a1|b1|a4|b4|a5|b5|a8 |b8 |a9 |b9 |a12|b12|a13|b13| | |||||
|c0|d0|c1|d1|c4|d4|c5|d5|c8 |d8 |c9 |d9 |c12|d12|c13|d13| | |||||
|e0|f0|e1|f1|e4|f4|e5|f5|e8 |f8 |e9 |f9 |e12|f12|e13|f13| | |||||
|g0|h0|g1|h1|g4|h4|g5|h5|g8 |h8 |g9 |h9 |g12|h12|g13|h13| | |||||
|a2|b2|a3|b3|a6|b6|a7|b7|a10|b10|a11|b11|a14|b14|a15|b15| | |||||
|c2|d2|c3|d3|c6|d6|c7|d7|c10|d10|c11|d11|c14|d14|c15|d15| | |||||
|e2|f2|e3|f3|e6|f6|e7|f7|e10|f10|e11|f11|e14|f14|e15|f15| | |||||
|g2|h2|g3|h3|g6|h6|g7|h7|g10|h10|g11|h11|g14|h14|g15|h15| | |||||
Step 2: 4-element interleave for matrix | |||||
|a0|b0|c0|d0|a4|b4|c4|d4|a8 |b8 |c8 |d8 |a12|b12|c12|d12| | |||||
|a1|b1|c1|d1|a5|b5|c5|d5|a9 |b9 |c9 |d9 |a13|b13|c13|d13| | |||||
|e0|f0|g0|h0|e4|f4|g4|h4|e8 |f8 |g8 |h8 |e12|f12|g12|h12| | |||||
|e1|f1|g1|h1|e5|f5|g5|h5|e9 |f9 |g9 |h9 |e13|f13|g13|h13| | |||||
|a2|b2|c2|d2|a6|b6|c6|d6|a10|b10|c10|d10|a14|b14|c14|d14| | |||||
|a3|b3|c3|d3|a7|b7|c7|d7|a11|b11|c11|d11|a15|b15|c15|d15| | |||||
|e2|f2|g2|h2|e6|f6|g6|h6|e10|f10|g10|h10|e14|f14|g14|h14| | |||||
|e3|f3|g3|h3|e7|f7|g7|h7|e11|f11|g11|h11|e15|f15|g15|h15| | |||||
*/ | |||||
#define FP32_INTERLEAVE_8x16(regArray) \ | |||||
regArray##_8 = _mm512_unpacklo_ps(regArray##_0, regArray##_1); \ | |||||
regArray##_9 = _mm512_unpacklo_ps(regArray##_2, regArray##_3); \ | |||||
regArray##_10 = _mm512_unpacklo_ps(regArray##_4, regArray##_5); \ | |||||
regArray##_11 = _mm512_unpacklo_ps(regArray##_6, regArray##_7); \ | |||||
regArray##_12 = _mm512_unpackhi_ps(regArray##_0, regArray##_1); \ | |||||
regArray##_13 = _mm512_unpackhi_ps(regArray##_2, regArray##_3); \ | |||||
regArray##_14 = _mm512_unpackhi_ps(regArray##_4, regArray##_5); \ | |||||
regArray##_15 = _mm512_unpackhi_ps(regArray##_6, regArray##_7); \ | |||||
\ | |||||
regArray##_0 = (__m512) _mm512_unpacklo_pd((__m512d) regArray##_8, (__m512d) regArray##_9); \ | |||||
regArray##_1 = (__m512) _mm512_unpackhi_pd((__m512d) regArray##_8, (__m512d) regArray##_9); \ | |||||
regArray##_4 = (__m512) _mm512_unpacklo_pd((__m512d) regArray##_10, (__m512d) regArray##_11); \ | |||||
regArray##_5 = (__m512) _mm512_unpackhi_pd((__m512d) regArray##_10, (__m512d) regArray##_11); \ | |||||
regArray##_2 = (__m512) _mm512_unpacklo_pd((__m512d) regArray##_12, (__m512d) regArray##_13); \ | |||||
regArray##_3 = (__m512) _mm512_unpackhi_pd((__m512d) regArray##_12, (__m512d) regArray##_13); \ | |||||
regArray##_6 = (__m512) _mm512_unpacklo_pd((__m512d) regArray##_14, (__m512d) regArray##_15); \ | |||||
regArray##_7 = (__m512) _mm512_unpackhi_pd((__m512d) regArray##_14, (__m512d) regArray##_15); | |||||
#define FP32_INTERLEAVE_8x16_ARRAY(regArray) \ | |||||
regArray[8] = _mm512_unpacklo_ps(regArray[0], regArray[1]); \ | |||||
regArray[9] = _mm512_unpacklo_ps(regArray[2], regArray[3]); \ | |||||
regArray[10] = _mm512_unpacklo_ps(regArray[4], regArray[5]); \ | |||||
regArray[11] = _mm512_unpacklo_ps(regArray[6], regArray[7]); \ | |||||
regArray[12] = _mm512_unpackhi_ps(regArray[0], regArray[1]); \ | |||||
regArray[13] = _mm512_unpackhi_ps(regArray[2], regArray[3]); \ | |||||
regArray[14] = _mm512_unpackhi_ps(regArray[4], regArray[5]); \ | |||||
regArray[15] = _mm512_unpackhi_ps(regArray[6], regArray[7]); \ | |||||
\ | |||||
regArray[0] = (__m512) _mm512_unpacklo_pd((__m512d) regArray[8], (__m512d) regArray[9]); \ | |||||
regArray[1] = (__m512) _mm512_unpackhi_pd((__m512d) regArray[8], (__m512d) regArray[9]); \ | |||||
regArray[4] = (__m512) _mm512_unpacklo_pd((__m512d) regArray[10], (__m512d) regArray[11]); \ | |||||
regArray[5] = (__m512) _mm512_unpackhi_pd((__m512d) regArray[10], (__m512d) regArray[11]); \ | |||||
regArray[2] = (__m512) _mm512_unpacklo_pd((__m512d) regArray[12], (__m512d) regArray[13]); \ | |||||
regArray[3] = (__m512) _mm512_unpackhi_pd((__m512d) regArray[12], (__m512d) regArray[13]); \ | |||||
regArray[6] = (__m512) _mm512_unpacklo_pd((__m512d) regArray[14], (__m512d) regArray[15]); \ | |||||
regArray[7] = (__m512) _mm512_unpackhi_pd((__m512d) regArray[14], (__m512d) regArray[15]); | |||||
/* 2-step interleave for matrix against 8 rows with 8 fp32 elements per row | |||||
Input - register array of 8 rows of raw-major matrix | |||||
Output - the output of Step 2 | |||||
Step 1: 2-element interleave for matrix | |||||
|a0|b0|a1|b1|a4|b4|a5|b5| | |||||
|c0|d0|c1|d1|c4|d4|c5|d5| | |||||
|e0|f0|e1|f1|e4|f4|e5|f5| | |||||
|g0|h0|g1|h1|g4|h4|g5|h5| | |||||
|a2|b2|a3|b3|a6|b6|a7|b7| | |||||
|c2|d2|c3|d3|c6|d6|c7|d7| | |||||
|e2|f2|e3|f3|e6|f6|e7|f7| | |||||
|g2|h2|g3|h3|g6|h6|g7|h7| | |||||
Step 2: 4-element interleave for matrix | |||||
|a0|b0|c0|d0|a4|b4|c4|d4| | |||||
|a1|b1|c1|d1|a5|b5|c5|d5| | |||||
|e0|f0|g0|h0|e4|f4|g4|h4| | |||||
|e1|f1|g1|h1|e5|f5|g5|h5| | |||||
|a2|b2|c2|d2|a6|b6|c6|d6| | |||||
|a3|b3|c3|d3|a7|b7|c7|d7| | |||||
|e2|f2|g2|h2|e6|f6|g6|h6| | |||||
|e3|f3|g3|h3|e7|f7|g7|h7| | |||||
*/ | |||||
#define FP32_INTERLEAVE_8x8(regArray) \ | |||||
regArray##_8 = _mm256_unpacklo_ps(regArray##_0, regArray##_1); \ | |||||
regArray##_9 = _mm256_unpacklo_ps(regArray##_2, regArray##_3); \ | |||||
regArray##_10 = _mm256_unpacklo_ps(regArray##_4, regArray##_5); \ | |||||
regArray##_11 = _mm256_unpacklo_ps(regArray##_6, regArray##_7); \ | |||||
regArray##_12 = _mm256_unpackhi_ps(regArray##_0, regArray##_1); \ | |||||
regArray##_13 = _mm256_unpackhi_ps(regArray##_2, regArray##_3); \ | |||||
regArray##_14 = _mm256_unpackhi_ps(regArray##_4, regArray##_5); \ | |||||
regArray##_15 = _mm256_unpackhi_ps(regArray##_6, regArray##_7); \ | |||||
\ | |||||
regArray##_0 = (__m256) _mm256_unpacklo_pd((__m256d) regArray##_8, (__m256d) regArray##_9); \ | |||||
regArray##_1 = (__m256) _mm256_unpackhi_pd((__m256d) regArray##_8, (__m256d) regArray##_9); \ | |||||
regArray##_4 = (__m256) _mm256_unpacklo_pd((__m256d) regArray##_10, (__m256d) regArray##_11); \ | |||||
regArray##_5 = (__m256) _mm256_unpackhi_pd((__m256d) regArray##_10, (__m256d) regArray##_11); \ | |||||
regArray##_2 = (__m256) _mm256_unpacklo_pd((__m256d) regArray##_12, (__m256d) regArray##_13); \ | |||||
regArray##_3 = (__m256) _mm256_unpackhi_pd((__m256d) regArray##_12, (__m256d) regArray##_13); \ | |||||
regArray##_6 = (__m256) _mm256_unpacklo_pd((__m256d) regArray##_14, (__m256d) regArray##_15); \ | |||||
regArray##_7 = (__m256) _mm256_unpackhi_pd((__m256d) regArray##_14, (__m256d) regArray##_15); | |||||
/* Accumulate the result for 2 batch of 4-registers | |||||
*/ | |||||
#define FP32_ACCUM2_8x16(regArray) \ | |||||
regArray##_0 = _mm512_add_ps(regArray##_0, regArray##_1); \ | |||||
regArray##_2 = _mm512_add_ps(regArray##_2, regArray##_3); \ | |||||
regArray##_4 = _mm512_add_ps(regArray##_4, regArray##_5); \ | |||||
regArray##_6 = _mm512_add_ps(regArray##_6, regArray##_7); \ | |||||
regArray##_0 = _mm512_add_ps(regArray##_0, regArray##_2); \ | |||||
regArray##_4 = _mm512_add_ps(regArray##_4, regArray##_6); | |||||
#define FP32_ACCUM2_8x16_ARRAY(regArray) \ | |||||
regArray[0] = _mm512_add_ps(regArray[0], regArray[1]); \ | |||||
regArray[2] = _mm512_add_ps(regArray[2], regArray[3]); \ | |||||
regArray[4] = _mm512_add_ps(regArray[4], regArray[5]); \ | |||||
regArray[6] = _mm512_add_ps(regArray[6], regArray[7]); \ | |||||
regArray[0] = _mm512_add_ps(regArray[0], regArray[2]); \ | |||||
regArray[4] = _mm512_add_ps(regArray[4], regArray[6]); | |||||
/* Accumulate the result for 2 batch of 4-registers | |||||
*/ | |||||
#define FP32_ACCUM2_8x8(regArray) \ | |||||
regArray##_0 = _mm256_add_ps(regArray##_0, regArray##_1); \ | |||||
regArray##_2 = _mm256_add_ps(regArray##_2, regArray##_3); \ | |||||
regArray##_4 = _mm256_add_ps(regArray##_4, regArray##_5); \ | |||||
regArray##_6 = _mm256_add_ps(regArray##_6, regArray##_7); \ | |||||
regArray##_0 = _mm256_add_ps(regArray##_0, regArray##_2); \ | |||||
regArray##_4 = _mm256_add_ps(regArray##_4, regArray##_6); | |||||
/* Store 16 (alpha * result + beta * y) to y | |||||
*/ | |||||
#define STORE16_COMPLETE_RESULT_ALPHA_BETA(regResult, targetAddr) \ | |||||
regResult = _mm512_fmadd_ps(ALPHAVECTOR, regResult, _mm512_mul_ps(BETAVECTOR, _mm512_loadu_ps(targetAddr))); \ | |||||
_mm512_storeu_ps(targetAddr, regResult); | |||||
/* Masked store 16 (alpha * result + beta * y) to y | |||||
*/ | |||||
#define STORE16_MASK_COMPLETE_RESULT_ALPHA_BETA(regResult, targetAddr, mask) \ | |||||
regResult = _mm512_fmadd_ps(ALPHAVECTOR, regResult, _mm512_mul_ps(BETAVECTOR, _mm512_maskz_loadu_ps(mask, targetAddr))); \ | |||||
_mm512_mask_storeu_ps(targetAddr, mask, regResult); | |||||
/* Store 8 (alpha * result + beta * y) to y | |||||
*/ | |||||
#define STORE8_COMPLETE_RESULT_ALPHA_BETA(regResult, targetAddr) \ | |||||
regResult = _mm256_fmadd_ps(_mm512_castps512_ps256(ALPHAVECTOR), regResult, _mm256_mul_ps(_mm512_castps512_ps256(BETAVECTOR), _mm256_loadu_ps(targetAddr))); \ | |||||
_mm256_storeu_ps(targetAddr, regResult); | |||||
/* Masked store 8 (alpha * result + beta * y) to y | |||||
*/ | |||||
#define STORE8_MASK_COMPLETE_RESULT_ALPHA_BETA(regResult, targetAddr, mask) \ | |||||
regResult = _mm256_fmadd_ps(_mm512_castps512_ps256(ALPHAVECTOR), regResult, _mm256_mul_ps(_mm512_castps512_ps256(BETAVECTOR), _mm256_maskz_loadu_ps(mask, targetAddr))); \ | |||||
_mm256_mask_storeu_ps(targetAddr, mask, regResult); | |||||
/* Store 4 (alpha * result + beta * y) to y | |||||
*/ | |||||
#define STORE4_COMPLETE_RESULT_ALPHA_BETA(regResult, targetAddr) \ | |||||
regResult = _mm_fmadd_ps(_mm512_castps512_ps128(ALPHAVECTOR), regResult, _mm_mul_ps(_mm512_castps512_ps128(BETAVECTOR), _mm_loadu_ps(targetAddr))); \ | |||||
_mm_storeu_ps(targetAddr, regResult); | |||||
/* Masked store 4 (alpha * result + beta * y) to y | |||||
*/ | |||||
#define STORE4_MASK_COMPLETE_RESULT_ALPHA_BETA(regResult, targetAddr, mask) \ | |||||
regResult = _mm_fmadd_ps(_mm512_castps512_ps128(ALPHAVECTOR), regResult, _mm_mul_ps(_mm512_castps512_ps128(BETAVECTOR), _mm_maskz_loadu_ps(mask, targetAddr))); \ | |||||
_mm_mask_storeu_ps(targetAddr, mask, regResult); | |||||
/* Store 16 (alpha * result + y) to y | |||||
*/ | |||||
#define STORE16_COMPLETE_RESULT_ALPHA_ONE(regResult, targetAddr) \ | |||||
regResult = _mm512_fmadd_ps(ALPHAVECTOR, regResult, _mm512_loadu_ps(targetAddr)); \ | |||||
_mm512_storeu_ps(targetAddr, regResult); | |||||
/* Masked store 16 (alpha * result + y) to y | |||||
*/ | |||||
#define STORE16_MASK_COMPLETE_RESULT_ALPHA_ONE(regResult, targetAddr, mask) \ | |||||
regResult = _mm512_fmadd_ps(ALPHAVECTOR, regResult, _mm512_maskz_loadu_ps(mask, targetAddr)); \ | |||||
_mm512_mask_storeu_ps(targetAddr, mask, regResult); | |||||
/* Store 8 (alpha * result + y) to y | |||||
*/ | |||||
#define STORE8_COMPLETE_RESULT_ALPHA_ONE(regResult, targetAddr) \ | |||||
regResult = _mm256_fmadd_ps(_mm512_castps512_ps256(ALPHAVECTOR), regResult, _mm256_loadu_ps(targetAddr)); \ | |||||
_mm256_storeu_ps(targetAddr, regResult); | |||||
/* Masked store 8 (alpha * result + y) to y | |||||
*/ | |||||
#define STORE8_MASK_COMPLETE_RESULT_ALPHA_ONE(regResult, targetAddr, mask) \ | |||||
regResult = _mm256_fmadd_ps(_mm512_castps512_ps256(ALPHAVECTOR), regResult, _mm256_maskz_loadu_ps(mask, targetAddr)); \ | |||||
_mm256_mask_storeu_ps(targetAddr, mask, regResult); | |||||
/* Store 4 (alpha * result + y) to y | |||||
*/ | |||||
#define STORE4_COMPLETE_RESULT_ALPHA_ONE(regResult, targetAddr) \ | |||||
regResult = _mm_fmadd_ps(_mm512_castps512_ps128(ALPHAVECTOR), regResult, _mm_loadu_ps(targetAddr)); \ | |||||
_mm_storeu_ps(targetAddr, regResult); | |||||
/* Masked store 4 (alpha * result + y) to y | |||||
*/ | |||||
#define STORE4_MASK_COMPLETE_RESULT_ALPHA_ONE(regResult, targetAddr, mask) \ | |||||
regResult = _mm_fmadd_ps(_mm512_castps512_ps128(ALPHAVECTOR), regResult, _mm_maskz_loadu_ps(mask, targetAddr)); \ | |||||
_mm_mask_storeu_ps(targetAddr, mask, regResult); | |||||
/* Store 16 (alpha * result) to y | |||||
*/ | |||||
#define STORE16_COMPLETE_RESULT_ALPHA(regResult, targetAddr) \ | |||||
_mm512_storeu_ps(targetAddr, _mm512_mul_ps(ALPHAVECTOR, regResult)); | |||||
/* Masked store 16 (alpha * result) to y | |||||
*/ | |||||
#define STORE16_MASK_COMPLETE_RESULT_ALPHA(regResult, targetAddr, mask) \ | |||||
_mm512_mask_storeu_ps(targetAddr, mask, _mm512_mul_ps(ALPHAVECTOR, regResult)); | |||||
/* Store 8 (alpha * result) to y | |||||
*/ | |||||
#define STORE8_COMPLETE_RESULT_ALPHA(regResult, targetAddr) \ | |||||
_mm256_storeu_ps(targetAddr, _mm256_mul_ps(_mm512_castps512_ps256(ALPHAVECTOR), regResult)); | |||||
/* Masked store 8 (alpha * result) to y | |||||
*/ | |||||
#define STORE8_MASK_COMPLETE_RESULT_ALPHA(regResult, targetAddr, mask) \ | |||||
_mm256_mask_storeu_ps(targetAddr, mask, _mm256_mul_ps(_mm512_castps512_ps256(ALPHAVECTOR), regResult)); | |||||
/* Store 4 (alpha * result) to y | |||||
*/ | |||||
#define STORE4_COMPLETE_RESULT_ALPHA(regResult, targetAddr) \ | |||||
_mm_storeu_ps(targetAddr, _mm_mul_ps(_mm512_castps512_ps128(ALPHAVECTOR), regResult)); | |||||
/* Masked store 4 (alpha * result) to y | |||||
*/ | |||||
#define STORE4_MASK_COMPLETE_RESULT_ALPHA(regResult, targetAddr, mask) \ | |||||
_mm_mask_storeu_ps(targetAddr, mask, _mm_mul_ps(_mm512_castps512_ps128(ALPHAVECTOR), regResult)); | |||||
/* Store 16 result to y | |||||
*/ | |||||
#define STORE16_COMPLETE_RESULT_DIRECT(regResult, targetAddr) \ | |||||
_mm512_storeu_ps(targetAddr, regResult); | |||||
/* Masked store 16 result to y | |||||
*/ | |||||
#define STORE16_MASK_COMPLETE_RESULT_DIRECT(regResult, targetAddr, mask) \ | |||||
_mm512_mask_storeu_ps(targetAddr, mask, regResult); | |||||
/* Store 8 result to y | |||||
*/ | |||||
#define STORE8_COMPLETE_RESULT_DIRECT(regResult, targetAddr) \ | |||||
_mm256_storeu_ps(targetAddr, regResult); | |||||
/* Masked store 8 result to y | |||||
*/ | |||||
#define STORE8_MASK_COMPLETE_RESULT_DIRECT(regResult, targetAddr, mask) \ | |||||
_mm256_mask_storeu_ps(targetAddr, mask, regResult); | |||||
/* Store 4 result to y | |||||
*/ | |||||
#define STORE4_COMPLETE_RESULT_DIRECT(regResult, targetAddr) \ | |||||
_mm_storeu_ps(targetAddr, regResult); | |||||
/* Masked store 4 result to y | |||||
*/ | |||||
#define STORE4_MASK_COMPLETE_RESULT_DIRECT(regResult, targetAddr, mask) \ | |||||
_mm_mask_storeu_ps(targetAddr, mask, regResult); | |||||
#endif |
@@ -0,0 +1,137 @@ | |||||
/*************************************************************************** | |||||
Copyright (c) 2014, The OpenBLAS Project | |||||
All rights reserved. | |||||
Redistribution and use in source and binary forms, with or without | |||||
modification, are permitted provided that the following conditions are | |||||
met: | |||||
1. Redistributions of source code must retain the above copyright | |||||
notice, this list of conditions and the following disclaimer. | |||||
2. Redistributions in binary form must reproduce the above copyright | |||||
notice, this list of conditions and the following disclaimer in | |||||
the documentation and/or other materials provided with the | |||||
distribution. | |||||
3. Neither the name of the OpenBLAS project nor the names of | |||||
its contributors may be used to endorse or promote products | |||||
derived from this software without specific prior written permission. | |||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||||
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE | |||||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE | |||||
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||||
*****************************************************************************/ | |||||
#include "common.h" | |||||
#if defined (COOPERLAKE) | |||||
#include "sbgemv_n_microk_cooperlake.c" | |||||
#endif | |||||
#define ALIGN64_ALLOC(alloc_size, TYPE, ptr_align, ptr) \ | |||||
ptr = (TYPE *) malloc(sizeof(TYPE)*alloc_size + 63); \ | |||||
ptr_align = ((int)(((uintptr_t)ptr & (uintptr_t)0x3F))!=0) ? (TYPE *)((char *)ptr + (64 - (int)((uintptr_t)ptr & (uintptr_t)0x3F))) : ptr | |||||
#define ALIGN64_FREE(ptr) \ | |||||
free(ptr) | |||||
#ifndef HAVE_SBGEMV_N_ACCL_KERNEL | |||||
static void sbgemv_kernel_n(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float beta, float *y) | |||||
{ | |||||
BLASLONG offset_lda, offset_m; | |||||
float accum = 0.0; | |||||
float tmp_x = 0.0; | |||||
bfloat16 * a_bf16 = malloc(sizeof(bfloat16)*m*n); | |||||
float * a_fp32 = malloc(sizeof(float)*m*n); | |||||
float * x_fp32 = malloc(sizeof(float)*n); | |||||
for (BLASLONG j=0; j<n; j++) { | |||||
offset_lda = lda * j; | |||||
offset_m = m * j; | |||||
for (BLASLONG i=0; i<m; i++) { | |||||
a_bf16[offset_m + i] = a[offset_lda + i]; | |||||
} | |||||
} | |||||
SBF16TOS_K(n, x, 1, x_fp32, 1); | |||||
SBF16TOS_K(m*n, a_bf16, 1, a_fp32, 1); | |||||
for (BLASLONG i=0; i<m; i++) { | |||||
accum = 0.0; | |||||
for (BLASLONG j=0; j<n; j++) { | |||||
accum += a_fp32[j*m + i] * x_fp32[j]; | |||||
} | |||||
if (beta == ZERO) { | |||||
y[i] = alpha * accum; | |||||
} else { | |||||
y[i] = alpha * accum + beta * y[i]; | |||||
} | |||||
} | |||||
free(a_bf16); | |||||
free(a_fp32); | |||||
free(x_fp32); | |||||
} | |||||
#endif | |||||
static void bf16_compress_vector(BLASLONG n, bfloat16 * src, bfloat16 * target, BLASLONG inc) | |||||
{ | |||||
for(BLASLONG i=0; i<n; i++) { | |||||
target[i] = src[i*inc]; | |||||
} | |||||
} | |||||
static void fp32_compress_vector(BLASLONG n, float * src, float * target, BLASLONG inc) | |||||
{ | |||||
for(BLASLONG i=0; i<n; i++) { | |||||
target[i] = src[i*inc]; | |||||
} | |||||
} | |||||
static void fp32_expand_vector(BLASLONG n, float * src, float * target, BLASLONG inc) | |||||
{ | |||||
for(BLASLONG i=0; i<n; i++) { | |||||
target[i*inc] = src[i]; | |||||
} | |||||
} | |||||
int CNAME(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, BLASLONG incx, float beta, float * y, BLASLONG incy) | |||||
{ | |||||
if ( m < 1 || n < 1) return(0); | |||||
bfloat16 * xbuffer_align = x; | |||||
float * ybuffer_align = y; | |||||
bfloat16 * xbuffer = NULL; | |||||
float * ybuffer = NULL; | |||||
if (incx != 1) { | |||||
ALIGN64_ALLOC(n, bfloat16, xbuffer_align, xbuffer); | |||||
bf16_compress_vector(n, x, xbuffer_align, incx); | |||||
} | |||||
if (incy != 1) { | |||||
ALIGN64_ALLOC(m, float, ybuffer_align, ybuffer); | |||||
if (beta != ZERO) { | |||||
fp32_compress_vector(m, y, ybuffer_align, incy); | |||||
} | |||||
} | |||||
sbgemv_kernel_n(m, n, alpha, a, lda, xbuffer_align, beta, ybuffer_align); | |||||
if (incy != 1) { | |||||
fp32_expand_vector(m, ybuffer_align, y, incy); | |||||
ALIGN64_FREE(ybuffer); | |||||
} | |||||
if (incx != 1) { | |||||
ALIGN64_FREE(xbuffer); | |||||
} | |||||
return(0); | |||||
} |
@@ -0,0 +1,76 @@ | |||||
/*************************************************************************** | |||||
Copyright (c) 2014, The OpenBLAS Project | |||||
All rights reserved. | |||||
Redistribution and use in source and binary forms, with or without | |||||
modification, are permitted provided that the following conditions are | |||||
met: | |||||
1. Redistributions of source code must retain the above copyright | |||||
notice, this list of conditions and the following disclaimer. | |||||
2. Redistributions in binary form must reproduce the above copyright | |||||
notice, this list of conditions and the following disclaimer in | |||||
the documentation and/or other materials provided with the | |||||
distribution. | |||||
3. Neither the name of the OpenBLAS project nor the names of | |||||
its contributors may be used to endorse or promote products | |||||
derived from this software without specific prior written permission. | |||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||||
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE | |||||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE | |||||
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||||
*****************************************************************************/ | |||||
/* need a new enough GCC for avx512 support */ | |||||
#if (( defined(__GNUC__) && __GNUC__ >= 10 && defined(__AVX512BF16__)) || (defined(__clang__) && __clang_major__ >= 9)) | |||||
#define HAVE_SBGEMV_N_ACCL_KERNEL 1 | |||||
#include "common.h" | |||||
#include <immintrin.h> | |||||
// Define micro kernels for ALPHA not ONE && BETA effective && BETA not ONE scenarios | |||||
#undef ZERO_BETA | |||||
#undef ONE_BETA | |||||
#undef ONE_ALPHA | |||||
#include "sbgemv_n_microk_cooperlake_template.c" | |||||
// Define micro kernels for ALPHA not ONE && BETA as ONE scenarios | |||||
#undef ZERO_BETA | |||||
#define ONE_BETA 1 | |||||
#undef ONE_ALPHA | |||||
#include "sbgemv_n_microk_cooperlake_template.c" | |||||
// Define micro kernels for ALPHA not ONE && BETA in-effective (BETA == 0) scenarios | |||||
#define ZERO_BETA 1 | |||||
#undef ONE_ALPHA | |||||
#include "sbgemv_n_microk_cooperlake_template.c" | |||||
// Define micro kernels for ALPHA as ONE && BETA in-effective (BETA == 0) scenarios | |||||
#define ZERO_BETA 1 | |||||
#define ONE_ALPHA 1 | |||||
#include "sbgemv_n_microk_cooperlake_template.c" | |||||
static int sbgemv_kernel_n(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float beta, float *y) | |||||
{ | |||||
if (beta == ZERO) { // BETA == 0.0, no need to accumulate the original Y data | |||||
if (alpha == ONE) { // ALPHA == 1.0, no need to multipy ALPHA | |||||
sbgemv_kernel_32xN_lda_direct(m, n, alpha, a, lda, x, y); | |||||
} else { // ALPHA != 1.0, need to multipy ALPHA | |||||
sbgemv_kernel_32xN_lda_direct_alpha(m, n, alpha, a, lda, x, y); | |||||
} | |||||
} else { // BETA != 0.0, need to accumulate the original Y data no matter what ALPHA is | |||||
if (beta == ONE) { | |||||
sbgemv_kernel_32xN_lda_direct_alpha_one(m, n, alpha, a, lda, x, beta, y); | |||||
} else { | |||||
sbgemv_kernel_32xN_lda_direct_alpha_beta(m, n, alpha, a, lda, x, beta, y); | |||||
} | |||||
} | |||||
return 0; | |||||
} | |||||
#endif |
@@ -0,0 +1,234 @@ | |||||
/*************************************************************************** | |||||
Copyright (c) 2014, The OpenBLAS Project | |||||
All rights reserved. | |||||
Redistribution and use in source and binary forms, with or without | |||||
modification, are permitted provided that the following conditions are | |||||
met: | |||||
1. Redistributions of source code must retain the above copyright | |||||
notice, this list of conditions and the following disclaimer. | |||||
2. Redistributions in binary form must reproduce the above copyright | |||||
notice, this list of conditions and the following disclaimer in | |||||
the documentation and/or other materials provided with the | |||||
distribution. | |||||
3. Neither the name of the OpenBLAS project nor the names of | |||||
its contributors may be used to endorse or promote products | |||||
derived from this software without specific prior written permission. | |||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||||
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE | |||||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE | |||||
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||||
*****************************************************************************/ | |||||
#include <immintrin.h> | |||||
#include "common.h" | |||||
// Include common macros for BF16 based operations with IA intrinsics | |||||
#include "bf16_common_macros.h" | |||||
#ifndef ZERO_BETA // Beta is non-zero | |||||
#ifndef ONE_BETA // BETA is not ONE | |||||
#define STORE16_COMPLETE_RESULT STORE16_COMPLETE_RESULT_ALPHA_BETA | |||||
#define STORE16_MASK_COMPLETE_RESULT STORE16_MASK_COMPLETE_RESULT_ALPHA_BETA | |||||
#define STORE8_COMPLETE_RESULT STORE8_COMPLETE_RESULT_ALPHA_BETA | |||||
#define STORE8_MASK_COMPLETE_RESULT STORE8_MASK_COMPLETE_RESULT_ALPHA_BETA | |||||
#define STORE4_COMPLETE_RESULT STORE4_COMPLETE_RESULT_ALPHA_BETA | |||||
#define STORE4_MASK_COMPLETE_RESULT STORE4_MASK_COMPLETE_RESULT_ALPHA_BETA | |||||
#else // BETA is ONE | |||||
#define STORE16_COMPLETE_RESULT STORE16_COMPLETE_RESULT_ALPHA_ONE | |||||
#define STORE16_MASK_COMPLETE_RESULT STORE16_MASK_COMPLETE_RESULT_ALPHA_ONE | |||||
#define STORE8_COMPLETE_RESULT STORE8_COMPLETE_RESULT_ALPHA_ONE | |||||
#define STORE8_MASK_COMPLETE_RESULT STORE8_MASK_COMPLETE_RESULT_ALPHA_ONE | |||||
#define STORE4_COMPLETE_RESULT STORE4_COMPLETE_RESULT_ALPHA_ONE | |||||
#define STORE4_MASK_COMPLETE_RESULT STORE4_MASK_COMPLETE_RESULT_ALPHA_ONE | |||||
#endif | |||||
#else // BETA is zero | |||||
#ifndef ONE_ALPHA // ALPHA is not ONE | |||||
#define STORE16_COMPLETE_RESULT STORE16_COMPLETE_RESULT_ALPHA | |||||
#define STORE16_MASK_COMPLETE_RESULT STORE16_MASK_COMPLETE_RESULT_ALPHA | |||||
#define STORE8_COMPLETE_RESULT STORE8_COMPLETE_RESULT_ALPHA | |||||
#define STORE8_MASK_COMPLETE_RESULT STORE8_MASK_COMPLETE_RESULT_ALPHA | |||||
#define STORE4_COMPLETE_RESULT STORE4_COMPLETE_RESULT_ALPHA | |||||
#define STORE4_MASK_COMPLETE_RESULT STORE4_MASK_COMPLETE_RESULT_ALPHA | |||||
#else // ALPHA is ONE | |||||
#define STORE16_COMPLETE_RESULT STORE16_COMPLETE_RESULT_DIRECT | |||||
#define STORE16_MASK_COMPLETE_RESULT STORE16_MASK_COMPLETE_RESULT_DIRECT | |||||
#define STORE8_COMPLETE_RESULT STORE8_COMPLETE_RESULT_DIRECT | |||||
#define STORE8_MASK_COMPLETE_RESULT STORE8_MASK_COMPLETE_RESULT_DIRECT | |||||
#define STORE4_COMPLETE_RESULT STORE4_COMPLETE_RESULT_DIRECT | |||||
#define STORE4_MASK_COMPLETE_RESULT STORE4_MASK_COMPLETE_RESULT_DIRECT | |||||
#endif | |||||
#endif | |||||
// 8 rows parallel processing BF16 GEMV kernel for big N && lda effective scenario (process before interleave) | |||||
#ifndef ZERO_BETA | |||||
#ifndef ONE_BETA | |||||
static int sbgemv_kernel_32xN_lda_direct_alpha_beta(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float beta, float *y) | |||||
#else | |||||
static int sbgemv_kernel_32xN_lda_direct_alpha_one(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float beta, float *y) | |||||
#endif | |||||
#else | |||||
#ifndef ONE_ALPHA | |||||
static int sbgemv_kernel_32xN_lda_direct_alpha(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float *y) | |||||
#else | |||||
static int sbgemv_kernel_32xN_lda_direct(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float *y) | |||||
#endif | |||||
#endif | |||||
{ | |||||
BLASLONG tag_m_32x = m & (~31); | |||||
BLASLONG tag_m_128x = m & (~127); | |||||
__m512 accum512_0, accum512_1, accum512_2, accum512_3, accum512_4, accum512_5, accum512_6, accum512_7, \ | |||||
accum512_8, accum512_9, accum512_10, accum512_11, accum512_12, accum512_13, accum512_14, accum512_15; | |||||
#ifndef ONE_ALPHA | |||||
__m512 ALPHAVECTOR = _mm512_set1_ps(alpha); | |||||
#endif | |||||
#ifndef ZERO_BETA | |||||
__m512 BETAVECTOR = _mm512_set1_ps(beta); | |||||
#endif | |||||
__m512i matrixArray_seed_0, matrixArray_seed_1, matrixArray_seed_2, matrixArray_seed_3; | |||||
__m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3, matrixArray_4, matrixArray_5, matrixArray_6, matrixArray_7; | |||||
__m512i xArray_0; | |||||
__m512i ZERO512 = _mm512_setzero_si512(); | |||||
unsigned int blend_hi_mask_value = ((unsigned int)0xaaaaaaaa); | |||||
__mmask32 blend_hi_mask = *((__mmask32*) &blend_hi_mask_value); | |||||
unsigned int blend_lo_mask_value = ((unsigned int)0x55555555); | |||||
__mmask32 blend_lo_mask = *((__mmask32*) &blend_lo_mask_value); | |||||
__m512i M512_EPI32_8 = _mm512_set1_epi32(8); | |||||
__m512i idx_base_0 = _mm512_set_epi32(23, 7, 22, 6, 21, 5, 20, 4, 19, 3, 18, 2, 17, 1, 16, 0); | |||||
__m512i idx_base_1 = _mm512_add_epi32(idx_base_0, M512_EPI32_8); | |||||
for (BLASLONG idx_m = 0; idx_m < tag_m_128x; idx_m+=128) { | |||||
accum512_0 = _mm512_setzero_ps(); | |||||
accum512_1 = _mm512_setzero_ps(); | |||||
accum512_2 = _mm512_setzero_ps(); | |||||
accum512_3 = _mm512_setzero_ps(); | |||||
accum512_4 = _mm512_setzero_ps(); | |||||
accum512_5 = _mm512_setzero_ps(); | |||||
accum512_6 = _mm512_setzero_ps(); | |||||
accum512_7 = _mm512_setzero_ps(); | |||||
for (BLASLONG idx_n = 0; idx_n < n; idx_n++) { | |||||
xArray_0 = _mm512_set1_epi16(x[idx_n]); | |||||
BF16_MATRIX_LOAD_1x32(matrixArray_seed_0, a, lda, idx_n, idx_m + 0) | |||||
BF16_MATRIX_LOAD_1x32(matrixArray_seed_1, a, lda, idx_n, idx_m + 32) | |||||
BF16_MATRIX_LOAD_1x32(matrixArray_seed_2, a, lda, idx_n, idx_m + 64) | |||||
BF16_MATRIX_LOAD_1x32(matrixArray_seed_3, a, lda, idx_n, idx_m + 96) | |||||
matrixArray_0 = _mm512_mask_blend_epi16(blend_lo_mask, ZERO512, matrixArray_seed_0); | |||||
matrixArray_1 = _mm512_mask_blend_epi16(blend_hi_mask, ZERO512, matrixArray_seed_0); | |||||
matrixArray_2 = _mm512_mask_blend_epi16(blend_lo_mask, ZERO512, matrixArray_seed_1); | |||||
matrixArray_3 = _mm512_mask_blend_epi16(blend_hi_mask, ZERO512, matrixArray_seed_1); | |||||
matrixArray_4 = _mm512_mask_blend_epi16(blend_lo_mask, ZERO512, matrixArray_seed_2); | |||||
matrixArray_5 = _mm512_mask_blend_epi16(blend_hi_mask, ZERO512, matrixArray_seed_2); | |||||
matrixArray_6 = _mm512_mask_blend_epi16(blend_lo_mask, ZERO512, matrixArray_seed_3); | |||||
matrixArray_7 = _mm512_mask_blend_epi16(blend_hi_mask, ZERO512, matrixArray_seed_3); | |||||
BF16_DOT_1x32(accum512_0, matrixArray_0, xArray_0) | |||||
BF16_DOT_1x32(accum512_1, matrixArray_1, xArray_0) | |||||
BF16_DOT_1x32(accum512_2, matrixArray_2, xArray_0) | |||||
BF16_DOT_1x32(accum512_3, matrixArray_3, xArray_0) | |||||
BF16_DOT_1x32(accum512_4, matrixArray_4, xArray_0) | |||||
BF16_DOT_1x32(accum512_5, matrixArray_5, xArray_0) | |||||
BF16_DOT_1x32(accum512_6, matrixArray_6, xArray_0) | |||||
BF16_DOT_1x32(accum512_7, matrixArray_7, xArray_0) | |||||
} | |||||
accum512_8 = _mm512_permutex2var_ps(accum512_0, idx_base_0, accum512_1); | |||||
accum512_9 = _mm512_permutex2var_ps(accum512_0, idx_base_1, accum512_1); | |||||
accum512_10 = _mm512_permutex2var_ps(accum512_2, idx_base_0, accum512_3); | |||||
accum512_11 = _mm512_permutex2var_ps(accum512_2, idx_base_1, accum512_3); | |||||
accum512_12 = _mm512_permutex2var_ps(accum512_4, idx_base_0, accum512_5); | |||||
accum512_13 = _mm512_permutex2var_ps(accum512_4, idx_base_1, accum512_5); | |||||
accum512_14 = _mm512_permutex2var_ps(accum512_6, idx_base_0, accum512_7); | |||||
accum512_15 = _mm512_permutex2var_ps(accum512_6, idx_base_1, accum512_7); | |||||
STORE16_COMPLETE_RESULT(accum512_8, y+idx_m+0) | |||||
STORE16_COMPLETE_RESULT(accum512_9, y+idx_m+16) | |||||
STORE16_COMPLETE_RESULT(accum512_10, y+idx_m+32) | |||||
STORE16_COMPLETE_RESULT(accum512_11, y+idx_m+48) | |||||
STORE16_COMPLETE_RESULT(accum512_12, y+idx_m+64) | |||||
STORE16_COMPLETE_RESULT(accum512_13, y+idx_m+80) | |||||
STORE16_COMPLETE_RESULT(accum512_14, y+idx_m+96) | |||||
STORE16_COMPLETE_RESULT(accum512_15, y+idx_m+112) | |||||
} | |||||
for (BLASLONG idx_m = tag_m_128x; idx_m < tag_m_32x; idx_m+=32) { | |||||
accum512_0 = _mm512_setzero_ps(); | |||||
accum512_1 = _mm512_setzero_ps(); | |||||
for (BLASLONG idx_n = 0; idx_n < n; idx_n++) { | |||||
xArray_0 = _mm512_set1_epi16(x[idx_n]); | |||||
BF16_MATRIX_LOAD_1x32(matrixArray_seed_0, a, lda, idx_n, idx_m) | |||||
matrixArray_0 = _mm512_mask_blend_epi16(blend_lo_mask, ZERO512, matrixArray_seed_0); | |||||
matrixArray_1 = _mm512_mask_blend_epi16(blend_hi_mask, ZERO512, matrixArray_seed_0); | |||||
BF16_DOT_1x32(accum512_0, matrixArray_0, xArray_0) | |||||
BF16_DOT_1x32(accum512_1, matrixArray_1, xArray_0) | |||||
} | |||||
accum512_8 = _mm512_permutex2var_ps(accum512_0, idx_base_0, accum512_1); | |||||
accum512_9 = _mm512_permutex2var_ps(accum512_0, idx_base_1, accum512_1); | |||||
STORE16_COMPLETE_RESULT(accum512_8, y+idx_m+0) | |||||
STORE16_COMPLETE_RESULT(accum512_9, y+idx_m+16) | |||||
} | |||||
if (tag_m_32x != m) { | |||||
unsigned int tail_mask_value = (((unsigned int)0xffffffff) >> (32-(m&31))); | |||||
__mmask32 tail_mask = *((__mmask32*) &tail_mask_value); | |||||
unsigned short store_tail_mask_value = (((unsigned int)0xffff) >> (16-(m&15))); | |||||
__mmask32 store_tail_mask = *((__mmask32*) &store_tail_mask_value); | |||||
accum512_0 = _mm512_setzero_ps(); | |||||
accum512_1 = _mm512_setzero_ps(); | |||||
for (BLASLONG idx_n = 0; idx_n < n; idx_n++) { | |||||
xArray_0 = _mm512_set1_epi16(x[idx_n]); | |||||
BF16_MATRIX_MASKZ_LOAD_1x32(matrixArray_seed_0, a, lda, idx_n, tag_m_32x, tail_mask) | |||||
matrixArray_0 = _mm512_mask_blend_epi16(blend_lo_mask, ZERO512, matrixArray_seed_0); | |||||
matrixArray_1 = _mm512_mask_blend_epi16(blend_hi_mask, ZERO512, matrixArray_seed_0); | |||||
BF16_DOT_1x32(accum512_0, matrixArray_0, xArray_0) | |||||
BF16_DOT_1x32(accum512_1, matrixArray_1, xArray_0) | |||||
} | |||||
accum512_8 = _mm512_permutex2var_ps(accum512_0, idx_base_0, accum512_1); | |||||
accum512_9 = _mm512_permutex2var_ps(accum512_0, idx_base_1, accum512_1); | |||||
if ((m-tag_m_32x) > 16) { | |||||
STORE16_COMPLETE_RESULT(accum512_8, y+tag_m_32x+0) | |||||
STORE16_MASK_COMPLETE_RESULT(accum512_9, y+tag_m_32x+16, store_tail_mask) | |||||
} else { | |||||
STORE16_MASK_COMPLETE_RESULT(accum512_8, y+tag_m_32x+0, store_tail_mask) | |||||
} | |||||
} | |||||
return 0; | |||||
} |
@@ -0,0 +1,142 @@ | |||||
/*************************************************************************** | |||||
Copyright (c) 2014, The OpenBLAS Project | |||||
All rights reserved. | |||||
Redistribution and use in source and binary forms, with or without | |||||
modification, are permitted provided that the following conditions are | |||||
met: | |||||
1. Redistributions of source code must retain the above copyright | |||||
notice, this list of conditions and the following disclaimer. | |||||
2. Redistributions in binary form must reproduce the above copyright | |||||
notice, this list of conditions and the following disclaimer in | |||||
the documentation and/or other materials provided with the | |||||
distribution. | |||||
3. Neither the name of the OpenBLAS project nor the names of | |||||
its contributors may be used to endorse or promote products | |||||
derived from this software without specific prior written permission. | |||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||||
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE | |||||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE | |||||
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||||
*****************************************************************************/ | |||||
#include "common.h" | |||||
#if defined (COOPERLAKE) | |||||
#include "sbgemv_t_microk_cooperlake.c" | |||||
#endif | |||||
#define ALIGN64_ALLOC(alloc_size, TYPE, ptr_align, ptr) \ | |||||
ptr = (TYPE *) malloc(sizeof(TYPE)*alloc_size + 63); \ | |||||
ptr_align = ((int)(((uintptr_t)ptr & (uintptr_t)0x3F))!=0) ? (TYPE *)((char *)ptr + (64 - (int)((uintptr_t)ptr & (uintptr_t)0x3F))) : ptr | |||||
#define ALIGN64_FREE(ptr) \ | |||||
free(ptr) | |||||
#ifndef HAVE_SBGEMV_T_ACCL_KERNEL | |||||
static void sbgemv_kernel_t(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float beta, float *y) | |||||
{ | |||||
BLASLONG offset_lda, offset_n; | |||||
float accum = 0.0; | |||||
bfloat16 * a_bf16 = malloc(sizeof(bfloat16)*m*n); | |||||
float * a_fp32 = malloc(sizeof(float)*m*n); | |||||
float * x_fp32 = malloc(sizeof(float)*n); | |||||
for (BLASLONG i=0; i<m; i++) { | |||||
offset_lda = lda * i; | |||||
offset_n = n * i; | |||||
for (BLASLONG j=0; j<n; j++) { | |||||
a_bf16[offset_n + j] = a[offset_lda + j]; | |||||
} | |||||
} | |||||
SBF16TOS_K(n, x, 1, x_fp32, 1); | |||||
SBF16TOS_K(m*n, a_bf16, 1, a_fp32, 1); | |||||
for (BLASLONG i=0; i<m; i++) { | |||||
offset_n = n * i; | |||||
accum = 0.0; | |||||
for (BLASLONG j=0; j<n; j++) { | |||||
accum += a_fp32[offset_n + j] * x_fp32[j]; | |||||
} | |||||
if (beta == ZERO) { | |||||
y[i] = alpha * accum; | |||||
} else { | |||||
y[i] = alpha * accum + beta * y[i]; | |||||
} | |||||
} | |||||
free(a_bf16); | |||||
free(a_fp32); | |||||
free(x_fp32); | |||||
} | |||||
#endif | |||||
static void bf16_compress_vector(BLASLONG n, bfloat16 * src, bfloat16 * target, BLASLONG inc) | |||||
{ | |||||
for(BLASLONG i=0; i<n; i++) { | |||||
target[i] = src[i*inc]; | |||||
} | |||||
} | |||||
static void fp32_compress_vector(BLASLONG n, float * src, float * target, BLASLONG inc) | |||||
{ | |||||
for(BLASLONG i=0; i<n; i++) { | |||||
target[i] = src[i*inc]; | |||||
} | |||||
} | |||||
static void fp32_expand_vector(BLASLONG n, float * src, float * target, BLASLONG inc) | |||||
{ | |||||
for(BLASLONG i=0; i<n; i++) { | |||||
target[i*inc] = src[i]; | |||||
} | |||||
} | |||||
int CNAME(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, BLASLONG incx, float beta, float * y, BLASLONG incy) | |||||
{ | |||||
if ( m < 1 || n < 1) return(0); | |||||
bfloat16 * xbuffer_align = x; | |||||
float * ybuffer_align = y; | |||||
bfloat16 * xbuffer = NULL; | |||||
float * ybuffer = NULL; | |||||
// Switch m and n | |||||
BLASLONG t = m; | |||||
m = n; | |||||
n = t; | |||||
if (incx != 1) { | |||||
ALIGN64_ALLOC(n, bfloat16, xbuffer_align, xbuffer); | |||||
bf16_compress_vector(n, x, xbuffer_align, incx); | |||||
} | |||||
if (incy != 1) { | |||||
ALIGN64_ALLOC(m, float, ybuffer_align, ybuffer); | |||||
if (beta != ZERO) { | |||||
fp32_compress_vector(m, y, ybuffer_align, incy); | |||||
} | |||||
} | |||||
sbgemv_kernel_t(m, n, alpha, a, lda, xbuffer_align, beta, ybuffer_align); | |||||
if (incy != 1) { | |||||
fp32_expand_vector(m, ybuffer_align, y, incy); | |||||
ALIGN64_FREE(ybuffer); | |||||
} | |||||
if (incx != 1) { | |||||
ALIGN64_FREE(xbuffer); | |||||
} | |||||
return(0); | |||||
} |
@@ -0,0 +1,202 @@ | |||||
/*************************************************************************** | |||||
Copyright (c) 2014, The OpenBLAS Project | |||||
All rights reserved. | |||||
Redistribution and use in source and binary forms, with or without | |||||
modification, are permitted provided that the following conditions are | |||||
met: | |||||
1. Redistributions of source code must retain the above copyright | |||||
notice, this list of conditions and the following disclaimer. | |||||
2. Redistributions in binary form must reproduce the above copyright | |||||
notice, this list of conditions and the following disclaimer in | |||||
the documentation and/or other materials provided with the | |||||
distribution. | |||||
3. Neither the name of the OpenBLAS project nor the names of | |||||
its contributors may be used to endorse or promote products | |||||
derived from this software without specific prior written permission. | |||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||||
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE | |||||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE | |||||
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||||
*****************************************************************************/ | |||||
/* need a new enough GCC for avx512 support */ | |||||
#if (( defined(__GNUC__) && __GNUC__ >= 10 && defined(__AVX512BF16__)) || (defined(__clang__) && __clang_major__ >= 9)) | |||||
#define HAVE_SBGEMV_T_ACCL_KERNEL 1 | |||||
// Define micro kernels for ALPHA not ONE && BETA effective && BETA not ONE scenarios | |||||
#undef ZERO_BETA | |||||
#undef ONE_BETA | |||||
#undef ONE_ALPHA | |||||
#include "sbgemv_t_microk_cooperlake_template.c" | |||||
// Define micro kernels for ALPHA not ONE && BETA as ONE scenarios | |||||
#undef ZERO_BETA | |||||
#define ONE_BETA 1 | |||||
#undef ONE_ALPHA | |||||
#include "sbgemv_t_microk_cooperlake_template.c" | |||||
// Define micro kernels for ALPHA not ONE && BETA in-effective (BETA == 0) scenarios | |||||
#define ZERO_BETA 1 | |||||
#undef ONE_ALPHA | |||||
#include "sbgemv_t_microk_cooperlake_template.c" | |||||
// Define micro kernels for ALPHA as ONE && BETA in-effective (BETA == 0) scenarios | |||||
#define ZERO_BETA 1 | |||||
#define ONE_ALPHA 1 | |||||
#include "sbgemv_t_microk_cooperlake_template.c" | |||||
static int sbgemv_kernel_t(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float beta, float *y) | |||||
{ | |||||
if (beta == ZERO) { // BETA == 0.0, no need to accumulate the original Y data | |||||
if (alpha == ONE) { // ALPHA == 1.0, no need to multipy ALPHA | |||||
if (n > 127) { | |||||
sbgemv_kernel_1x128_lda_direct(m, n, alpha, a, lda, x, y); | |||||
} else if (n > 32) { | |||||
sbgemv_kernel_8x32_lda_direct(m, n, alpha, a, lda, x, y); | |||||
} else { | |||||
if (n > 16) { | |||||
sbgemv_kernel_8x16p_lda(m, n, alpha, a, lda, x, y); | |||||
} else { | |||||
if (lda == n) { | |||||
switch(n) { | |||||
case 1: sbgemv_kernel_32x1 (m, alpha, a, x, y); break; | |||||
case 2: sbgemv_kernel_32x2 (m, alpha, a, x, y); break; | |||||
case 3: sbgemv_kernel_32x3 (m, alpha, a, x, y); break; | |||||
case 4: sbgemv_kernel_16x4 (m, alpha, a, x, y); break; | |||||
case 5: sbgemv_kernel_30x5 (m, alpha, a, x, y); break; | |||||
case 6: sbgemv_kernel_16x6 (m, alpha, a, x, y); break; | |||||
case 7: sbgemv_kernel_16x7 (m, alpha, a, x, y); break; | |||||
case 8: sbgemv_kernel_16x8 (m, alpha, a, x, y); break; | |||||
case 9: sbgemv_kernel_14x9 (m, alpha, a, x, y); break; | |||||
case 10: sbgemv_kernel_12x10(m, alpha, a, x, y); break; | |||||
case 11: sbgemv_kernel_15x11(m, alpha, a, x, y); break; | |||||
case 12: sbgemv_kernel_15x12(m, alpha, a, x, y); break; | |||||
case 13: sbgemv_kernel_16x13(m, alpha, a, x, y); break; | |||||
case 14: sbgemv_kernel_16x14(m, alpha, a, x, y); break; | |||||
case 15: sbgemv_kernel_16x15(m, alpha, a, x, y); break; | |||||
case 16: sbgemv_kernel_16x16(m, alpha, a, x, y); break; | |||||
default: break; | |||||
} | |||||
} else { | |||||
sbgemv_kernel_8x16m_lda(m, n, alpha, a, lda, x, y); | |||||
} | |||||
} | |||||
} | |||||
} else { // ALPHA != 1.0, need to multipy ALPHA | |||||
if (n > 127) { | |||||
sbgemv_kernel_1x128_lda_direct_alpha(m, n, alpha, a, lda, x, y); | |||||
} else if (n > 32) { | |||||
sbgemv_kernel_8x32_lda_direct_alpha(m, n, alpha, a, lda, x, y); | |||||
} else { | |||||
if (n > 16) { | |||||
sbgemv_kernel_8x16p_lda_alpha(m, n, alpha, a, lda, x, y); | |||||
} else { | |||||
if (lda == n) { | |||||
switch(n) { | |||||
case 1: sbgemv_kernel_32x1_alpha (m, alpha, a, x, y); break; | |||||
case 2: sbgemv_kernel_32x2_alpha (m, alpha, a, x, y); break; | |||||
case 3: sbgemv_kernel_32x3_alpha (m, alpha, a, x, y); break; | |||||
case 4: sbgemv_kernel_16x4_alpha (m, alpha, a, x, y); break; | |||||
case 5: sbgemv_kernel_30x5_alpha (m, alpha, a, x, y); break; | |||||
case 6: sbgemv_kernel_16x6_alpha (m, alpha, a, x, y); break; | |||||
case 7: sbgemv_kernel_16x7_alpha (m, alpha, a, x, y); break; | |||||
case 8: sbgemv_kernel_16x8_alpha (m, alpha, a, x, y); break; | |||||
case 9: sbgemv_kernel_14x9_alpha (m, alpha, a, x, y); break; | |||||
case 10: sbgemv_kernel_12x10_alpha(m, alpha, a, x, y); break; | |||||
case 11: sbgemv_kernel_15x11_alpha(m, alpha, a, x, y); break; | |||||
case 12: sbgemv_kernel_15x12_alpha(m, alpha, a, x, y); break; | |||||
case 13: sbgemv_kernel_16x13_alpha(m, alpha, a, x, y); break; | |||||
case 14: sbgemv_kernel_16x14_alpha(m, alpha, a, x, y); break; | |||||
case 15: sbgemv_kernel_16x15_alpha(m, alpha, a, x, y); break; | |||||
case 16: sbgemv_kernel_16x16_alpha(m, alpha, a, x, y); break; | |||||
default: break; | |||||
} | |||||
} else { | |||||
sbgemv_kernel_8x16m_lda_alpha(m, n, alpha, a, lda, x, y); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} else { // BETA != 0.0, need to accumulate the original Y data no matter what ALPHA is | |||||
if (beta == ONE) { | |||||
if (n > 127) { | |||||
sbgemv_kernel_1x128_lda_direct_alpha_one(m, n, alpha, a, lda, x, beta, y); | |||||
} else if (n > 32) { | |||||
sbgemv_kernel_8x32_lda_direct_alpha_one(m, n, alpha, a, lda, x, beta, y); | |||||
} else { | |||||
if (n > 16) { | |||||
sbgemv_kernel_8x16p_lda_alpha_one(m, n, alpha, a, lda, x, beta, y); | |||||
} else { | |||||
if (lda == n) { | |||||
switch(n) { | |||||
case 1: sbgemv_kernel_32x1_alpha_one (m, alpha, a, x, beta, y); break; | |||||
case 2: sbgemv_kernel_32x2_alpha_one (m, alpha, a, x, beta, y); break; | |||||
case 3: sbgemv_kernel_32x3_alpha_one (m, alpha, a, x, beta, y); break; | |||||
case 4: sbgemv_kernel_16x4_alpha_one (m, alpha, a, x, beta, y); break; | |||||
case 5: sbgemv_kernel_30x5_alpha_one (m, alpha, a, x, beta, y); break; | |||||
case 6: sbgemv_kernel_16x6_alpha_one (m, alpha, a, x, beta, y); break; | |||||
case 7: sbgemv_kernel_16x7_alpha_one (m, alpha, a, x, beta, y); break; | |||||
case 8: sbgemv_kernel_16x8_alpha_one (m, alpha, a, x, beta, y); break; | |||||
case 9: sbgemv_kernel_14x9_alpha_one (m, alpha, a, x, beta, y); break; | |||||
case 10: sbgemv_kernel_12x10_alpha_one(m, alpha, a, x, beta, y); break; | |||||
case 11: sbgemv_kernel_15x11_alpha_one(m, alpha, a, x, beta, y); break; | |||||
case 12: sbgemv_kernel_15x12_alpha_one(m, alpha, a, x, beta, y); break; | |||||
case 13: sbgemv_kernel_16x13_alpha_one(m, alpha, a, x, beta, y); break; | |||||
case 14: sbgemv_kernel_16x14_alpha_one(m, alpha, a, x, beta, y); break; | |||||
case 15: sbgemv_kernel_16x15_alpha_one(m, alpha, a, x, beta, y); break; | |||||
case 16: sbgemv_kernel_16x16_alpha_one(m, alpha, a, x, beta, y); break; | |||||
default: break; | |||||
} | |||||
} else { | |||||
sbgemv_kernel_8x16m_lda_alpha_one(m, n, alpha, a, lda, x, beta, y); | |||||
} | |||||
} | |||||
} | |||||
} else { | |||||
if (n > 127) { | |||||
sbgemv_kernel_1x128_lda_direct_alpha_beta(m, n, alpha, a, lda, x, beta, y); | |||||
} else if (n > 32) { | |||||
sbgemv_kernel_8x32_lda_direct_alpha_beta(m, n, alpha, a, lda, x, beta, y); | |||||
} else { | |||||
if (n > 16) { | |||||
sbgemv_kernel_8x16p_lda_alpha_beta(m, n, alpha, a, lda, x, beta, y); | |||||
} else { | |||||
if (lda == n) { | |||||
switch(n) { | |||||
case 1: sbgemv_kernel_32x1_alpha_beta (m, alpha, a, x, beta, y); break; | |||||
case 2: sbgemv_kernel_32x2_alpha_beta (m, alpha, a, x, beta, y); break; | |||||
case 3: sbgemv_kernel_32x3_alpha_beta (m, alpha, a, x, beta, y); break; | |||||
case 4: sbgemv_kernel_16x4_alpha_beta (m, alpha, a, x, beta, y); break; | |||||
case 5: sbgemv_kernel_30x5_alpha_beta (m, alpha, a, x, beta, y); break; | |||||
case 6: sbgemv_kernel_16x6_alpha_beta (m, alpha, a, x, beta, y); break; | |||||
case 7: sbgemv_kernel_16x7_alpha_beta (m, alpha, a, x, beta, y); break; | |||||
case 8: sbgemv_kernel_16x8_alpha_beta (m, alpha, a, x, beta, y); break; | |||||
case 9: sbgemv_kernel_14x9_alpha_beta (m, alpha, a, x, beta, y); break; | |||||
case 10: sbgemv_kernel_12x10_alpha_beta(m, alpha, a, x, beta, y); break; | |||||
case 11: sbgemv_kernel_15x11_alpha_beta(m, alpha, a, x, beta, y); break; | |||||
case 12: sbgemv_kernel_15x12_alpha_beta(m, alpha, a, x, beta, y); break; | |||||
case 13: sbgemv_kernel_16x13_alpha_beta(m, alpha, a, x, beta, y); break; | |||||
case 14: sbgemv_kernel_16x14_alpha_beta(m, alpha, a, x, beta, y); break; | |||||
case 15: sbgemv_kernel_16x15_alpha_beta(m, alpha, a, x, beta, y); break; | |||||
case 16: sbgemv_kernel_16x16_alpha_beta(m, alpha, a, x, beta, y); break; | |||||
default: break; | |||||
} | |||||
} else { | |||||
sbgemv_kernel_8x16m_lda_alpha_beta(m, n, alpha, a, lda, x, beta, y); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
return 0; | |||||
} | |||||
#endif |