You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

gemm_batch.c 13 kB


  1. /*****************************************************************************
  2. Copyright (c) 2020, The OpenBLAS Project
  3. All rights reserved.
  4. Redistribution and use in source and binary forms, with or without
  5. modification, are permitted provided that the following conditions are
  6. met:
  7. 1. Redistributions of source code must retain the above copyright
  8. notice, this list of conditions and the following disclaimer.
  9. 2. Redistributions in binary form must reproduce the above copyright
  10. notice, this list of conditions and the following disclaimer in
  11. the documentation and/or other materials provided with the
  12. distribution.
  13. 3. Neither the name of the OpenBLAS project nor the names of
  14. its contributors may be used to endorse or promote products
  15. derived from this software without specific prior written
  16. permission.
  17. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
  18. AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  19. IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
  20. ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
  21. LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
  22. DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
  23. SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
  24. CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
  25. OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
  26. USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  27. **********************************************************************************/
  28. #include <stdio.h>
  29. #include <stdlib.h>
  30. #include "common.h"
  31. void openblas_warning(int verbose, const char * msg);
  32. #ifndef COMPLEX
  33. #ifdef XDOUBLE
  34. #define ERROR_NAME "QGEMM_BATCH "
  35. #elif defined(DOUBLE)
  36. #define ERROR_NAME "DGEMM_BATCH "
  37. #define GEMM_BATCH_THREAD dgemm_batch_thread
  38. #else
  39. #define ERROR_NAME "SGEMM_BATCH "
  40. #define GEMM_BATCH_THREAD sgemm_batch_thread
  41. #endif
  42. #else
  43. #ifdef XDOUBLE
  44. #define ERROR_NAME "XGEMM_BATCH "
  45. #elif defined(DOUBLE)
  46. #define ERROR_NAME "ZGEMM_BATCH "
  47. #define GEMM_BATCH_THREAD zgemm_batch_thread
  48. #else
  49. #define ERROR_NAME "CGEMM_BATCH "
  50. #define GEMM_BATCH_THREAD cgemm_batch_thread
  51. #endif
  52. #endif
  53. static int (*gemm[])(blas_arg_t *, BLASLONG *, BLASLONG *, IFLOAT *, IFLOAT *, BLASLONG) = {
  54. GEMM_NN, GEMM_TN, GEMM_RN, GEMM_CN,
  55. GEMM_NT, GEMM_TT, GEMM_RT, GEMM_CT,
  56. GEMM_NR, GEMM_TR, GEMM_RR, GEMM_CR,
  57. GEMM_NC, GEMM_TC, GEMM_RC, GEMM_CC,
  58. };
  59. #if defined(SMALL_MATRIX_OPT) && !defined(GEMM3M) && !defined(XDOUBLE)
  60. #define USE_SMALL_MATRIX_OPT 1
  61. #else
  62. #define USE_SMALL_MATRIX_OPT 0
  63. #endif
  64. #if USE_SMALL_MATRIX_OPT
  65. #ifndef DYNAMIC_ARCH
  66. #define SMALL_KERNEL_ADDR(table, idx) ((void *)(table[idx]))
  67. #else
  68. #define SMALL_KERNEL_ADDR(table, idx) ((void *)(*(uintptr_t *)((char *)gotoblas + (size_t)(table[idx]))))
  69. #endif
  70. #ifndef COMPLEX
  71. static size_t gemm_small_kernel[] = {
  72. GEMM_SMALL_KERNEL_NN, GEMM_SMALL_KERNEL_TN, 0, 0,
  73. GEMM_SMALL_KERNEL_NT, GEMM_SMALL_KERNEL_TT, 0, 0,
  74. };
  75. static size_t gemm_small_kernel_b0[] = {
  76. GEMM_SMALL_KERNEL_B0_NN, GEMM_SMALL_KERNEL_B0_TN, 0, 0,
  77. GEMM_SMALL_KERNEL_B0_NT, GEMM_SMALL_KERNEL_B0_TT, 0, 0,
  78. };
  79. #define GEMM_SMALL_KERNEL_B0(idx) (int (*)(BLASLONG, BLASLONG, BLASLONG, IFLOAT *, BLASLONG, FLOAT, IFLOAT *, BLASLONG, FLOAT *, BLASLONG)) SMALL_KERNEL_ADDR(gemm_small_kernel_b0, (idx))
  80. #define GEMM_SMALL_KERNEL(idx) (int (*)(BLASLONG, BLASLONG, BLASLONG, IFLOAT *, BLASLONG, FLOAT, IFLOAT *, BLASLONG, FLOAT, FLOAT *, BLASLONG)) SMALL_KERNEL_ADDR(gemm_small_kernel, (idx))
  81. #else
  82. static size_t zgemm_small_kernel[] = {
  83. GEMM_SMALL_KERNEL_NN, GEMM_SMALL_KERNEL_TN, GEMM_SMALL_KERNEL_RN, GEMM_SMALL_KERNEL_CN,
  84. GEMM_SMALL_KERNEL_NT, GEMM_SMALL_KERNEL_TT, GEMM_SMALL_KERNEL_RT, GEMM_SMALL_KERNEL_CT,
  85. GEMM_SMALL_KERNEL_NR, GEMM_SMALL_KERNEL_TR, GEMM_SMALL_KERNEL_RR, GEMM_SMALL_KERNEL_CR,
  86. GEMM_SMALL_KERNEL_NC, GEMM_SMALL_KERNEL_TC, GEMM_SMALL_KERNEL_RC, GEMM_SMALL_KERNEL_CC,
  87. };
  88. static size_t zgemm_small_kernel_b0[] = {
  89. GEMM_SMALL_KERNEL_B0_NN, GEMM_SMALL_KERNEL_B0_TN, GEMM_SMALL_KERNEL_B0_RN, GEMM_SMALL_KERNEL_B0_CN,
  90. GEMM_SMALL_KERNEL_B0_NT, GEMM_SMALL_KERNEL_B0_TT, GEMM_SMALL_KERNEL_B0_RT, GEMM_SMALL_KERNEL_B0_CT,
  91. GEMM_SMALL_KERNEL_B0_NR, GEMM_SMALL_KERNEL_B0_TR, GEMM_SMALL_KERNEL_B0_RR, GEMM_SMALL_KERNEL_B0_CR,
  92. GEMM_SMALL_KERNEL_B0_NC, GEMM_SMALL_KERNEL_B0_TC, GEMM_SMALL_KERNEL_B0_RC, GEMM_SMALL_KERNEL_B0_CC,
  93. };
  94. #define ZGEMM_SMALL_KERNEL(idx) (int (*)(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT , FLOAT, FLOAT *, BLASLONG, FLOAT , FLOAT, FLOAT *, BLASLONG)) SMALL_KERNEL_ADDR(zgemm_small_kernel, (idx))
  95. #define ZGEMM_SMALL_KERNEL_B0(idx) (int (*)(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT , FLOAT, FLOAT *, BLASLONG, FLOAT *, BLASLONG)) SMALL_KERNEL_ADDR(zgemm_small_kernel_b0, (idx))
  96. #endif
  97. #endif
  98. void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE * transa_array, enum CBLAS_TRANSPOSE * transb_array,
  99. blasint * m_array, blasint * n_array, blasint * k_array,
  100. #ifndef COMPLEX
  101. FLOAT * alpha_array,
  102. FLOAT ** a_array, blasint * lda_array,
  103. FLOAT ** b_array, blasint * ldb_array,
  104. FLOAT * beta_array,
  105. FLOAT ** c_array, blasint * ldc_array, blasint group_count, blasint * group_size) {
  106. #else
  107. void * valpha_array,
  108. void ** va_array, blasint * lda_array,
  109. void ** vb_array, blasint * ldb_array,
  110. void * vbeta_array,
  111. void ** vc_array, blasint * ldc_array, blasint group_count, blasint * group_size) {
  112. FLOAT * alpha_array=(FLOAT *)valpha_array;
  113. FLOAT * beta_array=(FLOAT *)vbeta_array;
  114. FLOAT ** a_array=(FLOAT**)va_array;
  115. FLOAT ** b_array=(FLOAT**)vb_array;
  116. FLOAT ** c_array=(FLOAT**)vc_array;
  117. #endif
  118. blas_arg_t * args_array=NULL;
  119. int mode=0, group_mode=0;
  120. blasint total_num=0;
  121. blasint i=0, j=0, matrix_idx=0, count=0;
  122. int group_transa, group_transb;
  123. BLASLONG group_nrowa, group_nrowb;
  124. blasint info;
  125. void * group_alpha, * group_beta;
  126. BLASLONG group_m, group_n, group_k;
  127. BLASLONG group_lda, group_ldb, group_ldc;
  128. void * group_routine=NULL;
  129. #ifdef SMALL_MATRIX_OPT
  130. void * group_small_matrix_opt_routine=NULL;
  131. #endif
  132. #if defined (SMP) || defined(SMALL_MATRIX_OPT)
  133. double MNK;
  134. #endif
  135. PRINT_DEBUG_CNAME;
  136. for(i=0; i<group_count; i++){
  137. total_num+=group_size[i];
  138. }
  139. args_array=(blas_arg_t *)malloc(total_num * sizeof(blas_arg_t));
  140. if(args_array == NULL){
  141. openblas_warning(0, "memory alloc failed!\n");
  142. return;
  143. }
  144. #ifdef SMP
  145. #ifndef COMPLEX
  146. #ifdef XDOUBLE
  147. mode = BLAS_XDOUBLE | BLAS_REAL;
  148. #elif defined(DOUBLE)
  149. mode = BLAS_DOUBLE | BLAS_REAL;
  150. #else
  151. mode = BLAS_SINGLE | BLAS_REAL;
  152. #endif
  153. #else
  154. #ifdef XDOUBLE
  155. mode = BLAS_XDOUBLE | BLAS_COMPLEX;
  156. #elif defined(DOUBLE)
  157. mode = BLAS_DOUBLE | BLAS_COMPLEX;
  158. #else
  159. mode = BLAS_SINGLE | BLAS_COMPLEX;
  160. #endif
  161. #endif
  162. #endif
  163. for(i=0; i<group_count; matrix_idx+=group_size[i], i++){
  164. group_alpha = (void *)&alpha_array[i * COMPSIZE];
  165. group_beta = (void *)&beta_array[i * COMPSIZE];
  166. group_m = group_n = group_k = 0;
  167. group_lda = group_ldb = group_ldc = 0;
  168. group_transa = -1;
  169. group_transb = -1;
  170. info = 0;
  171. if (order == CblasColMajor) {
  172. group_m = m_array[i];
  173. group_n = n_array[i];
  174. group_k = k_array[i];
  175. group_lda = lda_array[i];
  176. group_ldb = ldb_array[i];
  177. group_ldc = ldc_array[i];
  178. if (transa_array[i] == CblasNoTrans) group_transa = 0;
  179. if (transa_array[i] == CblasTrans) group_transa = 1;
  180. #ifndef COMPLEX
  181. if (transa_array[i] == CblasConjNoTrans) group_transa = 0;
  182. if (transa_array[i] == CblasConjTrans) group_transa = 1;
  183. #else
  184. if (transa_array[i] == CblasConjNoTrans) group_transa = 2;
  185. if (transa_array[i] == CblasConjTrans) group_transa = 3;
  186. #endif
  187. if (transb_array[i] == CblasNoTrans) group_transb = 0;
  188. if (transb_array[i] == CblasTrans) group_transb = 1;
  189. #ifndef COMPLEX
  190. if (transb_array[i] == CblasConjNoTrans) group_transb = 0;
  191. if (transb_array[i] == CblasConjTrans) group_transb = 1;
  192. #else
  193. if (transb_array[i] == CblasConjNoTrans) group_transb = 2;
  194. if (transb_array[i] == CblasConjTrans) group_transb = 3;
  195. #endif
  196. group_nrowa = group_m;
  197. if (group_transa & 1) group_nrowa = group_k;
  198. group_nrowb = group_k;
  199. if (group_transb & 1) group_nrowb = group_n;
  200. info=-1;
  201. if (group_ldc < group_m) info = 13;
  202. if (group_ldb < group_nrowb) info = 10;
  203. if (group_lda < group_nrowa) info = 8;
  204. if (group_k < 0) info = 5;
  205. if (group_n < 0) info = 4;
  206. if (group_m < 0) info = 3;
  207. if (group_transb < 0) info = 2;
  208. if (group_transa < 0) info = 1;
  209. }else if (order == CblasRowMajor) {
  210. group_m = n_array[i];
  211. group_n = m_array[i];
  212. group_k = k_array[i];
  213. group_lda = ldb_array[i];
  214. group_ldb = lda_array[i];
  215. group_ldc = ldc_array[i];
  216. if (transb_array[i] == CblasNoTrans) group_transa = 0;
  217. if (transb_array[i] == CblasTrans) group_transa = 1;
  218. #ifndef COMPLEX
  219. if (transb_array[i] == CblasConjNoTrans) group_transa = 0;
  220. if (transb_array[i] == CblasConjTrans) group_transa = 1;
  221. #else
  222. if (transb_array[i] == CblasConjNoTrans) group_transa = 2;
  223. if (transb_array[i] == CblasConjTrans) group_transa = 3;
  224. #endif
  225. if (transa_array[i] == CblasNoTrans) group_transb = 0;
  226. if (transa_array[i] == CblasTrans) group_transb = 1;
  227. #ifndef COMPLEX
  228. if (transa_array[i] == CblasConjNoTrans) group_transb = 0;
  229. if (transa_array[i] == CblasConjTrans) group_transb = 1;
  230. #else
  231. if (transa_array[i] == CblasConjNoTrans) group_transb = 2;
  232. if (transa_array[i] == CblasConjTrans) group_transb = 3;
  233. #endif
  234. group_nrowa = group_m;
  235. if (group_transa & 1) group_nrowa = group_k;
  236. group_nrowb = group_k;
  237. if (group_transb & 1) group_nrowb = group_n;
  238. info=-1;
  239. if (group_ldc < group_m) info = 13;
  240. if (group_ldb < group_nrowb) info = 10;
  241. if (group_lda < group_nrowa) info = 8;
  242. if (group_k < 0) info = 5;
  243. if (group_n < 0) info = 4;
  244. if (group_m < 0) info = 3;
  245. if (group_transb < 0) info = 2;
  246. if (group_transa < 0) info = 1;
  247. }
  248. if (info >= 0) {
  249. BLASFUNC(xerbla)(ERROR_NAME, &info, sizeof(ERROR_NAME));
  250. free(args_array);
  251. return;
  252. }
  253. if (group_m == 0 || group_n == 0) continue;
  254. group_mode=mode;
  255. #if defined(SMP) || defined(SMALL_MATRIX_OPT)
  256. MNK = (double) group_m * (double) group_n * (double) group_k;
  257. #endif
  258. #ifdef SMALL_MATRIX_OPT
  259. if (MNK <= 100.0*100.0*100.0){
  260. group_routine=NULL;
  261. #if !defined(COMPLEX)
  262. if(*(FLOAT *)(group_beta) == 0.0){
  263. group_mode=mode | BLAS_SMALL_B0_OPT;
  264. group_small_matrix_opt_routine=(void *)(gemm_small_kernel_b0[(group_transb<<2)|group_transa]);
  265. }else{
  266. group_mode=mode | BLAS_SMALL_OPT;
  267. group_small_matrix_opt_routine=(void *)(gemm_small_kernel[(group_transb<<2)|group_transa]);
  268. }
  269. #else
  270. if(((FLOAT *)(group_beta))[0] == 0.0 && ((FLOAT *)(group_beta))[1] == 0.0){
  271. group_mode=mode | BLAS_SMALL_B0_OPT;
  272. group_small_matrix_opt_routine=(void *)(zgemm_small_kernel_b0[(group_transb<<2)|group_transa]);
  273. }else{
  274. group_mode=mode | BLAS_SMALL_OPT;
  275. group_small_matrix_opt_routine=(void *)(zgemm_small_kernel[(group_transb<<2)|group_transa]);
  276. }
  277. #endif
  278. }else{
  279. #endif
  280. group_routine=(void*)(gemm[(group_transb<<2)|group_transa]);
  281. #ifdef SMALL_MATRIX_OPT
  282. }
  283. #endif
  284. for(j=0; j<group_size[i]; j++){
  285. args_array[count].m=group_m;
  286. args_array[count].n=group_n;
  287. args_array[count].k=group_k;
  288. args_array[count].lda=group_lda;
  289. args_array[count].ldb=group_ldb;
  290. args_array[count].ldc=group_ldc;
  291. args_array[count].alpha=group_alpha;
  292. args_array[count].beta=group_beta;
  293. if (order == CblasColMajor) {
  294. args_array[count].a=(a_array[matrix_idx+j]);
  295. args_array[count].b=(b_array[matrix_idx+j]);
  296. }else if(order == CblasRowMajor){
  297. args_array[count].a=(b_array[matrix_idx+j]);
  298. args_array[count].b=(a_array[matrix_idx+j]);
  299. }
  300. args_array[count].c=(c_array[matrix_idx+j]);
  301. args_array[count].routine_mode=group_mode;
  302. args_array[count].routine=group_routine;
  303. #ifdef SMALL_MATRIX_OPT
  304. if (!group_routine)
  305. args_array[count].routine=group_small_matrix_opt_routine;
  306. #endif
  307. count++;
  308. }
  309. }
  310. if(count>0){
  311. GEMM_BATCH_THREAD(args_array,count);
  312. }
  313. free(args_array);
  314. }