Browse Source

optimize mindwizard prompt hints

tags/0.7.0-beta
liangyongxiong 5 years ago
parent
commit
8e3f0d4cf4
8 changed files with 23 additions and 19 deletions
  1. +2
    -2
      RELEASE.md
  2. +1
    -1
      mindinsight/wizard/cli.py
  3. +1
    -1
      mindinsight/wizard/conf/templates/network/alexnet/train.py-tpl
  4. +1
    -1
      mindinsight/wizard/conf/templates/network/lenet/eval.py-tpl
  5. +1
    -4
      mindinsight/wizard/conf/templates/network/lenet/train.py-tpl
  6. +1
    -1
      mindinsight/wizard/conf/templates/network/resnet50/train.py-tpl
  7. +14
    -9
      mindinsight/wizard/create_project.py
  8. +2
    -0
      mindinsight/wizard/network/generic_network.py

+ 2
- 2
RELEASE.md View File

@@ -9,8 +9,8 @@
* Web UI supports language internationalization, including both Chinese and English. * Web UI supports language internationalization, including both Chinese and English.


## Bugfixes ## Bugfixes
* Optimize UI page initialization to handle timeout requests. [!503](https://gitee.com/mindspore/mindinsight/pulls/503)
* Fix the line break problem when the profiling file number is too long. [532](https://gitee.com/mindspore/mindinsight/pulls/532)
* Optimize UI page initialization to handle timeout requests. ([!503](https://gitee.com/mindspore/mindinsight/pulls/503))
* Fix the line break problem when the profiling file number is too long. ([!532](https://gitee.com/mindspore/mindinsight/pulls/532))


## Thanks to our Contributors ## Thanks to our Contributors
Thanks goes to these wonderful people: Thanks goes to these wonderful people:


+ 1
- 1
mindinsight/wizard/cli.py View File

@@ -28,7 +28,7 @@ def cli_entry():
os.umask(permissions << 3 | permissions) os.umask(permissions << 3 | permissions)


parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
prog='wizard',
prog='mindwizard',
description='MindWizard CLI entry point (version: {})'.format(mindinsight.__version__)) description='MindWizard CLI entry point (version: {})'.format(mindinsight.__version__))


parser.add_argument( parser.add_argument(


+ 1
- 1
mindinsight/wizard/conf/templates/network/alexnet/train.py-tpl View File

@@ -41,7 +41,7 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser(description='MindSpore AlexNet Example') parser = argparse.ArgumentParser(description='MindSpore AlexNet Example')
parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute') parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute')
parser.add_argument('--device_num', type=int, default=1, help='Device num') parser.add_argument('--device_num', type=int, default=1, help='Device num')
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'],
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU'],
help='device where the code will be implemented (default: Ascend)') help='device where the code will be implemented (default: Ascend)')
parser.add_argument('--dataset_path', type=str, default="./", help='path where the dataset is saved') parser.add_argument('--dataset_path', type=str, default="./", help='path where the dataset is saved')
parser.add_argument('--pre_trained', type=str, default=None, help='Pre-trained checkpoint path') parser.add_argument('--pre_trained', type=str, default=None, help='Pre-trained checkpoint path')


+ 1
- 1
mindinsight/wizard/conf/templates/network/lenet/eval.py-tpl View File

@@ -33,7 +33,7 @@ from src.lenet import LeNet5


if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description='MindSpore Lenet Example') parser = argparse.ArgumentParser(description='MindSpore Lenet Example')
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'],
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU'],
help='device where the code will be implemented (default: Ascend)') help='device where the code will be implemented (default: Ascend)')
parser.add_argument('--dataset_path', type=str, default="./Data", parser.add_argument('--dataset_path', type=str, default="./Data",
help='path where the dataset is saved') help='path where the dataset is saved')


+ 1
- 4
mindinsight/wizard/conf/templates/network/lenet/train.py-tpl View File

@@ -36,7 +36,7 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser(description='MindSpore Lenet Example') parser = argparse.ArgumentParser(description='MindSpore Lenet Example')
parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute') parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute')
parser.add_argument('--device_num', type=int, default=1, help='Device num.') parser.add_argument('--device_num', type=int, default=1, help='Device num.')
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'],
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU'],
help='device where the code will be implemented (default: Ascend)') help='device where the code will be implemented (default: Ascend)')
parser.add_argument('--dataset_path', type=str, default="./Data", parser.add_argument('--dataset_path', type=str, default="./Data",
help='path where the dataset is saved') help='path where the dataset is saved')
@@ -45,9 +45,6 @@ if __name__ == "__main__":


args = parser.parse_args() args = parser.parse_args()


if args.device_target == "CPU":
args.dataset_sink = False

context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
ckpt_save_dir = './' ckpt_save_dir = './'
if args.run_distribute: if args.run_distribute:


