You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

bench_blas.py 3.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. import pytest
  2. import numpy as np
  3. from openblas_wrap import (
  4. # level 1
  5. dnrm2, ddot, daxpy,
  6. # level 3
  7. dgemm, dsyrk,
  8. # lapack
  9. dgesv, # linalg.solve
  10. dgesdd, dgesdd_lwork, # linalg.svd
  11. dsyev, dsyev_lwork, # linalg.eigh
  12. )
  13. # ### BLAS level 1 ###
  14. # dnrm2
  15. dnrm2_sizes = [100, 1000]
  16. def run_dnrm2(n, x, incx):
  17. res = dnrm2(x, n, incx=incx)
  18. return res
  19. @pytest.mark.parametrize('n', dnrm2_sizes)
  20. def test_nrm2(benchmark, n):
  21. rndm = np.random.RandomState(1234)
  22. x = np.array(rndm.uniform(size=(n,)), dtype=float)
  23. result = benchmark(run_dnrm2, n, x, 1)
  24. # ddot
  25. ddot_sizes = [100, 1000]
  26. def run_ddot(x, y,):
  27. res = ddot(x, y)
  28. return res
  29. @pytest.mark.parametrize('n', ddot_sizes)
  30. def test_dot(benchmark, n):
  31. rndm = np.random.RandomState(1234)
  32. x = np.array(rndm.uniform(size=(n,)), dtype=float)
  33. y = np.array(rndm.uniform(size=(n,)), dtype=float)
  34. result = benchmark(run_ddot, x, y)
  35. # daxpy
  36. daxpy_sizes = [100, 1000]
  37. def run_daxpy(x, y,):
  38. res = daxpy(x, y, a=2.0)
  39. return res
  40. @pytest.mark.parametrize('n', daxpy_sizes)
  41. def test_daxpy(benchmark, n):
  42. rndm = np.random.RandomState(1234)
  43. x = np.array(rndm.uniform(size=(n,)), dtype=float)
  44. y = np.array(rndm.uniform(size=(n,)), dtype=float)
  45. result = benchmark(run_daxpy, x, y)
  46. # ### BLAS level 3 ###
  47. # dgemm
  48. gemm_sizes = [100, 1000]
  49. def run_gemm(a, b, c):
  50. alpha = 1.0
  51. res = dgemm(alpha, a, b, c=c, overwrite_c=True)
  52. return res
  53. @pytest.mark.parametrize('n', gemm_sizes)
  54. def test_gemm(benchmark, n):
  55. rndm = np.random.RandomState(1234)
  56. a = np.array(rndm.uniform(size=(n, n)), dtype=float, order='F')
  57. b = np.array(rndm.uniform(size=(n, n)), dtype=float, order='F')
  58. c = np.empty((n, n), dtype=float, order='F')
  59. result = benchmark(run_gemm, a, b, c)
  60. assert result is c
  61. # dsyrk
  62. syrk_sizes = [100, 1000]
  63. def run_syrk(a, c):
  64. res = dsyrk(1.0, a, c=c, overwrite_c=True)
  65. return res
  66. @pytest.mark.parametrize('n', syrk_sizes)
  67. def test_syrk(benchmark, n):
  68. rndm = np.random.RandomState(1234)
  69. a = np.array(rndm.uniform(size=(n, n)), dtype=float, order='F')
  70. c = np.empty((n, n), dtype=float, order='F')
  71. result = benchmark(run_syrk, a, c)
  72. assert result is c
  73. # ### LAPACK ###
  74. # linalg.solve
  75. gesv_sizes = [100, 1000]
  76. def run_gesv(a, b):
  77. res = dgesv(a, b, overwrite_a=True, overwrite_b=True)
  78. return res
  79. @pytest.mark.parametrize('n', gesv_sizes)
  80. def test_gesv(benchmark, n):
  81. rndm = np.random.RandomState(1234)
  82. a = (np.array(rndm.uniform(size=(n, n)), dtype=float, order='F') +
  83. np.eye(n, order='F'))
  84. b = np.array(rndm.uniform(size=(n, 1)), order='F')
  85. lu, piv, x, info = benchmark(run_gesv, a, b)
  86. assert lu is a
  87. assert x is b
  88. assert info == 0
  89. # linalg.svd
  90. gesdd_sizes = [(100, 5), (1000, 222)]
  91. def run_gesdd(a, lwork):
  92. res = dgesdd(a, lwork=lwork, full_matrices=False, overwrite_a=False)
  93. return res
  94. @pytest.mark.parametrize('mn', gesdd_sizes)
  95. def test_gesdd(benchmark, mn):
  96. m, n = mn
  97. rndm = np.random.RandomState(1234)
  98. a = np.array(rndm.uniform(size=(m, n)), dtype=float, order='F')
  99. lwork, info = dgesdd_lwork(m, n)
  100. lwork = int(lwork)
  101. assert info == 0
  102. u, s, vt, info = benchmark(run_gesdd, a, lwork)
  103. assert info == 0
  104. np.testing.assert_allclose(u @ np.diag(s) @ vt, a, atol=1e-13)
  105. # linalg.eigh
  106. syev_sizes = [50, 200]
  107. def run_syev(a, lwork):
  108. res = dsyev(a, lwork=lwork, overwrite_a=True)
  109. return res
  110. @pytest.mark.parametrize('n', syev_sizes)
  111. def test_syev(benchmark, n):
  112. rndm = np.random.RandomState(1234)
  113. a = rndm.uniform(size=(n, n))
  114. a = np.asarray(a + a.T, dtype=float, order='F')
  115. a_ = a.copy()
  116. lwork, info = dsyev_lwork(n)
  117. lwork = int(lwork)
  118. assert info == 0
  119. w, v, info = benchmark(run_syev, a, lwork)
  120. assert info == 0
  121. assert a is v # overwrite_a=True