|
@@ -81,6 +81,8 @@ float16to32 (bfloat16_bits f16) |
|
|
return f32.v; |
|
|
return f32.v; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
#define SBGEMM_LARGEST 256 |
|
|
|
|
|
|
|
|
int |
|
|
int |
|
|
main (int argc, char *argv[]) |
|
|
main (int argc, char *argv[]) |
|
|
{ |
|
|
{ |
|
@@ -88,12 +90,39 @@ main (int argc, char *argv[]) |
|
|
int i, j, l; |
|
|
int i, j, l; |
|
|
blasint x, y; |
|
|
blasint x, y; |
|
|
int ret = 0; |
|
|
int ret = 0; |
|
|
int loop = 100; |
|
|
|
|
|
|
|
|
int loop = SBGEMM_LARGEST; |
|
|
char transA = 'N', transB = 'N'; |
|
|
char transA = 'N', transB = 'N'; |
|
|
float alpha = 1.0, beta = 0.0; |
|
|
float alpha = 1.0, beta = 0.0; |
|
|
|
|
|
|
|
|
for (x = 0; x <= loop; x++) |
|
|
for (x = 0; x <= loop; x++) |
|
|
{ |
|
|
{ |
|
|
|
|
|
if ((x > 100) && (x != SBGEMM_LARGEST)) continue; |
|
|
|
|
|
m = k = n = x; |
|
|
|
|
|
float *A = (float *)malloc(m * k * sizeof(FLOAT)); |
|
|
|
|
|
float *B = (float *)malloc(k * n * sizeof(FLOAT)); |
|
|
|
|
|
float *C = (float *)malloc(m * n * sizeof(FLOAT)); |
|
|
|
|
|
bfloat16_bits *AA = (bfloat16_bits *)malloc(m * k * sizeof(bfloat16_bits)); |
|
|
|
|
|
bfloat16_bits *BB = (bfloat16_bits *)malloc(k * n * sizeof(bfloat16_bits)); |
|
|
|
|
|
float *DD = (float *)malloc(m * n * sizeof(FLOAT)); |
|
|
|
|
|
float *CC = (float *)malloc(m * n * sizeof(FLOAT)); |
|
|
|
|
|
if ((A == NULL) || (B == NULL) || (C == NULL) || (AA == NULL) || (BB == NULL) || |
|
|
|
|
|
(DD == NULL) || (CC == NULL)) |
|
|
|
|
|
return 1; |
|
|
|
|
|
bfloat16 atmp,btmp; |
|
|
|
|
|
blasint one=1; |
|
|
|
|
|
|
|
|
|
|
|
for (j = 0; j < m; j++) |
|
|
|
|
|
{ |
|
|
|
|
|
for (i = 0; i < n; i++) |
|
|
|
|
|
{ |
|
|
|
|
|
A[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; |
|
|
|
|
|
B[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; |
|
|
|
|
|
sbstobf16_(&one, &A[j*k+i], &one, &atmp, &one); |
|
|
|
|
|
sbstobf16_(&one, &B[j*k+i], &one, &btmp, &one); |
|
|
|
|
|
AA[j * k + i].v = atmp; |
|
|
|
|
|
BB[j * k + i].v = btmp; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
for (y = 0; y < 4; y++) |
|
|
for (y = 0; y < 4; y++) |
|
|
{ |
|
|
{ |
|
|
if ((y == 0) || (y == 2)) { |
|
|
if ((y == 0) || (y == 2)) { |
|
@@ -106,34 +135,16 @@ main (int argc, char *argv[]) |
|
|
} else { |
|
|
} else { |
|
|
transB = 'T'; |
|
|
transB = 'T'; |
|
|
} |
|
|
} |
|
|
m = k = n = x; |
|
|
|
|
|
float A[m * k]; |
|
|
|
|
|
float B[k * n]; |
|
|
|
|
|
float C[m * n]; |
|
|
|
|
|
bfloat16_bits AA[m * k], BB[k * n]; |
|
|
|
|
|
float DD[m * n], CC[m * n]; |
|
|
|
|
|
bfloat16 atmp,btmp; |
|
|
|
|
|
blasint one=1; |
|
|
|
|
|
|
|
|
|
|
|
for (j = 0; j < m; j++) |
|
|
|
|
|
{ |
|
|
|
|
|
for (i = 0; i < m; i++) |
|
|
|
|
|
{ |
|
|
|
|
|
A[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; |
|
|
|
|
|
B[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; |
|
|
|
|
|
C[j * k + i] = 0; |
|
|
|
|
|
sbstobf16_(&one, &A[j*k+i], &one, &atmp, &one); |
|
|
|
|
|
sbstobf16_(&one, &B[j*k+i], &one, &btmp, &one); |
|
|
|
|
|
AA[j * k + i].v = atmp; |
|
|
|
|
|
BB[j * k + i].v = btmp; |
|
|
|
|
|
CC[j * k + i] = 0; |
|
|
|
|
|
DD[j * k + i] = 0; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
memset(CC, 0, m * n * sizeof(FLOAT)); |
|
|
|
|
|
memset(DD, 0, m * n * sizeof(FLOAT)); |
|
|
|
|
|
memset(C, 0, m * n * sizeof(FLOAT)); |
|
|
|
|
|
|
|
|
SGEMM (&transA, &transB, &m, &n, &k, &alpha, A, |
|
|
SGEMM (&transA, &transB, &m, &n, &k, &alpha, A, |
|
|
&m, B, &k, &beta, C, &m); |
|
|
&m, B, &k, &beta, C, &m); |
|
|
SBGEMM (&transA, &transB, &m, &n, &k, &alpha, (bfloat16*) AA, |
|
|
SBGEMM (&transA, &transB, &m, &n, &k, &alpha, (bfloat16*) AA, |
|
|
&m, (bfloat16*)BB, &k, &beta, CC, &m); |
|
|
&m, (bfloat16*)BB, &k, &beta, CC, &m); |
|
|
|
|
|
|
|
|
for (i = 0; i < n; i++) |
|
|
for (i = 0; i < n; i++) |
|
|
for (j = 0; j < m; j++) |
|
|
for (j = 0; j < m; j++) |
|
|
if (fabs (CC[i * m + j] - C[i * m + j]) > 1.0) |
|
|
if (fabs (CC[i * m + j] - C[i * m + j]) > 1.0) |
|
@@ -160,9 +171,16 @@ main (int argc, char *argv[]) |
|
|
} |
|
|
} |
|
|
for (i = 0; i < n; i++) |
|
|
for (i = 0; i < n; i++) |
|
|
for (j = 0; j < m; j++) |
|
|
for (j = 0; j < m; j++) |
|
|
if (CC[i * m + j] != DD[i * m + j]) |
|
|
|
|
|
|
|
|
if (fabs (CC[i * m + j] - DD[i * m + j]) > 1.0) |
|
|
ret++; |
|
|
ret++; |
|
|
} |
|
|
} |
|
|
|
|
|
free(A); |
|
|
|
|
|
free(B); |
|
|
|
|
|
free(C); |
|
|
|
|
|
free(AA); |
|
|
|
|
|
free(BB); |
|
|
|
|
|
free(DD); |
|
|
|
|
|
free(CC); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if (ret != 0) |
|
|
if (ret != 0) |
|
|