Browse Source

Merge pull request #5276 from nakagawa-fj/gemm_2d_thread_partitioning

Improvement of 2D thread-partitioned GEMM for M << N case
tags/v0.3.30
Martin Kroeker GitHub 4 months ago
parent
commit
e2e6a4d90a
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
1 changed files with 13 additions and 3 deletions
  1. +13
    -3
      driver/level3/level3_thread.c

+ 13
- 3
driver/level3/level3_thread.c View File

@@ -851,9 +851,19 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, IFLOAT *sa, IF
/* Objective function come from sum of partitions in m and n. */
/* (n / nthreads_n) + (m / nthreads_m) */
/* = (n * nthreads_m + m * nthreads_n) / (nthreads_n * nthreads_m) */
while (nthreads_m % 2 == 0 && n * nthreads_m + m * nthreads_n > n * (nthreads_m / 2) + m * (nthreads_n * 2)) {
nthreads_m /= 2;
nthreads_n *= 2;
BLASLONG cost = 0, div = 0;
for (BLASLONG i = 1; i <= sqrt(nthreads_m); i++) {
if (nthreads_m % i) continue;
BLASLONG j = nthreads_m / i;
BLASLONG cost_i = n * j + m * nthreads_n * i;
BLASLONG cost_j = n * i + m * nthreads_n * j;
if (cost == 0 ||
cost_i < cost) {cost = cost_i; div = i;}
if (cost_j < cost) {cost = cost_j; div = j;}
}
if (div > 1) {
nthreads_m /= div;
nthreads_n *= div;
}
}



Loading…
Cancel
Save