Browse Source

add tool to fix weight name in ckpt file

pull/1272/head
moran 4 years ago
parent
commit
1fca5c95c3
3 changed files with 405 additions and 0 deletions
  1. +67
    -0
      mindinsight/mindconverter/tools/README.md
  2. +67
    -0
      mindinsight/mindconverter/tools/README_CN.md
  3. +271
    -0
      mindinsight/mindconverter/tools/fix_checkpoint_file.py

+ 67
- 0
mindinsight/mindconverter/tools/README.md View File

@@ -0,0 +1,67 @@
# Tool Tutorial

[查看中文](./README_CN.md)

<!-- TOC -->

- [Tool Tutorial](#Tool-Tutorial)
- [Fix CheckPoint file Tool](#Fix-CheckPoint-file-Tool)
- [Overview](#Overview)
- [Usage](#Usage)
- [Example](#Example)
- [Limitation](#Limitation)

<!-- /TOC -->

## Fix CheckPoint file Tool

### Overview

Requirements

```text
MindSpore>=1.2.0
```

User may need to change module name or variable name in order to improve the readability of converted MindSpore script. However, these changes will result in failing to load checkpoint because of unmatched weight name. The [Fix CheckPoint file Tool](./fix_checkpoint_file.py) is used to fix weight names in checkpoint file.

### Usage

```bash
usage: fix_checkpoint_file.py [-h]
source_py_file fixed_py_file source_ckpt_file
fixed_ckpt_file

Fix weight name in CheckPoint file.

positional arguments:
source_py_file source model script file
fixed_py_file fixed model script file
source_ckpt_file source_checkpoint file
fixed_ckpt_file fixed_checkpoint file

optional arguments:
-h, --help show this hekp message and exit
```

### Example

Assuming that the source model script file is `xxx/model.py`, the fixed model script file is `xxx/fixed_model.py`, the source checkpoint file is `xxx/model.ckpt` and the new checkpoint file is `xxx/fixed_model.ckpt`.

The command is that:

```bash
python -m mindinsight.mindconverter.tools.fix_checkpoit_file xxx/model.py xxx/fixed_model.py xxx/model.ckpt xxx/fixed_model.ckpt
```

If generation is successful, the result below would be shown:

```text
Saved new checkpoint file to xxx/fixed_model.ckpt.
```

### Limitation

1. Only MindSpore Script and CheckPoint file generated by MindConverter using graph-based conversion (using `--model_file`) are supported in the tool.
2. The situation that only variable name or class name in scripts has changed are supported, while the one that model structure or script structure (add or delete operator) has changed is unsupported.
3. MindSpore is required by the tool, so make sure MindSpore installed correctly.

+ 67
- 0
mindinsight/mindconverter/tools/README_CN.md View File

@@ -0,0 +1,67 @@
# 工具使用教程

[Switch to English version](./README.md)

<!-- TOC -->

- [工具使用教程](#工具使用教程)
- [权重名修正工具](#权重名修正工具)
- [概述](#概述)
- [使用方法](#使用方法)
- [使用示例](#使用示例)
- [约束限制](#约束限制)

<!-- /TOC -->

## 权重名修正工具

### 概述

所需的依赖

```text
MindSpore>=1.2
```

对于由MindConverter转换生成的MindSpore脚本,用户可能会依照自己的需求对脚本中的类名、变量名进行修改,以增加网络脚本的可读性。但是这些修改有可能会导致加载权重信息时,因为无法找到对应的权重名称,而加载失败。该[权重名修正工具](./fix_checkpoint_file.py)用于将脚本中修改后的内容统一修改到权重文件中,生成可以用于新脚本的权重文件。

### 使用方法

```bash
usage: fix_checkpoint_file.py [-h]
source_py_file fixed_py_file source_ckpt_file
fixed_ckpt_file

Fix weight name in CheckPoint file.

positional arguments:
source_py_file source model script file
fixed_py_file fixed model script file
source_ckpt_file source_checkpoint file
fixed_ckpt_file fixed_checkpoint file

optional arguments:
-h, --help show this hekp message and exit
```

### 使用示例

假设原始网络脚本为`xxx/model.py`,修改后的网络脚本为`xxx/fixed_model.py`,原始权重文件为`xxx/model.ckpt`,生成的新权重文件为`xxx/fixed_model.ckpt`。

则运行命令为:

```bash
python -m mindinsight.mindconverter.tools.fix_checkpoit_file xxx/model.py xxx/fixed_model.py xxx/model.ckpt xxx/fixed_model.ckpt
```

如果显示结果如下,则说明转换完成:

```text
Saved new checkpoint file to xxx/fixed_model.ckpt.
```

### 约束限制

1. 该工具仅适用于:在MindConverter的图模式下(通过--model_file迁移)迁移生成的MindSpore网络脚本和权重文件。
2. 该工具仅适用于修改变量名,类名的场景中,不适用于修改网络结构、脚本结构(新增或删除算子定义)的应用场景。
3. 该工具依赖MindSpore,需要确保正确安装MindSpore。

+ 271
- 0
mindinsight/mindconverter/tools/fix_checkpoint_file.py View File

@@ -0,0 +1,271 @@
# Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Fix weight names in CheckPoint file when user edits converted MindSpore scripts."""
import os
import re
import argparse
import sys
import ast
from ast import NodeTransformer, ClassDef, Assign, FunctionDef
from importlib import import_module

from mindspore import load_checkpoint, Tensor, save_checkpoint, load_param_into_net
from mindinsight.mindconverter.common.log import logger_console


class FixCheckPointGenerator(NodeTransformer):
"""Fix weight names in CheckPoint file."""

def __init__(self, source_script_path, target_script_path):
self._source_script_name = os.path.basename(source_script_path)
self._target_script_name = os.path.basename(target_script_path)

self._source_variable_mapper = self._extract(source_script_path)
self._target_variable_mapper = self._extract(target_script_path)

self._fixed_mapper = self._generator()

def _extract(self, script_path):
"""Extract info from AST Tree."""

with open(script_path, 'r') as rf:
tree = ast.parse(rf.read())

variable_mapper = dict()
for block in tree.body:
if not isinstance(block, ClassDef):
continue
module_name = block.name
valid_body = [block_pick for block_pick in block.body if isinstance(block_pick, FunctionDef)]
init_body = valid_body[0]

variable_name = self._extract_init(init_body)

if not variable_mapper.get(module_name):
variable_mapper[module_name] = variable_name
else:
variable_mapper[module_name].extend(variable_name)
return variable_mapper

@staticmethod
def _extract_init(init_body):
"""Extract init information."""

variable_names = list()
for block in init_body.body:
if not isinstance(block, Assign):
continue
variable_name = block.targets[0].attr
variable_names.append(variable_name)
return variable_names

def _check_data(self, data_1, data_2):
"""Check the shape of two inputs."""

if len(data_1) != len(data_2):
logger_console.error(
f"The construct of {self._source_script_name} and that of {self._target_script_name} ars not matched.")
exit(0)

def _generator(self):
"""Generator fixed_mapper."""

main_module_name = list(self._target_variable_mapper.keys())[-1].lower()
fixed_variable_mapper = dict()
fixed_module_mapper = dict()

self._check_data(self._source_variable_mapper, self._target_variable_mapper)

for source_module_name, target_module_name in zip(self._source_variable_mapper, self._target_variable_mapper):
fixed_variable_mapper[target_module_name.lower()] = self._fixed_variable_mapper_generator(
self._source_variable_mapper[source_module_name], self._target_variable_mapper[target_module_name])

if source_module_name != target_module_name:
fixed_module_mapper[source_module_name.lower()] = target_module_name.lower()

return {
'main_module_name': main_module_name,
'fixed_variable_mapper': fixed_variable_mapper,
'fixed_module_mapper': fixed_module_mapper
}

def _fixed_variable_mapper_generator(self, source_variable_names, target_variable_names):
"""Generate fixed_variable_mapper."""

self._check_data(source_variable_names, target_variable_names)

fixed_variable_mapper = dict()
for source_variable_name, target_variable_name in zip(source_variable_names, target_variable_names):
if source_variable_name != target_variable_name:
fixed_variable_mapper[source_variable_name] = target_variable_name
return fixed_variable_mapper

def fix_ckpt(self, ckpt_path, new_ckpt_path):
"""Fix checkpoint file."""

param_dict = load_checkpoint(ckpt_path)

main_module_name = self._fixed_mapper['main_module_name']
fixed_variable_dict = self._fixed_mapper['fixed_variable_mapper']
fixed_module_dict = self._fixed_mapper['fixed_module_mapper']

save_obj = list()
for weight_name, weight_value in param_dict.items():
weight_name_scopes = weight_name.split('.')
weight_name_scopes.insert(0, main_module_name)
for idx, w in enumerate(weight_name_scopes[:-1]):
for fixed_variable_module, fixed_variable_name_mapper in fixed_variable_dict.items():
if re.match(fixed_variable_module, fixed_module_dict.get('_'.join(w.split('_')[:-1]), w)):
weight_name = weight_name.replace(
weight_name_scopes[idx + 1],
fixed_variable_name_mapper.get(weight_name_scopes[idx + 1], weight_name_scopes[idx + 1]))

obj = {
'name': weight_name,
'data': Tensor(weight_value)
}
save_obj.append(obj)

save_checkpoint(save_obj, new_ckpt_path)
logger_console.info(f'Saved new checkpoint file to {new_ckpt_path}.')


def source_checker(py_path, ckpt_path):
"""Check source model script and source checkpoint file."""

sys.path.append(os.path.dirname(py_path))
model = getattr(import_module(os.path.basename(py_path).replace('.py', '')), 'Model')()
param_dict = load_checkpoint(ckpt_path)
not_load_name = load_param_into_net(model, param_dict)
return not bool(not_load_name)


def file_existed_checker(parser_in, in_file, action_type):
"""Check file exists or not."""

out_file = os.path.realpath(in_file)
if not os.path.exists(out_file):
if action_type == 'in':
parser_in.error(f"{out_file} does NOT exist, check it.")
elif not os.path.exists(os.path.dirname(out_file)):
os.makedirs(os.path.dirname(out_file))
return out_file


def file_validation_checker(parser_in, in_file, expected_type):
"""Check file is valid or not."""

if not in_file.endswith(expected_type):
parser_in.error(f"'xxx{expected_type}' is expected, but gotten {os.path.basename(in_file)}.")


class ScriptAction(argparse.Action):
"""Script action class definition."""

def __call__(self, parser_in, namespace, values, option_string=None):
"""
Inherited __call__ method from argparse.Action.

Args:
parser_in (ArgumentParser): Passed-in argument parser.
namespace (Namespace): Namespace object to hold arguments.
values (str): Argument values.
option_string (str): Optional string for specific argument name. Default: None.
"""

out_file = file_existed_checker(parser_in, values, "in")
file_validation_checker(parser_in, out_file, ".py")

setattr(namespace, self.dest, out_file)


class InCheckPointAction(argparse.Action):
"""In CheckPoint action class definition."""

def __call__(self, parser_in, namespace, values, option_string=None):
"""
Inherited __call__ method from argparse.Action.

Args:
parser_in (ArgumentParser): Passed-in argument parser.
namespace (Namespace): Namespace object to hold arguments.
values (str): Argument values.
option_string (str): Optional string for specific argument name. Default: None.
"""

out_file = file_existed_checker(parser_in, values, "in")
file_validation_checker(parser_in, out_file, ".ckpt")

setattr(namespace, self.dest, out_file)


class OutCheckPointAction(argparse.Action):
"""Out CheckPoint action class definition."""

def __call__(self, parser_in, namespace, values, option_string=None):
"""
Inherited __call__ method from argparse.Action.

Args:
parser_in (ArgumentParser): Passed-in argument parser.
namespace (Namespace): Namespace object to hold arguments.
values (str): Argument values.
option_string (str): Optional string for specific argument name. Default: None.
"""

out_file = file_existed_checker(parser_in, values, "out")
file_validation_checker(parser_in, out_file, ".ckpt")

setattr(namespace, self.dest, out_file)


parser = argparse.ArgumentParser(description="Fix weight names in CheckPoint file.")
parser.add_argument(
"source_py_file",
action=ScriptAction,
help="source model script file")
parser.add_argument(
"fixed_py_file",
action=ScriptAction,
help="fixed model script file")
parser.add_argument(
"source_ckpt_file",
action=InCheckPointAction,
help="source checkpoint file")
parser.add_argument(
"fixed_ckpt_file",
action=OutCheckPointAction,
help="fixed checkpoint file")

if __name__ == '__main__':

argv = sys.argv[1:]
if not argv:
argv = ['-h']
args = parser.parse_args(argv)
else:
args = parser.parse_args()

source_py_file = args.source_py_file
fixed_py_file = args.fixed_py_file
source_ckpt_file = args.source_ckpt_file
fixed_ckpt_file = args.fixed_ckpt_file

if not source_checker(source_py_file, source_ckpt_file):
logger_console.error("source checkpoint file is not inconsistent with source model script.")
exit(0)

fix_checkpoint_generator = FixCheckPointGenerator(source_py_file, fixed_py_file)
fix_checkpoint_generator.fix_ckpt(source_ckpt_file, fixed_ckpt_file)

Loading…
Cancel
Save