- Adds bgemv T based off of sbgemv T kernel - Adds bgemv N which is slightly alterated to not use Y as an accumulator due to the output being bf16 which results in loss of precision - Enables BGEMM_GEMV_FORWARD to proxy BGEMM to BGEMV with new kernelspull/5394/head
| @@ -277,6 +277,7 @@ endif | |||||
| ifeq ($(ARCH), arm64) | ifeq ($(ARCH), arm64) | ||||
| GEMM_GEMV_FORWARD = 1 | GEMM_GEMV_FORWARD = 1 | ||||
| SBGEMM_GEMV_FORWARD = 1 | SBGEMM_GEMV_FORWARD = 1 | ||||
| BGEMM_GEMV_FORWARD = 1 | |||||
| endif | endif | ||||
| ifeq ($(ARCH), riscv) | ifeq ($(ARCH), riscv) | ||||
| GEMM_GEMV_FORWARD = 1 | GEMM_GEMV_FORWARD = 1 | ||||
| @@ -296,6 +297,9 @@ endif | |||||
| ifeq ($(SBGEMM_GEMV_FORWARD), 1) | ifeq ($(SBGEMM_GEMV_FORWARD), 1) | ||||
| CCOMMON_OPT += -DSBGEMM_GEMV_FORWARD | CCOMMON_OPT += -DSBGEMM_GEMV_FORWARD | ||||
| endif | endif | ||||
| ifeq ($(BGEMM_GEMV_FORWARD), 1) | |||||
| CCOMMON_OPT += -DBGEMM_GEMV_FORWARD | |||||
| endif | |||||
| endif | endif | ||||
| # This operation is expensive, so execution should be once. | # This operation is expensive, so execution should be once. | ||||
| @@ -84,7 +84,7 @@ GOTO_LAPACK_TARGETS= | |||||
| endif | endif | ||||
| ifeq ($(BUILD_BFLOAT16),1) | ifeq ($(BUILD_BFLOAT16),1) | ||||
| GOTO_BFLOAT_TARGETS=bgemm.goto sbgemm.goto | |||||
| GOTO_BFLOAT_TARGETS=bgemm.goto sbgemm.goto bgemv.goto sbgemv.goto | |||||
| else | else | ||||
| GOTO_BFLOAT_TARGETS= | GOTO_BFLOAT_TARGETS= | ||||
| endif | endif | ||||
| @@ -667,6 +667,10 @@ bgemm.goto : bgemm.$(SUFFIX) ../$(LIBNAME) | |||||
| $(CC) $(CFLAGS) -o $(@F) $^ $(CEXTRALIB) $(EXTRALIB) $(FEXTRALIB) -lm | $(CC) $(CFLAGS) -o $(@F) $^ $(CEXTRALIB) $(EXTRALIB) $(FEXTRALIB) -lm | ||||
| sbgemm.goto : sbgemm.$(SUFFIX) ../$(LIBNAME) | sbgemm.goto : sbgemm.$(SUFFIX) ../$(LIBNAME) | ||||
| $(CC) $(CFLAGS) -o $(@F) $^ $(CEXTRALIB) $(EXTRALIB) $(FEXTRALIB) -lm | $(CC) $(CFLAGS) -o $(@F) $^ $(CEXTRALIB) $(EXTRALIB) $(FEXTRALIB) -lm | ||||
| bgemv.goto : bgemv.$(SUFFIX) ../$(LIBNAME) | |||||
| $(CC) $(CFLAGS) -o $(@F) $^ $(CEXTRALIB) $(EXTRALIB) $(FEXTRALIB) -lm | |||||
| sbgemv.goto : sbgemv.$(SUFFIX) ../$(LIBNAME) | |||||
| $(CC) $(CFLAGS) -o $(@F) $^ $(CEXTRALIB) $(EXTRALIB) $(FEXTRALIB) -lm | |||||
| endif | endif | ||||
| ifeq ($(BUILD_HFLOAT16),1) | ifeq ($(BUILD_HFLOAT16),1) | ||||
| @@ -3146,6 +3150,13 @@ dgemv.$(SUFFIX) : gemv.c | |||||
| cgemv.$(SUFFIX) : gemv.c | cgemv.$(SUFFIX) : gemv.c | ||||
| $(CC) $(CFLAGS) -c -DCOMPLEX -UDOUBLE -o $(@F) $^ | $(CC) $(CFLAGS) -c -DCOMPLEX -UDOUBLE -o $(@F) $^ | ||||
| ifeq ($(BUILD_BFLOAT16),1) | |||||
| bgemv.$(SUFFIX) : gemv.c | |||||
| $(CC) $(CFLAGS) -c -DBFLOAT16 -DBGEMM -UCOMPLEX -UDOUBLE -o $(@F) $^ | |||||
| sbgemv.$(SUFFIX) : gemv.c | |||||
| $(CC) $(CFLAGS) -c -DBFLOAT16 -UCOMPLEX -UDOUBLE -o $(@F) $^ | |||||
| endif () | |||||
| zgemv.$(SUFFIX) : gemv.c | zgemv.$(SUFFIX) : gemv.c | ||||
| $(CC) $(CFLAGS) -c -DCOMPLEX -DDOUBLE -o $(@F) $^ | $(CC) $(CFLAGS) -c -DCOMPLEX -DDOUBLE -o $(@F) $^ | ||||
| @@ -1,5 +1,5 @@ | |||||
| /*************************************************************************** | /*************************************************************************** | ||||
| Copyright (c) 2014, The OpenBLAS Project | |||||
| Copyright (c) 2014, 2025 The OpenBLAS Project | |||||
| All rights reserved. | All rights reserved. | ||||
| Redistribution and use in source and binary forms, with or without | Redistribution and use in source and binary forms, with or without | ||||
| modification, are permitted provided that the following conditions are | modification, are permitted provided that the following conditions are | ||||
| @@ -34,6 +34,12 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||||
| #ifdef DOUBLE | #ifdef DOUBLE | ||||
| #define GEMV BLASFUNC(dgemv) | #define GEMV BLASFUNC(dgemv) | ||||
| #elif defined(BFLOAT16) && defined(BGEMM) | |||||
| #define GEMV BLASFUNC(bgemv) | |||||
| #elif defined(BFLOAT16) | |||||
| #define GEMV BLASFUNC(sbgemv) | |||||
| #undef IFLOAT | |||||
| #define IFLOAT bfloat16 | |||||
| #else | #else | ||||
| #define GEMV BLASFUNC(sgemv) | #define GEMV BLASFUNC(sgemv) | ||||
| #endif | #endif | ||||
| @@ -49,9 +55,20 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||||
| #endif | #endif | ||||
| int main(int argc, char *argv[]){ | int main(int argc, char *argv[]){ | ||||
| FLOAT *a, *x, *y; | |||||
| FLOAT alpha[] = {1.0, 1.0}; | |||||
| FLOAT beta [] = {1.0, 0.0}; | |||||
| IFLOAT *a, *x; | |||||
| FLOAT *y; | |||||
| #ifdef BGEMM | |||||
| blasint one=1; | |||||
| blasint two=2; | |||||
| float alpha_in[] = {1.0, 0.0}; | |||||
| float beta_in[] = {0.0, 0.0}; | |||||
| FLOAT alpha[2], beta[2]; | |||||
| sbstobf16_(&two, alpha_in, &one, alpha, &one); | |||||
| sbstobf16_(&two, beta_in, &one, beta, &one); | |||||
| #else | |||||
| FLOAT alpha[] = {1.0, 0.0}; | |||||
| FLOAT beta [] = {0.0, 0.0}; | |||||
| #endif | |||||
| char trans='N'; | char trans='N'; | ||||
| blasint m, i, j; | blasint m, i, j; | ||||
| blasint inc_x=1,inc_y=1; | blasint inc_x=1,inc_y=1; | ||||
| @@ -97,11 +114,11 @@ int main(int argc, char *argv[]){ | |||||
| fprintf(stderr, "From : %3d To : %3d Step = %3d Trans = '%c' Inc_x = %d Inc_y = %d Loops = %d\n", from, to, step,trans,inc_x,inc_y,loops); | fprintf(stderr, "From : %3d To : %3d Step = %3d Trans = '%c' Inc_x = %d Inc_y = %d Loops = %d\n", from, to, step,trans,inc_x,inc_y,loops); | ||||
| if (( a = (FLOAT *)malloc(sizeof(FLOAT) * tomax * tomax * COMPSIZE)) == NULL){ | |||||
| if (( a = (IFLOAT *)malloc(sizeof(IFLOAT) * tomax * tomax * COMPSIZE)) == NULL){ | |||||
| fprintf(stderr,"Out of Memory!!\n");exit(1); | fprintf(stderr,"Out of Memory!!\n");exit(1); | ||||
| } | } | ||||
| if (( x = (FLOAT *)malloc(sizeof(FLOAT) * tomax * abs(inc_x) * COMPSIZE)) == NULL){ | |||||
| if (( x = (IFLOAT *)malloc(sizeof(IFLOAT) * tomax * abs(inc_x) * COMPSIZE)) == NULL){ | |||||
| fprintf(stderr,"Out of Memory!!\n");exit(1); | fprintf(stderr,"Out of Memory!!\n");exit(1); | ||||
| } | } | ||||
| @@ -125,7 +142,7 @@ int main(int argc, char *argv[]){ | |||||
| fprintf(stderr, " %6dx%d : ", (int)m,(int)n); | fprintf(stderr, " %6dx%d : ", (int)m,(int)n); | ||||
| for(j = 0; j < m; j++){ | for(j = 0; j < m; j++){ | ||||
| for(i = 0; i < n * COMPSIZE; i++){ | for(i = 0; i < n * COMPSIZE; i++){ | ||||
| a[(long)i + (long)j * (long)m * COMPSIZE] = ((FLOAT) rand() / (FLOAT) RAND_MAX) - 0.5; | |||||
| a[(long)i + (long)j * (long)m * COMPSIZE] = ((IFLOAT) rand() / (FLOAT) RAND_MAX) - 0.5; | |||||
| } | } | ||||
| } | } | ||||
| @@ -133,7 +150,7 @@ int main(int argc, char *argv[]){ | |||||
| { | { | ||||
| for(i = 0; i < n * COMPSIZE * abs(inc_x); i++){ | for(i = 0; i < n * COMPSIZE * abs(inc_x); i++){ | ||||
| x[i] = ((FLOAT) rand() / (FLOAT) RAND_MAX) - 0.5; | |||||
| x[i] = ((IFLOAT) rand() / (FLOAT) RAND_MAX) - 0.5; | |||||
| } | } | ||||
| for(i = 0; i < m * COMPSIZE * abs(inc_y); i++){ | for(i = 0; i < m * COMPSIZE * abs(inc_y); i++){ | ||||
| @@ -46,6 +46,9 @@ BGEMMOTCOPY = sbgemm_tcopy_$(BGEMM_UNROLL_N)_neoversev1.c | |||||
| BGEMMONCOPYOBJ = bgemm_oncopy$(TSUFFIX).$(SUFFIX) | BGEMMONCOPYOBJ = bgemm_oncopy$(TSUFFIX).$(SUFFIX) | ||||
| BGEMMOTCOPYOBJ = bgemm_otcopy$(TSUFFIX).$(SUFFIX) | BGEMMOTCOPYOBJ = bgemm_otcopy$(TSUFFIX).$(SUFFIX) | ||||
| BGEMVTKERNEL = sbgemv_t_bfdot.c | |||||
| BGEMVNKERNEL = bgemv_n_sve_v3x4.c | |||||
| SBGEMM_BETA = sbgemm_beta_neoversev1.c | SBGEMM_BETA = sbgemm_beta_neoversev1.c | ||||
| SBGEMMKERNEL = sbgemm_kernel_$(SBGEMM_UNROLL_M)x$(SBGEMM_UNROLL_N)_neoversev1.c | SBGEMMKERNEL = sbgemm_kernel_$(SBGEMM_UNROLL_M)x$(SBGEMM_UNROLL_N)_neoversev1.c | ||||
| ifneq ($(SBGEMM_UNROLL_M), $(SBGEMM_UNROLL_N)) | ifneq ($(SBGEMM_UNROLL_M), $(SBGEMM_UNROLL_N)) | ||||
| @@ -0,0 +1,321 @@ | |||||
| /*************************************************************************** | |||||
| Copyright (c) 2025 The OpenBLAS Project | |||||
| All rights reserved. | |||||
| Redistribution and use in source and binary forms, with or without | |||||
| modification, are permitted provided that the following conditions are | |||||
| met: | |||||
| 1. Redistributions of source code must retain the above copyright | |||||
| notice, this list of conditions and the following disclaimer. | |||||
| 2. Redistributions in binary form must reproduce the above copyright | |||||
| notice, this list of conditions and the following disclaimer in | |||||
| the documentation and/or other materials provided with the | |||||
| distribution. | |||||
| 3. Neither the name of the OpenBLAS project nor the names of | |||||
| its contributors may be used to endorse or promote products | |||||
| derived from this software without specific prior written permission. | |||||
| THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||||
| AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||||
| IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||||
| ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE | |||||
| LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |||||
| DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |||||
| SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |||||
| CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |||||
| OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE | |||||
| USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||||
| *****************************************************************************/ | |||||
| #include "common.h" | |||||
| #include <arm_sve.h> | |||||
| #define UPDATE_PTRSx2 \ | |||||
| a_ptr1 = a_ptr0 + lda; | |||||
| #define UPDATE_PTRSx4 \ | |||||
| UPDATE_PTRSx2 \ | |||||
| a_ptr2 = a_ptr1 + lda; \ | |||||
| a_ptr3 = a_ptr2 + lda; | |||||
| #define UPDATE_PTRSx8 \ | |||||
| UPDATE_PTRSx4 \ | |||||
| a_ptr4 = a_ptr3 + lda; \ | |||||
| a_ptr5 = a_ptr4 + lda; \ | |||||
| a_ptr6 = a_ptr5 + lda; \ | |||||
| a_ptr7 = a_ptr6 + lda; | |||||
| #define LANESx2(MACRO, offset) \ | |||||
| MACRO(offset, 0) \ | |||||
| MACRO(offset, 1) | |||||
| #define LANESx4(MACRO, offset) \ | |||||
| LANESx2(MACRO, offset) \ | |||||
| MACRO(offset, 2) \ | |||||
| MACRO(offset, 3) | |||||
| #define LANESx8(MACRO, offset) \ | |||||
| LANESx4(MACRO, offset) \ | |||||
| MACRO(offset, 4) \ | |||||
| MACRO(offset, 5) \ | |||||
| MACRO(offset, 6) \ | |||||
| MACRO(offset, 7) | |||||
| #define LOAD_A_VEC(offset, vec) \ | |||||
| svbfloat16_t a_vec ## offset ## vec = svld1(pg_full, &a_ptr ## vec[i] + offset * sve_size_bf16); | |||||
| #define UPDATE_ACCUMULATORS_FROM_LANE(offset, lane) \ | |||||
| acc ## offset ## 0 = svbfmlalb_lane(acc ## offset ## 0, a_vec ## offset ## lane, x_vec, lane); \ | |||||
| acc ## offset ## 1 = svbfmlalt_lane(acc ## offset ## 1, a_vec ## offset ## lane, x_vec, lane); | |||||
| #define INIT_ACCUMULATORS(offset) \ | |||||
| svfloat32_t acc ## offset ## 0 = svdup_f32(0.0); \ | |||||
| svfloat32_t acc ## offset ## 1 = svdup_f32(0.0); | |||||
| #define UPDATE_ACCUMULATORS(offset) \ | |||||
| acc ## offset ## 0 = svbfmlalb(acc ## offset ## 0, a_vec ## offset ## 0, x_vec); \ | |||||
| acc ## offset ## 1 = svbfmlalt(acc ## offset ## 1, a_vec ## offset ## 0, x_vec); | |||||
| #define STORE_ACCUMULATORS(offset) \ | |||||
| svbfloat16_t acc ## offset ## 0_bf16 = svcvt_bf16_x(pg_full, acc ## offset ## 0); \ | |||||
| svbfloat16_t acc ## offset ## 1_bf16 = svcvt_bf16_x(pg_full, acc ## offset ## 1); \ | |||||
| svbfloat16_t combined ## offset = svtrn1(acc ## offset ## 0_bf16, acc ## offset ## 1_bf16); \ | |||||
| svst1(pg_full, &y[i] + offset * sve_size_bf16, combined ## offset); | |||||
| #define ALPHA_OP(offset) \ | |||||
| acc ## offset ## 0 = svmul_x(pg_full, acc ## offset ## 0, svalpha); \ | |||||
| acc ## offset ## 1 = svmul_x(pg_full, acc ## offset ## 1, svalpha); | |||||
| #define BETA_OP(offset) \ | |||||
| svbfloat16_t y_vec ## offset = svld1(pg_full, &y[i] + offset * sve_size_bf16); \ | |||||
| acc ## offset ## 0 = svbfmlalb(acc ## offset ## 0, svbeta16, y_vec ## offset); \ | |||||
| acc ## offset ## 1 = svbfmlalt(acc ## offset ## 1, svbeta16, y_vec ## offset); | |||||
| int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a_in, BLASLONG lda, IFLOAT *x_in, BLASLONG inc_x, FLOAT beta, FLOAT *y_in, BLASLONG inc_y) | |||||
| { | |||||
| BLASLONG ix; | |||||
| bfloat16_t *a_ptr0, *a_ptr1, *a_ptr2, *a_ptr3, *a_ptr4, *a_ptr5, *a_ptr6, *a_ptr7; | |||||
| BLASLONG sve_size_bf16 = svcnth(); | |||||
| BLASLONG sve_size2_bf16 = sve_size_bf16 * 2; | |||||
| BLASLONG sve_size3_bf16 = sve_size_bf16 * 3; | |||||
| svbool_t pg_full = svptrue_b16(); | |||||
| svbool_t pg_tail = svwhilelt_b16_s64(0, m % sve_size_bf16); | |||||
| BLASLONG n8 = n & -8; | |||||
| BLASLONG n4 = n & -4; | |||||
| BLASLONG n2 = n & -2; | |||||
| bfloat16_t *a = (bfloat16_t*)a_in; | |||||
| bfloat16_t *x = (bfloat16_t*)x_in; | |||||
| bfloat16_t *y = (bfloat16_t*)y_in; | |||||
| bfloat16_t alpha_bf16, beta_bf16; | |||||
| memcpy(&alpha_bf16, &alpha, sizeof(bfloat16_t)); | |||||
| memcpy(&beta_bf16, &beta, sizeof(bfloat16_t)); | |||||
| float beta_f32 = vcvtah_f32_bf16(beta_bf16); | |||||
| float alpha_f32 = vcvtah_f32_bf16(alpha_bf16); | |||||
| svfloat32_t svalpha = svdup_f32(vcvtah_f32_bf16(alpha_bf16)); | |||||
| svbfloat16_t svbeta16 = svdup_bf16(beta_bf16); | |||||
| BLASLONG i = 0; | |||||
| if (inc_y == 1) { | |||||
| for (; (i + sve_size3_bf16 - 1) < m; i+= sve_size3_bf16) { | |||||
| INIT_ACCUMULATORS(0); | |||||
| INIT_ACCUMULATORS(1); | |||||
| INIT_ACCUMULATORS(2); | |||||
| BLASLONG j = 0; | |||||
| ix = 0; | |||||
| a_ptr0 = a; | |||||
| UPDATE_PTRSx8; | |||||
| if (inc_x == 1) { | |||||
| for (; j < n4; j+= 4) { | |||||
| uint64_t* x_u64 = (uint64_t*)&x[ix]; | |||||
| svbfloat16_t x_vec = svreinterpret_bf16_u64(svdup_u64(*x_u64)); | |||||
| LANESx4(LOAD_A_VEC, 0); | |||||
| LANESx4(LOAD_A_VEC, 1); | |||||
| LANESx4(LOAD_A_VEC, 2); | |||||
| LANESx4(UPDATE_ACCUMULATORS_FROM_LANE, 0); | |||||
| LANESx4(UPDATE_ACCUMULATORS_FROM_LANE, 1); | |||||
| LANESx4(UPDATE_ACCUMULATORS_FROM_LANE, 2); | |||||
| ix += 4; | |||||
| a_ptr0 += 4 * lda; | |||||
| UPDATE_PTRSx4; | |||||
| } | |||||
| for (; j < n2; j+= 2) { | |||||
| uint32_t* x_u32 = (uint32_t*)&x[ix]; | |||||
| svbfloat16_t x_vec = svreinterpret_bf16_u32(svdup_u32(*x_u32)); | |||||
| LANESx2(LOAD_A_VEC, 0); | |||||
| LANESx2(LOAD_A_VEC, 1); | |||||
| LANESx2(LOAD_A_VEC, 2); | |||||
| LANESx2(UPDATE_ACCUMULATORS_FROM_LANE, 0); | |||||
| LANESx2(UPDATE_ACCUMULATORS_FROM_LANE, 1); | |||||
| LANESx2(UPDATE_ACCUMULATORS_FROM_LANE, 2); | |||||
| ix += 2; | |||||
| a_ptr0 += 2 * lda; | |||||
| UPDATE_PTRSx2; | |||||
| } | |||||
| } | |||||
| for (; j < n; j++) { | |||||
| svbfloat16_t x_vec = svdup_bf16(x[ix]); | |||||
| LOAD_A_VEC(0, 0); | |||||
| LOAD_A_VEC(1, 0); | |||||
| LOAD_A_VEC(2, 0); | |||||
| UPDATE_ACCUMULATORS(0); | |||||
| UPDATE_ACCUMULATORS(1); | |||||
| UPDATE_ACCUMULATORS(2); | |||||
| ix += inc_x; | |||||
| a_ptr0 += lda; | |||||
| } | |||||
| if (alpha_f32 != ONE) { | |||||
| ALPHA_OP(0); | |||||
| ALPHA_OP(1); | |||||
| ALPHA_OP(2); | |||||
| } | |||||
| if (beta_f32 != ZERO) { | |||||
| BETA_OP(0); | |||||
| BETA_OP(1); | |||||
| BETA_OP(2); | |||||
| } | |||||
| STORE_ACCUMULATORS(0); | |||||
| STORE_ACCUMULATORS(1); | |||||
| STORE_ACCUMULATORS(2); | |||||
| } | |||||
| for (; (i + sve_size_bf16 - 1) < m; i+= sve_size_bf16) { | |||||
| INIT_ACCUMULATORS(0); | |||||
| BLASLONG j = 0; | |||||
| ix = 0; | |||||
| a_ptr0 = a; | |||||
| UPDATE_PTRSx8; | |||||
| if (inc_x == 1) { | |||||
| for (; j < n8; j+= 8) { | |||||
| svbfloat16_t x_vec = svld1rq(pg_full, &x[ix]); | |||||
| LANESx8(LOAD_A_VEC, 0); | |||||
| LANESx8(UPDATE_ACCUMULATORS_FROM_LANE, 0); | |||||
| ix += 8; | |||||
| a_ptr0 += 8 * lda; | |||||
| UPDATE_PTRSx8; | |||||
| } | |||||
| for (; j < n4; j+= 4) { | |||||
| uint64_t* x_u64 = (uint64_t*)&x[ix]; | |||||
| svbfloat16_t x_vec = svreinterpret_bf16_u64(svdup_u64(*x_u64)); | |||||
| LANESx4(LOAD_A_VEC, 0); | |||||
| LANESx4(UPDATE_ACCUMULATORS_FROM_LANE, 0); | |||||
| ix += 4; | |||||
| a_ptr0 += 4 * lda; | |||||
| UPDATE_PTRSx4; | |||||
| } | |||||
| for (; j < n2; j+= 2) { | |||||
| uint32_t* x_u32 = (uint32_t*)&x[ix]; | |||||
| svbfloat16_t x_vec = svreinterpret_bf16_u32(svdup_u32(*x_u32)); | |||||
| LANESx2(LOAD_A_VEC, 0); | |||||
| LANESx2(UPDATE_ACCUMULATORS_FROM_LANE, 0); | |||||
| ix += 2; | |||||
| a_ptr0 += 2 * lda; | |||||
| UPDATE_PTRSx2; | |||||
| } | |||||
| } | |||||
| for (; j < n; j++) { | |||||
| svbfloat16_t x_vec = svdup_bf16(x[ix]); | |||||
| LOAD_A_VEC(0, 0); | |||||
| UPDATE_ACCUMULATORS(0); | |||||
| ix += inc_x; | |||||
| a_ptr0 += lda; | |||||
| } | |||||
| if (alpha_f32 != ONE) { | |||||
| ALPHA_OP(0); | |||||
| } | |||||
| if (beta_f32 != ZERO) { | |||||
| BETA_OP(0); | |||||
| } | |||||
| STORE_ACCUMULATORS(0); | |||||
| } | |||||
| if (i < m) { | |||||
| svfloat32_t acc0 = svdup_f32(0.0); | |||||
| svfloat32_t acc1 = svdup_f32(0.0); | |||||
| ix = 0; | |||||
| a_ptr0 = a; | |||||
| for (BLASLONG j = 0; j < n; j++) { | |||||
| svbfloat16_t x_vec0 = svdup_bf16(x[ix]); | |||||
| svbfloat16_t a_vec0 = svld1(pg_tail, &a_ptr0[i]); | |||||
| acc0 = svbfmlalb(acc0, a_vec0, x_vec0); | |||||
| acc1 = svbfmlalt(acc1, a_vec0, x_vec0); | |||||
| ix += inc_x; | |||||
| a_ptr0 += lda; | |||||
| } | |||||
| if (alpha_f32 != ONE) { | |||||
| acc0 = svmul_x(pg_full, acc0, svalpha); | |||||
| acc1 = svmul_x(pg_full, acc1, svalpha); | |||||
| } | |||||
| if (beta_f32 != ZERO) { | |||||
| svbfloat16_t y_vec = svld1(pg_tail, &y[i]); | |||||
| acc0 = svbfmlalb(acc0, svbeta16, y_vec); | |||||
| acc1 = svbfmlalt(acc1, svbeta16, y_vec); | |||||
| } | |||||
| svbfloat16_t acc0_bf16 = svcvt_bf16_x(pg_full, acc0); | |||||
| svbfloat16_t acc1_bf16 = svcvt_bf16_x(pg_full, acc1); | |||||
| svbfloat16_t combined = svtrn1(acc0_bf16, acc1_bf16); | |||||
| svst1(pg_tail, &y[i], combined); | |||||
| } | |||||
| return 0; | |||||
| } | |||||
| // Scalar fallback | |||||
| BLASLONG iy = 0; | |||||
| for (; i < m; i++) { | |||||
| float temp = 0.0; | |||||
| ix = 0; | |||||
| a_ptr0 = a; | |||||
| for (BLASLONG j = 0; j < n; j++) { | |||||
| temp += vcvtah_f32_bf16(a_ptr0[i]) * vcvtah_f32_bf16(x[ix]); | |||||
| ix += inc_x; | |||||
| a_ptr0 += lda; | |||||
| } | |||||
| if (beta_f32 == ZERO) { | |||||
| y[iy] = vcvth_bf16_f32(alpha_f32 * temp); | |||||
| } else { | |||||
| y[iy] = vcvth_bf16_f32(alpha_f32 * temp + beta_f32 * vcvtah_f32_bf16(y[iy])); | |||||
| } | |||||
| iy += inc_y; | |||||
| } | |||||
| return (0); | |||||
| } | |||||
| @@ -33,16 +33,39 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||||
| #include <arm_neon.h> | #include <arm_neon.h> | ||||
| #include "common.h" | #include "common.h" | ||||
| int CNAME(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, BLASLONG incx, float beta, float *y, BLASLONG incy) | |||||
| #ifdef BGEMM | |||||
| #define INNER_FLOAT bfloat16_t | |||||
| #define TO32(x) vcvtah_f32_bf16(x) | |||||
| #define FROM32(x) vcvth_bf16_f32(x) | |||||
| #else | |||||
| #define INNER_FLOAT float | |||||
| #define TO32(x) x | |||||
| #define FROM32(x) x | |||||
| #endif | |||||
| int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, BLASLONG incx, FLOAT beta, FLOAT *y_in, BLASLONG incy) | |||||
| { | { | ||||
| if (m < 1 || n < 1) return(0); | if (m < 1 || n < 1) return(0); | ||||
| BLASLONG i; | BLASLONG i; | ||||
| BLASLONG ix,iy; | BLASLONG ix,iy; | ||||
| BLASLONG j; | BLASLONG j; | ||||
| bfloat16_t *a_ptr; | bfloat16_t *a_ptr; | ||||
| bfloat16_t *x_ptr; | bfloat16_t *x_ptr; | ||||
| float *y_ptr; | |||||
| float temp; | |||||
| float temp, temp0, temp1, temp2, temp3; | |||||
| #ifdef BGEMM | |||||
| bfloat16_t alpha_bf16, beta_bf16; | |||||
| memcpy(&alpha_bf16, &alpha, sizeof(bfloat16_t)); | |||||
| memcpy(&beta_bf16, &beta, sizeof(bfloat16_t)); | |||||
| float alpha_f32 = vcvtah_f32_bf16(alpha_bf16); | |||||
| float beta_f32 = vcvtah_f32_bf16(beta_bf16); | |||||
| #else | |||||
| float alpha_f32 = alpha; | |||||
| float beta_f32 = beta; | |||||
| #endif | |||||
| INNER_FLOAT *y = (INNER_FLOAT *)y_in; | |||||
| INNER_FLOAT *y_ptr; | |||||
| iy = 0; | iy = 0; | ||||
| a_ptr = (bfloat16_t*)(a); | a_ptr = (bfloat16_t*)(a); | ||||
| @@ -56,10 +79,10 @@ int CNAME(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat | |||||
| bfloat16_t *a2_ptr = a_ptr + lda * width * 2; | bfloat16_t *a2_ptr = a_ptr + lda * width * 2; | ||||
| bfloat16_t *a3_ptr = a_ptr + lda * width * 3; | bfloat16_t *a3_ptr = a_ptr + lda * width * 3; | ||||
| float *y0_ptr = y + incy * width * 0; | |||||
| float *y1_ptr = y + incy * width * 1; | |||||
| float *y2_ptr = y + incy * width * 2; | |||||
| float *y3_ptr = y + incy * width * 3; | |||||
| INNER_FLOAT *y0_ptr = y + incy * width * 0; | |||||
| INNER_FLOAT *y1_ptr = y + incy * width * 1; | |||||
| INNER_FLOAT *y2_ptr = y + incy * width * 2; | |||||
| INNER_FLOAT *y3_ptr = y + incy * width * 3; | |||||
| for (j = 0; j < width; j++) { | for (j = 0; j < width; j++) { | ||||
| float32x4_t temp0_vec = vdupq_n_f32(0.0f); | float32x4_t temp0_vec = vdupq_n_f32(0.0f); | ||||
| @@ -113,26 +136,31 @@ int CNAME(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat | |||||
| i += 4; | i += 4; | ||||
| } | } | ||||
| if (beta == 0.0f) { | |||||
| y0_ptr[iy] = alpha * vaddvq_f32(temp0_vec); | |||||
| y1_ptr[iy] = alpha * vaddvq_f32(temp1_vec); | |||||
| y2_ptr[iy] = alpha * vaddvq_f32(temp2_vec); | |||||
| y3_ptr[iy] = alpha * vaddvq_f32(temp3_vec); | |||||
| } | |||||
| else { | |||||
| y0_ptr[iy] = alpha * vaddvq_f32(temp0_vec) + beta * y0_ptr[iy]; | |||||
| y1_ptr[iy] = alpha * vaddvq_f32(temp1_vec) + beta * y1_ptr[iy]; | |||||
| y2_ptr[iy] = alpha * vaddvq_f32(temp2_vec) + beta * y2_ptr[iy]; | |||||
| y3_ptr[iy] = alpha * vaddvq_f32(temp3_vec) + beta * y3_ptr[iy]; | |||||
| if (beta_f32 == 0.0f) { | |||||
| temp0 = alpha_f32 * vaddvq_f32(temp0_vec); | |||||
| temp1 = alpha_f32 * vaddvq_f32(temp1_vec); | |||||
| temp2 = alpha_f32 * vaddvq_f32(temp2_vec); | |||||
| temp3 = alpha_f32 * vaddvq_f32(temp3_vec); | |||||
| } else { | |||||
| temp0 = alpha_f32 * vaddvq_f32(temp0_vec) + beta_f32 * TO32(y0_ptr[iy]); | |||||
| temp1 = alpha_f32 * vaddvq_f32(temp1_vec) + beta_f32 * TO32(y1_ptr[iy]); | |||||
| temp2 = alpha_f32 * vaddvq_f32(temp2_vec) + beta_f32 * TO32(y2_ptr[iy]); | |||||
| temp3 = alpha_f32 * vaddvq_f32(temp3_vec) + beta_f32 * TO32(y3_ptr[iy]); | |||||
| } | } | ||||
| for (; i < m; ++i) { | for (; i < m; ++i) { | ||||
| y0_ptr[iy] += alpha * vcvtah_f32_bf16(a0_ptr[i]) * vcvtah_f32_bf16(x_ptr[i]); | |||||
| y1_ptr[iy] += alpha * vcvtah_f32_bf16(a1_ptr[i]) * vcvtah_f32_bf16(x_ptr[i]); | |||||
| y2_ptr[iy] += alpha * vcvtah_f32_bf16(a2_ptr[i]) * vcvtah_f32_bf16(x_ptr[i]); | |||||
| y3_ptr[iy] += alpha * vcvtah_f32_bf16(a3_ptr[i]) * vcvtah_f32_bf16(x_ptr[i]); | |||||
| temp0 = temp0 + alpha_f32 * vcvtah_f32_bf16(a0_ptr[i]) * vcvtah_f32_bf16(x_ptr[i]); | |||||
| temp1 = temp1 + alpha_f32 * vcvtah_f32_bf16(a1_ptr[i]) * vcvtah_f32_bf16(x_ptr[i]); | |||||
| temp2 = temp2 + alpha_f32 * vcvtah_f32_bf16(a2_ptr[i]) * vcvtah_f32_bf16(x_ptr[i]); | |||||
| temp3 = temp3 + alpha_f32 * vcvtah_f32_bf16(a3_ptr[i]) * vcvtah_f32_bf16(x_ptr[i]); | |||||
| } | } | ||||
| y0_ptr[iy] = FROM32(temp0); | |||||
| y1_ptr[iy] = FROM32(temp1); | |||||
| y2_ptr[iy] = FROM32(temp2); | |||||
| y3_ptr[iy] = FROM32(temp3); | |||||
| iy += incy; | iy += incy; | ||||
| a0_ptr += lda; | a0_ptr += lda; | ||||
| @@ -164,16 +192,16 @@ int CNAME(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat | |||||
| i += 4; | i += 4; | ||||
| } | } | ||||
| if (beta == 0.0f) { | |||||
| y_ptr[iy] = alpha * vaddvq_f32(temp0_vec); | |||||
| } | |||||
| else { | |||||
| y_ptr[iy] = alpha * vaddvq_f32(temp0_vec) + beta * y_ptr[iy]; | |||||
| } | |||||
| if (beta_f32 == 0.0f) { | |||||
| temp = alpha_f32 * vaddvq_f32(temp0_vec); | |||||
| } else { | |||||
| temp = alpha_f32 * vaddvq_f32(temp0_vec) + beta_f32 * TO32(y_ptr[iy]); | |||||
| } | |||||
| for (; i < m; ++i) { | for (; i < m; ++i) { | ||||
| y_ptr[iy] += alpha * vcvtah_f32_bf16(a_ptr[i]) * vcvtah_f32_bf16(x_ptr[i]); | |||||
| temp += alpha_f32 * vcvtah_f32_bf16(a_ptr[i]) * vcvtah_f32_bf16(x_ptr[i]); | |||||
| } | } | ||||
| y_ptr[iy] = FROM32(temp); | |||||
| iy += incy; | iy += incy; | ||||
| @@ -189,11 +217,11 @@ int CNAME(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat | |||||
| temp += vcvtah_f32_bf16(a_ptr[i]) * vcvtah_f32_bf16(x_ptr[ix]); | temp += vcvtah_f32_bf16(a_ptr[i]) * vcvtah_f32_bf16(x_ptr[ix]); | ||||
| ix += incx; | ix += incx; | ||||
| } | } | ||||
| if (beta == 0.0f) { | |||||
| y[iy] = alpha * temp; | |||||
| if (beta_f32 == 0.0f) { | |||||
| y[iy] = FROM32(alpha_f32 * temp); | |||||
| } | } | ||||
| else { | else { | ||||
| y[iy] = alpha * temp + beta * y[iy]; | |||||
| y[iy] = FROM32(alpha_f32 * temp + beta_f32 * TO32(y[iy])); | |||||
| } | } | ||||
| iy += incy; | iy += incy; | ||||
| a_ptr += lda; | a_ptr += lda; | ||||
| @@ -127,7 +127,7 @@ int main(int argc, char *argv[]) | |||||
| if (!is_close(float16to32(CC[j << l]), truncate_float32_to_bfloat16(DD[j]), 0.001, 0.0001)) | if (!is_close(float16to32(CC[j << l]), truncate_float32_to_bfloat16(DD[j]), 0.001, 0.0001)) | ||||
| { | { | ||||
| #ifdef DEBUG | #ifdef DEBUG | ||||
| printf("Mismatch at trans=%c, alpha=%.2f, beta=%.2f, i=%d, j=%d, k=%ld: CC=%.6f, C=%.6f\n", | |||||
| printf("Mismatch at trans=%c, alpha=%.2f, beta=%.2f, i=%d, j=%d, k=%ld: CC=%.6f, DD=%.6f\n", | |||||
| transA, alpha, beta, i, j, k, float16to32(CC[j << l]), truncate_float32_to_bfloat16(DD[j])); | transA, alpha, beta, i, j, k, float16to32(CC[j << l]), truncate_float32_to_bfloat16(DD[j])); | ||||
| #endif | #endif | ||||
| ret++; | ret++; | ||||