You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cgemmt.c 5.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. #include "relapack.h"
  2. static void RELAPACK_cgemmt_rec(const char *, const char *, const char *,
  3. const blasint *, const blasint *, const float *, const float *, const blasint *,
  4. const float *, const blasint *, const float *, float *, const blasint *);
  5. static void RELAPACK_cgemmt_rec2(const char *, const char *, const char *,
  6. const blasint *, const blasint *, const float *, const float *, const blasint *,
  7. const float *, const blasint *, const float *, float *, const blasint *);
  8. /** CGEMMT computes a matrix-matrix product with general matrices but updates
  9. * only the upper or lower triangular part of the result matrix.
  10. *
  11. * This routine performs the same operation as the BLAS routine
  12. * cgemm(transA, transB, n, n, k, alpha, A, ldA, B, ldB, beta, C, ldC)
  13. * but only updates the triangular part of C specified by uplo:
  14. * If (*uplo == 'L'), only the lower triangular part of C is updated,
  15. * otherwise the upper triangular part is updated.
  16. * */
  17. void RELAPACK_cgemmt(
  18. const char *uplo, const char *transA, const char *transB,
  19. const blasint *n, const blasint *k,
  20. const float *alpha, const float *A, const blasint *ldA,
  21. const float *B, const blasint *ldB,
  22. const float *beta, float *C, const blasint *ldC
  23. ) {
  24. #if HAVE_XGEMMT
  25. BLAS(cgemmt)(uplo, transA, transB, n, k, alpha, A, ldA, B, ldB, beta, C, ldC);
  26. return;
  27. #else
  28. // Check arguments
  29. const blasint lower = LAPACK(lsame)(uplo, "L");
  30. const blasint upper = LAPACK(lsame)(uplo, "U");
  31. const blasint notransA = LAPACK(lsame)(transA, "N");
  32. const blasint tranA = LAPACK(lsame)(transA, "T");
  33. const blasint ctransA = LAPACK(lsame)(transA, "C");
  34. const blasint notransB = LAPACK(lsame)(transB, "N");
  35. const blasint tranB = LAPACK(lsame)(transB, "T");
  36. const blasint ctransB = LAPACK(lsame)(transB, "C");
  37. blasint info = 0;
  38. if (!lower && !upper)
  39. info = 1;
  40. else if (!tranA && !ctransA && !notransA)
  41. info = 2;
  42. else if (!tranB && !ctransB && !notransB)
  43. info = 3;
  44. else if (*n < 0)
  45. info = 4;
  46. else if (*k < 0)
  47. info = 5;
  48. else if (*ldA < MAX(1, notransA ? *n : *k))
  49. info = 8;
  50. else if (*ldB < MAX(1, notransB ? *k : *n))
  51. info = 10;
  52. else if (*ldC < MAX(1, *n))
  53. info = 13;
  54. if (info) {
  55. LAPACK(xerbla)("CGEMMT", &info, strlen("CGEMMT"));
  56. return;
  57. }
  58. // Clean char * arguments
  59. const char cleanuplo = lower ? 'L' : 'U';
  60. const char cleantransA = notransA ? 'N' : (tranA ? 'T' : 'C');
  61. const char cleantransB = notransB ? 'N' : (tranB ? 'T' : 'C');
  62. // Recursive kernel
  63. RELAPACK_cgemmt_rec(&cleanuplo, &cleantransA, &cleantransB, n, k, alpha, A, ldA, B, ldB, beta, C, ldC);
  64. #endif
  65. }
  66. /** cgemmt's recursive compute kernel */
  67. static void RELAPACK_cgemmt_rec(
  68. const char *uplo, const char *transA, const char *transB,
  69. const blasint *n, const blasint *k,
  70. const float *alpha, const float *A, const blasint *ldA,
  71. const float *B, const blasint *ldB,
  72. const float *beta, float *C, const blasint *ldC
  73. ) {
  74. if (*n <= MAX(CROSSOVER_CGEMMT, 1)) {
  75. // Unblocked
  76. RELAPACK_cgemmt_rec2(uplo, transA, transB, n, k, alpha, A, ldA, B, ldB, beta, C, ldC);
  77. return;
  78. }
  79. // Splitting
  80. const blasint n1 = CREC_SPLIT(*n);
  81. const blasint n2 = *n - n1;
  82. // A_T
  83. // A_B
  84. const float *const A_T = A;
  85. const float *const A_B = A + 2 * ((*transA == 'N') ? n1 : *ldA * n1);
  86. // B_L B_R
  87. const float *const B_L = B;
  88. const float *const B_R = B + 2 * ((*transB == 'N') ? *ldB * n1 : n1);
  89. // C_TL C_TR
  90. // C_BL C_BR
  91. float *const C_TL = C;
  92. float *const C_TR = C + 2 * *ldC * n1;
  93. float *const C_BL = C + 2 * n1;
  94. float *const C_BR = C + 2 * *ldC * n1 + 2 * n1;
  95. // recursion(C_TL)
  96. RELAPACK_cgemmt_rec(uplo, transA, transB, &n1, k, alpha, A_T, ldA, B_L, ldB, beta, C_TL, ldC);
  97. if (*uplo == 'L')
  98. // C_BL = alpha A_B B_L + beta C_BL
  99. BLAS(cgemm)(transA, transB, &n2, &n1, k, alpha, A_B, ldA, B_L, ldB, beta, C_BL, ldC);
  100. else
  101. // C_TR = alpha A_T B_R + beta C_TR
  102. BLAS(cgemm)(transA, transB, &n1, &n2, k, alpha, A_T, ldA, B_R, ldB, beta, C_TR, ldC);
  103. // recursion(C_BR)
  104. RELAPACK_cgemmt_rec(uplo, transA, transB, &n2, k, alpha, A_B, ldA, B_R, ldB, beta, C_BR, ldC);
  105. }
  106. /** cgemmt's unblocked compute kernel */
  107. static void RELAPACK_cgemmt_rec2(
  108. const char *uplo, const char *transA, const char *transB,
  109. const blasint *n, const blasint *k,
  110. const float *alpha, const float *A, const blasint *ldA,
  111. const float *B, const blasint *ldB,
  112. const float *beta, float *C, const blasint *ldC
  113. ) {
  114. const blasint incB = (*transB == 'N') ? 1 : *ldB;
  115. const blasint incC = 1;
  116. blasint i;
  117. for (i = 0; i < *n; i++) {
  118. // A_0
  119. // A_i
  120. const float *const A_0 = A;
  121. const float *const A_i = A + 2 * ((*transA == 'N') ? i : *ldA * i);
  122. // * B_i *
  123. const float *const B_i = B + 2 * ((*transB == 'N') ? *ldB * i : i);
  124. // * C_0i *
  125. // * C_ii *
  126. float *const C_0i = C + 2 * *ldC * i;
  127. float *const C_ii = C + 2 * *ldC * i + 2 * i;
  128. if (*uplo == 'L') {
  129. const blasint nmi = *n - i;
  130. if (*transA == 'N')
  131. BLAS(cgemv)(transA, &nmi, k, alpha, A_i, ldA, B_i, &incB, beta, C_ii, &incC);
  132. else
  133. BLAS(cgemv)(transA, k, &nmi, alpha, A_i, ldA, B_i, &incB, beta, C_ii, &incC);
  134. } else {
  135. const blasint ip1 = i + 1;
  136. if (*transA == 'N')
  137. BLAS(cgemv)(transA, &ip1, k, alpha, A_0, ldA, B_i, &incB, beta, C_0i, &incC);
  138. else
  139. BLAS(cgemv)(transA, k, &ip1, alpha, A_0, ldA, B_i, &incB, beta, C_0i, &incC);
  140. }
  141. }
  142. }