Browse Source

googlenet support imagenet dataset on Ascend

pull/5925/head
caojian05 CaoJian 5 years ago
parent
commit
e4a0b0db34
15 changed files with 926 additions and 248 deletions
  1. +327
    -176
      model_zoo/googlenet/README.md
  2. +44
    -10
      model_zoo/googlenet/eval.py
  3. +19
    -6
      model_zoo/googlenet/export.py
  4. +25
    -0
      model_zoo/googlenet/mindspore_hub_conf.py
  5. +43
    -0
      model_zoo/googlenet/scripts/run_eval_gpu.sh
  6. +24
    -9
      model_zoo/googlenet/scripts/run_train.sh
  7. +51
    -0
      model_zoo/googlenet/scripts/run_train_gpu.sh
  8. +39
    -2
      model_zoo/googlenet/src/config.py
  9. +88
    -11
      model_zoo/googlenet/src/dataset.py
  10. +11
    -12
      model_zoo/googlenet/src/googlenet.py
  11. +0
    -0
      model_zoo/googlenet/src/lr_scheduler/__init__.py
  12. +20
    -0
      model_zoo/googlenet/src/lr_scheduler/linear_warmup.py
  13. +39
    -0
      model_zoo/googlenet/src/lr_scheduler/warmup_cosine_annealing_lr.py
  14. +59
    -0
      model_zoo/googlenet/src/lr_scheduler/warmup_step_lr.py
  15. +137
    -22
      model_zoo/googlenet/train.py

+ 327
- 176
model_zoo/googlenet/README.md View File

