| @@ -17,6 +17,7 @@ import copy | |||
| import importlib | |||
| import inspect | |||
| import os | |||
| import re | |||
| import stat | |||
| from mindinsight.mindconverter.config import ALL_MAPPING | |||
| @@ -363,9 +364,19 @@ class Converter: | |||
| for key, value in mapping.items(): | |||
| code = code.replace(key, value) | |||
| code = 'import mindspore.ops.operations as P\n' + code | |||
| code = 'import mindspore.nn as nn\n' + code | |||
| code = 'import mindspore\n' + code | |||
| source_lines = code.splitlines(keepends=True) | |||
| valid_line_num = 0 | |||
| # find the first valid code line of the source | |||
| for num, source in enumerate(source_lines): | |||
| if re.search(r'^[a-z]\w+', source): | |||
| valid_line_num = num | |||
| break | |||
| source_lines.insert(valid_line_num, 'import mindspore.ops.operations as P\n') | |||
| source_lines.insert(valid_line_num, 'import mindspore.nn as nn\n') | |||
| source_lines.insert(valid_line_num, 'import mindspore\n') | |||
| code = ''.join(source_lines) | |||
| self.convert_info += '||[Import Add] Add follow import sentences:\n' | |||
| self.convert_info += 'import mindspore.ops.operations as P\n' | |||
| @@ -456,6 +467,6 @@ def main(files_config): | |||
| module_name = '.'.join(in_file_split) | |||
| convert_ins.convert(module_name, files_config['outfile_dir'], files_config['report_dir']) | |||
| in_module = files_config['in_module'] | |||
| in_module = files_config.get('in_module') | |||
| if in_module: | |||
| convert_ins.convert(in_module, files_config['outfile_dir'], files_config['report_dir']) | |||
| @@ -0,0 +1,14 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| @@ -0,0 +1,53 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """The st config.""" | |||
| import os | |||
| import shutil | |||
| import sys | |||
| import tempfile | |||
| import types | |||
| import pytest | |||
| OUTPUT_DIR = tempfile.mktemp(prefix='test_mindconverter_output_dir_') | |||
| sys.modules['torch'] = types.ModuleType('torch') | |||
| nn = types.ModuleType('torch.nn') | |||
| sys.modules['torch.nn'] = nn | |||
| nn.Module = type('Module', (object,), dict()) | |||
| sys.modules['torch.nn.functional'] = types.ModuleType('torch.nn.functional') | |||
| @pytest.fixture(scope='session') | |||
| def create_output_dir(): | |||
| """Create output directory.""" | |||
| try: | |||
| if os.path.exists(OUTPUT_DIR): | |||
| shutil.rmtree(OUTPUT_DIR) | |||
| permissions = os.R_OK | os.W_OK | os.X_OK | |||
| mode = permissions << 6 | |||
| if not os.path.exists(OUTPUT_DIR): | |||
| os.mkdir(OUTPUT_DIR, mode=mode) | |||
| yield | |||
| finally: | |||
| if os.path.exists(OUTPUT_DIR): | |||
| shutil.rmtree(OUTPUT_DIR) | |||
| @pytest.fixture() | |||
| def output(): | |||
| """Get the output directory.""" | |||
| return OUTPUT_DIR | |||
| @@ -0,0 +1,46 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Test network script of LeNet.""" | |||
| import mindspore.nn as nn | |||
| import mindspore.ops.operations as P | |||
| # import torch.nn as nn | |||
| # import torch.nn.functional as F | |||
| class TestLeNet(nn.Cell): | |||
| """TestLeNet network.""" | |||
| def __init__(self): | |||
| self.conv1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5, pad_mode='pad', has_bias=True) | |||
| self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, pad_mode='pad', has_bias=True) | |||
| self.fc1 = nn.Dense(in_channels=16 * 5 * 5, out_channels=120) | |||
| self.fc2 = nn.Dense(in_channels=120, out_channels=84) | |||
| self.fc3 = nn.Dense(in_channels=84, out_channels=10) | |||
| def construct(self, input_x): | |||
| """Callback method.""" | |||
| out = self.forward1(input_x) | |||
| return out | |||
| def forward1(self, input_x): | |||
| """forward1 method.""" | |||
| out = P.ReLU()(self.conv1(input_x)) | |||
| out = P.MaxPool(2, None, 'valid')(out) | |||
| out = P.ReLU()(self.conv2(out)) | |||
| out = P.MaxPool(2, None, 'valid')(out) | |||
| out = P.Reshape()(out, (P.Shape()(out)[0], -1,)) | |||
| out = P.ReLU()(self.fc1(out)) | |||
| out = P.ReLU()(self.fc2(out)) | |||
| out = self.fc3(out) | |||
| return out | |||
| @@ -0,0 +1,44 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Test network script of LeNet.""" | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| class TestLeNet(nn.Module): | |||
| """TestLeNet network.""" | |||
| def __init__(self): | |||
| self.conv1 = nn.Conv2d(3, 6, 5) | |||
| self.conv2 = nn.Conv2d(6, 16, 5) | |||
| self.fc1 = nn.Linear(16 * 5 * 5, 120) | |||
| self.fc2 = nn.Linear(120, 84) | |||
| self.fc3 = nn.Linear(84, 10) | |||
| def forward(self, input_x): | |||
| """Callback method.""" | |||
| out = self.forward1(input_x) | |||
| return out | |||
| def forward1(self, input_x): | |||
| """forward1 method.""" | |||
| out = F.relu(self.conv1(input_x)) | |||
| out = F.max_pool2d(out, 2) | |||
| out = F.relu(self.conv2(out)) | |||
| out = F.max_pool2d(out, 2) | |||
| out = out.view(out.size(0), -1) | |||
| out = F.relu(self.fc1(out)) | |||
| out = F.relu(self.fc2(out)) | |||
| out = self.fc3(out) | |||
| return out | |||
| @@ -0,0 +1,79 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| Fuction: | |||
| Test mindconverter to convert user's PyTorch network script. | |||
| Usage: | |||
| pytest tests/st/func/mindconverter | |||
| """ | |||
| import difflib | |||
| import os | |||
| import sys | |||
| import pytest | |||
| from mindinsight.mindconverter.converter import main | |||
| @pytest.mark.usefixtures('create_output_dir') | |||
| class TestConverter: | |||
| """Test Converter module.""" | |||
| @classmethod | |||
| def setup_class(cls): | |||
| """Setup method.""" | |||
| cls.script_dir = os.path.join(os.path.dirname(__file__), 'data') | |||
| sys.path.insert(0, cls.script_dir) | |||
| @classmethod | |||
| def teardown_class(cls): | |||
| """Teardown method.""" | |||
| sys.path.remove(cls.script_dir) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_single | |||
| def test_convert_lenet(self, output): | |||
| """Test LeNet script of the PyTorch convert to MindSpore script""" | |||
| script_filename = "lenet_script.py" | |||
| expect_filename = "lenet_converted.py" | |||
| files_config = { | |||
| 'root_path': self.script_dir, | |||
| 'in_files': [os.path.join(self.script_dir, script_filename)], | |||
| 'outfile_dir': output, | |||
| 'report_dir': output | |||
| } | |||
| main(files_config) | |||
| assert os.path.isfile(os.path.join(output, script_filename)) | |||
| with open(os.path.join(output, script_filename)) as converted_f: | |||
| converted_source = converted_f.readlines() | |||
| with open(os.path.join(self.script_dir, expect_filename)) as expect_f: | |||
| expect_source = expect_f.readlines() | |||
| diff = difflib.ndiff(converted_source, expect_source) | |||
| diff_lines = 0 | |||
| for line in diff: | |||
| if line.startswith('+'): | |||
| diff_lines += 1 | |||
| converted_ratio = 100 - (diff_lines * 100) / (len(expect_source)) | |||
| assert converted_ratio >= 80 | |||
| @@ -35,7 +35,7 @@ class TestForwardCall: | |||
| self.fc3 = nn.Linear(84, 10) | |||
| def forward(self, x): | |||
| out = self.forward1(out) | |||
| out = self.forward1(x) | |||
| return out | |||
| def forward1(self, x): | |||
| @@ -57,7 +57,7 @@ class TestForwardCall: | |||
| forward_call = ForwardCall("mock") | |||
| forward_call.visit(ast.parse(self.source)) | |||
| expect_calls = ['TestNet.forward1', | |||
| expect_calls = ['TestNet.forward', | |||
| 'TestNet.forward1', | |||
| 'F.relu', | |||
| 'TestNet.conv1', | |||
| @@ -69,5 +69,7 @@ class TestForwardCall: | |||
| 'TestNet.fc2', | |||
| 'TestNet.fc3', | |||
| ] | |||
| assert [forward_call.calls].sort() == expect_calls.sort() | |||
| expect_calls.sort() | |||
| real_calls = list(forward_call.calls) | |||
| real_calls.sort() | |||
| assert real_calls == expect_calls | |||