# Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Test CreateProject class.""" import os import shutil import tempfile from unittest.mock import patch from mindinsight.wizard.base.source_file import SourceFile from mindinsight.wizard.create_project import CreateProject from mindinsight.wizard.network.generic_network import GenericNetwork from tests.ut.wizard.utils import generate_file class TestCreateProject: """Test SourceFile""" workspace_dir = None def setup_method(self): """Setup before call test method.""" self.workspace_dir = tempfile.mkdtemp() def teardown_method(self): """Tear down after call test method.""" self._remove_dirs() self.workspace_dir = None def _remove_dirs(self): """Recursively delete a directory tree.""" if self.workspace_dir and os.path.exists(self.workspace_dir): shutil.rmtree(self.workspace_dir) @staticmethod def _generate_file(file): """Create a file and write content.""" generate_file(file, "template file.") @patch.object(GenericNetwork, 'generate') @patch.object(GenericNetwork, 'configure') @patch.object(CreateProject, 'ask_network') @patch.object(CreateProject, 'echo_notice') @patch('os.getcwd') def test_run(self, mock_getcwd, mock_echo_notice, mock_ask_network, mock_config, mock_generate): """Test run method of CreateProject.""" source_file = SourceFile() source_file.template_file_path = os.path.join(self.workspace_dir, 'templates', 'train.py-tpl') source_file.file_relative_path = 'train.py' self._generate_file(source_file.template_file_path) # mock os.getcwd method mock_getcwd.return_value = self.workspace_dir mock_echo_notice.return_value = None mock_ask_network.return_value = 'lenet' mock_config.return_value = None mock_generate.return_value = [source_file] project_name = 'test' new_project = CreateProject() new_project.run({'name': project_name}) assert os.path.exists(os.path.join(self.workspace_dir, project_name)) assert os.access(os.path.join(self.workspace_dir, project_name, 'train.py'), mode=os.F_OK | os.R_OK | os.W_OK)