You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

bench_blas.py 6.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274
  1. import pytest
  2. import numpy as np
  3. import openblas_wrap as ow
  4. dtype_map = {
  5. 's': np.float32,
  6. 'd': np.float64,
  7. 'c': np.complex64,
  8. 'z': np.complex128,
  9. 'dz': np.complex128,
  10. }
  11. # ### BLAS level 1 ###
  12. # dnrm2
  13. dnrm2_sizes = [100, 1000]
  14. def run_dnrm2(n, x, incx, func):
  15. res = func(x, n, incx=incx)
  16. return res
  17. @pytest.mark.parametrize('variant', ['d', 'dz'])
  18. @pytest.mark.parametrize('n', dnrm2_sizes)
  19. def test_nrm2(benchmark, n, variant):
  20. rndm = np.random.RandomState(1234)
  21. dtyp = dtype_map[variant]
  22. x = np.array(rndm.uniform(size=(n,)), dtype=dtyp)
  23. nrm2 = ow.get_func('nrm2', variant)
  24. result = benchmark(run_dnrm2, n, x, 1, nrm2)
  25. # ddot
  26. ddot_sizes = [100, 1000]
  27. def run_ddot(x, y, func):
  28. res = func(x, y)
  29. return res
  30. @pytest.mark.parametrize('n', ddot_sizes)
  31. def test_dot(benchmark, n):
  32. rndm = np.random.RandomState(1234)
  33. x = np.array(rndm.uniform(size=(n,)), dtype=float)
  34. y = np.array(rndm.uniform(size=(n,)), dtype=float)
  35. dot = ow.get_func('dot', 'd')
  36. result = benchmark(run_ddot, x, y, dot)
  37. # daxpy
  38. daxpy_sizes = [100, 1000]
  39. def run_daxpy(x, y, func):
  40. res = func(x, y, a=2.0)
  41. return res
  42. @pytest.mark.parametrize('variant', ['s', 'd', 'c', 'z'])
  43. @pytest.mark.parametrize('n', daxpy_sizes)
  44. def test_daxpy(benchmark, n, variant):
  45. rndm = np.random.RandomState(1234)
  46. dtyp = dtype_map[variant]
  47. x = np.array(rndm.uniform(size=(n,)), dtype=dtyp)
  48. y = np.array(rndm.uniform(size=(n,)), dtype=dtyp)
  49. axpy = ow.get_func('axpy', variant)
  50. result = benchmark(run_daxpy, x, y, axpy)
  51. # ### BLAS level 2 ###
  52. gemv_sizes = [100, 1000]
  53. def run_gemv(a, x, y, func):
  54. res = func(1.0, a, x, y=y, overwrite_y=True)
  55. return res
  56. @pytest.mark.parametrize('variant', ['s', 'd', 'c', 'z'])
  57. @pytest.mark.parametrize('n', gemv_sizes)
  58. def test_dgemv(benchmark, n, variant):
  59. rndm = np.random.RandomState(1234)
  60. dtyp = dtype_map[variant]
  61. x = np.array(rndm.uniform(size=(n,)), dtype=dtyp)
  62. y = np.empty(n, dtype=dtyp)
  63. a = np.array(rndm.uniform(size=(n,n)), dtype=dtyp)
  64. x = np.array(rndm.uniform(size=(n,)), dtype=dtyp)
  65. y = np.zeros(n, dtype=dtyp)
  66. gemv = ow.get_func('gemv', variant)
  67. result = benchmark(run_gemv, a, x, y, gemv)
  68. assert result is y
  69. # dgbmv
  70. dgbmv_sizes = [100, 1000]
  71. def run_gbmv(m, n, kl, ku, a, x, y, func):
  72. res = func(m, n, kl, ku, 1.0, a, x, y=y, overwrite_y=True)
  73. return res
  74. @pytest.mark.parametrize('variant', ['s', 'd', 'c', 'z'])
  75. @pytest.mark.parametrize('n', dgbmv_sizes)
  76. @pytest.mark.parametrize('kl', [1])
  77. def test_dgbmv(benchmark, n, kl, variant):
  78. rndm = np.random.RandomState(1234)
  79. dtyp = dtype_map[variant]
  80. x = np.array(rndm.uniform(size=(n,)), dtype=dtyp)
  81. y = np.empty(n, dtype=dtyp)
  82. m = n
  83. a = rndm.uniform(size=(2*kl + 1, n))
  84. a = np.array(a, dtype=dtyp, order='F')
  85. gbmv = ow.get_func('gbmv', variant)
  86. result = benchmark(run_gbmv, m, n, kl, kl, a, x, y, gbmv)
  87. assert result is y
  88. # ### BLAS level 3 ###
  89. # dgemm
  90. gemm_sizes = [100, 1000]
  91. def run_gemm(a, b, c, func):
  92. alpha = 1.0
  93. res = func(alpha, a, b, c=c, overwrite_c=True)
  94. return res
  95. @pytest.mark.parametrize('variant', ['s', 'd', 'c', 'z'])
  96. @pytest.mark.parametrize('n', gemm_sizes)
  97. def test_gemm(benchmark, n, variant):
  98. rndm = np.random.RandomState(1234)
  99. dtyp = dtype_map[variant]
  100. a = np.array(rndm.uniform(size=(n, n)), dtype=dtyp, order='F')
  101. b = np.array(rndm.uniform(size=(n, n)), dtype=dtyp, order='F')
  102. c = np.empty((n, n), dtype=dtyp, order='F')
  103. gemm = ow.get_func('gemm', variant)
  104. result = benchmark(run_gemm, a, b, c, gemm)
  105. assert result is c
  106. # dsyrk
  107. syrk_sizes = [100, 1000]
  108. def run_syrk(a, c, func):
  109. res = func(1.0, a, c=c, overwrite_c=True)
  110. return res
  111. @pytest.mark.parametrize('variant', ['s', 'd', 'c', 'z'])
  112. @pytest.mark.parametrize('n', syrk_sizes)
  113. def test_syrk(benchmark, n, variant):
  114. rndm = np.random.RandomState(1234)
  115. dtyp = dtype_map[variant]
  116. a = np.array(rndm.uniform(size=(n, n)), dtype=dtyp, order='F')
  117. c = np.empty((n, n), dtype=dtyp, order='F')
  118. syrk = ow.get_func('syrk', variant)
  119. result = benchmark(run_syrk, a, c, syrk)
  120. assert result is c
  121. # ### LAPACK ###
  122. # linalg.solve
  123. gesv_sizes = [100, 1000]
  124. def run_gesv(a, b, func):
  125. res = func(a, b, overwrite_a=True, overwrite_b=True)
  126. return res
  127. @pytest.mark.parametrize('variant', ['s', 'd', 'c', 'z'])
  128. @pytest.mark.parametrize('n', gesv_sizes)
  129. def test_gesv(benchmark, n, variant):
  130. rndm = np.random.RandomState(1234)
  131. dtyp = dtype_map[variant]
  132. a = (np.array(rndm.uniform(size=(n, n)), dtype=dtyp, order='F') +
  133. np.eye(n, dtype=dtyp, order='F'))
  134. b = np.array(rndm.uniform(size=(n, 1)), dtype=dtyp, order='F')
  135. gesv = ow.get_func('gesv', variant)
  136. lu, piv, x, info = benchmark(run_gesv, a, b, gesv)
  137. assert lu is a
  138. assert x is b
  139. assert info == 0
  140. # linalg.svd
  141. gesdd_sizes = [(100, 5), (1000, 222)]
  142. def run_gesdd(a, lwork, func):
  143. res = func(a, lwork=lwork, full_matrices=False, overwrite_a=False)
  144. return res
  145. @pytest.mark.parametrize('variant', ['s', 'd'])
  146. @pytest.mark.parametrize('mn', gesdd_sizes)
  147. def test_gesdd(benchmark, mn, variant):
  148. m, n = mn
  149. rndm = np.random.RandomState(1234)
  150. dtyp = dtype_map[variant]
  151. a = np.array(rndm.uniform(size=(m, n)), dtype=dtyp, order='F')
  152. gesdd_lwork = ow.get_func('gesdd_lwork', variant)
  153. lwork, info = gesdd_lwork(m, n)
  154. lwork = int(lwork)
  155. assert info == 0
  156. gesdd = ow.get_func('gesdd', variant)
  157. u, s, vt, info = benchmark(run_gesdd, a, lwork, gesdd)
  158. assert info == 0
  159. atol = {'s': 1e-5, 'd': 1e-13}
  160. np.testing.assert_allclose(u @ np.diag(s) @ vt, a, atol=atol[variant])
  161. # linalg.eigh
  162. syev_sizes = [50, 200]
  163. def run_syev(a, lwork, func):
  164. res = func(a, lwork=lwork, overwrite_a=True)
  165. return res
  166. @pytest.mark.parametrize('variant', ['s', 'd'])
  167. @pytest.mark.parametrize('n', syev_sizes)
  168. def test_syev(benchmark, n, variant):
  169. rndm = np.random.RandomState(1234)
  170. dtyp = dtype_map[variant]
  171. a = rndm.uniform(size=(n, n))
  172. a = np.asarray(a + a.T, dtype=dtyp, order='F')
  173. a_ = a.copy()
  174. dsyev_lwork = ow.get_func('syev_lwork', variant)
  175. lwork, info = dsyev_lwork(n)
  176. lwork = int(lwork)
  177. assert info == 0
  178. syev = ow.get_func('syev', variant)
  179. w, v, info = benchmark(run_syev, a, lwork, syev)
  180. assert info == 0
  181. assert a is v # overwrite_a=True