|
|
@@ -33,11 +33,6 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
|
|
|
#include <arm_neon.h> |
|
|
|
#include "common.h" |
|
|
|
|
|
|
|
static inline float bf16_to_fp32(bfloat16 bf16) { |
|
|
|
uint32_t fp32 = (uint32_t)bf16 << 16; |
|
|
|
return *((float*)&fp32); |
|
|
|
} |
|
|
|
|
|
|
|
int CNAME(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, BLASLONG incx, float beta, float *y, BLASLONG incy) |
|
|
|
{ |
|
|
|
if (m < 1 || n < 1) return(0); |
|
|
@@ -132,10 +127,10 @@ int CNAME(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat |
|
|
|
} |
|
|
|
|
|
|
|
for (; i < m; ++i) { |
|
|
|
y0_ptr[iy] += alpha * a0_ptr[i] * x_ptr[i]; |
|
|
|
y1_ptr[iy] += alpha * a1_ptr[i] * x_ptr[i]; |
|
|
|
y2_ptr[iy] += alpha * a2_ptr[i] * x_ptr[i]; |
|
|
|
y3_ptr[iy] += alpha * a3_ptr[i] * x_ptr[i]; |
|
|
|
y0_ptr[iy] += alpha * vcvtah_f32_bf16(a0_ptr[i]) * vcvtah_f32_bf16(x_ptr[i]); |
|
|
|
y1_ptr[iy] += alpha * vcvtah_f32_bf16(a1_ptr[i]) * vcvtah_f32_bf16(x_ptr[i]); |
|
|
|
y2_ptr[iy] += alpha * vcvtah_f32_bf16(a2_ptr[i]) * vcvtah_f32_bf16(x_ptr[i]); |
|
|
|
y3_ptr[iy] += alpha * vcvtah_f32_bf16(a3_ptr[i]) * vcvtah_f32_bf16(x_ptr[i]); |
|
|
|
} |
|
|
|
|
|
|
|
iy += incy; |
|
|
@@ -177,7 +172,7 @@ int CNAME(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat |
|
|
|
} |
|
|
|
|
|
|
|
for (; i < m; ++i) { |
|
|
|
y_ptr[iy] += alpha * a_ptr[i] * x_ptr[i]; |
|
|
|
y_ptr[iy] += alpha * vcvtah_f32_bf16(a_ptr[i]) * vcvtah_f32_bf16(x_ptr[i]); |
|
|
|
} |
|
|
|
|
|
|
|
iy += incy; |
|
|
@@ -191,7 +186,7 @@ int CNAME(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat |
|
|
|
temp = 0.0; |
|
|
|
ix = 0; |
|
|
|
for (i = 0; i < m; i++) { |
|
|
|
temp += bf16_to_fp32(a[i]) * bf16_to_fp32(x[ix]); |
|
|
|
temp += vcvtah_f32_bf16(a_ptr[i]) * vcvtah_f32_bf16(x_ptr[ix]); |
|
|
|
ix += incx; |
|
|
|
} |
|
|
|
if (beta == 0.0f) { |
|
|
|