Browse Source

Optimized BGEMV for NEOVERSEV1 target

- Adds bgemv T based off of sbgemv T kernel
- Adds bgemv N which is slightly alterated to not use Y as an
accumulator due to the output being bf16 which results in loss of
precision
- Enables BGEMM_GEMV_FORWARD to proxy BGEMM to BGEMV with new kernels
pull/5394/head
Chris Sidebottom 2 months ago
parent
commit
2c3cdaf74e
7 changed files with 426 additions and 42 deletions
  1. +4
    -0
      Makefile.system
  2. +12
    -1
      benchmark/Makefile
  3. +25
    -8
      benchmark/gemv.c
  4. +3
    -0
      kernel/arm64/KERNEL.NEOVERSEV1
  5. +321
    -0
      kernel/arm64/bgemv_n_sve_v3x4.c
  6. +60
    -32
      kernel/arm64/sbgemv_t_bfdot.c
  7. +1
    -1
      test/compare_sgemv_bgemv.c

+ 4
- 0
Makefile.system View File

@@ -277,6 +277,7 @@ endif
ifeq ($(ARCH), arm64)
GEMM_GEMV_FORWARD = 1
SBGEMM_GEMV_FORWARD = 1
BGEMM_GEMV_FORWARD = 1
endif
ifeq ($(ARCH), riscv)
GEMM_GEMV_FORWARD = 1
@@ -296,6 +297,9 @@ endif
ifeq ($(SBGEMM_GEMV_FORWARD), 1)
CCOMMON_OPT += -DSBGEMM_GEMV_FORWARD
endif
ifeq ($(BGEMM_GEMV_FORWARD), 1)
CCOMMON_OPT += -DBGEMM_GEMV_FORWARD
endif
endif

# This operation is expensive, so execution should be once.


+ 12
- 1
benchmark/Makefile View File

@@ -84,7 +84,7 @@ GOTO_LAPACK_TARGETS=
endif

ifeq ($(BUILD_BFLOAT16),1)
GOTO_BFLOAT_TARGETS=bgemm.goto sbgemm.goto
GOTO_BFLOAT_TARGETS=bgemm.goto sbgemm.goto bgemv.goto sbgemv.goto
else
GOTO_BFLOAT_TARGETS=
endif
@@ -667,6 +667,10 @@ bgemm.goto : bgemm.$(SUFFIX) ../$(LIBNAME)
$(CC) $(CFLAGS) -o $(@F) $^ $(CEXTRALIB) $(EXTRALIB) $(FEXTRALIB) -lm
sbgemm.goto : sbgemm.$(SUFFIX) ../$(LIBNAME)
$(CC) $(CFLAGS) -o $(@F) $^ $(CEXTRALIB) $(EXTRALIB) $(FEXTRALIB) -lm
bgemv.goto : bgemv.$(SUFFIX) ../$(LIBNAME)
$(CC) $(CFLAGS) -o $(@F) $^ $(CEXTRALIB) $(EXTRALIB) $(FEXTRALIB) -lm
sbgemv.goto : sbgemv.$(SUFFIX) ../$(LIBNAME)
$(CC) $(CFLAGS) -o $(@F) $^ $(CEXTRALIB) $(EXTRALIB) $(FEXTRALIB) -lm
endif

ifeq ($(BUILD_HFLOAT16),1)
@@ -3146,6 +3150,13 @@ dgemv.$(SUFFIX) : gemv.c
cgemv.$(SUFFIX) : gemv.c
$(CC) $(CFLAGS) -c -DCOMPLEX -UDOUBLE -o $(@F) $^

ifeq ($(BUILD_BFLOAT16),1)
bgemv.$(SUFFIX) : gemv.c
$(CC) $(CFLAGS) -c -DBFLOAT16 -DBGEMM -UCOMPLEX -UDOUBLE -o $(@F) $^
sbgemv.$(SUFFIX) : gemv.c
$(CC) $(CFLAGS) -c -DBFLOAT16 -UCOMPLEX -UDOUBLE -o $(@F) $^
endif ()

zgemv.$(SUFFIX) : gemv.c
$(CC) $(CFLAGS) -c -DCOMPLEX -DDOUBLE -o $(@F) $^



