| @@ -0,0 +1,132 @@ | |||||
| # DeepFM Description | |||||
| This is an example of training DeepFM with Criteo dataset in MindSpore. | |||||
| [Paper](https://arxiv.org/pdf/1703.04247.pdf) Huifeng Guo, Ruiming Tang, Yunming Ye, Zhenguo Li, Xiuqiang He | |||||
| # Model architecture | |||||
| The overall network architecture of DeepFM is show below: | |||||
| [Link](https://arxiv.org/pdf/1703.04247.pdf) | |||||
| # Requirements | |||||
| - Install [MindSpore](https://www.mindspore.cn/install/en). | |||||
| - Download the criteo dataset for pre-training. Extract and clean text in the dataset with [WikiExtractor](https://github.com/attardi/wikiextractor). Convert the dataset to TFRecord format and move the files to a specified path. | |||||
| - For more information, please check the resources below: | |||||
| - [MindSpore tutorials](https://www.mindspore.cn/tutorial/zh-CN/master/index.html) | |||||
| - [MindSpore API](https://www.mindspore.cn/api/zh-CN/master/index.html) | |||||
| # Script description | |||||
| ## Script and sample code | |||||
| ```python | |||||
| ├── deepfm | |||||
| ├── README.md | |||||
| ├── scripts | |||||
| │ ├──run_train.sh | |||||
| │ ├──run_eval.sh | |||||
| ├── src | |||||
| │ ├──config.py | |||||
| │ ├──dataset.py | |||||
| │ ├──callback.py | |||||
| │ ├──deepfm.py | |||||
| ├── train.py | |||||
| ├── eval.py | |||||
| ``` | |||||
| ## Training process | |||||
| ### Usage | |||||
| - sh run_train.sh [DEVICE_NUM] [DATASET_PATH] [MINDSPORE_HCCL_CONFIG_PAHT] | |||||
| - python train.py --dataset_path [DATASET_PATH] | |||||
| ### Launch | |||||
| ``` | |||||
| # distribute training example | |||||
| sh scripts/run_distribute_train.sh 8 /opt/dataset/criteo /opt/mindspore_hccl_file.json | |||||
| # standalone training example | |||||
| sh scripts/run_standalone_train.sh 0 /opt/dataset/criteo | |||||
| or | |||||
| python train.py --dataset_path /opt/dataset/criteo > output.log 2>&1 & | |||||
| ``` | |||||
| ### Result | |||||
| Training result will be stored in the example path. | |||||
| Checkpoints will be stored at `./checkpoint` by default, | |||||
| and training log will be redirected to `./output.log` by default, | |||||
| and loss log will be redirected to `./loss.log` by default, | |||||
| and eval log will be redirected to `./auc.log` by default. | |||||
| ## Eval process | |||||
| ### Usage | |||||
| - sh run_eval.sh [DEVICE_ID] [DATASET_PATH] [CHECKPOINT_PATH] | |||||
| ### Launch | |||||
| ``` | |||||
| # infer example | |||||
| sh scripts/run_eval.sh 0 ~/criteo/eval/ ~/train/deepfm-15_41257.ckpt | |||||
| ``` | |||||
| > checkpoint can be produced in training process. | |||||
| ### Result | |||||
| Inference result will be stored in the example path, you can find result like the followings in `auc.log`. | |||||
| ``` | |||||
| 2020-05-27 20:51:35 AUC: 0.80577889065281, eval time: 35.55999s. | |||||
| ``` | |||||
| # Model description | |||||
| ## Performance | |||||
| ### Training Performance | |||||
| | Parameters | DeepFM | | |||||
| | -------------------------- | ------------------------------------------------------| | |||||
| | Model Version | | | |||||
| | Resource | Ascend 910, cpu:2.60GHz 96cores, memory:1.5T | | |||||
| | uploaded Date | 05/27/2020 | | |||||
| | MindSpore Version | 0.2.0 | | |||||
| | Dataset | Criteo | | |||||
| | Training Parameters | src/config.py | | |||||
| | Optimizer | Adam | | |||||
| | Loss Function | SoftmaxCrossEntropyWithLogits | | |||||
| | outputs | | | |||||
| | Loss | 0.4234 | | |||||
| | Accuracy | AUC[0.8055] | | |||||
| | Total time | 91 min | | |||||
| | Params (M) | | | |||||
| | Checkpoint for Fine tuning | | | |||||
| | Model for inference | | | |||||
| #### Inference Performance | |||||
| | Parameters | | | | |||||
| | -------------------------- | ----------------------------- | ------------------------- | | |||||
| | Model Version | | | | |||||
| | Resource | Ascend 910 | Ascend 310 | | |||||
| | uploaded Date | 05/27/2020 | 05/27/2020 | | |||||
| | MindSpore Version | 0.2.0 | 0.2.0 | | |||||
| | Dataset | Criteo | | | |||||
| | batch_size | 1000 | | | |||||
| | outputs | | | | |||||
| | Accuracy | AUC[0.8055] | | | |||||
| | Speed | | | | |||||
| | Total time | 35.559s | | | |||||
| | Model for inference | | | | |||||
| # ModelZoo Homepage | |||||
| [Link](https://gitee.com/mindspore/mindspore/tree/master/mindspore/model_zoo) | |||||
| @@ -0,0 +1,14 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the License); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # httpwww.apache.orglicensesLICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an AS IS BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| @@ -0,0 +1,66 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """train_criteo.""" | |||||
| import os | |||||
| import sys | |||||
| import time | |||||
| import argparse | |||||
| from mindspore import context | |||||
| from mindspore.train.model import Model | |||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||||
| from src.deepfm import ModelBuilder, AUCMetric | |||||
| from src.config import DataConfig, ModelConfig, TrainConfig | |||||
| from src.dataset import create_dataset | |||||
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |||||
| parser = argparse.ArgumentParser(description='CTR Prediction') | |||||
| parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path') | |||||
| parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path') | |||||
| args_opt, _ = parser.parse_known_args() | |||||
| device_id = int(os.getenv('DEVICE_ID')) | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id) | |||||
| def add_write(file_path, print_str): | |||||
| with open(file_path, 'a+', encoding='utf-8') as file_out: | |||||
| file_out.write(print_str + '\n') | |||||
| if __name__ == '__main__': | |||||
| data_config = DataConfig() | |||||
| model_config = ModelConfig() | |||||
| train_config = TrainConfig() | |||||
| ds_eval = create_dataset(args_opt.dataset_path, train_mode=False, | |||||
| epochs=1, batch_size=train_config.batch_size) | |||||
| model_builder = ModelBuilder(ModelConfig, TrainConfig) | |||||
| train_net, eval_net = model_builder.get_train_eval_net() | |||||
| train_net.set_train() | |||||
| eval_net.set_train(False) | |||||
| auc_metric = AUCMetric() | |||||
| model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric}) | |||||
| param_dict = load_checkpoint(args_opt.checkpoint_path) | |||||
| load_param_into_net(eval_net, param_dict) | |||||
| start = time.time() | |||||
| res = model.eval(ds_eval) | |||||
| eval_time = time.time() - start | |||||
| time_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) | |||||
| out_str = f'{time_str} AUC: {list(res.values())[0]}, eval time: {eval_time}s.' | |||||
| print(out_str) | |||||
| add_write('./auc.log', str(out_str)) | |||||
| @@ -0,0 +1,44 @@ | |||||
| #!/bin/bash | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| echo "Please run the script as: " | |||||
| echo "sh scripts/run_distribute_train.sh DEVICE_NUM DATASET_PATH MINDSPORE_HCCL_CONFIG_PAHT" | |||||
| echo "for example: sh scripts/run_distribute_train.sh 8 /dataset_path /rank_table_8p.json" | |||||
| echo "After running the script, the network runs in the background, The log will be generated in logx/output.log" | |||||
| export RANK_SIZE=$1 | |||||
| DATA_URL=$2 | |||||
| export MINDSPORE_HCCL_CONFIG_PAHT=$3 | |||||
| for ((i=0; i<RANK_SIZE;i++)) | |||||
| do | |||||
| export DEVICE_ID=$i | |||||
| export RANK_ID=$i | |||||
| rm -rf log$i | |||||
| mkdir ./log$i | |||||
| cp *.py ./log$i | |||||
| cp -r src ./log$i | |||||
| cd ./log$i || exit | |||||
| echo "start training for rank $i, device $DEVICE_ID" | |||||
| env > env.log | |||||
| python -u train.py \ | |||||
| --dataset_path=$DATA_URL \ | |||||
| --ckpt_path="checkpoint" \ | |||||
| --eval_file_name='auc.log' \ | |||||
| --loss_file_name='loss.log' \ | |||||
| --do_eval=True > output.log 2>&1 & | |||||
| cd ../ | |||||
| done | |||||
| @@ -0,0 +1,32 @@ | |||||
| #!/bin/bash | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| echo "Please run the script as: " | |||||
| echo "sh scripts/run_eval.sh DEVICE_ID DATASET_PATH CHECKPOINT_PATH" | |||||
| echo "for example: sh scripts/run_eval.sh 0 /dataset_path /checkpoint_path" | |||||
| echo "After running the script, the network runs in the background, The log will be generated in ms_log/eval_output.log" | |||||
| export DEVICE_ID=$1 | |||||
| DATA_URL=$2 | |||||
| CHECKPOINT_PATH=$3 | |||||
| mkdir -p ms_log | |||||
| CUR_DIR=`pwd` | |||||
| export GLOG_log_dir=${CUR_DIR}/ms_log | |||||
| export GLOG_logtostderr=0 | |||||
| python -u eval.py \ | |||||
| --dataset_path=$DATA_URL \ | |||||
| --checkpoint_path=$CHECKPOINT_PATH > ms_log/eval_output.log 2>&1 & | |||||
| @@ -0,0 +1,34 @@ | |||||
| #!/bin/bash | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| echo "Please run the script as: " | |||||
| echo "sh scripts/run_standalone_train.sh DEVICE_ID DATASET_PATH" | |||||
| echo "for example: sh scripts/run_standalone_train.sh 0 /dataset_path" | |||||
| echo "After running the script, the network runs in the background, The log will be generated in ms_log/output.log" | |||||
| export DEVICE_ID=$1 | |||||
| DATA_URL=$2 | |||||
| mkdir -p ms_log | |||||
| CUR_DIR=`pwd` | |||||
| export GLOG_log_dir=${CUR_DIR}/ms_log | |||||
| export GLOG_logtostderr=0 | |||||
| python -u train.py \ | |||||
| --dataset_path=$DATA_URL \ | |||||
| --ckpt_path="checkpoint" \ | |||||
| --eval_file_name='auc.log' \ | |||||
| --loss_file_name='loss.log' \ | |||||
| --do_eval=True > ms_log/output.log 2>&1 & | |||||
| @@ -0,0 +1,14 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the License); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # httpwww.apache.orglicensesLICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an AS IS BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| @@ -0,0 +1,107 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the License); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # httpwww.apache.orglicensesLICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an AS IS BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| Defined callback for DeepFM. | |||||
| """ | |||||
| import time | |||||
| from mindspore.train.callback import Callback | |||||
| def add_write(file_path, out_str): | |||||
| with open(file_path, 'a+', encoding='utf-8') as file_out: | |||||
| file_out.write(out_str + '\n') | |||||
| class EvalCallBack(Callback): | |||||
| """ | |||||
| Monitor the loss in training. | |||||
| If the loss is NAN or INF terminating training. | |||||
| Note | |||||
| If per_print_times is 0 do not print loss. | |||||
| """ | |||||
| def __init__(self, model, eval_dataset, auc_metric, eval_file_path): | |||||
| super(EvalCallBack, self).__init__() | |||||
| self.model = model | |||||
| self.eval_dataset = eval_dataset | |||||
| self.aucMetric = auc_metric | |||||
| self.aucMetric.clear() | |||||
| self.eval_file_path = eval_file_path | |||||
| def epoch_end(self, run_context): | |||||
| start_time = time.time() | |||||
| out = self.model.eval(self.eval_dataset) | |||||
| eval_time = int(time.time() - start_time) | |||||
| time_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) | |||||
| out_str = "{} EvalCallBack metric{}; eval_time{}s".format( | |||||
| time_str, out.values(), eval_time) | |||||
| print(out_str) | |||||
| add_write(self.eval_file_path, out_str) | |||||
| class LossCallBack(Callback): | |||||
| """ | |||||
| Monitor the loss in training. | |||||
| If the loss is NAN or INF terminating training. | |||||
| Note | |||||
| If per_print_times is 0 do not print loss. | |||||
| Args | |||||
| loss_file_path (str) The file absolute path, to save as loss_file; | |||||
| per_print_times (int) Print loss every times. Default 1. | |||||
| """ | |||||
| def __init__(self, loss_file_path, per_print_times=1): | |||||
| super(LossCallBack, self).__init__() | |||||
| if not isinstance(per_print_times, int) or per_print_times < 0: | |||||
| raise ValueError("print_step must be int and >= 0.") | |||||
| self.loss_file_path = loss_file_path | |||||
| self._per_print_times = per_print_times | |||||
| def step_end(self, run_context): | |||||
| cb_params = run_context.original_args() | |||||
| loss = cb_params.net_outputs.asnumpy() | |||||
| cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1 | |||||
| cur_num = cb_params.cur_step_num | |||||
| if self._per_print_times != 0 and cur_num % self._per_print_times == 0: | |||||
| with open(self.loss_file_path, "a+") as loss_file: | |||||
| time_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) | |||||
| loss_file.write("{} epoch: {} step: {}, loss is {}\n".format( | |||||
| time_str, cb_params.cur_epoch_num, cur_step_in_epoch, loss)) | |||||
| print("epoch: {} step: {}, loss is {}\n".format( | |||||
| cb_params.cur_epoch_num, cur_step_in_epoch, loss)) | |||||
| class TimeMonitor(Callback): | |||||
| """ | |||||
| Time monitor for calculating cost of each epoch. | |||||
| Args | |||||
| data_size (int) step size of an epoch. | |||||
| """ | |||||
| def __init__(self, data_size): | |||||
| super(TimeMonitor, self).__init__() | |||||
| self.data_size = data_size | |||||
| def epoch_begin(self, run_context): | |||||
| self.epoch_time = time.time() | |||||
| def epoch_end(self, run_context): | |||||
| epoch_mseconds = (time.time() - self.epoch_time) * 1000 | |||||
| per_step_mseconds = epoch_mseconds / self.data_size | |||||
| print("epoch time: {0}, per step time: {1}".format(epoch_mseconds, per_step_mseconds), flush=True) | |||||
| def step_begin(self, run_context): | |||||
| self.step_time = time.time() | |||||
| def step_end(self, run_context): | |||||
| step_mseconds = (time.time() - self.step_time) * 1000 | |||||
| print(f"step time {step_mseconds}", flush=True) | |||||
| @@ -0,0 +1,62 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| network config setting, will be used in train.py and eval.py | |||||
| """ | |||||
| class DataConfig: | |||||
| """ | |||||
| Define parameters of dataset. | |||||
| """ | |||||
| data_vocab_size = 184965 | |||||
| train_num_of_parts = 21 | |||||
| test_num_of_parts = 3 | |||||
| batch_size = 1000 | |||||
| data_field_size = 39 | |||||
| # dataset format, 1: mindrecord, 2: tfrecord, 3: h5 | |||||
| data_format = 2 | |||||
| class ModelConfig: | |||||
| """ | |||||
| Define parameters of model. | |||||
| """ | |||||
| batch_size = DataConfig.batch_size | |||||
| data_field_size = DataConfig.data_field_size | |||||
| data_vocab_size = DataConfig.data_vocab_size | |||||
| data_emb_dim = 80 | |||||
| deep_layer_args = [[400, 400, 512], "relu"] | |||||
| init_args = [-0.01, 0.01] | |||||
| weight_bias_init = ['normal', 'normal'] | |||||
| keep_prob = 0.9 | |||||
| class TrainConfig: | |||||
| """ | |||||
| Define parameters of training. | |||||
| """ | |||||
| batch_size = DataConfig.batch_size | |||||
| l2_coef = 1e-6 | |||||
| learning_rate = 1e-5 | |||||
| epsilon = 1e-8 | |||||
| loss_scale = 1024.0 | |||||
| train_epochs = 15 | |||||
| save_checkpoint = True | |||||
| ckpt_file_name_prefix = "deepfm" | |||||
| save_checkpoint_steps = 1 | |||||
| keep_checkpoint_max = 15 | |||||
| eval_callback = True | |||||
| loss_callback = True | |||||
| @@ -0,0 +1,299 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| Create train or eval dataset. | |||||
| """ | |||||
| import os | |||||
| import math | |||||
| from enum import Enum | |||||
| import pandas as pd | |||||
| import numpy as np | |||||
| import mindspore.dataset.engine as de | |||||
| import mindspore.common.dtype as mstype | |||||
| from .config import DataConfig | |||||
| class DataType(Enum): | |||||
| """ | |||||
| Enumerate supported dataset format. | |||||
| """ | |||||
| MINDRECORD = 1 | |||||
| TFRECORD = 2 | |||||
| H5 = 3 | |||||
| class H5Dataset(): | |||||
| """ | |||||
| Create dataset with H5 format. | |||||
| Args: | |||||
| data_path (str): Dataset directory. | |||||
| train_mode (bool): Whether dataset is used for train or eval (default=True). | |||||
| train_num_of_parts (int): The number of train data file (default=21). | |||||
| test_num_of_parts (int): The number of test data file (default=3). | |||||
| """ | |||||
| max_length = 39 | |||||
| def __init__(self, data_path, train_mode=True, | |||||
| train_num_of_parts=DataConfig.train_num_of_parts, | |||||
| test_num_of_parts=DataConfig.test_num_of_parts): | |||||
| self._hdf_data_dir = data_path | |||||
| self._is_training = train_mode | |||||
| if self._is_training: | |||||
| self._file_prefix = 'train' | |||||
| self._num_of_parts = train_num_of_parts | |||||
| else: | |||||
| self._file_prefix = 'test' | |||||
| self._num_of_parts = test_num_of_parts | |||||
| self.data_size = self._bin_count(self._hdf_data_dir, self._file_prefix, self._num_of_parts) | |||||
| print("data_size: {}".format(self.data_size)) | |||||
| def _bin_count(self, hdf_data_dir, file_prefix, num_of_parts): | |||||
| size = 0 | |||||
| for part in range(num_of_parts): | |||||
| _y = pd.read_hdf(os.path.join(hdf_data_dir, f'{file_prefix}_output_part_{str(part)}.h5')) | |||||
| size += _y.shape[0] | |||||
| return size | |||||
| def _iterate_hdf_files_(self, num_of_parts=None, | |||||
| shuffle_block=False): | |||||
| """ | |||||
| iterate among hdf files(blocks). when the whole data set is finished, the iterator restarts | |||||
| from the beginning, thus the data stream will never stop | |||||
| :param train_mode: True or false,false is eval_mode, | |||||
| this file iterator will go through the train set | |||||
| :param num_of_parts: number of files | |||||
| :param shuffle_block: shuffle block files at every round | |||||
| :return: input_hdf_file_name, output_hdf_file_name, finish_flag | |||||
| """ | |||||
| parts = np.arange(num_of_parts) | |||||
| while True: | |||||
| if shuffle_block: | |||||
| for _ in range(int(shuffle_block)): | |||||
| np.random.shuffle(parts) | |||||
| for i, p in enumerate(parts): | |||||
| yield os.path.join(self._hdf_data_dir, f'{self._file_prefix}_input_part_{str(p)}.h5'), \ | |||||
| os.path.join(self._hdf_data_dir, f'{self._file_prefix}_output_part_{str(p)}.h5'), \ | |||||
| i + 1 == len(parts) | |||||
| def _generator(self, X, y, batch_size, shuffle=True): | |||||
| """ | |||||
| should be accessed only in private | |||||
| :param X: | |||||
| :param y: | |||||
| :param batch_size: | |||||
| :param shuffle: | |||||
| :return: | |||||
| """ | |||||
| number_of_batches = np.ceil(1. * X.shape[0] / batch_size) | |||||
| counter = 0 | |||||
| finished = False | |||||
| sample_index = np.arange(X.shape[0]) | |||||
| if shuffle: | |||||
| for _ in range(int(shuffle)): | |||||
| np.random.shuffle(sample_index) | |||||
| assert X.shape[0] > 0 | |||||
| while True: | |||||
| batch_index = sample_index[batch_size * counter: batch_size * (counter + 1)] | |||||
| X_batch = X[batch_index] | |||||
| y_batch = y[batch_index] | |||||
| counter += 1 | |||||
| yield X_batch, y_batch, finished | |||||
| if counter == number_of_batches: | |||||
| counter = 0 | |||||
| finished = True | |||||
| def batch_generator(self, batch_size=1000, | |||||
| random_sample=False, shuffle_block=False): | |||||
| """ | |||||
| :param train_mode: True or false,false is eval_mode, | |||||
| :param batch_size | |||||
| :param num_of_parts: number of files | |||||
| :param random_sample: if True, will shuffle | |||||
| :param shuffle_block: shuffle file blocks at every round | |||||
| :return: | |||||
| """ | |||||
| for hdf_in, hdf_out, _ in self._iterate_hdf_files_(self._num_of_parts, | |||||
| shuffle_block): | |||||
| start = stop = None | |||||
| X_all = pd.read_hdf(hdf_in, start=start, stop=stop).values | |||||
| y_all = pd.read_hdf(hdf_out, start=start, stop=stop).values | |||||
| data_gen = self._generator(X_all, y_all, batch_size, | |||||
| shuffle=random_sample) | |||||
| finished = False | |||||
| while not finished: | |||||
| X, y, finished = data_gen.__next__() | |||||
| X_id = X[:, 0:self.max_length] | |||||
| X_va = X[:, self.max_length:] | |||||
| yield np.array(X_id.astype(dtype=np.int32)), \ | |||||
| np.array(X_va.astype(dtype=np.float32)), \ | |||||
| np.array(y.astype(dtype=np.float32)) | |||||
| def _get_h5_dataset(directory, train_mode=True, epochs=1, batch_size=1000): | |||||
| """ | |||||
| Get dataset with h5 format. | |||||
| Args: | |||||
| directory (str): Dataset directory. | |||||
| train_mode (bool): Whether dataset is use for train or eval (default=True). | |||||
| epochs (int): Dataset epoch size (default=1). | |||||
| batch_size (int): Dataset batch size (default=1000) | |||||
| Returns: | |||||
| Dataset. | |||||
| """ | |||||
| data_para = {'batch_size': batch_size} | |||||
| if train_mode: | |||||
| data_para['random_sample'] = True | |||||
| data_para['shuffle_block'] = True | |||||
| h5_dataset = H5Dataset(data_path=directory, train_mode=train_mode) | |||||
| numbers_of_batch = math.ceil(h5_dataset.data_size / batch_size) | |||||
| def _iter_h5_data(): | |||||
| train_eval_gen = h5_dataset.batch_generator(**data_para) | |||||
| for _ in range(0, numbers_of_batch, 1): | |||||
| yield train_eval_gen.__next__() | |||||
| ds = de.GeneratorDataset(_iter_h5_data, ["ids", "weights", "labels"]) | |||||
| ds.set_dataset_size(numbers_of_batch) | |||||
| ds = ds.repeat(epochs) | |||||
| return ds | |||||
| def _get_mindrecord_dataset(directory, train_mode=True, epochs=1, batch_size=1000, | |||||
| line_per_sample=1000, rank_size=None, rank_id=None): | |||||
| """ | |||||
| Get dataset with mindrecord format. | |||||
| Args: | |||||
| directory (str): Dataset directory. | |||||
| train_mode (bool): Whether dataset is use for train or eval (default=True). | |||||
| epochs (int): Dataset epoch size (default=1). | |||||
| batch_size (int): Dataset batch size (default=1000). | |||||
| line_per_sample (int): The number of sample per line (default=1000). | |||||
| rank_size (int): The number of device, not necessary for single device (default=None). | |||||
| rank_id (int): Id of device, not necessary for single device (default=None). | |||||
| Returns: | |||||
| Dataset. | |||||
| """ | |||||
| file_prefix_name = 'train_input_part.mindrecord' if train_mode else 'test_input_part.mindrecord' | |||||
| file_suffix_name = '00' if train_mode else '0' | |||||
| shuffle = train_mode | |||||
| if rank_size is not None and rank_id is not None: | |||||
| ds = de.MindDataset(os.path.join(directory, file_prefix_name + file_suffix_name), | |||||
| columns_list=['feat_ids', 'feat_vals', 'label'], | |||||
| num_shards=rank_size, shard_id=rank_id, shuffle=shuffle, | |||||
| num_parallel_workers=8) | |||||
| else: | |||||
| ds = de.MindDataset(os.path.join(directory, file_prefix_name + file_suffix_name), | |||||
| columns_list=['feat_ids', 'feat_vals', 'label'], | |||||
| shuffle=shuffle, num_parallel_workers=8) | |||||
| ds = ds.batch(int(batch_size / line_per_sample), drop_remainder=True) | |||||
| ds = ds.map(operations=(lambda x, y, z: (np.array(x).flatten().reshape(batch_size, 39), | |||||
| np.array(y).flatten().reshape(batch_size, 39), | |||||
| np.array(z).flatten().reshape(batch_size, 1))), | |||||
| input_columns=['feat_ids', 'feat_vals', 'label'], | |||||
| columns_order=['feat_ids', 'feat_vals', 'label'], | |||||
| num_parallel_workers=8) | |||||
| ds = ds.repeat(epochs) | |||||
| return ds | |||||
| def _get_tf_dataset(directory, train_mode=True, epochs=1, batch_size=1000, | |||||
| line_per_sample=1000, rank_size=None, rank_id=None): | |||||
| """ | |||||
| Get dataset with tfrecord format. | |||||
| Args: | |||||
| directory (str): Dataset directory. | |||||
| train_mode (bool): Whether dataset is use for train or eval (default=True). | |||||
| epochs (int): Dataset epoch size (default=1). | |||||
| batch_size (int): Dataset batch size (default=1000). | |||||
| line_per_sample (int): The number of sample per line (default=1000). | |||||
| rank_size (int): The number of device, not necessary for single device (default=None). | |||||
| rank_id (int): Id of device, not necessary for single device (default=None). | |||||
| Returns: | |||||
| Dataset. | |||||
| """ | |||||
| dataset_files = [] | |||||
| file_prefixt_name = 'train' if train_mode else 'test' | |||||
| shuffle = train_mode | |||||
| for (dir_path, _, filenames) in os.walk(directory): | |||||
| for filename in filenames: | |||||
| if file_prefixt_name in filename and 'tfrecord' in filename: | |||||
| dataset_files.append(os.path.join(dir_path, filename)) | |||||
| schema = de.Schema() | |||||
| schema.add_column('feat_ids', de_type=mstype.int32) | |||||
| schema.add_column('feat_vals', de_type=mstype.float32) | |||||
| schema.add_column('label', de_type=mstype.float32) | |||||
| if rank_size is not None and rank_id is not None: | |||||
| ds = de.TFRecordDataset(dataset_files=dataset_files, shuffle=shuffle, | |||||
| schema=schema, num_parallel_workers=8, | |||||
| num_shards=rank_size, shard_id=rank_id, | |||||
| shard_equal_rows=True) | |||||
| else: | |||||
| ds = de.TFRecordDataset(dataset_files=dataset_files, shuffle=shuffle, | |||||
| schema=schema, num_parallel_workers=8) | |||||
| ds = ds.batch(int(batch_size / line_per_sample), drop_remainder=True) | |||||
| ds = ds.map(operations=(lambda x, y, z: ( | |||||
| np.array(x).flatten().reshape(batch_size, 39), | |||||
| np.array(y).flatten().reshape(batch_size, 39), | |||||
| np.array(z).flatten().reshape(batch_size, 1))), | |||||
| input_columns=['feat_ids', 'feat_vals', 'label'], | |||||
| columns_order=['feat_ids', 'feat_vals', 'label'], | |||||
| num_parallel_workers=8) | |||||
| ds = ds.repeat(epochs) | |||||
| return ds | |||||
| def create_dataset(directory, train_mode=True, epochs=1, batch_size=1000, | |||||
| data_type=DataType.TFRECORD, line_per_sample=1000, | |||||
| rank_size=None, rank_id=None): | |||||
| """ | |||||
| Get dataset. | |||||
| Args: | |||||
| directory (str): Dataset directory. | |||||
| train_mode (bool): Whether dataset is use for train or eval (default=True). | |||||
| epochs (int): Dataset epoch size (default=1). | |||||
| batch_size (int): Dataset batch size (default=1000). | |||||
| data_type (DataType): The type of dataset which is one of H5, TFRECORE, MINDRECORD (default=TFRECORD). | |||||
| line_per_sample (int): The number of sample per line (default=1000). | |||||
| rank_size (int): The number of device, not necessary for single device (default=None). | |||||
| rank_id (int): Id of device, not necessary for single device (default=None). | |||||
| Returns: | |||||
| Dataset. | |||||
| """ | |||||
| if data_type == DataType.MINDRECORD: | |||||
| return _get_mindrecord_dataset(directory, train_mode, epochs, | |||||
| batch_size, line_per_sample, | |||||
| rank_size, rank_id) | |||||
| if data_type == DataType.TFRECORD: | |||||
| return _get_tf_dataset(directory, train_mode, epochs, batch_size, | |||||
| line_per_sample, rank_size=rank_size, rank_id=rank_id) | |||||
| if rank_size is not None and rank_size > 1: | |||||
| raise ValueError('Please use mindrecord dataset.') | |||||
| return _get_h5_dataset(directory, train_mode, epochs, batch_size) | |||||
| @@ -0,0 +1,370 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ test_training """ | |||||
| import os | |||||
| import numpy as np | |||||
| from sklearn.metrics import roc_auc_score | |||||
| import mindspore.common.dtype as mstype | |||||
| from mindspore.ops import functional as F | |||||
| from mindspore.ops import composite as C | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore.nn import Dropout | |||||
| from mindspore.nn.optim import Adam | |||||
| from mindspore.nn.metrics import Metric | |||||
| from mindspore import nn, ParameterTuple, Parameter | |||||
| from mindspore.common.initializer import Uniform, initializer, Normal | |||||
| from mindspore.train.callback import ModelCheckpoint, CheckpointConfig | |||||
| from .callback import EvalCallBack, LossCallBack | |||||
| np_type = np.float32 | |||||
| ms_type = mstype.float32 | |||||
| class AUCMetric(Metric): | |||||
| """AUC metric for DeepFM model.""" | |||||
| def __init__(self): | |||||
| super(AUCMetric, self).__init__() | |||||
| self.pred_probs = [] | |||||
| self.true_labels = [] | |||||
| def clear(self): | |||||
| """Clear the internal evaluation result.""" | |||||
| self.pred_probs = [] | |||||
| self.true_labels = [] | |||||
| def update(self, *inputs): | |||||
| batch_predict = inputs[1].asnumpy() | |||||
| batch_label = inputs[2].asnumpy() | |||||
| self.pred_probs.extend(batch_predict.flatten().tolist()) | |||||
| self.true_labels.extend(batch_label.flatten().tolist()) | |||||
| def eval(self): | |||||
| if len(self.true_labels) != len(self.pred_probs): | |||||
| raise RuntimeError('true_labels.size() is not equal to pred_probs.size()') | |||||
| auc = roc_auc_score(self.true_labels, self.pred_probs) | |||||
| return auc | |||||
| def init_method(method, shape, name, max_val=0.01): | |||||
| """ | |||||
| The method of init parameters. | |||||
| Args: | |||||
| method (str): The method uses to initialize parameter. | |||||
| shape (list): The shape of parameter. | |||||
| name (str): The name of parameter. | |||||
| max_val (float): Max value in parameter when uses 'random' or 'uniform' to initialize parameter. | |||||
| Returns: | |||||
| Parameter. | |||||
| """ | |||||
| if method in ['random', 'uniform']: | |||||
| params = Parameter(initializer(Uniform(max_val), shape, ms_type), name=name) | |||||
| elif method == "one": | |||||
| params = Parameter(initializer("ones", shape, ms_type), name=name) | |||||
| elif method == 'zero': | |||||
| params = Parameter(initializer("zeros", shape, ms_type), name=name) | |||||
| elif method == "normal": | |||||
| params = Parameter(initializer(Normal(max_val), shape, ms_type), name=name) | |||||
| return params | |||||
| def init_var_dict(init_args, values): | |||||
| """ | |||||
| Init parameter. | |||||
| Args: | |||||
| init_args (list): Define max and min value of parameters. | |||||
| values (list): Define name, shape and init method of parameters. | |||||
| Returns: | |||||
| dict, a dict ot Parameter. | |||||
| """ | |||||
| var_map = {} | |||||
| _, _max_val = init_args | |||||
| for key, shape, init_flag in values: | |||||
| if key not in var_map.keys(): | |||||
| if init_flag in ['random', 'uniform']: | |||||
| var_map[key] = Parameter(initializer(Uniform(_max_val), shape, ms_type), name=key) | |||||
| elif init_flag == "one": | |||||
| var_map[key] = Parameter(initializer("ones", shape, ms_type), name=key) | |||||
| elif init_flag == "zero": | |||||
| var_map[key] = Parameter(initializer("zeros", shape, ms_type), name=key) | |||||
| elif init_flag == 'normal': | |||||
| var_map[key] = Parameter(initializer(Normal(_max_val), shape, ms_type), name=key) | |||||
| return var_map | |||||
| class DenseLayer(nn.Cell): | |||||
| """ | |||||
| Dense Layer for Deep Layer of DeepFM Model; | |||||
| Containing: activation, matmul, bias_add; | |||||
| Args: | |||||
| input_dim (int): the shape of weight at 0-aixs; | |||||
| output_dim (int): the shape of weight at 1-aixs, and shape of bias | |||||
| weight_bias_init (list): weight and bias init method, "random", "uniform", "one", "zero", "normal"; | |||||
| act_str (str): activation function method, "relu", "sigmoid", "tanh"; | |||||
| keep_prob (float): Dropout Layer keep_prob_rate; | |||||
| scale_coef (float): input scale coefficient; | |||||
| """ | |||||
| def __init__(self, input_dim, output_dim, weight_bias_init, act_str, keep_prob=0.9, scale_coef=1.0): | |||||
| super(DenseLayer, self).__init__() | |||||
| weight_init, bias_init = weight_bias_init | |||||
| self.weight = init_method(weight_init, [input_dim, output_dim], name="weight") | |||||
| self.bias = init_method(bias_init, [output_dim], name="bias") | |||||
| self.act_func = self._init_activation(act_str) | |||||
| self.matmul = P.MatMul(transpose_b=False) | |||||
| self.bias_add = P.BiasAdd() | |||||
| self.cast = P.Cast() | |||||
| self.dropout = Dropout(keep_prob=keep_prob) | |||||
| self.mul = P.Mul() | |||||
| self.realDiv = P.RealDiv() | |||||
| self.scale_coef = scale_coef | |||||
| def _init_activation(self, act_str): | |||||
| act_str = act_str.lower() | |||||
| if act_str == "relu": | |||||
| act_func = P.ReLU() | |||||
| elif act_str == "sigmoid": | |||||
| act_func = P.Sigmoid() | |||||
| elif act_str == "tanh": | |||||
| act_func = P.Tanh() | |||||
| return act_func | |||||
| def construct(self, x): | |||||
| x = self.act_func(x) | |||||
| if self.training: | |||||
| x = self.dropout(x) | |||||
| x = self.mul(x, self.scale_coef) | |||||
| x = self.cast(x, mstype.float16) | |||||
| weight = self.cast(self.weight, mstype.float16) | |||||
| wx = self.matmul(x, weight) | |||||
| wx = self.cast(wx, mstype.float32) | |||||
| wx = self.realDiv(wx, self.scale_coef) | |||||
| output = self.bias_add(wx, self.bias) | |||||
| return output | |||||
| class DeepFMModel(nn.Cell): | |||||
| """ | |||||
| From paper: "DeepFM: A Factorization-Machine based Neural Network for CTR Prediction" | |||||
| Args: | |||||
| batch_size (int): smaple_number of per step in training; (int, batch_size=128) | |||||
| filed_size (int): input filed number, or called id_feature number; (int, filed_size=39) | |||||
| vocab_size (int): id_feature vocab size, id dict size; (int, vocab_size=200000) | |||||
| emb_dim (int): id embedding vector dim, id mapped to embedding vector; (int, emb_dim=100) | |||||
| deep_layer_args (list): Deep Layer args, layer_dim_list, layer_activator; | |||||
| (int, deep_layer_args=[[100, 100, 100], "relu"]) | |||||
| init_args (list): init args for Parameter init; (list, init_args=[min, max, seeds]) | |||||
| weight_bias_init (list): weight, bias init method for deep layers; | |||||
| (list[str], weight_bias_init=['random', 'zero']) | |||||
| keep_prob (float): if dropout_flag is True, keep_prob rate to keep connect; (float, keep_prob=0.8) | |||||
| """ | |||||
| def __init__(self, config): | |||||
| super(DeepFMModel, self).__init__() | |||||
| self.batch_size = config.batch_size | |||||
| self.field_size = config.data_field_size | |||||
| self.vocab_size = config.data_vocab_size | |||||
| self.emb_dim = config.data_emb_dim | |||||
| self.deep_layer_dims_list, self.deep_layer_act = config.deep_layer_args | |||||
| self.init_args = config.init_args | |||||
| self.weight_bias_init = config.weight_bias_init | |||||
| self.keep_prob = config.keep_prob | |||||
| init_acts = [('W_l2', [self.vocab_size, 1], 'normal'), | |||||
| ('V_l2', [self.vocab_size, self.emb_dim], 'normal'), | |||||
| ('b', [1], 'normal')] | |||||
| var_map = init_var_dict(self.init_args, init_acts) | |||||
| self.fm_w = var_map["W_l2"] | |||||
| self.fm_b = var_map["b"] | |||||
| self.embedding_table = var_map["V_l2"] | |||||
| # Deep Layers | |||||
| self.deep_input_dims = self.field_size * self.emb_dim + 1 | |||||
| self.all_dim_list = [self.deep_input_dims] + self.deep_layer_dims_list + [1] | |||||
| self.dense_layer_1 = DenseLayer(self.all_dim_list[0], self.all_dim_list[1], | |||||
| self.weight_bias_init, self.deep_layer_act, self.keep_prob) | |||||
| self.dense_layer_2 = DenseLayer(self.all_dim_list[1], self.all_dim_list[2], | |||||
| self.weight_bias_init, self.deep_layer_act, self.keep_prob) | |||||
| self.dense_layer_3 = DenseLayer(self.all_dim_list[2], self.all_dim_list[3], | |||||
| self.weight_bias_init, self.deep_layer_act, self.keep_prob) | |||||
| self.dense_layer_4 = DenseLayer(self.all_dim_list[3], self.all_dim_list[4], | |||||
| self.weight_bias_init, self.deep_layer_act, self.keep_prob) | |||||
| # FM, linear Layers | |||||
| self.Gatherv2 = P.GatherV2() | |||||
| self.Mul = P.Mul() | |||||
| self.ReduceSum = P.ReduceSum(keep_dims=False) | |||||
| self.Reshape = P.Reshape() | |||||
| self.Square = P.Square() | |||||
| self.Shape = P.Shape() | |||||
| self.Tile = P.Tile() | |||||
| self.Concat = P.Concat(axis=1) | |||||
| self.Cast = P.Cast() | |||||
| def construct(self, id_hldr, wt_hldr): | |||||
| """ | |||||
| Args: | |||||
| id_hldr: batch ids; [bs, field_size] | |||||
| wt_hldr: batch weights; [bs, field_size] | |||||
| """ | |||||
| mask = self.Reshape(wt_hldr, (self.batch_size, self.field_size, 1)) | |||||
| # Linear layer | |||||
| fm_id_weight = self.Gatherv2(self.fm_w, id_hldr, 0) | |||||
| wx = self.Mul(fm_id_weight, mask) | |||||
| linear_out = self.ReduceSum(wx, 1) | |||||
| # FM layer | |||||
| fm_id_embs = self.Gatherv2(self.embedding_table, id_hldr, 0) | |||||
| vx = self.Mul(fm_id_embs, mask) | |||||
| v1 = self.ReduceSum(vx, 1) | |||||
| v1 = self.Square(v1) | |||||
| v2 = self.Square(vx) | |||||
| v2 = self.ReduceSum(v2, 1) | |||||
| fm_out = 0.5 * self.ReduceSum(v1 - v2, 1) | |||||
| fm_out = self.Reshape(fm_out, (-1, 1)) | |||||
| # Deep layer | |||||
| b = self.Reshape(self.fm_b, (1, 1)) | |||||
| b = self.Tile(b, (self.batch_size, 1)) | |||||
| deep_in = self.Reshape(vx, (-1, self.field_size * self.emb_dim)) | |||||
| deep_in = self.Concat((deep_in, b)) | |||||
| deep_in = self.dense_layer_1(deep_in) | |||||
| deep_in = self.dense_layer_2(deep_in) | |||||
| deep_in = self.dense_layer_3(deep_in) | |||||
| deep_out = self.dense_layer_4(deep_in) | |||||
| out = linear_out + fm_out + deep_out | |||||
| return out, fm_id_weight, fm_id_embs | |||||
| class NetWithLossClass(nn.Cell): | |||||
| """ | |||||
| NetWithLossClass definition. | |||||
| """ | |||||
| def __init__(self, network, l2_coef=1e-6): | |||||
| super(NetWithLossClass, self).__init__(auto_prefix=False) | |||||
| self.loss = P.SigmoidCrossEntropyWithLogits() | |||||
| self.network = network | |||||
| self.l2_coef = l2_coef | |||||
| self.Square = P.Square() | |||||
| self.ReduceMean_false = P.ReduceMean(keep_dims=False) | |||||
| self.ReduceSum_false = P.ReduceSum(keep_dims=False) | |||||
| def construct(self, batch_ids, batch_wts, label): | |||||
| predict, fm_id_weight, fm_id_embs = self.network(batch_ids, batch_wts) | |||||
| log_loss = self.loss(predict, label) | |||||
| mean_log_loss = self.ReduceMean_false(log_loss) | |||||
| l2_loss_w = self.ReduceSum_false(self.Square(fm_id_weight)) | |||||
| l2_loss_v = self.ReduceSum_false(self.Square(fm_id_embs)) | |||||
| l2_loss_all = self.l2_coef * (l2_loss_v + l2_loss_w) * 0.5 | |||||
| loss = mean_log_loss + l2_loss_all | |||||
| return loss | |||||
| class TrainStepWrap(nn.Cell): | |||||
| """ | |||||
| TrainStepWrap definition | |||||
| """ | |||||
| def __init__(self, network, lr=5e-8, eps=1e-8, loss_scale=1000.0): | |||||
| super(TrainStepWrap, self).__init__(auto_prefix=False) | |||||
| self.network = network | |||||
| self.network.set_train() | |||||
| self.weights = ParameterTuple(network.trainable_params()) | |||||
| self.optimizer = Adam(self.weights, learning_rate=lr, eps=eps, loss_scale=loss_scale) | |||||
| self.hyper_map = C.HyperMap() | |||||
| self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) | |||||
| self.sens = loss_scale | |||||
| def construct(self, batch_ids, batch_wts, label): | |||||
| weights = self.weights | |||||
| loss = self.network(batch_ids, batch_wts, label) | |||||
| sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) # | |||||
| grads = self.grad(self.network, weights)(batch_ids, batch_wts, label, sens) | |||||
| return F.depend(loss, self.optimizer(grads)) | |||||
| class PredictWithSigmoid(nn.Cell): | |||||
| """ | |||||
| Eval model with sigmoid. | |||||
| """ | |||||
| def __init__(self, network): | |||||
| super(PredictWithSigmoid, self).__init__(auto_prefix=False) | |||||
| self.network = network | |||||
| self.sigmoid = P.Sigmoid() | |||||
| def construct(self, batch_ids, batch_wts, labels): | |||||
| logits, _, _, = self.network(batch_ids, batch_wts) | |||||
| pred_probs = self.sigmoid(logits) | |||||
| return logits, pred_probs, labels | |||||
| class ModelBuilder: | |||||
| """ | |||||
| Model builder for DeepFM. | |||||
| Args: | |||||
| model_config (ModelConfig): Model configuration. | |||||
| train_config (TrainConfig): Train configuration. | |||||
| """ | |||||
| def __init__(self, model_config, train_config): | |||||
| self.model_config = model_config | |||||
| self.train_config = train_config | |||||
| def get_callback_list(self, model=None, eval_dataset=None): | |||||
| """ | |||||
| Get callbacks which contains checkpoint callback, eval callback and loss callback. | |||||
| Args: | |||||
| model (Cell): The network is added callback (default=None). | |||||
| eval_dataset (Dataset): Dataset for eval (default=None). | |||||
| """ | |||||
| callback_list = [] | |||||
| if self.train_config.save_checkpoint: | |||||
| config_ck = CheckpointConfig(save_checkpoint_steps=self.train_config.save_checkpoint_steps, | |||||
| keep_checkpoint_max=self.train_config.keep_checkpoint_max) | |||||
| ckpt_cb = ModelCheckpoint(prefix=self.train_config.ckpt_file_name_prefix, | |||||
| directory=self.train_config.output_path, | |||||
| config=config_ck) | |||||
| callback_list.append(ckpt_cb) | |||||
| if self.train_config.eval_callback: | |||||
| if model is None: | |||||
| raise RuntimeError("train_config.eval_callback is {}; get_callback_list() args model is {}".format( | |||||
| self.train_config.eval_callback, model)) | |||||
| if eval_dataset is None: | |||||
| raise RuntimeError("train_config.eval_callback is {}; get_callback_list() " | |||||
| "args eval_dataset is {}".format(self.train_config.eval_callback, eval_dataset)) | |||||
| auc_metric = AUCMetric() | |||||
| eval_callback = EvalCallBack(model, eval_dataset, auc_metric, | |||||
| eval_file_path=os.path.join(self.train_config.output_path, | |||||
| self.train_config.eval_file_name)) | |||||
| callback_list.append(eval_callback) | |||||
| if self.train_config.loss_callback: | |||||
| loss_callback = LossCallBack(loss_file_path=os.path.join(self.train_config.output_path, | |||||
| self.train_config.loss_file_name)) | |||||
| callback_list.append(loss_callback) | |||||
| if callback_list: | |||||
| return callback_list | |||||
| return None | |||||
| def get_train_eval_net(self): | |||||
| deepfm_net = DeepFMModel(self.model_config) | |||||
| loss_net = NetWithLossClass(deepfm_net, l2_coef=self.train_config.l2_coef) | |||||
| train_net = TrainStepWrap(loss_net, lr=self.train_config.learning_rate, | |||||
| eps=self.train_config.epsilon, | |||||
| loss_scale=self.train_config.loss_scale) | |||||
| eval_net = PredictWithSigmoid(deepfm_net) | |||||
| return train_net, eval_net | |||||
| @@ -0,0 +1,91 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """train_criteo.""" | |||||
| import os | |||||
| import sys | |||||
| import argparse | |||||
| from mindspore import context, ParallelMode | |||||
| from mindspore.communication.management import init | |||||
| from mindspore.train.model import Model | |||||
| from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor | |||||
| from src.deepfm import ModelBuilder, AUCMetric | |||||
| from src.config import DataConfig, ModelConfig, TrainConfig | |||||
| from src.dataset import create_dataset, DataType | |||||
| from src.callback import EvalCallBack, LossCallBack | |||||
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |||||
| parser = argparse.ArgumentParser(description='CTR Prediction') | |||||
| parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path') | |||||
| parser.add_argument('--ckpt_path', type=str, default=None, help='Checkpoint path') | |||||
| parser.add_argument('--eval_file_name', type=str, default="./auc.log", help='eval file path') | |||||
| parser.add_argument('--loss_file_name', type=str, default="./loss.log", help='loss file path') | |||||
| parser.add_argument('--do_eval', type=bool, default=True, help='Do evaluation or not.') | |||||
| args_opt, _ = parser.parse_known_args() | |||||
| device_id = int(os.getenv('DEVICE_ID')) | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id) | |||||
| if __name__ == '__main__': | |||||
| data_config = DataConfig() | |||||
| model_config = ModelConfig() | |||||
| train_config = TrainConfig() | |||||
| rank_size = int(os.environ.get("RANK_SIZE", 1)) | |||||
| if rank_size > 1: | |||||
| context.reset_auto_parallel_context() | |||||
| context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True) | |||||
| init() | |||||
| rank_id = int(os.environ.get('RANK_ID')) | |||||
| else: | |||||
| rank_size = None | |||||
| rank_id = None | |||||
| ds_train = create_dataset(args_opt.dataset_path, | |||||
| train_mode=True, | |||||
| epochs=train_config.train_epochs, | |||||
| batch_size=train_config.batch_size, | |||||
| data_type=DataType(data_config.data_format), | |||||
| rank_size=rank_size, | |||||
| rank_id=rank_id) | |||||
| model_builder = ModelBuilder(ModelConfig, TrainConfig) | |||||
| train_net, eval_net = model_builder.get_train_eval_net() | |||||
| auc_metric = AUCMetric() | |||||
| model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric}) | |||||
| time_callback = TimeMonitor(data_size=ds_train.get_dataset_size()) | |||||
| loss_callback = LossCallBack(loss_file_path=args_opt.loss_file_name) | |||||
| callback_list = [time_callback, loss_callback] | |||||
| if train_config.save_checkpoint: | |||||
| config_ck = CheckpointConfig(save_checkpoint_steps=train_config.save_checkpoint_steps, | |||||
| keep_checkpoint_max=train_config.keep_checkpoint_max) | |||||
| ckpt_cb = ModelCheckpoint(prefix=train_config.ckpt_file_name_prefix, | |||||
| directory=args_opt.ckpt_path, | |||||
| config=config_ck) | |||||
| callback_list.append(ckpt_cb) | |||||
| if args_opt.do_eval: | |||||
| ds_eval = create_dataset(args_opt.dataset_path, train_mode=False, | |||||
| epochs=train_config.train_epochs, | |||||
| batch_size=train_config.batch_size, | |||||
| data_type=DataType(data_config.data_format)) | |||||
| eval_callback = EvalCallBack(model, ds_eval, auc_metric, | |||||
| eval_file_path=args_opt.eval_file_name) | |||||
| callback_list.append(eval_callback) | |||||
| model.train(train_config.train_epochs, ds_train, callbacks=callback_list) | |||||