Browse Source

Support for SME1 based sgemm_direct kernel for cblas_sgemm level 3 API

* Added ARMV9SME target
* Added SGEMM_DIRECT kernel based on SME1
tags/v0.3.30
Vaisakh K V 10 months ago
parent
commit
d23eb3b93e
26 changed files with 694 additions and 36 deletions
  1. +2
    -1
      CMakeLists.txt
  2. +5
    -0
      Makefile.arm64
  3. +8
    -0
      Makefile.system
  4. +1
    -0
      TargetList.txt
  5. +19
    -0
      c_check
  6. +15
    -3
      cmake/arch.cmake
  7. +6
    -0
      cmake/cc.cmake
  8. +1
    -1
      cmake/prebuild.cmake
  9. +27
    -11
      cmake/system.cmake
  10. +11
    -0
      cmake/system_check.cmake
  11. +1
    -0
      common.h
  12. +1
    -1
      common_arm64.h
  13. +6
    -0
      common_param.h
  14. +2
    -2
      common_s.h
  15. +34
    -0
      driver/others/dynamic_arm64.c
  16. +13
    -0
      getarch.c
  17. +57
    -13
      interface/gemm.c
  18. +14
    -2
      kernel/CMakeLists.txt
  19. +4
    -0
      kernel/Makefile
  20. +32
    -1
      kernel/Makefile.L3
  21. +3
    -0
      kernel/arm64/KERNEL.ARMV9SME
  22. +59
    -0
      kernel/arm64/sgemm_direct_arm64_sme1.c
  23. +228
    -0
      kernel/arm64/sgemm_direct_sme1.S
  24. +133
    -0
      kernel/arm64/sgemm_direct_sme1_preprocess.S
  25. +5
    -0
      kernel/setparam-ref.c
  26. +7
    -1
      param.h

+ 2
- 1
CMakeLists.txt View File

@@ -4,11 +4,12 @@

cmake_minimum_required(VERSION 3.16.0)

set (CMAKE_ASM_SOURCE_FILE_EXTENSIONS "S")
project(OpenBLAS C ASM)

set(OpenBLAS_MAJOR_VERSION 0)
set(OpenBLAS_MINOR_VERSION 3)
set(OpenBLAS_PATCH_VERSION 28.dev)
set(OpenBLAS_PATCH_VERSION 29.dev)

set(OpenBLAS_VERSION "${OpenBLAS_MAJOR_VERSION}.${OpenBLAS_MINOR_VERSION}.${OpenBLAS_PATCH_VERSION}")



+ 5
- 0
Makefile.arm64 View File

@@ -30,6 +30,11 @@ FCOMMON_OPT += -march=armv8-a+sve
endif
endif

ifeq ($(CORE), ARMV9SME)
CCOMMON_OPT += -march=armv9-a+sve2+sme
FCOMMON_OPT += -march=armv9-a+sve2
endif

ifeq ($(CORE), CORTEXA53)
CCOMMON_OPT += -march=armv8-a -mtune=cortex-a53
ifneq ($(F_COMPILER), NAG)


+ 8
- 0
Makefile.system View File

@@ -420,6 +420,7 @@ ifeq ($(ARCH), arm64)
export MACOSX_DEPLOYMENT_TARGET=11.0
ifeq ($(C_COMPILER), GCC)
export NO_SVE = 1
export NO_SME = 1
endif
else
export MACOSX_DEPLOYMENT_TARGET=10.8
@@ -709,6 +710,9 @@ DYNAMIC_CORE += NEOVERSEN2
DYNAMIC_CORE += ARMV8SVE
DYNAMIC_CORE += A64FX
endif
ifneq ($(NO_SME), 1)
DYNAMIC_CORE += ARMV9SME
endif
DYNAMIC_CORE += THUNDERX
DYNAMIC_CORE += THUNDERX2T99
DYNAMIC_CORE += TSV110
@@ -1474,6 +1478,10 @@ ifeq ($(NO_SVE), 1)
CCOMMON_OPT += -DNO_SVE
endif

ifeq ($(NO_SME), 1)
CCOMMON_OPT += -DNO_SME
endif

ifdef SMP
CCOMMON_OPT += -DSMP_SERVER



+ 1
- 0
TargetList.txt View File

@@ -111,6 +111,7 @@ THUNDERX3T110
VORTEX
A64FX
ARMV8SVE
ARMV9SME
FT2000

9.System Z:


+ 19
- 0
c_check View File

@@ -331,6 +331,24 @@ if [ "$architecture" = "arm64" ]; then
rm -rf "$tmpd"
fi

no_sme=0
if [ "$architecture" = "arm64" ]; then
tmpd=$(mktemp -d 2>/dev/null || mktemp -d -t 'OBC')
tmpf="$tmpd/a.S"
printf ".text \n.global sme_test\n\nsme_test:\nsmstart\nsmstop\nret\n">> "$tmpf"
args=" -march=armv9-a+sve2+sme -c -o $tmpf.o $tmpf"
no_sme=0
{
$compiler_name $flags $args >/dev/null 2>&1
} || {
args=" -march=armv9-a+sme -c -o $tmpf.o $tmpf"
$compiler_name $flags $args >/dev/null 2>&1
} || {
no_sme=1
}
rm -rf "$tmpd"
fi

c11_atomics=0
case "$data" in
*HAVE_C11*)
@@ -472,6 +490,7 @@ done
printf "CEXTRALIB=%s %s %s\n" "$linker_L" "$linker_l" "$linker_a"
[ "$no_msa" -eq 1 ] && printf "NO_MSA=1\n"
[ "$no_sve" -eq 1 ] && printf "NO_SVE=1\n"
[ "$no_sme" -eq 1 ] && printf "NO_SME=1\n"
[ "$no_rv64gv" -eq 1 ] && printf "NO_RV64GV=1\n"
[ "$no_avx512" -eq 1 ] && printf "NO_AVX512=1\n"
[ "$no_avx512bf" -eq 1 ] && printf "NO_AVX512BF16=1\n"


+ 15
- 3
cmake/arch.cmake View File

@@ -44,9 +44,21 @@ endif ()

if (DYNAMIC_ARCH)
if (ARM64)
set(DYNAMIC_CORE ARMV8 CORTEXA53 CORTEXA57 THUNDERX THUNDERX2T99 TSV110 EMAG8180 NEOVERSEN1 THUNDERX3T110)
if (${CMAKE_C_COMPILER_VERSION} VERSION_GREATER 9.99)
set(DYNAMIC_CORE ${DYNAMIC_CORE} NEOVERSEV1 NEOVERSEN2 ARMV8SVE A64FX)
set(DYNAMIC_CORE ARMV8 CORTEXA53 CORTEXA57 THUNDERX THUNDERX2T99 TSV110 EMAG8180 NEOVERSEN1 THUNDERX3T110)
if (${CMAKE_C_COMPILER_ID} STREQUAL "GNU")
if (${CMAKE_C_COMPILER_VERSION} VERSION_GREATER_EQUAL 10) # SVE ACLE supported in GCC >= 10
set(DYNAMIC_CORE ${DYNAMIC_CORE} NEOVERSEV1 NEOVERSEN2 ARMV8SVE A64FX)
endif ()
if (${CMAKE_C_COMPILER_VERSION} VERSION_GREATER_EQUAL 14) # SME ACLE supported in GCC >= 14
set(DYNAMIC_CORE ${DYNAMIC_CORE} ARMV9SME)
endif()
elseif (${CMAKE_C_COMPILER_ID} MATCHES "Clang")
if (${CMAKE_C_COMPILER_VERSION} VERSION_GREATER_EQUAL 11) # SVE ACLE supported in LLVM >= 11
set(DYNAMIC_CORE ${DYNAMIC_CORE} NEOVERSEV1 NEOVERSEN2 ARMV8SVE A64FX)
endif ()
if (${CMAKE_C_COMPILER_VERSION} VERSION_GREATER_EQUAL 19) # SME ACLE supported in LLVM >= 19
set(DYNAMIC_CORE ${DYNAMIC_CORE} ARMV9SME)
endif()
endif ()
if (DYNAMIC_LIST)
set(DYNAMIC_CORE ARMV8 ${DYNAMIC_LIST})