+ 25
- 8
benchmark/gemv.c View File

@@ -1,5 +1,5 @@
/***************************************************************************
Copyright (c) 2014, The OpenBLAS Project
Copyright (c) 2014, 2025 The OpenBLAS Project
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
@@ -34,6 +34,12 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

#ifdef DOUBLE
#define GEMV BLASFUNC(dgemv)
#elif defined(BFLOAT16) && defined(BGEMM)
#define GEMV BLASFUNC(bgemv)
#elif defined(BFLOAT16)
#define GEMV BLASFUNC(sbgemv)
#undef IFLOAT
#define IFLOAT bfloat16
#else
#define GEMV BLASFUNC(sgemv)
#endif
@@ -49,9 +55,20 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#endif
int main(int argc, char *argv[]){

FLOAT *a, *x, *y;
FLOAT alpha[] = {1.0, 1.0};
FLOAT beta [] = {1.0, 0.0};
IFLOAT *a, *x;
FLOAT *y;
#ifdef BGEMM
blasint one=1;
blasint two=2;
float alpha_in[] = {1.0, 0.0};
float beta_in[] = {0.0, 0.0};
FLOAT alpha[2], beta[2];
sbstobf16_(&two, alpha_in, &one, alpha, &one);
sbstobf16_(&two, beta_in, &one, beta, &one);
#else
FLOAT alpha[] = {1.0, 0.0};
FLOAT beta [] = {0.0, 0.0};
#endif
char trans='N';
blasint m, i, j;
blasint inc_x=1,inc_y=1;
@@ -97,11 +114,11 @@ int main(int argc, char *argv[]){

fprintf(stderr, "From : %3d To : %3d Step = %3d Trans = '%c' Inc_x = %d Inc_y = %d Loops = %d\n", from, to, step,trans,inc_x,inc_y,loops);

if (( a = (FLOAT *)malloc(sizeof(FLOAT) * tomax * tomax * COMPSIZE)) == NULL){
if (( a = (IFLOAT *)malloc(sizeof(IFLOAT) * tomax * tomax * COMPSIZE)) == NULL){
fprintf(stderr,"Out of Memory!!\n");exit(1);
}

if (( x = (FLOAT *)malloc(sizeof(FLOAT) * tomax * abs(inc_x) * COMPSIZE)) == NULL){
if (( x = (IFLOAT *)malloc(sizeof(IFLOAT) * tomax * abs(inc_x) * COMPSIZE)) == NULL){
fprintf(stderr,"Out of Memory!!\n");exit(1);
}

@@ -125,7 +142,7 @@ int main(int argc, char *argv[]){
fprintf(stderr, " %6dx%d : ", (int)m,(int)n);
for(j = 0; j < m; j++){
for(i = 0; i < n * COMPSIZE; i++){
a[(long)i + (long)j * (long)m * COMPSIZE] = ((FLOAT) rand() / (FLOAT) RAND_MAX) - 0.5;
a[(long)i + (long)j * (long)m * COMPSIZE] = ((IFLOAT) rand() / (FLOAT) RAND_MAX) - 0.5;
}
}

@@ -133,7 +150,7 @@ int main(int argc, char *argv[]){
{

for(i = 0; i < n * COMPSIZE * abs(inc_x); i++){
x[i] = ((FLOAT) rand() / (FLOAT) RAND_MAX) - 0.5;
x[i] = ((IFLOAT) rand() / (FLOAT) RAND_MAX) - 0.5;
}

for(i = 0; i < m * COMPSIZE * abs(inc_y); i++){


+ 3
- 0
kernel/arm64/KERNEL.NEOVERSEV1 View File

@@ -46,6 +46,9 @@ BGEMMOTCOPY = sbgemm_tcopy_$(BGEMM_UNROLL_N)_neoversev1.c
BGEMMONCOPYOBJ = bgemm_oncopy$(TSUFFIX).$(SUFFIX)
BGEMMOTCOPYOBJ = bgemm_otcopy$(TSUFFIX).$(SUFFIX)

BGEMVTKERNEL = sbgemv_t_bfdot.c
BGEMVNKERNEL = bgemv_n_sve_v3x4.c

SBGEMM_BETA = sbgemm_beta_neoversev1.c
SBGEMMKERNEL = sbgemm_kernel_$(SBGEMM_UNROLL_M)x$(SBGEMM_UNROLL_N)_neoversev1.c
ifneq ($(SBGEMM_UNROLL_M), $(SBGEMM_UNROLL_N))


+ 321
- 0
kernel/arm64/bgemv_n_sve_v3x4.c View File

@@ -0,0 +1,321 @@
/***************************************************************************
Copyright (c) 2025 The OpenBLAS Project
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in
the documentation and/or other materials provided with the
distribution.
3. Neither the name of the OpenBLAS project nor the names of
its contributors may be used to endorse or promote products
derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*****************************************************************************/

