Browse Source

Add optimized BGEMM kernel for NEOVERSEV1 target

This also improves the testing and generic kernel by re-using the BF16
conversion functions.

Built on top of https://github.com/OpenMathLib/OpenBLAS/pull/5357 and derived from https://github.com/OpenMathLib/OpenBLAS/pull/5287

Co-authored-by: Ye Tao <ye.tao@arm.com>
pull/5373/head
Chris Sidebottom 2 months ago
parent
commit
740efd71c4
13 changed files with 816 additions and 99 deletions
  1. +33
    -1
      benchmark/Makefile
  2. +12
    -0
      benchmark/gemm.c
  3. +10
    -4
      common_b.h
  4. +29
    -1
      interface/Makefile
  5. +20
    -1
      kernel/Makefile.L3
  6. +42
    -1
      kernel/arm64/KERNEL.NEOVERSEV1
  7. +107
    -0
      kernel/arm64/bgemm_beta_neon.c
  8. +50
    -0
      kernel/arm64/bgemm_kernel_4x4_neoversev1.c
  9. +429
    -0
      kernel/arm64/bgemm_kernel_4x4_neoversev1_impl.c
  10. +39
    -15
      kernel/generic/gemmkernel_2x2.c
  11. +4
    -0
      kernel/setparam-ref.c
  12. +7
    -0
      param.h
  13. +34
    -76
      test/compare_sgemm_bgemm.c

+ 33
- 1
benchmark/Makefile View File

@@ -1,3 +1,31 @@
###############################################################################
# Copyright (c) 2025, The OpenBLAS Project
# All rights reserved.
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
# 1. Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in
# the documentation and/or other materials provided with the
# distribution.
# 3. Neither the name of the OpenBLAS project nor the names of
# its contributors may be used to endorse or promote products
# derived from this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
###############################################################################

TOPDIR = ..
include $(TOPDIR)/Makefile.system

@@ -56,7 +84,7 @@ GOTO_LAPACK_TARGETS=
endif

ifeq ($(BUILD_BFLOAT16),1)
GOTO_BFLOAT_TARGETS=sbgemm.goto
GOTO_BFLOAT_TARGETS=bgemm.goto sbgemm.goto
else
GOTO_BFLOAT_TARGETS=
endif
@@ -635,6 +663,8 @@ zcholesky.essl : zcholesky.$(SUFFIX)

##################################### Sgemm ####################################################
ifeq ($(BUILD_BFLOAT16),1)
bgemm.goto : bgemm.$(SUFFIX) ../$(LIBNAME)
$(CC) $(CFLAGS) -o $(@F) $^ $(CEXTRALIB) $(EXTRALIB) $(FEXTRALIB) -lm
sbgemm.goto : sbgemm.$(SUFFIX) ../$(LIBNAME)
$(CC) $(CFLAGS) -o $(@F) $^ $(CEXTRALIB) $(EXTRALIB) $(FEXTRALIB) -lm
endif
@@ -2970,6 +3000,8 @@ zcholesky.$(SUFFIX) : cholesky.c
$(CC) $(CFLAGS) -c -DCOMPLEX -DDOUBLE -o $(@F) $^

ifeq ($(BUILD_BFLOAT16),1)
bgemm.$(SUFFIX) : gemm.c
$(CC) $(CFLAGS) -c -DBFLOAT16 -DBGEMM -UCOMPLEX -UDOUBLE -o $(@F) $^
sbgemm.$(SUFFIX) : gemm.c
$(CC) $(CFLAGS) -c -DBFLOAT16 -UCOMPLEX -UDOUBLE -o $(@F) $^
endif


+ 12
- 0
benchmark/gemm.c View File

@@ -33,6 +33,8 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

