SME1 based direct kernel (with alpha and beta) for cblas_sgemm level 3pull/5386/head
@@ -52,6 +52,13 @@ void sgemm_direct(BLASLONG M, BLASLONG N, BLASLONG K, | |||
float * B, BLASLONG strideB, | |||
float * R, BLASLONG strideR); | |||
void sgemm_direct_alpha_beta(BLASLONG M, BLASLONG N, BLASLONG K, | |||
float alpha, | |||
float * A, BLASLONG strideA, | |||
float * B, BLASLONG strideB, | |||
float beta, | |||
float * R, BLASLONG strideR); | |||
int sgemm_direct_performant(BLASLONG M, BLASLONG N, BLASLONG K); | |||
int shgemm_beta(BLASLONG, BLASLONG, BLASLONG, float, | |||
@@ -256,6 +256,7 @@ int (*shgemm_otcopy )(BLASLONG, BLASLONG, hfloat16 *, BLASLONG, hfloat16 *); | |||
#endif | |||
#ifdef ARCH_ARM64 | |||
void (*sgemm_direct) (BLASLONG, BLASLONG, BLASLONG, float *, BLASLONG , float *, BLASLONG , float * , BLASLONG); | |||
void (*sgemm_direct_alpha_beta) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float, float * , BLASLONG); | |||
#endif | |||
@@ -49,6 +49,7 @@ | |||
#define SGEMM_DIRECT_PERFORMANT sgemm_direct_performant | |||
#define SGEMM_DIRECT sgemm_direct | |||
#define SGEMM_DIRECT_ALPHA_BETA sgemm_direct_alpha_beta | |||
#define SGEMM_ONCOPY sgemm_oncopy | |||
#define SGEMM_OTCOPY sgemm_otcopy | |||
@@ -218,6 +219,7 @@ | |||
#elif ARCH_ARM64 | |||
#define SGEMM_DIRECT_PERFORMANT sgemm_direct_performant | |||
#define SGEMM_DIRECT gotoblas -> sgemm_direct | |||
#define SGEMM_DIRECT_ALPHA_BETA gotoblas -> sgemm_direct_alpha_beta | |||
#endif | |||
#define SGEMM_ONCOPY gotoblas -> sgemm_oncopy | |||
@@ -441,6 +441,9 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS | |||
if (beta == 0 && alpha == 1.0 && order == CblasRowMajor && TransA == CblasNoTrans && TransB == CblasNoTrans) { | |||
SGEMM_DIRECT(m, n, k, a, lda, b, ldb, c, ldc); | |||
return; | |||
}else if (order == CblasRowMajor && TransA == CblasNoTrans && TransB == CblasNoTrans) { | |||
SGEMM_DIRECT_ALPHA_BETA(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); | |||
return; | |||
} | |||
#endif | |||
#endif | |||
@@ -255,9 +255,11 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS) | |||
GenerateNamedObjects("${KERNELDIR}/${SGEMMDIRECTPERFORMANT}" "" "gemm_direct_performant" false "" "" false SINGLE) | |||
elseif (ARM64) | |||
set (SGEMMDIRECTKERNEL sgemm_direct_arm64_sme1.c) | |||
set (SGEMMDIRECTKERNEL_ALPHA_BETA sgemm_direct_alpha_beta_arm64_sme1.c) | |||
set (SGEMMDIRECTSMEKERNEL sgemm_direct_sme1.S) | |||
set (SGEMMDIRECTPREKERNEL sgemm_direct_sme1_preprocess.S) | |||
GenerateNamedObjects("${KERNELDIR}/${SGEMMDIRECTKERNEL}" "" "gemm_direct" false "" "" false SINGLE) | |||
GenerateNamedObjects("${KERNELDIR}/${SGEMMDIRECTKERNEL_ALPHA_BETA}" "" "gemm_direct_alpha_beta" false "" "" false SINGLE) | |||
if (HAVE_SME) | |||
GenerateNamedObjects("${KERNELDIR}/${SGEMMDIRECTSMEKERNEL}" "" "gemm_direct_sme1" false "" "" false SINGLE) | |||
GenerateNamedObjects("${KERNELDIR}/${SGEMMDIRECTPREKERNEL}" "" "gemm_direct_sme1_preprocess" false "" "" false SINGLE) | |||
@@ -132,6 +132,7 @@ ifeq ($(TARGET_CORE), ARMV9SME) | |||
HAVE_SME = 1 | |||
endif | |||
SGEMMDIRECTKERNEL = sgemm_direct_arm64_sme1.c | |||
SGEMMDIRECTKERNEL_ALPHA_BETA = sgemm_direct_alpha_beta_arm64_sme1.c | |||
endif | |||
endif | |||
endif | |||
@@ -208,7 +209,8 @@ SKERNELOBJS += \ | |||
endif | |||
ifeq ($(ARCH), arm64) | |||
SKERNELOBJS += \ | |||
sgemm_direct$(TSUFFIX).$(SUFFIX) | |||
sgemm_direct$(TSUFFIX).$(SUFFIX) \ | |||
sgemm_direct_alpha_beta$(TSUFFIX).$(SUFFIX) | |||
ifdef HAVE_SME | |||
SKERNELOBJS += \ | |||
sgemm_direct_sme1$(TSUFFIX).$(SUFFIX) \ | |||
@@ -969,6 +971,8 @@ endif | |||
ifeq ($(ARCH), arm64) | |||
$(KDIR)sgemm_direct$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMMDIRECTKERNEL) | |||
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ | |||
$(KDIR)sgemm_direct_alpha_beta$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMMDIRECTKERNEL_ALPHA_BETA) | |||
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ | |||
ifdef HAVE_SME | |||
$(KDIR)sgemm_direct_sme1$(TSUFFIX).$(SUFFIX) : | |||
$(CC) $(CFLAGS) -c $(KERNELDIR)/sgemm_direct_sme1.S -UDOUBLE -UCOMPLEX -o $@ | |||
@@ -0,0 +1,199 @@ | |||
/* | |||
Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. | |||
SPDX-License-Identifier: BSD-3-Clause-Clear | |||
*/ | |||
#include "common.h" | |||
#include <stdlib.h> | |||
#include <inttypes.h> | |||
#include <math.h> | |||
#include "sme_abi.h" | |||
#if defined(HAVE_SME) | |||
#if defined(__ARM_FEATURE_SME) && defined(__clang__) && __clang_major__ >= 16 | |||
#include <arm_sme.h> | |||
#endif | |||
/* Function prototypes */ | |||
extern void sgemm_direct_sme1_preprocess(uint64_t nbr, uint64_t nbc,\ | |||
const float * restrict a, float * a_mod) __asm__("sgemm_direct_sme1_preprocess"); | |||
/* Function Definitions */ | |||
static uint64_t sve_cntw() { | |||
uint64_t cnt; | |||
asm volatile( | |||
"rdsvl %[res], #1\n" | |||
"lsr %[res], %[res], #2\n" | |||
: [res] "=r" (cnt) :: | |||
); | |||
return cnt; | |||
} | |||
#if defined(__ARM_FEATURE_SME) && defined(__ARM_FEATURE_LOCALLY_STREAMING) && defined(__clang__) && __clang_major__ >= 16 | |||
// Outer product kernel. | |||
// Computes a 2SVL x 2SVL block of C, utilizing all four FP32 tiles of ZA. | |||
__attribute__((always_inline)) inline void | |||
kernel_2x2(const float *A, const float *B, float *C, size_t shared_dim, | |||
size_t ldc, size_t block_rows, size_t block_cols, float alpha, float beta) | |||
__arm_out("za") __arm_streaming { | |||
const uint64_t svl = svcntw(); | |||
size_t ldb = ldc; | |||
// Predicate set-up | |||
svbool_t pg = svptrue_b32(); | |||
svbool_t pg_a_0 = svwhilelt_b32_u64(0, block_rows); | |||
svbool_t pg_a_1 = svwhilelt_b32_u64(svl, block_rows); | |||
svbool_t pg_b_0 = svwhilelt_b32_u64(0, block_cols); | |||
svbool_t pg_b_1 = svwhilelt_b32_u64(svl, block_cols); | |||
#define pg_c_0 pg_b_0 | |||
#define pg_c_1 pg_b_1 | |||
svzero_za(); | |||
svfloat32_t beta_vec = svdup_f32(beta); | |||
// Load C to ZA | |||
for (size_t i = 0; i < MIN(svl, block_rows); i++) { | |||
svfloat32_t row_c_0 = svld1(pg_c_0, &C[i * ldc]); | |||
row_c_0 = svmul_x(pg, beta_vec, row_c_0); | |||
svwrite_hor_za32_f32_m(/*tile*/0, /*slice*/i, pg_c_0, row_c_0); | |||
svfloat32_t row_c_1 = svld1(pg_c_1, &C[i * ldc + svl]); | |||
row_c_1 = svmul_x(pg, beta_vec, row_c_1); | |||
svwrite_hor_za32_f32_m(/*tile*/1, /*slice*/i, pg_c_1, row_c_1); | |||
} | |||
for (size_t i = svl; i < block_rows; i++) { | |||
svfloat32_t row_c_0 = svld1(pg_c_0, &C[i * ldc]); | |||
row_c_0 = svmul_x(pg, beta_vec, row_c_0); | |||
svwrite_hor_za32_f32_m(/*tile*/2, /*slice*/i, pg_c_0, row_c_0); | |||
svfloat32_t row_c_1 = svld1(pg_c_1, &C[i * ldc + svl]); | |||
row_c_1 = svmul_x(pg, beta_vec, row_c_1); | |||
svwrite_hor_za32_f32_m(/*tile*/3, /*slice*/i, pg_c_1, row_c_1); | |||
} | |||
svfloat32_t alpha_vec = svdup_f32(alpha); | |||
// Iterate through shared dimension (K) | |||
for (size_t k = 0; k < shared_dim; k++) { | |||
// Load column of A | |||
svfloat32_t col_a_0 = svld1(pg_a_0, &A[k * svl]); | |||
col_a_0 = svmul_x(pg, alpha_vec, col_a_0); | |||
svfloat32_t col_a_1 = svld1(pg_a_1, &A[(k + shared_dim) * svl]); | |||
col_a_1 = svmul_x(pg, alpha_vec, col_a_1); | |||
// Load row of B | |||
svfloat32_t row_b_0 = svld1(pg_b_0, &B[k * ldb]); | |||
svfloat32_t row_b_1 = svld1(pg_b_1, &B[k * ldb + svl]); | |||
// Perform outer product | |||
svmopa_za32_m(/*tile*/0, pg, pg, col_a_0, row_b_0); | |||
svmopa_za32_m(/*tile*/1, pg, pg, col_a_0, row_b_1); | |||
svmopa_za32_m(/*tile*/2, pg, pg, col_a_1, row_b_0); | |||
svmopa_za32_m(/*tile*/3, pg, pg, col_a_1, row_b_1); | |||
} | |||
// Store to C from ZA | |||
for (size_t i = 0; i < MIN(svl, block_rows); i++) { | |||
svst1_hor_za32(/*tile*/0, /*slice*/i, pg_c_0, &C[i * ldc]); | |||
svst1_hor_za32(/*tile*/1, /*slice*/i, pg_c_1, &C[i * ldc + svl]); | |||
} | |||
for (size_t i = svl; i < block_rows; i++) { | |||
svst1_hor_za32(/*tile*/2, /*slice*/i, pg_c_0, &C[i * ldc]); | |||
svst1_hor_za32(/*tile*/3, /*slice*/i, pg_c_1, &C[i * ldc + svl]); | |||
} | |||
} | |||
__arm_new("za") __arm_locally_streaming | |||
void sgemm_direct_alpha_beta_sme1_2VLx2VL(uint64_t m, uint64_t k, uint64_t n, const float* alpha,\ | |||
const float *ba, const float *restrict bb, const float* beta,\ | |||
float *restrict C) { | |||
const uint64_t num_rows = m; | |||
const uint64_t num_cols = n; | |||
const float *restrict a_ptr = ba; | |||
const float *restrict b_ptr = bb; | |||
float *restrict c_ptr = C; | |||
const uint64_t svl = svcntw(); | |||
const uint64_t ldc = n; | |||
// Block over rows of C (panels of A) | |||
uint64_t row_idx = 0; | |||
// 2x2 loop | |||
uint64_t row_batch = 2*svl; | |||
// Block over row dimension of C | |||
for (; row_idx < num_rows; row_idx += row_batch) { | |||
row_batch = MIN(row_batch, num_rows - row_idx); | |||
uint64_t col_idx = 0; | |||
uint64_t col_batch = 2*svl; | |||
// Block over column dimension of C | |||
for (; col_idx < num_cols; col_idx += col_batch) { | |||
col_batch = MIN(col_batch, num_cols - col_idx); | |||
kernel_2x2(&a_ptr[row_idx * k], &b_ptr[col_idx], | |||
&c_ptr[row_idx * ldc + col_idx], k, | |||
ldc, row_batch, col_batch, *alpha, *beta); | |||
} | |||
} | |||
return; | |||
} | |||
#else | |||
void sgemm_direct_alpha_beta_sme1_2VLx2VL(uint64_t m, uint64_t k, uint64_t n, const float* alpha,\ | |||
const float *ba, const float *restrict bb, const float* beta,\ | |||
float *restrict C){} | |||
#endif | |||
/*void sgemm_kernel_direct (BLASLONG M, BLASLONG N, BLASLONG K,\ | |||
float * __restrict A, BLASLONG strideA, float * __restrict B,\ | |||
BLASLONG strideB , float * __restrict R, BLASLONG strideR) | |||
*/ | |||
void CNAME (BLASLONG M, BLASLONG N, BLASLONG K, float alpha, float * __restrict A,\ | |||
BLASLONG strideA, float * __restrict B, BLASLONG strideB ,\ | |||
float beta, float * __restrict R, BLASLONG strideR){ | |||
uint64_t m_mod, vl_elms; | |||
vl_elms = sve_cntw(); | |||
m_mod = ceil((double)M/(double)vl_elms) * vl_elms; | |||
float *A_mod = (float *) malloc(m_mod*K*sizeof(float)); | |||
/* Prevent compiler optimization by reading from memory instead | |||
* of reading directly from vector (z) registers. | |||
* */ | |||
asm volatile("" : : :"p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", | |||
"p8", "p9", "p10", "p11", "p12", "p13", "p14", "p15", | |||
"z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", | |||
"z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", | |||
"z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", | |||
"z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31"); | |||
/* Pre-process the left matrix to make it suitable for | |||
matrix sum of outer-product calculation | |||
*/ | |||
sgemm_direct_sme1_preprocess(M, K, A, A_mod); | |||
asm volatile("" : : :"p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", | |||
"p8", "p9", "p10", "p11", "p12", "p13", "p14", "p15", | |||
"z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", | |||
"z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", | |||
"z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", | |||
"z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31"); | |||
/* Calculate C = alpha*A*B + beta*C */ | |||
sgemm_direct_alpha_beta_sme1_2VLx2VL(M, K, N, &alpha, A_mod, B, &beta, R); | |||
free(A_mod); | |||
} | |||
#else | |||
void CNAME (BLASLONG M, BLASLONG N, BLASLONG K, float alpha, float * __restrict A,\ | |||
BLASLONG strideA, float * __restrict B, BLASLONG strideB ,\ | |||
float beta, float * __restrict R, BLASLONG strideR){} | |||
#endif |
@@ -0,0 +1,46 @@ | |||
/*************************************************************************** | |||
* Copyright (c) 2024, The OpenBLAS Project | |||
* All rights reserved. | |||
* Redistribution and use in source and binary forms, with or without | |||
* modification, are permitted provided that the following conditions are | |||
* met: | |||
* 1. Redistributions of source code must retain the above copyright | |||
* notice, this list of conditions and the following disclaimer. | |||
* 2. Redistributions in binary form must reproduce the above copyright | |||
* notice, this list of conditions and the following disclaimer in | |||
* the documentation and/or other materials provided with the | |||
* distribution. | |||
* 3. Neither the name of the OpenBLAS project nor the names of | |||
* its contributors may be used to endorse or promote products | |||
* derived from this software without specific prior written permission. | |||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||
* ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE | |||
* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR | |||
* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE | |||
* GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) | |||
* HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT | |||
* LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF | |||
* THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||
* *****************************************************************************/ | |||
#pragma once | |||
#include <stdlib.h> | |||
/** | |||
* * These are SME ABI routines for saving & restoring SME state. | |||
* * They are typically provided by a compiler runtime library such | |||
* * as libgcc or compiler-rt, but support for these routines is not | |||
* * yet available on all platforms. | |||
* * | |||
* * Define these as aborting stubs so that we loudly fail on nested | |||
* * usage of SME state. | |||
* * | |||
* * These are defined as weak symbols so that a compiler runtime can | |||
* * override them if supported. | |||
* */ | |||
__attribute__((weak)) void __arm_tpidr2_save() { abort(); } | |||
__attribute__((weak)) void __arm_tpidr2_restore() { abort(); } | |||
@@ -215,6 +215,7 @@ gotoblas_t TABLE_NAME = { | |||
#endif | |||
#ifdef ARCH_ARM64 | |||
sgemm_directTS, | |||
sgemm_direct_alpha_betaTS, | |||
#endif | |||
sgemm_kernelTS, sgemm_betaTS, | |||