Browse Source

add BertSum

tags/v0.4.10
maszhongming 5 years ago
parent
commit
248eefe9eb
6 changed files with 686 additions and 0 deletions
  1. +129
    -0
      reproduction/Summmarization/BertSum/callback.py
  2. +157
    -0
      reproduction/Summmarization/BertSum/dataloader.py
  3. +178
    -0
      reproduction/Summmarization/BertSum/metrics.py
  4. +51
    -0
      reproduction/Summmarization/BertSum/model.py
  5. +147
    -0
      reproduction/Summmarization/BertSum/train_BertSum.py
  6. +24
    -0
      reproduction/Summmarization/BertSum/utils.py

+ 129
- 0
reproduction/Summmarization/BertSum/callback.py View File

@@ -0,0 +1,129 @@
import os
import torch
import sys
from torch import nn

from fastNLP.core.callback import Callback
from fastNLP.core.utils import _get_model_device

class MyCallback(Callback):
def __init__(self, args):
super(MyCallback, self).__init__()
self.args = args
self.real_step = 0

def on_step_end(self):
if self.step % self.update_every == 0 and self.step > 0:
self.real_step += 1
cur_lr = self.args.max_lr * 100 * min(self.real_step ** (-0.5), self.real_step * self.args.warmup_steps**(-1.5))
for param_group in self.optimizer.param_groups:
param_group['lr'] = cur_lr

if self.real_step % 1000 == 0:
self.pbar.write('Current learning rate is {:.8f}, real_step: {}'.format(cur_lr, self.real_step))
def on_epoch_end(self):
self.pbar.write('Epoch {} is done !!!'.format(self.epoch))

def _save_model(model, model_name, save_dir, only_param=False):
""" 存储不含有显卡信息的 state_dict 或 model
:param model:
:param model_name:
:param save_dir: 保存的 directory
:param only_param:
:return:
"""
model_path = os.path.join(save_dir, model_name)
if not os.path.isdir(save_dir):
os.makedirs(save_dir, exist_ok=True)
if isinstance(model, nn.DataParallel):
model = model.module
if only_param:
state_dict = model.state_dict()
for key in state_dict:
state_dict[key] = state_dict[key].cpu()
torch.save(state_dict, model_path)
else:
_model_device = _get_model_device(model)
model.cpu()
torch.save(model, model_path)
model.to(_model_device)

class SaveModelCallback(Callback):
"""
由于Trainer在训练过程中只会保存最佳的模型, 该 callback 可实现多种方式的结果存储。
会根据训练开始的时间戳在 save_dir 下建立文件夹,在再文件夹下存放多个模型
-save_dir
-2019-07-03-15-06-36
-epoch0step20{metric_key}{evaluate_performance}.pt # metric是给定的metric_key, evaluate_perfomance是性能
-epoch1step40
-2019-07-03-15-10-00
-epoch:0step:20{metric_key}:{evaluate_performance}.pt # metric是给定的metric_key, evaluate_perfomance是性能
:param str save_dir: 将模型存放在哪个目录下,会在该目录下创建以时间戳命名的目录,并存放模型
:param int top: 保存dev表现top多少模型。-1为保存所有模型
:param bool only_param: 是否只保存模型权重
:param save_on_exception: 发生exception时,是否保存一份当时的模型
"""
def __init__(self, save_dir, top=5, only_param=False, save_on_exception=False):
super().__init__()

if not os.path.isdir(save_dir):
raise IsADirectoryError("{} is not a directory.".format(save_dir))
self.save_dir = save_dir
if top < 0:
self.top = sys.maxsize
else:
self.top = top
self._ordered_save_models = [] # List[Tuple], Tuple[0]是metric, Tuple[1]是path。metric是依次变好的,所以从头删

self.only_param = only_param
self.save_on_exception = save_on_exception

def on_train_begin(self):
self.save_dir = os.path.join(self.save_dir, self.trainer.start_time)

def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval):
metric_value = list(eval_result.values())[0][metric_key]
self._save_this_model(metric_value)

