You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dsygst.c 8.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. #include "relapack.h"
  2. #if XSYGST_ALLOW_MALLOC
  3. #include "stdlib.h"
  4. #endif
  5. static void RELAPACK_dsygst_rec(const blasint *, const char *, const blasint *,
  6. double *, const blasint *, const double *, const blasint *,
  7. double *, const blasint *, blasint *);
  8. /** DSYGST reduces a real symmetric-definite generalized eigenproblem to standard form.
  9. *
  10. * This routine is functionally equivalent to LAPACK's dsygst.
  11. * For details on its interface, see
  12. * http://www.netlib.org/lapack/explore-html/dc/d04/dsygst_8f.html
  13. * */
  14. void RELAPACK_dsygst(
  15. const blasint *itype, const char *uplo, const blasint *n,
  16. double *A, const blasint *ldA, const double *B, const blasint *ldB,
  17. blasint *info
  18. ) {
  19. // Check arguments
  20. const blasint lower = LAPACK(lsame)(uplo, "L");
  21. const blasint upper = LAPACK(lsame)(uplo, "U");
  22. *info = 0;
  23. if (*itype < 1 || *itype > 3)
  24. *info = -1;
  25. else if (!lower && !upper)
  26. *info = -2;
  27. else if (*n < 0)
  28. *info = -3;
  29. else if (*ldA < MAX(1, *n))
  30. *info = -5;
  31. else if (*ldB < MAX(1, *n))
  32. *info = -7;
  33. if (*info) {
  34. const blasint minfo = -*info;
  35. LAPACK(xerbla)("DSYGST", &minfo, strlen("DSYGST"));
  36. return;
  37. }
  38. // Clean char * arguments
  39. const char cleanuplo = lower ? 'L' : 'U';
  40. // Allocate work space
  41. double *Work = NULL;
  42. blasint lWork = 0;
  43. #if XSYGST_ALLOW_MALLOC
  44. const blasint n1 = DREC_SPLIT(*n);
  45. lWork = abs( n1 * (*n - n1) );
  46. Work = malloc(lWork * sizeof(double));
  47. if (!Work)
  48. lWork = 0;
  49. #endif
  50. // recursive kernel
  51. RELAPACK_dsygst_rec(itype, &cleanuplo, n, A, ldA, B, ldB, Work, &lWork, info);
  52. // Free work space
  53. #if XSYGST_ALLOW_MALLOC
  54. if (Work)
  55. free(Work);
  56. #endif
  57. }
  58. /** dsygst's recursive compute kernel */
  59. static void RELAPACK_dsygst_rec(
  60. const blasint *itype, const char *uplo, const blasint *n,
  61. double *A, const blasint *ldA, const double *B, const blasint *ldB,
  62. double *Work, const blasint *lWork, blasint *info
  63. ) {
  64. if (*n <= MAX(CROSSOVER_SSYGST, 1)) {
  65. // Unblocked
  66. LAPACK(dsygs2)(itype, uplo, n, A, ldA, B, ldB, info);
  67. return;
  68. }
  69. // Constants
  70. const double ZERO[] = { 0. };
  71. const double ONE[] = { 1. };
  72. const double MONE[] = { -1. };
  73. const double HALF[] = { .5 };
  74. const double MHALF[] = { -.5 };
  75. const blasint iONE[] = { 1 };
  76. // Loop iterator
  77. blasint i;
  78. // Splitting
  79. const blasint n1 = DREC_SPLIT(*n);
  80. const blasint n2 = *n - n1;
  81. // A_TL A_TR
  82. // A_BL A_BR
  83. double *const A_TL = A;
  84. double *const A_TR = A + *ldA * n1;
  85. double *const A_BL = A + n1;
  86. double *const A_BR = A + *ldA * n1 + n1;
  87. // B_TL B_TR
  88. // B_BL B_BR
  89. const double *const B_TL = B;
  90. const double *const B_TR = B + *ldB * n1;
  91. const double *const B_BL = B + n1;
  92. const double *const B_BR = B + *ldB * n1 + n1;
  93. // recursion(A_TL, B_TL)
  94. RELAPACK_dsygst_rec(itype, uplo, &n1, A_TL, ldA, B_TL, ldB, Work, lWork, info);
  95. if (*itype == 1)
  96. if (*uplo == 'L') {
  97. // A_BL = A_BL / B_TL'
  98. BLAS(dtrsm)("R", "L", "T", "N", &n2, &n1, ONE, B_TL, ldB, A_BL, ldA);
  99. if (*lWork > n2 * n1) {
  100. // T = -1/2 * B_BL * A_TL
  101. BLAS(dsymm)("R", "L", &n2, &n1, MHALF, A_TL, ldA, B_BL, ldB, ZERO, Work, &n2);
  102. // A_BL = A_BL + T
  103. for (i = 0; i < n1; i++)
  104. BLAS(daxpy)(&n2, ONE, Work + n2 * i, iONE, A_BL + *ldA * i, iONE);
  105. } else
  106. // A_BL = A_BL - 1/2 B_BL * A_TL
  107. BLAS(dsymm)("R", "L", &n2, &n1, MHALF, A_TL, ldA, B_BL, ldB, ONE, A_BL, ldA);
  108. // A_BR = A_BR - A_BL * B_BL' - B_BL * A_BL'
  109. BLAS(dsyr2k)("L", "N", &n2, &n1, MONE, A_BL, ldA, B_BL, ldB, ONE, A_BR, ldA);
  110. if (*lWork > n2 * n1)
  111. // A_BL = A_BL + T
  112. for (i = 0; i < n1; i++)
  113. BLAS(daxpy)(&n2, ONE, Work + n2 * i, iONE, A_BL + *ldA * i, iONE);
  114. else
  115. // A_BL = A_BL - 1/2 B_BL * A_TL
  116. BLAS(dsymm)("R", "L", &n2, &n1, MHALF, A_TL, ldA, B_BL, ldB, ONE, A_BL, ldA);
  117. // A_BL = B_BR \ A_BL
  118. BLAS(dtrsm)("L", "L", "N", "N", &n2, &n1, ONE, B_BR, ldB, A_BL, ldA);
  119. } else {
  120. // A_TR = B_TL' \ A_TR
  121. BLAS(dtrsm)("L", "U", "T", "N", &n1, &n2, ONE, B_TL, ldB, A_TR, ldA);
  122. if (*lWork > n2 * n1) {
  123. // T = -1/2 * A_TL * B_TR
  124. BLAS(dsymm)("L", "U", &n1, &n2, MHALF, A_TL, ldA, B_TR, ldB, ZERO, Work, &n1);
  125. // A_TR = A_BL + T
  126. for (i = 0; i < n2; i++)
  127. BLAS(daxpy)(&n1, ONE, Work + n1 * i, iONE, A_TR + *ldA * i, iONE);
  128. } else
  129. // A_TR = A_TR - 1/2 A_TL * B_TR
  130. BLAS(dsymm)("L", "U", &n1, &n2, MHALF, A_TL, ldA, B_TR, ldB, ONE, A_TR, ldA);
  131. // A_BR = A_BR - A_TR' * B_TR - B_TR' * A_TR
  132. BLAS(dsyr2k)("U", "T", &n2, &n1, MONE, A_TR, ldA, B_TR, ldB, ONE, A_BR, ldA);
  133. if (*lWork > n2 * n1)
  134. // A_TR = A_BL + T
  135. for (i = 0; i < n2; i++)
  136. BLAS(daxpy)(&n1, ONE, Work + n1 * i, iONE, A_TR + *ldA * i, iONE);
  137. else
  138. // A_TR = A_TR - 1/2 A_TL * B_TR
  139. BLAS(dsymm)("L", "U", &n1, &n2, MHALF, A_TL, ldA, B_TR, ldB, ONE, A_TR, ldA);
  140. // A_TR = A_TR / B_BR
  141. BLAS(dtrsm)("R", "U", "N", "N", &n1, &n2, ONE, B_BR, ldB, A_TR, ldA);
  142. }
  143. else
  144. if (*uplo == 'L') {
  145. // A_BL = A_BL * B_TL
  146. BLAS(dtrmm)("R", "L", "N", "N", &n2, &n1, ONE, B_TL, ldB, A_BL, ldA);
  147. if (*lWork > n2 * n1) {
  148. // T = 1/2 * A_BR * B_BL
  149. BLAS(dsymm)("L", "L", &n2, &n1, HALF, A_BR, ldA, B_BL, ldB, ZERO, Work, &n2);
  150. // A_BL = A_BL + T
  151. for (i = 0; i < n1; i++)
  152. BLAS(daxpy)(&n2, ONE, Work + n2 * i, iONE, A_BL + *ldA * i, iONE);
  153. } else
  154. // A_BL = A_BL + 1/2 A_BR * B_BL
  155. BLAS(dsymm)("L", "L", &n2, &n1, HALF, A_BR, ldA, B_BL, ldB, ONE, A_BL, ldA);
  156. // A_TL = A_TL + A_BL' * B_BL + B_BL' * A_BL
  157. BLAS(dsyr2k)("L", "T", &n1, &n2, ONE, A_BL, ldA, B_BL, ldB, ONE, A_TL, ldA);
  158. if (*lWork > n2 * n1)
  159. // A_BL = A_BL + T
  160. for (i = 0; i < n1; i++)
  161. BLAS(daxpy)(&n2, ONE, Work + n2 * i, iONE, A_BL + *ldA * i, iONE);
  162. else
  163. // A_BL = A_BL + 1/2 A_BR * B_BL
  164. BLAS(dsymm)("L", "L", &n2, &n1, HALF, A_BR, ldA, B_BL, ldB, ONE, A_BL, ldA);
  165. // A_BL = B_BR * A_BL
  166. BLAS(dtrmm)("L", "L", "T", "N", &n2, &n1, ONE, B_BR, ldB, A_BL, ldA);
  167. } else {
  168. // A_TR = B_TL * A_TR
  169. BLAS(dtrmm)("L", "U", "N", "N", &n1, &n2, ONE, B_TL, ldB, A_TR, ldA);
  170. if (*lWork > n2 * n1) {
  171. // T = 1/2 * B_TR * A_BR
  172. BLAS(dsymm)("R", "U", &n1, &n2, HALF, A_BR, ldA, B_TR, ldB, ZERO, Work, &n1);
  173. // A_TR = A_TR + T
  174. for (i = 0; i < n2; i++)
  175. BLAS(daxpy)(&n1, ONE, Work + n1 * i, iONE, A_TR + *ldA * i, iONE);
  176. } else
  177. // A_TR = A_TR + 1/2 B_TR A_BR
  178. BLAS(dsymm)("R", "U", &n1, &n2, HALF, A_BR, ldA, B_TR, ldB, ONE, A_TR, ldA);
  179. // A_TL = A_TL + A_TR * B_TR' + B_TR * A_TR'
  180. BLAS(dsyr2k)("U", "N", &n1, &n2, ONE, A_TR, ldA, B_TR, ldB, ONE, A_TL, ldA);
  181. if (*lWork > n2 * n1)
  182. // A_TR = A_TR + T
  183. for (i = 0; i < n2; i++)
  184. BLAS(daxpy)(&n1, ONE, Work + n1 * i, iONE, A_TR + *ldA * i, iONE);
  185. else
  186. // A_TR = A_TR + 1/2 B_TR * A_BR
  187. BLAS(dsymm)("R", "U", &n1, &n2, HALF, A_BR, ldA, B_TR, ldB, ONE, A_TR, ldA);
  188. // A_TR = A_TR * B_BR
  189. BLAS(dtrmm)("R", "U", "T", "N", &n1, &n2, ONE, B_BR, ldB, A_TR, ldA);
  190. }
  191. // recursion(A_BR, B_BR)
  192. RELAPACK_dsygst_rec(itype, uplo, &n2, A_BR, ldA, B_BR, ldB, Work, lWork, info);
  193. }