|
- import sys
- import argparse
- import os
- import json
- import torch
- from time import time
- from datetime import timedelta
- from os.path import join, exists
- from torch.optim import Adam
-
- from utils import get_data_path, get_rouge_path
-
- from dataloader import BertSumLoader
- from model import BertSum
- from fastNLP.core.optimizer import AdamW
- from metrics import MyBCELoss, LossMetric, RougeMetric
- from fastNLP.core.sampler import BucketSampler
- from callback import MyCallback, SaveModelCallback
- from fastNLP.core.trainer import Trainer
- from fastNLP.core.tester import Tester
-
-
- def configure_training(args):
- devices = [int(gpu) for gpu in args.gpus.split(',')]
- params = {}
- params['label_type'] = args.label_type
- params['batch_size'] = args.batch_size
- params['accum_count'] = args.accum_count
- params['max_lr'] = args.max_lr
- params['warmup_steps'] = args.warmup_steps
- params['n_epochs'] = args.n_epochs
- params['valid_steps'] = args.valid_steps
- return devices, params
-
- def train_model(args):
-
- # check if the data_path and save_path exists
- data_paths = get_data_path(args.mode, args.label_type)
- for name in data_paths:
- assert exists(data_paths[name])
- if not exists(args.save_path):
- os.makedirs(args.save_path)
-
- # load summarization datasets
- datasets = BertSumLoader().process(data_paths)
- print('Information of dataset is:')
- print(datasets)
- train_set = datasets.datasets['train']
- valid_set = datasets.datasets['val']
-
- # configure training
- devices, train_params = configure_training(args)
- with open(join(args.save_path, 'params.json'), 'w') as f:
- json.dump(train_params, f, indent=4)
- print('Devices is:')
- print(devices)
-
- # configure model
- model = BertSum()
- optimizer = Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0)
- callbacks = [MyCallback(args), SaveModelCallback(args.save_path)]
- criterion = MyBCELoss()
- val_metric = [LossMetric()]
- # sampler = BucketSampler(num_buckets=32, batch_size=args.batch_size)
- trainer = Trainer(train_data=train_set, model=model, optimizer=optimizer,
- loss=criterion, batch_size=args.batch_size, # sampler=sampler,
- update_every=args.accum_count, n_epochs=args.n_epochs,
- print_every=100, dev_data=valid_set, metrics=val_metric,
- metric_key='-loss', validate_every=args.valid_steps,
- save_path=args.save_path, device=devices, callbacks=callbacks)
-
- print('Start training with the following hyper-parameters:')
- print(train_params)
- trainer.train()
-
- def test_model(args):
-
- models = os.listdir(args.save_path)
-
- # load dataset
- data_paths = get_data_path(args.mode, args.label_type)
- datasets = BertSumLoader().process(data_paths)
- print('Information of dataset is:')
- print(datasets)
- test_set = datasets.datasets['test']
-
- # only need 1 gpu for testing
- device = int(args.gpus)
-
- args.batch_size = 1
-
- for cur_model in models:
-
- print('Current model is {}'.format(cur_model))
-
- # load model
- model = torch.load(join(args.save_path, cur_model))
-
- # configure testing
- original_path, dec_path, ref_path = get_rouge_path(args.label_type)
- test_metric = RougeMetric(data_path=original_path, dec_path=dec_path,
- ref_path=ref_path, n_total = len(test_set))
- tester = Tester(data=test_set, model=model, metrics=[test_metric],
- batch_size=args.batch_size, device=device)
- tester.test()
-
-
- if __name__ == '__main__':
- parser = argparse.ArgumentParser(
- description='training/testing of BertSum(liu et al. 2019)'
- )
- parser.add_argument('--mode', required=True,
- help='training or testing of BertSum', type=str)
-
- parser.add_argument('--label_type', default='greedy',
- help='greedy/limit', type=str)
- parser.add_argument('--save_path', required=True,
- help='root of the model', type=str)
- # example for gpus input: '0,1,2,3'
- parser.add_argument('--gpus', required=True,
- help='available gpus for training(separated by commas)', type=str)
-
- parser.add_argument('--batch_size', default=18,
- help='the training batch size', type=int)
- parser.add_argument('--accum_count', default=2,
- help='number of updates steps to accumulate before performing a backward/update pass.', type=int)
- parser.add_argument('--max_lr', default=2e-5,
- help='max learning rate for warm up', type=float)
- parser.add_argument('--warmup_steps', default=10000,
- help='warm up steps for training', type=int)
- parser.add_argument('--n_epochs', default=10,
- help='total number of training epochs', type=int)
- parser.add_argument('--valid_steps', default=1000,
- help='number of update steps for checkpoint and validation', type=int)
-
- args = parser.parse_args()
-
- if args.mode == 'train':
- print('Training process of BertSum !!!')
- train_model(args)
- else:
- print('Testing process of BertSum !!!')
- test_model(args)
-
-
-
|