def _insert_into_ordered_save_models(self, pair):
# pair:(metric_value, model_name)
# 返回save的模型pair与删除的模型pair. pair中第一个元素是metric的值,第二个元素是模型的名称
index = -1
for _pair in self._ordered_save_models:
if _pair[0]>=pair[0] and self.trainer.increase_better:
break
if not self.trainer.increase_better and _pair[0]<=pair[0]:
break
index += 1
save_pair = None
if len(self._ordered_save_models)<self.top or (len(self._ordered_save_models)>=self.top and index!=-1):
save_pair = pair
self._ordered_save_models.insert(index+1, pair)
delete_pair = None
if len(self._ordered_save_models)>self.top:
delete_pair = self._ordered_save_models.pop(0)
return save_pair, delete_pair

def _save_this_model(self, metric_value):
name = "epoch:{}_step:{}_{}:{:.6f}.pt".format(self.epoch, self.step, self.trainer.metric_key, metric_value)
save_pair, delete_pair = self._insert_into_ordered_save_models((metric_value, name))
if save_pair:
try:
_save_model(self.model, model_name=name, save_dir=self.save_dir, only_param=self.only_param)
except Exception as e:
print(f"The following exception:{e} happens when saves model to {self.save_dir}.")
if delete_pair:
try:
delete_model_path = os.path.join(self.save_dir, delete_pair[1])
if os.path.exists(delete_model_path):
os.remove(delete_model_path)
except Exception as e:
print(f"Fail to delete model {name} at {self.save_dir} caused by exception:{e}.")

def on_exception(self, exception):
if self.save_on_exception:
name = "epoch:{}_step:{}_Exception:{}.pt".format(self.epoch, self.step, exception.__class__.__name__)
_save_model(self.model, model_name=name, save_dir=self.save_dir, only_param=self.only_param)



+ 157
- 0
reproduction/Summmarization/BertSum/dataloader.py View File

@@ -0,0 +1,157 @@
from time import time
from datetime import timedelta

from fastNLP.io.dataset_loader import JsonLoader
from fastNLP.modules.encoder._bert import BertTokenizer
from fastNLP.io.base_loader import DataInfo
from fastNLP.core.const import Const

class BertData(JsonLoader):
def __init__(self, max_nsents=60, max_ntokens=100, max_len=512):
fields = {'article': 'article',
'label': 'label'}
super(BertData, self).__init__(fields=fields)

self.max_nsents = max_nsents
self.max_ntokens = max_ntokens
self.max_len = max_len

self.tokenizer = BertTokenizer.from_pretrained('/path/to/uncased_L-12_H-768_A-12')
self.cls_id = self.tokenizer.vocab['[CLS]']
self.sep_id = self.tokenizer.vocab['[SEP]']
self.pad_id = self.tokenizer.vocab['[PAD]']

def _load(self, paths):
dataset = super(BertData, self)._load(paths)
return dataset

def process(self, paths):
def truncate_articles(instance, max_nsents=self.max_nsents, max_ntokens=self.max_ntokens):
article = [' '.join(sent.lower().split()[:max_ntokens]) for sent in instance['article']]
return article[:max_nsents]

def truncate_labels(instance):
label = list(filter(lambda x: x < len(instance['article']), instance['label']))
return label
def bert_tokenize(instance, tokenizer, max_len, pad_value):
article = instance['article']
article = ' [SEP] [CLS] '.join(article)
word_pieces = tokenizer.tokenize(article)[:(max_len - 2)]
word_pieces = ['[CLS]'] + word_pieces + ['[SEP]']
token_ids = tokenizer.convert_tokens_to_ids(word_pieces)
while len(token_ids) < max_len:
token_ids.append(pad_value)
assert len(token_ids) == max_len
return token_ids

def get_seg_id(instance, max_len, sep_id):
_segs = [-1] + [i for i, idx in enumerate(instance['article']) if idx == sep_id]
segs = [_segs[i] - _segs[i - 1] for i in range(1, len(_segs))]
segment_id = []
for i, length in enumerate(segs):
if i % 2 == 0:
segment_id += length * [0]
else:
segment_id += length * [1]
while len(segment_id) < max_len:
segment_id.append(0)
return segment_id
def get_cls_id(instance, cls_id):
classification_id = [i for i, idx in enumerate(instance['article']) if idx == cls_id]
return classification_id
def get_labels(instance):
labels = [0] * len(instance['cls_id'])
label_idx = list(filter(lambda x: x < len(instance['cls_id']), instance['label']))
for idx in label_idx:
labels[idx] = 1
return labels

