- Adds bgemv T based off of sbgemv T kernel - Adds bgemv N which is slightly alterated to not use Y as an accumulator due to the output being bf16 which results in loss of precision - Enables BGEMM_GEMV_FORWARD to proxy BGEMM to BGEMV with new kernelspull/5394/head
@@ -277,6 +277,7 @@ endif | |||||
ifeq ($(ARCH), arm64) | ifeq ($(ARCH), arm64) | ||||
GEMM_GEMV_FORWARD = 1 | GEMM_GEMV_FORWARD = 1 | ||||
SBGEMM_GEMV_FORWARD = 1 | SBGEMM_GEMV_FORWARD = 1 | ||||
BGEMM_GEMV_FORWARD = 1 | |||||
endif | endif | ||||
ifeq ($(ARCH), riscv) | ifeq ($(ARCH), riscv) | ||||
GEMM_GEMV_FORWARD = 1 | GEMM_GEMV_FORWARD = 1 | ||||
@@ -296,6 +297,9 @@ endif | |||||
ifeq ($(SBGEMM_GEMV_FORWARD), 1) | ifeq ($(SBGEMM_GEMV_FORWARD), 1) | ||||
CCOMMON_OPT += -DSBGEMM_GEMV_FORWARD | CCOMMON_OPT += -DSBGEMM_GEMV_FORWARD | ||||
endif | endif | ||||
ifeq ($(BGEMM_GEMV_FORWARD), 1) | |||||
CCOMMON_OPT += -DBGEMM_GEMV_FORWARD | |||||
endif | |||||
endif | endif | ||||
# This operation is expensive, so execution should be once. | # This operation is expensive, so execution should be once. | ||||
@@ -84,7 +84,7 @@ GOTO_LAPACK_TARGETS= | |||||
endif | endif | ||||
ifeq ($(BUILD_BFLOAT16),1) | ifeq ($(BUILD_BFLOAT16),1) | ||||
GOTO_BFLOAT_TARGETS=bgemm.goto sbgemm.goto | |||||
GOTO_BFLOAT_TARGETS=bgemm.goto sbgemm.goto bgemv.goto sbgemv.goto | |||||
else | else | ||||
GOTO_BFLOAT_TARGETS= | GOTO_BFLOAT_TARGETS= | ||||
endif | endif | ||||
@@ -667,6 +667,10 @@ bgemm.goto : bgemm.$(SUFFIX) ../$(LIBNAME) | |||||
$(CC) $(CFLAGS) -o $(@F) $^ $(CEXTRALIB) $(EXTRALIB) $(FEXTRALIB) -lm | $(CC) $(CFLAGS) -o $(@F) $^ $(CEXTRALIB) $(EXTRALIB) $(FEXTRALIB) -lm | ||||
sbgemm.goto : sbgemm.$(SUFFIX) ../$(LIBNAME) | sbgemm.goto : sbgemm.$(SUFFIX) ../$(LIBNAME) | ||||
$(CC) $(CFLAGS) -o $(@F) $^ $(CEXTRALIB) $(EXTRALIB) $(FEXTRALIB) -lm | $(CC) $(CFLAGS) -o $(@F) $^ $(CEXTRALIB) $(EXTRALIB) $(FEXTRALIB) -lm | ||||
bgemv.goto : bgemv.$(SUFFIX) ../$(LIBNAME) | |||||
$(CC) $(CFLAGS) -o $(@F) $^ $(CEXTRALIB) $(EXTRALIB) $(FEXTRALIB) -lm | |||||
sbgemv.goto : sbgemv.$(SUFFIX) ../$(LIBNAME) | |||||
$(CC) $(CFLAGS) -o $(@F) $^ $(CEXTRALIB) $(EXTRALIB) $(FEXTRALIB) -lm | |||||
endif | endif | ||||
ifeq ($(BUILD_HFLOAT16),1) | ifeq ($(BUILD_HFLOAT16),1) | ||||
@@ -3146,6 +3150,13 @@ dgemv.$(SUFFIX) : gemv.c | |||||
cgemv.$(SUFFIX) : gemv.c | cgemv.$(SUFFIX) : gemv.c | ||||
$(CC) $(CFLAGS) -c -DCOMPLEX -UDOUBLE -o $(@F) $^ | $(CC) $(CFLAGS) -c -DCOMPLEX -UDOUBLE -o $(@F) $^ | ||||
ifeq ($(BUILD_BFLOAT16),1) | |||||
bgemv.$(SUFFIX) : gemv.c | |||||
$(CC) $(CFLAGS) -c -DBFLOAT16 -DBGEMM -UCOMPLEX -UDOUBLE -o $(@F) $^ | |||||
sbgemv.$(SUFFIX) : gemv.c | |||||
$(CC) $(CFLAGS) -c -DBFLOAT16 -UCOMPLEX -UDOUBLE -o $(@F) $^ | |||||
endif () | |||||
zgemv.$(SUFFIX) : gemv.c | zgemv.$(SUFFIX) : gemv.c | ||||
$(CC) $(CFLAGS) -c -DCOMPLEX -DDOUBLE -o $(@F) $^ | $(CC) $(CFLAGS) -c -DCOMPLEX -DDOUBLE -o $(@F) $^ | ||||
@@ -1,5 +1,5 @@ | |||||
/*************************************************************************** | /*************************************************************************** | ||||
Copyright (c) 2014, The OpenBLAS Project | |||||
Copyright (c) 2014, 2025 The OpenBLAS Project | |||||
All rights reserved. | All rights reserved. | ||||
Redistribution and use in source and binary forms, with or without | Redistribution and use in source and binary forms, with or without | ||||
modification, are permitted provided that the following conditions are | modification, are permitted provided that the following conditions are | ||||
@@ -34,6 +34,12 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||||
#ifdef DOUBLE | #ifdef DOUBLE | ||||
#define GEMV BLASFUNC(dgemv) | #define GEMV BLASFUNC(dgemv) | ||||
#elif defined(BFLOAT16) && defined(BGEMM) | |||||
#define GEMV BLASFUNC(bgemv) | |||||
#elif defined(BFLOAT16) | |||||
#define GEMV BLASFUNC(sbgemv) | |||||
#undef IFLOAT | |||||
#define IFLOAT bfloat16 | |||||
#else | #else | ||||
#define GEMV BLASFUNC(sgemv) | #define GEMV BLASFUNC(sgemv) | ||||
#endif | #endif | ||||
@@ -49,9 +55,20 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||||
#endif | #endif | ||||
int main(int argc, char *argv[]){ | int main(int argc, char *argv[]){ | ||||
FLOAT *a, *x, *y; | |||||
FLOAT alpha[] = {1.0, 1.0}; | |||||
FLOAT beta [] = {1.0, 0.0}; | |||||
IFLOAT *a, *x; | |||||
FLOAT *y; | |||||
#ifdef BGEMM | |||||
blasint one=1; | |||||
blasint two=2; | |||||
float alpha_in[] = {1.0, 0.0}; | |||||
float beta_in[] = {0.0, 0.0}; | |||||
FLOAT alpha[2], beta[2]; | |||||
sbstobf16_(&two, alpha_in, &one, alpha, &one); | |||||
sbstobf16_(&two, beta_in, &one, beta, &one); | |||||
#else | |||||
FLOAT alpha[] = {1.0, 0.0}; | |||||
FLOAT beta [] = {0.0, 0.0}; | |||||
#endif | |||||
char trans='N'; | char trans='N'; | ||||
blasint m, i, j; | blasint m, i, j; | ||||
blasint inc_x=1,inc_y=1; | blasint inc_x=1,inc_y=1; | ||||
@@ -97,11 +114,11 @@ int main(int argc, char *argv[]){ | |||||
fprintf(stderr, "From : %3d To : %3d Step = %3d Trans = '%c' Inc_x = %d Inc_y = %d Loops = %d\n", from, to, step,trans,inc_x,inc_y,loops); | fprintf(stderr, "From : %3d To : %3d Step = %3d Trans = '%c' Inc_x = %d Inc_y = %d Loops = %d\n", from, to, step,trans,inc_x,inc_y,loops); | ||||
if (( a = (FLOAT *)malloc(sizeof(FLOAT) * tomax * tomax * COMPSIZE)) == NULL){ | |||||
if (( a = (IFLOAT *)malloc(sizeof(IFLOAT) * tomax * tomax * COMPSIZE)) == NULL){ | |||||
fprintf(stderr,"Out of Memory!!\n");exit(1); | fprintf(stderr,"Out of Memory!!\n");exit(1); | ||||
} | } | ||||
if (( x = (FLOAT *)malloc(sizeof(FLOAT) * tomax * abs(inc_x) * COMPSIZE)) == NULL){ | |||||
if (( x = (IFLOAT *)malloc(sizeof(IFLOAT) * tomax * abs(inc_x) * COMPSIZE)) == NULL){ | |||||
fprintf(stderr,"Out of Memory!!\n");exit(1); | fprintf(stderr,"Out of Memory!!\n");exit(1); | ||||
} | } | ||||
@@ -125,7 +142,7 @@ int main(int argc, char *argv[]){ | |||||
fprintf(stderr, " %6dx%d : ", (int)m,(int)n); | fprintf(stderr, " %6dx%d : ", (int)m,(int)n); | ||||
for(j = 0; j < m; j++){ | for(j = 0; j < m; j++){ | ||||
for(i = 0; i < n * COMPSIZE; i++){ | for(i = 0; i < n * COMPSIZE; i++){ | ||||
a[(long)i + (long)j * (long)m * COMPSIZE] = ((FLOAT) rand() / (FLOAT) RAND_MAX) - 0.5; | |||||
a[(long)i + (long)j * (long)m * COMPSIZE] = ((IFLOAT) rand() / (FLOAT) RAND_MAX) - 0.5; | |||||
} | } | ||||
} | } | ||||
@@ -133,7 +150,7 @@ int main(int argc, char *argv[]){ | |||||
{ | { | ||||
for(i = 0; i < n * COMPSIZE * abs(inc_x); i++){ | for(i = 0; i < n * COMPSIZE * abs(inc_x); i++){ | ||||
x[i] = ((FLOAT) rand() / (FLOAT) RAND_MAX) - 0.5; | |||||
x[i] = ((IFLOAT) rand() / (FLOAT) RAND_MAX) - 0.5; | |||||
} | } | ||||
for(i = 0; i < m * COMPSIZE * abs(inc_y); i++){ | for(i = 0; i < m * COMPSIZE * abs(inc_y); i++){ | ||||
@@ -46,6 +46,9 @@ BGEMMOTCOPY = sbgemm_tcopy_$(BGEMM_UNROLL_N)_neoversev1.c | |||||
BGEMMONCOPYOBJ = bgemm_oncopy$(TSUFFIX).$(SUFFIX) | BGEMMONCOPYOBJ = bgemm_oncopy$(TSUFFIX).$(SUFFIX) | ||||
BGEMMOTCOPYOBJ = bgemm_otcopy$(TSUFFIX).$(SUFFIX) | BGEMMOTCOPYOBJ = bgemm_otcopy$(TSUFFIX).$(SUFFIX) | ||||
BGEMVTKERNEL = sbgemv_t_bfdot.c | |||||
BGEMVNKERNEL = bgemv_n_sve_v3x4.c | |||||
SBGEMM_BETA = sbgemm_beta_neoversev1.c | SBGEMM_BETA = sbgemm_beta_neoversev1.c | ||||
SBGEMMKERNEL = sbgemm_kernel_$(SBGEMM_UNROLL_M)x$(SBGEMM_UNROLL_N)_neoversev1.c | SBGEMMKERNEL = sbgemm_kernel_$(SBGEMM_UNROLL_M)x$(SBGEMM_UNROLL_N)_neoversev1.c | ||||
ifneq ($(SBGEMM_UNROLL_M), $(SBGEMM_UNROLL_N)) | ifneq ($(SBGEMM_UNROLL_M), $(SBGEMM_UNROLL_N)) | ||||
@@ -0,0 +1,321 @@ | |||||
/*************************************************************************** | |||||
Copyright (c) 2025 The OpenBLAS Project | |||||
All rights reserved. | |||||
Redistribution and use in source and binary forms, with or without | |||||
modification, are permitted provided that the following conditions are | |||||
met: | |||||
1. Redistributions of source code must retain the above copyright | |||||
notice, this list of conditions and the following disclaimer. | |||||
2. Redistributions in binary form must reproduce the above copyright | |||||
notice, this list of conditions and the following disclaimer in | |||||
the documentation and/or other materials provided with the | |||||
distribution. | |||||
3. Neither the name of the OpenBLAS project nor the names of | |||||
its contributors may be used to endorse or promote products | |||||
derived from this software without specific prior written permission. | |||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||||
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE | |||||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE | |||||
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||||
*****************************************************************************/ | |||||
#include "common.h" | |||||
#include <arm_sve.h> | |||||
#define UPDATE_PTRSx2 \ | |||||
a_ptr1 = a_ptr0 + lda; | |||||
#define UPDATE_PTRSx4 \ | |||||
UPDATE_PTRSx2 \ | |||||
a_ptr2 = a_ptr1 + lda; \ | |||||
a_ptr3 = a_ptr2 + lda; | |||||
#define UPDATE_PTRSx8 \ | |||||
UPDATE_PTRSx4 \ | |||||
a_ptr4 = a_ptr3 + lda; \ | |||||
a_ptr5 = a_ptr4 + lda; \ | |||||
a_ptr6 = a_ptr5 + lda; \ | |||||
a_ptr7 = a_ptr6 + lda; | |||||
#define LANESx2(MACRO, offset) \ | |||||
MACRO(offset, 0) \ | |||||
MACRO(offset, 1) | |||||
#define LANESx4(MACRO, offset) \ | |||||
LANESx2(MACRO, offset) \ | |||||
MACRO(offset, 2) \ | |||||
MACRO(offset, 3) | |||||
#define LANESx8(MACRO, offset) \ | |||||
LANESx4(MACRO, offset) \ | |||||
MACRO(offset, 4) \ | |||||
MACRO(offset, 5) \ | |||||
MACRO(offset, 6) \ | |||||
MACRO(offset, 7) | |||||
#define LOAD_A_VEC(offset, vec) \ | |||||
svbfloat16_t a_vec ## offset ## vec = svld1(pg_full, &a_ptr ## vec[i] + offset * sve_size_bf16); | |||||
#define UPDATE_ACCUMULATORS_FROM_LANE(offset, lane) \ | |||||
acc ## offset ## 0 = svbfmlalb_lane(acc ## offset ## 0, a_vec ## offset ## lane, x_vec, lane); \ | |||||
acc ## offset ## 1 = svbfmlalt_lane(acc ## offset ## 1, a_vec ## offset ## lane, x_vec, lane); | |||||
#define INIT_ACCUMULATORS(offset) \ | |||||
svfloat32_t acc ## offset ## 0 = svdup_f32(0.0); \ | |||||
svfloat32_t acc ## offset ## 1 = svdup_f32(0.0); | |||||
#define UPDATE_ACCUMULATORS(offset) \ | |||||
acc ## offset ## 0 = svbfmlalb(acc ## offset ## 0, a_vec ## offset ## 0, x_vec); \ | |||||
acc ## offset ## 1 = svbfmlalt(acc ## offset ## 1, a_vec ## offset ## 0, x_vec); | |||||
#define STORE_ACCUMULATORS(offset) \ | |||||
svbfloat16_t acc ## offset ## 0_bf16 = svcvt_bf16_x(pg_full, acc ## offset ## 0); \ | |||||
svbfloat16_t acc ## offset ## 1_bf16 = svcvt_bf16_x(pg_full, acc ## offset ## 1); \ | |||||
svbfloat16_t combined ## offset = svtrn1(acc ## offset ## 0_bf16, acc ## offset ## 1_bf16); \ | |||||
svst1(pg_full, &y[i] + offset * sve_size_bf16, combined ## offset); | |||||
#define ALPHA_OP(offset) \ | |||||
acc ## offset ## 0 = svmul_x(pg_full, acc ## offset ## 0, svalpha); \ | |||||
acc ## offset ## 1 = svmul_x(pg_full, acc ## offset ## 1, svalpha); | |||||
#define BETA_OP(offset) \ | |||||
svbfloat16_t y_vec ## offset = svld1(pg_full, &y[i] + offset * sve_size_bf16); \ | |||||
acc ## offset ## 0 = svbfmlalb(acc ## offset ## 0, svbeta16, y_vec ## offset); \ | |||||
acc ## offset ## 1 = svbfmlalt(acc ## offset ## 1, svbeta16, y_vec ## offset); | |||||
int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a_in, BLASLONG lda, IFLOAT *x_in, BLASLONG inc_x, FLOAT beta, FLOAT *y_in, BLASLONG inc_y) | |||||
{ | |||||
BLASLONG ix; | |||||
bfloat16_t *a_ptr0, *a_ptr1, *a_ptr2, *a_ptr3, *a_ptr4, *a_ptr5, *a_ptr6, *a_ptr7; | |||||
BLASLONG sve_size_bf16 = svcnth(); | |||||
BLASLONG sve_size2_bf16 = sve_size_bf16 * 2; | |||||
BLASLONG sve_size3_bf16 = sve_size_bf16 * 3; | |||||
svbool_t pg_full = svptrue_b16(); | |||||
svbool_t pg_tail = svwhilelt_b16_s64(0, m % sve_size_bf16); | |||||
BLASLONG n8 = n & -8; | |||||
BLASLONG n4 = n & -4; | |||||
BLASLONG n2 = n & -2; | |||||
bfloat16_t *a = (bfloat16_t*)a_in; | |||||
bfloat16_t *x = (bfloat16_t*)x_in; | |||||
bfloat16_t *y = (bfloat16_t*)y_in; | |||||
bfloat16_t alpha_bf16, beta_bf16; | |||||
memcpy(&alpha_bf16, &alpha, sizeof(bfloat16_t)); | |||||
memcpy(&beta_bf16, &beta, sizeof(bfloat16_t)); | |||||
float beta_f32 = vcvtah_f32_bf16(beta_bf16); | |||||
float alpha_f32 = vcvtah_f32_bf16(alpha_bf16); | |||||
svfloat32_t svalpha = svdup_f32(vcvtah_f32_bf16(alpha_bf16)); | |||||
svbfloat16_t svbeta16 = svdup_bf16(beta_bf16); | |||||
BLASLONG i = 0; | |||||
if (inc_y == 1) { | |||||
for (; (i + sve_size3_bf16 - 1) < m; i+= sve_size3_bf16) { | |||||
INIT_ACCUMULATORS(0); | |||||
INIT_ACCUMULATORS(1); | |||||
INIT_ACCUMULATORS(2); | |||||
BLASLONG j = 0; | |||||
ix = 0; | |||||
a_ptr0 = a; | |||||
UPDATE_PTRSx8; | |||||
if (inc_x == 1) { | |||||
for (; j < n4; j+= 4) { | |||||
uint64_t* x_u64 = (uint64_t*)&x[ix]; | |||||
svbfloat16_t x_vec = svreinterpret_bf16_u64(svdup_u64(*x_u64)); | |||||
LANESx4(LOAD_A_VEC, 0); | |||||
LANESx4(LOAD_A_VEC, 1); | |||||
LANESx4(LOAD_A_VEC, 2); | |||||
LANESx4(UPDATE_ACCUMULATORS_FROM_LANE, 0); | |||||
LANESx4(UPDATE_ACCUMULATORS_FROM_LANE, 1); | |||||
LANESx4(UPDATE_ACCUMULATORS_FROM_LANE, 2); | |||||
ix += 4; | |||||
a_ptr0 += 4 * lda; | |||||
UPDATE_PTRSx4; | |||||
} | |||||
for (; j < n2; j+= 2) { | |||||
uint32_t* x_u32 = (uint32_t*)&x[ix]; | |||||
svbfloat16_t x_vec = svreinterpret_bf16_u32(svdup_u32(*x_u32)); | |||||
LANESx2(LOAD_A_VEC, 0); | |||||
LANESx2(LOAD_A_VEC, 1); | |||||
LANESx2(LOAD_A_VEC, 2); | |||||
LANESx2(UPDATE_ACCUMULATORS_FROM_LANE, 0); | |||||
LANESx2(UPDATE_ACCUMULATORS_FROM_LANE, 1); | |||||
LANESx2(UPDATE_ACCUMULATORS_FROM_LANE, 2); | |||||
ix += 2; | |||||
a_ptr0 += 2 * lda; | |||||
UPDATE_PTRSx2; | |||||
} | |||||
} | |||||
for (; j < n; j++) { | |||||
svbfloat16_t x_vec = svdup_bf16(x[ix]); | |||||
LOAD_A_VEC(0, 0); | |||||
LOAD_A_VEC(1, 0); | |||||
LOAD_A_VEC(2, 0); | |||||
UPDATE_ACCUMULATORS(0); | |||||
UPDATE_ACCUMULATORS(1); | |||||
UPDATE_ACCUMULATORS(2); | |||||
ix += inc_x; | |||||
a_ptr0 += lda; | |||||
} | |||||
if (alpha_f32 != ONE) { | |||||
ALPHA_OP(0); | |||||
ALPHA_OP(1); | |||||
ALPHA_OP(2); | |||||
} | |||||
if (beta_f32 != ZERO) { | |||||
BETA_OP(0); | |||||
BETA_OP(1); | |||||
BETA_OP(2); | |||||
} | |||||
STORE_ACCUMULATORS(0); | |||||
STORE_ACCUMULATORS(1); | |||||
STORE_ACCUMULATORS(2); | |||||
} | |||||
for (; (i + sve_size_bf16 - 1) < m; i+= sve_size_bf16) { | |||||
INIT_ACCUMULATORS(0); | |||||
BLASLONG j = 0; | |||||
ix = 0; | |||||
a_ptr0 = a; | |||||
UPDATE_PTRSx8; | |||||
if (inc_x == 1) { | |||||
for (; j < n8; j+= 8) { | |||||
svbfloat16_t x_vec = svld1rq(pg_full, &x[ix]); | |||||
LANESx8(LOAD_A_VEC, 0); | |||||
LANESx8(UPDATE_ACCUMULATORS_FROM_LANE, 0); | |||||
ix += 8; | |||||
a_ptr0 += 8 * lda; | |||||
UPDATE_PTRSx8; | |||||
} | |||||
for (; j < n4; j+= 4) { | |||||
uint64_t* x_u64 = (uint64_t*)&x[ix]; | |||||
svbfloat16_t x_vec = svreinterpret_bf16_u64(svdup_u64(*x_u64)); | |||||
LANESx4(LOAD_A_VEC, 0); | |||||
LANESx4(UPDATE_ACCUMULATORS_FROM_LANE, 0); | |||||
ix += 4; | |||||
a_ptr0 += 4 * lda; | |||||
UPDATE_PTRSx4; | |||||
} | |||||
for (; j < n2; j+= 2) { | |||||
uint32_t* x_u32 = (uint32_t*)&x[ix]; | |||||
svbfloat16_t x_vec = svreinterpret_bf16_u32(svdup_u32(*x_u32)); | |||||
LANESx2(LOAD_A_VEC, 0); | |||||
LANESx2(UPDATE_ACCUMULATORS_FROM_LANE, 0); | |||||
ix += 2; | |||||
a_ptr0 += 2 * lda; | |||||
UPDATE_PTRSx2; | |||||
} | |||||
} | |||||
for (; j < n; j++) { | |||||
svbfloat16_t x_vec = svdup_bf16(x[ix]); | |||||
LOAD_A_VEC(0, 0); | |||||
UPDATE_ACCUMULATORS(0); | |||||
ix += inc_x; | |||||
a_ptr0 += lda; | |||||
} | |||||
if (alpha_f32 != ONE) { | |||||
ALPHA_OP(0); | |||||
} | |||||
if (beta_f32 != ZERO) { | |||||
BETA_OP(0); | |||||
} | |||||
STORE_ACCUMULATORS(0); | |||||
} | |||||
if (i < m) { | |||||
svfloat32_t acc0 = svdup_f32(0.0); | |||||
svfloat32_t acc1 = svdup_f32(0.0); | |||||
ix = 0; | |||||
a_ptr0 = a; | |||||
for (BLASLONG j = 0; j < n; j++) { | |||||
svbfloat16_t x_vec0 = svdup_bf16(x[ix]); | |||||
svbfloat16_t a_vec0 = svld1(pg_tail, &a_ptr0[i]); | |||||
acc0 = svbfmlalb(acc0, a_vec0, x_vec0); | |||||
acc1 = svbfmlalt(acc1, a_vec0, x_vec0); | |||||
ix += inc_x; | |||||
a_ptr0 += lda; | |||||
} | |||||
if (alpha_f32 != ONE) { | |||||
acc0 = svmul_x(pg_full, acc0, svalpha); | |||||
acc1 = svmul_x(pg_full, acc1, svalpha); | |||||
} | |||||
if (beta_f32 != ZERO) { | |||||
svbfloat16_t y_vec = svld1(pg_tail, &y[i]); | |||||
acc0 = svbfmlalb(acc0, svbeta16, y_vec); | |||||
acc1 = svbfmlalt(acc1, svbeta16, y_vec); | |||||
} | |||||
svbfloat16_t acc0_bf16 = svcvt_bf16_x(pg_full, acc0); | |||||
svbfloat16_t acc1_bf16 = svcvt_bf16_x(pg_full, acc1); | |||||
svbfloat16_t combined = svtrn1(acc0_bf16, acc1_bf16); | |||||
svst1(pg_tail, &y[i], combined); | |||||
} | |||||
return 0; | |||||
} | |||||
// Scalar fallback | |||||
BLASLONG iy = 0; | |||||
for (; i < m; i++) { | |||||
float temp = 0.0; | |||||
ix = 0; | |||||
a_ptr0 = a; | |||||
for (BLASLONG j = 0; j < n; j++) { | |||||
temp += vcvtah_f32_bf16(a_ptr0[i]) * vcvtah_f32_bf16(x[ix]); | |||||
ix += inc_x; | |||||
a_ptr0 += lda; | |||||
} | |||||
if (beta_f32 == ZERO) { | |||||
y[iy] = vcvth_bf16_f32(alpha_f32 * temp); | |||||
} else { | |||||
y[iy] = vcvth_bf16_f32(alpha_f32 * temp + beta_f32 * vcvtah_f32_bf16(y[iy])); | |||||
} | |||||
iy += inc_y; | |||||
} | |||||
return (0); | |||||
} |
@@ -33,16 +33,39 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||||
#include <arm_neon.h> | #include <arm_neon.h> | ||||
#include "common.h" | #include "common.h" | ||||
int CNAME(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, BLASLONG incx, float beta, float *y, BLASLONG incy) | |||||
#ifdef BGEMM | |||||
#define INNER_FLOAT bfloat16_t | |||||
#define TO32(x) vcvtah_f32_bf16(x) | |||||
#define FROM32(x) vcvth_bf16_f32(x) | |||||
#else | |||||
#define INNER_FLOAT float | |||||
#define TO32(x) x | |||||
#define FROM32(x) x | |||||
#endif | |||||
int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, BLASLONG incx, FLOAT beta, FLOAT *y_in, BLASLONG incy) | |||||
{ | { | ||||
if (m < 1 || n < 1) return(0); | if (m < 1 || n < 1) return(0); | ||||
BLASLONG i; | BLASLONG i; | ||||
BLASLONG ix,iy; | BLASLONG ix,iy; | ||||
BLASLONG j; | BLASLONG j; | ||||
bfloat16_t *a_ptr; | bfloat16_t *a_ptr; | ||||
bfloat16_t *x_ptr; | bfloat16_t *x_ptr; | ||||
float *y_ptr; | |||||
float temp; | |||||
float temp, temp0, temp1, temp2, temp3; | |||||
#ifdef BGEMM | |||||
bfloat16_t alpha_bf16, beta_bf16; | |||||
memcpy(&alpha_bf16, &alpha, sizeof(bfloat16_t)); | |||||
memcpy(&beta_bf16, &beta, sizeof(bfloat16_t)); | |||||
float alpha_f32 = vcvtah_f32_bf16(alpha_bf16); | |||||
float beta_f32 = vcvtah_f32_bf16(beta_bf16); | |||||
#else | |||||
float alpha_f32 = alpha; | |||||
float beta_f32 = beta; | |||||
#endif | |||||
INNER_FLOAT *y = (INNER_FLOAT *)y_in; | |||||
INNER_FLOAT *y_ptr; | |||||
iy = 0; | iy = 0; | ||||
a_ptr = (bfloat16_t*)(a); | a_ptr = (bfloat16_t*)(a); | ||||
@@ -56,10 +79,10 @@ int CNAME(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat | |||||
bfloat16_t *a2_ptr = a_ptr + lda * width * 2; | bfloat16_t *a2_ptr = a_ptr + lda * width * 2; | ||||
bfloat16_t *a3_ptr = a_ptr + lda * width * 3; | bfloat16_t *a3_ptr = a_ptr + lda * width * 3; | ||||
float *y0_ptr = y + incy * width * 0; | |||||
float *y1_ptr = y + incy * width * 1; | |||||
float *y2_ptr = y + incy * width * 2; | |||||
float *y3_ptr = y + incy * width * 3; | |||||
INNER_FLOAT *y0_ptr = y + incy * width * 0; | |||||
INNER_FLOAT *y1_ptr = y + incy * width * 1; | |||||
INNER_FLOAT *y2_ptr = y + incy * width * 2; | |||||
INNER_FLOAT *y3_ptr = y + incy * width * 3; | |||||
for (j = 0; j < width; j++) { | for (j = 0; j < width; j++) { | ||||
float32x4_t temp0_vec = vdupq_n_f32(0.0f); | float32x4_t temp0_vec = vdupq_n_f32(0.0f); | ||||
@@ -113,26 +136,31 @@ int CNAME(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat | |||||
i += 4; | i += 4; | ||||
} | } | ||||
if (beta == 0.0f) { | |||||
y0_ptr[iy] = alpha * vaddvq_f32(temp0_vec); | |||||
y1_ptr[iy] = alpha * vaddvq_f32(temp1_vec); | |||||
y2_ptr[iy] = alpha * vaddvq_f32(temp2_vec); | |||||
y3_ptr[iy] = alpha * vaddvq_f32(temp3_vec); | |||||
} | |||||
else { | |||||
y0_ptr[iy] = alpha * vaddvq_f32(temp0_vec) + beta * y0_ptr[iy]; | |||||
y1_ptr[iy] = alpha * vaddvq_f32(temp1_vec) + beta * y1_ptr[iy]; | |||||
y2_ptr[iy] = alpha * vaddvq_f32(temp2_vec) + beta * y2_ptr[iy]; | |||||
y3_ptr[iy] = alpha * vaddvq_f32(temp3_vec) + beta * y3_ptr[iy]; | |||||
if (beta_f32 == 0.0f) { | |||||
temp0 = alpha_f32 * vaddvq_f32(temp0_vec); | |||||
temp1 = alpha_f32 * vaddvq_f32(temp1_vec); | |||||
temp2 = alpha_f32 * vaddvq_f32(temp2_vec); | |||||
temp3 = alpha_f32 * vaddvq_f32(temp3_vec); | |||||
} else { | |||||
temp0 = alpha_f32 * vaddvq_f32(temp0_vec) + beta_f32 * TO32(y0_ptr[iy]); | |||||
temp1 = alpha_f32 * vaddvq_f32(temp1_vec) + beta_f32 * TO32(y1_ptr[iy]); | |||||
temp2 = alpha_f32 * vaddvq_f32(temp2_vec) + beta_f32 * TO32(y2_ptr[iy]); | |||||
temp3 = alpha_f32 * vaddvq_f32(temp3_vec) + beta_f32 * TO32(y3_ptr[iy]); | |||||
} | } | ||||
for (; i < m; ++i) { | for (; i < m; ++i) { | ||||
y0_ptr[iy] += alpha * vcvtah_f32_bf16(a0_ptr[i]) * vcvtah_f32_bf16(x_ptr[i]); | |||||
y1_ptr[iy] += alpha * vcvtah_f32_bf16(a1_ptr[i]) * vcvtah_f32_bf16(x_ptr[i]); | |||||
y2_ptr[iy] += alpha * vcvtah_f32_bf16(a2_ptr[i]) * vcvtah_f32_bf16(x_ptr[i]); | |||||
y3_ptr[iy] += alpha * vcvtah_f32_bf16(a3_ptr[i]) * vcvtah_f32_bf16(x_ptr[i]); | |||||
temp0 = temp0 + alpha_f32 * vcvtah_f32_bf16(a0_ptr[i]) * vcvtah_f32_bf16(x_ptr[i]); | |||||
temp1 = temp1 + alpha_f32 * vcvtah_f32_bf16(a1_ptr[i]) * vcvtah_f32_bf16(x_ptr[i]); | |||||
temp2 = temp2 + alpha_f32 * vcvtah_f32_bf16(a2_ptr[i]) * vcvtah_f32_bf16(x_ptr[i]); | |||||
temp3 = temp3 + alpha_f32 * vcvtah_f32_bf16(a3_ptr[i]) * vcvtah_f32_bf16(x_ptr[i]); | |||||
} | } | ||||
y0_ptr[iy] = FROM32(temp0); | |||||
y1_ptr[iy] = FROM32(temp1); | |||||
y2_ptr[iy] = FROM32(temp2); | |||||
y3_ptr[iy] = FROM32(temp3); | |||||
iy += incy; | iy += incy; | ||||
a0_ptr += lda; | a0_ptr += lda; | ||||
@@ -164,16 +192,16 @@ int CNAME(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat | |||||
i += 4; | i += 4; | ||||
} | } | ||||
if (beta == 0.0f) { | |||||
y_ptr[iy] = alpha * vaddvq_f32(temp0_vec); | |||||
} | |||||
else { | |||||
y_ptr[iy] = alpha * vaddvq_f32(temp0_vec) + beta * y_ptr[iy]; | |||||
} | |||||
if (beta_f32 == 0.0f) { | |||||
temp = alpha_f32 * vaddvq_f32(temp0_vec); | |||||
} else { | |||||
temp = alpha_f32 * vaddvq_f32(temp0_vec) + beta_f32 * TO32(y_ptr[iy]); | |||||
} | |||||
for (; i < m; ++i) { | for (; i < m; ++i) { | ||||
y_ptr[iy] += alpha * vcvtah_f32_bf16(a_ptr[i]) * vcvtah_f32_bf16(x_ptr[i]); | |||||
temp += alpha_f32 * vcvtah_f32_bf16(a_ptr[i]) * vcvtah_f32_bf16(x_ptr[i]); | |||||
} | } | ||||
y_ptr[iy] = FROM32(temp); | |||||
iy += incy; | iy += incy; | ||||
@@ -189,11 +217,11 @@ int CNAME(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat | |||||
temp += vcvtah_f32_bf16(a_ptr[i]) * vcvtah_f32_bf16(x_ptr[ix]); | temp += vcvtah_f32_bf16(a_ptr[i]) * vcvtah_f32_bf16(x_ptr[ix]); | ||||
ix += incx; | ix += incx; | ||||
} | } | ||||
if (beta == 0.0f) { | |||||
y[iy] = alpha * temp; | |||||
if (beta_f32 == 0.0f) { | |||||
y[iy] = FROM32(alpha_f32 * temp); | |||||
} | } | ||||
else { | else { | ||||
y[iy] = alpha * temp + beta * y[iy]; | |||||
y[iy] = FROM32(alpha_f32 * temp + beta_f32 * TO32(y[iy])); | |||||
} | } | ||||
iy += incy; | iy += incy; | ||||
a_ptr += lda; | a_ptr += lda; | ||||
@@ -127,7 +127,7 @@ int main(int argc, char *argv[]) | |||||
if (!is_close(float16to32(CC[j << l]), truncate_float32_to_bfloat16(DD[j]), 0.001, 0.0001)) | if (!is_close(float16to32(CC[j << l]), truncate_float32_to_bfloat16(DD[j]), 0.001, 0.0001)) | ||||
{ | { | ||||
#ifdef DEBUG | #ifdef DEBUG | ||||
printf("Mismatch at trans=%c, alpha=%.2f, beta=%.2f, i=%d, j=%d, k=%ld: CC=%.6f, C=%.6f\n", | |||||
printf("Mismatch at trans=%c, alpha=%.2f, beta=%.2f, i=%d, j=%d, k=%ld: CC=%.6f, DD=%.6f\n", | |||||
transA, alpha, beta, i, j, k, float16to32(CC[j << l]), truncate_float32_to_bfloat16(DD[j])); | transA, alpha, beta, i, j, k, float16to32(CC[j << l]), truncate_float32_to_bfloat16(DD[j])); | ||||
#endif | #endif | ||||
ret++; | ret++; | ||||