| @@ -46,83 +46,50 @@ typedef union | |||
| } bits; | |||
| } bfloat16_bits; | |||
| typedef union | |||
| { | |||
| float v; | |||
| struct | |||
| { | |||
| uint32_t m:23; | |||
| uint32_t e:8; | |||
| uint32_t s:1; | |||
| } bits; | |||
| } float32_bits; | |||
| float | |||
| float16to32 (bfloat16_bits f16) | |||
| { | |||
| float32_bits f32; | |||
| f32.bits.s = f16.bits.s; | |||
| f32.bits.e = f16.bits.e; | |||
| f32.bits.m = (uint32_t) f16.bits.m << 16; | |||
| return f32.v; | |||
| } | |||
| int | |||
| main (int argc, char *argv[]) | |||
| { | |||
| int m, n, k; | |||
| int i, j, l; | |||
| int x; | |||
| int ret = 0; | |||
| int loop = 100; | |||
| char transA = 'N', transB = 'N'; | |||
| float alpha = 1.0, beta = 0.0; | |||
| char transa = 'N'; | |||
| char transb = 'N'; | |||
| for (int x = 0; x <= loop; x++) | |||
| for (x = 0; x <= loop; x++) | |||
| { | |||
| m = k = n = x; | |||
| float A[m * k]; | |||
| float B[k * n]; | |||
| float C[m * n]; | |||
| bfloat16_bits AA[m * k], BB[k * n]; | |||
| float DD[m * n], CC[m * n]; | |||
| float CC[m * n]; | |||
| for (int j = 0; j < m; j++) | |||
| for (j = 0; j < m; j++) | |||
| { | |||
| for (int i = 0; i < m; i++) | |||
| for (i = 0; i < m; i++) | |||
| { | |||
| A[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; | |||
| B[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; | |||
| A[j * k + i] = ((FLOAT) rand() / (FLOAT) RAND_MAX) + 0.5; | |||
| B[j * k + i] = ((FLOAT) rand() / (FLOAT) RAND_MAX) + 0.5; | |||
| C[j * k + i] = 0; | |||
| AA[j * k + i].v = *(uint32_t *) & A[j * k + i] >> 16; | |||
| BB[j * k + i].v = *(uint32_t *) & B[j * k + i] >> 16; | |||
| CC[j * k + i] = 0; | |||
| DD[j * k + i] = 0; | |||
| } | |||
| } | |||
| SGEMM (&transA, &transB, &m, &n, &k, &alpha, A, | |||
| &m, B, &k, &beta, C, &m); | |||
| &m, B, &k, &beta, C, &m); | |||
| SHGEMM (&transA, &transB, &m, &n, &k, &alpha, AA, | |||
| &m, BB, &k, &beta, CC, &m); | |||
| &m, BB, &k, &beta, CC, &m); | |||
| for (i = 0; i < n; i++) | |||
| for (j = 0; j < m; j++) | |||
| for (l = 0; l < k; l++) | |||
| if (fabs (CC[i * m + j] - C[i * m + j]) > 1.0) | |||
| ret++; | |||
| if (transA == 'N' && transB == 'N') | |||
| { | |||
| for (i = 0; i < n; i++) | |||
| for (j = 0; j < m; j++) | |||
| for (l = 0; l < k; l++) | |||
| { | |||
| DD[i * m + j] += | |||
| float16to32 (AA[l * m + j]) * float16to32 (BB[l + k * i]); | |||
| } | |||
| for (i = 0; i < n; i++) | |||
| for (j = 0; j < m; j++) | |||
| for (l = 0; l < k; l++) | |||
| if (CC[i * m + j] != DD[i * m + j]) | |||
| ret++; | |||
| } | |||
| for (j = 0; j < m; j++) | |||
| for (l = 0; l < k; l++) | |||
| if (fabs(CC[i * m + j]-C[i * m + j]) > 1.0) | |||
| ret++; | |||
| } | |||
| if (ret != 0) | |||
| fprintf (stderr, "FATAL ERROR SHGEMM - Return code: %d\n", ret); | |||