Browse Source

Optimize the zgemm_tcopy_4_rvv function to be compatible with the situations where the vector lengths(vlens) are 128 and 256.

Signed-off-by: tingbo.liao <tingbo.liao@starfivetech.com>
tags/v0.3.29
tingbo.liao 11 months ago
parent
commit
0bea1cfd9d
1 changed files with 25 additions and 111 deletions
  1. +25
    -111
      kernel/riscv64/zgemm_tcopy_4_rvv.c

+ 25
- 111
kernel/riscv64/zgemm_tcopy_4_rvv.c View File

@@ -28,35 +28,22 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "common.h" #include "common.h"


#if !defined(DOUBLE) #if !defined(DOUBLE)
#define VSETVL(n) __riscv_vsetvl_e32m1(n)
#define FLOAT_V_T vfloat32m1_t
#define FLOAT_VX2_T vfloat32m1x2_t
#define FLOAT_VX4_T vfloat32m1x4_t
#define FLOAT_VX8_T vfloat32m1x8_t
#define VLEV_FLOAT __riscv_vle32_v_f32m1
#define VSEV_FLOAT __riscv_vse32_v_f32m1
#define VLSSEG2_FLOAT __riscv_vlsseg2e32_v_f32m1x2
#define VLSSEG4_FLOAT __riscv_vlsseg4e32_v_f32m1x4
#define VLSSEG8_FLOAT __riscv_vlsseg8e32_v_f32m1x8
#define VSSEG2_FLOAT __riscv_vsseg2e32_v_f32m1x2
#define VSSEG4_FLOAT __riscv_vsseg4e32_v_f32m1x4
#define VSSEG8_FLOAT __riscv_vsseg8e32_v_f32m1x8
#define FLOAT_V_T vfloat32m2_t
#define FLOAT_V_T_HALF vfloat32m1_t
#define VLEV_FLOAT __riscv_vle32_v_f32m2
#define VLEV_FLOAT_HALF __riscv_vle32_v_f32m1
#define VSEV_FLOAT __riscv_vse32_v_f32m2
#define VSEV_FLOAT_HALF __riscv_vse32_v_f32m1
#else #else
#define VSETVL(n) __riscv_vsetvl_e64m1(n)
#define FLOAT_V_T vfloat64m1_t
#define FLOAT_VX2_T vfloat64m1x2_t
#define FLOAT_VX4_T vfloat64m1x4_t
#define FLOAT_VX8_T vfloat64m1x8_t
#define VLEV_FLOAT __riscv_vle64_v_f64m1
#define VSEV_FLOAT __riscv_vse64_v_f64m1
#define VLSSEG2_FLOAT __riscv_vlsseg2e64_v_f64m1x2
#define VLSSEG4_FLOAT __riscv_vlsseg4e64_v_f64m1x4
#define VLSSEG8_FLOAT __riscv_vlsseg8e64_v_f64m1x8
#define VSSEG2_FLOAT __riscv_vsseg2e64_v_f64m1x2
#define VSSEG4_FLOAT __riscv_vsseg4e64_v_f64m1x4
#define VSSEG8_FLOAT __riscv_vsseg8e64_v_f64m1x8
#define FLOAT_V_T vfloat64m4_t
#define FLOAT_V_T_HALF vfloat64m2_t
#define VLEV_FLOAT __riscv_vle64_v_f64m4
#define VLEV_FLOAT_HALF __riscv_vle64_v_f64m2
#define VSEV_FLOAT __riscv_vse64_v_f64m4
#define VSEV_FLOAT_HALF __riscv_vse64_v_f64m2
#endif #endif