@@ -36,14 +36,8 @@ GoogleNet, a 22 layers deep network, was proposed in 2014 and won the first plac
# [Model Architecture](#contents) # [Model Architecture](#contents)
The overall network architecture of GoogleNet is shown below:
![](https://miro.medium.com/max/3780/1*ZFPOSAted10TPd3hBQU8iQ.png)
Specifically, the GoogleNet contains numerous inception modules, which are connected together to go deeper. In general, an inception module with dimensionality reduction consists of **1×1 conv**, **3×3 conv**, **5×5 conv**, and **3×3 max pooling**, which are done altogether for the previous input, and stack together again at output. Specifically, the GoogleNet contains numerous inception modules, which are connected together to go deeper. In general, an inception module with dimensionality reduction consists of **1×1 conv**, **3×3 conv**, **5×5 conv**, and **3×3 max pooling**, which are done altogether for the previous input, and stack together again at output.
![](https://miro.medium.com/max/1108/1*sezFsYW1MyM9YOMa1q909A.png)
# [Dataset](#contents) # [Dataset](#contents)
@@ -52,10 +46,9 @@ Dataset used: [CIFAR-10](<http://www.cs.toronto.edu/~kriz/cifar.html>)
- Dataset size:175M,60,000 32*32 colorful images in 10 classes - Dataset size:175M,60,000 32*32 colorful images in 10 classes
- Train:146M,50,000 images - Train:146M,50,000 images
- Test:29.3M,10,000 images
- Test:29M,10,000 images
- Data format:binary files - Data format:binary files
- Note:Data will be processed in dataset.py
- Note:Data will be processed in src/dataset.py
# [Features](#contents) # [Features](#contents)
@@ -72,7 +65,7 @@ For FP16 operators, if the input data type is FP32, the backend of MindSpore wil
- Hardware(Ascend/GPU) - Hardware(Ascend/GPU)
- Prepare hardware environment with Ascend or GPU processor. If you want to try Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources. - Prepare hardware environment with Ascend or GPU processor. If you want to try Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources.
- Framework - Framework
- [MindSpore](http://10.90.67.50/mindspore/archive/20200506/OpenSource/me_vm_x86/)
- [MindSpore](https://www.mindspore.cn/install/en)
- For more information, please check the resources below: - For more information, please check the resources below:
- [MindSpore tutorials](https://www.mindspore.cn/tutorial/zh-CN/master/index.html) - [MindSpore tutorials](https://www.mindspore.cn/tutorial/zh-CN/master/index.html)
- [MindSpore API](https://www.mindspore.cn/api/zh-CN/master/index.html) - [MindSpore API](https://www.mindspore.cn/api/zh-CN/master/index.html)
@@ -83,16 +76,45 @@ For FP16 operators, if the input data type is FP32, the backend of MindSpore wil
After installing MindSpore via the official website, you can start training and evaluation as follows: After installing MindSpore via the official website, you can start training and evaluation as follows:
```python
# run training example
python train.py > train.log 2>&1 &
- runing on Ascend
# run distributed training example
sh scripts/run_train.sh rank_table.json
```python
# run training example
python train.py > train.log 2>&1 &
# run distributed training example
sh scripts/run_train.sh rank_table.json
# run evaluation example
python eval.py > eval.log 2>&1 &
OR
sh run_eval.sh
```
For distributed training, a hccl configuration file with JSON format needs to be created in advance.
Please follow the instructions in the link below:
https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools.
- running on GPU
For running on GPU, please change `device_target` from `Ascend` to `GPU` in configuration file src/config.py
```python
# run training example
export CUDA_VISIBLE_DEVICES=0
python train.py > train.log 2>&1 &
# run distributed training example
sh scripts/run_train_gpu.sh 8 0,1,2,3,4,5,6,7
# run evaluation example
python eval.py --checkpoint_path=[CHECKPOINT_PATH] > eval.log 2>&1 &
OR
sh run_eval_gpu.sh [CHECKPOINT_PATH]
```
# run evaluation example
python eval.py > eval.log 2>&1 & OR sh run_eval.sh
```
@@ -106,109 +128,168 @@ python eval.py > eval.log 2>&1 & OR sh run_eval.sh
├── googlenet ├── googlenet
├── README.md // descriptions about googlenet ├── README.md // descriptions about googlenet
├── scripts ├── scripts
│ ├──run_train.sh // shell script for distributed
│ ├──run_eval.sh // shell script for evaluation
│ ├──run_train.sh // shell script for distributed on Ascend
│ ├──run_train_gpu.sh // shell script for distributed on GPU
│ ├──run_eval.sh // shell script for evaluation on Ascend
│ ├──run_eval_gpu.sh // shell script for evaluation on GPU
├── src ├── src
│ ├──dataset.py // creating dataset │ ├──dataset.py // creating dataset
│ ├──googlenet.py // googlenet architecture │ ├──googlenet.py // googlenet architecture
│ ├──config.py // parameter configuration │ ├──config.py // parameter configuration
├── train.py // training script ├── train.py // training script
├── eval.py // evaluation script ├── eval.py // evaluation script
├── export.py // export checkpoint files into geir/onnx
├── export.py // export checkpoint files into air/onnx
``` ```
## [Script Parameters](#contents) ## [Script Parameters](#contents)
```python
Major parameters in train.py and config.py are:
--data_path: The absolute full path to the train and evaluation datasets.
--epoch_size: Total training epochs.
--batch_size: Training batch size.
--lr_init: Initial learning rate.
--num_classes: The number of classes in the training set.
--weight_decay: Weight decay value.
--image_height: Image height used as input to the model.
--image_width: Image width used as input the model.
--pre_trained: Whether training from scratch or training based on the
pre-trained model.Optional values are True, False.
--device_target: Device where the code will be implemented. Optional values
are "Ascend", "GPU".
--device_id: Device ID used to train or evaluate the dataset. Ignore it
when you use run_train.sh for distributed training.
--checkpoint_path: The absolute full path to the checkpoint file saved
after training.
--onnx_filename: File name of the onnx model used in export.py.
--geir_filename: File name of the geir model used in export.py.
```
Parameters for both training and evaluation can be set in config.py
- config for GoogleNet, CIFAR-10 dataset
```python
'pre_trained': 'False' # whether training based on the pre-trained model
'nump_classes': 10 # the number of classes in the dataset
'lr_init': 0.1 # initial learning rate
'batch_size': 128 # training batch size
'epoch_size': 125 # total training epochs
'momentum': 0.9 # momentum
'weight_decay': 5e-4 # weight decay value
'buffer_size': 10 # buffer size
'image_height': 224 # image height used as input to the model
'image_width': 224 # image width used as input to the model
'data_path': './cifar10' # absolute full path to the train and evaluation datasets
'device_target': 'Ascend' # device running the program
'device_id': 4 # device ID used to train or evaluate the dataset. Ignore it when you use run_train.sh for distributed training
'keep_checkpoint_max': 10 # only keep the last keep_checkpoint_max checkpoint
'checkpoint_path': './train_googlenet_cifar10-125_390.ckpt' # the absolute full path to save the checkpoint file
'onnx_filename': 'googlenet.onnx' # file name of the onnx model used in export.py
'geir_filename': 'googlenet.geir' # file name of the geir model used in export.py
```
## [Training Process](#contents) ## [Training Process](#contents)
### Training ### Training
```
python train.py > train.log 2>&1 &
```
The python command above will run in the background, you can view the results through the file `train.log`.
- running on Ascend
```
python train.py > train.log 2>&1 &
```
The python command above will run in the background, you can view the results through the file `train.log`.
After training, you'll get some checkpoint files under the script folder by default. The loss value will be achieved as follows:
```
# grep "loss is " train.log
epoch: 1 step: 390, loss is 1.4842823
epcoh: 2 step: 390, loss is 1.0897788
...
```
The model checkpoint will be saved in the current directory.
- running on GPU
```
export CUDA_VISIBLE_DEVICES=0
python train.py > train.log 2>&1 &
```
The python command above will run in the background, you can view the results through the file `train.log`.
After training, you'll get some checkpoint files under the folder `./ckpt_0/` by default.
After training, you'll get some checkpoint files under the script folder by default. The loss value will be achieved as follows:
```
# grep "loss is " train.log
epoch: 1 step: 390, loss is 1.4842823
epcoh: 2 step: 390, loss is 1.0897788
...
```
The model checkpoint will be saved in the current directory.
### Distributed Training ### Distributed Training
```
sh scripts/run_train.sh rank_table.json
```
The above shell script will run distribute training in the background. You can view the results through the file `train_parallel[X]/log`. The loss value will be achieved as follows:
```
# grep "result: " train_parallel*/log
train_parallel0/log:epoch: 1 step: 48, loss is 1.4302931
train_parallel0/log:epcoh: 2 step: 48, loss is 1.4023874
...
train_parallel1/log:epoch: 1 step: 48, loss is 1.3458025
train_parallel1/log:epcoh: 2 step: 48, loss is 1.3729336
...
...
```
- running on Ascend
```
sh scripts/run_train.sh rank_table.json
```
The above shell script will run distribute training in the background. You can view the results through the file `train_parallel[X]/log`. The loss value will be achieved as follows:
```
# grep "result: " train_parallel*/log
train_parallel0/log:epoch: 1 step: 48, loss is 1.4302931
train_parallel0/log:epcoh: 2 step: 48, loss is 1.4023874
...
train_parallel1/log:epoch: 1 step: 48, loss is 1.3458025
train_parallel1/log:epcoh: 2 step: 48, loss is 1.3729336
...
...
```
- running on GPU
```
sh scripts/run_train_gpu.sh 8 0,1,2,3,4,5,6,7
```
The above shell script will run distribute training in the background. You can view the results through the file `train/train.log`.
## [Evaluation Process](#contents) ## [Evaluation Process](#contents)
### Evaluation ### Evaluation
Before running the command below, please check the checkpoint path used for evaluation. Please set the checkpoint path to be the absolute full path, e.g., "username/googlenet/train_googlenet_cifar10-125_390.ckpt".
```
python eval.py > eval.log 2>&1 &
OR
sh scripts/run_eval.sh
```
The above python command will run in the background. You can view the results through the file "eval.log". The accuracy of the test dataset will be as follows:
```
# grep "accuracy: " eval.log
accuracy: {'acc': 0.934}
```
Note that for evaluation after distributed training, please set the checkpoint_path to be the last saved checkpoint file such as "username/googlenet/train_parallel0/train_googlenet_cifar10-125_48.ckpt". The accuracy of the test dataset will be as follows:
```
# grep "accuracy: " dist.eval.log
accuracy: {'acc': 0.9217}
```
- evaluation on CIFAR-10 dataset when running on Ascend
Before running the command below, please check the checkpoint path used for evaluation. Please set the checkpoint path to be the absolute full path, e.g., "username/googlenet/train_googlenet_cifar10-125_390.ckpt".
```
python eval.py > eval.log 2>&1 &
OR
sh scripts/run_eval.sh
```
The above python command will run in the background. You can view the results through the file "eval.log". The accuracy of the test dataset will be as follows:
```
# grep "accuracy: " eval.log
accuracy: {'acc': 0.934}
```
Note that for evaluation after distributed training, please set the checkpoint_path to be the last saved checkpoint file such as "username/googlenet/train_parallel0/train_googlenet_cifar10-125_48.ckpt". The accuracy of the test dataset will be as follows:
```
# grep "accuracy: " dist.eval.log
accuracy: {'acc': 0.9217}
```
- evaluation on CIFAR-10 dataset when running on GPU
Before running the command below, please check the checkpoint path used for evaluation. Please set the checkpoint path to be the absolute full path, e.g., "username/googlenet/train/ckpt_0/train_googlenet_cifar10-125_390.ckpt".
```
python eval.py --checkpoint_path=[CHECKPOINT_PATH] > eval.log 2>&1 &
```
The above python command will run in the background. You can view the results through the file "eval.log". The accuracy of the test dataset will be as follows:
```
# grep "accuracy: " eval.log
accuracy: {'acc': 0.930}
```
OR,
```
sh scripts/run_eval_gpu.sh [CHECKPOINT_PATH]
```
The above python command will run in the background. You can view the results through the file "eval/eval.log". The accuracy of the test dataset will be as follows:
```
# grep "accuracy: " eval/eval.log
accuracy: {'acc': 0.930}
```
# [Model Description](#contents) # [Model Description](#contents)
@@ -216,100 +297,170 @@ accuracy: {'acc': 0.9217}
### Evaluation Performance ### Evaluation Performance
| Parameters | GoogleNet |
| -------------------------- | ----------------------------------------------------------- |
| Model Version | Inception V1 |
| Resource | Ascend 910 ;CPU 2.60GHz,56cores;Memory,314G |
| uploaded Date | 06/09/2020 (month/day/year) |
| MindSpore Version | 0.3.0-alpha |
| Dataset | CIFAR-10 |
| Training Parameters | epoch=125, steps=390, batch_size = 128, lr=0.1 |
| Optimizer | SGD |
| Loss Function | Softmax Cross Entropy |
| outputs | probability |
| Loss | 0.0016 |
| Speed | 1pc: 79 ms/step; 8pcs: 82 ms/step |
| Total time | 1pc: 63.85 mins; 8pcs: 11.28 mins |
| Parameters (M) | 6.8 |
| Checkpoint for Fine tuning | 43.07M (.ckpt file) |
| Model for inference | 21.50M (.onnx file), 21.60M(.geir file) |
| Scripts | https://gitee.com/mindspore/mindspore/tree/master/model_zoo/googlenet |
| Parameters | Ascend | GPU |
| -------------------------- | ----------------------------------------------------------- | ---------------------- |
| Model Version | Inception V1 | Inception V1 |
| Resource | Ascend 910 ;CPU 2.60GHz,56cores;Memory,314G | NV SMX2 V100-32G |
| uploaded Date | 08/31/2020 (month/day/year) | 08/20/2020 (month/day/year) |
| MindSpore Version | 0.7.0-alpha | 0.6.0-alpha |
| Dataset | CIFAR-10 | CIFAR-10 |
| Training Parameters | epoch=125, steps=390, batch_size = 128, lr=0.1 | epoch=125, steps=390, batch_size=128, lr=0.1 |
| Optimizer | SGD | SGD |
| Loss Function | Softmax Cross Entropy | Softmax Cross Entropy |
| outputs | probability | probobility |
| Loss | 0.0016 | 0.0016 |
| Speed | 1pc: 79 ms/step; 8pcs: 82 ms/step | 1pc: 150 ms/step; 8pcs: 164 ms/step |
| Total time | 1pc: 63.85 mins; 8pcs: 11.28 mins | 1pc: 126.87 mins; 8pcs: 21.65 mins |
| Parameters (M) | 13.0 | 13.0 |
| Checkpoint for Fine tuning | 43.07M (.ckpt file) | 43.07M (.ckpt file) |
| Model for inference | 21.50M (.onnx file), 21.60M(.air file) | |
| Scripts | [googlenet script](https://gitee.com/mindspore/mindspore/tree/r0.7/model_zoo/official/cv/googlenet) | [googlenet script](https://gitee.com/mindspore/mindspore/tree/r0.6/model_zoo/official/cv/googlenet) |
### Inference Performance ### Inference Performance
| Parameters | GoogleNet |
| ------------------- | --------------------------- |
| Model Version | Inception V1 |
| Resource | Ascend 910 |
| Uploaded Date | 06/09/2020 (month/day/year) |
| MindSpore Version | 0.3.0-alpha |
| Dataset | CIFAR-10, 10,000 images |
| batch_size | 128 |
| outputs | probability |
| Accuracy | 1pc: 93.4%; 8pcs: 92.17% |
| Model for inference | 21.50M (.onnx file) |
| Parameters | Ascend | GPU |
| ------------------- | --------------------------- | --------------------------- |
| Model Version | Inception V1 | Inception V1 |
| Resource | Ascend 910 | GPU |
| Uploaded Date | 08/31/2020 (month/day/year) | 08/20/2020 (month/day/year) |
| MindSpore Version | 0.7.0-alpha | 0.6.0-alpha |
| Dataset | CIFAR-10, 10,000 images | CIFAR-10, 10,000 images |
| batch_size | 128 | 128 |
| outputs | probability | probability |
| Accuracy | 1pc: 93.4%; 8pcs: 92.17% | 1pc: 93%, 8pcs: 92.89% |
| Model for inference | 21.50M (.onnx file) | |
## [How to use](#contents) ## [How to use](#contents)
### Inference ### Inference
If you need to use the trained model to perform inference on multiple hardware platforms, such as GPU, Ascend 910 or Ascend 310, you can refer to this [Link](https://www.mindspore.cn/tutorial/zh-CN/master/advanced_use/network_migration.html). Following the steps below, this is a simple example: If you need to use the trained model to perform inference on multiple hardware platforms, such as GPU, Ascend 910 or Ascend 310, you can refer to this [Link](https://www.mindspore.cn/tutorial/zh-CN/master/advanced_use/network_migration.html). Following the steps below, this is a simple example:
```
# Load unseen dataset for inference
dataset = dataset.create_dataset(cfg.data_path, 1, False)
# Define model
net = GoogleNet(num_classes=cfg.num_classes)
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01,
cfg.momentum, weight_decay=cfg.weight_decay)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean',
is_grad=False)
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'})
# Load pre-trained model
param_dict = load_checkpoint(cfg.checkpoint_path)
load_param_into_net(net, param_dict)
net.set_train(False)
# Make predictions on the unseen dataset
acc = model.eval(dataset)
print("accuracy: ", acc)
```
- Running on Ascend
```
# Set context
context.set_context(mode=context.GRAPH_HOME, device_target=cfg.device_target)
context.set_context(device_id=cfg.device_id)
# Load unseen dataset for inference
dataset = dataset.create_dataset(cfg.data_path, 1, False)
# Define model
net = GoogleNet(num_classes=cfg.num_classes)
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01,
cfg.momentum, weight_decay=cfg.weight_decay)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean',
is_grad=False)
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'})
# Load pre-trained model
param_dict = load_checkpoint(cfg.checkpoint_path)
load_param_into_net(net, param_dict)
net.set_train(False)
# Make predictions on the unseen dataset
acc = model.eval(dataset)
print("accuracy: ", acc)
```
- Running on GPU:
```
# Set context
context.set_context(mode=context.GRAPH_HOME, device_target="GPU")
# Load unseen dataset for inference
dataset = dataset.create_dataset(cfg.data_path, 1, False)
# Define model
net = GoogleNet(num_classes=cfg.num_classes)
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01,
cfg.momentum, weight_decay=cfg.weight_decay)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean',
is_grad=False)
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'})
# Load pre-trained model
param_dict = load_checkpoint(args_opt.checkpoint_path)
load_param_into_net(net, param_dict)
net.set_train(False)
# Make predictions on the unseen dataset
acc = model.eval(dataset)
print("accuracy: ", acc)
```
### Continue Training on the Pretrained Model ### Continue Training on the Pretrained Model
```
# Load dataset
dataset = create_dataset(cfg.data_path, cfg.epoch_size)
batch_num = dataset.get_dataset_size()
# Define model
net = GoogleNet(num_classes=cfg.num_classes)
# Continue training if set pre_trained to be True
if cfg.pre_trained:
param_dict = load_checkpoint(cfg.checkpoint_path)
load_param_into_net(net, param_dict)
lr = lr_steps(0, lr_max=cfg.lr_init, total_epochs=cfg.epoch_size,
steps_per_epoch=batch_num)
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()),
Tensor(lr), cfg.momentum, weight_decay=cfg.weight_decay)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False)
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'},
amp_level="O2", keep_batchnorm_fp32=False, loss_scale_manager=None)
# Set callbacks
config_ck = CheckpointConfig(save_checkpoint_steps=batch_num * 5,
keep_checkpoint_max=cfg.keep_checkpoint_max)
time_cb = TimeMonitor(data_size=batch_num)
ckpoint_cb = ModelCheckpoint(prefix="train_googlenet_cifar10", directory="./",
config=config_ck)
loss_cb = LossMonitor()
# Start training
model.train(cfg.epoch_size, dataset, callbacks=[time_cb, ckpoint_cb, loss_cb])
print("train success")
```
- running on Ascend
```
# Load dataset
dataset = create_dataset(cfg.data_path, 1)
batch_num = dataset.get_dataset_size()
# Define model
net = GoogleNet(num_classes=cfg.num_classes)
# Continue training if set pre_trained to be True
if cfg.pre_trained:
param_dict = load_checkpoint(cfg.checkpoint_path)
load_param_into_net(net, param_dict)
lr = lr_steps(0, lr_max=cfg.lr_init, total_epochs=cfg.epoch_size,
steps_per_epoch=batch_num)
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()),
Tensor(lr), cfg.momentum, weight_decay=cfg.weight_decay)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False)
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'},
amp_level="O2", keep_batchnorm_fp32=False, loss_scale_manager=None)
# Set callbacks
config_ck = CheckpointConfig(save_checkpoint_steps=batch_num * 5,
keep_checkpoint_max=cfg.keep_checkpoint_max)
time_cb = TimeMonitor(data_size=batch_num)
ckpoint_cb = ModelCheckpoint(prefix="train_googlenet_cifar10", directory="./",
config=config_ck)
loss_cb = LossMonitor()
# Start training
model.train(cfg.epoch_size, dataset, callbacks=[time_cb, ckpoint_cb, loss_cb])
print("train success")
```
- running on GPU
```
# Load dataset
dataset = create_dataset(cfg.data_path, 1)
batch_num = dataset.get_dataset_size()
# Define model
net = GoogleNet(num_classes=cfg.num_classes)
# Continue training if set pre_trained to be True
if cfg.pre_trained:
param_dict = load_checkpoint(cfg.checkpoint_path)
load_param_into_net(net, param_dict)
lr = lr_steps(0, lr_max=cfg.lr_init, total_epochs=cfg.epoch_size,
steps_per_epoch=batch_num)
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()),
Tensor(lr), cfg.momentum, weight_decay=cfg.weight_decay)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False)
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'},
amp_level="O2", keep_batchnorm_fp32=True, loss_scale_manager=None)
# Set callbacks
config_ck = CheckpointConfig(save_checkpoint_steps=batch_num * 5,
keep_checkpoint_max=cfg.keep_checkpoint_max)
time_cb = TimeMonitor(data_size=batch_num)
ckpoint_cb = ModelCheckpoint(prefix="train_googlenet_cifar10", directory="./ckpt_" + str(get_rank()) + "/",
config=config_ck)
loss_cb = LossMonitor()
# Start training
model.train(cfg.epoch_size, dataset, callbacks=[time_cb, ckpoint_cb, loss_cb])
print("train success")
```
### Transfer Learning ### Transfer Learning
To be added. To be added.


+ 44
- 10
model_zoo/googlenet/eval.py View File

@@ -16,30 +16,64 @@
##############test googlenet example on cifar10################# ##############test googlenet example on cifar10#################
python eval.py python eval.py
""" """
import argparse

import mindspore.nn as nn import mindspore.nn as nn
from mindspore import context from mindspore import context
from mindspore.nn.optim.momentum import Momentum from mindspore.nn.optim.momentum import Momentum
from mindspore.train.model import Model from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net


from src.config import cifar_cfg as cfg
from src.dataset import create_dataset
from src.config import cifar_cfg, imagenet_cfg
from src.dataset import create_dataset_cifar10, create_dataset_imagenet

from src.googlenet import GoogleNet from src.googlenet import GoogleNet




parser = argparse.ArgumentParser(description='googlenet')
parser.add_argument('--dataset_name', type=str, default='cifar10', choices=['imagenet', 'cifar10'],
help='dataset name.')
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
args_opt = parser.parse_args()

if __name__ == '__main__': if __name__ == '__main__':

if args_opt.dataset_name == 'cifar10':
cfg = cifar_cfg
dataset = create_dataset_cifar10(cfg.data_path, 1, False)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False)
net = GoogleNet(num_classes=cfg.num_classes)
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, cfg.momentum,
weight_decay=cfg.weight_decay)
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'})

elif args_opt.dataset_name == "imagenet":
cfg = imagenet_cfg
dataset = create_dataset_imagenet(cfg.val_data_path, 1, False)
if not cfg.use_label_smooth:
cfg.label_smooth_factor = 0.0
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean",
smooth_factor=cfg.label_smooth_factor, num_classes=cfg.num_classes)
net = GoogleNet(num_classes=cfg.num_classes)
model = Model(net, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'})

else:
raise ValueError("dataset is not support.")

device_target = cfg.device_target
context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target) context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target)
context.set_context(device_id=cfg.device_id)
if device_target == "Ascend":
context.set_context(device_id=cfg.device_id)


net = GoogleNet(num_classes=cfg.num_classes)
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, cfg.momentum,
weight_decay=cfg.weight_decay)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False)
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'})
if args_opt.checkpoint_path is not None:
param_dict = load_checkpoint(args_opt.checkpoint_path)
print("load checkpoint from [{}].".format(args_opt.checkpoint_path))
else:
param_dict = load_checkpoint(cfg.checkpoint_path)
print("load checkpoint from [{}].".format(cfg.checkpoint_path))


param_dict = load_checkpoint(cfg.checkpoint_path)
load_param_into_net(net, param_dict) load_param_into_net(net, param_dict)
net.set_train(False) net.set_train(False)
dataset = create_dataset(cfg.data_path, 1, False)
acc = model.eval(dataset) acc = model.eval(dataset)
print("accuracy: ", acc) print("accuracy: ", acc)

+ 19
- 6
model_zoo/googlenet/export.py View File

@@ -13,24 +13,37 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
""" """
##############export checkpoint file into geir and onnx models#################
##############export checkpoint file into air and onnx models#################
python export.py python export.py
""" """
import argparse
import numpy as np import numpy as np


import mindspore as ms
from mindspore import Tensor from mindspore import Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export from mindspore.train.serialization import load_checkpoint, load_param_into_net, export


from src.config import cifar_cfg as cfg
from src.config import cifar_cfg, imagenet_cfg
from src.googlenet import GoogleNet from src.googlenet import GoogleNet



if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Classification')
parser.add_argument('--dataset_name', type=str, default='cifar10', choices=['imagenet', 'cifar10'],
help='dataset name.')
args_opt = parser.parse_args()

if args_opt.dataset_name == 'cifar10':
cfg = cifar_cfg
elif args_opt.dataset_name == 'imagenet':
cfg = imagenet_cfg
else:
raise ValueError("dataset is not support.")

net = GoogleNet(num_classes=cfg.num_classes) net = GoogleNet(num_classes=cfg.num_classes)

assert cfg.checkpoint_path is not None, "cfg.checkpoint_path is None."
param_dict = load_checkpoint(cfg.checkpoint_path) param_dict = load_checkpoint(cfg.checkpoint_path)
load_param_into_net(net, param_dict) load_param_into_net(net, param_dict)


input_arr = Tensor(np.random.uniform(0.0, 1.0, size=[1, 3, 224, 224]), ms.float32)
input_arr = Tensor(np.random.uniform(0.0, 1.0, size=[1, 3, 224, 224]).astype(np.float32))
export(net, input_arr, file_name=cfg.onnx_filename, file_format="ONNX") export(net, input_arr, file_name=cfg.onnx_filename, file_format="ONNX")
export(net, input_arr, file_name=cfg.geir_filename, file_format="GEIR")
export(net, input_arr, file_name=cfg.air_filename, file_format="AIR")

+ 25
- 0
model_zoo/googlenet/mindspore_hub_conf.py View File

@@ -0,0 +1,25 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""hub config."""
from src.googlenet import GoogleNet

def googlenet(*args, **kwargs):
return GoogleNet(*args, **kwargs)


def create_network(name, *args, **kwargs):
if name == "googlenet":
return googlenet(*args, **kwargs)
raise NotImplementedError(f"{name} is not implemented in the repo")

+ 43
- 0
model_zoo/googlenet/scripts/run_eval_gpu.sh View File

@@ -0,0 +1,43 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

ulimit -u unlimited

if [ $# != 1 ]
then
echo "GPU: sh run_eval_gpu.sh [CHECKPOINT_PATH]"
exit 1
fi

# check checkpoint file
if [ ! -f $1 ]
then
echo "error: CHECKPOINT_PATH=$1 is not a file"
exit 1
fi

BASEPATH=$(cd "`dirname $0`" || exit; pwd)
export PYTHONPATH=${BASEPATH}:$PYTHONPATH
export DEVICE_ID=0

if [ -d "../eval" ];
then
rm -rf ../eval
fi
mkdir ../eval
cd ../eval || exit

python3 ${BASEPATH}/../eval.py --checkpoint_path=$1 > ./eval.log 2>&1 &

+ 24
- 9
model_zoo/googlenet/scripts/run_train.sh View File

@@ -14,36 +14,51 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================


if [ $# != 1 ]
if [ $# != 1 ] && [ $# != 2 ]
then then
echo "Usage: sh run_train.sh [MINDSPORE_HCCL_CONFIG_PATH]"
echo "Usage: sh run_train.sh [RANK_TABLE_FILE] [cifar10|imagenet]"
exit 1 exit 1
fi fi


if [ ! -f $1 ] if [ ! -f $1 ]
then then
echo "error: MINDSPORE_HCCL_CONFIG_PATH=$1 is not a file"
echo "error: RANK_TABLE_FILE=$1 is not a file"
exit 1 exit 1
fi fi



dataset_type='cifar10'
if [ $# == 2 ]
then
if [ $2 != "cifar10" ] && [ $2 != "imagenet" ]
then
echo "error: the selected dataset is neither cifar10 nor imagenet"
exit 1
fi
dataset_type=$2
fi


ulimit -u unlimited ulimit -u unlimited
export DEVICE_NUM=8 export DEVICE_NUM=8
export RANK_SIZE=8 export RANK_SIZE=8
MINDSPORE_HCCL_CONFIG_PATH=$(realpath $1)
export MINDSPORE_HCCL_CONFIG_PATH
echo "MINDSPORE_HCCL_CONFIG_PATH=${MINDSPORE_HCCL_CONFIG_PATH}"
RANK_TABLE_FILE=$(realpath $1)
export RANK_TABLE_FILE
echo "RANK_TABLE_FILE=${RANK_TABLE_FILE}"


export SERVER_ID=0
rank_start=$((DEVICE_NUM * SERVER_ID))
for((i=0; i<${DEVICE_NUM}; i++)) for((i=0; i<${DEVICE_NUM}; i++))
do do
export DEVICE_ID=$i export DEVICE_ID=$i
export RANK_ID=$i
export RANK_ID=$((rank_start + i))
rm -rf ./train_parallel$i rm -rf ./train_parallel$i
mkdir ./train_parallel$i mkdir ./train_parallel$i
cp -r ./src ./train_parallel$i cp -r ./src ./train_parallel$i
cp ./train.py ./train_parallel$i cp ./train.py ./train_parallel$i
echo "start training for rank $RANK_ID, device $DEVICE_ID"
echo "start training for rank $RANK_ID, device $DEVICE_ID, $dataset_type"
cd ./train_parallel$i ||exit cd ./train_parallel$i ||exit
env > env.log env > env.log
python train.py --device_id=$i > log 2>&1 &
python train.py --device_id=$i --dataset_name=$dataset_type> log 2>&1 &
cd .. cd ..
done done

+ 51
- 0
model_zoo/googlenet/scripts/run_train_gpu.sh View File

@@ -0,0 +1,51 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

if [ $# -lt 2 ]
then
echo "Usage:\n \
sh run_train.sh [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)]\n \
"
exit 1
fi

if [ $1 -lt 1 ] && [ $1 -gt 8 ]
then
echo "error: DEVICE_NUM=$1 is not in (1-8)"
exit 1
fi

export DEVICE_NUM=$1
export RANK_SIZE=$1

BASEPATH=$(cd "`dirname $0`" || exit; pwd)
export PYTHONPATH=${BASEPATH}:$PYTHONPATH
if [ -d "../train" ];
then
rm -rf ../train
fi
mkdir ../train
cd ../train || exit

export CUDA_VISIBLE_DEVICES="$2"

if [ $1 -gt 1 ]
then
mpirun -n $1 --allow-run-as-root \
python3 ${BASEPATH}/../train.py > train.log 2>&1 &
else
python3 ${BASEPATH}/../train.py > train.log 2>&1 &
fi

+ 39
- 2
model_zoo/googlenet/src/config.py View File

@@ -18,6 +18,7 @@ network config setting, will be used in main.py
from easydict import EasyDict as edict from easydict import EasyDict as edict


cifar_cfg = edict({ cifar_cfg = edict({
'name': 'cifar10',
'pre_trained': False, 'pre_trained': False,
'num_classes': 10, 'num_classes': 10,
'lr_init': 0.1, 'lr_init': 0.1,
@@ -30,9 +31,45 @@ cifar_cfg = edict({
'image_width': 224, 'image_width': 224,
'data_path': './cifar10', 'data_path': './cifar10',
'device_target': 'Ascend', 'device_target': 'Ascend',
'device_id': 4,
'device_id': 0,
'keep_checkpoint_max': 10, 'keep_checkpoint_max': 10,
'checkpoint_path': './train_googlenet_cifar10-125_390.ckpt', 'checkpoint_path': './train_googlenet_cifar10-125_390.ckpt',
'onnx_filename': 'googlenet.onnx', 'onnx_filename': 'googlenet.onnx',
'geir_filename': 'googlenet.geir'
'air_filename': 'googlenet.air'
})

imagenet_cfg = edict({
'name': 'imagenet',
'pre_trained': False,
'num_classes': 1000,
'lr_init': 0.1,
'batch_size': 256,
'epoch_size': 300,
'momentum': 0.9,
'weight_decay': 1e-4,
'buffer_size': None, # invalid parameter
'image_height': 224,
'image_width': 224,
'data_path': './ImageNet_Original/train/',
'val_data_path': './ImageNet_Original/val/',
'device_target': 'Ascend',
'device_id': 0,
'keep_checkpoint_max': 10,
'checkpoint_path': None,
'onnx_filename': 'googlenet.onnx',
'air_filename': 'googlenet.air',

# optimizer and lr related
'lr_scheduler': 'exponential',
'lr_epochs': [70, 140, 210, 280],
'lr_gamma': 0.3,
'eta_min': 0.0,
'T_max': 150,
'warmup_epochs': 0,

# loss related
'is_dynamic_loss_scale': 0,
'loss_scale': 1024,
'label_smooth_factor': 0.1,
'use_label_smooth': True,
}) })

+ 88
- 11
model_zoo/googlenet/src/dataset.py View File

@@ -21,27 +21,30 @@ import mindspore.common.dtype as mstype
import mindspore.dataset as ds import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as C import mindspore.dataset.transforms.c_transforms as C
import mindspore.dataset.transforms.vision.c_transforms as vision import mindspore.dataset.transforms.vision.c_transforms as vision
from src.config import cifar_cfg as cfg
from src.config import cifar_cfg, imagenet_cfg




def create_dataset(data_home, repeat_num=1, training=True):
def create_dataset_cifar10(data_home, repeat_num=1, training=True):
"""Data operations.""" """Data operations."""
ds.config.set_seed(1) ds.config.set_seed(1)
data_dir = os.path.join(data_home, "cifar-10-batches-bin") data_dir = os.path.join(data_home, "cifar-10-batches-bin")
if not training: if not training:
data_dir = os.path.join(data_home, "cifar-10-verify-bin") data_dir = os.path.join(data_home, "cifar-10-verify-bin")


rank_size = int(os.environ.get("RANK_SIZE")) if os.environ.get("RANK_SIZE") else None
rank_id = int(os.environ.get("RANK_ID")) if os.environ.get("RANK_ID") else None
data_set = ds.Cifar10Dataset(data_dir, num_shards=rank_size, shard_id=rank_id)
rank_size, rank_id = _get_rank_info()
if training:
data_set = ds.Cifar10Dataset(data_dir, num_shards=rank_size, shard_id=rank_id, shuffle=True)
else:
data_set = ds.Cifar10Dataset(data_dir, num_shards=rank_size, shard_id=rank_id, shuffle=False)


resize_height = cfg.image_height
resize_width = cfg.image_width
resize_height = cifar_cfg.image_height
resize_width = cifar_cfg.image_width


# define map operations # define map operations
random_crop_op = vision.RandomCrop((32, 32), (4, 4, 4, 4)) # padding_mode default CONSTANT random_crop_op = vision.RandomCrop((32, 32), (4, 4, 4, 4)) # padding_mode default CONSTANT
random_horizontal_op = vision.RandomHorizontalFlip() random_horizontal_op = vision.RandomHorizontalFlip()
resize_op = vision.Resize((resize_height, resize_width)) # interpolation default BILINEAR resize_op = vision.Resize((resize_height, resize_width)) # interpolation default BILINEAR
rescale_op = vision.Rescale(1.0 / 255.0, 0.0)
normalize_op = vision.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) normalize_op = vision.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
changeswap_op = vision.HWC2CHW() changeswap_op = vision.HWC2CHW()
type_cast_op = C.TypeCast(mstype.int32) type_cast_op = C.TypeCast(mstype.int32)
@@ -49,19 +52,93 @@ def create_dataset(data_home, repeat_num=1, training=True):
c_trans = [] c_trans = []
if training: if training:
c_trans = [random_crop_op, random_horizontal_op] c_trans = [random_crop_op, random_horizontal_op]
c_trans += [resize_op, normalize_op, changeswap_op]
c_trans += [resize_op, rescale_op, normalize_op, changeswap_op]


# apply map operations on images # apply map operations on images
data_set = data_set.map(input_columns="label", operations=type_cast_op) data_set = data_set.map(input_columns="label", operations=type_cast_op)
data_set = data_set.map(input_columns="image", operations=c_trans) data_set = data_set.map(input_columns="image", operations=c_trans)


# apply batch operations
data_set = data_set.batch(batch_size=cifar_cfg.batch_size, drop_remainder=True)

# apply repeat operations # apply repeat operations
data_set = data_set.repeat(repeat_num) data_set = data_set.repeat(repeat_num)


# apply shuffle operations
data_set = data_set.shuffle(buffer_size=10)
return data_set


def create_dataset_imagenet(dataset_path, repeat_num=1, training=True,
num_parallel_workers=None, shuffle=None):
"""
create a train or eval imagenet2012 dataset for resnet50

Args:
dataset_path(string): the path of dataset.
do_train(bool): whether dataset is used for train or eval.
repeat_num(int): the repeat times of dataset. Default: 1
batch_size(int): the batch size of dataset. Default: 32
target(str): the device target. Default: Ascend

Returns:
dataset
"""

device_num, rank_id = _get_rank_info()

if device_num == 1:
data_set = ds.ImageFolderDatasetV2(dataset_path, num_parallel_workers=num_parallel_workers, shuffle=shuffle)
else:
data_set = ds.ImageFolderDatasetV2(dataset_path, num_parallel_workers=num_parallel_workers, shuffle=shuffle,
num_shards=device_num, shard_id=rank_id)

assert imagenet_cfg.image_height == imagenet_cfg.image_width, "image_height not equal image_width"
image_size = imagenet_cfg.image_height
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]

# define map operations
if training:
transform_img = [
vision.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
vision.RandomHorizontalFlip(prob=0.5),
vision.RandomColorAdjust(0.4, 0.4, 0.4, 0.1),
vision.Normalize(mean=mean, std=std),
vision.HWC2CHW()
]
else:
transform_img = [
vision.Decode(),
vision.Resize(256),
vision.CenterCrop(image_size),
vision.Normalize(mean=mean, std=std),
vision.HWC2CHW()
]

transform_label = [C.TypeCast(mstype.int32)]

data_set = data_set.map(input_columns="image", num_parallel_workers=8, operations=transform_img)
data_set = data_set.map(input_columns="label", num_parallel_workers=8, operations=transform_label)


# apply batch operations # apply batch operations
data_set = data_set.batch(batch_size=cfg.batch_size, drop_remainder=True)
data_set = data_set.batch(imagenet_cfg.batch_size, drop_remainder=True)

# apply dataset repeat operation
data_set = data_set.repeat(repeat_num)


return data_set return data_set


def _get_rank_info():
"""
get rank size and rank id
"""
rank_size = int(os.environ.get("RANK_SIZE", 1))

if rank_size > 1:
from mindspore.communication.management import get_rank, get_group_size
rank_size = get_group_size()
rank_id = get_rank()
else:
rank_size = rank_id = None

return rank_size, rank_id

+ 11
- 12
model_zoo/googlenet/src/googlenet.py View File

@@ -63,7 +63,7 @@ class Inception(nn.Cell):
Conv2dBlock(n3x3red, n3x3, kernel_size=3, padding=0)]) Conv2dBlock(n3x3red, n3x3, kernel_size=3, padding=0)])
self.b3 = nn.SequentialCell([Conv2dBlock(in_channels, n5x5red, kernel_size=1), self.b3 = nn.SequentialCell([Conv2dBlock(in_channels, n5x5red, kernel_size=1),
Conv2dBlock(n5x5red, n5x5, kernel_size=3, padding=0)]) Conv2dBlock(n5x5red, n5x5, kernel_size=3, padding=0)])
self.maxpool = P.MaxPoolWithArgmax(ksize=3, strides=1, padding="same")
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=1, pad_mode="same")
self.b4 = Conv2dBlock(in_channels, pool_planes, kernel_size=1) self.b4 = Conv2dBlock(in_channels, pool_planes, kernel_size=1)
self.concat = P.Concat(axis=1) self.concat = P.Concat(axis=1)


@@ -71,9 +71,8 @@ class Inception(nn.Cell):
branch1 = self.b1(x) branch1 = self.b1(x)
branch2 = self.b2(x) branch2 = self.b2(x)
branch3 = self.b3(x) branch3 = self.b3(x)
cell, argmax = self.maxpool(x)
cell = self.maxpool(x)
branch4 = self.b4(cell) branch4 = self.b4(cell)
_ = argmax
return self.concat((branch1, branch2, branch3, branch4)) return self.concat((branch1, branch2, branch3, branch4))




@@ -85,22 +84,22 @@ class GoogleNet(nn.Cell):
def __init__(self, num_classes): def __init__(self, num_classes):
super(GoogleNet, self).__init__() super(GoogleNet, self).__init__()
self.conv1 = Conv2dBlock(3, 64, kernel_size=7, stride=2, padding=0) self.conv1 = Conv2dBlock(3, 64, kernel_size=7, stride=2, padding=0)
self.maxpool1 = P.MaxPoolWithArgmax(ksize=3, strides=2, padding="same")
self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")


self.conv2 = Conv2dBlock(64, 64, kernel_size=1) self.conv2 = Conv2dBlock(64, 64, kernel_size=1)
self.conv3 = Conv2dBlock(64, 192, kernel_size=3, padding=0) self.conv3 = Conv2dBlock(64, 192, kernel_size=3, padding=0)
self.maxpool2 = P.MaxPoolWithArgmax(ksize=3, strides=2, padding="same")
self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")


self.block3a = Inception(192, 64, 96, 128, 16, 32, 32) self.block3a = Inception(192, 64, 96, 128, 16, 32, 32)
self.block3b = Inception(256, 128, 128, 192, 32, 96, 64) self.block3b = Inception(256, 128, 128, 192, 32, 96, 64)
self.maxpool3 = P.MaxPoolWithArgmax(ksize=3, strides=2, padding="same")
self.maxpool3 = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")


self.block4a = Inception(480, 192, 96, 208, 16, 48, 64) self.block4a = Inception(480, 192, 96, 208, 16, 48, 64)
self.block4b = Inception(512, 160, 112, 224, 24, 64, 64) self.block4b = Inception(512, 160, 112, 224, 24, 64, 64)
self.block4c = Inception(512, 128, 128, 256, 24, 64, 64) self.block4c = Inception(512, 128, 128, 256, 24, 64, 64)
self.block4d = Inception(512, 112, 144, 288, 32, 64, 64) self.block4d = Inception(512, 112, 144, 288, 32, 64, 64)
self.block4e = Inception(528, 256, 160, 320, 32, 128, 128) self.block4e = Inception(528, 256, 160, 320, 32, 128, 128)
self.maxpool4 = P.MaxPoolWithArgmax(ksize=2, strides=2, padding="same")
self.maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode="same")


self.block5a = Inception(832, 256, 160, 320, 32, 128, 128) self.block5a = Inception(832, 256, 160, 320, 32, 128, 128)
self.block5b = Inception(832, 384, 192, 384, 48, 128, 128) self.block5b = Inception(832, 384, 192, 384, 48, 128, 128)
@@ -113,23 +112,24 @@ class GoogleNet(nn.Cell):




def construct(self, x): def construct(self, x):
"""construct"""
x = self.conv1(x) x = self.conv1(x)
x, argmax = self.maxpool1(x)
x = self.maxpool1(x)


x = self.conv2(x) x = self.conv2(x)
x = self.conv3(x) x = self.conv3(x)
x, argmax = self.maxpool2(x)
x = self.maxpool2(x)


x = self.block3a(x) x = self.block3a(x)
x = self.block3b(x) x = self.block3b(x)
x, argmax = self.maxpool3(x)
x = self.maxpool3(x)


x = self.block4a(x) x = self.block4a(x)
x = self.block4b(x) x = self.block4b(x)
x = self.block4c(x) x = self.block4c(x)
x = self.block4d(x) x = self.block4d(x)
x = self.block4e(x) x = self.block4e(x)
x, argmax = self.maxpool4(x)
x = self.maxpool4(x)


x = self.block5a(x) x = self.block5a(x)
x = self.block5b(x) x = self.block5b(x)
@@ -138,5 +138,4 @@ class GoogleNet(nn.Cell):
x = self.flatten(x) x = self.flatten(x)
x = self.classifier(x) x = self.classifier(x)


_ = argmax
return x return x

+ 0
- 0
model_zoo/googlenet/src/lr_scheduler/__init__.py View File


+ 20
- 0
model_zoo/googlenet/src/lr_scheduler/linear_warmup.py View File

@@ -0,0 +1,20 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""lr"""
def linear_warmup_lr(current_step, warmup_steps, base_lr, init_lr):
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
lr = float(init_lr) + lr_inc * current_step
return lr

+ 39
- 0
model_zoo/googlenet/src/lr_scheduler/warmup_cosine_annealing_lr.py View File

@@ -0,0 +1,39 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""lr"""
import math
import numpy as np
from .linear_warmup import linear_warmup_lr
def warmup_cosine_annealing_lr(lr, steps_per_epoch, warmup_epochs, max_epoch, T_max, eta_min=0):
""" warmup cosine annealing lr"""
base_lr = lr
warmup_init_lr = 0
total_steps = int(max_epoch * steps_per_epoch)
warmup_steps = int(warmup_epochs * steps_per_epoch)
lr_each_step = []
for i in range(total_steps):
last_epoch = i // steps_per_epoch
if i < warmup_steps:
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
else:
lr = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi * last_epoch / T_max)) / 2
lr_each_step.append(lr)
return np.array(lr_each_step).astype(np.float32)

+ 59
- 0
model_zoo/googlenet/src/lr_scheduler/warmup_step_lr.py View File

@@ -0,0 +1,59 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""lr"""
from collections import Counter
import numpy as np
from .linear_warmup import linear_warmup_lr
def warmup_step_lr(lr, lr_epochs, steps_per_epoch, warmup_epochs, max_epoch, gamma=0.1):
"""warmup step lr"""
base_lr = lr
warmup_init_lr = 0
total_steps = int(max_epoch * steps_per_epoch)
warmup_steps = int(warmup_epochs * steps_per_epoch)
milestones = lr_epochs
milestones_steps = []
for milestone in milestones:
milestones_step = milestone * steps_per_epoch
milestones_steps.append(milestones_step)
lr_each_step = []
lr = base_lr
milestones_steps_counter = Counter(milestones_steps)
for i in range(total_steps):
if i < warmup_steps:
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
else:
lr = lr * gamma ** milestones_steps_counter[i]
lr_each_step.append(lr)
return np.array(lr_each_step).astype(np.float32)
def multi_step_lr(lr, milestones, steps_per_epoch, max_epoch, gamma=0.1):
"""lr"""
return warmup_step_lr(lr, milestones, steps_per_epoch, 0, max_epoch, gamma=gamma)
def step_lr(lr, epoch_size, steps_per_epoch, max_epoch, gamma=0.1):
"""lr"""
lr_epochs = []
for i in range(1, max_epoch):
if i % epoch_size == 0:
lr_epochs.append(i)
return multi_step_lr(lr, lr_epochs, steps_per_epoch, max_epoch, gamma=gamma)

+ 137
- 22
model_zoo/googlenet/train.py View File

@@ -25,21 +25,23 @@ import numpy as np
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor from mindspore import Tensor
from mindspore import context from mindspore import context
from mindspore.communication.management import init
from mindspore.communication.management import init, get_rank
from mindspore.nn.optim.momentum import Momentum from mindspore.nn.optim.momentum import Momentum
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train.model import Model, ParallelMode
from mindspore.train.loss_scale_manager import DynamicLossScaleManager, FixedLossScaleManager
from mindspore.train.model import Model
from mindspore import ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net


from src.config import cifar_cfg as cfg
from src.dataset import create_dataset
from src.config import cifar_cfg, imagenet_cfg
from src.dataset import create_dataset_cifar10, create_dataset_imagenet
from src.googlenet import GoogleNet from src.googlenet import GoogleNet


random.seed(1) random.seed(1)
np.random.seed(1) np.random.seed(1)




def lr_steps(global_step, lr_max=None, total_epochs=None, steps_per_epoch=None):
def lr_steps_cifar10(global_step, lr_max=None, total_epochs=None, steps_per_epoch=None):
"""Set learning rate.""" """Set learning rate."""
lr_each_step = [] lr_each_step = []
total_steps = steps_per_epoch * total_epochs total_steps = steps_per_epoch * total_epochs
@@ -60,25 +62,79 @@ def lr_steps(global_step, lr_max=None, total_epochs=None, steps_per_epoch=None):
return learning_rate return learning_rate




def lr_steps_imagenet(_cfg, steps_per_epoch):
"""lr step for imagenet"""
from src.lr_scheduler.warmup_step_lr import warmup_step_lr
from src.lr_scheduler.warmup_cosine_annealing_lr import warmup_cosine_annealing_lr
if _cfg.lr_scheduler == 'exponential':
_lr = warmup_step_lr(_cfg.lr_init,
_cfg.lr_epochs,
steps_per_epoch,
_cfg.warmup_epochs,
_cfg.epoch_size,
gamma=_cfg.lr_gamma,
)
elif _cfg.lr_scheduler == 'cosine_annealing':
_lr = warmup_cosine_annealing_lr(_cfg.lr_init,
steps_per_epoch,
_cfg.warmup_epochs,
_cfg.epoch_size,
_cfg.T_max,
_cfg.eta_min)
else:
raise NotImplementedError(_cfg.lr_scheduler)

return _lr


if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Cifar10 classification')
parser = argparse.ArgumentParser(description='Classification')
parser.add_argument('--dataset_name', type=str, default='cifar10', choices=['imagenet', 'cifar10'],
help='dataset name.')
parser.add_argument('--device_id', type=int, default=None, help='device id of GPU or Ascend. (Default: None)') parser.add_argument('--device_id', type=int, default=None, help='device id of GPU or Ascend. (Default: None)')
args_opt = parser.parse_args() args_opt = parser.parse_args()


context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target)
if args_opt.device_id is not None:
context.set_context(device_id=args_opt.device_id)
if args_opt.dataset_name == "cifar10":
cfg = cifar_cfg
elif args_opt.dataset_name == "imagenet":
cfg = imagenet_cfg
else: else:
context.set_context(device_id=cfg.device_id)
raise ValueError("Unsupport dataset.")

# set context
device_target = cfg.device_target


context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target)
device_num = int(os.environ.get("DEVICE_NUM", 1)) device_num = int(os.environ.get("DEVICE_NUM", 1))
if device_num > 1:
context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True)

