Browse Source

Use correct constants for per-target BGEMM/SBGEMM

This fixes the build and tests on `NEOVERSEV1` target, which was failing
with specific constants for `SBGEMM`

Co-authored-by: Ye Tao <ye.tao@arm.com>
pull/5357/head
Chris Sidebottom 2 months ago
parent
commit
48394384ef
2 changed files with 34 additions and 10 deletions
  1. +17
    -5
      driver/level3/level3.c
  2. +17
    -5
      driver/level3/level3_thread.c

+ 17
- 5
driver/level3/level3.c View File

@@ -170,6 +170,22 @@
#define STOP_RPCC(COUNTER)
#endif

#if defined(BUILD_BFLOAT16)
#if defined(DYNAMIC_ARCH)
#if defined(BGEMM)
#define BFLOAT16_ALIGN_K gotoblas->bgemm_align_k
#else
#define BFLOAT16_ALIGN_K gotoblas->sbgemm_align_k
#endif
#else
#if defined(BGEMM)
#define BFLOAT16_ALIGN_K BGEMM_ALIGN_K
#else
#define BFLOAT16_ALIGN_K SBGEMM_ALIGN_K
#endif
#endif
#endif

int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
XFLOAT *sa, XFLOAT *sb, BLASLONG dummy){
BLASLONG k, lda, ldb, ldc;
@@ -307,11 +323,7 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,

BLASLONG pad_min_l = min_l;
#if defined(BFLOAT16)
#if defined(DYNAMIC_ARCH)
pad_min_l = (min_l + gotoblas->sbgemm_align_k - 1) & ~(gotoblas->sbgemm_align_k-1);
#else
pad_min_l = (min_l + SBGEMM_ALIGN_K - 1) & ~(SBGEMM_ALIGN_K - 1);;
#endif
pad_min_l = (min_l + BFLOAT16_ALIGN_K - 1) & ~(BFLOAT16_ALIGN_K - 1);
#endif

/* First, we have to move data A to L2 cache */


+ 17
- 5
driver/level3/level3_thread.c View File

@@ -216,6 +216,22 @@ typedef struct {
#define STOP_RPCC(COUNTER)
#endif

#if defined(BUILD_BFLOAT16)
#if defined(DYNAMIC_ARCH)
#if defined(BGEMM)
#define BFLOAT16_ALIGN_K gotoblas->bgemm_align_k
#else
#define BFLOAT16_ALIGN_K gotoblas->sbgemm_align_k
#endif
#else
#if defined(BGEMM)
#define BFLOAT16_ALIGN_K BGEMM_ALIGN_K
#else
#define BFLOAT16_ALIGN_K SBGEMM_ALIGN_K
#endif
#endif
#endif

static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, IFLOAT *sa, IFLOAT *sb, BLASLONG mypos){

IFLOAT *buffer[DIVIDE_RATE];
@@ -325,11 +341,7 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
BLASLONG pad_min_l = min_l;

#if defined(BFLOAT16)
#if defined(DYNAMIC_ARCH)
pad_min_l = (min_l + gotoblas->sbgemm_align_k - 1) & ~(gotoblas->sbgemm_align_k-1);
#else
pad_min_l = (min_l + SBGEMM_ALIGN_K - 1) & ~(SBGEMM_ALIGN_K - 1);;
#endif
pad_min_l = (min_l + BFLOAT16_ALIGN_K - 1) & ~(BFLOAT16_ALIGN_K - 1);
#endif

/* Determine step size in m


Loading…
Cancel
Save