You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

spotrf.c 2.5 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. #include "relapack.h"
  2. static void RELAPACK_spotrf_rec(const char *, const blasint *, float *,
  3. const blasint *, blasint *);
  4. /** SPOTRF computes the Cholesky factorization of a real symmetric positive definite matrix A.
  5. *
  6. * This routine is functionally equivalent to LAPACK's spotrf.
  7. * For details on its interface, see
  8. * http://www.netlib.org/lapack/explore-html/d0/da2/spotrf_8f.html
  9. * */
  10. void RELAPACK_spotrf(
  11. const char *uplo, const blasint *n,
  12. float *A, const blasint *ldA,
  13. blasint *info
  14. ) {
  15. // Check arguments
  16. const blasint lower = LAPACK(lsame)(uplo, "L");
  17. const blasint upper = LAPACK(lsame)(uplo, "U");
  18. *info = 0;
  19. if (!lower && !upper)
  20. *info = -1;
  21. else if (*n < 0)
  22. *info = -2;
  23. else if (*ldA < MAX(1, *n))
  24. *info = -4;
  25. if (*info) {
  26. const blasint minfo = -*info;
  27. LAPACK(xerbla)("SPOTRF", &minfo, strlen("SPOTRF"));
  28. return;
  29. }
  30. // Clean char * arguments
  31. const char cleanuplo = lower ? 'L' : 'U';
  32. // Recursive kernel
  33. RELAPACK_spotrf_rec(&cleanuplo, n, A, ldA, info);
  34. }
  35. /** spotrf's recursive compute kernel */
  36. static void RELAPACK_spotrf_rec(
  37. const char *uplo, const blasint *n,
  38. float *A, const blasint *ldA,
  39. blasint *info
  40. ) {
  41. if (*n <= MAX(CROSSOVER_SPOTRF, 1)) {
  42. // Unblocked
  43. LAPACK(spotf2)(uplo, n, A, ldA, info);
  44. return;
  45. }
  46. // Constants
  47. const float ONE[] = { 1. };
  48. const float MONE[] = { -1. };
  49. // Splitting
  50. const blasint n1 = SREC_SPLIT(*n);
  51. const blasint n2 = *n - n1;
  52. // A_TL A_TR
  53. // A_BL A_BR
  54. float *const A_TL = A;
  55. float *const A_TR = A + *ldA * n1;
  56. float *const A_BL = A + n1;
  57. float *const A_BR = A + *ldA * n1 + n1;
  58. // recursion(A_TL)
  59. RELAPACK_spotrf_rec(uplo, &n1, A_TL, ldA, info);
  60. if (*info)
  61. return;
  62. if (*uplo == 'L') {
  63. // A_BL = A_BL / A_TL'
  64. BLAS(strsm)("R", "L", "T", "N", &n2, &n1, ONE, A_TL, ldA, A_BL, ldA);
  65. // A_BR = A_BR - A_BL * A_BL'
  66. BLAS(ssyrk)("L", "N", &n2, &n1, MONE, A_BL, ldA, ONE, A_BR, ldA);
  67. } else {
  68. // A_TR = A_TL' \ A_TR
  69. BLAS(strsm)("L", "U", "T", "N", &n1, &n2, ONE, A_TL, ldA, A_TR, ldA);
  70. // A_BR = A_BR - A_TR' * A_TR
  71. BLAS(ssyrk)("U", "T", &n2, &n1, MONE, A_TR, ldA, ONE, A_BR, ldA);
  72. }
  73. // recursion(A_BR)
  74. RELAPACK_spotrf_rec(uplo, &n2, A_BR, ldA, info);
  75. if (*info)
  76. *info += n1;
  77. }