#ifdef DOUBLE
#define GEMM BLASFUNC(dgemm)
#elif defined(BFLOAT16) && defined(BGEMM)
#define GEMM BLASFUNC(bgemm)
#elif defined(BFLOAT16)
#define GEMM BLASFUNC(sbgemm)
#undef IFLOAT
@@ -60,8 +62,18 @@ int main(int argc, char *argv[]){

IFLOAT *a, *b;
FLOAT *c;
#ifdef BGEMM
blasint one=1;
blasint two=2;
float alpha_in[] = {1.0, 0.0};
float beta_in[] = {0.0, 0.0};
FLOAT alpha[2], beta[2];
sbstobf16_(&two, alpha_in, &one, alpha, &one);
sbstobf16_(&two, beta_in, &one, beta, &one);
#else
FLOAT alpha[] = {1.0, 0.0};
FLOAT beta [] = {0.0, 0.0};
#endif
char transa = 'N';
char transb = 'N';
blasint m, n, k, i, j, lda, ldb, ldc;


+ 10
- 4
common_b.h View File

@@ -30,10 +30,16 @@
#define COMMON_B_H

#ifndef DYNAMIC_ARCH
#define BGEMM_ONCOPY bgemm_oncopy
#define BGEMM_OTCOPY bgemm_otcopy
#define BGEMM_INCOPY bgemm_incopy
#define BGEMM_ITCOPY bgemm_itcopy
#define BGEMM_ONCOPY bgemm_oncopy
#define BGEMM_OTCOPY bgemm_otcopy

#if BGEMM_DEFAULT_UNROLL_M == BGEMM_DEFAULT_UNROLL_N
#define BGEMM_INCOPY bgemm_oncopy
#define BGEMM_ITCOPY bgemm_otcopy
#else
#define BGEMM_INCOPY bgemm_incopy
#define BGEMM_ITCOPY bgemm_itcopy
#endif

#define BGEMM_BETA bgemm_beta
#define BGEMM_KERNEL bgemm_kernel


+ 29
- 1
interface/Makefile View File

@@ -1,3 +1,31 @@
###############################################################################
# Copyright (c) 2025, The OpenBLAS Project
# All rights reserved.
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
# 1. Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in
# the documentation and/or other materials provided with the
# distribution.
# 3. Neither the name of the OpenBLAS project nor the names of
# its contributors may be used to endorse or promote products
# derived from this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
###############################################################################

TOPDIR = ..
include $(TOPDIR)/Makefile.system

@@ -526,7 +554,7 @@ ifneq ($(BUILD_COMPLEX16),1)
ZBLASOBJS=
endif

FUNCOBJS = $(SBEXTOBJS) $(CXERBLAOBJS) $(SBBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS) $(SHBLASOBJS)
FUNCOBJS = $(SBEXTOBJS) $(CXERBLAOBJS) $(BBLASOBJS) $(SBBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS) $(SHBLASOBJS)

ifeq ($(EXPRECISION), 1)
FUNCOBJS += $(QBLASOBJS) $(XBLASOBJS)


+ 20
- 1
kernel/Makefile.L3 View File

@@ -674,6 +674,10 @@ ZBLASOBJS += \
endif

ifeq ($(BUILD_BFLOAT16), 1)
BGEMMINCOPYOBJ_P = $(BGEMMINCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX))
BGEMMITCOPYOBJ_P = $(BGEMMITCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX))
BGEMMONCOPYOBJ_P = $(BGEMMONCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX))
BGEMMOTCOPYOBJ_P = $(BGEMMOTCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX))
SBGEMMINCOPYOBJ_P = $(SBGEMMINCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX))
SBGEMMITCOPYOBJ_P = $(SBGEMMITCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX))
SBGEMMONCOPYOBJ_P = $(SBGEMMONCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX))
@@ -2998,6 +3002,20 @@ $(KDIR)xgemm_beta$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(XGEMM_BETA)


ifeq ($(BUILD_BFLOAT16), 1)
$(BGEMMONCOPYOBJ_P) : $(KERNELDIR)/$(BGEMMONCOPY)
$(CC) $(PFLAGS) -c -DBFLOAT16 -DBGEMM -UDOUBLE -UCOMPLEX $< -o $@

$(BGEMMOTCOPYOBJ_P) : $(KERNELDIR)/$(BGEMMOTCOPY)
$(CC) $(PFLAGS) -c -DBFLOAT16 -DBGEMM -UDOUBLE -UCOMPLEX $< -o $@

ifneq ($(BGEMM_UNROLL_M), $(BGEMM_UNROLL_N))
$(BGEMMINCOPYOBJ_P) : $(KERNELDIR)/$(BGEMMINCOPY)
$(CC) $(PFLAGS) -c -DBFLOAT16 -DBGEMM -UDOUBLE -UCOMPLEX $< -o $@

$(BGEMMITCOPYOBJ_P) : $(KERNELDIR)/$(BGEMMITCOPY)
$(CC) $(PFLAGS) -c -DBFLOAT16 -DBGEMM -UDOUBLE -UCOMPLEX $< -o $@
endif

