|
|
@@ -69,12 +69,8 @@ static void beta_op(float *x, BLASLONG n, FLOAT beta) { |
|
|
|
x += 4; |
|
|
|
} |
|
|
|
|
|
|
|
if (rest_n & 3) { |
|
|
|
x[0] *= beta; |
|
|
|
if ((rest_n & 3) > 1) |
|
|
|
x[1] *= beta; |
|
|
|
if ((rest_n & 3) > 2) |
|
|
|
x[2] *= beta; |
|
|
|
for (BLASLONG i = 0; i < (rest_n & 3); i ++) { |
|
|
|
x[i] *= beta; |
|
|
|
} |
|
|
|
} |
|
|
|
return; |
|
|
@@ -88,7 +84,10 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda, |
|
|
|
|
|
|
|
bfloat16x8_t a0, a1, a2, a3, a4, a5, a6, a7; |
|
|
|
bfloat16x8_t t0, t1, t2, t3, t4, t5, t6, t7; |
|
|
|
|
|
|
|
bfloat16x8_t x_vec; |
|
|
|
bfloat16x4_t x_vecx4; |
|
|
|
|
|
|
|
float32x4_t y1_vec, y2_vec; |
|
|
|
float32x4_t fp32_low, fp32_high; |
|
|
|
|
|
|
@@ -106,7 +105,7 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda, |
|
|
|
|
|
|
|
if (incx == 1 && incy == 1) { |
|
|
|
if (beta != 1) { |
|
|
|
beta_op(y, n, beta); |
|
|
|
beta_op(y, m, beta); |
|
|
|
} |
|
|
|
|
|
|
|
for (i = 0; i < n / 8; i++) { |
|
|
@@ -290,12 +289,9 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda, |
|
|
|
|
|
|
|
a_ptr += 4 * lda; |
|
|
|
|
|
|
|
bfloat16x4_t x_vecx4 = vld1_bf16(x_ptr); |
|
|
|
x_vecx4 = vld1_bf16(x_ptr); |
|
|
|
if (alpha != 1) { |
|
|
|
x_vec = vcombine_bf16(x_vecx4, bf16_zero); |
|
|
|
fp32_low = vreinterpretq_f32_u16( |
|
|
|
vzip1q_u16(vreinterpretq_u16_bf16(bf16_zero_q), |
|
|
|
vreinterpretq_u16_bf16(x_vec))); |
|
|
|
fp32_low = vcvt_f32_bf16(x_vecx4); |
|
|
|
fp32_low = vmulq_n_f32(fp32_low, alpha); |
|
|
|
x_vecx4 = vcvt_bf16_f32(fp32_low); |
|
|
|
} |
|
|
@@ -348,15 +344,11 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda, |
|
|
|
|
|
|
|
y1_vec = vld1q_f32(y_ptr); |
|
|
|
|
|
|
|
a0 = vcombine_bf16(a0x4, bf16_zero); |
|
|
|
a1 = vcombine_bf16(a1x4, bf16_zero); |
|
|
|
a2 = vcombine_bf16(a2x4, bf16_zero); |
|
|
|
a3 = vcombine_bf16(a3x4, bf16_zero); |
|
|
|
a0 = vcombine_bf16(a0x4, a2x4); |
|
|
|
a1 = vcombine_bf16(a1x4, a3x4); |
|
|
|
|
|
|
|
t0 = vreinterpretq_bf16_u16( |
|
|
|
vzip1q_u16(vreinterpretq_u16_bf16(a0), vreinterpretq_u16_bf16(a1))); |
|
|
|
t1 = vreinterpretq_bf16_u16( |
|
|
|
vzip1q_u16(vreinterpretq_u16_bf16(a2), vreinterpretq_u16_bf16(a3))); |
|
|
|
t0 = vreinterpretq_bf16_u16(vzip1q_u16(vreinterpretq_u16_bf16(a0), vreinterpretq_u16_bf16(a1))); |
|
|
|
t1 = vreinterpretq_bf16_u16(vzip2q_u16(vreinterpretq_u16_bf16(a0), vreinterpretq_u16_bf16(a1))); |
|
|
|
|
|
|
|
y1_vec = vbfmlalbq_lane_f32(y1_vec, t0, x_vecx4, 0); |
|
|
|
y1_vec = vbfmlaltq_lane_f32(y1_vec, t0, x_vecx4, 1); |
|
|
@@ -374,10 +366,12 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda, |
|
|
|
} |
|
|
|
|
|
|
|
if (rest_m) { |
|
|
|
x0 = alpha * vcvtah_f32_bf16(x_ptr[0]); |
|
|
|
x1 = alpha * vcvtah_f32_bf16(x_ptr[1]); |
|
|
|
x2 = alpha * vcvtah_f32_bf16(x_ptr[2]); |
|
|
|
x3 = alpha * vcvtah_f32_bf16(x_ptr[3]); |
|
|
|
fp32_low = vcvt_f32_bf16(x_vecx4); |
|
|
|
|
|
|
|
x0 = vgetq_lane_f32(fp32_low, 0); |
|
|
|
x1 = vgetq_lane_f32(fp32_low, 1); |
|
|
|
x2 = vgetq_lane_f32(fp32_low, 2); |
|
|
|
x3 = vgetq_lane_f32(fp32_low, 3); |
|
|
|
|
|
|
|
for (BLASLONG j = 0; j < rest_m; j++) { |
|
|
|
y_ptr[j] += x0 * vcvtah_f32_bf16(a_ptr0[j]); |
|
|
@@ -396,18 +390,13 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda, |
|
|
|
|
|
|
|
a_ptr += 2 * lda; |
|
|
|
|
|
|
|
bfloat16_t tmp_buffer[4]; |
|
|
|
memset((void*)tmp_buffer, 0, sizeof(bfloat16_t)); |
|
|
|
|
|
|
|
tmp_buffer[0] = x_ptr[0]; |
|
|
|
tmp_buffer[1] = x_ptr[1]; |
|
|
|
x_vecx4 = vreinterpret_bf16_u16(vzip1_u16( |
|
|
|
vreinterpret_u16_bf16(vdup_n_bf16(x_ptr[0])), |
|
|
|
vreinterpret_u16_bf16(vdup_n_bf16(x_ptr[1])) |
|
|
|
)); |
|
|
|
|
|
|
|
bfloat16x4_t x_vecx4 = vld1_bf16(tmp_buffer); |
|
|
|
if (alpha != 1) { |
|
|
|
x_vec = vcombine_bf16(x_vecx4, bf16_zero); |
|
|
|
fp32_low = vreinterpretq_f32_u16( |
|
|
|
vzip1q_u16(vreinterpretq_u16_bf16(bf16_zero_q), |
|
|
|
vreinterpretq_u16_bf16(x_vec))); |
|
|
|
fp32_low = vcvt_f32_bf16(x_vecx4); |
|
|
|
fp32_low = vmulq_n_f32(fp32_low, alpha); |
|
|
|
x_vecx4 = vcvt_bf16_f32(fp32_low); |
|
|
|
} |
|
|
@@ -422,14 +411,14 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda, |
|
|
|
|
|
|
|
t0 = vreinterpretq_bf16_u16( |
|
|
|
vzip1q_u16(vreinterpretq_u16_bf16(a0), vreinterpretq_u16_bf16(a1))); |
|
|
|
t4 = vreinterpretq_bf16_u16( |
|
|
|
t1 = vreinterpretq_bf16_u16( |
|
|
|
vzip2q_u16(vreinterpretq_u16_bf16(a0), vreinterpretq_u16_bf16(a1))); |
|
|
|
|
|
|
|
y1_vec = vbfmlalbq_lane_f32(y1_vec, t0, x_vecx4, 0); |
|
|
|
y1_vec = vbfmlaltq_lane_f32(y1_vec, t0, x_vecx4, 1); |
|
|
|
|
|
|
|
y2_vec = vbfmlalbq_lane_f32(y2_vec, t4, x_vecx4, 0); |
|
|
|
y2_vec = vbfmlaltq_lane_f32(y2_vec, t4, x_vecx4, 1); |
|
|
|
y2_vec = vbfmlalbq_lane_f32(y2_vec, t1, x_vecx4, 0); |
|
|
|
y2_vec = vbfmlaltq_lane_f32(y2_vec, t1, x_vecx4, 1); |
|
|
|
|
|
|
|
vst1q_f32(y_ptr, y1_vec); |
|
|
|
vst1q_f32(y_ptr + 4, y2_vec); |
|
|
@@ -449,29 +438,24 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda, |
|
|
|
a0 = vcombine_bf16(a0x4, bf16_zero); |
|
|
|
a1 = vcombine_bf16(a1x4, bf16_zero); |
|
|
|
|
|
|
|
t0 = vreinterpretq_bf16_u16( |
|
|
|
vzip1q_u16(vreinterpretq_u16_bf16(a0), vreinterpretq_u16_bf16(a1))); |
|
|
|
t1 = vreinterpretq_bf16_u16( |
|
|
|
vzip1q_u16(vreinterpretq_u16_bf16(a2), vreinterpretq_u16_bf16(a3))); |
|
|
|
t0 = vreinterpretq_bf16_u16(vzip1q_u16(vreinterpretq_u16_bf16(a0), vreinterpretq_u16_bf16(a1))); |
|
|
|
|
|
|
|
y1_vec = vbfmlalbq_lane_f32(y1_vec, t0, x_vecx4, 0); |
|
|
|
y1_vec = vbfmlaltq_lane_f32(y1_vec, t0, x_vecx4, 1); |
|
|
|
y1_vec = vbfmlalbq_lane_f32(y1_vec, t1, x_vecx4, 2); |
|
|
|
y1_vec = vbfmlaltq_lane_f32(y1_vec, t1, x_vecx4, 3); |
|
|
|
|
|
|
|
vst1q_f32(y_ptr, y1_vec); |
|
|
|
|
|
|
|
a_ptr0 += 4; |
|
|
|
a_ptr1 += 4; |
|
|
|
a_ptr2 += 4; |
|
|
|
a_ptr3 += 4; |
|
|
|
|
|
|
|
y_ptr += 4; |
|
|
|
} |
|
|
|
|
|
|
|
if (m & 2) { |
|
|
|
x0 = alpha * (vcvtah_f32_bf16(x_ptr[0])); |
|
|
|
x1 = alpha * (vcvtah_f32_bf16(x_ptr[1])); |
|
|
|
fp32_low = vcvt_f32_bf16(x_vecx4); |
|
|
|
x0 = vgetq_lane_f32(fp32_low, 0); |
|
|
|
x1 = vgetq_lane_f32(fp32_low, 1); |
|
|
|
|
|
|
|
|
|
|
|
y_ptr[0] += x0 * vcvtah_f32_bf16(a_ptr0[0]); |
|
|
|
y_ptr[0] += x1 * vcvtah_f32_bf16(a_ptr1[0]); |
|
|
@@ -485,8 +469,9 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda, |
|
|
|
} |
|
|
|
|
|
|
|
if (m & 1) { |
|
|
|
x0 = alpha * vcvtah_f32_bf16(x_ptr[0]); |
|
|
|
x1 = alpha * vcvtah_f32_bf16(x_ptr[1]); |
|
|
|
fp32_low = vcvt_f32_bf16(x_vecx4); |
|
|
|
x0 = vgetq_lane_f32(fp32_low, 0); |
|
|
|
x1 = vgetq_lane_f32(fp32_low, 1); |
|
|
|
|
|
|
|
y_ptr[0] += x0 * vcvtah_f32_bf16(a_ptr0[0]); |
|
|
|
y_ptr[0] += x1 * vcvtah_f32_bf16(a_ptr1[0]); |
|
|
|