+ 6
- 0
cmake/cc.cmake View File

@@ -238,6 +238,12 @@ if (${CORE} STREQUAL ARMV8SVE)
endif ()
endif ()

if (${CORE} STREQUAL ARMV9SME)
if (NOT DYNAMIC_ARCH)
set (CCOMMON_OPT "${CCOMMON_OPT} -march=armv9-a+sme")
endif ()
endif ()

if (${CORE} STREQUAL CORTEXA510)
if (NOT DYNAMIC_ARCH)
set (CCOMMON_OPT "${CCOMMON_OPT} -march=armv8-a+sve")


+ 1
- 1
cmake/prebuild.cmake View File

@@ -1014,7 +1014,7 @@ endif ()
set(ZGEMM_UNROLL_M 4)
set(ZGEMM_UNROLL_N 4)
set(SYMV_P 16)
elseif ("${TCORE}" STREQUAL "NEOVERSEN2")
elseif ("${TCORE}" STREQUAL "NEOVERSEN2" or "${TCORE}" STREQUAL "ARMV9SME")
file(APPEND ${TARGET_CONF_TEMP}
"#define L1_CODE_SIZE\t65536\n"
"#define L1_CODE_LINESIZE\t64\n"


+ 27
- 11
cmake/system.cmake View File

@@ -21,7 +21,15 @@ endif()
# Other files expect CORE, which is actually TARGET and will become TARGET_CORE for kernel build. Confused yet?
# It seems we are meant to use TARGET as input and CORE internally as kernel.
if(NOT DEFINED CORE AND DEFINED TARGET)
set(CORE ${TARGET})
if (${TARGET} STREQUAL "LOONGSON3R5")
set(CORE "LA464")
elseif (${TARGET} STREQUAL "LOONGSON2K1000")
set(CORE "LA264")
elseif (${TARGET} STREQUAL "LOONGSONGENERIC")
set(CORE "LA64_GENERIC)")
else ()
set(CORE ${TARGET})
endif()
endif()

# TARGET_CORE will override TARGET which is used in DYNAMIC_ARCH=1.
@@ -310,6 +318,9 @@ if (${TARGET} STREQUAL NEOVERSEV1)
set (KERNEL_DEFINITIONS "${KERNEL_DEFINITIONS} -march=armv8.2-a+sve")
endif()
endif()
if (${TARGET} STREQUAL ARMV9SME)
set (KERNEL_DEFINITIONS "${KERNEL_DEFINITIONS} -march=armv9-a+sme -O3")
endif()
if (${TARGET} STREQUAL A64FX)
if (${CMAKE_C_COMPILER_ID} STREQUAL "PGI" AND NOT NO_SVE)
set (KERNEL_DEFINITIONS "${KERNEL_DEFINITIONS} -Msve-intrinsics -march=armv8.2-a+sve -mtune=a64fx")
@@ -382,6 +393,8 @@ if (NEED_PIC)
if (NOT NOFORTRAN)
if (${F_COMPILER} STREQUAL "SUN")
set(FCOMMON_OPT "${FCOMMON_OPT} -pic")
elseif (${F_COMPILER} STREQUAL "NAGFOR")
set(FCOMMON_OPT "${FCOMMON_OPT} -PIC")
else ()
set(FCOMMON_OPT "${FCOMMON_OPT} -fPIC")
endif ()
@@ -640,17 +653,17 @@ if (${CMAKE_SYSTEM_NAME} STREQUAL "Windows")
endif ()

if (CMAKE_Fortran_COMPILER)
if ("${F_COMPILER}" STREQUAL "NAG" OR "${F_COMPILER}" STREQUAL "CRAY" OR CMAKE_Fortran_COMPILER_ID MATCHES "LLVMFlang.*")
set(FILTER_FLAGS "-msse3;-mssse3;-msse4.1;-mavx;-mavx2,-mskylake-avx512")
if (CMAKE_Fortran_COMPILER_ID MATCHES "LLVMFlang.*")
message(STATUS "removing fortran flags")
set(FILTER_FLAGS "${FILTER_FLAGS};-m32;-m64")
if ("${F_COMPILER}" STREQUAL "NAGFOR" OR "${F_COMPILER}" STREQUAL "CRAY" OR CMAKE_Fortran_COMPILER_ID MATCHES "LLVMFlang.*")
set(FILTER_FLAGS "-msse3;-mssse3;-msse4.1;-mavx;-mavx2,-mskylake-avx512")
if (CMAKE_Fortran_COMPILER_ID MATCHES "LLVMFlang.*")
message(STATUS "removing fortran flags")
set(FILTER_FLAGS "${FILTER_FLAGS};-m32;-m64")
endif ()
foreach (FILTER_FLAG ${FILTER_FLAGS})
string(REPLACE ${FILTER_FLAG} "" LAPACK_FFLAGS ${LAPACK_FFLAGS})
string(REPLACE ${FILTER_FLAG} "" LAPACK_FPFLAGS ${LAPACK_FPFLAGS})
endforeach ()
endif ()
foreach (FILTER_FLAG ${FILTER_FLAGS})
string(REPLACE ${FILTER_FLAG} "" LAPACK_FFLAGS ${LAPACK_FFLAGS})
string(REPLACE ${FILTER_FLAG} "" LAPACK_FPFLAGS ${LAPACK_FPFLAGS})
endforeach ()
endif ()
endif ()

if ("${F_COMPILER}" STREQUAL "GFORTRAN")
@@ -670,6 +683,9 @@ endif ()
if (${CMAKE_C_COMPILER} STREQUAL "LSB" OR ${CMAKE_SYSTEM_NAME} STREQUAL "Windows")
set(LAPACK_CFLAGS "${LAPACK_CFLAGS} -DLAPACK_COMPLEX_STRUCTURE")
endif ()
if (${CMAKE_C_COMPILER_ID} MATCHES "IntelLLVM" AND ${CMAKE_SYSTEM_NAME} STREQUAL "Windows")
set(LAPACK_CFLAGS "${LAPACK_CFLAGS} -DNOCHANGE")
endif ()

if ("${CMAKE_BUILD_TYPE}" STREQUAL "Release")
if ("${F_COMPILER}" STREQUAL "FLANG")


+ 11
- 0
cmake/system_check.cmake View File

@@ -135,6 +135,17 @@ endif()
endif()
endif()

if (ARM64)
if (NOT NO_SME)
file(WRITE ${PROJECT_BINARY_DIR}/sme.c ".text \n.global sme_test\n\nsme_test:\nsmstart\nsmstop\nret\n")
execute_process(COMMAND ${CMAKE_C_COMPILER} -march=armv9-a+sve2+sme -c -v -o ${PROJECT_BINARY_DIR}/sme.o ${PROJECT_BINARY_DIR}/sme.c OUTPUT_QUIET ERROR_QUIET RESULT_VARIABLE NO_SME)
if (NO_SME EQUAL 1)
set (CCOMMON_OPT "${CCOMMON_OPT} -DNO_SME")
endif()
file(REMOVE "${PROJECT_BINARY_DIR}/sme.c" "${PROJECT_BINARY_DIR}/sme.o")
endif()
endif()

include(CheckIncludeFile)
CHECK_INCLUDE_FILE("stdatomic.h" HAVE_C11)
if (HAVE_C11 EQUAL 1)


+ 1
- 0
common.h View File

@@ -696,6 +696,7 @@ void gotoblas_profile_init(void);
void gotoblas_profile_quit(void);
int support_avx512(void);
int support_sme1(void);

#ifdef USE_OPENMP



+ 1
- 1
common_arm64.h View File

@@ -175,7 +175,7 @@ REALNAME:
#define HUGE_PAGESIZE ( 4 << 20)

#ifndef BUFFERSIZE
#if defined(NEOVERSEN1) || defined(NEOVERSEN2) || defined(NEOVERSEV1) || defined(A64FX) || defined(ARMV8SVE)
#if defined(NEOVERSEN1) || defined(NEOVERSEN2) || defined(NEOVERSEV1) || defined(A64FX) || defined(ARMV8SVE) || defined(ARMV9SME)
#define BUFFER_SIZE (32 << 22)
#else
#define BUFFER_SIZE (32 << 20)


+ 6
- 0
common_param.h View File

@@ -221,6 +221,12 @@ BLASLONG (*ismin_k) (BLASLONG, float *, BLASLONG);
void (*sgemm_direct) (BLASLONG, BLASLONG, BLASLONG, float *, BLASLONG , float *, BLASLONG , float * , BLASLONG);
int (*sgemm_direct_performant) (BLASLONG M, BLASLONG N, BLASLONG K);
#endif
#ifdef ARCH_ARM64
#ifdef HAVE_SME
void (*sgemm_direct) (BLASLONG, BLASLONG, BLASLONG, float *, BLASLONG , float *, BLASLONG , float * , BLASLONG);
#endif
#endif

int (*sgemm_kernel )(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG);
int (*sgemm_beta )(BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG);


+ 2
- 2
common_s.h View File

@@ -213,9 +213,9 @@
#ifdef ARCH_X86_64
#define SGEMM_DIRECT_PERFORMANT gotoblas -> sgemm_direct_performant
#define SGEMM_DIRECT gotoblas -> sgemm_direct
#else
#elif ARCH_ARM64
#define SGEMM_DIRECT_PERFORMANT sgemm_direct_performant
#define SGEMM_DIRECT sgemm_direct
#define SGEMM_DIRECT gotoblas -> sgemm_direct
#endif

#define SGEMM_ONCOPY gotoblas -> sgemm_oncopy


+ 34
- 0
driver/others/dynamic_arm64.c View File

@@ -115,6 +115,11 @@ extern gotoblas_t gotoblas_ARMV8SVE;
#else
#define gotoblas_ARMV8SVE gotoblas_ARMV8
#endif
#ifdef DYN_ARMV9SME
extern gotoblas_t gotoblas_ARMV9SME;
#else
#define gotoblas_ARMV9SME gotoblas_ARMV8
#endif
#ifdef DYN_CORTEX_A55
extern gotoblas_t gotoblas_CORTEXA55;
#else
@@ -148,6 +153,13 @@ extern gotoblas_t gotoblas_A64FX;
#define gotoblas_ARMV8SVE gotoblas_ARMV8
#define gotoblas_A64FX gotoblas_ARMV8
#endif

#ifndef NO_SME
extern gotoblas_t gotoblas_ARMV9SME;
#else
#define gotoblas_ARMV9SME gotoblas_ARMV8SVE
#endif

extern gotoblas_t gotoblas_THUNDERX3T110;
#endif
#define gotoblas_NEOVERSEV2 gotoblas_NEOVERSEV1
@@ -168,6 +180,9 @@ extern void openblas_warning(int verbose, const char * msg);
#ifndef HWCAP_SVE
#define HWCAP_SVE (1 << 22)
#endif
#ifndef HWCAP2_SME
#define HWCAP2_SME 1<<23
#endif

#define get_cpu_ftr(id, var) ({ \
__asm__ __volatile__ ("mrs %0, "#id : "=r" (var)); \
@@ -393,6 +408,13 @@ static gotoblas_t *get_coretype(void) {
snprintf(coremsg, 128, "Unknown CPU model - implementer %x part %x\n",implementer,part);
openblas_warning(1, coremsg);
}

#if !defined(NO_SME) && defined(HWCAP2_SME)
if ((getauxval(AT_HWCAP2) & HWCAP2_SME)) {
return &gotoblas_ARMV9SME;
}
#endif

#ifndef NO_SVE
if ((getauxval(AT_HWCAP) & HWCAP_SVE)) {
return &gotoblas_ARMV8SVE;
@@ -443,3 +465,15 @@ void gotoblas_dynamic_init(void) {
void gotoblas_dynamic_quit(void) {
gotoblas = NULL;
}

int support_sme1(void) {
int ret = 0;

#if (defined OS_LINUX || defined OS_ANDROID)
ret = getauxval(AT_HWCAP2) & HWCAP2_SME;
if(getauxval(AT_HWCAP2) & HWCAP2_SME){
ret = 1;
}
#endif
return ret;
}

+ 13
- 0
getarch.c View File

@@ -1289,6 +1289,19 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#define CORENAME "ARMV8SVE"
#endif

#ifdef FORCE_ARMV9SME
#define FORCE
#define ARCHITECTURE "ARM64"
#define SUBARCHITECTURE "ARMV9SME"
#define SUBDIRNAME "arm64"
#define ARCHCONFIG "-DARMV9SME " \
"-DL1_DATA_SIZE=32768 -DL1_DATA_LINESIZE=64 " \
"-DL2_SIZE=262144 -DL2_LINESIZE=64 " \
"-DDTB_DEFAULT_ENTRIES=64 -DDTB_SIZE=4096 -DL2_ASSOCIATIVE=32 " \
"-DHAVE_VFPV4 -DHAVE_VFPV3 -DHAVE_VFP -DHAVE_NEON -DHAVE_SVE -DHAVE_SME -DARMV8 -DARMV9"
#define LIBNAME "armv9sme"
#define CORENAME "ARMV9SME"
#endif

#ifdef FORCE_ARMV8
#define FORCE


+ 57
- 13
interface/gemm.c View File

@@ -1,5 +1,5 @@
/*********************************************************************/
/* Copyright 2024 The OpenBLAS Project */
/* Copyright 2024, 2025 The OpenBLAS Project */
/* Copyright 2009, 2010 The University of Texas at Austin. */
/* All rights reserved. */
/* */
@@ -86,7 +86,7 @@
#endif

static int (*gemm[])(blas_arg_t *, BLASLONG *, BLASLONG *, IFLOAT *, IFLOAT *, BLASLONG) = {
#ifndef GEMM3M
#if !defined(GEMM3M) || defined(GENERIC)
GEMM_NN, GEMM_TN, GEMM_RN, GEMM_CN,
GEMM_NT, GEMM_TT, GEMM_RT, GEMM_CT,
GEMM_NR, GEMM_TR, GEMM_RR, GEMM_CR,
@@ -177,6 +177,49 @@ static int init_amxtile_permission() {
}
#endif

#ifdef DYNAMIC_ARCH
extern char* gotoblas_corename(void);
#endif

#if defined(DYNAMIC_ARCH) || defined(NEOVERSEV1)
static inline int get_gemm_optimal_nthreads_neoversev1(double MNK, int ncpu) {
return
MNK < 262144L ? 1
: MNK < 1124864L ? MIN(ncpu, 6)
: MNK < 7880599L ? MIN(ncpu, 12)
: MNK < 17173512L ? MIN(ncpu, 16)
: MNK < 33386248L ? MIN(ncpu, 20)
: MNK < 57066625L ? MIN(ncpu, 24)
: MNK < 91733851L ? MIN(ncpu, 32)
: MNK < 265847707L ? MIN(ncpu, 40)
: MNK < 458314011L ? MIN(ncpu, 48)
: MNK < 729000000L ? MIN(ncpu, 56)
: ncpu;
}
#endif

static inline int get_gemm_optimal_nthreads(double MNK) {
int ncpu = num_cpu_avail(3);
#if defined(NEOVERSEV1) && !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16)
return get_gemm_optimal_nthreads_neoversev1(MNK, ncpu);
#elif defined(DYNAMIC_ARCH) && !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16)
if (strcmp(gotoblas_corename(), "neoversev1") == 0) {
return get_gemm_optimal_nthreads_neoversev1(MNK, ncpu);
}
#endif
if ( MNK <= (SMP_THRESHOLD_MIN * (double) GEMM_MULTITHREAD_THRESHOLD) ) {
return 1;
}
else {
if (MNK/ncpu < SMP_THRESHOLD_MIN*(double)GEMM_MULTITHREAD_THRESHOLD) {
return MNK/(SMP_THRESHOLD_MIN*(double)GEMM_MULTITHREAD_THRESHOLD);
}
else {
return ncpu;
}
}
}

#ifndef CBLAS

void NAME(char *TRANSA, char *TRANSB,
@@ -310,7 +353,7 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS
FLOAT *beta = (FLOAT*) vbeta;
FLOAT *a = (FLOAT*) va;
FLOAT *b = (FLOAT*) vb;
FLOAT *c = (FLOAT*) vc;
FLOAT *c = (FLOAT*) vc;
#endif

blas_arg_t args;
@@ -350,14 +393,21 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS
PRINT_DEBUG_CNAME;

#if !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16) && defined(USE_SGEMM_KERNEL_DIRECT)
#ifdef DYNAMIC_ARCH
#if defined(DYNAMIC_ARCH) && defined(ARCH_x86)
if (support_avx512() )
#endif
if (beta == 0 && alpha == 1.0 && order == CblasRowMajor && TransA == CblasNoTrans && TransB == CblasNoTrans && SGEMM_DIRECT_PERFORMANT(m,n,k)) {
SGEMM_DIRECT(m, n, k, a, lda, b, ldb, c, ldc);
return;
}

#endif
#if defined(DYNAMIC_ARCH) && defined(ARCH_ARM64)
if (support_sme1()){
if (beta == 0 && alpha == 1.0 && order == CblasRowMajor && TransA == CblasNoTrans && TransB == CblasNoTrans) {
SGEMM_DIRECT(m, n, k, a, lda, b, ldb, c, ldc);
return;
}
}
#endif
#endif

#ifndef COMPLEX
@@ -604,13 +654,7 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS
#endif

MNK = (double) args.m * (double) args.n * (double) args.k;
if ( MNK <= (SMP_THRESHOLD_MIN * (double) GEMM_MULTITHREAD_THRESHOLD) )
args.nthreads = 1;
else {
args.nthreads = num_cpu_avail(3);
if (MNK/args.nthreads < SMP_THRESHOLD_MIN*(double)GEMM_MULTITHREAD_THRESHOLD)
args.nthreads = MNK/(SMP_THRESHOLD_MIN*(double)GEMM_MULTITHREAD_THRESHOLD);
}
args.nthreads = get_gemm_optimal_nthreads(MNK);

args.common = NULL;



+ 14
- 2
kernel/CMakeLists.txt View File

@@ -65,6 +65,7 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS)
GenerateNamedObjects("${KERNELDIR}/${${float_char}COPYKERNEL}" "C_INTERFACE" "copy_k" false "" "" false ${float_type})
GenerateNamedObjects("${KERNELDIR}/${${float_char}NRM2KERNEL}" "" "nrm2_k" false "" "" false ${float_type})
GenerateNamedObjects("${KERNELDIR}/${${float_char}ROTKERNEL}" "" "rot_k" false "" "" false ${float_type})
GenerateNamedObjects("${KERNELDIR}/${${float_char}ROTMKERNEL}" "" "rotm_k" false "" "" false ${float_type})
GenerateNamedObjects("${KERNELDIR}/${${float_char}SCALKERNEL}" "" "scal_k" false "" "" false ${float_type})
GenerateNamedObjects("${KERNELDIR}/${${float_char}SWAPKERNEL}" "" "swap_k" false "" "" false ${float_type})
GenerateNamedObjects("${KERNELDIR}/${${float_char}AXPBYKERNEL}" "" "axpby_k" false "" "" false ${float_type})
@@ -125,6 +126,7 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS)
GenerateNamedObjects("${KERNELDIR}/${SNRM2KERNEL}" "" "nrm2_k" false "" "" false "SINGLE")
GenerateNamedObjects("${KERNELDIR}/${SDOTKERNEL}" "" "dot_k" false "" "" false "SINGLE")
GenerateNamedObjects("${KERNELDIR}/${SROTKERNEL}" "" "rot_k" false "" "" false "SINGLE")
GenerateNamedObjects("${KERNELDIR}/${SROTMKERNEL}" "" "rotm_k" false "" "" false "SINGLE")
endif ()
if (BUILD_COMPLEX16 AND NOT BUILD_DOUBLE)
GenerateNamedObjects("${KERNELDIR}/${DAMAXKERNEL}" "USE_ABS" "amax_k" false "" "" false "DOUBLE")
@@ -148,6 +150,7 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS)
GenerateNamedObjects("${KERNELDIR}/${DCOPYKERNEL}" "C_INTERFACE" "copy_k" false "" "" false "DOUBLE")
GenerateNamedObjects("${KERNELDIR}/${DNRM2KERNEL}" "" "nrm2_k" false "" "" false "DOUBLE")
GenerateNamedObjects("${KERNELDIR}/${DROTKERNEL}" "" "rot_k" false "" "" false "DOUBLE")
GenerateNamedObjects("${KERNELDIR}/${DROTMKERNEL}" "" "rotm_k" false "" "" false "DOUBLE")
GenerateNamedObjects("${KERNELDIR}/${DDOTKERNEL}" "" "dot_k" false "" "" false "DOUBLE")
GenerateNamedObjects("${KERNELDIR}/${DSWAPKERNEL}" "" "swap_k" false "" "" false "DOUBLE")
GenerateNamedObjects("${KERNELDIR}/${DAXPYKERNEL}" "" "axpy_k" false "" "" false "DOUBLE")
@@ -204,19 +207,27 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS)
if (ZARCH OR (UC_TARGET_CORE MATCHES POWER8) OR (UC_TARGET_CORE MATCHES POWER9) OR (UC_TARGET_CORE MATCHES POWER10))
set(USE_TRMM true)
endif ()

set(USE_DIRECT_SGEMM false)
if (X86_64)
if (X86_64 OR (ARM64 AND (UC_TARGET_CORE MATCHES ARMV9SME)))
set(USE_DIRECT_SGEMM true)
endif()

if (USE_DIRECT_SGEMM)
# if (NOT DEFINED SGEMMDIRECTKERNEL)
if (X86_64)
set (SGEMMDIRECTKERNEL sgemm_direct_skylakex.c)
set (SGEMMDIRECTPERFORMANT sgemm_direct_performant.c)
# endif()
GenerateNamedObjects("${KERNELDIR}/${SGEMMDIRECTKERNEL}" "" "gemm_direct" false "" "" false SINGLE)
GenerateNamedObjects("${KERNELDIR}/${SGEMMDIRECTPERFORMANT}" "" "gemm_direct_performant" false "" "" false SINGLE)
elseif (ARM64)
set (SGEMMDIRECTKERNEL sgemm_direct_arm64_sme1.c)
set (SGEMMDIRECTSMEKERNEL sgemm_direct_sme1.S)
set (SGEMMDIRECTPREKERNEL sgemm_direct_sme1_preprocess.S)
GenerateNamedObjects("${KERNELDIR}/${SGEMMDIRECTKERNEL}" "" "gemm_direct" false "" "" false SINGLE)
GenerateNamedObjects("${KERNELDIR}/${SGEMMDIRECTSMEKERNEL}" "" "gemm_direct_sme1" false "" "" false SINGLE)
GenerateNamedObjects("${KERNELDIR}/${SGEMMDIRECTPREKERNEL}" "" "gemm_direct_sme1_preprocess" false "" "" false SINGLE)
endif ()
endif()

foreach (float_type SINGLE DOUBLE)
@@ -1105,6 +1116,7 @@ endif ()
GenerateNamedObjects("${KERNELDIR}/${DCOPYKERNEL}" "C_INTERFACE" "copy_k" false "" "" false "DOUBLE")
GenerateNamedObjects("${KERNELDIR}/${DNRM2KERNEL}" "" "nrm2_k" false "" "" false "DOUBLE")
GenerateNamedObjects("${KERNELDIR}/${DROTKERNEL}" "" "rot_k" false "" "" false "DOUBLE")
GenerateNamedObjects("${KERNELDIR}/${DROTMKERNEL}" "" "rotm_k" false "" "" false "DOUBLE")
GenerateNamedObjects("${KERNELDIR}/${DDOTKERNEL}" "" "dot_k" false "" "" false "DOUBLE")
GenerateNamedObjects("${KERNELDIR}/${DSWAPKERNEL}" "" "swap_k" false "" "" false "DOUBLE")
GenerateNamedObjects("${KERNELDIR}/${DAXPYKERNEL}" "" "axpy_k" false "" "" false "DOUBLE")


+ 4
- 0
kernel/Makefile View File

@@ -24,7 +24,11 @@ ifdef NO_AVX2
AVX2OPT=
endif


ifdef TARGET_CORE
ifeq ($(TARGET_CORE), ARMV9SME)
override CFLAGS += -DBUILD_KERNEL -DTABLE_NAME=gotoblas_$(TARGET_CORE) -DHAVE_SME -march=armv9-a+sve2+sme
endif
ifeq ($(TARGET_CORE), SAPPHIRERAPIDS)
override CFLAGS += -DBUILD_KERNEL -DTABLE_NAME=gotoblas_$(TARGET_CORE)
ifeq (1, $(filter 1,$(GCCVERSIONGTEQ11) $(CLANGVERSIONGTEQ12)))


+ 32
- 1
kernel/Makefile.L3 View File

@@ -24,6 +24,7 @@ endif

ifeq ($(ARCH), arm64)
USE_TRMM = 1
USE_DIRECT_SGEMM = 1
endif

ifeq ($(ARCH), riscv64)
@@ -95,9 +96,17 @@ endif

ifdef USE_DIRECT_SGEMM
ifndef SGEMMDIRECTKERNEL
ifeq ($(ARCH), x86_64)
SGEMMDIRECTKERNEL = sgemm_direct_skylakex.c
SGEMMDIRECTPERFORMANT = sgemm_direct_performant.c
endif
ifeq ($(ARCH), arm64)
ifeq ($(TARGET_CORE), ARMV9SME)
HAVE_SME = 1
SGEMMDIRECTKERNEL = sgemm_direct_arm64_sme1.c
endif
endif
endif
endif

ifeq ($(BUILD_BFLOAT16), 1)
@@ -128,9 +137,19 @@ SKERNELOBJS += \
$(SGEMMONCOPYOBJ) $(SGEMMOTCOPYOBJ)

ifdef USE_DIRECT_SGEMM
ifeq ($(ARCH), x86_64)
SKERNELOBJS += \
sgemm_direct$(TSUFFIX).$(SUFFIX) \
sgemm_direct_performant$(TSUFFIX).$(SUFFIX)
endif
ifeq ($(ARCH), arm64)
ifdef HAVE_SME
SKERNELOBJS += \
sgemm_direct$(TSUFFIX).$(SUFFIX) \
sgemm_direct_performant$(TSUFFIX).$(SUFFIX)
sgemm_direct_sme1$(TSUFFIX).$(SUFFIX) \
sgemm_direct_sme1_preprocess$(TSUFFIX).$(SUFFIX)
endif
endif
endif
endif

@@ -809,11 +828,23 @@ else
endif

ifdef USE_DIRECT_SGEMM
ifeq ($(ARCH), x86_64)
$(KDIR)sgemm_direct_performant$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMMDIRECTPERFORMANT)
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@
$(KDIR)sgemm_direct$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMMDIRECTKERNEL)
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@
endif
ifeq ($(ARCH), arm64)
ifdef HAVE_SME
$(KDIR)sgemm_direct$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMMDIRECTKERNEL)
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@
$(KDIR)sgemm_direct_sme1$(TSUFFIX).$(SUFFIX) :
$(CC) $(CFLAGS) -c $(KERNELDIR)/sgemm_direct_sme1.S -UDOUBLE -UCOMPLEX -o $@
$(KDIR)sgemm_direct_sme1_preprocess$(TSUFFIX).$(SUFFIX) :
$(CC) $(CFLAGS) -c $(KERNELDIR)/sgemm_direct_sme1_preprocess.S -UDOUBLE -UCOMPLEX -o $@
endif
endif
endif

