|
- # Copyright 2020 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
-
- '''
- Bert finetune and evaluation script.
- '''
- import os
- import argparse
- import collections
- from src.bert_for_finetune import BertSquadCell, BertSquad
- from src.finetune_eval_config import optimizer_cfg, bert_net_cfg
- from src.dataset import create_squad_dataset
- from src import tokenization
- from src.create_squad_data import read_squad_examples, convert_examples_to_features
- from src.run_squad import write_predictions
- from src.utils import make_directory, LossCallBack, LoadNewestCkpt
- import mindspore.common.dtype as mstype
- from mindspore import context
- from mindspore import log as logger
- from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
- from mindspore.nn.optim import AdamWeightDecayDynamicLR, Lamb, Momentum
- from mindspore.common.tensor import Tensor
- from mindspore.train.model import Model
- from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor
- from mindspore.train.serialization import load_checkpoint, load_param_into_net
-
- _cur_dir = os.getcwd()
-
- def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path=""):
- """ do train """
- if load_checkpoint_path == "":
- raise ValueError("Pretrain model missed, finetune task must load pretrain model!")
- steps_per_epoch = dataset.get_dataset_size()
- epoch_num = dataset.get_repeat_count()
- # optimizer
- if optimizer_cfg.optimizer == 'AdamWeightDecayDynamicLR':
- optimizer = AdamWeightDecayDynamicLR(network.trainable_params(),
- decay_steps=steps_per_epoch * epoch_num,
- learning_rate=optimizer_cfg.AdamWeightDecayDynamicLR.learning_rate,
- end_learning_rate=optimizer_cfg.AdamWeightDecayDynamicLR.end_learning_rate,
- power=optimizer_cfg.AdamWeightDecayDynamicLR.power,
- warmup_steps=int(steps_per_epoch * epoch_num * 0.1),
- weight_decay=optimizer_cfg.AdamWeightDecayDynamicLR.weight_decay,
- eps=optimizer_cfg.AdamWeightDecayDynamicLR.eps)
- elif optimizer_cfg.optimizer == 'Lamb':
- optimizer = Lamb(network.trainable_params(), decay_steps=steps_per_epoch * epoch_num,
- start_learning_rate=optimizer_cfg.Lamb.start_learning_rate,
- end_learning_rate=optimizer_cfg.Lamb.end_learning_rate,
- power=optimizer_cfg.Lamb.power, weight_decay=optimizer_cfg.Lamb.weight_decay,
- warmup_steps=int(steps_per_epoch * epoch_num * 0.1),
- decay_filter=optimizer_cfg.Lamb.decay_filter)
- elif optimizer_cfg.optimizer == 'Momentum':
- optimizer = Momentum(network.trainable_params(), learning_rate=optimizer_cfg.Momentum.learning_rate,
- momentum=optimizer_cfg.Momentum.momentum)
- else:
- raise Exception("Optimizer not supported. support: [AdamWeightDecayDynamicLR, Lamb, Momentum]")
-
- # load checkpoint into network
- ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1)
- ckpoint_cb = ModelCheckpoint(prefix="squad", directory=save_checkpoint_path, config=ckpt_config)
- param_dict = load_checkpoint(load_checkpoint_path)
- load_param_into_net(network, param_dict)
-
- update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2**32, scale_factor=2, scale_window=1000)
- netwithgrads = BertSquadCell(network, optimizer=optimizer, scale_update_cell=update_cell)
- model = Model(netwithgrads)
- callbacks = [TimeMonitor(dataset.get_dataset_size()), LossCallBack(), ckpoint_cb]
- model.train(epoch_num, dataset, callbacks=callbacks)
-
-
- def do_eval(dataset=None, vocab_file="", eval_json="", load_checkpoint_path="", seq_length=384):
- """ do eval """
- if load_checkpoint_path == "":
- raise ValueError("Finetune model missed, evaluation task must load finetune model!")
- tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file, do_lower_case=True)
- eval_examples = read_squad_examples(eval_json, False)
- eval_features = convert_examples_to_features(
- examples=eval_examples,
- tokenizer=tokenizer,
- max_seq_length=seq_length,
- doc_stride=128,
- max_query_length=64,
- is_training=False,
- output_fn=None,
- verbose_logging=False)
-
- net = BertSquad(bert_net_cfg, False, 2)
- net.set_train(False)
- param_dict = load_checkpoint(load_checkpoint_path)
- load_param_into_net(net, param_dict)
- model = Model(net)
- output = []
- RawResult = collections.namedtuple("RawResult", ["unique_id", "start_logits", "end_logits"])
- columns_list = ["input_ids", "input_mask", "segment_ids", "unique_ids"]
- for data in dataset.create_dict_iterator():
- input_data = []
- for i in columns_list:
- input_data.append(Tensor(data[i]))
- input_ids, input_mask, segment_ids, unique_ids = input_data
- start_positions = Tensor([1], mstype.float32)
- end_positions = Tensor([1], mstype.float32)
- is_impossible = Tensor([1], mstype.float32)
- logits = model.predict(input_ids, input_mask, segment_ids, start_positions,
- end_positions, unique_ids, is_impossible)
- ids = logits[0].asnumpy()
- start = logits[1].asnumpy()
- end = logits[2].asnumpy()
-
- for i in range(bert_net_cfg.batch_size):
- unique_id = int(ids[i])
- start_logits = [float(x) for x in start[i].flat]
- end_logits = [float(x) for x in end[i].flat]
- output.append(RawResult(
- unique_id=unique_id,
- start_logits=start_logits,
- end_logits=end_logits))
- write_predictions(eval_examples, eval_features, output, 20, 30, True, "./predictions.json", None, None)
-
- def run_squad():
- """run squad task"""
- parser = argparse.ArgumentParser(description="run classifier")
- parser.add_argument("--device_target", type=str, default="Ascend", help="Device type, default is Ascend")
- parser.add_argument("--do_train", type=str, default="false", help="Eable train, default is false")
- parser.add_argument("--do_eval", type=str, default="false", help="Eable eval, default is false")
- parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
- parser.add_argument("--epoch_num", type=int, default="1", help="Epoch number, default is 1.")
- parser.add_argument("--num_class", type=int, default="2", help="The number of class, default is 2.")
- parser.add_argument("--vocab_file_path", type=str, default="", help="Vocab file path")
- parser.add_argument("--eval_json_path", type=str, default="", help="Evaluation json file path, can be eval.json")
- parser.add_argument("--save_finetune_checkpoint_path", type=str, default="", help="Save checkpoint path")
- parser.add_argument("--load_pretrain_checkpoint_path", type=str, default="", help="Load checkpoint file path")
- parser.add_argument("--load_finetune_checkpoint_path", type=str, default="", help="Load checkpoint file path")
- parser.add_argument("--train_data_file_path", type=str, default="",
- help="Data path, it is better to use absolute path")
- parser.add_argument("--eval_data_file_path", type=str, default="",
- help="Data path, it is better to use absolute path")
- parser.add_argument("--schema_file_path", type=str, default="",
- help="Schema path, it is better to use absolute path")
- args_opt = parser.parse_args()
- epoch_num = args_opt.epoch_num
- load_pretrain_checkpoint_path = args_opt.load_pretrain_checkpoint_path
- save_finetune_checkpoint_path = args_opt.save_finetune_checkpoint_path
- load_finetune_checkpoint_path = args_opt.load_finetune_checkpoint_path
-
- if args_opt.do_train.lower() == "false" and args_opt.do_eval.lower() == "false":
- raise ValueError("At least one of 'do_train' or 'do_eval' must be true")
- if args_opt.do_train.lower() == "true" and args_opt.train_data_file_path == "":
- raise ValueError("'train_data_file_path' must be set when do finetune task")
- if args_opt.do_eval.lower() == "true":
- if args_opt.eval_data_file_path == "":
- raise ValueError("'eval_data_file_path' must be set when do evaluation task")
- if args_opt.vocab_file_path == "":
- raise ValueError("'vocab_file_path' must be set when do evaluation task")
- if args_opt.eval_json_path == "":
- raise ValueError("'tokenization_file_path' must be set when do evaluation task")
-
-
- target = args_opt.device_target
- if target == "Ascend":
- context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
- elif target == "GPU":
- context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
- if bert_net_cfg.compute_type != mstype.float32:
- logger.warning('GPU only support fp32 temporarily, run with fp32.')
- bert_net_cfg.compute_type = mstype.float32
- else:
- raise Exception("Target error, GPU or Ascend is supported.")
-
- netwithloss = BertSquad(bert_net_cfg, True, 2, dropout_prob=0.1)
-
- if args_opt.do_train.lower() == "true":
- ds = create_squad_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=epoch_num,
- data_file_path=args_opt.train_data_file_path,
- schema_file_path=args_opt.schema_file_path)
- do_train(ds, netwithloss, load_pretrain_checkpoint_path, save_finetune_checkpoint_path)
- if args_opt.do_eval.lower() == "true":
- if save_finetune_checkpoint_path == "":
- load_finetune_checkpoint_dir = _cur_dir
- else:
- load_finetune_checkpoint_dir = make_directory(save_finetune_checkpoint_path)
- load_finetune_checkpoint_path = LoadNewestCkpt(load_finetune_checkpoint_dir,
- ds.get_dataset_size(), epoch_num, "squad")
-
- if args_opt.do_eval.lower() == "true":
- ds = create_squad_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=epoch_num,
- data_file_path=args_opt.eval_data_file_path,
- schema_file_path=args_opt.schema_file_path, is_training=False)
- do_eval(ds, args_opt.vocab_file_path, args_opt.eval_json_path,
- load_finetune_checkpoint_path, bert_net_cfg.seq_length)
-
- if __name__ == "__main__":
- run_squad()
|