diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 041582892..938a3bf91 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -237,6 +237,7 @@ In chronological order: * [2025-01-10] Add thread throttling profile for SGEMM on NEOVERSEV1 * [2025-01-21] Optimize gemv_t_sve_v1x3 kernel * [2025-02-26] Add sbgemv_t_bfdot kernel + * [2025-03-12] Fix aarch64 sbgemv_t compilation error for GCC < 13 * Marek Michalowski * [2025-01-21] Add thread throttling profile for SGEMV on `NEOVERSEV1` diff --git a/kernel/arm64/sbgemv_t_bfdot.c b/kernel/arm64/sbgemv_t_bfdot.c index 0751690fc..fc4ae019e 100644 --- a/kernel/arm64/sbgemv_t_bfdot.c +++ b/kernel/arm64/sbgemv_t_bfdot.c @@ -33,11 +33,6 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include #include "common.h" -static inline float bf16_to_fp32(bfloat16 bf16) { - uint32_t fp32 = (uint32_t)bf16 << 16; - return *((float*)&fp32); -} - int CNAME(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, BLASLONG incx, float beta, float *y, BLASLONG incy) { if (m < 1 || n < 1) return(0); @@ -132,10 +127,10 @@ int CNAME(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat } for (; i < m; ++i) { - y0_ptr[iy] += alpha * a0_ptr[i] * x_ptr[i]; - y1_ptr[iy] += alpha * a1_ptr[i] * x_ptr[i]; - y2_ptr[iy] += alpha * a2_ptr[i] * x_ptr[i]; - y3_ptr[iy] += alpha * a3_ptr[i] * x_ptr[i]; + y0_ptr[iy] += alpha * vcvtah_f32_bf16(a0_ptr[i]) * vcvtah_f32_bf16(x_ptr[i]); + y1_ptr[iy] += alpha * vcvtah_f32_bf16(a1_ptr[i]) * vcvtah_f32_bf16(x_ptr[i]); + y2_ptr[iy] += alpha * vcvtah_f32_bf16(a2_ptr[i]) * vcvtah_f32_bf16(x_ptr[i]); + y3_ptr[iy] += alpha * vcvtah_f32_bf16(a3_ptr[i]) * vcvtah_f32_bf16(x_ptr[i]); } iy += incy; @@ -177,7 +172,7 @@ int CNAME(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat } for (; i < m; ++i) { - y_ptr[iy] += alpha * a_ptr[i] * x_ptr[i]; + y_ptr[iy] += alpha * vcvtah_f32_bf16(a_ptr[i]) * vcvtah_f32_bf16(x_ptr[i]); } iy += incy; @@ -191,7 +186,7 @@ int CNAME(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat temp = 0.0; ix = 0; for (i = 0; i < m; i++) { - temp += bf16_to_fp32(a[i]) * bf16_to_fp32(x[ix]); + temp += vcvtah_f32_bf16(a_ptr[i]) * vcvtah_f32_bf16(x_ptr[ix]); ix += incx; } if (beta == 0.0f) {