@@ -52,6 +52,13 @@ void sgemm_direct(BLASLONG M, BLASLONG N, BLASLONG K, | |||||
float * B, BLASLONG strideB, | float * B, BLASLONG strideB, | ||||
float * R, BLASLONG strideR); | float * R, BLASLONG strideR); | ||||
void sgemm_direct_alpha_beta(BLASLONG M, BLASLONG N, BLASLONG K, | |||||
float alpha, | |||||
float * A, BLASLONG strideA, | |||||
float * B, BLASLONG strideB, | |||||
float beta, | |||||
float * R, BLASLONG strideR); | |||||
int sgemm_direct_performant(BLASLONG M, BLASLONG N, BLASLONG K); | int sgemm_direct_performant(BLASLONG M, BLASLONG N, BLASLONG K); | ||||
int shgemm_beta(BLASLONG, BLASLONG, BLASLONG, float, | int shgemm_beta(BLASLONG, BLASLONG, BLASLONG, float, | ||||
@@ -240,6 +240,7 @@ int (*shgemm_otcopy )(BLASLONG, BLASLONG, hfloat16 *, BLASLONG, hfloat16 *); | |||||
#endif | #endif | ||||
#ifdef ARCH_ARM64 | #ifdef ARCH_ARM64 | ||||
void (*sgemm_direct) (BLASLONG, BLASLONG, BLASLONG, float *, BLASLONG , float *, BLASLONG , float * , BLASLONG); | void (*sgemm_direct) (BLASLONG, BLASLONG, BLASLONG, float *, BLASLONG , float *, BLASLONG , float * , BLASLONG); | ||||
void (*sgemm_direct_alpha_beta) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float, float * , BLASLONG); | |||||
#endif | #endif | ||||
@@ -49,6 +49,7 @@ | |||||
#define SGEMM_DIRECT_PERFORMANT sgemm_direct_performant | #define SGEMM_DIRECT_PERFORMANT sgemm_direct_performant | ||||
#define SGEMM_DIRECT sgemm_direct | #define SGEMM_DIRECT sgemm_direct | ||||
#define SGEMM_DIRECT_ALPHA_BETA sgemm_direct_alpha_beta | |||||
#define SGEMM_ONCOPY sgemm_oncopy | #define SGEMM_ONCOPY sgemm_oncopy | ||||
#define SGEMM_OTCOPY sgemm_otcopy | #define SGEMM_OTCOPY sgemm_otcopy | ||||
@@ -218,6 +219,7 @@ | |||||
#elif ARCH_ARM64 | #elif ARCH_ARM64 | ||||
#define SGEMM_DIRECT_PERFORMANT sgemm_direct_performant | #define SGEMM_DIRECT_PERFORMANT sgemm_direct_performant | ||||
#define SGEMM_DIRECT gotoblas -> sgemm_direct | #define SGEMM_DIRECT gotoblas -> sgemm_direct | ||||
#define SGEMM_DIRECT_ALPHA_BETA gotoblas -> sgemm_direct_alpha_beta | |||||
#endif | #endif | ||||
#define SGEMM_ONCOPY gotoblas -> sgemm_oncopy | #define SGEMM_ONCOPY gotoblas -> sgemm_oncopy | ||||
@@ -436,6 +436,9 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS | |||||
if (beta == 0 && alpha == 1.0 && order == CblasRowMajor && TransA == CblasNoTrans && TransB == CblasNoTrans) { | if (beta == 0 && alpha == 1.0 && order == CblasRowMajor && TransA == CblasNoTrans && TransB == CblasNoTrans) { | ||||
SGEMM_DIRECT(m, n, k, a, lda, b, ldb, c, ldc); | SGEMM_DIRECT(m, n, k, a, lda, b, ldb, c, ldc); | ||||
return; | return; | ||||
}else if (order == CblasRowMajor && TransA == CblasNoTrans && TransB == CblasNoTrans) { | |||||
SGEMM_DIRECT_ALPHA_BETA(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); | |||||
return; | |||||
} | } | ||||
#endif | #endif | ||||
#endif | #endif | ||||
@@ -222,9 +222,11 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS) | |||||
GenerateNamedObjects("${KERNELDIR}/${SGEMMDIRECTPERFORMANT}" "" "gemm_direct_performant" false "" "" false SINGLE) | GenerateNamedObjects("${KERNELDIR}/${SGEMMDIRECTPERFORMANT}" "" "gemm_direct_performant" false "" "" false SINGLE) | ||||
elseif (ARM64) | elseif (ARM64) | ||||
set (SGEMMDIRECTKERNEL sgemm_direct_arm64_sme1.c) | set (SGEMMDIRECTKERNEL sgemm_direct_arm64_sme1.c) | ||||
set (SGEMMDIRECTKERNEL_ALPHA_BETA sgemm_direct_alpha_beta_arm64_sme1.c) | |||||
set (SGEMMDIRECTSMEKERNEL sgemm_direct_sme1.S) | set (SGEMMDIRECTSMEKERNEL sgemm_direct_sme1.S) | ||||
set (SGEMMDIRECTPREKERNEL sgemm_direct_sme1_preprocess.S) | set (SGEMMDIRECTPREKERNEL sgemm_direct_sme1_preprocess.S) | ||||
GenerateNamedObjects("${KERNELDIR}/${SGEMMDIRECTKERNEL}" "" "gemm_direct" false "" "" false SINGLE) | GenerateNamedObjects("${KERNELDIR}/${SGEMMDIRECTKERNEL}" "" "gemm_direct" false "" "" false SINGLE) | ||||
GenerateNamedObjects("${KERNELDIR}/${SGEMMDIRECTKERNEL_ALPHA_BETA}" "" "gemm_direct_alpha_beta" false "" "" false SINGLE) | |||||
if (HAVE_SME) | if (HAVE_SME) | ||||
GenerateNamedObjects("${KERNELDIR}/${SGEMMDIRECTSMEKERNEL}" "" "gemm_direct_sme1" false "" "" false SINGLE) | GenerateNamedObjects("${KERNELDIR}/${SGEMMDIRECTSMEKERNEL}" "" "gemm_direct_sme1" false "" "" false SINGLE) | ||||
GenerateNamedObjects("${KERNELDIR}/${SGEMMDIRECTPREKERNEL}" "" "gemm_direct_sme1_preprocess" false "" "" false SINGLE) | GenerateNamedObjects("${KERNELDIR}/${SGEMMDIRECTPREKERNEL}" "" "gemm_direct_sme1_preprocess" false "" "" false SINGLE) | ||||
@@ -105,6 +105,7 @@ ifeq ($(TARGET_CORE), ARMV9SME) | |||||
HAVE_SME = 1 | HAVE_SME = 1 | ||||
endif | endif | ||||
SGEMMDIRECTKERNEL = sgemm_direct_arm64_sme1.c | SGEMMDIRECTKERNEL = sgemm_direct_arm64_sme1.c | ||||
SGEMMDIRECTKERNEL_ALPHA_BETA = sgemm_direct_alpha_beta_arm64_sme1.c | |||||
endif | endif | ||||
endif | endif | ||||
endif | endif | ||||
@@ -164,7 +165,8 @@ SKERNELOBJS += \ | |||||
endif | endif | ||||
ifeq ($(ARCH), arm64) | ifeq ($(ARCH), arm64) | ||||
SKERNELOBJS += \ | SKERNELOBJS += \ | ||||
sgemm_direct$(TSUFFIX).$(SUFFIX) | |||||
sgemm_direct$(TSUFFIX).$(SUFFIX) \ | |||||
sgemm_direct_alpha_beta$(TSUFFIX).$(SUFFIX) | |||||
ifdef HAVE_SME | ifdef HAVE_SME | ||||
SKERNELOBJS += \ | SKERNELOBJS += \ | ||||
sgemm_direct_sme1$(TSUFFIX).$(SUFFIX) \ | sgemm_direct_sme1$(TSUFFIX).$(SUFFIX) \ | ||||
@@ -904,6 +906,8 @@ endif | |||||
ifeq ($(ARCH), arm64) | ifeq ($(ARCH), arm64) | ||||
$(KDIR)sgemm_direct$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMMDIRECTKERNEL) | $(KDIR)sgemm_direct$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMMDIRECTKERNEL) | ||||
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ | $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ | ||||
$(KDIR)sgemm_direct_alpha_beta$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMMDIRECTKERNEL_ALPHA_BETA) | |||||
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ | |||||
ifdef HAVE_SME | ifdef HAVE_SME | ||||
$(KDIR)sgemm_direct_sme1$(TSUFFIX).$(SUFFIX) : | $(KDIR)sgemm_direct_sme1$(TSUFFIX).$(SUFFIX) : | ||||
$(CC) $(CFLAGS) -c $(KERNELDIR)/sgemm_direct_sme1.S -UDOUBLE -UCOMPLEX -o $@ | $(CC) $(CFLAGS) -c $(KERNELDIR)/sgemm_direct_sme1.S -UDOUBLE -UCOMPLEX -o $@ | ||||
@@ -0,0 +1,199 @@ | |||||
/* | |||||
Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. | |||||
SPDX-License-Identifier: BSD-3-Clause-Clear | |||||
*/ | |||||
#include "common.h" | |||||
#include <stdlib.h> | |||||
#include <inttypes.h> | |||||
#include <math.h> | |||||
#include "sme_abi.h" | |||||
#if defined(HAVE_SME) | |||||
#if defined(__ARM_FEATURE_SME) && defined(__clang__) && __clang_major__ >= 16 | |||||
#include <arm_sme.h> | |||||
#endif | |||||
/* Function prototypes */ | |||||
extern void sgemm_direct_sme1_preprocess(uint64_t nbr, uint64_t nbc,\ | |||||
const float * restrict a, float * a_mod) __asm__("sgemm_direct_sme1_preprocess"); | |||||
/* Function Definitions */ | |||||
static uint64_t sve_cntw() { | |||||
uint64_t cnt; | |||||
asm volatile( | |||||
"rdsvl %[res], #1\n" | |||||
"lsr %[res], %[res], #2\n" | |||||
: [res] "=r" (cnt) :: | |||||
); | |||||
return cnt; | |||||
} | |||||
#if defined(__ARM_FEATURE_SME) && defined(__ARM_FEATURE_LOCALLY_STREAMING) && defined(__clang__) && __clang_major__ >= 16 | |||||
// Outer product kernel. | |||||
// Computes a 2SVL x 2SVL block of C, utilizing all four FP32 tiles of ZA. | |||||
__attribute__((always_inline)) inline void | |||||
kernel_2x2(const float *A, const float *B, float *C, size_t shared_dim, | |||||
size_t ldc, size_t block_rows, size_t block_cols, float alpha, float beta) | |||||
__arm_out("za") __arm_streaming { | |||||
const uint64_t svl = svcntw(); | |||||
size_t ldb = ldc; | |||||
// Predicate set-up | |||||
svbool_t pg = svptrue_b32(); | |||||
svbool_t pg_a_0 = svwhilelt_b32_u64(0, block_rows); | |||||
svbool_t pg_a_1 = svwhilelt_b32_u64(svl, block_rows); | |||||
svbool_t pg_b_0 = svwhilelt_b32_u64(0, block_cols); | |||||
svbool_t pg_b_1 = svwhilelt_b32_u64(svl, block_cols); | |||||
#define pg_c_0 pg_b_0 | |||||
#define pg_c_1 pg_b_1 | |||||
svzero_za(); | |||||
svfloat32_t beta_vec = svdup_f32(beta); | |||||
// Load C to ZA | |||||
for (size_t i = 0; i < MIN(svl, block_rows); i++) { | |||||
svfloat32_t row_c_0 = svld1(pg_c_0, &C[i * ldc]); | |||||
row_c_0 = svmul_x(pg, beta_vec, row_c_0); | |||||
svwrite_hor_za32_f32_m(/*tile*/0, /*slice*/i, pg_c_0, row_c_0); | |||||
svfloat32_t row_c_1 = svld1(pg_c_1, &C[i * ldc + svl]); | |||||
row_c_1 = svmul_x(pg, beta_vec, row_c_1); | |||||
svwrite_hor_za32_f32_m(/*tile*/1, /*slice*/i, pg_c_1, row_c_1); | |||||
} | |||||
for (size_t i = svl; i < block_rows; i++) { | |||||
svfloat32_t row_c_0 = svld1(pg_c_0, &C[i * ldc]); | |||||
row_c_0 = svmul_x(pg, beta_vec, row_c_0); | |||||
svwrite_hor_za32_f32_m(/*tile*/2, /*slice*/i, pg_c_0, row_c_0); | |||||
svfloat32_t row_c_1 = svld1(pg_c_1, &C[i * ldc + svl]); | |||||
row_c_1 = svmul_x(pg, beta_vec, row_c_1); | |||||
svwrite_hor_za32_f32_m(/*tile*/3, /*slice*/i, pg_c_1, row_c_1); | |||||
} | |||||
svfloat32_t alpha_vec = svdup_f32(alpha); | |||||
// Iterate through shared dimension (K) | |||||
for (size_t k = 0; k < shared_dim; k++) { | |||||
// Load column of A | |||||
svfloat32_t col_a_0 = svld1(pg_a_0, &A[k * svl]); | |||||
col_a_0 = svmul_x(pg, alpha_vec, col_a_0); | |||||
svfloat32_t col_a_1 = svld1(pg_a_1, &A[(k + shared_dim) * svl]); | |||||
col_a_1 = svmul_x(pg, alpha_vec, col_a_1); | |||||
// Load row of B | |||||
svfloat32_t row_b_0 = svld1(pg_b_0, &B[k * ldb]); | |||||
svfloat32_t row_b_1 = svld1(pg_b_1, &B[k * ldb + svl]); | |||||
// Perform outer product | |||||
svmopa_za32_m(/*tile*/0, pg, pg, col_a_0, row_b_0); | |||||
svmopa_za32_m(/*tile*/1, pg, pg, col_a_0, row_b_1); | |||||
svmopa_za32_m(/*tile*/2, pg, pg, col_a_1, row_b_0); | |||||
svmopa_za32_m(/*tile*/3, pg, pg, col_a_1, row_b_1); | |||||
} | |||||
// Store to C from ZA | |||||
for (size_t i = 0; i < MIN(svl, block_rows); i++) { | |||||
svst1_hor_za32(/*tile*/0, /*slice*/i, pg_c_0, &C[i * ldc]); | |||||
svst1_hor_za32(/*tile*/1, /*slice*/i, pg_c_1, &C[i * ldc + svl]); | |||||
} | |||||
for (size_t i = svl; i < block_rows; i++) { | |||||
svst1_hor_za32(/*tile*/2, /*slice*/i, pg_c_0, &C[i * ldc]); | |||||
svst1_hor_za32(/*tile*/3, /*slice*/i, pg_c_1, &C[i * ldc + svl]); | |||||
} | |||||
} | |||||
__arm_new("za") __arm_locally_streaming | |||||
void sgemm_direct_alpha_beta_sme1_2VLx2VL(uint64_t m, uint64_t k, uint64_t n, const float* alpha,\ | |||||
const float *ba, const float *restrict bb, const float* beta,\ | |||||
float *restrict C) { | |||||
const uint64_t num_rows = m; | |||||
const uint64_t num_cols = n; | |||||
const float *restrict a_ptr = ba; | |||||
const float *restrict b_ptr = bb; | |||||
float *restrict c_ptr = C; | |||||
const uint64_t svl = svcntw(); | |||||
const uint64_t ldc = n; | |||||
// Block over rows of C (panels of A) | |||||
uint64_t row_idx = 0; | |||||
// 2x2 loop | |||||
uint64_t row_batch = 2*svl; | |||||
// Block over row dimension of C | |||||
for (; row_idx < num_rows; row_idx += row_batch) { | |||||
row_batch = MIN(row_batch, num_rows - row_idx); | |||||
uint64_t col_idx = 0; | |||||
uint64_t col_batch = 2*svl; | |||||
// Block over column dimension of C | |||||
for (; col_idx < num_cols; col_idx += col_batch) { | |||||
col_batch = MIN(col_batch, num_cols - col_idx); | |||||
kernel_2x2(&a_ptr[row_idx * k], &b_ptr[col_idx], | |||||
&c_ptr[row_idx * ldc + col_idx], k, | |||||
ldc, row_batch, col_batch, *alpha, *beta); | |||||
} | |||||
} | |||||
return; | |||||
} | |||||
#else | |||||
void sgemm_direct_alpha_beta_sme1_2VLx2VL(uint64_t m, uint64_t k, uint64_t n, const float* alpha,\ | |||||
const float *ba, const float *restrict bb, const float* beta,\ | |||||
float *restrict C){} | |||||
#endif | |||||
/*void sgemm_kernel_direct (BLASLONG M, BLASLONG N, BLASLONG K,\ | |||||
float * __restrict A, BLASLONG strideA, float * __restrict B,\ | |||||
BLASLONG strideB , float * __restrict R, BLASLONG strideR) | |||||
*/ | |||||
void CNAME (BLASLONG M, BLASLONG N, BLASLONG K, float alpha, float * __restrict A,\ | |||||
BLASLONG strideA, float * __restrict B, BLASLONG strideB ,\ | |||||
float beta, float * __restrict R, BLASLONG strideR){ | |||||
uint64_t m_mod, vl_elms; | |||||
vl_elms = sve_cntw(); | |||||
m_mod = ceil((double)M/(double)vl_elms) * vl_elms; | |||||
float *A_mod = (float *) malloc(m_mod*K*sizeof(float)); | |||||
/* Prevent compiler optimization by reading from memory instead | |||||
* of reading directly from vector (z) registers. | |||||
* */ | |||||
asm volatile("" : : :"p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", | |||||
"p8", "p9", "p10", "p11", "p12", "p13", "p14", "p15", | |||||
"z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", | |||||
"z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", | |||||
"z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", | |||||
"z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31"); | |||||
/* Pre-process the left matrix to make it suitable for | |||||
matrix sum of outer-product calculation | |||||
*/ | |||||
sgemm_direct_sme1_preprocess(M, K, A, A_mod); | |||||
asm volatile("" : : :"p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", | |||||
"p8", "p9", "p10", "p11", "p12", "p13", "p14", "p15", | |||||
"z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", | |||||
"z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", | |||||
"z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", | |||||
"z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31"); | |||||
/* Calculate C = alpha*A*B + beta*C */ | |||||
sgemm_direct_alpha_beta_sme1_2VLx2VL(M, K, N, &alpha, A_mod, B, &beta, R); | |||||
free(A_mod); | |||||
} | |||||
#else | |||||
void CNAME (BLASLONG M, BLASLONG N, BLASLONG K, float alpha, float * __restrict A,\ | |||||
BLASLONG strideA, float * __restrict B, BLASLONG strideB ,\ | |||||
float beta, float * __restrict R, BLASLONG strideR){} | |||||
#endif |
@@ -0,0 +1,46 @@ | |||||
/*************************************************************************** | |||||
* Copyright (c) 2024, The OpenBLAS Project | |||||
* All rights reserved. | |||||
* Redistribution and use in source and binary forms, with or without | |||||
* modification, are permitted provided that the following conditions are | |||||
* met: | |||||
* 1. Redistributions of source code must retain the above copyright | |||||
* notice, this list of conditions and the following disclaimer. | |||||
* 2. Redistributions in binary form must reproduce the above copyright | |||||
* notice, this list of conditions and the following disclaimer in | |||||
* the documentation and/or other materials provided with the | |||||
* distribution. | |||||
* 3. Neither the name of the OpenBLAS project nor the names of | |||||
* its contributors may be used to endorse or promote products | |||||
* derived from this software without specific prior written permission. | |||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||||
* ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE | |||||
* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR | |||||
* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE | |||||
* GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) | |||||
* HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT | |||||
* LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF | |||||
* THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||||
* *****************************************************************************/ | |||||
#pragma once | |||||
#include <stdlib.h> | |||||
/** | |||||
* * These are SME ABI routines for saving & restoring SME state. | |||||
* * They are typically provided by a compiler runtime library such | |||||
* * as libgcc or compiler-rt, but support for these routines is not | |||||
* * yet available on all platforms. | |||||
* * | |||||
* * Define these as aborting stubs so that we loudly fail on nested | |||||
* * usage of SME state. | |||||
* * | |||||
* * These are defined as weak symbols so that a compiler runtime can | |||||
* * override them if supported. | |||||
* */ | |||||
__attribute__((weak)) void __arm_tpidr2_save() { abort(); } | |||||
__attribute__((weak)) void __arm_tpidr2_restore() { abort(); } | |||||
@@ -198,6 +198,7 @@ gotoblas_t TABLE_NAME = { | |||||
#endif | #endif | ||||
#ifdef ARCH_ARM64 | #ifdef ARCH_ARM64 | ||||
sgemm_directTS, | sgemm_directTS, | ||||
sgemm_direct_alpha_betaTS, | |||||
#endif | #endif | ||||
sgemm_kernelTS, sgemm_betaTS, | sgemm_kernelTS, sgemm_betaTS, | ||||