From 3192c9ac666fcb2b7b1d2410f67718e684ebac35 Mon Sep 17 00:00:00 2001 From: yunfan Date: Sun, 4 Nov 2018 17:57:35 +0800 Subject: [PATCH] update trainer --- fastNLP/core/field.py | 3 + fastNLP/core/instance.py | 3 + fastNLP/core/tester.py | 2 +- fastNLP/core/trainer.py | 34 ++++--- fastNLP/models/biaffine_parser.py | 40 ++++++-- reproduction/Biaffine_parser/cfg.cfg | 11 ++- reproduction/Biaffine_parser/run.py | 136 ++++++++++++++++++--------- 7 files changed, 157 insertions(+), 72 deletions(-) diff --git a/fastNLP/core/field.py b/fastNLP/core/field.py index a3cf21d5..5e0895d1 100644 --- a/fastNLP/core/field.py +++ b/fastNLP/core/field.py @@ -24,6 +24,9 @@ class Field(object): def __repr__(self): return self.contents().__repr__() + def new(self, *args, **kwargs): + return self.__class__(*args, **kwargs, is_target=self.is_target) + class TextField(Field): def __init__(self, text, is_target): """ diff --git a/fastNLP/core/instance.py b/fastNLP/core/instance.py index 0527a16f..50787fd1 100644 --- a/fastNLP/core/instance.py +++ b/fastNLP/core/instance.py @@ -35,6 +35,9 @@ class Instance(object): else: raise KeyError("{} not found".format(name)) + def __setitem__(self, name, field): + return self.add_field(name, field) + def get_length(self): """Fetch the length of all fields in the instance. diff --git a/fastNLP/core/tester.py b/fastNLP/core/tester.py index 51f84691..4c0cfb41 100644 --- a/fastNLP/core/tester.py +++ b/fastNLP/core/tester.py @@ -74,7 +74,7 @@ class Tester(object): output_list = [] truth_list = [] - data_iterator = Batch(dev_data, self.batch_size, sampler=RandomSampler(), use_cuda=self.use_cuda) + data_iterator = Batch(dev_data, self.batch_size, sampler=RandomSampler(), use_cuda=self.use_cuda, sort_in_batch=True, sort_key='word_seq') with torch.no_grad(): for batch_x, batch_y in data_iterator: diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index 49761725..8334a960 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -1,6 +1,6 @@ import os import time -from datetime import timedelta +from datetime import timedelta, datetime import torch from tensorboardX import SummaryWriter @@ -15,7 +15,7 @@ from fastNLP.saver.logger import create_logger from fastNLP.saver.model_saver import ModelSaver logger = create_logger(__name__, "./train_test.log") - +logger.disabled = True class Trainer(object): """Operations of training a model, including data loading, gradient descent, and validation. @@ -42,7 +42,7 @@ class Trainer(object): """ default_args = {"epochs": 1, "batch_size": 2, "validate": False, "use_cuda": False, "pickle_path": "./save/", "save_best_dev": False, "model_name": "default_model_name.pkl", "print_every_step": 1, - "valid_step": 500, "eval_sort_key": None, + "valid_step": 500, "eval_sort_key": 'acc', "loss": Loss(None), # used to pass type check "optimizer": Optimizer("Adam", lr=0.001, weight_decay=0), "evaluator": Evaluator() @@ -111,13 +111,17 @@ class Trainer(object): else: self._model = network + print(self._model) + # define Tester over dev data + self.dev_data = None if self.validate: default_valid_args = {"batch_size": self.batch_size, "pickle_path": self.pickle_path, "use_cuda": self.use_cuda, "evaluator": self._evaluator} if self.validator is None: self.validator = self._create_validator(default_valid_args) logger.info("validator defined as {}".format(str(self.validator))) + self.dev_data = dev_data # optimizer and loss self.define_optimizer() @@ -130,7 +134,7 @@ class Trainer(object): # main training procedure start = time.time() - self.start_time = str(start) + self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M')) logger.info("training epochs started " + self.start_time) epoch, iters = 1, 0 @@ -141,15 +145,17 @@ class Trainer(object): # prepare mini-batch iterator data_iterator = Batch(train_data, batch_size=self.batch_size, sampler=RandomSampler(), - use_cuda=self.use_cuda) + use_cuda=self.use_cuda, sort_in_batch=True, sort_key='word_seq') logger.info("prepared data iterator") # one forward and backward pass - iters += self._train_step(data_iterator, network, start=start, n_print=self.print_every_step, epoch=epoch, step=iters, dev_data=dev_data) + iters = self._train_step(data_iterator, network, start=start, n_print=self.print_every_step, epoch=epoch, step=iters, dev_data=dev_data) # validation if self.validate: self.valid_model() + self.save_model(self._model, 'training_model_'+self.start_time) + epoch += 1 def _train_step(self, data_iterator, network, **kwargs): """Training process in one epoch. @@ -160,13 +166,16 @@ class Trainer(object): - epoch: int, """ step = kwargs['step'] - dev_data = kwargs['dev_data'] for batch_x, batch_y in data_iterator: - prediction = self.data_forward(network, batch_x) loss = self.get_loss(prediction, batch_y) self.grad_backward(loss) + if torch.rand(1).item() < 0.001: + print('[grads at epoch: {:>3} step: {:>4}]'.format(kwargs['epoch'], step)) + for name, p in self._model.named_parameters(): + if p.requires_grad: + print('\t{} {} {}'.format(name, tuple(p.size()), torch.sum(p.grad).item())) self.update() self._summary_writer.add_scalar("loss", loss.item(), global_step=step) @@ -183,13 +192,14 @@ class Trainer(object): return step def valid_model(self): - if dev_data is None: + if self.dev_data is None: raise RuntimeError( "self.validate is True in trainer, but dev_data is None. Please provide the validation data.") logger.info("validation started") - res = self.validator.test(network, dev_data) + res = self.validator.test(self._model, self.dev_data) if self.save_best_dev and self.best_eval_result(res): logger.info('save best result! {}'.format(res)) + print('save best result! {}'.format(res)) self.save_model(self._model, 'best_model_'+self.start_time) return res @@ -282,14 +292,10 @@ class Trainer(object): """ if isinstance(metrics, tuple): loss, metrics = metrics - else: - metrics = validator.metrics if isinstance(metrics, dict): if len(metrics) == 1: accuracy = list(metrics.values())[0] - elif self.eval_sort_key is None: - raise ValueError('dict format metrics should provide sort key for eval best result') else: accuracy = metrics[self.eval_sort_key] else: diff --git a/fastNLP/models/biaffine_parser.py b/fastNLP/models/biaffine_parser.py index 4561dbd2..0cc40cb4 100644 --- a/fastNLP/models/biaffine_parser.py +++ b/fastNLP/models/biaffine_parser.py @@ -199,6 +199,8 @@ class BiaffineParser(GraphParser): word_emb_dim, pos_vocab_size, pos_emb_dim, + word_hid_dim, + pos_hid_dim, rnn_layers, rnn_hidden_size, arc_mlp_size, @@ -209,10 +211,15 @@ class BiaffineParser(GraphParser): use_greedy_infer=False): super(BiaffineParser, self).__init__() + rnn_out_size = 2 * rnn_hidden_size self.word_embedding = nn.Embedding(num_embeddings=word_vocab_size, embedding_dim=word_emb_dim) self.pos_embedding = nn.Embedding(num_embeddings=pos_vocab_size, embedding_dim=pos_emb_dim) + self.word_fc = nn.Linear(word_emb_dim, word_hid_dim) + self.pos_fc = nn.Linear(pos_emb_dim, pos_hid_dim) + self.word_norm = nn.LayerNorm(word_hid_dim) + self.pos_norm = nn.LayerNorm(pos_hid_dim) if use_var_lstm: - self.lstm = VarLSTM(input_size=word_emb_dim + pos_emb_dim, + self.lstm = VarLSTM(input_size=word_hid_dim + pos_hid_dim, hidden_size=rnn_hidden_size, num_layers=rnn_layers, bias=True, @@ -221,7 +228,7 @@ class BiaffineParser(GraphParser): hidden_dropout=dropout, bidirectional=True) else: - self.lstm = nn.LSTM(input_size=word_emb_dim + pos_emb_dim, + self.lstm = nn.LSTM(input_size=word_hid_dim + pos_hid_dim, hidden_size=rnn_hidden_size, num_layers=rnn_layers, bias=True, @@ -229,12 +236,13 @@ class BiaffineParser(GraphParser): dropout=dropout, bidirectional=True) - rnn_out_size = 2 * rnn_hidden_size self.arc_head_mlp = nn.Sequential(nn.Linear(rnn_out_size, arc_mlp_size), + nn.LayerNorm(arc_mlp_size), nn.ELU(), TimestepDropout(p=dropout),) self.arc_dep_mlp = copy.deepcopy(self.arc_head_mlp) self.label_head_mlp = nn.Sequential(nn.Linear(rnn_out_size, label_mlp_size), + nn.LayerNorm(label_mlp_size), nn.ELU(), TimestepDropout(p=dropout),) self.label_dep_mlp = copy.deepcopy(self.label_head_mlp) @@ -242,10 +250,18 @@ class BiaffineParser(GraphParser): self.label_predictor = LabelBilinear(label_mlp_size, label_mlp_size, num_label, bias=True) self.normal_dropout = nn.Dropout(p=dropout) self.use_greedy_infer = use_greedy_infer - initial_parameter(self) - self.word_norm = nn.LayerNorm(word_emb_dim) - self.pos_norm = nn.LayerNorm(pos_emb_dim) - self.lstm_norm = nn.LayerNorm(rnn_out_size) + self.reset_parameters() + + def reset_parameters(self): + for m in self.modules(): + if isinstance(m, nn.Embedding): + continue + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + else: + for p in m.parameters(): + nn.init.normal_(p, 0, 0.01) def forward(self, word_seq, pos_seq, word_seq_origin_len, gold_heads=None, **_): """ @@ -262,19 +278,21 @@ class BiaffineParser(GraphParser): # prepare embeddings batch_size, seq_len = word_seq.shape # print('forward {} {}'.format(batch_size, seq_len)) - batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=word_seq.device).unsqueeze(1) # get sequence mask seq_mask = len2masks(word_seq_origin_len, seq_len).long() word = self.normal_dropout(self.word_embedding(word_seq)) # [N,L] -> [N,L,C_0] pos = self.normal_dropout(self.pos_embedding(pos_seq)) # [N,L] -> [N,L,C_1] + word, pos = self.word_fc(word), self.pos_fc(pos) word, pos = self.word_norm(word), self.pos_norm(pos) x = torch.cat([word, pos], dim=2) # -> [N,L,C] + del word, pos # lstm, extract features + x = nn.utils.rnn.pack_padded_sequence(x, word_seq_origin_len.squeeze(1), batch_first=True) feat, _ = self.lstm(x) # -> [N,L,C] - feat = self.lstm_norm(feat) + feat, _ = nn.utils.rnn.pad_packed_sequence(feat, batch_first=True) # for arc biaffine # mlp, reduce dim @@ -282,6 +300,7 @@ class BiaffineParser(GraphParser): arc_head = self.arc_head_mlp(feat) label_dep = self.label_dep_mlp(feat) label_head = self.label_head_mlp(feat) + del feat # biaffine arc classifier arc_pred = self.arc_predictor(arc_head, arc_dep) # [N, L, L] @@ -289,7 +308,7 @@ class BiaffineParser(GraphParser): arc_pred.masked_fill_(flip_mask.unsqueeze(1), -np.inf) # use gold or predicted arc to predict label - if gold_heads is None: + if gold_heads is None or not self.training: # use greedy decoding in training if self.training or self.use_greedy_infer: heads = self._greedy_decoder(arc_pred, seq_mask) @@ -301,6 +320,7 @@ class BiaffineParser(GraphParser): head_pred = None heads = gold_heads + batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=word_seq.device).unsqueeze(1) label_head = label_head[batch_range, heads].contiguous() label_pred = self.label_predictor(label_head, label_dep) # [N, L, num_label] res_dict = {'arc_pred': arc_pred, 'label_pred': label_pred, 'seq_mask': seq_mask} diff --git a/reproduction/Biaffine_parser/cfg.cfg b/reproduction/Biaffine_parser/cfg.cfg index 3adb6937..e967ac46 100644 --- a/reproduction/Biaffine_parser/cfg.cfg +++ b/reproduction/Biaffine_parser/cfg.cfg @@ -1,16 +1,14 @@ [train] epochs = -1 -<<<<<<< HEAD -batch_size = 16 -======= batch_size = 32 ->>>>>>> update biaffine pickle_path = "./save/" validate = true save_best_dev = true eval_sort_key = "UAS" use_cuda = true model_saved_path = "./save/" +print_every_step = 20 +use_golden_train=true [test] save_output = true @@ -26,14 +24,17 @@ word_vocab_size = -1 word_emb_dim = 100 pos_vocab_size = -1 pos_emb_dim = 100 +word_hid_dim = 100 +pos_hid_dim = 100 rnn_layers = 3 rnn_hidden_size = 400 arc_mlp_size = 500 label_mlp_size = 100 num_label = -1 dropout = 0.33 -use_var_lstm=true +use_var_lstm=false use_greedy_infer=false [optim] lr = 2e-3 +weight_decay = 0.0 diff --git a/reproduction/Biaffine_parser/run.py b/reproduction/Biaffine_parser/run.py index 5bab554a..a1bce780 100644 --- a/reproduction/Biaffine_parser/run.py +++ b/reproduction/Biaffine_parser/run.py @@ -6,6 +6,7 @@ sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) from collections import defaultdict import math import torch +import re from fastNLP.core.trainer import Trainer from fastNLP.core.metrics import Evaluator @@ -55,10 +56,10 @@ class ConlluDataLoader(object): return ds def get_one(self, sample): - text = [''] - pos_tags = [''] - heads = [0] - head_tags = ['root'] + text = [] + pos_tags = [] + heads = [] + head_tags = [] for w in sample: t1, t2, t3, t4 = w[1], w[3], w[6], w[7] if t3 == '_': @@ -96,12 +97,13 @@ class CTBDataLoader(object): def convert(self, data): dataset = DataSet() for sample in data: - word_seq = [""] + sample[0] - pos_seq = [""] + sample[1] - heads = [0] + list(map(int, sample[2])) - head_tags = ["ROOT"] + sample[3] + word_seq = [""] + sample[0] + [''] + pos_seq = [""] + sample[1] + [''] + heads = [0] + list(map(int, sample[2])) + [0] + head_tags = [""] + sample[3] + [''] dataset.append(Instance(word_seq=TextField(word_seq, is_target=False), pos_seq=TextField(pos_seq, is_target=False), + gold_heads=SeqLabelField(heads, is_target=False), head_indices=SeqLabelField(heads, is_target=True), head_labels=TextField(head_tags, is_target=True))) return dataset @@ -117,7 +119,8 @@ datadir = '/home/yfshao/workdir/parser-data/' train_data_name = "train_ctb5.txt" dev_data_name = "dev_ctb5.txt" test_data_name = "test_ctb5.txt" -emb_file_name = "/home/yfshao/parser-data/word_OOVthr_30_100v.txt" +emb_file_name = "/home/yfshao/workdir/parser-data/word_OOVthr_30_100v.txt" +# emb_file_name = "/home/yfshao/workdir/word_vector/cc.zh.300.vec" loader = CTBDataLoader() cfgfile = './cfg.cfg' @@ -129,6 +132,10 @@ test_args = ConfigSection() model_args = ConfigSection() optim_args = ConfigSection() ConfigLoader.load_config(cfgfile, {"train": train_args, "test": test_args, "model": model_args, "optim": optim_args}) +print('trainre Args:', train_args.data) +print('test Args:', test_args.data) +print('optim Args:', optim_args.data) + # Pickle Loader def save_data(dirpath, **kwargs): @@ -151,9 +158,31 @@ def load_data(dirpath): datas[name] = _pickle.load(f) return datas +def P2(data, field, length): + ds = [ins for ins in data if ins[field].get_length() >= length] + data.clear() + data.extend(ds) + return ds + +def P1(data, field): + def reeng(w): + return w if w == '' or w == '' or re.search(r'^([a-zA-Z]+[\.\-]*)+$', w) is None else 'ENG' + def renum(w): + return w if re.search(r'^[0-9]+\.?[0-9]*$', w) is None else 'NUMBER' + for ins in data: + ori = ins[field].contents() + s = list(map(renum, map(reeng, ori))) + if s != ori: + # print(ori) + # print(s) + # print() + ins[field] = ins[field].new(s) + return data + class ParserEvaluator(Evaluator): - def __init__(self): + def __init__(self, ignore_label): super(ParserEvaluator, self).__init__() + self.ignore = ignore_label def __call__(self, predict_list, truth_list): head_all, label_all, total_all = 0, 0, 0 @@ -174,6 +203,7 @@ class ParserEvaluator(Evaluator): label_pred_correct: number of correct predicted labels. total_tokens: number of predicted tokens """ + seq_mask *= (head_labels != self.ignore).long() head_pred_correct = (head_pred == head_indices).long() * seq_mask _, label_preds = torch.max(label_pred, dim=2) label_pred_correct = (label_preds == head_labels).long() * head_pred_correct @@ -181,72 +211,93 @@ class ParserEvaluator(Evaluator): try: data_dict = load_data(processed_datadir) - word_v = data_dict['word_v'] pos_v = data_dict['pos_v'] tag_v = data_dict['tag_v'] train_data = data_dict['train_data'] dev_data = data_dict['dev_data'] + test_data = data_dict['test_datas'] print('use saved pickles') except Exception as _: print('load raw data and preprocess') - word_v = Vocabulary(need_default=True, min_freq=2) + # use pretrain embedding pos_v = Vocabulary(need_default=True) tag_v = Vocabulary(need_default=False) train_data = loader.load(os.path.join(datadir, train_data_name)) dev_data = loader.load(os.path.join(datadir, dev_data_name)) test_data = loader.load(os.path.join(datadir, test_data_name)) - train_data.update_vocab(word_seq=word_v, pos_seq=pos_v, head_labels=tag_v) - save_data(processed_datadir, word_v=word_v, pos_v=pos_v, tag_v=tag_v, train_data=train_data, dev_data=dev_data) + train_data.update_vocab(pos_seq=pos_v, head_labels=tag_v) + save_data(processed_datadir, pos_v=pos_v, tag_v=tag_v, train_data=train_data, dev_data=dev_data, test_data=test_data) -train_data.index_field("word_seq", word_v).index_field("pos_seq", pos_v).index_field("head_labels", tag_v) -dev_data.index_field("word_seq", word_v).index_field("pos_seq", pos_v).index_field("head_labels", tag_v) -train_data.set_origin_len("word_seq") -dev_data.set_origin_len("word_seq") +embed, word_v = EmbedLoader.load_embedding(model_args['word_emb_dim'], emb_file_name, 'glove', None, os.path.join(processed_datadir, 'word_emb.pkl')) +word_v.unknown_label = "" -print(train_data[:3]) -print(len(train_data)) -print(len(dev_data)) +# Model model_args['word_vocab_size'] = len(word_v) model_args['pos_vocab_size'] = len(pos_v) model_args['num_label'] = len(tag_v) +model = BiaffineParser(**model_args.data) +model.reset_parameters() + +datasets = (train_data, dev_data, test_data) +for ds in datasets: + # print('====='*30) + P1(ds, 'word_seq') + P2(ds, 'word_seq', 5) + ds.index_field("word_seq", word_v).index_field("pos_seq", pos_v).index_field("head_labels", tag_v) + ds.set_origin_len('word_seq') + if train_args['use_golden_train']: + ds.set_target(gold_heads=False) + else: + ds.set_target(gold_heads=None) +train_args.data.pop('use_golden_train') +ignore_label = pos_v['P'] + +print(test_data[0]) +print(len(train_data)) +print(len(dev_data)) +print(len(test_data)) + -def train(): + +def train(path): # Trainer trainer = Trainer(**train_args.data) def _define_optim(obj): - obj._optimizer = torch.optim.Adam(obj._model.parameters(), **optim_args.data) + lr = optim_args.data['lr'] + embed_params = set(obj._model.word_embedding.parameters()) + decay_params = set(obj._model.arc_predictor.parameters()) | set(obj._model.label_predictor.parameters()) + params = [p for p in obj._model.parameters() if p not in decay_params and p not in embed_params] + obj._optimizer = torch.optim.Adam([ + {'params': list(embed_params), 'lr':lr*0.1}, + {'params': list(decay_params), **optim_args.data}, + {'params': params} + ], lr=lr) obj._scheduler = torch.optim.lr_scheduler.LambdaLR(obj._optimizer, lambda ep: max(.75 ** (ep / 5e4), 0.05)) def _update(obj): + # torch.nn.utils.clip_grad_norm_(obj._model.parameters(), 5.0) obj._scheduler.step() obj._optimizer.step() trainer.define_optimizer = lambda: _define_optim(trainer) trainer.update = lambda: _update(trainer) - trainer.set_validator(Tester(**test_args.data, evaluator=ParserEvaluator())) + trainer.set_validator(Tester(**test_args.data, evaluator=ParserEvaluator(ignore_label))) - # Model - model = BiaffineParser(**model_args.data) - - # use pretrain embedding - word_v.unknown_label = "" - embed, _ = EmbedLoader.load_embedding(model_args['word_emb_dim'], emb_file_name, 'glove', word_v, os.path.join(processed_datadir, 'word_emb.pkl')) model.word_embedding = torch.nn.Embedding.from_pretrained(embed, freeze=False) - model.word_embedding.padding_idx = word_v.padding_idx model.word_embedding.weight.data[word_v.padding_idx].fill_(0) model.pos_embedding.padding_idx = pos_v.padding_idx model.pos_embedding.weight.data[pos_v.padding_idx].fill_(0) - try: - ModelLoader.load_pytorch(model, "./save/saved_model.pkl") - print('model parameter loaded!') - except Exception as _: - print("No saved model. Continue.") - pass + # try: + # ModelLoader.load_pytorch(model, "./save/saved_model.pkl") + # print('model parameter loaded!') + # except Exception as _: + # print("No saved model. Continue.") + # pass # Start training trainer.train(model, train_data, dev_data) @@ -258,15 +309,15 @@ def train(): print("Model saved!") -def test(): +def test(path): # Tester - tester = Tester(**test_args.data, evaluator=ParserEvaluator()) + tester = Tester(**test_args.data, evaluator=ParserEvaluator(ignore_label)) # Model model = BiaffineParser(**model_args.data) try: - ModelLoader.load_pytorch(model, "./save/saved_model.pkl") + ModelLoader.load_pytorch(model, path) print('model parameter loaded!') except Exception as _: print("No saved model. Abort test.") @@ -284,11 +335,12 @@ if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description='Run a chinese word segmentation model') parser.add_argument('--mode', help='set the model\'s model', choices=['train', 'test', 'infer']) + parser.add_argument('--path', type=str, default='') args = parser.parse_args() if args.mode == 'train': - train() + train(args.path) elif args.mode == 'test': - test() + test(args.path) elif args.mode == 'infer': infer() else: