|
|
@@ -93,7 +93,7 @@ static int sgemv_kernel_t_1(BLASLONG m, float alpha, float *a, float *x, float * |
|
|
|
} |
|
|
|
|
|
|
|
if (tag_m_32x != m) { |
|
|
|
for (BLASLONG idx_m = tag_m_64x; idx_m < tag_m_16x; idx_m+=32) { |
|
|
|
for (BLASLONG idx_m = tag_m_32x; idx_m < tag_m_16x; idx_m+=16) { |
|
|
|
matrixArray_0 = _mm512_loadu_ps(&a[idx_m + 0]); |
|
|
|
|
|
|
|
_mm512_storeu_ps(&y[idx_m + 0], _mm512_fmadd_ps(matrixArray_0, ALPHAXVECTOR, _mm512_loadu_ps(&y[idx_m + 0]))); |
|
|
@@ -145,8 +145,8 @@ static int sgemv_kernel_t_2(BLASLONG m, float alpha, float *a, float *x, float * |
|
|
|
} |
|
|
|
if (tag_m_32x != m) { |
|
|
|
for (BLASLONG idx_m = tag_m_32x; idx_m < tag_m_16x; idx_m+=16) { |
|
|
|
m0 = _mm512_loadu_ps(&a[idx_m]); |
|
|
|
m1 = _mm512_loadu_ps(&a[idx_m + 16]); |
|
|
|
m0 = _mm512_loadu_ps(&a[idx_m*2]); |
|
|
|
m1 = _mm512_loadu_ps(&a[idx_m*2 + 16]); |
|
|
|
col1_1 = _mm512_permutex2var_ps(m0, idx_base_0, m1); |
|
|
|
col1_2 = _mm512_permutex2var_ps(m0, idx_base_1, m1); |
|
|
|
_mm512_storeu_ps(&y[idx_m], _mm512_add_ps(_mm512_fmadd_ps(x2Array, col1_2, _mm512_mul_ps(col1_1, x1Array)), _mm512_loadu_ps(&y[idx_m]))); |
|
|
@@ -157,7 +157,7 @@ static int sgemv_kernel_t_2(BLASLONG m, float alpha, float *a, float *x, float * |
|
|
|
__mmask8 load_mask = *((__mmask8*) &load_mask_value); |
|
|
|
x1Array = _mm512_broadcast_f32x2(_mm_maskz_loadu_ps(load_mask, x)); |
|
|
|
for (BLASLONG idx_m = tag_m_16x; idx_m < tag_m_8x; idx_m+=8) { |
|
|
|
m0 = _mm512_loadu_ps(&a[idx_m]); |
|
|
|
m0 = _mm512_loadu_ps(&a[idx_m*2]); |
|
|
|
m1 = _mm512_mul_ps(_mm512_mul_ps(m0, x1Array), ALPHAVECTOR); |
|
|
|
m2 = _mm512_permutexvar_ps(_mm512_set_epi32(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0), m1); |
|
|
|
__m256 ret = _mm256_add_ps(_mm512_extractf32x8_ps(m2, 1), _mm512_extractf32x8_ps(m2, 0)); |
|
|
@@ -166,12 +166,12 @@ static int sgemv_kernel_t_2(BLASLONG m, float alpha, float *a, float *x, float * |
|
|
|
} |
|
|
|
|
|
|
|
if (tag_m_8x != m) { |
|
|
|
unsigned short tail_mask_value = (((unsigned int)0xffff) >> (16-((m-tag_m_8x)*2)&15)); |
|
|
|
unsigned short tail_mask_value = (((unsigned int)0xffff) >> (16-(((m-tag_m_8x)*2)&15))); |
|
|
|
__mmask16 a_mask = *((__mmask16*) &tail_mask_value); |
|
|
|
unsigned char y_mask_value = (((unsigned char)0xff) >> (8-(m-tag_m_8x))); |
|
|
|
__mmask8 y_mask = *((__mmask8*) &y_mask_value); |
|
|
|
|
|
|
|
m0 = _mm512_maskz_loadu_ps(a_mask, &a[tag_m_8x]); |
|
|
|
m0 = _mm512_maskz_loadu_ps(a_mask, &a[tag_m_8x*2]); |
|
|
|
m1 = _mm512_mul_ps(_mm512_mul_ps(m0, x1Array), ALPHAVECTOR); |
|
|
|
m2 = _mm512_permutexvar_ps(_mm512_set_epi32(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0), m1); |
|
|
|
__m256 ret = _mm256_add_ps(_mm512_extractf32x8_ps(m2, 1), _mm512_extractf32x8_ps(m2, 0)); |
|
|
@@ -322,7 +322,7 @@ static int sgemv_kernel_t_4(BLASLONG m, float alpha, float *a, float *x, float * |
|
|
|
{ |
|
|
|
BLASLONG tag_m_4x = m & (~3); |
|
|
|
BLASLONG tag_m_2x = m & (~1); |
|
|
|
__m512 m0, m1, m2; |
|
|
|
__m512 m0, m1; |
|
|
|
__m256 m256_0, m256_1, c256_1, c256_2; |
|
|
|
__m128 c1, c2, c3, c4, ret; |
|
|
|
__m128 xarray = _mm_maskz_loadu_ps(0x0f, x); |
|
|
@@ -346,7 +346,7 @@ static int sgemv_kernel_t_4(BLASLONG m, float alpha, float *a, float *x, float * |
|
|
|
c3 = _mm256_extractf32x4_ps(c256_2, 0); |
|
|
|
c4 = _mm256_extractf32x4_ps(c256_2, 1); |
|
|
|
|
|
|
|
ret = _mm_maskz_add_ps(0xff, _mm_maskz_add_ps(0xff, _mm_maskz_add_ps(0xff, c1, c2), _mm_maskz_add_ps(0xff, c3, c4)), _mm_maskz_loadu_ps(0xff, y)); |
|
|
|
ret = _mm_maskz_add_ps(0xff, _mm_maskz_add_ps(0xff, _mm_maskz_add_ps(0xff, c1, c2), _mm_maskz_add_ps(0xff, c3, c4)), _mm_maskz_loadu_ps(0xff, &y[idx_m])); |
|
|
|
_mm_mask_storeu_ps(&y[idx_m], 0xff, ret); |
|
|
|
} |
|
|
|
|
|
|
@@ -958,6 +958,7 @@ static int sgemv_kernel_t_7(BLASLONG m, float alpha, float *a, float *x, float * |
|
|
|
c256_1 = _mm512_extractf32x8_ps(tmp0, 1); |
|
|
|
|
|
|
|
c256_0 = _mm256_add_ps(c256_0, c256_1); |
|
|
|
c256_0 = _mm256_mul_ps(c256_0, alpha256); |
|
|
|
|
|
|
|
__m128 c128_0 = _mm256_extractf32x4_ps(c256_0, 0); |
|
|
|
__m128 c128_1 = _mm256_extractf32x4_ps(c256_0, 1); |
|
|
@@ -1016,9 +1017,10 @@ static int sgemv_kernel_t_8(BLASLONG m, float alpha, float *a, float *x, float * |
|
|
|
__m512 m0, m1, m2, m3; |
|
|
|
__m256 r0, r1, r2, r3, r4, r5, r6, r7, tmp0, tmp1, tmp2, tmp3; |
|
|
|
__m128 c128_0, c128_1, c128_2, c128_3; |
|
|
|
__m128 alpha128 = _mm_set1_ps(alpha); |
|
|
|
__m256 alpha256 = _mm256_set1_ps(alpha); |
|
|
|
|
|
|
|
__m256 x256 = _mm256_loadu_ps(x); |
|
|
|
x256 = _mm256_mul_ps(x256, alpha256); |
|
|
|
__m512 x512 = _mm512_broadcast_f32x8(x256); |
|
|
|
|
|
|
|
for(BLASLONG idx_m=0; idx_m<tag_m_8x; idx_m+=8) { |
|
|
@@ -1053,8 +1055,8 @@ static int sgemv_kernel_t_8(BLASLONG m, float alpha, float *a, float *x, float * |
|
|
|
|
|
|
|
c128_0 = _mm_add_ps(c128_0, c128_1); |
|
|
|
c128_2 = _mm_add_ps(c128_2, c128_3); |
|
|
|
_mm_storeu_ps(&y[idx_m], _mm_fmadd_ps(c128_0, alpha128, _mm_loadu_ps(&y[idx_m]))); |
|
|
|
_mm_storeu_ps(&y[idx_m+4], _mm_fmadd_ps(c128_2, alpha128, _mm_loadu_ps(&y[idx_m+4]))); |
|
|
|
_mm_storeu_ps(&y[idx_m], _mm_add_ps(c128_0, _mm_loadu_ps(&y[idx_m]))); |
|
|
|
_mm_storeu_ps(&y[idx_m+4], _mm_add_ps(c128_2, _mm_loadu_ps(&y[idx_m+4]))); |
|
|
|
} |
|
|
|
|
|
|
|
if (tag_m_8x !=m ){ |
|
|
@@ -1078,7 +1080,7 @@ static int sgemv_kernel_t_8(BLASLONG m, float alpha, float *a, float *x, float * |
|
|
|
c128_1 = _mm256_extractf32x4_ps(tmp1, 1); |
|
|
|
|
|
|
|
c128_0 = _mm_add_ps(c128_0, c128_1); |
|
|
|
_mm_storeu_ps(&y[idx_m], _mm_fmadd_ps(c128_0, alpha128, _mm_loadu_ps(&y[idx_m]))); |
|
|
|
_mm_storeu_ps(&y[idx_m], _mm_add_ps(c128_0, _mm_loadu_ps(&y[idx_m]))); |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
@@ -1094,7 +1096,6 @@ static int sgemv_kernel_t_8(BLASLONG m, float alpha, float *a, float *x, float * |
|
|
|
c128_1 = _mm256_extractf32x4_ps(tmp0, 1); |
|
|
|
|
|
|
|
c128_0 = _mm_add_ps(c128_0, c128_1); |
|
|
|
c128_0 = _mm_mul_ps(c128_0, alpha128); |
|
|
|
|
|
|
|
_mm_storeu_ps(ret, c128_0); |
|
|
|
y[idx_m] += (ret[0]+ret[1]); |
|
|
|