| @@ -18,7 +18,5 @@ import os | |||
| TEMPLATES_BASE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'templates') | |||
| SUPPORT_MINDSPORE_VERSION = '0.7.0' | |||
| SUPPORT_RUN_DRIVER_VERSION = 'C75' | |||
| SUPPORT_CUDA_VERSION = '10.1' | |||
| QUESTION_START = '>>> ' | |||
| @@ -56,22 +56,21 @@ class GenericNetwork(BaseNetwork): | |||
| dict, configuration value to network. | |||
| """ | |||
| if settings: | |||
| config = {'loss': settings['loss'], | |||
| 'optimizer': settings['optimizer'], | |||
| 'dataset': settings['dataset']} | |||
| self.settings.update(config) | |||
| return config | |||
| loss = self.ask_loss_function() | |||
| optimizer = self.ask_optimizer() | |||
| dataset = self.ask_dataset() | |||
| self._dataset_maker = load_dataset_maker(dataset) | |||
| config = dict(settings) | |||
| dataset_name = settings['dataset'] | |||
| self._dataset_maker = load_dataset_maker(dataset_name) | |||
| else: | |||
| loss = self.ask_loss_function() | |||
| optimizer = self.ask_optimizer() | |||
| dataset_name = self.ask_dataset() | |||
| self._dataset_maker = load_dataset_maker(dataset_name) | |||
| dataset_config = self._dataset_maker.configure() | |||
| config = {'loss': loss, | |||
| 'optimizer': optimizer, | |||
| 'dataset': dataset_name} | |||
| config.update(dataset_config) | |||
| self._dataset_maker.set_network(self) | |||
| dataset_config = self._dataset_maker.configure() | |||
| config = {'loss': loss, | |||
| 'optimizer': optimizer, | |||
| 'dataset': dataset} | |||
| config.update(dataset_config) | |||
| self.settings.update(config) | |||
| return config | |||
| @@ -0,0 +1,15 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Test the wizard module.""" | |||
| @@ -0,0 +1,74 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Test CreateProject class.""" | |||
| import os | |||
| import shutil | |||
| import tempfile | |||
| from unittest.mock import patch | |||
| from mindinsight.wizard.base.source_file import SourceFile | |||
| from mindinsight.wizard.create_project import CreateProject | |||
| from mindinsight.wizard.network.generic_network import GenericNetwork | |||
| from tests.ut.wizard.utils import generate_file | |||
| class TestCreateProject: | |||
| """Test SourceFile""" | |||
| workspace_dir = None | |||
| def setup_method(self): | |||
| """Setup before call test method.""" | |||
| self.workspace_dir = tempfile.mkdtemp() | |||
| def teardown_method(self): | |||
| """Tear down after call test method.""" | |||
| self._remove_dirs() | |||
| self.workspace_dir = None | |||
| def _remove_dirs(self): | |||
| """Recursively delete a directory tree.""" | |||
| if self.workspace_dir and os.path.exists(self.workspace_dir): | |||
| shutil.rmtree(self.workspace_dir) | |||
| @staticmethod | |||
| def _generate_file(file): | |||
| """Create a file and write content.""" | |||
| generate_file(file, "template file.") | |||
| @patch.object(GenericNetwork, 'generate') | |||
| @patch.object(GenericNetwork, 'configure') | |||
| @patch.object(CreateProject, 'ask_network') | |||
| @patch.object(CreateProject, 'echo_notice') | |||
| @patch('os.getcwd') | |||
| def test_run(self, mock_getcwd, mock_echo_notice, mock_ask_network, mock_config, mock_generate): | |||
| """Test run method of CreateProject.""" | |||
| source_file = SourceFile() | |||
| source_file.template_file_path = os.path.join(self.workspace_dir, 'templates', 'train.py-tpl') | |||
| source_file.file_relative_path = 'train.py' | |||
| self._generate_file(source_file.template_file_path) | |||
| # mock os.getcwd method | |||
| mock_getcwd.return_value = self.workspace_dir | |||
| mock_echo_notice.return_value = None | |||
| mock_ask_network.return_value = 'lenet' | |||
| mock_config.return_value = None | |||
| mock_generate.return_value = [source_file] | |||
| project_name = 'test' | |||
| new_project = CreateProject() | |||
| new_project.run({'name': project_name}) | |||
| assert os.path.exists(os.path.join(self.workspace_dir, project_name)) | |||
| assert os.access(os.path.join(self.workspace_dir, project_name, 'train.py'), mode=os.F_OK | os.R_OK | os.W_OK) | |||
| @@ -0,0 +1,71 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Test GenericNetwork class.""" | |||
| import os | |||
| import pytest | |||
| from mindinsight.wizard.network import lenet | |||
| class TestGenericNetwork: | |||
| """Test SourceFile""" | |||
| def test_generate_scripts(self): | |||
| """Test network object to generate network scripts""" | |||
| network_inst = lenet.Network() | |||
| network_inst.configure({ | |||
| "loss": "SoftmaxCrossEntropyWithLogits", | |||
| "optimizer": "Momentum", | |||
| "dataset": "mnist"}) | |||
| sources_files = network_inst.generate() | |||
| dataset_source_file = None | |||
| config_source_file = None | |||
| shell_script_dir_files = [] | |||
| out_files = [] | |||
| for sources_file in sources_files: | |||
| if sources_file.file_relative_path == 'src/dataset.py': | |||
| dataset_source_file = sources_file | |||
| elif sources_file.file_relative_path == 'src/config.py': | |||
| config_source_file = sources_file | |||
| elif sources_file.file_relative_path.startswith('scripts'): | |||
| shell_script_dir_files.append(sources_file) | |||
| elif not os.path.dirname(sources_file.file_relative_path): | |||
| out_files.append(sources_file) | |||
| else: | |||
| pass | |||
| assert sources_files | |||
| assert dataset_source_file is not None | |||
| assert config_source_file is not None | |||
| assert shell_script_dir_files | |||
| assert out_files | |||
| def test_config(self): | |||
| """Test network object to config.""" | |||
| network_inst = lenet.Network() | |||
| settings = { | |||
| "loss": "SoftmaxCrossEntropyWithLogits", | |||
| "optimizer": "Momentum", | |||
| "dataset": "mnist"} | |||
| configurations = network_inst.configure(settings) | |||
| assert configurations["dataset"] == settings["dataset"] | |||
| assert configurations["loss"] == settings["loss"] | |||
| assert configurations["optimizer"] == settings["optimizer"] | |||
| settings["dataset"] = "mnist_another" | |||
| with pytest.raises(ModuleNotFoundError) as exec_info: | |||
| network_inst.configure(settings) | |||
| assert exec_info.value.name == f'mindinsight.wizard.dataset.{settings["dataset"]}' | |||
| @@ -0,0 +1,98 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Test SourceFile class.""" | |||
| import os | |||
| import shutil | |||
| import stat | |||
| import tempfile | |||
| import pytest | |||
| from mindinsight.wizard.base.source_file import SourceFile | |||
| from tests.ut.wizard.utils import generate_file | |||
| class TestSourceFile: | |||
| """Test SourceFile""" | |||
| def setup_method(self): | |||
| """Setup before call test method.""" | |||
| self._input_dir = tempfile.mkdtemp() | |||
| self._output_dir = tempfile.mkdtemp() | |||
| def teardown_method(self): | |||
| """Tear down after call test method.""" | |||
| self._remove_dirs() | |||
| self._input_dir = None | |||
| self._output_dir = None | |||
| def _remove_dirs(self): | |||
| """Recursively delete a directory tree.""" | |||
| for temp_dir in [self._input_dir, self._output_dir]: | |||
| if temp_dir and os.path.exists(temp_dir): | |||
| shutil.rmtree(temp_dir) | |||
| @staticmethod | |||
| def _generate_file(file, stat_mode): | |||
| """Create a file and write content.""" | |||
| generate_file(file, "template file.", stat_mode) | |||
| @pytest.mark.parametrize('params', [{ | |||
| 'file_relative_path': 'src/config.py', | |||
| 'template_file_path': 'src/config.py-tpl' | |||
| }, { | |||
| 'file_relative_path': 'src/lenet.py', | |||
| 'template_file_path': 'src/lenet.py-tpl' | |||
| }, { | |||
| 'file_relative_path': 'README.md', | |||
| 'template_file_path': 'README.md-tpl' | |||
| }, { | |||
| 'file_relative_path': 'train.py', | |||
| 'template_file_path': 'train.py-tpl' | |||
| }]) | |||
| def test_write_py(self, params): | |||
| """Test write python script file""" | |||
| source_file = SourceFile() | |||
| source_file.file_relative_path = params['file_relative_path'] | |||
| source_file.template_file_path = os.path.join(self._input_dir, params['template_file_path']) | |||
| self._generate_file(source_file.template_file_path, stat.S_IRUSR) | |||
| # start write | |||
| source_file.write(self._output_dir) | |||
| output_file_path = os.path.join(self._output_dir, source_file.file_relative_path) | |||
| assert os.access(output_file_path, os.F_OK | os.R_OK | os.W_OK) | |||
| assert stat.filemode(os.stat(output_file_path).st_mode) == '-rw-------' | |||
| @pytest.mark.parametrize('params', [{ | |||
| 'file_relative_path': 'scripts/run_eval.sh', | |||
| 'template_file_path': 'scripts/run_eval.sh-tpl' | |||
| }, { | |||
| 'file_relative_path': 'run_distribute_train.sh', | |||
| 'template_file_path': 'run_distribute_train.sh-tpl' | |||
| }]) | |||
| def test_write_sh(self, params): | |||
| """Test write shell script file""" | |||
| source_file = SourceFile() | |||
| source_file.file_relative_path = params['file_relative_path'] | |||
| source_file.template_file_path = os.path.join(self._input_dir, params['template_file_path']) | |||
| self._generate_file(source_file.template_file_path, stat.S_IRUSR) | |||
| # start write | |||
| source_file.write(self._output_dir) | |||
| output_file_path = os.path.join(self._output_dir, source_file.file_relative_path) | |||
| assert os.access(output_file_path, os.F_OK | os.R_OK | os.W_OK | os.X_OK) | |||
| assert stat.filemode(os.stat(output_file_path).st_mode) == '-rwx------' | |||
| @@ -0,0 +1,188 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Test TemplateManager class.""" | |||
| import os | |||
| import shutil | |||
| import tempfile | |||
| import textwrap | |||
| from mindinsight.wizard.base.templates import TemplateManager | |||
| from tests.ut.wizard.utils import generate_file | |||
| def create_template_files(template_dir): | |||
| """Create network template files.""" | |||
| all_template_files = [] | |||
| train_file = os.path.join(template_dir, 'train.py-tpl') | |||
| generate_file(train_file, | |||
| textwrap.dedent("""\ | |||
| {% if loss=='SoftmaxCrossEntropyWithLogits' %} | |||
| net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") | |||
| {% elif loss=='SoftmaxCrossEntropyExpand' %} | |||
| net_loss = nn.SoftmaxCrossEntropyExpand(sparse=True) | |||
| {% endif %} | |||
| """)) | |||
| all_template_files.append(train_file) | |||
| os.mkdir(os.path.join(template_dir, 'src')) | |||
| config_file = os.path.join(template_dir, 'src', 'config.py-tpl') | |||
| generate_file(config_file, | |||
| textwrap.dedent("""\ | |||
| { | |||
| 'num_classes': 10, | |||
| {% if optimizer=='Momentum' %} | |||
| 'lr': 0.01, | |||
| "momentum": 0.9, | |||
| {% elif optimizer=='SGD' %} | |||
| 'lr': 0.1, | |||
| {% else %} | |||
| 'lr': 0.001, | |||
| {% endif %} | |||
| 'epoch_size': 1 | |||
| } | |||
| """)) | |||
| all_template_files.append(config_file) | |||
| os.mkdir(os.path.join(template_dir, 'scripts')) | |||
| run_standalone_train_file = os.path.join(template_dir, 'scripts', 'run_standalone_train.sh-tpl') | |||
| generate_file(run_standalone_train_file, | |||
| textwrap.dedent("""\ | |||
| python train.py --dataset_path=$PATH1 --pre_trained=$PATH2 &> log & | |||
| """)) | |||
| all_template_files.append(run_standalone_train_file) | |||
| os.mkdir(os.path.join(template_dir, 'dataset')) | |||
| os.mkdir(os.path.join(template_dir, 'dataset', 'mnist')) | |||
| dataset_file = os.path.join(template_dir, 'dataset', 'mnist', 'dataset.py-tpl') | |||
| generate_file(dataset_file, | |||
| textwrap.dedent("""\ | |||
| import mindspore.dataset as ds | |||
| import mindspore.dataset.transforms.vision.c_transforms as CV | |||
| """)) | |||
| all_template_files.append(dataset_file) | |||
| return all_template_files | |||
| class TestTemplateManager: | |||
| """Test TemplateManager""" | |||
| template_dir = None | |||
| all_template_files = [] | |||
| def setup_method(self): | |||
| """Setup before call test method.""" | |||
| self.template_dir = tempfile.mkdtemp() | |||
| self.all_template_files = create_template_files(self.template_dir) | |||
| def teardown_method(self): | |||
| """Tear down after call test method.""" | |||
| self._remove_dirs() | |||
| self.template_dir = None | |||
| def _remove_dirs(self): | |||
| """Recursively delete a directory tree.""" | |||
| if self.template_dir and os.path.exists(self.template_dir): | |||
| shutil.rmtree(self.template_dir) | |||
| def test_template_files(self): | |||
| """Test get_template_files method.""" | |||
| src_file_num = 1 | |||
| dataset_file_num = 1 | |||
| template_mgr = TemplateManager(self.template_dir) | |||
| all_files = template_mgr.get_template_files() | |||
| assert set(all_files) == set(self.all_template_files) | |||
| template_mgr = TemplateManager(os.path.join(self.template_dir, 'src')) | |||
| all_files = template_mgr.get_template_files() | |||
| assert len(all_files) == src_file_num | |||
| template_mgr = TemplateManager(os.path.join(self.template_dir, 'dataset')) | |||
| all_files = template_mgr.get_template_files() | |||
| assert len(all_files) == dataset_file_num | |||
| template_mgr = TemplateManager(self.template_dir, exclude_dirs=['src']) | |||
| all_files = template_mgr.get_template_files() | |||
| assert len(all_files) == len(self.all_template_files) - src_file_num | |||
| template_mgr = TemplateManager(self.template_dir, exclude_dirs=['src', 'dataset']) | |||
| all_files = template_mgr.get_template_files() | |||
| assert len(all_files) == len(self.all_template_files) - src_file_num - dataset_file_num | |||
| template_mgr = TemplateManager(self.template_dir, | |||
| exclude_dirs=['src', 'dataset'], | |||
| exclude_files=['train.py-tpl']) | |||
| all_files = template_mgr.get_template_files() | |||
| assert len(all_files) == len(self.all_template_files) - src_file_num - dataset_file_num - 1 | |||
| def test_src_render(self): | |||
| """Test render file in src directory.""" | |||
| template_mgr = TemplateManager(os.path.join(self.template_dir, 'src')) | |||
| source_files = template_mgr.render(optimizer='Momentum') | |||
| assert source_files[0].content == textwrap.dedent("""\ | |||
| { | |||
| 'num_classes': 10, | |||
| 'lr': 0.01, | |||
| "momentum": 0.9, | |||
| 'epoch_size': 1 | |||
| } | |||
| """) | |||
| source_files = template_mgr.render(optimizer='SGD') | |||
| assert source_files[0].content == textwrap.dedent("""\ | |||
| { | |||
| 'num_classes': 10, | |||
| 'lr': 0.1, | |||
| 'epoch_size': 1 | |||
| } | |||
| """) | |||
| source_files = template_mgr.render() | |||
| assert source_files[0].content == textwrap.dedent("""\ | |||
| { | |||
| 'num_classes': 10, | |||
| 'lr': 0.001, | |||
| 'epoch_size': 1 | |||
| } | |||
| """) | |||
| def test_dataset_render(self): | |||
| """Test render file in dataset directory.""" | |||
| template_mgr = TemplateManager(os.path.join(self.template_dir, 'dataset')) | |||
| source_files = template_mgr.render() | |||
| assert source_files[0].content == textwrap.dedent("""\ | |||
| import mindspore.dataset as ds | |||
| import mindspore.dataset.transforms.vision.c_transforms as CV | |||
| """) | |||
| assert source_files[0].file_relative_path == 'mnist/dataset.py' | |||
| assert source_files[0].template_file_path == os.path.join(self.template_dir, 'dataset', 'mnist/dataset.py-tpl') | |||
| def test_assemble_render(self): | |||
| """Test render assemble files in template directory.""" | |||
| template_mgr = TemplateManager(self.template_dir, exclude_dirs=['src', 'dataset']) | |||
| source_files = template_mgr.render(loss='SoftmaxCrossEntropyWithLogits') | |||
| unmatched_files = [] | |||
| for source_file in source_files: | |||
| if source_file.template_file_path == os.path.join(self.template_dir, 'scripts/run_standalone_train.sh-tpl'): | |||
| assert source_file.content == textwrap.dedent("""\ | |||
| python train.py --dataset_path=$PATH1 --pre_trained=$PATH2 &> log & | |||
| """) | |||
| assert source_file.file_relative_path == 'scripts/run_standalone_train.sh' | |||
| elif source_file.template_file_path == os.path.join(self.template_dir, 'train.py-tpl'): | |||
| assert source_file.content == textwrap.dedent("""\ | |||
| net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") | |||
| """) | |||
| assert source_file.file_relative_path == 'train.py' | |||
| else: | |||
| unmatched_files.append(source_file) | |||
| assert not unmatched_files | |||
| @@ -0,0 +1,28 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Utils method.""" | |||
| import os | |||
| import stat | |||
| def generate_file(file, template_content, mode=None): | |||
| """Create a file and write content.""" | |||
| os.makedirs(os.path.dirname(file), mode=stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR, exist_ok=True) | |||
| with open(file, 'w') as fp: | |||
| fp.write(template_content) | |||
| if mode: | |||
| os.chmod(file, mode) | |||
| else: | |||
| os.chmod(file, stat.S_IRUSR) | |||