Merge pull request !3146 from yao_yf/modezoo_widedeep_run_clusterstags/v0.6.0-beta
| @@ -97,7 +97,7 @@ class ReshapeInfo : public OperatorInfo { | |||||
| TensorLayout output_layout_; | TensorLayout output_layout_; | ||||
| bool input_layout_set_flag_; | bool input_layout_set_flag_; | ||||
| bool output_layout_set_flag_; | bool output_layout_set_flag_; | ||||
| bool is_generating_costs_; | |||||
| bool is_generating_costs_ = false; | |||||
| bool is_skip_ = false; | bool is_skip_ = false; | ||||
| std::string pre_operator_name_; | std::string pre_operator_name_; | ||||
| std::string next_operator_name_; | std::string next_operator_name_; | ||||
| @@ -16,7 +16,7 @@ Arguments: | |||||
| * `--data_path`: Dataset storage path (Default: ./criteo_data/). | * `--data_path`: Dataset storage path (Default: ./criteo_data/). | ||||
| ## Dataset | ## Dataset | ||||
| The Criteo datasets are used for model training and evaluation. | |||||
| The common used benchmark datasets are used for model training and evaluation. | |||||
| ## Running Code | ## Running Code | ||||
| @@ -63,6 +63,7 @@ Arguments: | |||||
| * `--ckpt_path`:The location of the checkpoint file. | * `--ckpt_path`:The location of the checkpoint file. | ||||
| * `--eval_file_name` : Eval output file. | * `--eval_file_name` : Eval output file. | ||||
| * `--loss_file_name` : Loss output file. | * `--loss_file_name` : Loss output file. | ||||
| * `--dataset_type` : tfrecord/mindrecord/hd5. | |||||
| To train the model in one device, command as follows: | To train the model in one device, command as follows: | ||||
| ``` | ``` | ||||
| @@ -84,6 +85,7 @@ Arguments: | |||||
| * `--ckpt_path`:The location of the checkpoint file. | * `--ckpt_path`:The location of the checkpoint file. | ||||
| * `--eval_file_name` : Eval output file. | * `--eval_file_name` : Eval output file. | ||||
| * `--loss_file_name` : Loss output file. | * `--loss_file_name` : Loss output file. | ||||
| * `--dataset_type` : tfrecord/mindrecord/hd5. | |||||
| To train the model in distributed, command as follows: | To train the model in distributed, command as follows: | ||||
| ``` | ``` | ||||
| @@ -95,6 +97,19 @@ bash run_multinpu_train.sh RANK_SIZE EPOCHS DATASET RANK_TABLE_FILE | |||||
| bash run_auto_parallel_train.sh RANK_SIZE EPOCHS DATASET RANK_TABLE_FILE | bash run_auto_parallel_train.sh RANK_SIZE EPOCHS DATASET RANK_TABLE_FILE | ||||
| ``` | ``` | ||||
| To train the model in clusters, command as follows:''' | |||||
| ``` | |||||
| # deploy wide&deep script in clusters | |||||
| # CLUSTER_CONFIG is a json file, the sample is in script/. | |||||
| # EXECUTE_PATH is the scripts path after the deploy. | |||||
| bash deploy_cluster.sh CLUSTER_CONFIG_PATH EXECUTE_PATH | |||||
| # enter EXECUTE_PATH, and execute start_cluster.sh as follows. | |||||
| # MODE: "host_device_mix" | |||||
| bash start_cluster.sh CLUSTER_CONFIG_PATH EPOCH_SIZE VOCAB_SIZE EMB_DIM | |||||
| DATASET ENV_SH RANK_TABLE_FILE MODE | |||||
| ``` | |||||
| To evaluate the model, command as follows: | To evaluate the model, command as follows: | ||||
| ``` | ``` | ||||
| python eval.py | python eval.py | ||||
| @@ -22,7 +22,7 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||||
| from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel | from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel | ||||
| from src.callbacks import LossCallBack, EvalCallBack | from src.callbacks import LossCallBack, EvalCallBack | ||||
| from src.datasets import create_dataset | |||||
| from src.datasets import create_dataset, DataType | |||||
| from src.metrics import AUCMetric | from src.metrics import AUCMetric | ||||
| from src.config import WideDeepConfig | from src.config import WideDeepConfig | ||||
| @@ -69,8 +69,14 @@ def test_eval(config): | |||||
| """ | """ | ||||
| data_path = config.data_path | data_path = config.data_path | ||||
| batch_size = config.batch_size | batch_size = config.batch_size | ||||
| ds_eval = create_dataset(data_path, train_mode=False, epochs=2, | |||||
| batch_size=batch_size) | |||||
| if config.dataset_type == "tfrecord": | |||||
| dataset_type = DataType.TFRECORD | |||||
| elif config.dataset_type == "mindrecord": | |||||
| dataset_type = DataType.MINDRECORD | |||||
| else: | |||||
| dataset_type = DataType.H5 | |||||
| ds_eval = create_dataset(data_path, train_mode=False, epochs=1, | |||||
| batch_size=batch_size, data_type=dataset_type) | |||||
| print("ds_eval.size: {}".format(ds_eval.get_dataset_size())) | print("ds_eval.size: {}".format(ds_eval.get_dataset_size())) | ||||
| net_builder = ModelBuilder() | net_builder = ModelBuilder() | ||||
| @@ -0,0 +1,21 @@ | |||||
| { | |||||
| "rank_size": 32, | |||||
| "cluster": { | |||||
| "xx.xx.xx.xx": { | |||||
| "user": "", | |||||
| "passwd": "" | |||||
| }, | |||||
| "xx.xx.xx.xx": { | |||||
| "user": "", | |||||
| "passwd": "" | |||||
| }, | |||||
| "xx.xx.xx.xx": { | |||||
| "user": "", | |||||
| "passwd": "" | |||||
| }, | |||||
| "xx.xx.xx.xx": { | |||||
| "user": "", | |||||
| "passwd": "" | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,95 @@ | |||||
| #!/bin/bash | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| SSH="ssh -o StrictHostKeyChecking=no" | |||||
| SCP="scp -o StrictHostKeyChecking=no" | |||||
| error_msg() | |||||
| { | |||||
| local msg="$*" | |||||
| echo "[ERROR]: $msg" 1>&2 | |||||
| exit 1 | |||||
| } | |||||
| ssh_pass() | |||||
| { | |||||
| local node="$1" | |||||
| local user="$2" | |||||
| local passwd="$3" | |||||
| shift 3 | |||||
| local cmd="$*" | |||||
| sshpass -p "${passwd}" ${SSH} "${user}"@"${node}" ${cmd} | |||||
| } | |||||
| scp_pass() | |||||
| { | |||||
| local node="$1" | |||||
| local user="$2" | |||||
| local passwd="$3" | |||||
| local src="$4" | |||||
| local target="$5" | |||||
| sshpass -p "${passwd}" ${SCP} -r "${src}" "${user}"@"${node}":"${target}" | |||||
| } | |||||
| rscp_pass() | |||||
| { | |||||
| local node="$1" | |||||
| local user="$2" | |||||
| local passwd="$3" | |||||
| local src="$4" | |||||
| local target="$5" | |||||
| sshpass -p "${passwd}" ${SCP} -r "${user}"@"${node}":"${src}" "${target}" | |||||
| } | |||||
| get_rank_size() | |||||
| { | |||||
| local cluster_config=$1 | |||||
| cat ${cluster_config} | python3 -c 'import sys,json;print(json.load(sys.stdin)["rank_size"])' | |||||
| } | |||||
| get_train_dataset() | |||||
| { | |||||
| local cluster_config=$1 | |||||
| cat ${cluster_config} | python3 -c 'import sys,json;print(json.load(sys.stdin)["train_dataset"])' | |||||
| } | |||||
| get_cluster_list() | |||||
| { | |||||
| local cluster_config=$1 | |||||
| cat ${cluster_config} | python3 -c 'import sys,json;[print(node) for node in json.load(sys.stdin)["cluster"].keys()]' | sort | |||||
| } | |||||
| get_node_user() | |||||
| { | |||||
| local cluster_config=$1 | |||||
| local node=$2 | |||||
| cat ${cluster_config} | python3 -c 'import sys,json;print(json.load(sys.stdin)["cluster"]['\"${node}\"']["user"])' | |||||
| } | |||||
| get_node_passwd() | |||||
| { | |||||
| local cluster_config=$1 | |||||
| local node=$2 | |||||
| cat ${cluster_config} | python3 -c 'import sys,json;print(json.load(sys.stdin)["cluster"]['\"${node}\"']["passwd"])' | |||||
| } | |||||
| rsync_sshpass() | |||||
| { | |||||
| local node=$1 | |||||
| local user="$2" | |||||
| local passwd="$3" | |||||
| scp_pass "${node}" "${user}" "${passwd}" /usr/local/bin/sshpass /usr/local/bin/sshpass | |||||
| } | |||||
| @@ -0,0 +1,37 @@ | |||||
| #!/bin/bash | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| SCRIPTPATH="$( cd "$(dirname "$0")" || exit ; pwd -P )" | |||||
| # shellcheck source=/dev/null | |||||
| source $SCRIPTPATH/common.sh | |||||
| cluster_config_path=$1 | |||||
| execute_path=$2 | |||||
| RANK_SIZE=$(get_rank_size ${cluster_config_path}) | |||||
| RANK_START=0 | |||||
| node_list=$(get_cluster_list ${cluster_config_path}) | |||||
| for node in ${node_list} | |||||
| do | |||||
| user=$(get_node_user ${cluster_config_path} ${node}) | |||||
| passwd=$(get_node_passwd ${cluster_config_path} ${node}) | |||||
| echo "------------------${user}@${node}---------------------" | |||||
| ssh_pass ${node} ${user} ${passwd} "rm -rf ${execute_path}" | |||||
| scp_pass ${node} ${user} ${passwd} $SCRIPTPATH/../../wide_and_deep ${execute_path} | |||||
| RANK_START=$[RANK_START+8] | |||||
| if [[ $RANK_START -ge $RANK_SIZE ]]; then | |||||
| break; | |||||
| fi | |||||
| done | |||||
| @@ -0,0 +1,48 @@ | |||||
| #!/bin/bash | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| execute_path=$(pwd) | |||||
| echo ${execute_path} | |||||
| script_self=$(readlink -f "$0") | |||||
| self_path=$(dirname "${script_self}") | |||||
| echo ${self_path} | |||||
| export RANK_SIZE=$1 | |||||
| RANK_START=$2 | |||||
| EPOCH_SIZE=$3 | |||||
| VOCAB_SIZE=$4 | |||||
| EMB_DIM=$5 | |||||
| DATASET=$6 | |||||
| ENV_SH=$7 | |||||
| MODE=$8 | |||||
| export MINDSPORE_HCCL_CONFIG=$9 | |||||
| export RANK_TABLE_FILE=$9 | |||||
| DEVICE_START=0 | |||||
| # shellcheck source=/dev/null | |||||
| source $ENV_SH | |||||
| for((i=0;i<=7;i++)); | |||||
| do | |||||
| export RANK_ID=$[i+RANK_START] | |||||
| export DEVICE_ID=$[i+DEVICE_START] | |||||
| rm -rf ${execute_path}/device_$RANK_ID | |||||
| mkdir ${execute_path}/device_$RANK_ID | |||||
| cd ${execute_path}/device_$RANK_ID || exit | |||||
| if [ $MODE == "host_device_mix" ]; then | |||||
| python -s ${self_path}/../train_and_eval_auto_parallel.py --data_path=$DATASET --epochs=$EPOCH_SIZE --vocab_size=$VOCAB_SIZE --emb_dim=$EMB_DIM --dropout_flag=1 --host_device_mix=1 >train_deep$i.log 2>&1 & | |||||
| else | |||||
| python -s ${self_path}/../train_and_eval_auto_parallel.py --data_path=$DATASET --epochs=$EPOCH_SIZE --vocab_size=$VOCAB_SIZE --emb_dim=$EMB_DIM --dropout_flag=1 --host_device_mix=0 >train_deep$i.log 2>&1 & | |||||
| fi | |||||
| done | |||||
| @@ -0,0 +1,51 @@ | |||||
| #!/bin/bash | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| execute_path=$(pwd) | |||||
| echo ${execute_path} | |||||
| script_self=$(readlink -f "$0") | |||||
| SCRIPTPATH=$(dirname "${script_self}") | |||||
| echo ${SCRIPTPATH} | |||||
| # shellcheck source=/dev/null | |||||
| source $SCRIPTPATH/common.sh | |||||
| cluster_config_path=$1 | |||||
| RANK_SIZE=$(get_rank_size ${cluster_config_path}) | |||||
| RANK_START=0 | |||||
| node_list=$(get_cluster_list ${cluster_config_path}) | |||||
| EPOCH_SIZE=$2 | |||||
| VOCAB_SIZE=$3 | |||||
| EMB_DIM=$4 | |||||
| DATASET=$5 | |||||
| MINDSPORE_HCCL_CONFIG_PATH=$6 | |||||
| ENV_SH=$7 | |||||
| MODE=$8 | |||||
| for node in ${node_list} | |||||
| do | |||||
| user=$(get_node_user ${cluster_config_path} ${node}) | |||||
| passwd=$(get_node_passwd ${cluster_config_path} ${node}) | |||||
| echo "------------------${user}@${node}---------------------" | |||||
| if [ $MODE == "host_device_mix" ]; then | |||||
| ssh_pass ${node} ${user} ${passwd} "mkdir -p ${execute_path}; cd ${execute_path}; bash ${SCRIPTPATH}/run_auto_parallel_train_cluster.sh ${RANK_SIZE} ${RANK_START} ${EPOCH_SIZE} ${VOCAB_SIZE} ${EMB_DIM} ${DATASET} ${ENV_SH} ${MODE} ${MINDSPORE_HCCL_CONFIG_PATH}" | |||||
| else | |||||
| echo "[ERROR] mode is wrong" | |||||
| exit 1 | |||||
| fi | |||||
| RANK_START=$[RANK_START+8] | |||||
| if [[ $RANK_START -ge $RANK_SIZE ]]; then | |||||
| break; | |||||
| fi | |||||
| done | |||||
| @@ -51,7 +51,7 @@ class LossCallBack(Callback): | |||||
| wide_loss, deep_loss = cb_params.net_outputs[0].asnumpy(), cb_params.net_outputs[1].asnumpy() | wide_loss, deep_loss = cb_params.net_outputs[0].asnumpy(), cb_params.net_outputs[1].asnumpy() | ||||
| cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1 | cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1 | ||||
| cur_num = cb_params.cur_step_num | cur_num = cb_params.cur_step_num | ||||
| print("===loss===", cb_params.cur_epoch_num, cur_step_in_epoch, wide_loss, deep_loss) | |||||
| print("===loss===", cb_params.cur_epoch_num, cur_step_in_epoch, wide_loss, deep_loss, flush=True) | |||||
| # raise ValueError | # raise ValueError | ||||
| if self._per_print_times != 0 and cur_num % self._per_print_times == 0 and self.config is not None: | if self._per_print_times != 0 and cur_num % self._per_print_times == 0 and self.config is not None: | ||||
| @@ -76,7 +76,7 @@ class EvalCallBack(Callback): | |||||
| Args: | Args: | ||||
| print_per_step (int): Print loss every times. Default: 1. | print_per_step (int): Print loss every times. Default: 1. | ||||
| """ | """ | ||||
| def __init__(self, model, eval_dataset, auc_metric, config, print_per_step=1): | |||||
| def __init__(self, model, eval_dataset, auc_metric, config, print_per_step=1, host_device_mix=False): | |||||
| super(EvalCallBack, self).__init__() | super(EvalCallBack, self).__init__() | ||||
| if not isinstance(print_per_step, int) or print_per_step < 0: | if not isinstance(print_per_step, int) or print_per_step < 0: | ||||
| raise ValueError("print_per_step must be int and >= 0.") | raise ValueError("print_per_step must be int and >= 0.") | ||||
| @@ -87,6 +87,7 @@ class EvalCallBack(Callback): | |||||
| self.aucMetric.clear() | self.aucMetric.clear() | ||||
| self.eval_file_name = config.eval_file_name | self.eval_file_name = config.eval_file_name | ||||
| self.eval_values = [] | self.eval_values = [] | ||||
| self.host_device_mix = host_device_mix | |||||
| def epoch_end(self, run_context): | def epoch_end(self, run_context): | ||||
| """ | """ | ||||
| @@ -98,7 +99,7 @@ class EvalCallBack(Callback): | |||||
| context.set_auto_parallel_context(strategy_ckpt_save_file="", | context.set_auto_parallel_context(strategy_ckpt_save_file="", | ||||
| strategy_ckpt_load_file="./strategy_train.ckpt") | strategy_ckpt_load_file="./strategy_train.ckpt") | ||||
| start_time = time.time() | start_time = time.time() | ||||
| out = self.model.eval(self.eval_dataset) | |||||
| out = self.model.eval(self.eval_dataset, dataset_sink_mode=(not self.host_device_mix)) | |||||
| end_time = time.time() | end_time = time.time() | ||||
| eval_time = int(end_time - start_time) | eval_time = int(end_time - start_time) | ||||
| @@ -38,6 +38,8 @@ def argparse_init(): | |||||
| parser.add_argument("--ckpt_path", type=str, default="./checkpoints/") | parser.add_argument("--ckpt_path", type=str, default="./checkpoints/") | ||||
| parser.add_argument("--eval_file_name", type=str, default="eval.log") | parser.add_argument("--eval_file_name", type=str, default="eval.log") | ||||
| parser.add_argument("--loss_file_name", type=str, default="loss.log") | parser.add_argument("--loss_file_name", type=str, default="loss.log") | ||||
| parser.add_argument("--host_device_mix", type=int, default=0) | |||||
| parser.add_argument("--dataset_type", type=str, default="tfrecord") | |||||
| return parser | return parser | ||||
| @@ -68,6 +70,8 @@ class WideDeepConfig(): | |||||
| self.eval_file_name = "eval.log" | self.eval_file_name = "eval.log" | ||||
| self.loss_file_name = "loss.log" | self.loss_file_name = "loss.log" | ||||
| self.ckpt_path = "./checkpoints/" | self.ckpt_path = "./checkpoints/" | ||||
| self.host_device_mix = 0 | |||||
| self.dataset_type = "tfrecord" | |||||
| def argparse_init(self): | def argparse_init(self): | ||||
| """ | """ | ||||
| @@ -97,3 +101,5 @@ class WideDeepConfig(): | |||||
| self.eval_file_name = args.eval_file_name | self.eval_file_name = args.eval_file_name | ||||
| self.loss_file_name = args.loss_file_name | self.loss_file_name = args.loss_file_name | ||||
| self.ckpt_path = args.ckpt_path | self.ckpt_path = args.ckpt_path | ||||
| self.host_device_mix = args.host_device_mix | |||||
| self.dataset_type = args.dataset_type | |||||
| @@ -20,7 +20,7 @@ from mindspore.ops import functional as F | |||||
| from mindspore.ops import composite as C | from mindspore.ops import composite as C | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore.nn import Dropout | from mindspore.nn import Dropout | ||||
| from mindspore.nn.optim import Adam, FTRL | |||||
| from mindspore.nn.optim import Adam, FTRL, LazyAdam | |||||
| # from mindspore.nn.metrics import Metric | # from mindspore.nn.metrics import Metric | ||||
| from mindspore.common.initializer import Uniform, initializer | from mindspore.common.initializer import Uniform, initializer | ||||
| # from mindspore.train.callback import ModelCheckpoint, CheckpointConfig | # from mindspore.train.callback import ModelCheckpoint, CheckpointConfig | ||||
| @@ -82,7 +82,7 @@ class DenseLayer(nn.Cell): | |||||
| """ | """ | ||||
| def __init__(self, input_dim, output_dim, weight_bias_init, act_str, | def __init__(self, input_dim, output_dim, weight_bias_init, act_str, | ||||
| keep_prob=0.7, use_activation=True, convert_dtype=True, drop_out=False): | |||||
| keep_prob=0.5, use_activation=True, convert_dtype=True, drop_out=False): | |||||
| super(DenseLayer, self).__init__() | super(DenseLayer, self).__init__() | ||||
| weight_init, bias_init = weight_bias_init | weight_init, bias_init = weight_bias_init | ||||
| self.weight = init_method( | self.weight = init_method( | ||||
| @@ -137,8 +137,10 @@ class WideDeepModel(nn.Cell): | |||||
| def __init__(self, config): | def __init__(self, config): | ||||
| super(WideDeepModel, self).__init__() | super(WideDeepModel, self).__init__() | ||||
| self.batch_size = config.batch_size | self.batch_size = config.batch_size | ||||
| host_device_mix = bool(config.host_device_mix) | |||||
| parallel_mode = _get_parallel_mode() | parallel_mode = _get_parallel_mode() | ||||
| if parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): | |||||
| is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) | |||||
| if is_auto_parallel: | |||||
| self.batch_size = self.batch_size * get_group_size() | self.batch_size = self.batch_size * get_group_size() | ||||
| self.field_size = config.field_size | self.field_size = config.field_size | ||||
| self.vocab_size = config.vocab_size | self.vocab_size = config.vocab_size | ||||
| @@ -187,16 +189,29 @@ class WideDeepModel(nn.Cell): | |||||
| self.weight_bias_init, | self.weight_bias_init, | ||||
| self.deep_layer_act, | self.deep_layer_act, | ||||
| use_activation=False, convert_dtype=True, drop_out=config.dropout_flag) | use_activation=False, convert_dtype=True, drop_out=config.dropout_flag) | ||||
| self.embeddinglookup = nn.EmbeddingLookup(target='DEVICE') | |||||
| self.mul = P.Mul() | |||||
| self.wide_mul = P.Mul() | |||||
| self.deep_mul = P.Mul() | |||||
| self.reduce_sum = P.ReduceSum(keep_dims=False) | self.reduce_sum = P.ReduceSum(keep_dims=False) | ||||
| self.reshape = P.Reshape() | self.reshape = P.Reshape() | ||||
| self.deep_reshape = P.Reshape() | |||||
| self.square = P.Square() | self.square = P.Square() | ||||
| self.shape = P.Shape() | self.shape = P.Shape() | ||||
| self.tile = P.Tile() | self.tile = P.Tile() | ||||
| self.concat = P.Concat(axis=1) | self.concat = P.Concat(axis=1) | ||||
| self.cast = P.Cast() | self.cast = P.Cast() | ||||
| if is_auto_parallel and host_device_mix: | |||||
| self.dense_layer_1.dropout.dropout_do_mask.set_strategy(((1, get_group_size()),)) | |||||
| self.dense_layer_1.matmul.set_strategy(((1, get_group_size()), (get_group_size(), 1))) | |||||
| self.deep_embeddinglookup = nn.EmbeddingLookup() | |||||
| self.deep_embeddinglookup.embeddinglookup.set_strategy(((1, get_group_size()), (1, 1))) | |||||
| self.wide_embeddinglookup = nn.EmbeddingLookup() | |||||
| self.wide_embeddinglookup.embeddinglookup.set_strategy(((get_group_size(), 1), (1, 1))) | |||||
| self.deep_mul.set_strategy(((1, 1, get_group_size()), (1, 1, 1))) | |||||
| self.deep_reshape.add_prim_attr("skip_redistribution", True) | |||||
| self.reduce_sum.add_prim_attr("cross_batch", True) | |||||
| else: | |||||
| self.deep_embeddinglookup = nn.EmbeddingLookup(target='DEVICE') | |||||
| self.wide_embeddinglookup = nn.EmbeddingLookup(target='DEVICE') | |||||
| def construct(self, id_hldr, wt_hldr): | def construct(self, id_hldr, wt_hldr): | ||||
| """ | """ | ||||
| @@ -206,13 +221,13 @@ class WideDeepModel(nn.Cell): | |||||
| """ | """ | ||||
| mask = self.reshape(wt_hldr, (self.batch_size, self.field_size, 1)) | mask = self.reshape(wt_hldr, (self.batch_size, self.field_size, 1)) | ||||
| # Wide layer | # Wide layer | ||||
| wide_id_weight = self.embeddinglookup(self.wide_w, id_hldr) | |||||
| wx = self.mul(wide_id_weight, mask) | |||||
| wide_id_weight = self.wide_embeddinglookup(self.wide_w, id_hldr) | |||||
| wx = self.wide_mul(wide_id_weight, mask) | |||||
| wide_out = self.reshape(self.reduce_sum(wx, 1) + self.wide_b, (-1, 1)) | wide_out = self.reshape(self.reduce_sum(wx, 1) + self.wide_b, (-1, 1)) | ||||
| # Deep layer | # Deep layer | ||||
| deep_id_embs = self.embeddinglookup(self.embedding_table, id_hldr) | |||||
| vx = self.mul(deep_id_embs, mask) | |||||
| deep_in = self.reshape(vx, (-1, self.field_size * self.emb_dim)) | |||||
| deep_id_embs = self.deep_embeddinglookup(self.embedding_table, id_hldr) | |||||
| vx = self.deep_mul(deep_id_embs, mask) | |||||
| deep_in = self.deep_reshape(vx, (-1, self.field_size * self.emb_dim)) | |||||
| deep_in = self.dense_layer_1(deep_in) | deep_in = self.dense_layer_1(deep_in) | ||||
| deep_in = self.dense_layer_2(deep_in) | deep_in = self.dense_layer_2(deep_in) | ||||
| deep_in = self.dense_layer_3(deep_in) | deep_in = self.dense_layer_3(deep_in) | ||||
| @@ -233,19 +248,28 @@ class NetWithLossClass(nn.Cell): | |||||
| def __init__(self, network, config): | def __init__(self, network, config): | ||||
| super(NetWithLossClass, self).__init__(auto_prefix=False) | super(NetWithLossClass, self).__init__(auto_prefix=False) | ||||
| host_device_mix = bool(config.host_device_mix) | |||||
| parallel_mode = _get_parallel_mode() | |||||
| is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) | |||||
| self.no_l2loss = host_device_mix and is_auto_parallel | |||||
| self.network = network | self.network = network | ||||
| self.l2_coef = config.l2_coef | self.l2_coef = config.l2_coef | ||||
| self.loss = P.SigmoidCrossEntropyWithLogits() | self.loss = P.SigmoidCrossEntropyWithLogits() | ||||
| self.square = P.Square() | self.square = P.Square() | ||||
| self.reduceMean_false = P.ReduceMean(keep_dims=False) | self.reduceMean_false = P.ReduceMean(keep_dims=False) | ||||
| if is_auto_parallel: | |||||
| self.reduceMean_false.add_prim_attr("cross_batch", True) | |||||
| self.reduceSum_false = P.ReduceSum(keep_dims=False) | self.reduceSum_false = P.ReduceSum(keep_dims=False) | ||||
| def construct(self, batch_ids, batch_wts, label): | def construct(self, batch_ids, batch_wts, label): | ||||
| predict, embedding_table = self.network(batch_ids, batch_wts) | predict, embedding_table = self.network(batch_ids, batch_wts) | ||||
| log_loss = self.loss(predict, label) | log_loss = self.loss(predict, label) | ||||
| wide_loss = self.reduceMean_false(log_loss) | wide_loss = self.reduceMean_false(log_loss) | ||||
| l2_loss_v = self.reduceSum_false(self.square(embedding_table)) / 2 | |||||
| deep_loss = self.reduceMean_false(log_loss) + self.l2_coef * l2_loss_v | |||||
| if self.no_l2loss: | |||||
| deep_loss = wide_loss | |||||
| else: | |||||
| l2_loss_v = self.reduceSum_false(self.square(embedding_table)) / 2 | |||||
| deep_loss = self.reduceMean_false(log_loss) + self.l2_coef * l2_loss_v | |||||
| return wide_loss, deep_loss | return wide_loss, deep_loss | ||||
| @@ -267,12 +291,15 @@ class TrainStepWrap(nn.Cell): | |||||
| Append Adam and FTRL optimizers to the training network after that construct | Append Adam and FTRL optimizers to the training network after that construct | ||||
| function can be called to create the backward graph. | function can be called to create the backward graph. | ||||
| Args: | Args: | ||||
| network (Cell): the training network. Note that loss function should have been added. | |||||
| sens (Number): The adjust parameter. Default: 1000.0 | |||||
| network (Cell): The training network. Note that loss function should have been added. | |||||
| sens (Number): The adjust parameter. Default: 1024.0 | |||||
| host_device_mix (Bool): Whether run in host and device mix mode. Default: False | |||||
| """ | """ | ||||
| def __init__(self, network, sens=1024.0): | |||||
| def __init__(self, network, sens=1024.0, host_device_mix=False): | |||||
| super(TrainStepWrap, self).__init__() | super(TrainStepWrap, self).__init__() | ||||
| parallel_mode = _get_parallel_mode() | |||||
| is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) | |||||
| self.network = network | self.network = network | ||||
| self.network.set_train() | self.network.set_train() | ||||
| self.trainable_params = network.trainable_params() | self.trainable_params = network.trainable_params() | ||||
| @@ -285,10 +312,19 @@ class TrainStepWrap(nn.Cell): | |||||
| weights_d.append(params) | weights_d.append(params) | ||||
| self.weights_w = ParameterTuple(weights_w) | self.weights_w = ParameterTuple(weights_w) | ||||
| self.weights_d = ParameterTuple(weights_d) | self.weights_d = ParameterTuple(weights_d) | ||||
| self.optimizer_w = FTRL(learning_rate=1e-2, params=self.weights_w, | |||||
| l1=1e-8, l2=1e-8, initial_accum=1.0) | |||||
| self.optimizer_d = Adam( | |||||
| self.weights_d, learning_rate=3.5e-4, eps=1e-8, loss_scale=sens) | |||||
| if host_device_mix and is_auto_parallel: | |||||
| self.optimizer_d = LazyAdam( | |||||
| self.weights_d, learning_rate=3.5e-4, eps=1e-8, loss_scale=sens) | |||||
| self.optimizer_w = FTRL(learning_rate=5e-2, params=self.weights_w, | |||||
| l1=1e-8, l2=1e-8, initial_accum=1.0, loss_scale=sens) | |||||
| self.optimizer_w.sparse_opt.add_prim_attr("primitive_target", "CPU") | |||||
| self.optimizer_d.sparse_opt.add_prim_attr("primitive_target", "CPU") | |||||
| else: | |||||
| self.optimizer_d = Adam( | |||||
| self.weights_d, learning_rate=3.5e-4, eps=1e-8, loss_scale=sens) | |||||
| self.optimizer_w = FTRL(learning_rate=5e-2, params=self.weights_w, | |||||
| l1=1e-8, l2=1e-8, initial_accum=1.0, loss_scale=sens) | |||||
| self.hyper_map = C.HyperMap() | self.hyper_map = C.HyperMap() | ||||
| self.grad_w = C.GradOperation('grad_w', get_by_list=True, | self.grad_w = C.GradOperation('grad_w', get_by_list=True, | ||||
| sens_param=True) | sens_param=True) | ||||
| @@ -17,7 +17,7 @@ from mindspore import Model, context | |||||
| from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor | from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor | ||||
| from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel | from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel | ||||
| from src.callbacks import LossCallBack | from src.callbacks import LossCallBack | ||||
| from src.datasets import create_dataset | |||||
| from src.datasets import create_dataset, DataType | |||||
| from src.config import WideDeepConfig | from src.config import WideDeepConfig | ||||
| @@ -63,7 +63,14 @@ def test_train(configure): | |||||
| data_path = configure.data_path | data_path = configure.data_path | ||||
| batch_size = configure.batch_size | batch_size = configure.batch_size | ||||
| epochs = configure.epochs | epochs = configure.epochs | ||||
| ds_train = create_dataset(data_path, train_mode=True, epochs=1, batch_size=batch_size) | |||||
| if configure.dataset_type == "tfrecord": | |||||
| dataset_type = DataType.TFRECORD | |||||
| elif configure.dataset_type == "mindrecord": | |||||
| dataset_type = DataType.MINDRECORD | |||||
| else: | |||||
| dataset_type = DataType.H5 | |||||
| ds_train = create_dataset(data_path, train_mode=True, epochs=1, | |||||
| batch_size=batch_size, data_type=dataset_type) | |||||
| print("ds_train.size: {}".format(ds_train.get_dataset_size())) | print("ds_train.size: {}".format(ds_train.get_dataset_size())) | ||||
| net_builder = ModelBuilder() | net_builder = ModelBuilder() | ||||
| @@ -19,7 +19,7 @@ from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMoni | |||||
| from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel | from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel | ||||
| from src.callbacks import LossCallBack, EvalCallBack | from src.callbacks import LossCallBack, EvalCallBack | ||||
| from src.datasets import create_dataset | |||||
| from src.datasets import create_dataset, DataType | |||||
| from src.metrics import AUCMetric | from src.metrics import AUCMetric | ||||
| from src.config import WideDeepConfig | from src.config import WideDeepConfig | ||||
| @@ -67,8 +67,16 @@ def test_train_eval(config): | |||||
| data_path = config.data_path | data_path = config.data_path | ||||
| batch_size = config.batch_size | batch_size = config.batch_size | ||||
| epochs = config.epochs | epochs = config.epochs | ||||
| ds_train = create_dataset(data_path, train_mode=True, epochs=1, batch_size=batch_size) | |||||
| ds_eval = create_dataset(data_path, train_mode=False, epochs=1, batch_size=batch_size) | |||||
| if config.dataset_type == "tfrecord": | |||||
| dataset_type = DataType.TFRECORD | |||||
| elif config.dataset_type == "mindrecord": | |||||
| dataset_type = DataType.MINDRECORD | |||||
| else: | |||||
| dataset_type = DataType.H5 | |||||
| ds_train = create_dataset(data_path, train_mode=True, epochs=1, | |||||
| batch_size=batch_size, data_type=dataset_type) | |||||
| ds_eval = create_dataset(data_path, train_mode=False, epochs=1, | |||||
| batch_size=batch_size, data_type=dataset_type) | |||||
| print("ds_train.size: {}".format(ds_train.get_dataset_size())) | print("ds_train.size: {}".format(ds_train.get_dataset_size())) | ||||
| print("ds_eval.size: {}".format(ds_eval.get_dataset_size())) | print("ds_eval.size: {}".format(ds_eval.get_dataset_size())) | ||||
| @@ -27,13 +27,14 @@ from mindspore.nn.wrap.cell_wrapper import VirtualDatasetCellTriple | |||||
| from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel | from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel | ||||
| from src.callbacks import LossCallBack, EvalCallBack | from src.callbacks import LossCallBack, EvalCallBack | ||||
| from src.datasets import create_dataset | |||||
| from src.datasets import create_dataset, DataType | |||||
| from src.metrics import AUCMetric | from src.metrics import AUCMetric | ||||
| from src.config import WideDeepConfig | from src.config import WideDeepConfig | ||||
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | ||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) | context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) | ||||
| context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, mirror_mean=True) | |||||
| context.set_context(variable_memory_max_size="24GB") | |||||
| context.set_context(enable_sparse=True) | |||||
| cost_model_context.set_cost_model_context(multi_subgraphs=True) | cost_model_context.set_cost_model_context(multi_subgraphs=True) | ||||
| init() | init() | ||||
| @@ -46,7 +47,7 @@ def get_WideDeep_net(config): | |||||
| WideDeep_net = WideDeepModel(config) | WideDeep_net = WideDeepModel(config) | ||||
| loss_net = NetWithLossClass(WideDeep_net, config) | loss_net = NetWithLossClass(WideDeep_net, config) | ||||
| loss_net = VirtualDatasetCellTriple(loss_net) | loss_net = VirtualDatasetCellTriple(loss_net) | ||||
| train_net = TrainStepWrap(loss_net) | |||||
| train_net = TrainStepWrap(loss_net, host_device_mix=bool(config.host_device_mix)) | |||||
| eval_net = PredictWithSigmoid(WideDeep_net) | eval_net = PredictWithSigmoid(WideDeep_net) | ||||
| eval_net = VirtualDatasetCellTriple(eval_net) | eval_net = VirtualDatasetCellTriple(eval_net) | ||||
| return train_net, eval_net | return train_net, eval_net | ||||
| @@ -81,19 +82,28 @@ def train_and_eval(config): | |||||
| data_path = config.data_path | data_path = config.data_path | ||||
| batch_size = config.batch_size | batch_size = config.batch_size | ||||
| epochs = config.epochs | epochs = config.epochs | ||||
| if config.dataset_type == "tfrecord": | |||||
| dataset_type = DataType.TFRECORD | |||||
| elif config.dataset_type == "mindrecord": | |||||
| dataset_type = DataType.MINDRECORD | |||||
| else: | |||||
| dataset_type = DataType.H5 | |||||
| host_device_mix = bool(config.host_device_mix) | |||||
| print("epochs is {}".format(epochs)) | print("epochs is {}".format(epochs)) | ||||
| if config.full_batch: | if config.full_batch: | ||||
| context.set_auto_parallel_context(full_batch=True) | context.set_auto_parallel_context(full_batch=True) | ||||
| de.config.set_seed(1) | de.config.set_seed(1) | ||||
| ds_train = create_dataset(data_path, train_mode=True, epochs=1, | ds_train = create_dataset(data_path, train_mode=True, epochs=1, | ||||
| batch_size=batch_size*get_group_size()) | |||||
| batch_size=batch_size*get_group_size(), data_type=dataset_type) | |||||
| ds_eval = create_dataset(data_path, train_mode=False, epochs=1, | ds_eval = create_dataset(data_path, train_mode=False, epochs=1, | ||||
| batch_size=batch_size*get_group_size()) | |||||
| batch_size=batch_size*get_group_size(), data_type=dataset_type) | |||||
| else: | else: | ||||
| ds_train = create_dataset(data_path, train_mode=True, epochs=1, | ds_train = create_dataset(data_path, train_mode=True, epochs=1, | ||||
| batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size()) | |||||
| batch_size=batch_size, rank_id=get_rank(), | |||||
| rank_size=get_group_size(), data_type=dataset_type) | |||||
| ds_eval = create_dataset(data_path, train_mode=False, epochs=1, | ds_eval = create_dataset(data_path, train_mode=False, epochs=1, | ||||
| batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size()) | |||||
| batch_size=batch_size, rank_id=get_rank(), | |||||
| rank_size=get_group_size(), data_type=dataset_type) | |||||
| print("ds_train.size: {}".format(ds_train.get_dataset_size())) | print("ds_train.size: {}".format(ds_train.get_dataset_size())) | ||||
| print("ds_eval.size: {}".format(ds_eval.get_dataset_size())) | print("ds_eval.size: {}".format(ds_eval.get_dataset_size())) | ||||
| @@ -105,18 +115,24 @@ def train_and_eval(config): | |||||
| model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric}) | model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric}) | ||||
| eval_callback = EvalCallBack(model, ds_eval, auc_metric, config) | |||||
| eval_callback = EvalCallBack(model, ds_eval, auc_metric, config, host_device_mix=host_device_mix) | |||||
| callback = LossCallBack(config=config) | callback = LossCallBack(config=config) | ||||
| ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(), keep_checkpoint_max=5) | ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(), keep_checkpoint_max=5) | ||||
| ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', | ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', | ||||
| directory=config.ckpt_path, config=ckptconfig) | directory=config.ckpt_path, config=ckptconfig) | ||||
| context.set_auto_parallel_context(strategy_ckpt_save_file="./strategy_train.ckpt") | context.set_auto_parallel_context(strategy_ckpt_save_file="./strategy_train.ckpt") | ||||
| model.train(epochs, ds_train, | |||||
| callbacks=[TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback, ckpoint_cb]) | |||||
| callback_list = [TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback] | |||||
| if not host_device_mix: | |||||
| callback_list.append(ckpoint_cb) | |||||
| model.train(epochs, ds_train, callbacks=callback_list, dataset_sink_mode=(not host_device_mix)) | |||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| wide_deep_config = WideDeepConfig() | wide_deep_config = WideDeepConfig() | ||||
| wide_deep_config.argparse_init() | wide_deep_config.argparse_init() | ||||
| if wide_deep_config.host_device_mix == 1: | |||||
| context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, mirror_mean=True) | |||||
| else: | |||||
| context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, mirror_mean=True) | |||||
| train_and_eval(wide_deep_config) | train_and_eval(wide_deep_config) | ||||
| @@ -25,7 +25,7 @@ from mindspore.communication.management import get_rank, get_group_size, init | |||||
| from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel | from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel | ||||
| from src.callbacks import LossCallBack, EvalCallBack | from src.callbacks import LossCallBack, EvalCallBack | ||||
| from src.datasets import create_dataset | |||||
| from src.datasets import create_dataset, DataType | |||||
| from src.metrics import AUCMetric | from src.metrics import AUCMetric | ||||
| from src.config import WideDeepConfig | from src.config import WideDeepConfig | ||||
| @@ -73,11 +73,19 @@ def train_and_eval(config): | |||||
| data_path = config.data_path | data_path = config.data_path | ||||
| batch_size = config.batch_size | batch_size = config.batch_size | ||||
| epochs = config.epochs | epochs = config.epochs | ||||
| if config.dataset_type == "tfrecord": | |||||
| dataset_type = DataType.TFRECORD | |||||
| elif config.dataset_type == "mindrecord": | |||||
| dataset_type = DataType.MINDRECORD | |||||
| else: | |||||
| dataset_type = DataType.H5 | |||||
| print("epochs is {}".format(epochs)) | print("epochs is {}".format(epochs)) | ||||
| ds_train = create_dataset(data_path, train_mode=True, epochs=1, | ds_train = create_dataset(data_path, train_mode=True, epochs=1, | ||||
| batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size()) | |||||
| batch_size=batch_size, rank_id=get_rank(), | |||||
| rank_size=get_group_size(), data_type=dataset_type) | |||||
| ds_eval = create_dataset(data_path, train_mode=False, epochs=1, | ds_eval = create_dataset(data_path, train_mode=False, epochs=1, | ||||
| batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size()) | |||||
| batch_size=batch_size, rank_id=get_rank(), | |||||
| rank_size=get_group_size(), data_type=dataset_type) | |||||
| print("ds_train.size: {}".format(ds_train.get_dataset_size())) | print("ds_train.size: {}".format(ds_train.get_dataset_size())) | ||||
| print("ds_eval.size: {}".format(ds_eval.get_dataset_size())) | print("ds_eval.size: {}".format(ds_eval.get_dataset_size())) | ||||
| @@ -21,10 +21,10 @@ export RANK_SIZE=$DEVICE_NUM | |||||
| unset SLOG_PRINT_TO_STDOUT | unset SLOG_PRINT_TO_STDOUT | ||||
| export MINDSPORE_HCCL_CONFIG_PATH=$CONFIG_PATH/hccl/rank_table_${DEVICE_NUM}p.json | export MINDSPORE_HCCL_CONFIG_PATH=$CONFIG_PATH/hccl/rank_table_${DEVICE_NUM}p.json | ||||
| CODE_DIR="./" | CODE_DIR="./" | ||||
| if [ -d ${BASE_PATH}/../../../../model_zoo/wide_and_deep ]; then | |||||
| CODE_DIR=${BASE_PATH}/../../../../model_zoo/wide_and_deep | |||||
| elif [ -d ${BASE_PATH}/../../model_zoo/wide_and_deep ]; then | |||||
| CODE_DIR=${BASE_PATH}/../../model_zoo/wide_and_deep | |||||
| if [ -d ${BASE_PATH}/../../../../model_zoo/official/recommend/wide_and_deep ]; then | |||||
| CODE_DIR=${BASE_PATH}/../../../../model_zoo/official/recommend/wide_and_deep | |||||
| elif [ -d ${BASE_PATH}/../../model_zoo/official/recommend/wide_and_deep ]; then | |||||
| CODE_DIR=${BASE_PATH}/../../model_zoo/official/recommend/wide_and_deep | |||||
| else | else | ||||
| echo "[ERROR] code dir is not found" | echo "[ERROR] code dir is not found" | ||||
| fi | fi | ||||