From e4a0b0db34d712e47b65c5214ec1e8045ed08ff4 Mon Sep 17 00:00:00 2001 From: caojian05 Date: Tue, 8 Sep 2020 23:35:56 +0800 Subject: [PATCH] googlenet support imagenet dataset on Ascend --- model_zoo/googlenet/README.md | 503 ++++++++++++------ model_zoo/googlenet/eval.py | 54 +- model_zoo/googlenet/export.py | 25 +- model_zoo/googlenet/mindspore_hub_conf.py | 25 + model_zoo/googlenet/scripts/run_eval_gpu.sh | 43 ++ model_zoo/googlenet/scripts/run_train.sh | 33 +- model_zoo/googlenet/scripts/run_train_gpu.sh | 51 ++ model_zoo/googlenet/src/config.py | 41 +- model_zoo/googlenet/src/dataset.py | 99 +++- model_zoo/googlenet/src/googlenet.py | 23 +- .../googlenet/src/lr_scheduler/__init__.py | 0 .../src/lr_scheduler/linear_warmup.py | 20 + .../warmup_cosine_annealing_lr.py | 39 ++ .../src/lr_scheduler/warmup_step_lr.py | 59 ++ model_zoo/googlenet/train.py | 159 +++++- 15 files changed, 926 insertions(+), 248 deletions(-) create mode 100644 model_zoo/googlenet/mindspore_hub_conf.py create mode 100644 model_zoo/googlenet/scripts/run_eval_gpu.sh create mode 100644 model_zoo/googlenet/scripts/run_train_gpu.sh create mode 100644 model_zoo/googlenet/src/lr_scheduler/__init__.py create mode 100644 model_zoo/googlenet/src/lr_scheduler/linear_warmup.py create mode 100644 model_zoo/googlenet/src/lr_scheduler/warmup_cosine_annealing_lr.py create mode 100644 model_zoo/googlenet/src/lr_scheduler/warmup_step_lr.py diff --git a/model_zoo/googlenet/README.md b/model_zoo/googlenet/README.md index 92cdd8af43..014dc78211 100644 --- a/model_zoo/googlenet/README.md +++ b/model_zoo/googlenet/README.md @@ -36,14 +36,8 @@ GoogleNet, a 22 layers deep network, was proposed in 2014 and won the first plac # [Model Architecture](#contents) -The overall network architecture of GoogleNet is shown below: - -![](https://miro.medium.com/max/3780/1*ZFPOSAted10TPd3hBQU8iQ.png) - Specifically, the GoogleNet contains numerous inception modules, which are connected together to go deeper. In general, an inception module with dimensionality reduction consists of **1×1 conv**, **3×3 conv**, **5×5 conv**, and **3×3 max pooling**, which are done altogether for the previous input, and stack together again at output. -![](https://miro.medium.com/max/1108/1*sezFsYW1MyM9YOMa1q909A.png) - # [Dataset](#contents) @@ -52,10 +46,9 @@ Dataset used: [CIFAR-10]() - Dataset size:175M,60,000 32*32 colorful images in 10 classes - Train:146M,50,000 images - - Test:29.3M,10,000 images + - Test:29M,10,000 images - Data format:binary files - - Note:Data will be processed in dataset.py - + - Note:Data will be processed in src/dataset.py # [Features](#contents) @@ -72,7 +65,7 @@ For FP16 operators, if the input data type is FP32, the backend of MindSpore wil - Hardware(Ascend/GPU) - Prepare hardware environment with Ascend or GPU processor. If you want to try Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources. - Framework - - [MindSpore](http://10.90.67.50/mindspore/archive/20200506/OpenSource/me_vm_x86/) + - [MindSpore](https://www.mindspore.cn/install/en) - For more information, please check the resources below: - [MindSpore tutorials](https://www.mindspore.cn/tutorial/zh-CN/master/index.html) - [MindSpore API](https://www.mindspore.cn/api/zh-CN/master/index.html) @@ -83,16 +76,45 @@ For FP16 operators, if the input data type is FP32, the backend of MindSpore wil After installing MindSpore via the official website, you can start training and evaluation as follows: -```python -# run training example -python train.py > train.log 2>&1 & +- runing on Ascend -# run distributed training example -sh scripts/run_train.sh rank_table.json + ```python + # run training example + python train.py > train.log 2>&1 & + + # run distributed training example + sh scripts/run_train.sh rank_table.json + + # run evaluation example + python eval.py > eval.log 2>&1 & + OR + sh run_eval.sh + ``` + + For distributed training, a hccl configuration file with JSON format needs to be created in advance. + + Please follow the instructions in the link below: + + https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools. + +- running on GPU + + For running on GPU, please change `device_target` from `Ascend` to `GPU` in configuration file src/config.py + + ```python + # run training example + export CUDA_VISIBLE_DEVICES=0 + python train.py > train.log 2>&1 & + + # run distributed training example + sh scripts/run_train_gpu.sh 8 0,1,2,3,4,5,6,7 + + # run evaluation example + python eval.py --checkpoint_path=[CHECKPOINT_PATH] > eval.log 2>&1 & + OR + sh run_eval_gpu.sh [CHECKPOINT_PATH] + ``` -# run evaluation example -python eval.py > eval.log 2>&1 & OR sh run_eval.sh -``` @@ -106,109 +128,168 @@ python eval.py > eval.log 2>&1 & OR sh run_eval.sh ├── googlenet ├── README.md // descriptions about googlenet ├── scripts - │ ├──run_train.sh // shell script for distributed - │ ├──run_eval.sh // shell script for evaluation + │ ├──run_train.sh // shell script for distributed on Ascend + │ ├──run_train_gpu.sh // shell script for distributed on GPU + │ ├──run_eval.sh // shell script for evaluation on Ascend + │ ├──run_eval_gpu.sh // shell script for evaluation on GPU ├── src │ ├──dataset.py // creating dataset │ ├──googlenet.py // googlenet architecture │ ├──config.py // parameter configuration ├── train.py // training script ├── eval.py // evaluation script - ├── export.py // export checkpoint files into geir/onnx + ├── export.py // export checkpoint files into air/onnx ``` ## [Script Parameters](#contents) -```python -Major parameters in train.py and config.py are: - ---data_path: The absolute full path to the train and evaluation datasets. ---epoch_size: Total training epochs. ---batch_size: Training batch size. ---lr_init: Initial learning rate. ---num_classes: The number of classes in the training set. ---weight_decay: Weight decay value. ---image_height: Image height used as input to the model. ---image_width: Image width used as input the model. ---pre_trained: Whether training from scratch or training based on the - pre-trained model.Optional values are True, False. ---device_target: Device where the code will be implemented. Optional values - are "Ascend", "GPU". ---device_id: Device ID used to train or evaluate the dataset. Ignore it - when you use run_train.sh for distributed training. ---checkpoint_path: The absolute full path to the checkpoint file saved - after training. ---onnx_filename: File name of the onnx model used in export.py. ---geir_filename: File name of the geir model used in export.py. -``` +Parameters for both training and evaluation can be set in config.py + +- config for GoogleNet, CIFAR-10 dataset + + ```python + 'pre_trained': 'False' # whether training based on the pre-trained model + 'nump_classes': 10 # the number of classes in the dataset + 'lr_init': 0.1 # initial learning rate + 'batch_size': 128 # training batch size + 'epoch_size': 125 # total training epochs + 'momentum': 0.9 # momentum + 'weight_decay': 5e-4 # weight decay value + 'buffer_size': 10 # buffer size + 'image_height': 224 # image height used as input to the model + 'image_width': 224 # image width used as input to the model + 'data_path': './cifar10' # absolute full path to the train and evaluation datasets + 'device_target': 'Ascend' # device running the program + 'device_id': 4 # device ID used to train or evaluate the dataset. Ignore it when you use run_train.sh for distributed training + 'keep_checkpoint_max': 10 # only keep the last keep_checkpoint_max checkpoint + 'checkpoint_path': './train_googlenet_cifar10-125_390.ckpt' # the absolute full path to save the checkpoint file + 'onnx_filename': 'googlenet.onnx' # file name of the onnx model used in export.py + 'geir_filename': 'googlenet.geir' # file name of the geir model used in export.py + ``` ## [Training Process](#contents) ### Training -``` -python train.py > train.log 2>&1 & -``` - -The python command above will run in the background, you can view the results through the file `train.log`. +- running on Ascend + + ``` + python train.py > train.log 2>&1 & + ``` + + The python command above will run in the background, you can view the results through the file `train.log`. + + After training, you'll get some checkpoint files under the script folder by default. The loss value will be achieved as follows: + + ``` + # grep "loss is " train.log + epoch: 1 step: 390, loss is 1.4842823 + epcoh: 2 step: 390, loss is 1.0897788 + ... + ``` + + The model checkpoint will be saved in the current directory. + +- running on GPU + + ``` + export CUDA_VISIBLE_DEVICES=0 + python train.py > train.log 2>&1 & + ``` + + The python command above will run in the background, you can view the results through the file `train.log`. + + After training, you'll get some checkpoint files under the folder `./ckpt_0/` by default. -After training, you'll get some checkpoint files under the script folder by default. The loss value will be achieved as follows: - -``` -# grep "loss is " train.log -epoch: 1 step: 390, loss is 1.4842823 -epcoh: 2 step: 390, loss is 1.0897788 -... -``` - -The model checkpoint will be saved in the current directory. ### Distributed Training -``` -sh scripts/run_train.sh rank_table.json -``` - -The above shell script will run distribute training in the background. You can view the results through the file `train_parallel[X]/log`. The loss value will be achieved as follows: - -``` -# grep "result: " train_parallel*/log -train_parallel0/log:epoch: 1 step: 48, loss is 1.4302931 -train_parallel0/log:epcoh: 2 step: 48, loss is 1.4023874 -... -train_parallel1/log:epoch: 1 step: 48, loss is 1.3458025 -train_parallel1/log:epcoh: 2 step: 48, loss is 1.3729336 -... -... -``` +- running on Ascend + + ``` + sh scripts/run_train.sh rank_table.json + ``` + + The above shell script will run distribute training in the background. You can view the results through the file `train_parallel[X]/log`. The loss value will be achieved as follows: + + ``` + # grep "result: " train_parallel*/log + train_parallel0/log:epoch: 1 step: 48, loss is 1.4302931 + train_parallel0/log:epcoh: 2 step: 48, loss is 1.4023874 + ... + train_parallel1/log:epoch: 1 step: 48, loss is 1.3458025 + train_parallel1/log:epcoh: 2 step: 48, loss is 1.3729336 + ... + ... + ``` + +- running on GPU + + ``` + sh scripts/run_train_gpu.sh 8 0,1,2,3,4,5,6,7 + ``` + + The above shell script will run distribute training in the background. You can view the results through the file `train/train.log`. ## [Evaluation Process](#contents) ### Evaluation -Before running the command below, please check the checkpoint path used for evaluation. Please set the checkpoint path to be the absolute full path, e.g., "username/googlenet/train_googlenet_cifar10-125_390.ckpt". - -``` -python eval.py > eval.log 2>&1 & -OR -sh scripts/run_eval.sh -``` - -The above python command will run in the background. You can view the results through the file "eval.log". The accuracy of the test dataset will be as follows: - -``` -# grep "accuracy: " eval.log -accuracy: {'acc': 0.934} -``` - -Note that for evaluation after distributed training, please set the checkpoint_path to be the last saved checkpoint file such as "username/googlenet/train_parallel0/train_googlenet_cifar10-125_48.ckpt". The accuracy of the test dataset will be as follows: - -``` -# grep "accuracy: " dist.eval.log -accuracy: {'acc': 0.9217} -``` +- evaluation on CIFAR-10 dataset when running on Ascend + + Before running the command below, please check the checkpoint path used for evaluation. Please set the checkpoint path to be the absolute full path, e.g., "username/googlenet/train_googlenet_cifar10-125_390.ckpt". + + ``` + python eval.py > eval.log 2>&1 & + OR + sh scripts/run_eval.sh + ``` + + The above python command will run in the background. You can view the results through the file "eval.log". The accuracy of the test dataset will be as follows: + + ``` + # grep "accuracy: " eval.log + accuracy: {'acc': 0.934} + ``` + + Note that for evaluation after distributed training, please set the checkpoint_path to be the last saved checkpoint file such as "username/googlenet/train_parallel0/train_googlenet_cifar10-125_48.ckpt". The accuracy of the test dataset will be as follows: + + ``` + # grep "accuracy: " dist.eval.log + accuracy: {'acc': 0.9217} + ``` + +- evaluation on CIFAR-10 dataset when running on GPU + + Before running the command below, please check the checkpoint path used for evaluation. Please set the checkpoint path to be the absolute full path, e.g., "username/googlenet/train/ckpt_0/train_googlenet_cifar10-125_390.ckpt". + + ``` + python eval.py --checkpoint_path=[CHECKPOINT_PATH] > eval.log 2>&1 & + ``` + + The above python command will run in the background. You can view the results through the file "eval.log". The accuracy of the test dataset will be as follows: + + ``` + # grep "accuracy: " eval.log + accuracy: {'acc': 0.930} + ``` + + OR, + + ``` + sh scripts/run_eval_gpu.sh [CHECKPOINT_PATH] + ``` + + The above python command will run in the background. You can view the results through the file "eval/eval.log". The accuracy of the test dataset will be as follows: + + ``` + # grep "accuracy: " eval/eval.log + accuracy: {'acc': 0.930} + ``` + + # [Model Description](#contents) @@ -216,100 +297,170 @@ accuracy: {'acc': 0.9217} ### Evaluation Performance -| Parameters | GoogleNet | -| -------------------------- | ----------------------------------------------------------- | -| Model Version | Inception V1 | -| Resource | Ascend 910 ;CPU 2.60GHz,56cores;Memory,314G | -| uploaded Date | 06/09/2020 (month/day/year) | -| MindSpore Version | 0.3.0-alpha | -| Dataset | CIFAR-10 | -| Training Parameters | epoch=125, steps=390, batch_size = 128, lr=0.1 | -| Optimizer | SGD | -| Loss Function | Softmax Cross Entropy | -| outputs | probability | -| Loss | 0.0016 | -| Speed | 1pc: 79 ms/step; 8pcs: 82 ms/step | -| Total time | 1pc: 63.85 mins; 8pcs: 11.28 mins | -| Parameters (M) | 6.8 | -| Checkpoint for Fine tuning | 43.07M (.ckpt file) | -| Model for inference | 21.50M (.onnx file), 21.60M(.geir file) | -| Scripts | https://gitee.com/mindspore/mindspore/tree/master/model_zoo/googlenet | +| Parameters | Ascend | GPU | +| -------------------------- | ----------------------------------------------------------- | ---------------------- | +| Model Version | Inception V1 | Inception V1 | +| Resource | Ascend 910 ;CPU 2.60GHz,56cores;Memory,314G | NV SMX2 V100-32G | +| uploaded Date | 08/31/2020 (month/day/year) | 08/20/2020 (month/day/year) | +| MindSpore Version | 0.7.0-alpha | 0.6.0-alpha | +| Dataset | CIFAR-10 | CIFAR-10 | +| Training Parameters | epoch=125, steps=390, batch_size = 128, lr=0.1 | epoch=125, steps=390, batch_size=128, lr=0.1 | +| Optimizer | SGD | SGD | +| Loss Function | Softmax Cross Entropy | Softmax Cross Entropy | +| outputs | probability | probobility | +| Loss | 0.0016 | 0.0016 | +| Speed | 1pc: 79 ms/step; 8pcs: 82 ms/step | 1pc: 150 ms/step; 8pcs: 164 ms/step | +| Total time | 1pc: 63.85 mins; 8pcs: 11.28 mins | 1pc: 126.87 mins; 8pcs: 21.65 mins | +| Parameters (M) | 13.0 | 13.0 | +| Checkpoint for Fine tuning | 43.07M (.ckpt file) | 43.07M (.ckpt file) | +| Model for inference | 21.50M (.onnx file), 21.60M(.air file) | | +| Scripts | [googlenet script](https://gitee.com/mindspore/mindspore/tree/r0.7/model_zoo/official/cv/googlenet) | [googlenet script](https://gitee.com/mindspore/mindspore/tree/r0.6/model_zoo/official/cv/googlenet) | ### Inference Performance -| Parameters | GoogleNet | -| ------------------- | --------------------------- | -| Model Version | Inception V1 | -| Resource | Ascend 910 | -| Uploaded Date | 06/09/2020 (month/day/year) | -| MindSpore Version | 0.3.0-alpha | -| Dataset | CIFAR-10, 10,000 images | -| batch_size | 128 | -| outputs | probability | -| Accuracy | 1pc: 93.4%; 8pcs: 92.17% | -| Model for inference | 21.50M (.onnx file) | +| Parameters | Ascend | GPU | +| ------------------- | --------------------------- | --------------------------- | +| Model Version | Inception V1 | Inception V1 | +| Resource | Ascend 910 | GPU | +| Uploaded Date | 08/31/2020 (month/day/year) | 08/20/2020 (month/day/year) | +| MindSpore Version | 0.7.0-alpha | 0.6.0-alpha | +| Dataset | CIFAR-10, 10,000 images | CIFAR-10, 10,000 images | +| batch_size | 128 | 128 | +| outputs | probability | probability | +| Accuracy | 1pc: 93.4%; 8pcs: 92.17% | 1pc: 93%, 8pcs: 92.89% | +| Model for inference | 21.50M (.onnx file) | | ## [How to use](#contents) ### Inference If you need to use the trained model to perform inference on multiple hardware platforms, such as GPU, Ascend 910 or Ascend 310, you can refer to this [Link](https://www.mindspore.cn/tutorial/zh-CN/master/advanced_use/network_migration.html). Following the steps below, this is a simple example: -``` -# Load unseen dataset for inference -dataset = dataset.create_dataset(cfg.data_path, 1, False) - -# Define model -net = GoogleNet(num_classes=cfg.num_classes) -opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, - cfg.momentum, weight_decay=cfg.weight_decay) -loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', - is_grad=False) -model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}) - -# Load pre-trained model -param_dict = load_checkpoint(cfg.checkpoint_path) -load_param_into_net(net, param_dict) -net.set_train(False) - -# Make predictions on the unseen dataset -acc = model.eval(dataset) -print("accuracy: ", acc) -``` +- Running on Ascend + + ``` + # Set context + context.set_context(mode=context.GRAPH_HOME, device_target=cfg.device_target) + context.set_context(device_id=cfg.device_id) + + # Load unseen dataset for inference + dataset = dataset.create_dataset(cfg.data_path, 1, False) + + # Define model + net = GoogleNet(num_classes=cfg.num_classes) + opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, + cfg.momentum, weight_decay=cfg.weight_decay) + loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', + is_grad=False) + model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}) + + # Load pre-trained model + param_dict = load_checkpoint(cfg.checkpoint_path) + load_param_into_net(net, param_dict) + net.set_train(False) + + # Make predictions on the unseen dataset + acc = model.eval(dataset) + print("accuracy: ", acc) + ``` + +- Running on GPU: + + ``` + # Set context + context.set_context(mode=context.GRAPH_HOME, device_target="GPU") + + # Load unseen dataset for inference + dataset = dataset.create_dataset(cfg.data_path, 1, False) + + # Define model + net = GoogleNet(num_classes=cfg.num_classes) + opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, + cfg.momentum, weight_decay=cfg.weight_decay) + loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', + is_grad=False) + model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}) + + # Load pre-trained model + param_dict = load_checkpoint(args_opt.checkpoint_path) + load_param_into_net(net, param_dict) + net.set_train(False) + + # Make predictions on the unseen dataset + acc = model.eval(dataset) + print("accuracy: ", acc) + + ``` ### Continue Training on the Pretrained Model -``` -# Load dataset -dataset = create_dataset(cfg.data_path, cfg.epoch_size) -batch_num = dataset.get_dataset_size() - -# Define model -net = GoogleNet(num_classes=cfg.num_classes) -# Continue training if set pre_trained to be True -if cfg.pre_trained: - param_dict = load_checkpoint(cfg.checkpoint_path) - load_param_into_net(net, param_dict) -lr = lr_steps(0, lr_max=cfg.lr_init, total_epochs=cfg.epoch_size, - steps_per_epoch=batch_num) -opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), - Tensor(lr), cfg.momentum, weight_decay=cfg.weight_decay) -loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False) -model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}, - amp_level="O2", keep_batchnorm_fp32=False, loss_scale_manager=None) - -# Set callbacks -config_ck = CheckpointConfig(save_checkpoint_steps=batch_num * 5, - keep_checkpoint_max=cfg.keep_checkpoint_max) -time_cb = TimeMonitor(data_size=batch_num) -ckpoint_cb = ModelCheckpoint(prefix="train_googlenet_cifar10", directory="./", - config=config_ck) -loss_cb = LossMonitor() - -# Start training -model.train(cfg.epoch_size, dataset, callbacks=[time_cb, ckpoint_cb, loss_cb]) -print("train success") -``` +- running on Ascend + + ``` + # Load dataset + dataset = create_dataset(cfg.data_path, 1) + batch_num = dataset.get_dataset_size() + + # Define model + net = GoogleNet(num_classes=cfg.num_classes) + # Continue training if set pre_trained to be True + if cfg.pre_trained: + param_dict = load_checkpoint(cfg.checkpoint_path) + load_param_into_net(net, param_dict) + lr = lr_steps(0, lr_max=cfg.lr_init, total_epochs=cfg.epoch_size, + steps_per_epoch=batch_num) + opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), + Tensor(lr), cfg.momentum, weight_decay=cfg.weight_decay) + loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False) + model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}, + amp_level="O2", keep_batchnorm_fp32=False, loss_scale_manager=None) + + # Set callbacks + config_ck = CheckpointConfig(save_checkpoint_steps=batch_num * 5, + keep_checkpoint_max=cfg.keep_checkpoint_max) + time_cb = TimeMonitor(data_size=batch_num) + ckpoint_cb = ModelCheckpoint(prefix="train_googlenet_cifar10", directory="./", + config=config_ck) + loss_cb = LossMonitor() + + # Start training + model.train(cfg.epoch_size, dataset, callbacks=[time_cb, ckpoint_cb, loss_cb]) + print("train success") + ``` + +- running on GPU + + ``` + # Load dataset + dataset = create_dataset(cfg.data_path, 1) + batch_num = dataset.get_dataset_size() + + # Define model + net = GoogleNet(num_classes=cfg.num_classes) + # Continue training if set pre_trained to be True + if cfg.pre_trained: + param_dict = load_checkpoint(cfg.checkpoint_path) + load_param_into_net(net, param_dict) + lr = lr_steps(0, lr_max=cfg.lr_init, total_epochs=cfg.epoch_size, + steps_per_epoch=batch_num) + opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), + Tensor(lr), cfg.momentum, weight_decay=cfg.weight_decay) + loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False) + model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}, + amp_level="O2", keep_batchnorm_fp32=True, loss_scale_manager=None) + + # Set callbacks + config_ck = CheckpointConfig(save_checkpoint_steps=batch_num * 5, + keep_checkpoint_max=cfg.keep_checkpoint_max) + time_cb = TimeMonitor(data_size=batch_num) + ckpoint_cb = ModelCheckpoint(prefix="train_googlenet_cifar10", directory="./ckpt_" + str(get_rank()) + "/", + config=config_ck) + loss_cb = LossMonitor() + + # Start training + model.train(cfg.epoch_size, dataset, callbacks=[time_cb, ckpoint_cb, loss_cb]) + print("train success") + ``` ### Transfer Learning To be added. diff --git a/model_zoo/googlenet/eval.py b/model_zoo/googlenet/eval.py index fc469879e7..650e8def7f 100644 --- a/model_zoo/googlenet/eval.py +++ b/model_zoo/googlenet/eval.py @@ -16,30 +16,64 @@ ##############test googlenet example on cifar10################# python eval.py """ +import argparse + import mindspore.nn as nn from mindspore import context from mindspore.nn.optim.momentum import Momentum from mindspore.train.model import Model from mindspore.train.serialization import load_checkpoint, load_param_into_net -from src.config import cifar_cfg as cfg -from src.dataset import create_dataset +from src.config import cifar_cfg, imagenet_cfg +from src.dataset import create_dataset_cifar10, create_dataset_imagenet + from src.googlenet import GoogleNet +parser = argparse.ArgumentParser(description='googlenet') +parser.add_argument('--dataset_name', type=str, default='cifar10', choices=['imagenet', 'cifar10'], + help='dataset name.') +parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path') +args_opt = parser.parse_args() + if __name__ == '__main__': + + if args_opt.dataset_name == 'cifar10': + cfg = cifar_cfg + dataset = create_dataset_cifar10(cfg.data_path, 1, False) + loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False) + net = GoogleNet(num_classes=cfg.num_classes) + opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, cfg.momentum, + weight_decay=cfg.weight_decay) + model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}) + + elif args_opt.dataset_name == "imagenet": + cfg = imagenet_cfg + dataset = create_dataset_imagenet(cfg.val_data_path, 1, False) + if not cfg.use_label_smooth: + cfg.label_smooth_factor = 0.0 + loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean", + smooth_factor=cfg.label_smooth_factor, num_classes=cfg.num_classes) + net = GoogleNet(num_classes=cfg.num_classes) + model = Model(net, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'}) + + else: + raise ValueError("dataset is not support.") + + device_target = cfg.device_target context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target) - context.set_context(device_id=cfg.device_id) + if device_target == "Ascend": + context.set_context(device_id=cfg.device_id) - net = GoogleNet(num_classes=cfg.num_classes) - opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, cfg.momentum, - weight_decay=cfg.weight_decay) - loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False) - model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}) + if args_opt.checkpoint_path is not None: + param_dict = load_checkpoint(args_opt.checkpoint_path) + print("load checkpoint from [{}].".format(args_opt.checkpoint_path)) + else: + param_dict = load_checkpoint(cfg.checkpoint_path) + print("load checkpoint from [{}].".format(cfg.checkpoint_path)) - param_dict = load_checkpoint(cfg.checkpoint_path) load_param_into_net(net, param_dict) net.set_train(False) - dataset = create_dataset(cfg.data_path, 1, False) + acc = model.eval(dataset) print("accuracy: ", acc) diff --git a/model_zoo/googlenet/export.py b/model_zoo/googlenet/export.py index d1a6de9b8d..2282cf07ea 100644 --- a/model_zoo/googlenet/export.py +++ b/model_zoo/googlenet/export.py @@ -13,24 +13,37 @@ # limitations under the License. # ============================================================================ """ -##############export checkpoint file into geir and onnx models################# +##############export checkpoint file into air and onnx models################# python export.py """ +import argparse import numpy as np -import mindspore as ms from mindspore import Tensor from mindspore.train.serialization import load_checkpoint, load_param_into_net, export -from src.config import cifar_cfg as cfg +from src.config import cifar_cfg, imagenet_cfg from src.googlenet import GoogleNet - if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Classification') + parser.add_argument('--dataset_name', type=str, default='cifar10', choices=['imagenet', 'cifar10'], + help='dataset name.') + args_opt = parser.parse_args() + + if args_opt.dataset_name == 'cifar10': + cfg = cifar_cfg + elif args_opt.dataset_name == 'imagenet': + cfg = imagenet_cfg + else: + raise ValueError("dataset is not support.") + net = GoogleNet(num_classes=cfg.num_classes) + + assert cfg.checkpoint_path is not None, "cfg.checkpoint_path is None." param_dict = load_checkpoint(cfg.checkpoint_path) load_param_into_net(net, param_dict) - input_arr = Tensor(np.random.uniform(0.0, 1.0, size=[1, 3, 224, 224]), ms.float32) + input_arr = Tensor(np.random.uniform(0.0, 1.0, size=[1, 3, 224, 224]).astype(np.float32)) export(net, input_arr, file_name=cfg.onnx_filename, file_format="ONNX") - export(net, input_arr, file_name=cfg.geir_filename, file_format="GEIR") + export(net, input_arr, file_name=cfg.air_filename, file_format="AIR") diff --git a/model_zoo/googlenet/mindspore_hub_conf.py b/model_zoo/googlenet/mindspore_hub_conf.py new file mode 100644 index 0000000000..600838c470 --- /dev/null +++ b/model_zoo/googlenet/mindspore_hub_conf.py @@ -0,0 +1,25 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""hub config.""" +from src.googlenet import GoogleNet + +def googlenet(*args, **kwargs): + return GoogleNet(*args, **kwargs) + + +def create_network(name, *args, **kwargs): + if name == "googlenet": + return googlenet(*args, **kwargs) + raise NotImplementedError(f"{name} is not implemented in the repo") diff --git a/model_zoo/googlenet/scripts/run_eval_gpu.sh b/model_zoo/googlenet/scripts/run_eval_gpu.sh new file mode 100644 index 0000000000..b2e2a38737 --- /dev/null +++ b/model_zoo/googlenet/scripts/run_eval_gpu.sh @@ -0,0 +1,43 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +ulimit -u unlimited + +if [ $# != 1 ] +then + echo "GPU: sh run_eval_gpu.sh [CHECKPOINT_PATH]" +exit 1 +fi + +# check checkpoint file +if [ ! -f $1 ] +then + echo "error: CHECKPOINT_PATH=$1 is not a file" +exit 1 +fi + +BASEPATH=$(cd "`dirname $0`" || exit; pwd) +export PYTHONPATH=${BASEPATH}:$PYTHONPATH +export DEVICE_ID=0 + +if [ -d "../eval" ]; +then + rm -rf ../eval +fi +mkdir ../eval +cd ../eval || exit + +python3 ${BASEPATH}/../eval.py --checkpoint_path=$1 > ./eval.log 2>&1 & diff --git a/model_zoo/googlenet/scripts/run_train.sh b/model_zoo/googlenet/scripts/run_train.sh index c21c2f04b6..1823e6ecc2 100644 --- a/model_zoo/googlenet/scripts/run_train.sh +++ b/model_zoo/googlenet/scripts/run_train.sh @@ -14,36 +14,51 @@ # limitations under the License. # ============================================================================ -if [ $# != 1 ] +if [ $# != 1 ] && [ $# != 2 ] then - echo "Usage: sh run_train.sh [MINDSPORE_HCCL_CONFIG_PATH]" + echo "Usage: sh run_train.sh [RANK_TABLE_FILE] [cifar10|imagenet]" exit 1 fi if [ ! -f $1 ] then - echo "error: MINDSPORE_HCCL_CONFIG_PATH=$1 is not a file" + echo "error: RANK_TABLE_FILE=$1 is not a file" exit 1 fi + +dataset_type='cifar10' +if [ $# == 2 ] +then + if [ $2 != "cifar10" ] && [ $2 != "imagenet" ] + then + echo "error: the selected dataset is neither cifar10 nor imagenet" + exit 1 + fi + dataset_type=$2 +fi + + ulimit -u unlimited export DEVICE_NUM=8 export RANK_SIZE=8 -MINDSPORE_HCCL_CONFIG_PATH=$(realpath $1) -export MINDSPORE_HCCL_CONFIG_PATH -echo "MINDSPORE_HCCL_CONFIG_PATH=${MINDSPORE_HCCL_CONFIG_PATH}" +RANK_TABLE_FILE=$(realpath $1) +export RANK_TABLE_FILE +echo "RANK_TABLE_FILE=${RANK_TABLE_FILE}" +export SERVER_ID=0 +rank_start=$((DEVICE_NUM * SERVER_ID)) for((i=0; i<${DEVICE_NUM}; i++)) do export DEVICE_ID=$i - export RANK_ID=$i + export RANK_ID=$((rank_start + i)) rm -rf ./train_parallel$i mkdir ./train_parallel$i cp -r ./src ./train_parallel$i cp ./train.py ./train_parallel$i - echo "start training for rank $RANK_ID, device $DEVICE_ID" + echo "start training for rank $RANK_ID, device $DEVICE_ID, $dataset_type" cd ./train_parallel$i ||exit env > env.log - python train.py --device_id=$i > log 2>&1 & + python train.py --device_id=$i --dataset_name=$dataset_type> log 2>&1 & cd .. done diff --git a/model_zoo/googlenet/scripts/run_train_gpu.sh b/model_zoo/googlenet/scripts/run_train_gpu.sh new file mode 100644 index 0000000000..b30160238d --- /dev/null +++ b/model_zoo/googlenet/scripts/run_train_gpu.sh @@ -0,0 +1,51 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# -lt 2 ] +then + echo "Usage:\n \ + sh run_train.sh [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)]\n \ + " +exit 1 +fi + +if [ $1 -lt 1 ] && [ $1 -gt 8 ] +then + echo "error: DEVICE_NUM=$1 is not in (1-8)" +exit 1 +fi + +export DEVICE_NUM=$1 +export RANK_SIZE=$1 + +BASEPATH=$(cd "`dirname $0`" || exit; pwd) +export PYTHONPATH=${BASEPATH}:$PYTHONPATH +if [ -d "../train" ]; +then + rm -rf ../train +fi +mkdir ../train +cd ../train || exit + +export CUDA_VISIBLE_DEVICES="$2" + +if [ $1 -gt 1 ] +then + mpirun -n $1 --allow-run-as-root \ + python3 ${BASEPATH}/../train.py > train.log 2>&1 & +else + python3 ${BASEPATH}/../train.py > train.log 2>&1 & +fi diff --git a/model_zoo/googlenet/src/config.py b/model_zoo/googlenet/src/config.py index 5f803ad325..2989e2ba69 100644 --- a/model_zoo/googlenet/src/config.py +++ b/model_zoo/googlenet/src/config.py @@ -18,6 +18,7 @@ network config setting, will be used in main.py from easydict import EasyDict as edict cifar_cfg = edict({ + 'name': 'cifar10', 'pre_trained': False, 'num_classes': 10, 'lr_init': 0.1, @@ -30,9 +31,45 @@ cifar_cfg = edict({ 'image_width': 224, 'data_path': './cifar10', 'device_target': 'Ascend', - 'device_id': 4, + 'device_id': 0, 'keep_checkpoint_max': 10, 'checkpoint_path': './train_googlenet_cifar10-125_390.ckpt', 'onnx_filename': 'googlenet.onnx', - 'geir_filename': 'googlenet.geir' + 'air_filename': 'googlenet.air' +}) + +imagenet_cfg = edict({ + 'name': 'imagenet', + 'pre_trained': False, + 'num_classes': 1000, + 'lr_init': 0.1, + 'batch_size': 256, + 'epoch_size': 300, + 'momentum': 0.9, + 'weight_decay': 1e-4, + 'buffer_size': None, # invalid parameter + 'image_height': 224, + 'image_width': 224, + 'data_path': './ImageNet_Original/train/', + 'val_data_path': './ImageNet_Original/val/', + 'device_target': 'Ascend', + 'device_id': 0, + 'keep_checkpoint_max': 10, + 'checkpoint_path': None, + 'onnx_filename': 'googlenet.onnx', + 'air_filename': 'googlenet.air', + + # optimizer and lr related + 'lr_scheduler': 'exponential', + 'lr_epochs': [70, 140, 210, 280], + 'lr_gamma': 0.3, + 'eta_min': 0.0, + 'T_max': 150, + 'warmup_epochs': 0, + + # loss related + 'is_dynamic_loss_scale': 0, + 'loss_scale': 1024, + 'label_smooth_factor': 0.1, + 'use_label_smooth': True, }) diff --git a/model_zoo/googlenet/src/dataset.py b/model_zoo/googlenet/src/dataset.py index a1cbc2cdab..4fd9802f12 100644 --- a/model_zoo/googlenet/src/dataset.py +++ b/model_zoo/googlenet/src/dataset.py @@ -21,27 +21,30 @@ import mindspore.common.dtype as mstype import mindspore.dataset as ds import mindspore.dataset.transforms.c_transforms as C import mindspore.dataset.transforms.vision.c_transforms as vision -from src.config import cifar_cfg as cfg +from src.config import cifar_cfg, imagenet_cfg -def create_dataset(data_home, repeat_num=1, training=True): +def create_dataset_cifar10(data_home, repeat_num=1, training=True): """Data operations.""" ds.config.set_seed(1) data_dir = os.path.join(data_home, "cifar-10-batches-bin") if not training: data_dir = os.path.join(data_home, "cifar-10-verify-bin") - rank_size = int(os.environ.get("RANK_SIZE")) if os.environ.get("RANK_SIZE") else None - rank_id = int(os.environ.get("RANK_ID")) if os.environ.get("RANK_ID") else None - data_set = ds.Cifar10Dataset(data_dir, num_shards=rank_size, shard_id=rank_id) + rank_size, rank_id = _get_rank_info() + if training: + data_set = ds.Cifar10Dataset(data_dir, num_shards=rank_size, shard_id=rank_id, shuffle=True) + else: + data_set = ds.Cifar10Dataset(data_dir, num_shards=rank_size, shard_id=rank_id, shuffle=False) - resize_height = cfg.image_height - resize_width = cfg.image_width + resize_height = cifar_cfg.image_height + resize_width = cifar_cfg.image_width # define map operations random_crop_op = vision.RandomCrop((32, 32), (4, 4, 4, 4)) # padding_mode default CONSTANT random_horizontal_op = vision.RandomHorizontalFlip() resize_op = vision.Resize((resize_height, resize_width)) # interpolation default BILINEAR + rescale_op = vision.Rescale(1.0 / 255.0, 0.0) normalize_op = vision.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) changeswap_op = vision.HWC2CHW() type_cast_op = C.TypeCast(mstype.int32) @@ -49,19 +52,93 @@ def create_dataset(data_home, repeat_num=1, training=True): c_trans = [] if training: c_trans = [random_crop_op, random_horizontal_op] - c_trans += [resize_op, normalize_op, changeswap_op] + c_trans += [resize_op, rescale_op, normalize_op, changeswap_op] # apply map operations on images data_set = data_set.map(input_columns="label", operations=type_cast_op) data_set = data_set.map(input_columns="image", operations=c_trans) + # apply batch operations + data_set = data_set.batch(batch_size=cifar_cfg.batch_size, drop_remainder=True) + # apply repeat operations data_set = data_set.repeat(repeat_num) - # apply shuffle operations - data_set = data_set.shuffle(buffer_size=10) + return data_set + + +def create_dataset_imagenet(dataset_path, repeat_num=1, training=True, + num_parallel_workers=None, shuffle=None): + """ + create a train or eval imagenet2012 dataset for resnet50 + + Args: + dataset_path(string): the path of dataset. + do_train(bool): whether dataset is used for train or eval. + repeat_num(int): the repeat times of dataset. Default: 1 + batch_size(int): the batch size of dataset. Default: 32 + target(str): the device target. Default: Ascend + + Returns: + dataset + """ + + device_num, rank_id = _get_rank_info() + + if device_num == 1: + data_set = ds.ImageFolderDatasetV2(dataset_path, num_parallel_workers=num_parallel_workers, shuffle=shuffle) + else: + data_set = ds.ImageFolderDatasetV2(dataset_path, num_parallel_workers=num_parallel_workers, shuffle=shuffle, + num_shards=device_num, shard_id=rank_id) + + assert imagenet_cfg.image_height == imagenet_cfg.image_width, "image_height not equal image_width" + image_size = imagenet_cfg.image_height + mean = [0.485 * 255, 0.456 * 255, 0.406 * 255] + std = [0.229 * 255, 0.224 * 255, 0.225 * 255] + + # define map operations + if training: + transform_img = [ + vision.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)), + vision.RandomHorizontalFlip(prob=0.5), + vision.RandomColorAdjust(0.4, 0.4, 0.4, 0.1), + vision.Normalize(mean=mean, std=std), + vision.HWC2CHW() + ] + else: + transform_img = [ + vision.Decode(), + vision.Resize(256), + vision.CenterCrop(image_size), + vision.Normalize(mean=mean, std=std), + vision.HWC2CHW() + ] + + transform_label = [C.TypeCast(mstype.int32)] + + data_set = data_set.map(input_columns="image", num_parallel_workers=8, operations=transform_img) + data_set = data_set.map(input_columns="label", num_parallel_workers=8, operations=transform_label) # apply batch operations - data_set = data_set.batch(batch_size=cfg.batch_size, drop_remainder=True) + data_set = data_set.batch(imagenet_cfg.batch_size, drop_remainder=True) + + # apply dataset repeat operation + data_set = data_set.repeat(repeat_num) return data_set + + +def _get_rank_info(): + """ + get rank size and rank id + """ + rank_size = int(os.environ.get("RANK_SIZE", 1)) + + if rank_size > 1: + from mindspore.communication.management import get_rank, get_group_size + rank_size = get_group_size() + rank_id = get_rank() + else: + rank_size = rank_id = None + + return rank_size, rank_id diff --git a/model_zoo/googlenet/src/googlenet.py b/model_zoo/googlenet/src/googlenet.py index 701b3aeb5a..78695f2d6c 100644 --- a/model_zoo/googlenet/src/googlenet.py +++ b/model_zoo/googlenet/src/googlenet.py @@ -63,7 +63,7 @@ class Inception(nn.Cell): Conv2dBlock(n3x3red, n3x3, kernel_size=3, padding=0)]) self.b3 = nn.SequentialCell([Conv2dBlock(in_channels, n5x5red, kernel_size=1), Conv2dBlock(n5x5red, n5x5, kernel_size=3, padding=0)]) - self.maxpool = P.MaxPoolWithArgmax(ksize=3, strides=1, padding="same") + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=1, pad_mode="same") self.b4 = Conv2dBlock(in_channels, pool_planes, kernel_size=1) self.concat = P.Concat(axis=1) @@ -71,9 +71,8 @@ class Inception(nn.Cell): branch1 = self.b1(x) branch2 = self.b2(x) branch3 = self.b3(x) - cell, argmax = self.maxpool(x) + cell = self.maxpool(x) branch4 = self.b4(cell) - _ = argmax return self.concat((branch1, branch2, branch3, branch4)) @@ -85,22 +84,22 @@ class GoogleNet(nn.Cell): def __init__(self, num_classes): super(GoogleNet, self).__init__() self.conv1 = Conv2dBlock(3, 64, kernel_size=7, stride=2, padding=0) - self.maxpool1 = P.MaxPoolWithArgmax(ksize=3, strides=2, padding="same") + self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same") self.conv2 = Conv2dBlock(64, 64, kernel_size=1) self.conv3 = Conv2dBlock(64, 192, kernel_size=3, padding=0) - self.maxpool2 = P.MaxPoolWithArgmax(ksize=3, strides=2, padding="same") + self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same") self.block3a = Inception(192, 64, 96, 128, 16, 32, 32) self.block3b = Inception(256, 128, 128, 192, 32, 96, 64) - self.maxpool3 = P.MaxPoolWithArgmax(ksize=3, strides=2, padding="same") + self.maxpool3 = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same") self.block4a = Inception(480, 192, 96, 208, 16, 48, 64) self.block4b = Inception(512, 160, 112, 224, 24, 64, 64) self.block4c = Inception(512, 128, 128, 256, 24, 64, 64) self.block4d = Inception(512, 112, 144, 288, 32, 64, 64) self.block4e = Inception(528, 256, 160, 320, 32, 128, 128) - self.maxpool4 = P.MaxPoolWithArgmax(ksize=2, strides=2, padding="same") + self.maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode="same") self.block5a = Inception(832, 256, 160, 320, 32, 128, 128) self.block5b = Inception(832, 384, 192, 384, 48, 128, 128) @@ -113,23 +112,24 @@ class GoogleNet(nn.Cell): def construct(self, x): + """construct""" x = self.conv1(x) - x, argmax = self.maxpool1(x) + x = self.maxpool1(x) x = self.conv2(x) x = self.conv3(x) - x, argmax = self.maxpool2(x) + x = self.maxpool2(x) x = self.block3a(x) x = self.block3b(x) - x, argmax = self.maxpool3(x) + x = self.maxpool3(x) x = self.block4a(x) x = self.block4b(x) x = self.block4c(x) x = self.block4d(x) x = self.block4e(x) - x, argmax = self.maxpool4(x) + x = self.maxpool4(x) x = self.block5a(x) x = self.block5b(x) @@ -138,5 +138,4 @@ class GoogleNet(nn.Cell): x = self.flatten(x) x = self.classifier(x) - _ = argmax return x diff --git a/model_zoo/googlenet/src/lr_scheduler/__init__.py b/model_zoo/googlenet/src/lr_scheduler/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/model_zoo/googlenet/src/lr_scheduler/linear_warmup.py b/model_zoo/googlenet/src/lr_scheduler/linear_warmup.py new file mode 100644 index 0000000000..78e7b85f6d --- /dev/null +++ b/model_zoo/googlenet/src/lr_scheduler/linear_warmup.py @@ -0,0 +1,20 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""lr""" + +def linear_warmup_lr(current_step, warmup_steps, base_lr, init_lr): + lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps) + lr = float(init_lr) + lr_inc * current_step + return lr diff --git a/model_zoo/googlenet/src/lr_scheduler/warmup_cosine_annealing_lr.py b/model_zoo/googlenet/src/lr_scheduler/warmup_cosine_annealing_lr.py new file mode 100644 index 0000000000..349270b6e1 --- /dev/null +++ b/model_zoo/googlenet/src/lr_scheduler/warmup_cosine_annealing_lr.py @@ -0,0 +1,39 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""lr""" + +import math +import numpy as np + +from .linear_warmup import linear_warmup_lr + + +def warmup_cosine_annealing_lr(lr, steps_per_epoch, warmup_epochs, max_epoch, T_max, eta_min=0): + """ warmup cosine annealing lr""" + base_lr = lr + warmup_init_lr = 0 + total_steps = int(max_epoch * steps_per_epoch) + warmup_steps = int(warmup_epochs * steps_per_epoch) + + lr_each_step = [] + for i in range(total_steps): + last_epoch = i // steps_per_epoch + if i < warmup_steps: + lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr) + else: + lr = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi * last_epoch / T_max)) / 2 + lr_each_step.append(lr) + + return np.array(lr_each_step).astype(np.float32) diff --git a/model_zoo/googlenet/src/lr_scheduler/warmup_step_lr.py b/model_zoo/googlenet/src/lr_scheduler/warmup_step_lr.py new file mode 100644 index 0000000000..df78f17d79 --- /dev/null +++ b/model_zoo/googlenet/src/lr_scheduler/warmup_step_lr.py @@ -0,0 +1,59 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""lr""" + +from collections import Counter +import numpy as np + +from .linear_warmup import linear_warmup_lr + + +def warmup_step_lr(lr, lr_epochs, steps_per_epoch, warmup_epochs, max_epoch, gamma=0.1): + """warmup step lr""" + base_lr = lr + warmup_init_lr = 0 + total_steps = int(max_epoch * steps_per_epoch) + warmup_steps = int(warmup_epochs * steps_per_epoch) + milestones = lr_epochs + milestones_steps = [] + for milestone in milestones: + milestones_step = milestone * steps_per_epoch + milestones_steps.append(milestones_step) + + lr_each_step = [] + lr = base_lr + milestones_steps_counter = Counter(milestones_steps) + for i in range(total_steps): + if i < warmup_steps: + lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr) + else: + lr = lr * gamma ** milestones_steps_counter[i] + lr_each_step.append(lr) + + return np.array(lr_each_step).astype(np.float32) + + +def multi_step_lr(lr, milestones, steps_per_epoch, max_epoch, gamma=0.1): + """lr""" + return warmup_step_lr(lr, milestones, steps_per_epoch, 0, max_epoch, gamma=gamma) + + +def step_lr(lr, epoch_size, steps_per_epoch, max_epoch, gamma=0.1): + """lr""" + lr_epochs = [] + for i in range(1, max_epoch): + if i % epoch_size == 0: + lr_epochs.append(i) + return multi_step_lr(lr, lr_epochs, steps_per_epoch, max_epoch, gamma=gamma) diff --git a/model_zoo/googlenet/train.py b/model_zoo/googlenet/train.py index 0129176510..2c1aaddfbe 100644 --- a/model_zoo/googlenet/train.py +++ b/model_zoo/googlenet/train.py @@ -25,21 +25,23 @@ import numpy as np import mindspore.nn as nn from mindspore import Tensor from mindspore import context -from mindspore.communication.management import init +from mindspore.communication.management import init, get_rank from mindspore.nn.optim.momentum import Momentum from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor -from mindspore.train.model import Model, ParallelMode +from mindspore.train.loss_scale_manager import DynamicLossScaleManager, FixedLossScaleManager +from mindspore.train.model import Model +from mindspore import ParallelMode from mindspore.train.serialization import load_checkpoint, load_param_into_net -from src.config import cifar_cfg as cfg -from src.dataset import create_dataset +from src.config import cifar_cfg, imagenet_cfg +from src.dataset import create_dataset_cifar10, create_dataset_imagenet from src.googlenet import GoogleNet random.seed(1) np.random.seed(1) -def lr_steps(global_step, lr_max=None, total_epochs=None, steps_per_epoch=None): +def lr_steps_cifar10(global_step, lr_max=None, total_epochs=None, steps_per_epoch=None): """Set learning rate.""" lr_each_step = [] total_steps = steps_per_epoch * total_epochs @@ -60,25 +62,79 @@ def lr_steps(global_step, lr_max=None, total_epochs=None, steps_per_epoch=None): return learning_rate +def lr_steps_imagenet(_cfg, steps_per_epoch): + """lr step for imagenet""" + from src.lr_scheduler.warmup_step_lr import warmup_step_lr + from src.lr_scheduler.warmup_cosine_annealing_lr import warmup_cosine_annealing_lr + if _cfg.lr_scheduler == 'exponential': + _lr = warmup_step_lr(_cfg.lr_init, + _cfg.lr_epochs, + steps_per_epoch, + _cfg.warmup_epochs, + _cfg.epoch_size, + gamma=_cfg.lr_gamma, + ) + elif _cfg.lr_scheduler == 'cosine_annealing': + _lr = warmup_cosine_annealing_lr(_cfg.lr_init, + steps_per_epoch, + _cfg.warmup_epochs, + _cfg.epoch_size, + _cfg.T_max, + _cfg.eta_min) + else: + raise NotImplementedError(_cfg.lr_scheduler) + + return _lr + + if __name__ == '__main__': - parser = argparse.ArgumentParser(description='Cifar10 classification') + parser = argparse.ArgumentParser(description='Classification') + parser.add_argument('--dataset_name', type=str, default='cifar10', choices=['imagenet', 'cifar10'], + help='dataset name.') parser.add_argument('--device_id', type=int, default=None, help='device id of GPU or Ascend. (Default: None)') args_opt = parser.parse_args() - context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target) - if args_opt.device_id is not None: - context.set_context(device_id=args_opt.device_id) + if args_opt.dataset_name == "cifar10": + cfg = cifar_cfg + elif args_opt.dataset_name == "imagenet": + cfg = imagenet_cfg else: - context.set_context(device_id=cfg.device_id) + raise ValueError("Unsupport dataset.") + + # set context + device_target = cfg.device_target + context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target) device_num = int(os.environ.get("DEVICE_NUM", 1)) - if device_num > 1: - context.reset_auto_parallel_context() - context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, - mirror_mean=True) + + if device_target == "Ascend": + if args_opt.device_id is not None: + context.set_context(device_id=args_opt.device_id) + else: + context.set_context(device_id=cfg.device_id) + + if device_num > 1: + context.reset_auto_parallel_context() + context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, + mirror_mean=True) + init() + elif device_target == "GPU": init() - dataset = create_dataset(cfg.data_path, cfg.epoch_size) + if device_num > 1: + context.reset_auto_parallel_context() + context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, + mirror_mean=True) + else: + raise ValueError("Unsupported platform.") + + if args_opt.dataset_name == "cifar10": + dataset = create_dataset_cifar10(cfg.data_path, cfg.epoch_size) + elif args_opt.dataset_name == "imagenet": + dataset = create_dataset_imagenet(cfg.data_path, cfg.epoch_size) + else: + raise ValueError("Unsupport dataset.") + batch_num = dataset.get_dataset_size() net = GoogleNet(num_classes=cfg.num_classes) @@ -86,16 +142,75 @@ if __name__ == '__main__': if cfg.pre_trained: param_dict = load_checkpoint(cfg.checkpoint_path) load_param_into_net(net, param_dict) - lr = lr_steps(0, lr_max=cfg.lr_init, total_epochs=cfg.epoch_size, steps_per_epoch=batch_num) - opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), Tensor(lr), cfg.momentum, - weight_decay=cfg.weight_decay) - loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False) - model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}, - amp_level="O2", keep_batchnorm_fp32=False, loss_scale_manager=None) + + loss_scale_manager = None + if args_opt.dataset_name == 'cifar10': + lr = lr_steps_cifar10(0, lr_max=cfg.lr_init, total_epochs=cfg.epoch_size, steps_per_epoch=batch_num) + opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), + learning_rate=Tensor(lr), + momentum=cfg.momentum, + weight_decay=cfg.weight_decay) + loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False) + + elif args_opt.dataset_name == 'imagenet': + lr = lr_steps_imagenet(cfg, batch_num) + + + def get_param_groups(network): + """ get param groups """ + decay_params = [] + no_decay_params = [] + for x in network.trainable_params(): + parameter_name = x.name + if parameter_name.endswith('.bias'): + # all bias not using weight decay + # print('no decay:{}'.format(parameter_name)) + no_decay_params.append(x) + elif parameter_name.endswith('.gamma'): + # bn weight bias not using weight decay, be carefully for now x not include BN + # print('no decay:{}'.format(parameter_name)) + no_decay_params.append(x) + elif parameter_name.endswith('.beta'): + # bn weight bias not using weight decay, be carefully for now x not include BN + # print('no decay:{}'.format(parameter_name)) + no_decay_params.append(x) + else: + decay_params.append(x) + + return [{'params': no_decay_params, 'weight_decay': 0.0}, {'params': decay_params}] + + + if cfg.is_dynamic_loss_scale: + cfg.loss_scale = 1 + + opt = Momentum(params=get_param_groups(net), + learning_rate=Tensor(lr), + momentum=cfg.momentum, + weight_decay=cfg.weight_decay, + loss_scale=cfg.loss_scale) + if not cfg.use_label_smooth: + cfg.label_smooth_factor = 0.0 + loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean", + smooth_factor=cfg.label_smooth_factor, num_classes=cfg.num_classes) + + if cfg.is_dynamic_loss_scale == 1: + loss_scale_manager = DynamicLossScaleManager(init_loss_scale=65536, scale_factor=2, scale_window=2000) + else: + loss_scale_manager = FixedLossScaleManager(cfg.loss_scale, drop_overflow_update=False) + + if device_target == "Ascend": + model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}, + amp_level="O2", keep_batchnorm_fp32=False, loss_scale_manager=loss_scale_manager) + ckpt_save_dir = "./" + else: # GPU + model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}, + amp_level="O2", keep_batchnorm_fp32=True, loss_scale_manager=loss_scale_manager) + ckpt_save_dir = "./ckpt_" + str(get_rank()) + "/" config_ck = CheckpointConfig(save_checkpoint_steps=batch_num * 5, keep_checkpoint_max=cfg.keep_checkpoint_max) time_cb = TimeMonitor(data_size=batch_num) - ckpoint_cb = ModelCheckpoint(prefix="train_googlenet_cifar10", directory="./", config=config_ck) + ckpoint_cb = ModelCheckpoint(prefix="train_googlenet_" + args_opt.dataset_name, directory=ckpt_save_dir, + config=config_ck) loss_cb = LossMonitor() model.train(cfg.epoch_size, dataset, callbacks=[time_cb, ckpoint_cb, loss_cb]) print("train success")