Merge pull request !3478 from yangyongjie/r0.6tags/v0.6.0-beta
| @@ -31,7 +31,8 @@ These is an example of training Warpctc with self-generated captcha image datase | |||||
| └──warpct | └──warpct | ||||
| ├── README.md | ├── README.md | ||||
| ├── script | ├── script | ||||
| ├── run_distribute_train.sh # launch distributed training(8 pcs) | |||||
| ├── run_distribute_train.sh # launch distributed training in Ascend(8 pcs) | |||||
| ├── run_distribute_train_for_gpu.sh # launch distributed training in GPU | |||||
| ├── run_eval.sh # launch evaluation | ├── run_eval.sh # launch evaluation | ||||
| ├── run_process_data.sh # launch dataset generation | ├── run_process_data.sh # launch dataset generation | ||||
| └── run_standalone_train.sh # launch standalone training(1 pcs) | └── run_standalone_train.sh # launch standalone training(1 pcs) | ||||
| @@ -75,22 +76,31 @@ Parameters for both training and evaluation can be set in config.py. | |||||
| #### Usage | #### Usage | ||||
| ``` | ``` | ||||
| # distributed training | |||||
| # distributed training in Ascend | |||||
| Usage: sh run_distribute_train.sh [MINDSPORE_HCCL_CONFIG_PATH] [DATASET_PATH] | Usage: sh run_distribute_train.sh [MINDSPORE_HCCL_CONFIG_PATH] [DATASET_PATH] | ||||
| # distributed training in GPU | |||||
| Usage: sh run_distribute_train_for_gpu.sh [RANK_SIZE] [DATASET_PATH] | |||||
| # standalone training | # standalone training | ||||
| Usage: sh run_standalone_train.sh [DATASET_PATH] | |||||
| Usage: sh run_standalone_train.sh [DATASET_PATH] [PLATFORM] | |||||
| ``` | ``` | ||||
| #### Launch | #### Launch | ||||
| ``` | ``` | ||||
| # distribute training example | |||||
| # distribute training example in Ascend | |||||
| sh run_distribute_train.sh rank_table.json ../data/train | sh run_distribute_train.sh rank_table.json ../data/train | ||||
| # standalone training example | |||||
| sh run_standalone_train.sh ../data/train | |||||
| # distribute training example in GPU | |||||
| sh run_distribute_train.sh 8 ../data/train | |||||
| # standalone training example in Ascend | |||||
| sh run_standalone_train.sh ../data/train Ascend | |||||
| # standalone training example in GPU | |||||
| sh run_standalone_train.sh ../data/train GPU | |||||
| ``` | ``` | ||||
| > About rank_table.json, you can refer to the [distributed training tutorial](https://www.mindspore.cn/tutorial/en/master/advanced_use/distributed_training.html). | > About rank_table.json, you can refer to the [distributed training tutorial](https://www.mindspore.cn/tutorial/en/master/advanced_use/distributed_training.html). | ||||
| @@ -116,14 +126,17 @@ Epoch: [ 5/ 30], step: [ 98/ 98], loss: [0.0186/0.0186], time: [75199.5809] | |||||
| ``` | ``` | ||||
| # evaluation | # evaluation | ||||
| Usage: sh run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH] | |||||
| Usage: sh run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH] [PLATFORM] | |||||
| ``` | ``` | ||||
| #### Launch | #### Launch | ||||
| ``` | ``` | ||||
| # evaluation example | |||||
| sh run_eval.sh ../data/test warpctc-30-98.ckpt | |||||
| # evaluation example in Ascend | |||||
| sh run_eval.sh ../data/test warpctc-30-98.ckpt Ascend | |||||
| # evaluation example in GPU | |||||
| sh run_eval.sh ../data/test warpctc-30-98.ckpt GPU | |||||
| ``` | ``` | ||||
| > checkpoint can be produced in training process. | > checkpoint can be produced in training process. | ||||
| @@ -23,10 +23,10 @@ from mindspore import dataset as de | |||||
| from mindspore.train.model import Model | from mindspore.train.model import Model | ||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | from mindspore.train.serialization import load_checkpoint, load_param_into_net | ||||
| from src.loss import CTCLoss | |||||
| from src.loss import CTCLoss, CTCLossV2 | |||||
| from src.config import config as cf | from src.config import config as cf | ||||
| from src.dataset import create_dataset | from src.dataset import create_dataset | ||||
| from src.warpctc import StackedRNN | |||||
| from src.warpctc import StackedRNN, StackedRNNForGPU | |||||
| from src.metric import WarpCTCAccuracy | from src.metric import WarpCTCAccuracy | ||||
| random.seed(1) | random.seed(1) | ||||
| @@ -36,30 +36,38 @@ de.config.set_seed(1) | |||||
| parser = argparse.ArgumentParser(description="Warpctc training") | parser = argparse.ArgumentParser(description="Warpctc training") | ||||
| parser.add_argument("--dataset_path", type=str, default=None, help="Dataset, default is None.") | parser.add_argument("--dataset_path", type=str, default=None, help="Dataset, default is None.") | ||||
| parser.add_argument("--checkpoint_path", type=str, default=None, help="checkpoint file path, default is None") | parser.add_argument("--checkpoint_path", type=str, default=None, help="checkpoint file path, default is None") | ||||
| parser.add_argument('--platform', type=str, default='Ascend', choices=['Ascend', 'GPU'], | |||||
| help='Running platform, choose from Ascend, GPU, and default is Ascend.') | |||||
| args_opt = parser.parse_args() | args_opt = parser.parse_args() | ||||
| device_id = int(os.getenv('DEVICE_ID')) | |||||
| context.set_context(mode=context.GRAPH_MODE, | |||||
| device_target="Ascend", | |||||
| save_graphs=False, | |||||
| device_id=device_id) | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform, save_graphs=False) | |||||
| if args_opt.platform == 'Ascend': | |||||
| device_id = int(os.getenv('DEVICE_ID')) | |||||
| context.set_context(device_id=device_id) | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| max_captcha_digits = cf.max_captcha_digits | max_captcha_digits = cf.max_captcha_digits | ||||
| input_size = m.ceil(cf.captcha_height / 64) * 64 * 3 | input_size = m.ceil(cf.captcha_height / 64) * 64 * 3 | ||||
| # create dataset | # create dataset | ||||
| dataset = create_dataset(dataset_path=args_opt.dataset_path, repeat_num=1, batch_size=cf.batch_size) | |||||
| dataset = create_dataset(dataset_path=args_opt.dataset_path, | |||||
| batch_size=cf.batch_size, | |||||
| device_target=args_opt.platform) | |||||
| step_size = dataset.get_dataset_size() | step_size = dataset.get_dataset_size() | ||||
| # define loss | |||||
| loss = CTCLoss(max_sequence_length=cf.captcha_width, max_label_length=max_captcha_digits, batch_size=cf.batch_size) | |||||
| # define net | |||||
| net = StackedRNN(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size) | |||||
| if args_opt.platform == 'Ascend': | |||||
| loss = CTCLoss(max_sequence_length=cf.captcha_width, | |||||
| max_label_length=max_captcha_digits, | |||||
| batch_size=cf.batch_size) | |||||
| net = StackedRNN(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size) | |||||
| else: | |||||
| loss = CTCLossV2(max_sequence_length=cf.captcha_width, batch_size=cf.batch_size) | |||||
| net = StackedRNNForGPU(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size) | |||||
| # load checkpoint | # load checkpoint | ||||
| param_dict = load_checkpoint(args_opt.checkpoint_path) | param_dict = load_checkpoint(args_opt.checkpoint_path) | ||||
| load_param_into_net(net, param_dict) | load_param_into_net(net, param_dict) | ||||
| net.set_train(False) | net.set_train(False) | ||||
| # define model | # define model | ||||
| model = Model(net, loss_fn=loss, metrics={'WarpCTCAccuracy': WarpCTCAccuracy()}) | |||||
| model = Model(net, loss_fn=loss, metrics={'WarpCTCAccuracy': WarpCTCAccuracy(args_opt.platform)}) | |||||
| # start evaluation | # start evaluation | ||||
| res = model.eval(dataset) | |||||
| res = model.eval(dataset, dataset_sink_mode=args_opt.platform == 'Ascend') | |||||
| print("result:", res, flush=True) | print("result:", res, flush=True) | ||||
| @@ -57,6 +57,6 @@ for ((i = 0; i < ${DEVICE_NUM}; i++)); do | |||||
| cd ./train_parallel$i || exit | cd ./train_parallel$i || exit | ||||
| echo "start training for rank $RANK_ID, device $DEVICE_ID" | echo "start training for rank $RANK_ID, device $DEVICE_ID" | ||||
| env >env.log | env >env.log | ||||
| python train.py --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$PATH2 &>log & | |||||
| python train.py --platform=Ascend --dataset_path=$PATH2 --run_distribute > log.txt 2>&1 & | |||||
| cd .. | cd .. | ||||
| done | done | ||||
| @@ -0,0 +1,52 @@ | |||||
| #!/bin/bash | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| if [ $# != 2 ]; then | |||||
| echo "Usage: sh run_distribute_train.sh [RANK_SIZE] [DATASET_PATH]" | |||||
| exit 1 | |||||
| fi | |||||
| get_real_path() { | |||||
| if [ "${1:0:1}" == "/" ]; then | |||||
| echo "$1" | |||||
| else | |||||
| echo "$(realpath -m $PWD/$1)" | |||||
| fi | |||||
| } | |||||
| RANK_SIZE=$1 | |||||
| DATASET_PATH=$(get_real_path $2) | |||||
| if [ ! -d $DATASET_PATH ]; then | |||||
| echo "error: DATASET_PATH=$DATASET_PATH is not a directory" | |||||
| exit 1 | |||||
| fi | |||||
| if [ -d "distribute_train" ]; then | |||||
| rm -rf ./distribute_train | |||||
| fi | |||||
| mkdir ./distribute_train | |||||
| cp ../*.py ./distribute_train | |||||
| cp -r ../src ./distribute_train | |||||
| cd ./distribute_train || exit | |||||
| mpirun --allow-run-as-root -n $RANK_SIZE \ | |||||
| python train.py \ | |||||
| --dataset_path=$DATASET_PATH \ | |||||
| --platform=GPU \ | |||||
| --run_distribute > log.txt 2>&1 & | |||||
| cd .. | |||||
| @@ -14,8 +14,8 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| if [ $# != 2 ]; then | |||||
| echo "Usage: sh run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH]" | |||||
| if [ $# != 3 ]; then | |||||
| echo "Usage: sh run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH] [PLATFORM]" | |||||
| exit 1 | exit 1 | ||||
| fi | fi | ||||
| @@ -29,6 +29,7 @@ get_real_path() { | |||||
| PATH1=$(get_real_path $1) | PATH1=$(get_real_path $1) | ||||
| PATH2=$(get_real_path $2) | PATH2=$(get_real_path $2) | ||||
| PLATFORM=$3 | |||||
| if [ ! -d $PATH1 ]; then | if [ ! -d $PATH1 ]; then | ||||
| echo "error: DATASET_PATH=$PATH1 is not a directory" | echo "error: DATASET_PATH=$PATH1 is not a directory" | ||||
| @@ -40,21 +41,44 @@ if [ ! -f $PATH2 ]; then | |||||
| exit 1 | exit 1 | ||||
| fi | fi | ||||
| ulimit -u unlimited | |||||
| export DEVICE_NUM=1 | |||||
| export DEVICE_ID=0 | |||||
| export RANK_SIZE=$DEVICE_NUM | |||||
| export RANK_ID=0 | |||||
| run_ascend() { | |||||
| ulimit -u unlimited | |||||
| export DEVICE_NUM=1 | |||||
| export DEVICE_ID=0 | |||||
| export RANK_SIZE=$DEVICE_NUM | |||||
| export RANK_ID=0 | |||||
| if [ -d "eval" ]; then | |||||
| rm -rf ./eval | |||||
| if [ -d "eval" ]; then | |||||
| rm -rf ./eval | |||||
| fi | |||||
| mkdir ./eval | |||||
| cp ../*.py ./eval | |||||
| cp -r ../src ./eval | |||||
| cd ./eval || exit | |||||
| env >env.log | |||||
| echo "start evaluation for device $DEVICE_ID" | |||||
| python eval.py --dataset_path=$1 --checkpoint_path=$2 --platform=Ascend > log.txt 2>&1 & | |||||
| cd .. | |||||
| } | |||||
| run_gpu() { | |||||
| if [ -d "eval" ]; then | |||||
| rm -rf ./eval | |||||
| fi | |||||
| mkdir ./eval | |||||
| cp ../*.py ./eval | |||||
| cp -r ../src ./eval | |||||
| cd ./eval || exit | |||||
| env >env.log | |||||
| python eval.py --dataset_path=$1 --checkpoint_path=$2 --platform=GPU > log.txt 2>&1 & | |||||
| cd .. | |||||
| } | |||||
| if [ "Ascend" == $PLATFORM ]; then | |||||
| run_ascend $PATH1 $PATH2 | |||||
| elif [ "GPU" == $PLATFORM ]; then | |||||
| run_gpu $PATH1 $PATH2 | |||||
| else | |||||
| echo "error: PLATFORM=$PLATFORM is not support, only support Ascend and GPU." | |||||
| fi | fi | ||||
| mkdir ./eval | |||||
| cp ../*.py ./eval | |||||
| cp *.sh ./eval | |||||
| cp -r ../src ./eval | |||||
| cd ./eval || exit | |||||
| env >env.log | |||||
| echo "start evaluation for device $DEVICE_ID" | |||||
| python eval.py --dataset_path=$PATH1 --checkpoint_path=$PATH2 &>log & | |||||
| cd .. | |||||
| @@ -14,8 +14,8 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| if [ $# != 1 ]; then | |||||
| echo "Usage: sh run_standalone_train.sh [DATASET_PATH]" | |||||
| if [ $# != 2 ]; then | |||||
| echo "Usage: sh run_standalone_train.sh [DATASET_PATH] [PLATFORM]" | |||||
| exit 1 | exit 1 | ||||
| fi | fi | ||||
| @@ -28,27 +28,44 @@ get_real_path() { | |||||
| } | } | ||||
| PATH1=$(get_real_path $1) | PATH1=$(get_real_path $1) | ||||
| PLATFORM=$2 | |||||
| if [ ! -d $PATH1 ]; then | if [ ! -d $PATH1 ]; then | ||||
| echo "error: DATASET_PATH=$PATH1 is not a directory" | echo "error: DATASET_PATH=$PATH1 is not a directory" | ||||
| exit 1 | exit 1 | ||||
| fi | fi | ||||
| ulimit -u unlimited | |||||
| export DEVICE_NUM=1 | |||||
| export DEVICE_ID=0 | |||||
| export RANK_ID=0 | |||||
| export RANK_SIZE=1 | |||||
| run_ascend() { | |||||
| ulimit -u unlimited | |||||
| export DEVICE_NUM=1 | |||||
| export DEVICE_ID=0 | |||||
| export RANK_ID=0 | |||||
| export RANK_SIZE=1 | |||||
| echo "start training for device $DEVICE_ID" | |||||
| env >env.log | |||||
| python train.py --dataset_path=$1 --platform=Ascend > log.txt 2>&1 & | |||||
| cd .. | |||||
| } | |||||
| run_gpu() { | |||||
| env >env.log | |||||
| python train.py --dataset_path=$1 --platform=GPU > log.txt 2>&1 & | |||||
| cd .. | |||||
| } | |||||
| if [ -d "train" ]; then | if [ -d "train" ]; then | ||||
| rm -rf ./train | |||||
| rm -rf ./train | |||||
| fi | fi | ||||
| mkdir ./train | mkdir ./train | ||||
| cp ../*.py ./train | cp ../*.py ./train | ||||
| cp *.sh ./train | |||||
| cp -r ../src ./train | cp -r ../src ./train | ||||
| cd ./train || exit | cd ./train || exit | ||||
| echo "start training for device $DEVICE_ID" | |||||
| env >env.log | |||||
| python train.py --dataset=$PATH1 &>log & | |||||
| cd .. | |||||
| if [ "Ascend" == $PLATFORM ]; then | |||||
| run_ascend $PATH1 | |||||
| elif [ "GPU" == $PLATFORM ]; then | |||||
| run_gpu $PATH1 | |||||
| else | |||||
| echo "error: PLATFORM=$PLATFORM is not support, only support Ascend and GPU." | |||||
| fi | |||||
| @@ -24,24 +24,25 @@ from PIL import Image | |||||
| from src.config import config as cf | from src.config import config as cf | ||||
| class _CaptchaDataset(): | |||||
| class _CaptchaDataset: | |||||
| """ | """ | ||||
| create train or evaluation dataset for warpctc | create train or evaluation dataset for warpctc | ||||
| Args: | Args: | ||||
| img_root_dir(str): root path of images | img_root_dir(str): root path of images | ||||
| max_captcha_digits(int): max number of digits in images. | max_captcha_digits(int): max number of digits in images. | ||||
| blank(int): value reserved for blank label, default is 10. When parsing label from image file names, if label | |||||
| length is less than max_captcha_digits, the remaining labels are padding with blank. | |||||
| device_target(str): platform of training, support Ascend and GPU. | |||||
| """ | """ | ||||
| def __init__(self, img_root_dir, max_captcha_digits, blank=10): | |||||
| def __init__(self, img_root_dir, max_captcha_digits, device_target='Ascend'): | |||||
| if not os.path.exists(img_root_dir): | if not os.path.exists(img_root_dir): | ||||
| raise RuntimeError("the input image dir {} is invalid!".format(img_root_dir)) | raise RuntimeError("the input image dir {} is invalid!".format(img_root_dir)) | ||||
| self.img_root_dir = img_root_dir | self.img_root_dir = img_root_dir | ||||
| self.img_names = [i for i in os.listdir(img_root_dir) if i.endswith('.png')] | self.img_names = [i for i in os.listdir(img_root_dir) if i.endswith('.png')] | ||||
| self.max_captcha_digits = max_captcha_digits | self.max_captcha_digits = max_captcha_digits | ||||
| self.blank = blank | |||||
| self.target = device_target | |||||
| self.blank = 10 if self.target == 'Ascend' else 0 | |||||
| self.label_length = [len(os.path.splitext(n)[0].split('-')[-1]) for n in self.img_names] | |||||
| def __len__(self): | def __len__(self): | ||||
| return len(self.img_names) | return len(self.img_names) | ||||
| @@ -54,27 +55,33 @@ class _CaptchaDataset(): | |||||
| image = np.array(im) | image = np.array(im) | ||||
| label_str = os.path.splitext(img_name)[0] | label_str = os.path.splitext(img_name)[0] | ||||
| label_str = label_str[label_str.find('-') + 1:] | label_str = label_str[label_str.find('-') + 1:] | ||||
| label = [int(i) for i in label_str] | |||||
| label.extend([int(self.blank)] * (self.max_captcha_digits - len(label))) | |||||
| if self.target == 'Ascend': | |||||
| label = [int(i) for i in label_str] | |||||
| label.extend([int(self.blank)] * (self.max_captcha_digits - len(label))) | |||||
| else: | |||||
| label = [int(i) + 1 for i in label_str] | |||||
| length = len(label) | |||||
| label.extend([int(self.blank)] * (self.max_captcha_digits - len(label))) | |||||
| label.append(length) | |||||
| label = np.array(label) | label = np.array(label) | ||||
| return image, label | return image, label | ||||
| def create_dataset(dataset_path, repeat_num=1, batch_size=1): | |||||
| def create_dataset(dataset_path, batch_size=1, num_shards=1, shard_id=0, device_target='Ascend'): | |||||
| """ | """ | ||||
| create train or evaluation dataset for warpctc | create train or evaluation dataset for warpctc | ||||
| Args: | Args: | ||||
| dataset_path(int): dataset path | dataset_path(int): dataset path | ||||
| repeat_num(int): dataset repetition num, default is 1 | |||||
| batch_size(int): batch size of generated dataset, default is 1 | batch_size(int): batch size of generated dataset, default is 1 | ||||
| num_shards(int): number of devices | |||||
| shard_id(int): rank id | |||||
| device_target(str): platform of training, support Ascend and GPU | |||||
| """ | """ | ||||
| rank_size = int(os.environ.get("RANK_SIZE")) if os.environ.get("RANK_SIZE") else 1 | |||||
| rank_id = int(os.environ.get("RANK_ID")) if os.environ.get("RANK_ID") else 0 | |||||
| dataset = _CaptchaDataset(dataset_path, cf.max_captcha_digits) | |||||
| ds = de.GeneratorDataset(dataset, ["image", "label"], shuffle=True, num_shards=rank_size, shard_id=rank_id) | |||||
| ds.set_dataset_size(m.ceil(len(dataset) / rank_size)) | |||||
| dataset = _CaptchaDataset(dataset_path, cf.max_captcha_digits, device_target) | |||||
| ds = de.GeneratorDataset(dataset, ["image", "label"], shuffle=True, num_shards=num_shards, shard_id=shard_id) | |||||
| ds.set_dataset_size(m.ceil(len(dataset) / num_shards)) | |||||
| image_trans = [ | image_trans = [ | ||||
| vc.Rescale(1.0 / 255.0, 0.0), | vc.Rescale(1.0 / 255.0, 0.0), | ||||
| vc.Normalize([0.9010, 0.9049, 0.9025], std=[0.1521, 0.1347, 0.1458]), | vc.Normalize([0.9010, 0.9049, 0.9025], std=[0.1521, 0.1347, 0.1458]), | ||||
| @@ -87,6 +94,5 @@ def create_dataset(dataset_path, repeat_num=1, batch_size=1): | |||||
| ds = ds.map(input_columns=["image"], num_parallel_workers=8, operations=image_trans) | ds = ds.map(input_columns=["image"], num_parallel_workers=8, operations=image_trans) | ||||
| ds = ds.map(input_columns=["label"], num_parallel_workers=8, operations=label_trans) | ds = ds.map(input_columns=["label"], num_parallel_workers=8, operations=label_trans) | ||||
| ds = ds.batch(batch_size) | |||||
| ds = ds.repeat(repeat_num) | |||||
| ds = ds.batch(batch_size, drop_remainder=True) | |||||
| return ds | return ds | ||||
| @@ -47,3 +47,25 @@ class CTCLoss(_Loss): | |||||
| labels_values = self.reshape(label, (-1,)) | labels_values = self.reshape(label, (-1,)) | ||||
| loss, _ = self.ctc_loss(logit, self.labels_indices, labels_values, self.sequence_length) | loss, _ = self.ctc_loss(logit, self.labels_indices, labels_values, self.sequence_length) | ||||
| return loss | return loss | ||||
| class CTCLossV2(_Loss): | |||||
| """ | |||||
| CTCLoss definition | |||||
| Args: | |||||
| max_sequence_length(int): max number of sequence length. For captcha images, the value is equal to image width | |||||
| batch_size(int): batch size of input logits | |||||
| """ | |||||
| def __init__(self, max_sequence_length, batch_size): | |||||
| super(CTCLossV2, self).__init__() | |||||
| self.input_length = Tensor(np.array([max_sequence_length] * batch_size), mstype.int32) | |||||
| self.reshape = P.Reshape() | |||||
| self.ctc_loss = P.CTCLossV2() | |||||
| def construct(self, logit, label): | |||||
| labels_values = label[:, :-1] | |||||
| labels_length = label[:, -1] | |||||
| loss, _ = self.ctc_loss(logit, labels_values, self.input_length, labels_length) | |||||
| return loss | |||||
| @@ -15,19 +15,19 @@ | |||||
| """Metric for accuracy evaluation.""" | """Metric for accuracy evaluation.""" | ||||
| from mindspore import nn | from mindspore import nn | ||||
| BLANK_LABLE = 10 | |||||
| class WarpCTCAccuracy(nn.Metric): | class WarpCTCAccuracy(nn.Metric): | ||||
| """ | """ | ||||
| Define accuracy metric for warpctc network. | Define accuracy metric for warpctc network. | ||||
| """ | """ | ||||
| def __init__(self): | |||||
| def __init__(self, device_target='Ascend'): | |||||
| super(WarpCTCAccuracy).__init__() | super(WarpCTCAccuracy).__init__() | ||||
| self._correct_num = 0 | self._correct_num = 0 | ||||
| self._total_num = 0 | self._total_num = 0 | ||||
| self._count = 0 | self._count = 0 | ||||
| self.device_target = device_target | |||||
| self.blank = 10 if device_target == 'Ascend' else 0 | |||||
| def clear(self): | def clear(self): | ||||
| self._correct_num = 0 | self._correct_num = 0 | ||||
| @@ -39,6 +39,8 @@ class WarpCTCAccuracy(nn.Metric): | |||||
| y_pred = self._convert_data(inputs[0]) | y_pred = self._convert_data(inputs[0]) | ||||
| y = self._convert_data(inputs[1]) | y = self._convert_data(inputs[1]) | ||||
| if self.device_target == 'GPU': | |||||
| y = y[:, :-1] | |||||
| self._count += 1 | self._count += 1 | ||||
| @@ -54,8 +56,7 @@ class WarpCTCAccuracy(nn.Metric): | |||||
| raise RuntimeError('Accuary can not be calculated, because the number of samples is 0.') | raise RuntimeError('Accuary can not be calculated, because the number of samples is 0.') | ||||
| return self._correct_num / self._total_num | return self._correct_num / self._total_num | ||||
| @staticmethod | |||||
| def _is_eq(pred_lbl, target): | |||||
| def _is_eq(self, pred_lbl, target): | |||||
| """ | """ | ||||
| check whether predict label is equal to target label | check whether predict label is equal to target label | ||||
| """ | """ | ||||
| @@ -63,11 +64,10 @@ class WarpCTCAccuracy(nn.Metric): | |||||
| pred_diff = len(target) - len(pred_lbl) | pred_diff = len(target) - len(pred_lbl) | ||||
| if pred_diff > 0: | if pred_diff > 0: | ||||
| # padding by BLANK_LABLE | # padding by BLANK_LABLE | ||||
| pred_lbl.extend([BLANK_LABLE] * pred_diff) | |||||
| pred_lbl.extend([self.blank] * pred_diff) | |||||
| return pred_lbl == target | return pred_lbl == target | ||||
| @staticmethod | |||||
| def _get_prediction(y_pred): | |||||
| def _get_prediction(self, y_pred): | |||||
| """ | """ | ||||
| parse predict result to labels | parse predict result to labels | ||||
| """ | """ | ||||
| @@ -78,11 +78,11 @@ class WarpCTCAccuracy(nn.Metric): | |||||
| pred_lbls = [] | pred_lbls = [] | ||||
| for i in range(batch_size): | for i in range(batch_size): | ||||
| idx = indices[:, i] | idx = indices[:, i] | ||||
| last_idx = BLANK_LABLE | |||||
| last_idx = self.blank | |||||
| pred_lbl = [] | pred_lbl = [] | ||||
| for j in range(lens[i]): | for j in range(lens[i]): | ||||
| cur_idx = idx[j] | cur_idx = idx[j] | ||||
| if cur_idx not in [last_idx, BLANK_LABLE]: | |||||
| if cur_idx not in [last_idx, self.blank]: | |||||
| pred_lbl.append(cur_idx) | pred_lbl.append(cur_idx) | ||||
| last_idx = cur_idx | last_idx = cur_idx | ||||
| pred_lbls.append(pred_lbl) | pred_lbls.append(pred_lbl) | ||||
| @@ -88,3 +88,52 @@ class StackedRNN(nn.Cell): | |||||
| output = self.concat((output, h2_after_fc)) | output = self.concat((output, h2_after_fc)) | ||||
| return output | return output | ||||
| class StackedRNNForGPU(nn.Cell): | |||||
| """ | |||||
| Define a stacked RNN network which contains two LSTM layers and one full-connect layer. | |||||
| Args: | |||||
| input_size(int): Size of time sequence. Usually, the input_size is equal to three times of image height for | |||||
| captcha images. | |||||
| batch_size(int): batch size of input data, default is 64 | |||||
| hidden_size(int): the hidden size in LSTM layers, default is 512 | |||||
| num_layer(int): the number of layer of LSTM. | |||||
| """ | |||||
| def __init__(self, input_size, batch_size=64, hidden_size=512, num_layer=2): | |||||
| super(StackedRNNForGPU, self).__init__() | |||||
| self.batch_size = batch_size | |||||
| self.input_size = input_size | |||||
| self.num_classes = 11 | |||||
| self.reshape = P.Reshape() | |||||
| self.cast = P.Cast() | |||||
| k = (1 / hidden_size) ** 0.5 | |||||
| weight_shape = 4 * hidden_size * (input_size + 3 * hidden_size + 4) | |||||
| self.weight = Parameter(np.random.uniform(-k, k, (weight_shape, 1, 1)).astype(np.float32), name='weight') | |||||
| self.h = Tensor(np.zeros(shape=(num_layer, batch_size, hidden_size)).astype(np.float32)) | |||||
| self.c = Tensor(np.zeros(shape=(num_layer, batch_size, hidden_size)).astype(np.float32)) | |||||
| self.lstm = nn.LSTM(input_size, hidden_size, num_layers=2) | |||||
| self.lstm.weight = self.weight | |||||
| self.fc_weight = np.random.random((self.num_classes, hidden_size)).astype(np.float32) | |||||
| self.fc_bias = np.random.random(self.num_classes).astype(np.float32) | |||||
| self.fc = nn.Dense(in_channels=hidden_size, out_channels=self.num_classes, weight_init=Tensor(self.fc_weight), | |||||
| bias_init=Tensor(self.fc_bias)) | |||||
| self.fc.to_float(mstype.float32) | |||||
| self.expand_dims = P.ExpandDims() | |||||
| self.concat = P.Concat() | |||||
| self.transpose = P.Transpose() | |||||
| def construct(self, x): | |||||
| x = self.transpose(x, (3, 0, 2, 1)) | |||||
| x = self.reshape(x, (-1, self.batch_size, self.input_size)) | |||||
| output, _ = self.lstm(x, (self.h, self.c)) | |||||
| res = () | |||||
| for i in range(F.shape(x)[0]): | |||||
| res += (self.expand_dims(self.fc(output[i]), 0),) | |||||
| res = self.concat(res) | |||||
| return res | |||||
| @@ -42,7 +42,7 @@ grad_div = C.MultitypeFuncGraph("grad_div") | |||||
| @grad_div.register("Tensor", "Tensor") | @grad_div.register("Tensor", "Tensor") | ||||
| def _grad_div(val, grad): | def _grad_div(val, grad): | ||||
| div = P.Div() | |||||
| div = P.RealDiv() | |||||
| mul = P.Mul() | mul = P.Mul() | ||||
| grad = mul(grad, 10.0) | grad = mul(grad, 10.0) | ||||
| ret = div(grad, val) | ret = div(grad, val) | ||||
| @@ -24,12 +24,12 @@ from mindspore import dataset as de | |||||
| from mindspore.train.model import Model, ParallelMode | from mindspore.train.model import Model, ParallelMode | ||||
| from mindspore.nn.wrap import WithLossCell | from mindspore.nn.wrap import WithLossCell | ||||
| from mindspore.train.callback import TimeMonitor, LossMonitor, CheckpointConfig, ModelCheckpoint | from mindspore.train.callback import TimeMonitor, LossMonitor, CheckpointConfig, ModelCheckpoint | ||||
| from mindspore.communication.management import init | |||||
| from mindspore.communication.management import init, get_group_size, get_rank | |||||
| from src.loss import CTCLoss | |||||
| from src.loss import CTCLoss, CTCLossV2 | |||||
| from src.config import config as cf | from src.config import config as cf | ||||
| from src.dataset import create_dataset | from src.dataset import create_dataset | ||||
| from src.warpctc import StackedRNN | |||||
| from src.warpctc import StackedRNN, StackedRNNForGPU | |||||
| from src.warpctc_for_train import TrainOneStepCellWithGradClip | from src.warpctc_for_train import TrainOneStepCellWithGradClip | ||||
| from src.lr_schedule import get_lr | from src.lr_schedule import get_lr | ||||
| @@ -38,38 +38,60 @@ np.random.seed(1) | |||||
| de.config.set_seed(1) | de.config.set_seed(1) | ||||
| parser = argparse.ArgumentParser(description="Warpctc training") | parser = argparse.ArgumentParser(description="Warpctc training") | ||||
| parser.add_argument("--run_distribute", type=bool, default=False, help="Run distribute, default is false.") | |||||
| parser.add_argument('--device_num', type=int, default=1, help='Device num, default is 1.') | |||||
| parser.add_argument("--run_distribute", action='store_true', help="Run distribute, default is false.") | |||||
| parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path, default is None') | parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path, default is None') | ||||
| parser.add_argument('--platform', type=str, default='Ascend', choices=['Ascend', 'GPU'], | |||||
| help='Running platform, choose from Ascend, GPU, and default is Ascend.') | |||||
| parser.set_defaults(run_distribute=False) | |||||
| args_opt = parser.parse_args() | args_opt = parser.parse_args() | ||||
| device_id = int(os.getenv('DEVICE_ID')) | |||||
| context.set_context(mode=context.GRAPH_MODE, | |||||
| device_target="Ascend", | |||||
| save_graphs=False, | |||||
| device_id=device_id) | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform, save_graphs=False) | |||||
| if args_opt.platform == 'Ascend': | |||||
| device_id = int(os.getenv('DEVICE_ID')) | |||||
| context.set_context(device_id=device_id) | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| lr_scale = 1 | |||||
| if args_opt.run_distribute: | if args_opt.run_distribute: | ||||
| if args_opt.platform == 'Ascend': | |||||
| init() | |||||
| lr_scale = 1 | |||||
| device_num = int(os.environ.get("RANK_SIZE")) | |||||
| rank = int(os.environ.get("RANK_ID")) | |||||
| else: | |||||
| init('nccl') | |||||
| lr_scale = 0.5 | |||||
| device_num = get_group_size() | |||||
| rank = get_rank() | |||||
| context.reset_auto_parallel_context() | context.reset_auto_parallel_context() | ||||
| context.set_auto_parallel_context(device_num=args_opt.device_num, | |||||
| context.set_auto_parallel_context(device_num=device_num, | |||||
| parallel_mode=ParallelMode.DATA_PARALLEL, | parallel_mode=ParallelMode.DATA_PARALLEL, | ||||
| mirror_mean=True) | mirror_mean=True) | ||||
| init() | |||||
| else: | |||||
| device_num = 1 | |||||
| rank = 0 | |||||
| max_captcha_digits = cf.max_captcha_digits | max_captcha_digits = cf.max_captcha_digits | ||||
| input_size = m.ceil(cf.captcha_height / 64) * 64 * 3 | input_size = m.ceil(cf.captcha_height / 64) * 64 * 3 | ||||
| # create dataset | # create dataset | ||||
| dataset = create_dataset(dataset_path=args_opt.dataset_path, repeat_num=1, batch_size=cf.batch_size) | |||||
| dataset = create_dataset(dataset_path=args_opt.dataset_path, batch_size=cf.batch_size, | |||||
| num_shards=device_num, shard_id=rank, device_target=args_opt.platform) | |||||
| step_size = dataset.get_dataset_size() | step_size = dataset.get_dataset_size() | ||||
| # define lr | # define lr | ||||
| lr_init = cf.learning_rate if not args_opt.run_distribute else cf.learning_rate * args_opt.device_num | |||||
| lr_init = cf.learning_rate if not args_opt.run_distribute else cf.learning_rate * device_num * lr_scale | |||||
| lr = get_lr(cf.epoch_size, step_size, lr_init) | lr = get_lr(cf.epoch_size, step_size, lr_init) | ||||
| # define loss | |||||
| loss = CTCLoss(max_sequence_length=cf.captcha_width, max_label_length=max_captcha_digits, batch_size=cf.batch_size) | |||||
| # define net | |||||
| net = StackedRNN(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size) | |||||
| # define opt | |||||
| opt = nn.SGD(params=net.trainable_params(), learning_rate=lr, momentum=cf.momentum) | |||||
| if args_opt.platform == 'Ascend': | |||||
| loss = CTCLoss(max_sequence_length=cf.captcha_width, | |||||
| max_label_length=max_captcha_digits, | |||||
| batch_size=cf.batch_size) | |||||
| net = StackedRNN(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size) | |||||
| opt = nn.SGD(params=net.trainable_params(), learning_rate=lr, momentum=cf.momentum) | |||||
| else: | |||||
| loss = CTCLossV2(max_sequence_length=cf.captcha_width, batch_size=cf.batch_size) | |||||
| net = StackedRNNForGPU(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size) | |||||
| opt = nn.Momentum(params=net.trainable_params(), learning_rate=lr, momentum=cf.momentum) | |||||
| net = WithLossCell(net, loss) | net = WithLossCell(net, loss) | ||||
| net = TrainOneStepCellWithGradClip(net, opt).set_train() | net = TrainOneStepCellWithGradClip(net, opt).set_train() | ||||
| # define model | # define model | ||||
| @@ -79,6 +101,6 @@ if __name__ == '__main__': | |||||
| if cf.save_checkpoint: | if cf.save_checkpoint: | ||||
| config_ck = CheckpointConfig(save_checkpoint_steps=cf.save_checkpoint_steps, | config_ck = CheckpointConfig(save_checkpoint_steps=cf.save_checkpoint_steps, | ||||
| keep_checkpoint_max=cf.keep_checkpoint_max) | keep_checkpoint_max=cf.keep_checkpoint_max) | ||||
| ckpt_cb = ModelCheckpoint(prefix="waptctc", directory=cf.save_checkpoint_path, config=config_ck) | |||||
| ckpt_cb = ModelCheckpoint(prefix="warpctc", directory=cf.save_checkpoint_path, config=config_ck) | |||||
| callbacks.append(ckpt_cb) | callbacks.append(ckpt_cb) | ||||
| model.train(cf.epoch_size, dataset, callbacks=callbacks) | model.train(cf.epoch_size, dataset, callbacks=callbacks) | ||||