You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

common.py 7.4 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. #! /usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. import os
  4. import tensorlayer as tl
  5. from tensorlayer.files import utils
  6. from tensorlayer import logging
  7. _act_dict = {
  8. "relu": tl.ops.ReLU,
  9. "relu6": tl.ops.ReLU6,
  10. "leaky_relu": tl.ops.LeakyReLU,
  11. "lrelu": tl.ops.LeakyReLU,
  12. "softplus": tl.ops.Softplus,
  13. "tanh": tl.ops.Tanh,
  14. "sigmoid": tl.ops.Sigmoid,
  15. "softmax": tl.ops.Softmax
  16. }
  17. def str2act(act):
  18. if len(act) > 5 and act[0:5] == "lrelu":
  19. try:
  20. alpha = float(act[5:])
  21. return tl.ops.LeakyReLU(alpha=alpha)
  22. except Exception as e:
  23. raise Exception("{} can not be parsed as a float".format(act[5:]))
  24. if len(act) > 10 and act[0:10] == "leaky_relu":
  25. try:
  26. alpha = float(act[10:])
  27. return tl.ops.LeakyReLU(alpha=alpha)
  28. except Exception as e:
  29. raise Exception("{} can not be parsed as a float".format(act[10:]))
  30. if act not in _act_dict.keys():
  31. raise Exception("Unsupported act: {}".format(act))
  32. return _act_dict[act]
  33. def _save_weights(net, file_path, format=None):
  34. """Input file_path, save model weights into a file of given format.
  35. Use net.load_weights() to restore.
  36. Parameters
  37. ----------
  38. file_path : str
  39. Filename to which the model weights will be saved.
  40. format : str or None
  41. Saved file format.
  42. Value should be None, 'hdf5', 'npz', 'npz_dict' or 'ckpt'. Other format is not supported now.
  43. 1) If this is set to None, then the postfix of file_path will be used to decide saved format.
  44. If the postfix is not in ['h5', 'hdf5', 'npz', 'ckpt'], then file will be saved in hdf5 format by default.
  45. 2) 'hdf5' will save model weights name in a list and each layer has its weights stored in a group of
  46. the hdf5 file.
  47. 3) 'npz' will save model weights sequentially into a npz file.
  48. 4) 'npz_dict' will save model weights along with its name as a dict into a npz file.
  49. 5) 'ckpt' will save model weights into a tensorflow ckpt file.
  50. Default None.
  51. Examples
  52. --------
  53. 1) Save model weights in hdf5 format by default.
  54. >>> net = vgg16()
  55. >>> optimizer = tl.optimizers.Adam(learning_rate=0.001)
  56. >>> metric = tl.metric.Accuracy()
  57. >>> model = tl.models.Model(network=net, loss_fn=tl.cost.cross_entropy, optimizer=optimizer, metrics=metric)
  58. >>> model.save_weights('./model.h5')
  59. ...
  60. >>> model.load_weights('./model.h5')
  61. 2) Save model weights in npz/npz_dict format
  62. >>> model.save_weights('./model.npz')
  63. >>> model.save_weights('./model.npz', format='npz_dict')
  64. """
  65. if net.all_weights is None or len(net.all_weights) == 0:
  66. logging.warning("Model contains no weights or layers haven't been built, nothing will be saved")
  67. return
  68. if format is None:
  69. postfix = file_path.split('.')[-1]
  70. if postfix in ['h5', 'hdf5', 'npz', 'ckpt']:
  71. format = postfix
  72. else:
  73. format = 'hdf5'
  74. if format == 'hdf5' or format == 'h5':
  75. raise NotImplementedError("hdf5 load/save is not supported now.")
  76. # utils.save_weights_to_hdf5(file_path, net)
  77. elif format == 'npz':
  78. utils.save_npz(net.all_weights, file_path)
  79. elif format == 'npz_dict':
  80. utils.save_npz_dict(net.all_weights, file_path)
  81. elif format == 'ckpt':
  82. # TODO: enable this when tf save ckpt is enabled
  83. raise NotImplementedError("ckpt load/save is not supported now.")
  84. else:
  85. raise ValueError(
  86. "Save format must be 'hdf5', 'npz', 'npz_dict' or 'ckpt'."
  87. "Other format is not supported now."
  88. )
  89. def _load_weights(net, file_path, format=None, in_order=True, skip=False):
  90. """Load model weights from a given file, which should be previously saved by net.save_weights().
  91. Parameters
  92. ----------
  93. file_path : str
  94. Filename from which the model weights will be loaded.
  95. format : str or None
  96. If not specified (None), the postfix of the file_path will be used to decide its format. If specified,
  97. value should be 'hdf5', 'npz', 'npz_dict' or 'ckpt'. Other format is not supported now.
  98. In addition, it should be the same format when you saved the file using net.save_weights().
  99. Default is None.
  100. in_order : bool
  101. Allow loading weights into model in a sequential way or by name. Only useful when 'format' is 'hdf5'.
  102. If 'in_order' is True, weights from the file will be loaded into model in a sequential way.
  103. If 'in_order' is False, weights from the file will be loaded into model by matching the name
  104. with the weights of the model, particularly useful when trying to restore model in eager(graph) mode from
  105. a weights file which is saved in graph(eager) mode.
  106. Default is True.
  107. skip : bool
  108. Allow skipping weights whose name is mismatched between the file and model. Only useful when 'format' is
  109. 'hdf5' or 'npz_dict'. If 'skip' is True, 'in_order' argument will be ignored and those loaded weights
  110. whose name is not found in model weights (net.all_weights) will be skipped. If 'skip' is False, error will
  111. occur when mismatch is found.
  112. Default is False.
  113. Examples
  114. --------
  115. 1) load model from a hdf5 file.
  116. >>> net = vgg16()
  117. >>> optimizer = tl.optimizers.Adam(learning_rate=0.001)
  118. >>> metric = tl.metric.Accuracy()
  119. >>> model = tl.models.Model(network=net, loss_fn=tl.cost.cross_entropy, optimizer=optimizer, metrics=metric)
  120. >>> model.load_weights('./model_graph.h5', in_order=False, skip=True) # load weights by name, skipping mismatch
  121. >>> model.load_weights('./model_eager.h5') # load sequentially
  122. 2) load model from a npz file
  123. >>> model.load_weights('./model.npz')
  124. 3) load model from a npz file, which is saved as npz_dict previously
  125. >>> model.load_weights('./model.npz', format='npz_dict')
  126. Notes
  127. -------
  128. 1) 'in_order' is only useful when 'format' is 'hdf5'. If you are trying to load a weights file which is
  129. saved in a different mode, it is recommended to set 'in_order' be True.
  130. 2) 'skip' is useful when 'format' is 'hdf5' or 'npz_dict'. If 'skip' is True,
  131. 'in_order' argument will be ignored.
  132. """
  133. if not os.path.exists(file_path):
  134. raise FileNotFoundError("file {} doesn't exist.".format(file_path))
  135. if format is None:
  136. format = file_path.split('.')[-1]
  137. if format == 'hdf5' or format == 'h5':
  138. raise NotImplementedError("hdf5 load/save is not supported now.")
  139. # if skip ==True or in_order == False:
  140. # # load by weights name
  141. # utils.load_hdf5_to_weights(file_path, net, skip)
  142. # else:
  143. # # load in order
  144. # utils.load_hdf5_to_weights_in_order(file_path, net)
  145. elif format == 'npz':
  146. utils.load_and_assign_npz(file_path, net)
  147. elif format == 'npz_dict':
  148. utils.load_and_assign_npz_dict(file_path, net, skip)
  149. elif format == 'ckpt':
  150. # TODO: enable this when tf save ckpt is enabled
  151. raise NotImplementedError("ckpt load/save is not supported now.")
  152. else:
  153. raise ValueError(
  154. "File format must be 'hdf5', 'npz', 'npz_dict' or 'ckpt'. "
  155. "Other format is not supported now."
  156. )

TensorLayer3.0 是一款兼容多种深度学习框架为计算后端的深度学习库。计划兼容TensorFlow, Pytorch, MindSpore, Paddle.