From: @moran3 Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -131,7 +131,7 @@ Some typical image classification networks have been tested for the Graph mode. | |||||
| > 2. The Dropout operator will be lost after conversion because the inference mode is used to load the PyTorch or TensorFlow model. Manually re-implement is necessary. | > 2. The Dropout operator will be lost after conversion because the inference mode is used to load the PyTorch or TensorFlow model. Manually re-implement is necessary. | ||||
| > 3. The Graph-based mode will be continuously developed and optimized with further updates. | > 3. The Graph-based mode will be continuously developed and optimized with further updates. | ||||
| Supported models list (Models in below table have been tested based on PyTorch 1.4.0(TorchVision 0.5.0) and TensorFlow 1.15.0, X86 Ubuntu released version): | |||||
| Supported models list (Models in below table have been tested based on PyTorch 1.5.0 and TensorFlow 1.15.0, X86 Ubuntu released version): | |||||
| | Supported Model | PyTorch Script | TensorFlow Script | Comment | PyTorch Weights Converted | TensorFlow Weights Converted | | | Supported Model | PyTorch Script | TensorFlow Script | Comment | PyTorch Weights Converted | TensorFlow Weights Converted | | ||||
| | :----: | :----: | :----: | :----: | :----: | :----: | | | :----: | :----: | :----: | :----: | :----: | :----: | | ||||
| @@ -139,7 +139,7 @@ Supported models list (Models in below table have been tested based on PyTorch 1 | |||||
| | ResNet34 | [Link](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/resnet.py) | / | | TESTED | / | | | ResNet34 | [Link](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/resnet.py) | / | | TESTED | / | | ||||
| | ResNet50 | [Link](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/resnet.py) | [Link](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/resnet.py) | | TESTED | TESTED | | | ResNet50 | [Link](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/resnet.py) | [Link](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/resnet.py) | | TESTED | TESTED | | ||||
| | ResNet50V2 | / | [Link](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/resnet_v2.py) | | / | TESTED | | | ResNet50V2 | / | [Link](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/resnet_v2.py) | | / | TESTED | | ||||
| | ResNet101 | [Link](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/resnet.py) | [Link](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/resnet.py) | | TESTED | TESTED | | |||||
| | ResNet101 | [Link](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/resnet.py) | [Link](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/resnet.py) | | UNTESTED | TESTED | | |||||
| | ResNet101V2 | / | [Link](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/resnet_v2.py) | | / | TESTED | | | ResNet101V2 | / | [Link](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/resnet_v2.py) | | / | TESTED | | ||||
| | ResNet152 | [Link](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/resnet.py) | [Link](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/resnet.py) | | TESTED | TESTED | | | ResNet152 | [Link](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/resnet.py) | [Link](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/resnet.py) | | TESTED | TESTED | | ||||
| | ResNet152V2 | / | [Link](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/resnet_v2.py) | | / | TESTED | | | ResNet152V2 | / | [Link](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/resnet_v2.py) | | / | TESTED | | ||||
| @@ -158,12 +158,12 @@ Supported models list (Models in below table have been tested based on PyTorch 1 | |||||
| | InceptionResNetV2 | / | [Link](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/inception_resnet_v2.py) | | / | TESTED | | | InceptionResNetV2 | / | [Link](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/inception_resnet_v2.py) | | / | TESTED | | ||||
| | MobileNetV1 | / | [Link](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/mobilenet.py) | | / | TESTED | | | MobileNetV1 | / | [Link](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/mobilenet.py) | | / | TESTED | | ||||
| | MobileNetV2 | [Link](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/mobilenet.py) | [Link](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/mobilenet_v2.py) | | TESTED | TESTED | | | MobileNetV2 | [Link](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/mobilenet.py) | [Link](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/mobilenet_v2.py) | | TESTED | TESTED | | ||||
| | MNASNet | [Link](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/mnasnet.py) | / | | TESTED | / | | |||||
| | MNASNet | [Link](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/mnasnet.py) | / | | mnasnet0_5:TESTED mnasnet0_75:UNTESTED mnasnet1_0:TESTED mnasnet1_3:UNTESTED | / | | |||||
| | SqueezeNet | [Link](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/squeezenet.py) | / | | TESTED | / | | | SqueezeNet | [Link](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/squeezenet.py) | / | | TESTED | / | | ||||
| | DenseNet121/169/201 | [Link](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/densenet.py) | [Link](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/densenet.py) | | TESTED | TESTED | | | DenseNet121/169/201 | [Link](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/densenet.py) | [Link](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/densenet.py) | | TESTED | TESTED | | ||||
| | DenseNet161 | [Link](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/densenet.py) | / | | TESTED | / | | | DenseNet161 | [Link](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/densenet.py) | / | | TESTED | / | | ||||
| | NASNetMobile/Large | / | [Link](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/nasnet.py) | | / | TESTED | | | NASNetMobile/Large | / | [Link](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/nasnet.py) | | / | TESTED | | ||||
| | EfficientNetB0~B7 | [Link](https://github.com/lukemelas/EfficientNet-PyTorch) | [TF1.5Link](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet) [TF2.3Link](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/efficientnet.py) | | TESTED | UNTESTED(TF1.5) TESTED(TF2.3) | | |||||
| | EfficientNetB0~B7 | [Link](https://github.com/lukemelas/EfficientNet-PyTorch) | [TF1.15Link](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet) [TF2.3Link](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/efficientnet.py) | | TESTED | TESTED(TF1.15) TESTED(TF2.3) | | |||||
| | Unet | [Link](https://github.com/milesial/Pytorch-UNet) | [Link](https://github.com/zhixuhao/unet) | Due to Operator `mindspore.ops.ResizeBilinear` is not implemented on GPU device for now, operator `mindspore.ops.ResizeBilinear` should be replaced by operator `mindspore.ops.ResizeNearestNeighbor`, while running in GPU device | TESTED | TESTED | | | Unet | [Link](https://github.com/milesial/Pytorch-UNet) | [Link](https://github.com/zhixuhao/unet) | Due to Operator `mindspore.ops.ResizeBilinear` is not implemented on GPU device for now, operator `mindspore.ops.ResizeBilinear` should be replaced by operator `mindspore.ops.ResizeNearestNeighbor`, while running in GPU device | TESTED | TESTED | | ||||
| ## Example | ## Example | ||||
| @@ -291,7 +291,7 @@ class Classifier(nn.Cell): | |||||
| ``` | ``` | ||||
| > `--output` and `--report` are optional. MindConverter creates an `output` folder under the current working directory, and outputs generated scripts and conversion reports to it. | |||||
| > `--output` and `--report` are optional. MindConverter creates an `output` folder under the current working directory, and outputs generated scripts, MindSpore checkpoint file, weight map file and conversion reports to it. | |||||
| #### TensorFlow Model Scripts Conversion | #### TensorFlow Model Scripts Conversion | ||||
| @@ -335,7 +335,7 @@ In addition, for operators that are not converted successfully, the input and ou | |||||
| ## Caution | ## Caution | ||||
| 1. PyTorch, TensorFlow are not an explicitly stated dependency libraries in MindInsight. The Graph conversion requires the consistent PyTorch or TensorFlow version as the model is trained. (MindConverter recommends PyTorch 1.4.0) | |||||
| 1. PyTorch, TensorFlow are not an explicitly stated dependency libraries in MindInsight. The Graph conversion requires the consistent PyTorch or TensorFlow version as the model is trained. (For MindConverter, PyTorch 1.5.0 is supported while PyTorch 1.4.x is unsupported; PyTorch 1.6.x and PyTorch 1.7.x are untested.) | |||||
| 2. This script conversion tool relies on operators which supported by MindConverter and MindSpore. Unsupported operators may not be successfully mapped to MindSpore operators. You can manually edit, or implement the mapping based on MindConverter, and contribute to our MindInsight repository. We appreciate your support for the MindSpore community. | 2. This script conversion tool relies on operators which supported by MindConverter and MindSpore. Unsupported operators may not be successfully mapped to MindSpore operators. You can manually edit, or implement the mapping based on MindConverter, and contribute to our MindInsight repository. We appreciate your support for the MindSpore community. | ||||
| 3. MindConverter can only guarantee that the converted model scripts require a minor revision or no revision when the inputs' shape fed to the generated model script are equal to the value of `--shape` (The batch size dimension is not limited). | 3. MindConverter can only guarantee that the converted model scripts require a minor revision or no revision when the inputs' shape fed to the generated model script are equal to the value of `--shape` (The batch size dimension is not limited). | ||||
| 4. MindSpore script, MindSpore checkpoint file and weight map file are saved in the same file folder path. | 4. MindSpore script, MindSpore checkpoint file and weight map file are saved in the same file folder path. | ||||
| @@ -485,7 +485,7 @@ def convert_to_froze_graph(keras_model: tf.python.keras.models.Model, model_name | |||||
| | TreeCreateFailError | Fail to create code hierarchical tree | 2000000 | Mainly caused by usage of `torch.nn.functional.xxx`, `torch.xxx`, `torch.Tensor.xxx` in PyTorch | | | TreeCreateFailError | Fail to create code hierarchical tree | 2000000 | Mainly caused by usage of `torch.nn.functional.xxx`, `torch.xxx`, `torch.Tensor.xxx` in PyTorch | | ||||
| | NodeInputMissingError | Fail to get the input node info | 2000001 | Fail to get input node info | | | NodeInputMissingError | Fail to get the input node info | 2000001 | Fail to get input node info | | ||||
| | TreeNodeInsertError | Fail to insert tree node | 2000002 | Mainly caused by wrong scope name | | | TreeNodeInsertError | Fail to insert tree node | 2000002 | Mainly caused by wrong scope name | | ||||
| | SourceFilesSaveError | Fail to generate or save converted script | 3000000 | Exception caused by 3000001~3000003 | | |||||
| | SourceFilesSaveError | Fail to generate or save converted script | 3000000 | Exception caused by 3000001~3000005 | | |||||
| | NodeInputTypeNotSupportError | Fail to recognize the input type of converted operator | 3000001 | Wrong input type set in mapper | | | NodeInputTypeNotSupportError | Fail to recognize the input type of converted operator | 3000001 | Wrong input type set in mapper | | ||||
| | ScriptGenerationError | Fail to generate converted script | 3000002 | No left space on hard disk; Converted code is not legal; A file with the same name already exists in `--output` | | | ScriptGenerationError | Fail to generate converted script | 3000002 | No left space on hard disk; Converted code is not legal; A file with the same name already exists in `--output` | | ||||
| | ReportGenerationError | Fail to generate converted script | 3000003 | No left space on hard disk; No available operator to be converted;A file with the same name already exists in `--report` | | | ReportGenerationError | Fail to generate converted script | 3000003 | No left space on hard disk; No available operator to be converted;A file with the same name already exists in `--report` | | ||||
| @@ -130,15 +130,15 @@ MindConverter提供两种技术方案,以应对不同脚本迁移场景: | |||||
| > 2. 基于图结构的脚本生成方案,由于要加载PyTorch、TensorFlow模型,会导致转换后网络中Dropout算子丢失,需要用户手动补齐; | > 2. 基于图结构的脚本生成方案,由于要加载PyTorch、TensorFlow模型,会导致转换后网络中Dropout算子丢失,需要用户手动补齐; | ||||
| > 3. 基于图结构的脚本生成方案持续优化中。 | > 3. 基于图结构的脚本生成方案持续优化中。 | ||||
| 支持的模型列表(如下模型已基于x86 Ubuntu发行版,PyTorch 1.4.0(TorchVision 0.5)以及TensorFlow 1.15.0测试通过): | |||||
| 支持的模型列表(如下模型已基于x86 Ubuntu发行版,PyTorch 1.5.0以及TensorFlow 1.15.0测试通过): | |||||
| | 模型 | PyTorch脚本 | TensorFlow脚本 | 备注 | PyTorch权重迁移 | TensorFlow权重迁移 | | |||||
| | :----: | :----: | :----: | :----: | :----: | :----: | | |||||
| | 模型 | PyTorch脚本 | TensorFlow脚本 | 备注 | PyTorch权重迁移 | TensorFlow权重迁移 | | |||||
| | :----: | :-----: | :----: | :----: | :----: | :----: | | |||||
| | ResNet18 | [脚本链接](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/resnet.py) | 暂未测试 | | 已测试 | / | | | ResNet18 | [脚本链接](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/resnet.py) | 暂未测试 | | 已测试 | / | | ||||
| | ResNet34 | [脚本链接](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/resnet.py) | 暂未测试 | | 已测试 | / | | | ResNet34 | [脚本链接](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/resnet.py) | 暂未测试 | | 已测试 | / | | ||||
| | ResNet50 | [脚本链接](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/resnet.py) | [脚本链接](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/resnet.py) | | 已测试 | 已测试 | | | ResNet50 | [脚本链接](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/resnet.py) | [脚本链接](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/resnet.py) | | 已测试 | 已测试 | | ||||
| | ResNet50V2 | 暂未测试 | [脚本链接](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/resnet_v2.py) | | / | 已测试 | | | ResNet50V2 | 暂未测试 | [脚本链接](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/resnet_v2.py) | | / | 已测试 | | ||||
| | ResNet101 | [脚本链接](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/resnet.py) | [脚本链接](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/resnet.py) | | 已测试 | 已测试 | | |||||
| | ResNet101 | [脚本链接](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/resnet.py) | [脚本链接](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/resnet.py) | | 未测试 | 已测试 | | |||||
| | ResNet101V2 | 暂未测试 | [脚本链接](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/resnet_v2.py) | | / | 已测试 | | | ResNet101V2 | 暂未测试 | [脚本链接](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/resnet_v2.py) | | / | 已测试 | | ||||
| | ResNet152 | [脚本链接](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/resnet.py) | [脚本链接](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/resnet.py) | | 已测试 | 已测试 | | | ResNet152 | [脚本链接](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/resnet.py) | [脚本链接](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/resnet.py) | | 已测试 | 已测试 | | ||||
| | ResNet152V2 | 暂未测试 | [脚本链接](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/resnet_v2.py) | | / | 已测试 | | | ResNet152V2 | 暂未测试 | [脚本链接](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/resnet_v2.py) | | / | 已测试 | | ||||
| @@ -157,12 +157,12 @@ MindConverter提供两种技术方案,以应对不同脚本迁移场景: | |||||
| | InceptionResNetV2 | 暂未测试 | [脚本链接](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/inception_resnet_v2.py) | | / | 已测试 | | | InceptionResNetV2 | 暂未测试 | [脚本链接](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/inception_resnet_v2.py) | | / | 已测试 | | ||||
| | MobileNetV1 | 暂未测试 | [脚本链接](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/mobilenet.py) | | / | 已测试 | | | MobileNetV1 | 暂未测试 | [脚本链接](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/mobilenet.py) | | / | 已测试 | | ||||
| | MobileNetV2 | [脚本链接](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/mobilenet.py) | [脚本链接](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/mobilenet_v2.py) | | 已测试 | 已测试 | | | MobileNetV2 | [脚本链接](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/mobilenet.py) | [脚本链接](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/mobilenet_v2.py) | | 已测试 | 已测试 | | ||||
| | MNASNet | [脚本链接](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/mnasnet.py) | 暂未测试 | | 已测试 | / | | |||||
| | MNASNet | [脚本链接](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/mnasnet.py) | 暂未测试 | | mnasnet0_5:已测试 mnasnet0_75:未测试 mnasnet1_0:已测试 mnasnet1_3:未测试 | / | | |||||
| | SqueezeNet | [脚本链接](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/squeezenet.py) | 暂未测试 | | 已测试 | / | | | SqueezeNet | [脚本链接](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/squeezenet.py) | 暂未测试 | | 已测试 | / | | ||||
| | DenseNet121/169/201 | [脚本链接](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/densenet.py) | [脚本链接](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/densenet.py) | | 已测试 | 已测试 | | | DenseNet121/169/201 | [脚本链接](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/densenet.py) | [脚本链接](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/densenet.py) | | 已测试 | 已测试 | | ||||
| | DenseNet161 | [脚本链接](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/densenet.py) | 暂未测试 | | 已测试 | / | | | DenseNet161 | [脚本链接](https://github.com/pytorch/vision/blob/v0.5.0/torchvision/models/densenet.py) | 暂未测试 | | 已测试 | / | | ||||
| | NASNetMobile/Large | 暂未测试 | [脚本链接](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/nasnet.py) | | / | 已测试 | | | NASNetMobile/Large | 暂未测试 | [脚本链接](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/nasnet.py) | | / | 已测试 | | ||||
| | EfficientNetB0~B7 | [脚本链接](https://github.com/lukemelas/EfficientNet-PyTorch) | [TF1.5脚本链接](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet) [TF2.3脚本链接](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/efficientnet.py) | | 已测试 | 未测试(TF1.5) 已测试(TF2.3)| | |||||
| | EfficientNetB0~B7 | [脚本链接](https://github.com/lukemelas/EfficientNet-PyTorch) | [TF1.15<br />脚本链接](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet) <br />[TF2.3<br />脚本链接](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/efficientnet.py) | | 已测试 | 已测试(TF1.15) 已测试(TF2.3)| | |||||
| | Unet | [脚本链接](https://github.com/milesial/Pytorch-UNet) | [脚本链接](https://github.com/zhixuhao/unet) | 由于算子`mindspore.ops.ResizeBilinear`在GPU上暂未实现,所以当运行在GPU设备上时,算子`mindspore.ops.ResizeBilinear`需要被替换为算子`mindspore.ops.ResizeNearestNeighbor` | 已测试 | 已测试 | | | Unet | [脚本链接](https://github.com/milesial/Pytorch-UNet) | [脚本链接](https://github.com/zhixuhao/unet) | 由于算子`mindspore.ops.ResizeBilinear`在GPU上暂未实现,所以当运行在GPU设备上时,算子`mindspore.ops.ResizeBilinear`需要被替换为算子`mindspore.ops.ResizeNearestNeighbor` | 已测试 | 已测试 | | ||||
| ## 使用示例 | ## 使用示例 | ||||
| @@ -239,8 +239,8 @@ line x:y: [UnConvert] 'operator' didn't convert. ... | |||||
| #### PyTorch模型脚本生成示例 | #### PyTorch模型脚本生成示例 | ||||
| 若用户已将PyTorch模型保存为.pth格式,假设模型绝对路径为`/home/uer/model.pth`,该模型期望的输入shape为(1, 3, 224, 224) | 若用户已将PyTorch模型保存为.pth格式,假设模型绝对路径为`/home/uer/model.pth`,该模型期望的输入shape为(1, 3, 224, 224) | ||||
| ,原PyTorch脚本位于`/home/user/project/model_training`,希望将脚本、权重文件和权重映射表输出至`/home/user/output`,转换报告输出至`/home/user/output/report`, | |||||
| 。则脚本生成命令为: | |||||
| ,原PyTorch脚本位于`/home/user/project/model_training`,希望将脚本、权重文件和权重映射表输出至`/home/user/output`,转换报告输出至`/home/user/output/report` | |||||
| 。<br /> 则脚本生成命令为: | |||||
| ```bash | ```bash | ||||
| mindconverter --model_file /home/user/model.pth --shape 1,3,224,224 \ | mindconverter --model_file /home/user/model.pth --shape 1,3,224,224 \ | ||||
| @@ -299,7 +299,7 @@ class Classifier(nn.Cell): | |||||
| ``` | ``` | ||||
| > 其中`--output`与`--report`参数可省略,若省略,该命令将在当前工作目录(Working directory)下自动创建`output`目录,将生成的脚本、转换报告输出至该目录。 | |||||
| > 其中`--output`与`--report`参数可省略,若省略,该命令将在当前工作目录(Working directory)下自动创建`output`目录,将生成的脚本、权重文件、权重映射表、转换报告输出至该目录。 | |||||
| #### TensorFlow模型脚本生成示例 | #### TensorFlow模型脚本生成示例 | ||||
| @@ -344,7 +344,7 @@ mindconverter --model_file /home/user/xxx/model.onnx --shape 1,3,224,224 \ | |||||
| ## 注意事项 | ## 注意事项 | ||||
| 1. PyTorch、TensorFlow不作为MindInsight明确声明的依赖库。若想使用基于图结构的脚本生成工具,需要用户手动安装与生成PyTorch模型版本一致的PyTorch库(MindConverter推荐使用PyTorch 1.4.0进行脚本生成),或TensorFlow; | |||||
| 1. PyTorch、TensorFlow不作为MindInsight明确声明的依赖库。若想使用基于图结构的脚本生成工具,需要用户手动安装与生成PyTorch模型版本一致的PyTorch库(MindConverter使用PyTorch 1.5.0进行测试,不支持PyTorch 1.4.x; PyTorch 1.6.x、PyTorch 1.7.x未进行测试。),或TensorFlow; | |||||
| 2. 脚本转换工具本质上为算子驱动,对于MindConverter未维护的PyTorch或ONNX算子与MindSpore算子映射,将会出现相应的算子无法转换的问题,对于该类算子,用户可手动修改,或基于MindConverter实现映射关系,向MindInsight仓库贡献。 | 2. 脚本转换工具本质上为算子驱动,对于MindConverter未维护的PyTorch或ONNX算子与MindSpore算子映射,将会出现相应的算子无法转换的问题,对于该类算子,用户可手动修改,或基于MindConverter实现映射关系,向MindInsight仓库贡献。 | ||||
| 3. MindConverter仅保证转换后模型脚本在输入数据尺寸与`--shape`一致的情况下,可达到无需人工修改或少量修改(`--shape`中batch size维度不受限)。 | 3. MindConverter仅保证转换后模型脚本在输入数据尺寸与`--shape`一致的情况下,可达到无需人工修改或少量修改(`--shape`中batch size维度不受限)。 | ||||
| 4. 脚本文件、权重文件和权重映射表输出于同一个目录下。 | 4. 脚本文件、权重文件和权重映射表输出于同一个目录下。 | ||||
| @@ -486,21 +486,21 @@ def convert_to_froze_graph(keras_model: tf.python.keras.models.Model, model_name | |||||
| ### MindConverter错误码速查表 | ### MindConverter错误码速查表 | ||||
| | 异常声明 | 异常描述 | 异常代码 | 常见原因 | | |||||
| | :----------------------------: | :------: | :----------------------------------------------------------- | ----------------------- | | |||||
| | MindConverterException | MindConverter异常基类 | NAN | MindConverter异常基类。 | | |||||
| | BaseConverterError | 未知错误引起的转换失败 | 0000000 | 程序运行中出现未知错误,请打开MindInsight log文件(默认位于`~/mindinsight/log/mindconverter/`目录下)查看具体错误原因。 | | |||||
| | UnKnownModelError | 识别网络模型对应的框架失败 | 0000001 | 通常为用户给定模型文件不符合TensorFlow或PyTorch标准。 | | |||||
| | 异常声明 | 异常描述 | 异常代码 | 常见原因 | | |||||
| | :----------------------------: | :------: | :--------------- | ----------------------- | | |||||
| | MindConverterException | MindConverter异常基类 | NAN | MindConverter异常基类。 | | |||||
| | BaseConverterError | 未知错误引起的转换失败 | 0000000 | 程序运行中出现未知错误,请打开MindInsight log文件(默认位于`~/mindinsight/log/mindconverter/`目录下)查看具体错误原因。 | | |||||
| | UnKnownModelError | 识别网络模型对应的框架失败 | 0000001 | 通常为用户给定模型文件不符合TensorFlow或PyTorch标准。 | | |||||
| | ParamMissingError | 缺少转换所需参数 | 0000002 | 通常为`--shape`, `--input_nodes` , `--output_nodes`缺失导致 | | | ParamMissingError | 缺少转换所需参数 | 0000002 | 通常为`--shape`, `--input_nodes` , `--output_nodes`缺失导致 | | ||||
| | GraphInitFailError | 依据网络模型构建计算图失败 | 1000000 | 由1000001,1000002,1000003导致的计算图无法解析。 | | | GraphInitFailError | 依据网络模型构建计算图失败 | 1000000 | 由1000001,1000002,1000003导致的计算图无法解析。 | | ||||
| | ModelNotSupportError | 解析.pth/.pb文件失败 | 1000001 | 给定的`--input_nodes`, `--output_nodes`与实际模型不符;或模型文件存在问题导致模型无法加载。 | | |||||
| | TfRuntimeError | TensorFlow库执行出错 | 1000002 | TensorFlow启动申请所需资源失败导致无法正常启动,请检查系统资源(进程数、内存、显存占用、CPU占用)是否充足。 | | |||||
| | ModelNotSupportError | 解析.pth/.pb文件失败 | 1000001 | 给定的`--input_nodes`, `--output_nodes`与实际模型不符;<br />或模型文件存在问题导致模型无法加载。 | | |||||
| | TfRuntimeError | TensorFlow库执行出错 | 1000002 | TensorFlow启动申请所需资源失败导致无法正常启动,<br />请检查系统资源(进程数、内存、显存占用、CPU占用)是否充足。 | | |||||
| | ModelLoadingError | 模型加载失败 | 1000003 | 可能由于用户给定网络输入尺寸错误导致模型无法加载。 | | | ModelLoadingError | 模型加载失败 | 1000003 | 可能由于用户给定网络输入尺寸错误导致模型无法加载。 | | ||||
| | RuntimeIntegrityError | 三方依赖库不完整 | 1000004 | MindConverter运行时所需的三方依赖库未安装。 | | | RuntimeIntegrityError | 三方依赖库不完整 | 1000004 | MindConverter运行时所需的三方依赖库未安装。 | | ||||
| | TreeCreateFailError | 依据计算图构建模型树失败 | 2000000 | Tree用于生成最终代码结构,通常由于PyTorch网络中存在`torch.nn.functional.xxx`, `torch.xxx`, `torch.Tensor.xxx`算子导致。 | | |||||
| | TreeCreateFailError | 依据计算图构建模型树失败 | 2000000 | Tree用于生成最终代码结构,<br />通常由于PyTorch网络中存在`torch.nn.functional.xxx`, `torch.xxx`, `torch.Tensor.xxx`算子导致。 | | |||||
| | NodeInputMissingError | 网络节点输入信息丢失 | 2000001 | 节点的输入信息丢失。 | | | NodeInputMissingError | 网络节点输入信息丢失 | 2000001 | 节点的输入信息丢失。 | | ||||
| | TreeNodeInsertError | 树节点构建失败 | 2000002 | 由于scope name错误,无法找到该节点的父节点。 | | | TreeNodeInsertError | 树节点构建失败 | 2000002 | 由于scope name错误,无法找到该节点的父节点。 | | ||||
| | SourceFilesSaveError | 生成和保存转换后的脚本文件失败 | 3000000 | 由3000001,3000002,3000003导致的脚本生成保存失败。 | | |||||
| | SourceFilesSaveError | 生成和保存转换后的脚本文件失败 | 3000000 | 由300000至3000005导致的脚本生成保存失败。 | | |||||
| | NodeInputTypeNotSupportError | 网络节点输入类型未知 | 3000001 | 映射关系中设置节点输入类型错误。 | | | NodeInputTypeNotSupportError | 网络节点输入类型未知 | 3000001 | 映射关系中设置节点输入类型错误。 | | ||||
| | ScriptGenerationError | 转换脚本生成失败 | 3000002 | 空间不足;生成的脚本不符合PEP-8规范;`--output`目录下已有同名文件存在 | | | ScriptGenerationError | 转换脚本生成失败 | 3000002 | 空间不足;生成的脚本不符合PEP-8规范;`--output`目录下已有同名文件存在 | | ||||
| | ReportGenerationError | 转换报告生成失败 | 3000003 | 空间不足;脚本中没有需要转换的算子;`--report`目录下已有同名文件存在。 | | | ReportGenerationError | 转换报告生成失败 | 3000003 | 空间不足;脚本中没有需要转换的算子;`--report`目录下已有同名文件存在。 | | ||||
| @@ -334,13 +334,26 @@ class ModelNotSupportError(GraphInitError): | |||||
| @classmethod | @classmethod | ||||
| def raise_from(cls): | def raise_from(cls): | ||||
| """Raise from exceptions below.""" | """Raise from exceptions below.""" | ||||
| onnxruntime_error = getattr(import_module('onnxruntime.capi'), 'onnxruntime_pybind11_state') | |||||
| except_source = (RuntimeError, | except_source = (RuntimeError, | ||||
| ModuleNotFoundError, | ModuleNotFoundError, | ||||
| ValueError, | ValueError, | ||||
| AssertionError, | AssertionError, | ||||
| TypeError, | TypeError, | ||||
| OSError, | OSError, | ||||
| ZeroDivisionError, cls) | |||||
| ZeroDivisionError, | |||||
| onnxruntime_error.Fail, | |||||
| onnxruntime_error.InvalidArgument, | |||||
| onnxruntime_error.NoSuchFile, | |||||
| onnxruntime_error.NoModel, | |||||
| onnxruntime_error.EngineError, | |||||
| onnxruntime_error.RuntimeException, | |||||
| onnxruntime_error.InvalidProtobuf, | |||||
| onnxruntime_error.ModelLoaded, | |||||
| onnxruntime_error.NotImplemented, | |||||
| onnxruntime_error.InvalidGraph, | |||||
| onnxruntime_error.EPFail, | |||||
| cls) | |||||
| return except_source | return except_source | ||||
| @@ -510,7 +523,7 @@ class GeneratorError(MindConverterException): | |||||
| @classmethod | @classmethod | ||||
| def raise_from(cls): | def raise_from(cls): | ||||
| """Raise from exceptions below.""" | """Raise from exceptions below.""" | ||||
| except_source = (ValueError, TypeError, cls) | |||||
| except_source = (ValueError, TypeError, SyntaxError, cls) | |||||
| return except_source | return except_source | ||||
| @@ -26,8 +26,6 @@ from mindinsight.mindconverter.common.log import logger as log | |||||
| from mindinsight.mindconverter.graph_based_converter.constant import SEPARATOR_IN_ONNX_OP, BINARY_HEADER_PYTORCH_BITS, \ | from mindinsight.mindconverter.graph_based_converter.constant import SEPARATOR_IN_ONNX_OP, BINARY_HEADER_PYTORCH_BITS, \ | ||||
| FrameworkType, BINARY_HEADER_PYTORCH_FILE, TENSORFLOW_MODEL_SUFFIX, THIRD_PART_VERSION | FrameworkType, BINARY_HEADER_PYTORCH_FILE, TENSORFLOW_MODEL_SUFFIX, THIRD_PART_VERSION | ||||
| from mindspore.train.serialization import save_checkpoint | |||||
| def is_converted(operation: str): | def is_converted(operation: str): | ||||
| """ | """ | ||||
| @@ -147,6 +145,7 @@ def save_code_file_and_report(model_name: str, code_lines: Mapping[str, Tuple], | |||||
| except (IOError, FileExistsError) as error: | except (IOError, FileExistsError) as error: | ||||
| raise ReportGenerationError(str(error)) | raise ReportGenerationError(str(error)) | ||||
| save_checkpoint = getattr(import_module('mindspore.train.serialization'), 'save_checkpoint') | |||||
| ckpt_file_path = os.path.realpath(os.path.join(out_folder, f"{model_name}.ckpt")) | ckpt_file_path = os.path.realpath(os.path.join(out_folder, f"{model_name}.ckpt")) | ||||
| try: | try: | ||||
| if os.path.exists(ckpt_file_path): | if os.path.exists(ckpt_file_path): | ||||
| @@ -45,6 +45,8 @@ ONNXRUNTIME_MIN_VER = "1.5.2" | |||||
| ONNXOPTIMIZER_MIN_VER = "0.1.2" | ONNXOPTIMIZER_MIN_VER = "0.1.2" | ||||
| ONNXOPTIMIZER_MAX_VER = "0.1.2" | ONNXOPTIMIZER_MAX_VER = "0.1.2" | ||||
| TORCH_MIN_VER = "1.5.0" | |||||
| @unique | @unique | ||||
| class TemplateKeywords(Enum): | class TemplateKeywords(Enum): | ||||
| @@ -107,6 +109,7 @@ NO_CONVERTED_OPERATORS = [ | |||||
| ] | ] | ||||
| THIRD_PART_VERSION = { | THIRD_PART_VERSION = { | ||||
| "torch": (TORCH_MIN_VER,), | |||||
| "onnx": (ONNX_MIN_VER,), | "onnx": (ONNX_MIN_VER,), | ||||
| "onnxruntime": (ONNXRUNTIME_MIN_VER,), | "onnxruntime": (ONNXRUNTIME_MIN_VER,), | ||||
| "onnxoptimizer": (ONNXOPTIMIZER_MIN_VER, ONNXOPTIMIZER_MAX_VER), | "onnxoptimizer": (ONNXOPTIMIZER_MIN_VER, ONNXOPTIMIZER_MAX_VER), | ||||
| @@ -13,7 +13,9 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================== | # ============================================================================== | ||||
| """Graph based scripts converter workflow.""" | """Graph based scripts converter workflow.""" | ||||
| import multiprocessing as mp | |||||
| import os | import os | ||||
| import re | |||||
| import sys | import sys | ||||
| from importlib import import_module | from importlib import import_module | ||||
| from importlib.util import find_spec | from importlib.util import find_spec | ||||
| @@ -23,7 +25,7 @@ from mindinsight.mindconverter.graph_based_converter.common.global_context impor | |||||
| from mindinsight.mindconverter.graph_based_converter.common.utils import lib_version_satisfied, onnx_satisfied, \ | from mindinsight.mindconverter.graph_based_converter.common.utils import lib_version_satisfied, onnx_satisfied, \ | ||||
| save_code_file_and_report, get_framework_type, check_dependency_integrity, get_third_part_lib_validation_error_info | save_code_file_and_report, get_framework_type, check_dependency_integrity, get_third_part_lib_validation_error_info | ||||
| from mindinsight.mindconverter.graph_based_converter.constant import FrameworkType, \ | from mindinsight.mindconverter.graph_based_converter.constant import FrameworkType, \ | ||||
| ONNX_MIN_VER, TF2ONNX_MIN_VER, ONNXRUNTIME_MIN_VER, ONNXOPTIMIZER_MIN_VER, ONNXOPTIMIZER_MAX_VER | |||||
| ONNX_MIN_VER, TF2ONNX_MIN_VER, ONNXRUNTIME_MIN_VER, ONNXOPTIMIZER_MIN_VER, ONNXOPTIMIZER_MAX_VER, TORCH_MIN_VER | |||||
| from mindinsight.mindconverter.graph_based_converter.generator import batch_add_nodes | from mindinsight.mindconverter.graph_based_converter.generator import batch_add_nodes | ||||
| from mindinsight.mindconverter.graph_based_converter.mapper import ONNXToMindSporeMapper | from mindinsight.mindconverter.graph_based_converter.mapper import ONNXToMindSporeMapper | ||||
| from mindinsight.mindconverter.common.log import logger as log, logger_console as log_console | from mindinsight.mindconverter.common.log import logger as log, logger_console as log_console | ||||
| @@ -55,6 +57,16 @@ def _print_error(err): | |||||
| log_console.error("\n") | log_console.error("\n") | ||||
| def torch_version_satisfied(output_queue): | |||||
| """Check Torch version whether is satisfied.""" | |||||
| satisfied = False | |||||
| pattern = r"\d+\.\d+\.\d+" | |||||
| torch_version = re.findall(pattern, getattr(import_module('torch'), "__version__")) | |||||
| if torch_version: | |||||
| satisfied = lib_version_satisfied(torch_version[0], TORCH_MIN_VER) | |||||
| output_queue.put(satisfied) | |||||
| def torch_installation_validation(func): | def torch_installation_validation(func): | ||||
| """ | """ | ||||
| Validate args of func. | Validate args of func. | ||||
| @@ -70,24 +82,33 @@ def torch_installation_validation(func): | |||||
| output_folder: str, report_folder: str = None): | output_folder: str, report_folder: str = None): | ||||
| # Check whether pytorch is installed. | # Check whether pytorch is installed. | ||||
| error_info = None | error_info = None | ||||
| torch_version_validation = False | |||||
| if graph_path.endswith('.onnx'): | if graph_path.endswith('.onnx'): | ||||
| if not onnx_satisfied() or not check_common_dependency_integrity(): | if not onnx_satisfied() or not check_common_dependency_integrity(): | ||||
| error_info = f"{get_third_part_lib_validation_error_info(['onnx', 'onnxruntime', 'onnxoptimizer'])} " \ | error_info = f"{get_third_part_lib_validation_error_info(['onnx', 'onnxruntime', 'onnxoptimizer'])} " \ | ||||
| f"are required when using graph based scripts converter." | f"are required when using graph based scripts converter." | ||||
| else: | else: | ||||
| if not find_spec("torch") or not onnx_satisfied() or not check_common_dependency_integrity("torch"): | |||||
| error_info = f"PyTorch, " \ | |||||
| f"{get_third_part_lib_validation_error_info(['onnx', 'onnxruntime', 'onnxoptimizer'])} " \ | |||||
| f"are required when using graph based scripts converter, and PyTorch version must " \ | |||||
| f"be consisted with model generation runtime." | |||||
| if not find_spec("torch") or not onnx_satisfied() or not check_common_dependency_integrity(): | |||||
| error_info = \ | |||||
| f"{get_third_part_lib_validation_error_info(['torch', 'onnx', 'onnxruntime', 'onnxoptimizer'])} " \ | |||||
| f"are required when using graph based scripts converter, and PyTorch version must " \ | |||||
| f"be consisted with model generation runtime." | |||||
| output_queue = mp.Queue() | |||||
| process = mp.Process(target=torch_version_satisfied, args=(output_queue,)) | |||||
| process.start() | |||||
| torch_version_validation = output_queue.get() | |||||
| process.join() | |||||
| if error_info: | if error_info: | ||||
| _print_error(RuntimeIntegrityError(error_info)) | _print_error(RuntimeIntegrityError(error_info)) | ||||
| sys.exit(0) | sys.exit(0) | ||||
| if not onnx_lib_version_satisfied(): | |||||
| if (not torch_version_validation and not graph_path.endswith('.onnx')) or not onnx_lib_version_satisfied(): | |||||
| lib_check_list = ['onnx', 'onnxruntime', 'onnxoptimizer'] | |||||
| if not graph_path.endswith('.onnx'): | |||||
| lib_check_list.insert(0, 'torch') | |||||
| error = RuntimeIntegrityError( | error = RuntimeIntegrityError( | ||||
| f"{get_third_part_lib_validation_error_info(['onnx', 'onnxruntime', 'onnxoptimizer'])} " | |||||
| f"{get_third_part_lib_validation_error_info(lib_check_list)} " | |||||
| f"are required when using graph based scripts converter." | f"are required when using graph based scripts converter." | ||||
| ) | ) | ||||
| _print_error(error) | _print_error(error) | ||||
| @@ -157,10 +178,10 @@ def _extract_model_name(model_path): | |||||
| Extract model name from model path. | Extract model name from model path. | ||||
| Args: | Args: | ||||
| model_path(str): Path of Converted model. | |||||
| model_path (str): Path of Converted model. | |||||
| Returns: | Returns: | ||||
| str: Name of Converted model. | |||||
| str, name of Converted model. | |||||
| """ | """ | ||||
| base_path = os.path.basename(model_path) | base_path = os.path.basename(model_path) | ||||
| @@ -245,14 +266,22 @@ def main_graph_base_converter(file_config): | |||||
| raise ParamMissingError("Param missing, `--shape` is required when using graph mode.") | raise ParamMissingError("Param missing, `--shape` is required when using graph mode.") | ||||
| if frame_type == FrameworkType.PYTORCH.value: | if frame_type == FrameworkType.PYTORCH.value: | ||||
| graph_based_converter_pytorch_to_ms(graph_path=graph_path, | |||||
| sample_shape=file_config['shape'], | |||||
| input_nodes=file_config['input_nodes'] if file_config['input_nodes'] | |||||
| else 'input.1', | |||||
| output_nodes=file_config['output_nodes'] if file_config['output_nodes'] | |||||
| else '', | |||||
| output_folder=file_config['outfile_dir'], | |||||
| report_folder=file_config['report_dir']) | |||||
| if graph_path.endswith('.onnx'): | |||||
| check_params = ['input_nodes', 'output_nodes'] | |||||
| check_params_exist(check_params, file_config) | |||||
| graph_based_converter_pytorch_to_ms(graph_path=graph_path, | |||||
| sample_shape=file_config['shape'], | |||||
| input_nodes=file_config['input_nodes'], | |||||
| output_nodes=file_config['output_nodes'], | |||||
| output_folder=file_config['outfile_dir'], | |||||
| report_folder=file_config['report_dir']) | |||||
| else: | |||||
| graph_based_converter_pytorch_to_ms(graph_path=graph_path, | |||||
| sample_shape=file_config['shape'], | |||||
| input_nodes='input.1', | |||||
| output_nodes='', | |||||
| output_folder=file_config['outfile_dir'], | |||||
| report_folder=file_config['report_dir']) | |||||
| elif frame_type == FrameworkType.TENSORFLOW.value: | elif frame_type == FrameworkType.TENSORFLOW.value: | ||||
| check_params = ['input_nodes', 'output_nodes'] | check_params = ['input_nodes', 'output_nodes'] | ||||
| check_params_exist(check_params, file_config) | check_params_exist(check_params, file_config) | ||||
| @@ -15,9 +15,9 @@ | |||||
| """Main Generator module.""" | """Main Generator module.""" | ||||
| import copy | import copy | ||||
| from collections import OrderedDict | from collections import OrderedDict | ||||
| from importlib import import_module | |||||
| from yapf.yapflib.yapf_api import FormatCode | from yapf.yapflib.yapf_api import FormatCode | ||||
| from mindspore import Tensor | |||||
| from mindinsight.mindconverter.common.exceptions import GeneratorError | from mindinsight.mindconverter.common.exceptions import GeneratorError | ||||
| from mindinsight.mindconverter.graph_based_converter.generator.scope_utils import Scope | from mindinsight.mindconverter.graph_based_converter.generator.scope_utils import Scope | ||||
| @@ -493,6 +493,7 @@ class Generator: | |||||
| def generate_checkpoint(self): | def generate_checkpoint(self): | ||||
| """Generate checkpoint.""" | """Generate checkpoint.""" | ||||
| mindspore = import_module('mindspore') | |||||
| trainable_weights_dict = dict() | trainable_weights_dict = dict() | ||||
| weight_map = list() | weight_map = list() | ||||
| for node_name, node_inst in self.node_structs.items(): | for node_name, node_inst in self.node_structs.items(): | ||||
| @@ -507,8 +508,8 @@ class Generator: | |||||
| weight_name = SEPARATOR_BTW_NAME_AND_ID.join((weights_scope_name, weight_key)) | weight_name = SEPARATOR_BTW_NAME_AND_ID.join((weights_scope_name, weight_key)) | ||||
| else: | else: | ||||
| weight_name = LINK_IN_WEIGHT_NAME.join((weights_scope_name, weight_key)) | weight_name = LINK_IN_WEIGHT_NAME.join((weights_scope_name, weight_key)) | ||||
| weight_shape = Tensor(value_data).shape | |||||
| data_type = Tensor(value_data).dtype | |||||
| weight_shape = mindspore.Tensor(value_data).shape | |||||
| data_type = mindspore.Tensor(value_data).dtype | |||||
| trainable_weights_dict[weight_name] = value_data | trainable_weights_dict[weight_name] = value_data | ||||
| onnx_weight_name = onnx_weight_inst[idx].name | onnx_weight_name = onnx_weight_inst[idx].name | ||||
| @@ -534,7 +535,7 @@ class Generator: | |||||
| for weight_name, weight_value in trainable_weights_dict.items(): | for weight_name, weight_value in trainable_weights_dict.items(): | ||||
| obj = { | obj = { | ||||
| 'name': weight_name, | 'name': weight_name, | ||||
| 'data': Tensor(weight_value) | |||||
| 'data': mindspore.Tensor(weight_value) | |||||
| } | } | ||||
| save_obj.append(obj) | save_obj.append(obj) | ||||
| @@ -16,6 +16,7 @@ | |||||
| from importlib import import_module | from importlib import import_module | ||||
| from typing import Dict, NoReturn | from typing import Dict, NoReturn | ||||
| from mindinsight.mindconverter.common.exceptions import ModelNotSupportError | |||||
| from mindinsight.mindconverter.common.log import logger as log | from mindinsight.mindconverter.common.log import logger as log | ||||
| from mindinsight.mindconverter.graph_based_converter.third_party_graph.base import Graph | from mindinsight.mindconverter.graph_based_converter.third_party_graph.base import Graph | ||||
| from mindinsight.mindconverter.graph_based_converter.third_party_graph.input_node import InputNode | from mindinsight.mindconverter.graph_based_converter.third_party_graph.input_node import InputNode | ||||
| @@ -196,12 +197,12 @@ class OnnxGraph(Graph): | |||||
| Returns: | Returns: | ||||
| object, ONNX model. | object, ONNX model. | ||||
| """ | """ | ||||
| tf_input_nodes = kwargs.get('input_nodes') | |||||
| tf_output_nodes = kwargs.get('output_nodes') | |||||
| input_nodes = kwargs.get('input_nodes') | |||||
| output_nodes = kwargs.get('output_nodes') | |||||
| if graph_path.endswith('.pb'): | if graph_path.endswith('.pb'): | ||||
| onnx_model = TFGraphParser.parse(graph_path, | onnx_model = TFGraphParser.parse(graph_path, | ||||
| input_nodes=tf_input_nodes, | |||||
| output_nodes=tf_output_nodes) | |||||
| input_nodes=input_nodes, | |||||
| output_nodes=output_nodes) | |||||
| elif graph_path.endswith('.onnx'): | elif graph_path.endswith('.onnx'): | ||||
| onnx = import_module('onnx') | onnx = import_module('onnx') | ||||
| onnx_model = onnx.load(graph_path) | onnx_model = onnx.load(graph_path) | ||||
| @@ -212,4 +213,7 @@ class OnnxGraph(Graph): | |||||
| else: | else: | ||||
| onnx_model = PyTorchGraphParser.parse(graph_path, **kwargs) | onnx_model = PyTorchGraphParser.parse(graph_path, **kwargs) | ||||
| onnx_inputs = [onnx_input.name for onnx_input in onnx_model.graph.input] | |||||
| if input_nodes not in onnx_inputs: | |||||
| raise ModelNotSupportError(f"input nodes({input_nodes}) is not in model inputs ({onnx_inputs}).") | |||||
| return onnx_model | return onnx_model | ||||
| @@ -18,6 +18,7 @@ from importlib import import_module | |||||
| import numpy as np | import numpy as np | ||||
| from mindinsight.mindconverter.common.exceptions import ModelNotSupportError | |||||
| from mindinsight.mindconverter.graph_based_converter.common.utils import fetch_output_from_onnx_model | from mindinsight.mindconverter.graph_based_converter.common.utils import fetch_output_from_onnx_model | ||||
| @@ -92,6 +93,9 @@ class OnnxSimplify: | |||||
| self._constant_nodes = copy.deepcopy(const_nodes) | self._constant_nodes = copy.deepcopy(const_nodes) | ||||
| @ModelNotSupportError.check_except( | |||||
| "Error occurs in loading model, please check your model or runtime environment integrity." | |||||
| ) | |||||
| def _onnx_infer(self, infer_inputs_shape): | def _onnx_infer(self, infer_inputs_shape): | ||||
| """ | """ | ||||
| Run onnx inference to get outputs of constant nodes. | Run onnx inference to get outputs of constant nodes. | ||||
| @@ -13,6 +13,7 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================== | # ============================================================================== | ||||
| """Third party graph parser.""" | """Third party graph parser.""" | ||||
| import multiprocessing as mp | |||||
| import os | import os | ||||
| from importlib import import_module | from importlib import import_module | ||||
| @@ -67,6 +68,30 @@ class PyTorchGraphParser(GraphParser): | |||||
| opset_version (int): Op set version of onnx. | opset_version (int): Op set version of onnx. | ||||
| """ | """ | ||||
| output_queue = mp.Queue() | |||||
| process = mp.Process(target=PyTorchGraphParser._pytorch_graph_to_proto, | |||||
| args=(output_queue, model_path, sample_shape, opset_version)) | |||||
| process.start() | |||||
| proto = output_queue.get() | |||||
| process.join() | |||||
| onnx = import_module('onnx') | |||||
| onnx_model = onnx.load_model_from_string(proto) | |||||
| return onnx_model | |||||
| @staticmethod | |||||
| def _pytorch_graph_to_proto(output_queue, model_path, sample_shape, opset_version): | |||||
| """ | |||||
| Convert pytorch graph to pytorch proto. | |||||
| Args: | |||||
| output_queue (Queue): Output queue from multi-processing. | |||||
| model_path (str): Path to the Pytorch model. | |||||
| sample_shape (tuple): Input shape to generate onnx model. | |||||
| opset_version (int): Op set version of onnx. | |||||
| """ | |||||
| torch = import_module('torch') | torch = import_module('torch') | ||||
| has_cuda = torch.cuda.is_available() | has_cuda = torch.cuda.is_available() | ||||
| if has_cuda: | if has_cuda: | ||||
| @@ -102,7 +127,4 @@ class PyTorchGraphParser(GraphParser): | |||||
| operator_export_type, True, False, dict(), | operator_export_type, True, False, dict(), | ||||
| True, False) | True, False) | ||||
| onnx = import_module('onnx') | |||||
| onnx_model = onnx.load_model_from_string(proto) | |||||
| return onnx_model | |||||
| output_queue.put(proto) | |||||