|
@@ -46,83 +46,50 @@ typedef union |
|
|
} bits; |
|
|
} bits; |
|
|
} bfloat16_bits; |
|
|
} bfloat16_bits; |
|
|
|
|
|
|
|
|
typedef union |
|
|
|
|
|
{ |
|
|
|
|
|
float v; |
|
|
|
|
|
struct |
|
|
|
|
|
{ |
|
|
|
|
|
uint32_t m:23; |
|
|
|
|
|
uint32_t e:8; |
|
|
|
|
|
uint32_t s:1; |
|
|
|
|
|
} bits; |
|
|
|
|
|
} float32_bits; |
|
|
|
|
|
|
|
|
|
|
|
float |
|
|
|
|
|
float16to32 (bfloat16_bits f16) |
|
|
|
|
|
{ |
|
|
|
|
|
float32_bits f32; |
|
|
|
|
|
f32.bits.s = f16.bits.s; |
|
|
|
|
|
f32.bits.e = f16.bits.e; |
|
|
|
|
|
f32.bits.m = (uint32_t) f16.bits.m << 16; |
|
|
|
|
|
return f32.v; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
int |
|
|
int |
|
|
main (int argc, char *argv[]) |
|
|
main (int argc, char *argv[]) |
|
|
{ |
|
|
{ |
|
|
int m, n, k; |
|
|
int m, n, k; |
|
|
int i, j, l; |
|
|
int i, j, l; |
|
|
|
|
|
int x; |
|
|
int ret = 0; |
|
|
int ret = 0; |
|
|
int loop = 100; |
|
|
int loop = 100; |
|
|
char transA = 'N', transB = 'N'; |
|
|
char transA = 'N', transB = 'N'; |
|
|
float alpha = 1.0, beta = 0.0; |
|
|
float alpha = 1.0, beta = 0.0; |
|
|
|
|
|
char transa = 'N'; |
|
|
|
|
|
char transb = 'N'; |
|
|
|
|
|
|
|
|
for (int x = 0; x <= loop; x++) |
|
|
|
|
|
|
|
|
for (x = 0; x <= loop; x++) |
|
|
{ |
|
|
{ |
|
|
m = k = n = x; |
|
|
m = k = n = x; |
|
|
float A[m * k]; |
|
|
float A[m * k]; |
|
|
float B[k * n]; |
|
|
float B[k * n]; |
|
|
float C[m * n]; |
|
|
float C[m * n]; |
|
|
bfloat16_bits AA[m * k], BB[k * n]; |
|
|
bfloat16_bits AA[m * k], BB[k * n]; |
|
|
float DD[m * n], CC[m * n]; |
|
|
|
|
|
|
|
|
float CC[m * n]; |
|
|
|
|
|
|
|
|
for (int j = 0; j < m; j++) |
|
|
|
|
|
|
|
|
for (j = 0; j < m; j++) |
|
|
{ |
|
|
{ |
|
|
for (int i = 0; i < m; i++) |
|
|
|
|
|
|
|
|
for (i = 0; i < m; i++) |
|
|
{ |
|
|
{ |
|
|
A[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; |
|
|
|
|
|
B[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; |
|
|
|
|
|
|
|
|
A[j * k + i] = ((FLOAT) rand() / (FLOAT) RAND_MAX) + 0.5; |
|
|
|
|
|
B[j * k + i] = ((FLOAT) rand() / (FLOAT) RAND_MAX) + 0.5; |
|
|
C[j * k + i] = 0; |
|
|
C[j * k + i] = 0; |
|
|
AA[j * k + i].v = *(uint32_t *) & A[j * k + i] >> 16; |
|
|
AA[j * k + i].v = *(uint32_t *) & A[j * k + i] >> 16; |
|
|
BB[j * k + i].v = *(uint32_t *) & B[j * k + i] >> 16; |
|
|
BB[j * k + i].v = *(uint32_t *) & B[j * k + i] >> 16; |
|
|
CC[j * k + i] = 0; |
|
|
CC[j * k + i] = 0; |
|
|
DD[j * k + i] = 0; |
|
|
|
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
SGEMM (&transA, &transB, &m, &n, &k, &alpha, A, |
|
|
SGEMM (&transA, &transB, &m, &n, &k, &alpha, A, |
|
|
&m, B, &k, &beta, C, &m); |
|
|
|
|
|
|
|
|
&m, B, &k, &beta, C, &m); |
|
|
SHGEMM (&transA, &transB, &m, &n, &k, &alpha, AA, |
|
|
SHGEMM (&transA, &transB, &m, &n, &k, &alpha, AA, |
|
|
&m, BB, &k, &beta, CC, &m); |
|
|
|
|
|
|
|
|
&m, BB, &k, &beta, CC, &m); |
|
|
|
|
|
|
|
|
for (i = 0; i < n; i++) |
|
|
for (i = 0; i < n; i++) |
|
|
for (j = 0; j < m; j++) |
|
|
|
|
|
for (l = 0; l < k; l++) |
|
|
|
|
|
if (fabs (CC[i * m + j] - C[i * m + j]) > 1.0) |
|
|
|
|
|
ret++; |
|
|
|
|
|
if (transA == 'N' && transB == 'N') |
|
|
|
|
|
{ |
|
|
|
|
|
for (i = 0; i < n; i++) |
|
|
|
|
|
for (j = 0; j < m; j++) |
|
|
|
|
|
for (l = 0; l < k; l++) |
|
|
|
|
|
{ |
|
|
|
|
|
DD[i * m + j] += |
|
|
|
|
|
float16to32 (AA[l * m + j]) * float16to32 (BB[l + k * i]); |
|
|
|
|
|
} |
|
|
|
|
|
for (i = 0; i < n; i++) |
|
|
|
|
|
for (j = 0; j < m; j++) |
|
|
|
|
|
for (l = 0; l < k; l++) |
|
|
|
|
|
if (CC[i * m + j] != DD[i * m + j]) |
|
|
|
|
|
ret++; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
for (j = 0; j < m; j++) |
|
|
|
|
|
for (l = 0; l < k; l++) |
|
|
|
|
|
if (fabs(CC[i * m + j]-C[i * m + j]) > 1.0) |
|
|
|
|
|
ret++; |
|
|
} |
|
|
} |
|
|
if (ret != 0) |
|
|
if (ret != 0) |
|
|
fprintf (stderr, "FATAL ERROR SHGEMM - Return code: %d\n", ret); |
|
|
fprintf (stderr, "FATAL ERROR SHGEMM - Return code: %d\n", ret); |
|
|