int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, FLOAT *b){ int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, FLOAT *b){


BLASLONG i, j; BLASLONG i, j;
@@ -67,9 +54,7 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, FLOAT *b){
IFLOAT *boffset, *boffset1, *boffset2, *boffset3; IFLOAT *boffset, *boffset1, *boffset2, *boffset3;


FLOAT_V_T v0; FLOAT_V_T v0;
FLOAT_VX2_T vx2;
FLOAT_VX4_T vx4;
FLOAT_VX8_T vx8;
FLOAT_V_T_HALF v1;


size_t vl; size_t vl;


@@ -80,86 +65,12 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, FLOAT *b){
boffset2 = b + 2 * m * (n & ~3); boffset2 = b + 2 * m * (n & ~3);
boffset3 = b + 2 * m * (n & ~1); boffset3 = b + 2 * m * (n & ~1);


for(j = (m >> 2); j > 0; j--) {

aoffset1 = aoffset;
aoffset += 8 * lda;

boffset1 = boffset;
boffset += 32;

for(i = (n >> 2); i > 0; i--) {
vl = 4;

vx8 = VLSSEG8_FLOAT(aoffset1, lda * sizeof(FLOAT) * 2, vl);
VSSEG8_FLOAT(boffset1, vx8, vl);

aoffset1 += 8;
boffset1 += m * 8;
}

if (n & 2) {
vl = 4;

vx4 = VLSSEG4_FLOAT(aoffset1, lda * sizeof(FLOAT) * 2, vl);
VSSEG4_FLOAT(boffset2, vx4, vl);

aoffset1 += 4;
boffset2 += 16;
}

if (n & 1) {
vl = 4;

vx2 = VLSSEG2_FLOAT(aoffset1, lda * sizeof(FLOAT) * 2, vl);
VSSEG2_FLOAT(boffset3, vx2, vl);

aoffset1 += 2;
boffset3 += 8;
}
}

if (m & 2) {
for(j = m; j > 0; j--) {
aoffset1 = aoffset; aoffset1 = aoffset;
aoffset += 4 * lda;

boffset1 = boffset; boffset1 = boffset;
boffset += 16;

for(i = (n >> 2); i > 0; i--) {
vl = 2;

vx8 = VLSSEG8_FLOAT(aoffset1, lda * sizeof(FLOAT) * 2, vl);
VSSEG8_FLOAT(boffset1, vx8, vl);

aoffset1 += 8;
boffset1 += m * 8;
}

if (n & 2) {
vl = 2;

vx4 = VLSSEG4_FLOAT(aoffset1, lda * sizeof(FLOAT) * 2, vl);
VSSEG4_FLOAT(boffset2, vx4, vl);

aoffset1 += 4;
boffset2 += 8;
}

if (n & 1) {
vl = 2;


vx2 = VLSSEG2_FLOAT(aoffset1, lda * sizeof(FLOAT) * 2, vl);
VSSEG2_FLOAT(boffset3, vx2, vl);

//aoffset1 += 2;
boffset3 += 4;
}
}

if (m & 1) {
aoffset1 = aoffset;
boffset1 = boffset;
aoffset += 2 * lda;
boffset += 8;


for(i = (n >> 2); i > 0; i--) { for(i = (n >> 2); i > 0; i--) {
vl = 8; vl = 8;
@@ -174,16 +85,19 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, FLOAT *b){
if (n & 2) { if (n & 2) {
vl = 4; vl = 4;


v0 = VLEV_FLOAT(aoffset1, vl);
VSEV_FLOAT(boffset2, v0, vl);
v1 = VLEV_FLOAT_HALF(aoffset1, vl);
VSEV_FLOAT_HALF(boffset2, v1, vl);


aoffset1 += 4; aoffset1 += 4;
//boffset2 += 4;
boffset2 += 4;
} }


if (n & 1) { if (n & 1) {
*(boffset3) = *(aoffset1);
*(boffset3 + 1) = *(aoffset1 + 1);
*(boffset3) = *(aoffset1);
*(boffset3 + 1) = *(aoffset1 + 1);

aoffset1 += 2;
boffset3 += 2;
} }
} }




Loading…
Cancel
Save