diff --git a/.gitignore b/.gitignore index 1807d4496..dcbc73dc8 100644 --- a/.gitignore +++ b/.gitignore @@ -83,6 +83,7 @@ test/ZBLAT3_3M.SUMM test/SHBLAT3.SUMM test/SBBLAT2.SUMM test/SBBLAT3.SUMM +test/BBLAT2.SUMM test/BBLAT3.SUMM test/cblat1 test/cblat2 @@ -100,6 +101,7 @@ test/test_shgemm test/test_sbgemm test/test_sbgemv test/test_bgemm +test/test_bgemv test/zblat1 test/zblat2 test/zblat3 diff --git a/cblas.h b/cblas.h index f48d5da1e..7503e43f7 100644 --- a/cblas.h +++ b/cblas.h @@ -465,6 +465,7 @@ void cblas_sbdtobf16(OPENBLAS_CONST blasint n, OPENBLAS_CONST double *in, OPEN void cblas_sbf16tos(OPENBLAS_CONST blasint n, OPENBLAS_CONST bfloat16 *in, OPENBLAS_CONST blasint incin, float *out, OPENBLAS_CONST blasint incout); /* convert BFLOAT16 array to double array */ void cblas_dbf16tod(OPENBLAS_CONST blasint n, OPENBLAS_CONST bfloat16 *in, OPENBLAS_CONST blasint incin, double *out, OPENBLAS_CONST blasint incout); +void cblas_bgemv(OPENBLAS_CONST enum CBLAS_ORDER order, OPENBLAS_CONST enum CBLAS_TRANSPOSE trans, OPENBLAS_CONST blasint m, OPENBLAS_CONST blasint n, OPENBLAS_CONST bfloat16 alpha, OPENBLAS_CONST bfloat16 *a, OPENBLAS_CONST blasint lda, OPENBLAS_CONST bfloat16 *x, OPENBLAS_CONST blasint incx, OPENBLAS_CONST bfloat16 beta, bfloat16 *y, OPENBLAS_CONST blasint incy); /* dot production of BFLOAT16 input arrays, and output as float */ float cblas_sbdot(OPENBLAS_CONST blasint n, OPENBLAS_CONST bfloat16 *x, OPENBLAS_CONST blasint incx, OPENBLAS_CONST bfloat16 *y, OPENBLAS_CONST blasint incy); void cblas_sbgemv(OPENBLAS_CONST enum CBLAS_ORDER order, OPENBLAS_CONST enum CBLAS_TRANSPOSE trans, OPENBLAS_CONST blasint m, OPENBLAS_CONST blasint n, OPENBLAS_CONST float alpha, OPENBLAS_CONST bfloat16 *a, OPENBLAS_CONST blasint lda, OPENBLAS_CONST bfloat16 *x, OPENBLAS_CONST blasint incx, OPENBLAS_CONST float beta, float *y, OPENBLAS_CONST blasint incy); diff --git a/cmake/kernel.cmake b/cmake/kernel.cmake index c0dd32ca0..ab548cfb0 100644 --- a/cmake/kernel.cmake +++ b/cmake/kernel.cmake @@ -169,8 +169,6 @@ if (BUILD_BFLOAT16) SetFallback(SHSWAPKERNEL ../arm/swap.c) SetFallback(TOBF16KERNEL ../x86_64/tobf16.c) SetFallback(BF16TOKERNEL ../x86_64/bf16to.c) - SetFallback(SBGEMVNKERNEL ../x86_64/sbgemv_n.c) - SetFallback(SBGEMVTKERNEL ../x86_64/sbgemv_t.c) endif () endmacro () @@ -221,6 +219,8 @@ macro(SetDefaultL2) SetFallback(XHEMV_V_KERNEL ../generic/zhemv_k.c) SetFallback(XHEMV_M_KERNEL ../generic/zhemv_k.c) if (BUILD_BFLOAT16) + SetFallback(BGEMVNKERNEL ../generic/gemv_n.c) + SetFallback(BGEMVTKERNEL ../generic/gemv_t.c) SetFallback(SBGEMVNKERNEL ../x86_64/sbgemv_n.c) SetFallback(SBGEMVTKERNEL ../x86_64/sbgemv_t.c) SetFallback(SHGERKERNEL ../generic/ger.c) diff --git a/cmake/utils.cmake b/cmake/utils.cmake index 1b19f41bc..52e5b5ee3 100644 --- a/cmake/utils.cmake +++ b/cmake/utils.cmake @@ -375,7 +375,7 @@ function(GenerateNamedObjects sources_in) if (NOT no_float_type) string(SUBSTRING ${float_type} 0 1 float_char) string(TOLOWER ${float_char} float_char) - if (${float_type} STREQUAL "BFLOAT16" AND NOT "${defines_in}" MATCHES "BGEMM") + if (${float_type} STREQUAL "BFLOAT16" AND NOT "${defines_in}" MATCHES "BGEM") set (float_char "sb") endif () endif () diff --git a/common_b.h b/common_b.h index 4d77ec4fa..1921c3a69 100644 --- a/common_b.h +++ b/common_b.h @@ -30,6 +30,11 @@ #define COMMON_B_H #ifndef DYNAMIC_ARCH +#define BGEMV_N_K bgemv_n +#define BGEMV_T_K bgemv_t + +#define BSCAL_K bscal_k + #define BGEMM_ONCOPY bgemm_oncopy #define BGEMM_OTCOPY bgemm_otcopy @@ -45,6 +50,10 @@ #define BGEMM_KERNEL bgemm_kernel #else +#define BGEMV_N_K gotoblas->bgemv_n +#define BGEMV_T_K gotoblas->bgemv_t + +#define BSCAL_K gotoblas->bscal_k #define BGEMM_ONCOPY gotoblas->bgemm_oncopy #define BGEMM_OTCOPY gotoblas->bgemm_otcopy diff --git a/common_interface.h b/common_interface.h index f69baab1c..945b6c8a1 100644 --- a/common_interface.h +++ b/common_interface.h @@ -60,6 +60,7 @@ double BLASFUNC(dsdot) (blasint *, float *, blasint *, float *, blasint *); double BLASFUNC(ddot) (blasint *, double *, blasint *, double *, blasint *); xdouble BLASFUNC(qdot) (blasint *, xdouble *, blasint *, xdouble *, blasint *); +void BLASFUNC(bscal) (blasint *, bfloat16 *, bfloat16 *, blasint *); float BLASFUNC(sbdot) (blasint *, bfloat16 *, blasint *, bfloat16 *, blasint *); void BLASFUNC(sbstobf16) (blasint *, float *, blasint *, bfloat16 *, blasint *); void BLASFUNC(sbdtobf16) (blasint *, double *, blasint *, bfloat16 *, blasint *); @@ -256,6 +257,8 @@ void BLASFUNC(xgeru)(blasint *, blasint *, xdouble *, xdouble *, blasint *, void BLASFUNC(xgerc)(blasint *, blasint *, xdouble *, xdouble *, blasint *, xdouble *, blasint *, xdouble *, blasint *); +void BLASFUNC(bgemv)(char *, blasint *, blasint *, bfloat16 *, bfloat16 *, blasint *, + bfloat16 *, blasint *, bfloat16 *, bfloat16 *, blasint *); void BLASFUNC(sbgemv)(char *, blasint *, blasint *, float *, bfloat16 *, blasint *, bfloat16 *, blasint *, float *, float *, blasint *); void BLASFUNC(sgemv)(char *, blasint *, blasint *, float *, float *, blasint *, diff --git a/common_level1.h b/common_level1.h index 85b39f7a7..7ab45a472 100644 --- a/common_level1.h +++ b/common_level1.h @@ -1,4 +1,5 @@ /*********************************************************************/ +/* Copyright 2025 The OpenBLAS Project. */ /* Copyright 2009, 2010 The University of Texas at Austin. */ /* All rights reserved. */ /* */ @@ -169,6 +170,9 @@ BLASLONG icmin_k(BLASLONG, float *, BLASLONG); BLASLONG izmin_k(BLASLONG, double *, BLASLONG); BLASLONG ixmin_k(BLASLONG, xdouble *, BLASLONG); + +int bscal_k(BLASLONG, BLASLONG, BLASLONG, bfloat16, + bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG); int sscal_k(BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG); int dscal_k(BLASLONG, BLASLONG, BLASLONG, double, diff --git a/common_level2.h b/common_level2.h index 9a5ebb4d9..eea5e43f3 100644 --- a/common_level2.h +++ b/common_level2.h @@ -1,4 +1,5 @@ /*********************************************************************/ +/* Copyright 2025 The OpenBLAS Project */ /* Copyright 2009, 2010 The University of Texas at Austin. */ /* All rights reserved. */ /* */ @@ -44,6 +45,11 @@ extern "C" { #endif + +int bgemv_n(BLASLONG, BLASLONG, bfloat16, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, bfloat16, bfloat16 *, BLASLONG); +int bgemv_t(BLASLONG, BLASLONG, bfloat16, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, bfloat16, bfloat16 *, BLASLONG); +int bgemv_thread_n(BLASLONG, BLASLONG, bfloat16, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, bfloat16, bfloat16 *, BLASLONG, int); +int bgemv_thread_t(BLASLONG, BLASLONG, bfloat16, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, bfloat16, bfloat16 *, BLASLONG, int); int sbgemv_n(BLASLONG, BLASLONG, float, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, float, float *, BLASLONG); int sbgemv_t(BLASLONG, BLASLONG, float, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, float, float *, BLASLONG); int sbgemv_thread_n(BLASLONG, BLASLONG, float, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, float, float *, BLASLONG, int); diff --git a/common_macro.h b/common_macro.h index 22c1e14a2..f9c22089b 100644 --- a/common_macro.h +++ b/common_macro.h @@ -705,6 +705,11 @@ #elif defined(BFLOAT16) && defined(BGEMM) +#define SCAL_K BSCAL_K + +#define GEMV_N BGEMV_N_K +#define GEMV_T BGEMV_T_K + #define GEMM_BETA BGEMM_BETA #define GEMM_KERNEL_N BGEMM_KERNEL #define GEMM_KERNEL_L BGEMM_KERNEL @@ -754,8 +759,8 @@ #define D_BF16_TO_K DBF16TOD_K #define S_TO_BF16_K SBSTOBF16_K #define S_BF16_TO_K SBF16TOS_K -#define SBGEMV_N SBGEMV_N_K -#define SBGEMV_T SBGEMV_T_K +#define GEMV_N SBGEMV_N_K +#define GEMV_T SBGEMV_T_K #define AMAX_K SAMAX_K #define AMIN_K SAMIN_K @@ -773,8 +778,6 @@ #define AXPYC_K SAXPYC_K #define AXPBY_K SAXPBY_K #define SCAL_K SSCAL_K -#define GEMV_N SGEMV_N -#define GEMV_T SGEMV_T #define SYMV_U SSYMV_U #define SYMV_L SSYMV_L #define GERU_K SGERU_K diff --git a/common_param.h b/common_param.h index 503525dd2..d6b8d9bad 100644 --- a/common_param.h +++ b/common_param.h @@ -98,10 +98,14 @@ int (*shgemm_otcopy )(BLASLONG, BLASLONG, hfloat16 *, BLASLONG, hfloat16 *); int (*sbrot_k) (BLASLONG, float *, BLASLONG, float *, BLASLONG, float, float); int (*sbrotm_k) (BLASLONG, float *, BLASLONG, float *, BLASLONG, float *); + int (*bscal_k) (BLASLONG, BLASLONG, BLASLONG, bfloat16, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG); int (*sbaxpy_k) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG); int (*sbscal_k) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG); int (*sbswap_k) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG); + int (*bgemv_n) (BLASLONG, BLASLONG, bfloat16, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, bfloat16, bfloat16 *, BLASLONG); + int (*bgemv_t) (BLASLONG, BLASLONG, bfloat16, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, bfloat16, bfloat16 *, BLASLONG); + int (*sbgemv_n) (BLASLONG, BLASLONG, float, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, float, float *, BLASLONG); int (*sbgemv_t) (BLASLONG, BLASLONG, float, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, float, float *, BLASLONG); int (*sbger_k) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG, float *); diff --git a/driver/level2/Makefile b/driver/level2/Makefile index 5f8c712a8..3f3731d3f 100644 --- a/driver/level2/Makefile +++ b/driver/level2/Makefile @@ -1,3 +1,31 @@ +############################################################################### +# Copyright (c) 2025 The OpenBLAS Project +# All rights reserved. +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# 1. Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in +# the documentation and/or other materials provided with the +# distribution. +# 3. Neither the name of the OpenBLAS project nor the names of +# its contributors may be used to endorse or promote products +# derived from this software without specific prior written permission. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +############################################################################### + TOPDIR = ../.. include ../../Makefile.system @@ -423,6 +451,9 @@ XBLASOBJS += \ xtbmv_thread_CLU.$(SUFFIX) xtbmv_thread_CLN.$(SUFFIX) ifeq ($(BUILD_BFLOAT16),1) +BBLASOBJS += \ + bgemv_thread_n$(TSUFFIX).$(SUFFIX) \ + bgemv_thread_t$(TSUFFIX).$(SUFFIX) SBBLASOBJS += \ sbgemv_thread_n$(TSUFFIX).$(SUFFIX) \ sbgemv_thread_t$(TSUFFIX).$(SUFFIX) @@ -3707,6 +3738,10 @@ xtrsv_CUN.$(SUFFIX) xtrsv_CUN.$(PSUFFIX) : ztrsv_L.c ../../param.h $(CC) -c $(CFLAGS) -DXDOUBLE -DCOMPLEX -DTRANSA=4 -UUNIT $< -o $(@F) ifeq ($(BUILD_BFLOAT16),1) +bgemv_thread_n.$(SUFFIX) bgemv_thread_n.$(PSUFFIX) : sbgemv_thread.c ../../common.h + $(CC) -c $(CFLAGS) -DBGEMM -UCOMPLEX -UDOUBLE -UTRANSA -UCONJ -UXCONJ $< -o $(@F) +bgemv_thread_t.$(SUFFIX) bgemv_thread_t.$(PSUFFIX) : sbgemv_thread.c ../../common.h + $(CC) -c $(CFLAGS) -DBGEMM -UCOMPLEX -UDOUBLE -DTRANSA -UCONJ -UXCONJ $< -o $(@F) sbgemv_thread_n.$(SUFFIX) sbgemv_thread_n.$(PSUFFIX) : sbgemv_thread.c ../../common.h $(CC) -c $(CFLAGS) -UCOMPLEX -UDOUBLE -UTRANSA -UCONJ -UXCONJ $< -o $(@F) sbgemv_thread_t.$(SUFFIX) sbgemv_thread_t.$(PSUFFIX) : sbgemv_thread.c ../../common.h diff --git a/driver/level2/sbgemv_thread.c b/driver/level2/sbgemv_thread.c index 534c60f95..c7fc90a35 100644 --- a/driver/level2/sbgemv_thread.c +++ b/driver/level2/sbgemv_thread.c @@ -1,4 +1,5 @@ /*********************************************************************/ +/* Copyright 2025 The OpenBLAS Project. */ /* Copyright 2009, 2010 The University of Texas at Austin. */ /* All rights reserved. */ /* */ @@ -41,21 +42,21 @@ #include "common.h" #ifndef TRANSA -#define SBGEMV SBGEMV_N +#define GEMV GEMV_N #else -#define SBGEMV SBGEMV_T +#define GEMV GEMV_T #endif static int sbgemv_kernel(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *dummy1, FLOAT *dummy2, BLASLONG dummy3){ - bfloat16 *a, *x; - float *y; + IFLOAT *a, *x; + FLOAT *y; BLASLONG lda, incx, incy; BLASLONG m_from, m_to, n_from, n_to; - a = (bfloat16 *)args->a; - x = (bfloat16 *)args->b; - y = (float *)args->c; + a = (IFLOAT *)args->a; + x = (IFLOAT *)args->b; + y = (FLOAT *)args->c; lda = args->lda; incx = args->ldb; @@ -77,12 +78,12 @@ static int sbgemv_kernel(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, y += n_from * incy; #endif - SBGEMV(m_to - m_from, n_to - n_from, *((FLOAT *)(args->alpha)), a, lda, x, incx, *((FLOAT *)(args->beta)), y, incy); + GEMV(m_to - m_from, n_to - n_from, *((FLOAT *)(args->alpha)), a, lda, x, incx, *((FLOAT *)(args->beta)), y, incy); return 0; } -int CNAME(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, BLASLONG incx, float beta, float *y, BLASLONG incy, int threads) +int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT *x, BLASLONG incx, FLOAT beta, FLOAT *y, BLASLONG incy, int threads) { blas_arg_t args; blas_queue_t queue[MAX_CPU_NUMBER]; diff --git a/exports/gensymbol b/exports/gensymbol index 17fbd2877..40e13e623 100755 --- a/exports/gensymbol +++ b/exports/gensymbol @@ -1,5 +1,33 @@ #!/bin/sh +############################################################################### +# Copyright (c) 2025, The OpenBLAS Project +# All rights reserved. +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# 1. Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in +# the documentation and/or other materials provided with the +# distribution. +# 3. Neither the name of the OpenBLAS project nor the names of +# its contributors may be used to endorse or promote products +# derived from this software without specific prior written permission. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +############################################################################### + # Changelog # 2017/09/03 staticfloat # Added zsymv and csymv into @lapackobjs2 so they are properly renamed @@ -51,7 +79,7 @@ blasobjsz=" zgeadd dzsum zgemmt zgemmtr" blasobjs="lsame xerbla" -bfblasobjs="bgemm sbgemm sbgemmt sbgemmtr sbgemv sbdot sbstobf16 sbdtobf16 sbf16tos dbf16tod" +bfblasobjs="bgemm bgemv sbgemm sbgemmt sbgemmtr sbgemv sbdot sbstobf16 sbdtobf16 sbf16tos dbf16tod" hfblasobjs="shgemm" cblasobjsc=" cblas_caxpy cblas_ccopy cblas_cdotc cblas_cdotu cblas_cgbmv cblas_cgemm cblas_cgemv diff --git a/exports/gensymbol.pl b/exports/gensymbol.pl index 01f68fbb3..3447a4e51 100644 --- a/exports/gensymbol.pl +++ b/exports/gensymbol.pl @@ -1,5 +1,33 @@ #!/usr/bin/env perl +############################################################################### +# Copyright (c) 2025, The OpenBLAS Project +# All rights reserved. +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# 1. Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in +# the documentation and/or other materials provided with the +# distribution. +# 3. Neither the name of the OpenBLAS project nor the names of +# its contributors may be used to endorse or promote products +# derived from this software without specific prior written permission. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +############################################################################### + # Changelog # 2017/09/03 staticfloat # Added zsymv and csymv into @lapackobjs2 so they are properly renamed @@ -51,7 +79,7 @@ zgeadd, dzsum, zgemmt,zgemmtr); @blasobjs = (lsame, xerbla); -@bfblasobjs = (bgemm, sbgemm, sbgemmt, sbgemmtr, sbgemv, sbdot, sbstobf16, sbdtobf16, sbf16tos, dbf16tod); +@bfblasobjs = (bgemm, bgemv, sbgemm, sbgemmt, sbgemmtr, sbgemv, sbdot, sbstobf16, sbdtobf16, sbf16tos, dbf16tod); @hfblasobjs = (shgemm); @cblasobjsc = ( cblas_caxpy, cblas_ccopy, cblas_cdotc, cblas_cdotu, cblas_cgbmv, cblas_cgemm, cblas_cgemv, diff --git a/interface/CMakeLists.txt b/interface/CMakeLists.txt index 995bebec3..0fcc79bfc 100644 --- a/interface/CMakeLists.txt +++ b/interface/CMakeLists.txt @@ -155,6 +155,7 @@ if (BUILD_BFLOAT16) GenerateNamedObjects("gemm.c" "" "sbgemm" ${CBLAS_FLAG} "" "" true "BFLOAT16") GenerateNamedObjects("sbgemmt.c" "" "sbgemmt" ${CBLAS_FLAG} "" "" true "BFLOAT16") GenerateNamedObjects("sbgemmt.c" "RNAME" "sbgemmtr" ${CBLAS_FLAG} "" "" true "BFLOAT16") + GenerateNamedObjects("bgemv.c" "BGEMV" "bgemv" ${CBLAS_FLAG} "" "" true "BFLOAT16") GenerateNamedObjects("sbgemv.c" "" "sbgemv" ${CBLAS_FLAG} "" "" true "BFLOAT16") GenerateNamedObjects("tobf16.c" "SINGLE_PREC" "sbstobf16" ${CBLAS_FLAG} "" "" true "BFLOAT16") GenerateNamedObjects("tobf16.c" "DOUBLE_PREC" "sbdtobf16" ${CBLAS_FLAG} "" "" true "BFLOAT16") diff --git a/interface/Makefile b/interface/Makefile index 3af12748f..999d23e9a 100644 --- a/interface/Makefile +++ b/interface/Makefile @@ -75,7 +75,9 @@ SBLAS3OBJS = \ sgeadd.$(SUFFIX) sgemmt.$(SUFFIX) sgemmtr.$(SUFFIX) ifeq ($(BUILD_BFLOAT16),1) -BBLAS3OBJ = bgemm.$(SUFFIX) +BBLAS3OBJS = bgemm.$(SUFFIX) +BBLAS2OBJS = bgemv.$(SUFFIX) +BBLAS1OBJS = bscal.$(SUFFIX) SBBLAS1OBJS = sbdot.$(SUFFIX) SBBLAS2OBJS = sbgemv.$(SUFFIX) SBBLAS3OBJS = sbgemm.$(SUFFIX) sbgemmt.$(SUFFIX) sbgemmtr.$(SUFFIX) @@ -319,6 +321,8 @@ CSBLAS3OBJS = \ ifeq ($(BUILD_BFLOAT16),1) CBBLAS3OBJS = cblas_bgemm.$(SUFFIX) +CBBLAS2OBJS = cblas_bgemv.$(SUFFIX) +CBBLAS1OBJS = cblas_bscal.$(SUFFIX) CSBBLAS1OBJS = cblas_sbdot.$(SUFFIX) CSBBLAS2OBJS = cblas_sbgemv.$(SUFFIX) CSBBLAS3OBJS = cblas_sbgemm.$(SUFFIX) cblas_sbgemmt.$(SUFFIX) cblas_sbgemmtr.$(SUFFIX) cblas_sbgemm_batch.$(SUFFIX) @@ -423,7 +427,9 @@ override CFLAGS += -I. SBLAS1OBJS += $(CSBLAS1OBJS) SBLAS2OBJS += $(CSBLAS2OBJS) SBLAS3OBJS += $(CSBLAS3OBJS) -BBLAS3OBJ += $(CBBLAS3OBJS) +BBLAS3OBJS += $(CBBLAS3OBJS) +BBLAS2OBJS += $(CBBLAS2OBJS) +BBLAS1OBJS += $(CBBLAS1OBJS) SBBLAS1OBJS += $(CSBBLAS1OBJS) SBBLAS2OBJS += $(CSBBLAS2OBJS) SBBLAS3OBJS += $(CSBBLAS3OBJS) @@ -443,7 +449,7 @@ SBEXTOBJS += $(CSBEXTOBJS) CBAUXOBJS += $(CXERBLAOBJ) endif -BBLASOBJS = $(BBLAS3OBJ) +BBLASOBJS = $(BBLAS3OBJS) $(BBLAS2OBJS) $(BBLAS1OBJS) SBLASOBJS = $(SBLAS1OBJS) $(SBLAS2OBJS) $(SBLAS3OBJS) SBBLASOBJS = $(SBBLAS1OBJS) $(SBBLAS2OBJS) $(SBBLAS3OBJS) SHBLASOBJS = $(SHBLAS3OBJS) @@ -589,7 +595,7 @@ clean :: level1 : $(SBEXTOBJS) $(SBBLAS1OBJS) $(SBLAS1OBJS) $(DBLAS1OBJS) $(QBLAS1OBJS) $(CBLAS1OBJS) $(ZBLAS1OBJS) $(XBLAS1OBJS) $(AR) $(ARFLAGS) -ru $(TOPDIR)/$(LIBNAME) $^ -level2 : $(SBBLAS2OBJS) $(SBLAS2OBJS) $(DBLAS2OBJS) $(QBLAS2OBJS) $(CBLAS2OBJS) $(ZBLAS2OBJS) $(XBLAS2OBJS) +level2 : $(SBBLAS2OBJS) $(BBLAS2OBJS) $(SBLAS2OBJS) $(DBLAS2OBJS) $(QBLAS2OBJS) $(CBLAS2OBJS) $(ZBLAS2OBJS) $(XBLAS2OBJS) $(AR) $(ARFLAGS) -ru $(TOPDIR)/$(LIBNAME) $^ level3 : $(SBBLAS3OBJS) $(BBLAS3OBJ) $(SBLAS3OBJS) $(DBLAS3OBJS) $(QBLAS3OBJS) $(CBLAS3OBJS) $(ZBLAS3OBJS) $(XBLAS3OBJS) $(SHBLAS3OBJS) @@ -824,6 +830,8 @@ dsdot.$(SUFFIX) dsdot.$(PSUFFIX) : dsdot.c $(CC) $(CFLAGS) -c $< -o $(@F) ifeq ($(BUILD_BFLOAT16),1) +bscal.$(SUFFIX) bscal.$(PSUFFIX) : scal.c + $(CC) $(CFLAGS) -DBGEMM -c $< -o $(@F) sbdot.$(SUFFIX) sbdot.$(PSUFFIX) : bf16dot.c $(CC) $(CFLAGS) -c $< -o $(@F) sbstobf16.$(SUFFIX) sbstobf16.$(PSUFFIX) : tobf16.c @@ -981,6 +989,8 @@ xgerc.$(SUFFIX) xgerc.$(PSUFFIX) : zger.c $(CC) -c $(CFLAGS) -DCONJ $< -o $(@F) ifeq ($(BUILD_BFLOAT16),1) +bgemv.$(SUFFIX) bgemv.$(PSUFFIX) : sbgemv.c + $(CC) $(CFLAGS) -DBGEMM -c $< -o $(@F) sbgemv.$(SUFFIX) sbgemv.$(PSUFFIX) : sbgemv.c $(CC) $(CFLAGS) -c $< -o $(@F) endif @@ -1653,6 +1663,8 @@ cblas_dsdot.$(SUFFIX) cblas_dsdot.$(PSUFFIX) : dsdot.c $(CC) $(CFLAGS) -DCBLAS -c $< -o $(@F) ifeq ($(BUILD_BFLOAT16),1) +cblas_bscal.$(SUFFIX) cblas_bscal.$(PSUFFIX) : scal.c + $(CC) $(CFLAGS) -DCBLAS -c $< -o $(@F) cblas_sbdot.$(SUFFIX) cblas_sbdot.$(PSUFFIX) : bf16dot.c $(CC) $(CFLAGS) -DCBLAS -c $< -o $(@F) cblas_sbstobf16.$(SUFFIX) cblas_sbstobf16.$(PSUFFIX) : tobf16.c @@ -1807,6 +1819,8 @@ cblas_zdrot.$(SUFFIX) cblas_zdrot.$(PSUFFIX) : zrot.c $(CC) $(CFLAGS) -DCBLAS -c $< -o $(@F) ifeq ($(BUILD_BFLOAT16),1) +cblas_bgemv.$(SUFFIX) cblas_bgemv.$(PSUFFIX) : sbgemv.c + $(CC) -DCBLAS -DBGEMM -c $(CFLAGS) $< -o $(@F) cblas_sbgemv.$(SUFFIX) cblas_sbgemv.$(PSUFFIX) : sbgemv.c $(CC) -DCBLAS -c $(CFLAGS) $< -o $(@F) endif diff --git a/interface/sbgemmt.c b/interface/sbgemmt.c index 759af4bfb..67914fe65 100644 --- a/interface/sbgemmt.c +++ b/interface/sbgemmt.c @@ -1,5 +1,5 @@ /*********************************************************************/ -/* Copyright 2024, The OpenBLAS Project. */ +/* Copyright 2024-2025 The OpenBLAS Project. */ /* All rights reserved. */ /* */ /* Redistribution and use in source and binary forms, with or */ @@ -305,7 +305,7 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, #endif int (*gemv[]) (BLASLONG, BLASLONG, FLOAT, IFLOAT *, BLASLONG, IFLOAT *, BLASLONG, FLOAT, FLOAT *, BLASLONG) = { - SBGEMV_N, SBGEMV_T,}; + GEMV_N, GEMV_T,}; if (m == 0) diff --git a/interface/sbgemv.c b/interface/sbgemv.c index fce86f8e4..cee3e80fc 100644 --- a/interface/sbgemv.c +++ b/interface/sbgemv.c @@ -1,4 +1,5 @@ /*********************************************************************/ +/* Copyright 2025 The OpenBLAS Project. */ /* Copyright 2009, 2010 The University of Texas at Austin. */ /* All rights reserved. */ /* */ @@ -43,17 +44,25 @@ #include "functable.h" #endif +#ifdef BGEMM +#define GEMV_THREAD_N bgemv_thread_n +#define GEMV_THREAD_T bgemv_thread_t +#define ERROR_NAME "BGEMV " +#else +#define GEMV_THREAD_N sbgemv_thread_n +#define GEMV_THREAD_T sbgemv_thread_t #define ERROR_NAME "SBGEMV " +#endif #ifdef SMP -static int (*sbgemv_thread[])(BLASLONG, BLASLONG, float, bfloat16 *, BLASLONG, bfloat16 * , BLASLONG, float, float *, BLASLONG, int) = { - sbgemv_thread_n, sbgemv_thread_t, +static int (*gemv_thread[])(BLASLONG, BLASLONG, FLOAT, IFLOAT *, BLASLONG, IFLOAT * , BLASLONG, FLOAT, FLOAT *, BLASLONG, int) = { + GEMV_THREAD_N, GEMV_THREAD_T, }; #endif #ifndef CBLAS -void NAME(char *TRANS, blasint *M, blasint *N, float *ALPHA, bfloat16 *a, blasint *LDA, bfloat16 *x, blasint *INCX, float *BETA, float *y, blasint *INCY) +void NAME(char *TRANS, blasint *M, blasint *N, FLOAT *ALPHA, IFLOAT *a, blasint *LDA, IFLOAT *x, blasint *INCX, FLOAT *BETA, FLOAT *y, blasint *INCY) { char trans = *TRANS; blasint m = *M; @@ -61,14 +70,14 @@ void NAME(char *TRANS, blasint *M, blasint *N, float *ALPHA, bfloat16 *a, blasin blasint lda = *LDA; blasint incx = *INCX; blasint incy = *INCY; - float alpha = *ALPHA; - float beta = *BETA; + FLOAT alpha = *ALPHA; + FLOAT beta = *BETA; #ifdef SMP int nthreads; #endif - int (*sbgemv[])(BLASLONG, BLASLONG, float, bfloat16 *, BLASLONG, bfloat16 * , BLASLONG, float, float *, BLASLONG) = { - SBGEMV_N, SBGEMV_T, + int (*gemv[])(BLASLONG, BLASLONG, FLOAT, IFLOAT *, BLASLONG, IFLOAT * , BLASLONG, FLOAT, FLOAT *, BLASLONG) = { + GEMV_N, GEMV_T, }; blasint info; @@ -104,7 +113,7 @@ void NAME(char *TRANS, blasint *M, blasint *N, float *ALPHA, bfloat16 *a, blasin #else -void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, blasint m, blasint n, float alpha, bfloat16 *a, blasint lda, bfloat16 *x, blasint incx, float beta, float *y, blasint incy) +void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, blasint m, blasint n, FLOAT alpha, IFLOAT *a, blasint lda, IFLOAT *x, blasint incx, FLOAT beta, FLOAT *y, blasint incy) { blasint lenx, leny; int trans; @@ -113,8 +122,8 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, blasint m, blasi int nthreads; #endif - int (*sbgemv[])(BLASLONG, BLASLONG, float, bfloat16 *, BLASLONG, bfloat16 * , BLASLONG, float, float *, BLASLONG) = { - SBGEMV_N, SBGEMV_T, + int (*gemv[])(BLASLONG, BLASLONG, FLOAT, IFLOAT *, BLASLONG, IFLOAT * , BLASLONG, FLOAT, FLOAT *, BLASLONG) = { + GEMV_N, GEMV_T, }; PRINT_DEBUG_CNAME; @@ -166,8 +175,17 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, blasint m, blasi leny = m; } - if (alpha == ZERO) { - if (beta != ONE) SCAL_K(leny, 0, 0, beta, y, blasabs(incy), NULL, 0, NULL, 0); +#ifdef BGEMM + float alpha_float, beta_float; + SBF16TOS_K(1, &alpha, 1, &alpha_float, 1); + SBF16TOS_K(1, &beta, 1, &beta_float, 1); +#else + float alpha_float = alpha; + float beta_float = beta; +#endif + + if (alpha_float == ZERO) { + if (beta_float != ONE) SCAL_K(leny, 0, 0, beta, y, blasabs(incy), NULL, 0, NULL, 0); return; } @@ -185,10 +203,10 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, blasint m, blasi if (nthreads == 1) { #endif - (sbgemv[(int)trans])(m, n, alpha, a, lda, x, incx, beta, y, incy); + (gemv[(int)trans])(m, n, alpha, a, lda, x, incx, beta, y, incy); #ifdef SMP } else { - (sbgemv_thread[(int)trans])(m, n, alpha, a, lda, x, incx, beta, y, incy, nthreads); + (gemv_thread[(int)trans])(m, n, alpha, a, lda, x, incx, beta, y, incy, nthreads); } #endif diff --git a/interface/scal.c b/interface/scal.c index c6638a62d..4f12df7c0 100644 --- a/interface/scal.c +++ b/interface/scal.c @@ -1,4 +1,5 @@ /*********************************************************************/ +/* Copyright 2025 The OpenBLAS Project. */ /* Copyright 2009, 2010 The University of Texas at Austin. */ /* All rights reserved. */ /* */ @@ -68,7 +69,14 @@ void CNAME(blasint n, FLOAT alpha, FLOAT *x, blasint incx){ if (incx <= 0 || n <= 0) return; - if (alpha == ONE) return; +#ifdef BGEMM + float alpha_float; + SBF16TOS_K(1, &alpha, 1, &alpha_float, 1); +#else + float alpha_float = alpha; +#endif + + if (alpha_float == ONE) return; IDEBUG_START; diff --git a/kernel/Makefile.L1 b/kernel/Makefile.L1 index 0fc672094..221cc5127 100644 --- a/kernel/Makefile.L1 +++ b/kernel/Makefile.L1 @@ -1,3 +1,31 @@ +############################################################################### +# Copyright (c) 2025 The OpenBLAS Project +# All rights reserved. +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# 1. Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in +# the documentation and/or other materials provided with the +# distribution. +# 3. Neither the name of the OpenBLAS project nor the names of +# its contributors may be used to endorse or promote products +# derived from this software without specific prior written permission. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +############################################################################### + FMAFLAG= ifndef OLDGCC ifdef HAVE_FMA3 @@ -271,6 +299,10 @@ XDOTKERNEL = zdot.S endif ifeq ($(BUILD_BFLOAT16),1) +ifndef BSCALKERNEL +BSCALKERNEL = ../generic/scal.c +endif + ifndef SBDOTKERNEL SBDOTKERNEL = ../x86_64/sbdot.c endif @@ -551,6 +583,8 @@ XBLASOBJS += \ xscal_k$(TSUFFIX).$(SUFFIX) xswap_k$(TSUFFIX).$(SUFFIX) xsum_k$(TSUFFIX).$(SUFFIX) ifeq ($(BUILD_BFLOAT16),1) +BBLASOBJS += \ + bscal_k$(TSUFFIX).$(SUFFIX) SBBLASOBJS += \ sbdot_k$(TSUFFIX).$(SUFFIX) SBEXTOBJS += \ @@ -778,6 +812,8 @@ $(KDIR)qdot_k$(TSUFFIX).$(SUFFIX) $(KDIR)qdot_k$(TPSUFFIX).$(PSUFFIX) : $(KERNEL $(CC) -c $(CFLAGS) -UCOMPLEX -DXDOUBLE $< -o $@ ifeq ($(BUILD_BFLOAT16),1) +$(KDIR)bscal_k$(TSUFFIX).$(SUFFIX) $(KDIR)bscal_k$(TPSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(BSCALKERNEL) + $(CC) -c $(CFLAGS) -UCOMPLEX -UDOUBLE $< -o $@ $(KDIR)sbdot_k$(TSUFFIX).$(SUFFIX) $(KDIR)sbdot_k$(TPSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SBDOTKERNEL) $(CC) -c $(CFLAGS) -UCOMPLEX $< -o $@ $(KDIR)sbstobf16_k$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(TOBF16KERNEL) diff --git a/kernel/Makefile.L2 b/kernel/Makefile.L2 index 0332ba722..a9fcf9225 100644 --- a/kernel/Makefile.L2 +++ b/kernel/Makefile.L2 @@ -1,3 +1,31 @@ +############################################################################### +# Copyright (c) 2025 The OpenBLAS Project +# All rights reserved. +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# 1. Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in +# the documentation and/or other materials provided with the +# distribution. +# 3. Neither the name of the OpenBLAS project nor the names of +# its contributors may be used to endorse or promote products +# derived from this software without specific prior written permission. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +############################################################################### + FMAFLAG= ifndef OLDGCC ifdef HAVE_FMA3 @@ -56,6 +84,14 @@ XGEMVTKERNEL = zgemv_t.S endif ifeq ($(BUILD_BFLOAT16),1) +ifndef BGEMVNKERNEL +BGEMVNKERNEL = ../generic/gemv_n.c +endif + +ifndef BGEMVTKERNEL +BGEMVTKERNEL = ../generic/gemv_t.c +endif + ifndef SBGEMVNKERNEL SBGEMVNKERNEL = ../x86_64/sbgemv_n.c endif @@ -255,6 +291,9 @@ XBLASOBJS += \ xgeru_k$(TSUFFIX).$(SUFFIX) xgerc_k$(TSUFFIX).$(SUFFIX) xgerv_k$(TSUFFIX).$(SUFFIX) xgerd_k$(TSUFFIX).$(SUFFIX) ifeq ($(BUILD_BFLOAT16),1) +BBLASOBJS += \ + bgemv_n$(TSUFFIX).$(SUFFIX) \ + bgemv_t$(TSUFFIX).$(SUFFIX) SBBLASOBJS += \ sbgemv_n$(TSUFFIX).$(SUFFIX) \ sbgemv_t$(TSUFFIX).$(SUFFIX) @@ -513,5 +552,9 @@ $(KDIR)sbgemv_n$(TSUFFIX).$(SUFFIX) $(KDIR)sbgemv_n$(TPSUFFIX).$(PSUFFIX) : $(KE $(CC) -c $(CFLAGS) -UCOMPLEX $< -o $@ $(KDIR)sbgemv_t$(TSUFFIX).$(SUFFIX) $(KDIR)sbgemv_t$(TPSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SBGEMVTKERNEL) $(CC) -c $(CFLAGS) -UCOMPLEX $< -o $@ +$(KDIR)bgemv_n$(TSUFFIX).$(SUFFIX) $(KDIR)bgemv_n$(TPSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(BGEMVNKERNEL) + $(CC) -c $(CFLAGS) -DBGEMM -UCOMPLEX $< -o $@ +$(KDIR)bgemv_t$(TSUFFIX).$(SUFFIX) $(KDIR)bgemv_t$(TPSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(BGEMVTKERNEL) + $(CC) -c $(CFLAGS) -DBGEMM -UCOMPLEX $< -o $@ endif diff --git a/kernel/generic/bf16_macros.h b/kernel/generic/bf16_macros.h new file mode 100644 index 000000000..f1b02cea4 --- /dev/null +++ b/kernel/generic/bf16_macros.h @@ -0,0 +1,64 @@ +/*************************************************************************** + * Copyright (c) 2025, The OpenBLAS Project + * All rights reserved. + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in + * the documentation and/or other materials provided with the + * distribution. + * 3. Neither the name of the OpenBLAS project nor the names of + * its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + * *****************************************************************************/ + +#if defined(BFLOAT16) && defined(BFLOAT16CONVERSION) +static float +bfloat16tof32 (bfloat16 value) +{ + blasint one = 1; + float result; + sbf16tos_(&one, &value, &one, &result, &one); + return result; +} + +#ifdef BGEMM +static bfloat16 f32tobfloat16(float value) { + blasint one = 1; + bfloat16 result; + sbstobf16_(&one, &value, &one, &result, &one); + return result; +} +#endif + +#ifdef BGEMM +#define ALPHA bfloat16tof32(alpha) +#define BETA bfloat16tof32(beta) +#define BF16TOF32(x) (bfloat16tof32(x)) +#define F32TOBF16(x) (f32tobfloat16(x)) +#else +#define ALPHA alpha +#define BETA beta +#define BF16TOF32(x) (bfloat16tof32(x)) +#define F32TOBF16(x) x +#endif +#else +#define ALPHA alpha +#define BETA beta +#define BF16TOF32(x) x +#define F32TOBF16(x) x +#endif diff --git a/kernel/generic/gemmkernel_2x2.c b/kernel/generic/gemmkernel_2x2.c index 8872f2f56..c24370c89 100644 --- a/kernel/generic/gemmkernel_2x2.c +++ b/kernel/generic/gemmkernel_2x2.c @@ -27,39 +27,8 @@ * *****************************************************************************/ #include "common.h" -#if defined(BFLOAT16) && defined(BFLOAT16CONVERSION) -static float -bfloat16tof32 (bfloat16 value) -{ - blasint one = 1; - float result; - sbf16tos_(&one, &value, &one, &result, &one); - return result; -} - -#ifdef BGEMM -static bfloat16 f32tobfloat16(float value) { - blasint one = 1; - bfloat16 result; - sbstobf16_(&one, &value, &one, &result, &one); - return result; -} -#endif +#include "bf16_macros.h" -#ifdef BGEMM -#define ALPHA bfloat16tof32(alpha) -#define BF16TOF32(x) (bfloat16tof32(x)) -#define F32TOBF16(x) (f32tobfloat16(x)) -#else -#define ALPHA alpha -#define BF16TOF32(x) (bfloat16tof32(x)) -#define F32TOBF16(x) x -#endif -#else -#define ALPHA alpha -#define BF16TOF32(x) x -#define F32TOBF16(x) x -#endif int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,IFLOAT* ba,IFLOAT* bb,FLOAT* C,BLASLONG ldc #ifdef TRMMKERNEL ,BLASLONG offset diff --git a/kernel/generic/gemv_n.c b/kernel/generic/gemv_n.c new file mode 100644 index 000000000..1c72b07af --- /dev/null +++ b/kernel/generic/gemv_n.c @@ -0,0 +1,70 @@ +/*************************************************************************** +Copyright (c) 2013-2014, 2025 The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#include "common.h" +#include "bf16_macros.h" + +int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT *x, BLASLONG inc_x, FLOAT beta, FLOAT *y, BLASLONG inc_y) +{ + BLASLONG i; + BLASLONG ix, iy; + BLASLONG j; + FLOAT *a_ptr; +#ifdef BGEMM + float temp; +#else + FLOAT temp; +#endif + + iy = 0; + for (BLASLONG i = 0; i < m; i++) + { + temp = 0.0; + + ix = 0; + a_ptr = a; + for (BLASLONG j = 0; j < n; j++) + { + temp += BF16TOF32(a_ptr[i]) * BF16TOF32(x[ix]); + ix += inc_x; + a_ptr += lda; + } + + if (BETA == ZERO) + { + y[iy] = F32TOBF16(ALPHA * temp); + } + else + { + y[iy] = F32TOBF16(ALPHA * temp + BETA * BF16TOF32(y[iy])); + } + + iy += inc_y; + } + + return (0); +} diff --git a/kernel/generic/gemv_t.c b/kernel/generic/gemv_t.c new file mode 100644 index 000000000..3b651b5c1 --- /dev/null +++ b/kernel/generic/gemv_t.c @@ -0,0 +1,60 @@ +/*************************************************************************** +Copyright (c) 2013, 2025 The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#include "common.h" +#include "bf16_macros.h" + +int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT *x, BLASLONG inc_x, FLOAT beta, FLOAT *y, BLASLONG inc_y) +{ + BLASLONG i; + BLASLONG ix, iy; + BLASLONG j; + FLOAT *a_ptr; +#ifdef BGEMM + float temp; +#else + FLOAT temp; +#endif + + iy = 0; + a_ptr = a; + + for (j = 0; j < n; j++) + { + temp = 0.0; + ix = 0; + for (i = 0; i < m; i++) + { + temp += BF16TOF32(a_ptr[i]) * BF16TOF32(x[ix]); + ix += inc_x; + } + y[iy] += F32TOBF16(ALPHA * temp); + iy += inc_y; + a_ptr += lda; + } + return (0); +} diff --git a/kernel/generic/scal.c b/kernel/generic/scal.c new file mode 100644 index 000000000..fef0c7bf9 --- /dev/null +++ b/kernel/generic/scal.c @@ -0,0 +1,106 @@ +/*************************************************************************** +Copyright (c) 2013, 2025 The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#include "common.h" + +int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da, FLOAT *x, BLASLONG inc_x, FLOAT *y, BLASLONG inc_y, FLOAT *dummy, BLASLONG dummy2) +{ + BLASLONG i = 0, j = 0; +#if defined(BFLOAT16) + float x_float, da_float; + SBF16TOS_K(1, &da, 1, &da_float, 1); +#else + float x_float; + float da_float = da; +#endif + + if ((n <= 0) || (inc_x <= 0)) + return (0); + + if (dummy2 == 0) + { + while (j < n) + { + + if (da_float == 0.0) + x_float = 0.0; + else + { +#if defined(BFLOAT16) + SBF16TOS_K(1, &x[i], 1, &x_float, 1); +#else + float x_float = x[i]; +#endif + x_float = da_float * x_float; + } + +#if defined(BFLOAT16) + SBSTOBF16_K(1, &x_float, 1, &x[i], 1); +#else + x[i] = x_float; +#endif + + i += inc_x; + j++; + } + } + else + { + + while (j < n) + { +#if defined(BFLOAT16) + SBF16TOS_K(1, &x[i], 1, &x_float, 1); +#else + float x_float = x[i]; +#endif + if (da == 0.0) + if (!isnan(x_float) && !isinf(x_float)) + { + x_float = 0.0; + } + else + { + x_float = NAN; + } + else + { + x_float = da_float * x_float; + } + +#if defined(BFLOAT16) + SBSTOBF16_K(1, &x_float, 1, &x[i], 1); +#else + x[i] = x_float; +#endif + + i += inc_x; + j++; + } + } + return 0; +} diff --git a/kernel/setparam-ref.c b/kernel/setparam-ref.c index 886895acc..c09472e76 100644 --- a/kernel/setparam-ref.c +++ b/kernel/setparam-ref.c @@ -83,8 +83,8 @@ gotoblas_t TABLE_NAME = { isamax_kTS, isamin_kTS, ismax_kTS, ismin_kTS, snrm2_kTS, sasum_kTS, ssum_kTS, scopy_kTS, sbdot_kTS, dsdot_kTS, - srot_kTS, srotm_kTS, saxpy_kTS, sscal_kTS, sswap_kTS, - sbgemv_nTS, sbgemv_tTS, sger_kTS, + srot_kTS, srotm_kTS, bscal_kTS, saxpy_kTS, sscal_kTS, sswap_kTS, + bgemv_nTS, bgemv_tTS, sbgemv_nTS, sbgemv_tTS, sger_kTS, ssymv_LTS, ssymv_UTS, bgemm_kernelTS, bgemm_betaTS, diff --git a/test/Makefile b/test/Makefile index 144738eb2..d8e12058a 100644 --- a/test/Makefile +++ b/test/Makefile @@ -120,6 +120,7 @@ endif endif ifeq ($(BUILD_BFLOAT16), 1) +BB2 = test_bgemv B2 = test_sbgemv endif ifeq ($(BUILD_SINGLE),1) @@ -135,12 +136,14 @@ ifeq ($(BUILD_COMPLEX16),1) Z2=zblat2 endif -level2: $(B2) $(S2) $(D2) $(C2) $(Z2) +level2: $(BB2) $(B2) $(S2) $(D2) $(C2) $(Z2) ifneq ($(CROSS), 1) rm -f ?BLAT2.SUMM ifeq ($(BUILD_BFLOAT16),1) + OPENBLAS_NUM_THREADS=1 OMP_NUM_THREADS=1 ./test_bgemv > BBLAT2.SUMM + @$(GREP) -q FATAL BBLAT2.SUMM && cat BBLAT2.SUMM || exit 0 OPENBLAS_NUM_THREADS=1 OMP_NUM_THREADS=1 ./test_sbgemv > SBBLAT2.SUMM @$(GREP) -q FATAL SBBLAT2.SUMM && cat SBBLAT2.SUMM || exit 0 endif @@ -164,6 +167,8 @@ ifdef SMP rm -f ?BLAT2.SUMM ifeq ($(USE_OPENMP), 1) ifeq ($(BUILD_BFLOAT16),1) + OMP_NUM_THREADS=2 ./test_bgemv > BBLAT2.SUMM + @$(GREP) -q FATAL BBLAT2.SUMM && cat BBLAT2.SUMM || exit 0 OMP_NUM_THREADS=2 ./test_sbgemv > SBBLAT2.SUMM @$(GREP) -q FATAL SBBLAT2.SUMM && cat SBBLAT2.SUMM || exit 0 endif @@ -185,6 +190,8 @@ ifeq ($(BUILD_COMPLEX16),1) endif else ifeq ($(BUILD_BFLOAT16),1) + OMP_NUM_THREADS=2 ./test_bgemv > BBLAT2.SUMM + @$(GREP) -q FATAL BBLAT2.SUMM && cat BBLAT2.SUMM || exit 0 OMP_NUM_THREADS=2 ./test_sbgemv > SBBLAT2.SUMM @$(GREP) -q FATAL SBBLAT2.SUMM && cat SBBLAT2.SUMM || exit 0 endif @@ -419,13 +426,16 @@ endif ifeq ($(BUILD_BFLOAT16),1) test_bgemm : compare_sgemm_bgemm.c test_helpers.h ../$(LIBNAME) - $(CC) $(CLDFLAGS) -o test_bgemm compare_sgemm_bgemm.c ../$(LIBNAME) $(EXTRALIB) $(CEXTRALIB) + $(CC) $(CLDFLAGS) -DIBFLOAT16 -DOBFLOAT16 -o test_bgemm compare_sgemm_bgemm.c ../$(LIBNAME) $(EXTRALIB) $(CEXTRALIB) + +test_bgemv : compare_sgemv_bgemv.c ../$(LIBNAME) + $(CC) $(CLDFLAGS) -DIBFLOAT16 -DOBFLOAT16 -o test_bgemv compare_sgemv_bgemv.c ../$(LIBNAME) $(EXTRALIB) $(CEXTRALIB) test_sbgemm : compare_sgemm_sbgemm.c test_helpers.h ../$(LIBNAME) - $(CC) $(CLDFLAGS) -o test_sbgemm compare_sgemm_sbgemm.c ../$(LIBNAME) $(EXTRALIB) $(CEXTRALIB) + $(CC) $(CLDFLAGS) -DIBFLOAT16 -o test_sbgemm compare_sgemm_sbgemm.c ../$(LIBNAME) $(EXTRALIB) $(CEXTRALIB) test_sbgemv : compare_sgemv_sbgemv.c ../$(LIBNAME) - $(CC) $(CLDFLAGS) -o test_sbgemv compare_sgemv_sbgemv.c ../$(LIBNAME) $(EXTRALIB) $(CEXTRALIB) + $(CC) $(CLDFLAGS) -DIBFLOAT16 -o test_sbgemv compare_sgemv_sbgemv.c ../$(LIBNAME) $(EXTRALIB) $(CEXTRALIB) endif ifeq ($(BUILD_COMPLEX),1) @@ -444,7 +454,7 @@ clean: @rm -f *.$(SUFFIX) *.$(PSUFFIX) gmon.$(SUFFIX)ut *.SUMM *.cxml *.exe *.pdb *.dwf \ sblat1 dblat1 cblat1 zblat1 \ sblat2 dblat2 cblat2 zblat2 \ - test_bgemm test_sbgemm test_sbgemv sblat3 dblat3 cblat3 zblat3 \ + test_bgemm test_bgemv test_sbgemm test_sbgemv sblat3 dblat3 cblat3 zblat3 \ sblat1p dblat1p cblat1p zblat1p \ sblat2p dblat2p cblat2p zblat2p \ sblat3p dblat3p cblat3p zblat3p \ diff --git a/test/compare_sgemm_bgemm.c b/test/compare_sgemm_bgemm.c index bc8a0b468..f18fe1201 100644 --- a/test/compare_sgemm_bgemm.c +++ b/test/compare_sgemm_bgemm.c @@ -34,15 +34,6 @@ THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #define BGEMM BLASFUNC(bgemm) #define BGEMM_LARGEST 256 -static float truncate_float32_to_bfloat16(float value) { - blasint one = 1; - bfloat16 tmp; - float result; - sbstobf16_(&one, &value, &one, &tmp, &one); - sbf16tos_(&one, &tmp, &one, &result, &one); - return result; -} - int main (int argc, char *argv[]) { diff --git a/test/compare_sgemv_bgemv.c b/test/compare_sgemv_bgemv.c new file mode 100644 index 000000000..aac98760f --- /dev/null +++ b/test/compare_sgemv_bgemv.c @@ -0,0 +1,149 @@ +/*************************************************************************** +Copyright (c) 2020,2025 The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ +#include +#include +#include "../common.h" + +#include "test_helpers.h" + +#define SGEMV BLASFUNC(sgemv) +#define BGEMV BLASFUNC(bgemv) +#define BGEMV_LARGEST 256 + +int main(int argc, char *argv[]) +{ + blasint k; + int i, j, l; + blasint x, y; + blasint one = 1; + int ret = 0; + int loop = BGEMV_LARGEST; + char transA = 'N'; + float alpha = 1.0, beta = 0.0; + bfloat16 alpha_bf16, beta_bf16; + + for (beta = 0; beta < 3; beta += 1) + { + for (alpha = 0; alpha < 3; alpha += 1) + { + for (l = 0; l < 2; l++) + { // l = 1 to test inc_x & inc_y not equal to one. + for (x = 1; x <= loop; x++) + { + k = (x == 0) ? 0 : l + 1; + float *A = (float *)malloc_safe(x * x * sizeof(FLOAT)); + float *B = (float *)malloc_safe(x * sizeof(FLOAT) << l); + float *C = (float *)malloc_safe(x * sizeof(FLOAT) << l); + bfloat16 *AA = (bfloat16 *)malloc_safe(x * x * sizeof(bfloat16)); + bfloat16 *BB = (bfloat16 *)malloc_safe(x * sizeof(bfloat16) << l); + bfloat16 *CC = (bfloat16 *)malloc_safe(x * sizeof(bfloat16) << l); + float *DD = (float *)malloc_safe(x * sizeof(FLOAT)); + if ((A == NULL) || (B == NULL) || (C == NULL) || (AA == NULL) || (BB == NULL) || + (CC == NULL) || (DD == NULL)) + return 1; + + for (j = 0; j < x; j++) + { + for (i = 0; i < x; i++) + { + A[j * x + i] = ((FLOAT)rand() / (FLOAT)RAND_MAX) + 0.5; + sbstobf16_(&one, &A[j * x + i], &one, &AA[j * x + i], &one); + } + B[j << l] = ((FLOAT)rand() / (FLOAT)RAND_MAX) + 0.5; + sbstobf16_(&one, &B[j << l], &one, &BB[j << l], &one); + + C[j << l] = ((FLOAT)rand() / (FLOAT)RAND_MAX) + 0.5; + sbstobf16_(&one, &B[j << l], &one, &CC[j << l], &one); + } + + for (y = 0; y < 2; y++) + { + if (y == 0) + { + transA = 'N'; + } + else + { + transA = 'T'; + } + + memset(C, 0, x * sizeof(FLOAT) << l); + memset(CC, 0, x * sizeof(bfloat16) << l); + memset(DD, 0, x * sizeof(FLOAT)); + + sbstobf16_(&one, &alpha, &one, &alpha_bf16, &one); + sbstobf16_(&one, &beta, &one, &beta_bf16, &one); + SGEMV(&transA, &x, &x, &alpha, A, &x, B, &k, &beta, C, &k); + BGEMV(&transA, &x, &x, &alpha_bf16, AA, &x, BB, &k, &beta_bf16, CC, &k); + + for (int i = 0; i < x; i++) + DD[i] *= beta; + + for (j = 0; j < x; j++) + for (i = 0; i < x; i++) + if (transA == 'N') + { + DD[i] += alpha * float16to32(AA[j * x + i]) * float16to32(BB[j << l]); + } + else if (transA == 'T') + { + DD[j] += alpha * float16to32(AA[j * x + i]) * float16to32(BB[i << l]); + } + + for (j = 0; j < x; j++) + { + if (!is_close(float16to32(CC[j << l]), truncate_float32_to_bfloat16(C[j << l]), 0.01, 0.001)) + { + printf("Mismatch at trans=%c, alpha=%.2f, beta=%.2f, i=%d, j=%d, k=%d: CC=%.6f, C=%.6f\n", + transA, alpha, beta, i, j, k, float16to32(CC[j << l]), truncate_float32_to_bfloat16(C[j << l])); + ret++; + } + if (!is_close(float16to32(CC[j << l]), truncate_float32_to_bfloat16(DD[j]), 0.001, 0.0001)) + { + printf("Mismatch at trans=%c, alpha=%.2f, beta=%.2f, i=%d, j=%d, k=%d: CC=%.6f, C=%.6f\n", + transA, alpha, beta, i, j, k, float16to32(CC[j << l]), truncate_float32_to_bfloat16(DD[j])); + ret++; + } + } + } + + free(A); + free(B); + free(C); + free(AA); + free(BB); + free(CC); + free(DD); + } // x + } // l + } // alpha + } // beta + + if (ret != 0) + fprintf(stderr, "FATAL ERROR BGEMV - Return code: %d\n", ret); + return ret; +} diff --git a/test/compare_sgemv_sbgemv.c b/test/compare_sgemv_sbgemv.c index 5fa2d5f66..15cdce6cb 100644 --- a/test/compare_sgemv_sbgemv.c +++ b/test/compare_sgemv_sbgemv.c @@ -56,8 +56,8 @@ main (int argc, char *argv[]) float *C = (float *)malloc_safe(x * sizeof(FLOAT) << l); bfloat16 *AA = (bfloat16 *)malloc_safe(x * x * sizeof(bfloat16)); bfloat16 *BB = (bfloat16 *)malloc_safe(x * sizeof(bfloat16) << l); - float *DD = (float *)malloc_safe(x * sizeof(FLOAT)); float *CC = (float *)malloc_safe(x * sizeof(FLOAT) << l); + float *DD = (float *)malloc_safe(x * sizeof(FLOAT)); if ((A == NULL) || (B == NULL) || (C == NULL) || (AA == NULL) || (BB == NULL) || (DD == NULL) || (CC == NULL)) return 1; diff --git a/test/test_helpers.h b/test/test_helpers.h index 2bb3f7acd..fcec86e10 100644 --- a/test/test_helpers.h +++ b/test/test_helpers.h @@ -31,7 +31,7 @@ THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include "../common.h" -#if IFLOAT == bfloat16 +#ifdef IBFLOAT16 static float float16to32(bfloat16 value) { blasint one = 1; @@ -41,6 +41,17 @@ static float float16to32(bfloat16 value) } #endif +#ifdef OBFLOAT16 +static float truncate_float32_to_bfloat16(float value) { + blasint one = 1; + bfloat16 tmp; + float result; + sbstobf16_(&one, &value, &one, &tmp, &one); + sbf16tos_(&one, &tmp, &one, &result, &one); + return result; +} +#endif + static void *malloc_safe(size_t size) { if (size == 0) return malloc(1);