|
|
@@ -85,6 +85,14 @@ float16to32 (bfloat16_bits f16) |
|
|
|
|
|
|
|
#define SBGEMM_LARGEST 256 |
|
|
|
|
|
|
|
void *malloc_safe(size_t size) |
|
|
|
{ |
|
|
|
if (size == 0) |
|
|
|
return malloc(1); |
|
|
|
else |
|
|
|
return malloc(size); |
|
|
|
} |
|
|
|
|
|
|
|
int |
|
|
|
main (int argc, char *argv[]) |
|
|
|
{ |
|
|
@@ -100,13 +108,13 @@ main (int argc, char *argv[]) |
|
|
|
{ |
|
|
|
if ((x > 100) && (x != SBGEMM_LARGEST)) continue; |
|
|
|
m = k = n = x; |
|
|
|
float *A = (float *)malloc(m * k * sizeof(FLOAT)); |
|
|
|
float *B = (float *)malloc(k * n * sizeof(FLOAT)); |
|
|
|
float *C = (float *)malloc(m * n * sizeof(FLOAT)); |
|
|
|
bfloat16_bits *AA = (bfloat16_bits *)malloc(m * k * sizeof(bfloat16_bits)); |
|
|
|
bfloat16_bits *BB = (bfloat16_bits *)malloc(k * n * sizeof(bfloat16_bits)); |
|
|
|
float *DD = (float *)malloc(m * n * sizeof(FLOAT)); |
|
|
|
float *CC = (float *)malloc(m * n * sizeof(FLOAT)); |
|
|
|
float *A = (float *)malloc_safe(m * k * sizeof(FLOAT)); |
|
|
|
float *B = (float *)malloc_safe(k * n * sizeof(FLOAT)); |
|
|
|
float *C = (float *)malloc_safe(m * n * sizeof(FLOAT)); |
|
|
|
bfloat16_bits *AA = (bfloat16_bits *)malloc_safe(m * k * sizeof(bfloat16_bits)); |
|
|
|
bfloat16_bits *BB = (bfloat16_bits *)malloc_safe(k * n * sizeof(bfloat16_bits)); |
|
|
|
float *DD = (float *)malloc_safe(m * n * sizeof(FLOAT)); |
|
|
|
float *CC = (float *)malloc_safe(m * n * sizeof(FLOAT)); |
|
|
|
if ((A == NULL) || (B == NULL) || (C == NULL) || (AA == NULL) || (BB == NULL) || |
|
|
|
(DD == NULL) || (CC == NULL)) |
|
|
|
return 1; |
|
|
@@ -194,16 +202,16 @@ main (int argc, char *argv[]) |
|
|
|
return ret; |
|
|
|
} |
|
|
|
|
|
|
|
k = 1; |
|
|
|
for (x = 1; x <= loop; x++) |
|
|
|
{ |
|
|
|
float *A = (float *)malloc(x * x * sizeof(FLOAT)); |
|
|
|
float *B = (float *)malloc(x * sizeof(FLOAT)); |
|
|
|
float *C = (float *)malloc(x * sizeof(FLOAT)); |
|
|
|
bfloat16_bits *AA = (bfloat16_bits *)malloc(x * x * sizeof(bfloat16_bits)); |
|
|
|
bfloat16_bits *BB = (bfloat16_bits *)malloc(x * sizeof(bfloat16_bits)); |
|
|
|
float *DD = (float *)malloc(x * sizeof(FLOAT)); |
|
|
|
float *CC = (float *)malloc(x * sizeof(FLOAT)); |
|
|
|
k = (x == 0) ? 0 : 1; |
|
|
|
float *A = (float *)malloc_safe(x * x * sizeof(FLOAT)); |
|
|
|
float *B = (float *)malloc_safe(x * sizeof(FLOAT)); |
|
|
|
float *C = (float *)malloc_safe(x * sizeof(FLOAT)); |
|
|
|
bfloat16_bits *AA = (bfloat16_bits *)malloc_safe(x * x * sizeof(bfloat16_bits)); |
|
|
|
bfloat16_bits *BB = (bfloat16_bits *)malloc_safe(x * sizeof(bfloat16_bits)); |
|
|
|
float *DD = (float *)malloc_safe(x * sizeof(FLOAT)); |
|
|
|
float *CC = (float *)malloc_safe(x * sizeof(FLOAT)); |
|
|
|
if ((A == NULL) || (B == NULL) || (C == NULL) || (AA == NULL) || (BB == NULL) || |
|
|
|
(DD == NULL) || (CC == NULL)) |
|
|
|
return 1; |
|
|
|