Browse Source

Fix race conditions in multithreaded GEMM3M

by adding barriers (and a mutex lock for the non-OpenMP case) like it was already done for GEMM in level3_thread.c some time ago
tags/v0.3.8^2
Martin Kroeker GitHub 6 years ago
parent
commit
f3065a0eed
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 42 additions and 11 deletions
  1. +42
    -11
      driver/level3/level3_gemm3m_thread.c

+ 42
- 11
driver/level3/level3_gemm3m_thread.c View File

@@ -408,7 +408,7 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,


/* Make sure if no one is using another buffer */ /* Make sure if no one is using another buffer */
for (i = 0; i < args -> nthreads; i++) for (i = 0; i < args -> nthreads; i++)
while (job[mypos].working[i][CACHE_LINE_SIZE * bufferside]) {YIELDING;};
while (job[mypos].working[i][CACHE_LINE_SIZE * bufferside]) {YIELDING;MB;};


STOP_RPCC(waiting1); STOP_RPCC(waiting1);


@@ -441,7 +441,8 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,


for (i = 0; i < args -> nthreads; i++) for (i = 0; i < args -> nthreads; i++)
job[mypos].working[i][CACHE_LINE_SIZE * bufferside] = (BLASLONG)buffer[bufferside]; job[mypos].working[i][CACHE_LINE_SIZE * bufferside] = (BLASLONG)buffer[bufferside];
}
WMB;
}


current = mypos; current = mypos;


@@ -458,7 +459,7 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
START_RPCC(); START_RPCC();


/* thread has to wait */ /* thread has to wait */
while(job[current].working[mypos][CACHE_LINE_SIZE * bufferside] == 0) {YIELDING;};
while(job[current].working[mypos][CACHE_LINE_SIZE * bufferside] == 0) {YIELDING;MB;};


STOP_RPCC(waiting2); STOP_RPCC(waiting2);


@@ -477,6 +478,7 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,


if (m_to - m_from == min_i) { if (m_to - m_from == min_i) {
job[current].working[mypos][CACHE_LINE_SIZE * bufferside] = 0; job[current].working[mypos][CACHE_LINE_SIZE * bufferside] = 0;
WMB;
} }
} }
} while (current != mypos); } while (current != mypos);
@@ -517,6 +519,7 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
if (is + min_i >= m_to) { if (is + min_i >= m_to) {
/* Thread doesn't need this buffer any more */ /* Thread doesn't need this buffer any more */
job[current].working[mypos][CACHE_LINE_SIZE * bufferside] = 0; job[current].working[mypos][CACHE_LINE_SIZE * bufferside] = 0;
WMB;
} }
} }


@@ -541,7 +544,7 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,


/* Make sure if no one is using another buffer */ /* Make sure if no one is using another buffer */
for (i = 0; i < args -> nthreads; i++) for (i = 0; i < args -> nthreads; i++)
while (job[mypos].working[i][CACHE_LINE_SIZE * bufferside]) {YIELDING;};
while (job[mypos].working[i][CACHE_LINE_SIZE * bufferside]) {YIELDING;MB;};


STOP_RPCC(waiting1); STOP_RPCC(waiting1);


@@ -595,7 +598,7 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
START_RPCC(); START_RPCC();


/* thread has to wait */ /* thread has to wait */
while(job[current].working[mypos][CACHE_LINE_SIZE * bufferside] == 0) {YIELDING;};
while(job[current].working[mypos][CACHE_LINE_SIZE * bufferside] == 0) {YIELDING;MB;};


STOP_RPCC(waiting2); STOP_RPCC(waiting2);


@@ -613,6 +616,7 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,


if (m_to - m_from == min_i) { if (m_to - m_from == min_i) {
job[current].working[mypos][CACHE_LINE_SIZE * bufferside] = 0; job[current].working[mypos][CACHE_LINE_SIZE * bufferside] = 0;
WMB;
} }
} }
} while (current != mypos); } while (current != mypos);
@@ -677,7 +681,7 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,


