Browse Source

SME1 based direct kernel with alpha and beta for cblas_sgemm level 3 API.

pull/5380/head
Rajendra Prasad Matcha 2 months ago
parent
commit
eae0abfdb6
9 changed files with 266 additions and 1 deletions
  1. +7
    -0
      common_level3.h
  2. +1
    -0
      common_param.h
  3. +2
    -0
      common_s.h
  4. +3
    -0
      interface/gemm.c
  5. +2
    -0
      kernel/CMakeLists.txt
  6. +5
    -1
      kernel/Makefile.L3
  7. +199
    -0
      kernel/arm64/sgemm_direct_alpha_beta_arm64_sme1.c
  8. +46
    -0
      kernel/arm64/sme_abi.h
  9. +1
    -0
      kernel/setparam-ref.c

+ 7
- 0
common_level3.h View File

@@ -52,6 +52,13 @@ void sgemm_direct(BLASLONG M, BLASLONG N, BLASLONG K,
float * B, BLASLONG strideB,
float * R, BLASLONG strideR);

void sgemm_direct_alpha_beta(BLASLONG M, BLASLONG N, BLASLONG K,
float alpha,
float * A, BLASLONG strideA,
float * B, BLASLONG strideB,
float beta,
float * R, BLASLONG strideR);

int sgemm_direct_performant(BLASLONG M, BLASLONG N, BLASLONG K);

int shgemm_beta(BLASLONG, BLASLONG, BLASLONG, float,


+ 1
- 0
common_param.h View File

@@ -240,6 +240,7 @@ int (*shgemm_otcopy )(BLASLONG, BLASLONG, hfloat16 *, BLASLONG, hfloat16 *);
#endif
#ifdef ARCH_ARM64
void (*sgemm_direct) (BLASLONG, BLASLONG, BLASLONG, float *, BLASLONG , float *, BLASLONG , float * , BLASLONG);
void (*sgemm_direct_alpha_beta) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float, float * , BLASLONG);
#endif



+ 2
- 0
common_s.h View File

@@ -49,6 +49,7 @@

#define SGEMM_DIRECT_PERFORMANT sgemm_direct_performant
#define SGEMM_DIRECT sgemm_direct
#define SGEMM_DIRECT_ALPHA_BETA sgemm_direct_alpha_beta

#define SGEMM_ONCOPY sgemm_oncopy
#define SGEMM_OTCOPY sgemm_otcopy
@@ -218,6 +219,7 @@
#elif ARCH_ARM64
#define SGEMM_DIRECT_PERFORMANT sgemm_direct_performant
#define SGEMM_DIRECT gotoblas -> sgemm_direct
#define SGEMM_DIRECT_ALPHA_BETA gotoblas -> sgemm_direct_alpha_beta
#endif

#define SGEMM_ONCOPY gotoblas -> sgemm_oncopy


+ 3
- 0
interface/gemm.c View File

@@ -436,6 +436,9 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS
if (beta == 0 && alpha == 1.0 && order == CblasRowMajor && TransA == CblasNoTrans && TransB == CblasNoTrans) {
SGEMM_DIRECT(m, n, k, a, lda, b, ldb, c, ldc);
return;
}else if (order == CblasRowMajor && TransA == CblasNoTrans && TransB == CblasNoTrans) {
SGEMM_DIRECT_ALPHA_BETA(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
return;
}
#endif
#endif


+ 2
- 0
kernel/CMakeLists.txt View File

@@ -222,9 +222,11 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS)
GenerateNamedObjects("${KERNELDIR}/${SGEMMDIRECTPERFORMANT}" "" "gemm_direct_performant" false "" "" false SINGLE)
elseif (ARM64)
set (SGEMMDIRECTKERNEL sgemm_direct_arm64_sme1.c)
set (SGEMMDIRECTKERNEL_ALPHA_BETA sgemm_direct_alpha_beta_arm64_sme1.c)
set (SGEMMDIRECTSMEKERNEL sgemm_direct_sme1.S)
set (SGEMMDIRECTPREKERNEL sgemm_direct_sme1_preprocess.S)
GenerateNamedObjects("${KERNELDIR}/${SGEMMDIRECTKERNEL}" "" "gemm_direct" false "" "" false SINGLE)
GenerateNamedObjects("${KERNELDIR}/${SGEMMDIRECTKERNEL_ALPHA_BETA}" "" "gemm_direct_alpha_beta" false "" "" false SINGLE)
if (HAVE_SME)
GenerateNamedObjects("${KERNELDIR}/${SGEMMDIRECTSMEKERNEL}" "" "gemm_direct_sme1" false "" "" false SINGLE)
GenerateNamedObjects("${KERNELDIR}/${SGEMMDIRECTPREKERNEL}" "" "gemm_direct_sme1_preprocess" false "" "" false SINGLE)


