|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767 |
-
- #include "common.h"
- #include <riscv_vector.h>
-
- int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B, FLOAT *C, BLASLONG ldc)
- {
- BLASLONG gvl = 0;
- BLASLONG m_top = 0;
- BLASLONG n_top = 0;
-
- // -- MAIN PASS
- for (BLASLONG j=0; j<N/8; j+=1) {
- m_top = 0;
- BLASLONG gvl = __riscv_vsetvl_e16m1(8);
-
- for (BLASLONG i=0; i<M/8; i+=1) {
- BLASLONG ai=m_top*K;
- BLASLONG bi=n_top*K;
-
- _Float16 B0 = B[bi+0];
- _Float16 B1 = B[bi+1];
- _Float16 B2 = B[bi+2];
- _Float16 B3 = B[bi+3];
- _Float16 B4 = B[bi+4];
- _Float16 B5 = B[bi+5];
- _Float16 B6 = B[bi+6];
- _Float16 B7 = B[bi+7];
- bi += 8;
-
- vfloat16m1_t A0 = __riscv_vle16_v_f16m1( &A[ai+0*gvl], gvl );
- ai += 8;
-
- vfloat32m2_t result0 = __riscv_vfwmul_vf_f32m2( A0, B0, gvl);
- vfloat32m2_t result1 = __riscv_vfwmul_vf_f32m2( A0, B1, gvl);
- vfloat32m2_t result2 = __riscv_vfwmul_vf_f32m2( A0, B2, gvl);
- vfloat32m2_t result3 = __riscv_vfwmul_vf_f32m2( A0, B3, gvl);
- vfloat32m2_t result4 = __riscv_vfwmul_vf_f32m2( A0, B4, gvl);
- vfloat32m2_t result5 = __riscv_vfwmul_vf_f32m2( A0, B5, gvl);
- vfloat32m2_t result6 = __riscv_vfwmul_vf_f32m2( A0, B6, gvl);
- vfloat32m2_t result7 = __riscv_vfwmul_vf_f32m2( A0, B7, gvl);
-
- for(BLASLONG k=1; k<K; k++) {
- B0 = B[bi+0];
- B1 = B[bi+1];
- B2 = B[bi+2];
- B3 = B[bi+3];
- B4 = B[bi+4];
- B5 = B[bi+5];
- B6 = B[bi+6];
- B7 = B[bi+7];
- bi += 8;
-
- A0 = __riscv_vle16_v_f16m1( &A[ai+0*gvl], gvl );
- ai += 8;
-
-
- result0 = __riscv_vfwmacc_vf_f32m2(result0, B0, A0, gvl);
- result1 = __riscv_vfwmacc_vf_f32m2(result1, B1, A0, gvl);
- result2 = __riscv_vfwmacc_vf_f32m2(result2, B2, A0, gvl);
- result3 = __riscv_vfwmacc_vf_f32m2(result3, B3, A0, gvl);
- result4 = __riscv_vfwmacc_vf_f32m2(result4, B4, A0, gvl);
- result5 = __riscv_vfwmacc_vf_f32m2(result5, B5, A0, gvl);
- result6 = __riscv_vfwmacc_vf_f32m2(result6, B6, A0, gvl);
- result7 = __riscv_vfwmacc_vf_f32m2(result7, B7, A0, gvl);
- }
-
-
- BLASLONG ci=n_top*ldc+m_top;
-
- vfloat32m2_t c0 = __riscv_vle32_v_f32m2( &C[ci], gvl); ci += ldc-gvl*0;
- vfloat32m2_t c1 = __riscv_vle32_v_f32m2( &C[ci], gvl); ci += ldc-gvl*0;
- vfloat32m2_t c2 = __riscv_vle32_v_f32m2( &C[ci], gvl); ci += ldc-gvl*0;
- vfloat32m2_t c3 = __riscv_vle32_v_f32m2( &C[ci], gvl); ci += ldc-gvl*0;
- vfloat32m2_t c4 = __riscv_vle32_v_f32m2( &C[ci], gvl); ci += ldc-gvl*0;
- vfloat32m2_t c5 = __riscv_vle32_v_f32m2( &C[ci], gvl); ci += ldc-gvl*0;
- vfloat32m2_t c6 = __riscv_vle32_v_f32m2( &C[ci], gvl); ci += ldc-gvl*0;
- vfloat32m2_t c7 = __riscv_vle32_v_f32m2( &C[ci], gvl); ci += ldc-gvl*0;
-
- c0 = __riscv_vfmacc_vf_f32m2(c0, alpha, result0, gvl);
- c1 = __riscv_vfmacc_vf_f32m2(c1, alpha, result1, gvl);
- c2 = __riscv_vfmacc_vf_f32m2(c2, alpha, result2, gvl);
- c3 = __riscv_vfmacc_vf_f32m2(c3, alpha, result3, gvl);
- c4 = __riscv_vfmacc_vf_f32m2(c4, alpha, result4, gvl);
- c5 = __riscv_vfmacc_vf_f32m2(c5, alpha, result5, gvl);
- c6 = __riscv_vfmacc_vf_f32m2(c6, alpha, result6, gvl);
- c7 = __riscv_vfmacc_vf_f32m2(c7, alpha, result7, gvl);
-
- ci = n_top * ldc + m_top;
-
- __riscv_vse32_v_f32m2( &C[ci], c0, gvl); ci += ldc-gvl*0;
- __riscv_vse32_v_f32m2( &C[ci], c1, gvl); ci += ldc-gvl*0;
- __riscv_vse32_v_f32m2( &C[ci], c2, gvl); ci += ldc-gvl*0;
- __riscv_vse32_v_f32m2( &C[ci], c3, gvl); ci += ldc-gvl*0;
- __riscv_vse32_v_f32m2( &C[ci], c4, gvl); ci += ldc-gvl*0;
- __riscv_vse32_v_f32m2( &C[ci], c5, gvl); ci += ldc-gvl*0;
- __riscv_vse32_v_f32m2( &C[ci], c6, gvl); ci += ldc-gvl*0;
- __riscv_vse32_v_f32m2( &C[ci], c7, gvl); ci += ldc-gvl*0;
- m_top += 8;
- }
-
- // -- tails for main pass --
-
- if( M & 4 ) {
- gvl = __riscv_vsetvl_e16m1(4);
-
- BLASLONG ai=m_top*K;
- BLASLONG bi=n_top*K;
- _Float16 B0 = B[bi+0];
- _Float16 B1 = B[bi+1];
- _Float16 B2 = B[bi+2];
- _Float16 B3 = B[bi+3];
- _Float16 B4 = B[bi+4];
- _Float16 B5 = B[bi+5];
- _Float16 B6 = B[bi+6];
- _Float16 B7 = B[bi+7];
- bi += 8;
-
- vfloat16m1_t A0 = __riscv_vle16_v_f16m1(&A[ai + 0 * gvl], gvl);
- ai += 4;
-
- vfloat32m2_t result0 = __riscv_vfwmul_vf_f32m2( A0, B0, gvl);
- vfloat32m2_t result1 = __riscv_vfwmul_vf_f32m2( A0, B1, gvl);
- vfloat32m2_t result2 = __riscv_vfwmul_vf_f32m2( A0, B2, gvl);
- vfloat32m2_t result3 = __riscv_vfwmul_vf_f32m2( A0, B3, gvl);
- vfloat32m2_t result4 = __riscv_vfwmul_vf_f32m2( A0, B4, gvl);
- vfloat32m2_t result5 = __riscv_vfwmul_vf_f32m2( A0, B5, gvl);
- vfloat32m2_t result6 = __riscv_vfwmul_vf_f32m2( A0, B6, gvl);
- vfloat32m2_t result7 = __riscv_vfwmul_vf_f32m2( A0, B7, gvl);
-
- for(BLASLONG k=1; k < K; ++k) {
- B0 = B[bi+0];
- B1 = B[bi+1];
- B2 = B[bi+2];
- B3 = B[bi+3];
- B4 = B[bi+4];
- B5 = B[bi+5];
- B6 = B[bi+6];
- B7 = B[bi+7];
- bi += 8;
-
- A0 = __riscv_vle16_v_f16m1(&A[ai + 0 * gvl], gvl);
- ai += 4;
-
- result0 = __riscv_vfwmacc_vf_f32m2(result0, B0, A0, gvl);
- result1 = __riscv_vfwmacc_vf_f32m2(result1, B1, A0, gvl);
- result2 = __riscv_vfwmacc_vf_f32m2(result2, B2, A0, gvl);
- result3 = __riscv_vfwmacc_vf_f32m2(result3, B3, A0, gvl);
- result4 = __riscv_vfwmacc_vf_f32m2(result4, B4, A0, gvl);
- result5 = __riscv_vfwmacc_vf_f32m2(result5, B5, A0, gvl);
- result6 = __riscv_vfwmacc_vf_f32m2(result6, B6, A0, gvl);
- result7 = __riscv_vfwmacc_vf_f32m2(result7, B7, A0, gvl);
- }
-
- BLASLONG ci = n_top * ldc + m_top;
-
- vfloat32m2_t c0 = __riscv_vle32_v_f32m2(&C[ci], gvl);
- ci += ldc - gvl * 0;
- vfloat32m2_t c1 = __riscv_vle32_v_f32m2(&C[ci], gvl);
- ci += ldc - gvl * 0;
- vfloat32m2_t c2 = __riscv_vle32_v_f32m2(&C[ci], gvl);
- ci += ldc - gvl * 0;
- vfloat32m2_t c3 = __riscv_vle32_v_f32m2(&C[ci], gvl);
- ci += ldc - gvl * 0;
- vfloat32m2_t c4 = __riscv_vle32_v_f32m2(&C[ci], gvl);
- ci += ldc - gvl * 0;
- vfloat32m2_t c5 = __riscv_vle32_v_f32m2(&C[ci], gvl);
- ci += ldc - gvl * 0;
- vfloat32m2_t c6 = __riscv_vle32_v_f32m2(&C[ci], gvl);
- ci += ldc - gvl * 0;
- vfloat32m2_t c7 = __riscv_vle32_v_f32m2(&C[ci], gvl);
- c0 = __riscv_vfmacc_vf_f32m2(c0, alpha, result0, gvl);
- c1 = __riscv_vfmacc_vf_f32m2(c1, alpha, result1, gvl);
- c2 = __riscv_vfmacc_vf_f32m2(c2, alpha, result2, gvl);
- c3 = __riscv_vfmacc_vf_f32m2(c3, alpha, result3, gvl);
- c4 = __riscv_vfmacc_vf_f32m2(c4, alpha, result4, gvl);
- c5 = __riscv_vfmacc_vf_f32m2(c5, alpha, result5, gvl);
- c6 = __riscv_vfmacc_vf_f32m2(c6, alpha, result6, gvl);
- c7 = __riscv_vfmacc_vf_f32m2(c7, alpha, result7, gvl);
-
- ci= n_top * ldc + m_top;
-
- __riscv_vse32_v_f32m2(&C[ci], c0, gvl); ci += ldc - gvl * 0;
- __riscv_vse32_v_f32m2(&C[ci], c1, gvl); ci += ldc - gvl * 0;
- __riscv_vse32_v_f32m2(&C[ci], c2, gvl); ci += ldc - gvl * 0;
- __riscv_vse32_v_f32m2(&C[ci], c3, gvl); ci += ldc - gvl * 0;
- __riscv_vse32_v_f32m2(&C[ci], c4, gvl); ci += ldc - gvl * 0;
- __riscv_vse32_v_f32m2(&C[ci], c5, gvl); ci += ldc - gvl * 0;
- __riscv_vse32_v_f32m2(&C[ci], c6, gvl); ci += ldc - gvl * 0;
- __riscv_vse32_v_f32m2(&C[ci], c7, gvl);
-
- m_top += 4;
- }
-
-
- if( M & 2 ) {
-
- BLASLONG ai = m_top * K;
- BLASLONG bi = n_top * K;
-
- float result0 = 0;
- float result1 = 0;
- float result2 = 0;
- float result3 = 0;
- float result4 = 0;
- float result5 = 0;
- float result6 = 0;
- float result7 = 0;
- float result8 = 0;
- float result9 = 0;
- float result10 = 0;
- float result11 = 0;
- float result12 = 0;
- float result13 = 0;
- float result14 = 0;
- float result15 = 0;
-
- for(BLASLONG k=0; k<K; k++) {
- result0+=(float)(A[ai+0]*B[bi+0]);
- result1+=(float)(A[ai+1]*B[bi+0]);
- result2+=(float)(A[ai+0]*B[bi+1]);
- result3+=(float)(A[ai+1]*B[bi+1]);
- result4+=(float)(A[ai+0]*B[bi+2]);
- result5+=(float)(A[ai+1]*B[bi+2]);
- result6+=(float)(A[ai+0]*B[bi+3]);
- result7+=(float)(A[ai+1]*B[bi+3]);
- result8+=(float)(A[ai+0]*B[bi+4]);
- result9+=(float)(A[ai+1]*B[bi+4]);
- result10+=(float)(A[ai+0]*B[bi+5]);
- result11+=(float)(A[ai+1]*B[bi+5]);
- result12+=(float)(A[ai+0]*B[bi+6]);
- result13+=(float)(A[ai+1]*B[bi+6]);
- result14+=(float)(A[ai+0]*B[bi+7]);
- result15+=(float)(A[ai+1]*B[bi+7]);
- ai+=2;
- bi+=8;
- }
-
-
- BLASLONG ci=n_top*ldc+m_top;
- C[ci + 0 * ldc + 0] += alpha * result0;
- C[ci + 0 * ldc + 1] += alpha * result1;
- C[ci + 1 * ldc + 0] += alpha * result2;
- C[ci + 1 * ldc + 1] += alpha * result3;
- C[ci + 2 * ldc + 0] += alpha * result4;
- C[ci + 2 * ldc + 1] += alpha * result5;
- C[ci + 3 * ldc + 0] += alpha * result6;
- C[ci + 3 * ldc + 1] += alpha * result7;
- C[ci + 4 * ldc + 0] += alpha * result8;
- C[ci + 4 * ldc + 1] += alpha * result9;
- C[ci + 5 * ldc + 0] += alpha * result10;
- C[ci + 5 * ldc + 1] += alpha * result11;
- C[ci + 6 * ldc + 0] += alpha * result12;
- C[ci + 6 * ldc + 1] += alpha * result13;
- C[ci + 7 * ldc + 0] += alpha * result14;
- C[ci + 7 * ldc + 1] += alpha * result15;
-
- m_top+=2;
- }
-
-
- if( M & 1 ) {
-
- float result0 = 0;
- float result1 = 0;
- float result2 = 0;
- float result3 = 0;
- float result4 = 0;
- float result5 = 0;
- float result6 = 0;
- float result7 = 0;
-
- BLASLONG ai = m_top * K;
- BLASLONG bi = n_top * K;
-
- for(BLASLONG k=0; k<K; k++) {
- result0+=(float)(A[ai+0]*B[bi+0]);
- result1+=(float)(A[ai+0]*B[bi+1]);
- result2+=(float)(A[ai+0]*B[bi+2]);
- result3+=(float)(A[ai+0]*B[bi+3]);
- result4+=(float)(A[ai+0]*B[bi+4]);
- result5+=(float)(A[ai+0]*B[bi+5]);
- result6+=(float)(A[ai+0]*B[bi+6]);
- result7+=(float)(A[ai+0]*B[bi+7]);
- ai+=1;
- bi+=8;
- }
-
- BLASLONG ci = n_top * ldc + m_top;
- C[ci + 0 * ldc + 0] += alpha * result0;
- C[ci + 1 * ldc + 0] += alpha * result1;
- C[ci + 2 * ldc + 0] += alpha * result2;
- C[ci + 3 * ldc + 0] += alpha * result3;
- C[ci + 4 * ldc + 0] += alpha * result4;
- C[ci + 5 * ldc + 0] += alpha * result5;
- C[ci + 6 * ldc + 0] += alpha * result6;
- C[ci + 7 * ldc + 0] += alpha * result7;
- m_top+=1;
- }
-
- n_top += 8;
- }
-
- // -- tails for N=4
- if( N & 4 ) {
- gvl = __riscv_vsetvl_e16m1(8);
- m_top = 0;
-
- for (BLASLONG i=0; i<M/8; i+=1) {
- BLASLONG ai=m_top*K;
- BLASLONG bi=n_top*K;
-
- _Float16 B0 = B[bi+0];
- _Float16 B1 = B[bi+1];
- _Float16 B2 = B[bi+2];
- _Float16 B3 = B[bi+3];
- bi += 4;
-
- vfloat16m1_t A0 = __riscv_vle16_v_f16m1( &A[ai+0*gvl], gvl );
- ai += 8;
-
- vfloat32m2_t result0 = __riscv_vfwmul_vf_f32m2( A0, B0, gvl);
- vfloat32m2_t result1 = __riscv_vfwmul_vf_f32m2( A0, B1, gvl);
- vfloat32m2_t result2 = __riscv_vfwmul_vf_f32m2( A0, B2, gvl);
- vfloat32m2_t result3 = __riscv_vfwmul_vf_f32m2( A0, B3, gvl);
-
- for(BLASLONG k=1; k<K; k++) {
- B0 = B[bi+0];
- B1 = B[bi+1];
- B2 = B[bi+2];
- B3 = B[bi+3];
- bi += 4;
-
- A0 = __riscv_vle16_v_f16m1( &A[ai+0*gvl], gvl );
- ai += 8;
-
- result0 = __riscv_vfwmacc_vf_f32m2(result0, B0, A0, gvl);
- result1 = __riscv_vfwmacc_vf_f32m2(result1, B1, A0, gvl);
- result2 = __riscv_vfwmacc_vf_f32m2(result2, B2, A0, gvl);
- result3 = __riscv_vfwmacc_vf_f32m2(result3, B3, A0, gvl);
- }
-
-
- BLASLONG ci=n_top*ldc+m_top;
-
- vfloat32m2_t c0 = __riscv_vle32_v_f32m2( &C[ci], gvl); ci += ldc - gvl * 0;
- vfloat32m2_t c1 = __riscv_vle32_v_f32m2( &C[ci], gvl); ci += ldc - gvl * 0;
- vfloat32m2_t c2 = __riscv_vle32_v_f32m2( &C[ci], gvl); ci += ldc - gvl * 0;
- vfloat32m2_t c3 = __riscv_vle32_v_f32m2( &C[ci], gvl);
-
- c0 = __riscv_vfmacc_vf_f32m2(c0, alpha, result0, gvl);
- c1 = __riscv_vfmacc_vf_f32m2(c1, alpha, result1, gvl);
- c2 = __riscv_vfmacc_vf_f32m2(c2, alpha, result2, gvl);
- c3 = __riscv_vfmacc_vf_f32m2(c3, alpha, result3, gvl);
-
- ci = n_top * ldc + m_top;
-
- __riscv_vse32_v_f32m2( &C[ci], c0, gvl); ci += ldc-gvl*0;
- __riscv_vse32_v_f32m2( &C[ci], c1, gvl); ci += ldc-gvl*0;
- __riscv_vse32_v_f32m2( &C[ci], c2, gvl); ci += ldc-gvl*0;
- __riscv_vse32_v_f32m2( &C[ci], c3, gvl);
- m_top += 8;
- }
-
- if( M & 4 ) {
- gvl = __riscv_vsetvl_e16m1(4);
-
- BLASLONG ai=m_top*K;
- BLASLONG bi=n_top*K;
- _Float16 B0 = B[bi+0];
- _Float16 B1 = B[bi+1];
- _Float16 B2 = B[bi+2];
- _Float16 B3 = B[bi+3];
- bi += 4;
-
- vfloat16m1_t A0 = __riscv_vle16_v_f16m1(&A[ai + 0 * gvl], gvl);
- ai += 4;
-
- vfloat32m2_t result0 = __riscv_vfwmul_vf_f32m2( A0, B0, gvl);
- vfloat32m2_t result1 = __riscv_vfwmul_vf_f32m2( A0, B1, gvl);
- vfloat32m2_t result2 = __riscv_vfwmul_vf_f32m2( A0, B2, gvl);
- vfloat32m2_t result3 = __riscv_vfwmul_vf_f32m2( A0, B3, gvl);
-
- for(BLASLONG k=1; k < K; ++k) {
- B0 = B[bi+0];
- B1 = B[bi+1];
- B2 = B[bi+2];
- B3 = B[bi+3];
- bi += 4;
-
- A0 = __riscv_vle16_v_f16m1(&A[ai + 0 * gvl], gvl);
- ai += 4;
-
- result0 = __riscv_vfwmacc_vf_f32m2(result0, B0, A0, gvl);
- result1 = __riscv_vfwmacc_vf_f32m2(result1, B1, A0, gvl);
- result2 = __riscv_vfwmacc_vf_f32m2(result2, B2, A0, gvl);
- result3 = __riscv_vfwmacc_vf_f32m2(result3, B3, A0, gvl);
- }
-
- BLASLONG ci = n_top * ldc + m_top;
-
- vfloat32m2_t c0 = __riscv_vle32_v_f32m2(&C[ci], gvl);
- ci += ldc - gvl * 0;
- vfloat32m2_t c1 = __riscv_vle32_v_f32m2(&C[ci], gvl);
- ci += ldc - gvl * 0;
- vfloat32m2_t c2 = __riscv_vle32_v_f32m2(&C[ci], gvl);
- ci += ldc - gvl * 0;
- vfloat32m2_t c3 = __riscv_vle32_v_f32m2(&C[ci], gvl);
- c0 = __riscv_vfmacc_vf_f32m2(c0, alpha, result0, gvl);
- c1 = __riscv_vfmacc_vf_f32m2(c1, alpha, result1, gvl);
- c2 = __riscv_vfmacc_vf_f32m2(c2, alpha, result2, gvl);
- c3 = __riscv_vfmacc_vf_f32m2(c3, alpha, result3, gvl);
-
- ci= n_top * ldc + m_top;
-
- __riscv_vse32_v_f32m2(&C[ci], c0, gvl); ci += ldc - gvl * 0;
- __riscv_vse32_v_f32m2(&C[ci], c1, gvl); ci += ldc - gvl * 0;
- __riscv_vse32_v_f32m2(&C[ci], c2, gvl); ci += ldc - gvl * 0;
- __riscv_vse32_v_f32m2(&C[ci], c3, gvl);
-
- m_top += 4;
- }
-
-
- if( M & 2 ) {
-
- BLASLONG ai = m_top * K;
- BLASLONG bi = n_top * K;
-
- float result0 = 0;
- float result1 = 0;
- float result2 = 0;
- float result3 = 0;
- float result4 = 0;
- float result5 = 0;
- float result6 = 0;
- float result7 = 0;
-
- for(BLASLONG k=0; k<K; k++) {
- result0+=(float)(A[ai+0]*B[bi+0]);
- result1+=(float)(A[ai+1]*B[bi+0]);
- result2+=(float)(A[ai+0]*B[bi+1]);
- result3+=(float)(A[ai+1]*B[bi+1]);
- result4+=(float)(A[ai+0]*B[bi+2]);
- result5+=(float)(A[ai+1]*B[bi+2]);
- result6+=(float)(A[ai+0]*B[bi+3]);
- result7+=(float)(A[ai+1]*B[bi+3]);
- ai+=2;
- bi+=4;
- }
-
-
- BLASLONG ci=n_top*ldc+m_top;
- C[ci + 0 * ldc + 0] += alpha * result0;
- C[ci + 0 * ldc + 1] += alpha * result1;
- C[ci + 1 * ldc + 0] += alpha * result2;
- C[ci + 1 * ldc + 1] += alpha * result3;
- C[ci + 2 * ldc + 0] += alpha * result4;
- C[ci + 2 * ldc + 1] += alpha * result5;
- C[ci + 3 * ldc + 0] += alpha * result6;
- C[ci + 3 * ldc + 1] += alpha * result7;
-
- m_top += 2;
- }
-
-
- if( M & 1 ) {
-
- float result0 = 0;
- float result1 = 0;
- float result2 = 0;
- float result3 = 0;
-
- BLASLONG ai = m_top * K;
- BLASLONG bi = n_top * K;
-
- for(BLASLONG k=0; k<K; k++) {
- result0+=(float)(A[ai+0]*B[bi+0]);
- result1+=(float)(A[ai+0]*B[bi+1]);
- result2+=(float)(A[ai+0]*B[bi+2]);
- result3+=(float)(A[ai+0]*B[bi+3]);
- ai+=1;
- bi+=4;
- }
-
- BLASLONG ci = n_top * ldc + m_top;
- C[ci + 0 * ldc + 0] += alpha * result0;
- C[ci + 1 * ldc + 0] += alpha * result1;
- C[ci + 2 * ldc + 0] += alpha * result2;
- C[ci + 3 * ldc + 0] += alpha * result3;
- m_top += 1;
- }
-
- n_top += 4;
- }
-
-
-
- // -- tails for N=2
- if( N & 2 ) {
- gvl = __riscv_vsetvl_e16m1(8);
- m_top = 0;
-
- for (BLASLONG i=0; i<M/8; i+=1) {
- BLASLONG ai=m_top*K;
- BLASLONG bi=n_top*K;
-
- _Float16 B0 = B[bi+0];
- _Float16 B1 = B[bi+1];
- bi += 2;
-
- vfloat16m1_t A0 = __riscv_vle16_v_f16m1( &A[ai+0*gvl], gvl );
- ai += 8;
-
- vfloat32m2_t result0 = __riscv_vfwmul_vf_f32m2( A0, B0, gvl);
- vfloat32m2_t result1 = __riscv_vfwmul_vf_f32m2( A0, B1, gvl);
-
- for(BLASLONG k=1; k<K; k++) {
- B0 = B[bi+0];
- B1 = B[bi+1];
- bi += 2;
-
- A0 = __riscv_vle16_v_f16m1( &A[ai+0*gvl], gvl );
- ai += 8;
-
- result0 = __riscv_vfwmacc_vf_f32m2(result0, B0, A0, gvl);
- result1 = __riscv_vfwmacc_vf_f32m2(result1, B1, A0, gvl);
- }
-
-
- BLASLONG ci=n_top*ldc+m_top;
-
- vfloat32m2_t c0 = __riscv_vle32_v_f32m2( &C[ci], gvl); ci += ldc - gvl * 0;
- vfloat32m2_t c1 = __riscv_vle32_v_f32m2( &C[ci], gvl);
-
- c0 = __riscv_vfmacc_vf_f32m2(c0, alpha, result0, gvl);
- c1 = __riscv_vfmacc_vf_f32m2(c1, alpha, result1, gvl);
-
- ci = n_top * ldc + m_top;
-
- __riscv_vse32_v_f32m2( &C[ci], c0, gvl); ci += ldc-gvl*0;
- __riscv_vse32_v_f32m2( &C[ci], c1, gvl);
- m_top += 8;
- }
-
- if( M & 4 ) {
- gvl = __riscv_vsetvl_e16m1(4);
-
- BLASLONG ai=m_top*K;
- BLASLONG bi=n_top*K;
- _Float16 B0 = B[bi+0];
- _Float16 B1 = B[bi+1];
- bi += 2;
-
- vfloat16m1_t A0 = __riscv_vle16_v_f16m1(&A[ai + 0 * gvl], gvl);
- ai += 4;
-
- vfloat32m2_t result0 = __riscv_vfwmul_vf_f32m2( A0, B0, gvl);
- vfloat32m2_t result1 = __riscv_vfwmul_vf_f32m2( A0, B1, gvl);
-
- for(BLASLONG k=1; k < K; ++k) {
- B0 = B[bi+0];
- B1 = B[bi+1];
- bi += 2;
-
- A0 = __riscv_vle16_v_f16m1(&A[ai + 0 * gvl], gvl);
- ai += 4;
-
- result0 = __riscv_vfwmacc_vf_f32m2(result0, B0, A0, gvl);
- result1 = __riscv_vfwmacc_vf_f32m2(result1, B1, A0, gvl);
- }
-
- BLASLONG ci = n_top * ldc + m_top;
-
- vfloat32m2_t c0 = __riscv_vle32_v_f32m2(&C[ci], gvl);
- ci += ldc - gvl * 0;
- vfloat32m2_t c1 = __riscv_vle32_v_f32m2(&C[ci], gvl);
- c0 = __riscv_vfmacc_vf_f32m2(c0, alpha, result0, gvl);
- c1 = __riscv_vfmacc_vf_f32m2(c1, alpha, result1, gvl);
-
- ci= n_top * ldc + m_top;
-
- __riscv_vse32_v_f32m2(&C[ci], c0, gvl); ci += ldc - gvl * 0;
- __riscv_vse32_v_f32m2(&C[ci], c1, gvl);
-
- m_top += 4;
- }
-
-
- if( M & 2 ) {
-
- BLASLONG ai = m_top * K;
- BLASLONG bi = n_top * K;
-
- float result0 = 0;
- float result1 = 0;
- float result2 = 0;
- float result3 = 0;
-
- for(BLASLONG k=0; k<K; k++) {
- result0+=(float)(A[ai+0]*B[bi+0]);
- result1+=(float)(A[ai+1]*B[bi+0]);
- result2+=(float)(A[ai+0]*B[bi+1]);
- result3+=(float)(A[ai+1]*B[bi+1]);
- ai+=2;
- bi+=2;
- }
-
- BLASLONG ci=n_top*ldc+m_top;
- C[ci + 0 * ldc + 0] += alpha * result0;
- C[ci + 0 * ldc + 1] += alpha * result1;
- C[ci + 1 * ldc + 0] += alpha * result2;
- C[ci + 1 * ldc + 1] += alpha * result3;
-
- m_top += 2;
- }
-
-
- if( M & 1 ) {
-
- float result0 = 0;
- float result1 = 0;
-
- BLASLONG ai = m_top * K;
- BLASLONG bi = n_top * K;
-
- for(BLASLONG k=0; k<K; k++) {
- result0+=(float)(A[ai+0]*B[bi+0]);
- result1+=(float)(A[ai+0]*B[bi+1]);
- ai+=1;
- bi+=2;
- }
-
- BLASLONG ci = n_top * ldc + m_top;
- C[ci + 0 * ldc + 0] += alpha * result0;
- C[ci + 1 * ldc + 0] += alpha * result1;
- m_top += 1;
- }
-
- n_top += 2;
- }
-
-
-
- // -- tails for N=1
- if( N & 1 ) {
- gvl = __riscv_vsetvl_e16m1(8);
- m_top = 0;
-
- for (BLASLONG i=0; i<M/8; i+=1) {
- BLASLONG ai=m_top*K;
- BLASLONG bi=n_top*K;
-
- _Float16 B0 = B[bi+0];
- bi += 1;
-
- vfloat16m1_t A0 = __riscv_vle16_v_f16m1( &A[ai+0*gvl], gvl );
- ai += 8;
-
- vfloat32m2_t result0 = __riscv_vfwmul_vf_f32m2( A0, B0, gvl);
-
- for(BLASLONG k=1; k<K; k++) {
- B0 = B[bi+0];
- bi += 1;
-
- A0 = __riscv_vle16_v_f16m1( &A[ai+0*gvl], gvl );
- ai += 8;
-
- result0 = __riscv_vfwmacc_vf_f32m2(result0, B0, A0, gvl);
- }
-
-
- BLASLONG ci=n_top*ldc+m_top;
-
- vfloat32m2_t c0 = __riscv_vle32_v_f32m2( &C[ci], gvl);
-
- c0 = __riscv_vfmacc_vf_f32m2(c0, alpha, result0, gvl);
-
- ci = n_top * ldc + m_top;
-
- __riscv_vse32_v_f32m2( &C[ci], c0, gvl);
- m_top += 8;
- }
-
- if( M & 4 ) {
- gvl = __riscv_vsetvl_e16m1(4);
-
- BLASLONG ai=m_top*K;
- BLASLONG bi=n_top*K;
- _Float16 B0 = B[bi+0];
- bi += 1;
-
- vfloat16m1_t A0 = __riscv_vle16_v_f16m1(&A[ai + 0 * gvl], gvl);
- ai += 4;
-
- vfloat32m2_t result0 = __riscv_vfwmul_vf_f32m2( A0, B0, gvl);
-
- for(BLASLONG k=1; k < K; ++k) {
- B0 = B[bi+0];
- bi += 1;
-
- A0 = __riscv_vle16_v_f16m1(&A[ai + 0 * gvl], gvl);
- ai += 4;
-
- result0 = __riscv_vfwmacc_vf_f32m2(result0, B0, A0, gvl);
- }
-
- BLASLONG ci = n_top * ldc + m_top;
-
- vfloat32m2_t c0 = __riscv_vle32_v_f32m2(&C[ci], gvl);
- c0 = __riscv_vfmacc_vf_f32m2(c0, alpha, result0, gvl);
-
- ci= n_top * ldc + m_top;
-
- __riscv_vse32_v_f32m2(&C[ci], c0, gvl);
- m_top += 4;
- }
-
-
- if( M & 2 ) {
-
- BLASLONG ai = m_top * K;
- BLASLONG bi = n_top * K;
-
- float result0 = 0;
- float result1 = 0;
-
- for(BLASLONG k=0; k<K; k++) {
- result0+=(float)(A[ai+0]*B[bi+0]);
- result1+=(float)(A[ai+1]*B[bi+0]);
- ai+=2;
- bi+=1;
- }
-
-
- BLASLONG ci=n_top*ldc+m_top;
- C[ci + 0 * ldc + 0] += alpha * result0;
- C[ci + 0 * ldc + 1] += alpha * result1;
-
- m_top += 2;
- }
-
-
- if( M & 1 ) {
-
- float result0 = 0;
-
- BLASLONG ai = m_top * K;
- BLASLONG bi = n_top * K;
-
- for(BLASLONG k=0; k<K; k++) {
- result0+=(float)(A[ai+0]*B[bi+0]);
- ai+=1;
- bi+=1;
- }
-
- BLASLONG ci = n_top * ldc + m_top;
- C[ci + 0 * ldc + 0] += alpha * result0;
- m_top += 1;
- }
-
- n_top += 1;
- }
-
-
- return 0;
-
- }
|