/* Make sure if no one is using another buffer */ /* Make sure if no one is using another buffer */
for (i = 0; i < args -> nthreads; i++) for (i = 0; i < args -> nthreads; i++)
while (job[mypos].working[i][CACHE_LINE_SIZE * bufferside]) {YIELDING;};
while (job[mypos].working[i][CACHE_LINE_SIZE * bufferside]) {YIELDING;MB;};


STOP_RPCC(waiting1); STOP_RPCC(waiting1);


@@ -731,7 +735,7 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
START_RPCC(); START_RPCC();


/* thread has to wait */ /* thread has to wait */
while(job[current].working[mypos][CACHE_LINE_SIZE * bufferside] == 0) {YIELDING;};
while(job[current].working[mypos][CACHE_LINE_SIZE * bufferside] == 0) {YIELDING;MB;};


STOP_RPCC(waiting2); STOP_RPCC(waiting2);


@@ -748,8 +752,9 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
} }


if (m_to - m_from == min_i) { if (m_to - m_from == min_i) {
job[current].working[mypos][CACHE_LINE_SIZE * bufferside] = 0;
}
job[current].working[mypos][CACHE_LINE_SIZE * bufferside] &= 0;
WMB;
}
} }
} while (current != mypos); } while (current != mypos);


@@ -787,7 +792,8 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
#endif #endif
if (is + min_i >= m_to) { if (is + min_i >= m_to) {
/* Thread doesn't need this buffer any more */ /* Thread doesn't need this buffer any more */
job[current].working[mypos][CACHE_LINE_SIZE * bufferside] = 0;
job[current].working[mypos][CACHE_LINE_SIZE * bufferside] &= 0;
WMB;
} }
} }


@@ -804,7 +810,7 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,


for (i = 0; i < args -> nthreads; i++) { for (i = 0; i < args -> nthreads; i++) {
for (xxx = 0; xxx < DIVIDE_RATE; xxx++) { for (xxx = 0; xxx < DIVIDE_RATE; xxx++) {
while (job[mypos].working[i][CACHE_LINE_SIZE * xxx] ) {YIELDING;};
while (job[mypos].working[i][CACHE_LINE_SIZE * xxx] ) {YIELDING;MB;};
} }
} }


@@ -840,6 +846,15 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
static int gemm_driver(blas_arg_t *args, BLASLONG *range_m, BLASLONG static int gemm_driver(blas_arg_t *args, BLASLONG *range_m, BLASLONG
*range_n, FLOAT *sa, FLOAT *sb, BLASLONG mypos){ *range_n, FLOAT *sa, FLOAT *sb, BLASLONG mypos){


#ifndef USE_OPENMP
#ifndef OS_WINDOWS
static pthread_mutex_t level3_lock = PTHREAD_MUTEX_INITIALIZER;
#else
CRITICAL_SECTION level3_lock;
InitializeCriticalSection((PCRITICAL_SECTION)&level3_lock);
#endif
#endif

blas_arg_t newarg; blas_arg_t newarg;


blas_queue_t queue[MAX_CPU_NUMBER]; blas_queue_t queue[MAX_CPU_NUMBER];
@@ -869,6 +884,14 @@ static int gemm_driver(blas_arg_t *args, BLASLONG *range_m, BLASLONG
mode = BLAS_SINGLE | BLAS_REAL | BLAS_NODE; mode = BLAS_SINGLE | BLAS_REAL | BLAS_NODE;
#endif #endif


#ifndef USE_OPENMP
#ifndef OS_WINDOWS
pthread_mutex_lock(&level3_lock);
#else
EnterCriticalSection((PCRITICAL_SECTION)&level3_lock);
#endif
#endif

newarg.m = args -> m; newarg.m = args -> m;
newarg.n = args -> n; newarg.n = args -> n;
newarg.k = args -> k; newarg.k = args -> k;
@@ -973,6 +996,14 @@ static int gemm_driver(blas_arg_t *args, BLASLONG *range_m, BLASLONG
free(job); free(job);
#endif #endif


#ifndef USE_OPENMP
#ifndef OS_WINDOWS
pthread_mutex_unlock(&level3_lock);
#else
LeaveCriticalSection((PCRITICAL_SECTION)&level3_lock);
#endif
#endif

return 0; return 0;
} }




Loading…
Cancel
Save