+ 5
- 1
kernel/Makefile.L3 View File

@@ -105,6 +105,7 @@ ifeq ($(TARGET_CORE), ARMV9SME)
HAVE_SME = 1
endif
SGEMMDIRECTKERNEL = sgemm_direct_arm64_sme1.c
SGEMMDIRECTKERNEL_ALPHA_BETA = sgemm_direct_alpha_beta_arm64_sme1.c
endif
endif
endif
@@ -164,7 +165,8 @@ SKERNELOBJS += \
endif
ifeq ($(ARCH), arm64)
SKERNELOBJS += \
sgemm_direct$(TSUFFIX).$(SUFFIX)
sgemm_direct$(TSUFFIX).$(SUFFIX) \
sgemm_direct_alpha_beta$(TSUFFIX).$(SUFFIX)
ifdef HAVE_SME
SKERNELOBJS += \
sgemm_direct_sme1$(TSUFFIX).$(SUFFIX) \
@@ -904,6 +906,8 @@ endif
ifeq ($(ARCH), arm64)
$(KDIR)sgemm_direct$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMMDIRECTKERNEL)
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@
$(KDIR)sgemm_direct_alpha_beta$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMMDIRECTKERNEL_ALPHA_BETA)
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@
ifdef HAVE_SME
$(KDIR)sgemm_direct_sme1$(TSUFFIX).$(SUFFIX) :
$(CC) $(CFLAGS) -c $(KERNELDIR)/sgemm_direct_sme1.S -UDOUBLE -UCOMPLEX -o $@


+ 199
- 0
kernel/arm64/sgemm_direct_alpha_beta_arm64_sme1.c View File

