Merge pull request !5922 from caojian05/ms_r0.5_vgg16_support_imagenet_on_ascendpull/5922/MERGE
| @@ -1,55 +1,279 @@ | |||
| # VGG16 Example | |||
| # Contents | |||
| ## Description | |||
| - [VGG Description](#vgg-description) | |||
| - [Model Architecture](#model-architecture) | |||
| - [Dataset](#dataset) | |||
| - [Features](#features) | |||
| - [Mixed Precision](#mixed-precision) | |||
| - [Environment Requirements](#environment-requirements) | |||
| - [Quick Start](#quick-start) | |||
| - [Script Description](#script-description) | |||
| - [Script and Sample Code](#script-and-sample-code) | |||
| - [Script Parameters](#script-parameters) | |||
| - [Parameter configuration](#parameter-configuration) | |||
| - [Training Process](#training-process) | |||
| - [Training](#training) | |||
| - [Evaluation Process](#evaluation-process) | |||
| - [Evaluation](#evaluation) | |||
| - [Model Description](#model-description) | |||
| - [Performance](#performance) | |||
| - [Training Performance](#training-performance) | |||
| - [Evaluation Performance](#evaluation-performance) | |||
| - [Description of Random Situation](#description-of-random-situation) | |||
| - [ModelZoo Homepage](#modelzoo-homepage) | |||
| This example is for VGG16 model training and evaluation. | |||
| ## Requirements | |||
| # [VGG Description](#contents) | |||
| - Install [MindSpore](https://www.mindspore.cn/install/en). | |||
| VGG, a very deep convolutional networks for large-scale image recognition, was proposed in 2014 and won the 1th place in object localization and 2th place in image classification task in ImageNet Large-Scale Visual Recognition Challenge 2014 (ILSVRC14). | |||
| - Download the CIFAR-10 binary version dataset. | |||
| [Paper](): Simonyan K, zisserman A. Very Deep Convolutional Networks for Large-Scale Image Recognition[J]. arXiv preprint arXiv:1409.1556, 2014. | |||
| > Unzip the CIFAR-10 dataset to any path you want and the folder structure should be as follows: | |||
| > ``` | |||
| > . | |||
| > ├── cifar-10-batches-bin # train dataset | |||
| > └── cifar-10-verify-bin # infer dataset | |||
| > ``` | |||
| # [Model Architecture](#contents) | |||
| VGG 16 network is mainly consisted by several basic modules (including convolution and pooling layer) and three continuous Dense layer. | |||
| here basic modules mainly include basic operation like: **3×3 conv** and **2×2 max pooling**. | |||
| ## Running the Example | |||
| ### Training | |||
| # [Dataset](#contents) | |||
| #### Dataset used: [CIFAR-10](<http://www.cs.toronto.edu/~kriz/cifar.html>) | |||
| - CIFAR-10 Dataset size:175M,60,000 32*32 colorful images in 10 classes | |||
| - Train:146M,50,000 images | |||
| - Test:29.3M,10,000 images | |||
| - Data format: binary files | |||
| - Note: Data will be processed in src/dataset.py | |||
| #### Dataset used: [ImageNet2012](http://www.image-net.org/) | |||
| - Dataset size: ~146G, 1.28 million colorful images in 1000 classes | |||
| - Train: 140G, 1,281,167 images | |||
| - Test: 6.4G, 50, 000 images | |||
| - Data format: RGB images | |||
| - Note: Data will be processed in src/dataset.py | |||
| #### Dataset organize way | |||
| CIFAR-10 | |||
| > Unzip the CIFAR-10 dataset to any path you want and the folder structure should be as follows: | |||
| > ``` | |||
| > . | |||
| > ├── cifar-10-batches-bin # train dataset | |||
| > └── cifar-10-verify-bin # infer dataset | |||
| > ``` | |||
| ImageNet2012 | |||
| > Unzip the ImageNet2012 dataset to any path you want and the folder should include train and eval dataset as follows: | |||
| > | |||
| > ``` | |||
| > . | |||
| > └─dataset | |||
| > ├─ilsvrc # train dataset | |||
| > └─validation_preprocess # evaluate dataset | |||
| > ``` | |||
| # [Features](#contents) | |||
| ## Mixed Precision | |||
| The [mixed precision](https://www.mindspore.cn/tutorial/zh-CN/master/advanced_use/mixed_precision.html) training method accelerates the deep learning neural network training process by using both the single-precision and half-precision data formats, and maintains the network precision achieved by the single-precision training at the same time. Mixed precision training can accelerate the computation process, reduce memory usage, and enable a larger model or batch size to be trained on specific hardware. | |||
| For FP16 operators, if the input data type is FP32, the backend of MindSpore will automatically handle it with reduced precision. Users could check the reduced-precision operators by enabling INFO log and then searching ‘reduce precision’. | |||
| # [Environment Requirements](#contents) | |||
| - Hardware(Ascend/GPU) | |||
| - Prepare hardware environment with Ascend or GPU processor. If you want to try Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources. | |||
| - Framework | |||
| - [MindSpore](https://www.mindspore.cn/install/en) | |||
| - For more information, please check the resources below: | |||
| - [MindSpore tutorials](https://www.mindspore.cn/tutorial/zh-CN/master/index.html) | |||
| - [MindSpore API](https://www.mindspore.cn/api/zh-CN/master/index.html) | |||
| # [Quick Start](#contents) | |||
| After installing MindSpore via the official website, you can start training and evaluation as follows: | |||
| - Running on Ascend | |||
| ```python | |||
| # run training example | |||
| python train.py --data_path=[DATA_PATH] --device_id=[DEVICE_ID] > output.train.log 2>&1 & | |||
| # run distributed training example | |||
| sh run_distribute_train.sh [RANL_TABLE_JSON] [DATA_PATH] | |||
| # run evaluation example | |||
| python eval.py --data_path=[DATA_PATH] --pre_trained=[PRE_TRAINED] > output.eval.log 2>&1 & | |||
| ``` | |||
| python train.py --data_path=your_data_path --device_id=6 > out.train.log 2>&1 & | |||
| For distributed training, a hccl configuration file with JSON format needs to be created in advance. | |||
| Please follow the instructions in the link below: | |||
| https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools | |||
| - Running on GPU | |||
| ``` | |||
| The python command above will run in the background, you can view the results through the file `out.train.log`. | |||
| # run training example | |||
| python train.py --device_target="GPU" --device_id=[DEVICE_ID] --dataset=[DATASET_TYPE] --data_path=[DATA_PATH] > output.train.log 2>&1 & | |||
| After training, you'll get some checkpoint files under the script folder by default. | |||
| # run distributed training example | |||
| sh run_distribute_train_gpu.sh [DATA_PATH] | |||
| You will get the loss value as following: | |||
| # run evaluation example | |||
| python eval.py --device_target="GPU" --device_id=[DEVICE_ID] --dataset=[DATASET_TYPE] --data_path=[DATA_PATH] --pre_trained=[PRE_TRAINED] > output.eval.log 2>&1 & | |||
| ``` | |||
| # grep "loss is " out.train.log | |||
| epoch: 1 step: 781, loss is 2.093086 | |||
| epcoh: 2 step: 781, loss is 1.827582 | |||
| ... | |||
| # [Script Description](#contents) | |||
| ## [Script and Sample Code](#contents) | |||
| ``` | |||
| ├── model_zoo | |||
| ├── README.md // descriptions about all the models | |||
| ├── vgg16 | |||
| ├── README.md // descriptions about googlenet | |||
| ├── scripts | |||
| │ ├── run_distribute_train.sh // shell script for distributed training on Ascend | |||
| │ ├── run_distribute_train_gpu.sh // shell script for distributed training on GPU | |||
| ├── src | |||
| │ ├── utils | |||
| │ │ ├── logging.py // logging format setting | |||
| │ │ ├── sampler.py // create sampler for dataset | |||
| │ │ ├── util.py // util function | |||
| │ │ ├── var_init.py // network parameter init method | |||
| │ ├── config.py // parameter configuration | |||
| │ ├── crossentropy.py // loss caculation | |||
| │ ├── dataset.py // creating dataset | |||
| │ ├── linear_warmup.py // linear leanring rate | |||
| │ ├── warmup_cosine_annealing_lr.py // consine anealing learning rate | |||
| │ ├── warmup_step_lr.py // step or multi step learning rate | |||
| │ ├──vgg.py // vgg architecture | |||
| ├── train.py // training script | |||
| ├── eval.py // evaluation script | |||
| ``` | |||
| ## [Script Parameters](#contents) | |||
| ### Training | |||
| ``` | |||
| usage: train.py [--device_target TARGET][--data_path DATA_PATH] | |||
| [--dataset DATASET_TYPE][--is_distributed VALUE] | |||
| [--device_id DEVICE_ID][--pre_trained PRE_TRAINED] | |||
| [--ckpt_path CHECKPOINT_PATH][--ckpt_interval INTERVAL_STEP] | |||
| parameters/options: | |||
| --device_target the training backend type, Ascend or GPU, default is Ascend. | |||
| --dataset the dataset type, cifar10 or imagenet2012. | |||
| --is_distributed the way of traing, whether do distribute traing, value can be 0 or 1. | |||
| --data_path the storage path of dataset | |||
| --device_id the device which used to train model. | |||
| --pre_trained the pretrained checkpoint file path. | |||
| --ckpt_path the path to save checkpoint. | |||
| --ckpt_interval the epoch interval for saving checkpoint. | |||
| ``` | |||
| ### Evaluation | |||
| ``` | |||
| python eval.py --data_path=your_data_path --device_id=6 --checkpoint_path=./train_vgg_cifar10-70-781.ckpt > out.eval.log 2>&1 & | |||
| usage: eval.py [--device_target TARGET][--data_path DATA_PATH] | |||
| [--dataset DATASET_TYPE][--pre_trained PRE_TRAINED] | |||
| [--device_id DEVICE_ID] | |||
| parameters/options: | |||
| --device_target the evaluation backend type, Ascend or GPU, default is Ascend. | |||
| --dataset the dataset type, cifar10 or imagenet2012. | |||
| --data_path the storage path of dataset. | |||
| --device_id the device which used to evaluate model. | |||
| --pre_trained the checkpoint file path used to evaluate model. | |||
| ``` | |||
| The above python command will run in the background, you can view the results through the file `out.eval.log`. | |||
| You will get the accuracy as following: | |||
| ## [Parameter configuration](#contents) | |||
| Parameters for both training and evaluation can be set in config.py. | |||
| - config for vgg16, CIFAR-10 dataset | |||
| ``` | |||
| # grep "result: " out.eval.log | |||
| result: {'acc': 0.92} | |||
| "num_classes": 10, # dataset class num | |||
| "lr": 0.01, # learning rate | |||
| "lr_init": 0.01, # initial learning rate | |||
| "lr_max": 0.1, # max learning rate | |||
| "lr_epochs": '30,60,90,120', # lr changing based epochs | |||
| "lr_scheduler": "step", # learning rate mode | |||
| "warmup_epochs": 5, # number of warmup epoch | |||
| "batch_size": 64, # batch size of input tensor | |||
| "max_epoch": 70, # only valid for taining, which is always 1 for inference | |||
| "momentum": 0.9, # momentum | |||
| "weight_decay": 5e-4, # weight decay | |||
| "loss_scale": 1.0, # loss scale | |||
| "label_smooth": 0, # label smooth | |||
| "label_smooth_factor": 0, # label smooth factor | |||
| "buffer_size": 10, # shuffle buffer size | |||
| "image_size": '224,224', # image size | |||
| "pad_mode": 'same', # pad mode for conv2d | |||
| "padding": 0, # padding value for conv2d | |||
| "has_bias": False, # whether has bias in conv2d | |||
| "batch_norm": True, # wether has batch_norm in conv2d | |||
| "keep_checkpoint_max": 10, # only keep the last keep_checkpoint_max checkpoint | |||
| "initialize_mode": "XavierUniform", # conv2d init mode | |||
| "has_dropout": True # wether using Dropout layer | |||
| ``` | |||
| - config for vgg16, ImageNet2012 dataset | |||
| ``` | |||
| "num_classes": 1000, # dataset class num | |||
| "lr": 0.01, # learning rate | |||
| "lr_init": 0.01, # initial learning rate | |||
| "lr_max": 0.1, # max learning rate | |||
| "lr_epochs": '30,60,90,120', # lr changing based epochs | |||
| "lr_scheduler": "cosine_annealing", # learning rate mode | |||
| "warmup_epochs": 0, # number of warmup epoch | |||
| "batch_size": 32, # batch size of input tensor | |||
| "max_epoch": 150, # only valid for taining, which is always 1 for inference | |||
| "momentum": 0.9, # momentum | |||
| "weight_decay": 1e-4, # weight decay | |||
| "loss_scale": 1024, # loss scale | |||
| "label_smooth": 1, # label smooth | |||
| "label_smooth_factor": 0.1, # label smooth factor | |||
| "buffer_size": 10, # shuffle buffer size | |||
| "image_size": '224,224', # image size | |||
| "pad_mode": 'pad', # pad mode for conv2d | |||
| "padding": 1, # padding value for conv2d | |||
| "has_bias": True, # whether has bias in conv2d | |||
| "batch_norm": False, # wether has batch_norm in conv2d | |||
| "keep_checkpoint_max": 10, # only keep the last keep_checkpoint_max checkpoint | |||
| "initialize_mode": "KaimingNormal", # conv2d init mode | |||
| "has_dropout": True # wether using Dropout layer | |||
| ``` | |||
| ## [Training Process](#contents) | |||
| ### Training | |||
| #### Run vgg16 on Ascend | |||
| - Training using single device(1p), using CIFAR-10 dataset in default | |||
| ``` | |||
| python train.py --data_path=your_data_path --device_id=6 > out.train.log 2>&1 & | |||
| ``` | |||
| The python command above will run in the background, you can view the results through the file `out.train.log`. | |||
| ### Distribute Training | |||
| After training, you'll get some checkpoint files in specified ckpt_path, default in ./output directory. | |||
| You will get the loss value as following: | |||
| ``` | |||
| # grep "loss is " output.train.log | |||
| epoch: 1 step: 781, loss is 2.093086 | |||
| epcoh: 2 step: 781, loss is 1.827582 | |||
| ... | |||
| ``` | |||
| - Distributed Training | |||
| ``` | |||
| sh run_distribute_train.sh rank_table.json your_data_path | |||
| ``` | |||
| @@ -68,40 +292,83 @@ train_parallel1/log:epcoh: 2 step: 97, loss is 1.7133579 | |||
| ``` | |||
| > About rank_table.json, you can refer to the [distributed training tutorial](https://www.mindspore.cn/tutorial/en/master/advanced_use/distributed_training.html). | |||
| ## Usage: | |||
| ### Training | |||
| ``` | |||
| usage: train.py [--device_target TARGET][--data_path DATA_PATH] | |||
| [--device_id DEVICE_ID][--pre_trained PRE_TRAINED] | |||
| #### Run vgg16 on GPU | |||
| parameters/options: | |||
| --device_target the training backend type, default is Ascend. | |||
| --data_path the storage path of dataset | |||
| --device_id the device which used to train model. | |||
| --pre_trained the pretrained checkpoint file path. | |||
| - Training using single device(1p) | |||
| ``` | |||
| python train.py --device_target="GPU" --dataset="imagenet2012" --is_distributed=0 --data_path=$DATA_PATH > output.train.log 2>&1 & | |||
| ``` | |||
| - Distributed Training | |||
| ``` | |||
| # distributed training(8p) | |||
| bash scripts/run_distribute_train_gpu.sh /path/ImageNet2012/train" | |||
| ``` | |||
| ## [Evaluation Process](#contents) | |||
| ### Evaluation | |||
| - Do eval as follows, need to specify dataset type as "cifar10" or "imagenet2012" | |||
| ``` | |||
| usage: eval.py [--device_target TARGET][--data_path DATA_PATH] | |||
| [--device_id DEVICE_ID][--checkpoint_path CKPT_PATH] | |||
| # when using cifar10 dataset | |||
| python eval.py --data_path=your_data_path --dataset="cifar10" --device_target="Ascend" --pre_trained=./*-70-781.ckpt > output.eval.log 2>&1 & | |||
| parameters/options: | |||
| --device_target the evaluation backend type, default is Ascend. | |||
| --data_path the storage path of datasetd | |||
| --device_id the device which used to evaluate model. | |||
| --checkpoint_path the checkpoint file path used to evaluate model. | |||
| # when using imagenet2012 dataset | |||
| python eval.py --data_path=your_data_path --dataset="imagenet2012" --device_target="GPU" --pre_trained=./*-150-5004.ckpt > output.eval.log 2>&1 & | |||
| ``` | |||
| ### Distribute Training | |||
| - The above python command will run in the background, you can view the results through the file `output.eval.log`. You will get the accuracy as following: | |||
| ``` | |||
| Usage: sh script/run_distribute_train.sh [MINDSPORE_HCCL_CONFIG_PATH] [DATA_PATH] | |||
| # when using cifar10 dataset | |||
| # grep "result: " output.eval.log | |||
| result: {'acc': 0.92} | |||
| parameters/options: | |||
| MINDSPORE_HCCL_CONFIG_PATH HCCL configuration file path. | |||
| DATA_PATH the storage path of dataset. | |||
| # when using the imagenet2012 dataset | |||
| after allreduce eval: top1_correct=36636, tot=50000, acc=73.27% | |||
| after allreduce eval: top5_correct=45582, tot=50000, acc=91.16% | |||
| ``` | |||
| # [Model Description](#contents) | |||
| ## [Performance](#contents) | |||
| ### Training Performance | |||
| | Parameters | VGG16(Ascend) | VGG16(GPU) | | |||
| | -------------------------- | ---------------------------------------------- |------------------------------------| | |||
| | Model Version | VGG16 | VGG16 | | |||
| | Resource | Ascend 910 ;CPU 2.60GHz,56cores;Memory,314G |NV SMX2 V100-32G | | |||
| | uploaded Date | 08/20/2020 |08/20/2020 | | |||
| | MindSpore Version | 0.5.0-alpha |0.5.0-alpha | | |||
| | Dataset | CIFAR-10 |ImageNet2012 | | |||
| | Training Parameters | epoch=70, steps=781, batch_size = 64, lr=0.1 |epoch=150, steps=40036, batch_size = 32, lr=0.1 | | |||
| | Optimizer | Momentum |Momentum | | |||
| | Loss Function | SoftmaxCrossEntropy |SoftmaxCrossEntropy | | |||
| | outputs | probability |probability | | |||
| | Loss | 0.01 |1.5~2.0 | | |||
| | Speed | 1pc: 79 ms/step; 8pcs: 104 ms/step |1pc: 81 ms/step; 8pcs 94.4ms/step | | |||
| | Total time | 1pc: 72 mins; 8pcs: 11.8 mins |8pcs: 19.7 hours | | |||
| | Checkpoint for Fine tuning | 1.1G(.ckpt file) |1.1G(.ckpt file) | | |||
| | Scripts |[vgg16](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/vgg16) | | | |||
| ### Evaluation Performance | |||
| | Parameters | VGG16(Ascend) | VGG16(GPU) | |||
| | ------------------- | --------------------------- |--------------------- | |||
| | Model Version | VGG16 | VGG16 | | |||
| | Resource | Ascend 910 | GPU | | |||
| | Uploaded Date | 08/20/2020 | 08/20/2020 | | |||
| | MindSpore Version | 0.5.0-alpha |0.5.0-alpha | | |||
| | Dataset | CIFAR-10, 10,000 images |ImageNet2012, 5000 images | | |||
| | batch_size | 64 | 32 | | |||
| | outputs | probability | probability | | |||
| | Accuracy | 1pc: 93.4% |1pc: 73.0%; | | |||
| # [Description of Random Situation](#contents) | |||
| In dataset.py, we set the seed inside “create_dataset" function. We also use random seed in train.py. | |||
| # [ModelZoo Homepage](#contents) | |||
| Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo). | |||
| @@ -12,42 +12,201 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| ##############test vgg16 example on cifar10################# | |||
| python eval.py --data_path=$DATA_HOME --device_id=$DEVICE_ID | |||
| """ | |||
| """Eval""" | |||
| import os | |||
| import time | |||
| import argparse | |||
| import datetime | |||
| import glob | |||
| import numpy as np | |||
| import mindspore.nn as nn | |||
| from mindspore import context | |||
| from mindspore import Tensor, context | |||
| from mindspore.nn.optim.momentum import Momentum | |||
| from mindspore.train.model import Model | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from src.config import cifar_cfg as cfg | |||
| from src.dataset import vgg_create_dataset | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops import functional as F | |||
| from mindspore.common import dtype as mstype | |||
| from src.utils.logging import get_logger | |||
| from src.vgg import vgg16 | |||
| from src.dataset import vgg_create_dataset | |||
| from src.dataset import classification_dataset | |||
| class ParameterReduce(nn.Cell): | |||
| """ParameterReduce""" | |||
| def __init__(self): | |||
| super(ParameterReduce, self).__init__() | |||
| self.cast = P.Cast() | |||
| self.reduce = P.AllReduce() | |||
| def construct(self, x): | |||
| one = self.cast(F.scalar_to_array(1.0), mstype.float32) | |||
| out = x * one | |||
| ret = self.reduce(out) | |||
| return ret | |||
| if __name__ == '__main__': | |||
| parser = argparse.ArgumentParser(description='Cifar10 classification') | |||
| def parse_args(cloud_args=None): | |||
| """parse_args""" | |||
| parser = argparse.ArgumentParser('mindspore classification test') | |||
| parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'], | |||
| help='device where the code will be implemented. (Default: Ascend)') | |||
| parser.add_argument('--data_path', type=str, default='./cifar', help='path where the dataset is saved') | |||
| parser.add_argument('--checkpoint_path', type=str, default=None, help='checkpoint file path.') | |||
| parser.add_argument('--device_id', type=int, default=None, help='device id of GPU or Ascend. (Default: None)') | |||
| # dataset related | |||
| parser.add_argument('--dataset', type=str, choices=["cifar10", "imagenet2012"], default="cifar10") | |||
| parser.add_argument('--data_path', type=str, default='', help='eval data dir') | |||
| parser.add_argument('--per_batch_size', default=32, type=int, help='batch size for per npu') | |||
| # network related | |||
| parser.add_argument('--graph_ckpt', type=int, default=1, help='graph ckpt or feed ckpt') | |||
| parser.add_argument('--pre_trained', default='', type=str, help='fully path of pretrained model to load. ' | |||
| 'If it is a direction, it will test all ckpt') | |||
| # logging related | |||
| parser.add_argument('--log_path', type=str, default='outputs/', help='path to save log') | |||
| parser.add_argument('--rank', type=int, default=0, help='local rank of distributed') | |||
| parser.add_argument('--group_size', type=int, default=1, help='world size of distributed') | |||
| args_opt = parser.parse_args() | |||
| args_opt = merge_args(args_opt, cloud_args) | |||
| if args_opt.dataset == "cifar10": | |||
| from src.config import cifar_cfg as cfg | |||
| else: | |||
| from src.config import imagenet_cfg as cfg | |||
| args_opt.image_size = cfg.image_size | |||
| args_opt.num_classes = cfg.num_classes | |||
| args_opt.per_batch_size = cfg.batch_size | |||
| args_opt.momentum = cfg.momentum | |||
| args_opt.weight_decay = cfg.weight_decay | |||
| args_opt.buffer_size = cfg.buffer_size | |||
| args_opt.pad_mode = cfg.pad_mode | |||
| args_opt.padding = cfg.padding | |||
| args_opt.has_bias = cfg.has_bias | |||
| args_opt.batch_norm = cfg.batch_norm | |||
| args_opt.initialize_mode = cfg.initialize_mode | |||
| args_opt.has_dropout = cfg.has_dropout | |||
| args_opt.image_size = list(map(int, args_opt.image_size.split(','))) | |||
| return args_opt | |||
| def get_top5_acc(top5_arg, gt_class): | |||
| sub_count = 0 | |||
| for top5, gt in zip(top5_arg, gt_class): | |||
| if gt in top5: | |||
| sub_count += 1 | |||
| return sub_count | |||
| def merge_args(args, cloud_args): | |||
| """merge_args""" | |||
| args_dict = vars(args) | |||
| if isinstance(cloud_args, dict): | |||
| for key in cloud_args.keys(): | |||
| val = cloud_args[key] | |||
| if key in args_dict and val: | |||
| arg_type = type(args_dict[key]) | |||
| if arg_type is not type(None): | |||
| val = arg_type(val) | |||
| args_dict[key] = val | |||
| return args | |||
| def test(cloud_args=None): | |||
| """test""" | |||
| args = parse_args(cloud_args) | |||
| context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, | |||
| device_target=args.device_target, save_graphs=False) | |||
| if os.getenv('DEVICE_ID', "not_set").isdigit(): | |||
| context.set_context(device_id=int(os.getenv('DEVICE_ID'))) | |||
| args.outputs_dir = os.path.join(args.log_path, | |||
| datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')) | |||
| args.logger = get_logger(args.outputs_dir, args.rank) | |||
| args.logger.save_args(args) | |||
| if args.dataset == "cifar10": | |||
| net = vgg16(num_classes=args.num_classes, args=args) | |||
| opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, args.momentum, | |||
| weight_decay=args.weight_decay) | |||
| loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False) | |||
| model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}) | |||
| param_dict = load_checkpoint(args.pre_trained) | |||
| load_param_into_net(net, param_dict) | |||
| net.set_train(False) | |||
| dataset = vgg_create_dataset(args.data_path, args.image_size, args.per_batch_size, training=False) | |||
| res = model.eval(dataset) | |||
| print("result: ", res) | |||
| else: | |||
| # network | |||
| args.logger.important_info('start create network') | |||
| if os.path.isdir(args.pre_trained): | |||
| models = list(glob.glob(os.path.join(args.pre_trained, '*.ckpt'))) | |||
| print(models) | |||
| if args.graph_ckpt: | |||
| f = lambda x: -1 * int(os.path.splitext(os.path.split(x)[-1])[0].split('-')[-1].split('_')[0]) | |||
| else: | |||
| f = lambda x: -1 * int(os.path.splitext(os.path.split(x)[-1])[0].split('_')[-1]) | |||
| args.models = sorted(models, key=f) | |||
| else: | |||
| args.models = [args.pre_trained,] | |||
| for model in args.models: | |||
| dataset = classification_dataset(args.data_path, args.image_size, args.per_batch_size, mode='eval') | |||
| eval_dataloader = dataset.create_tuple_iterator() | |||
| network = vgg16(args.num_classes, args, phase="test") | |||
| # pre_trained | |||
| load_param_into_net(network, load_checkpoint(model)) | |||
| network.add_flags_recursive(fp16=True) | |||
| img_tot = 0 | |||
| top1_correct = 0 | |||
| top5_correct = 0 | |||
| network.set_train(False) | |||
| t_end = time.time() | |||
| it = 0 | |||
| for data, gt_classes in eval_dataloader: | |||
| output = network(Tensor(data, mstype.float32)) | |||
| output = output.asnumpy() | |||
| top1_output = np.argmax(output, (-1)) | |||
| top5_output = np.argsort(output)[:, -5:] | |||
| t1_correct = np.equal(top1_output, gt_classes).sum() | |||
| top1_correct += t1_correct | |||
| top5_correct += get_top5_acc(top5_output, gt_classes) | |||
| img_tot += args.per_batch_size | |||
| if args.rank == 0 and it == 0: | |||
| t_end = time.time() | |||
| it = 1 | |||
| if args.rank == 0: | |||
| time_used = time.time() - t_end | |||
| fps = (img_tot - args.per_batch_size) * args.group_size / time_used | |||
| args.logger.info('Inference Performance: {:.2f} img/sec'.format(fps)) | |||
| results = [[top1_correct], [top5_correct], [img_tot]] | |||
| args.logger.info('before results={}'.format(results)) | |||
| results = np.array(results) | |||
| args.logger.info('after results={}'.format(results)) | |||
| top1_correct = results[0, 0] | |||
| top5_correct = results[1, 0] | |||
| img_tot = results[2, 0] | |||
| acc1 = 100.0 * top1_correct / img_tot | |||
| acc5 = 100.0 * top5_correct / img_tot | |||
| args.logger.info('after allreduce eval: top1_correct={}, tot={},' | |||
| 'acc={:.2f}%(TOP1)'.format(top1_correct, img_tot, acc1)) | |||
| args.logger.info('after allreduce eval: top5_correct={}, tot={},' | |||
| 'acc={:.2f}%(TOP5)'.format(top5_correct, img_tot, acc5)) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target) | |||
| context.set_context(device_id=args_opt.device_id) | |||
| net = vgg16(num_classes=cfg.num_classes) | |||
| opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, cfg.momentum, | |||
| weight_decay=cfg.weight_decay) | |||
| loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False) | |||
| model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}) | |||
| param_dict = load_checkpoint(args_opt.checkpoint_path) | |||
| load_param_into_net(net, param_dict) | |||
| net.set_train(False) | |||
| dataset = vgg_create_dataset(args_opt.data_path, 1, False) | |||
| res = model.eval(dataset) | |||
| print("result: ", res) | |||
| if __name__ == "__main__": | |||
| test() | |||
| @@ -14,15 +14,15 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| if [ $# != 2 ] | |||
| if [ $# != 2 ] && [ $# != 3 ] | |||
| then | |||
| echo "Usage: sh run_distribute_train.sh [MINDSPORE_HCCL_CONFIG_PATH] [DATA_PATH]" | |||
| echo "Usage: sh run_distribute_train.sh [RANK_TABLE_FILE] [DATA_PATH] [cifar10|imagenet2012]" | |||
| exit 1 | |||
| fi | |||
| if [ ! -f $1 ] | |||
| then | |||
| echo "error: MINDSPORE_HCCL_CONFIG_PATH=$1 is not a file" | |||
| echo "error: RANK_TABLE_FILE=$1 is not a file" | |||
| exit 1 | |||
| fi | |||
| @@ -32,9 +32,22 @@ then | |||
| exit 1 | |||
| fi | |||
| dataset_type='cifar10' | |||
| if [ $# == 3 ] | |||
| then | |||
| if [ $3 != "cifar10" ] && [ $3 != "imagenet2012" ] | |||
| then | |||
| echo "error: the selected dataset is neither cifar10 nor imagenet2012" | |||
| exit 1 | |||
| fi | |||
| dataset_type=$3 | |||
| fi | |||
| export DEVICE_NUM=8 | |||
| export RANK_SIZE=8 | |||
| export MINDSPORE_HCCL_CONFIG_PATH=$1 | |||
| export RANK_TABLE_FILE=$1 | |||
| for((i=0;i<RANK_SIZE;i++)) | |||
| do | |||
| @@ -45,8 +58,8 @@ do | |||
| cp *.py ./train_parallel$i | |||
| cp -r src ./train_parallel$i | |||
| cd ./train_parallel$i || exit | |||
| echo "start training for rank $RANK_ID, device $DEVICE_ID" | |||
| echo "start training for rank $RANK_ID, device $DEVICE_ID, $dataset_type" | |||
| env > env.log | |||
| python train.py --data_path=$2 --device_id=$i &> log & | |||
| python train.py --data_path=$2 --device_target="Ascend" --device_id=$i --is_distributed=1 --dataset=$dataset_type &> log & | |||
| cd .. | |||
| done | |||
| done | |||
| @@ -0,0 +1,29 @@ | |||
| #!/bin/bash | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| echo "==============================================================================================================" | |||
| echo "Please run the script as: " | |||
| echo "bash run_distribute_train_gpu.sh DATA_PATH" | |||
| echo "for example: bash run_distribute_train_gpu.sh /path/ImageNet2012/train" | |||
| echo "==============================================================================================================" | |||
| DATA_PATH=$1 | |||
| mpirun -n 8 python train.py \ | |||
| --device_target="GPU" \ | |||
| --dataset="imagenet2012" \ | |||
| --is_distributed=1 \ | |||
| --data_path=$DATA_PATH > output.train.log 2>&1 & | |||
| @@ -0,0 +1,32 @@ | |||
| #!/bin/bash | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| echo "==============================================================================================================" | |||
| echo "Please run the script as: " | |||
| echo "bash run_eval.sh DATA_PATH DATASET_TYPE DEVICE_TYPE CHECKPOINT_PATH" | |||
| echo "for example: bash run_eval.sh /path/ImageNet2012/train cifar10 Ascend /path/a.ckpt " | |||
| echo "==============================================================================================================" | |||
| DATA_PATH=&1 | |||
| DATASET_TYPE=$2 | |||
| DEVICE_TYPE=$3 | |||
| CHECKPOINT_PATH=$4 | |||
| python eval.py \ | |||
| --data_path=$DATA_PATH \ | |||
| --dataset=$DATASET_TYPE \ | |||
| --device_target=$DEVICE_TYPE \ | |||
| --pre_trained=$CHECKPOINT_PATH > output.eval.log 2>&1 & | |||
| @@ -13,21 +13,60 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| network config setting, will be used in main.py | |||
| network config setting, will be used in train.py and eval.py | |||
| """ | |||
| from easydict import EasyDict as edict | |||
| # config for vgg16, cifar10 | |||
| cifar_cfg = edict({ | |||
| 'num_classes': 10, | |||
| 'lr_init': 0.01, | |||
| 'lr_max': 0.1, | |||
| 'warmup_epochs': 5, | |||
| 'batch_size': 64, | |||
| 'epoch_size': 70, | |||
| 'momentum': 0.9, | |||
| 'weight_decay': 5e-4, | |||
| 'buffer_size': 10, | |||
| 'image_height': 224, | |||
| 'image_width': 224, | |||
| 'keep_checkpoint_max': 10 | |||
| "num_classes": 10, | |||
| "lr": 0.01, | |||
| "lr_init": 0.01, | |||
| "lr_max": 0.1, | |||
| "lr_epochs": '30,60,90,120', | |||
| "lr_scheduler": "step", | |||
| "warmup_epochs": 5, | |||
| "batch_size": 64, | |||
| "max_epoch": 70, | |||
| "momentum": 0.9, | |||
| "weight_decay": 5e-4, | |||
| "loss_scale": 1.0, | |||
| "label_smooth": 0, | |||
| "label_smooth_factor": 0, | |||
| "buffer_size": 10, | |||
| "image_size": '224,224', | |||
| "pad_mode": 'same', | |||
| "padding": 0, | |||
| "has_bias": False, | |||
| "batch_norm": True, | |||
| "keep_checkpoint_max": 10, | |||
| "initialize_mode": "XavierUniform", | |||
| "has_dropout": False | |||
| }) | |||
| # config for vgg16, imagenet2012 | |||
| imagenet_cfg = edict({ | |||
| "num_classes": 1000, | |||
| "lr": 0.01, | |||
| "lr_init": 0.01, | |||
| "lr_max": 0.1, | |||
| "lr_epochs": '30,60,90,120', | |||
| "lr_scheduler": 'cosine_annealing', | |||
| "warmup_epochs": 0, | |||
| "batch_size": 32, | |||
| "max_epoch": 150, | |||
| "momentum": 0.9, | |||
| "weight_decay": 1e-4, | |||
| "loss_scale": 1024, | |||
| "label_smooth": 1, | |||
| "label_smooth_factor": 0.1, | |||
| "buffer_size": 10, | |||
| "image_size": '224,224', | |||
| "pad_mode": 'pad', | |||
| "padding": 1, | |||
| "has_bias": False, | |||
| "batch_norm": False, | |||
| "keep_checkpoint_max": 10, | |||
| "initialize_mode": "XavierUnifor", | |||
| "has_dropout": True | |||
| }) | |||
| @@ -0,0 +1,39 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """define loss function for network""" | |||
| from mindspore.nn.loss.loss import _Loss | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops import functional as F | |||
| from mindspore import Tensor | |||
| from mindspore.common import dtype as mstype | |||
| import mindspore.nn as nn | |||
| class CrossEntropy(_Loss): | |||
| """the redefined loss function with SoftmaxCrossEntropyWithLogits""" | |||
| def __init__(self, smooth_factor=0., num_classes=1001): | |||
| super(CrossEntropy, self).__init__() | |||
| self.onehot = P.OneHot() | |||
| self.on_value = Tensor(1.0 - smooth_factor, mstype.float32) | |||
| self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32) | |||
| self.ce = nn.SoftmaxCrossEntropyWithLogits() | |||
| self.mean = P.ReduceMean(False) | |||
| def construct(self, logit, label): | |||
| one_hot_label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value) | |||
| loss = self.ce(logit, one_hot_label) | |||
| loss = self.mean(loss, 0) | |||
| return loss | |||
| @@ -13,37 +13,35 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| Data operations, will be used in train.py and eval.py | |||
| dataset processing. | |||
| """ | |||
| import os | |||
| import mindspore.common.dtype as mstype | |||
| import mindspore.dataset as ds | |||
| from mindspore.common import dtype as mstype | |||
| import mindspore.dataset as de | |||
| import mindspore.dataset.transforms.c_transforms as C | |||
| import mindspore.dataset.transforms.vision.c_transforms as vision | |||
| from .config import cifar_cfg as cfg | |||
| from PIL import Image, ImageFile | |||
| from src.utils.sampler import DistributedSampler | |||
| ImageFile.LOAD_TRUNCATED_IMAGES = True | |||
| def vgg_create_dataset(data_home, repeat_num=1, training=True): | |||
| def vgg_create_dataset(data_home, image_size, batch_size, rank_id=0, rank_size=1, repeat_num=1, training=True): | |||
| """Data operations.""" | |||
| ds.config.set_seed(1) | |||
| de.config.set_seed(1) | |||
| data_dir = os.path.join(data_home, "cifar-10-batches-bin") | |||
| if not training: | |||
| data_dir = os.path.join(data_home, "cifar-10-verify-bin") | |||
| rank_size = int(os.environ.get("RANK_SIZE")) if os.environ.get("RANK_SIZE") else None | |||
| rank_id = int(os.environ.get("RANK_ID")) if os.environ.get("RANK_ID") else None | |||
| data_set = ds.Cifar10Dataset(data_dir, num_shards=rank_size, shard_id=rank_id) | |||
| data_set = de.Cifar10Dataset(data_dir, num_shards=rank_size, shard_id=rank_id) | |||
| resize_height = cfg.image_height | |||
| resize_width = cfg.image_width | |||
| rescale = 1.0 / 255.0 | |||
| shift = 0.0 | |||
| # define map operations | |||
| random_crop_op = vision.RandomCrop((32, 32), (4, 4, 4, 4)) # padding_mode default CONSTANT | |||
| random_horizontal_op = vision.RandomHorizontalFlip() | |||
| resize_op = vision.Resize((resize_height, resize_width)) # interpolation default BILINEAR | |||
| resize_op = vision.Resize(image_size) # interpolation default BILINEAR | |||
| rescale_op = vision.Rescale(rescale, shift) | |||
| normalize_op = vision.Normalize((0.4465, 0.4822, 0.4914), (0.2010, 0.1994, 0.2023)) | |||
| changeswap_op = vision.HWC2CHW() | |||
| @@ -66,6 +64,134 @@ def vgg_create_dataset(data_home, repeat_num=1, training=True): | |||
| data_set = data_set.shuffle(buffer_size=10) | |||
| # apply batch operations | |||
| data_set = data_set.batch(batch_size=cfg.batch_size, drop_remainder=True) | |||
| data_set = data_set.batch(batch_size=batch_size, drop_remainder=True) | |||
| return data_set | |||
| def classification_dataset(data_dir, image_size, per_batch_size, rank=0, group_size=1, | |||
| mode='train', | |||
| input_mode='folder', | |||
| root='', | |||
| num_parallel_workers=None, | |||
| shuffle=None, | |||
| sampler=None, | |||
| repeat_num=1, | |||
| class_indexing=None, | |||
| drop_remainder=True, | |||
| transform=None, | |||
| target_transform=None): | |||
| """ | |||
| A function that returns a dataset for classification. The mode of input dataset could be "folder" or "txt". | |||
| If it is "folder", all images within one folder have the same label. If it is "txt", all paths of images | |||
| are written into a textfile. | |||
| Args: | |||
| data_dir (str): Path to the root directory that contains the dataset for "input_mode="folder"". | |||
| Or path of the textfile that contains every image's path of the dataset. | |||
| image_size (str): Size of the input images. | |||
| per_batch_size (int): the batch size of evey step during training. | |||
| rank (int): The shard ID within num_shards (default=None). | |||
| group_size (int): Number of shards that the dataset should be divided | |||
| into (default=None). | |||
| mode (str): "train" or others. Default: " train". | |||
| input_mode (str): The form of the input dataset. "folder" or "txt". Default: "folder". | |||
| root (str): the images path for "input_mode="txt"". Default: " ". | |||
| num_parallel_workers (int): Number of workers to read the data. Default: None. | |||
| shuffle (bool): Whether or not to perform shuffle on the dataset | |||
| (default=None, performs shuffle). | |||
| sampler (Sampler): Object used to choose samples from the dataset. Default: None. | |||
| repeat_num (int): the num of repeat dataset. | |||
| class_indexing (dict): A str-to-int mapping from folder name to index | |||
| (default=None, the folder names will be sorted | |||
| alphabetically and each class will be given a | |||
| unique index starting from 0). | |||
| Examples: | |||
| >>> from mindvision.common.datasets.classification import classification_dataset | |||
| >>> # path to imagefolder directory. This directory needs to contain sub-directories which contain the images | |||
| >>> dataset_dir = "/path/to/imagefolder_directory" | |||
| >>> de_dataset = classification_dataset(train_data_dir, image_size=[224, 244], | |||
| >>> per_batch_size=64, rank=0, group_size=4) | |||
| >>> # Path of the textfile that contains every image's path of the dataset. | |||
| >>> dataset_dir = "/path/to/dataset/images/train.txt" | |||
| >>> images_dir = "/path/to/dataset/images" | |||
| >>> de_dataset = classification_dataset(train_data_dir, image_size=[224, 244], | |||
| >>> per_batch_size=64, rank=0, group_size=4, | |||
| >>> input_mode="txt", root=images_dir) | |||
| """ | |||
| mean = [0.485 * 255, 0.456 * 255, 0.406 * 255] | |||
| std = [0.229 * 255, 0.224 * 255, 0.225 * 255] | |||
| if transform is None: | |||
| if mode == 'train': | |||
| transform_img = [ | |||
| vision.RandomCropDecodeResize(image_size, scale=(0.08, 1.0)), | |||
| vision.RandomHorizontalFlip(prob=0.5), | |||
| vision.Normalize(mean=mean, std=std), | |||
| vision.HWC2CHW() | |||
| ] | |||
| else: | |||
| transform_img = [ | |||
| vision.Decode(), | |||
| vision.Resize((256, 256)), | |||
| vision.CenterCrop(image_size), | |||
| vision.Normalize(mean=mean, std=std), | |||
| vision.HWC2CHW() | |||
| ] | |||
| else: | |||
| transform_img = transform | |||
| if target_transform is None: | |||
| transform_label = [C.TypeCast(mstype.int32)] | |||
| else: | |||
| transform_label = target_transform | |||
| if input_mode == 'folder': | |||
| de_dataset = de.ImageFolderDatasetV2(data_dir, num_parallel_workers=num_parallel_workers, | |||
| shuffle=shuffle, sampler=sampler, class_indexing=class_indexing, | |||
| num_shards=group_size, shard_id=rank) | |||
| else: | |||
| dataset = TxtDataset(root, data_dir) | |||
| sampler = DistributedSampler(dataset, rank, group_size, shuffle=shuffle) | |||
| de_dataset = de.GeneratorDataset(dataset, ["image", "label"], sampler=sampler) | |||
| de_dataset.set_dataset_size(len(sampler)) | |||
| de_dataset = de_dataset.map(input_columns="image", num_parallel_workers=8, operations=transform_img) | |||
| de_dataset = de_dataset.map(input_columns="label", num_parallel_workers=8, operations=transform_label) | |||
| columns_to_project = ["image", "label"] | |||
| de_dataset = de_dataset.project(columns=columns_to_project) | |||
| de_dataset = de_dataset.batch(per_batch_size, drop_remainder=drop_remainder) | |||
| de_dataset = de_dataset.repeat(repeat_num) | |||
| return de_dataset | |||
| class TxtDataset: | |||
| """ | |||
| create txt dataset. | |||
| Args: | |||
| Returns: | |||
| de_dataset. | |||
| """ | |||
| def __init__(self, root, txt_name): | |||
| super(TxtDataset, self).__init__() | |||
| self.imgs = [] | |||
| self.labels = [] | |||
| fin = open(txt_name, "r") | |||
| for line in fin: | |||
| img_name, label = line.strip().split(' ') | |||
| self.imgs.append(os.path.join(root, img_name)) | |||
| self.labels.append(int(label)) | |||
| fin.close() | |||
| def __getitem__(self, index): | |||
| img = Image.open(self.imgs[index]).convert('RGB') | |||
| return img, self.labels[index] | |||
| def __len__(self): | |||
| return len(self.imgs) | |||
| @@ -0,0 +1,23 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| linear warm up learning rate. | |||
| """ | |||
| def linear_warmup_lr(current_step, warmup_steps, base_lr, init_lr): | |||
| lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps) | |||
| lr = float(init_lr) + lr_inc * current_step | |||
| return lr | |||
| @@ -0,0 +1,82 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| get logger. | |||
| """ | |||
| import logging | |||
| import os | |||
| import sys | |||
| from datetime import datetime | |||
| class LOGGER(logging.Logger): | |||
| """ | |||
| set up logging file. | |||
| Args: | |||
| logger_name (string): logger name. | |||
| log_dir (string): path of logger. | |||
| Returns: | |||
| string, logger path | |||
| """ | |||
| def __init__(self, logger_name, rank=0): | |||
| super(LOGGER, self).__init__(logger_name) | |||
| if rank % 8 == 0: | |||
| console = logging.StreamHandler(sys.stdout) | |||
| console.setLevel(logging.INFO) | |||
| formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s') | |||
| console.setFormatter(formatter) | |||
| self.addHandler(console) | |||
| def setup_logging_file(self, log_dir, rank=0): | |||
| """set up log file""" | |||
| self.rank = rank | |||
| if not os.path.exists(log_dir): | |||
| os.makedirs(log_dir, exist_ok=True) | |||
| log_name = datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S') + '_rank_{}.log'.format(rank) | |||
| self.log_fn = os.path.join(log_dir, log_name) | |||
| fh = logging.FileHandler(self.log_fn) | |||
| fh.setLevel(logging.INFO) | |||
| formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s') | |||
| fh.setFormatter(formatter) | |||
| self.addHandler(fh) | |||
| def info(self, msg, *args, **kwargs): | |||
| if self.isEnabledFor(logging.INFO): | |||
| self._log(logging.INFO, msg, args, **kwargs) | |||
| def save_args(self, args): | |||
| self.info('Args:') | |||
| args_dict = vars(args) | |||
| for key in args_dict.keys(): | |||
| self.info('--> %s: %s', key, args_dict[key]) | |||
| self.info('') | |||
| def important_info(self, msg, *args, **kwargs): | |||
| if self.isEnabledFor(logging.INFO) and self.rank == 0: | |||
| line_width = 2 | |||
| important_msg = '\n' | |||
| important_msg += ('*'*70 + '\n')*line_width | |||
| important_msg += ('*'*line_width + '\n')*2 | |||
| important_msg += '*'*line_width + ' '*8 + msg + '\n' | |||
| important_msg += ('*'*line_width + '\n')*2 | |||
| important_msg += ('*'*70 + '\n')*line_width | |||
| self.info(important_msg, *args, **kwargs) | |||
| def get_logger(path, rank): | |||
| logger = LOGGER("mindversion", rank) | |||
| logger.setup_logging_file(path, rank) | |||
| return logger | |||
| @@ -0,0 +1,53 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| choose samples from the dataset | |||
| """ | |||
| import math | |||
| import numpy as np | |||
| class DistributedSampler(): | |||
| """ | |||
| sampling the dataset. | |||
| Args: | |||
| Returns: | |||
| num_samples, number of samples. | |||
| """ | |||
| def __init__(self, dataset, rank, group_size, shuffle=True, seed=0): | |||
| self.dataset = dataset | |||
| self.rank = rank | |||
| self.group_size = group_size | |||
| self.dataset_length = len(self.dataset) | |||
| self.num_samples = int(math.ceil(self.dataset_length * 1.0 / self.group_size)) | |||
| self.total_size = self.num_samples * self.group_size | |||
| self.shuffle = shuffle | |||
| self.seed = seed | |||
| def __iter__(self): | |||
| if self.shuffle: | |||
| self.seed = (self.seed + 1) & 0xffffffff | |||
| np.random.seed(self.seed) | |||
| indices = np.random.permutation(self.dataset_length).tolist() | |||
| else: | |||
| indices = list(range(len(self.dataset_length))) | |||
| indices += indices[:(self.total_size - len(indices))] | |||
| indices = indices[self.rank::self.group_size] | |||
| return iter(indices) | |||
| def __len__(self): | |||
| return self.num_samples | |||
| @@ -0,0 +1,36 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Util class or function.""" | |||
| def get_param_groups(network): | |||
| """Param groups for optimizer.""" | |||
| decay_params = [] | |||
| no_decay_params = [] | |||
| for x in network.trainable_params(): | |||
| parameter_name = x.name | |||
| if parameter_name.endswith('.bias'): | |||
| # all bias not using weight decay | |||
| no_decay_params.append(x) | |||
| elif parameter_name.endswith('.gamma'): | |||
| # bn weight bias not using weight decay, be carefully for now x not include BN | |||
| no_decay_params.append(x) | |||
| elif parameter_name.endswith('.beta'): | |||
| # bn weight bias not using weight decay, be carefully for now x not include BN | |||
| no_decay_params.append(x) | |||
| else: | |||
| decay_params.append(x) | |||
| return [{'params': no_decay_params, 'weight_decay': 0.0}, {'params': decay_params}] | |||
| @@ -0,0 +1,214 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| Initialize. | |||
| """ | |||
| import math | |||
| from functools import reduce | |||
| import numpy as np | |||
| import mindspore.nn as nn | |||
| from mindspore.common import initializer as init | |||
| def _calculate_gain(nonlinearity, param=None): | |||
| r""" | |||
| Return the recommended gain value for the given nonlinearity function. | |||
| The values are as follows: | |||
| ================= ==================================================== | |||
| nonlinearity gain | |||
| ================= ==================================================== | |||
| Linear / Identity :math:`1` | |||
| Conv{1,2,3}D :math:`1` | |||
| Sigmoid :math:`1` | |||
| Tanh :math:`\frac{5}{3}` | |||
| ReLU :math:`\sqrt{2}` | |||
| Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}` | |||
| ================= ==================================================== | |||
| Args: | |||
| nonlinearity: the non-linear function | |||
| param: optional parameter for the non-linear function | |||
| Examples: | |||
| >>> gain = calculate_gain('leaky_relu', 0.2) # leaky_relu with negative_slope=0.2 | |||
| """ | |||
| linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d'] | |||
| if nonlinearity in linear_fns or nonlinearity == 'sigmoid': | |||
| return 1 | |||
| if nonlinearity == 'tanh': | |||
| return 5.0 / 3 | |||
| if nonlinearity == 'relu': | |||
| return math.sqrt(2.0) | |||
| if nonlinearity == 'leaky_relu': | |||
| if param is None: | |||
| negative_slope = 0.01 | |||
| elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float): | |||
| negative_slope = param | |||
| else: | |||
| raise ValueError("negative_slope {} not a valid number".format(param)) | |||
| return math.sqrt(2.0 / (1 + negative_slope ** 2)) | |||
| raise ValueError("Unsupported nonlinearity {}".format(nonlinearity)) | |||
| def _assignment(arr, num): | |||
| """Assign the value of `num` to `arr`.""" | |||
| if arr.shape == (): | |||
| arr = arr.reshape((1)) | |||
| arr[:] = num | |||
| arr = arr.reshape(()) | |||
| else: | |||
| if isinstance(num, np.ndarray): | |||
| arr[:] = num[:] | |||
| else: | |||
| arr[:] = num | |||
| return arr | |||
| def _calculate_in_and_out(arr): | |||
| """ | |||
| Calculate n_in and n_out. | |||
| Args: | |||
| arr (Array): Input array. | |||
| Returns: | |||
| Tuple, a tuple with two elements, the first element is `n_in` and the second element is `n_out`. | |||
| """ | |||
| dim = len(arr.shape) | |||
| if dim < 2: | |||
| raise ValueError("If initialize data with xavier uniform, the dimension of data must greater than 1.") | |||
| n_in = arr.shape[1] | |||
| n_out = arr.shape[0] | |||
| if dim > 2: | |||
| counter = reduce(lambda x, y: x * y, arr.shape[2:]) | |||
| n_in *= counter | |||
| n_out *= counter | |||
| return n_in, n_out | |||
| def _select_fan(array, mode): | |||
| mode = mode.lower() | |||
| valid_modes = ['fan_in', 'fan_out'] | |||
| if mode not in valid_modes: | |||
| raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes)) | |||
| fan_in, fan_out = _calculate_in_and_out(array) | |||
| return fan_in if mode == 'fan_in' else fan_out | |||
| class KaimingInit(init.Initializer): | |||
| r""" | |||
| Base Class. Initialize the array with He kaiming algorithm. | |||
| Args: | |||
| a: the negative slope of the rectifier used after this layer (only | |||
| used with ``'leaky_relu'``) | |||
| mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'`` | |||
| preserves the magnitude of the variance of the weights in the | |||
| forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the | |||
| backwards pass. | |||
| nonlinearity: the non-linear function, recommended to use only with | |||
| ``'relu'`` or ``'leaky_relu'`` (default). | |||
| """ | |||
| def __init__(self, a=0, mode='fan_in', nonlinearity='leaky_relu'): | |||
| super(KaimingInit, self).__init__() | |||
| self.mode = mode | |||
| self.gain = _calculate_gain(nonlinearity, a) | |||
| def _initialize(self, arr): | |||
| pass | |||
| class KaimingUniform(KaimingInit): | |||
| r""" | |||
| Initialize the array with He kaiming uniform algorithm. The resulting tensor will | |||
| have values sampled from :math:`\mathcal{U}(-\text{bound}, \text{bound})` where | |||
| .. math:: | |||
| \text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}} | |||
| Input: | |||
| arr (Array): The array to be assigned. | |||
| Returns: | |||
| Array, assigned array. | |||
| Examples: | |||
| >>> w = np.empty(3, 5) | |||
| >>> KaimingUniform(w, mode='fan_in', nonlinearity='relu') | |||
| """ | |||
| def _initialize(self, arr): | |||
| fan = _select_fan(arr, self.mode) | |||
| bound = math.sqrt(3.0) * self.gain / math.sqrt(fan) | |||
| np.random.seed(0) | |||
| data = np.random.uniform(-bound, bound, arr.shape) | |||
| _assignment(arr, data) | |||
| class KaimingNormal(KaimingInit): | |||
| r""" | |||
| Initialize the array with He kaiming normal algorithm. The resulting tensor will | |||
| have values sampled from :math:`\mathcal{N}(0, \text{std}^2)` where | |||
| .. math:: | |||
| \text{std} = \frac{\text{gain}}{\sqrt{\text{fan\_mode}}} | |||
| Input: | |||
| arr (Array): The array to be assigned. | |||
| Returns: | |||
| Array, assigned array. | |||
| Examples: | |||
| >>> w = np.empty(3, 5) | |||
| >>> KaimingNormal(w, mode='fan_out', nonlinearity='relu') | |||
| """ | |||
| def _initialize(self, arr): | |||
| fan = _select_fan(arr, self.mode) | |||
| std = self.gain / math.sqrt(fan) | |||
| np.random.seed(0) | |||
| data = np.random.normal(0, std, arr.shape) | |||
| _assignment(arr, data) | |||
| def default_recurisive_init(custom_cell): | |||
| """default_recurisive_init""" | |||
| for _, cell in custom_cell.cells_and_names(): | |||
| if isinstance(cell, nn.Conv2d): | |||
| cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)), | |||
| cell.weight.shape, | |||
| cell.weight.dtype) | |||
| if cell.bias is not None: | |||
| fan_in, _ = _calculate_in_and_out(cell.weight) | |||
| bound = 1 / math.sqrt(fan_in) | |||
| np.random.seed(0) | |||
| cell.bias.default_input = init.initializer(init.Uniform(bound), | |||
| cell.bias.shape, | |||
| cell.bias.dtype) | |||
| elif isinstance(cell, nn.Dense): | |||
| cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)), | |||
| cell.weight.shape, | |||
| cell.weight.dtype) | |||
| if cell.bias is not None: | |||
| fan_in, _ = _calculate_in_and_out(cell.weight) | |||
| bound = 1 / math.sqrt(fan_in) | |||
| np.random.seed(0) | |||
| cell.bias.default_input = init.initializer(init.Uniform(bound), | |||
| cell.bias.shape, | |||
| cell.bias.dtype) | |||
| elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)): | |||
| pass | |||
| @@ -12,12 +12,18 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """VGG.""" | |||
| """ | |||
| Image classifiation. | |||
| """ | |||
| import math | |||
| import mindspore.nn as nn | |||
| from mindspore.common.initializer import initializer | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore.common import initializer as init | |||
| from mindspore.common.initializer import initializer | |||
| from .utils.var_init import default_recurisive_init, KaimingNormal | |||
| def _make_layer(base, batch_norm): | |||
| def _make_layer(base, args, batch_norm): | |||
| """Make stage network of VGG.""" | |||
| layers = [] | |||
| in_channels = 3 | |||
| @@ -25,13 +31,17 @@ def _make_layer(base, batch_norm): | |||
| if v == 'M': | |||
| layers += [nn.MaxPool2d(kernel_size=2, stride=2)] | |||
| else: | |||
| weight_shape = (v, in_channels, 3, 3) | |||
| weight = initializer('XavierUniform', shape=weight_shape, dtype=mstype.float32).to_tensor() | |||
| weight = 'ones' | |||
| if args.initialize_mode == "XavierUniform": | |||
| weight_shape = (v, in_channels, 3, 3) | |||
| weight = initializer('XavierUniform', shape=weight_shape, dtype=mstype.float32).to_tensor() | |||
| conv2d = nn.Conv2d(in_channels=in_channels, | |||
| out_channels=v, | |||
| kernel_size=3, | |||
| padding=0, | |||
| pad_mode='same', | |||
| padding=args.padding, | |||
| pad_mode=args.pad_mode, | |||
| has_bias=args.has_bias, | |||
| weight_init=weight) | |||
| if batch_norm: | |||
| layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU()] | |||
| @@ -59,17 +69,25 @@ class Vgg(nn.Cell): | |||
| >>> num_classes=1000, batch_norm=False, batch_size=1) | |||
| """ | |||
| def __init__(self, base, num_classes=1000, batch_norm=False, batch_size=1): | |||
| def __init__(self, base, num_classes=1000, batch_norm=False, batch_size=1, args=None, phase="train"): | |||
| super(Vgg, self).__init__() | |||
| _ = batch_size | |||
| self.layers = _make_layer(base, batch_norm=batch_norm) | |||
| self.layers = _make_layer(base, args, batch_norm=batch_norm) | |||
| self.flatten = nn.Flatten() | |||
| dropout_ratio = 0.5 | |||
| if not args.has_dropout or phase == "test": | |||
| dropout_ratio = 1.0 | |||
| self.classifier = nn.SequentialCell([ | |||
| nn.Dense(512 * 7 * 7, 4096), | |||
| nn.ReLU(), | |||
| nn.Dropout(dropout_ratio), | |||
| nn.Dense(4096, 4096), | |||
| nn.ReLU(), | |||
| nn.Dropout(dropout_ratio), | |||
| nn.Dense(4096, num_classes)]) | |||
| if args.initialize_mode == "KaimingNormal": | |||
| default_recurisive_init(self) | |||
| self.custom_init_weight() | |||
| def construct(self, x): | |||
| x = self.layers(x) | |||
| @@ -77,6 +95,25 @@ class Vgg(nn.Cell): | |||
| x = self.classifier(x) | |||
| return x | |||
| def custom_init_weight(self): | |||
| """ | |||
| Init the weight of Conv2d and Dense in the net. | |||
| """ | |||
| for _, cell in self.cells_and_names(): | |||
| if isinstance(cell, nn.Conv2d): | |||
| cell.weight.default_input = init.initializer( | |||
| KaimingNormal(a=math.sqrt(5), mode='fan_out', nonlinearity='relu'), | |||
| cell.weight.shape, cell.weight.dtype) | |||
| if cell.bias is not None: | |||
| cell.bias.default_input = init.initializer( | |||
| 'zeros', cell.bias.shape, cell.bias.dtype) | |||
| elif isinstance(cell, nn.Dense): | |||
| cell.weight.default_input = init.initializer( | |||
| init.Normal(0.01), cell.weight.shape, cell.weight.dtype) | |||
| if cell.bias is not None: | |||
| cell.bias.default_input = init.initializer( | |||
| 'zeros', cell.bias.shape, cell.bias.dtype) | |||
| cfg = { | |||
| '11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], | |||
| @@ -86,19 +123,24 @@ cfg = { | |||
| } | |||
| def vgg16(num_classes=1000): | |||
| def vgg16(num_classes=1000, args=None, phase="train"): | |||
| """ | |||
| Get Vgg16 neural network with batch normalization. | |||
| Args: | |||
| num_classes (int): Class numbers. Default: 1000. | |||
| args(namespace): param for net init. | |||
| phase(str): train or test mode. | |||
| Returns: | |||
| Cell, cell instance of Vgg16 neural network with batch normalization. | |||
| Examples: | |||
| >>> vgg16(num_classes=1000) | |||
| >>> vgg16(num_classes=1000, args=args) | |||
| """ | |||
| net = Vgg(cfg['16'], num_classes=num_classes, batch_norm=True) | |||
| if args is None: | |||
| from .config import cifar_cfg | |||
| args = cifar_cfg | |||
| net = Vgg(cfg['16'], num_classes=num_classes, args=args, batch_norm=args.batch_norm, phase=phase) | |||
| return net | |||
| @@ -0,0 +1,40 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| warm up cosine annealing learning rate. | |||
| """ | |||
| import math | |||
| import numpy as np | |||
| from .linear_warmup import linear_warmup_lr | |||
| def warmup_cosine_annealing_lr(lr, steps_per_epoch, warmup_epochs, max_epoch, T_max, eta_min=0): | |||
| """warm up cosine annealing learning rate.""" | |||
| base_lr = lr | |||
| warmup_init_lr = 0 | |||
| total_steps = int(max_epoch * steps_per_epoch) | |||
| warmup_steps = int(warmup_epochs * steps_per_epoch) | |||
| lr_each_step = [] | |||
| for i in range(total_steps): | |||
| last_epoch = i // steps_per_epoch | |||
| if i < warmup_steps: | |||
| lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr) | |||
| else: | |||
| lr = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi*last_epoch / T_max)) / 2 | |||
| lr_each_step.append(lr) | |||
| return np.array(lr_each_step).astype(np.float32) | |||
| @@ -0,0 +1,84 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| warm up step learning rate. | |||
| """ | |||
| from collections import Counter | |||
| import numpy as np | |||
| from .linear_warmup import linear_warmup_lr | |||
| def lr_steps(global_step, lr_init, lr_max, warmup_epochs, total_epochs, steps_per_epoch): | |||
| """Set learning rate.""" | |||
| lr_each_step = [] | |||
| total_steps = steps_per_epoch * total_epochs | |||
| warmup_steps = steps_per_epoch * warmup_epochs | |||
| if warmup_steps != 0: | |||
| inc_each_step = (float(lr_max) - float(lr_init)) / float(warmup_steps) | |||
| else: | |||
| inc_each_step = 0 | |||
| for i in range(total_steps): | |||
| if i < warmup_steps: | |||
| lr_value = float(lr_init) + inc_each_step * float(i) | |||
| else: | |||
| base = (1.0 - (float(i) - float(warmup_steps)) / (float(total_steps) - float(warmup_steps))) | |||
| lr_value = float(lr_max) * base * base | |||
| if lr_value < 0.0: | |||
| lr_value = 0.0 | |||
| lr_each_step.append(lr_value) | |||
| current_step = global_step | |||
| lr_each_step = np.array(lr_each_step).astype(np.float32) | |||
| learning_rate = lr_each_step[current_step:] | |||
| return learning_rate | |||
| def warmup_step_lr(lr, lr_epochs, steps_per_epoch, warmup_epochs, max_epoch, gamma=0.1): | |||
| """warmup_step_lr""" | |||
| base_lr = lr | |||
| warmup_init_lr = 0 | |||
| total_steps = int(max_epoch * steps_per_epoch) | |||
| warmup_steps = int(warmup_epochs * steps_per_epoch) | |||
| milestones = lr_epochs | |||
| milestones_steps = [] | |||
| for milestone in milestones: | |||
| milestones_step = milestone * steps_per_epoch | |||
| milestones_steps.append(milestones_step) | |||
| lr_each_step = [] | |||
| lr = base_lr | |||
| milestones_steps_counter = Counter(milestones_steps) | |||
| for i in range(total_steps): | |||
| if i < warmup_steps: | |||
| lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr) | |||
| else: | |||
| lr = lr * gamma**milestones_steps_counter[i] | |||
| lr_each_step.append(lr) | |||
| return np.array(lr_each_step).astype(np.float32) | |||
| def multi_step_lr(lr, milestones, steps_per_epoch, max_epoch, gamma=0.1): | |||
| return warmup_step_lr(lr, milestones, steps_per_epoch, 0, max_epoch, gamma=gamma) | |||
| def step_lr(lr, epoch_size, steps_per_epoch, max_epoch, gamma=0.1): | |||
| lr_epochs = [] | |||
| for i in range(1, max_epoch): | |||
| if i % epoch_size == 0: | |||
| lr_epochs.append(i) | |||
| return multi_step_lr(lr, lr_epochs, steps_per_epoch, max_epoch, gamma=gamma) | |||
| @@ -17,6 +17,7 @@ | |||
| python train.py --data_path=$DATA_HOME --device_id=$DEVICE_ID | |||
| """ | |||
| import argparse | |||
| import datetime | |||
| import os | |||
| import random | |||
| @@ -25,83 +26,213 @@ import numpy as np | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore import context | |||
| from mindspore.communication.management import init | |||
| from mindspore import ParallelMode | |||
| from mindspore.communication.management import init, get_rank, get_group_size | |||
| from mindspore.nn.optim.momentum import Momentum | |||
| from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor | |||
| from mindspore.train.model import Model, ParallelMode | |||
| from mindspore.train.model import Model | |||
| from mindspore.train.serialization import load_param_into_net, load_checkpoint | |||
| from src.config import cifar_cfg as cfg | |||
| from mindspore.train.loss_scale_manager import FixedLossScaleManager | |||
| from src.dataset import vgg_create_dataset | |||
| from src.dataset import classification_dataset | |||
| from src.crossentropy import CrossEntropy | |||
| from src.warmup_step_lr import warmup_step_lr | |||
| from src.warmup_cosine_annealing_lr import warmup_cosine_annealing_lr | |||
| from src.warmup_step_lr import lr_steps | |||
| from src.utils.logging import get_logger | |||
| from src.utils.util import get_param_groups | |||
| from src.vgg import vgg16 | |||
| random.seed(1) | |||
| np.random.seed(1) | |||
| def lr_steps(global_step, lr_init, lr_max, warmup_epochs, total_epochs, steps_per_epoch): | |||
| """Set learning rate.""" | |||
| lr_each_step = [] | |||
| total_steps = steps_per_epoch * total_epochs | |||
| warmup_steps = steps_per_epoch * warmup_epochs | |||
| if warmup_steps != 0: | |||
| inc_each_step = (float(lr_max) - float(lr_init)) / float(warmup_steps) | |||
| else: | |||
| inc_each_step = 0 | |||
| for i in range(total_steps): | |||
| if i < warmup_steps: | |||
| lr_value = float(lr_init) + inc_each_step * float(i) | |||
| else: | |||
| base = (1.0 - (float(i) - float(warmup_steps)) / (float(total_steps) - float(warmup_steps))) | |||
| lr_value = float(lr_max) * base * base | |||
| if lr_value < 0.0: | |||
| lr_value = 0.0 | |||
| lr_each_step.append(lr_value) | |||
| def parse_args(cloud_args=None): | |||
| """parameters""" | |||
| parser = argparse.ArgumentParser('mindspore classification training') | |||
| parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'], | |||
| help='device where the code will be implemented. (Default: Ascend)') | |||
| parser.add_argument('--device_id', type=int, default=1, help='device id of GPU or Ascend. (Default: None)') | |||
| current_step = global_step | |||
| lr_each_step = np.array(lr_each_step).astype(np.float32) | |||
| learning_rate = lr_each_step[current_step:] | |||
| # dataset related | |||
| parser.add_argument('--dataset', type=str, choices=["cifar10", "imagenet2012"], default="cifar10") | |||
| parser.add_argument('--data_path', type=str, default='', help='train data dir') | |||
| return learning_rate | |||
| # network related | |||
| parser.add_argument('--pre_trained', default='', type=str, help='model_path, local pretrained model to load') | |||
| parser.add_argument('--lr_gamma', type=float, default=0.1, | |||
| help='decrease lr by a factor of exponential lr_scheduler') | |||
| parser.add_argument('--eta_min', type=float, default=0., help='eta_min in cosine_annealing scheduler') | |||
| parser.add_argument('--T_max', type=int, default=150, help='T-max in cosine_annealing scheduler') | |||
| # logging and checkpoint related | |||
| parser.add_argument('--log_interval', type=int, default=100, help='logging interval') | |||
| parser.add_argument('--ckpt_path', type=str, default='outputs/', help='checkpoint save location') | |||
| parser.add_argument('--ckpt_interval', type=int, default=5, help='ckpt_interval') | |||
| parser.add_argument('--is_save_on_master', type=int, default=1, help='save ckpt on master or all rank') | |||
| if __name__ == '__main__': | |||
| parser = argparse.ArgumentParser(description='Cifar10 classification') | |||
| parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'], | |||
| help='device where the code will be implemented. (Default: Ascend)') | |||
| parser.add_argument('--data_path', type=str, default='./cifar', help='path where the dataset is saved') | |||
| parser.add_argument('--device_id', type=int, default=None, help='device id of GPU or Ascend. (Default: None)') | |||
| parser.add_argument('--pre_trained', type=str, default=None, help='the pretrained checkpoint file path.') | |||
| # distributed related | |||
| parser.add_argument('--is_distributed', type=int, default=0, help='if multi device') | |||
| parser.add_argument('--rank', type=int, default=0, help='local rank of distributed') | |||
| parser.add_argument('--group_size', type=int, default=1, help='world size of distributed') | |||
| args_opt = parser.parse_args() | |||
| args_opt = merge_args(args_opt, cloud_args) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target) | |||
| context.set_context(device_id=args_opt.device_id) | |||
| if args_opt.dataset == "cifar10": | |||
| from src.config import cifar_cfg as cfg | |||
| else: | |||
| from src.config import imagenet_cfg as cfg | |||
| args_opt.label_smooth = cfg.label_smooth | |||
| args_opt.label_smooth_factor = cfg.label_smooth_factor | |||
| args_opt.lr_scheduler = cfg.lr_scheduler | |||
| args_opt.loss_scale = cfg.loss_scale | |||
| args_opt.max_epoch = cfg.max_epoch | |||
| args_opt.warmup_epochs = cfg.warmup_epochs | |||
| args_opt.lr = cfg.lr | |||
| args_opt.lr_init = cfg.lr_init | |||
| args_opt.lr_max = cfg.lr_max | |||
| args_opt.momentum = cfg.momentum | |||
| args_opt.weight_decay = cfg.weight_decay | |||
| args_opt.per_batch_size = cfg.batch_size | |||
| args_opt.num_classes = cfg.num_classes | |||
| args_opt.buffer_size = cfg.buffer_size | |||
| args_opt.ckpt_save_max = cfg.keep_checkpoint_max | |||
| args_opt.pad_mode = cfg.pad_mode | |||
| args_opt.padding = cfg.padding | |||
| args_opt.has_bias = cfg.has_bias | |||
| args_opt.batch_norm = cfg.batch_norm | |||
| args_opt.initialize_mode = cfg.initialize_mode | |||
| args_opt.has_dropout = cfg.has_dropout | |||
| args_opt.lr_epochs = list(map(int, cfg.lr_epochs.split(','))) | |||
| args_opt.image_size = list(map(int, cfg.image_size.split(','))) | |||
| return args_opt | |||
| def merge_args(args_opt, cloud_args): | |||
| """dictionary""" | |||
| args_dict = vars(args_opt) | |||
| if isinstance(cloud_args, dict): | |||
| for key_arg in cloud_args.keys(): | |||
| val = cloud_args[key_arg] | |||
| if key_arg in args_dict and val: | |||
| arg_type = type(args_dict[key_arg]) | |||
| if arg_type is not None: | |||
| val = arg_type(val) | |||
| args_dict[key_arg] = val | |||
| return args_opt | |||
| if __name__ == '__main__': | |||
| args = parse_args() | |||
| device_num = int(os.environ.get("DEVICE_NUM", 1)) | |||
| if device_num > 1: | |||
| if args.is_distributed: | |||
| if args.device_target == "Ascend": | |||
| init() | |||
| context.set_context(device_id=args.device_id) | |||
| elif args.device_target == "GPU": | |||
| init("nccl") | |||
| args.rank = get_rank() | |||
| args.group_size = get_group_size() | |||
| device_num = args.group_size | |||
| context.reset_auto_parallel_context() | |||
| context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, | |||
| mirror_mean=True) | |||
| init() | |||
| parameter_broadcast=True, mirror_mean=True) | |||
| else: | |||
| context.set_context(device_id=args.device_id) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) | |||
| # select for master rank save ckpt or all rank save, compatible for model parallel | |||
| args.rank_save_ckpt_flag = 0 | |||
| if args.is_save_on_master: | |||
| if args.rank == 0: | |||
| args.rank_save_ckpt_flag = 1 | |||
| else: | |||
| args.rank_save_ckpt_flag = 1 | |||
| # logger | |||
| args.outputs_dir = os.path.join(args.ckpt_path, | |||
| datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')) | |||
| args.logger = get_logger(args.outputs_dir, args.rank) | |||
| if args.dataset == "cifar10": | |||
| dataset = vgg_create_dataset(args.data_path, args.image_size, args.per_batch_size, args.rank, args.group_size, | |||
| repeat_num=args.max_epoch) | |||
| else: | |||
| dataset = classification_dataset(args.data_path, args.image_size, args.per_batch_size, | |||
| args.rank, args.group_size, repeat_num=args.max_epoch) | |||
| dataset = vgg_create_dataset(args_opt.data_path, cfg.epoch_size) | |||
| batch_num = dataset.get_dataset_size() | |||
| args.steps_per_epoch = dataset.get_dataset_size() | |||
| args.logger.save_args(args) | |||
| # network | |||
| args.logger.important_info('start create network') | |||
| # get network and init | |||
| network = vgg16(args.num_classes, args) | |||
| net = vgg16(num_classes=cfg.num_classes) | |||
| # pre_trained | |||
| if args_opt.pre_trained: | |||
| load_param_into_net(net, load_checkpoint(args_opt.pre_trained)) | |||
| lr = lr_steps(0, lr_init=cfg.lr_init, lr_max=cfg.lr_max, warmup_epochs=cfg.warmup_epochs, | |||
| total_epochs=cfg.epoch_size, steps_per_epoch=batch_num) | |||
| opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), Tensor(lr), cfg.momentum, | |||
| weight_decay=cfg.weight_decay) | |||
| loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False) | |||
| model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}, | |||
| amp_level="O2", keep_batchnorm_fp32=False, loss_scale_manager=None) | |||
| config_ck = CheckpointConfig(save_checkpoint_steps=batch_num * 5, keep_checkpoint_max=cfg.keep_checkpoint_max) | |||
| if args.pre_trained: | |||
| load_param_into_net(network, load_checkpoint(args.pre_trained)) | |||
| # lr scheduler | |||
| if args.lr_scheduler == 'exponential': | |||
| lr = warmup_step_lr(args.lr, | |||
| args.lr_epochs, | |||
| args.steps_per_epoch, | |||
| args.warmup_epochs, | |||
| args.max_epoch, | |||
| gamma=args.lr_gamma, | |||
| ) | |||
| elif args.lr_scheduler == 'cosine_annealing': | |||
| lr = warmup_cosine_annealing_lr(args.lr, | |||
| args.steps_per_epoch, | |||
| args.warmup_epochs, | |||
| args.max_epoch, | |||
| args.T_max, | |||
| args.eta_min) | |||
| elif args.lr_scheduler == 'step': | |||
| lr = lr_steps(0, lr_init=args.lr_init, lr_max=args.lr_max, warmup_epochs=args.warmup_epochs, | |||
| total_epochs=args.max_epoch, steps_per_epoch=batch_num) | |||
| else: | |||
| raise NotImplementedError(args.lr_scheduler) | |||
| # optimizer | |||
| opt = Momentum(params=get_param_groups(network), | |||
| learning_rate=Tensor(lr), | |||
| momentum=args.momentum, | |||
| weight_decay=args.weight_decay, | |||
| loss_scale=args.loss_scale) | |||
| if args.dataset == "cifar10": | |||
| loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False) | |||
| model = Model(network, loss_fn=loss, optimizer=opt, metrics={'acc'}, | |||
| amp_level="O2", keep_batchnorm_fp32=False, loss_scale_manager=None) | |||
| else: | |||
| if not args.label_smooth: | |||
| args.label_smooth_factor = 0.0 | |||
| loss = CrossEntropy(smooth_factor=args.label_smooth_factor, num_classes=args.num_classes) | |||
| loss_scale_manager = FixedLossScaleManager(args.loss_scale, drop_overflow_update=False) | |||
| model = Model(network, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale_manager, amp_level="O2") | |||
| # define callbacks | |||
| time_cb = TimeMonitor(data_size=batch_num) | |||
| ckpoint_cb = ModelCheckpoint(prefix="train_vgg_cifar10", directory="./", config=config_ck) | |||
| loss_cb = LossMonitor() | |||
| model.train(cfg.epoch_size, dataset, callbacks=[time_cb, ckpoint_cb, loss_cb]) | |||
| print("train success") | |||
| loss_cb = LossMonitor(per_print_times=batch_num) | |||
| callbacks = [time_cb, loss_cb] | |||
| if args.rank_save_ckpt_flag: | |||
| ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval * args.steps_per_epoch, | |||
| keep_checkpoint_max=args.ckpt_save_max) | |||
| ckpt_cb = ModelCheckpoint(config=ckpt_config, | |||
| directory=args.outputs_dir, | |||
| prefix='{}'.format(args.rank)) | |||
| callbacks.append(ckpt_cb) | |||
| model.train(args.max_epoch, dataset, callbacks=callbacks) | |||