You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

load_backend.py 2.3 kB

4 years ago
4 years ago
4 years ago
4 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. #! /usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. import json
  4. import os
  5. import sys
  6. BACKEND = 'tensorflow'
  7. # BACKEND = 'mindspore'
  8. # BACKEND = 'paddle'
  9. # Check for backend.json files
  10. tl_backend_dir = os.path.expanduser('~')
  11. if not os.access(tl_backend_dir, os.W_OK):
  12. tl_backend_dir = '/tmp'
  13. tl_dir = os.path.join(tl_backend_dir, '.tl')
  14. config = {
  15. 'backend': BACKEND,
  16. }
  17. if not os.path.exists(tl_dir):
  18. path = os.path.join(tl_dir, 'tl_backend.json')
  19. os.makedirs(tl_dir)
  20. with open(path, "w") as f:
  21. json.dump(config, f)
  22. BACKEND = config['backend']
  23. sys.stderr.write("Create the backend configuration file :" + path + '\n')
  24. else:
  25. path = os.path.join(tl_dir, 'tl_backend.json')
  26. with open(path, 'r') as load_f:
  27. load_dict = json.load(load_f)
  28. if load_dict['backend'] is not config['backend']:
  29. BACKEND = config['backend']
  30. else:
  31. BACKEND = load_dict['backend']
  32. # Set backend based on TL_BACKEND.
  33. if 'TL_BACKEND' in os.environ:
  34. backend = os.environ['TL_BACKEND']
  35. if backend:
  36. BACKEND = backend
  37. # import backend functions
  38. if BACKEND == 'tensorflow':
  39. from .tensorflow_backend import *
  40. from .tensorflow_nn import *
  41. import tensorflow as tf
  42. BACKEND_VERSION = tf.__version__
  43. sys.stderr.write('Using TensorFlow backend.\n')
  44. elif BACKEND == 'mindspore':
  45. from .mindspore_backend import *
  46. from .mindspore_nn import *
  47. import mindspore as ms
  48. BACKEND_VERSION = ms.__version__
  49. # set context
  50. import mindspore.context as context
  51. import os
  52. os.environ['DEVICE_ID'] = '0'
  53. context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU'),
  54. # context.set_context(mode=context.GRAPH_MODE, device_target='CPU'),
  55. # enable_task_sink=True, enable_loop_sink=True)
  56. # context.set_context(mode=context.GRAPH_MODE, backend_policy='ms',
  57. # device_target='Ascend', enable_task_sink=True, enable_loop_sink=True)
  58. sys.stderr.write('Using MindSpore backend.\n')
  59. elif BACKEND == 'paddle':
  60. from .paddle_backend import *
  61. from .paddle_nn import *
  62. import paddle as pd
  63. BACKEND_VERSION = pd.__version__
  64. sys.stderr.write('Using Paddle backend.\n')
  65. else:
  66. raise NotImplementedError("This backend is not supported")

TensorLayer3.0 是一款兼容多种深度学习框架为计算后端的深度学习库。计划兼容TensorFlow, Pytorch, MindSpore, Paddle.