You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

sgemm_direct_sme1_preprocess.S 4.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. /*
  2. Copyright (c) 2025 Qualcomm Innovation Center, Inc. All rights reserved.
  3. SPDX-License-Identifier: BSD-3-Clause-Clear
  4. */
  5. /*----------------------------------------------------------------------------
  6. * This function is used to re-arrange the elements of input matrix to
  7. * make it suitable for matrix outer product computation using SME for matrix
  8. * multiplication. It should be used to pre-process the leftmatrix(A) in the
  9. * matrix muliplication (C= A*B) using sgemm_direct_sme1_2VLx2VL()
  10. *
  11. * The pre-processing transposes a block of SVLs rows of the input matrix and
  12. * stores it contiguously. The same is applied to remaining blocks of SVLs
  13. * rows. The last block of SVLs rows is zero-padded to SVLs rows if needed.
  14. *
  15. * Usage of function:
  16. * sgemm_direct_sme1_preprocess(uint64_t nrow, uint64_t ncol, \
  17. * const float * restrict mat, float * mat_mod);
  18. *
  19. ----------------------------------------------------------------------------*/
  20. #define nrow x0 //Number of rows of input matrix
  21. #define ncol x1 //Number of coulumns of input matrix
  22. #define mat x2 //Input matrix base address
  23. #define mat_mod x3 //Output matrix (re-arranged matrix) base address
  24. #define mat_mod_ptr x4 //Pointer to output matrix
  25. #define mat_ptr0 x5 //Pointer to input matrix
  26. #define mat_ptr1 x6 //2nd pointer to input matrix
  27. #define outer_loop_cntr x7 //Outer loop counter
  28. #define inner_loop_exit x8 //Inner loop exit condition
  29. #define C1 x9 //Constant1: SVLs - No. of 32-bit elements
  30. #define C2 x10 //Constant2: 3*SVLs
  31. #define C3 x11 //Constant3: ncol*SVLs
  32. #define C4 x13 //Constant4: 2*SVLs
  33. #define C5 x14 //Constant5: 2*ncol
  34. #define C6 x15 //Constant6: 3*ncol
  35. .text
  36. .global sgemm_direct_sme1_preprocess
  37. sgemm_direct_sme1_preprocess:
  38. stp x19, x20, [sp, #-48]!
  39. stp x21, x22, [sp, #16]
  40. stp x23, x24, [sp, #32]
  41. smstart
  42. cntw C1 //SVLs
  43. mul C3, C1, ncol //SVLs*ncol
  44. lsl C5, ncol, #1 //2*ncol
  45. add C6, C5, ncol //3*ncol
  46. cnth C4 //2*SVLs
  47. add C2, C1, C1, lsl #1 //3*SVLs
  48. mov outer_loop_cntr, #0
  49. //Tile predicate (M dimension)
  50. whilelt p0.s, outer_loop_cntr, nrow
  51. //Predicate for stores
  52. ptrue p9.s
  53. .M_Loop:
  54. mov mat_ptr0, mat //Load base address of mat
  55. mov mat_mod_ptr, mat_mod //a_mod store base address
  56. add inner_loop_exit, mat, ncol, lsl #2 //Exit condition for inner loop
  57. whilelt p8.b, mat_ptr0, inner_loop_exit //Tile predicate (K dimension)
  58. .Loop_process:
  59. mov mat_ptr1, mat_ptr0
  60. //Load_to_tile loop counter
  61. mov w12, #0
  62. .Load_to_tile:
  63. psel p2, p8, p0.s[w12, 0]
  64. psel p3, p8, p0.s[w12, 1]
  65. psel p4, p8, p0.s[w12, 2]
  66. psel p5, p8, p0.s[w12, 3]
  67. //Load 1st row from mat_ptr1
  68. ld1w {za0h.s[w12, #0]}, p2/z, [mat_ptr1]
  69. //Load 2nd row from mat_ptr1 + ncol
  70. ld1w {za0h.s[w12, #1]}, p3/z, [mat_ptr1, ncol, lsl #2]
  71. //Load 3rd row from mat_ptr1 + 2*ncol
  72. ld1w {za0h.s[w12, #2]}, p4/z, [mat_ptr1, C5, lsl #2]
  73. //Load 4th row from mat_ptr1 + 3*ncol
  74. ld1w {za0h.s[w12, #3]}, p5/z, [mat_ptr1, C6, lsl #2]
  75. //mat_ptr1+=4*ncol FP32 elements
  76. add mat_ptr1, mat_ptr1, ncol, lsl #4
  77. //Increment counter
  78. add w12, w12, #4
  79. cmp w12, w9
  80. b.mi .Load_to_tile
  81. // Store_from_tile loop counter
  82. mov w12, #0
  83. .Store_from_tile:
  84. psel p2, p9, p8.s[w12, 0]
  85. psel p3, p9, p8.s[w12, 1]
  86. psel p4, p9, p8.s[w12, 2]
  87. psel p5, p9, p8.s[w12, 3]
  88. //Store 1st col to mat_mod
  89. st1w {za0v.s[w12, #0]}, p2, [mat_mod_ptr]
  90. //Store 2nd col to mat_mod + SVLs
  91. st1w {za0v.s[w12, #1]}, p3, [mat_mod_ptr, C1, lsl #2]
  92. //Store 3rd col to mat_mod + 2*SVLs
  93. st1w {za0v.s[w12, #2]}, p4, [mat_mod_ptr, C4, lsl #2]
  94. //Store 4th col to mat_mod + 3*SVLs
  95. st1w {za0v.s[w12, #3]}, p5, [mat_mod_ptr, C2, lsl #2]
  96. addvl mat_mod_ptr, mat_mod_ptr, #4 //mat_mod_ptr += 4*SVLb
  97. add w12, w12, #4 //Increment counter
  98. cmp w12, w9
  99. b.mi .Store_from_tile
  100. addvl mat_ptr0, mat_ptr0, #1 //mat_ptr0 += SVLb
  101. whilelt p8.b, mat_ptr0, inner_loop_exit
  102. b.first .Loop_process
  103. add mat_mod, mat_mod, C3, lsl #2 //mat_mod+=SVLs*nbc FP32 elements
  104. add mat, mat, C3, lsl #2 //mat+=SVLs*nbc FP32 elements
  105. incw outer_loop_cntr
  106. whilelt p0.s, outer_loop_cntr, nrow
  107. b.first .M_Loop
  108. smstop
  109. ldp x23, x24, [sp, #32]
  110. ldp x21, x22, [sp, #16]
  111. ldp x19, x20, [sp], #48
  112. ret