ifeq ($(BUILD_BFLOAT16), 1)



+ 3
- 0
kernel/arm64/KERNEL.ARMV9SME View File

@@ -0,0 +1,3 @@
include $(KERNELDIR)/KERNEL.ARMV8SVE



+ 59
- 0
kernel/arm64/sgemm_direct_arm64_sme1.c View File

@@ -0,0 +1,59 @@
/*
Copyright (c) 2025 Qualcomm Innovation Center, Inc. All rights reserved.
SPDX-License-Identifier: BSD-3-Clause-Clear
*/
#include "common.h"
#include <stdlib.h>
#include <inttypes.h>
#include <math.h>
#if defined(HAVE_SME)
/* Function prototypes */
extern void sgemm_direct_sme1_preprocess(uint64_t nbr, uint64_t nbc,\
const float * restrict a, float * a_mod) __asm__("sgemm_direct_sme1_preprocess");
extern void sgemm_direct_sme1_2VLx2VL(uint64_t m, uint64_t k, uint64_t n,\
const float * matLeft,\
const float * restrict matRight,\
const float * restrict matResult) __asm__("sgemm_direct_sme1_2VLx2VL");
/* Function Definitions */
uint64_t sve_cntw() {
uint64_t cnt;
asm volatile(
"rdsvl %[res], #1\n"
"lsr %[res], %[res], #2\n"
: [res] "=r" (cnt) ::
);
return cnt;
}
/*void sgemm_kernel_direct (BLASLONG M, BLASLONG N, BLASLONG K,\
float * __restrict A, BLASLONG strideA, float * __restrict B,\
BLASLONG strideB , float * __restrict R, BLASLONG strideR)
*/
void CNAME (BLASLONG M, BLASLONG N, BLASLONG K, float * __restrict A,\
BLASLONG strideA, float * __restrict B, BLASLONG strideB ,\
float * __restrict R, BLASLONG strideR){
uint64_t m_mod, vl_elms;
vl_elms = sve_cntw();
m_mod = ceil((double)M/(double)vl_elms) * vl_elms;
float *A_mod = (float *) malloc(m_mod*K*sizeof(float));
/* Pre-process the left matrix to make it suitable for
matrix sum of outer-product calculation
*/
sgemm_direct_sme1_preprocess(M, K, A, A_mod);
/* Calculate C = A*B */
sgemm_direct_sme1_2VLx2VL(M, K, N, A_mod, B, R);
free(A_mod);
}
#endif

