You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

strsm_kernel_8x4_haswell_LT.c 9.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. #include "common.h"
  2. #include <stdint.h>
  3. #include "strsm_kernel_8x4_haswell_L_common.h"
  4. #define SOLVE_LT_m1n4 \
  5. "movq %2,%3;" GEMM_SUM_REORDER_1x4(4)\
  6. SOLVE_m1n4(0,4) SAVE_b_m1n4(0,4)\
  7. "movq %2,%3; addq $4,%2;" save_c_m1n4(4)
  8. #define SOLVE_LT_m1n8 \
  9. "movq %2,%3;" GEMM_SUM_REORDER_1x4(4) GEMM_SUM_REORDER_1x4(5)\
  10. SOLVE_m1n8(0,4,5) SAVE_b_m1n8(0,4,5)\
  11. "movq %2,%3; addq $4,%2;" save_c_m1n4(4) save_c_m1n4(5)
  12. #define SOLVE_LT_m1n12 \
  13. "movq %2,%3;" GEMM_SUM_REORDER_1x4(4) GEMM_SUM_REORDER_1x4(5) GEMM_SUM_REORDER_1x4(6)\
  14. SOLVE_m1n12(0,4,5,6) SAVE_b_m1n12(0,4,5,6)\
  15. "movq %2,%3; addq $4,%2;" save_c_m1n4(4) save_c_m1n4(5) save_c_m1n4(6)
  16. #define SOLVE_LT_m2n4 \
  17. "movq %2,%3;" GEMM_SUM_REORDER_2x4(4,5,4)\
  18. SOLVE_uplo_m2n4(0,4)\
  19. SOLVE_lo_m2n4(8,4) SAVE_b_m2n4(0,4)\
  20. "movq %2,%3; addq $8,%2;" save_c_m2n4(4)
  21. #define SOLVE_LT_m2n8 \
  22. "movq %2,%3;" GEMM_SUM_REORDER_2x4(4,5,4) GEMM_SUM_REORDER_2x4(6,7,5)\
  23. SOLVE_uplo_m2n8(0,4,5)\
  24. SOLVE_lo_m2n8(8,4,5) SAVE_b_m2n8(0,4,5)\
  25. "movq %2,%3; addq $8,%2;" save_c_m2n4(4) save_c_m2n4(5)
  26. #define SOLVE_LT_m2n12 \
  27. "movq %2,%3;" GEMM_SUM_REORDER_2x4(4,5,4) GEMM_SUM_REORDER_2x4(6,7,5) GEMM_SUM_REORDER_2x4(8,9,6)\
  28. SOLVE_uplo_m2n12(0,4,5,6)\
  29. SOLVE_lo_m2n12(8,4,5,6) SAVE_b_m2n12(0,4,5,6)\
  30. "movq %2,%3; addq $8,%2;" save_c_m2n4(4) save_c_m2n4(5) save_c_m2n4(6)
  31. #define SOLVE_LT_m4n4 \
  32. "movq %2,%3;" GEMM_SUM_REORDER_4x4(4,5,6,7,4,5)\
  33. \
  34. SOLVE_uplo_m2n4(0,4) SUBTRACT_m2n4(8,5)\
  35. SOLVE_lo_m2n4(16,4) SUBTRACT_m2n4(24,5) SAVE_b_m2n4(0,4)\
  36. \
  37. SOLVE_uplo_m2n4(40,5)\
  38. SOLVE_lo_m2n4(56,5) SAVE_b_m2n4(32,5)\
  39. \
  40. "movq %2,%3; addq $16,%2;" save_c_m4n4(4,5)
  41. #define SOLVE_LT_m4n8 \
  42. "movq %2,%3;" GEMM_SUM_REORDER_4x4(4,5,6,7,4,5) GEMM_SUM_REORDER_4x4(8,9,10,11,6,7)\
  43. \
  44. SOLVE_uplo_m2n8(0,4,6) SUBTRACT_m2n8(8,5,7)\
  45. SOLVE_lo_m2n8(16,4,6) SUBTRACT_m2n8(24,5,7) SAVE_b_m2n8(0,4,6)\
  46. \
  47. SOLVE_uplo_m2n8(40,5,7)\
  48. SOLVE_lo_m2n8(56,5,7) SAVE_b_m2n8(32,5,7)\
  49. \
  50. "movq %2,%3; addq $16,%2;" save_c_m4n4(4,5) save_c_m4n4(6,7)
  51. #define SOLVE_LT_m4n12 \
  52. "movq %2,%3;" GEMM_SUM_REORDER_4x4(4,5,6,7,4,5) GEMM_SUM_REORDER_4x4(8,9,10,11,6,7) GEMM_SUM_REORDER_4x4(12,13,14,15,8,9)\
  53. \
  54. SOLVE_uplo_m2n12(0,4,6,8) SUBTRACT_m2n12(8,5,7,9)\
  55. SOLVE_lo_m2n12(16,4,6,8) SUBTRACT_m2n12(24,5,7,9) SAVE_b_m2n12(0,4,6,8)\
  56. \
  57. SOLVE_uplo_m2n12(40,5,7,9)\
  58. SOLVE_lo_m2n12(56,5,7,9) SAVE_b_m2n12(32,5,7,9)\
  59. \
  60. "movq %2,%3; addq $16,%2;" save_c_m4n4(4,5) save_c_m4n4(6,7) save_c_m4n4(8,9)
  61. #define SOLVE_LT_m8n4 \
  62. "movq %2,%3;" GEMM_SUM_REORDER_8x4(4,5,6,7,63)\
  63. \
  64. SOLVE_uplo_m2n4(0,4) SUBTRACT_m2n4(8,5) SUBTRACT_m2n4(16,6) SUBTRACT_m2n4(24,7)\
  65. SOLVE_lo_m2n4(32,4) SUBTRACT_m2n4(40,5) SUBTRACT_m2n4(48,6) SUBTRACT_m2n4(56,7) SAVE_b_m2n4(0,4)\
  66. \
  67. SOLVE_uplo_m2n4(72,5) SUBTRACT_m2n4(80,6) SUBTRACT_m2n4(88,7)\
  68. SOLVE_lo_m2n4(104,5) SUBTRACT_m2n4(112,6) SUBTRACT_m2n4(120,7) SAVE_b_m2n4(32,5)\
  69. \
  70. SOLVE_uplo_m2n4(144,6) SUBTRACT_m2n4(152,7)\
  71. SOLVE_lo_m2n4(176,6) SUBTRACT_m2n4(184,7) SAVE_b_m2n4(64,6)\
  72. \
  73. SOLVE_uplo_m2n4(216,7)\
  74. SOLVE_lo_m2n4(248,7) SAVE_b_m2n4(96,7)\
  75. \
  76. "movq %2,%3; addq $32,%2;" save_c_m8n4(4,5,6,7)
  77. #define SOLVE_LT_m8n8 \
  78. "movq %2,%3;" GEMM_SUM_REORDER_8x4(4,5,6,7,63) GEMM_SUM_REORDER_8x4(8,9,10,11,63)\
  79. \
  80. SOLVE_uplo_m2n8(0,4,8) SUBTRACT_m2n8(8,5,9) SUBTRACT_m2n8(16,6,10) SUBTRACT_m2n8(24,7,11)\
  81. SOLVE_lo_m2n8(32,4,8) SUBTRACT_m2n8(40,5,9) SUBTRACT_m2n8(48,6,10) SUBTRACT_m2n8(56,7,11) SAVE_b_m2n8(0,4,8)\
  82. \
  83. SOLVE_uplo_m2n8(72,5,9) SUBTRACT_m2n8(80,6,10) SUBTRACT_m2n8(88,7,11)\
  84. SOLVE_lo_m2n8(104,5,9) SUBTRACT_m2n8(112,6,10) SUBTRACT_m2n8(120,7,11) SAVE_b_m2n8(32,5,9)\
  85. \
  86. SOLVE_uplo_m2n8(144,6,10) SUBTRACT_m2n8(152,7,11)\
  87. SOLVE_lo_m2n8(176,6,10) SUBTRACT_m2n8(184,7,11) SAVE_b_m2n8(64,6,10)\
  88. \
  89. SOLVE_uplo_m2n8(216,7,11)\
  90. SOLVE_lo_m2n8(248,7,11) SAVE_b_m2n8(96,7,11)\
  91. \
  92. "movq %2,%3; addq $32,%2;" save_c_m8n4(4,5,6,7) save_c_m8n4(8,9,10,11)
  93. #define SOLVE_LT_m8n12 \
  94. "movq %2,%3;" GEMM_SUM_REORDER_8x4(4,5,6,7,63) GEMM_SUM_REORDER_8x4(8,9,10,11,63) GEMM_SUM_REORDER_8x4(12,13,14,15,63)\
  95. \
  96. SOLVE_uplo_m2n12(0,4,8,12) SUBTRACT_m2n12(8,5,9,13) SUBTRACT_m2n12(16,6,10,14) SUBTRACT_m2n12(24,7,11,15)\
  97. SOLVE_lo_m2n12(32,4,8,12) SUBTRACT_m2n12(40,5,9,13) SUBTRACT_m2n12(48,6,10,14) SUBTRACT_m2n12(56,7,11,15) SAVE_b_m2n12(0,4,8,12)\
  98. \
  99. SOLVE_uplo_m2n12(72,5,9,13) SUBTRACT_m2n12(80,6,10,14) SUBTRACT_m2n12(88,7,11,15)\
  100. SOLVE_lo_m2n12(104,5,9,13) SUBTRACT_m2n12(112,6,10,14) SUBTRACT_m2n12(120,7,11,15) SAVE_b_m2n12(32,5,9,13)\
  101. \
  102. SOLVE_uplo_m2n12(144,6,10,14) SUBTRACT_m2n12(152,7,11,15)\
  103. SOLVE_lo_m2n12(176,6,10,14) SUBTRACT_m2n12(184,7,11,15) SAVE_b_m2n12(64,6,10,14)\
  104. \
  105. SOLVE_uplo_m2n12(216,7,11,15)\
  106. SOLVE_lo_m2n12(248,7,11,15) SAVE_b_m2n12(96,7,11,15)\
  107. \
  108. "movq %2,%3; addq $32,%2;" save_c_m8n4(4,5,6,7) save_c_m8n4(8,9,10,11) save_c_m8n4(12,13,14,15)
  109. #define GEMM_LT_SIMPLE(mdim,ndim) \
  110. "movq %%r15,%0; leaq (%%r15,%%r12,"#mdim"),%%r15; movq %%r13,%5; addq $"#mdim",%%r13; movq %%r14,%1;" INIT_m##mdim##n##ndim\
  111. "testq %5,%5; jz 1"#mdim""#ndim"2f;"\
  112. "1"#mdim""#ndim"1:\n\t"\
  113. GEMM_KERNEL_k1m##mdim##n##ndim "addq $16,%1; addq $"#mdim"*4,%0; decq %5; jnz 1"#mdim""#ndim"1b;"\
  114. "1"#mdim""#ndim"2:\n\t"
  115. #define GEMM_LT_m8n4 GEMM_LT_SIMPLE(8,4)
  116. #define GEMM_LT_m8n8 GEMM_LT_SIMPLE(8,8)
  117. #define GEMM_LT_m8n12 \
  118. "movq %%r15,%0; leaq (%%r15,%%r12,8),%%r15; movq %%r13,%5; addq $8,%%r13; movq %%r14,%1;" INIT_m8n12\
  119. "cmpq $8,%5; jb 18122f;"\
  120. "18121:\n\t"\
  121. GEMM_KERNEL_k1m8n12 "prefetcht0 384(%0); addq $32,%0; addq $16,%1;"\
  122. GEMM_KERNEL_k1m8n12 "addq $32,%0; addq $16,%1;"\
  123. GEMM_KERNEL_k1m8n12 "prefetcht0 384(%0); addq $32,%0; addq $16,%1;"\
  124. GEMM_KERNEL_k1m8n12 "addq $32,%0; addq $16,%1;"\
  125. GEMM_KERNEL_k1m8n12 "prefetcht0 384(%0); addq $32,%0; addq $16,%1;"\
  126. GEMM_KERNEL_k1m8n12 "addq $32,%0; addq $16,%1;"\
  127. GEMM_KERNEL_k1m8n12 "prefetcht0 384(%0); addq $32,%0; addq $16,%1;"\
  128. GEMM_KERNEL_k1m8n12 "addq $32,%0; addq $16,%1;"\
  129. "subq $8,%5; cmpq $8,%5; jnb 18121b;"\
  130. "18122:\n\t"\
  131. "testq %5,%5; jz 18124f;"\
  132. "18123:\n\t"\
  133. GEMM_KERNEL_k1m8n12 "addq $32,%0; addq $16,%1; decq %5; jnz 18123b;"\
  134. "18124:\n\t"
  135. #define GEMM_LT_m4n4 GEMM_LT_SIMPLE(4,4)
  136. #define GEMM_LT_m4n8 GEMM_LT_SIMPLE(4,8)
  137. #define GEMM_LT_m4n12 GEMM_LT_SIMPLE(4,12)
  138. #define GEMM_LT_m2n4 GEMM_LT_SIMPLE(2,4)
  139. #define GEMM_LT_m2n8 GEMM_LT_SIMPLE(2,8)
  140. #define GEMM_LT_m2n12 GEMM_LT_SIMPLE(2,12)
  141. #define GEMM_LT_m1n4 GEMM_LT_SIMPLE(1,4)
  142. #define GEMM_LT_m1n8 GEMM_LT_SIMPLE(1,8)
  143. #define GEMM_LT_m1n12 GEMM_LT_SIMPLE(1,12)
  144. #define COMPUTE(ndim) {\
  145. __asm__ __volatile__(\
  146. "movq %0,%%r15; movq %1,%%r14; movq %7,%%r13; movq %6,%%r12; salq $2,%%r12; movq %10,%%r11;"\
  147. "cmpq $8,%%r11; jb "#ndim"772f;"\
  148. #ndim"771:\n\t"\
  149. GEMM_LT_m8n##ndim SOLVE_LT_m8n##ndim "subq $8,%%r11; cmpq $8,%%r11; jnb "#ndim"771b;"\
  150. #ndim"772:\n\t"\
  151. "testq $4,%%r11; jz "#ndim"773f;"\
  152. GEMM_LT_m4n##ndim SOLVE_LT_m4n##ndim "subq $4,%%r11;"\
  153. #ndim"773:\n\t"\
  154. "testq $2,%%r11; jz "#ndim"774f;"\
  155. GEMM_LT_m2n##ndim SOLVE_LT_m2n##ndim "subq $2,%%r11;"\
  156. #ndim"774:\n\t"\
  157. "testq $1,%%r11; jz "#ndim"775f;"\
  158. GEMM_LT_m1n##ndim SOLVE_LT_m1n##ndim "subq $1,%%r11;"\
  159. #ndim"775:\n\t"\
  160. "movq %%r15,%0; movq %%r14,%1; vzeroupper;"\
  161. :"+r"(a_ptr),"+r"(b_ptr),"+r"(c_ptr),"+r"(c_tmp),"+r"(ldc_bytes),"+r"(k_cnt):"m"(K),"m"(OFF),"m"(one[0]),"m"(zero[0]),"m"(M)\
  162. :"r11","r12","r13","r14","r15","cc","memory",\
  163. "xmm0","xmm1","xmm2","xmm3","xmm4","xmm5","xmm6","xmm7","xmm8","xmm9","xmm10","xmm11","xmm12","xmm13","xmm14","xmm15");\
  164. a_ptr -= M * K; b_ptr += ndim * K; c_ptr += ldc * ndim - M;\
  165. }
  166. static void solve_LT(BLASLONG m, BLASLONG n, FLOAT *a, FLOAT *b, FLOAT *c, BLASLONG ldc) {
  167. FLOAT a0, b0;
  168. int i, j, k;
  169. for (i=0;i<m;i++) {
  170. a0 = a[i*m+i];
  171. for (j=0;j<n;j++) {
  172. b0 = c[j*ldc+i] * a0;
  173. b[i*n+j] = c[j*ldc+i] = b0;
  174. for (k=i+1;k<m;k++) c[j*ldc+k] -= b0 * a[i*m+k];
  175. }
  176. }
  177. }
  178. static void COMPUTE_EDGE_1_nchunk(BLASLONG m, BLASLONG n, FLOAT *sa, FLOAT *sb, FLOAT *C, BLASLONG ldc, BLASLONG k, BLASLONG offset) {
  179. BLASLONG m_count = m, kk = offset; FLOAT *a_ptr = sa, *c_ptr = C;
  180. for(;m_count>7;m_count-=8){
  181. if(kk>0) GEMM_KERNEL_N(8,n,kk,-1.0,a_ptr,sb,c_ptr,ldc);
  182. solve_LT(8,n,a_ptr+kk*8,sb+kk*n,c_ptr,ldc);
  183. kk += 8; a_ptr += k * 8; c_ptr += 8;
  184. }
  185. for(;m_count>3;m_count-=4){
  186. if(kk>0) GEMM_KERNEL_N(4,n,kk,-1.0,a_ptr,sb,c_ptr,ldc);
  187. solve_LT(4,n,a_ptr+kk*4,sb+kk*n,c_ptr,ldc);
  188. kk += 4; a_ptr += k * 4; c_ptr += 4;
  189. }
  190. for(;m_count>1;m_count-=2){
  191. if(kk>0) GEMM_KERNEL_N(2,n,kk,-1.0,a_ptr,sb,c_ptr,ldc);
  192. solve_LT(2,n,a_ptr+kk*2,sb+kk*n,c_ptr,ldc);
  193. kk += 2; a_ptr += k * 2; c_ptr += 2;
  194. }
  195. if(m_count>0){
  196. if(kk>0) GEMM_KERNEL_N(1,n,kk,-1.0,a_ptr,sb,c_ptr,ldc);
  197. solve_LT(1,n,a_ptr+kk*1,sb+kk*n,c_ptr,ldc);
  198. kk += 1; a_ptr += k * 1; c_ptr += 1;
  199. }
  200. }
  201. int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT dummy1, FLOAT *sa, FLOAT *sb, FLOAT *C, BLASLONG ldc, BLASLONG offset){
  202. float *a_ptr = sa, *b_ptr = sb, *c_ptr = C, *c_tmp = C;
  203. float one[8] = {1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0};
  204. float zero[8] = {0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0};
  205. uint64_t ldc_bytes = (uint64_t)ldc * sizeof(float), K = (uint64_t)k, M = (uint64_t)m, OFF = (uint64_t)offset, k_cnt = 0;
  206. BLASLONG n_count = n;
  207. for(;n_count>11;n_count-=12) COMPUTE(12)
  208. for(;n_count>7;n_count-=8) COMPUTE(8)
  209. for(;n_count>3;n_count-=4) COMPUTE(4)
  210. for(;n_count>1;n_count-=2) { COMPUTE_EDGE_1_nchunk(m,2,a_ptr,b_ptr,c_ptr,ldc,k,offset); b_ptr += 2*k; c_ptr += ldc*2;}
  211. if(n_count>0) COMPUTE_EDGE_1_nchunk(m,1,a_ptr,b_ptr,c_ptr,ldc,k,offset);
  212. return 0;
  213. }