datasets = {}
for name in paths:
datasets[name] = self._load(paths[name])
# remove empty samples
datasets[name].drop(lambda ins: len(ins['article']) == 0 or len(ins['label']) == 0)
# truncate articles
datasets[name].apply(lambda ins: truncate_articles(ins, self.max_nsents, self.max_ntokens), new_field_name='article')
# truncate labels
datasets[name].apply(truncate_labels, new_field_name='label')
# tokenize and convert tokens to id
datasets[name].apply(lambda ins: bert_tokenize(ins, self.tokenizer, self.max_len, self.pad_id), new_field_name='article')
# get segment id
datasets[name].apply(lambda ins: get_seg_id(ins, self.max_len, self.sep_id), new_field_name='segment_id')
# get classification id
datasets[name].apply(lambda ins: get_cls_id(ins, self.cls_id), new_field_name='cls_id')

# get label
datasets[name].apply(get_labels, new_field_name='label')
# rename filed
datasets[name].rename_field('article', Const.INPUTS(0))
datasets[name].rename_field('segment_id', Const.INPUTS(1))
datasets[name].rename_field('cls_id', Const.INPUTS(2))
datasets[name].rename_field('lbael', Const.TARGET)

# set input and target
datasets[name].set_input(Const.INPUTS(0), Const.INPUTS(1), Const.INPUTS(2))
datasets[name].set_target(Const.TARGET)
# set paddding value
datasets[name].set_pad_val('article', 0)

return DataInfo(datasets=datasets)


class BertSumLoader(JsonLoader):
def __init__(self):
fields = {'article': 'article',
'segment_id': 'segment_id',
'cls_id': 'cls_id',
'label': Const.TARGET
}
super(BertSumLoader, self).__init__(fields=fields)

def _load(self, paths):
dataset = super(BertSumLoader, self)._load(paths)
return dataset

def process(self, paths):
def get_seq_len(instance):
return len(instance['article'])

print('Start loading datasets !!!')
start = time()

# load datasets
datasets = {}
for name in paths:
datasets[name] = self._load(paths[name])
datasets[name].apply(get_seq_len, new_field_name='seq_len')

# set input and target
datasets[name].set_input('article', 'segment_id', 'cls_id')
datasets[name].set_target(Const.TARGET)
# set padding value
datasets[name].set_pad_val('article', 0)
datasets[name].set_pad_val('segment_id', 0)
datasets[name].set_pad_val('cls_id', -1)
datasets[name].set_pad_val(Const.TARGET, 0)

print('Finished in {}'.format(timedelta(seconds=time()-start)))

return DataInfo(datasets=datasets)

+ 178
- 0
reproduction/Summmarization/BertSum/metrics.py View File

@@ -0,0 +1,178 @@
import numpy as np
import json
from os.path import join
import torch
import logging
import tempfile
import subprocess as sp
from datetime import timedelta
from time import time

from pyrouge import Rouge155
from pyrouge.utils import log

from fastNLP.core.losses import LossBase
from fastNLP.core.metrics import MetricBase

_ROUGE_PATH = '/path/to/RELEASE-1.5.5'

class MyBCELoss(LossBase):
def __init__(self, pred=None, target=None, mask=None):
super(MyBCELoss, self).__init__()
self._init_param_map(pred=pred, target=target, mask=mask)
self.loss_func = torch.nn.BCELoss(reduction='none')

def get_loss(self, pred, target, mask):
loss = self.loss_func(pred, target.float())
loss = (loss * mask.float()).sum()
return loss

class LossMetric(MetricBase):
def __init__(self, pred=None, target=None, mask=None):
super(LossMetric, self).__init__()
self._init_param_map(pred=pred, target=target, mask=mask)
self.loss_func = torch.nn.BCELoss(reduction='none')
self.avg_loss = 0.0
self.nsamples = 0

