You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

sgemm_direct_alpha_beta_arm64_sme1.c 7.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. /*
  2. Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
  3. SPDX-License-Identifier: BSD-3-Clause-Clear
  4. */
  5. #include "common.h"
  6. #include <stdlib.h>
  7. #include <inttypes.h>
  8. #include <math.h>
  9. #include "sme_abi.h"
  10. #if defined(HAVE_SME)
  11. #if defined(__ARM_FEATURE_SME) && defined(__clang__) && __clang_major__ >= 16
  12. #include <arm_sme.h>
  13. #endif
  14. /* Function prototypes */
  15. extern void sgemm_direct_sme1_preprocess(uint64_t nbr, uint64_t nbc,\
  16. const float * restrict a, float * a_mod) __asm__("sgemm_direct_sme1_preprocess");
  17. /* Function Definitions */
  18. static uint64_t sve_cntw() {
  19. uint64_t cnt;
  20. asm volatile(
  21. "rdsvl %[res], #1\n"
  22. "lsr %[res], %[res], #2\n"
  23. : [res] "=r" (cnt) ::
  24. );
  25. return cnt;
  26. }
  27. #if defined(__ARM_FEATURE_SME) && defined(__ARM_FEATURE_LOCALLY_STREAMING) && defined(__clang__) && __clang_major__ >= 16
  28. // Outer product kernel.
  29. // Computes a 2SVL x 2SVL block of C, utilizing all four FP32 tiles of ZA.
  30. __attribute__((always_inline)) inline void
  31. kernel_2x2(const float *A, const float *B, float *C, size_t shared_dim,
  32. size_t ldc, size_t block_rows, size_t block_cols, float alpha, float beta)
  33. __arm_out("za") __arm_streaming {
  34. const uint64_t svl = svcntw();
  35. size_t ldb = ldc;
  36. // Predicate set-up
  37. svbool_t pg = svptrue_b32();
  38. svbool_t pg_a_0 = svwhilelt_b32_u64(0, block_rows);
  39. svbool_t pg_a_1 = svwhilelt_b32_u64(svl, block_rows);
  40. svbool_t pg_b_0 = svwhilelt_b32_u64(0, block_cols);
  41. svbool_t pg_b_1 = svwhilelt_b32_u64(svl, block_cols);
  42. #define pg_c_0 pg_b_0
  43. #define pg_c_1 pg_b_1
  44. svzero_za();
  45. svfloat32_t beta_vec = svdup_f32(beta);
  46. // Load C to ZA
  47. for (size_t i = 0; i < MIN(svl, block_rows); i++) {
  48. svfloat32_t row_c_0 = svld1(pg_c_0, &C[i * ldc]);
  49. row_c_0 = svmul_x(pg, beta_vec, row_c_0);
  50. svwrite_hor_za32_f32_m(/*tile*/0, /*slice*/i, pg_c_0, row_c_0);
  51. svfloat32_t row_c_1 = svld1(pg_c_1, &C[i * ldc + svl]);
  52. row_c_1 = svmul_x(pg, beta_vec, row_c_1);
  53. svwrite_hor_za32_f32_m(/*tile*/1, /*slice*/i, pg_c_1, row_c_1);
  54. }
  55. for (size_t i = svl; i < block_rows; i++) {
  56. svfloat32_t row_c_0 = svld1(pg_c_0, &C[i * ldc]);
  57. row_c_0 = svmul_x(pg, beta_vec, row_c_0);
  58. svwrite_hor_za32_f32_m(/*tile*/2, /*slice*/i, pg_c_0, row_c_0);
  59. svfloat32_t row_c_1 = svld1(pg_c_1, &C[i * ldc + svl]);
  60. row_c_1 = svmul_x(pg, beta_vec, row_c_1);
  61. svwrite_hor_za32_f32_m(/*tile*/3, /*slice*/i, pg_c_1, row_c_1);
  62. }
  63. svfloat32_t alpha_vec = svdup_f32(alpha);
  64. // Iterate through shared dimension (K)
  65. for (size_t k = 0; k < shared_dim; k++) {
  66. // Load column of A
  67. svfloat32_t col_a_0 = svld1(pg_a_0, &A[k * svl]);
  68. col_a_0 = svmul_x(pg, alpha_vec, col_a_0);
  69. svfloat32_t col_a_1 = svld1(pg_a_1, &A[(k + shared_dim) * svl]);
  70. col_a_1 = svmul_x(pg, alpha_vec, col_a_1);
  71. // Load row of B
  72. svfloat32_t row_b_0 = svld1(pg_b_0, &B[k * ldb]);
  73. svfloat32_t row_b_1 = svld1(pg_b_1, &B[k * ldb + svl]);
  74. // Perform outer product
  75. svmopa_za32_m(/*tile*/0, pg, pg, col_a_0, row_b_0);
  76. svmopa_za32_m(/*tile*/1, pg, pg, col_a_0, row_b_1);
  77. svmopa_za32_m(/*tile*/2, pg, pg, col_a_1, row_b_0);
  78. svmopa_za32_m(/*tile*/3, pg, pg, col_a_1, row_b_1);
  79. }
  80. // Store to C from ZA
  81. for (size_t i = 0; i < MIN(svl, block_rows); i++) {
  82. svst1_hor_za32(/*tile*/0, /*slice*/i, pg_c_0, &C[i * ldc]);
  83. svst1_hor_za32(/*tile*/1, /*slice*/i, pg_c_1, &C[i * ldc + svl]);
  84. }
  85. for (size_t i = svl; i < block_rows; i++) {
  86. svst1_hor_za32(/*tile*/2, /*slice*/i, pg_c_0, &C[i * ldc]);
  87. svst1_hor_za32(/*tile*/3, /*slice*/i, pg_c_1, &C[i * ldc + svl]);
  88. }
  89. }
  90. __arm_new("za") __arm_locally_streaming
  91. void sgemm_direct_alpha_beta_sme1_2VLx2VL(uint64_t m, uint64_t k, uint64_t n, const float* alpha,\
  92. const float *ba, const float *restrict bb, const float* beta,\
  93. float *restrict C) {
  94. const uint64_t num_rows = m;
  95. const uint64_t num_cols = n;
  96. const float *restrict a_ptr = ba;
  97. const float *restrict b_ptr = bb;
  98. float *restrict c_ptr = C;
  99. const uint64_t svl = svcntw();
  100. const uint64_t ldc = n;
  101. // Block over rows of C (panels of A)
  102. uint64_t row_idx = 0;
  103. // 2x2 loop
  104. uint64_t row_batch = 2*svl;
  105. // Block over row dimension of C
  106. for (; row_idx < num_rows; row_idx += row_batch) {
  107. row_batch = MIN(row_batch, num_rows - row_idx);
  108. uint64_t col_idx = 0;
  109. uint64_t col_batch = 2*svl;
  110. // Block over column dimension of C
  111. for (; col_idx < num_cols; col_idx += col_batch) {
  112. col_batch = MIN(col_batch, num_cols - col_idx);
  113. kernel_2x2(&a_ptr[row_idx * k], &b_ptr[col_idx],
  114. &c_ptr[row_idx * ldc + col_idx], k,
  115. ldc, row_batch, col_batch, *alpha, *beta);
  116. }
  117. }
  118. return;
  119. }
  120. #else
  121. void sgemm_direct_alpha_beta_sme1_2VLx2VL(uint64_t m, uint64_t k, uint64_t n, const float* alpha,\
  122. const float *ba, const float *restrict bb, const float* beta,\
  123. float *restrict C){}
  124. #endif
  125. /*void sgemm_kernel_direct (BLASLONG M, BLASLONG N, BLASLONG K,\
  126. float * __restrict A, BLASLONG strideA, float * __restrict B,\
  127. BLASLONG strideB , float * __restrict R, BLASLONG strideR)
  128. */
  129. void CNAME (BLASLONG M, BLASLONG N, BLASLONG K, float alpha, float * __restrict A,\
  130. BLASLONG strideA, float * __restrict B, BLASLONG strideB ,\
  131. float beta, float * __restrict R, BLASLONG strideR){
  132. uint64_t m_mod, vl_elms;
  133. vl_elms = sve_cntw();
  134. m_mod = ceil((double)M/(double)vl_elms) * vl_elms;
  135. float *A_mod = (float *) malloc(m_mod*K*sizeof(float));
  136. /* Prevent compiler optimization by reading from memory instead
  137. * of reading directly from vector (z) registers.
  138. * */
  139. asm volatile("" : : :"p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7",
  140. "p8", "p9", "p10", "p11", "p12", "p13", "p14", "p15",
  141. "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7",
  142. "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15",
  143. "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23",
  144. "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31");
  145. /* Pre-process the left matrix to make it suitable for
  146. matrix sum of outer-product calculation
  147. */
  148. sgemm_direct_sme1_preprocess(M, K, A, A_mod);
  149. asm volatile("" : : :"p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7",
  150. "p8", "p9", "p10", "p11", "p12", "p13", "p14", "p15",
  151. "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7",
  152. "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15",
  153. "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23",
  154. "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31");
  155. /* Calculate C = alpha*A*B + beta*C */
  156. sgemm_direct_alpha_beta_sme1_2VLx2VL(M, K, N, &alpha, A_mod, B, &beta, R);
  157. free(A_mod);
  158. }
  159. #else
  160. void CNAME (BLASLONG M, BLASLONG N, BLASLONG K, float alpha, float * __restrict A,\
  161. BLASLONG strideA, float * __restrict B, BLASLONG strideB ,\
  162. float beta, float * __restrict R, BLASLONG strideR){}
  163. #endif