$(SBGEMMONCOPYOBJ_P) : $(KERNELDIR)/$(SBGEMMONCOPY)
$(CC) $(PFLAGS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@

@@ -3010,7 +3028,6 @@ $(SBGEMMINCOPYOBJ_P) : $(KERNELDIR)/$(SBGEMMINCOPY)

$(SBGEMMITCOPYOBJ_P) : $(KERNELDIR)/$(SBGEMMITCOPY)
$(CC) $(PFLAGS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@

endif
endif

@@ -3137,6 +3154,8 @@ endif


ifeq ($(BUILD_BFLOAT16), 1)
$(KDIR)bgemm_kernel$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(BGEMMKERNEL) $(BGEMMDEPEND)
$(CC) $(PFLAGS) -c -DBFLOAT16 -DBGEMM -UDOUBLE -UCOMPLEX $< -o $@
$(KDIR)sbgemm_kernel$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SBGEMMKERNEL) $(SBGEMMDEPEND)
$(CC) $(PFLAGS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@
endif


+ 42
- 1
kernel/arm64/KERNEL.NEOVERSEV1 View File

@@ -1,3 +1,31 @@
###############################################################################
# Copyright (c) 2025, The OpenBLAS Project
# All rights reserved.
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
# 1. Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in
# the documentation and/or other materials provided with the
# distribution.
# 3. Neither the name of the OpenBLAS project nor the names of
# its contributors may be used to endorse or promote products
# derived from this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
###############################################################################

include $(KERNELDIR)/KERNEL.ARMV8SVE

SGEMVNKERNEL = gemv_n_sve_v1x3.c
@@ -5,6 +33,19 @@ DGEMVNKERNEL = gemv_n_sve_v1x3.c
SGEMVTKERNEL = gemv_t_sve_v1x3.c
DGEMVTKERNEL = gemv_t_sve_v1x3.c
ifeq ($(BUILD_BFLOAT16), 1)
BGEMM_BETA = bgemm_beta_neon.c
BGEMMKERNEL = bgemm_kernel_$(BGEMM_UNROLL_M)x$(BGEMM_UNROLL_N)_neoversev1.c
ifneq ($(BGEMM_UNROLL_M), $(BGEMM_UNROLL_N))
BGEMMINCOPY = sbgemm_ncopy_$(SBGEMM_UNROLL_M)_neoversev1.c
BGEMMITCOPY = sbgemm_tcopy_$(SBGEMM_UNROLL_M)_neoversev1.c
BGEMMINCOPYOBJ = bgemm_incopy$(TSUFFIX).$(SUFFIX)
BGEMMITCOPYOBJ = bgemm_itcopy$(TSUFFIX).$(SUFFIX)
endif
BGEMMONCOPY = sbgemm_ncopy_$(BGEMM_UNROLL_N)_neoversev1.c
BGEMMOTCOPY = sbgemm_tcopy_$(BGEMM_UNROLL_N)_neoversev1.c
BGEMMONCOPYOBJ = bgemm_oncopy$(TSUFFIX).$(SUFFIX)
BGEMMOTCOPYOBJ = bgemm_otcopy$(TSUFFIX).$(SUFFIX)

SBGEMM_BETA = sbgemm_beta_neoversev1.c
SBGEMMKERNEL = sbgemm_kernel_$(SBGEMM_UNROLL_M)x$(SBGEMM_UNROLL_N)_neoversev1.c
ifneq ($(SBGEMM_UNROLL_M), $(SBGEMM_UNROLL_N))
@@ -21,4 +62,4 @@ SBGEMMOTCOPYOBJ = sbgemm_otcopy$(TSUFFIX).$(SUFFIX)
SBGEMVNKERNEL = sbgemv_n_neon.c
SBGEMVTKERNEL = sbgemv_t_bfdot.c

endif
endif

+ 107
- 0
kernel/arm64/bgemm_beta_neon.c View File

@@ -0,0 +1,107 @@
/***************************************************************************
* Copyright (c) 2025, The OpenBLAS Project
* All rights reserved.
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
* 1. Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in
* the documentation and/or other materials provided with the
* distribution.
* 3. Neither the name of the OpenBLAS project nor the names of
* its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
* POSSIBILITY OF SUCH DAMAGE.
* *****************************************************************************/

#include "common.h"

#include <arm_neon.h>

int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT beta_in, IFLOAT *dummy2,
BLASLONG dummy3, IFLOAT *dummy4, BLASLONG dummy5, FLOAT *c,
BLASLONG ldc) {
BLASLONG i, j;
BLASLONG chunk, remain;
bfloat16_t *ptr_c, *ptr_c0;

bfloat16x8_t x0, z0;
float32x4_t y0, y1;
float x;
bfloat16_t z;

bfloat16_t zero_bf16 = vcvth_bf16_f32(0.0f);
bfloat16x8_t zeros = vdupq_n_bf16(zero_bf16);

bfloat16_t beta_bf16;
memcpy(&beta_bf16, &beta_in, sizeof(bfloat16_t));
float beta = vcvtah_f32_bf16(beta_bf16);
float32x4_t beta_neon = vdupq_n_f32(beta);

ptr_c = (bfloat16_t *)c;
chunk = m >> 3;
remain = m & 7;

if (beta == 0.0f){
for (j = 0; j < n; j ++){
ptr_c0 = ptr_c;
ptr_c += ldc;

for (i = 0; i < chunk; i ++){
vst1q_bf16(ptr_c0, zeros);
ptr_c0 += 8;
}

for (i = 0; i < remain; i ++){
ptr_c0[0] = zero_bf16;
ptr_c0 ++;
}
}
} else {
for (j = 0; j < n; j ++){
ptr_c0 = ptr_c;
ptr_c += ldc;

for (i = 0; i < chunk; i ++){
x0 = vld1q_bf16(ptr_c0);

y0 = vcvtq_low_f32_bf16(x0);
y1 = vcvtq_high_f32_bf16(x0);

y0 = vmulq_f32(y0, beta_neon);
y1 = vmulq_f32(y1, beta_neon);

z0 = vcvtq_low_bf16_f32(y0);
z0 = vcvtq_high_bf16_f32(z0, y1);
vst1q_bf16(ptr_c0, z0);

ptr_c0 += 8;
}

for (i = 0; i < remain; i ++){
x = vcvtah_f32_bf16(ptr_c0[0]);
z = vcvth_bf16_f32(x * beta);

ptr_c0[0] = z;
ptr_c0 ++;
}
}
}
return 0;
};

+ 50
- 0
kernel/arm64/bgemm_kernel_4x4_neoversev1.c View File

@@ -0,0 +1,50 @@
/***************************************************************************
* Copyright (c) 2025, The OpenBLAS Project
* All rights reserved.
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
* 1. Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in
* the documentation and/or other materials provided with the
* distribution.
* 3. Neither the name of the OpenBLAS project nor the names of
* its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
* POSSIBILITY OF SUCH DAMAGE.
* *****************************************************************************/

#include <arm_sve.h>

#include "common.h"

#define ALPHA_ONE
#include "bgemm_kernel_4x4_neoversev1_impl.c"
#undef ALPHA_ONE
#undef UPDATE_C
#include "bgemm_kernel_4x4_neoversev1_impl.c"

int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT *A, IFLOAT *B,
FLOAT *C, BLASLONG ldc) {
bfloat16_t alpha_bf16;
memcpy(&alpha_bf16, &alpha, sizeof(bfloat16_t));
float alpha_f32 = vcvtah_f32_bf16(alpha_bf16);

if (alpha_f32 == 1.0f)
return bgemm_kernel_neoversev1_alpha_one(m, n, k, alpha, A, B, C, ldc);
else
return bgemm_kernel_neoversev1_alpha(m, n, k, alpha, A, B, C, ldc);
return 0;
}

