@@ -0,0 +1,126 @@ | |||||
# Copyright 2022 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
"""Analyse result of ocr evaluation.""" | |||||
import os | |||||
import sys | |||||
import json | |||||
from collections import defaultdict | |||||
from io import BytesIO | |||||
import lmdb | |||||
from PIL import Image | |||||
from cnn_ctc.src.model_utils.config import config | |||||
def analyse_adv_iii5t_3000(lmdb_path): | |||||
"""Analyse result of ocr evaluation.""" | |||||
env = lmdb.open(lmdb_path, max_readers=32, readonly=True, lock=False, readahead=False, meminit=False) | |||||
if not env: | |||||
print('cannot create lmdb from %s' % (lmdb_path)) | |||||
sys.exit(0) | |||||
with env.begin(write=False) as txn: | |||||
n_samples = int(txn.get('num-samples'.encode())) | |||||
print(n_samples) | |||||
n_samples = n_samples // config.TEST_BATCH_SIZE * config.TEST_BATCH_SIZE | |||||
result = defaultdict(dict) | |||||
wrong_count = 0 | |||||
adv_wrong_count = 0 | |||||
ori_correct_adv_wrong_count = 0 | |||||
ori_wrong_adv_wrong_count = 0 | |||||
if not os.path.exists(os.path.join(lmdb_path, 'adv_wrong_pred')): | |||||
os.mkdir(os.path.join(lmdb_path, 'adv_wrong_pred')) | |||||
if not os.path.exists(os.path.join(lmdb_path, 'ori_correct_adv_wrong_pred')): | |||||
os.mkdir(os.path.join(lmdb_path, 'ori_correct_adv_wrong_pred')) | |||||
if not os.path.exists(os.path.join(lmdb_path, 'ori_wrong_adv_wrong_pred')): | |||||
os.mkdir(os.path.join(lmdb_path, 'ori_wrong_adv_wrong_pred')) | |||||
for index in range(n_samples): | |||||
index += 1 # lmdb starts with 1 | |||||
label_key = 'label-%09d'.encode() % index | |||||
label = txn.get(label_key).decode('utf-8').lower() | |||||
pred_key = 'pred-%09d'.encode() % index | |||||
pred = txn.get(pred_key).decode('utf-8') | |||||
if pred != label: | |||||
wrong_count += 1 | |||||
adv_pred_key = 'adv_pred-%09d'.encode() % index | |||||
adv_pred = txn.get(adv_pred_key).decode('utf-8') | |||||
adv_info_key = 'adv_info-%09d'.encode() % index | |||||
adv_info = json.loads(txn.get(adv_info_key).decode('utf-8')) | |||||
for info in adv_info: | |||||
if not result[info[0]]: | |||||
result[info[0]] = defaultdict(int) | |||||
result[info[0]]['count'] += 1 | |||||
if adv_pred != label: | |||||
adv_wrong_count += 1 | |||||
for info in adv_info: | |||||
result[info[0]]['wrong_count'] += 1 | |||||
# save wrong predicted image | |||||
adv_image = 'adv_image-%09d'.encode() % index | |||||
imgbuf = txn.get(adv_image) | |||||
image = Image.open(BytesIO(imgbuf)) | |||||
result_path = os.path.join(lmdb_path, 'adv_wrong_pred', adv_info[0][0]) | |||||
if not os.path.exists(result_path): | |||||
os.mkdir(result_path) | |||||
image.save(os.path.join(result_path, label + '-' + adv_pred + '.png')) | |||||
# origin image is correctly predicted and adv is wrong. | |||||
if pred == label: | |||||
ori_correct_adv_wrong_count += 1 | |||||
result[info[0]]['ori_correct_adv_wrong_count'] += 1 | |||||
result_path = os.path.join(lmdb_path, 'ori_correct_adv_wrong_pred', adv_info[0][0]) | |||||
if not os.path.exists(result_path): | |||||
os.mkdir(result_path) | |||||
image.save(os.path.join(result_path, label + '-' + adv_pred + '.png')) | |||||
# wrong predicted in both origin and adv image. | |||||
else: | |||||
ori_wrong_adv_wrong_count += 1 | |||||
result[info[0]]['ori_wrong_adv_wrong_count'] += 1 | |||||
result_path = os.path.join(lmdb_path, 'ori_wrong_adv_wrong_pred', adv_info[0][0]) | |||||
if not os.path.exists(result_path): | |||||
os.mkdir(result_path) | |||||
image.save(os.path.join(result_path, label + '-' + adv_pred + '.png')) | |||||
print('Number of samples in analyse dataset: ', n_samples) | |||||
print('Accuracy of original dataset: ', 1 - wrong_count / n_samples) | |||||
print('Accuracy of adversarial dataset: ', 1 - adv_wrong_count / n_samples) | |||||
print('Number of samples correctly predicted in original dataset but wrong in adversarial dataset: ', | |||||
ori_correct_adv_wrong_count) | |||||
print('Number of samples both wrong predicted in original and adversarial dataset: ', ori_wrong_adv_wrong_count) | |||||
print('------------------------------------------------------------------------------') | |||||
for key in result.keys(): | |||||
print('Method ', key) | |||||
print('Number of perturb samples: {} '.format(result[key]['count'])) | |||||
print('Number of wrong predicted: {}'.format(result[key]['wrong_count'])) | |||||
print('Number of correctly predicted in origin dataset but wrong in adversarial: {}'.format( | |||||
result[key]['ori_correct_adv_wrong_count'])) | |||||
print('Number of both wrong predicted in origin and adversarial dataset: {}'.format( | |||||
result[key]['ori_wrong_adv_wrong_count'])) | |||||
print('------------------------------------------------------------------------------') | |||||
return result | |||||
if __name__ == '__main__': | |||||
lmdb_data_path = config.ADV_TEST_DATASET_PATH | |||||
analyse_adv_iii5t_3000(lmdb_path=lmdb_data_path) |
@@ -0,0 +1,591 @@ | |||||
# Contents | |||||
- [CNNCTC Description](#CNNCTC-description) | |||||
- [Model Architecture](#model-architecture) | |||||
- [Dataset](#dataset) | |||||
- [Features](#features) | |||||
- [Mixed Precision](#mixed-precision) | |||||
- [Environment Requirements](#environment-requirements) | |||||
- [Quick Start](#quick-start) | |||||
- [Script Description](#script-description) | |||||
- [Script and Sample Code](#script-and-sample-code) | |||||
- [Script Parameters](#script-parameters) | |||||
- [Training Process](#training-process) | |||||
- [Training](#training) | |||||
- [Distributed Training](#distributed-training) | |||||
- [Evaluation Process](#evaluation-process) | |||||
- [Evaluation](#evaluation) | |||||
- [Inference Process](#inference-process) | |||||
- [Export MindIR](#export-mindir) | |||||
- [Infer on Ascend310](#infer-on-ascend310) | |||||
- [result](#result) | |||||
- [Model Description](#model-description) | |||||
- [Performance](#performance) | |||||
- [Training Performance](#training-performance) | |||||
- [Evaluation Performance](#evaluation-performance) | |||||
- [Inference Performance](#inference-performance) | |||||
- [How to use](#how-to-use) | |||||
- [Inference](#inference) | |||||
- [Continue Training on the Pretrained Model](#continue-training-on-the-pretrained-model) | |||||
- [Transfer Learning](#transfer-learning) | |||||
- [Description of Random Situation](#description-of-random-situation) | |||||
- [ModelZoo Homepage](#modelzoo-homepage) | |||||
# [CNNCTC Description](#contents) | |||||
This paper proposes three major contributions to addresses scene text recognition (STR). | |||||
First, we examine the inconsistencies of training and evaluation datasets, and the performance gap results from inconsistencies. | |||||
Second, we introduce a unified four-stage STR framework that most existing STR models fit into. | |||||
Using this framework allows for the extensive evaluation of previously proposed STR modules and the discovery of previously | |||||
unexplored module combinations. Third, we analyze the module-wise contributions to performance in terms of accuracy, speed, | |||||
and memory demand, under one consistent set of training and evaluation datasets. Such analyses clean up the hindrance on the current | |||||
comparisons to understand the performance gain of the existing modules. | |||||
[Paper](https://arxiv.org/abs/1904.01906): J. Baek, G. Kim, J. Lee, S. Park, D. Han, S. Yun, S. J. Oh, and H. Lee, “What is wrong with scene text recognition model comparisons? dataset and model analysis,” ArXiv, vol. abs/1904.01906, 2019. | |||||
# [Model Architecture](#contents) | |||||
This is an example of training CNN+CTC model for text recognition on MJSynth and SynthText dataset with MindSpore. | |||||
# [Dataset](#contents) | |||||
Note that you can run the scripts based on the dataset mentioned in original paper or widely used in relevant domain/network architecture. In the following sections, we will introduce how to run the scripts using the related dataset below. | |||||
The [MJSynth](https://www.robots.ox.ac.uk/~vgg/data/text/) and [SynthText](https://github.com/ankush-me/SynthText) dataset are used for model training. The [The IIIT 5K-word dataset](https://cvit.iiit.ac.in/research/projects/cvit-projects/the-iiit-5k-word-dataset) dataset is used for evaluation. | |||||
- step 1: | |||||
All the datasets have been preprocessed and stored in .lmdb format and can be downloaded [**HERE**](https://drive.google.com/drive/folders/192UfE9agQUMNq6AgU3_E05_FcPZK4hyt). | |||||
- step 2: | |||||
Uncompress the downloaded file, rename the MJSynth dataset as MJ, the SynthText dataset as ST and the IIIT dataset as IIIT. | |||||
- step 3: | |||||
Move above mentioned three datasets into `cnnctc_data` folder, and the structure should be as below: | |||||
```text | |||||
|--- CNNCTC/ | |||||
|--- cnnctc_data/ | |||||
|--- ST/ | |||||
data.mdb | |||||
lock.mdb | |||||
|--- MJ/ | |||||
data.mdb | |||||
lock.mdb | |||||
|--- IIIT/ | |||||
data.mdb | |||||
lock.mdb | |||||
...... | |||||
``` | |||||
- step 4: | |||||
Preprocess the dataset by running: | |||||
```bash | |||||
python src/preprocess_dataset.py | |||||
``` | |||||
This takes around 75 minutes. | |||||
# [Features](#contents) | |||||
## Mixed Precision | |||||
The [mixed precision](https://www.mindspore.cn/docs/programming_guide/en/master/enable_mixed_precision.html) training method accelerates the deep learning neural network training process by using both the single-precision and half-precision data formats, and maintains the network precision achieved by the single-precision training at the same time. Mixed precision training can accelerate the computation process, reduce memory usage, and enable a larger model or batch size to be trained on specific hardware. | |||||
For FP16 operators, if the input data type is FP32, the backend of MindSpore will automatically handle it with reduced precision. Users could check the reduced-precision operators by enabling INFO log and then searching ‘reduce precision’. | |||||
# [Environment Requirements](#contents) | |||||
- Hardware(Ascend/GPU) | |||||
- Prepare hardware environment with Ascend or GPU processor. | |||||
- Framework | |||||
- [MindSpore](https://www.mindspore.cn/install/en) | |||||
- For more information, please check the resources below: | |||||
- [MindSpore tutorials](https://www.mindspore.cn/tutorials/en/master/index.html) | |||||
- [MindSpore Python API](https://www.mindspore.cn/docs/api/en/master/index.html) | |||||
# [Quick Start](#contents) | |||||
- Install dependencies: | |||||
```bash | |||||
pip install lmdb | |||||
pip install Pillow | |||||
pip install tqdm | |||||
pip install six | |||||
``` | |||||
```default_config.yaml | |||||
TRAIN_DATASET_PATH: /home/DataSet/MJ-ST-IIIT/ST-MJ/ | |||||
TRAIN_DATASET_INDEX_PATH: /home/DataSet/MJ-ST-IIIT/st_mj_fixed_length_index_list.pkl | |||||
TEST_DATASET_PATH: /home/DataSet/MJ-ST-IIIT/IIIT5K_3000 | |||||
Modify the parameters according to the actual path | |||||
``` | |||||
- Standalone Ascend Training: | |||||
```bash | |||||
bash scripts/run_standalone_train_ascend.sh $DEVICE_ID $PRETRAINED_CKPT(options) | |||||
# example: bash scripts/run_standalone_train_ascend.sh 0 | |||||
``` | |||||
- Standalone GPU Training: | |||||
```bash | |||||
bash scripts/run_standalone_train_gpu.sh $PRETRAINED_CKPT(options) | |||||
``` | |||||
- Distributed Ascend Training: | |||||
```bash | |||||
bash scripts/run_distribute_train_ascend.sh $RANK_TABLE_FILE $PRETRAINED_CKPT(options) | |||||
# example: bash scripts/run_distribute_train_ascend.sh ~/hccl_8p.json | |||||
``` | |||||
- Distributed GPU Training: | |||||
```bash | |||||
bash scripts/run_distribute_train_gpu.sh $PRETRAINED_CKPT(options) | |||||
``` | |||||
- Ascend Evaluation: | |||||
```bash | |||||
bash scripts/run_eval_ascend.sh $DEVICE_ID $TRAINED_CKPT | |||||
# example: scripts/run_eval_ascend.sh 0 /home/model/cnnctc/ckpt/CNNCTC-1_8000.ckpt | |||||
``` | |||||
- GPU Evaluation: | |||||
```bash | |||||
bash scripts/run_eval_gpu.sh $TRAINED_CKPT | |||||
``` | |||||
# [Script Description](#contents) | |||||
## [Script and Sample Code](#contents) | |||||
The entire code structure is as following: | |||||
```text | |||||
|--- CNNCTC/ | |||||
|---README.md // descriptions about cnnctc | |||||
|---README_cn.md // descriptions about cnnctc | |||||
|---default_config.yaml // config file | |||||
|---train.py // train scripts | |||||
|---eval.py // eval scripts | |||||
|---export.py // export scripts | |||||
|---preprocess.py // preprocess scripts | |||||
|---postprocess.py // postprocess scripts | |||||
|---ascend310_infer // application for 310 inference | |||||
|---scripts | |||||
|---run_infer_310.sh // shell script for infer on ascend310 | |||||
|---run_standalone_train_ascend.sh // shell script for standalone on ascend | |||||
|---run_standalone_train_gpu.sh // shell script for standalone on gpu | |||||
|---run_distribute_train_ascend.sh // shell script for distributed on ascend | |||||
|---run_distribute_train_gpu.sh // shell script for distributed on gpu | |||||
|---run_eval_ascend.sh // shell script for eval on ascend | |||||
|---src | |||||
|---__init__.py // init file | |||||
|---cnn_ctc.py // cnn_ctc network | |||||
|---callback.py // loss callback file | |||||
|---dataset.py // process dataset | |||||
|---util.py // routine operation | |||||
|---preprocess_dataset.py // preprocess dataset | |||||
|--- model_utils | |||||
|---config.py // Parameter config | |||||
|---moxing_adapter.py // modelarts device configuration | |||||
|---device_adapter.py // Device Config | |||||
|---local_adapter.py // local device config | |||||
``` | |||||
## [Script Parameters](#contents) | |||||
Parameters for both training and evaluation can be set in `default_config.yaml`. | |||||
Arguments: | |||||
- `--CHARACTER`: Character labels. | |||||
- `--NUM_CLASS`: The number of classes including all character labels and the <blank> label for CTCLoss. | |||||
- `--HIDDEN_SIZE`: Model hidden size. | |||||
- `--FINAL_FEATURE_WIDTH`: The number of features. | |||||
- `--IMG_H`: The height of input image. | |||||
- `--IMG_W`: The width of input image. | |||||
- `--TRAIN_DATASET_PATH`: The path to training dataset. | |||||
- `--TRAIN_DATASET_INDEX_PATH`: The path to training dataset index file which determines the order . | |||||
- `--TRAIN_BATCH_SIZE`: Training batch size. The batch size and index file must ensure input data is in fixed shape. | |||||
- `--TRAIN_DATASET_SIZE`: Training dataset size. | |||||
- `--TEST_DATASET_PATH`: The path to test dataset. | |||||
- `--TEST_BATCH_SIZE`: Test batch size. | |||||
- `--TRAIN_EPOCHS`:Total training epochs. | |||||
- `--CKPT_PATH`:The path to model checkpoint file, can be used to resume training and evaluation. | |||||
- `--SAVE_PATH`:The path to save model checkpoint file. | |||||
- `--LR`:Learning rate for standalone training. | |||||
- `--LR_PARA`:Learning rate for distributed training. | |||||
- `--MOMENTUM`:Momentum. | |||||
- `--LOSS_SCALE`:Loss scale to prevent gradient underflow. | |||||
- `--SAVE_CKPT_PER_N_STEP`:Save model checkpoint file per N steps. | |||||
- `--KEEP_CKPT_MAX_NUM`:The maximum number of saved model checkpoint file. | |||||
## [Training Process](#contents) | |||||
### Training | |||||
- Standalone Ascend Training: | |||||
```bash | |||||
bash scripts/run_standalone_train_ascend.sh [DEVICE_ID] [PRETRAINED_CKPT(options)] | |||||
# example: bash scripts/run_standalone_train_ascend.sh 0 | |||||
``` | |||||
Results and checkpoints are written to `./train` folder. Log can be found in `./train/log` and loss values are recorded in `./train/loss.log`. | |||||
`$PRETRAINED_CKPT` is the path to model checkpoint and it is **optional**. If none is given the model will be trained from scratch. | |||||
- Distributed Ascend Training: | |||||
```bash | |||||
bash scripts/run_distribute_train_ascend.sh [RANK_TABLE_FILE] [PRETRAINED_CKPT(options)] | |||||
# example: bash scripts/run_distribute_train_ascend.sh ~/hccl_8p.json | |||||
``` | |||||
For distributed training, a hccl configuration file with JSON format needs to be created in advance. | |||||
Please follow the instructions in the link below: | |||||
<https://gitee.com/mindspore/models/tree/master/utils/hccl_tools>. | |||||
Results and checkpoints are written to `./train_parallel_{i}` folder for device `i` respectively. | |||||
Log can be found in `./train_parallel_{i}/log_{i}.log` and loss values are recorded in `./train_parallel_{i}/loss.log`. | |||||
`$RANK_TABLE_FILE` is needed when you are running a distribute task on ascend. | |||||
`$PATH_TO_CHECKPOINT` is the path to model checkpoint and it is **optional**. If none is given the model will be trained from scratch. | |||||
### Training Result | |||||
Training result will be stored in the example path, whose folder name begins with "train" or "train_parallel". You can find checkpoint file together with result like the following in loss.log. | |||||
```text | |||||
# distribute training result(8p) | |||||
epoch: 1 step: 1 , loss is 76.25, average time per step is 0.235177839748392712 | |||||
epoch: 1 step: 2 , loss is 73.46875, average time per step is 0.25798572540283203 | |||||
epoch: 1 step: 3 , loss is 69.46875, average time per step is 0.229678678512573 | |||||
epoch: 1 step: 4 , loss is 64.3125, average time per step is 0.23512671788533527 | |||||
epoch: 1 step: 5 , loss is 58.375, average time per step is 0.23149147033691406 | |||||
epoch: 1 step: 6 , loss is 52.7265625, average time per step is 0.2292975425720215 | |||||
... | |||||
epoch: 1 step: 8689 , loss is 9.706798802612482, average time per step is 0.2184656601312549 | |||||
epoch: 1 step: 8690 , loss is 9.70612545289855, average time per step is 0.2184725407765116 | |||||
epoch: 1 step: 8691 , loss is 9.70695776049204, average time per step is 0.21847309686135555 | |||||
epoch: 1 step: 8692 , loss is 9.707279624277456, average time per step is 0.21847339290613375 | |||||
epoch: 1 step: 8693 , loss is 9.70763437950938, average time per step is 0.2184720295013031 | |||||
epoch: 1 step: 8694 , loss is 9.707695425072046, average time per step is 0.21847410284595573 | |||||
epoch: 1 step: 8695 , loss is 9.708408273381295, average time per step is 0.21847338271072345 | |||||
epoch: 1 step: 8696 , loss is 9.708703753591953, average time per step is 0.2184726025560777 | |||||
epoch: 1 step: 8697 , loss is 9.709536406025824, average time per step is 0.21847212061114694 | |||||
epoch: 1 step: 8698 , loss is 9.708542263610315, average time per step is 0.2184715309307257 | |||||
``` | |||||
- running on ModelArts | |||||
- If you want to train the model on modelarts, you can refer to the [official guidance document] of modelarts (https://support.huaweicloud.com/modelarts/) | |||||
```python | |||||
# Example of using distributed training dpn on modelarts : | |||||
# Data set storage method | |||||
# ├── CNNCTC_Data # dataset dir | |||||
# ├──train # train dir | |||||
# ├── ST_MJ # train dataset dir | |||||
# ├── data.mdb # data file | |||||
# ├── lock.mdb | |||||
# ├── st_mj_fixed_length_index_list.pkl | |||||
# ├── eval # eval dir | |||||
# ├── IIIT5K_3000 # eval dataset dir | |||||
# ├── checkpoint # checkpoint dir | |||||
# (1) Choose either a (modify yaml file parameters) or b (modelArts create training job to modify parameters) 。 | |||||
# a. set "enable_modelarts=True" | |||||
# set "run_distribute=True" | |||||
# set "TRAIN_DATASET_PATH=/cache/data/ST_MJ/" | |||||
# set "TRAIN_DATASET_INDEX_PATH=/cache/data/st_mj_fixed_length_index_list.pkl" | |||||
# set "SAVE_PATH=/cache/train/checkpoint" | |||||
# | |||||
# b. add "enable_modelarts=True" Parameters are on the interface of modearts。 | |||||
# Set the parameters required by method a on the modelarts interface | |||||
# Note: The path parameter does not need to be quoted | |||||
# (2) Set the path of the network configuration file "_config_path=/The path of config in default_config.yaml/" | |||||
# (3) Set the code path on the modelarts interface "/path/cnnctc"。 | |||||
# (4) Set the model's startup file on the modelarts interface "train.py" 。 | |||||
# (5) Set the data path of the model on the modelarts interface ".../CNNCTC_Data/train"(choices CNNCTC_Data/train Folder path) , | |||||
# The output path of the model "Output file path" and the log path of the model "Job log path" 。 | |||||
# (6) start trainning the model。 | |||||
# Example of using model inference on modelarts | |||||
# (1) Place the trained model to the corresponding position of the bucket。 | |||||
# (2) chocie a or b。 | |||||
# a.set "enable_modelarts=True" | |||||
# set "TEST_DATASET_PATH=/cache/data/IIIT5K_3000/" | |||||
# set "CHECKPOINT_PATH=/cache/data/checkpoint/checkpoint file name" | |||||
# b. Add "enable_modelarts=True" parameter on the interface of modearts。 | |||||
# Set the parameters required by method a on the modelarts interface | |||||
# Note: The path parameter does not need to be quoted | |||||
# (3) Set the path of the network configuration file "_config_path=/The path of config in default_config.yaml/" | |||||
# (4) Set the code path on the modelarts interface "/path/cnnctc"。 | |||||
# (5) Set the model's startup file on the modelarts interface "train.py" 。 | |||||
# (6) Set the data path of the model on the modelarts interface ".../CNNCTC_Data/train"(choices CNNCTC_Data/train Folder path) , | |||||
# The output path of the model "Output file path" and the log path of the model "Job log path" 。 | |||||
# (7) Start model inference。 | |||||
``` | |||||
- Standalone GPU Training: | |||||
```bash | |||||
bash scripts/run_standalone_train_gpu.sh [PRETRAINED_CKPT(options)] | |||||
``` | |||||
Results and checkpoints are written to `./train` folder. Log can be found in `./train/log` and loss values are recorded in `./train/loss.log`. | |||||
`$PRETRAINED_CKPT` is the path to model checkpoint and it is **optional**. If none is given the model will be trained from scratch. | |||||
- Distributed GPU Training: | |||||
```bash | |||||
bash scripts/run_distribute_train_gpu.sh [PRETRAINED_CKPT(options)] | |||||
``` | |||||
Results and checkpoints are written to `./train_parallel` folder with model checkpoints in ckpt_{i} directories. | |||||
Log can be found in `./train_parallel/log` and loss values are recorded in `./train_parallel/loss.log`. | |||||
## [Evaluation Process](#contents) | |||||
### Evaluation | |||||
- Ascend Evaluation: | |||||
```bash | |||||
bash scripts/run_eval_ascend.sh [DEVICE_ID] [TRAINED_CKPT] | |||||
# example: scripts/run_eval_ascend.sh 0 /home/model/cnnctc/ckpt/CNNCTC-1_8000.ckpt | |||||
``` | |||||
The model will be evaluated on the IIIT dataset, sample results and overall accuracy will be printed. | |||||
- GPU Evaluation: | |||||
```bash | |||||
bash scripts/run_eval_gpu.sh [TRAINED_CKPT] | |||||
``` | |||||
## [Inference process](#contents) | |||||
### Export MindIR | |||||
```shell | |||||
python export.py --ckpt_file [CKPT_PATH] --file_format [EXPORT_FORMAT] --TEST_BATCH_SIZE [BATCH_SIZE] | |||||
``` | |||||
The ckpt_file parameter is required, | |||||
`EXPORT_FORMAT` should be in ["AIR", "MINDIR"]. | |||||
`BATCH_SIZE` current batch_size can only be set to 1. | |||||
- Export MindIR on Modelarts | |||||
```Modelarts | |||||
Export MindIR example on ModelArts | |||||
Data storage method is the same as training | |||||
# (1) Choose either a (modify yaml file parameters) or b (modelArts create training job to modify parameters)。 | |||||
# a. set "enable_modelarts=True" | |||||
# set "file_name=cnnctc" | |||||
# set "file_format=MINDIR" | |||||
# set "ckpt_file=/cache/data/checkpoint file name" | |||||
# b. Add "enable_modelarts=True" parameter on the interface of modearts。 | |||||
# Set the parameters required by method a on the modelarts interface | |||||
# Note: The path parameter does not need to be quoted | |||||
# (2)Set the path of the network configuration file "_config_path=/The path of config in default_config.yaml/" | |||||
# (3) Set the code path on the modelarts interface "/path/cnnctc"。 | |||||
# (4) Set the model's startup file on the modelarts interface "export.py" 。 | |||||
# (5) Set the data path of the model on the modelarts interface ".../CNNCTC_Data/eval/checkpoint"(choices CNNCTC_Data/eval/checkpoint Folder path) , | |||||
# The output path of the model "Output file path" and the log path of the model "Job log path" 。 | |||||
``` | |||||
### Infer on Ascend310 | |||||
Before performing inference, the mindir file must be exported by `export.py` script. We only provide an example of inference using MINDIR model. | |||||
```shell | |||||
# Ascend310 inference | |||||
bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DVPP] [DEVICE_ID] | |||||
``` | |||||
- `DVPP` is mandatory, and must choose from ["DVPP", "CPU"], it's case-insensitive. CNNCTC only support CPU mode . | |||||
- `DEVICE_ID` is optional, default value is 0. | |||||
### Result | |||||
- Ascend Result | |||||
Inference result is saved in current path, you can find result like this in acc.log file. | |||||
```bash | |||||
'Accuracy': 0.8642 | |||||
``` | |||||
- GPU result | |||||
Inference result is saved in ./eval/log, you can find result like this. | |||||
```bash | |||||
accuracy: 0.8533 | |||||
``` | |||||
# [Model Description](#contents) | |||||
## [Performance](#contents) | |||||
### Training Performance | |||||
| Parameters | CNNCTC | | |||||
| -------------------------- | ----------------------------------------------------------- | | |||||
| Model Version | V1 | | |||||
| Resource | Ascend 910; CPU 2.60GHz, 192cores; Memory 755G; OS Euler2.8 | | |||||
| uploaded Date | 09/28/2020 (month/day/year) | | |||||
| MindSpore Version | 1.0.0 | | |||||
| Dataset | MJSynth,SynthText | | |||||
| Training Parameters | epoch=3, batch_size=192 | | |||||
| Optimizer | RMSProp | | |||||
| Loss Function | CTCLoss | | |||||
| Speed | 1pc: 250 ms/step; 8pcs: 260 ms/step | | |||||
| Total time | 1pc: 15 hours; 8pcs: 1.92 hours | | |||||
| Parameters (M) | 177 | | |||||
| Scripts | <https://gitee.com/mindspore/models/tree/master/official/cv/cnnctc> | | |||||
| Parameters | CNNCTC | | |||||
| -------------------------- | ----------------------------------------------------------- | | |||||
| Model Version | V1 | | |||||
| Resource | GPU(Tesla V100-PCIE); CPU 2.60 GHz, 26 cores; Memory 790G; OS linux-gnu | | |||||
| uploaded Date | 07/06/2021 (month/day/year) | | |||||
| MindSpore Version | 1.0.0 | | |||||
| Dataset | MJSynth,SynthText | | |||||
| Training Parameters | epoch=3, batch_size=192 | | |||||
| Optimizer | RMSProp | | |||||
| Loss Function | CTCLoss | | |||||
| Speed | 1pc: 1180 ms/step; 8pcs: 1180 ms/step | | |||||
| Total time | 1pc: 62.9 hours; 8pcs: 8.67 hours | | |||||
| Parameters (M) | 177 | | |||||
| Scripts | <https://gitee.com/mindspore/models/tree/master/official/cv/cnnctc> | | |||||
### Evaluation Performance | |||||
| Parameters | CNNCTC | | |||||
| ------------------- | --------------------------- | | |||||
| Model Version | V1 | | |||||
| Resource | Ascend 910; OS Euler2.8 | | |||||
| Uploaded Date | 09/28/2020 (month/day/year) | | |||||
| MindSpore Version | 1.0.0 | | |||||
| Dataset | IIIT5K | | |||||
| batch_size | 192 | | |||||
| outputs | Accuracy | | |||||
| Accuracy | 85% | | |||||
| Model for inference | 675M (.ckpt file) | | |||||
### Inference Performance | |||||
| Parameters | Ascend | | |||||
| ------------------- | --------------------------- | | |||||
| Model Version | CNNCTC | | |||||
| Resource | Ascend 310; CentOS 3.10 | | |||||
| Uploaded Date | 19/05/2021 (month/day/year) | | |||||
| MindSpore Version | 1.2.0 | | |||||
| Dataset | IIIT5K | | |||||
| batch_size | 1 | | |||||
| outputs | Accuracy | | |||||
| Accuracy | Accuracy=0.8642 | | |||||
| Model for inference | 675M(.ckpt file) | | |||||
## [How to use](#contents) | |||||
### Inference | |||||
If you need to use the trained model to perform inference on multiple hardware platforms, such as GPU, Ascend 910 or Ascend 310, you can refer to this [Link](https://www.mindspore.cn/docs/programming_guide/en/master/multi_platform_inference.html). Following the steps below, this is a simple example: | |||||
- Running on Ascend | |||||
```python | |||||
# Set context | |||||
context.set_context(mode=context.GRAPH_HOME, device_target=cfg.device_target) | |||||
context.set_context(device_id=cfg.device_id) | |||||
# Load unseen dataset for inference | |||||
dataset = dataset.create_dataset(cfg.data_path, 1, False) | |||||
# Define model | |||||
net = CNNCTC(cfg.NUM_CLASS, cfg.HIDDEN_SIZE, cfg.FINAL_FEATURE_WIDTH) | |||||
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, | |||||
cfg.momentum, weight_decay=cfg.weight_decay) | |||||
loss = P.CTCLoss(preprocess_collapse_repeated=False, | |||||
ctc_merge_repeated=True, | |||||
ignore_longer_outputs_than_inputs=False) | |||||
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}) | |||||
# Load pre-trained model | |||||
param_dict = load_checkpoint(cfg.checkpoint_path) | |||||
load_param_into_net(net, param_dict) | |||||
net.set_train(False) | |||||
# Make predictions on the unseen dataset | |||||
acc = model.eval(dataset) | |||||
print("accuracy: ", acc) | |||||
``` | |||||
### Continue Training on the Pretrained Model | |||||
- running on Ascend | |||||
```python | |||||
# Load dataset | |||||
dataset = create_dataset(cfg.data_path, 1) | |||||
batch_num = dataset.get_dataset_size() | |||||
# Define model | |||||
net = CNNCTC(cfg.NUM_CLASS, cfg.HIDDEN_SIZE, cfg.FINAL_FEATURE_WIDTH) | |||||
# Continue training if set pre_trained to be True | |||||
if cfg.pre_trained: | |||||
param_dict = load_checkpoint(cfg.checkpoint_path) | |||||
load_param_into_net(net, param_dict) | |||||
lr = lr_steps(0, lr_max=cfg.lr_init, total_epochs=cfg.epoch_size, | |||||
steps_per_epoch=batch_num) | |||||
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), | |||||
Tensor(lr), cfg.momentum, weight_decay=cfg.weight_decay) | |||||
loss = P.CTCLoss(preprocess_collapse_repeated=False, | |||||
ctc_merge_repeated=True, | |||||
ignore_longer_outputs_than_inputs=False) | |||||
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}, | |||||
amp_level="O2", keep_batchnorm_fp32=False, loss_scale_manager=None) | |||||
# Set callbacks | |||||
config_ck = CheckpointConfig(save_checkpoint_steps=batch_num * 5, | |||||
keep_checkpoint_max=cfg.keep_checkpoint_max) | |||||
time_cb = TimeMonitor(data_size=batch_num) | |||||
ckpoint_cb = ModelCheckpoint(prefix="train_googlenet_cifar10", directory="./", | |||||
config=config_ck) | |||||
loss_cb = LossMonitor() | |||||
# Start training | |||||
model.train(cfg.epoch_size, dataset, callbacks=[time_cb, ckpoint_cb, loss_cb]) | |||||
print("train success") | |||||
``` | |||||
# [ModelZoo Homepage](#contents) | |||||
Please check the official [homepage](https://gitee.com/mindspore/models). |
@@ -0,0 +1,523 @@ | |||||
# 目录 | |||||
<!-- TOC --> | |||||
- [目录](#目录) | |||||
- [CNN+CTC描述](#cnnctc描述) | |||||
- [模型架构](#模型架构) | |||||
- [数据集](#数据集) | |||||
- [特性](#特性) | |||||
- [混合精度](#混合精度) | |||||
- [环境要求](#环境要求) | |||||
- [快速入门](#快速入门) | |||||
- [脚本说明](#脚本说明) | |||||
- [脚本及样例代码](#脚本及样例代码) | |||||
- [脚本参数](#脚本参数) | |||||
- [训练过程](#训练过程) | |||||
- [训练](#训练) | |||||
- [训练结果](#训练结果) | |||||
- [评估过程](#评估过程) | |||||
- [评估](#评估) | |||||
- [推理过程](#推理过程) | |||||
- [导出MindIR](#导出mindir) | |||||
- [在Ascend310执行推理](#在ascend310执行推理) | |||||
- [结果](#结果) | |||||
- [模型描述](#模型描述) | |||||
- [性能](#性能) | |||||
- [训练性能](#训练性能) | |||||
- [评估性能](#评估性能) | |||||
- [推理性能](#推理性能) | |||||
- [用法](#用法) | |||||
- [推理](#推理) | |||||
- [在预训练模型上继续训练](#在预训练模型上继续训练) | |||||
- [ModelZoo主页](#modelzoo主页) | |||||
<!-- /TOC --> | |||||
# CNN+CTC描述 | |||||
本文描述了对场景文本识别(STR)的三个主要贡献。 | |||||
首先检查训练和评估数据集不一致的内容,以及导致的性能差距。 | |||||
再引入一个统一的四阶段STR框架,目前大多数STR模型都能够适应这个框架。 | |||||
使用这个框架可以广泛评估以前提出的STR模块,并发现以前未开发的模块组合。 | |||||
第三,分析在一致的训练和评估数据集下,模块对性能的贡献,包括准确率、速度和内存需求。 | |||||
这些分析清除了当前比较的障碍,有助于了解现有模块的性能增益。 | |||||
[论文](https://arxiv.org/abs/1904.01906): J. Baek, G. Kim, J. Lee, S. Park, D. Han, S. Yun, S. J. Oh, and H. Lee, “What is wrong with scene text recognition model comparisons? dataset and model analysis,” ArXiv, vol. abs/1904.01906, 2019. | |||||
# 模型架构 | |||||
示例:在MindSpore上使用MJSynth和SynthText数据集训练CNN+CTC模型进行文本识别。 | |||||
# 数据集 | |||||
[MJSynth](https://www.robots.ox.ac.uk/~vgg/data/text/)和[SynthText](https://github.com/ankush-me/SynthText)数据集用于模型训练。[The IIIT 5K-word dataset](https://cvit.iiit.ac.in/research/projects/cvit-projects/the-iiit-5k-word-dataset)数据集用于评估。 | |||||
- 步骤1: | |||||
所有数据集均经过预处理,以.lmdb格式存储,点击[**此处**](https://drive.google.com/drive/folders/192UfE9agQUMNq6AgU3_E05_FcPZK4hyt)可下载。 | |||||
- 步骤2: | |||||
解压下载的文件,重命名MJSynth数据集为MJ,SynthText数据集为ST,IIIT数据集为IIIT。 | |||||
- 步骤3: | |||||
将上述三个数据集移至`cnctc_data`文件夹中,结构如下: | |||||
```python | |||||
|--- CNNCTC/ | |||||
|--- cnnctc_data/ | |||||
|--- ST/ | |||||
data.mdb | |||||
lock.mdb | |||||
|--- MJ/ | |||||
data.mdb | |||||
lock.mdb | |||||
|--- IIIT/ | |||||
data.mdb | |||||
lock.mdb | |||||
...... | |||||
``` | |||||
- 步骤4: | |||||
预处理数据集: | |||||
```shell | |||||
python src/preprocess_dataset.py | |||||
``` | |||||
这大约需要75分钟。 | |||||
# 特性 | |||||
## 混合精度 | |||||
采用[混合精度](https://www.mindspore.cn/docs/programming_guide/zh-CN/master/enable_mixed_precision.html)的训练方法使用支持单精度和半精度数据来提高深度学习神经网络的训练速度,同时保持单精度训练所能达到的网络精度。混合精度训练提高计算速度、减少内存使用的同时,支持在特定硬件上训练更大的模型或实现更大批次的训练。 | |||||
以FP16算子为例,如果输入数据类型为FP32,MindSpore后台会自动降低精度来处理数据。用户可打开INFO日志,搜索“reduce precision”查看精度降低的算子。 | |||||
# 环境要求 | |||||
- 硬件(Ascend) | |||||
- 准备Ascend或GPU处理器搭建硬件环境。 | |||||
- 框架 | |||||
- [MindSpore](https://www.mindspore.cn/install) | |||||
- 如需查看详情,请参见如下资源: | |||||
- [MindSpore教程](https://www.mindspore.cn/tutorials/zh-CN/master/index.html) | |||||
- [MindSpore Python API](https://www.mindspore.cn/docs/api/zh-CN/master/index.html) | |||||
# 快速入门 | |||||
- 安装依赖: | |||||
```python | |||||
pip install lmdb | |||||
pip install Pillow | |||||
pip install tqdm | |||||
pip install six | |||||
``` | |||||
```default_config.yaml | |||||
TRAIN_DATASET_PATH: /home/DataSet/MJ-ST-IIIT/ST-MJ/ | |||||
TRAIN_DATASET_INDEX_PATH: /home/DataSet/MJ-ST-IIIT/st_mj_fixed_length_index_list.pkl | |||||
TEST_DATASET_PATH: /home/DataSet/MJ-ST-IIIT/IIIT5K_3000 | |||||
根据实际路径修改参数 | |||||
``` | |||||
- 单机训练: | |||||
```shell | |||||
bash scripts/run_standalone_train_ascend.sh [DEVICE_ID] [PRETRAINED_CKPT(options)] | |||||
# example: bash scripts/run_standalone_train_ascend.sh 0 | |||||
``` | |||||
- 分布式训练: | |||||
```shell | |||||
bash scripts/run_distribute_train_ascend.sh [RANK_TABLE_FILE] [PRETRAINED_CKPT(options)] | |||||
# example: bash scripts/run_distribute_train_ascend.sh ~/hccl_8p.json | |||||
``` | |||||
- 评估: | |||||
```shell | |||||
bash scripts/run_eval_ascend.sh DEVICE_ID TRAINED_CKPT | |||||
# example: scripts/run_eval_ascend.sh 0 /home/model/cnnctc/ckpt/CNNCTC-1_8000.ckpt | |||||
``` | |||||
# 脚本说明 | |||||
## 脚本及样例代码 | |||||
完整代码结构如下: | |||||
```python | |||||
|--- CNNCTC/ | |||||
|---README_CN.md // CNN+CTC相关描述 | |||||
|---README.md // CNN+CTC相关描述 | |||||
|---train.py // 训练脚本 | |||||
|---eval.py // 评估脚本 | |||||
|---export.py // 模型导出脚本 | |||||
|---postprocess.py // 推理后处理脚本 | |||||
|---preprocess.py // 推理前处理脚本 | |||||
|---ascend310_infer // 用于310推理 | |||||
|---default_config.yaml // 参数配置 | |||||
|---scripts | |||||
|---run_standalone_train_ascend.sh // Ascend单机shell脚本 | |||||
|---run_distribute_train_ascend.sh // Ascend分布式shell脚本 | |||||
|---run_eval_ascend.sh // Ascend评估shell脚本 | |||||
|---run_infer_310.sh // Ascend310推理的shell脚本 | |||||
|---src | |||||
|---__init__.py // init文件 | |||||
|---cnn_ctc.py // cnn_ctc网络 | |||||
|---callback.py // 损失回调文件 | |||||
|---dataset.py // 处理数据集 | |||||
|---util.py // 常规操作 | |||||
|---generate_hccn_file.py // 生成分布式json文件 | |||||
|---preprocess_dataset.py // 预处理数据集 | |||||
|---model_utils | |||||
|---config.py # 参数生成 | |||||
|---device_adapter.py # 设备相关信息 | |||||
|---local_adapter.py # 设备相关信息 | |||||
|---moxing_adapter.py # 装饰器(主要用于ModelArts数据拷贝) | |||||
``` | |||||
## 脚本参数 | |||||
在`default_config.yaml`中可以同时配置训练参数和评估参数。 | |||||
参数: | |||||
- `--CHARACTER`:字符标签。 | |||||
- `--NUM_CLASS`:类别数,包含所有字符标签和CTCLoss的<blank>标签。 | |||||
- `--HIDDEN_SIZE`:模型隐藏大小。 | |||||
- `--FINAL_FEATURE_WIDTH`:特性的数量。 | |||||
- `--IMG_H`:输入图像高度。 | |||||
- `--IMG_W`:输入图像宽度。 | |||||
- `--TRAIN_DATASET_PATH`:训练数据集的路径。 | |||||
- `--TRAIN_DATASET_INDEX_PATH`:决定顺序的训练数据集索引文件的路径。 | |||||
- `--TRAIN_BATCH_SIZE`:训练批次大小。在批次大小和索引文件中,必须确保输入数据是固定的形状。 | |||||
- `--TRAIN_DATASET_SIZE`:训练数据集大小。 | |||||
- `--TEST_DATASET_PATH`:测试数据集的路径。 | |||||
- `--TEST_BATCH_SIZE`:测试批次大小。 | |||||
- `--TRAIN_EPOCHS`:总训练轮次。 | |||||
- `--CKPT_PATH`:模型检查点文件路径,可用于恢复训练和评估。 | |||||
- `--SAVE_PATH`:模型检查点文件保存路径。 | |||||
- `--LR`:单机训练学习率。 | |||||
- `--LR_PARA`:分布式训练学习率。 | |||||
- `--Momentum`:动量。 | |||||
- `--LOSS_SCALE`:损失放大,避免梯度下溢。 | |||||
- `--SAVE_CKPT_PER_N_STEP`:每N步保存模型检查点文件。 | |||||
- `--KEEP_CKPT_MAX_NUM`:模型检查点文件保存数量上限。 | |||||
## 训练过程 | |||||
### 训练 | |||||
- 单机训练: | |||||
```shell | |||||
bash scripts/run_standalone_train_ascend.sh [DEVICE_ID] [PRETRAINED_CKPT(options)] | |||||
# example: bash scripts/run_standalone_train_ascend.sh 0 | |||||
``` | |||||
结果和检查点被写入`./train`文件夹。日志可以在`./train/log`中找到,损失值记录在`./train/loss.log`中。 | |||||
`$PRETRAINED_CKPT`为模型检查点的路径,**可选**。如果值为none,模型将从头开始训练。 | |||||
- 分布式训练: | |||||
```shell | |||||
bash scripts/run_distribute_train_ascend.sh [RANK_TABLE_FILE] [PRETRAINED_CKPT(options)] | |||||
# example: bash scripts/run_distribute_train_ascend.sh ~/hccl_8p.json | |||||
``` | |||||
结果和检查点分别写入设备`i`的`./train_parallel_{i}`文件夹。 | |||||
日志可以在`./train_parallel_{i}/log_{i}.log`中找到,损失值记录在`./train_parallel_{i}/loss.log`中。 | |||||
在Ascend上运行分布式任务时需要`$RANK_TABLE_FILE`。 | |||||
`$PATH_TO_CHECKPOINT`为模型检查点的路径,**可选**。如果值为none,模型将从头开始训练。 | |||||
> 注意: | |||||
RANK_TABLE_FILE相关参考资料见[链接](https://www.mindspore.cn/docs/programming_guide/zh-CN/master/distributed_training_ascend.html), 获取device_ip方法详见[链接](https://gitee.com/mindspore/models/tree/master/utils/hccl_tools). | |||||
### 训练结果 | |||||
训练结果保存在示例路径中,文件夹名称以“train”或“train_parallel”开头。您可在此路径下的日志中找到检查点文件以及结果,如下所示。 | |||||
```python | |||||
# 分布式训练结果(8P) | |||||
epoch: 1 step: 1 , loss is 76.25, average time per step is 0.335177839748392712 | |||||
epoch: 1 step: 2 , loss is 73.46875, average time per step is 0.36798572540283203 | |||||
epoch: 1 step: 3 , loss is 69.46875, average time per step is 0.3429678678512573 | |||||
epoch: 1 step: 4 , loss is 64.3125, average time per step is 0.33512671788533527 | |||||
epoch: 1 step: 5 , loss is 58.375, average time per step is 0.33149147033691406 | |||||
epoch: 1 step: 6 , loss is 52.7265625, average time per step is 0.3292975425720215 | |||||
... | |||||
epoch: 1 step: 8689 , loss is 9.706798802612482, average time per step is 0.3184656601312549 | |||||
epoch: 1 step: 8690 , loss is 9.70612545289855, average time per step is 0.3184725407765116 | |||||
epoch: 1 step: 8691 , loss is 9.70695776049204, average time per step is 0.31847309686135555 | |||||
epoch: 1 step: 8692 , loss is 9.707279624277456, average time per step is 0.31847339290613375 | |||||
epoch: 1 step: 8693 , loss is 9.70763437950938, average time per step is 0.3184720295013031 | |||||
epoch: 1 step: 8694 , loss is 9.707695425072046, average time per step is 0.31847410284595573 | |||||
epoch: 1 step: 8695 , loss is 9.708408273381295, average time per step is 0.31847338271072345 | |||||
epoch: 1 step: 8696 , loss is 9.708703753591953, average time per step is 0.3184726025560777 | |||||
epoch: 1 step: 8697 , loss is 9.709536406025824, average time per step is 0.31847212061114694 | |||||
epoch: 1 step: 8698 , loss is 9.708542263610315, average time per step is 0.3184715309307257 | |||||
``` | |||||
## 评估过程 | |||||
### 评估 | |||||
- 评估: | |||||
```shell | |||||
bash scripts/run_eval_ascend.sh [DEVICE_ID] [TRAINED_CKPT] | |||||
# example: scripts/run_eval_ascend.sh 0 /home/model/cnnctc/ckpt/CNNCTC-1_8000.ckpt | |||||
``` | |||||
在IIIT数据集上评估模型,并打印样本结果和总准确率。 | |||||
- 如果要在modelarts上进行模型的训练,可以参考modelarts的[官方指导文档](https://support.huaweicloud.com/modelarts/) 开始进行模型的训练和推理,具体操作如下: | |||||
```ModelArts | |||||
# 在ModelArts上使用分布式训练示例: | |||||
# 数据集存放方式 | |||||
# ├── CNNCTC_Data # dataset dir | |||||
# ├──train # train dir | |||||
# ├── ST_MJ # train dataset dir | |||||
# ├── data.mdb # data file | |||||
# ├── lock.mdb | |||||
# ├── st_mj_fixed_length_index_list.pkl | |||||
# ├── eval # eval dir | |||||
# ├── IIIT5K_3000 # eval dataset dir | |||||
# ├── checkpoint # checkpoint dir | |||||
# (1) 选择a(修改yaml文件参数)或者b(ModelArts创建训练作业修改参数)其中一种方式。 | |||||
# a. 设置 "enable_modelarts=True" | |||||
# 设置 "run_distribute=True" | |||||
# 设置 "TRAIN_DATASET_PATH=/cache/data/ST_MJ/" | |||||
# 设置 "TRAIN_DATASET_INDEX_PATH=/cache/data/st_mj_fixed_length_index_list.pkl" | |||||
# 设置 "SAVE_PATH=/cache/train/checkpoint" | |||||
# b. 增加 "enable_modelarts=True" 参数在modearts的界面上。 | |||||
# 在modelarts的界面上设置方法a所需要的参数 | |||||
# 注意:路径参数不需要加引号 | |||||
# (2)设置网络配置文件的路径 "_config_path=/The path of config in default_config.yaml/" | |||||
# (3) 在modelarts的界面上设置代码的路径 "/path/cnnctc"。 | |||||
# (4) 在modelarts的界面上设置模型的启动文件 "train.py" 。 | |||||
# (5) 在modelarts的界面上设置模型的数据路径 ".../CNNCTC_Data/train"(选择CNNCTC_Data/train文件夹路径) , | |||||
# 模型的输出路径"Output file path" 和模型的日志路径 "Job log path" 。 | |||||
# (6) 开始模型的训练。 | |||||
# 在modelarts上使用模型推理的示例 | |||||
# (1) 把训练好的模型地方到桶的对应位置。 | |||||
# (2) 选择a或者b其中一种方式。 | |||||
# a.设置 "enable_modelarts=True" | |||||
# 设置 "TEST_DATASET_PATH=/cache/data/IIIT5K_3000/" | |||||
# 设置 "CHECKPOINT_PATH=/cache/data/checkpoint/checkpoint file name" | |||||
# b. 增加 "enable_modelarts=True" 参数在modearts的界面上。 | |||||
# 在modelarts的界面上设置方法a所需要的参数 | |||||
# 注意:路径参数不需要加引号 | |||||
# (3) 设置网络配置文件的路径 "_config_path=/The path of config in default_config.yaml/" | |||||
# (4) 在modelarts的界面上设置代码的路径 "/path/cnnctc"。 | |||||
# (5) 在modelarts的界面上设置模型的启动文件 "eval.py" 。 | |||||
# (6) 在modelarts的界面上设置模型的数据路径 "../CNNCTC_Data/eval"(选择CNNCTC_Data/eval文件夹路径) , | |||||
# 模型的输出路径"Output file path" 和模型的日志路径 "Job log path" 。 | |||||
# (7) 开始模型的推理。 | |||||
``` | |||||
## 推理过程 | |||||
### 导出MindIR | |||||
```shell | |||||
python export.py --ckpt_file [CKPT_PATH] --file_format [EXPORT_FORMAT] --TEST_BATCH_SIZE [BATCH_SIZE] | |||||
``` | |||||
参数ckpt_file为必填项, | |||||
`EXPORT_FORMAT` 可选 ["AIR", "MINDIR"]. | |||||
`BATCH_SIZE` 目前仅支持batch_size为1的推理. | |||||
- 在modelarts上导出MindIR | |||||
```Modelarts | |||||
在ModelArts上导出MindIR示例 | |||||
数据集存放方式同Modelart训练 | |||||
# (1) 选择a(修改yaml文件参数)或者b(ModelArts创建训练作业修改参数)其中一种方式。 | |||||
# a. 设置 "enable_modelarts=True" | |||||
# 设置 "file_name=cnnctc" | |||||
# 设置 "file_format=MINDIR" | |||||
# 设置 "ckpt_file=/cache/data/checkpoint file name" | |||||
# b. 增加 "enable_modelarts=True" 参数在modearts的界面上。 | |||||
# 在modelarts的界面上设置方法a所需要的参数 | |||||
# 注意:路径参数不需要加引号 | |||||
# (2)设置网络配置文件的路径 "_config_path=/The path of config in default_config.yaml/" | |||||
# (3) 在modelarts的界面上设置代码的路径 "/path/cnnctc"。 | |||||
# (4) 在modelarts的界面上设置模型的启动文件 "export.py" 。 | |||||
# (5) 在modelarts的界面上设置模型的数据路径 ".../CNNCTC_Data/eval/checkpoint"(选择CNNCTC_Data/eval/checkpoint文件夹路径) , | |||||
# MindIR的输出路径"Output file path" 和模型的日志路径 "Job log path" 。 | |||||
``` | |||||
### 在Ascend310执行推理 | |||||
在执行推理前,mindir文件必须通过`export.py`脚本导出。以下展示了使用mindir模型执行推理的示例。 | |||||
```shell | |||||
# Ascend310 inference | |||||
bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DVPP] [DEVICE_ID] | |||||
``` | |||||
- `DVPP` 为必填项,需要在["DVPP", "CPU"]选择,大小写均可。CNNCTC目前仅支持使用CPU算子进行推理。 | |||||
- `DEVICE_ID` 可选,默认值为0。 | |||||
### 结果 | |||||
推理结果保存在脚本执行的当前路径,你可以在acc.log中看到以下精度计算结果。 | |||||
```bash | |||||
'Accuracy':0.8642 | |||||
``` | |||||
# 模型描述 | |||||
## 性能 | |||||
### 训练性能 | |||||
| 参数 | CNNCTC | | |||||
| -------------------------- | ----------------------------------------------------------- | | |||||
| 模型版本 | V1 | | |||||
| 资源 | Ascend 910;CPU 2.60GHz,192核;内存:755G | | |||||
| 上传日期 | 2020-09-28 | | |||||
| MindSpore版本 | 1.0.0 | | |||||
| 数据集 | MJSynth、SynthText | | |||||
| 训练参数 | epoch=3, batch_size=192 | | |||||
| 优化器 | RMSProp | | |||||
| 损失函数 | CTCLoss | | |||||
| 速度 | 1卡:300毫秒/步;8卡:310毫秒/步 | | |||||
| 总时间 | 1卡:18小时;8卡:2.3小时 | | |||||
| 参数(M) | 177 | | |||||
| 脚本 | <https://gitee.com/mindspore/models/tree/master/official/cv/cnnctc> | | |||||
### 评估性能 | |||||
| 参数 | CNNCTC | | |||||
| ------------------- | --------------------------- | | |||||
| 模型版本 | V1 | | |||||
| 资源 | Ascend 910 | | |||||
| 上传日期 | 2020-09-28 | | |||||
| MindSpore版本 | 1.0.0 | | |||||
| 数据集 | IIIT5K | | |||||
| batch_size | 192 | | |||||
| 输出 |准确率 | | |||||
| 准确率 | 85% | | |||||
| 推理模型 | 675M(.ckpt文件) | | |||||
### 推理性能 | |||||
| 参数 | Ascend | | |||||
| -------------- | ---------------------------| | |||||
| 模型版本 | CNNCTC | | |||||
| 资源 | Ascend 310;系统 CentOS 3.10 | | |||||
| 上传日期 | 2021-05-19 | | |||||
| MindSpore版本 | 1.2.0 | | |||||
| 数据集 | IIIT5K | | |||||
| batch_size | 1 | | |||||
| 输出 | Accuracy | | |||||
| 准确率 | Accuracy=0.8642 | | |||||
| 推理模型 | 675M(.ckpt文件) | | |||||
## 用法 | |||||
### 推理 | |||||
如果您需要在GPU、Ascend 910、Ascend 310等多个硬件平台上使用训练好的模型进行推理,请参考此[链接](https://www.mindspore.cn/docs/programming_guide/zh-CN/master/multi_platform_inference.html)。以下为简单示例: | |||||
- Ascend处理器环境运行 | |||||
```python | |||||
# 设置上下文 | |||||
context.set_context(mode=context.GRAPH_HOME, device_target=cfg.device_target) | |||||
context.set_context(device_id=cfg.device_id) | |||||
# 加载未知数据集进行推理 | |||||
dataset = dataset.create_dataset(cfg.data_path, 1, False) | |||||
# 定义模型 | |||||
net = CNNCTC(cfg.NUM_CLASS, cfg.HIDDEN_SIZE, cfg.FINAL_FEATURE_WIDTH) | |||||
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, | |||||
cfg.momentum, weight_decay=cfg.weight_decay) | |||||
loss = P.CTCLoss(preprocess_collapse_repeated=False, | |||||
ctc_merge_repeated=True, | |||||
ignore_longer_outputs_than_inputs=False) | |||||
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}) | |||||
# 加载预训练模型 | |||||
param_dict = load_checkpoint(cfg.checkpoint_path) | |||||
load_param_into_net(net, param_dict) | |||||
net.set_train(False) | |||||
# Make predictions on the unseen dataset | |||||
acc = model.eval(dataset) | |||||
print("accuracy: ", acc) | |||||
``` | |||||
### 在预训练模型上继续训练 | |||||
- Ascend处理器环境运行 | |||||
```python | |||||
# 加载数据集 | |||||
dataset = create_dataset(cfg.data_path, 1) | |||||
batch_num = dataset.get_dataset_size() | |||||
# 定义模型 | |||||
net = CNNCTC(cfg.NUM_CLASS, cfg.HIDDEN_SIZE, cfg.FINAL_FEATURE_WIDTH) | |||||
# 如果pre_trained为True,则继续训练 | |||||
if cfg.pre_trained: | |||||
param_dict = load_checkpoint(cfg.checkpoint_path) | |||||
load_param_into_net(net, param_dict) | |||||
lr = lr_steps(0, lr_max=cfg.lr_init, total_epochs=cfg.epoch_size, | |||||
steps_per_epoch=batch_num) | |||||
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), | |||||
Tensor(lr), cfg.momentum, weight_decay=cfg.weight_decay) | |||||
loss = P.CTCLoss(preprocess_collapse_repeated=False, | |||||
ctc_merge_repeated=True, | |||||
ignore_longer_outputs_than_inputs=False) | |||||
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}, | |||||
amp_level="O2", keep_batchnorm_fp32=False, loss_scale_manager=None) | |||||
# 设置回调 | |||||
config_ck = CheckpointConfig(save_checkpoint_steps=batch_num * 5, | |||||
keep_checkpoint_max=cfg.keep_checkpoint_max) | |||||
time_cb = TimeMonitor(data_size=batch_num) | |||||
ckpoint_cb = ModelCheckpoint(prefix="train_googlenet_cifar10", directory="./", | |||||
config=config_ck) | |||||
loss_cb = LossMonitor() | |||||
# 开始训练 | |||||
model.train(cfg.epoch_size, dataset, callbacks=[time_cb, ckpoint_cb, loss_cb]) | |||||
print("train success") | |||||
``` | |||||
# ModelZoo主页 | |||||
请浏览官网[主页](https://gitee.com/mindspore/models)。 |
@@ -0,0 +1,111 @@ | |||||
# Copyright 2020 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================ | |||||
"""cnnctc eval""" | |||||
import time | |||||
import numpy as np | |||||
from mindspore import Tensor, context | |||||
import mindspore.common.dtype as mstype | |||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||||
from mindspore.dataset import GeneratorDataset | |||||
from src.util import CTCLabelConverter, AverageMeter | |||||
from src.dataset import iiit_generator_batch, adv_iiit_generator_batch | |||||
from src.cnn_ctc import CNNCTC | |||||
from src.model_utils.config import config | |||||
from src.model_utils.moxing_adapter import moxing_wrapper | |||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=False, save_graphs_path=".") | |||||
def test_dataset_creator(is_adv=False): | |||||
if is_adv: | |||||
ds = GeneratorDataset(adv_iiit_generator_batch(), ['img', 'label_indices', 'text', | |||||
'sequence_length', 'label_str']) | |||||
else: | |||||
ds = GeneratorDataset(iiit_generator_batch, ['img', 'label_indices', 'text', | |||||
'sequence_length', 'label_str']) | |||||
return ds | |||||
@moxing_wrapper(pre_process=None) | |||||
def test(): | |||||
"""Eval cnn-ctc model.""" | |||||
target = config.device_target | |||||
context.set_context(device_target=target) | |||||
ds = test_dataset_creator(is_adv=config.IS_ADV) | |||||
net = CNNCTC(config.NUM_CLASS, config.HIDDEN_SIZE, config.FINAL_FEATURE_WIDTH) | |||||
ckpt_path = config.CHECKPOINT_PATH | |||||
param_dict = load_checkpoint(ckpt_path) | |||||
load_param_into_net(net, param_dict) | |||||
print('parameters loaded! from: ', ckpt_path) | |||||
converter = CTCLabelConverter(config.CHARACTER) | |||||
model_run_time = AverageMeter() | |||||
npu_to_cpu_time = AverageMeter() | |||||
postprocess_time = AverageMeter() | |||||
count = 0 | |||||
correct_count = 0 | |||||
for data in ds.create_tuple_iterator(): | |||||
img, _, text, _, length = data | |||||
img_tensor = Tensor(img, mstype.float32) | |||||
model_run_begin = time.time() | |||||
model_predict = net(img_tensor) | |||||
model_run_end = time.time() | |||||
model_run_time.update(model_run_end - model_run_begin) | |||||
npu_to_cpu_begin = time.time() | |||||
model_predict = np.squeeze(model_predict.asnumpy()) | |||||
npu_to_cpu_end = time.time() | |||||
npu_to_cpu_time.update(npu_to_cpu_end - npu_to_cpu_begin) | |||||
postprocess_begin = time.time() | |||||
preds_size = np.array([model_predict.shape[1]] * config.TEST_BATCH_SIZE) | |||||
preds_index = np.argmax(model_predict, 2) | |||||
preds_index = np.reshape(preds_index, [-1]) | |||||
preds_str = converter.decode(preds_index, preds_size) | |||||
postprocess_end = time.time() | |||||
postprocess_time.update(postprocess_end - postprocess_begin) | |||||
label_str = converter.reverse_encode(text.asnumpy(), length.asnumpy()) | |||||
if count == 0: | |||||
model_run_time.reset() | |||||
npu_to_cpu_time.reset() | |||||
postprocess_time.reset() | |||||
else: | |||||
print('---------model run time--------', model_run_time.avg) | |||||
print('---------npu_to_cpu run time--------', npu_to_cpu_time.avg) | |||||
print('---------postprocess run time--------', postprocess_time.avg) | |||||
print("Prediction samples: \n", preds_str[:5]) | |||||
print("Ground truth: \n", label_str[:5]) | |||||
for pred, label in zip(preds_str, label_str): | |||||
if pred == label: | |||||
correct_count += 1 | |||||
count += 1 | |||||
print(count) | |||||
print('accuracy: ', correct_count / count) | |||||
if __name__ == '__main__': | |||||
test() |
@@ -0,0 +1,51 @@ | |||||
# Copyright 2020 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================ | |||||
"""export checkpoint file into air, onnx, mindir models | |||||
suggest run as python export.py --filename cnnctc --file_format MINDIR --ckpt_file [ckpt file path] | |||||
""" | |||||
import os | |||||
import numpy as np | |||||
from mindspore import Tensor, context, load_checkpoint, export | |||||
import mindspore.common.dtype as mstype | |||||
from src.cnn_ctc import CNNCTC | |||||
from src.model_utils.config import config | |||||
from src.model_utils.moxing_adapter import moxing_wrapper | |||||
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target) | |||||
if config.device_target == "Ascend": | |||||
context.set_context(device_id=config.device_id) | |||||
def modelarts_pre_process(): | |||||
config.file_name = os.path.join(config.output_path, config.file_name) | |||||
@moxing_wrapper(pre_process=modelarts_pre_process) | |||||
def model_export(): | |||||
"""Export model.""" | |||||
net = CNNCTC(config.NUM_CLASS, config.HIDDEN_SIZE, config.FINAL_FEATURE_WIDTH) | |||||
load_checkpoint(config.ckpt_file, net=net) | |||||
bs = config.TEST_BATCH_SIZE | |||||
input_data = Tensor(np.zeros([bs, 3, config.IMG_H, config.IMG_W]), mstype.float32) | |||||
export(net, input_data, file_name=config.file_name, file_format=config.file_format) | |||||
if __name__ == '__main__': | |||||
model_export() |
@@ -0,0 +1,30 @@ | |||||
# Copyright 2021 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================ | |||||
"""hub config""" | |||||
from src.cnn_ctc import CNNCTC | |||||
from src.config import Config_CNNCTC | |||||
def cnnctc_net(*args, **kwargs): | |||||
return CNNCTC(*args, **kwargs) | |||||
def create_network(name, *args, **kwargs): | |||||
""" | |||||
create cnnctc network | |||||
""" | |||||
if name == "cnnctc": | |||||
config = Config_CNNCTC | |||||
return cnnctc_net(config.NUM_CLASS, config.HIDDEN_SIZE, config.FINAL_FEATURE_WIDTH, *args, **kwargs) | |||||
raise NotImplementedError(f"{name} is not implemented in the repo") |
@@ -0,0 +1,54 @@ | |||||
# Copyright 2021 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================ | |||||
"""post process for 310 inference""" | |||||
import os | |||||
import numpy as np | |||||
from src.model_utils.config import config | |||||
from src.util import CTCLabelConverter | |||||
def calcul_acc(labels, preds): | |||||
return sum(1 for x, y in zip(labels, preds) if x == y) / len(labels) | |||||
def get_result(result_path, label_path): | |||||
"""Get result.""" | |||||
converter = CTCLabelConverter(config.CHARACTER) | |||||
files = os.listdir(result_path) | |||||
preds = [] | |||||
labels = [] | |||||
label_dict = {} | |||||
with open(label_path, 'r') as f: | |||||
lines = f.readlines() | |||||
for line in lines: | |||||
label_dict[line.split(',')[0]] = line.split(',')[1].replace('\n', '') | |||||
for file in files: | |||||
file_name = file.split('.')[0] | |||||
label = label_dict[file_name] | |||||
labels.append(label) | |||||
new_result_path = os.path.join(result_path, file) | |||||
output = np.fromfile(new_result_path, dtype=np.float32) | |||||
output = np.reshape(output, (config.FINAL_FEATURE_WIDTH, config.NUM_CLASS)) | |||||
model_predict = np.squeeze(output) | |||||
preds_size = np.array([model_predict.shape[0]] * 1) | |||||
preds_index = np.argmax(model_predict, axis=1) | |||||
preds_str = converter.decode(preds_index, preds_size) | |||||
preds.append(preds_str[0]) | |||||
acc = calcul_acc(labels, preds) | |||||
print("Total data: {}, accuracy: {}".format(len(labels), acc)) | |||||
if __name__ == '__main__': | |||||
get_result(config.result_path, config.label_path) |
@@ -0,0 +1,96 @@ | |||||
# Copyright 2021 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================ | |||||
"""post process for 310 inference""" | |||||
import os | |||||
import sys | |||||
import six | |||||
import lmdb | |||||
from PIL import Image | |||||
from src.model_utils.config import config | |||||
from src.util import CTCLabelConverter | |||||
def get_img_from_lmdb(env_, ind): | |||||
"""Get image_from lmdb.""" | |||||
with env_.begin(write=False) as txn_: | |||||
label_key = 'label-%09d'.encode() % ind | |||||
label_ = txn_.get(label_key).decode('utf-8') | |||||
img_key = 'image-%09d'.encode() % ind | |||||
imgbuf = txn_.get(img_key) | |||||
buf = six.BytesIO() | |||||
buf.write(imgbuf) | |||||
buf.seek(0) | |||||
try: | |||||
img_ = Image.open(buf).convert('RGB') # for color image | |||||
except IOError: | |||||
print(f'Corrupted image for {ind}') | |||||
# make dummy image and dummy label for corrupted image. | |||||
img_ = Image.new('RGB', (config.IMG_W, config.IMG_H)) | |||||
label_ = '[dummy_label]' | |||||
label_ = label_.lower() | |||||
return img_, label_ | |||||
if __name__ == '__main__': | |||||
max_len = int((26 + 1) // 2) | |||||
converter = CTCLabelConverter(config.CHARACTER) | |||||
env = lmdb.open(config.TEST_DATASET_PATH, max_readers=32, readonly=True, lock=False, readahead=False, meminit=False) | |||||
if not env: | |||||
print('cannot create lmdb from %s' % (config.TEST_DATASET_PATH)) | |||||
sys.exit(0) | |||||
with env.begin(write=False) as txn: | |||||
n_samples = int(txn.get('num-samples'.encode())) | |||||
n_samples = n_samples | |||||
# Filtering | |||||
filtered_index_list = [] | |||||
for index_ in range(n_samples): | |||||
index_ += 1 # lmdb starts with 1 | |||||
label_key_ = 'label-%09d'.encode() % index_ | |||||
label = txn.get(label_key_).decode('utf-8') | |||||
if len(label) > max_len: | |||||
continue | |||||
illegal_sample = False | |||||
for char_item in label.lower(): | |||||
if char_item not in config.CHARACTER: | |||||
illegal_sample = True | |||||
break | |||||
if illegal_sample: | |||||
continue | |||||
filtered_index_list.append(index_) | |||||
img_ret = [] | |||||
text_ret = [] | |||||
print(f'num of samples in IIIT dataset: {len(filtered_index_list)}') | |||||
i = 0 | |||||
label_dict = {} | |||||
for index in filtered_index_list: | |||||
img, label = get_img_from_lmdb(env, index) | |||||
img_name = os.path.join(config.preprocess_output, str(i) + ".png") | |||||
img.save(img_name) | |||||
label_dict[str(i)] = label | |||||
i += 1 | |||||
with open('./label.txt', 'w') as file: | |||||
for k, v in label_dict.items(): | |||||
file.write(str(k) + ',' + str(v) + '\n') |
@@ -0,0 +1,7 @@ | |||||
lmdb | |||||
tqdm | |||||
six | |||||
numpy | |||||
pillow | |||||
pyyaml | |||||
@@ -0,0 +1,50 @@ | |||||
#!/bin/bash | |||||
# Copyright 2020 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================ | |||||
if [ $# -ne 2 ] | |||||
then | |||||
echo "Usage: sh scripts/run_eval_ascend.sh [DEVICE_ID] [TRAINED_CKPT]" | |||||
exit 1 | |||||
fi | |||||
get_real_path(){ | |||||
if [ "${1:0:1}" == "/" ]; then | |||||
echo "$1" | |||||
else | |||||
echo "$(realpath -m $PWD/$1)" | |||||
fi | |||||
} | |||||
PATH1=$(get_real_path $2) | |||||
echo $PATH1 | |||||
if [ ! -f $PATH1 ] | |||||
then | |||||
echo "error: TRAINED_CKPT=$PATH1 is not a file" | |||||
exit 1 | |||||
fi | |||||
ulimit -u unlimited | |||||
export DEVICE_ID=$1 | |||||
if [ -d "eval" ]; | |||||
then | |||||
rm -rf ./eval | |||||
fi | |||||
mkdir ./eval | |||||
echo "start inferring for device $DEVICE_ID" | |||||
env > ./eval/env.log | |||||
python eval.py --CHECKPOINT_PATH=$PATH1 &> ./eval/log & | |||||
#cd .. || exit |
@@ -0,0 +1,50 @@ | |||||
#!/bin/bash | |||||
# Copyright 2021 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================ | |||||
if [ $# -ne 1 ] | |||||
then | |||||
echo "Usage: sh run_eval_gpu.sh [TRAINED_CKPT]" | |||||
exit 1 | |||||
fi | |||||
get_real_path(){ | |||||
if [ "${1:0:1}" == "/" ]; then | |||||
echo "$1" | |||||
else | |||||
echo "$(realpath -m $PWD/$1)" | |||||
fi | |||||
} | |||||
PATH1=$(get_real_path $1) | |||||
echo $PATH1 | |||||
if [ ! -f $PATH1 ] | |||||
then | |||||
echo "error: TRAINED_CKPT=$PATH1 is not a file" | |||||
exit 1 | |||||
fi | |||||
#ulimit -u unlimited | |||||
export DEVICE_ID=0 | |||||
if [ -d "eval" ]; | |||||
then | |||||
rm -rf ./eval | |||||
fi | |||||
mkdir ./eval | |||||
echo "start inferring for device $DEVICE_ID" | |||||
env > ./eval/env.log | |||||
python eval.py --device_target="GPU" --device_id=$DEVICE_ID --CHECKPOINT_PATH=$PATH1 &> ./eval/log & | |||||
#cd .. || exit |
@@ -0,0 +1,49 @@ | |||||
#!/bin/bash | |||||
# Copyright 2020 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================ | |||||
if [ $# != 1 ] && [ $# != 2 ] | |||||
then | |||||
echo "run as sh scripts/run_standalone_train_ascend.sh DEVICE_ID PRE_TRAINED(options)" | |||||
exit 1 | |||||
fi | |||||
get_real_path(){ | |||||
if [ "${1:0:1}" == "/" ]; then | |||||
echo "$1" | |||||
else | |||||
echo "$(realpath -m $PWD/$1)" | |||||
fi | |||||
} | |||||
PATH1=$(get_real_path $2) | |||||
export DEVICE_ID=$1 | |||||
ulimit -u unlimited | |||||
if [ -d "train" ]; | |||||
then | |||||
rm -rf ./train | |||||
fi | |||||
mkdir ./train | |||||
echo "start training for device $DEVICE_ID" | |||||
env > env.log | |||||
if [ -f $PATH1 ] | |||||
then | |||||
python train.py --PRED_TRAINED=$PATH1 --run_distribute=False &> log & | |||||
else | |||||
python train.py --run_distribute=False &> log & | |||||
fi | |||||
cd .. || exit |
@@ -0,0 +1,42 @@ | |||||
#!/bin/bash | |||||
# Copyright 2021 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================ | |||||
get_real_path(){ | |||||
if [ "${1:0:1}" == "/" ]; then | |||||
echo "$1" | |||||
else | |||||
echo "$(realpath -m $PWD/$1)" | |||||
fi | |||||
} | |||||
PATH1=$(get_real_path $1) | |||||
echo $PATH1 | |||||
export DEVICE_NUM=1 | |||||
export RANK_SIZE=1 | |||||
if [ -d "train" ]; | |||||
then | |||||
rm -rf ./train | |||||
fi | |||||
mkdir ./train | |||||
env > ./train/env.log | |||||
if [ -f $PATH1 ] | |||||
then | |||||
python train.py --device_target="GPU" --PRED_TRAINED=$PATH1 --run_distribute=False &> log & | |||||
else | |||||
python train.py --device_target="GPU" --run_distribute=False &> ./train/log & | |||||
fi | |||||
#cd .. || exit |
@@ -0,0 +1,15 @@ | |||||
# Copyright 2020 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================ | |||||
"""src init file""" |
@@ -0,0 +1,73 @@ | |||||
# Copyright 2020 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================ | |||||
"""loss callback""" | |||||
import time | |||||
import numpy as np | |||||
from mindspore.train.callback import Callback | |||||
from .util import AverageMeter | |||||
class LossCallBack(Callback): | |||||
""" | |||||
Monitor the loss in training. | |||||
If the loss is NAN or INF terminating training. | |||||
Note: | |||||
If per_print_times is 0 do not print loss. | |||||
Args: | |||||
per_print_times (int): Print loss every times. Default: 1. | |||||
""" | |||||
def __init__(self, per_print_times=1): | |||||
super(LossCallBack, self).__init__() | |||||
if not isinstance(per_print_times, int) or per_print_times < 0: | |||||
raise ValueError("print_step must be int and >= 0.") | |||||
self._per_print_times = per_print_times | |||||
self.loss_avg = AverageMeter() | |||||
self.timer = AverageMeter() | |||||
self.start_time = time.time() | |||||
def step_end(self, run_context): | |||||
"""step end.""" | |||||
cb_params = run_context.original_args() | |||||
loss = np.array(cb_params.net_outputs) | |||||
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1 | |||||
cur_num = cb_params.cur_step_num | |||||
if cur_step_in_epoch % 2000 == 1: | |||||
self.loss_avg = AverageMeter() | |||||
self.timer = AverageMeter() | |||||
self.start_time = time.time() | |||||
else: | |||||
self.timer.update(time.time() - self.start_time) | |||||
self.start_time = time.time() | |||||
self.loss_avg.update(loss) | |||||
if self._per_print_times != 0 and cur_num % self._per_print_times == 0: | |||||
loss_file = open("./loss.log", "a+") | |||||
loss_file.write("epoch: %s step: %s , loss is %s, average time per step is %s" % ( | |||||
cb_params.cur_epoch_num, cur_step_in_epoch, | |||||
self.loss_avg.avg, self.timer.avg)) | |||||
loss_file.write("\n") | |||||
loss_file.close() | |||||
print("epoch: %s step: %s , loss is %s, average time per step is %s" % ( | |||||
cb_params.cur_epoch_num, cur_step_in_epoch, | |||||
self.loss_avg.avg, self.timer.avg)) |
@@ -0,0 +1,389 @@ | |||||
# Copyright 2020 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================ | |||||
"""cnn_ctc network define""" | |||||
import mindspore.common.dtype as mstype | |||||
import mindspore.nn as nn | |||||
from mindspore import Tensor, Parameter, ParameterTuple, context | |||||
from mindspore.common.initializer import TruncatedNormal, initializer | |||||
from mindspore.communication.management import get_group_size | |||||
from mindspore.context import ParallelMode | |||||
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer | |||||
from mindspore.ops import composite as C | |||||
from mindspore.ops import functional as F | |||||
from mindspore.ops import operations as P | |||||
grad_scale = C.MultitypeFuncGraph("grad_scale") | |||||
reciprocal = P.Reciprocal() | |||||
@grad_scale.register("Tensor", "Tensor") | |||||
def tensor_grad_scale(scale, grad): | |||||
return grad * F.cast(reciprocal(scale), F.dtype(grad)) | |||||
_grad_overflow = C.MultitypeFuncGraph("_grad_overflow") | |||||
grad_overflow = P.FloatStatus() | |||||
@_grad_overflow.register("Tensor") | |||||
def _tensor_grad_overflow(grad): | |||||
return grad_overflow(grad) | |||||
GRADIENT_CLIP_MIN = -64000 | |||||
GRADIENT_CLIP_MAX = 64000 | |||||
class ClipGradients(nn.Cell): | |||||
""" | |||||
Clip large gradients, typically generated from overflow. | |||||
""" | |||||
def __init__(self): | |||||
super(ClipGradients, self).__init__() | |||||
self.clip_by_norm = nn.ClipByNorm() | |||||
self.cast = P.Cast() | |||||
self.dtype = P.DType() | |||||
def construct(self, grads, clip_min, clip_max): | |||||
new_grads = () | |||||
for grad in grads: | |||||
dt = self.dtype(grad) | |||||
t = C.clip_by_value(grad, self.cast(F.tuple_to_array((clip_min,)), dt), | |||||
self.cast(F.tuple_to_array((clip_max,)), dt)) | |||||
t = self.cast(t, dt) | |||||
new_grads = new_grads + (t,) | |||||
return new_grads | |||||
class CNNCTCTrainOneStepWithLossScaleCell(nn.Cell): | |||||
""" | |||||
Encapsulation class of CNNCTC network training. | |||||
Used for GPU training in order to manage overflowing gradients. | |||||
Args: | |||||
network (Cell): The training network. Note that loss function should have been added. | |||||
optimizer (Optimizer): Optimizer for updating the weights. | |||||
scale_sense (Cell): Loss scaling value. | |||||
""" | |||||
def __init__(self, network, optimizer, scale_sense): | |||||
super(CNNCTCTrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False) | |||||
self.network = network | |||||
self.optimizer = optimizer | |||||
if isinstance(scale_sense, nn.Cell): | |||||
self.loss_scaling_manager = scale_sense | |||||
self.scale_sense = Parameter(Tensor(scale_sense.get_loss_scale(), | |||||
dtype=mstype.float32), name="scale_sense") | |||||
elif isinstance(scale_sense, Tensor): | |||||
if scale_sense.shape == (1,) or scale_sense.shape == (): | |||||
self.scale_sense = Parameter(scale_sense, name='scale_sense') | |||||
else: | |||||
raise ValueError("The shape of scale_sense must be (1,) or (), but got {}".format( | |||||
scale_sense.shape)) | |||||
else: | |||||
raise TypeError("The scale_sense must be Cell or Tensor, but got {}".format( | |||||
type(scale_sense))) | |||||
self.network.set_grad() | |||||
self.weights = ParameterTuple(network.trainable_params()) | |||||
self.grad = C.GradOperation(get_by_list=True, | |||||
sens_param=True) | |||||
self.reducer_flag = False | |||||
self.parallel_mode = context.get_auto_parallel_context("parallel_mode") | |||||
if self.parallel_mode not in ParallelMode.MODE_LIST: | |||||
raise ValueError("Parallel mode does not support: ", self.parallel_mode) | |||||
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: | |||||
self.reducer_flag = True | |||||
self.grad_reducer = None | |||||
if self.reducer_flag: | |||||
mean = context.get_auto_parallel_context("gradients_mean") | |||||
degree = get_group_size() | |||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) | |||||
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) | |||||
self.clip_gradients = ClipGradients() | |||||
self.cast = P.Cast() | |||||
self.addn = P.AddN() | |||||
self.reshape = P.Reshape() | |||||
self.hyper_map = C.HyperMap() | |||||
self.less_equal = P.LessEqual() | |||||
self.allreduce = P.AllReduce() | |||||
def construct(self, img, label_indices, text, sequence_length): | |||||
"""model construct.""" | |||||
weights = self.weights | |||||
loss = self.network(img, label_indices, text, sequence_length) | |||||
scaling_sens = self.scale_sense | |||||
grads = self.grad(self.network, weights)(img, label_indices, text, sequence_length, | |||||
self.cast(scaling_sens, mstype.float32)) | |||||
grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads) | |||||
grads = self.clip_gradients(grads, GRADIENT_CLIP_MIN, GRADIENT_CLIP_MAX) | |||||
if self.reducer_flag: | |||||
# apply grad reducer on grads | |||||
grads = self.grad_reducer(grads) | |||||
self.optimizer(grads) | |||||
return (loss, scaling_sens) | |||||
class CNNCTC(nn.Cell): | |||||
"""CNNCTC model construct.""" | |||||
def __init__(self, num_class, hidden_size, final_feature_width): | |||||
super(CNNCTC, self).__init__() | |||||
self.num_class = num_class | |||||
self.hidden_size = hidden_size | |||||
self.final_feature_width = final_feature_width | |||||
self.feature_extraction = ResNetFeatureExtractor() | |||||
self.prediction = nn.Dense(self.hidden_size, self.num_class) | |||||
self.transpose = P.Transpose() | |||||
self.reshape = P.Reshape() | |||||
def construct(self, x): | |||||
x = self.feature_extraction(x) | |||||
x = self.transpose(x, (0, 3, 1, 2)) # [b, c, h, w] -> [b, w, c, h] | |||||
x = self.reshape(x, (-1, self.hidden_size)) | |||||
x = self.prediction(x) | |||||
x = self.reshape(x, (-1, self.final_feature_width, self.num_class)) | |||||
return x | |||||
class WithLossCell(nn.Cell): | |||||
"""Add loss cell for network.""" | |||||
def __init__(self, backbone, loss_fn): | |||||
super(WithLossCell, self).__init__(auto_prefix=False) | |||||
self._backbone = backbone | |||||
self._loss_fn = loss_fn | |||||
def construct(self, img, label_indices, text, sequence_length): | |||||
model_predict = self._backbone(img) | |||||
return self._loss_fn(model_predict, label_indices, text, sequence_length) | |||||
@property | |||||
def backbone_network(self): | |||||
return self._backbone | |||||
class CTCLoss(nn.Cell): | |||||
"""Loss of CTC.""" | |||||
def __init__(self): | |||||
super(CTCLoss, self).__init__() | |||||
self.loss = P.CTCLoss(preprocess_collapse_repeated=False, | |||||
ctc_merge_repeated=True, | |||||
ignore_longer_outputs_than_inputs=False) | |||||
self.mean = P.ReduceMean() | |||||
self.transpose = P.Transpose() | |||||
self.reshape = P.Reshape() | |||||
def construct(self, inputs, labels_indices, labels_values, sequence_length): | |||||
inputs = self.transpose(inputs, (1, 0, 2)) | |||||
loss, _ = self.loss(inputs, labels_indices, labels_values, sequence_length) | |||||
loss = self.mean(loss) | |||||
return loss | |||||
class ResNetFeatureExtractor(nn.Cell): | |||||
"""Extractor of ResNet feature.""" | |||||
def __init__(self): | |||||
super(ResNetFeatureExtractor, self).__init__() | |||||
self.conv_net = ResNet(3, 512, BasicBlock, [1, 2, 5, 3]) | |||||
def construct(self, feature_map): | |||||
return self.conv_net(feature_map) | |||||
class ResNet(nn.Cell): | |||||
"""Network of ResNet.""" | |||||
def __init__(self, input_channel, output_channel, block, layers): | |||||
super(ResNet, self).__init__() | |||||
self.output_channel_block = [int(output_channel / 4), int(output_channel / 2), output_channel, output_channel] | |||||
self.inplanes = int(output_channel / 8) | |||||
self.conv0_1 = ms_conv3x3(input_channel, int(output_channel / 16), stride=1, padding=1, pad_mode='pad') | |||||
self.bn0_1 = ms_fused_bn(int(output_channel / 16)) | |||||
self.conv0_2 = ms_conv3x3(int(output_channel / 16), self.inplanes, stride=1, padding=1, pad_mode='pad') | |||||
self.bn0_2 = ms_fused_bn(self.inplanes) | |||||
self.relu = P.ReLU() | |||||
self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='valid') | |||||
self.layer1 = self._make_layer(block, self.output_channel_block[0], layers[0]) | |||||
self.conv1 = ms_conv3x3(self.output_channel_block[0], self.output_channel_block[0], stride=1, padding=1, | |||||
pad_mode='pad') | |||||
self.bn1 = ms_fused_bn(self.output_channel_block[0]) | |||||
self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='valid') | |||||
self.layer2 = self._make_layer(block, self.output_channel_block[1], layers[1]) | |||||
self.conv2 = ms_conv3x3(self.output_channel_block[1], self.output_channel_block[1], stride=1, padding=1, | |||||
pad_mode='pad') | |||||
self.bn2 = ms_fused_bn(self.output_channel_block[1]) | |||||
self.pad = P.Pad(((0, 0), (0, 0), (0, 0), (2, 2))) | |||||
self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=(2, 1), pad_mode='valid') | |||||
self.layer3 = self._make_layer(block, self.output_channel_block[2], layers[2]) | |||||
self.conv3 = ms_conv3x3(self.output_channel_block[2], self.output_channel_block[2], stride=1, padding=1, | |||||
pad_mode='pad') | |||||
self.bn3 = ms_fused_bn(self.output_channel_block[2]) | |||||
self.layer4 = self._make_layer(block, self.output_channel_block[3], layers[3]) | |||||
self.conv4_1 = ms_conv2x2(self.output_channel_block[3], self.output_channel_block[3], stride=(2, 1), | |||||
pad_mode='valid') | |||||
self.bn4_1 = ms_fused_bn(self.output_channel_block[3]) | |||||
self.conv4_2 = ms_conv2x2(self.output_channel_block[3], self.output_channel_block[3], stride=1, padding=0, | |||||
pad_mode='valid') | |||||
self.bn4_2 = ms_fused_bn(self.output_channel_block[3]) | |||||
def _make_layer(self, block, planes, blocks, stride=1): | |||||
"""make layer""" | |||||
downsample = None | |||||
if stride != 1 or self.inplanes != planes * block.expansion: | |||||
downsample = nn.SequentialCell( | |||||
[ms_conv1x1(self.inplanes, planes * block.expansion, stride=stride), | |||||
ms_fused_bn(planes * block.expansion)] | |||||
) | |||||
layers = [] | |||||
layers.append(block(self.inplanes, planes, stride, downsample)) | |||||
self.inplanes = planes * block.expansion | |||||
for _ in range(1, blocks): | |||||
layers.append(block(self.inplanes, planes)) | |||||
return nn.SequentialCell(layers) | |||||
def construct(self, x): | |||||
"""model construct""" | |||||
x = self.conv0_1(x) | |||||
x = self.bn0_1(x) | |||||
x = self.relu(x) | |||||
x = self.conv0_2(x) | |||||
x = self.bn0_2(x) | |||||
x = self.relu(x) | |||||
x = self.maxpool1(x) | |||||
x = self.layer1(x) | |||||
x = self.conv1(x) | |||||
x = self.bn1(x) | |||||
x = self.relu(x) | |||||
x = self.maxpool2(x) | |||||
x = self.layer2(x) | |||||
x = self.conv2(x) | |||||
x = self.bn2(x) | |||||
x = self.relu(x) | |||||
x = self.maxpool3(x) | |||||
x = self.layer3(x) | |||||
x = self.conv3(x) | |||||
x = self.bn3(x) | |||||
x = self.relu(x) | |||||
x = self.layer4(x) | |||||
x = self.pad(x) | |||||
x = self.conv4_1(x) | |||||
x = self.bn4_1(x) | |||||
x = self.relu(x) | |||||
x = self.conv4_2(x) | |||||
x = self.bn4_2(x) | |||||
x = self.relu(x) | |||||
return x | |||||
class BasicBlock(nn.Cell): | |||||
"""BasicBlock""" | |||||
expansion = 1 | |||||
def __init__(self, inplanes, planes, stride=1, downsample=None): | |||||
super(BasicBlock, self).__init__() | |||||
self.conv1 = ms_conv3x3(inplanes, planes, stride=stride, padding=1, pad_mode='pad') | |||||
self.bn1 = ms_fused_bn(planes) | |||||
self.conv2 = ms_conv3x3(planes, planes, stride=stride, padding=1, pad_mode='pad') | |||||
self.bn2 = ms_fused_bn(planes) | |||||
self.relu = P.ReLU() | |||||
self.downsample = downsample | |||||
self.add = P.Add() | |||||
def construct(self, x): | |||||
"""Basic block construct""" | |||||
residual = x | |||||
out = self.conv1(x) | |||||
out = self.bn1(out) | |||||
out = self.relu(out) | |||||
out = self.conv2(out) | |||||
out = self.bn2(out) | |||||
if self.downsample is not None: | |||||
residual = self.downsample(x) | |||||
out = self.add(out, residual) | |||||
out = self.relu(out) | |||||
return out | |||||
def weight_variable(shape, half_precision=False): | |||||
if half_precision: | |||||
return initializer(TruncatedNormal(0.02), shape, dtype=mstype.float16) | |||||
return TruncatedNormal(0.02) | |||||
def ms_conv3x3(in_channels, out_channels, stride=1, padding=0, pad_mode='same', has_bias=False): | |||||
"""Get a conv2d layer with 3x3 kernel size.""" | |||||
init_value = weight_variable((out_channels, in_channels, 3, 3)) | |||||
return nn.Conv2d(in_channels, out_channels, | |||||
kernel_size=3, stride=stride, padding=padding, pad_mode=pad_mode, weight_init=init_value, | |||||
has_bias=has_bias) | |||||
def ms_conv1x1(in_channels, out_channels, stride=1, padding=0, pad_mode='same', has_bias=False): | |||||
"""Get a conv2d layer with 1x1 kernel size.""" | |||||
init_value = weight_variable((out_channels, in_channels, 1, 1)) | |||||
return nn.Conv2d(in_channels, out_channels, | |||||
kernel_size=1, stride=stride, padding=padding, pad_mode=pad_mode, weight_init=init_value, | |||||
has_bias=has_bias) | |||||
def ms_conv2x2(in_channels, out_channels, stride=1, padding=0, pad_mode='same', has_bias=False): | |||||
"""Get a conv2d layer with 2x2 kernel size.""" | |||||
init_value = weight_variable((out_channels, in_channels, 1, 1)) | |||||
return nn.Conv2d(in_channels, out_channels, | |||||
kernel_size=2, stride=stride, padding=padding, pad_mode=pad_mode, weight_init=init_value, | |||||
has_bias=has_bias) | |||||
def ms_fused_bn(channels, momentum=0.1): | |||||
"""Get a fused batchnorm""" | |||||
return nn.BatchNorm2d(channels, momentum=momentum) |
@@ -0,0 +1,343 @@ | |||||
# Copyright 2020 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================ | |||||
"""cnn_ctc dataset""" | |||||
import sys | |||||
import pickle | |||||
import math | |||||
import six | |||||
import numpy as np | |||||
from PIL import Image | |||||
import lmdb | |||||
from mindspore.communication.management import get_rank, get_group_size | |||||
from src.model_utils.config import config | |||||
from src.util import CTCLabelConverter | |||||
class NormalizePAD: | |||||
"""Normalize pad.""" | |||||
def __init__(self, max_size, pad_type='right'): | |||||
self.max_size = max_size | |||||
self.pad_type = pad_type | |||||
def __call__(self, img): | |||||
# toTensor | |||||
img = np.array(img, dtype=np.float32) | |||||
# normalize | |||||
means = [121.58949, 123.93914, 123.418655] | |||||
stds = [65.70353, 65.142426, 68.61079] | |||||
img = np.subtract(img, means) | |||||
img = np.true_divide(img, stds) | |||||
img = img.transpose([2, 0, 1]) | |||||
img = img.astype(np.float) | |||||
_, _, w = img.shape | |||||
pad_img = np.zeros(shape=self.max_size, dtype=np.float32) | |||||
pad_img[:, :, :w] = img # right pad | |||||
if self.max_size[2] != w: # add border Pad | |||||
pad_img[:, :, w:] = np.tile(np.expand_dims(img[:, :, w - 1], 2), (1, 1, self.max_size[2] - w)) | |||||
return pad_img | |||||
class AlignCollate: | |||||
"""Align collate""" | |||||
def __init__(self, img_h=32, img_w=100): | |||||
self.img_h = img_h | |||||
self.img_w = img_w | |||||
def __call__(self, images): | |||||
resized_max_w = self.img_w | |||||
input_channel = 3 | |||||
transform = NormalizePAD((input_channel, self.img_h, resized_max_w)) | |||||
resized_images = [] | |||||
for image in images: | |||||
w, h = image.size | |||||
ratio = w / float(h) | |||||
if math.ceil(self.img_h * ratio) > self.img_w: | |||||
resized_w = self.img_w | |||||
else: | |||||
resized_w = math.ceil(self.img_h * ratio) | |||||
resized_image = image.resize((resized_w, self.img_h), Image.BICUBIC) | |||||
resized_images.append(transform(resized_image)) | |||||
image_tensors = np.concatenate([np.expand_dims(t, 0) for t in resized_images], 0) | |||||
return image_tensors | |||||
def get_img_from_lmdb(env, index, is_adv=False): | |||||
"""get image from lmdb.""" | |||||
with env.begin(write=False) as txn: | |||||
label_key = 'label-%09d'.encode() % index | |||||
label = txn.get(label_key).decode('utf-8') | |||||
if is_adv: | |||||
img_key = 'adv_image-%09d'.encode() % index | |||||
else: | |||||
img_key = 'image-%09d'.encode() % index | |||||
imgbuf = txn.get(img_key) | |||||
buf = six.BytesIO() | |||||
buf.write(imgbuf) | |||||
buf.seek(0) | |||||
try: | |||||
img = Image.open(buf).convert('RGB') # for color image | |||||
except IOError: | |||||
print(f'Corrupted image for {index}') | |||||
# make dummy image and dummy label for corrupted image. | |||||
img = Image.new('RGB', (config.IMG_W, config.IMG_H)) | |||||
label = '[dummy_label]' | |||||
label = label.lower() | |||||
return img, label | |||||
class STMJGeneratorBatchFixedLength: | |||||
"""ST_MJ Generator with Batch Fixed Length""" | |||||
def __init__(self): | |||||
self.align_collector = AlignCollate() | |||||
self.converter = CTCLabelConverter(config.CHARACTER) | |||||
self.env = lmdb.open(config.TRAIN_DATASET_PATH, max_readers=32, readonly=True, lock=False, readahead=False, | |||||
meminit=False) | |||||
if not self.env: | |||||
print('cannot create lmdb from %s' % (config.TRAIN_DATASET_PATH)) | |||||
raise ValueError(config.TRAIN_DATASET_PATH) | |||||
with open(config.TRAIN_DATASET_INDEX_PATH, 'rb') as f: | |||||
self.st_mj_filtered_index_list = pickle.load(f) | |||||
print(f'num of samples in ST_MJ dataset: {len(self.st_mj_filtered_index_list)}') | |||||
self.dataset_size = len(self.st_mj_filtered_index_list) // config.TRAIN_BATCH_SIZE | |||||
self.batch_size = config.TRAIN_BATCH_SIZE | |||||
def __len__(self): | |||||
return self.dataset_size | |||||
def __getitem__(self, item): | |||||
img_ret = [] | |||||
text_ret = [] | |||||
for i in range(item * self.batch_size, (item + 1) * self.batch_size): | |||||
index = self.st_mj_filtered_index_list[i] | |||||
img, label = get_img_from_lmdb(self.env, index) | |||||
img_ret.append(img) | |||||
text_ret.append(label) | |||||
img_ret = self.align_collector(img_ret) | |||||
text_ret, length = self.converter.encode(text_ret) | |||||
label_indices = [] | |||||
for i, _ in enumerate(length): | |||||
for j in range(length[i]): | |||||
label_indices.append((i, j)) | |||||
label_indices = np.array(label_indices, np.int64) | |||||
sequence_length = np.array([config.FINAL_FEATURE_WIDTH] * config.TRAIN_BATCH_SIZE, dtype=np.int32) | |||||
text_ret = text_ret.astype(np.int32) | |||||
return img_ret, label_indices, text_ret, sequence_length | |||||
class STMJGeneratorBatchFixedLengthPara: | |||||
"""ST_MJ Generator with batch fixed length Para""" | |||||
def __init__(self): | |||||
self.align_collector = AlignCollate() | |||||
self.converter = CTCLabelConverter(config.CHARACTER) | |||||
self.env = lmdb.open(config.TRAIN_DATASET_PATH, max_readers=32, readonly=True, lock=False, readahead=False, | |||||
meminit=False) | |||||
if not self.env: | |||||
print('cannot create lmdb from %s' % (config.TRAIN_DATASET_PATH)) | |||||
raise ValueError(config.TRAIN_DATASET_PATH) | |||||
with open(config.TRAIN_DATASET_INDEX_PATH, 'rb') as f: | |||||
self.st_mj_filtered_index_list = pickle.load(f) | |||||
print(f'num of samples in ST_MJ dataset: {len(self.st_mj_filtered_index_list)}') | |||||
self.rank_id = get_rank() | |||||
self.rank_size = get_group_size() | |||||
self.dataset_size = len(self.st_mj_filtered_index_list) // config.TRAIN_BATCH_SIZE // self.rank_size | |||||
self.batch_size = config.TRAIN_BATCH_SIZE | |||||
def __len__(self): | |||||
return self.dataset_size | |||||
def __getitem__(self, item): | |||||
img_ret = [] | |||||
text_ret = [] | |||||
rank_item = (item * self.rank_size) + self.rank_id | |||||
for i in range(rank_item * self.batch_size, (rank_item + 1) * self.batch_size): | |||||
index = self.st_mj_filtered_index_list[i] | |||||
img, label = get_img_from_lmdb(self.env, index) | |||||
img_ret.append(img) | |||||
text_ret.append(label) | |||||
img_ret = self.align_collector(img_ret) | |||||
text_ret, length = self.converter.encode(text_ret) | |||||
label_indices = [] | |||||
for i, _ in enumerate(length): | |||||
for j in range(length[i]): | |||||
label_indices.append((i, j)) | |||||
label_indices = np.array(label_indices, np.int64) | |||||
sequence_length = np.array([config.FINAL_FEATURE_WIDTH] * config.TRAIN_BATCH_SIZE, dtype=np.int32) | |||||
text_ret = text_ret.astype(np.int32) | |||||
return img_ret, label_indices, text_ret, sequence_length | |||||
def iiit_generator_batch(): | |||||
"""IIIT dataset generator""" | |||||
max_len = int((26 + 1) // 2) | |||||
align_collector = AlignCollate() | |||||
converter = CTCLabelConverter(config.CHARACTER) | |||||
env = lmdb.open(config.TEST_DATASET_PATH, max_readers=32, readonly=True, lock=False, readahead=False, meminit=False) | |||||
if not env: | |||||
print('cannot create lmdb from %s' % (config.TEST_DATASET_PATH)) | |||||
sys.exit(0) | |||||
with env.begin(write=False) as txn: | |||||
n_samples = int(txn.get('num-samples'.encode())) | |||||
n_samples = n_samples | |||||
# Filtering | |||||
filtered_index_list = [] | |||||
for index in range(n_samples): | |||||
index += 1 # lmdb starts with 1 | |||||
label_key = 'label-%09d'.encode() % index | |||||
label = txn.get(label_key).decode('utf-8') | |||||
if len(label) > max_len: | |||||
continue | |||||
illegal_sample = False | |||||
for char_item in label.lower(): | |||||
if char_item not in config.CHARACTER: | |||||
illegal_sample = True | |||||
break | |||||
if illegal_sample: | |||||
continue | |||||
filtered_index_list.append(index) | |||||
img_ret = [] | |||||
text_ret = [] | |||||
print(f'num of samples in IIIT dataset: {len(filtered_index_list)}') | |||||
for index in filtered_index_list: | |||||
img, label = get_img_from_lmdb(env, index, config.IS_ADV) | |||||
img_ret.append(img) | |||||
text_ret.append(label) | |||||
if len(img_ret) == config.TEST_BATCH_SIZE: | |||||
img_ret = align_collector(img_ret) | |||||
text_ret, length = converter.encode(text_ret) | |||||
label_indices = [] | |||||
for i, _ in enumerate(length): | |||||
for j in range(length[i]): | |||||
label_indices.append((i, j)) | |||||
label_indices = np.array(label_indices, np.int64) | |||||
sequence_length = np.array([26] * config.TEST_BATCH_SIZE, dtype=np.int32) | |||||
text_ret = text_ret.astype(np.int32) | |||||
yield img_ret, label_indices, text_ret, sequence_length, length | |||||
# return img_ret, label_indices, text_ret, sequence_length, length | |||||
img_ret = [] | |||||
text_ret = [] | |||||
def adv_iiit_generator_batch(): | |||||
"""Perturb IIII dataset generator.""" | |||||
max_len = int((26 + 1) // 2) | |||||
align_collector = AlignCollate() | |||||
converter = CTCLabelConverter(config.CHARACTER) | |||||
env = lmdb.open(config.ADV_TEST_DATASET_PATH, max_readers=32, readonly=True, lock=False, readahead=False, | |||||
meminit=False) | |||||
if not env: | |||||
print('cannot create lmdb from %s' % (config.ADV_TEST_DATASET_PATH)) | |||||
sys.exit(0) | |||||
with env.begin(write=False) as txn: | |||||
n_samples = int(txn.get('num-samples'.encode())) | |||||
n_samples = n_samples | |||||
# Filtering | |||||
filtered_index_list = [] | |||||
for index in range(n_samples): | |||||
index += 1 # lmdb starts with 1 | |||||
label_key = 'label-%09d'.encode() % index | |||||
label = txn.get(label_key).decode('utf-8') | |||||
if len(label) > max_len: | |||||
continue | |||||
illegal_sample = False | |||||
for char_item in label.lower(): | |||||
if char_item not in config.CHARACTER: | |||||
illegal_sample = True | |||||
break | |||||
if illegal_sample: | |||||
continue | |||||
filtered_index_list.append(index) | |||||
img_ret = [] | |||||
text_ret = [] | |||||
print(f'num of samples in IIIT dataset: {len(filtered_index_list)}') | |||||
for index in filtered_index_list: | |||||
img, label = get_img_from_lmdb(env, index, is_adv=True) | |||||
img_ret.append(img) | |||||
text_ret.append(label) | |||||
if len(img_ret) == config.TEST_BATCH_SIZE: | |||||
img_ret = align_collector(img_ret) | |||||
text_ret, length = converter.encode(text_ret) | |||||
label_indices = [] | |||||
for i, _ in enumerate(length): | |||||
for j in range(length[i]): | |||||
label_indices.append((i, j)) | |||||
label_indices = np.array(label_indices, np.int64) | |||||
sequence_length = np.array([26] * config.TEST_BATCH_SIZE, dtype=np.int32) | |||||
text_ret = text_ret.astype(np.int32) | |||||
yield img_ret, label_indices, text_ret, sequence_length, length | |||||
img_ret = [] | |||||
text_ret = [] |
@@ -0,0 +1,41 @@ | |||||
# Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================ | |||||
"""lr generator for cnnctc""" | |||||
import math | |||||
def linear_warmup_learning_rate(current_step, warmup_steps, base_lr, init_lr): | |||||
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps) | |||||
learning_rate = float(init_lr) + lr_inc * current_step | |||||
return learning_rate | |||||
def a_cosine_learning_rate(current_step, base_lr, warmup_steps, decay_steps): | |||||
base = float(current_step - warmup_steps) / float(decay_steps) | |||||
learning_rate = (1 + math.cos(base * math.pi)) / 2 * base_lr | |||||
return learning_rate | |||||
def dynamic_lr(config, steps_per_epoch): | |||||
"""dynamic learning rate generator""" | |||||
base_lr = config.base_lr | |||||
total_steps = steps_per_epoch * config.TRAIN_EPOCHS | |||||
warmup_steps = int(config.warmup_step) | |||||
decay_steps = total_steps - warmup_steps | |||||
lr = [] | |||||
for i in range(total_steps): | |||||
if i < warmup_steps: | |||||
lr.append(linear_warmup_learning_rate(i, warmup_steps, base_lr, base_lr * config.warmup_ratio)) | |||||
else: | |||||
lr.append(a_cosine_learning_rate(i, base_lr, warmup_steps, decay_steps)) | |||||
return lr |
@@ -0,0 +1,131 @@ | |||||
# Copyright 2021 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License Version 2.0(the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# you may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0# | |||||
# | |||||
# Unless required by applicable law or agreed to in writing software | |||||
# distributed under the License is distributed on an "AS IS" BASIS | |||||
# WITHOUT WARRANT IES OR CONITTONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ==================================================================================== | |||||
"""Parse arguments""" | |||||
import os | |||||
import ast | |||||
import argparse | |||||
from pprint import pprint, pformat | |||||
import yaml | |||||
_config_path = '../../../default_config.yaml' | |||||
class Config: | |||||
""" | |||||
Configuration namespace. Convert dictionary to members | |||||
""" | |||||
def __init__(self, cfg_dict): | |||||
for k, v in cfg_dict.items(): | |||||
if isinstance(v, (list, tuple)): | |||||
setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v]) | |||||
else: | |||||
setattr(self, k, Config(v) if isinstance(v, dict) else v) | |||||
def __str__(self): | |||||
return pformat(self.__dict__) | |||||
def __repr__(self): | |||||
return self.__str__() | |||||
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path='default_config.yaml'): | |||||
""" | |||||
Parse command line arguments to the configuration according to the default yaml | |||||
Args: | |||||
parser: Parent parser | |||||
cfg: Base configuration | |||||
helper: Helper description | |||||
cfg_path: Path to the default yaml config | |||||
""" | |||||
parser = argparse.ArgumentParser(description='[REPLACE THIS at config.py]', | |||||
parents=[parser]) | |||||
helper = {} if helper is None else helper | |||||
choices = {} if choices is None else choices | |||||
for item in cfg: | |||||
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict): | |||||
help_description = helper[item] if item in helper else 'Please reference to {}'.format(cfg_path) | |||||
choice = choices[item] if item in choices else None | |||||
if isinstance(cfg[item], bool): | |||||
parser.add_argument('--' + item, type=ast.literal_eval, default=cfg[item], choices=choice, | |||||
help=help_description) | |||||
else: | |||||
parser.add_argument('--' + item, type=type(cfg[item]), default=cfg[item], choices=choice, | |||||
help=help_description) | |||||
args = parser.parse_args() | |||||
return args | |||||
def parse_yaml(yaml_path): | |||||
""" | |||||
Parse the yaml config file | |||||
Args: | |||||
yaml_path: Path to the yaml config | |||||
""" | |||||
with open(yaml_path, 'r') as fin: | |||||
try: | |||||
cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader) | |||||
cfgs = [x for x in cfgs] | |||||
if len(cfgs) == 1: | |||||
cfg_helper = {} | |||||
cfg = cfgs[0] | |||||
cfg_choices = {} | |||||
elif len(cfgs) == 2: | |||||
cfg, cfg_helper = cfgs | |||||
cfg_choices = {} | |||||
elif len(cfgs) == 3: | |||||
cfg, cfg_helper, cfg_choices = cfgs | |||||
else: | |||||
raise ValueError('At most 3 docs (config description for help, choices) are supported in config yaml') | |||||
print(cfg_helper) | |||||
except: | |||||
raise ValueError('Failed to parse yaml') | |||||
return cfg, cfg_helper, cfg_choices | |||||
def merge(args, cfg): | |||||
""" | |||||
Merge the base config from yaml file and command line arguments | |||||
Args: | |||||
args: command line arguments | |||||
cfg: Base configuration | |||||
""" | |||||
args_var = vars(args) | |||||
for item in args_var: | |||||
cfg[item] = args_var[item] | |||||
return cfg | |||||
def get_config(): | |||||
""" | |||||
Get Config according to the yaml file and cli arguments | |||||
""" | |||||
parser = argparse.ArgumentParser(description='default name', add_help=False) | |||||
current_dir = os.path.dirname(os.path.abspath(__file__)) | |||||
parser.add_argument('--config_path', type=str, default=os.path.join(current_dir, _config_path), | |||||
help='Config file path') | |||||
path_args, _ = parser.parse_known_args() | |||||
default, helper, choices = parse_yaml(path_args.config_path) | |||||
args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path) | |||||
final_config = merge(args, default) | |||||
pprint(final_config) | |||||
print("Please check the above information for the configurations", flush=True) | |||||
return Config(final_config) | |||||
config = get_config() |
@@ -0,0 +1,26 @@ | |||||
# Copyright 2021 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License Version 2.0(the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# you may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0# | |||||
# | |||||
# Unless required by applicable law or agreed to in writing software | |||||
# distributed under the License is distributed on an "AS IS" BASIS | |||||
# WITHOUT WARRANT IES OR CONITTONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ==================================================================================== | |||||
"""Device adapter for ModelArts""" | |||||
from .config import config | |||||
if config.enable_modelarts: | |||||
from .moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id | |||||
else: | |||||
from .local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id | |||||
__all__ = [ | |||||
'get_device_id', 'get_device_num', 'get_job_id', 'get_rank_id' | |||||
] |
@@ -0,0 +1,36 @@ | |||||
# Copyright 2021 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License Version 2.0(the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# you may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0# | |||||
# | |||||
# Unless required by applicable law or agreed to in writing software | |||||
# distributed under the License is distributed on an "AS IS" BASIS | |||||
# WITHOUT WARRANT IES OR CONITTONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ==================================================================================== | |||||
"""Local adapter""" | |||||
import os | |||||
def get_device_id(): | |||||
device_id = os.getenv('DEVICE_ID', '0') | |||||
return int(device_id) | |||||
def get_device_num(): | |||||
device_num = os.getenv('RANK_SIZE', '1') | |||||
return int(device_num) | |||||
def get_rank_id(): | |||||
global_rank_id = os.getenv('RANK_ID', '0') | |||||
return int(global_rank_id) | |||||
def get_job_id(): | |||||
return 'Local Job' |
@@ -0,0 +1,124 @@ | |||||
# Copyright 2021 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License Version 2.0(the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# you may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0# | |||||
# | |||||
# Unless required by applicable law or agreed to in writing software | |||||
# distributed under the License is distributed on an "AS IS" BASIS | |||||
# WITHOUT WARRANT IES OR CONITTONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ==================================================================================== | |||||
"""Moxing adapter for ModelArts""" | |||||
import os | |||||
import functools | |||||
from mindspore import context | |||||
from .config import config | |||||
_global_syn_count = 0 | |||||
def get_device_id(): | |||||
device_id = os.getenv('DEVICE_ID', '0') | |||||
return int(device_id) | |||||
def get_device_num(): | |||||
device_num = os.getenv('RANK_SIZE', '1') | |||||
return int(device_num) | |||||
def get_rank_id(): | |||||
global_rank_id = os.getenv('RANK_ID', '0') | |||||
return int(global_rank_id) | |||||
def get_job_id(): | |||||
job_id = os.getenv('JOB_ID') | |||||
job_id = job_id if job_id != "" else "default" | |||||
return job_id | |||||
def sync_data(from_path, to_path): | |||||
""" | |||||
Download data from remote obs to local directory if the first url is remote url and the second one is local | |||||
Uploca data from local directory to remote obs in contrast | |||||
""" | |||||
import moxing as mox | |||||
import time | |||||
global _global_syn_count | |||||
sync_lock = '/tmp/copy_sync.lock' + str(_global_syn_count) | |||||
_global_syn_count += 1 | |||||
# Each server contains 8 devices as most | |||||
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock): | |||||
print('from path: ', from_path) | |||||
print('to path: ', to_path) | |||||
mox.file.copy_parallel(from_path, to_path) | |||||
print('===finished data synchronization===') | |||||
try: | |||||
os.mknod(sync_lock) | |||||
except IOError: | |||||
pass | |||||
print('===save flag===') | |||||
while True: | |||||
if os.path.exists(sync_lock): | |||||
break | |||||
time.sleep(1) | |||||
print('Finish sync data from {} to {}'.format(from_path, to_path)) | |||||
def moxing_wrapper(pre_process=None, post_process=None): | |||||
""" | |||||
Moxing wrapper to download dataset and upload outputs | |||||
""" | |||||
def wrapper(run_func): | |||||
@functools.wraps(run_func) | |||||
def wrapped_func(*args, **kwargs): | |||||
# Download data from data_url | |||||
if config.enable_modelarts: | |||||
if config.data_url: | |||||
sync_data(config.data_url, config.data_path) | |||||
print('Dataset downloaded: ', os.listdir(config.data_path)) | |||||
if config.checkpoint_url: | |||||
if not os.path.exists(config.load_path): | |||||
# os.makedirs(config.load_path) | |||||
print('=' * 20 + 'makedirs') | |||||
if os.path.isdir(config.load_path): | |||||
print('=' * 20 + 'makedirs success') | |||||
else: | |||||
print('=' * 20 + 'makedirs fail') | |||||
sync_data(config.checkpoint_url, config.load_path) | |||||
print('Preload downloaded: ', os.listdir(config.load_path)) | |||||
if config.train_url: | |||||
sync_data(config.train_url, config.output_path) | |||||
print('Workspace downloaded: ', os.listdir(config.output_path)) | |||||
context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id()))) | |||||
config.device_num = get_device_num() | |||||
config.device_id = get_device_id() | |||||
if not os.path.exists(config.output_path): | |||||
os.makedirs(config.output_path) | |||||
if pre_process: | |||||
pre_process() | |||||
run_func(*args, **kwargs) | |||||
# Upload data to train_url | |||||
if config.enable_modelarts: | |||||
if post_process: | |||||
post_process() | |||||
if config.train_url: | |||||
print('Start to copy output directory') | |||||
sync_data(config.output_path, config.train_url) | |||||
return wrapped_func | |||||
return wrapper |
@@ -0,0 +1,172 @@ | |||||
# Copyright 2020 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================ | |||||
"""preprocess dataset""" | |||||
import random | |||||
import pickle | |||||
import numpy as np | |||||
import lmdb | |||||
from tqdm import tqdm | |||||
def combine_lmdbs(lmdb_paths, lmdb_save_path): | |||||
"""combine lmdb dataset""" | |||||
max_len = int((26 + 1) // 2) | |||||
character = '0123456789abcdefghijklmnopqrstuvwxyz' | |||||
env_save = lmdb.open( | |||||
lmdb_save_path, | |||||
map_size=1099511627776) | |||||
cnt = 0 | |||||
for lmdb_path in lmdb_paths: | |||||
env = lmdb.open(lmdb_path, max_readers=32, readonly=True, lock=False, readahead=False, meminit=False) | |||||
with env.begin(write=False) as txn: | |||||
n_samples = int(txn.get('num-samples'.encode())) | |||||
n_samples = n_samples | |||||
# Filtering | |||||
for index in tqdm(range(n_samples)): | |||||
index += 1 # lmdb starts with 1 | |||||
label_key = 'label-% '.encode() % index | |||||
label = txn.get(label_key).decode('utf-8') | |||||
if len(label) > max_len: | |||||
continue | |||||
illegal_sample = False | |||||
for char_item in label.lower(): | |||||
if char_item not in character: | |||||
illegal_sample = True | |||||
break | |||||
if illegal_sample: | |||||
continue | |||||
img_key = 'image-%09d'.encode() % index | |||||
imgbuf = txn.get(img_key) | |||||
with env_save.begin(write=True) as txn_save: | |||||
cnt += 1 | |||||
label_key_save = 'label-%09d'.encode() % cnt | |||||
label_save = label.encode() | |||||
image_key_save = 'image-%09d'.encode() % cnt | |||||
image_save = imgbuf | |||||
txn_save.put(label_key_save, label_save) | |||||
txn_save.put(image_key_save, image_save) | |||||
n_samples = cnt | |||||
with env_save.begin(write=True) as txn_save: | |||||
txn_save.put('num-samples'.encode(), str(n_samples).encode()) | |||||
def analyze_lmdb_label_length(lmdb_path, batch_size=192, num_of_combinations=1000): | |||||
"""analyze lmdb label""" | |||||
label_length_dict = {} | |||||
env = lmdb.open(lmdb_path, max_readers=32, readonly=True, lock=False, readahead=False, meminit=False) | |||||
with env.begin(write=False) as txn: | |||||
n_samples = int(txn.get('num-samples'.encode())) | |||||
n_samples = n_samples | |||||
for index in tqdm(range(n_samples)): | |||||
index += 1 # lmdb starts with 1 | |||||
label_key = 'label-%09d'.encode() % index | |||||
label = txn.get(label_key).decode('utf-8') | |||||
label_length = len(label) | |||||
if label_length in label_length_dict: | |||||
label_length_dict[label_length] += 1 | |||||
else: | |||||
label_length_dict[label_length] = 1 | |||||
sorted_label_length = sorted(label_length_dict.items(), key=lambda x: x[1], reverse=True) | |||||
label_length_sum = 0 | |||||
label_num = 0 | |||||
lengths = [] | |||||
p = [] | |||||
for l, num in sorted_label_length: | |||||
label_length_sum += l * num | |||||
label_num += num | |||||
p.append(num) | |||||
lengths.append(l) | |||||
for i, _ in enumerate(p): | |||||
p[i] /= label_num | |||||
average_overall_length = int(label_length_sum / label_num * batch_size) | |||||
def get_combinations_of_fix_length(fix_length, items, p, batch_size): | |||||
ret = np.random.choice(items, batch_size - 1, True, p) | |||||
cur_sum = sum(ret) | |||||
ret = list(ret) | |||||
if fix_length - cur_sum in items: | |||||
ret.append(fix_length - cur_sum) | |||||
else: | |||||
return None | |||||
return ret | |||||
result = [] | |||||
while len(result) < num_of_combinations: | |||||
ret = get_combinations_of_fix_length(average_overall_length, lengths, p, batch_size) | |||||
if ret is not None: | |||||
result.append(ret) | |||||
return result | |||||
def generate_fix_shape_index_list(lmdb_path, combinations, pkl_save_path, num_of_iters=70000): | |||||
"""generate fix shape index list""" | |||||
length_index_dict = {} | |||||
env = lmdb.open(lmdb_path, max_readers=32, readonly=True, lock=False, readahead=False, meminit=False) | |||||
with env.begin(write=False) as txn: | |||||
n_samples = int(txn.get('num-samples'.encode())) | |||||
n_samples = n_samples | |||||
for index in tqdm(range(n_samples)): | |||||
index += 1 # lmdb starts with 1 | |||||
label_key = 'label-%09d'.encode() % index | |||||
label = txn.get(label_key).decode('utf-8') | |||||
label_length = len(label) | |||||
if label_length in length_index_dict: | |||||
length_index_dict[label_length].append(index) | |||||
else: | |||||
length_index_dict[label_length] = [index] | |||||
ret = [] | |||||
for _ in range(num_of_iters): | |||||
comb = random.choice(combinations) | |||||
for l in comb: | |||||
ret.append(random.choice(length_index_dict[l])) | |||||
with open(pkl_save_path, 'wb') as f: | |||||
pickle.dump(ret, f, -1) | |||||
if __name__ == '__main__': | |||||
# step 1: combine the SynthText dataset and MJSynth dataset into a single lmdb file | |||||
print('Begin to combine multiple lmdb datasets') | |||||
combine_lmdbs(['/home/workspace/mindspore_dataset/CNNCTC_Data/1_ST/', | |||||
'/home/workspace/mindspore_dataset/CNNCTC_Data/MJ_train/'], | |||||
'/home/workspace/mindspore_dataset/CNNCTC_Data/ST_MJ') | |||||
# step 2: generate the order of input data, guarantee that the input batch shape is fixed | |||||
print('Begin to generate the index order of input data') | |||||
combination = analyze_lmdb_label_length('/home/workspace/mindspore_dataset/CNNCTC_Data/ST_MJ') | |||||
generate_fix_shape_index_list('/home/workspace/mindspore_dataset/CNNCTC_Data/ST_MJ', combination, | |||||
'/home/workspace/mindspore_dataset/CNNCTC_Data/st_mj_fixed_length_index_list.pkl') | |||||
print('Done') |
@@ -0,0 +1,102 @@ | |||||
# Copyright 2020 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================ | |||||
"""util file""" | |||||
import numpy as np | |||||
class AverageMeter(): | |||||
"""Computes and stores the average and current value""" | |||||
def __init__(self): | |||||
self.reset() | |||||
def reset(self): | |||||
self.val = 0 | |||||
self.avg = 0 | |||||
self.sum = 0 | |||||
self.count = 0 | |||||
def update(self, val, n=1): | |||||
self.val = val | |||||
self.sum += val * n | |||||
self.count += n | |||||
self.avg = self.sum / self.count | |||||
class CTCLabelConverter(): | |||||
""" Convert between text-label and text-index """ | |||||
def __init__(self, character): | |||||
# character (str): set of the possible characters. | |||||
dict_character = list(character) | |||||
self.dict = {} | |||||
for i, char in enumerate(dict_character): | |||||
self.dict[char] = i | |||||
self.character = dict_character + ['[blank]'] # dummy '[blank]' token for CTCLoss (index 0) | |||||
self.dict['[blank]'] = len(dict_character) | |||||
def encode(self, text): | |||||
"""convert text-label into text-index. | |||||
input: | |||||
text: text labels of each image. [batch_size] | |||||
output: | |||||
text: concatenated text index for CTCLoss. | |||||
[sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)] | |||||
length: length of each text. [batch_size] | |||||
""" | |||||
length = [len(s) for s in text] | |||||
text = ''.join(text) | |||||
text = [self.dict[char] for char in text] | |||||
return np.array(text), np.array(length) | |||||
def decode(self, text_index, length): | |||||
""" convert text-index into text-label. """ | |||||
texts = [] | |||||
index = 0 | |||||
for l in length: | |||||
t = text_index[index:index + l] | |||||
char_list = [] | |||||
for i in range(l): | |||||
# if t[i] != self.dict['[blank]'] and (not (i > 0 and t[i - 1] == t[i])): # removing repeated characters and blank. | |||||
if t[i] != self.dict['[blank]'] and ( | |||||
not (i > 0 and t[i - 1] == t[i])): # removing repeated characters and blank. | |||||
char_list.append(self.character[t[i]]) | |||||
text = ''.join(char_list) | |||||
texts.append(text) | |||||
index += l | |||||
return texts | |||||
def reverse_encode(self, text_index, length): | |||||
""" convert text-index into text-label. """ | |||||
texts = [] | |||||
index = 0 | |||||
for l in length: | |||||
t = text_index[index:index + l] | |||||
char_list = [] | |||||
for i in range(l): | |||||
if t[i] != self.dict['[blank]']: # removing repeated characters and blank. | |||||
char_list.append(self.character[t[i]]) | |||||
text = ''.join(char_list) | |||||
texts.append(text) | |||||
index += l | |||||
return texts |
@@ -0,0 +1,148 @@ | |||||
# Copyright 2020 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================ | |||||
"""cnnctc train""" | |||||
import numpy as np | |||||
import mindspore | |||||
import mindspore.common.dtype as mstype | |||||
from mindspore import context | |||||
from mindspore import Tensor | |||||
from mindspore.common import set_seed | |||||
from mindspore.communication.management import init, get_rank, get_group_size | |||||
from mindspore.dataset import GeneratorDataset | |||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig | |||||
from mindspore.train.model import Model | |||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||||
from src.callback import LossCallBack | |||||
from src.cnn_ctc import CNNCTC, CTCLoss, WithLossCell, CNNCTCTrainOneStepWithLossScaleCell | |||||
from src.dataset import STMJGeneratorBatchFixedLength, STMJGeneratorBatchFixedLengthPara | |||||
from src.lr_schedule import dynamic_lr | |||||
from src.model_utils.config import config | |||||
from src.model_utils.device_adapter import get_device_id | |||||
from src.model_utils.moxing_adapter import moxing_wrapper | |||||
set_seed(1) | |||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=False, save_graphs_path=".") | |||||
def dataset_creator(run_distribute): | |||||
"""dataset creator""" | |||||
if run_distribute: | |||||
st_dataset = STMJGeneratorBatchFixedLengthPara() | |||||
else: | |||||
st_dataset = STMJGeneratorBatchFixedLength() | |||||
ds = GeneratorDataset(st_dataset, | |||||
['img', 'label_indices', 'text', 'sequence_length'], | |||||
num_parallel_workers=8) | |||||
return ds | |||||
def modelarts_pre_process(): | |||||
pass | |||||
@moxing_wrapper(pre_process=modelarts_pre_process) | |||||
def train(): | |||||
"""train cnnctc model""" | |||||
target = config.device_target | |||||
context.set_context(device_target=target) | |||||
if target == "Ascend": | |||||
device_id = get_device_id() | |||||
context.set_context(device_id=device_id) | |||||
if config.run_distribute: | |||||
init() | |||||
context.set_auto_parallel_context(parallel_mode="data_parallel") | |||||
ckpt_save_dir = config.SAVE_PATH | |||||
else: | |||||
# GPU target | |||||
device_id = get_device_id() | |||||
context.set_context(device_id=device_id) | |||||
if config.run_distribute: | |||||
init() | |||||
context.set_auto_parallel_context(device_num=get_group_size(), | |||||
parallel_mode="data_parallel", | |||||
gradients_mean=False, | |||||
gradient_fp32_sync=False) | |||||
ckpt_save_dir = config.SAVE_PATH + "ckpt_" + str(get_rank()) + "/" | |||||
print(ckpt_save_dir) | |||||
else: | |||||
ckpt_save_dir = config.SAVE_PATH + "ckpt_standalone/" | |||||
ds = dataset_creator(config.run_distribute) | |||||
net = CNNCTC(config.NUM_CLASS, config.HIDDEN_SIZE, config.FINAL_FEATURE_WIDTH) | |||||
net.set_train(True) | |||||
if config.PRED_TRAINED: | |||||
param_dict = load_checkpoint(config.PRED_TRAINED) | |||||
load_param_into_net(net, param_dict) | |||||
print('parameters loaded!') | |||||
else: | |||||
print('train from scratch...') | |||||
criterion = CTCLoss() | |||||
dataset_size = ds.get_dataset_size() | |||||
lr = Tensor(dynamic_lr(config, dataset_size), mstype.float32) | |||||
opt = mindspore.nn.RMSProp(params=net.trainable_params(), | |||||
centered=True, | |||||
learning_rate=lr, | |||||
momentum=config.MOMENTUM, | |||||
loss_scale=config.LOSS_SCALE) | |||||
net = WithLossCell(net, criterion) | |||||
if target == "Ascend": | |||||
loss_scale_manager = mindspore.train.loss_scale_manager.FixedLossScaleManager( | |||||
config.LOSS_SCALE, False) | |||||
net.set_train(True) | |||||
model = Model(net, optimizer=opt, loss_scale_manager=loss_scale_manager, amp_level="O2") | |||||
else: | |||||
scaling_sens = Tensor(np.full((1), config.LOSS_SCALE), dtype=mstype.float32) | |||||
net = CNNCTCTrainOneStepWithLossScaleCell(net, opt, scaling_sens) | |||||
net.set_train(True) | |||||
model = Model(net) | |||||
callback = LossCallBack() | |||||
config_ck = CheckpointConfig(save_checkpoint_steps=config.SAVE_CKPT_PER_N_STEP, | |||||
keep_checkpoint_max=config.KEEP_CKPT_MAX_NUM) | |||||
ckpoint_cb = ModelCheckpoint(prefix="CNNCTC", config=config_ck, directory=ckpt_save_dir) | |||||
if config.run_distribute: | |||||
if device_id == 0: | |||||
model.train(config.TRAIN_EPOCHS, | |||||
ds, | |||||
callbacks=[callback, ckpoint_cb], | |||||
dataset_sink_mode=False) | |||||
else: | |||||
model.train(config.TRAIN_EPOCHS, ds, callbacks=[callback], dataset_sink_mode=False) | |||||
else: | |||||
model.train(config.TRAIN_EPOCHS, | |||||
ds, | |||||
callbacks=[callback, ckpoint_cb], | |||||
dataset_sink_mode=False) | |||||
if __name__ == '__main__': | |||||
train() |
@@ -0,0 +1,76 @@ | |||||
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unlesee you know exactly what you are doing) | |||||
enable_modelarts: False | |||||
# url for modelarts | |||||
data_url: "" | |||||
train_url: "" | |||||
checkpoint_url: "" | |||||
# path for local | |||||
data_path: "/cache/data" | |||||
output_path: "/cache/train" | |||||
load_path: "/cache/checkpoint_path" | |||||
device_target: "GPU" | |||||
enable_profiling: False | |||||
# ====================================================================================== | |||||
# Training options | |||||
CHARACTER: "0123456789abcdefghijklmnopqrstuvwxyz" | |||||
# NUM_CLASS = len(CHARACTER) + 1 | |||||
NUM_CLASS: 37 | |||||
HIDDEN_SIZE: 512 | |||||
FINAL_FEATURE_WIDTH: 26 | |||||
# dataset config | |||||
IMG_H: 32 | |||||
IMG_W: 100 | |||||
TRAIN_DATASET_PATH: "/opt/dataset/CNNCTC_data/MJ-ST-IIIT/ST_MJ/" | |||||
TRAIN_DATASET_INDEX_PATH: "/opt/dataset/CNNCTC_data/MJ-ST-IIIT/st_mj_fixed_length_index_list.pkl" | |||||
TRAIN_BATCH_SIZE: 192 | |||||
TRAIN_EPOCHS: 3 | |||||
# training config | |||||
run_distribute: False | |||||
PRED_TRAINED: "" | |||||
SAVE_PATH: "./" | |||||
#LR | |||||
base_lr: 0.0005 | |||||
warmup_step: 2000 | |||||
warmup_ratio: 0.0625 | |||||
MOMENTUM: 0.8 | |||||
LOSS_SCALE: 8096 | |||||
SAVE_CKPT_PER_N_STEP: 2000 | |||||
KEEP_CKPT_MAX_NUM: 5 | |||||
# ====================================================================================== | |||||
# Eval options | |||||
TEST_DATASET_PATH: "/opt/dataset/CNNCTC_data/MJ-ST-IIIT/IIIT5k_3000" | |||||
#TEST_DATASET_PATH: "/home/mindarmour/examples/natural_robustness/ocr_evaluate/data" | |||||
TEST_BATCH_SIZE: 256 | |||||
CHECKPOINT_PATH: "/home/mindarmour/examples/natural_robustness/ocr_evaluate/cnn_ctc/ckpt_standalone/CNNCTC-3_70000.ckpt" | |||||
ADV_TEST_DATASET_PATH: "/home/mindarmour/examples/natural_robustness/ocr_evaluate/data" | |||||
IS_ADV: False | |||||
# export options | |||||
device_id: 0 | |||||
file_name: "cnnctc" | |||||
file_format: "MINDIR" | |||||
ckpt_file: "" | |||||
# 310 infer | |||||
result_path: "" | |||||
label_path: "" | |||||
preprocess_output: "" | |||||
--- | |||||
# Help description for each configuration | |||||
enable_modelarts: "Whether training on modelarts default: False" | |||||
data_url: "Url for modelarts" | |||||
train_url: "Url for modelarts" | |||||
data_path: "The location of input data" | |||||
output_pah: "The location of the output file" | |||||
device_target: "device id of GPU or Ascend. (Default: None)" | |||||
enable_profiling: "Whether enable profiling while training default: False" | |||||
file_name: "CNN&CTC output air name" | |||||
file_format: "choices [AIR, MINDIR]" | |||||
ckpt_file: "CNN&CTC ckpt file" |
@@ -0,0 +1,100 @@ | |||||
# Copyright 2020 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================ | |||||
"""cnnctc eval""" | |||||
import numpy as np | |||||
import lmdb | |||||
from mindspore import Tensor, context | |||||
import mindspore.common.dtype as mstype | |||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||||
from mindspore.dataset import GeneratorDataset | |||||
from cnn_ctc.src.util import CTCLabelConverter | |||||
from cnn_ctc.src.dataset import iiit_generator_batch, adv_iiit_generator_batch | |||||
from cnn_ctc.src.cnn_ctc import CNNCTC | |||||
from cnn_ctc.src.model_utils.config import config | |||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=False, | |||||
save_graphs_path=".") | |||||
def test_dataset_creator(is_adv=False): | |||||
if is_adv: | |||||
ds = GeneratorDataset(adv_iiit_generator_batch(), ['img', 'label_indices', 'text', | |||||
'sequence_length', 'label_str']) | |||||
else: | |||||
ds = GeneratorDataset(iiit_generator_batch, ['img', 'label_indices', 'text', | |||||
'sequence_length', 'label_str']) | |||||
return ds | |||||
def test(lmdb_save_path): | |||||
"""eval cnnctc model on begin and perturb data.""" | |||||
target = config.device_target | |||||
context.set_context(device_target=target) | |||||
ds = test_dataset_creator(is_adv=config.IS_ADV) | |||||
net = CNNCTC(config.NUM_CLASS, config.HIDDEN_SIZE, config.FINAL_FEATURE_WIDTH) | |||||
ckpt_path = config.CHECKPOINT_PATH | |||||
param_dict = load_checkpoint(ckpt_path) | |||||
load_param_into_net(net, param_dict) | |||||
print('parameters loaded! from: ', ckpt_path) | |||||
converter = CTCLabelConverter(config.CHARACTER) | |||||
count = 0 | |||||
correct_count = 0 | |||||
env_save = lmdb.open(lmdb_save_path, map_size=1099511627776) | |||||
with env_save.begin(write=True) as txn_save: | |||||
for data in ds.create_tuple_iterator(): | |||||
img, _, text, _, length = data | |||||
img_tensor = Tensor(img, mstype.float32) | |||||
model_predict = net(img_tensor) | |||||
model_predict = np.squeeze(model_predict.asnumpy()) | |||||
preds_size = np.array([model_predict.shape[1]] * config.TEST_BATCH_SIZE) | |||||
preds_index = np.argmax(model_predict, 2) | |||||
preds_index = np.reshape(preds_index, [-1]) | |||||
preds_str = converter.decode(preds_index, preds_size) | |||||
label_str = converter.reverse_encode(text.asnumpy(), length.asnumpy()) | |||||
print("Prediction samples: \n", preds_str[:5]) | |||||
print("Ground truth: \n", label_str[:5]) | |||||
for pred, label in zip(preds_str, label_str): | |||||
if pred == label: | |||||
correct_count += 1 | |||||
count += 1 | |||||
if config.IS_ADV: | |||||
pred_key = 'adv_pred-%09d'.encode() % count | |||||
else: | |||||
pred_key = 'pred-%09d'.encode() % count | |||||
txn_save.put(pred_key, pred.encode()) | |||||
accuracy = correct_count / count | |||||
return accuracy | |||||
if __name__ == '__main__': | |||||
save_path = config.ADV_TEST_DATASET_PATH | |||||
config.IS_ADV = False | |||||
config.TEST_DATASET_PATH = save_path | |||||
ori_acc = test(lmdb_save_path=save_path) | |||||
config.IS_ADV = True | |||||
adv_acc = test(lmdb_save_path=save_path) | |||||
print('Accuracy of benign sample: ', ori_acc) | |||||
print('Accuracy of perturbed sample: ', adv_acc) |
@@ -0,0 +1,139 @@ | |||||
# Copyright 2020 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================ | |||||
"""Generated natural robustness samples. """ | |||||
import sys | |||||
import json | |||||
import time | |||||
import lmdb | |||||
from mindspore_serving.client import Client | |||||
from cnn_ctc.src.model_utils.config import config | |||||
config_perturb = [ | |||||
{"method": "Contrast", "params": {"alpha": 1.5, "beta": 0}}, | |||||
{"method": "GaussianBlur", "params": {"ksize": 5}}, | |||||
{"method": "SaltAndPepperNoise", "params": {"factor": 0.05}}, | |||||
{"method": "Translate", "params": {"x_bias": 0.1, "y_bias": -0.1}}, | |||||
{"method": "Scale", "params": {"factor_x": 0.8, "factor_y": 0.8}}, | |||||
{"method": "Shear", "params": {"factor": 1.5, "direction": "horizontal"}}, | |||||
{"method": "Rotate", "params": {"angle": 30}}, | |||||
{"method": "MotionBlur", "params": {"degree": 5, "angle": 45}}, | |||||
{"method": "GradientBlur", "params": {"point": [50, 100], "kernel_num": 3, "center": True}}, | |||||
{"method": "GradientLuminance", "params": {"color_start": [255, 255, 255], "color_end": [0, 0, 0], | |||||
"start_point": [100, 150], "scope": 0.3, | |||||
"bright_rate": 0.3, "pattern": "light", "mode": "circle"}}, | |||||
{"method": "GradientLuminance", "params": {"color_start": [255, 255, 255], | |||||
"color_end": [0, 0, 0], "start_point": [150, 200], | |||||
"scope": 0.3, "pattern": "light", "mode": "horizontal"}}, | |||||
{"method": "GradientLuminance", "params": {"color_start": [255, 255, 255], "color_end": [0, 0, 0], | |||||
"start_point": [150, 200], "scope": 0.3, | |||||
"pattern": "light", "mode": "vertical"}}, | |||||
{"method": "Curve", "params": {"curves": 0.5, "depth": 3, "mode": "vertical"}}, | |||||
{"method": "Perspective", "params": {"ori_pos": [[0, 0], [0, 800], [800, 0], [800, 800]], | |||||
"dst_pos": [[10, 0], [0, 800], [790, 0], [800, 800]]}}, | |||||
] | |||||
def generate_adv_iii5t_3000(lmdb_paths, lmdb_save_path, perturb_config): | |||||
"""generate perturb iii5t_3000""" | |||||
max_len = int((26 + 1) // 2) | |||||
instances = [] | |||||
methods_number = 1 | |||||
outputs_number = 2 | |||||
perturb_config = json.dumps(perturb_config) | |||||
env = lmdb.open(lmdb_paths, max_readers=32, readonly=True, lock=False, readahead=False, meminit=False) | |||||
if not env: | |||||
print('cannot create lmdb from %s' % (lmdb_paths)) | |||||
sys.exit(0) | |||||
with env.begin(write=False) as txn: | |||||
n_samples = int(txn.get('num-samples'.encode())) | |||||
# Filtering | |||||
filtered_labels = [] | |||||
filtered_index_list = [] | |||||
for index in range(n_samples): | |||||
index += 1 # lmdb starts with 1 | |||||
label_key = 'label-%09d'.encode() % index | |||||
label = txn.get(label_key).decode('utf-8') | |||||
if len(label) > max_len: continue | |||||
illegal_sample = False | |||||
for char_item in label.lower(): | |||||
if char_item not in config.CHARACTER: | |||||
illegal_sample = True | |||||
break | |||||
if illegal_sample: continue | |||||
filtered_labels.append(label) | |||||
filtered_index_list.append(index) | |||||
img_key = 'image-%09d'.encode() % index | |||||
imgbuf = txn.get(img_key) | |||||
instances.append({"img": imgbuf, 'perturb_config': perturb_config, "methods_number": methods_number, | |||||
"outputs_number": outputs_number}) | |||||
print(f'num of samples in IIIT dataset: {len(filtered_index_list)}') | |||||
client = Client("10.113.216.54:5500", "perturbation", "natural_perturbation") | |||||
start_time = time.time() | |||||
result = client.infer(instances) | |||||
end_time = time.time() | |||||
print('generated natural perturbs images cost: ', end_time - start_time) | |||||
env_save = lmdb.open(lmdb_save_path, map_size=1099511627776) | |||||
txn = env.begin(write=False) | |||||
with env_save.begin(write=True) as txn_save: | |||||
new_index = 1 | |||||
for i, index in enumerate(filtered_index_list): | |||||
try: | |||||
file_names = result[i]['file_names'].split(';') | |||||
except: | |||||
error_msg = result[i] | |||||
raise ValueError(error_msg) | |||||
length = result[i]['file_length'].tolist() | |||||
before = 0 | |||||
label = filtered_labels[i] | |||||
label = label.encode() | |||||
img_key = 'image-%09d'.encode() % index | |||||
ori_img = txn.get(img_key) | |||||
names_dict = result[i]['names_dict'] | |||||
names_dict = json.loads(names_dict) | |||||
for name, leng in zip(file_names, length): | |||||
label_key = 'label-%09d'.encode() % new_index | |||||
txn_save.put(label_key, label) | |||||
img_key = 'image-%09d'.encode() % new_index | |||||
adv_img = result[i]['results'] | |||||
adv_img = adv_img[before:before + leng] | |||||
adv_img_key = 'adv_image-%09d'.encode() % new_index | |||||
txn_save.put(img_key, ori_img) | |||||
txn_save.put(adv_img_key, adv_img) | |||||
adv_info_key = 'adv_info-%09d'.encode() % new_index | |||||
adv_info = json.dumps(names_dict[name]).encode() | |||||
txn_save.put(adv_info_key, adv_info) | |||||
before = before + leng | |||||
new_index += 1 | |||||
txn_save.put("num-samples".encode(), str(new_index - 1).encode()) | |||||
env.close() | |||||
if __name__ == '__main__': | |||||
save_path_lmdb = config.ADV_TEST_DATASET_PATH | |||||
generate_adv_iii5t_3000(config.TEST_DATASET_PATH, save_path_lmdb, config_perturb) |
@@ -0,0 +1,508 @@ | |||||
# 对OCR模型CNN-CTC的鲁棒性评测 | |||||
## 概述 | |||||
本教程主要演示利用自然扰动serving服务,对OCR模型CNN-CTC做一个简单的鲁棒性评测。先基于serving生成多种自然扰动样本数据集,然后根据CNN-CTC模型在自然扰动样本数据集上的表现来评估模型的鲁棒性。 | |||||
## 环境要求 | |||||
- 硬件 | |||||
- Ascend或GPU处理器搭建硬件环境。 | |||||
- 依赖 | |||||
- [MindSpore](https://www.mindspore.cn/install) | |||||
- indSpore-Serving=1.6.0 | |||||
- MindArmour | |||||
## 脚本说明 | |||||
### 代码结构 | |||||
```bash | |||||
|-- natural_robustness | |||||
|-- serving # 提供自然扰动样本生成的serving服务 | |||||
|-- ocr_evaluate | |||||
|-- cnn_ctc # cnn_ctc模型相关:模型的训练、推理、前后处理 | |||||
|-- data # 存储实验分析数据 | |||||
|-- default_config.yaml # 参数配置 | |||||
|-- generate_adv_samples.py # 用于生成自然扰动样本 | |||||
|-- eval_and_save.py # cnn_ctc在扰动样本上推理,并保存推理结果 | |||||
|-- analyse.py # 分析cnn_ctc模型的鲁棒性 | |||||
``` | |||||
### 脚本参数 | |||||
在`default_config.yaml`中可以同时配置训练参数、推理参数、鲁棒性评测参数。这里我们重点关注在评测过程中使用到的参数,以及需要用户配置的参数,其余参数说明参考[CNN-CTC教程](https://gitee.com/mindspore/models/tree/master/official/cv/cnnctc)。 | |||||
训练参数: | |||||
- `--TRAIN_DATASET_PATH`:训练数据集的路径。 | |||||
- `--TRAIN_DATASET_INDEX_PATH`:决定顺序的训练数据集索引文件的路径。。 | |||||
- `--SAVE_PATH`:模型检查点文件保存路径。 | |||||
推理和评估参数: | |||||
- `--TEST_DATASET_PATH`:测试数据集路径 | |||||
- `--CHECKPOINT_PATH`:checkpoint路径 | |||||
- `--ADV_TEST_DATASET_PATH`:扰动样本数据集路径 | |||||
- `--IS_ADV`:是否使用扰动样本进行测试 | |||||
### 模型与数据 | |||||
数据处理与模型训练参考[CNN-CTC教程](https://gitee.com/mindspore/models/tree/master/official/cv/cnnctc)。评测任务需基于该教程获得预处理后的数据集和checkpoint模型文件。 | |||||
#### 模型 | |||||
被评测的模型为基于MindSpore实现的OCR模型CNN-CTC,改模型主要针对场景文字识别(Scene Text Recognition)任务,用CNN模型提取特征,用CTC(Connectionist temporal classification)预测输出序列。具体说明和实现参考[CNN-CTC](https://gitee.com/mindspore/models/tree/master/official/cv/cnnctc)。 | |||||
[论文](https://arxiv.org/abs/1904.01906): J. Baek, G. Kim, J. Lee, S. Park, D. Han, S. Yun, S. J. Oh, and H. Lee, “What is wrong with scene text recognition model comparisons? dataset and model analysis,” ArXiv, vol. abs/1904.01906, 2019. | |||||
#### 数据集 | |||||
训练数据集:[MJSynth](https://www.robots.ox.ac.uk/~vgg/data/text/)和[SynthText](https://github.com/ankush-me/SynthText) | |||||
测试数据集:[The IIIT 5K-word dataset](https://cvit.iiit.ac.in/research/projects/cvit-projects/the-iiit-5k-word-dataset) | |||||
##### 数据集处理: | |||||
- 步骤1: | |||||
所有数据集均经过预处理,以.lmdb格式存储,点击[**此处**](https://gitee.com/link?target=https%3A%2F%2Fdrive.google.com%2Fdrive%2Ffolders%2F192UfE9agQUMNq6AgU3_E05_FcPZK4hyt)可下载。 | |||||
- 步骤2: | |||||
解压下载的文件,重命名MJSynth数据集为MJ,SynthText数据集为ST,IIIT数据集为IIIT。 | |||||
- 步骤3: | |||||
将上述三个数据集移至`cnctc_data`文件夹中,结构如下: | |||||
``` | |||||
|--- CNNCTC/ | |||||
|--- cnnctc_data/ | |||||
|--- ST/ | |||||
data.mdb | |||||
lock.mdb | |||||
|--- MJ/ | |||||
data.mdb | |||||
lock.mdb | |||||
|--- IIIT/ | |||||
data.mdb | |||||
lock.mdb | |||||
...... | |||||
``` | |||||
- 步骤4: | |||||
预处理数据集: | |||||
```bash | |||||
cd ocr_evaluate/cnn_ctc | |||||
python src/preprocess_dataset.py | |||||
``` | |||||
这个过程大概需要75分钟。 | |||||
预处理后的数据集为.lmdb格式,以键值对方式存储: | |||||
| key | value | | |||||
| ----------- | ---------------------- | | |||||
| label-%09d | 图片的真实标签 | | |||||
| image-%09d | 原始图片数据 | | |||||
| num-samples | lmdb数据集中的样本数量 | | |||||
`%09d`为:长度为9的数字串。形如:label-000000001。 | |||||
##### 模型训练 | |||||
训练CNN-CTC模型,得到checkpoint文件: | |||||
```bash | |||||
cd ocr_evaluate/cnn_ctc | |||||
bash scripts/run_standalone_train_gpu.sh | |||||
``` | |||||
### 基于自然扰动serving生成评测数据集 | |||||
1. 启动自然扰动serving服务。具体说明参考:[ 自然扰动样本生成serving服务](https://gitee.com/mindspore/mindarmour/blob/master/examples/natural_robustness/serving/README.md) | |||||
```bash | |||||
cd serving/server/ | |||||
python serving_server.py | |||||
``` | |||||
2. 基于serving服务,生成测评数据集。 | |||||
1. 在default_config.yaml中配置原来测试样本数据路径`TEST_DATASET_PATH`和生成扰动样本数据集路径`ADV_TEST_DATASET_PATH`。例如: | |||||
```yaml | |||||
TEST_DATASET_PATH: "/opt/dataset/CNNCTC_data/MJ-ST-IIIT/IIIT5k_3000" | |||||
ADV_TEST_DATASET_PATH: "/home/mindarmour/examples/natural_robustness/ocr_evaluate/data" | |||||
``` | |||||
2. 核心代码说明: | |||||
1. 配置扰动方法,目前可选的扰动方法及参数配置参考[image transform methods](https://gitee.com/mindspore/mindarmour/tree/master/mindarmour/natural_robustness/transform/image)。下面是一个配置例子。 | |||||
```python | |||||
PerturbConfig = [ | |||||
{"method": "Contrast", "params": {"alpha": 1.5, "beta": 0}}, | |||||
{"method": "GaussianBlur", "params": {"ksize": 5}}, | |||||
{"method": "SaltAndPepperNoise", "params": {"factor": 0.05}}, | |||||
{"method": "Translate", "params": {"x_bias": 0.1, "y_bias": -0.1}}, | |||||
{"method": "Scale", "params": {"factor_x": 0.8, "factor_y": 0.8}}, | |||||
{"method": "Shear", "params": {"factor": 1.5, "direction": "horizontal"}}, | |||||
{"method": "Rotate", "params": {"angle": 30}}, | |||||
{"method": "MotionBlur", "params": {"degree": 5, "angle": 45}}, | |||||
{"method": "GradientBlur", "params": {"point": [50, 100], "kernel_num": 3, "center": True}}, | |||||
{"method": "GradientLuminance", "params": {"color_start": [255, 255, 255], "color_end": [0, 0, 0], "start_point": [100, 150], "scope": 0.3, "bright_rate": 0.3, "pattern": "light", "mode": "circle"}}, | |||||
{"method": "GradientLuminance", "params": {"color_start": [255, 255, 255], "color_end": [0, 0, 0], "start_point": [150, 200], "scope": 0.3, "pattern": "light", "mode": "horizontal"}}, | |||||
{"method": "GradientLuminance", "params": {"color_start": [255, 255, 255], "color_end": [0, 0, 0], "start_point": [150, 200], "scope": 0.3, "pattern": "light", "mode": "vertical"}}, | |||||
{"method": "Curve", "params": {"curves": 0.5, "depth": 3, "mode": "vertical"}}, | |||||
{"method": "Perspective", "params": {"ori_pos": [[0, 0], [0, 800], [800, 0], [800, 800]], "dst_pos": [[10, 0], [0, 800], [790, 0], [800, 800]]}}, | |||||
] | |||||
``` | |||||
2. 准备需要扰动的数据。 | |||||
```python | |||||
instances = [] | |||||
methods_number = 1 | |||||
outputs_number = 2 | |||||
perturb_config = json.dumps(perturb_config) | |||||
env = lmdb.open(lmdb_paths, max_readers=32, readonly=True, lock=False, readahead=False, meminit=False) | |||||
if not env: | |||||
print('cannot create lmdb from %s' % (lmdb_paths)) | |||||
sys.exit(0) | |||||
with env.begin(write=False) as txn: | |||||
n_samples = int(txn.get('num-samples'.encode())) | |||||
# Filtering | |||||
filtered_labels = [] | |||||
filtered_index_list = [] | |||||
for index in range(n_samples): | |||||
index += 1 # lmdb starts with 1 | |||||
label_key = 'label-%09d'.encode() % index | |||||
label = txn.get(label_key).decode('utf-8') | |||||
if len(label) > max_len: continue | |||||
illegal_sample = False | |||||
for char_item in label.lower(): | |||||
if char_item not in config.CHARACTER: | |||||
illegal_sample = True | |||||
break | |||||
if illegal_sample: continue | |||||
filtered_labels.append(label) | |||||
filtered_index_list.append(index) | |||||
img_key = 'image-%09d'.encode() % index | |||||
imgbuf = txn.get(img_key) | |||||
instances.append({"img": imgbuf, 'perturb_config': perturb_config, "methods_number": methods_number, | |||||
"outputs_number": outputs_number}) | |||||
print(f'num of samples in IIIT daaset: {len(filtered_index_list)}') | |||||
``` | |||||
3. 请求自然扰动serving服务,并保存serving返回的数据 | |||||
```python | |||||
client = Client("10.113.216.54:5500", "perturbation", "natural_perturbation") | |||||
start_time = time.time() | |||||
result = client.infer(instances) | |||||
end_time = time.time() | |||||
print('generated natural perturbs images cost: ', end_time - start_time) | |||||
env_save = lmdb.open(lmdb_save_path, map_size=1099511627776) | |||||
txn = env.begin(write=False) | |||||
with env_save.begin(write=True) as txn_save: | |||||
new_index = 1 | |||||
for i, index in enumerate(filtered_index_list): | |||||
try: | |||||
file_names = result[i]['file_names'].split(';') | |||||
except: | |||||
print('index: ', index) | |||||
print(result[i]) | |||||
length = result[i]['file_length'].tolist() | |||||
before = 0 | |||||
label = filtered_labels[i] | |||||
label = label.encode() | |||||
img_key = 'image-%09d'.encode() % index | |||||
ori_img = txn.get(img_key) | |||||
names_dict = result[i]['names_dict'] | |||||
names_dict = json.loads(names_dict) | |||||
for name, leng in zip(file_names, length): | |||||
label_key = 'label-%09d'.encode() % new_index | |||||
txn_save.put(label_key, label) | |||||
img_key = 'image-%09d'.encode() % new_index | |||||
adv_img = result[i]['results'] | |||||
adv_img = adv_img[before:before + leng] | |||||
adv_img_key = 'adv_image-%09d'.encode() % new_index | |||||
txn_save.put(img_key, ori_img) | |||||
txn_save.put(adv_img_key, adv_img) | |||||
adv_info_key = 'adv_info-%09d'.encode() % new_index | |||||
adv_info = json.dumps(names_dict[name]).encode() | |||||
txn_save.put(adv_info_key, adv_info) | |||||
before = before + leng | |||||
new_index += 1 | |||||
xn_save.put("num-samples".encode(),str(new_index - 1).encode()) | |||||
env.close() | |||||
``` | |||||
3. 执行自然扰动样本生成脚本: | |||||
```bash | |||||
python generate_adv_samples.py | |||||
``` | |||||
4. 生成的自然扰动数据为.lmdb格式,包含下列数据项: | |||||
| key | value | | |||||
| -------------- | ---------------------------- | | |||||
| label-%09d | 图片的真实标签 | | |||||
| image-%09d | 原始图片数据 | | |||||
| adv_image-%09d | 生成的扰动图片数据 | | |||||
| adv_info-%09d | 扰动信息,包含扰动方法和参数 | | |||||
| num-samples | lmdb数据集中的样本数量 | | |||||
### CNN-CTC模型在生成扰动数据集上推理 | |||||
1. 在default_config.yaml中将测试数据集路径`TEST_DATASET_PATH`设置成和生成扰动样本数据集路径`ADV_TEST_DATASET_PATH`一样的。例如: | |||||
```yaml | |||||
TEST_DATASET_PATH: "/home/mindarmour/examples/natural_robustness/ocr_evaluate/data" | |||||
ADV_TEST_DATASET_PATH: "/home/mindarmour/examples/natural_robustness/ocr_evaluate/data" | |||||
``` | |||||
2. 核心脚本说明 | |||||
1. 加载模型和数据集 | |||||
```python | |||||
ds = test_dataset_creator(is_adv=config.IS_ADV) | |||||
net = CNNCTC(config.NUM_CLASS, config.HIDDEN_SIZE, config.FINAL_FEATURE_WIDTH) | |||||
ckpt_path = config.CHECKPOINT_PATH | |||||
param_dict = load_checkpoint(ckpt_path) | |||||
load_param_into_net(net, param_dict) | |||||
print('parameters loaded! from: ', ckpt_path) | |||||
``` | |||||
2. 推理并保存模型对于原始样本和扰动样本的推理结果。 | |||||
```python | |||||
env_save = lmdb.open(lmdb_save_path, map_size=1099511627776) | |||||
with env_save.begin(write=True) as txn_save: | |||||
for data in ds.create_tuple_iterator(): | |||||
img, _, text, _, length = data | |||||
img_tensor = Tensor(img, mstype.float32) | |||||
model_predict = net(img_tensor) | |||||
model_predict = np.squeeze(model_predict.asnumpy()) | |||||
preds_size = np.array([model_predict.shape[1]] * config.TEST_BATCH_SIZE) | |||||
preds_index = np.argmax(model_predict, 2) | |||||
preds_index = np.reshape(preds_index, [-1]) | |||||
preds_str = converter.decode(preds_index, preds_size) | |||||
label_str = converter.reverse_encode(text.asnumpy(), length.asnumpy()) | |||||
print("Prediction samples: \n", preds_str[:5]) | |||||
print("Ground truth: \n", label_str[:5]) | |||||
for pred, label in zip(preds_str, label_str): | |||||
if pred == label: | |||||
correct_count += 1 | |||||
count += 1 | |||||
if config.IS_ADV: | |||||
pred_key = 'adv_pred-%09d'.encode() % count | |||||
else: | |||||
pred_key = 'pred-%09d'.encode() % count | |||||
txn_save.put(pred_key, pred.encode()) | |||||
accuracy = correct_count / count | |||||
``` | |||||
3. 执行eval_and_save.py脚本: | |||||
```bash | |||||
python eval_and_save.py | |||||
``` | |||||
CNN-CTC模型在生成的自然扰动数据集上进行推理,并在`ADV_TEST_DATASET_PATH`中保存模型对于每个样本的推理结果。 | |||||
数据集中新增数据项: | |||||
| Key | Value | | |||||
| ------------- | ---------------------------- | | |||||
| pred-%09d | 模型对原始图片数据的预测结果 | | |||||
| adv_pred-%09d | 模型对扰动图片数据的预测结果 | | |||||
模型对于真实样本的预测结果: | |||||
```bash | |||||
Prediction samples: | |||||
['private', 'private', 'parking', 'parking', 'salutes'] | |||||
Ground truth: | |||||
['private', 'private', 'parking', 'parking', 'salutes'] | |||||
Prediction samples: | |||||
['venus', 'venus', 'its', 'its', 'the'] | |||||
Ground truth: | |||||
['venus', 'venus', 'its', 'its', 'the'] | |||||
Prediction samples: | |||||
['summer', 'summer', 'joeys', 'joeys', 'think'] | |||||
Ground truth: | |||||
['summer', 'summer', 'joes', 'joes', 'think'] | |||||
... | |||||
``` | |||||
模型对于自然扰动样本的预测结果: | |||||
```bash | |||||
Prediction samples: | |||||
['private', 'private', 'parking', 'parking', 'salutes'] | |||||
Ground truth: | |||||
['private', 'private', 'parking', 'parking', 'salutes'] | |||||
Prediction samples: | |||||
['dams', 'vares', 'its', 'its', 'the'] | |||||
Ground truth: | |||||
['venus', 'venus', 'its', 'its', 'the'] | |||||
Prediction samples: | |||||
['sune', 'summer', '', 'joeys', 'think'] | |||||
Ground truth: | |||||
['summer', 'summer', 'joes', 'joes', 'think'] | |||||
... | |||||
``` | |||||
模型在原始测试数据集和自然扰动数据集上的准确率: | |||||
```bash | |||||
num of samples in IIIT dataset: 5952 | |||||
Accuracy of benign sample: 0.8546195652173914 | |||||
Accuracy of perturbed sample: 0.6126019021739131 | |||||
``` | |||||
### 鲁棒性分析 | |||||
根据CNN-CTC模型在扰动数据集上的表现进行统计分析。运行脚本analyse.py | |||||
```bash | |||||
python analyse.py | |||||
``` | |||||
分析结果: | |||||
```bash | |||||
Number of samples in analyse dataset: 5952 | |||||
Accuracy of original dataset: 0.46127717391304346 | |||||
Accuracy of adversarial dataset: 0.6126019021739131 | |||||
Number of samples correctly predicted in original dataset but wrong in adversarial dataset: 832 | |||||
Number of samples both wrong predicted in original and adversarial dataset: 1449 | |||||
------------------------------------------------------------------------------ | |||||
Method Shear | |||||
Number of perturb samples: 442 | |||||
Number of wrong predicted: 351 | |||||
Number of correctly predicted in origin dataset but wrong in adversarial: 153 | |||||
Number of both wrong predicted in origin and adversarial dataset: 198 | |||||
------------------------------------------------------------------------------ | |||||
Method Contrast | |||||
Number of perturb samples: 387 | |||||
Number of wrong predicted: 57 | |||||
Number of correctly predicted in origin dataset but wrong in adversarial: 8 | |||||
Number of both wrong predicted in origin and adversarial dataset: 49 | |||||
------------------------------------------------------------------------------ | |||||
Method GaussianBlur | |||||
Number of perturb samples: 436 | |||||
Number of wrong predicted: 181 | |||||
Number of correctly predicted in origin dataset but wrong in adversarial: 71 | |||||
Number of both wrong predicted in origin and adversarial dataset: 110 | |||||
------------------------------------------------------------------------------ | |||||
Method MotionBlur | |||||
Number of perturb samples: 458 | |||||
Number of wrong predicted: 215 | |||||
Number of correctly predicted in origin dataset but wrong in adversarial: 92 | |||||
Number of both wrong predicted in origin and adversarial dataset: 123 | |||||
------------------------------------------------------------------------------ | |||||
Method GradientLuminance | |||||
Number of perturb samples: 1243 | |||||
Number of wrong predicted: 154 | |||||
Number of correctly predicted in origin dataset but wrong in adversarial: 4 | |||||
Number of both wrong predicted in origin and adversarial dataset: 150 | |||||
------------------------------------------------------------------------------ | |||||
Method Rotate | |||||
Number of perturb samples: 405 | |||||
Number of wrong predicted: 298 | |||||
Number of correctly predicted in origin dataset but wrong in adversarial: 136 | |||||
Number of both wrong predicted in origin and adversarial dataset: 162 | |||||
------------------------------------------------------------------------------ | |||||
Method SaltAndPepperNoise | |||||
Number of perturb samples: 413 | |||||
Number of wrong predicted: 116 | |||||
Number of correctly predicted in origin dataset but wrong in adversarial: 29 | |||||
Number of both wrong predicted in origin and adversarial dataset: 87 | |||||
------------------------------------------------------------------------------ | |||||
Method Translate | |||||
Number of perturb samples: 419 | |||||
Number of wrong predicted: 159 | |||||
Number of correctly predicted in origin dataset but wrong in adversarial: 57 | |||||
Number of both wrong predicted in origin and adversarial dataset: 102 | |||||
------------------------------------------------------------------------------ | |||||
Method GradientBlur | |||||
Number of perturb samples: 440 | |||||
Number of wrong predicted: 92 | |||||
Number of correctly predicted in origin dataset but wrong in adversarial: 26 | |||||
Number of both wrong predicted in origin and adversarial dataset: 66 | |||||
------------------------------------------------------------------------------ | |||||
Method Perspective | |||||
Number of perturb samples: 401 | |||||
Number of wrong predicted: 181 | |||||
Number of correctly predicted in origin dataset but wrong in adversarial: 75 | |||||
Number of both wrong predicted in origin and adversarial dataset: 106 | |||||
------------------------------------------------------------------------------ | |||||
Method Curve | |||||
Number of perturb samples: 410 | |||||
Number of wrong predicted: 361 | |||||
Number of correctly predicted in origin dataset but wrong in adversarial: 162 | |||||
Number of both wrong predicted in origin and adversarial dataset: 199 | |||||
------------------------------------------------------------------------------ | |||||
Method Scale | |||||
Number of perturb samples: 434 | |||||
Number of wrong predicted: 116 | |||||
Number of correctly predicted in origin dataset but wrong in adversarial: 19 | |||||
Number of both wrong predicted in origin and adversarial dataset: 97 | |||||
------------------------------------------------------------------------------ | |||||
``` | |||||
分析结果包含: | |||||
1. 评测的样本数量:5888 | |||||
2. CNN-CTC模型在原数据集上的准确率:85.4% | |||||
3. CNN-CTC模型在扰动数据集上的准确率:57.2% | |||||
4. 在原图上预测正确,扰动后图片预测错误的 样本数量:1736 | |||||
5. 在原图和扰动后图片上均预测错误的样本数量:782 | |||||
6. 对于每一个扰动方法,包含样本数量、扰动样本预测错误的数量、原样本预测正确扰动后预测错误的数量、原样本和扰动样本均预测错误的数量。 | |||||
如果模型对某扰动方法扰动后的图片预测错误率较高,则说明CNN-CTC模型对于该方法鲁棒性较差,建议针对性提升,如Rotate、Curve、MotionBlur和Shear这几种扰动方法,大部分扰动后的图片都预测错误,建议进一步分析。 | |||||
同时在`ADV_TEST_DATASET_PATH`路径下生成3个文件夹: | |||||
``` | |||||
adv_wrong_pred # 模型对于扰动后图片分类错误的数据集 | |||||
ori_corret_adv_wrong_pred # 模型对于原图分类正确但扰动后图片分类错误的数据集 | |||||
ori_wrong_adv_wrong_pred # 模型对于原图分类和扰动后图片均分类错误的数据集 | |||||
``` | |||||
每个文件夹均按照扰动方法分类: | |||||
 | |||||
每张图片的命名格式:真值-预测值.png,如下图: | |||||
 | |||||
存储的图片可供进一步分析,是模型质量问题、图片质量问题、还是扰动方法影响图片语义从而导致预测错误。 | |||||
 |