+ 228
- 0
kernel/arm64/sgemm_direct_sme1.S View File

@@ -0,0 +1,228 @@
/*
Copyright (c) 2025 Qualcomm Innovation Center, Inc. All rights reserved.
SPDX-License-Identifier: BSD-3-Clause-Clear
*/

/*--------------------------------------------------------------------------
* SME1 based Matrix multiplication code for FP32 input matrices to FP32
* output matrix
* C = A*B
* A: Left input matrix of dimension M x K
* B: Right input matrix of dimension K x N
* C: Result matrix of dimension M x N
*
* Usage of function:
* sgemm_direct_sme1_2VLx2VL( uint64_t M , uint64_t K, uint64_t N,\
const float * restrict A_base,\
const float * restrict B_base,\
const float * restrict C_base);
----------------------------------------------------------------------------*/

#define M x0 //M dimension
#define K x1 //K dimension
#define N x2 //N dimension
#define A_base x3 //Pointer to left matrix(A)
#define B_base x4 //Pointer to right matrix(B)
#define C_base x5 //Pointer to result matrix(C)
#define Aptr x6 //Pointer to traverse A
#define Aptr_end x7 //Pointer to end of row of A
#define Cptr x8 //Pointer to traverse C
#define Cptr0 x9 //2nd Pointer to traverse C
#define Cptr1 x10 //3rd Pointer to traverse C
#define Bptr x11 //Pointer to traverse B
#define Bptr0 x12 //2nd Pointer to traverse B
#define N_exit x14 //Exit condition for N loop
#define K_exit x15 //Exit condition for K loop
#define M_cntr x16 //M loop counter
#define C1 x17 //Constant1: N*(SVLs+1);SVLs-No. of 32-bit elements
#define C2 x18 //Constant2: N + SVLs
#define C3 x19 //Constant3: K*SVLs + SVLs
#define C4 x20 //Constant4: SVLs-2
#define C5 x21 //Constant5: K*SVLs
#define C6 x22 //Constant6: N*SVLs

