Add GEMM optimization for small matrix and single/double kernel for skylakextags/v0.3.18
@@ -244,6 +244,14 @@ else | |||
ONLY_CBLAS = 0 | |||
endif | |||
#For small matrix optimization | |||
ifeq ($(ARCH), x86_64) | |||
SMALL_MATRIX_OPT = 1 | |||
endif | |||
ifeq ($(SMALL_MATRIX_OPT), 1) | |||
CCOMMON_OPT += -DSMALL_MATRIX_OPT | |||
endif | |||
# This operation is expensive, so execution should be once. | |||
ifndef GOTOBLAS_MAKEFILE | |||
export GOTOBLAS_MAKEFILE = 1 | |||
@@ -258,6 +258,13 @@ if (NEED_PIC) | |||
endif() | |||
endif () | |||
if (X86_64) | |||
set(SMALL_MATRIX_OPT TRUE) | |||
endif () | |||
if (SMALL_MATRIX_OPT) | |||
set(CCOMMON_OPT "${CCOMMON_OPT} -DSMALL_MATRIX_OPT") | |||
endif () | |||
if (DYNAMIC_ARCH) | |||
if (X86 OR X86_64 OR ARM64 OR PPC) | |||
set(CCOMMON_OPT "${CCOMMON_OPT} -DDYNAMIC_ARCH") | |||
@@ -232,6 +232,8 @@ | |||
#define CGEADD_K cgeadd_k | |||
#define CGEMM_SMALL_MATRIX_PERMIT cgemm_small_matrix_permit | |||
#else | |||
#define CAMAX_K gotoblas -> camax_k | |||
@@ -426,8 +428,51 @@ | |||
#define CGEADD_K gotoblas -> cgeadd_k | |||
#define CGEMM_SMALL_MATRIX_PERMIT gotoblas -> cgemm_small_matrix_permit | |||
#endif | |||
#define CGEMM_SMALL_KERNEL_NN FUNC_OFFSET(cgemm_small_kernel_nn) | |||
#define CGEMM_SMALL_KERNEL_NT FUNC_OFFSET(cgemm_small_kernel_nt) | |||
#define CGEMM_SMALL_KERNEL_NR FUNC_OFFSET(cgemm_small_kernel_nr) | |||
#define CGEMM_SMALL_KERNEL_NC FUNC_OFFSET(cgemm_small_kernel_nc) | |||
#define CGEMM_SMALL_KERNEL_TN FUNC_OFFSET(cgemm_small_kernel_tn) | |||
#define CGEMM_SMALL_KERNEL_TT FUNC_OFFSET(cgemm_small_kernel_tt) | |||
#define CGEMM_SMALL_KERNEL_TR FUNC_OFFSET(cgemm_small_kernel_tr) | |||
#define CGEMM_SMALL_KERNEL_TC FUNC_OFFSET(cgemm_small_kernel_tc) | |||
#define CGEMM_SMALL_KERNEL_RN FUNC_OFFSET(cgemm_small_kernel_rn) | |||
#define CGEMM_SMALL_KERNEL_RT FUNC_OFFSET(cgemm_small_kernel_rt) | |||
#define CGEMM_SMALL_KERNEL_RR FUNC_OFFSET(cgemm_small_kernel_rr) | |||
#define CGEMM_SMALL_KERNEL_RC FUNC_OFFSET(cgemm_small_kernel_rc) | |||
#define CGEMM_SMALL_KERNEL_CN FUNC_OFFSET(cgemm_small_kernel_cn) | |||
#define CGEMM_SMALL_KERNEL_CT FUNC_OFFSET(cgemm_small_kernel_ct) | |||
#define CGEMM_SMALL_KERNEL_CR FUNC_OFFSET(cgemm_small_kernel_cr) | |||
#define CGEMM_SMALL_KERNEL_CC FUNC_OFFSET(cgemm_small_kernel_cc) | |||
#define CGEMM_SMALL_KERNEL_B0_NN FUNC_OFFSET(cgemm_small_kernel_b0_nn) | |||
#define CGEMM_SMALL_KERNEL_B0_NT FUNC_OFFSET(cgemm_small_kernel_b0_nt) | |||
#define CGEMM_SMALL_KERNEL_B0_NR FUNC_OFFSET(cgemm_small_kernel_b0_nr) | |||
#define CGEMM_SMALL_KERNEL_B0_NC FUNC_OFFSET(cgemm_small_kernel_b0_nc) | |||
#define CGEMM_SMALL_KERNEL_B0_TN FUNC_OFFSET(cgemm_small_kernel_b0_tn) | |||
#define CGEMM_SMALL_KERNEL_B0_TT FUNC_OFFSET(cgemm_small_kernel_b0_tt) | |||
#define CGEMM_SMALL_KERNEL_B0_TR FUNC_OFFSET(cgemm_small_kernel_b0_tr) | |||
#define CGEMM_SMALL_KERNEL_B0_TC FUNC_OFFSET(cgemm_small_kernel_b0_tc) | |||
#define CGEMM_SMALL_KERNEL_B0_RN FUNC_OFFSET(cgemm_small_kernel_b0_rn) | |||
#define CGEMM_SMALL_KERNEL_B0_RT FUNC_OFFSET(cgemm_small_kernel_b0_rt) | |||
#define CGEMM_SMALL_KERNEL_B0_RR FUNC_OFFSET(cgemm_small_kernel_b0_rr) | |||
#define CGEMM_SMALL_KERNEL_B0_RC FUNC_OFFSET(cgemm_small_kernel_b0_rc) | |||
#define CGEMM_SMALL_KERNEL_B0_CN FUNC_OFFSET(cgemm_small_kernel_b0_cn) | |||
#define CGEMM_SMALL_KERNEL_B0_CT FUNC_OFFSET(cgemm_small_kernel_b0_ct) | |||
#define CGEMM_SMALL_KERNEL_B0_CR FUNC_OFFSET(cgemm_small_kernel_b0_cr) | |||
#define CGEMM_SMALL_KERNEL_B0_CC FUNC_OFFSET(cgemm_small_kernel_b0_cc) | |||
#define CGEMM_NN cgemm_nn | |||
#define CGEMM_CN cgemm_cn | |||
#define CGEMM_TN cgemm_tn | |||
@@ -157,6 +157,8 @@ | |||
#define DIMATCOPY_K_RT dimatcopy_k_rt | |||
#define DGEADD_K dgeadd_k | |||
#define DGEMM_SMALL_MATRIX_PERMIT dgemm_small_matrix_permit | |||
#else | |||
#define DAMAX_K gotoblas -> damax_k | |||
@@ -281,8 +283,21 @@ | |||
#define DGEADD_K gotoblas -> dgeadd_k | |||
#define DGEMM_SMALL_MATRIX_PERMIT gotoblas -> dgemm_small_matrix_permit | |||
#endif | |||
#define DGEMM_SMALL_KERNEL_NN FUNC_OFFSET(dgemm_small_kernel_nn) | |||
#define DGEMM_SMALL_KERNEL_NT FUNC_OFFSET(dgemm_small_kernel_nt) | |||
#define DGEMM_SMALL_KERNEL_TN FUNC_OFFSET(dgemm_small_kernel_tn) | |||
#define DGEMM_SMALL_KERNEL_TT FUNC_OFFSET(dgemm_small_kernel_tt) | |||
#define DGEMM_SMALL_KERNEL_B0_NN FUNC_OFFSET(dgemm_small_kernel_b0_nn) | |||
#define DGEMM_SMALL_KERNEL_B0_NT FUNC_OFFSET(dgemm_small_kernel_b0_nt) | |||
#define DGEMM_SMALL_KERNEL_B0_TN FUNC_OFFSET(dgemm_small_kernel_b0_tn) | |||
#define DGEMM_SMALL_KERNEL_B0_TT FUNC_OFFSET(dgemm_small_kernel_b0_tt) | |||
#define DGEMM_NN dgemm_nn | |||
#define DGEMM_CN dgemm_tn | |||
#define DGEMM_TN dgemm_tn | |||
@@ -515,6 +515,117 @@ int qgemm_kernel(BLASLONG, BLASLONG, BLASLONG, xidouble *, xidouble *, xidouble | |||
int qgemm_kernel(BLASLONG, BLASLONG, BLASLONG, xdouble, xdouble *, xdouble *, xdouble *, BLASLONG); | |||
#endif | |||
#ifdef SMALL_MATRIX_OPT | |||
int sgemm_small_matrix_permit(int transa, int transb, BLASLONG m, BLASLONG n, BLASLONG k, float alpha, float beta); | |||
int sgemm_small_kernel_nn(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha, float * B, BLASLONG ldb, float beta, float * C, BLASLONG ldc); | |||
int sgemm_small_kernel_nt(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha, float * B, BLASLONG ldb, float beta, float * C, BLASLONG ldc); | |||
int sgemm_small_kernel_tn(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha, float * B, BLASLONG ldb, float beta, float * C, BLASLONG ldc); | |||
int sgemm_small_kernel_tt(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha, float * B, BLASLONG ldb, float beta, float * C, BLASLONG ldc); | |||
int dgemm_small_matrix_permit(int transa, int transb, BLASLONG m, BLASLONG n, BLASLONG k, double alpha, double beta); | |||
int dgemm_small_kernel_nn(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha, double * B, BLASLONG ldb, double beta, double * C, BLASLONG ldc); | |||
int dgemm_small_kernel_nt(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha, double * B, BLASLONG ldb, double beta, double * C, BLASLONG ldc); | |||
int dgemm_small_kernel_tn(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha, double * B, BLASLONG ldb, double beta, double * C, BLASLONG ldc); | |||
int dgemm_small_kernel_tt(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha, double * B, BLASLONG ldb, double beta, double * C, BLASLONG ldc); | |||
int sgemm_small_kernel_b0_nn(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha, float * B, BLASLONG ldb, float * C, BLASLONG ldc); | |||
int sgemm_small_kernel_b0_nt(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha, float * B, BLASLONG ldb, float * C, BLASLONG ldc); | |||
int sgemm_small_kernel_b0_tn(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha, float * B, BLASLONG ldb, float * C, BLASLONG ldc); | |||
int sgemm_small_kernel_b0_tt(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha, float * B, BLASLONG ldb, float * C, BLASLONG ldc); | |||
int dgemm_small_kernel_b0_nn(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha, double * B, BLASLONG ldb, double * C, BLASLONG ldc); | |||
int dgemm_small_kernel_b0_nt(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha, double * B, BLASLONG ldb, double * C, BLASLONG ldc); | |||
int dgemm_small_kernel_b0_tn(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha, double * B, BLASLONG ldb, double * C, BLASLONG ldc); | |||
int dgemm_small_kernel_b0_tt(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha, double * B, BLASLONG ldb, double * C, BLASLONG ldc); | |||
int cgemm_small_matrix_permit(int transa, int transb, BLASLONG m, BLASLONG n, BLASLONG k, float alpha0, float alpha1, float beta0, float beta1); | |||
int cgemm_small_kernel_nn(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); | |||
int cgemm_small_kernel_nt(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); | |||
int cgemm_small_kernel_nr(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); | |||
int cgemm_small_kernel_nc(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); | |||
int cgemm_small_kernel_tn(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); | |||
int cgemm_small_kernel_tt(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); | |||
int cgemm_small_kernel_tr(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); | |||
int cgemm_small_kernel_tc(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); | |||
int cgemm_small_kernel_rn(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); | |||
int cgemm_small_kernel_rt(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); | |||
int cgemm_small_kernel_rr(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); | |||
int cgemm_small_kernel_rc(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); | |||
int cgemm_small_kernel_cn(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); | |||
int cgemm_small_kernel_ct(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); | |||
int cgemm_small_kernel_cr(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); | |||
int cgemm_small_kernel_cc(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); | |||
int zgemm_small_matrix_permit(int transa, int transb, BLASLONG m, BLASLONG n, BLASLONG k, double alpha0, double alpha1, double beta0, double beta1); | |||
int zgemm_small_kernel_nn(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); | |||
int zgemm_small_kernel_nt(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); | |||
int zgemm_small_kernel_nr(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); | |||
int zgemm_small_kernel_nc(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); | |||
int zgemm_small_kernel_tn(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); | |||
int zgemm_small_kernel_tt(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); | |||
int zgemm_small_kernel_tr(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); | |||
int zgemm_small_kernel_tc(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); | |||
int zgemm_small_kernel_rn(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); | |||
int zgemm_small_kernel_rt(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); | |||
int zgemm_small_kernel_rr(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); | |||
int zgemm_small_kernel_rc(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); | |||
int zgemm_small_kernel_cn(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); | |||
int zgemm_small_kernel_ct(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); | |||
int zgemm_small_kernel_cr(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); | |||
int zgemm_small_kernel_cc(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); | |||
int cgemm_small_kernel_b0_nn(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); | |||
int cgemm_small_kernel_b0_nt(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); | |||
int cgemm_small_kernel_b0_nr(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); | |||
int cgemm_small_kernel_b0_nc(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); | |||
int cgemm_small_kernel_b0_tn(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); | |||
int cgemm_small_kernel_b0_tt(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); | |||
int cgemm_small_kernel_b0_tr(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); | |||
int cgemm_small_kernel_b0_tc(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); | |||
int cgemm_small_kernel_b0_rn(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); | |||
int cgemm_small_kernel_b0_rt(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); | |||
int cgemm_small_kernel_b0_rr(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); | |||
int cgemm_small_kernel_b0_rc(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); | |||
int cgemm_small_kernel_b0_cn(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); | |||
int cgemm_small_kernel_b0_ct(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); | |||
int cgemm_small_kernel_b0_cr(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); | |||
int cgemm_small_kernel_b0_cc(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); | |||
int zgemm_small_kernel_b0_nn(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); | |||
int zgemm_small_kernel_b0_nt(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); | |||
int zgemm_small_kernel_b0_nr(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); | |||
int zgemm_small_kernel_b0_nc(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); | |||
int zgemm_small_kernel_b0_tn(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); | |||
int zgemm_small_kernel_b0_tt(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); | |||
int zgemm_small_kernel_b0_tr(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); | |||
int zgemm_small_kernel_b0_tc(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); | |||
int zgemm_small_kernel_b0_rn(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); | |||
int zgemm_small_kernel_b0_rt(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); | |||
int zgemm_small_kernel_b0_rr(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); | |||
int zgemm_small_kernel_b0_rc(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); | |||
int zgemm_small_kernel_b0_cn(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); | |||
int zgemm_small_kernel_b0_ct(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); | |||
int zgemm_small_kernel_b0_cr(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); | |||
int zgemm_small_kernel_b0_cc(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); | |||
#endif | |||
int cgemm_kernel_n(BLASLONG, BLASLONG, BLASLONG, float, float, float *, float *, float *, BLASLONG); | |||
int cgemm_kernel_l(BLASLONG, BLASLONG, BLASLONG, float, float, float *, float *, float *, BLASLONG); | |||
int cgemm_kernel_r(BLASLONG, BLASLONG, BLASLONG, float, float, float *, float *, float *, BLASLONG); | |||
@@ -644,6 +644,17 @@ | |||
#define GEADD_K DGEADD_K | |||
#define GEMM_SMALL_MATRIX_PERMIT DGEMM_SMALL_MATRIX_PERMIT | |||
#define GEMM_SMALL_KERNEL_NN DGEMM_SMALL_KERNEL_NN | |||
#define GEMM_SMALL_KERNEL_NT DGEMM_SMALL_KERNEL_NT | |||
#define GEMM_SMALL_KERNEL_TN DGEMM_SMALL_KERNEL_TN | |||
#define GEMM_SMALL_KERNEL_TT DGEMM_SMALL_KERNEL_TT | |||
#define GEMM_SMALL_KERNEL_B0_NN DGEMM_SMALL_KERNEL_B0_NN | |||
#define GEMM_SMALL_KERNEL_B0_NT DGEMM_SMALL_KERNEL_B0_NT | |||
#define GEMM_SMALL_KERNEL_B0_TN DGEMM_SMALL_KERNEL_B0_TN | |||
#define GEMM_SMALL_KERNEL_B0_TT DGEMM_SMALL_KERNEL_B0_TT | |||
#elif defined(BFLOAT16) | |||
#define D_TO_BF16_K SBDTOBF16_K | |||
@@ -931,6 +942,18 @@ | |||
#define GEADD_K SGEADD_K | |||
#define GEMM_SMALL_MATRIX_PERMIT SGEMM_SMALL_MATRIX_PERMIT | |||
#define GEMM_SMALL_KERNEL_NN SGEMM_SMALL_KERNEL_NN | |||
#define GEMM_SMALL_KERNEL_NT SGEMM_SMALL_KERNEL_NT | |||
#define GEMM_SMALL_KERNEL_TN SGEMM_SMALL_KERNEL_TN | |||
#define GEMM_SMALL_KERNEL_TT SGEMM_SMALL_KERNEL_TT | |||
#define GEMM_SMALL_KERNEL_B0_NN SGEMM_SMALL_KERNEL_B0_NN | |||
#define GEMM_SMALL_KERNEL_B0_NT SGEMM_SMALL_KERNEL_B0_NT | |||
#define GEMM_SMALL_KERNEL_B0_TN SGEMM_SMALL_KERNEL_B0_TN | |||
#define GEMM_SMALL_KERNEL_B0_TT SGEMM_SMALL_KERNEL_B0_TT | |||
#endif | |||
#else | |||
@@ -1236,6 +1259,19 @@ | |||
#define IMATCOPY_K_RT SIMATCOPY_K_RT | |||
#define GEADD_K SGEADD_K | |||
#define GEMM_SMALL_MATRIX_PERMIT SGEMM_SMALL_MATRIX_PERMIT | |||
#define GEMM_SMALL_KERNEL_NN SGEMM_SMALL_KERNEL_NN | |||
#define GEMM_SMALL_KERNEL_NT SGEMM_SMALL_KERNEL_NT | |||
#define GEMM_SMALL_KERNEL_TN SGEMM_SMALL_KERNEL_TN | |||
#define GEMM_SMALL_KERNEL_TT SGEMM_SMALL_KERNEL_TT | |||
#define GEMM_SMALL_KERNEL_B0_NN SGEMM_SMALL_KERNEL_B0_NN | |||
#define GEMM_SMALL_KERNEL_B0_NT SGEMM_SMALL_KERNEL_B0_NT | |||
#define GEMM_SMALL_KERNEL_B0_TN SGEMM_SMALL_KERNEL_B0_TN | |||
#define GEMM_SMALL_KERNEL_B0_TT SGEMM_SMALL_KERNEL_B0_TT | |||
#endif | |||
#else | |||
#ifdef XDOUBLE | |||
@@ -2063,6 +2099,48 @@ | |||
#define GEADD_K ZGEADD_K | |||
#define GEMM_SMALL_MATRIX_PERMIT ZGEMM_SMALL_MATRIX_PERMIT | |||
#define GEMM_SMALL_KERNEL_NN ZGEMM_SMALL_KERNEL_NN | |||
#define GEMM_SMALL_KERNEL_NT ZGEMM_SMALL_KERNEL_NT | |||
#define GEMM_SMALL_KERNEL_NR ZGEMM_SMALL_KERNEL_NR | |||
#define GEMM_SMALL_KERNEL_NC ZGEMM_SMALL_KERNEL_NC | |||
#define GEMM_SMALL_KERNEL_TN ZGEMM_SMALL_KERNEL_TN | |||
#define GEMM_SMALL_KERNEL_TT ZGEMM_SMALL_KERNEL_TT | |||
#define GEMM_SMALL_KERNEL_TR ZGEMM_SMALL_KERNEL_TR | |||
#define GEMM_SMALL_KERNEL_TC ZGEMM_SMALL_KERNEL_TC | |||
#define GEMM_SMALL_KERNEL_RN ZGEMM_SMALL_KERNEL_RN | |||
#define GEMM_SMALL_KERNEL_RT ZGEMM_SMALL_KERNEL_RT | |||
#define GEMM_SMALL_KERNEL_RR ZGEMM_SMALL_KERNEL_RR | |||
#define GEMM_SMALL_KERNEL_RC ZGEMM_SMALL_KERNEL_RC | |||
#define GEMM_SMALL_KERNEL_CN ZGEMM_SMALL_KERNEL_CN | |||
#define GEMM_SMALL_KERNEL_CT ZGEMM_SMALL_KERNEL_CT | |||
#define GEMM_SMALL_KERNEL_CR ZGEMM_SMALL_KERNEL_CR | |||
#define GEMM_SMALL_KERNEL_CC ZGEMM_SMALL_KERNEL_CC | |||
#define GEMM_SMALL_KERNEL_B0_NN ZGEMM_SMALL_KERNEL_B0_NN | |||
#define GEMM_SMALL_KERNEL_B0_NT ZGEMM_SMALL_KERNEL_B0_NT | |||
#define GEMM_SMALL_KERNEL_B0_NR ZGEMM_SMALL_KERNEL_B0_NR | |||
#define GEMM_SMALL_KERNEL_B0_NC ZGEMM_SMALL_KERNEL_B0_NC | |||
#define GEMM_SMALL_KERNEL_B0_TN ZGEMM_SMALL_KERNEL_B0_TN | |||
#define GEMM_SMALL_KERNEL_B0_TT ZGEMM_SMALL_KERNEL_B0_TT | |||
#define GEMM_SMALL_KERNEL_B0_TR ZGEMM_SMALL_KERNEL_B0_TR | |||
#define GEMM_SMALL_KERNEL_B0_TC ZGEMM_SMALL_KERNEL_B0_TC | |||
#define GEMM_SMALL_KERNEL_B0_RN ZGEMM_SMALL_KERNEL_B0_RN | |||
#define GEMM_SMALL_KERNEL_B0_RT ZGEMM_SMALL_KERNEL_B0_RT | |||
#define GEMM_SMALL_KERNEL_B0_RR ZGEMM_SMALL_KERNEL_B0_RR | |||
#define GEMM_SMALL_KERNEL_B0_RC ZGEMM_SMALL_KERNEL_B0_RC | |||
#define GEMM_SMALL_KERNEL_B0_CN ZGEMM_SMALL_KERNEL_B0_CN | |||
#define GEMM_SMALL_KERNEL_B0_CT ZGEMM_SMALL_KERNEL_B0_CT | |||
#define GEMM_SMALL_KERNEL_B0_CR ZGEMM_SMALL_KERNEL_B0_CR | |||
#define GEMM_SMALL_KERNEL_B0_CC ZGEMM_SMALL_KERNEL_B0_CC | |||
#else | |||
#define AMAX_K CAMAX_K | |||
@@ -2486,6 +2564,48 @@ | |||
#define GEADD_K CGEADD_K | |||
#define GEMM_SMALL_MATRIX_PERMIT CGEMM_SMALL_MATRIX_PERMIT | |||
#define GEMM_SMALL_KERNEL_NN CGEMM_SMALL_KERNEL_NN | |||
#define GEMM_SMALL_KERNEL_NT CGEMM_SMALL_KERNEL_NT | |||
#define GEMM_SMALL_KERNEL_NR CGEMM_SMALL_KERNEL_NR | |||
#define GEMM_SMALL_KERNEL_NC CGEMM_SMALL_KERNEL_NC | |||
#define GEMM_SMALL_KERNEL_TN CGEMM_SMALL_KERNEL_TN | |||
#define GEMM_SMALL_KERNEL_TT CGEMM_SMALL_KERNEL_TT | |||
#define GEMM_SMALL_KERNEL_TR CGEMM_SMALL_KERNEL_TR | |||
#define GEMM_SMALL_KERNEL_TC CGEMM_SMALL_KERNEL_TC | |||
#define GEMM_SMALL_KERNEL_RN CGEMM_SMALL_KERNEL_RN | |||
#define GEMM_SMALL_KERNEL_RT CGEMM_SMALL_KERNEL_RT | |||
#define GEMM_SMALL_KERNEL_RR CGEMM_SMALL_KERNEL_RR | |||
#define GEMM_SMALL_KERNEL_RC CGEMM_SMALL_KERNEL_RC | |||
#define GEMM_SMALL_KERNEL_CN CGEMM_SMALL_KERNEL_CN | |||
#define GEMM_SMALL_KERNEL_CT CGEMM_SMALL_KERNEL_CT | |||
#define GEMM_SMALL_KERNEL_CR CGEMM_SMALL_KERNEL_CR | |||
#define GEMM_SMALL_KERNEL_CC CGEMM_SMALL_KERNEL_CC | |||
#define GEMM_SMALL_KERNEL_B0_NN CGEMM_SMALL_KERNEL_B0_NN | |||
#define GEMM_SMALL_KERNEL_B0_NT CGEMM_SMALL_KERNEL_B0_NT | |||
#define GEMM_SMALL_KERNEL_B0_NR CGEMM_SMALL_KERNEL_B0_NR | |||
#define GEMM_SMALL_KERNEL_B0_NC CGEMM_SMALL_KERNEL_B0_NC | |||
#define GEMM_SMALL_KERNEL_B0_TN CGEMM_SMALL_KERNEL_B0_TN | |||
#define GEMM_SMALL_KERNEL_B0_TT CGEMM_SMALL_KERNEL_B0_TT | |||
#define GEMM_SMALL_KERNEL_B0_TR CGEMM_SMALL_KERNEL_B0_TR | |||
#define GEMM_SMALL_KERNEL_B0_TC CGEMM_SMALL_KERNEL_B0_TC | |||
#define GEMM_SMALL_KERNEL_B0_RN CGEMM_SMALL_KERNEL_B0_RN | |||
#define GEMM_SMALL_KERNEL_B0_RT CGEMM_SMALL_KERNEL_B0_RT | |||
#define GEMM_SMALL_KERNEL_B0_RR CGEMM_SMALL_KERNEL_B0_RR | |||
#define GEMM_SMALL_KERNEL_B0_RC CGEMM_SMALL_KERNEL_B0_RC | |||
#define GEMM_SMALL_KERNEL_B0_CN CGEMM_SMALL_KERNEL_B0_CN | |||
#define GEMM_SMALL_KERNEL_B0_CT CGEMM_SMALL_KERNEL_B0_CT | |||
#define GEMM_SMALL_KERNEL_B0_CR CGEMM_SMALL_KERNEL_B0_CR | |||
#define GEMM_SMALL_KERNEL_B0_CC CGEMM_SMALL_KERNEL_B0_CC | |||
#endif | |||
#endif | |||
@@ -207,6 +207,20 @@ BLASLONG (*ismin_k) (BLASLONG, float *, BLASLONG); | |||
int (*sgemm_otcopy )(BLASLONG, BLASLONG, float *, BLASLONG, float *); | |||
#endif | |||
#ifdef BUILD_SINGLE | |||
#ifdef SMALL_MATRIX_OPT | |||
int (*sgemm_small_matrix_permit)(int transa, int transb, BLASLONG m, BLASLONG n, BLASLONG k, float alpha, float beta); | |||
int (*sgemm_small_kernel_nn )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha, float * B, BLASLONG ldb, float beta, float * C, BLASLONG ldc); | |||
int (*sgemm_small_kernel_nt )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha, float * B, BLASLONG ldb, float beta, float * C, BLASLONG ldc); | |||
int (*sgemm_small_kernel_tn )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha, float * B, BLASLONG ldb, float beta, float * C, BLASLONG ldc); | |||
int (*sgemm_small_kernel_tt )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha, float * B, BLASLONG ldb, float beta, float * C, BLASLONG ldc); | |||
int (*sgemm_small_kernel_b0_nn )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha, float * B, BLASLONG ldb, float * C, BLASLONG ldc); | |||
int (*sgemm_small_kernel_b0_nt )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha, float * B, BLASLONG ldb, float * C, BLASLONG ldc); | |||
int (*sgemm_small_kernel_b0_tn )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha, float * B, BLASLONG ldb, float * C, BLASLONG ldc); | |||
int (*sgemm_small_kernel_b0_tt )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha, float * B, BLASLONG ldb, float * C, BLASLONG ldc); | |||
#endif | |||
int (*strsm_kernel_LN)(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG, BLASLONG); | |||
int (*strsm_kernel_LT)(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG, BLASLONG); | |||
int (*strsm_kernel_RN)(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG, BLASLONG); | |||
@@ -314,6 +328,19 @@ BLASLONG (*idmin_k) (BLASLONG, double *, BLASLONG); | |||
int (*dgemm_otcopy )(BLASLONG, BLASLONG, double *, BLASLONG, double *); | |||
#endif | |||
#ifdef BUILD_DOUBLE | |||
#ifdef SMALL_MATRIX_OPT | |||
int (*dgemm_small_matrix_permit)(int transa, int transb, BLASLONG m, BLASLONG n, BLASLONG k, double alpha, double beta); | |||
int (*dgemm_small_kernel_nn )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha, double * B, BLASLONG ldb, double beta, double * C, BLASLONG ldc); | |||
int (*dgemm_small_kernel_nt )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha, double * B, BLASLONG ldb, double beta, double * C, BLASLONG ldc); | |||
int (*dgemm_small_kernel_tn )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha, double * B, BLASLONG ldb, double beta, double * C, BLASLONG ldc); | |||
int (*dgemm_small_kernel_tt )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha, double * B, BLASLONG ldb, double beta, double * C, BLASLONG ldc); | |||
int (*dgemm_small_kernel_b0_nn )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha, double * B, BLASLONG ldb, double * C, BLASLONG ldc); | |||
int (*dgemm_small_kernel_b0_nt )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha, double * B, BLASLONG ldb, double * C, BLASLONG ldc); | |||
int (*dgemm_small_kernel_b0_tn )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha, double * B, BLASLONG ldb, double * C, BLASLONG ldc); | |||
int (*dgemm_small_kernel_b0_tt )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha, double * B, BLASLONG ldb, double * C, BLASLONG ldc); | |||
#endif | |||
int (*dtrsm_kernel_LN)(BLASLONG, BLASLONG, BLASLONG, double, double *, double *, double *, BLASLONG, BLASLONG); | |||
int (*dtrsm_kernel_LT)(BLASLONG, BLASLONG, BLASLONG, double, double *, double *, double *, BLASLONG, BLASLONG); | |||
int (*dtrsm_kernel_RN)(BLASLONG, BLASLONG, BLASLONG, double, double *, double *, double *, BLASLONG, BLASLONG); | |||
@@ -513,6 +540,50 @@ BLASLONG (*icamin_k)(BLASLONG, float *, BLASLONG); | |||
int (*cgemm_oncopy )(BLASLONG, BLASLONG, float *, BLASLONG, float *); | |||
int (*cgemm_otcopy )(BLASLONG, BLASLONG, float *, BLASLONG, float *); | |||
#ifdef SMALL_MATRIX_OPT | |||
int (*cgemm_small_matrix_permit)(int transa, int transb, BLASLONG m, BLASLONG n, BLASLONG k, float alpha0, float alpha1, float beta0, float beta1); | |||
int (*cgemm_small_kernel_nn )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); | |||
int (*cgemm_small_kernel_nt )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); | |||
int (*cgemm_small_kernel_nr )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); | |||
int (*cgemm_small_kernel_nc )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); | |||
int (*cgemm_small_kernel_tn )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); | |||
int (*cgemm_small_kernel_tt )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); | |||
int (*cgemm_small_kernel_tr )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); | |||
int (*cgemm_small_kernel_tc )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); | |||
int (*cgemm_small_kernel_rn )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); | |||
int (*cgemm_small_kernel_rt )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); | |||
int (*cgemm_small_kernel_rr )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); | |||
int (*cgemm_small_kernel_rc )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); | |||
int (*cgemm_small_kernel_cn )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); | |||
int (*cgemm_small_kernel_ct )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); | |||
int (*cgemm_small_kernel_cr )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); | |||
int (*cgemm_small_kernel_cc )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); | |||
int (*cgemm_small_kernel_b0_nn )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); | |||
int (*cgemm_small_kernel_b0_nt )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); | |||
int (*cgemm_small_kernel_b0_nr )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); | |||
int (*cgemm_small_kernel_b0_nc )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); | |||
int (*cgemm_small_kernel_b0_tn )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); | |||
int (*cgemm_small_kernel_b0_tt )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); | |||
int (*cgemm_small_kernel_b0_tr )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); | |||
int (*cgemm_small_kernel_b0_tc )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); | |||
int (*cgemm_small_kernel_b0_rn )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); | |||
int (*cgemm_small_kernel_b0_rt )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); | |||
int (*cgemm_small_kernel_b0_rr )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); | |||
int (*cgemm_small_kernel_b0_rc )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); | |||
int (*cgemm_small_kernel_b0_cn )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); | |||
int (*cgemm_small_kernel_b0_ct )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); | |||
int (*cgemm_small_kernel_b0_cr )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); | |||
int (*cgemm_small_kernel_b0_cc )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); | |||
#endif | |||
int (*ctrsm_kernel_LN)(BLASLONG, BLASLONG, BLASLONG, float, float, float *, float *, float *, BLASLONG, BLASLONG); | |||
int (*ctrsm_kernel_LT)(BLASLONG, BLASLONG, BLASLONG, float, float, float *, float *, float *, BLASLONG, BLASLONG); | |||
int (*ctrsm_kernel_LR)(BLASLONG, BLASLONG, BLASLONG, float, float, float *, float *, float *, BLASLONG, BLASLONG); | |||
@@ -679,6 +750,50 @@ BLASLONG (*izamin_k)(BLASLONG, double *, BLASLONG); | |||
int (*zgemm_oncopy )(BLASLONG, BLASLONG, double *, BLASLONG, double *); | |||
int (*zgemm_otcopy )(BLASLONG, BLASLONG, double *, BLASLONG, double *); | |||
#ifdef SMALL_MATRIX_OPT | |||
int (*zgemm_small_matrix_permit)(int transa, int transb, BLASLONG m, BLASLONG n, BLASLONG k, double alpha0, double alpha1, double beta0, double beta1); | |||
int (*zgemm_small_kernel_nn )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); | |||
int (*zgemm_small_kernel_nt )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); | |||
int (*zgemm_small_kernel_nr )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); | |||
int (*zgemm_small_kernel_nc )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); | |||
int (*zgemm_small_kernel_tn )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); | |||
int (*zgemm_small_kernel_tt )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); | |||
int (*zgemm_small_kernel_tr )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); | |||
int (*zgemm_small_kernel_tc )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); | |||
int (*zgemm_small_kernel_rn )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); | |||
int (*zgemm_small_kernel_rt )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); | |||
int (*zgemm_small_kernel_rr )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); | |||
int (*zgemm_small_kernel_rc )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); | |||
int (*zgemm_small_kernel_cn )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); | |||
int (*zgemm_small_kernel_ct )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); | |||
int (*zgemm_small_kernel_cr )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); | |||
int (*zgemm_small_kernel_cc )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); | |||
int (*zgemm_small_kernel_b0_nn )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); | |||
int (*zgemm_small_kernel_b0_nt )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); | |||
int (*zgemm_small_kernel_b0_nr )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); | |||
int (*zgemm_small_kernel_b0_nc )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); | |||
int (*zgemm_small_kernel_b0_tn )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); | |||
int (*zgemm_small_kernel_b0_tt )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); | |||
int (*zgemm_small_kernel_b0_tr )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); | |||
int (*zgemm_small_kernel_b0_tc )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); | |||
int (*zgemm_small_kernel_b0_rn )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); | |||
int (*zgemm_small_kernel_b0_rt )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); | |||
int (*zgemm_small_kernel_b0_rr )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); | |||
int (*zgemm_small_kernel_b0_rc )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); | |||
int (*zgemm_small_kernel_b0_cn )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); | |||
int (*zgemm_small_kernel_b0_ct )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); | |||
int (*zgemm_small_kernel_b0_cr )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); | |||
int (*zgemm_small_kernel_b0_cc )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); | |||
#endif | |||
int (*ztrsm_kernel_LN)(BLASLONG, BLASLONG, BLASLONG, double, double, double *, double *, double *, BLASLONG, BLASLONG); | |||
int (*ztrsm_kernel_LT)(BLASLONG, BLASLONG, BLASLONG, double, double, double *, double *, double *, BLASLONG, BLASLONG); | |||
int (*ztrsm_kernel_LR)(BLASLONG, BLASLONG, BLASLONG, double, double, double *, double *, double *, BLASLONG, BLASLONG); | |||
@@ -1069,6 +1184,8 @@ BLASLONG (*ixamin_k)(BLASLONG, xdouble *, BLASLONG); | |||
extern gotoblas_t *gotoblas; | |||
#define FUNC_OFFSET(func) (size_t)(&((gotoblas_t *)NULL)->func) | |||
#define DTB_ENTRIES gotoblas -> dtb_entries | |||
#define GEMM_OFFSET_A gotoblas -> offsetA | |||
#define GEMM_OFFSET_B gotoblas -> offsetB | |||
@@ -1174,6 +1291,8 @@ extern gotoblas_t *gotoblas; | |||
#else | |||
#define FUNC_OFFSET(func) (size_t)(func) | |||
#define DTB_ENTRIES DTB_DEFAULT_ENTRIES | |||
#define GEMM_OFFSET_A GEMM_DEFAULT_OFFSET_A | |||
@@ -164,6 +164,8 @@ | |||
#define SGEADD_K sgeadd_k | |||
#define SGEMM_SMALL_MATRIX_PERMIT sgemm_small_matrix_permit | |||
#else | |||
#define SAMAX_K gotoblas -> samax_k | |||
@@ -299,8 +301,21 @@ | |||
#define SGEADD_K gotoblas -> sgeadd_k | |||
#define SGEMM_SMALL_MATRIX_PERMIT gotoblas -> sgemm_small_matrix_permit | |||
#endif | |||
#define SGEMM_SMALL_KERNEL_NN FUNC_OFFSET(sgemm_small_kernel_nn) | |||
#define SGEMM_SMALL_KERNEL_NT FUNC_OFFSET(sgemm_small_kernel_nt) | |||
#define SGEMM_SMALL_KERNEL_TN FUNC_OFFSET(sgemm_small_kernel_tn) | |||
#define SGEMM_SMALL_KERNEL_TT FUNC_OFFSET(sgemm_small_kernel_tt) | |||
#define SGEMM_SMALL_KERNEL_B0_NN FUNC_OFFSET(sgemm_small_kernel_b0_nn) | |||
#define SGEMM_SMALL_KERNEL_B0_NT FUNC_OFFSET(sgemm_small_kernel_b0_nt) | |||
#define SGEMM_SMALL_KERNEL_B0_TN FUNC_OFFSET(sgemm_small_kernel_b0_tn) | |||
#define SGEMM_SMALL_KERNEL_B0_TT FUNC_OFFSET(sgemm_small_kernel_b0_tt) | |||
#define SGEMM_NN sgemm_nn | |||
#define SGEMM_CN sgemm_tn | |||
#define SGEMM_TN sgemm_tn | |||
@@ -232,6 +232,8 @@ | |||
#define ZGEADD_K zgeadd_k | |||
#define ZGEMM_SMALL_MATRIX_PERMIT zgemm_small_matrix_permit | |||
#else | |||
#define ZAMAX_K gotoblas -> zamax_k | |||
@@ -426,8 +428,51 @@ | |||
#define ZGEADD_K gotoblas -> zgeadd_k | |||
#define ZGEMM_SMALL_MATRIX_PERMIT gotoblas -> zgemm_small_matrix_permit | |||
#endif | |||
#define ZGEMM_SMALL_KERNEL_NN FUNC_OFFSET(zgemm_small_kernel_nn) | |||
#define ZGEMM_SMALL_KERNEL_NT FUNC_OFFSET(zgemm_small_kernel_nt) | |||
#define ZGEMM_SMALL_KERNEL_NR FUNC_OFFSET(zgemm_small_kernel_nr) | |||
#define ZGEMM_SMALL_KERNEL_NC FUNC_OFFSET(zgemm_small_kernel_nc) | |||
#define ZGEMM_SMALL_KERNEL_TN FUNC_OFFSET(zgemm_small_kernel_tn) | |||
#define ZGEMM_SMALL_KERNEL_TT FUNC_OFFSET(zgemm_small_kernel_tt) | |||
#define ZGEMM_SMALL_KERNEL_TR FUNC_OFFSET(zgemm_small_kernel_tr) | |||
#define ZGEMM_SMALL_KERNEL_TC FUNC_OFFSET(zgemm_small_kernel_tc) | |||
#define ZGEMM_SMALL_KERNEL_RN FUNC_OFFSET(zgemm_small_kernel_rn) | |||
#define ZGEMM_SMALL_KERNEL_RT FUNC_OFFSET(zgemm_small_kernel_rt) | |||
#define ZGEMM_SMALL_KERNEL_RR FUNC_OFFSET(zgemm_small_kernel_rr) | |||
#define ZGEMM_SMALL_KERNEL_RC FUNC_OFFSET(zgemm_small_kernel_rc) | |||
#define ZGEMM_SMALL_KERNEL_CN FUNC_OFFSET(zgemm_small_kernel_cn) | |||
#define ZGEMM_SMALL_KERNEL_CT FUNC_OFFSET(zgemm_small_kernel_ct) | |||
#define ZGEMM_SMALL_KERNEL_CR FUNC_OFFSET(zgemm_small_kernel_cr) | |||
#define ZGEMM_SMALL_KERNEL_CC FUNC_OFFSET(zgemm_small_kernel_cc) | |||
#define ZGEMM_SMALL_KERNEL_B0_NN FUNC_OFFSET(zgemm_small_kernel_b0_nn) | |||
#define ZGEMM_SMALL_KERNEL_B0_NT FUNC_OFFSET(zgemm_small_kernel_b0_nt) | |||
#define ZGEMM_SMALL_KERNEL_B0_NR FUNC_OFFSET(zgemm_small_kernel_b0_nr) | |||
#define ZGEMM_SMALL_KERNEL_B0_NC FUNC_OFFSET(zgemm_small_kernel_b0_nc) | |||
#define ZGEMM_SMALL_KERNEL_B0_TN FUNC_OFFSET(zgemm_small_kernel_b0_tn) | |||
#define ZGEMM_SMALL_KERNEL_B0_TT FUNC_OFFSET(zgemm_small_kernel_b0_tt) | |||
#define ZGEMM_SMALL_KERNEL_B0_TR FUNC_OFFSET(zgemm_small_kernel_b0_tr) | |||
#define ZGEMM_SMALL_KERNEL_B0_TC FUNC_OFFSET(zgemm_small_kernel_b0_tc) | |||
#define ZGEMM_SMALL_KERNEL_B0_RN FUNC_OFFSET(zgemm_small_kernel_b0_rn) | |||
#define ZGEMM_SMALL_KERNEL_B0_RT FUNC_OFFSET(zgemm_small_kernel_b0_rt) | |||
#define ZGEMM_SMALL_KERNEL_B0_RR FUNC_OFFSET(zgemm_small_kernel_b0_rr) | |||
#define ZGEMM_SMALL_KERNEL_B0_RC FUNC_OFFSET(zgemm_small_kernel_b0_rc) | |||
#define ZGEMM_SMALL_KERNEL_B0_CN FUNC_OFFSET(zgemm_small_kernel_b0_cn) | |||
#define ZGEMM_SMALL_KERNEL_B0_CT FUNC_OFFSET(zgemm_small_kernel_b0_ct) | |||
#define ZGEMM_SMALL_KERNEL_B0_CR FUNC_OFFSET(zgemm_small_kernel_b0_cr) | |||
#define ZGEMM_SMALL_KERNEL_B0_CC FUNC_OFFSET(zgemm_small_kernel_b0_cc) | |||
#define ZGEMM_NN zgemm_nn | |||
#define ZGEMM_CN zgemm_cn | |||
#define ZGEMM_TN zgemm_tn | |||
@@ -105,6 +105,55 @@ static int (*gemm[])(blas_arg_t *, BLASLONG *, BLASLONG *, IFLOAT *, IFLOAT *, B | |||
#endif | |||
}; | |||
#if defined(SMALL_MATRIX_OPT) && !defined(GEMM3M) && !defined(XDOUBLE) && !defined(BFLOAT16) | |||
#define USE_SMALL_MATRIX_OPT 1 | |||
#else | |||
#define USE_SMALL_MATRIX_OPT 0 | |||
#endif | |||
#if USE_SMALL_MATRIX_OPT | |||
#ifndef DYNAMIC_ARCH | |||
#define SMALL_KERNEL_ADDR(table, idx) ((void *)(table[idx])) | |||
#else | |||
#define SMALL_KERNEL_ADDR(table, idx) ((void *)(*(uintptr_t *)((char *)gotoblas + (size_t)(table[idx])))) | |||
#endif | |||
#ifndef COMPLEX | |||
static size_t gemm_small_kernel[] = { | |||
GEMM_SMALL_KERNEL_NN, GEMM_SMALL_KERNEL_TN, 0, 0, | |||
GEMM_SMALL_KERNEL_NT, GEMM_SMALL_KERNEL_TT, 0, 0, | |||
}; | |||
static size_t gemm_small_kernel_b0[] = { | |||
GEMM_SMALL_KERNEL_B0_NN, GEMM_SMALL_KERNEL_B0_TN, 0, 0, | |||
GEMM_SMALL_KERNEL_B0_NT, GEMM_SMALL_KERNEL_B0_TT, 0, 0, | |||
}; | |||
#define GEMM_SMALL_KERNEL_B0(idx) (int (*)(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT, FLOAT *, BLASLONG, FLOAT *, BLASLONG)) SMALL_KERNEL_ADDR(gemm_small_kernel_b0, (idx)) | |||
#define GEMM_SMALL_KERNEL(idx) (int (*)(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT ,FLOAT *, BLASLONG, FLOAT, FLOAT *, BLASLONG)) SMALL_KERNEL_ADDR(gemm_small_kernel, (idx)) | |||
#else | |||
static size_t zgemm_small_kernel[] = { | |||
GEMM_SMALL_KERNEL_NN, GEMM_SMALL_KERNEL_TN, GEMM_SMALL_KERNEL_RN, GEMM_SMALL_KERNEL_CN, | |||
GEMM_SMALL_KERNEL_NT, GEMM_SMALL_KERNEL_TT, GEMM_SMALL_KERNEL_RT, GEMM_SMALL_KERNEL_CT, | |||
GEMM_SMALL_KERNEL_NR, GEMM_SMALL_KERNEL_TR, GEMM_SMALL_KERNEL_RR, GEMM_SMALL_KERNEL_CR, | |||
GEMM_SMALL_KERNEL_NC, GEMM_SMALL_KERNEL_TC, GEMM_SMALL_KERNEL_RC, GEMM_SMALL_KERNEL_CC, | |||
}; | |||
static size_t zgemm_small_kernel_b0[] = { | |||
GEMM_SMALL_KERNEL_B0_NN, GEMM_SMALL_KERNEL_B0_TN, GEMM_SMALL_KERNEL_B0_RN, GEMM_SMALL_KERNEL_B0_CN, | |||
GEMM_SMALL_KERNEL_B0_NT, GEMM_SMALL_KERNEL_B0_TT, GEMM_SMALL_KERNEL_B0_RT, GEMM_SMALL_KERNEL_B0_CT, | |||
GEMM_SMALL_KERNEL_B0_NR, GEMM_SMALL_KERNEL_B0_TR, GEMM_SMALL_KERNEL_B0_RR, GEMM_SMALL_KERNEL_B0_CR, | |||
GEMM_SMALL_KERNEL_B0_NC, GEMM_SMALL_KERNEL_B0_TC, GEMM_SMALL_KERNEL_B0_RC, GEMM_SMALL_KERNEL_B0_CC, | |||
}; | |||
#define ZGEMM_SMALL_KERNEL(idx) (int (*)(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT , FLOAT, FLOAT *, BLASLONG, FLOAT , FLOAT, FLOAT *, BLASLONG)) SMALL_KERNEL_ADDR(zgemm_small_kernel, (idx)) | |||
#define ZGEMM_SMALL_KERNEL_B0(idx) (int (*)(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT , FLOAT, FLOAT *, BLASLONG, FLOAT *, BLASLONG)) SMALL_KERNEL_ADDR(zgemm_small_kernel_b0, (idx)) | |||
#endif | |||
#endif | |||
#ifndef CBLAS | |||
void NAME(char *TRANSA, char *TRANSB, | |||
@@ -417,6 +466,28 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS | |||
FUNCTION_PROFILE_START(); | |||
#if USE_SMALL_MATRIX_OPT | |||
#if !defined(COMPLEX) | |||
if(GEMM_SMALL_MATRIX_PERMIT(transa, transb, args.m, args.n, args.k, *(FLOAT *)(args.alpha), *(FLOAT *)(args.beta))){ | |||
if(*(FLOAT *)(args.beta) == 0.0){ | |||
(GEMM_SMALL_KERNEL_B0((transb << 2) | transa))(args.m, args.n, args.k, args.a, args.lda, *(FLOAT *)(args.alpha), args.b, args.ldb, args.c, args.ldc); | |||
}else{ | |||
(GEMM_SMALL_KERNEL((transb << 2) | transa))(args.m, args.n, args.k, args.a, args.lda, *(FLOAT *)(args.alpha), args.b, args.ldb, *(FLOAT *)(args.beta), args.c, args.ldc); | |||
} | |||
return; | |||
} | |||
#else | |||
if(GEMM_SMALL_MATRIX_PERMIT(transa, transb, args.m, args.n, args.k, alpha[0], alpha[1], beta[0], beta[1])){ | |||
if(beta[0] == 0.0 && beta[1] == 0.0){ | |||
(ZGEMM_SMALL_KERNEL_B0((transb << 2) | transa))(args.m, args.n, args.k, args.a, args.lda, alpha[0], alpha[1], args.b, args.ldb, args.c, args.ldc); | |||
}else{ | |||
(ZGEMM_SMALL_KERNEL((transb << 2) | transa))(args.m, args.n, args.k, args.a, args.lda, alpha[0], alpha[1], args.b, args.ldb, beta[0], beta[1], args.c, args.ldc); | |||
} | |||
return; | |||
} | |||
#endif | |||
#endif | |||
buffer = (XFLOAT *)blas_memory_alloc(0); | |||
sa = (XFLOAT *)((BLASLONG)buffer +GEMM_OFFSET_A); | |||
@@ -458,7 +458,117 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS) | |||
GenerateNamedObjects("${KERNELDIR}/${${float_char}TRSMKERNEL_RN}" "UPPER;RN;TRSMKERNEL" "trsm_kernel_RN" false "" "" false ${float_type}) | |||
GenerateNamedObjects("${KERNELDIR}/${${float_char}TRSMKERNEL_RT}" "RT;TRSMKERNEL" "trsm_kernel_RT" false "" "" false ${float_type}) | |||
if (NOT DEFINED ${float_char}GEMM_SMALL_M_PERMIT) | |||
if (${float_char} STREQUAL "Z" OR ${float_char} STREQUAL "C") | |||
set(${float_char}GEMM_SMALL_M_PERMIT ../generic/zgemm_small_matrix_permit.c) | |||
else () | |||
set(${float_char}GEMM_SMALL_M_PERMIT ../generic/gemm_small_matrix_permit.c) | |||
endif () | |||
endif () | |||
if (NOT DEFINED ${float_char}GEMM_SMALL_K_NN) | |||
if (${float_char} STREQUAL "Z" OR ${float_char} STREQUAL "C") | |||
set(${float_char}GEMM_SMALL_K_NN ../generic/zgemm_small_matrix_kernel_nn.c) | |||
else () | |||
set(${float_char}GEMM_SMALL_K_NN ../generic/gemm_small_matrix_kernel_nn.c) | |||
endif () | |||
endif () | |||
if (NOT DEFINED ${float_char}GEMM_SMALL_K_NT) | |||
if (${float_char} STREQUAL "Z" OR ${float_char} STREQUAL "C") | |||
set(${float_char}GEMM_SMALL_K_NT ../generic/zgemm_small_matrix_kernel_nt.c) | |||
else () | |||
set(${float_char}GEMM_SMALL_K_NT ../generic/gemm_small_matrix_kernel_nt.c) | |||
endif () | |||
endif () | |||
if (NOT DEFINED ${float_char}GEMM_SMALL_K_TN) | |||
if (${float_char} STREQUAL "Z" OR ${float_char} STREQUAL "C") | |||
set(${float_char}GEMM_SMALL_K_TN ../generic/zgemm_small_matrix_kernel_tn.c) | |||
else () | |||
set(${float_char}GEMM_SMALL_K_TN ../generic/gemm_small_matrix_kernel_tn.c) | |||
endif () | |||
endif () | |||
if (NOT DEFINED ${float_char}GEMM_SMALL_K_TT) | |||
if (${float_char} STREQUAL "Z" OR ${float_char} STREQUAL "C") | |||
set(${float_char}GEMM_SMALL_K_TT ../generic/zgemm_small_matrix_kernel_tt.c) | |||
else () | |||
set(${float_char}GEMM_SMALL_K_TT ../generic/gemm_small_matrix_kernel_tt.c) | |||
endif () | |||
endif () | |||
if (NOT DEFINED ${float_char}GEMM_SMALL_K_B0_NN) | |||
if (${float_char} STREQUAL "Z" OR ${float_char} STREQUAL "C") | |||
set(${float_char}GEMM_SMALL_K_B0_NN ../generic/zgemm_small_matrix_kernel_nn.c) | |||
else () | |||
set(${float_char}GEMM_SMALL_K_B0_NN ../generic/gemm_small_matrix_kernel_nn.c) | |||
endif () | |||
endif () | |||
if (NOT DEFINED ${float_char}GEMM_SMALL_K_B0_NT) | |||
if (${float_char} STREQUAL "Z" OR ${float_char} STREQUAL "C") | |||
set(${float_char}GEMM_SMALL_K_B0_NT ../generic/zgemm_small_matrix_kernel_nt.c) | |||
else () | |||
set(${float_char}GEMM_SMALL_K_B0_NT ../generic/gemm_small_matrix_kernel_nt.c) | |||
endif () | |||
endif () | |||
if (NOT DEFINED ${float_char}GEMM_SMALL_K_B0_TN) | |||
if (${float_char} STREQUAL "Z" OR ${float_char} STREQUAL "C") | |||
set(${float_char}GEMM_SMALL_K_B0_TN ../generic/zgemm_small_matrix_kernel_tn.c) | |||
else () | |||
set(${float_char}GEMM_SMALL_K_B0_TN ../generic/gemm_small_matrix_kernel_tn.c) | |||
endif () | |||
endif () | |||
if (NOT DEFINED ${float_char}GEMM_SMALL_K_B0_TT) | |||
if (${float_char} STREQUAL "Z" OR ${float_char} STREQUAL "C") | |||
set(${float_char}GEMM_SMALL_K_B0_TT ../generic/zgemm_small_matrix_kernel_tt.c) | |||
else () | |||
set(${float_char}GEMM_SMALL_K_B0_TT ../generic/gemm_small_matrix_kernel_tt.c) | |||
endif () | |||
endif () | |||
if (SMALL_MATRIX_OPT) | |||
GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_M_PERMIT}" "" "gemm_small_matrix_permit" false "" "" false ${float_type}) | |||
if (${float_char} STREQUAL "Z" OR ${float_char} STREQUAL "C") | |||
GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_NN}" "NN" "gemm_small_kernel_nn" false "" "" false ${float_type}) | |||
GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_NN}" "NR" "gemm_small_kernel_nr" false "" "" false ${float_type}) | |||
GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_NN}" "RN" "gemm_small_kernel_rn" false "" "" false ${float_type}) | |||
GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_NN}" "RR" "gemm_small_kernel_rr" false "" "" false ${float_type}) | |||
GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_NT}" "NT" "gemm_small_kernel_nt" false "" "" false ${float_type}) | |||
GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_NT}" "NC" "gemm_small_kernel_nc" false "" "" false ${float_type}) | |||
GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_NT}" "RT" "gemm_small_kernel_rt" false "" "" false ${float_type}) | |||
GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_NT}" "RC" "gemm_small_kernel_rc" false "" "" false ${float_type}) | |||
GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_TN}" "TN" "gemm_small_kernel_tn" false "" "" false ${float_type}) | |||
GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_TN}" "TR" "gemm_small_kernel_tr" false "" "" false ${float_type}) | |||
GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_TN}" "CN" "gemm_small_kernel_cn" false "" "" false ${float_type}) | |||
GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_TN}" "CR" "gemm_small_kernel_cr" false "" "" false ${float_type}) | |||
GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_TT}" "TT" "gemm_small_kernel_tt" false "" "" false ${float_type}) | |||
GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_TT}" "TC" "gemm_small_kernel_tc" false "" "" false ${float_type}) | |||
GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_TT}" "CT" "gemm_small_kernel_ct" false "" "" false ${float_type}) | |||
GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_TT}" "CC" "gemm_small_kernel_cc" false "" "" false ${float_type}) | |||
GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_NN}" "NN;B0" "gemm_small_kernel_b0_nn" false "" "" false ${float_type}) | |||
GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_NN}" "NR;B0" "gemm_small_kernel_b0_nr" false "" "" false ${float_type}) | |||
GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_NN}" "RN;B0" "gemm_small_kernel_b0_rn" false "" "" false ${float_type}) | |||
GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_NN}" "RR;B0" "gemm_small_kernel_b0_rr" false "" "" false ${float_type}) | |||
GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_NT}" "NT;B0" "gemm_small_kernel_b0_nt" false "" "" false ${float_type}) | |||
GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_NT}" "NC;B0" "gemm_small_kernel_b0_nc" false "" "" false ${float_type}) | |||
GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_NT}" "RT;B0" "gemm_small_kernel_b0_rt" false "" "" false ${float_type}) | |||
GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_NT}" "RC;B0" "gemm_small_kernel_b0_rc" false "" "" false ${float_type}) | |||
GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_TN}" "TN;B0" "gemm_small_kernel_b0_tn" false "" "" false ${float_type}) | |||
GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_TN}" "TR;B0" "gemm_small_kernel_b0_tr" false "" "" false ${float_type}) | |||
GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_TN}" "CN;B0" "gemm_small_kernel_b0_cn" false "" "" false ${float_type}) | |||
GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_TN}" "CR;B0" "gemm_small_kernel_b0_cr" false "" "" false ${float_type}) | |||
GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_TT}" "TT;B0" "gemm_small_kernel_b0_tt" false "" "" false ${float_type}) | |||
GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_TT}" "TC;B0" "gemm_small_kernel_b0_tc" false "" "" false ${float_type}) | |||
GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_TT}" "CT;B0" "gemm_small_kernel_b0_ct" false "" "" false ${float_type}) | |||
GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_TT}" "CC;B0" "gemm_small_kernel_b0_cc" false "" "" false ${float_type}) | |||
else () | |||
GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_NN}" "" "gemm_small_kernel_nn" false "" "" false ${float_type}) | |||
GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_NT}" "" "gemm_small_kernel_nt" false "" "" false ${float_type}) | |||
GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_TN}" "" "gemm_small_kernel_tn" false "" "" false ${float_type}) | |||
GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_NT}" "" "gemm_small_kernel_tt" false "" "" false ${float_type}) | |||
GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_NN}" "B0" "gemm_small_kernel_b0_nn" false "" "" false ${float_type}) | |||
GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_NT}" "B0" "gemm_small_kernel_b0_nt" false "" "" false ${float_type}) | |||
GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_TN}" "B0" "gemm_small_kernel_b0_tn" false "" "" false ${float_type}) | |||
GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_NT}" "B0" "gemm_small_kernel_b0_tt" false "" "" false ${float_type}) | |||
endif () | |||
endif () | |||
if (NOT DEFINED ${float_char}OMATCOPY_CN) | |||
if (${float_char} STREQUAL "Z" OR ${float_char} STREQUAL "C") | |||
@@ -447,6 +447,63 @@ XBLASOBJS += \ | |||
endif | |||
###### BLAS small matrix optimization ##### | |||
ifeq ($(SMALL_MATRIX_OPT), 1) | |||
SBLASOBJS += \ | |||
sgemm_small_matrix_permit$(TSUFFIX).$(SUFFIX) \ | |||
sgemm_small_kernel_nn$(TSUFFIX).$(SUFFIX) sgemm_small_kernel_nt$(TSUFFIX).$(SUFFIX) \ | |||
sgemm_small_kernel_tn$(TSUFFIX).$(SUFFIX) sgemm_small_kernel_tt$(TSUFFIX).$(SUFFIX) \ | |||
sgemm_small_kernel_b0_nn$(TSUFFIX).$(SUFFIX) sgemm_small_kernel_b0_nt$(TSUFFIX).$(SUFFIX) \ | |||
sgemm_small_kernel_b0_tn$(TSUFFIX).$(SUFFIX) sgemm_small_kernel_b0_tt$(TSUFFIX).$(SUFFIX) | |||
DBLASOBJS += \ | |||
dgemm_small_matrix_permit$(TSUFFIX).$(SUFFIX) \ | |||
dgemm_small_kernel_nn$(TSUFFIX).$(SUFFIX) dgemm_small_kernel_nt$(TSUFFIX).$(SUFFIX) \ | |||
dgemm_small_kernel_tn$(TSUFFIX).$(SUFFIX) dgemm_small_kernel_tt$(TSUFFIX).$(SUFFIX) \ | |||
dgemm_small_kernel_b0_nn$(TSUFFIX).$(SUFFIX) dgemm_small_kernel_b0_nt$(TSUFFIX).$(SUFFIX) \ | |||
dgemm_small_kernel_b0_tn$(TSUFFIX).$(SUFFIX) dgemm_small_kernel_b0_tt$(TSUFFIX).$(SUFFIX) | |||
CBLASOBJS += \ | |||
cgemm_small_matrix_permit$(TSUFFIX).$(SUFFIX) \ | |||
cgemm_small_kernel_nn$(TSUFFIX).$(SUFFIX) cgemm_small_kernel_nt$(TSUFFIX).$(SUFFIX) \ | |||
cgemm_small_kernel_nr$(TSUFFIX).$(SUFFIX) cgemm_small_kernel_nc$(TSUFFIX).$(SUFFIX) \ | |||
cgemm_small_kernel_tn$(TSUFFIX).$(SUFFIX) cgemm_small_kernel_tt$(TSUFFIX).$(SUFFIX) \ | |||
cgemm_small_kernel_tr$(TSUFFIX).$(SUFFIX) cgemm_small_kernel_tc$(TSUFFIX).$(SUFFIX) \ | |||
cgemm_small_kernel_rn$(TSUFFIX).$(SUFFIX) cgemm_small_kernel_rt$(TSUFFIX).$(SUFFIX) \ | |||
cgemm_small_kernel_rr$(TSUFFIX).$(SUFFIX) cgemm_small_kernel_rc$(TSUFFIX).$(SUFFIX) \ | |||
cgemm_small_kernel_cn$(TSUFFIX).$(SUFFIX) cgemm_small_kernel_ct$(TSUFFIX).$(SUFFIX) \ | |||
cgemm_small_kernel_cr$(TSUFFIX).$(SUFFIX) cgemm_small_kernel_cc$(TSUFFIX).$(SUFFIX) \ | |||
cgemm_small_kernel_b0_nn$(TSUFFIX).$(SUFFIX) cgemm_small_kernel_b0_nt$(TSUFFIX).$(SUFFIX) \ | |||
cgemm_small_kernel_b0_nr$(TSUFFIX).$(SUFFIX) cgemm_small_kernel_b0_nc$(TSUFFIX).$(SUFFIX) \ | |||
cgemm_small_kernel_b0_tn$(TSUFFIX).$(SUFFIX) cgemm_small_kernel_b0_tt$(TSUFFIX).$(SUFFIX) \ | |||
cgemm_small_kernel_b0_tr$(TSUFFIX).$(SUFFIX) cgemm_small_kernel_b0_tc$(TSUFFIX).$(SUFFIX) \ | |||
cgemm_small_kernel_b0_rn$(TSUFFIX).$(SUFFIX) cgemm_small_kernel_b0_rt$(TSUFFIX).$(SUFFIX) \ | |||
cgemm_small_kernel_b0_rr$(TSUFFIX).$(SUFFIX) cgemm_small_kernel_b0_rc$(TSUFFIX).$(SUFFIX) \ | |||
cgemm_small_kernel_b0_cn$(TSUFFIX).$(SUFFIX) cgemm_small_kernel_b0_ct$(TSUFFIX).$(SUFFIX) \ | |||
cgemm_small_kernel_b0_cr$(TSUFFIX).$(SUFFIX) cgemm_small_kernel_b0_cc$(TSUFFIX).$(SUFFIX) | |||
ZBLASOBJS += \ | |||
zgemm_small_matrix_permit$(TSUFFIX).$(SUFFIX) \ | |||
zgemm_small_kernel_nn$(TSUFFIX).$(SUFFIX) zgemm_small_kernel_nt$(TSUFFIX).$(SUFFIX) \ | |||
zgemm_small_kernel_nr$(TSUFFIX).$(SUFFIX) zgemm_small_kernel_nc$(TSUFFIX).$(SUFFIX) \ | |||
zgemm_small_kernel_tn$(TSUFFIX).$(SUFFIX) zgemm_small_kernel_tt$(TSUFFIX).$(SUFFIX) \ | |||
zgemm_small_kernel_tr$(TSUFFIX).$(SUFFIX) zgemm_small_kernel_tc$(TSUFFIX).$(SUFFIX) \ | |||
zgemm_small_kernel_rn$(TSUFFIX).$(SUFFIX) zgemm_small_kernel_rt$(TSUFFIX).$(SUFFIX) \ | |||
zgemm_small_kernel_rr$(TSUFFIX).$(SUFFIX) zgemm_small_kernel_rc$(TSUFFIX).$(SUFFIX) \ | |||
zgemm_small_kernel_cn$(TSUFFIX).$(SUFFIX) zgemm_small_kernel_ct$(TSUFFIX).$(SUFFIX) \ | |||
zgemm_small_kernel_cr$(TSUFFIX).$(SUFFIX) zgemm_small_kernel_cc$(TSUFFIX).$(SUFFIX) \ | |||
zgemm_small_kernel_b0_nn$(TSUFFIX).$(SUFFIX) zgemm_small_kernel_b0_nt$(TSUFFIX).$(SUFFIX) \ | |||
zgemm_small_kernel_b0_nr$(TSUFFIX).$(SUFFIX) zgemm_small_kernel_b0_nc$(TSUFFIX).$(SUFFIX) \ | |||
zgemm_small_kernel_b0_tn$(TSUFFIX).$(SUFFIX) zgemm_small_kernel_b0_tt$(TSUFFIX).$(SUFFIX) \ | |||
zgemm_small_kernel_b0_tr$(TSUFFIX).$(SUFFIX) zgemm_small_kernel_b0_tc$(TSUFFIX).$(SUFFIX) \ | |||
zgemm_small_kernel_b0_rn$(TSUFFIX).$(SUFFIX) zgemm_small_kernel_b0_rt$(TSUFFIX).$(SUFFIX) \ | |||
zgemm_small_kernel_b0_rr$(TSUFFIX).$(SUFFIX) zgemm_small_kernel_b0_rc$(TSUFFIX).$(SUFFIX) \ | |||
zgemm_small_kernel_b0_cn$(TSUFFIX).$(SUFFIX) zgemm_small_kernel_b0_ct$(TSUFFIX).$(SUFFIX) \ | |||
zgemm_small_kernel_b0_cr$(TSUFFIX).$(SUFFIX) zgemm_small_kernel_b0_cc$(TSUFFIX).$(SUFFIX) | |||
endif | |||
###### BLAS extensions ##### | |||
ifeq ($(BUILD_SINGLE),1) | |||
@@ -4237,3 +4294,403 @@ endif | |||
$(KDIR)zgeadd_k$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEADD_K) | |||
$(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -UROWM $< -o $@ | |||
###### BLAS small matrix optimization ##### | |||
ifndef DGEMM_SMALL_M_PERMIT | |||
DGEMM_SMALL_M_PERMIT = ../generic/gemm_small_matrix_permit.c | |||
endif | |||
ifndef DGEMM_SMALL_K_NN | |||
DGEMM_SMALL_K_NN = ../generic/gemm_small_matrix_kernel_nn.c | |||
endif | |||
ifndef DGEMM_SMALL_K_NT | |||
DGEMM_SMALL_K_NT = ../generic/gemm_small_matrix_kernel_nt.c | |||
endif | |||
ifndef DGEMM_SMALL_K_TN | |||
DGEMM_SMALL_K_TN = ../generic/gemm_small_matrix_kernel_tn.c | |||
endif | |||
ifndef DGEMM_SMALL_K_TT | |||
DGEMM_SMALL_K_TT = ../generic/gemm_small_matrix_kernel_tt.c | |||
endif | |||
$(KDIR)dgemm_small_matrix_permit$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMM_SMALL_M_PERMIT) | |||
$(CC) $(CFLAGS) -c -DDOUBLE -UCOMPLEX $< -o $@ | |||
$(KDIR)dgemm_small_kernel_nn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMM_SMALL_K_NN) | |||
$(CC) $(CFLAGS) -c -DDOUBLE -UCOMPLEX $< -o $@ | |||
$(KDIR)dgemm_small_kernel_nt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMM_SMALL_K_NT) | |||
$(CC) $(CFLAGS) -c -DDOUBLE -UCOMPLEX $< -o $@ | |||
$(KDIR)dgemm_small_kernel_tn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMM_SMALL_K_TN) | |||
$(CC) $(CFLAGS) -c -DDOUBLE -UCOMPLEX $< -o $@ | |||
$(KDIR)dgemm_small_kernel_tt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMM_SMALL_K_TT) | |||
$(CC) $(CFLAGS) -c -DDOUBLE -UCOMPLEX $< -o $@ | |||
ifndef DGEMM_SMALL_K_B0_NN | |||
DGEMM_SMALL_K_B0_NN = ../generic/gemm_small_matrix_kernel_nn.c | |||
endif | |||
ifndef DGEMM_SMALL_K_B0_NT | |||
DGEMM_SMALL_K_B0_NT = ../generic/gemm_small_matrix_kernel_nt.c | |||
endif | |||
ifndef DGEMM_SMALL_K_B0_TN | |||
DGEMM_SMALL_K_B0_TN = ../generic/gemm_small_matrix_kernel_tn.c | |||
endif | |||
ifndef DGEMM_SMALL_K_B0_TT | |||
DGEMM_SMALL_K_B0_TT = ../generic/gemm_small_matrix_kernel_tt.c | |||
endif | |||
$(KDIR)dgemm_small_kernel_b0_nn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMM_SMALL_K_B0_NN) | |||
$(CC) $(CFLAGS) -c -DDOUBLE -UCOMPLEX -DB0 $< -o $@ | |||
$(KDIR)dgemm_small_kernel_b0_nt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMM_SMALL_K_B0_NT) | |||
$(CC) $(CFLAGS) -c -DDOUBLE -UCOMPLEX -DB0 $< -o $@ | |||
$(KDIR)dgemm_small_kernel_b0_tn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMM_SMALL_K_B0_TN) | |||
$(CC) $(CFLAGS) -c -DDOUBLE -UCOMPLEX -DB0 $< -o $@ | |||
$(KDIR)dgemm_small_kernel_b0_tt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMM_SMALL_K_B0_TT) | |||
$(CC) $(CFLAGS) -c -DDOUBLE -UCOMPLEX -DB0 $< -o $@ | |||
ifndef SGEMM_SMALL_M_PERMIT | |||
SGEMM_SMALL_M_PERMIT = ../generic/gemm_small_matrix_permit.c | |||
endif | |||
ifndef SGEMM_SMALL_K_NN | |||
SGEMM_SMALL_K_NN = ../generic/gemm_small_matrix_kernel_nn.c | |||
endif | |||
ifndef SGEMM_SMALL_K_NT | |||
SGEMM_SMALL_K_NT = ../generic/gemm_small_matrix_kernel_nt.c | |||
endif | |||
ifndef SGEMM_SMALL_K_TN | |||
SGEMM_SMALL_K_TN = ../generic/gemm_small_matrix_kernel_tn.c | |||
endif | |||
ifndef SGEMM_SMALL_K_TT | |||
SGEMM_SMALL_K_TT = ../generic/gemm_small_matrix_kernel_tt.c | |||
endif | |||
$(KDIR)sgemm_small_matrix_permit$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_SMALL_M_PERMIT) | |||
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ | |||
$(KDIR)sgemm_small_kernel_nn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_SMALL_K_NN) | |||
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ | |||
$(KDIR)sgemm_small_kernel_nt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_SMALL_K_NT) | |||
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ | |||
$(KDIR)sgemm_small_kernel_tn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_SMALL_K_TN) | |||
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ | |||
$(KDIR)sgemm_small_kernel_tt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_SMALL_K_TT) | |||
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ | |||
ifndef SGEMM_SMALL_K_B0_NN | |||
SGEMM_SMALL_K_B0_NN = ../generic/gemm_small_matrix_kernel_nn.c | |||
endif | |||
ifndef SGEMM_SMALL_K_B0_NT | |||
SGEMM_SMALL_K_B0_NT = ../generic/gemm_small_matrix_kernel_nt.c | |||
endif | |||
ifndef SGEMM_SMALL_K_B0_TN | |||
SGEMM_SMALL_K_B0_TN = ../generic/gemm_small_matrix_kernel_tn.c | |||
endif | |||
ifndef SGEMM_SMALL_K_B0_TT | |||
SGEMM_SMALL_K_B0_TT = ../generic/gemm_small_matrix_kernel_tt.c | |||
endif | |||
$(KDIR)sgemm_small_kernel_b0_nn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_SMALL_K_B0_NN) | |||
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -DB0 $< -o $@ | |||
$(KDIR)sgemm_small_kernel_b0_nt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_SMALL_K_B0_NT) | |||
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -DB0 $< -o $@ | |||
$(KDIR)sgemm_small_kernel_b0_tn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_SMALL_K_B0_TN) | |||
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -DB0 $< -o $@ | |||
$(KDIR)sgemm_small_kernel_b0_tt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_SMALL_K_B0_TT) | |||
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -DB0 $< -o $@ | |||
ifndef CGEMM_SMALL_M_PERMIT | |||
CGEMM_SMALL_M_PERMIT = ../generic/zgemm_small_matrix_permit.c | |||
endif | |||
ifndef CGEMM_SMALL_K_NN | |||
CGEMM_SMALL_K_NN = ../generic/zgemm_small_matrix_kernel_nn.c | |||
endif | |||
ifndef CGEMM_SMALL_K_NT | |||
CGEMM_SMALL_K_NT = ../generic/zgemm_small_matrix_kernel_nt.c | |||
endif | |||
ifndef CGEMM_SMALL_K_TN | |||
CGEMM_SMALL_K_TN = ../generic/zgemm_small_matrix_kernel_tn.c | |||
endif | |||
ifndef CGEMM_SMALL_K_TT | |||
CGEMM_SMALL_K_TT = ../generic/zgemm_small_matrix_kernel_tt.c | |||
endif | |||
$(KDIR)cgemm_small_matrix_permit$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_M_PERMIT) | |||
$(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX $< -o $@ | |||
$(KDIR)cgemm_small_kernel_nn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_NN) | |||
$(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DNN $< -o $@ | |||
$(KDIR)cgemm_small_kernel_nr$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_NN) | |||
$(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DNR $< -o $@ | |||
$(KDIR)cgemm_small_kernel_rn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_NN) | |||
$(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DRN $< -o $@ | |||
$(KDIR)cgemm_small_kernel_rr$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_NN) | |||
$(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DRR $< -o $@ | |||
$(KDIR)cgemm_small_kernel_nt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_NT) | |||
$(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DNT $< -o $@ | |||
$(KDIR)cgemm_small_kernel_nc$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_NT) | |||
$(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DNC $< -o $@ | |||
$(KDIR)cgemm_small_kernel_rt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_NT) | |||
$(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DRT $< -o $@ | |||
$(KDIR)cgemm_small_kernel_rc$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_NT) | |||
$(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DRC $< -o $@ | |||
$(KDIR)cgemm_small_kernel_tn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_TN) | |||
$(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DTN $< -o $@ | |||
$(KDIR)cgemm_small_kernel_tr$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_TN) | |||
$(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DTR $< -o $@ | |||
$(KDIR)cgemm_small_kernel_cn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_TN) | |||
$(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DCN $< -o $@ | |||
$(KDIR)cgemm_small_kernel_cr$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_TN) | |||
$(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DCR $< -o $@ | |||
$(KDIR)cgemm_small_kernel_tt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_TT) | |||
$(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DTT $< -o $@ | |||
$(KDIR)cgemm_small_kernel_tc$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_TT) | |||
$(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DTC $< -o $@ | |||
$(KDIR)cgemm_small_kernel_ct$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_TT) | |||
$(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DCT $< -o $@ | |||
$(KDIR)cgemm_small_kernel_cc$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_TT) | |||
$(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DCC $< -o $@ | |||
ifndef CGEMM_SMALL_K_B0_NN | |||
CGEMM_SMALL_K_B0_NN = ../generic/zgemm_small_matrix_kernel_nn.c | |||
endif | |||
ifndef CGEMM_SMALL_K_B0_NT | |||
CGEMM_SMALL_K_B0_NT = ../generic/zgemm_small_matrix_kernel_nt.c | |||
endif | |||
ifndef CGEMM_SMALL_K_B0_TN | |||
CGEMM_SMALL_K_B0_TN = ../generic/zgemm_small_matrix_kernel_tn.c | |||
endif | |||
ifndef CGEMM_SMALL_K_B0_TT | |||
CGEMM_SMALL_K_B0_TT = ../generic/zgemm_small_matrix_kernel_tt.c | |||
endif | |||
$(KDIR)cgemm_small_kernel_b0_nn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_B0_NN) | |||
$(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DNN -DB0 $< -o $@ | |||
$(KDIR)cgemm_small_kernel_b0_nr$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_B0_NN) | |||
$(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DNR -DB0 $< -o $@ | |||
$(KDIR)cgemm_small_kernel_b0_rn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_B0_NN) | |||
$(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DRN -DB0 $< -o $@ | |||
$(KDIR)cgemm_small_kernel_b0_rr$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_B0_NN) | |||
$(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DRR -DB0 $< -o $@ | |||
$(KDIR)cgemm_small_kernel_b0_nt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_B0_NT) | |||
$(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DNT -DB0 $< -o $@ | |||
$(KDIR)cgemm_small_kernel_b0_nc$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_B0_NT) | |||
$(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DNC -DB0 $< -o $@ | |||
$(KDIR)cgemm_small_kernel_b0_rt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_B0_NT) | |||
$(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DRT -DB0 $< -o $@ | |||
$(KDIR)cgemm_small_kernel_b0_rc$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_B0_NT) | |||
$(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DRC -DB0 $< -o $@ | |||
$(KDIR)cgemm_small_kernel_b0_tn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_B0_TN) | |||
$(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DTN -DB0 $< -o $@ | |||
$(KDIR)cgemm_small_kernel_b0_tr$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_B0_TN) | |||
$(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DTR -DB0 $< -o $@ | |||
$(KDIR)cgemm_small_kernel_b0_cn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_B0_TN) | |||
$(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DCN -DB0 $< -o $@ | |||
$(KDIR)cgemm_small_kernel_b0_cr$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_B0_TN) | |||
$(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DCR -DB0 $< -o $@ | |||
$(KDIR)cgemm_small_kernel_b0_tt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_B0_TT) | |||
$(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DTT -DB0 $< -o $@ | |||
$(KDIR)cgemm_small_kernel_b0_tc$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_B0_TT) | |||
$(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DTC -DB0 $< -o $@ | |||
$(KDIR)cgemm_small_kernel_b0_ct$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_B0_TT) | |||
$(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DCT -DB0 $< -o $@ | |||
$(KDIR)cgemm_small_kernel_b0_cc$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_B0_TT) | |||
$(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DCC -DB0 $< -o $@ | |||
ifndef ZGEMM_SMALL_M_PERMIT | |||
ZGEMM_SMALL_M_PERMIT = ../generic/zgemm_small_matrix_permit.c | |||
endif | |||
ifndef ZGEMM_SMALL_K_NN | |||
ZGEMM_SMALL_K_NN = ../generic/zgemm_small_matrix_kernel_nn.c | |||
endif | |||
ifndef ZGEMM_SMALL_K_NT | |||
ZGEMM_SMALL_K_NT = ../generic/zgemm_small_matrix_kernel_nt.c | |||
endif | |||
ifndef ZGEMM_SMALL_K_TN | |||
ZGEMM_SMALL_K_TN = ../generic/zgemm_small_matrix_kernel_tn.c | |||
endif | |||
ifndef ZGEMM_SMALL_K_TT | |||
ZGEMM_SMALL_K_TT = ../generic/zgemm_small_matrix_kernel_tt.c | |||
endif | |||
$(KDIR)zgemm_small_matrix_permit$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_M_PERMIT) | |||
$(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX $< -o $@ | |||
$(KDIR)zgemm_small_kernel_nn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_NN) | |||
$(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DNN $< -o $@ | |||
$(KDIR)zgemm_small_kernel_nr$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_NN) | |||
$(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DNR $< -o $@ | |||
$(KDIR)zgemm_small_kernel_rn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_NN) | |||
$(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DRN $< -o $@ | |||
$(KDIR)zgemm_small_kernel_rr$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_NN) | |||
$(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DRR $< -o $@ | |||
$(KDIR)zgemm_small_kernel_nt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_NT) | |||
$(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DNT $< -o $@ | |||
$(KDIR)zgemm_small_kernel_nc$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_NT) | |||
$(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DNC $< -o $@ | |||
$(KDIR)zgemm_small_kernel_rt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_NT) | |||
$(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DRT $< -o $@ | |||
$(KDIR)zgemm_small_kernel_rc$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_NT) | |||
$(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DRC $< -o $@ | |||
$(KDIR)zgemm_small_kernel_tn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_TN) | |||
$(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DTN $< -o $@ | |||
$(KDIR)zgemm_small_kernel_tr$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_TN) | |||
$(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DTR $< -o $@ | |||
$(KDIR)zgemm_small_kernel_cn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_TN) | |||
$(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DCN $< -o $@ | |||
$(KDIR)zgemm_small_kernel_cr$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_TN) | |||
$(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DCR $< -o $@ | |||
$(KDIR)zgemm_small_kernel_tt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_TT) | |||
$(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DTT $< -o $@ | |||
$(KDIR)zgemm_small_kernel_tc$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_TT) | |||
$(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DTC $< -o $@ | |||
$(KDIR)zgemm_small_kernel_ct$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_TT) | |||
$(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DCT $< -o $@ | |||
$(KDIR)zgemm_small_kernel_cc$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_TT) | |||
$(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DCC $< -o $@ | |||
ifndef ZGEMM_SMALL_K_B0_NN | |||
ZGEMM_SMALL_K_B0_NN = ../generic/zgemm_small_matrix_kernel_nn.c | |||
endif | |||
ifndef ZGEMM_SMALL_K_B0_NT | |||
ZGEMM_SMALL_K_B0_NT = ../generic/zgemm_small_matrix_kernel_nt.c | |||
endif | |||
ifndef ZGEMM_SMALL_K_B0_TN | |||
ZGEMM_SMALL_K_B0_TN = ../generic/zgemm_small_matrix_kernel_tn.c | |||
endif | |||
ifndef ZGEMM_SMALL_K_B0_TT | |||
ZGEMM_SMALL_K_B0_TT = ../generic/zgemm_small_matrix_kernel_tt.c | |||
endif | |||
$(KDIR)zgemm_small_kernel_b0_nn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_B0_NN) | |||
$(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DNN -DB0 $< -o $@ | |||
$(KDIR)zgemm_small_kernel_b0_nr$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_B0_NN) | |||
$(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DNR -DB0 $< -o $@ | |||
$(KDIR)zgemm_small_kernel_b0_rn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_B0_NN) | |||
$(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DRN -DB0 $< -o $@ | |||
$(KDIR)zgemm_small_kernel_b0_rr$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_B0_NN) | |||
$(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DRR -DB0 $< -o $@ | |||
$(KDIR)zgemm_small_kernel_b0_nt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_B0_NT) | |||
$(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DNT -DB0 $< -o $@ | |||
$(KDIR)zgemm_small_kernel_b0_nc$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_B0_NT) | |||
$(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DNC -DB0 $< -o $@ | |||
$(KDIR)zgemm_small_kernel_b0_rt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_B0_NT) | |||
$(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DRT -DB0 $< -o $@ | |||
$(KDIR)zgemm_small_kernel_b0_rc$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_B0_NT) | |||
$(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DRC -DB0 $< -o $@ | |||
$(KDIR)zgemm_small_kernel_b0_tn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_B0_TN) | |||
$(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DTN -DB0 $< -o $@ | |||
$(KDIR)zgemm_small_kernel_b0_tr$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_B0_TN) | |||
$(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DTR -DB0 $< -o $@ | |||
$(KDIR)zgemm_small_kernel_b0_cn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_B0_TN) | |||
$(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DCN -DB0 $< -o $@ | |||
$(KDIR)zgemm_small_kernel_b0_cr$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_B0_TN) | |||
$(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DCR -DB0 $< -o $@ | |||
$(KDIR)zgemm_small_kernel_b0_tt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_B0_TT) | |||
$(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DTT -DB0 $< -o $@ | |||
$(KDIR)zgemm_small_kernel_b0_tc$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_B0_TT) | |||
$(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DTC -DB0 $< -o $@ | |||
$(KDIR)zgemm_small_kernel_b0_ct$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_B0_TT) | |||
$(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DCT -DB0 $< -o $@ | |||
$(KDIR)zgemm_small_kernel_b0_cc$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_B0_TT) | |||
$(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DCC -DB0 $< -o $@ |
@@ -0,0 +1,56 @@ | |||
/*************************************************************************** | |||
Copyright (c) 2020, The OpenBLAS Project | |||
All rights reserved. | |||
Redistribution and use in source and binary forms, with or without | |||
modification, are permitted provided that the following conditions are | |||
met: | |||
1. Redistributions of source code must retain the above copyright | |||
notice, this list of conditions and the following disclaimer. | |||
2. Redistributions in binary form must reproduce the above copyright | |||
notice, this list of conditions and the following disclaimer in | |||
the documentation and/or other materials provided with the | |||
distribution. | |||
3. Neither the name of the OpenBLAS project nor the names of | |||
its contributors may be used to endorse or promote products | |||
derived from this software without specific prior written permission. | |||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE | |||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE | |||
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||
*****************************************************************************/ | |||
#include "common.h" | |||
#ifdef B0 | |||
int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb,FLOAT * C, BLASLONG ldc) | |||
#else | |||
int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb, FLOAT beta, FLOAT * C, BLASLONG ldc) | |||
#endif | |||
{ | |||
//naive implemtation | |||
//Column major | |||
BLASLONG i,j,k; | |||
FLOAT result=0.0; | |||
for(i=0; i<M; i++){ | |||
for(j=0; j<N; j++){ | |||
result=0.0; | |||
for(k=0; k<K; k++){ | |||
result += A[i+k*lda] * B[k+j*ldb]; | |||
} | |||
#ifdef B0 | |||
C[i+j*ldc]=alpha * result; | |||
#else | |||
C[i+j*ldc]=C[i+j*ldc] * beta + alpha * result; | |||
#endif | |||
} | |||
} | |||
return 0; | |||
} |
@@ -0,0 +1,56 @@ | |||
/*************************************************************************** | |||
Copyright (c) 2020, The OpenBLAS Project | |||
All rights reserved. | |||
Redistribution and use in source and binary forms, with or without | |||
modification, are permitted provided that the following conditions are | |||
met: | |||
1. Redistributions of source code must retain the above copyright | |||
notice, this list of conditions and the following disclaimer. | |||
2. Redistributions in binary form must reproduce the above copyright | |||
notice, this list of conditions and the following disclaimer in | |||
the documentation and/or other materials provided with the | |||
distribution. | |||
3. Neither the name of the OpenBLAS project nor the names of | |||
its contributors may be used to endorse or promote products | |||
derived from this software without specific prior written permission. | |||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE | |||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE | |||
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||
*****************************************************************************/ | |||
#include "common.h" | |||
#ifdef B0 | |||
int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb, FLOAT * C, BLASLONG ldc) | |||
#else | |||
int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb, FLOAT beta, FLOAT * C, BLASLONG ldc) | |||
#endif | |||
{ | |||
//naive implemtation | |||
//Column major | |||
BLASLONG i,j,k; | |||
FLOAT result=0.0; | |||
for(i=0; i<M; i++){ | |||
for(j=0; j<N; j++){ | |||
result=0.0; | |||
for(k=0; k<K; k++){ | |||
result += A[i+k*lda] * B[k*ldb+j]; | |||
} | |||
#ifdef B0 | |||
C[i+j*ldc]=alpha * result; | |||
#else | |||
C[i+j*ldc]=C[i+j*ldc] * beta + alpha * result; | |||
#endif | |||
} | |||
} | |||
return 0; | |||
} |
@@ -0,0 +1,57 @@ | |||
/*************************************************************************** | |||
Copyright (c) 2020, The OpenBLAS Project | |||
All rights reserved. | |||
Redistribution and use in source and binary forms, with or without | |||
modification, are permitted provided that the following conditions are | |||
met: | |||
1. Redistributions of source code must retain the above copyright | |||
notice, this list of conditions and the following disclaimer. | |||
2. Redistributions in binary form must reproduce the above copyright | |||
notice, this list of conditions and the following disclaimer in | |||
the documentation and/or other materials provided with the | |||
distribution. | |||
3. Neither the name of the OpenBLAS project nor the names of | |||
its contributors may be used to endorse or promote products | |||
derived from this software without specific prior written permission. | |||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE | |||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE | |||
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||
*****************************************************************************/ | |||
#include "common.h" | |||
#ifdef B0 | |||
int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb,FLOAT * C, BLASLONG ldc) | |||
#else | |||
int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb, FLOAT beta, FLOAT * C, BLASLONG ldc) | |||
#endif | |||
{ | |||
//naive implemtation | |||
//Column major | |||
BLASLONG i,j,k; | |||
FLOAT result=0.0; | |||
for(i=0; i<M; i++){ | |||
for(j=0; j<N; j++){ | |||
result=0.0; | |||
for(k=0; k<K; k++){ | |||
result += A[i*lda+k] * B[k+j*ldb]; | |||
} | |||
#ifdef B0 | |||
C[i+j*ldc]=alpha * result; | |||
#else | |||
C[i+j*ldc]=C[i+j*ldc] * beta + alpha * result; | |||
#endif | |||
} | |||
} | |||
return 0; | |||
} |
@@ -0,0 +1,57 @@ | |||
/*************************************************************************** | |||
Copyright (c) 2020, The OpenBLAS Project | |||
All rights reserved. | |||
Redistribution and use in source and binary forms, with or without | |||
modification, are permitted provided that the following conditions are | |||
met: | |||
1. Redistributions of source code must retain the above copyright | |||
notice, this list of conditions and the following disclaimer. | |||
2. Redistributions in binary form must reproduce the above copyright | |||
notice, this list of conditions and the following disclaimer in | |||
the documentation and/or other materials provided with the | |||
distribution. | |||
3. Neither the name of the OpenBLAS project nor the names of | |||
its contributors may be used to endorse or promote products | |||
derived from this software without specific prior written permission. | |||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE | |||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE | |||
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||
*****************************************************************************/ | |||
#include "common.h" | |||
#ifdef B0 | |||
int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb, FLOAT * C, BLASLONG ldc) | |||
#else | |||
int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb, FLOAT beta, FLOAT * C, BLASLONG ldc) | |||
#endif | |||
{ | |||
//naive implemtation | |||
//Column major | |||
BLASLONG i,j,k; | |||
FLOAT result=0.0; | |||
for(i=0; i<M; i++){ | |||
for(j=0; j<N; j++){ | |||
result=0.0; | |||
for(k=0; k<K; k++){ | |||
result += A[i*lda+k] * B[k*ldb+j]; | |||
} | |||
#ifdef B0 | |||
C[i+j*ldc]=alpha * result; | |||
#else | |||
C[i+j*ldc]=C[i+j*ldc] * beta + alpha * result; | |||
#endif | |||
} | |||
} | |||
return 0; | |||
} |
@@ -0,0 +1,40 @@ | |||
/*************************************************************************** | |||
Copyright (c) 2021, The OpenBLAS Project | |||
All rights reserved. | |||
Redistribution and use in source and binary forms, with or without | |||
modification, are permitted provided that the following conditions are | |||
met: | |||
1. Redistributions of source code must retain the above copyright | |||
notice, this list of conditions and the following disclaimer. | |||
2. Redistributions in binary form must reproduce the above copyright | |||
notice, this list of conditions and the following disclaimer in | |||
the documentation and/or other materials provided with the | |||
distribution. | |||
3. Neither the name of the OpenBLAS project nor the names of | |||
its contributors may be used to endorse or promote products | |||
derived from this software without specific prior written permission. | |||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE | |||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE | |||
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||
*****************************************************************************/ | |||
#include "common.h" | |||
int CNAME(int transa, int transb, BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, FLOAT beta) | |||
{ | |||
return 0; | |||
/* | |||
double MNK = (double) M * (double) N * (double) K; | |||
if (MNK <= 100.0*100.0*100.0) | |||
return 1; | |||
else | |||
return 0; | |||
*/ | |||
} |
@@ -0,0 +1,89 @@ | |||
/*************************************************************************** | |||
Copyright (c) 2020, The OpenBLAS Project | |||
All rights reserved. | |||
Redistribution and use in source and binary forms, with or without | |||
modification, are permitted provided that the following conditions are | |||
met: | |||
1. Redistributions of source code must retain the above copyright | |||
notice, this list of conditions and the following disclaimer. | |||
2. Redistributions in binary form must reproduce the above copyright | |||
notice, this list of conditions and the following disclaimer in | |||
the documentation and/or other materials provided with the | |||
distribution. | |||
3. Neither the name of the OpenBLAS project nor the names of | |||
its contributors may be used to endorse or promote products | |||
derived from this software without specific prior written permission. | |||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE | |||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE | |||
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||
*****************************************************************************/ | |||
#include "common.h" | |||
#ifndef B0 | |||
int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha0, FLOAT alpha1, FLOAT * B, BLASLONG ldb, FLOAT beta0, FLOAT beta1, FLOAT * C, BLASLONG ldc) | |||
#else | |||
int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha0, FLOAT alpha1, FLOAT * B, BLASLONG ldb, FLOAT * C, BLASLONG ldc) | |||
#endif | |||
{ | |||
FLOAT real, imag; | |||
#ifndef B0 | |||
FLOAT tmp0, tmp1; | |||
#endif | |||
int i, j, l; | |||
for(i = 0; i < M; i++){ | |||
for(j = 0; j < N; j++){ | |||
real=0; | |||
imag=0; | |||
for(l = 0; l < K; l++){ | |||
#if defined(NN) | |||
real += (A[l*2*lda + 2*i]*B[j*2*ldb + 2*l] | |||
-A[l*2*lda + 2*i + 1] * B[j*2*ldb + 2*l + 1]); | |||
imag+=(A[l*2*lda + 2*i] * B[j*2*ldb + 2*l + 1] | |||
+ A[l*2*lda + 2*i + 1] * B[j*2*ldb + 2*l]); | |||
#elif defined(NR) | |||
real += (A[l*2*lda + 2*i]*B[j*2*ldb + 2*l] | |||
+A[l*2*lda + 2*i + 1] * B[j*2*ldb + 2*l + 1]); | |||
imag+=(-A[l*2*lda + 2*i] * B[j*2*ldb + 2*l + 1] | |||
+ A[l*2*lda + 2*i + 1] * B[j*2*ldb + 2*l]); | |||
#elif defined(RN) | |||
real += (A[l*2*lda + 2*i]*B[j*2*ldb + 2*l] | |||
+A[l*2*lda + 2*i + 1] * B[j*2*ldb + 2*l + 1]); | |||
imag+=(A[l*2*lda + 2*i] * B[j*2*ldb + 2*l + 1] | |||
- A[l*2*lda + 2*i + 1] * B[j*2*ldb + 2*l]); | |||
#elif defined(RR) | |||
real += (A[l*2*lda + 2*i]*B[j*2*ldb + 2*l] | |||
-A[l*2*lda + 2*i + 1] * B[j*2*ldb + 2*l + 1]); | |||
imag+=(-A[l*2*lda + 2*i] * B[j*2*ldb + 2*l + 1] | |||
- A[l*2*lda + 2*i + 1] * B[j*2*ldb + 2*l]); | |||
#endif | |||
} | |||
#ifndef B0 | |||
tmp0 = beta0*C[j*2*ldc + 2*i] - beta1*C[j*2*ldc+ 2*i + 1]; | |||
tmp1 = beta0*C[j*2*ldc+ 2*i + 1] + beta1*C[j*2*ldc + 2*i]; | |||
C[j*2*ldc + 2*i] =tmp0+ alpha0*real - alpha1*imag; | |||
C[j*2*ldc+ 2*i + 1] = tmp1+ alpha0*imag + real*alpha1; | |||
#else | |||
C[j*2*ldc + 2*i] = alpha0*real - alpha1*imag; | |||
C[j*2*ldc+ 2*i + 1] = alpha0*imag + real*alpha1; | |||
#endif | |||
} | |||
} | |||
return 0; | |||
} |
@@ -0,0 +1,93 @@ | |||
/*************************************************************************** | |||
Copyright (c) 2020, The OpenBLAS Project | |||
All rights reserved. | |||
Redistribution and use in source and binary forms, with or without | |||
modification, are permitted provided that the following conditions are | |||
met: | |||
1. Redistributions of source code must retain the above copyright | |||
notice, this list of conditions and the following disclaimer. | |||
2. Redistributions in binary form must reproduce the above copyright | |||
notice, this list of conditions and the following disclaimer in | |||
the documentation and/or other materials provided with the | |||
distribution. | |||
3. Neither the name of the OpenBLAS project nor the names of | |||
its contributors may be used to endorse or promote products | |||
derived from this software without specific prior written permission. | |||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE | |||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE | |||
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||
*****************************************************************************/ | |||
#include "common.h" | |||
#ifndef B0 | |||
int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha0, FLOAT alpha1, FLOAT * B, BLASLONG ldb, FLOAT beta0, FLOAT beta1, FLOAT * C, BLASLONG ldc) | |||
#else | |||
int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha0, FLOAT alpha1, FLOAT * B, BLASLONG ldb, FLOAT * C, BLASLONG ldc) | |||
#endif | |||
{ | |||
FLOAT real, imag; | |||
#ifndef B0 | |||
FLOAT tmp0, tmp1; | |||
#endif | |||
int i, j, l; | |||
for(i = 0; i < M; i++){ | |||
for(j = 0; j < N; j++){ | |||
real=0; | |||
imag=0; | |||
for(l = 0; l < K; l++){ | |||
#if defined(NT) | |||
real += (A[l*2*lda + 2*i]*B[l*2*ldb + 2*j] | |||
-A[l*2*lda + 2*i + 1] * B[l*2*ldb + 2*j + 1]); | |||
imag+=(A[l*2*lda + 2*i] * B[l*2*ldb + 2*j + 1] | |||
+ A[l*2*lda + 2*i + 1] * B[l*2*ldb + 2*j]); | |||
#elif defined(NC) | |||
real += (A[l*2*lda + 2*i]*B[l*2*ldb + 2*j] | |||
+A[l*2*lda + 2*i + 1] * B[l*2*ldb + 2*j + 1]); | |||
imag+=(-A[l*2*lda + 2*i] * B[l*2*ldb + 2*j + 1] | |||
+ A[l*2*lda + 2*i + 1] * B[l*2*ldb + 2*j]); | |||
#elif defined(RT) | |||
real += (A[l*2*lda + 2*i]*B[l*2*ldb + 2*j] | |||
+A[l*2*lda + 2*i + 1] * B[l*2*ldb + 2*j + 1]); | |||
imag+=(A[l*2*lda + 2*i] * B[l*2*ldb + 2*j + 1] | |||
- A[l*2*lda + 2*i + 1] * B[l*2*ldb + 2*j]); | |||
#elif defined(RC) | |||
real += (A[l*2*lda + 2*i]*B[l*2*ldb + 2*j] | |||
-A[l*2*lda + 2*i + 1] * B[l*2*ldb + 2*j + 1]); | |||
imag+=(-A[l*2*lda + 2*i] * B[l*2*ldb + 2*j + 1] | |||
- A[l*2*lda + 2*i + 1] * B[l*2*ldb + 2*j]); | |||
#endif | |||
} | |||
#ifndef B0 | |||
tmp0 = beta0*C[j*2*ldc + 2*i] - beta1*C[j*2*ldc+ 2*i + 1]; | |||
tmp1 = beta0*C[j*2*ldc+ 2*i + 1] + beta1*C[j*2*ldc + 2*i]; | |||
C[j*2*ldc + 2*i] =tmp0+ alpha0*real - alpha1*imag; | |||
C[j*2*ldc+ 2*i + 1] = tmp1+ alpha0*imag + real*alpha1; | |||
#else | |||
C[j*2*ldc + 2*i] = alpha0*real - alpha1*imag; | |||
C[j*2*ldc+ 2*i + 1] = alpha0*imag + real*alpha1; | |||
#endif | |||
} | |||
} | |||
return 0; | |||
} |
@@ -0,0 +1,93 @@ | |||
/*************************************************************************** | |||
Copyright (c) 2020, The OpenBLAS Project | |||
All rights reserved. | |||
Redistribution and use in source and binary forms, with or without | |||
modification, are permitted provided that the following conditions are | |||
met: | |||
1. Redistributions of source code must retain the above copyright | |||
notice, this list of conditions and the following disclaimer. | |||
2. Redistributions in binary form must reproduce the above copyright | |||
notice, this list of conditions and the following disclaimer in | |||
the documentation and/or other materials provided with the | |||
distribution. | |||
3. Neither the name of the OpenBLAS project nor the names of | |||
its contributors may be used to endorse or promote products | |||
derived from this software without specific prior written permission. | |||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE | |||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE | |||
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||
*****************************************************************************/ | |||
#include "common.h" | |||
#ifndef B0 | |||
int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha0, FLOAT alpha1, FLOAT * B, BLASLONG ldb, FLOAT beta0, FLOAT beta1, FLOAT * C, BLASLONG ldc) | |||
#else | |||
int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha0, FLOAT alpha1, FLOAT * B, BLASLONG ldb, FLOAT * C, BLASLONG ldc) | |||
#endif | |||
{ | |||
FLOAT real, imag; | |||
#ifndef B0 | |||
FLOAT tmp0, tmp1; | |||
#endif | |||
int i, j, l; | |||
for(i = 0; i < M; i++){ | |||
for(j = 0; j < N; j++){ | |||
real=0; | |||
imag=0; | |||
for(l = 0; l < K; l++){ | |||
#if defined(TN) | |||
real += (A[i*2*lda + 2*l]*B[j*2*ldb + 2*l] | |||
-A[i*2*lda + 2*l + 1] * B[j*2*ldb + 2*l + 1]); | |||
imag+=(A[i*2*lda + 2*l] * B[j*2*ldb + 2*l + 1] | |||
+ A[i*2*lda + 2*l + 1] * B[j*2*ldb + 2*l]); | |||
#elif defined(TR) | |||
real += (A[i*2*lda + 2*l]*B[j*2*ldb + 2*l] | |||
+A[i*2*lda + 2*l + 1] * B[j*2*ldb + 2*l + 1]); | |||
imag+=(-A[i*2*lda + 2*l] * B[j*2*ldb + 2*l + 1] | |||
+ A[i*2*lda + 2*l + 1] * B[j*2*ldb + 2*l]); | |||
#elif defined(CN) | |||
real += (A[i*2*lda + 2*l]*B[j*2*ldb + 2*l] | |||
+A[i*2*lda + 2*l + 1] * B[j*2*ldb + 2*l + 1]); | |||
imag+=(A[i*2*lda + 2*l] * B[j*2*ldb + 2*l + 1] | |||
- A[i*2*lda + 2*l + 1] * B[j*2*ldb + 2*l]); | |||
#elif defined(CR) | |||
real += (A[i*2*lda + 2*l]*B[j*2*ldb + 2*l] | |||
-A[i*2*lda + 2*l + 1] * B[j*2*ldb + 2*l + 1]); | |||
imag+=(-A[i*2*lda + 2*l] * B[j*2*ldb + 2*l + 1] | |||
- A[i*2*lda + 2*l + 1] * B[j*2*ldb + 2*l]); | |||
#endif | |||
} | |||
#ifndef B0 | |||
tmp0 = beta0*C[j*2*ldc + 2*i] - beta1*C[j*2*ldc+ 2*i + 1]; | |||
tmp1 = beta0*C[j*2*ldc+ 2*i + 1] + beta1*C[j*2*ldc + 2*i]; | |||
C[j*2*ldc + 2*i] =tmp0+ alpha0*real - alpha1*imag; | |||
C[j*2*ldc+ 2*i + 1] = tmp1+ alpha0*imag + real*alpha1; | |||
#else | |||
C[j*2*ldc + 2*i] = alpha0*real - alpha1*imag; | |||
C[j*2*ldc+ 2*i + 1] = alpha0*imag + real*alpha1; | |||
#endif | |||
} | |||
} | |||
return 0; | |||
} |
@@ -0,0 +1,93 @@ | |||
/*************************************************************************** | |||
Copyright (c) 2020, The OpenBLAS Project | |||
All rights reserved. | |||
Redistribution and use in source and binary forms, with or without | |||
modification, are permitted provided that the following conditions are | |||
met: | |||
1. Redistributions of source code must retain the above copyright | |||
notice, this list of conditions and the following disclaimer. | |||
2. Redistributions in binary form must reproduce the above copyright | |||
notice, this list of conditions and the following disclaimer in | |||
the documentation and/or other materials provided with the | |||
distribution. | |||
3. Neither the name of the OpenBLAS project nor the names of | |||
its contributors may be used to endorse or promote products | |||
derived from this software without specific prior written permission. | |||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE | |||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE | |||
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||
*****************************************************************************/ | |||
#include "common.h" | |||
#ifndef B0 | |||
int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha0, FLOAT alpha1, FLOAT * B, BLASLONG ldb, FLOAT beta0, FLOAT beta1, FLOAT * C, BLASLONG ldc) | |||
#else | |||
int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha0, FLOAT alpha1, FLOAT * B, BLASLONG ldb, FLOAT * C, BLASLONG ldc) | |||
#endif | |||
{ | |||
FLOAT real, imag; | |||
#ifndef B0 | |||
FLOAT tmp0, tmp1; | |||
#endif | |||
int i, j, l; | |||
for(i = 0; i < M; i++){ | |||
for(j = 0; j < N; j++){ | |||
real=0; | |||
imag=0; | |||
for(l = 0; l < K; l++){ | |||
#if defined(TT) | |||
real += (A[i*2*lda + 2*l]*B[l*2*ldb + 2*j] | |||
-A[i*2*lda + 2*l + 1] * B[l*2*ldb + 2*j + 1]); | |||
imag+=(A[i*2*lda + 2*l] * B[l*2*ldb + 2*j + 1] | |||
+ A[i*2*lda + 2*l + 1] * B[l*2*ldb + 2*j]); | |||
#elif defined(TC) | |||
real += (A[i*2*lda + 2*l]*B[l*2*ldb + 2*j] | |||
+A[i*2*lda + 2*l + 1] * B[l*2*ldb + 2*j + 1]); | |||
imag+=(-A[i*2*lda + 2*l] * B[l*2*ldb + 2*j + 1] | |||
+ A[i*2*lda + 2*l + 1] * B[l*2*ldb + 2*j]); | |||
#elif defined(CT) | |||
real += (A[i*2*lda + 2*l]*B[l*2*ldb + 2*j] | |||
+A[i*2*lda + 2*l + 1] * B[l*2*ldb + 2*j + 1]); | |||
imag+=(A[i*2*lda + 2*l] * B[l*2*ldb + 2*j + 1] | |||
- A[i*2*lda + 2*l + 1] * B[l*2*ldb + 2*j]); | |||
#elif defined(CC) | |||
real += (A[i*2*lda + 2*l]*B[l*2*ldb + 2*j] | |||
-A[i*2*lda + 2*l + 1] * B[l*2*ldb + 2*j + 1]); | |||
imag+=(-A[i*2*lda + 2*l] * B[l*2*ldb + 2*j + 1] | |||
- A[i*2*lda + 2*l + 1] * B[l*2*ldb + 2*j]); | |||
#endif | |||
} | |||
#ifndef B0 | |||
tmp0 = beta0*C[j*2*ldc + 2*i] - beta1*C[j*2*ldc+ 2*i + 1]; | |||
tmp1 = beta0*C[j*2*ldc+ 2*i + 1] + beta1*C[j*2*ldc + 2*i]; | |||
C[j*2*ldc + 2*i] =tmp0+ alpha0*real - alpha1*imag; | |||
C[j*2*ldc+ 2*i + 1] = tmp1+ alpha0*imag + real*alpha1; | |||
#else | |||
C[j*2*ldc + 2*i] = alpha0*real - alpha1*imag; | |||
C[j*2*ldc+ 2*i + 1] = alpha0*imag + real*alpha1; | |||
#endif | |||
} | |||
} | |||
return 0; | |||
} |
@@ -0,0 +1,40 @@ | |||
/*************************************************************************** | |||
Copyright (c) 2021, The OpenBLAS Project | |||
All rights reserved. | |||
Redistribution and use in source and binary forms, with or without | |||
modification, are permitted provided that the following conditions are | |||
met: | |||
1. Redistributions of source code must retain the above copyright | |||
notice, this list of conditions and the following disclaimer. | |||
2. Redistributions in binary form must reproduce the above copyright | |||
notice, this list of conditions and the following disclaimer in | |||
the documentation and/or other materials provided with the | |||
distribution. | |||
3. Neither the name of the OpenBLAS project nor the names of | |||
its contributors may be used to endorse or promote products | |||
derived from this software without specific prior written permission. | |||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE | |||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE | |||
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||
*****************************************************************************/ | |||
#include "common.h" | |||
int CNAME(int transa, int transb, BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha0, FLOAT alpha1, FLOAT beta0, FLOAT beta1) | |||
{ | |||
return 0; | |||
/* | |||
double MNK = (double) M * (double) N * (double) K; | |||
if (MNK <= 100.0*100.0*100.0) | |||
return 1; | |||
else | |||
return 0; | |||
*/ | |||
} |
@@ -171,6 +171,14 @@ gotoblas_t TABLE_NAME = { | |||
sgemm_oncopyTS, sgemm_otcopyTS, | |||
#endif | |||
#if BUILD_SINGLE == 1 | |||
#ifdef SMALL_MATRIX_OPT | |||
sgemm_small_matrix_permitTS, | |||
sgemm_small_kernel_nnTS, sgemm_small_kernel_ntTS, sgemm_small_kernel_tnTS, sgemm_small_kernel_ttTS, | |||
sgemm_small_kernel_b0_nnTS, sgemm_small_kernel_b0_ntTS, sgemm_small_kernel_b0_tnTS, sgemm_small_kernel_b0_ttTS, | |||
#endif | |||
#endif | |||
#if (BUILD_SINGLE==1) || (BUILD_DOUBLE==1) | |||
strsm_kernel_LNTS, strsm_kernel_LTTS, strsm_kernel_RNTS, strsm_kernel_RTTS, | |||
#if SGEMM_DEFAULT_UNROLL_M != SGEMM_DEFAULT_UNROLL_N | |||
@@ -257,6 +265,11 @@ gotoblas_t TABLE_NAME = { | |||
#endif | |||
#if (BUILD_DOUBLE==1) | |||
#ifdef SMALL_MATRIX_OPT | |||
dgemm_small_matrix_permitTS, | |||
dgemm_small_kernel_nnTS, dgemm_small_kernel_ntTS, dgemm_small_kernel_tnTS, dgemm_small_kernel_ttTS, | |||
dgemm_small_kernel_b0_nnTS, dgemm_small_kernel_b0_ntTS, dgemm_small_kernel_b0_tnTS, dgemm_small_kernel_b0_ttTS, | |||
#endif | |||
dtrsm_kernel_LNTS, dtrsm_kernel_LTTS, dtrsm_kernel_RNTS, dtrsm_kernel_RTTS, | |||
#if DGEMM_DEFAULT_UNROLL_M != DGEMM_DEFAULT_UNROLL_N | |||
dtrsm_iunucopyTS, dtrsm_iunncopyTS, dtrsm_iutucopyTS, dtrsm_iutncopyTS, | |||
@@ -389,6 +402,18 @@ gotoblas_t TABLE_NAME = { | |||
#endif | |||
cgemm_oncopyTS, cgemm_otcopyTS, | |||
#ifdef SMALL_MATRIX_OPT | |||
cgemm_small_matrix_permitTS, | |||
cgemm_small_kernel_nnTS, cgemm_small_kernel_ntTS, cgemm_small_kernel_nrTS, cgemm_small_kernel_ncTS, | |||
cgemm_small_kernel_tnTS, cgemm_small_kernel_ttTS, cgemm_small_kernel_trTS, cgemm_small_kernel_tcTS, | |||
cgemm_small_kernel_rnTS, cgemm_small_kernel_rtTS, cgemm_small_kernel_rrTS, cgemm_small_kernel_rcTS, | |||
cgemm_small_kernel_cnTS, cgemm_small_kernel_ctTS, cgemm_small_kernel_crTS, cgemm_small_kernel_ccTS, | |||
cgemm_small_kernel_b0_nnTS, cgemm_small_kernel_b0_ntTS, cgemm_small_kernel_b0_nrTS, cgemm_small_kernel_b0_ncTS, | |||
cgemm_small_kernel_b0_tnTS, cgemm_small_kernel_b0_ttTS, cgemm_small_kernel_b0_trTS, cgemm_small_kernel_b0_tcTS, | |||
cgemm_small_kernel_b0_rnTS, cgemm_small_kernel_b0_rtTS, cgemm_small_kernel_b0_rrTS, cgemm_small_kernel_b0_rcTS, | |||
cgemm_small_kernel_b0_cnTS, cgemm_small_kernel_b0_ctTS, cgemm_small_kernel_b0_crTS, cgemm_small_kernel_b0_ccTS, | |||
#endif | |||
ctrsm_kernel_LNTS, ctrsm_kernel_LTTS, ctrsm_kernel_LRTS, ctrsm_kernel_LCTS, | |||
ctrsm_kernel_RNTS, ctrsm_kernel_RTTS, ctrsm_kernel_RRTS, ctrsm_kernel_RCTS, | |||
@@ -533,6 +558,18 @@ gotoblas_t TABLE_NAME = { | |||
#endif | |||
zgemm_oncopyTS, zgemm_otcopyTS, | |||
#ifdef SMALL_MATRIX_OPT | |||
zgemm_small_matrix_permitTS, | |||
zgemm_small_kernel_nnTS, zgemm_small_kernel_ntTS, zgemm_small_kernel_nrTS, zgemm_small_kernel_ncTS, | |||
zgemm_small_kernel_tnTS, zgemm_small_kernel_ttTS, zgemm_small_kernel_trTS, zgemm_small_kernel_tcTS, | |||
zgemm_small_kernel_rnTS, zgemm_small_kernel_rtTS, zgemm_small_kernel_rrTS, zgemm_small_kernel_rcTS, | |||
zgemm_small_kernel_cnTS, zgemm_small_kernel_ctTS, zgemm_small_kernel_crTS, zgemm_small_kernel_ccTS, | |||
zgemm_small_kernel_b0_nnTS, zgemm_small_kernel_b0_ntTS, zgemm_small_kernel_b0_nrTS, zgemm_small_kernel_b0_ncTS, | |||
zgemm_small_kernel_b0_tnTS, zgemm_small_kernel_b0_ttTS, zgemm_small_kernel_b0_trTS, zgemm_small_kernel_b0_tcTS, | |||
zgemm_small_kernel_b0_rnTS, zgemm_small_kernel_b0_rtTS, zgemm_small_kernel_b0_rrTS, zgemm_small_kernel_b0_rcTS, | |||
zgemm_small_kernel_b0_cnTS, zgemm_small_kernel_b0_ctTS, zgemm_small_kernel_b0_crTS, zgemm_small_kernel_b0_ccTS, | |||
#endif | |||
ztrsm_kernel_LNTS, ztrsm_kernel_LTTS, ztrsm_kernel_LRTS, ztrsm_kernel_LCTS, | |||
ztrsm_kernel_RNTS, ztrsm_kernel_RTTS, ztrsm_kernel_RRTS, ztrsm_kernel_RCTS, | |||
@@ -10,6 +10,15 @@ STRSMKERNEL_LN = ../generic/trsm_kernel_LN.c | |||
STRSMKERNEL_LT = ../generic/trsm_kernel_LT.c | |||
STRSMKERNEL_RN = ../generic/trsm_kernel_RN.c | |||
STRSMKERNEL_RT = ../generic/trsm_kernel_RT.c | |||
SGEMM_SMALL_M_PERMIT = sgemm_small_kernel_permit_skylakex.c | |||
SGEMM_SMALL_K_NN = sgemm_small_kernel_nn_skylakex.c | |||
SGEMM_SMALL_K_B0_NN = sgemm_small_kernel_nn_skylakex.c | |||
SGEMM_SMALL_K_NT = sgemm_small_kernel_nt_skylakex.c | |||
SGEMM_SMALL_K_B0_NT = sgemm_small_kernel_nt_skylakex.c | |||
SGEMM_SMALL_K_TN = sgemm_small_kernel_tn_skylakex.c | |||
SGEMM_SMALL_K_B0_TN = sgemm_small_kernel_tn_skylakex.c | |||
SGEMM_SMALL_K_TT = sgemm_small_kernel_tt_skylakex.c | |||
SGEMM_SMALL_K_B0_TT = sgemm_small_kernel_tt_skylakex.c | |||
DGEMMKERNEL = dgemm_kernel_16x2_skylakex.c | |||
DTRMMKERNEL = dgemm_kernel_16x2_skylakex.c | |||
@@ -18,6 +27,15 @@ DGEMMITCOPY = dgemm_tcopy_16_skylakex.c | |||
DGEMMONCOPY = ../generic/gemm_ncopy_2.c | |||
DGEMMOTCOPY = ../generic/gemm_tcopy_2.c | |||
DTRSMKERNEL_RN = ../generic/trsm_kernel_RN.c | |||
DGEMM_SMALL_M_PERMIT = dgemm_small_kernel_permit_skylakex.c | |||
DGEMM_SMALL_K_NN = dgemm_small_kernel_nn_skylakex.c | |||
DGEMM_SMALL_K_B0_NN = dgemm_small_kernel_nn_skylakex.c | |||
DGEMM_SMALL_K_NT = dgemm_small_kernel_nt_skylakex.c | |||
DGEMM_SMALL_K_B0_NT = dgemm_small_kernel_nt_skylakex.c | |||
DGEMM_SMALL_K_TN = dgemm_small_kernel_tn_skylakex.c | |||
DGEMM_SMALL_K_B0_TN = dgemm_small_kernel_tn_skylakex.c | |||
DGEMM_SMALL_K_TT = dgemm_small_kernel_tt_skylakex.c | |||
DGEMM_SMALL_K_B0_TT = dgemm_small_kernel_tt_skylakex.c | |||
SGEMM_BETA = sgemm_beta_skylakex.c | |||
DGEMM_BETA = dgemm_beta_skylakex.c | |||
@@ -0,0 +1,590 @@ | |||
/*************************************************************************** | |||
Copyright (c) 2021, The OpenBLAS Project | |||
All rights reserved. | |||
Redistribution and use in source and binary forms, with or without | |||
modification, are permitted provided that the following conditions are | |||
met: | |||
1. Redistributions of source code must retain the above copyright | |||
notice, this list of conditions and the following disclaimer. | |||
2. Redistributions in binary form must reproduce the above copyright | |||
notice, this list of conditions and the following disclaimer in | |||
the documentation and/or other materials provided with the | |||
distribution. | |||
3. Neither the name of the OpenBLAS project nor the names of | |||
its contributors may be used to endorse or promote products | |||
derived from this software without specific prior written permission. | |||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE | |||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE | |||
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||
*****************************************************************************/ | |||
#include <immintrin.h> | |||
#include "common.h" | |||
#include <stdio.h> | |||
#include <memory.h> | |||
#define DECLARE_RESULT_512(M, N) __m512d result##M##N = _mm512_setzero_pd() | |||
#define LOAD_A_512(M, N) __m512d Aval##M = _mm512_loadu_pd(&A[lda * k + i + (M*8)]) | |||
#define MASK_LOAD_A_512(M, N) __m512d Aval##M = _mm512_maskz_loadu_pd(mask, &A[lda * k + i + (M*8)]) | |||
#define BROADCAST_LOAD_B_512(M, N) __m512d Bval##N = _mm512_broadcastsd_pd(_mm_load_pd1(&B[k + ldb * (j+N)])) | |||
#define MATMUL_512(M, N) result##M##N = _mm512_fmadd_pd(Aval##M, Bval##N, result##M##N) | |||
#if defined(B0) | |||
#define STORE_512(M, N) result##M##N = _mm512_mul_pd(result##M##N, alpha_512); \ | |||
_mm512_storeu_pd(&C[(j+N)*ldc + i + (M*8)], result##M##N) | |||
#define MASK_STORE_512(M, N) result##M##N = _mm512_mul_pd(result##M##N, alpha_512); \ | |||
_mm512_mask_storeu_pd(&C[(j+N)*ldc + i + (M*8)], mask, result##M##N) | |||
#else | |||
#define STORE_512(M, N) \ | |||
result##M##N = _mm512_mul_pd(result##M##N, alpha_512); \ | |||
asm("vfmadd231pd (%1), %2, %0": "+v"(result##M##N):"r"(&C[(j+N)*ldc + i + (M*8)]), "v"(beta_512)); \ | |||
_mm512_storeu_pd(&C[(j+N)*ldc + i + (M*8)], result##M##N) | |||
#define MASK_STORE_512(M, N) \ | |||
result##M##N = _mm512_mul_pd(result##M##N, alpha_512); \ | |||
asm("vfmadd231pd (%1), %2, %0 %{%3%}": "+v"(result##M##N):"r"(&C[(j+N)*ldc + i + (M*8)]), "v"(beta_512), "k"(mask)); \ | |||
_mm512_mask_storeu_pd(&C[(j+N)*ldc + i + (M*8)], mask, result##M##N) | |||
#endif | |||
#define LOAD_KA_512(M, N) __m512d Aval##M = _mm512_loadu_pd(&mbuf[(mi + M)*K + k]); | |||
#define LOAD_KB_512(M, N) __m512d Bval##N = _mm512_loadu_pd(&B[(j + N)*ldb + k]) | |||
#define MASK_LOAD_KA_512(M, N) __m512d Aval##M = _mm512_maskz_loadu_pd(mask, &mbuf[(mi + M)*K + k]) | |||
#define MASK_LOAD_KB_512(M, N) __m512d Bval##N = _mm512_maskz_loadu_pd(mask, &B[(j + N)*ldb + k]) | |||
#define REDUCE_4(rr0, rr1, rr2, rr3) \ | |||
__m512d r0, r1, r2, r3, t0, t1, t2, t3;\ | |||
r0 = _mm512_unpacklo_pd(rr0, rr1); r1 = _mm512_unpackhi_pd(rr0, rr1); \ | |||
r2 = _mm512_unpacklo_pd(rr2, rr3); r3 = _mm512_unpackhi_pd(rr2, rr3); \ | |||
t0 = _mm512_permutex2var_pd(r0, idx_lo, r2); t1 = _mm512_permutex2var_pd(r1, idx_lo, r3); \ | |||
t2 = _mm512_permutex2var_pd(r0, idx_hi, r2); t3 = _mm512_permutex2var_pd(r1, idx_hi, r3); \ | |||
r0 = _mm512_add_pd(t0, t1); r1 = _mm512_add_pd(t2, t3); t0 = _mm512_add_pd(r0, r1); \ | |||
__m256d s0, s1; \ | |||
s0 = _mm512_extractf64x4_pd(t0, 0); s1 = _mm512_extractf64x4_pd(t0, 1); \ | |||
s0 = _mm256_add_pd(s0, s1); s0 = _mm256_mul_pd(alpha_256, s0); | |||
#define REDUCE_M4(N) REDUCE_4(result0##N, result1##N, result2##N, result3##N) | |||
#define REDUCE_N4(M) REDUCE_4(result##M##0, result##M##1, result##M##2, result##M##3) | |||
#if defined(B0) | |||
#define STORE_REDUCE(M, N) C[(j+N)*ldc + i + M] = alpha * _mm512_reduce_add_pd(result##M##N); | |||
#define STORE_REDUCE_M4(N) {\ | |||
REDUCE_M4(N) \ | |||
_mm256_storeu_pd(&C[(j + N)*ldc + i], s0); \ | |||
} | |||
#define STORE_REDUCE_N4(M) {\ | |||
REDUCE_N4(M) \ | |||
_mm256_i64scatter_pd(&C[j*ldc + i + M], vindex_n, s0, 8); \ | |||
} | |||
#else | |||
#define STORE_REDUCE(M, N) C[(j+N)*ldc + i + M] = alpha * _mm512_reduce_add_pd(result##M##N) + beta * C[(j+N)*ldc + i + M]; | |||
#define STORE_REDUCE_M4(N) {\ | |||
REDUCE_M4(N) \ | |||
asm("vfmadd231pd (%1), %2, %0": "+v"(s0):"r"(&C[(j + N)*ldc + i]), "v"(beta_256)); \ | |||
_mm256_storeu_pd(&C[(j + N)*ldc + i], s0); \ | |||
} | |||
#define STORE_REDUCE_N4(M) {\ | |||
REDUCE_N4(M) \ | |||
s1 = _mm256_i64gather_pd(&C[j*ldc + i + M], vindex_n, 8); \ | |||
s0 = _mm256_fmadd_pd(s1, beta_256, s0); \ | |||
_mm256_i64scatter_pd(&C[j*ldc + i + M], vindex_n, s0, 8); \ | |||
} | |||
#endif | |||
#if defined(B0) | |||
int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb, FLOAT * C, BLASLONG ldc) | |||
#else | |||
int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb, FLOAT beta, FLOAT * C, BLASLONG ldc) | |||
#endif | |||
{ | |||
// column major | |||
BLASLONG i, j, k; | |||
BLASLONG m32 = M & ~31; | |||
BLASLONG m16 = M & ~15; | |||
BLASLONG m8 = M & ~7; | |||
BLASLONG m4 = M & ~3; | |||
BLASLONG m2 = M & ~1; | |||
BLASLONG n6 = N - (N % 6); | |||
BLASLONG n4 = N & ~3; | |||
BLASLONG n2 = N & ~1; | |||
__m512d alpha_512 = _mm512_broadcastsd_pd(_mm_load_pd1(&alpha)); | |||
#if !defined(B0) | |||
__m512d beta_512 = _mm512_broadcastsd_pd(_mm_load_pd1(&beta)); | |||
#endif | |||
for (i = 0; i < m32; i += 32) { | |||
for (j = 0; j < n4; j += 4) { | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); | |||
DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); | |||
DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); DECLARE_RESULT_512(2, 2); DECLARE_RESULT_512(3, 2); | |||
DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); DECLARE_RESULT_512(2, 3); DECLARE_RESULT_512(3, 3); | |||
for (k = 0; k < K; k++) { | |||
LOAD_A_512(0, x); LOAD_A_512(1, x); LOAD_A_512(2, x); LOAD_A_512(3, x); | |||
BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); | |||
BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); | |||
MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); | |||
MATMUL_512(0, 2); MATMUL_512(1, 2); MATMUL_512(2, 2); MATMUL_512(3, 2); | |||
MATMUL_512(0, 3); MATMUL_512(1, 3); MATMUL_512(2, 3); MATMUL_512(3, 3); | |||
} | |||
STORE_512(0, 0); STORE_512(1, 0); STORE_512(2, 0); STORE_512(3, 0); | |||
STORE_512(0, 1); STORE_512(1, 1); STORE_512(2, 1); STORE_512(3, 1); | |||
STORE_512(0, 2); STORE_512(1, 2); STORE_512(2, 2); STORE_512(3, 2); | |||
STORE_512(0, 3); STORE_512(1, 3); STORE_512(2, 3); STORE_512(3, 3); | |||
} | |||
for (; j < n2; j += 2) { | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); | |||
DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); | |||
for (k = 0; k < K; k++) { | |||
LOAD_A_512(0, x); LOAD_A_512(1, x); LOAD_A_512(2, x); LOAD_A_512(3, x); | |||
BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); | |||
MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); | |||
} | |||
STORE_512(0, 0); STORE_512(1, 0); STORE_512(2, 0); STORE_512(3, 0); | |||
STORE_512(0, 1); STORE_512(1, 1); STORE_512(2, 1); STORE_512(3, 1); | |||
} | |||
for (; j < N; j++) { | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); | |||
for (k = 0; k < K; k++) { | |||
LOAD_A_512(0, x); LOAD_A_512(1, x); LOAD_A_512(2, x); LOAD_A_512(3, x); | |||
BROADCAST_LOAD_B_512(x, 0); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); | |||
} | |||
STORE_512(0, 0); STORE_512(1, 0); STORE_512(2, 0); STORE_512(3, 0); | |||
} | |||
} | |||
for (; i < m16; i += 16) { | |||
for (j = 0; j < n6; j += 6) { | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); | |||
DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); | |||
DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); | |||
DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); | |||
DECLARE_RESULT_512(0, 4); DECLARE_RESULT_512(1, 4); | |||
DECLARE_RESULT_512(0, 5); DECLARE_RESULT_512(1, 5); | |||
for (k = 0; k < K; k++) { | |||
LOAD_A_512(0, x); LOAD_A_512(1, x); | |||
BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); | |||
BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3); | |||
BROADCAST_LOAD_B_512(x, 4); BROADCAST_LOAD_B_512(x, 5); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); | |||
MATMUL_512(0, 1); MATMUL_512(1, 1); | |||
MATMUL_512(0, 2); MATMUL_512(1, 2); | |||
MATMUL_512(0, 3); MATMUL_512(1, 3); | |||
MATMUL_512(0, 4); MATMUL_512(1, 4); | |||
MATMUL_512(0, 5); MATMUL_512(1, 5); | |||
} | |||
STORE_512(0, 0); STORE_512(1, 0); | |||
STORE_512(0, 1); STORE_512(1, 1); | |||
STORE_512(0, 2); STORE_512(1, 2); | |||
STORE_512(0, 3); STORE_512(1, 3); | |||
STORE_512(0, 4); STORE_512(1, 4); | |||
STORE_512(0, 5); STORE_512(1, 5); | |||
} | |||
for (; j < n2; j += 2) { | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); | |||
DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); | |||
for (k = 0; k < K; k++) { | |||
LOAD_A_512(0, x); LOAD_A_512(1, x); | |||
BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); | |||
MATMUL_512(0, 1); MATMUL_512(1, 1); | |||
} | |||
STORE_512(0, 0); STORE_512(1, 0); | |||
STORE_512(0, 1); STORE_512(1, 1); | |||
} | |||
for (; j < N; j++) { | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); | |||
for (k = 0; k < K; k++) { | |||
LOAD_A_512(0, x); LOAD_A_512(1, x); | |||
BROADCAST_LOAD_B_512(x, 0); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); | |||
} | |||
STORE_512(0, 0); STORE_512(1, 0); | |||
} | |||
} | |||
for (; i < m8; i += 8) { | |||
for (j = 0; j < n6; j += 6) { | |||
DECLARE_RESULT_512(0, 0); | |||
DECLARE_RESULT_512(0, 1); | |||
DECLARE_RESULT_512(0, 2); | |||
DECLARE_RESULT_512(0, 3); | |||
DECLARE_RESULT_512(0, 4); | |||
DECLARE_RESULT_512(0, 5); | |||
for (k = 0; k < K; k++) { | |||
LOAD_A_512(0, x); | |||
BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); | |||
BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3); | |||
BROADCAST_LOAD_B_512(x, 4); BROADCAST_LOAD_B_512(x, 5); | |||
MATMUL_512(0, 0); | |||
MATMUL_512(0, 1); | |||
MATMUL_512(0, 2); | |||
MATMUL_512(0, 3); | |||
MATMUL_512(0, 4); | |||
MATMUL_512(0, 5); | |||
} | |||
STORE_512(0, 0); | |||
STORE_512(0, 1); | |||
STORE_512(0, 2); | |||
STORE_512(0, 3); | |||
STORE_512(0, 4); | |||
STORE_512(0, 5); | |||
} | |||
for (; j < n2; j += 2) { | |||
DECLARE_RESULT_512(0, 0); | |||
DECLARE_RESULT_512(0, 1); | |||
for (k = 0; k < K; k++) { | |||
LOAD_A_512(0, x); | |||
BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); | |||
MATMUL_512(0, 0); | |||
MATMUL_512(0, 1); | |||
} | |||
STORE_512(0, 0); | |||
STORE_512(0, 1); | |||
} | |||
for (; j < N; j++) { | |||
DECLARE_RESULT_512(0, 0); | |||
for (k = 0; k < K; k++) { | |||
LOAD_A_512(0, x); | |||
BROADCAST_LOAD_B_512(x, 0); | |||
MATMUL_512(0, 0); | |||
} | |||
STORE_512(0, 0); | |||
} | |||
} | |||
int mm = M - i; | |||
if (!mm) return 0; | |||
if (mm > 4 || K < 16) { | |||
register __mmask8 mask asm("k1") = (1UL << mm) - 1; | |||
for (j = 0; j < n6; j += 6) { | |||
DECLARE_RESULT_512(0, 0); | |||
DECLARE_RESULT_512(0, 1); | |||
DECLARE_RESULT_512(0, 2); | |||
DECLARE_RESULT_512(0, 3); | |||
DECLARE_RESULT_512(0, 4); | |||
DECLARE_RESULT_512(0, 5); | |||
for (k = 0; k < K; k++) { | |||
MASK_LOAD_A_512(0, x); | |||
BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); | |||
BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3); | |||
BROADCAST_LOAD_B_512(x, 4); BROADCAST_LOAD_B_512(x, 5); | |||
MATMUL_512(0, 0); | |||
MATMUL_512(0, 1); | |||
MATMUL_512(0, 2); | |||
MATMUL_512(0, 3); | |||
MATMUL_512(0, 4); | |||
MATMUL_512(0, 5); | |||
} | |||
MASK_STORE_512(0, 0); | |||
MASK_STORE_512(0, 1); | |||
MASK_STORE_512(0, 2); | |||
MASK_STORE_512(0, 3); | |||
MASK_STORE_512(0, 4); | |||
MASK_STORE_512(0, 5); | |||
} | |||
for (; j < n2; j += 2) { | |||
DECLARE_RESULT_512(0, 0); | |||
DECLARE_RESULT_512(0, 1); | |||
for (k = 0; k < K; k++) { | |||
MASK_LOAD_A_512(0, x); | |||
BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); | |||
MATMUL_512(0, 0); | |||
MATMUL_512(0, 1); | |||
} | |||
MASK_STORE_512(0, 0); | |||
MASK_STORE_512(0, 1); | |||
} | |||
for (; j < N; j++) { | |||
DECLARE_RESULT_512(0, 0); | |||
for (k = 0; k < K; k++) { | |||
MASK_LOAD_A_512(0, x); | |||
BROADCAST_LOAD_B_512(x, 0); | |||
MATMUL_512(0, 0); | |||
} | |||
MASK_STORE_512(0, 0); | |||
} | |||
} else { | |||
/* M => [1, 4] | |||
* | |||
* This kernel use dot-like style to calc a value - C(x, y): | |||
* C(x, y) = A(x, 0)*B(0, y) + A(x, 1)*B(1, y) +....+ A(x, K)*B(K, y) | |||
* | |||
* Alloc a buf to copy rest of A as row major, | |||
* so memory access from 0 to K is continuous for both A & B. | |||
* | |||
* Loading to zmm and FMA 8 of k at one loop, | |||
* finally reduce_add zmm to a single float result in C(x, y). | |||
* | |||
* Note: performance is bad when K is small. | |||
*/ | |||
FLOAT *mbuf = (FLOAT *) malloc(sizeof(FLOAT)*mm*K); | |||
__mmask8 mask = (1UL << mm) - 1; | |||
BLASLONG k8 = K & ~7; | |||
BLASLONG k4 = K & ~3; | |||
for (k = 0; k < k4; k += 4) { | |||
__m256d r0, r1, r2, r3; | |||
__m256d t0, t1, t2, t3; | |||
r0 = _mm256_maskz_loadu_pd(mask, &A[i + lda*(0 + k)]); | |||
r1 = _mm256_maskz_loadu_pd(mask, &A[i + lda*(1 + k)]); | |||
r2 = _mm256_maskz_loadu_pd(mask, &A[i + lda*(2 + k)]); | |||
r3 = _mm256_maskz_loadu_pd(mask, &A[i + lda*(3 + k)]); | |||
t0 = _mm256_unpacklo_pd(r0, r1); | |||
t1 = _mm256_unpackhi_pd(r0, r1); | |||
t2 = _mm256_unpacklo_pd(r2, r3); | |||
t3 = _mm256_unpackhi_pd(r2, r3); | |||
r0 = _mm256_permute2f128_pd(t0, t2, 0x20); | |||
r1 = _mm256_permute2f128_pd(t1, t3, 0x20); | |||
r2 = _mm256_permute2f128_pd(t0, t2, 0x31); | |||
r3 = _mm256_permute2f128_pd(t1, t3, 0x31); | |||
switch (mm) { | |||
case 4: _mm256_storeu_pd(&mbuf[k + 3*K], r3); | |||
case 3: _mm256_storeu_pd(&mbuf[k + 2*K], r2); | |||
case 2: _mm256_storeu_pd(&mbuf[k + 1*K], r1); | |||
case 1: _mm256_storeu_pd(&mbuf[k + 0*K], r0); | |||
} | |||
} | |||
for (; k < K; k++) { | |||
for (int ii = 0; ii < mm; ii++) { | |||
mbuf[k + ii*K] = A[i + lda*k + ii]; | |||
} | |||
} | |||
int mi = 0; | |||
__m256d alpha_256 = _mm256_broadcast_sd(&alpha); | |||
#if !defined(B0) | |||
__m256d beta_256 = _mm256_broadcast_sd(&beta); | |||
#endif | |||
__m256i vindex_n = _mm256_set_epi64x(ldc*3, ldc*2, ldc*1, 0); | |||
long long permute_table[] = { | |||
0, 1, 0|8, 1|8, 4, 5, 4|8, 5|8, | |||
2, 3, 2|8, 3|8, 6, 7, 6|8, 7|8, | |||
}; | |||
__m512i idx_lo = _mm512_loadu_si512(permute_table); | |||
__m512i idx_hi = _mm512_loadu_si512(permute_table + 8); | |||
for (; i < m4; i += 4, mi += 4) { | |||
for (j = 0; j < n4; j += 4) { | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); | |||
DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); | |||
DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); DECLARE_RESULT_512(2, 2); DECLARE_RESULT_512(3, 2); | |||
DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); DECLARE_RESULT_512(2, 3); DECLARE_RESULT_512(3, 3); | |||
for (k = 0; k < k8; k += 8) { | |||
LOAD_KA_512(0, x); LOAD_KA_512(1, x); LOAD_KA_512(2, x); LOAD_KA_512(3, x); | |||
LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); LOAD_KB_512(x, 2); LOAD_KB_512(x, 3); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); | |||
MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); | |||
MATMUL_512(0, 2); MATMUL_512(1, 2); MATMUL_512(2, 2); MATMUL_512(3, 2); | |||
MATMUL_512(0, 3); MATMUL_512(1, 3); MATMUL_512(2, 3); MATMUL_512(3, 3); | |||
} | |||
int remains = K - k; | |||
if (remains) { | |||
mask = (1UL << remains) - 1; | |||
MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); MASK_LOAD_KA_512(2, x); MASK_LOAD_KA_512(3, x); | |||
MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); MASK_LOAD_KB_512(x, 2); MASK_LOAD_KB_512(x, 3); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); | |||
MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); | |||
MATMUL_512(0, 2); MATMUL_512(1, 2); MATMUL_512(2, 2); MATMUL_512(3, 2); | |||
MATMUL_512(0, 3); MATMUL_512(1, 3); MATMUL_512(2, 3); MATMUL_512(3, 3); | |||
} | |||
STORE_REDUCE_M4(0); STORE_REDUCE_M4(1); STORE_REDUCE_M4(2); STORE_REDUCE_M4(3); | |||
} | |||
for (; j < n2; j += 2) { | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); | |||
DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); | |||
for (k = 0; k < k8; k += 8) { | |||
LOAD_KA_512(0, x); LOAD_KA_512(1, x); LOAD_KA_512(2, x); LOAD_KA_512(3, x); | |||
LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); | |||
MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); | |||
} | |||
int remains = K - k; | |||
if (remains) { | |||
mask = (1UL << remains) - 1; | |||
MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); MASK_LOAD_KA_512(2, x); MASK_LOAD_KA_512(3, x); | |||
MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); | |||
MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); | |||
} | |||
STORE_REDUCE_M4(0); STORE_REDUCE_M4(1); | |||
} | |||
for (; j < N; j += 1) { | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); | |||
for (k = 0; k < k8; k += 8) { | |||
LOAD_KA_512(0, x); LOAD_KA_512(1, x); LOAD_KA_512(2, x); LOAD_KA_512(3, x); | |||
LOAD_KB_512(x, 0); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); | |||
} | |||
int remains = K - k; | |||
if (remains) { | |||
mask = (1UL << remains) - 1; | |||
MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); MASK_LOAD_KA_512(2, x); MASK_LOAD_KA_512(3, x); | |||
MASK_LOAD_KB_512(x, 0); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); | |||
} | |||
STORE_REDUCE_M4(0); | |||
} | |||
} | |||
for (; i < m2; i += 2, mi += 2) { | |||
for (j = 0; j < n4; j += 4) { | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); | |||
DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); | |||
DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); | |||
DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); | |||
for (k = 0; k < k8; k += 8) { | |||
LOAD_KA_512(0, x); LOAD_KA_512(1, x); | |||
LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); LOAD_KB_512(x, 2); LOAD_KB_512(x, 3); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); | |||
MATMUL_512(0, 1); MATMUL_512(1, 1); | |||
MATMUL_512(0, 2); MATMUL_512(1, 2); | |||
MATMUL_512(0, 3); MATMUL_512(1, 3); | |||
} | |||
int remains = K - k; | |||
if (remains) { | |||
mask = (1UL << remains) - 1; | |||
MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); | |||
MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); MASK_LOAD_KB_512(x, 2); MASK_LOAD_KB_512(x, 3); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); | |||
MATMUL_512(0, 1); MATMUL_512(1, 1); | |||
MATMUL_512(0, 2); MATMUL_512(1, 2); | |||
MATMUL_512(0, 3); MATMUL_512(1, 3); | |||
} | |||
STORE_REDUCE_N4(0); STORE_REDUCE_N4(1); | |||
} | |||
for (; j < n2; j += 2) { | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); | |||
DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); | |||
for (k = 0; k < k8; k += 8) { | |||
LOAD_KA_512(0, x); LOAD_KA_512(1, x); | |||
LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); | |||
MATMUL_512(0, 1); MATMUL_512(1, 1); | |||
} | |||
int remains = K - k; | |||
if (remains) { | |||
mask = (1UL << remains) - 1; | |||
MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); | |||
MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); | |||
MATMUL_512(0, 1); MATMUL_512(1, 1); | |||
} | |||
STORE_REDUCE(0, 0); STORE_REDUCE(1, 0); | |||
STORE_REDUCE(0, 1); STORE_REDUCE(1, 1); | |||
} | |||
for (; j < N; j += 1) { | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); | |||
for (k = 0; k < k8; k += 8) { | |||
LOAD_KA_512(0, x); LOAD_KA_512(1, x); | |||
LOAD_KB_512(x, 0); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); | |||
} | |||
int remains = K - k; | |||
if (remains) { | |||
mask = (1UL << remains) - 1; | |||
MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); | |||
MASK_LOAD_KB_512(x, 0); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); | |||
} | |||
STORE_REDUCE(0, 0); STORE_REDUCE(1, 0); | |||
} | |||
} | |||
for (; i < M; i += 1, mi += 1) { | |||
for (j = 0; j < n4; j += 4) { | |||
DECLARE_RESULT_512(0, 0); | |||
DECLARE_RESULT_512(0, 1); | |||
DECLARE_RESULT_512(0, 2); | |||
DECLARE_RESULT_512(0, 3); | |||
for (k = 0; k < k8; k += 8) { | |||
LOAD_KA_512(0, x); | |||
LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); LOAD_KB_512(x, 2); LOAD_KB_512(x, 3); | |||
MATMUL_512(0, 0); | |||
MATMUL_512(0, 1); | |||
MATMUL_512(0, 2); | |||
MATMUL_512(0, 3); | |||
} | |||
int remains = K - k; | |||
if (remains) { | |||
mask = (1UL << remains) - 1; | |||
MASK_LOAD_KA_512(0, x); | |||
MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); MASK_LOAD_KB_512(x, 2); MASK_LOAD_KB_512(x, 3); | |||
MATMUL_512(0, 0); | |||
MATMUL_512(0, 1); | |||
MATMUL_512(0, 2); | |||
MATMUL_512(0, 3); | |||
} | |||
STORE_REDUCE_N4(0); | |||
} | |||
for (; j < n2; j += 2) { | |||
DECLARE_RESULT_512(0, 0); | |||
DECLARE_RESULT_512(0, 1); | |||
for (k = 0; k < k8; k += 8) { | |||
LOAD_KA_512(0, x); | |||
LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); | |||
MATMUL_512(0, 0); | |||
MATMUL_512(0, 1); | |||
} | |||
int remains = K - k; | |||
if (remains) { | |||
mask = (1UL << remains) - 1; | |||
MASK_LOAD_KA_512(0, x); | |||
MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); | |||
MATMUL_512(0, 0); | |||
MATMUL_512(0, 1); | |||
} | |||
STORE_REDUCE(0, 0); | |||
STORE_REDUCE(0, 1); | |||
} | |||
for (; j < N; j += 1) { | |||
DECLARE_RESULT_512(0, 0); | |||
for (k = 0; k < k8; k += 8) { | |||
LOAD_KA_512(0, x); | |||
LOAD_KB_512(x, 0); | |||
MATMUL_512(0, 0); | |||
} | |||
int remains = K - k; | |||
if (remains) { | |||
mask = (1UL << remains) - 1; | |||
MASK_LOAD_KA_512(0, x); | |||
MASK_LOAD_KB_512(x, 0); | |||
MATMUL_512(0, 0); | |||
} | |||
STORE_REDUCE(0, 0); | |||
} | |||
} | |||
free(mbuf); | |||
} | |||
return 0; | |||
} |
@@ -0,0 +1,535 @@ | |||
/*************************************************************************** | |||
Copyright (c) 2021, The OpenBLAS Project | |||
All rights reserved. | |||
Redistribution and use in source and binary forms, with or without | |||
modification, are permitted provided that the following conditions are | |||
met: | |||
1. Redistributions of source code must retain the above copyright | |||
notice, this list of conditions and the following disclaimer. | |||
2. Redistributions in binary form must reproduce the above copyright | |||
notice, this list of conditions and the following disclaimer in | |||
the documentation and/or other materials provided with the | |||
distribution. | |||
3. Neither the name of the OpenBLAS project nor the names of | |||
its contributors may be used to endorse or promote products | |||
derived from this software without specific prior written permission. | |||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE | |||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE | |||
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||
*****************************************************************************/ | |||
#include <immintrin.h> | |||
#include "common.h" | |||
#include <stdio.h> | |||
#include <memory.h> | |||
#define DECLARE_RESULT_512(M, N) __m512d result##M##N = _mm512_setzero_pd() | |||
#define LOAD_A_512(M, N) __m512d Aval##M = _mm512_loadu_pd(&A[lda * k + i + (M*8)]) | |||
#define MASK_LOAD_A_512(M, N) __m512d Aval##M = _mm512_maskz_loadu_pd(mask, &A[lda * k + i + (M*8)]) | |||
#define BROADCAST_LOAD_B_512(M, N) __m512d Bval##N = _mm512_broadcastsd_pd(_mm_load_sd(&B[ldb * k + j + N])) | |||
#define MATMUL_512(M, N) result##M##N = _mm512_fmadd_pd(Aval##M, Bval##N, result##M##N) | |||
#define BROADCAST_LOAD_A_512(M, N) __m512d Aval##M = _mm512_broadcastsd_pd(_mm_load_sd(&A[lda * k + i + M])) | |||
#define LOAD_B_512(M, N) __m512d Bval##N = _mm512_loadu_pd(&B[ldb * k + j + (N*8)]) | |||
#define MASK_LOAD_B_512(M, N) __m512d Bval##N = _mm512_maskz_loadu_pd(mask, &B[ldb * k + j + (N*8)]) | |||
#if defined(B0) | |||
#define STORE_512(M, N) result##M##N = _mm512_mul_pd(result##M##N, alpha_512); \ | |||
_mm512_storeu_pd(&C[(j+N)*ldc + i + (M*8)], result##M##N) | |||
#define MASK_STORE_512(M, N) result##M##N = _mm512_mul_pd(result##M##N, alpha_512); \ | |||
_mm512_mask_storeu_pd(&C[(j+N)*ldc + i + (M*8)], mask, result##M##N) | |||
#define SCATTER_STORE_512(M, N) result##M##N = _mm512_mul_pd(result##M##N, alpha_512); \ | |||
_mm512_i64scatter_pd(&C[(j + N*8)*ldc + i + M], vindex_n, result##M##N, 8); | |||
#define MASK_SCATTER_STORE_512(M, N) result##M##N = _mm512_mul_pd(result##M##N, alpha_512); \ | |||
_mm512_mask_i64scatter_pd(&C[(j + N*8)*ldc + i + M], mask, vindex_n, result##M##N, 8) | |||
#else | |||
#define STORE_512(M, N) \ | |||
result##M##N = _mm512_mul_pd(result##M##N, alpha_512); \ | |||
asm("vfmadd231pd (%1), %2, %0": "+v"(result##M##N):"r"(&C[(j+N)*ldc + i + (M*8)]), "v"(beta_512)); \ | |||
_mm512_storeu_pd(&C[(j+N)*ldc + i + (M*8)], result##M##N) | |||
#define MASK_STORE_512(M, N) \ | |||
result##M##N = _mm512_mul_pd(result##M##N, alpha_512); \ | |||
asm("vfmadd231pd (%1), %2, %0 %{%3%}": "+v"(result##M##N):"r"(&C[(j+N)*ldc + i + (M*8)]), "v"(beta_512), "k"(mask)); \ | |||
_mm512_mask_storeu_pd(&C[(j+N)*ldc + i + (M*8)], mask, result##M##N) | |||
#define SCATTER_STORE_512(M, N) result##M##N = _mm512_mul_pd(result##M##N, alpha_512); \ | |||
__m512d tmp##M##N = _mm512_i64gather_pd(vindex_n, &C[(j + N*8)*ldc + i + M], 8); \ | |||
result##M##N = _mm512_fmadd_pd(tmp##M##N, beta_512, result##M##N); \ | |||
_mm512_i64scatter_pd(&C[(j + N*8)*ldc + i + M], vindex_n, result##M##N, 8); | |||
#define MASK_SCATTER_STORE_512(M, N) result##M##N = _mm512_mul_pd(result##M##N, alpha_512); \ | |||
__m512d tmp##M##N = _mm512_mask_i64gather_pd(_mm512_setzero_pd(), mask, vindex_n, &C[(j + N*8)*ldc + i + M], 8); \ | |||
result##M##N = _mm512_fmadd_pd(tmp##M##N, beta_512, result##M##N); \ | |||
_mm512_mask_i64scatter_pd(&C[(j + N*8)*ldc + i + M], mask, vindex_n, result##M##N, 8); | |||
#endif | |||
#if defined(B0) | |||
int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb, FLOAT * C, BLASLONG ldc) | |||
#else | |||
int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb, FLOAT beta, FLOAT * C, BLASLONG ldc) | |||
#endif | |||
{ | |||
// column major | |||
BLASLONG i, j, k; | |||
BLASLONG m32 = M & ~31; | |||
BLASLONG m16 = M & ~15; | |||
BLASLONG m8 = M & ~7; | |||
BLASLONG m4 = M & ~3; | |||
BLASLONG m2 = M & ~1; | |||
BLASLONG n32 = N & ~31; | |||
BLASLONG n16 = N & ~15; | |||
BLASLONG n8 = N & ~7; | |||
BLASLONG n6 = N - (N % 6); | |||
BLASLONG n4 = N & ~3; | |||
BLASLONG n2 = N & ~1; | |||
__m512d alpha_512 = _mm512_broadcastsd_pd(_mm_load_sd(&alpha)); | |||
#if !defined(B0) | |||
__m512d beta_512 = _mm512_broadcastsd_pd(_mm_load_sd(&beta)); | |||
#endif | |||
for (i = 0; i < m32; i += 32) { | |||
for (j = 0; j < n6; j += 6) { | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); | |||
DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); | |||
DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); DECLARE_RESULT_512(2, 2); DECLARE_RESULT_512(3, 2); | |||
DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); DECLARE_RESULT_512(2, 3); DECLARE_RESULT_512(3, 3); | |||
DECLARE_RESULT_512(0, 4); DECLARE_RESULT_512(1, 4); DECLARE_RESULT_512(2, 4); DECLARE_RESULT_512(3, 4); | |||
DECLARE_RESULT_512(0, 5); DECLARE_RESULT_512(1, 5); DECLARE_RESULT_512(2, 5); DECLARE_RESULT_512(3, 5); | |||
for (k = 0; k < K; k++) { | |||
LOAD_A_512(0, x); LOAD_A_512(1, x); LOAD_A_512(2, x); LOAD_A_512(3, x); | |||
BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); | |||
MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); | |||
BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3); | |||
MATMUL_512(0, 2); MATMUL_512(1, 2); MATMUL_512(2, 2); MATMUL_512(3, 2); | |||
MATMUL_512(0, 3); MATMUL_512(1, 3); MATMUL_512(2, 3); MATMUL_512(3, 3); | |||
BROADCAST_LOAD_B_512(x, 4); BROADCAST_LOAD_B_512(x, 5); | |||
MATMUL_512(0, 4); MATMUL_512(1, 4); MATMUL_512(2, 4); MATMUL_512(3, 4); | |||
MATMUL_512(0, 5); MATMUL_512(1, 5); MATMUL_512(2, 5); MATMUL_512(3, 5); | |||
} | |||
STORE_512(0, 0); STORE_512(1, 0); STORE_512(2, 0); STORE_512(3, 0); | |||
STORE_512(0, 1); STORE_512(1, 1); STORE_512(2, 1); STORE_512(3, 1); | |||
STORE_512(0, 2); STORE_512(1, 2); STORE_512(2, 2); STORE_512(3, 2); | |||
STORE_512(0, 3); STORE_512(1, 3); STORE_512(2, 3); STORE_512(3, 3); | |||
STORE_512(0, 4); STORE_512(1, 4); STORE_512(2, 4); STORE_512(3, 4); | |||
STORE_512(0, 5); STORE_512(1, 5); STORE_512(2, 5); STORE_512(3, 5); | |||
} | |||
for (; j < n2; j += 2) { | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); | |||
DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); | |||
for (k = 0; k < K; k++) { | |||
LOAD_A_512(0, x); LOAD_A_512(1, x); LOAD_A_512(2, x); LOAD_A_512(3, x); | |||
BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); | |||
MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); | |||
} | |||
STORE_512(0, 0); STORE_512(1, 0); STORE_512(2, 0); STORE_512(3, 0); | |||
STORE_512(0, 1); STORE_512(1, 1); STORE_512(2, 1); STORE_512(3, 1); | |||
} | |||
for (; j < N; j++) { | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); | |||
for (k = 0; k < K; k++) { | |||
LOAD_A_512(0, x); LOAD_A_512(1, x); LOAD_A_512(2, x); LOAD_A_512(3, x); | |||
BROADCAST_LOAD_B_512(x, 0); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); | |||
} | |||
STORE_512(0, 0); STORE_512(1, 0); STORE_512(2, 0); STORE_512(3, 0); | |||
} | |||
} | |||
for (; i < m16; i += 16) { | |||
for (j = 0; j < n8; j += 8) { | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); | |||
DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); | |||
DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); | |||
DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); | |||
DECLARE_RESULT_512(0, 4); DECLARE_RESULT_512(1, 4); | |||
DECLARE_RESULT_512(0, 5); DECLARE_RESULT_512(1, 5); | |||
DECLARE_RESULT_512(0, 6); DECLARE_RESULT_512(1, 6); | |||
DECLARE_RESULT_512(0, 7); DECLARE_RESULT_512(1, 7); | |||
for (k = 0; k < K; k++) { | |||
LOAD_A_512(0, x); LOAD_A_512(1, x); | |||
BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); | |||
BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3); | |||
BROADCAST_LOAD_B_512(x, 4); BROADCAST_LOAD_B_512(x, 5); | |||
BROADCAST_LOAD_B_512(x, 6); BROADCAST_LOAD_B_512(x, 7); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); | |||
MATMUL_512(0, 1); MATMUL_512(1, 1); | |||
MATMUL_512(0, 2); MATMUL_512(1, 2); | |||
MATMUL_512(0, 3); MATMUL_512(1, 3); | |||
MATMUL_512(0, 4); MATMUL_512(1, 4); | |||
MATMUL_512(0, 5); MATMUL_512(1, 5); | |||
MATMUL_512(0, 6); MATMUL_512(1, 6); | |||
MATMUL_512(0, 7); MATMUL_512(1, 7); | |||
} | |||
STORE_512(0, 0); STORE_512(1, 0); | |||
STORE_512(0, 1); STORE_512(1, 1); | |||
STORE_512(0, 2); STORE_512(1, 2); | |||
STORE_512(0, 3); STORE_512(1, 3); | |||
STORE_512(0, 4); STORE_512(1, 4); | |||
STORE_512(0, 5); STORE_512(1, 5); | |||
STORE_512(0, 6); STORE_512(1, 6); | |||
STORE_512(0, 7); STORE_512(1, 7); | |||
} | |||
for (;j < n4; j += 4) { | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); | |||
DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); | |||
DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); | |||
DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); | |||
for (k = 0; k < K; k++) { | |||
LOAD_A_512(0, x); LOAD_A_512(1, x); | |||
BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); | |||
BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); | |||
MATMUL_512(0, 1); MATMUL_512(1, 1); | |||
MATMUL_512(0, 2); MATMUL_512(1, 2); | |||
MATMUL_512(0, 3); MATMUL_512(1, 3); | |||
} | |||
STORE_512(0, 0); STORE_512(1, 0); | |||
STORE_512(0, 1); STORE_512(1, 1); | |||
STORE_512(0, 2); STORE_512(1, 2); | |||
STORE_512(0, 3); STORE_512(1, 3); | |||
} | |||
for (; j < n2; j += 2) { | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); | |||
DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); | |||
for (k = 0; k < K; k++) { | |||
LOAD_A_512(0, x); LOAD_A_512(1, x); | |||
BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); | |||
MATMUL_512(0, 1); MATMUL_512(1, 1); | |||
} | |||
STORE_512(0, 0); STORE_512(1, 0); | |||
STORE_512(0, 1); STORE_512(1, 1); | |||
} | |||
for (; j < N; j++) { | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); | |||
for (k = 0; k < K; k++) { | |||
LOAD_A_512(0, x); LOAD_A_512(1, x); | |||
BROADCAST_LOAD_B_512(x, 0); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); | |||
} | |||
STORE_512(0, 0); STORE_512(1, 0); | |||
} | |||
} | |||
for (; i < m8; i += 8) { | |||
for (j = 0; j < n8; j += 8) { | |||
DECLARE_RESULT_512(0, 0); | |||
DECLARE_RESULT_512(0, 1); | |||
DECLARE_RESULT_512(0, 2); | |||
DECLARE_RESULT_512(0, 3); | |||
DECLARE_RESULT_512(0, 4); | |||
DECLARE_RESULT_512(0, 5); | |||
DECLARE_RESULT_512(0, 6); | |||
DECLARE_RESULT_512(0, 7); | |||
for (k = 0; k < K; k++) { | |||
LOAD_A_512(0, x); | |||
BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); | |||
BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3); | |||
BROADCAST_LOAD_B_512(x, 4); BROADCAST_LOAD_B_512(x, 5); | |||
BROADCAST_LOAD_B_512(x, 6); BROADCAST_LOAD_B_512(x, 7); | |||
MATMUL_512(0, 0); | |||
MATMUL_512(0, 1); | |||
MATMUL_512(0, 2); | |||
MATMUL_512(0, 3); | |||
MATMUL_512(0, 4); | |||
MATMUL_512(0, 5); | |||
MATMUL_512(0, 6); | |||
MATMUL_512(0, 7); | |||
} | |||
STORE_512(0, 0); | |||
STORE_512(0, 1); | |||
STORE_512(0, 2); | |||
STORE_512(0, 3); | |||
STORE_512(0, 4); | |||
STORE_512(0, 5); | |||
STORE_512(0, 6); | |||
STORE_512(0, 7); | |||
} | |||
for (; j < n4; j += 4) { | |||
DECLARE_RESULT_512(0, 0); | |||
DECLARE_RESULT_512(0, 1); | |||
DECLARE_RESULT_512(0, 2); | |||
DECLARE_RESULT_512(0, 3); | |||
for (k = 0; k < K; k++) { | |||
LOAD_A_512(0, x); | |||
BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); | |||
BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3); | |||
MATMUL_512(0, 0); | |||
MATMUL_512(0, 1); | |||
MATMUL_512(0, 2); | |||
MATMUL_512(0, 3); | |||
} | |||
STORE_512(0, 0); | |||
STORE_512(0, 1); | |||
STORE_512(0, 2); | |||
STORE_512(0, 3); | |||
} | |||
for (; j < n2; j += 2) { | |||
DECLARE_RESULT_512(0, 0); | |||
DECLARE_RESULT_512(0, 1); | |||
for (k = 0; k < K; k++) { | |||
LOAD_A_512(0, x); | |||
BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); | |||
MATMUL_512(0, 0); | |||
MATMUL_512(0, 1); | |||
} | |||
STORE_512(0, 0); | |||
STORE_512(0, 1); | |||
} | |||
for (; j < N; j++) { | |||
DECLARE_RESULT_512(0, 0); | |||
for (k = 0; k < K; k++) { | |||
LOAD_A_512(0, x); | |||
BROADCAST_LOAD_B_512(x, 0); | |||
MATMUL_512(0, 0); | |||
} | |||
STORE_512(0, 0); | |||
} | |||
} | |||
int mm = M - i; | |||
if (mm >= 6) { | |||
register __mmask16 mask asm("k1") = (1UL << mm) - 1; | |||
for (j = 0; j < n8; j += 8) { | |||
DECLARE_RESULT_512(0, 0); | |||
DECLARE_RESULT_512(0, 1); | |||
DECLARE_RESULT_512(0, 2); | |||
DECLARE_RESULT_512(0, 3); | |||
DECLARE_RESULT_512(0, 4); | |||
DECLARE_RESULT_512(0, 5); | |||
DECLARE_RESULT_512(0, 6); | |||
DECLARE_RESULT_512(0, 7); | |||
for (k = 0; k < K; k++) { | |||
MASK_LOAD_A_512(0, x); | |||
BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); | |||
BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3); | |||
BROADCAST_LOAD_B_512(x, 4); BROADCAST_LOAD_B_512(x, 5); | |||
BROADCAST_LOAD_B_512(x, 6); BROADCAST_LOAD_B_512(x, 7); | |||
MATMUL_512(0, 0); | |||
MATMUL_512(0, 1); | |||
MATMUL_512(0, 2); | |||
MATMUL_512(0, 3); | |||
MATMUL_512(0, 4); | |||
MATMUL_512(0, 5); | |||
MATMUL_512(0, 6); | |||
MATMUL_512(0, 7); | |||
} | |||
MASK_STORE_512(0, 0); | |||
MASK_STORE_512(0, 1); | |||
MASK_STORE_512(0, 2); | |||
MASK_STORE_512(0, 3); | |||
MASK_STORE_512(0, 4); | |||
MASK_STORE_512(0, 5); | |||
MASK_STORE_512(0, 6); | |||
MASK_STORE_512(0, 7); | |||
} | |||
for (; j < n4; j += 4) { | |||
DECLARE_RESULT_512(0, 0); | |||
DECLARE_RESULT_512(0, 1); | |||
DECLARE_RESULT_512(0, 2); | |||
DECLARE_RESULT_512(0, 3); | |||
for (k = 0; k < K; k++) { | |||
MASK_LOAD_A_512(0, x); | |||
BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); | |||
BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3); | |||
MATMUL_512(0, 0); | |||
MATMUL_512(0, 1); | |||
MATMUL_512(0, 2); | |||
MATMUL_512(0, 3); | |||
} | |||
MASK_STORE_512(0, 0); | |||
MASK_STORE_512(0, 1); | |||
MASK_STORE_512(0, 2); | |||
MASK_STORE_512(0, 3); | |||
} | |||
for (; j < n2; j += 2) { | |||
DECLARE_RESULT_512(0, 0); | |||
DECLARE_RESULT_512(0, 1); | |||
for (k = 0; k < K; k++) { | |||
MASK_LOAD_A_512(0, x); | |||
BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); | |||
MATMUL_512(0, 0); | |||
MATMUL_512(0, 1); | |||
} | |||
MASK_STORE_512(0, 0); | |||
MASK_STORE_512(0, 1); | |||
} | |||
for (; j < N; j++) { | |||
DECLARE_RESULT_512(0, 0); | |||
for (k = 0; k < K; k++) { | |||
MASK_LOAD_A_512(0, x); | |||
BROADCAST_LOAD_B_512(x, 0); | |||
MATMUL_512(0, 0); | |||
} | |||
MASK_STORE_512(0, 0); | |||
} | |||
} else if (mm > 0) { | |||
long long index_n[8]; | |||
for (int ii = 0; ii < 8; ii++) { | |||
index_n[ii] = ii * ldc; | |||
} | |||
__m512i vindex_n = _mm512_loadu_si512(index_n); | |||
for (; i < m4; i += 4) { | |||
for (j = 0; j < n32; j += 32) { | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); | |||
DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); | |||
DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); DECLARE_RESULT_512(2, 2); DECLARE_RESULT_512(3, 2); | |||
DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); DECLARE_RESULT_512(2, 3); DECLARE_RESULT_512(3, 3); | |||
for (k = 0; k < K; k++) { | |||
BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); BROADCAST_LOAD_A_512(2, x); BROADCAST_LOAD_A_512(3, x); | |||
LOAD_B_512(x, 0); | |||
LOAD_B_512(x, 1); | |||
LOAD_B_512(x, 2); | |||
LOAD_B_512(x, 3); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); | |||
MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); | |||
MATMUL_512(0, 2); MATMUL_512(1, 2); MATMUL_512(2, 2); MATMUL_512(3, 2); | |||
MATMUL_512(0, 3); MATMUL_512(1, 3); MATMUL_512(2, 3); MATMUL_512(3, 3); | |||
} | |||
SCATTER_STORE_512(0, 0); SCATTER_STORE_512(1, 0); SCATTER_STORE_512(2, 0); SCATTER_STORE_512(3, 0); | |||
SCATTER_STORE_512(0, 1); SCATTER_STORE_512(1, 1); SCATTER_STORE_512(2, 1); SCATTER_STORE_512(3, 1); | |||
SCATTER_STORE_512(0, 2); SCATTER_STORE_512(1, 2); SCATTER_STORE_512(2, 2); SCATTER_STORE_512(3, 2); | |||
SCATTER_STORE_512(0, 3); SCATTER_STORE_512(1, 3); SCATTER_STORE_512(2, 3); SCATTER_STORE_512(3, 3); | |||
} | |||
for (; j < n16; j += 16) { | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); | |||
DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); | |||
for (k = 0; k < K; k++) { | |||
BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); BROADCAST_LOAD_A_512(2, x); BROADCAST_LOAD_A_512(3, x); | |||
LOAD_B_512(x, 0); | |||
LOAD_B_512(x, 1); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); | |||
MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); | |||
} | |||
SCATTER_STORE_512(0, 0); SCATTER_STORE_512(1, 0); SCATTER_STORE_512(2, 0); SCATTER_STORE_512(3, 0); | |||
SCATTER_STORE_512(0, 1); SCATTER_STORE_512(1, 1); SCATTER_STORE_512(2, 1); SCATTER_STORE_512(3, 1); | |||
} | |||
__mmask8 mask = 0xff; | |||
for (; j < N; j += 8) { | |||
int remains = N - j; | |||
if (remains < 8) mask = (1UL << remains) - 1; | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); | |||
for (k = 0; k < K; k++) { | |||
BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); BROADCAST_LOAD_A_512(2, x); BROADCAST_LOAD_A_512(3, x); | |||
MASK_LOAD_B_512(x, 0); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); | |||
} | |||
MASK_SCATTER_STORE_512(0, 0); MASK_SCATTER_STORE_512(1, 0); MASK_SCATTER_STORE_512(2, 0); MASK_SCATTER_STORE_512(3, 0); | |||
} | |||
} | |||
for (; i < m2; i += 2) { | |||
for (j = 0; j < n32; j += 32) { | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); | |||
DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); | |||
DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); | |||
DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); | |||
for (k = 0; k < K; k++) { | |||
BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); | |||
LOAD_B_512(x, 0); | |||
LOAD_B_512(x, 1); | |||
LOAD_B_512(x, 2); | |||
LOAD_B_512(x, 3); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); | |||
MATMUL_512(0, 1); MATMUL_512(1, 1); | |||
MATMUL_512(0, 2); MATMUL_512(1, 2); | |||
MATMUL_512(0, 3); MATMUL_512(1, 3); | |||
} | |||
SCATTER_STORE_512(0, 0); SCATTER_STORE_512(1, 0); | |||
SCATTER_STORE_512(0, 1); SCATTER_STORE_512(1, 1); | |||
SCATTER_STORE_512(0, 2); SCATTER_STORE_512(1, 2); | |||
SCATTER_STORE_512(0, 3); SCATTER_STORE_512(1, 3); | |||
} | |||
for (; j < n16; j += 16) { | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); | |||
DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); | |||
for (k = 0; k < K; k++) { | |||
BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); | |||
LOAD_B_512(x, 0); | |||
LOAD_B_512(x, 1); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); | |||
MATMUL_512(0, 1); MATMUL_512(1, 1); | |||
} | |||
SCATTER_STORE_512(0, 0); SCATTER_STORE_512(1, 0); | |||
SCATTER_STORE_512(0, 1); SCATTER_STORE_512(1, 1); | |||
} | |||
__mmask8 mask = 0xff; | |||
for (; j < N; j += 8) { | |||
int remains = N - j; | |||
if (remains < 8) mask = (1UL << remains) - 1; | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); | |||
for (k = 0; k < K; k++) { | |||
BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); | |||
MASK_LOAD_B_512(x, 0); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); | |||
} | |||
MASK_SCATTER_STORE_512(0, 0); MASK_SCATTER_STORE_512(1, 0); | |||
} | |||
} | |||
for (; i < M; i += 1) { | |||
for (j = 0; j < n32; j += 32) { | |||
DECLARE_RESULT_512(0, 0); | |||
DECLARE_RESULT_512(0, 1); | |||
DECLARE_RESULT_512(0, 2); | |||
DECLARE_RESULT_512(0, 3); | |||
for (k = 0; k < K; k++) { | |||
BROADCAST_LOAD_A_512(0, x); | |||
LOAD_B_512(x, 0); | |||
LOAD_B_512(x, 1); | |||
LOAD_B_512(x, 2); | |||
LOAD_B_512(x, 3); | |||
MATMUL_512(0, 0); | |||
MATMUL_512(0, 1); | |||
MATMUL_512(0, 2); | |||
MATMUL_512(0, 3); | |||
} | |||
SCATTER_STORE_512(0, 0); | |||
SCATTER_STORE_512(0, 1); | |||
SCATTER_STORE_512(0, 2); | |||
SCATTER_STORE_512(0, 3); | |||
} | |||
for (; j < n16; j += 16) { | |||
DECLARE_RESULT_512(0, 0); | |||
DECLARE_RESULT_512(0, 1); | |||
for (k = 0; k < K; k++) { | |||
BROADCAST_LOAD_A_512(0, x); | |||
LOAD_B_512(x, 0); | |||
LOAD_B_512(x, 1); | |||
MATMUL_512(0, 0); | |||
MATMUL_512(0, 1); | |||
} | |||
SCATTER_STORE_512(0, 0); | |||
SCATTER_STORE_512(0, 1); | |||
} | |||
__mmask8 mask = 0xff; | |||
for (; j < N; j += 8) { | |||
int remains = N - j; | |||
if (remains < 8) mask = (1UL << remains) - 1; | |||
DECLARE_RESULT_512(0, 0); | |||
for (k = 0; k < K; k++) { | |||
BROADCAST_LOAD_A_512(0, x); | |||
MASK_LOAD_B_512(x, 0); | |||
MATMUL_512(0, 0); | |||
} | |||
MASK_SCATTER_STORE_512(0, 0); | |||
} | |||
} | |||
} | |||
return 0; | |||
} |
@@ -0,0 +1,44 @@ | |||
/*************************************************************************** | |||
Copyright (c) 2021, The OpenBLAS Project | |||
All rights reserved. | |||
Redistribution and use in source and binary forms, with or without | |||
modification, are permitted provided that the following conditions are | |||
met: | |||
1. Redistributions of source code must retain the above copyright | |||
notice, this list of conditions and the following disclaimer. | |||
2. Redistributions in binary form must reproduce the above copyright | |||
notice, this list of conditions and the following disclaimer in | |||
the documentation and/or other materials provided with the | |||
distribution. | |||
3. Neither the name of the OpenBLAS project nor the names of | |||
its contributors may be used to endorse or promote products | |||
derived from this software without specific prior written permission. | |||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE | |||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE | |||
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||
*****************************************************************************/ | |||
#include "common.h" | |||
int CNAME(int transa, int transb, BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, FLOAT beta) | |||
{ | |||
double MNK = (double) M * (double) N * (double) K; | |||
if (MNK > 100.0*100.0*100.0) // disable for big size matrix | |||
return 0; | |||
if (transa && !transb) { | |||
/* TN kernel perform not good when: | |||
* 1. C matrix is too big | |||
* 2. K is too small | |||
*/ | |||
if (M * N > 1200 || K < 32) | |||
return 0; | |||
} | |||
return 1; | |||
} |
@@ -0,0 +1,322 @@ | |||
/*************************************************************************** | |||
Copyright (c) 2021, The OpenBLAS Project | |||
All rights reserved. | |||
Redistribution and use in source and binary forms, with or without | |||
modification, are permitted provided that the following conditions are | |||
met: | |||
1. Redistributions of source code must retain the above copyright | |||
notice, this list of conditions and the following disclaimer. | |||
2. Redistributions in binary form must reproduce the above copyright | |||
notice, this list of conditions and the following disclaimer in | |||
the documentation and/or other materials provided with the | |||
distribution. | |||
3. Neither the name of the OpenBLAS project nor the names of | |||
its contributors may be used to endorse or promote products | |||
derived from this software without specific prior written permission. | |||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE | |||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE | |||
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||
*****************************************************************************/ | |||
#include <immintrin.h> | |||
#include "common.h" | |||
#include <stdio.h> | |||
#include <memory.h> | |||
#define DECLARE_RESULT_512(M, N) __m512d result##M##N = _mm512_setzero_pd() | |||
#define MATMUL_512(M, N) result##M##N = _mm512_fmadd_pd(Aval##M, Bval##N, result##M##N) | |||
#define LOAD_KA_512(M, N) __m512d Aval##M = _mm512_loadu_pd(&A[(i + M)*lda + k]); | |||
#define LOAD_KB_512(M, N) __m512d Bval##N = _mm512_loadu_pd(&B[(j + N)*ldb + k]) | |||
#define MASK_LOAD_KA_512(M, N) __m512d Aval##M = _mm512_maskz_loadu_pd(mask, &A[(i + M)*lda + k]) | |||
#define MASK_LOAD_KB_512(M, N) __m512d Bval##N = _mm512_maskz_loadu_pd(mask, &B[(j + N)*ldb + k]) | |||
#define REDUCE_4(rr0, rr1, rr2, rr3) \ | |||
__m512d r0, r1, r2, r3, t0, t1, t2, t3;\ | |||
r0 = _mm512_unpacklo_pd(rr0, rr1); r1 = _mm512_unpackhi_pd(rr0, rr1); \ | |||
r2 = _mm512_unpacklo_pd(rr2, rr3); r3 = _mm512_unpackhi_pd(rr2, rr3); \ | |||
t0 = _mm512_permutex2var_pd(r0, idx_lo, r2); t1 = _mm512_permutex2var_pd(r1, idx_lo, r3); \ | |||
t2 = _mm512_permutex2var_pd(r0, idx_hi, r2); t3 = _mm512_permutex2var_pd(r1, idx_hi, r3); \ | |||
r0 = _mm512_add_pd(t0, t1); r1 = _mm512_add_pd(t2, t3); t0 = _mm512_add_pd(r0, r1); \ | |||
__m256d s0, s1; \ | |||
s0 = _mm512_extractf64x4_pd(t0, 0); s1 = _mm512_extractf64x4_pd(t0, 1); \ | |||
s0 = _mm256_add_pd(s0, s1); s0 = _mm256_mul_pd(alpha_256, s0); | |||
#define REDUCE_M4(N) REDUCE_4(result0##N, result1##N, result2##N, result3##N) | |||
#define REDUCE_N4(M) REDUCE_4(result##M##0, result##M##1, result##M##2, result##M##3) | |||
#if defined(B0) | |||
#define STORE_REDUCE(M, N) C[(j+N)*ldc + i + M] = alpha * _mm512_reduce_add_pd(result##M##N) | |||
#define STORE_M4(N, s0) _mm256_storeu_pd(&C[(j + N)*ldc + i], s0); | |||
#define STORE_N4(M, s0) _mm256_i64scatter_pd(&C[j*ldc + i + M], vindex_n, s0, 8); | |||
#else | |||
#define STORE_REDUCE(M, N) C[(j+N)*ldc + i + M] = alpha * _mm512_reduce_add_pd(result##M##N) + beta * C[(j+N)*ldc + i + M] | |||
#define STORE_M4(N, s0) \ | |||
asm("vfmadd231pd (%1), %2, %0": "+v"(s0):"r"(&C[(j + N)*ldc + i]), "v"(beta_256)); \ | |||
_mm256_storeu_pd(&C[(j + N)*ldc + i], s0); | |||
#define STORE_N4(M, s0) \ | |||
s0 = _mm256_fmadd_pd(_mm256_i64gather_pd(&C[j*ldc + i + M], vindex_n, 8), beta_256, s0); \ | |||
_mm256_i64scatter_pd(&C[j*ldc + i + M], vindex_n, s0, 8); | |||
#endif | |||
#define STORE_REDUCE_M4(N) {\ | |||
REDUCE_M4(N) \ | |||
STORE_M4(N, s0) \ | |||
} | |||
#define STORE_REDUCE_N4(M) {\ | |||
REDUCE_N4(M) \ | |||
STORE_N4(M, s0) \ | |||
} | |||
#if defined(B0) | |||
int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb, FLOAT * C, BLASLONG ldc) | |||
#else | |||
int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb, FLOAT beta, FLOAT * C, BLASLONG ldc) | |||
#endif | |||
{ | |||
// column major | |||
BLASLONG i, j, k; | |||
BLASLONG m4 = M & ~3; | |||
BLASLONG m2 = M & ~1; | |||
BLASLONG n4 = N & ~3; | |||
BLASLONG n2 = N & ~1; | |||
BLASLONG k8 = K & ~7; | |||
__mmask8 mask; | |||
__m256i vindex_n = _mm256_set_epi64x(ldc*3, ldc*2, ldc, 0); | |||
__m256d alpha_256 = _mm256_broadcast_sd(&alpha); | |||
#if !defined(B0) | |||
__m256d beta_256 = _mm256_broadcast_sd(&beta); | |||
#endif | |||
long long permute_table[] = { | |||
0, 1, 0|8, 1|8, 4, 5, 4|8, 5|8, | |||
2, 3, 2|8, 3|8, 6, 7, 6|8, 7|8, | |||
}; | |||
__m512i idx_lo = _mm512_loadu_si512(permute_table); | |||
__m512i idx_hi = _mm512_loadu_si512(permute_table + 8); | |||
for (i = 0; i < m4; i += 4) { | |||
for (j = 0; j < n4; j += 4) { | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); | |||
DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); | |||
DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); DECLARE_RESULT_512(2, 2); DECLARE_RESULT_512(3, 2); | |||
DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); DECLARE_RESULT_512(2, 3); DECLARE_RESULT_512(3, 3); | |||
for (k = 0; k < k8; k += 8) { | |||
LOAD_KA_512(0, x); LOAD_KA_512(1, x); LOAD_KA_512(2, x); LOAD_KA_512(3, x); | |||
LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); LOAD_KB_512(x, 2); LOAD_KB_512(x, 3); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); | |||
MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); | |||
MATMUL_512(0, 2); MATMUL_512(1, 2); MATMUL_512(2, 2); MATMUL_512(3, 2); | |||
MATMUL_512(0, 3); MATMUL_512(1, 3); MATMUL_512(2, 3); MATMUL_512(3, 3); | |||
} | |||
int remains = K - k; | |||
if (remains) { | |||
mask = (1UL << remains) - 1; | |||
MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); MASK_LOAD_KA_512(2, x); MASK_LOAD_KA_512(3, x); | |||
MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); MASK_LOAD_KB_512(x, 2); MASK_LOAD_KB_512(x, 3); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); | |||
MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); | |||
MATMUL_512(0, 2); MATMUL_512(1, 2); MATMUL_512(2, 2); MATMUL_512(3, 2); | |||
MATMUL_512(0, 3); MATMUL_512(1, 3); MATMUL_512(2, 3); MATMUL_512(3, 3); | |||
} | |||
STORE_REDUCE_M4(0); STORE_REDUCE_M4(1); STORE_REDUCE_M4(2); STORE_REDUCE_M4(3); | |||
} | |||
for (; j < n2; j += 2) { | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); | |||
DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); | |||
for (k = 0; k < k8; k += 8) { | |||
LOAD_KA_512(0, x); LOAD_KA_512(1, x); LOAD_KA_512(2, x); LOAD_KA_512(3, x); | |||
LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); | |||
MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); | |||
} | |||
int remains = K - k; | |||
if (remains) { | |||
mask = (1UL << remains) - 1; | |||
MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); MASK_LOAD_KA_512(2, x); MASK_LOAD_KA_512(3, x); | |||
MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); | |||
MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); | |||
} | |||
STORE_REDUCE_M4(0); STORE_REDUCE_M4(1); | |||
} | |||
for (; j < N; j += 1) { | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); | |||
for (k = 0; k < k8; k += 8) { | |||
LOAD_KA_512(0, x); LOAD_KA_512(1, x); LOAD_KA_512(2, x); LOAD_KA_512(3, x); | |||
LOAD_KB_512(x, 0); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); | |||
} | |||
int remains = K - k; | |||
if (remains) { | |||
mask = (1UL << remains) - 1; | |||
MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); MASK_LOAD_KA_512(2, x); MASK_LOAD_KA_512(3, x); | |||
MASK_LOAD_KB_512(x, 0); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); | |||
} | |||
STORE_REDUCE_M4(0); | |||
} | |||
} | |||
for (; i < m2; i += 2) { | |||
for (j = 0; j < n4; j += 4) { | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); | |||
DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); | |||
DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); | |||
DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); | |||
for (k = 0; k < k8; k += 8) { | |||
LOAD_KA_512(0, x); LOAD_KA_512(1, x); | |||
LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); LOAD_KB_512(x, 2); LOAD_KB_512(x, 3); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); | |||
MATMUL_512(0, 1); MATMUL_512(1, 1); | |||
MATMUL_512(0, 2); MATMUL_512(1, 2); | |||
MATMUL_512(0, 3); MATMUL_512(1, 3); | |||
} | |||
int remains = K - k; | |||
if (remains) { | |||
mask = (1UL << remains) - 1; | |||
MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); | |||
MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); MASK_LOAD_KB_512(x, 2); MASK_LOAD_KB_512(x, 3); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); | |||
MATMUL_512(0, 1); MATMUL_512(1, 1); | |||
MATMUL_512(0, 2); MATMUL_512(1, 2); | |||
MATMUL_512(0, 3); MATMUL_512(1, 3); | |||
} | |||
STORE_REDUCE_N4(0); STORE_REDUCE_N4(1); | |||
} | |||
for (; j < n2; j += 2) { | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); | |||
DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); | |||
for (k = 0; k < k8; k += 8) { | |||
LOAD_KA_512(0, x); LOAD_KA_512(1, x); | |||
LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); | |||
MATMUL_512(0, 1); MATMUL_512(1, 1); | |||
} | |||
int remains = K - k; | |||
if (remains) { | |||
mask = (1UL << remains) - 1; | |||
MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); | |||
MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); | |||
MATMUL_512(0, 1); MATMUL_512(1, 1); | |||
} | |||
STORE_REDUCE(0, 0); STORE_REDUCE(1, 0); | |||
STORE_REDUCE(0, 1); STORE_REDUCE(1, 1); | |||
} | |||
for (; j < N; j += 1) { | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); | |||
for (k = 0; k < k8; k += 8) { | |||
LOAD_KA_512(0, x); LOAD_KA_512(1, x); | |||
LOAD_KB_512(x, 0); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); | |||
} | |||
int remains = K - k; | |||
if (remains) { | |||
mask = (1UL << remains) - 1; | |||
MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); | |||
MASK_LOAD_KB_512(x, 0); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); | |||
} | |||
STORE_REDUCE(0, 0); STORE_REDUCE(1, 0); | |||
} | |||
} | |||
for (; i < M; i += 1) { | |||
for (j = 0; j < n4; j += 4) { | |||
DECLARE_RESULT_512(0, 0); | |||
DECLARE_RESULT_512(0, 1); | |||
DECLARE_RESULT_512(0, 2); | |||
DECLARE_RESULT_512(0, 3); | |||
for (k = 0; k < k8; k += 8) { | |||
LOAD_KA_512(0, x); | |||
LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); LOAD_KB_512(x, 2); LOAD_KB_512(x, 3); | |||
MATMUL_512(0, 0); | |||
MATMUL_512(0, 1); | |||
MATMUL_512(0, 2); | |||
MATMUL_512(0, 3); | |||
} | |||
int remains = K - k; | |||
if (remains) { | |||
mask = (1UL << remains) - 1; | |||
MASK_LOAD_KA_512(0, x); | |||
MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); MASK_LOAD_KB_512(x, 2); MASK_LOAD_KB_512(x, 3); | |||
MATMUL_512(0, 0); | |||
MATMUL_512(0, 1); | |||
MATMUL_512(0, 2); | |||
MATMUL_512(0, 3); | |||
} | |||
STORE_REDUCE_N4(0); | |||
} | |||
for (; j < n2; j += 2) { | |||
DECLARE_RESULT_512(0, 0); | |||
DECLARE_RESULT_512(0, 1); | |||
for (k = 0; k < k8; k += 8) { | |||
LOAD_KA_512(0, x); | |||
LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); | |||
MATMUL_512(0, 0); | |||
MATMUL_512(0, 1); | |||
} | |||
int remains = K - k; | |||
if (remains) { | |||
mask = (1UL << remains) - 1; | |||
MASK_LOAD_KA_512(0, x); | |||
MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); | |||
MATMUL_512(0, 0); | |||
MATMUL_512(0, 1); | |||
} | |||
STORE_REDUCE(0, 0); | |||
STORE_REDUCE(0, 1); | |||
} | |||
for (; j < N; j += 1) { | |||
DECLARE_RESULT_512(0, 0); | |||
for (k = 0; k < k8; k += 8) { | |||
LOAD_KA_512(0, x); | |||
LOAD_KB_512(x, 0); | |||
MATMUL_512(0, 0); | |||
} | |||
int remains = K - k; | |||
if (remains) { | |||
mask = (1UL << remains) - 1; | |||
MASK_LOAD_KA_512(0, x); | |||
MASK_LOAD_KB_512(x, 0); | |||
MATMUL_512(0, 0); | |||
} | |||
STORE_REDUCE(0, 0); | |||
} | |||
} | |||
return 0; | |||
} |
@@ -0,0 +1,392 @@ | |||
/*************************************************************************** | |||
Copyright (c) 2021, The OpenBLAS Project | |||
All rights reserved. | |||
Redistribution and use in source and binary forms, with or without | |||
modification, are permitted provided that the following conditions are | |||
met: | |||
1. Redistributions of source code must retain the above copyright | |||
notice, this list of conditions and the following disclaimer. | |||
2. Redistributions in binary form must reproduce the above copyright | |||
notice, this list of conditions and the following disclaimer in | |||
the documentation and/or other materials provided with the | |||
distribution. | |||
3. Neither the name of the OpenBLAS project nor the names of | |||
its contributors may be used to endorse or promote products | |||
derived from this software without specific prior written permission. | |||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE | |||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE | |||
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||
*****************************************************************************/ | |||
#include <immintrin.h> | |||
#include "common.h" | |||
#include <stdio.h> | |||
#define DECLARE_RESULT_512(M, N) __m512d result##M##N = _mm512_setzero_pd() | |||
#define BROADCAST_LOAD_A_512(M, N) __m512d Aval##M = _mm512_broadcastsd_pd(_mm_load_sd(&A[k + lda * (i+M)])) | |||
#define LOAD_B_512(M,N) __m512d Bval##N = _mm512_loadu_pd(&B[ldb * k + j + (N*8)]) | |||
#define MASK_LOAD_B_512(M, N) __m512d Bval##N = _mm512_maskz_loadu_pd(mask, &B[ldb * k + j + (N*8)]) | |||
#define MATMUL_512(M, N) result##M##N = _mm512_fmadd_pd(Aval##M, Bval##N, result##M##N) | |||
#if defined(B0) | |||
#define STORE_8xy(v, N, x, y) _mm512_storeu_pd(&C[(j + N*8 + x + y*8)*ldc + i], v) | |||
#define STORE_4xy(v, N, x, y) _mm256_storeu_pd(&C[(j + N*8 + x + y*4)*ldc + i], v) | |||
#define SCATTER_STORE_512(M, N) result##M##N = _mm512_mul_pd(result##M##N, alpha_512); \ | |||
_mm512_i64scatter_pd(&C[(j + N*8)*ldc + i + M], vindex_n, result##M##N, 8); | |||
#define MASK_SCATTER_STORE_512(M, N) result##M##N = _mm512_mul_pd(result##M##N, alpha_512); \ | |||
_mm512_mask_i64scatter_pd(&C[(j + N*8)*ldc + i + M], mask, vindex_n, result##M##N, 8); | |||
#else | |||
#define STORE_8xy(v, N, x, y) \ | |||
asm("vfmadd231pd (%1), %2, %0": "+v"(v): "r"(&C[(j + N*8 + x + y*8)*ldc + i]), "v"(beta_512)); \ | |||
_mm512_storeu_pd(&C[(j + N*8 + x + y*8)*ldc + i], v) | |||
#define STORE_4xy(v, N, x, y) \ | |||
asm("vfmadd231pd (%1), %2, %0": "+v"(v): "r"(&C[(j + N*8 + x + y*4)*ldc + i]), "v"(beta_256)); \ | |||
_mm256_storeu_pd(&C[(j + N*8 + x + y*4)*ldc + i], v) | |||
#define SCATTER_STORE_512(M, N) result##M##N = _mm512_mul_pd(result##M##N, alpha_512); \ | |||
__m512d tmp##M##N = _mm512_i64gather_pd(vindex_n, &C[(j + N*8)*ldc + i + M], 8); \ | |||
result##M##N = _mm512_fmadd_pd(tmp##M##N, beta_512, result##M##N); \ | |||
_mm512_i64scatter_pd(&C[(j + N*8)*ldc + i + M], vindex_n, result##M##N, 8); | |||
#define MASK_SCATTER_STORE_512(M, N) result##M##N = _mm512_mul_pd(result##M##N, alpha_512); \ | |||
__m512d tmp##M##N = _mm512_mask_i64gather_pd(_mm512_setzero_pd(), mask, vindex_n, &C[(j + N*8)*ldc + i + M], 8); \ | |||
result##M##N = _mm512_fmadd_pd(tmp##M##N, beta_512, result##M##N); \ | |||
_mm512_mask_i64scatter_pd(&C[(j + N*8)*ldc + i + M], mask, vindex_n, result##M##N, 8); | |||
#endif | |||
#define REORDER_8x8(r0, r1, r2, r3, r4, r5, r6, r7) \ | |||
__m512d t0, t1, t2, t3, t4, t5, t6, t7; \ | |||
t0 = _mm512_unpacklo_pd(r0, r1); \ | |||
t1 = _mm512_unpackhi_pd(r0, r1); \ | |||
t2 = _mm512_unpacklo_pd(r2, r3); \ | |||
t3 = _mm512_unpackhi_pd(r2, r3); \ | |||
t4 = _mm512_unpacklo_pd(r4, r5); \ | |||
t5 = _mm512_unpackhi_pd(r4, r5); \ | |||
t6 = _mm512_unpacklo_pd(r6, r7); \ | |||
t7 = _mm512_unpackhi_pd(r6, r7); \ | |||
r0 = _mm512_shuffle_f64x2(t0, t2, 0x88); \ | |||
r1 = _mm512_shuffle_f64x2(t1, t3, 0x88); \ | |||
r2 = _mm512_shuffle_f64x2(t0, t2, 0xdd); \ | |||
r3 = _mm512_shuffle_f64x2(t1, t3, 0xdd); \ | |||
r4 = _mm512_shuffle_f64x2(t4, t6, 0x88); \ | |||
r5 = _mm512_shuffle_f64x2(t5, t7, 0x88); \ | |||
r6 = _mm512_shuffle_f64x2(t4, t6, 0xdd); \ | |||
r7 = _mm512_shuffle_f64x2(t5, t7, 0xdd); \ | |||
t0 = _mm512_permutex2var_pd(r0, idx_lo, r4); \ | |||
t1 = _mm512_permutex2var_pd(r1, idx_lo, r5); \ | |||
t2 = _mm512_permutex2var_pd(r2, idx_lo, r6); \ | |||
t3 = _mm512_permutex2var_pd(r3, idx_lo, r7); \ | |||
t4 = _mm512_permutex2var_pd(r0, idx_hi, r4); \ | |||
t5 = _mm512_permutex2var_pd(r1, idx_hi, r5); \ | |||
t6 = _mm512_permutex2var_pd(r2, idx_hi, r6); \ | |||
t7 = _mm512_permutex2var_pd(r3, idx_hi, r7); \ | |||
t0 = _mm512_mul_pd(t0, alpha_512); \ | |||
t1 = _mm512_mul_pd(t1, alpha_512); \ | |||
t2 = _mm512_mul_pd(t2, alpha_512); \ | |||
t3 = _mm512_mul_pd(t3, alpha_512); \ | |||
t4 = _mm512_mul_pd(t4, alpha_512); \ | |||
t5 = _mm512_mul_pd(t5, alpha_512); \ | |||
t6 = _mm512_mul_pd(t6, alpha_512); \ | |||
t7 = _mm512_mul_pd(t7, alpha_512); | |||
#define SAVE_8(N, x) {\ | |||
STORE_8xy(t##x, N, x, 0); \ | |||
} | |||
#define REORDER_STORE_8x8(N) {\ | |||
REORDER_8x8(result0##N, result1##N, result2##N, result3##N, result4##N, result5##N, result6##N, result7##N); \ | |||
SAVE_8(N, 0); SAVE_8(N, 1); SAVE_8(N, 2); SAVE_8(N, 3); SAVE_8(N, 4); SAVE_8(N, 5); SAVE_8(N, 6); SAVE_8(N, 7); \ | |||
} | |||
#define MASK_SAVE_8() \ | |||
switch (nn) { \ | |||
case 8: SAVE_8(0, 7); \ | |||
case 7: SAVE_8(0, 6); \ | |||
case 6: SAVE_8(0, 5); \ | |||
case 5: SAVE_8(0, 4); \ | |||
case 4: SAVE_8(0, 3); \ | |||
case 3: SAVE_8(0, 2); \ | |||
case 2: SAVE_8(0, 1); \ | |||
case 1: SAVE_8(0, 0); \ | |||
} | |||
#define MASK_REORDER_STORE_8x8(N) {\ | |||
REORDER_8x8(result0##N, result1##N, result2##N, result3##N, result4##N, result5##N, result6##N, result7##N); \ | |||
MASK_SAVE_8(); \ | |||
} | |||
#define REORDER_4x8(r0, r1, r2, r3) \ | |||
__m512d t0, t1, t2, t3; \ | |||
t0 = _mm512_unpacklo_pd(r0, r1); \ | |||
t1 = _mm512_unpackhi_pd(r0, r1); \ | |||
t2 = _mm512_unpacklo_pd(r2, r3); \ | |||
t3 = _mm512_unpackhi_pd(r2, r3); \ | |||
r0 = _mm512_permutex2var_pd(t0, idx_lo, t2); \ | |||
r1 = _mm512_permutex2var_pd(t1, idx_lo, t3); \ | |||
r2 = _mm512_permutex2var_pd(t0, idx_hi, t2); \ | |||
r3 = _mm512_permutex2var_pd(t1, idx_hi, t3); \ | |||
t0 = _mm512_mul_pd(r0, alpha_512); \ | |||
t1 = _mm512_mul_pd(r1, alpha_512); \ | |||
t2 = _mm512_mul_pd(r2, alpha_512); \ | |||
t3 = _mm512_mul_pd(r3, alpha_512); | |||
#define SAVE_4(N, x, y) {\ | |||
__m256d v4 = _mm512_extractf64x4_pd(t##x, y); \ | |||
STORE_4xy(v4, N, x, y); \ | |||
} | |||
#define REORDER_STORE_4x8(N) {\ | |||
REORDER_4x8(result0##N, result1##N, result2##N, result3##N); \ | |||
SAVE_4(N, 0, 0); SAVE_4(N, 1, 0); SAVE_4(N, 2, 0); SAVE_4(N, 3, 0); \ | |||
SAVE_4(N, 0, 1); SAVE_4(N, 1, 1); SAVE_4(N, 2, 1); SAVE_4(N, 3, 1); \ | |||
} | |||
#define MASK_SAVE_4() \ | |||
switch (nn) { \ | |||
case 8: SAVE_4(0, 3, 1); \ | |||
case 7: SAVE_4(0, 2, 1); \ | |||
case 6: SAVE_4(0, 1, 1); \ | |||
case 5: SAVE_4(0, 0, 1); \ | |||
case 4: SAVE_4(0, 3, 0); \ | |||
case 3: SAVE_4(0, 2, 0); \ | |||
case 2: SAVE_4(0, 1, 0); \ | |||
case 1: SAVE_4(0, 0, 0); \ | |||
} | |||
#define MASK_REORDER_STORE_4x8(N) {\ | |||
REORDER_4x8(result0##N, result1##N, result2##N, result3##N); \ | |||
MASK_SAVE_4(); \ | |||
} | |||
#if defined(B0) | |||
int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb, FLOAT * C, BLASLONG ldc) | |||
#else | |||
int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb, FLOAT beta, FLOAT * C, BLASLONG ldc) | |||
#endif | |||
{ | |||
// column major | |||
BLASLONG i, j, k; | |||
BLASLONG m8 = M & ~7; | |||
BLASLONG m4 = M & ~3; | |||
BLASLONG m2 = M & ~1; | |||
BLASLONG n32 = N & ~31; | |||
BLASLONG n16 = N & ~15; | |||
__m512d alpha_512 = _mm512_broadcastsd_pd(_mm_load_sd(&alpha)); | |||
#if !defined(B0) | |||
__m512d beta_512 = _mm512_broadcastsd_pd(_mm_load_sd(&beta)); | |||
__m256d beta_256 = _mm256_broadcastsd_pd(_mm_load_sd(&beta)); | |||
#endif | |||
long long permute_table[] = { | |||
0, 1, 4, 5, 0|8, 1|8, 4|8, 5|8, | |||
2, 3, 6, 7, 2|8, 3|8, 6|8, 7|8, | |||
}; | |||
__m512i idx_lo = _mm512_loadu_si512(permute_table); | |||
__m512i idx_hi = _mm512_loadu_si512(permute_table + 8); | |||
for (i = 0; i < m8; i += 8) { | |||
for (j = 0; j < n16; j += 16) { | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); | |||
DECLARE_RESULT_512(4, 0); DECLARE_RESULT_512(5, 0); DECLARE_RESULT_512(6, 0); DECLARE_RESULT_512(7, 0); | |||
DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); | |||
DECLARE_RESULT_512(4, 1); DECLARE_RESULT_512(5, 1); DECLARE_RESULT_512(6, 1); DECLARE_RESULT_512(7, 1); | |||
for (k = 0; k < K; k++) { | |||
BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); BROADCAST_LOAD_A_512(2, x); BROADCAST_LOAD_A_512(3, x); | |||
BROADCAST_LOAD_A_512(4, x); BROADCAST_LOAD_A_512(5, x); BROADCAST_LOAD_A_512(6, x); BROADCAST_LOAD_A_512(7, x); | |||
LOAD_B_512(x, 0); LOAD_B_512(x, 1); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); | |||
MATMUL_512(4, 0); MATMUL_512(5, 0); MATMUL_512(6, 0); MATMUL_512(7, 0); | |||
MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); | |||
MATMUL_512(4, 1); MATMUL_512(5, 1); MATMUL_512(6, 1); MATMUL_512(7, 1); | |||
} | |||
REORDER_STORE_8x8(0); | |||
REORDER_STORE_8x8(1); | |||
} | |||
__mmask8 mask = 0xff; | |||
int nn = 8; | |||
for (; j < N; j += 8) { | |||
if (N - j < 8) { | |||
nn = N - j; | |||
mask = (1UL << nn) - 1; | |||
} | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); | |||
DECLARE_RESULT_512(4, 0); DECLARE_RESULT_512(5, 0); DECLARE_RESULT_512(6, 0); DECLARE_RESULT_512(7, 0); | |||
for (k = 0; k < K; k++) { | |||
BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); BROADCAST_LOAD_A_512(2, x); BROADCAST_LOAD_A_512(3, x); | |||
BROADCAST_LOAD_A_512(4, x); BROADCAST_LOAD_A_512(5, x); BROADCAST_LOAD_A_512(6, x); BROADCAST_LOAD_A_512(7, x); | |||
MASK_LOAD_B_512(x, 0); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); | |||
MATMUL_512(4, 0); MATMUL_512(5, 0); MATMUL_512(6, 0); MATMUL_512(7, 0); | |||
} | |||
MASK_REORDER_STORE_8x8(0); | |||
} | |||
} | |||
for (; i < m4; i += 4) { | |||
long long permute_table2[] = { | |||
0, 1, 0|8, 1|8, 4, 5, 4|8, 5|8, | |||
2, 3, 2|8, 3|8, 6, 7, 6|8, 7|8, | |||
}; | |||
idx_lo = _mm512_loadu_si512(permute_table2); | |||
idx_hi = _mm512_loadu_si512(permute_table2 + 8); | |||
for (j = 0; j < n32; j += 32) { | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); | |||
DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); | |||
DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); DECLARE_RESULT_512(2, 2); DECLARE_RESULT_512(3, 2); | |||
DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); DECLARE_RESULT_512(2, 3); DECLARE_RESULT_512(3, 3); | |||
for (k = 0; k < K; k++) { | |||
BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); BROADCAST_LOAD_A_512(2, x); BROADCAST_LOAD_A_512(3, x); | |||
LOAD_B_512(x, 0); LOAD_B_512(x, 1); LOAD_B_512(x, 2); LOAD_B_512(x, 3); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); | |||
MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); | |||
MATMUL_512(0, 2); MATMUL_512(1, 2); MATMUL_512(2, 2); MATMUL_512(3, 2); | |||
MATMUL_512(0, 3); MATMUL_512(1, 3); MATMUL_512(2, 3); MATMUL_512(3, 3); | |||
} | |||
REORDER_STORE_4x8(0); | |||
REORDER_STORE_4x8(1); | |||
REORDER_STORE_4x8(2); | |||
REORDER_STORE_4x8(3); | |||
} | |||
for (; j < n16; j += 16) { | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); | |||
DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); | |||
for (k = 0; k < K; k++) { | |||
BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); BROADCAST_LOAD_A_512(2, x); BROADCAST_LOAD_A_512(3, x); | |||
LOAD_B_512(x, 0); LOAD_B_512(x, 1); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); | |||
MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); | |||
} | |||
REORDER_STORE_4x8(0); | |||
REORDER_STORE_4x8(1); | |||
} | |||
__mmask8 mask = 0xff; | |||
int nn = 8; | |||
for (; j < N; j += 8) { | |||
if (N - j < 8) { | |||
nn = N - j; | |||
mask = (1UL << nn) - 1; | |||
} | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); | |||
for (k = 0; k < K; k++) { | |||
BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); BROADCAST_LOAD_A_512(2, x); BROADCAST_LOAD_A_512(3, x); | |||
MASK_LOAD_B_512(x, 0); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); | |||
} | |||
MASK_REORDER_STORE_4x8(0); | |||
} | |||
} | |||
if (i < M) { | |||
long long index_n[8]; | |||
for (int ii = 0; ii < 8; ii++) { | |||
index_n[ii] = ii * ldc; | |||
} | |||
__m512i vindex_n = _mm512_loadu_si512(index_n); | |||
#if !defined(B0) | |||
__m512d beta_512 = _mm512_broadcastsd_pd(_mm_load_sd(&beta)); | |||
#endif | |||
for (; i < m2; i += 2) { | |||
for (j = 0; j < n32; j += 32) { | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); | |||
DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); | |||
DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); | |||
DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); | |||
for (k = 0; k < K; k++) { | |||
BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); | |||
LOAD_B_512(x, 0); LOAD_B_512(x, 1); LOAD_B_512(x, 2); LOAD_B_512(x, 3); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); | |||
MATMUL_512(0, 1); MATMUL_512(1, 1); | |||
MATMUL_512(0, 2); MATMUL_512(1, 2); | |||
MATMUL_512(0, 3); MATMUL_512(1, 3); | |||
} | |||
SCATTER_STORE_512(0, 0); SCATTER_STORE_512(1, 0); | |||
SCATTER_STORE_512(0, 1); SCATTER_STORE_512(1, 1); | |||
SCATTER_STORE_512(0, 2); SCATTER_STORE_512(1, 2); | |||
SCATTER_STORE_512(0, 3); SCATTER_STORE_512(1, 3); | |||
} | |||
for (; j < n16; j += 16) { | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); | |||
DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); | |||
for (k = 0; k < K; k++) { | |||
BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); | |||
LOAD_B_512(x, 0); LOAD_B_512(x, 1); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); | |||
MATMUL_512(0, 1); MATMUL_512(1, 1); | |||
} | |||
SCATTER_STORE_512(0, 0); SCATTER_STORE_512(1, 0); | |||
SCATTER_STORE_512(0, 1); SCATTER_STORE_512(1, 1); | |||
} | |||
__mmask8 mask = 0xff; | |||
int nn = 8; | |||
for (; j < N; j += 8) { | |||
if (N - j < 8) { | |||
nn = N - j; | |||
mask = (1UL << nn) - 1; | |||
} | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); | |||
for (k = 0; k < K; k++) { | |||
BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); | |||
MASK_LOAD_B_512(x, 0); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); | |||
} | |||
MASK_SCATTER_STORE_512(0, 0); MASK_SCATTER_STORE_512(1, 0); | |||
} | |||
} | |||
for (; i < M; i += 1) { | |||
for (j = 0; j < n32; j += 32) { | |||
DECLARE_RESULT_512(0, 0); | |||
DECLARE_RESULT_512(0, 1); | |||
DECLARE_RESULT_512(0, 2); | |||
DECLARE_RESULT_512(0, 3); | |||
for (k = 0; k < K; k++) { | |||
BROADCAST_LOAD_A_512(0, x); | |||
LOAD_B_512(x, 0); LOAD_B_512(x, 1); LOAD_B_512(x, 2); LOAD_B_512(x, 3); | |||
MATMUL_512(0, 0); | |||
MATMUL_512(0, 1); | |||
MATMUL_512(0, 2); | |||
MATMUL_512(0, 3); | |||
} | |||
SCATTER_STORE_512(0, 0); | |||
SCATTER_STORE_512(0, 1); | |||
SCATTER_STORE_512(0, 2); | |||
SCATTER_STORE_512(0, 3); | |||
} | |||
for (; j < n16; j += 16) { | |||
DECLARE_RESULT_512(0, 0); | |||
DECLARE_RESULT_512(0, 1); | |||
for (k = 0; k < K; k++) { | |||
BROADCAST_LOAD_A_512(0, x); | |||
LOAD_B_512(x, 0); LOAD_B_512(x, 1); | |||
MATMUL_512(0, 0); | |||
MATMUL_512(0, 1); | |||
} | |||
SCATTER_STORE_512(0, 0); | |||
SCATTER_STORE_512(0, 1); | |||
} | |||
__mmask8 mask = 0xff; | |||
int nn = 8; | |||
for (; j < N; j += 8) { | |||
if (N - j < 8) { | |||
nn = N - j; | |||
mask = (1UL << nn) - 1; | |||
} | |||
DECLARE_RESULT_512(0, 0); | |||
for (k = 0; k < K; k++) { | |||
BROADCAST_LOAD_A_512(0, x); | |||
MASK_LOAD_B_512(x, 0); | |||
MATMUL_512(0, 0); | |||
} | |||
MASK_SCATTER_STORE_512(0, 0); | |||
} | |||
} | |||
} | |||
return 0; | |||
} |
@@ -0,0 +1,612 @@ | |||
/*************************************************************************** | |||
Copyright (c) 2021, The OpenBLAS Project | |||
All rights reserved. | |||
Redistribution and use in source and binary forms, with or without | |||
modification, are permitted provided that the following conditions are | |||
met: | |||
1. Redistributions of source code must retain the above copyright | |||
notice, this list of conditions and the following disclaimer. | |||
2. Redistributions in binary form must reproduce the above copyright | |||
notice, this list of conditions and the following disclaimer in | |||
the documentation and/or other materials provided with the | |||
distribution. | |||
3. Neither the name of the OpenBLAS project nor the names of | |||
its contributors may be used to endorse or promote products | |||
derived from this software without specific prior written permission. | |||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE | |||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE | |||
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||
*****************************************************************************/ | |||
#include <immintrin.h> | |||
#include "common.h" | |||
#include <stdio.h> | |||
#include <memory.h> | |||
#define DECLARE_RESULT_512(M, N) __m512 result##M##N = _mm512_setzero_ps() | |||
#define LOAD_A_512(M, N) __m512 Aval##M = _mm512_loadu_ps(&A[lda * k + i + (M*16)]) | |||
#define MASK_LOAD_A_512(M, N) __m512 Aval##M = _mm512_maskz_loadu_ps(mask, &A[lda * k + i + (M*16)]) | |||
#define BROADCAST_LOAD_B_512(M, N) __m512 Bval##N = _mm512_broadcastss_ps(_mm_load_ss(&B[k + ldb * (j+N)])) | |||
#define MATMUL_512(M, N) result##M##N = _mm512_fmadd_ps(Aval##M, Bval##N, result##M##N) | |||
#if defined(B0) | |||
#define STORE_512(M, N) result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \ | |||
_mm512_storeu_ps(&C[(j+N)*ldc + i + (M*16)], result##M##N) | |||
#define MASK_STORE_512(M, N) result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \ | |||
_mm512_mask_storeu_ps(&C[(j+N)*ldc + i + (M*16)], mask, result##M##N) | |||
#else | |||
#define STORE_512(M, N) \ | |||
result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \ | |||
asm("vfmadd231ps (%1), %2, %0": "+v"(result##M##N):"r"(&C[(j+N)*ldc + i + (M*16)]), "v"(beta_512)); \ | |||
_mm512_storeu_ps(&C[(j+N)*ldc + i + (M*16)], result##M##N) | |||
#define MASK_STORE_512(M, N) \ | |||
result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \ | |||
asm("vfmadd231ps (%1), %2, %0 %{%3%}": "+v"(result##M##N):"r"(&C[(j+N)*ldc + i + (M*16)]), "v"(beta_512), "k"(mask)); \ | |||
_mm512_mask_storeu_ps(&C[(j+N)*ldc + i + (M*16)], mask, result##M##N) | |||
#endif | |||
#define LOAD_KA_512(M, N) __m512 Aval##M = _mm512_loadu_ps(&mbuf[(mi + M)*K + k]); | |||
#define LOAD_KB_512(M, N) __m512 Bval##N = _mm512_loadu_ps(&B[(j + N)*ldb + k]) | |||
#define MASK_LOAD_KA_512(M, N) __m512 Aval##M = _mm512_maskz_loadu_ps(mask, &mbuf[(mi + M)*K + k]) | |||
#define MASK_LOAD_KB_512(M, N) __m512 Bval##N = _mm512_maskz_loadu_ps(mask, &B[(j + N)*ldb + k]) | |||
#define REDUCE_4(rr0, rr1, rr2, rr3) \ | |||
__m512 r0, r1, r2, r3, t0, t1, t2, t3;\ | |||
r0 = _mm512_unpacklo_ps(rr0, rr1); r1 = _mm512_unpackhi_ps(rr0, rr1); \ | |||
r2 = _mm512_unpacklo_ps(rr2, rr3); r3 = _mm512_unpackhi_ps(rr2, rr3); \ | |||
t0 = _mm512_shuffle_ps(r0, r2, _MM_SHUFFLE(1, 0, 1, 0)); t1 = _mm512_shuffle_ps(r0, r2, _MM_SHUFFLE(3, 2, 3, 2)); \ | |||
t2 = _mm512_shuffle_ps(r1, r3, _MM_SHUFFLE(1, 0, 1, 0)); t3 = _mm512_shuffle_ps(r1, r3, _MM_SHUFFLE(3, 2, 3, 2)); \ | |||
r0 = _mm512_add_ps(t0, t1); r1 = _mm512_add_ps(t2, t3); t0 = _mm512_add_ps(r0, r1); \ | |||
__m128 s0, s1, s2, s3; \ | |||
s0 = _mm512_extractf32x4_ps(t0, 0); s1 = _mm512_extractf32x4_ps(t0, 1); s2 = _mm512_extractf32x4_ps(t0, 2); s3 = _mm512_extractf32x4_ps(t0, 3); \ | |||
s0 = _mm_maskz_add_ps(mask8, s0, s1); s2 = _mm_maskz_add_ps(mask8, s2, s3); s0 = _mm_maskz_add_ps(mask8, s0, s2); \ | |||
s0 = _mm_maskz_mul_ps(mask8, alpha_128, s0); | |||
#define REDUCE_M4(N) REDUCE_4(result0##N, result1##N, result2##N, result3##N) | |||
#define REDUCE_N4(M) REDUCE_4(result##M##0, result##M##1, result##M##2, result##M##3) | |||
#if defined(B0) | |||
#define STORE_REDUCE(M, N) C[(j+N)*ldc + i + M] = alpha * _mm512_reduce_add_ps(result##M##N); | |||
#define STORE_REDUCE_M4(N) {\ | |||
REDUCE_M4(N) \ | |||
_mm_mask_storeu_ps(&C[(j + N)*ldc + i], mask8, s0); \ | |||
} | |||
#define STORE_REDUCE_N4(M) {\ | |||
REDUCE_N4(M) \ | |||
_mm_i32scatter_ps(&C[j*ldc + i + M], vindex_n, s0, 4); \ | |||
} | |||
#else | |||
#define STORE_REDUCE(M, N) C[(j+N)*ldc + i + M] = alpha * _mm512_reduce_add_ps(result##M##N) + beta * C[(j+N)*ldc + i + M]; | |||
#define STORE_REDUCE_M4(N) {\ | |||
REDUCE_M4(N) \ | |||
asm("vfmadd231ps (%1), %2, %0": "+v"(s0):"r"(&C[(j + N)*ldc + i]), "v"(beta_128)); \ | |||
_mm_mask_storeu_ps(&C[(j + N)*ldc + i], mask8, s0); \ | |||
} | |||
#define STORE_REDUCE_N4(M) {\ | |||
REDUCE_N4(M) \ | |||
s1 = _mm_i32gather_ps(&C[j*ldc + i + M], vindex_n, 4); \ | |||
s0 = _mm_fmadd_ps(s1, beta_128, s0); \ | |||
_mm_i32scatter_ps(&C[j*ldc + i + M], vindex_n, s0, 4); \ | |||
} | |||
#endif | |||
#if defined(B0) | |||
int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb, FLOAT * C, BLASLONG ldc) | |||
#else | |||
int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb, FLOAT beta, FLOAT * C, BLASLONG ldc) | |||
#endif | |||
{ | |||
// column major | |||
BLASLONG i, j, k; | |||
BLASLONG m64 = M & ~63; | |||
BLASLONG m32 = M & ~31; | |||
BLASLONG m16 = M & ~15; | |||
BLASLONG m4 = M & ~3; | |||
BLASLONG m2 = M & ~1; | |||
BLASLONG n6 = N - (N % 6); | |||
BLASLONG n4 = N & ~3; | |||
BLASLONG n2 = N & ~1; | |||
__m512 alpha_512 = _mm512_broadcastss_ps(_mm_load_ss(&alpha)); | |||
#if !defined(B0) | |||
__m512 beta_512 = _mm512_broadcastss_ps(_mm_load_ss(&beta)); | |||
#endif | |||
for (i = 0; i < m64; i += 64) { | |||
for (j = 0; j < n4; j += 4) { | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); | |||
DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); | |||
DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); DECLARE_RESULT_512(2, 2); DECLARE_RESULT_512(3, 2); | |||
DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); DECLARE_RESULT_512(2, 3); DECLARE_RESULT_512(3, 3); | |||
for (k = 0; k < K; k++) { | |||
LOAD_A_512(0, x); LOAD_A_512(1, x); LOAD_A_512(2, x); LOAD_A_512(3, x); | |||
BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); | |||
BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); | |||
MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); | |||
MATMUL_512(0, 2); MATMUL_512(1, 2); MATMUL_512(2, 2); MATMUL_512(3, 2); | |||
MATMUL_512(0, 3); MATMUL_512(1, 3); MATMUL_512(2, 3); MATMUL_512(3, 3); | |||
} | |||
STORE_512(0, 0); STORE_512(1, 0); STORE_512(2, 0); STORE_512(3, 0); | |||
STORE_512(0, 1); STORE_512(1, 1); STORE_512(2, 1); STORE_512(3, 1); | |||
STORE_512(0, 2); STORE_512(1, 2); STORE_512(2, 2); STORE_512(3, 2); | |||
STORE_512(0, 3); STORE_512(1, 3); STORE_512(2, 3); STORE_512(3, 3); | |||
} | |||
for (; j < n2; j += 2) { | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); | |||
DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); | |||
for (k = 0; k < K; k++) { | |||
LOAD_A_512(0, x); LOAD_A_512(1, x); LOAD_A_512(2, x); LOAD_A_512(3, x); | |||
BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); | |||
MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); | |||
} | |||
STORE_512(0, 0); STORE_512(1, 0); STORE_512(2, 0); STORE_512(3, 0); | |||
STORE_512(0, 1); STORE_512(1, 1); STORE_512(2, 1); STORE_512(3, 1); | |||
} | |||
for (; j < N; j++) { | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); | |||
for (k = 0; k < K; k++) { | |||
LOAD_A_512(0, x); LOAD_A_512(1, x); LOAD_A_512(2, x); LOAD_A_512(3, x); | |||
BROADCAST_LOAD_B_512(x, 0); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); | |||
} | |||
STORE_512(0, 0); STORE_512(1, 0); STORE_512(2, 0); STORE_512(3, 0); | |||
} | |||
} | |||
for (; i < m32; i += 32) { | |||
for (j = 0; j < n6; j += 6) { | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); | |||
DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); | |||
DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); | |||
DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); | |||
DECLARE_RESULT_512(0, 4); DECLARE_RESULT_512(1, 4); | |||
DECLARE_RESULT_512(0, 5); DECLARE_RESULT_512(1, 5); | |||
for (k = 0; k < K; k++) { | |||
LOAD_A_512(0, x); LOAD_A_512(1, x); | |||
BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); | |||
BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3); | |||
BROADCAST_LOAD_B_512(x, 4); BROADCAST_LOAD_B_512(x, 5); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); | |||
MATMUL_512(0, 1); MATMUL_512(1, 1); | |||
MATMUL_512(0, 2); MATMUL_512(1, 2); | |||
MATMUL_512(0, 3); MATMUL_512(1, 3); | |||
MATMUL_512(0, 4); MATMUL_512(1, 4); | |||
MATMUL_512(0, 5); MATMUL_512(1, 5); | |||
} | |||
STORE_512(0, 0); STORE_512(1, 0); | |||
STORE_512(0, 1); STORE_512(1, 1); | |||
STORE_512(0, 2); STORE_512(1, 2); | |||
STORE_512(0, 3); STORE_512(1, 3); | |||
STORE_512(0, 4); STORE_512(1, 4); | |||
STORE_512(0, 5); STORE_512(1, 5); | |||
} | |||
for (; j < n2; j += 2) { | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); | |||
DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); | |||
for (k = 0; k < K; k++) { | |||
LOAD_A_512(0, x); LOAD_A_512(1, x); | |||
BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); | |||
MATMUL_512(0, 1); MATMUL_512(1, 1); | |||
} | |||
STORE_512(0, 0); STORE_512(1, 0); | |||
STORE_512(0, 1); STORE_512(1, 1); | |||
} | |||
for (; j < N; j++) { | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); | |||
for (k = 0; k < K; k++) { | |||
LOAD_A_512(0, x); LOAD_A_512(1, x); | |||
BROADCAST_LOAD_B_512(x, 0); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); | |||
} | |||
STORE_512(0, 0); STORE_512(1, 0); | |||
} | |||
} | |||
for (; i < m16; i += 16) { | |||
for (j = 0; j < n6; j += 6) { | |||
DECLARE_RESULT_512(0, 0); | |||
DECLARE_RESULT_512(0, 1); | |||
DECLARE_RESULT_512(0, 2); | |||
DECLARE_RESULT_512(0, 3); | |||
DECLARE_RESULT_512(0, 4); | |||
DECLARE_RESULT_512(0, 5); | |||
for (k = 0; k < K; k++) { | |||
LOAD_A_512(0, x); | |||
BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); | |||
BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3); | |||
BROADCAST_LOAD_B_512(x, 4); BROADCAST_LOAD_B_512(x, 5); | |||
MATMUL_512(0, 0); | |||
MATMUL_512(0, 1); | |||
MATMUL_512(0, 2); | |||
MATMUL_512(0, 3); | |||
MATMUL_512(0, 4); | |||
MATMUL_512(0, 5); | |||
} | |||
STORE_512(0, 0); | |||
STORE_512(0, 1); | |||
STORE_512(0, 2); | |||
STORE_512(0, 3); | |||
STORE_512(0, 4); | |||
STORE_512(0, 5); | |||
} | |||
for (; j < n2; j += 2) { | |||
DECLARE_RESULT_512(0, 0); | |||
DECLARE_RESULT_512(0, 1); | |||
for (k = 0; k < K; k++) { | |||
LOAD_A_512(0, x); | |||
BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); | |||
MATMUL_512(0, 0); | |||
MATMUL_512(0, 1); | |||
} | |||
STORE_512(0, 0); | |||
STORE_512(0, 1); | |||
} | |||
for (; j < N; j++) { | |||
DECLARE_RESULT_512(0, 0); | |||
for (k = 0; k < K; k++) { | |||
LOAD_A_512(0, x); | |||
BROADCAST_LOAD_B_512(x, 0); | |||
MATMUL_512(0, 0); | |||
} | |||
STORE_512(0, 0); | |||
} | |||
} | |||
int mm = M - i; | |||
if (!mm) return 0; | |||
if (mm > 8 || K < 32) { | |||
register __mmask16 mask asm("k1") = (1UL << mm) - 1; | |||
for (j = 0; j < n6; j += 6) { | |||
DECLARE_RESULT_512(0, 0); | |||
DECLARE_RESULT_512(0, 1); | |||
DECLARE_RESULT_512(0, 2); | |||
DECLARE_RESULT_512(0, 3); | |||
DECLARE_RESULT_512(0, 4); | |||
DECLARE_RESULT_512(0, 5); | |||
for (k = 0; k < K; k++) { | |||
MASK_LOAD_A_512(0, x); | |||
BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); | |||
BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3); | |||
BROADCAST_LOAD_B_512(x, 4); BROADCAST_LOAD_B_512(x, 5); | |||
MATMUL_512(0, 0); | |||
MATMUL_512(0, 1); | |||
MATMUL_512(0, 2); | |||
MATMUL_512(0, 3); | |||
MATMUL_512(0, 4); | |||
MATMUL_512(0, 5); | |||
} | |||
MASK_STORE_512(0, 0); | |||
MASK_STORE_512(0, 1); | |||
MASK_STORE_512(0, 2); | |||
MASK_STORE_512(0, 3); | |||
MASK_STORE_512(0, 4); | |||
MASK_STORE_512(0, 5); | |||
} | |||
for (; j < n2; j += 2) { | |||
DECLARE_RESULT_512(0, 0); | |||
DECLARE_RESULT_512(0, 1); | |||
for (k = 0; k < K; k++) { | |||
MASK_LOAD_A_512(0, x); | |||
BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); | |||
MATMUL_512(0, 0); | |||
MATMUL_512(0, 1); | |||
} | |||
MASK_STORE_512(0, 0); | |||
MASK_STORE_512(0, 1); | |||
} | |||
for (; j < N; j++) { | |||
DECLARE_RESULT_512(0, 0); | |||
for (k = 0; k < K; k++) { | |||
MASK_LOAD_A_512(0, x); | |||
BROADCAST_LOAD_B_512(x, 0); | |||
MATMUL_512(0, 0); | |||
} | |||
MASK_STORE_512(0, 0); | |||
} | |||
} else { | |||
/* M => [1, 8] | |||
* | |||
* This kernel use dot-like style to calc a value - C(x, y): | |||
* C(x, y) = A(x, 0)*B(0, y) + A(x, 1)*B(1, y) +....+ A(x, K)*B(K, y) | |||
* | |||
* Alloc a buf to copy rest of A as row major, | |||
* so memory access from 0 to K is continuous for both A & B. | |||
* | |||
* Loading to zmm and FMA 16 of k at one loop, | |||
* finally reduce_add zmm to a single float result in C(x, y). | |||
* | |||
* Note: performance is bad when K is small. | |||
*/ | |||
FLOAT *mbuf = (FLOAT *) malloc(sizeof(FLOAT)*mm*K); | |||
__mmask8 mask8 = (1UL << mm) - 1; | |||
__mmask16 mask; | |||
BLASLONG k16 = K & ~15; | |||
BLASLONG k8 = K & ~7; | |||
for (k = 0; k < k8; k += 8) { | |||
__m256 r0, r1, r2, r3, r4, r5, r6, r7; | |||
__m256 t0, t1, t2, t3, t4, t5, t6, t7; | |||
r0 = _mm256_maskz_loadu_ps(mask8, &A[i + lda*(0 + k)]); | |||
r1 = _mm256_maskz_loadu_ps(mask8, &A[i + lda*(1 + k)]); | |||
r2 = _mm256_maskz_loadu_ps(mask8, &A[i + lda*(2 + k)]); | |||
r3 = _mm256_maskz_loadu_ps(mask8, &A[i + lda*(3 + k)]); | |||
r4 = _mm256_maskz_loadu_ps(mask8, &A[i + lda*(4 + k)]); | |||
r5 = _mm256_maskz_loadu_ps(mask8, &A[i + lda*(5 + k)]); | |||
r6 = _mm256_maskz_loadu_ps(mask8, &A[i + lda*(6 + k)]); | |||
r7 = _mm256_maskz_loadu_ps(mask8, &A[i + lda*(7 + k)]); | |||
t0 = _mm256_unpacklo_ps(r0, r1); | |||
t1 = _mm256_unpackhi_ps(r0, r1); | |||
t2 = _mm256_unpacklo_ps(r2, r3); | |||
t3 = _mm256_unpackhi_ps(r2, r3); | |||
t4 = _mm256_unpacklo_ps(r4, r5); | |||
t5 = _mm256_unpackhi_ps(r4, r5); | |||
t6 = _mm256_unpacklo_ps(r6, r7); | |||
t7 = _mm256_unpackhi_ps(r6, r7); | |||
r0 = _mm256_shuffle_ps(t0,t2,_MM_SHUFFLE(1,0,1,0)); | |||
r1 = _mm256_shuffle_ps(t0,t2,_MM_SHUFFLE(3,2,3,2)); | |||
r2 = _mm256_shuffle_ps(t1,t3,_MM_SHUFFLE(1,0,1,0)); | |||
r3 = _mm256_shuffle_ps(t1,t3,_MM_SHUFFLE(3,2,3,2)); | |||
r4 = _mm256_shuffle_ps(t4,t6,_MM_SHUFFLE(1,0,1,0)); | |||
r5 = _mm256_shuffle_ps(t4,t6,_MM_SHUFFLE(3,2,3,2)); | |||
r6 = _mm256_shuffle_ps(t5,t7,_MM_SHUFFLE(1,0,1,0)); | |||
r7 = _mm256_shuffle_ps(t5,t7,_MM_SHUFFLE(3,2,3,2)); | |||
t0 = _mm256_permute2f128_ps(r0, r4, 0x20); | |||
t1 = _mm256_permute2f128_ps(r1, r5, 0x20); | |||
t2 = _mm256_permute2f128_ps(r2, r6, 0x20); | |||
t3 = _mm256_permute2f128_ps(r3, r7, 0x20); | |||
t4 = _mm256_permute2f128_ps(r0, r4, 0x31); | |||
t5 = _mm256_permute2f128_ps(r1, r5, 0x31); | |||
t6 = _mm256_permute2f128_ps(r2, r6, 0x31); | |||
t7 = _mm256_permute2f128_ps(r3, r7, 0x31); | |||
switch (mm) { | |||
case 8: _mm256_storeu_ps(&mbuf[k + 7*K], t7); | |||
case 7: _mm256_storeu_ps(&mbuf[k + 6*K], t6); | |||
case 6: _mm256_storeu_ps(&mbuf[k + 5*K], t5); | |||
case 5: _mm256_storeu_ps(&mbuf[k + 4*K], t4); | |||
case 4: _mm256_storeu_ps(&mbuf[k + 3*K], t3); | |||
case 3: _mm256_storeu_ps(&mbuf[k + 2*K], t2); | |||
case 2: _mm256_storeu_ps(&mbuf[k + 1*K], t1); | |||
case 1: _mm256_storeu_ps(&mbuf[k + 0*K], t0); | |||
} | |||
} | |||
for (; k < K; k++) { | |||
for (int ii = 0; ii < mm; ii++) { | |||
mbuf[k + ii*K] = A[i + lda*k + ii]; | |||
} | |||
} | |||
int mi = 0; | |||
mask8 = 0xff; // just use to avoid SSE instruction | |||
__m128 alpha_128 = _mm_broadcast_ss(&alpha); | |||
#if !defined(B0) | |||
__m128 beta_128 = _mm_broadcast_ss(&beta); | |||
#endif | |||
__m128i vindex_n = _mm_set_epi32(ldc*3, ldc*2, ldc, 0); | |||
for (; i < m4; i += 4, mi += 4) { | |||
for (j = 0; j < n4; j += 4) { | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); | |||
DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); | |||
DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); DECLARE_RESULT_512(2, 2); DECLARE_RESULT_512(3, 2); | |||
DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); DECLARE_RESULT_512(2, 3); DECLARE_RESULT_512(3, 3); | |||
for (k = 0; k < k16; k += 16) { | |||
LOAD_KA_512(0, x); LOAD_KA_512(1, x); LOAD_KA_512(2, x); LOAD_KA_512(3, x); | |||
LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); LOAD_KB_512(x, 2); LOAD_KB_512(x, 3); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); | |||
MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); | |||
MATMUL_512(0, 2); MATMUL_512(1, 2); MATMUL_512(2, 2); MATMUL_512(3, 2); | |||
MATMUL_512(0, 3); MATMUL_512(1, 3); MATMUL_512(2, 3); MATMUL_512(3, 3); | |||
} | |||
int remains = K - k; | |||
if (remains) { | |||
mask = (1UL << remains) - 1; | |||
MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); MASK_LOAD_KA_512(2, x); MASK_LOAD_KA_512(3, x); | |||
MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); MASK_LOAD_KB_512(x, 2); MASK_LOAD_KB_512(x, 3); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); | |||
MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); | |||
MATMUL_512(0, 2); MATMUL_512(1, 2); MATMUL_512(2, 2); MATMUL_512(3, 2); | |||
MATMUL_512(0, 3); MATMUL_512(1, 3); MATMUL_512(2, 3); MATMUL_512(3, 3); | |||
} | |||
STORE_REDUCE_M4(0); STORE_REDUCE_M4(1); STORE_REDUCE_M4(2); STORE_REDUCE_M4(3); | |||
} | |||
for (; j < n2; j += 2) { | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); | |||
DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); | |||
for (k = 0; k < k16; k += 16) { | |||
LOAD_KA_512(0, x); LOAD_KA_512(1, x); LOAD_KA_512(2, x); LOAD_KA_512(3, x); | |||
LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); | |||
MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); | |||
} | |||
int remains = K - k; | |||
if (remains) { | |||
mask = (1UL << remains) - 1; | |||
MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); MASK_LOAD_KA_512(2, x); MASK_LOAD_KA_512(3, x); | |||
MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); | |||
MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); | |||
} | |||
STORE_REDUCE_M4(0); STORE_REDUCE_M4(1); | |||
} | |||
for (; j < N; j += 1) { | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); | |||
for (k = 0; k < k16; k += 16) { | |||
LOAD_KA_512(0, x); LOAD_KA_512(1, x); LOAD_KA_512(2, x); LOAD_KA_512(3, x); | |||
LOAD_KB_512(x, 0); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); | |||
} | |||
int remains = K - k; | |||
if (remains) { | |||
mask = (1UL << remains) - 1; | |||
MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); MASK_LOAD_KA_512(2, x); MASK_LOAD_KA_512(3, x); | |||
MASK_LOAD_KB_512(x, 0); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); | |||
} | |||
STORE_REDUCE_M4(0); | |||
} | |||
} | |||
for (; i < m2; i += 2, mi += 2) { | |||
for (j = 0; j < n4; j += 4) { | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); | |||
DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); | |||
DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); | |||
DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); | |||
for (k = 0; k < k16; k += 16) { | |||
LOAD_KA_512(0, x); LOAD_KA_512(1, x); | |||
LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); LOAD_KB_512(x, 2); LOAD_KB_512(x, 3); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); | |||
MATMUL_512(0, 1); MATMUL_512(1, 1); | |||
MATMUL_512(0, 2); MATMUL_512(1, 2); | |||
MATMUL_512(0, 3); MATMUL_512(1, 3); | |||
} | |||
int remains = K - k; | |||
if (remains) { | |||
mask = (1UL << remains) - 1; | |||
MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); | |||
MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); MASK_LOAD_KB_512(x, 2); MASK_LOAD_KB_512(x, 3); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); | |||
MATMUL_512(0, 1); MATMUL_512(1, 1); | |||
MATMUL_512(0, 2); MATMUL_512(1, 2); | |||
MATMUL_512(0, 3); MATMUL_512(1, 3); | |||
} | |||
STORE_REDUCE_N4(0); STORE_REDUCE_N4(1); | |||
} | |||
for (; j < n2; j += 2) { | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); | |||
DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); | |||
for (k = 0; k < k16; k += 16) { | |||
LOAD_KA_512(0, x); LOAD_KA_512(1, x); | |||
LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); | |||
MATMUL_512(0, 1); MATMUL_512(1, 1); | |||
} | |||
int remains = K - k; | |||
if (remains) { | |||
mask = (1UL << remains) - 1; | |||
MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); | |||
MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); | |||
MATMUL_512(0, 1); MATMUL_512(1, 1); | |||
} | |||
STORE_REDUCE(0, 0); STORE_REDUCE(1, 0); | |||
STORE_REDUCE(0, 1); STORE_REDUCE(1, 1); | |||
} | |||
for (; j < N; j += 1) { | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); | |||
for (k = 0; k < k16; k += 16) { | |||
LOAD_KA_512(0, x); LOAD_KA_512(1, x); | |||
LOAD_KB_512(x, 0); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); | |||
} | |||
int remains = K - k; | |||
if (remains) { | |||
mask = (1UL << remains) - 1; | |||
MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); | |||
MASK_LOAD_KB_512(x, 0); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); | |||
} | |||
STORE_REDUCE(0, 0); STORE_REDUCE(1, 0); | |||
} | |||
} | |||
for (; i < M; i += 1, mi += 1) { | |||
for (j = 0; j < n4; j += 4) { | |||
DECLARE_RESULT_512(0, 0); | |||
DECLARE_RESULT_512(0, 1); | |||
DECLARE_RESULT_512(0, 2); | |||
DECLARE_RESULT_512(0, 3); | |||
for (k = 0; k < k16; k += 16) { | |||
LOAD_KA_512(0, x); | |||
LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); LOAD_KB_512(x, 2); LOAD_KB_512(x, 3); | |||
MATMUL_512(0, 0); | |||
MATMUL_512(0, 1); | |||
MATMUL_512(0, 2); | |||
MATMUL_512(0, 3); | |||
} | |||
int remains = K - k; | |||
if (remains) { | |||
mask = (1UL << remains) - 1; | |||
MASK_LOAD_KA_512(0, x); | |||
MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); MASK_LOAD_KB_512(x, 2); MASK_LOAD_KB_512(x, 3); | |||
MATMUL_512(0, 0); | |||
MATMUL_512(0, 1); | |||
MATMUL_512(0, 2); | |||
MATMUL_512(0, 3); | |||
} | |||
STORE_REDUCE_N4(0); | |||
} | |||
for (; j < n2; j += 2) { | |||
DECLARE_RESULT_512(0, 0); | |||
DECLARE_RESULT_512(0, 1); | |||
for (k = 0; k < k16; k += 16) { | |||
LOAD_KA_512(0, x); | |||
LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); | |||
MATMUL_512(0, 0); | |||
MATMUL_512(0, 1); | |||
} | |||
int remains = K - k; | |||
if (remains) { | |||
mask = (1UL << remains) - 1; | |||
MASK_LOAD_KA_512(0, x); | |||
MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); | |||
MATMUL_512(0, 0); | |||
MATMUL_512(0, 1); | |||
} | |||
STORE_REDUCE(0, 0); | |||
STORE_REDUCE(0, 1); | |||
} | |||
for (; j < N; j += 1) { | |||
DECLARE_RESULT_512(0, 0); | |||
for (k = 0; k < k16; k += 16) { | |||
LOAD_KA_512(0, x); | |||
LOAD_KB_512(x, 0); | |||
MATMUL_512(0, 0); | |||
} | |||
int remains = K - k; | |||
if (remains) { | |||
mask = (1UL << remains) - 1; | |||
MASK_LOAD_KA_512(0, x); | |||
MASK_LOAD_KB_512(x, 0); | |||
MATMUL_512(0, 0); | |||
} | |||
STORE_REDUCE(0, 0); | |||
} | |||
} | |||
free(mbuf); | |||
} | |||
return 0; | |||
} |
@@ -0,0 +1,535 @@ | |||
/*************************************************************************** | |||
Copyright (c) 2021, The OpenBLAS Project | |||
All rights reserved. | |||
Redistribution and use in source and binary forms, with or without | |||
modification, are permitted provided that the following conditions are | |||
met: | |||
1. Redistributions of source code must retain the above copyright | |||
notice, this list of conditions and the following disclaimer. | |||
2. Redistributions in binary form must reproduce the above copyright | |||
notice, this list of conditions and the following disclaimer in | |||
the documentation and/or other materials provided with the | |||
distribution. | |||
3. Neither the name of the OpenBLAS project nor the names of | |||
its contributors may be used to endorse or promote products | |||
derived from this software without specific prior written permission. | |||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE | |||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE | |||
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||
*****************************************************************************/ | |||
#include <immintrin.h> | |||
#include "common.h" | |||
#include <stdio.h> | |||
#include <memory.h> | |||
#define DECLARE_RESULT_512(M, N) __m512 result##M##N = _mm512_setzero_ps() | |||
#define LOAD_A_512(M, N) __m512 Aval##M = _mm512_loadu_ps(&A[lda * k + i + (M*16)]) | |||
#define MASK_LOAD_A_512(M, N) __m512 Aval##M = _mm512_maskz_loadu_ps(mask, &A[lda * k + i + (M*16)]) | |||
#define BROADCAST_LOAD_B_512(M, N) __m512 Bval##N = _mm512_broadcastss_ps(_mm_load_ss(&B[ldb * k + j + N])) | |||
#define MATMUL_512(M, N) result##M##N = _mm512_fmadd_ps(Aval##M, Bval##N, result##M##N) | |||
#define BROADCAST_LOAD_A_512(M, N) __m512 Aval##M = _mm512_broadcastss_ps(_mm_load_ss(&A[lda * k + i + M])) | |||
#define LOAD_B_512(M, N) __m512 Bval##N = _mm512_loadu_ps(&B[ldb * k + j + (N*16)]) | |||
#define MASK_LOAD_B_512(M, N) __m512 Bval##N = _mm512_maskz_loadu_ps(mask, &B[ldb * k + j + (N*16)]) | |||
#if defined(B0) | |||
#define STORE_512(M, N) result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \ | |||
_mm512_storeu_ps(&C[(j+N)*ldc + i + (M*16)], result##M##N) | |||
#define MASK_STORE_512(M, N) result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \ | |||
_mm512_mask_storeu_ps(&C[(j+N)*ldc + i + (M*16)], mask, result##M##N) | |||
#define SCATTER_STORE_512(M, N) result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \ | |||
_mm512_i32scatter_ps(&C[(j + N*16)*ldc + i + M], vindex_n, result##M##N, 4); | |||
#define MASK_SCATTER_STORE_512(M, N) result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \ | |||
_mm512_mask_i32scatter_ps(&C[(j + N*16)*ldc + i + M], mask, vindex_n, result##M##N, 4) | |||
#else | |||
#define STORE_512(M, N) \ | |||
result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \ | |||
asm("vfmadd231ps (%1), %2, %0": "+v"(result##M##N):"r"(&C[(j+N)*ldc + i + (M*16)]), "v"(beta_512)); \ | |||
_mm512_storeu_ps(&C[(j+N)*ldc + i + (M*16)], result##M##N) | |||
#define MASK_STORE_512(M, N) \ | |||
result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \ | |||
asm("vfmadd231ps (%1), %2, %0 %{%3%}": "+v"(result##M##N):"r"(&C[(j+N)*ldc + i + (M*16)]), "v"(beta_512), "k"(mask)); \ | |||
_mm512_mask_storeu_ps(&C[(j+N)*ldc + i + (M*16)], mask, result##M##N) | |||
#define SCATTER_STORE_512(M, N) result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \ | |||
__m512 tmp##M##N = _mm512_i32gather_ps(vindex_n, &C[(j + N*16)*ldc + i + M], 4); \ | |||
result##M##N = _mm512_fmadd_ps(tmp##M##N, beta_512, result##M##N); \ | |||
_mm512_i32scatter_ps(&C[(j + N*16)*ldc + i + M], vindex_n, result##M##N, 4); | |||
#define MASK_SCATTER_STORE_512(M, N) result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \ | |||
__m512 tmp##M##N = _mm512_mask_i32gather_ps(_mm512_setzero_ps(), mask, vindex_n, &C[(j + N*16)*ldc + i + M], 4); \ | |||
result##M##N = _mm512_fmadd_ps(tmp##M##N, beta_512, result##M##N); \ | |||
_mm512_mask_i32scatter_ps(&C[(j + N*16)*ldc + i + M], mask, vindex_n, result##M##N, 4); | |||
#endif | |||
#if defined(B0) | |||
int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb, FLOAT * C, BLASLONG ldc) | |||
#else | |||
int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb, FLOAT beta, FLOAT * C, BLASLONG ldc) | |||
#endif | |||
{ | |||
// column major | |||
BLASLONG i, j, k; | |||
BLASLONG m64 = M & ~63; | |||
BLASLONG m32 = M & ~31; | |||
BLASLONG m16 = M & ~15; | |||
BLASLONG m4 = M & ~3; | |||
BLASLONG m2 = M & ~1; | |||
BLASLONG n64 = N & ~63; | |||
BLASLONG n32 = N & ~31; | |||
BLASLONG n8 = N & ~7; | |||
BLASLONG n6 = N - (N % 6); | |||
BLASLONG n4 = N & ~3; | |||
BLASLONG n2 = N & ~1; | |||
__m512 alpha_512 = _mm512_broadcastss_ps(_mm_load_ss(&alpha)); | |||
#if !defined(B0) | |||
__m512 beta_512 = _mm512_broadcastss_ps(_mm_load_ss(&beta)); | |||
#endif | |||
for (i = 0; i < m64; i += 64) { | |||
for (j = 0; j < n6; j += 6) { | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); | |||
DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); | |||
DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); DECLARE_RESULT_512(2, 2); DECLARE_RESULT_512(3, 2); | |||
DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); DECLARE_RESULT_512(2, 3); DECLARE_RESULT_512(3, 3); | |||
DECLARE_RESULT_512(0, 4); DECLARE_RESULT_512(1, 4); DECLARE_RESULT_512(2, 4); DECLARE_RESULT_512(3, 4); | |||
DECLARE_RESULT_512(0, 5); DECLARE_RESULT_512(1, 5); DECLARE_RESULT_512(2, 5); DECLARE_RESULT_512(3, 5); | |||
for (k = 0; k < K; k++) { | |||
LOAD_A_512(0, x); LOAD_A_512(1, x); LOAD_A_512(2, x); LOAD_A_512(3, x); | |||
BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); | |||
MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); | |||
BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3); | |||
MATMUL_512(0, 2); MATMUL_512(1, 2); MATMUL_512(2, 2); MATMUL_512(3, 2); | |||
MATMUL_512(0, 3); MATMUL_512(1, 3); MATMUL_512(2, 3); MATMUL_512(3, 3); | |||
BROADCAST_LOAD_B_512(x, 4); BROADCAST_LOAD_B_512(x, 5); | |||
MATMUL_512(0, 4); MATMUL_512(1, 4); MATMUL_512(2, 4); MATMUL_512(3, 4); | |||
MATMUL_512(0, 5); MATMUL_512(1, 5); MATMUL_512(2, 5); MATMUL_512(3, 5); | |||
} | |||
STORE_512(0, 0); STORE_512(1, 0); STORE_512(2, 0); STORE_512(3, 0); | |||
STORE_512(0, 1); STORE_512(1, 1); STORE_512(2, 1); STORE_512(3, 1); | |||
STORE_512(0, 2); STORE_512(1, 2); STORE_512(2, 2); STORE_512(3, 2); | |||
STORE_512(0, 3); STORE_512(1, 3); STORE_512(2, 3); STORE_512(3, 3); | |||
STORE_512(0, 4); STORE_512(1, 4); STORE_512(2, 4); STORE_512(3, 4); | |||
STORE_512(0, 5); STORE_512(1, 5); STORE_512(2, 5); STORE_512(3, 5); | |||
} | |||
for (; j < n2; j += 2) { | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); | |||
DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); | |||
for (k = 0; k < K; k++) { | |||
LOAD_A_512(0, x); LOAD_A_512(1, x); LOAD_A_512(2, x); LOAD_A_512(3, x); | |||
BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); | |||
MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); | |||
} | |||
STORE_512(0, 0); STORE_512(1, 0); STORE_512(2, 0); STORE_512(3, 0); | |||
STORE_512(0, 1); STORE_512(1, 1); STORE_512(2, 1); STORE_512(3, 1); | |||
} | |||
for (; j < N; j++) { | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); | |||
for (k = 0; k < K; k++) { | |||
LOAD_A_512(0, x); LOAD_A_512(1, x); LOAD_A_512(2, x); LOAD_A_512(3, x); | |||
BROADCAST_LOAD_B_512(x, 0); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); | |||
} | |||
STORE_512(0, 0); STORE_512(1, 0); STORE_512(2, 0); STORE_512(3, 0); | |||
} | |||
} | |||
for (; i < m32; i += 32) { | |||
for (j = 0; j < n8; j += 8) { | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); | |||
DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); | |||
DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); | |||
DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); | |||
DECLARE_RESULT_512(0, 4); DECLARE_RESULT_512(1, 4); | |||
DECLARE_RESULT_512(0, 5); DECLARE_RESULT_512(1, 5); | |||
DECLARE_RESULT_512(0, 6); DECLARE_RESULT_512(1, 6); | |||
DECLARE_RESULT_512(0, 7); DECLARE_RESULT_512(1, 7); | |||
for (k = 0; k < K; k++) { | |||
LOAD_A_512(0, x); LOAD_A_512(1, x); | |||
BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); | |||
BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3); | |||
BROADCAST_LOAD_B_512(x, 4); BROADCAST_LOAD_B_512(x, 5); | |||
BROADCAST_LOAD_B_512(x, 6); BROADCAST_LOAD_B_512(x, 7); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); | |||
MATMUL_512(0, 1); MATMUL_512(1, 1); | |||
MATMUL_512(0, 2); MATMUL_512(1, 2); | |||
MATMUL_512(0, 3); MATMUL_512(1, 3); | |||
MATMUL_512(0, 4); MATMUL_512(1, 4); | |||
MATMUL_512(0, 5); MATMUL_512(1, 5); | |||
MATMUL_512(0, 6); MATMUL_512(1, 6); | |||
MATMUL_512(0, 7); MATMUL_512(1, 7); | |||
} | |||
STORE_512(0, 0); STORE_512(1, 0); | |||
STORE_512(0, 1); STORE_512(1, 1); | |||
STORE_512(0, 2); STORE_512(1, 2); | |||
STORE_512(0, 3); STORE_512(1, 3); | |||
STORE_512(0, 4); STORE_512(1, 4); | |||
STORE_512(0, 5); STORE_512(1, 5); | |||
STORE_512(0, 6); STORE_512(1, 6); | |||
STORE_512(0, 7); STORE_512(1, 7); | |||
} | |||
for (;j < n4; j += 4) { | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); | |||
DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); | |||
DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); | |||
DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); | |||
for (k = 0; k < K; k++) { | |||
LOAD_A_512(0, x); LOAD_A_512(1, x); | |||
BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); | |||
BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); | |||
MATMUL_512(0, 1); MATMUL_512(1, 1); | |||
MATMUL_512(0, 2); MATMUL_512(1, 2); | |||
MATMUL_512(0, 3); MATMUL_512(1, 3); | |||
} | |||
STORE_512(0, 0); STORE_512(1, 0); | |||
STORE_512(0, 1); STORE_512(1, 1); | |||
STORE_512(0, 2); STORE_512(1, 2); | |||
STORE_512(0, 3); STORE_512(1, 3); | |||
} | |||
for (; j < n2; j += 2) { | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); | |||
DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); | |||
for (k = 0; k < K; k++) { | |||
LOAD_A_512(0, x); LOAD_A_512(1, x); | |||
BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); | |||
MATMUL_512(0, 1); MATMUL_512(1, 1); | |||
} | |||
STORE_512(0, 0); STORE_512(1, 0); | |||
STORE_512(0, 1); STORE_512(1, 1); | |||
} | |||
for (; j < N; j++) { | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); | |||
for (k = 0; k < K; k++) { | |||
LOAD_A_512(0, x); LOAD_A_512(1, x); | |||
BROADCAST_LOAD_B_512(x, 0); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); | |||
} | |||
STORE_512(0, 0); STORE_512(1, 0); | |||
} | |||
} | |||
for (; i < m16; i += 16) { | |||
for (j = 0; j < n8; j += 8) { | |||
DECLARE_RESULT_512(0, 0); | |||
DECLARE_RESULT_512(0, 1); | |||
DECLARE_RESULT_512(0, 2); | |||
DECLARE_RESULT_512(0, 3); | |||
DECLARE_RESULT_512(0, 4); | |||
DECLARE_RESULT_512(0, 5); | |||
DECLARE_RESULT_512(0, 6); | |||
DECLARE_RESULT_512(0, 7); | |||
for (k = 0; k < K; k++) { | |||
LOAD_A_512(0, x); | |||
BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); | |||
BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3); | |||
BROADCAST_LOAD_B_512(x, 4); BROADCAST_LOAD_B_512(x, 5); | |||
BROADCAST_LOAD_B_512(x, 6); BROADCAST_LOAD_B_512(x, 7); | |||
MATMUL_512(0, 0); | |||
MATMUL_512(0, 1); | |||
MATMUL_512(0, 2); | |||
MATMUL_512(0, 3); | |||
MATMUL_512(0, 4); | |||
MATMUL_512(0, 5); | |||
MATMUL_512(0, 6); | |||
MATMUL_512(0, 7); | |||
} | |||
STORE_512(0, 0); | |||
STORE_512(0, 1); | |||
STORE_512(0, 2); | |||
STORE_512(0, 3); | |||
STORE_512(0, 4); | |||
STORE_512(0, 5); | |||
STORE_512(0, 6); | |||
STORE_512(0, 7); | |||
} | |||
for (; j < n4; j += 4) { | |||
DECLARE_RESULT_512(0, 0); | |||
DECLARE_RESULT_512(0, 1); | |||
DECLARE_RESULT_512(0, 2); | |||
DECLARE_RESULT_512(0, 3); | |||
for (k = 0; k < K; k++) { | |||
LOAD_A_512(0, x); | |||
BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); | |||
BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3); | |||
MATMUL_512(0, 0); | |||
MATMUL_512(0, 1); | |||
MATMUL_512(0, 2); | |||
MATMUL_512(0, 3); | |||
} | |||
STORE_512(0, 0); | |||
STORE_512(0, 1); | |||
STORE_512(0, 2); | |||
STORE_512(0, 3); | |||
} | |||
for (; j < n2; j += 2) { | |||
DECLARE_RESULT_512(0, 0); | |||
DECLARE_RESULT_512(0, 1); | |||
for (k = 0; k < K; k++) { | |||
LOAD_A_512(0, x); | |||
BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); | |||
MATMUL_512(0, 0); | |||
MATMUL_512(0, 1); | |||
} | |||
STORE_512(0, 0); | |||
STORE_512(0, 1); | |||
} | |||
for (; j < N; j++) { | |||
DECLARE_RESULT_512(0, 0); | |||
for (k = 0; k < K; k++) { | |||
LOAD_A_512(0, x); | |||
BROADCAST_LOAD_B_512(x, 0); | |||
MATMUL_512(0, 0); | |||
} | |||
STORE_512(0, 0); | |||
} | |||
} | |||
int mm = M - i; | |||
if (mm >= 12) { | |||
register __mmask16 mask asm("k1") = (1UL << mm) - 1; | |||
for (j = 0; j < n8; j += 8) { | |||
DECLARE_RESULT_512(0, 0); | |||
DECLARE_RESULT_512(0, 1); | |||
DECLARE_RESULT_512(0, 2); | |||
DECLARE_RESULT_512(0, 3); | |||
DECLARE_RESULT_512(0, 4); | |||
DECLARE_RESULT_512(0, 5); | |||
DECLARE_RESULT_512(0, 6); | |||
DECLARE_RESULT_512(0, 7); | |||
for (k = 0; k < K; k++) { | |||
MASK_LOAD_A_512(0, x); | |||
BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); | |||
BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3); | |||
BROADCAST_LOAD_B_512(x, 4); BROADCAST_LOAD_B_512(x, 5); | |||
BROADCAST_LOAD_B_512(x, 6); BROADCAST_LOAD_B_512(x, 7); | |||
MATMUL_512(0, 0); | |||
MATMUL_512(0, 1); | |||
MATMUL_512(0, 2); | |||
MATMUL_512(0, 3); | |||
MATMUL_512(0, 4); | |||
MATMUL_512(0, 5); | |||
MATMUL_512(0, 6); | |||
MATMUL_512(0, 7); | |||
} | |||
MASK_STORE_512(0, 0); | |||
MASK_STORE_512(0, 1); | |||
MASK_STORE_512(0, 2); | |||
MASK_STORE_512(0, 3); | |||
MASK_STORE_512(0, 4); | |||
MASK_STORE_512(0, 5); | |||
MASK_STORE_512(0, 6); | |||
MASK_STORE_512(0, 7); | |||
} | |||
for (; j < n4; j += 4) { | |||
DECLARE_RESULT_512(0, 0); | |||
DECLARE_RESULT_512(0, 1); | |||
DECLARE_RESULT_512(0, 2); | |||
DECLARE_RESULT_512(0, 3); | |||
for (k = 0; k < K; k++) { | |||
MASK_LOAD_A_512(0, x); | |||
BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); | |||
BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3); | |||
MATMUL_512(0, 0); | |||
MATMUL_512(0, 1); | |||
MATMUL_512(0, 2); | |||
MATMUL_512(0, 3); | |||
} | |||
MASK_STORE_512(0, 0); | |||
MASK_STORE_512(0, 1); | |||
MASK_STORE_512(0, 2); | |||
MASK_STORE_512(0, 3); | |||
} | |||
for (; j < n2; j += 2) { | |||
DECLARE_RESULT_512(0, 0); | |||
DECLARE_RESULT_512(0, 1); | |||
for (k = 0; k < K; k++) { | |||
MASK_LOAD_A_512(0, x); | |||
BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); | |||
MATMUL_512(0, 0); | |||
MATMUL_512(0, 1); | |||
} | |||
MASK_STORE_512(0, 0); | |||
MASK_STORE_512(0, 1); | |||
} | |||
for (; j < N; j++) { | |||
DECLARE_RESULT_512(0, 0); | |||
for (k = 0; k < K; k++) { | |||
MASK_LOAD_A_512(0, x); | |||
BROADCAST_LOAD_B_512(x, 0); | |||
MATMUL_512(0, 0); | |||
} | |||
MASK_STORE_512(0, 0); | |||
} | |||
} else if (mm > 0) { | |||
int index_n[16]; | |||
for (int ii = 0; ii < 16; ii++) { | |||
index_n[ii] = ii * ldc; | |||
} | |||
__m512i vindex_n = _mm512_loadu_si512(index_n); | |||
for (; i < m4; i += 4) { | |||
for (j = 0; j < n64; j += 64) { | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); | |||
DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); | |||
DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); DECLARE_RESULT_512(2, 2); DECLARE_RESULT_512(3, 2); | |||
DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); DECLARE_RESULT_512(2, 3); DECLARE_RESULT_512(3, 3); | |||
for (k = 0; k < K; k++) { | |||
BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); BROADCAST_LOAD_A_512(2, x); BROADCAST_LOAD_A_512(3, x); | |||
LOAD_B_512(x, 0); | |||
LOAD_B_512(x, 1); | |||
LOAD_B_512(x, 2); | |||
LOAD_B_512(x, 3); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); | |||
MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); | |||
MATMUL_512(0, 2); MATMUL_512(1, 2); MATMUL_512(2, 2); MATMUL_512(3, 2); | |||
MATMUL_512(0, 3); MATMUL_512(1, 3); MATMUL_512(2, 3); MATMUL_512(3, 3); | |||
} | |||
SCATTER_STORE_512(0, 0); SCATTER_STORE_512(1, 0); SCATTER_STORE_512(2, 0); SCATTER_STORE_512(3, 0); | |||
SCATTER_STORE_512(0, 1); SCATTER_STORE_512(1, 1); SCATTER_STORE_512(2, 1); SCATTER_STORE_512(3, 1); | |||
SCATTER_STORE_512(0, 2); SCATTER_STORE_512(1, 2); SCATTER_STORE_512(2, 2); SCATTER_STORE_512(3, 2); | |||
SCATTER_STORE_512(0, 3); SCATTER_STORE_512(1, 3); SCATTER_STORE_512(2, 3); SCATTER_STORE_512(3, 3); | |||
} | |||
for (; j < n32; j += 32) { | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); | |||
DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); | |||
for (k = 0; k < K; k++) { | |||
BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); BROADCAST_LOAD_A_512(2, x); BROADCAST_LOAD_A_512(3, x); | |||
LOAD_B_512(x, 0); | |||
LOAD_B_512(x, 1); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); | |||
MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); | |||
} | |||
SCATTER_STORE_512(0, 0); SCATTER_STORE_512(1, 0); SCATTER_STORE_512(2, 0); SCATTER_STORE_512(3, 0); | |||
SCATTER_STORE_512(0, 1); SCATTER_STORE_512(1, 1); SCATTER_STORE_512(2, 1); SCATTER_STORE_512(3, 1); | |||
} | |||
__mmask16 mask = 0xffff; | |||
for (; j < N; j += 16) { | |||
int remains = N - j; | |||
if (remains < 16) mask = (1UL << remains) - 1; | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); | |||
for (k = 0; k < K; k++) { | |||
BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); BROADCAST_LOAD_A_512(2, x); BROADCAST_LOAD_A_512(3, x); | |||
MASK_LOAD_B_512(x, 0); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); | |||
} | |||
MASK_SCATTER_STORE_512(0, 0); MASK_SCATTER_STORE_512(1, 0); MASK_SCATTER_STORE_512(2, 0); MASK_SCATTER_STORE_512(3, 0); | |||
} | |||
} | |||
for (; i < m2; i += 2) { | |||
for (j = 0; j < n64; j += 64) { | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); | |||
DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); | |||
DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); | |||
DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); | |||
for (k = 0; k < K; k++) { | |||
BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); | |||
LOAD_B_512(x, 0); | |||
LOAD_B_512(x, 1); | |||
LOAD_B_512(x, 2); | |||
LOAD_B_512(x, 3); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); | |||
MATMUL_512(0, 1); MATMUL_512(1, 1); | |||
MATMUL_512(0, 2); MATMUL_512(1, 2); | |||
MATMUL_512(0, 3); MATMUL_512(1, 3); | |||
} | |||
SCATTER_STORE_512(0, 0); SCATTER_STORE_512(1, 0); | |||
SCATTER_STORE_512(0, 1); SCATTER_STORE_512(1, 1); | |||
SCATTER_STORE_512(0, 2); SCATTER_STORE_512(1, 2); | |||
SCATTER_STORE_512(0, 3); SCATTER_STORE_512(1, 3); | |||
} | |||
for (; j < n32; j += 32) { | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); | |||
DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); | |||
for (k = 0; k < K; k++) { | |||
BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); | |||
LOAD_B_512(x, 0); | |||
LOAD_B_512(x, 1); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); | |||
MATMUL_512(0, 1); MATMUL_512(1, 1); | |||
} | |||
SCATTER_STORE_512(0, 0); SCATTER_STORE_512(1, 0); | |||
SCATTER_STORE_512(0, 1); SCATTER_STORE_512(1, 1); | |||
} | |||
__mmask16 mask = 0xffff; | |||
for (; j < N; j += 16) { | |||
int remains = N - j; | |||
if (remains < 16) mask = (1UL << remains) - 1; | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); | |||
for (k = 0; k < K; k++) { | |||
BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); | |||
MASK_LOAD_B_512(x, 0); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); | |||
} | |||
MASK_SCATTER_STORE_512(0, 0); MASK_SCATTER_STORE_512(1, 0); | |||
} | |||
} | |||
for (; i < M; i += 1) { | |||
for (j = 0; j < n64; j += 64) { | |||
DECLARE_RESULT_512(0, 0); | |||
DECLARE_RESULT_512(0, 1); | |||
DECLARE_RESULT_512(0, 2); | |||
DECLARE_RESULT_512(0, 3); | |||
for (k = 0; k < K; k++) { | |||
BROADCAST_LOAD_A_512(0, x); | |||
LOAD_B_512(x, 0); | |||
LOAD_B_512(x, 1); | |||
LOAD_B_512(x, 2); | |||
LOAD_B_512(x, 3); | |||
MATMUL_512(0, 0); | |||
MATMUL_512(0, 1); | |||
MATMUL_512(0, 2); | |||
MATMUL_512(0, 3); | |||
} | |||
SCATTER_STORE_512(0, 0); | |||
SCATTER_STORE_512(0, 1); | |||
SCATTER_STORE_512(0, 2); | |||
SCATTER_STORE_512(0, 3); | |||
} | |||
for (; j < n32; j += 32) { | |||
DECLARE_RESULT_512(0, 0); | |||
DECLARE_RESULT_512(0, 1); | |||
for (k = 0; k < K; k++) { | |||
BROADCAST_LOAD_A_512(0, x); | |||
LOAD_B_512(x, 0); | |||
LOAD_B_512(x, 1); | |||
MATMUL_512(0, 0); | |||
MATMUL_512(0, 1); | |||
} | |||
SCATTER_STORE_512(0, 0); | |||
SCATTER_STORE_512(0, 1); | |||
} | |||
__mmask16 mask = 0xffff; | |||
for (; j < N; j += 16) { | |||
int remains = N - j; | |||
if (remains < 16) mask = (1UL << remains) - 1; | |||
DECLARE_RESULT_512(0, 0); | |||
for (k = 0; k < K; k++) { | |||
BROADCAST_LOAD_A_512(0, x); | |||
MASK_LOAD_B_512(x, 0); | |||
MATMUL_512(0, 0); | |||
} | |||
MASK_SCATTER_STORE_512(0, 0); | |||
} | |||
} | |||
} | |||
return 0; | |||
} |
@@ -0,0 +1,53 @@ | |||
/*************************************************************************** | |||
Copyright (c) 2021, The OpenBLAS Project | |||
All rights reserved. | |||
Redistribution and use in source and binary forms, with or without | |||
modification, are permitted provided that the following conditions are | |||
met: | |||
1. Redistributions of source code must retain the above copyright | |||
notice, this list of conditions and the following disclaimer. | |||
2. Redistributions in binary form must reproduce the above copyright | |||
notice, this list of conditions and the following disclaimer in | |||
the documentation and/or other materials provided with the | |||
distribution. | |||
3. Neither the name of the OpenBLAS project nor the names of | |||
its contributors may be used to endorse or promote products | |||
derived from this software without specific prior written permission. | |||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE | |||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE | |||
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||
*****************************************************************************/ | |||
#include "common.h" | |||
int CNAME(int transa, int transb, BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, FLOAT beta) | |||
{ | |||
double MNK = (double) M * (double) N * (double) K; | |||
if (MNK > 100.0*100.0*100.0) // disable for big size matrix | |||
return 0; | |||
// tuning for A transpose | |||
if (transa) { | |||
if (transb) { | |||
/* TT kernel perform not good when: | |||
* 1. K is too small. | |||
*/ | |||
if (K < 4) return 0; | |||
} else { | |||
/* TN kernel perform not good when: | |||
* 1. C matrix is too big | |||
* 2. K is too small | |||
*/ | |||
if (M * N > 1200 || K < 32) | |||
return 0; | |||
} | |||
} | |||
return 1; | |||
} |
@@ -0,0 +1,316 @@ | |||
/*************************************************************************** | |||
Copyright (c) 2021, The OpenBLAS Project | |||
All rights reserved. | |||
Redistribution and use in source and binary forms, with or without | |||
modification, are permitted provided that the following conditions are | |||
met: | |||
1. Redistributions of source code must retain the above copyright | |||
notice, this list of conditions and the following disclaimer. | |||
2. Redistributions in binary form must reproduce the above copyright | |||
notice, this list of conditions and the following disclaimer in | |||
the documentation and/or other materials provided with the | |||
distribution. | |||
3. Neither the name of the OpenBLAS project nor the names of | |||
its contributors may be used to endorse or promote products | |||
derived from this software without specific prior written permission. | |||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE | |||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE | |||
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||
*****************************************************************************/ | |||
#include <immintrin.h> | |||
#include "common.h" | |||
#include <stdio.h> | |||
#include <memory.h> | |||
#define DECLARE_RESULT_512(M, N) __m512 result##M##N = _mm512_setzero_ps() | |||
#define MATMUL_512(M, N) result##M##N = _mm512_fmadd_ps(Aval##M, Bval##N, result##M##N) | |||
#define LOAD_KA_512(M, N) __m512 Aval##M = _mm512_loadu_ps(&A[(i + M)*lda + k]); | |||
#define LOAD_KB_512(M, N) __m512 Bval##N = _mm512_loadu_ps(&B[(j + N)*ldb + k]) | |||
#define MASK_LOAD_KA_512(M, N) __m512 Aval##M = _mm512_maskz_loadu_ps(mask, &A[(i + M)*lda + k]) | |||
#define MASK_LOAD_KB_512(M, N) __m512 Bval##N = _mm512_maskz_loadu_ps(mask, &B[(j + N)*ldb + k]) | |||
#define REDUCE_4(rr0, rr1, rr2, rr3) \ | |||
__m512 r0, r1, r2, r3, t0, t1, t2, t3;\ | |||
r0 = _mm512_unpacklo_ps(rr0, rr1); r1 = _mm512_unpackhi_ps(rr0, rr1); \ | |||
r2 = _mm512_unpacklo_ps(rr2, rr3); r3 = _mm512_unpackhi_ps(rr2, rr3); \ | |||
t0 = _mm512_shuffle_ps(r0, r2, _MM_SHUFFLE(1, 0, 1, 0)); t1 = _mm512_shuffle_ps(r0, r2, _MM_SHUFFLE(3, 2, 3, 2)); \ | |||
t2 = _mm512_shuffle_ps(r1, r3, _MM_SHUFFLE(1, 0, 1, 0)); t3 = _mm512_shuffle_ps(r1, r3, _MM_SHUFFLE(3, 2, 3, 2)); \ | |||
r0 = _mm512_add_ps(t0, t1); r1 = _mm512_add_ps(t2, t3); t0 = _mm512_add_ps(r0, r1); \ | |||
__m128 s0, s1, s2, s3; \ | |||
s0 = _mm512_extractf32x4_ps(t0, 0); s1 = _mm512_extractf32x4_ps(t0, 1); s2 = _mm512_extractf32x4_ps(t0, 2); s3 = _mm512_extractf32x4_ps(t0, 3); \ | |||
s0 = _mm_maskz_add_ps(mask8, s0, s1); s2 = _mm_maskz_add_ps(mask8, s2, s3); s0 = _mm_maskz_add_ps(mask8, s0, s2); \ | |||
s0 = _mm_maskz_mul_ps(mask8, alpha_128, s0); | |||
#define REDUCE_M4(N) REDUCE_4(result0##N, result1##N, result2##N, result3##N) | |||
#define REDUCE_N4(M) REDUCE_4(result##M##0, result##M##1, result##M##2, result##M##3) | |||
#if defined(B0) | |||
#define STORE_REDUCE(M, N) C[(j+N)*ldc + i + M] = alpha * _mm512_reduce_add_ps(result##M##N) | |||
#define STORE_M4(N, s0) _mm_mask_storeu_ps(&C[(j + N)*ldc + i], mask8, s0); | |||
#define STORE_N4(M, s0) _mm_i32scatter_ps(&C[j*ldc + i + M], vindex_n, s0, 4); | |||
#else | |||
#define STORE_REDUCE(M, N) C[(j+N)*ldc + i + M] = alpha * _mm512_reduce_add_ps(result##M##N) + beta * C[(j+N)*ldc + i + M] | |||
#define STORE_M4(N, s0) \ | |||
asm("vfmadd231ps (%1), %2, %0": "+v"(s0):"r"(&C[(j + N)*ldc + i]), "v"(beta_128)); \ | |||
_mm_mask_storeu_ps(&C[(j + N)*ldc + i], mask8, s0); | |||
#define STORE_N4(M, s0) \ | |||
s0 = _mm_fmadd_ps(_mm_i32gather_ps(&C[j*ldc + i + M], vindex_n, 4), beta_128, s0); \ | |||
_mm_i32scatter_ps(&C[j*ldc + i + M], vindex_n, s0, 4); | |||
#endif | |||
#define STORE_REDUCE_M4(N) {\ | |||
REDUCE_M4(N) \ | |||
STORE_M4(N, s0) \ | |||
} | |||
#define STORE_REDUCE_N4(M) {\ | |||
REDUCE_N4(M) \ | |||
STORE_N4(M, s0) \ | |||
} | |||
#if defined(B0) | |||
int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb, FLOAT * C, BLASLONG ldc) | |||
#else | |||
int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb, FLOAT beta, FLOAT * C, BLASLONG ldc) | |||
#endif | |||
{ | |||
// column major | |||
BLASLONG i, j, k; | |||
BLASLONG m4 = M & ~3; | |||
BLASLONG m2 = M & ~1; | |||
BLASLONG n4 = N & ~3; | |||
BLASLONG n2 = N & ~1; | |||
BLASLONG k16 = K & ~15; | |||
__mmask16 mask; | |||
__mmask8 mask8 = 0xff; // just use to avoid SSE instruction | |||
__m128i vindex_n = _mm_set_epi32(ldc*3, ldc*2, ldc, 0); | |||
__m128 alpha_128 = _mm_broadcast_ss(&alpha); | |||
#if !defined(B0) | |||
__m128 beta_128 = _mm_broadcast_ss(&beta); | |||
#endif | |||
for (i = 0; i < m4; i += 4) { | |||
for (j = 0; j < n4; j += 4) { | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); | |||
DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); | |||
DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); DECLARE_RESULT_512(2, 2); DECLARE_RESULT_512(3, 2); | |||
DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); DECLARE_RESULT_512(2, 3); DECLARE_RESULT_512(3, 3); | |||
for (k = 0; k < k16; k += 16) { | |||
LOAD_KA_512(0, x); LOAD_KA_512(1, x); LOAD_KA_512(2, x); LOAD_KA_512(3, x); | |||
LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); LOAD_KB_512(x, 2); LOAD_KB_512(x, 3); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); | |||
MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); | |||
MATMUL_512(0, 2); MATMUL_512(1, 2); MATMUL_512(2, 2); MATMUL_512(3, 2); | |||
MATMUL_512(0, 3); MATMUL_512(1, 3); MATMUL_512(2, 3); MATMUL_512(3, 3); | |||
} | |||
int remains = K - k; | |||
if (remains) { | |||
mask = (1UL << remains) - 1; | |||
MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); MASK_LOAD_KA_512(2, x); MASK_LOAD_KA_512(3, x); | |||
MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); MASK_LOAD_KB_512(x, 2); MASK_LOAD_KB_512(x, 3); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); | |||
MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); | |||
MATMUL_512(0, 2); MATMUL_512(1, 2); MATMUL_512(2, 2); MATMUL_512(3, 2); | |||
MATMUL_512(0, 3); MATMUL_512(1, 3); MATMUL_512(2, 3); MATMUL_512(3, 3); | |||
} | |||
STORE_REDUCE_M4(0); STORE_REDUCE_M4(1); STORE_REDUCE_M4(2); STORE_REDUCE_M4(3); | |||
} | |||
for (; j < n2; j += 2) { | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); | |||
DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); | |||
for (k = 0; k < k16; k += 16) { | |||
LOAD_KA_512(0, x); LOAD_KA_512(1, x); LOAD_KA_512(2, x); LOAD_KA_512(3, x); | |||
LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); | |||
MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); | |||
} | |||
int remains = K - k; | |||
if (remains) { | |||
mask = (1UL << remains) - 1; | |||
MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); MASK_LOAD_KA_512(2, x); MASK_LOAD_KA_512(3, x); | |||
MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); | |||
MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); | |||
} | |||
STORE_REDUCE_M4(0); STORE_REDUCE_M4(1); | |||
} | |||
for (; j < N; j += 1) { | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); | |||
for (k = 0; k < k16; k += 16) { | |||
LOAD_KA_512(0, x); LOAD_KA_512(1, x); LOAD_KA_512(2, x); LOAD_KA_512(3, x); | |||
LOAD_KB_512(x, 0); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); | |||
} | |||
int remains = K - k; | |||
if (remains) { | |||
mask = (1UL << remains) - 1; | |||
MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); MASK_LOAD_KA_512(2, x); MASK_LOAD_KA_512(3, x); | |||
MASK_LOAD_KB_512(x, 0); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); | |||
} | |||
STORE_REDUCE_M4(0); | |||
} | |||
} | |||
for (; i < m2; i += 2) { | |||
for (j = 0; j < n4; j += 4) { | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); | |||
DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); | |||
DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); | |||
DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); | |||
for (k = 0; k < k16; k += 16) { | |||
LOAD_KA_512(0, x); LOAD_KA_512(1, x); | |||
LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); LOAD_KB_512(x, 2); LOAD_KB_512(x, 3); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); | |||
MATMUL_512(0, 1); MATMUL_512(1, 1); | |||
MATMUL_512(0, 2); MATMUL_512(1, 2); | |||
MATMUL_512(0, 3); MATMUL_512(1, 3); | |||
} | |||
int remains = K - k; | |||
if (remains) { | |||
mask = (1UL << remains) - 1; | |||
MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); | |||
MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); MASK_LOAD_KB_512(x, 2); MASK_LOAD_KB_512(x, 3); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); | |||
MATMUL_512(0, 1); MATMUL_512(1, 1); | |||
MATMUL_512(0, 2); MATMUL_512(1, 2); | |||
MATMUL_512(0, 3); MATMUL_512(1, 3); | |||
} | |||
STORE_REDUCE_N4(0); STORE_REDUCE_N4(1); | |||
} | |||
for (; j < n2; j += 2) { | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); | |||
DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); | |||
for (k = 0; k < k16; k += 16) { | |||
LOAD_KA_512(0, x); LOAD_KA_512(1, x); | |||
LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); | |||
MATMUL_512(0, 1); MATMUL_512(1, 1); | |||
} | |||
int remains = K - k; | |||
if (remains) { | |||
mask = (1UL << remains) - 1; | |||
MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); | |||
MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); | |||
MATMUL_512(0, 1); MATMUL_512(1, 1); | |||
} | |||
STORE_REDUCE(0, 0); STORE_REDUCE(1, 0); | |||
STORE_REDUCE(0, 1); STORE_REDUCE(1, 1); | |||
} | |||
for (; j < N; j += 1) { | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); | |||
for (k = 0; k < k16; k += 16) { | |||
LOAD_KA_512(0, x); LOAD_KA_512(1, x); | |||
LOAD_KB_512(x, 0); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); | |||
} | |||
int remains = K - k; | |||
if (remains) { | |||
mask = (1UL << remains) - 1; | |||
MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); | |||
MASK_LOAD_KB_512(x, 0); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); | |||
} | |||
STORE_REDUCE(0, 0); STORE_REDUCE(1, 0); | |||
} | |||
} | |||
for (; i < M; i += 1) { | |||
for (j = 0; j < n4; j += 4) { | |||
DECLARE_RESULT_512(0, 0); | |||
DECLARE_RESULT_512(0, 1); | |||
DECLARE_RESULT_512(0, 2); | |||
DECLARE_RESULT_512(0, 3); | |||
for (k = 0; k < k16; k += 16) { | |||
LOAD_KA_512(0, x); | |||
LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); LOAD_KB_512(x, 2); LOAD_KB_512(x, 3); | |||
MATMUL_512(0, 0); | |||
MATMUL_512(0, 1); | |||
MATMUL_512(0, 2); | |||
MATMUL_512(0, 3); | |||
} | |||
int remains = K - k; | |||
if (remains) { | |||
mask = (1UL << remains) - 1; | |||
MASK_LOAD_KA_512(0, x); | |||
MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); MASK_LOAD_KB_512(x, 2); MASK_LOAD_KB_512(x, 3); | |||
MATMUL_512(0, 0); | |||
MATMUL_512(0, 1); | |||
MATMUL_512(0, 2); | |||
MATMUL_512(0, 3); | |||
} | |||
STORE_REDUCE_N4(0); | |||
} | |||
for (; j < n2; j += 2) { | |||
DECLARE_RESULT_512(0, 0); | |||
DECLARE_RESULT_512(0, 1); | |||
for (k = 0; k < k16; k += 16) { | |||
LOAD_KA_512(0, x); | |||
LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); | |||
MATMUL_512(0, 0); | |||
MATMUL_512(0, 1); | |||
} | |||
int remains = K - k; | |||
if (remains) { | |||
mask = (1UL << remains) - 1; | |||
MASK_LOAD_KA_512(0, x); | |||
MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); | |||
MATMUL_512(0, 0); | |||
MATMUL_512(0, 1); | |||
} | |||
STORE_REDUCE(0, 0); | |||
STORE_REDUCE(0, 1); | |||
} | |||
for (; j < N; j += 1) { | |||
DECLARE_RESULT_512(0, 0); | |||
for (k = 0; k < k16; k += 16) { | |||
LOAD_KA_512(0, x); | |||
LOAD_KB_512(x, 0); | |||
MATMUL_512(0, 0); | |||
} | |||
int remains = K - k; | |||
if (remains) { | |||
mask = (1UL << remains) - 1; | |||
MASK_LOAD_KA_512(0, x); | |||
MASK_LOAD_KB_512(x, 0); | |||
MATMUL_512(0, 0); | |||
} | |||
STORE_REDUCE(0, 0); | |||
} | |||
} | |||
return 0; | |||
} |
@@ -0,0 +1,414 @@ | |||
/*************************************************************************** | |||
Copyright (c) 2021, The OpenBLAS Project | |||
All rights reserved. | |||
Redistribution and use in source and binary forms, with or without | |||
modification, are permitted provided that the following conditions are | |||
met: | |||
1. Redistributions of source code must retain the above copyright | |||
notice, this list of conditions and the following disclaimer. | |||
2. Redistributions in binary form must reproduce the above copyright | |||
notice, this list of conditions and the following disclaimer in | |||
the documentation and/or other materials provided with the | |||
distribution. | |||
3. Neither the name of the OpenBLAS project nor the names of | |||
its contributors may be used to endorse or promote products | |||
derived from this software without specific prior written permission. | |||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE | |||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE | |||
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||
*****************************************************************************/ | |||
#include <immintrin.h> | |||
#include "common.h" | |||
#include <stdio.h> | |||
#define DECLARE_RESULT_512(M, N) __m512 result##M##N = _mm512_setzero_ps() | |||
#define BROADCAST_LOAD_A_512(M, N) __m512 Aval##M = _mm512_broadcastss_ps(_mm_load_ss(&A[k + lda * (i+M)])) | |||
#define LOAD_B_512(M,N) __m512 Bval##N = _mm512_loadu_ps(&B[ldb * k + j + (N*16)]) | |||
#define MASK_LOAD_B_512(M, N) __m512 Bval##N = _mm512_maskz_loadu_ps(mask, &B[ldb * k + j + (N*16)]) | |||
#define MATMUL_512(M, N) result##M##N = _mm512_fmadd_ps(Aval##M, Bval##N, result##M##N) | |||
#if defined(B0) | |||
#define STORE_8xy(v, N, x, y) _mm256_storeu_ps(&C[(j + N*16 + x + y*8)*ldc + i], v) | |||
#define STORE_4xy(v, N, x, y) _mm_mask_storeu_ps(&C[(j + N*16 + x + y*4)*ldc + i], mask8, v) | |||
#define SCATTER_STORE_512(M, N) result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \ | |||
_mm512_i32scatter_ps(&C[(j + N*16)*ldc + i + M], vindex_n, result##M##N, 4); | |||
#define MASK_SCATTER_STORE_512(M, N) result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \ | |||
_mm512_mask_i32scatter_ps(&C[(j + N*16)*ldc + i + M], mask, vindex_n, result##M##N, 4); | |||
#else | |||
#define STORE_8xy(v, N, x, y) \ | |||
asm("vfmadd231ps (%1), %2, %0": "+v"(v): "r"(&C[(j + N*16 + x + y*8)*ldc + i]), "v"(beta_256)); \ | |||
_mm256_storeu_ps(&C[(j + N*16 + x + y*8)*ldc + i], v) | |||
#define STORE_4xy(v, N, x, y) \ | |||
asm("vfmadd231ps (%1), %2, %0": "+v"(v): "r"(&C[(j + N*16 + x + y*4)*ldc + i]), "v"(beta_128)); \ | |||
_mm_mask_storeu_ps(&C[(j + N*16 + x + y*4)*ldc + i], mask8, v) | |||
#define SCATTER_STORE_512(M, N) result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \ | |||
__m512 tmp##M##N = _mm512_i32gather_ps(vindex_n, &C[(j + N*16)*ldc + i + M], 4); \ | |||
result##M##N = _mm512_fmadd_ps(tmp##M##N, beta_512, result##M##N); \ | |||
_mm512_i32scatter_ps(&C[(j + N*16)*ldc + i + M], vindex_n, result##M##N, 4); | |||
#define MASK_SCATTER_STORE_512(M, N) result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \ | |||
__m512 tmp##M##N = _mm512_mask_i32gather_ps(_mm512_setzero_ps(), mask, vindex_n, &C[(j + N*16)*ldc + i + M], 4); \ | |||
result##M##N = _mm512_fmadd_ps(tmp##M##N, beta_512, result##M##N); \ | |||
_mm512_mask_i32scatter_ps(&C[(j + N*16)*ldc + i + M], mask, vindex_n, result##M##N, 4); | |||
#endif | |||
#define REORDER_8x16(r0, r1, r2, r3, r4, r5, r6, r7) \ | |||
__m512 t0, t1, t2, t3, t4, t5, t6, t7, v; \ | |||
t0 = _mm512_unpacklo_ps(r0, r1); \ | |||
t1 = _mm512_unpackhi_ps(r0, r1); \ | |||
t2 = _mm512_unpacklo_ps(r2, r3); \ | |||
t3 = _mm512_unpackhi_ps(r2, r3); \ | |||
t4 = _mm512_unpacklo_ps(r4, r5); \ | |||
t5 = _mm512_unpackhi_ps(r4, r5); \ | |||
t6 = _mm512_unpacklo_ps(r6, r7); \ | |||
t7 = _mm512_unpackhi_ps(r6, r7); \ | |||
v = _mm512_shuffle_ps(t0, t2, 0x4E); \ | |||
r0 = _mm512_mask_blend_ps(kc, t0, v); \ | |||
r1 = _mm512_mask_blend_ps(k3, t2, v); \ | |||
v = _mm512_shuffle_ps(t1, t3, 0x4E); \ | |||
r2 = _mm512_mask_blend_ps(kc, t1, v); \ | |||
r3 = _mm512_mask_blend_ps(k3, t3, v); \ | |||
v = _mm512_shuffle_ps(t4, t6, 0x4E); \ | |||
r4 = _mm512_mask_blend_ps(kc, t4, v); \ | |||
r5 = _mm512_mask_blend_ps(k3, t6, v); \ | |||
v = _mm512_shuffle_ps(t5, t7, 0x4E); \ | |||
r6 = _mm512_mask_blend_ps(kc, t5, v); \ | |||
r7 = _mm512_mask_blend_ps(k3, t7, v); \ | |||
t0 = _mm512_permutex2var_ps(r0, idx_lo, r4); \ | |||
t1 = _mm512_permutex2var_ps(r1, idx_lo, r5); \ | |||
t2 = _mm512_permutex2var_ps(r2, idx_lo, r6); \ | |||
t3 = _mm512_permutex2var_ps(r3, idx_lo, r7); \ | |||
t4 = _mm512_permutex2var_ps(r0, idx_hi, r4); \ | |||
t5 = _mm512_permutex2var_ps(r1, idx_hi, r5); \ | |||
t6 = _mm512_permutex2var_ps(r2, idx_hi, r6); \ | |||
t7 = _mm512_permutex2var_ps(r3, idx_hi, r7); \ | |||
t0 = _mm512_mul_ps(t0, alpha_512); \ | |||
t1 = _mm512_mul_ps(t1, alpha_512); \ | |||
t2 = _mm512_mul_ps(t2, alpha_512); \ | |||
t3 = _mm512_mul_ps(t3, alpha_512); \ | |||
t4 = _mm512_mul_ps(t4, alpha_512); \ | |||
t5 = _mm512_mul_ps(t5, alpha_512); \ | |||
t6 = _mm512_mul_ps(t6, alpha_512); \ | |||
t7 = _mm512_mul_ps(t7, alpha_512); | |||
#define SAVE_8(N, x, y) {\ | |||
__m256 v8 = _mm512_extractf32x8_ps(t##x, y); \ | |||
STORE_8xy(v8, N, x, y); \ | |||
} | |||
#define REORDER_STORE_8x16(N) {\ | |||
REORDER_8x16(result0##N, result1##N, result2##N, result3##N, result4##N, result5##N, result6##N, result7##N); \ | |||
SAVE_8(N, 0, 0); SAVE_8(N, 1, 0); SAVE_8(N, 2, 0); SAVE_8(N, 3, 0); SAVE_8(N, 4, 0); SAVE_8(N, 5, 0); SAVE_8(N, 6, 0); SAVE_8(N, 7, 0); \ | |||
SAVE_8(N, 0, 1); SAVE_8(N, 1, 1); SAVE_8(N, 2, 1); SAVE_8(N, 3, 1); SAVE_8(N, 4, 1); SAVE_8(N, 5, 1); SAVE_8(N, 6, 1); SAVE_8(N, 7, 1); \ | |||
} | |||
#define MASK_SAVE_8() \ | |||
switch (nn) { \ | |||
case 16: SAVE_8(0, 7, 1); \ | |||
case 15: SAVE_8(0, 6, 1); \ | |||
case 14: SAVE_8(0, 5, 1); \ | |||
case 13: SAVE_8(0, 4, 1); \ | |||
case 12: SAVE_8(0, 3, 1); \ | |||
case 11: SAVE_8(0, 2, 1); \ | |||
case 10: SAVE_8(0, 1, 1); \ | |||
case 9: SAVE_8(0, 0, 1); \ | |||
case 8: SAVE_8(0, 7, 0); \ | |||
case 7: SAVE_8(0, 6, 0); \ | |||
case 6: SAVE_8(0, 5, 0); \ | |||
case 5: SAVE_8(0, 4, 0); \ | |||
case 4: SAVE_8(0, 3, 0); \ | |||
case 3: SAVE_8(0, 2, 0); \ | |||
case 2: SAVE_8(0, 1, 0); \ | |||
case 1: SAVE_8(0, 0, 0); \ | |||
} | |||
#define MASK_REORDER_STORE_8x16(N) {\ | |||
REORDER_8x16(result0##N, result1##N, result2##N, result3##N, result4##N, result5##N, result6##N, result7##N); \ | |||
MASK_SAVE_8(); \ | |||
} | |||
#define REORDER_4x16(r0, r1, r2, r3) \ | |||
__m512 t0, t1, t2, t3, v; \ | |||
t0 = _mm512_unpacklo_ps(r0, r1); \ | |||
t1 = _mm512_unpackhi_ps(r0, r1); \ | |||
t2 = _mm512_unpacklo_ps(r2, r3); \ | |||
t3 = _mm512_unpackhi_ps(r2, r3); \ | |||
v = _mm512_shuffle_ps(t0, t2, 0x4E); \ | |||
r0 = _mm512_mask_blend_ps(kc, t0, v); \ | |||
r1 = _mm512_mask_blend_ps(k3, t2, v); \ | |||
v = _mm512_shuffle_ps(t1, t3, 0x4E); \ | |||
r2 = _mm512_mask_blend_ps(kc, t1, v); \ | |||
r3 = _mm512_mask_blend_ps(k3, t3, v); \ | |||
t0 = _mm512_mul_ps(r0, alpha_512); \ | |||
t1 = _mm512_mul_ps(r1, alpha_512); \ | |||
t2 = _mm512_mul_ps(r2, alpha_512); \ | |||
t3 = _mm512_mul_ps(r3, alpha_512); | |||
#define SAVE_4(N, x, y) {\ | |||
__m128 v4 = _mm512_extractf32x4_ps(t##x, y); \ | |||
STORE_4xy(v4, N, x, y); \ | |||
} | |||
#define REORDER_STORE_4x16(N) {\ | |||
REORDER_4x16(result0##N, result1##N, result2##N, result3##N); \ | |||
SAVE_4(N, 0, 0); SAVE_4(N, 1, 0); SAVE_4(N, 2, 0); SAVE_4(N, 3, 0); \ | |||
SAVE_4(N, 0, 1); SAVE_4(N, 1, 1); SAVE_4(N, 2, 1); SAVE_4(N, 3, 1); \ | |||
SAVE_4(N, 0, 2); SAVE_4(N, 1, 2); SAVE_4(N, 2, 2); SAVE_4(N, 3, 2); \ | |||
SAVE_4(N, 0, 3); SAVE_4(N, 1, 3); SAVE_4(N, 2, 3); SAVE_4(N, 3, 3); \ | |||
} | |||
#define MASK_SAVE_4() \ | |||
switch (nn) { \ | |||
case 16: SAVE_4(0, 3, 3); \ | |||
case 15: SAVE_4(0, 2, 3); \ | |||
case 14: SAVE_4(0, 1, 3); \ | |||
case 13: SAVE_4(0, 0, 3); \ | |||
case 12: SAVE_4(0, 3, 2); \ | |||
case 11: SAVE_4(0, 2, 2); \ | |||
case 10: SAVE_4(0, 1, 2); \ | |||
case 9: SAVE_4(0, 0, 2); \ | |||
case 8: SAVE_4(0, 3, 1); \ | |||
case 7: SAVE_4(0, 2, 1); \ | |||
case 6: SAVE_4(0, 1, 1); \ | |||
case 5: SAVE_4(0, 0, 1); \ | |||
case 4: SAVE_4(0, 3, 0); \ | |||
case 3: SAVE_4(0, 2, 0); \ | |||
case 2: SAVE_4(0, 1, 0); \ | |||
case 1: SAVE_4(0, 0, 0); \ | |||
} | |||
#define MASK_REORDER_STORE_4x16(N) {\ | |||
REORDER_4x16(result0##N, result1##N, result2##N, result3##N); \ | |||
MASK_SAVE_4(); \ | |||
} | |||
#if defined(B0) | |||
int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb, FLOAT * C, BLASLONG ldc) | |||
#else | |||
int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb, FLOAT beta, FLOAT * C, BLASLONG ldc) | |||
#endif | |||
{ | |||
// column major | |||
BLASLONG i, j, k; | |||
BLASLONG m8 = M & ~7; | |||
BLASLONG m4 = M & ~3; | |||
BLASLONG m2 = M & ~1; | |||
BLASLONG n64 = N & ~63; | |||
BLASLONG n32 = N & ~31; | |||
__m512 alpha_512 = _mm512_broadcastss_ps(_mm_load_ss(&alpha)); | |||
#if !defined(B0) | |||
__m256 beta_256 = _mm256_broadcastss_ps(_mm_load_ss(&beta)); | |||
__m128 beta_128 = _mm_broadcastss_ps(_mm_load_ss(&beta)); | |||
#endif | |||
int permute_table[] = { | |||
0x0, 0x1, 0x2, 0x3, 0x10, 0x11, 0x12, 0x13, 0x8, 0x9, 0xa, 0xb, 0x18, 0x19, 0x1a, 0x1b, | |||
0x4, 0x5, 0x6, 0x7, 0x14, 0x15, 0x16, 0x17, 0xc, 0xd, 0xe, 0xf, 0x1c, 0x1d, 0x1e, 0x1f, | |||
}; | |||
__m512i idx_lo = _mm512_loadu_si512(permute_table); | |||
__m512i idx_hi = _mm512_loadu_si512(permute_table + 16); | |||
__mmask16 kc = 0xcccc; | |||
__mmask16 k3 = 0x3333; | |||
__mmask8 mask8 = 0xff; // force use AVX128 instead of SSE | |||
for (i = 0; i < m8; i += 8) { | |||
for (j = 0; j < n32; j += 32) { | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); | |||
DECLARE_RESULT_512(4, 0); DECLARE_RESULT_512(5, 0); DECLARE_RESULT_512(6, 0); DECLARE_RESULT_512(7, 0); | |||
DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); | |||
DECLARE_RESULT_512(4, 1); DECLARE_RESULT_512(5, 1); DECLARE_RESULT_512(6, 1); DECLARE_RESULT_512(7, 1); | |||
for (k = 0; k < K; k++) { | |||
BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); BROADCAST_LOAD_A_512(2, x); BROADCAST_LOAD_A_512(3, x); | |||
BROADCAST_LOAD_A_512(4, x); BROADCAST_LOAD_A_512(5, x); BROADCAST_LOAD_A_512(6, x); BROADCAST_LOAD_A_512(7, x); | |||
LOAD_B_512(x, 0); LOAD_B_512(x, 1); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); | |||
MATMUL_512(4, 0); MATMUL_512(5, 0); MATMUL_512(6, 0); MATMUL_512(7, 0); | |||
MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); | |||
MATMUL_512(4, 1); MATMUL_512(5, 1); MATMUL_512(6, 1); MATMUL_512(7, 1); | |||
} | |||
REORDER_STORE_8x16(0); | |||
REORDER_STORE_8x16(1); | |||
} | |||
__mmask16 mask = 0xffff; | |||
int nn = 16; | |||
for (; j < N; j += 16) { | |||
if (N - j < 16) { | |||
nn = N - j; | |||
mask = (1UL << nn) - 1; | |||
} | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); | |||
DECLARE_RESULT_512(4, 0); DECLARE_RESULT_512(5, 0); DECLARE_RESULT_512(6, 0); DECLARE_RESULT_512(7, 0); | |||
for (k = 0; k < K; k++) { | |||
BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); BROADCAST_LOAD_A_512(2, x); BROADCAST_LOAD_A_512(3, x); | |||
BROADCAST_LOAD_A_512(4, x); BROADCAST_LOAD_A_512(5, x); BROADCAST_LOAD_A_512(6, x); BROADCAST_LOAD_A_512(7, x); | |||
MASK_LOAD_B_512(x, 0); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); | |||
MATMUL_512(4, 0); MATMUL_512(5, 0); MATMUL_512(6, 0); MATMUL_512(7, 0); | |||
} | |||
MASK_REORDER_STORE_8x16(0); | |||
} | |||
} | |||
for (; i < m4; i += 4) { | |||
for (j = 0; j < n64; j += 64) { | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); | |||
DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); | |||
DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); DECLARE_RESULT_512(2, 2); DECLARE_RESULT_512(3, 2); | |||
DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); DECLARE_RESULT_512(2, 3); DECLARE_RESULT_512(3, 3); | |||
for (k = 0; k < K; k++) { | |||
BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); BROADCAST_LOAD_A_512(2, x); BROADCAST_LOAD_A_512(3, x); | |||
LOAD_B_512(x, 0); LOAD_B_512(x, 1); LOAD_B_512(x, 2); LOAD_B_512(x, 3); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); | |||
MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); | |||
MATMUL_512(0, 2); MATMUL_512(1, 2); MATMUL_512(2, 2); MATMUL_512(3, 2); | |||
MATMUL_512(0, 3); MATMUL_512(1, 3); MATMUL_512(2, 3); MATMUL_512(3, 3); | |||
} | |||
REORDER_STORE_4x16(0); | |||
REORDER_STORE_4x16(1); | |||
REORDER_STORE_4x16(2); | |||
REORDER_STORE_4x16(3); | |||
} | |||
for (; j < n32; j += 32) { | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); | |||
DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); | |||
for (k = 0; k < K; k++) { | |||
BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); BROADCAST_LOAD_A_512(2, x); BROADCAST_LOAD_A_512(3, x); | |||
LOAD_B_512(x, 0); LOAD_B_512(x, 1); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); | |||
MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); | |||
} | |||
REORDER_STORE_4x16(0); | |||
REORDER_STORE_4x16(1); | |||
} | |||
__mmask16 mask = 0xffff; | |||
int nn = 16; | |||
for (; j < N; j += 16) { | |||
if (N - j < 16) { | |||
nn = N - j; | |||
mask = (1UL << nn) - 1; | |||
} | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); | |||
for (k = 0; k < K; k++) { | |||
BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); BROADCAST_LOAD_A_512(2, x); BROADCAST_LOAD_A_512(3, x); | |||
MASK_LOAD_B_512(x, 0); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); | |||
} | |||
MASK_REORDER_STORE_4x16(0); | |||
} | |||
} | |||
if (i < M) { | |||
int index_n[16]; | |||
for (int ii = 0; ii < 16; ii++) { | |||
index_n[ii] = ii * ldc; | |||
} | |||
__m512i vindex_n = _mm512_loadu_si512(index_n); | |||
#if !defined(B0) | |||
__m512 beta_512 = _mm512_broadcastss_ps(_mm_load_ss(&beta)); | |||
#endif | |||
for (; i < m2; i += 2) { | |||
for (j = 0; j < n64; j += 64) { | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); | |||
DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); | |||
DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); | |||
DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); | |||
for (k = 0; k < K; k++) { | |||
BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); | |||
LOAD_B_512(x, 0); LOAD_B_512(x, 1); LOAD_B_512(x, 2); LOAD_B_512(x, 3); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); | |||
MATMUL_512(0, 1); MATMUL_512(1, 1); | |||
MATMUL_512(0, 2); MATMUL_512(1, 2); | |||
MATMUL_512(0, 3); MATMUL_512(1, 3); | |||
} | |||
SCATTER_STORE_512(0, 0); SCATTER_STORE_512(1, 0); | |||
SCATTER_STORE_512(0, 1); SCATTER_STORE_512(1, 1); | |||
SCATTER_STORE_512(0, 2); SCATTER_STORE_512(1, 2); | |||
SCATTER_STORE_512(0, 3); SCATTER_STORE_512(1, 3); | |||
} | |||
for (; j < n32; j += 32) { | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); | |||
DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); | |||
for (k = 0; k < K; k++) { | |||
BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); | |||
LOAD_B_512(x, 0); LOAD_B_512(x, 1); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); | |||
MATMUL_512(0, 1); MATMUL_512(1, 1); | |||
} | |||
SCATTER_STORE_512(0, 0); SCATTER_STORE_512(1, 0); | |||
SCATTER_STORE_512(0, 1); SCATTER_STORE_512(1, 1); | |||
} | |||
__mmask16 mask = 0xffff; | |||
int nn = 16; | |||
for (; j < N; j += 16) { | |||
if (N - j < 16) { | |||
nn = N - j; | |||
mask = (1UL << nn) - 1; | |||
} | |||
DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); | |||
for (k = 0; k < K; k++) { | |||
BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); | |||
MASK_LOAD_B_512(x, 0); | |||
MATMUL_512(0, 0); MATMUL_512(1, 0); | |||
} | |||
MASK_SCATTER_STORE_512(0, 0); MASK_SCATTER_STORE_512(1, 0); | |||
} | |||
} | |||
for (; i < M; i += 1) { | |||
for (j = 0; j < n64; j += 64) { | |||
DECLARE_RESULT_512(0, 0); | |||
DECLARE_RESULT_512(0, 1); | |||
DECLARE_RESULT_512(0, 2); | |||
DECLARE_RESULT_512(0, 3); | |||
for (k = 0; k < K; k++) { | |||
BROADCAST_LOAD_A_512(0, x); | |||
LOAD_B_512(x, 0); LOAD_B_512(x, 1); LOAD_B_512(x, 2); LOAD_B_512(x, 3); | |||
MATMUL_512(0, 0); | |||
MATMUL_512(0, 1); | |||
MATMUL_512(0, 2); | |||
MATMUL_512(0, 3); | |||
} | |||
SCATTER_STORE_512(0, 0); | |||
SCATTER_STORE_512(0, 1); | |||
SCATTER_STORE_512(0, 2); | |||
SCATTER_STORE_512(0, 3); | |||
} | |||
for (; j < n32; j += 32) { | |||
DECLARE_RESULT_512(0, 0); | |||
DECLARE_RESULT_512(0, 1); | |||
for (k = 0; k < K; k++) { | |||
BROADCAST_LOAD_A_512(0, x); | |||
LOAD_B_512(x, 0); LOAD_B_512(x, 1); | |||
MATMUL_512(0, 0); | |||
MATMUL_512(0, 1); | |||
} | |||
SCATTER_STORE_512(0, 0); | |||
SCATTER_STORE_512(0, 1); | |||
} | |||
__mmask16 mask = 0xffff; | |||
int nn = 16; | |||
for (; j < N; j += 16) { | |||
if (N - j < 16) { | |||
nn = N - j; | |||
mask = (1UL << nn) - 1; | |||
} | |||
DECLARE_RESULT_512(0, 0); | |||
for (k = 0; k < K; k++) { | |||
BROADCAST_LOAD_A_512(0, x); | |||
MASK_LOAD_B_512(x, 0); | |||
MATMUL_512(0, 0); | |||
} | |||
MASK_SCATTER_STORE_512(0, 0); | |||
} | |||
} | |||
} | |||
return 0; | |||
} |