| @@ -17,6 +17,7 @@ import copy | |||||
| import importlib | import importlib | ||||
| import inspect | import inspect | ||||
| import os | import os | ||||
| import re | |||||
| import stat | import stat | ||||
| from mindinsight.mindconverter.config import ALL_MAPPING | from mindinsight.mindconverter.config import ALL_MAPPING | ||||
| @@ -363,9 +364,19 @@ class Converter: | |||||
| for key, value in mapping.items(): | for key, value in mapping.items(): | ||||
| code = code.replace(key, value) | code = code.replace(key, value) | ||||
| code = 'import mindspore.ops.operations as P\n' + code | |||||
| code = 'import mindspore.nn as nn\n' + code | |||||
| code = 'import mindspore\n' + code | |||||
| source_lines = code.splitlines(keepends=True) | |||||
| valid_line_num = 0 | |||||
| # find the first valid code line of the source | |||||
| for num, source in enumerate(source_lines): | |||||
| if re.search(r'^[a-z]\w+', source): | |||||
| valid_line_num = num | |||||
| break | |||||
| source_lines.insert(valid_line_num, 'import mindspore.ops.operations as P\n') | |||||
| source_lines.insert(valid_line_num, 'import mindspore.nn as nn\n') | |||||
| source_lines.insert(valid_line_num, 'import mindspore\n') | |||||
| code = ''.join(source_lines) | |||||
| self.convert_info += '||[Import Add] Add follow import sentences:\n' | self.convert_info += '||[Import Add] Add follow import sentences:\n' | ||||
| self.convert_info += 'import mindspore.ops.operations as P\n' | self.convert_info += 'import mindspore.ops.operations as P\n' | ||||
| @@ -456,6 +467,6 @@ def main(files_config): | |||||
| module_name = '.'.join(in_file_split) | module_name = '.'.join(in_file_split) | ||||
| convert_ins.convert(module_name, files_config['outfile_dir'], files_config['report_dir']) | convert_ins.convert(module_name, files_config['outfile_dir'], files_config['report_dir']) | ||||
| in_module = files_config['in_module'] | |||||
| in_module = files_config.get('in_module') | |||||
| if in_module: | if in_module: | ||||
| convert_ins.convert(in_module, files_config['outfile_dir'], files_config['report_dir']) | convert_ins.convert(in_module, files_config['outfile_dir'], files_config['report_dir']) | ||||
| @@ -0,0 +1,14 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| @@ -0,0 +1,53 @@ | |||||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """The st config.""" | |||||
| import os | |||||
| import shutil | |||||
| import sys | |||||
| import tempfile | |||||
| import types | |||||
| import pytest | |||||
| OUTPUT_DIR = tempfile.mktemp(prefix='test_mindconverter_output_dir_') | |||||
| sys.modules['torch'] = types.ModuleType('torch') | |||||
| nn = types.ModuleType('torch.nn') | |||||
| sys.modules['torch.nn'] = nn | |||||
| nn.Module = type('Module', (object,), dict()) | |||||
| sys.modules['torch.nn.functional'] = types.ModuleType('torch.nn.functional') | |||||
| @pytest.fixture(scope='session') | |||||
| def create_output_dir(): | |||||
| """Create output directory.""" | |||||
| try: | |||||
| if os.path.exists(OUTPUT_DIR): | |||||
| shutil.rmtree(OUTPUT_DIR) | |||||
| permissions = os.R_OK | os.W_OK | os.X_OK | |||||
| mode = permissions << 6 | |||||
| if not os.path.exists(OUTPUT_DIR): | |||||
| os.mkdir(OUTPUT_DIR, mode=mode) | |||||
| yield | |||||
| finally: | |||||
| if os.path.exists(OUTPUT_DIR): | |||||
| shutil.rmtree(OUTPUT_DIR) | |||||
| @pytest.fixture() | |||||
| def output(): | |||||
| """Get the output directory.""" | |||||
| return OUTPUT_DIR | |||||
| @@ -0,0 +1,46 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Test network script of LeNet.""" | |||||
| import mindspore.nn as nn | |||||
| import mindspore.ops.operations as P | |||||
| # import torch.nn as nn | |||||
| # import torch.nn.functional as F | |||||
| class TestLeNet(nn.Cell): | |||||
| """TestLeNet network.""" | |||||
| def __init__(self): | |||||
| self.conv1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5, pad_mode='pad', has_bias=True) | |||||
| self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, pad_mode='pad', has_bias=True) | |||||
| self.fc1 = nn.Dense(in_channels=16 * 5 * 5, out_channels=120) | |||||
| self.fc2 = nn.Dense(in_channels=120, out_channels=84) | |||||
| self.fc3 = nn.Dense(in_channels=84, out_channels=10) | |||||
| def construct(self, input_x): | |||||
| """Callback method.""" | |||||
| out = self.forward1(input_x) | |||||
| return out | |||||
| def forward1(self, input_x): | |||||
| """forward1 method.""" | |||||
| out = P.ReLU()(self.conv1(input_x)) | |||||
| out = P.MaxPool(2, None, 'valid')(out) | |||||
| out = P.ReLU()(self.conv2(out)) | |||||
| out = P.MaxPool(2, None, 'valid')(out) | |||||
| out = P.Reshape()(out, (P.Shape()(out)[0], -1,)) | |||||
| out = P.ReLU()(self.fc1(out)) | |||||
| out = P.ReLU()(self.fc2(out)) | |||||
| out = self.fc3(out) | |||||
| return out | |||||
| @@ -0,0 +1,44 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Test network script of LeNet.""" | |||||
| import torch.nn as nn | |||||
| import torch.nn.functional as F | |||||
| class TestLeNet(nn.Module): | |||||
| """TestLeNet network.""" | |||||
| def __init__(self): | |||||
| self.conv1 = nn.Conv2d(3, 6, 5) | |||||
| self.conv2 = nn.Conv2d(6, 16, 5) | |||||
| self.fc1 = nn.Linear(16 * 5 * 5, 120) | |||||
| self.fc2 = nn.Linear(120, 84) | |||||
| self.fc3 = nn.Linear(84, 10) | |||||
| def forward(self, input_x): | |||||
| """Callback method.""" | |||||
| out = self.forward1(input_x) | |||||
| return out | |||||
| def forward1(self, input_x): | |||||
| """forward1 method.""" | |||||
| out = F.relu(self.conv1(input_x)) | |||||
| out = F.max_pool2d(out, 2) | |||||
| out = F.relu(self.conv2(out)) | |||||
| out = F.max_pool2d(out, 2) | |||||
| out = out.view(out.size(0), -1) | |||||
| out = F.relu(self.fc1(out)) | |||||
| out = F.relu(self.fc2(out)) | |||||
| out = self.fc3(out) | |||||
| return out | |||||
| @@ -0,0 +1,79 @@ | |||||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| Fuction: | |||||
| Test mindconverter to convert user's PyTorch network script. | |||||
| Usage: | |||||
| pytest tests/st/func/mindconverter | |||||
| """ | |||||
| import difflib | |||||
| import os | |||||
| import sys | |||||
| import pytest | |||||
| from mindinsight.mindconverter.converter import main | |||||
| @pytest.mark.usefixtures('create_output_dir') | |||||
| class TestConverter: | |||||
| """Test Converter module.""" | |||||
| @classmethod | |||||
| def setup_class(cls): | |||||
| """Setup method.""" | |||||
| cls.script_dir = os.path.join(os.path.dirname(__file__), 'data') | |||||
| sys.path.insert(0, cls.script_dir) | |||||
| @classmethod | |||||
| def teardown_class(cls): | |||||
| """Teardown method.""" | |||||
| sys.path.remove(cls.script_dir) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_arm_ascend_training | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.platform_x86_ascend_training | |||||
| @pytest.mark.platform_x86_cpu | |||||
| @pytest.mark.env_single | |||||
| def test_convert_lenet(self, output): | |||||
| """Test LeNet script of the PyTorch convert to MindSpore script""" | |||||
| script_filename = "lenet_script.py" | |||||
| expect_filename = "lenet_converted.py" | |||||
| files_config = { | |||||
| 'root_path': self.script_dir, | |||||
| 'in_files': [os.path.join(self.script_dir, script_filename)], | |||||
| 'outfile_dir': output, | |||||
| 'report_dir': output | |||||
| } | |||||
| main(files_config) | |||||
| assert os.path.isfile(os.path.join(output, script_filename)) | |||||
| with open(os.path.join(output, script_filename)) as converted_f: | |||||
| converted_source = converted_f.readlines() | |||||
| with open(os.path.join(self.script_dir, expect_filename)) as expect_f: | |||||
| expect_source = expect_f.readlines() | |||||
| diff = difflib.ndiff(converted_source, expect_source) | |||||
| diff_lines = 0 | |||||
| for line in diff: | |||||
| if line.startswith('+'): | |||||
| diff_lines += 1 | |||||
| converted_ratio = 100 - (diff_lines * 100) / (len(expect_source)) | |||||
| assert converted_ratio >= 80 | |||||
| @@ -35,7 +35,7 @@ class TestForwardCall: | |||||
| self.fc3 = nn.Linear(84, 10) | self.fc3 = nn.Linear(84, 10) | ||||
| def forward(self, x): | def forward(self, x): | ||||
| out = self.forward1(out) | |||||
| out = self.forward1(x) | |||||
| return out | return out | ||||
| def forward1(self, x): | def forward1(self, x): | ||||
| @@ -57,7 +57,7 @@ class TestForwardCall: | |||||
| forward_call = ForwardCall("mock") | forward_call = ForwardCall("mock") | ||||
| forward_call.visit(ast.parse(self.source)) | forward_call.visit(ast.parse(self.source)) | ||||
| expect_calls = ['TestNet.forward1', | |||||
| expect_calls = ['TestNet.forward', | |||||
| 'TestNet.forward1', | 'TestNet.forward1', | ||||
| 'F.relu', | 'F.relu', | ||||
| 'TestNet.conv1', | 'TestNet.conv1', | ||||
| @@ -69,5 +69,7 @@ class TestForwardCall: | |||||
| 'TestNet.fc2', | 'TestNet.fc2', | ||||
| 'TestNet.fc3', | 'TestNet.fc3', | ||||
| ] | ] | ||||
| assert [forward_call.calls].sort() == expect_calls.sort() | |||||
| expect_calls.sort() | |||||
| real_calls = list(forward_call.calls) | |||||
| real_calls.sort() | |||||
| assert real_calls == expect_calls | |||||