#include "common.h"

#include <arm_sve.h>

#define UPDATE_PTRSx2 \
a_ptr1 = a_ptr0 + lda;

#define UPDATE_PTRSx4 \
UPDATE_PTRSx2 \
a_ptr2 = a_ptr1 + lda; \
a_ptr3 = a_ptr2 + lda;

#define UPDATE_PTRSx8 \
UPDATE_PTRSx4 \
a_ptr4 = a_ptr3 + lda; \
a_ptr5 = a_ptr4 + lda; \
a_ptr6 = a_ptr5 + lda; \
a_ptr7 = a_ptr6 + lda;

#define LANESx2(MACRO, offset) \
MACRO(offset, 0) \
MACRO(offset, 1)

#define LANESx4(MACRO, offset) \
LANESx2(MACRO, offset) \
MACRO(offset, 2) \
MACRO(offset, 3)

#define LANESx8(MACRO, offset) \
LANESx4(MACRO, offset) \
MACRO(offset, 4) \
MACRO(offset, 5) \
MACRO(offset, 6) \
MACRO(offset, 7)

#define LOAD_A_VEC(offset, vec) \
svbfloat16_t a_vec ## offset ## vec = svld1(pg_full, &a_ptr ## vec[i] + offset * sve_size_bf16);

#define UPDATE_ACCUMULATORS_FROM_LANE(offset, lane) \
acc ## offset ## 0 = svbfmlalb_lane(acc ## offset ## 0, a_vec ## offset ## lane, x_vec, lane); \
acc ## offset ## 1 = svbfmlalt_lane(acc ## offset ## 1, a_vec ## offset ## lane, x_vec, lane);

#define INIT_ACCUMULATORS(offset) \
svfloat32_t acc ## offset ## 0 = svdup_f32(0.0); \
svfloat32_t acc ## offset ## 1 = svdup_f32(0.0);

#define UPDATE_ACCUMULATORS(offset) \
acc ## offset ## 0 = svbfmlalb(acc ## offset ## 0, a_vec ## offset ## 0, x_vec); \
acc ## offset ## 1 = svbfmlalt(acc ## offset ## 1, a_vec ## offset ## 0, x_vec);

