| @@ -32,6 +32,7 @@ class MindInsightModules(Enum): | |||||
| PROFILERMGR = 6 | PROFILERMGR = 6 | ||||
| SCRIPTCONVERTER = 7 | SCRIPTCONVERTER = 7 | ||||
| SYSMETRIC = 8 | SYSMETRIC = 8 | ||||
| WIZARD = 9 | |||||
| class GeneralErrors(Enum): | class GeneralErrors(Enum): | ||||
| @@ -84,3 +85,7 @@ class ScriptConverterErrors(Enum): | |||||
| class SysmetricErrors(Enum): | class SysmetricErrors(Enum): | ||||
| """Enum definition for sysmetric errors.""" | """Enum definition for sysmetric errors.""" | ||||
| DSMI_QUERYING_NONZERO = 1 | DSMI_QUERYING_NONZERO = 1 | ||||
| class WizardErrors(Enum): | |||||
| """Enum definition for mindwizard errors.""" | |||||
| @@ -0,0 +1,14 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| @@ -0,0 +1,14 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| @@ -0,0 +1,32 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """BaseDataset module.""" | |||||
| class BaseDataset: | |||||
| """BaseDataset code generator.""" | |||||
| name = 'BaseDataset' | |||||
| settings = {} | |||||
| def __init__(self): | |||||
| pass | |||||
| def configure(self): | |||||
| """Configure the dataset options.""" | |||||
| raise NotImplementedError("Not provide a configure method in the subclass.") | |||||
| def generate(self, **options): | |||||
| """Generate dataset scripts.""" | |||||
| raise NotImplementedError("Not provide a generate method in the subclass.") | |||||
| @@ -0,0 +1,29 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """BaseNetwork module.""" | |||||
| class BaseNetwork: | |||||
| """BaseNetwork code generator.""" | |||||
| name = 'BaseNetwork' | |||||
| settings = {} | |||||
| def configure(self, settings=None): | |||||
| """Configure the dataset options.""" | |||||
| raise NotImplementedError("Not provide a configure method in the subclass.") | |||||
| def generate(self, **options): | |||||
| """Generate network definition scripts.""" | |||||
| raise NotImplementedError("Not provide a generate method in the subclass.") | |||||
| @@ -0,0 +1,68 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Source file module.""" | |||||
| import os | |||||
| import shutil | |||||
| import stat | |||||
| from pathlib import Path | |||||
| from mindinsight.wizard.common.exceptions import OSPermissionError, TemplateFileError | |||||
| class SourceFile: | |||||
| """Network code generator.""" | |||||
| file_relative_path = '' | |||||
| template_file_path = '' | |||||
| content = '' | |||||
| @staticmethod | |||||
| def _make_dir(directory): | |||||
| permissions = os.R_OK | os.W_OK | os.X_OK | |||||
| mode = permissions << 6 | |||||
| os.makedirs(directory, mode=mode, exist_ok=True) | |||||
| return directory | |||||
| def write(self, project_directory): | |||||
| """Generate the network files.""" | |||||
| template_file_path = Path(self.template_file_path) | |||||
| if not template_file_path.is_file(): | |||||
| raise TemplateFileError("Template file %s is not exist." % self.template_file_path) | |||||
| new_file_path = os.path.join(project_directory, self.file_relative_path) | |||||
| self._make_dir(os.path.dirname(new_file_path)) | |||||
| with open(new_file_path, 'w', encoding='utf-8') as fp: | |||||
| fp.write(self.content) | |||||
| try: | |||||
| shutil.copymode(self.template_file_path, new_file_path) | |||||
| self.set_writeable(new_file_path) | |||||
| if new_file_path.endswith('.sh'): | |||||
| self.set_executable(new_file_path) | |||||
| except OSError: | |||||
| raise OSPermissionError("Notice: Set permission bits failed on %s." % new_file_path) | |||||
| @staticmethod | |||||
| def set_writeable(file_name): | |||||
| """Add write permission.""" | |||||
| if not os.access(file_name, os.W_OK): | |||||
| file_stat = os.stat(file_name) | |||||
| permissions = stat.S_IMODE(file_stat.st_mode) | stat.S_IWUSR | |||||
| os.chmod(file_name, permissions) | |||||
| @staticmethod | |||||
| def set_executable(file_name): | |||||
| """Add executable permission.""" | |||||
| if not os.access(file_name, os.X_OK): | |||||
| file_stat = os.stat(file_name) | |||||
| permissions = stat.S_IMODE(file_stat.st_mode) | stat.S_IXUSR | |||||
| os.chmod(file_name, permissions) | |||||
| @@ -0,0 +1,91 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """BaseNetwork module.""" | |||||
| import os | |||||
| from jinja2 import Template | |||||
| from mindinsight.wizard.base.source_file import SourceFile | |||||
| def render_template(template_file_path, context): | |||||
| with open(template_file_path, encoding='utf-8') as fp: | |||||
| content = fp.read() | |||||
| template = Template(content, trim_blocks=True, lstrip_blocks=True, keep_trailing_newline=True) | |||||
| return template.render(context) | |||||
| class TemplateManager: | |||||
| """BaseNetwork code generator.""" | |||||
| replace_template_suffixes = [('.py-tpl', '.py')] | |||||
| def __init__(self, template_base_dir, exclude_dirs=None, exclude_files=None): | |||||
| self.template_base_dir = template_base_dir | |||||
| self.exclude_dirs = self._get_exclude_paths(template_base_dir, exclude_dirs) | |||||
| self.exclude_files = self._get_exclude_paths(template_base_dir, exclude_files) | |||||
| @staticmethod | |||||
| def _get_exclude_paths(base_dir, exclude_paths): | |||||
| """Convert exclude path to absolute directory path.""" | |||||
| exclude_abs_paths = [] | |||||
| if exclude_paths is None: | |||||
| return exclude_abs_paths | |||||
| for exclude_path in exclude_paths: | |||||
| if exclude_path.startswith(base_dir): | |||||
| exclude_abs_paths.append(exclude_path) | |||||
| else: | |||||
| exclude_abs_paths.append(os.path.join(base_dir, exclude_path)) | |||||
| return exclude_abs_paths | |||||
| def get_template_files(self): | |||||
| """Get template files for template directory.""" | |||||
| template_files = [] | |||||
| for root, sub_dirs, files in os.walk(self.template_base_dir): | |||||
| for sub_dir in sub_dirs[:]: | |||||
| if sub_dir.startswith('.') or \ | |||||
| sub_dir == '__pycache__' or \ | |||||
| os.path.join(root, sub_dir) in self.exclude_dirs: | |||||
| sub_dirs.remove(sub_dir) | |||||
| for filename in files: | |||||
| if os.path.join(root, filename) not in self.exclude_files: | |||||
| template_file_path = os.path.join(root, filename) | |||||
| template_files.append(template_file_path) | |||||
| return template_files | |||||
| def render(self, **options): | |||||
| """Generate the network files.""" | |||||
| source_files = [] | |||||
| template_files = self.get_template_files() | |||||
| extensions = tuple(options.get('extensions', '.py')) | |||||
| for template_file in template_files: | |||||
| new_file_path = template_file | |||||
| template_file_path = template_file | |||||
| for template_suffix, new_file_suffix in self.replace_template_suffixes: | |||||
| if new_file_path.endswith(template_suffix): | |||||
| new_file_path = new_file_path[:-len(template_suffix)] + new_file_suffix | |||||
| break | |||||
| source_file = SourceFile() | |||||
| source_file.file_relative_path = new_file_path[len(self.template_base_dir) + 1:] | |||||
| source_file.template_file_path = template_file_path | |||||
| if new_file_path.endswith(extensions): | |||||
| source_file.content = render_template(template_file_path, options) | |||||
| else: | |||||
| with open(template_file_path, encoding='utf-8') as fp: | |||||
| source_file.content = fp.read() | |||||
| source_files.append(source_file) | |||||
| return source_files | |||||
| @@ -0,0 +1,41 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """NetworkUtility module.""" | |||||
| from importlib import import_module | |||||
| from mindinsight.wizard.common.exceptions import CommandError | |||||
| def find_network_maker_names(): | |||||
| return ['lenet', 'alexnet', 'resnet50'] | |||||
| def load_network_maker(network_name): | |||||
| module = import_module(f'mindinsight.wizard.network.{network_name.lower()}') | |||||
| return module.Network() | |||||
| def load_dataset_maker(dataset_name, **kwargs): | |||||
| module = import_module(f'mindinsight.wizard.dataset.{dataset_name.lower()}') | |||||
| return module.Dataset(**kwargs) | |||||
| def process_prompt_choice(value, prompt_type): | |||||
| """Convert command value to business value.""" | |||||
| if value is not None: | |||||
| idx = prompt_type(value) | |||||
| return idx | |||||
| raise CommandError("The choice is not exist, please choice again.") | |||||
| @@ -0,0 +1,49 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Command module.""" | |||||
| import os | |||||
| import sys | |||||
| import argparse | |||||
| import mindinsight | |||||
| from mindinsight.wizard.create_project import CreateProject | |||||
| def cli_entry(): | |||||
| """Entry point for mindwizard CLI.""" | |||||
| permissions = os.R_OK | os.W_OK | os.X_OK | |||||
| os.umask(permissions << 3 | permissions) | |||||
| parser = argparse.ArgumentParser( | |||||
| prog='wizard', | |||||
| description='MindWizard CLI entry point (version: {})'.format(mindinsight.__version__)) | |||||
| parser.add_argument( | |||||
| '--version', | |||||
| action='version', | |||||
| version='%(prog)s ({})'.format(mindinsight.__version__)) | |||||
| command = CreateProject() | |||||
| command.add_arguments(parser) | |||||
| argv = sys.argv[1:] | |||||
| if not argv or argv[0] == 'help': | |||||
| argv = ['-h'] | |||||
| args = parser.parse_args(argv) | |||||
| command.invoke(vars(args)) | |||||
| if __name__ == '__main__': | |||||
| cli_entry() | |||||
| @@ -0,0 +1,14 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| @@ -0,0 +1,49 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Define custom exception.""" | |||||
| from enum import unique | |||||
| from mindinsight.utils.constant import WizardErrors | |||||
| from mindinsight.utils.exceptions import MindInsightException | |||||
| @unique | |||||
| class WizardErrorCodes(WizardErrors): | |||||
| """Wizard error codes.""" | |||||
| CODE_SYNTAX_ERROR = 1 | |||||
| OS_PERMISSION_ERROR = 2 | |||||
| COMMAND_ERROR = 3 | |||||
| TEMPLATE_FILE_ERROR = 4 | |||||
| class CodeSyntaxError(MindInsightException): | |||||
| """The CodeSyntaxError class definition.""" | |||||
| def __init__(self, msg): | |||||
| super(CodeSyntaxError, self).__init__(WizardErrorCodes.CODE_SYNTAX_ERROR, msg) | |||||
| class OSPermissionError(MindInsightException): | |||||
| def __init__(self, msg): | |||||
| super(OSPermissionError, self).__init__(WizardErrorCodes.OS_PERMISSION_ERROR, msg) | |||||
| class CommandError(MindInsightException): | |||||
| def __init__(self, msg): | |||||
| super(CommandError, self).__init__(WizardErrorCodes.COMMAND_ERROR, msg) | |||||
| class TemplateFileError(MindInsightException): | |||||
| def __init__(self, msg): | |||||
| super(TemplateFileError, self).__init__(WizardErrorCodes.TEMPLATE_FILE_ERROR, msg) | |||||
| @@ -0,0 +1,14 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| @@ -0,0 +1,24 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Constants module for wizard.""" | |||||
| import os | |||||
| TEMPLATES_BASE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'templates') | |||||
| SUPPORT_MINDSPORE_VERSION = '0.7.0' | |||||
| SUPPORT_RUN_DRIVER_VERSION = 'C75' | |||||
| SUPPORT_CUDA_VERSION = '10.1' | |||||
| QUESTION_START = '>>> ' | |||||
| @@ -0,0 +1,82 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| Produce the dataset | |||||
| """ | |||||
| import os | |||||
| import mindspore.dataset as ds | |||||
| import mindspore.dataset.transforms.c_transforms as C | |||||
| import mindspore.dataset.transforms.vision.c_transforms as CV | |||||
| from mindspore.common import dtype as mstype | |||||
| from .config import cfg | |||||
| from mindspore.communication.management import init, get_rank, get_group_size | |||||
| def create_dataset(data_path, batch_size=32, repeat_size=1, do_train=True, target="Ascend"): | |||||
| """ | |||||
| create dataset for train or test | |||||
| """ | |||||
| if target == "Ascend": | |||||
| device_num, rank_id = _get_rank_info() | |||||
| elif target == "GPU": | |||||
| init("nccl") | |||||
| rank_id = get_rank() | |||||
| device_num = get_group_size() | |||||
| else: | |||||
| device_num = 1 | |||||
| if device_num == 1: | |||||
| cifar_ds = ds.Cifar10Dataset(data_path, num_parallel_workers=8) | |||||
| else: | |||||
| cifar_ds = ds.Cifar10Dataset(data_path, num_parallel_workers=8, num_shards=device_num, shard_id=rank_id) | |||||
| rescale = 1.0 / 255.0 | |||||
| shift = 0.0 | |||||
| resize_op = CV.Resize((cfg.image_height, cfg.image_width)) | |||||
| rescale_op = CV.Rescale(rescale, shift) | |||||
| normalize_op = CV.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) | |||||
| if do_train: | |||||
| random_crop_op = CV.RandomCrop([32, 32], [4, 4, 4, 4]) | |||||
| random_horizontal_op = CV.RandomHorizontalFlip() | |||||
| channel_swap_op = CV.HWC2CHW() | |||||
| typecast_op = C.TypeCast(mstype.int32) | |||||
| cifar_ds = cifar_ds.map(input_columns="label", operations=typecast_op) | |||||
| if do_train: | |||||
| cifar_ds = cifar_ds.map(input_columns="image", operations=random_crop_op) | |||||
| cifar_ds = cifar_ds.map(input_columns="image", operations=random_horizontal_op) | |||||
| cifar_ds = cifar_ds.map(input_columns="image", operations=resize_op) | |||||
| cifar_ds = cifar_ds.map(input_columns="image", operations=rescale_op) | |||||
| cifar_ds = cifar_ds.map(input_columns="image", operations=normalize_op) | |||||
| cifar_ds = cifar_ds.map(input_columns="image", operations=channel_swap_op) | |||||
| cifar_ds = cifar_ds.shuffle(buffer_size=1000) | |||||
| cifar_ds = cifar_ds.batch(batch_size, drop_remainder=True) | |||||
| cifar_ds = cifar_ds.repeat(repeat_size) | |||||
| return cifar_ds | |||||
| def _get_rank_info(): | |||||
| """ | |||||
| get rank size and rank id | |||||
| """ | |||||
| rank_size = int(os.environ.get("RANK_SIZE", 1)) | |||||
| if rank_size > 1: | |||||
| rank_size = get_group_size() | |||||
| rank_id = get_rank() | |||||
| else: | |||||
| rank_size = 1 | |||||
| rank_id = 0 | |||||
| return rank_size, rank_id | |||||
| @@ -0,0 +1,104 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| create train or eval dataset. | |||||
| """ | |||||
| import os | |||||
| import mindspore.common.dtype as mstype | |||||
| import mindspore.dataset.engine as de | |||||
| import mindspore.dataset.transforms.vision.c_transforms as C | |||||
| import mindspore.dataset.transforms.c_transforms as C2 | |||||
| from mindspore.communication.management import init, get_rank, get_group_size | |||||
| from .config import cfg | |||||
| def create_dataset(data_path, batch_size=32, repeat_size=1, do_train=True, target="Ascend"): | |||||
| """ | |||||
| create a train or eval imagenet dataset | |||||
| Args: | |||||
| dataset_path(string): the path of dataset. | |||||
| do_train(bool): whether dataset is used for train or eval. | |||||
| repeat_num(int): the repeat times of dataset. Default: 1 | |||||
| batch_size(int): the batch size of dataset. Default: 32 | |||||
| target(string): the target of device. Default: "Ascend" | |||||
| Returns: | |||||
| dataset | |||||
| """ | |||||
| if target == "Ascend": | |||||
| device_num, rank_id = _get_rank_info() | |||||
| elif target == "GPU": | |||||
| init("nccl") | |||||
| rank_id = get_rank() | |||||
| device_num = get_group_size() | |||||
| else: | |||||
| device_num = 1 | |||||
| if device_num == 1: | |||||
| ds = de.ImageFolderDatasetV2(data_path, num_parallel_workers=8, shuffle=True) | |||||
| else: | |||||
| ds = de.ImageFolderDatasetV2(data_path, num_parallel_workers=8, shuffle=True, | |||||
| num_shards=device_num, shard_id=rank_id) | |||||
| image_size = cfg.image_height | |||||
| mean = [0.485 * 255, 0.456 * 255, 0.406 * 255] | |||||
| std = [0.229 * 255, 0.224 * 255, 0.225 * 255] | |||||
| # define map operations | |||||
| if do_train: | |||||
| trans = [ | |||||
| C.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)), | |||||
| C.RandomHorizontalFlip(prob=0.5), | |||||
| C.Normalize(mean=mean, std=std), | |||||
| C.HWC2CHW() | |||||
| ] | |||||
| else: | |||||
| trans = [ | |||||
| C.Decode(), | |||||
| C.Resize(256), | |||||
| C.CenterCrop(image_size), | |||||
| C.Normalize(mean=mean, std=std), | |||||
| C.HWC2CHW() | |||||
| ] | |||||
| type_cast_op = C2.TypeCast(mstype.int32) | |||||
| ds = ds.map(input_columns="image", num_parallel_workers=8, operations=trans) | |||||
| ds = ds.map(input_columns="label", num_parallel_workers=8, operations=type_cast_op) | |||||
| # apply batch operations | |||||
| ds = ds.batch(batch_size, drop_remainder=True) | |||||
| # apply dataset repeat operation | |||||
| ds = ds.repeat(repeat_size) | |||||
| return ds | |||||
| def _get_rank_info(): | |||||
| """ | |||||
| get rank size and rank id | |||||
| """ | |||||
| rank_size = int(os.environ.get("RANK_SIZE", 1)) | |||||
| if rank_size > 1: | |||||
| rank_size = get_group_size() | |||||
| rank_id = get_rank() | |||||
| else: | |||||
| rank_size = 1 | |||||
| rank_id = 0 | |||||
| return rank_size, rank_id | |||||
| @@ -0,0 +1,98 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| Produce the dataset | |||||
| """ | |||||
| import os | |||||
| import mindspore.dataset as ds | |||||
| import mindspore.dataset.transforms.vision.c_transforms as CV | |||||
| import mindspore.dataset.transforms.c_transforms as C | |||||
| from mindspore.dataset.transforms.vision import Inter | |||||
| from mindspore.common import dtype as mstype | |||||
| from mindspore.communication.management import init, get_rank, get_group_size | |||||
| from .config import cfg | |||||
| from mindspore.communication.management import init, get_rank, get_group_size | |||||
| def create_dataset(data_path, batch_size=32, repeat_size=1, do_train=True, target="Ascend"): | |||||
| """ | |||||
| create dataset for train or test | |||||
| """ | |||||
| if do_train: | |||||
| data_path = os.path.join(data_path, "train") | |||||
| else: | |||||
| data_path = os.path.join(data_path, "test") | |||||
| if target == 'Ascend': | |||||
| device_num, rank_id = _get_rank_info() | |||||
| elif target == 'GPU': | |||||
| init("nccl") | |||||
| rank_id = get_rank() | |||||
| device_num = get_group_size() | |||||
| else: | |||||
| device_num = 1 | |||||
| # define dataset | |||||
| if device_num == 1: | |||||
| mnist_ds = ds.MnistDataset(data_path) | |||||
| else: | |||||
| mnist_ds = ds.MnistDataset(data_path, num_parallel_workers=8, shuffle=True, | |||||
| num_shards=device_num, shard_id=rank_id) | |||||
| resize_height, resize_width = cfg.image_height, cfg.image_width | |||||
| rescale = 1.0 / 255.0 | |||||
| shift = 0.0 | |||||
| rescale_nml = 1 / 0.3081 | |||||
| shift_nml = -1 * 0.1307 / 0.3081 | |||||
| # define map operations | |||||
| resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) # Bilinear mode | |||||
| rescale_nml_op = CV.Rescale(rescale_nml, shift_nml) | |||||
| rescale_op = CV.Rescale(rescale, shift) | |||||
| hwc2chw_op = CV.HWC2CHW() | |||||
| type_cast_op = C.TypeCast(mstype.int32) | |||||
| # apply map operations on images | |||||
| mnist_ds = mnist_ds.map(input_columns="label", operations=type_cast_op) | |||||
| mnist_ds = mnist_ds.map(input_columns="image", operations=resize_op) | |||||
| mnist_ds = mnist_ds.map(input_columns="image", operations=rescale_op) | |||||
| mnist_ds = mnist_ds.map(input_columns="image", operations=rescale_nml_op) | |||||
| mnist_ds = mnist_ds.map(input_columns="image", operations=hwc2chw_op) | |||||
| # apply DatasetOps | |||||
| buffer_size = 10000 | |||||
| mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) | |||||
| mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True) | |||||
| mnist_ds = mnist_ds.repeat(repeat_size) | |||||
| return mnist_ds | |||||
| def _get_rank_info(): | |||||
| """ | |||||
| get rank size and rank id | |||||
| """ | |||||
| rank_size = int(os.environ.get("RANK_SIZE", 1)) | |||||
| if rank_size > 1: | |||||
| rank_size = get_group_size() | |||||
| rank_id = get_rank() | |||||
| else: | |||||
| rank_size = 1 | |||||
| rank_id = 0 | |||||
| return rank_size, rank_id | |||||
| @@ -0,0 +1,67 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| ######################## eval alexnet example ######################## | |||||
| eval alexnet according to model file: | |||||
| python eval.py --data_path /YourDataPath --ckpt_path Your.ckpt | |||||
| """ | |||||
| import argparse | |||||
| from src.config import cfg | |||||
| from src.dataset import create_dataset | |||||
| from src.alexnet import AlexNet | |||||
| import mindspore.nn as nn | |||||
| from mindspore import context | |||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||||
| from mindspore.train import Model | |||||
| from mindspore.nn.metrics import Accuracy | |||||
| if __name__ == "__main__": | |||||
| parser = argparse.ArgumentParser(description='MindSpore AlexNet Example') | |||||
| parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU'], | |||||
| help='device where the code will be implemented (default: Ascend)') | |||||
| parser.add_argument('--data_path', type=str, default="./", help='path where the dataset is saved') | |||||
| parser.add_argument('--ckpt_path', type=str, default="./ckpt", help='if is test, must provide\ | |||||
| path where the trained ckpt file') | |||||
| parser.add_argument('--dataset_sink_mode', type=str, default='True', choices = ['True', 'False'], | |||||
| help='DataSet sink mode is True or False') | |||||
| args = parser.parse_args() | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) | |||||
| data_path = args.data_path | |||||
| dataset_sink_mode = args.dataset_sink_mode=='True' | |||||
| network = AlexNet(cfg.num_classes) | |||||
| {% if loss=='SoftmaxCrossEntropyWithLogits' %} | |||||
| net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") | |||||
| {% elif loss=='SoftmaxCrossEntropyExpand' %} | |||||
| net_loss = nn.SoftmaxCrossEntropyExpand(sparse=True) | |||||
| {% endif %} | |||||
| {% if optimizer=='Lamb' %} | |||||
| net_opt = nn.Lamb(network.trainable_params(), learning_rate=cfg.lr) | |||||
| {% elif optimizer=='Momentum' %} | |||||
| net_opt = nn.Momentum(network.trainable_params(), learning_rate=cfg.lr, momentum=cfg.momentum) | |||||
| {% endif %} | |||||
| model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) | |||||
| print("============== Starting Testing ==============") | |||||
| param_dict = load_checkpoint(args.ckpt_path) | |||||
| load_param_into_net(network, param_dict) | |||||
| do_train = False | |||||
| ds_eval = create_dataset(data_path=data_path, batch_size=cfg.batch_size, do_train=do_train, | |||||
| target=args.device_target) | |||||
| acc = model.eval(ds_eval, dataset_sink_mode=dataset_sink_mode) | |||||
| print("============== {} ==============".format(acc)) | |||||
| @@ -0,0 +1,87 @@ | |||||
| #!/bin/bash | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| if [ $# != 2 ] && [ $# != 3 ] | |||||
| then | |||||
| echo "Usage: sh run_distribute_train.sh [DATASET_PATH] [MINDSPORE_HCCL_CONFIG_PATH] [PRETRAINED_CKPT_PATH](optional)" | |||||
| exit 1 | |||||
| fi | |||||
| get_real_path(){ | |||||
| if [ "${1:0:1}" == "/" ]; then | |||||
| echo "$1" | |||||
| else | |||||
| echo "$(realpath -m $PWD/$1)" | |||||
| fi | |||||
| } | |||||
| PATH1=$(get_real_path $1) | |||||
| PATH2=$(get_real_path $2) | |||||
| if [ ! -d $PATH1 ] | |||||
| then | |||||
| echo "error: DATASET_PATH=$PATH1 is not a directory" | |||||
| exit 1 | |||||
| fi | |||||
| if [ ! -f $PATH2 ] | |||||
| then | |||||
| echo "error: MINDSPORE_HCCL_CONFIG_PATH=$PATH2 is not a file" | |||||
| exit 1 | |||||
| fi | |||||
| if [ $# == 3 ] | |||||
| then | |||||
| PATH3=$(get_real_path $3) | |||||
| fi | |||||
| if [ $# == 3 ] && [ ! -f $PATH3 ] | |||||
| then | |||||
| echo "error: PRETRAINED_CKPT_PATH=$PATH3 is not a file" | |||||
| exit 1 | |||||
| fi | |||||
| ulimit -u unlimited | |||||
| export DEVICE_NUM=8 | |||||
| export RANK_SIZE=8 | |||||
| export MINDSPORE_HCCL_CONFIG_PATH=$PATH2 | |||||
| export RANK_TABLE_FILE=$PATH2 | |||||
| rank_start=$((DEVICE_NUM * SERVER_ID)) | |||||
| for((i=0; i<DEVICE_NUM; i++)) | |||||
| do | |||||
| export DEVICE_ID=$i | |||||
| export RANK_ID=$((rank_start + i)) | |||||
| rm -rf ./train_parallel$i | |||||
| mkdir ./train_parallel$i | |||||
| cp ../*.py ./train_parallel$i | |||||
| cp *.sh ./train_parallel$i | |||||
| cp -r ../src ./train_parallel$i | |||||
| cd ./train_parallel$i || exit | |||||
| echo "start training for rank $RANK_ID, device $DEVICE_ID" | |||||
| env > env.log | |||||
| if [ $# == 2 ] | |||||
| then | |||||
| python train.py --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$PATH1 --dataset_sink_mode=False &> log & | |||||
| fi | |||||
| if [ $# == 3 ] | |||||
| then | |||||
| python train.py --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$PATH1 --pre_trained=$PATH2 --dataset_sink_mode=False &> log & | |||||
| fi | |||||
| cd .. | |||||
| done | |||||
| @@ -0,0 +1,53 @@ | |||||
| #!/bin/bash | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| if [ $# != 1 ] | |||||
| then | |||||
| echo "Usage: sh run_distribute_train_gpu.sh [DATASET_PATH]" | |||||
| exit 1 | |||||
| fi | |||||
| get_real_path(){ | |||||
| if [ "${1:0:1}" == "/" ]; then | |||||
| echo "$1" | |||||
| else | |||||
| echo "$(realpath -m $PWD/$1)" | |||||
| fi | |||||
| } | |||||
| PATH1=$(get_real_path $1) | |||||
| if [ ! -d $PATH1 ] | |||||
| then | |||||
| echo "error: DATASET_PATH=$PATH1 is not a directory" | |||||
| exit 1 | |||||
| fi | |||||
| ulimit -u unlimited | |||||
| export DEVICE_NUM=4 | |||||
| export RANK_SIZE=4 | |||||
| rm -rf ./train_parallel | |||||
| mkdir ./train_parallel | |||||
| cp ../*.py ./train_parallel | |||||
| cp *.sh ./train_parallel | |||||
| cp -r ../src ./train_parallel | |||||
| cd ./train_parallel || exit | |||||
| mpirun --allow-run-as-root -n $RANK_SIZE \ | |||||
| python train.py --run_distribute=True \ | |||||
| --device_num=$DEVICE_NUM --device_target="GPU" --dataset_path=$PATH1 &> log & | |||||
| @@ -0,0 +1,65 @@ | |||||
| #!/bin/bash | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| if [ $# != 2 ] | |||||
| then | |||||
| echo "Usage: sh run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH]" | |||||
| exit 1 | |||||
| fi | |||||
| get_real_path(){ | |||||
| if [ "${1:0:1}" == "/" ]; then | |||||
| echo "$1" | |||||
| else | |||||
| echo "$(realpath -m $PWD/$1)" | |||||
| fi | |||||
| } | |||||
| PATH1=$(get_real_path $1) | |||||
| PATH2=$(get_real_path $2) | |||||
| if [ ! -d $PATH1 ] | |||||
| then | |||||
| echo "error: DATASET_PATH=$PATH1 is not a directory" | |||||
| exit 1 | |||||
| fi | |||||
| if [ ! -f $PATH2 ] | |||||
| then | |||||
| echo "error: CHECKPOINT_PATH=$PATH2 is not a file" | |||||
| exit 1 | |||||
| fi | |||||
| ulimit -u unlimited | |||||
| export DEVICE_NUM=1 | |||||
| export DEVICE_ID=0 | |||||
| export RANK_SIZE=$DEVICE_NUM | |||||
| export RANK_ID=0 | |||||
| if [ -d "eval" ]; | |||||
| then | |||||
| rm -rf ./eval | |||||
| fi | |||||
| mkdir ./eval | |||||
| cp ../*.py ./eval | |||||
| cp *.sh ./eval | |||||
| cp -r ../src ./eval | |||||
| cd ./eval || exit | |||||
| env > env.log | |||||
| echo "start evaluation for device $DEVICE_ID" | |||||
| python eval.py --dataset_path=$PATH1 --checkpoint_path=$PATH2 &> log & | |||||
| cd .. | |||||
| @@ -0,0 +1,66 @@ | |||||
| #!/bin/bash | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| if [ $# != 2 ] | |||||
| then | |||||
| echo "Usage: sh run_eval_gpu.sh [DATASET_PATH] [CHECKPOINT_PATH]" | |||||
| exit 1 | |||||
| fi | |||||
| get_real_path(){ | |||||
| if [ "${1:0:1}" == "/" ]; then | |||||
| echo "$1" | |||||
| else | |||||
| echo "$(realpath -m $PWD/$1)" | |||||
| fi | |||||
| } | |||||
| PATH1=$(get_real_path $1) | |||||
| PATH2=$(get_real_path $2) | |||||
| if [ ! -d $PATH1 ] | |||||
| then | |||||
| echo "error: DATASET_PATH=$PATH1 is not a directory" | |||||
| exit 1 | |||||
| fi | |||||
| if [ ! -f $PATH2 ] | |||||
| then | |||||
| echo "error: CHECKPOINT_PATH=$PATH2 is not a file" | |||||
| exit 1 | |||||
| fi | |||||
| ulimit -u unlimited | |||||
| export DEVICE_NUM=1 | |||||
| export DEVICE_ID=0 | |||||
| export RANK_SIZE=$DEVICE_NUM | |||||
| export RANK_ID=0 | |||||
| if [ -d "eval" ]; | |||||
| then | |||||
| rm -rf ./eval | |||||
| fi | |||||
| mkdir ./eval | |||||
| cp ../*.py ./eval | |||||
| cp *.sh ./eval | |||||
| cp -r ../src ./eval | |||||
| cd ./eval || exit | |||||
| env > env.log | |||||
| echo "start evaluation for device $DEVICE_ID" | |||||
| python eval.py --dataset_path=$PATH1 --checkpoint_path=$PATH2 --device_target="GPU" &> log & | |||||
| cd .. | |||||
| @@ -0,0 +1,76 @@ | |||||
| #!/bin/bash | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| if [ $# != 1 ] && [ $# != 2 ] | |||||
| then | |||||
| echo "Usage: sh run_standalone_train.sh [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)" | |||||
| exit 1 | |||||
| fi | |||||
| get_real_path(){ | |||||
| if [ "${1:0:1}" == "/" ]; then | |||||
| echo "$1" | |||||
| else | |||||
| echo "$(realpath -m $PWD/$1)" | |||||
| fi | |||||
| } | |||||
| PATH1=$(get_real_path $1) | |||||
| if [ $# == 2 ] | |||||
| then | |||||
| PATH2=$(get_real_path $2) | |||||
| fi | |||||
| if [ ! -d $PATH1 ] | |||||
| then | |||||
| echo "error: DATASET_PATH=$PATH1 is not a directory" | |||||
| exit 1 | |||||
| fi | |||||
| if [ $# == 2 ] && [ ! -f $PATH2 ] | |||||
| then | |||||
| echo "error: PRETRAINED_CKPT_PATH=$PATH2 is not a file" | |||||
| exit 1 | |||||
| fi | |||||
| ulimit -u unlimited | |||||
| export DEVICE_NUM=1 | |||||
| export DEVICE_ID=0 | |||||
| export RANK_ID=0 | |||||
| export RANK_SIZE=1 | |||||
| if [ -d "train" ]; | |||||
| then | |||||
| rm -rf ./train | |||||
| fi | |||||
| mkdir ./train | |||||
| cp ../*.py ./train | |||||
| cp *.sh ./train | |||||
| cp -r ../src ./train | |||||
| cd ./train || exit | |||||
| echo "start training for device $DEVICE_ID" | |||||
| env > env.log | |||||
| if [ $# == 1 ] | |||||
| then | |||||
| python train.py --dataset_path=$PATH1 --dataset_sink_mode=False &> log & | |||||
| fi | |||||
| if [ $# == 2 ] | |||||
| then | |||||
| python train.py --dataset_path=$PATH1 --pre_trained=$PATH2 --dataset_sink_mode=False &> log & | |||||
| fi | |||||
| cd .. | |||||
| @@ -0,0 +1,59 @@ | |||||
| #!/bin/bash | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| if [ $# != 1 ] | |||||
| then | |||||
| echo "Usage: sh run_standalone_train_gpu.sh [DATASET_PATH]" | |||||
| exit 1 | |||||
| fi | |||||
| get_real_path(){ | |||||
| if [ "${1:0:1}" == "/" ]; then | |||||
| echo "$1" | |||||
| else | |||||
| echo "$(realpath -m $PWD/$1)" | |||||
| fi | |||||
| } | |||||
| PATH1=$(get_real_path $1) | |||||
| if [ ! -d $PATH1 ] | |||||
| then | |||||
| echo "error: DATASET_PATH=$PATH1 is not a directory" | |||||
| exit 1 | |||||
| fi | |||||
| ulimit -u unlimited | |||||
| export DEVICE_NUM=1 | |||||
| export DEVICE_ID=0 | |||||
| export RANK_ID=0 | |||||
| export RANK_SIZE=1 | |||||
| if [ -d "train" ]; | |||||
| then | |||||
| rm -rf ./train | |||||
| fi | |||||
| mkdir ./train | |||||
| cp ../*.py ./train | |||||
| cp *.sh ./train | |||||
| cp -r ../src ./train | |||||
| cd ./train || exit | |||||
| python train.py --device_target="GPU" --dataset_path=$PATH1 &> log & | |||||
| cd .. | |||||
| @@ -0,0 +1,14 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| @@ -0,0 +1,73 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Alexnet.""" | |||||
| import mindspore.nn as nn | |||||
| from mindspore.common.initializer import TruncatedNormal | |||||
| from mindspore.ops import operations as P | |||||
| def conv(in_channels, out_channels, kernel_size, stride=1, padding=0, pad_mode="valid"): | |||||
| weight = weight_variable() | |||||
| return nn.Conv2d(in_channels, out_channels, | |||||
| kernel_size=kernel_size, stride=stride, padding=padding, | |||||
| weight_init=weight, has_bias=False, pad_mode=pad_mode) | |||||
| def fc_with_initialize(input_channels, out_channels): | |||||
| weight = weight_variable() | |||||
| bias = weight_variable() | |||||
| return nn.Dense(input_channels, out_channels, weight, bias) | |||||
| def weight_variable(): | |||||
| return TruncatedNormal(0.02) | |||||
| class AlexNet(nn.Cell): | |||||
| """ | |||||
| Alexnet | |||||
| """ | |||||
| def __init__(self, num_classes=10, channel=3): | |||||
| super(AlexNet, self).__init__() | |||||
| self.conv1 = conv(channel, 96, 11, stride=4) | |||||
| self.conv2 = conv(96, 256, 5, pad_mode="same") | |||||
| self.conv3 = conv(256, 384, 3, pad_mode="same") | |||||
| self.conv4 = conv(384, 384, 3, pad_mode="same") | |||||
| self.conv5 = conv(384, 256, 3, pad_mode="same") | |||||
| self.relu = nn.ReLU() | |||||
| self.max_pool2d = P.MaxPool(ksize=3, strides=2) | |||||
| self.flatten = nn.Flatten() | |||||
| self.fc1 = fc_with_initialize(6*6*256, 4096) | |||||
| self.fc2 = fc_with_initialize(4096, 4096) | |||||
| self.fc3 = fc_with_initialize(4096, num_classes) | |||||
| def construct(self, x): | |||||
| x = self.conv1(x) | |||||
| x = self.relu(x) | |||||
| x = self.max_pool2d(x) | |||||
| x = self.conv2(x) | |||||
| x = self.relu(x) | |||||
| x = self.max_pool2d(x) | |||||
| x = self.conv3(x) | |||||
| x = self.relu(x) | |||||
| x = self.conv4(x) | |||||
| x = self.relu(x) | |||||
| x = self.conv5(x) | |||||
| x = self.relu(x) | |||||
| x = self.max_pool2d(x) | |||||
| x = self.flatten(x) | |||||
| x = self.fc1(x) | |||||
| x = self.relu(x) | |||||
| x = self.fc2(x) | |||||
| x = self.relu(x) | |||||
| x = self.fc3(x) | |||||
| return x | |||||
| @@ -0,0 +1,42 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| network config setting, will be used in train.py | |||||
| """ | |||||
| from easydict import EasyDict as edict | |||||
| cfg = edict({ | |||||
| {% if dataset=='MNIST' %} | |||||
| 'num_classes': 10, | |||||
| {% elif dataset=='Cifar10' %} | |||||
| 'num_classes': 10, | |||||
| {% elif dataset=='ImageNet' %} | |||||
| 'num_classes': 1001, | |||||
| {% endif %} | |||||
| 'lr': 0.002, | |||||
| {% if optimizer=='Momentum' %} | |||||
| "momentum": 0.9, | |||||
| {% endif %} | |||||
| 'epoch_size': 1, | |||||
| 'batch_size': 32, | |||||
| 'buffer_size': 1000, | |||||
| 'image_height': 227, | |||||
| 'image_width': 227, | |||||
| 'save_checkpoint': True, | |||||
| 'save_checkpoint_epochs': 5, | |||||
| 'keep_checkpoint_max': 10, | |||||
| 'save_checkpoint_path': './' | |||||
| }) | |||||
| @@ -0,0 +1,44 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """learning rate generator""" | |||||
| import numpy as np | |||||
| def get_lr(current_step, lr_max, total_epochs, steps_per_epoch): | |||||
| """ | |||||
| generate learning rate array | |||||
| Args: | |||||
| current_step(int): current steps of the training | |||||
| lr_max(float): max learning rate | |||||
| total_epochs(int): total epoch of training | |||||
| steps_per_epoch(int): steps of one epoch | |||||
| Returns: | |||||
| np.array, learning rate array | |||||
| """ | |||||
| lr_each_step = [] | |||||
| total_steps = steps_per_epoch * total_epochs | |||||
| decay_epoch_index = [0.8 * total_steps] | |||||
| for i in range(total_steps): | |||||
| if i < decay_epoch_index[0]: | |||||
| lr = lr_max | |||||
| else: | |||||
| lr = lr_max * 0.1 | |||||
| lr_each_step.append(lr) | |||||
| lr_each_step = np.array(lr_each_step).astype(np.float32) | |||||
| learning_rate = lr_each_step[current_step:] | |||||
| return learning_rate | |||||
| @@ -0,0 +1,122 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| ######################## train alexnet example ######################## | |||||
| train alexnet and get network model files(.ckpt) : | |||||
| python train.py --data_path /YourDataPath | |||||
| """ | |||||
| import argparse | |||||
| from src.config import cfg | |||||
| from src.dataset import create_dataset | |||||
| from src.generator_lr import get_lr | |||||
| from src.alexnet import AlexNet | |||||
| import mindspore.nn as nn | |||||
| from mindspore import context | |||||
| from mindspore import Tensor | |||||
| from mindspore.train import Model | |||||
| from mindspore.nn.metrics import Accuracy | |||||
| from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor | |||||
| from mindspore.parallel._auto_parallel_context import auto_parallel_context | |||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||||
| from mindspore.communication.management import init, get_rank, get_group_size | |||||
| import mindspore.common.initializer as weight_init | |||||
| if __name__ == "__main__": | |||||
| parser = argparse.ArgumentParser(description='MindSpore AlexNet Example') | |||||
| parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute') | |||||
| parser.add_argument('--device_num', type=int, default=1, help='Device num') | |||||
| parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'], | |||||
| help='device where the code will be implemented (default: Ascend)') | |||||
| parser.add_argument('--data_path', type=str, default="./", help='path where the dataset is saved') | |||||
| parser.add_argument('--pre_trained', type=str, default=None, help='Pre-trained checkpoint path') | |||||
| parser.add_argument('--dataset_sink_mode', type=str, default='True', choices = ['True', 'False'], | |||||
| help='DataSet sink mode is True or False') | |||||
| args = parser.parse_args() | |||||
| target = args.device_target | |||||
| ckpt_save_dir = cfg.save_checkpoint_path | |||||
| dataset_sink_mode = args.dataset_sink_mode=='True' | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target=target) | |||||
| if args.run_distribute: | |||||
| if target == "Ascend": | |||||
| device_id = int(os.getenv('DEVICE_ID')) | |||||
| context.set_context(device_id=device_id, enable_auto_mixed_precision=True) | |||||
| context.set_auto_parallel_context(device_num=args.device_num, parallel_mode=ParallelMode.DATA_PARALLEL, | |||||
| mirror_mean=True) | |||||
| auto_parallel_context().set_all_reduce_fusion_split_indices([107, 160]) | |||||
| init() | |||||
| # GPU target | |||||
| else: | |||||
| init("nccl") | |||||
| context.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL, | |||||
| mirror_mean=True) | |||||
| ckpt_save_dir = cfg.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/" | |||||
| data_path = args.data_path | |||||
| do_train = True | |||||
| ds_train = create_dataset(data_path=data_path, batch_size=cfg.batch_size, do_train=do_train, | |||||
| target=target) | |||||
| step_size = ds_train.get_dataset_size() | |||||
| # define net | |||||
| network = AlexNet(cfg.num_classes) | |||||
| # init weight | |||||
| if args.pre_trained: | |||||
| param_dict = load_checkpoint(args.pre_trained) | |||||
| load_param_into_net(network, param_dict) | |||||
| else: | |||||
| for _, cell in network.cells_and_names(): | |||||
| if isinstance(cell, nn.Conv2d): | |||||
| cell.weight.default_input = weight_init.initializer(weight_init.XavierUniform(), | |||||
| cell.weight.default_input.shape, | |||||
| cell.weight.default_input.dtype).to_tensor() | |||||
| if isinstance(cell, nn.Dense): | |||||
| cell.weight.default_input = weight_init.initializer(weight_init.TruncatedNormal(), | |||||
| cell.weight.default_input.shape, | |||||
| cell.weight.default_input.dtype).to_tensor() | |||||
| {% if loss=='SoftmaxCrossEntropyWithLogits' %} | |||||
| net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") | |||||
| {% elif loss=='SoftmaxCrossEntropyExpand' %} | |||||
| net_loss = nn.SoftmaxCrossEntropyExpand(sparse=True) | |||||
| {% endif %} | |||||
| lr = Tensor(get_lr(0, cfg.lr, cfg.epoch_size, ds_train.get_dataset_size())) | |||||
| {% if optimizer=='Lamb' %} | |||||
| net_opt = nn.Lamb(network.trainable_params(), learning_rate=lr) | |||||
| {% elif optimizer=='Momentum' %} | |||||
| net_opt = nn.Momentum(network.trainable_params(), learning_rate=lr, momentum=cfg.momentum) | |||||
| {% endif %} | |||||
| model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) | |||||
| # define callbacks | |||||
| time_cb = TimeMonitor(data_size=step_size) | |||||
| loss_cb = LossMonitor() | |||||
| cb = [time_cb, loss_cb] | |||||
| if cfg.save_checkpoint: | |||||
| cfg_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_epochs * step_size, | |||||
| keep_checkpoint_max=cfg.keep_checkpoint_max) | |||||
| ckpt_cb = ModelCheckpoint(prefix="alexnet", directory=ckpt_save_dir, config=cfg_ck) | |||||
| cb += [ckpt_cb] | |||||
| print("============== Starting Training ==============") | |||||
| model.train(cfg.epoch_size, ds_train, callbacks=cb, dataset_sink_mode=dataset_sink_mode) | |||||
| @@ -0,0 +1,98 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| Produce the dataset | |||||
| """ | |||||
| import os | |||||
| import mindspore.dataset as ds | |||||
| import mindspore.dataset.transforms.vision.c_transforms as CV | |||||
| import mindspore.dataset.transforms.c_transforms as C | |||||
| from mindspore.dataset.transforms.vision import Inter | |||||
| from mindspore.common import dtype as mstype | |||||
| from mindspore.communication.management import init, get_rank, get_group_size | |||||
| from .config import cfg | |||||
| def create_dataset(data_path, batch_size=32, repeat_size=1, do_train=True, target='Ascend'): | |||||
| """ | |||||
| create dataset for train or test | |||||
| """ | |||||
| if do_train: | |||||
| data_path = os.path.join(data_path, "train") | |||||
| else: | |||||
| data_path = os.path.join(data_path, "test") | |||||
| if target == 'Ascend': | |||||
| device_num, rank_id = _get_rank_info() | |||||
| elif target == 'GPU': | |||||
| init("nccl") | |||||
| rank_id = get_rank() | |||||
| device_num = get_group_size() | |||||
| else: | |||||
| device_num = 1 | |||||
| # define dataset | |||||
| if device_num == 1: | |||||
| mnist_ds = ds.MnistDataset(data_path) | |||||
| else: | |||||
| mnist_ds = ds.MnistDataset(data_path, num_parallel_workers=8, shuffle=True, | |||||
| num_shards=device_num, shard_id=rank_id) | |||||
| resize_height, resize_width = cfg.image_height, cfg.image_width | |||||
| rescale = 1.0 / 255.0 | |||||
| shift = 0.0 | |||||
| rescale_nml = 1 / 0.3081 | |||||
| shift_nml = -1 * 0.1307 / 0.3081 | |||||
| # define map operations | |||||
| resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) # Bilinear mode | |||||
| rescale_nml_op = CV.Rescale(rescale_nml, shift_nml) | |||||
| rescale_op = CV.Rescale(rescale, shift) | |||||
| hwc2chw_op = CV.HWC2CHW() | |||||
| type_cast_op = C.TypeCast(mstype.int32) | |||||
| # apply map operations on images | |||||
| mnist_ds = mnist_ds.map(input_columns="label", operations=type_cast_op) | |||||
| mnist_ds = mnist_ds.map(input_columns="image", operations=resize_op) | |||||
| mnist_ds = mnist_ds.map(input_columns="image", operations=rescale_op) | |||||
| mnist_ds = mnist_ds.map(input_columns="image", operations=rescale_nml_op) | |||||
| mnist_ds = mnist_ds.map(input_columns="image", operations=hwc2chw_op) | |||||
| # apply DatasetOps | |||||
| buffer_size = 10000 | |||||
| mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) | |||||
| mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True) | |||||
| mnist_ds = mnist_ds.repeat(repeat_size) | |||||
| return mnist_ds | |||||
| def _get_rank_info(): | |||||
| """ | |||||
| get rank size and rank id | |||||
| """ | |||||
| rank_size = int(os.environ.get("RANK_SIZE", 1)) | |||||
| if rank_size > 1: | |||||
| rank_size = get_group_size() | |||||
| rank_id = get_rank() | |||||
| else: | |||||
| rank_size = 1 | |||||
| rank_id = 0 | |||||
| return rank_size, rank_id | |||||
| @@ -0,0 +1,67 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| ######################## eval lenet example ######################## | |||||
| eval lenet according to model file: | |||||
| python eval.py --data_path /YourDataPath --ckpt_path Your.ckpt | |||||
| """ | |||||
| import argparse | |||||
| import mindspore.nn as nn | |||||
| from mindspore import context | |||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||||
| from mindspore.train import Model | |||||
| from mindspore.nn.metrics import Accuracy | |||||
| from src.dataset import create_dataset | |||||
| from src.config import cfg | |||||
| from src.lenet import LeNet5 | |||||
| if __name__ == "__main__": | |||||
| parser = argparse.ArgumentParser(description='MindSpore Lenet Example') | |||||
| parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'], | |||||
| help='device where the code will be implemented (default: Ascend)') | |||||
| parser.add_argument('--dataset_path', type=str, default="./Data", | |||||
| help='path where the dataset is saved') | |||||
| parser.add_argument('--checkpoint_path', type=str, default="", help='if mode is test, must provide\ | |||||
| path where the trained ckpt file') | |||||
| parser.add_argument('--dataset_sink_mode', type=bool, default=False, help='dataset_sink_mode is False or True') | |||||
| args = parser.parse_args() | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) | |||||
| network = LeNet5(cfg.num_classes) | |||||
| {% if loss=='SoftmaxCrossEntropyWithLogits' %} | |||||
| net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") | |||||
| {% elif loss=='SoftmaxCrossEntropyExpand' %} | |||||
| net_loss = nn.SoftmaxCrossEntropyExpand(sparse=True) | |||||
| {% endif %} | |||||
| {% if optimizer=='Lamb' %} | |||||
| net_opt = nn.Lamb(network.trainable_params(), learning_rate=cfg.lr) | |||||
| {% elif optimizer=='Momentum' %} | |||||
| net_opt = nn.Momentum(network.trainable_params(), learning_rate=cfg.lr, momentum=cfg.momentum) | |||||
| {% endif %} | |||||
| model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) | |||||
| print("============== Starting Testing ==============") | |||||
| param_dict = load_checkpoint(args.checkpoint_path) | |||||
| load_param_into_net(network, param_dict) | |||||
| data_path = args.dataset_path | |||||
| do_train = False | |||||
| ds_eval = create_dataset(data_path=data_path, do_train=do_train, batch_size=cfg.batch_size, | |||||
| target=args.device_target) | |||||
| acc = model.eval(ds_eval, dataset_sink_mode=args.dataset_sink_mode) | |||||
| print("============== {} ==============".format(acc)) | |||||
| @@ -0,0 +1,87 @@ | |||||
| #!/bin/bash | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| if [ $# != 2 ] && [ $# != 3 ] | |||||
| then | |||||
| echo "Usage: sh run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)" | |||||
| exit 1 | |||||
| fi | |||||
| get_real_path(){ | |||||
| if [ "${1:0:1}" == "/" ]; then | |||||
| echo "$1" | |||||
| else | |||||
| echo "$(realpath -m $PWD/$1)" | |||||
| fi | |||||
| } | |||||
| PATH1=$(get_real_path $1) | |||||
| PATH2=$(get_real_path $2) | |||||
| if [ ! -f $PATH1 ] | |||||
| then | |||||
| echo "error: RANK_TABLE_FILE=$PATH1 is not a file" | |||||
| exit 1 | |||||
| fi | |||||
| if [ ! -d $PATH2 ] | |||||
| then | |||||
| echo "error: DATASET_PATH=$PATH2 is not a directory" | |||||
| exit 1 | |||||
| fi | |||||
| if [ $# == 3 ] | |||||
| then | |||||
| PATH3=$(get_real_path $3) | |||||
| fi | |||||
| if [ $# == 3 ] && [ ! -f $PATH3 ] | |||||
| then | |||||
| echo "error: PRETRAINED_CKPT_PATH=$PATH3 is not a file" | |||||
| exit 1 | |||||
| fi | |||||
| ulimit -u unlimited | |||||
| export DEVICE_NUM=8 | |||||
| export RANK_SIZE=$DEVICE_NUM | |||||
| export RANK_TABLE_FILE=$PATH1 | |||||
| export SERVER_ID=0 | |||||
| rank_start=$((DEVICE_NUM * SERVER_ID)) | |||||
| for((i=0; i<DEVICE_NUM; i++)) | |||||
| do | |||||
| export DEVICE_ID=$i | |||||
| export RANK_ID=$((rank_start + i)) | |||||
| rm -rf ./train_parallel$i | |||||
| mkdir ./train_parallel$i | |||||
| cp ../*.py ./train_parallel$i | |||||
| cp *.sh ./train_parallel$i | |||||
| cp -r ../src ./train_parallel$i | |||||
| cd ./train_parallel$i || exit | |||||
| echo "start training for rank $RANK_ID, device $DEVICE_ID" | |||||
| env > env.log | |||||
| if [ $# == 2 ] | |||||
| then | |||||
| python train.py --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$PATH2 &> log & | |||||
| else | |||||
| python train.py --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$PATH2 --pre_trained=$PATH3 &> log & | |||||
| fi | |||||
| cd .. | |||||
| done | |||||
| @@ -0,0 +1,72 @@ | |||||
| #!/bin/bash | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| if [ $# != 1 ]&& [ $# != 2 ] | |||||
| then | |||||
| echo "Usage: sh run_distribute_train_gpu.sh [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)" | |||||
| exit 1 | |||||
| fi | |||||
| get_real_path(){ | |||||
| if [ "${1:0:1}" == "/" ]; then | |||||
| echo "$1" | |||||
| else | |||||
| echo "$(realpath -m $PWD/$1)" | |||||
| fi | |||||
| } | |||||
| PATH1=$(get_real_path $1) | |||||
| if [ ! -d $PATH1 ] | |||||
| then | |||||
| echo "error: DATASET_PATH=$PATH1 is not a directory" | |||||
| exit 1 | |||||
| fi | |||||
| if [ $# == 2 ] | |||||
| then | |||||
| PATH2=$(get_real_path $2) | |||||
| fi | |||||
| if [ $# == 2 ] && [ ! -f $PATH2 ] | |||||
| then | |||||
| echo "error: PRETRAINED_CKPT_PATH=$PATH2 is not a file" | |||||
| exit 1 | |||||
| fi | |||||
| ulimit -u unlimited | |||||
| export DEVICE_NUM=8 | |||||
| export RANK_SIZE=$DEVICE_NUM | |||||
| rm -rf ./train_parallel | |||||
| mkdir ./train_parallel | |||||
| cp ../*.py ./train_parallel | |||||
| cp *.sh ./train_parallel | |||||
| cp -r ../src ./train_parallel | |||||
| cd ./train_parallel || exit | |||||
| if [ $# == 1 ] | |||||
| then | |||||
| mpirun --allow-run-as-root -n $RANK_SIZE \ | |||||
| python train.py --run_distribute=True \ | |||||
| --device_num=$DEVICE_NUM --device_target="GPU" --dataset_path=$PATH1 &> log & | |||||
| fi | |||||
| if [ $# == 2 ] | |||||
| then | |||||
| mpirun --allow-run-as-root -n $RANK_SIZE \ | |||||
| python train.py --run_distribute=True \ | |||||
| --device_num=$DEVICE_NUM --device_target="GPU" --dataset_path=$PATH1 --pre_trained=$PATH2 &> log & | |||||
| fi | |||||
| @@ -0,0 +1,65 @@ | |||||
| #!/bin/bash | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| if [ $# != 2 ] | |||||
| then | |||||
| echo "Usage: sh run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH]" | |||||
| exit 1 | |||||
| fi | |||||
| get_real_path(){ | |||||
| if [ "${1:0:1}" == "/" ]; then | |||||
| echo "$1" | |||||
| else | |||||
| echo "$(realpath -m $PWD/$1)" | |||||
| fi | |||||
| } | |||||
| PATH1=$(get_real_path $1) | |||||
| PATH2=$(get_real_path $2) | |||||
| if [ ! -d $PATH1 ] | |||||
| then | |||||
| echo "error: DATASET_PATH=$PATH1 is not a directory" | |||||
| exit 1 | |||||
| fi | |||||
| if [ ! -f $PATH2 ] | |||||
| then | |||||
| echo "error: CHECKPOINT_PATH=$PATH2 is not a file" | |||||
| exit 1 | |||||
| fi | |||||
| ulimit -u unlimited | |||||
| export DEVICE_NUM=1 | |||||
| export DEVICE_ID=0 | |||||
| export RANK_SIZE=$DEVICE_NUM | |||||
| export RANK_ID=0 | |||||
| if [ -d "eval" ]; | |||||
| then | |||||
| rm -rf ./eval | |||||
| fi | |||||
| mkdir ./eval | |||||
| cp ../*.py ./eval | |||||
| cp *.sh ./eval | |||||
| cp -r ../src ./eval | |||||
| cd ./eval || exit | |||||
| env > env.log | |||||
| echo "start evaluation for device $DEVICE_ID" | |||||
| python eval.py --dataset_path=$PATH1 --checkpoint_path=$PATH2 &> log & | |||||
| cd .. | |||||
| @@ -0,0 +1,66 @@ | |||||
| #!/bin/bash | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| if [ $# != 2 ] | |||||
| then | |||||
| echo "Usage: sh run_eval_gpu.sh [DATASET_PATH] [CHECKPOINT_PATH]" | |||||
| exit 1 | |||||
| fi | |||||
| get_real_path(){ | |||||
| if [ "${1:0:1}" == "/" ]; then | |||||
| echo "$1" | |||||
| else | |||||
| echo "$(realpath -m $PWD/$1)" | |||||
| fi | |||||
| } | |||||
| PATH1=$(get_real_path $1) | |||||
| PATH2=$(get_real_path $2) | |||||
| if [ ! -d $PATH1 ] | |||||
| then | |||||
| echo "error: DATASET_PATH=$PATH1 is not a directory" | |||||
| exit 1 | |||||
| fi | |||||
| if [ ! -f $PATH2 ] | |||||
| then | |||||
| echo "error: CHECKPOINT_PATH=$PATH2 is not a file" | |||||
| exit 1 | |||||
| fi | |||||
| ulimit -u unlimited | |||||
| export DEVICE_NUM=1 | |||||
| export DEVICE_ID=0 | |||||
| export RANK_SIZE=$DEVICE_NUM | |||||
| export RANK_ID=0 | |||||
| if [ -d "eval" ]; | |||||
| then | |||||
| rm -rf ./eval | |||||
| fi | |||||
| mkdir ./eval | |||||
| cp ../*.py ./eval | |||||
| cp *.sh ./eval | |||||
| cp -r ../src ./eval | |||||
| cd ./eval || exit | |||||
| env > env.log | |||||
| echo "start evaluation for device $DEVICE_ID" | |||||
| python eval.py --dataset_path=$PATH1 --checkpoint_path=$PATH2 --device_target="GPU" &> log & | |||||
| cd .. | |||||
| @@ -0,0 +1,79 @@ | |||||
| #!/bin/bash | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| if [ $# != 1 ] && [ $# != 2 ] | |||||
| then | |||||
| echo "Usage: sh run_standalone_train.sh [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)" | |||||
| exit 1 | |||||
| fi | |||||
| get_real_path(){ | |||||
| if [ "${1:0:1}" == "/" ]; then | |||||
| echo "$1" | |||||
| else | |||||
| echo "$(realpath -m $PWD/$1)" | |||||
| fi | |||||
| } | |||||
| PATH1=$(get_real_path $1) | |||||
| if [ ! -d $PATH1 ] | |||||
| then | |||||
| echo "error: DATASET_PATH=$PATH1 is not a directory" | |||||
| exit 1 | |||||
| fi | |||||
| if [ $# == 2 ] | |||||
| then | |||||
| PATH2=$(get_real_path $2) | |||||
| fi | |||||
| if [ $# == 2 ] && [ ! -f $PATH2 ] | |||||
| then | |||||
| echo "error: PRETRAINED_CKPT_PATH=$PATH2 is not a file" | |||||
| exit 1 | |||||
| fi | |||||
| ulimit -u unlimited | |||||
| export DEVICE_NUM=1 | |||||
| export DEVICE_ID=0 | |||||
| export RANK_ID=0 | |||||
| export RANK_SIZE=1 | |||||
| if [ -d "train" ]; | |||||
| then | |||||
| rm -rf ./train | |||||
| fi | |||||
| mkdir ./train | |||||
| cp ../*.py ./train | |||||
| cp *.sh ./train | |||||
| cp -r ../src ./train | |||||
| cd ./train || exit | |||||
| echo "start training for device $DEVICE_ID" | |||||
| env > env.log | |||||
| if [ $# == 1 ] | |||||
| then | |||||
| python train.py --dataset_path=$PATH1 &> log & | |||||
| fi | |||||
| if [ $# == 2 ] | |||||
| then | |||||
| python train.py --dataset_path=$PATH1 --pre_trained=$PATH2 &> log & | |||||
| fi | |||||
| cd .. | |||||
| @@ -0,0 +1,78 @@ | |||||
| #!/bin/bash | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| if [ $# != 1 ] && [ $# != 2 ] | |||||
| then | |||||
| echo "Usage: sh run_standalone_train_gpu.sh [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)" | |||||
| exit 1 | |||||
| fi | |||||
| get_real_path(){ | |||||
| if [ "${1:0:1}" == "/" ]; then | |||||
| echo "$1" | |||||
| else | |||||
| echo "$(realpath -m $PWD/$1)" | |||||
| fi | |||||
| } | |||||
| PATH1=$(get_real_path $1) | |||||
| if [ ! -d $PATH1 ] | |||||
| then | |||||
| echo "error: DATASET_PATH=$PATH1 is not a directory" | |||||
| exit 1 | |||||
| fi | |||||
| if [ $# == 2 ] | |||||
| then | |||||
| PATH2=$(get_real_path $2) | |||||
| fi | |||||
| if [ $# == 2 ] && [ ! -f $PATH2 ] | |||||
| then | |||||
| echo "error: PRETRAINED_CKPT_PATH=$PATH2 is not a file" | |||||
| exit 1 | |||||
| fi | |||||
| ulimit -u unlimited | |||||
| export DEVICE_NUM=1 | |||||
| export DEVICE_ID=0 | |||||
| export RANK_ID=0 | |||||
| export RANK_SIZE=1 | |||||
| if [ -d "train" ]; | |||||
| then | |||||
| rm -rf ./train | |||||
| fi | |||||
| mkdir ./train | |||||
| cp ../*.py ./train | |||||
| cp *.sh ./train | |||||
| cp -r ../src ./train | |||||
| cd ./train || exit | |||||
| if [ $# == 1 ] | |||||
| then | |||||
| python train.py --device_target="GPU" --dataset_path=$PATH1 &> log & | |||||
| fi | |||||
| if [ $# == 2 ] | |||||
| then | |||||
| python train.py --device_target="GPU" --dataset_path=$PATH1 --pre_trained=$PATH2 &> log & | |||||
| fi | |||||
| cd .. | |||||
| @@ -0,0 +1,14 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| @@ -0,0 +1,43 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| network config setting, will be used in train.py | |||||
| """ | |||||
| from easydict import EasyDict as edict | |||||
| cfg = edict({ | |||||
| {% if dataset=='MNIST' %} | |||||
| 'num_classes': 10, | |||||
| {% elif dataset=='Cifar10' %} | |||||
| 'num_classes': 10, | |||||
| {% elif dataset=='ImageNet' %} | |||||
| 'num_classes': 1001, | |||||
| {% endif %} | |||||
| {% if dataset=='Momentum' %} | |||||
| 'lr': 0.01, | |||||
| {% else %} | |||||
| 'lr': 0.001, | |||||
| {% endif %} | |||||
| {% if optimizer=='Momentum' %} | |||||
| "momentum": 0.9, | |||||
| {% endif %} | |||||
| 'epoch_size': 1, | |||||
| 'batch_size': 32, | |||||
| 'buffer_size': 1000, | |||||
| 'image_height': 32, | |||||
| 'image_width': 32, | |||||
| 'save_checkpoint_steps': 1875, | |||||
| 'keep_checkpoint_max': 10, | |||||
| }) | |||||
| @@ -0,0 +1,78 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """LeNet.""" | |||||
| import mindspore.nn as nn | |||||
| from mindspore.common.initializer import TruncatedNormal | |||||
| def conv(in_channels, out_channels, kernel_size, stride=1, padding=0): | |||||
| """weight initial for conv layer""" | |||||
| weight = weight_variable() | |||||
| return nn.Conv2d(in_channels, out_channels, | |||||
| kernel_size=kernel_size, stride=stride, padding=padding, | |||||
| weight_init=weight, has_bias=False, pad_mode="valid") | |||||
| def fc_with_initialize(input_channels, out_channels): | |||||
| """weight initial for fc layer""" | |||||
| weight = weight_variable() | |||||
| bias = weight_variable() | |||||
| return nn.Dense(input_channels, out_channels, weight, bias) | |||||
| def weight_variable(): | |||||
| """weight initial""" | |||||
| return TruncatedNormal(0.02) | |||||
| class LeNet5(nn.Cell): | |||||
| """ | |||||
| Lenet network | |||||
| Args: | |||||
| num_class (int): Num classes. Default: 10. | |||||
| Returns: | |||||
| Tensor, output tensor | |||||
| Examples: | |||||
| >>> LeNet(num_class=10) | |||||
| """ | |||||
| def __init__(self, num_class=10, channel=1): | |||||
| super(LeNet5, self).__init__() | |||||
| self.num_class = num_class | |||||
| self.conv1 = conv(channel, 6, 5) | |||||
| self.conv2 = conv(6, 16, 5) | |||||
| self.fc1 = fc_with_initialize(16 * 5 * 5, 120) | |||||
| self.fc2 = fc_with_initialize(120, 84) | |||||
| self.fc3 = fc_with_initialize(84, self.num_class) | |||||
| self.relu = nn.ReLU() | |||||
| self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) | |||||
| self.flatten = nn.Flatten() | |||||
| def construct(self, x): | |||||
| x = self.conv1(x) | |||||
| x = self.relu(x) | |||||
| x = self.max_pool2d(x) | |||||
| x = self.conv2(x) | |||||
| x = self.relu(x) | |||||
| x = self.max_pool2d(x) | |||||
| x = self.flatten(x) | |||||
| x = self.fc1(x) | |||||
| x = self.relu(x) | |||||
| x = self.fc2(x) | |||||
| x = self.relu(x) | |||||
| x = self.fc3(x) | |||||
| return x | |||||
| @@ -0,0 +1,95 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| ######################## train lenet example ######################## | |||||
| train lenet and get network model files(.ckpt) : | |||||
| python train.py --data_path /YourDataPath | |||||
| """ | |||||
| import os | |||||
| import argparse | |||||
| import mindspore.nn as nn | |||||
| from mindspore import context, ParallelMode | |||||
| from mindspore.communication.management import init, get_rank | |||||
| from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor | |||||
| from mindspore.train import Model | |||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||||
| from mindspore.nn.metrics import Accuracy | |||||
| from src.config import cfg | |||||
| from src.dataset import create_dataset | |||||
| from src.lenet import LeNet5 | |||||
| if __name__ == "__main__": | |||||
| parser = argparse.ArgumentParser(description='MindSpore Lenet Example') | |||||
| parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute') | |||||
| parser.add_argument('--device_num', type=int, default=1, help='Device num.') | |||||
| parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'], | |||||
| help='device where the code will be implemented (default: Ascend)') | |||||
| parser.add_argument('--dataset_path', type=str, default="./Data", | |||||
| help='path where the dataset is saved') | |||||
| parser.add_argument('--pre_trained', type=str, default=None, help='Pre-trained checkpoint path') | |||||
| parser.add_argument('--dataset_sink', action='store_true', help='enable dataset sink or not') | |||||
| args = parser.parse_args() | |||||
| if args.device_target == "CPU": | |||||
| args.dataset_sink = False | |||||
| ckpt_save_dir = './' | |||||
| if args.run_distribute: | |||||
| if args.device_target == 'Ascend': | |||||
| device_id = int(os.getenv('DEVICE_ID')) | |||||
| context.set_context(device_id=device_id) | |||||
| init() | |||||
| elif args.device_target == "GPU": | |||||
| init("nccl") | |||||
| ckpt_save_dir = os.path.join(ckpt_save_dir, 'ckpt_' + str(get_rank())) | |||||
| else: | |||||
| raise ValueError('Distribute running is no supported on %s' % args.device_target) | |||||
| context.reset_auto_parallel_context() | |||||
| context.set_auto_parallel_context(device_num=args.device_num, parallel_mode=ParallelMode.DATA_PARALLEL, | |||||
| mirror_mean=True) | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) | |||||
| data_path = args.dataset_path | |||||
| do_train = True | |||||
| ds_train = create_dataset(data_path=data_path, do_train=do_train, | |||||
| batch_size=cfg.batch_size, target=args.device_target) | |||||
| network = LeNet5(cfg.num_classes) | |||||
| if args.pre_trained: | |||||
| param_dict = load_checkpoint(args.pre_trained) | |||||
| load_param_into_net(network, param_dict) | |||||
| {% if loss=='SoftmaxCrossEntropyWithLogits' %} | |||||
| net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") | |||||
| {% elif loss=='SoftmaxCrossEntropyExpand' %} | |||||
| net_loss = nn.SoftmaxCrossEntropyExpand(sparse=True) | |||||
| {% endif %} | |||||
| {% if optimizer=='Lamb' %} | |||||
| net_opt = nn.Lamb(network.trainable_params(), learning_rate=cfg.lr) | |||||
| {% elif optimizer=='Momentum' %} | |||||
| net_opt = nn.Momentum(network.trainable_params(), learning_rate=cfg.lr, momentum=cfg.momentum) | |||||
| {% endif %} | |||||
| time_cb = TimeMonitor(data_size=ds_train.get_dataset_size()) | |||||
| config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps, | |||||
| keep_checkpoint_max=cfg.keep_checkpoint_max) | |||||
| ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", directory=ckpt_save_dir, config=config_ck) | |||||
| model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) | |||||
| print("============== Starting Training ==============") | |||||
| model.train(cfg['epoch_size'], ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()], | |||||
| dataset_sink_mode=args.dataset_sink) | |||||
| @@ -0,0 +1,87 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| Produce the dataset | |||||
| """ | |||||
| import os | |||||
| import mindspore.dataset as ds | |||||
| import mindspore.dataset.transforms.c_transforms as C | |||||
| import mindspore.dataset.transforms.vision.c_transforms as CV | |||||
| from mindspore.common import dtype as mstype | |||||
| from mindspore.communication.management import init, get_rank, get_group_size | |||||
| from .config import cfg | |||||
| def create_dataset(data_path, batch_size=32, repeat_size=1, do_train=True, target="Ascend"): | |||||
| """ | |||||
| create dataset for train or test | |||||
| """ | |||||
| if target == "Ascend": | |||||
| device_num, rank_id = _get_rank_info() | |||||
| elif target == "GPU": | |||||
| init("nccl") | |||||
| rank_id = get_rank() | |||||
| device_num = get_group_size() | |||||
| else: | |||||
| device_num = 1 | |||||
| if device_num == 1: | |||||
| cifar_ds = ds.Cifar10Dataset(data_path, num_parallel_workers=8) | |||||
| else: | |||||
| cifar_ds = ds.Cifar10Dataset(data_path, num_parallel_workers=8, num_shards=device_num, shard_id=rank_id) | |||||
| rescale = 1.0 / 255.0 | |||||
| shift = 0.0 | |||||
| rescale = 1.0 / 255.0 | |||||
| shift = 0.0 | |||||
| resize_op = CV.Resize((cfg.image_height, cfg.image_width)) | |||||
| rescale_op = CV.Rescale(rescale, shift) | |||||
| normalize_op = CV.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) | |||||
| if do_train: | |||||
| random_crop_op = CV.RandomCrop([32, 32], [4, 4, 4, 4]) | |||||
| random_horizontal_op = CV.RandomHorizontalFlip() | |||||
| channel_swap_op = CV.HWC2CHW() | |||||
| typecast_op = C.TypeCast(mstype.int32) | |||||
| cifar_ds = cifar_ds.map(input_columns="label", operations=typecast_op) | |||||
| if do_train: | |||||
| cifar_ds = cifar_ds.map(input_columns="image", operations=random_crop_op) | |||||
| cifar_ds = cifar_ds.map(input_columns="image", operations=random_horizontal_op) | |||||
| cifar_ds = cifar_ds.map(input_columns="image", operations=resize_op) | |||||
| cifar_ds = cifar_ds.map(input_columns="image", operations=rescale_op) | |||||
| cifar_ds = cifar_ds.map(input_columns="image", operations=normalize_op) | |||||
| cifar_ds = cifar_ds.map(input_columns="image", operations=channel_swap_op) | |||||
| cifar_ds = cifar_ds.shuffle(buffer_size=1000) | |||||
| cifar_ds = cifar_ds.batch(batch_size, drop_remainder=True) | |||||
| cifar_ds = cifar_ds.repeat(repeat_size) | |||||
| return cifar_ds | |||||
| def _get_rank_info(): | |||||
| """ | |||||
| get rank size and rank id | |||||
| """ | |||||
| rank_size = int(os.environ.get("RANK_SIZE", 1)) | |||||
| if rank_size > 1: | |||||
| rank_size = get_group_size() | |||||
| rank_id = get_rank() | |||||
| else: | |||||
| rank_size = 1 | |||||
| rank_id = 0 | |||||
| return rank_size, rank_id | |||||
| @@ -0,0 +1,104 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| create train or eval dataset. | |||||
| """ | |||||
| import os | |||||
| import mindspore.common.dtype as mstype | |||||
| import mindspore.dataset.engine as de | |||||
| import mindspore.dataset.transforms.vision.c_transforms as C | |||||
| import mindspore.dataset.transforms.c_transforms as C2 | |||||
| from mindspore.communication.management import init, get_rank, get_group_size | |||||
| from .config import cfg | |||||
| def create_dataset(data_path, batch_size=32, repeat_size=1, do_train=True, target="Ascend"): | |||||
| """ | |||||
| create a train or eval imagenet dataset | |||||
| Args: | |||||
| dataset_path(string): the path of dataset. | |||||
| do_train(bool): whether dataset is used for train or eval. | |||||
| repeat_num(int): the repeat times of dataset. Default: 1 | |||||
| batch_size(int): the batch size of dataset. Default: 32 | |||||
| target(string): the target of device. Default: "Ascend" | |||||
| Returns: | |||||
| dataset | |||||
| """ | |||||
| if target == "Ascend": | |||||
| device_num, rank_id = _get_rank_info() | |||||
| elif target == "GPU": | |||||
| init("nccl") | |||||
| rank_id = get_rank() | |||||
| device_num = get_group_size() | |||||
| else: | |||||
| device_num = 1 | |||||
| if device_num == 1: | |||||
| ds = de.ImageFolderDatasetV2(data_path, num_parallel_workers=8, shuffle=True) | |||||
| else: | |||||
| ds = de.ImageFolderDatasetV2(data_path, num_parallel_workers=8, shuffle=True, | |||||
| num_shards=device_num, shard_id=rank_id) | |||||
| image_size = cfg.image_height | |||||
| mean = [0.485 * 255, 0.456 * 255, 0.406 * 255] | |||||
| std = [0.229 * 255, 0.224 * 255, 0.225 * 255] | |||||
| # define map operations | |||||
| if do_train: | |||||
| trans = [ | |||||
| C.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)), | |||||
| C.RandomHorizontalFlip(prob=0.5), | |||||
| C.Normalize(mean=mean, std=std), | |||||
| C.HWC2CHW() | |||||
| ] | |||||
| else: | |||||
| trans = [ | |||||
| C.Decode(), | |||||
| C.Resize(256), | |||||
| C.CenterCrop(image_size), | |||||
| C.Normalize(mean=mean, std=std), | |||||
| C.HWC2CHW() | |||||
| ] | |||||
| type_cast_op = C2.TypeCast(mstype.int32) | |||||
| ds = ds.map(input_columns="image", num_parallel_workers=8, operations=trans) | |||||
| ds = ds.map(input_columns="label", num_parallel_workers=8, operations=type_cast_op) | |||||
| # apply batch operations | |||||
| ds = ds.batch(batch_size, drop_remainder=True) | |||||
| # apply dataset repeat operation | |||||
| ds = ds.repeat(repeat_size) | |||||
| return ds | |||||
| def _get_rank_info(): | |||||
| """ | |||||
| get rank size and rank id | |||||
| """ | |||||
| rank_size = int(os.environ.get("RANK_SIZE", 1)) | |||||
| if rank_size > 1: | |||||
| rank_size = get_group_size() | |||||
| rank_id = get_rank() | |||||
| else: | |||||
| rank_size = 1 | |||||
| rank_id = 0 | |||||
| return rank_size, rank_id | |||||
| @@ -0,0 +1,98 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| Produce the dataset | |||||
| """ | |||||
| import os | |||||
| import mindspore.dataset as ds | |||||
| import mindspore.dataset.transforms.vision.c_transforms as CV | |||||
| import mindspore.dataset.transforms.c_transforms as C | |||||
| from mindspore.dataset.transforms.vision import Inter | |||||
| from mindspore.common import dtype as mstype | |||||
| from mindspore.communication.management import init, get_rank, get_group_size | |||||
| from .config import cfg | |||||
| def create_dataset(data_path, batch_size=32, repeat_size=1, do_train=True, target='Ascend'): | |||||
| """ | |||||
| create dataset for train or test | |||||
| """ | |||||
| if do_train: | |||||
| data_path = os.path.join(data_path, "train") | |||||
| else: | |||||
| data_path = os.path.join(data_path, "test") | |||||
| if target == 'Ascend': | |||||
| device_num, rank_id = _get_rank_info() | |||||
| elif target == 'GPU': | |||||
| init("nccl") | |||||
| rank_id = get_rank() | |||||
| device_num = get_group_size() | |||||
| else: | |||||
| device_num = 1 | |||||
| # define dataset | |||||
| if device_num == 1: | |||||
| mnist_ds = ds.MnistDataset(data_path) | |||||
| else: | |||||
| mnist_ds = ds.MnistDataset(data_path, num_parallel_workers=8, shuffle=True, | |||||
| num_shards=device_num, shard_id=rank_id) | |||||
| resize_height, resize_width = cfg.image_height, cfg.image_width | |||||
| rescale = 1.0 / 255.0 | |||||
| shift = 0.0 | |||||
| rescale_nml = 1 / 0.3081 | |||||
| shift_nml = -1 * 0.1307 / 0.3081 | |||||
| # define map operations | |||||
| resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) # Bilinear mode | |||||
| rescale_nml_op = CV.Rescale(rescale_nml, shift_nml) | |||||
| rescale_op = CV.Rescale(rescale, shift) | |||||
| hwc2chw_op = CV.HWC2CHW() | |||||
| type_cast_op = C.TypeCast(mstype.int32) | |||||
| # apply map operations on images | |||||
| mnist_ds = mnist_ds.map(input_columns="label", operations=type_cast_op) | |||||
| mnist_ds = mnist_ds.map(input_columns="image", operations=resize_op) | |||||
| mnist_ds = mnist_ds.map(input_columns="image", operations=rescale_op) | |||||
| mnist_ds = mnist_ds.map(input_columns="image", operations=rescale_nml_op) | |||||
| mnist_ds = mnist_ds.map(input_columns="image", operations=hwc2chw_op) | |||||
| # apply DatasetOps | |||||
| buffer_size = 10000 | |||||
| mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) | |||||
| mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True) | |||||
| mnist_ds = mnist_ds.repeat(repeat_size) | |||||
| return mnist_ds | |||||
| def _get_rank_info(): | |||||
| """ | |||||
| get rank size and rank id | |||||
| """ | |||||
| rank_size = int(os.environ.get("RANK_SIZE", 1)) | |||||
| if rank_size > 1: | |||||
| rank_size = get_group_size() | |||||
| rank_id = get_rank() | |||||
| else: | |||||
| rank_size = 1 | |||||
| rank_id = 0 | |||||
| return rank_size, rank_id | |||||
| @@ -0,0 +1,85 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """train resnet.""" | |||||
| import os | |||||
| import random | |||||
| import argparse | |||||
| import numpy as np | |||||
| import mindspore.nn as nn | |||||
| from mindspore import context | |||||
| from mindspore import dataset as de | |||||
| from mindspore.train.model import Model | |||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||||
| from src.crossentropy import CrossEntropy | |||||
| parser = argparse.ArgumentParser(description='Image classification') | |||||
| parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path') | |||||
| parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path') | |||||
| parser.add_argument('--device_target', type=str, default='Ascend', help='Device target') | |||||
| parser.add_argument('--dataset_sink_mode', type=str, default='True', choices = ['True', 'False'], | |||||
| help='DataSet sink mode is True or False') | |||||
| args_opt = parser.parse_args() | |||||
| random.seed(1) | |||||
| np.random.seed(1) | |||||
| de.config.set_seed(1) | |||||
| from src.resnet50 import resnet50 as resnet | |||||
| from src.config import cfg | |||||
| from src.dataset import create_dataset | |||||
| if __name__ == '__main__': | |||||
| target = args_opt.device_target | |||||
| # init context | |||||
| device_id = int(os.getenv('DEVICE_ID')) | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False, device_id=device_id) | |||||
| dataset_sink_mode = args_opt.dataset_sink_mode=='True' | |||||
| # create dataset | |||||
| dataset = create_dataset(data_path=args_opt.dataset_path, do_train=False, batch_size=cfg.batch_size, target=target) | |||||
| step_size = dataset.get_dataset_size() | |||||
| # define net | |||||
| net = resnet(class_num=cfg.num_classes) | |||||
| # load checkpoint | |||||
| param_dict = load_checkpoint(args_opt.checkpoint_path) | |||||
| load_param_into_net(net, param_dict) | |||||
| net.set_train(False) | |||||
| # define loss, model | |||||
| {% if dataset=='ImageNet' %} | |||||
| if not cfg.use_label_smooth: | |||||
| cfg.label_smooth_factor = 0.0 | |||||
| loss = CrossEntropy(smooth_factor=cfg.label_smooth_factor, num_classes=cfg.num_classes) | |||||
| {% else %} | |||||
| {% if loss=='SoftmaxCrossEntropyWithLogits' %} | |||||
| loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') | |||||
| {% elif loss=='SoftmaxCrossEntropyExpand' %} | |||||
| loss = nn.SoftmaxCrossEntropyExpand(sparse=True) | |||||
| {% endif %} | |||||
| {% endif %} | |||||
| # define model | |||||
| model = Model(net, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'}) | |||||
| # eval model | |||||
| res = model.eval(dataset, dataset_sink_mode=dataset_sink_mode) | |||||
| print("result:", res, "ckpt=", args_opt.checkpoint_path) | |||||
| @@ -0,0 +1,88 @@ | |||||
| #!/bin/bash | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| if [ $# != 2 ] && [ $# != 3 ] | |||||
| then | |||||
| echo "Usage: sh run_distribute_train.sh [DATASET_PATH] [MINDSPORE_HCCL_CONFIG_PATH] [PRETRAINED_CKPT_PATH](optional)" | |||||
| exit 1 | |||||
| fi | |||||
| get_real_path(){ | |||||
| if [ "${1:0:1}" == "/" ]; then | |||||
| echo "$1" | |||||
| else | |||||
| echo "$(realpath -m $PWD/$1)" | |||||
| fi | |||||
| } | |||||
| PATH1=$(get_real_path $1) | |||||
| PATH2=$(get_real_path $2) | |||||
| if [ ! -d $PATH1 ] | |||||
| then | |||||
| echo "error: DATASET_PATH=$PATH1 is not a directory" | |||||
| exit 1 | |||||
| fi | |||||
| if [ ! -f $PATH2 ] | |||||
| then | |||||
| echo "error: MINDSPORE_HCCL_CONFIG_PATH=$PATH2 is not a file" | |||||
| exit 1 | |||||
| fi | |||||
| if [ $# == 3 ] | |||||
| then | |||||
| PATH3=$(get_real_path $3) | |||||
| fi | |||||
| if [ $# == 3 ] && [ ! -f $PATH3 ] | |||||
| then | |||||
| echo "error: PRETRAINED_CKPT_PATH=$PATH3 is not a file" | |||||
| exit 1 | |||||
| fi | |||||
| ulimit -u unlimited | |||||
| export DEVICE_NUM=8 | |||||
| export RANK_SIZE=8 | |||||
| export MINDSPORE_HCCL_CONFIG_PATH=$PATH2 | |||||
| export RANK_TABLE_FILE=$PATH2 | |||||
| export SERVER_ID=0 | |||||
| rank_start=$((DEVICE_NUM * SERVER_ID)) | |||||
| for((i=0; i<DEVICE_NUM; i++)) | |||||
| do | |||||
| export DEVICE_ID=$i | |||||
| export RANK_ID=$((rank_start + i)) | |||||
| rm -rf ./train_parallel$i | |||||
| mkdir ./train_parallel$i | |||||
| cp ../*.py ./train_parallel$i | |||||
| cp *.sh ./train_parallel$i | |||||
| cp -r ../src ./train_parallel$i | |||||
| cd ./train_parallel$i || exit | |||||
| echo "start training for rank $RANK_ID, device $DEVICE_ID" | |||||
| env > env.log | |||||
| if [ $# == 2 ] | |||||
| then | |||||
| python train.py --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$PATH1 --dataset_sink_mode=False &> log & | |||||
| fi | |||||
| if [ $# == 3 ] | |||||
| then | |||||
| python train.py --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$PATH1 --pre_trained=$PATH2 --dataset_sink_mode=False &> log & | |||||
| fi | |||||
| cd .. | |||||
| done | |||||
| @@ -0,0 +1,53 @@ | |||||
| #!/bin/bash | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| if [ $# != 1 ] | |||||
| then | |||||
| echo "Usage: sh run_distribute_train_gpu.sh [DATASET_PATH]" | |||||
| exit 1 | |||||
| fi | |||||
| get_real_path(){ | |||||
| if [ "${1:0:1}" == "/" ]; then | |||||
| echo "$1" | |||||
| else | |||||
| echo "$(realpath -m $PWD/$1)" | |||||
| fi | |||||
| } | |||||
| PATH1=$(get_real_path $1) | |||||
| if [ ! -d $PATH1 ] | |||||
| then | |||||
| echo "error: DATASET_PATH=$PATH1 is not a directory" | |||||
| exit 1 | |||||
| fi | |||||
| ulimit -u unlimited | |||||
| export DEVICE_NUM=4 | |||||
| export RANK_SIZE=4 | |||||
| rm -rf ./train_parallel | |||||
| mkdir ./train_parallel | |||||
| cp ../*.py ./train_parallel | |||||
| cp *.sh ./train_parallel | |||||
| cp -r ../src ./train_parallel | |||||
| cd ./train_parallel || exit | |||||
| mpirun --allow-run-as-root -n $RANK_SIZE \ | |||||
| python train.py --run_distribute=True \ | |||||
| --device_num=$DEVICE_NUM --device_target="GPU" --dataset_path=$PATH1 &> log & | |||||
| @@ -0,0 +1,65 @@ | |||||
| #!/bin/bash | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| if [ $# != 2 ] | |||||
| then | |||||
| echo "Usage: sh run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH]" | |||||
| exit 1 | |||||
| fi | |||||
| get_real_path(){ | |||||
| if [ "${1:0:1}" == "/" ]; then | |||||
| echo "$1" | |||||
| else | |||||
| echo "$(realpath -m $PWD/$1)" | |||||
| fi | |||||
| } | |||||
| PATH1=$(get_real_path $1) | |||||
| PATH2=$(get_real_path $2) | |||||
| if [ ! -d $PATH1 ] | |||||
| then | |||||
| echo "error: DATASET_PATH=$PATH1 is not a directory" | |||||
| exit 1 | |||||
| fi | |||||
| if [ ! -f $PATH2 ] | |||||
| then | |||||
| echo "error: CHECKPOINT_PATH=$PATH2 is not a file" | |||||
| exit 1 | |||||
| fi | |||||
| ulimit -u unlimited | |||||
| export DEVICE_NUM=1 | |||||
| export DEVICE_ID=0 | |||||
| export RANK_SIZE=$DEVICE_NUM | |||||
| export RANK_ID=0 | |||||
| if [ -d "eval" ]; | |||||
| then | |||||
| rm -rf ./eval | |||||
| fi | |||||
| mkdir ./eval | |||||
| cp ../*.py ./eval | |||||
| cp *.sh ./eval | |||||
| cp -r ../src ./eval | |||||
| cd ./eval || exit | |||||
| env > env.log | |||||
| echo "start evaluation for device $DEVICE_ID" | |||||
| python eval.py --dataset_path=$PATH1 --checkpoint_path=$PATH2 &> log & | |||||
| cd .. | |||||
| @@ -0,0 +1,66 @@ | |||||
| #!/bin/bash | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| if [ $# != 2 ] | |||||
| then | |||||
| echo "Usage: sh run_eval_gpu.sh [DATASET_PATH] [CHECKPOINT_PATH]" | |||||
| exit 1 | |||||
| fi | |||||
| get_real_path(){ | |||||
| if [ "${1:0:1}" == "/" ]; then | |||||
| echo "$1" | |||||
| else | |||||
| echo "$(realpath -m $PWD/$1)" | |||||
| fi | |||||
| } | |||||
| PATH1=$(get_real_path $1) | |||||
| PATH2=$(get_real_path $2) | |||||
| if [ ! -d $PATH1 ] | |||||
| then | |||||
| echo "error: DATASET_PATH=$PATH1 is not a directory" | |||||
| exit 1 | |||||
| fi | |||||
| if [ ! -f $PATH2 ] | |||||
| then | |||||
| echo "error: CHECKPOINT_PATH=$PATH2 is not a file" | |||||
| exit 1 | |||||
| fi | |||||
| ulimit -u unlimited | |||||
| export DEVICE_NUM=1 | |||||
| export DEVICE_ID=0 | |||||
| export RANK_SIZE=$DEVICE_NUM | |||||
| export RANK_ID=0 | |||||
| if [ -d "eval" ]; | |||||
| then | |||||
| rm -rf ./eval | |||||
| fi | |||||
| mkdir ./eval | |||||
| cp ../*.py ./eval | |||||
| cp *.sh ./eval | |||||
| cp -r ../src ./eval | |||||
| cd ./eval || exit | |||||
| env > env.log | |||||
| echo "start evaluation for device $DEVICE_ID" | |||||
| python eval.py --dataset_path=$PATH1 --checkpoint_path=$PATH2 --device_target="GPU" &> log & | |||||
| cd .. | |||||
| @@ -0,0 +1,76 @@ | |||||
| #!/bin/bash | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| if [ $# != 1 ] && [ $# != 2 ] | |||||
| then | |||||
| echo "Usage: sh run_standalone_train.sh [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)" | |||||
| exit 1 | |||||
| fi | |||||
| get_real_path(){ | |||||
| if [ "${1:0:1}" == "/" ]; then | |||||
| echo "$1" | |||||
| else | |||||
| echo "$(realpath -m $PWD/$1)" | |||||
| fi | |||||
| } | |||||
| PATH1=$(get_real_path $1) | |||||
| if [ $# == 2 ] | |||||
| then | |||||
| PATH2=$(get_real_path $2) | |||||
| fi | |||||
| if [ ! -d $PATH1 ] | |||||
| then | |||||
| echo "error: DATASET_PATH=$PATH1 is not a directory" | |||||
| exit 1 | |||||
| fi | |||||
| if [ $# == 2 ] && [ ! -f $PATH2 ] | |||||
| then | |||||
| echo "error: PRETRAINED_CKPT_PATH=$PATH2 is not a file" | |||||
| exit 1 | |||||
| fi | |||||
| ulimit -u unlimited | |||||
| export DEVICE_NUM=1 | |||||
| export DEVICE_ID=0 | |||||
| export RANK_ID=0 | |||||
| export RANK_SIZE=1 | |||||
| if [ -d "train" ]; | |||||
| then | |||||
| rm -rf ./train | |||||
| fi | |||||
| mkdir ./train | |||||
| cp ../*.py ./train | |||||
| cp *.sh ./train | |||||
| cp -r ../src ./train | |||||
| cd ./train || exit | |||||
| echo "start training for device $DEVICE_ID" | |||||
| env > env.log | |||||
| if [ $# == 1 ] | |||||
| then | |||||
| python train.py --dataset_path=$PATH1 --dataset_sink_mode=False &> log & | |||||
| fi | |||||
| if [ $# == 2 ] | |||||
| then | |||||
| python train.py --dataset_path=$PATH1 --pre_trained=$PATH2 --dataset_sink_mode=False &> log & | |||||
| fi | |||||
| cd .. | |||||
| @@ -0,0 +1,59 @@ | |||||
| #!/bin/bash | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| if [ $# != 1 ] | |||||
| then | |||||
| echo "Usage: sh run_standalone_train_gpu.sh [DATASET_PATH]" | |||||
| exit 1 | |||||
| fi | |||||
| get_real_path(){ | |||||
| if [ "${1:0:1}" == "/" ]; then | |||||
| echo "$1" | |||||
| else | |||||
| echo "$(realpath -m $PWD/$1)" | |||||
| fi | |||||
| } | |||||
| PATH1=$(get_real_path $1) | |||||
| if [ ! -d $PATH1 ] | |||||
| then | |||||
| echo "error: DATASET_PATH=$PATH1 is not a directory" | |||||
| exit 1 | |||||
| fi | |||||
| ulimit -u unlimited | |||||
| export DEVICE_NUM=1 | |||||
| export DEVICE_ID=0 | |||||
| export RANK_ID=0 | |||||
| export RANK_SIZE=1 | |||||
| if [ -d "train" ]; | |||||
| then | |||||
| rm -rf ./train | |||||
| fi | |||||
| mkdir ./train | |||||
| cp ../*.py ./train | |||||
| cp *.sh ./train | |||||
| cp -r ../src ./train | |||||
| cd ./train || exit | |||||
| python train.py --device_target="GPU" --dataset_path=$PATH1 &> log & | |||||
| cd .. | |||||
| @@ -0,0 +1,14 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| @@ -0,0 +1,55 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| network config setting, will be used in train.py and eval.py | |||||
| """ | |||||
| from easydict import EasyDict as ed | |||||
| cfg = ed({ | |||||
| {% if dataset=='MNIST' %} | |||||
| 'num_classes': 10, | |||||
| {% elif dataset=='Cifar10' %} | |||||
| 'num_classes': 10, | |||||
| {% elif dataset=='ImageNet' %} | |||||
| 'num_classes': 1001, | |||||
| {% endif %} | |||||
| "batch_size": 32, | |||||
| "loss_scale": 1024, | |||||
| {% if optimizer=='Momentum' %} | |||||
| "momentum": 0.9, | |||||
| {% endif %} | |||||
| "image_height": 224, | |||||
| "image_width": 224, | |||||
| "weight_decay": 1e-4, | |||||
| "epoch_size": 1, | |||||
| "pretrain_epoch_size": 1, | |||||
| "save_checkpoint": True, | |||||
| "save_checkpoint_epochs": 5, | |||||
| "keep_checkpoint_max": 10, | |||||
| "save_checkpoint_path": "./", | |||||
| {% if dataset=='ImageNet' %} | |||||
| "warmup_epochs": 0, | |||||
| "lr_decay_mode": "cosine", | |||||
| {% else %} | |||||
| "warmup_epochs": 5, | |||||
| "lr_decay_mode": "poly", | |||||
| {% endif %} | |||||
| "use_label_smooth": True, | |||||
| "label_smooth_factor": 0.1, | |||||
| "lr": 0.01, | |||||
| "lr_init": 0.01, | |||||
| "lr_end": 0.00001, | |||||
| "lr_max": 0.1 | |||||
| }) | |||||
| @@ -0,0 +1,39 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """define loss function for network""" | |||||
| from mindspore.nn.loss.loss import _Loss | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore.ops import functional as F | |||||
| from mindspore import Tensor | |||||
| from mindspore.common import dtype as mstype | |||||
| import mindspore.nn as nn | |||||
| class CrossEntropy(_Loss): | |||||
| """the redefined loss function with SoftmaxCrossEntropyWithLogits""" | |||||
| def __init__(self, smooth_factor=0., num_classes=1001): | |||||
| super(CrossEntropy, self).__init__() | |||||
| self.onehot = P.OneHot() | |||||
| self.on_value = Tensor(1.0 - smooth_factor, mstype.float32) | |||||
| self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32) | |||||
| self.ce = nn.SoftmaxCrossEntropyWithLogits() | |||||
| self.mean = P.ReduceMean(False) | |||||
| def construct(self, logit, label): | |||||
| one_hot_label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value) | |||||
| loss = self.ce(logit, one_hot_label) | |||||
| loss = self.mean(loss, 0) | |||||
| return loss | |||||
| @@ -0,0 +1,116 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """learning rate generator""" | |||||
| import math | |||||
| import numpy as np | |||||
| def get_lr(lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch, lr_decay_mode): | |||||
| """ | |||||
| generate learning rate array | |||||
| Args: | |||||
| lr_init(float): init learning rate | |||||
| lr_end(float): end learning rate | |||||
| lr_max(float): max learning rate | |||||
| warmup_epochs(int): number of warmup epochs | |||||
| total_epochs(int): total epoch of training | |||||
| steps_per_epoch(int): steps of one epoch | |||||
| lr_decay_mode(string): learning rate decay mode, including steps, poly or default | |||||
| Returns: | |||||
| np.array, learning rate array | |||||
| """ | |||||
| lr_each_step = [] | |||||
| total_steps = steps_per_epoch * total_epochs | |||||
| warmup_steps = steps_per_epoch * warmup_epochs | |||||
| if lr_decay_mode == 'steps': | |||||
| decay_epoch_index = [0.3 * total_steps, 0.6 * total_steps, 0.8 * total_steps] | |||||
| for i in range(total_steps): | |||||
| if i < decay_epoch_index[0]: | |||||
| lr = lr_max | |||||
| elif i < decay_epoch_index[1]: | |||||
| lr = lr_max * 0.1 | |||||
| elif i < decay_epoch_index[2]: | |||||
| lr = lr_max * 0.01 | |||||
| else: | |||||
| lr = lr_max * 0.001 | |||||
| lr_each_step.append(lr) | |||||
| elif lr_decay_mode == 'poly': | |||||
| if warmup_steps != 0: | |||||
| inc_each_step = (float(lr_max) - float(lr_init)) / float(warmup_steps) | |||||
| else: | |||||
| inc_each_step = 0 | |||||
| for i in range(total_steps): | |||||
| if i < warmup_steps: | |||||
| lr = float(lr_init) + inc_each_step * float(i) | |||||
| else: | |||||
| base = (1.0 - (float(i) - float(warmup_steps)) / (float(total_steps) - float(warmup_steps))) | |||||
| lr = float(lr_max) * base * base | |||||
| if lr < 0.0: | |||||
| lr = 0.0 | |||||
| lr_each_step.append(lr) | |||||
| else: | |||||
| for i in range(total_steps): | |||||
| if i < warmup_steps: | |||||
| lr = lr_init + (lr_max - lr_init) * i / warmup_steps | |||||
| else: | |||||
| lr = lr_max - (lr_max - lr_end) * (i - warmup_steps) / (total_steps - warmup_steps) | |||||
| lr_each_step.append(lr) | |||||
| lr_each_step = np.array(lr_each_step).astype(np.float32) | |||||
| return lr_each_step | |||||
| def linear_warmup_lr(current_step, warmup_steps, base_lr, init_lr): | |||||
| lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps) | |||||
| lr = float(init_lr) + lr_inc * current_step | |||||
| return lr | |||||
| def warmup_cosine_annealing_lr(lr, steps_per_epoch, warmup_epochs, max_epoch=120, global_step=0): | |||||
| """ | |||||
| generate learning rate array with cosine | |||||
| Args: | |||||
| lr(float): base learning rate | |||||
| steps_per_epoch(int): steps size of one epoch | |||||
| warmup_epochs(int): number of warmup epochs | |||||
| max_epoch(int): total epochs of training | |||||
| global_step(int): the current start index of lr array | |||||
| Returns: | |||||
| np.array, learning rate array | |||||
| """ | |||||
| base_lr = lr | |||||
| warmup_init_lr = 0 | |||||
| total_steps = int(max_epoch * steps_per_epoch) | |||||
| warmup_steps = int(warmup_epochs * steps_per_epoch) | |||||
| decay_steps = total_steps - warmup_steps | |||||
| lr_each_step = [] | |||||
| for i in range(total_steps): | |||||
| if i < warmup_steps: | |||||
| lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr) | |||||
| else: | |||||
| linear_decay = (total_steps - i) / decay_steps | |||||
| cosine_decay = 0.5 * (1 + math.cos(math.pi * 2 * 0.47 * i / decay_steps)) | |||||
| decayed = linear_decay * cosine_decay + 0.00001 | |||||
| lr = base_lr * decayed | |||||
| lr_each_step.append(lr) | |||||
| lr_each_step = np.array(lr_each_step).astype(np.float32) | |||||
| learning_rate = lr_each_step[global_step:] | |||||
| return learning_rate | |||||
| @@ -0,0 +1,262 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ResNet.""" | |||||
| import numpy as np | |||||
| import mindspore.nn as nn | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore.common.tensor import Tensor | |||||
| def _weight_variable(shape, factor=0.01): | |||||
| init_value = np.random.randn(*shape).astype(np.float32) * factor | |||||
| return Tensor(init_value) | |||||
| def _conv3x3(in_channel, out_channel, stride=1): | |||||
| weight_shape = (out_channel, in_channel, 3, 3) | |||||
| weight = _weight_variable(weight_shape) | |||||
| return nn.Conv2d(in_channel, out_channel, | |||||
| kernel_size=3, stride=stride, padding=0, pad_mode='same', weight_init=weight) | |||||
| def _conv1x1(in_channel, out_channel, stride=1): | |||||
| weight_shape = (out_channel, in_channel, 1, 1) | |||||
| weight = _weight_variable(weight_shape) | |||||
| return nn.Conv2d(in_channel, out_channel, | |||||
| kernel_size=1, stride=stride, padding=0, pad_mode='same', weight_init=weight) | |||||
| def _conv7x7(in_channel, out_channel, stride=1): | |||||
| weight_shape = (out_channel, in_channel, 7, 7) | |||||
| weight = _weight_variable(weight_shape) | |||||
| return nn.Conv2d(in_channel, out_channel, | |||||
| kernel_size=7, stride=stride, padding=0, pad_mode='same', weight_init=weight) | |||||
| def _bn(channel): | |||||
| return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9, | |||||
| gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1) | |||||
| def _bn_last(channel): | |||||
| return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9, | |||||
| gamma_init=0, beta_init=0, moving_mean_init=0, moving_var_init=1) | |||||
| def _fc(in_channel, out_channel): | |||||
| weight_shape = (out_channel, in_channel) | |||||
| weight = _weight_variable(weight_shape) | |||||
| return nn.Dense(in_channel, out_channel, has_bias=True, weight_init=weight, bias_init=0) | |||||
| class ResidualBlock(nn.Cell): | |||||
| """ | |||||
| ResNet V1 residual block definition. | |||||
| Args: | |||||
| in_channel (int): Input channel. | |||||
| out_channel (int): Output channel. | |||||
| stride (int): Stride size for the first convolutional layer. Default: 1. | |||||
| Returns: | |||||
| Tensor, output tensor. | |||||
| Examples: | |||||
| >>> ResidualBlock(3, 256, stride=2) | |||||
| """ | |||||
| expansion = 4 | |||||
| def __init__(self, | |||||
| in_channel, | |||||
| out_channel, | |||||
| stride=1): | |||||
| super(ResidualBlock, self).__init__() | |||||
| channel = out_channel // self.expansion | |||||
| self.conv1 = _conv1x1(in_channel, channel, stride=1) | |||||
| self.bn1 = _bn(channel) | |||||
| self.conv2 = _conv3x3(channel, channel, stride=stride) | |||||
| self.bn2 = _bn(channel) | |||||
| self.conv3 = _conv1x1(channel, out_channel, stride=1) | |||||
| self.bn3 = _bn_last(out_channel) | |||||
| self.relu = nn.ReLU() | |||||
| self.down_sample = False | |||||
| if stride != 1 or in_channel != out_channel: | |||||
| self.down_sample = True | |||||
| self.down_sample_layer = None | |||||
| if self.down_sample: | |||||
| self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride), | |||||
| _bn(out_channel)]) | |||||
| self.add = P.TensorAdd() | |||||
| def construct(self, x): | |||||
| identity = x | |||||
| out = self.conv1(x) | |||||
| out = self.bn1(out) | |||||
| out = self.relu(out) | |||||
| out = self.conv2(out) | |||||
| out = self.bn2(out) | |||||
| out = self.relu(out) | |||||
| out = self.conv3(out) | |||||
| out = self.bn3(out) | |||||
| if self.down_sample: | |||||
| identity = self.down_sample_layer(identity) | |||||
| out = self.add(out, identity) | |||||
| out = self.relu(out) | |||||
| return out | |||||
| class ResNet(nn.Cell): | |||||
| """ | |||||
| ResNet architecture. | |||||
| Args: | |||||
| block (Cell): Block for network. | |||||
| layer_nums (list): Numbers of block in different layers. | |||||
| in_channels (list): Input channel in each layer. | |||||
| out_channels (list): Output channel in each layer. | |||||
| strides (list): Stride size in each layer. | |||||
| num_classes (int): The number of classes that the training images are belonging to. | |||||
| Returns: | |||||
| Tensor, output tensor. | |||||
| Examples: | |||||
| >>> ResNet(ResidualBlock, | |||||
| >>> [3, 4, 6, 3], | |||||
| >>> [64, 256, 512, 1024], | |||||
| >>> [256, 512, 1024, 2048], | |||||
| >>> [1, 2, 2, 2], | |||||
| >>> 10) | |||||
| """ | |||||
| def __init__(self, | |||||
| block, | |||||
| layer_nums, | |||||
| in_channels, | |||||
| out_channels, | |||||
| strides, | |||||
| num_classes): | |||||
| super(ResNet, self).__init__() | |||||
| if not len(layer_nums) == len(in_channels) == len(out_channels) == 4: | |||||
| raise ValueError("the length of layer_num, in_channels, out_channels list must be 4!") | |||||
| self.conv1 = _conv7x7(3, 64, stride=2) | |||||
| self.bn1 = _bn(64) | |||||
| self.relu = P.ReLU() | |||||
| self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same") | |||||
| self.layer1 = self._make_layer(block, | |||||
| layer_nums[0], | |||||
| in_channel=in_channels[0], | |||||
| out_channel=out_channels[0], | |||||
| stride=strides[0]) | |||||
| self.layer2 = self._make_layer(block, | |||||
| layer_nums[1], | |||||
| in_channel=in_channels[1], | |||||
| out_channel=out_channels[1], | |||||
| stride=strides[1]) | |||||
| self.layer3 = self._make_layer(block, | |||||
| layer_nums[2], | |||||
| in_channel=in_channels[2], | |||||
| out_channel=out_channels[2], | |||||
| stride=strides[2]) | |||||
| self.layer4 = self._make_layer(block, | |||||
| layer_nums[3], | |||||
| in_channel=in_channels[3], | |||||
| out_channel=out_channels[3], | |||||
| stride=strides[3]) | |||||
| self.mean = P.ReduceMean(keep_dims=True) | |||||
| self.flatten = nn.Flatten() | |||||
| self.end_point = _fc(out_channels[3], num_classes) | |||||
| def _make_layer(self, block, layer_num, in_channel, out_channel, stride): | |||||
| """ | |||||
| Make stage network of ResNet. | |||||
| Args: | |||||
| block (Cell): Resnet block. | |||||
| layer_num (int): Layer number. | |||||
| in_channel (int): Input channel. | |||||
| out_channel (int): Output channel. | |||||
| stride (int): Stride size for the first convolutional layer. | |||||
| Returns: | |||||
| SequentialCell, the output layer. | |||||
| Examples: | |||||
| >>> _make_layer(ResidualBlock, 3, 128, 256, 2) | |||||
| """ | |||||
| layers = [] | |||||
| resnet_block = block(in_channel, out_channel, stride=stride) | |||||
| layers.append(resnet_block) | |||||
| for _ in range(1, layer_num): | |||||
| resnet_block = block(out_channel, out_channel, stride=1) | |||||
| layers.append(resnet_block) | |||||
| return nn.SequentialCell(layers) | |||||
| def construct(self, x): | |||||
| x = self.conv1(x) | |||||
| x = self.bn1(x) | |||||
| x = self.relu(x) | |||||
| c1 = self.maxpool(x) | |||||
| c2 = self.layer1(c1) | |||||
| c3 = self.layer2(c2) | |||||
| c4 = self.layer3(c3) | |||||
| c5 = self.layer4(c4) | |||||
| out = self.mean(c5, (2, 3)) | |||||
| out = self.flatten(out) | |||||
| out = self.end_point(out) | |||||
| return out | |||||
| def resnet50(class_num=10): | |||||
| """ | |||||
| Get ResNet50 neural network. | |||||
| Args: | |||||
| class_num (int): Class number. | |||||
| Returns: | |||||
| Cell, cell instance of ResNet50 neural network. | |||||
| Examples: | |||||
| >>> net = resnet50(10) | |||||
| """ | |||||
| return ResNet(ResidualBlock, | |||||
| [3, 4, 6, 3], | |||||
| [64, 256, 512, 1024], | |||||
| [256, 512, 1024, 2048], | |||||
| [1, 2, 2, 2], | |||||
| class_num) | |||||
| @@ -0,0 +1,164 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """train resnet.""" | |||||
| import os | |||||
| import random | |||||
| import argparse | |||||
| import numpy as np | |||||
| from mindspore import context | |||||
| from mindspore import Tensor | |||||
| from mindspore import dataset as de | |||||
| from mindspore.parallel._auto_parallel_context import auto_parallel_context | |||||
| from mindspore.train.model import Model, ParallelMode | |||||
| from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor | |||||
| from mindspore.train.loss_scale_manager import FixedLossScaleManager | |||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||||
| from mindspore.communication.management import init, get_rank, get_group_size | |||||
| import mindspore.nn as nn | |||||
| import mindspore.common.initializer as weight_init | |||||
| from src.lr_generator import get_lr, warmup_cosine_annealing_lr | |||||
| from src.crossentropy import CrossEntropy | |||||
| parser = argparse.ArgumentParser(description='Image classification') | |||||
| parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute') | |||||
| parser.add_argument('--device_num', type=int, default=1, help='Device num.') | |||||
| parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path') | |||||
| parser.add_argument('--device_target', type=str, default='Ascend', help='Device target: "Ascend", "GPU", "CPU"') | |||||
| parser.add_argument('--pre_trained', type=str, default=None, help='Pretrained checkpoint path') | |||||
| parser.add_argument('--dataset_sink_mode', type=str, default='True', choices = ['True', 'False'], | |||||
| help='DataSet sink mode is True or False') | |||||
| args_opt = parser.parse_args() | |||||
| random.seed(1) | |||||
| np.random.seed(1) | |||||
| de.config.set_seed(1) | |||||
| from src.resnet50 import resnet50 as resnet | |||||
| from src.config import cfg | |||||
| from src.dataset import create_dataset | |||||
| if __name__ == '__main__': | |||||
| target = args_opt.device_target | |||||
| ckpt_save_dir = cfg.save_checkpoint_path | |||||
| dataset_sink_mode = args_opt.dataset_sink_mode=='True' | |||||
| # init context | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False) | |||||
| if args_opt.run_distribute: | |||||
| if target == "Ascend": | |||||
| device_id = int(os.getenv('DEVICE_ID')) | |||||
| context.set_context(device_id=device_id, enable_auto_mixed_precision=True) | |||||
| context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL, | |||||
| mirror_mean=True) | |||||
| auto_parallel_context().set_all_reduce_fusion_split_indices([107, 160]) | |||||
| init() | |||||
| # GPU target | |||||
| else: | |||||
| init("nccl") | |||||
| context.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL, | |||||
| mirror_mean=True) | |||||
| ckpt_save_dir = cfg.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/" | |||||
| # create dataset | |||||
| dataset = create_dataset(data_path=args_opt.dataset_path, do_train=True, batch_size=cfg.batch_size, target=target) | |||||
| step_size = dataset.get_dataset_size() | |||||
| # define net | |||||
| net = resnet(class_num=cfg.num_classes) | |||||
| # init weight | |||||
| if args_opt.pre_trained: | |||||
| param_dict = load_checkpoint(args_opt.pre_trained) | |||||
| load_param_into_net(net, param_dict) | |||||
| else: | |||||
| for _, cell in net.cells_and_names(): | |||||
| if isinstance(cell, nn.Conv2d): | |||||
| cell.weight.default_input = weight_init.initializer(weight_init.XavierUniform(), | |||||
| cell.weight.default_input.shape, | |||||
| cell.weight.default_input.dtype).to_tensor() | |||||
| if isinstance(cell, nn.Dense): | |||||
| cell.weight.default_input = weight_init.initializer(weight_init.TruncatedNormal(), | |||||
| cell.weight.default_input.shape, | |||||
| cell.weight.default_input.dtype).to_tensor() | |||||
| # init lr | |||||
| {% if dataset=='Cifar10' %} | |||||
| lr = get_lr(lr_init=cfg.lr_init, lr_end=cfg.lr_end, lr_max=cfg.lr_max, | |||||
| warmup_epochs=cfg.warmup_epochs, total_epochs=cfg.epoch_size, steps_per_epoch=step_size, | |||||
| lr_decay_mode='poly') | |||||
| {% else %} | |||||
| lr = get_lr(lr_init=cfg.lr_init, lr_end=0.0, lr_max=cfg.lr_max, warmup_epochs=cfg.warmup_epochs, | |||||
| total_epochs=cfg.epoch_size, steps_per_epoch=step_size, lr_decay_mode='cosine') | |||||
| {% endif %} | |||||
| lr = Tensor(lr) | |||||
| # define opt | |||||
| {% if optimizer=='Lamb' %} | |||||
| opt = nn.Lamb(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate=lr, | |||||
| weight_decay=cfg.weight_decay) | |||||
| {% elif optimizer=='Momentum' %} | |||||
| opt = nn.Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate=lr, momentum=cfg.momentum, | |||||
| weight_decay=cfg.weight_decay, loss_scale=cfg.loss_scale) | |||||
| {% endif %} | |||||
| # define loss, model | |||||
| if target == "Ascend": | |||||
| {% if dataset=='ImageNet' %} | |||||
| if not cfg.use_label_smooth: | |||||
| cfg.label_smooth_factor = 0.0 | |||||
| loss = CrossEntropy(smooth_factor=cfg.label_smooth_factor, num_classes=cfg.num_classes) | |||||
| {% else %} | |||||
| {% if loss=='SoftmaxCrossEntropyWithLogits' %} | |||||
| loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') | |||||
| {% elif loss=='SoftmaxCrossEntropyExpand' %} | |||||
| loss = nn.SoftmaxCrossEntropyExpand(sparse=True) | |||||
| {% endif %} | |||||
| {% endif %} | |||||
| loss_scale = FixedLossScaleManager(cfg.loss_scale, drop_overflow_update=False) | |||||
| model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'}, | |||||
| amp_level="O2", keep_batchnorm_fp32=False) | |||||
| else: | |||||
| # GPU target | |||||
| {% if loss=='SoftmaxCrossEntropyWithLogits' %} | |||||
| loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, is_grad=False, reduction='mean') | |||||
| {% elif loss=='SoftmaxCrossEntropyExpand' %} | |||||
| loss = nn.SoftmaxCrossEntropyExpand(sparse=True) | |||||
| {% endif %} | |||||
| {% if optimizer=='Lamb' %} | |||||
| opt = nn.Lamb(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate=lr) | |||||
| {% elif optimizer=='Momentum' %} | |||||
| opt = nn.Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate=lr, momentum=cfg.momentum) | |||||
| {% endif %} | |||||
| model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}) | |||||
| # define callbacks | |||||
| time_cb = TimeMonitor(data_size=step_size) | |||||
| loss_cb = LossMonitor() | |||||
| cb = [time_cb, loss_cb] | |||||
| if cfg.save_checkpoint: | |||||
| cfg_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_epochs * step_size, | |||||
| keep_checkpoint_max=cfg.keep_checkpoint_max) | |||||
| ckpt_cb = ModelCheckpoint(prefix="resnet", directory=ckpt_save_dir, config=cfg_ck) | |||||
| cb += [ckpt_cb] | |||||
| # train model | |||||
| model.train(cfg.epoch_size, dataset, callbacks=cb, dataset_sink_mode=dataset_sink_mode) | |||||
| @@ -0,0 +1,119 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Create project command module.""" | |||||
| import os | |||||
| import re | |||||
| import sys | |||||
| import textwrap | |||||
| from pathlib import Path | |||||
| import click | |||||
| from mindinsight.utils.command import BaseCommand | |||||
| from mindinsight.wizard.base.utility import find_network_maker_names, load_network_maker, process_prompt_choice | |||||
| from mindinsight.wizard.common.exceptions import CommandError | |||||
| from mindinsight.wizard.conf.constants import SUPPORT_MINDSPORE_VERSION, QUESTION_START | |||||
| class CreateProject(BaseCommand): | |||||
| """Create project class.""" | |||||
| name = 'createproject' | |||||
| description = 'create project' | |||||
| def __init__(self): | |||||
| self._network_types = find_network_maker_names() | |||||
| def add_arguments(self, parser): | |||||
| """ | |||||
| Add arguments to parser. | |||||
| Args: | |||||
| parser (ArgumentParser): Specify parser to which arguments are added. | |||||
| """ | |||||
| parser.add_argument( | |||||
| 'name', | |||||
| type=str, | |||||
| help='Specify the new project name.') | |||||
| def _make_project_dir(self, project_name): | |||||
| self._check_project_dir(project_name) | |||||
| permissions = os.R_OK | os.W_OK | os.X_OK | |||||
| mode = permissions << 6 | |||||
| project_dir = os.path.join(os.getcwd(), project_name) | |||||
| os.makedirs(project_dir, mode=mode, exist_ok=True) | |||||
| return project_dir | |||||
| @staticmethod | |||||
| def _check_project_dir(project_name): | |||||
| """Check project directory whether empty or exist.""" | |||||
| if not re.search('^[A-Za-z0-9][A-Za-z0-9._-]*$', project_name): | |||||
| raise CommandError("'%s' is not a valid project name. Please input a valid name" % project_name) | |||||
| project_dir = os.path.join(os.getcwd(), project_name) | |||||
| if os.path.exists(project_dir): | |||||
| output_path = Path(project_dir) | |||||
| if output_path.is_dir(): | |||||
| if os.path.os.listdir(project_dir): | |||||
| raise CommandError('%s already exists, %s is not empty directory, please try another name.' | |||||
| % (project_name, project_dir)) | |||||
| else: | |||||
| CommandError('There is a file in the current directory has the same name as the project %s, ' | |||||
| 'please try another name.' % project_name) | |||||
| return True | |||||
| def ask_network(self): | |||||
| """Ask user question for selecting a network to create.""" | |||||
| network_type_choices = self._network_types[:] | |||||
| network_type_choices.sort(reverse=False) | |||||
| prompt_msg = '{}:\n{}\n'.format( | |||||
| '%sPlease select a network' % QUESTION_START, | |||||
| '\n'.join(f'{idx: >4}: {choice}' for idx, choice in enumerate(network_type_choices, start=1)) | |||||
| ) | |||||
| prompt_type = click.IntRange(min=1, max=len(network_type_choices)) | |||||
| choice = click.prompt(prompt_msg, type=prompt_type, hide_input=False, show_choices=False, | |||||
| confirmation_prompt=False, | |||||
| value_proc=lambda x: process_prompt_choice(x, prompt_type)) | |||||
| return network_type_choices[choice - 1] | |||||
| @staticmethod | |||||
| def echo_notice(): | |||||
| """Echo notice for depending environment.""" | |||||
| click.secho(textwrap.dedent(""" | |||||
| [NOTICE] To ensure the final generated scripts run under specific environment with the following | |||||
| mindspore : %s | |||||
| """ % SUPPORT_MINDSPORE_VERSION), fg='red') | |||||
| def run(self, args): | |||||
| """Override run method to start.""" | |||||
| project_name = args.get('name') | |||||
| try: | |||||
| self._check_project_dir(project_name) | |||||
| except CommandError as error: | |||||
| click.secho(error.message, fg='red') | |||||
| sys.exit(1) | |||||
| try: | |||||
| self.echo_notice() | |||||
| network_maker_name = self.ask_network() | |||||
| network_maker = load_network_maker(network_maker_name) | |||||
| network_maker.configure() | |||||
| except click.exceptions.Abort: | |||||
| sys.exit(1) | |||||
| project_dir = self._make_project_dir(project_name) | |||||
| source_files = network_maker.generate(**args) | |||||
| for source_file in source_files: | |||||
| source_file.write(project_dir) | |||||
| click.secho(f"{project_name} is generated in {project_dir}") | |||||
| @@ -0,0 +1,14 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| @@ -0,0 +1,47 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """BaseDataset module.""" | |||||
| import os | |||||
| from mindinsight.wizard.base.dataset import BaseDataset | |||||
| from mindinsight.wizard.base.templates import TemplateManager | |||||
| from mindinsight.wizard.conf.constants import TEMPLATES_BASE_DIR | |||||
| class Dataset(BaseDataset): | |||||
| """BaseDataset code generator.""" | |||||
| name = 'Cifar10' | |||||
| def __init__(self): | |||||
| super(Dataset, self).__init__() | |||||
| self._network = None | |||||
| self.template_manager = None | |||||
| def set_network(self, network_maker): | |||||
| self._network = network_maker | |||||
| template_dir = os.path.join(TEMPLATES_BASE_DIR, | |||||
| 'network', | |||||
| network_maker.name.lower(), | |||||
| 'dataset', | |||||
| self.name.lower()) | |||||
| self.template_manager = TemplateManager(template_dir) | |||||
| def configure(self): | |||||
| """Configure the network options.""" | |||||
| return self.settings | |||||
| def generate(self, **options): | |||||
| source_files = self.template_manager.render(**options) | |||||
| return source_files | |||||
| @@ -0,0 +1,47 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """BaseDataset module.""" | |||||
| import os | |||||
| from mindinsight.wizard.base.dataset import BaseDataset | |||||
| from mindinsight.wizard.base.templates import TemplateManager | |||||
| from mindinsight.wizard.conf.constants import TEMPLATES_BASE_DIR | |||||
| class Dataset(BaseDataset): | |||||
| """BaseDataset code generator.""" | |||||
| name = 'ImageNet' | |||||
| def __init__(self): | |||||
| super(Dataset, self).__init__() | |||||
| self._network = None | |||||
| self.template_manager = None | |||||
| def set_network(self, network_maker): | |||||
| self._network = network_maker | |||||
| template_dir = os.path.join(TEMPLATES_BASE_DIR, | |||||
| 'network', | |||||
| network_maker.name.lower(), | |||||
| 'dataset', | |||||
| self.name.lower()) | |||||
| self.template_manager = TemplateManager(template_dir) | |||||
| def configure(self): | |||||
| """Configure the network options.""" | |||||
| return self.settings | |||||
| def generate(self, **options): | |||||
| source_files = self.template_manager.render(**options) | |||||
| return source_files | |||||
| @@ -0,0 +1,47 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """BaseDataset module.""" | |||||
| import os | |||||
| from mindinsight.wizard.base.dataset import BaseDataset | |||||
| from mindinsight.wizard.base.templates import TemplateManager | |||||
| from mindinsight.wizard.conf.constants import TEMPLATES_BASE_DIR | |||||
| class Dataset(BaseDataset): | |||||
| """BaseDataset code generator.""" | |||||
| name = 'MNIST' | |||||
| def __init__(self): | |||||
| super(Dataset, self).__init__() | |||||
| self._network = None | |||||
| self.template_manager = None | |||||
| def set_network(self, network_maker): | |||||
| self._network = network_maker | |||||
| template_dir = os.path.join(TEMPLATES_BASE_DIR, | |||||
| 'network', | |||||
| network_maker.name.lower(), | |||||
| 'dataset', | |||||
| self.name.lower()) | |||||
| self.template_manager = TemplateManager(template_dir) | |||||
| def configure(self): | |||||
| """Configure the network options.""" | |||||
| return self.settings | |||||
| def generate(self, **options): | |||||
| source_files = self.template_manager.render(**options) | |||||
| return source_files | |||||
| @@ -0,0 +1,14 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| @@ -0,0 +1,18 @@ | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """alexnet module.""" | |||||
| from mindinsight.wizard.network.generic_network import GenericNetwork | |||||
| class Network(GenericNetwork): | |||||
| """Network code generator.""" | |||||
| name = 'alexnet' | |||||
| supported_datasets = ['Cifar10', 'ImageNet'] | |||||
| supported_loss_functions = ['SoftmaxCrossEntropyWithLogits', 'SoftmaxCrossEntropyExpand'] | |||||
| supported_optimizers = ['Momentum', 'Lamb'] | |||||
| @@ -0,0 +1,143 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """GenericNetwork module.""" | |||||
| import os | |||||
| import click | |||||
| from mindinsight.wizard.base.network import BaseNetwork | |||||
| from mindinsight.wizard.base.templates import TemplateManager | |||||
| from mindinsight.wizard.base.utility import process_prompt_choice, load_dataset_maker | |||||
| from mindinsight.wizard.conf.constants import TEMPLATES_BASE_DIR | |||||
| from mindinsight.wizard.conf.constants import QUESTION_START | |||||
| class GenericNetwork(BaseNetwork): | |||||
| """BaseNetwork code generator.""" | |||||
| name = 'GenericNetwork' | |||||
| supported_datasets = [] | |||||
| supported_loss_functions = [] | |||||
| supported_optimizers = [] | |||||
| def __init__(self): | |||||
| self._dataset_maker = None | |||||
| template_dir = os.path.join(TEMPLATES_BASE_DIR, 'network', self.name.lower()) | |||||
| self.network_template_manager = TemplateManager(os.path.join(template_dir, 'src')) | |||||
| self.common_template_manager = TemplateManager(template_dir, ['src', 'dataset']) | |||||
| def configure(self, settings=None): | |||||
| """ | |||||
| Configure the network options. | |||||
| If settings is not None, then use the input settings to configure the network. | |||||
| Args: | |||||
| settings (dict): Settings to configure, format is {'options': value}. | |||||
| Example: | |||||
| { | |||||
| "loss": "SoftmaxCrossEntropyWithLogits", | |||||
| "optimizer": "Momentum", | |||||
| "dataset": "Cifar10" | |||||
| } | |||||
| Returns: | |||||
| dict, configuration value to network. | |||||
| """ | |||||
| if settings: | |||||
| config = {'loss': settings['loss'], | |||||
| 'optimizer': settings['optimizer'], | |||||
| 'dataset': settings['dataset']} | |||||
| self.settings.update(config) | |||||
| return config | |||||
| loss = self.ask_loss_function() | |||||
| optimizer = self.ask_optimizer() | |||||
| dataset = self.ask_dataset() | |||||
| self._dataset_maker = load_dataset_maker(dataset) | |||||
| self._dataset_maker.set_network(self) | |||||
| dataset_config = self._dataset_maker.configure() | |||||
| config = {'loss': loss, | |||||
| 'optimizer': optimizer, | |||||
| 'dataset': dataset} | |||||
| config.update(dataset_config) | |||||
| self.settings.update(config) | |||||
| return config | |||||
| @staticmethod | |||||
| def ask_choice(prompt_head, content_list, default_value=None): | |||||
| """Ask user to get selected result.""" | |||||
| if default_value is None: | |||||
| default_choice = 1 # start from 1 in prompt message. | |||||
| default_value = content_list[default_choice - 1] | |||||
| choice_contents = content_list[:] | |||||
| choice_contents.sort(reverse=False) | |||||
| default_choice = choice_contents.index(default_value) + 1 # start from 1 in prompt message. | |||||
| prompt_msg = '{}:\n{}\n'.format( | |||||
| prompt_head, | |||||
| '\n'.join(f'{idx: >4}: {choice}' for idx, choice in enumerate(choice_contents, start=1)) | |||||
| ) | |||||
| prompt_type = click.IntRange(min=1, max=len(choice_contents)) | |||||
| choice = click.prompt(prompt_msg, type=prompt_type, hide_input=False, show_choices=False, | |||||
| confirmation_prompt=False, default=default_choice, | |||||
| value_proc=lambda x: process_prompt_choice(x, prompt_type)) | |||||
| return choice_contents[choice - 1] | |||||
| def ask_loss_function(self): | |||||
| """Select loss function by user.""" | |||||
| return self.ask_choice('%sPlease select a loss function' % QUESTION_START, self.supported_loss_functions) | |||||
| def ask_optimizer(self): | |||||
| """Select optimizer by user.""" | |||||
| return self.ask_choice('%sPlease select an optimizer' % QUESTION_START, self.supported_optimizers) | |||||
| def ask_dataset(self): | |||||
| """Select dataset by user.""" | |||||
| return self.ask_choice('%sPlease select a dataset' % QUESTION_START, self.supported_datasets) | |||||
| def generate(self, **options): | |||||
| """Generate network definition scripts.""" | |||||
| context = self.get_generate_context(**options) | |||||
| network_source_files = self.network_template_manager.render(**context) | |||||
| for source_file in network_source_files: | |||||
| source_file.file_relative_path = os.path.join('src', source_file.file_relative_path) | |||||
| dataset_source_files = self._dataset_maker.generate(**options) | |||||
| for source_file in dataset_source_files: | |||||
| source_file.file_relative_path = os.path.join('src', source_file.file_relative_path) | |||||
| assemble_files = self._assemble(**options) | |||||
| source_files = network_source_files + dataset_source_files + assemble_files | |||||
| return source_files | |||||
| def get_generate_context(self, **options): | |||||
| """Get detailed info based on settings to network files.""" | |||||
| context = dict(options) | |||||
| context.update(self.settings) | |||||
| return context | |||||
| def get_assemble_context(self, **options): | |||||
| """Get detailed info based on settings to assemble files.""" | |||||
| context = dict(options) | |||||
| context.update(self.settings) | |||||
| return context | |||||
| def _assemble(self, **options): | |||||
| # generate train.py & eval.py & assemble scripts. | |||||
| assemble_files = [] | |||||
| context = self.get_assemble_context(**options) | |||||
| common_source_files = self.common_template_manager.render(**context) | |||||
| assemble_files.extend(common_source_files) | |||||
| return assemble_files | |||||
| @@ -0,0 +1,24 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """lenet module.""" | |||||
| from mindinsight.wizard.network.generic_network import GenericNetwork | |||||
| class Network(GenericNetwork): | |||||
| """Network code generator.""" | |||||
| name = 'lenet' | |||||
| supported_datasets = ['MNIST'] | |||||
| supported_loss_functions = ['SoftmaxCrossEntropyWithLogits', 'SoftmaxCrossEntropyExpand'] | |||||
| supported_optimizers = ['Momentum', 'Lamb'] | |||||
| @@ -0,0 +1,18 @@ | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """alexnet module.""" | |||||
| from mindinsight.wizard.network.generic_network import GenericNetwork | |||||
| class Network(GenericNetwork): | |||||
| """Network code generator.""" | |||||
| name = 'resnet50' | |||||
| supported_datasets = ['Cifar10', 'ImageNet'] | |||||
| supported_loss_functions = ['SoftmaxCrossEntropyWithLogits', 'SoftmaxCrossEntropyExpand'] | |||||
| supported_optimizers = ['Momentum', 'Lamb'] | |||||
| @@ -207,6 +207,7 @@ if __name__ == '__main__': | |||||
| 'console_scripts': [ | 'console_scripts': [ | ||||
| 'mindinsight=mindinsight.utils.command:main', | 'mindinsight=mindinsight.utils.command:main', | ||||
| 'mindconverter=mindinsight.mindconverter.cli:cli_entry', | 'mindconverter=mindinsight.mindconverter.cli:cli_entry', | ||||
| 'mindwizard=mindinsight.wizard.cli:cli_entry', | |||||
| ], | ], | ||||
| }, | }, | ||||
| python_requires='>=3.7', | python_requires='>=3.7', | ||||