|
|
@@ -28,16 +28,16 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
|
|
|
#include "common.h" |
|
|
|
|
|
|
|
#if defined(COOPERLAKE) |
|
|
|
#include "shdot_microk_cooperlake.c" |
|
|
|
#include "sbdot_microk_cooperlake.c" |
|
|
|
#endif |
|
|
|
|
|
|
|
static float shdot_compute(BLASLONG n, bfloat16 *x, BLASLONG inc_x, bfloat16 *y, BLASLONG inc_y) |
|
|
|
static float sbdot_compute(BLASLONG n, bfloat16 *x, BLASLONG inc_x, bfloat16 *y, BLASLONG inc_y) |
|
|
|
{ |
|
|
|
float d = 0.0; |
|
|
|
|
|
|
|
#ifdef HAVE_SHDOT_ACCL_KERNEL |
|
|
|
#ifdef HAVE_SBDOT_ACCL_KERNEL |
|
|
|
if ((inc_x == 1) && (inc_y == 1)) { |
|
|
|
return shdot_accl_kernel(n, x, y); |
|
|
|
return sbdot_accl_kernel(n, x, y); |
|
|
|
} |
|
|
|
#endif |
|
|
|
|
|
|
@@ -56,11 +56,11 @@ static float shdot_compute(BLASLONG n, bfloat16 *x, BLASLONG inc_x, bfloat16 *y, |
|
|
|
} |
|
|
|
|
|
|
|
#if defined(SMP) |
|
|
|
static int shdot_thread_func(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, bfloat16 dummy2, |
|
|
|
static int sbdot_thread_func(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, bfloat16 dummy2, |
|
|
|
bfloat16 *x, BLASLONG inc_x, bfloat16 *y, BLASLONG inc_y, |
|
|
|
float *result, BLASLONG dummy3) |
|
|
|
{ |
|
|
|
*(float *)result = shdot_compute(n, x, inc_x, y, inc_y); |
|
|
|
*(float *)result = sbdot_compute(n, x, inc_x, y, inc_y); |
|
|
|
return 0; |
|
|
|
} |
|
|
|
|
|
|
@@ -94,13 +94,13 @@ float CNAME(BLASLONG n, bfloat16 *x, BLASLONG inc_x, bfloat16 *y, BLASLONG inc_y |
|
|
|
} |
|
|
|
|
|
|
|
if (nthreads <= 1) { |
|
|
|
dot_result = shdot_compute(n, x, inc_x, y, inc_y); |
|
|
|
dot_result = sbdot_compute(n, x, inc_x, y, inc_y); |
|
|
|
} else { |
|
|
|
char thread_result[MAX_CPU_NUMBER * sizeof(double) * 2]; |
|
|
|
int mode = BLAS_BFLOAT16 | BLAS_REAL; |
|
|
|
blas_level1_thread_with_return_value(mode, n, 0, 0, &dummy_alpha, |
|
|
|
x, inc_x, y, inc_y, thread_result, 0, |
|
|
|
(void *)shdot_thread_func, nthreads); |
|
|
|
(void *)sbdot_thread_func, nthreads); |
|
|
|
float * ptr = (float *)thread_result; |
|
|
|
for (int i = 0; i < nthreads; i++) { |
|
|
|
dot_result += (*ptr); |
|
|
@@ -108,7 +108,7 @@ float CNAME(BLASLONG n, bfloat16 *x, BLASLONG inc_x, bfloat16 *y, BLASLONG inc_y |
|
|
|
} |
|
|
|
} |
|
|
|
#else |
|
|
|
dot_result = shdot_compute(n, x, inc_x, y, inc_y); |
|
|
|
dot_result = sbdot_compute(n, x, inc_x, y, inc_y); |
|
|
|
#endif |
|
|
|
|
|
|
|
return dot_result; |
|
|
|