def evaluate(self, pred, target, mask):
batch_size = pred.size(0)
loss = self.loss_func(pred, target.float())
loss = (loss * mask.float()).sum()
self.avg_loss += loss
self.nsamples += batch_size

def get_metric(self, reset=True):
self.avg_loss = self.avg_loss / self.nsamples
eval_result = {'loss': self.avg_loss}
if reset:
self.avg_loss = 0
self.nsamples = 0
return eval_result
class RougeMetric(MetricBase):
def __init__(self, data_path, dec_path, ref_path, n_total, n_ext=3, ngram_block=3, pred=None, target=None, mask=None):
super(RougeMetric, self).__init__()
self._init_param_map(pred=pred, target=target, mask=mask)
self.data_path = data_path
self.dec_path = dec_path
self.ref_path = ref_path
self.n_total = n_total
self.n_ext = n_ext
self.ngram_block = ngram_block

self.cur_idx = 0
self.ext = []
self.start = time()

@staticmethod
def eval_rouge(dec_dir, ref_dir):
assert _ROUGE_PATH is not None
log.get_global_console_logger().setLevel(logging.WARNING)
dec_pattern = '(\d+).dec'
ref_pattern = '#ID#.ref'
cmd = '-c 95 -r 1000 -n 2 -m'
with tempfile.TemporaryDirectory() as tmp_dir:
Rouge155.convert_summaries_to_rouge_format(
dec_dir, join(tmp_dir, 'dec'))
Rouge155.convert_summaries_to_rouge_format(
ref_dir, join(tmp_dir, 'ref'))
Rouge155.write_config_static(
join(tmp_dir, 'dec'), dec_pattern,
join(tmp_dir, 'ref'), ref_pattern,
join(tmp_dir, 'settings.xml'), system_id=1
)
cmd = (join(_ROUGE_PATH, 'ROUGE-1.5.5.pl')
+ ' -e {} '.format(join(_ROUGE_PATH, 'data'))
+ cmd
+ ' -a {}'.format(join(tmp_dir, 'settings.xml')))
output = sp.check_output(cmd.split(' '), universal_newlines=True)
R_1 = float(output.split('\n')[3].split(' ')[3])
R_2 = float(output.split('\n')[7].split(' ')[3])
R_L = float(output.split('\n')[11].split(' ')[3])
print(output)
return R_1, R_2, R_L
def evaluate(self, pred, target, mask):
pred = pred + mask.float()
pred = pred.cpu().data.numpy()
ext_ids = np.argsort(-pred, 1)
for sent_id in ext_ids:
self.ext.append(sent_id)
self.cur_idx += 1
print('{}/{} ({:.2f}%) decoded in {} seconds\r'.format(
self.cur_idx, self.n_total, self.cur_idx/self.n_total*100, timedelta(seconds=int(time()-self.start))
), end='')

def get_metric(self, use_ngram_block=True, reset=True):
def check_n_gram(sentence, n, dic):
tokens = sentence.split(' ')
s_len = len(tokens)
for i in range(s_len):
if i + n > s_len:
break
if ' '.join(tokens[i: i + n]) in dic:
return False
return True # no n_gram overlap

# load original data
data = []
with open(self.data_path) as f:
for line in f:
cur_data = json.loads(line)
if 'text' in cur_data:
new_data = {}
new_data['article'] = cur_data['text']
new_data['abstract'] = cur_data['summary']
data.append(new_data)
else:
data.append(cur_data)
# write decode sentences and references
if use_ngram_block == True:
print('\nStart {}-gram blocking !!!'.format(self.ngram_block))
for i, ext_ids in enumerate(self.ext):
dec, ref = [], []
if use_ngram_block == False:
n_sent = min(len(data[i]['article']), self.n_ext)
for j in range(n_sent):
idx = ext_ids[j]
dec.append(data[i]['article'][idx])
else:
n_sent = len(ext_ids)
dic = {}
for j in range(n_sent):
sent = data[i]['article'][ext_ids[j]]
if check_n_gram(sent, self.ngram_block, dic) == True:
dec.append(sent)
# update dic
tokens = sent.split(' ')
s_len = len(tokens)
for k in range(s_len):
if k + self.ngram_block > s_len:
break
dic[' '.join(tokens[k: k + self.ngram_block])] = 1
if len(dec) >= self.n_ext:
break

