Forward GEMM to GEMV when one argument is actually a vectortags/v0.3.28^2
| @@ -274,9 +274,18 @@ endif | |||
| ifeq ($(ARCH), loongarch64) | |||
| SMALL_MATRIX_OPT = 1 | |||
| endif | |||
| ifeq ($(ARCH), arm64) | |||
| GEMM_GEMV_FORWARD = 1 | |||
| endif | |||
| ifeq ($(SMALL_MATRIX_OPT), 1) | |||
| CCOMMON_OPT += -DSMALL_MATRIX_OPT | |||
| endif | |||
| ifeq ($(GEMM_GEMV_FORWARD), 1) | |||
| ifneq ($(ONLY_CBLAS), 1) | |||
| CCOMMON_OPT += -DGEMM_GEMV_FORWARD | |||
| endif | |||
| endif | |||
| # This operation is expensive, so execution should be once. | |||
| ifndef GOTOBLAS_MAKEFILE | |||
| @@ -391,6 +391,13 @@ endif () | |||
| if (X86_64 OR ${CORE} STREQUAL POWER10) | |||
| set(SMALL_MATRIX_OPT TRUE) | |||
| endif () | |||
| if (ARM64) | |||
| set(GEMM_GEMV_FORWARD TRUE) | |||
| endif () | |||
| if (GEMM_GEMV_FORWARD AND NOT ONLY_CBLAS) | |||
| set(CCOMMON_OPT "${CCOMMON_OPT} -DGEMM_GEMV_FORWARD") | |||
| endif () | |||
| if (SMALL_MATRIX_OPT) | |||
| set(CCOMMON_OPT "${CCOMMON_OPT} -DSMALL_MATRIX_OPT") | |||
| endif () | |||
| @@ -1,4 +1,5 @@ | |||
| /*********************************************************************/ | |||
| /* Copyright 2024 The OpenBLAS Project */ | |||
| /* Copyright 2009, 2010 The University of Texas at Austin. */ | |||
| /* All rights reserved. */ | |||
| /* */ | |||
| @@ -47,12 +48,16 @@ | |||
| #define SMP_THRESHOLD_MIN 65536.0 | |||
| #ifdef XDOUBLE | |||
| #define ERROR_NAME "QGEMM " | |||
| #define GEMV BLASFUNC(qgemv) | |||
| #elif defined(DOUBLE) | |||
| #define ERROR_NAME "DGEMM " | |||
| #define GEMV BLASFUNC(dgemv) | |||
| #elif defined(BFLOAT16) | |||
| #define ERROR_NAME "SBGEMM " | |||
| #define GEMV BLASFUNC(sbgemv) | |||
| #else | |||
| #define ERROR_NAME "SGEMM " | |||
| #define GEMV BLASFUNC(sgemv) | |||
| #endif | |||
| #else | |||
| #define SMP_THRESHOLD_MIN 8192.0 | |||
| @@ -493,6 +498,52 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS | |||
| args.m, args.n, args.k, args.lda, args.ldb, args.ldc); | |||
| #endif | |||
| #if defined(GEMM_GEMV_FORWARD) && !defined(GEMM3M) && !defined(COMPLEX) | |||
| // Check if we can convert GEMM -> GEMV | |||
| if (args.k != 0) { | |||
| if (args.n == 1) { | |||
| blasint inc_x = 1; | |||
| blasint inc_y = 1; | |||
| // These were passed in as blasint, but the struct translates them to blaslong | |||
| blasint m = args.m; | |||
| blasint n = args.k; | |||
| blasint lda = args.lda; | |||
| // Create new transpose parameters | |||
| char NT = 'N'; | |||
| if (transa & 1) { | |||
| NT = 'T'; | |||
| m = args.k; | |||
| n = args.m; | |||
| } | |||
| if (transb & 1) { | |||
| inc_x = args.ldb; | |||
| } | |||
| GEMV(&NT, &m, &n, args.alpha, args.a, &lda, args.b, &inc_x, args.beta, args.c, &inc_y); | |||
| return; | |||
| } | |||
| if (args.m == 1) { | |||
| blasint inc_x = args.lda; | |||
| blasint inc_y = args.ldc; | |||
| // These were passed in as blasint, but the struct translates them to blaslong | |||
| blasint m = args.k; | |||
| blasint n = args.n; | |||
| blasint ldb = args.ldb; | |||
| // Create new transpose parameters | |||
| char NT = 'T'; | |||
| if (transa & 1) { | |||
| inc_x = 1; | |||
| } | |||
| if (transb & 1) { | |||
| NT = 'N'; | |||
| m = args.n; | |||
| n = args.k; | |||
| } | |||
| GEMV(&NT, &m, &n, args.alpha, args.b, &ldb, args.a, &inc_x, args.beta, args.c, &inc_y); | |||
| return; | |||
| } | |||
| } | |||
| #endif | |||
| IDEBUG_START; | |||
| FUNCTION_PROFILE_START(); | |||
| @@ -1 +1,4 @@ | |||
| include $(KERNELDIR)/KERNEL.ARMV8SVE | |||
| SGEMVTKERNEL = gemv_t_sve.c | |||
| DGEMVTKERNEL = gemv_t_sve.c | |||
| @@ -1,5 +1,5 @@ | |||
| /******************************************************************************* | |||
| Copyright (c) 2015, The OpenBLAS Project | |||
| Copyright (c) 2015, 2024 The OpenBLAS Project | |||
| All rights reserved. | |||
| Redistribution and use in source and binary forms, with or without | |||
| modification, are permitted provided that the following conditions are | |||
| @@ -170,39 +170,48 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||
| .macro KERNEL_F32_FINALIZE | |||
| #if !defined(DOUBLE) | |||
| fadd v1.4s, v1.4s, v2.4s | |||
| // F8 only has 2 accumulators | |||
| // so add into those pairs | |||
| fadd v1.4s, v1.4s, v3.4s | |||
| fadd v1.4s, v1.4s, v4.4s | |||
| #else | |||
| fadd v1.2d, v1.2d, v2.2d | |||
| fadd v1.2d, v1.2d, v3.2d | |||
| fadd v1.2d, v1.2d, v4.2d | |||
| fadd v2.4s, v2.4s, v4.4s | |||
| #endif | |||
| .endm | |||
| .macro KERNEL_F4 | |||
| .macro KERNEL_F8 | |||
| #if !defined(DOUBLE) | |||
| ld1 {v2.4s}, [A_PTR], #16 | |||
| ld1 {v3.4s}, [X_PTR], #16 | |||
| fmla v1.4s, v2.4s, v3.4s | |||
| #else | |||
| ld1 {v2.2d}, [A_PTR], #16 | |||
| ld1 {v3.2d}, [X_PTR], #16 | |||
| fmla v1.2d, v2.2d, v3.2d | |||
| ld1 {v4.2d}, [A_PTR], #16 | |||
| ld1 {v5.2d}, [X_PTR], #16 | |||
| fmla v1.2d, v4.2d, v5.2d | |||
| ld1 {v13.4s, v14.4s}, [A_PTR], #32 | |||
| ld1 {v17.4s, v18.4s}, [X_PTR], #32 | |||
| fmla v1.4s, v13.4s, v17.4s | |||
| fmla v2.4s, v14.4s, v18.4s | |||
| #else | |||
| ld1 {v13.2d, v14.2d, v15.2d, v16.2d}, [A_PTR], #64 | |||
| ld1 {v17.2d, v18.2d, v19.2d, v20.2d}, [X_PTR], #64 | |||
| fmla v1.2d, v13.2d, v17.2d | |||
| fmla v2.2d, v14.2d, v18.2d | |||
| fmla v3.2d, v15.2d, v19.2d | |||
| fmla v4.2d, v16.2d, v20.2d | |||
| #endif | |||
| .endm | |||
| .macro KERNEL_F4_FINALIZE | |||
| .macro KERNEL_F8_FINALIZE | |||
| #if !defined(DOUBLE) | |||
| ext v2.16b, v1.16b, v1.16b, #8 | |||
| // Take the top two elements of v1 and | |||
| // put them into the first two lanes of v3 | |||
| ext v3.16b, v1.16b, v1.16b, #8 | |||
| fadd v1.2s, v1.2s, v3.2s | |||
| ext v4.16b, v2.16b, v2.16b, #8 | |||
| fadd v2.2s, v2.2s, v4.2s | |||
| // Final pair | |||
| fadd v1.2s, v1.2s, v2.2s | |||
| faddp TEMP, v1.2s | |||
| #else | |||
| faddp TEMP, v1.2d | |||
| faddp TEMP1, v2.2d | |||
| faddp TEMP2, v3.2d | |||
| faddp TEMP3, v4.2d | |||
| fadd TEMP, TEMP, TEMP1 | |||
| fadd TEMP2, TEMP2, TEMP3 | |||
| fadd TEMP, TEMP, TEMP2 | |||
| #endif | |||
| .endm | |||
| @@ -258,7 +267,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||
| asr I, M, #5 | |||
| cmp I, xzr | |||
| beq .Lgemv_t_kernel_F4 | |||
| beq .Lgemv_t_kernel_F8 | |||
| .Lgemv_t_kernel_F320: | |||
| @@ -269,24 +278,24 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||
| KERNEL_F32_FINALIZE | |||
| .Lgemv_t_kernel_F4: | |||
| .Lgemv_t_kernel_F8: | |||
| ands I, M, #31 | |||
| asr I, I, #2 | |||
| asr I, I, #3 | |||
| cmp I, xzr | |||
| beq .Lgemv_t_kernel_F1 | |||
| .Lgemv_t_kernel_F40: | |||
| .Lgemv_t_kernel_F80: | |||
| KERNEL_F4 | |||
| KERNEL_F8 | |||
| subs I, I, #1 | |||
| bne .Lgemv_t_kernel_F40 | |||
| bne .Lgemv_t_kernel_F80 | |||
| .Lgemv_t_kernel_F1: | |||
| KERNEL_F4_FINALIZE | |||
| KERNEL_F8_FINALIZE | |||
| ands I, M, #3 | |||
| ands I, M, #7 | |||
| ble .Lgemv_t_kernel_F_END | |||
| .Lgemv_t_kernel_F10: | |||
| @@ -59,20 +59,46 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT alpha, FLOAT *a, BLASLO | |||
| a_ptr = a; | |||
| if (inc_x == 1) { | |||
| svbool_t pg_true = SV_TRUE(); | |||
| uint64_t sve_size = SV_COUNT(); | |||
| uint64_t sve_size2 = sve_size * 2; | |||
| BLASLONG m1 = m & -sve_size; | |||
| BLASLONG m2 = m & -sve_size2; | |||
| for (j = 0; j < n; j++) { | |||
| BLASLONG i = 0; | |||
| SV_TYPE temp_vec_v2_0 = SV_DUP(0.0); | |||
| SV_TYPE temp_vec_v2_1 = SV_DUP(0.0); | |||
| for (; i < m2; i += sve_size2) { | |||
| SV_TYPE a_vec0 = svld1(pg_true, a_ptr + i); | |||
| SV_TYPE x_vec0 = svld1(pg_true, x + i); | |||
| SV_TYPE a_vec1 = svld1(pg_true, a_ptr + i + sve_size); | |||
| SV_TYPE x_vec1 = svld1(pg_true, x + i + sve_size); | |||
| temp_vec_v2_0 = svmla_m(pg_true, temp_vec_v2_0, a_vec0, x_vec0); | |||
| temp_vec_v2_1 = svmla_m(pg_true, temp_vec_v2_1, a_vec1, x_vec1); | |||
| } | |||
| SV_TYPE temp_vec_v1 = SV_DUP(0.0); | |||
| for (; i < m1; i += sve_size) { | |||
| SV_TYPE a_vec0 = svld1(pg_true, a_ptr + i); | |||
| SV_TYPE x_vec0 = svld1(pg_true, x + i); | |||
| temp_vec_v1 = svmla_m(pg_true, temp_vec_v1, a_vec0, x_vec0); | |||
| } | |||
| SV_TYPE temp_vec = SV_DUP(0.0); | |||
| i = 0; | |||
| svbool_t pg = SV_WHILE(i, m); | |||
| while (svptest_any(SV_TRUE(), pg)) { | |||
| for (; i < m; i += sve_size) { | |||
| svbool_t pg = SV_WHILE(i, m); | |||
| SV_TYPE a_vec = svld1(pg, a_ptr + i); | |||
| SV_TYPE x_vec = svld1(pg, x + i); | |||
| temp_vec = svmla_m(pg, temp_vec, a_vec, x_vec); | |||
| i += sve_size; | |||
| pg = SV_WHILE(i, m); | |||
| } | |||
| temp = svaddv(SV_TRUE(), temp_vec); | |||
| y[iy] += alpha * temp; | |||
| y[iy] += alpha * ( | |||
| (svaddv(SV_TRUE(), temp_vec_v2_0) + svaddv(SV_TRUE(), temp_vec)) + | |||
| (svaddv(SV_TRUE(), temp_vec_v2_1) + svaddv(SV_TRUE(), temp_vec_v1)) | |||
| ); | |||
| iy += inc_y; | |||
| a_ptr += lda; | |||
| } | |||