You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

sbgemm_kernel_8x4_neoversen2_impl.c 12 kB


  1. /***************************************************************************
  2. * Copyright (c) 2022, The OpenBLAS Project
  3. * All rights reserved.
  4. * Redistribution and use in source and binary forms, with or without
  5. * modification, are permitted provided that the following conditions are
  6. * met:
  7. * 1. Redistributions of source code must retain the above copyright
  8. * notice, this list of conditions and the following disclaimer.
  9. * 2. Redistributions in binary form must reproduce the above copyright
  10. * notice, this list of conditions and the following disclaimer in
  11. * the documentation and/or other materials provided with the
  12. * distribution.
  13. * 3. Neither the name of the OpenBLAS project nor the names of
  14. * its contributors may be used to endorse or promote products
  15. * derived from this software without specific prior written permission.
  16. * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
  17. * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  18. * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
  19. * ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
  20. * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
  21. * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
  22. * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
  23. * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
  24. * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
  25. * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
  26. * POSSIBILITY OF SUCH DAMAGE.
  27. * *****************************************************************************/
  28. #include <arm_sve.h>
  29. #include "common.h"
  30. #define INIT_C(M, N) mc##M##N = svdup_f32(0);
  31. #define MATMUL(M, N) mc##M##N = svbfmmla(mc##M##N, ma##M, mb##N);
  32. #define INIT_C_8x4 \
  33. do { \
  34. INIT_C(0, 0); \
  35. INIT_C(0, 1); \
  36. INIT_C(1, 0); \
  37. INIT_C(1, 1); \
  38. INIT_C(2, 0); \
  39. INIT_C(2, 1); \
  40. INIT_C(3, 0); \
  41. INIT_C(3, 1); \
  42. } while (0);
  43. #ifdef ALPHA_ONE
  44. #define UPDATE_C(PG, PTR, DST, SRC) \
  45. do { \
  46. DST = svld1_f32((PG), (PTR)); \
  47. DST = svadd_z((PG), SRC, DST); \
  48. svst1_f32((PG), (PTR), DST); \
  49. } while (0);
  50. #else
  51. #define UPDATE_C(PG, PTR, DST, SRC) \
  52. do { \
  53. DST = svld1_f32((PG), (PTR)); \
  54. DST = svmad_z((PG), svalpha, SRC, DST); \
  55. svst1_f32((PG), (PTR), DST); \
  56. } while (0);
  57. #endif
  58. #ifdef ALPHA_ONE
  59. int sbgemm_kernel_neoversen2_alpha_one(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT * A, IFLOAT * B, FLOAT * C, BLASLONG ldc)
  60. #else
  61. int sbgemm_kernel_neoversen2_alpha(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT * A, IFLOAT * B, FLOAT * C, BLASLONG ldc)
  62. #endif
  63. {
  64. BLASLONG pad_k = (k + 3) & ~3;
  65. svbfloat16_t ma0, ma1, ma2, ma3, mb0, mb1;
  66. svfloat32_t mc00, mc01, mc10, mc11, mc20, mc21, mc30, mc31,
  67. vc0, vc1, vc2, vc3, vc4, vc5, vc6, vc7,
  68. oc0, oc1, oc2, oc3, oc4, oc5, oc6, oc7;
  69. svfloat32_t svalpha = svdup_f32(alpha);
  70. svbool_t pg16 = svptrue_b16();
  71. svbool_t pg16_low = svdupq_b16(1, 1, 1, 1, 0, 0, 0, 0);
  72. svbool_t pg32 = svptrue_b32();
  73. svbool_t pg32_low = svdupq_b32(1, 1, 0, 0);
  74. svbool_t pg32_first = svdupq_b32(1, 0, 0, 0);
  75. bfloat16_t *ptr_a = (bfloat16_t *)A;
  76. bfloat16_t *ptr_b = (bfloat16_t *)B;
  77. FLOAT *ptr_c = C;
  78. bfloat16_t *ptr_a0, *ptr_a1, *ptr_a2, *ptr_a3;
  79. bfloat16_t *ptr_b0, *ptr_b1;
  80. FLOAT *ptr_c0, *ptr_c1, *ptr_c2, *ptr_c3;
  81. for (BLASLONG j = 0; j < n / 4; j++) {
  82. ptr_c0 = ptr_c;
  83. ptr_c1 = ptr_c0 + ldc;
  84. ptr_c2 = ptr_c1 + ldc;
  85. ptr_c3 = ptr_c2 + ldc;
  86. ptr_c += 4 * ldc;
  87. ptr_a = (bfloat16_t *)A;
  88. for (BLASLONG i = 0; i < m / 8; i++) {
  89. ptr_a0 = ptr_a;
  90. ptr_a += 8 * pad_k;
  91. ptr_b0 = ptr_b;
  92. INIT_C_8x4;
  93. for (BLASLONG p = 0; p < pad_k; p += 4) {
  94. ma0 = svld1_bf16(pg16, ptr_a0);
  95. ma1 = svld1_bf16(pg16, ptr_a0 + 8);
  96. ma2 = svld1_bf16(pg16, ptr_a0 + 16);
  97. ma3 = svld1_bf16(pg16, ptr_a0 + 24);
  98. mb0 = svld1_bf16(pg16, ptr_b0);
  99. mb1 = svld1_bf16(pg16, ptr_b0 + 8);
  100. MATMUL(0, 0); MATMUL(0, 1);
  101. MATMUL(1, 0); MATMUL(1, 1);
  102. MATMUL(2, 0); MATMUL(2, 1);
  103. MATMUL(3, 0); MATMUL(3, 1);
  104. ptr_a0 += 32;
  105. ptr_b0 += 16;
  106. }
  107. vc0 = svuzp1(mc00, mc10);
  108. vc1 = svuzp1(mc20, mc30);
  109. vc2 = svuzp2(mc00, mc10);
  110. vc3 = svuzp2(mc20, mc30);
  111. vc4 = svuzp1(mc01, mc11);
  112. vc5 = svuzp1(mc21, mc31);
  113. vc6 = svuzp2(mc01, mc11);
  114. vc7 = svuzp2(mc21, mc31);
  115. UPDATE_C(pg32, ptr_c0, oc0, vc0);
  116. UPDATE_C(pg32, ptr_c0+4, oc1, vc1);
  117. UPDATE_C(pg32, ptr_c1, oc2, vc2);
  118. UPDATE_C(pg32, ptr_c1+4, oc3, vc3);
  119. UPDATE_C(pg32, ptr_c2, oc4, vc4)
  120. UPDATE_C(pg32, ptr_c2+4, oc5, vc5);
  121. UPDATE_C(pg32, ptr_c3, oc6, vc6)
  122. UPDATE_C(pg32, ptr_c3+4, oc7, vc7);
  123. ptr_c0 += 8;
  124. ptr_c1 += 8;
  125. ptr_c2 += 8;
  126. ptr_c3 += 8;
  127. }
  128. if (m & 4) {
  129. ptr_a0 = ptr_a;
  130. ptr_a += 4 * pad_k;
  131. ptr_b0 = ptr_b;
  132. INIT_C(0, 0); INIT_C(0, 1);
  133. INIT_C(1, 0); INIT_C(1, 1);
  134. for (BLASLONG p = 0; p < pad_k; p += 4) {
  135. ma0 = svld1_bf16(pg16, ptr_a0);
  136. ma1 = svld1_bf16(pg16, ptr_a0 + 8);
  137. mb0 = svld1_bf16(pg16, ptr_b0);
  138. mb1 = svld1_bf16(pg16, ptr_b0 + 8);
  139. MATMUL(0, 0); MATMUL(0, 1);
  140. MATMUL(1, 0); MATMUL(1, 1);
  141. ptr_a0 += 16;
  142. ptr_b0 += 16;
  143. }
  144. vc0 = svuzp1(mc00, mc10);
  145. vc1 = svuzp2(mc00, mc10);
  146. vc2 = svuzp1(mc01, mc11);
  147. vc3 = svuzp2(mc01, mc11);
  148. UPDATE_C(pg32, ptr_c0, oc0, vc0);
  149. UPDATE_C(pg32, ptr_c1, oc1, vc1);
  150. UPDATE_C(pg32, ptr_c2, oc2, vc2);
  151. UPDATE_C(pg32, ptr_c3, oc3, vc3);
  152. ptr_c0 += 4;
  153. ptr_c1 += 4;
  154. ptr_c2 += 4;
  155. ptr_c3 += 4;
  156. }
  157. if (m & 2) {
  158. ptr_a0 = ptr_a;
  159. ptr_a += 2 * pad_k;
  160. ptr_b0 = ptr_b;
  161. INIT_C(0, 0); INIT_C(0, 1);
  162. for (BLASLONG p = 0; p < pad_k; p += 4) {
  163. ma0 = svld1_bf16(pg16, ptr_a0);
  164. mb0 = svld1_bf16(pg16, ptr_b0);
  165. mb1 = svld1_bf16(pg16, ptr_b0 + 8);
  166. MATMUL(0, 0); MATMUL(0, 1);
  167. ptr_a0 += 8;
  168. ptr_b0 += 16;
  169. }
  170. vc0 = svuzp1(mc00, mc00);
  171. vc1 = svuzp2(mc00, mc00);
  172. vc2 = svuzp1(mc01, mc01);
  173. vc3 = svuzp2(mc01, mc01);
  174. UPDATE_C(pg32_low, ptr_c0, oc0, vc0);
  175. UPDATE_C(pg32_low, ptr_c1, oc1, vc1);
  176. UPDATE_C(pg32_low, ptr_c2, oc2, vc2);
  177. UPDATE_C(pg32_low, ptr_c3, oc3, vc3);
  178. ptr_c0 += 2;
  179. ptr_c1 += 2;
  180. ptr_c2 += 2;
  181. ptr_c3 += 2;
  182. }
  183. if (m & 1) {
  184. ptr_a0 = ptr_a;
  185. ptr_b0 = ptr_b;
  186. INIT_C(0, 0); INIT_C(0, 1);
  187. for (BLASLONG p = 0; p < pad_k; p += 4) {
  188. ma0 = svld1_bf16(pg16_low, ptr_a0);
  189. mb0 = svld1_bf16(pg16, ptr_b0);
  190. mb1 = svld1_bf16(pg16, ptr_b0 + 8);
  191. MATMUL(0, 0); MATMUL(0, 1);
  192. ptr_a0 += 4;
  193. ptr_b0 += 16;
  194. }
  195. vc1 = svuzp2(mc00, mc00);
  196. vc3 = svuzp2(mc01, mc01);
  197. UPDATE_C(pg32_first, ptr_c0, oc0, mc00);
  198. UPDATE_C(pg32_first, ptr_c1, oc1, vc1);
  199. UPDATE_C(pg32_first, ptr_c2, oc2, mc01);
  200. UPDATE_C(pg32_first, ptr_c3, oc3, vc3);
  201. }
  202. ptr_b += 4 * pad_k;
  203. }
  204. if (n & 2) {
  205. ptr_c0 = ptr_c;
  206. ptr_c1 = ptr_c0 + ldc;
  207. ptr_c += 2 * ldc;
  208. ptr_a = (bfloat16_t *)A;
  209. for (BLASLONG i = 0; i < m / 8; i++) {
  210. ptr_a0 = ptr_a;
  211. ptr_a += 8 * pad_k;
  212. ptr_b0 = ptr_b;
  213. INIT_C(0, 0);
  214. INIT_C(1, 0);
  215. INIT_C(2, 0);
  216. INIT_C(3, 0);
  217. for (BLASLONG p = 0; p < pad_k; p += 4) {
  218. ma0 = svld1_bf16(pg16, ptr_a0);
  219. ma1 = svld1_bf16(pg16, ptr_a0 + 8);
  220. ma2 = svld1_bf16(pg16, ptr_a0 + 16);
  221. ma3 = svld1_bf16(pg16, ptr_a0 + 24);
  222. mb0 = svld1_bf16(pg16, ptr_b0);
  223. MATMUL(0, 0);
  224. MATMUL(1, 0);
  225. MATMUL(2, 0);
  226. MATMUL(3, 0);
  227. ptr_a0 += 32;
  228. ptr_b0 += 8;
  229. }
  230. vc0 = svuzp1(mc00, mc10);
  231. vc1 = svuzp1(mc20, mc30);
  232. vc2 = svuzp2(mc00, mc10);
  233. vc3 = svuzp2(mc20, mc30);
  234. UPDATE_C(pg32, ptr_c0, oc0, vc0);
  235. UPDATE_C(pg32, ptr_c0 + 4, oc1, vc1);
  236. UPDATE_C(pg32, ptr_c1, oc2, vc2);
  237. UPDATE_C(pg32, ptr_c1 + 4, oc3, vc3);
  238. ptr_c0 += 8;
  239. ptr_c1 += 8;
  240. }
  241. if (m & 4) {
  242. ptr_a0 = ptr_a;
  243. ptr_a += 4 * pad_k;
  244. ptr_b0 = ptr_b;
  245. INIT_C(0, 0);
  246. INIT_C(1, 0);
  247. for (BLASLONG p = 0; p < pad_k; p += 4) {
  248. ma0 = svld1_bf16(pg16, ptr_a0);
  249. ma1 = svld1_bf16(pg16, ptr_a0 + 8);
  250. mb0 = svld1_bf16(pg16, ptr_b0);
  251. MATMUL(0, 0);
  252. MATMUL(1, 0);
  253. ptr_a0 += 16;
  254. ptr_b0 += 8;
  255. }
  256. vc0 = svuzp1(mc00, mc10);
  257. vc1 = svuzp2(mc00, mc10);
  258. UPDATE_C(pg32, ptr_c0, oc0, vc0);
  259. UPDATE_C(pg32, ptr_c1, oc1, vc1);
  260. ptr_c0 += 4;
  261. ptr_c1 += 4;
  262. }
  263. if (m & 2) {
  264. ptr_a0 = ptr_a;
  265. ptr_a += 2 * pad_k;
  266. ptr_b0 = ptr_b;
  267. INIT_C(0, 0);
  268. for (BLASLONG p = 0; p < pad_k; p += 4) {
  269. ma0 = svld1_bf16(pg16, ptr_a0);
  270. mb0 = svld1_bf16(pg16, ptr_b0);
  271. MATMUL(0, 0);
  272. ptr_a0 += 8;
  273. ptr_b0 += 8;
  274. }
  275. vc0 = svuzp1(mc00, mc00);
  276. vc1 = svuzp2(mc00, mc00);
  277. UPDATE_C(pg32_low, ptr_c0, oc0, vc0);
  278. UPDATE_C(pg32_low, ptr_c1, oc1, vc1);
  279. ptr_c0 += 2;
  280. ptr_c1 += 2;
  281. }
  282. if (m & 1) {
  283. ptr_a0 = ptr_a;
  284. ptr_b0 = ptr_b;
  285. INIT_C(0, 0);
  286. for (BLASLONG p = 0; p < pad_k; p += 4) {
  287. ma0 = svld1_bf16(pg16_low, ptr_a0);
  288. mb0 = svld1_bf16(pg16, ptr_b0);
  289. MATMUL(0, 0);
  290. ptr_a0 += 4;
  291. ptr_b0 += 8;
  292. }
  293. vc1 = svuzp2(mc00, mc00);
  294. UPDATE_C(pg32_first, ptr_c0, oc0, mc00);
  295. UPDATE_C(pg32_first, ptr_c1, oc1, vc1);
  296. }
  297. ptr_b += 2 * pad_k;
  298. }
  299. if (n & 1) {
  300. ptr_c0 = ptr_c;
  301. ptr_a = (bfloat16_t *)A;
  302. for (BLASLONG i = 0; i < m / 8; i++) {
  303. ptr_a0 = ptr_a;
  304. ptr_a += 8 * pad_k;
  305. ptr_b0 = ptr_b;
  306. INIT_C(0, 0);
  307. INIT_C(1, 0);
  308. INIT_C(2, 0);
  309. INIT_C(3, 0);
  310. for (BLASLONG p = 0; p < pad_k; p += 4) {
  311. ma0 = svld1_bf16(pg16, ptr_a0);
  312. ma1 = svld1_bf16(pg16, ptr_a0 + 8);
  313. ma2 = svld1_bf16(pg16, ptr_a0 + 16);
  314. ma3 = svld1_bf16(pg16, ptr_a0 + 24);
  315. mb0 = svld1_bf16(pg16_low, ptr_b0);
  316. MATMUL(0, 0);
  317. MATMUL(1, 0);
  318. MATMUL(2, 0);
  319. MATMUL(3, 0);
  320. ptr_a0 += 32;
  321. ptr_b0 += 4;
  322. }
  323. vc0 = svuzp1(mc00, mc10);
  324. vc1 = svuzp1(mc20, mc30);
  325. UPDATE_C(pg32, ptr_c0, oc0, vc0);
  326. UPDATE_C(pg32, ptr_c0 + 4, oc1, vc1);
  327. ptr_c0 += 8;
  328. }
  329. if (m & 4) {
  330. ptr_a0 = ptr_a;
  331. ptr_a += 4 * pad_k;
  332. ptr_b0 = ptr_b;
  333. INIT_C(0, 0);
  334. INIT_C(1, 0);
  335. for (BLASLONG p = 0; p < pad_k; p += 4) {
  336. ma0 = svld1_bf16(pg16, ptr_a0);
  337. ma1 = svld1_bf16(pg16, ptr_a0 + 8);
  338. mb0 = svld1_bf16(pg16_low, ptr_b0);
  339. MATMUL(0, 0);
  340. MATMUL(1, 0);
  341. ptr_a0 += 16;
  342. ptr_b0 += 4;
  343. }
  344. vc0 = svuzp1(mc00, mc10);
  345. UPDATE_C(pg32, ptr_c0, oc0, vc0);
  346. ptr_c0 += 4;
  347. }
  348. if (m & 2) {
  349. ptr_a0 = ptr_a;
  350. ptr_a += 2 * pad_k;
  351. ptr_b0 = ptr_b;
  352. INIT_C(0, 0);
  353. for (BLASLONG p = 0; p < pad_k; p += 4) {
  354. ma0 = svld1_bf16(pg16, ptr_a0);
  355. mb0 = svld1_bf16(pg16_low, ptr_b0);
  356. MATMUL(0, 0);
  357. ptr_a0 += 8;
  358. ptr_b0 += 4;
  359. }
  360. vc0 = svuzp1(mc00, mc00);
  361. UPDATE_C(pg32_low, ptr_c0, oc0, vc0);
  362. ptr_c0 += 2;
  363. }
  364. if (m & 1) {
  365. ptr_a0 = ptr_a;
  366. ptr_b0 = ptr_b;
  367. INIT_C(0, 0);
  368. for (BLASLONG p = 0; p < pad_k; p += 4) {
  369. ma0 = svld1_bf16(pg16_low, ptr_a0);
  370. mb0 = svld1_bf16(pg16_low, ptr_b0);
  371. MATMUL(0, 0);
  372. ptr_a0 += 4;
  373. ptr_b0 += 4;
  374. }
  375. UPDATE_C(pg32_first, ptr_c0, oc0, mc00);
  376. }
  377. }
  378. return 0;
  379. }