for sent in data[i]['abstract']:
ref.append(sent)

with open(join(self.dec_path, '{}.dec'.format(i)), 'w') as f:
for sent in dec:
print(sent, file=f)
with open(join(self.ref_path, '{}.ref'.format(i)), 'w') as f:
for sent in ref:
print(sent, file=f)
print('\nStart evaluating ROUGE score !!!')
R_1, R_2, R_L = RougeMetric.eval_rouge(self.dec_path, self.ref_path)
eval_result = {'ROUGE-1': R_1, 'ROUGE-2': R_2, 'ROUGE-L':R_L}

if reset == True:
self.cur_idx = 0
self.ext = []
self.start = time()
return eval_result

+ 51
- 0
reproduction/Summmarization/BertSum/model.py View File

@@ -0,0 +1,51 @@
import torch
from torch import nn
from torch.nn import init

from fastNLP.modules.encoder._bert import BertModel


class Classifier(nn.Module):
def __init__(self, hidden_size):
super(Classifier, self).__init__()
self.linear = nn.Linear(hidden_size, 1)
self.sigmoid = nn.Sigmoid()

def forward(self, inputs, mask_cls):
h = self.linear(inputs).squeeze(-1) # [batch_size, seq_len]
sent_scores = self.sigmoid(h) * mask_cls.float()
return sent_scores


class BertSum(nn.Module):
def __init__(self, hidden_size=768):
super(BertSum, self).__init__()
self.hidden_size = hidden_size

self.encoder = BertModel.from_pretrained('/path/to/uncased_L-12_H-768_A-12')
self.decoder = Classifier(self.hidden_size)

def forward(self, article, segment_id, cls_id):
# print(article.device)
# print(segment_id.device)
# print(cls_id.device)

input_mask = 1 - (article == 0)
mask_cls = 1 - (cls_id == -1)
assert input_mask.size() == article.size()
assert mask_cls.size() == cls_id.size()

bert_out = self.encoder(article, token_type_ids=segment_id, attention_mask=input_mask)
bert_out = bert_out[0][-1] # last layer

sent_emb = bert_out[torch.arange(bert_out.size(0)).unsqueeze(1), cls_id]
sent_emb = sent_emb * mask_cls.unsqueeze(-1).float()
assert sent_emb.size() == (article.size(0), cls_id.size(1), self.hidden_size) # [batch_size, seq_len, hidden_size]

sent_scores = self.decoder(sent_emb, mask_cls) # [batch_size, seq_len]
assert sent_scores.size() == (article.size(0), cls_id.size(1))

return {'pred': sent_scores, 'mask': mask_cls}

+ 147
- 0
reproduction/Summmarization/BertSum/train_BertSum.py View File

@@ -0,0 +1,147 @@
import sys
import argparse
import os
import json
import torch
from time import time
from datetime import timedelta
from os.path import join, exists
from torch.optim import Adam

from utils import get_data_path, get_rouge_path

from dataloader import BertSumLoader
from model import BertSum
from fastNLP.core.optimizer import AdamW
from metrics import MyBCELoss, LossMetric, RougeMetric
from fastNLP.core.sampler import BucketSampler
from callback import MyCallback, SaveModelCallback
from fastNLP.core.trainer import Trainer
from fastNLP.core.tester import Tester


def configure_training(args):
devices = [int(gpu) for gpu in args.gpus.split(',')]
params = {}
params['label_type'] = args.label_type
params['batch_size'] = args.batch_size
params['accum_count'] = args.accum_count
params['max_lr'] = args.max_lr
params['warmup_steps'] = args.warmup_steps
params['n_epochs'] = args.n_epochs
params['valid_steps'] = args.valid_steps
return devices, params

