You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

config.py 9.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Configuration class for Transformer."""
  16. import os
  17. import json
  18. import copy
  19. from typing import List
  20. import mindspore.common.dtype as mstype
  21. def _is_dataset_file(file: str):
  22. return "tfrecord" in file.lower() or "mindrecord" in file.lower()
  23. def _get_files_from_dir(folder: str):
  24. _files = []
  25. for file in os.listdir(folder):
  26. if _is_dataset_file(file):
  27. _files.append(os.path.join(folder, file))
  28. return _files
  29. def get_source_list(folder: str) -> List:
  30. """
  31. Get file list from a folder.
  32. Returns:
  33. list, file list.
  34. """
  35. _list = []
  36. if not folder:
  37. return _list
  38. if os.path.isdir(folder):
  39. _list = _get_files_from_dir(folder)
  40. else:
  41. if _is_dataset_file(folder):
  42. _list.append(folder)
  43. return _list
  44. PARAM_NODES = {"dataset_config",
  45. "model_config",
  46. "loss_scale_config",
  47. "learn_rate_config",
  48. "checkpoint_options"}
  49. class TransformerConfig:
  50. """
  51. Configuration for `Transformer`.
  52. Args:
  53. random_seed (int): Random seed.
  54. batch_size (int): Batch size of input dataset.
  55. epochs (int): Epoch number.
  56. dataset_sink_mode (bool): Whether enable dataset sink mode.
  57. dataset_sink_step (int): Dataset sink step.
  58. lr_scheduler (str): Whether use lr_scheduler, only support "ISR" now.
  59. lr (float): Initial learning rate.
  60. min_lr (float): Minimum learning rate.
  61. decay_start_step (int): Step to decay.
  62. warmup_steps (int): Warm up steps.
  63. dataset_schema (str): Path of dataset schema file.
  64. pre_train_dataset (str): Path of pre-training dataset file or folder.
  65. fine_tune_dataset (str): Path of fine-tune dataset file or folder.
  66. test_dataset (str): Path of test dataset file or folder.
  67. valid_dataset (str): Path of validation dataset file or folder.
  68. ckpt_path (str): Checkpoints save path.
  69. save_ckpt_steps (int): Interval of saving ckpt.
  70. ckpt_prefix (str): Prefix of ckpt file.
  71. keep_ckpt_max (int): Max ckpt files number.
  72. seq_length (int): Length of input sequence. Default: 64.
  73. vocab_size (int): The shape of each embedding vector. Default: 46192.
  74. hidden_size (int): Size of embedding, attention, dim. Default: 512.
  75. num_hidden_layers (int): Encoder, Decoder layers.
  76. num_attention_heads (int): Number of hidden layers in the Transformer encoder/decoder
  77. cell. Default: 6.
  78. intermediate_size (int): Size of intermediate layer in the Transformer
  79. encoder/decoder cell. Default: 4096.
  80. hidden_act (str): Activation function used in the Transformer encoder/decoder
  81. cell. Default: "relu".
  82. init_loss_scale (int): Initialized loss scale.
  83. loss_scale_factor (int): Loss scale factor.
  84. scale_window (int): Window size of loss scale.
  85. beam_width (int): Beam width for beam search in inferring. Default: 4.
  86. length_penalty_weight (float): Penalty for sentence length. Default: 1.0.
  87. label_smoothing (float): Label smoothing setting. Default: 0.1.
  88. input_mask_from_dataset (bool): Specifies whether to use the input mask that loaded from
  89. dataset. Default: True.
  90. save_graphs (bool): Whether to save graphs, please set to True if mindinsight
  91. is wanted.
  92. dtype (mstype): Data type of the input. Default: mstype.float32.
  93. max_decode_length (int): Max decode length for inferring. Default: 64.
  94. hidden_dropout_prob (float): The dropout probability for hidden outputs. Default: 0.1.
  95. attention_dropout_prob (float): The dropout probability for
  96. Multi-head Self-Attention. Default: 0.1.
  97. max_position_embeddings (int): Maximum length of sequences used in this
  98. model. Default: 512.
  99. initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
  100. """
  101. def __init__(self,
  102. random_seed=74,
  103. batch_size=64, epochs=1,
  104. dataset_sink_mode=True, dataset_sink_step=1,
  105. lr_scheduler="", optimizer="adam",
  106. lr=1e-4, min_lr=1e-6,
  107. decay_steps=10000, poly_lr_scheduler_power=1,
  108. decay_start_step=-1, warmup_steps=2000,
  109. pre_train_dataset: str = None,
  110. fine_tune_dataset: str = None,
  111. test_dataset: str = None,
  112. valid_dataset: str = None,
  113. ckpt_path: str = None,
  114. save_ckpt_steps=2000,
  115. ckpt_prefix="CKPT",
  116. existed_ckpt="",
  117. keep_ckpt_max=20,
  118. seq_length=128,
  119. vocab_size=46192,
  120. hidden_size=512,
  121. num_hidden_layers=6,
  122. num_attention_heads=8,
  123. intermediate_size=4096,
  124. hidden_act="relu",
  125. hidden_dropout_prob=0.1,
  126. attention_dropout_prob=0.1,
  127. max_position_embeddings=64,
  128. initializer_range=0.02,
  129. init_loss_scale=2 ** 10,
  130. loss_scale_factor=2, scale_window=2000,
  131. beam_width=5,
  132. length_penalty_weight=1.0,
  133. label_smoothing=0.1,
  134. input_mask_from_dataset=True,
  135. save_graphs=False,
  136. dtype=mstype.float32,
  137. max_decode_length=64):
  138. self.save_graphs = save_graphs
  139. self.random_seed = random_seed
  140. self.pre_train_dataset = get_source_list(pre_train_dataset) # type: List[str]
  141. self.fine_tune_dataset = get_source_list(fine_tune_dataset) # type: List[str]
  142. self.valid_dataset = get_source_list(valid_dataset) # type: List[str]
  143. self.test_dataset = get_source_list(test_dataset) # type: List[str]
  144. if not isinstance(epochs, int) and epochs < 0:
  145. raise ValueError("`epoch` must be type of int.")
  146. self.epochs = epochs
  147. self.dataset_sink_mode = dataset_sink_mode
  148. self.dataset_sink_step = dataset_sink_step
  149. self.ckpt_path = ckpt_path
  150. self.keep_ckpt_max = keep_ckpt_max
  151. self.save_ckpt_steps = save_ckpt_steps
  152. self.ckpt_prefix = ckpt_prefix
  153. self.existed_ckpt = existed_ckpt
  154. self.batch_size = batch_size
  155. self.seq_length = seq_length
  156. self.vocab_size = vocab_size
  157. self.hidden_size = hidden_size
  158. self.num_hidden_layers = num_hidden_layers
  159. self.num_attention_heads = num_attention_heads
  160. self.hidden_act = hidden_act
  161. self.intermediate_size = intermediate_size
  162. self.hidden_dropout_prob = hidden_dropout_prob
  163. self.attention_dropout_prob = attention_dropout_prob
  164. self.max_position_embeddings = max_position_embeddings
  165. self.initializer_range = initializer_range
  166. self.label_smoothing = label_smoothing
  167. self.beam_width = beam_width
  168. self.length_penalty_weight = length_penalty_weight
  169. self.max_decode_length = max_decode_length
  170. self.input_mask_from_dataset = input_mask_from_dataset
  171. self.compute_type = mstype.float16
  172. self.dtype = dtype
  173. self.scale_window = scale_window
  174. self.loss_scale_factor = loss_scale_factor
  175. self.init_loss_scale = init_loss_scale
  176. self.optimizer = optimizer
  177. self.lr = lr
  178. self.lr_scheduler = lr_scheduler
  179. self.min_lr = min_lr
  180. self.poly_lr_scheduler_power = poly_lr_scheduler_power
  181. self.decay_steps = decay_steps
  182. self.decay_start_step = decay_start_step
  183. self.warmup_steps = warmup_steps
  184. self.train_url = ""
  185. @classmethod
  186. def from_dict(cls, json_object: dict):
  187. """Constructs a `TransformerConfig` from a Python dictionary of parameters."""
  188. _params = {}
  189. for node in PARAM_NODES:
  190. for key in json_object[node]:
  191. _params[key] = json_object[node][key]
  192. return cls(**_params)
  193. @classmethod
  194. def from_json_file(cls, json_file):
  195. """Constructs a `TransformerConfig` from a json file of parameters."""
  196. with open(json_file, "r") as reader:
  197. return cls.from_dict(json.load(reader))
  198. def to_dict(self):
  199. """Serializes this instance to a Python dictionary."""
  200. output = copy.deepcopy(self.__dict__)
  201. return output
  202. def to_json_string(self):
  203. """Serializes this instance to a JSON string."""
  204. return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"