You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

gemmkernel_2x2.c 7.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. /***************************************************************************
  2. * Copyright (c) 2025, The OpenBLAS Project
  3. * All rights reserved.
  4. * Redistribution and use in source and binary forms, with or without
  5. * modification, are permitted provided that the following conditions are
  6. * met:
  7. * 1. Redistributions of source code must retain the above copyright
  8. * notice, this list of conditions and the following disclaimer.
  9. * 2. Redistributions in binary form must reproduce the above copyright
  10. * notice, this list of conditions and the following disclaimer in
  11. * the documentation and/or other materials provided with the
  12. * distribution.
  13. * 3. Neither the name of the OpenBLAS project nor the names of
  14. * its contributors may be used to endorse or promote products
  15. * derived from this software without specific prior written permission.
  16. * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
  17. * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  18. * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
  19. * ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
  20. * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
  21. * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
  22. * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
  23. * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
  24. * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
  25. * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
  26. * POSSIBILITY OF SUCH DAMAGE.
  27. * *****************************************************************************/
  28. #include "common.h"
  29. #include "bf16_macros.h"
  30. int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,IFLOAT* ba,IFLOAT* bb,FLOAT* C,BLASLONG ldc
  31. #ifdef TRMMKERNEL
  32. ,BLASLONG offset
  33. #endif
  34. )
  35. {
  36. BLASLONG i,j,k;
  37. FLOAT *C0,*C1;
  38. IFLOAT *ptrba,*ptrbb;
  39. #ifdef BGEMM
  40. float res0,res1,res2,res3;
  41. #else
  42. FLOAT res0,res1,res2,res3;
  43. #endif
  44. IFLOAT load0,load1,load2,load3,load4,load5,load6,load7;
  45. for (j=0; j<bn/2; j+=1)
  46. {
  47. C0 = C;
  48. C1 = C0+ldc;
  49. ptrba = ba;
  50. for (i=0; i<bm/2; i+=1)
  51. {
  52. ptrbb = bb;
  53. res0 = 0;
  54. res1 = 0;
  55. res2 = 0;
  56. res3 = 0;
  57. for (k=0; k<bk/4; k+=1)
  58. {
  59. load0 = ptrba[2*0+0];
  60. load1 = ptrbb[2*0+0];
  61. res0 = res0+BF16TOF32(load0)*BF16TOF32(load1);
  62. load2 = ptrba[2*0+1];
  63. res1 = res1+BF16TOF32(load2)*BF16TOF32(load1);
  64. load3 = ptrbb[2*0+1];
  65. res2 = res2+BF16TOF32(load0)*BF16TOF32(load3);
  66. res3 = res3+BF16TOF32(load2)*BF16TOF32(load3);
  67. load4 = ptrba[2*1+0];
  68. load5 = ptrbb[2*1+0];
  69. res0 = res0+BF16TOF32(load4)*BF16TOF32(load5);
  70. load6 = ptrba[2*1+1];
  71. res1 = res1+BF16TOF32(load6)*BF16TOF32(load5);
  72. load7 = ptrbb[2*1+1];
  73. res2 = res2+BF16TOF32(load4)*BF16TOF32(load7);
  74. res3 = res3+BF16TOF32(load6)*BF16TOF32(load7);
  75. load0 = ptrba[2*2+0];
  76. load1 = ptrbb[2*2+0];
  77. res0 = res0+BF16TOF32(load0)*BF16TOF32(load1);
  78. load2 = ptrba[2*2+1];
  79. res1 = res1+BF16TOF32(load2)*BF16TOF32(load1);
  80. load3 = ptrbb[2*2+1];
  81. res2 = res2+BF16TOF32(load0)*BF16TOF32(load3);
  82. res3 = res3+BF16TOF32(load2)*BF16TOF32(load3);
  83. load4 = ptrba[2*3+0];
  84. load5 = ptrbb[2*3+0];
  85. res0 = res0+BF16TOF32(load4)*BF16TOF32(load5);
  86. load6 = ptrba[2*3+1];
  87. res1 = res1+BF16TOF32(load6)*BF16TOF32(load5);
  88. load7 = ptrbb[2*3+1];
  89. res2 = res2+BF16TOF32(load4)*BF16TOF32(load7);
  90. res3 = res3+BF16TOF32(load6)*BF16TOF32(load7);
  91. ptrba = ptrba+8;
  92. ptrbb = ptrbb+8;
  93. }
  94. for (k=0; k<(bk&3); k+=1)
  95. {
  96. load0 = ptrba[2*0+0];
  97. load1 = ptrbb[2*0+0];
  98. res0 = res0+BF16TOF32(load0)*BF16TOF32(load1);
  99. load2 = ptrba[2*0+1];
  100. res1 = res1+BF16TOF32(load2)*BF16TOF32(load1);
  101. load3 = ptrbb[2*0+1];
  102. res2 = res2+BF16TOF32(load0)*BF16TOF32(load3);
  103. res3 = res3+BF16TOF32(load2)*BF16TOF32(load3);
  104. ptrba = ptrba+2;
  105. ptrbb = ptrbb+2;
  106. }
  107. res0 = res0*ALPHA;
  108. C0[0] = F32TOBF16(BF16TOF32(C0[0])+res0);
  109. res1 = res1*ALPHA;
  110. C0[1] = F32TOBF16(BF16TOF32(C0[1])+res1);
  111. res2 = res2*ALPHA;
  112. C1[0] = F32TOBF16(BF16TOF32(C1[0])+res2);
  113. res3 = res3*ALPHA;
  114. C1[1] = F32TOBF16(BF16TOF32(C1[1])+res3);
  115. C0 = C0+2;
  116. C1 = C1+2;
  117. }
  118. for (i=0; i<(bm&1); i+=1)
  119. {
  120. ptrbb = bb;
  121. res0 = 0;
  122. res1 = 0;
  123. for (k=0; k<bk; k+=1)
  124. {
  125. load0 = ptrba[0+0];
  126. load1 = ptrbb[2*0+0];
  127. res0 = res0+BF16TOF32(load0)*BF16TOF32(load1);
  128. load2 = ptrbb[2*0+1];
  129. res1 = res1+BF16TOF32(load0)*BF16TOF32(load2);
  130. ptrba = ptrba+1;
  131. ptrbb = ptrbb+2;
  132. }
  133. res0 = res0*ALPHA;
  134. C0[0] = F32TOBF16(BF16TOF32(C0[0])+res0);
  135. res1 = res1*ALPHA;
  136. C1[0] = F32TOBF16(BF16TOF32(C1[0])+res1);
  137. C0 = C0+1;
  138. C1 = C1+1;
  139. }
  140. k = (bk<<1);
  141. bb = bb+k;
  142. i = (ldc<<1);
  143. C = C+i;
  144. }
  145. for (j=0; j<(bn&1); j+=1)
  146. {
  147. C0 = C;
  148. ptrba = ba;
  149. for (i=0; i<bm/2; i+=1)
  150. {
  151. ptrbb = bb;
  152. res0 = 0;
  153. res1 = 0;
  154. for (k=0; k<bk; k+=1)
  155. {
  156. load0 = ptrba[2*0+0];
  157. load1 = ptrbb[0+0];
  158. res0 = res0+BF16TOF32(load0)*BF16TOF32(load1);
  159. load2 = ptrba[2*0+1];
  160. res1 = res1+BF16TOF32(load2)*BF16TOF32(load1);
  161. ptrba = ptrba+2;
  162. ptrbb = ptrbb+1;
  163. }
  164. res0 = res0*ALPHA;
  165. C0[0] = F32TOBF16(BF16TOF32(C0[0])+res0);
  166. res1 = res1*ALPHA;
  167. C0[1] = F32TOBF16(BF16TOF32(C0[1])+res1);
  168. C0 = C0+2;
  169. }
  170. for (i=0; i<(bm&1); i+=1)
  171. {
  172. ptrbb = bb;
  173. res0 = 0;
  174. for (k=0; k<bk; k+=1)
  175. {
  176. load0 = ptrba[0+0];
  177. load1 = ptrbb[0+0];
  178. res0 = res0+BF16TOF32(load0)*BF16TOF32(load1);
  179. ptrba = ptrba+1;
  180. ptrbb = ptrbb+1;
  181. }
  182. res0 = res0*ALPHA;
  183. C0[0] = F32TOBF16(BF16TOF32(C0[0])+res0);
  184. C0 = C0+1;
  185. }
  186. k = (bk<<0);
  187. bb = bb+k;
  188. C = C+ldc;
  189. }
  190. return 0;
  191. }