|
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755275627572758275927602761276227632764276527662767276827692770277127722773277427752776277727782779278027812782278327842785278627872788278927902791279227932794279527962797279827992800280128022803280428052806280728082809281028112812281328142815281628172818281928202821282228232824282528262827282828292830283128322833283428352836283728382839284028412842284328442845284628472848284928502851285228532854285528562857285828592860286128622863286428652866286728682869287028712872287328742875287628772878287928802881288228832884288528862887288828892890289128922893289428952896289728982899290029012902290329042905290629072908290929102911291229132914291529162917291829192920292129222923292429252926292729282929293029312932293329342935293629372938293929402941294229432944294529462947294829492950295129522953295429552956295729582959296029612962296329642965296629672968296929702971297229732974297529762977297829792980298129822983298429852986298729882989299029912992299329942995299629972998299930003001300230033004300530063007300830093010301130123013301430153016301730183019302030213022302330243025302630273028302930303031303230333034303530363037303830393040304130423043304430453046304730483049305030513052305330543055305630573058305930603061306230633064306530663067306830693070307130723073307430753076307730783079308030813082 |
- /***************************************************************************
- Copyright (c) 2014, The OpenBLAS Project
- All rights reserved.
- Redistribution and use in source and binary forms, with or without
- modification, are permitted provided that the following conditions are
- met:
- 1. Redistributions of source code must retain the above copyright
- notice, this list of conditions and the following disclaimer.
- 2. Redistributions in binary form must reproduce the above copyright
- notice, this list of conditions and the following disclaimer in
- the documentation and/or other materials provided with the
- distribution.
- 3. Neither the name of the OpenBLAS project nor the names of
- its contributors may be used to endorse or promote products
- derived from this software without specific prior written permission.
- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
- AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
- IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
- ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
- LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
- DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
- SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
- CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
- OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
- USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
- *****************************************************************************/
- #include <immintrin.h>
- #include "common.h"
- // Include common macros for BF16 based operations with IA intrinsics
- #include "bf16_common_macros.h"
-
- #ifndef ZERO_BETA // Beta is non-zero
-
- #ifndef ONE_BETA // BETA is not ONE
-
- #define STORE16_COMPLETE_RESULT STORE16_COMPLETE_RESULT_ALPHA_BETA
- #define STORE16_MASK_COMPLETE_RESULT STORE16_MASK_COMPLETE_RESULT_ALPHA_BETA
- #define STORE8_COMPLETE_RESULT STORE8_COMPLETE_RESULT_ALPHA_BETA
- #define STORE8_MASK_COMPLETE_RESULT STORE8_MASK_COMPLETE_RESULT_ALPHA_BETA
- #define STORE4_COMPLETE_RESULT STORE4_COMPLETE_RESULT_ALPHA_BETA
- #define STORE4_MASK_COMPLETE_RESULT STORE4_MASK_COMPLETE_RESULT_ALPHA_BETA
-
- #else // BETA is ONE
-
- #define STORE16_COMPLETE_RESULT STORE16_COMPLETE_RESULT_ALPHA_ONE
- #define STORE16_MASK_COMPLETE_RESULT STORE16_MASK_COMPLETE_RESULT_ALPHA_ONE
- #define STORE8_COMPLETE_RESULT STORE8_COMPLETE_RESULT_ALPHA_ONE
- #define STORE8_MASK_COMPLETE_RESULT STORE8_MASK_COMPLETE_RESULT_ALPHA_ONE
- #define STORE4_COMPLETE_RESULT STORE4_COMPLETE_RESULT_ALPHA_ONE
- #define STORE4_MASK_COMPLETE_RESULT STORE4_MASK_COMPLETE_RESULT_ALPHA_ONE
-
- #endif
-
- #else // BETA is zero
-
- #ifndef ONE_ALPHA // ALPHA is not ONE
-
- #define STORE16_COMPLETE_RESULT STORE16_COMPLETE_RESULT_ALPHA
- #define STORE16_MASK_COMPLETE_RESULT STORE16_MASK_COMPLETE_RESULT_ALPHA
- #define STORE8_COMPLETE_RESULT STORE8_COMPLETE_RESULT_ALPHA
- #define STORE8_MASK_COMPLETE_RESULT STORE8_MASK_COMPLETE_RESULT_ALPHA
- #define STORE4_COMPLETE_RESULT STORE4_COMPLETE_RESULT_ALPHA
- #define STORE4_MASK_COMPLETE_RESULT STORE4_MASK_COMPLETE_RESULT_ALPHA
-
- #else // ALPHA is ONE
-
- #define STORE16_COMPLETE_RESULT STORE16_COMPLETE_RESULT_DIRECT
- #define STORE16_MASK_COMPLETE_RESULT STORE16_MASK_COMPLETE_RESULT_DIRECT
- #define STORE8_COMPLETE_RESULT STORE8_COMPLETE_RESULT_DIRECT
- #define STORE8_MASK_COMPLETE_RESULT STORE8_MASK_COMPLETE_RESULT_DIRECT
- #define STORE4_COMPLETE_RESULT STORE4_COMPLETE_RESULT_DIRECT
- #define STORE4_MASK_COMPLETE_RESULT STORE4_MASK_COMPLETE_RESULT_DIRECT
-
- #endif
-
- #endif
-
-
- // 32 rows parallel processing BF16 GEMV kernel for n=1 && lda ineffective scenario
- #ifndef ZERO_BETA
- #ifndef ONE_BETA
- static int sbgemv_kernel_32x1_alpha_beta(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
- #else
- static int sbgemv_kernel_32x1_alpha_one(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
- #endif
- #else
- #ifndef ONE_ALPHA
- static int sbgemv_kernel_32x1_alpha(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
- #else
- static int sbgemv_kernel_32x1(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
- #endif
- #endif
- {
- BLASLONG tag_m_32x = m & (~31);
-
- __m512i matrixArray_0, matrixArray_1, matrixArray_2;
- __m512i xArray;
- __m512 result_0, result_1;
- #ifndef ONE_ALPHA
- __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
- #endif
- #ifndef ZERO_BETA
- #ifndef ONE_BETA
- __m512 BETAVECTOR = _mm512_set1_ps(beta);
- #endif
- #endif
-
- __m512i load_idx_lo = _mm512_set_epi16(0, 15, 0, 14, 0, 13, 0, 12, 0, 11, 0, 10, 0, 9, 0, 8,\
- 0, 7, 0, 6, 0, 5, 0, 4, 0, 3, 0, 2, 0, 1, 0, 0);
- __m512i M512_EPI16_16 = _mm512_set1_epi16(16);
- __m512i load_idx_hi = _mm512_add_epi16(load_idx_lo, M512_EPI16_16);
-
- unsigned int interleve_mask_value = ((unsigned int) 0x55555555);
- __mmask32 interleave_mask = *((__mmask32*) &interleve_mask_value);
-
- xArray = _mm512_set1_epi16((short) x[0]);
- xArray = _mm512_mask_blend_epi16(interleave_mask, _mm512_setzero_si512(), xArray);
-
- if (tag_m_32x > 0) {
- for (BLASLONG idx_m = 0; idx_m < tag_m_32x; idx_m+=32) {
- result_0 = _mm512_setzero_ps();
- result_1 = _mm512_setzero_ps();
-
- matrixArray_0 = _mm512_loadu_si512(&a[(idx_m)]); // Load 32 rows with n=1
- matrixArray_1 = _mm512_permutexvar_epi16(load_idx_lo, matrixArray_0); // Expand the low 16 elements
- matrixArray_2 = _mm512_permutexvar_epi16(load_idx_hi, matrixArray_0); // Expand the high 16 elements
-
- result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_1, (__m512bh) xArray);
- result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_2, (__m512bh) xArray);
-
- STORE16_COMPLETE_RESULT(result_0, y+idx_m)
- STORE16_COMPLETE_RESULT(result_1, y+idx_m+16)
- }
- }
-
- BLASLONG tail_num = m - tag_m_32x;
- if (tail_num > 16) {
- result_0 = _mm512_setzero_ps();
- result_1 = _mm512_setzero_ps();
-
- unsigned int tail_mask_value = (((unsigned int)0xffffffff) >> (32-tail_num));
- __mmask32 tail_mask = *((__mmask32*) &tail_mask_value);
- matrixArray_0 = _mm512_maskz_loadu_epi16(tail_mask, &a[(tag_m_32x)]); // Load 32 rows with n=1
- matrixArray_1 = _mm512_permutexvar_epi16(load_idx_lo, matrixArray_0); // Expand the low 16 elements
- matrixArray_2 = _mm512_permutexvar_epi16(load_idx_hi, matrixArray_0); // Expand the high 16 elements
-
- result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_1, (__m512bh) xArray);
- result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_2, (__m512bh) xArray);
-
- unsigned short store_mask_value = (((unsigned short)0xffff) >> (32-tail_num));
- __mmask16 store_mask = *((__mmask16*) &store_mask_value);
- STORE16_COMPLETE_RESULT(result_0, y+tag_m_32x)
- STORE16_MASK_COMPLETE_RESULT(result_1, y+tag_m_32x+16, store_mask)
- } else if (tail_num > 8) {
- __m256 result256_0 = _mm256_setzero_ps();
- __m256 result256_1 = _mm256_setzero_ps();
-
- __m256i load_idx_lo256 = _mm512_castsi512_si256(load_idx_lo);
- __m256i load_idx_hi256 = _mm512_extracti32x8_epi32(load_idx_lo, 0x1);
- __m256i xArray256 = _mm512_castsi512_si256(xArray);
-
- unsigned short tail_mask_value = (((unsigned short)0xffff) >> (16-tail_num));
- __mmask16 tail_mask = *((__mmask16*) &tail_mask_value);
- __m256i matrixArray256_0 = _mm256_maskz_loadu_epi16(tail_mask, &a[(tag_m_32x)]); // Load 16 rows with n=1
- __m256i matrixArray256_1 = _mm256_permutexvar_epi16(load_idx_lo256, matrixArray256_0); // Expand the low 8 elements
- __m256i matrixArray256_2 = _mm256_permutexvar_epi16(load_idx_hi256, matrixArray256_0); // Expand the high 8 elements
-
- result256_0 = _mm256_dpbf16_ps(result256_0, (__m256bh) matrixArray256_1, (__m256bh) xArray256);
- result256_1 = _mm256_dpbf16_ps(result256_1, (__m256bh) matrixArray256_2, (__m256bh) xArray256);
-
- unsigned char store_mask_value = (((unsigned char)0xff) >> (16-tail_num));
- __mmask8 store_mask = *((__mmask8*) &store_mask_value);
- STORE8_COMPLETE_RESULT(result256_0, y+tag_m_32x)
- STORE8_MASK_COMPLETE_RESULT(result256_1, y+tag_m_32x+8, store_mask)
- } else {
- __m128 result128_0 = _mm_setzero_ps();
- __m128 result128_1 = _mm_setzero_ps();
-
- __m128i load_idx_lo128 = _mm_set_epi16(0, 3, 0, 2, 0, 1, 0, 0);
- __m128i M128_EPI16_4 = _mm_set1_epi16(4);
- __m128i load_idx_hi128 = _mm_add_epi16(load_idx_lo128, M128_EPI16_4);
-
- __m128i xArray128 = _mm512_castsi512_si128(xArray);
-
- unsigned char tail_mask_value = (((unsigned char)0xff) >> (8-tail_num));
- __mmask8 tail_mask = *((__mmask8*) &tail_mask_value);
- __m128i matrixArray128_0 = _mm_maskz_loadu_epi16(tail_mask, &a[(tag_m_32x)]); // Load 8 rows with n=1
- __m128i matrixArray128_1 = _mm_permutexvar_epi16(load_idx_lo128, matrixArray128_0); // Expand the low 4 elements
- __m128i matrixArray128_2 = _mm_permutexvar_epi16(load_idx_hi128, matrixArray128_0); // Expand the high 4 elements
-
- result128_0 = _mm_dpbf16_ps(result128_0, (__m128bh) matrixArray128_1, (__m128bh) xArray128);
- result128_1 = _mm_dpbf16_ps(result128_1, (__m128bh) matrixArray128_2, (__m128bh) xArray128);
-
- if (tail_num > 4) {
- unsigned char store_mask_value = (((unsigned char)0xf) >> (8-tail_num));
- __mmask8 store_mask = *((__mmask8*) &store_mask_value);
- STORE4_COMPLETE_RESULT(result128_0, y+tag_m_32x)
- STORE4_MASK_COMPLETE_RESULT(result128_1, y+tag_m_32x+4, store_mask)
- } else {
- unsigned char store_mask_value = (((unsigned char)0xf) >> (4-tail_num));
- __mmask8 store_mask = *((__mmask8*) &store_mask_value);
- STORE4_MASK_COMPLETE_RESULT(result128_0, y+tag_m_32x, store_mask)
- }
- }
-
- return 0;
- }
-
- // 32 rows parallel processing BF16 GEMV kernel for n=2 && lda ineffective scenario
- #ifndef ZERO_BETA
- #ifndef ONE_BETA
- static int sbgemv_kernel_32x2_alpha_beta(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
- #else
- static int sbgemv_kernel_32x2_alpha_one(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
- #endif
- #else
- #ifndef ONE_ALPHA
- static int sbgemv_kernel_32x2_alpha(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
- #else
- static int sbgemv_kernel_32x2(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
- #endif
- #endif
- {
- BLASLONG tag_m_32x = m & (~31);
-
- __m512i matrixArray_0, matrixArray_1;
- __m512i xArray;
- __m512 result_0, result_1;
-
- #ifndef ONE_ALPHA
- __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
- #endif
- #ifndef ZERO_BETA
- __m512 BETAVECTOR = _mm512_set1_ps(beta);
- #endif
-
- unsigned char load_mask_value = (((unsigned char)0xff) >> 6);
- __mmask8 load_mask = *((__mmask8*) &load_mask_value);
- xArray = _mm512_broadcastd_epi32(_mm_maskz_loadu_epi16(load_mask, x));
-
- if (tag_m_32x > 0) {
- for (BLASLONG idx_m = 0; idx_m < tag_m_32x; idx_m+=32) {
- result_0 = _mm512_setzero_ps();
- result_1 = _mm512_setzero_ps();
-
- matrixArray_0 = _mm512_loadu_si512(&a[(idx_m)*2]); // Load 16 rows as n=2
- matrixArray_1 = _mm512_loadu_si512(&a[(idx_m+16)*2]); // Load 16 rows as n=2
-
- result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_0, (__m512bh) xArray);
- result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_1, (__m512bh) xArray);
-
- STORE16_COMPLETE_RESULT(result_0, y+idx_m)
- STORE16_COMPLETE_RESULT(result_1, y+idx_m+16)
- }
- }
-
- if (m - tag_m_32x >= 16) {
- result_0 = _mm512_setzero_ps();
-
- matrixArray_0 = _mm512_loadu_si512(&a[(tag_m_32x)*2]); // Load 16 rows with n=2
-
- result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_0, (__m512bh) xArray);
-
- STORE16_COMPLETE_RESULT(result_0, y+tag_m_32x)
-
- tag_m_32x += 16;
- }
-
- BLASLONG tail_num = m - tag_m_32x;
- if (tail_num > 8) {
- result_0 = _mm512_setzero_ps();
-
- unsigned short tail_mask_value = (((unsigned short)0xffff) >> (16-(m&15)));
- __mmask16 tail_mask = *((__mmask16*) &tail_mask_value);
- matrixArray_0 = _mm512_maskz_loadu_epi32(tail_mask, &a[(tag_m_32x)*2]); // Load 16 rows with n=2
-
- result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_0, (__m512bh) xArray);
-
- STORE16_MASK_COMPLETE_RESULT(result_0, y+tag_m_32x, tail_mask)
- } else if (tail_num == 8) {
- __m256 result256 = _mm256_setzero_ps();
-
- __m256i matrixArray256 = _mm256_loadu_si256(&a[(tag_m_32x)*2]); // Load 8 rows with n=2
- __m256i xArray256 = _mm512_castsi512_si256(xArray);
- result256 = _mm256_dpbf16_ps(result256, (__m256bh) matrixArray256, (__m256bh) xArray256);
-
- STORE8_COMPLETE_RESULT(result256, y+tag_m_32x)
- } else {
- __m256 result256 = _mm256_setzero_ps();
-
- unsigned char tail_mask_value = (((unsigned char)0xff) >> (8-(m&7)));
- __mmask8 tail_mask = *((__mmask8*) &tail_mask_value);
- __m256i matrixArray256 = _mm256_maskz_loadu_epi32(tail_mask, &a[(tag_m_32x)*2]); // Load 8 rows with n=2
- __m256i xArray256 = _mm512_castsi512_si256(xArray);
- result256 = _mm256_dpbf16_ps(result256, (__m256bh) matrixArray256, (__m256bh) xArray256);
-
- STORE8_MASK_COMPLETE_RESULT(result256, y+tag_m_32x, tail_mask)
- }
-
- return 0;
- }
-
- // 32 rows parallel processing BF16 GEMV kernel for n=3 && lda ineffective scenario
- #ifndef ZERO_BETA
- #ifndef ONE_BETA
- static int sbgemv_kernel_32x3_alpha_beta(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
- #else
- static int sbgemv_kernel_32x3_alpha_one(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
- #endif
- #else
- #ifndef ONE_ALPHA
- static int sbgemv_kernel_32x3_alpha(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
- #else
- static int sbgemv_kernel_32x3(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
- #endif
- #endif
- {
- BLASLONG tag_m_32x = m & (~31);
-
- __m512 result_0, result_1;
-
- #ifndef ONE_ALPHA
- __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
- #endif
- #ifndef ZERO_BETA
- __m512 BETAVECTOR = _mm512_set1_ps(beta);
- #endif
-
- unsigned char x_load_mask_value = (((unsigned char)0xff) >> 5);
- __mmask8 x_load_mask = *((__mmask8*) &x_load_mask_value);
- __m128i xTmp = _mm_maskz_loadu_epi16(x_load_mask, x); // x0|x1|x2|0|0|0|0|0|
- __m512i xArray_0 = _mm512_broadcastd_epi32(xTmp); // x0|x1|x0|x1|...|x0|x1|
- __m512i xArray_1 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(xTmp, 0x1)); // x2| 0|x2| 0|...|x2| 0|
-
- __m512i load_idx_base;
- __m512i M512_EPI16_2, M512_EPI16_8, M512_EPI16_16;
- M512_EPI16_2 = _mm512_set1_epi16(2);
- M512_EPI16_8 = _mm512_add_epi16(M512_EPI16_2, M512_EPI16_2);
- M512_EPI16_8 = _mm512_add_epi16(M512_EPI16_8, M512_EPI16_8);
- M512_EPI16_16 = _mm512_add_epi16(M512_EPI16_8, M512_EPI16_8);
- load_idx_base = _mm512_set_epi16(46, 45, 43, 42, 40, 39, 37, 36, 34, 33, 31, 30, 28, 27, 25, 24,
- 22, 21, 19, 18, 16, 15, 13, 12, 10, 9, 7, 6, 4, 3, 1, 0);
-
- if (tag_m_32x > 0) {
- __m512i load_idx01_1st, load_idx01_2nd, load_idx2_1st, load_idx2_2nd;
- __m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3, matrixArray_4, matrixArray_5, matrixArray_6;
-
- unsigned int idx_blend_mask_value = ((unsigned int)0x80000000);
- __mmask32 idx_blend_mask = *((__mmask32*) &idx_blend_mask_value);
-
- load_idx01_1st = load_idx_base;
- load_idx01_2nd = _mm512_add_epi16(load_idx01_1st, M512_EPI16_16);
- load_idx2_1st = _mm512_add_epi16(load_idx01_1st, M512_EPI16_2);
- load_idx2_2nd = _mm512_add_epi16(load_idx01_2nd, M512_EPI16_2);
- load_idx2_2nd = _mm512_mask_blend_epi16(idx_blend_mask, load_idx2_2nd, _mm512_setzero_si512());
-
- for (BLASLONG idx_m = 0; idx_m < tag_m_32x; idx_m+=32) {
- result_0 = _mm512_setzero_ps();
- result_1 = _mm512_setzero_ps();
-
- matrixArray_0 = _mm512_loadu_si512(&a[(idx_m)*3]); // Load 10 rows with n=3 plus 2 element
- matrixArray_1 = _mm512_loadu_si512(&a[((idx_m+10)*3 + 2)]); // Load 10 rows with n=3 plus 2 element
- matrixArray_2 = _mm512_loadu_si512(&a[((idx_m+21)*3 + 1)]); // Load 10 rows with n=3 plus 2 element
-
- matrixArray_3 = _mm512_permutex2var_epi16(matrixArray_0, load_idx01_1st, matrixArray_1); // Select the first 2 elements for each row
- matrixArray_4 = _mm512_permutex2var_epi16(matrixArray_1, load_idx01_2nd, matrixArray_2); // Select the first 2 elements for each row
- matrixArray_5 = _mm512_permutex2var_epi16(matrixArray_0, load_idx2_1st, matrixArray_1); // Select the third element for each row
- matrixArray_6 = _mm512_permutex2var_epi16(matrixArray_1, load_idx2_2nd, matrixArray_2); // Select the third element for each row
-
- result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_3, (__m512bh) xArray_0);
- result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_5, (__m512bh) xArray_1);
- result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_4, (__m512bh) xArray_0);
- result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_6, (__m512bh) xArray_1);
-
- STORE16_COMPLETE_RESULT(result_0, y+idx_m)
- STORE16_COMPLETE_RESULT(result_1, y+idx_m+16)
- }
- }
-
- if (tag_m_32x != m) {
- __m256i load256_idx01_1st, load256_idx01_2nd, load256_idx2_1st, load256_idx2_2nd;
- __m256i matrixArray256_0, matrixArray256_1, matrixArray256_2, matrixArray256_3, matrixArray256_4, matrixArray256_5, matrixArray256_6;
- __m256 result256_0, result256_1;
-
- unsigned short idx256_blend_mask_value = ((unsigned short)0x8000);
- __mmask16 idx256_blend_mask = *((__mmask16*) &idx256_blend_mask_value);
-
- load256_idx01_1st = _mm512_castsi512_si256(load_idx_base);
- load256_idx01_2nd = _mm256_add_epi16(load256_idx01_1st, _mm512_castsi512_si256(M512_EPI16_8));
- load256_idx2_1st = _mm256_add_epi16(load256_idx01_1st, _mm512_castsi512_si256(M512_EPI16_2));
- load256_idx2_2nd = _mm256_add_epi16(load256_idx01_2nd, _mm512_castsi512_si256(M512_EPI16_2));
- load256_idx2_2nd = _mm256_mask_blend_epi16(idx256_blend_mask, load256_idx2_2nd, _mm256_setzero_si256());
-
- if (m - tag_m_32x > 15) {
- result256_0 = _mm256_setzero_ps();
- result256_1 = _mm256_setzero_ps();
-
- matrixArray256_0 = _mm256_loadu_si256(&a[(tag_m_32x)*3]); // Load 5 rows with n=3 plus 1 element
- matrixArray256_1 = _mm256_loadu_si256(&a[((tag_m_32x+5)*3 + 1)]); // Load 5 rows with n=3 plus 1 element
- matrixArray256_2 = _mm256_loadu_si256(&a[((tag_m_32x+10)*3 + 2)]); // Load 5 rows with n=3 plus 1 element
-
- matrixArray256_3 = _mm256_permutex2var_epi16(matrixArray256_0, load256_idx01_1st, matrixArray256_1); // Select the first 2 elements for each row
- matrixArray256_4 = _mm256_permutex2var_epi16(matrixArray256_1, load256_idx01_2nd, matrixArray256_2); // Select the first 2 elements for each row
- matrixArray256_5 = _mm256_permutex2var_epi16(matrixArray256_0, load256_idx2_1st, matrixArray256_1); // Select the third element for each row
- matrixArray256_6 = _mm256_permutex2var_epi16(matrixArray256_1, load256_idx2_2nd, matrixArray256_2); // Select the third element for each row
-
- result256_0 = _mm256_dpbf16_ps(result256_0, (__m256bh) matrixArray256_3, (__m256bh) _mm512_castsi512_si256(xArray_0));
- result256_0 = _mm256_dpbf16_ps(result256_0, (__m256bh) matrixArray256_5, (__m256bh) _mm512_castsi512_si256(xArray_1));
- result256_1 = _mm256_dpbf16_ps(result256_1, (__m256bh) matrixArray256_4, (__m256bh) _mm512_castsi512_si256(xArray_0));
- result256_1 = _mm256_dpbf16_ps(result256_1, (__m256bh) matrixArray256_6, (__m256bh) _mm512_castsi512_si256(xArray_1));
-
- STORE8_COMPLETE_RESULT(result256_0, y+tag_m_32x)
- STORE8_COMPLETE_RESULT(result256_1, y+tag_m_32x+8)
-
- tag_m_32x += 16;
- }
-
- if (tag_m_32x != m) {
- result256_0 = _mm256_setzero_ps();
- result256_1 = _mm256_setzero_ps();
- BLASLONG tail_num = m-tag_m_32x;
-
- if (tail_num > 10) {
- unsigned short tail_mask_value = (((unsigned short)0xffff) >> (16-((tail_num-10-1)*3+1)));
- __mmask16 tail_mask = *((__mmask16*) &tail_mask_value);
- matrixArray256_0 = _mm256_loadu_si256(&a[(tag_m_32x)*3]); // Load 5 rows with n=3 plus 1 element
- matrixArray256_1 = _mm256_loadu_si256(&a[((tag_m_32x+5)*3 + 1)]); // Load 5 rows with n=3 plus 1 element
- matrixArray256_2 = _mm256_maskz_loadu_epi16(tail_mask, &a[((tag_m_32x+10)*3 + 2)]); // Load m-tag_m_32x-10 rows
-
- matrixArray256_3 = _mm256_permutex2var_epi16(matrixArray256_0, load256_idx01_1st, matrixArray256_1); // Select the first 2 elements for each row
- matrixArray256_4 = _mm256_permutex2var_epi16(matrixArray256_1, load256_idx01_2nd, matrixArray256_2); // Select the first 2 elements for each row
- matrixArray256_5 = _mm256_permutex2var_epi16(matrixArray256_0, load256_idx2_1st, matrixArray256_1); // Select the third element for each row
- matrixArray256_6 = _mm256_permutex2var_epi16(matrixArray256_1, load256_idx2_2nd, matrixArray256_2); // Select the third element for each row
-
- result256_0 = _mm256_dpbf16_ps(result256_0, (__m256bh) matrixArray256_3, (__m256bh) _mm512_castsi512_si256(xArray_0));
- result256_0 = _mm256_dpbf16_ps(result256_0, (__m256bh) matrixArray256_5, (__m256bh) _mm512_castsi512_si256(xArray_1));
- result256_1 = _mm256_dpbf16_ps(result256_1, (__m256bh) matrixArray256_4, (__m256bh) _mm512_castsi512_si256(xArray_0));
- result256_1 = _mm256_dpbf16_ps(result256_1, (__m256bh) matrixArray256_6, (__m256bh) _mm512_castsi512_si256(xArray_1));
- } else if (tail_num > 5) {
- unsigned short tail_mask_value = (((unsigned short)0xffff) >> (16-((tail_num-5-1)*3+2)));
- __mmask16 tail_mask = *((__mmask16*) &tail_mask_value);
- matrixArray256_0 = _mm256_loadu_si256(&a[(tag_m_32x)*3]); // Load 5 rows with n=3 plus 1 element
- matrixArray256_1 = _mm256_maskz_loadu_epi16(tail_mask, &a[((tag_m_32x+5)*3+1)]); // Load m-tag_m_32x-5 rows
- matrixArray256_2 = _mm256_setzero_si256();
-
- matrixArray256_3 = _mm256_permutex2var_epi16(matrixArray256_0, load256_idx01_1st, matrixArray256_1); // Select the first 2 elements for each row
- matrixArray256_4 = _mm256_permutex2var_epi16(matrixArray256_1, load256_idx01_2nd, matrixArray256_2); // Select the first 2 elements for each row
- matrixArray256_5 = _mm256_permutex2var_epi16(matrixArray256_0, load256_idx2_1st, matrixArray256_1); // Select the third element for each row
- matrixArray256_6 = _mm256_permutex2var_epi16(matrixArray256_1, load256_idx2_2nd, matrixArray256_2); // Select the third element for each row
-
- result256_0 = _mm256_dpbf16_ps(result256_0, (__m256bh) matrixArray256_3, (__m256bh) _mm512_castsi512_si256(xArray_0));
- result256_0 = _mm256_dpbf16_ps(result256_0, (__m256bh) matrixArray256_5, (__m256bh) _mm512_castsi512_si256(xArray_1));
- result256_1 = _mm256_dpbf16_ps(result256_1, (__m256bh) matrixArray256_4, (__m256bh) _mm512_castsi512_si256(xArray_0));
- result256_1 = _mm256_dpbf16_ps(result256_1, (__m256bh) matrixArray256_6, (__m256bh) _mm512_castsi512_si256(xArray_1));
- } else {
- unsigned short tail_mask_value = (((unsigned short)0xffff) >> (16-(tail_num*3)));
- __mmask16 tail_mask = *((__mmask16*) &tail_mask_value);
- matrixArray256_0 = _mm256_maskz_loadu_epi16(tail_mask, &a[(tag_m_32x)*3]); // Load m-tag_m_32x rows
- matrixArray256_1 = _mm256_setzero_si256();
-
- matrixArray256_3 = _mm256_permutex2var_epi16(matrixArray256_0, load256_idx01_1st, matrixArray256_1); // Select the first 2 elements for each row
- matrixArray256_5 = _mm256_permutex2var_epi16(matrixArray256_0, load256_idx2_1st, matrixArray256_1); // Select the third element for each row
-
- result256_0 = _mm256_dpbf16_ps(result256_0, (__m256bh) matrixArray256_3, (__m256bh) _mm512_castsi512_si256(xArray_0));
- result256_0 = _mm256_dpbf16_ps(result256_0, (__m256bh) matrixArray256_5, (__m256bh) _mm512_castsi512_si256(xArray_1));
- }
-
- unsigned short store_tail_mask_value = (((unsigned short)0xffff) >> (16-(tail_num)));
- __mmask16 store_tail_mask = *((__mmask16*) &store_tail_mask_value);
- __m512 result512 = _mm512_insertf32x8(_mm512_castps256_ps512(result256_0), result256_1, 0x1);
- STORE16_MASK_COMPLETE_RESULT(result512, y+tag_m_32x, store_tail_mask)
- }
- }
-
- return 0;
- }
-
- // 16 rows parallel processing BF16 GEMV kernel for n=4 && lda ineffective scenario
- #ifndef ZERO_BETA
- #ifndef ONE_BETA
- static int sbgemv_kernel_16x4_alpha_beta(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
- #else
- static int sbgemv_kernel_16x4_alpha_one(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
- #endif
- #else
- #ifndef ONE_ALPHA
- static int sbgemv_kernel_16x4_alpha(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
- #else
- static int sbgemv_kernel_16x4(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
- #endif
- #endif
- {
- BLASLONG tag_m_16x = m & (~15);
- __m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3;
- __m512i xArray_01, xArray_23, xArray_remix;
- __m512 result;
-
- #ifndef ONE_ALPHA
- __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
- #endif
- #ifndef ZERO_BETA
- __m512 BETAVECTOR = _mm512_set1_ps(beta);
- #endif
-
- __m512i M512_EPI32_1 = _mm512_set1_epi32(1);
- __m512i idx_base_0 = _mm512_set_epi32(30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
- __m512i idx_base_1 = _mm512_add_epi32(idx_base_0, M512_EPI32_1);
- __m512i idx_base_remix = _mm512_inserti32x8(idx_base_0, _mm512_castsi512_si256(idx_base_1), 0x1);
-
- unsigned char x_load_mask_value = (((unsigned char)0xf) >> 2);
- __mmask8 x_load_mask = *((__mmask8*) &x_load_mask_value);
- __m128i xTmp = _mm_maskz_loadu_epi32(x_load_mask, x); // |x0|x1|x2|x3|0|0|0|0|
- xArray_01 = _mm512_broadcastd_epi32(xTmp); // |x0|x1|x0|x1|...|x0|x1|
- xArray_23 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(xTmp, 0x1)); // |x2|x3|x2|x3|...|x2|x3|
- unsigned short blend_mask_value = ((unsigned short)0xff00);
- __mmask16 blend_mask = *((__mmask16*) &blend_mask_value);
- xArray_remix = _mm512_mask_blend_epi32(blend_mask, xArray_01, xArray_23); // |x0|x1|x0|x1|x0|x1|x0|x1|...|x2|x3|x2|x3|x2|x3|x2|x3|
-
- if (tag_m_16x > 0) {
- for (BLASLONG idx_m = 0; idx_m < tag_m_16x; idx_m+=16) {
- result = _mm512_setzero_ps();
-
- matrixArray_0 = _mm512_loadu_si512(&a[(idx_m)*4]); // Load 8 rows with n=4
- matrixArray_1 = _mm512_loadu_si512(&a[(idx_m+8)*4]); // Load 8 rows with n=4
-
- matrixArray_2 = _mm512_permutex2var_epi32(matrixArray_0, idx_base_0, matrixArray_1); // |a0|a1|...|h0|h1|i0|i1|...|p0|p1|
- matrixArray_3 = _mm512_permutex2var_epi32(matrixArray_0, idx_base_1, matrixArray_1); // |a2|a3|...|h2|h3|i2|i3|...|p2|p3|
-
- result = _mm512_dpbf16_ps(result, (__m512bh) matrixArray_2, (__m512bh) xArray_01);
- result = _mm512_dpbf16_ps(result, (__m512bh) matrixArray_3, (__m512bh) xArray_23);
-
- STORE16_COMPLETE_RESULT(result, y+idx_m)
- }
- }
-
- if (m - tag_m_16x > 7) {
- result = _mm512_setzero_ps();
-
- matrixArray_0 = _mm512_loadu_si512(&a[(tag_m_16x)*4]); // Load 8 rows with n=4
- matrixArray_2 = _mm512_permutexvar_epi32(idx_base_remix, matrixArray_0); // a0|a1|...|h0|h1|a2|a3|...|h2|h3|
-
- result = _mm512_dpbf16_ps(result, (__m512bh) matrixArray_2, (__m512bh) xArray_remix);
- __m256 result256 = _mm256_add_ps(_mm512_castps512_ps256(result), _mm512_extractf32x8_ps(result, 1));
-
- STORE8_COMPLETE_RESULT(result256, y+tag_m_16x)
- tag_m_16x += 8;
- }
-
- BLASLONG tail_num = m-tag_m_16x;
- if (tail_num != 0) {
- result = _mm512_setzero_ps();
-
- unsigned short tail_mask_value = (((unsigned short)0xffff) >> (16-tail_num*2));
- __mmask16 tail_mask = *((__mmask16*) &tail_mask_value);
- matrixArray_0 = _mm512_maskz_loadu_epi32(tail_mask, &a[(tag_m_16x)*4]); // Load 8 rows with n=4
- matrixArray_2 = _mm512_permutexvar_epi32(idx_base_remix, matrixArray_0); // a0|a1|...|h0|h1|a2|a3|...|h2|h3|
-
- result = _mm512_dpbf16_ps(result, (__m512bh) matrixArray_2, (__m512bh) xArray_remix);
- __m256 result256 = _mm256_add_ps(_mm512_castps512_ps256(result), _mm512_extractf32x8_ps(result, 1));
-
- unsigned char store_tail_mask_value = (((unsigned char)0xff) >> (8-tail_num));
- __mmask8 store_tail_mask = *((__mmask8*) &store_tail_mask_value);
- STORE8_MASK_COMPLETE_RESULT(result256, y+tag_m_16x, store_tail_mask)
- }
-
- return 0;
- }
-
- // 30 rows parallel processing BF16 GEMV kernel for n=5 && lda ineffective scenario
- #ifndef ZERO_BETA
- #ifndef ONE_BETA
- static int sbgemv_kernel_30x5_alpha_beta(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
- #else
- static int sbgemv_kernel_30x5_alpha_one(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
- #endif
- #else
- #ifndef ONE_ALPHA
- static int sbgemv_kernel_30x5_alpha(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
- #else
- static int sbgemv_kernel_30x5(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
- #endif
- #endif
- {
- BLASLONG tag_m_30x = m - (m%30);
-
- unsigned char x_load_mask_value = (((unsigned char)0xff) >> 3);
- __mmask8 x_load_mask = *((__mmask8*) &x_load_mask_value);
- __m128i x128 = _mm_maskz_loadu_epi16(x_load_mask, x); // x0|x1|x2|x3|x4|0|0|0|
-
- #ifndef ONE_ALPHA
- __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
- #endif
- #ifndef ZERO_BETA
- __m512 BETAVECTOR = _mm512_set1_ps(beta);
- #endif
-
- __m512 result_0, result_1;
- __m512i xArray_01 = _mm512_broadcastd_epi32(x128); // x0|x1|x0|x1|...|x0|x1|
- __m512i xArray_23 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128, 0x1)); // x2|x3|x2|x3|...|x2|x3|
- __m512i xArray_4 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128, 0x2)); // x4| 0|x4| 0|...|x4| 0|
-
- __m512i M512_EPI16_2 = _mm512_set1_epi16(2);
- __m512i load_idx01_stage1_1st = _mm512_set_epi16( 0, 0, 0, 0, 0, 0, 0, 0, 58, 57, 53, 52, 48, 47, 43, 42,
- 38, 37, 33, 32, 26, 25, 21, 20, 16, 15, 11, 10, 6, 5, 1, 0);
- __m512i load_idx01_stage1_2nd = _mm512_shuffle_i32x4(load_idx01_stage1_1st, load_idx01_stage1_1st, 0x39);
- __m512i load_idx01_stage1_3rd = _mm512_shuffle_i32x4(load_idx01_stage1_1st, load_idx01_stage1_1st, 0x4f);
-
- __m512i load_idx23_stage1_1st = _mm512_add_epi16(load_idx01_stage1_1st, M512_EPI16_2);
- __m512i load_idx23_stage1_2nd = _mm512_add_epi16(load_idx01_stage1_2nd, M512_EPI16_2);
- __m512i load_idx23_stage1_3rd = _mm512_add_epi16(load_idx01_stage1_3rd, M512_EPI16_2);
-
- __m512i load_idx4_stage1_1st = _mm512_add_epi16(load_idx23_stage1_1st, M512_EPI16_2);
- __m512i load_idx4_stage1_2nd = _mm512_add_epi16(load_idx23_stage1_2nd, M512_EPI16_2);
- __m512i load_idx4_stage1_3rd = _mm512_add_epi16(load_idx23_stage1_3rd, M512_EPI16_2);
-
- __m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3, matrixArray_4;
- __m512i matrixArray_stage1_0, matrixArray_stage1_1, matrixArray_stage1_2;
- __m512i matrixArray_stage2_0, matrixArray_stage2_1;
-
- unsigned int load_mask_value = (((unsigned int)0xffffffff) >> 2);
- __mmask32 load_mask = *((__mmask32*) &load_mask_value);
- unsigned short store_mask_value = (((unsigned short)0xffff) >> 2);
- __mmask16 store_mask = *((__mmask16*) &store_mask_value);
-
- if (tag_m_30x > 0) {
- unsigned short blend_mask_value_0 = ((unsigned short)0xf000);
- __mmask16 blend_mask_0 = *((__mmask16*) &blend_mask_value_0);
- unsigned short blend_mask_value_1 = ((unsigned short)0x3f00);
- __mmask16 blend_mask_1 = *((__mmask16*) &blend_mask_value_1);
- for (BLASLONG idx_m = 0; idx_m < tag_m_30x; idx_m+=30) {
- result_0 = _mm512_setzero_ps();
- result_1 = _mm512_setzero_ps();
-
- matrixArray_0 = _mm512_maskz_loadu_epi16(load_mask, &a[(idx_m)*5]); // Load 6 rows with n=5
- matrixArray_1 = _mm512_maskz_loadu_epi16(load_mask, &a[((idx_m+6)*5)]); // Load 6 rows with n=5
- matrixArray_2 = _mm512_maskz_loadu_epi16(load_mask, &a[((idx_m+12)*5)]); // Load 6 rows with n=5
- matrixArray_3 = _mm512_maskz_loadu_epi16(load_mask, &a[((idx_m+18)*5)]); // Load 6 rows with n=5
- matrixArray_4 = _mm512_maskz_loadu_epi16(load_mask, &a[((idx_m+24)*5)]); // Load 6 rows with n=5
-
- // Process the 0|1 elements
- // Stage 1: Select the 0|1 elements for each row
- matrixArray_stage1_0 = _mm512_permutex2var_epi16(matrixArray_0, load_idx01_stage1_1st, matrixArray_1);
- matrixArray_stage1_1 = _mm512_permutex2var_epi16(matrixArray_2, load_idx01_stage1_2nd, matrixArray_3);
- matrixArray_stage1_2 = _mm512_permutexvar_epi16(load_idx01_stage1_3rd, matrixArray_4);
- // Stage 2: Reorder and compress all the 0|1 elements
- matrixArray_stage2_0 = _mm512_mask_blend_epi32(blend_mask_0, matrixArray_stage1_0, matrixArray_stage1_1);
- matrixArray_stage2_1 = _mm512_mask_blend_epi32(blend_mask_1, matrixArray_stage1_1, matrixArray_stage1_2);
- // Calculate the result of the 0|1 elements
- result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage2_0, (__m512bh) xArray_01);
- result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_stage2_1, (__m512bh) xArray_01);
-
- // Process the 2|3 elements
- // Stage 1: Select the 2|3 elements for each row
- matrixArray_stage1_0 = _mm512_permutex2var_epi16(matrixArray_0, load_idx23_stage1_1st, matrixArray_1);
- matrixArray_stage1_1 = _mm512_permutex2var_epi16(matrixArray_2, load_idx23_stage1_2nd, matrixArray_3);
- matrixArray_stage1_2 = _mm512_permutexvar_epi16(load_idx23_stage1_3rd, matrixArray_4);
- // Stage 2: Reorder and compress all the 2|3 elements
- matrixArray_stage2_0 = _mm512_mask_blend_epi32(blend_mask_0, matrixArray_stage1_0, matrixArray_stage1_1);
- matrixArray_stage2_1 = _mm512_mask_blend_epi32(blend_mask_1, matrixArray_stage1_1, matrixArray_stage1_2);
- // Calculate the result of the 2|3 elements and accumulate the result of 0|1 elements
- result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage2_0, (__m512bh) xArray_23);
- result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_stage2_1, (__m512bh) xArray_23);
-
- // Process the for 4 elements
- // Stage 1: Select the 4 elements for each row
- matrixArray_stage1_0 = _mm512_permutex2var_epi16(matrixArray_0, load_idx4_stage1_1st, matrixArray_1);
- matrixArray_stage1_1 = _mm512_permutex2var_epi16(matrixArray_2, load_idx4_stage1_2nd, matrixArray_3);
- matrixArray_stage1_2 = _mm512_permutexvar_epi16(load_idx4_stage1_3rd, matrixArray_4);
- // Stage 2: Reorder and compress all the 4 elements
- matrixArray_stage2_0 = _mm512_mask_blend_epi32(blend_mask_0, matrixArray_stage1_0, matrixArray_stage1_1);
- matrixArray_stage2_1 = _mm512_mask_blend_epi32(blend_mask_1, matrixArray_stage1_1, matrixArray_stage1_2);
- // Calculate the result of the 4 element and accumulate the result of 0|1 and 2|3 elements
- result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage2_0, (__m512bh) xArray_4);
- result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_stage2_1, (__m512bh) xArray_4);
-
- STORE16_COMPLETE_RESULT(result_0, y+idx_m)
- STORE16_MASK_COMPLETE_RESULT(result_1, y+idx_m+16, store_mask)
- }
- }
-
- if (m - tag_m_30x > 11) {
- BLASLONG tag_m_12x = m - ((m-tag_m_30x)%12);
- for (BLASLONG idx_m = tag_m_30x; idx_m < tag_m_12x; idx_m+=12) {
- unsigned short store_less_mask_value = (((unsigned short)0xffff) >> 4);
- __mmask16 store_less_mask = *((__mmask16*) &store_less_mask_value);
- result_0 = _mm512_setzero_ps();
-
- matrixArray_0 = _mm512_maskz_loadu_epi16(load_mask, &a[(idx_m)*5]); // Load 6 rows with n=5
- matrixArray_1 = _mm512_maskz_loadu_epi16(load_mask, &a[((idx_m+6)*5)]); // Load 6 rows with n=5
-
- // Interleave the elements
- matrixArray_stage1_0 = _mm512_permutex2var_epi16(matrixArray_0, load_idx01_stage1_1st, matrixArray_1);
- matrixArray_stage1_1 = _mm512_permutex2var_epi16(matrixArray_0, load_idx23_stage1_1st, matrixArray_1);
- matrixArray_stage1_2 = _mm512_permutex2var_epi16(matrixArray_0, load_idx4_stage1_1st, matrixArray_1);
- // Calculate and accumulate the result
- result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage1_0, (__m512bh) xArray_01);
- result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage1_1, (__m512bh) xArray_23);
- result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage1_2, (__m512bh) xArray_4);
-
- STORE16_MASK_COMPLETE_RESULT(result_0, y+idx_m, store_less_mask)
- tag_m_30x += 12;
- }
- }
-
- BLASLONG tail_num = m - tag_m_30x;
- if (tail_num > 6) {
- unsigned short store_less_mask_value = (((unsigned short)0xffff) >> (4+(12-tail_num)));
- __mmask16 store_less_mask = *((__mmask16*) &store_less_mask_value);
- unsigned int load_less_mask_value = (((unsigned int)0xffffffff) >> (2+(12-tail_num)*5));
- __mmask32 load_less_mask = *((__mmask32*) &load_less_mask_value);
- result_0 = _mm512_setzero_ps();
-
- matrixArray_0 = _mm512_maskz_loadu_epi16(load_mask, &a[(tag_m_30x)*5]); // Load 6 rows with n=5
- matrixArray_1 = _mm512_maskz_loadu_epi16(load_less_mask, &a[((tag_m_30x+6)*5)]); // Load x rows with n=5
-
- // Interleave the elements
- matrixArray_stage1_0 = _mm512_permutex2var_epi16(matrixArray_0, load_idx01_stage1_1st, matrixArray_1);
- matrixArray_stage1_1 = _mm512_permutex2var_epi16(matrixArray_0, load_idx23_stage1_1st, matrixArray_1);
- matrixArray_stage1_2 = _mm512_permutex2var_epi16(matrixArray_0, load_idx4_stage1_1st, matrixArray_1);
- // Calculate and accumulate the result
- result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage1_0, (__m512bh) xArray_01);
- result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage1_1, (__m512bh) xArray_23);
- result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage1_2, (__m512bh) xArray_4);
-
- STORE16_MASK_COMPLETE_RESULT(result_0, y+tag_m_30x, store_less_mask)
- } else {
- __m128i matrixArray128;
- __m128 result128, tmp128;
- for (BLASLONG i = tag_m_30x; i < m; i++) {
- result128 = _mm_setzero_ps();
- matrixArray128 = _mm_maskz_loadu_epi16(x_load_mask, &a[(i)*5]); // Load 1 rows with n=5
- result128 = _mm_dpbf16_ps(result128, (__m128bh) matrixArray128, (__m128bh) x128);
- tmp128 = _mm_shuffle_ps(result128, result128, 14);
- result128 = _mm_add_ps(result128, tmp128);
- tmp128 = _mm_shuffle_ps(result128, result128, 1);
- result128 = _mm_add_ps(result128, tmp128);
- #ifndef ZERO_BETA
- #ifndef ONE_BETA
- y[i] = alpha * result128[0] + beta * y[i];
- #else
- y[i] = alpha * result128[0] + y[i];
- #endif
- #else
- #ifndef ONE_ALPHA
- y[i] = result128[0] * alpha;
- #else
- y[i] = result128[0];
- #endif
- #endif
-
- }
- }
-
- return 0;
- }
-
- // 16 rows parallel processing BF16 GEMV kernel for n=6 && lda ineffective scenario
- #ifndef ZERO_BETA
- #ifndef ONE_BETA
- static int sbgemv_kernel_16x6_alpha_beta(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
- #else
- static int sbgemv_kernel_16x6_alpha_one(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
- #endif
- #else
- #ifndef ONE_ALPHA
- static int sbgemv_kernel_16x6_alpha(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
- #else
- static int sbgemv_kernel_16x6(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
- #endif
- #endif
- {
- BLASLONG tag_m_16x = m & (~15);
-
- unsigned char x_load_mask_value = (((unsigned char)0xff) >> 2);
- __mmask8 x_load_mask = *((__mmask8*) &x_load_mask_value);
- __m128i x128 = _mm_maskz_loadu_epi16(x_load_mask, x); // x0|x1|x2|x3|x4|x5|0|0|
-
- if (tag_m_16x > 0) {
- __m512 result_0;
-
- #ifndef ONE_ALPHA
- __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
- #endif
- #ifndef ZERO_BETA
- __m512 BETAVECTOR = _mm512_set1_ps(beta);
- #endif
-
- __m512i M512_EPI32_1 = _mm512_set1_epi32(1);
- __m512i load_idx01_1st = _mm512_set_epi32( 0, 0, 0, 0, 0, 30, 27, 24, 21, 18, 15, 12, 9, 6, 3, 0);
- __m512i load_idx01_2nd = _mm512_set_epi32(13, 10, 7, 4, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0);
-
- __m512i load_idx23_1st = _mm512_add_epi32(load_idx01_1st, M512_EPI32_1);
- __m512i load_idx23_2nd = _mm512_add_epi32(load_idx01_2nd, M512_EPI32_1);
-
- __m512i load_idx45_1st = _mm512_add_epi32(load_idx23_1st, M512_EPI32_1);
- __m512i load_idx45_2nd = _mm512_add_epi32(load_idx23_2nd, M512_EPI32_1);
-
- unsigned short blend_mask_value = ((unsigned short)0x0400);
- __mmask16 blend_mask = *((__mmask16*) &blend_mask_value);
- // Set the 11th element to be 0 as invalid index for a 512 bit epi32 register
- load_idx45_1st = _mm512_mask_blend_epi32(blend_mask, load_idx45_1st, load_idx01_2nd);
- // Set the 11th element to be 0 as 0 is the correct index
- load_idx45_2nd = _mm512_mask_blend_epi32(blend_mask, load_idx45_2nd, load_idx01_2nd);
-
- __m512i xArray_01 = _mm512_broadcastd_epi32(x128); // x0|x1|x0|x1|...|x0|x1|
- __m512i xArray_23 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128, 0x1)); // x2|x3|x2|x3|...|x2|x3|
- __m512i xArray_45 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128, 0x2)); // x4|x5|x4|x5|...|x4|x5|
-
- unsigned short permute_mask01_uint = (((unsigned short)0xf800));
- __mmask16 permute_mask01 = *((__mmask16*) &permute_mask01_uint);
- unsigned short permute_mask45_uint = (((unsigned short)0xfc00));
- __mmask16 permute_mask45 = *((__mmask16*) &permute_mask45_uint);
-
- __m512i matrixArray_0, matrixArray_1, matrixArray_2;
- __m512i matrixArray_stage_0, matrixArray_stage_1, matrixArray_stage_2;
- for (BLASLONG idx_m = 0; idx_m < tag_m_16x; idx_m+=16) {
- result_0 = _mm512_setzero_ps();
-
- matrixArray_0 = _mm512_loadu_si512(&a[(idx_m)*6]); // Load 5 rows with n=6 plus 2 element
- matrixArray_1 = _mm512_loadu_si512(&a[((idx_m+5)*6 + 2)]); // Load 5 rows with n=6 plus 2 element
- matrixArray_2 = _mm512_loadu_si512(&a[((idx_m+10)*6 + 4)]); // Load 5 rows with n=6 plus 2 element
-
- // Stage 1: interleave for the a..k elements
- matrixArray_stage_0 = _mm512_permutex2var_epi32(matrixArray_0, load_idx01_1st, matrixArray_1);
- matrixArray_stage_1 = _mm512_permutex2var_epi32(matrixArray_0, load_idx23_1st, matrixArray_1);
- matrixArray_stage_2 = _mm512_permutex2var_epi32(matrixArray_0, load_idx45_1st, matrixArray_1);
-
- // Stage 2: interleave for the l..p elements and remix together
- matrixArray_stage_0 = _mm512_mask_permutexvar_epi32(matrixArray_stage_0, permute_mask01, load_idx01_2nd, matrixArray_2);
- matrixArray_stage_1 = _mm512_mask_permutexvar_epi32(matrixArray_stage_1, permute_mask01, load_idx23_2nd, matrixArray_2);
- matrixArray_stage_2 = _mm512_mask_permutexvar_epi32(matrixArray_stage_2, permute_mask45, load_idx45_2nd, matrixArray_2);
-
- // Calculate the result of the 0|1 elements
- result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage_0, (__m512bh) xArray_01);
- result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage_1, (__m512bh) xArray_23);
- result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage_2, (__m512bh) xArray_45);
-
- STORE16_COMPLETE_RESULT(result_0, y+idx_m)
- }
-
- if (m - tag_m_16x > 7) {
- __m256i M256_EPI32_1 = _mm512_castsi512_si256(M512_EPI32_1);
- __m256i load_idx01_1st = _mm256_set_epi32( 0, 0, 15, 12, 9, 6, 3, 0);
- __m256i load_idx01_2nd = _mm256_set_epi32( 5, 2, 0, 0, 0, 0, 0, 0);
-
- __m256i load_idx23_1st = _mm256_add_epi32(load_idx01_1st, M256_EPI32_1);
- __m256i load_idx23_2nd = _mm256_add_epi32(load_idx01_2nd, M256_EPI32_1);
- unsigned char blend_mask_value = ((unsigned char)0x20);
- __mmask8 blend_mask = *((__mmask8*) &blend_mask_value);
- // Set the 6th element to be 0 as invalid index for a 512 bit epi32 register
- load_idx23_1st = _mm256_mask_blend_epi32(blend_mask, load_idx23_1st, load_idx01_2nd);
- // Set the 6th element to be 0 as 0 is the correct index
- load_idx23_2nd = _mm256_mask_blend_epi32(blend_mask, load_idx23_2nd, load_idx01_2nd);
-
- __m256i load_idx45_1st = _mm256_add_epi32(load_idx23_1st, M256_EPI32_1);
- __m256i load_idx45_2nd = _mm256_add_epi32(load_idx23_2nd, M256_EPI32_1);
-
- unsigned char permute_mask01_uint = (((unsigned char)0xc0));
- __mmask8 permute_mask01 = *((__mmask8*) &permute_mask01_uint);
- unsigned char permute_mask45_uint = (((unsigned char)0xe0));
- __mmask8 permute_mask45 = *((__mmask8*) &permute_mask45_uint);
-
- __m256i matrixArray_0, matrixArray_1, matrixArray_2;
- __m256i matrixArray_stage_0;
- __m256 result256_0;
-
- result256_0 = _mm256_setzero_ps();
-
- matrixArray_0 = _mm256_loadu_si256(&a[(tag_m_16x)*6]); // Load 2 rows with n=6 plus 4 element
- matrixArray_1 = _mm256_loadu_si256(&a[((tag_m_16x+2)*6 + 4)]); // Load 2 rows with n=6 plus 4 element
- matrixArray_2 = _mm256_loadu_si256(&a[((tag_m_16x+5)*6 + 2)]); // Load 2 rows with n=6 plus 4 element
-
- // Process the 0|1 elements
- // Select the 0|1 elements for each row
- matrixArray_stage_0 = _mm256_permutex2var_epi32(matrixArray_0, load_idx01_1st, matrixArray_1);
- matrixArray_stage_0 = _mm256_mask_permutexvar_epi32(matrixArray_stage_0, permute_mask01, load_idx01_2nd, matrixArray_2);
- // Calculate the result of the 0|1 elements
- result256_0 = _mm256_dpbf16_ps(result256_0, (__m256bh) matrixArray_stage_0, (__m256bh) _mm512_castsi512_si256(xArray_01));
-
- // Process the 2|3 elements
- // Select the 2|3 elements for each row
- matrixArray_stage_0 = _mm256_permutex2var_epi32(matrixArray_0, load_idx23_1st, matrixArray_1);
- matrixArray_stage_0 = _mm256_mask_permutexvar_epi32(matrixArray_stage_0, permute_mask45, load_idx23_2nd, matrixArray_2);
- // Calculate the result of the 0|1 elements
- result256_0 = _mm256_dpbf16_ps(result256_0, (__m256bh) matrixArray_stage_0, (__m256bh) _mm512_castsi512_si256(xArray_23));
-
- // Process the for 4 elements
- // Select the 4|5 elements for each row
- matrixArray_stage_0 = _mm256_permutex2var_epi32(matrixArray_0, load_idx45_1st, matrixArray_1);
- matrixArray_stage_0 = _mm256_mask_permutexvar_epi32(matrixArray_stage_0, permute_mask45, load_idx45_2nd, matrixArray_2);
- // Calculate the result of the 0|1 elements
- result256_0 = _mm256_dpbf16_ps(result256_0, (__m256bh) matrixArray_stage_0, (__m256bh) _mm512_castsi512_si256(xArray_45));
-
- STORE8_COMPLETE_RESULT(result256_0, y+tag_m_16x)
- tag_m_16x += 8;
- }
- }
-
- if (tag_m_16x != m) {
- __m128i matrixArray128;
- __m128 result128, tmp128;
- for (BLASLONG i = tag_m_16x; i < m; i++) {
- result128 = _mm_setzero_ps();
- matrixArray128 = _mm_maskz_loadu_epi16(x_load_mask, &a[(i)*6]); // Load 1 rows with n=6
- result128 = _mm_dpbf16_ps(result128, (__m128bh) matrixArray128, (__m128bh) x128);
- tmp128 = _mm_shuffle_ps(result128, result128, 14);
- result128 = _mm_add_ps(result128, tmp128);
- tmp128 = _mm_shuffle_ps(result128, result128, 1);
- result128 = _mm_add_ps(result128, tmp128);
- #ifndef ZERO_BETA
- #ifndef ONE_BETA
- y[i] = alpha * result128[0] + beta * y[i];
- #else
- y[i] = alpha * result128[0] + y[i];
- #endif
- #else
- #ifndef ONE_ALPHA
- y[i] = result128[0] * alpha;
- #else
- y[i] = result128[0];
- #endif
- #endif
- }
- }
-
- return 0;
- }
-
- // 16 rows parallel processing BF16 GEMV kernel for n=7 && lda ineffective scenario
- #ifndef ZERO_BETA
- #ifndef ONE_BETA
- static int sbgemv_kernel_16x7_alpha_beta(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
- #else
- static int sbgemv_kernel_16x7_alpha_one(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
- #endif
- #else
- #ifndef ONE_ALPHA
- static int sbgemv_kernel_16x7_alpha(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
- #else
- static int sbgemv_kernel_16x7(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
- #endif
- #endif
- {
- BLASLONG tag_m_16x = m & (~15);
-
- unsigned char x_load_mask_value = (((unsigned char)0xff) >> 1);
- __mmask8 x_load_mask = *((__mmask8*) &x_load_mask_value);
- __m128i x128 = _mm_maskz_loadu_epi16(x_load_mask, x); // |x0|x1|x2|x3|x4|x5|x6|0|
-
- if (tag_m_16x > 0) {
- __m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3;
- __m512i matrixArray_stage_0, matrixArray_stage_1, matrixArray_stage_2, matrixArray_stage_3;
- __m512i xArray_0123, xArray_4567;
- __m512 result_0, result_1, result_2, result_3;
-
- #ifndef ONE_ALPHA
- __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
- #endif
- #ifndef ZERO_BETA
- __m512 BETAVECTOR = _mm512_set1_ps(beta);
- #endif
-
- __m512i M512_EPI32_2 = _mm512_set1_epi32(2);
- __m512i load_idx_stage1_0 = _mm512_set_epi16(31, 27, 26, 25, 24, 23, 22, 21, 31, 20, 19, 18, 17, 16, 15, 14,
- 31, 13, 12, 11, 10, 9, 8, 7, 31, 6, 5, 4, 3, 2, 1, 0);
- __m512i load_idx_stage2_0 = _mm512_set_epi32(29, 25, 21, 17, 13, 9, 5, 1, 28, 24, 20, 16, 12, 8, 4, 0);
- __m512i load_idx_stage2_1 = _mm512_add_epi32(load_idx_stage2_0, M512_EPI32_2);
-
- unsigned short x_blend_mask_value = ((unsigned short)0xff00);
- __mmask16 x_blend_mask = *((__mmask16*) &x_blend_mask_value);
- xArray_0123 = _mm512_mask_blend_epi32(x_blend_mask, _mm512_broadcastd_epi32(x128), \
- _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128, 0x1)));
- xArray_4567 = _mm512_mask_blend_epi32(x_blend_mask, _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128, 0x2)), \
- _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128, 0x3)));
-
- unsigned int load_mask_value = (((unsigned int)0xffffffff) >> 4);
- __mmask32 load_mask = *((__mmask32*) &load_mask_value);
- for (BLASLONG idx_m = 0; idx_m < tag_m_16x; idx_m+=16) {
- result_0 = _mm512_setzero_ps();
- result_1 = _mm512_setzero_ps();
-
- matrixArray_0 = _mm512_maskz_loadu_epi16(load_mask, &a[(idx_m)*7]); // Load 4 rows with n=7
- matrixArray_1 = _mm512_maskz_loadu_epi16(load_mask, &a[(idx_m+4)*7]); // Load 4 rows with n=7
- matrixArray_2 = _mm512_maskz_loadu_epi16(load_mask, &a[(idx_m+8)*7]); // Load 4 rows with n=7
- matrixArray_3 = _mm512_maskz_loadu_epi16(load_mask, &a[(idx_m+12)*7]); // Load 4 rows with n=7
-
- // Stage 1: padding
- matrixArray_0 = _mm512_permutexvar_epi16(load_idx_stage1_0, matrixArray_0); // |a0|a1|a2|a3|...|b6|b7|c0|c1|c2|c3|...|d6|d7|
- matrixArray_1 = _mm512_permutexvar_epi16(load_idx_stage1_0, matrixArray_1); // |e0|e1|e2|e3|...|f6|f7|g0|g1|g2|g3|...|h6|h7|
- matrixArray_2 = _mm512_permutexvar_epi16(load_idx_stage1_0, matrixArray_2); // |i0|i1|i2|i3|...|j6|j7|k0|k1|k2|k3|...|l6|l7|
- matrixArray_3 = _mm512_permutexvar_epi16(load_idx_stage1_0, matrixArray_3); // |m0|m1|m2|m3|...|n6|n7|o0|o1|o2|o3|...|p6|p7|
-
- // Stage 2: interleave per 32 bits
- matrixArray_stage_0 = _mm512_permutex2var_epi32(matrixArray_0, load_idx_stage2_0, matrixArray_1); // |a0|a1|...|h0|h1|a2|a3|...|h2|h3|
- matrixArray_stage_1 = _mm512_permutex2var_epi32(matrixArray_0, load_idx_stage2_1, matrixArray_1); // |a4|a5|...|h4|h5|a6|a7|...|h6|h7|
- matrixArray_stage_2 = _mm512_permutex2var_epi32(matrixArray_2, load_idx_stage2_0, matrixArray_3); // |i0|i1|...|p0|p1|i2|i3|...|p2|p3|
- matrixArray_stage_3 = _mm512_permutex2var_epi32(matrixArray_2, load_idx_stage2_1, matrixArray_3); // |i4|i5|...|p4|p5|i6|i7|...|p6|p7|
-
- result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage_0, (__m512bh) xArray_0123);
- result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_stage_2, (__m512bh) xArray_0123);
- result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage_1, (__m512bh) xArray_4567);
- result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_stage_3, (__m512bh) xArray_4567);
-
- // Stage 3: interleave per 256 bits
- result_2 = _mm512_shuffle_f32x4(result_0, result_1, 0x44);
- result_3 = _mm512_shuffle_f32x4(result_0, result_1, 0xee);
-
- result_2 = _mm512_add_ps(result_2, result_3);
-
- STORE16_COMPLETE_RESULT(result_2, y+idx_m)
- }
-
- if (m - tag_m_16x > 7) {
- result_0 = _mm512_setzero_ps();
-
- matrixArray_0 = _mm512_maskz_loadu_epi16(load_mask, &a[(tag_m_16x)*7]); // Load 4 rows with n=7
- matrixArray_1 = _mm512_maskz_loadu_epi16(load_mask, &a[(tag_m_16x+4)*7]); // Load 4 rows with n=7
-
- // Stage 1: padding
- matrixArray_0 = _mm512_permutexvar_epi16(load_idx_stage1_0, matrixArray_0); // |a0|a1|a2|a3|...|b6|b7|c0|c1|c2|c3|...|d6|d7|
- matrixArray_1 = _mm512_permutexvar_epi16(load_idx_stage1_0, matrixArray_1); // |e0|e1|e2|e3|...|f6|f7|g0|g1|g2|g3|...|h6|h7|
-
- // Stage 2: interleave per 32 bits
- matrixArray_stage_0 = _mm512_permutex2var_epi32(matrixArray_0, load_idx_stage2_0, matrixArray_1); // |a0|a1|b0|b1|...|h0|h1|a2|a3|b2|b3|...|h2|h3|
- matrixArray_stage_1 = _mm512_permutex2var_epi32(matrixArray_0, load_idx_stage2_1, matrixArray_1); // |a4|a5|b4|b5|...|h4|h5|a6|a7|b6|b7|...|h6|h7|
-
- result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage_0, (__m512bh) xArray_0123);
- result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage_1, (__m512bh) xArray_4567);
-
- __m256 result256 = _mm256_add_ps(_mm512_castps512_ps256(result_0), _mm512_extractf32x8_ps(result_0, 0x1));
-
- STORE8_COMPLETE_RESULT(result256, y+tag_m_16x)
-
- tag_m_16x += 8;
- }
-
- BLASLONG tail_num = m - tag_m_16x;
- if (tail_num > 3) {
- result_0 = _mm512_setzero_ps();
-
- matrixArray_0 = _mm512_maskz_loadu_epi16(load_mask, &a[(tag_m_16x)*7]); // Load 4 rows with n=7
- unsigned int tail_load_mask_value = (((unsigned int)0xffffffff) >> (4+(8-tail_num)*7));
- __mmask32 tail_load_mask = *((__mmask32*) &tail_load_mask_value);
- matrixArray_1 = _mm512_maskz_loadu_epi16(tail_load_mask, &a[(tag_m_16x+4)*7]); // Load 4 rows with n=7
-
- // Stage 1: padding
- matrixArray_0 = _mm512_permutexvar_epi16(load_idx_stage1_0, matrixArray_0); // |a0|a1|a2|a3|...|b6|b7|c0|c1|c2|c3|...|d6|d7|
- matrixArray_1 = _mm512_permutexvar_epi16(load_idx_stage1_0, matrixArray_1); // |e0|e1|e2|e3|...|f6|f7|g0|g1|g2|g3|...|h6|h7|
-
- // Stage 2: interleave per 32 bits
- matrixArray_stage_0 = _mm512_permutex2var_epi32(matrixArray_0, load_idx_stage2_0, matrixArray_1); // |a0|a1|b0|b1|...|h0|h1|a2|a3|b2|b3|...|h2|h3|
- matrixArray_stage_1 = _mm512_permutex2var_epi32(matrixArray_0, load_idx_stage2_1, matrixArray_1); // |a4|a5|b4|b5|...|h4|h5|a6|a7|b6|b7|...|h6|h7|
-
- result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage_0, (__m512bh) xArray_0123);
- result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage_1, (__m512bh) xArray_4567);
-
- __m256 result256 = _mm256_add_ps(_mm512_castps512_ps256(result_0), _mm512_extractf32x8_ps(result_0, 0x1));
-
- unsigned char tail_mask_value = (((unsigned char)0xff) >> (8-tail_num));
- __mmask8 tail_mask = *((__mmask8*) &tail_mask_value);
- STORE8_MASK_COMPLETE_RESULT(result256, y+tag_m_16x, tail_mask)
- tag_m_16x = m;
- }
- }
-
- if (tag_m_16x != m) {
- __m128i matrixArray128;
- __m128 result128, tmp128;
- for (BLASLONG i = tag_m_16x; i < m; i++) {
- result128 = _mm_setzero_ps();
- matrixArray128 = _mm_maskz_loadu_epi16(x_load_mask, &a[(i)*7]); // Load 1 rows with n=7
- result128 = _mm_dpbf16_ps(result128, (__m128bh) matrixArray128, (__m128bh) x128);
- tmp128 = _mm_shuffle_ps(result128, result128, 14);
- result128 = _mm_add_ps(result128, tmp128);
- tmp128 = _mm_shuffle_ps(result128, result128, 1);
- result128 = _mm_add_ps(result128, tmp128);
- #ifndef ZERO_BETA
- #ifndef ONE_BETA
- y[i] = alpha * result128[0] + beta * y[i];
- #else
- y[i] = alpha * result128[0] + y[i];
- #endif
- #else
- #ifndef ONE_ALPHA
- y[i] = result128[0] * alpha;
- #else
- y[i] = result128[0];
- #endif
- #endif
- }
- }
-
- return 0;
- }
-
- // 16 rows parallel processing BF16 GEMV kernel for n=8 && lda ineffective scenario
- #ifndef ZERO_BETA
- #ifndef ONE_BETA
- static int sbgemv_kernel_16x8_alpha_beta(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
- #else
- static int sbgemv_kernel_16x8_alpha_one(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
- #endif
- #else
- #ifndef ONE_ALPHA
- static int sbgemv_kernel_16x8_alpha(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
- #else
- static int sbgemv_kernel_16x8(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
- #endif
- #endif
- {
- BLASLONG tag_m_16x = m & (~15);
-
- __m128i x128 = _mm_loadu_si128(x); // |x0|x1|x2|x3|x4|x5|x6|x7|
-
- if (tag_m_16x > 0) {
- __m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3;
- __m512i matrixArray_stage_0, matrixArray_stage_1, matrixArray_stage_2, matrixArray_stage_3;
- __m512i xArray_0123, xArray_4567;
- __m512 result_0, result_1, result_2, result_3;
-
- #ifndef ONE_ALPHA
- __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
- #endif
- #ifndef ZERO_BETA
- __m512 BETAVECTOR = _mm512_set1_ps(beta);
- #endif
-
- __m512i M512_EPI32_2 = _mm512_set1_epi32(2);
- __m512i load_idx_stage2_0 = _mm512_set_epi32(29, 25, 21, 17, 13, 9, 5, 1, 28, 24, 20, 16, 12, 8, 4, 0);
- __m512i load_idx_stage2_1 = _mm512_add_epi32(load_idx_stage2_0, M512_EPI32_2);
-
- unsigned short x_blend_mask_value = ((unsigned short)0xff00);
- __mmask16 x_blend_mask = *((__mmask16*) &x_blend_mask_value);
- xArray_0123 = _mm512_mask_blend_epi32(x_blend_mask, _mm512_broadcastd_epi32(x128), \
- _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128, 0x1)));
- xArray_4567 = _mm512_mask_blend_epi32(x_blend_mask, _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128, 0x2)), \
- _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128, 0x3)));
-
- for (BLASLONG idx_m = 0; idx_m < tag_m_16x; idx_m+=16) {
- result_0 = _mm512_setzero_ps();
- result_1 = _mm512_setzero_ps();
-
- matrixArray_0 = _mm512_loadu_si512(&a[(idx_m)*8]); // Load 4 rows with n=8
- matrixArray_1 = _mm512_loadu_si512(&a[(idx_m+4)*8]); // Load 4 rows with n=8
- matrixArray_2 = _mm512_loadu_si512(&a[(idx_m+8)*8]); // Load 4 rows with n=8
- matrixArray_3 = _mm512_loadu_si512(&a[(idx_m+12)*8]); // Load 4 rows with n=8
-
- // Stage 1: interleave per 32 bits
- matrixArray_stage_0 = _mm512_permutex2var_epi32(matrixArray_0, load_idx_stage2_0, matrixArray_1); // |a0|a1|...|h0|h1|a2|a3|...|h2|h3|
- matrixArray_stage_1 = _mm512_permutex2var_epi32(matrixArray_0, load_idx_stage2_1, matrixArray_1); // |a4|a5|...|h4|h5|a6|a7|...|h6|h7|
- matrixArray_stage_2 = _mm512_permutex2var_epi32(matrixArray_2, load_idx_stage2_0, matrixArray_3); // |i0|i1|...|p0|p1|i2|i3|...|p2|p3|
- matrixArray_stage_3 = _mm512_permutex2var_epi32(matrixArray_2, load_idx_stage2_1, matrixArray_3); // |i4|i5|...|p4|p5|i6|i7|...|p6|p7|
-
- result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage_0, (__m512bh) xArray_0123);
- result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_stage_2, (__m512bh) xArray_0123);
- result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage_1, (__m512bh) xArray_4567);
- result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_stage_3, (__m512bh) xArray_4567);
-
- // Stage 2: interleave per 256 bits
- result_2 = _mm512_shuffle_f32x4(result_0, result_1, 0x44);
- result_3 = _mm512_shuffle_f32x4(result_0, result_1, 0xee);
-
- result_2 = _mm512_add_ps(result_2, result_3);
-
- STORE16_COMPLETE_RESULT(result_2, y+idx_m)
- }
-
- if (m - tag_m_16x > 7) {
- result_0 = _mm512_setzero_ps();
-
- matrixArray_0 = _mm512_loadu_si512(&a[(tag_m_16x)*8]); // Load 4 rows with n=8
- matrixArray_1 = _mm512_loadu_si512(&a[(tag_m_16x+4)*8]); // Load 4 rows with n=8
-
- // Stage 1: interleave per 32 bits
- matrixArray_stage_0 = _mm512_permutex2var_epi32(matrixArray_0, load_idx_stage2_0, matrixArray_1); // |a0|a1|b0|b1|...|h0|h1|a2|a3|b2|b3|...|h2|h3|
- matrixArray_stage_1 = _mm512_permutex2var_epi32(matrixArray_0, load_idx_stage2_1, matrixArray_1); // |a4|a5|b4|b5|...|h4|h5|a6|a7|b6|b7|...|h6|h7|
-
- result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage_0, (__m512bh) xArray_0123);
- result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage_1, (__m512bh) xArray_4567);
-
- __m256 result256 = _mm256_add_ps(_mm512_castps512_ps256(result_0), _mm512_extractf32x8_ps(result_0, 0x1));
-
- STORE8_COMPLETE_RESULT(result256, y+tag_m_16x)
- tag_m_16x += 8;
- }
-
- BLASLONG tail_num = m - tag_m_16x;
- if (tail_num > 3) {
- result_0 = _mm512_setzero_ps();
-
- matrixArray_0 = _mm512_loadu_si512(&a[(tag_m_16x)*8]); // Load 4 rows with n=8
- unsigned short tail_load_mask_value = (((unsigned int)0xffff) >> ((8-tail_num)*4));
- __mmask16 tail_load_mask = *((__mmask16*) &tail_load_mask_value);
- matrixArray_1 = _mm512_maskz_loadu_epi32(tail_load_mask, &a[(tag_m_16x+4)*8]); // Load 4 rows with n=8
-
- // Stage 1: interleave per 32 bits
- matrixArray_stage_0 = _mm512_permutex2var_epi32(matrixArray_0, load_idx_stage2_0, matrixArray_1); // |a0|a1|b0|b1|...|h0|h1|a2|a3|b2|b3|...|h2|h3|
- matrixArray_stage_1 = _mm512_permutex2var_epi32(matrixArray_0, load_idx_stage2_1, matrixArray_1); // |a4|a5|b4|b5|...|h4|h5|a6|a7|b6|b7|...|h6|h7|
-
- result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage_0, (__m512bh) xArray_0123);
- result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_stage_1, (__m512bh) xArray_4567);
-
- __m256 result256 = _mm256_add_ps(_mm512_castps512_ps256(result_0), _mm512_extractf32x8_ps(result_0, 0x1));
-
- unsigned char tail_mask_value = (((unsigned char)0xff) >> (8-tail_num));
- __mmask8 tail_mask = *((__mmask8*) &tail_mask_value);
- STORE8_MASK_COMPLETE_RESULT(result256, y+tag_m_16x, tail_mask)
- tag_m_16x = m;
- }
- }
-
- if (tag_m_16x != m) {
- __m128i matrixArray128;
- __m128 result128, tmp128;
- for (BLASLONG i = tag_m_16x; i < m; i++) {
- result128 = _mm_setzero_ps();
- matrixArray128 = _mm_loadu_si128(&a[(i)*8]); // Load 1 rows with n=8
- result128 = _mm_dpbf16_ps(result128, (__m128bh) matrixArray128, (__m128bh) x128);
- tmp128 = _mm_shuffle_ps(result128, result128, 14);
- result128 = _mm_add_ps(result128, tmp128);
- tmp128 = _mm_shuffle_ps(result128, result128, 1);
- result128 = _mm_add_ps(result128, tmp128);
- #ifndef ZERO_BETA
- #ifndef ONE_BETA
- y[i] = alpha * result128[0] + beta * y[i];
- #else
- y[i] = alpha * result128[0] + y[i];
- #endif
- #else
- #ifndef ONE_ALPHA
- y[i] = result128[0] * alpha;
- #else
- y[i] = result128[0];
- #endif
- #endif
- }
- }
-
- return 0;
- }
-
- // 14 rows parallel processing BF16 GEMV kernel for n=9 && lda ineffective scenario
- #ifndef ZERO_BETA
- #ifndef ONE_BETA
- static int sbgemv_kernel_14x9_alpha_beta(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
- #else
- static int sbgemv_kernel_14x9_alpha_one(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
- #endif
- #else
- #ifndef ONE_ALPHA
- static int sbgemv_kernel_14x9_alpha(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
- #else
- static int sbgemv_kernel_14x9(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
- #endif
- #endif
- {
- BLASLONG tag_m_14x = m - (m%14);
-
- unsigned char x_load_mask_value = (((unsigned char)0xff) >> 7);
- __mmask8 x_load_mask = *((__mmask8*) &x_load_mask_value);
- __m128i x128_0 = _mm_loadu_si128(x); // |x0|x1|x2|x3|x4|x5|x6|x7|
- __m128i x128_1 = _mm_maskz_loadu_epi16(x_load_mask, (x+8)); // |x8|0 |0 | 0| 0| 0| 0| 0|
-
- if (tag_m_14x > 0) {
- __m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3, matrixArray_4, matrixArray_5;
- __m512i matrixArray_stage_0, matrixArray_stage_1, matrixArray_stage_2, matrixArray_stage_3;
- __m512i xArray_01, xArray_23, xArray_45, xArray_67, xArray_89;
- __m512 result_0, result_1;
-
- #ifndef ONE_ALPHA
- __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
- #endif
- #ifndef ZERO_BETA
- __m512 BETAVECTOR = _mm512_set1_ps(beta);
- #endif
-
- __m256i M256_EPI16_2 = _mm256_set1_epi16(2);
- __m256i idx_base_0 = _mm256_set_epi16( 0, 0, 55, 54, 46, 45, 37, 36, 28, 27, 19, 18, 10, 9, 1, 0);
- __m256i idx_base_1 = _mm256_add_epi16(idx_base_0, M256_EPI16_2);
- __m256i idx_base_2 = _mm256_add_epi16(idx_base_1, M256_EPI16_2);
- __m256i idx_base_3 = _mm256_add_epi16(idx_base_2, M256_EPI16_2);
- __m256i idx_base_4 = _mm256_add_epi16(idx_base_3, M256_EPI16_2);
- __m512i idx_idx = _mm512_set_epi32( 0, 0, 22, 21, 20, 19, 18, 17, 16, 6, 5, 4, 3, 2, 1, 0);
-
- __m512i load_idx_stage1_0 = _mm512_permutex2var_epi32(_mm512_castsi256_si512(idx_base_0), idx_idx, _mm512_castsi256_si512(idx_base_1));
- __m512i load_idx_stage1_1 = _mm512_permutex2var_epi32(_mm512_castsi256_si512(idx_base_2), idx_idx, _mm512_castsi256_si512(idx_base_3));
- __m512i load_idx_stage1_2 = _mm512_permutex2var_epi32(_mm512_castsi256_si512(idx_base_1), idx_idx, _mm512_castsi256_si512(idx_base_0));
- __m512i load_idx_stage1_3 = _mm512_permutex2var_epi32(_mm512_castsi256_si512(idx_base_3), idx_idx, _mm512_castsi256_si512(idx_base_2));
- __m512i load_idx_stage1_4 = _mm512_permutex2var_epi32(_mm512_castsi256_si512(idx_base_4), idx_idx, _mm512_castsi256_si512(idx_base_4));
- __m512i load_idx_stage2_0 = _mm512_set_epi32( 0, 0, 22, 21, 20, 19, 18, 17, 16, 13, 12, 11, 10, 9, 8, 7);
-
- xArray_01 = _mm512_broadcastd_epi32(x128_0); // |x0|x1|x0|x1| ... |x0|x1|
- xArray_23 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128_0, 0x1)); // |x2|x3|x2|x3| ... |x2|x3|
- xArray_45 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128_0, 0x2)); // |x4|x5|x4|x5| ... |x4|x5|
- xArray_67 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128_0, 0x3)); // |x6|x7|x6|x7| ... |x6|x7|
- xArray_89 = _mm512_broadcastd_epi32(x128_1); // |x8|0 |x8| 0| ... |x8| 0|
-
- unsigned int load_mask_value = (((unsigned int)0xffffffff) >> 1);
- __mmask32 load_mask = *((__mmask32*) &load_mask_value);
- unsigned short blend_mask_value = ((unsigned short)0x3f80);
- __mmask16 blend_mask = *((__mmask16*) &blend_mask_value);
- unsigned short store_mask_value = (((unsigned short)0xffff) >> 2);
- __mmask16 store_mask = *((__mmask16*) &store_mask_value);
- for (BLASLONG idx_m = 0; idx_m < tag_m_14x; idx_m+=14) {
- result_0 = _mm512_setzero_ps();
- result_1 = _mm512_setzero_ps();
-
- matrixArray_0 = _mm512_loadu_si512(&a[(idx_m)*9]); // Load 3 rows with n=9 plus 5 elements
- matrixArray_1 = _mm512_maskz_loadu_epi16(load_mask, &a[(idx_m+3)*9 + 5]); // Load 3 rows with n=9 plus 4 elements
- matrixArray_2 = _mm512_loadu_si512(&a[(idx_m+7)*9]); // Load 3 rows with n=9 plus 5 elements
- matrixArray_3 = _mm512_maskz_loadu_epi16(load_mask, &a[(idx_m+10)*9 + 5]); // Load 3 rows with n=9 plus 4 elements
-
- // Stage 1: interleave per 16 bits
- matrixArray_stage_0 = _mm512_permutex2var_epi16(matrixArray_0, load_idx_stage1_0, matrixArray_1); // |a0|a1|...|g0|g1|a2|a3|...|g2|g3|x|x|x|x|
- matrixArray_stage_1 = _mm512_permutex2var_epi16(matrixArray_0, load_idx_stage1_1, matrixArray_1); // |a4|a5|...|g4|g5|a6|a7|...|g6|g7|x|x|x|x|
- matrixArray_stage_2 = _mm512_permutex2var_epi16(matrixArray_2, load_idx_stage1_2, matrixArray_3); // |h2|h3|...|n2|n3|h0|h1|...|n0|n1|x|x|x|x|
- matrixArray_stage_3 = _mm512_permutex2var_epi16(matrixArray_2, load_idx_stage1_3, matrixArray_3); // |h6|h7|...|n6|n7|h4|h5|...|n4|n5|x|x|x|x|
- matrixArray_4 = _mm512_permutex2var_epi16(matrixArray_0, load_idx_stage1_4, matrixArray_1); // |a8| x|...|g8| x| x| x|...| x| x|x|x|x|x|
- matrixArray_5 = _mm512_permutex2var_epi16(matrixArray_2, load_idx_stage1_4, matrixArray_3); // | x| x|...| x| x|h8| x|...|n8| x|x|x|x|x|
-
- // Stage 2: interleave per 32 bits
- matrixArray_0 = _mm512_mask_blend_epi32(blend_mask, matrixArray_stage_0, matrixArray_stage_2); // |a0|a1|b0|b1|...|h0|h1|i0|i1|j0|j1|...|n0|n1|x|x|x|x|
- matrixArray_1 = _mm512_permutex2var_epi32(matrixArray_stage_0, load_idx_stage2_0, matrixArray_stage_2); // |a2|a3|b2|b3|...|h2|h3|i2|i3|j2|j3|...|n2|n3|x|x|x|x|
- matrixArray_2 = _mm512_mask_blend_epi32(blend_mask, matrixArray_stage_1, matrixArray_stage_3); // |a4|a5|b4|b5|...|h4|h5|i4|i5|j4|j5|...|n4|n5|x|x|x|x|
- matrixArray_3 = _mm512_permutex2var_epi32(matrixArray_stage_1, load_idx_stage2_0, matrixArray_stage_3); // |a6|a7|b6|b7|...|h6|h7|i6|i7|j6|j7|...|n6|n7|x|x|x|x|
- matrixArray_4 = _mm512_mask_blend_epi32(blend_mask, matrixArray_4, matrixArray_5); // |a8| x|b8| x|...|h8| x|i8| x|j8| x|...|n8| x|x|x|x|x|
-
- result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_0, (__m512bh) xArray_01);
- result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_1, (__m512bh) xArray_23);
- result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_2, (__m512bh) xArray_45);
- result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_3, (__m512bh) xArray_67);
- result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_4, (__m512bh) xArray_89);
- result_0 = _mm512_add_ps(result_0, result_1);
-
- STORE16_MASK_COMPLETE_RESULT(result_0, y+idx_m, store_mask)
- }
- }
-
- if (tag_m_14x != m) {
- __m256i matrixArray256;
- __m256i x256 = _mm256_insertf128_si256(_mm256_castsi128_si256(x128_0), x128_1, 0x1);
- __m256 result256;
- __m128 result128, tmp128;
- unsigned short load256_mask_value = (((unsigned short)0xffff) >> 7);
- __mmask16 load256_mask = *((__mmask16*) &load256_mask_value);
- for (BLASLONG i = tag_m_14x; i < m; i++) {
- result256 = _mm256_setzero_ps();
- matrixArray256 = _mm256_maskz_loadu_epi16(load256_mask, &a[(i)*9]);
- result256 = _mm256_dpbf16_ps(result256, (__m256bh) matrixArray256, (__m256bh) x256);
- result128 = _mm_add_ps(_mm256_castps256_ps128(result256), _mm256_extractf128_ps(result256, 0x1));
- tmp128 = _mm_shuffle_ps(result128, result128, 14);
- result128 = _mm_add_ps(result128, tmp128);
- tmp128 = _mm_shuffle_ps(result128, result128, 1);
- result128 = _mm_add_ps(result128, tmp128);
- #ifndef ZERO_BETA
- #ifndef ONE_BETA
- y[i] = alpha * result128[0] + beta * y[i];
- #else
- y[i] = alpha * result128[0] + y[i];
- #endif
- #else
- #ifndef ONE_ALPHA
- y[i] = result128[0] * alpha;
- #else
- y[i] = result128[0];
- #endif
- #endif
- }
- }
-
- return 0;
- }
-
- // 12 rows parallel processing BF16 GEMV kernel for n=10 && lda ineffective scenario
- #ifndef ZERO_BETA
- #ifndef ONE_BETA
- static int sbgemv_kernel_12x10_alpha_beta(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
- #else
- static int sbgemv_kernel_12x10_alpha_one(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
- #endif
- #else
- #ifndef ONE_ALPHA
- static int sbgemv_kernel_12x10_alpha(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
- #else
- static int sbgemv_kernel_12x10(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
- #endif
- #endif
- {
- BLASLONG tag_m_12x = m - (m%12);
-
- unsigned char x_load_mask_value = (((unsigned char)0xf) >> 3);
- __mmask8 x_load_mask = *((__mmask8*) &x_load_mask_value);
- __m128i x128_0 = _mm_loadu_si128(x); // |x0|x1|x2|x3|x4|x5|x6|x7|
- __m128i x128_1 = _mm_maskz_loadu_epi32(x_load_mask, (x+8)); // |x8|x9|0 | 0| 0| 0| 0| 0|
-
- if (tag_m_12x > 0) {
- __m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3, matrixArray_4;
- __m512i matrixArray_stage_0, matrixArray_stage_1, matrixArray_stage_2, matrixArray_stage_3, matrixArray_stage_4, matrixArray_stage_5;
- __m512i xArray_01, xArray_23, xArray_45, xArray_67, xArray_89;
- __m512 result_0, result_1;
-
- #ifndef ONE_ALPHA
- __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
- #endif
- #ifndef ZERO_BETA
- __m512 BETAVECTOR = _mm512_set1_ps(beta);
- #endif
-
- __m256i M256_EPI32_1 = _mm256_set1_epi32(1);
- __m256i idx_base_0 = _mm256_set_epi32( 0, 0, 26, 21, 16, 10, 5, 0);
- __m256i idx_base_1 = _mm256_add_epi32(idx_base_0, M256_EPI32_1);
- __m256i idx_base_2 = _mm256_add_epi32(idx_base_1, M256_EPI32_1);
- __m256i idx_base_3 = _mm256_add_epi32(idx_base_2, M256_EPI32_1);
- __m256i idx_base_4 = _mm256_add_epi32(idx_base_3, M256_EPI32_1);
- __m512i idx_idx = _mm512_set_epi32( 0, 0, 0, 0, 21, 20, 19, 18, 17, 16, 5, 4, 3, 2, 1, 0);
-
- __m512i load_idx_stage1_0 = _mm512_permutex2var_epi32(_mm512_castsi256_si512(idx_base_0), idx_idx, _mm512_castsi256_si512(idx_base_1));
- __m512i load_idx_stage1_1 = _mm512_permutex2var_epi32(_mm512_castsi256_si512(idx_base_2), idx_idx, _mm512_castsi256_si512(idx_base_3));
- __m512i load_idx_stage1_2 = _mm512_permutex2var_epi32(_mm512_castsi256_si512(idx_base_1), idx_idx, _mm512_castsi256_si512(idx_base_0));
- __m512i load_idx_stage1_3 = _mm512_permutex2var_epi32(_mm512_castsi256_si512(idx_base_3), idx_idx, _mm512_castsi256_si512(idx_base_2));
- __m512i load_idx_stage1_4 = _mm512_permutex2var_epi32(_mm512_castsi256_si512(idx_base_4), idx_idx, _mm512_castsi256_si512(idx_base_4));
- __m512i load_idx_stage2_0 = _mm512_set_epi32( 0, 0, 0, 0, 21, 20, 19, 18, 17, 16, 11, 10, 9, 8, 7, 6);
-
- xArray_01 = _mm512_broadcastd_epi32(x128_0); // |x0|x1|x0|x1| ... |x0|x1|
- xArray_23 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128_0, 0x1)); // |x2|x3|x2|x3| ... |x2|x3|
- xArray_45 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128_0, 0x2)); // |x4|x5|x4|x5| ... |x4|x5|
- xArray_67 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128_0, 0x3)); // |x6|x7|x6|x7| ... |x6|x7|
- xArray_89 = _mm512_broadcastd_epi32(x128_1); // |x8|x9|x8|x9| ... |x8|x9|
-
- unsigned short blend_mask_value = ((unsigned short)0x0fc0);
- __mmask16 blend_mask = *((__mmask16*) &blend_mask_value);
- unsigned short load_mask_value = (((unsigned short)0xffff) >> 1);
- __mmask16 load_mask = *((__mmask16*) &load_mask_value);
- unsigned short store_mask_value = (((unsigned short)0xffff) >> 4);
- __mmask16 store_mask = *((__mmask16*) &store_mask_value);
- for (BLASLONG idx_m = 0; idx_m < tag_m_12x; idx_m+=12) {
- result_0 = _mm512_setzero_ps();
- result_1 = _mm512_setzero_ps();
-
- matrixArray_0 = _mm512_maskz_loadu_epi32(load_mask, &a[(idx_m)*10]); // Load 3 rows with n=10
- matrixArray_1 = _mm512_maskz_loadu_epi32(load_mask, &a[(idx_m+3)*10]); // Load 3 rows with n=10
- matrixArray_2 = _mm512_maskz_loadu_epi32(load_mask, &a[(idx_m+6)*10]); // Load 3 rows with n=10
- matrixArray_3 = _mm512_maskz_loadu_epi32(load_mask, &a[(idx_m+9)*10]); // Load 3 rows with n=10
-
- // Stage 1: interleave per 32 bits
- matrixArray_stage_0 = _mm512_permutex2var_epi32(matrixArray_0, load_idx_stage1_0, matrixArray_1); // |a0|a1|...|f0|f1|a2|a3|...|f2|f3|x|x|x|x|x|x|x|x|
- matrixArray_stage_1 = _mm512_permutex2var_epi32(matrixArray_0, load_idx_stage1_1, matrixArray_1); // |a4|a5|...|f4|f5|a6|a7|...|f6|f7|x|x|x|x|x|x|x|x|
- matrixArray_stage_2 = _mm512_permutex2var_epi32(matrixArray_2, load_idx_stage1_2, matrixArray_3); // |g2|g3|...|l2|l3|g0|g1|...|l0|l1|x|x|x|x|x|x|x|x|
- matrixArray_stage_3 = _mm512_permutex2var_epi32(matrixArray_2, load_idx_stage1_3, matrixArray_3); // |g6|g7|...|l6|l7|g4|g5|...|l4|l5|x|x|x|x|x|x|x|x|
- matrixArray_stage_4 = _mm512_permutex2var_epi32(matrixArray_0, load_idx_stage1_4, matrixArray_1); // |a8|a9|...|f8|f9| x| x|...| x| x|x|x|x|x|x|x|x|x|
- matrixArray_stage_5 = _mm512_permutex2var_epi32(matrixArray_2, load_idx_stage1_4, matrixArray_3); // | x| x|...| x| x|g8|g9|...|l8|l9|x|x|x|x|x|x|x|x|
-
- // Stage 3: interleave per 256 bits
- matrixArray_0 = _mm512_mask_blend_epi32(blend_mask, matrixArray_stage_0, matrixArray_stage_2); // |a0|a1|...|l0|l1|x|x|x|x|x|x|x|x|
- matrixArray_1 = _mm512_permutex2var_epi32(matrixArray_stage_0, load_idx_stage2_0, matrixArray_stage_2); // |a2|a3|...|l2|l3|x|x|x|x|x|x|x|x|
- matrixArray_2 = _mm512_mask_blend_epi32(blend_mask, matrixArray_stage_1, matrixArray_stage_3); // |a4|a5|...|l4|l5|x|x|x|x|x|x|x|x|
- matrixArray_3 = _mm512_permutex2var_epi32(matrixArray_stage_1, load_idx_stage2_0, matrixArray_stage_3); // |a6|a7|...|l6|l7|x|x|x|x|x|x|x|x|
- matrixArray_4 = _mm512_mask_blend_epi32(blend_mask, matrixArray_stage_4, matrixArray_stage_5); // |a8|a9|...|l8|l9|x|x|x|x|x|x|x|x|
-
- result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_0, (__m512bh) xArray_01);
- result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_1, (__m512bh) xArray_23);
- result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_2, (__m512bh) xArray_45);
- result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_3, (__m512bh) xArray_67);
- result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_4, (__m512bh) xArray_89);
- result_0 = _mm512_add_ps(result_0, result_1);
-
- STORE16_MASK_COMPLETE_RESULT(result_0, y+idx_m, store_mask)
- }
- }
-
- if (tag_m_12x != m) {
- __m256i matrixArray256;
- __m256i x256 = _mm256_insertf128_si256(_mm256_castsi128_si256(x128_0), x128_1, 0x1);
- __m256 result256;
- __m128 result128, tmp128;
- unsigned char load256_mask_value = (((unsigned char)0xff) >> 3);
- __mmask8 load256_mask = *((__mmask8*) &load256_mask_value);
- for (BLASLONG i = tag_m_12x; i < m; i++) {
- result256 = _mm256_setzero_ps();
- matrixArray256 = _mm256_maskz_loadu_epi32(load256_mask, &a[(i)*10]);
- result256 = _mm256_dpbf16_ps(result256, (__m256bh) matrixArray256, (__m256bh) x256);
- result128 = _mm_add_ps(_mm256_castps256_ps128(result256), _mm256_extractf128_ps(result256, 0x1));
- tmp128 = _mm_shuffle_ps(result128, result128, 14);
- result128 = _mm_add_ps(result128, tmp128);
- tmp128 = _mm_shuffle_ps(result128, result128, 1);
- result128 = _mm_add_ps(result128, tmp128);
- #ifndef ZERO_BETA
- #ifndef ONE_BETA
- y[i] = alpha * result128[0] + beta * y[i];
- #else
- y[i] = alpha * result128[0] + y[i];
- #endif
- #else
- #ifndef ONE_ALPHA
- y[i] = result128[0] * alpha;
- #else
- y[i] = result128[0];
- #endif
- #endif
- }
- }
-
- return 0;
- }
-
- // 15 rows parallel processing BF16 GEMV kernel for n=11 && lda ineffective scenario
- #ifndef ZERO_BETA
- #ifndef ONE_BETA
- static int sbgemv_kernel_15x11_alpha_beta(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
- #else
- static int sbgemv_kernel_15x11_alpha_one(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
- #endif
- #else
- #ifndef ONE_ALPHA
- static int sbgemv_kernel_15x11_alpha(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
- #else
- static int sbgemv_kernel_15x11(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
- #endif
- #endif
- {
- BLASLONG tag_m_15x = m - (m%15);
-
- unsigned char x_load_mask_value = (((unsigned char)0xff) >> 5);
- __mmask8 x_load_mask = *((__mmask8*) &x_load_mask_value);
- __m128i x128_0 = _mm_loadu_si128(x); // |x0|x1| x2|x3|x4|x5|x6|x7|
- __m128i x128_1 = _mm_maskz_loadu_epi16(x_load_mask, (x+8)); // |x8|x9|x10| 0| 0| 0| 0| 0|
-
- if (tag_m_15x > 0) {
- __m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3, matrixArray_4, matrixArray_5;
- __m512i matrixArray_stage_0, matrixArray_stage_1, matrixArray_stage_2, matrixArray_stage_3, matrixArray_stage_4, matrixArray_stage_5;
- __m512i xArray_01, xArray_23, xArray_45, xArray_67, xArray_89, xArray_10;
- __m512 result_0, result_1;
-
- #ifndef ONE_ALPHA
- __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
- #endif
- #ifndef ZERO_BETA
- __m512 BETAVECTOR = _mm512_set1_ps(beta);
- #endif
-
- __m512i idx_stage1_base_0, idx_stage1_base_1, idx_stage1_base_2, idx_stage1_base_3, idx_stage1_base_4, idx_stage1_base_5;
- __m512i idx_stage2_base_0, idx_stage2_base_1, idx_stage2_base_2, idx_stage2_base_3;
-
- __m512i M512_EPI16_2, M512_EPI16_4, M512_EPI16_6, M512_EPI32_5;
- M512_EPI16_2 = _mm512_set1_epi16(2);
- M512_EPI16_4 = _mm512_add_epi16(M512_EPI16_2, M512_EPI16_2);
- M512_EPI16_6 = _mm512_add_epi16(M512_EPI16_4, M512_EPI16_2);
- M512_EPI32_5 = _mm512_set1_epi32(5);
-
- unsigned int BASE_MASK_10_value = ((unsigned int)0x000003ff);
- __mmask32 BASE_MASK_10 = *((__mmask32*) &BASE_MASK_10_value);
- unsigned int BASE_MASK_20_value = ((unsigned int)0x000ffc00);
- __mmask32 BASE_MASK_20 = *((__mmask32*) &BASE_MASK_20_value);
- unsigned int BASE_MASK_30_value = ((unsigned int)0x3ff00000);
- __mmask32 BASE_MASK_30 = *((__mmask32*) &BASE_MASK_30_value);
-
- idx_stage1_base_0 = _mm512_set_epi16( 0, 0, 49, 48, 38, 37, 27, 26, 16, 15, 5, 4, 47, 46, 36, 35,
- 25, 24, 14, 13, 3, 2, 45, 44, 34, 33, 23, 22, 12, 11, 1, 0);
- idx_stage1_base_1 = _mm512_add_epi16(idx_stage1_base_0, M512_EPI16_6);
-
- idx_stage1_base_2 = _mm512_mask_add_epi16(idx_stage1_base_0, BASE_MASK_10, idx_stage1_base_0, M512_EPI16_2);
- idx_stage1_base_2 = _mm512_mask_sub_epi16(idx_stage1_base_2, BASE_MASK_20, idx_stage1_base_0, M512_EPI16_2);
- idx_stage1_base_3 = _mm512_add_epi16(idx_stage1_base_2, M512_EPI16_6);
-
- idx_stage1_base_4 = _mm512_mask_add_epi16(idx_stage1_base_2, BASE_MASK_10, idx_stage1_base_2, M512_EPI16_2);
- idx_stage1_base_4 = _mm512_mask_add_epi16(idx_stage1_base_4, BASE_MASK_20, idx_stage1_base_2, M512_EPI16_2);
- idx_stage1_base_4 = _mm512_mask_sub_epi16(idx_stage1_base_4, BASE_MASK_30, idx_stage1_base_2, M512_EPI16_4);
- idx_stage1_base_5 = _mm512_add_epi16(idx_stage1_base_4, M512_EPI16_6);
-
- unsigned short idx_stage2_mask_1_value = ((unsigned short)0x03e0);
- __mmask16 idx_stage2_mask_1 = *((__mmask16*) &idx_stage2_mask_1_value);
- unsigned short idx_stage2_mask_2_value = ((unsigned short)0x7c00);
- __mmask16 idx_stage2_mask_2 = *((__mmask16*) &idx_stage2_mask_2_value);
- idx_stage2_base_0 = _mm512_set_epi32( 0, 0, 0, 0, 0, 0, 20, 19, 18, 17, 16, 9, 8, 7, 6, 5);
- idx_stage2_base_1 = _mm512_set_epi32( 0, 25, 24, 23, 22, 21, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0);
- idx_stage2_base_2 = _mm512_add_epi32(idx_stage2_base_0, M512_EPI32_5);
- idx_stage2_base_2 = _mm512_mask_add_epi32(idx_stage2_base_2, idx_stage2_mask_1, idx_stage2_base_2, M512_EPI32_5);
- idx_stage2_base_3 = _mm512_mask_sub_epi32(idx_stage2_base_1, idx_stage2_mask_2, idx_stage2_base_1, M512_EPI32_5);
-
- xArray_01 = _mm512_broadcastd_epi32(x128_0); // |x0 |x1 |x0 |x1 | ... |x0 |x1 |
- xArray_23 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128_0, 0x1)); // |x2 |x3 |x2 |x3 | ... |x2 |x3 |
- xArray_45 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128_0, 0x2)); // |x4 |x5 |x4 |x5 | ... |x4 |x5 |
- xArray_67 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128_0, 0x3)); // |x6 |x7 |x6 |x7 | ... |x6 |x7 |
- xArray_89 = _mm512_broadcastd_epi32(x128_1); // |x8 |x9 |x8 |x9 | ... |x8 |x9 |
- xArray_10 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128_1, 0x1)); // |x10|0 |x10|0 | ... |x10|0 |
-
- unsigned int load_mask_value = (((unsigned int)0xffffffff) >> 9);
- __mmask32 load_mask = *((__mmask32*) &load_mask_value);
-
- unsigned short store_mask_value = (((unsigned short)0xffff) >> 1);
- __mmask16 store_mask = *((__mmask16*) &store_mask_value);
-
- for (BLASLONG idx_m = 0; idx_m < tag_m_15x; idx_m+=15) {
- result_0 = _mm512_setzero_ps();
- result_1 = _mm512_setzero_ps();
-
- matrixArray_0 = _mm512_loadu_si512(&a[idx_m*11]); // Load 2 rows with n=11 plus 10 elements
- matrixArray_1 = _mm512_maskz_loadu_epi16(load_mask, &a[idx_m*11 + 32]); // Load 2 rows with n=11 plus 1 element
- matrixArray_2 = _mm512_loadu_si512(&a[(idx_m+5)*11]); // Load 2 rows with n=11 plus 10 elements
- matrixArray_3 = _mm512_maskz_loadu_epi16(load_mask, &a[(idx_m+5)*11 + 32]); // Load 2 rows with n=11 plus 1 element
- matrixArray_4 = _mm512_loadu_si512(&a[(idx_m+10)*11]); // Load 2 rows with n=11 plus 10 elements
- matrixArray_5 = _mm512_maskz_loadu_epi16(load_mask, &a[(idx_m+10)*11 + 32]); // Load 2 rows with n=11 plus 1 element
-
- // Stage 1: interleave per 16 bits
- matrixArray_stage_0 = _mm512_permutex2var_epi16(matrixArray_0, idx_stage1_base_0, matrixArray_1); // |a0|a1|...|e0|e1|a2|a3|...|e2|e3|a4 |a5|...|e4 |e5|
- matrixArray_stage_1 = _mm512_permutex2var_epi16(matrixArray_0, idx_stage1_base_1, matrixArray_1); // |a6|a7|...|e6|e7|a8|a9|...|e8|e9|a10|x |...|e10|x |
- matrixArray_stage_2 = _mm512_permutex2var_epi16(matrixArray_2, idx_stage1_base_2, matrixArray_3); // |f2|f3|...|j2|j3|f0|f1|...|j0|j1|f4 |f5|...|j4 |j5|
- matrixArray_stage_3 = _mm512_permutex2var_epi16(matrixArray_2, idx_stage1_base_3, matrixArray_3); // |f8|f9|...|j8|j9|f6|f7|...|j6|j7|f10|x |...|j10|x |
- matrixArray_stage_4 = _mm512_permutex2var_epi16(matrixArray_4, idx_stage1_base_4, matrixArray_5); // |k4|k5|...|o4|o5|k2|k3|...|o2|o3|k0 |k1|...|o0 |o1|
- matrixArray_stage_5 = _mm512_permutex2var_epi16(matrixArray_4, idx_stage1_base_5, matrixArray_5); // |k10|x|...|o10|x|k8|k9|...|o8|o9|k6 |k7|...|o6 |o7|
-
- // Stage 2: interleave per 32 bits
- matrixArray_0 = _mm512_mask_blend_epi32(idx_stage2_mask_1, matrixArray_stage_0, matrixArray_stage_2); // |a0|a1|...|j0|j1|x|x|x|x|x|x|x|x|x|x|x|x|
- matrixArray_3 = _mm512_mask_blend_epi32(idx_stage2_mask_1, matrixArray_stage_1, matrixArray_stage_3); // |a6|a7|...|j6|j7|x|x|x|x|x|x|x|x|x|x|x|x|
- matrixArray_1 = _mm512_permutex2var_epi32(matrixArray_stage_0, idx_stage2_base_0, matrixArray_stage_2); // |a2|a3|...|j2|j3|x|x|x|x|x|x|x|x|x|x|x|x|
- matrixArray_2 = _mm512_permutex2var_epi32(matrixArray_stage_0, idx_stage2_base_2, matrixArray_stage_2); // |a4|a5|...|j4|j5|x|x|x|x|x|x|x|x|x|x|x|x|
- matrixArray_4 = _mm512_permutex2var_epi32(matrixArray_stage_1, idx_stage2_base_0, matrixArray_stage_3); // |a8|a9|...|j8|j9|x|x|x|x|x|x|x|x|x|x|x|x|
- matrixArray_5 = _mm512_permutex2var_epi32(matrixArray_stage_1, idx_stage2_base_2, matrixArray_stage_3); // |a10|x|...|j10|x|x|x|x|x|x|x|x|x|x|x|x|x|
-
- matrixArray_0 = _mm512_mask_blend_epi32(idx_stage2_mask_2, matrixArray_0, matrixArray_stage_4); // |a0|a1|.......................|o0|o1|x|x|
- matrixArray_3 = _mm512_mask_blend_epi32(idx_stage2_mask_2, matrixArray_3, matrixArray_stage_5); // |a6|a7|.......................|o6|o7|x|x|
- matrixArray_1 = _mm512_permutex2var_epi32(matrixArray_1 , idx_stage2_base_1, matrixArray_stage_4); // |a2|a3|.......................|o2|o3|x|x|
- matrixArray_2 = _mm512_permutex2var_epi32(matrixArray_2 , idx_stage2_base_3, matrixArray_stage_4); // |a4|a5|.......................|o4|o5|x|x|
- matrixArray_4 = _mm512_permutex2var_epi32(matrixArray_4 , idx_stage2_base_1, matrixArray_stage_5); // |a8|a9|.......................|o8|o9|x|x|
- matrixArray_5 = _mm512_permutex2var_epi32(matrixArray_5 , idx_stage2_base_3, matrixArray_stage_5); // |a10|x|.......................|o10|x|x|x|
-
- result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_0, (__m512bh) xArray_01);
- result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_1, (__m512bh) xArray_23);
- result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_2, (__m512bh) xArray_45);
- result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_3, (__m512bh) xArray_67);
- result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_4, (__m512bh) xArray_89);
- result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_5, (__m512bh) xArray_10);
- result_0 = _mm512_add_ps(result_0, result_1);
-
- STORE16_MASK_COMPLETE_RESULT(result_0, y+idx_m, store_mask)
- }
- }
-
- if (tag_m_15x != m) {
- __m256i matrixArray256;
- __m256i x256 = _mm256_insertf128_si256(_mm256_castsi128_si256(x128_0), x128_1, 0x1);
- __m256 result256;
- __m128 result128, tmp128;
- unsigned short load256_mask_value = (((unsigned short)0xffff) >> 5);
- __mmask16 load256_mask = *((__mmask16*) &load256_mask_value);
- for (BLASLONG i = tag_m_15x; i < m; i++) {
- result256 = _mm256_setzero_ps();
- matrixArray256 = _mm256_maskz_loadu_epi16(load256_mask, &a[(i)*11]);
- result256 = _mm256_dpbf16_ps(result256, (__m256bh) matrixArray256, (__m256bh) x256);
- result128 = _mm_add_ps(_mm256_castps256_ps128(result256), _mm256_extractf128_ps(result256, 0x1));
- tmp128 = _mm_shuffle_ps(result128, result128, 14);
- result128 = _mm_add_ps(result128, tmp128);
- tmp128 = _mm_shuffle_ps(result128, result128, 1);
- result128 = _mm_add_ps(result128, tmp128);
- #ifndef ZERO_BETA
- #ifndef ONE_BETA
- y[i] = alpha * result128[0] + beta * y[i];
- #else
- y[i] = alpha * result128[0] + y[i];
- #endif
- #else
- #ifndef ONE_ALPHA
- y[i] = result128[0] * alpha;
- #else
- y[i] = result128[0];
- #endif
- #endif
- }
- }
-
- return 0;
- }
-
- // 15 rows parallel processing BF16 GEMV kernel for n=12 && lda ineffective scenario
- #ifndef ZERO_BETA
- #ifndef ONE_BETA
- static int sbgemv_kernel_15x12_alpha_beta(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
- #else
- static int sbgemv_kernel_15x12_alpha_one(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
- #endif
- #else
- #ifndef ONE_ALPHA
- static int sbgemv_kernel_15x12_alpha(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
- #else
- static int sbgemv_kernel_15x12(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
- #endif
- #endif
- {
- BLASLONG tag_m_15x = m - (m%15);
-
- unsigned char x_load_mask_value = (((unsigned char)0xff) >> 4);
- __mmask8 x_load_mask = *((__mmask8*) &x_load_mask_value);
- __m128i x128_0 = _mm_loadu_si128(x); // |x0|x1| x2| x3|x4|x5|x6|x7|
- __m128i x128_1 = _mm_maskz_loadu_epi16(x_load_mask, (x+8)); // |x8|x9|x10|x11| 0| 0| 0| 0|
-
- if (tag_m_15x > 0) {
- __m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3, matrixArray_4, matrixArray_5;
- __m512i matrixArray_stage_0, matrixArray_stage_1, matrixArray_stage_2, matrixArray_stage_3, matrixArray_stage_4, matrixArray_stage_5;
- __m512i xArray_01, xArray_23, xArray_45, xArray_67, xArray_89, xArray_10;
- __m512 result_0, result_1;
-
- #ifndef ONE_ALPHA
- __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
- #endif
- #ifndef ZERO_BETA
- __m512 BETAVECTOR = _mm512_set1_ps(beta);
- #endif
-
- __m512i idx_stage1_base_0, idx_stage1_base_1, idx_stage1_base_2, idx_stage1_base_3, idx_stage1_base_4, idx_stage1_base_5;
- __m512i idx_stage2_base_0, idx_stage2_base_1, idx_stage2_base_2, idx_stage2_base_3;
-
- __m512i M512_EPI32_1, M512_EPI32_2, M512_EPI32_3, M512_EPI32_5;
- M512_EPI32_1 = _mm512_set1_epi32(1);
- M512_EPI32_2 = _mm512_add_epi32(M512_EPI32_1, M512_EPI32_1);
- M512_EPI32_3 = _mm512_add_epi32(M512_EPI32_2, M512_EPI32_1);
- M512_EPI32_5 = _mm512_add_epi32(M512_EPI32_3, M512_EPI32_2);
-
- unsigned short BASE_MASK_10_value = ((unsigned short)0x001f);
- __mmask16 BASE_MASK_10 = *((__mmask16*) &BASE_MASK_10_value);
- unsigned short BASE_MASK_20_value = ((unsigned short)0x03e0);
- __mmask16 BASE_MASK_20 = *((__mmask16*) &BASE_MASK_20_value);
- unsigned short BASE_MASK_30_value = ((unsigned short)0xfc00);
- __mmask16 BASE_MASK_30 = *((__mmask16*) &BASE_MASK_30_value);
-
- idx_stage1_base_0 = _mm512_set_epi32( 0, 26, 20, 14, 8, 2, 25, 19, 13, 7, 1, 24, 18, 12, 6, 0);
- idx_stage1_base_1 = _mm512_add_epi32(idx_stage1_base_0, M512_EPI32_3);
-
- idx_stage1_base_2 = _mm512_mask_add_epi32(idx_stage1_base_0, BASE_MASK_10, idx_stage1_base_0, M512_EPI32_1);
- idx_stage1_base_2 = _mm512_mask_sub_epi32(idx_stage1_base_2, BASE_MASK_20, idx_stage1_base_0, M512_EPI32_1);
- idx_stage1_base_3 = _mm512_add_epi32(idx_stage1_base_2, M512_EPI32_3);
-
- idx_stage1_base_4 = _mm512_mask_add_epi32(idx_stage1_base_2, BASE_MASK_10, idx_stage1_base_2, M512_EPI32_1);
- idx_stage1_base_4 = _mm512_mask_add_epi32(idx_stage1_base_4, BASE_MASK_20, idx_stage1_base_2, M512_EPI32_1);
- idx_stage1_base_4 = _mm512_mask_sub_epi32(idx_stage1_base_4, BASE_MASK_30, idx_stage1_base_2, M512_EPI32_2);
- idx_stage1_base_5 = _mm512_add_epi32(idx_stage1_base_4, M512_EPI32_3);
-
- unsigned short idx_stage2_mask_1_value = ((unsigned short)0x03e0);
- __mmask16 idx_stage2_mask_1 = *((__mmask16*) &idx_stage2_mask_1_value);
- unsigned short idx_stage2_mask_2_value = ((unsigned short)0x7c00);
- __mmask16 idx_stage2_mask_2 = *((__mmask16*) &idx_stage2_mask_2_value);
- idx_stage2_base_0 = _mm512_set_epi32( 0, 0, 0, 0, 0, 0, 20, 19, 18, 17, 16, 9, 8, 7, 6, 5);
- idx_stage2_base_1 = _mm512_set_epi32( 0, 25, 24, 23, 22, 21, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0);
- idx_stage2_base_2 = _mm512_add_epi32(idx_stage2_base_0, M512_EPI32_5);
- idx_stage2_base_2 = _mm512_mask_add_epi32(idx_stage2_base_2, idx_stage2_mask_1, idx_stage2_base_2, M512_EPI32_5);
- idx_stage2_base_3 = _mm512_mask_sub_epi32(idx_stage2_base_1, idx_stage2_mask_2, idx_stage2_base_1, M512_EPI32_5);
-
- xArray_01 = _mm512_broadcastd_epi32(x128_0); // |x0 |x1 |x0 |x1 | ... |x0 |x1 |
- xArray_23 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128_0, 0x1)); // |x2 |x3 |x2 |x3 | ... |x2 |x3 |
- xArray_45 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128_0, 0x2)); // |x4 |x5 |x4 |x5 | ... |x4 |x5 |
- xArray_67 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128_0, 0x3)); // |x6 |x7 |x6 |x7 | ... |x6 |x7 |
- xArray_89 = _mm512_broadcastd_epi32(x128_1); // |x8 |x9 |x8 |x9 | ... |x8 |x9 |
- xArray_10 = _mm512_broadcastd_epi32(_mm_shuffle_epi32(x128_1, 0x1)); // |x10|x11|x10|x11| ... |x10|x11|
-
- unsigned int load_mask_value = (((unsigned int)0xffffffff) >> 4);
- __mmask32 load_mask = *((__mmask32*) &load_mask_value);
-
- unsigned short store_mask_value = (((unsigned short)0xffff) >> 1);
- __mmask16 store_mask = *((__mmask16*) &store_mask_value);
-
- for (BLASLONG idx_m = 0; idx_m < tag_m_15x; idx_m+=15) {
- result_0 = _mm512_setzero_ps();
- result_1 = _mm512_setzero_ps();
-
- matrixArray_0 = _mm512_loadu_si512(&a[idx_m*12]); // Load 2 rows with n=12 plus 8 elements
- matrixArray_1 = _mm512_maskz_loadu_epi16(load_mask, &a[idx_m*12 + 32]); // Load 2 rows with n=12 plus 4 element
- matrixArray_2 = _mm512_loadu_si512(&a[(idx_m+5)*12]); // Load 2 rows with n=12 plus 8 elements
- matrixArray_3 = _mm512_maskz_loadu_epi16(load_mask, &a[(idx_m+5)*12 + 32]); // Load 2 rows with n=12 plus 4 element
- matrixArray_4 = _mm512_loadu_si512(&a[(idx_m+10)*12]); // Load 2 rows with n=12 plus 8 elements
- matrixArray_5 = _mm512_maskz_loadu_epi16(load_mask, &a[(idx_m+10)*12 + 32]); // Load 2 rows with n=12 plus 4 element
-
- // Stage 1: interleave per 16 bits
- matrixArray_stage_0 = _mm512_permutex2var_epi32(matrixArray_0, idx_stage1_base_0, matrixArray_1); // |a0 |a1 |...|e0 |e1 |a2|a3|...|e2|e3|a4 |a5 |...|e4 |e5 |
- matrixArray_stage_1 = _mm512_permutex2var_epi32(matrixArray_0, idx_stage1_base_1, matrixArray_1); // |a6 |a7 |...|e6 |e7 |a8|a9|...|e8|e9|a10|a11|...|e10|e11|
- matrixArray_stage_2 = _mm512_permutex2var_epi32(matrixArray_2, idx_stage1_base_2, matrixArray_3); // |f2 |f3 |...|j2 |j3 |f0|f1|...|j0|j1|f4 |f5 |...|j4 |j5 |
- matrixArray_stage_3 = _mm512_permutex2var_epi32(matrixArray_2, idx_stage1_base_3, matrixArray_3); // |f8 |f9 |...|j8 |j9 |f6|f7|...|j6|j7|f10|f11|...|j10|j11|
- matrixArray_stage_4 = _mm512_permutex2var_epi32(matrixArray_4, idx_stage1_base_4, matrixArray_5); // |k4 |k5 |...|o4 |o5 |k2|k3|...|o2|o3|k0 |k1 |...|o0 |o1 |
- matrixArray_stage_5 = _mm512_permutex2var_epi32(matrixArray_4, idx_stage1_base_5, matrixArray_5); // |k10|k11|...|o10|o11|k8|k9|...|o8|o9|k6 |k7 |...|o6 |o7 |
-
- // Stage 2: interleave per 32 bits
- matrixArray_0 = _mm512_mask_blend_epi32(idx_stage2_mask_1, matrixArray_stage_0, matrixArray_stage_2); // |a0 |a1 |...|j0 |j1 |x|x|x|x|x|x|x|x|x|x|x|x|
- matrixArray_3 = _mm512_mask_blend_epi32(idx_stage2_mask_1, matrixArray_stage_1, matrixArray_stage_3); // |a6 |a7 |...|j6 |j7 |x|x|x|x|x|x|x|x|x|x|x|x|
- matrixArray_1 = _mm512_permutex2var_epi32(matrixArray_stage_0, idx_stage2_base_0, matrixArray_stage_2); // |a2 |a3 |...|j2 |j3 |x|x|x|x|x|x|x|x|x|x|x|x|
- matrixArray_2 = _mm512_permutex2var_epi32(matrixArray_stage_0, idx_stage2_base_2, matrixArray_stage_2); // |a4 |a5 |...|j4 |j5 |x|x|x|x|x|x|x|x|x|x|x|x|
- matrixArray_4 = _mm512_permutex2var_epi32(matrixArray_stage_1, idx_stage2_base_0, matrixArray_stage_3); // |a8 |a9 |...|j8 |j9 |x|x|x|x|x|x|x|x|x|x|x|x|
- matrixArray_5 = _mm512_permutex2var_epi32(matrixArray_stage_1, idx_stage2_base_2, matrixArray_stage_3); // |a10|a11|...|j10|j11|x|x|x|x|x|x|x|x|x|x|x|x|
-
- matrixArray_0 = _mm512_mask_blend_epi32(idx_stage2_mask_2, matrixArray_0, matrixArray_stage_4); // |a0|a1|.......................|o0|o1|x|x|
- matrixArray_3 = _mm512_mask_blend_epi32(idx_stage2_mask_2, matrixArray_3, matrixArray_stage_5); // |a6|a7|.......................|o6|o7|x|x|
- matrixArray_1 = _mm512_permutex2var_epi32(matrixArray_1 , idx_stage2_base_1, matrixArray_stage_4); // |a2|a3|.......................|o2|o3|x|x|
- matrixArray_2 = _mm512_permutex2var_epi32(matrixArray_2 , idx_stage2_base_3, matrixArray_stage_4); // |a4|a5|.......................|o4|o5|x|x|
- matrixArray_4 = _mm512_permutex2var_epi32(matrixArray_4 , idx_stage2_base_1, matrixArray_stage_5); // |a8|a9|.......................|o8|o9|x|x|
- matrixArray_5 = _mm512_permutex2var_epi32(matrixArray_5 , idx_stage2_base_3, matrixArray_stage_5); // |a10|x|.......................|o10|x|x|x|
-
- result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_0, (__m512bh) xArray_01);
- result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_1, (__m512bh) xArray_23);
- result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_2, (__m512bh) xArray_45);
- result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_3, (__m512bh) xArray_67);
- result_0 = _mm512_dpbf16_ps(result_0, (__m512bh) matrixArray_4, (__m512bh) xArray_89);
- result_1 = _mm512_dpbf16_ps(result_1, (__m512bh) matrixArray_5, (__m512bh) xArray_10);
- result_0 = _mm512_add_ps(result_0, result_1);
-
- STORE16_MASK_COMPLETE_RESULT(result_0, y+idx_m, store_mask)
- }
- }
-
- if (tag_m_15x != m) {
- __m256i matrixArray256;
- __m256i x256 = _mm256_insertf128_si256(_mm256_castsi128_si256(x128_0), x128_1, 0x1);
- __m256 result256;
- __m128 result128, tmp128;
- unsigned short load256_mask_value = (((unsigned short)0xffff) >> 4);
- __mmask16 load256_mask = *((__mmask16*) &load256_mask_value);
- for (BLASLONG i = tag_m_15x; i < m; i++) {
- result256 = _mm256_setzero_ps();
- matrixArray256 = _mm256_maskz_loadu_epi16(load256_mask, &a[(i)*12]);
- result256 = _mm256_dpbf16_ps(result256, (__m256bh) matrixArray256, (__m256bh) x256);
- result128 = _mm_add_ps(_mm256_castps256_ps128(result256), _mm256_extractf128_ps(result256, 0x1));
- tmp128 = _mm_shuffle_ps(result128, result128, 14);
- result128 = _mm_add_ps(result128, tmp128);
- tmp128 = _mm_shuffle_ps(result128, result128, 1);
- result128 = _mm_add_ps(result128, tmp128);
- #ifndef ZERO_BETA
- #ifndef ONE_BETA
- y[i] = alpha * result128[0] + beta * y[i];
- #else
- y[i] = alpha * result128[0] + y[i];
- #endif
- #else
- #ifndef ONE_ALPHA
- y[i] = result128[0] * alpha;
- #else
- y[i] = result128[0];
- #endif
- #endif
- }
- }
-
- return 0;
- }
-
-
- // 16 rows parallel processing BF16 GEMV kernel for n=13 && lda ineffective scenario
- #ifndef ZERO_BETA
- #ifndef ONE_BETA
- static int sbgemv_kernel_16x13_alpha_beta(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
- #else
- static int sbgemv_kernel_16x13_alpha_one(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
- #endif
- #else
- #ifndef ONE_ALPHA
- static int sbgemv_kernel_16x13_alpha(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
- #else
- static int sbgemv_kernel_16x13(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
- #endif
- #endif
- {
- BLASLONG tag_m_16x = m & (~15);
-
- unsigned short x_load_mask_value = (((unsigned short)0xffff) >> 3);
- __mmask16 x_load_mask = *((__mmask16*) &x_load_mask_value);
- __m256i x256 = _mm256_maskz_loadu_epi16(x_load_mask, x); // |x0|x1|x2|x3|x4|x5|x6|x7|x8|x9|x10|x11|x12|0|0|0|
-
- if (tag_m_16x > 0) {
- __m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3, matrixArray_4, matrixArray_5, matrixArray_6, matrixArray_7, \
- matrixArray_8, matrixArray_9, matrixArray_10, matrixArray_11, matrixArray_12, matrixArray_13, matrixArray_14, matrixArray_15;
- __m512i xArray_0, xArray_1, xArray_2, xArray_3;
- __m512 accum512_0, accum512_1;
- __m512 result_0, result_1;
-
- __m256i matrixArray256_0, matrixArray256_1, matrixArray256_2, matrixArray256_3, matrixArray256_4, matrixArray256_5, matrixArray256_6, matrixArray256_7;
-
- #ifndef ONE_ALPHA
- __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
- #endif
- #ifndef ZERO_BETA
- __m512 BETAVECTOR = _mm512_set1_ps(beta);
- #endif
-
- __m512i M512_EPI32_4 = _mm512_set1_epi32(4);
- __m512i idx_base_0 = _mm512_set_epi32(27, 26, 25, 24, 11, 10, 9, 8, 19, 18, 17, 16, 3, 2, 1, 0);
- __m512i idx_base_1 = _mm512_add_epi32(idx_base_0, M512_EPI32_4);
-
- unsigned int load_mask_value = (((unsigned int)0xffffffff) >> 6);
- __mmask32 load_mask = *((__mmask32*) &load_mask_value);
-
- // Prepare X with 2-step interleave way
- xArray_0 = _mm512_inserti32x8(_mm512_castsi256_si512(x256), x256, 0x1);
- BF16_INTERLEAVE_1x32(xArray)
-
- for (BLASLONG idx_m = 0; idx_m < tag_m_16x; idx_m+=16) {
- accum512_0 = _mm512_setzero_ps();
- accum512_1 = _mm512_setzero_ps();
-
- // Load matrix
- BF16_MATRIX_MASKZ_LOAD_8x16(matrixArray256, a, 13, idx_m, 0, x_load_mask)
-
- matrixArray_8 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_0), matrixArray256_1, 0x1);
- matrixArray_9 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_2), matrixArray256_3, 0x1);
- matrixArray_10 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_4), matrixArray256_5, 0x1);
- matrixArray_11 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_6), matrixArray256_7, 0x1);
-
- BF16_MATRIX_MASKZ_LOAD_8x16(matrixArray256, a, 13, idx_m+8, 0, x_load_mask)
-
- matrixArray_12 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_0), matrixArray256_1, 0x1);
- matrixArray_13 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_2), matrixArray256_3, 0x1);
- matrixArray_14 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_4), matrixArray256_5, 0x1);
- matrixArray_15 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_6), matrixArray256_7, 0x1);
-
- // interleave per 256 bits
- BF16_INTERLEAVE256_8x32(matrixArray)
-
- // 2-step interleave for matrix
- BF16_INTERLEAVE_8x32(matrixArray)
-
- // Calculate the temp result for a..p[0:15]
- BF16_2STEP_INTERLEAVED_DOT_8x32(accum512, matrixArray, xArray)
-
- // Reorder and add up the final result
- result_0 = _mm512_permutex2var_ps(accum512_0, idx_base_0, accum512_1);
- result_1 = _mm512_permutex2var_ps(accum512_0, idx_base_1, accum512_1);
- result_0 = _mm512_add_ps(result_0, result_1);
- STORE16_COMPLETE_RESULT(result_0, y+idx_m)
- }
-
- if (m - tag_m_16x > 7) {
- __m512i permutevar_idx = _mm512_set_epi32(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0);
- accum512_0 = _mm512_setzero_ps();
- accum512_1 = _mm512_setzero_ps();
-
- // Load matrix
- BF16_MATRIX_MASKZ_LOAD_8x16(matrixArray256, a, 13, tag_m_16x, 0, x_load_mask)
-
- matrixArray_8 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_0), matrixArray256_1, 0x1);
- matrixArray_9 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_2), matrixArray256_3, 0x1);
- matrixArray_10 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_4), matrixArray256_5, 0x1);
- matrixArray_11 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_6), matrixArray256_7, 0x1);
-
- // interleave per 256 bits
- matrixArray_0 = _mm512_shuffle_i32x4(matrixArray_8, matrixArray_10, 0x44);
- matrixArray_1 = _mm512_shuffle_i32x4(matrixArray_8, matrixArray_10, 0xee);
- matrixArray_2 = _mm512_shuffle_i32x4(matrixArray_9, matrixArray_11, 0x44);
- matrixArray_3 = _mm512_shuffle_i32x4(matrixArray_9, matrixArray_11, 0xee);
-
- // 2-step interleave for matrix
- BF16_INTERLEAVE_4x32(matrixArray)
-
- // Calculate the temp result for a..h[0:15]
- BF16_2STEP_INTERLEAVED_DOT_4x32(accum512, matrixArray, xArray)
-
- accum512_0 = _mm512_add_ps(accum512_0, accum512_1);
- accum512_0 = _mm512_permutexvar_ps(permutevar_idx, accum512_0);
- __m256 result256 = _mm256_add_ps(_mm512_castps512_ps256(accum512_0), _mm512_extractf32x8_ps(accum512_0, 1));
- STORE8_COMPLETE_RESULT(result256, y+tag_m_16x)
- tag_m_16x += 8;
- }
-
- if (m - tag_m_16x > 3) {
- __m256i xArray256_0, xArray256_1, xArray256_2, xArray256_3;
- __m256 accum256_0, accum256_1;
-
- xArray256_0 = _mm512_castsi512_si256(xArray_0);
- xArray256_1 = _mm512_castsi512_si256(xArray_1);
- xArray256_2 = _mm512_castsi512_si256(xArray_2);
- xArray256_3 = _mm512_castsi512_si256(xArray_3);
-
- accum256_0 = _mm256_setzero_ps();
- accum256_1 = _mm256_setzero_ps();
-
- BF16_MATRIX_MASKZ_LOAD_4x16(matrixArray256, a, 13, tag_m_16x, 0, x_load_mask)
-
- // 2-step interleave for matrix
- BF16_INTERLEAVE_4x16(matrixArray256)
-
- // Calculate the temp result for a..d[0:15]
- BF16_2STEP_INTERLEAVED_DOT_4x16(accum256, matrixArray256, xArray256)
-
- accum256_0 = _mm256_add_ps(accum256_0, accum256_1);
- __m128 result128 = _mm_add_ps(_mm256_castps256_ps128(accum256_0), _mm256_extractf32x4_ps(accum256_0, 1));
- STORE4_COMPLETE_RESULT(result128, y+tag_m_16x)
- tag_m_16x += 4;
- }
- }
-
- if (tag_m_16x != m) {
- __m256i matrixArray256;
- __m256 accum256;
- __m128 accum128, tmp128;
- for (BLASLONG i = tag_m_16x; i < m; i++) {
- accum256 = _mm256_setzero_ps();
- matrixArray256 = _mm256_maskz_loadu_epi16(x_load_mask, &a[(i)*13]); // Load 1 rows with n=13
- accum256 = _mm256_dpbf16_ps(accum256, (__m256bh) matrixArray256, (__m256bh) x256);
- accum128 = _mm_add_ps(_mm256_castps256_ps128(accum256), _mm256_extractf32x4_ps(accum256, 1));
- tmp128 = _mm_shuffle_ps(accum128, accum128, 0x0e);
- accum128 = _mm_add_ps(accum128, tmp128);
- tmp128 = _mm_shuffle_ps(accum128, accum128, 0x01);
- accum128 = _mm_add_ps(accum128, tmp128);
- #ifndef ZERO_BETA
- #ifndef ONE_BETA
- y[i] = alpha * accum128[0] + beta * y[i];
- #else
- y[i] = alpha * accum128[0] + y[i];
- #endif
- #else
- #ifndef ONE_ALPHA
- y[i] = accum128[0] * alpha;
- #else
- y[i] = accum128[0];
- #endif
- #endif
- }
- }
-
- return 0;
- }
-
- // 16 rows parallel processing BF16 GEMV kernel for n=14 && lda ineffective scenario
- #ifndef ZERO_BETA
- #ifndef ONE_BETA
- static int sbgemv_kernel_16x14_alpha_beta(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
- #else
- static int sbgemv_kernel_16x14_alpha_one(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
- #endif
- #else
- #ifndef ONE_ALPHA
- static int sbgemv_kernel_16x14_alpha(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
- #else
- static int sbgemv_kernel_16x14(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
- #endif
- #endif
- {
- BLASLONG tag_m_16x = m & (~15);
-
- unsigned short x_load_mask_value = (((unsigned short)0xffff) >> 2);
- __mmask16 x_load_mask = *((__mmask16*) &x_load_mask_value);
- __m256i x256 = _mm256_maskz_loadu_epi16(x_load_mask, x); // |x0|x1|x2|x3|x4|x5|x6|x7|x8|x9|x10|x11|x12|x13|0|0|
-
- if (tag_m_16x > 0) {
- __m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3, matrixArray_4, matrixArray_5, matrixArray_6, matrixArray_7, \
- matrixArray_8, matrixArray_9, matrixArray_10, matrixArray_11, matrixArray_12, matrixArray_13, matrixArray_14, matrixArray_15;
- __m512i xArray_0, xArray_1, xArray_2, xArray_3;
- __m512 accum512_0, accum512_1;
- __m512 result_0, result_1;
-
- #ifndef ONE_ALPHA
- __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
- #endif
- #ifndef ZERO_BETA
- __m512 BETAVECTOR = _mm512_set1_ps(beta);
- #endif
-
- __m512i M512_EPI32_4 = _mm512_set1_epi32(4);
- __m512i idx_base_0 = _mm512_set_epi32(27, 26, 25, 24, 11, 10, 9, 8, 19, 18, 17, 16, 3, 2, 1, 0);
- __m512i idx_base_1 = _mm512_add_epi32(idx_base_0, M512_EPI32_4);
- __m512i shift_idx = _mm512_set_epi32(0, 13, 12, 11, 10, 9, 8, 7, 0, 6, 5, 4, 3, 2, 1, 0);
-
- unsigned int load_mask_value = (((unsigned int)0xffffffff) >> 4);
- __mmask32 load_mask = *((__mmask32*) &load_mask_value);
-
- // Prepare X with 2-step interleave way
- xArray_0 = _mm512_inserti32x8(_mm512_castsi256_si512(x256), x256, 0x1);
- BF16_INTERLEAVE_1x32(xArray)
-
- for (BLASLONG idx_m = 0; idx_m < tag_m_16x; idx_m+=16) {
- accum512_0 = _mm512_setzero_ps();
- accum512_1 = _mm512_setzero_ps();
-
- // Load matrix
- BF16_MATRIX_MASKZ_LOAD_8x32_2(matrixArray, a, 14, idx_m, 0, load_mask)
-
- // Pre-stage: shift the 2nd vector 1 position right for each register
- BF16_PERMUTE_8x32_2(shift_idx, matrixArray)
-
- // interleave per 256 bits
- BF16_INTERLEAVE256_8x32(matrixArray)
-
- // 2-step interleave for matrix
- BF16_INTERLEAVE_8x32(matrixArray)
-
- // Calculate the temp result for a..p[0:15]
- BF16_2STEP_INTERLEAVED_DOT_8x32(accum512, matrixArray, xArray)
-
- // Reorder and add up the final result
- result_0 = _mm512_permutex2var_ps(accum512_0, idx_base_0, accum512_1);
- result_1 = _mm512_permutex2var_ps(accum512_0, idx_base_1, accum512_1);
- result_0 = _mm512_add_ps(result_0, result_1);
- STORE16_COMPLETE_RESULT(result_0, y+idx_m)
- }
-
- if (m - tag_m_16x > 7) {
- __m512i permutevar_idx = _mm512_set_epi32(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0);
- accum512_0 = _mm512_setzero_ps();
- accum512_1 = _mm512_setzero_ps();
-
- // Load matrix
- BF16_MATRIX_MASKZ_LOAD_4x32_2(matrixArray, a, 14, tag_m_16x, 0, load_mask)
-
- // Pre-stage: shift the 2nd vector 1 position right for each register
- BF16_PERMUTE_4x32_2(shift_idx, matrixArray)
-
- // interleave per 256 bits
- BF16_INTERLEAVE256_4x32(matrixArray)
-
- // 2-step interleave for matrix
- BF16_INTERLEAVE_4x32(matrixArray)
-
- // Calculate the temp result for a..h[0:15]
- BF16_2STEP_INTERLEAVED_DOT_4x32(accum512, matrixArray, xArray)
-
- accum512_0 = _mm512_add_ps(accum512_0, accum512_1);
- accum512_0 = _mm512_permutexvar_ps(permutevar_idx, accum512_0);
- __m256 result256 = _mm256_add_ps(_mm512_castps512_ps256(accum512_0), _mm512_extractf32x8_ps(accum512_0, 1));
- STORE8_COMPLETE_RESULT(result256, y+tag_m_16x)
- tag_m_16x += 8;
- }
-
- if (m - tag_m_16x > 3) {
- __m256i matrixArray256_0, matrixArray256_1, matrixArray256_2, matrixArray256_3, matrixArray256_4, matrixArray256_5, matrixArray256_6, matrixArray256_7;
- __m256i xArray256_0, xArray256_1, xArray256_2, xArray256_3;
- __m256 accum256_0, accum256_1;
-
- xArray256_0 = _mm512_castsi512_si256(xArray_0);
- xArray256_1 = _mm512_castsi512_si256(xArray_1);
- xArray256_2 = _mm512_castsi512_si256(xArray_2);
- xArray256_3 = _mm512_castsi512_si256(xArray_3);
-
- accum256_0 = _mm256_setzero_ps();
- accum256_1 = _mm256_setzero_ps();
-
- BF16_MATRIX_MASKZ_LOAD_4x16(matrixArray256, a, 14, tag_m_16x, 0, x_load_mask)
-
- // 2-step interleave for matrix
- BF16_INTERLEAVE_4x16(matrixArray256)
-
- // Calculate the temp result for a..d[0:15]
- BF16_2STEP_INTERLEAVED_DOT_4x16(accum256, matrixArray256, xArray256)
-
- accum256_0 = _mm256_add_ps(accum256_0, accum256_1);
- __m128 result128 = _mm_add_ps(_mm256_castps256_ps128(accum256_0), _mm256_extractf32x4_ps(accum256_0, 1));
- STORE4_COMPLETE_RESULT(result128, y+tag_m_16x)
- tag_m_16x += 4;
- }
- }
-
- if (tag_m_16x != m) {
- __m256i matrixArray256;
- __m256 accum256;
- __m128 accum128, tmp128;
- for (BLASLONG i = tag_m_16x; i < m; i++) {
- accum256 = _mm256_setzero_ps();
- matrixArray256 = _mm256_maskz_loadu_epi16(x_load_mask, &a[(i)*14]); // Load 1 rows with n=14
- accum256 = _mm256_dpbf16_ps(accum256, (__m256bh) matrixArray256, (__m256bh) x256);
- accum128 = _mm_add_ps(_mm256_castps256_ps128(accum256), _mm256_extractf32x4_ps(accum256, 1));
- tmp128 = _mm_shuffle_ps(accum128, accum128, 0x0e);
- accum128 = _mm_add_ps(accum128, tmp128);
- tmp128 = _mm_shuffle_ps(accum128, accum128, 0x01);
- accum128 = _mm_add_ps(accum128, tmp128);
- #ifndef ZERO_BETA
- #ifndef ONE_BETA
- y[i] = alpha * accum128[0] + beta * y[i];
- #else
- y[i] = alpha * accum128[0] + y[i];
- #endif
- #else
- #ifndef ONE_ALPHA
- y[i] = accum128[0] * alpha;
- #else
- y[i] = accum128[0];
- #endif
- #endif
- }
- }
-
- return 0;
- }
-
- // 16 rows parallel processing BF16 GEMV kernel for n=15 && lda ineffective scenario
- #ifndef ZERO_BETA
- #ifndef ONE_BETA
- static int sbgemv_kernel_16x15_alpha_beta(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
- #else
- static int sbgemv_kernel_16x15_alpha_one(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
- #endif
- #else
- #ifndef ONE_ALPHA
- static int sbgemv_kernel_16x15_alpha(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
- #else
- static int sbgemv_kernel_16x15(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
- #endif
- #endif
- {
- BLASLONG tag_m_16x = m & (~15);
-
- unsigned short x_load_mask_value = (((unsigned short)0xffff) >> 1);
- __mmask16 x_load_mask = *((__mmask16*) &x_load_mask_value);
- __m256i x256 = _mm256_maskz_loadu_epi16(x_load_mask, x); // |x0|x1|x2|x3|x4|x5|x6|x7|x8|x9|x10|x11|x12|x13|x14|0|
-
- if (tag_m_16x > 0) {
- __m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3, matrixArray_4, matrixArray_5, matrixArray_6, matrixArray_7, \
- matrixArray_8, matrixArray_9, matrixArray_10, matrixArray_11, matrixArray_12, matrixArray_13, matrixArray_14, matrixArray_15;
- __m512i xArray_0, xArray_1, xArray_2, xArray_3;
- __m512 accum512_0, accum512_1;
- __m512 result_0, result_1;
-
- __m256i matrixArray256_0, matrixArray256_1, matrixArray256_2, matrixArray256_3, matrixArray256_4, matrixArray256_5, matrixArray256_6, matrixArray256_7;
-
- #ifndef ONE_ALPHA
- __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
- #endif
- #ifndef ZERO_BETA
- __m512 BETAVECTOR = _mm512_set1_ps(beta);
- #endif
-
- __m512i M512_EPI32_4 = _mm512_set1_epi32(4);
- __m512i idx_base_0 = _mm512_set_epi32(27, 26, 25, 24, 11, 10, 9, 8, 19, 18, 17, 16, 3, 2, 1, 0);
- __m512i idx_base_1 = _mm512_add_epi32(idx_base_0, M512_EPI32_4);
-
- unsigned int load_mask_value = (((unsigned int)0xffffffff) >> 2);
- __mmask32 load_mask = *((__mmask32*) &load_mask_value);
-
- // Prepare X with 2-step interleave way
- xArray_0 = _mm512_inserti32x8(_mm512_castsi256_si512(x256), x256, 0x1);
- BF16_INTERLEAVE_1x32(xArray)
-
- for (BLASLONG idx_m = 0; idx_m < tag_m_16x; idx_m+=16) {
- accum512_0 = _mm512_setzero_ps();
- accum512_1 = _mm512_setzero_ps();
-
- // Load matrix
- BF16_MATRIX_MASKZ_LOAD_8x16(matrixArray256, a, 15, idx_m, 0, x_load_mask)
-
- matrixArray_8 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_0), matrixArray256_1, 0x1);
- matrixArray_9 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_2), matrixArray256_3, 0x1);
- matrixArray_10 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_4), matrixArray256_5, 0x1);
- matrixArray_11 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_6), matrixArray256_7, 0x1);
-
- BF16_MATRIX_MASKZ_LOAD_8x16(matrixArray256, a, 15, idx_m+8, 0, x_load_mask)
-
- matrixArray_12 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_0), matrixArray256_1, 0x1);
- matrixArray_13 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_2), matrixArray256_3, 0x1);
- matrixArray_14 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_4), matrixArray256_5, 0x1);
- matrixArray_15 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_6), matrixArray256_7, 0x1);
-
- // interleave per 256 bits
- BF16_INTERLEAVE256_8x32(matrixArray)
-
- // 2-step interleave for matrix
- BF16_INTERLEAVE_8x32(matrixArray)
-
- // Calculate the temp result for a..p[0:15]
- BF16_2STEP_INTERLEAVED_DOT_8x32(accum512, matrixArray, xArray)
-
- // Reorder and add up the final result
- result_0 = _mm512_permutex2var_ps(accum512_0, idx_base_0, accum512_1);
- result_1 = _mm512_permutex2var_ps(accum512_0, idx_base_1, accum512_1);
- result_0 = _mm512_add_ps(result_0, result_1);
- STORE16_COMPLETE_RESULT(result_0, y+idx_m)
- }
-
- if (m - tag_m_16x > 7) {
- __m512i permutevar_idx = _mm512_set_epi32(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0);
- accum512_0 = _mm512_setzero_ps();
- accum512_1 = _mm512_setzero_ps();
-
- // Load matrix
- BF16_MATRIX_MASKZ_LOAD_8x16(matrixArray256, a, 15, tag_m_16x, 0, x_load_mask)
-
- matrixArray_8 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_0), matrixArray256_1, 0x1);
- matrixArray_9 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_2), matrixArray256_3, 0x1);
- matrixArray_10 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_4), matrixArray256_5, 0x1);
- matrixArray_11 = _mm512_inserti32x8(_mm512_castsi256_si512(matrixArray256_6), matrixArray256_7, 0x1);
-
- // interleave per 256 bits
- matrixArray_0 = _mm512_shuffle_i32x4(matrixArray_8, matrixArray_10, 0x44);
- matrixArray_1 = _mm512_shuffle_i32x4(matrixArray_8, matrixArray_10, 0xee);
- matrixArray_2 = _mm512_shuffle_i32x4(matrixArray_9, matrixArray_11, 0x44);
- matrixArray_3 = _mm512_shuffle_i32x4(matrixArray_9, matrixArray_11, 0xee);
-
- // 2-step interleave for matrix
- BF16_INTERLEAVE_4x32(matrixArray)
-
- // Calculate the temp result for a..h[0:15]
- BF16_2STEP_INTERLEAVED_DOT_4x32(accum512, matrixArray, xArray)
-
- accum512_0 = _mm512_add_ps(accum512_0, accum512_1);
- accum512_0 = _mm512_permutexvar_ps(permutevar_idx, accum512_0);
- __m256 result256 = _mm256_add_ps(_mm512_castps512_ps256(accum512_0), _mm512_extractf32x8_ps(accum512_0, 1));
- STORE8_COMPLETE_RESULT(result256, y+tag_m_16x)
- tag_m_16x += 8;
- }
-
- if (m - tag_m_16x > 3) {
- __m256i xArray256_0, xArray256_1, xArray256_2, xArray256_3;
- __m256 accum256_0, accum256_1;
-
- xArray256_0 = _mm512_castsi512_si256(xArray_0);
- xArray256_1 = _mm512_castsi512_si256(xArray_1);
- xArray256_2 = _mm512_castsi512_si256(xArray_2);
- xArray256_3 = _mm512_castsi512_si256(xArray_3);
-
- accum256_0 = _mm256_setzero_ps();
- accum256_1 = _mm256_setzero_ps();
-
- BF16_MATRIX_MASKZ_LOAD_4x16(matrixArray256, a, 15, tag_m_16x, 0, x_load_mask)
-
- // 2-step interleave for matrix
- BF16_INTERLEAVE_4x16(matrixArray256)
-
- // Calculate the temp result for a..d[0:15]
- BF16_2STEP_INTERLEAVED_DOT_4x16(accum256, matrixArray256, xArray256)
-
- accum256_0 = _mm256_add_ps(accum256_0, accum256_1);
- __m128 result128 = _mm_add_ps(_mm256_castps256_ps128(accum256_0), _mm256_extractf32x4_ps(accum256_0, 1));
- STORE4_COMPLETE_RESULT(result128, y+tag_m_16x)
- tag_m_16x += 4;
- }
- }
-
- if (tag_m_16x != m) {
- __m256i matrixArray256;
- __m256 accum256;
- __m128 accum128, tmp128;
- for (BLASLONG i = tag_m_16x; i < m; i++) {
- accum256 = _mm256_setzero_ps();
- matrixArray256 = _mm256_maskz_loadu_epi16(x_load_mask, &a[(i)*15]); // Load 1 rows with n=15
- accum256 = _mm256_dpbf16_ps(accum256, (__m256bh) matrixArray256, (__m256bh) x256);
- accum128 = _mm_add_ps(_mm256_castps256_ps128(accum256), _mm256_extractf32x4_ps(accum256, 1));
- tmp128 = _mm_shuffle_ps(accum128, accum128, 0x0e);
- accum128 = _mm_add_ps(accum128, tmp128);
- tmp128 = _mm_shuffle_ps(accum128, accum128, 0x01);
- accum128 = _mm_add_ps(accum128, tmp128);
- #ifndef ZERO_BETA
- #ifndef ONE_BETA
- y[i] = alpha * accum128[0] + beta * y[i];
- #else
- y[i] = alpha * accum128[0] + y[i];
- #endif
- #else
- #ifndef ONE_ALPHA
- y[i] = accum128[0] * alpha;
- #else
- y[i] = accum128[0];
- #endif
- #endif
- }
- }
-
- return 0;
- }
-
- // 16 rows parallel processing BF16 GEMV kernel for n=16 && lda ineffective scenario
- #ifndef ZERO_BETA
- #ifndef ONE_BETA
- static int sbgemv_kernel_16x16_alpha_beta(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
- #else
- static int sbgemv_kernel_16x16_alpha_one(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float beta, float *y)
- #endif
- #else
- #ifndef ONE_ALPHA
- static int sbgemv_kernel_16x16_alpha(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
- #else
- static int sbgemv_kernel_16x16(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, float *y)
- #endif
- #endif
- {
- BLASLONG tag_m_16x = m & (~15);
-
- __m256i x256 = _mm256_loadu_si256(x); // |x0|x1|x2|x3|x4|x5|x6|x7|x8|x9|x10|x11|x12|x13|x14|x15|
-
- if (tag_m_16x > 0) {
- __m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3, matrixArray_4, matrixArray_5, matrixArray_6, matrixArray_7, \
- matrixArray_8, matrixArray_9, matrixArray_10, matrixArray_11, matrixArray_12, matrixArray_13, matrixArray_14, matrixArray_15;
- __m512i xArray_0, xArray_1, xArray_2, xArray_3;
- __m512 accum512_0, accum512_1;
- __m512 result_0, result_1;
-
- #ifndef ONE_ALPHA
- __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
- #endif
- #ifndef ZERO_BETA
- __m512 BETAVECTOR = _mm512_set1_ps(beta);
- #endif
-
- __m512i M512_EPI32_4 = _mm512_set1_epi32(4);
- __m512i idx_base_0 = _mm512_set_epi32(27, 26, 25, 24, 11, 10, 9, 8, 19, 18, 17, 16, 3, 2, 1, 0);
- __m512i idx_base_1 = _mm512_add_epi32(idx_base_0, M512_EPI32_4);
-
- // Prepare X with 2-step interleave way
- xArray_0 = _mm512_inserti32x8(_mm512_castsi256_si512(x256), x256, 0x1);
- BF16_INTERLEAVE_1x32(xArray)
-
- for (BLASLONG idx_m = 0; idx_m < tag_m_16x; idx_m+=16) {
- accum512_0 = _mm512_setzero_ps();
- accum512_1 = _mm512_setzero_ps();
-
- matrixArray_8 = _mm512_loadu_si512(&a[(idx_m )*16]); // Load 2 rows with n=16
- matrixArray_9 = _mm512_loadu_si512(&a[(idx_m+2 )*16]); // Load 2 rows with n=16
- matrixArray_10 = _mm512_loadu_si512(&a[(idx_m+4 )*16]); // Load 2 rows with n=16
- matrixArray_11 = _mm512_loadu_si512(&a[(idx_m+6 )*16]); // Load 2 rows with n=16
- matrixArray_12 = _mm512_loadu_si512(&a[(idx_m+8 )*16]); // Load 2 rows with n=16
- matrixArray_13 = _mm512_loadu_si512(&a[(idx_m+10)*16]); // Load 2 rows with n=16
- matrixArray_14 = _mm512_loadu_si512(&a[(idx_m+12)*16]); // Load 2 rows with n=16
- matrixArray_15 = _mm512_loadu_si512(&a[(idx_m+14)*16]); // Load 2 rows with n=16
-
- // interleave per 256 bits
- BF16_INTERLEAVE256_8x32(matrixArray)
-
- // 2-step interleave for matrix
- BF16_INTERLEAVE_8x32(matrixArray)
-
- // Calculate the temp result for a..p[0:15]
- BF16_2STEP_INTERLEAVED_DOT_8x32(accum512, matrixArray, xArray)
-
- // Reorder and add up the final result
- result_0 = _mm512_permutex2var_ps(accum512_0, idx_base_0, accum512_1);
- result_1 = _mm512_permutex2var_ps(accum512_0, idx_base_1, accum512_1);
- result_0 = _mm512_add_ps(result_0, result_1);
- STORE16_COMPLETE_RESULT(result_0, y+idx_m)
- }
-
- if (m - tag_m_16x > 7) {
- __m512i permutevar_idx = _mm512_set_epi32(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0);
- accum512_0 = _mm512_setzero_ps();
- accum512_1 = _mm512_setzero_ps();
-
- matrixArray_4 = _mm512_loadu_si512(&a[(tag_m_16x )*16]); // Load 2 rows with n=16
- matrixArray_5 = _mm512_loadu_si512(&a[(tag_m_16x+2 )*16]); // Load 2 rows with n=16
- matrixArray_6 = _mm512_loadu_si512(&a[(tag_m_16x+4 )*16]); // Load 2 rows with n=16
- matrixArray_7 = _mm512_loadu_si512(&a[(tag_m_16x+6 )*16]); // Load 2 rows with n=16
-
- // interleave per 256 bits
- BF16_INTERLEAVE256_4x32(matrixArray)
-
- // 2-step interleave for matrix
- BF16_INTERLEAVE_4x32(matrixArray)
-
- // Calculate the temp result for a..h[0:15]
- BF16_2STEP_INTERLEAVED_DOT_4x32(accum512, matrixArray, xArray)
-
- accum512_0 = _mm512_add_ps(accum512_0, accum512_1);
- accum512_0 = _mm512_permutexvar_ps(permutevar_idx, accum512_0);
- __m256 result256 = _mm256_add_ps(_mm512_castps512_ps256(accum512_0), _mm512_extractf32x8_ps(accum512_0, 1));
- STORE8_COMPLETE_RESULT(result256, y+tag_m_16x)
- tag_m_16x += 8;
- }
-
- if (m - tag_m_16x > 3) {
- __m256i matrixArray256_0, matrixArray256_1, matrixArray256_2, matrixArray256_3, \
- matrixArray256_4, matrixArray256_5, matrixArray256_6, matrixArray256_7;
- __m256i xArray256_0, xArray256_1, xArray256_2, xArray256_3;
- __m256 accum256_0, accum256_1;
-
- xArray256_0 = _mm512_castsi512_si256(xArray_0);
- xArray256_1 = _mm512_castsi512_si256(xArray_1);
- xArray256_2 = _mm512_castsi512_si256(xArray_2);
- xArray256_3 = _mm512_castsi512_si256(xArray_3);
-
- accum256_0 = _mm256_setzero_ps();
- accum256_1 = _mm256_setzero_ps();
-
- matrixArray_0 = _mm512_loadu_si512(&a[(tag_m_16x )*16]); // Load 2 rows with n=16
- matrixArray_1 = _mm512_loadu_si512(&a[(tag_m_16x+2 )*16]); // Load 2 rows with n=16
-
- matrixArray256_0 = _mm512_castsi512_si256(matrixArray_0);
- matrixArray256_1 = _mm512_extracti32x8_epi32(matrixArray_0, 0x1);
- matrixArray256_2 = _mm512_castsi512_si256(matrixArray_1);
- matrixArray256_3 = _mm512_extracti32x8_epi32(matrixArray_1, 0x1);
-
- // 2-step interleave for matrix
- BF16_INTERLEAVE_4x16(matrixArray256)
-
- // Calculate the temp result for a..d[0:15]
- BF16_2STEP_INTERLEAVED_DOT_4x16(accum256, matrixArray256, xArray256)
-
- accum256_0 = _mm256_add_ps(accum256_0, accum256_1);
- __m128 result128 = _mm_add_ps(_mm256_castps256_ps128(accum256_0), _mm256_extractf32x4_ps(accum256_0, 1));
- STORE4_COMPLETE_RESULT(result128, y+tag_m_16x)
- tag_m_16x += 4;
- }
- }
-
- if (tag_m_16x != m) {
- __m256i matrixArray256;
- __m256 accum256;
- __m128 accum128, tmp128;
- for (BLASLONG i = tag_m_16x; i < m; i++) {
- accum256 = _mm256_setzero_ps();
- matrixArray256 = _mm256_loadu_si256(&a[(i)*16]); // Load 1 rows with n=16
- accum256 = _mm256_dpbf16_ps(accum256, (__m256bh) matrixArray256, (__m256bh) x256);
- accum128 = _mm_add_ps(_mm256_castps256_ps128(accum256), _mm256_extractf32x4_ps(accum256, 1));
- tmp128 = _mm_shuffle_ps(accum128, accum128, 0x0e);
- accum128 = _mm_add_ps(accum128, tmp128);
- tmp128 = _mm_shuffle_ps(accum128, accum128, 0x01);
- accum128 = _mm_add_ps(accum128, tmp128);
- #ifndef ZERO_BETA
- #ifndef ONE_BETA
- y[i] = alpha * accum128[0] + beta * y[i];
- #else
- y[i] = alpha * accum128[0] + y[i];
- #endif
- #else
- #ifndef ONE_ALPHA
- y[i] = accum128[0] * alpha;
- #else
- y[i] = accum128[0];
- #endif
- #endif
- }
- }
-
- return 0;
- }
-
- // 8 rows parallel processing BF16 GEMV kernel for n>16 && lda effective scenario
- #ifndef ZERO_BETA
- #ifndef ONE_BETA
- static int sbgemv_kernel_8x16p_lda_alpha_beta(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float beta, float *y)
- #else
- static int sbgemv_kernel_8x16p_lda_alpha_one(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float beta, float *y)
- #endif
- #else
- #ifndef ONE_ALPHA
- static int sbgemv_kernel_8x16p_lda_alpha(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float *y)
- #else
- static int sbgemv_kernel_8x16p_lda(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float *y)
- #endif
- #endif
- {
- BLASLONG tag_m_8x = m & (~7);
-
- unsigned int load_mask_value = (((unsigned int)0xffffffff) >> (32-n));
- __mmask32 load_mask = *((__mmask32*) &load_mask_value);
- __m512i x512 = _mm512_maskz_loadu_epi16(load_mask, x); // |x0|x1|x2|x3|x4|x5|x6|x7|x8|x9|x10|x11|x12|x13|x14|x15|...
-
- #ifndef ONE_ALPHA
- __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
- #endif
- #ifndef ZERO_BETA
- __m512 BETAVECTOR = _mm512_set1_ps(beta);
- #endif
-
- __m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3, matrixArray_4, matrixArray_5, matrixArray_6, matrixArray_7, \
- matrixArray_8, matrixArray_9, matrixArray_10, matrixArray_11, matrixArray_12, matrixArray_13, matrixArray_14, matrixArray_15;
- __m512 accum512_0, accum512_1, accum512_2, accum512_3;
- __m256 accum256;
- __m128 accum128;
-
- if (tag_m_8x > 0) {
- __m512i xArray_0, xArray_1, xArray_2, xArray_3;
-
- __m512i M512_EPI32_4 = _mm512_set1_epi32(4);
- __m512i idx_base_0 = _mm512_set_epi32(27, 26, 25, 24, 11, 10, 9, 8, 19, 18, 17, 16, 3, 2, 1, 0);
- __m512i idx_base_1 = _mm512_add_epi32(idx_base_0, M512_EPI32_4);
-
- // Prepare X with 2-step interleave way
- xArray_0 = x512;
- BF16_INTERLEAVE_1x32(xArray)
-
- for (BLASLONG idx_m = 0; idx_m < tag_m_8x; idx_m+=8) {
- accum512_0 = _mm512_setzero_ps();
- accum512_1 = _mm512_setzero_ps();
-
- // Load 8 rows from matrix
- BF16_MATRIX_MASKZ_LOAD_8x32(matrixArray, a, lda, idx_m, 0, load_mask)
-
- // 2-step interleave for matrix
- BF16_INTERLEAVE_8x32(matrixArray)
-
- // Calculate the temp result for a..h[0:31]
- BF16_2STEP_INTERLEAVED_DOT_8x32(accum512, matrixArray, xArray)
-
- // Reorder and add up the final result
- accum512_2 = _mm512_permutex2var_ps(accum512_0, idx_base_0, accum512_1);
- accum512_3 = _mm512_permutex2var_ps(accum512_0, idx_base_1, accum512_1);
- accum512_2 = _mm512_add_ps(accum512_2, accum512_3);
- accum256 = _mm256_add_ps(_mm512_castps512_ps256(accum512_2), _mm512_extractf32x8_ps(accum512_2, 1));
- STORE8_COMPLETE_RESULT(accum256, y+idx_m)
- }
-
- if (m - tag_m_8x > 3) {
- accum512_0 = _mm512_setzero_ps();
- accum512_1 = _mm512_setzero_ps();
-
- // Load 4 rows from matrix
- BF16_MATRIX_MASKZ_LOAD_4x32(matrixArray, a, lda, tag_m_8x, 0, load_mask)
-
- // 2-step interleave for matrix
- BF16_INTERLEAVE_4x32(matrixArray)
-
- // Calculate the temp result for a..d[0:31]
- BF16_2STEP_INTERLEAVED_DOT_4x32(accum512, matrixArray, xArray)
-
- accum512_0 = _mm512_add_ps(accum512_0, accum512_1);
- accum256 = _mm256_add_ps(_mm512_castps512_ps256(accum512_0), _mm512_extractf32x8_ps(accum512_0, 1));
- accum128 = _mm_add_ps(_mm256_castps256_ps128(accum256), _mm256_extractf32x4_ps(accum256, 1));
- STORE4_COMPLETE_RESULT(accum128, y+tag_m_8x)
- tag_m_8x += 4;
- }
- }
-
- if (tag_m_8x != m) {
- __m128 tmp128;
- for (BLASLONG i = tag_m_8x; i < m; i++) {
- accum512_0 = _mm512_setzero_ps();
- matrixArray_0 = _mm512_maskz_loadu_epi16(load_mask, &a[(i)*lda]); // Load 1 rows with n=16
- accum512_0 = _mm512_dpbf16_ps(accum512_0, (__m512bh) matrixArray_0, (__m512bh) x512);
- accum256 = _mm256_add_ps(_mm512_castps512_ps256(accum512_0), _mm512_extractf32x8_ps(accum512_0, 1));
- accum128 = _mm_add_ps(_mm256_castps256_ps128(accum256), _mm256_extractf32x4_ps(accum256, 1));
- tmp128 = _mm_shuffle_ps(accum128, accum128, 0x0e);
- accum128 = _mm_add_ps(accum128, tmp128);
- tmp128 = _mm_shuffle_ps(accum128, accum128, 0x01);
- accum128 = _mm_add_ps(accum128, tmp128);
- #ifndef ZERO_BETA
- #ifndef ONE_BETA
- y[i] = alpha * accum128[0] + beta * y[i];
- #else
- y[i] = alpha * accum128[0] + y[i];
- #endif
- #else
- #ifndef ONE_ALPHA
- y[i] = accum128[0] * alpha;
- #else
- y[i] = accum128[0];
- #endif
- #endif
- }
- }
-
- return 0;
- }
-
- // 8 rows parallel processing BF16 GEMV kernel for big N && lda effective scenario (process before interleave)
- #ifndef ZERO_BETA
- #ifndef ONE_BETA
- static int sbgemv_kernel_1x128_lda_direct_alpha_beta(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float beta, float *y)
- #else
- static int sbgemv_kernel_1x128_lda_direct_alpha_one(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float beta, float *y)
- #endif
- #else
- #ifndef ONE_ALPHA
- static int sbgemv_kernel_1x128_lda_direct_alpha(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float *y)
- #else
- static int sbgemv_kernel_1x128_lda_direct(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float *y)
- #endif
- #endif
- {
- BLASLONG tag_m_8x = m & (~7);
- BLASLONG tag_n_32x = n & (~31);
- BLASLONG tag_n_128x = n & (~127);
-
- __m512 accum512_0, accum512_1, accum512_2, accum512_3, accum512_4, accum512_5, accum512_6, accum512_7, \
- accum512_8, accum512_9, accum512_10, accum512_11, accum512_12, accum512_13, accum512_14, accum512_15;
- __m512 accum512_bridge[8];
- __m512 accum512_t_0, accum512_t_1, accum512_t_2, accum512_t_3;
- __m256 accum256_0;
- __m128 accum128;
-
- #ifndef ONE_ALPHA
- __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
- #endif
- #ifndef ZERO_BETA
- __m512 BETAVECTOR = _mm512_set1_ps(beta);
- #endif
-
- __m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3;
- __m512i xArray_0, xArray_1, xArray_2, xArray_3;
-
- unsigned int tail_mask_value = (((unsigned int)0xffffffff) >> (32-(n&31)));
- __mmask32 tail_mask = *((__mmask32*) &tail_mask_value);
-
- __m512i M512_EPI32_4 = _mm512_set1_epi32(4);
- __m512i idx_base_0 = _mm512_set_epi32(27, 26, 25, 24, 11, 10, 9, 8, 19, 18, 17, 16, 3, 2, 1, 0);
- __m512i idx_base_1 = _mm512_add_epi32(idx_base_0, M512_EPI32_4);
-
- if (tag_m_8x > 0) {
- for (BLASLONG idx_m = 0; idx_m < tag_m_8x; idx_m+=8) {
- for (int j = idx_m; j < idx_m + 8; j++) {
- accum512_t_0 = _mm512_setzero_ps();
- accum512_t_1 = _mm512_setzero_ps();
- accum512_t_2 = _mm512_setzero_ps();
- accum512_t_3 = _mm512_setzero_ps();
- /* Processing the main chunk with 128-elements per round */
- for (long idx_n = 0; idx_n < tag_n_128x; idx_n += 128) {
- BF16_MATRIX_LOAD_1x32(matrixArray_0, a, lda, j, idx_n + 0)
- BF16_MATRIX_LOAD_1x32(matrixArray_1, a, lda, j, idx_n + 32)
- BF16_MATRIX_LOAD_1x32(matrixArray_2, a, lda, j, idx_n + 64)
- BF16_MATRIX_LOAD_1x32(matrixArray_3, a, lda, j, idx_n + 96)
-
- BF16_VECTOR_LOAD_1x32(xArray_0, x, idx_n + 0)
- BF16_VECTOR_LOAD_1x32(xArray_1, x, idx_n + 32)
- BF16_VECTOR_LOAD_1x32(xArray_2, x, idx_n + 64)
- BF16_VECTOR_LOAD_1x32(xArray_3, x, idx_n + 96)
-
- BF16_DOT_1x32(accum512_t_0, matrixArray_0, xArray_0)
- BF16_DOT_1x32(accum512_t_1, matrixArray_1, xArray_1)
- BF16_DOT_1x32(accum512_t_2, matrixArray_2, xArray_2)
- BF16_DOT_1x32(accum512_t_3, matrixArray_3, xArray_3)
- }
-
- /* Processing the remaining <128 chunk with 32-elements per round */
- for (long idx_n = tag_n_128x; idx_n < tag_n_32x; idx_n += 32) {
- BF16_MATRIX_LOAD_1x32(matrixArray_0, a, lda, j, idx_n)
- BF16_VECTOR_LOAD_1x32(xArray_0, x, idx_n)
- BF16_DOT_1x32(accum512_t_0, matrixArray_0, xArray_0)
- }
-
- /* Processing the remaining <32 chunk with masked 32-elements processing */
- if ((n&31) != 0) {
- BF16_MATRIX_MASKZ_LOAD_1x32(matrixArray_0, a, lda, j, tag_n_32x, tail_mask)
- BF16_VECTOR_MASKZ_LOAD_1x32(xArray_0, x, tag_n_32x, tail_mask)
- BF16_DOT_1x32(accum512_t_2, matrixArray_0, xArray_0)
- }
-
- /* Accumulate the 4 registers into 1 register */
- accum512_t_0 = _mm512_add_ps(accum512_t_0, accum512_t_1);
- accum512_t_2 = _mm512_add_ps(accum512_t_2, accum512_t_3);
- accum512_t_0 = _mm512_add_ps(accum512_t_0, accum512_t_2);
-
- // Temply save the result into a ZMM
- accum512_bridge[j-idx_m] = accum512_t_0;
- }
-
- FP32_INTERLEAVE_8x16_ARRAY(accum512_bridge)
- FP32_ACCUM2_8x16_ARRAY(accum512_bridge)
- accum512_bridge[1] = _mm512_permutex2var_ps(accum512_bridge[0], idx_base_0, accum512_bridge[4]);
- accum512_bridge[2] = _mm512_permutex2var_ps(accum512_bridge[0], idx_base_1, accum512_bridge[4]);
- accum512_bridge[1] = _mm512_add_ps(accum512_bridge[1], accum512_bridge[2]);
- accum256_0 = _mm256_add_ps(_mm512_castps512_ps256(accum512_bridge[1]), _mm512_extractf32x8_ps(accum512_bridge[1], 1));
- STORE8_COMPLETE_RESULT(accum256_0, y+idx_m)
- }
- }
-
- if (tag_m_8x != m) {
- __m128 tmp128;
- for (BLASLONG j = tag_m_8x; j < m; j++) {
- accum512_t_0 = _mm512_setzero_ps();
- accum512_t_1 = _mm512_setzero_ps();
- accum512_t_2 = _mm512_setzero_ps();
- accum512_t_3 = _mm512_setzero_ps();
- /* Processing the main chunk with 128-elements per round */
- for (long idx_n = 0; idx_n < tag_n_128x; idx_n += 128) {
- BF16_MATRIX_LOAD_1x32(matrixArray_0, a, lda, j, idx_n + 0)
- BF16_MATRIX_LOAD_1x32(matrixArray_1, a, lda, j, idx_n + 32)
- BF16_MATRIX_LOAD_1x32(matrixArray_2, a, lda, j, idx_n + 64)
- BF16_MATRIX_LOAD_1x32(matrixArray_3, a, lda, j, idx_n + 96)
-
- BF16_VECTOR_LOAD_1x32(xArray_0, x, idx_n + 0)
- BF16_VECTOR_LOAD_1x32(xArray_1, x, idx_n + 32)
- BF16_VECTOR_LOAD_1x32(xArray_2, x, idx_n + 64)
- BF16_VECTOR_LOAD_1x32(xArray_3, x, idx_n + 96)
-
- BF16_DOT_1x32(accum512_t_0, matrixArray_0, xArray_0)
- BF16_DOT_1x32(accum512_t_1, matrixArray_1, xArray_1)
- BF16_DOT_1x32(accum512_t_2, matrixArray_2, xArray_2)
- BF16_DOT_1x32(accum512_t_3, matrixArray_3, xArray_3)
- }
-
- /* Processing the remaining <128 chunk with 32-elements per round */
- for (long idx_n = tag_n_128x; idx_n < tag_n_32x; idx_n += 32) {
- BF16_MATRIX_LOAD_1x32(matrixArray_0, a, lda, j, idx_n)
- BF16_VECTOR_LOAD_1x32(xArray_0, x, idx_n)
- BF16_DOT_1x32(accum512_t_0, matrixArray_0, xArray_0)
- }
-
- /* Processing the remaining <32 chunk with masked 32-elements processing */
- if ((n&31) != 0) {
- BF16_MATRIX_MASKZ_LOAD_1x32(matrixArray_0, a, lda, j, tag_n_32x, tail_mask)
- BF16_VECTOR_MASKZ_LOAD_1x32(xArray_0, x, tag_n_32x, tail_mask)
- BF16_DOT_1x32(accum512_t_2, matrixArray_0, xArray_0)
- }
-
- /* Accumulate the 4 registers into 1 register */
- accum512_t_0 = _mm512_add_ps(accum512_t_0, accum512_t_1);
- accum512_t_2 = _mm512_add_ps(accum512_t_2, accum512_t_3);
- accum512_t_0 = _mm512_add_ps(accum512_t_0, accum512_t_2);
-
- accum256_0 = _mm256_add_ps(_mm512_castps512_ps256(accum512_t_0), _mm512_extractf32x8_ps(accum512_t_0, 1));
- accum128 = _mm_add_ps(_mm256_castps256_ps128(accum256_0), _mm256_extractf32x4_ps(accum256_0, 1));
- tmp128 = _mm_shuffle_ps(accum128, accum128, 0x0e);
- accum128 = _mm_add_ps(accum128, tmp128);
- tmp128 = _mm_shuffle_ps(accum128, accum128, 0x01);
- accum128 = _mm_add_ps(accum128, tmp128);
- #ifndef ZERO_BETA
- #ifndef ONE_BETA
- y[j] = alpha * accum128[0] + beta * y[j];
- #else
- y[j] = alpha * accum128[0] + y[j];
- #endif
- #else
- #ifndef ONE_ALPHA
- y[j] = accum128[0] * alpha;
- #else
- y[j] = accum128[0];
- #endif
- #endif
- }
- }
-
- return 0;
- }
-
- // 8 rows parallel processing BF16 GEMV kernel for n=32 && lda effective scenario (process before interleave)
- #ifndef ZERO_BETA
- #ifndef ONE_BETA
- static int sbgemv_kernel_8x32_lda_direct_alpha_beta(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float beta, float *y)
- #else
- static int sbgemv_kernel_8x32_lda_direct_alpha_one(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float beta, float *y)
- #endif
- #else
- #ifndef ONE_ALPHA
- static int sbgemv_kernel_8x32_lda_direct_alpha(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float *y)
- #else
- static int sbgemv_kernel_8x32_lda_direct(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float *y)
- #endif
- #endif
- {
- BLASLONG tag_m_8x = m & (~7);
- BLASLONG tag_n_32x = n & (~31);
-
- __m512 accum512_0, accum512_1, accum512_2, accum512_3, accum512_4, accum512_5, accum512_6, accum512_7, \
- accum512_8, accum512_9, accum512_10, accum512_11, accum512_12, accum512_13, accum512_14, accum512_15;
- __m256 accum256_0;
- __m128 accum128;
-
- #ifndef ONE_ALPHA
- __m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
- #endif
- #ifndef ZERO_BETA
- __m512 BETAVECTOR = _mm512_set1_ps(beta);
- #endif
-
- __m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3, matrixArray_4, matrixArray_5, matrixArray_6, matrixArray_7;
- __m512i xArray_0;
-
- unsigned int tail_mask_value = (((unsigned int)0xffffffff) >> (32-(n&31)));
- __mmask32 tail_mask = *((__mmask32*) &tail_mask_value);
-
- if (tag_m_8x > 0) {
- __m512i M512_EPI32_4 = _mm512_set1_epi32(4);
- __m512i idx_base_0 = _mm512_set_epi32(27, 26, 25, 24, 11, 10, 9, 8, 19, 18, 17, 16, 3, 2, 1, 0);
- __m512i idx_base_1 = _mm512_add_epi32(idx_base_0, M512_EPI32_4);
-
- for (BLASLONG idx_m = 0; idx_m < tag_m_8x; idx_m+=8) {
- accum512_0 = _mm512_setzero_ps();
- accum512_1 = _mm512_setzero_ps();
- accum512_2 = _mm512_setzero_ps();
- accum512_3 = _mm512_setzero_ps();
- accum512_4 = _mm512_setzero_ps();
- accum512_5 = _mm512_setzero_ps();
- accum512_6 = _mm512_setzero_ps();
- accum512_7 = _mm512_setzero_ps();
-
- for (BLASLONG idx_n = 0; idx_n < tag_n_32x; idx_n+=32) {
- // Load 8 rows from matrix
- BF16_MATRIX_LOAD_8x32(matrixArray, a, lda, idx_m, idx_n)
-
- // Load x
- BF16_VECTOR_LOAD_1x32(xArray_0, x, idx_n)
-
- // Calculate the temp result for a..h[0:31]
- BF16_DOT_8x32(accum512, matrixArray, xArray_0)
- }
-
- if (tag_n_32x != n) { // Go with masked 512
- // Load 8 rows from matrix
- BF16_MATRIX_MASKZ_LOAD_8x32(matrixArray, a, lda, idx_m, tag_n_32x, tail_mask)
-
- // Load x
- BF16_VECTOR_MASKZ_LOAD_1x32(xArray_0, x, tag_n_32x, tail_mask)
-
- // Calculate the temp result for a..h[0:31]
- BF16_DOT_8x32(accum512, matrixArray, xArray_0)
- }
-
- // 2-step interleave for FP32 regsiter array
- FP32_INTERLEAVE_8x16(accum512)
-
- // Accumulate the 2 batch of registers into 2 register (0 and 4)
- FP32_ACCUM2_8x16(accum512)
-
- accum512_1 = _mm512_permutex2var_ps(accum512_0, idx_base_0, accum512_4);
- accum512_2 = _mm512_permutex2var_ps(accum512_0, idx_base_1, accum512_4);
- accum512_1 = _mm512_add_ps(accum512_1, accum512_2);
- accum256_0 = _mm256_add_ps(_mm512_castps512_ps256(accum512_1), _mm512_extractf32x8_ps(accum512_1, 1));
- STORE8_COMPLETE_RESULT(accum256_0, y+idx_m)
- }
- }
-
- if (tag_m_8x != m) {
- __m128 tmp128;
- for (BLASLONG i = tag_m_8x; i < m; i++) {
- accum512_0 = _mm512_setzero_ps();
- for (BLASLONG idx_n = 0; idx_n < tag_n_32x; idx_n+=32) {
- // Load 32 elements from matrix
- BF16_MATRIX_LOAD_1x32(matrixArray_0, a, lda, i, idx_n)
-
- // Load 32 elements from x
- BF16_VECTOR_LOAD_1x32(xArray_0, x, idx_n)
-
- // Calculate and accumulate the temp result
- BF16_DOT_1x32(accum512_0, matrixArray_0, xArray_0)
- }
-
- if (tag_n_32x != n) {
- // Load tail elements from matrix
- BF16_MATRIX_MASKZ_LOAD_1x32(matrixArray_0, a, lda, i, tag_n_32x, tail_mask)
-
- // Load 32 elements from x
- BF16_VECTOR_MASKZ_LOAD_1x32(xArray_0, x, tag_n_32x, tail_mask)
-
- // Calculate and accumulate the temp result
- BF16_DOT_1x32(accum512_0, matrixArray_0, xArray_0)
- }
-
- accum256_0 = _mm256_add_ps(_mm512_castps512_ps256(accum512_0), _mm512_extractf32x8_ps(accum512_0, 1));
- accum128 = _mm_add_ps(_mm256_castps256_ps128(accum256_0), _mm256_extractf32x4_ps(accum256_0, 1));
- tmp128 = _mm_shuffle_ps(accum128, accum128, 0x0e);
- accum128 = _mm_add_ps(accum128, tmp128);
- tmp128 = _mm_shuffle_ps(accum128, accum128, 0x01);
- accum128 = _mm_add_ps(accum128, tmp128);
- #ifndef ZERO_BETA
- #ifndef ONE_BETA
- y[i] = alpha * accum128[0] + beta * y[i];
- #else
- y[i] = alpha * accum128[0] + y[i];
- #endif
- #else
- #ifndef ONE_ALPHA
- y[i] = accum128[0] * alpha;
- #else
- y[i] = accum128[0];
- #endif
- #endif
- }
- }
-
- return 0;
- }
-
- // 8 rows parallel processing BF16 GEMV kernel for n<16 && lda effective scenario
- #ifndef ZERO_BETA
- #ifndef ONE_BETA
- static int sbgemv_kernel_8x16m_lda_alpha_beta(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float beta, float *y)
- #else
- static int sbgemv_kernel_8x16m_lda_alpha_one(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float beta, float *y)
- #endif
- #else
- #ifndef ONE_ALPHA
- static int sbgemv_kernel_8x16m_lda_alpha(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float *y)
- #else
- static int sbgemv_kernel_8x16m_lda(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, float *y)
- #endif
- #endif
- {
- BLASLONG tag_m_8x = m & (~7);
-
- __m256i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3, matrixArray_4, matrixArray_5, matrixArray_6, matrixArray_7;
- __m256i xArray256;
-
- // Keep align with other kernels and macro definition, the high 256bit is never used
- #ifndef ONE_ALPHA
- __m512 ALPHAVECTOR = _mm512_castps256_ps512(_mm256_set1_ps(alpha));
- #endif
- #ifndef ZERO_BETA
- __m512 BETAVECTOR = _mm512_castps256_ps512(_mm256_set1_ps(beta));
- #endif
-
- __m256 accum256_0, accum256_1, accum256_2, accum256_3, accum256_4, accum256_5, accum256_6, accum256_7, \
- accum256_8, accum256_9, accum256_10, accum256_11, accum256_12, accum256_13, accum256_14, accum256_15;
-
- __m256i M256_EPI32_4 = _mm256_set1_epi32(4);
- __m256i idx_base_0 = _mm256_set_epi32(11, 10, 9, 8, 3, 2, 1, 0);
- __m256i idx_base_1 = _mm256_add_epi32(idx_base_0, M256_EPI32_4);
-
- unsigned short load_mask_value = (((unsigned short)0xffff) >> (16-n));
- __mmask16 load_mask = *((__mmask16*) &load_mask_value);
-
- if (n == 16) {
- BF16_VECTOR_LOAD_1x16(xArray256, x, 0)
- } else {
- BF16_VECTOR_MASKZ_LOAD_1x16(xArray256, x, 0, load_mask)
- }
-
- if (n == 16) {
- for (BLASLONG idx_m = 0; idx_m < tag_m_8x; idx_m+=8) {
- accum256_0 = _mm256_setzero_ps();
- accum256_1 = _mm256_setzero_ps();
- accum256_2 = _mm256_setzero_ps();
- accum256_3 = _mm256_setzero_ps();
- accum256_4 = _mm256_setzero_ps();
- accum256_5 = _mm256_setzero_ps();
- accum256_6 = _mm256_setzero_ps();
- accum256_7 = _mm256_setzero_ps();
-
- BF16_MATRIX_LOAD_8x16(matrixArray, a, lda, idx_m, 0)
-
- BF16_DOT_8x16(accum256, matrixArray, xArray256)
-
- // 2-step interleave for FP32 regsiter array
- FP32_INTERLEAVE_8x8(accum256)
-
- // Accumulate the 2 batch of registers into 2 register (0 and 4)
- FP32_ACCUM2_8x8(accum256)
-
- accum256_1 = _mm256_permutex2var_ps(accum256_0, idx_base_0, accum256_4);
- accum256_2 = _mm256_permutex2var_ps(accum256_0, idx_base_1, accum256_4);
- accum256_1 = _mm256_add_ps(accum256_1, accum256_2);
-
- STORE8_COMPLETE_RESULT(accum256_1, y+idx_m)
- }
-
- if (tag_m_8x != m) {
- __m128 accum128, tmp128;
- for (BLASLONG i = tag_m_8x; i < m; i++) {
- accum256_0 = _mm256_setzero_ps();
- matrixArray_0 = _mm256_loadu_si256(&a[(i)*lda]); // Load 1 rows with n=16
- accum256_0 = _mm256_dpbf16_ps(accum256_0, (__m256bh) matrixArray_0, (__m256bh) xArray256);
- accum128 = _mm_add_ps(_mm256_castps256_ps128(accum256_0), _mm256_extractf32x4_ps(accum256_0, 1));
- tmp128 = _mm_shuffle_ps(accum128, accum128, 0x0e);
- accum128 = _mm_add_ps(accum128, tmp128);
- tmp128 = _mm_shuffle_ps(accum128, accum128, 0x01);
- accum128 = _mm_add_ps(accum128, tmp128);
- y[i] += accum128[0] * alpha;
- }
- }
- } else {
- for (BLASLONG idx_m = 0; idx_m < tag_m_8x; idx_m+=8) {
- accum256_0 = _mm256_setzero_ps();
- accum256_1 = _mm256_setzero_ps();
- accum256_2 = _mm256_setzero_ps();
- accum256_3 = _mm256_setzero_ps();
- accum256_4 = _mm256_setzero_ps();
- accum256_5 = _mm256_setzero_ps();
- accum256_6 = _mm256_setzero_ps();
- accum256_7 = _mm256_setzero_ps();
-
- BF16_MATRIX_MASKZ_LOAD_8x16(matrixArray, a, lda, idx_m, 0, load_mask)
-
- BF16_DOT_8x16(accum256, matrixArray, xArray256)
-
- // 2-step interleave for FP32 regsiter array
- FP32_INTERLEAVE_8x8(accum256)
-
- // Accumulate the 2 batch of registers into 2 register (0 and 4)
- FP32_ACCUM2_8x8(accum256)
-
- accum256_1 = _mm256_permutex2var_ps(accum256_0, idx_base_0, accum256_4);
- accum256_2 = _mm256_permutex2var_ps(accum256_0, idx_base_1, accum256_4);
- accum256_1 = _mm256_add_ps(accum256_1, accum256_2);
-
- STORE8_COMPLETE_RESULT(accum256_1, y+idx_m)
- }
-
- if (tag_m_8x != m) {
- __m128 accum128, tmp128;
- for (BLASLONG i = tag_m_8x; i < m; i++) {
- accum256_0 = _mm256_setzero_ps();
- matrixArray_0 = _mm256_maskz_loadu_epi16(load_mask, &a[(i)*lda]); // Load 1 rows with n=16
- accum256_0 = _mm256_dpbf16_ps(accum256_0, (__m256bh) matrixArray_0, (__m256bh) xArray256);
- accum128 = _mm_add_ps(_mm256_castps256_ps128(accum256_0), _mm256_extractf32x4_ps(accum256_0, 1));
- tmp128 = _mm_shuffle_ps(accum128, accum128, 0x0e);
- accum128 = _mm_add_ps(accum128, tmp128);
- tmp128 = _mm_shuffle_ps(accum128, accum128, 0x01);
- accum128 = _mm_add_ps(accum128, tmp128);
- #ifndef ZERO_BETA
- #ifndef ONE_BETA
- y[i] = alpha * accum128[0] + beta * y[i];
- #else
- y[i] = alpha * accum128[0] + y[i];
- #endif
- #else
- #ifndef ONE_ALPHA
- y[i] = accum128[0] * alpha;
- #else
- y[i] = accum128[0];
- #endif
- #endif
- }
- }
- }
-
- return 0;
- }
|