| @@ -18,7 +18,5 @@ import os | |||||
| TEMPLATES_BASE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'templates') | TEMPLATES_BASE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'templates') | ||||
| SUPPORT_MINDSPORE_VERSION = '0.7.0' | SUPPORT_MINDSPORE_VERSION = '0.7.0' | ||||
| SUPPORT_RUN_DRIVER_VERSION = 'C75' | |||||
| SUPPORT_CUDA_VERSION = '10.1' | |||||
| QUESTION_START = '>>> ' | QUESTION_START = '>>> ' | ||||
| @@ -56,22 +56,21 @@ class GenericNetwork(BaseNetwork): | |||||
| dict, configuration value to network. | dict, configuration value to network. | ||||
| """ | """ | ||||
| if settings: | if settings: | ||||
| config = {'loss': settings['loss'], | |||||
| 'optimizer': settings['optimizer'], | |||||
| 'dataset': settings['dataset']} | |||||
| self.settings.update(config) | |||||
| return config | |||||
| loss = self.ask_loss_function() | |||||
| optimizer = self.ask_optimizer() | |||||
| dataset = self.ask_dataset() | |||||
| self._dataset_maker = load_dataset_maker(dataset) | |||||
| config = dict(settings) | |||||
| dataset_name = settings['dataset'] | |||||
| self._dataset_maker = load_dataset_maker(dataset_name) | |||||
| else: | |||||
| loss = self.ask_loss_function() | |||||
| optimizer = self.ask_optimizer() | |||||
| dataset_name = self.ask_dataset() | |||||
| self._dataset_maker = load_dataset_maker(dataset_name) | |||||
| dataset_config = self._dataset_maker.configure() | |||||
| config = {'loss': loss, | |||||
| 'optimizer': optimizer, | |||||
| 'dataset': dataset_name} | |||||
| config.update(dataset_config) | |||||
| self._dataset_maker.set_network(self) | self._dataset_maker.set_network(self) | ||||
| dataset_config = self._dataset_maker.configure() | |||||
| config = {'loss': loss, | |||||
| 'optimizer': optimizer, | |||||
| 'dataset': dataset} | |||||
| config.update(dataset_config) | |||||
| self.settings.update(config) | self.settings.update(config) | ||||
| return config | return config | ||||
| @@ -0,0 +1,15 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Test the wizard module.""" | |||||
| @@ -0,0 +1,74 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Test CreateProject class.""" | |||||
| import os | |||||
| import shutil | |||||
| import tempfile | |||||
| from unittest.mock import patch | |||||
| from mindinsight.wizard.base.source_file import SourceFile | |||||
| from mindinsight.wizard.create_project import CreateProject | |||||
| from mindinsight.wizard.network.generic_network import GenericNetwork | |||||
| from tests.ut.wizard.utils import generate_file | |||||
| class TestCreateProject: | |||||
| """Test SourceFile""" | |||||
| workspace_dir = None | |||||
| def setup_method(self): | |||||
| """Setup before call test method.""" | |||||
| self.workspace_dir = tempfile.mkdtemp() | |||||
| def teardown_method(self): | |||||
| """Tear down after call test method.""" | |||||
| self._remove_dirs() | |||||
| self.workspace_dir = None | |||||
| def _remove_dirs(self): | |||||
| """Recursively delete a directory tree.""" | |||||
| if self.workspace_dir and os.path.exists(self.workspace_dir): | |||||
| shutil.rmtree(self.workspace_dir) | |||||
| @staticmethod | |||||
| def _generate_file(file): | |||||
| """Create a file and write content.""" | |||||
| generate_file(file, "template file.") | |||||
| @patch.object(GenericNetwork, 'generate') | |||||
| @patch.object(GenericNetwork, 'configure') | |||||
| @patch.object(CreateProject, 'ask_network') | |||||
| @patch.object(CreateProject, 'echo_notice') | |||||
| @patch('os.getcwd') | |||||
| def test_run(self, mock_getcwd, mock_echo_notice, mock_ask_network, mock_config, mock_generate): | |||||
| """Test run method of CreateProject.""" | |||||
| source_file = SourceFile() | |||||
| source_file.template_file_path = os.path.join(self.workspace_dir, 'templates', 'train.py-tpl') | |||||
| source_file.file_relative_path = 'train.py' | |||||
| self._generate_file(source_file.template_file_path) | |||||
| # mock os.getcwd method | |||||
| mock_getcwd.return_value = self.workspace_dir | |||||
| mock_echo_notice.return_value = None | |||||
| mock_ask_network.return_value = 'lenet' | |||||
| mock_config.return_value = None | |||||
| mock_generate.return_value = [source_file] | |||||
| project_name = 'test' | |||||
| new_project = CreateProject() | |||||
| new_project.run({'name': project_name}) | |||||
| assert os.path.exists(os.path.join(self.workspace_dir, project_name)) | |||||
| assert os.access(os.path.join(self.workspace_dir, project_name, 'train.py'), mode=os.F_OK | os.R_OK | os.W_OK) | |||||
| @@ -0,0 +1,71 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Test GenericNetwork class.""" | |||||
| import os | |||||
| import pytest | |||||
| from mindinsight.wizard.network import lenet | |||||
| class TestGenericNetwork: | |||||
| """Test SourceFile""" | |||||
| def test_generate_scripts(self): | |||||
| """Test network object to generate network scripts""" | |||||
| network_inst = lenet.Network() | |||||
| network_inst.configure({ | |||||
| "loss": "SoftmaxCrossEntropyWithLogits", | |||||
| "optimizer": "Momentum", | |||||
| "dataset": "mnist"}) | |||||
| sources_files = network_inst.generate() | |||||
| dataset_source_file = None | |||||
| config_source_file = None | |||||
| shell_script_dir_files = [] | |||||
| out_files = [] | |||||
| for sources_file in sources_files: | |||||
| if sources_file.file_relative_path == 'src/dataset.py': | |||||
| dataset_source_file = sources_file | |||||
| elif sources_file.file_relative_path == 'src/config.py': | |||||
| config_source_file = sources_file | |||||
| elif sources_file.file_relative_path.startswith('scripts'): | |||||
| shell_script_dir_files.append(sources_file) | |||||
| elif not os.path.dirname(sources_file.file_relative_path): | |||||
| out_files.append(sources_file) | |||||
| else: | |||||
| pass | |||||
| assert sources_files | |||||
| assert dataset_source_file is not None | |||||
| assert config_source_file is not None | |||||
| assert shell_script_dir_files | |||||
| assert out_files | |||||
| def test_config(self): | |||||
| """Test network object to config.""" | |||||
| network_inst = lenet.Network() | |||||
| settings = { | |||||
| "loss": "SoftmaxCrossEntropyWithLogits", | |||||
| "optimizer": "Momentum", | |||||
| "dataset": "mnist"} | |||||
| configurations = network_inst.configure(settings) | |||||
| assert configurations["dataset"] == settings["dataset"] | |||||
| assert configurations["loss"] == settings["loss"] | |||||
| assert configurations["optimizer"] == settings["optimizer"] | |||||
| settings["dataset"] = "mnist_another" | |||||
| with pytest.raises(ModuleNotFoundError) as exec_info: | |||||
| network_inst.configure(settings) | |||||
| assert exec_info.value.name == f'mindinsight.wizard.dataset.{settings["dataset"]}' | |||||
| @@ -0,0 +1,98 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Test SourceFile class.""" | |||||
| import os | |||||
| import shutil | |||||
| import stat | |||||
| import tempfile | |||||
| import pytest | |||||
| from mindinsight.wizard.base.source_file import SourceFile | |||||
| from tests.ut.wizard.utils import generate_file | |||||
| class TestSourceFile: | |||||
| """Test SourceFile""" | |||||
| def setup_method(self): | |||||
| """Setup before call test method.""" | |||||
| self._input_dir = tempfile.mkdtemp() | |||||
| self._output_dir = tempfile.mkdtemp() | |||||
| def teardown_method(self): | |||||
| """Tear down after call test method.""" | |||||
| self._remove_dirs() | |||||
| self._input_dir = None | |||||
| self._output_dir = None | |||||
| def _remove_dirs(self): | |||||
| """Recursively delete a directory tree.""" | |||||
| for temp_dir in [self._input_dir, self._output_dir]: | |||||
| if temp_dir and os.path.exists(temp_dir): | |||||
| shutil.rmtree(temp_dir) | |||||
| @staticmethod | |||||
| def _generate_file(file, stat_mode): | |||||
| """Create a file and write content.""" | |||||
| generate_file(file, "template file.", stat_mode) | |||||
| @pytest.mark.parametrize('params', [{ | |||||
| 'file_relative_path': 'src/config.py', | |||||
| 'template_file_path': 'src/config.py-tpl' | |||||
| }, { | |||||
| 'file_relative_path': 'src/lenet.py', | |||||
| 'template_file_path': 'src/lenet.py-tpl' | |||||
| }, { | |||||
| 'file_relative_path': 'README.md', | |||||
| 'template_file_path': 'README.md-tpl' | |||||
| }, { | |||||
| 'file_relative_path': 'train.py', | |||||
| 'template_file_path': 'train.py-tpl' | |||||
| }]) | |||||
| def test_write_py(self, params): | |||||
| """Test write python script file""" | |||||
| source_file = SourceFile() | |||||
| source_file.file_relative_path = params['file_relative_path'] | |||||
| source_file.template_file_path = os.path.join(self._input_dir, params['template_file_path']) | |||||
| self._generate_file(source_file.template_file_path, stat.S_IRUSR) | |||||
| # start write | |||||
| source_file.write(self._output_dir) | |||||
| output_file_path = os.path.join(self._output_dir, source_file.file_relative_path) | |||||
| assert os.access(output_file_path, os.F_OK | os.R_OK | os.W_OK) | |||||
| assert stat.filemode(os.stat(output_file_path).st_mode) == '-rw-------' | |||||
| @pytest.mark.parametrize('params', [{ | |||||
| 'file_relative_path': 'scripts/run_eval.sh', | |||||
| 'template_file_path': 'scripts/run_eval.sh-tpl' | |||||
| }, { | |||||
| 'file_relative_path': 'run_distribute_train.sh', | |||||
| 'template_file_path': 'run_distribute_train.sh-tpl' | |||||
| }]) | |||||
| def test_write_sh(self, params): | |||||
| """Test write shell script file""" | |||||
| source_file = SourceFile() | |||||
| source_file.file_relative_path = params['file_relative_path'] | |||||
| source_file.template_file_path = os.path.join(self._input_dir, params['template_file_path']) | |||||
| self._generate_file(source_file.template_file_path, stat.S_IRUSR) | |||||
| # start write | |||||
| source_file.write(self._output_dir) | |||||
| output_file_path = os.path.join(self._output_dir, source_file.file_relative_path) | |||||
| assert os.access(output_file_path, os.F_OK | os.R_OK | os.W_OK | os.X_OK) | |||||
| assert stat.filemode(os.stat(output_file_path).st_mode) == '-rwx------' | |||||
| @@ -0,0 +1,188 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Test TemplateManager class.""" | |||||
| import os | |||||
| import shutil | |||||
| import tempfile | |||||
| import textwrap | |||||
| from mindinsight.wizard.base.templates import TemplateManager | |||||
| from tests.ut.wizard.utils import generate_file | |||||
| def create_template_files(template_dir): | |||||
| """Create network template files.""" | |||||
| all_template_files = [] | |||||
| train_file = os.path.join(template_dir, 'train.py-tpl') | |||||
| generate_file(train_file, | |||||
| textwrap.dedent("""\ | |||||
| {% if loss=='SoftmaxCrossEntropyWithLogits' %} | |||||
| net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") | |||||
| {% elif loss=='SoftmaxCrossEntropyExpand' %} | |||||
| net_loss = nn.SoftmaxCrossEntropyExpand(sparse=True) | |||||
| {% endif %} | |||||
| """)) | |||||
| all_template_files.append(train_file) | |||||
| os.mkdir(os.path.join(template_dir, 'src')) | |||||
| config_file = os.path.join(template_dir, 'src', 'config.py-tpl') | |||||
| generate_file(config_file, | |||||
| textwrap.dedent("""\ | |||||
| { | |||||
| 'num_classes': 10, | |||||
| {% if optimizer=='Momentum' %} | |||||
| 'lr': 0.01, | |||||
| "momentum": 0.9, | |||||
| {% elif optimizer=='SGD' %} | |||||
| 'lr': 0.1, | |||||
| {% else %} | |||||
| 'lr': 0.001, | |||||
| {% endif %} | |||||
| 'epoch_size': 1 | |||||
| } | |||||
| """)) | |||||
| all_template_files.append(config_file) | |||||
| os.mkdir(os.path.join(template_dir, 'scripts')) | |||||
| run_standalone_train_file = os.path.join(template_dir, 'scripts', 'run_standalone_train.sh-tpl') | |||||
| generate_file(run_standalone_train_file, | |||||
| textwrap.dedent("""\ | |||||
| python train.py --dataset_path=$PATH1 --pre_trained=$PATH2 &> log & | |||||
| """)) | |||||
| all_template_files.append(run_standalone_train_file) | |||||
| os.mkdir(os.path.join(template_dir, 'dataset')) | |||||
| os.mkdir(os.path.join(template_dir, 'dataset', 'mnist')) | |||||
| dataset_file = os.path.join(template_dir, 'dataset', 'mnist', 'dataset.py-tpl') | |||||
| generate_file(dataset_file, | |||||
| textwrap.dedent("""\ | |||||
| import mindspore.dataset as ds | |||||
| import mindspore.dataset.transforms.vision.c_transforms as CV | |||||
| """)) | |||||
| all_template_files.append(dataset_file) | |||||
| return all_template_files | |||||
| class TestTemplateManager: | |||||
| """Test TemplateManager""" | |||||
| template_dir = None | |||||
| all_template_files = [] | |||||
| def setup_method(self): | |||||
| """Setup before call test method.""" | |||||
| self.template_dir = tempfile.mkdtemp() | |||||
| self.all_template_files = create_template_files(self.template_dir) | |||||
| def teardown_method(self): | |||||
| """Tear down after call test method.""" | |||||
| self._remove_dirs() | |||||
| self.template_dir = None | |||||
| def _remove_dirs(self): | |||||
| """Recursively delete a directory tree.""" | |||||
| if self.template_dir and os.path.exists(self.template_dir): | |||||
| shutil.rmtree(self.template_dir) | |||||
| def test_template_files(self): | |||||
| """Test get_template_files method.""" | |||||
| src_file_num = 1 | |||||
| dataset_file_num = 1 | |||||
| template_mgr = TemplateManager(self.template_dir) | |||||
| all_files = template_mgr.get_template_files() | |||||
| assert set(all_files) == set(self.all_template_files) | |||||
| template_mgr = TemplateManager(os.path.join(self.template_dir, 'src')) | |||||
| all_files = template_mgr.get_template_files() | |||||
| assert len(all_files) == src_file_num | |||||
| template_mgr = TemplateManager(os.path.join(self.template_dir, 'dataset')) | |||||
| all_files = template_mgr.get_template_files() | |||||
| assert len(all_files) == dataset_file_num | |||||
| template_mgr = TemplateManager(self.template_dir, exclude_dirs=['src']) | |||||
| all_files = template_mgr.get_template_files() | |||||
| assert len(all_files) == len(self.all_template_files) - src_file_num | |||||
| template_mgr = TemplateManager(self.template_dir, exclude_dirs=['src', 'dataset']) | |||||
| all_files = template_mgr.get_template_files() | |||||
| assert len(all_files) == len(self.all_template_files) - src_file_num - dataset_file_num | |||||
| template_mgr = TemplateManager(self.template_dir, | |||||
| exclude_dirs=['src', 'dataset'], | |||||
| exclude_files=['train.py-tpl']) | |||||
| all_files = template_mgr.get_template_files() | |||||
| assert len(all_files) == len(self.all_template_files) - src_file_num - dataset_file_num - 1 | |||||
| def test_src_render(self): | |||||
| """Test render file in src directory.""" | |||||
| template_mgr = TemplateManager(os.path.join(self.template_dir, 'src')) | |||||
| source_files = template_mgr.render(optimizer='Momentum') | |||||
| assert source_files[0].content == textwrap.dedent("""\ | |||||
| { | |||||
| 'num_classes': 10, | |||||
| 'lr': 0.01, | |||||
| "momentum": 0.9, | |||||
| 'epoch_size': 1 | |||||
| } | |||||
| """) | |||||
| source_files = template_mgr.render(optimizer='SGD') | |||||
| assert source_files[0].content == textwrap.dedent("""\ | |||||
| { | |||||
| 'num_classes': 10, | |||||
| 'lr': 0.1, | |||||
| 'epoch_size': 1 | |||||
| } | |||||
| """) | |||||
| source_files = template_mgr.render() | |||||
| assert source_files[0].content == textwrap.dedent("""\ | |||||
| { | |||||
| 'num_classes': 10, | |||||
| 'lr': 0.001, | |||||
| 'epoch_size': 1 | |||||
| } | |||||
| """) | |||||
| def test_dataset_render(self): | |||||
| """Test render file in dataset directory.""" | |||||
| template_mgr = TemplateManager(os.path.join(self.template_dir, 'dataset')) | |||||
| source_files = template_mgr.render() | |||||
| assert source_files[0].content == textwrap.dedent("""\ | |||||
| import mindspore.dataset as ds | |||||
| import mindspore.dataset.transforms.vision.c_transforms as CV | |||||
| """) | |||||
| assert source_files[0].file_relative_path == 'mnist/dataset.py' | |||||
| assert source_files[0].template_file_path == os.path.join(self.template_dir, 'dataset', 'mnist/dataset.py-tpl') | |||||
| def test_assemble_render(self): | |||||
| """Test render assemble files in template directory.""" | |||||
| template_mgr = TemplateManager(self.template_dir, exclude_dirs=['src', 'dataset']) | |||||
| source_files = template_mgr.render(loss='SoftmaxCrossEntropyWithLogits') | |||||
| unmatched_files = [] | |||||
| for source_file in source_files: | |||||
| if source_file.template_file_path == os.path.join(self.template_dir, 'scripts/run_standalone_train.sh-tpl'): | |||||
| assert source_file.content == textwrap.dedent("""\ | |||||
| python train.py --dataset_path=$PATH1 --pre_trained=$PATH2 &> log & | |||||
| """) | |||||
| assert source_file.file_relative_path == 'scripts/run_standalone_train.sh' | |||||
| elif source_file.template_file_path == os.path.join(self.template_dir, 'train.py-tpl'): | |||||
| assert source_file.content == textwrap.dedent("""\ | |||||
| net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") | |||||
| """) | |||||
| assert source_file.file_relative_path == 'train.py' | |||||
| else: | |||||
| unmatched_files.append(source_file) | |||||
| assert not unmatched_files | |||||
| @@ -0,0 +1,28 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Utils method.""" | |||||
| import os | |||||
| import stat | |||||
| def generate_file(file, template_content, mode=None): | |||||
| """Create a file and write content.""" | |||||
| os.makedirs(os.path.dirname(file), mode=stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR, exist_ok=True) | |||||
| with open(file, 'w') as fp: | |||||
| fp.write(template_content) | |||||
| if mode: | |||||
| os.chmod(file, mode) | |||||
| else: | |||||
| os.chmod(file, stat.S_IRUSR) | |||||