@@ -0,0 +1,129 @@ | |||
import os | |||
import torch | |||
import sys | |||
from torch import nn | |||
from fastNLP.core.callback import Callback | |||
from fastNLP.core.utils import _get_model_device | |||
class MyCallback(Callback): | |||
def __init__(self, args): | |||
super(MyCallback, self).__init__() | |||
self.args = args | |||
self.real_step = 0 | |||
def on_step_end(self): | |||
if self.step % self.update_every == 0 and self.step > 0: | |||
self.real_step += 1 | |||
cur_lr = self.args.max_lr * 100 * min(self.real_step ** (-0.5), self.real_step * self.args.warmup_steps**(-1.5)) | |||
for param_group in self.optimizer.param_groups: | |||
param_group['lr'] = cur_lr | |||
if self.real_step % 1000 == 0: | |||
self.pbar.write('Current learning rate is {:.8f}, real_step: {}'.format(cur_lr, self.real_step)) | |||
def on_epoch_end(self): | |||
self.pbar.write('Epoch {} is done !!!'.format(self.epoch)) | |||
def _save_model(model, model_name, save_dir, only_param=False): | |||
""" 存储不含有显卡信息的 state_dict 或 model | |||
:param model: | |||
:param model_name: | |||
:param save_dir: 保存的 directory | |||
:param only_param: | |||
:return: | |||
""" | |||
model_path = os.path.join(save_dir, model_name) | |||
if not os.path.isdir(save_dir): | |||
os.makedirs(save_dir, exist_ok=True) | |||
if isinstance(model, nn.DataParallel): | |||
model = model.module | |||
if only_param: | |||
state_dict = model.state_dict() | |||
for key in state_dict: | |||
state_dict[key] = state_dict[key].cpu() | |||
torch.save(state_dict, model_path) | |||
else: | |||
_model_device = _get_model_device(model) | |||
model.cpu() | |||
torch.save(model, model_path) | |||
model.to(_model_device) | |||
class SaveModelCallback(Callback): | |||
""" | |||
由于Trainer在训练过程中只会保存最佳的模型, 该 callback 可实现多种方式的结果存储。 | |||
会根据训练开始的时间戳在 save_dir 下建立文件夹,在再文件夹下存放多个模型 | |||
-save_dir | |||
-2019-07-03-15-06-36 | |||
-epoch0step20{metric_key}{evaluate_performance}.pt # metric是给定的metric_key, evaluate_perfomance是性能 | |||
-epoch1step40 | |||
-2019-07-03-15-10-00 | |||
-epoch:0step:20{metric_key}:{evaluate_performance}.pt # metric是给定的metric_key, evaluate_perfomance是性能 | |||
:param str save_dir: 将模型存放在哪个目录下,会在该目录下创建以时间戳命名的目录,并存放模型 | |||
:param int top: 保存dev表现top多少模型。-1为保存所有模型 | |||
:param bool only_param: 是否只保存模型权重 | |||
:param save_on_exception: 发生exception时,是否保存一份当时的模型 | |||
""" | |||
def __init__(self, save_dir, top=5, only_param=False, save_on_exception=False): | |||
super().__init__() | |||
if not os.path.isdir(save_dir): | |||
raise IsADirectoryError("{} is not a directory.".format(save_dir)) | |||
self.save_dir = save_dir | |||
if top < 0: | |||
self.top = sys.maxsize | |||
else: | |||
self.top = top | |||
self._ordered_save_models = [] # List[Tuple], Tuple[0]是metric, Tuple[1]是path。metric是依次变好的,所以从头删 | |||
self.only_param = only_param | |||
self.save_on_exception = save_on_exception | |||
def on_train_begin(self): | |||
self.save_dir = os.path.join(self.save_dir, self.trainer.start_time) | |||
def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval): | |||
metric_value = list(eval_result.values())[0][metric_key] | |||
self._save_this_model(metric_value) | |||
def _insert_into_ordered_save_models(self, pair): | |||
# pair:(metric_value, model_name) | |||
# 返回save的模型pair与删除的模型pair. pair中第一个元素是metric的值,第二个元素是模型的名称 | |||
index = -1 | |||
for _pair in self._ordered_save_models: | |||
if _pair[0]>=pair[0] and self.trainer.increase_better: | |||
break | |||
if not self.trainer.increase_better and _pair[0]<=pair[0]: | |||
break | |||
index += 1 | |||
save_pair = None | |||
if len(self._ordered_save_models)<self.top or (len(self._ordered_save_models)>=self.top and index!=-1): | |||
save_pair = pair | |||
self._ordered_save_models.insert(index+1, pair) | |||
delete_pair = None | |||
if len(self._ordered_save_models)>self.top: | |||
delete_pair = self._ordered_save_models.pop(0) | |||
return save_pair, delete_pair | |||
def _save_this_model(self, metric_value): | |||
name = "epoch:{}_step:{}_{}:{:.6f}.pt".format(self.epoch, self.step, self.trainer.metric_key, metric_value) | |||
save_pair, delete_pair = self._insert_into_ordered_save_models((metric_value, name)) | |||
if save_pair: | |||
try: | |||
_save_model(self.model, model_name=name, save_dir=self.save_dir, only_param=self.only_param) | |||
except Exception as e: | |||
print(f"The following exception:{e} happens when saves model to {self.save_dir}.") | |||
if delete_pair: | |||
try: | |||
delete_model_path = os.path.join(self.save_dir, delete_pair[1]) | |||
if os.path.exists(delete_model_path): | |||
os.remove(delete_model_path) | |||
except Exception as e: | |||
print(f"Fail to delete model {name} at {self.save_dir} caused by exception:{e}.") | |||
def on_exception(self, exception): | |||
if self.save_on_exception: | |||
name = "epoch:{}_step:{}_Exception:{}.pt".format(self.epoch, self.step, exception.__class__.__name__) | |||
_save_model(self.model, model_name=name, save_dir=self.save_dir, only_param=self.only_param) | |||
@@ -0,0 +1,157 @@ | |||
from time import time | |||
from datetime import timedelta | |||
from fastNLP.io.dataset_loader import JsonLoader | |||
from fastNLP.modules.encoder._bert import BertTokenizer | |||
from fastNLP.io.base_loader import DataInfo | |||
from fastNLP.core.const import Const | |||
class BertData(JsonLoader): | |||
def __init__(self, max_nsents=60, max_ntokens=100, max_len=512): | |||
fields = {'article': 'article', | |||
'label': 'label'} | |||
super(BertData, self).__init__(fields=fields) | |||
self.max_nsents = max_nsents | |||
self.max_ntokens = max_ntokens | |||
self.max_len = max_len | |||
self.tokenizer = BertTokenizer.from_pretrained('/path/to/uncased_L-12_H-768_A-12') | |||
self.cls_id = self.tokenizer.vocab['[CLS]'] | |||
self.sep_id = self.tokenizer.vocab['[SEP]'] | |||
self.pad_id = self.tokenizer.vocab['[PAD]'] | |||
def _load(self, paths): | |||
dataset = super(BertData, self)._load(paths) | |||
return dataset | |||
def process(self, paths): | |||
def truncate_articles(instance, max_nsents=self.max_nsents, max_ntokens=self.max_ntokens): | |||
article = [' '.join(sent.lower().split()[:max_ntokens]) for sent in instance['article']] | |||
return article[:max_nsents] | |||
def truncate_labels(instance): | |||
label = list(filter(lambda x: x < len(instance['article']), instance['label'])) | |||
return label | |||
def bert_tokenize(instance, tokenizer, max_len, pad_value): | |||
article = instance['article'] | |||
article = ' [SEP] [CLS] '.join(article) | |||
word_pieces = tokenizer.tokenize(article)[:(max_len - 2)] | |||
word_pieces = ['[CLS]'] + word_pieces + ['[SEP]'] | |||
token_ids = tokenizer.convert_tokens_to_ids(word_pieces) | |||
while len(token_ids) < max_len: | |||
token_ids.append(pad_value) | |||
assert len(token_ids) == max_len | |||
return token_ids | |||
def get_seg_id(instance, max_len, sep_id): | |||
_segs = [-1] + [i for i, idx in enumerate(instance['article']) if idx == sep_id] | |||
segs = [_segs[i] - _segs[i - 1] for i in range(1, len(_segs))] | |||
segment_id = [] | |||
for i, length in enumerate(segs): | |||
if i % 2 == 0: | |||
segment_id += length * [0] | |||
else: | |||
segment_id += length * [1] | |||
while len(segment_id) < max_len: | |||
segment_id.append(0) | |||
return segment_id | |||
def get_cls_id(instance, cls_id): | |||
classification_id = [i for i, idx in enumerate(instance['article']) if idx == cls_id] | |||
return classification_id | |||
def get_labels(instance): | |||
labels = [0] * len(instance['cls_id']) | |||
label_idx = list(filter(lambda x: x < len(instance['cls_id']), instance['label'])) | |||
for idx in label_idx: | |||
labels[idx] = 1 | |||
return labels | |||
datasets = {} | |||
for name in paths: | |||
datasets[name] = self._load(paths[name]) | |||
# remove empty samples | |||
datasets[name].drop(lambda ins: len(ins['article']) == 0 or len(ins['label']) == 0) | |||
# truncate articles | |||
datasets[name].apply(lambda ins: truncate_articles(ins, self.max_nsents, self.max_ntokens), new_field_name='article') | |||
# truncate labels | |||
datasets[name].apply(truncate_labels, new_field_name='label') | |||
# tokenize and convert tokens to id | |||
datasets[name].apply(lambda ins: bert_tokenize(ins, self.tokenizer, self.max_len, self.pad_id), new_field_name='article') | |||
# get segment id | |||
datasets[name].apply(lambda ins: get_seg_id(ins, self.max_len, self.sep_id), new_field_name='segment_id') | |||
# get classification id | |||
datasets[name].apply(lambda ins: get_cls_id(ins, self.cls_id), new_field_name='cls_id') | |||
# get label | |||
datasets[name].apply(get_labels, new_field_name='label') | |||
# rename filed | |||
datasets[name].rename_field('article', Const.INPUTS(0)) | |||
datasets[name].rename_field('segment_id', Const.INPUTS(1)) | |||
datasets[name].rename_field('cls_id', Const.INPUTS(2)) | |||
datasets[name].rename_field('lbael', Const.TARGET) | |||
# set input and target | |||
datasets[name].set_input(Const.INPUTS(0), Const.INPUTS(1), Const.INPUTS(2)) | |||
datasets[name].set_target(Const.TARGET) | |||
# set paddding value | |||
datasets[name].set_pad_val('article', 0) | |||
return DataInfo(datasets=datasets) | |||
class BertSumLoader(JsonLoader): | |||
def __init__(self): | |||
fields = {'article': 'article', | |||
'segment_id': 'segment_id', | |||
'cls_id': 'cls_id', | |||
'label': Const.TARGET | |||
} | |||
super(BertSumLoader, self).__init__(fields=fields) | |||
def _load(self, paths): | |||
dataset = super(BertSumLoader, self)._load(paths) | |||
return dataset | |||
def process(self, paths): | |||
def get_seq_len(instance): | |||
return len(instance['article']) | |||
print('Start loading datasets !!!') | |||
start = time() | |||
# load datasets | |||
datasets = {} | |||
for name in paths: | |||
datasets[name] = self._load(paths[name]) | |||
datasets[name].apply(get_seq_len, new_field_name='seq_len') | |||
# set input and target | |||
datasets[name].set_input('article', 'segment_id', 'cls_id') | |||
datasets[name].set_target(Const.TARGET) | |||
# set padding value | |||
datasets[name].set_pad_val('article', 0) | |||
datasets[name].set_pad_val('segment_id', 0) | |||
datasets[name].set_pad_val('cls_id', -1) | |||
datasets[name].set_pad_val(Const.TARGET, 0) | |||
print('Finished in {}'.format(timedelta(seconds=time()-start))) | |||
return DataInfo(datasets=datasets) |
@@ -0,0 +1,178 @@ | |||
import numpy as np | |||
import json | |||
from os.path import join | |||
import torch | |||
import logging | |||
import tempfile | |||
import subprocess as sp | |||
from datetime import timedelta | |||
from time import time | |||
from pyrouge import Rouge155 | |||
from pyrouge.utils import log | |||
from fastNLP.core.losses import LossBase | |||
from fastNLP.core.metrics import MetricBase | |||
_ROUGE_PATH = '/path/to/RELEASE-1.5.5' | |||
class MyBCELoss(LossBase): | |||
def __init__(self, pred=None, target=None, mask=None): | |||
super(MyBCELoss, self).__init__() | |||
self._init_param_map(pred=pred, target=target, mask=mask) | |||
self.loss_func = torch.nn.BCELoss(reduction='none') | |||
def get_loss(self, pred, target, mask): | |||
loss = self.loss_func(pred, target.float()) | |||
loss = (loss * mask.float()).sum() | |||
return loss | |||
class LossMetric(MetricBase): | |||
def __init__(self, pred=None, target=None, mask=None): | |||
super(LossMetric, self).__init__() | |||
self._init_param_map(pred=pred, target=target, mask=mask) | |||
self.loss_func = torch.nn.BCELoss(reduction='none') | |||
self.avg_loss = 0.0 | |||
self.nsamples = 0 | |||
def evaluate(self, pred, target, mask): | |||
batch_size = pred.size(0) | |||
loss = self.loss_func(pred, target.float()) | |||
loss = (loss * mask.float()).sum() | |||
self.avg_loss += loss | |||
self.nsamples += batch_size | |||
def get_metric(self, reset=True): | |||
self.avg_loss = self.avg_loss / self.nsamples | |||
eval_result = {'loss': self.avg_loss} | |||
if reset: | |||
self.avg_loss = 0 | |||
self.nsamples = 0 | |||
return eval_result | |||
class RougeMetric(MetricBase): | |||
def __init__(self, data_path, dec_path, ref_path, n_total, n_ext=3, ngram_block=3, pred=None, target=None, mask=None): | |||
super(RougeMetric, self).__init__() | |||
self._init_param_map(pred=pred, target=target, mask=mask) | |||
self.data_path = data_path | |||
self.dec_path = dec_path | |||
self.ref_path = ref_path | |||
self.n_total = n_total | |||
self.n_ext = n_ext | |||
self.ngram_block = ngram_block | |||
self.cur_idx = 0 | |||
self.ext = [] | |||
self.start = time() | |||
@staticmethod | |||
def eval_rouge(dec_dir, ref_dir): | |||
assert _ROUGE_PATH is not None | |||
log.get_global_console_logger().setLevel(logging.WARNING) | |||
dec_pattern = '(\d+).dec' | |||
ref_pattern = '#ID#.ref' | |||
cmd = '-c 95 -r 1000 -n 2 -m' | |||
with tempfile.TemporaryDirectory() as tmp_dir: | |||
Rouge155.convert_summaries_to_rouge_format( | |||
dec_dir, join(tmp_dir, 'dec')) | |||
Rouge155.convert_summaries_to_rouge_format( | |||
ref_dir, join(tmp_dir, 'ref')) | |||
Rouge155.write_config_static( | |||
join(tmp_dir, 'dec'), dec_pattern, | |||
join(tmp_dir, 'ref'), ref_pattern, | |||
join(tmp_dir, 'settings.xml'), system_id=1 | |||
) | |||
cmd = (join(_ROUGE_PATH, 'ROUGE-1.5.5.pl') | |||
+ ' -e {} '.format(join(_ROUGE_PATH, 'data')) | |||
+ cmd | |||
+ ' -a {}'.format(join(tmp_dir, 'settings.xml'))) | |||
output = sp.check_output(cmd.split(' '), universal_newlines=True) | |||
R_1 = float(output.split('\n')[3].split(' ')[3]) | |||
R_2 = float(output.split('\n')[7].split(' ')[3]) | |||
R_L = float(output.split('\n')[11].split(' ')[3]) | |||
print(output) | |||
return R_1, R_2, R_L | |||
def evaluate(self, pred, target, mask): | |||
pred = pred + mask.float() | |||
pred = pred.cpu().data.numpy() | |||
ext_ids = np.argsort(-pred, 1) | |||
for sent_id in ext_ids: | |||
self.ext.append(sent_id) | |||
self.cur_idx += 1 | |||
print('{}/{} ({:.2f}%) decoded in {} seconds\r'.format( | |||
self.cur_idx, self.n_total, self.cur_idx/self.n_total*100, timedelta(seconds=int(time()-self.start)) | |||
), end='') | |||
def get_metric(self, use_ngram_block=True, reset=True): | |||
def check_n_gram(sentence, n, dic): | |||
tokens = sentence.split(' ') | |||
s_len = len(tokens) | |||
for i in range(s_len): | |||
if i + n > s_len: | |||
break | |||
if ' '.join(tokens[i: i + n]) in dic: | |||
return False | |||
return True # no n_gram overlap | |||
# load original data | |||
data = [] | |||
with open(self.data_path) as f: | |||
for line in f: | |||
cur_data = json.loads(line) | |||
if 'text' in cur_data: | |||
new_data = {} | |||
new_data['article'] = cur_data['text'] | |||
new_data['abstract'] = cur_data['summary'] | |||
data.append(new_data) | |||
else: | |||
data.append(cur_data) | |||
# write decode sentences and references | |||
if use_ngram_block == True: | |||
print('\nStart {}-gram blocking !!!'.format(self.ngram_block)) | |||
for i, ext_ids in enumerate(self.ext): | |||
dec, ref = [], [] | |||
if use_ngram_block == False: | |||
n_sent = min(len(data[i]['article']), self.n_ext) | |||
for j in range(n_sent): | |||
idx = ext_ids[j] | |||
dec.append(data[i]['article'][idx]) | |||
else: | |||
n_sent = len(ext_ids) | |||
dic = {} | |||
for j in range(n_sent): | |||
sent = data[i]['article'][ext_ids[j]] | |||
if check_n_gram(sent, self.ngram_block, dic) == True: | |||
dec.append(sent) | |||
# update dic | |||
tokens = sent.split(' ') | |||
s_len = len(tokens) | |||
for k in range(s_len): | |||
if k + self.ngram_block > s_len: | |||
break | |||
dic[' '.join(tokens[k: k + self.ngram_block])] = 1 | |||
if len(dec) >= self.n_ext: | |||
break | |||
for sent in data[i]['abstract']: | |||
ref.append(sent) | |||
with open(join(self.dec_path, '{}.dec'.format(i)), 'w') as f: | |||
for sent in dec: | |||
print(sent, file=f) | |||
with open(join(self.ref_path, '{}.ref'.format(i)), 'w') as f: | |||
for sent in ref: | |||
print(sent, file=f) | |||
print('\nStart evaluating ROUGE score !!!') | |||
R_1, R_2, R_L = RougeMetric.eval_rouge(self.dec_path, self.ref_path) | |||
eval_result = {'ROUGE-1': R_1, 'ROUGE-2': R_2, 'ROUGE-L':R_L} | |||
if reset == True: | |||
self.cur_idx = 0 | |||
self.ext = [] | |||
self.start = time() | |||
return eval_result |
@@ -0,0 +1,51 @@ | |||
import torch | |||
from torch import nn | |||
from torch.nn import init | |||
from fastNLP.modules.encoder._bert import BertModel | |||
class Classifier(nn.Module): | |||
def __init__(self, hidden_size): | |||
super(Classifier, self).__init__() | |||
self.linear = nn.Linear(hidden_size, 1) | |||
self.sigmoid = nn.Sigmoid() | |||
def forward(self, inputs, mask_cls): | |||
h = self.linear(inputs).squeeze(-1) # [batch_size, seq_len] | |||
sent_scores = self.sigmoid(h) * mask_cls.float() | |||
return sent_scores | |||
class BertSum(nn.Module): | |||
def __init__(self, hidden_size=768): | |||
super(BertSum, self).__init__() | |||
self.hidden_size = hidden_size | |||
self.encoder = BertModel.from_pretrained('/path/to/uncased_L-12_H-768_A-12') | |||
self.decoder = Classifier(self.hidden_size) | |||
def forward(self, article, segment_id, cls_id): | |||
# print(article.device) | |||
# print(segment_id.device) | |||
# print(cls_id.device) | |||
input_mask = 1 - (article == 0) | |||
mask_cls = 1 - (cls_id == -1) | |||
assert input_mask.size() == article.size() | |||
assert mask_cls.size() == cls_id.size() | |||
bert_out = self.encoder(article, token_type_ids=segment_id, attention_mask=input_mask) | |||
bert_out = bert_out[0][-1] # last layer | |||
sent_emb = bert_out[torch.arange(bert_out.size(0)).unsqueeze(1), cls_id] | |||
sent_emb = sent_emb * mask_cls.unsqueeze(-1).float() | |||
assert sent_emb.size() == (article.size(0), cls_id.size(1), self.hidden_size) # [batch_size, seq_len, hidden_size] | |||
sent_scores = self.decoder(sent_emb, mask_cls) # [batch_size, seq_len] | |||
assert sent_scores.size() == (article.size(0), cls_id.size(1)) | |||
return {'pred': sent_scores, 'mask': mask_cls} |
@@ -0,0 +1,147 @@ | |||
import sys | |||
import argparse | |||
import os | |||
import json | |||
import torch | |||
from time import time | |||
from datetime import timedelta | |||
from os.path import join, exists | |||
from torch.optim import Adam | |||
from utils import get_data_path, get_rouge_path | |||
from dataloader import BertSumLoader | |||
from model import BertSum | |||
from fastNLP.core.optimizer import AdamW | |||
from metrics import MyBCELoss, LossMetric, RougeMetric | |||
from fastNLP.core.sampler import BucketSampler | |||
from callback import MyCallback, SaveModelCallback | |||
from fastNLP.core.trainer import Trainer | |||
from fastNLP.core.tester import Tester | |||
def configure_training(args): | |||
devices = [int(gpu) for gpu in args.gpus.split(',')] | |||
params = {} | |||
params['label_type'] = args.label_type | |||
params['batch_size'] = args.batch_size | |||
params['accum_count'] = args.accum_count | |||
params['max_lr'] = args.max_lr | |||
params['warmup_steps'] = args.warmup_steps | |||
params['n_epochs'] = args.n_epochs | |||
params['valid_steps'] = args.valid_steps | |||
return devices, params | |||
def train_model(args): | |||
# check if the data_path and save_path exists | |||
data_paths = get_data_path(args.mode, args.label_type) | |||
for name in data_paths: | |||
assert exists(data_paths[name]) | |||
if not exists(args.save_path): | |||
os.makedirs(args.save_path) | |||
# load summarization datasets | |||
datasets = BertSumLoader().process(data_paths) | |||
print('Information of dataset is:') | |||
print(datasets) | |||
train_set = datasets.datasets['train'] | |||
valid_set = datasets.datasets['val'] | |||
# configure training | |||
devices, train_params = configure_training(args) | |||
with open(join(args.save_path, 'params.json'), 'w') as f: | |||
json.dump(train_params, f, indent=4) | |||
print('Devices is:') | |||
print(devices) | |||
# configure model | |||
model = BertSum() | |||
optimizer = Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0) | |||
callbacks = [MyCallback(args), SaveModelCallback(args.save_path)] | |||
criterion = MyBCELoss() | |||
val_metric = [LossMetric()] | |||
# sampler = BucketSampler(num_buckets=32, batch_size=args.batch_size) | |||
trainer = Trainer(train_data=train_set, model=model, optimizer=optimizer, | |||
loss=criterion, batch_size=args.batch_size, # sampler=sampler, | |||
update_every=args.accum_count, n_epochs=args.n_epochs, | |||
print_every=100, dev_data=valid_set, metrics=val_metric, | |||
metric_key='-loss', validate_every=args.valid_steps, | |||
save_path=args.save_path, device=devices, callbacks=callbacks) | |||
print('Start training with the following hyper-parameters:') | |||
print(train_params) | |||
trainer.train() | |||
def test_model(args): | |||
models = os.listdir(args.save_path) | |||
# load dataset | |||
data_paths = get_data_path(args.mode, args.label_type) | |||
datasets = BertSumLoader().process(data_paths) | |||
print('Information of dataset is:') | |||
print(datasets) | |||
test_set = datasets.datasets['test'] | |||
# only need 1 gpu for testing | |||
device = int(args.gpus) | |||
args.batch_size = 1 | |||
for cur_model in models: | |||
print('Current model is {}'.format(cur_model)) | |||
# load model | |||
model = torch.load(join(args.save_path, cur_model)) | |||
# configure testing | |||
original_path, dec_path, ref_path = get_rouge_path(args.label_type) | |||
test_metric = RougeMetric(data_path=original_path, dec_path=dec_path, | |||
ref_path=ref_path, n_total = len(test_set)) | |||
tester = Tester(data=test_set, model=model, metrics=[test_metric], | |||
batch_size=args.batch_size, device=device) | |||
tester.test() | |||
if __name__ == '__main__': | |||
parser = argparse.ArgumentParser( | |||
description='training/testing of BertSum(liu et al. 2019)' | |||
) | |||
parser.add_argument('--mode', required=True, | |||
help='training or testing of BertSum', type=str) | |||
parser.add_argument('--label_type', default='greedy', | |||
help='greedy/limit', type=str) | |||
parser.add_argument('--save_path', required=True, | |||
help='root of the model', type=str) | |||
# example for gpus input: '0,1,2,3' | |||
parser.add_argument('--gpus', required=True, | |||
help='available gpus for training(separated by commas)', type=str) | |||
parser.add_argument('--batch_size', default=18, | |||
help='the training batch size', type=int) | |||
parser.add_argument('--accum_count', default=2, | |||
help='number of updates steps to accumulate before performing a backward/update pass.', type=int) | |||
parser.add_argument('--max_lr', default=2e-5, | |||
help='max learning rate for warm up', type=float) | |||
parser.add_argument('--warmup_steps', default=10000, | |||
help='warm up steps for training', type=int) | |||
parser.add_argument('--n_epochs', default=10, | |||
help='total number of training epochs', type=int) | |||
parser.add_argument('--valid_steps', default=1000, | |||
help='number of update steps for checkpoint and validation', type=int) | |||
args = parser.parse_args() | |||
if args.mode == 'train': | |||
print('Training process of BertSum !!!') | |||
train_model(args) | |||
else: | |||
print('Testing process of BertSum !!!') | |||
test_model(args) | |||
@@ -0,0 +1,24 @@ | |||
import os | |||
from os.path import exists | |||
def get_data_path(mode, label_type): | |||
paths = {} | |||
if mode == 'train': | |||
paths['train'] = 'data/' + label_type + '/bert.train.jsonl' | |||
paths['val'] = 'data/' + label_type + '/bert.val.jsonl' | |||
else: | |||
paths['test'] = 'data/' + label_type + '/bert.test.jsonl' | |||
return paths | |||
def get_rouge_path(label_type): | |||
if label_type == 'others': | |||
data_path = 'data/' + label_type + '/bert.test.jsonl' | |||
else: | |||
data_path = 'data/' + label_type + '/test.jsonl' | |||
dec_path = 'dec' | |||
ref_path = 'ref' | |||
if not exists(ref_path): | |||
os.makedirs(ref_path) | |||
if not exists(dec_path): | |||
os.makedirs(dec_path) | |||
return data_path, dec_path, ref_path |