@@ -0,0 +1,199 @@
/*
Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
SPDX-License-Identifier: BSD-3-Clause-Clear
*/
#include "common.h"
#include <stdlib.h>
#include <inttypes.h>
#include <math.h>
#include "sme_abi.h"
#if defined(HAVE_SME)
#if defined(__ARM_FEATURE_SME) && defined(__clang__) && __clang_major__ >= 16
#include <arm_sme.h>
#endif
/* Function prototypes */
extern void sgemm_direct_sme1_preprocess(uint64_t nbr, uint64_t nbc,\
const float * restrict a, float * a_mod) __asm__("sgemm_direct_sme1_preprocess");
/* Function Definitions */
static uint64_t sve_cntw() {
uint64_t cnt;
asm volatile(
"rdsvl %[res], #1\n"
"lsr %[res], %[res], #2\n"
: [res] "=r" (cnt) ::
);
return cnt;
}
#if defined(__ARM_FEATURE_SME) && defined(__ARM_FEATURE_LOCALLY_STREAMING) && defined(__clang__) && __clang_major__ >= 16
// Outer product kernel.
// Computes a 2SVL x 2SVL block of C, utilizing all four FP32 tiles of ZA.
__attribute__((always_inline)) inline void
kernel_2x2(const float *A, const float *B, float *C, size_t shared_dim,
size_t ldc, size_t block_rows, size_t block_cols, float alpha, float beta)
__arm_out("za") __arm_streaming {
const uint64_t svl = svcntw();
size_t ldb = ldc;
// Predicate set-up
svbool_t pg = svptrue_b32();
svbool_t pg_a_0 = svwhilelt_b32_u64(0, block_rows);
svbool_t pg_a_1 = svwhilelt_b32_u64(svl, block_rows);
svbool_t pg_b_0 = svwhilelt_b32_u64(0, block_cols);
svbool_t pg_b_1 = svwhilelt_b32_u64(svl, block_cols);
#define pg_c_0 pg_b_0
#define pg_c_1 pg_b_1
svzero_za();
svfloat32_t beta_vec = svdup_f32(beta);
// Load C to ZA
for (size_t i = 0; i < MIN(svl, block_rows); i++) {
svfloat32_t row_c_0 = svld1(pg_c_0, &C[i * ldc]);
row_c_0 = svmul_x(pg, beta_vec, row_c_0);
svwrite_hor_za32_f32_m(/*tile*/0, /*slice*/i, pg_c_0, row_c_0);
svfloat32_t row_c_1 = svld1(pg_c_1, &C[i * ldc + svl]);
row_c_1 = svmul_x(pg, beta_vec, row_c_1);
svwrite_hor_za32_f32_m(/*tile*/1, /*slice*/i, pg_c_1, row_c_1);
}
for (size_t i = svl; i < block_rows; i++) {
svfloat32_t row_c_0 = svld1(pg_c_0, &C[i * ldc]);
row_c_0 = svmul_x(pg, beta_vec, row_c_0);
svwrite_hor_za32_f32_m(/*tile*/2, /*slice*/i, pg_c_0, row_c_0);
svfloat32_t row_c_1 = svld1(pg_c_1, &C[i * ldc + svl]);
row_c_1 = svmul_x(pg, beta_vec, row_c_1);
svwrite_hor_za32_f32_m(/*tile*/3, /*slice*/i, pg_c_1, row_c_1);
}
svfloat32_t alpha_vec = svdup_f32(alpha);
// Iterate through shared dimension (K)
for (size_t k = 0; k < shared_dim; k++) {
// Load column of A
svfloat32_t col_a_0 = svld1(pg_a_0, &A[k * svl]);
col_a_0 = svmul_x(pg, alpha_vec, col_a_0);
svfloat32_t col_a_1 = svld1(pg_a_1, &A[(k + shared_dim) * svl]);
col_a_1 = svmul_x(pg, alpha_vec, col_a_1);
// Load row of B
svfloat32_t row_b_0 = svld1(pg_b_0, &B[k * ldb]);
svfloat32_t row_b_1 = svld1(pg_b_1, &B[k * ldb + svl]);
// Perform outer product
svmopa_za32_m(/*tile*/0, pg, pg, col_a_0, row_b_0);
svmopa_za32_m(/*tile*/1, pg, pg, col_a_0, row_b_1);
svmopa_za32_m(/*tile*/2, pg, pg, col_a_1, row_b_0);
svmopa_za32_m(/*tile*/3, pg, pg, col_a_1, row_b_1);
}
// Store to C from ZA
for (size_t i = 0; i < MIN(svl, block_rows); i++) {
svst1_hor_za32(/*tile*/0, /*slice*/i, pg_c_0, &C[i * ldc]);
svst1_hor_za32(/*tile*/1, /*slice*/i, pg_c_1, &C[i * ldc + svl]);
}
for (size_t i = svl; i < block_rows; i++) {
svst1_hor_za32(/*tile*/2, /*slice*/i, pg_c_0, &C[i * ldc]);
svst1_hor_za32(/*tile*/3, /*slice*/i, pg_c_1, &C[i * ldc + svl]);
}
}
__arm_new("za") __arm_locally_streaming
void sgemm_direct_alpha_beta_sme1_2VLx2VL(uint64_t m, uint64_t k, uint64_t n, const float* alpha,\
const float *ba, const float *restrict bb, const float* beta,\
float *restrict C) {
const uint64_t num_rows = m;
const uint64_t num_cols = n;
const float *restrict a_ptr = ba;
const float *restrict b_ptr = bb;
float *restrict c_ptr = C;
const uint64_t svl = svcntw();
const uint64_t ldc = n;
// Block over rows of C (panels of A)
uint64_t row_idx = 0;
// 2x2 loop
uint64_t row_batch = 2*svl;
// Block over row dimension of C
for (; row_idx < num_rows; row_idx += row_batch) {
row_batch = MIN(row_batch, num_rows - row_idx);
uint64_t col_idx = 0;
uint64_t col_batch = 2*svl;
// Block over column dimension of C
for (; col_idx < num_cols; col_idx += col_batch) {
col_batch = MIN(col_batch, num_cols - col_idx);
kernel_2x2(&a_ptr[row_idx * k], &b_ptr[col_idx],
&c_ptr[row_idx * ldc + col_idx], k,
ldc, row_batch, col_batch, *alpha, *beta);
}
}
return;
}
#else
void sgemm_direct_alpha_beta_sme1_2VLx2VL(uint64_t m, uint64_t k, uint64_t n, const float* alpha,\
const float *ba, const float *restrict bb, const float* beta,\
float *restrict C){}
#endif
/*void sgemm_kernel_direct (BLASLONG M, BLASLONG N, BLASLONG K,\
float * __restrict A, BLASLONG strideA, float * __restrict B,\
BLASLONG strideB , float * __restrict R, BLASLONG strideR)
*/
void CNAME (BLASLONG M, BLASLONG N, BLASLONG K, float alpha, float * __restrict A,\
BLASLONG strideA, float * __restrict B, BLASLONG strideB ,\
float beta, float * __restrict R, BLASLONG strideR){
uint64_t m_mod, vl_elms;
vl_elms = sve_cntw();
m_mod = ceil((double)M/(double)vl_elms) * vl_elms;
float *A_mod = (float *) malloc(m_mod*K*sizeof(float));
/* Prevent compiler optimization by reading from memory instead
* of reading directly from vector (z) registers.
* */
asm volatile("" : : :"p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7",
"p8", "p9", "p10", "p11", "p12", "p13", "p14", "p15",
"z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7",
"z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15",
"z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23",
"z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31");
/* Pre-process the left matrix to make it suitable for
matrix sum of outer-product calculation
*/
sgemm_direct_sme1_preprocess(M, K, A, A_mod);
asm volatile("" : : :"p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7",
"p8", "p9", "p10", "p11", "p12", "p13", "p14", "p15",
"z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7",
"z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15",
"z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23",
"z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31");
/* Calculate C = alpha*A*B + beta*C */
sgemm_direct_alpha_beta_sme1_2VLx2VL(M, K, N, &alpha, A_mod, B, &beta, R);
free(A_mod);
}
#else
void CNAME (BLASLONG M, BLASLONG N, BLASLONG K, float alpha, float * __restrict A,\
BLASLONG strideA, float * __restrict B, BLASLONG strideB ,\
float beta, float * __restrict R, BLASLONG strideR){}
#endif

