- Adds bgemv T based off of sbgemv T kernel - Adds bgemv N which is slightly alterated to not use Y as an accumulator due to the output being bf16 which results in loss of precision - Enables BGEMM_GEMV_FORWARD to proxy BGEMM to BGEMV with new kernelspull/5394/head
| @@ -277,6 +277,7 @@ endif | |||
| ifeq ($(ARCH), arm64) | |||
| GEMM_GEMV_FORWARD = 1 | |||
| SBGEMM_GEMV_FORWARD = 1 | |||
| BGEMM_GEMV_FORWARD = 1 | |||
| endif | |||
| ifeq ($(ARCH), riscv) | |||
| GEMM_GEMV_FORWARD = 1 | |||
| @@ -296,6 +297,9 @@ endif | |||
| ifeq ($(SBGEMM_GEMV_FORWARD), 1) | |||
| CCOMMON_OPT += -DSBGEMM_GEMV_FORWARD | |||
| endif | |||
| ifeq ($(BGEMM_GEMV_FORWARD), 1) | |||
| CCOMMON_OPT += -DBGEMM_GEMV_FORWARD | |||
| endif | |||
| endif | |||
| # This operation is expensive, so execution should be once. | |||
| @@ -84,7 +84,7 @@ GOTO_LAPACK_TARGETS= | |||
| endif | |||
| ifeq ($(BUILD_BFLOAT16),1) | |||
| GOTO_BFLOAT_TARGETS=bgemm.goto sbgemm.goto | |||
| GOTO_BFLOAT_TARGETS=bgemm.goto sbgemm.goto bgemv.goto sbgemv.goto | |||
| else | |||
| GOTO_BFLOAT_TARGETS= | |||
| endif | |||
| @@ -667,6 +667,10 @@ bgemm.goto : bgemm.$(SUFFIX) ../$(LIBNAME) | |||
| $(CC) $(CFLAGS) -o $(@F) $^ $(CEXTRALIB) $(EXTRALIB) $(FEXTRALIB) -lm | |||
| sbgemm.goto : sbgemm.$(SUFFIX) ../$(LIBNAME) | |||
| $(CC) $(CFLAGS) -o $(@F) $^ $(CEXTRALIB) $(EXTRALIB) $(FEXTRALIB) -lm | |||
| bgemv.goto : bgemv.$(SUFFIX) ../$(LIBNAME) | |||
| $(CC) $(CFLAGS) -o $(@F) $^ $(CEXTRALIB) $(EXTRALIB) $(FEXTRALIB) -lm | |||
| sbgemv.goto : sbgemv.$(SUFFIX) ../$(LIBNAME) | |||
| $(CC) $(CFLAGS) -o $(@F) $^ $(CEXTRALIB) $(EXTRALIB) $(FEXTRALIB) -lm | |||
| endif | |||
| ifeq ($(BUILD_HFLOAT16),1) | |||
| @@ -3146,6 +3150,13 @@ dgemv.$(SUFFIX) : gemv.c | |||
| cgemv.$(SUFFIX) : gemv.c | |||
| $(CC) $(CFLAGS) -c -DCOMPLEX -UDOUBLE -o $(@F) $^ | |||
| ifeq ($(BUILD_BFLOAT16),1) | |||
| bgemv.$(SUFFIX) : gemv.c | |||
| $(CC) $(CFLAGS) -c -DBFLOAT16 -DBGEMM -UCOMPLEX -UDOUBLE -o $(@F) $^ | |||
| sbgemv.$(SUFFIX) : gemv.c | |||
| $(CC) $(CFLAGS) -c -DBFLOAT16 -UCOMPLEX -UDOUBLE -o $(@F) $^ | |||
| endif () | |||
| zgemv.$(SUFFIX) : gemv.c | |||
| $(CC) $(CFLAGS) -c -DCOMPLEX -DDOUBLE -o $(@F) $^ | |||
| @@ -1,5 +1,5 @@ | |||
| /*************************************************************************** | |||
| Copyright (c) 2014, The OpenBLAS Project | |||
| Copyright (c) 2014, 2025 The OpenBLAS Project | |||
| All rights reserved. | |||
| Redistribution and use in source and binary forms, with or without | |||
| modification, are permitted provided that the following conditions are | |||
| @@ -34,6 +34,12 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||
| #ifdef DOUBLE | |||
| #define GEMV BLASFUNC(dgemv) | |||
| #elif defined(BFLOAT16) && defined(BGEMM) | |||
| #define GEMV BLASFUNC(bgemv) | |||
| #elif defined(BFLOAT16) | |||
| #define GEMV BLASFUNC(sbgemv) | |||
| #undef IFLOAT | |||
| #define IFLOAT bfloat16 | |||
| #else | |||
| #define GEMV BLASFUNC(sgemv) | |||
| #endif | |||
| @@ -49,9 +55,20 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||
| #endif | |||
| int main(int argc, char *argv[]){ | |||
| FLOAT *a, *x, *y; | |||
| FLOAT alpha[] = {1.0, 1.0}; | |||
| FLOAT beta [] = {1.0, 0.0}; | |||
| IFLOAT *a, *x; | |||
| FLOAT *y; | |||
| #ifdef BGEMM | |||
| blasint one=1; | |||
| blasint two=2; | |||
| float alpha_in[] = {1.0, 0.0}; | |||
| float beta_in[] = {0.0, 0.0}; | |||
| FLOAT alpha[2], beta[2]; | |||
| sbstobf16_(&two, alpha_in, &one, alpha, &one); | |||
| sbstobf16_(&two, beta_in, &one, beta, &one); | |||
| #else | |||
| FLOAT alpha[] = {1.0, 0.0}; | |||
| FLOAT beta [] = {0.0, 0.0}; | |||
| #endif | |||
| char trans='N'; | |||
| blasint m, i, j; | |||
| blasint inc_x=1,inc_y=1; | |||
| @@ -97,11 +114,11 @@ int main(int argc, char *argv[]){ | |||
| fprintf(stderr, "From : %3d To : %3d Step = %3d Trans = '%c' Inc_x = %d Inc_y = %d Loops = %d\n", from, to, step,trans,inc_x,inc_y,loops); | |||
| if (( a = (FLOAT *)malloc(sizeof(FLOAT) * tomax * tomax * COMPSIZE)) == NULL){ | |||
| if (( a = (IFLOAT *)malloc(sizeof(IFLOAT) * tomax * tomax * COMPSIZE)) == NULL){ | |||
| fprintf(stderr,"Out of Memory!!\n");exit(1); | |||
| } | |||
| if (( x = (FLOAT *)malloc(sizeof(FLOAT) * tomax * abs(inc_x) * COMPSIZE)) == NULL){ | |||
| if (( x = (IFLOAT *)malloc(sizeof(IFLOAT) * tomax * abs(inc_x) * COMPSIZE)) == NULL){ | |||
| fprintf(stderr,"Out of Memory!!\n");exit(1); | |||
| } | |||
| @@ -125,7 +142,7 @@ int main(int argc, char *argv[]){ | |||
| fprintf(stderr, " %6dx%d : ", (int)m,(int)n); | |||
| for(j = 0; j < m; j++){ | |||
| for(i = 0; i < n * COMPSIZE; i++){ | |||
| a[(long)i + (long)j * (long)m * COMPSIZE] = ((FLOAT) rand() / (FLOAT) RAND_MAX) - 0.5; | |||
| a[(long)i + (long)j * (long)m * COMPSIZE] = ((IFLOAT) rand() / (FLOAT) RAND_MAX) - 0.5; | |||
| } | |||
| } | |||
| @@ -133,7 +150,7 @@ int main(int argc, char *argv[]){ | |||
| { | |||
| for(i = 0; i < n * COMPSIZE * abs(inc_x); i++){ | |||
| x[i] = ((FLOAT) rand() / (FLOAT) RAND_MAX) - 0.5; | |||
| x[i] = ((IFLOAT) rand() / (FLOAT) RAND_MAX) - 0.5; | |||
| } | |||
| for(i = 0; i < m * COMPSIZE * abs(inc_y); i++){ | |||
| @@ -46,6 +46,9 @@ BGEMMOTCOPY = sbgemm_tcopy_$(BGEMM_UNROLL_N)_neoversev1.c | |||
| BGEMMONCOPYOBJ = bgemm_oncopy$(TSUFFIX).$(SUFFIX) | |||
| BGEMMOTCOPYOBJ = bgemm_otcopy$(TSUFFIX).$(SUFFIX) | |||
| BGEMVTKERNEL = sbgemv_t_bfdot.c | |||
| BGEMVNKERNEL = bgemv_n_sve_v3x4.c | |||
| SBGEMM_BETA = sbgemm_beta_neoversev1.c | |||
| SBGEMMKERNEL = sbgemm_kernel_$(SBGEMM_UNROLL_M)x$(SBGEMM_UNROLL_N)_neoversev1.c | |||
| ifneq ($(SBGEMM_UNROLL_M), $(SBGEMM_UNROLL_N)) | |||
| @@ -0,0 +1,321 @@ | |||
| /*************************************************************************** | |||
| Copyright (c) 2025 The OpenBLAS Project | |||
| All rights reserved. | |||
| Redistribution and use in source and binary forms, with or without | |||
| modification, are permitted provided that the following conditions are | |||
| met: | |||
| 1. Redistributions of source code must retain the above copyright | |||
| notice, this list of conditions and the following disclaimer. | |||
| 2. Redistributions in binary form must reproduce the above copyright | |||
| notice, this list of conditions and the following disclaimer in | |||
| the documentation and/or other materials provided with the | |||
| distribution. | |||
| 3. Neither the name of the OpenBLAS project nor the names of | |||
| its contributors may be used to endorse or promote products | |||
| derived from this software without specific prior written permission. | |||
| THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||
| AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||
| IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||
| ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE | |||
| LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |||
| DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |||
| SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |||
| CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |||
| OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE | |||
| USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||
| *****************************************************************************/ | |||
| #include "common.h" | |||
| #include <arm_sve.h> | |||
| #define UPDATE_PTRSx2 \ | |||
| a_ptr1 = a_ptr0 + lda; | |||
| #define UPDATE_PTRSx4 \ | |||
| UPDATE_PTRSx2 \ | |||
| a_ptr2 = a_ptr1 + lda; \ | |||
| a_ptr3 = a_ptr2 + lda; | |||
| #define UPDATE_PTRSx8 \ | |||
| UPDATE_PTRSx4 \ | |||
| a_ptr4 = a_ptr3 + lda; \ | |||
| a_ptr5 = a_ptr4 + lda; \ | |||
| a_ptr6 = a_ptr5 + lda; \ | |||
| a_ptr7 = a_ptr6 + lda; | |||
| #define LANESx2(MACRO, offset) \ | |||
| MACRO(offset, 0) \ | |||
| MACRO(offset, 1) | |||
| #define LANESx4(MACRO, offset) \ | |||
| LANESx2(MACRO, offset) \ | |||
| MACRO(offset, 2) \ | |||
| MACRO(offset, 3) | |||
| #define LANESx8(MACRO, offset) \ | |||
| LANESx4(MACRO, offset) \ | |||
| MACRO(offset, 4) \ | |||
| MACRO(offset, 5) \ | |||
| MACRO(offset, 6) \ | |||
| MACRO(offset, 7) | |||
| #define LOAD_A_VEC(offset, vec) \ | |||
| svbfloat16_t a_vec ## offset ## vec = svld1(pg_full, &a_ptr ## vec[i] + offset * sve_size_bf16); | |||
| #define UPDATE_ACCUMULATORS_FROM_LANE(offset, lane) \ | |||
| acc ## offset ## 0 = svbfmlalb_lane(acc ## offset ## 0, a_vec ## offset ## lane, x_vec, lane); \ | |||
| acc ## offset ## 1 = svbfmlalt_lane(acc ## offset ## 1, a_vec ## offset ## lane, x_vec, lane); | |||
| #define INIT_ACCUMULATORS(offset) \ | |||
| svfloat32_t acc ## offset ## 0 = svdup_f32(0.0); \ | |||
| svfloat32_t acc ## offset ## 1 = svdup_f32(0.0); | |||
| #define UPDATE_ACCUMULATORS(offset) \ | |||
| acc ## offset ## 0 = svbfmlalb(acc ## offset ## 0, a_vec ## offset ## 0, x_vec); \ | |||
| acc ## offset ## 1 = svbfmlalt(acc ## offset ## 1, a_vec ## offset ## 0, x_vec); | |||
| #define STORE_ACCUMULATORS(offset) \ | |||
| svbfloat16_t acc ## offset ## 0_bf16 = svcvt_bf16_x(pg_full, acc ## offset ## 0); \ | |||
| svbfloat16_t acc ## offset ## 1_bf16 = svcvt_bf16_x(pg_full, acc ## offset ## 1); \ | |||
| svbfloat16_t combined ## offset = svtrn1(acc ## offset ## 0_bf16, acc ## offset ## 1_bf16); \ | |||
| svst1(pg_full, &y[i] + offset * sve_size_bf16, combined ## offset); | |||
| #define ALPHA_OP(offset) \ | |||
| acc ## offset ## 0 = svmul_x(pg_full, acc ## offset ## 0, svalpha); \ | |||
| acc ## offset ## 1 = svmul_x(pg_full, acc ## offset ## 1, svalpha); | |||
| #define BETA_OP(offset) \ | |||
| svbfloat16_t y_vec ## offset = svld1(pg_full, &y[i] + offset * sve_size_bf16); \ | |||
| acc ## offset ## 0 = svbfmlalb(acc ## offset ## 0, svbeta16, y_vec ## offset); \ | |||
| acc ## offset ## 1 = svbfmlalt(acc ## offset ## 1, svbeta16, y_vec ## offset); | |||
| int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a_in, BLASLONG lda, IFLOAT *x_in, BLASLONG inc_x, FLOAT beta, FLOAT *y_in, BLASLONG inc_y) | |||
| { | |||
| BLASLONG ix; | |||
| bfloat16_t *a_ptr0, *a_ptr1, *a_ptr2, *a_ptr3, *a_ptr4, *a_ptr5, *a_ptr6, *a_ptr7; | |||
| BLASLONG sve_size_bf16 = svcnth(); | |||
| BLASLONG sve_size2_bf16 = sve_size_bf16 * 2; | |||
| BLASLONG sve_size3_bf16 = sve_size_bf16 * 3; | |||
| svbool_t pg_full = svptrue_b16(); | |||
| svbool_t pg_tail = svwhilelt_b16_s64(0, m % sve_size_bf16); | |||
| BLASLONG n8 = n & -8; | |||
| BLASLONG n4 = n & -4; | |||
| BLASLONG n2 = n & -2; | |||
| bfloat16_t *a = (bfloat16_t*)a_in; | |||
| bfloat16_t *x = (bfloat16_t*)x_in; | |||
| bfloat16_t *y = (bfloat16_t*)y_in; | |||
| bfloat16_t alpha_bf16, beta_bf16; | |||
| memcpy(&alpha_bf16, &alpha, sizeof(bfloat16_t)); | |||
| memcpy(&beta_bf16, &beta, sizeof(bfloat16_t)); | |||
| float beta_f32 = vcvtah_f32_bf16(beta_bf16); | |||
| float alpha_f32 = vcvtah_f32_bf16(alpha_bf16); | |||
| svfloat32_t svalpha = svdup_f32(vcvtah_f32_bf16(alpha_bf16)); | |||
| svbfloat16_t svbeta16 = svdup_bf16(beta_bf16); | |||
| BLASLONG i = 0; | |||
| if (inc_y == 1) { | |||
| for (; (i + sve_size3_bf16 - 1) < m; i+= sve_size3_bf16) { | |||
| INIT_ACCUMULATORS(0); | |||
| INIT_ACCUMULATORS(1); | |||
| INIT_ACCUMULATORS(2); | |||
| BLASLONG j = 0; | |||
| ix = 0; | |||
| a_ptr0 = a; | |||
| UPDATE_PTRSx8; | |||
| if (inc_x == 1) { | |||
| for (; j < n4; j+= 4) { | |||
| uint64_t* x_u64 = (uint64_t*)&x[ix]; | |||
| svbfloat16_t x_vec = svreinterpret_bf16_u64(svdup_u64(*x_u64)); | |||
| LANESx4(LOAD_A_VEC, 0); | |||
| LANESx4(LOAD_A_VEC, 1); | |||
| LANESx4(LOAD_A_VEC, 2); | |||
| LANESx4(UPDATE_ACCUMULATORS_FROM_LANE, 0); | |||
| LANESx4(UPDATE_ACCUMULATORS_FROM_LANE, 1); | |||
| LANESx4(UPDATE_ACCUMULATORS_FROM_LANE, 2); | |||
| ix += 4; | |||
| a_ptr0 += 4 * lda; | |||
| UPDATE_PTRSx4; | |||
| } | |||
| for (; j < n2; j+= 2) { | |||
| uint32_t* x_u32 = (uint32_t*)&x[ix]; | |||
| svbfloat16_t x_vec = svreinterpret_bf16_u32(svdup_u32(*x_u32)); | |||
| LANESx2(LOAD_A_VEC, 0); | |||
| LANESx2(LOAD_A_VEC, 1); | |||
| LANESx2(LOAD_A_VEC, 2); | |||
| LANESx2(UPDATE_ACCUMULATORS_FROM_LANE, 0); | |||
| LANESx2(UPDATE_ACCUMULATORS_FROM_LANE, 1); | |||
| LANESx2(UPDATE_ACCUMULATORS_FROM_LANE, 2); | |||
| ix += 2; | |||
| a_ptr0 += 2 * lda; | |||
| UPDATE_PTRSx2; | |||
| } | |||
| } | |||
| for (; j < n; j++) { | |||
| svbfloat16_t x_vec = svdup_bf16(x[ix]); | |||
| LOAD_A_VEC(0, 0); | |||
| LOAD_A_VEC(1, 0); | |||
| LOAD_A_VEC(2, 0); | |||
| UPDATE_ACCUMULATORS(0); | |||
| UPDATE_ACCUMULATORS(1); | |||
| UPDATE_ACCUMULATORS(2); | |||
| ix += inc_x; | |||
| a_ptr0 += lda; | |||
| } | |||
| if (alpha_f32 != ONE) { | |||
| ALPHA_OP(0); | |||
| ALPHA_OP(1); | |||
| ALPHA_OP(2); | |||
| } | |||
| if (beta_f32 != ZERO) { | |||
| BETA_OP(0); | |||
| BETA_OP(1); | |||
| BETA_OP(2); | |||
| } | |||
| STORE_ACCUMULATORS(0); | |||
| STORE_ACCUMULATORS(1); | |||
| STORE_ACCUMULATORS(2); | |||
| } | |||
| for (; (i + sve_size_bf16 - 1) < m; i+= sve_size_bf16) { | |||
| INIT_ACCUMULATORS(0); | |||
| BLASLONG j = 0; | |||
| ix = 0; | |||
| a_ptr0 = a; | |||
| UPDATE_PTRSx8; | |||
| if (inc_x == 1) { | |||
| for (; j < n8; j+= 8) { | |||
| svbfloat16_t x_vec = svld1rq(pg_full, &x[ix]); | |||
| LANESx8(LOAD_A_VEC, 0); | |||
| LANESx8(UPDATE_ACCUMULATORS_FROM_LANE, 0); | |||
| ix += 8; | |||
| a_ptr0 += 8 * lda; | |||
| UPDATE_PTRSx8; | |||
| } | |||
| for (; j < n4; j+= 4) { | |||
| uint64_t* x_u64 = (uint64_t*)&x[ix]; | |||
| svbfloat16_t x_vec = svreinterpret_bf16_u64(svdup_u64(*x_u64)); | |||
| LANESx4(LOAD_A_VEC, 0); | |||
| LANESx4(UPDATE_ACCUMULATORS_FROM_LANE, 0); | |||
| ix += 4; | |||
| a_ptr0 += 4 * lda; | |||
| UPDATE_PTRSx4; | |||
| } | |||
| for (; j < n2; j+= 2) { | |||
| uint32_t* x_u32 = (uint32_t*)&x[ix]; | |||
| svbfloat16_t x_vec = svreinterpret_bf16_u32(svdup_u32(*x_u32)); | |||
| LANESx2(LOAD_A_VEC, 0); | |||
| LANESx2(UPDATE_ACCUMULATORS_FROM_LANE, 0); | |||
| ix += 2; | |||
| a_ptr0 += 2 * lda; | |||
| UPDATE_PTRSx2; | |||
| } | |||
| } | |||
| for (; j < n; j++) { | |||
| svbfloat16_t x_vec = svdup_bf16(x[ix]); | |||
| LOAD_A_VEC(0, 0); | |||
| UPDATE_ACCUMULATORS(0); | |||
| ix += inc_x; | |||
| a_ptr0 += lda; | |||
| } | |||
| if (alpha_f32 != ONE) { | |||
| ALPHA_OP(0); | |||
| } | |||
| if (beta_f32 != ZERO) { | |||
| BETA_OP(0); | |||
| } | |||
| STORE_ACCUMULATORS(0); | |||
| } | |||
| if (i < m) { | |||
| svfloat32_t acc0 = svdup_f32(0.0); | |||
| svfloat32_t acc1 = svdup_f32(0.0); | |||
| ix = 0; | |||
| a_ptr0 = a; | |||
| for (BLASLONG j = 0; j < n; j++) { | |||
| svbfloat16_t x_vec0 = svdup_bf16(x[ix]); | |||
| svbfloat16_t a_vec0 = svld1(pg_tail, &a_ptr0[i]); | |||
| acc0 = svbfmlalb(acc0, a_vec0, x_vec0); | |||
| acc1 = svbfmlalt(acc1, a_vec0, x_vec0); | |||
| ix += inc_x; | |||
| a_ptr0 += lda; | |||
| } | |||
| if (alpha_f32 != ONE) { | |||
| acc0 = svmul_x(pg_full, acc0, svalpha); | |||
| acc1 = svmul_x(pg_full, acc1, svalpha); | |||
| } | |||
| if (beta_f32 != ZERO) { | |||
| svbfloat16_t y_vec = svld1(pg_tail, &y[i]); | |||
| acc0 = svbfmlalb(acc0, svbeta16, y_vec); | |||
| acc1 = svbfmlalt(acc1, svbeta16, y_vec); | |||
| } | |||
| svbfloat16_t acc0_bf16 = svcvt_bf16_x(pg_full, acc0); | |||
| svbfloat16_t acc1_bf16 = svcvt_bf16_x(pg_full, acc1); | |||
| svbfloat16_t combined = svtrn1(acc0_bf16, acc1_bf16); | |||
| svst1(pg_tail, &y[i], combined); | |||
| } | |||
| return 0; | |||
| } | |||
| // Scalar fallback | |||
| BLASLONG iy = 0; | |||
| for (; i < m; i++) { | |||
| float temp = 0.0; | |||
| ix = 0; | |||
| a_ptr0 = a; | |||
| for (BLASLONG j = 0; j < n; j++) { | |||
| temp += vcvtah_f32_bf16(a_ptr0[i]) * vcvtah_f32_bf16(x[ix]); | |||
| ix += inc_x; | |||
| a_ptr0 += lda; | |||
| } | |||
| if (beta_f32 == ZERO) { | |||
| y[iy] = vcvth_bf16_f32(alpha_f32 * temp); | |||
| } else { | |||
| y[iy] = vcvth_bf16_f32(alpha_f32 * temp + beta_f32 * vcvtah_f32_bf16(y[iy])); | |||
| } | |||
| iy += inc_y; | |||
| } | |||
| return (0); | |||
| } | |||
| @@ -33,16 +33,39 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||
| #include <arm_neon.h> | |||
| #include "common.h" | |||
| int CNAME(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, BLASLONG incx, float beta, float *y, BLASLONG incy) | |||
| #ifdef BGEMM | |||
| #define INNER_FLOAT bfloat16_t | |||
| #define TO32(x) vcvtah_f32_bf16(x) | |||
| #define FROM32(x) vcvth_bf16_f32(x) | |||
| #else | |||
| #define INNER_FLOAT float | |||
| #define TO32(x) x | |||
| #define FROM32(x) x | |||
| #endif | |||
| int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, BLASLONG incx, FLOAT beta, FLOAT *y_in, BLASLONG incy) | |||
| { | |||
| if (m < 1 || n < 1) return(0); | |||
| BLASLONG i; | |||
| BLASLONG ix,iy; | |||
| BLASLONG j; | |||
| bfloat16_t *a_ptr; | |||
| bfloat16_t *x_ptr; | |||
| float *y_ptr; | |||
| float temp; | |||
| float temp, temp0, temp1, temp2, temp3; | |||
| #ifdef BGEMM | |||
| bfloat16_t alpha_bf16, beta_bf16; | |||
| memcpy(&alpha_bf16, &alpha, sizeof(bfloat16_t)); | |||
| memcpy(&beta_bf16, &beta, sizeof(bfloat16_t)); | |||
| float alpha_f32 = vcvtah_f32_bf16(alpha_bf16); | |||
| float beta_f32 = vcvtah_f32_bf16(beta_bf16); | |||
| #else | |||
| float alpha_f32 = alpha; | |||
| float beta_f32 = beta; | |||
| #endif | |||
| INNER_FLOAT *y = (INNER_FLOAT *)y_in; | |||
| INNER_FLOAT *y_ptr; | |||
| iy = 0; | |||
| a_ptr = (bfloat16_t*)(a); | |||
| @@ -56,10 +79,10 @@ int CNAME(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat | |||
| bfloat16_t *a2_ptr = a_ptr + lda * width * 2; | |||
| bfloat16_t *a3_ptr = a_ptr + lda * width * 3; | |||
| float *y0_ptr = y + incy * width * 0; | |||
| float *y1_ptr = y + incy * width * 1; | |||
| float *y2_ptr = y + incy * width * 2; | |||
| float *y3_ptr = y + incy * width * 3; | |||
| INNER_FLOAT *y0_ptr = y + incy * width * 0; | |||
| INNER_FLOAT *y1_ptr = y + incy * width * 1; | |||
| INNER_FLOAT *y2_ptr = y + incy * width * 2; | |||
| INNER_FLOAT *y3_ptr = y + incy * width * 3; | |||
| for (j = 0; j < width; j++) { | |||
| float32x4_t temp0_vec = vdupq_n_f32(0.0f); | |||
| @@ -113,26 +136,31 @@ int CNAME(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat | |||
| i += 4; | |||
| } | |||
| if (beta == 0.0f) { | |||
| y0_ptr[iy] = alpha * vaddvq_f32(temp0_vec); | |||
| y1_ptr[iy] = alpha * vaddvq_f32(temp1_vec); | |||
| y2_ptr[iy] = alpha * vaddvq_f32(temp2_vec); | |||
| y3_ptr[iy] = alpha * vaddvq_f32(temp3_vec); | |||
| } | |||
| else { | |||
| y0_ptr[iy] = alpha * vaddvq_f32(temp0_vec) + beta * y0_ptr[iy]; | |||
| y1_ptr[iy] = alpha * vaddvq_f32(temp1_vec) + beta * y1_ptr[iy]; | |||
| y2_ptr[iy] = alpha * vaddvq_f32(temp2_vec) + beta * y2_ptr[iy]; | |||
| y3_ptr[iy] = alpha * vaddvq_f32(temp3_vec) + beta * y3_ptr[iy]; | |||
| if (beta_f32 == 0.0f) { | |||
| temp0 = alpha_f32 * vaddvq_f32(temp0_vec); | |||
| temp1 = alpha_f32 * vaddvq_f32(temp1_vec); | |||
| temp2 = alpha_f32 * vaddvq_f32(temp2_vec); | |||
| temp3 = alpha_f32 * vaddvq_f32(temp3_vec); | |||
| } else { | |||
| temp0 = alpha_f32 * vaddvq_f32(temp0_vec) + beta_f32 * TO32(y0_ptr[iy]); | |||
| temp1 = alpha_f32 * vaddvq_f32(temp1_vec) + beta_f32 * TO32(y1_ptr[iy]); | |||
| temp2 = alpha_f32 * vaddvq_f32(temp2_vec) + beta_f32 * TO32(y2_ptr[iy]); | |||
| temp3 = alpha_f32 * vaddvq_f32(temp3_vec) + beta_f32 * TO32(y3_ptr[iy]); | |||
| } | |||
| for (; i < m; ++i) { | |||
| y0_ptr[iy] += alpha * vcvtah_f32_bf16(a0_ptr[i]) * vcvtah_f32_bf16(x_ptr[i]); | |||
| y1_ptr[iy] += alpha * vcvtah_f32_bf16(a1_ptr[i]) * vcvtah_f32_bf16(x_ptr[i]); | |||
| y2_ptr[iy] += alpha * vcvtah_f32_bf16(a2_ptr[i]) * vcvtah_f32_bf16(x_ptr[i]); | |||
| y3_ptr[iy] += alpha * vcvtah_f32_bf16(a3_ptr[i]) * vcvtah_f32_bf16(x_ptr[i]); | |||
| temp0 = temp0 + alpha_f32 * vcvtah_f32_bf16(a0_ptr[i]) * vcvtah_f32_bf16(x_ptr[i]); | |||
| temp1 = temp1 + alpha_f32 * vcvtah_f32_bf16(a1_ptr[i]) * vcvtah_f32_bf16(x_ptr[i]); | |||
| temp2 = temp2 + alpha_f32 * vcvtah_f32_bf16(a2_ptr[i]) * vcvtah_f32_bf16(x_ptr[i]); | |||
| temp3 = temp3 + alpha_f32 * vcvtah_f32_bf16(a3_ptr[i]) * vcvtah_f32_bf16(x_ptr[i]); | |||
| } | |||
| y0_ptr[iy] = FROM32(temp0); | |||
| y1_ptr[iy] = FROM32(temp1); | |||
| y2_ptr[iy] = FROM32(temp2); | |||
| y3_ptr[iy] = FROM32(temp3); | |||
| iy += incy; | |||
| a0_ptr += lda; | |||
| @@ -164,16 +192,16 @@ int CNAME(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat | |||
| i += 4; | |||
| } | |||
| if (beta == 0.0f) { | |||
| y_ptr[iy] = alpha * vaddvq_f32(temp0_vec); | |||
| } | |||
| else { | |||
| y_ptr[iy] = alpha * vaddvq_f32(temp0_vec) + beta * y_ptr[iy]; | |||
| } | |||
| if (beta_f32 == 0.0f) { | |||
| temp = alpha_f32 * vaddvq_f32(temp0_vec); | |||
| } else { | |||
| temp = alpha_f32 * vaddvq_f32(temp0_vec) + beta_f32 * TO32(y_ptr[iy]); | |||
| } | |||
| for (; i < m; ++i) { | |||
| y_ptr[iy] += alpha * vcvtah_f32_bf16(a_ptr[i]) * vcvtah_f32_bf16(x_ptr[i]); | |||
| temp += alpha_f32 * vcvtah_f32_bf16(a_ptr[i]) * vcvtah_f32_bf16(x_ptr[i]); | |||
| } | |||
| y_ptr[iy] = FROM32(temp); | |||
| iy += incy; | |||
| @@ -189,11 +217,11 @@ int CNAME(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat | |||
| temp += vcvtah_f32_bf16(a_ptr[i]) * vcvtah_f32_bf16(x_ptr[ix]); | |||
| ix += incx; | |||
| } | |||
| if (beta == 0.0f) { | |||
| y[iy] = alpha * temp; | |||
| if (beta_f32 == 0.0f) { | |||
| y[iy] = FROM32(alpha_f32 * temp); | |||
| } | |||
| else { | |||
| y[iy] = alpha * temp + beta * y[iy]; | |||
| y[iy] = FROM32(alpha_f32 * temp + beta_f32 * TO32(y[iy])); | |||
| } | |||
| iy += incy; | |||
| a_ptr += lda; | |||
| @@ -127,7 +127,7 @@ int main(int argc, char *argv[]) | |||
| if (!is_close(float16to32(CC[j << l]), truncate_float32_to_bfloat16(DD[j]), 0.001, 0.0001)) | |||
| { | |||
| #ifdef DEBUG | |||
| printf("Mismatch at trans=%c, alpha=%.2f, beta=%.2f, i=%d, j=%d, k=%ld: CC=%.6f, C=%.6f\n", | |||
| printf("Mismatch at trans=%c, alpha=%.2f, beta=%.2f, i=%d, j=%d, k=%ld: CC=%.6f, DD=%.6f\n", | |||
| transA, alpha, beta, i, j, k, float16to32(CC[j << l]), truncate_float32_to_bfloat16(DD[j])); | |||
| #endif | |||
| ret++; | |||