if device_target == "Ascend":
if args_opt.device_id is not None:
context.set_context(device_id=args_opt.device_id)
else:
context.set_context(device_id=cfg.device_id)

if device_num > 1:
context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True)
init()
elif device_target == "GPU":
init() init()


dataset = create_dataset(cfg.data_path, cfg.epoch_size)
if device_num > 1:
context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True)
else:
raise ValueError("Unsupported platform.")

if args_opt.dataset_name == "cifar10":
dataset = create_dataset_cifar10(cfg.data_path, cfg.epoch_size)
elif args_opt.dataset_name == "imagenet":
dataset = create_dataset_imagenet(cfg.data_path, cfg.epoch_size)
else:
raise ValueError("Unsupport dataset.")

batch_num = dataset.get_dataset_size() batch_num = dataset.get_dataset_size()


net = GoogleNet(num_classes=cfg.num_classes) net = GoogleNet(num_classes=cfg.num_classes)
@@ -86,16 +142,75 @@ if __name__ == '__main__':
if cfg.pre_trained: if cfg.pre_trained:
param_dict = load_checkpoint(cfg.checkpoint_path) param_dict = load_checkpoint(cfg.checkpoint_path)
load_param_into_net(net, param_dict) load_param_into_net(net, param_dict)
lr = lr_steps(0, lr_max=cfg.lr_init, total_epochs=cfg.epoch_size, steps_per_epoch=batch_num)
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), Tensor(lr), cfg.momentum,
weight_decay=cfg.weight_decay)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False)
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'},
amp_level="O2", keep_batchnorm_fp32=False, loss_scale_manager=None)