+ 46
- 0
kernel/arm64/sme_abi.h View File

@@ -0,0 +1,46 @@
/***************************************************************************
* Copyright (c) 2024, The OpenBLAS Project
* All rights reserved.
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
* 1. Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in
* the documentation and/or other materials provided with the
* distribution.
* 3. Neither the name of the OpenBLAS project nor the names of
* its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE
* GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
* HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
* LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF
* THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
* *****************************************************************************/

#pragma once

#include <stdlib.h>

/**
* * These are SME ABI routines for saving & restoring SME state.
* * They are typically provided by a compiler runtime library such
* * as libgcc or compiler-rt, but support for these routines is not
* * yet available on all platforms.
* *
* * Define these as aborting stubs so that we loudly fail on nested
* * usage of SME state.
* *
* * These are defined as weak symbols so that a compiler runtime can
* * override them if supported.
* */
__attribute__((weak)) void __arm_tpidr2_save() { abort(); }
__attribute__((weak)) void __arm_tpidr2_restore() { abort(); }


+ 1
- 0
kernel/setparam-ref.c View File

@@ -198,6 +198,7 @@ gotoblas_t TABLE_NAME = {
#endif
#ifdef ARCH_ARM64
sgemm_directTS,
sgemm_direct_alpha_betaTS,
#endif

sgemm_kernelTS, sgemm_betaTS,


Loading…
Cancel
Save