| @@ -1,86 +0,0 @@ | |||||
| #!/usr/bin/env python3 | |||||
| # -*- coding: utf-8 -*- | |||||
| """ | |||||
| Created on Wed Dec 19 15:31:01 2018 | |||||
| A script to set the thread number of OpenBLAS (if used). | |||||
| Some modules (such as Numpy, Scipy, sklearn) using OpenBLAS perform parallel | |||||
| computation automatically, which causes conflict when other paralleling modules | |||||
| such as multiprossing.Pool, highly increase the computing time. By setting | |||||
| thread to 1, OpenBLAS is forced to use single thread/CPU, thus this conflict | |||||
| can be avoided. | |||||
| -e.g: | |||||
| with num_threads(8): | |||||
| np.dot(x, y) | |||||
| @author: ali_m | |||||
| @Reference: `ali_m's answer <https://stackoverflow.com/a/29582987>`__, 2018.12 | |||||
| """ | |||||
| import contextlib | |||||
| import ctypes | |||||
| from ctypes.util import find_library | |||||
| import os | |||||
| # Prioritize hand-compiled OpenBLAS library over version in /usr/lib/ | |||||
| # from Ubuntu repos | |||||
| try_paths = ['/opt/OpenBLAS/lib/libopenblas.so', | |||||
| '/lib/libopenblas.so', | |||||
| '/usr/lib/libopenblas.so.0', | |||||
| find_library('openblas')] | |||||
| openblas_lib = None | |||||
| for libpath in try_paths: | |||||
| try: | |||||
| openblas_lib = ctypes.cdll.LoadLibrary(libpath) | |||||
| break | |||||
| except OSError: | |||||
| continue | |||||
| if openblas_lib is None: | |||||
| raise EnvironmentError('Could not locate an OpenBLAS shared library', 2) | |||||
| def set_num_threads(n): | |||||
| """Set the current number of threads used by the OpenBLAS server.""" | |||||
| openblas_lib.openblas_set_num_threads(int(n)) | |||||
| # At the time of writing these symbols were very new: | |||||
| # https://github.com/xianyi/OpenBLAS/commit/65a847c | |||||
| try: | |||||
| openblas_lib.openblas_get_num_threads() | |||||
| def get_num_threads(): | |||||
| """Get the current number of threads used by the OpenBLAS server.""" | |||||
| return openblas_lib.openblas_get_num_threads() | |||||
| except AttributeError: | |||||
| def get_num_threads(): | |||||
| """Dummy function (symbol not present in %s), returns -1.""" | |||||
| return -1 | |||||
| pass | |||||
| try: | |||||
| len(os.sched_getaffinity(0)) | |||||
| def get_num_procs(): | |||||
| """Get the total number of physical processors""" | |||||
| return len(os.sched_getaffinity(0)) | |||||
| except AttributeError: | |||||
| def get_num_procs(): | |||||
| """Dummy function (symbol not present), returns -1.""" | |||||
| return -1 | |||||
| pass | |||||
| @contextlib.contextmanager | |||||
| def num_threads(n): | |||||
| """Temporarily changes the number of OpenBLAS threads. | |||||
| Example usage: | |||||
| print("Before: {}".format(get_num_threads())) | |||||
| with num_threads(n): | |||||
| print("In thread context: {}".format(get_num_threads())) | |||||
| print("After: {}".format(get_num_threads())) | |||||
| """ | |||||
| old_n = get_num_threads() | |||||
| set_num_threads(n) | |||||
| try: | |||||
| yield | |||||
| finally: | |||||
| set_num_threads(old_n) | |||||