Browse Source

Merge pull request #5427 from yuanjia111/develop

Optimize the gemv_t_vector.c  kernel  for  RISCV64_ZVL256B target
pull/1752/merge
Martin Kroeker GitHub 1 month ago
parent
commit
da7d0f4a38
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
1 changed files with 168 additions and 79 deletions
  1. +168
    -79
      kernel/riscv64/gemv_t_vector.c

+ 168
- 79
kernel/riscv64/gemv_t_vector.c View File

@@ -27,110 +27,199 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.


#include "common.h" #include "common.h"
#if !defined(DOUBLE) #if !defined(DOUBLE)
#define VSETVL(n) RISCV_RVV(vsetvl_e32m2)(n)
#define FLOAT_V_T vfloat32m2_t
#define VSETVL(n) RISCV_RVV(vsetvl_e32m8)(n)
#define VSETVL_MAX_M1 RISCV_RVV(vsetvlmax_e32m1)
#define FLOAT_V_T vfloat32m8_t
#define FLOAT_V_T_M1 vfloat32m1_t #define FLOAT_V_T_M1 vfloat32m1_t
#define VLEV_FLOAT RISCV_RVV(vle32_v_f32m2)
#define VLSEV_FLOAT RISCV_RVV(vlse32_v_f32m2)
#define VLEV_FLOAT RISCV_RVV(vle32_v_f32m8)
#define VLSEV_FLOAT RISCV_RVV(vlse32_v_f32m8)
#ifdef RISCV_0p10_INTRINSICS #ifdef RISCV_0p10_INTRINSICS
#define VFREDSUM_FLOAT(va, vb, gvl) vfredusum_vs_f32m2_f32m1(v_res, va, vb, gvl)
#define VFREDSUM_FLOAT(va, vb, gvl) vfredusum_vs_f32m8_f32m1(v_res, va, vb, gvl)
#else #else
#define VFREDSUM_FLOAT RISCV_RVV(vfredusum_vs_f32m2_f32m1)
#define VFREDSUM_FLOAT RISCV_RVV(vfredusum_vs_f32m8_f32m1)
#endif #endif
#define VFMACCVV_FLOAT RISCV_RVV(vfmacc_vv_f32m2)
#define VFMVVF_FLOAT RISCV_RVV(vfmv_v_f_f32m2)
#define VFMULVV_FLOAT RISCV_RVV(vfmul_vv_f32m8)
#define VFMVVF_FLOAT RISCV_RVV(vfmv_v_f_f32m8)
#define VFMVVF_FLOAT_M1 RISCV_RVV(vfmv_v_f_f32m1) #define VFMVVF_FLOAT_M1 RISCV_RVV(vfmv_v_f_f32m1)
#define VFMULVV_FLOAT RISCV_RVV(vfmul_vv_f32m2)
#define xint_t int #define xint_t int
#else #else
#define VSETVL(n) RISCV_RVV(vsetvl_e64m2)(n)
#define FLOAT_V_T vfloat64m2_t
#define VSETVL(n) RISCV_RVV(vsetvl_e64m8)(n)
#define VSETVL_MAX_M1 RISCV_RVV(vsetvlmax_e64m1)
#define FLOAT_V_T vfloat64m8_t
#define FLOAT_V_T_M1 vfloat64m1_t #define FLOAT_V_T_M1 vfloat64m1_t
#define VLEV_FLOAT RISCV_RVV(vle64_v_f64m2)
#define VLSEV_FLOAT RISCV_RVV(vlse64_v_f64m2)
#define VLEV_FLOAT RISCV_RVV(vle64_v_f64m8)
#define VLSEV_FLOAT RISCV_RVV(vlse64_v_f64m8)
#ifdef RISCV_0p10_INTRINSICS #ifdef RISCV_0p10_INTRINSICS
#define VFREDSUM_FLOAT(va, vb, gvl) vfredusum_vs_f64m2_f64m1(v_res, va, vb, gvl)
#define VFREDSUM_FLOAT(va, vb, gvl) vfredusum_vs_f64m8_f64m1(v_res, va, vb, gvl)
#else #else
#define VFREDSUM_FLOAT RISCV_RVV(vfredusum_vs_f64m2_f64m1)
#define VFREDSUM_FLOAT RISCV_RVV(vfredusum_vs_f64m8_f64m1)
#endif #endif
#define VFMACCVV_FLOAT RISCV_RVV(vfmacc_vv_f64m2)
#define VFMVVF_FLOAT RISCV_RVV(vfmv_v_f_f64m2)
#define VFMULVV_FLOAT RISCV_RVV(vfmul_vv_f64m8)
#define VFMVVF_FLOAT RISCV_RVV(vfmv_v_f_f64m8)
#define VFMVVF_FLOAT_M1 RISCV_RVV(vfmv_v_f_f64m1) #define VFMVVF_FLOAT_M1 RISCV_RVV(vfmv_v_f_f64m1)
#define VFMULVV_FLOAT RISCV_RVV(vfmul_vv_f64m2)
#define xint_t long long #define xint_t long long
#endif #endif


