diff --git a/driver/level3/level3.c b/driver/level3/level3.c index 5d3438450..78bc6aa52 100644 --- a/driver/level3/level3.c +++ b/driver/level3/level3.c @@ -170,6 +170,22 @@ #define STOP_RPCC(COUNTER) #endif +#if defined(BUILD_BFLOAT16) +#if defined(DYNAMIC_ARCH) + #if defined(BGEMM) + #define BFLOAT16_ALIGN_K gotoblas->bgemm_align_k + #else + #define BFLOAT16_ALIGN_K gotoblas->sbgemm_align_k + #endif +#else + #if defined(BGEMM) + #define BFLOAT16_ALIGN_K BGEMM_ALIGN_K + #else + #define BFLOAT16_ALIGN_K SBGEMM_ALIGN_K + #endif +#endif +#endif + int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, XFLOAT *sa, XFLOAT *sb, BLASLONG dummy){ BLASLONG k, lda, ldb, ldc; @@ -307,11 +323,7 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, BLASLONG pad_min_l = min_l; #if defined(BFLOAT16) -#if defined(DYNAMIC_ARCH) - pad_min_l = (min_l + gotoblas->sbgemm_align_k - 1) & ~(gotoblas->sbgemm_align_k-1); -#else - pad_min_l = (min_l + SBGEMM_ALIGN_K - 1) & ~(SBGEMM_ALIGN_K - 1);; -#endif + pad_min_l = (min_l + BFLOAT16_ALIGN_K - 1) & ~(BFLOAT16_ALIGN_K - 1); #endif /* First, we have to move data A to L2 cache */ diff --git a/driver/level3/level3_thread.c b/driver/level3/level3_thread.c index 5ede6153e..cb93591ab 100644 --- a/driver/level3/level3_thread.c +++ b/driver/level3/level3_thread.c @@ -216,6 +216,22 @@ typedef struct { #define STOP_RPCC(COUNTER) #endif +#if defined(BUILD_BFLOAT16) +#if defined(DYNAMIC_ARCH) + #if defined(BGEMM) + #define BFLOAT16_ALIGN_K gotoblas->bgemm_align_k + #else + #define BFLOAT16_ALIGN_K gotoblas->sbgemm_align_k + #endif +#else + #if defined(BGEMM) + #define BFLOAT16_ALIGN_K BGEMM_ALIGN_K + #else + #define BFLOAT16_ALIGN_K SBGEMM_ALIGN_K + #endif +#endif +#endif + static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, IFLOAT *sa, IFLOAT *sb, BLASLONG mypos){ IFLOAT *buffer[DIVIDE_RATE]; @@ -325,11 +341,7 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, BLASLONG pad_min_l = min_l; #if defined(BFLOAT16) -#if defined(DYNAMIC_ARCH) - pad_min_l = (min_l + gotoblas->sbgemm_align_k - 1) & ~(gotoblas->sbgemm_align_k-1); -#else - pad_min_l = (min_l + SBGEMM_ALIGN_K - 1) & ~(SBGEMM_ALIGN_K - 1);; -#endif + pad_min_l = (min_l + BFLOAT16_ALIGN_K - 1) & ~(BFLOAT16_ALIGN_K - 1); #endif /* Determine step size in m