@@ -47,27 +47,59 @@ FLOAT CNAME(BLASLONG n, FLOAT *x, BLASLONG inc_x, FLOAT *y, BLASLONG inc_y) | |||
if ( (inc_x == 1) && (inc_y == 1) ) | |||
{ | |||
int n1 = n & -4; | |||
while(i < n1) | |||
#if V_SIMD && !defined(DSDOT) | |||
const int vstep = v_nlanes_f32; | |||
const int unrollx4 = n & (-vstep * 4); | |||
const int unrollx = n & -vstep; | |||
v_f32 vsum0 = v_zero_f32(); | |||
v_f32 vsum1 = v_zero_f32(); | |||
v_f32 vsum2 = v_zero_f32(); | |||
v_f32 vsum3 = v_zero_f32(); | |||
while(i < unrollx4) | |||
{ | |||
vsum0 = v_muladd_f32( | |||
v_loadu_f32(x + i), v_loadu_f32(y + i), vsum0 | |||
); | |||
vsum1 = v_muladd_f32( | |||
v_loadu_f32(x + i + vstep), v_loadu_f32(y + i + vstep), vsum1 | |||
); | |||
vsum2 = v_muladd_f32( | |||
v_loadu_f32(x + i + vstep*2), v_loadu_f32(y + i + vstep*2), vsum2 | |||
); | |||
vsum3 = v_muladd_f32( | |||
v_loadu_f32(x + i + vstep*3), v_loadu_f32(y + i + vstep*3), vsum3 | |||
); | |||
i += vstep*4; | |||
} | |||
vsum0 = v_add_f32( | |||
v_add_f32(vsum0, vsum1), v_add_f32(vsum2 , vsum3) | |||
); | |||
while(i < unrollx) | |||
{ | |||
vsum0 = v_muladd_f32( | |||
v_loadu_f32(x + i), v_loadu_f32(y + i), vsum0 | |||
); | |||
i += vstep; | |||
} | |||
dot = v_sum_f32(vsum0); | |||
#elif defined(DSDOT) | |||
for (; i < n1; i += 4) | |||
{ | |||
#if defined(DSDOT) | |||
dot += (double) y[i] * (double) x[i] | |||
+ (double) y[i+1] * (double) x[i+1] | |||
+ (double) y[i+2] * (double) x[i+2] | |||
+ (double) y[i+3] * (double) x[i+3] ; | |||
} | |||
#else | |||
for (; i < n1; i += 4) | |||
{ | |||
dot += y[i] * x[i] | |||
+ y[i+1] * x[i+1] | |||
+ y[i+2] * x[i+2] | |||
+ y[i+3] * x[i+3] ; | |||
#endif | |||
i+=4 ; | |||
} | |||
#endif | |||
while(i < n) | |||
{ | |||
@@ -51,6 +51,11 @@ extern "C" { | |||
#include <immintrin.h> | |||
#endif | |||
/** NEON **/ | |||
#ifdef HAVE_NEON | |||
#include <arm_neon.h> | |||
#endif | |||
// distribute | |||
#if defined(HAVE_AVX512VL) || defined(HAVE_AVX512BF16) | |||
#include "intrin_avx512.h" | |||
@@ -60,6 +65,10 @@ extern "C" { | |||
#include "intrin_sse.h" | |||
#endif | |||
#ifdef HAVE_NEON | |||
#include "intrin_neon.h" | |||
#endif | |||
#ifndef V_SIMD | |||
#define V_SIMD 0 | |||
#define V_SIMD_F64 0 | |||
@@ -1,13 +1,13 @@ | |||
#define V_SIMD 256 | |||
#define V_SIMD_F64 1 | |||
/* | |||
Data Type | |||
*/ | |||
/*************************** | |||
* Data Type | |||
***************************/ | |||
typedef __m256 v_f32; | |||
#define v_nlanes_f32 8 | |||
/* | |||
arithmetic | |||
*/ | |||
/*************************** | |||
* Arithmetic | |||
***************************/ | |||
#define v_add_f32 _mm256_add_ps | |||
#define v_mul_f32 _mm256_mul_ps | |||
@@ -20,10 +20,22 @@ arithmetic | |||
{ return v_add_f32(v_mul_f32(a, b), c); } | |||
#endif // !HAVE_FMA3 | |||
/* | |||
memory | |||
*/ | |||
// Horizontal add: Calculates the sum of all vector elements. | |||
BLAS_FINLINE float v_sum_f32(__m256 a) | |||
{ | |||
__m256 sum_halves = _mm256_hadd_ps(a, a); | |||
sum_halves = _mm256_hadd_ps(sum_halves, sum_halves); | |||
__m128 lo = _mm256_castps256_ps128(sum_halves); | |||
__m128 hi = _mm256_extractf128_ps(sum_halves, 1); | |||
__m128 sum = _mm_add_ps(lo, hi); | |||
return _mm_cvtss_f32(sum); | |||
} | |||
/*************************** | |||
* memory | |||
***************************/ | |||
// unaligned load | |||
#define v_loadu_f32 _mm256_loadu_ps | |||
#define v_storeu_f32 _mm256_storeu_ps | |||
#define v_setall_f32(VAL) _mm256_set1_ps(VAL) | |||
#define v_setall_f32(VAL) _mm256_set1_ps(VAL) | |||
#define v_zero_f32 _mm256_setzero_ps |
@@ -1,21 +1,35 @@ | |||
#define V_SIMD 512 | |||
#define V_SIMD_F64 1 | |||
/* | |||
Data Type | |||
*/ | |||
/*************************** | |||
* Data Type | |||
***************************/ | |||
typedef __m512 v_f32; | |||
#define v_nlanes_f32 16 | |||
/* | |||
arithmetic | |||
*/ | |||
/*************************** | |||
* Arithmetic | |||
***************************/ | |||
#define v_add_f32 _mm512_add_ps | |||
#define v_mul_f32 _mm512_mul_ps | |||
// multiply and add, a*b + c | |||
#define v_muladd_f32 _mm512_fmadd_ps | |||
/* | |||
memory | |||
*/ | |||
BLAS_FINLINE float v_sum_f32(v_f32 a) | |||
{ | |||
__m512 h64 = _mm512_shuffle_f32x4(a, a, _MM_SHUFFLE(3, 2, 3, 2)); | |||
__m512 sum32 = _mm512_add_ps(a, h64); | |||
__m512 h32 = _mm512_shuffle_f32x4(sum32, sum32, _MM_SHUFFLE(1, 0, 3, 2)); | |||
__m512 sum16 = _mm512_add_ps(sum32, h32); | |||
__m512 h16 = _mm512_permute_ps(sum16, _MM_SHUFFLE(1, 0, 3, 2)); | |||
__m512 sum8 = _mm512_add_ps(sum16, h16); | |||
__m512 h4 = _mm512_permute_ps(sum8, _MM_SHUFFLE(2, 3, 0, 1)); | |||
__m512 sum4 = _mm512_add_ps(sum8, h4); | |||
return _mm_cvtss_f32(_mm512_castps512_ps128(sum4)); | |||
} | |||
/*************************** | |||
* memory | |||
***************************/ | |||
// unaligned load | |||
#define v_loadu_f32(PTR) _mm512_loadu_ps((const __m512*)(PTR)) | |||
#define v_storeu_f32 _mm512_storeu_ps | |||
#define v_setall_f32(VAL) _mm512_set1_ps(VAL) | |||
#define v_zero_f32 _mm512_setzero_ps |
@@ -0,0 +1,42 @@ | |||
#define V_SIMD 128 | |||
#ifdef __aarch64__ | |||
#define V_SIMD_F64 1 | |||
#else | |||
#define V_SIMD_F64 0 | |||
#endif | |||
/*************************** | |||
* Data Type | |||
***************************/ | |||
typedef float32x4_t v_f32; | |||
#define v_nlanes_f32 4 | |||
/*************************** | |||
* Arithmetic | |||
***************************/ | |||
#define v_add_f32 vaddq_f32 | |||
#define v_mul_f32 vmulq_f32 | |||
// FUSED F32 | |||
#ifdef HAVE_VFPV4 // FMA | |||
// multiply and add, a*b + c | |||
BLAS_FINLINE v_f32 v_muladd_f32(v_f32 a, v_f32 b, v_f32 c) | |||
{ return vfmaq_f32(c, a, b); } | |||
#else | |||
// multiply and add, a*b + c | |||
BLAS_FINLINE v_f32 v_muladd_f32(v_f32 a, v_f32 b, v_f32 c) | |||
{ return vmlaq_f32(c, a, b); } | |||
#endif | |||
// Horizontal add: Calculates the sum of all vector elements. | |||
BLAS_FINLINE float v_sum_f32(float32x4_t a) | |||
{ | |||
float32x2_t r = vadd_f32(vget_high_f32(a), vget_low_f32(a)); | |||
return vget_lane_f32(vpadd_f32(r, r), 0); | |||
} | |||
/*************************** | |||
* memory | |||
***************************/ | |||
// unaligned load | |||
#define v_loadu_f32(a) vld1q_f32((const float*)a) | |||
#define v_storeu_f32 vst1q_f32 | |||
#define v_setall_f32(VAL) vdupq_n_f32(VAL) | |||
#define v_zero_f32() vdupq_n_f32(0.0f) |
@@ -1,13 +1,13 @@ | |||
#define V_SIMD 128 | |||
#define V_SIMD_F64 1 | |||
/* | |||
Data Type | |||
*/ | |||
/*************************** | |||
* Data Type | |||
***************************/ | |||
typedef __m128 v_f32; | |||
#define v_nlanes_f32 4 | |||
/* | |||
arithmetic | |||
*/ | |||
/*************************** | |||
* Arithmetic | |||
***************************/ | |||
#define v_add_f32 _mm_add_ps | |||
#define v_mul_f32 _mm_mul_ps | |||
#ifdef HAVE_FMA3 | |||
@@ -21,10 +21,26 @@ arithmetic | |||
BLAS_FINLINE v_f32 v_muladd_f32(v_f32 a, v_f32 b, v_f32 c) | |||
{ return v_add_f32(v_mul_f32(a, b), c); } | |||
#endif // HAVE_FMA3 | |||
/* | |||
memory | |||
*/ | |||
// Horizontal add: Calculates the sum of all vector elements. | |||
BLAS_FINLINE float v_sum_f32(__m128 a) | |||
{ | |||
#ifdef HAVE_SSE3 | |||
__m128 sum_halves = _mm_hadd_ps(a, a); | |||
return _mm_cvtss_f32(_mm_hadd_ps(sum_halves, sum_halves)); | |||
#else | |||
__m128 t1 = _mm_movehl_ps(a, a); | |||
__m128 t2 = _mm_add_ps(a, t1); | |||
__m128 t3 = _mm_shuffle_ps(t2, t2, 1); | |||
__m128 t4 = _mm_add_ss(t2, t3); | |||
return _mm_cvtss_f32(t4); | |||
#endif | |||
} | |||
/*************************** | |||
* memory | |||
***************************/ | |||
// unaligned load | |||
#define v_loadu_f32 _mm_loadu_ps | |||
#define v_storeu_f32 _mm_storeu_ps | |||
#define v_setall_f32(VAL) _mm_set1_ps(VAL) | |||
#define v_setall_f32(VAL) _mm_set1_ps(VAL) | |||
#define v_zero_f32 _mm_setzero_ps |
@@ -47,3 +47,17 @@ CTEST(dsdot,dsdot_n_1) | |||
ASSERT_DBL_NEAR_TOL(res2, res1, DOUBLE_EPS); | |||
} | |||
CTEST(dsdot,dsdot_n_2) | |||
{ | |||
float x[] = {0.1F, 0.2F, 0.3F, 0.4F, 0.5F, 0.6F, 0.7F, 0.8F}; | |||
float y[] = {0.1F, 0.2F, 0.3F, 0.4F, 0.5F, 0.6F, 0.7F, 0.8F}; | |||
blasint incx=1; | |||
blasint incy=1; | |||
blasint n=8; | |||
double res1=0.0f, res2= 2.0400000444054616; | |||
res1=BLASFUNC(dsdot)(&n, &x, &incx, &y, &incy); | |||
ASSERT_DBL_NEAR_TOL(res2, res1, DOUBLE_EPS); | |||
} |