loss_scale_manager = None
if args_opt.dataset_name == 'cifar10':
lr = lr_steps_cifar10(0, lr_max=cfg.lr_init, total_epochs=cfg.epoch_size, steps_per_epoch=batch_num)
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()),
learning_rate=Tensor(lr),
momentum=cfg.momentum,
weight_decay=cfg.weight_decay)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False)

elif args_opt.dataset_name == 'imagenet':
lr = lr_steps_imagenet(cfg, batch_num)


def get_param_groups(network):
""" get param groups """
decay_params = []
no_decay_params = []
for x in network.trainable_params():
parameter_name = x.name
if parameter_name.endswith('.bias'):
# all bias not using weight decay
# print('no decay:{}'.format(parameter_name))
no_decay_params.append(x)
elif parameter_name.endswith('.gamma'):
# bn weight bias not using weight decay, be carefully for now x not include BN
# print('no decay:{}'.format(parameter_name))
no_decay_params.append(x)
elif parameter_name.endswith('.beta'):
# bn weight bias not using weight decay, be carefully for now x not include BN
# print('no decay:{}'.format(parameter_name))
no_decay_params.append(x)
else:
decay_params.append(x)

return [{'params': no_decay_params, 'weight_decay': 0.0}, {'params': decay_params}]


if cfg.is_dynamic_loss_scale:
cfg.loss_scale = 1

