Browse Source

Fix aarch64 sbgemv_t compilation error for GCC < 13

tags/v0.3.30
Annop Wongwathanarat 6 months ago
parent
commit
a085b6c9ec
2 changed files with 7 additions and 11 deletions
  1. +1
    -0
      CONTRIBUTORS.md
  2. +6
    -11
      kernel/arm64/sbgemv_t_bfdot.c

+ 1
- 0
CONTRIBUTORS.md View File

@@ -237,6 +237,7 @@ In chronological order:
* [2025-01-10] Add thread throttling profile for SGEMM on NEOVERSEV1
* [2025-01-21] Optimize gemv_t_sve_v1x3 kernel
* [2025-02-26] Add sbgemv_t_bfdot kernel
* [2025-03-12] Fix aarch64 sbgemv_t compilation error for GCC < 13

* Marek Michalowski <marek.michalowski@arm.com>
* [2025-01-21] Add thread throttling profile for SGEMV on `NEOVERSEV1`


+ 6
- 11
kernel/arm64/sbgemv_t_bfdot.c View File

@@ -33,11 +33,6 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include <arm_neon.h>
#include "common.h"

static inline float bf16_to_fp32(bfloat16 bf16) {
uint32_t fp32 = (uint32_t)bf16 << 16;
return *((float*)&fp32);
}

int CNAME(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, BLASLONG incx, float beta, float *y, BLASLONG incy)
{
if (m < 1 || n < 1) return(0);
@@ -132,10 +127,10 @@ int CNAME(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat
}

for (; i < m; ++i) {
y0_ptr[iy] += alpha * a0_ptr[i] * x_ptr[i];
y1_ptr[iy] += alpha * a1_ptr[i] * x_ptr[i];
y2_ptr[iy] += alpha * a2_ptr[i] * x_ptr[i];
y3_ptr[iy] += alpha * a3_ptr[i] * x_ptr[i];
y0_ptr[iy] += alpha * vcvtah_f32_bf16(a0_ptr[i]) * vcvtah_f32_bf16(x_ptr[i]);
y1_ptr[iy] += alpha * vcvtah_f32_bf16(a1_ptr[i]) * vcvtah_f32_bf16(x_ptr[i]);
y2_ptr[iy] += alpha * vcvtah_f32_bf16(a2_ptr[i]) * vcvtah_f32_bf16(x_ptr[i]);
y3_ptr[iy] += alpha * vcvtah_f32_bf16(a3_ptr[i]) * vcvtah_f32_bf16(x_ptr[i]);
}

iy += incy;
@@ -177,7 +172,7 @@ int CNAME(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat
}

for (; i < m; ++i) {
y_ptr[iy] += alpha * a_ptr[i] * x_ptr[i];
y_ptr[iy] += alpha * vcvtah_f32_bf16(a_ptr[i]) * vcvtah_f32_bf16(x_ptr[i]);
}

iy += incy;
@@ -191,7 +186,7 @@ int CNAME(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat
temp = 0.0;
ix = 0;
for (i = 0; i < m; i++) {
temp += bf16_to_fp32(a[i]) * bf16_to_fp32(x[ix]);
temp += vcvtah_f32_bf16(a_ptr[i]) * vcvtah_f32_bf16(x_ptr[ix]);
ix += incx;
}
if (beta == 0.0f) {


Loading…
Cancel
Save