You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

sbgemv_t_microk_cooperlake_template.c 161 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755275627572758275927602761276227632764276527662767276827692770277127722773277427752776277727782779278027812782278327842785278627872788278927902791279227932794279527962797279827992800280128022803280428052806280728082809281028112812281328142815281628172818281928202821282228232824282528262827282828292830283128322833283428352836283728382839284028412842284328442845284628472848284928502851285228532854285528562857285828592860286128622863286428652866286728682869287028712872287328742875287628772878287928802881288228832884288528862887288828892890289128922893289428952896289728982899290029012902290329042905290629072908290929102911291229132914291529162917291829192920292129222923292429252926292729282929293029312932293329342935293629372938293929402941294229432944294529462947294829492950295129522953295429552956295729582959296029612962296329642965296629672968296929702971297229732974297529762977297829792980298129822983298429852986298729882989299029912992299329942995299629972998299930003001300230033004300530063007300830093010301130123013301430153016301730183019302030213022302330243025302630273028302930303031303230333034303530363037303830393040304130423043304430453046304730483049305030513052305330543055305630573058305930603061306230633064306530663067306830693070307130723073307430753076307730783079308030813082
  1. /***************************************************************************
  2. Copyright (c) 2014, The OpenBLAS Project
  3. All rights reserved.
  4. Redistribution and use in source and binary forms, with or without
  5. modification, are permitted provided that the following conditions are
  6. met:
  7. 1. Redistributions of source code must retain the above copyright
  8. notice, this list of conditions and the following disclaimer.
  9. 2. Redistributions in binary form must reproduce the above copyright
  10. notice, this list of conditions and the following disclaimer in
  11. the documentation and/or other materials provided with the
  12. distribution.
  13. 3. Neither the name of the OpenBLAS project nor the names of
  14. its contributors may be used to endorse or promote products
  15. derived from this software without specific prior written permission.
  16. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
  17. AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  18. IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
  19. ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
  20. LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
  21. DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
  22. SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
  23. CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
  24. OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
  25. USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  26. *****************************************************************************/
  27. #include <immintrin.h>
  28. #include "common.h"
  29. // Include common macros for BF16 based operations with IA intrinsics
  30. #include "bf16_common_macros.h"
  31. #ifndef ZERO_BETA // Beta is non-zero
  32. #ifndef ONE_BETA // BETA is not ONE
  33. #define STORE16_COMPLETE_RESULT STORE16_COMPLETE_RESULT_ALPHA_BETA
  34. #define STORE16_MASK_COMPLETE_RESULT STORE16_MASK_COMPLETE_RESULT_ALPHA_BETA
  35. #define STORE8_COMPLETE_RESULT STORE8_COMPLETE_RESULT_ALPHA_BETA
  36. #define STORE8_MASK_COMPLETE_RESULT STORE8_MASK_COMPLETE_RESULT_ALPHA_BETA
  37. #define STORE4_COMPLETE_RESULT STORE4_COMPLETE_RESULT_ALPHA_BETA
  38. #define STORE4_MASK_COMPLETE_RESULT STORE4_MASK_COMPLETE_RESULT_ALPHA_BETA
  39. #else // BETA is ONE
  40. #define STORE16_COMPLETE_RESULT STORE16_COMPLETE_RESULT_ALPHA_ONE
  41. #define STORE16_MASK_COMPLETE_RESULT STORE16_MASK_COMPLETE_RESULT_ALPHA_ONE
  42. #define STORE8_COMPLETE_RESULT STORE8_COMPLETE_RESULT_ALPHA_ONE
  43. #define STORE8_MASK_COMPLETE_RESULT STORE8_MASK_COMPLETE_RESULT_ALPHA_ONE
  44. #define STORE4_COMPLETE_RESULT STORE4_COMPLETE_RESULT_ALPHA_ONE
  45. #define STORE4_MASK_COMPLETE_RESULT STORE4_MASK_COMPLETE_RESULT_ALPHA_ONE
  46. #endif
  47. #else // BETA is zero
  48. #ifndef ONE_ALPHA // ALPHA is not ONE
  49. #define STORE16_COMPLETE_RESULT STORE16_COMPLETE_RESULT_ALPHA
  50. #define STORE16_MASK_COMPLETE_RESULT STORE16_MASK_COMPLETE_RESULT_ALPHA
  51. #define STORE8_COMPLETE_RESULT STORE8_COMPLETE_RESULT_ALPHA
  52. #define STORE8_MASK_COMPLETE_RESULT STORE8_MASK_COMPLETE_RESULT_ALPHA
  53. #define STORE4_COMPLETE_RESULT STORE4_COMPLETE_RESULT_ALPHA
  54. #define STORE4_MASK_COMPLETE_RESULT STORE4_MASK_COMPLETE_RESULT_ALPHA
  55. #else // ALPHA is ONE
  56. #define STORE16_COMPLETE_RESULT STORE16_COMPLETE_RESULT_DIRECT
  57. #define STORE16_MASK_COMPLETE_RESULT STORE16_MASK_COMPLETE_RESULT_DIRECT
  58. #define STORE8_COMPLETE_RESULT STORE8_COMPLETE_RESULT_DIRECT
  59. #define STORE8_MASK_COMPLETE_RESULT STORE8_MASK_COMPLETE_RESULT_DIRECT
  60. #define STORE4_COMPLETE_RESULT STORE4_COMPLETE_RESULT_DIRECT
  61. #define STORE4_MASK_COMPLETE_RESULT STORE4_MASK_COMPLETE_RESULT_DIRECT
  62. #endif
  63. #endif
  64. // 32 rows parallel processing BF16 GEMV kernel for n=1 && lda ineffective scenario
  65. #ifndef ZERO_BETA
  66. #ifndef ONE_BETA
  67. static int sbgemv_kernel_32x1_alpha_beta(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
  68. #else
  69. static int sbgemv_kernel_32x1_alpha_one(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
  70. #endif
  71. #else
  72. #ifndef ONE_ALPHA
  73. static int sbgemv_kernel_32x1_alpha(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
  74. #else
  75. static int sbgemv_kernel_32x1(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
  76. #endif
  77. #endif
  78. {
  79. BLASLONG tag_m_32x = m & (~31);
  80. __m512i matrixArray_0, matrixArray_1, matrixArray_2;
  81. __m512i xArray;
  82. __m512 result_0, result_1;
  83. #ifndef ONE_ALPHA
  84. __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
  85. #endif
  86. #ifndef ZERO_BETA
  87. #ifndef ONE_BETA
  88. __m512 BETAVECTOR = _mm512_set1_ps(beta);
  89. #endif
  90. #endif
  91. __m512i load_idx_lo = _mm512_set_epi16(0, 15, 0, 14, 0, 13, 0, 12, 0, 11, 0, 10, 0, 9, 0, 8,\
  92. 0, 7, 0, 6, 0, 5, 0, 4, 0, 3, 0, 2, 0, 1, 0, 0);
  93. __m512i M512_EPI16_16 = _mm512_set1_epi16(16);
  94. __m512i load_idx_hi = _mm512_add_epi16(load_idx_lo, M512_EPI16_16);
  95. unsigned int interleve_mask_value = ((unsigned int) 0x55555555);
  96. __mmask32 interleave_mask = *((__mmask32*) &interleve_mask_value);
  97. xArray = _mm512_set1_epi16((short) x[0]);
  98. xArray = _mm512_mask_blend_epi16(interleave_mask, _mm512_setzero_si512(), xArray);
  99. if (tag_m_32x > 0) {
  100. for (BLASLONG idx_m = 0; idx_m < tag_m_32x; idx_m+=32) {
  101. result_0 = _mm512_setzero_ps();
  102. result_1 = _mm512_setzero_ps();
  103. matrixArray_0 = _mm512_loadu_si512(&a[(idx_m)]); // Load 32 rows with n=1
  104. matrixArray_1 = _mm512_permutexvar_epi16(load_idx_lo, matrixArray_0); // Expand the low 16 elements
  105. matrixArray_2 = _mm512_permutexvar_epi16(load_idx_hi, matrixArray_0); // Expand the high 16 elements
  106. result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_1, (__m512bh) xArray);
  107. result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_2, (__m512bh) xArray);
  108. STORE16_COMPLETE_RESULT(result_0, y+idx_m)
  109. STORE16_COMPLETE_RESULT(result_1, y+idx_m+16)
  110. }
  111. }
  112. BLASLONG tail_num = m - tag_m_32x;
  113. if (tail_num > 16) {
  114. result_0 = _mm512_setzero_ps();
  115. result_1 = _mm512_setzero_ps();
  116. unsigned int tail_mask_value = (((unsigned int)0xffffffff) >> (32-tail_num));
  117. __mmask32 tail_mask = *((__mmask32*) &tail_mask_value);
  118. matrixArray_0 = _mm512_maskz_loadu_epi16(tail_mask, &a[(tag_m_32x)]); // Load 32 rows with n=1
  119. matrixArray_1 = _mm512_permutexvar_epi16(load_idx_lo, matrixArray_0); // Expand the low 16 elements
  120. matrixArray_2 = _mm512_permutexvar_epi16(load_idx_hi, matrixArray_0); // Expand the high 16 elements
  121. result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_1, (__m512bh) xArray);
  122. result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_2, (__m512bh) xArray);
  123. unsigned short store_mask_value = (((unsigned short)0xffff) >> (32-tail_num));
  124. __mmask16 store_mask = *((__mmask16*) &store_mask_value);
  125. STORE16_COMPLETE_RESULT(result_0, y+tag_m_32x)
  126. STORE16_MASK_COMPLETE_RESULT(result_1, y+tag_m_32x+16, store_mask)
  127. } else if (tail_num > 8) {
  128. __m256 result256_0 = _mm256_setzero_ps();
  129. __m256 result256_1 = _mm256_setzero_ps();
  130. __m256i load_idx_lo256 = _mm512_castsi512_si256(load_idx_lo);
  131. __m256i load_idx_hi256 = _mm512_extracti32x8_epi32(load_idx_lo, 0x1);
  132. __m256i xArray256 = _mm512_castsi512_si256(xArray);
  133. unsigned short tail_mask_value = (((unsigned short)0xffff) >> (16-tail_num));
  134. __mmask16 tail_mask = *((__mmask16*) &tail_mask_value);
  135. __m256i matrixArray256_0 = _mm256_maskz_loadu_epi16(tail_mask, &a[(tag_m_32x)]); // Load 16 rows with n=1
  136. __m256i matrixArray256_1 = _mm256_permutexvar_epi16(load_idx_lo256, matrixArray256_0); // Expand the low 8 elements
  137. __m256i matrixArray256_2 = _mm256_permutexvar_epi16(load_idx_hi256, matrixArray256_0); // Expand the high 8 elements
  138. result256_0 = _mm256_dpbf16_ps(result256_0, (__m256bh) matrixArray256_1, (__m256bh) xArray256);
  139. result256_1 = _mm256_dpbf16_ps(result256_1, (__m256bh) matrixArray256_2, (__m256bh) xArray256);
  140. unsigned char store_mask_value = (((unsigned char)0xff) >> (16-tail_num));
  141. __mmask8 store_mask = *((__mmask8*) &store_mask_value);
  142. STORE8_COMPLETE_RESULT(result256_0, y+tag_m_32x)
  143. STORE8_MASK_COMPLETE_RESULT(result256_1, y+tag_m_32x+8, store_mask)
  144. } else {
  145. __m128 result128_0 = _mm_setzero_ps();
  146. __m128 result128_1 = _mm_setzero_ps();
  147. __m128i load_idx_lo128 = _mm_set_epi16(0, 3, 0, 2, 0, 1, 0, 0);
  148. __m128i M128_EPI16_4 = _mm_set1_epi16(4);
  149. __m128i load_idx_hi128 = _mm_add_epi16(load_idx_lo128, M128_EPI16_4);
  150. __m128i xArray128 = _mm512_castsi512_si128(xArray);
  151. unsigned char tail_mask_value = (((unsigned char)0xff) >> (8-tail_num));
  152. __mmask8 tail_mask = *((__mmask8*) &tail_mask_value);
  153. __m128i matrixArray128_0 = _mm_maskz_loadu_epi16(tail_mask, &a[(tag_m_32x)]); // Load 8 rows with n=1
  154. __m128i matrixArray128_1 = _mm_permutexvar_epi16(load_idx_lo128, matrixArray128_0); // Expand the low 4 elements
  155. __m128i matrixArray128_2 = _mm_permutexvar_epi16(load_idx_hi128, matrixArray128_0); // Expand the high 4 elements
  156. result128_0 = _mm_dpbf16_ps(result128_0, (__m128bh) matrixArray128_1, (__m128bh) xArray128);
  157. result128_1 = _mm_dpbf16_ps(result128_1, (__m128bh) matrixArray128_2, (__m128bh) xArray128);
  158. if (tail_num > 4) {
  159. unsigned char store_mask_value = (((unsigned char)0xf) >> (8-tail_num));
  160. __mmask8 store_mask = *((__mmask8*) &store_mask_value);
  161. STORE4_COMPLETE_RESULT(result128_0, y+tag_m_32x)
  162. STORE4_MASK_COMPLETE_RESULT(result128_1, y+tag_m_32x+4, store_mask)
  163. } else {
  164. unsigned char store_mask_value = (((unsigned char)0xf) >> (4-tail_num));
  165. __mmask8 store_mask = *((__mmask8*) &store_mask_value);
  166. STORE4_MASK_COMPLETE_RESULT(result128_0, y+tag_m_32x, store_mask)
  167. }
  168. }
  169. return 0;
  170. }
  171. // 32 rows parallel processing BF16 GEMV kernel for n=2 && lda ineffective scenario
  172. #ifndef ZERO_BETA
  173. #ifndef ONE_BETA
  174. static int sbgemv_kernel_32x2_alpha_beta(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
  175. #else
  176. static int sbgemv_kernel_32x2_alpha_one(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
  177. #endif
  178. #else
  179. #ifndef ONE_ALPHA
  180. static int sbgemv_kernel_32x2_alpha(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
  181. #else
  182. static int sbgemv_kernel_32x2(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
  183. #endif
  184. #endif
  185. {
  186. BLASLONG tag_m_32x = m & (~31);
  187. __m512i matrixArray_0, matrixArray_1;
  188. __m512i xArray;
  189. __m512 result_0, result_1;
  190. #ifndef ONE_ALPHA
  191. __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
  192. #endif
  193. #ifndef ZERO_BETA
  194. __m512 BETAVECTOR = _mm512_set1_ps(beta);
  195. #endif
  196. unsigned char load_mask_value = (((unsigned char)0xff) >> 6);
  197. __mmask8 load_mask = *((__mmask8*) &load_mask_value);
  198. xArray = _mm512_broadcastd_epi32(_mm_maskz_loadu_epi16(load_mask, x));
  199. if (tag_m_32x > 0) {
  200. for (BLASLONG idx_m = 0; idx_m < tag_m_32x; idx_m+=32) {
  201. result_0 = _mm512_setzero_ps();
  202. result_1 = _mm512_setzero_ps();
  203. matrixArray_0 = _mm512_loadu_si512(&a[(idx_m)*2]); // Load 16 rows as n=2
  204. matrixArray_1 = _mm512_loadu_si512(&a[(idx_m+16)*2]); // Load 16 rows as n=2
  205. result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_0, (__m512bh) xArray);
  206. result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_1, (__m512bh) xArray);
  207. STORE16_COMPLETE_RESULT(result_0, y+idx_m)
  208. STORE16_COMPLETE_RESULT(result_1, y+idx_m+16)
  209. }
  210. }
  211. if (m - tag_m_32x >= 16) {
  212. result_0 = _mm512_setzero_ps();
  213. matrixArray_0 = _mm512_loadu_si512(&a[(tag_m_32x)*2]); // Load 16 rows with n=2
  214. result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_0, (__m512bh) xArray);
  215. STORE16_COMPLETE_RESULT(result_0, y+tag_m_32x)
  216. tag_m_32x += 16;
  217. }
  218. BLASLONG tail_num = m - tag_m_32x;
  219. if (tail_num > 8) {
  220. result_0 = _mm512_setzero_ps();
  221. unsigned short tail_mask_value = (((unsigned short)0xffff) >> (16-(m&15)));
  222. __mmask16 tail_mask = *((__mmask16*) &tail_mask_value);
  223. matrixArray_0 = _mm512_maskz_loadu_epi32(tail_mask, &a[(tag_m_32x)*2]); // Load 16 rows with n=2
  224. result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_0, (__m512bh) xArray);
  225. STORE16_MASK_COMPLETE_RESULT(result_0, y+tag_m_32x, tail_mask)
  226. } else if (tail_num == 8) {
  227. __m256 result256 = _mm256_setzero_ps();
  228. __m256i matrixArray256 = _mm256_loadu_si256(&a[(tag_m_32x)*2]); // Load 8 rows with n=2
  229. __m256i xArray256 = _mm512_castsi512_si256(xArray);
  230. result256 = _mm256_dpbf16_ps(result256, (__m256bh) matrixArray256, (__m256bh) xArray256);
  231. STORE8_COMPLETE_RESULT(result256, y+tag_m_32x)
  232. } else {
  233. __m256 result256 = _mm256_setzero_ps();
  234. unsigned char tail_mask_value = (((unsigned char)0xff) >> (8-(m&7)));
  235. __mmask8 tail_mask = *((__mmask8*) &tail_mask_value);
  236. __m256i matrixArray256 = _mm256_maskz_loadu_epi32(tail_mask, &a[(tag_m_32x)*2]); // Load 8 rows with n=2
  237. __m256i xArray256 = _mm512_castsi512_si256(xArray);
  238. result256 = _mm256_dpbf16_ps(result256, (__m256bh) matrixArray256, (__m256bh) xArray256);
  239. STORE8_MASK_COMPLETE_RESULT(result256, y+tag_m_32x, tail_mask)
  240. }
  241. return 0;
  242. }
  243. // 32 rows parallel processing BF16 GEMV kernel for n=3 && lda ineffective scenario
  244. #ifndef ZERO_BETA
  245. #ifndef ONE_BETA
  246. static int sbgemv_kernel_32x3_alpha_beta(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
  247. #else
  248. static int sbgemv_kernel_32x3_alpha_one(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
  249. #endif
  250. #else
  251. #ifndef ONE_ALPHA
  252. static int sbgemv_kernel_32x3_alpha(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
  253. #else
  254. static int sbgemv_kernel_32x3(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
  255. #endif
  256. #endif
  257. {
  258. BLASLONG tag_m_32x = m & (~31);
  259. __m512 result_0, result_1;
  260. #ifndef ONE_ALPHA
  261. __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
  262. #endif
  263. #ifndef ZERO_BETA
  264. __m512 BETAVECTOR = _mm512_set1_ps(beta);
  265. #endif
  266. unsigned char x_load_mask_value = (((unsigned char)0xff) >> 5);
  267. __mmask8 x_load_mask = *((__mmask8*) &x_load_mask_value);
  268. __m128i xTmp = _mm_maskz_loadu_epi16(x_load_mask, x); // x0|x1|x2|0|0|0|0|0|
  269. __m512i xArray_0 = _mm512_broadcastd_epi32(xTmp); // x0|x1|x0|x1|...|x0|x1|
  270. __m512i xArray_1 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(xTmp, 0x1)); // x2| 0|x2| 0|...|x2| 0|
  271. __m512i load_idx_base;
  272. __m512i M512_EPI16_2, M512_EPI16_8, M512_EPI16_16;
  273. M512_EPI16_2 = _mm512_set1_epi16(2);
  274. M512_EPI16_8 = _mm512_add_epi16(M512_EPI16_2, M512_EPI16_2);
  275. M512_EPI16_8 = _mm512_add_epi16(M512_EPI16_8, M512_EPI16_8);
  276. M512_EPI16_16 = _mm512_add_epi16(M512_EPI16_8, M512_EPI16_8);
  277. load_idx_base = _mm512_set_epi16(46, 45, 43, 42, 40, 39, 37, 36, 34, 33, 31, 30, 28, 27, 25, 24,
  278. 22, 21, 19, 18, 16, 15, 13, 12, 10, 9, 7, 6, 4, 3, 1, 0);
  279. if (tag_m_32x > 0) {
  280. __m512i load_idx01_1st, load_idx01_2nd, load_idx2_1st, load_idx2_2nd;
  281. __m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3, matrixArray_4, matrixArray_5, matrixArray_6;
  282. unsigned int idx_blend_mask_value = ((unsigned int)0x80000000);
  283. __mmask32 idx_blend_mask = *((__mmask32*) &idx_blend_mask_value);
  284. load_idx01_1st = load_idx_base;
  285. load_idx01_2nd = _mm512_add_epi16(load_idx01_1st, M512_EPI16_16);
  286. load_idx2_1st = _mm512_add_epi16(load_idx01_1st, M512_EPI16_2);
  287. load_idx2_2nd = _mm512_add_epi16(load_idx01_2nd, M512_EPI16_2);
  288. load_idx2_2nd = _mm512_mask_blend_epi16(idx_blend_mask, load_idx2_2nd, _mm512_setzero_si512());
  289. for (BLASLONG idx_m = 0; idx_m < tag_m_32x; idx_m+=32) {
  290. result_0 = _mm512_setzero_ps();
  291. result_1 = _mm512_setzero_ps();
  292. matrixArray_0 = _mm512_loadu_si512(&a[(idx_m)*3]); // Load 10 rows with n=3 plus 2 element
  293. matrixArray_1 = _mm512_loadu_si512(&a[((idx_m+10)*3 + 2)]); // Load 10 rows with n=3 plus 2 element
  294. matrixArray_2 = _mm512_loadu_si512(&a[((idx_m+21)*3 + 1)]); // Load 10 rows with n=3 plus 2 element
  295. matrixArray_3 = _mm512_permutex2var_epi16(matrixArray_0, load_idx01_1st, matrixArray_1); // Select the first 2 elements for each row
  296. matrixArray_4 = _mm512_permutex2var_epi16(matrixArray_1, load_idx01_2nd, matrixArray_2); // Select the first 2 elements for each row
  297. matrixArray_5 = _mm512_permutex2var_epi16(matrixArray_0, load_idx2_1st, matrixArray_1); // Select the third element for each row
  298. matrixArray_6 = _mm512_permutex2var_epi16(matrixArray_1, load_idx2_2nd, matrixArray_2); // Select the third element for each row
  299. result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_3, (__m512bh) xArray_0);
  300. result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_5, (__m512bh) xArray_1);
  301. result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_4, (__m512bh) xArray_0);
  302. result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_6, (__m512bh) xArray_1);
  303. STORE16_COMPLETE_RESULT(result_0, y+idx_m)
  304. STORE16_COMPLETE_RESULT(result_1, y+idx_m+16)
  305. }
  306. }
  307. if (tag_m_32x != m) {
  308. __m256i load256_idx01_1st, load256_idx01_2nd, load256_idx2_1st, load256_idx2_2nd;
  309. __m256i matrixArray256_0, matrixArray256_1, matrixArray256_2, matrixArray256_3, matrixArray256_4, matrixArray256_5, matrixArray256_6;
  310. __m256 result256_0, result256_1;
  311. unsigned short idx256_blend_mask_value = ((unsigned short)0x8000);
  312. __mmask16 idx256_blend_mask = *((__mmask16*) &idx256_blend_mask_value);
  313. load256_idx01_1st = _mm512_castsi512_si256(load_idx_base);
  314. load256_idx01_2nd = _mm256_add_epi16(load256_idx01_1st, _mm512_castsi512_si256(M512_EPI16_8));
  315. load256_idx2_1st = _mm256_add_epi16(load256_idx01_1st, _mm512_castsi512_si256(M512_EPI16_2));
  316. load256_idx2_2nd = _mm256_add_epi16(load256_idx01_2nd, _mm512_castsi512_si256(M512_EPI16_2));
  317. load256_idx2_2nd = _mm256_mask_blend_epi16(idx256_blend_mask, load256_idx2_2nd, _mm256_setzero_si256());
  318. if (m - tag_m_32x > 15) {
  319. result256_0 = _mm256_setzero_ps();
  320. result256_1 = _mm256_setzero_ps();
  321. matrixArray256_0 = _mm256_loadu_si256(&a[(tag_m_32x)*3]); // Load 5 rows with n=3 plus 1 element
  322. matrixArray256_1 = _mm256_loadu_si256(&a[((tag_m_32x+5)*3 + 1)]); // Load 5 rows with n=3 plus 1 element
  323. matrixArray256_2 = _mm256_loadu_si256(&a[((tag_m_32x+10)*3 + 2)]); // Load 5 rows with n=3 plus 1 element
  324. matrixArray256_3 = _mm256_permutex2var_epi16(matrixArray256_0, load256_idx01_1st, matrixArray256_1); // Select the first 2 elements for each row
  325. matrixArray256_4 = _mm256_permutex2var_epi16(matrixArray256_1, load256_idx01_2nd, matrixArray256_2); // Select the first 2 elements for each row
  326. matrixArray256_5 = _mm256_permutex2var_epi16(matrixArray256_0, load256_idx2_1st, matrixArray256_1); // Select the third element for each row
  327. matrixArray256_6 = _mm256_permutex2var_epi16(matrixArray256_1, load256_idx2_2nd, matrixArray256_2); // Select the third element for each row
  328. result256_0 = _mm256_dpbf16_ps(result256_0, (__m256bh) matrixArray256_3, (__m256bh) _mm512_castsi512_si256(xArray_0));
  329. result256_0 = _mm256_dpbf16_ps(result256_0, (__m256bh) matrixArray256_5, (__m256bh) _mm512_castsi512_si256(xArray_1));
  330. result256_1 = _mm256_dpbf16_ps(result256_1, (__m256bh) matrixArray256_4, (__m256bh) _mm512_castsi512_si256(xArray_0));
  331. result256_1 = _mm256_dpbf16_ps(result256_1, (__m256bh) matrixArray256_6, (__m256bh) _mm512_castsi512_si256(xArray_1));
  332. STORE8_COMPLETE_RESULT(result256_0, y+tag_m_32x)
  333. STORE8_COMPLETE_RESULT(result256_1, y+tag_m_32x+8)
  334. tag_m_32x += 16;
  335. }
  336. if (tag_m_32x != m) {
  337. result256_0 = _mm256_setzero_ps();
  338. result256_1 = _mm256_setzero_ps();
  339. BLASLONG tail_num = m-tag_m_32x;
  340. if (tail_num > 10) {
  341. unsigned short tail_mask_value = (((unsigned short)0xffff) >> (16-((tail_num-10-1)*3+1)));
  342. __mmask16 tail_mask = *((__mmask16*) &tail_mask_value);
  343. matrixArray256_0 = _mm256_loadu_si256(&a[(tag_m_32x)*3]); // Load 5 rows with n=3 plus 1 element
  344. matrixArray256_1 = _mm256_loadu_si256(&a[((tag_m_32x+5)*3 + 1)]); // Load 5 rows with n=3 plus 1 element
  345. matrixArray256_2 = _mm256_maskz_loadu_epi16(tail_mask, &a[((tag_m_32x+10)*3 + 2)]); // Load m-tag_m_32x-10 rows
  346. matrixArray256_3 = _mm256_permutex2var_epi16(matrixArray256_0, load256_idx01_1st, matrixArray256_1); // Select the first 2 elements for each row
  347. matrixArray256_4 = _mm256_permutex2var_epi16(matrixArray256_1, load256_idx01_2nd, matrixArray256_2); // Select the first 2 elements for each row
  348. matrixArray256_5 = _mm256_permutex2var_epi16(matrixArray256_0, load256_idx2_1st, matrixArray256_1); // Select the third element for each row
  349. matrixArray256_6 = _mm256_permutex2var_epi16(matrixArray256_1, load256_idx2_2nd, matrixArray256_2); // Select the third element for each row
  350. result256_0 = _mm256_dpbf16_ps(result256_0, (__m256bh) matrixArray256_3, (__m256bh) _mm512_castsi512_si256(xArray_0));
  351. result256_0 = _mm256_dpbf16_ps(result256_0, (__m256bh) matrixArray256_5, (__m256bh) _mm512_castsi512_si256(xArray_1));
  352. result256_1 = _mm256_dpbf16_ps(result256_1, (__m256bh) matrixArray256_4, (__m256bh) _mm512_castsi512_si256(xArray_0));
  353. result256_1 = _mm256_dpbf16_ps(result256_1, (__m256bh) matrixArray256_6, (__m256bh) _mm512_castsi512_si256(xArray_1));
  354. } else if (tail_num > 5) {
  355. unsigned short tail_mask_value = (((unsigned short)0xffff) >> (16-((tail_num-5-1)*3+2)));
  356. __mmask16 tail_mask = *((__mmask16*) &tail_mask_value);
  357. matrixArray256_0 = _mm256_loadu_si256(&a[(tag_m_32x)*3]); // Load 5 rows with n=3 plus 1 element
  358. matrixArray256_1 = _mm256_maskz_loadu_epi16(tail_mask, &a[((tag_m_32x+5)*3+1)]); // Load m-tag_m_32x-5 rows
  359. matrixArray256_2 = _mm256_setzero_si256();
  360. matrixArray256_3 = _mm256_permutex2var_epi16(matrixArray256_0, load256_idx01_1st, matrixArray256_1); // Select the first 2 elements for each row
  361. matrixArray256_4 = _mm256_permutex2var_epi16(matrixArray256_1, load256_idx01_2nd, matrixArray256_2); // Select the first 2 elements for each row
  362. matrixArray256_5 = _mm256_permutex2var_epi16(matrixArray256_0, load256_idx2_1st, matrixArray256_1); // Select the third element for each row
  363. matrixArray256_6 = _mm256_permutex2var_epi16(matrixArray256_1, load256_idx2_2nd, matrixArray256_2); // Select the third element for each row
  364. result256_0 = _mm256_dpbf16_ps(result256_0, (__m256bh) matrixArray256_3, (__m256bh) _mm512_castsi512_si256(xArray_0));
  365. result256_0 = _mm256_dpbf16_ps(result256_0, (__m256bh) matrixArray256_5, (__m256bh) _mm512_castsi512_si256(xArray_1));
  366. result256_1 = _mm256_dpbf16_ps(result256_1, (__m256bh) matrixArray256_4, (__m256bh) _mm512_castsi512_si256(xArray_0));
  367. result256_1 = _mm256_dpbf16_ps(result256_1, (__m256bh) matrixArray256_6, (__m256bh) _mm512_castsi512_si256(xArray_1));
  368. } else {
  369. unsigned short tail_mask_value = (((unsigned short)0xffff) >> (16-(tail_num*3)));
  370. __mmask16 tail_mask = *((__mmask16*) &tail_mask_value);
  371. matrixArray256_0 = _mm256_maskz_loadu_epi16(tail_mask, &a[(tag_m_32x)*3]); // Load m-tag_m_32x rows
  372. matrixArray256_1 = _mm256_setzero_si256();
  373. matrixArray256_3 = _mm256_permutex2var_epi16(matrixArray256_0, load256_idx01_1st, matrixArray256_1); // Select the first 2 elements for each row
  374. matrixArray256_5 = _mm256_permutex2var_epi16(matrixArray256_0, load256_idx2_1st, matrixArray256_1); // Select the third element for each row
  375. result256_0 = _mm256_dpbf16_ps(result256_0, (__m256bh) matrixArray256_3, (__m256bh) _mm512_castsi512_si256(xArray_0));
  376. result256_0 = _mm256_dpbf16_ps(result256_0, (__m256bh) matrixArray256_5, (__m256bh) _mm512_castsi512_si256(xArray_1));
  377. }
  378. unsigned short store_tail_mask_value = (((unsigned short)0xffff) >> (16-(tail_num)));
  379. __mmask16 store_tail_mask = *((__mmask16*) &store_tail_mask_value);
  380. __m512 result512 = _mm512_insertf32x8(_mm512_castps256_ps512(result256_0), result256_1, 0x1);
  381. STORE16_MASK_COMPLETE_RESULT(result512, y+tag_m_32x, store_tail_mask)
  382. }
  383. }
  384. return 0;
  385. }
  386. // 16 rows parallel processing BF16 GEMV kernel for n=4 && lda ineffective scenario
  387. #ifndef ZERO_BETA
  388. #ifndef ONE_BETA
  389. static int sbgemv_kernel_16x4_alpha_beta(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
  390. #else
  391. static int sbgemv_kernel_16x4_alpha_one(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
  392. #endif
  393. #else
  394. #ifndef ONE_ALPHA
  395. static int sbgemv_kernel_16x4_alpha(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
  396. #else
  397. static int sbgemv_kernel_16x4(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
  398. #endif
  399. #endif
  400. {
  401. BLASLONG tag_m_16x = m & (~15);
  402. __m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3;
  403. __m512i xArray_01, xArray_23, xArray_remix;
  404. __m512 result;
  405. #ifndef ONE_ALPHA
  406. __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
  407. #endif
  408. #ifndef ZERO_BETA
  409. __m512 BETAVECTOR = _mm512_set1_ps(beta);
  410. #endif
  411. __m512i M512_EPI32_1 = _mm512_set1_epi32(1);
  412. __m512i idx_base_0 = _mm512_set_epi32(30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
  413. __m512i idx_base_1 = _mm512_add_epi32(idx_base_0, M512_EPI32_1);
  414. __m512i idx_base_remix = _mm512_inserti32x8(idx_base_0, _mm512_castsi512_si256(idx_base_1), 0x1);
  415. unsigned char x_load_mask_value = (((unsigned char)0xf) >> 2);
  416. __mmask8 x_load_mask = *((__mmask8*) &x_load_mask_value);
  417. __m128i xTmp = _mm_maskz_loadu_epi32(x_load_mask, x); // |x0|x1|x2|x3|0|0|0|0|
  418. xArray_01 = _mm512_broadcastd_epi32(xTmp); // |x0|x1|x0|x1|...|x0|x1|
  419. xArray_23 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(xTmp, 0x1)); // |x2|x3|x2|x3|...|x2|x3|
  420. unsigned short blend_mask_value = ((unsigned short)0xff00);
  421. __mmask16 blend_mask = *((__mmask16*) &blend_mask_value);
  422. xArray_remix = _mm512_mask_blend_epi32(blend_mask, xArray_01, xArray_23); // |x0|x1|x0|x1|x0|x1|x0|x1|...|x2|x3|x2|x3|x2|x3|x2|x3|
  423. if (tag_m_16x > 0) {
  424. for (BLASLONG idx_m = 0; idx_m < tag_m_16x; idx_m+=16) {
  425. result = _mm512_setzero_ps();
  426. matrixArray_0 = _mm512_loadu_si512(&a[(idx_m)*4]); // Load 8 rows with n=4
  427. matrixArray_1 = _mm512_loadu_si512(&a[(idx_m+8)*4]); // Load 8 rows with n=4
  428. matrixArray_2 = _mm512_permutex2var_epi32(matrixArray_0, idx_base_0, matrixArray_1); // |a0|a1|...|h0|h1|i0|i1|...|p0|p1|
  429. matrixArray_3 = _mm512_permutex2var_epi32(matrixArray_0, idx_base_1, matrixArray_1); // |a2|a3|...|h2|h3|i2|i3|...|p2|p3|
  430. result = _mm512_dpbf16_ps(result, (__m512bh) matrixArray_2, (__m512bh) xArray_01);
  431. result = _mm512_dpbf16_ps(result, (__m512bh) matrixArray_3, (__m512bh) xArray_23);
  432. STORE16_COMPLETE_RESULT(result, y+idx_m)
  433. }
  434. }
  435. if (m - tag_m_16x > 7) {
  436. result = _mm512_setzero_ps();
  437. matrixArray_0 = _mm512_loadu_si512(&a[(tag_m_16x)*4]); // Load 8 rows with n=4
  438. matrixArray_2 = _mm512_permutexvar_epi32(idx_base_remix, matrixArray_0); // a0|a1|...|h0|h1|a2|a3|...|h2|h3|
  439. result = _mm512_dpbf16_ps(result, (__m512bh) matrixArray_2, (__m512bh) xArray_remix);
  440. __m256 result256 = _mm256_add_ps(_mm512_castps512_ps256(result), _mm512_extractf32x8_ps(result, 1));
  441. STORE8_COMPLETE_RESULT(result256, y+tag_m_16x)
  442. tag_m_16x += 8;
  443. }
  444. BLASLONG tail_num = m-tag_m_16x;
  445. if (tail_num != 0) {
  446. result = _mm512_setzero_ps();
  447. unsigned short tail_mask_value = (((unsigned short)0xffff) >> (16-tail_num*2));
  448. __mmask16 tail_mask = *((__mmask16*) &tail_mask_value);
  449. matrixArray_0 = _mm512_maskz_loadu_epi32(tail_mask, &a[(tag_m_16x)*4]); // Load 8 rows with n=4
  450. matrixArray_2 = _mm512_permutexvar_epi32(idx_base_remix, matrixArray_0); // a0|a1|...|h0|h1|a2|a3|...|h2|h3|
  451. result = _mm512_dpbf16_ps(result, (__m512bh) matrixArray_2, (__m512bh) xArray_remix);
  452. __m256 result256 = _mm256_add_ps(_mm512_castps512_ps256(result), _mm512_extractf32x8_ps(result, 1));
  453. unsigned char store_tail_mask_value = (((unsigned char)0xff) >> (8-tail_num));
  454. __mmask8 store_tail_mask = *((__mmask8*) &store_tail_mask_value);
  455. STORE8_MASK_COMPLETE_RESULT(result256, y+tag_m_16x, store_tail_mask)
  456. }
  457. return 0;
  458. }
  459. // 30 rows parallel processing BF16 GEMV kernel for n=5 && lda ineffective scenario
  460. #ifndef ZERO_BETA
  461. #ifndef ONE_BETA
  462. static int sbgemv_kernel_30x5_alpha_beta(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
  463. #else
  464. static int sbgemv_kernel_30x5_alpha_one(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
  465. #endif
  466. #else
  467. #ifndef ONE_ALPHA
  468. static int sbgemv_kernel_30x5_alpha(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
  469. #else
  470. static int sbgemv_kernel_30x5(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
  471. #endif
  472. #endif
  473. {
  474. BLASLONG tag_m_30x = m - (m%30);
  475. unsigned char x_load_mask_value = (((unsigned char)0xff) >> 3);
  476. __mmask8 x_load_mask = *((__mmask8*) &x_load_mask_value);
  477. __m128i x128 = _mm_maskz_loadu_epi16(x_load_mask, x); // x0|x1|x2|x3|x4|0|0|0|
  478. #ifndef ONE_ALPHA
  479. __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
  480. #endif
  481. #ifndef ZERO_BETA
  482. __m512 BETAVECTOR = _mm512_set1_ps(beta);
  483. #endif
  484. __m512 result_0, result_1;
  485. __m512i xArray_01 = _mm512_broadcastd_epi32(x128); // x0|x1|x0|x1|...|x0|x1|
  486. __m512i xArray_23 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128, 0x1)); // x2|x3|x2|x3|...|x2|x3|
  487. __m512i xArray_4 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128, 0x2)); // x4| 0|x4| 0|...|x4| 0|
  488. __m512i M512_EPI16_2 = _mm512_set1_epi16(2);
  489. __m512i load_idx01_stage1_1st = _mm512_set_epi16( 0, 0, 0, 0, 0, 0, 0, 0, 58, 57, 53, 52, 48, 47, 43, 42,
  490. 38, 37, 33, 32, 26, 25, 21, 20, 16, 15, 11, 10, 6, 5, 1, 0);
  491. __m512i load_idx01_stage1_2nd = _mm512_shuffle_i32x4(load_idx01_stage1_1st, load_idx01_stage1_1st, 0x39);
  492. __m512i load_idx01_stage1_3rd = _mm512_shuffle_i32x4(load_idx01_stage1_1st, load_idx01_stage1_1st, 0x4f);
  493. __m512i load_idx23_stage1_1st = _mm512_add_epi16(load_idx01_stage1_1st, M512_EPI16_2);
  494. __m512i load_idx23_stage1_2nd = _mm512_add_epi16(load_idx01_stage1_2nd, M512_EPI16_2);
  495. __m512i load_idx23_stage1_3rd = _mm512_add_epi16(load_idx01_stage1_3rd, M512_EPI16_2);
  496. __m512i load_idx4_stage1_1st = _mm512_add_epi16(load_idx23_stage1_1st, M512_EPI16_2);
  497. __m512i load_idx4_stage1_2nd = _mm512_add_epi16(load_idx23_stage1_2nd, M512_EPI16_2);
  498. __m512i load_idx4_stage1_3rd = _mm512_add_epi16(load_idx23_stage1_3rd, M512_EPI16_2);
  499. __m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3, matrixArray_4;
  500. __m512i matrixArray_stage1_0, matrixArray_stage1_1, matrixArray_stage1_2;
  501. __m512i matrixArray_stage2_0, matrixArray_stage2_1;
  502. unsigned int load_mask_value = (((unsigned int)0xffffffff) >> 2);
  503. __mmask32 load_mask = *((__mmask32*) &load_mask_value);
  504. unsigned short store_mask_value = (((unsigned short)0xffff) >> 2);
  505. __mmask16 store_mask = *((__mmask16*) &store_mask_value);
  506. if (tag_m_30x > 0) {
  507. unsigned short blend_mask_value_0 = ((unsigned short)0xf000);
  508. __mmask16 blend_mask_0 = *((__mmask16*) &blend_mask_value_0);
  509. unsigned short blend_mask_value_1 = ((unsigned short)0x3f00);
  510. __mmask16 blend_mask_1 = *((__mmask16*) &blend_mask_value_1);
  511. for (BLASLONG idx_m = 0; idx_m < tag_m_30x; idx_m+=30) {
  512. result_0 = _mm512_setzero_ps();
  513. result_1 = _mm512_setzero_ps();
  514. matrixArray_0 = _mm512_maskz_loadu_epi16(load_mask, &a[(idx_m)*5]); // Load 6 rows with n=5
  515. matrixArray_1 = _mm512_maskz_loadu_epi16(load_mask, &a[((idx_m+6)*5)]); // Load 6 rows with n=5
  516. matrixArray_2 = _mm512_maskz_loadu_epi16(load_mask, &a[((idx_m+12)*5)]); // Load 6 rows with n=5
  517. matrixArray_3 = _mm512_maskz_loadu_epi16(load_mask, &a[((idx_m+18)*5)]); // Load 6 rows with n=5
  518. matrixArray_4 = _mm512_maskz_loadu_epi16(load_mask, &a[((idx_m+24)*5)]); // Load 6 rows with n=5
  519. // Process the 0|1 elements
  520. // Stage 1: Select the 0|1 elements for each row
  521. matrixArray_stage1_0 = _mm512_permutex2var_epi16(matrixArray_0, load_idx01_stage1_1st, matrixArray_1);
  522. matrixArray_stage1_1 = _mm512_permutex2var_epi16(matrixArray_2, load_idx01_stage1_2nd, matrixArray_3);
  523. matrixArray_stage1_2 = _mm512_permutexvar_epi16(load_idx01_stage1_3rd, matrixArray_4);
  524. // Stage 2: Reorder and compress all the 0|1 elements
  525. matrixArray_stage2_0 = _mm512_mask_blend_epi32(blend_mask_0, matrixArray_stage1_0, matrixArray_stage1_1);
  526. matrixArray_stage2_1 = _mm512_mask_blend_epi32(blend_mask_1, matrixArray_stage1_1, matrixArray_stage1_2);
  527. // Calculate the result of the 0|1 elements
  528. result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage2_0, (__m512bh) xArray_01);
  529. result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_stage2_1, (__m512bh) xArray_01);
  530. // Process the 2|3 elements
  531. // Stage 1: Select the 2|3 elements for each row
  532. matrixArray_stage1_0 = _mm512_permutex2var_epi16(matrixArray_0, load_idx23_stage1_1st, matrixArray_1);
  533. matrixArray_stage1_1 = _mm512_permutex2var_epi16(matrixArray_2, load_idx23_stage1_2nd, matrixArray_3);
  534. matrixArray_stage1_2 = _mm512_permutexvar_epi16(load_idx23_stage1_3rd, matrixArray_4);
  535. // Stage 2: Reorder and compress all the 2|3 elements
  536. matrixArray_stage2_0 = _mm512_mask_blend_epi32(blend_mask_0, matrixArray_stage1_0, matrixArray_stage1_1);
  537. matrixArray_stage2_1 = _mm512_mask_blend_epi32(blend_mask_1, matrixArray_stage1_1, matrixArray_stage1_2);
  538. // Calculate the result of the 2|3 elements and accumulate the result of 0|1 elements
  539. result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage2_0, (__m512bh) xArray_23);
  540. result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_stage2_1, (__m512bh) xArray_23);
  541. // Process the for 4 elements
  542. // Stage 1: Select the 4 elements for each row
  543. matrixArray_stage1_0 = _mm512_permutex2var_epi16(matrixArray_0, load_idx4_stage1_1st, matrixArray_1);
  544. matrixArray_stage1_1 = _mm512_permutex2var_epi16(matrixArray_2, load_idx4_stage1_2nd, matrixArray_3);
  545. matrixArray_stage1_2 = _mm512_permutexvar_epi16(load_idx4_stage1_3rd, matrixArray_4);
  546. // Stage 2: Reorder and compress all the 4 elements
  547. matrixArray_stage2_0 = _mm512_mask_blend_epi32(blend_mask_0, matrixArray_stage1_0, matrixArray_stage1_1);
  548. matrixArray_stage2_1 = _mm512_mask_blend_epi32(blend_mask_1, matrixArray_stage1_1, matrixArray_stage1_2);
  549. // Calculate the result of the 4 element and accumulate the result of 0|1 and 2|3 elements
  550. result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage2_0, (__m512bh) xArray_4);
  551. result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_stage2_1, (__m512bh) xArray_4);
  552. STORE16_COMPLETE_RESULT(result_0, y+idx_m)
  553. STORE16_MASK_COMPLETE_RESULT(result_1, y+idx_m+16, store_mask)
  554. }
  555. }
  556. if (m - tag_m_30x > 11) {
  557. BLASLONG tag_m_12x = m - ((m-tag_m_30x)%12);
  558. for (BLASLONG idx_m = tag_m_30x; idx_m < tag_m_12x; idx_m+=12) {
  559. unsigned short store_less_mask_value = (((unsigned short)0xffff) >> 4);
  560. __mmask16 store_less_mask = *((__mmask16*) &store_less_mask_value);
  561. result_0 = _mm512_setzero_ps();
  562. matrixArray_0 = _mm512_maskz_loadu_epi16(load_mask, &a[(idx_m)*5]); // Load 6 rows with n=5
  563. matrixArray_1 = _mm512_maskz_loadu_epi16(load_mask, &a[((idx_m+6)*5)]); // Load 6 rows with n=5
  564. // Interleave the elements
  565. matrixArray_stage1_0 = _mm512_permutex2var_epi16(matrixArray_0, load_idx01_stage1_1st, matrixArray_1);
  566. matrixArray_stage1_1 = _mm512_permutex2var_epi16(matrixArray_0, load_idx23_stage1_1st, matrixArray_1);
  567. matrixArray_stage1_2 = _mm512_permutex2var_epi16(matrixArray_0, load_idx4_stage1_1st, matrixArray_1);
  568. // Calculate and accumulate the result
  569. result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage1_0, (__m512bh) xArray_01);
  570. result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage1_1, (__m512bh) xArray_23);
  571. result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage1_2, (__m512bh) xArray_4);
  572. STORE16_MASK_COMPLETE_RESULT(result_0, y+idx_m, store_less_mask)
  573. tag_m_30x += 12;
  574. }
  575. }
  576. BLASLONG tail_num = m - tag_m_30x;
  577. if (tail_num > 6) {
  578. unsigned short store_less_mask_value = (((unsigned short)0xffff) >> (4+(12-tail_num)));
  579. __mmask16 store_less_mask = *((__mmask16*) &store_less_mask_value);
  580. unsigned int load_less_mask_value = (((unsigned int)0xffffffff) >> (2+(12-tail_num)*5));
  581. __mmask32 load_less_mask = *((__mmask32*) &load_less_mask_value);
  582. result_0 = _mm512_setzero_ps();
  583. matrixArray_0 = _mm512_maskz_loadu_epi16(load_mask, &a[(tag_m_30x)*5]); // Load 6 rows with n=5
  584. matrixArray_1 = _mm512_maskz_loadu_epi16(load_less_mask, &a[((tag_m_30x+6)*5)]); // Load x rows with n=5
  585. // Interleave the elements
  586. matrixArray_stage1_0 = _mm512_permutex2var_epi16(matrixArray_0, load_idx01_stage1_1st, matrixArray_1);
  587. matrixArray_stage1_1 = _mm512_permutex2var_epi16(matrixArray_0, load_idx23_stage1_1st, matrixArray_1);
  588. matrixArray_stage1_2 = _mm512_permutex2var_epi16(matrixArray_0, load_idx4_stage1_1st, matrixArray_1);
  589. // Calculate and accumulate the result
  590. result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage1_0, (__m512bh) xArray_01);
  591. result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage1_1, (__m512bh) xArray_23);
  592. result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage1_2, (__m512bh) xArray_4);
  593. STORE16_MASK_COMPLETE_RESULT(result_0, y+tag_m_30x, store_less_mask)
  594. } else {
  595. __m128i matrixArray128;
  596. __m128 result128, tmp128;
  597. for (BLASLONG i = tag_m_30x; i < m; i++) {
  598. result128 = _mm_setzero_ps();
  599. matrixArray128 = _mm_maskz_loadu_epi16(x_load_mask, &a[(i)*5]); // Load 1 rows with n=5
  600. result128 = _mm_dpbf16_ps(result128, (__m128bh) matrixArray128, (__m128bh) x128);
  601. tmp128 = _mm_shuffle_ps(result128, result128, 14);
  602. result128 = _mm_add_ps(result128, tmp128);
  603. tmp128 = _mm_shuffle_ps(result128, result128, 1);
  604. result128 = _mm_add_ps(result128, tmp128);
  605. #ifndef ZERO_BETA
  606. #ifndef ONE_BETA
  607. y[i] = alpha * result128[0] + beta * y[i];
  608. #else
  609. y[i] = alpha * result128[0] + y[i];
  610. #endif
  611. #else
  612. #ifndef ONE_ALPHA
  613. y[i] = result128[0] * alpha;
  614. #else
  615. y[i] = result128[0];
  616. #endif
  617. #endif
  618. }
  619. }
  620. return 0;
  621. }
  622. // 16 rows parallel processing BF16 GEMV kernel for n=6 && lda ineffective scenario
  623. #ifndef ZERO_BETA
  624. #ifndef ONE_BETA
  625. static int sbgemv_kernel_16x6_alpha_beta(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
  626. #else
  627. static int sbgemv_kernel_16x6_alpha_one(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
  628. #endif
  629. #else
  630. #ifndef ONE_ALPHA
  631. static int sbgemv_kernel_16x6_alpha(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
  632. #else
  633. static int sbgemv_kernel_16x6(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
  634. #endif
  635. #endif
  636. {
  637. BLASLONG tag_m_16x = m & (~15);
  638. unsigned char x_load_mask_value = (((unsigned char)0xff) >> 2);
  639. __mmask8 x_load_mask = *((__mmask8*) &x_load_mask_value);
  640. __m128i x128 = _mm_maskz_loadu_epi16(x_load_mask, x); // x0|x1|x2|x3|x4|x5|0|0|
  641. if (tag_m_16x > 0) {
  642. __m512 result_0;
  643. #ifndef ONE_ALPHA
  644. __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
  645. #endif
  646. #ifndef ZERO_BETA
  647. __m512 BETAVECTOR = _mm512_set1_ps(beta);
  648. #endif
  649. __m512i M512_EPI32_1 = _mm512_set1_epi32(1);
  650. __m512i load_idx01_1st = _mm512_set_epi32( 0, 0, 0, 0, 0, 30, 27, 24, 21, 18, 15, 12, 9, 6, 3, 0);
  651. __m512i load_idx01_2nd = _mm512_set_epi32(13, 10, 7, 4, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0);
  652. __m512i load_idx23_1st = _mm512_add_epi32(load_idx01_1st, M512_EPI32_1);
  653. __m512i load_idx23_2nd = _mm512_add_epi32(load_idx01_2nd, M512_EPI32_1);
  654. __m512i load_idx45_1st = _mm512_add_epi32(load_idx23_1st, M512_EPI32_1);
  655. __m512i load_idx45_2nd = _mm512_add_epi32(load_idx23_2nd, M512_EPI32_1);
  656. unsigned short blend_mask_value = ((unsigned short)0x0400);
  657. __mmask16 blend_mask = *((__mmask16*) &blend_mask_value);
  658. // Set the 11th element to be 0 as invalid index for a 512 bit epi32 register
  659. load_idx45_1st = _mm512_mask_blend_epi32(blend_mask, load_idx45_1st, load_idx01_2nd);
  660. // Set the 11th element to be 0 as 0 is the correct index
  661. load_idx45_2nd = _mm512_mask_blend_epi32(blend_mask, load_idx45_2nd, load_idx01_2nd);
  662. __m512i xArray_01 = _mm512_broadcastd_epi32(x128); // x0|x1|x0|x1|...|x0|x1|
  663. __m512i xArray_23 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128, 0x1)); // x2|x3|x2|x3|...|x2|x3|
  664. __m512i xArray_45 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128, 0x2)); // x4|x5|x4|x5|...|x4|x5|
  665. unsigned short permute_mask01_uint = (((unsigned short)0xf800));
  666. __mmask16 permute_mask01 = *((__mmask16*) &permute_mask01_uint);
  667. unsigned short permute_mask45_uint = (((unsigned short)0xfc00));
  668. __mmask16 permute_mask45 = *((__mmask16*) &permute_mask45_uint);
  669. __m512i matrixArray_0, matrixArray_1, matrixArray_2;
  670. __m512i matrixArray_stage_0, matrixArray_stage_1, matrixArray_stage_2;
  671. for (BLASLONG idx_m = 0; idx_m < tag_m_16x; idx_m+=16) {
  672. result_0 = _mm512_setzero_ps();
  673. matrixArray_0 = _mm512_loadu_si512(&a[(idx_m)*6]); // Load 5 rows with n=6 plus 2 element
  674. matrixArray_1 = _mm512_loadu_si512(&a[((idx_m+5)*6 + 2)]); // Load 5 rows with n=6 plus 2 element
  675. matrixArray_2 = _mm512_loadu_si512(&a[((idx_m+10)*6 + 4)]); // Load 5 rows with n=6 plus 2 element
  676. // Stage 1: interleave for the a..k elements
  677. matrixArray_stage_0 = _mm512_permutex2var_epi32(matrixArray_0, load_idx01_1st, matrixArray_1);
  678. matrixArray_stage_1 = _mm512_permutex2var_epi32(matrixArray_0, load_idx23_1st, matrixArray_1);
  679. matrixArray_stage_2 = _mm512_permutex2var_epi32(matrixArray_0, load_idx45_1st, matrixArray_1);
  680. // Stage 2: interleave for the l..p elements and remix together
  681. matrixArray_stage_0 = _mm512_mask_permutexvar_epi32(matrixArray_stage_0, permute_mask01, load_idx01_2nd, matrixArray_2);
  682. matrixArray_stage_1 = _mm512_mask_permutexvar_epi32(matrixArray_stage_1, permute_mask01, load_idx23_2nd, matrixArray_2);
  683. matrixArray_stage_2 = _mm512_mask_permutexvar_epi32(matrixArray_stage_2, permute_mask45, load_idx45_2nd, matrixArray_2);
  684. // Calculate the result of the 0|1 elements
  685. result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage_0, (__m512bh) xArray_01);
  686. result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage_1, (__m512bh) xArray_23);
  687. result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage_2, (__m512bh) xArray_45);
  688. STORE16_COMPLETE_RESULT(result_0, y+idx_m)
  689. }
  690. if (m - tag_m_16x > 7) {
  691. __m256i M256_EPI32_1 = _mm512_castsi512_si256(M512_EPI32_1);
  692. __m256i load_idx01_1st = _mm256_set_epi32( 0, 0, 15, 12, 9, 6, 3, 0);
  693. __m256i load_idx01_2nd = _mm256_set_epi32( 5, 2, 0, 0, 0, 0, 0, 0);
  694. __m256i load_idx23_1st = _mm256_add_epi32(load_idx01_1st, M256_EPI32_1);
  695. __m256i load_idx23_2nd = _mm256_add_epi32(load_idx01_2nd, M256_EPI32_1);
  696. unsigned char blend_mask_value = ((unsigned char)0x20);
  697. __mmask8 blend_mask = *((__mmask8*) &blend_mask_value);
  698. // Set the 6th element to be 0 as invalid index for a 512 bit epi32 register
  699. load_idx23_1st = _mm256_mask_blend_epi32(blend_mask, load_idx23_1st, load_idx01_2nd);
  700. // Set the 6th element to be 0 as 0 is the correct index
  701. load_idx23_2nd = _mm256_mask_blend_epi32(blend_mask, load_idx23_2nd, load_idx01_2nd);
  702. __m256i load_idx45_1st = _mm256_add_epi32(load_idx23_1st, M256_EPI32_1);
  703. __m256i load_idx45_2nd = _mm256_add_epi32(load_idx23_2nd, M256_EPI32_1);
  704. unsigned char permute_mask01_uint = (((unsigned char)0xc0));
  705. __mmask8 permute_mask01 = *((__mmask8*) &permute_mask01_uint);
  706. unsigned char permute_mask45_uint = (((unsigned char)0xe0));
  707. __mmask8 permute_mask45 = *((__mmask8*) &permute_mask45_uint);
  708. __m256i matrixArray_0, matrixArray_1, matrixArray_2;
  709. __m256i matrixArray_stage_0;
  710. __m256 result256_0;
  711. result256_0 = _mm256_setzero_ps();
  712. matrixArray_0 = _mm256_loadu_si256(&a[(tag_m_16x)*6]); // Load 2 rows with n=6 plus 4 element
  713. matrixArray_1 = _mm256_loadu_si256(&a[((tag_m_16x+2)*6 + 4)]); // Load 2 rows with n=6 plus 4 element
  714. matrixArray_2 = _mm256_loadu_si256(&a[((tag_m_16x+5)*6 + 2)]); // Load 2 rows with n=6 plus 4 element
  715. // Process the 0|1 elements
  716. // Select the 0|1 elements for each row
  717. matrixArray_stage_0 = _mm256_permutex2var_epi32(matrixArray_0, load_idx01_1st, matrixArray_1);
  718. matrixArray_stage_0 = _mm256_mask_permutexvar_epi32(matrixArray_stage_0, permute_mask01, load_idx01_2nd, matrixArray_2);
  719. // Calculate the result of the 0|1 elements
  720. result256_0 = _mm256_dpbf16_ps(result256_0, (__m256bh) matrixArray_stage_0, (__m256bh) _mm512_castsi512_si256(xArray_01));
  721. // Process the 2|3 elements
  722. // Select the 2|3 elements for each row
  723. matrixArray_stage_0 = _mm256_permutex2var_epi32(matrixArray_0, load_idx23_1st, matrixArray_1);
  724. matrixArray_stage_0 = _mm256_mask_permutexvar_epi32(matrixArray_stage_0, permute_mask45, load_idx23_2nd, matrixArray_2);
  725. // Calculate the result of the 0|1 elements
  726. result256_0 = _mm256_dpbf16_ps(result256_0, (__m256bh) matrixArray_stage_0, (__m256bh) _mm512_castsi512_si256(xArray_23));
  727. // Process the for 4 elements
  728. // Select the 4|5 elements for each row
  729. matrixArray_stage_0 = _mm256_permutex2var_epi32(matrixArray_0, load_idx45_1st, matrixArray_1);
  730. matrixArray_stage_0 = _mm256_mask_permutexvar_epi32(matrixArray_stage_0, permute_mask45, load_idx45_2nd, matrixArray_2);
  731. // Calculate the result of the 0|1 elements
  732. result256_0 = _mm256_dpbf16_ps(result256_0, (__m256bh) matrixArray_stage_0, (__m256bh) _mm512_castsi512_si256(xArray_45));
  733. STORE8_COMPLETE_RESULT(result256_0, y+tag_m_16x)
  734. tag_m_16x += 8;
  735. }
  736. }
  737. if (tag_m_16x != m) {
  738. __m128i matrixArray128;
  739. __m128 result128, tmp128;
  740. for (BLASLONG i = tag_m_16x; i < m; i++) {
  741. result128 = _mm_setzero_ps();
  742. matrixArray128 = _mm_maskz_loadu_epi16(x_load_mask, &a[(i)*6]); // Load 1 rows with n=6
  743. result128 = _mm_dpbf16_ps(result128, (__m128bh) matrixArray128, (__m128bh) x128);
  744. tmp128 = _mm_shuffle_ps(result128, result128, 14);
  745. result128 = _mm_add_ps(result128, tmp128);
  746. tmp128 = _mm_shuffle_ps(result128, result128, 1);
  747. result128 = _mm_add_ps(result128, tmp128);
  748. #ifndef ZERO_BETA
  749. #ifndef ONE_BETA
  750. y[i] = alpha * result128[0] + beta * y[i];
  751. #else
  752. y[i] = alpha * result128[0] + y[i];
  753. #endif
  754. #else
  755. #ifndef ONE_ALPHA
  756. y[i] = result128[0] * alpha;
  757. #else
  758. y[i] = result128[0];
  759. #endif
  760. #endif
  761. }
  762. }
  763. return 0;
  764. }
  765. // 16 rows parallel processing BF16 GEMV kernel for n=7 && lda ineffective scenario
  766. #ifndef ZERO_BETA
  767. #ifndef ONE_BETA
  768. static int sbgemv_kernel_16x7_alpha_beta(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
  769. #else
  770. static int sbgemv_kernel_16x7_alpha_one(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
  771. #endif
  772. #else
  773. #ifndef ONE_ALPHA
  774. static int sbgemv_kernel_16x7_alpha(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
  775. #else
  776. static int sbgemv_kernel_16x7(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
  777. #endif
  778. #endif
  779. {
  780. BLASLONG tag_m_16x = m & (~15);
  781. unsigned char x_load_mask_value = (((unsigned char)0xff) >> 1);
  782. __mmask8 x_load_mask = *((__mmask8*) &x_load_mask_value);
  783. __m128i x128 = _mm_maskz_loadu_epi16(x_load_mask, x); // |x0|x1|x2|x3|x4|x5|x6|0|
  784. if (tag_m_16x > 0) {
  785. __m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3;
  786. __m512i matrixArray_stage_0, matrixArray_stage_1, matrixArray_stage_2, matrixArray_stage_3;
  787. __m512i xArray_0123, xArray_4567;
  788. __m512 result_0, result_1, result_2, result_3;
  789. #ifndef ONE_ALPHA
  790. __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
  791. #endif
  792. #ifndef ZERO_BETA
  793. __m512 BETAVECTOR = _mm512_set1_ps(beta);
  794. #endif
  795. __m512i M512_EPI32_2 = _mm512_set1_epi32(2);
  796. __m512i load_idx_stage1_0 = _mm512_set_epi16(31, 27, 26, 25, 24, 23, 22, 21, 31, 20, 19, 18, 17, 16, 15, 14,
  797. 31, 13, 12, 11, 10, 9, 8, 7, 31, 6, 5, 4, 3, 2, 1, 0);
  798. __m512i load_idx_stage2_0 = _mm512_set_epi32(29, 25, 21, 17, 13, 9, 5, 1, 28, 24, 20, 16, 12, 8, 4, 0);
  799. __m512i load_idx_stage2_1 = _mm512_add_epi32(load_idx_stage2_0, M512_EPI32_2);
  800. unsigned short x_blend_mask_value = ((unsigned short)0xff00);
  801. __mmask16 x_blend_mask = *((__mmask16*) &x_blend_mask_value);
  802. xArray_0123 = _mm512_mask_blend_epi32(x_blend_mask, _mm512_broadcastd_epi32(x128), \
  803. _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128, 0x1)));
  804. xArray_4567 = _mm512_mask_blend_epi32(x_blend_mask, _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128, 0x2)), \
  805. _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128, 0x3)));
  806. unsigned int load_mask_value = (((unsigned int)0xffffffff) >> 4);
  807. __mmask32 load_mask = *((__mmask32*) &load_mask_value);
  808. for (BLASLONG idx_m = 0; idx_m < tag_m_16x; idx_m+=16) {
  809. result_0 = _mm512_setzero_ps();
  810. result_1 = _mm512_setzero_ps();
  811. matrixArray_0 = _mm512_maskz_loadu_epi16(load_mask, &a[(idx_m)*7]); // Load 4 rows with n=7
  812. matrixArray_1 = _mm512_maskz_loadu_epi16(load_mask, &a[(idx_m+4)*7]); // Load 4 rows with n=7
  813. matrixArray_2 = _mm512_maskz_loadu_epi16(load_mask, &a[(idx_m+8)*7]); // Load 4 rows with n=7
  814. matrixArray_3 = _mm512_maskz_loadu_epi16(load_mask, &a[(idx_m+12)*7]); // Load 4 rows with n=7
  815. // Stage 1: padding
  816. matrixArray_0 = _mm512_permutexvar_epi16(load_idx_stage1_0, matrixArray_0); // |a0|a1|a2|a3|...|b6|b7|c0|c1|c2|c3|...|d6|d7|
  817. matrixArray_1 = _mm512_permutexvar_epi16(load_idx_stage1_0, matrixArray_1); // |e0|e1|e2|e3|...|f6|f7|g0|g1|g2|g3|...|h6|h7|
  818. matrixArray_2 = _mm512_permutexvar_epi16(load_idx_stage1_0, matrixArray_2); // |i0|i1|i2|i3|...|j6|j7|k0|k1|k2|k3|...|l6|l7|
  819. matrixArray_3 = _mm512_permutexvar_epi16(load_idx_stage1_0, matrixArray_3); // |m0|m1|m2|m3|...|n6|n7|o0|o1|o2|o3|...|p6|p7|
  820. // Stage 2: interleave per 32 bits
  821. matrixArray_stage_0 = _mm512_permutex2var_epi32(matrixArray_0, load_idx_stage2_0, matrixArray_1); // |a0|a1|...|h0|h1|a2|a3|...|h2|h3|
  822. matrixArray_stage_1 = _mm512_permutex2var_epi32(matrixArray_0, load_idx_stage2_1, matrixArray_1); // |a4|a5|...|h4|h5|a6|a7|...|h6|h7|
  823. matrixArray_stage_2 = _mm512_permutex2var_epi32(matrixArray_2, load_idx_stage2_0, matrixArray_3); // |i0|i1|...|p0|p1|i2|i3|...|p2|p3|
  824. matrixArray_stage_3 = _mm512_permutex2var_epi32(matrixArray_2, load_idx_stage2_1, matrixArray_3); // |i4|i5|...|p4|p5|i6|i7|...|p6|p7|
  825. result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage_0, (__m512bh) xArray_0123);
  826. result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_stage_2, (__m512bh) xArray_0123);
  827. result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage_1, (__m512bh) xArray_4567);
  828. result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_stage_3, (__m512bh) xArray_4567);
  829. // Stage 3: interleave per 256 bits
  830. result_2 = _mm512_shuffle_f32x4(result_0, result_1, 0x44);
  831. result_3 = _mm512_shuffle_f32x4(result_0, result_1, 0xee);
  832. result_2 = _mm512_add_ps(result_2, result_3);
  833. STORE16_COMPLETE_RESULT(result_2, y+idx_m)
  834. }
  835. if (m - tag_m_16x > 7) {
  836. result_0 = _mm512_setzero_ps();
  837. matrixArray_0 = _mm512_maskz_loadu_epi16(load_mask, &a[(tag_m_16x)*7]); // Load 4 rows with n=7
  838. matrixArray_1 = _mm512_maskz_loadu_epi16(load_mask, &a[(tag_m_16x+4)*7]); // Load 4 rows with n=7
  839. // Stage 1: padding
  840. matrixArray_0 = _mm512_permutexvar_epi16(load_idx_stage1_0, matrixArray_0); // |a0|a1|a2|a3|...|b6|b7|c0|c1|c2|c3|...|d6|d7|
  841. matrixArray_1 = _mm512_permutexvar_epi16(load_idx_stage1_0, matrixArray_1); // |e0|e1|e2|e3|...|f6|f7|g0|g1|g2|g3|...|h6|h7|
  842. // Stage 2: interleave per 32 bits
  843. matrixArray_stage_0 = _mm512_permutex2var_epi32(matrixArray_0, load_idx_stage2_0, matrixArray_1); // |a0|a1|b0|b1|...|h0|h1|a2|a3|b2|b3|...|h2|h3|
  844. matrixArray_stage_1 = _mm512_permutex2var_epi32(matrixArray_0, load_idx_stage2_1, matrixArray_1); // |a4|a5|b4|b5|...|h4|h5|a6|a7|b6|b7|...|h6|h7|
  845. result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage_0, (__m512bh) xArray_0123);
  846. result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage_1, (__m512bh) xArray_4567);
  847. __m256 result256 = _mm256_add_ps(_mm512_castps512_ps256(result_0), _mm512_extractf32x8_ps(result_0, 0x1));
  848. STORE8_COMPLETE_RESULT(result256, y+tag_m_16x)
  849. tag_m_16x += 8;
  850. }
  851. BLASLONG tail_num = m - tag_m_16x;
  852. if (tail_num > 3) {
  853. result_0 = _mm512_setzero_ps();
  854. matrixArray_0 = _mm512_maskz_loadu_epi16(load_mask, &a[(tag_m_16x)*7]); // Load 4 rows with n=7
  855. unsigned int tail_load_mask_value = (((unsigned int)0xffffffff) >> (4+(8-tail_num)*7));
  856. __mmask32 tail_load_mask = *((__mmask32*) &tail_load_mask_value);
  857. matrixArray_1 = _mm512_maskz_loadu_epi16(tail_load_mask, &a[(tag_m_16x+4)*7]); // Load 4 rows with n=7
  858. // Stage 1: padding
  859. matrixArray_0 = _mm512_permutexvar_epi16(load_idx_stage1_0, matrixArray_0); // |a0|a1|a2|a3|...|b6|b7|c0|c1|c2|c3|...|d6|d7|
  860. matrixArray_1 = _mm512_permutexvar_epi16(load_idx_stage1_0, matrixArray_1); // |e0|e1|e2|e3|...|f6|f7|g0|g1|g2|g3|...|h6|h7|
  861. // Stage 2: interleave per 32 bits
  862. matrixArray_stage_0 = _mm512_permutex2var_epi32(matrixArray_0, load_idx_stage2_0, matrixArray_1); // |a0|a1|b0|b1|...|h0|h1|a2|a3|b2|b3|...|h2|h3|
  863. matrixArray_stage_1 = _mm512_permutex2var_epi32(matrixArray_0, load_idx_stage2_1, matrixArray_1); // |a4|a5|b4|b5|...|h4|h5|a6|a7|b6|b7|...|h6|h7|
  864. result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage_0, (__m512bh) xArray_0123);
  865. result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage_1, (__m512bh) xArray_4567);
  866. __m256 result256 = _mm256_add_ps(_mm512_castps512_ps256(result_0), _mm512_extractf32x8_ps(result_0, 0x1));
  867. unsigned char tail_mask_value = (((unsigned char)0xff) >> (8-tail_num));
  868. __mmask8 tail_mask = *((__mmask8*) &tail_mask_value);
  869. STORE8_MASK_COMPLETE_RESULT(result256, y+tag_m_16x, tail_mask)
  870. tag_m_16x = m;
  871. }
  872. }
  873. if (tag_m_16x != m) {
  874. __m128i matrixArray128;
  875. __m128 result128, tmp128;
  876. for (BLASLONG i = tag_m_16x; i < m; i++) {
  877. result128 = _mm_setzero_ps();
  878. matrixArray128 = _mm_maskz_loadu_epi16(x_load_mask, &a[(i)*7]); // Load 1 rows with n=7
  879. result128 = _mm_dpbf16_ps(result128, (__m128bh) matrixArray128, (__m128bh) x128);
  880. tmp128 = _mm_shuffle_ps(result128, result128, 14);
  881. result128 = _mm_add_ps(result128, tmp128);
  882. tmp128 = _mm_shuffle_ps(result128, result128, 1);
  883. result128 = _mm_add_ps(result128, tmp128);
  884. #ifndef ZERO_BETA
  885. #ifndef ONE_BETA
  886. y[i] = alpha * result128[0] + beta * y[i];
  887. #else
  888. y[i] = alpha * result128[0] + y[i];
  889. #endif
  890. #else
  891. #ifndef ONE_ALPHA
  892. y[i] = result128[0] * alpha;
  893. #else
  894. y[i] = result128[0];
  895. #endif
  896. #endif
  897. }
  898. }
  899. return 0;
  900. }
  901. // 16 rows parallel processing BF16 GEMV kernel for n=8 && lda ineffective scenario
  902. #ifndef ZERO_BETA
  903. #ifndef ONE_BETA
  904. static int sbgemv_kernel_16x8_alpha_beta(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
  905. #else
  906. static int sbgemv_kernel_16x8_alpha_one(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
  907. #endif
  908. #else
  909. #ifndef ONE_ALPHA
  910. static int sbgemv_kernel_16x8_alpha(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
  911. #else
  912. static int sbgemv_kernel_16x8(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
  913. #endif
  914. #endif
  915. {
  916. BLASLONG tag_m_16x = m & (~15);
  917. __m128i x128 = _mm_loadu_si128(x); // |x0|x1|x2|x3|x4|x5|x6|x7|
  918. if (tag_m_16x > 0) {
  919. __m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3;
  920. __m512i matrixArray_stage_0, matrixArray_stage_1, matrixArray_stage_2, matrixArray_stage_3;
  921. __m512i xArray_0123, xArray_4567;
  922. __m512 result_0, result_1, result_2, result_3;
  923. #ifndef ONE_ALPHA
  924. __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
  925. #endif
  926. #ifndef ZERO_BETA
  927. __m512 BETAVECTOR = _mm512_set1_ps(beta);
  928. #endif
  929. __m512i M512_EPI32_2 = _mm512_set1_epi32(2);
  930. __m512i load_idx_stage2_0 = _mm512_set_epi32(29, 25, 21, 17, 13, 9, 5, 1, 28, 24, 20, 16, 12, 8, 4, 0);
  931. __m512i load_idx_stage2_1 = _mm512_add_epi32(load_idx_stage2_0, M512_EPI32_2);
  932. unsigned short x_blend_mask_value = ((unsigned short)0xff00);
  933. __mmask16 x_blend_mask = *((__mmask16*) &x_blend_mask_value);
  934. xArray_0123 = _mm512_mask_blend_epi32(x_blend_mask, _mm512_broadcastd_epi32(x128), \
  935. _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128, 0x1)));
  936. xArray_4567 = _mm512_mask_blend_epi32(x_blend_mask, _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128, 0x2)), \
  937. _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128, 0x3)));
  938. for (BLASLONG idx_m = 0; idx_m < tag_m_16x; idx_m+=16) {
  939. result_0 = _mm512_setzero_ps();
  940. result_1 = _mm512_setzero_ps();
  941. matrixArray_0 = _mm512_loadu_si512(&a[(idx_m)*8]); // Load 4 rows with n=8
  942. matrixArray_1 = _mm512_loadu_si512(&a[(idx_m+4)*8]); // Load 4 rows with n=8
  943. matrixArray_2 = _mm512_loadu_si512(&a[(idx_m+8)*8]); // Load 4 rows with n=8
  944. matrixArray_3 = _mm512_loadu_si512(&a[(idx_m+12)*8]); // Load 4 rows with n=8
  945. // Stage 1: interleave per 32 bits
  946. matrixArray_stage_0 = _mm512_permutex2var_epi32(matrixArray_0, load_idx_stage2_0, matrixArray_1); // |a0|a1|...|h0|h1|a2|a3|...|h2|h3|
  947. matrixArray_stage_1 = _mm512_permutex2var_epi32(matrixArray_0, load_idx_stage2_1, matrixArray_1); // |a4|a5|...|h4|h5|a6|a7|...|h6|h7|
  948. matrixArray_stage_2 = _mm512_permutex2var_epi32(matrixArray_2, load_idx_stage2_0, matrixArray_3); // |i0|i1|...|p0|p1|i2|i3|...|p2|p3|
  949. matrixArray_stage_3 = _mm512_permutex2var_epi32(matrixArray_2, load_idx_stage2_1, matrixArray_3); // |i4|i5|...|p4|p5|i6|i7|...|p6|p7|
  950. result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage_0, (__m512bh) xArray_0123);
  951. result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_stage_2, (__m512bh) xArray_0123);
  952. result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage_1, (__m512bh) xArray_4567);
  953. result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_stage_3, (__m512bh) xArray_4567);
  954. // Stage 2: interleave per 256 bits
  955. result_2 = _mm512_shuffle_f32x4(result_0, result_1, 0x44);
  956. result_3 = _mm512_shuffle_f32x4(result_0, result_1, 0xee);
  957. result_2 = _mm512_add_ps(result_2, result_3);
  958. STORE16_COMPLETE_RESULT(result_2, y+idx_m)
  959. }
  960. if (m - tag_m_16x > 7) {
  961. result_0 = _mm512_setzero_ps();
  962. matrixArray_0 = _mm512_loadu_si512(&a[(tag_m_16x)*8]); // Load 4 rows with n=8
  963. matrixArray_1 = _mm512_loadu_si512(&a[(tag_m_16x+4)*8]); // Load 4 rows with n=8
  964. // Stage 1: interleave per 32 bits
  965. matrixArray_stage_0 = _mm512_permutex2var_epi32(matrixArray_0, load_idx_stage2_0, matrixArray_1); // |a0|a1|b0|b1|...|h0|h1|a2|a3|b2|b3|...|h2|h3|
  966. matrixArray_stage_1 = _mm512_permutex2var_epi32(matrixArray_0, load_idx_stage2_1, matrixArray_1); // |a4|a5|b4|b5|...|h4|h5|a6|a7|b6|b7|...|h6|h7|
  967. result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage_0, (__m512bh) xArray_0123);
  968. result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage_1, (__m512bh) xArray_4567);
  969. __m256 result256 = _mm256_add_ps(_mm512_castps512_ps256(result_0), _mm512_extractf32x8_ps(result_0, 0x1));
  970. STORE8_COMPLETE_RESULT(result256, y+tag_m_16x)
  971. tag_m_16x += 8;
  972. }
  973. BLASLONG tail_num = m - tag_m_16x;
  974. if (tail_num > 3) {
  975. result_0 = _mm512_setzero_ps();
  976. matrixArray_0 = _mm512_loadu_si512(&a[(tag_m_16x)*8]); // Load 4 rows with n=8
  977. unsigned short tail_load_mask_value = (((unsigned int)0xffff) >> ((8-tail_num)*4));
  978. __mmask16 tail_load_mask = *((__mmask16*) &tail_load_mask_value);
  979. matrixArray_1 = _mm512_maskz_loadu_epi32(tail_load_mask, &a[(tag_m_16x+4)*8]); // Load 4 rows with n=8
  980. // Stage 1: interleave per 32 bits
  981. matrixArray_stage_0 = _mm512_permutex2var_epi32(matrixArray_0, load_idx_stage2_0, matrixArray_1); // |a0|a1|b0|b1|...|h0|h1|a2|a3|b2|b3|...|h2|h3|
  982. matrixArray_stage_1 = _mm512_permutex2var_epi32(matrixArray_0, load_idx_stage2_1, matrixArray_1); // |a4|a5|b4|b5|...|h4|h5|a6|a7|b6|b7|...|h6|h7|
  983. result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage_0, (__m512bh) xArray_0123);
  984. result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage_1, (__m512bh) xArray_4567);
  985. __m256 result256 = _mm256_add_ps(_mm512_castps512_ps256(result_0), _mm512_extractf32x8_ps(result_0, 0x1));
  986. unsigned char tail_mask_value = (((unsigned char)0xff) >> (8-tail_num));
  987. __mmask8 tail_mask = *((__mmask8*) &tail_mask_value);
  988. STORE8_MASK_COMPLETE_RESULT(result256, y+tag_m_16x, tail_mask)
  989. tag_m_16x = m;
  990. }
  991. }
  992. if (tag_m_16x != m) {
  993. __m128i matrixArray128;
  994. __m128 result128, tmp128;
  995. for (BLASLONG i = tag_m_16x; i < m; i++) {
  996. result128 = _mm_setzero_ps();
  997. matrixArray128 = _mm_loadu_si128(&a[(i)*8]); // Load 1 rows with n=8
  998. result128 = _mm_dpbf16_ps(result128, (__m128bh) matrixArray128, (__m128bh) x128);
  999. tmp128 = _mm_shuffle_ps(result128, result128, 14);
  1000. result128 = _mm_add_ps(result128, tmp128);
  1001. tmp128 = _mm_shuffle_ps(result128, result128, 1);
  1002. result128 = _mm_add_ps(result128, tmp128);
  1003. #ifndef ZERO_BETA
  1004. #ifndef ONE_BETA
  1005. y[i] = alpha * result128[0] + beta * y[i];
  1006. #else
  1007. y[i] = alpha * result128[0] + y[i];
  1008. #endif
  1009. #else
  1010. #ifndef ONE_ALPHA
  1011. y[i] = result128[0] * alpha;
  1012. #else
  1013. y[i] = result128[0];
  1014. #endif
  1015. #endif
  1016. }
  1017. }
  1018. return 0;
  1019. }
  1020. // 14 rows parallel processing BF16 GEMV kernel for n=9 && lda ineffective scenario
  1021. #ifndef ZERO_BETA
  1022. #ifndef ONE_BETA
  1023. static int sbgemv_kernel_14x9_alpha_beta(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
  1024. #else
  1025. static int sbgemv_kernel_14x9_alpha_one(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
  1026. #endif
  1027. #else
  1028. #ifndef ONE_ALPHA
  1029. static int sbgemv_kernel_14x9_alpha(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
  1030. #else
  1031. static int sbgemv_kernel_14x9(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
  1032. #endif
  1033. #endif
  1034. {
  1035. BLASLONG tag_m_14x = m - (m%14);
  1036. unsigned char x_load_mask_value = (((unsigned char)0xff) >> 7);
  1037. __mmask8 x_load_mask = *((__mmask8*) &x_load_mask_value);
  1038. __m128i x128_0 = _mm_loadu_si128(x); // |x0|x1|x2|x3|x4|x5|x6|x7|
  1039. __m128i x128_1 = _mm_maskz_loadu_epi16(x_load_mask, (x+8)); // |x8|0 |0 | 0| 0| 0| 0| 0|
  1040. if (tag_m_14x > 0) {
  1041. __m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3, matrixArray_4, matrixArray_5;
  1042. __m512i matrixArray_stage_0, matrixArray_stage_1, matrixArray_stage_2, matrixArray_stage_3;
  1043. __m512i xArray_01, xArray_23, xArray_45, xArray_67, xArray_89;
  1044. __m512 result_0, result_1;
  1045. #ifndef ONE_ALPHA
  1046. __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
  1047. #endif
  1048. #ifndef ZERO_BETA
  1049. __m512 BETAVECTOR = _mm512_set1_ps(beta);
  1050. #endif
  1051. __m256i M256_EPI16_2 = _mm256_set1_epi16(2);
  1052. __m256i idx_base_0 = _mm256_set_epi16( 0, 0, 55, 54, 46, 45, 37, 36, 28, 27, 19, 18, 10, 9, 1, 0);
  1053. __m256i idx_base_1 = _mm256_add_epi16(idx_base_0, M256_EPI16_2);
  1054. __m256i idx_base_2 = _mm256_add_epi16(idx_base_1, M256_EPI16_2);
  1055. __m256i idx_base_3 = _mm256_add_epi16(idx_base_2, M256_EPI16_2);
  1056. __m256i idx_base_4 = _mm256_add_epi16(idx_base_3, M256_EPI16_2);
  1057. __m512i idx_idx = _mm512_set_epi32( 0, 0, 22, 21, 20, 19, 18, 17, 16, 6, 5, 4, 3, 2, 1, 0);
  1058. __m512i load_idx_stage1_0 = _mm512_permutex2var_epi32(_mm512_castsi256_si512(idx_base_0), idx_idx, _mm512_castsi256_si512(idx_base_1));
  1059. __m512i load_idx_stage1_1 = _mm512_permutex2var_epi32(_mm512_castsi256_si512(idx_base_2), idx_idx, _mm512_castsi256_si512(idx_base_3));
  1060. __m512i load_idx_stage1_2 = _mm512_permutex2var_epi32(_mm512_castsi256_si512(idx_base_1), idx_idx, _mm512_castsi256_si512(idx_base_0));
  1061. __m512i load_idx_stage1_3 = _mm512_permutex2var_epi32(_mm512_castsi256_si512(idx_base_3), idx_idx, _mm512_castsi256_si512(idx_base_2));
  1062. __m512i load_idx_stage1_4 = _mm512_permutex2var_epi32(_mm512_castsi256_si512(idx_base_4), idx_idx, _mm512_castsi256_si512(idx_base_4));
  1063. __m512i load_idx_stage2_0 = _mm512_set_epi32( 0, 0, 22, 21, 20, 19, 18, 17, 16, 13, 12, 11, 10, 9, 8, 7);
  1064. xArray_01 = _mm512_broadcastd_epi32(x128_0); // |x0|x1|x0|x1| ... |x0|x1|
  1065. xArray_23 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128_0, 0x1)); // |x2|x3|x2|x3| ... |x2|x3|
  1066. xArray_45 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128_0, 0x2)); // |x4|x5|x4|x5| ... |x4|x5|
  1067. xArray_67 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128_0, 0x3)); // |x6|x7|x6|x7| ... |x6|x7|
  1068. xArray_89 = _mm512_broadcastd_epi32(x128_1); // |x8|0 |x8| 0| ... |x8| 0|
  1069. unsigned int load_mask_value = (((unsigned int)0xffffffff) >> 1);
  1070. __mmask32 load_mask = *((__mmask32*) &load_mask_value);
  1071. unsigned short blend_mask_value = ((unsigned short)0x3f80);
  1072. __mmask16 blend_mask = *((__mmask16*) &blend_mask_value);
  1073. unsigned short store_mask_value = (((unsigned short)0xffff) >> 2);
  1074. __mmask16 store_mask = *((__mmask16*) &store_mask_value);
  1075. for (BLASLONG idx_m = 0; idx_m < tag_m_14x; idx_m+=14) {
  1076. result_0 = _mm512_setzero_ps();
  1077. result_1 = _mm512_setzero_ps();
  1078. matrixArray_0 = _mm512_loadu_si512(&a[(idx_m)*9]); // Load 3 rows with n=9 plus 5 elements
  1079. matrixArray_1 = _mm512_maskz_loadu_epi16(load_mask, &a[(idx_m+3)*9 + 5]); // Load 3 rows with n=9 plus 4 elements
  1080. matrixArray_2 = _mm512_loadu_si512(&a[(idx_m+7)*9]); // Load 3 rows with n=9 plus 5 elements
  1081. matrixArray_3 = _mm512_maskz_loadu_epi16(load_mask, &a[(idx_m+10)*9 + 5]); // Load 3 rows with n=9 plus 4 elements
  1082. // Stage 1: interleave per 16 bits
  1083. matrixArray_stage_0 = _mm512_permutex2var_epi16(matrixArray_0, load_idx_stage1_0, matrixArray_1); // |a0|a1|...|g0|g1|a2|a3|...|g2|g3|x|x|x|x|
  1084. matrixArray_stage_1 = _mm512_permutex2var_epi16(matrixArray_0, load_idx_stage1_1, matrixArray_1); // |a4|a5|...|g4|g5|a6|a7|...|g6|g7|x|x|x|x|
  1085. matrixArray_stage_2 = _mm512_permutex2var_epi16(matrixArray_2, load_idx_stage1_2, matrixArray_3); // |h2|h3|...|n2|n3|h0|h1|...|n0|n1|x|x|x|x|
  1086. matrixArray_stage_3 = _mm512_permutex2var_epi16(matrixArray_2, load_idx_stage1_3, matrixArray_3); // |h6|h7|...|n6|n7|h4|h5|...|n4|n5|x|x|x|x|
  1087. matrixArray_4 = _mm512_permutex2var_epi16(matrixArray_0, load_idx_stage1_4, matrixArray_1); // |a8| x|...|g8| x| x| x|...| x| x|x|x|x|x|
  1088. matrixArray_5 = _mm512_permutex2var_epi16(matrixArray_2, load_idx_stage1_4, matrixArray_3); // | x| x|...| x| x|h8| x|...|n8| x|x|x|x|x|
  1089. // Stage 2: interleave per 32 bits
  1090. matrixArray_0 = _mm512_mask_blend_epi32(blend_mask, matrixArray_stage_0, matrixArray_stage_2); // |a0|a1|b0|b1|...|h0|h1|i0|i1|j0|j1|...|n0|n1|x|x|x|x|
  1091. matrixArray_1 = _mm512_permutex2var_epi32(matrixArray_stage_0, load_idx_stage2_0, matrixArray_stage_2); // |a2|a3|b2|b3|...|h2|h3|i2|i3|j2|j3|...|n2|n3|x|x|x|x|
  1092. matrixArray_2 = _mm512_mask_blend_epi32(blend_mask, matrixArray_stage_1, matrixArray_stage_3); // |a4|a5|b4|b5|...|h4|h5|i4|i5|j4|j5|...|n4|n5|x|x|x|x|
  1093. matrixArray_3 = _mm512_permutex2var_epi32(matrixArray_stage_1, load_idx_stage2_0, matrixArray_stage_3); // |a6|a7|b6|b7|...|h6|h7|i6|i7|j6|j7|...|n6|n7|x|x|x|x|
  1094. matrixArray_4 = _mm512_mask_blend_epi32(blend_mask, matrixArray_4, matrixArray_5); // |a8| x|b8| x|...|h8| x|i8| x|j8| x|...|n8| x|x|x|x|x|
  1095. result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_0, (__m512bh) xArray_01);
  1096. result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_1, (__m512bh) xArray_23);
  1097. result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_2, (__m512bh) xArray_45);
  1098. result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_3, (__m512bh) xArray_67);
  1099. result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_4, (__m512bh) xArray_89);
  1100. result_0 = _mm512_add_ps(result_0, result_1);
  1101. STORE16_MASK_COMPLETE_RESULT(result_0, y+idx_m, store_mask)
  1102. }
  1103. }
  1104. if (tag_m_14x != m) {
  1105. __m256i matrixArray256;
  1106. __m256i x256 = _mm256_insertf128_si256(_mm256_castsi128_si256(x128_0), x128_1, 0x1);
  1107. __m256 result256;
  1108. __m128 result128, tmp128;
  1109. unsigned short load256_mask_value = (((unsigned short)0xffff) >> 7);
  1110. __mmask16 load256_mask = *((__mmask16*) &load256_mask_value);
  1111. for (BLASLONG i = tag_m_14x; i < m; i++) {
  1112. result256 = _mm256_setzero_ps();
  1113. matrixArray256 = _mm256_maskz_loadu_epi16(load256_mask, &a[(i)*9]);
  1114. result256 = _mm256_dpbf16_ps(result256, (__m256bh) matrixArray256, (__m256bh) x256);
  1115. result128 = _mm_add_ps(_mm256_castps256_ps128(result256), _mm256_extractf128_ps(result256, 0x1));
  1116. tmp128 = _mm_shuffle_ps(result128, result128, 14);
  1117. result128 = _mm_add_ps(result128, tmp128);
  1118. tmp128 = _mm_shuffle_ps(result128, result128, 1);
  1119. result128 = _mm_add_ps(result128, tmp128);
  1120. #ifndef ZERO_BETA
  1121. #ifndef ONE_BETA
  1122. y[i] = alpha * result128[0] + beta * y[i];
  1123. #else
  1124. y[i] = alpha * result128[0] + y[i];
  1125. #endif
  1126. #else
  1127. #ifndef ONE_ALPHA
  1128. y[i] = result128[0] * alpha;
  1129. #else
  1130. y[i] = result128[0];
  1131. #endif
  1132. #endif
  1133. }
  1134. }
  1135. return 0;
  1136. }
  1137. // 12 rows parallel processing BF16 GEMV kernel for n=10 && lda ineffective scenario
  1138. #ifndef ZERO_BETA
  1139. #ifndef ONE_BETA
  1140. static int sbgemv_kernel_12x10_alpha_beta(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
  1141. #else
  1142. static int sbgemv_kernel_12x10_alpha_one(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
  1143. #endif
  1144. #else
  1145. #ifndef ONE_ALPHA
  1146. static int sbgemv_kernel_12x10_alpha(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
  1147. #else
  1148. static int sbgemv_kernel_12x10(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
  1149. #endif
  1150. #endif
  1151. {
  1152. BLASLONG tag_m_12x = m - (m%12);
  1153. unsigned char x_load_mask_value = (((unsigned char)0xf) >> 3);
  1154. __mmask8 x_load_mask = *((__mmask8*) &x_load_mask_value);
  1155. __m128i x128_0 = _mm_loadu_si128(x); // |x0|x1|x2|x3|x4|x5|x6|x7|
  1156. __m128i x128_1 = _mm_maskz_loadu_epi32(x_load_mask, (x+8)); // |x8|x9|0 | 0| 0| 0| 0| 0|
  1157. if (tag_m_12x > 0) {
  1158. __m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3, matrixArray_4;
  1159. __m512i matrixArray_stage_0, matrixArray_stage_1, matrixArray_stage_2, matrixArray_stage_3, matrixArray_stage_4, matrixArray_stage_5;
  1160. __m512i xArray_01, xArray_23, xArray_45, xArray_67, xArray_89;
  1161. __m512 result_0, result_1;
  1162. #ifndef ONE_ALPHA
  1163. __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
  1164. #endif
  1165. #ifndef ZERO_BETA
  1166. __m512 BETAVECTOR = _mm512_set1_ps(beta);
  1167. #endif
  1168. __m256i M256_EPI32_1 = _mm256_set1_epi32(1);
  1169. __m256i idx_base_0 = _mm256_set_epi32( 0, 0, 26, 21, 16, 10, 5, 0);
  1170. __m256i idx_base_1 = _mm256_add_epi32(idx_base_0, M256_EPI32_1);
  1171. __m256i idx_base_2 = _mm256_add_epi32(idx_base_1, M256_EPI32_1);
  1172. __m256i idx_base_3 = _mm256_add_epi32(idx_base_2, M256_EPI32_1);
  1173. __m256i idx_base_4 = _mm256_add_epi32(idx_base_3, M256_EPI32_1);
  1174. __m512i idx_idx = _mm512_set_epi32( 0, 0, 0, 0, 21, 20, 19, 18, 17, 16, 5, 4, 3, 2, 1, 0);
  1175. __m512i load_idx_stage1_0 = _mm512_permutex2var_epi32(_mm512_castsi256_si512(idx_base_0), idx_idx, _mm512_castsi256_si512(idx_base_1));
  1176. __m512i load_idx_stage1_1 = _mm512_permutex2var_epi32(_mm512_castsi256_si512(idx_base_2), idx_idx, _mm512_castsi256_si512(idx_base_3));
  1177. __m512i load_idx_stage1_2 = _mm512_permutex2var_epi32(_mm512_castsi256_si512(idx_base_1), idx_idx, _mm512_castsi256_si512(idx_base_0));
  1178. __m512i load_idx_stage1_3 = _mm512_permutex2var_epi32(_mm512_castsi256_si512(idx_base_3), idx_idx, _mm512_castsi256_si512(idx_base_2));
  1179. __m512i load_idx_stage1_4 = _mm512_permutex2var_epi32(_mm512_castsi256_si512(idx_base_4), idx_idx, _mm512_castsi256_si512(idx_base_4));
  1180. __m512i load_idx_stage2_0 = _mm512_set_epi32( 0, 0, 0, 0, 21, 20, 19, 18, 17, 16, 11, 10, 9, 8, 7, 6);
  1181. xArray_01 = _mm512_broadcastd_epi32(x128_0); // |x0|x1|x0|x1| ... |x0|x1|
  1182. xArray_23 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128_0, 0x1)); // |x2|x3|x2|x3| ... |x2|x3|
  1183. xArray_45 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128_0, 0x2)); // |x4|x5|x4|x5| ... |x4|x5|
  1184. xArray_67 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128_0, 0x3)); // |x6|x7|x6|x7| ... |x6|x7|
  1185. xArray_89 = _mm512_broadcastd_epi32(x128_1); // |x8|x9|x8|x9| ... |x8|x9|
  1186. unsigned short blend_mask_value = ((unsigned short)0x0fc0);
  1187. __mmask16 blend_mask = *((__mmask16*) &blend_mask_value);
  1188. unsigned short load_mask_value = (((unsigned short)0xffff) >> 1);
  1189. __mmask16 load_mask = *((__mmask16*) &load_mask_value);
  1190. unsigned short store_mask_value = (((unsigned short)0xffff) >> 4);
  1191. __mmask16 store_mask = *((__mmask16*) &store_mask_value);
  1192. for (BLASLONG idx_m = 0; idx_m < tag_m_12x; idx_m+=12) {
  1193. result_0 = _mm512_setzero_ps();
  1194. result_1 = _mm512_setzero_ps();
  1195. matrixArray_0 = _mm512_maskz_loadu_epi32(load_mask, &a[(idx_m)*10]); // Load 3 rows with n=10
  1196. matrixArray_1 = _mm512_maskz_loadu_epi32(load_mask, &a[(idx_m+3)*10]); // Load 3 rows with n=10
  1197. matrixArray_2 = _mm512_maskz_loadu_epi32(load_mask, &a[(idx_m+6)*10]); // Load 3 rows with n=10
  1198. matrixArray_3 = _mm512_maskz_loadu_epi32(load_mask, &a[(idx_m+9)*10]); // Load 3 rows with n=10
  1199. // Stage 1: interleave per 32 bits
  1200. matrixArray_stage_0 = _mm512_permutex2var_epi32(matrixArray_0, load_idx_stage1_0, matrixArray_1); // |a0|a1|...|f0|f1|a2|a3|...|f2|f3|x|x|x|x|x|x|x|x|
  1201. matrixArray_stage_1 = _mm512_permutex2var_epi32(matrixArray_0, load_idx_stage1_1, matrixArray_1); // |a4|a5|...|f4|f5|a6|a7|...|f6|f7|x|x|x|x|x|x|x|x|
  1202. matrixArray_stage_2 = _mm512_permutex2var_epi32(matrixArray_2, load_idx_stage1_2, matrixArray_3); // |g2|g3|...|l2|l3|g0|g1|...|l0|l1|x|x|x|x|x|x|x|x|
  1203. matrixArray_stage_3 = _mm512_permutex2var_epi32(matrixArray_2, load_idx_stage1_3, matrixArray_3); // |g6|g7|...|l6|l7|g4|g5|...|l4|l5|x|x|x|x|x|x|x|x|
  1204. matrixArray_stage_4 = _mm512_permutex2var_epi32(matrixArray_0, load_idx_stage1_4, matrixArray_1); // |a8|a9|...|f8|f9| x| x|...| x| x|x|x|x|x|x|x|x|x|
  1205. matrixArray_stage_5 = _mm512_permutex2var_epi32(matrixArray_2, load_idx_stage1_4, matrixArray_3); // | x| x|...| x| x|g8|g9|...|l8|l9|x|x|x|x|x|x|x|x|
  1206. // Stage 3: interleave per 256 bits
  1207. matrixArray_0 = _mm512_mask_blend_epi32(blend_mask, matrixArray_stage_0, matrixArray_stage_2); // |a0|a1|...|l0|l1|x|x|x|x|x|x|x|x|
  1208. matrixArray_1 = _mm512_permutex2var_epi32(matrixArray_stage_0, load_idx_stage2_0, matrixArray_stage_2); // |a2|a3|...|l2|l3|x|x|x|x|x|x|x|x|
  1209. matrixArray_2 = _mm512_mask_blend_epi32(blend_mask, matrixArray_stage_1, matrixArray_stage_3); // |a4|a5|...|l4|l5|x|x|x|x|x|x|x|x|
  1210. matrixArray_3 = _mm512_permutex2var_epi32(matrixArray_stage_1, load_idx_stage2_0, matrixArray_stage_3); // |a6|a7|...|l6|l7|x|x|x|x|x|x|x|x|
  1211. matrixArray_4 = _mm512_mask_blend_epi32(blend_mask, matrixArray_stage_4, matrixArray_stage_5); // |a8|a9|...|l8|l9|x|x|x|x|x|x|x|x|
  1212. result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_0, (__m512bh) xArray_01);
  1213. result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_1, (__m512bh) xArray_23);
  1214. result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_2, (__m512bh) xArray_45);
  1215. result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_3, (__m512bh) xArray_67);
  1216. result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_4, (__m512bh) xArray_89);
  1217. result_0 = _mm512_add_ps(result_0, result_1);
  1218. STORE16_MASK_COMPLETE_RESULT(result_0, y+idx_m, store_mask)
  1219. }
  1220. }
  1221. if (tag_m_12x != m) {
  1222. __m256i matrixArray256;
  1223. __m256i x256 = _mm256_insertf128_si256(_mm256_castsi128_si256(x128_0), x128_1, 0x1);
  1224. __m256 result256;
  1225. __m128 result128, tmp128;
  1226. unsigned char load256_mask_value = (((unsigned char)0xff) >> 3);
  1227. __mmask8 load256_mask = *((__mmask8*) &load256_mask_value);
  1228. for (BLASLONG i = tag_m_12x; i < m; i++) {
  1229. result256 = _mm256_setzero_ps();
  1230. matrixArray256 = _mm256_maskz_loadu_epi32(load256_mask, &a[(i)*10]);
  1231. result256 = _mm256_dpbf16_ps(result256, (__m256bh) matrixArray256, (__m256bh) x256);
  1232. result128 = _mm_add_ps(_mm256_castps256_ps128(result256), _mm256_extractf128_ps(result256, 0x1));
  1233. tmp128 = _mm_shuffle_ps(result128, result128, 14);
  1234. result128 = _mm_add_ps(result128, tmp128);
  1235. tmp128 = _mm_shuffle_ps(result128, result128, 1);
  1236. result128 = _mm_add_ps(result128, tmp128);
  1237. #ifndef ZERO_BETA
  1238. #ifndef ONE_BETA
  1239. y[i] = alpha * result128[0] + beta * y[i];
  1240. #else
  1241. y[i] = alpha * result128[0] + y[i];
  1242. #endif
  1243. #else
  1244. #ifndef ONE_ALPHA
  1245. y[i] = result128[0] * alpha;
  1246. #else
  1247. y[i] = result128[0];
  1248. #endif
  1249. #endif
  1250. }
  1251. }
  1252. return 0;
  1253. }
  1254. // 15 rows parallel processing BF16 GEMV kernel for n=11 && lda ineffective scenario
  1255. #ifndef ZERO_BETA
  1256. #ifndef ONE_BETA
  1257. static int sbgemv_kernel_15x11_alpha_beta(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
  1258. #else
  1259. static int sbgemv_kernel_15x11_alpha_one(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
  1260. #endif
  1261. #else
  1262. #ifndef ONE_ALPHA
  1263. static int sbgemv_kernel_15x11_alpha(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
  1264. #else
  1265. static int sbgemv_kernel_15x11(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
  1266. #endif
  1267. #endif
  1268. {
  1269. BLASLONG tag_m_15x = m - (m%15);
  1270. unsigned char x_load_mask_value = (((unsigned char)0xff) >> 5);
  1271. __mmask8 x_load_mask = *((__mmask8*) &x_load_mask_value);
  1272. __m128i x128_0 = _mm_loadu_si128(x); // |x0|x1| x2|x3|x4|x5|x6|x7|
  1273. __m128i x128_1 = _mm_maskz_loadu_epi16(x_load_mask, (x+8)); // |x8|x9|x10| 0| 0| 0| 0| 0|
  1274. if (tag_m_15x > 0) {
  1275. __m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3, matrixArray_4, matrixArray_5;
  1276. __m512i matrixArray_stage_0, matrixArray_stage_1, matrixArray_stage_2, matrixArray_stage_3, matrixArray_stage_4, matrixArray_stage_5;
  1277. __m512i xArray_01, xArray_23, xArray_45, xArray_67, xArray_89, xArray_10;
  1278. __m512 result_0, result_1;
  1279. #ifndef ONE_ALPHA
  1280. __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
  1281. #endif
  1282. #ifndef ZERO_BETA
  1283. __m512 BETAVECTOR = _mm512_set1_ps(beta);
  1284. #endif
  1285. __m512i idx_stage1_base_0, idx_stage1_base_1, idx_stage1_base_2, idx_stage1_base_3, idx_stage1_base_4, idx_stage1_base_5;
  1286. __m512i idx_stage2_base_0, idx_stage2_base_1, idx_stage2_base_2, idx_stage2_base_3;
  1287. __m512i M512_EPI16_2, M512_EPI16_4, M512_EPI16_6, M512_EPI32_5;
  1288. M512_EPI16_2 = _mm512_set1_epi16(2);
  1289. M512_EPI16_4 = _mm512_add_epi16(M512_EPI16_2, M512_EPI16_2);
  1290. M512_EPI16_6 = _mm512_add_epi16(M512_EPI16_4, M512_EPI16_2);
  1291. M512_EPI32_5 = _mm512_set1_epi32(5);
  1292. unsigned int BASE_MASK_10_value = ((unsigned int)0x000003ff);
  1293. __mmask32 BASE_MASK_10 = *((__mmask32*) &BASE_MASK_10_value);
  1294. unsigned int BASE_MASK_20_value = ((unsigned int)0x000ffc00);
  1295. __mmask32 BASE_MASK_20 = *((__mmask32*) &BASE_MASK_20_value);
  1296. unsigned int BASE_MASK_30_value = ((unsigned int)0x3ff00000);
  1297. __mmask32 BASE_MASK_30 = *((__mmask32*) &BASE_MASK_30_value);
  1298. idx_stage1_base_0 = _mm512_set_epi16( 0, 0, 49, 48, 38, 37, 27, 26, 16, 15, 5, 4, 47, 46, 36, 35,
  1299. 25, 24, 14, 13, 3, 2, 45, 44, 34, 33, 23, 22, 12, 11, 1, 0);
  1300. idx_stage1_base_1 = _mm512_add_epi16(idx_stage1_base_0, M512_EPI16_6);
  1301. idx_stage1_base_2 = _mm512_mask_add_epi16(idx_stage1_base_0, BASE_MASK_10, idx_stage1_base_0, M512_EPI16_2);
  1302. idx_stage1_base_2 = _mm512_mask_sub_epi16(idx_stage1_base_2, BASE_MASK_20, idx_stage1_base_0, M512_EPI16_2);
  1303. idx_stage1_base_3 = _mm512_add_epi16(idx_stage1_base_2, M512_EPI16_6);
  1304. idx_stage1_base_4 = _mm512_mask_add_epi16(idx_stage1_base_2, BASE_MASK_10, idx_stage1_base_2, M512_EPI16_2);
  1305. idx_stage1_base_4 = _mm512_mask_add_epi16(idx_stage1_base_4, BASE_MASK_20, idx_stage1_base_2, M512_EPI16_2);
  1306. idx_stage1_base_4 = _mm512_mask_sub_epi16(idx_stage1_base_4, BASE_MASK_30, idx_stage1_base_2, M512_EPI16_4);
  1307. idx_stage1_base_5 = _mm512_add_epi16(idx_stage1_base_4, M512_EPI16_6);
  1308. unsigned short idx_stage2_mask_1_value = ((unsigned short)0x03e0);
  1309. __mmask16 idx_stage2_mask_1 = *((__mmask16*) &idx_stage2_mask_1_value);
  1310. unsigned short idx_stage2_mask_2_value = ((unsigned short)0x7c00);
  1311. __mmask16 idx_stage2_mask_2 = *((__mmask16*) &idx_stage2_mask_2_value);
  1312. idx_stage2_base_0 = _mm512_set_epi32( 0, 0, 0, 0, 0, 0, 20, 19, 18, 17, 16, 9, 8, 7, 6, 5);
  1313. idx_stage2_base_1 = _mm512_set_epi32( 0, 25, 24, 23, 22, 21, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0);
  1314. idx_stage2_base_2 = _mm512_add_epi32(idx_stage2_base_0, M512_EPI32_5);
  1315. idx_stage2_base_2 = _mm512_mask_add_epi32(idx_stage2_base_2, idx_stage2_mask_1, idx_stage2_base_2, M512_EPI32_5);
  1316. idx_stage2_base_3 = _mm512_mask_sub_epi32(idx_stage2_base_1, idx_stage2_mask_2, idx_stage2_base_1, M512_EPI32_5);
  1317. xArray_01 = _mm512_broadcastd_epi32(x128_0); // |x0 |x1 |x0 |x1 | ... |x0 |x1 |
  1318. xArray_23 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128_0, 0x1)); // |x2 |x3 |x2 |x3 | ... |x2 |x3 |
  1319. xArray_45 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128_0, 0x2)); // |x4 |x5 |x4 |x5 | ... |x4 |x5 |
  1320. xArray_67 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128_0, 0x3)); // |x6 |x7 |x6 |x7 | ... |x6 |x7 |
  1321. xArray_89 = _mm512_broadcastd_epi32(x128_1); // |x8 |x9 |x8 |x9 | ... |x8 |x9 |
  1322. xArray_10 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128_1, 0x1)); // |x10|0 |x10|0 | ... |x10|0 |
  1323. unsigned int load_mask_value = (((unsigned int)0xffffffff) >> 9);
  1324. __mmask32 load_mask = *((__mmask32*) &load_mask_value);
  1325. unsigned short store_mask_value = (((unsigned short)0xffff) >> 1);
  1326. __mmask16 store_mask = *((__mmask16*) &store_mask_value);
  1327. for (BLASLONG idx_m = 0; idx_m < tag_m_15x; idx_m+=15) {
  1328. result_0 = _mm512_setzero_ps();
  1329. result_1 = _mm512_setzero_ps();
  1330. matrixArray_0 = _mm512_loadu_si512(&a[idx_m*11]); // Load 2 rows with n=11 plus 10 elements
  1331. matrixArray_1 = _mm512_maskz_loadu_epi16(load_mask, &a[idx_m*11 + 32]); // Load 2 rows with n=11 plus 1 element
  1332. matrixArray_2 = _mm512_loadu_si512(&a[(idx_m+5)*11]); // Load 2 rows with n=11 plus 10 elements
  1333. matrixArray_3 = _mm512_maskz_loadu_epi16(load_mask, &a[(idx_m+5)*11 + 32]); // Load 2 rows with n=11 plus 1 element
  1334. matrixArray_4 = _mm512_loadu_si512(&a[(idx_m+10)*11]); // Load 2 rows with n=11 plus 10 elements
  1335. matrixArray_5 = _mm512_maskz_loadu_epi16(load_mask, &a[(idx_m+10)*11 + 32]); // Load 2 rows with n=11 plus 1 element
  1336. // Stage 1: interleave per 16 bits
  1337. matrixArray_stage_0 = _mm512_permutex2var_epi16(matrixArray_0, idx_stage1_base_0, matrixArray_1); // |a0|a1|...|e0|e1|a2|a3|...|e2|e3|a4 |a5|...|e4 |e5|
  1338. matrixArray_stage_1 = _mm512_permutex2var_epi16(matrixArray_0, idx_stage1_base_1, matrixArray_1); // |a6|a7|...|e6|e7|a8|a9|...|e8|e9|a10|x |...|e10|x |
  1339. matrixArray_stage_2 = _mm512_permutex2var_epi16(matrixArray_2, idx_stage1_base_2, matrixArray_3); // |f2|f3|...|j2|j3|f0|f1|...|j0|j1|f4 |f5|...|j4 |j5|
  1340. matrixArray_stage_3 = _mm512_permutex2var_epi16(matrixArray_2, idx_stage1_base_3, matrixArray_3); // |f8|f9|...|j8|j9|f6|f7|...|j6|j7|f10|x |...|j10|x |
  1341. matrixArray_stage_4 = _mm512_permutex2var_epi16(matrixArray_4, idx_stage1_base_4, matrixArray_5); // |k4|k5|...|o4|o5|k2|k3|...|o2|o3|k0 |k1|...|o0 |o1|
  1342. matrixArray_stage_5 = _mm512_permutex2var_epi16(matrixArray_4, idx_stage1_base_5, matrixArray_5); // |k10|x|...|o10|x|k8|k9|...|o8|o9|k6 |k7|...|o6 |o7|
  1343. // Stage 2: interleave per 32 bits
  1344. matrixArray_0 = _mm512_mask_blend_epi32(idx_stage2_mask_1, matrixArray_stage_0, matrixArray_stage_2); // |a0|a1|...|j0|j1|x|x|x|x|x|x|x|x|x|x|x|x|
  1345. matrixArray_3 = _mm512_mask_blend_epi32(idx_stage2_mask_1, matrixArray_stage_1, matrixArray_stage_3); // |a6|a7|...|j6|j7|x|x|x|x|x|x|x|x|x|x|x|x|
  1346. matrixArray_1 = _mm512_permutex2var_epi32(matrixArray_stage_0, idx_stage2_base_0, matrixArray_stage_2); // |a2|a3|...|j2|j3|x|x|x|x|x|x|x|x|x|x|x|x|
  1347. matrixArray_2 = _mm512_permutex2var_epi32(matrixArray_stage_0, idx_stage2_base_2, matrixArray_stage_2); // |a4|a5|...|j4|j5|x|x|x|x|x|x|x|x|x|x|x|x|
  1348. matrixArray_4 = _mm512_permutex2var_epi32(matrixArray_stage_1, idx_stage2_base_0, matrixArray_stage_3); // |a8|a9|...|j8|j9|x|x|x|x|x|x|x|x|x|x|x|x|
  1349. matrixArray_5 = _mm512_permutex2var_epi32(matrixArray_stage_1, idx_stage2_base_2, matrixArray_stage_3); // |a10|x|...|j10|x|x|x|x|x|x|x|x|x|x|x|x|x|
  1350. matrixArray_0 = _mm512_mask_blend_epi32(idx_stage2_mask_2, matrixArray_0, matrixArray_stage_4); // |a0|a1|.......................|o0|o1|x|x|
  1351. matrixArray_3 = _mm512_mask_blend_epi32(idx_stage2_mask_2, matrixArray_3, matrixArray_stage_5); // |a6|a7|.......................|o6|o7|x|x|
  1352. matrixArray_1 = _mm512_permutex2var_epi32(matrixArray_1 , idx_stage2_base_1, matrixArray_stage_4); // |a2|a3|.......................|o2|o3|x|x|
  1353. matrixArray_2 = _mm512_permutex2var_epi32(matrixArray_2 , idx_stage2_base_3, matrixArray_stage_4); // |a4|a5|.......................|o4|o5|x|x|
  1354. matrixArray_4 = _mm512_permutex2var_epi32(matrixArray_4 , idx_stage2_base_1, matrixArray_stage_5); // |a8|a9|.......................|o8|o9|x|x|
  1355. matrixArray_5 = _mm512_permutex2var_epi32(matrixArray_5 , idx_stage2_base_3, matrixArray_stage_5); // |a10|x|.......................|o10|x|x|x|
  1356. result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_0, (__m512bh) xArray_01);
  1357. result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_1, (__m512bh) xArray_23);
  1358. result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_2, (__m512bh) xArray_45);
  1359. result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_3, (__m512bh) xArray_67);
  1360. result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_4, (__m512bh) xArray_89);
  1361. result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_5, (__m512bh) xArray_10);
  1362. result_0 = _mm512_add_ps(result_0, result_1);
  1363. STORE16_MASK_COMPLETE_RESULT(result_0, y+idx_m, store_mask)
  1364. }
  1365. }
  1366. if (tag_m_15x != m) {
  1367. __m256i matrixArray256;
  1368. __m256i x256 = _mm256_insertf128_si256(_mm256_castsi128_si256(x128_0), x128_1, 0x1);
  1369. __m256 result256;
  1370. __m128 result128, tmp128;
  1371. unsigned short load256_mask_value = (((unsigned short)0xffff) >> 5);
  1372. __mmask16 load256_mask = *((__mmask16*) &load256_mask_value);
  1373. for (BLASLONG i = tag_m_15x; i < m; i++) {
  1374. result256 = _mm256_setzero_ps();
  1375. matrixArray256 = _mm256_maskz_loadu_epi16(load256_mask, &a[(i)*11]);
  1376. result256 = _mm256_dpbf16_ps(result256, (__m256bh) matrixArray256, (__m256bh) x256);
  1377. result128 = _mm_add_ps(_mm256_castps256_ps128(result256), _mm256_extractf128_ps(result256, 0x1));
  1378. tmp128 = _mm_shuffle_ps(result128, result128, 14);
  1379. result128 = _mm_add_ps(result128, tmp128);
  1380. tmp128 = _mm_shuffle_ps(result128, result128, 1);
  1381. result128 = _mm_add_ps(result128, tmp128);
  1382. #ifndef ZERO_BETA
  1383. #ifndef ONE_BETA
  1384. y[i] = alpha * result128[0] + beta * y[i];
  1385. #else
  1386. y[i] = alpha * result128[0] + y[i];
  1387. #endif
  1388. #else
  1389. #ifndef ONE_ALPHA
  1390. y[i] = result128[0] * alpha;
  1391. #else
  1392. y[i] = result128[0];
  1393. #endif
  1394. #endif
  1395. }
  1396. }
  1397. return 0;
  1398. }
  1399. // 15 rows parallel processing BF16 GEMV kernel for n=12 && lda ineffective scenario
  1400. #ifndef ZERO_BETA
  1401. #ifndef ONE_BETA
  1402. static int sbgemv_kernel_15x12_alpha_beta(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
  1403. #else
  1404. static int sbgemv_kernel_15x12_alpha_one(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
  1405. #endif
  1406. #else
  1407. #ifndef ONE_ALPHA
  1408. static int sbgemv_kernel_15x12_alpha(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
  1409. #else
  1410. static int sbgemv_kernel_15x12(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
  1411. #endif
  1412. #endif
  1413. {
  1414. BLASLONG tag_m_15x = m - (m%15);
  1415. unsigned char x_load_mask_value = (((unsigned char)0xff) >> 4);
  1416. __mmask8 x_load_mask = *((__mmask8*) &x_load_mask_value);
  1417. __m128i x128_0 = _mm_loadu_si128(x); // |x0|x1| x2| x3|x4|x5|x6|x7|
  1418. __m128i x128_1 = _mm_maskz_loadu_epi16(x_load_mask, (x+8)); // |x8|x9|x10|x11| 0| 0| 0| 0|
  1419. if (tag_m_15x > 0) {
  1420. __m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3, matrixArray_4, matrixArray_5;
  1421. __m512i matrixArray_stage_0, matrixArray_stage_1, matrixArray_stage_2, matrixArray_stage_3, matrixArray_stage_4, matrixArray_stage_5;
  1422. __m512i xArray_01, xArray_23, xArray_45, xArray_67, xArray_89, xArray_10;
  1423. __m512 result_0, result_1;
  1424. #ifndef ONE_ALPHA
  1425. __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
  1426. #endif
  1427. #ifndef ZERO_BETA
  1428. __m512 BETAVECTOR = _mm512_set1_ps(beta);
  1429. #endif
  1430. __m512i idx_stage1_base_0, idx_stage1_base_1, idx_stage1_base_2, idx_stage1_base_3, idx_stage1_base_4, idx_stage1_base_5;
  1431. __m512i idx_stage2_base_0, idx_stage2_base_1, idx_stage2_base_2, idx_stage2_base_3;
  1432. __m512i M512_EPI32_1, M512_EPI32_2, M512_EPI32_3, M512_EPI32_5;
  1433. M512_EPI32_1 = _mm512_set1_epi32(1);
  1434. M512_EPI32_2 = _mm512_add_epi32(M512_EPI32_1, M512_EPI32_1);
  1435. M512_EPI32_3 = _mm512_add_epi32(M512_EPI32_2, M512_EPI32_1);
  1436. M512_EPI32_5 = _mm512_add_epi32(M512_EPI32_3, M512_EPI32_2);
  1437. unsigned short BASE_MASK_10_value = ((unsigned short)0x001f);
  1438. __mmask16 BASE_MASK_10 = *((__mmask16*) &BASE_MASK_10_value);
  1439. unsigned short BASE_MASK_20_value = ((unsigned short)0x03e0);
  1440. __mmask16 BASE_MASK_20 = *((__mmask16*) &BASE_MASK_20_value);
  1441. unsigned short BASE_MASK_30_value = ((unsigned short)0xfc00);
  1442. __mmask16 BASE_MASK_30 = *((__mmask16*) &BASE_MASK_30_value);
  1443. idx_stage1_base_0 = _mm512_set_epi32( 0, 26, 20, 14, 8, 2, 25, 19, 13, 7, 1, 24, 18, 12, 6, 0);
  1444. idx_stage1_base_1 = _mm512_add_epi32(idx_stage1_base_0, M512_EPI32_3);
  1445. idx_stage1_base_2 = _mm512_mask_add_epi32(idx_stage1_base_0, BASE_MASK_10, idx_stage1_base_0, M512_EPI32_1);
  1446. idx_stage1_base_2 = _mm512_mask_sub_epi32(idx_stage1_base_2, BASE_MASK_20, idx_stage1_base_0, M512_EPI32_1);
  1447. idx_stage1_base_3 = _mm512_add_epi32(idx_stage1_base_2, M512_EPI32_3);
  1448. idx_stage1_base_4 = _mm512_mask_add_epi32(idx_stage1_base_2, BASE_MASK_10, idx_stage1_base_2, M512_EPI32_1);
  1449. idx_stage1_base_4 = _mm512_mask_add_epi32(idx_stage1_base_4, BASE_MASK_20, idx_stage1_base_2, M512_EPI32_1);
  1450. idx_stage1_base_4 = _mm512_mask_sub_epi32(idx_stage1_base_4, BASE_MASK_30, idx_stage1_base_2, M512_EPI32_2);
  1451. idx_stage1_base_5 = _mm512_add_epi32(idx_stage1_base_4, M512_EPI32_3);
  1452. unsigned short idx_stage2_mask_1_value = ((unsigned short)0x03e0);
  1453. __mmask16 idx_stage2_mask_1 = *((__mmask16*) &idx_stage2_mask_1_value);
  1454. unsigned short idx_stage2_mask_2_value = ((unsigned short)0x7c00);
  1455. __mmask16 idx_stage2_mask_2 = *((__mmask16*) &idx_stage2_mask_2_value);
  1456. idx_stage2_base_0 = _mm512_set_epi32( 0, 0, 0, 0, 0, 0, 20, 19, 18, 17, 16, 9, 8, 7, 6, 5);
  1457. idx_stage2_base_1 = _mm512_set_epi32( 0, 25, 24, 23, 22, 21, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0);
  1458. idx_stage2_base_2 = _mm512_add_epi32(idx_stage2_base_0, M512_EPI32_5);
  1459. idx_stage2_base_2 = _mm512_mask_add_epi32(idx_stage2_base_2, idx_stage2_mask_1, idx_stage2_base_2, M512_EPI32_5);
  1460. idx_stage2_base_3 = _mm512_mask_sub_epi32(idx_stage2_base_1, idx_stage2_mask_2, idx_stage2_base_1, M512_EPI32_5);
  1461. xArray_01 = _mm512_broadcastd_epi32(x128_0); // |x0 |x1 |x0 |x1 | ... |x0 |x1 |
  1462. xArray_23 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128_0, 0x1)); // |x2 |x3 |x2 |x3 | ... |x2 |x3 |
  1463. xArray_45 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128_0, 0x2)); // |x4 |x5 |x4 |x5 | ... |x4 |x5 |
  1464. xArray_67 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128_0, 0x3)); // |x6 |x7 |x6 |x7 | ... |x6 |x7 |
  1465. xArray_89 = _mm512_broadcastd_epi32(x128_1); // |x8 |x9 |x8 |x9 | ... |x8 |x9 |
  1466. xArray_10 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128_1, 0x1)); // |x10|x11|x10|x11| ... |x10|x11|
  1467. unsigned int load_mask_value = (((unsigned int)0xffffffff) >> 4);
  1468. __mmask32 load_mask = *((__mmask32*) &load_mask_value);
  1469. unsigned short store_mask_value = (((unsigned short)0xffff) >> 1);
  1470. __mmask16 store_mask = *((__mmask16*) &store_mask_value);
  1471. for (BLASLONG idx_m = 0; idx_m < tag_m_15x; idx_m+=15) {
  1472. result_0 = _mm512_setzero_ps();
  1473. result_1 = _mm512_setzero_ps();
  1474. matrixArray_0 = _mm512_loadu_si512(&a[idx_m*12]); // Load 2 rows with n=12 plus 8 elements
  1475. matrixArray_1 = _mm512_maskz_loadu_epi16(load_mask, &a[idx_m*12 + 32]); // Load 2 rows with n=12 plus 4 element
  1476. matrixArray_2 = _mm512_loadu_si512(&a[(idx_m+5)*12]); // Load 2 rows with n=12 plus 8 elements
  1477. matrixArray_3 = _mm512_maskz_loadu_epi16(load_mask, &a[(idx_m+5)*12 + 32]); // Load 2 rows with n=12 plus 4 element
  1478. matrixArray_4 = _mm512_loadu_si512(&a[(idx_m+10)*12]); // Load 2 rows with n=12 plus 8 elements
  1479. matrixArray_5 = _mm512_maskz_loadu_epi16(load_mask, &a[(idx_m+10)*12 + 32]); // Load 2 rows with n=12 plus 4 element
  1480. // Stage 1: interleave per 16 bits
  1481. matrixArray_stage_0 = _mm512_permutex2var_epi32(matrixArray_0, idx_stage1_base_0, matrixArray_1); // |a0 |a1 |...|e0 |e1 |a2|a3|...|e2|e3|a4 |a5 |...|e4 |e5 |
  1482. matrixArray_stage_1 = _mm512_permutex2var_epi32(matrixArray_0, idx_stage1_base_1, matrixArray_1); // |a6 |a7 |...|e6 |e7 |a8|a9|...|e8|e9|a10|a11|...|e10|e11|
  1483. matrixArray_stage_2 = _mm512_permutex2var_epi32(matrixArray_2, idx_stage1_base_2, matrixArray_3); // |f2 |f3 |...|j2 |j3 |f0|f1|...|j0|j1|f4 |f5 |...|j4 |j5 |
  1484. matrixArray_stage_3 = _mm512_permutex2var_epi32(matrixArray_2, idx_stage1_base_3, matrixArray_3); // |f8 |f9 |...|j8 |j9 |f6|f7|...|j6|j7|f10|f11|...|j10|j11|
  1485. matrixArray_stage_4 = _mm512_permutex2var_epi32(matrixArray_4, idx_stage1_base_4, matrixArray_5); // |k4 |k5 |...|o4 |o5 |k2|k3|...|o2|o3|k0 |k1 |...|o0 |o1 |
  1486. matrixArray_stage_5 = _mm512_permutex2var_epi32(matrixArray_4, idx_stage1_base_5, matrixArray_5); // |k10|k11|...|o10|o11|k8|k9|...|o8|o9|k6 |k7 |...|o6 |o7 |
  1487. // Stage 2: interleave per 32 bits
  1488. matrixArray_0 = _mm512_mask_blend_epi32(idx_stage2_mask_1, matrixArray_stage_0, matrixArray_stage_2); // |a0 |a1 |...|j0 |j1 |x|x|x|x|x|x|x|x|x|x|x|x|
  1489. matrixArray_3 = _mm512_mask_blend_epi32(idx_stage2_mask_1, matrixArray_stage_1, matrixArray_stage_3); // |a6 |a7 |...|j6 |j7 |x|x|x|x|x|x|x|x|x|x|x|x|
  1490. matrixArray_1 = _mm512_permutex2var_epi32(matrixArray_stage_0, idx_stage2_base_0, matrixArray_stage_2); // |a2 |a3 |...|j2 |j3 |x|x|x|x|x|x|x|x|x|x|x|x|
  1491. matrixArray_2 = _mm512_permutex2var_epi32(matrixArray_stage_0, idx_stage2_base_2, matrixArray_stage_2); // |a4 |a5 |...|j4 |j5 |x|x|x|x|x|x|x|x|x|x|x|x|
  1492. matrixArray_4 = _mm512_permutex2var_epi32(matrixArray_stage_1, idx_stage2_base_0, matrixArray_stage_3); // |a8 |a9 |...|j8 |j9 |x|x|x|x|x|x|x|x|x|x|x|x|
  1493. matrixArray_5 = _mm512_permutex2var_epi32(matrixArray_stage_1, idx_stage2_base_2, matrixArray_stage_3); // |a10|a11|...|j10|j11|x|x|x|x|x|x|x|x|x|x|x|x|
  1494. matrixArray_0 = _mm512_mask_blend_epi32(idx_stage2_mask_2, matrixArray_0, matrixArray_stage_4); // |a0|a1|.......................|o0|o1|x|x|
  1495. matrixArray_3 = _mm512_mask_blend_epi32(idx_stage2_mask_2, matrixArray_3, matrixArray_stage_5); // |a6|a7|.......................|o6|o7|x|x|
  1496. matrixArray_1 = _mm512_permutex2var_epi32(matrixArray_1 , idx_stage2_base_1, matrixArray_stage_4); // |a2|a3|.......................|o2|o3|x|x|
  1497. matrixArray_2 = _mm512_permutex2var_epi32(matrixArray_2 , idx_stage2_base_3, matrixArray_stage_4); // |a4|a5|.......................|o4|o5|x|x|
  1498. matrixArray_4 = _mm512_permutex2var_epi32(matrixArray_4 , idx_stage2_base_1, matrixArray_stage_5); // |a8|a9|.......................|o8|o9|x|x|
  1499. matrixArray_5 = _mm512_permutex2var_epi32(matrixArray_5 , idx_stage2_base_3, matrixArray_stage_5); // |a10|x|.......................|o10|x|x|x|
  1500. result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_0, (__m512bh) xArray_01);
  1501. result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_1, (__m512bh) xArray_23);
  1502. result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_2, (__m512bh) xArray_45);
  1503. result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_3, (__m512bh) xArray_67);
  1504. result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_4, (__m512bh) xArray_89);
  1505. result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_5, (__m512bh) xArray_10);
  1506. result_0 = _mm512_add_ps(result_0, result_1);
  1507. STORE16_MASK_COMPLETE_RESULT(result_0, y+idx_m, store_mask)
  1508. }
  1509. }
  1510. if (tag_m_15x != m) {
  1511. __m256i matrixArray256;
  1512. __m256i x256 = _mm256_insertf128_si256(_mm256_castsi128_si256(x128_0), x128_1, 0x1);
  1513. __m256 result256;
  1514. __m128 result128, tmp128;
  1515. unsigned short load256_mask_value = (((unsigned short)0xffff) >> 4);
  1516. __mmask16 load256_mask = *((__mmask16*) &load256_mask_value);
  1517. for (BLASLONG i = tag_m_15x; i < m; i++) {
  1518. result256 = _mm256_setzero_ps();
  1519. matrixArray256 = _mm256_maskz_loadu_epi16(load256_mask, &a[(i)*12]);
  1520. result256 = _mm256_dpbf16_ps(result256, (__m256bh) matrixArray256, (__m256bh) x256);
  1521. result128 = _mm_add_ps(_mm256_castps256_ps128(result256), _mm256_extractf128_ps(result256, 0x1));
  1522. tmp128 = _mm_shuffle_ps(result128, result128, 14);
  1523. result128 = _mm_add_ps(result128, tmp128);
  1524. tmp128 = _mm_shuffle_ps(result128, result128, 1);
  1525. result128 = _mm_add_ps(result128, tmp128);
  1526. #ifndef ZERO_BETA
  1527. #ifndef ONE_BETA
  1528. y[i] = alpha * result128[0] + beta * y[i];
  1529. #else
  1530. y[i] = alpha * result128[0] + y[i];
  1531. #endif
  1532. #else
  1533. #ifndef ONE_ALPHA
  1534. y[i] = result128[0] * alpha;
  1535. #else
  1536. y[i] = result128[0];
  1537. #endif
  1538. #endif
  1539. }
  1540. }
  1541. return 0;
  1542. }
  1543. // 16 rows parallel processing BF16 GEMV kernel for n=13 && lda ineffective scenario
  1544. #ifndef ZERO_BETA
  1545. #ifndef ONE_BETA
  1546. static int sbgemv_kernel_16x13_alpha_beta(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
  1547. #else
  1548. static int sbgemv_kernel_16x13_alpha_one(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
  1549. #endif
  1550. #else
  1551. #ifndef ONE_ALPHA
  1552. static int sbgemv_kernel_16x13_alpha(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
  1553. #else
  1554. static int sbgemv_kernel_16x13(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
  1555. #endif
  1556. #endif
  1557. {
  1558. BLASLONG tag_m_16x = m & (~15);
  1559. unsigned short x_load_mask_value = (((unsigned short)0xffff) >> 3);
  1560. __mmask16 x_load_mask = *((__mmask16*) &x_load_mask_value);
  1561. __m256i x256 = _mm256_maskz_loadu_epi16(x_load_mask, x); // |x0|x1|x2|x3|x4|x5|x6|x7|x8|x9|x10|x11|x12|0|0|0|
  1562. if (tag_m_16x > 0) {
  1563. __m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3, matrixArray_4, matrixArray_5, matrixArray_6, matrixArray_7, \
  1564. matrixArray_8, matrixArray_9, matrixArray_10, matrixArray_11, matrixArray_12, matrixArray_13, matrixArray_14, matrixArray_15;
  1565. __m512i xArray_0, xArray_1, xArray_2, xArray_3;
  1566. __m512 accum512_0, accum512_1;
  1567. __m512 result_0, result_1;
  1568. __m256i matrixArray256_0, matrixArray256_1, matrixArray256_2, matrixArray256_3, matrixArray256_4, matrixArray256_5, matrixArray256_6, matrixArray256_7;
  1569. #ifndef ONE_ALPHA
  1570. __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
  1571. #endif
  1572. #ifndef ZERO_BETA
  1573. __m512 BETAVECTOR = _mm512_set1_ps(beta);
  1574. #endif
  1575. __m512i M512_EPI32_4 = _mm512_set1_epi32(4);
  1576. __m512i idx_base_0 = _mm512_set_epi32(27, 26, 25, 24, 11, 10, 9, 8, 19, 18, 17, 16, 3, 2, 1, 0);
  1577. __m512i idx_base_1 = _mm512_add_epi32(idx_base_0, M512_EPI32_4);
  1578. unsigned int load_mask_value = (((unsigned int)0xffffffff) >> 6);
  1579. __mmask32 load_mask = *((__mmask32*) &load_mask_value);
  1580. // Prepare X with 2-step interleave way
  1581. xArray_0 = _mm512_inserti32x8(_mm512_castsi256_si512(x256), x256, 0x1);
  1582. BF16_INTERLEAVE_1x32(xArray)
  1583. for (BLASLONG idx_m = 0; idx_m < tag_m_16x; idx_m+=16) {
  1584. accum512_0 = _mm512_setzero_ps();
  1585. accum512_1 = _mm512_setzero_ps();
  1586. // Load matrix
  1587. BF16_MATRIX_MASKZ_LOAD_8x16(matrixArray256, a, 13, idx_m, 0, x_load_mask)
  1588. matrixArray_8 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_0), matrixArray256_1, 0x1);
  1589. matrixArray_9 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_2), matrixArray256_3, 0x1);
  1590. matrixArray_10 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_4), matrixArray256_5, 0x1);
  1591. matrixArray_11 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_6), matrixArray256_7, 0x1);
  1592. BF16_MATRIX_MASKZ_LOAD_8x16(matrixArray256, a, 13, idx_m+8, 0, x_load_mask)
  1593. matrixArray_12 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_0), matrixArray256_1, 0x1);
  1594. matrixArray_13 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_2), matrixArray256_3, 0x1);
  1595. matrixArray_14 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_4), matrixArray256_5, 0x1);
  1596. matrixArray_15 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_6), matrixArray256_7, 0x1);
  1597. // interleave per 256 bits
  1598. BF16_INTERLEAVE256_8x32(matrixArray)
  1599. // 2-step interleave for matrix
  1600. BF16_INTERLEAVE_8x32(matrixArray)
  1601. // Calculate the temp result for a..p[0:15]
  1602. BF16_2STEP_INTERLEAVED_DOT_8x32(accum512, matrixArray, xArray)
  1603. // Reorder and add up the final result
  1604. result_0 = _mm512_permutex2var_ps(accum512_0, idx_base_0, accum512_1);
  1605. result_1 = _mm512_permutex2var_ps(accum512_0, idx_base_1, accum512_1);
  1606. result_0 = _mm512_add_ps(result_0, result_1);
  1607. STORE16_COMPLETE_RESULT(result_0, y+idx_m)
  1608. }
  1609. if (m - tag_m_16x > 7) {
  1610. __m512i permutevar_idx = _mm512_set_epi32(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0);
  1611. accum512_0 = _mm512_setzero_ps();
  1612. accum512_1 = _mm512_setzero_ps();
  1613. // Load matrix
  1614. BF16_MATRIX_MASKZ_LOAD_8x16(matrixArray256, a, 13, tag_m_16x, 0, x_load_mask)
  1615. matrixArray_8 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_0), matrixArray256_1, 0x1);
  1616. matrixArray_9 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_2), matrixArray256_3, 0x1);
  1617. matrixArray_10 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_4), matrixArray256_5, 0x1);
  1618. matrixArray_11 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_6), matrixArray256_7, 0x1);
  1619. // interleave per 256 bits
  1620. matrixArray_0 = _mm512_shuffle_i32x4(matrixArray_8, matrixArray_10, 0x44);
  1621. matrixArray_1 = _mm512_shuffle_i32x4(matrixArray_8, matrixArray_10, 0xee);
  1622. matrixArray_2 = _mm512_shuffle_i32x4(matrixArray_9, matrixArray_11, 0x44);
  1623. matrixArray_3 = _mm512_shuffle_i32x4(matrixArray_9, matrixArray_11, 0xee);
  1624. // 2-step interleave for matrix
  1625. BF16_INTERLEAVE_4x32(matrixArray)
  1626. // Calculate the temp result for a..h[0:15]
  1627. BF16_2STEP_INTERLEAVED_DOT_4x32(accum512, matrixArray, xArray)
  1628. accum512_0 = _mm512_add_ps(accum512_0, accum512_1);
  1629. accum512_0 = _mm512_permutexvar_ps(permutevar_idx, accum512_0);
  1630. __m256 result256 = _mm256_add_ps(_mm512_castps512_ps256(accum512_0), _mm512_extractf32x8_ps(accum512_0, 1));
  1631. STORE8_COMPLETE_RESULT(result256, y+tag_m_16x)
  1632. tag_m_16x += 8;
  1633. }
  1634. if (m - tag_m_16x > 3) {
  1635. __m256i xArray256_0, xArray256_1, xArray256_2, xArray256_3;
  1636. __m256 accum256_0, accum256_1;
  1637. xArray256_0 = _mm512_castsi512_si256(xArray_0);
  1638. xArray256_1 = _mm512_castsi512_si256(xArray_1);
  1639. xArray256_2 = _mm512_castsi512_si256(xArray_2);
  1640. xArray256_3 = _mm512_castsi512_si256(xArray_3);
  1641. accum256_0 = _mm256_setzero_ps();
  1642. accum256_1 = _mm256_setzero_ps();
  1643. BF16_MATRIX_MASKZ_LOAD_4x16(matrixArray256, a, 13, tag_m_16x, 0, x_load_mask)
  1644. // 2-step interleave for matrix
  1645. BF16_INTERLEAVE_4x16(matrixArray256)
  1646. // Calculate the temp result for a..d[0:15]
  1647. BF16_2STEP_INTERLEAVED_DOT_4x16(accum256, matrixArray256, xArray256)
  1648. accum256_0 = _mm256_add_ps(accum256_0, accum256_1);
  1649. __m128 result128 = _mm_add_ps(_mm256_castps256_ps128(accum256_0), _mm256_extractf32x4_ps(accum256_0, 1));
  1650. STORE4_COMPLETE_RESULT(result128, y+tag_m_16x)
  1651. tag_m_16x += 4;
  1652. }
  1653. }
  1654. if (tag_m_16x != m) {
  1655. __m256i matrixArray256;
  1656. __m256 accum256;
  1657. __m128 accum128, tmp128;
  1658. for (BLASLONG i = tag_m_16x; i < m; i++) {
  1659. accum256 = _mm256_setzero_ps();
  1660. matrixArray256 = _mm256_maskz_loadu_epi16(x_load_mask, &a[(i)*13]); // Load 1 rows with n=13
  1661. accum256 = _mm256_dpbf16_ps(accum256, (__m256bh) matrixArray256, (__m256bh) x256);
  1662. accum128 = _mm_add_ps(_mm256_castps256_ps128(accum256), _mm256_extractf32x4_ps(accum256, 1));
  1663. tmp128 = _mm_shuffle_ps(accum128, accum128, 0x0e);
  1664. accum128 = _mm_add_ps(accum128, tmp128);
  1665. tmp128 = _mm_shuffle_ps(accum128, accum128, 0x01);
  1666. accum128 = _mm_add_ps(accum128, tmp128);
  1667. #ifndef ZERO_BETA
  1668. #ifndef ONE_BETA
  1669. y[i] = alpha * accum128[0] + beta * y[i];
  1670. #else
  1671. y[i] = alpha * accum128[0] + y[i];
  1672. #endif
  1673. #else
  1674. #ifndef ONE_ALPHA
  1675. y[i] = accum128[0] * alpha;
  1676. #else
  1677. y[i] = accum128[0];
  1678. #endif
  1679. #endif
  1680. }
  1681. }
  1682. return 0;
  1683. }
  1684. // 16 rows parallel processing BF16 GEMV kernel for n=14 && lda ineffective scenario
  1685. #ifndef ZERO_BETA
  1686. #ifndef ONE_BETA
  1687. static int sbgemv_kernel_16x14_alpha_beta(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
  1688. #else
  1689. static int sbgemv_kernel_16x14_alpha_one(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
  1690. #endif
  1691. #else
  1692. #ifndef ONE_ALPHA
  1693. static int sbgemv_kernel_16x14_alpha(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
  1694. #else
  1695. static int sbgemv_kernel_16x14(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
  1696. #endif
  1697. #endif
  1698. {
  1699. BLASLONG tag_m_16x = m & (~15);
  1700. unsigned short x_load_mask_value = (((unsigned short)0xffff) >> 2);
  1701. __mmask16 x_load_mask = *((__mmask16*) &x_load_mask_value);
  1702. __m256i x256 = _mm256_maskz_loadu_epi16(x_load_mask, x); // |x0|x1|x2|x3|x4|x5|x6|x7|x8|x9|x10|x11|x12|x13|0|0|
  1703. if (tag_m_16x > 0) {
  1704. __m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3, matrixArray_4, matrixArray_5, matrixArray_6, matrixArray_7, \
  1705. matrixArray_8, matrixArray_9, matrixArray_10, matrixArray_11, matrixArray_12, matrixArray_13, matrixArray_14, matrixArray_15;
  1706. __m512i xArray_0, xArray_1, xArray_2, xArray_3;
  1707. __m512 accum512_0, accum512_1;
  1708. __m512 result_0, result_1;
  1709. #ifndef ONE_ALPHA
  1710. __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
  1711. #endif
  1712. #ifndef ZERO_BETA
  1713. __m512 BETAVECTOR = _mm512_set1_ps(beta);
  1714. #endif
  1715. __m512i M512_EPI32_4 = _mm512_set1_epi32(4);
  1716. __m512i idx_base_0 = _mm512_set_epi32(27, 26, 25, 24, 11, 10, 9, 8, 19, 18, 17, 16, 3, 2, 1, 0);
  1717. __m512i idx_base_1 = _mm512_add_epi32(idx_base_0, M512_EPI32_4);
  1718. __m512i shift_idx = _mm512_set_epi32(0, 13, 12, 11, 10, 9, 8, 7, 0, 6, 5, 4, 3, 2, 1, 0);
  1719. unsigned int load_mask_value = (((unsigned int)0xffffffff) >> 4);
  1720. __mmask32 load_mask = *((__mmask32*) &load_mask_value);
  1721. // Prepare X with 2-step interleave way
  1722. xArray_0 = _mm512_inserti32x8(_mm512_castsi256_si512(x256), x256, 0x1);
  1723. BF16_INTERLEAVE_1x32(xArray)
  1724. for (BLASLONG idx_m = 0; idx_m < tag_m_16x; idx_m+=16) {
  1725. accum512_0 = _mm512_setzero_ps();
  1726. accum512_1 = _mm512_setzero_ps();
  1727. // Load matrix
  1728. BF16_MATRIX_MASKZ_LOAD_8x32_2(matrixArray, a, 14, idx_m, 0, load_mask)
  1729. // Pre-stage: shift the 2nd vector 1 position right for each register
  1730. BF16_PERMUTE_8x32_2(shift_idx, matrixArray)
  1731. // interleave per 256 bits
  1732. BF16_INTERLEAVE256_8x32(matrixArray)
  1733. // 2-step interleave for matrix
  1734. BF16_INTERLEAVE_8x32(matrixArray)
  1735. // Calculate the temp result for a..p[0:15]
  1736. BF16_2STEP_INTERLEAVED_DOT_8x32(accum512, matrixArray, xArray)
  1737. // Reorder and add up the final result
  1738. result_0 = _mm512_permutex2var_ps(accum512_0, idx_base_0, accum512_1);
  1739. result_1 = _mm512_permutex2var_ps(accum512_0, idx_base_1, accum512_1);
  1740. result_0 = _mm512_add_ps(result_0, result_1);
  1741. STORE16_COMPLETE_RESULT(result_0, y+idx_m)
  1742. }
  1743. if (m - tag_m_16x > 7) {
  1744. __m512i permutevar_idx = _mm512_set_epi32(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0);
  1745. accum512_0 = _mm512_setzero_ps();
  1746. accum512_1 = _mm512_setzero_ps();
  1747. // Load matrix
  1748. BF16_MATRIX_MASKZ_LOAD_4x32_2(matrixArray, a, 14, tag_m_16x, 0, load_mask)
  1749. // Pre-stage: shift the 2nd vector 1 position right for each register
  1750. BF16_PERMUTE_4x32_2(shift_idx, matrixArray)
  1751. // interleave per 256 bits
  1752. BF16_INTERLEAVE256_4x32(matrixArray)
  1753. // 2-step interleave for matrix
  1754. BF16_INTERLEAVE_4x32(matrixArray)
  1755. // Calculate the temp result for a..h[0:15]
  1756. BF16_2STEP_INTERLEAVED_DOT_4x32(accum512, matrixArray, xArray)
  1757. accum512_0 = _mm512_add_ps(accum512_0, accum512_1);
  1758. accum512_0 = _mm512_permutexvar_ps(permutevar_idx, accum512_0);
  1759. __m256 result256 = _mm256_add_ps(_mm512_castps512_ps256(accum512_0), _mm512_extractf32x8_ps(accum512_0, 1));
  1760. STORE8_COMPLETE_RESULT(result256, y+tag_m_16x)
  1761. tag_m_16x += 8;
  1762. }
  1763. if (m - tag_m_16x > 3) {
  1764. __m256i matrixArray256_0, matrixArray256_1, matrixArray256_2, matrixArray256_3, matrixArray256_4, matrixArray256_5, matrixArray256_6, matrixArray256_7;
  1765. __m256i xArray256_0, xArray256_1, xArray256_2, xArray256_3;
  1766. __m256 accum256_0, accum256_1;
  1767. xArray256_0 = _mm512_castsi512_si256(xArray_0);
  1768. xArray256_1 = _mm512_castsi512_si256(xArray_1);
  1769. xArray256_2 = _mm512_castsi512_si256(xArray_2);
  1770. xArray256_3 = _mm512_castsi512_si256(xArray_3);
  1771. accum256_0 = _mm256_setzero_ps();
  1772. accum256_1 = _mm256_setzero_ps();
  1773. BF16_MATRIX_MASKZ_LOAD_4x16(matrixArray256, a, 14, tag_m_16x, 0, x_load_mask)
  1774. // 2-step interleave for matrix
  1775. BF16_INTERLEAVE_4x16(matrixArray256)
  1776. // Calculate the temp result for a..d[0:15]
  1777. BF16_2STEP_INTERLEAVED_DOT_4x16(accum256, matrixArray256, xArray256)
  1778. accum256_0 = _mm256_add_ps(accum256_0, accum256_1);
  1779. __m128 result128 = _mm_add_ps(_mm256_castps256_ps128(accum256_0), _mm256_extractf32x4_ps(accum256_0, 1));
  1780. STORE4_COMPLETE_RESULT(result128, y+tag_m_16x)
  1781. tag_m_16x += 4;
  1782. }
  1783. }
  1784. if (tag_m_16x != m) {
  1785. __m256i matrixArray256;
  1786. __m256 accum256;
  1787. __m128 accum128, tmp128;
  1788. for (BLASLONG i = tag_m_16x; i < m; i++) {
  1789. accum256 = _mm256_setzero_ps();
  1790. matrixArray256 = _mm256_maskz_loadu_epi16(x_load_mask, &a[(i)*14]); // Load 1 rows with n=14
  1791. accum256 = _mm256_dpbf16_ps(accum256, (__m256bh) matrixArray256, (__m256bh) x256);
  1792. accum128 = _mm_add_ps(_mm256_castps256_ps128(accum256), _mm256_extractf32x4_ps(accum256, 1));
  1793. tmp128 = _mm_shuffle_ps(accum128, accum128, 0x0e);
  1794. accum128 = _mm_add_ps(accum128, tmp128);
  1795. tmp128 = _mm_shuffle_ps(accum128, accum128, 0x01);
  1796. accum128 = _mm_add_ps(accum128, tmp128);
  1797. #ifndef ZERO_BETA
  1798. #ifndef ONE_BETA
  1799. y[i] = alpha * accum128[0] + beta * y[i];
  1800. #else
  1801. y[i] = alpha * accum128[0] + y[i];
  1802. #endif
  1803. #else
  1804. #ifndef ONE_ALPHA
  1805. y[i] = accum128[0] * alpha;
  1806. #else
  1807. y[i] = accum128[0];
  1808. #endif
  1809. #endif
  1810. }
  1811. }
  1812. return 0;
  1813. }
  1814. // 16 rows parallel processing BF16 GEMV kernel for n=15 && lda ineffective scenario
  1815. #ifndef ZERO_BETA
  1816. #ifndef ONE_BETA
  1817. static int sbgemv_kernel_16x15_alpha_beta(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
  1818. #else
  1819. static int sbgemv_kernel_16x15_alpha_one(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
  1820. #endif
  1821. #else
  1822. #ifndef ONE_ALPHA
  1823. static int sbgemv_kernel_16x15_alpha(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
  1824. #else
  1825. static int sbgemv_kernel_16x15(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
  1826. #endif
  1827. #endif
  1828. {
  1829. BLASLONG tag_m_16x = m & (~15);
  1830. unsigned short x_load_mask_value = (((unsigned short)0xffff) >> 1);
  1831. __mmask16 x_load_mask = *((__mmask16*) &x_load_mask_value);
  1832. __m256i x256 = _mm256_maskz_loadu_epi16(x_load_mask, x); // |x0|x1|x2|x3|x4|x5|x6|x7|x8|x9|x10|x11|x12|x13|x14|0|
  1833. if (tag_m_16x > 0) {
  1834. __m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3, matrixArray_4, matrixArray_5, matrixArray_6, matrixArray_7, \
  1835. matrixArray_8, matrixArray_9, matrixArray_10, matrixArray_11, matrixArray_12, matrixArray_13, matrixArray_14, matrixArray_15;
  1836. __m512i xArray_0, xArray_1, xArray_2, xArray_3;
  1837. __m512 accum512_0, accum512_1;
  1838. __m512 result_0, result_1;
  1839. __m256i matrixArray256_0, matrixArray256_1, matrixArray256_2, matrixArray256_3, matrixArray256_4, matrixArray256_5, matrixArray256_6, matrixArray256_7;
  1840. #ifndef ONE_ALPHA
  1841. __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
  1842. #endif
  1843. #ifndef ZERO_BETA
  1844. __m512 BETAVECTOR = _mm512_set1_ps(beta);
  1845. #endif
  1846. __m512i M512_EPI32_4 = _mm512_set1_epi32(4);
  1847. __m512i idx_base_0 = _mm512_set_epi32(27, 26, 25, 24, 11, 10, 9, 8, 19, 18, 17, 16, 3, 2, 1, 0);
  1848. __m512i idx_base_1 = _mm512_add_epi32(idx_base_0, M512_EPI32_4);
  1849. unsigned int load_mask_value = (((unsigned int)0xffffffff) >> 2);
  1850. __mmask32 load_mask = *((__mmask32*) &load_mask_value);
  1851. // Prepare X with 2-step interleave way
  1852. xArray_0 = _mm512_inserti32x8(_mm512_castsi256_si512(x256), x256, 0x1);
  1853. BF16_INTERLEAVE_1x32(xArray)
  1854. for (BLASLONG idx_m = 0; idx_m < tag_m_16x; idx_m+=16) {
  1855. accum512_0 = _mm512_setzero_ps();
  1856. accum512_1 = _mm512_setzero_ps();
  1857. // Load matrix
  1858. BF16_MATRIX_MASKZ_LOAD_8x16(matrixArray256, a, 15, idx_m, 0, x_load_mask)
  1859. matrixArray_8 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_0), matrixArray256_1, 0x1);
  1860. matrixArray_9 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_2), matrixArray256_3, 0x1);
  1861. matrixArray_10 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_4), matrixArray256_5, 0x1);
  1862. matrixArray_11 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_6), matrixArray256_7, 0x1);
  1863. BF16_MATRIX_MASKZ_LOAD_8x16(matrixArray256, a, 15, idx_m+8, 0, x_load_mask)
  1864. matrixArray_12 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_0), matrixArray256_1, 0x1);
  1865. matrixArray_13 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_2), matrixArray256_3, 0x1);
  1866. matrixArray_14 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_4), matrixArray256_5, 0x1);
  1867. matrixArray_15 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_6), matrixArray256_7, 0x1);
  1868. // interleave per 256 bits
  1869. BF16_INTERLEAVE256_8x32(matrixArray)
  1870. // 2-step interleave for matrix
  1871. BF16_INTERLEAVE_8x32(matrixArray)
  1872. // Calculate the temp result for a..p[0:15]
  1873. BF16_2STEP_INTERLEAVED_DOT_8x32(accum512, matrixArray, xArray)
  1874. // Reorder and add up the final result
  1875. result_0 = _mm512_permutex2var_ps(accum512_0, idx_base_0, accum512_1);
  1876. result_1 = _mm512_permutex2var_ps(accum512_0, idx_base_1, accum512_1);
  1877. result_0 = _mm512_add_ps(result_0, result_1);
  1878. STORE16_COMPLETE_RESULT(result_0, y+idx_m)
  1879. }
  1880. if (m - tag_m_16x > 7) {
  1881. __m512i permutevar_idx = _mm512_set_epi32(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0);
  1882. accum512_0 = _mm512_setzero_ps();
  1883. accum512_1 = _mm512_setzero_ps();
  1884. // Load matrix
  1885. BF16_MATRIX_MASKZ_LOAD_8x16(matrixArray256, a, 15, tag_m_16x, 0, x_load_mask)
  1886. matrixArray_8 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_0), matrixArray256_1, 0x1);
  1887. matrixArray_9 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_2), matrixArray256_3, 0x1);
  1888. matrixArray_10 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_4), matrixArray256_5, 0x1);
  1889. matrixArray_11 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_6), matrixArray256_7, 0x1);
  1890. // interleave per 256 bits
  1891. matrixArray_0 = _mm512_shuffle_i32x4(matrixArray_8, matrixArray_10, 0x44);
  1892. matrixArray_1 = _mm512_shuffle_i32x4(matrixArray_8, matrixArray_10, 0xee);
  1893. matrixArray_2 = _mm512_shuffle_i32x4(matrixArray_9, matrixArray_11, 0x44);
  1894. matrixArray_3 = _mm512_shuffle_i32x4(matrixArray_9, matrixArray_11, 0xee);
  1895. // 2-step interleave for matrix
  1896. BF16_INTERLEAVE_4x32(matrixArray)
  1897. // Calculate the temp result for a..h[0:15]
  1898. BF16_2STEP_INTERLEAVED_DOT_4x32(accum512, matrixArray, xArray)
  1899. accum512_0 = _mm512_add_ps(accum512_0, accum512_1);
  1900. accum512_0 = _mm512_permutexvar_ps(permutevar_idx, accum512_0);
  1901. __m256 result256 = _mm256_add_ps(_mm512_castps512_ps256(accum512_0), _mm512_extractf32x8_ps(accum512_0, 1));
  1902. STORE8_COMPLETE_RESULT(result256, y+tag_m_16x)
  1903. tag_m_16x += 8;
  1904. }
  1905. if (m - tag_m_16x > 3) {
  1906. __m256i xArray256_0, xArray256_1, xArray256_2, xArray256_3;
  1907. __m256 accum256_0, accum256_1;
  1908. xArray256_0 = _mm512_castsi512_si256(xArray_0);
  1909. xArray256_1 = _mm512_castsi512_si256(xArray_1);
  1910. xArray256_2 = _mm512_castsi512_si256(xArray_2);
  1911. xArray256_3 = _mm512_castsi512_si256(xArray_3);
  1912. accum256_0 = _mm256_setzero_ps();
  1913. accum256_1 = _mm256_setzero_ps();
  1914. BF16_MATRIX_MASKZ_LOAD_4x16(matrixArray256, a, 15, tag_m_16x, 0, x_load_mask)
  1915. // 2-step interleave for matrix
  1916. BF16_INTERLEAVE_4x16(matrixArray256)
  1917. // Calculate the temp result for a..d[0:15]
  1918. BF16_2STEP_INTERLEAVED_DOT_4x16(accum256, matrixArray256, xArray256)
  1919. accum256_0 = _mm256_add_ps(accum256_0, accum256_1);
  1920. __m128 result128 = _mm_add_ps(_mm256_castps256_ps128(accum256_0), _mm256_extractf32x4_ps(accum256_0, 1));
  1921. STORE4_COMPLETE_RESULT(result128, y+tag_m_16x)
  1922. tag_m_16x += 4;
  1923. }
  1924. }
  1925. if (tag_m_16x != m) {
  1926. __m256i matrixArray256;
  1927. __m256 accum256;
  1928. __m128 accum128, tmp128;
  1929. for (BLASLONG i = tag_m_16x; i < m; i++) {
  1930. accum256 = _mm256_setzero_ps();
  1931. matrixArray256 = _mm256_maskz_loadu_epi16(x_load_mask, &a[(i)*15]); // Load 1 rows with n=15
  1932. accum256 = _mm256_dpbf16_ps(accum256, (__m256bh) matrixArray256, (__m256bh) x256);
  1933. accum128 = _mm_add_ps(_mm256_castps256_ps128(accum256), _mm256_extractf32x4_ps(accum256, 1));
  1934. tmp128 = _mm_shuffle_ps(accum128, accum128, 0x0e);
  1935. accum128 = _mm_add_ps(accum128, tmp128);
  1936. tmp128 = _mm_shuffle_ps(accum128, accum128, 0x01);
  1937. accum128 = _mm_add_ps(accum128, tmp128);
  1938. #ifndef ZERO_BETA
  1939. #ifndef ONE_BETA
  1940. y[i] = alpha * accum128[0] + beta * y[i];
  1941. #else
  1942. y[i] = alpha * accum128[0] + y[i];
  1943. #endif
  1944. #else
  1945. #ifndef ONE_ALPHA
  1946. y[i] = accum128[0] * alpha;
  1947. #else
  1948. y[i] = accum128[0];
  1949. #endif
  1950. #endif
  1951. }
  1952. }
  1953. return 0;
  1954. }
  1955. // 16 rows parallel processing BF16 GEMV kernel for n=16 && lda ineffective scenario
  1956. #ifndef ZERO_BETA
  1957. #ifndef ONE_BETA
  1958. static int sbgemv_kernel_16x16_alpha_beta(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
  1959. #else
  1960. static int sbgemv_kernel_16x16_alpha_one(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
  1961. #endif
  1962. #else
  1963. #ifndef ONE_ALPHA
  1964. static int sbgemv_kernel_16x16_alpha(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
  1965. #else
  1966. static int sbgemv_kernel_16x16(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
  1967. #endif
  1968. #endif
  1969. {
  1970. BLASLONG tag_m_16x = m & (~15);
  1971. __m256i x256 = _mm256_loadu_si256(x); // |x0|x1|x2|x3|x4|x5|x6|x7|x8|x9|x10|x11|x12|x13|x14|x15|
  1972. if (tag_m_16x > 0) {
  1973. __m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3, matrixArray_4, matrixArray_5, matrixArray_6, matrixArray_7, \
  1974. matrixArray_8, matrixArray_9, matrixArray_10, matrixArray_11, matrixArray_12, matrixArray_13, matrixArray_14, matrixArray_15;
  1975. __m512i xArray_0, xArray_1, xArray_2, xArray_3;
  1976. __m512 accum512_0, accum512_1;
  1977. __m512 result_0, result_1;
  1978. #ifndef ONE_ALPHA
  1979. __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
  1980. #endif
  1981. #ifndef ZERO_BETA
  1982. __m512 BETAVECTOR = _mm512_set1_ps(beta);
  1983. #endif
  1984. __m512i M512_EPI32_4 = _mm512_set1_epi32(4);
  1985. __m512i idx_base_0 = _mm512_set_epi32(27, 26, 25, 24, 11, 10, 9, 8, 19, 18, 17, 16, 3, 2, 1, 0);
  1986. __m512i idx_base_1 = _mm512_add_epi32(idx_base_0, M512_EPI32_4);
  1987. // Prepare X with 2-step interleave way
  1988. xArray_0 = _mm512_inserti32x8(_mm512_castsi256_si512(x256), x256, 0x1);
  1989. BF16_INTERLEAVE_1x32(xArray)
  1990. for (BLASLONG idx_m = 0; idx_m < tag_m_16x; idx_m+=16) {
  1991. accum512_0 = _mm512_setzero_ps();
  1992. accum512_1 = _mm512_setzero_ps();
  1993. matrixArray_8 = _mm512_loadu_si512(&a[(idx_m )*16]); // Load 2 rows with n=16
  1994. matrixArray_9 = _mm512_loadu_si512(&a[(idx_m+2 )*16]); // Load 2 rows with n=16
  1995. matrixArray_10 = _mm512_loadu_si512(&a[(idx_m+4 )*16]); // Load 2 rows with n=16
  1996. matrixArray_11 = _mm512_loadu_si512(&a[(idx_m+6 )*16]); // Load 2 rows with n=16
  1997. matrixArray_12 = _mm512_loadu_si512(&a[(idx_m+8 )*16]); // Load 2 rows with n=16
  1998. matrixArray_13 = _mm512_loadu_si512(&a[(idx_m+10)*16]); // Load 2 rows with n=16
  1999. matrixArray_14 = _mm512_loadu_si512(&a[(idx_m+12)*16]); // Load 2 rows with n=16
  2000. matrixArray_15 = _mm512_loadu_si512(&a[(idx_m+14)*16]); // Load 2 rows with n=16
  2001. // interleave per 256 bits
  2002. BF16_INTERLEAVE256_8x32(matrixArray)
  2003. // 2-step interleave for matrix
  2004. BF16_INTERLEAVE_8x32(matrixArray)
  2005. // Calculate the temp result for a..p[0:15]
  2006. BF16_2STEP_INTERLEAVED_DOT_8x32(accum512, matrixArray, xArray)
  2007. // Reorder and add up the final result
  2008. result_0 = _mm512_permutex2var_ps(accum512_0, idx_base_0, accum512_1);
  2009. result_1 = _mm512_permutex2var_ps(accum512_0, idx_base_1, accum512_1);
  2010. result_0 = _mm512_add_ps(result_0, result_1);
  2011. STORE16_COMPLETE_RESULT(result_0, y+idx_m)
  2012. }
  2013. if (m - tag_m_16x > 7) {
  2014. __m512i permutevar_idx = _mm512_set_epi32(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0);
  2015. accum512_0 = _mm512_setzero_ps();
  2016. accum512_1 = _mm512_setzero_ps();
  2017. matrixArray_4 = _mm512_loadu_si512(&a[(tag_m_16x )*16]); // Load 2 rows with n=16
  2018. matrixArray_5 = _mm512_loadu_si512(&a[(tag_m_16x+2 )*16]); // Load 2 rows with n=16
  2019. matrixArray_6 = _mm512_loadu_si512(&a[(tag_m_16x+4 )*16]); // Load 2 rows with n=16
  2020. matrixArray_7 = _mm512_loadu_si512(&a[(tag_m_16x+6 )*16]); // Load 2 rows with n=16
  2021. // interleave per 256 bits
  2022. BF16_INTERLEAVE256_4x32(matrixArray)
  2023. // 2-step interleave for matrix
  2024. BF16_INTERLEAVE_4x32(matrixArray)
  2025. // Calculate the temp result for a..h[0:15]
  2026. BF16_2STEP_INTERLEAVED_DOT_4x32(accum512, matrixArray, xArray)
  2027. accum512_0 = _mm512_add_ps(accum512_0, accum512_1);
  2028. accum512_0 = _mm512_permutexvar_ps(permutevar_idx, accum512_0);
  2029. __m256 result256 = _mm256_add_ps(_mm512_castps512_ps256(accum512_0), _mm512_extractf32x8_ps(accum512_0, 1));
  2030. STORE8_COMPLETE_RESULT(result256, y+tag_m_16x)
  2031. tag_m_16x += 8;
  2032. }
  2033. if (m - tag_m_16x > 3) {
  2034. __m256i matrixArray256_0, matrixArray256_1, matrixArray256_2, matrixArray256_3, \
  2035. matrixArray256_4, matrixArray256_5, matrixArray256_6, matrixArray256_7;
  2036. __m256i xArray256_0, xArray256_1, xArray256_2, xArray256_3;
  2037. __m256 accum256_0, accum256_1;
  2038. xArray256_0 = _mm512_castsi512_si256(xArray_0);
  2039. xArray256_1 = _mm512_castsi512_si256(xArray_1);
  2040. xArray256_2 = _mm512_castsi512_si256(xArray_2);
  2041. xArray256_3 = _mm512_castsi512_si256(xArray_3);
  2042. accum256_0 = _mm256_setzero_ps();
  2043. accum256_1 = _mm256_setzero_ps();
  2044. matrixArray_0 = _mm512_loadu_si512(&a[(tag_m_16x )*16]); // Load 2 rows with n=16
  2045. matrixArray_1 = _mm512_loadu_si512(&a[(tag_m_16x+2 )*16]); // Load 2 rows with n=16
  2046. matrixArray256_0 = _mm512_castsi512_si256(matrixArray_0);
  2047. matrixArray256_1 = _mm512_extracti32x8_epi32(matrixArray_0, 0x1);
  2048. matrixArray256_2 = _mm512_castsi512_si256(matrixArray_1);
  2049. matrixArray256_3 = _mm512_extracti32x8_epi32(matrixArray_1, 0x1);
  2050. // 2-step interleave for matrix
  2051. BF16_INTERLEAVE_4x16(matrixArray256)
  2052. // Calculate the temp result for a..d[0:15]
  2053. BF16_2STEP_INTERLEAVED_DOT_4x16(accum256, matrixArray256, xArray256)
  2054. accum256_0 = _mm256_add_ps(accum256_0, accum256_1);
  2055. __m128 result128 = _mm_add_ps(_mm256_castps256_ps128(accum256_0), _mm256_extractf32x4_ps(accum256_0, 1));
  2056. STORE4_COMPLETE_RESULT(result128, y+tag_m_16x)
  2057. tag_m_16x += 4;
  2058. }
  2059. }
  2060. if (tag_m_16x != m) {
  2061. __m256i matrixArray256;
  2062. __m256 accum256;
  2063. __m128 accum128, tmp128;
  2064. for (BLASLONG i = tag_m_16x; i < m; i++) {
  2065. accum256 = _mm256_setzero_ps();
  2066. matrixArray256 = _mm256_loadu_si256(&a[(i)*16]); // Load 1 rows with n=16
  2067. accum256 = _mm256_dpbf16_ps(accum256, (__m256bh) matrixArray256, (__m256bh) x256);
  2068. accum128 = _mm_add_ps(_mm256_castps256_ps128(accum256), _mm256_extractf32x4_ps(accum256, 1));
  2069. tmp128 = _mm_shuffle_ps(accum128, accum128, 0x0e);
  2070. accum128 = _mm_add_ps(accum128, tmp128);
  2071. tmp128 = _mm_shuffle_ps(accum128, accum128, 0x01);
  2072. accum128 = _mm_add_ps(accum128, tmp128);
  2073. #ifndef ZERO_BETA
  2074. #ifndef ONE_BETA
  2075. y[i] = alpha * accum128[0] + beta * y[i];
  2076. #else
  2077. y[i] = alpha * accum128[0] + y[i];
  2078. #endif
  2079. #else
  2080. #ifndef ONE_ALPHA
  2081. y[i] = accum128[0] * alpha;
  2082. #else
  2083. y[i] = accum128[0];
  2084. #endif
  2085. #endif
  2086. }
  2087. }
  2088. return 0;
  2089. }
  2090. // 8 rows parallel processing BF16 GEMV kernel for n>16 && lda effective scenario
  2091. #ifndef ZERO_BETA
  2092. #ifndef ONE_BETA
  2093. static int sbgemv_kernel_8x16p_lda_alpha_beta(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float beta, float *y)
  2094. #else
  2095. static int sbgemv_kernel_8x16p_lda_alpha_one(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float beta, float *y)
  2096. #endif
  2097. #else
  2098. #ifndef ONE_ALPHA
  2099. static int sbgemv_kernel_8x16p_lda_alpha(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float *y)
  2100. #else
  2101. static int sbgemv_kernel_8x16p_lda(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float *y)
  2102. #endif
  2103. #endif
  2104. {
  2105. BLASLONG tag_m_8x = m & (~7);
  2106. unsigned int load_mask_value = (((unsigned int)0xffffffff) >> (32-n));
  2107. __mmask32 load_mask = *((__mmask32*) &load_mask_value);
  2108. __m512i x512 = _mm512_maskz_loadu_epi16(load_mask, x); // |x0|x1|x2|x3|x4|x5|x6|x7|x8|x9|x10|x11|x12|x13|x14|x15|...
  2109. #ifndef ONE_ALPHA
  2110. __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
  2111. #endif
  2112. #ifndef ZERO_BETA
  2113. __m512 BETAVECTOR = _mm512_set1_ps(beta);
  2114. #endif
  2115. __m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3, matrixArray_4, matrixArray_5, matrixArray_6, matrixArray_7, \
  2116. matrixArray_8, matrixArray_9, matrixArray_10, matrixArray_11, matrixArray_12, matrixArray_13, matrixArray_14, matrixArray_15;
  2117. __m512 accum512_0, accum512_1, accum512_2, accum512_3;
  2118. __m256 accum256;
  2119. __m128 accum128;
  2120. if (tag_m_8x > 0) {
  2121. __m512i xArray_0, xArray_1, xArray_2, xArray_3;
  2122. __m512i M512_EPI32_4 = _mm512_set1_epi32(4);
  2123. __m512i idx_base_0 = _mm512_set_epi32(27, 26, 25, 24, 11, 10, 9, 8, 19, 18, 17, 16, 3, 2, 1, 0);
  2124. __m512i idx_base_1 = _mm512_add_epi32(idx_base_0, M512_EPI32_4);
  2125. // Prepare X with 2-step interleave way
  2126. xArray_0 = x512;
  2127. BF16_INTERLEAVE_1x32(xArray)
  2128. for (BLASLONG idx_m = 0; idx_m < tag_m_8x; idx_m+=8) {
  2129. accum512_0 = _mm512_setzero_ps();
  2130. accum512_1 = _mm512_setzero_ps();
  2131. // Load 8 rows from matrix
  2132. BF16_MATRIX_MASKZ_LOAD_8x32(matrixArray, a, lda, idx_m, 0, load_mask)
  2133. // 2-step interleave for matrix
  2134. BF16_INTERLEAVE_8x32(matrixArray)
  2135. // Calculate the temp result for a..h[0:31]
  2136. BF16_2STEP_INTERLEAVED_DOT_8x32(accum512, matrixArray, xArray)
  2137. // Reorder and add up the final result
  2138. accum512_2 = _mm512_permutex2var_ps(accum512_0, idx_base_0, accum512_1);
  2139. accum512_3 = _mm512_permutex2var_ps(accum512_0, idx_base_1, accum512_1);
  2140. accum512_2 = _mm512_add_ps(accum512_2, accum512_3);
  2141. accum256 = _mm256_add_ps(_mm512_castps512_ps256(accum512_2), _mm512_extractf32x8_ps(accum512_2, 1));
  2142. STORE8_COMPLETE_RESULT(accum256, y+idx_m)
  2143. }
  2144. if (m - tag_m_8x > 3) {
  2145. accum512_0 = _mm512_setzero_ps();
  2146. accum512_1 = _mm512_setzero_ps();
  2147. // Load 4 rows from matrix
  2148. BF16_MATRIX_MASKZ_LOAD_4x32(matrixArray, a, lda, tag_m_8x, 0, load_mask)
  2149. // 2-step interleave for matrix
  2150. BF16_INTERLEAVE_4x32(matrixArray)
  2151. // Calculate the temp result for a..d[0:31]
  2152. BF16_2STEP_INTERLEAVED_DOT_4x32(accum512, matrixArray, xArray)
  2153. accum512_0 = _mm512_add_ps(accum512_0, accum512_1);
  2154. accum256 = _mm256_add_ps(_mm512_castps512_ps256(accum512_0), _mm512_extractf32x8_ps(accum512_0, 1));
  2155. accum128 = _mm_add_ps(_mm256_castps256_ps128(accum256), _mm256_extractf32x4_ps(accum256, 1));
  2156. STORE4_COMPLETE_RESULT(accum128, y+tag_m_8x)
  2157. tag_m_8x += 4;
  2158. }
  2159. }
  2160. if (tag_m_8x != m) {
  2161. __m128 tmp128;
  2162. for (BLASLONG i = tag_m_8x; i < m; i++) {
  2163. accum512_0 = _mm512_setzero_ps();
  2164. matrixArray_0 = _mm512_maskz_loadu_epi16(load_mask, &a[(i)*lda]); // Load 1 rows with n=16
  2165. accum512_0 = _mm512_dpbf16_ps(accum512_0, (__m512bh) matrixArray_0, (__m512bh) x512);
  2166. accum256 = _mm256_add_ps(_mm512_castps512_ps256(accum512_0), _mm512_extractf32x8_ps(accum512_0, 1));
  2167. accum128 = _mm_add_ps(_mm256_castps256_ps128(accum256), _mm256_extractf32x4_ps(accum256, 1));
  2168. tmp128 = _mm_shuffle_ps(accum128, accum128, 0x0e);
  2169. accum128 = _mm_add_ps(accum128, tmp128);
  2170. tmp128 = _mm_shuffle_ps(accum128, accum128, 0x01);
  2171. accum128 = _mm_add_ps(accum128, tmp128);
  2172. #ifndef ZERO_BETA
  2173. #ifndef ONE_BETA
  2174. y[i] = alpha * accum128[0] + beta * y[i];
  2175. #else
  2176. y[i] = alpha * accum128[0] + y[i];
  2177. #endif
  2178. #else
  2179. #ifndef ONE_ALPHA
  2180. y[i] = accum128[0] * alpha;
  2181. #else
  2182. y[i] = accum128[0];
  2183. #endif
  2184. #endif
  2185. }
  2186. }
  2187. return 0;
  2188. }
  2189. // 8 rows parallel processing BF16 GEMV kernel for big N && lda effective scenario (process before interleave)
  2190. #ifndef ZERO_BETA
  2191. #ifndef ONE_BETA
  2192. static int sbgemv_kernel_1x128_lda_direct_alpha_beta(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float beta, float *y)
  2193. #else
  2194. static int sbgemv_kernel_1x128_lda_direct_alpha_one(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float beta, float *y)
  2195. #endif
  2196. #else
  2197. #ifndef ONE_ALPHA
  2198. static int sbgemv_kernel_1x128_lda_direct_alpha(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float *y)
  2199. #else
  2200. static int sbgemv_kernel_1x128_lda_direct(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float *y)
  2201. #endif
  2202. #endif
  2203. {
  2204. BLASLONG tag_m_8x = m & (~7);
  2205. BLASLONG tag_n_32x = n & (~31);
  2206. BLASLONG tag_n_128x = n & (~127);
  2207. __m512 accum512_0, accum512_1, accum512_2, accum512_3, accum512_4, accum512_5, accum512_6, accum512_7, \
  2208. accum512_8, accum512_9, accum512_10, accum512_11, accum512_12, accum512_13, accum512_14, accum512_15;
  2209. __m512 accum512_bridge[8];
  2210. __m512 accum512_t_0, accum512_t_1, accum512_t_2, accum512_t_3;
  2211. __m256 accum256_0;
  2212. __m128 accum128;
  2213. #ifndef ONE_ALPHA
  2214. __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
  2215. #endif
  2216. #ifndef ZERO_BETA
  2217. __m512 BETAVECTOR = _mm512_set1_ps(beta);
  2218. #endif
  2219. __m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3;
  2220. __m512i xArray_0, xArray_1, xArray_2, xArray_3;
  2221. unsigned int tail_mask_value = (((unsigned int)0xffffffff) >> (32-(n&31)));
  2222. __mmask32 tail_mask = *((__mmask32*) &tail_mask_value);
  2223. __m512i M512_EPI32_4 = _mm512_set1_epi32(4);
  2224. __m512i idx_base_0 = _mm512_set_epi32(27, 26, 25, 24, 11, 10, 9, 8, 19, 18, 17, 16, 3, 2, 1, 0);
  2225. __m512i idx_base_1 = _mm512_add_epi32(idx_base_0, M512_EPI32_4);
  2226. if (tag_m_8x > 0) {
  2227. for (BLASLONG idx_m = 0; idx_m < tag_m_8x; idx_m+=8) {
  2228. for (int j = idx_m; j < idx_m + 8; j++) {
  2229. accum512_t_0 = _mm512_setzero_ps();
  2230. accum512_t_1 = _mm512_setzero_ps();
  2231. accum512_t_2 = _mm512_setzero_ps();
  2232. accum512_t_3 = _mm512_setzero_ps();
  2233. /* Processing the main chunk with 128-elements per round */
  2234. for (long idx_n = 0; idx_n < tag_n_128x; idx_n += 128) {
  2235. BF16_MATRIX_LOAD_1x32(matrixArray_0, a, lda, j, idx_n + 0)
  2236. BF16_MATRIX_LOAD_1x32(matrixArray_1, a, lda, j, idx_n + 32)
  2237. BF16_MATRIX_LOAD_1x32(matrixArray_2, a, lda, j, idx_n + 64)
  2238. BF16_MATRIX_LOAD_1x32(matrixArray_3, a, lda, j, idx_n + 96)
  2239. BF16_VECTOR_LOAD_1x32(xArray_0, x, idx_n + 0)
  2240. BF16_VECTOR_LOAD_1x32(xArray_1, x, idx_n + 32)
  2241. BF16_VECTOR_LOAD_1x32(xArray_2, x, idx_n + 64)
  2242. BF16_VECTOR_LOAD_1x32(xArray_3, x, idx_n + 96)
  2243. BF16_DOT_1x32(accum512_t_0, matrixArray_0, xArray_0)
  2244. BF16_DOT_1x32(accum512_t_1, matrixArray_1, xArray_1)
  2245. BF16_DOT_1x32(accum512_t_2, matrixArray_2, xArray_2)
  2246. BF16_DOT_1x32(accum512_t_3, matrixArray_3, xArray_3)
  2247. }
  2248. /* Processing the remaining <128 chunk with 32-elements per round */
  2249. for (long idx_n = tag_n_128x; idx_n < tag_n_32x; idx_n += 32) {
  2250. BF16_MATRIX_LOAD_1x32(matrixArray_0, a, lda, j, idx_n)
  2251. BF16_VECTOR_LOAD_1x32(xArray_0, x, idx_n)
  2252. BF16_DOT_1x32(accum512_t_0, matrixArray_0, xArray_0)
  2253. }
  2254. /* Processing the remaining <32 chunk with masked 32-elements processing */
  2255. if ((n&31) != 0) {
  2256. BF16_MATRIX_MASKZ_LOAD_1x32(matrixArray_0, a, lda, j, tag_n_32x, tail_mask)
  2257. BF16_VECTOR_MASKZ_LOAD_1x32(xArray_0, x, tag_n_32x, tail_mask)
  2258. BF16_DOT_1x32(accum512_t_2, matrixArray_0, xArray_0)
  2259. }
  2260. /* Accumulate the 4 registers into 1 register */
  2261. accum512_t_0 = _mm512_add_ps(accum512_t_0, accum512_t_1);
  2262. accum512_t_2 = _mm512_add_ps(accum512_t_2, accum512_t_3);
  2263. accum512_t_0 = _mm512_add_ps(accum512_t_0, accum512_t_2);
  2264. // Temply save the result into a ZMM
  2265. accum512_bridge[j-idx_m] = accum512_t_0;
  2266. }
  2267. FP32_INTERLEAVE_8x16_ARRAY(accum512_bridge)
  2268. FP32_ACCUM2_8x16_ARRAY(accum512_bridge)
  2269. accum512_bridge[1] = _mm512_permutex2var_ps(accum512_bridge[0], idx_base_0, accum512_bridge[4]);
  2270. accum512_bridge[2] = _mm512_permutex2var_ps(accum512_bridge[0], idx_base_1, accum512_bridge[4]);
  2271. accum512_bridge[1] = _mm512_add_ps(accum512_bridge[1], accum512_bridge[2]);
  2272. accum256_0 = _mm256_add_ps(_mm512_castps512_ps256(accum512_bridge[1]), _mm512_extractf32x8_ps(accum512_bridge[1], 1));
  2273. STORE8_COMPLETE_RESULT(accum256_0, y+idx_m)
  2274. }
  2275. }
  2276. if (tag_m_8x != m) {
  2277. __m128 tmp128;
  2278. for (BLASLONG j = tag_m_8x; j < m; j++) {
  2279. accum512_t_0 = _mm512_setzero_ps();
  2280. accum512_t_1 = _mm512_setzero_ps();
  2281. accum512_t_2 = _mm512_setzero_ps();
  2282. accum512_t_3 = _mm512_setzero_ps();
  2283. /* Processing the main chunk with 128-elements per round */
  2284. for (long idx_n = 0; idx_n < tag_n_128x; idx_n += 128) {
  2285. BF16_MATRIX_LOAD_1x32(matrixArray_0, a, lda, j, idx_n + 0)
  2286. BF16_MATRIX_LOAD_1x32(matrixArray_1, a, lda, j, idx_n + 32)
  2287. BF16_MATRIX_LOAD_1x32(matrixArray_2, a, lda, j, idx_n + 64)
  2288. BF16_MATRIX_LOAD_1x32(matrixArray_3, a, lda, j, idx_n + 96)
  2289. BF16_VECTOR_LOAD_1x32(xArray_0, x, idx_n + 0)
  2290. BF16_VECTOR_LOAD_1x32(xArray_1, x, idx_n + 32)
  2291. BF16_VECTOR_LOAD_1x32(xArray_2, x, idx_n + 64)
  2292. BF16_VECTOR_LOAD_1x32(xArray_3, x, idx_n + 96)
  2293. BF16_DOT_1x32(accum512_t_0, matrixArray_0, xArray_0)
  2294. BF16_DOT_1x32(accum512_t_1, matrixArray_1, xArray_1)
  2295. BF16_DOT_1x32(accum512_t_2, matrixArray_2, xArray_2)
  2296. BF16_DOT_1x32(accum512_t_3, matrixArray_3, xArray_3)
  2297. }
  2298. /* Processing the remaining <128 chunk with 32-elements per round */
  2299. for (long idx_n = tag_n_128x; idx_n < tag_n_32x; idx_n += 32) {
  2300. BF16_MATRIX_LOAD_1x32(matrixArray_0, a, lda, j, idx_n)
  2301. BF16_VECTOR_LOAD_1x32(xArray_0, x, idx_n)
  2302. BF16_DOT_1x32(accum512_t_0, matrixArray_0, xArray_0)
  2303. }
  2304. /* Processing the remaining <32 chunk with masked 32-elements processing */
  2305. if ((n&31) != 0) {
  2306. BF16_MATRIX_MASKZ_LOAD_1x32(matrixArray_0, a, lda, j, tag_n_32x, tail_mask)
  2307. BF16_VECTOR_MASKZ_LOAD_1x32(xArray_0, x, tag_n_32x, tail_mask)
  2308. BF16_DOT_1x32(accum512_t_2, matrixArray_0, xArray_0)
  2309. }
  2310. /* Accumulate the 4 registers into 1 register */
  2311. accum512_t_0 = _mm512_add_ps(accum512_t_0, accum512_t_1);
  2312. accum512_t_2 = _mm512_add_ps(accum512_t_2, accum512_t_3);
  2313. accum512_t_0 = _mm512_add_ps(accum512_t_0, accum512_t_2);
  2314. accum256_0 = _mm256_add_ps(_mm512_castps512_ps256(accum512_t_0), _mm512_extractf32x8_ps(accum512_t_0, 1));
  2315. accum128 = _mm_add_ps(_mm256_castps256_ps128(accum256_0), _mm256_extractf32x4_ps(accum256_0, 1));
  2316. tmp128 = _mm_shuffle_ps(accum128, accum128, 0x0e);
  2317. accum128 = _mm_add_ps(accum128, tmp128);
  2318. tmp128 = _mm_shuffle_ps(accum128, accum128, 0x01);
  2319. accum128 = _mm_add_ps(accum128, tmp128);
  2320. #ifndef ZERO_BETA
  2321. #ifndef ONE_BETA
  2322. y[j] = alpha * accum128[0] + beta * y[j];
  2323. #else
  2324. y[j] = alpha * accum128[0] + y[j];
  2325. #endif
  2326. #else
  2327. #ifndef ONE_ALPHA
  2328. y[j] = accum128[0] * alpha;
  2329. #else
  2330. y[j] = accum128[0];
  2331. #endif
  2332. #endif
  2333. }
  2334. }
  2335. return 0;
  2336. }
  2337. // 8 rows parallel processing BF16 GEMV kernel for n=32 && lda effective scenario (process before interleave)
  2338. #ifndef ZERO_BETA
  2339. #ifndef ONE_BETA
  2340. static int sbgemv_kernel_8x32_lda_direct_alpha_beta(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float beta, float *y)
  2341. #else
  2342. static int sbgemv_kernel_8x32_lda_direct_alpha_one(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float beta, float *y)
  2343. #endif
  2344. #else
  2345. #ifndef ONE_ALPHA
  2346. static int sbgemv_kernel_8x32_lda_direct_alpha(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float *y)
  2347. #else
  2348. static int sbgemv_kernel_8x32_lda_direct(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float *y)
  2349. #endif
  2350. #endif
  2351. {
  2352. BLASLONG tag_m_8x = m & (~7);
  2353. BLASLONG tag_n_32x = n & (~31);
  2354. __m512 accum512_0, accum512_1, accum512_2, accum512_3, accum512_4, accum512_5, accum512_6, accum512_7, \
  2355. accum512_8, accum512_9, accum512_10, accum512_11, accum512_12, accum512_13, accum512_14, accum512_15;
  2356. __m256 accum256_0;
  2357. __m128 accum128;
  2358. #ifndef ONE_ALPHA
  2359. __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
  2360. #endif
  2361. #ifndef ZERO_BETA
  2362. __m512 BETAVECTOR = _mm512_set1_ps(beta);
  2363. #endif
  2364. __m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3, matrixArray_4, matrixArray_5, matrixArray_6, matrixArray_7;
  2365. __m512i xArray_0;
  2366. unsigned int tail_mask_value = (((unsigned int)0xffffffff) >> (32-(n&31)));
  2367. __mmask32 tail_mask = *((__mmask32*) &tail_mask_value);
  2368. if (tag_m_8x > 0) {
  2369. __m512i M512_EPI32_4 = _mm512_set1_epi32(4);
  2370. __m512i idx_base_0 = _mm512_set_epi32(27, 26, 25, 24, 11, 10, 9, 8, 19, 18, 17, 16, 3, 2, 1, 0);
  2371. __m512i idx_base_1 = _mm512_add_epi32(idx_base_0, M512_EPI32_4);
  2372. for (BLASLONG idx_m = 0; idx_m < tag_m_8x; idx_m+=8) {
  2373. accum512_0 = _mm512_setzero_ps();
  2374. accum512_1 = _mm512_setzero_ps();
  2375. accum512_2 = _mm512_setzero_ps();
  2376. accum512_3 = _mm512_setzero_ps();
  2377. accum512_4 = _mm512_setzero_ps();
  2378. accum512_5 = _mm512_setzero_ps();
  2379. accum512_6 = _mm512_setzero_ps();
  2380. accum512_7 = _mm512_setzero_ps();
  2381. for (BLASLONG idx_n = 0; idx_n < tag_n_32x; idx_n+=32) {
  2382. // Load 8 rows from matrix
  2383. BF16_MATRIX_LOAD_8x32(matrixArray, a, lda, idx_m, idx_n)
  2384. // Load x
  2385. BF16_VECTOR_LOAD_1x32(xArray_0, x, idx_n)
  2386. // Calculate the temp result for a..h[0:31]
  2387. BF16_DOT_8x32(accum512, matrixArray, xArray_0)
  2388. }
  2389. if (tag_n_32x != n) { // Go with masked 512
  2390. // Load 8 rows from matrix
  2391. BF16_MATRIX_MASKZ_LOAD_8x32(matrixArray, a, lda, idx_m, tag_n_32x, tail_mask)
  2392. // Load x
  2393. BF16_VECTOR_MASKZ_LOAD_1x32(xArray_0, x, tag_n_32x, tail_mask)
  2394. // Calculate the temp result for a..h[0:31]
  2395. BF16_DOT_8x32(accum512, matrixArray, xArray_0)
  2396. }
  2397. // 2-step interleave for FP32 regsiter array
  2398. FP32_INTERLEAVE_8x16(accum512)
  2399. // Accumulate the 2 batch of registers into 2 register (0 and 4)
  2400. FP32_ACCUM2_8x16(accum512)
  2401. accum512_1 = _mm512_permutex2var_ps(accum512_0, idx_base_0, accum512_4);
  2402. accum512_2 = _mm512_permutex2var_ps(accum512_0, idx_base_1, accum512_4);
  2403. accum512_1 = _mm512_add_ps(accum512_1, accum512_2);
  2404. accum256_0 = _mm256_add_ps(_mm512_castps512_ps256(accum512_1), _mm512_extractf32x8_ps(accum512_1, 1));
  2405. STORE8_COMPLETE_RESULT(accum256_0, y+idx_m)
  2406. }
  2407. }
  2408. if (tag_m_8x != m) {
  2409. __m128 tmp128;
  2410. for (BLASLONG i = tag_m_8x; i < m; i++) {
  2411. accum512_0 = _mm512_setzero_ps();
  2412. for (BLASLONG idx_n = 0; idx_n < tag_n_32x; idx_n+=32) {
  2413. // Load 32 elements from matrix
  2414. BF16_MATRIX_LOAD_1x32(matrixArray_0, a, lda, i, idx_n)
  2415. // Load 32 elements from x
  2416. BF16_VECTOR_LOAD_1x32(xArray_0, x, idx_n)
  2417. // Calculate and accumulate the temp result
  2418. BF16_DOT_1x32(accum512_0, matrixArray_0, xArray_0)
  2419. }
  2420. if (tag_n_32x != n) {
  2421. // Load tail elements from matrix
  2422. BF16_MATRIX_MASKZ_LOAD_1x32(matrixArray_0, a, lda, i, tag_n_32x, tail_mask)
  2423. // Load 32 elements from x
  2424. BF16_VECTOR_MASKZ_LOAD_1x32(xArray_0, x, tag_n_32x, tail_mask)
  2425. // Calculate and accumulate the temp result
  2426. BF16_DOT_1x32(accum512_0, matrixArray_0, xArray_0)
  2427. }
  2428. accum256_0 = _mm256_add_ps(_mm512_castps512_ps256(accum512_0), _mm512_extractf32x8_ps(accum512_0, 1));
  2429. accum128 = _mm_add_ps(_mm256_castps256_ps128(accum256_0), _mm256_extractf32x4_ps(accum256_0, 1));
  2430. tmp128 = _mm_shuffle_ps(accum128, accum128, 0x0e);
  2431. accum128 = _mm_add_ps(accum128, tmp128);
  2432. tmp128 = _mm_shuffle_ps(accum128, accum128, 0x01);
  2433. accum128 = _mm_add_ps(accum128, tmp128);
  2434. #ifndef ZERO_BETA
  2435. #ifndef ONE_BETA
  2436. y[i] = alpha * accum128[0] + beta * y[i];
  2437. #else
  2438. y[i] = alpha * accum128[0] + y[i];
  2439. #endif
  2440. #else
  2441. #ifndef ONE_ALPHA
  2442. y[i] = accum128[0] * alpha;
  2443. #else
  2444. y[i] = accum128[0];
  2445. #endif
  2446. #endif
  2447. }
  2448. }
  2449. return 0;
  2450. }
  2451. // 8 rows parallel processing BF16 GEMV kernel for n<16 && lda effective scenario
  2452. #ifndef ZERO_BETA
  2453. #ifndef ONE_BETA
  2454. static int sbgemv_kernel_8x16m_lda_alpha_beta(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float beta, float *y)
  2455. #else
  2456. static int sbgemv_kernel_8x16m_lda_alpha_one(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float beta, float *y)
  2457. #endif
  2458. #else
  2459. #ifndef ONE_ALPHA
  2460. static int sbgemv_kernel_8x16m_lda_alpha(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float *y)
  2461. #else
  2462. static int sbgemv_kernel_8x16m_lda(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float *y)
  2463. #endif
  2464. #endif
  2465. {
  2466. BLASLONG tag_m_8x = m & (~7);
  2467. __m256i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3, matrixArray_4, matrixArray_5, matrixArray_6, matrixArray_7;
  2468. __m256i xArray256;
  2469. // Keep align with other kernels and macro definition, the high 256bit is never used
  2470. #ifndef ONE_ALPHA
  2471. __m512 ALPHAVECTOR = _mm512_castps256_ps512(_mm256_set1_ps(alpha));
  2472. #endif
  2473. #ifndef ZERO_BETA
  2474. __m512 BETAVECTOR = _mm512_castps256_ps512(_mm256_set1_ps(beta));
  2475. #endif
  2476. __m256 accum256_0, accum256_1, accum256_2, accum256_3, accum256_4, accum256_5, accum256_6, accum256_7, \
  2477. accum256_8, accum256_9, accum256_10, accum256_11, accum256_12, accum256_13, accum256_14, accum256_15;
  2478. __m256i M256_EPI32_4 = _mm256_set1_epi32(4);
  2479. __m256i idx_base_0 = _mm256_set_epi32(11, 10, 9, 8, 3, 2, 1, 0);
  2480. __m256i idx_base_1 = _mm256_add_epi32(idx_base_0, M256_EPI32_4);
  2481. unsigned short load_mask_value = (((unsigned short)0xffff) >> (16-n));
  2482. __mmask16 load_mask = *((__mmask16*) &load_mask_value);
  2483. if (n == 16) {
  2484. BF16_VECTOR_LOAD_1x16(xArray256, x, 0)
  2485. } else {
  2486. BF16_VECTOR_MASKZ_LOAD_1x16(xArray256, x, 0, load_mask)
  2487. }
  2488. if (n == 16) {
  2489. for (BLASLONG idx_m = 0; idx_m < tag_m_8x; idx_m+=8) {
  2490. accum256_0 = _mm256_setzero_ps();
  2491. accum256_1 = _mm256_setzero_ps();
  2492. accum256_2 = _mm256_setzero_ps();
  2493. accum256_3 = _mm256_setzero_ps();
  2494. accum256_4 = _mm256_setzero_ps();
  2495. accum256_5 = _mm256_setzero_ps();
  2496. accum256_6 = _mm256_setzero_ps();
  2497. accum256_7 = _mm256_setzero_ps();
  2498. BF16_MATRIX_LOAD_8x16(matrixArray, a, lda, idx_m, 0)
  2499. BF16_DOT_8x16(accum256, matrixArray, xArray256)
  2500. // 2-step interleave for FP32 regsiter array
  2501. FP32_INTERLEAVE_8x8(accum256)
  2502. // Accumulate the 2 batch of registers into 2 register (0 and 4)
  2503. FP32_ACCUM2_8x8(accum256)
  2504. accum256_1 = _mm256_permutex2var_ps(accum256_0, idx_base_0, accum256_4);
  2505. accum256_2 = _mm256_permutex2var_ps(accum256_0, idx_base_1, accum256_4);
  2506. accum256_1 = _mm256_add_ps(accum256_1, accum256_2);
  2507. STORE8_COMPLETE_RESULT(accum256_1, y+idx_m)
  2508. }
  2509. if (tag_m_8x != m) {
  2510. __m128 accum128, tmp128;
  2511. for (BLASLONG i = tag_m_8x; i < m; i++) {
  2512. accum256_0 = _mm256_setzero_ps();
  2513. matrixArray_0 = _mm256_loadu_si256(&a[(i)*lda]); // Load 1 rows with n=16
  2514. accum256_0 = _mm256_dpbf16_ps(accum256_0, (__m256bh) matrixArray_0, (__m256bh) xArray256);
  2515. accum128 = _mm_add_ps(_mm256_castps256_ps128(accum256_0), _mm256_extractf32x4_ps(accum256_0, 1));
  2516. tmp128 = _mm_shuffle_ps(accum128, accum128, 0x0e);
  2517. accum128 = _mm_add_ps(accum128, tmp128);
  2518. tmp128 = _mm_shuffle_ps(accum128, accum128, 0x01);
  2519. accum128 = _mm_add_ps(accum128, tmp128);
  2520. y[i] += accum128[0] * alpha;
  2521. }
  2522. }
  2523. } else {
  2524. for (BLASLONG idx_m = 0; idx_m < tag_m_8x; idx_m+=8) {
  2525. accum256_0 = _mm256_setzero_ps();
  2526. accum256_1 = _mm256_setzero_ps();
  2527. accum256_2 = _mm256_setzero_ps();
  2528. accum256_3 = _mm256_setzero_ps();
  2529. accum256_4 = _mm256_setzero_ps();
  2530. accum256_5 = _mm256_setzero_ps();
  2531. accum256_6 = _mm256_setzero_ps();
  2532. accum256_7 = _mm256_setzero_ps();
  2533. BF16_MATRIX_MASKZ_LOAD_8x16(matrixArray, a, lda, idx_m, 0, load_mask)
  2534. BF16_DOT_8x16(accum256, matrixArray, xArray256)
  2535. // 2-step interleave for FP32 regsiter array
  2536. FP32_INTERLEAVE_8x8(accum256)
  2537. // Accumulate the 2 batch of registers into 2 register (0 and 4)
  2538. FP32_ACCUM2_8x8(accum256)
  2539. accum256_1 = _mm256_permutex2var_ps(accum256_0, idx_base_0, accum256_4);
  2540. accum256_2 = _mm256_permutex2var_ps(accum256_0, idx_base_1, accum256_4);
  2541. accum256_1 = _mm256_add_ps(accum256_1, accum256_2);
  2542. STORE8_COMPLETE_RESULT(accum256_1, y+idx_m)
  2543. }
  2544. if (tag_m_8x != m) {
  2545. __m128 accum128, tmp128;
  2546. for (BLASLONG i = tag_m_8x; i < m; i++) {
  2547. accum256_0 = _mm256_setzero_ps();
  2548. matrixArray_0 = _mm256_maskz_loadu_epi16(load_mask, &a[(i)*lda]); // Load 1 rows with n=16
  2549. accum256_0 = _mm256_dpbf16_ps(accum256_0, (__m256bh) matrixArray_0, (__m256bh) xArray256);
  2550. accum128 = _mm_add_ps(_mm256_castps256_ps128(accum256_0), _mm256_extractf32x4_ps(accum256_0, 1));
  2551. tmp128 = _mm_shuffle_ps(accum128, accum128, 0x0e);
  2552. accum128 = _mm_add_ps(accum128, tmp128);
  2553. tmp128 = _mm_shuffle_ps(accum128, accum128, 0x01);
  2554. accum128 = _mm_add_ps(accum128, tmp128);
  2555. #ifndef ZERO_BETA
  2556. #ifndef ONE_BETA
  2557. y[i] = alpha * accum128[0] + beta * y[i];
  2558. #else
  2559. y[i] = alpha * accum128[0] + y[i];
  2560. #endif
  2561. #else
  2562. #ifndef ONE_ALPHA
  2563. y[i] = accum128[0] * alpha;
  2564. #else
  2565. y[i] = accum128[0];
  2566. #endif
  2567. #endif
  2568. }
  2569. }
  2570. }
  2571. return 0;
  2572. }