You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

shgemm_kernel_8x8_zvl128b.c 27 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767
  1. #include "common.h"
  2. #include <riscv_vector.h>
  3. int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B, FLOAT *C, BLASLONG ldc)
  4. {
  5. BLASLONG gvl = 0;
  6. BLASLONG m_top = 0;
  7. BLASLONG n_top = 0;
  8. // -- MAIN PASS
  9. for (BLASLONG j=0; j<N/8; j+=1) {
  10. m_top = 0;
  11. BLASLONG gvl = __riscv_vsetvl_e16m1(8);
  12. for (BLASLONG i=0; i<M/8; i+=1) {
  13. BLASLONG ai=m_top*K;
  14. BLASLONG bi=n_top*K;
  15. _Float16 B0 = B[bi+0];
  16. _Float16 B1 = B[bi+1];
  17. _Float16 B2 = B[bi+2];
  18. _Float16 B3 = B[bi+3];
  19. _Float16 B4 = B[bi+4];
  20. _Float16 B5 = B[bi+5];
  21. _Float16 B6 = B[bi+6];
  22. _Float16 B7 = B[bi+7];
  23. bi += 8;
  24. vfloat16m1_t A0 = __riscv_vle16_v_f16m1( &A[ai+0*gvl], gvl );
  25. ai += 8;
  26. vfloat32m2_t result0 = __riscv_vfwmul_vf_f32m2( A0, B0, gvl);
  27. vfloat32m2_t result1 = __riscv_vfwmul_vf_f32m2( A0, B1, gvl);
  28. vfloat32m2_t result2 = __riscv_vfwmul_vf_f32m2( A0, B2, gvl);
  29. vfloat32m2_t result3 = __riscv_vfwmul_vf_f32m2( A0, B3, gvl);
  30. vfloat32m2_t result4 = __riscv_vfwmul_vf_f32m2( A0, B4, gvl);
  31. vfloat32m2_t result5 = __riscv_vfwmul_vf_f32m2( A0, B5, gvl);
  32. vfloat32m2_t result6 = __riscv_vfwmul_vf_f32m2( A0, B6, gvl);
  33. vfloat32m2_t result7 = __riscv_vfwmul_vf_f32m2( A0, B7, gvl);
  34. for(BLASLONG k=1; k<K; k++) {
  35. B0 = B[bi+0];
  36. B1 = B[bi+1];
  37. B2 = B[bi+2];
  38. B3 = B[bi+3];
  39. B4 = B[bi+4];
  40. B5 = B[bi+5];
  41. B6 = B[bi+6];
  42. B7 = B[bi+7];
  43. bi += 8;
  44. A0 = __riscv_vle16_v_f16m1( &A[ai+0*gvl], gvl );
  45. ai += 8;
  46. result0 = __riscv_vfwmacc_vf_f32m2(result0, B0, A0, gvl);
  47. result1 = __riscv_vfwmacc_vf_f32m2(result1, B1, A0, gvl);
  48. result2 = __riscv_vfwmacc_vf_f32m2(result2, B2, A0, gvl);
  49. result3 = __riscv_vfwmacc_vf_f32m2(result3, B3, A0, gvl);
  50. result4 = __riscv_vfwmacc_vf_f32m2(result4, B4, A0, gvl);
  51. result5 = __riscv_vfwmacc_vf_f32m2(result5, B5, A0, gvl);
  52. result6 = __riscv_vfwmacc_vf_f32m2(result6, B6, A0, gvl);
  53. result7 = __riscv_vfwmacc_vf_f32m2(result7, B7, A0, gvl);
  54. }
  55. BLASLONG ci=n_top*ldc+m_top;
  56. vfloat32m2_t c0 = __riscv_vle32_v_f32m2( &C[ci], gvl); ci += ldc-gvl*0;
  57. vfloat32m2_t c1 = __riscv_vle32_v_f32m2( &C[ci], gvl); ci += ldc-gvl*0;
  58. vfloat32m2_t c2 = __riscv_vle32_v_f32m2( &C[ci], gvl); ci += ldc-gvl*0;
  59. vfloat32m2_t c3 = __riscv_vle32_v_f32m2( &C[ci], gvl); ci += ldc-gvl*0;
  60. vfloat32m2_t c4 = __riscv_vle32_v_f32m2( &C[ci], gvl); ci += ldc-gvl*0;
  61. vfloat32m2_t c5 = __riscv_vle32_v_f32m2( &C[ci], gvl); ci += ldc-gvl*0;
  62. vfloat32m2_t c6 = __riscv_vle32_v_f32m2( &C[ci], gvl); ci += ldc-gvl*0;
  63. vfloat32m2_t c7 = __riscv_vle32_v_f32m2( &C[ci], gvl); ci += ldc-gvl*0;
  64. c0 = __riscv_vfmacc_vf_f32m2(c0, alpha, result0, gvl);
  65. c1 = __riscv_vfmacc_vf_f32m2(c1, alpha, result1, gvl);
  66. c2 = __riscv_vfmacc_vf_f32m2(c2, alpha, result2, gvl);
  67. c3 = __riscv_vfmacc_vf_f32m2(c3, alpha, result3, gvl);
  68. c4 = __riscv_vfmacc_vf_f32m2(c4, alpha, result4, gvl);
  69. c5 = __riscv_vfmacc_vf_f32m2(c5, alpha, result5, gvl);
  70. c6 = __riscv_vfmacc_vf_f32m2(c6, alpha, result6, gvl);
  71. c7 = __riscv_vfmacc_vf_f32m2(c7, alpha, result7, gvl);
  72. ci = n_top * ldc + m_top;
  73. __riscv_vse32_v_f32m2( &C[ci], c0, gvl); ci += ldc-gvl*0;
  74. __riscv_vse32_v_f32m2( &C[ci], c1, gvl); ci += ldc-gvl*0;
  75. __riscv_vse32_v_f32m2( &C[ci], c2, gvl); ci += ldc-gvl*0;
  76. __riscv_vse32_v_f32m2( &C[ci], c3, gvl); ci += ldc-gvl*0;
  77. __riscv_vse32_v_f32m2( &C[ci], c4, gvl); ci += ldc-gvl*0;
  78. __riscv_vse32_v_f32m2( &C[ci], c5, gvl); ci += ldc-gvl*0;
  79. __riscv_vse32_v_f32m2( &C[ci], c6, gvl); ci += ldc-gvl*0;
  80. __riscv_vse32_v_f32m2( &C[ci], c7, gvl); ci += ldc-gvl*0;
  81. m_top += 8;
  82. }
  83. // -- tails for main pass --
  84. if( M & 4 ) {
  85. gvl = __riscv_vsetvl_e16m1(4);
  86. BLASLONG ai=m_top*K;
  87. BLASLONG bi=n_top*K;
  88. _Float16 B0 = B[bi+0];
  89. _Float16 B1 = B[bi+1];
  90. _Float16 B2 = B[bi+2];
  91. _Float16 B3 = B[bi+3];
  92. _Float16 B4 = B[bi+4];
  93. _Float16 B5 = B[bi+5];
  94. _Float16 B6 = B[bi+6];
  95. _Float16 B7 = B[bi+7];
  96. bi += 8;
  97. vfloat16m1_t A0 = __riscv_vle16_v_f16m1(&A[ai + 0 * gvl], gvl);
  98. ai += 4;
  99. vfloat32m2_t result0 = __riscv_vfwmul_vf_f32m2( A0, B0, gvl);
  100. vfloat32m2_t result1 = __riscv_vfwmul_vf_f32m2( A0, B1, gvl);
  101. vfloat32m2_t result2 = __riscv_vfwmul_vf_f32m2( A0, B2, gvl);
  102. vfloat32m2_t result3 = __riscv_vfwmul_vf_f32m2( A0, B3, gvl);
  103. vfloat32m2_t result4 = __riscv_vfwmul_vf_f32m2( A0, B4, gvl);
  104. vfloat32m2_t result5 = __riscv_vfwmul_vf_f32m2( A0, B5, gvl);
  105. vfloat32m2_t result6 = __riscv_vfwmul_vf_f32m2( A0, B6, gvl);
  106. vfloat32m2_t result7 = __riscv_vfwmul_vf_f32m2( A0, B7, gvl);
  107. for(BLASLONG k=1; k < K; ++k) {
  108. B0 = B[bi+0];
  109. B1 = B[bi+1];
  110. B2 = B[bi+2];
  111. B3 = B[bi+3];
  112. B4 = B[bi+4];
  113. B5 = B[bi+5];
  114. B6 = B[bi+6];
  115. B7 = B[bi+7];
  116. bi += 8;
  117. A0 = __riscv_vle16_v_f16m1(&A[ai + 0 * gvl], gvl);
  118. ai += 4;
  119. result0 = __riscv_vfwmacc_vf_f32m2(result0, B0, A0, gvl);
  120. result1 = __riscv_vfwmacc_vf_f32m2(result1, B1, A0, gvl);
  121. result2 = __riscv_vfwmacc_vf_f32m2(result2, B2, A0, gvl);
  122. result3 = __riscv_vfwmacc_vf_f32m2(result3, B3, A0, gvl);
  123. result4 = __riscv_vfwmacc_vf_f32m2(result4, B4, A0, gvl);
  124. result5 = __riscv_vfwmacc_vf_f32m2(result5, B5, A0, gvl);
  125. result6 = __riscv_vfwmacc_vf_f32m2(result6, B6, A0, gvl);
  126. result7 = __riscv_vfwmacc_vf_f32m2(result7, B7, A0, gvl);
  127. }
  128. BLASLONG ci = n_top * ldc + m_top;
  129. vfloat32m2_t c0 = __riscv_vle32_v_f32m2(&C[ci], gvl);
  130. ci += ldc - gvl * 0;
  131. vfloat32m2_t c1 = __riscv_vle32_v_f32m2(&C[ci], gvl);
  132. ci += ldc - gvl * 0;
  133. vfloat32m2_t c2 = __riscv_vle32_v_f32m2(&C[ci], gvl);
  134. ci += ldc - gvl * 0;
  135. vfloat32m2_t c3 = __riscv_vle32_v_f32m2(&C[ci], gvl);
  136. ci += ldc - gvl * 0;
  137. vfloat32m2_t c4 = __riscv_vle32_v_f32m2(&C[ci], gvl);
  138. ci += ldc - gvl * 0;
  139. vfloat32m2_t c5 = __riscv_vle32_v_f32m2(&C[ci], gvl);
  140. ci += ldc - gvl * 0;
  141. vfloat32m2_t c6 = __riscv_vle32_v_f32m2(&C[ci], gvl);
  142. ci += ldc - gvl * 0;
  143. vfloat32m2_t c7 = __riscv_vle32_v_f32m2(&C[ci], gvl);
  144. c0 = __riscv_vfmacc_vf_f32m2(c0, alpha, result0, gvl);
  145. c1 = __riscv_vfmacc_vf_f32m2(c1, alpha, result1, gvl);
  146. c2 = __riscv_vfmacc_vf_f32m2(c2, alpha, result2, gvl);
  147. c3 = __riscv_vfmacc_vf_f32m2(c3, alpha, result3, gvl);
  148. c4 = __riscv_vfmacc_vf_f32m2(c4, alpha, result4, gvl);
  149. c5 = __riscv_vfmacc_vf_f32m2(c5, alpha, result5, gvl);
  150. c6 = __riscv_vfmacc_vf_f32m2(c6, alpha, result6, gvl);
  151. c7 = __riscv_vfmacc_vf_f32m2(c7, alpha, result7, gvl);
  152. ci= n_top * ldc + m_top;
  153. __riscv_vse32_v_f32m2(&C[ci], c0, gvl); ci += ldc - gvl * 0;
  154. __riscv_vse32_v_f32m2(&C[ci], c1, gvl); ci += ldc - gvl * 0;
  155. __riscv_vse32_v_f32m2(&C[ci], c2, gvl); ci += ldc - gvl * 0;
  156. __riscv_vse32_v_f32m2(&C[ci], c3, gvl); ci += ldc - gvl * 0;
  157. __riscv_vse32_v_f32m2(&C[ci], c4, gvl); ci += ldc - gvl * 0;
  158. __riscv_vse32_v_f32m2(&C[ci], c5, gvl); ci += ldc - gvl * 0;
  159. __riscv_vse32_v_f32m2(&C[ci], c6, gvl); ci += ldc - gvl * 0;
  160. __riscv_vse32_v_f32m2(&C[ci], c7, gvl);
  161. m_top += 4;
  162. }
  163. if( M & 2 ) {
  164. BLASLONG ai = m_top * K;
  165. BLASLONG bi = n_top * K;
  166. float result0 = 0;
  167. float result1 = 0;
  168. float result2 = 0;
  169. float result3 = 0;
  170. float result4 = 0;
  171. float result5 = 0;
  172. float result6 = 0;
  173. float result7 = 0;
  174. float result8 = 0;
  175. float result9 = 0;
  176. float result10 = 0;
  177. float result11 = 0;
  178. float result12 = 0;
  179. float result13 = 0;
  180. float result14 = 0;
  181. float result15 = 0;
  182. for(BLASLONG k=0; k<K; k++) {
  183. result0+=(float)(A[ai+0]*B[bi+0]);
  184. result1+=(float)(A[ai+1]*B[bi+0]);
  185. result2+=(float)(A[ai+0]*B[bi+1]);
  186. result3+=(float)(A[ai+1]*B[bi+1]);
  187. result4+=(float)(A[ai+0]*B[bi+2]);
  188. result5+=(float)(A[ai+1]*B[bi+2]);
  189. result6+=(float)(A[ai+0]*B[bi+3]);
  190. result7+=(float)(A[ai+1]*B[bi+3]);
  191. result8+=(float)(A[ai+0]*B[bi+4]);
  192. result9+=(float)(A[ai+1]*B[bi+4]);
  193. result10+=(float)(A[ai+0]*B[bi+5]);
  194. result11+=(float)(A[ai+1]*B[bi+5]);
  195. result12+=(float)(A[ai+0]*B[bi+6]);
  196. result13+=(float)(A[ai+1]*B[bi+6]);
  197. result14+=(float)(A[ai+0]*B[bi+7]);
  198. result15+=(float)(A[ai+1]*B[bi+7]);
  199. ai+=2;
  200. bi+=8;
  201. }
  202. BLASLONG ci=n_top*ldc+m_top;
  203. C[ci + 0 * ldc + 0] += alpha * result0;
  204. C[ci + 0 * ldc + 1] += alpha * result1;
  205. C[ci + 1 * ldc + 0] += alpha * result2;
  206. C[ci + 1 * ldc + 1] += alpha * result3;
  207. C[ci + 2 * ldc + 0] += alpha * result4;
  208. C[ci + 2 * ldc + 1] += alpha * result5;
  209. C[ci + 3 * ldc + 0] += alpha * result6;
  210. C[ci + 3 * ldc + 1] += alpha * result7;
  211. C[ci + 4 * ldc + 0] += alpha * result8;
  212. C[ci + 4 * ldc + 1] += alpha * result9;
  213. C[ci + 5 * ldc + 0] += alpha * result10;
  214. C[ci + 5 * ldc + 1] += alpha * result11;
  215. C[ci + 6 * ldc + 0] += alpha * result12;
  216. C[ci + 6 * ldc + 1] += alpha * result13;
  217. C[ci + 7 * ldc + 0] += alpha * result14;
  218. C[ci + 7 * ldc + 1] += alpha * result15;
  219. m_top+=2;
  220. }
  221. if( M & 1 ) {
  222. float result0 = 0;
  223. float result1 = 0;
  224. float result2 = 0;
  225. float result3 = 0;
  226. float result4 = 0;
  227. float result5 = 0;
  228. float result6 = 0;
  229. float result7 = 0;
  230. BLASLONG ai = m_top * K;
  231. BLASLONG bi = n_top * K;
  232. for(BLASLONG k=0; k<K; k++) {
  233. result0+=(float)(A[ai+0]*B[bi+0]);
  234. result1+=(float)(A[ai+0]*B[bi+1]);
  235. result2+=(float)(A[ai+0]*B[bi+2]);
  236. result3+=(float)(A[ai+0]*B[bi+3]);
  237. result4+=(float)(A[ai+0]*B[bi+4]);
  238. result5+=(float)(A[ai+0]*B[bi+5]);
  239. result6+=(float)(A[ai+0]*B[bi+6]);
  240. result7+=(float)(A[ai+0]*B[bi+7]);
  241. ai+=1;
  242. bi+=8;
  243. }
  244. BLASLONG ci = n_top * ldc + m_top;
  245. C[ci + 0 * ldc + 0] += alpha * result0;
  246. C[ci + 1 * ldc + 0] += alpha * result1;
  247. C[ci + 2 * ldc + 0] += alpha * result2;
  248. C[ci + 3 * ldc + 0] += alpha * result3;
  249. C[ci + 4 * ldc + 0] += alpha * result4;
  250. C[ci + 5 * ldc + 0] += alpha * result5;
  251. C[ci + 6 * ldc + 0] += alpha * result6;
  252. C[ci + 7 * ldc + 0] += alpha * result7;
  253. m_top+=1;
  254. }
  255. n_top += 8;
  256. }
  257. // -- tails for N=4
  258. if( N & 4 ) {
  259. gvl = __riscv_vsetvl_e16m1(8);
  260. m_top = 0;
  261. for (BLASLONG i=0; i<M/8; i+=1) {
  262. BLASLONG ai=m_top*K;
  263. BLASLONG bi=n_top*K;
  264. _Float16 B0 = B[bi+0];
  265. _Float16 B1 = B[bi+1];
  266. _Float16 B2 = B[bi+2];
  267. _Float16 B3 = B[bi+3];
  268. bi += 4;
  269. vfloat16m1_t A0 = __riscv_vle16_v_f16m1( &A[ai+0*gvl], gvl );
  270. ai += 8;
  271. vfloat32m2_t result0 = __riscv_vfwmul_vf_f32m2( A0, B0, gvl);
  272. vfloat32m2_t result1 = __riscv_vfwmul_vf_f32m2( A0, B1, gvl);
  273. vfloat32m2_t result2 = __riscv_vfwmul_vf_f32m2( A0, B2, gvl);
  274. vfloat32m2_t result3 = __riscv_vfwmul_vf_f32m2( A0, B3, gvl);
  275. for(BLASLONG k=1; k<K; k++) {
  276. B0 = B[bi+0];
  277. B1 = B[bi+1];
  278. B2 = B[bi+2];
  279. B3 = B[bi+3];
  280. bi += 4;
  281. A0 = __riscv_vle16_v_f16m1( &A[ai+0*gvl], gvl );
  282. ai += 8;
  283. result0 = __riscv_vfwmacc_vf_f32m2(result0, B0, A0, gvl);
  284. result1 = __riscv_vfwmacc_vf_f32m2(result1, B1, A0, gvl);
  285. result2 = __riscv_vfwmacc_vf_f32m2(result2, B2, A0, gvl);
  286. result3 = __riscv_vfwmacc_vf_f32m2(result3, B3, A0, gvl);
  287. }
  288. BLASLONG ci=n_top*ldc+m_top;
  289. vfloat32m2_t c0 = __riscv_vle32_v_f32m2( &C[ci], gvl); ci += ldc - gvl * 0;
  290. vfloat32m2_t c1 = __riscv_vle32_v_f32m2( &C[ci], gvl); ci += ldc - gvl * 0;
  291. vfloat32m2_t c2 = __riscv_vle32_v_f32m2( &C[ci], gvl); ci += ldc - gvl * 0;
  292. vfloat32m2_t c3 = __riscv_vle32_v_f32m2( &C[ci], gvl);
  293. c0 = __riscv_vfmacc_vf_f32m2(c0, alpha, result0, gvl);
  294. c1 = __riscv_vfmacc_vf_f32m2(c1, alpha, result1, gvl);
  295. c2 = __riscv_vfmacc_vf_f32m2(c2, alpha, result2, gvl);
  296. c3 = __riscv_vfmacc_vf_f32m2(c3, alpha, result3, gvl);
  297. ci = n_top * ldc + m_top;
  298. __riscv_vse32_v_f32m2( &C[ci], c0, gvl); ci += ldc-gvl*0;
  299. __riscv_vse32_v_f32m2( &C[ci], c1, gvl); ci += ldc-gvl*0;
  300. __riscv_vse32_v_f32m2( &C[ci], c2, gvl); ci += ldc-gvl*0;
  301. __riscv_vse32_v_f32m2( &C[ci], c3, gvl);
  302. m_top += 8;
  303. }
  304. if( M & 4 ) {
  305. gvl = __riscv_vsetvl_e16m1(4);
  306. BLASLONG ai=m_top*K;
  307. BLASLONG bi=n_top*K;
  308. _Float16 B0 = B[bi+0];
  309. _Float16 B1 = B[bi+1];
  310. _Float16 B2 = B[bi+2];
  311. _Float16 B3 = B[bi+3];
  312. bi += 4;
  313. vfloat16m1_t A0 = __riscv_vle16_v_f16m1(&A[ai + 0 * gvl], gvl);
  314. ai += 4;
  315. vfloat32m2_t result0 = __riscv_vfwmul_vf_f32m2( A0, B0, gvl);
  316. vfloat32m2_t result1 = __riscv_vfwmul_vf_f32m2( A0, B1, gvl);
  317. vfloat32m2_t result2 = __riscv_vfwmul_vf_f32m2( A0, B2, gvl);
  318. vfloat32m2_t result3 = __riscv_vfwmul_vf_f32m2( A0, B3, gvl);
  319. for(BLASLONG k=1; k < K; ++k) {
  320. B0 = B[bi+0];
  321. B1 = B[bi+1];
  322. B2 = B[bi+2];
  323. B3 = B[bi+3];
  324. bi += 4;
  325. A0 = __riscv_vle16_v_f16m1(&A[ai + 0 * gvl], gvl);
  326. ai += 4;
  327. result0 = __riscv_vfwmacc_vf_f32m2(result0, B0, A0, gvl);
  328. result1 = __riscv_vfwmacc_vf_f32m2(result1, B1, A0, gvl);
  329. result2 = __riscv_vfwmacc_vf_f32m2(result2, B2, A0, gvl);
  330. result3 = __riscv_vfwmacc_vf_f32m2(result3, B3, A0, gvl);
  331. }
  332. BLASLONG ci = n_top * ldc + m_top;
  333. vfloat32m2_t c0 = __riscv_vle32_v_f32m2(&C[ci], gvl);
  334. ci += ldc - gvl * 0;
  335. vfloat32m2_t c1 = __riscv_vle32_v_f32m2(&C[ci], gvl);
  336. ci += ldc - gvl * 0;
  337. vfloat32m2_t c2 = __riscv_vle32_v_f32m2(&C[ci], gvl);
  338. ci += ldc - gvl * 0;
  339. vfloat32m2_t c3 = __riscv_vle32_v_f32m2(&C[ci], gvl);
  340. c0 = __riscv_vfmacc_vf_f32m2(c0, alpha, result0, gvl);
  341. c1 = __riscv_vfmacc_vf_f32m2(c1, alpha, result1, gvl);
  342. c2 = __riscv_vfmacc_vf_f32m2(c2, alpha, result2, gvl);
  343. c3 = __riscv_vfmacc_vf_f32m2(c3, alpha, result3, gvl);
  344. ci= n_top * ldc + m_top;
  345. __riscv_vse32_v_f32m2(&C[ci], c0, gvl); ci += ldc - gvl * 0;
  346. __riscv_vse32_v_f32m2(&C[ci], c1, gvl); ci += ldc - gvl * 0;
  347. __riscv_vse32_v_f32m2(&C[ci], c2, gvl); ci += ldc - gvl * 0;
  348. __riscv_vse32_v_f32m2(&C[ci], c3, gvl);
  349. m_top += 4;
  350. }
  351. if( M & 2 ) {
  352. BLASLONG ai = m_top * K;
  353. BLASLONG bi = n_top * K;
  354. float result0 = 0;
  355. float result1 = 0;
  356. float result2 = 0;
  357. float result3 = 0;
  358. float result4 = 0;
  359. float result5 = 0;
  360. float result6 = 0;
  361. float result7 = 0;
  362. for(BLASLONG k=0; k<K; k++) {
  363. result0+=(float)(A[ai+0]*B[bi+0]);
  364. result1+=(float)(A[ai+1]*B[bi+0]);
  365. result2+=(float)(A[ai+0]*B[bi+1]);
  366. result3+=(float)(A[ai+1]*B[bi+1]);
  367. result4+=(float)(A[ai+0]*B[bi+2]);
  368. result5+=(float)(A[ai+1]*B[bi+2]);
  369. result6+=(float)(A[ai+0]*B[bi+3]);
  370. result7+=(float)(A[ai+1]*B[bi+3]);
  371. ai+=2;
  372. bi+=4;
  373. }
  374. BLASLONG ci=n_top*ldc+m_top;
  375. C[ci + 0 * ldc + 0] += alpha * result0;
  376. C[ci + 0 * ldc + 1] += alpha * result1;
  377. C[ci + 1 * ldc + 0] += alpha * result2;
  378. C[ci + 1 * ldc + 1] += alpha * result3;
  379. C[ci + 2 * ldc + 0] += alpha * result4;
  380. C[ci + 2 * ldc + 1] += alpha * result5;
  381. C[ci + 3 * ldc + 0] += alpha * result6;
  382. C[ci + 3 * ldc + 1] += alpha * result7;
  383. m_top += 2;
  384. }
  385. if( M & 1 ) {
  386. float result0 = 0;
  387. float result1 = 0;
  388. float result2 = 0;
  389. float result3 = 0;
  390. BLASLONG ai = m_top * K;
  391. BLASLONG bi = n_top * K;
  392. for(BLASLONG k=0; k<K; k++) {
  393. result0+=(float)(A[ai+0]*B[bi+0]);
  394. result1+=(float)(A[ai+0]*B[bi+1]);
  395. result2+=(float)(A[ai+0]*B[bi+2]);
  396. result3+=(float)(A[ai+0]*B[bi+3]);
  397. ai+=1;
  398. bi+=4;
  399. }
  400. BLASLONG ci = n_top * ldc + m_top;
  401. C[ci + 0 * ldc + 0] += alpha * result0;
  402. C[ci + 1 * ldc + 0] += alpha * result1;
  403. C[ci + 2 * ldc + 0] += alpha * result2;
  404. C[ci + 3 * ldc + 0] += alpha * result3;
  405. m_top += 1;
  406. }
  407. n_top += 4;
  408. }
  409. // -- tails for N=2
  410. if( N & 2 ) {
  411. gvl = __riscv_vsetvl_e16m1(8);
  412. m_top = 0;
  413. for (BLASLONG i=0; i<M/8; i+=1) {
  414. BLASLONG ai=m_top*K;
  415. BLASLONG bi=n_top*K;
  416. _Float16 B0 = B[bi+0];
  417. _Float16 B1 = B[bi+1];
  418. bi += 2;
  419. vfloat16m1_t A0 = __riscv_vle16_v_f16m1( &A[ai+0*gvl], gvl );
  420. ai += 8;
  421. vfloat32m2_t result0 = __riscv_vfwmul_vf_f32m2( A0, B0, gvl);
  422. vfloat32m2_t result1 = __riscv_vfwmul_vf_f32m2( A0, B1, gvl);
  423. for(BLASLONG k=1; k<K; k++) {
  424. B0 = B[bi+0];
  425. B1 = B[bi+1];
  426. bi += 2;
  427. A0 = __riscv_vle16_v_f16m1( &A[ai+0*gvl], gvl );
  428. ai += 8;
  429. result0 = __riscv_vfwmacc_vf_f32m2(result0, B0, A0, gvl);
  430. result1 = __riscv_vfwmacc_vf_f32m2(result1, B1, A0, gvl);
  431. }
  432. BLASLONG ci=n_top*ldc+m_top;
  433. vfloat32m2_t c0 = __riscv_vle32_v_f32m2( &C[ci], gvl); ci += ldc - gvl * 0;
  434. vfloat32m2_t c1 = __riscv_vle32_v_f32m2( &C[ci], gvl);
  435. c0 = __riscv_vfmacc_vf_f32m2(c0, alpha, result0, gvl);
  436. c1 = __riscv_vfmacc_vf_f32m2(c1, alpha, result1, gvl);
  437. ci = n_top * ldc + m_top;
  438. __riscv_vse32_v_f32m2( &C[ci], c0, gvl); ci += ldc-gvl*0;
  439. __riscv_vse32_v_f32m2( &C[ci], c1, gvl);
  440. m_top += 8;
  441. }
  442. if( M & 4 ) {
  443. gvl = __riscv_vsetvl_e16m1(4);
  444. BLASLONG ai=m_top*K;
  445. BLASLONG bi=n_top*K;
  446. _Float16 B0 = B[bi+0];
  447. _Float16 B1 = B[bi+1];
  448. bi += 2;
  449. vfloat16m1_t A0 = __riscv_vle16_v_f16m1(&A[ai + 0 * gvl], gvl);
  450. ai += 4;
  451. vfloat32m2_t result0 = __riscv_vfwmul_vf_f32m2( A0, B0, gvl);
  452. vfloat32m2_t result1 = __riscv_vfwmul_vf_f32m2( A0, B1, gvl);
  453. for(BLASLONG k=1; k < K; ++k) {
  454. B0 = B[bi+0];
  455. B1 = B[bi+1];
  456. bi += 2;
  457. A0 = __riscv_vle16_v_f16m1(&A[ai + 0 * gvl], gvl);
  458. ai += 4;
  459. result0 = __riscv_vfwmacc_vf_f32m2(result0, B0, A0, gvl);
  460. result1 = __riscv_vfwmacc_vf_f32m2(result1, B1, A0, gvl);
  461. }
  462. BLASLONG ci = n_top * ldc + m_top;
  463. vfloat32m2_t c0 = __riscv_vle32_v_f32m2(&C[ci], gvl);
  464. ci += ldc - gvl * 0;
  465. vfloat32m2_t c1 = __riscv_vle32_v_f32m2(&C[ci], gvl);
  466. c0 = __riscv_vfmacc_vf_f32m2(c0, alpha, result0, gvl);
  467. c1 = __riscv_vfmacc_vf_f32m2(c1, alpha, result1, gvl);
  468. ci= n_top * ldc + m_top;
  469. __riscv_vse32_v_f32m2(&C[ci], c0, gvl); ci += ldc - gvl * 0;
  470. __riscv_vse32_v_f32m2(&C[ci], c1, gvl);
  471. m_top += 4;
  472. }
  473. if( M & 2 ) {
  474. BLASLONG ai = m_top * K;
  475. BLASLONG bi = n_top * K;
  476. float result0 = 0;
  477. float result1 = 0;
  478. float result2 = 0;
  479. float result3 = 0;
  480. for(BLASLONG k=0; k<K; k++) {
  481. result0+=(float)(A[ai+0]*B[bi+0]);
  482. result1+=(float)(A[ai+1]*B[bi+0]);
  483. result2+=(float)(A[ai+0]*B[bi+1]);
  484. result3+=(float)(A[ai+1]*B[bi+1]);
  485. ai+=2;
  486. bi+=2;
  487. }
  488. BLASLONG ci=n_top*ldc+m_top;
  489. C[ci + 0 * ldc + 0] += alpha * result0;
  490. C[ci + 0 * ldc + 1] += alpha * result1;
  491. C[ci + 1 * ldc + 0] += alpha * result2;
  492. C[ci + 1 * ldc + 1] += alpha * result3;
  493. m_top += 2;
  494. }
  495. if( M & 1 ) {
  496. float result0 = 0;
  497. float result1 = 0;
  498. BLASLONG ai = m_top * K;
  499. BLASLONG bi = n_top * K;
  500. for(BLASLONG k=0; k<K; k++) {
  501. result0+=(float)(A[ai+0]*B[bi+0]);
  502. result1+=(float)(A[ai+0]*B[bi+1]);
  503. ai+=1;
  504. bi+=2;
  505. }
  506. BLASLONG ci = n_top * ldc + m_top;
  507. C[ci + 0 * ldc + 0] += alpha * result0;
  508. C[ci + 1 * ldc + 0] += alpha * result1;
  509. m_top += 1;
  510. }
  511. n_top += 2;
  512. }
  513. // -- tails for N=1
  514. if( N & 1 ) {
  515. gvl = __riscv_vsetvl_e16m1(8);
  516. m_top = 0;
  517. for (BLASLONG i=0; i<M/8; i+=1) {
  518. BLASLONG ai=m_top*K;
  519. BLASLONG bi=n_top*K;
  520. _Float16 B0 = B[bi+0];
  521. bi += 1;
  522. vfloat16m1_t A0 = __riscv_vle16_v_f16m1( &A[ai+0*gvl], gvl );
  523. ai += 8;
  524. vfloat32m2_t result0 = __riscv_vfwmul_vf_f32m2( A0, B0, gvl);
  525. for(BLASLONG k=1; k<K; k++) {
  526. B0 = B[bi+0];
  527. bi += 1;
  528. A0 = __riscv_vle16_v_f16m1( &A[ai+0*gvl], gvl );
  529. ai += 8;
  530. result0 = __riscv_vfwmacc_vf_f32m2(result0, B0, A0, gvl);
  531. }
  532. BLASLONG ci=n_top*ldc+m_top;
  533. vfloat32m2_t c0 = __riscv_vle32_v_f32m2( &C[ci], gvl);
  534. c0 = __riscv_vfmacc_vf_f32m2(c0, alpha, result0, gvl);
  535. ci = n_top * ldc + m_top;
  536. __riscv_vse32_v_f32m2( &C[ci], c0, gvl);
  537. m_top += 8;
  538. }
  539. if( M & 4 ) {
  540. gvl = __riscv_vsetvl_e16m1(4);
  541. BLASLONG ai=m_top*K;
  542. BLASLONG bi=n_top*K;
  543. _Float16 B0 = B[bi+0];
  544. bi += 1;
  545. vfloat16m1_t A0 = __riscv_vle16_v_f16m1(&A[ai + 0 * gvl], gvl);
  546. ai += 4;
  547. vfloat32m2_t result0 = __riscv_vfwmul_vf_f32m2( A0, B0, gvl);
  548. for(BLASLONG k=1; k < K; ++k) {
  549. B0 = B[bi+0];
  550. bi += 1;
  551. A0 = __riscv_vle16_v_f16m1(&A[ai + 0 * gvl], gvl);
  552. ai += 4;
  553. result0 = __riscv_vfwmacc_vf_f32m2(result0, B0, A0, gvl);
  554. }
  555. BLASLONG ci = n_top * ldc + m_top;
  556. vfloat32m2_t c0 = __riscv_vle32_v_f32m2(&C[ci], gvl);
  557. c0 = __riscv_vfmacc_vf_f32m2(c0, alpha, result0, gvl);
  558. ci= n_top * ldc + m_top;
  559. __riscv_vse32_v_f32m2(&C[ci], c0, gvl);
  560. m_top += 4;
  561. }
  562. if( M & 2 ) {
  563. BLASLONG ai = m_top * K;
  564. BLASLONG bi = n_top * K;
  565. float result0 = 0;
  566. float result1 = 0;
  567. for(BLASLONG k=0; k<K; k++) {
  568. result0+=(float)(A[ai+0]*B[bi+0]);
  569. result1+=(float)(A[ai+1]*B[bi+0]);
  570. ai+=2;
  571. bi+=1;
  572. }
  573. BLASLONG ci=n_top*ldc+m_top;
  574. C[ci + 0 * ldc + 0] += alpha * result0;
  575. C[ci + 0 * ldc + 1] += alpha * result1;
  576. m_top += 2;
  577. }
  578. if( M & 1 ) {
  579. float result0 = 0;
  580. BLASLONG ai = m_top * K;
  581. BLASLONG bi = n_top * K;
  582. for(BLASLONG k=0; k<K; k++) {
  583. result0+=(float)(A[ai+0]*B[bi+0]);
  584. ai+=1;
  585. bi+=1;
  586. }
  587. BLASLONG ci = n_top * ldc + m_top;
  588. C[ci + 0 * ldc + 0] += alpha * result0;
  589. m_top += 1;
  590. }
  591. n_top += 1;
  592. }
  593. return 0;
  594. }