You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

gemmkernel_2x2.c 6.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. #include "common.h"
  2. #if defined(BFLOAT16) && defined(BFLOAT16CONVERSION)
  3. static float
  4. bfloat16tof32 (bfloat16 f16)
  5. {
  6. float result = 0;
  7. unsigned short* q = (unsigned short*)(&result);
  8. #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
  9. q[0] = f16;
  10. #else
  11. q[1] = f16;
  12. #endif
  13. return result;
  14. }
  15. static bfloat16 f32tobfloat16(float f32) {
  16. unsigned short *q = (unsigned short *)(&f32);
  17. #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
  18. return q[0];
  19. #else
  20. return q[1];
  21. #endif
  22. }
  23. #ifdef BGEMM
  24. #define ALPHA bfloat16tof32(alpha)
  25. #define BF16TOF32(x) (bfloat16tof32(x))
  26. #define F32TOBF16(x) (f32tobfloat16(x))
  27. #else
  28. #define ALPHA alpha
  29. #define BF16TOF32(x) (bfloat16tof32(x))
  30. #define F32TOBF16(x) x
  31. #endif
  32. #else
  33. #define ALPHA alpha
  34. #define BF16TOF32(x) x
  35. #define F32TOBF16(x) x
  36. #endif
  37. int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,IFLOAT* ba,IFLOAT* bb,FLOAT* C,BLASLONG ldc
  38. #ifdef TRMMKERNEL
  39. ,BLASLONG offset
  40. #endif
  41. )
  42. {
  43. BLASLONG i,j,k;
  44. FLOAT *C0,*C1;
  45. IFLOAT *ptrba,*ptrbb;
  46. #ifdef BGEMM
  47. float res0,res1,res2,res3;
  48. #else
  49. FLOAT res0,res1,res2,res3;
  50. #endif
  51. IFLOAT load0,load1,load2,load3,load4,load5,load6,load7;
  52. for (j=0; j<bn/2; j+=1)
  53. {
  54. C0 = C;
  55. C1 = C0+ldc;
  56. ptrba = ba;
  57. for (i=0; i<bm/2; i+=1)
  58. {
  59. ptrbb = bb;
  60. res0 = 0;
  61. res1 = 0;
  62. res2 = 0;
  63. res3 = 0;
  64. for (k=0; k<bk/4; k+=1)
  65. {
  66. load0 = ptrba[2*0+0];
  67. load1 = ptrbb[2*0+0];
  68. res0 = res0+BF16TOF32(load0)*BF16TOF32(load1);
  69. load2 = ptrba[2*0+1];
  70. res1 = res1+BF16TOF32(load2)*BF16TOF32(load1);
  71. load3 = ptrbb[2*0+1];
  72. res2 = res2+BF16TOF32(load0)*BF16TOF32(load3);
  73. res3 = res3+BF16TOF32(load2)*BF16TOF32(load3);
  74. load4 = ptrba[2*1+0];
  75. load5 = ptrbb[2*1+0];
  76. res0 = res0+BF16TOF32(load4)*BF16TOF32(load5);
  77. load6 = ptrba[2*1+1];
  78. res1 = res1+BF16TOF32(load6)*BF16TOF32(load5);
  79. load7 = ptrbb[2*1+1];
  80. res2 = res2+BF16TOF32(load4)*BF16TOF32(load7);
  81. res3 = res3+BF16TOF32(load6)*BF16TOF32(load7);
  82. load0 = ptrba[2*2+0];
  83. load1 = ptrbb[2*2+0];
  84. res0 = res0+BF16TOF32(load0)*BF16TOF32(load1);
  85. load2 = ptrba[2*2+1];
  86. res1 = res1+BF16TOF32(load2)*BF16TOF32(load1);
  87. load3 = ptrbb[2*2+1];
  88. res2 = res2+BF16TOF32(load0)*BF16TOF32(load3);
  89. res3 = res3+BF16TOF32(load2)*BF16TOF32(load3);
  90. load4 = ptrba[2*3+0];
  91. load5 = ptrbb[2*3+0];
  92. res0 = res0+BF16TOF32(load4)*BF16TOF32(load5);
  93. load6 = ptrba[2*3+1];
  94. res1 = res1+BF16TOF32(load6)*BF16TOF32(load5);
  95. load7 = ptrbb[2*3+1];
  96. res2 = res2+BF16TOF32(load4)*BF16TOF32(load7);
  97. res3 = res3+BF16TOF32(load6)*BF16TOF32(load7);
  98. ptrba = ptrba+8;
  99. ptrbb = ptrbb+8;
  100. }
  101. for (k=0; k<(bk&3); k+=1)
  102. {
  103. load0 = ptrba[2*0+0];
  104. load1 = ptrbb[2*0+0];
  105. res0 = res0+BF16TOF32(load0)*BF16TOF32(load1);
  106. load2 = ptrba[2*0+1];
  107. res1 = res1+BF16TOF32(load2)*BF16TOF32(load1);
  108. load3 = ptrbb[2*0+1];
  109. res2 = res2+BF16TOF32(load0)*BF16TOF32(load3);
  110. res3 = res3+BF16TOF32(load2)*BF16TOF32(load3);
  111. ptrba = ptrba+2;
  112. ptrbb = ptrbb+2;
  113. }
  114. res0 = res0*ALPHA;
  115. C0[0] = F32TOBF16(BF16TOF32(C0[0])+res0);
  116. res1 = res1*ALPHA;
  117. C0[1] = F32TOBF16(BF16TOF32(C0[1])+res1);
  118. res2 = res2*ALPHA;
  119. C1[0] = F32TOBF16(BF16TOF32(C1[0])+res2);
  120. res3 = res3*ALPHA;
  121. C1[1] = F32TOBF16(BF16TOF32(C1[1])+res3);
  122. C0 = C0+2;
  123. C1 = C1+2;
  124. }
  125. for (i=0; i<(bm&1); i+=1)
  126. {
  127. ptrbb = bb;
  128. res0 = 0;
  129. res1 = 0;
  130. for (k=0; k<bk; k+=1)
  131. {
  132. load0 = ptrba[0+0];
  133. load1 = ptrbb[2*0+0];
  134. res0 = res0+BF16TOF32(load0)*BF16TOF32(load1);
  135. load2 = ptrbb[2*0+1];
  136. res1 = res1+BF16TOF32(load0)*BF16TOF32(load2);
  137. ptrba = ptrba+1;
  138. ptrbb = ptrbb+2;
  139. }
  140. res0 = res0*ALPHA;
  141. C0[0] = F32TOBF16(BF16TOF32(C0[0])+res0);
  142. res1 = res1*ALPHA;
  143. C1[0] = F32TOBF16(BF16TOF32(C1[0])+res1);
  144. C0 = C0+1;
  145. C1 = C1+1;
  146. }
  147. k = (bk<<1);
  148. bb = bb+k;
  149. i = (ldc<<1);
  150. C = C+i;
  151. }
  152. for (j=0; j<(bn&1); j+=1)
  153. {
  154. C0 = C;
  155. ptrba = ba;
  156. for (i=0; i<bm/2; i+=1)
  157. {
  158. ptrbb = bb;
  159. res0 = 0;
  160. res1 = 0;
  161. for (k=0; k<bk; k+=1)
  162. {
  163. load0 = ptrba[2*0+0];
  164. load1 = ptrbb[0+0];
  165. res0 = res0+BF16TOF32(load0)*BF16TOF32(load1);
  166. load2 = ptrba[2*0+1];
  167. res1 = res1+BF16TOF32(load2)*BF16TOF32(load1);
  168. ptrba = ptrba+2;
  169. ptrbb = ptrbb+1;
  170. }
  171. res0 = res0*ALPHA;
  172. C0[0] = F32TOBF16(BF16TOF32(C0[0])+res0);
  173. res1 = res1*ALPHA;
  174. C0[1] = F32TOBF16(BF16TOF32(C0[1])+res1);
  175. C0 = C0+2;
  176. }
  177. for (i=0; i<(bm&1); i+=1)
  178. {
  179. ptrbb = bb;
  180. res0 = 0;
  181. for (k=0; k<bk; k+=1)
  182. {
  183. load0 = ptrba[0+0];
  184. load1 = ptrbb[0+0];
  185. res0 = res0+BF16TOF32(load0)*BF16TOF32(load1);
  186. ptrba = ptrba+1;
  187. ptrbb = ptrbb+1;
  188. }
  189. res0 = res0*ALPHA;
  190. C0[0] = F32TOBF16(BF16TOF32(C0[0])+res0);
  191. C0 = C0+1;
  192. }
  193. k = (bk<<0);
  194. bb = bb+k;
  195. C = C+ldc;
  196. }
  197. return 0;
  198. }