You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

bgemm_tcopy_2vl_neoversev1.c 6.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. /***************************************************************************
  2. * Copyright (c) 2025, The OpenBLAS Project
  3. * All rights reserved.
  4. * Redistribution and use in source and binary forms, with or without
  5. * modification, are permitted provided that the following conditions are
  6. * met:
  7. * 1. Redistributions of source code must retain the above copyright
  8. * notice, this list of conditions and the following disclaimer.
  9. * 2. Redistributions in binary form must reproduce the above copyright
  10. * notice, this list of conditions and the following disclaimer in
  11. * the documentation and/or other materials provided with the
  12. * distribution.
  13. * 3. Neither the name of the OpenBLAS project nor the names of
  14. * its contributors may be used to endorse or promote products
  15. * derived from this software without specific prior written permission.
  16. * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
  17. * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  18. * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
  19. * ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
  20. * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
  21. * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
  22. * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
  23. * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
  24. * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
  25. * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
  26. * POSSIBILITY OF SUCH DAMAGE.
  27. * *****************************************************************************/
  28. #include <arm_sve.h>
  29. #include <arm_neon.h>
  30. #include "common.h"
  31. int CNAME(BLASLONG m, BLASLONG n, IFLOAT *input, BLASLONG lda, IFLOAT *output) {
  32. const int sve_size_bf16 = svcnth();
  33. const int num_accumulators_sve = sve_size_bf16 >> 1;
  34. const int num_accumulators = num_accumulators_sve;
  35. const int incr_accumulators = 4;
  36. const int n_sve_accumulators = (n & -num_accumulators);
  37. const int n2 = n & -2;
  38. const int n_rest = n - n2;
  39. const int m4 = m & -4;
  40. const int m_rest = m - m4;
  41. size_t n_step = 0;
  42. for (; n_step < n_sve_accumulators; n_step += num_accumulators) {
  43. const uint16_t* inner_input = input;
  44. // Full 4x4 item transposes down the M dimension
  45. for (size_t m_step = 0; m_step < m4; m_step += 4) {
  46. const uint16_t* tile = inner_input;
  47. for (size_t line = 0; line < num_accumulators; line += incr_accumulators) {
  48. // Load 4x4 block
  49. uint16x4_t a_vec0 = vld1_u16(tile);
  50. uint16x4_t a_vec1 = vld1_u16(tile + lda);
  51. uint16x4_t a_vec2 = vld1_u16(tile + 2 * lda);
  52. uint16x4_t a_vec3 = vld1_u16(tile + 3 * lda);
  53. // Transpose 4x4 blocks
  54. uint16x4_t out_vec0 = vzip1_u16(a_vec0, a_vec1);
  55. uint16x4_t out_vec1 = vzip2_u16(a_vec0, a_vec1);
  56. uint16x4_t out_vec2 = vzip1_u16(a_vec2, a_vec3);
  57. uint16x4_t out_vec3 = vzip2_u16(a_vec2, a_vec3);
  58. // Transpose 8x4 blocks
  59. a_vec0 = vreinterpret_u16_u32(vzip1_u32(vreinterpret_u32_u16(out_vec0), vreinterpret_u32_u16(out_vec2)));
  60. a_vec1 = vreinterpret_u16_u32(vzip2_u32(vreinterpret_u32_u16(out_vec0), vreinterpret_u32_u16(out_vec2)));
  61. a_vec2 = vreinterpret_u16_u32(vzip1_u32(vreinterpret_u32_u16(out_vec1), vreinterpret_u32_u16(out_vec3)));
  62. a_vec3 = vreinterpret_u16_u32(vzip2_u32(vreinterpret_u32_u16(out_vec1), vreinterpret_u32_u16(out_vec3)));
  63. vst1_u16(output, a_vec0);
  64. vst1_u16(output + 4, a_vec1);
  65. vst1_u16(output + 8, a_vec2);
  66. vst1_u16(output + 12, a_vec3);
  67. tile += incr_accumulators;
  68. output += 16;
  69. }
  70. inner_input += incr_accumulators * lda;
  71. }
  72. if (m_rest) {
  73. for (BLASLONG line = 0; line < num_accumulators; line++) {
  74. output[0] = inner_input[0];
  75. output[1] = m_rest == 1 ? 0 : *(inner_input + lda);
  76. output[2] = m_rest <= 2 ? 0 : *(inner_input + 2 * lda);
  77. output[3] = m_rest <= 3 ? 0 : *(inner_input + 3 * lda);
  78. inner_input++;
  79. output += 4;
  80. }
  81. }
  82. input += num_accumulators;
  83. }
  84. for (; n_step < n2; n_step += 2) {
  85. const uint16_t* inner_input = input;
  86. for (size_t m_step = 0; m_step < m4; m_step += 4) {
  87. for (BLASLONG line = 0; line < 2; line++) {
  88. output[0] = *(inner_input + line);
  89. output[1] = *(inner_input + line + lda);
  90. output[2] = *(inner_input + line + 2 * lda);
  91. output[3] = *(inner_input + line + 3 * lda);
  92. output += 4;
  93. }
  94. inner_input += 4 * lda;
  95. }
  96. if (m_rest) {
  97. for (BLASLONG line = 0; line < 2; line++) {
  98. output[0] = *(inner_input + line);
  99. output[1] = m_rest == 1 ? 0 : *(inner_input + line + lda);
  100. output[2] = m_rest <= 2 ? 0 : *(inner_input + line + 2 * lda);
  101. output[3] = m_rest <= 3 ? 0 : *(inner_input + line + 3 * lda);
  102. output += 4;
  103. }
  104. }
  105. input += 2;
  106. }
  107. if (n_rest & 1) {
  108. const uint16_t* inner_input = input;
  109. for (size_t m_step = 0; m_step < m4; m_step += 4) {
  110. output[0] = *inner_input;
  111. output[1] = *(inner_input + lda);
  112. output[2] = *(inner_input + 2 * lda);
  113. output[3] = *(inner_input + 3 * lda);
  114. inner_input += 4 * lda;
  115. output += 4;
  116. }
  117. if (m_rest) {
  118. output[0] = inner_input[0];
  119. output[1] = m_rest == 1 ? 0 : *(inner_input + lda);
  120. output[2] = m_rest <= 2 ? 0 : *(inner_input + 2 * lda);
  121. output[3] = m_rest <= 3 ? 0 : *(inner_input + 3 * lda);
  122. output += 4;
  123. }
  124. }
  125. return 0;
  126. }