From 3f110c827294b54ba7438b1763584c7c01b1ee64 Mon Sep 17 00:00:00 2001 From: Chris Sidebottom Date: Sun, 13 Jul 2025 12:48:09 +0000 Subject: [PATCH] Improve bgemm and sbgemm testing - Fixes wrong return type for `is_close` - Adds stricter compiler flags for test files so we don't see the above issue again - Re-uses test helper functions between compare_sgemm_sbgemm/bgemm.c --- test/Makefile | 7 +-- test/compare_sgemm_bgemm.c | 30 +++--------- test/compare_sgemm_sbgemm.c | 97 ++++++++----------------------------- test/test_helpers.h | 55 +++++++++++++++++++++ 4 files changed, 84 insertions(+), 105 deletions(-) create mode 100644 test/test_helpers.h diff --git a/test/Makefile b/test/Makefile index 7ac87f7f6..cd8006c04 100644 --- a/test/Makefile +++ b/test/Makefile @@ -13,7 +13,7 @@ # 3. Neither the name of the OpenBLAS project nor the names of # its contributors may be used to endorse or promote products # derived from this software without specific prior written permission. -# +# # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE @@ -34,6 +34,7 @@ ifneq (, $(filter $(CORE),LOONGSON3R3 LOONGSON3R4)) endif override FFLAGS += -fno-tree-vectorize endif +override CFLAGS += -std=c11 -Wall -Werror SUPPORT_GEMM3M = 0 @@ -402,10 +403,10 @@ zblat3 : zblat3.$(SUFFIX) ../$(LIBNAME) endif ifeq ($(BUILD_BFLOAT16),1) -test_bgemm : compare_sgemm_bgemm.c ../$(LIBNAME) +test_bgemm : compare_sgemm_bgemm.c test_helpers.h ../$(LIBNAME) $(CC) $(CLDFLAGS) -o test_bgemm compare_sgemm_bgemm.c ../$(LIBNAME) $(EXTRALIB) $(CEXTRALIB) -test_sbgemm : compare_sgemm_sbgemm.c ../$(LIBNAME) +test_sbgemm : compare_sgemm_sbgemm.c test_helpers.h ../$(LIBNAME) $(CC) $(CLDFLAGS) -o test_sbgemm compare_sgemm_sbgemm.c ../$(LIBNAME) $(EXTRALIB) $(CEXTRALIB) endif diff --git a/test/compare_sgemm_bgemm.c b/test/compare_sgemm_bgemm.c index 69210b98e..8ece63841 100644 --- a/test/compare_sgemm_bgemm.c +++ b/test/compare_sgemm_bgemm.c @@ -28,20 +28,13 @@ THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include #include +#include "test_helpers.h" #define SGEMM BLASFUNC(sgemm) #define BGEMM BLASFUNC(bgemm) #define BGEMM_LARGEST 256 -static float float16to32(bfloat16 value) -{ - blasint one = 1; - float result; - sbf16tos_(&one, &value, &one, &result, &one); - return result; -} - -static float truncate_float(float value) { +static float truncate_float32_to_bfloat16(float value) { blasint one = 1; bfloat16 tmp; float result; @@ -50,17 +43,6 @@ static float truncate_float(float value) { return result; } -static void *malloc_safe(size_t size) { - if (size == 0) - return malloc(1); - else - return malloc(size); -} - -static float is_close(float a, float b, float rtol, float atol) { - return fabs(a - b) <= (atol + rtol*fabs(b)); -} - int main (int argc, char *argv[]) { @@ -151,15 +133,15 @@ main (int argc, char *argv[]) DD[i * m + j] += float16to32 (AA[k * j + l]) * float16to32 (BB[i + l * n]); } - if (!is_close(float16to32(CC[i * m + j]), truncate_float(C[i * m + j]), 0.01, 0.001)) { + if (!is_close(float16to32(CC[i * m + j]), truncate_float32_to_bfloat16(C[i * m + j]), 0.01, 0.001)) { printf("Mismatch at i=%d, j=%d, k=%d: CC=%.6f, C=%.6f\n", - i, j, k, float16to32(CC[i * m + j]), truncate_float(C[i * m + j])); + i, j, k, float16to32(CC[i * m + j]), truncate_float32_to_bfloat16(C[i * m + j])); ret++; } - if (!is_close(float16to32(CC[i * m + j]), truncate_float(DD[i * m + j]), 0.0001, 0.00001)) { + if (!is_close(float16to32(CC[i * m + j]), truncate_float32_to_bfloat16(DD[i * m + j]), 0.0001, 0.00001)) { printf("Mismatch at i=%d, j=%d, k=%d: CC=%.6f, DD=%.6f\n", - i, j, k, float16to32(CC[i * m + j]), truncate_float(DD[i * m + j])); + i, j, k, float16to32(CC[i * m + j]), truncate_float32_to_bfloat16(DD[i * m + j])); ret++; } diff --git a/test/compare_sgemm_sbgemm.c b/test/compare_sgemm_sbgemm.c index ae109c1a5..4fa24b9ce 100644 --- a/test/compare_sgemm_sbgemm.c +++ b/test/compare_sgemm_sbgemm.c @@ -27,72 +27,15 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include #include #include "../common.h" + +#include "test_helpers.h" + #define SGEMM BLASFUNC(sgemm) #define SBGEMM BLASFUNC(sbgemm) #define SGEMV BLASFUNC(sgemv) #define SBGEMV BLASFUNC(sbgemv) -typedef union -{ - unsigned short v; -#if defined(_AIX) - struct __attribute__((packed)) -#else - struct -#endif - { -#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ - unsigned short s:1; - unsigned short e:8; - unsigned short m:7; -#else - unsigned short m:7; - unsigned short e:8; - unsigned short s:1; -#endif - } bits; -} bfloat16_bits; - -typedef union -{ - float v; -#if defined(_AIX) - struct __attribute__((packed)) -#else - struct -#endif - { -#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ - uint32_t s:1; - uint32_t e:8; - uint32_t m:23; -#else - uint32_t m:23; - uint32_t e:8; - uint32_t s:1; -#endif - } bits; -} float32_bits; - -float -float16to32 (bfloat16_bits f16) -{ - float32_bits f32; - f32.bits.s = f16.bits.s; - f32.bits.e = f16.bits.e; - f32.bits.m = (uint32_t) f16.bits.m << 16; - return f32.v; -} - #define SBGEMM_LARGEST 256 -void *malloc_safe(size_t size) -{ - if (size == 0) - return malloc(1); - else - return malloc(size); -} - int main (int argc, char *argv[]) { @@ -111,14 +54,13 @@ main (int argc, char *argv[]) float *A = (float *)malloc_safe(m * k * sizeof(FLOAT)); float *B = (float *)malloc_safe(k * n * sizeof(FLOAT)); float *C = (float *)malloc_safe(m * n * sizeof(FLOAT)); - bfloat16_bits *AA = (bfloat16_bits *)malloc_safe(m * k * sizeof(bfloat16_bits)); - bfloat16_bits *BB = (bfloat16_bits *)malloc_safe(k * n * sizeof(bfloat16_bits)); + bfloat16 *AA = (bfloat16 *)malloc_safe(m * k * sizeof(bfloat16)); + bfloat16 *BB = (bfloat16 *)malloc_safe(k * n * sizeof(bfloat16)); float *DD = (float *)malloc_safe(m * n * sizeof(FLOAT)); float *CC = (float *)malloc_safe(m * n * sizeof(FLOAT)); if ((A == NULL) || (B == NULL) || (C == NULL) || (AA == NULL) || (BB == NULL) || (DD == NULL) || (CC == NULL)) return 1; - bfloat16 atmp,btmp; blasint one=1; for (j = 0; j < m; j++) @@ -126,8 +68,7 @@ main (int argc, char *argv[]) for (i = 0; i < k; i++) { A[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; - sbstobf16_(&one, &A[j*k+i], &one, &atmp, &one); - AA[j * k + i].v = atmp; + sbstobf16_(&one, &A[j*k+i], &one, &AA[j * k + i], &one); } } for (j = 0; j < n; j++) @@ -135,8 +76,7 @@ main (int argc, char *argv[]) for (i = 0; i < k; i++) { B[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; - sbstobf16_(&one, &B[j*k+i], &one, &btmp, &one); - BB[j * k + i].v = btmp; + sbstobf16_(&one, &B[j*k+i], &one, &BB[j * k + i], &one); } } for (y = 0; y < 4; y++) @@ -182,10 +122,12 @@ main (int argc, char *argv[]) DD[i * m + j] += float16to32 (AA[k * j + l]) * float16to32 (BB[i + l * n]); } - if (fabs (CC[i * m + j] - C[i * m + j]) > 1.0) + if (!is_close(CC[i * m + j], C[i * m + j], 0.01, 0.001)) { ret++; - if (fabs (CC[i * m + j] - DD[i * m + j]) > 1.0) + } + if (!is_close(CC[i * m + j], DD[i * m + j], 0.001, 0.0001)) { ret++; + } } } free(A); @@ -211,14 +153,13 @@ main (int argc, char *argv[]) float *A = (float *)malloc_safe(x * x * sizeof(FLOAT)); float *B = (float *)malloc_safe(x * sizeof(FLOAT) << l); float *C = (float *)malloc_safe(x * sizeof(FLOAT) << l); - bfloat16_bits *AA = (bfloat16_bits *)malloc_safe(x * x * sizeof(bfloat16_bits)); - bfloat16_bits *BB = (bfloat16_bits *)malloc_safe(x * sizeof(bfloat16_bits) << l); + bfloat16 *AA = (bfloat16 *)malloc_safe(x * x * sizeof(bfloat16)); + bfloat16 *BB = (bfloat16 *)malloc_safe(x * sizeof(bfloat16) << l); float *DD = (float *)malloc_safe(x * sizeof(FLOAT)); float *CC = (float *)malloc_safe(x * sizeof(FLOAT) << l); if ((A == NULL) || (B == NULL) || (C == NULL) || (AA == NULL) || (BB == NULL) || (DD == NULL) || (CC == NULL)) return 1; - bfloat16 atmp, btmp; blasint one = 1; for (j = 0; j < x; j++) @@ -226,12 +167,10 @@ main (int argc, char *argv[]) for (i = 0; i < x; i++) { A[j * x + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; - sbstobf16_(&one, &A[j*x+i], &one, &atmp, &one); - AA[j * x + i].v = atmp; + sbstobf16_(&one, &A[j*x+i], &one, &AA[j * x + i], &one); } B[j << l] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; - sbstobf16_(&one, &B[j << l], &one, &btmp, &one); - BB[j << l].v = btmp; + sbstobf16_(&one, &B[j << l], &one, &BB[j << l], &one); CC[j << l] = C[j << l] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; } @@ -262,10 +201,12 @@ main (int argc, char *argv[]) } for (j = 0; j < x; j++) { - if (fabs (CC[j << l] - C[j << l]) > 1.0) + if (!is_close(CC[j << l], C[j << l], 0.01, 0.001)) { ret++; - if (fabs (CC[j << l] - DD[j]) > 1.0) + } + if (!is_close(CC[j << l], DD[j], 0.001, 0.0001)) { ret++; + } } } free(A); diff --git a/test/test_helpers.h b/test/test_helpers.h new file mode 100644 index 000000000..2bb3f7acd --- /dev/null +++ b/test/test_helpers.h @@ -0,0 +1,55 @@ +/*************************************************************************** +Copyright (c) 2025 The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE +GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) +HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF +THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#ifndef TEST_HELPERS_H +#define TEST_HELPERS_H +#include + +#include "../common.h" + +#if IFLOAT == bfloat16 +static float float16to32(bfloat16 value) +{ + blasint one = 1; + float result; + sbf16tos_(&one, &value, &one, &result, &one); + return result; +} +#endif + +static void *malloc_safe(size_t size) { + if (size == 0) + return malloc(1); + else + return malloc(size); +} + +static bool is_close(float a, float b, float rtol, float atol) { + return fabs(a - b) <= (atol + rtol*fabs(b)); +} + +#endif