@@ -0,0 +1,129 @@ | |||||
import os | |||||
import torch | |||||
import sys | |||||
from torch import nn | |||||
from fastNLP.core.callback import Callback | |||||
from fastNLP.core.utils import _get_model_device | |||||
class MyCallback(Callback): | |||||
def __init__(self, args): | |||||
super(MyCallback, self).__init__() | |||||
self.args = args | |||||
self.real_step = 0 | |||||
def on_step_end(self): | |||||
if self.step % self.update_every == 0 and self.step > 0: | |||||
self.real_step += 1 | |||||
cur_lr = self.args.max_lr * 100 * min(self.real_step ** (-0.5), self.real_step * self.args.warmup_steps**(-1.5)) | |||||
for param_group in self.optimizer.param_groups: | |||||
param_group['lr'] = cur_lr | |||||
if self.real_step % 1000 == 0: | |||||
self.pbar.write('Current learning rate is {:.8f}, real_step: {}'.format(cur_lr, self.real_step)) | |||||
def on_epoch_end(self): | |||||
self.pbar.write('Epoch {} is done !!!'.format(self.epoch)) | |||||
def _save_model(model, model_name, save_dir, only_param=False): | |||||
""" 存储不含有显卡信息的 state_dict 或 model | |||||
:param model: | |||||
:param model_name: | |||||
:param save_dir: 保存的 directory | |||||
:param only_param: | |||||
:return: | |||||
""" | |||||
model_path = os.path.join(save_dir, model_name) | |||||
if not os.path.isdir(save_dir): | |||||
os.makedirs(save_dir, exist_ok=True) | |||||
if isinstance(model, nn.DataParallel): | |||||
model = model.module | |||||
if only_param: | |||||
state_dict = model.state_dict() | |||||
for key in state_dict: | |||||
state_dict[key] = state_dict[key].cpu() | |||||
torch.save(state_dict, model_path) | |||||
else: | |||||
_model_device = _get_model_device(model) | |||||
model.cpu() | |||||
torch.save(model, model_path) | |||||
model.to(_model_device) | |||||
class SaveModelCallback(Callback): | |||||
""" | |||||
由于Trainer在训练过程中只会保存最佳的模型, 该 callback 可实现多种方式的结果存储。 | |||||
会根据训练开始的时间戳在 save_dir 下建立文件夹,在再文件夹下存放多个模型 | |||||
-save_dir | |||||
-2019-07-03-15-06-36 | |||||
-epoch0step20{metric_key}{evaluate_performance}.pt # metric是给定的metric_key, evaluate_perfomance是性能 | |||||
-epoch1step40 | |||||
-2019-07-03-15-10-00 | |||||
-epoch:0step:20{metric_key}:{evaluate_performance}.pt # metric是给定的metric_key, evaluate_perfomance是性能 | |||||
:param str save_dir: 将模型存放在哪个目录下,会在该目录下创建以时间戳命名的目录,并存放模型 | |||||
:param int top: 保存dev表现top多少模型。-1为保存所有模型 | |||||
:param bool only_param: 是否只保存模型权重 | |||||
:param save_on_exception: 发生exception时,是否保存一份当时的模型 | |||||
""" | |||||
def __init__(self, save_dir, top=5, only_param=False, save_on_exception=False): | |||||
super().__init__() | |||||
if not os.path.isdir(save_dir): | |||||
raise IsADirectoryError("{} is not a directory.".format(save_dir)) | |||||
self.save_dir = save_dir | |||||
if top < 0: | |||||
self.top = sys.maxsize | |||||
else: | |||||
self.top = top | |||||
self._ordered_save_models = [] # List[Tuple], Tuple[0]是metric, Tuple[1]是path。metric是依次变好的,所以从头删 | |||||
self.only_param = only_param | |||||
self.save_on_exception = save_on_exception | |||||
def on_train_begin(self): | |||||
self.save_dir = os.path.join(self.save_dir, self.trainer.start_time) | |||||
def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval): | |||||
metric_value = list(eval_result.values())[0][metric_key] | |||||
self._save_this_model(metric_value) | |||||
def _insert_into_ordered_save_models(self, pair): | |||||
# pair:(metric_value, model_name) | |||||
# 返回save的模型pair与删除的模型pair. pair中第一个元素是metric的值,第二个元素是模型的名称 | |||||
index = -1 | |||||
for _pair in self._ordered_save_models: | |||||
if _pair[0]>=pair[0] and self.trainer.increase_better: | |||||
break | |||||
if not self.trainer.increase_better and _pair[0]<=pair[0]: | |||||
break | |||||
index += 1 | |||||
save_pair = None | |||||
if len(self._ordered_save_models)<self.top or (len(self._ordered_save_models)>=self.top and index!=-1): | |||||
save_pair = pair | |||||
self._ordered_save_models.insert(index+1, pair) | |||||
delete_pair = None | |||||
if len(self._ordered_save_models)>self.top: | |||||
delete_pair = self._ordered_save_models.pop(0) | |||||
return save_pair, delete_pair | |||||
def _save_this_model(self, metric_value): | |||||
name = "epoch:{}_step:{}_{}:{:.6f}.pt".format(self.epoch, self.step, self.trainer.metric_key, metric_value) | |||||
save_pair, delete_pair = self._insert_into_ordered_save_models((metric_value, name)) | |||||
if save_pair: | |||||
try: | |||||
_save_model(self.model, model_name=name, save_dir=self.save_dir, only_param=self.only_param) | |||||
except Exception as e: | |||||
print(f"The following exception:{e} happens when saves model to {self.save_dir}.") | |||||
if delete_pair: | |||||
try: | |||||
delete_model_path = os.path.join(self.save_dir, delete_pair[1]) | |||||
if os.path.exists(delete_model_path): | |||||
os.remove(delete_model_path) | |||||
except Exception as e: | |||||
print(f"Fail to delete model {name} at {self.save_dir} caused by exception:{e}.") | |||||
def on_exception(self, exception): | |||||
if self.save_on_exception: | |||||
name = "epoch:{}_step:{}_Exception:{}.pt".format(self.epoch, self.step, exception.__class__.__name__) | |||||
_save_model(self.model, model_name=name, save_dir=self.save_dir, only_param=self.only_param) | |||||
@@ -0,0 +1,157 @@ | |||||
from time import time | |||||
from datetime import timedelta | |||||
from fastNLP.io.dataset_loader import JsonLoader | |||||
from fastNLP.modules.encoder._bert import BertTokenizer | |||||
from fastNLP.io.base_loader import DataInfo | |||||
from fastNLP.core.const import Const | |||||
class BertData(JsonLoader): | |||||
def __init__(self, max_nsents=60, max_ntokens=100, max_len=512): | |||||
fields = {'article': 'article', | |||||
'label': 'label'} | |||||
super(BertData, self).__init__(fields=fields) | |||||
self.max_nsents = max_nsents | |||||
self.max_ntokens = max_ntokens | |||||
self.max_len = max_len | |||||
self.tokenizer = BertTokenizer.from_pretrained('/path/to/uncased_L-12_H-768_A-12') | |||||
self.cls_id = self.tokenizer.vocab['[CLS]'] | |||||
self.sep_id = self.tokenizer.vocab['[SEP]'] | |||||
self.pad_id = self.tokenizer.vocab['[PAD]'] | |||||
def _load(self, paths): | |||||
dataset = super(BertData, self)._load(paths) | |||||
return dataset | |||||
def process(self, paths): | |||||
def truncate_articles(instance, max_nsents=self.max_nsents, max_ntokens=self.max_ntokens): | |||||
article = [' '.join(sent.lower().split()[:max_ntokens]) for sent in instance['article']] | |||||
return article[:max_nsents] | |||||
def truncate_labels(instance): | |||||
label = list(filter(lambda x: x < len(instance['article']), instance['label'])) | |||||
return label | |||||
def bert_tokenize(instance, tokenizer, max_len, pad_value): | |||||
article = instance['article'] | |||||
article = ' [SEP] [CLS] '.join(article) | |||||
word_pieces = tokenizer.tokenize(article)[:(max_len - 2)] | |||||
word_pieces = ['[CLS]'] + word_pieces + ['[SEP]'] | |||||
token_ids = tokenizer.convert_tokens_to_ids(word_pieces) | |||||
while len(token_ids) < max_len: | |||||
token_ids.append(pad_value) | |||||
assert len(token_ids) == max_len | |||||
return token_ids | |||||
def get_seg_id(instance, max_len, sep_id): | |||||
_segs = [-1] + [i for i, idx in enumerate(instance['article']) if idx == sep_id] | |||||
segs = [_segs[i] - _segs[i - 1] for i in range(1, len(_segs))] | |||||
segment_id = [] | |||||
for i, length in enumerate(segs): | |||||
if i % 2 == 0: | |||||
segment_id += length * [0] | |||||
else: | |||||
segment_id += length * [1] | |||||
while len(segment_id) < max_len: | |||||
segment_id.append(0) | |||||
return segment_id | |||||
def get_cls_id(instance, cls_id): | |||||
classification_id = [i for i, idx in enumerate(instance['article']) if idx == cls_id] | |||||
return classification_id | |||||
def get_labels(instance): | |||||
labels = [0] * len(instance['cls_id']) | |||||
label_idx = list(filter(lambda x: x < len(instance['cls_id']), instance['label'])) | |||||
for idx in label_idx: | |||||
labels[idx] = 1 | |||||
return labels | |||||
datasets = {} | |||||
for name in paths: | |||||
datasets[name] = self._load(paths[name]) | |||||
# remove empty samples | |||||
datasets[name].drop(lambda ins: len(ins['article']) == 0 or len(ins['label']) == 0) | |||||
# truncate articles | |||||
datasets[name].apply(lambda ins: truncate_articles(ins, self.max_nsents, self.max_ntokens), new_field_name='article') | |||||
# truncate labels | |||||
datasets[name].apply(truncate_labels, new_field_name='label') | |||||
# tokenize and convert tokens to id | |||||
datasets[name].apply(lambda ins: bert_tokenize(ins, self.tokenizer, self.max_len, self.pad_id), new_field_name='article') | |||||
# get segment id | |||||
datasets[name].apply(lambda ins: get_seg_id(ins, self.max_len, self.sep_id), new_field_name='segment_id') | |||||
# get classification id | |||||
datasets[name].apply(lambda ins: get_cls_id(ins, self.cls_id), new_field_name='cls_id') | |||||
# get label | |||||
datasets[name].apply(get_labels, new_field_name='label') | |||||
# rename filed | |||||
datasets[name].rename_field('article', Const.INPUTS(0)) | |||||
datasets[name].rename_field('segment_id', Const.INPUTS(1)) | |||||
datasets[name].rename_field('cls_id', Const.INPUTS(2)) | |||||
datasets[name].rename_field('lbael', Const.TARGET) | |||||
# set input and target | |||||
datasets[name].set_input(Const.INPUTS(0), Const.INPUTS(1), Const.INPUTS(2)) | |||||
datasets[name].set_target(Const.TARGET) | |||||
# set paddding value | |||||
datasets[name].set_pad_val('article', 0) | |||||
return DataInfo(datasets=datasets) | |||||
class BertSumLoader(JsonLoader): | |||||
def __init__(self): | |||||
fields = {'article': 'article', | |||||
'segment_id': 'segment_id', | |||||
'cls_id': 'cls_id', | |||||
'label': Const.TARGET | |||||
} | |||||
super(BertSumLoader, self).__init__(fields=fields) | |||||
def _load(self, paths): | |||||
dataset = super(BertSumLoader, self)._load(paths) | |||||
return dataset | |||||
def process(self, paths): | |||||
def get_seq_len(instance): | |||||
return len(instance['article']) | |||||
print('Start loading datasets !!!') | |||||
start = time() | |||||
# load datasets | |||||
datasets = {} | |||||
for name in paths: | |||||
datasets[name] = self._load(paths[name]) | |||||
datasets[name].apply(get_seq_len, new_field_name='seq_len') | |||||
# set input and target | |||||
datasets[name].set_input('article', 'segment_id', 'cls_id') | |||||
datasets[name].set_target(Const.TARGET) | |||||
# set padding value | |||||
datasets[name].set_pad_val('article', 0) | |||||
datasets[name].set_pad_val('segment_id', 0) | |||||
datasets[name].set_pad_val('cls_id', -1) | |||||
datasets[name].set_pad_val(Const.TARGET, 0) | |||||
print('Finished in {}'.format(timedelta(seconds=time()-start))) | |||||
return DataInfo(datasets=datasets) |
@@ -0,0 +1,178 @@ | |||||
import numpy as np | |||||
import json | |||||
from os.path import join | |||||
import torch | |||||
import logging | |||||
import tempfile | |||||
import subprocess as sp | |||||
from datetime import timedelta | |||||
from time import time | |||||
from pyrouge import Rouge155 | |||||
from pyrouge.utils import log | |||||
from fastNLP.core.losses import LossBase | |||||
from fastNLP.core.metrics import MetricBase | |||||
_ROUGE_PATH = '/path/to/RELEASE-1.5.5' | |||||
class MyBCELoss(LossBase): | |||||
def __init__(self, pred=None, target=None, mask=None): | |||||
super(MyBCELoss, self).__init__() | |||||
self._init_param_map(pred=pred, target=target, mask=mask) | |||||
self.loss_func = torch.nn.BCELoss(reduction='none') | |||||
def get_loss(self, pred, target, mask): | |||||
loss = self.loss_func(pred, target.float()) | |||||
loss = (loss * mask.float()).sum() | |||||
return loss | |||||
class LossMetric(MetricBase): | |||||
def __init__(self, pred=None, target=None, mask=None): | |||||
super(LossMetric, self).__init__() | |||||
self._init_param_map(pred=pred, target=target, mask=mask) | |||||
self.loss_func = torch.nn.BCELoss(reduction='none') | |||||
self.avg_loss = 0.0 | |||||
self.nsamples = 0 | |||||
def evaluate(self, pred, target, mask): | |||||
batch_size = pred.size(0) | |||||
loss = self.loss_func(pred, target.float()) | |||||
loss = (loss * mask.float()).sum() | |||||
self.avg_loss += loss | |||||
self.nsamples += batch_size | |||||
def get_metric(self, reset=True): | |||||
self.avg_loss = self.avg_loss / self.nsamples | |||||
eval_result = {'loss': self.avg_loss} | |||||
if reset: | |||||
self.avg_loss = 0 | |||||
self.nsamples = 0 | |||||
return eval_result | |||||
class RougeMetric(MetricBase): | |||||
def __init__(self, data_path, dec_path, ref_path, n_total, n_ext=3, ngram_block=3, pred=None, target=None, mask=None): | |||||
super(RougeMetric, self).__init__() | |||||
self._init_param_map(pred=pred, target=target, mask=mask) | |||||
self.data_path = data_path | |||||
self.dec_path = dec_path | |||||
self.ref_path = ref_path | |||||
self.n_total = n_total | |||||
self.n_ext = n_ext | |||||
self.ngram_block = ngram_block | |||||
self.cur_idx = 0 | |||||
self.ext = [] | |||||
self.start = time() | |||||
@staticmethod | |||||
def eval_rouge(dec_dir, ref_dir): | |||||
assert _ROUGE_PATH is not None | |||||
log.get_global_console_logger().setLevel(logging.WARNING) | |||||
dec_pattern = '(\d+).dec' | |||||
ref_pattern = '#ID#.ref' | |||||
cmd = '-c 95 -r 1000 -n 2 -m' | |||||
with tempfile.TemporaryDirectory() as tmp_dir: | |||||
Rouge155.convert_summaries_to_rouge_format( | |||||
dec_dir, join(tmp_dir, 'dec')) | |||||
Rouge155.convert_summaries_to_rouge_format( | |||||
ref_dir, join(tmp_dir, 'ref')) | |||||
Rouge155.write_config_static( | |||||
join(tmp_dir, 'dec'), dec_pattern, | |||||
join(tmp_dir, 'ref'), ref_pattern, | |||||
join(tmp_dir, 'settings.xml'), system_id=1 | |||||
) | |||||
cmd = (join(_ROUGE_PATH, 'ROUGE-1.5.5.pl') | |||||
+ ' -e {} '.format(join(_ROUGE_PATH, 'data')) | |||||
+ cmd | |||||
+ ' -a {}'.format(join(tmp_dir, 'settings.xml'))) | |||||
output = sp.check_output(cmd.split(' '), universal_newlines=True) | |||||
R_1 = float(output.split('\n')[3].split(' ')[3]) | |||||
R_2 = float(output.split('\n')[7].split(' ')[3]) | |||||
R_L = float(output.split('\n')[11].split(' ')[3]) | |||||
print(output) | |||||
return R_1, R_2, R_L | |||||
def evaluate(self, pred, target, mask): | |||||
pred = pred + mask.float() | |||||
pred = pred.cpu().data.numpy() | |||||
ext_ids = np.argsort(-pred, 1) | |||||
for sent_id in ext_ids: | |||||
self.ext.append(sent_id) | |||||
self.cur_idx += 1 | |||||
print('{}/{} ({:.2f}%) decoded in {} seconds\r'.format( | |||||
self.cur_idx, self.n_total, self.cur_idx/self.n_total*100, timedelta(seconds=int(time()-self.start)) | |||||
), end='') | |||||
def get_metric(self, use_ngram_block=True, reset=True): | |||||
def check_n_gram(sentence, n, dic): | |||||
tokens = sentence.split(' ') | |||||
s_len = len(tokens) | |||||
for i in range(s_len): | |||||
if i + n > s_len: | |||||
break | |||||
if ' '.join(tokens[i: i + n]) in dic: | |||||
return False | |||||
return True # no n_gram overlap | |||||
# load original data | |||||
data = [] | |||||
with open(self.data_path) as f: | |||||
for line in f: | |||||
cur_data = json.loads(line) | |||||
if 'text' in cur_data: | |||||
new_data = {} | |||||
new_data['article'] = cur_data['text'] | |||||
new_data['abstract'] = cur_data['summary'] | |||||
data.append(new_data) | |||||
else: | |||||
data.append(cur_data) | |||||
# write decode sentences and references | |||||
if use_ngram_block == True: | |||||
print('\nStart {}-gram blocking !!!'.format(self.ngram_block)) | |||||
for i, ext_ids in enumerate(self.ext): | |||||
dec, ref = [], [] | |||||
if use_ngram_block == False: | |||||
n_sent = min(len(data[i]['article']), self.n_ext) | |||||
for j in range(n_sent): | |||||
idx = ext_ids[j] | |||||
dec.append(data[i]['article'][idx]) | |||||
else: | |||||
n_sent = len(ext_ids) | |||||
dic = {} | |||||
for j in range(n_sent): | |||||
sent = data[i]['article'][ext_ids[j]] | |||||
if check_n_gram(sent, self.ngram_block, dic) == True: | |||||
dec.append(sent) | |||||
# update dic | |||||
tokens = sent.split(' ') | |||||
s_len = len(tokens) | |||||
for k in range(s_len): | |||||
if k + self.ngram_block > s_len: | |||||
break | |||||
dic[' '.join(tokens[k: k + self.ngram_block])] = 1 | |||||
if len(dec) >= self.n_ext: | |||||
break | |||||
for sent in data[i]['abstract']: | |||||
ref.append(sent) | |||||
with open(join(self.dec_path, '{}.dec'.format(i)), 'w') as f: | |||||
for sent in dec: | |||||
print(sent, file=f) | |||||
with open(join(self.ref_path, '{}.ref'.format(i)), 'w') as f: | |||||
for sent in ref: | |||||
print(sent, file=f) | |||||
print('\nStart evaluating ROUGE score !!!') | |||||
R_1, R_2, R_L = RougeMetric.eval_rouge(self.dec_path, self.ref_path) | |||||
eval_result = {'ROUGE-1': R_1, 'ROUGE-2': R_2, 'ROUGE-L':R_L} | |||||
if reset == True: | |||||
self.cur_idx = 0 | |||||
self.ext = [] | |||||
self.start = time() | |||||
return eval_result |
@@ -0,0 +1,51 @@ | |||||
import torch | |||||
from torch import nn | |||||
from torch.nn import init | |||||
from fastNLP.modules.encoder._bert import BertModel | |||||
class Classifier(nn.Module): | |||||
def __init__(self, hidden_size): | |||||
super(Classifier, self).__init__() | |||||
self.linear = nn.Linear(hidden_size, 1) | |||||
self.sigmoid = nn.Sigmoid() | |||||
def forward(self, inputs, mask_cls): | |||||
h = self.linear(inputs).squeeze(-1) # [batch_size, seq_len] | |||||
sent_scores = self.sigmoid(h) * mask_cls.float() | |||||
return sent_scores | |||||
class BertSum(nn.Module): | |||||
def __init__(self, hidden_size=768): | |||||
super(BertSum, self).__init__() | |||||
self.hidden_size = hidden_size | |||||
self.encoder = BertModel.from_pretrained('/path/to/uncased_L-12_H-768_A-12') | |||||
self.decoder = Classifier(self.hidden_size) | |||||
def forward(self, article, segment_id, cls_id): | |||||
# print(article.device) | |||||
# print(segment_id.device) | |||||
# print(cls_id.device) | |||||
input_mask = 1 - (article == 0) | |||||
mask_cls = 1 - (cls_id == -1) | |||||
assert input_mask.size() == article.size() | |||||
assert mask_cls.size() == cls_id.size() | |||||
bert_out = self.encoder(article, token_type_ids=segment_id, attention_mask=input_mask) | |||||
bert_out = bert_out[0][-1] # last layer | |||||
sent_emb = bert_out[torch.arange(bert_out.size(0)).unsqueeze(1), cls_id] | |||||
sent_emb = sent_emb * mask_cls.unsqueeze(-1).float() | |||||
assert sent_emb.size() == (article.size(0), cls_id.size(1), self.hidden_size) # [batch_size, seq_len, hidden_size] | |||||
sent_scores = self.decoder(sent_emb, mask_cls) # [batch_size, seq_len] | |||||
assert sent_scores.size() == (article.size(0), cls_id.size(1)) | |||||
return {'pred': sent_scores, 'mask': mask_cls} |
@@ -0,0 +1,147 @@ | |||||
import sys | |||||
import argparse | |||||
import os | |||||
import json | |||||
import torch | |||||
from time import time | |||||
from datetime import timedelta | |||||
from os.path import join, exists | |||||
from torch.optim import Adam | |||||
from utils import get_data_path, get_rouge_path | |||||
from dataloader import BertSumLoader | |||||
from model import BertSum | |||||
from fastNLP.core.optimizer import AdamW | |||||
from metrics import MyBCELoss, LossMetric, RougeMetric | |||||
from fastNLP.core.sampler import BucketSampler | |||||
from callback import MyCallback, SaveModelCallback | |||||
from fastNLP.core.trainer import Trainer | |||||
from fastNLP.core.tester import Tester | |||||
def configure_training(args): | |||||
devices = [int(gpu) for gpu in args.gpus.split(',')] | |||||
params = {} | |||||
params['label_type'] = args.label_type | |||||
params['batch_size'] = args.batch_size | |||||
params['accum_count'] = args.accum_count | |||||
params['max_lr'] = args.max_lr | |||||
params['warmup_steps'] = args.warmup_steps | |||||
params['n_epochs'] = args.n_epochs | |||||
params['valid_steps'] = args.valid_steps | |||||
return devices, params | |||||
def train_model(args): | |||||
# check if the data_path and save_path exists | |||||
data_paths = get_data_path(args.mode, args.label_type) | |||||
for name in data_paths: | |||||
assert exists(data_paths[name]) | |||||
if not exists(args.save_path): | |||||
os.makedirs(args.save_path) | |||||
# load summarization datasets | |||||
datasets = BertSumLoader().process(data_paths) | |||||
print('Information of dataset is:') | |||||
print(datasets) | |||||
train_set = datasets.datasets['train'] | |||||
valid_set = datasets.datasets['val'] | |||||
# configure training | |||||
devices, train_params = configure_training(args) | |||||
with open(join(args.save_path, 'params.json'), 'w') as f: | |||||
json.dump(train_params, f, indent=4) | |||||
print('Devices is:') | |||||
print(devices) | |||||
# configure model | |||||
model = BertSum() | |||||
optimizer = Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0) | |||||
callbacks = [MyCallback(args), SaveModelCallback(args.save_path)] | |||||
criterion = MyBCELoss() | |||||
val_metric = [LossMetric()] | |||||
# sampler = BucketSampler(num_buckets=32, batch_size=args.batch_size) | |||||
trainer = Trainer(train_data=train_set, model=model, optimizer=optimizer, | |||||
loss=criterion, batch_size=args.batch_size, # sampler=sampler, | |||||
update_every=args.accum_count, n_epochs=args.n_epochs, | |||||
print_every=100, dev_data=valid_set, metrics=val_metric, | |||||
metric_key='-loss', validate_every=args.valid_steps, | |||||
save_path=args.save_path, device=devices, callbacks=callbacks) | |||||
print('Start training with the following hyper-parameters:') | |||||
print(train_params) | |||||
trainer.train() | |||||
def test_model(args): | |||||
models = os.listdir(args.save_path) | |||||
# load dataset | |||||
data_paths = get_data_path(args.mode, args.label_type) | |||||
datasets = BertSumLoader().process(data_paths) | |||||
print('Information of dataset is:') | |||||
print(datasets) | |||||
test_set = datasets.datasets['test'] | |||||
# only need 1 gpu for testing | |||||
device = int(args.gpus) | |||||
args.batch_size = 1 | |||||
for cur_model in models: | |||||
print('Current model is {}'.format(cur_model)) | |||||
# load model | |||||
model = torch.load(join(args.save_path, cur_model)) | |||||
# configure testing | |||||
original_path, dec_path, ref_path = get_rouge_path(args.label_type) | |||||
test_metric = RougeMetric(data_path=original_path, dec_path=dec_path, | |||||
ref_path=ref_path, n_total = len(test_set)) | |||||
tester = Tester(data=test_set, model=model, metrics=[test_metric], | |||||
batch_size=args.batch_size, device=device) | |||||
tester.test() | |||||
if __name__ == '__main__': | |||||
parser = argparse.ArgumentParser( | |||||
description='training/testing of BertSum(liu et al. 2019)' | |||||
) | |||||
parser.add_argument('--mode', required=True, | |||||
help='training or testing of BertSum', type=str) | |||||
parser.add_argument('--label_type', default='greedy', | |||||
help='greedy/limit', type=str) | |||||
parser.add_argument('--save_path', required=True, | |||||
help='root of the model', type=str) | |||||
# example for gpus input: '0,1,2,3' | |||||
parser.add_argument('--gpus', required=True, | |||||
help='available gpus for training(separated by commas)', type=str) | |||||
parser.add_argument('--batch_size', default=18, | |||||
help='the training batch size', type=int) | |||||
parser.add_argument('--accum_count', default=2, | |||||
help='number of updates steps to accumulate before performing a backward/update pass.', type=int) | |||||
parser.add_argument('--max_lr', default=2e-5, | |||||
help='max learning rate for warm up', type=float) | |||||
parser.add_argument('--warmup_steps', default=10000, | |||||
help='warm up steps for training', type=int) | |||||
parser.add_argument('--n_epochs', default=10, | |||||
help='total number of training epochs', type=int) | |||||
parser.add_argument('--valid_steps', default=1000, | |||||
help='number of update steps for checkpoint and validation', type=int) | |||||
args = parser.parse_args() | |||||
if args.mode == 'train': | |||||
print('Training process of BertSum !!!') | |||||
train_model(args) | |||||
else: | |||||
print('Testing process of BertSum !!!') | |||||
test_model(args) | |||||
@@ -0,0 +1,24 @@ | |||||
import os | |||||
from os.path import exists | |||||
def get_data_path(mode, label_type): | |||||
paths = {} | |||||
if mode == 'train': | |||||
paths['train'] = 'data/' + label_type + '/bert.train.jsonl' | |||||
paths['val'] = 'data/' + label_type + '/bert.val.jsonl' | |||||
else: | |||||
paths['test'] = 'data/' + label_type + '/bert.test.jsonl' | |||||
return paths | |||||
def get_rouge_path(label_type): | |||||
if label_type == 'others': | |||||
data_path = 'data/' + label_type + '/bert.test.jsonl' | |||||
else: | |||||
data_path = 'data/' + label_type + '/test.jsonl' | |||||
dec_path = 'dec' | |||||
ref_path = 'ref' | |||||
if not exists(ref_path): | |||||
os.makedirs(ref_path) | |||||
if not exists(dec_path): | |||||
os.makedirs(dec_path) | |||||
return data_path, dec_path, ref_path |