You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

create_project.py 4.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Create project command module."""
  16. import os
  17. import re
  18. import sys
  19. import textwrap
  20. from pathlib import Path
  21. import click
  22. from mindinsight.utils.command import BaseCommand
  23. from mindinsight.wizard.base.utility import find_network_maker_names, load_network_maker, process_prompt_choice
  24. from mindinsight.wizard.common.exceptions import CommandError
  25. from mindinsight.wizard.conf.constants import SUPPORT_MINDSPORE_VERSION, QUESTION_START
  26. class CreateProject(BaseCommand):
  27. """Create project class."""
  28. name = 'createproject'
  29. description = 'create project'
  30. def __init__(self):
  31. self._network_types = find_network_maker_names()
  32. def add_arguments(self, parser):
  33. """
  34. Add arguments to parser.
  35. Args:
  36. parser (ArgumentParser): Specify parser to which arguments are added.
  37. """
  38. parser.add_argument(
  39. 'name',
  40. type=str,
  41. help='Specify the new project name.')
  42. def _make_project_dir(self, project_name):
  43. self._check_project_dir(project_name)
  44. permissions = os.R_OK | os.W_OK | os.X_OK
  45. mode = permissions << 6
  46. project_dir = os.path.join(os.getcwd(), project_name)
  47. os.makedirs(project_dir, mode=mode, exist_ok=True)
  48. return project_dir
  49. @staticmethod
  50. def _check_project_dir(project_name):
  51. """Check project directory whether empty or exist."""
  52. if not re.search('^[A-Za-z0-9][A-Za-z0-9._-]*$', project_name):
  53. raise CommandError("'%s' is not a valid project name. Please input a valid name" % project_name)
  54. project_dir = os.path.join(os.getcwd(), project_name)
  55. if os.path.exists(project_dir):
  56. output_path = Path(project_dir)
  57. if output_path.is_dir():
  58. if os.path.os.listdir(project_dir):
  59. raise CommandError('%s already exists, %s is not empty directory, please try another name.'
  60. % (project_name, project_dir))
  61. else:
  62. CommandError('There is a file in the current directory has the same name as the project %s, '
  63. 'please try another name.' % project_name)
  64. return True
  65. def ask_network(self):
  66. """Ask user question for selecting a network to create."""
  67. network_type_choices = self._network_types[:]
  68. network_type_choices.sort(reverse=False)
  69. prompt_msg = '{}:\n{}\n'.format(
  70. '%sPlease select a network' % QUESTION_START,
  71. '\n'.join(f'{idx: >4}: {choice}' for idx, choice in enumerate(network_type_choices, start=1))
  72. )
  73. prompt_type = click.IntRange(min=1, max=len(network_type_choices))
  74. choice = click.prompt(prompt_msg, type=prompt_type, hide_input=False, show_choices=False,
  75. confirmation_prompt=False,
  76. value_proc=lambda x: process_prompt_choice(x, prompt_type))
  77. return network_type_choices[choice - 1]
  78. @staticmethod
  79. def echo_notice():
  80. """Echo notice for depending environment."""
  81. click.secho(textwrap.dedent("""
  82. [NOTICE] To ensure the final generated scripts run under specific environment with the following
  83. mindspore : %s
  84. """ % SUPPORT_MINDSPORE_VERSION), fg='red')
  85. def run(self, args):
  86. """Override run method to start."""
  87. project_name = args.get('name')
  88. try:
  89. self._check_project_dir(project_name)
  90. except CommandError as error:
  91. click.secho(error.message, fg='red')
  92. sys.exit(1)
  93. try:
  94. self.echo_notice()
  95. network_maker_name = self.ask_network()
  96. network_maker = load_network_maker(network_maker_name)
  97. network_maker.configure()
  98. except click.exceptions.Abort:
  99. sys.exit(1)
  100. project_dir = self._make_project_dir(project_name)
  101. source_files = network_maker.generate(**args)
  102. for source_file in source_files:
  103. source_file.write(project_dir)
  104. click.secho(f"{project_name} is generated in {project_dir}")