+ 429
- 0
kernel/arm64/bgemm_kernel_4x4_neoversev1_impl.c View File

@@ -0,0 +1,429 @@
/***************************************************************************
* Copyright (c) 2025, The OpenBLAS Project
* All rights reserved.
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
* 1. Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in
* the documentation and/or other materials provided with the
* distribution.
* 3. Neither the name of the OpenBLAS project nor the names of
* its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
* POSSIBILITY OF SUCH DAMAGE.
* *****************************************************************************/

#include <arm_sve.h>

#include "common.h"

#define INIT_C(M, N) mc##M##N = svdup_f32(0);

#define MATMUL(M, N) mc##M##N = svbfmmla(mc##M##N, ma##M, mb##N);

#define INIT_C_4x4 \
do { \
INIT_C(0, 0); \
INIT_C(0, 1); \
INIT_C(1, 0); \
INIT_C(1, 1); \
} while (0);

#ifdef ALPHA_ONE
#define UPDATE_C(PG16, PG32, PTR, TMP32, TMP16, SRC32) \
do { \
TMP32 = svreinterpret_f32_u32(svld1uh_u32((PG16), (uint16_t*)PTR)); \
TMP32 = svadd_z((PG32), SRC32, TMP32); \
TMP16 = svcvt_bf16_f32_z((PG32), TMP32); \
TMP16 = svuzp1_bf16(TMP16, TMP16); \
svst1_bf16((PG16), (PTR), TMP16); \
} while (0)
#else
#define UPDATE_C(PG16, PG32, PTR, TMP32, TMP16, SRC32) \
do { \
TMP32 = svreinterpret_f32_u32(svld1uh_u32((PG16), (uint16_t*)PTR)); \
TMP32 = svmad_z((PG32), svalpha, SRC32, TMP32); \
TMP16 = svcvt_bf16_f32_z((PG32), TMP32); \
TMP16 = svuzp1_bf16(TMP16, TMP16); \
svst1_bf16((PG16), (PTR), TMP16); \
} while (0)
#endif

#define ZIP_EVEN_ELEMENTS(PG, mc0, mc1, tmp, vc) \
do { \
(tmp) = svuzp1_f32((mc0), (mc1)); \
(vc) = svcompact_f32((PG), (tmp)); \
} while (0)

#define ZIP_ODD_ELEMENTS(PG, mc0, mc1, tmp, vc) \
do { \
(tmp) = svuzp2_f32((mc0), (mc1)); \
(vc) = svcompact_f32((PG), (tmp)); \
} while (0)