+ 1
- 1
mindinsight/wizard/conf/templates/network/resnet50/train.py-tpl View File

@@ -36,7 +36,7 @@ parser.add_argument('--run_distribute', type=bool, default=False, help='Run dist
parser.add_argument('--device_num', type=int, default=1, help='Device num.') parser.add_argument('--device_num', type=int, default=1, help='Device num.')


parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path') parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
parser.add_argument('--device_target', type=str, default='Ascend', help='Device target: "Ascend", "GPU", "CPU"')
parser.add_argument('--device_target', type=str, default='Ascend', help='Device target: "Ascend", "GPU"')
parser.add_argument('--pre_trained', type=str, default=None, help='Pretrained checkpoint path') parser.add_argument('--pre_trained', type=str, default=None, help='Pretrained checkpoint path')
parser.add_argument('--dataset_sink_mode', type=str, default='True', choices = ['True', 'False'], parser.add_argument('--dataset_sink_mode', type=str, default='True', choices = ['True', 'False'],
help='DataSet sink mode is True or False') help='DataSet sink mode is True or False')


+ 14
- 9
mindinsight/wizard/create_project.py View File

@@ -59,7 +59,8 @@ class CreateProject(BaseCommand):
def _check_project_dir(project_name): def _check_project_dir(project_name):
"""Check project directory whether empty or exist.""" """Check project directory whether empty or exist."""
if not re.search('^[A-Za-z0-9][A-Za-z0-9._-]*$', project_name): if not re.search('^[A-Za-z0-9][A-Za-z0-9._-]*$', project_name):
raise CommandError("'%s' is not a valid project name. Please input a valid name" % project_name)
raise CommandError("'%s' is not a valid project name. Please input a valid name matching "
"regex ^[A-Za-z0-9][A-Za-z0-9._-]*$" % project_name)
project_dir = os.path.join(os.getcwd(), project_name) project_dir = os.path.join(os.getcwd(), project_name)
if os.path.exists(project_dir): if os.path.exists(project_dir):
output_path = Path(project_dir) output_path = Path(project_dir)
@@ -81,19 +82,23 @@ class CreateProject(BaseCommand):
'\n'.join(f'{idx: >4}: {choice}' for idx, choice in enumerate(network_type_choices, start=1)) '\n'.join(f'{idx: >4}: {choice}' for idx, choice in enumerate(network_type_choices, start=1))
) )
prompt_type = click.IntRange(min=1, max=len(network_type_choices)) prompt_type = click.IntRange(min=1, max=len(network_type_choices))
choice = click.prompt(prompt_msg, type=prompt_type, hide_input=False, show_choices=False,
confirmation_prompt=False,
value_proc=lambda x: process_prompt_choice(x, prompt_type))
choice = 0
while not choice:
choice = click.prompt(prompt_msg, default=0, type=prompt_type,
hide_input=False, show_choices=False,
confirmation_prompt=False, show_default=False,
value_proc=lambda x: process_prompt_choice(x, prompt_type))
if not choice:
click.secho(textwrap.dedent("Network is required."), fg='red')

return network_type_choices[choice - 1] return network_type_choices[choice - 1]


@staticmethod @staticmethod
def echo_notice(): def echo_notice():
"""Echo notice for depending environment.""" """Echo notice for depending environment."""
click.secho(textwrap.dedent("""
[NOTICE] To ensure the final generated scripts run under specific environment with the following
mindspore : %s
""" % SUPPORT_MINDSPORE_VERSION), fg='red')
click.secho(textwrap.dedent(
"[NOTICE] The final generated scripts should be run under environment "
"where mindspore==%s and related device drivers are installed. " % SUPPORT_MINDSPORE_VERSION), fg='yellow')


def run(self, args): def run(self, args):
"""Override run method to start.""" """Override run method to start."""


+ 2
- 0
mindinsight/wizard/network/generic_network.py View File

@@ -14,6 +14,7 @@
# ============================================================================ # ============================================================================
"""GenericNetwork module.""" """GenericNetwork module."""
import os import os
import textwrap


import click import click


@@ -93,6 +94,7 @@ class GenericNetwork(BaseNetwork):
choice = click.prompt(prompt_msg, type=prompt_type, hide_input=False, show_choices=False, choice = click.prompt(prompt_msg, type=prompt_type, hide_input=False, show_choices=False,
confirmation_prompt=False, default=default_choice, confirmation_prompt=False, default=default_choice,
value_proc=lambda x: process_prompt_choice(x, prompt_type)) value_proc=lambda x: process_prompt_choice(x, prompt_type))
click.secho(textwrap.dedent("Your choice is %s." % choice_contents[choice - 1]), fg='yellow')
return choice_contents[choice - 1] return choice_contents[choice - 1]


def ask_loss_function(self): def ask_loss_function(self):


Loading…
Cancel
Save