#define STORE_ACCUMULATORS(offset) \
svbfloat16_t acc ## offset ## 0_bf16 = svcvt_bf16_x(pg_full, acc ## offset ## 0); \
svbfloat16_t acc ## offset ## 1_bf16 = svcvt_bf16_x(pg_full, acc ## offset ## 1); \
svbfloat16_t combined ## offset = svtrn1(acc ## offset ## 0_bf16, acc ## offset ## 1_bf16); \
svst1(pg_full, &y[i] + offset * sve_size_bf16, combined ## offset);

#define ALPHA_OP(offset) \
acc ## offset ## 0 = svmul_x(pg_full, acc ## offset ## 0, svalpha); \
acc ## offset ## 1 = svmul_x(pg_full, acc ## offset ## 1, svalpha);

#define BETA_OP(offset) \
svbfloat16_t y_vec ## offset = svld1(pg_full, &y[i] + offset * sve_size_bf16); \
acc ## offset ## 0 = svbfmlalb(acc ## offset ## 0, svbeta16, y_vec ## offset); \
acc ## offset ## 1 = svbfmlalt(acc ## offset ## 1, svbeta16, y_vec ## offset);

int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a_in, BLASLONG lda, IFLOAT *x_in, BLASLONG inc_x, FLOAT beta, FLOAT *y_in, BLASLONG inc_y)
{
BLASLONG ix;
bfloat16_t *a_ptr0, *a_ptr1, *a_ptr2, *a_ptr3, *a_ptr4, *a_ptr5, *a_ptr6, *a_ptr7;
BLASLONG sve_size_bf16 = svcnth();
BLASLONG sve_size2_bf16 = sve_size_bf16 * 2;
BLASLONG sve_size3_bf16 = sve_size_bf16 * 3;
svbool_t pg_full = svptrue_b16();
svbool_t pg_tail = svwhilelt_b16_s64(0, m % sve_size_bf16);

BLASLONG n8 = n & -8;
BLASLONG n4 = n & -4;
BLASLONG n2 = n & -2;
bfloat16_t *a = (bfloat16_t*)a_in;
bfloat16_t *x = (bfloat16_t*)x_in;
bfloat16_t *y = (bfloat16_t*)y_in;

bfloat16_t alpha_bf16, beta_bf16;
memcpy(&alpha_bf16, &alpha, sizeof(bfloat16_t));
memcpy(&beta_bf16, &beta, sizeof(bfloat16_t));
float beta_f32 = vcvtah_f32_bf16(beta_bf16);
float alpha_f32 = vcvtah_f32_bf16(alpha_bf16);
svfloat32_t svalpha = svdup_f32(vcvtah_f32_bf16(alpha_bf16));
svbfloat16_t svbeta16 = svdup_bf16(beta_bf16);
BLASLONG i = 0;
if (inc_y == 1) {
for (; (i + sve_size3_bf16 - 1) < m; i+= sve_size3_bf16) {
INIT_ACCUMULATORS(0);
INIT_ACCUMULATORS(1);
INIT_ACCUMULATORS(2);

BLASLONG j = 0;
ix = 0;
a_ptr0 = a;
UPDATE_PTRSx8;

if (inc_x == 1) {
for (; j < n4; j+= 4) {
uint64_t* x_u64 = (uint64_t*)&x[ix];
svbfloat16_t x_vec = svreinterpret_bf16_u64(svdup_u64(*x_u64));

LANESx4(LOAD_A_VEC, 0);
LANESx4(LOAD_A_VEC, 1);
LANESx4(LOAD_A_VEC, 2);
LANESx4(UPDATE_ACCUMULATORS_FROM_LANE, 0);
LANESx4(UPDATE_ACCUMULATORS_FROM_LANE, 1);
LANESx4(UPDATE_ACCUMULATORS_FROM_LANE, 2);

ix += 4;

a_ptr0 += 4 * lda;
UPDATE_PTRSx4;
}

for (; j < n2; j+= 2) {
uint32_t* x_u32 = (uint32_t*)&x[ix];
svbfloat16_t x_vec = svreinterpret_bf16_u32(svdup_u32(*x_u32));

LANESx2(LOAD_A_VEC, 0);
LANESx2(LOAD_A_VEC, 1);
LANESx2(LOAD_A_VEC, 2);
LANESx2(UPDATE_ACCUMULATORS_FROM_LANE, 0);
LANESx2(UPDATE_ACCUMULATORS_FROM_LANE, 1);
LANESx2(UPDATE_ACCUMULATORS_FROM_LANE, 2);

ix += 2;

a_ptr0 += 2 * lda;
UPDATE_PTRSx2;
}
}

for (; j < n; j++) {
svbfloat16_t x_vec = svdup_bf16(x[ix]);
LOAD_A_VEC(0, 0);
LOAD_A_VEC(1, 0);
LOAD_A_VEC(2, 0);
UPDATE_ACCUMULATORS(0);
UPDATE_ACCUMULATORS(1);
UPDATE_ACCUMULATORS(2);

ix += inc_x;
a_ptr0 += lda;
}

if (alpha_f32 != ONE) {
ALPHA_OP(0);
ALPHA_OP(1);
ALPHA_OP(2);
}
if (beta_f32 != ZERO) {
BETA_OP(0);
BETA_OP(1);
BETA_OP(2);
}

STORE_ACCUMULATORS(0);
STORE_ACCUMULATORS(1);
STORE_ACCUMULATORS(2);
}

for (; (i + sve_size_bf16 - 1) < m; i+= sve_size_bf16) {
INIT_ACCUMULATORS(0);

BLASLONG j = 0;
ix = 0;
a_ptr0 = a;
UPDATE_PTRSx8;

if (inc_x == 1) {
for (; j < n8; j+= 8) {
svbfloat16_t x_vec = svld1rq(pg_full, &x[ix]);
LANESx8(LOAD_A_VEC, 0);
LANESx8(UPDATE_ACCUMULATORS_FROM_LANE, 0);

ix += 8;

a_ptr0 += 8 * lda;
UPDATE_PTRSx8;
}

for (; j < n4; j+= 4) {
uint64_t* x_u64 = (uint64_t*)&x[ix];
svbfloat16_t x_vec = svreinterpret_bf16_u64(svdup_u64(*x_u64));

LANESx4(LOAD_A_VEC, 0);
LANESx4(UPDATE_ACCUMULATORS_FROM_LANE, 0);

ix += 4;

a_ptr0 += 4 * lda;
UPDATE_PTRSx4;
}

for (; j < n2; j+= 2) {
uint32_t* x_u32 = (uint32_t*)&x[ix];
svbfloat16_t x_vec = svreinterpret_bf16_u32(svdup_u32(*x_u32));

LANESx2(LOAD_A_VEC, 0);
LANESx2(UPDATE_ACCUMULATORS_FROM_LANE, 0);

ix += 2;

a_ptr0 += 2 * lda;
UPDATE_PTRSx2;
}
}

for (; j < n; j++) {
svbfloat16_t x_vec = svdup_bf16(x[ix]);
LOAD_A_VEC(0, 0);
UPDATE_ACCUMULATORS(0);

ix += inc_x;
a_ptr0 += lda;
}

if (alpha_f32 != ONE) {
ALPHA_OP(0);
}
if (beta_f32 != ZERO) {
BETA_OP(0);
}

STORE_ACCUMULATORS(0);
}

if (i < m) {
svfloat32_t acc0 = svdup_f32(0.0);
svfloat32_t acc1 = svdup_f32(0.0);

ix = 0;
a_ptr0 = a;
for (BLASLONG j = 0; j < n; j++) {
svbfloat16_t x_vec0 = svdup_bf16(x[ix]);
svbfloat16_t a_vec0 = svld1(pg_tail, &a_ptr0[i]);

acc0 = svbfmlalb(acc0, a_vec0, x_vec0);
acc1 = svbfmlalt(acc1, a_vec0, x_vec0);

ix += inc_x;
a_ptr0 += lda;
}

if (alpha_f32 != ONE) {
acc0 = svmul_x(pg_full, acc0, svalpha);
acc1 = svmul_x(pg_full, acc1, svalpha);
}
if (beta_f32 != ZERO) {
svbfloat16_t y_vec = svld1(pg_tail, &y[i]);
acc0 = svbfmlalb(acc0, svbeta16, y_vec);
acc1 = svbfmlalt(acc1, svbeta16, y_vec);
}

svbfloat16_t acc0_bf16 = svcvt_bf16_x(pg_full, acc0);
svbfloat16_t acc1_bf16 = svcvt_bf16_x(pg_full, acc1);
svbfloat16_t combined = svtrn1(acc0_bf16, acc1_bf16);
svst1(pg_tail, &y[i], combined);
}

return 0;
}

// Scalar fallback
BLASLONG iy = 0;
for (; i < m; i++) {
float temp = 0.0;

ix = 0;
a_ptr0 = a;
for (BLASLONG j = 0; j < n; j++) {
temp += vcvtah_f32_bf16(a_ptr0[i]) * vcvtah_f32_bf16(x[ix]);
ix += inc_x;
a_ptr0 += lda;
}

if (beta_f32 == ZERO) {
y[iy] = vcvth_bf16_f32(alpha_f32 * temp);
} else {
y[iy] = vcvth_bf16_f32(alpha_f32 * temp + beta_f32 * vcvtah_f32_bf16(y[iy]));
}

iy += inc_y;
}

return (0);
}

