From eae0abfdb6153a4f8619927e46c797859e55d48c Mon Sep 17 00:00:00 2001 From: Rajendra Prasad Matcha Date: Fri, 11 Jul 2025 14:51:16 +0530 Subject: [PATCH] SME1 based direct kernel with alpha and beta for cblas_sgemm level 3 API. --- common_level3.h | 7 + common_param.h | 1 + common_s.h | 2 + interface/gemm.c | 3 + kernel/CMakeLists.txt | 2 + kernel/Makefile.L3 | 6 +- .../sgemm_direct_alpha_beta_arm64_sme1.c | 199 ++++++++++++++++++ kernel/arm64/sme_abi.h | 46 ++++ kernel/setparam-ref.c | 1 + 9 files changed, 266 insertions(+), 1 deletion(-) create mode 100644 kernel/arm64/sgemm_direct_alpha_beta_arm64_sme1.c create mode 100644 kernel/arm64/sme_abi.h diff --git a/common_level3.h b/common_level3.h index 1838b4bf6..6ca62dc88 100644 --- a/common_level3.h +++ b/common_level3.h @@ -52,6 +52,13 @@ void sgemm_direct(BLASLONG M, BLASLONG N, BLASLONG K, float * B, BLASLONG strideB, float * R, BLASLONG strideR); +void sgemm_direct_alpha_beta(BLASLONG M, BLASLONG N, BLASLONG K, + float alpha, + float * A, BLASLONG strideA, + float * B, BLASLONG strideB, + float beta, + float * R, BLASLONG strideR); + int sgemm_direct_performant(BLASLONG M, BLASLONG N, BLASLONG K); int shgemm_beta(BLASLONG, BLASLONG, BLASLONG, float, diff --git a/common_param.h b/common_param.h index f82b73a72..8923fa430 100644 --- a/common_param.h +++ b/common_param.h @@ -240,6 +240,7 @@ int (*shgemm_otcopy )(BLASLONG, BLASLONG, hfloat16 *, BLASLONG, hfloat16 *); #endif #ifdef ARCH_ARM64 void (*sgemm_direct) (BLASLONG, BLASLONG, BLASLONG, float *, BLASLONG , float *, BLASLONG , float * , BLASLONG); + void (*sgemm_direct_alpha_beta) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float, float * , BLASLONG); #endif diff --git a/common_s.h b/common_s.h index 1dede1e36..88b4732f5 100644 --- a/common_s.h +++ b/common_s.h @@ -49,6 +49,7 @@ #define SGEMM_DIRECT_PERFORMANT sgemm_direct_performant #define SGEMM_DIRECT sgemm_direct +#define SGEMM_DIRECT_ALPHA_BETA sgemm_direct_alpha_beta #define SGEMM_ONCOPY sgemm_oncopy #define SGEMM_OTCOPY sgemm_otcopy @@ -218,6 +219,7 @@ #elif ARCH_ARM64 #define SGEMM_DIRECT_PERFORMANT sgemm_direct_performant #define SGEMM_DIRECT gotoblas -> sgemm_direct +#define SGEMM_DIRECT_ALPHA_BETA gotoblas -> sgemm_direct_alpha_beta #endif #define SGEMM_ONCOPY gotoblas -> sgemm_oncopy diff --git a/interface/gemm.c b/interface/gemm.c index d79282e13..0b29a4dbb 100644 --- a/interface/gemm.c +++ b/interface/gemm.c @@ -436,6 +436,9 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS if (beta == 0 && alpha == 1.0 && order == CblasRowMajor && TransA == CblasNoTrans && TransB == CblasNoTrans) { SGEMM_DIRECT(m, n, k, a, lda, b, ldb, c, ldc); return; + }else if (order == CblasRowMajor && TransA == CblasNoTrans && TransB == CblasNoTrans) { + SGEMM_DIRECT_ALPHA_BETA(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + return; } #endif #endif diff --git a/kernel/CMakeLists.txt b/kernel/CMakeLists.txt index 9434f114e..2ec3bf77c 100644 --- a/kernel/CMakeLists.txt +++ b/kernel/CMakeLists.txt @@ -222,9 +222,11 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS) GenerateNamedObjects("${KERNELDIR}/${SGEMMDIRECTPERFORMANT}" "" "gemm_direct_performant" false "" "" false SINGLE) elseif (ARM64) set (SGEMMDIRECTKERNEL sgemm_direct_arm64_sme1.c) + set (SGEMMDIRECTKERNEL_ALPHA_BETA sgemm_direct_alpha_beta_arm64_sme1.c) set (SGEMMDIRECTSMEKERNEL sgemm_direct_sme1.S) set (SGEMMDIRECTPREKERNEL sgemm_direct_sme1_preprocess.S) GenerateNamedObjects("${KERNELDIR}/${SGEMMDIRECTKERNEL}" "" "gemm_direct" false "" "" false SINGLE) + GenerateNamedObjects("${KERNELDIR}/${SGEMMDIRECTKERNEL_ALPHA_BETA}" "" "gemm_direct_alpha_beta" false "" "" false SINGLE) if (HAVE_SME) GenerateNamedObjects("${KERNELDIR}/${SGEMMDIRECTSMEKERNEL}" "" "gemm_direct_sme1" false "" "" false SINGLE) GenerateNamedObjects("${KERNELDIR}/${SGEMMDIRECTPREKERNEL}" "" "gemm_direct_sme1_preprocess" false "" "" false SINGLE) diff --git a/kernel/Makefile.L3 b/kernel/Makefile.L3 index 6afb49a77..ebdfe7b06 100644 --- a/kernel/Makefile.L3 +++ b/kernel/Makefile.L3 @@ -105,6 +105,7 @@ ifeq ($(TARGET_CORE), ARMV9SME) HAVE_SME = 1 endif SGEMMDIRECTKERNEL = sgemm_direct_arm64_sme1.c +SGEMMDIRECTKERNEL_ALPHA_BETA = sgemm_direct_alpha_beta_arm64_sme1.c endif endif endif @@ -164,7 +165,8 @@ SKERNELOBJS += \ endif ifeq ($(ARCH), arm64) SKERNELOBJS += \ - sgemm_direct$(TSUFFIX).$(SUFFIX) + sgemm_direct$(TSUFFIX).$(SUFFIX) \ + sgemm_direct_alpha_beta$(TSUFFIX).$(SUFFIX) ifdef HAVE_SME SKERNELOBJS += \ sgemm_direct_sme1$(TSUFFIX).$(SUFFIX) \ @@ -904,6 +906,8 @@ endif ifeq ($(ARCH), arm64) $(KDIR)sgemm_direct$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMMDIRECTKERNEL) $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ +$(KDIR)sgemm_direct_alpha_beta$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMMDIRECTKERNEL_ALPHA_BETA) + $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ ifdef HAVE_SME $(KDIR)sgemm_direct_sme1$(TSUFFIX).$(SUFFIX) : $(CC) $(CFLAGS) -c $(KERNELDIR)/sgemm_direct_sme1.S -UDOUBLE -UCOMPLEX -o $@ diff --git a/kernel/arm64/sgemm_direct_alpha_beta_arm64_sme1.c b/kernel/arm64/sgemm_direct_alpha_beta_arm64_sme1.c new file mode 100644 index 000000000..d9de3ace3 --- /dev/null +++ b/kernel/arm64/sgemm_direct_alpha_beta_arm64_sme1.c @@ -0,0 +1,199 @@ +/* + Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. + SPDX-License-Identifier: BSD-3-Clause-Clear +*/ + +#include "common.h" +#include +#include +#include +#include "sme_abi.h" +#if defined(HAVE_SME) + +#if defined(__ARM_FEATURE_SME) && defined(__clang__) && __clang_major__ >= 16 +#include +#endif + +/* Function prototypes */ +extern void sgemm_direct_sme1_preprocess(uint64_t nbr, uint64_t nbc,\ + const float * restrict a, float * a_mod) __asm__("sgemm_direct_sme1_preprocess"); + +/* Function Definitions */ +static uint64_t sve_cntw() { + uint64_t cnt; + asm volatile( + "rdsvl %[res], #1\n" + "lsr %[res], %[res], #2\n" + : [res] "=r" (cnt) :: + ); + return cnt; +} + +#if defined(__ARM_FEATURE_SME) && defined(__ARM_FEATURE_LOCALLY_STREAMING) && defined(__clang__) && __clang_major__ >= 16 +// Outer product kernel. +// Computes a 2SVL x 2SVL block of C, utilizing all four FP32 tiles of ZA. +__attribute__((always_inline)) inline void +kernel_2x2(const float *A, const float *B, float *C, size_t shared_dim, + size_t ldc, size_t block_rows, size_t block_cols, float alpha, float beta) + __arm_out("za") __arm_streaming { + + const uint64_t svl = svcntw(); + size_t ldb = ldc; + // Predicate set-up + svbool_t pg = svptrue_b32(); + svbool_t pg_a_0 = svwhilelt_b32_u64(0, block_rows); + svbool_t pg_a_1 = svwhilelt_b32_u64(svl, block_rows); + + svbool_t pg_b_0 = svwhilelt_b32_u64(0, block_cols); + svbool_t pg_b_1 = svwhilelt_b32_u64(svl, block_cols); + +#define pg_c_0 pg_b_0 +#define pg_c_1 pg_b_1 + + svzero_za(); + svfloat32_t beta_vec = svdup_f32(beta); + // Load C to ZA + for (size_t i = 0; i < MIN(svl, block_rows); i++) { + svfloat32_t row_c_0 = svld1(pg_c_0, &C[i * ldc]); + row_c_0 = svmul_x(pg, beta_vec, row_c_0); + svwrite_hor_za32_f32_m(/*tile*/0, /*slice*/i, pg_c_0, row_c_0); + + svfloat32_t row_c_1 = svld1(pg_c_1, &C[i * ldc + svl]); + row_c_1 = svmul_x(pg, beta_vec, row_c_1); + svwrite_hor_za32_f32_m(/*tile*/1, /*slice*/i, pg_c_1, row_c_1); + } + for (size_t i = svl; i < block_rows; i++) { + svfloat32_t row_c_0 = svld1(pg_c_0, &C[i * ldc]); + row_c_0 = svmul_x(pg, beta_vec, row_c_0); + svwrite_hor_za32_f32_m(/*tile*/2, /*slice*/i, pg_c_0, row_c_0); + + svfloat32_t row_c_1 = svld1(pg_c_1, &C[i * ldc + svl]); + row_c_1 = svmul_x(pg, beta_vec, row_c_1); + svwrite_hor_za32_f32_m(/*tile*/3, /*slice*/i, pg_c_1, row_c_1); + } + + svfloat32_t alpha_vec = svdup_f32(alpha); + // Iterate through shared dimension (K) + for (size_t k = 0; k < shared_dim; k++) { + // Load column of A + svfloat32_t col_a_0 = svld1(pg_a_0, &A[k * svl]); + col_a_0 = svmul_x(pg, alpha_vec, col_a_0); + svfloat32_t col_a_1 = svld1(pg_a_1, &A[(k + shared_dim) * svl]); + col_a_1 = svmul_x(pg, alpha_vec, col_a_1); + // Load row of B + svfloat32_t row_b_0 = svld1(pg_b_0, &B[k * ldb]); + svfloat32_t row_b_1 = svld1(pg_b_1, &B[k * ldb + svl]); + // Perform outer product + svmopa_za32_m(/*tile*/0, pg, pg, col_a_0, row_b_0); + svmopa_za32_m(/*tile*/1, pg, pg, col_a_0, row_b_1); + svmopa_za32_m(/*tile*/2, pg, pg, col_a_1, row_b_0); + svmopa_za32_m(/*tile*/3, pg, pg, col_a_1, row_b_1); + } + + // Store to C from ZA + for (size_t i = 0; i < MIN(svl, block_rows); i++) { + svst1_hor_za32(/*tile*/0, /*slice*/i, pg_c_0, &C[i * ldc]); + svst1_hor_za32(/*tile*/1, /*slice*/i, pg_c_1, &C[i * ldc + svl]); + } + for (size_t i = svl; i < block_rows; i++) { + svst1_hor_za32(/*tile*/2, /*slice*/i, pg_c_0, &C[i * ldc]); + svst1_hor_za32(/*tile*/3, /*slice*/i, pg_c_1, &C[i * ldc + svl]); + } +} + +__arm_new("za") __arm_locally_streaming +void sgemm_direct_alpha_beta_sme1_2VLx2VL(uint64_t m, uint64_t k, uint64_t n, const float* alpha,\ + const float *ba, const float *restrict bb, const float* beta,\ + float *restrict C) { + + const uint64_t num_rows = m; + const uint64_t num_cols = n; + + const float *restrict a_ptr = ba; + const float *restrict b_ptr = bb; + float *restrict c_ptr = C; + + const uint64_t svl = svcntw(); + const uint64_t ldc = n; + + // Block over rows of C (panels of A) + uint64_t row_idx = 0; + + // 2x2 loop + uint64_t row_batch = 2*svl; + + // Block over row dimension of C + for (; row_idx < num_rows; row_idx += row_batch) { + row_batch = MIN(row_batch, num_rows - row_idx); + uint64_t col_idx = 0; + uint64_t col_batch = 2*svl; + + // Block over column dimension of C + for (; col_idx < num_cols; col_idx += col_batch) { + col_batch = MIN(col_batch, num_cols - col_idx); + + kernel_2x2(&a_ptr[row_idx * k], &b_ptr[col_idx], + &c_ptr[row_idx * ldc + col_idx], k, + ldc, row_batch, col_batch, *alpha, *beta); + } + } + return; +} + +#else +void sgemm_direct_alpha_beta_sme1_2VLx2VL(uint64_t m, uint64_t k, uint64_t n, const float* alpha,\ + const float *ba, const float *restrict bb, const float* beta,\ + float *restrict C){} +#endif + +/*void sgemm_kernel_direct (BLASLONG M, BLASLONG N, BLASLONG K,\ + float * __restrict A, BLASLONG strideA, float * __restrict B,\ + BLASLONG strideB , float * __restrict R, BLASLONG strideR) +*/ +void CNAME (BLASLONG M, BLASLONG N, BLASLONG K, float alpha, float * __restrict A,\ + BLASLONG strideA, float * __restrict B, BLASLONG strideB ,\ + float beta, float * __restrict R, BLASLONG strideR){ + + uint64_t m_mod, vl_elms; + + vl_elms = sve_cntw(); + + m_mod = ceil((double)M/(double)vl_elms) * vl_elms; + + float *A_mod = (float *) malloc(m_mod*K*sizeof(float)); + + /* Prevent compiler optimization by reading from memory instead + * of reading directly from vector (z) registers. + * */ + asm volatile("" : : :"p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", + "p8", "p9", "p10", "p11", "p12", "p13", "p14", "p15", + "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", + "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", + "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", + "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31"); + + /* Pre-process the left matrix to make it suitable for + matrix sum of outer-product calculation + */ + sgemm_direct_sme1_preprocess(M, K, A, A_mod); + + asm volatile("" : : :"p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", + "p8", "p9", "p10", "p11", "p12", "p13", "p14", "p15", + "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", + "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", + "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", + "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31"); + + /* Calculate C = alpha*A*B + beta*C */ + sgemm_direct_alpha_beta_sme1_2VLx2VL(M, K, N, &alpha, A_mod, B, &beta, R); + + free(A_mod); +} + +#else + +void CNAME (BLASLONG M, BLASLONG N, BLASLONG K, float alpha, float * __restrict A,\ + BLASLONG strideA, float * __restrict B, BLASLONG strideB ,\ + float beta, float * __restrict R, BLASLONG strideR){} + +#endif diff --git a/kernel/arm64/sme_abi.h b/kernel/arm64/sme_abi.h new file mode 100644 index 000000000..07bba4895 --- /dev/null +++ b/kernel/arm64/sme_abi.h @@ -0,0 +1,46 @@ +/*************************************************************************** + * Copyright (c) 2024, The OpenBLAS Project + * All rights reserved. + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in + * the documentation and/or other materials provided with the + * distribution. + * 3. Neither the name of the OpenBLAS project nor the names of + * its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE + * GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT + * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF + * THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * *****************************************************************************/ + +#pragma once + +#include + +/** + * * These are SME ABI routines for saving & restoring SME state. + * * They are typically provided by a compiler runtime library such + * * as libgcc or compiler-rt, but support for these routines is not + * * yet available on all platforms. + * * + * * Define these as aborting stubs so that we loudly fail on nested + * * usage of SME state. + * * + * * These are defined as weak symbols so that a compiler runtime can + * * override them if supported. + * */ +__attribute__((weak)) void __arm_tpidr2_save() { abort(); } +__attribute__((weak)) void __arm_tpidr2_restore() { abort(); } + diff --git a/kernel/setparam-ref.c b/kernel/setparam-ref.c index 20a1df261..d807071a4 100644 --- a/kernel/setparam-ref.c +++ b/kernel/setparam-ref.c @@ -198,6 +198,7 @@ gotoblas_t TABLE_NAME = { #endif #ifdef ARCH_ARM64 sgemm_directTS, + sgemm_direct_alpha_betaTS, #endif sgemm_kernelTS, sgemm_betaTS,