| @@ -29,7 +29,7 @@ class ForwardCall(ast.NodeVisitor): | |||
| self.module_name = os.path.basename(filename).replace('.py', '') | |||
| self.name_stack = [] | |||
| self.forward_stack = [] | |||
| self.calls = [] | |||
| self.calls = set() | |||
| self.process() | |||
| def process(self): | |||
| @@ -68,7 +68,7 @@ class ForwardCall(ast.NodeVisitor): | |||
| self.forward_stack.append(func_name) | |||
| if node.name == 'forward': | |||
| self.calls.append(func_name) | |||
| self.calls.add(func_name) | |||
| self.generic_visit(node) | |||
| @@ -85,12 +85,12 @@ class ForwardCall(ast.NodeVisitor): | |||
| if isinstance(node.func, ast.Name): | |||
| if func_name not in ['super', 'str', 'repr']: | |||
| if self.forward_stack: | |||
| self.calls.append(func_name) | |||
| self.calls.add(func_name) | |||
| self.visit(node.func) | |||
| else: | |||
| if self.forward_stack: | |||
| if 'self' in func_name: | |||
| self.calls.append(f'{self.get_current_namespace()}.{func_name.split(".")[-1]}') | |||
| self.calls.add(f'{self.get_current_namespace()}.{func_name.split(".")[-1]}') | |||
| else: | |||
| self.calls.append(func_name) | |||
| self.calls.add(func_name) | |||
| self.visit(node.func) | |||
| @@ -0,0 +1,14 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| @@ -0,0 +1,52 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Test config module.""" | |||
| from collections import OrderedDict | |||
| import pytest | |||
| from mindinsight.mindconverter.config import APIPt, REQUIRED | |||
| class TestAPIBase: | |||
| """Test the class of APIPt.""" | |||
| function_name = "func" | |||
| @pytest.mark.parametrize('parameters', ['(out.size(0), -1', '(2, 1, 0)']) | |||
| def test_parse_args_exception(self, parameters): | |||
| """Test parse arguments exception""" | |||
| parameters_spec = OrderedDict(in_channels=REQUIRED, out_channels=REQUIRED) | |||
| api_parser = APIPt(self.function_name, parameters_spec) | |||
| with pytest.raises(ValueError): | |||
| api_parser.parse_args(api_parser.name, parameters) | |||
| def test_parse_single_arg(self): | |||
| """Test parse one argument""" | |||
| source = '(1)' | |||
| parameters_spec = OrderedDict(in_channels=REQUIRED) | |||
| api_parser = APIPt(self.function_name, parameters_spec) | |||
| parsed_args = api_parser.parse_args(api_parser.name, source) | |||
| assert parsed_args['in_channels'] == '1' | |||
| def test_parse_args(self): | |||
| """Test parse multiple arguments""" | |||
| source = '(1, 2)' | |||
| parameters_spec = OrderedDict(in_channels=REQUIRED, out_channels=REQUIRED) | |||
| api_parser = APIPt(self.function_name, parameters_spec) | |||
| parsed_args = api_parser.parse_args(api_parser.name, source) | |||
| assert parsed_args['in_channels'] == '1' | |||
| assert parsed_args['out_channels'] == '2' | |||
| @@ -0,0 +1,73 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Test forward_call module.""" | |||
| import ast | |||
| import textwrap | |||
| from unittest.mock import patch | |||
| from mindinsight.mindconverter.forward_call import ForwardCall | |||
| class TestForwardCall: | |||
| """Test the class of ForwardCall.""" | |||
| source = textwrap.dedent("""\ | |||
| import a | |||
| import a.nn as nn | |||
| import a.nn.functional as F | |||
| class TestNet: | |||
| def __init__(self): | |||
| self.conv1 = nn.Conv2d(3, 6, 5) | |||
| self.conv2 = nn.Conv2d(6, 16, 5) | |||
| self.fc1 = nn.Linear(16 * 5 * 5, 120) | |||
| self.fc2 = nn.Linear(120, 84) | |||
| self.fc3 = nn.Linear(84, 10) | |||
| def forward(self, x): | |||
| out = self.forward1(out) | |||
| return out | |||
| def forward1(self, x): | |||
| out = F.relu(self.conv1(x)) | |||
| out = F.max_pool2d(out, 2) | |||
| out = F.relu(self.conv2(out)) | |||
| out = F.max_pool2d(out, 2) | |||
| out = out.view(out.size(0), -1) | |||
| out = F.relu(self.fc1(out)) | |||
| out = F.relu(self.fc2(out)) | |||
| out = self.fc3(out) | |||
| return out | |||
| """) | |||
| @patch.object(ForwardCall, 'process') | |||
| def test_process(self, mock_process): | |||
| """Test the function of visit ast tree to find out forward functions.""" | |||
| mock_process.return_value = None | |||
| forward_call = ForwardCall("mock") | |||
| forward_call.visit(ast.parse(self.source)) | |||
| expect_calls = ['TestNet.forward1', | |||
| 'TestNet.forward1', | |||
| 'F.relu', | |||
| 'TestNet.conv1', | |||
| 'F.max_pool2d', | |||
| 'TestNet.conv2', | |||
| 'out.view', | |||
| 'out.size', | |||
| 'TestNet.fc1', | |||
| 'TestNet.fc2', | |||
| 'TestNet.fc3', | |||
| ] | |||
| assert [forward_call.calls].sort() == expect_calls.sort() | |||