From 2c3cdaf74ed3397ad75a15d8c7f64324ecf7ecf0 Mon Sep 17 00:00:00 2001 From: Chris Sidebottom Date: Mon, 21 Jul 2025 17:09:47 +0000 Subject: [PATCH] Optimized BGEMV for NEOVERSEV1 target - Adds bgemv T based off of sbgemv T kernel - Adds bgemv N which is slightly alterated to not use Y as an accumulator due to the output being bf16 which results in loss of precision - Enables BGEMM_GEMV_FORWARD to proxy BGEMM to BGEMV with new kernels --- Makefile.system | 4 + benchmark/Makefile | 13 +- benchmark/gemv.c | 33 +++- kernel/arm64/KERNEL.NEOVERSEV1 | 3 + kernel/arm64/bgemv_n_sve_v3x4.c | 321 ++++++++++++++++++++++++++++++++ kernel/arm64/sbgemv_t_bfdot.c | 92 +++++---- test/compare_sgemv_bgemv.c | 2 +- 7 files changed, 426 insertions(+), 42 deletions(-) create mode 100644 kernel/arm64/bgemv_n_sve_v3x4.c diff --git a/Makefile.system b/Makefile.system index 9e20c132c..b214006b1 100644 --- a/Makefile.system +++ b/Makefile.system @@ -277,6 +277,7 @@ endif ifeq ($(ARCH), arm64) GEMM_GEMV_FORWARD = 1 SBGEMM_GEMV_FORWARD = 1 +BGEMM_GEMV_FORWARD = 1 endif ifeq ($(ARCH), riscv) GEMM_GEMV_FORWARD = 1 @@ -296,6 +297,9 @@ endif ifeq ($(SBGEMM_GEMV_FORWARD), 1) CCOMMON_OPT += -DSBGEMM_GEMV_FORWARD endif +ifeq ($(BGEMM_GEMV_FORWARD), 1) +CCOMMON_OPT += -DBGEMM_GEMV_FORWARD +endif endif # This operation is expensive, so execution should be once. diff --git a/benchmark/Makefile b/benchmark/Makefile index cdf87c0ab..d82a06af9 100644 --- a/benchmark/Makefile +++ b/benchmark/Makefile @@ -84,7 +84,7 @@ GOTO_LAPACK_TARGETS= endif ifeq ($(BUILD_BFLOAT16),1) -GOTO_BFLOAT_TARGETS=bgemm.goto sbgemm.goto +GOTO_BFLOAT_TARGETS=bgemm.goto sbgemm.goto bgemv.goto sbgemv.goto else GOTO_BFLOAT_TARGETS= endif @@ -667,6 +667,10 @@ bgemm.goto : bgemm.$(SUFFIX) ../$(LIBNAME) $(CC) $(CFLAGS) -o $(@F) $^ $(CEXTRALIB) $(EXTRALIB) $(FEXTRALIB) -lm sbgemm.goto : sbgemm.$(SUFFIX) ../$(LIBNAME) $(CC) $(CFLAGS) -o $(@F) $^ $(CEXTRALIB) $(EXTRALIB) $(FEXTRALIB) -lm +bgemv.goto : bgemv.$(SUFFIX) ../$(LIBNAME) + $(CC) $(CFLAGS) -o $(@F) $^ $(CEXTRALIB) $(EXTRALIB) $(FEXTRALIB) -lm +sbgemv.goto : sbgemv.$(SUFFIX) ../$(LIBNAME) + $(CC) $(CFLAGS) -o $(@F) $^ $(CEXTRALIB) $(EXTRALIB) $(FEXTRALIB) -lm endif ifeq ($(BUILD_HFLOAT16),1) @@ -3146,6 +3150,13 @@ dgemv.$(SUFFIX) : gemv.c cgemv.$(SUFFIX) : gemv.c $(CC) $(CFLAGS) -c -DCOMPLEX -UDOUBLE -o $(@F) $^ +ifeq ($(BUILD_BFLOAT16),1) +bgemv.$(SUFFIX) : gemv.c + $(CC) $(CFLAGS) -c -DBFLOAT16 -DBGEMM -UCOMPLEX -UDOUBLE -o $(@F) $^ +sbgemv.$(SUFFIX) : gemv.c + $(CC) $(CFLAGS) -c -DBFLOAT16 -UCOMPLEX -UDOUBLE -o $(@F) $^ +endif () + zgemv.$(SUFFIX) : gemv.c $(CC) $(CFLAGS) -c -DCOMPLEX -DDOUBLE -o $(@F) $^ diff --git a/benchmark/gemv.c b/benchmark/gemv.c index fc39f3f3d..884a9bb27 100644 --- a/benchmark/gemv.c +++ b/benchmark/gemv.c @@ -1,5 +1,5 @@ /*************************************************************************** -Copyright (c) 2014, The OpenBLAS Project +Copyright (c) 2014, 2025 The OpenBLAS Project All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -34,6 +34,12 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #ifdef DOUBLE #define GEMV BLASFUNC(dgemv) +#elif defined(BFLOAT16) && defined(BGEMM) +#define GEMV BLASFUNC(bgemv) +#elif defined(BFLOAT16) +#define GEMV BLASFUNC(sbgemv) +#undef IFLOAT +#define IFLOAT bfloat16 #else #define GEMV BLASFUNC(sgemv) #endif @@ -49,9 +55,20 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #endif int main(int argc, char *argv[]){ - FLOAT *a, *x, *y; - FLOAT alpha[] = {1.0, 1.0}; - FLOAT beta [] = {1.0, 0.0}; + IFLOAT *a, *x; + FLOAT *y; +#ifdef BGEMM + blasint one=1; + blasint two=2; + float alpha_in[] = {1.0, 0.0}; + float beta_in[] = {0.0, 0.0}; + FLOAT alpha[2], beta[2]; + sbstobf16_(&two, alpha_in, &one, alpha, &one); + sbstobf16_(&two, beta_in, &one, beta, &one); +#else + FLOAT alpha[] = {1.0, 0.0}; + FLOAT beta [] = {0.0, 0.0}; +#endif char trans='N'; blasint m, i, j; blasint inc_x=1,inc_y=1; @@ -97,11 +114,11 @@ int main(int argc, char *argv[]){ fprintf(stderr, "From : %3d To : %3d Step = %3d Trans = '%c' Inc_x = %d Inc_y = %d Loops = %d\n", from, to, step,trans,inc_x,inc_y,loops); - if (( a = (FLOAT *)malloc(sizeof(FLOAT) * tomax * tomax * COMPSIZE)) == NULL){ + if (( a = (IFLOAT *)malloc(sizeof(IFLOAT) * tomax * tomax * COMPSIZE)) == NULL){ fprintf(stderr,"Out of Memory!!\n");exit(1); } - if (( x = (FLOAT *)malloc(sizeof(FLOAT) * tomax * abs(inc_x) * COMPSIZE)) == NULL){ + if (( x = (IFLOAT *)malloc(sizeof(IFLOAT) * tomax * abs(inc_x) * COMPSIZE)) == NULL){ fprintf(stderr,"Out of Memory!!\n");exit(1); } @@ -125,7 +142,7 @@ int main(int argc, char *argv[]){ fprintf(stderr, " %6dx%d : ", (int)m,(int)n); for(j = 0; j < m; j++){ for(i = 0; i < n * COMPSIZE; i++){ - a[(long)i + (long)j * (long)m * COMPSIZE] = ((FLOAT) rand() / (FLOAT) RAND_MAX) - 0.5; + a[(long)i + (long)j * (long)m * COMPSIZE] = ((IFLOAT) rand() / (FLOAT) RAND_MAX) - 0.5; } } @@ -133,7 +150,7 @@ int main(int argc, char *argv[]){ { for(i = 0; i < n * COMPSIZE * abs(inc_x); i++){ - x[i] = ((FLOAT) rand() / (FLOAT) RAND_MAX) - 0.5; + x[i] = ((IFLOAT) rand() / (FLOAT) RAND_MAX) - 0.5; } for(i = 0; i < m * COMPSIZE * abs(inc_y); i++){ diff --git a/kernel/arm64/KERNEL.NEOVERSEV1 b/kernel/arm64/KERNEL.NEOVERSEV1 index 8bc0f35e5..6f632f2cc 100644 --- a/kernel/arm64/KERNEL.NEOVERSEV1 +++ b/kernel/arm64/KERNEL.NEOVERSEV1 @@ -46,6 +46,9 @@ BGEMMOTCOPY = sbgemm_tcopy_$(BGEMM_UNROLL_N)_neoversev1.c BGEMMONCOPYOBJ = bgemm_oncopy$(TSUFFIX).$(SUFFIX) BGEMMOTCOPYOBJ = bgemm_otcopy$(TSUFFIX).$(SUFFIX) +BGEMVTKERNEL = sbgemv_t_bfdot.c +BGEMVNKERNEL = bgemv_n_sve_v3x4.c + SBGEMM_BETA = sbgemm_beta_neoversev1.c SBGEMMKERNEL = sbgemm_kernel_$(SBGEMM_UNROLL_M)x$(SBGEMM_UNROLL_N)_neoversev1.c ifneq ($(SBGEMM_UNROLL_M), $(SBGEMM_UNROLL_N)) diff --git a/kernel/arm64/bgemv_n_sve_v3x4.c b/kernel/arm64/bgemv_n_sve_v3x4.c new file mode 100644 index 000000000..6347746d0 --- /dev/null +++ b/kernel/arm64/bgemv_n_sve_v3x4.c @@ -0,0 +1,321 @@ +/*************************************************************************** +Copyright (c) 2025 The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#include "common.h" + +#include + +#define UPDATE_PTRSx2 \ + a_ptr1 = a_ptr0 + lda; + +#define UPDATE_PTRSx4 \ + UPDATE_PTRSx2 \ + a_ptr2 = a_ptr1 + lda; \ + a_ptr3 = a_ptr2 + lda; + +#define UPDATE_PTRSx8 \ + UPDATE_PTRSx4 \ + a_ptr4 = a_ptr3 + lda; \ + a_ptr5 = a_ptr4 + lda; \ + a_ptr6 = a_ptr5 + lda; \ + a_ptr7 = a_ptr6 + lda; + +#define LANESx2(MACRO, offset) \ + MACRO(offset, 0) \ + MACRO(offset, 1) + +#define LANESx4(MACRO, offset) \ + LANESx2(MACRO, offset) \ + MACRO(offset, 2) \ + MACRO(offset, 3) + +#define LANESx8(MACRO, offset) \ + LANESx4(MACRO, offset) \ + MACRO(offset, 4) \ + MACRO(offset, 5) \ + MACRO(offset, 6) \ + MACRO(offset, 7) + +#define LOAD_A_VEC(offset, vec) \ + svbfloat16_t a_vec ## offset ## vec = svld1(pg_full, &a_ptr ## vec[i] + offset * sve_size_bf16); + +#define UPDATE_ACCUMULATORS_FROM_LANE(offset, lane) \ + acc ## offset ## 0 = svbfmlalb_lane(acc ## offset ## 0, a_vec ## offset ## lane, x_vec, lane); \ + acc ## offset ## 1 = svbfmlalt_lane(acc ## offset ## 1, a_vec ## offset ## lane, x_vec, lane); + +#define INIT_ACCUMULATORS(offset) \ + svfloat32_t acc ## offset ## 0 = svdup_f32(0.0); \ + svfloat32_t acc ## offset ## 1 = svdup_f32(0.0); + +#define UPDATE_ACCUMULATORS(offset) \ + acc ## offset ## 0 = svbfmlalb(acc ## offset ## 0, a_vec ## offset ## 0, x_vec); \ + acc ## offset ## 1 = svbfmlalt(acc ## offset ## 1, a_vec ## offset ## 0, x_vec); + +#define STORE_ACCUMULATORS(offset) \ + svbfloat16_t acc ## offset ## 0_bf16 = svcvt_bf16_x(pg_full, acc ## offset ## 0); \ + svbfloat16_t acc ## offset ## 1_bf16 = svcvt_bf16_x(pg_full, acc ## offset ## 1); \ + svbfloat16_t combined ## offset = svtrn1(acc ## offset ## 0_bf16, acc ## offset ## 1_bf16); \ + svst1(pg_full, &y[i] + offset * sve_size_bf16, combined ## offset); + +#define ALPHA_OP(offset) \ + acc ## offset ## 0 = svmul_x(pg_full, acc ## offset ## 0, svalpha); \ + acc ## offset ## 1 = svmul_x(pg_full, acc ## offset ## 1, svalpha); + +#define BETA_OP(offset) \ + svbfloat16_t y_vec ## offset = svld1(pg_full, &y[i] + offset * sve_size_bf16); \ + acc ## offset ## 0 = svbfmlalb(acc ## offset ## 0, svbeta16, y_vec ## offset); \ + acc ## offset ## 1 = svbfmlalt(acc ## offset ## 1, svbeta16, y_vec ## offset); + +int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a_in, BLASLONG lda, IFLOAT *x_in, BLASLONG inc_x, FLOAT beta, FLOAT *y_in, BLASLONG inc_y) +{ + BLASLONG ix; + bfloat16_t *a_ptr0, *a_ptr1, *a_ptr2, *a_ptr3, *a_ptr4, *a_ptr5, *a_ptr6, *a_ptr7; + BLASLONG sve_size_bf16 = svcnth(); + BLASLONG sve_size2_bf16 = sve_size_bf16 * 2; + BLASLONG sve_size3_bf16 = sve_size_bf16 * 3; + svbool_t pg_full = svptrue_b16(); + svbool_t pg_tail = svwhilelt_b16_s64(0, m % sve_size_bf16); + + BLASLONG n8 = n & -8; + BLASLONG n4 = n & -4; + BLASLONG n2 = n & -2; + + bfloat16_t *a = (bfloat16_t*)a_in; + bfloat16_t *x = (bfloat16_t*)x_in; + bfloat16_t *y = (bfloat16_t*)y_in; + + bfloat16_t alpha_bf16, beta_bf16; + memcpy(&alpha_bf16, &alpha, sizeof(bfloat16_t)); + memcpy(&beta_bf16, &beta, sizeof(bfloat16_t)); + float beta_f32 = vcvtah_f32_bf16(beta_bf16); + float alpha_f32 = vcvtah_f32_bf16(alpha_bf16); + svfloat32_t svalpha = svdup_f32(vcvtah_f32_bf16(alpha_bf16)); + svbfloat16_t svbeta16 = svdup_bf16(beta_bf16); + + BLASLONG i = 0; + if (inc_y == 1) { + for (; (i + sve_size3_bf16 - 1) < m; i+= sve_size3_bf16) { + INIT_ACCUMULATORS(0); + INIT_ACCUMULATORS(1); + INIT_ACCUMULATORS(2); + + BLASLONG j = 0; + ix = 0; + a_ptr0 = a; + UPDATE_PTRSx8; + + if (inc_x == 1) { + for (; j < n4; j+= 4) { + uint64_t* x_u64 = (uint64_t*)&x[ix]; + svbfloat16_t x_vec = svreinterpret_bf16_u64(svdup_u64(*x_u64)); + + LANESx4(LOAD_A_VEC, 0); + LANESx4(LOAD_A_VEC, 1); + LANESx4(LOAD_A_VEC, 2); + LANESx4(UPDATE_ACCUMULATORS_FROM_LANE, 0); + LANESx4(UPDATE_ACCUMULATORS_FROM_LANE, 1); + LANESx4(UPDATE_ACCUMULATORS_FROM_LANE, 2); + + ix += 4; + + a_ptr0 += 4 * lda; + UPDATE_PTRSx4; + } + + for (; j < n2; j+= 2) { + uint32_t* x_u32 = (uint32_t*)&x[ix]; + svbfloat16_t x_vec = svreinterpret_bf16_u32(svdup_u32(*x_u32)); + + LANESx2(LOAD_A_VEC, 0); + LANESx2(LOAD_A_VEC, 1); + LANESx2(LOAD_A_VEC, 2); + LANESx2(UPDATE_ACCUMULATORS_FROM_LANE, 0); + LANESx2(UPDATE_ACCUMULATORS_FROM_LANE, 1); + LANESx2(UPDATE_ACCUMULATORS_FROM_LANE, 2); + + ix += 2; + + a_ptr0 += 2 * lda; + UPDATE_PTRSx2; + } + } + + for (; j < n; j++) { + svbfloat16_t x_vec = svdup_bf16(x[ix]); + LOAD_A_VEC(0, 0); + LOAD_A_VEC(1, 0); + LOAD_A_VEC(2, 0); + UPDATE_ACCUMULATORS(0); + UPDATE_ACCUMULATORS(1); + UPDATE_ACCUMULATORS(2); + + ix += inc_x; + a_ptr0 += lda; + } + + if (alpha_f32 != ONE) { + ALPHA_OP(0); + ALPHA_OP(1); + ALPHA_OP(2); + } + if (beta_f32 != ZERO) { + BETA_OP(0); + BETA_OP(1); + BETA_OP(2); + } + + STORE_ACCUMULATORS(0); + STORE_ACCUMULATORS(1); + STORE_ACCUMULATORS(2); + } + + for (; (i + sve_size_bf16 - 1) < m; i+= sve_size_bf16) { + INIT_ACCUMULATORS(0); + + BLASLONG j = 0; + ix = 0; + a_ptr0 = a; + UPDATE_PTRSx8; + + if (inc_x == 1) { + for (; j < n8; j+= 8) { + svbfloat16_t x_vec = svld1rq(pg_full, &x[ix]); + LANESx8(LOAD_A_VEC, 0); + LANESx8(UPDATE_ACCUMULATORS_FROM_LANE, 0); + + ix += 8; + + a_ptr0 += 8 * lda; + UPDATE_PTRSx8; + } + + for (; j < n4; j+= 4) { + uint64_t* x_u64 = (uint64_t*)&x[ix]; + svbfloat16_t x_vec = svreinterpret_bf16_u64(svdup_u64(*x_u64)); + + LANESx4(LOAD_A_VEC, 0); + LANESx4(UPDATE_ACCUMULATORS_FROM_LANE, 0); + + ix += 4; + + a_ptr0 += 4 * lda; + UPDATE_PTRSx4; + } + + for (; j < n2; j+= 2) { + uint32_t* x_u32 = (uint32_t*)&x[ix]; + svbfloat16_t x_vec = svreinterpret_bf16_u32(svdup_u32(*x_u32)); + + LANESx2(LOAD_A_VEC, 0); + LANESx2(UPDATE_ACCUMULATORS_FROM_LANE, 0); + + ix += 2; + + a_ptr0 += 2 * lda; + UPDATE_PTRSx2; + } + } + + for (; j < n; j++) { + svbfloat16_t x_vec = svdup_bf16(x[ix]); + LOAD_A_VEC(0, 0); + UPDATE_ACCUMULATORS(0); + + ix += inc_x; + a_ptr0 += lda; + } + + if (alpha_f32 != ONE) { + ALPHA_OP(0); + } + if (beta_f32 != ZERO) { + BETA_OP(0); + } + + STORE_ACCUMULATORS(0); + } + + if (i < m) { + svfloat32_t acc0 = svdup_f32(0.0); + svfloat32_t acc1 = svdup_f32(0.0); + + ix = 0; + a_ptr0 = a; + for (BLASLONG j = 0; j < n; j++) { + svbfloat16_t x_vec0 = svdup_bf16(x[ix]); + svbfloat16_t a_vec0 = svld1(pg_tail, &a_ptr0[i]); + + acc0 = svbfmlalb(acc0, a_vec0, x_vec0); + acc1 = svbfmlalt(acc1, a_vec0, x_vec0); + + ix += inc_x; + a_ptr0 += lda; + } + + if (alpha_f32 != ONE) { + acc0 = svmul_x(pg_full, acc0, svalpha); + acc1 = svmul_x(pg_full, acc1, svalpha); + } + if (beta_f32 != ZERO) { + svbfloat16_t y_vec = svld1(pg_tail, &y[i]); + acc0 = svbfmlalb(acc0, svbeta16, y_vec); + acc1 = svbfmlalt(acc1, svbeta16, y_vec); + } + + svbfloat16_t acc0_bf16 = svcvt_bf16_x(pg_full, acc0); + svbfloat16_t acc1_bf16 = svcvt_bf16_x(pg_full, acc1); + svbfloat16_t combined = svtrn1(acc0_bf16, acc1_bf16); + svst1(pg_tail, &y[i], combined); + } + + return 0; + } + + // Scalar fallback + BLASLONG iy = 0; + for (; i < m; i++) { + float temp = 0.0; + + ix = 0; + a_ptr0 = a; + for (BLASLONG j = 0; j < n; j++) { + temp += vcvtah_f32_bf16(a_ptr0[i]) * vcvtah_f32_bf16(x[ix]); + ix += inc_x; + a_ptr0 += lda; + } + + if (beta_f32 == ZERO) { + y[iy] = vcvth_bf16_f32(alpha_f32 * temp); + } else { + y[iy] = vcvth_bf16_f32(alpha_f32 * temp + beta_f32 * vcvtah_f32_bf16(y[iy])); + } + + iy += inc_y; + } + + return (0); +} diff --git a/kernel/arm64/sbgemv_t_bfdot.c b/kernel/arm64/sbgemv_t_bfdot.c index 672f70acf..4de245d3b 100644 --- a/kernel/arm64/sbgemv_t_bfdot.c +++ b/kernel/arm64/sbgemv_t_bfdot.c @@ -33,16 +33,39 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include #include "common.h" -int CNAME(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, BLASLONG incx, float beta, float *y, BLASLONG incy) +#ifdef BGEMM +#define INNER_FLOAT bfloat16_t +#define TO32(x) vcvtah_f32_bf16(x) +#define FROM32(x) vcvth_bf16_f32(x) +#else +#define INNER_FLOAT float +#define TO32(x) x +#define FROM32(x) x +#endif + +int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, BLASLONG incx, FLOAT beta, FLOAT *y_in, BLASLONG incy) { if (m < 1 || n < 1) return(0); + BLASLONG i; BLASLONG ix,iy; BLASLONG j; bfloat16_t *a_ptr; bfloat16_t *x_ptr; - float *y_ptr; - float temp; + float temp, temp0, temp1, temp2, temp3; + +#ifdef BGEMM + bfloat16_t alpha_bf16, beta_bf16; + memcpy(&alpha_bf16, &alpha, sizeof(bfloat16_t)); + memcpy(&beta_bf16, &beta, sizeof(bfloat16_t)); + float alpha_f32 = vcvtah_f32_bf16(alpha_bf16); + float beta_f32 = vcvtah_f32_bf16(beta_bf16); +#else + float alpha_f32 = alpha; + float beta_f32 = beta; +#endif + INNER_FLOAT *y = (INNER_FLOAT *)y_in; + INNER_FLOAT *y_ptr; iy = 0; a_ptr = (bfloat16_t*)(a); @@ -56,10 +79,10 @@ int CNAME(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat bfloat16_t *a2_ptr = a_ptr + lda * width * 2; bfloat16_t *a3_ptr = a_ptr + lda * width * 3; - float *y0_ptr = y + incy * width * 0; - float *y1_ptr = y + incy * width * 1; - float *y2_ptr = y + incy * width * 2; - float *y3_ptr = y + incy * width * 3; + INNER_FLOAT *y0_ptr = y + incy * width * 0; + INNER_FLOAT *y1_ptr = y + incy * width * 1; + INNER_FLOAT *y2_ptr = y + incy * width * 2; + INNER_FLOAT *y3_ptr = y + incy * width * 3; for (j = 0; j < width; j++) { float32x4_t temp0_vec = vdupq_n_f32(0.0f); @@ -113,26 +136,31 @@ int CNAME(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat i += 4; } - if (beta == 0.0f) { - y0_ptr[iy] = alpha * vaddvq_f32(temp0_vec); - y1_ptr[iy] = alpha * vaddvq_f32(temp1_vec); - y2_ptr[iy] = alpha * vaddvq_f32(temp2_vec); - y3_ptr[iy] = alpha * vaddvq_f32(temp3_vec); - } - else { - y0_ptr[iy] = alpha * vaddvq_f32(temp0_vec) + beta * y0_ptr[iy]; - y1_ptr[iy] = alpha * vaddvq_f32(temp1_vec) + beta * y1_ptr[iy]; - y2_ptr[iy] = alpha * vaddvq_f32(temp2_vec) + beta * y2_ptr[iy]; - y3_ptr[iy] = alpha * vaddvq_f32(temp3_vec) + beta * y3_ptr[iy]; + + if (beta_f32 == 0.0f) { + temp0 = alpha_f32 * vaddvq_f32(temp0_vec); + temp1 = alpha_f32 * vaddvq_f32(temp1_vec); + temp2 = alpha_f32 * vaddvq_f32(temp2_vec); + temp3 = alpha_f32 * vaddvq_f32(temp3_vec); + } else { + temp0 = alpha_f32 * vaddvq_f32(temp0_vec) + beta_f32 * TO32(y0_ptr[iy]); + temp1 = alpha_f32 * vaddvq_f32(temp1_vec) + beta_f32 * TO32(y1_ptr[iy]); + temp2 = alpha_f32 * vaddvq_f32(temp2_vec) + beta_f32 * TO32(y2_ptr[iy]); + temp3 = alpha_f32 * vaddvq_f32(temp3_vec) + beta_f32 * TO32(y3_ptr[iy]); } for (; i < m; ++i) { - y0_ptr[iy] += alpha * vcvtah_f32_bf16(a0_ptr[i]) * vcvtah_f32_bf16(x_ptr[i]); - y1_ptr[iy] += alpha * vcvtah_f32_bf16(a1_ptr[i]) * vcvtah_f32_bf16(x_ptr[i]); - y2_ptr[iy] += alpha * vcvtah_f32_bf16(a2_ptr[i]) * vcvtah_f32_bf16(x_ptr[i]); - y3_ptr[iy] += alpha * vcvtah_f32_bf16(a3_ptr[i]) * vcvtah_f32_bf16(x_ptr[i]); + temp0 = temp0 + alpha_f32 * vcvtah_f32_bf16(a0_ptr[i]) * vcvtah_f32_bf16(x_ptr[i]); + temp1 = temp1 + alpha_f32 * vcvtah_f32_bf16(a1_ptr[i]) * vcvtah_f32_bf16(x_ptr[i]); + temp2 = temp2 + alpha_f32 * vcvtah_f32_bf16(a2_ptr[i]) * vcvtah_f32_bf16(x_ptr[i]); + temp3 = temp3 + alpha_f32 * vcvtah_f32_bf16(a3_ptr[i]) * vcvtah_f32_bf16(x_ptr[i]); } + y0_ptr[iy] = FROM32(temp0); + y1_ptr[iy] = FROM32(temp1); + y2_ptr[iy] = FROM32(temp2); + y3_ptr[iy] = FROM32(temp3); + iy += incy; a0_ptr += lda; @@ -164,16 +192,16 @@ int CNAME(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat i += 4; } - if (beta == 0.0f) { - y_ptr[iy] = alpha * vaddvq_f32(temp0_vec); - } - else { - y_ptr[iy] = alpha * vaddvq_f32(temp0_vec) + beta * y_ptr[iy]; - } + if (beta_f32 == 0.0f) { + temp = alpha_f32 * vaddvq_f32(temp0_vec); + } else { + temp = alpha_f32 * vaddvq_f32(temp0_vec) + beta_f32 * TO32(y_ptr[iy]); + } for (; i < m; ++i) { - y_ptr[iy] += alpha * vcvtah_f32_bf16(a_ptr[i]) * vcvtah_f32_bf16(x_ptr[i]); + temp += alpha_f32 * vcvtah_f32_bf16(a_ptr[i]) * vcvtah_f32_bf16(x_ptr[i]); } + y_ptr[iy] = FROM32(temp); iy += incy; @@ -189,11 +217,11 @@ int CNAME(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat temp += vcvtah_f32_bf16(a_ptr[i]) * vcvtah_f32_bf16(x_ptr[ix]); ix += incx; } - if (beta == 0.0f) { - y[iy] = alpha * temp; + if (beta_f32 == 0.0f) { + y[iy] = FROM32(alpha_f32 * temp); } else { - y[iy] = alpha * temp + beta * y[iy]; + y[iy] = FROM32(alpha_f32 * temp + beta_f32 * TO32(y[iy])); } iy += incy; a_ptr += lda; diff --git a/test/compare_sgemv_bgemv.c b/test/compare_sgemv_bgemv.c index bab15deb1..17011a4b7 100644 --- a/test/compare_sgemv_bgemv.c +++ b/test/compare_sgemv_bgemv.c @@ -127,7 +127,7 @@ int main(int argc, char *argv[]) if (!is_close(float16to32(CC[j << l]), truncate_float32_to_bfloat16(DD[j]), 0.001, 0.0001)) { #ifdef DEBUG - printf("Mismatch at trans=%c, alpha=%.2f, beta=%.2f, i=%d, j=%d, k=%ld: CC=%.6f, C=%.6f\n", + printf("Mismatch at trans=%c, alpha=%.2f, beta=%.2f, i=%d, j=%d, k=%ld: CC=%.6f, DD=%.6f\n", transA, alpha, beta, i, j, k, float16to32(CC[j << l]), truncate_float32_to_bfloat16(DD[j])); #endif ret++;