Browse Source

Redefined threading logic for WoA

tags/v0.3.30
Harishmcw 7 months ago
parent
commit
030ae1fd97
3 changed files with 21 additions and 9 deletions
  1. +5
    -0
      interface/gemv.c
  2. +6
    -4
      interface/lapack/gesv.c
  3. +10
    -5
      interface/zgemv.c

+ 5
- 0
interface/gemv.c View File

@@ -79,6 +79,11 @@ static inline int get_gemv_optimal_nthreads_neoversev1(BLASLONG MN, int ncpu) {

static inline int get_gemv_optimal_nthreads(BLASLONG MN) {
int ncpu = num_cpu_avail(3);
#if defined(_WIN64) && defined(_M_ARM64)
if (MN > 100000000L)
return num_cpu_avail(4);
return 1;
#endif
#if defined(NEOVERSEV1) && !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16)
return get_gemv_optimal_nthreads_neoversev1(MN, ncpu);
#elif defined(DYNAMIC_ARCH) && !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16)


+ 6
- 4
interface/lapack/gesv.c View File

@@ -117,13 +117,15 @@ int NAME(blasint *N, blasint *NRHS, FLOAT *a, blasint *ldA, blasint *ipiv,

#if defined(_WIN64) && defined(_M_ARM64)
#ifdef COMPLEX
if (args.m * args.n > 600)
if (args.m * args.n <= 300)
#else
if (args.m * args.n > 1000)
if (args.m * args.n <= 500)
#endif
args.nthreads = num_cpu_avail(4);
else
args.nthreads = 1;
else if (args.m * args.n <= 1000)
args.nthreads = 4;
else
args.nthreads = num_cpu_avail(4);
#else
#ifndef DOUBLE
if (args.m * args.n < 40000)


+ 10
- 5
interface/zgemv.c View File

@@ -252,25 +252,30 @@ void CNAME(enum CBLAS_ORDER order,

#ifdef SMP

if ( 1L * m * n < 1024L * GEMM_MULTITHREAD_THRESHOLD )
#if defined(_WIN64) && defined(_M_ARM64)
if (m*n > 25000000L)
nthreads = num_cpu_avail(4);
else
nthreads = 1;
#else
if (1L * m * n < 1024L * GEMM_MULTITHREAD_THRESHOLD)
nthreads = 1;
else
nthreads = num_cpu_avail(2);
#endif

if (nthreads == 1) {
#endif
#endif

(gemv[(int)trans])(m, n, 0, alpha_r, alpha_i, a, lda, x, incx, y, incy, buffer);

#ifdef SMP

} else {

(gemv_thread[(int)trans])(m, n, ALPHA, a, lda, x, incx, y, incy, buffer, nthreads);

}
#endif


STACK_FREE(buffer);

FUNCTION_PROFILE_END(4, m * n + m + n, 2 * m * n);


Loading…
Cancel
Save