Browse Source

Optimize gemv_n_sve_v1x3 kernel

- Calculate predicate outside the loop
- Divide matrix in blocks of 3
pull/5292/head
Sharif Inamdar 3 months ago
parent
commit
8279e68805
2 changed files with 61 additions and 40 deletions
  1. +4
    -1
      CONTRIBUTORS.md
  2. +57
    -39
      kernel/arm64/gemv_n_sve_v1x3.c

+ 4
- 1
CONTRIBUTORS.md View File

@@ -253,4 +253,7 @@ In chronological order:
* [2025-02-27] Add sbgemv_n_neon kernel

* Abhishek Kumar <https://github.com/abhishek-iitmadras>
* [2025-04-22] Optimise dot kernel for NEOVERSE V1
* [2025-04-22] Optimise dot kernel for NEOVERSE V1

* Sharif Inamdar <sharif.inamdar@arm.com>
* [2025-06-05] Optimize gemv_n_sve_v1x3 kernel

+ 57
- 39
kernel/arm64/gemv_n_sve_v1x3.c View File

@@ -52,17 +52,17 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT alpha, FLOAT *a,
BLASLONG lda, FLOAT *x, BLASLONG inc_x, FLOAT *y, BLASLONG inc_y,
FLOAT *buffer)
{
BLASLONG i;
BLASLONG ix,iy;
BLASLONG j;
FLOAT *a_ptr;
BLASLONG i, j;
BLASLONG ix = 0;
BLASLONG iy;
FLOAT *a_ptr = a;
FLOAT temp;

ix = 0;
a_ptr = a;

if (inc_y == 1) {
BLASLONG width = (n + 3 - 1) / 3;
BLASLONG width = n / 3; // Only process full 3-column blocks
BLASLONG sve_size = SV_COUNT();
svbool_t pg_full = SV_TRUE();
svbool_t pg_tail = SV_WHILE(0, m % sve_size);

FLOAT *a0_ptr = a_ptr + lda * width * 0;
FLOAT *a1_ptr = a_ptr + lda * width * 1;
@@ -73,57 +73,75 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT alpha, FLOAT *a,
FLOAT *x2_ptr = x + inc_x * width * 2;

for (j = 0; j < width; j++) {
svbool_t pg00 = ((j + width * 0) < n) ? SV_TRUE() : svpfalse();
svbool_t pg01 = ((j + width * 1) < n) ? SV_TRUE() : svpfalse();
svbool_t pg02 = ((j + width * 2) < n) ? SV_TRUE() : svpfalse();
SV_TYPE temp0_vec = SV_DUP(alpha * x0_ptr[ix]);
SV_TYPE temp1_vec = SV_DUP(alpha * x1_ptr[ix]);
SV_TYPE temp2_vec = SV_DUP(alpha * x2_ptr[ix]);

SV_TYPE temp0_vec = ((j + width * 0) < n) ? SV_DUP(alpha * x0_ptr[ix]) : SV_DUP(0.0);
SV_TYPE temp1_vec = ((j + width * 1) < n) ? SV_DUP(alpha * x1_ptr[ix]) : SV_DUP(0.0);
SV_TYPE temp2_vec = ((j + width * 2) < n) ? SV_DUP(alpha * x2_ptr[ix]) : SV_DUP(0.0);
i = 0;
BLASLONG sve_size = SV_COUNT();
while ((i + sve_size * 1 - 1) < m) {
SV_TYPE y0_vec = svld1_vnum(SV_TRUE(), y + i, 0);
while ((i + sve_size - 1) < m) {
SV_TYPE y0_vec = svld1(pg_full, y + i);

SV_TYPE a00_vec = svld1_vnum(pg00, a0_ptr + i, 0);
SV_TYPE a01_vec = svld1_vnum(pg01, a1_ptr + i, 0);
SV_TYPE a02_vec = svld1_vnum(pg02, a2_ptr + i, 0);
SV_TYPE a00_vec = svld1(pg_full, a0_ptr + i);
SV_TYPE a01_vec = svld1(pg_full, a1_ptr + i);
SV_TYPE a02_vec = svld1(pg_full, a2_ptr + i);

y0_vec = svmla_m(pg00, y0_vec, temp0_vec, a00_vec);
y0_vec = svmla_m(pg01, y0_vec, temp1_vec, a01_vec);
y0_vec = svmla_m(pg02, y0_vec, temp2_vec, a02_vec);
y0_vec = svmla_x(pg_full, y0_vec, temp0_vec, a00_vec);
y0_vec = svmla_x(pg_full, y0_vec, temp1_vec, a01_vec);
y0_vec = svmla_x(pg_full, y0_vec, temp2_vec, a02_vec);

svst1_vnum(SV_TRUE(), y + i, 0, y0_vec);
i += sve_size * 1;
svst1(pg_full, y + i, y0_vec);
i += sve_size;
}

if (i < m) {
svbool_t pg0 = SV_WHILE(i + sve_size * 0, m);

pg00 = svand_z(SV_TRUE(), pg0, pg00);
pg01 = svand_z(SV_TRUE(), pg0, pg01);
pg02 = svand_z(SV_TRUE(), pg0, pg02);
SV_TYPE y0_vec = svld1(pg_tail, y + i);

SV_TYPE y0_vec = svld1_vnum(pg0, y + i, 0);
SV_TYPE a00_vec = svld1(pg_tail, a0_ptr + i);
SV_TYPE a01_vec = svld1(pg_tail, a1_ptr + i);
SV_TYPE a02_vec = svld1(pg_tail, a2_ptr + i);

SV_TYPE a00_vec = svld1_vnum(pg00, a0_ptr + i, 0);
SV_TYPE a01_vec = svld1_vnum(pg01, a1_ptr + i, 0);
SV_TYPE a02_vec = svld1_vnum(pg02, a2_ptr + i, 0);
y0_vec = svmla_m(pg_tail, y0_vec, temp0_vec, a00_vec);
y0_vec = svmla_m(pg_tail, y0_vec, temp1_vec, a01_vec);
y0_vec = svmla_m(pg_tail, y0_vec, temp2_vec, a02_vec);

y0_vec = svmla_m(pg00, y0_vec, temp0_vec, a00_vec);
y0_vec = svmla_m(pg01, y0_vec, temp1_vec, a01_vec);
y0_vec = svmla_m(pg02, y0_vec, temp2_vec, a02_vec);

svst1_vnum(pg0, y + i, 0, y0_vec);
svst1(pg_tail, y + i, y0_vec);
}
a0_ptr += lda;
a1_ptr += lda;
a2_ptr += lda;
ix += inc_x;
}
// Handle remaining n % 3 columns
for (j = width * 3; j < n; j++) {
FLOAT *a_col = a + j * lda;
temp = alpha * x[j * inc_x];
SV_TYPE temp_vec = SV_DUP(temp);

i = 0;
while ((i + sve_size - 1) < m) {
SV_TYPE y_vec = svld1(pg_full, y + i);

SV_TYPE a_vec = svld1(pg_full, a_col + i);

y_vec = svmla_x(pg_full, y_vec, temp_vec, a_vec);

svst1(pg_full, y + i, y_vec);
i += sve_size;
}
if (i < m) {
SV_TYPE y_vec = svld1(pg_tail, y + i);

SV_TYPE a_vec = svld1(pg_tail, a_col + i);

y_vec = svmla_m(pg_tail, y_vec, temp_vec, a_vec);

svst1(pg_tail, y + i, y_vec);
}
}
return(0);
}

// Fallback scalar loop
for (j = 0; j < n; j++) {
temp = alpha * x[ix];
iy = 0;


Loading…
Cancel
Save