opt = Momentum(params=get_param_groups(net),
learning_rate=Tensor(lr),
momentum=cfg.momentum,
weight_decay=cfg.weight_decay,
loss_scale=cfg.loss_scale)
if not cfg.use_label_smooth:
cfg.label_smooth_factor = 0.0
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean",
smooth_factor=cfg.label_smooth_factor, num_classes=cfg.num_classes)

if cfg.is_dynamic_loss_scale == 1:
loss_scale_manager = DynamicLossScaleManager(init_loss_scale=65536, scale_factor=2, scale_window=2000)
else:
loss_scale_manager = FixedLossScaleManager(cfg.loss_scale, drop_overflow_update=False)

if device_target == "Ascend":
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'},
amp_level="O2", keep_batchnorm_fp32=False, loss_scale_manager=loss_scale_manager)
ckpt_save_dir = "./"
else: # GPU
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'},
amp_level="O2", keep_batchnorm_fp32=True, loss_scale_manager=loss_scale_manager)
ckpt_save_dir = "./ckpt_" + str(get_rank()) + "/"


config_ck = CheckpointConfig(save_checkpoint_steps=batch_num * 5, keep_checkpoint_max=cfg.keep_checkpoint_max) config_ck = CheckpointConfig(save_checkpoint_steps=batch_num * 5, keep_checkpoint_max=cfg.keep_checkpoint_max)
time_cb = TimeMonitor(data_size=batch_num) time_cb = TimeMonitor(data_size=batch_num)
ckpoint_cb = ModelCheckpoint(prefix="train_googlenet_cifar10", directory="./", config=config_ck)
ckpoint_cb = ModelCheckpoint(prefix="train_googlenet_" + args_opt.dataset_name, directory=ckpt_save_dir,
config=config_ck)
loss_cb = LossMonitor() loss_cb = LossMonitor()
model.train(cfg.epoch_size, dataset, callbacks=[time_cb, ckpoint_cb, loss_cb]) model.train(cfg.epoch_size, dataset, callbacks=[time_cb, ckpoint_cb, loss_cb])
print("train success") print("train success")

Loading…
Cancel
Save