Browse Source

optimize gemv forwarding on ARM64 systems

tags/v0.3.29
Chris Daley 11 months ago
parent
commit
cb48505251
2 changed files with 23 additions and 4 deletions
  1. +3
    -0
      CONTRIBUTORS.md
  2. +20
    -4
      interface/gemm.c

+ 3
- 0
CONTRIBUTORS.md View File

@@ -226,3 +226,6 @@ In chronological order:


* Dirreke <https://github.com/mseminatore> * Dirreke <https://github.com/mseminatore>
* [2024-01-16] Add basic support for the CSKY architecture * [2024-01-16] Add basic support for the CSKY architecture

* Christopher Daley <https://github.com/cdaley>
* [2024-01-24] Optimize GEMV forwarding on ARM64 systems

+ 20
- 4
interface/gemm.c View File

@@ -39,6 +39,7 @@


#include <stdio.h> #include <stdio.h>
#include <stdlib.h> #include <stdlib.h>
#include <stdbool.h>
#include "common.h" #include "common.h"
#ifdef FUNCTION_PROFILE #ifdef FUNCTION_PROFILE
#include "functable.h" #include "functable.h"
@@ -499,6 +500,15 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS
#endif #endif


#if defined(GEMM_GEMV_FORWARD) && !defined(GEMM3M) && !defined(COMPLEX) && (!defined(BFLOAT16) || defined(GEMM_GEMV_FORWARD_BF16)) #if defined(GEMM_GEMV_FORWARD) && !defined(GEMM3M) && !defined(COMPLEX) && (!defined(BFLOAT16) || defined(GEMM_GEMV_FORWARD_BF16))
#if defined(ARCH_ARM64)
// The gemv kernels in arm64/{gemv_n.S,gemv_n_sve.c,gemv_t.S,gemv_t_sve.c}
// perform poorly in certain circumstances. We use the following boolean
// variable along with the gemv argument values to avoid these inefficient
// gemv cases, see github issue#4951.
bool have_tuned_gemv = false;
#else
bool have_tuned_gemv = true;
#endif
// Check if we can convert GEMM -> GEMV // Check if we can convert GEMM -> GEMV
if (args.k != 0) { if (args.k != 0) {
if (args.n == 1) { if (args.n == 1) {
@@ -518,8 +528,11 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS
if (transb & 1) { if (transb & 1) {
inc_x = args.ldb; inc_x = args.ldb;
} }
GEMV(&NT, &m, &n, args.alpha, args.a, &lda, args.b, &inc_x, args.beta, args.c, &inc_y);
return;
bool is_efficient_gemv = have_tuned_gemv || ((NT == 'N') || (NT == 'T' && inc_x == 1));
if (is_efficient_gemv) {
GEMV(&NT, &m, &n, args.alpha, args.a, &lda, args.b, &inc_x, args.beta, args.c, &inc_y);
return;
}
} }
if (args.m == 1) { if (args.m == 1) {
blasint inc_x = args.lda; blasint inc_x = args.lda;
@@ -538,8 +551,11 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS
m = args.n; m = args.n;
n = args.k; n = args.k;
} }
GEMV(&NT, &m, &n, args.alpha, args.b, &ldb, args.a, &inc_x, args.beta, args.c, &inc_y);
return;
bool is_efficient_gemv = have_tuned_gemv || ((NT == 'N' && inc_y == 1) || (NT == 'T' && inc_x == 1));
if (is_efficient_gemv) {
GEMV(&NT, &m, &n, args.alpha, args.b, &ldb, args.a, &inc_x, args.beta, args.c, &inc_y);
return;
}
} }
} }
#endif #endif


Loading…
Cancel
Save