|
- # Copyright 2020 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """Create project command module."""
- import os
- import re
- import sys
- import textwrap
- from pathlib import Path
-
- import click
-
- from mindinsight.utils.command import BaseCommand
- from mindinsight.wizard.base.utility import find_network_maker_names, load_network_maker, process_prompt_choice
- from mindinsight.wizard.common.exceptions import CommandError
- from mindinsight.wizard.conf.constants import SUPPORT_MINDSPORE_VERSION, QUESTION_START
-
-
- class CreateProject(BaseCommand):
- """Create project class."""
- name = 'createproject'
- description = 'create project'
-
- def __init__(self):
- self._network_types = find_network_maker_names()
-
- def add_arguments(self, parser):
- """
- Add arguments to parser.
-
- Args:
- parser (ArgumentParser): Specify parser to which arguments are added.
- """
- parser.add_argument(
- 'name',
- type=str,
- help='Specify the new project name.')
-
- def _make_project_dir(self, project_name):
- self._check_project_dir(project_name)
- permissions = os.R_OK | os.W_OK | os.X_OK
- mode = permissions << 6
- project_dir = os.path.join(os.getcwd(), project_name)
- os.makedirs(project_dir, mode=mode, exist_ok=True)
- return project_dir
-
- @staticmethod
- def _check_project_dir(project_name):
- """Check project directory whether empty or exist."""
- if not re.search('^[A-Za-z0-9][A-Za-z0-9._-]*$', project_name):
- raise CommandError("'%s' is not a valid project name. Please input a valid name" % project_name)
- project_dir = os.path.join(os.getcwd(), project_name)
- if os.path.exists(project_dir):
- output_path = Path(project_dir)
- if output_path.is_dir():
- if os.path.os.listdir(project_dir):
- raise CommandError('%s already exists, %s is not empty directory, please try another name.'
- % (project_name, project_dir))
- else:
- CommandError('There is a file in the current directory has the same name as the project %s, '
- 'please try another name.' % project_name)
- return True
-
- def ask_network(self):
- """Ask user question for selecting a network to create."""
- network_type_choices = self._network_types[:]
- network_type_choices.sort(reverse=False)
- prompt_msg = '{}:\n{}\n'.format(
- '%sPlease select a network' % QUESTION_START,
- '\n'.join(f'{idx: >4}: {choice}' for idx, choice in enumerate(network_type_choices, start=1))
- )
- prompt_type = click.IntRange(min=1, max=len(network_type_choices))
- choice = click.prompt(prompt_msg, type=prompt_type, hide_input=False, show_choices=False,
- confirmation_prompt=False,
- value_proc=lambda x: process_prompt_choice(x, prompt_type))
- return network_type_choices[choice - 1]
-
- @staticmethod
- def echo_notice():
- """Echo notice for depending environment."""
- click.secho(textwrap.dedent("""
- [NOTICE] To ensure the final generated scripts run under specific environment with the following
-
- mindspore : %s
- """ % SUPPORT_MINDSPORE_VERSION), fg='red')
-
- def run(self, args):
- """Override run method to start."""
- project_name = args.get('name')
- try:
- self._check_project_dir(project_name)
- except CommandError as error:
- click.secho(error.message, fg='red')
- sys.exit(1)
- try:
- self.echo_notice()
- network_maker_name = self.ask_network()
- network_maker = load_network_maker(network_maker_name)
- network_maker.configure()
- except click.exceptions.Abort:
- sys.exit(1)
-
- project_dir = self._make_project_dir(project_name)
- source_files = network_maker.generate(**args)
- for source_file in source_files:
- source_file.write(project_dir)
-
- click.secho(f"{project_name} is generated in {project_dir}")
|