- Adds bgemv T based off of sbgemv T kernel - Adds bgemv N which is slightly alterated to not use Y as an accumulator due to the output being bf16 which results in loss of precision - Enables BGEMM_GEMV_FORWARD to proxy BGEMM to BGEMV with new kernelspull/5394/head
@@ -277,6 +277,7 @@ endif | |||
ifeq ($(ARCH), arm64) | |||
GEMM_GEMV_FORWARD = 1 | |||
SBGEMM_GEMV_FORWARD = 1 | |||
BGEMM_GEMV_FORWARD = 1 | |||
endif | |||
ifeq ($(ARCH), riscv) | |||
GEMM_GEMV_FORWARD = 1 | |||
@@ -296,6 +297,9 @@ endif | |||
ifeq ($(SBGEMM_GEMV_FORWARD), 1) | |||
CCOMMON_OPT += -DSBGEMM_GEMV_FORWARD | |||
endif | |||
ifeq ($(BGEMM_GEMV_FORWARD), 1) | |||
CCOMMON_OPT += -DBGEMM_GEMV_FORWARD | |||
endif | |||
endif | |||
# This operation is expensive, so execution should be once. | |||
@@ -84,7 +84,7 @@ GOTO_LAPACK_TARGETS= | |||
endif | |||
ifeq ($(BUILD_BFLOAT16),1) | |||
GOTO_BFLOAT_TARGETS=bgemm.goto sbgemm.goto | |||
GOTO_BFLOAT_TARGETS=bgemm.goto sbgemm.goto bgemv.goto sbgemv.goto | |||
else | |||
GOTO_BFLOAT_TARGETS= | |||
endif | |||
@@ -667,6 +667,10 @@ bgemm.goto : bgemm.$(SUFFIX) ../$(LIBNAME) | |||
$(CC) $(CFLAGS) -o $(@F) $^ $(CEXTRALIB) $(EXTRALIB) $(FEXTRALIB) -lm | |||
sbgemm.goto : sbgemm.$(SUFFIX) ../$(LIBNAME) | |||
$(CC) $(CFLAGS) -o $(@F) $^ $(CEXTRALIB) $(EXTRALIB) $(FEXTRALIB) -lm | |||
bgemv.goto : bgemv.$(SUFFIX) ../$(LIBNAME) | |||
$(CC) $(CFLAGS) -o $(@F) $^ $(CEXTRALIB) $(EXTRALIB) $(FEXTRALIB) -lm | |||
sbgemv.goto : sbgemv.$(SUFFIX) ../$(LIBNAME) | |||
$(CC) $(CFLAGS) -o $(@F) $^ $(CEXTRALIB) $(EXTRALIB) $(FEXTRALIB) -lm | |||
endif | |||
ifeq ($(BUILD_HFLOAT16),1) | |||
@@ -3146,6 +3150,13 @@ dgemv.$(SUFFIX) : gemv.c | |||
cgemv.$(SUFFIX) : gemv.c | |||
$(CC) $(CFLAGS) -c -DCOMPLEX -UDOUBLE -o $(@F) $^ | |||
ifeq ($(BUILD_BFLOAT16),1) | |||
bgemv.$(SUFFIX) : gemv.c | |||
$(CC) $(CFLAGS) -c -DBFLOAT16 -DBGEMM -UCOMPLEX -UDOUBLE -o $(@F) $^ | |||
sbgemv.$(SUFFIX) : gemv.c | |||
$(CC) $(CFLAGS) -c -DBFLOAT16 -UCOMPLEX -UDOUBLE -o $(@F) $^ | |||
endif () | |||
zgemv.$(SUFFIX) : gemv.c | |||
$(CC) $(CFLAGS) -c -DCOMPLEX -DDOUBLE -o $(@F) $^ | |||
@@ -1,5 +1,5 @@ | |||
/*************************************************************************** | |||
Copyright (c) 2014, The OpenBLAS Project | |||
Copyright (c) 2014, 2025 The OpenBLAS Project | |||
All rights reserved. | |||
Redistribution and use in source and binary forms, with or without | |||
modification, are permitted provided that the following conditions are | |||
@@ -34,6 +34,12 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||
#ifdef DOUBLE | |||
#define GEMV BLASFUNC(dgemv) | |||
#elif defined(BFLOAT16) && defined(BGEMM) | |||
#define GEMV BLASFUNC(bgemv) | |||
#elif defined(BFLOAT16) | |||
#define GEMV BLASFUNC(sbgemv) | |||
#undef IFLOAT | |||
#define IFLOAT bfloat16 | |||
#else | |||
#define GEMV BLASFUNC(sgemv) | |||
#endif | |||
@@ -49,9 +55,20 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||
#endif | |||
int main(int argc, char *argv[]){ | |||
FLOAT *a, *x, *y; | |||
FLOAT alpha[] = {1.0, 1.0}; | |||
FLOAT beta [] = {1.0, 0.0}; | |||
IFLOAT *a, *x; | |||
FLOAT *y; | |||
#ifdef BGEMM | |||
blasint one=1; | |||
blasint two=2; | |||
float alpha_in[] = {1.0, 0.0}; | |||
float beta_in[] = {0.0, 0.0}; | |||
FLOAT alpha[2], beta[2]; | |||
sbstobf16_(&two, alpha_in, &one, alpha, &one); | |||
sbstobf16_(&two, beta_in, &one, beta, &one); | |||
#else | |||
FLOAT alpha[] = {1.0, 0.0}; | |||
FLOAT beta [] = {0.0, 0.0}; | |||
#endif | |||
char trans='N'; | |||
blasint m, i, j; | |||
blasint inc_x=1,inc_y=1; | |||
@@ -97,11 +114,11 @@ int main(int argc, char *argv[]){ | |||
fprintf(stderr, "From : %3d To : %3d Step = %3d Trans = '%c' Inc_x = %d Inc_y = %d Loops = %d\n", from, to, step,trans,inc_x,inc_y,loops); | |||
if (( a = (FLOAT *)malloc(sizeof(FLOAT) * tomax * tomax * COMPSIZE)) == NULL){ | |||
if (( a = (IFLOAT *)malloc(sizeof(IFLOAT) * tomax * tomax * COMPSIZE)) == NULL){ | |||
fprintf(stderr,"Out of Memory!!\n");exit(1); | |||
} | |||
if (( x = (FLOAT *)malloc(sizeof(FLOAT) * tomax * abs(inc_x) * COMPSIZE)) == NULL){ | |||
if (( x = (IFLOAT *)malloc(sizeof(IFLOAT) * tomax * abs(inc_x) * COMPSIZE)) == NULL){ | |||
fprintf(stderr,"Out of Memory!!\n");exit(1); | |||
} | |||
@@ -125,7 +142,7 @@ int main(int argc, char *argv[]){ | |||
fprintf(stderr, " %6dx%d : ", (int)m,(int)n); | |||
for(j = 0; j < m; j++){ | |||
for(i = 0; i < n * COMPSIZE; i++){ | |||
a[(long)i + (long)j * (long)m * COMPSIZE] = ((FLOAT) rand() / (FLOAT) RAND_MAX) - 0.5; | |||
a[(long)i + (long)j * (long)m * COMPSIZE] = ((IFLOAT) rand() / (FLOAT) RAND_MAX) - 0.5; | |||
} | |||
} | |||
@@ -133,7 +150,7 @@ int main(int argc, char *argv[]){ | |||
{ | |||
for(i = 0; i < n * COMPSIZE * abs(inc_x); i++){ | |||
x[i] = ((FLOAT) rand() / (FLOAT) RAND_MAX) - 0.5; | |||
x[i] = ((IFLOAT) rand() / (FLOAT) RAND_MAX) - 0.5; | |||
} | |||
for(i = 0; i < m * COMPSIZE * abs(inc_y); i++){ | |||
@@ -46,6 +46,9 @@ BGEMMOTCOPY = sbgemm_tcopy_$(BGEMM_UNROLL_N)_neoversev1.c | |||
BGEMMONCOPYOBJ = bgemm_oncopy$(TSUFFIX).$(SUFFIX) | |||
BGEMMOTCOPYOBJ = bgemm_otcopy$(TSUFFIX).$(SUFFIX) | |||
BGEMVTKERNEL = sbgemv_t_bfdot.c | |||
BGEMVNKERNEL = bgemv_n_sve_v3x4.c | |||
SBGEMM_BETA = sbgemm_beta_neoversev1.c | |||
SBGEMMKERNEL = sbgemm_kernel_$(SBGEMM_UNROLL_M)x$(SBGEMM_UNROLL_N)_neoversev1.c | |||
ifneq ($(SBGEMM_UNROLL_M), $(SBGEMM_UNROLL_N)) | |||
@@ -0,0 +1,321 @@ | |||
/*************************************************************************** | |||
Copyright (c) 2025 The OpenBLAS Project | |||
All rights reserved. | |||
Redistribution and use in source and binary forms, with or without | |||
modification, are permitted provided that the following conditions are | |||
met: | |||
1. Redistributions of source code must retain the above copyright | |||
notice, this list of conditions and the following disclaimer. | |||
2. Redistributions in binary form must reproduce the above copyright | |||
notice, this list of conditions and the following disclaimer in | |||
the documentation and/or other materials provided with the | |||
distribution. | |||
3. Neither the name of the OpenBLAS project nor the names of | |||
its contributors may be used to endorse or promote products | |||
derived from this software without specific prior written permission. | |||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE | |||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE | |||
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||
*****************************************************************************/ | |||
#include "common.h" | |||
#include <arm_sve.h> | |||
#define UPDATE_PTRSx2 \ | |||
a_ptr1 = a_ptr0 + lda; | |||
#define UPDATE_PTRSx4 \ | |||
UPDATE_PTRSx2 \ | |||
a_ptr2 = a_ptr1 + lda; \ | |||
a_ptr3 = a_ptr2 + lda; | |||
#define UPDATE_PTRSx8 \ | |||
UPDATE_PTRSx4 \ | |||
a_ptr4 = a_ptr3 + lda; \ | |||
a_ptr5 = a_ptr4 + lda; \ | |||
a_ptr6 = a_ptr5 + lda; \ | |||
a_ptr7 = a_ptr6 + lda; | |||
#define LANESx2(MACRO, offset) \ | |||
MACRO(offset, 0) \ | |||
MACRO(offset, 1) | |||
#define LANESx4(MACRO, offset) \ | |||
LANESx2(MACRO, offset) \ | |||
MACRO(offset, 2) \ | |||
MACRO(offset, 3) | |||
#define LANESx8(MACRO, offset) \ | |||
LANESx4(MACRO, offset) \ | |||
MACRO(offset, 4) \ | |||
MACRO(offset, 5) \ | |||
MACRO(offset, 6) \ | |||
MACRO(offset, 7) | |||
#define LOAD_A_VEC(offset, vec) \ | |||
svbfloat16_t a_vec ## offset ## vec = svld1(pg_full, &a_ptr ## vec[i] + offset * sve_size_bf16); | |||
#define UPDATE_ACCUMULATORS_FROM_LANE(offset, lane) \ | |||
acc ## offset ## 0 = svbfmlalb_lane(acc ## offset ## 0, a_vec ## offset ## lane, x_vec, lane); \ | |||
acc ## offset ## 1 = svbfmlalt_lane(acc ## offset ## 1, a_vec ## offset ## lane, x_vec, lane); | |||
#define INIT_ACCUMULATORS(offset) \ | |||
svfloat32_t acc ## offset ## 0 = svdup_f32(0.0); \ | |||
svfloat32_t acc ## offset ## 1 = svdup_f32(0.0); | |||
#define UPDATE_ACCUMULATORS(offset) \ | |||
acc ## offset ## 0 = svbfmlalb(acc ## offset ## 0, a_vec ## offset ## 0, x_vec); \ | |||
acc ## offset ## 1 = svbfmlalt(acc ## offset ## 1, a_vec ## offset ## 0, x_vec); | |||
#define STORE_ACCUMULATORS(offset) \ | |||
svbfloat16_t acc ## offset ## 0_bf16 = svcvt_bf16_x(pg_full, acc ## offset ## 0); \ | |||
svbfloat16_t acc ## offset ## 1_bf16 = svcvt_bf16_x(pg_full, acc ## offset ## 1); \ | |||
svbfloat16_t combined ## offset = svtrn1(acc ## offset ## 0_bf16, acc ## offset ## 1_bf16); \ | |||
svst1(pg_full, &y[i] + offset * sve_size_bf16, combined ## offset); | |||
#define ALPHA_OP(offset) \ | |||
acc ## offset ## 0 = svmul_x(pg_full, acc ## offset ## 0, svalpha); \ | |||
acc ## offset ## 1 = svmul_x(pg_full, acc ## offset ## 1, svalpha); | |||
#define BETA_OP(offset) \ | |||
svbfloat16_t y_vec ## offset = svld1(pg_full, &y[i] + offset * sve_size_bf16); \ | |||
acc ## offset ## 0 = svbfmlalb(acc ## offset ## 0, svbeta16, y_vec ## offset); \ | |||
acc ## offset ## 1 = svbfmlalt(acc ## offset ## 1, svbeta16, y_vec ## offset); | |||
int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a_in, BLASLONG lda, IFLOAT *x_in, BLASLONG inc_x, FLOAT beta, FLOAT *y_in, BLASLONG inc_y) | |||
{ | |||
BLASLONG ix; | |||
bfloat16_t *a_ptr0, *a_ptr1, *a_ptr2, *a_ptr3, *a_ptr4, *a_ptr5, *a_ptr6, *a_ptr7; | |||
BLASLONG sve_size_bf16 = svcnth(); | |||
BLASLONG sve_size2_bf16 = sve_size_bf16 * 2; | |||
BLASLONG sve_size3_bf16 = sve_size_bf16 * 3; | |||
svbool_t pg_full = svptrue_b16(); | |||
svbool_t pg_tail = svwhilelt_b16_s64(0, m % sve_size_bf16); | |||
BLASLONG n8 = n & -8; | |||
BLASLONG n4 = n & -4; | |||
BLASLONG n2 = n & -2; | |||
bfloat16_t *a = (bfloat16_t*)a_in; | |||
bfloat16_t *x = (bfloat16_t*)x_in; | |||
bfloat16_t *y = (bfloat16_t*)y_in; | |||
bfloat16_t alpha_bf16, beta_bf16; | |||
memcpy(&alpha_bf16, &alpha, sizeof(bfloat16_t)); | |||
memcpy(&beta_bf16, &beta, sizeof(bfloat16_t)); | |||
float beta_f32 = vcvtah_f32_bf16(beta_bf16); | |||
float alpha_f32 = vcvtah_f32_bf16(alpha_bf16); | |||
svfloat32_t svalpha = svdup_f32(vcvtah_f32_bf16(alpha_bf16)); | |||
svbfloat16_t svbeta16 = svdup_bf16(beta_bf16); | |||
BLASLONG i = 0; | |||
if (inc_y == 1) { | |||
for (; (i + sve_size3_bf16 - 1) < m; i+= sve_size3_bf16) { | |||
INIT_ACCUMULATORS(0); | |||
INIT_ACCUMULATORS(1); | |||
INIT_ACCUMULATORS(2); | |||
BLASLONG j = 0; | |||
ix = 0; | |||
a_ptr0 = a; | |||
UPDATE_PTRSx8; | |||
if (inc_x == 1) { | |||
for (; j < n4; j+= 4) { | |||
uint64_t* x_u64 = (uint64_t*)&x[ix]; | |||
svbfloat16_t x_vec = svreinterpret_bf16_u64(svdup_u64(*x_u64)); | |||
LANESx4(LOAD_A_VEC, 0); | |||
LANESx4(LOAD_A_VEC, 1); | |||
LANESx4(LOAD_A_VEC, 2); | |||
LANESx4(UPDATE_ACCUMULATORS_FROM_LANE, 0); | |||
LANESx4(UPDATE_ACCUMULATORS_FROM_LANE, 1); | |||
LANESx4(UPDATE_ACCUMULATORS_FROM_LANE, 2); | |||
ix += 4; | |||
a_ptr0 += 4 * lda; | |||
UPDATE_PTRSx4; | |||
} | |||
for (; j < n2; j+= 2) { | |||
uint32_t* x_u32 = (uint32_t*)&x[ix]; | |||
svbfloat16_t x_vec = svreinterpret_bf16_u32(svdup_u32(*x_u32)); | |||
LANESx2(LOAD_A_VEC, 0); | |||
LANESx2(LOAD_A_VEC, 1); | |||
LANESx2(LOAD_A_VEC, 2); | |||
LANESx2(UPDATE_ACCUMULATORS_FROM_LANE, 0); | |||
LANESx2(UPDATE_ACCUMULATORS_FROM_LANE, 1); | |||
LANESx2(UPDATE_ACCUMULATORS_FROM_LANE, 2); | |||
ix += 2; | |||
a_ptr0 += 2 * lda; | |||
UPDATE_PTRSx2; | |||
} | |||
} | |||
for (; j < n; j++) { | |||
svbfloat16_t x_vec = svdup_bf16(x[ix]); | |||
LOAD_A_VEC(0, 0); | |||
LOAD_A_VEC(1, 0); | |||
LOAD_A_VEC(2, 0); | |||
UPDATE_ACCUMULATORS(0); | |||
UPDATE_ACCUMULATORS(1); | |||
UPDATE_ACCUMULATORS(2); | |||
ix += inc_x; | |||
a_ptr0 += lda; | |||
} | |||
if (alpha_f32 != ONE) { | |||
ALPHA_OP(0); | |||
ALPHA_OP(1); | |||
ALPHA_OP(2); | |||
} | |||
if (beta_f32 != ZERO) { | |||
BETA_OP(0); | |||
BETA_OP(1); | |||
BETA_OP(2); | |||
} | |||
STORE_ACCUMULATORS(0); | |||
STORE_ACCUMULATORS(1); | |||
STORE_ACCUMULATORS(2); | |||
} | |||
for (; (i + sve_size_bf16 - 1) < m; i+= sve_size_bf16) { | |||
INIT_ACCUMULATORS(0); | |||
BLASLONG j = 0; | |||
ix = 0; | |||
a_ptr0 = a; | |||
UPDATE_PTRSx8; | |||
if (inc_x == 1) { | |||
for (; j < n8; j+= 8) { | |||
svbfloat16_t x_vec = svld1rq(pg_full, &x[ix]); | |||
LANESx8(LOAD_A_VEC, 0); | |||
LANESx8(UPDATE_ACCUMULATORS_FROM_LANE, 0); | |||
ix += 8; | |||
a_ptr0 += 8 * lda; | |||
UPDATE_PTRSx8; | |||
} | |||
for (; j < n4; j+= 4) { | |||
uint64_t* x_u64 = (uint64_t*)&x[ix]; | |||
svbfloat16_t x_vec = svreinterpret_bf16_u64(svdup_u64(*x_u64)); | |||
LANESx4(LOAD_A_VEC, 0); | |||
LANESx4(UPDATE_ACCUMULATORS_FROM_LANE, 0); | |||
ix += 4; | |||
a_ptr0 += 4 * lda; | |||
UPDATE_PTRSx4; | |||
} | |||
for (; j < n2; j+= 2) { | |||
uint32_t* x_u32 = (uint32_t*)&x[ix]; | |||
svbfloat16_t x_vec = svreinterpret_bf16_u32(svdup_u32(*x_u32)); | |||
LANESx2(LOAD_A_VEC, 0); | |||
LANESx2(UPDATE_ACCUMULATORS_FROM_LANE, 0); | |||
ix += 2; | |||
a_ptr0 += 2 * lda; | |||
UPDATE_PTRSx2; | |||
} | |||
} | |||
for (; j < n; j++) { | |||
svbfloat16_t x_vec = svdup_bf16(x[ix]); | |||
LOAD_A_VEC(0, 0); | |||
UPDATE_ACCUMULATORS(0); | |||
ix += inc_x; | |||
a_ptr0 += lda; | |||
} | |||
if (alpha_f32 != ONE) { | |||
ALPHA_OP(0); | |||
} | |||
if (beta_f32 != ZERO) { | |||
BETA_OP(0); | |||
} | |||
STORE_ACCUMULATORS(0); | |||
} | |||
if (i < m) { | |||
svfloat32_t acc0 = svdup_f32(0.0); | |||
svfloat32_t acc1 = svdup_f32(0.0); | |||
ix = 0; | |||
a_ptr0 = a; | |||
for (BLASLONG j = 0; j < n; j++) { | |||
svbfloat16_t x_vec0 = svdup_bf16(x[ix]); | |||
svbfloat16_t a_vec0 = svld1(pg_tail, &a_ptr0[i]); | |||
acc0 = svbfmlalb(acc0, a_vec0, x_vec0); | |||
acc1 = svbfmlalt(acc1, a_vec0, x_vec0); | |||
ix += inc_x; | |||
a_ptr0 += lda; | |||
} | |||
if (alpha_f32 != ONE) { | |||
acc0 = svmul_x(pg_full, acc0, svalpha); | |||
acc1 = svmul_x(pg_full, acc1, svalpha); | |||
} | |||
if (beta_f32 != ZERO) { | |||
svbfloat16_t y_vec = svld1(pg_tail, &y[i]); | |||
acc0 = svbfmlalb(acc0, svbeta16, y_vec); | |||
acc1 = svbfmlalt(acc1, svbeta16, y_vec); | |||
} | |||
svbfloat16_t acc0_bf16 = svcvt_bf16_x(pg_full, acc0); | |||
svbfloat16_t acc1_bf16 = svcvt_bf16_x(pg_full, acc1); | |||
svbfloat16_t combined = svtrn1(acc0_bf16, acc1_bf16); | |||
svst1(pg_tail, &y[i], combined); | |||
} | |||
return 0; | |||
} | |||
// Scalar fallback | |||
BLASLONG iy = 0; | |||
for (; i < m; i++) { | |||
float temp = 0.0; | |||
ix = 0; | |||
a_ptr0 = a; | |||
for (BLASLONG j = 0; j < n; j++) { | |||
temp += vcvtah_f32_bf16(a_ptr0[i]) * vcvtah_f32_bf16(x[ix]); | |||
ix += inc_x; | |||
a_ptr0 += lda; | |||
} | |||
if (beta_f32 == ZERO) { | |||
y[iy] = vcvth_bf16_f32(alpha_f32 * temp); | |||
} else { | |||
y[iy] = vcvth_bf16_f32(alpha_f32 * temp + beta_f32 * vcvtah_f32_bf16(y[iy])); | |||
} | |||
iy += inc_y; | |||
} | |||
return (0); | |||
} |
@@ -33,16 +33,39 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||
#include <arm_neon.h> | |||
#include "common.h" | |||
int CNAME(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, BLASLONG incx, float beta, float *y, BLASLONG incy) | |||
#ifdef BGEMM | |||
#define INNER_FLOAT bfloat16_t | |||
#define TO32(x) vcvtah_f32_bf16(x) | |||
#define FROM32(x) vcvth_bf16_f32(x) | |||
#else | |||
#define INNER_FLOAT float | |||
#define TO32(x) x | |||
#define FROM32(x) x | |||
#endif | |||
int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, BLASLONG incx, FLOAT beta, FLOAT *y_in, BLASLONG incy) | |||
{ | |||
if (m < 1 || n < 1) return(0); | |||
BLASLONG i; | |||
BLASLONG ix,iy; | |||
BLASLONG j; | |||
bfloat16_t *a_ptr; | |||
bfloat16_t *x_ptr; | |||
float *y_ptr; | |||
float temp; | |||
float temp, temp0, temp1, temp2, temp3; | |||
#ifdef BGEMM | |||
bfloat16_t alpha_bf16, beta_bf16; | |||
memcpy(&alpha_bf16, &alpha, sizeof(bfloat16_t)); | |||
memcpy(&beta_bf16, &beta, sizeof(bfloat16_t)); | |||
float alpha_f32 = vcvtah_f32_bf16(alpha_bf16); | |||
float beta_f32 = vcvtah_f32_bf16(beta_bf16); | |||
#else | |||
float alpha_f32 = alpha; | |||
float beta_f32 = beta; | |||
#endif | |||
INNER_FLOAT *y = (INNER_FLOAT *)y_in; | |||
INNER_FLOAT *y_ptr; | |||
iy = 0; | |||
a_ptr = (bfloat16_t*)(a); | |||
@@ -56,10 +79,10 @@ int CNAME(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat | |||
bfloat16_t *a2_ptr = a_ptr + lda * width * 2; | |||
bfloat16_t *a3_ptr = a_ptr + lda * width * 3; | |||
float *y0_ptr = y + incy * width * 0; | |||
float *y1_ptr = y + incy * width * 1; | |||
float *y2_ptr = y + incy * width * 2; | |||
float *y3_ptr = y + incy * width * 3; | |||
INNER_FLOAT *y0_ptr = y + incy * width * 0; | |||
INNER_FLOAT *y1_ptr = y + incy * width * 1; | |||
INNER_FLOAT *y2_ptr = y + incy * width * 2; | |||
INNER_FLOAT *y3_ptr = y + incy * width * 3; | |||
for (j = 0; j < width; j++) { | |||
float32x4_t temp0_vec = vdupq_n_f32(0.0f); | |||
@@ -113,26 +136,31 @@ int CNAME(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat | |||
i += 4; | |||
} | |||
if (beta == 0.0f) { | |||
y0_ptr[iy] = alpha * vaddvq_f32(temp0_vec); | |||
y1_ptr[iy] = alpha * vaddvq_f32(temp1_vec); | |||
y2_ptr[iy] = alpha * vaddvq_f32(temp2_vec); | |||
y3_ptr[iy] = alpha * vaddvq_f32(temp3_vec); | |||
} | |||
else { | |||
y0_ptr[iy] = alpha * vaddvq_f32(temp0_vec) + beta * y0_ptr[iy]; | |||
y1_ptr[iy] = alpha * vaddvq_f32(temp1_vec) + beta * y1_ptr[iy]; | |||
y2_ptr[iy] = alpha * vaddvq_f32(temp2_vec) + beta * y2_ptr[iy]; | |||
y3_ptr[iy] = alpha * vaddvq_f32(temp3_vec) + beta * y3_ptr[iy]; | |||
if (beta_f32 == 0.0f) { | |||
temp0 = alpha_f32 * vaddvq_f32(temp0_vec); | |||
temp1 = alpha_f32 * vaddvq_f32(temp1_vec); | |||
temp2 = alpha_f32 * vaddvq_f32(temp2_vec); | |||
temp3 = alpha_f32 * vaddvq_f32(temp3_vec); | |||
} else { | |||
temp0 = alpha_f32 * vaddvq_f32(temp0_vec) + beta_f32 * TO32(y0_ptr[iy]); | |||
temp1 = alpha_f32 * vaddvq_f32(temp1_vec) + beta_f32 * TO32(y1_ptr[iy]); | |||
temp2 = alpha_f32 * vaddvq_f32(temp2_vec) + beta_f32 * TO32(y2_ptr[iy]); | |||
temp3 = alpha_f32 * vaddvq_f32(temp3_vec) + beta_f32 * TO32(y3_ptr[iy]); | |||
} | |||
for (; i < m; ++i) { | |||
y0_ptr[iy] += alpha * vcvtah_f32_bf16(a0_ptr[i]) * vcvtah_f32_bf16(x_ptr[i]); | |||
y1_ptr[iy] += alpha * vcvtah_f32_bf16(a1_ptr[i]) * vcvtah_f32_bf16(x_ptr[i]); | |||
y2_ptr[iy] += alpha * vcvtah_f32_bf16(a2_ptr[i]) * vcvtah_f32_bf16(x_ptr[i]); | |||
y3_ptr[iy] += alpha * vcvtah_f32_bf16(a3_ptr[i]) * vcvtah_f32_bf16(x_ptr[i]); | |||
temp0 = temp0 + alpha_f32 * vcvtah_f32_bf16(a0_ptr[i]) * vcvtah_f32_bf16(x_ptr[i]); | |||
temp1 = temp1 + alpha_f32 * vcvtah_f32_bf16(a1_ptr[i]) * vcvtah_f32_bf16(x_ptr[i]); | |||
temp2 = temp2 + alpha_f32 * vcvtah_f32_bf16(a2_ptr[i]) * vcvtah_f32_bf16(x_ptr[i]); | |||
temp3 = temp3 + alpha_f32 * vcvtah_f32_bf16(a3_ptr[i]) * vcvtah_f32_bf16(x_ptr[i]); | |||
} | |||
y0_ptr[iy] = FROM32(temp0); | |||
y1_ptr[iy] = FROM32(temp1); | |||
y2_ptr[iy] = FROM32(temp2); | |||
y3_ptr[iy] = FROM32(temp3); | |||
iy += incy; | |||
a0_ptr += lda; | |||
@@ -164,16 +192,16 @@ int CNAME(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat | |||
i += 4; | |||
} | |||
if (beta == 0.0f) { | |||
y_ptr[iy] = alpha * vaddvq_f32(temp0_vec); | |||
} | |||
else { | |||
y_ptr[iy] = alpha * vaddvq_f32(temp0_vec) + beta * y_ptr[iy]; | |||
} | |||
if (beta_f32 == 0.0f) { | |||
temp = alpha_f32 * vaddvq_f32(temp0_vec); | |||
} else { | |||
temp = alpha_f32 * vaddvq_f32(temp0_vec) + beta_f32 * TO32(y_ptr[iy]); | |||
} | |||
for (; i < m; ++i) { | |||
y_ptr[iy] += alpha * vcvtah_f32_bf16(a_ptr[i]) * vcvtah_f32_bf16(x_ptr[i]); | |||
temp += alpha_f32 * vcvtah_f32_bf16(a_ptr[i]) * vcvtah_f32_bf16(x_ptr[i]); | |||
} | |||
y_ptr[iy] = FROM32(temp); | |||
iy += incy; | |||
@@ -189,11 +217,11 @@ int CNAME(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat | |||
temp += vcvtah_f32_bf16(a_ptr[i]) * vcvtah_f32_bf16(x_ptr[ix]); | |||
ix += incx; | |||
} | |||
if (beta == 0.0f) { | |||
y[iy] = alpha * temp; | |||
if (beta_f32 == 0.0f) { | |||
y[iy] = FROM32(alpha_f32 * temp); | |||
} | |||
else { | |||
y[iy] = alpha * temp + beta * y[iy]; | |||
y[iy] = FROM32(alpha_f32 * temp + beta_f32 * TO32(y[iy])); | |||
} | |||
iy += incy; | |||
a_ptr += lda; | |||
@@ -127,7 +127,7 @@ int main(int argc, char *argv[]) | |||
if (!is_close(float16to32(CC[j << l]), truncate_float32_to_bfloat16(DD[j]), 0.001, 0.0001)) | |||
{ | |||
#ifdef DEBUG | |||
printf("Mismatch at trans=%c, alpha=%.2f, beta=%.2f, i=%d, j=%d, k=%ld: CC=%.6f, C=%.6f\n", | |||
printf("Mismatch at trans=%c, alpha=%.2f, beta=%.2f, i=%d, j=%d, k=%ld: CC=%.6f, DD=%.6f\n", | |||
transA, alpha, beta, i, j, k, float16to32(CC[j << l]), truncate_float32_to_bfloat16(DD[j])); | |||
#endif | |||
ret++; | |||