You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

zgemm_kernel_generic.c 3.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. #include "common.h"
  2. /* for debugging/unit tests
  3. * this is a drop-in replacement for zgemm/cgemm/ztrmm/ctrmm kernels that supports arbitrary combinations of unroll values
  4. */
  5. #ifdef TRMMKERNEL
  6. #if defined(LEFT) != defined(TRANSA)
  7. #define BACKWARDS
  8. #endif
  9. #endif
  10. #ifdef DOUBLE
  11. #define UNROLL_M ZGEMM_DEFAULT_UNROLL_M
  12. #define UNROLL_N ZGEMM_DEFAULT_UNROLL_N
  13. #else
  14. #define UNROLL_M CGEMM_DEFAULT_UNROLL_M
  15. #define UNROLL_N CGEMM_DEFAULT_UNROLL_N
  16. #endif
  17. int CNAME(BLASLONG M,BLASLONG N,BLASLONG K,FLOAT alphar,FLOAT alphai,FLOAT* A,FLOAT* B,FLOAT* C,BLASLONG ldc
  18. #ifdef TRMMKERNEL
  19. ,BLASLONG offset
  20. #endif
  21. )
  22. {
  23. FLOAT res[UNROLL_M*UNROLL_N*2];
  24. #if defined(NN) || defined(NT) || defined(TN) || defined(TT)
  25. FLOAT sign[4] = { 1, -1, 1, 1};
  26. #endif
  27. #if defined(NR) || defined(NC) || defined(TR) || defined(TC)
  28. FLOAT sign[4] = { 1, 1, 1, -1};
  29. #endif
  30. #if defined(RN) || defined(RT) || defined(CN) || defined(CT)
  31. FLOAT sign[4] = { 1, 1, -1, 1};
  32. #endif
  33. #if defined(RR) || defined(RC) || defined(CR) || defined(CC)
  34. FLOAT sign[4] = { 1, -1, -1, -1};
  35. #endif
  36. BLASLONG n_packing = UNROLL_N;
  37. BLASLONG n_top = 0;
  38. while(n_top < N)
  39. {
  40. while( n_top+n_packing > N )
  41. n_packing >>= 1;
  42. BLASLONG m_packing = UNROLL_M;
  43. BLASLONG m_top = 0;
  44. while (m_top < M)
  45. {
  46. while( m_top+m_packing > M )
  47. m_packing >>= 1;
  48. BLASLONG ai = K*m_top*2;
  49. BLASLONG bi = K*n_top*2;
  50. BLASLONG pass_K = K;
  51. #ifdef TRMMKERNEL
  52. #ifdef LEFT
  53. BLASLONG off = offset + m_top;
  54. #else
  55. BLASLONG off = -offset + n_top;
  56. #endif
  57. #ifdef BACKWARDS
  58. ai += off * m_packing*2;
  59. bi += off * n_packing*2;
  60. pass_K -= off;
  61. #else
  62. #ifdef LEFT
  63. pass_K = off + m_packing;
  64. #else
  65. pass_K = off + n_packing;
  66. #endif
  67. #endif
  68. #endif
  69. memset( res, 0, UNROLL_M*UNROLL_N*2*sizeof(FLOAT) );
  70. for (BLASLONG k=0; k<pass_K; k+=1)
  71. {
  72. for( BLASLONG ki = 0; ki < n_packing; ++ki )
  73. {
  74. FLOAT B0 = B[bi+ki*2+0];
  75. FLOAT B1 = B[bi+ki*2+1];
  76. for( BLASLONG kj = 0; kj < m_packing; ++kj )
  77. {
  78. FLOAT A0 = A[ai+kj*2+0];
  79. FLOAT A1 = A[ai+kj*2+1];
  80. res[(ki*UNROLL_M+kj)*2+0] += sign[0]*A0*B0 +sign[1]*A1*B1;
  81. res[(ki*UNROLL_M+kj)*2+1] += sign[2]*A1*B0 +sign[3]*A0*B1;
  82. }
  83. }
  84. ai += m_packing*2;
  85. bi += n_packing*2;
  86. }
  87. BLASLONG cofs = ldc * n_top + m_top;
  88. for( BLASLONG ki = 0; ki < n_packing; ++ki )
  89. {
  90. for( BLASLONG kj = 0; kj < m_packing; ++kj )
  91. {
  92. #ifdef TRMMKERNEL
  93. FLOAT Cr = 0;
  94. FLOAT Ci = 0;
  95. #else
  96. FLOAT Cr = C[(cofs+ki*ldc+kj)*2+0];
  97. FLOAT Ci = C[(cofs+ki*ldc+kj)*2+1];
  98. #endif
  99. Cr += res[(ki*UNROLL_M+kj)*2+0]*alphar;
  100. Cr += -res[(ki*UNROLL_M+kj)*2+1]*alphai;
  101. Ci += res[(ki*UNROLL_M+kj)*2+1]*alphar;
  102. Ci += res[(ki*UNROLL_M+kj)*2+0]*alphai;
  103. C[(cofs+ki*ldc+kj)*2+0] = Cr;
  104. C[(cofs+ki*ldc+kj)*2+1] = Ci;
  105. }
  106. }
  107. m_top += m_packing;
  108. }
  109. n_top += n_packing;
  110. }
  111. return 0;
  112. }