You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

fix_checkpoint_file.py 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271
  1. # Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Fix weight names in CheckPoint file when user edits converted MindSpore scripts."""
  16. import os
  17. import re
  18. import argparse
  19. import sys
  20. import ast
  21. from ast import NodeTransformer, ClassDef, Assign, FunctionDef
  22. from importlib import import_module
  23. from mindspore import load_checkpoint, Tensor, save_checkpoint, load_param_into_net
  24. from mindinsight.mindconverter.common.log import logger_console
  25. class FixCheckPointGenerator(NodeTransformer):
  26. """Fix weight names in CheckPoint file."""
  27. def __init__(self, source_script_path, target_script_path):
  28. self._source_script_name = os.path.basename(source_script_path)
  29. self._target_script_name = os.path.basename(target_script_path)
  30. self._source_variable_mapper = self._extract(source_script_path)
  31. self._target_variable_mapper = self._extract(target_script_path)
  32. self._fixed_mapper = self._generator()
  33. def _extract(self, script_path):
  34. """Extract info from AST Tree."""
  35. with open(script_path, 'r') as rf:
  36. tree = ast.parse(rf.read())
  37. variable_mapper = dict()
  38. for block in tree.body:
  39. if not isinstance(block, ClassDef):
  40. continue
  41. module_name = block.name
  42. valid_body = [block_pick for block_pick in block.body if isinstance(block_pick, FunctionDef)]
  43. init_body = valid_body[0]
  44. variable_name = self._extract_init(init_body)
  45. if not variable_mapper.get(module_name):
  46. variable_mapper[module_name] = variable_name
  47. else:
  48. variable_mapper[module_name].extend(variable_name)
  49. return variable_mapper
  50. @staticmethod
  51. def _extract_init(init_body):
  52. """Extract init information."""
  53. variable_names = list()
  54. for block in init_body.body:
  55. if not isinstance(block, Assign):
  56. continue
  57. variable_name = block.targets[0].attr
  58. variable_names.append(variable_name)
  59. return variable_names
  60. def _check_data(self, data_1, data_2):
  61. """Check the shape of two inputs."""
  62. if len(data_1) != len(data_2):
  63. logger_console.error(
  64. f"The construct of {self._source_script_name} and that of {self._target_script_name} ars not matched.")
  65. exit(0)
  66. def _generator(self):
  67. """Generator fixed_mapper."""
  68. main_module_name = list(self._target_variable_mapper.keys())[-1].lower()
  69. fixed_variable_mapper = dict()
  70. fixed_module_mapper = dict()
  71. self._check_data(self._source_variable_mapper, self._target_variable_mapper)
  72. for source_module_name, target_module_name in zip(self._source_variable_mapper, self._target_variable_mapper):
  73. fixed_variable_mapper[target_module_name.lower()] = self._fixed_variable_mapper_generator(
  74. self._source_variable_mapper[source_module_name], self._target_variable_mapper[target_module_name])
  75. if source_module_name != target_module_name:
  76. fixed_module_mapper[source_module_name.lower()] = target_module_name.lower()
  77. return {
  78. 'main_module_name': main_module_name,
  79. 'fixed_variable_mapper': fixed_variable_mapper,
  80. 'fixed_module_mapper': fixed_module_mapper
  81. }
  82. def _fixed_variable_mapper_generator(self, source_variable_names, target_variable_names):
  83. """Generate fixed_variable_mapper."""
  84. self._check_data(source_variable_names, target_variable_names)
  85. fixed_variable_mapper = dict()
  86. for source_variable_name, target_variable_name in zip(source_variable_names, target_variable_names):
  87. if source_variable_name != target_variable_name:
  88. fixed_variable_mapper[source_variable_name] = target_variable_name
  89. return fixed_variable_mapper
  90. def fix_ckpt(self, ckpt_path, new_ckpt_path):
  91. """Fix checkpoint file."""
  92. param_dict = load_checkpoint(ckpt_path)
  93. main_module_name = self._fixed_mapper['main_module_name']
  94. fixed_variable_dict = self._fixed_mapper['fixed_variable_mapper']
  95. fixed_module_dict = self._fixed_mapper['fixed_module_mapper']
  96. save_obj = list()
  97. for weight_name, weight_value in param_dict.items():
  98. weight_name_scopes = weight_name.split('.')
  99. weight_name_scopes.insert(0, main_module_name)
  100. for idx, w in enumerate(weight_name_scopes[:-1]):
  101. for fixed_variable_module, fixed_variable_name_mapper in fixed_variable_dict.items():
  102. if re.match(fixed_variable_module, fixed_module_dict.get('_'.join(w.split('_')[:-1]), w)):
  103. weight_name = weight_name.replace(
  104. weight_name_scopes[idx + 1],
  105. fixed_variable_name_mapper.get(weight_name_scopes[idx + 1], weight_name_scopes[idx + 1]))
  106. obj = {
  107. 'name': weight_name,
  108. 'data': Tensor(weight_value)
  109. }
  110. save_obj.append(obj)
  111. save_checkpoint(save_obj, new_ckpt_path)
  112. logger_console.info(f'Saved new checkpoint file to {new_ckpt_path}.')
  113. def source_checker(py_path, ckpt_path):
  114. """Check source model script and source checkpoint file."""
  115. sys.path.append(os.path.dirname(py_path))
  116. model = getattr(import_module(os.path.basename(py_path).replace('.py', '')), 'Model')()
  117. param_dict = load_checkpoint(ckpt_path)
  118. not_load_name = load_param_into_net(model, param_dict)
  119. return not bool(not_load_name)
  120. def file_existed_checker(parser_in, in_file, action_type):
  121. """Check file exists or not."""
  122. out_file = os.path.realpath(in_file)
  123. if not os.path.exists(out_file):
  124. if action_type == 'in':
  125. parser_in.error(f"{out_file} does NOT exist, check it.")
  126. elif not os.path.exists(os.path.dirname(out_file)):
  127. os.makedirs(os.path.dirname(out_file))
  128. return out_file
  129. def file_validation_checker(parser_in, in_file, expected_type):
  130. """Check file is valid or not."""
  131. if not in_file.endswith(expected_type):
  132. parser_in.error(f"'xxx{expected_type}' is expected, but gotten {os.path.basename(in_file)}.")
  133. class ScriptAction(argparse.Action):
  134. """Script action class definition."""
  135. def __call__(self, parser_in, namespace, values, option_string=None):
  136. """
  137. Inherited __call__ method from argparse.Action.
  138. Args:
  139. parser_in (ArgumentParser): Passed-in argument parser.
  140. namespace (Namespace): Namespace object to hold arguments.
  141. values (str): Argument values.
  142. option_string (str): Optional string for specific argument name. Default: None.
  143. """
  144. out_file = file_existed_checker(parser_in, values, "in")
  145. file_validation_checker(parser_in, out_file, ".py")
  146. setattr(namespace, self.dest, out_file)
  147. class InCheckPointAction(argparse.Action):
  148. """In CheckPoint action class definition."""
  149. def __call__(self, parser_in, namespace, values, option_string=None):
  150. """
  151. Inherited __call__ method from argparse.Action.
  152. Args:
  153. parser_in (ArgumentParser): Passed-in argument parser.
  154. namespace (Namespace): Namespace object to hold arguments.
  155. values (str): Argument values.
  156. option_string (str): Optional string for specific argument name. Default: None.
  157. """
  158. out_file = file_existed_checker(parser_in, values, "in")
  159. file_validation_checker(parser_in, out_file, ".ckpt")
  160. setattr(namespace, self.dest, out_file)
  161. class OutCheckPointAction(argparse.Action):
  162. """Out CheckPoint action class definition."""
  163. def __call__(self, parser_in, namespace, values, option_string=None):
  164. """
  165. Inherited __call__ method from argparse.Action.
  166. Args:
  167. parser_in (ArgumentParser): Passed-in argument parser.
  168. namespace (Namespace): Namespace object to hold arguments.
  169. values (str): Argument values.
  170. option_string (str): Optional string for specific argument name. Default: None.
  171. """
  172. out_file = file_existed_checker(parser_in, values, "out")
  173. file_validation_checker(parser_in, out_file, ".ckpt")
  174. setattr(namespace, self.dest, out_file)
  175. parser = argparse.ArgumentParser(description="Fix weight names in CheckPoint file.")
  176. parser.add_argument(
  177. "source_py_file",
  178. action=ScriptAction,
  179. help="source model script file")
  180. parser.add_argument(
  181. "fixed_py_file",
  182. action=ScriptAction,
  183. help="fixed model script file")
  184. parser.add_argument(
  185. "source_ckpt_file",
  186. action=InCheckPointAction,
  187. help="source checkpoint file")
  188. parser.add_argument(
  189. "fixed_ckpt_file",
  190. action=OutCheckPointAction,
  191. help="fixed checkpoint file")
  192. if __name__ == '__main__':
  193. argv = sys.argv[1:]
  194. if not argv:
  195. argv = ['-h']
  196. args = parser.parse_args(argv)
  197. else:
  198. args = parser.parse_args()
  199. source_py_file = args.source_py_file
  200. fixed_py_file = args.fixed_py_file
  201. source_ckpt_file = args.source_ckpt_file
  202. fixed_ckpt_file = args.fixed_ckpt_file
  203. if not source_checker(source_py_file, source_ckpt_file):
  204. logger_console.error("source checkpoint file is not inconsistent with source model script.")
  205. exit(0)
  206. fix_checkpoint_generator = FixCheckPointGenerator(source_py_file, fixed_py_file)
  207. fix_checkpoint_generator.fix_ckpt(source_ckpt_file, fixed_ckpt_file)