* Added ARMV9SME target * Added SGEMM_DIRECT kernel based on SME1tags/v0.3.30
@@ -4,11 +4,12 @@ | |||
cmake_minimum_required(VERSION 3.16.0) | |||
set (CMAKE_ASM_SOURCE_FILE_EXTENSIONS "S") | |||
project(OpenBLAS C ASM) | |||
set(OpenBLAS_MAJOR_VERSION 0) | |||
set(OpenBLAS_MINOR_VERSION 3) | |||
set(OpenBLAS_PATCH_VERSION 28.dev) | |||
set(OpenBLAS_PATCH_VERSION 29.dev) | |||
set(OpenBLAS_VERSION "${OpenBLAS_MAJOR_VERSION}.${OpenBLAS_MINOR_VERSION}.${OpenBLAS_PATCH_VERSION}") | |||
@@ -30,6 +30,11 @@ FCOMMON_OPT += -march=armv8-a+sve | |||
endif | |||
endif | |||
ifeq ($(CORE), ARMV9SME) | |||
CCOMMON_OPT += -march=armv9-a+sve2+sme | |||
FCOMMON_OPT += -march=armv9-a+sve2 | |||
endif | |||
ifeq ($(CORE), CORTEXA53) | |||
CCOMMON_OPT += -march=armv8-a -mtune=cortex-a53 | |||
ifneq ($(F_COMPILER), NAG) | |||
@@ -420,6 +420,7 @@ ifeq ($(ARCH), arm64) | |||
export MACOSX_DEPLOYMENT_TARGET=11.0 | |||
ifeq ($(C_COMPILER), GCC) | |||
export NO_SVE = 1 | |||
export NO_SME = 1 | |||
endif | |||
else | |||
export MACOSX_DEPLOYMENT_TARGET=10.8 | |||
@@ -709,6 +710,9 @@ DYNAMIC_CORE += NEOVERSEN2 | |||
DYNAMIC_CORE += ARMV8SVE | |||
DYNAMIC_CORE += A64FX | |||
endif | |||
ifneq ($(NO_SME), 1) | |||
DYNAMIC_CORE += ARMV9SME | |||
endif | |||
DYNAMIC_CORE += THUNDERX | |||
DYNAMIC_CORE += THUNDERX2T99 | |||
DYNAMIC_CORE += TSV110 | |||
@@ -1474,6 +1478,10 @@ ifeq ($(NO_SVE), 1) | |||
CCOMMON_OPT += -DNO_SVE | |||
endif | |||
ifeq ($(NO_SME), 1) | |||
CCOMMON_OPT += -DNO_SME | |||
endif | |||
ifdef SMP | |||
CCOMMON_OPT += -DSMP_SERVER | |||
@@ -111,6 +111,7 @@ THUNDERX3T110 | |||
VORTEX | |||
A64FX | |||
ARMV8SVE | |||
ARMV9SME | |||
FT2000 | |||
9.System Z: | |||
@@ -331,6 +331,24 @@ if [ "$architecture" = "arm64" ]; then | |||
rm -rf "$tmpd" | |||
fi | |||
no_sme=0 | |||
if [ "$architecture" = "arm64" ]; then | |||
tmpd=$(mktemp -d 2>/dev/null || mktemp -d -t 'OBC') | |||
tmpf="$tmpd/a.S" | |||
printf ".text \n.global sme_test\n\nsme_test:\nsmstart\nsmstop\nret\n">> "$tmpf" | |||
args=" -march=armv9-a+sve2+sme -c -o $tmpf.o $tmpf" | |||
no_sme=0 | |||
{ | |||
$compiler_name $flags $args >/dev/null 2>&1 | |||
} || { | |||
args=" -march=armv9-a+sme -c -o $tmpf.o $tmpf" | |||
$compiler_name $flags $args >/dev/null 2>&1 | |||
} || { | |||
no_sme=1 | |||
} | |||
rm -rf "$tmpd" | |||
fi | |||
c11_atomics=0 | |||
case "$data" in | |||
*HAVE_C11*) | |||
@@ -472,6 +490,7 @@ done | |||
printf "CEXTRALIB=%s %s %s\n" "$linker_L" "$linker_l" "$linker_a" | |||
[ "$no_msa" -eq 1 ] && printf "NO_MSA=1\n" | |||
[ "$no_sve" -eq 1 ] && printf "NO_SVE=1\n" | |||
[ "$no_sme" -eq 1 ] && printf "NO_SME=1\n" | |||
[ "$no_rv64gv" -eq 1 ] && printf "NO_RV64GV=1\n" | |||
[ "$no_avx512" -eq 1 ] && printf "NO_AVX512=1\n" | |||
[ "$no_avx512bf" -eq 1 ] && printf "NO_AVX512BF16=1\n" | |||
@@ -44,9 +44,21 @@ endif () | |||
if (DYNAMIC_ARCH) | |||
if (ARM64) | |||
set(DYNAMIC_CORE ARMV8 CORTEXA53 CORTEXA57 THUNDERX THUNDERX2T99 TSV110 EMAG8180 NEOVERSEN1 THUNDERX3T110) | |||
if (${CMAKE_C_COMPILER_VERSION} VERSION_GREATER 9.99) | |||
set(DYNAMIC_CORE ${DYNAMIC_CORE} NEOVERSEV1 NEOVERSEN2 ARMV8SVE A64FX) | |||
set(DYNAMIC_CORE ARMV8 CORTEXA53 CORTEXA57 THUNDERX THUNDERX2T99 TSV110 EMAG8180 NEOVERSEN1 THUNDERX3T110) | |||
if (${CMAKE_C_COMPILER_ID} STREQUAL "GNU") | |||
if (${CMAKE_C_COMPILER_VERSION} VERSION_GREATER_EQUAL 10) # SVE ACLE supported in GCC >= 10 | |||
set(DYNAMIC_CORE ${DYNAMIC_CORE} NEOVERSEV1 NEOVERSEN2 ARMV8SVE A64FX) | |||
endif () | |||
if (${CMAKE_C_COMPILER_VERSION} VERSION_GREATER_EQUAL 14) # SME ACLE supported in GCC >= 14 | |||
set(DYNAMIC_CORE ${DYNAMIC_CORE} ARMV9SME) | |||
endif() | |||
elseif (${CMAKE_C_COMPILER_ID} MATCHES "Clang") | |||
if (${CMAKE_C_COMPILER_VERSION} VERSION_GREATER_EQUAL 11) # SVE ACLE supported in LLVM >= 11 | |||
set(DYNAMIC_CORE ${DYNAMIC_CORE} NEOVERSEV1 NEOVERSEN2 ARMV8SVE A64FX) | |||
endif () | |||
if (${CMAKE_C_COMPILER_VERSION} VERSION_GREATER_EQUAL 19) # SME ACLE supported in LLVM >= 19 | |||
set(DYNAMIC_CORE ${DYNAMIC_CORE} ARMV9SME) | |||
endif() | |||
endif () | |||
if (DYNAMIC_LIST) | |||
set(DYNAMIC_CORE ARMV8 ${DYNAMIC_LIST}) | |||
@@ -238,6 +238,12 @@ if (${CORE} STREQUAL ARMV8SVE) | |||
endif () | |||
endif () | |||
if (${CORE} STREQUAL ARMV9SME) | |||
if (NOT DYNAMIC_ARCH) | |||
set (CCOMMON_OPT "${CCOMMON_OPT} -march=armv9-a+sme") | |||
endif () | |||
endif () | |||
if (${CORE} STREQUAL CORTEXA510) | |||
if (NOT DYNAMIC_ARCH) | |||
set (CCOMMON_OPT "${CCOMMON_OPT} -march=armv8-a+sve") | |||
@@ -1014,7 +1014,7 @@ endif () | |||
set(ZGEMM_UNROLL_M 4) | |||
set(ZGEMM_UNROLL_N 4) | |||
set(SYMV_P 16) | |||
elseif ("${TCORE}" STREQUAL "NEOVERSEN2") | |||
elseif ("${TCORE}" STREQUAL "NEOVERSEN2" or "${TCORE}" STREQUAL "ARMV9SME") | |||
file(APPEND ${TARGET_CONF_TEMP} | |||
"#define L1_CODE_SIZE\t65536\n" | |||
"#define L1_CODE_LINESIZE\t64\n" | |||
@@ -21,7 +21,15 @@ endif() | |||
# Other files expect CORE, which is actually TARGET and will become TARGET_CORE for kernel build. Confused yet? | |||
# It seems we are meant to use TARGET as input and CORE internally as kernel. | |||
if(NOT DEFINED CORE AND DEFINED TARGET) | |||
set(CORE ${TARGET}) | |||
if (${TARGET} STREQUAL "LOONGSON3R5") | |||
set(CORE "LA464") | |||
elseif (${TARGET} STREQUAL "LOONGSON2K1000") | |||
set(CORE "LA264") | |||
elseif (${TARGET} STREQUAL "LOONGSONGENERIC") | |||
set(CORE "LA64_GENERIC)") | |||
else () | |||
set(CORE ${TARGET}) | |||
endif() | |||
endif() | |||
# TARGET_CORE will override TARGET which is used in DYNAMIC_ARCH=1. | |||
@@ -310,6 +318,9 @@ if (${TARGET} STREQUAL NEOVERSEV1) | |||
set (KERNEL_DEFINITIONS "${KERNEL_DEFINITIONS} -march=armv8.2-a+sve") | |||
endif() | |||
endif() | |||
if (${TARGET} STREQUAL ARMV9SME) | |||
set (KERNEL_DEFINITIONS "${KERNEL_DEFINITIONS} -march=armv9-a+sme -O3") | |||
endif() | |||
if (${TARGET} STREQUAL A64FX) | |||
if (${CMAKE_C_COMPILER_ID} STREQUAL "PGI" AND NOT NO_SVE) | |||
set (KERNEL_DEFINITIONS "${KERNEL_DEFINITIONS} -Msve-intrinsics -march=armv8.2-a+sve -mtune=a64fx") | |||
@@ -382,6 +393,8 @@ if (NEED_PIC) | |||
if (NOT NOFORTRAN) | |||
if (${F_COMPILER} STREQUAL "SUN") | |||
set(FCOMMON_OPT "${FCOMMON_OPT} -pic") | |||
elseif (${F_COMPILER} STREQUAL "NAGFOR") | |||
set(FCOMMON_OPT "${FCOMMON_OPT} -PIC") | |||
else () | |||
set(FCOMMON_OPT "${FCOMMON_OPT} -fPIC") | |||
endif () | |||
@@ -640,17 +653,17 @@ if (${CMAKE_SYSTEM_NAME} STREQUAL "Windows") | |||
endif () | |||
if (CMAKE_Fortran_COMPILER) | |||
if ("${F_COMPILER}" STREQUAL "NAG" OR "${F_COMPILER}" STREQUAL "CRAY" OR CMAKE_Fortran_COMPILER_ID MATCHES "LLVMFlang.*") | |||
set(FILTER_FLAGS "-msse3;-mssse3;-msse4.1;-mavx;-mavx2,-mskylake-avx512") | |||
if (CMAKE_Fortran_COMPILER_ID MATCHES "LLVMFlang.*") | |||
message(STATUS "removing fortran flags") | |||
set(FILTER_FLAGS "${FILTER_FLAGS};-m32;-m64") | |||
if ("${F_COMPILER}" STREQUAL "NAGFOR" OR "${F_COMPILER}" STREQUAL "CRAY" OR CMAKE_Fortran_COMPILER_ID MATCHES "LLVMFlang.*") | |||
set(FILTER_FLAGS "-msse3;-mssse3;-msse4.1;-mavx;-mavx2,-mskylake-avx512") | |||
if (CMAKE_Fortran_COMPILER_ID MATCHES "LLVMFlang.*") | |||
message(STATUS "removing fortran flags") | |||
set(FILTER_FLAGS "${FILTER_FLAGS};-m32;-m64") | |||
endif () | |||
foreach (FILTER_FLAG ${FILTER_FLAGS}) | |||
string(REPLACE ${FILTER_FLAG} "" LAPACK_FFLAGS ${LAPACK_FFLAGS}) | |||
string(REPLACE ${FILTER_FLAG} "" LAPACK_FPFLAGS ${LAPACK_FPFLAGS}) | |||
endforeach () | |||
endif () | |||
foreach (FILTER_FLAG ${FILTER_FLAGS}) | |||
string(REPLACE ${FILTER_FLAG} "" LAPACK_FFLAGS ${LAPACK_FFLAGS}) | |||
string(REPLACE ${FILTER_FLAG} "" LAPACK_FPFLAGS ${LAPACK_FPFLAGS}) | |||
endforeach () | |||
endif () | |||
endif () | |||
if ("${F_COMPILER}" STREQUAL "GFORTRAN") | |||
@@ -670,6 +683,9 @@ endif () | |||
if (${CMAKE_C_COMPILER} STREQUAL "LSB" OR ${CMAKE_SYSTEM_NAME} STREQUAL "Windows") | |||
set(LAPACK_CFLAGS "${LAPACK_CFLAGS} -DLAPACK_COMPLEX_STRUCTURE") | |||
endif () | |||
if (${CMAKE_C_COMPILER_ID} MATCHES "IntelLLVM" AND ${CMAKE_SYSTEM_NAME} STREQUAL "Windows") | |||
set(LAPACK_CFLAGS "${LAPACK_CFLAGS} -DNOCHANGE") | |||
endif () | |||
if ("${CMAKE_BUILD_TYPE}" STREQUAL "Release") | |||
if ("${F_COMPILER}" STREQUAL "FLANG") | |||
@@ -135,6 +135,17 @@ endif() | |||
endif() | |||
endif() | |||
if (ARM64) | |||
if (NOT NO_SME) | |||
file(WRITE ${PROJECT_BINARY_DIR}/sme.c ".text \n.global sme_test\n\nsme_test:\nsmstart\nsmstop\nret\n") | |||
execute_process(COMMAND ${CMAKE_C_COMPILER} -march=armv9-a+sve2+sme -c -v -o ${PROJECT_BINARY_DIR}/sme.o ${PROJECT_BINARY_DIR}/sme.c OUTPUT_QUIET ERROR_QUIET RESULT_VARIABLE NO_SME) | |||
if (NO_SME EQUAL 1) | |||
set (CCOMMON_OPT "${CCOMMON_OPT} -DNO_SME") | |||
endif() | |||
file(REMOVE "${PROJECT_BINARY_DIR}/sme.c" "${PROJECT_BINARY_DIR}/sme.o") | |||
endif() | |||
endif() | |||
include(CheckIncludeFile) | |||
CHECK_INCLUDE_FILE("stdatomic.h" HAVE_C11) | |||
if (HAVE_C11 EQUAL 1) | |||
@@ -696,6 +696,7 @@ void gotoblas_profile_init(void); | |||
void gotoblas_profile_quit(void); | |||
int support_avx512(void); | |||
int support_sme1(void); | |||
#ifdef USE_OPENMP | |||
@@ -175,7 +175,7 @@ REALNAME: | |||
#define HUGE_PAGESIZE ( 4 << 20) | |||
#ifndef BUFFERSIZE | |||
#if defined(NEOVERSEN1) || defined(NEOVERSEN2) || defined(NEOVERSEV1) || defined(A64FX) || defined(ARMV8SVE) | |||
#if defined(NEOVERSEN1) || defined(NEOVERSEN2) || defined(NEOVERSEV1) || defined(A64FX) || defined(ARMV8SVE) || defined(ARMV9SME) | |||
#define BUFFER_SIZE (32 << 22) | |||
#else | |||
#define BUFFER_SIZE (32 << 20) | |||
@@ -221,6 +221,12 @@ BLASLONG (*ismin_k) (BLASLONG, float *, BLASLONG); | |||
void (*sgemm_direct) (BLASLONG, BLASLONG, BLASLONG, float *, BLASLONG , float *, BLASLONG , float * , BLASLONG); | |||
int (*sgemm_direct_performant) (BLASLONG M, BLASLONG N, BLASLONG K); | |||
#endif | |||
#ifdef ARCH_ARM64 | |||
#ifdef HAVE_SME | |||
void (*sgemm_direct) (BLASLONG, BLASLONG, BLASLONG, float *, BLASLONG , float *, BLASLONG , float * , BLASLONG); | |||
#endif | |||
#endif | |||
int (*sgemm_kernel )(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG); | |||
int (*sgemm_beta )(BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG); | |||
@@ -213,9 +213,9 @@ | |||
#ifdef ARCH_X86_64 | |||
#define SGEMM_DIRECT_PERFORMANT gotoblas -> sgemm_direct_performant | |||
#define SGEMM_DIRECT gotoblas -> sgemm_direct | |||
#else | |||
#elif ARCH_ARM64 | |||
#define SGEMM_DIRECT_PERFORMANT sgemm_direct_performant | |||
#define SGEMM_DIRECT sgemm_direct | |||
#define SGEMM_DIRECT gotoblas -> sgemm_direct | |||
#endif | |||
#define SGEMM_ONCOPY gotoblas -> sgemm_oncopy | |||
@@ -115,6 +115,11 @@ extern gotoblas_t gotoblas_ARMV8SVE; | |||
#else | |||
#define gotoblas_ARMV8SVE gotoblas_ARMV8 | |||
#endif | |||
#ifdef DYN_ARMV9SME | |||
extern gotoblas_t gotoblas_ARMV9SME; | |||
#else | |||
#define gotoblas_ARMV9SME gotoblas_ARMV8 | |||
#endif | |||
#ifdef DYN_CORTEX_A55 | |||
extern gotoblas_t gotoblas_CORTEXA55; | |||
#else | |||
@@ -148,6 +153,13 @@ extern gotoblas_t gotoblas_A64FX; | |||
#define gotoblas_ARMV8SVE gotoblas_ARMV8 | |||
#define gotoblas_A64FX gotoblas_ARMV8 | |||
#endif | |||
#ifndef NO_SME | |||
extern gotoblas_t gotoblas_ARMV9SME; | |||
#else | |||
#define gotoblas_ARMV9SME gotoblas_ARMV8SVE | |||
#endif | |||
extern gotoblas_t gotoblas_THUNDERX3T110; | |||
#endif | |||
#define gotoblas_NEOVERSEV2 gotoblas_NEOVERSEV1 | |||
@@ -168,6 +180,9 @@ extern void openblas_warning(int verbose, const char * msg); | |||
#ifndef HWCAP_SVE | |||
#define HWCAP_SVE (1 << 22) | |||
#endif | |||
#ifndef HWCAP2_SME | |||
#define HWCAP2_SME 1<<23 | |||
#endif | |||
#define get_cpu_ftr(id, var) ({ \ | |||
__asm__ __volatile__ ("mrs %0, "#id : "=r" (var)); \ | |||
@@ -393,6 +408,13 @@ static gotoblas_t *get_coretype(void) { | |||
snprintf(coremsg, 128, "Unknown CPU model - implementer %x part %x\n",implementer,part); | |||
openblas_warning(1, coremsg); | |||
} | |||
#if !defined(NO_SME) && defined(HWCAP2_SME) | |||
if ((getauxval(AT_HWCAP2) & HWCAP2_SME)) { | |||
return &gotoblas_ARMV9SME; | |||
} | |||
#endif | |||
#ifndef NO_SVE | |||
if ((getauxval(AT_HWCAP) & HWCAP_SVE)) { | |||
return &gotoblas_ARMV8SVE; | |||
@@ -443,3 +465,15 @@ void gotoblas_dynamic_init(void) { | |||
void gotoblas_dynamic_quit(void) { | |||
gotoblas = NULL; | |||
} | |||
int support_sme1(void) { | |||
int ret = 0; | |||
#if (defined OS_LINUX || defined OS_ANDROID) | |||
ret = getauxval(AT_HWCAP2) & HWCAP2_SME; | |||
if(getauxval(AT_HWCAP2) & HWCAP2_SME){ | |||
ret = 1; | |||
} | |||
#endif | |||
return ret; | |||
} |
@@ -1289,6 +1289,19 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||
#define CORENAME "ARMV8SVE" | |||
#endif | |||
#ifdef FORCE_ARMV9SME | |||
#define FORCE | |||
#define ARCHITECTURE "ARM64" | |||
#define SUBARCHITECTURE "ARMV9SME" | |||
#define SUBDIRNAME "arm64" | |||
#define ARCHCONFIG "-DARMV9SME " \ | |||
"-DL1_DATA_SIZE=32768 -DL1_DATA_LINESIZE=64 " \ | |||
"-DL2_SIZE=262144 -DL2_LINESIZE=64 " \ | |||
"-DDTB_DEFAULT_ENTRIES=64 -DDTB_SIZE=4096 -DL2_ASSOCIATIVE=32 " \ | |||
"-DHAVE_VFPV4 -DHAVE_VFPV3 -DHAVE_VFP -DHAVE_NEON -DHAVE_SVE -DHAVE_SME -DARMV8 -DARMV9" | |||
#define LIBNAME "armv9sme" | |||
#define CORENAME "ARMV9SME" | |||
#endif | |||
#ifdef FORCE_ARMV8 | |||
#define FORCE | |||
@@ -1,5 +1,5 @@ | |||
/*********************************************************************/ | |||
/* Copyright 2024 The OpenBLAS Project */ | |||
/* Copyright 2024, 2025 The OpenBLAS Project */ | |||
/* Copyright 2009, 2010 The University of Texas at Austin. */ | |||
/* All rights reserved. */ | |||
/* */ | |||
@@ -86,7 +86,7 @@ | |||
#endif | |||
static int (*gemm[])(blas_arg_t *, BLASLONG *, BLASLONG *, IFLOAT *, IFLOAT *, BLASLONG) = { | |||
#ifndef GEMM3M | |||
#if !defined(GEMM3M) || defined(GENERIC) | |||
GEMM_NN, GEMM_TN, GEMM_RN, GEMM_CN, | |||
GEMM_NT, GEMM_TT, GEMM_RT, GEMM_CT, | |||
GEMM_NR, GEMM_TR, GEMM_RR, GEMM_CR, | |||
@@ -177,6 +177,49 @@ static int init_amxtile_permission() { | |||
} | |||
#endif | |||
#ifdef DYNAMIC_ARCH | |||
extern char* gotoblas_corename(void); | |||
#endif | |||
#if defined(DYNAMIC_ARCH) || defined(NEOVERSEV1) | |||
static inline int get_gemm_optimal_nthreads_neoversev1(double MNK, int ncpu) { | |||
return | |||
MNK < 262144L ? 1 | |||
: MNK < 1124864L ? MIN(ncpu, 6) | |||
: MNK < 7880599L ? MIN(ncpu, 12) | |||
: MNK < 17173512L ? MIN(ncpu, 16) | |||
: MNK < 33386248L ? MIN(ncpu, 20) | |||
: MNK < 57066625L ? MIN(ncpu, 24) | |||
: MNK < 91733851L ? MIN(ncpu, 32) | |||
: MNK < 265847707L ? MIN(ncpu, 40) | |||
: MNK < 458314011L ? MIN(ncpu, 48) | |||
: MNK < 729000000L ? MIN(ncpu, 56) | |||
: ncpu; | |||
} | |||
#endif | |||
static inline int get_gemm_optimal_nthreads(double MNK) { | |||
int ncpu = num_cpu_avail(3); | |||
#if defined(NEOVERSEV1) && !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16) | |||
return get_gemm_optimal_nthreads_neoversev1(MNK, ncpu); | |||
#elif defined(DYNAMIC_ARCH) && !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16) | |||
if (strcmp(gotoblas_corename(), "neoversev1") == 0) { | |||
return get_gemm_optimal_nthreads_neoversev1(MNK, ncpu); | |||
} | |||
#endif | |||
if ( MNK <= (SMP_THRESHOLD_MIN * (double) GEMM_MULTITHREAD_THRESHOLD) ) { | |||
return 1; | |||
} | |||
else { | |||
if (MNK/ncpu < SMP_THRESHOLD_MIN*(double)GEMM_MULTITHREAD_THRESHOLD) { | |||
return MNK/(SMP_THRESHOLD_MIN*(double)GEMM_MULTITHREAD_THRESHOLD); | |||
} | |||
else { | |||
return ncpu; | |||
} | |||
} | |||
} | |||
#ifndef CBLAS | |||
void NAME(char *TRANSA, char *TRANSB, | |||
@@ -310,7 +353,7 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS | |||
FLOAT *beta = (FLOAT*) vbeta; | |||
FLOAT *a = (FLOAT*) va; | |||
FLOAT *b = (FLOAT*) vb; | |||
FLOAT *c = (FLOAT*) vc; | |||
FLOAT *c = (FLOAT*) vc; | |||
#endif | |||
blas_arg_t args; | |||
@@ -350,14 +393,21 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS | |||
PRINT_DEBUG_CNAME; | |||
#if !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16) && defined(USE_SGEMM_KERNEL_DIRECT) | |||
#ifdef DYNAMIC_ARCH | |||
#if defined(DYNAMIC_ARCH) && defined(ARCH_x86) | |||
if (support_avx512() ) | |||
#endif | |||
if (beta == 0 && alpha == 1.0 && order == CblasRowMajor && TransA == CblasNoTrans && TransB == CblasNoTrans && SGEMM_DIRECT_PERFORMANT(m,n,k)) { | |||
SGEMM_DIRECT(m, n, k, a, lda, b, ldb, c, ldc); | |||
return; | |||
} | |||
#endif | |||
#if defined(DYNAMIC_ARCH) && defined(ARCH_ARM64) | |||
if (support_sme1()){ | |||
if (beta == 0 && alpha == 1.0 && order == CblasRowMajor && TransA == CblasNoTrans && TransB == CblasNoTrans) { | |||
SGEMM_DIRECT(m, n, k, a, lda, b, ldb, c, ldc); | |||
return; | |||
} | |||
} | |||
#endif | |||
#endif | |||
#ifndef COMPLEX | |||
@@ -604,13 +654,7 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS | |||
#endif | |||
MNK = (double) args.m * (double) args.n * (double) args.k; | |||
if ( MNK <= (SMP_THRESHOLD_MIN * (double) GEMM_MULTITHREAD_THRESHOLD) ) | |||
args.nthreads = 1; | |||
else { | |||
args.nthreads = num_cpu_avail(3); | |||
if (MNK/args.nthreads < SMP_THRESHOLD_MIN*(double)GEMM_MULTITHREAD_THRESHOLD) | |||
args.nthreads = MNK/(SMP_THRESHOLD_MIN*(double)GEMM_MULTITHREAD_THRESHOLD); | |||
} | |||
args.nthreads = get_gemm_optimal_nthreads(MNK); | |||
args.common = NULL; | |||
@@ -65,6 +65,7 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS) | |||
GenerateNamedObjects("${KERNELDIR}/${${float_char}COPYKERNEL}" "C_INTERFACE" "copy_k" false "" "" false ${float_type}) | |||
GenerateNamedObjects("${KERNELDIR}/${${float_char}NRM2KERNEL}" "" "nrm2_k" false "" "" false ${float_type}) | |||
GenerateNamedObjects("${KERNELDIR}/${${float_char}ROTKERNEL}" "" "rot_k" false "" "" false ${float_type}) | |||
GenerateNamedObjects("${KERNELDIR}/${${float_char}ROTMKERNEL}" "" "rotm_k" false "" "" false ${float_type}) | |||
GenerateNamedObjects("${KERNELDIR}/${${float_char}SCALKERNEL}" "" "scal_k" false "" "" false ${float_type}) | |||
GenerateNamedObjects("${KERNELDIR}/${${float_char}SWAPKERNEL}" "" "swap_k" false "" "" false ${float_type}) | |||
GenerateNamedObjects("${KERNELDIR}/${${float_char}AXPBYKERNEL}" "" "axpby_k" false "" "" false ${float_type}) | |||
@@ -125,6 +126,7 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS) | |||
GenerateNamedObjects("${KERNELDIR}/${SNRM2KERNEL}" "" "nrm2_k" false "" "" false "SINGLE") | |||
GenerateNamedObjects("${KERNELDIR}/${SDOTKERNEL}" "" "dot_k" false "" "" false "SINGLE") | |||
GenerateNamedObjects("${KERNELDIR}/${SROTKERNEL}" "" "rot_k" false "" "" false "SINGLE") | |||
GenerateNamedObjects("${KERNELDIR}/${SROTMKERNEL}" "" "rotm_k" false "" "" false "SINGLE") | |||
endif () | |||
if (BUILD_COMPLEX16 AND NOT BUILD_DOUBLE) | |||
GenerateNamedObjects("${KERNELDIR}/${DAMAXKERNEL}" "USE_ABS" "amax_k" false "" "" false "DOUBLE") | |||
@@ -148,6 +150,7 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS) | |||
GenerateNamedObjects("${KERNELDIR}/${DCOPYKERNEL}" "C_INTERFACE" "copy_k" false "" "" false "DOUBLE") | |||
GenerateNamedObjects("${KERNELDIR}/${DNRM2KERNEL}" "" "nrm2_k" false "" "" false "DOUBLE") | |||
GenerateNamedObjects("${KERNELDIR}/${DROTKERNEL}" "" "rot_k" false "" "" false "DOUBLE") | |||
GenerateNamedObjects("${KERNELDIR}/${DROTMKERNEL}" "" "rotm_k" false "" "" false "DOUBLE") | |||
GenerateNamedObjects("${KERNELDIR}/${DDOTKERNEL}" "" "dot_k" false "" "" false "DOUBLE") | |||
GenerateNamedObjects("${KERNELDIR}/${DSWAPKERNEL}" "" "swap_k" false "" "" false "DOUBLE") | |||
GenerateNamedObjects("${KERNELDIR}/${DAXPYKERNEL}" "" "axpy_k" false "" "" false "DOUBLE") | |||
@@ -204,19 +207,27 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS) | |||
if (ZARCH OR (UC_TARGET_CORE MATCHES POWER8) OR (UC_TARGET_CORE MATCHES POWER9) OR (UC_TARGET_CORE MATCHES POWER10)) | |||
set(USE_TRMM true) | |||
endif () | |||
set(USE_DIRECT_SGEMM false) | |||
if (X86_64) | |||
if (X86_64 OR (ARM64 AND (UC_TARGET_CORE MATCHES ARMV9SME))) | |||
set(USE_DIRECT_SGEMM true) | |||
endif() | |||
if (USE_DIRECT_SGEMM) | |||
# if (NOT DEFINED SGEMMDIRECTKERNEL) | |||
if (X86_64) | |||
set (SGEMMDIRECTKERNEL sgemm_direct_skylakex.c) | |||
set (SGEMMDIRECTPERFORMANT sgemm_direct_performant.c) | |||
# endif() | |||
GenerateNamedObjects("${KERNELDIR}/${SGEMMDIRECTKERNEL}" "" "gemm_direct" false "" "" false SINGLE) | |||
GenerateNamedObjects("${KERNELDIR}/${SGEMMDIRECTPERFORMANT}" "" "gemm_direct_performant" false "" "" false SINGLE) | |||
elseif (ARM64) | |||
set (SGEMMDIRECTKERNEL sgemm_direct_arm64_sme1.c) | |||
set (SGEMMDIRECTSMEKERNEL sgemm_direct_sme1.S) | |||
set (SGEMMDIRECTPREKERNEL sgemm_direct_sme1_preprocess.S) | |||
GenerateNamedObjects("${KERNELDIR}/${SGEMMDIRECTKERNEL}" "" "gemm_direct" false "" "" false SINGLE) | |||
GenerateNamedObjects("${KERNELDIR}/${SGEMMDIRECTSMEKERNEL}" "" "gemm_direct_sme1" false "" "" false SINGLE) | |||
GenerateNamedObjects("${KERNELDIR}/${SGEMMDIRECTPREKERNEL}" "" "gemm_direct_sme1_preprocess" false "" "" false SINGLE) | |||
endif () | |||
endif() | |||
foreach (float_type SINGLE DOUBLE) | |||
@@ -1105,6 +1116,7 @@ endif () | |||
GenerateNamedObjects("${KERNELDIR}/${DCOPYKERNEL}" "C_INTERFACE" "copy_k" false "" "" false "DOUBLE") | |||
GenerateNamedObjects("${KERNELDIR}/${DNRM2KERNEL}" "" "nrm2_k" false "" "" false "DOUBLE") | |||
GenerateNamedObjects("${KERNELDIR}/${DROTKERNEL}" "" "rot_k" false "" "" false "DOUBLE") | |||
GenerateNamedObjects("${KERNELDIR}/${DROTMKERNEL}" "" "rotm_k" false "" "" false "DOUBLE") | |||
GenerateNamedObjects("${KERNELDIR}/${DDOTKERNEL}" "" "dot_k" false "" "" false "DOUBLE") | |||
GenerateNamedObjects("${KERNELDIR}/${DSWAPKERNEL}" "" "swap_k" false "" "" false "DOUBLE") | |||
GenerateNamedObjects("${KERNELDIR}/${DAXPYKERNEL}" "" "axpy_k" false "" "" false "DOUBLE") | |||
@@ -24,7 +24,11 @@ ifdef NO_AVX2 | |||
AVX2OPT= | |||
endif | |||
ifdef TARGET_CORE | |||
ifeq ($(TARGET_CORE), ARMV9SME) | |||
override CFLAGS += -DBUILD_KERNEL -DTABLE_NAME=gotoblas_$(TARGET_CORE) -DHAVE_SME -march=armv9-a+sve2+sme | |||
endif | |||
ifeq ($(TARGET_CORE), SAPPHIRERAPIDS) | |||
override CFLAGS += -DBUILD_KERNEL -DTABLE_NAME=gotoblas_$(TARGET_CORE) | |||
ifeq (1, $(filter 1,$(GCCVERSIONGTEQ11) $(CLANGVERSIONGTEQ12))) | |||
@@ -24,6 +24,7 @@ endif | |||
ifeq ($(ARCH), arm64) | |||
USE_TRMM = 1 | |||
USE_DIRECT_SGEMM = 1 | |||
endif | |||
ifeq ($(ARCH), riscv64) | |||
@@ -95,9 +96,17 @@ endif | |||
ifdef USE_DIRECT_SGEMM | |||
ifndef SGEMMDIRECTKERNEL | |||
ifeq ($(ARCH), x86_64) | |||
SGEMMDIRECTKERNEL = sgemm_direct_skylakex.c | |||
SGEMMDIRECTPERFORMANT = sgemm_direct_performant.c | |||
endif | |||
ifeq ($(ARCH), arm64) | |||
ifeq ($(TARGET_CORE), ARMV9SME) | |||
HAVE_SME = 1 | |||
SGEMMDIRECTKERNEL = sgemm_direct_arm64_sme1.c | |||
endif | |||
endif | |||
endif | |||
endif | |||
ifeq ($(BUILD_BFLOAT16), 1) | |||
@@ -128,9 +137,19 @@ SKERNELOBJS += \ | |||
$(SGEMMONCOPYOBJ) $(SGEMMOTCOPYOBJ) | |||
ifdef USE_DIRECT_SGEMM | |||
ifeq ($(ARCH), x86_64) | |||
SKERNELOBJS += \ | |||
sgemm_direct$(TSUFFIX).$(SUFFIX) \ | |||
sgemm_direct_performant$(TSUFFIX).$(SUFFIX) | |||
endif | |||
ifeq ($(ARCH), arm64) | |||
ifdef HAVE_SME | |||
SKERNELOBJS += \ | |||
sgemm_direct$(TSUFFIX).$(SUFFIX) \ | |||
sgemm_direct_performant$(TSUFFIX).$(SUFFIX) | |||
sgemm_direct_sme1$(TSUFFIX).$(SUFFIX) \ | |||
sgemm_direct_sme1_preprocess$(TSUFFIX).$(SUFFIX) | |||
endif | |||
endif | |||
endif | |||
endif | |||
@@ -809,11 +828,23 @@ else | |||
endif | |||
ifdef USE_DIRECT_SGEMM | |||
ifeq ($(ARCH), x86_64) | |||
$(KDIR)sgemm_direct_performant$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMMDIRECTPERFORMANT) | |||
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ | |||
$(KDIR)sgemm_direct$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMMDIRECTKERNEL) | |||
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ | |||
endif | |||
ifeq ($(ARCH), arm64) | |||
ifdef HAVE_SME | |||
$(KDIR)sgemm_direct$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMMDIRECTKERNEL) | |||
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ | |||
$(KDIR)sgemm_direct_sme1$(TSUFFIX).$(SUFFIX) : | |||
$(CC) $(CFLAGS) -c $(KERNELDIR)/sgemm_direct_sme1.S -UDOUBLE -UCOMPLEX -o $@ | |||
$(KDIR)sgemm_direct_sme1_preprocess$(TSUFFIX).$(SUFFIX) : | |||
$(CC) $(CFLAGS) -c $(KERNELDIR)/sgemm_direct_sme1_preprocess.S -UDOUBLE -UCOMPLEX -o $@ | |||
endif | |||
endif | |||
endif | |||
ifeq ($(BUILD_BFLOAT16), 1) | |||
@@ -0,0 +1,3 @@ | |||
include $(KERNELDIR)/KERNEL.ARMV8SVE | |||
@@ -0,0 +1,59 @@ | |||
/* | |||
Copyright (c) 2025 Qualcomm Innovation Center, Inc. All rights reserved. | |||
SPDX-License-Identifier: BSD-3-Clause-Clear | |||
*/ | |||
#include "common.h" | |||
#include <stdlib.h> | |||
#include <inttypes.h> | |||
#include <math.h> | |||
#if defined(HAVE_SME) | |||
/* Function prototypes */ | |||
extern void sgemm_direct_sme1_preprocess(uint64_t nbr, uint64_t nbc,\ | |||
const float * restrict a, float * a_mod) __asm__("sgemm_direct_sme1_preprocess"); | |||
extern void sgemm_direct_sme1_2VLx2VL(uint64_t m, uint64_t k, uint64_t n,\ | |||
const float * matLeft,\ | |||
const float * restrict matRight,\ | |||
const float * restrict matResult) __asm__("sgemm_direct_sme1_2VLx2VL"); | |||
/* Function Definitions */ | |||
uint64_t sve_cntw() { | |||
uint64_t cnt; | |||
asm volatile( | |||
"rdsvl %[res], #1\n" | |||
"lsr %[res], %[res], #2\n" | |||
: [res] "=r" (cnt) :: | |||
); | |||
return cnt; | |||
} | |||
/*void sgemm_kernel_direct (BLASLONG M, BLASLONG N, BLASLONG K,\ | |||
float * __restrict A, BLASLONG strideA, float * __restrict B,\ | |||
BLASLONG strideB , float * __restrict R, BLASLONG strideR) | |||
*/ | |||
void CNAME (BLASLONG M, BLASLONG N, BLASLONG K, float * __restrict A,\ | |||
BLASLONG strideA, float * __restrict B, BLASLONG strideB ,\ | |||
float * __restrict R, BLASLONG strideR){ | |||
uint64_t m_mod, vl_elms; | |||
vl_elms = sve_cntw(); | |||
m_mod = ceil((double)M/(double)vl_elms) * vl_elms; | |||
float *A_mod = (float *) malloc(m_mod*K*sizeof(float)); | |||
/* Pre-process the left matrix to make it suitable for | |||
matrix sum of outer-product calculation | |||
*/ | |||
sgemm_direct_sme1_preprocess(M, K, A, A_mod); | |||
/* Calculate C = A*B */ | |||
sgemm_direct_sme1_2VLx2VL(M, K, N, A_mod, B, R); | |||
free(A_mod); | |||
} | |||
#endif |
@@ -0,0 +1,228 @@ | |||
/* | |||
Copyright (c) 2025 Qualcomm Innovation Center, Inc. All rights reserved. | |||
SPDX-License-Identifier: BSD-3-Clause-Clear | |||
*/ | |||
/*-------------------------------------------------------------------------- | |||
* SME1 based Matrix multiplication code for FP32 input matrices to FP32 | |||
* output matrix | |||
* C = A*B | |||
* A: Left input matrix of dimension M x K | |||
* B: Right input matrix of dimension K x N | |||
* C: Result matrix of dimension M x N | |||
* | |||
* Usage of function: | |||
* sgemm_direct_sme1_2VLx2VL( uint64_t M , uint64_t K, uint64_t N,\ | |||
const float * restrict A_base,\ | |||
const float * restrict B_base,\ | |||
const float * restrict C_base); | |||
----------------------------------------------------------------------------*/ | |||
#define M x0 //M dimension | |||
#define K x1 //K dimension | |||
#define N x2 //N dimension | |||
#define A_base x3 //Pointer to left matrix(A) | |||
#define B_base x4 //Pointer to right matrix(B) | |||
#define C_base x5 //Pointer to result matrix(C) | |||
#define Aptr x6 //Pointer to traverse A | |||
#define Aptr_end x7 //Pointer to end of row of A | |||
#define Cptr x8 //Pointer to traverse C | |||
#define Cptr0 x9 //2nd Pointer to traverse C | |||
#define Cptr1 x10 //3rd Pointer to traverse C | |||
#define Bptr x11 //Pointer to traverse B | |||
#define Bptr0 x12 //2nd Pointer to traverse B | |||
#define N_exit x14 //Exit condition for N loop | |||
#define K_exit x15 //Exit condition for K loop | |||
#define M_cntr x16 //M loop counter | |||
#define C1 x17 //Constant1: N*(SVLs+1);SVLs-No. of 32-bit elements | |||
#define C2 x18 //Constant2: N + SVLs | |||
#define C3 x19 //Constant3: K*SVLs + SVLs | |||
#define C4 x20 //Constant4: SVLs-2 | |||
#define C5 x21 //Constant5: K*SVLs | |||
#define C6 x22 //Constant6: N*SVLs | |||
.text | |||
.global sgemm_direct_sme1_2VLx2VL | |||
sgemm_direct_sme1_2VLx2VL: | |||
stp x19, x20, [sp, #-48]! | |||
stp x21, x22, [sp, #16] | |||
stp x23, x24, [sp, #32] | |||
smstart | |||
cntw C4 //SVLs | |||
mul C5, C4, K //K*SVLs | |||
mul C6, C4, N //N*SVLs | |||
add C1, C6, N //N*SVLs + N | |||
add N_exit, B_base, N, lsl #2 //N_Loop exit conditon | |||
mov M_cntr, #0 | |||
add C2, N, C4 //N + SVLs | |||
add C3, C5, C4 //K*SVLs + SVLs | |||
whilelt p2.s, M_cntr, M //Tile 0,1 predicate (M dimension) | |||
sub w20, w20, #2 //SVLs-2 | |||
.M_Loop: | |||
incw M_cntr | |||
whilelt p3.s, M_cntr, M //Tile 2,3 predicate (M dimension) | |||
mov Bptr, B_base //B_base | |||
mov Cptr, C_base //C_base | |||
whilelt p0.b, Bptr, N_exit //Tile 0/2 predicate (N dimension) | |||
.N_Loop: | |||
mov Aptr, A_base //Aptr = A_base | |||
mov Bptr0, Bptr //Bptr = B_base | |||
mov Cptr0, Cptr //Cptr0 = C_base | |||
addvl Cptr1, Cptr, #1 //Cptr1 = C_base + SVLb | |||
addvl Bptr, Bptr, #1 | |||
whilelt p1.b, Bptr, N_exit //Tile 1,3 predicate (N dimension) | |||
add Aptr_end, A_base, C5, lsl #2 //A_base + K*SVLs | |||
addvl K_exit, Aptr_end, #-1 //Exit condition for K loop | |||
//Load 1st vector from Aptr | |||
ld1w {z1.s}, p2/z, [Aptr] | |||
zero {za} | |||
// Load 1st vector from Bptr | |||
ld1w {z2.s}, p0/z, [Bptr0] | |||
// ZA0 += 1st Aptr vector OP 1st Bptr vector | |||
fmopa za0.s, p2/m, p0/m, z1.s, z2.s | |||
// Load 2nd vector from Aptr | |||
ld1w {z5.s}, p3/z, [Aptr, C5, lsl #2] | |||
// Aptr += SVLb | |||
addvl Aptr, Aptr, #1 | |||
.K_Loop: | |||
// ZA2 += 2nd Aptr vector OP 1st Bptr vector | |||
fmopa za2.s, p3/m, p0/m, z5.s, z2.s | |||
// Load 2nd vector from Bptr | |||
ld1w {z3.s}, p1/z, [Bptr0, #1, MUL VL] | |||
// ZA1 += 1st Aptr vector OP 2nd Bptr vector | |||
fmopa za1.s, p2/m, p1/m, z1.s, z3.s | |||
// Load next 1st vector from Aptr | |||
ld1w {z0.s}, p2/z, [Aptr] | |||
// ZA3 += 2nd Aptr vector OP 2nd Bptr vector | |||
fmopa za3.s, p3/m, p1/m, z5.s, z3.s | |||
cmp K, #2 | |||
b.le process_K_less_than_equal_2 | |||
// Load next 1st vector from Bptr | |||
ld1w {z6.s}, p0/z, [Bptr0, N, lsl #2] | |||
// ZA0 += 1st Aptr vector OP 1st Bptr vector | |||
fmopa za0.s, p2/m, p0/m, z0.s, z6.s | |||
// Load next 2nd vector from Aptr | |||
ld1w {z4.s}, p3/z, [Aptr, C5, lsl #2] | |||
// ZA2 += 2nd Aptr vector OP 1st Bptr vector | |||
fmopa za2.s, p3/m, p0/m, z4.s, z6.s | |||
// Load next 2nd vector from Bptr | |||
ld1w {z7.s}, p1/z, [Bptr0, C2, lsl #2] | |||
// Bptr += 2*ldb FP32 elms [Bytes] | |||
add Bptr0, Bptr0, N, lsl #3 | |||
// ZA1 += 1st Aptr vector OP 2nd Bptr vector | |||
fmopa za1.s, p2/m, p1/m, z0.s, z7.s | |||
// Load next 2nd vector from Aptr | |||
ld1w {z1.s}, p2/z, [Aptr, #1, MUL VL] | |||
// ZA3 += 2nd Aptr vector OP 2nd Bptr vector | |||
fmopa za3.s, p3/m, p1/m, z4.s, z7.s | |||
// Load next 1st vector from Bptr | |||
ld1w {z2.s}, p0/z, [Bptr0] | |||
// ZA0 += 1st Aptr vector OP 1st Bptr vector | |||
fmopa za0.s, p2/m, p0/m, z1.s, z2.s | |||
// Load next 2nd vector from Aptr | |||
ld1w {z5.s}, p3/z, [Aptr, C3, lsl #2] | |||
// Aptr += 2*SVLb [Bytes] | |||
addvl Aptr, Aptr, #2 | |||
cmp Aptr, K_exit | |||
b.mi .K_Loop | |||
// ZA2 += 2nd Aptr vector OP 1st Bptr vector | |||
fmopa za2.s, p3/m, p0/m, z5.s, z2.s | |||
// Load next 2nd vector from Bptr | |||
ld1w {z3.s}, p1/z, [Bptr0, #1, MUL VL] | |||
// ZA1 += 1st Aptr vector OP 2nd Bptr vector | |||
fmopa za1.s, p2/m, p1/m, z1.s, z3.s | |||
// ZA3 += 2nd Aptr vector OP 2nd Bptr vector | |||
fmopa za3.s, p3/m, p1/m, z5.s, z3.s | |||
process_K_less_than_equal_2: | |||
// Bptr += 2*ldb FP32 elements | |||
add Bptr0, Bptr0, N, lsl #2 | |||
cmp Aptr, Aptr_end | |||
b.pl .Ktail_end | |||
.Ktail_start: | |||
ld1w {z1.s}, p2/z, [Aptr] | |||
ld1w {z2.s}, p0/z, [Bptr0] | |||
ld1w {z3.s}, p1/z, [Bptr0, #1, MUL VL] | |||
fmopa za0.s, p2/m, p0/m, z1.s, z2.s | |||
ld1w {z5.s}, p3/z, [Aptr, C5, lsl #2] | |||
fmopa za2.s, p3/m, p0/m, z5.s, z2.s | |||
fmopa za1.s, p2/m, p1/m, z1.s, z3.s | |||
fmopa za3.s, p3/m, p1/m, z5.s, z3.s | |||
.Ktail_end: | |||
mov w13, #0 | |||
psel p4, p0, p2.s[w13, 0] | |||
psel p5, p1, p2.s[w13, 0] | |||
psel p6, p0, p3.s[w13, 0] | |||
psel p7, p1, p3.s[w13, 0] | |||
// Store to Cptr0 | |||
st1w {za0h.s[w13, #0]}, p4, [Cptr0] | |||
// Store to Cptr1 | |||
st1w {za1h.s[w13, #0]}, p5, [Cptr1] | |||
// Store to Cptr0 + N*SVLs | |||
st1w {za2h.s[w13, #0]}, p6, [Cptr0, C6, lsl #2] | |||
// Store to Cptr1 + N*SVLs | |||
st1w {za3h.s[w13, #0]}, p7, [Cptr1, C6, lsl #2] | |||
.Loop_store_ZA: | |||
psel p4, p0, p2.s[w13, 1] | |||
psel p5, p1, p2.s[w13, 1] | |||
psel p6, p0, p3.s[w13, 1] | |||
psel p7, p1, p3.s[w13, 1] | |||
// Store to Cptr0 + N | |||
st1w {za0h.s[w13, #1]}, p4, [Cptr0, N, lsl #2] | |||
// Store to Cptr1 + N | |||
st1w {za1h.s[w13, #1]}, p5, [Cptr1, N, lsl #2] | |||
// Store to Cptr0 + N*(SVLs+1) | |||
st1w {za2h.s[w13, #1]}, p6, [Cptr0, C1, lsl #2] | |||
// Store to Cptr1 + N*(SVLs+1) | |||
st1w {za3h.s[w13, #1]}, p7, [Cptr1, C1, lsl #2] | |||
add Cptr0, Cptr0, N, lsl #3 //Cptr0 += 2*N FP32 elements | |||
add Cptr1, Cptr1, N, lsl #3 //Cptr1 += 2*N FP32 elements | |||
add w13, w13, #2 | |||
psel p4, p0, p2.s[w13, 0] | |||
psel p5, p1, p2.s[w13, 0] | |||
psel p6, p0, p3.s[w13, 0] | |||
psel p7, p1, p3.s[w13, 0] | |||
st1w {za0h.s[w13, #0]}, p4, [Cptr0] | |||
st1w {za1h.s[w13, #0]}, p5, [Cptr1] | |||
st1w {za2h.s[w13, #0]}, p6, [Cptr0, C6, lsl #2] | |||
st1w {za3h.s[w13, #0]}, p7, [Cptr1, C6, lsl #2] | |||
cmp w13, w20 | |||
b.mi .Loop_store_ZA | |||
psel p4, p0, p2.s[w13, 1] | |||
psel p5, p1, p2.s[w13, 1] | |||
psel p6, p0, p3.s[w13, 1] | |||
psel p7, p1, p3.s[w13, 1] | |||
st1w {za0h.s[w13, #1]}, p4, [Cptr0, N, lsl #2] | |||
st1w {za1h.s[w13, #1]}, p5, [Cptr1, N, lsl #2] | |||
st1w {za2h.s[w13, #1]}, p6, [Cptr0, C1, lsl #2] | |||
st1w {za3h.s[w13, #1]}, p7, [Cptr1, C1, lsl #2] | |||
addvl Cptr, Cptr, #2 | |||
addvl Bptr, Bptr, #1 | |||
whilelt p0.b, Bptr, N_exit //1st Tile predicate (N dimension) | |||
b.first .N_Loop | |||
add A_base, A_base, C5, lsl #3 //A_base += 2*K*SVLs FP32 elements | |||
add C_base, C_base, C6, lsl #3 //C_base += 2*N*SVLs FP32 elements | |||
incw M_cntr | |||
whilelt p2.s, M_cntr, M //1st Tile predicate (M dimension) | |||
b.first .M_Loop | |||
smstop | |||
ldp x23, x24, [sp, #32] | |||
ldp x21, x22, [sp, #16] | |||
ldp x19, x20, [sp], #48 | |||
ret | |||
@@ -0,0 +1,133 @@ | |||
/* | |||
Copyright (c) 2025 Qualcomm Innovation Center, Inc. All rights reserved. | |||
SPDX-License-Identifier: BSD-3-Clause-Clear | |||
*/ | |||
/*---------------------------------------------------------------------------- | |||
* This function is used to re-arrange the elements of input matrix to | |||
* make it suitable for matrix outer product computation using SME for matrix | |||
* multiplication. It should be used to pre-process the leftmatrix(A) in the | |||
* matrix muliplication (C= A*B) using sgemm_direct_sme1_2VLx2VL() | |||
* | |||
* The pre-processing transposes a block of SVLs rows of the input matrix and | |||
* stores it contiguously. The same is applied to remaining blocks of SVLs | |||
* rows. The last block of SVLs rows is zero-padded to SVLs rows if needed. | |||
* | |||
* Usage of function: | |||
* sgemm_direct_sme1_preprocess(uint64_t nrow, uint64_t ncol, \ | |||
* const float * restrict mat, float * mat_mod); | |||
* | |||
----------------------------------------------------------------------------*/ | |||
#define nrow x0 //Number of rows of input matrix | |||
#define ncol x1 //Number of coulumns of input matrix | |||
#define mat x2 //Input matrix base address | |||
#define mat_mod x3 //Output matrix (re-arranged matrix) base address | |||
#define mat_mod_ptr x4 //Pointer to output matrix | |||
#define mat_ptr0 x5 //Pointer to input matrix | |||
#define mat_ptr1 x6 //2nd pointer to input matrix | |||
#define outer_loop_cntr x7 //Outer loop counter | |||
#define inner_loop_exit x8 //Inner loop exit condition | |||
#define C1 x9 //Constant1: SVLs - No. of 32-bit elements | |||
#define C2 x10 //Constant2: 3*SVLs | |||
#define C3 x11 //Constant3: ncol*SVLs | |||
#define C4 x13 //Constant4: 2*SVLs | |||
#define C5 x14 //Constant5: 2*ncol | |||
#define C6 x15 //Constant6: 3*ncol | |||
.text | |||
.global sgemm_direct_sme1_preprocess | |||
sgemm_direct_sme1_preprocess: | |||
stp x19, x20, [sp, #-48]! | |||
stp x21, x22, [sp, #16] | |||
stp x23, x24, [sp, #32] | |||
smstart | |||
cntw C1 //SVLs | |||
mul C3, C1, ncol //SVLs*ncol | |||
lsl C5, ncol, #1 //2*ncol | |||
add C6, C5, ncol //3*ncol | |||
cnth C4 //2*SVLs | |||
add C2, C1, C1, lsl #1 //3*SVLs | |||
mov outer_loop_cntr, #0 | |||
//Tile predicate (M dimension) | |||
whilelt p0.s, outer_loop_cntr, nrow | |||
//Predicate for stores | |||
ptrue p9.s | |||
.M_Loop: | |||
mov mat_ptr0, mat //Load base address of mat | |||
mov mat_mod_ptr, mat_mod //a_mod store base address | |||
add inner_loop_exit, mat, ncol, lsl #2 //Exit condition for inner loop | |||
whilelt p8.b, mat_ptr0, inner_loop_exit //Tile predicate (K dimension) | |||
.Loop_process: | |||
mov mat_ptr1, mat_ptr0 | |||
//Load_to_tile loop counter | |||
mov w12, #0 | |||
.Load_to_tile: | |||
psel p2, p8, p0.s[w12, 0] | |||
psel p3, p8, p0.s[w12, 1] | |||
psel p4, p8, p0.s[w12, 2] | |||
psel p5, p8, p0.s[w12, 3] | |||
//Load 1st row from mat_ptr1 | |||
ld1w {za0h.s[w12, #0]}, p2/z, [mat_ptr1] | |||
//Load 2nd row from mat_ptr1 + ncol | |||
ld1w {za0h.s[w12, #1]}, p3/z, [mat_ptr1, ncol, lsl #2] | |||
//Load 3rd row from mat_ptr1 + 2*ncol | |||
ld1w {za0h.s[w12, #2]}, p4/z, [mat_ptr1, C5, lsl #2] | |||
//Load 4th row from mat_ptr1 + 3*ncol | |||
ld1w {za0h.s[w12, #3]}, p5/z, [mat_ptr1, C6, lsl #2] | |||
//mat_ptr1+=4*ncol FP32 elements | |||
add mat_ptr1, mat_ptr1, ncol, lsl #4 | |||
//Increment counter | |||
add w12, w12, #4 | |||
cmp w12, w9 | |||
b.mi .Load_to_tile | |||
// Store_from_tile loop counter | |||
mov w12, #0 | |||
.Store_from_tile: | |||
psel p2, p9, p8.s[w12, 0] | |||
psel p3, p9, p8.s[w12, 1] | |||
psel p4, p9, p8.s[w12, 2] | |||
psel p5, p9, p8.s[w12, 3] | |||
//Store 1st col to mat_mod | |||
st1w {za0v.s[w12, #0]}, p2, [mat_mod_ptr] | |||
//Store 2nd col to mat_mod + SVLs | |||
st1w {za0v.s[w12, #1]}, p3, [mat_mod_ptr, C1, lsl #2] | |||
//Store 3rd col to mat_mod + 2*SVLs | |||
st1w {za0v.s[w12, #2]}, p4, [mat_mod_ptr, C4, lsl #2] | |||
//Store 4th col to mat_mod + 3*SVLs | |||
st1w {za0v.s[w12, #3]}, p5, [mat_mod_ptr, C2, lsl #2] | |||
addvl mat_mod_ptr, mat_mod_ptr, #4 //mat_mod_ptr += 4*SVLb | |||
add w12, w12, #4 //Increment counter | |||
cmp w12, w9 | |||
b.mi .Store_from_tile | |||
addvl mat_ptr0, mat_ptr0, #1 //mat_ptr0 += SVLb | |||
whilelt p8.b, mat_ptr0, inner_loop_exit | |||
b.first .Loop_process | |||
add mat_mod, mat_mod, C3, lsl #2 //mat_mod+=SVLs*nbc FP32 elements | |||
add mat, mat, C3, lsl #2 //mat+=SVLs*nbc FP32 elements | |||
incw outer_loop_cntr | |||
whilelt p0.s, outer_loop_cntr, nrow | |||
b.first .M_Loop | |||
smstop | |||
ldp x23, x24, [sp, #32] | |||
ldp x21, x22, [sp, #16] | |||
ldp x19, x20, [sp], #48 | |||
ret | |||
@@ -178,6 +178,11 @@ gotoblas_t TABLE_NAME = { | |||
#ifdef ARCH_X86_64 | |||
sgemm_directTS, | |||
sgemm_direct_performantTS, | |||
#endif | |||
#ifdef ARCH_ARM64 | |||
#ifdef HAVE_SME | |||
sgemm_directTS, | |||
#endif | |||
#endif | |||
sgemm_kernelTS, sgemm_betaTS, | |||
@@ -3303,6 +3303,8 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||
#define GEMM_DEFAULT_OFFSET_A 0 | |||
#define GEMM_DEFAULT_OFFSET_B 0 | |||
#ifdef _WIN64 | |||
/* Use explicit casting for win64 as LLP64 datamodel is used */ | |||
#define GEMM_DEFAULT_ALIGN (BLASULONG)0x03fffUL | |||
@@ -3667,7 +3669,7 @@ Until then, just keep it different than DGEMM_DEFAULT_UNROLL_N to keep copy rout | |||
#define CGEMM_DEFAULT_R 4096 | |||
#define ZGEMM_DEFAULT_R 4096 | |||
#elif defined(ARMV8SVE) || defined(ARMV9) || defined(CORTEXA510)|| defined(CORTEXA710) || defined(CORTEXX2) // 128-bit SVE | |||
#elif defined(ARMV8SVE) || defined(ARMV9SME) || defined(ARMV9) || defined(CORTEXA510)|| defined(CORTEXA710) || defined(CORTEXX2) // 128-bit SVE | |||
#if defined(XDOUBLE) || defined(DOUBLE) | |||
#define SWITCH_RATIO 8 | |||
@@ -3738,6 +3740,10 @@ Until then, just keep it different than DGEMM_DEFAULT_UNROLL_N to keep copy rout | |||
#endif /* ARMv8 */ | |||
#if defined(ARMV9SME) /* ARMv9 SME */ | |||
#define USE_SGEMM_KERNEL_DIRECT 1 | |||
#endif /* ARMv9 SME */ | |||
#if defined(ARMV5) | |||
#define SNUMOPT 2 | |||
#define DNUMOPT 2 | |||