#define ACCUMULATE_LAST4_TO_FIRST4(M, N, TMP) \
do { \
TMP = svext_f32(mc##M##N, mc##M##N, 4); \
mc##M##N = svadd_f32_z(svptrue_b32(), mc##M##N, (TMP)); \
} while (0)

#ifdef ALPHA_ONE
int bgemm_kernel_neoversev1_alpha_one(BLASLONG m, BLASLONG n, BLASLONG k,
FLOAT alpha, IFLOAT *A, IFLOAT *B,
FLOAT *C, BLASLONG ldc)
#else
int bgemm_kernel_neoversev1_alpha(BLASLONG m, BLASLONG n, BLASLONG k,
FLOAT alpha, IFLOAT *A, IFLOAT *B, FLOAT *C,
BLASLONG ldc)
#endif
{
BLASLONG pad_k = (k + 7) & ~7;
svbfloat16_t ma0, ma1, mb0, mb1;
svfloat32_t mc00, mc01, mc10, mc11, vc0, vc1, vc2, vc3;
svfloat32_t tmp;
#ifndef ALPHA_ONE
bfloat16_t alpha_bf16;
memcpy(&alpha_bf16, &alpha, sizeof(bfloat16_t));
svfloat32_t svalpha = svdup_f32(vcvtah_f32_bf16(alpha_bf16));
#endif

svbool_t pg16_all = svptrue_b16();

svbool_t pg32_first_1 = svwhilelt_b32(0, 1);
svbool_t pg32_first_2 = svwhilelt_b32(0, 2);
svbool_t pg32_first_4 = svwhilelt_b32(0, 4);

svbool_t pg16_first_1 = svwhilelt_b16(0, 1);
svbool_t pg16_first_2 = svwhilelt_b16(0, 2);
svbool_t pg16_first_4 = svwhilelt_b16(0, 4);

svbool_t pg32_select_first_2_per_quadword = svdupq_b32(1, 1, 0, 0);

bfloat16_t *ptr_a = (bfloat16_t *)A;
bfloat16_t *ptr_b = (bfloat16_t *)B;
bfloat16_t *ptr_c = (bfloat16_t *)C;

bfloat16_t *ptr_a0;
bfloat16_t *ptr_b0;
bfloat16_t *ptr_c0, *ptr_c1, *ptr_c2, *ptr_c3;

svfloat32_t tmp32;
svbfloat16_t tmp16;

for (BLASLONG j = 0; j < n / 4; j++) {
ptr_c0 = ptr_c;
ptr_c1 = ptr_c0 + ldc;
ptr_c2 = ptr_c1 + ldc;
ptr_c3 = ptr_c2 + ldc;
ptr_c += 4 * ldc;
ptr_a = (bfloat16_t *)A;

for (BLASLONG i = 0; i < m / 4; i++) {
ptr_a0 = ptr_a;
ptr_a += 4 * pad_k;

ptr_b0 = ptr_b;

INIT_C_4x4;

for (BLASLONG p = 0; p < pad_k; p += 8) {
ma0 = svld1_bf16(pg16_all, ptr_a0);
ma1 = svld1_bf16(pg16_all, ptr_a0 + 16);

mb0 = svld1_bf16(pg16_all, ptr_b0);
mb1 = svld1_bf16(pg16_all, ptr_b0 + 16);

MATMUL(0, 0);
MATMUL(0, 1);
MATMUL(1, 0);
MATMUL(1, 1);

ptr_a0 += 32;
ptr_b0 += 32;
}

ACCUMULATE_LAST4_TO_FIRST4(0, 0, tmp);
ACCUMULATE_LAST4_TO_FIRST4(0, 1, tmp);
ACCUMULATE_LAST4_TO_FIRST4(1, 0, tmp);
ACCUMULATE_LAST4_TO_FIRST4(1, 1, tmp);

ZIP_EVEN_ELEMENTS(pg32_select_first_2_per_quadword, mc00, mc10, tmp, vc0);
ZIP_ODD_ELEMENTS(pg32_select_first_2_per_quadword, mc00, mc10, tmp, vc1);

ZIP_EVEN_ELEMENTS(pg32_select_first_2_per_quadword, mc01, mc11, tmp, vc2);
ZIP_ODD_ELEMENTS(pg32_select_first_2_per_quadword, mc01, mc11, tmp, vc3);

UPDATE_C(pg16_first_4, pg32_first_4, ptr_c0, tmp32, tmp16, vc0);
UPDATE_C(pg16_first_4, pg32_first_4, ptr_c1, tmp32, tmp16, vc1);
UPDATE_C(pg16_first_4, pg32_first_4, ptr_c2, tmp32, tmp16, vc2);
UPDATE_C(pg16_first_4, pg32_first_4, ptr_c3, tmp32, tmp16, vc3);

ptr_c0 += 4;
ptr_c1 += 4;
ptr_c2 += 4;
ptr_c3 += 4;
}

if (m & 2) {
ptr_a0 = ptr_a;
ptr_a += 2 * pad_k;

ptr_b0 = ptr_b;
INIT_C(0, 0);
INIT_C(0, 1);
for (BLASLONG p = 0; p < pad_k; p += 8) {
ma0 = svld1_bf16(pg16_all, ptr_a0);
mb0 = svld1_bf16(pg16_all, ptr_b0);
mb1 = svld1_bf16(pg16_all, ptr_b0 + 16);

MATMUL(0, 0);
MATMUL(0, 1);

ptr_a0 += 16;
ptr_b0 += 32;
}

ACCUMULATE_LAST4_TO_FIRST4(0, 0, tmp);
ACCUMULATE_LAST4_TO_FIRST4(0, 1, tmp);

vc0 = svuzp1(mc00, mc00);
vc1 = svuzp2(mc00, mc00);
vc2 = svuzp1(mc01, mc01);
vc3 = svuzp2(mc01, mc01);

UPDATE_C(pg16_first_2, pg32_first_2, ptr_c0, tmp32, tmp16, vc0);
UPDATE_C(pg16_first_2, pg32_first_2, ptr_c1, tmp32, tmp16, vc1);
UPDATE_C(pg16_first_2, pg32_first_2, ptr_c2, tmp32, tmp16, vc2);
UPDATE_C(pg16_first_2, pg32_first_2, ptr_c3, tmp32, tmp16, vc3);

ptr_c0 += 2;
ptr_c1 += 2;
ptr_c2 += 2;
ptr_c3 += 2;
}

if (m & 1) {
ptr_a0 = ptr_a;
ptr_b0 = ptr_b;

INIT_C(0, 0);
INIT_C(0, 1);
for (BLASLONG p = 0; p < pad_k; p += 8) {
ma0 = svld1_bf16(pg16_all, ptr_a0);
mb0 = svld1_bf16(pg16_all, ptr_b0);
mb1 = svld1_bf16(pg16_all, ptr_b0 + 16);

MATMUL(0, 0);
MATMUL(0, 1);

ptr_a0 += 16;
ptr_b0 += 32;
}

ACCUMULATE_LAST4_TO_FIRST4(0, 0, tmp);
ACCUMULATE_LAST4_TO_FIRST4(0, 1, tmp);

// use compact is more straightforward
vc1 = svuzp2(mc00, mc00);
vc3 = svuzp2(mc01, mc01);

UPDATE_C(pg16_first_1, pg32_first_1, ptr_c0, tmp32, tmp16, mc00);
UPDATE_C(pg16_first_1, pg32_first_1, ptr_c1, tmp32, tmp16, vc1);
UPDATE_C(pg16_first_1, pg32_first_1, ptr_c2, tmp32, tmp16, mc01);
UPDATE_C(pg16_first_1, pg32_first_1, ptr_c3, tmp32, tmp16, vc3);
}

ptr_b += 4 * pad_k;
}

if (n & 2) {
ptr_c0 = ptr_c;
ptr_c1 = ptr_c0 + ldc;
ptr_c += 2 * ldc;
ptr_a = (bfloat16_t *)A;

for (BLASLONG i = 0; i < m / 4; i++) {
ptr_a0 = ptr_a;
ptr_a += 4 * pad_k;

ptr_b0 = ptr_b;

INIT_C(0, 0);
INIT_C(1, 0);

for (BLASLONG p = 0; p < pad_k; p += 8) {
ma0 = svld1_bf16(pg16_all, ptr_a0);
ma1 = svld1_bf16(pg16_all, ptr_a0 + 16);

mb0 = svld1_bf16(pg16_all, ptr_b0);

MATMUL(0, 0);
MATMUL(1, 0);

ptr_a0 += 32;
ptr_b0 += 16;
}

ACCUMULATE_LAST4_TO_FIRST4(0, 0, tmp);
ACCUMULATE_LAST4_TO_FIRST4(1, 0, tmp);

ZIP_EVEN_ELEMENTS(pg32_select_first_2_per_quadword, mc00, mc10, tmp, vc0);
ZIP_ODD_ELEMENTS(pg32_select_first_2_per_quadword, mc00, mc10, tmp, vc2);

UPDATE_C(pg16_first_4, pg32_first_4, ptr_c0, tmp32, tmp16, vc0);
UPDATE_C(pg16_first_4, pg32_first_4, ptr_c1, tmp32, tmp16, vc2);

ptr_c0 += 4;
ptr_c1 += 4;
}

if (m & 2) {
ptr_a0 = ptr_a;
ptr_a += 2 * pad_k;
ptr_b0 = ptr_b;

INIT_C(0, 0);

for (BLASLONG p = 0; p < pad_k; p += 8) {
ma0 = svld1_bf16(pg16_all, ptr_a0);
mb0 = svld1_bf16(pg16_all, ptr_b0);

MATMUL(0, 0);

ptr_a0 += 16;
ptr_b0 += 16;
}

ACCUMULATE_LAST4_TO_FIRST4(0, 0, tmp);
vc0 = svuzp1(mc00, mc00);
vc1 = svuzp2(mc00, mc00);

UPDATE_C(pg16_first_2, pg32_first_2, ptr_c0, tmp32, tmp16, vc0);
UPDATE_C(pg16_first_2, pg32_first_2, ptr_c1, tmp32, tmp16, vc1);

ptr_c0 += 2;
ptr_c1 += 2;
}

if (m & 1) {
ptr_a0 = ptr_a;
ptr_b0 = ptr_b;
INIT_C(0, 0);
for (BLASLONG p = 0; p < pad_k; p += 8) {
ma0 = svld1_bf16(pg16_all, ptr_a0);
mb0 = svld1_bf16(pg16_all, ptr_b0);
MATMUL(0, 0);
ptr_a0 += 16;
ptr_b0 += 16;
}

ACCUMULATE_LAST4_TO_FIRST4(0, 0, tmp);
vc1 = svuzp2(mc00, mc00);

UPDATE_C(pg16_first_1, pg32_first_1, ptr_c0, tmp32, tmp16, mc00);
UPDATE_C(pg16_first_1, pg32_first_1, ptr_c1, tmp32, tmp16, vc1);
}

ptr_b += 2 * pad_k;
}

if (n & 1) { // TODO: this case seems a overhead. find out whether it's in our
// case.
ptr_c0 = ptr_c;
ptr_a = (bfloat16_t *)A;

for (BLASLONG i = 0; i < m / 4; i++) {
ptr_a0 = ptr_a;
ptr_a += 4 * pad_k;

ptr_b0 = ptr_b;

INIT_C(0, 0);
INIT_C(1, 0);

for (BLASLONG p = 0; p < pad_k; p += 8) {
ma0 = svld1_bf16(pg16_all, ptr_a0);
ma1 = svld1_bf16(pg16_all, ptr_a0 + 16);

mb0 = svld1_bf16(pg16_all, ptr_b0);

MATMUL(0, 0);
MATMUL(1, 0);

ptr_a0 += 32;
ptr_b0 += 16;
}

ACCUMULATE_LAST4_TO_FIRST4(0, 0, tmp);
ACCUMULATE_LAST4_TO_FIRST4(1, 0, tmp);

ZIP_EVEN_ELEMENTS(pg32_select_first_2_per_quadword, mc00, mc10, tmp, vc0);

UPDATE_C(pg16_first_4, pg32_first_4, ptr_c0, tmp32, tmp16, vc0);

ptr_c0 += 4;
}

if (m & 2) {
ptr_a0 = ptr_a;
ptr_a += 2 * pad_k;
ptr_b0 = ptr_b;

INIT_C(0, 0);

for (BLASLONG p = 0; p < pad_k; p += 8) {
ma0 = svld1_bf16(pg16_all, ptr_a0);
mb0 = svld1_bf16(pg16_all, ptr_b0);

MATMUL(0, 0);

ptr_a0 += 16;
ptr_b0 += 16;
}

ACCUMULATE_LAST4_TO_FIRST4(0, 0, tmp);

vc0 = svuzp1(mc00, mc00);

UPDATE_C(pg16_first_2, pg32_first_2, ptr_c0, tmp32, tmp16, vc0);

ptr_c0 += 2;
}

if (m & 1) {
ptr_a0 = ptr_a;
ptr_b0 = ptr_b;

INIT_C(0, 0);
for (BLASLONG p = 0; p < pad_k; p += 8) {

ma0 = svld1_bf16(pg16_all, ptr_a0);
mb0 = svld1_bf16(pg16_all, ptr_b0);

MATMUL(0, 0);
ptr_a0 += 16;
ptr_b0 += 16;
}

ACCUMULATE_LAST4_TO_FIRST4(0, 0, tmp);

UPDATE_C(pg16_first_1, pg32_first_1, ptr_c0, tmp32, tmp16, mc00);
}
}

return 0;
}

+ 39
- 15
kernel/generic/gemmkernel_2x2.c View File

@@ -1,26 +1,50 @@
/***************************************************************************
* Copyright (c) 2025, The OpenBLAS Project
* All rights reserved.
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
* 1. Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in
* the documentation and/or other materials provided with the
* distribution.
* 3. Neither the name of the OpenBLAS project nor the names of
* its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
* POSSIBILITY OF SUCH DAMAGE.
* *****************************************************************************/

#include "common.h"
#if defined(BFLOAT16) && defined(BFLOAT16CONVERSION)
static float
bfloat16tof32 (bfloat16 f16)
bfloat16tof32 (bfloat16 value)
{
float result = 0;
unsigned short* q = (unsigned short*)(&result);
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
q[0] = f16;
#else
q[1] = f16;
#endif
blasint one = 1;
float result;
sbf16tos_(&one, &value, &one, &result, &one);
return result;
}

static bfloat16 f32tobfloat16(float f32) {
unsigned short *q = (unsigned short *)(&f32);
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
return q[0];
#else
return q[1];
#endif
#ifdef BGEMM
static bfloat16 f32tobfloat16(float value) {
blasint one = 1;
bfloat16 result;
sbstobf16_(&one, &value, &one, &result, &one);
return result;
}
#endif

#ifdef BGEMM
#define ALPHA bfloat16tof32(alpha)


+ 4
- 0
kernel/setparam-ref.c View File

@@ -88,7 +88,11 @@ gotoblas_t TABLE_NAME = {
ssymv_LTS, ssymv_UTS,

bgemm_kernelTS, bgemm_betaTS,
#if BGEMM_DEFAULT_UNROLL_M != BGEMM_DEFAULT_UNROLL_N
bgemm_incopyTS, bgemm_itcopyTS,
#else
bgemm_oncopyTS, bgemm_otcopyTS,
#endif
bgemm_oncopyTS, bgemm_otcopyTS,

sbgemm_kernelTS, sbgemm_betaTS,


+ 7
- 0
param.h View File

@@ -3593,6 +3593,13 @@ is a big desktop or server with abundant cache rather than a phone or embedded d
#define GEMM_PREFERED_SIZE 8
#endif

#undef BGEMM_ALIGN_K
#undef BGEMM_DEFAULT_UNROLL_M
#undef BGEMM_DEFAULT_UNROLL_N
#define BGEMM_ALIGN_K 8
#define BGEMM_DEFAULT_UNROLL_N 4
#define BGEMM_DEFAULT_UNROLL_M 4

#undef SBGEMM_ALIGN_K
#undef SBGEMM_DEFAULT_UNROLL_M
#undef SBGEMM_DEFAULT_UNROLL_N


+ 34
- 76
test/compare_sgemm_bgemm.c View File

@@ -33,92 +33,49 @@ THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#define BGEMM BLASFUNC(bgemm)
#define BGEMM_LARGEST 256

typedef union
static float float16to32(bfloat16 value)
{
unsigned short v;
#if defined(_AIX)
struct __attribute__((packed))
#else
struct
#endif
{
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
unsigned short s:1;
unsigned short e:8;
unsigned short m:7;
#else
unsigned short m:7;
unsigned short e:8;
unsigned short s:1;
#endif
} bits;
} bfloat16_bits;

typedef union
{
float v;
#if defined(_AIX)
struct __attribute__((packed))
#else
struct
#endif
{
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
uint32_t s:1;
uint32_t e:8;
uint32_t m:23;
#else
uint32_t m:23;
uint32_t e:8;
uint32_t s:1;
#endif
} bits;
} float32_bits;

float
float16to32 (bfloat16_bits f16)
{
float32_bits f32;
f32.bits.s = f16.bits.s;
f32.bits.e = f16.bits.e;
f32.bits.m = (uint32_t) f16.bits.m << 16;
return f32.v;
}

bfloat16
float32to16 (float32_bits f32)
{
bfloat16_bits f16;
f16.bits.s = f32.bits.s;
f16.bits.e = f32.bits.e;
f16.bits.m = (f32.bits.m >> 16) & 0x7f;
return f16.v;
blasint one = 1;
float result;
sbf16tos_(&one, &value, &one, &result, &one);
return result;
}

static float truncate_float(float value) {
bfloat16_bits f16 = (bfloat16_bits)float32to16((float32_bits)value);
return float16to32(f16);
blasint one = 1;
bfloat16 tmp;
float result;
sbstobf16_(&one, &value, &one, &tmp, &one);
sbf16tos_(&one, &tmp, &one, &result, &one);
return result;
}

void *malloc_safe(size_t size) {
static void *malloc_safe(size_t size) {
if (size == 0)
return malloc(1);
else
return malloc(size);
}

static is_close(float a, float b, float rtol, float atol) {
return fabs(a - b) <= (atol + rtol*fabs(b));
}

int
main (int argc, char *argv[])
{
blasint m, n, k;
int i, j, l;
blasint x, y;
blasint one = 1;
int ret = 0;
int loop = BGEMM_LARGEST;
char transA = 'N', transB = 'N';
float alpha = 1.0, beta = 0.0;
bfloat16 alpha_bf16 = float32to16((float32_bits)alpha);
bfloat16 beta_bf16 = float32to16((float32_bits)beta);
bfloat16 alpha_bf16;
sbstobf16_(&one, &alpha, &one, &alpha_bf16, &one);
bfloat16 beta_bf16;
sbstobf16_(&one, &beta, &one, &beta_bf16, &one);

for (x = 0; x <= loop; x++)
{
@@ -127,23 +84,20 @@ main (int argc, char *argv[])
float *A = (float *)malloc_safe(m * k * sizeof(FLOAT));
float *B = (float *)malloc_safe(k * n * sizeof(FLOAT));
float *C = (float *)malloc_safe(m * n * sizeof(FLOAT));
bfloat16_bits *AA = (bfloat16_bits *)malloc_safe(m * k * sizeof(bfloat16_bits));
bfloat16_bits *BB = (bfloat16_bits *)malloc_safe(k * n * sizeof(bfloat16_bits));
bfloat16_bits *CC = (bfloat16_bits *)malloc_safe(k * n * sizeof(bfloat16_bits));
bfloat16 *AA = (bfloat16 *)malloc_safe(m * k * sizeof(bfloat16));
bfloat16 *BB = (bfloat16 *)malloc_safe(k * n * sizeof(bfloat16));
bfloat16 *CC = (bfloat16 *)malloc_safe(k * n * sizeof(bfloat16));
FLOAT *DD = (FLOAT *)malloc_safe(m * n * sizeof(FLOAT));
if ((A == NULL) || (B == NULL) || (C == NULL) || (AA == NULL) || (BB == NULL) ||
(DD == NULL) || (CC == NULL))
return 1;
bfloat16 atmp,btmp;
blasint one=1;

for (j = 0; j < m; j++)
{
for (i = 0; i < k; i++)
{
A[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
sbstobf16_(&one, &A[j*k+i], &one, &atmp, &one);
AA[j * k + i].v = atmp;
sbstobf16_(&one, &A[j*k+i], &one, &AA[j * k + i], &one);
}
}
for (j = 0; j < n; j++)
@@ -151,8 +105,7 @@ main (int argc, char *argv[])
for (i = 0; i < k; i++)
{
B[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
sbstobf16_(&one, &B[j*k+i], &one, &btmp, &one);
BB[j * k + i].v = btmp;
sbstobf16_(&one, &B[j*k+i], &one, &BB[j * k + i], &one);
}
}
for (y = 0; y < 4; y++)
@@ -168,7 +121,7 @@ main (int argc, char *argv[])
transB = 'T';
}

memset(CC, 0, m * n * sizeof(bfloat16_bits));
memset(CC, 0, m * n * sizeof(bfloat16));
memset(DD, 0, m * n * sizeof(FLOAT));
memset(C, 0, m * n * sizeof(FLOAT));

@@ -198,10 +151,15 @@ main (int argc, char *argv[])
DD[i * m + j] +=
float16to32 (AA[k * j + l]) * float16to32 (BB[i + l * n]);
}
if (fabs(float16to32(CC[i * m + j]) - truncate_float(C[i * m + j])) > 2.0) {
if (!is_close(float16to32(CC[i * m + j]), truncate_float(C[i * m + j]), 0.01, 0.001)) {
printf("Mismatch at i=%d, j=%d, k=%d: CC=%.6f, C=%.6f\n",
i, j, k, float16to32(CC[i * m + j]), truncate_float(C[i * m + j]));
ret++;
}
if (fabs(float16to32(CC[i * m + j]) - truncate_float(DD[i * m + j])) > 1.0) {

if (!is_close(float16to32(CC[i * m + j]), truncate_float(DD[i * m + j]), 0.0001, 0.00001)) {
printf("Mismatch at i=%d, j=%d, k=%d: CC=%.6f, DD=%.6f\n",
i, j, k, float16to32(CC[i * m + j]), truncate_float(DD[i * m + j]));
ret++;
}


Loading…
Cancel
Save