You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

sbgemm_microk_cooperlake_template.c 94 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835
  1. #include "bf16_common_macros.h"
  2. #include <immintrin.h>
  3. #define BF16_BLOCK_STEP_N 8
  4. #define BF16_BLOCK_THRES_K 1024
  5. #define BF16_BLOCK_THRES_M 32
  6. #define BF16_BLOCK_THRES_N 1024
  7. #define A(i,j) A[(i)*lda+(j)]
  8. #define B(i,j) B[(i)*ldb+(j)]
  9. #define C(i,j) C[(i)*ldc+(j)]
  10. #define ONE 1.e0f
  11. #define ZERO 0.e0f
  12. #define SHUFFLE_MAGIC_NO (const int) 0x39
  13. #undef STORE16_COMPLETE_RESULT
  14. #undef STORE16_MASK_COMPLETE_RESULT
  15. #undef SBGEMM_BLOCK_KERNEL_NN_32x8xK
  16. #undef SBGEMM_BLOCK_KERNEL_NN_16x8xK
  17. #undef SBGEMM_BLOCK_KERNEL_NN_32xNx32
  18. #undef SBGEMM_BLOCK_KERNEL_NN_16xNx32
  19. #undef SBGEMM_BLOCK_KERNEL_NT_32x8xK
  20. #undef SBGEMM_BLOCK_KERNEL_NT_16x8xK
  21. #undef SBGEMM_BLOCK_KERNEL_NT_32xNxK
  22. #undef SBGEMM_BLOCK_KERNEL_NT_16xNxK
  23. #undef SBGEMM_BLOCK_KERNEL_TN_32x8xK
  24. #undef SBGEMM_BLOCK_KERNEL_TN_16x8xK
  25. #undef SBGEMM_BLOCK_KERNEL_TN_32xNx32
  26. #undef SBGEMM_BLOCK_KERNEL_TN_16xNx32
  27. #undef SBGEMM_BLOCK_KERNEL_TT_32x8xK
  28. #undef SBGEMM_BLOCK_KERNEL_TT_16x8xK
  29. #undef SBGEMM_BLOCK_KERNEL_TT_32xNxK
  30. #undef SBGEMM_BLOCK_KERNEL_TT_16xNxK
  31. #undef SBGEMM_BLOCKING_KERNEL_NN
  32. #undef SBGEMM_BLOCKING_KERNEL_NT
  33. #undef SBGEMM_BLOCKING_KERNEL_TN
  34. #undef SBGEMM_BLOCKING_KERNEL_TT
  35. #ifndef ONE_ALPHA // ALPHA is not ONE
  36. #define STORE16_COMPLETE_RESULT STORE16_COMPLETE_RESULT_ALPHA_ONE
  37. #define STORE16_MASK_COMPLETE_RESULT STORE16_MASK_COMPLETE_RESULT_ALPHA_ONE
  38. #define SBGEMM_BLOCK_KERNEL_NN_32x8xK sbgemm_block_kernel_nn_32x8xK_alpha
  39. #define SBGEMM_BLOCK_KERNEL_NN_16x8xK sbgemm_block_kernel_nn_16x8xK_alpha
  40. #define SBGEMM_BLOCK_KERNEL_NN_32xNx32 sbgemm_block_kernel_nn_32xNx32_alpha
  41. #define SBGEMM_BLOCK_KERNEL_NN_16xNx32 sbgemm_block_kernel_nn_16xNx32_alpha
  42. #define SBGEMM_BLOCK_KERNEL_NT_32x8xK SBGEMM_BLOCK_KERNEL_NN_32x8xK
  43. #define SBGEMM_BLOCK_KERNEL_NT_16x8xK SBGEMM_BLOCK_KERNEL_NN_16x8xK
  44. #define SBGEMM_BLOCK_KERNEL_NT_32xNxK sbgemm_block_kernel_nt_32xNxK_alpha
  45. #define SBGEMM_BLOCK_KERNEL_NT_16xNxK sbgemm_block_kernel_nt_16xNxK_alpha
  46. #define SBGEMM_BLOCK_KERNEL_TN_32x8xK sbgemm_block_kernel_tn_32x8xK_alpha
  47. #define SBGEMM_BLOCK_KERNEL_TN_16x8xK sbgemm_block_kernel_tn_16x8xK_alpha
  48. #define SBGEMM_BLOCK_KERNEL_TN_32xNx32 sbgemm_block_kernel_tn_32xNx32_alpha
  49. #define SBGEMM_BLOCK_KERNEL_TN_16xNx32 sbgemm_block_kernel_tn_16xNx32_alpha
  50. #define SBGEMM_BLOCK_KERNEL_TT_32x8xK SBGEMM_BLOCK_KERNEL_TN_32x8xK
  51. #define SBGEMM_BLOCK_KERNEL_TT_16x8xK SBGEMM_BLOCK_KERNEL_TN_16x8xK
  52. #define SBGEMM_BLOCK_KERNEL_TT_32xNxK sbgemm_block_kernel_tt_32xNxK_alpha
  53. #define SBGEMM_BLOCK_KERNEL_TT_16xNxK sbgemm_block_kernel_tt_16xNxK_alpha
  54. #define SBGEMM_BLOCKING_KERNEL_NN sbgemm_blocking_kernel_nn_alpha
  55. #define SBGEMM_BLOCKING_KERNEL_NT sbgemm_blocking_kernel_nt_alpha
  56. #define SBGEMM_BLOCKING_KERNEL_TN sbgemm_blocking_kernel_tn_alpha
  57. #define SBGEMM_BLOCKING_KERNEL_TT sbgemm_blocking_kernel_tt_alpha
  58. #else // ALPHA is ONE
  59. #define STORE16_COMPLETE_RESULT STORE16_COMPLETE_RESULT_ONE_ONE
  60. #define STORE16_MASK_COMPLETE_RESULT STORE16_MASK_COMPLETE_RESULT_ONE_ONE
  61. #define SBGEMM_BLOCK_KERNEL_NN_32x8xK sbgemm_block_kernel_nn_32x8xK_one
  62. #define SBGEMM_BLOCK_KERNEL_NN_16x8xK sbgemm_block_kernel_nn_16x8xK_one
  63. #define SBGEMM_BLOCK_KERNEL_NN_32xNx32 sbgemm_block_kernel_nn_32xNx32_one
  64. #define SBGEMM_BLOCK_KERNEL_NN_16xNx32 sbgemm_block_kernel_nn_16xNx32_one
  65. #define SBGEMM_BLOCK_KERNEL_NT_32x8xK SBGEMM_BLOCK_KERNEL_NN_32x8xK
  66. #define SBGEMM_BLOCK_KERNEL_NT_16x8xK SBGEMM_BLOCK_KERNEL_NN_16x8xK
  67. #define SBGEMM_BLOCK_KERNEL_NT_32xNxK sbgemm_block_kernel_nt_32xNxK_one
  68. #define SBGEMM_BLOCK_KERNEL_NT_16xNxK sbgemm_block_kernel_nt_16xNxK_one
  69. #define SBGEMM_BLOCK_KERNEL_TN_32x8xK sbgemm_block_kernel_tn_32x8xK_one
  70. #define SBGEMM_BLOCK_KERNEL_TN_16x8xK sbgemm_block_kernel_tn_16x8xK_one
  71. #define SBGEMM_BLOCK_KERNEL_TN_32xNx32 sbgemm_block_kernel_tn_32xNx32_one
  72. #define SBGEMM_BLOCK_KERNEL_TN_16xNx32 sbgemm_block_kernel_tn_16xNx32_one
  73. #define SBGEMM_BLOCK_KERNEL_TT_32x8xK SBGEMM_BLOCK_KERNEL_TN_32x8xK
  74. #define SBGEMM_BLOCK_KERNEL_TT_16x8xK SBGEMM_BLOCK_KERNEL_TN_16x8xK
  75. #define SBGEMM_BLOCK_KERNEL_TT_32xNxK sbgemm_block_kernel_tt_32xNxK_one
  76. #define SBGEMM_BLOCK_KERNEL_TT_16xNxK sbgemm_block_kernel_tt_16xNxK_one
  77. #define SBGEMM_BLOCKING_KERNEL_NN sbgemm_blocking_kernel_nn_one
  78. #define SBGEMM_BLOCKING_KERNEL_NT sbgemm_blocking_kernel_nt_one
  79. #define SBGEMM_BLOCKING_KERNEL_TN sbgemm_blocking_kernel_tn_one
  80. #define SBGEMM_BLOCKING_KERNEL_TT sbgemm_blocking_kernel_tt_one
  81. #endif
  82. extern bfloat16 * block_A;
  83. extern bfloat16 * block_B;
  84. /* --------------------------------------------- NN kernels ------------------------------------------ */
  85. // SBGEMM Kernel for 16<M<=32, N=8, K can be any number, but the processing will take 32 as a base
  86. #ifndef ONE_ALPHA // ALPHA is not ONE
  87. void sbgemm_block_kernel_nn_32x8xK_alpha(BLASLONG m, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
  88. #else // ALPHA is ONE
  89. void sbgemm_block_kernel_nn_32x8xK_one(BLASLONG m, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
  90. #endif
  91. {
  92. bfloat16 * A_addr = A;
  93. bfloat16 * B_addr = B;
  94. float * C_addr = C;
  95. #ifndef ONE_ALPHA
  96. __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
  97. #endif
  98. __m512i arrayA_512_0, arrayA_512_1;
  99. __m512i arrayB_512_0, arrayB_512_1, arrayB_512_2, arrayB_512_3, arrayB_512_4, arrayB_512_5, arrayB_512_6, arrayB_512_7;
  100. __m512 result_512_0, result_512_1, result_512_2, result_512_3, result_512_4, result_512_5, result_512_6, result_512_7,
  101. result_512_8, result_512_9, result_512_10, result_512_11, result_512_12, result_512_13, result_512_14, result_512_15;
  102. __m512 result_512_tmp_0, result_512_tmp_1, result_512_tmp_2, result_512_tmp_3;
  103. __m512i M512_EPI32_8 = _mm512_set1_epi32(8);
  104. __m512i shuffle_idx_base0 = _mm512_set_epi32(23, 22, 21, 20, 7, 6, 5, 4, 19, 18, 17, 16, 3, 2, 1, 0);
  105. __m512i shuffle_idx_base1 = _mm512_add_epi32(shuffle_idx_base0, M512_EPI32_8);
  106. result_512_0 = _mm512_setzero_ps();
  107. result_512_1 = _mm512_setzero_ps();
  108. result_512_2 = _mm512_setzero_ps();
  109. result_512_3 = _mm512_setzero_ps();
  110. result_512_4 = _mm512_setzero_ps();
  111. result_512_5 = _mm512_setzero_ps();
  112. result_512_6 = _mm512_setzero_ps();
  113. result_512_7 = _mm512_setzero_ps();
  114. result_512_8 = _mm512_setzero_ps();
  115. result_512_9 = _mm512_setzero_ps();
  116. result_512_10 = _mm512_setzero_ps();
  117. result_512_11 = _mm512_setzero_ps();
  118. result_512_12 = _mm512_setzero_ps();
  119. result_512_13 = _mm512_setzero_ps();
  120. result_512_14 = _mm512_setzero_ps();
  121. result_512_15 = _mm512_setzero_ps();
  122. for (BLASLONG idx_k = 0; idx_k < k; idx_k += 2) {
  123. // Each two rows are a group for 32-pair bf16 elements
  124. arrayA_512_0 = _mm512_loadu_si512(A_addr);
  125. arrayA_512_1 = _mm512_loadu_si512(A_addr + 32);
  126. _MM512_BROADCASTD_EPI32(B_addr + 0, arrayB_512_0);
  127. _MM512_BROADCASTD_EPI32(B_addr + 2, arrayB_512_1);
  128. _MM512_BROADCASTD_EPI32(B_addr + 4, arrayB_512_2);
  129. _MM512_BROADCASTD_EPI32(B_addr + 6, arrayB_512_3);
  130. _MM512_BROADCASTD_EPI32(B_addr + 8, arrayB_512_4);
  131. _MM512_BROADCASTD_EPI32(B_addr + 10, arrayB_512_5);
  132. _MM512_BROADCASTD_EPI32(B_addr + 12, arrayB_512_6);
  133. _MM512_BROADCASTD_EPI32(B_addr + 14, arrayB_512_7);
  134. result_512_0 = _mm512_dpbf16_ps(result_512_0, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_0);
  135. result_512_1 = _mm512_dpbf16_ps(result_512_1, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_1);
  136. result_512_2 = _mm512_dpbf16_ps(result_512_2, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_2);
  137. result_512_3 = _mm512_dpbf16_ps(result_512_3, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_3);
  138. result_512_4 = _mm512_dpbf16_ps(result_512_4, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_4);
  139. result_512_5 = _mm512_dpbf16_ps(result_512_5, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_5);
  140. result_512_6 = _mm512_dpbf16_ps(result_512_6, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_6);
  141. result_512_7 = _mm512_dpbf16_ps(result_512_7, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_7);
  142. result_512_8 = _mm512_dpbf16_ps(result_512_8, (__m512bh) arrayA_512_1, (__m512bh) arrayB_512_0);
  143. result_512_9 = _mm512_dpbf16_ps(result_512_9, (__m512bh) arrayA_512_1, (__m512bh) arrayB_512_1);
  144. result_512_10 = _mm512_dpbf16_ps(result_512_10, (__m512bh) arrayA_512_1, (__m512bh) arrayB_512_2);
  145. result_512_11 = _mm512_dpbf16_ps(result_512_11, (__m512bh) arrayA_512_1, (__m512bh) arrayB_512_3);
  146. result_512_12 = _mm512_dpbf16_ps(result_512_12, (__m512bh) arrayA_512_1, (__m512bh) arrayB_512_4);
  147. result_512_13 = _mm512_dpbf16_ps(result_512_13, (__m512bh) arrayA_512_1, (__m512bh) arrayB_512_5);
  148. result_512_14 = _mm512_dpbf16_ps(result_512_14, (__m512bh) arrayA_512_1, (__m512bh) arrayB_512_6);
  149. result_512_15 = _mm512_dpbf16_ps(result_512_15, (__m512bh) arrayA_512_1, (__m512bh) arrayB_512_7);
  150. // Load B with unroll 8
  151. B_addr += 16;
  152. // Load A with unroll 64
  153. A_addr += 64;
  154. }
  155. if (m != 32) {
  156. unsigned short tail_mask_value = (((unsigned short)0xffff) >> (32-m));
  157. __mmask16 tail_mask = *((__mmask16*) &tail_mask_value);
  158. result_512_tmp_0 = _mm512_permutex2var_ps(result_512_0, shuffle_idx_base0, result_512_8);
  159. result_512_tmp_1 = _mm512_permutex2var_ps(result_512_0, shuffle_idx_base1, result_512_8);
  160. result_512_tmp_2 = _mm512_permutex2var_ps(result_512_1, shuffle_idx_base0, result_512_9);
  161. result_512_tmp_3 = _mm512_permutex2var_ps(result_512_1, shuffle_idx_base1, result_512_9);
  162. STORE16_COMPLETE_RESULT(result_512_tmp_0, (C_addr))
  163. STORE16_MASK_COMPLETE_RESULT(result_512_tmp_1, (C_addr + 16), tail_mask)
  164. STORE16_COMPLETE_RESULT(result_512_tmp_2, (C_addr + ldc*1))
  165. STORE16_MASK_COMPLETE_RESULT(result_512_tmp_3, (C_addr + ldc*1 + 16), tail_mask)
  166. result_512_tmp_0 = _mm512_permutex2var_ps(result_512_2, shuffle_idx_base0, result_512_10);
  167. result_512_tmp_1 = _mm512_permutex2var_ps(result_512_2, shuffle_idx_base1, result_512_10);
  168. result_512_tmp_2 = _mm512_permutex2var_ps(result_512_3, shuffle_idx_base0, result_512_11);
  169. result_512_tmp_3 = _mm512_permutex2var_ps(result_512_3, shuffle_idx_base1, result_512_11);
  170. STORE16_COMPLETE_RESULT(result_512_tmp_0, (C_addr + ldc*2))
  171. STORE16_MASK_COMPLETE_RESULT(result_512_tmp_1, (C_addr + ldc*2 + 16), tail_mask)
  172. STORE16_COMPLETE_RESULT(result_512_tmp_2, (C_addr + ldc*3))
  173. STORE16_MASK_COMPLETE_RESULT(result_512_tmp_3, (C_addr + ldc*3 + 16), tail_mask)
  174. result_512_tmp_0 = _mm512_permutex2var_ps(result_512_4, shuffle_idx_base0, result_512_12);
  175. result_512_tmp_1 = _mm512_permutex2var_ps(result_512_4, shuffle_idx_base1, result_512_12);
  176. result_512_tmp_2 = _mm512_permutex2var_ps(result_512_5, shuffle_idx_base0, result_512_13);
  177. result_512_tmp_3 = _mm512_permutex2var_ps(result_512_5, shuffle_idx_base1, result_512_13);
  178. STORE16_COMPLETE_RESULT(result_512_tmp_0, (C_addr + ldc*4))
  179. STORE16_MASK_COMPLETE_RESULT(result_512_tmp_1, (C_addr + ldc*4 + 16), tail_mask)
  180. STORE16_COMPLETE_RESULT(result_512_tmp_2, (C_addr + ldc*5))
  181. STORE16_MASK_COMPLETE_RESULT(result_512_tmp_3, (C_addr + ldc*5 + 16), tail_mask)
  182. result_512_tmp_0 = _mm512_permutex2var_ps(result_512_6, shuffle_idx_base0, result_512_14);
  183. result_512_tmp_1 = _mm512_permutex2var_ps(result_512_6, shuffle_idx_base1, result_512_14);
  184. result_512_tmp_2 = _mm512_permutex2var_ps(result_512_7, shuffle_idx_base0, result_512_15);
  185. result_512_tmp_3 = _mm512_permutex2var_ps(result_512_7, shuffle_idx_base1, result_512_15);
  186. STORE16_COMPLETE_RESULT(result_512_tmp_0, (C_addr + ldc*6))
  187. STORE16_MASK_COMPLETE_RESULT(result_512_tmp_1, (C_addr + ldc*6 + 16), tail_mask)
  188. STORE16_COMPLETE_RESULT(result_512_tmp_2, (C_addr + ldc*7))
  189. STORE16_MASK_COMPLETE_RESULT(result_512_tmp_3, (C_addr + ldc*7 + 16), tail_mask)
  190. } else {
  191. result_512_tmp_0 = _mm512_permutex2var_ps(result_512_0, shuffle_idx_base0, result_512_8);
  192. result_512_tmp_1 = _mm512_permutex2var_ps(result_512_0, shuffle_idx_base1, result_512_8);
  193. result_512_tmp_2 = _mm512_permutex2var_ps(result_512_1, shuffle_idx_base0, result_512_9);
  194. result_512_tmp_3 = _mm512_permutex2var_ps(result_512_1, shuffle_idx_base1, result_512_9);
  195. STORE16_COMPLETE_RESULT(result_512_tmp_0, (C_addr))
  196. STORE16_COMPLETE_RESULT(result_512_tmp_1, (C_addr + 16))
  197. STORE16_COMPLETE_RESULT(result_512_tmp_2, (C_addr + ldc*1))
  198. STORE16_COMPLETE_RESULT(result_512_tmp_3, (C_addr + ldc*1 + 16))
  199. result_512_tmp_0 = _mm512_permutex2var_ps(result_512_2, shuffle_idx_base0, result_512_10);
  200. result_512_tmp_1 = _mm512_permutex2var_ps(result_512_2, shuffle_idx_base1, result_512_10);
  201. result_512_tmp_2 = _mm512_permutex2var_ps(result_512_3, shuffle_idx_base0, result_512_11);
  202. result_512_tmp_3 = _mm512_permutex2var_ps(result_512_3, shuffle_idx_base1, result_512_11);
  203. STORE16_COMPLETE_RESULT(result_512_tmp_0, (C_addr + ldc*2))
  204. STORE16_COMPLETE_RESULT(result_512_tmp_1, (C_addr + ldc*2 + 16))
  205. STORE16_COMPLETE_RESULT(result_512_tmp_2, (C_addr + ldc*3))
  206. STORE16_COMPLETE_RESULT(result_512_tmp_3, (C_addr + ldc*3 + 16))
  207. result_512_tmp_0 = _mm512_permutex2var_ps(result_512_4, shuffle_idx_base0, result_512_12);
  208. result_512_tmp_1 = _mm512_permutex2var_ps(result_512_4, shuffle_idx_base1, result_512_12);
  209. result_512_tmp_2 = _mm512_permutex2var_ps(result_512_5, shuffle_idx_base0, result_512_13);
  210. result_512_tmp_3 = _mm512_permutex2var_ps(result_512_5, shuffle_idx_base1, result_512_13);
  211. STORE16_COMPLETE_RESULT(result_512_tmp_0, (C_addr + ldc*4))
  212. STORE16_COMPLETE_RESULT(result_512_tmp_1, (C_addr + ldc*4 + 16))
  213. STORE16_COMPLETE_RESULT(result_512_tmp_2, (C_addr + ldc*5))
  214. STORE16_COMPLETE_RESULT(result_512_tmp_3, (C_addr + ldc*5 + 16))
  215. result_512_tmp_0 = _mm512_permutex2var_ps(result_512_6, shuffle_idx_base0, result_512_14);
  216. result_512_tmp_1 = _mm512_permutex2var_ps(result_512_6, shuffle_idx_base1, result_512_14);
  217. result_512_tmp_2 = _mm512_permutex2var_ps(result_512_7, shuffle_idx_base0, result_512_15);
  218. result_512_tmp_3 = _mm512_permutex2var_ps(result_512_7, shuffle_idx_base1, result_512_15);
  219. STORE16_COMPLETE_RESULT(result_512_tmp_0, (C_addr + ldc*6))
  220. STORE16_COMPLETE_RESULT(result_512_tmp_1, (C_addr + ldc*6 + 16))
  221. STORE16_COMPLETE_RESULT(result_512_tmp_2, (C_addr + ldc*7))
  222. STORE16_COMPLETE_RESULT(result_512_tmp_3, (C_addr + ldc*7 + 16))
  223. }
  224. }
  225. // SBGEMM Kernel for M<=16, N=8, K can be any number
  226. #ifndef ONE_ALPHA // ALPHA is not ONE
  227. void sbgemm_block_kernel_nn_16x8xK_alpha(BLASLONG m, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
  228. #else // ALPHA is ONE
  229. void sbgemm_block_kernel_nn_16x8xK_one(BLASLONG m, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
  230. #endif
  231. {
  232. bfloat16 * A_addr = A;
  233. bfloat16 * B_addr = B;
  234. float * C_addr = C;
  235. #ifndef ONE_ALPHA
  236. __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
  237. #endif
  238. __m512i arrayA_512_0;
  239. __m512i arrayB_512_0, arrayB_512_1, arrayB_512_2, arrayB_512_3, arrayB_512_4, arrayB_512_5, arrayB_512_6, arrayB_512_7;
  240. __m512 result_512_0, result_512_1, result_512_2, result_512_3, result_512_4, result_512_5, result_512_6, result_512_7;
  241. result_512_0 = _mm512_setzero_ps();
  242. result_512_1 = _mm512_setzero_ps();
  243. result_512_2 = _mm512_setzero_ps();
  244. result_512_3 = _mm512_setzero_ps();
  245. result_512_4 = _mm512_setzero_ps();
  246. result_512_5 = _mm512_setzero_ps();
  247. result_512_6 = _mm512_setzero_ps();
  248. result_512_7 = _mm512_setzero_ps();
  249. for (BLASLONG idx_k = 0; idx_k < k; idx_k += 2) {
  250. // Each two rows are a group for 32-pair bf16 elements
  251. // Load two rows into a 512 register
  252. arrayA_512_0 = _mm512_loadu_si512(A_addr);
  253. _MM512_BROADCASTD_EPI32(B_addr + 0, arrayB_512_0);
  254. _MM512_BROADCASTD_EPI32(B_addr + 2, arrayB_512_1);
  255. _MM512_BROADCASTD_EPI32(B_addr + 4, arrayB_512_2);
  256. _MM512_BROADCASTD_EPI32(B_addr + 6, arrayB_512_3);
  257. _MM512_BROADCASTD_EPI32(B_addr + 8, arrayB_512_4);
  258. _MM512_BROADCASTD_EPI32(B_addr + 10, arrayB_512_5);
  259. _MM512_BROADCASTD_EPI32(B_addr + 12, arrayB_512_6);
  260. _MM512_BROADCASTD_EPI32(B_addr + 14, arrayB_512_7);
  261. result_512_0 = _mm512_dpbf16_ps(result_512_0, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_0);
  262. result_512_1 = _mm512_dpbf16_ps(result_512_1, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_1);
  263. result_512_2 = _mm512_dpbf16_ps(result_512_2, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_2);
  264. result_512_3 = _mm512_dpbf16_ps(result_512_3, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_3);
  265. result_512_4 = _mm512_dpbf16_ps(result_512_4, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_4);
  266. result_512_5 = _mm512_dpbf16_ps(result_512_5, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_5);
  267. result_512_6 = _mm512_dpbf16_ps(result_512_6, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_6);
  268. result_512_7 = _mm512_dpbf16_ps(result_512_7, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_7);
  269. // Load B with unroll 8
  270. B_addr += 16;
  271. // Load A with unroll 16
  272. A_addr += 32;
  273. }
  274. if (m != 16) {
  275. unsigned short tail_mask = (((unsigned short)0xffff) >> (16-m));
  276. result_512_0 = _mm512_shuffle_f32x4(result_512_0, result_512_0, 0xd8);
  277. result_512_1 = _mm512_shuffle_f32x4(result_512_1, result_512_1, 0xd8);
  278. result_512_2 = _mm512_shuffle_f32x4(result_512_2, result_512_2, 0xd8);
  279. result_512_3 = _mm512_shuffle_f32x4(result_512_3, result_512_3, 0xd8);
  280. STORE16_MASK_COMPLETE_RESULT(result_512_0, (C_addr), tail_mask)
  281. STORE16_MASK_COMPLETE_RESULT(result_512_1, (C_addr + ldc*1), tail_mask)
  282. STORE16_MASK_COMPLETE_RESULT(result_512_2, (C_addr + ldc*2), tail_mask)
  283. STORE16_MASK_COMPLETE_RESULT(result_512_3, (C_addr + ldc*3), tail_mask)
  284. result_512_4 = _mm512_shuffle_f32x4(result_512_4, result_512_4, 0xd8);
  285. result_512_5 = _mm512_shuffle_f32x4(result_512_5, result_512_5, 0xd8);
  286. result_512_6 = _mm512_shuffle_f32x4(result_512_6, result_512_6, 0xd8);
  287. result_512_7 = _mm512_shuffle_f32x4(result_512_7, result_512_7, 0xd8);
  288. STORE16_MASK_COMPLETE_RESULT(result_512_4, (C_addr + ldc*4), tail_mask)
  289. STORE16_MASK_COMPLETE_RESULT(result_512_5, (C_addr + ldc*5), tail_mask)
  290. STORE16_MASK_COMPLETE_RESULT(result_512_6, (C_addr + ldc*6), tail_mask)
  291. STORE16_MASK_COMPLETE_RESULT(result_512_7, (C_addr + ldc*7), tail_mask)
  292. } else {
  293. result_512_0 = _mm512_shuffle_f32x4(result_512_0, result_512_0, 0xd8);
  294. result_512_1 = _mm512_shuffle_f32x4(result_512_1, result_512_1, 0xd8);
  295. result_512_2 = _mm512_shuffle_f32x4(result_512_2, result_512_2, 0xd8);
  296. result_512_3 = _mm512_shuffle_f32x4(result_512_3, result_512_3, 0xd8);
  297. STORE16_COMPLETE_RESULT(result_512_0, (C_addr))
  298. STORE16_COMPLETE_RESULT(result_512_1, (C_addr + ldc*1))
  299. STORE16_COMPLETE_RESULT(result_512_2, (C_addr + ldc*2))
  300. STORE16_COMPLETE_RESULT(result_512_3, (C_addr + ldc*3))
  301. result_512_4 = _mm512_shuffle_f32x4(result_512_4, result_512_4, 0xd8);
  302. result_512_5 = _mm512_shuffle_f32x4(result_512_5, result_512_5, 0xd8);
  303. result_512_6 = _mm512_shuffle_f32x4(result_512_6, result_512_6, 0xd8);
  304. result_512_7 = _mm512_shuffle_f32x4(result_512_7, result_512_7, 0xd8);
  305. STORE16_COMPLETE_RESULT(result_512_4, (C_addr + ldc*4))
  306. STORE16_COMPLETE_RESULT(result_512_5, (C_addr + ldc*5))
  307. STORE16_COMPLETE_RESULT(result_512_6, (C_addr + ldc*6))
  308. STORE16_COMPLETE_RESULT(result_512_7, (C_addr + ldc*7))
  309. }
  310. }
  311. // SBGEMM Kernel for 16<M<=32, N<8, K can be any number, but the processing will take 32 as a base
  312. #ifndef ONE_ALPHA // ALPHA is not ONE
  313. void sbgemm_block_kernel_nn_32xNx32_alpha(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
  314. #else // ALPHA is ONE
  315. void sbgemm_block_kernel_nn_32xNx32_one(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
  316. #endif
  317. {
  318. bfloat16 * A_addr = A;
  319. bfloat16 * B_addr = B;
  320. float * C_addr = C;
  321. BLASLONG tag_k_32x = k & (~31);
  322. #ifndef ONE_ALPHA
  323. __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
  324. #endif
  325. __m512i arrayA_512[2];
  326. __m512i arrayB_512[8];
  327. __m512 result_512[16];
  328. __m512 result_512_tmp_0, result_512_tmp_1;
  329. __m512i M512_EPI32_8 = _mm512_set1_epi32(8);
  330. __m512i shuffle_idx_base0 = _mm512_set_epi32(23, 22, 21, 20, 7, 6, 5, 4, 19, 18, 17, 16, 3, 2, 1, 0);
  331. __m512i shuffle_idx_base1 = _mm512_add_epi32(shuffle_idx_base0, M512_EPI32_8);
  332. for (int i = 0; i < 15; i += 2) {
  333. result_512[i] = _mm512_setzero_ps();
  334. result_512[i+1] = _mm512_setzero_ps();
  335. }
  336. for (BLASLONG idx_k = 0; idx_k < tag_k_32x; idx_k += 32) {
  337. // Load B with unroll n
  338. for (int i = 0; i < n; i ++) {
  339. arrayB_512[i] = _mm512_loadu_si512(B_addr);
  340. B_addr += 32;
  341. }
  342. for (BLASLONG idx = 0; idx < 32;) {
  343. // Each two rows are a group for 32-pair bf16 elements
  344. arrayA_512[0] = _mm512_loadu_si512(A_addr);
  345. arrayA_512[1] = _mm512_loadu_si512(A_addr + 32);
  346. A_addr += 64;
  347. for (int i = 0; i < n; i++) {
  348. result_512[i] = _mm512_dpbf16_ps(result_512[i] , (__m512bh) arrayA_512[0], (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512[i])));
  349. result_512[i+8] = _mm512_dpbf16_ps(result_512[i+8], (__m512bh) arrayA_512[1], (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512[i])));
  350. arrayB_512[i] = _mm512_shuffle_epi32(arrayB_512[i], SHUFFLE_MAGIC_NO);
  351. }
  352. idx += 2;
  353. // Every 4 loops we need to switch to next 128 bits of arrayB registers
  354. if ((idx & (~7)) == idx) {
  355. for (int i = 0; i < n; i++) {
  356. arrayB_512[i] = _mm512_shuffle_i32x4(arrayB_512[i], arrayB_512[i], SHUFFLE_MAGIC_NO);
  357. }
  358. }
  359. }
  360. }
  361. if (tag_k_32x != k) {
  362. // Load B with unroll n
  363. for (int i = 0; i < n; i ++) {
  364. arrayB_512[i] = _mm512_loadu_si512(B_addr);
  365. B_addr += 32;
  366. }
  367. BLASLONG width = k - tag_k_32x;
  368. for (BLASLONG idx = 0; idx < width;) {
  369. // Each two rows are a group for 32-pair bf16 elements
  370. arrayA_512[0] = _mm512_loadu_si512(A_addr);
  371. arrayA_512[1] = _mm512_loadu_si512(A_addr + 32);
  372. A_addr += 64;
  373. for (int i = 0; i < n; i++) {
  374. result_512[i] = _mm512_dpbf16_ps(result_512[i] , (__m512bh) arrayA_512[0], (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512[i])));
  375. result_512[i+8] = _mm512_dpbf16_ps(result_512[i+8], (__m512bh) arrayA_512[1], (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512[i])));
  376. arrayB_512[i] = _mm512_shuffle_epi32(arrayB_512[i], SHUFFLE_MAGIC_NO);
  377. }
  378. idx += 2;
  379. // Every 4 loops we need to switch to next 128 bits of arrayB registers
  380. if ((idx & (~7)) == idx) {
  381. for (int i = 0; i < n; i++) {
  382. arrayB_512[i] = _mm512_shuffle_i32x4(arrayB_512[i], arrayB_512[i], SHUFFLE_MAGIC_NO);
  383. }
  384. }
  385. }
  386. }
  387. if (m != 32) {
  388. unsigned short tail_mask = (((unsigned short)0xffff) >> (32-m));
  389. for (int i = 0; i < n; i++) {
  390. result_512_tmp_0 = _mm512_permutex2var_ps(result_512[i], shuffle_idx_base0, result_512[i+8]);
  391. result_512_tmp_1 = _mm512_permutex2var_ps(result_512[i], shuffle_idx_base1, result_512[i+8]);
  392. STORE16_COMPLETE_RESULT(result_512_tmp_0, (C_addr + ldc*i))
  393. STORE16_MASK_COMPLETE_RESULT(result_512_tmp_1, (C_addr + ldc*i + 16), tail_mask)
  394. }
  395. } else {
  396. for (int i = 0; i < n; i++) {
  397. result_512_tmp_0 = _mm512_permutex2var_ps(result_512[i], shuffle_idx_base0, result_512[i+8]);
  398. result_512_tmp_1 = _mm512_permutex2var_ps(result_512[i], shuffle_idx_base1, result_512[i+8]);
  399. STORE16_COMPLETE_RESULT(result_512_tmp_0, (C_addr + ldc*i))
  400. STORE16_COMPLETE_RESULT(result_512_tmp_1, (C_addr + ldc*i + 16))
  401. }
  402. }
  403. }
  404. // SBGEMM Kernel for 16<=M, N<8, K can be any number, but the processing will take 32 as a base
  405. #ifndef ONE_ALPHA // ALPHA is not ONE
  406. void sbgemm_block_kernel_nn_16xNx32_alpha(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
  407. #else // ALPHA is ONE
  408. void sbgemm_block_kernel_nn_16xNx32_one(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
  409. #endif
  410. {
  411. bfloat16 * A_addr = A;
  412. bfloat16 * B_addr = B;
  413. float * C_addr = C;
  414. BLASLONG tag_k_32x = k & (~31);
  415. #ifndef ONE_ALPHA
  416. __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
  417. #endif
  418. __m512i arrayA_512;
  419. __m512i arrayB_512[8];
  420. __m512 result_512[8];
  421. for (int i = 0; i < 8; i += 2) {
  422. result_512[i] = _mm512_setzero_ps();
  423. result_512[i+1] = _mm512_setzero_ps();
  424. }
  425. for (BLASLONG idx_k = 0; idx_k < tag_k_32x; idx_k += 32) {
  426. // Load B with unroll n
  427. for (int i = 0; i < n; i++) {
  428. arrayB_512[i] = _mm512_loadu_si512(B_addr);
  429. B_addr += 32;
  430. }
  431. for (BLASLONG idx = 0; idx < 32;) {
  432. // Each two rows are a group for 32-pair bf16 elements
  433. // Load two rows into a 512 register
  434. arrayA_512 = _mm512_loadu_si512(A_addr);
  435. A_addr += 32;
  436. for (int i = 0; i < n; i ++) {
  437. result_512[i] = _mm512_dpbf16_ps(result_512[i], (__m512bh) arrayA_512, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512[i])));
  438. arrayB_512[i] = _mm512_shuffle_epi32(arrayB_512[i], SHUFFLE_MAGIC_NO);
  439. }
  440. idx += 2;
  441. // Every 4 loops we need to switch to next 128 bits of arrayB registers
  442. if ((idx & (~7)) == idx) {
  443. for (int i = 0; i < n; i++) {
  444. arrayB_512[i] = _mm512_shuffle_i32x4(arrayB_512[i], arrayB_512[i], SHUFFLE_MAGIC_NO);
  445. }
  446. }
  447. }
  448. }
  449. if (tag_k_32x != k) {
  450. // Load B with unroll n
  451. for (int i = 0; i < n; i++) {
  452. arrayB_512[i] = _mm512_loadu_si512(B_addr);
  453. B_addr += 32;
  454. }
  455. BLASLONG width = k - tag_k_32x;
  456. for (BLASLONG idx = 0; idx < width;) {
  457. // Each two rows are a group for 32-pair bf16 elements
  458. // Load two rows into a 512 register
  459. arrayA_512 = _mm512_loadu_si512(A_addr);
  460. A_addr += 32;
  461. for (int i = 0; i < n; i++) {
  462. result_512[i] = _mm512_dpbf16_ps(result_512[i], (__m512bh) arrayA_512, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512[i])));
  463. arrayB_512[i] = _mm512_shuffle_epi32(arrayB_512[i], SHUFFLE_MAGIC_NO);
  464. }
  465. idx += 2;
  466. // Every 4 loops we need to switch to next 128 bits of arrayB registers
  467. if ((idx & (~7)) == idx) {
  468. for (int i = 0; i < n; i++) {
  469. arrayB_512[i] = _mm512_shuffle_i32x4(arrayB_512[i], arrayB_512[i], SHUFFLE_MAGIC_NO);
  470. }
  471. }
  472. }
  473. }
  474. if (m != 16) {
  475. unsigned short tail_mask = (((unsigned short)0xffff) >> (16-m));
  476. for (int i = 0; i < n; i++) {
  477. result_512[i] = _mm512_shuffle_f32x4(result_512[i], result_512[i], 0xd8);
  478. STORE16_MASK_COMPLETE_RESULT(result_512[i], (C_addr + ldc*i), tail_mask)
  479. }
  480. } else {
  481. for (int i = 0; i < n; i++) {
  482. result_512[i] = _mm512_shuffle_f32x4(result_512[i], result_512[i], 0xd8);
  483. STORE16_COMPLETE_RESULT(result_512[i], (C_addr + ldc*i))
  484. }
  485. }
  486. }
  487. #ifndef ONE_ALPHA // ALPHA is not ONE
  488. void sbgemm_blocking_kernel_nn_alpha(blasint M, blasint N, blasint K, float alpha, bfloat16 *A, blasint lda, bfloat16 *B, blasint ldb, float *C, blasint ldc, bfloat16 * block_A, bfloat16 * block_B)
  489. #else // ALPHA is ONE
  490. void sbgemm_blocking_kernel_nn_one(blasint M, blasint N, blasint K, float alpha, bfloat16 *A, blasint lda, bfloat16 *B, blasint ldb, float *C, blasint ldc, bfloat16 * block_A, bfloat16 * block_B)
  491. #endif
  492. {
  493. BLASLONG m_step, n_step, k_step, k_step_round32;
  494. BLASLONG tag_m_Nx = M & (~(BF16_BLOCK_THRES_M-1));
  495. BLASLONG n_from, n_to;
  496. BLASLONG tag_n_Nx;
  497. n_from = 0;
  498. n_to = (BF16_BLOCK_THRES_N > N) ? N : BF16_BLOCK_THRES_N;
  499. tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1));
  500. k_step = (K > BF16_BLOCK_THRES_K) ? BF16_BLOCK_THRES_K : K;
  501. k_step_round32 = k_step & (~31);
  502. k_step_round32 = (k_step > k_step_round32) ? (k_step_round32 + 32) : k_step_round32;
  503. if (M >= BF16_BLOCK_THRES_M) {
  504. while (n_from < N) {
  505. for (BLASLONG idx_k = 0; idx_k < K;) {
  506. // Use Kx32 kernel when BF16_BLOCK_THRES_M==32, Kx16 kernel when BF16_BLOCK_THRES_M==16, ...
  507. COL_MAJOR_INCOPY_KERNEL_Kx32(k_step, 32, &A(idx_k, 0), lda, block_A);
  508. for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
  509. // Use 8x32 kernel when BF16_BLOCK_THRES_N==8, 4x32 kernel when BF16_BLOCK_THRES_N==4, ...
  510. COL_MAJOR_ONCOPY_KERNEL_8x32(k_step, &B(idx_n, idx_k), ldb, block_B + (idx_n-n_from)*k_step_round32);
  511. SBGEMM_BLOCK_KERNEL_NN_32x8xK(32, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, 0), ldc);
  512. }
  513. if (tag_n_Nx != n_to) {
  514. n_step = n_to - tag_n_Nx;
  515. COL_MAJOR_ONCOPY_KERNEL_Nx32(n_step, k_step, &B(tag_n_Nx, idx_k), ldb, block_B + (tag_n_Nx-n_from)*k_step_round32);
  516. SBGEMM_BLOCK_KERNEL_NN_32xNx32(32, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, 0), ldc);
  517. }
  518. for (BLASLONG idx_m = BF16_BLOCK_THRES_M; idx_m < tag_m_Nx; idx_m += BF16_BLOCK_THRES_M) {
  519. COL_MAJOR_INCOPY_KERNEL_Kx32(k_step, 32, &A(idx_k, idx_m), lda, block_A);
  520. for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
  521. SBGEMM_BLOCK_KERNEL_NN_32x8xK(32, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, idx_m), ldc);
  522. }
  523. if (tag_n_Nx != n_to) {
  524. n_step = n_to - tag_n_Nx;
  525. SBGEMM_BLOCK_KERNEL_NN_32xNx32(32, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, idx_m), ldc);
  526. }
  527. }
  528. if (tag_m_Nx != M) {
  529. m_step = M - tag_m_Nx;
  530. if (m_step > 16) {
  531. COL_MAJOR_INCOPY_KERNEL_Kx32(k_step, m_step, &A(idx_k, tag_m_Nx), lda, block_A);
  532. for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
  533. SBGEMM_BLOCK_KERNEL_NN_32x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, tag_m_Nx), ldc);
  534. }
  535. if (tag_n_Nx != n_to) {
  536. n_step = n_to - tag_n_Nx;
  537. SBGEMM_BLOCK_KERNEL_NN_32xNx32(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, tag_m_Nx), ldc);
  538. }
  539. } else {
  540. COL_MAJOR_INCOPY_KERNEL_Kx16(k_step, m_step, &A(idx_k, tag_m_Nx), lda, block_A);
  541. for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
  542. SBGEMM_BLOCK_KERNEL_NN_16x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, tag_m_Nx), ldc);
  543. }
  544. if (tag_n_Nx != n_to) {
  545. n_step = n_to - tag_n_Nx;
  546. SBGEMM_BLOCK_KERNEL_NN_16xNx32(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, tag_m_Nx), ldc);
  547. }
  548. }
  549. }
  550. idx_k += k_step;
  551. k_step = K - idx_k;
  552. k_step = (k_step > BF16_BLOCK_THRES_K) ? BF16_BLOCK_THRES_K : k_step;
  553. k_step_round32 = k_step & (~31);
  554. k_step_round32 = (k_step > k_step_round32) ? (k_step_round32 + 32) : k_step_round32;
  555. }
  556. n_from = n_to;
  557. n_to += BF16_BLOCK_THRES_N;
  558. n_to = (n_to > N) ? N : n_to;
  559. tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1));
  560. }
  561. } else {
  562. m_step = M;
  563. if (m_step > 16) {
  564. while (n_from < N) {
  565. for (BLASLONG idx_k = 0; idx_k < K;) {
  566. // Use Kx32 kernel when BF16_BLOCK_THRES_M==32, Kx16 kernel when BF16_BLOCK_THRES_M==16, ...
  567. COL_MAJOR_INCOPY_KERNEL_Kx32(k_step, m_step, &A(idx_k, 0), lda, block_A);
  568. for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
  569. // Use 8x32 kernel when BF16_BLOCK_THRES_N==8, 4x32 kernel when BF16_BLOCK_THRES_N==4, ...
  570. COL_MAJOR_ONCOPY_KERNEL_8x32(k_step, &B(idx_n, idx_k), ldb, block_B + (idx_n-n_from)*k_step_round32);
  571. SBGEMM_BLOCK_KERNEL_NN_32x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, 0), ldc);
  572. }
  573. if (tag_n_Nx != n_to) {
  574. n_step = n_to - tag_n_Nx;
  575. COL_MAJOR_ONCOPY_KERNEL_Nx32(n_step, k_step, &B(tag_n_Nx, idx_k), ldb, block_B + (tag_n_Nx-n_from)*k_step_round32);
  576. SBGEMM_BLOCK_KERNEL_NN_32xNx32(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, 0), ldc);
  577. }
  578. idx_k += k_step;
  579. k_step = K - idx_k;
  580. k_step = (k_step > BF16_BLOCK_THRES_K) ? BF16_BLOCK_THRES_K : k_step;
  581. k_step_round32 = k_step & (~31);
  582. k_step_round32 = (k_step > k_step_round32) ? (k_step_round32 + 32) : k_step_round32;
  583. }
  584. n_from = n_to;
  585. n_to += BF16_BLOCK_THRES_N;
  586. n_to = (n_to > N) ? N : n_to;
  587. tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1));
  588. }
  589. } else {
  590. while (n_from < N) {
  591. for (BLASLONG idx_k = 0; idx_k < K;) {
  592. COL_MAJOR_INCOPY_KERNEL_Kx16(k_step, m_step, &A(idx_k, 0), lda, block_A);
  593. for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
  594. // Use 8x32 kernel when BF16_BLOCK_THRES_N==8, 4x32 kernel when BF16_BLOCK_THRES_N==4, ...
  595. COL_MAJOR_ONCOPY_KERNEL_8x32(k_step, &B(idx_n, idx_k), ldb, block_B + (idx_n-n_from)*k_step_round32);
  596. SBGEMM_BLOCK_KERNEL_NN_16x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, 0), ldc);
  597. }
  598. if (tag_n_Nx != n_to) {
  599. n_step = n_to - tag_n_Nx;
  600. COL_MAJOR_ONCOPY_KERNEL_Nx32(n_step, k_step, &B(tag_n_Nx, idx_k), ldb, block_B + (tag_n_Nx-n_from)*k_step_round32);
  601. SBGEMM_BLOCK_KERNEL_NN_16xNx32(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, 0), ldc);
  602. }
  603. idx_k += k_step;
  604. k_step = K - idx_k;
  605. k_step = (k_step > BF16_BLOCK_THRES_K) ? BF16_BLOCK_THRES_K : k_step;
  606. k_step_round32 = k_step & (~31);
  607. k_step_round32 = (k_step > k_step_round32) ? (k_step_round32 + 32) : k_step_round32;
  608. }
  609. n_from = n_to;
  610. n_to += BF16_BLOCK_THRES_N;
  611. n_to = (n_to > N) ? N : n_to;
  612. tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1));
  613. }
  614. }
  615. }
  616. }
  617. /* ----------------------------------------- End of NN kernels --------------------------------------- */
  618. /* --------------------------------------------- NT kernels ------------------------------------------ */
  619. // SBGEMM Kernel for 16<M<=32, N<8, K can be any number
  620. #ifndef ONE_ALPHA // ALPHA is not ONE
  621. void sbgemm_block_kernel_nt_32xNxK_alpha(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
  622. #else // ALPHA is ONE
  623. void sbgemm_block_kernel_nt_32xNxK_one(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
  624. #endif
  625. {
  626. bfloat16 * A_addr = A;
  627. bfloat16 * B_addr = B;
  628. float * C_addr = C;
  629. #ifndef ONE_ALPHA
  630. __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
  631. #endif
  632. __m512i arrayA_512_0, arrayA_512_1;
  633. __m512i arrayB_512[8];
  634. __m512 result_512[16];
  635. __m512 result_512_tmp_0, result_512_tmp_1;
  636. __m512i M512_EPI32_8 = _mm512_set1_epi32(8);
  637. __m512i shuffle_idx_base0 = _mm512_set_epi32(23, 22, 21, 20, 7, 6, 5, 4, 19, 18, 17, 16, 3, 2, 1, 0);
  638. __m512i shuffle_idx_base1 = _mm512_add_epi32(shuffle_idx_base0, M512_EPI32_8);
  639. result_512[0] = _mm512_setzero_ps();
  640. result_512[1] = _mm512_setzero_ps();
  641. result_512[2] = _mm512_setzero_ps();
  642. result_512[3] = _mm512_setzero_ps();
  643. result_512[4] = _mm512_setzero_ps();
  644. result_512[5] = _mm512_setzero_ps();
  645. result_512[6] = _mm512_setzero_ps();
  646. result_512[7] = _mm512_setzero_ps();
  647. result_512[8] = _mm512_setzero_ps();
  648. result_512[9] = _mm512_setzero_ps();
  649. result_512[10] = _mm512_setzero_ps();
  650. result_512[11] = _mm512_setzero_ps();
  651. result_512[12] = _mm512_setzero_ps();
  652. result_512[13] = _mm512_setzero_ps();
  653. result_512[14] = _mm512_setzero_ps();
  654. result_512[15] = _mm512_setzero_ps();
  655. for (BLASLONG idx_k = 0; idx_k < k; idx_k += 2) {
  656. // Each two rows are a group for 32-pair bf16 elements
  657. arrayA_512_0 = _mm512_loadu_si512(A_addr);
  658. arrayA_512_1 = _mm512_loadu_si512(A_addr + 32);
  659. A_addr += 64;
  660. for (int i = 0; i < n; i ++) {
  661. _MM512_BROADCASTD_EPI32(B_addr + i*2, arrayB_512[i]);
  662. }
  663. B_addr += 16;
  664. for (int i = 0; i < n; i ++) {
  665. result_512[i] = _mm512_dpbf16_ps(result_512[i], (__m512bh) arrayA_512_0, (__m512bh) arrayB_512[i]);
  666. result_512[i+8] = _mm512_dpbf16_ps(result_512[i+8], (__m512bh) arrayA_512_1, (__m512bh) arrayB_512[i]);
  667. }
  668. }
  669. if (m != 32) {
  670. unsigned short tail_mask = (((unsigned short)0xffff) >> (32-m));
  671. for (int i = 0; i < n; i ++) {
  672. result_512_tmp_0 = _mm512_permutex2var_ps(result_512[i], shuffle_idx_base0, result_512[i+8]);
  673. result_512_tmp_1 = _mm512_permutex2var_ps(result_512[i], shuffle_idx_base1, result_512[i+8]);
  674. STORE16_COMPLETE_RESULT(result_512_tmp_0, (C_addr + ldc*i))
  675. STORE16_MASK_COMPLETE_RESULT(result_512_tmp_1, (C_addr + ldc*i + 16), tail_mask)
  676. }
  677. } else {
  678. for (int i = 0; i < n; i ++) {
  679. result_512_tmp_0 = _mm512_permutex2var_ps(result_512[i], shuffle_idx_base0, result_512[i+8]);
  680. result_512_tmp_1 = _mm512_permutex2var_ps(result_512[i], shuffle_idx_base1, result_512[i+8]);
  681. STORE16_COMPLETE_RESULT(result_512_tmp_0, (C_addr + ldc*i))
  682. STORE16_COMPLETE_RESULT(result_512_tmp_1, (C_addr + ldc*i + 16))
  683. }
  684. }
  685. }
  686. // SBGEMM Kernel for M<=16, N<8, K can be any number
  687. #ifndef ONE_ALPHA // ALPHA is not ONE
  688. void sbgemm_block_kernel_nt_16xNxK_alpha(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
  689. #else // ALPHA is ONE
  690. void sbgemm_block_kernel_nt_16xNxK_one(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
  691. #endif
  692. {
  693. bfloat16 * A_addr = A;
  694. bfloat16 * B_addr = B;
  695. float * C_addr = C;
  696. #ifndef ONE_ALPHA
  697. __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
  698. #endif
  699. __m512i arrayA_512_0;
  700. __m512i arrayB_512[8];
  701. __m512 result_512[8];
  702. result_512[0] = _mm512_setzero_ps();
  703. result_512[1] = _mm512_setzero_ps();
  704. result_512[2] = _mm512_setzero_ps();
  705. result_512[3] = _mm512_setzero_ps();
  706. result_512[4] = _mm512_setzero_ps();
  707. result_512[5] = _mm512_setzero_ps();
  708. result_512[6] = _mm512_setzero_ps();
  709. result_512[7] = _mm512_setzero_ps();
  710. for (BLASLONG idx_k = 0; idx_k < k; idx_k += 2) {
  711. // Each two rows are a group for 16-pair bf16 elements
  712. // Load two rows into a 512 register
  713. arrayA_512_0 = _mm512_loadu_si512(A_addr);
  714. A_addr += 32;
  715. for (int i = 0; i < n; i ++) {
  716. _MM512_BROADCASTD_EPI32(B_addr + i*2, arrayB_512[i]);
  717. }
  718. B_addr += 16;
  719. for (int i = 0; i < n; i ++) {
  720. result_512[i] = _mm512_dpbf16_ps(result_512[i], (__m512bh) arrayA_512_0, (__m512bh) arrayB_512[i]);
  721. }
  722. }
  723. if (m != 16) {
  724. unsigned short tail_mask = (((unsigned short)0xffff) >> (16-m));
  725. for (int i = 0; i < n; i++) {
  726. result_512[i] = _mm512_shuffle_f32x4(result_512[i], result_512[i], 0xd8);
  727. STORE16_MASK_COMPLETE_RESULT(result_512[i], (C_addr + ldc*i), tail_mask)
  728. }
  729. } else {
  730. for (int i = 0; i < n; i++) {
  731. result_512[i] = _mm512_shuffle_f32x4(result_512[i], result_512[i], 0xd8);
  732. STORE16_COMPLETE_RESULT(result_512[i], (C_addr + ldc*i))
  733. }
  734. }
  735. }
  736. #ifndef ONE_ALPHA // ALPHA is not ONE
  737. void sbgemm_blocking_kernel_nt_alpha(blasint M, blasint N, blasint K, float alpha, bfloat16 *A, blasint lda, bfloat16 *B, blasint ldb, float *C, blasint ldc, bfloat16 * block_A, bfloat16 * block_B)
  738. #else // ALPHA is ONE
  739. void sbgemm_blocking_kernel_nt_one(blasint M, blasint N, blasint K, float alpha, bfloat16 *A, blasint lda, bfloat16 *B, blasint ldb, float *C, blasint ldc, bfloat16 * block_A, bfloat16 * block_B)
  740. #endif
  741. {
  742. BLASLONG m_step, n_step, k_step, k_step_round32;
  743. BLASLONG tag_m_Nx = M & (~(BF16_BLOCK_THRES_M-1));
  744. BLASLONG n_from, n_to;
  745. BLASLONG tag_n_Nx;
  746. n_from = 0;
  747. n_to = (BF16_BLOCK_THRES_N > N) ? N : BF16_BLOCK_THRES_N;
  748. tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1));
  749. k_step = (K > BF16_BLOCK_THRES_K) ? BF16_BLOCK_THRES_K : K;
  750. k_step_round32 = k_step & (~31);
  751. k_step_round32 = (k_step > k_step_round32) ? (k_step_round32 + 32) : k_step_round32;
  752. if (M >= BF16_BLOCK_THRES_M) {
  753. while (n_from < N) {
  754. for (BLASLONG idx_k = 0; idx_k < K;) {
  755. // Use Kx32 kernel when BF16_BLOCK_THRES_M==32, Kx16 kernel when BF16_BLOCK_THRES_M==16, ...
  756. COL_MAJOR_INCOPY_KERNEL_Kx32(k_step, 32, &A(idx_k, 0), lda, block_A);
  757. for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
  758. // Use 8x32 kernel when BF16_BLOCK_THRES_N==8, 4x32 kernel when BF16_BLOCK_THRES_N==4, ...
  759. COL_MAJOR_OTCOPY_KERNEL_Kx8(k_step, &B(idx_k, idx_n), ldb, block_B + (idx_n-n_from)*k_step_round32);
  760. SBGEMM_BLOCK_KERNEL_NT_32x8xK(32, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, 0), ldc);
  761. }
  762. if (tag_n_Nx != n_to) {
  763. n_step = n_to - tag_n_Nx;
  764. COL_MAJOR_OTCOPY_KERNEL_Kx8m(k_step, n_step, &B(idx_k, tag_n_Nx), ldb, block_B + (tag_n_Nx-n_from)*k_step_round32);
  765. SBGEMM_BLOCK_KERNEL_NT_32xNxK(32, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, 0), ldc);
  766. }
  767. for (BLASLONG idx_m = BF16_BLOCK_THRES_M; idx_m < tag_m_Nx; idx_m += BF16_BLOCK_THRES_M) {
  768. COL_MAJOR_INCOPY_KERNEL_Kx32(k_step, 32, &A(idx_k, idx_m), lda, block_A);
  769. for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
  770. SBGEMM_BLOCK_KERNEL_NT_32x8xK(32, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, idx_m), ldc);
  771. }
  772. if (tag_n_Nx != n_to) {
  773. n_step = n_to - tag_n_Nx;
  774. SBGEMM_BLOCK_KERNEL_NT_32xNxK(32, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, idx_m), ldc);
  775. }
  776. }
  777. if (tag_m_Nx != M) {
  778. m_step = M - tag_m_Nx;
  779. if (m_step > 16) {
  780. COL_MAJOR_INCOPY_KERNEL_Kx32(k_step, m_step, &A(idx_k, tag_m_Nx), lda, block_A);
  781. for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
  782. SBGEMM_BLOCK_KERNEL_NT_32x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, tag_m_Nx), ldc);
  783. }
  784. if (tag_n_Nx != n_to) {
  785. n_step = n_to - tag_n_Nx;
  786. SBGEMM_BLOCK_KERNEL_NT_32xNxK(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, tag_m_Nx), ldc);
  787. }
  788. } else {
  789. COL_MAJOR_INCOPY_KERNEL_Kx16(k_step, m_step, &A(idx_k, tag_m_Nx), lda, block_A);
  790. for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
  791. SBGEMM_BLOCK_KERNEL_NT_16x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, tag_m_Nx), ldc);
  792. }
  793. if (tag_n_Nx != n_to) {
  794. n_step = n_to - tag_n_Nx;
  795. SBGEMM_BLOCK_KERNEL_NT_16xNxK(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, tag_m_Nx), ldc);
  796. }
  797. }
  798. }
  799. idx_k += k_step;
  800. k_step = K - idx_k;
  801. k_step = (k_step > BF16_BLOCK_THRES_K) ? BF16_BLOCK_THRES_K : k_step;
  802. k_step_round32 = k_step & (~31);
  803. k_step_round32 = (k_step > k_step_round32) ? (k_step_round32 + 32) : k_step_round32;
  804. }
  805. n_from = n_to;
  806. n_to += BF16_BLOCK_THRES_N;
  807. n_to = (n_to > N) ? N : n_to;
  808. tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1));
  809. }
  810. } else {
  811. m_step = M;
  812. if (m_step > 16) {
  813. while (n_from < N) {
  814. for (BLASLONG idx_k = 0; idx_k < K;) {
  815. // Use Kx32 kernel when BF16_BLOCK_THRES_M==32, Kx16 kernel when BF16_BLOCK_THRES_M==16, ...
  816. COL_MAJOR_INCOPY_KERNEL_Kx32(k_step, m_step, &A(idx_k, 0), lda, block_A);
  817. for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
  818. // Use 8x32 kernel when BF16_BLOCK_THRES_N==8, 4x32 kernel when BF16_BLOCK_THRES_N==4, ...
  819. COL_MAJOR_OTCOPY_KERNEL_Kx8(k_step, &B(idx_k, idx_n), ldb, block_B + (idx_n-n_from)*k_step_round32);
  820. SBGEMM_BLOCK_KERNEL_NT_32x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, 0), ldc);
  821. }
  822. if (tag_n_Nx != n_to) {
  823. n_step = n_to - tag_n_Nx;
  824. COL_MAJOR_OTCOPY_KERNEL_Kx8m(k_step, n_step, &B(idx_k, tag_n_Nx), ldb, block_B + (tag_n_Nx-n_from)*k_step_round32);
  825. SBGEMM_BLOCK_KERNEL_NT_32xNxK(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, 0), ldc);
  826. }
  827. idx_k += k_step;
  828. k_step = K - idx_k;
  829. k_step = (k_step > BF16_BLOCK_THRES_K) ? BF16_BLOCK_THRES_K : k_step;
  830. k_step_round32 = k_step & (~31);
  831. k_step_round32 = (k_step > k_step_round32) ? (k_step_round32 + 32) : k_step_round32;
  832. }
  833. n_from = n_to;
  834. n_to += BF16_BLOCK_THRES_N;
  835. n_to = (n_to > N) ? N : n_to;
  836. tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1));
  837. }
  838. } else {
  839. while (n_from < N) {
  840. for (BLASLONG idx_k = 0; idx_k < K;) {
  841. // Use Kx32 kernel when BF16_BLOCK_THRES_M==32, Kx16 kernel when BF16_BLOCK_THRES_M==16, ...
  842. COL_MAJOR_INCOPY_KERNEL_Kx16(k_step, m_step, &A(idx_k, 0), lda, block_A);
  843. for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
  844. // Use 8x32 kernel when BF16_BLOCK_THRES_N==8, 4x32 kernel when BF16_BLOCK_THRES_N==4, ...
  845. COL_MAJOR_OTCOPY_KERNEL_Kx8(k_step, &B(idx_k, idx_n), ldb, block_B + (idx_n-n_from)*k_step_round32);
  846. SBGEMM_BLOCK_KERNEL_NT_16x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, 0), ldc);
  847. }
  848. if (tag_n_Nx != n_to) {
  849. n_step = n_to - tag_n_Nx;
  850. COL_MAJOR_OTCOPY_KERNEL_Kx8m(k_step, n_step, &B(idx_k, tag_n_Nx), ldb, block_B + (tag_n_Nx-n_from)*k_step_round32);
  851. SBGEMM_BLOCK_KERNEL_NT_16xNxK(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, 0), ldc);
  852. }
  853. idx_k += k_step;
  854. k_step = K - idx_k;
  855. k_step = (k_step > BF16_BLOCK_THRES_K) ? BF16_BLOCK_THRES_K : k_step;
  856. k_step_round32 = k_step & (~31);
  857. k_step_round32 = (k_step > k_step_round32) ? (k_step_round32 + 32) : k_step_round32;
  858. }
  859. n_from = n_to;
  860. n_to += BF16_BLOCK_THRES_N;
  861. n_to = (n_to > N) ? N : n_to;
  862. tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1));
  863. }
  864. }
  865. }
  866. }
  867. /* ----------------------------------------- End of NT kernels --------------------------------------- */
  868. /* --------------------------------------------- TN kernels ------------------------------------------ */
  869. // SBGEMM Kernel for 16<M<=32, N=8, K=Any number
  870. #ifndef ONE_ALPHA // ALPHA is not ONE
  871. void sbgemm_block_kernel_tn_32x8xK_alpha(BLASLONG m, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
  872. #else // ALPHA is ONE
  873. void sbgemm_block_kernel_tn_32x8xK_one(BLASLONG m, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
  874. #endif
  875. {
  876. bfloat16 * A_addr = A;
  877. bfloat16 * B_addr = B;
  878. float * C_addr = C;
  879. #ifndef ONE_ALPHA
  880. __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
  881. #endif
  882. __m512i arrayA_512_0, arrayA_512_1;
  883. __m512i arrayB_512_0, arrayB_512_1, arrayB_512_2, arrayB_512_3, arrayB_512_4, arrayB_512_5, arrayB_512_6, arrayB_512_7;
  884. __m512 result_512_0, result_512_1, result_512_2, result_512_3, result_512_4, result_512_5, result_512_6, result_512_7,
  885. result_512_8, result_512_9, result_512_10, result_512_11, result_512_12, result_512_13, result_512_14, result_512_15;
  886. result_512_0 = _mm512_setzero_ps();
  887. result_512_1 = _mm512_setzero_ps();
  888. result_512_2 = _mm512_setzero_ps();
  889. result_512_3 = _mm512_setzero_ps();
  890. result_512_4 = _mm512_setzero_ps();
  891. result_512_5 = _mm512_setzero_ps();
  892. result_512_6 = _mm512_setzero_ps();
  893. result_512_7 = _mm512_setzero_ps();
  894. result_512_8 = _mm512_setzero_ps();
  895. result_512_9 = _mm512_setzero_ps();
  896. result_512_10 = _mm512_setzero_ps();
  897. result_512_11 = _mm512_setzero_ps();
  898. result_512_12 = _mm512_setzero_ps();
  899. result_512_13 = _mm512_setzero_ps();
  900. result_512_14 = _mm512_setzero_ps();
  901. result_512_15 = _mm512_setzero_ps();
  902. for (BLASLONG idx_k = 0; idx_k < k; idx_k += 2) {
  903. // Load 32 pair of BF16 elements from A (32 rows)
  904. arrayA_512_0 = _mm512_loadu_si512(A_addr);
  905. arrayA_512_1 = _mm512_loadu_si512(A_addr + 32);
  906. // Load 8 rows of B
  907. _MM512_BROADCASTD_EPI32(B_addr + 0, arrayB_512_0);
  908. _MM512_BROADCASTD_EPI32(B_addr + 2, arrayB_512_1);
  909. _MM512_BROADCASTD_EPI32(B_addr + 4, arrayB_512_2);
  910. _MM512_BROADCASTD_EPI32(B_addr + 6, arrayB_512_3);
  911. _MM512_BROADCASTD_EPI32(B_addr + 8, arrayB_512_4);
  912. _MM512_BROADCASTD_EPI32(B_addr + 10, arrayB_512_5);
  913. _MM512_BROADCASTD_EPI32(B_addr + 12, arrayB_512_6);
  914. _MM512_BROADCASTD_EPI32(B_addr + 14, arrayB_512_7);
  915. result_512_0 = _mm512_dpbf16_ps(result_512_0, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_0);
  916. result_512_1 = _mm512_dpbf16_ps(result_512_1, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_1);
  917. result_512_2 = _mm512_dpbf16_ps(result_512_2, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_2);
  918. result_512_3 = _mm512_dpbf16_ps(result_512_3, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_3);
  919. result_512_4 = _mm512_dpbf16_ps(result_512_4, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_4);
  920. result_512_5 = _mm512_dpbf16_ps(result_512_5, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_5);
  921. result_512_6 = _mm512_dpbf16_ps(result_512_6, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_6);
  922. result_512_7 = _mm512_dpbf16_ps(result_512_7, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_7);
  923. result_512_8 = _mm512_dpbf16_ps(result_512_8, (__m512bh) arrayA_512_1, (__m512bh) arrayB_512_0);
  924. result_512_9 = _mm512_dpbf16_ps(result_512_9, (__m512bh) arrayA_512_1, (__m512bh) arrayB_512_1);
  925. result_512_10 = _mm512_dpbf16_ps(result_512_10, (__m512bh) arrayA_512_1, (__m512bh) arrayB_512_2);
  926. result_512_11 = _mm512_dpbf16_ps(result_512_11, (__m512bh) arrayA_512_1, (__m512bh) arrayB_512_3);
  927. result_512_12 = _mm512_dpbf16_ps(result_512_12, (__m512bh) arrayA_512_1, (__m512bh) arrayB_512_4);
  928. result_512_13 = _mm512_dpbf16_ps(result_512_13, (__m512bh) arrayA_512_1, (__m512bh) arrayB_512_5);
  929. result_512_14 = _mm512_dpbf16_ps(result_512_14, (__m512bh) arrayA_512_1, (__m512bh) arrayB_512_6);
  930. result_512_15 = _mm512_dpbf16_ps(result_512_15, (__m512bh) arrayA_512_1, (__m512bh) arrayB_512_7);
  931. // Load B with unroll 8
  932. B_addr += 16;
  933. // Load A with unroll 64
  934. A_addr += 64;
  935. }
  936. if (m != 32) {
  937. unsigned short tail_mask_value = (((unsigned short)0xffff) >> (32-m));
  938. __mmask16 tail_mask = *((__mmask16*) &tail_mask_value);
  939. STORE16_COMPLETE_RESULT(result_512_0, (C_addr))
  940. STORE16_MASK_COMPLETE_RESULT(result_512_8, (C_addr + 16), tail_mask)
  941. STORE16_COMPLETE_RESULT(result_512_1, (C_addr + ldc))
  942. STORE16_MASK_COMPLETE_RESULT(result_512_9, (C_addr + ldc + 16), tail_mask)
  943. STORE16_COMPLETE_RESULT(result_512_2, (C_addr + ldc*2))
  944. STORE16_MASK_COMPLETE_RESULT(result_512_10, (C_addr + ldc*2 + 16), tail_mask)
  945. STORE16_COMPLETE_RESULT(result_512_3, (C_addr + ldc*3))
  946. STORE16_MASK_COMPLETE_RESULT(result_512_11, (C_addr + ldc*3 + 16), tail_mask)
  947. STORE16_COMPLETE_RESULT(result_512_4, (C_addr + ldc*4))
  948. STORE16_MASK_COMPLETE_RESULT(result_512_12, (C_addr + ldc*4 + 16), tail_mask)
  949. STORE16_COMPLETE_RESULT(result_512_5, (C_addr + ldc*5))
  950. STORE16_MASK_COMPLETE_RESULT(result_512_13, (C_addr + ldc*5 + 16), tail_mask)
  951. STORE16_COMPLETE_RESULT(result_512_6, (C_addr + ldc*6))
  952. STORE16_MASK_COMPLETE_RESULT(result_512_14, (C_addr + ldc*6 + 16), tail_mask)
  953. STORE16_COMPLETE_RESULT(result_512_7, (C_addr + ldc*7))
  954. STORE16_MASK_COMPLETE_RESULT(result_512_15, (C_addr + ldc*7 + 16), tail_mask)
  955. } else {
  956. STORE16_COMPLETE_RESULT(result_512_0, (C_addr))
  957. STORE16_COMPLETE_RESULT(result_512_8, (C_addr + 16))
  958. STORE16_COMPLETE_RESULT(result_512_1, (C_addr + ldc))
  959. STORE16_COMPLETE_RESULT(result_512_9, (C_addr + ldc + 16))
  960. STORE16_COMPLETE_RESULT(result_512_2, (C_addr + ldc*2))
  961. STORE16_COMPLETE_RESULT(result_512_10, (C_addr + ldc*2 + 16))
  962. STORE16_COMPLETE_RESULT(result_512_3, (C_addr + ldc*3))
  963. STORE16_COMPLETE_RESULT(result_512_11, (C_addr + ldc*3 + 16))
  964. STORE16_COMPLETE_RESULT(result_512_4, (C_addr + ldc*4))
  965. STORE16_COMPLETE_RESULT(result_512_12, (C_addr + ldc*4 + 16))
  966. STORE16_COMPLETE_RESULT(result_512_5, (C_addr + ldc*5))
  967. STORE16_COMPLETE_RESULT(result_512_13, (C_addr + ldc*5 + 16))
  968. STORE16_COMPLETE_RESULT(result_512_6, (C_addr + ldc*6))
  969. STORE16_COMPLETE_RESULT(result_512_14, (C_addr + ldc*6 + 16))
  970. STORE16_COMPLETE_RESULT(result_512_7, (C_addr + ldc*7))
  971. STORE16_COMPLETE_RESULT(result_512_15, (C_addr + ldc*7 + 16))
  972. }
  973. }
  974. // SBGEMM Kernel for M=16, N=8, K=Any number
  975. #ifndef ONE_ALPHA // ALPHA is not ONE
  976. void sbgemm_block_kernel_tn_16x8xK_alpha(BLASLONG m, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
  977. #else // ALPHA is ONE
  978. void sbgemm_block_kernel_tn_16x8xK_one(BLASLONG m, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
  979. #endif
  980. {
  981. bfloat16 * A_addr = A;
  982. bfloat16 * B_addr = B;
  983. float * C_addr = C;
  984. #ifndef ONE_ALPHA
  985. __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
  986. #endif
  987. __m512i arrayA_512_0;
  988. __m512i arrayB_512_0, arrayB_512_1, arrayB_512_2, arrayB_512_3, arrayB_512_4, arrayB_512_5, arrayB_512_6, arrayB_512_7;
  989. __m512 result_512_0, result_512_1, result_512_2, result_512_3, result_512_4, result_512_5, result_512_6, result_512_7;
  990. result_512_0 = _mm512_setzero_ps();
  991. result_512_1 = _mm512_setzero_ps();
  992. result_512_2 = _mm512_setzero_ps();
  993. result_512_3 = _mm512_setzero_ps();
  994. result_512_4 = _mm512_setzero_ps();
  995. result_512_5 = _mm512_setzero_ps();
  996. result_512_6 = _mm512_setzero_ps();
  997. result_512_7 = _mm512_setzero_ps();
  998. for (BLASLONG idx_k = 0; idx_k < k; idx_k += 2) {
  999. // Load 16 pair of BF16 elements from A (16 rows)
  1000. arrayA_512_0 = _mm512_loadu_si512(A_addr + 0);
  1001. // Load 8 rows of B
  1002. _MM512_BROADCASTD_EPI32(B_addr + 0, arrayB_512_0);
  1003. _MM512_BROADCASTD_EPI32(B_addr + 2, arrayB_512_1);
  1004. _MM512_BROADCASTD_EPI32(B_addr + 4, arrayB_512_2);
  1005. _MM512_BROADCASTD_EPI32(B_addr + 6, arrayB_512_3);
  1006. _MM512_BROADCASTD_EPI32(B_addr + 8, arrayB_512_4);
  1007. _MM512_BROADCASTD_EPI32(B_addr + 10, arrayB_512_5);
  1008. _MM512_BROADCASTD_EPI32(B_addr + 12, arrayB_512_6);
  1009. _MM512_BROADCASTD_EPI32(B_addr + 14, arrayB_512_7);
  1010. result_512_0 = _mm512_dpbf16_ps(result_512_0, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_0);
  1011. result_512_1 = _mm512_dpbf16_ps(result_512_1, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_1);
  1012. result_512_2 = _mm512_dpbf16_ps(result_512_2, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_2);
  1013. result_512_3 = _mm512_dpbf16_ps(result_512_3, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_3);
  1014. result_512_4 = _mm512_dpbf16_ps(result_512_4, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_4);
  1015. result_512_5 = _mm512_dpbf16_ps(result_512_5, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_5);
  1016. result_512_6 = _mm512_dpbf16_ps(result_512_6, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_6);
  1017. result_512_7 = _mm512_dpbf16_ps(result_512_7, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_7);
  1018. // Load B with unroll 8
  1019. B_addr += 16;
  1020. // Load A with unroll 32
  1021. A_addr += 32;
  1022. }
  1023. if (m != 16) {
  1024. unsigned short tail_mask_value = (((unsigned short)0xffff) >> (16-m));
  1025. __mmask16 tail_mask = *((__mmask16*) &tail_mask_value);
  1026. STORE16_MASK_COMPLETE_RESULT(result_512_0, (C_addr), tail_mask)
  1027. STORE16_MASK_COMPLETE_RESULT(result_512_1, (C_addr + ldc), tail_mask)
  1028. STORE16_MASK_COMPLETE_RESULT(result_512_2, (C_addr + ldc*2), tail_mask)
  1029. STORE16_MASK_COMPLETE_RESULT(result_512_3, (C_addr + ldc*3), tail_mask)
  1030. STORE16_MASK_COMPLETE_RESULT(result_512_4, (C_addr + ldc*4), tail_mask)
  1031. STORE16_MASK_COMPLETE_RESULT(result_512_5, (C_addr + ldc*5), tail_mask)
  1032. STORE16_MASK_COMPLETE_RESULT(result_512_6, (C_addr + ldc*6), tail_mask)
  1033. STORE16_MASK_COMPLETE_RESULT(result_512_7, (C_addr + ldc*7), tail_mask)
  1034. } else {
  1035. STORE16_COMPLETE_RESULT(result_512_0, (C_addr))
  1036. STORE16_COMPLETE_RESULT(result_512_1, (C_addr + ldc))
  1037. STORE16_COMPLETE_RESULT(result_512_2, (C_addr + ldc*2))
  1038. STORE16_COMPLETE_RESULT(result_512_3, (C_addr + ldc*3))
  1039. STORE16_COMPLETE_RESULT(result_512_4, (C_addr + ldc*4))
  1040. STORE16_COMPLETE_RESULT(result_512_5, (C_addr + ldc*5))
  1041. STORE16_COMPLETE_RESULT(result_512_6, (C_addr + ldc*6))
  1042. STORE16_COMPLETE_RESULT(result_512_7, (C_addr + ldc*7))
  1043. }
  1044. }
  1045. // SBGEMM Kernel for 16<M<=32, N<8, K=Any number but will be processed based on 32
  1046. #ifndef ONE_ALPHA // ALPHA is not ONE
  1047. void sbgemm_block_kernel_tn_32xNx32_alpha(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
  1048. #else // ALPHA is ONE
  1049. void sbgemm_block_kernel_tn_32xNx32_one(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
  1050. #endif
  1051. {
  1052. bfloat16 * A_addr = A;
  1053. bfloat16 * B_addr = B;
  1054. float * C_addr = C;
  1055. BLASLONG tag_k_32x = k & (~31);
  1056. #ifndef ONE_ALPHA
  1057. __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
  1058. #endif
  1059. __m512i arrayA_512[2];
  1060. __m512i arrayB_512[8];
  1061. __m512 result_512[16];
  1062. for (int i = 0; i < 15; i++) {
  1063. result_512[i] = _mm512_setzero_ps();
  1064. }
  1065. for (BLASLONG idx_k = 0; idx_k < tag_k_32x; idx_k += 32) {
  1066. // Load B with unroll n
  1067. for (int i = 0; i < n; i ++) {
  1068. arrayB_512[i] = _mm512_loadu_si512(B_addr);
  1069. B_addr += 32;
  1070. }
  1071. for (BLASLONG idx = 0; idx < 32;) {
  1072. // Each two rows are a group for 32-pair bf16 elements
  1073. arrayA_512[0] = _mm512_loadu_si512(A_addr);
  1074. arrayA_512[1] = _mm512_loadu_si512(A_addr + 32);
  1075. A_addr += 64;
  1076. for (int i = 0; i < n; i++) {
  1077. result_512[i] = _mm512_dpbf16_ps(result_512[i] , (__m512bh) arrayA_512[0], (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512[i])));
  1078. result_512[i+8] = _mm512_dpbf16_ps(result_512[i+8], (__m512bh) arrayA_512[1], (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512[i])));
  1079. arrayB_512[i] = _mm512_shuffle_epi32(arrayB_512[i], SHUFFLE_MAGIC_NO);
  1080. }
  1081. idx += 2;
  1082. // Every 4 loops we need to switch to next 128 bits of arrayB registers
  1083. if ((idx & (~7)) == idx) {
  1084. for (int i = 0; i < n; i++) {
  1085. arrayB_512[i] = _mm512_shuffle_i32x4(arrayB_512[i], arrayB_512[i], SHUFFLE_MAGIC_NO);
  1086. }
  1087. }
  1088. }
  1089. }
  1090. if (tag_k_32x != k) {
  1091. // Load B with unroll n
  1092. for (int i = 0; i < n; i ++) {
  1093. arrayB_512[i] = _mm512_loadu_si512(B_addr);
  1094. B_addr += 32;
  1095. }
  1096. BLASLONG width = k - tag_k_32x;
  1097. for (BLASLONG idx = 0; idx < width;) {
  1098. // Each two rows are a group for 32-pair bf16 elements
  1099. arrayA_512[0] = _mm512_loadu_si512(A_addr);
  1100. arrayA_512[1] = _mm512_loadu_si512(A_addr + 32);
  1101. A_addr += 64;
  1102. for (int i = 0; i < n; i++) {
  1103. result_512[i] = _mm512_dpbf16_ps(result_512[i] , (__m512bh) arrayA_512[0], (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512[i])));
  1104. result_512[i+8] = _mm512_dpbf16_ps(result_512[i+8], (__m512bh) arrayA_512[1], (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512[i])));
  1105. arrayB_512[i] = _mm512_shuffle_epi32(arrayB_512[i], SHUFFLE_MAGIC_NO);
  1106. }
  1107. idx += 2;
  1108. // Every 4 loops we need to switch to next 128 bits of arrayB registers
  1109. if ((idx & (~7)) == idx) {
  1110. for (int i = 0; i < n; i++) {
  1111. arrayB_512[i] = _mm512_shuffle_i32x4(arrayB_512[i], arrayB_512[i], SHUFFLE_MAGIC_NO);
  1112. }
  1113. }
  1114. }
  1115. }
  1116. if (m != 32) {
  1117. unsigned short tail_mask = (((unsigned short)0xffff) >> (32-m));
  1118. for (int i = 0; i < n; i++) {
  1119. STORE16_COMPLETE_RESULT(result_512[i], (C_addr + ldc*i))
  1120. STORE16_MASK_COMPLETE_RESULT(result_512[i+8], (C_addr + ldc*i + 16), tail_mask)
  1121. }
  1122. } else {
  1123. for (int i = 0; i < n; i++) {
  1124. STORE16_COMPLETE_RESULT(result_512[i], (C_addr + ldc*i))
  1125. STORE16_COMPLETE_RESULT(result_512[i+8], (C_addr + ldc*i + 16))
  1126. }
  1127. }
  1128. }
  1129. // SBGEMM Kernel for M<=16, N<8, K=Any number but will be processed based on 32
  1130. #ifndef ONE_ALPHA // ALPHA is not ONE
  1131. void sbgemm_block_kernel_tn_16xNx32_alpha(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
  1132. #else // ALPHA is ONE
  1133. void sbgemm_block_kernel_tn_16xNx32_one(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
  1134. #endif
  1135. {
  1136. bfloat16 * A_addr = A;
  1137. bfloat16 * B_addr = B;
  1138. float * C_addr = C;
  1139. BLASLONG tag_k_32x = k & (~31);
  1140. #ifndef ONE_ALPHA
  1141. __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
  1142. #endif
  1143. __m512i arrayA_512;
  1144. __m512i arrayB_512[8];
  1145. __m512 result_512[8];
  1146. for (int i = 0; i < 8; i++) {
  1147. result_512[i] = _mm512_setzero_ps();
  1148. }
  1149. for (BLASLONG idx_k = 0; idx_k < tag_k_32x; idx_k += 32) {
  1150. // Load B with unroll n
  1151. for (int i = 0; i < n; i ++) {
  1152. arrayB_512[i] = _mm512_loadu_si512(B_addr);
  1153. B_addr += 32;
  1154. }
  1155. for (BLASLONG idx = 0; idx < 32;) {
  1156. // Each two rows are a group for 32-pair bf16 elements
  1157. arrayA_512 = _mm512_loadu_si512(A_addr);
  1158. A_addr += 32;
  1159. for (int i = 0; i < n; i++) {
  1160. result_512[i] = _mm512_dpbf16_ps(result_512[i], (__m512bh) arrayA_512, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512[i])));
  1161. arrayB_512[i] = _mm512_shuffle_epi32(arrayB_512[i], SHUFFLE_MAGIC_NO);
  1162. }
  1163. idx += 2;
  1164. // Every 4 loops we need to switch to next 128 bits of arrayB registers
  1165. if ((idx & (~7)) == idx) {
  1166. for (int i = 0; i < n; i++) {
  1167. arrayB_512[i] = _mm512_shuffle_i32x4(arrayB_512[i], arrayB_512[i], SHUFFLE_MAGIC_NO);
  1168. }
  1169. }
  1170. }
  1171. }
  1172. if (tag_k_32x != k) {
  1173. // Load B with unroll n
  1174. for (int i = 0; i < n; i ++) {
  1175. arrayB_512[i] = _mm512_loadu_si512(B_addr);
  1176. B_addr += 32;
  1177. }
  1178. BLASLONG width = k - tag_k_32x;
  1179. for (BLASLONG idx = 0; idx < width;) {
  1180. // Each two rows are a group for 32-pair bf16 elements
  1181. arrayA_512 = _mm512_loadu_si512(A_addr);
  1182. A_addr += 32;
  1183. for (int i = 0; i < n; i++) {
  1184. result_512[i] = _mm512_dpbf16_ps(result_512[i], (__m512bh) arrayA_512, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512[i])));
  1185. arrayB_512[i] = _mm512_shuffle_epi32(arrayB_512[i], SHUFFLE_MAGIC_NO);
  1186. }
  1187. idx += 2;
  1188. // Every 4 loops we need to switch to next 128 bits of arrayB registers
  1189. if ((idx & (~7)) == idx) {
  1190. for (int i = 0; i < n; i++) {
  1191. arrayB_512[i] = _mm512_shuffle_i32x4(arrayB_512[i], arrayB_512[i], SHUFFLE_MAGIC_NO);
  1192. }
  1193. }
  1194. }
  1195. }
  1196. if (m != 16) {
  1197. unsigned short tail_mask = (((unsigned short)0xffff) >> (16-m));
  1198. for (int i = 0; i < n; i++) {
  1199. STORE16_MASK_COMPLETE_RESULT(result_512[i], (C_addr + ldc*i), tail_mask)
  1200. }
  1201. } else {
  1202. for (int i = 0; i < n; i++) {
  1203. STORE16_COMPLETE_RESULT(result_512[i], (C_addr + ldc*i))
  1204. }
  1205. }
  1206. }
  1207. #ifndef ONE_ALPHA // ALPHA is not ONE
  1208. void sbgemm_blocking_kernel_tn_alpha(blasint M, blasint N, blasint K, float alpha, bfloat16 *A, blasint lda, bfloat16 *B, blasint ldb, float *C, blasint ldc, bfloat16 * block_A, bfloat16 * block_B)
  1209. #else // ALPHA is ONE
  1210. void sbgemm_blocking_kernel_tn_one(blasint M, blasint N, blasint K, float alpha, bfloat16 *A, blasint lda, bfloat16 *B, blasint ldb, float *C, blasint ldc, bfloat16 * block_A, bfloat16 * block_B)
  1211. #endif
  1212. {
  1213. BLASLONG m_step, n_step, k_step, k_step_round32;
  1214. BLASLONG tag_m_Nx = M & (~(BF16_BLOCK_THRES_M-1));
  1215. BLASLONG n_from, n_to;
  1216. BLASLONG tag_n_Nx;
  1217. n_from = 0;
  1218. n_to = (BF16_BLOCK_THRES_N > N) ? N : BF16_BLOCK_THRES_N;
  1219. tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1));
  1220. k_step = (K > BF16_BLOCK_THRES_K) ? BF16_BLOCK_THRES_K : K;
  1221. k_step_round32 = k_step & (~31);
  1222. k_step_round32 = (k_step > k_step_round32) ? (k_step_round32 + 32) : k_step_round32;
  1223. if (M >= BF16_BLOCK_THRES_M) {
  1224. while (n_from < N) {
  1225. for (BLASLONG idx_k = 0; idx_k < K;) {
  1226. // Use Kx32 kernel when BF16_BLOCK_THRES_M==32, Kx16 kernel when BF16_BLOCK_THRES_M==16, ...
  1227. COL_MAJOR_ITCOPY_KERNEL_Kx32(k_step, &A(0, idx_k), lda, block_A);
  1228. for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
  1229. // Use 8x32 kernel when BF16_BLOCK_THRES_N==8, 4x32 kernel when BF16_BLOCK_THRES_N==4, ...
  1230. COL_MAJOR_ONCOPY_KERNEL_8x32(k_step, &B(idx_n, idx_k), ldb, block_B + (idx_n-n_from)*k_step_round32);
  1231. SBGEMM_BLOCK_KERNEL_TN_32x8xK(32, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, 0), ldc); // TODO how to process m
  1232. }
  1233. if (tag_n_Nx != n_to) {
  1234. n_step = n_to - tag_n_Nx;
  1235. COL_MAJOR_ONCOPY_KERNEL_Nx32(n_step, k_step, &B(tag_n_Nx, idx_k), ldb, block_B + (tag_n_Nx-n_from)*k_step_round32);
  1236. SBGEMM_BLOCK_KERNEL_TN_32xNx32(32, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, 0), ldc);
  1237. }
  1238. for (BLASLONG idx_m = BF16_BLOCK_THRES_M; idx_m < tag_m_Nx; idx_m += BF16_BLOCK_THRES_M) {
  1239. COL_MAJOR_ITCOPY_KERNEL_Kx32(k_step, &A(idx_m, idx_k), lda, block_A);
  1240. for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
  1241. SBGEMM_BLOCK_KERNEL_TN_32x8xK(32, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, idx_m), ldc);
  1242. }
  1243. if (tag_n_Nx != n_to) {
  1244. n_step = n_to - tag_n_Nx;
  1245. SBGEMM_BLOCK_KERNEL_TN_32xNx32(32, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, idx_m), ldc);
  1246. }
  1247. }
  1248. if (tag_m_Nx != M) {
  1249. m_step = M - tag_m_Nx;
  1250. if (m_step > 16) {
  1251. COL_MAJOR_ITCOPY_KERNEL_Kx32m(m_step, k_step, &A(tag_m_Nx, idx_k), lda, block_A);
  1252. for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
  1253. SBGEMM_BLOCK_KERNEL_TN_32x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, tag_m_Nx), ldc);
  1254. }
  1255. if (tag_n_Nx != n_to) {
  1256. n_step = n_to - tag_n_Nx;
  1257. SBGEMM_BLOCK_KERNEL_TN_32xNx32(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, tag_m_Nx), ldc);
  1258. }
  1259. } else {
  1260. COL_MAJOR_ITCOPY_KERNEL_Kx16m(m_step, k_step, &A(tag_m_Nx, idx_k), lda, block_A);
  1261. for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
  1262. SBGEMM_BLOCK_KERNEL_TN_16x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, tag_m_Nx), ldc);
  1263. }
  1264. if (tag_n_Nx != n_to) {
  1265. n_step = n_to - tag_n_Nx;
  1266. SBGEMM_BLOCK_KERNEL_TN_16xNx32(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, tag_m_Nx), ldc);
  1267. }
  1268. }
  1269. }
  1270. idx_k += k_step;
  1271. k_step = K - idx_k;
  1272. k_step = (k_step > BF16_BLOCK_THRES_K) ? BF16_BLOCK_THRES_K : k_step;
  1273. k_step_round32 = k_step & (~31);
  1274. k_step_round32 = (k_step > k_step_round32) ? (k_step_round32 + 32) : k_step_round32;
  1275. }
  1276. n_from = n_to;
  1277. n_to += BF16_BLOCK_THRES_N;
  1278. n_to = (n_to > N) ? N : n_to;
  1279. tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1));
  1280. }
  1281. } else {
  1282. m_step = M;
  1283. if (m_step > 16) {
  1284. while (n_from < N) {
  1285. for (BLASLONG idx_k = 0; idx_k < K;) {
  1286. // Use Kx32 kernel when BF16_BLOCK_THRES_M==32, Kx16 kernel when BF16_BLOCK_THRES_M==16, ...
  1287. COL_MAJOR_ITCOPY_KERNEL_Kx32m(m_step, k_step, &A(0, idx_k), lda, block_A);
  1288. for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
  1289. // Use 8x32 kernel when BF16_BLOCK_THRES_N==8, 4x32 kernel when BF16_BLOCK_THRES_N==4, ...
  1290. COL_MAJOR_ONCOPY_KERNEL_8x32(k_step, &B(idx_n, idx_k), ldb, block_B + (idx_n-n_from)*k_step_round32);
  1291. SBGEMM_BLOCK_KERNEL_TN_32x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, 0), ldc);
  1292. }
  1293. if (tag_n_Nx != n_to) {
  1294. n_step = n_to - tag_n_Nx;
  1295. COL_MAJOR_ONCOPY_KERNEL_Nx32(n_step, k_step, &B(tag_n_Nx, idx_k), ldb, block_B + (tag_n_Nx-n_from)*k_step_round32);
  1296. SBGEMM_BLOCK_KERNEL_TN_32xNx32(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, 0), ldc);
  1297. }
  1298. idx_k += k_step;
  1299. k_step = K - idx_k;
  1300. k_step = (k_step > BF16_BLOCK_THRES_K) ? BF16_BLOCK_THRES_K : k_step;
  1301. k_step_round32 = k_step & (~31);
  1302. k_step_round32 = (k_step > k_step_round32) ? (k_step_round32 + 32) : k_step_round32;
  1303. }
  1304. n_from = n_to;
  1305. n_to += BF16_BLOCK_THRES_N;
  1306. n_to = (n_to > N) ? N : n_to;
  1307. tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1));
  1308. }
  1309. } else {
  1310. while (n_from < N) {
  1311. for (BLASLONG idx_k = 0; idx_k < K;) {
  1312. // Use Kx32 kernel when BF16_BLOCK_THRES_M==32, Kx16 kernel when BF16_BLOCK_THRES_M==16, ...
  1313. COL_MAJOR_ITCOPY_KERNEL_Kx16m(m_step, k_step, &A(0, idx_k), lda, block_A);
  1314. for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
  1315. // Use 8x32 kernel when BF16_BLOCK_THRES_N==8, 4x32 kernel when BF16_BLOCK_THRES_N==4, ...
  1316. COL_MAJOR_ONCOPY_KERNEL_8x32(k_step, &B(idx_n, idx_k), ldb, block_B + (idx_n-n_from)*k_step_round32);
  1317. SBGEMM_BLOCK_KERNEL_TN_16x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, 0), ldc);
  1318. }
  1319. if (tag_n_Nx != n_to) {
  1320. n_step = n_to - tag_n_Nx;
  1321. COL_MAJOR_ONCOPY_KERNEL_Nx32(n_step, k_step, &B(tag_n_Nx, idx_k), ldb, block_B + (tag_n_Nx-n_from)*k_step_round32);
  1322. SBGEMM_BLOCK_KERNEL_TN_16xNx32(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, 0), ldc);
  1323. }
  1324. idx_k += k_step;
  1325. k_step = K - idx_k;
  1326. k_step = (k_step > BF16_BLOCK_THRES_K) ? BF16_BLOCK_THRES_K : k_step;
  1327. k_step_round32 = k_step & (~31);
  1328. k_step_round32 = (k_step > k_step_round32) ? (k_step_round32 + 32) : k_step_round32;
  1329. }
  1330. n_from = n_to;
  1331. n_to += BF16_BLOCK_THRES_N;
  1332. n_to = (n_to > N) ? N : n_to;
  1333. tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1));
  1334. }
  1335. }
  1336. }
  1337. }
  1338. /* ----------------------------------------- End of TN kernels --------------------------------------- */
  1339. /* --------------------------------------------- TT kernels ------------------------------------------ */
  1340. // SBGEMM Kernel for 16<M<=32, N<8, K can be any number
  1341. #ifndef ONE_ALPHA // ALPHA is not ONE
  1342. void sbgemm_block_kernel_tt_32xNxK_alpha(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
  1343. #else // ALPHA is ONE
  1344. void sbgemm_block_kernel_tt_32xNxK_one(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
  1345. #endif
  1346. {
  1347. bfloat16 * A_addr = A;
  1348. bfloat16 * B_addr = B;
  1349. float * C_addr = C;
  1350. #ifndef ONE_ALPHA
  1351. __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
  1352. #endif
  1353. __m512i arrayA_512_0, arrayA_512_1;
  1354. __m512i arrayB_512[8];
  1355. __m512 result_512[16];
  1356. result_512[0] = _mm512_setzero_ps();
  1357. result_512[1] = _mm512_setzero_ps();
  1358. result_512[2] = _mm512_setzero_ps();
  1359. result_512[3] = _mm512_setzero_ps();
  1360. result_512[4] = _mm512_setzero_ps();
  1361. result_512[5] = _mm512_setzero_ps();
  1362. result_512[6] = _mm512_setzero_ps();
  1363. result_512[7] = _mm512_setzero_ps();
  1364. result_512[8] = _mm512_setzero_ps();
  1365. result_512[9] = _mm512_setzero_ps();
  1366. result_512[10] = _mm512_setzero_ps();
  1367. result_512[11] = _mm512_setzero_ps();
  1368. result_512[12] = _mm512_setzero_ps();
  1369. result_512[13] = _mm512_setzero_ps();
  1370. result_512[14] = _mm512_setzero_ps();
  1371. result_512[15] = _mm512_setzero_ps();
  1372. for (BLASLONG idx_k = 0; idx_k < k; idx_k += 2) {
  1373. // Each two rows are a group for 32-pair bf16 elements
  1374. arrayA_512_0 = _mm512_loadu_si512(A_addr);
  1375. arrayA_512_1 = _mm512_loadu_si512(A_addr + 32);
  1376. A_addr += 64;
  1377. for (int i = 0; i < n; i ++) {
  1378. _MM512_BROADCASTD_EPI32(B_addr + i*2, arrayB_512[i]);
  1379. }
  1380. B_addr += 16;
  1381. for (int i = 0; i < n; i ++) {
  1382. result_512[i] = _mm512_dpbf16_ps(result_512[i], (__m512bh) arrayA_512_0, (__m512bh) arrayB_512[i]);
  1383. result_512[i+8] = _mm512_dpbf16_ps(result_512[i+8], (__m512bh) arrayA_512_1, (__m512bh) arrayB_512[i]);
  1384. }
  1385. }
  1386. if (m != 32) {
  1387. unsigned short tail_mask = (((unsigned short)0xffff) >> (32-m));
  1388. for (int i = 0; i < n; i ++) {
  1389. STORE16_COMPLETE_RESULT(result_512[i], (C_addr + ldc*i))
  1390. STORE16_MASK_COMPLETE_RESULT(result_512[i+8], (C_addr + ldc*i + 16), tail_mask)
  1391. }
  1392. } else {
  1393. for (int i = 0; i < n; i ++) {
  1394. STORE16_COMPLETE_RESULT(result_512[i], (C_addr + ldc*i))
  1395. STORE16_COMPLETE_RESULT(result_512[i+8], (C_addr + ldc*i + 16))
  1396. }
  1397. }
  1398. }
  1399. // SBGEMM Kernel for M<=16, N<8, K can be any number
  1400. #ifndef ONE_ALPHA // ALPHA is not ONE
  1401. void sbgemm_block_kernel_tt_16xNxK_alpha(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
  1402. #else // ALPHA is ONE
  1403. void sbgemm_block_kernel_tt_16xNxK_one(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc)
  1404. #endif
  1405. {
  1406. bfloat16 * A_addr = A;
  1407. bfloat16 * B_addr = B;
  1408. float * C_addr = C;
  1409. #ifndef ONE_ALPHA
  1410. __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
  1411. #endif
  1412. __m512i arrayA_512_0;
  1413. __m512i arrayB_512[8];
  1414. __m512 result_512[8];
  1415. result_512[0] = _mm512_setzero_ps();
  1416. result_512[1] = _mm512_setzero_ps();
  1417. result_512[2] = _mm512_setzero_ps();
  1418. result_512[3] = _mm512_setzero_ps();
  1419. result_512[4] = _mm512_setzero_ps();
  1420. result_512[5] = _mm512_setzero_ps();
  1421. result_512[6] = _mm512_setzero_ps();
  1422. result_512[7] = _mm512_setzero_ps();
  1423. for (BLASLONG idx_k = 0; idx_k < k; idx_k += 2) {
  1424. // Each two rows are a group for 16-pair bf16 elements
  1425. // Load two rows into a 512 register
  1426. arrayA_512_0 = _mm512_loadu_si512(A_addr);
  1427. A_addr += 32;
  1428. for (int i = 0; i < n; i ++) {
  1429. _MM512_BROADCASTD_EPI32(B_addr + i*2, arrayB_512[i]);
  1430. }
  1431. B_addr += 16;
  1432. for (int i = 0; i < n; i ++) {
  1433. result_512[i] = _mm512_dpbf16_ps(result_512[i], (__m512bh) arrayA_512_0, (__m512bh) arrayB_512[i]);
  1434. }
  1435. }
  1436. if (m != 16) {
  1437. unsigned short tail_mask = (((unsigned short)0xffff) >> (16-m));
  1438. for (int i = 0; i < n; i++) {
  1439. STORE16_MASK_COMPLETE_RESULT(result_512[i], (C_addr + ldc*i), tail_mask)
  1440. }
  1441. } else {
  1442. for (int i = 0; i < n; i++) {
  1443. STORE16_COMPLETE_RESULT(result_512[i], (C_addr + ldc*i))
  1444. }
  1445. }
  1446. }
  1447. #ifndef ONE_ALPHA // ALPHA is not ONE
  1448. void sbgemm_blocking_kernel_tt_alpha(blasint M, blasint N, blasint K, float alpha, bfloat16 *A, blasint lda, bfloat16 *B, blasint ldb, float *C, blasint ldc, bfloat16 * block_A, bfloat16 * block_B)
  1449. #else // ALPHA is ONE
  1450. void sbgemm_blocking_kernel_tt_one(blasint M, blasint N, blasint K, float alpha, bfloat16 *A, blasint lda, bfloat16 *B, blasint ldb, float *C, blasint ldc, bfloat16 * block_A, bfloat16 * block_B)
  1451. #endif
  1452. {
  1453. BLASLONG m_step, n_step, k_step, k_step_round32;
  1454. BLASLONG tag_m_Nx = M & (~(BF16_BLOCK_THRES_M-1));
  1455. BLASLONG n_from, n_to;
  1456. BLASLONG tag_n_Nx;
  1457. n_from = 0;
  1458. n_to = (BF16_BLOCK_THRES_N > N) ? N : BF16_BLOCK_THRES_N;
  1459. tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1));
  1460. k_step = (K > BF16_BLOCK_THRES_K) ? BF16_BLOCK_THRES_K : K;
  1461. k_step_round32 = k_step & (~31);
  1462. k_step_round32 = (k_step > k_step_round32) ? (k_step_round32 + 32) : k_step_round32;
  1463. if (M >= BF16_BLOCK_THRES_M) {
  1464. while (n_from < N) {
  1465. for (BLASLONG idx_k = 0; idx_k < K;) {
  1466. // Use Kx32 kernel when BF16_BLOCK_THRES_M==32, Kx16 kernel when BF16_BLOCK_THRES_M==16, ...
  1467. COL_MAJOR_ITCOPY_KERNEL_Kx32(k_step, &A(0, idx_k), lda, block_A);
  1468. for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
  1469. // Use 8x32 kernel when BF16_BLOCK_THRES_N==8, 4x32 kernel when BF16_BLOCK_THRES_N==4, ...
  1470. COL_MAJOR_OTCOPY_KERNEL_Kx8(k_step, &B(idx_k, idx_n), ldb, block_B + (idx_n-n_from)*k_step_round32);
  1471. SBGEMM_BLOCK_KERNEL_TT_32x8xK(32, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, 0), ldc);
  1472. }
  1473. if (tag_n_Nx != n_to) {
  1474. n_step = n_to - tag_n_Nx;
  1475. COL_MAJOR_OTCOPY_KERNEL_Kx8m(k_step, n_step, &B(idx_k, tag_n_Nx), ldb, block_B + (tag_n_Nx-n_from)*k_step_round32);
  1476. SBGEMM_BLOCK_KERNEL_TT_32xNxK(32, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, 0), ldc);
  1477. }
  1478. for (BLASLONG idx_m = BF16_BLOCK_THRES_M; idx_m < tag_m_Nx; idx_m += BF16_BLOCK_THRES_M) {
  1479. COL_MAJOR_ITCOPY_KERNEL_Kx32(k_step, &A(idx_m, idx_k), lda, block_A);
  1480. for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
  1481. SBGEMM_BLOCK_KERNEL_TT_32x8xK(32, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, idx_m), ldc);
  1482. }
  1483. if (tag_n_Nx != n_to) {
  1484. n_step = n_to - tag_n_Nx;
  1485. SBGEMM_BLOCK_KERNEL_TT_32xNxK(32, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, idx_m), ldc);
  1486. }
  1487. }
  1488. if (tag_m_Nx != M) {
  1489. m_step = M - tag_m_Nx;
  1490. if (m_step > 16) {
  1491. COL_MAJOR_ITCOPY_KERNEL_Kx32m(m_step, k_step, &A(tag_m_Nx, idx_k), lda, block_A);
  1492. for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
  1493. SBGEMM_BLOCK_KERNEL_TT_32x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, tag_m_Nx), ldc);
  1494. }
  1495. if (tag_n_Nx != n_to) {
  1496. n_step = n_to - tag_n_Nx;
  1497. SBGEMM_BLOCK_KERNEL_TT_32xNxK(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, tag_m_Nx), ldc);
  1498. }
  1499. } else {
  1500. COL_MAJOR_ITCOPY_KERNEL_Kx16m(m_step, k_step, &A(tag_m_Nx, idx_k), lda, block_A);
  1501. for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
  1502. SBGEMM_BLOCK_KERNEL_TT_16x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, tag_m_Nx), ldc);
  1503. }
  1504. if (tag_n_Nx != n_to) {
  1505. n_step = n_to - tag_n_Nx;
  1506. SBGEMM_BLOCK_KERNEL_TT_16xNxK(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, tag_m_Nx), ldc);
  1507. }
  1508. }
  1509. }
  1510. idx_k += k_step;
  1511. k_step = K - idx_k;
  1512. k_step = (k_step > BF16_BLOCK_THRES_K) ? BF16_BLOCK_THRES_K : k_step;
  1513. k_step_round32 = k_step & (~31);
  1514. k_step_round32 = (k_step > k_step_round32) ? (k_step_round32 + 32) : k_step_round32;
  1515. }
  1516. n_from = n_to;
  1517. n_to += BF16_BLOCK_THRES_N;
  1518. n_to = (n_to > N) ? N : n_to;
  1519. tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1));
  1520. }
  1521. } else {
  1522. m_step = M;
  1523. if (m_step > 16) {
  1524. while (n_from < N) {
  1525. for (BLASLONG idx_k = 0; idx_k < K;) {
  1526. // Use Kx32 kernel when BF16_BLOCK_THRES_M==32, Kx16 kernel when BF16_BLOCK_THRES_M==16, ...
  1527. COL_MAJOR_ITCOPY_KERNEL_Kx32m(m_step, k_step, &A(0, idx_k), lda, block_A);
  1528. for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
  1529. // Use 8x32 kernel when BF16_BLOCK_THRES_N==8, 4x32 kernel when BF16_BLOCK_THRES_N==4, ...
  1530. COL_MAJOR_OTCOPY_KERNEL_Kx8(k_step, &B(idx_k, idx_n), ldb, block_B + (idx_n-n_from)*k_step_round32);
  1531. SBGEMM_BLOCK_KERNEL_TT_32x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, 0), ldc);
  1532. }
  1533. if (tag_n_Nx != n_to) {
  1534. n_step = n_to - tag_n_Nx;
  1535. COL_MAJOR_OTCOPY_KERNEL_Kx8m(k_step, n_step, &B(idx_k, tag_n_Nx), ldb, block_B + (tag_n_Nx-n_from)*k_step_round32);
  1536. SBGEMM_BLOCK_KERNEL_TT_32xNxK(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, 0), ldc);
  1537. }
  1538. idx_k += k_step;
  1539. k_step = K - idx_k;
  1540. k_step = (k_step > BF16_BLOCK_THRES_K) ? BF16_BLOCK_THRES_K : k_step;
  1541. k_step_round32 = k_step & (~31);
  1542. k_step_round32 = (k_step > k_step_round32) ? (k_step_round32 + 32) : k_step_round32;
  1543. }
  1544. n_from = n_to;
  1545. n_to += BF16_BLOCK_THRES_N;
  1546. n_to = (n_to > N) ? N : n_to;
  1547. tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1));
  1548. }
  1549. } else {
  1550. while (n_from < N) {
  1551. for (BLASLONG idx_k = 0; idx_k < K;) {
  1552. // Use Kx32 kernel when BF16_BLOCK_THRES_M==32, Kx16 kernel when BF16_BLOCK_THRES_M==16, ...
  1553. COL_MAJOR_ITCOPY_KERNEL_Kx16m(m_step, k_step, &A(0, idx_k), lda, block_A);
  1554. for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) {
  1555. // Use 8x32 kernel when BF16_BLOCK_THRES_N==8, 4x32 kernel when BF16_BLOCK_THRES_N==4, ...
  1556. COL_MAJOR_OTCOPY_KERNEL_Kx8(k_step, &B(idx_k, idx_n), ldb, block_B + (idx_n-n_from)*k_step_round32);
  1557. SBGEMM_BLOCK_KERNEL_TT_16x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, 0), ldc);
  1558. }
  1559. if (tag_n_Nx != n_to) {
  1560. n_step = n_to - tag_n_Nx;
  1561. COL_MAJOR_OTCOPY_KERNEL_Kx8m(k_step, n_step, &B(idx_k, tag_n_Nx), ldb, block_B + (tag_n_Nx-n_from)*k_step_round32);
  1562. SBGEMM_BLOCK_KERNEL_TT_16xNxK(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, 0), ldc);
  1563. }
  1564. idx_k += k_step;
  1565. k_step = K - idx_k;
  1566. k_step = (k_step > BF16_BLOCK_THRES_K) ? BF16_BLOCK_THRES_K : k_step;
  1567. k_step_round32 = k_step & (~31);
  1568. k_step_round32 = (k_step > k_step_round32) ? (k_step_round32 + 32) : k_step_round32;
  1569. }
  1570. n_from = n_to;
  1571. n_to += BF16_BLOCK_THRES_N;
  1572. n_to = (n_to > N) ? N : n_to;
  1573. tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1));
  1574. }
  1575. }
  1576. }
  1577. }
  1578. /* ----------------------------------------- End of TT kernels --------------------------------------- */
  1579. /*
  1580. #ifndef ONE_ALPHA // ALPHA is not ONE
  1581. void sbgemm_internal_kernel_alpha(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransA, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransB, OPENBLAS_CONST blasint M, OPENBLAS_CONST blasint N, OPENBLAS_CONST blasint K,
  1582. OPENBLAS_CONST float alpha, OPENBLAS_CONST bfloat16 *A, OPENBLAS_CONST blasint lda, OPENBLAS_CONST bfloat16 *B, OPENBLAS_CONST blasint ldb, float *C, OPENBLAS_CONST blasint ldc)
  1583. #else // ALPHA is ONE
  1584. void sbgemm_internal_kernel_one(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransA, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransB, OPENBLAS_CONST blasint M, OPENBLAS_CONST blasint N, OPENBLAS_CONST blasint K,
  1585. OPENBLAS_CONST float alpha, OPENBLAS_CONST bfloat16 *A, OPENBLAS_CONST blasint lda, OPENBLAS_CONST bfloat16 *B, OPENBLAS_CONST blasint ldb, float *C, OPENBLAS_CONST blasint ldc)
  1586. #endif
  1587. {
  1588. if (Order == CblasColMajor) {
  1589. if (TransA == CblasNoTrans) {
  1590. if (TransB == CblasNoTrans) {
  1591. SBGEMM_BLOCKING_KERNEL_NN(M, N, K, alpha, A, lda, B, ldb, C, ldc, block_A, block_B);
  1592. } else if (TransB == CblasTrans) {
  1593. SBGEMM_BLOCKING_KERNEL_NT(M, N, K, alpha, A, lda, B, ldb, C, ldc, block_A, block_B);
  1594. }
  1595. } else {
  1596. if (TransB == CblasNoTrans) {
  1597. SBGEMM_BLOCKING_KERNEL_TN(M, N, K, alpha, A, lda, B, ldb, C, ldc, block_A, block_B);
  1598. } else if (TransB == CblasTrans) {
  1599. SBGEMM_BLOCKING_KERNEL_TT(M, N, K, alpha, A, lda, B, ldb, C, ldc, block_A, block_B);
  1600. }
  1601. }
  1602. } else {
  1603. if (TransA == CblasNoTrans) {
  1604. if (TransB == CblasNoTrans) {
  1605. SBGEMM_BLOCKING_KERNEL_NN(N, M, K, alpha, B, ldb, A, lda, C, ldc, block_A, block_B);
  1606. } else if (TransB == CblasTrans) {
  1607. SBGEMM_BLOCKING_KERNEL_TN(N, M, K, alpha, B, ldb, A, lda, C, ldc, block_A, block_B);
  1608. }
  1609. } else {
  1610. if (TransB == CblasNoTrans) {
  1611. SBGEMM_BLOCKING_KERNEL_NT(N, M, K, alpha, B, ldb, A, lda, C, ldc, block_A, block_B);
  1612. } else if (TransB == CblasTrans) {
  1613. SBGEMM_BLOCKING_KERNEL_TT(N, M, K, alpha, B, ldb, A, lda, C, ldc, block_A, block_B);
  1614. }
  1615. }
  1616. }
  1617. }
  1618. */