int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT alpha, FLOAT *a, BLASLONG lda, FLOAT *x, BLASLONG inc_x, FLOAT *y, BLASLONG inc_y, FLOAT *buffer) int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT alpha, FLOAT *a, BLASLONG lda, FLOAT *x, BLASLONG inc_x, FLOAT *y, BLASLONG inc_y, FLOAT *buffer)
{ {
BLASLONG i = 0, j = 0, k = 0;
BLASLONG ix = 0, iy = 0;
FLOAT *a_ptr = a;
FLOAT temp;
BLASLONG i = 0, j = 0, k = 0;
BLASLONG ix = 0, iy = 0;
FLOAT *a_ptr = a;
FLOAT temp;


FLOAT_V_T va, vr, vx;
unsigned int gvl = 0;
FLOAT_V_T_M1 v_res;
FLOAT_V_T va, vr, vx;
unsigned int gvl = 0;
FLOAT_V_T_M1 v_res;
size_t vlmax = VSETVL_MAX_M1();


#ifndef RISCV_0p10_INTRINSICS
FLOAT_V_T va0, va1, va2, va3, vr0, vr1, vr2, vr3;
FLOAT_V_T_M1 vec0, vec1, vec2, vec3;
FLOAT *a_ptrs[4], *y_ptrs[4];
#endif


if(inc_x == 1){
for(i = 0; i < n; i++){
v_res = VFMVVF_FLOAT_M1(0, 1);
gvl = VSETVL(m);
j = 0;
vr = VFMVVF_FLOAT(0, gvl);
for(k = 0; k < m/gvl; k++){
va = VLEV_FLOAT(&a_ptr[j], gvl);
vx = VLEV_FLOAT(&x[j], gvl);
vr = VFMULVV_FLOAT(va, vx, gvl); // could vfmacc here and reduce outside loop
v_res = VFREDSUM_FLOAT(vr, v_res, gvl); // but that reordering diverges far enough from scalar path to make tests fail
j += gvl;
}
if(j < m){
gvl = VSETVL(m-j);
va = VLEV_FLOAT(&a_ptr[j], gvl);
vx = VLEV_FLOAT(&x[j], gvl);
vr = VFMULVV_FLOAT(va, vx, gvl);
v_res = VFREDSUM_FLOAT(vr, v_res, gvl);
}
temp = (FLOAT)EXTRACT_FLOAT(v_res);
y[iy] += alpha * temp;
if(inc_x == 1){
#ifndef RISCV_0p10_INTRINSICS
BLASLONG anr = n - n % 4;
for (; i < anr; i += 4) {
gvl = VSETVL(m);
j = 0;
for (int l = 0; l < 4; l++) {
a_ptrs[l] = a + (i + l) * lda;
y_ptrs[l] = y + (i + l) * inc_y;
}
vec0 = VFMVVF_FLOAT_M1(0.0, vlmax);
vec1 = VFMVVF_FLOAT_M1(0.0, vlmax);
vec2 = VFMVVF_FLOAT_M1(0.0, vlmax);
vec3 = VFMVVF_FLOAT_M1(0.0, vlmax);
vr0 = VFMVVF_FLOAT(0.0, gvl);
vr1 = VFMVVF_FLOAT(0.0, gvl);
vr2 = VFMVVF_FLOAT(0.0, gvl);
vr3 = VFMVVF_FLOAT(0.0, gvl);
for (k = 0; k < m / gvl; k++) {
va0 = VLEV_FLOAT(a_ptrs[0] + j, gvl);
va1 = VLEV_FLOAT(a_ptrs[1] + j, gvl);
va2 = VLEV_FLOAT(a_ptrs[2] + j, gvl);
va3 = VLEV_FLOAT(a_ptrs[3] + j, gvl);


vx = VLEV_FLOAT(x + j, gvl);
vr0 = VFMULVV_FLOAT(va0, vx, gvl);
vr1 = VFMULVV_FLOAT(va1, vx, gvl);
vr2 = VFMULVV_FLOAT(va2, vx, gvl);
vr3 = VFMULVV_FLOAT(va3, vx, gvl);
// Floating-point addition does not satisfy the associative law, that is, (a + b) + c ≠ a + (b + c),
// so piecewise multiplication and reduction must be performed inside the loop body.
vec0 = VFREDSUM_FLOAT(vr0, vec0, gvl);
vec1 = VFREDSUM_FLOAT(vr1, vec1, gvl);
vec2 = VFREDSUM_FLOAT(vr2, vec2, gvl);
vec3 = VFREDSUM_FLOAT(vr3, vec3, gvl);
j += gvl;
}
if (j < m) {
gvl = VSETVL(m - j);
va0 = VLEV_FLOAT(a_ptrs[0] + j, gvl);
va1 = VLEV_FLOAT(a_ptrs[1] + j, gvl);
va2 = VLEV_FLOAT(a_ptrs[2] + j, gvl);
va3 = VLEV_FLOAT(a_ptrs[3] + j, gvl);


iy += inc_y;
a_ptr += lda;
}
}else{
BLASLONG stride_x = inc_x * sizeof(FLOAT);
for(i = 0; i < n; i++){
v_res = VFMVVF_FLOAT_M1(0, 1);
gvl = VSETVL(m);
j = 0;
ix = 0;
vr = VFMVVF_FLOAT(0, gvl);
for(k = 0; k < m/gvl; k++){
va = VLEV_FLOAT(&a_ptr[j], gvl);
vx = VLSEV_FLOAT(&x[ix], stride_x, gvl);
vr = VFMULVV_FLOAT(va, vx, gvl);
v_res = VFREDSUM_FLOAT(vr, v_res, gvl);
j += gvl;
ix += inc_x * gvl;
}
if(j < m){
gvl = VSETVL(m-j);
va = VLEV_FLOAT(&a_ptr[j], gvl);
vx = VLSEV_FLOAT(&x[ix], stride_x, gvl);
vr = VFMULVV_FLOAT(va, vx, gvl);
v_res = VFREDSUM_FLOAT(vr, v_res, gvl);
}
temp = (FLOAT)EXTRACT_FLOAT(v_res);
y[iy] += alpha * temp;
vx = VLEV_FLOAT(x + j, gvl);
vr0 = VFMULVV_FLOAT(va0, vx, gvl);
vr1 = VFMULVV_FLOAT(va1, vx, gvl);
vr2 = VFMULVV_FLOAT(va2, vx, gvl);
vr3 = VFMULVV_FLOAT(va3, vx, gvl);
vec0 = VFREDSUM_FLOAT(vr0, vec0, gvl);
vec1 = VFREDSUM_FLOAT(vr1, vec1, gvl);
vec2 = VFREDSUM_FLOAT(vr2, vec2, gvl);
vec3 = VFREDSUM_FLOAT(vr3, vec3, gvl);
}
*y_ptrs[0] += alpha * (FLOAT)(EXTRACT_FLOAT(vec0));
*y_ptrs[1] += alpha * (FLOAT)(EXTRACT_FLOAT(vec1));
*y_ptrs[2] += alpha * (FLOAT)(EXTRACT_FLOAT(vec2));
*y_ptrs[3] += alpha * (FLOAT)(EXTRACT_FLOAT(vec3));
}
// deal with the tail
for (; i < n; i++) {
v_res = VFMVVF_FLOAT_M1(0, vlmax);
gvl = VSETVL(m);
j = 0;
a_ptrs[0] = a + i * lda;
y_ptrs[0] = y + i * inc_y;
vr0 = VFMVVF_FLOAT(0, gvl);
for (k = 0; k < m / gvl; k++) {
va0 = VLEV_FLOAT(a_ptrs[0] + j, gvl);
vx = VLEV_FLOAT(x + j, gvl);
vr0 = VFMULVV_FLOAT(va0, vx, gvl);
v_res = VFREDSUM_FLOAT(vr0, v_res, gvl);
j += gvl;
}
if (j < m) {
gvl = VSETVL(m - j);
va0 = VLEV_FLOAT(a_ptrs[0] + j, gvl);
vx = VLEV_FLOAT(x + j, gvl);
vr0 = VFMULVV_FLOAT(va0, vx, gvl);
v_res = VFREDSUM_FLOAT(vr0, v_res, gvl);
}
*y_ptrs[0] += alpha * (FLOAT)(EXTRACT_FLOAT(v_res));
}
#else
for(i = 0; i < n; i++){
v_res = VFMVVF_FLOAT_M1(0, 1);
gvl = VSETVL(m);
j = 0;
vr = VFMVVF_FLOAT(0, gvl);
for(k = 0; k < m/gvl; k++){
va = VLEV_FLOAT(&a_ptr[j], gvl);
vx = VLEV_FLOAT(&x[j], gvl);
vr = VFMULVV_FLOAT(va, vx, gvl); // could vfmacc here and reduce outside loop
v_res = VFREDSUM_FLOAT(vr, v_res, gvl); // but that reordering diverges far enough from scalar path to make tests fail
j += gvl;
}
if(j < m){
gvl = VSETVL(m-j);
va = VLEV_FLOAT(&a_ptr[j], gvl);
vx = VLEV_FLOAT(&x[j], gvl);
vr = VFMULVV_FLOAT(va, vx, gvl);
v_res = VFREDSUM_FLOAT(vr, v_res, gvl);
}
temp = (FLOAT)EXTRACT_FLOAT(v_res);
y[iy] += alpha * temp;




iy += inc_y;
a_ptr += lda;
iy += inc_y;
a_ptr += lda;
}
#endif
} else {
BLASLONG stride_x = inc_x * sizeof(FLOAT);
for(i = 0; i < n; i++){
v_res = VFMVVF_FLOAT_M1(0, 1);
gvl = VSETVL(m);
j = 0;
ix = 0;
vr = VFMVVF_FLOAT(0, gvl);
for(k = 0; k < m/gvl; k++){
va = VLEV_FLOAT(&a_ptr[j], gvl);
vx = VLSEV_FLOAT(&x[ix], stride_x, gvl);
vr = VFMULVV_FLOAT(va, vx, gvl);
v_res = VFREDSUM_FLOAT(vr, v_res, gvl);
j += gvl;
ix += inc_x * gvl;
}
if(j < m){
gvl = VSETVL(m-j);
va = VLEV_FLOAT(&a_ptr[j], gvl);
vx = VLSEV_FLOAT(&x[ix], stride_x, gvl);
vr = VFMULVV_FLOAT(va, vx, gvl);
v_res = VFREDSUM_FLOAT(vr, v_res, gvl);
} }
}
temp = (FLOAT)EXTRACT_FLOAT(v_res);
y[iy] += alpha * temp;




return(0);
iy += inc_y;
a_ptr += lda;
}
}

return (0);
} }

Loading…
Cancel
Save