| @@ -52,6 +52,13 @@ void sgemm_direct(BLASLONG M, BLASLONG N, BLASLONG K, | |||||
| float * B, BLASLONG strideB, | float * B, BLASLONG strideB, | ||||
| float * R, BLASLONG strideR); | float * R, BLASLONG strideR); | ||||
| void sgemm_direct_alpha_beta(BLASLONG M, BLASLONG N, BLASLONG K, | |||||
| float alpha, | |||||
| float * A, BLASLONG strideA, | |||||
| float * B, BLASLONG strideB, | |||||
| float beta, | |||||
| float * R, BLASLONG strideR); | |||||
| int sgemm_direct_performant(BLASLONG M, BLASLONG N, BLASLONG K); | int sgemm_direct_performant(BLASLONG M, BLASLONG N, BLASLONG K); | ||||
| int shgemm_beta(BLASLONG, BLASLONG, BLASLONG, float, | int shgemm_beta(BLASLONG, BLASLONG, BLASLONG, float, | ||||
| @@ -240,6 +240,7 @@ int (*shgemm_otcopy )(BLASLONG, BLASLONG, hfloat16 *, BLASLONG, hfloat16 *); | |||||
| #endif | #endif | ||||
| #ifdef ARCH_ARM64 | #ifdef ARCH_ARM64 | ||||
| void (*sgemm_direct) (BLASLONG, BLASLONG, BLASLONG, float *, BLASLONG , float *, BLASLONG , float * , BLASLONG); | void (*sgemm_direct) (BLASLONG, BLASLONG, BLASLONG, float *, BLASLONG , float *, BLASLONG , float * , BLASLONG); | ||||
| void (*sgemm_direct_alpha_beta) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float, float * , BLASLONG); | |||||
| #endif | #endif | ||||
| @@ -49,6 +49,7 @@ | |||||
| #define SGEMM_DIRECT_PERFORMANT sgemm_direct_performant | #define SGEMM_DIRECT_PERFORMANT sgemm_direct_performant | ||||
| #define SGEMM_DIRECT sgemm_direct | #define SGEMM_DIRECT sgemm_direct | ||||
| #define SGEMM_DIRECT_ALPHA_BETA sgemm_direct_alpha_beta | |||||
| #define SGEMM_ONCOPY sgemm_oncopy | #define SGEMM_ONCOPY sgemm_oncopy | ||||
| #define SGEMM_OTCOPY sgemm_otcopy | #define SGEMM_OTCOPY sgemm_otcopy | ||||
| @@ -218,6 +219,7 @@ | |||||
| #elif ARCH_ARM64 | #elif ARCH_ARM64 | ||||
| #define SGEMM_DIRECT_PERFORMANT sgemm_direct_performant | #define SGEMM_DIRECT_PERFORMANT sgemm_direct_performant | ||||
| #define SGEMM_DIRECT gotoblas -> sgemm_direct | #define SGEMM_DIRECT gotoblas -> sgemm_direct | ||||
| #define SGEMM_DIRECT_ALPHA_BETA gotoblas -> sgemm_direct_alpha_beta | |||||
| #endif | #endif | ||||
| #define SGEMM_ONCOPY gotoblas -> sgemm_oncopy | #define SGEMM_ONCOPY gotoblas -> sgemm_oncopy | ||||
| @@ -436,6 +436,9 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS | |||||
| if (beta == 0 && alpha == 1.0 && order == CblasRowMajor && TransA == CblasNoTrans && TransB == CblasNoTrans) { | if (beta == 0 && alpha == 1.0 && order == CblasRowMajor && TransA == CblasNoTrans && TransB == CblasNoTrans) { | ||||
| SGEMM_DIRECT(m, n, k, a, lda, b, ldb, c, ldc); | SGEMM_DIRECT(m, n, k, a, lda, b, ldb, c, ldc); | ||||
| return; | return; | ||||
| }else if (order == CblasRowMajor && TransA == CblasNoTrans && TransB == CblasNoTrans) { | |||||
| SGEMM_DIRECT_ALPHA_BETA(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); | |||||
| return; | |||||
| } | } | ||||
| #endif | #endif | ||||
| #endif | #endif | ||||
| @@ -222,9 +222,11 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS) | |||||
| GenerateNamedObjects("${KERNELDIR}/${SGEMMDIRECTPERFORMANT}" "" "gemm_direct_performant" false "" "" false SINGLE) | GenerateNamedObjects("${KERNELDIR}/${SGEMMDIRECTPERFORMANT}" "" "gemm_direct_performant" false "" "" false SINGLE) | ||||
| elseif (ARM64) | elseif (ARM64) | ||||
| set (SGEMMDIRECTKERNEL sgemm_direct_arm64_sme1.c) | set (SGEMMDIRECTKERNEL sgemm_direct_arm64_sme1.c) | ||||
| set (SGEMMDIRECTKERNEL_ALPHA_BETA sgemm_direct_alpha_beta_arm64_sme1.c) | |||||
| set (SGEMMDIRECTSMEKERNEL sgemm_direct_sme1.S) | set (SGEMMDIRECTSMEKERNEL sgemm_direct_sme1.S) | ||||
| set (SGEMMDIRECTPREKERNEL sgemm_direct_sme1_preprocess.S) | set (SGEMMDIRECTPREKERNEL sgemm_direct_sme1_preprocess.S) | ||||
| GenerateNamedObjects("${KERNELDIR}/${SGEMMDIRECTKERNEL}" "" "gemm_direct" false "" "" false SINGLE) | GenerateNamedObjects("${KERNELDIR}/${SGEMMDIRECTKERNEL}" "" "gemm_direct" false "" "" false SINGLE) | ||||
| GenerateNamedObjects("${KERNELDIR}/${SGEMMDIRECTKERNEL_ALPHA_BETA}" "" "gemm_direct_alpha_beta" false "" "" false SINGLE) | |||||
| if (HAVE_SME) | if (HAVE_SME) | ||||
| GenerateNamedObjects("${KERNELDIR}/${SGEMMDIRECTSMEKERNEL}" "" "gemm_direct_sme1" false "" "" false SINGLE) | GenerateNamedObjects("${KERNELDIR}/${SGEMMDIRECTSMEKERNEL}" "" "gemm_direct_sme1" false "" "" false SINGLE) | ||||
| GenerateNamedObjects("${KERNELDIR}/${SGEMMDIRECTPREKERNEL}" "" "gemm_direct_sme1_preprocess" false "" "" false SINGLE) | GenerateNamedObjects("${KERNELDIR}/${SGEMMDIRECTPREKERNEL}" "" "gemm_direct_sme1_preprocess" false "" "" false SINGLE) | ||||
| @@ -105,6 +105,7 @@ ifeq ($(TARGET_CORE), ARMV9SME) | |||||
| HAVE_SME = 1 | HAVE_SME = 1 | ||||
| endif | endif | ||||
| SGEMMDIRECTKERNEL = sgemm_direct_arm64_sme1.c | SGEMMDIRECTKERNEL = sgemm_direct_arm64_sme1.c | ||||
| SGEMMDIRECTKERNEL_ALPHA_BETA = sgemm_direct_alpha_beta_arm64_sme1.c | |||||
| endif | endif | ||||
| endif | endif | ||||
| endif | endif | ||||
| @@ -164,7 +165,8 @@ SKERNELOBJS += \ | |||||
| endif | endif | ||||
| ifeq ($(ARCH), arm64) | ifeq ($(ARCH), arm64) | ||||
| SKERNELOBJS += \ | SKERNELOBJS += \ | ||||
| sgemm_direct$(TSUFFIX).$(SUFFIX) | |||||
| sgemm_direct$(TSUFFIX).$(SUFFIX) \ | |||||
| sgemm_direct_alpha_beta$(TSUFFIX).$(SUFFIX) | |||||
| ifdef HAVE_SME | ifdef HAVE_SME | ||||
| SKERNELOBJS += \ | SKERNELOBJS += \ | ||||
| sgemm_direct_sme1$(TSUFFIX).$(SUFFIX) \ | sgemm_direct_sme1$(TSUFFIX).$(SUFFIX) \ | ||||
| @@ -904,6 +906,8 @@ endif | |||||
| ifeq ($(ARCH), arm64) | ifeq ($(ARCH), arm64) | ||||
| $(KDIR)sgemm_direct$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMMDIRECTKERNEL) | $(KDIR)sgemm_direct$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMMDIRECTKERNEL) | ||||
| $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ | $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ | ||||
| $(KDIR)sgemm_direct_alpha_beta$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMMDIRECTKERNEL_ALPHA_BETA) | |||||
| $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ | |||||
| ifdef HAVE_SME | ifdef HAVE_SME | ||||
| $(KDIR)sgemm_direct_sme1$(TSUFFIX).$(SUFFIX) : | $(KDIR)sgemm_direct_sme1$(TSUFFIX).$(SUFFIX) : | ||||
| $(CC) $(CFLAGS) -c $(KERNELDIR)/sgemm_direct_sme1.S -UDOUBLE -UCOMPLEX -o $@ | $(CC) $(CFLAGS) -c $(KERNELDIR)/sgemm_direct_sme1.S -UDOUBLE -UCOMPLEX -o $@ | ||||
| @@ -0,0 +1,199 @@ | |||||
| /* | |||||
| Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. | |||||
| SPDX-License-Identifier: BSD-3-Clause-Clear | |||||
| */ | |||||
| #include "common.h" | |||||
| #include <stdlib.h> | |||||
| #include <inttypes.h> | |||||
| #include <math.h> | |||||
| #include "sme_abi.h" | |||||
| #if defined(HAVE_SME) | |||||
| #if defined(__ARM_FEATURE_SME) && defined(__clang__) && __clang_major__ >= 16 | |||||
| #include <arm_sme.h> | |||||
| #endif | |||||
| /* Function prototypes */ | |||||
| extern void sgemm_direct_sme1_preprocess(uint64_t nbr, uint64_t nbc,\ | |||||
| const float * restrict a, float * a_mod) __asm__("sgemm_direct_sme1_preprocess"); | |||||
| /* Function Definitions */ | |||||
| static uint64_t sve_cntw() { | |||||
| uint64_t cnt; | |||||
| asm volatile( | |||||
| "rdsvl %[res], #1\n" | |||||
| "lsr %[res], %[res], #2\n" | |||||
| : [res] "=r" (cnt) :: | |||||
| ); | |||||
| return cnt; | |||||
| } | |||||
| #if defined(__ARM_FEATURE_SME) && defined(__ARM_FEATURE_LOCALLY_STREAMING) && defined(__clang__) && __clang_major__ >= 16 | |||||
| // Outer product kernel. | |||||
| // Computes a 2SVL x 2SVL block of C, utilizing all four FP32 tiles of ZA. | |||||
| __attribute__((always_inline)) inline void | |||||
| kernel_2x2(const float *A, const float *B, float *C, size_t shared_dim, | |||||
| size_t ldc, size_t block_rows, size_t block_cols, float alpha, float beta) | |||||
| __arm_out("za") __arm_streaming { | |||||
| const uint64_t svl = svcntw(); | |||||
| size_t ldb = ldc; | |||||
| // Predicate set-up | |||||
| svbool_t pg = svptrue_b32(); | |||||
| svbool_t pg_a_0 = svwhilelt_b32_u64(0, block_rows); | |||||
| svbool_t pg_a_1 = svwhilelt_b32_u64(svl, block_rows); | |||||
| svbool_t pg_b_0 = svwhilelt_b32_u64(0, block_cols); | |||||
| svbool_t pg_b_1 = svwhilelt_b32_u64(svl, block_cols); | |||||
| #define pg_c_0 pg_b_0 | |||||
| #define pg_c_1 pg_b_1 | |||||
| svzero_za(); | |||||
| svfloat32_t beta_vec = svdup_f32(beta); | |||||
| // Load C to ZA | |||||
| for (size_t i = 0; i < MIN(svl, block_rows); i++) { | |||||
| svfloat32_t row_c_0 = svld1(pg_c_0, &C[i * ldc]); | |||||
| row_c_0 = svmul_x(pg, beta_vec, row_c_0); | |||||
| svwrite_hor_za32_f32_m(/*tile*/0, /*slice*/i, pg_c_0, row_c_0); | |||||
| svfloat32_t row_c_1 = svld1(pg_c_1, &C[i * ldc + svl]); | |||||
| row_c_1 = svmul_x(pg, beta_vec, row_c_1); | |||||
| svwrite_hor_za32_f32_m(/*tile*/1, /*slice*/i, pg_c_1, row_c_1); | |||||
| } | |||||
| for (size_t i = svl; i < block_rows; i++) { | |||||
| svfloat32_t row_c_0 = svld1(pg_c_0, &C[i * ldc]); | |||||
| row_c_0 = svmul_x(pg, beta_vec, row_c_0); | |||||
| svwrite_hor_za32_f32_m(/*tile*/2, /*slice*/i, pg_c_0, row_c_0); | |||||
| svfloat32_t row_c_1 = svld1(pg_c_1, &C[i * ldc + svl]); | |||||
| row_c_1 = svmul_x(pg, beta_vec, row_c_1); | |||||
| svwrite_hor_za32_f32_m(/*tile*/3, /*slice*/i, pg_c_1, row_c_1); | |||||
| } | |||||
| svfloat32_t alpha_vec = svdup_f32(alpha); | |||||
| // Iterate through shared dimension (K) | |||||
| for (size_t k = 0; k < shared_dim; k++) { | |||||
| // Load column of A | |||||
| svfloat32_t col_a_0 = svld1(pg_a_0, &A[k * svl]); | |||||
| col_a_0 = svmul_x(pg, alpha_vec, col_a_0); | |||||
| svfloat32_t col_a_1 = svld1(pg_a_1, &A[(k + shared_dim) * svl]); | |||||
| col_a_1 = svmul_x(pg, alpha_vec, col_a_1); | |||||
| // Load row of B | |||||
| svfloat32_t row_b_0 = svld1(pg_b_0, &B[k * ldb]); | |||||
| svfloat32_t row_b_1 = svld1(pg_b_1, &B[k * ldb + svl]); | |||||
| // Perform outer product | |||||
| svmopa_za32_m(/*tile*/0, pg, pg, col_a_0, row_b_0); | |||||
| svmopa_za32_m(/*tile*/1, pg, pg, col_a_0, row_b_1); | |||||
| svmopa_za32_m(/*tile*/2, pg, pg, col_a_1, row_b_0); | |||||
| svmopa_za32_m(/*tile*/3, pg, pg, col_a_1, row_b_1); | |||||
| } | |||||
| // Store to C from ZA | |||||
| for (size_t i = 0; i < MIN(svl, block_rows); i++) { | |||||
| svst1_hor_za32(/*tile*/0, /*slice*/i, pg_c_0, &C[i * ldc]); | |||||
| svst1_hor_za32(/*tile*/1, /*slice*/i, pg_c_1, &C[i * ldc + svl]); | |||||
| } | |||||
| for (size_t i = svl; i < block_rows; i++) { | |||||
| svst1_hor_za32(/*tile*/2, /*slice*/i, pg_c_0, &C[i * ldc]); | |||||
| svst1_hor_za32(/*tile*/3, /*slice*/i, pg_c_1, &C[i * ldc + svl]); | |||||
| } | |||||
| } | |||||
| __arm_new("za") __arm_locally_streaming | |||||
| void sgemm_direct_alpha_beta_sme1_2VLx2VL(uint64_t m, uint64_t k, uint64_t n, const float* alpha,\ | |||||
| const float *ba, const float *restrict bb, const float* beta,\ | |||||
| float *restrict C) { | |||||
| const uint64_t num_rows = m; | |||||
| const uint64_t num_cols = n; | |||||
| const float *restrict a_ptr = ba; | |||||
| const float *restrict b_ptr = bb; | |||||
| float *restrict c_ptr = C; | |||||
| const uint64_t svl = svcntw(); | |||||
| const uint64_t ldc = n; | |||||
| // Block over rows of C (panels of A) | |||||
| uint64_t row_idx = 0; | |||||
| // 2x2 loop | |||||
| uint64_t row_batch = 2*svl; | |||||
| // Block over row dimension of C | |||||
| for (; row_idx < num_rows; row_idx += row_batch) { | |||||
| row_batch = MIN(row_batch, num_rows - row_idx); | |||||
| uint64_t col_idx = 0; | |||||
| uint64_t col_batch = 2*svl; | |||||
| // Block over column dimension of C | |||||
| for (; col_idx < num_cols; col_idx += col_batch) { | |||||
| col_batch = MIN(col_batch, num_cols - col_idx); | |||||
| kernel_2x2(&a_ptr[row_idx * k], &b_ptr[col_idx], | |||||
| &c_ptr[row_idx * ldc + col_idx], k, | |||||
| ldc, row_batch, col_batch, *alpha, *beta); | |||||
| } | |||||
| } | |||||
| return; | |||||
| } | |||||
| #else | |||||
| void sgemm_direct_alpha_beta_sme1_2VLx2VL(uint64_t m, uint64_t k, uint64_t n, const float* alpha,\ | |||||
| const float *ba, const float *restrict bb, const float* beta,\ | |||||
| float *restrict C){} | |||||
| #endif | |||||
| /*void sgemm_kernel_direct (BLASLONG M, BLASLONG N, BLASLONG K,\ | |||||
| float * __restrict A, BLASLONG strideA, float * __restrict B,\ | |||||
| BLASLONG strideB , float * __restrict R, BLASLONG strideR) | |||||
| */ | |||||
| void CNAME (BLASLONG M, BLASLONG N, BLASLONG K, float alpha, float * __restrict A,\ | |||||
| BLASLONG strideA, float * __restrict B, BLASLONG strideB ,\ | |||||
| float beta, float * __restrict R, BLASLONG strideR){ | |||||
| uint64_t m_mod, vl_elms; | |||||
| vl_elms = sve_cntw(); | |||||
| m_mod = ceil((double)M/(double)vl_elms) * vl_elms; | |||||
| float *A_mod = (float *) malloc(m_mod*K*sizeof(float)); | |||||
| /* Prevent compiler optimization by reading from memory instead | |||||
| * of reading directly from vector (z) registers. | |||||
| * */ | |||||
| asm volatile("" : : :"p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", | |||||
| "p8", "p9", "p10", "p11", "p12", "p13", "p14", "p15", | |||||
| "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", | |||||
| "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", | |||||
| "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", | |||||
| "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31"); | |||||
| /* Pre-process the left matrix to make it suitable for | |||||
| matrix sum of outer-product calculation | |||||
| */ | |||||
| sgemm_direct_sme1_preprocess(M, K, A, A_mod); | |||||
| asm volatile("" : : :"p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", | |||||
| "p8", "p9", "p10", "p11", "p12", "p13", "p14", "p15", | |||||
| "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", | |||||
| "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", | |||||
| "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", | |||||
| "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31"); | |||||
| /* Calculate C = alpha*A*B + beta*C */ | |||||
| sgemm_direct_alpha_beta_sme1_2VLx2VL(M, K, N, &alpha, A_mod, B, &beta, R); | |||||
| free(A_mod); | |||||
| } | |||||
| #else | |||||
| void CNAME (BLASLONG M, BLASLONG N, BLASLONG K, float alpha, float * __restrict A,\ | |||||
| BLASLONG strideA, float * __restrict B, BLASLONG strideB ,\ | |||||
| float beta, float * __restrict R, BLASLONG strideR){} | |||||
| #endif | |||||
| @@ -0,0 +1,46 @@ | |||||
| /*************************************************************************** | |||||
| * Copyright (c) 2024, The OpenBLAS Project | |||||
| * All rights reserved. | |||||
| * Redistribution and use in source and binary forms, with or without | |||||
| * modification, are permitted provided that the following conditions are | |||||
| * met: | |||||
| * 1. Redistributions of source code must retain the above copyright | |||||
| * notice, this list of conditions and the following disclaimer. | |||||
| * 2. Redistributions in binary form must reproduce the above copyright | |||||
| * notice, this list of conditions and the following disclaimer in | |||||
| * the documentation and/or other materials provided with the | |||||
| * distribution. | |||||
| * 3. Neither the name of the OpenBLAS project nor the names of | |||||
| * its contributors may be used to endorse or promote products | |||||
| * derived from this software without specific prior written permission. | |||||
| * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||||
| * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||||
| * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||||
| * ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE | |||||
| * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR | |||||
| * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE | |||||
| * GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) | |||||
| * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT | |||||
| * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF | |||||
| * THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||||
| * *****************************************************************************/ | |||||
| #pragma once | |||||
| #include <stdlib.h> | |||||
| /** | |||||
| * * These are SME ABI routines for saving & restoring SME state. | |||||
| * * They are typically provided by a compiler runtime library such | |||||
| * * as libgcc or compiler-rt, but support for these routines is not | |||||
| * * yet available on all platforms. | |||||
| * * | |||||
| * * Define these as aborting stubs so that we loudly fail on nested | |||||
| * * usage of SME state. | |||||
| * * | |||||
| * * These are defined as weak symbols so that a compiler runtime can | |||||
| * * override them if supported. | |||||
| * */ | |||||
| __attribute__((weak)) void __arm_tpidr2_save() { abort(); } | |||||
| __attribute__((weak)) void __arm_tpidr2_restore() { abort(); } | |||||
| @@ -198,6 +198,7 @@ gotoblas_t TABLE_NAME = { | |||||
| #endif | #endif | ||||
| #ifdef ARCH_ARM64 | #ifdef ARCH_ARM64 | ||||
| sgemm_directTS, | sgemm_directTS, | ||||
| sgemm_direct_alpha_betaTS, | |||||
| #endif | #endif | ||||
| sgemm_kernelTS, sgemm_betaTS, | sgemm_kernelTS, sgemm_betaTS, | ||||