From 248eefe9ebf47a14441f375b402b1616d6ef9781 Mon Sep 17 00:00:00 2001 From: maszhongming <1521951528@163.com> Date: Mon, 8 Jul 2019 16:10:01 +0900 Subject: [PATCH] add BertSum --- .../Summmarization/BertSum/callback.py | 129 +++++++++++++ .../Summmarization/BertSum/dataloader.py | 157 +++++++++++++++ .../Summmarization/BertSum/metrics.py | 178 ++++++++++++++++++ reproduction/Summmarization/BertSum/model.py | 51 +++++ .../Summmarization/BertSum/train_BertSum.py | 147 +++++++++++++++ reproduction/Summmarization/BertSum/utils.py | 24 +++ 6 files changed, 686 insertions(+) create mode 100644 reproduction/Summmarization/BertSum/callback.py create mode 100644 reproduction/Summmarization/BertSum/dataloader.py create mode 100644 reproduction/Summmarization/BertSum/metrics.py create mode 100644 reproduction/Summmarization/BertSum/model.py create mode 100644 reproduction/Summmarization/BertSum/train_BertSum.py create mode 100644 reproduction/Summmarization/BertSum/utils.py diff --git a/reproduction/Summmarization/BertSum/callback.py b/reproduction/Summmarization/BertSum/callback.py new file mode 100644 index 00000000..a1bb4f54 --- /dev/null +++ b/reproduction/Summmarization/BertSum/callback.py @@ -0,0 +1,129 @@ +import os +import torch +import sys +from torch import nn + +from fastNLP.core.callback import Callback +from fastNLP.core.utils import _get_model_device + +class MyCallback(Callback): + def __init__(self, args): + super(MyCallback, self).__init__() + self.args = args + self.real_step = 0 + + def on_step_end(self): + if self.step % self.update_every == 0 and self.step > 0: + self.real_step += 1 + cur_lr = self.args.max_lr * 100 * min(self.real_step ** (-0.5), self.real_step * self.args.warmup_steps**(-1.5)) + for param_group in self.optimizer.param_groups: + param_group['lr'] = cur_lr + + if self.real_step % 1000 == 0: + self.pbar.write('Current learning rate is {:.8f}, real_step: {}'.format(cur_lr, self.real_step)) + + def on_epoch_end(self): + self.pbar.write('Epoch {} is done !!!'.format(self.epoch)) + +def _save_model(model, model_name, save_dir, only_param=False): + """ 存储不含有显卡信息的 state_dict 或 model + :param model: + :param model_name: + :param save_dir: 保存的 directory + :param only_param: + :return: + """ + model_path = os.path.join(save_dir, model_name) + if not os.path.isdir(save_dir): + os.makedirs(save_dir, exist_ok=True) + if isinstance(model, nn.DataParallel): + model = model.module + if only_param: + state_dict = model.state_dict() + for key in state_dict: + state_dict[key] = state_dict[key].cpu() + torch.save(state_dict, model_path) + else: + _model_device = _get_model_device(model) + model.cpu() + torch.save(model, model_path) + model.to(_model_device) + +class SaveModelCallback(Callback): + """ + 由于Trainer在训练过程中只会保存最佳的模型, 该 callback 可实现多种方式的结果存储。 + 会根据训练开始的时间戳在 save_dir 下建立文件夹,在再文件夹下存放多个模型 + -save_dir + -2019-07-03-15-06-36 + -epoch0step20{metric_key}{evaluate_performance}.pt # metric是给定的metric_key, evaluate_perfomance是性能 + -epoch1step40 + -2019-07-03-15-10-00 + -epoch:0step:20{metric_key}:{evaluate_performance}.pt # metric是给定的metric_key, evaluate_perfomance是性能 + :param str save_dir: 将模型存放在哪个目录下,会在该目录下创建以时间戳命名的目录,并存放模型 + :param int top: 保存dev表现top多少模型。-1为保存所有模型 + :param bool only_param: 是否只保存模型权重 + :param save_on_exception: 发生exception时,是否保存一份当时的模型 + """ + def __init__(self, save_dir, top=5, only_param=False, save_on_exception=False): + super().__init__() + + if not os.path.isdir(save_dir): + raise IsADirectoryError("{} is not a directory.".format(save_dir)) + self.save_dir = save_dir + if top < 0: + self.top = sys.maxsize + else: + self.top = top + self._ordered_save_models = [] # List[Tuple], Tuple[0]是metric, Tuple[1]是path。metric是依次变好的,所以从头删 + + self.only_param = only_param + self.save_on_exception = save_on_exception + + def on_train_begin(self): + self.save_dir = os.path.join(self.save_dir, self.trainer.start_time) + + def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval): + metric_value = list(eval_result.values())[0][metric_key] + self._save_this_model(metric_value) + + def _insert_into_ordered_save_models(self, pair): + # pair:(metric_value, model_name) + # 返回save的模型pair与删除的模型pair. pair中第一个元素是metric的值,第二个元素是模型的名称 + index = -1 + for _pair in self._ordered_save_models: + if _pair[0]>=pair[0] and self.trainer.increase_better: + break + if not self.trainer.increase_better and _pair[0]<=pair[0]: + break + index += 1 + save_pair = None + if len(self._ordered_save_models)=self.top and index!=-1): + save_pair = pair + self._ordered_save_models.insert(index+1, pair) + delete_pair = None + if len(self._ordered_save_models)>self.top: + delete_pair = self._ordered_save_models.pop(0) + return save_pair, delete_pair + + def _save_this_model(self, metric_value): + name = "epoch:{}_step:{}_{}:{:.6f}.pt".format(self.epoch, self.step, self.trainer.metric_key, metric_value) + save_pair, delete_pair = self._insert_into_ordered_save_models((metric_value, name)) + if save_pair: + try: + _save_model(self.model, model_name=name, save_dir=self.save_dir, only_param=self.only_param) + except Exception as e: + print(f"The following exception:{e} happens when saves model to {self.save_dir}.") + if delete_pair: + try: + delete_model_path = os.path.join(self.save_dir, delete_pair[1]) + if os.path.exists(delete_model_path): + os.remove(delete_model_path) + except Exception as e: + print(f"Fail to delete model {name} at {self.save_dir} caused by exception:{e}.") + + def on_exception(self, exception): + if self.save_on_exception: + name = "epoch:{}_step:{}_Exception:{}.pt".format(self.epoch, self.step, exception.__class__.__name__) + _save_model(self.model, model_name=name, save_dir=self.save_dir, only_param=self.only_param) + + diff --git a/reproduction/Summmarization/BertSum/dataloader.py b/reproduction/Summmarization/BertSum/dataloader.py new file mode 100644 index 00000000..cb1acd53 --- /dev/null +++ b/reproduction/Summmarization/BertSum/dataloader.py @@ -0,0 +1,157 @@ +from time import time +from datetime import timedelta + +from fastNLP.io.dataset_loader import JsonLoader +from fastNLP.modules.encoder._bert import BertTokenizer +from fastNLP.io.base_loader import DataInfo +from fastNLP.core.const import Const + +class BertData(JsonLoader): + + def __init__(self, max_nsents=60, max_ntokens=100, max_len=512): + + fields = {'article': 'article', + 'label': 'label'} + super(BertData, self).__init__(fields=fields) + + self.max_nsents = max_nsents + self.max_ntokens = max_ntokens + self.max_len = max_len + + self.tokenizer = BertTokenizer.from_pretrained('/path/to/uncased_L-12_H-768_A-12') + self.cls_id = self.tokenizer.vocab['[CLS]'] + self.sep_id = self.tokenizer.vocab['[SEP]'] + self.pad_id = self.tokenizer.vocab['[PAD]'] + + def _load(self, paths): + dataset = super(BertData, self)._load(paths) + return dataset + + def process(self, paths): + + def truncate_articles(instance, max_nsents=self.max_nsents, max_ntokens=self.max_ntokens): + article = [' '.join(sent.lower().split()[:max_ntokens]) for sent in instance['article']] + return article[:max_nsents] + + def truncate_labels(instance): + label = list(filter(lambda x: x < len(instance['article']), instance['label'])) + return label + + def bert_tokenize(instance, tokenizer, max_len, pad_value): + article = instance['article'] + article = ' [SEP] [CLS] '.join(article) + word_pieces = tokenizer.tokenize(article)[:(max_len - 2)] + word_pieces = ['[CLS]'] + word_pieces + ['[SEP]'] + token_ids = tokenizer.convert_tokens_to_ids(word_pieces) + while len(token_ids) < max_len: + token_ids.append(pad_value) + assert len(token_ids) == max_len + return token_ids + + def get_seg_id(instance, max_len, sep_id): + _segs = [-1] + [i for i, idx in enumerate(instance['article']) if idx == sep_id] + segs = [_segs[i] - _segs[i - 1] for i in range(1, len(_segs))] + segment_id = [] + for i, length in enumerate(segs): + if i % 2 == 0: + segment_id += length * [0] + else: + segment_id += length * [1] + while len(segment_id) < max_len: + segment_id.append(0) + return segment_id + + def get_cls_id(instance, cls_id): + classification_id = [i for i, idx in enumerate(instance['article']) if idx == cls_id] + return classification_id + + def get_labels(instance): + labels = [0] * len(instance['cls_id']) + label_idx = list(filter(lambda x: x < len(instance['cls_id']), instance['label'])) + for idx in label_idx: + labels[idx] = 1 + return labels + + datasets = {} + for name in paths: + datasets[name] = self._load(paths[name]) + + # remove empty samples + datasets[name].drop(lambda ins: len(ins['article']) == 0 or len(ins['label']) == 0) + + # truncate articles + datasets[name].apply(lambda ins: truncate_articles(ins, self.max_nsents, self.max_ntokens), new_field_name='article') + + # truncate labels + datasets[name].apply(truncate_labels, new_field_name='label') + + # tokenize and convert tokens to id + datasets[name].apply(lambda ins: bert_tokenize(ins, self.tokenizer, self.max_len, self.pad_id), new_field_name='article') + + # get segment id + datasets[name].apply(lambda ins: get_seg_id(ins, self.max_len, self.sep_id), new_field_name='segment_id') + + # get classification id + datasets[name].apply(lambda ins: get_cls_id(ins, self.cls_id), new_field_name='cls_id') + + # get label + datasets[name].apply(get_labels, new_field_name='label') + + # rename filed + datasets[name].rename_field('article', Const.INPUTS(0)) + datasets[name].rename_field('segment_id', Const.INPUTS(1)) + datasets[name].rename_field('cls_id', Const.INPUTS(2)) + datasets[name].rename_field('lbael', Const.TARGET) + + # set input and target + datasets[name].set_input(Const.INPUTS(0), Const.INPUTS(1), Const.INPUTS(2)) + datasets[name].set_target(Const.TARGET) + + # set paddding value + datasets[name].set_pad_val('article', 0) + + return DataInfo(datasets=datasets) + + +class BertSumLoader(JsonLoader): + + def __init__(self): + fields = {'article': 'article', + 'segment_id': 'segment_id', + 'cls_id': 'cls_id', + 'label': Const.TARGET + } + super(BertSumLoader, self).__init__(fields=fields) + + def _load(self, paths): + dataset = super(BertSumLoader, self)._load(paths) + return dataset + + def process(self, paths): + + def get_seq_len(instance): + return len(instance['article']) + + print('Start loading datasets !!!') + start = time() + + # load datasets + datasets = {} + for name in paths: + datasets[name] = self._load(paths[name]) + + datasets[name].apply(get_seq_len, new_field_name='seq_len') + + # set input and target + datasets[name].set_input('article', 'segment_id', 'cls_id') + datasets[name].set_target(Const.TARGET) + + # set padding value + datasets[name].set_pad_val('article', 0) + datasets[name].set_pad_val('segment_id', 0) + datasets[name].set_pad_val('cls_id', -1) + datasets[name].set_pad_val(Const.TARGET, 0) + + print('Finished in {}'.format(timedelta(seconds=time()-start))) + + return DataInfo(datasets=datasets) diff --git a/reproduction/Summmarization/BertSum/metrics.py b/reproduction/Summmarization/BertSum/metrics.py new file mode 100644 index 00000000..228f6789 --- /dev/null +++ b/reproduction/Summmarization/BertSum/metrics.py @@ -0,0 +1,178 @@ +import numpy as np +import json +from os.path import join +import torch +import logging +import tempfile +import subprocess as sp +from datetime import timedelta +from time import time + +from pyrouge import Rouge155 +from pyrouge.utils import log + +from fastNLP.core.losses import LossBase +from fastNLP.core.metrics import MetricBase + +_ROUGE_PATH = '/path/to/RELEASE-1.5.5' + +class MyBCELoss(LossBase): + + def __init__(self, pred=None, target=None, mask=None): + super(MyBCELoss, self).__init__() + self._init_param_map(pred=pred, target=target, mask=mask) + self.loss_func = torch.nn.BCELoss(reduction='none') + + def get_loss(self, pred, target, mask): + loss = self.loss_func(pred, target.float()) + loss = (loss * mask.float()).sum() + return loss + +class LossMetric(MetricBase): + def __init__(self, pred=None, target=None, mask=None): + super(LossMetric, self).__init__() + self._init_param_map(pred=pred, target=target, mask=mask) + self.loss_func = torch.nn.BCELoss(reduction='none') + self.avg_loss = 0.0 + self.nsamples = 0 + + def evaluate(self, pred, target, mask): + batch_size = pred.size(0) + loss = self.loss_func(pred, target.float()) + loss = (loss * mask.float()).sum() + self.avg_loss += loss + self.nsamples += batch_size + + def get_metric(self, reset=True): + self.avg_loss = self.avg_loss / self.nsamples + eval_result = {'loss': self.avg_loss} + if reset: + self.avg_loss = 0 + self.nsamples = 0 + return eval_result + +class RougeMetric(MetricBase): + def __init__(self, data_path, dec_path, ref_path, n_total, n_ext=3, ngram_block=3, pred=None, target=None, mask=None): + super(RougeMetric, self).__init__() + self._init_param_map(pred=pred, target=target, mask=mask) + self.data_path = data_path + self.dec_path = dec_path + self.ref_path = ref_path + self.n_total = n_total + self.n_ext = n_ext + self.ngram_block = ngram_block + + self.cur_idx = 0 + self.ext = [] + self.start = time() + + @staticmethod + def eval_rouge(dec_dir, ref_dir): + assert _ROUGE_PATH is not None + log.get_global_console_logger().setLevel(logging.WARNING) + dec_pattern = '(\d+).dec' + ref_pattern = '#ID#.ref' + cmd = '-c 95 -r 1000 -n 2 -m' + with tempfile.TemporaryDirectory() as tmp_dir: + Rouge155.convert_summaries_to_rouge_format( + dec_dir, join(tmp_dir, 'dec')) + Rouge155.convert_summaries_to_rouge_format( + ref_dir, join(tmp_dir, 'ref')) + Rouge155.write_config_static( + join(tmp_dir, 'dec'), dec_pattern, + join(tmp_dir, 'ref'), ref_pattern, + join(tmp_dir, 'settings.xml'), system_id=1 + ) + cmd = (join(_ROUGE_PATH, 'ROUGE-1.5.5.pl') + + ' -e {} '.format(join(_ROUGE_PATH, 'data')) + + cmd + + ' -a {}'.format(join(tmp_dir, 'settings.xml'))) + output = sp.check_output(cmd.split(' '), universal_newlines=True) + R_1 = float(output.split('\n')[3].split(' ')[3]) + R_2 = float(output.split('\n')[7].split(' ')[3]) + R_L = float(output.split('\n')[11].split(' ')[3]) + print(output) + return R_1, R_2, R_L + + def evaluate(self, pred, target, mask): + pred = pred + mask.float() + pred = pred.cpu().data.numpy() + ext_ids = np.argsort(-pred, 1) + for sent_id in ext_ids: + self.ext.append(sent_id) + self.cur_idx += 1 + print('{}/{} ({:.2f}%) decoded in {} seconds\r'.format( + self.cur_idx, self.n_total, self.cur_idx/self.n_total*100, timedelta(seconds=int(time()-self.start)) + ), end='') + + def get_metric(self, use_ngram_block=True, reset=True): + + def check_n_gram(sentence, n, dic): + tokens = sentence.split(' ') + s_len = len(tokens) + for i in range(s_len): + if i + n > s_len: + break + if ' '.join(tokens[i: i + n]) in dic: + return False + return True # no n_gram overlap + + # load original data + data = [] + with open(self.data_path) as f: + for line in f: + cur_data = json.loads(line) + if 'text' in cur_data: + new_data = {} + new_data['article'] = cur_data['text'] + new_data['abstract'] = cur_data['summary'] + data.append(new_data) + else: + data.append(cur_data) + + # write decode sentences and references + if use_ngram_block == True: + print('\nStart {}-gram blocking !!!'.format(self.ngram_block)) + for i, ext_ids in enumerate(self.ext): + dec, ref = [], [] + if use_ngram_block == False: + n_sent = min(len(data[i]['article']), self.n_ext) + for j in range(n_sent): + idx = ext_ids[j] + dec.append(data[i]['article'][idx]) + else: + n_sent = len(ext_ids) + dic = {} + for j in range(n_sent): + sent = data[i]['article'][ext_ids[j]] + if check_n_gram(sent, self.ngram_block, dic) == True: + dec.append(sent) + # update dic + tokens = sent.split(' ') + s_len = len(tokens) + for k in range(s_len): + if k + self.ngram_block > s_len: + break + dic[' '.join(tokens[k: k + self.ngram_block])] = 1 + if len(dec) >= self.n_ext: + break + + for sent in data[i]['abstract']: + ref.append(sent) + + with open(join(self.dec_path, '{}.dec'.format(i)), 'w') as f: + for sent in dec: + print(sent, file=f) + with open(join(self.ref_path, '{}.ref'.format(i)), 'w') as f: + for sent in ref: + print(sent, file=f) + + print('\nStart evaluating ROUGE score !!!') + R_1, R_2, R_L = RougeMetric.eval_rouge(self.dec_path, self.ref_path) + eval_result = {'ROUGE-1': R_1, 'ROUGE-2': R_2, 'ROUGE-L':R_L} + + if reset == True: + self.cur_idx = 0 + self.ext = [] + self.start = time() + return eval_result diff --git a/reproduction/Summmarization/BertSum/model.py b/reproduction/Summmarization/BertSum/model.py new file mode 100644 index 00000000..655ad16e --- /dev/null +++ b/reproduction/Summmarization/BertSum/model.py @@ -0,0 +1,51 @@ +import torch +from torch import nn +from torch.nn import init + +from fastNLP.modules.encoder._bert import BertModel + + +class Classifier(nn.Module): + def __init__(self, hidden_size): + super(Classifier, self).__init__() + self.linear = nn.Linear(hidden_size, 1) + self.sigmoid = nn.Sigmoid() + + def forward(self, inputs, mask_cls): + h = self.linear(inputs).squeeze(-1) # [batch_size, seq_len] + sent_scores = self.sigmoid(h) * mask_cls.float() + return sent_scores + + +class BertSum(nn.Module): + + def __init__(self, hidden_size=768): + super(BertSum, self).__init__() + + self.hidden_size = hidden_size + + self.encoder = BertModel.from_pretrained('/path/to/uncased_L-12_H-768_A-12') + self.decoder = Classifier(self.hidden_size) + + def forward(self, article, segment_id, cls_id): + + # print(article.device) + # print(segment_id.device) + # print(cls_id.device) + + input_mask = 1 - (article == 0) + mask_cls = 1 - (cls_id == -1) + assert input_mask.size() == article.size() + assert mask_cls.size() == cls_id.size() + + bert_out = self.encoder(article, token_type_ids=segment_id, attention_mask=input_mask) + bert_out = bert_out[0][-1] # last layer + + sent_emb = bert_out[torch.arange(bert_out.size(0)).unsqueeze(1), cls_id] + sent_emb = sent_emb * mask_cls.unsqueeze(-1).float() + assert sent_emb.size() == (article.size(0), cls_id.size(1), self.hidden_size) # [batch_size, seq_len, hidden_size] + + sent_scores = self.decoder(sent_emb, mask_cls) # [batch_size, seq_len] + assert sent_scores.size() == (article.size(0), cls_id.size(1)) + + return {'pred': sent_scores, 'mask': mask_cls} diff --git a/reproduction/Summmarization/BertSum/train_BertSum.py b/reproduction/Summmarization/BertSum/train_BertSum.py new file mode 100644 index 00000000..d34fa0b9 --- /dev/null +++ b/reproduction/Summmarization/BertSum/train_BertSum.py @@ -0,0 +1,147 @@ +import sys +import argparse +import os +import json +import torch +from time import time +from datetime import timedelta +from os.path import join, exists +from torch.optim import Adam + +from utils import get_data_path, get_rouge_path + +from dataloader import BertSumLoader +from model import BertSum +from fastNLP.core.optimizer import AdamW +from metrics import MyBCELoss, LossMetric, RougeMetric +from fastNLP.core.sampler import BucketSampler +from callback import MyCallback, SaveModelCallback +from fastNLP.core.trainer import Trainer +from fastNLP.core.tester import Tester + + +def configure_training(args): + devices = [int(gpu) for gpu in args.gpus.split(',')] + params = {} + params['label_type'] = args.label_type + params['batch_size'] = args.batch_size + params['accum_count'] = args.accum_count + params['max_lr'] = args.max_lr + params['warmup_steps'] = args.warmup_steps + params['n_epochs'] = args.n_epochs + params['valid_steps'] = args.valid_steps + return devices, params + +def train_model(args): + + # check if the data_path and save_path exists + data_paths = get_data_path(args.mode, args.label_type) + for name in data_paths: + assert exists(data_paths[name]) + if not exists(args.save_path): + os.makedirs(args.save_path) + + # load summarization datasets + datasets = BertSumLoader().process(data_paths) + print('Information of dataset is:') + print(datasets) + train_set = datasets.datasets['train'] + valid_set = datasets.datasets['val'] + + # configure training + devices, train_params = configure_training(args) + with open(join(args.save_path, 'params.json'), 'w') as f: + json.dump(train_params, f, indent=4) + print('Devices is:') + print(devices) + + # configure model + model = BertSum() + optimizer = Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0) + callbacks = [MyCallback(args), SaveModelCallback(args.save_path)] + criterion = MyBCELoss() + val_metric = [LossMetric()] + # sampler = BucketSampler(num_buckets=32, batch_size=args.batch_size) + trainer = Trainer(train_data=train_set, model=model, optimizer=optimizer, + loss=criterion, batch_size=args.batch_size, # sampler=sampler, + update_every=args.accum_count, n_epochs=args.n_epochs, + print_every=100, dev_data=valid_set, metrics=val_metric, + metric_key='-loss', validate_every=args.valid_steps, + save_path=args.save_path, device=devices, callbacks=callbacks) + + print('Start training with the following hyper-parameters:') + print(train_params) + trainer.train() + +def test_model(args): + + models = os.listdir(args.save_path) + + # load dataset + data_paths = get_data_path(args.mode, args.label_type) + datasets = BertSumLoader().process(data_paths) + print('Information of dataset is:') + print(datasets) + test_set = datasets.datasets['test'] + + # only need 1 gpu for testing + device = int(args.gpus) + + args.batch_size = 1 + + for cur_model in models: + + print('Current model is {}'.format(cur_model)) + + # load model + model = torch.load(join(args.save_path, cur_model)) + + # configure testing + original_path, dec_path, ref_path = get_rouge_path(args.label_type) + test_metric = RougeMetric(data_path=original_path, dec_path=dec_path, + ref_path=ref_path, n_total = len(test_set)) + tester = Tester(data=test_set, model=model, metrics=[test_metric], + batch_size=args.batch_size, device=device) + tester.test() + + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description='training/testing of BertSum(liu et al. 2019)' + ) + parser.add_argument('--mode', required=True, + help='training or testing of BertSum', type=str) + + parser.add_argument('--label_type', default='greedy', + help='greedy/limit', type=str) + parser.add_argument('--save_path', required=True, + help='root of the model', type=str) + # example for gpus input: '0,1,2,3' + parser.add_argument('--gpus', required=True, + help='available gpus for training(separated by commas)', type=str) + + parser.add_argument('--batch_size', default=18, + help='the training batch size', type=int) + parser.add_argument('--accum_count', default=2, + help='number of updates steps to accumulate before performing a backward/update pass.', type=int) + parser.add_argument('--max_lr', default=2e-5, + help='max learning rate for warm up', type=float) + parser.add_argument('--warmup_steps', default=10000, + help='warm up steps for training', type=int) + parser.add_argument('--n_epochs', default=10, + help='total number of training epochs', type=int) + parser.add_argument('--valid_steps', default=1000, + help='number of update steps for checkpoint and validation', type=int) + + args = parser.parse_args() + + if args.mode == 'train': + print('Training process of BertSum !!!') + train_model(args) + else: + print('Testing process of BertSum !!!') + test_model(args) + + + + diff --git a/reproduction/Summmarization/BertSum/utils.py b/reproduction/Summmarization/BertSum/utils.py new file mode 100644 index 00000000..2ba848b7 --- /dev/null +++ b/reproduction/Summmarization/BertSum/utils.py @@ -0,0 +1,24 @@ +import os +from os.path import exists + +def get_data_path(mode, label_type): + paths = {} + if mode == 'train': + paths['train'] = 'data/' + label_type + '/bert.train.jsonl' + paths['val'] = 'data/' + label_type + '/bert.val.jsonl' + else: + paths['test'] = 'data/' + label_type + '/bert.test.jsonl' + return paths + +def get_rouge_path(label_type): + if label_type == 'others': + data_path = 'data/' + label_type + '/bert.test.jsonl' + else: + data_path = 'data/' + label_type + '/test.jsonl' + dec_path = 'dec' + ref_path = 'ref' + if not exists(ref_path): + os.makedirs(ref_path) + if not exists(dec_path): + os.makedirs(dec_path) + return data_path, dec_path, ref_path