Browse Source

!790 Update description of mindconverter.

Merge pull request !790 from 刘崇鸣/update_mc_readme
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
3470477dc4
2 changed files with 204 additions and 57 deletions
  1. +103
    -29
      mindinsight/mindconverter/README.md
  2. +101
    -28
      mindinsight/mindconverter/README_CN.md

+ 103
- 29
mindinsight/mindconverter/README.md View File

@@ -8,10 +8,14 @@
- [Overview](#overview)
- [Installation](#installation)
- [Usage](#usage)
- [PyTorch Model Scripts Migration](#PyTorch-model-scripts-migration)
- [TensorFlow Model Scripts Migration](#TensorFlow-model-scripts-migration)
- [Scenario](#scenario)
- [Example](#example)
- [AST-Based Conversion](#ast-based-conversion)
- [Graph-Based Conversion](#graph-based-conversion)
- [PyTorch Model Scripts Conversion](#PyTorch-Model-Scripts-Conversion)
- [TensorFlow Model Scripts Conversion](#TensorFlow-Model-Scripts-Conversion)
- [Caution](#caution)
- [Unsupported situation of AST mode](#unsupported-situation-of-ast-mode)
- [Situation1](#situation1)
@@ -21,7 +25,7 @@

## Overview

MindConverter is a migration tool to transform the model scripts from PyTorch to Mindspore. Users can migrate their PyTorch models to Mindspore rapidly with minor changes according to the conversion report.
MindConverter is a migration tool to transform the model scripts from PyTorch or TensorFlow to Mindspore. Users can migrate their PyTorch or TensorFlow models to Mindspore rapidly with minor changes according to the conversion report.


## Installation
@@ -36,22 +40,31 @@ MindConverter currently only provides command-line interface. Here is the manual
```bash
usage: mindconverter [-h] [--version] [--in_file IN_FILE]
[--model_file MODEL_FILE] [--shape SHAPE]
[--input_node INPUT_NODE] [--output_node OUTPUT_NODE]
[--output OUTPUT] [--report REPORT]
[--project_path PROJECT_PATH]

optional arguments:
-h, --help show this help message and exit
--version show program version number and exit
--version show program's version number and exit
--in_file IN_FILE Specify path for script file to use AST schema to do
script conversation.
--model_file MODEL_FILE
PyTorch .pth model file path to use graph based schema
to do script generation. When `--in_file` and
`--model_file` are both provided, use AST schema as
default.
PyTorch .pth or Tensorflow .pb model file path to use
graph based schema to do script generation. When
`--in_file` and `--model_file` are both provided, use
AST schema as default.
--shape SHAPE Optional, expected input tensor shape of
`--model_file`. It is required when use graph based
schema. Usage: --shape 3,244,244
`--model_file`. It's required when use graph based
schema. Usage: --shape 1,3,244,244
--input_node INPUT_NODE
Optional, input node(s) name of `--model_file`. It's
required when use Tensorflow model. Usage:
--input_node input_1:0,input_2:0
--output_node OUTPUT_NODE
Optional, output node(s) name of `--model_file`. It's
required when use Tensorflow model. Usage:
--output_node output_1:0,output_2:0
--output OUTPUT Optional, specify path for converted script file
directory. Default output directory is `output` folder
in the current working directory.
@@ -65,14 +78,16 @@ optional arguments:

```

**MindConverter provides two modes:**
### PyTorch Model Scripts Migration

**MindConverter provides two modes for PyTorch:**

1. **Abstract Syntax Tree (AST) based conversion**:Use the argument `--in_file` will enable the AST mode.
2. **Computational Graph basedconversion**:Use `--model_file` and `--shape` arguments will enable the Graph mode.
2. **Computational Graph based conversion**:Use `--model_file` and `--shape` arguments will enable the Graph mode.

> The AST mode will be enabled, if both `--in_file` and `--model_file` are specified.

For the Grapa mode, `--shape` is mandatory.
For the Graph mode, `--shape` is mandatory.

For the AST mode, `--shape` is ignored.

@@ -82,9 +97,13 @@ Please note that your original PyTorch project is included in the module search

> Assume the project is located at `/home/user/project/model_training`, users can use this command to add the project to `PYTHONPATH` : `export PYTHONPATH=/home/user/project/model_training:$PYTHONPATH`

> MindConverter needs the original PyTourch scripts because of the reverse serialization.
> MindConverter needs the original PyTorch scripts because of the reverse serialization.

### TensorFlow Model Scripts Migration

**MindConverter provides computational graph based conversion for TensorFlow**: Transformation will be done given `--model_file`, `--shape`, `--input_node` and `--output_node`.

> AST mode is not supported for TensorFlow, only computational graph based mode is available.

## Scenario

@@ -93,29 +112,32 @@ MindConverter provides two modes for different migration demands.
1. Keep original scripts' structures, including variables, functions, and libraries.
2. Keep extra modifications as few as possible, or no modifications are required after conversion.

The AST mode is recommended for the first demand. It parses and analyzes PyTorch scripts, then replace them with the MindSpore AST to generate codes. Theoretically, The AST mode supports any model script. However, the conversion may differ due to the coding style of original scripts.
The AST mode is recommended for the first demand (AST mode is only supported for PyTorch). It parses and analyzes PyTorch scripts, then replace them with the MindSpore AST to generate codes. Theoretically, The AST mode supports any model script. However, the conversion may differ due to the coding style of original scripts.

For the second demand, the Graph mode is recommended. As the computational graph is a standard descriptive language, it is not affected by user's coding style. This mode may have more operators converted as long as these operators are supported by MindConverter.

Some typical image classification networks such as ResNet and VGG have been tested for the Graph mode. Note that:

> 1. Currently, the Graph mode does not support models with multiple inputs. Only models with a single input and single output are supported.
> 2. The Dropout operator will be lost after conversion because the inference mode is used to load the PyTorch model. Manually re-implement is necessary.
> 2. The Dropout operator will be lost after conversion because the inference mode is used to load the PyTorch or TensorFlow model. Manually re-implement is necessary.
> 3. The Graph-based mode will be continuously developed and optimized with further updates.

Supported models list:

Supported Model | PyTorch Script |
| :----: | :----:|
| ResNet18 | [link](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/resnet.py) |
| ResNet34 | [link](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/resnet.py) |
| ResNet50 | [link](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/resnet.py) |
| ResNet101 | [link](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/resnet.py) |
| VGG11/11BN | [link](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/vgg.py) |
| VGG13/13BN | [link](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/vgg.py) |
| VGG16/16BN | [link](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/vgg.py) |
| VGG19/19BN | [link](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/vgg.py) |
| AlexNet | [link](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/alexnet.py) |
| Supported Model | PyTorch Script | TensorFlow Script |
| :----: | :----: | :----: |
| ResNet18 | [link](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/resnet.py) | / |
| ResNet34 | [link](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/resnet.py) | / |
| ResNet50 | [link](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/resnet.py) | [link](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/resnet.py) |
| ResNet101 | [link](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/resnet.py) | [link](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/resnet.py) |
| ResNet152 | [link](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/resnet.py) | [link](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/resnet.py) |
| VGG11/11BN | [link](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/vgg.py) | / |
| VGG13/13BN | [link](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/vgg.py) | / |
| VGG16/16BN | [link](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/vgg.py) | [link](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/vgg16.py) |
| VGG19/19BN | [link](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/vgg.py) | [link](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/vgg19.py) |
| AlexNet | [link](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/alexnet.py) | / |
| GoogLeNet | [link](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/googlenet.py) | / |
| MobileNetV2 | [link](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/mobilenet.py) | [link](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/mobilenet_v2.py) |

## Example

@@ -154,10 +176,12 @@ For non-transformed operators, suggestions are provided in the report. For insta

### Graph-Based Conversion

Assume the PyTorch model (.pth file) is located at `/home/user/model.pth`, with input shape (3, 224, 224) and the original PyTorch script is at `/home/user/project/model_training`. Output the transformed MindSpore script to `/home/user/output`, with the conversion report to `/home/user/output/report`. Use the following command:
#### PyTorch Model Scripts Conversion

Assume the PyTorch model (.pth file) is located at `/home/user/model.pth`, with input shape (1, 3, 224, 224) and the original PyTorch script is at `/home/user/project/model_training`. Output the transformed MindSpore script to `/home/user/output`, with the conversion report to `/home/user/output/report`. Use the following command:

```bash
mindconverter --model_file /home/user/model.pth --shape 3,224,224 \
mindconverter --model_file /home/user/model.pth --shape 1,3,224,224 \
--output /home/user/output \
--report /home/user/output/report \
--project_path /home/user/project/model_training
@@ -165,7 +189,7 @@ mindconverter --model_file /home/user/model.pth --shape 3,224,224 \

The Graph mode has the same conversion report as the AST mode. However, the line number and column number refer to the transformed scripts since no original scripts are used in the process.

In addition, input and output Tensor shape of unconverted operators shows explicitly (`input_shape` and `output_shape`) as comments in converted scripts to help further manual modifications. Here is an example of the `Reshape` operator (Not supported in current version):
In addition, input and output Tensor shape of unconverted operators shows explicitly (`input_shape` and `output_shape`) as comments in converted scripts to help further manual modifications. <a name="manual_modify">Here is an example of the `Reshape` operator (Not supported in current version)</a>:

```python
class Classifier(nn.Cell):
@@ -211,9 +235,59 @@ class Classifier(nn.Cell):
> Note: `--output` and `--report` are optional. MindConverter creates an `output` folder under the current working directory, and outputs generated scripts and conversion reports to it.


#### TensorFlow Model Scripts Conversion


To use TensorFlow model script migration, you need to export TensorFlow model to Pb format first, and obtain the model input node and output node name. You can refer to the following methods to export and obtain the node name:

```python
import tensorflow as tf
from tensorflow.python.framework import graph_io
from tensorflow.keras.applications.inception_v3 import InceptionV3


def freeze_graph(graph, session, output):
saved_path = "/home/user/xxx"
with graph.as_default():
graphdef_inf = tf.graph_util.remove_training_nodes(graph.as_graph_def())
graphdef_frozen = tf.graph_util.convert_variables_to_constants(session, graphdef_inf, output)
graph_io.write_graph(graphdef_frozen, saved_path, "frozen_model.pb", as_text=False)

tf.keras.backend.set_learning_phase(0) # this line most important

base_model = InceptionV3()
session = tf.keras.backend.get_session()

INPUT_NODE = base_model.inputs[0].op.name # Get input node name of TensorFlow.
OUTPUT_NODE = base_model.outputs[0].op.name # Get output node name of TensorFlow.
freeze_graph(session.graph, session, [out.op.name for out in base_model.outputs])
print(f"Input node name: {INPUT_NODE}, output node name: {OUTPUT_NODE}")

```

After the above code is executed, the model will be saved to `/home/user/xxx/frozen_model.pb`. `INPUT_NODE` is input node name, and `OUTPUT_NODE` is output node's.

Suppose the input node name is `input_1:0`, output node name is `predictions/Softmax:0`, the input shape of model is `1,224,224,3`, the following command can be used to generate the script:
```shell script
mindconverter --model_file /home/user/xxx/frozen_model.pb --shape 1,224,224,3 \
--input_node input_1:0 \
--output_node predictions/Softmax:0 \
--output /home/user/output \
--report /home/user/output/report
```

After executed,MindSpore script, and report file can be found in corresponding directory.


The format of conversion report generated by script generation scheme based on graph structure is the same as that of AST scheme. However, since the graph based scheme is a generative method, the original pytorch script is not referenced in the conversion process. Therefore, the code line and column numbers involved in the generated conversion report refer to the generated script.


In addition, for operators that are not converted successfully, the input and output shape of tensor of the node will be identified in the code_ shape`, `output_ For example, please refer to [PyTorch Model Scripts Conversion](#manual_modify).


## Caution

1. PyTorch is not an explicitly stated dependency library in MindInsight. The Graph conversion requires the consistent PyTorch version as the model is trained. (MindConverter recommends PyTorch 1.4.0 or 1.6.0)
1. PyTorch, TensorFlow, TF2ONNX is not an explicitly stated dependency library in MindInsight. The Graph conversion requires the consistent PyTorch or TensorFlow version as the model is trained. (MindConverter recommends PyTorch 1.4.0 or 1.6.0)
2. This script conversion tool relies on operators which supported by MindConverter and MindSpore. Unsupported operators may not successfully mapped to MindSpore operators. You can manually edit, or implement the mapping based on MindConverter, and contribute to our MindInsight repository. We appreciate your support for the MindSpore community.




+ 101
- 28
mindinsight/mindconverter/README_CN.md View File

@@ -8,10 +8,14 @@
- [概述](#概述)
- [安装](#安装)
- [命令行用法](#命令行用法)
- [PyTorch模型脚本迁移](#PyTorch模型脚本迁移)
- [TensorFlow模型脚本迁移](#TensorFlow模型脚本迁移)
- [使用场景](#使用场景)
- [使用示例](#使用示例)
- [基于AST的脚本转换示例](#基于ast的脚本转换示例)
- [基于图结构的脚本生成示例](#基于图结构的脚本生成示例)
- [PyTorch模型脚本生成示例](#PyTorch模型脚本生成示例)
- [TensorFlow模型脚本生成示例](#TensorFlow模型脚本生成示例)
- [注意事项](#注意事项)
- [AST方案不支持场景](#ast方案不支持场景)
- [场景1](#场景1)
@@ -21,7 +25,7 @@

## 概述

MindConverter是一款用于将PyTorch脚本转换到MindSpore脚本的工具。结合转换报告的信息,用户只需对转换后的脚本进行微小的改动,即可快速将PyTorch框架的模型迁移到MindSpore。
MindConverter是一款用于将PyTorch,TensorFlow脚本转换到MindSpore脚本的工具。结合转换报告的信息,用户只需对转换后的脚本进行微小的改动,即可快速将PyTorch,TensorFlow框架的模型脚本迁移到MindSpore。

## 安装

@@ -32,22 +36,31 @@ MindConverter是一款用于将PyTorch脚本转换到MindSpore脚本的工具。
```buildoutcfg
usage: mindconverter [-h] [--version] [--in_file IN_FILE]
[--model_file MODEL_FILE] [--shape SHAPE]
[--input_node INPUT_NODE] [--output_node OUTPUT_NODE]
[--output OUTPUT] [--report REPORT]
[--project_path PROJECT_PATH]

optional arguments:
-h, --help show this help message and exit
--version show program version number and exit
--version show program's version number and exit
--in_file IN_FILE Specify path for script file to use AST schema to do
script conversation.
--model_file MODEL_FILE
PyTorch .pth model file path to use graph based schema
to do script generation. When `--in_file` and
`--model_file` are both provided, use AST schema as
default.
PyTorch .pth or Tensorflow .pb model file path to use
graph based schema to do script generation. When
`--in_file` and `--model_file` are both provided, use
AST schema as default.
--shape SHAPE Optional, expected input tensor shape of
`--model_file`. It is required when use graph based
schema. Usage: --shape 3,244,244
`--model_file`. It's required when use graph based
schema. Usage: --shape 1,3,244,244
--input_node INPUT_NODE
Optional, input node(s) name of `--model_file`. It's
required when use Tensorflow model. Usage:
--input_node input_1:0,input_2:0
--output_node OUTPUT_NODE
Optional, output node(s) name of `--model_file`. It's
required when use Tensorflow model. Usage:
--output_node output_1:0,output_2:0
--output OUTPUT Optional, specify path for converted script file
directory. Default output directory is `output` folder
in the current working directory.
@@ -61,7 +74,10 @@ optional arguments:

```

**MindConverter提供两种模型脚本迁移方案:**

### PyTorch模型脚本迁移

**MindConverter提供两种PyTorch模型脚本迁移方案:**

1. **基于抽象语法树(Abstract syntax tree, AST)的脚本转换**:指定`--in_file`的值,将使用基于AST的脚本转换方案;
2. **基于图结构的脚本生成**:指定`--model_file`与`--shape`将使用基于图结构的脚本生成方案。
@@ -74,40 +90,48 @@ optional arguments:

另外,当使用基于图结构的脚本生成方案时,请确保原PyTorch项目已在Python包搜索路径中,可通过CLI进入Python交互式命令行,通过import的方式判断是否已满足;若未加入,可通过`--project_path`命令手动将项目路径传入,以确保MindConverter可引用到原PyTorch脚本。


> 假设用户项目目录为`/home/user/project/model_training`,用户可通过如下命令手动项目添加至包搜索路径中:`export PYTHONPATH=/home/user/project/model_training:$PYTHONPATH`

> 此处MindConverter需要引用原PyTorch脚本,是因为PyTorch模型反向序列化过程中会引用原脚本。

### TensorFlow模型脚本迁移

**MindConverter提供基于图结构的脚本生成方案**:指定`--model_file`, `--shape`, `--input_node`, `--output_node`进行脚本迁移。

> AST方案不支持TensorFlow模型脚本迁移,TensorFlow脚本迁移仅支持基于图结构的方案。

## 使用场景

MindConverter提供两种技术方案,以应对不同脚本迁移场景:
1. 用户希望迁移后脚本保持原有PyTorch脚本结构(包括变量、函数、类命名等与原脚本保持一致);
1. 用户希望迁移后脚本保持原脚本结构(包括变量、函数、类命名等与原脚本保持一致);
2. 用户希望迁移后脚本保持较高的转换率,尽量少的修改、甚至不需要修改,即可实现迁移后模型脚本的执行。

对于上述第一种场景,推荐用户使用基于AST的方案进行转换,AST方案通过对原PyTorch脚本的抽象语法树进行解析、编辑,将其替换为MindSpore的抽象语法树,再利用抽象语法树生成代码。理论上,AST方案支持任意模型脚本迁移,但语法树解析操作受原脚本用户编码风格影响,可能导致同一模型的不同脚本最终的转换率存在一定差异。
对于上述第一种场景,推荐用户使用基于AST的方案进行转换(AST方案仅支持PyTorch脚本转换),AST方案通过对原PyTorch脚本的抽象语法树进行解析、编辑,将其替换为MindSpore的抽象语法树,再利用抽象语法树生成代码。理论上,AST方案支持任意模型脚本迁移,但语法树解析操作受原脚本用户编码风格影响,可能导致同一模型的不同脚本最终的转换率存在一定差异。

对于上述第二种场景,推荐用户使用基于图结构的脚本生成方案,计算图作为一种标准的模型描述语言,可以消除用户代码风格多样导致的脚本转换率不稳定的问题。在已支持算子的情况下,该方案可提供优于AST方案的转换率。

目前已基于典型图像分类网络(Resnet, VGG)对图结构的脚本转换方案进行测试。

> 1. 基于图结构的脚本生成方案,目前仅支持单输入、单输出模型,对于多输入模型暂不支持;
> 2. 基于图结构的脚本生成方案,由于要基于推理模式加载PyTorch模型,会导致转换后网络中Dropout算子丢失,需要用户手动补齐;
> 2. 基于图结构的脚本生成方案,由于要加载PyTorch, TensorFlow模型,会导致转换后网络中Dropout算子丢失,需要用户手动补齐;
> 3. 基于图结构的脚本生成方案持续优化中。

支持网络列表:

| 支持网络 | PyTorch脚本 |
| :----: | :----:|
| ResNet18 | [脚本链接](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/resnet.py) |
| ResNet34 | [脚本链接](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/resnet.py) |
| ResNet50 | [脚本链接](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/resnet.py) |
| ResNet101 | [脚本链接](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/resnet.py) |
| VGG11/11BN | [脚本链接](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/vgg.py) |
| VGG13/13BN | [脚本链接](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/vgg.py) |
| VGG16/16BN | [脚本链接](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/vgg.py) |
| VGG19/19BN | [脚本链接](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/vgg.py) |
| AlexNet | [脚本链接](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/alexnet.py) |
| 支持网络 | PyTorch脚本 | TensorFlow脚本 |
| :----: | :----: | :----: |
| ResNet18 | [脚本链接](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/resnet.py) | 暂未测试 |
| ResNet34 | [脚本链接](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/resnet.py) | 暂未测试 |
| ResNet50 | [脚本链接](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/resnet.py) | [脚本链接](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/resnet.py) |
| ResNet101 | [脚本链接](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/resnet.py) | [脚本链接](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/resnet.py) |
| ResNet152 | [脚本链接](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/resnet.py) | [脚本链接](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/resnet.py) |
| VGG11/11BN | [脚本链接](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/vgg.py) | 暂未测试 |
| VGG13/13BN | [脚本链接](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/vgg.py) | 暂未测试 |
| VGG16/16BN | [脚本链接](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/vgg.py) | [脚本链接](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/vgg16.py) |
| VGG19/19BN | [脚本链接](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/vgg.py) | [脚本链接](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/vgg19.py) |
| AlexNet | [脚本链接](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/alexnet.py) | 暂未测试 |
| GoogLeNet | [脚本链接](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/googlenet.py) | 暂未测试 |
| MobileNetV2 | [脚本链接](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/mobilenet.py) | [脚本链接](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/mobilenet_v2.py) |

## 使用示例
### 基于AST的脚本转换示例
@@ -142,10 +166,11 @@ line x:y: [UnConvert] 'operator' didn't convert. ...

### 基于图结构的脚本生成示例

若用户已将PyTorch模型保存为.pth格式,假设模型绝对路径为`/home/user/model.pth`,该模型期望的输入样本shape为(3, 224, 224),原PyTorch脚本位于`/home/user/project/model_training`,希望将脚本输出至`/home/user/output`,转换报告输出至`/home/user/output/report`,则脚本生成命令为:
#### PyTorch模型脚本生成示例
若用户已将PyTorch模型保存为.pth格式,假设模型绝对路径为`/home/user/model.pth`,该模型期望的输入shape为(1, 3, 224, 224),原PyTorch脚本位于`/home/user/project/model_training`,希望将脚本输出至`/home/user/output`,转换报告输出至`/home/user/output/report`,则脚本生成命令为:

```bash
mindconverter --model_file /home/user/model.pth --shape 3,224,224 \
mindconverter --model_file /home/user/model.pth --shape 1,3,224,224 \
--output /home/user/output \
--report /home/user/output/report \
--project_path /home/user/project/model_training
@@ -157,7 +182,7 @@ mindconverter --model_file /home/user/model.pth --shape 3,224,224 \
基于图结构的脚本生成方案产生的转换报告格式与AST方案相同。然而,由于基于图结构方案属于生成式方法,转换过程中未参考原PyTorch脚本,因此生成的转换报告中涉及的代码行、列号均指生成后脚本。


另外对于未成功转换的算子,在代码中会相应的标识该节点输入、输出Tensor的shape(以`input_shape`, `output_shape`标识),便于用户手动修改。以Reshape算子为例(暂不支持Reshape),将生成如下代码:
另外对于未成功转换的算子,在代码中会相应的标识该节点输入、输出Tensor的shape(以`input_shape`, `output_shape`标识),便于用户手动修改。以Reshape算子为例(暂不支持Reshape),<a name="manual_modify">将生成如下代码</a>

```python
class Classifier(nn.Cell):
@@ -204,9 +229,57 @@ class Classifier(nn.Cell):
> 注意:其中`--output`与`--report`参数可省略,若省略,该命令将在当前工作目录(Working directory)下自动创建`output`目录,将生成的脚本、转换报告输出至该目录。


#### TensorFlow模型脚本生成示例

使用TensorFlow模型脚本迁移,需要先将TensorFlow模型导出为pb格式,并且获取模型输入节点、输出节点名称,可参考如下方法进行导出、获取节点名称:
```python
import tensorflow as tf
from tensorflow.python.framework import graph_io
from tensorflow.keras.applications.inception_v3 import InceptionV3


def freeze_graph(graph, session, output):
saved_path = "/home/user/xxx"
with graph.as_default():
graphdef_inf = tf.graph_util.remove_training_nodes(graph.as_graph_def())
graphdef_frozen = tf.graph_util.convert_variables_to_constants(session, graphdef_inf, output)
graph_io.write_graph(graphdef_frozen, saved_path, "frozen_model.pb", as_text=False)

tf.keras.backend.set_learning_phase(0) # this line most important

base_model = InceptionV3()
session = tf.keras.backend.get_session()

INPUT_NODE = base_model.inputs[0].op.name # Get input node name of TensorFlow.
OUTPUT_NODE = base_model.outputs[0].op.name # Get output node name of TensorFlow.
freeze_graph(session.graph, session, [out.op.name for out in base_model.outputs])
print(f"Input node name: {INPUT_NODE}, output node name: {OUTPUT_NODE}")

```

上述代码执行完毕,模型将会保存至`/home/user/xxx/frozen_model.pb`。其中,`INPUT_NODE`为输入节点名称,`OUTPUT_NODE`为输出节点名称。

假设输入节点名称为`input_1:0`, 输出节点名称为`predictions/Softmax:0`,模型输入样本尺寸为`1,224,224,3`,则可使用如下命令进行脚本生成:
```shell script
mindconverter --model_file /home/user/xxx/frozen_model.pb --shape 1,224,224,3 \
--input_node input_1:0 \
--output_node predictions/Softmax:0 \
--output /home/user/output \
--report /home/user/output/report
```

执行该命令,MindSpore代码文件、转换报告生成至相应目录。


基于图结构的脚本生成方案产生的转换报告格式与AST方案相同。然而,由于基于图结构方案属于生成式方法,转换过程中未参考原PyTorch脚本,因此生成的转换报告中涉及的代码行、列号均指生成后脚本。


另外对于未成功转换的算子,在代码中会相应的标识该节点输入、输出Tensor的shape(以`input_shape`, `output_shape`标识),便于用户手动修改,示例请见[PyTorch模型脚本生成示例](#manual_modify)。


## 注意事项

1. PyTorch不作为MindInsight明确声明的依赖库。若想使用基于图结构的脚本生成工具,需要用户手动安装与生成PyTorch模型版本一致的PyTorch库(MindConverter推荐使用PyTorch 1.4.0或PyTorch 1.6.0进行脚本生成);
1. PyTorch, TensorFlow, TF2ONNX不作为MindInsight明确声明的依赖库。若想使用基于图结构的脚本生成工具,需要用户手动安装与生成PyTorch模型版本一致的PyTorch库(MindConverter推荐使用PyTorch 1.4.0或PyTorch 1.6.0进行脚本生成),或TensorFlow
2. 脚本转换工具本质上为算子驱动,对于MindConverter未维护的PyTorch或ONNX算子与MindSpore算子映射,将会出现相应的算子无法转换的问题,对于该类算子,用户可手动修改,或基于MindConverter实现映射关系,向MindInsight仓库贡献。




Loading…
Cancel
Save