Browse Source

Redefined threading logic for WoA

tags/v0.3.30
Harishmcw 7 months ago
parent
commit
030ae1fd97
3 changed files with 21 additions and 9 deletions
  1. +5
    -0
      interface/gemv.c
  2. +6
    -4
      interface/lapack/gesv.c
  3. +10
    -5
      interface/zgemv.c

+ 5
- 0
interface/gemv.c View File

@@ -79,6 +79,11 @@ static inline int get_gemv_optimal_nthreads_neoversev1(BLASLONG MN, int ncpu) {


static inline int get_gemv_optimal_nthreads(BLASLONG MN) { static inline int get_gemv_optimal_nthreads(BLASLONG MN) {
int ncpu = num_cpu_avail(3); int ncpu = num_cpu_avail(3);
#if defined(_WIN64) && defined(_M_ARM64)
if (MN > 100000000L)
return num_cpu_avail(4);
return 1;
#endif
#if defined(NEOVERSEV1) && !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16) #if defined(NEOVERSEV1) && !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16)
return get_gemv_optimal_nthreads_neoversev1(MN, ncpu); return get_gemv_optimal_nthreads_neoversev1(MN, ncpu);
#elif defined(DYNAMIC_ARCH) && !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16) #elif defined(DYNAMIC_ARCH) && !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16)


+ 6
- 4
interface/lapack/gesv.c View File

@@ -117,13 +117,15 @@ int NAME(blasint *N, blasint *NRHS, FLOAT *a, blasint *ldA, blasint *ipiv,


#if defined(_WIN64) && defined(_M_ARM64) #if defined(_WIN64) && defined(_M_ARM64)
#ifdef COMPLEX #ifdef COMPLEX
if (args.m * args.n > 600)
if (args.m * args.n <= 300)
#else #else
if (args.m * args.n > 1000)
if (args.m * args.n <= 500)
#endif #endif
args.nthreads = num_cpu_avail(4);
else
args.nthreads = 1; args.nthreads = 1;
else if (args.m * args.n <= 1000)
args.nthreads = 4;
else
args.nthreads = num_cpu_avail(4);
#else #else
#ifndef DOUBLE #ifndef DOUBLE
if (args.m * args.n < 40000) if (args.m * args.n < 40000)


+ 10
- 5
interface/zgemv.c View File

@@ -252,25 +252,30 @@ void CNAME(enum CBLAS_ORDER order,


#ifdef SMP #ifdef SMP


if ( 1L * m * n < 1024L * GEMM_MULTITHREAD_THRESHOLD )
#if defined(_WIN64) && defined(_M_ARM64)
if (m*n > 25000000L)
nthreads = num_cpu_avail(4);
else
nthreads = 1;
#else
if (1L * m * n < 1024L * GEMM_MULTITHREAD_THRESHOLD)
nthreads = 1; nthreads = 1;
else else
nthreads = num_cpu_avail(2); nthreads = num_cpu_avail(2);
#endif


if (nthreads == 1) { if (nthreads == 1) {
#endif
#endif


(gemv[(int)trans])(m, n, 0, alpha_r, alpha_i, a, lda, x, incx, y, incy, buffer); (gemv[(int)trans])(m, n, 0, alpha_r, alpha_i, a, lda, x, incx, y, incy, buffer);


#ifdef SMP #ifdef SMP

} else { } else {

(gemv_thread[(int)trans])(m, n, ALPHA, a, lda, x, incx, y, incy, buffer, nthreads); (gemv_thread[(int)trans])(m, n, ALPHA, a, lda, x, incx, y, incy, buffer, nthreads);

} }
#endif #endif



STACK_FREE(buffer); STACK_FREE(buffer);


FUNCTION_PROFILE_END(4, m * n + m + n, 2 * m * n); FUNCTION_PROFILE_END(4, m * n + m + n, 2 * m * n);


Loading…
Cancel
Save