You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

chetrf_rook.c 7.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. #include "relapack.h"
  2. #if XSYTRF_ALLOW_MALLOC
  3. #include <stdlib.h>
  4. #endif
  5. static void RELAPACK_chetrf_rook_rec(const char *, const blasint *, const blasint *, blasint *,
  6. float *, const blasint *, blasint *, float *, const blasint *, blasint *);
  7. /** CHETRF_ROOK computes the factorization of a complex Hermitian indefinite matrix using the bounded Bunch-Kaufman ("rook") diagonal pivoting method.
  8. *
  9. * This routine is functionally equivalent to LAPACK's chetrf_rook.
  10. * For details on its interface, see
  11. * http://www.netlib.org/lapack/explore-html/d0/d5e/chetrf__rook_8f.html
  12. * */
  13. void RELAPACK_chetrf_rook(
  14. const char *uplo, const blasint *n,
  15. float *A, const blasint *ldA, blasint *ipiv,
  16. float *Work, const blasint *lWork, blasint *info
  17. ) {
  18. // Required work size
  19. const blasint cleanlWork = *n * (*n / 2);
  20. blasint minlWork = cleanlWork;
  21. #if XSYTRF_ALLOW_MALLOC
  22. minlWork = 1;
  23. #endif
  24. // Check arguments
  25. const blasint lower = LAPACK(lsame)(uplo, "L");
  26. const blasint upper = LAPACK(lsame)(uplo, "U");
  27. *info = 0;
  28. if (!lower && !upper)
  29. *info = -1;
  30. else if (*n < 0)
  31. *info = -2;
  32. else if (*ldA < MAX(1, *n))
  33. *info = -4;
  34. else if ((*lWork < 1 || *lWork < minlWork) && *lWork != -1)
  35. *info = -7;
  36. else if (*lWork == -1) {
  37. // Work size query
  38. *Work = cleanlWork;
  39. return;
  40. }
  41. // Ensure Work size
  42. float *cleanWork = Work;
  43. #if XSYTRF_ALLOW_MALLOC
  44. if (!*info && *lWork < cleanlWork) {
  45. cleanWork = malloc(cleanlWork * 2 * sizeof(float));
  46. if (!cleanWork)
  47. *info = -7;
  48. }
  49. #endif
  50. if (*info) {
  51. const blasint minfo = -*info;
  52. LAPACK(xerbla)("CHETRF_ROOK", &minfo, strlen("CHETRF_ROOK"));
  53. return;
  54. }
  55. // Clean char * arguments
  56. const char cleanuplo = lower ? 'L' : 'U';
  57. // Dummy argument
  58. blasint nout;
  59. // Recursive kernel
  60. RELAPACK_chetrf_rook_rec(&cleanuplo, n, n, &nout, A, ldA, ipiv, cleanWork, n, info);
  61. #if XSYTRF_ALLOW_MALLOC
  62. if (cleanWork != Work)
  63. free(cleanWork);
  64. #endif
  65. }
  66. /** chetrf_rook's recursive compute kernel */
  67. static void RELAPACK_chetrf_rook_rec(
  68. const char *uplo, const blasint *n_full, const blasint *n, blasint *n_out,
  69. float *A, const blasint *ldA, blasint *ipiv,
  70. float *Work, const blasint *ldWork, blasint *info
  71. ) {
  72. // top recursion level?
  73. const blasint top = *n_full == *n;
  74. if (*n <= MAX(CROSSOVER_CHETRF, 3)) {
  75. // Unblocked
  76. if (top) {
  77. LAPACK(chetf2)(uplo, n, A, ldA, ipiv, info);
  78. *n_out = *n;
  79. } else
  80. RELAPACK_chetrf_rook_rec2(uplo, n_full, n, n_out, A, ldA, ipiv, Work, ldWork, info);
  81. return;
  82. }
  83. blasint info1, info2;
  84. // Constants
  85. const float ONE[] = { 1., 0. };
  86. const float MONE[] = { -1., 0. };
  87. const blasint iONE[] = { 1 };
  88. const blasint n_rest = *n_full - *n;
  89. if (*uplo == 'L') {
  90. // Splitting (setup)
  91. blasint n1 = CREC_SPLIT(*n);
  92. blasint n2 = *n - n1;
  93. // Work_L *
  94. float *const Work_L = Work;
  95. // recursion(A_L)
  96. blasint n1_out;
  97. RELAPACK_chetrf_rook_rec(uplo, n_full, &n1, &n1_out, A, ldA, ipiv, Work_L, ldWork, &info1);
  98. n1 = n1_out;
  99. // Splitting (continued)
  100. n2 = *n - n1;
  101. const blasint n_full2 = *n_full - n1;
  102. // * *
  103. // A_BL A_BR
  104. // A_BL_B A_BR_B
  105. float *const A_BL = A + 2 * n1;
  106. float *const A_BR = A + 2 * *ldA * n1 + 2 * n1;
  107. float *const A_BL_B = A + 2 * *n;
  108. float *const A_BR_B = A + 2 * *ldA * n1 + 2 * *n;
  109. // * *
  110. // Work_BL Work_BR
  111. // * *
  112. // (top recursion level: use Work as Work_BR)
  113. float *const Work_BL = Work + 2 * n1;
  114. float *const Work_BR = top ? Work : Work + 2 * *ldWork * n1 + 2 * n1;
  115. const blasint ldWork_BR = top ? n2 : *ldWork;
  116. // ipiv_T
  117. // ipiv_B
  118. blasint *const ipiv_B = ipiv + n1;
  119. // A_BR = A_BR - A_BL Work_BL'
  120. RELAPACK_cgemmt(uplo, "N", "T", &n2, &n1, MONE, A_BL, ldA, Work_BL, ldWork, ONE, A_BR, ldA);
  121. BLAS(cgemm)("N", "T", &n_rest, &n2, &n1, MONE, A_BL_B, ldA, Work_BL, ldWork, ONE, A_BR_B, ldA);
  122. // recursion(A_BR)
  123. blasint n2_out;
  124. RELAPACK_chetrf_rook_rec(uplo, &n_full2, &n2, &n2_out, A_BR, ldA, ipiv_B, Work_BR, &ldWork_BR, &info2);
  125. if (n2_out != n2) {
  126. // undo 1 column of updates
  127. const blasint n_restp1 = n_rest + 1;
  128. // last column of A_BR
  129. float *const A_BR_r = A_BR + 2 * *ldA * n2_out + 2 * n2_out;
  130. // last row of A_BL
  131. float *const A_BL_b = A_BL + 2 * n2_out;
  132. // last row of Work_BL
  133. float *const Work_BL_b = Work_BL + 2 * n2_out;
  134. // A_BR_r = A_BR_r + A_BL_b Work_BL_b'
  135. BLAS(cgemv)("N", &n_restp1, &n1, ONE, A_BL_b, ldA, Work_BL_b, ldWork, ONE, A_BR_r, iONE);
  136. }
  137. n2 = n2_out;
  138. // shift pivots
  139. blasint i;
  140. for (i = 0; i < n2; i++)
  141. if (ipiv_B[i] > 0)
  142. ipiv_B[i] += n1;
  143. else
  144. ipiv_B[i] -= n1;
  145. *info = info1 || info2;
  146. *n_out = n1 + n2;
  147. } else {
  148. // Splitting (setup)
  149. blasint n2 = CREC_SPLIT(*n);
  150. blasint n1 = *n - n2;
  151. // * Work_R
  152. // (top recursion level: use Work as Work_R)
  153. float *const Work_R = top ? Work : Work + 2 * *ldWork * n1;
  154. // recursion(A_R)
  155. blasint n2_out;
  156. RELAPACK_chetrf_rook_rec(uplo, n_full, &n2, &n2_out, A, ldA, ipiv, Work_R, ldWork, &info2);
  157. const blasint n2_diff = n2 - n2_out;
  158. n2 = n2_out;
  159. // Splitting (continued)
  160. n1 = *n - n2;
  161. const blasint n_full1 = *n_full - n2;
  162. // * A_TL_T A_TR_T
  163. // * A_TL A_TR
  164. // * * *
  165. float *const A_TL_T = A + 2 * *ldA * n_rest;
  166. float *const A_TR_T = A + 2 * *ldA * (n_rest + n1);
  167. float *const A_TL = A + 2 * *ldA * n_rest + 2 * n_rest;
  168. float *const A_TR = A + 2 * *ldA * (n_rest + n1) + 2 * n_rest;
  169. // Work_L *
  170. // * Work_TR
  171. // * *
  172. // (top recursion level: Work_R was Work)
  173. float *const Work_L = Work;
  174. float *const Work_TR = Work + 2 * *ldWork * (top ? n2_diff : n1) + 2 * n_rest;
  175. const blasint ldWork_L = top ? n1 : *ldWork;
  176. // A_TL = A_TL - A_TR Work_TR'
  177. RELAPACK_cgemmt(uplo, "N", "T", &n1, &n2, MONE, A_TR, ldA, Work_TR, ldWork, ONE, A_TL, ldA);
  178. BLAS(cgemm)("N", "T", &n_rest, &n1, &n2, MONE, A_TR_T, ldA, Work_TR, ldWork, ONE, A_TL_T, ldA);
  179. // recursion(A_TL)
  180. blasint n1_out;
  181. RELAPACK_chetrf_rook_rec(uplo, &n_full1, &n1, &n1_out, A, ldA, ipiv, Work_L, &ldWork_L, &info1);
  182. if (n1_out != n1) {
  183. // undo 1 column of updates
  184. const blasint n_restp1 = n_rest + 1;
  185. // A_TL_T_l = A_TL_T_l + A_TR_T Work_TR_t'
  186. BLAS(cgemv)("N", &n_restp1, &n2, ONE, A_TR_T, ldA, Work_TR, ldWork, ONE, A_TL_T, iONE);
  187. }
  188. n1 = n1_out;
  189. *info = info2 || info1;
  190. *n_out = n1 + n2;
  191. }
  192. }