def train_model(args):
# check if the data_path and save_path exists
data_paths = get_data_path(args.mode, args.label_type)
for name in data_paths:
assert exists(data_paths[name])
if not exists(args.save_path):
os.makedirs(args.save_path)
# load summarization datasets
datasets = BertSumLoader().process(data_paths)
print('Information of dataset is:')
print(datasets)
train_set = datasets.datasets['train']
valid_set = datasets.datasets['val']
# configure training
devices, train_params = configure_training(args)
with open(join(args.save_path, 'params.json'), 'w') as f:
json.dump(train_params, f, indent=4)
print('Devices is:')
print(devices)

# configure model
model = BertSum()
optimizer = Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0)
callbacks = [MyCallback(args), SaveModelCallback(args.save_path)]
criterion = MyBCELoss()
val_metric = [LossMetric()]
# sampler = BucketSampler(num_buckets=32, batch_size=args.batch_size)
trainer = Trainer(train_data=train_set, model=model, optimizer=optimizer,
loss=criterion, batch_size=args.batch_size, # sampler=sampler,
update_every=args.accum_count, n_epochs=args.n_epochs,
print_every=100, dev_data=valid_set, metrics=val_metric,
metric_key='-loss', validate_every=args.valid_steps,
save_path=args.save_path, device=devices, callbacks=callbacks)
print('Start training with the following hyper-parameters:')
print(train_params)
trainer.train()
def test_model(args):

models = os.listdir(args.save_path)
# load dataset
data_paths = get_data_path(args.mode, args.label_type)
datasets = BertSumLoader().process(data_paths)
print('Information of dataset is:')
print(datasets)
test_set = datasets.datasets['test']
# only need 1 gpu for testing
device = int(args.gpus)
args.batch_size = 1

for cur_model in models:
print('Current model is {}'.format(cur_model))

# load model
model = torch.load(join(args.save_path, cur_model))
# configure testing
original_path, dec_path, ref_path = get_rouge_path(args.label_type)
test_metric = RougeMetric(data_path=original_path, dec_path=dec_path,
ref_path=ref_path, n_total = len(test_set))
tester = Tester(data=test_set, model=model, metrics=[test_metric],
batch_size=args.batch_size, device=device)
tester.test()


if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='training/testing of BertSum(liu et al. 2019)'
)
parser.add_argument('--mode', required=True,
help='training or testing of BertSum', type=str)

parser.add_argument('--label_type', default='greedy',
help='greedy/limit', type=str)
parser.add_argument('--save_path', required=True,
help='root of the model', type=str)
# example for gpus input: '0,1,2,3'
parser.add_argument('--gpus', required=True,
help='available gpus for training(separated by commas)', type=str)
parser.add_argument('--batch_size', default=18,
help='the training batch size', type=int)
parser.add_argument('--accum_count', default=2,
help='number of updates steps to accumulate before performing a backward/update pass.', type=int)
parser.add_argument('--max_lr', default=2e-5,
help='max learning rate for warm up', type=float)
parser.add_argument('--warmup_steps', default=10000,
help='warm up steps for training', type=int)
parser.add_argument('--n_epochs', default=10,
help='total number of training epochs', type=int)
parser.add_argument('--valid_steps', default=1000,
help='number of update steps for checkpoint and validation', type=int)

args = parser.parse_args()
if args.mode == 'train':
print('Training process of BertSum !!!')
train_model(args)
else:
print('Testing process of BertSum !!!')
test_model(args)





+ 24
- 0
reproduction/Summmarization/BertSum/utils.py View File

@@ -0,0 +1,24 @@
import os
from os.path import exists

def get_data_path(mode, label_type):
paths = {}
if mode == 'train':
paths['train'] = 'data/' + label_type + '/bert.train.jsonl'
paths['val'] = 'data/' + label_type + '/bert.val.jsonl'
else:
paths['test'] = 'data/' + label_type + '/bert.test.jsonl'
return paths

def get_rouge_path(label_type):
if label_type == 'others':
data_path = 'data/' + label_type + '/bert.test.jsonl'
else:
data_path = 'data/' + label_type + '/test.jsonl'
dec_path = 'dec'
ref_path = 'ref'
if not exists(ref_path):
os.makedirs(ref_path)
if not exists(dec_path):
os.makedirs(dec_path)
return data_path, dec_path, ref_path

Loading…
Cancel
Save