+ 60
- 32
kernel/arm64/sbgemv_t_bfdot.c View File

@@ -33,16 +33,39 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include <arm_neon.h>
#include "common.h"

int CNAME(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, BLASLONG incx, float beta, float *y, BLASLONG incy)
#ifdef BGEMM
#define INNER_FLOAT bfloat16_t
#define TO32(x) vcvtah_f32_bf16(x)
#define FROM32(x) vcvth_bf16_f32(x)
#else
#define INNER_FLOAT float
#define TO32(x) x
#define FROM32(x) x
#endif

int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, BLASLONG incx, FLOAT beta, FLOAT *y_in, BLASLONG incy)
{
if (m < 1 || n < 1) return(0);

BLASLONG i;
BLASLONG ix,iy;
BLASLONG j;
bfloat16_t *a_ptr;
bfloat16_t *x_ptr;
float *y_ptr;
float temp;
float temp, temp0, temp1, temp2, temp3;

#ifdef BGEMM
bfloat16_t alpha_bf16, beta_bf16;
memcpy(&alpha_bf16, &alpha, sizeof(bfloat16_t));
memcpy(&beta_bf16, &beta, sizeof(bfloat16_t));
float alpha_f32 = vcvtah_f32_bf16(alpha_bf16);
float beta_f32 = vcvtah_f32_bf16(beta_bf16);
#else
float alpha_f32 = alpha;
float beta_f32 = beta;
#endif
INNER_FLOAT *y = (INNER_FLOAT *)y_in;
INNER_FLOAT *y_ptr;

iy = 0;
a_ptr = (bfloat16_t*)(a);
@@ -56,10 +79,10 @@ int CNAME(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat
bfloat16_t *a2_ptr = a_ptr + lda * width * 2;
bfloat16_t *a3_ptr = a_ptr + lda * width * 3;

float *y0_ptr = y + incy * width * 0;
float *y1_ptr = y + incy * width * 1;
float *y2_ptr = y + incy * width * 2;
float *y3_ptr = y + incy * width * 3;
INNER_FLOAT *y0_ptr = y + incy * width * 0;
INNER_FLOAT *y1_ptr = y + incy * width * 1;
INNER_FLOAT *y2_ptr = y + incy * width * 2;
INNER_FLOAT *y3_ptr = y + incy * width * 3;

for (j = 0; j < width; j++) {
float32x4_t temp0_vec = vdupq_n_f32(0.0f);
@@ -113,26 +136,31 @@ int CNAME(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat

i += 4;
}
if (beta == 0.0f) {
y0_ptr[iy] = alpha * vaddvq_f32(temp0_vec);
y1_ptr[iy] = alpha * vaddvq_f32(temp1_vec);
y2_ptr[iy] = alpha * vaddvq_f32(temp2_vec);
y3_ptr[iy] = alpha * vaddvq_f32(temp3_vec);
}
else {
y0_ptr[iy] = alpha * vaddvq_f32(temp0_vec) + beta * y0_ptr[iy];
y1_ptr[iy] = alpha * vaddvq_f32(temp1_vec) + beta * y1_ptr[iy];
y2_ptr[iy] = alpha * vaddvq_f32(temp2_vec) + beta * y2_ptr[iy];
y3_ptr[iy] = alpha * vaddvq_f32(temp3_vec) + beta * y3_ptr[iy];
if (beta_f32 == 0.0f) {
temp0 = alpha_f32 * vaddvq_f32(temp0_vec);
temp1 = alpha_f32 * vaddvq_f32(temp1_vec);
temp2 = alpha_f32 * vaddvq_f32(temp2_vec);
temp3 = alpha_f32 * vaddvq_f32(temp3_vec);
} else {
temp0 = alpha_f32 * vaddvq_f32(temp0_vec) + beta_f32 * TO32(y0_ptr[iy]);
temp1 = alpha_f32 * vaddvq_f32(temp1_vec) + beta_f32 * TO32(y1_ptr[iy]);
temp2 = alpha_f32 * vaddvq_f32(temp2_vec) + beta_f32 * TO32(y2_ptr[iy]);
temp3 = alpha_f32 * vaddvq_f32(temp3_vec) + beta_f32 * TO32(y3_ptr[iy]);
}

for (; i < m; ++i) {
y0_ptr[iy] += alpha * vcvtah_f32_bf16(a0_ptr[i]) * vcvtah_f32_bf16(x_ptr[i]);
y1_ptr[iy] += alpha * vcvtah_f32_bf16(a1_ptr[i]) * vcvtah_f32_bf16(x_ptr[i]);
y2_ptr[iy] += alpha * vcvtah_f32_bf16(a2_ptr[i]) * vcvtah_f32_bf16(x_ptr[i]);
y3_ptr[iy] += alpha * vcvtah_f32_bf16(a3_ptr[i]) * vcvtah_f32_bf16(x_ptr[i]);
temp0 = temp0 + alpha_f32 * vcvtah_f32_bf16(a0_ptr[i]) * vcvtah_f32_bf16(x_ptr[i]);
temp1 = temp1 + alpha_f32 * vcvtah_f32_bf16(a1_ptr[i]) * vcvtah_f32_bf16(x_ptr[i]);
temp2 = temp2 + alpha_f32 * vcvtah_f32_bf16(a2_ptr[i]) * vcvtah_f32_bf16(x_ptr[i]);
temp3 = temp3 + alpha_f32 * vcvtah_f32_bf16(a3_ptr[i]) * vcvtah_f32_bf16(x_ptr[i]);
}

y0_ptr[iy] = FROM32(temp0);
y1_ptr[iy] = FROM32(temp1);
y2_ptr[iy] = FROM32(temp2);
y3_ptr[iy] = FROM32(temp3);

iy += incy;

a0_ptr += lda;
@@ -164,16 +192,16 @@ int CNAME(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat

i += 4;
}
if (beta == 0.0f) {
y_ptr[iy] = alpha * vaddvq_f32(temp0_vec);
}
else {
y_ptr[iy] = alpha * vaddvq_f32(temp0_vec) + beta * y_ptr[iy];
}

if (beta_f32 == 0.0f) {
temp = alpha_f32 * vaddvq_f32(temp0_vec);
} else {
temp = alpha_f32 * vaddvq_f32(temp0_vec) + beta_f32 * TO32(y_ptr[iy]);
}
for (; i < m; ++i) {
y_ptr[iy] += alpha * vcvtah_f32_bf16(a_ptr[i]) * vcvtah_f32_bf16(x_ptr[i]);
temp += alpha_f32 * vcvtah_f32_bf16(a_ptr[i]) * vcvtah_f32_bf16(x_ptr[i]);
}
y_ptr[iy] = FROM32(temp);

iy += incy;

@@ -189,11 +217,11 @@ int CNAME(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat
temp += vcvtah_f32_bf16(a_ptr[i]) * vcvtah_f32_bf16(x_ptr[ix]);
ix += incx;
}
if (beta == 0.0f) {
y[iy] = alpha * temp;
if (beta_f32 == 0.0f) {
y[iy] = FROM32(alpha_f32 * temp);
}
else {
y[iy] = alpha * temp + beta * y[iy];
y[iy] = FROM32(alpha_f32 * temp + beta_f32 * TO32(y[iy]));
}
iy += incy;
a_ptr += lda;


+ 1
- 1
test/compare_sgemv_bgemv.c View File

@@ -127,7 +127,7 @@ int main(int argc, char *argv[])
if (!is_close(float16to32(CC[j << l]), truncate_float32_to_bfloat16(DD[j]), 0.001, 0.0001))
{
#ifdef DEBUG
printf("Mismatch at trans=%c, alpha=%.2f, beta=%.2f, i=%d, j=%d, k=%ld: CC=%.6f, C=%.6f\n",
printf("Mismatch at trans=%c, alpha=%.2f, beta=%.2f, i=%d, j=%d, k=%ld: CC=%.6f, DD=%.6f\n",
transA, alpha, beta, i, j, k, float16to32(CC[j << l]), truncate_float32_to_bfloat16(DD[j]));
#endif
ret++;


Loading…
Cancel
Save