.text
.global sgemm_direct_sme1_2VLx2VL

sgemm_direct_sme1_2VLx2VL:

stp x19, x20, [sp, #-48]!
stp x21, x22, [sp, #16]
stp x23, x24, [sp, #32]

smstart

cntw C4 //SVLs
mul C5, C4, K //K*SVLs
mul C6, C4, N //N*SVLs
add C1, C6, N //N*SVLs + N
add N_exit, B_base, N, lsl #2 //N_Loop exit conditon
mov M_cntr, #0
add C2, N, C4 //N + SVLs
add C3, C5, C4 //K*SVLs + SVLs
whilelt p2.s, M_cntr, M //Tile 0,1 predicate (M dimension)
sub w20, w20, #2 //SVLs-2

.M_Loop:
incw M_cntr
whilelt p3.s, M_cntr, M //Tile 2,3 predicate (M dimension)
mov Bptr, B_base //B_base
mov Cptr, C_base //C_base
whilelt p0.b, Bptr, N_exit //Tile 0/2 predicate (N dimension)

.N_Loop:
mov Aptr, A_base //Aptr = A_base
mov Bptr0, Bptr //Bptr = B_base
mov Cptr0, Cptr //Cptr0 = C_base
addvl Cptr1, Cptr, #1 //Cptr1 = C_base + SVLb
addvl Bptr, Bptr, #1
whilelt p1.b, Bptr, N_exit //Tile 1,3 predicate (N dimension)
add Aptr_end, A_base, C5, lsl #2 //A_base + K*SVLs
addvl K_exit, Aptr_end, #-1 //Exit condition for K loop
//Load 1st vector from Aptr
ld1w {z1.s}, p2/z, [Aptr]
zero {za}
// Load 1st vector from Bptr
ld1w {z2.s}, p0/z, [Bptr0]
// ZA0 += 1st Aptr vector OP 1st Bptr vector
fmopa za0.s, p2/m, p0/m, z1.s, z2.s
// Load 2nd vector from Aptr
ld1w {z5.s}, p3/z, [Aptr, C5, lsl #2]
// Aptr += SVLb
addvl Aptr, Aptr, #1

.K_Loop:
// ZA2 += 2nd Aptr vector OP 1st Bptr vector
fmopa za2.s, p3/m, p0/m, z5.s, z2.s
// Load 2nd vector from Bptr
ld1w {z3.s}, p1/z, [Bptr0, #1, MUL VL]
// ZA1 += 1st Aptr vector OP 2nd Bptr vector
fmopa za1.s, p2/m, p1/m, z1.s, z3.s
// Load next 1st vector from Aptr
ld1w {z0.s}, p2/z, [Aptr]
// ZA3 += 2nd Aptr vector OP 2nd Bptr vector
fmopa za3.s, p3/m, p1/m, z5.s, z3.s
cmp K, #2
b.le process_K_less_than_equal_2
// Load next 1st vector from Bptr
ld1w {z6.s}, p0/z, [Bptr0, N, lsl #2]
// ZA0 += 1st Aptr vector OP 1st Bptr vector
fmopa za0.s, p2/m, p0/m, z0.s, z6.s
// Load next 2nd vector from Aptr
ld1w {z4.s}, p3/z, [Aptr, C5, lsl #2]
// ZA2 += 2nd Aptr vector OP 1st Bptr vector
fmopa za2.s, p3/m, p0/m, z4.s, z6.s
// Load next 2nd vector from Bptr
ld1w {z7.s}, p1/z, [Bptr0, C2, lsl #2]
// Bptr += 2*ldb FP32 elms [Bytes]
add Bptr0, Bptr0, N, lsl #3
// ZA1 += 1st Aptr vector OP 2nd Bptr vector
fmopa za1.s, p2/m, p1/m, z0.s, z7.s
// Load next 2nd vector from Aptr
ld1w {z1.s}, p2/z, [Aptr, #1, MUL VL]
// ZA3 += 2nd Aptr vector OP 2nd Bptr vector
fmopa za3.s, p3/m, p1/m, z4.s, z7.s
// Load next 1st vector from Bptr
ld1w {z2.s}, p0/z, [Bptr0]
// ZA0 += 1st Aptr vector OP 1st Bptr vector
fmopa za0.s, p2/m, p0/m, z1.s, z2.s
// Load next 2nd vector from Aptr
ld1w {z5.s}, p3/z, [Aptr, C3, lsl #2]
// Aptr += 2*SVLb [Bytes]
addvl Aptr, Aptr, #2
cmp Aptr, K_exit
b.mi .K_Loop
// ZA2 += 2nd Aptr vector OP 1st Bptr vector
fmopa za2.s, p3/m, p0/m, z5.s, z2.s
// Load next 2nd vector from Bptr
ld1w {z3.s}, p1/z, [Bptr0, #1, MUL VL]
// ZA1 += 1st Aptr vector OP 2nd Bptr vector
fmopa za1.s, p2/m, p1/m, z1.s, z3.s
// ZA3 += 2nd Aptr vector OP 2nd Bptr vector
fmopa za3.s, p3/m, p1/m, z5.s, z3.s

process_K_less_than_equal_2:
// Bptr += 2*ldb FP32 elements
add Bptr0, Bptr0, N, lsl #2
cmp Aptr, Aptr_end
b.pl .Ktail_end

.Ktail_start:
ld1w {z1.s}, p2/z, [Aptr]
ld1w {z2.s}, p0/z, [Bptr0]
ld1w {z3.s}, p1/z, [Bptr0, #1, MUL VL]
fmopa za0.s, p2/m, p0/m, z1.s, z2.s
ld1w {z5.s}, p3/z, [Aptr, C5, lsl #2]
fmopa za2.s, p3/m, p0/m, z5.s, z2.s
fmopa za1.s, p2/m, p1/m, z1.s, z3.s
fmopa za3.s, p3/m, p1/m, z5.s, z3.s

.Ktail_end:
mov w13, #0
psel p4, p0, p2.s[w13, 0]
psel p5, p1, p2.s[w13, 0]
psel p6, p0, p3.s[w13, 0]
psel p7, p1, p3.s[w13, 0]
// Store to Cptr0
st1w {za0h.s[w13, #0]}, p4, [Cptr0]
// Store to Cptr1
st1w {za1h.s[w13, #0]}, p5, [Cptr1]
// Store to Cptr0 + N*SVLs
st1w {za2h.s[w13, #0]}, p6, [Cptr0, C6, lsl #2]
// Store to Cptr1 + N*SVLs
st1w {za3h.s[w13, #0]}, p7, [Cptr1, C6, lsl #2]

.Loop_store_ZA:
psel p4, p0, p2.s[w13, 1]
psel p5, p1, p2.s[w13, 1]
psel p6, p0, p3.s[w13, 1]
psel p7, p1, p3.s[w13, 1]
// Store to Cptr0 + N
st1w {za0h.s[w13, #1]}, p4, [Cptr0, N, lsl #2]
// Store to Cptr1 + N
st1w {za1h.s[w13, #1]}, p5, [Cptr1, N, lsl #2]
// Store to Cptr0 + N*(SVLs+1)
st1w {za2h.s[w13, #1]}, p6, [Cptr0, C1, lsl #2]
// Store to Cptr1 + N*(SVLs+1)
st1w {za3h.s[w13, #1]}, p7, [Cptr1, C1, lsl #2]

add Cptr0, Cptr0, N, lsl #3 //Cptr0 += 2*N FP32 elements
add Cptr1, Cptr1, N, lsl #3 //Cptr1 += 2*N FP32 elements
add w13, w13, #2

psel p4, p0, p2.s[w13, 0]
psel p5, p1, p2.s[w13, 0]
psel p6, p0, p3.s[w13, 0]
psel p7, p1, p3.s[w13, 0]
st1w {za0h.s[w13, #0]}, p4, [Cptr0]
st1w {za1h.s[w13, #0]}, p5, [Cptr1]
st1w {za2h.s[w13, #0]}, p6, [Cptr0, C6, lsl #2]
st1w {za3h.s[w13, #0]}, p7, [Cptr1, C6, lsl #2]
cmp w13, w20
b.mi .Loop_store_ZA
psel p4, p0, p2.s[w13, 1]
psel p5, p1, p2.s[w13, 1]
psel p6, p0, p3.s[w13, 1]
psel p7, p1, p3.s[w13, 1]
st1w {za0h.s[w13, #1]}, p4, [Cptr0, N, lsl #2]
st1w {za1h.s[w13, #1]}, p5, [Cptr1, N, lsl #2]
st1w {za2h.s[w13, #1]}, p6, [Cptr0, C1, lsl #2]
st1w {za3h.s[w13, #1]}, p7, [Cptr1, C1, lsl #2]
addvl Cptr, Cptr, #2
addvl Bptr, Bptr, #1
whilelt p0.b, Bptr, N_exit //1st Tile predicate (N dimension)
b.first .N_Loop
add A_base, A_base, C5, lsl #3 //A_base += 2*K*SVLs FP32 elements
add C_base, C_base, C6, lsl #3 //C_base += 2*N*SVLs FP32 elements
incw M_cntr
whilelt p2.s, M_cntr, M //1st Tile predicate (M dimension)
b.first .M_Loop

smstop

ldp x23, x24, [sp, #32]
ldp x21, x22, [sp, #16]
ldp x19, x20, [sp], #48

ret


+ 133
- 0
kernel/arm64/sgemm_direct_sme1_preprocess.S View File

@@ -0,0 +1,133 @@
/*
Copyright (c) 2025 Qualcomm Innovation Center, Inc. All rights reserved.
SPDX-License-Identifier: BSD-3-Clause-Clear
*/

/*----------------------------------------------------------------------------
* This function is used to re-arrange the elements of input matrix to
* make it suitable for matrix outer product computation using SME for matrix
* multiplication. It should be used to pre-process the leftmatrix(A) in the
* matrix muliplication (C= A*B) using sgemm_direct_sme1_2VLx2VL()
*
* The pre-processing transposes a block of SVLs rows of the input matrix and
* stores it contiguously. The same is applied to remaining blocks of SVLs
* rows. The last block of SVLs rows is zero-padded to SVLs rows if needed.
*
* Usage of function:
* sgemm_direct_sme1_preprocess(uint64_t nrow, uint64_t ncol, \
* const float * restrict mat, float * mat_mod);
*
----------------------------------------------------------------------------*/


#define nrow x0 //Number of rows of input matrix
#define ncol x1 //Number of coulumns of input matrix
#define mat x2 //Input matrix base address
#define mat_mod x3 //Output matrix (re-arranged matrix) base address
#define mat_mod_ptr x4 //Pointer to output matrix
#define mat_ptr0 x5 //Pointer to input matrix
#define mat_ptr1 x6 //2nd pointer to input matrix
#define outer_loop_cntr x7 //Outer loop counter
#define inner_loop_exit x8 //Inner loop exit condition
#define C1 x9 //Constant1: SVLs - No. of 32-bit elements
#define C2 x10 //Constant2: 3*SVLs
#define C3 x11 //Constant3: ncol*SVLs
#define C4 x13 //Constant4: 2*SVLs
#define C5 x14 //Constant5: 2*ncol
#define C6 x15 //Constant6: 3*ncol

.text
.global sgemm_direct_sme1_preprocess

sgemm_direct_sme1_preprocess:

stp x19, x20, [sp, #-48]!
stp x21, x22, [sp, #16]
stp x23, x24, [sp, #32]

smstart

cntw C1 //SVLs
mul C3, C1, ncol //SVLs*ncol
lsl C5, ncol, #1 //2*ncol
add C6, C5, ncol //3*ncol
cnth C4 //2*SVLs
add C2, C1, C1, lsl #1 //3*SVLs

mov outer_loop_cntr, #0
//Tile predicate (M dimension)
whilelt p0.s, outer_loop_cntr, nrow
//Predicate for stores
ptrue p9.s

.M_Loop:
mov mat_ptr0, mat //Load base address of mat
mov mat_mod_ptr, mat_mod //a_mod store base address
add inner_loop_exit, mat, ncol, lsl #2 //Exit condition for inner loop
whilelt p8.b, mat_ptr0, inner_loop_exit //Tile predicate (K dimension)

.Loop_process:
mov mat_ptr1, mat_ptr0
//Load_to_tile loop counter
mov w12, #0

.Load_to_tile:
psel p2, p8, p0.s[w12, 0]
psel p3, p8, p0.s[w12, 1]
psel p4, p8, p0.s[w12, 2]
psel p5, p8, p0.s[w12, 3]
//Load 1st row from mat_ptr1
ld1w {za0h.s[w12, #0]}, p2/z, [mat_ptr1]
//Load 2nd row from mat_ptr1 + ncol
ld1w {za0h.s[w12, #1]}, p3/z, [mat_ptr1, ncol, lsl #2]
//Load 3rd row from mat_ptr1 + 2*ncol
ld1w {za0h.s[w12, #2]}, p4/z, [mat_ptr1, C5, lsl #2]
//Load 4th row from mat_ptr1 + 3*ncol
ld1w {za0h.s[w12, #3]}, p5/z, [mat_ptr1, C6, lsl #2]
//mat_ptr1+=4*ncol FP32 elements
add mat_ptr1, mat_ptr1, ncol, lsl #4
//Increment counter
add w12, w12, #4
cmp w12, w9
b.mi .Load_to_tile
// Store_from_tile loop counter
mov w12, #0

.Store_from_tile:
psel p2, p9, p8.s[w12, 0]
psel p3, p9, p8.s[w12, 1]
psel p4, p9, p8.s[w12, 2]
psel p5, p9, p8.s[w12, 3]
//Store 1st col to mat_mod
st1w {za0v.s[w12, #0]}, p2, [mat_mod_ptr]
//Store 2nd col to mat_mod + SVLs
st1w {za0v.s[w12, #1]}, p3, [mat_mod_ptr, C1, lsl #2]
//Store 3rd col to mat_mod + 2*SVLs
st1w {za0v.s[w12, #2]}, p4, [mat_mod_ptr, C4, lsl #2]
//Store 4th col to mat_mod + 3*SVLs
st1w {za0v.s[w12, #3]}, p5, [mat_mod_ptr, C2, lsl #2]

addvl mat_mod_ptr, mat_mod_ptr, #4 //mat_mod_ptr += 4*SVLb
add w12, w12, #4 //Increment counter
cmp w12, w9
b.mi .Store_from_tile

addvl mat_ptr0, mat_ptr0, #1 //mat_ptr0 += SVLb
whilelt p8.b, mat_ptr0, inner_loop_exit
b.first .Loop_process

add mat_mod, mat_mod, C3, lsl #2 //mat_mod+=SVLs*nbc FP32 elements
add mat, mat, C3, lsl #2 //mat+=SVLs*nbc FP32 elements
incw outer_loop_cntr

whilelt p0.s, outer_loop_cntr, nrow
b.first .M_Loop

smstop

ldp x23, x24, [sp, #32]
ldp x21, x22, [sp, #16]
ldp x19, x20, [sp], #48

ret


+ 5
- 0
kernel/setparam-ref.c View File

@@ -178,6 +178,11 @@ gotoblas_t TABLE_NAME = {
#ifdef ARCH_X86_64
sgemm_directTS,
sgemm_direct_performantTS,
#endif
#ifdef ARCH_ARM64
#ifdef HAVE_SME
sgemm_directTS,
#endif
#endif

sgemm_kernelTS, sgemm_betaTS,


+ 7
- 1
param.h View File

@@ -3303,6 +3303,8 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

#define GEMM_DEFAULT_OFFSET_A 0
#define GEMM_DEFAULT_OFFSET_B 0


#ifdef _WIN64
/* Use explicit casting for win64 as LLP64 datamodel is used */
#define GEMM_DEFAULT_ALIGN (BLASULONG)0x03fffUL
@@ -3667,7 +3669,7 @@ Until then, just keep it different than DGEMM_DEFAULT_UNROLL_N to keep copy rout
#define CGEMM_DEFAULT_R 4096
#define ZGEMM_DEFAULT_R 4096

#elif defined(ARMV8SVE) || defined(ARMV9) || defined(CORTEXA510)|| defined(CORTEXA710) || defined(CORTEXX2) // 128-bit SVE
#elif defined(ARMV8SVE) || defined(ARMV9SME) || defined(ARMV9) || defined(CORTEXA510)|| defined(CORTEXA710) || defined(CORTEXX2) // 128-bit SVE

#if defined(XDOUBLE) || defined(DOUBLE)
#define SWITCH_RATIO 8
@@ -3738,6 +3740,10 @@ Until then, just keep it different than DGEMM_DEFAULT_UNROLL_N to keep copy rout

#endif /* ARMv8 */

#if defined(ARMV9SME) /* ARMv9 SME */
#define USE_SGEMM_KERNEL_DIRECT 1
#endif /* ARMv9 SME */

#if defined(ARMV5)
#define SNUMOPT 2
#define DNUMOPT 2


Loading…
Cancel
Save