Browse Source

fix bugs in aarch64 sbgemv_n kernel

tags/v0.3.30
Ye Tao 6 months ago
parent
commit
f27ba5efd1
1 changed files with 34 additions and 49 deletions
  1. +34
    -49
      kernel/arm64/sbgemv_n_neon.c

+ 34
- 49
kernel/arm64/sbgemv_n_neon.c View File

@@ -69,12 +69,8 @@ static void beta_op(float *x, BLASLONG n, FLOAT beta) {
x += 4;
}

if (rest_n & 3) {
x[0] *= beta;
if ((rest_n & 3) > 1)
x[1] *= beta;
if ((rest_n & 3) > 2)
x[2] *= beta;
for (BLASLONG i = 0; i < (rest_n & 3); i ++) {
x[i] *= beta;
}
}
return;
@@ -88,7 +84,10 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda,

bfloat16x8_t a0, a1, a2, a3, a4, a5, a6, a7;
bfloat16x8_t t0, t1, t2, t3, t4, t5, t6, t7;

bfloat16x8_t x_vec;
bfloat16x4_t x_vecx4;

float32x4_t y1_vec, y2_vec;
float32x4_t fp32_low, fp32_high;

@@ -106,7 +105,7 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda,

if (incx == 1 && incy == 1) {
if (beta != 1) {
beta_op(y, n, beta);
beta_op(y, m, beta);
}

for (i = 0; i < n / 8; i++) {
@@ -290,12 +289,9 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda,

a_ptr += 4 * lda;

bfloat16x4_t x_vecx4 = vld1_bf16(x_ptr);
x_vecx4 = vld1_bf16(x_ptr);
if (alpha != 1) {
x_vec = vcombine_bf16(x_vecx4, bf16_zero);
fp32_low = vreinterpretq_f32_u16(
vzip1q_u16(vreinterpretq_u16_bf16(bf16_zero_q),
vreinterpretq_u16_bf16(x_vec)));
fp32_low = vcvt_f32_bf16(x_vecx4);
fp32_low = vmulq_n_f32(fp32_low, alpha);
x_vecx4 = vcvt_bf16_f32(fp32_low);
}
@@ -348,15 +344,11 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda,

y1_vec = vld1q_f32(y_ptr);

a0 = vcombine_bf16(a0x4, bf16_zero);
a1 = vcombine_bf16(a1x4, bf16_zero);
a2 = vcombine_bf16(a2x4, bf16_zero);
a3 = vcombine_bf16(a3x4, bf16_zero);
a0 = vcombine_bf16(a0x4, a2x4);
a1 = vcombine_bf16(a1x4, a3x4);

t0 = vreinterpretq_bf16_u16(
vzip1q_u16(vreinterpretq_u16_bf16(a0), vreinterpretq_u16_bf16(a1)));
t1 = vreinterpretq_bf16_u16(
vzip1q_u16(vreinterpretq_u16_bf16(a2), vreinterpretq_u16_bf16(a3)));
t0 = vreinterpretq_bf16_u16(vzip1q_u16(vreinterpretq_u16_bf16(a0), vreinterpretq_u16_bf16(a1)));
t1 = vreinterpretq_bf16_u16(vzip2q_u16(vreinterpretq_u16_bf16(a0), vreinterpretq_u16_bf16(a1)));

y1_vec = vbfmlalbq_lane_f32(y1_vec, t0, x_vecx4, 0);
y1_vec = vbfmlaltq_lane_f32(y1_vec, t0, x_vecx4, 1);
@@ -374,10 +366,12 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda,
}

if (rest_m) {
x0 = alpha * vcvtah_f32_bf16(x_ptr[0]);
x1 = alpha * vcvtah_f32_bf16(x_ptr[1]);
x2 = alpha * vcvtah_f32_bf16(x_ptr[2]);
x3 = alpha * vcvtah_f32_bf16(x_ptr[3]);
fp32_low = vcvt_f32_bf16(x_vecx4);

x0 = vgetq_lane_f32(fp32_low, 0);
x1 = vgetq_lane_f32(fp32_low, 1);
x2 = vgetq_lane_f32(fp32_low, 2);
x3 = vgetq_lane_f32(fp32_low, 3);

for (BLASLONG j = 0; j < rest_m; j++) {
y_ptr[j] += x0 * vcvtah_f32_bf16(a_ptr0[j]);
@@ -396,18 +390,13 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda,

a_ptr += 2 * lda;

bfloat16_t tmp_buffer[4];
memset((void*)tmp_buffer, 0, sizeof(bfloat16_t));
tmp_buffer[0] = x_ptr[0];
tmp_buffer[1] = x_ptr[1];
x_vecx4 = vreinterpret_bf16_u16(vzip1_u16(
vreinterpret_u16_bf16(vdup_n_bf16(x_ptr[0])),
vreinterpret_u16_bf16(vdup_n_bf16(x_ptr[1]))
));

bfloat16x4_t x_vecx4 = vld1_bf16(tmp_buffer);
if (alpha != 1) {
x_vec = vcombine_bf16(x_vecx4, bf16_zero);
fp32_low = vreinterpretq_f32_u16(
vzip1q_u16(vreinterpretq_u16_bf16(bf16_zero_q),
vreinterpretq_u16_bf16(x_vec)));
fp32_low = vcvt_f32_bf16(x_vecx4);
fp32_low = vmulq_n_f32(fp32_low, alpha);
x_vecx4 = vcvt_bf16_f32(fp32_low);
}
@@ -422,14 +411,14 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda,

t0 = vreinterpretq_bf16_u16(
vzip1q_u16(vreinterpretq_u16_bf16(a0), vreinterpretq_u16_bf16(a1)));
t4 = vreinterpretq_bf16_u16(
t1 = vreinterpretq_bf16_u16(
vzip2q_u16(vreinterpretq_u16_bf16(a0), vreinterpretq_u16_bf16(a1)));

y1_vec = vbfmlalbq_lane_f32(y1_vec, t0, x_vecx4, 0);
y1_vec = vbfmlaltq_lane_f32(y1_vec, t0, x_vecx4, 1);

y2_vec = vbfmlalbq_lane_f32(y2_vec, t4, x_vecx4, 0);
y2_vec = vbfmlaltq_lane_f32(y2_vec, t4, x_vecx4, 1);
y2_vec = vbfmlalbq_lane_f32(y2_vec, t1, x_vecx4, 0);
y2_vec = vbfmlaltq_lane_f32(y2_vec, t1, x_vecx4, 1);

vst1q_f32(y_ptr, y1_vec);
vst1q_f32(y_ptr + 4, y2_vec);
@@ -449,29 +438,24 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda,
a0 = vcombine_bf16(a0x4, bf16_zero);
a1 = vcombine_bf16(a1x4, bf16_zero);

t0 = vreinterpretq_bf16_u16(
vzip1q_u16(vreinterpretq_u16_bf16(a0), vreinterpretq_u16_bf16(a1)));
t1 = vreinterpretq_bf16_u16(
vzip1q_u16(vreinterpretq_u16_bf16(a2), vreinterpretq_u16_bf16(a3)));
t0 = vreinterpretq_bf16_u16(vzip1q_u16(vreinterpretq_u16_bf16(a0), vreinterpretq_u16_bf16(a1)));

y1_vec = vbfmlalbq_lane_f32(y1_vec, t0, x_vecx4, 0);
y1_vec = vbfmlaltq_lane_f32(y1_vec, t0, x_vecx4, 1);
y1_vec = vbfmlalbq_lane_f32(y1_vec, t1, x_vecx4, 2);
y1_vec = vbfmlaltq_lane_f32(y1_vec, t1, x_vecx4, 3);

vst1q_f32(y_ptr, y1_vec);

a_ptr0 += 4;
a_ptr1 += 4;
a_ptr2 += 4;
a_ptr3 += 4;

y_ptr += 4;
}

if (m & 2) {
x0 = alpha * (vcvtah_f32_bf16(x_ptr[0]));
x1 = alpha * (vcvtah_f32_bf16(x_ptr[1]));
fp32_low = vcvt_f32_bf16(x_vecx4);
x0 = vgetq_lane_f32(fp32_low, 0);
x1 = vgetq_lane_f32(fp32_low, 1);


y_ptr[0] += x0 * vcvtah_f32_bf16(a_ptr0[0]);
y_ptr[0] += x1 * vcvtah_f32_bf16(a_ptr1[0]);
@@ -485,8 +469,9 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda,
}

if (m & 1) {
x0 = alpha * vcvtah_f32_bf16(x_ptr[0]);
x1 = alpha * vcvtah_f32_bf16(x_ptr[1]);
fp32_low = vcvt_f32_bf16(x_vecx4);
x0 = vgetq_lane_f32(fp32_low, 0);
x1 = vgetq_lane_f32(fp32_low, 1);

y_ptr[0] += x0 * vcvtah_f32_bf16(a_ptr0[0]);
y_ptr[0] += x1 * vcvtah_f32_bf16(a_ptr1[0]);


Loading…
Cancel
Save