| @@ -46,83 +46,50 @@ typedef union | |||||
| } bits; | } bits; | ||||
| } bfloat16_bits; | } bfloat16_bits; | ||||
| typedef union | |||||
| { | |||||
| float v; | |||||
| struct | |||||
| { | |||||
| uint32_t m:23; | |||||
| uint32_t e:8; | |||||
| uint32_t s:1; | |||||
| } bits; | |||||
| } float32_bits; | |||||
| float | |||||
| float16to32 (bfloat16_bits f16) | |||||
| { | |||||
| float32_bits f32; | |||||
| f32.bits.s = f16.bits.s; | |||||
| f32.bits.e = f16.bits.e; | |||||
| f32.bits.m = (uint32_t) f16.bits.m << 16; | |||||
| return f32.v; | |||||
| } | |||||
| int | int | ||||
| main (int argc, char *argv[]) | main (int argc, char *argv[]) | ||||
| { | { | ||||
| int m, n, k; | int m, n, k; | ||||
| int i, j, l; | int i, j, l; | ||||
| int x; | |||||
| int ret = 0; | int ret = 0; | ||||
| int loop = 100; | int loop = 100; | ||||
| char transA = 'N', transB = 'N'; | char transA = 'N', transB = 'N'; | ||||
| float alpha = 1.0, beta = 0.0; | float alpha = 1.0, beta = 0.0; | ||||
| char transa = 'N'; | |||||
| char transb = 'N'; | |||||
| for (int x = 0; x <= loop; x++) | |||||
| for (x = 0; x <= loop; x++) | |||||
| { | { | ||||
| m = k = n = x; | m = k = n = x; | ||||
| float A[m * k]; | float A[m * k]; | ||||
| float B[k * n]; | float B[k * n]; | ||||
| float C[m * n]; | float C[m * n]; | ||||
| bfloat16_bits AA[m * k], BB[k * n]; | bfloat16_bits AA[m * k], BB[k * n]; | ||||
| float DD[m * n], CC[m * n]; | |||||
| float CC[m * n]; | |||||
| for (int j = 0; j < m; j++) | |||||
| for (j = 0; j < m; j++) | |||||
| { | { | ||||
| for (int i = 0; i < m; i++) | |||||
| for (i = 0; i < m; i++) | |||||
| { | { | ||||
| A[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; | |||||
| B[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; | |||||
| A[j * k + i] = ((FLOAT) rand() / (FLOAT) RAND_MAX) + 0.5; | |||||
| B[j * k + i] = ((FLOAT) rand() / (FLOAT) RAND_MAX) + 0.5; | |||||
| C[j * k + i] = 0; | C[j * k + i] = 0; | ||||
| AA[j * k + i].v = *(uint32_t *) & A[j * k + i] >> 16; | AA[j * k + i].v = *(uint32_t *) & A[j * k + i] >> 16; | ||||
| BB[j * k + i].v = *(uint32_t *) & B[j * k + i] >> 16; | BB[j * k + i].v = *(uint32_t *) & B[j * k + i] >> 16; | ||||
| CC[j * k + i] = 0; | CC[j * k + i] = 0; | ||||
| DD[j * k + i] = 0; | |||||
| } | } | ||||
| } | } | ||||
| SGEMM (&transA, &transB, &m, &n, &k, &alpha, A, | SGEMM (&transA, &transB, &m, &n, &k, &alpha, A, | ||||
| &m, B, &k, &beta, C, &m); | |||||
| &m, B, &k, &beta, C, &m); | |||||
| SHGEMM (&transA, &transB, &m, &n, &k, &alpha, AA, | SHGEMM (&transA, &transB, &m, &n, &k, &alpha, AA, | ||||
| &m, BB, &k, &beta, CC, &m); | |||||
| &m, BB, &k, &beta, CC, &m); | |||||
| for (i = 0; i < n; i++) | for (i = 0; i < n; i++) | ||||
| for (j = 0; j < m; j++) | |||||
| for (l = 0; l < k; l++) | |||||
| if (fabs (CC[i * m + j] - C[i * m + j]) > 1.0) | |||||
| ret++; | |||||
| if (transA == 'N' && transB == 'N') | |||||
| { | |||||
| for (i = 0; i < n; i++) | |||||
| for (j = 0; j < m; j++) | |||||
| for (l = 0; l < k; l++) | |||||
| { | |||||
| DD[i * m + j] += | |||||
| float16to32 (AA[l * m + j]) * float16to32 (BB[l + k * i]); | |||||
| } | |||||
| for (i = 0; i < n; i++) | |||||
| for (j = 0; j < m; j++) | |||||
| for (l = 0; l < k; l++) | |||||
| if (CC[i * m + j] != DD[i * m + j]) | |||||
| ret++; | |||||
| } | |||||
| for (j = 0; j < m; j++) | |||||
| for (l = 0; l < k; l++) | |||||
| if (fabs(CC[i * m + j]-C[i * m + j]) > 1.0) | |||||
| ret++; | |||||
| } | } | ||||
| if (ret != 0) | if (ret != 0) | ||||
| fprintf (stderr, "FATAL ERROR SHGEMM - Return code: %d\n", ret); | fprintf (stderr, "FATAL ERROR SHGEMM - Return code: %d\n", ret); | ||||