Browse Source

Merge pull request #4382 from Mousius/sve-dot-again

Tweak SVE dot kernel
tags/v0.3.26
Martin Kroeker GitHub 1 year ago
parent
commit
fa220b2969
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 74 additions and 26 deletions
  1. +74
    -26
      kernel/arm64/dot_kernel_sve.c

+ 74
- 26
kernel/arm64/dot_kernel_sve.c View File

@@ -1,4 +1,5 @@
/*************************************************************************** /***************************************************************************
Copyright (c) 2023, The OpenBLAS Project
Copyright (c) 2022, Arm Ltd Copyright (c) 2022, Arm Ltd
All rights reserved. All rights reserved.
Redistribution and use in source and binary forms, with or without Redistribution and use in source and binary forms, with or without
@@ -30,37 +31,84 @@ THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include <arm_sve.h> #include <arm_sve.h>


#ifdef DOUBLE #ifdef DOUBLE
#define SVE_TYPE svfloat64_t
#define SVE_ZERO svdup_f64(0.0)
#define SVE_WHILELT svwhilelt_b64
#define SVE_ALL svptrue_b64()
#define SVE_WIDTH svcntd()
#define DTYPE "d"
#define WIDTH "d"
#define SHIFT "3"
#else #else
#define SVE_TYPE svfloat32_t
#define SVE_ZERO svdup_f32(0.0)
#define SVE_WHILELT svwhilelt_b32
#define SVE_ALL svptrue_b32()
#define SVE_WIDTH svcntw()
#define DTYPE "s"
#define WIDTH "w"
#define SHIFT "2"
#endif #endif


static FLOAT dot_kernel_sve(BLASLONG n, FLOAT *x, FLOAT *y) {
SVE_TYPE acc_a = SVE_ZERO;
SVE_TYPE acc_b = SVE_ZERO;
#define COUNT \
" cnt"WIDTH" x9 \n"
#define SETUP_TRUE \
" ptrue p0."DTYPE" \n"
#define OFFSET_INPUTS \
" add x12, %[X_], x9, lsl #"SHIFT" \n" \
" add x13, %[Y_], x9, lsl #"SHIFT" \n"
#define TAIL_WHILE \
" whilelo p1."DTYPE", x8, x0 \n"
#define UPDATE(pg, x,y,out) \
" ld1"WIDTH" { z2."DTYPE" }, "pg"/z, ["x", x8, lsl #"SHIFT"] \n" \
" ld1"WIDTH" { z3."DTYPE" }, "pg"/z, ["y", x8, lsl #"SHIFT"] \n" \
" fmla "out"."DTYPE", "pg"/m, z2."DTYPE", z3."DTYPE" \n"
#define SUM_VECTOR(v) \
" faddv "DTYPE""v", p0, z"v"."DTYPE" \n"
#define RET \
" fadd %"DTYPE"[RET_], "DTYPE"1, "DTYPE"0 \n"


BLASLONG sve_width = SVE_WIDTH;
#define DOT_KERNEL \
COUNT \
" mov z1.d, #0 \n" \
" mov z0.d, #0 \n" \
" mov x8, #0 \n" \
" movi d1, #0x0 \n" \
SETUP_TRUE \
" neg x10, x9, lsl #1 \n" \
" ands x11, x10, x0 \n" \
" b.eq 2f // skip_2x \n" \
OFFSET_INPUTS \
"1: // vector_2x \n" \
UPDATE("p0", "%[X_]", "%[Y_]", "z1") \
UPDATE("p0", "x12", "x13", "z0") \
" sub x8, x8, x10 \n" \
" cmp x8, x11 \n" \
" b.lo 1b // vector_2x \n" \
SUM_VECTOR("1") \
"2: // skip_2x \n" \
" neg x10, x9 \n" \
" and x10, x10, x0 \n" \
" cmp x8, x10 \n" \
" b.hs 4f // tail \n" \
"3: // vector_1x \n" \
UPDATE("p0", "%[X_]", "%[Y_]", "z0") \
" add x8, x8, x9 \n" \
" cmp x8, x10 \n" \
" b.lo 3b // vector_1x \n" \
"4: // tail \n" \
" cmp x10, x0 \n" \
" b.eq 5f // end \n" \
TAIL_WHILE \
UPDATE("p1", "%[X_]", "%[Y_]", "z0") \
"5: // end \n" \
SUM_VECTOR("0") \
RET


for (BLASLONG i = 0; i < n; i += sve_width * 2) {
svbool_t pg_a = SVE_WHILELT((uint64_t)i, (uint64_t)n);
svbool_t pg_b = SVE_WHILELT((uint64_t)(i + sve_width), (uint64_t)n);
static
FLOAT
dot_kernel_sve(BLASLONG n, FLOAT* x, FLOAT* y)
{
FLOAT ret;


SVE_TYPE x_vec_a = svld1(pg_a, &x[i]);
SVE_TYPE y_vec_a = svld1(pg_a, &y[i]);
SVE_TYPE x_vec_b = svld1(pg_b, &x[i + sve_width]);
SVE_TYPE y_vec_b = svld1(pg_b, &y[i + sve_width]);
asm(DOT_KERNEL
:
[RET_] "=&w" (ret)
:
[N_] "r" (n),
[X_] "r" (x),
[Y_] "r" (y)
:);


acc_a = svmla_m(pg_a, acc_a, x_vec_a, y_vec_a);
acc_b = svmla_m(pg_b, acc_b, x_vec_b, y_vec_b);
}

return svaddv(SVE_ALL, acc_a) + svaddv(SVE_ALL, acc_b);
return ret;
} }

Loading…
Cancel
Save