You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

gemmkernel_2x2.c 5.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. #include "common.h"
  2. #if defined(BFLOAT16) && defined(BFLOAT16CONVERSION)
  3. static float
  4. bfloat16tof32 (bfloat16 f16)
  5. {
  6. float result = 0;
  7. unsigned short* q = (unsigned short*)(&result);
  8. #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
  9. q[0] = f16;
  10. #else
  11. q[1] = f16;
  12. #endif
  13. return result;
  14. }
  15. #define BF16TOF32(x) (bfloat16tof32(x))
  16. #else
  17. #define BF16TOF32(x) x
  18. #endif
  19. int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,IFLOAT* ba,IFLOAT* bb,FLOAT* C,BLASLONG ldc
  20. #ifdef TRMMKERNEL
  21. ,BLASLONG offset
  22. #endif
  23. )
  24. {
  25. BLASLONG i,j,k;
  26. FLOAT *C0,*C1;
  27. IFLOAT *ptrba,*ptrbb;
  28. FLOAT res0,res1,res2,res3;
  29. IFLOAT load0,load1,load2,load3,load4,load5,load6,load7;
  30. for (j=0; j<bn/2; j+=1)
  31. {
  32. C0 = C;
  33. C1 = C0+ldc;
  34. ptrba = ba;
  35. for (i=0; i<bm/2; i+=1)
  36. {
  37. ptrbb = bb;
  38. res0 = 0;
  39. res1 = 0;
  40. res2 = 0;
  41. res3 = 0;
  42. for (k=0; k<bk/4; k+=1)
  43. {
  44. load0 = ptrba[2*0+0];
  45. load1 = ptrbb[2*0+0];
  46. res0 = res0+BF16TOF32(load0)*BF16TOF32(load1);
  47. load2 = ptrba[2*0+1];
  48. res1 = res1+BF16TOF32(load2)*BF16TOF32(load1);
  49. load3 = ptrbb[2*0+1];
  50. res2 = res2+BF16TOF32(load0)*BF16TOF32(load3);
  51. res3 = res3+BF16TOF32(load2)*BF16TOF32(load3);
  52. load4 = ptrba[2*1+0];
  53. load5 = ptrbb[2*1+0];
  54. res0 = res0+BF16TOF32(load4)*BF16TOF32(load5);
  55. load6 = ptrba[2*1+1];
  56. res1 = res1+BF16TOF32(load6)*BF16TOF32(load5);
  57. load7 = ptrbb[2*1+1];
  58. res2 = res2+BF16TOF32(load4)*BF16TOF32(load7);
  59. res3 = res3+BF16TOF32(load6)*BF16TOF32(load7);
  60. load0 = ptrba[2*2+0];
  61. load1 = ptrbb[2*2+0];
  62. res0 = res0+BF16TOF32(load0)*BF16TOF32(load1);
  63. load2 = ptrba[2*2+1];
  64. res1 = res1+BF16TOF32(load2)*BF16TOF32(load1);
  65. load3 = ptrbb[2*2+1];
  66. res2 = res2+BF16TOF32(load0)*BF16TOF32(load3);
  67. res3 = res3+BF16TOF32(load2)*BF16TOF32(load3);
  68. load4 = ptrba[2*3+0];
  69. load5 = ptrbb[2*3+0];
  70. res0 = res0+BF16TOF32(load4)*BF16TOF32(load5);
  71. load6 = ptrba[2*3+1];
  72. res1 = res1+BF16TOF32(load6)*BF16TOF32(load5);
  73. load7 = ptrbb[2*3+1];
  74. res2 = res2+BF16TOF32(load4)*BF16TOF32(load7);
  75. res3 = res3+BF16TOF32(load6)*BF16TOF32(load7);
  76. ptrba = ptrba+8;
  77. ptrbb = ptrbb+8;
  78. }
  79. for (k=0; k<(bk&3); k+=1)
  80. {
  81. load0 = ptrba[2*0+0];
  82. load1 = ptrbb[2*0+0];
  83. res0 = res0+BF16TOF32(load0)*BF16TOF32(load1);
  84. load2 = ptrba[2*0+1];
  85. res1 = res1+BF16TOF32(load2)*BF16TOF32(load1);
  86. load3 = ptrbb[2*0+1];
  87. res2 = res2+BF16TOF32(load0)*BF16TOF32(load3);
  88. res3 = res3+BF16TOF32(load2)*BF16TOF32(load3);
  89. ptrba = ptrba+2;
  90. ptrbb = ptrbb+2;
  91. }
  92. res0 = res0*alpha;
  93. C0[0] = C0[0]+res0;
  94. res1 = res1*alpha;
  95. C0[1] = C0[1]+res1;
  96. res2 = res2*alpha;
  97. C1[0] = C1[0]+res2;
  98. res3 = res3*alpha;
  99. C1[1] = C1[1]+res3;
  100. C0 = C0+2;
  101. C1 = C1+2;
  102. }
  103. for (i=0; i<(bm&1); i+=1)
  104. {
  105. ptrbb = bb;
  106. res0 = 0;
  107. res1 = 0;
  108. for (k=0; k<bk; k+=1)
  109. {
  110. load0 = ptrba[0+0];
  111. load1 = ptrbb[2*0+0];
  112. res0 = res0+BF16TOF32(load0)*BF16TOF32(load1);
  113. load2 = ptrbb[2*0+1];
  114. res1 = res1+BF16TOF32(load0)*BF16TOF32(load2);
  115. ptrba = ptrba+1;
  116. ptrbb = ptrbb+2;
  117. }
  118. res0 = res0*alpha;
  119. C0[0] = C0[0]+res0;
  120. res1 = res1*alpha;
  121. C1[0] = C1[0]+res1;
  122. C0 = C0+1;
  123. C1 = C1+1;
  124. }
  125. k = (bk<<1);
  126. bb = bb+k;
  127. i = (ldc<<1);
  128. C = C+i;
  129. }
  130. for (j=0; j<(bn&1); j+=1)
  131. {
  132. C0 = C;
  133. ptrba = ba;
  134. for (i=0; i<bm/2; i+=1)
  135. {
  136. ptrbb = bb;
  137. res0 = 0;
  138. res1 = 0;
  139. for (k=0; k<bk; k+=1)
  140. {
  141. load0 = ptrba[2*0+0];
  142. load1 = ptrbb[0+0];
  143. res0 = res0+BF16TOF32(load0)*BF16TOF32(load1);
  144. load2 = ptrba[2*0+1];
  145. res1 = res1+BF16TOF32(load2)*BF16TOF32(load1);
  146. ptrba = ptrba+2;
  147. ptrbb = ptrbb+1;
  148. }
  149. res0 = res0*alpha;
  150. C0[0] = C0[0]+res0;
  151. res1 = res1*alpha;
  152. C0[1] = C0[1]+res1;
  153. C0 = C0+2;
  154. }
  155. for (i=0; i<(bm&1); i+=1)
  156. {
  157. ptrbb = bb;
  158. res0 = 0;
  159. for (k=0; k<bk; k+=1)
  160. {
  161. load0 = ptrba[0+0];
  162. load1 = ptrbb[0+0];
  163. res0 = res0+BF16TOF32(load0)*BF16TOF32(load1);
  164. ptrba = ptrba+1;
  165. ptrbb = ptrbb+1;
  166. }
  167. res0 = res0*alpha;
  168. C0[0] = C0[0]+res0;
  169. C0 = C0+1;
  170. }
  171. k = (bk<<0);
  172. bb = bb+k;
  173. C = C+ldc;
  174. }
  175. return 0;
  176. }