diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index f4e64c5d..b763ada2 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -275,7 +275,7 @@ class DataSet(object): :return results: if new_field_name is not passed, returned values of the function over all instances. """ results = [func(ins) for ins in self._inner_iter()] - if len(list(filter(lambda x: x is not None, results))) == 0 and not (new_field_name is None): # all None + if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None raise ValueError("{} always return None.".format(get_func_signature(func=func))) extra_param = {} diff --git a/fastNLP/io/dataset_loader.py b/fastNLP/io/dataset_loader.py index c1092e53..211d6cc9 100644 --- a/fastNLP/io/dataset_loader.py +++ b/fastNLP/io/dataset_loader.py @@ -897,7 +897,10 @@ class ConllxDataLoader(object): if return_dataset is True: ds = DataSet() for example in data_list: - ds.append(Instance(words=example[0], tag=example[1])) + ds.append(Instance(words=example[0], + pos_tags=example[1], + heads=example[2], + labels=example[3])) data_list = ds return data_list diff --git a/fastNLP/models/biaffine_parser.py b/fastNLP/models/biaffine_parser.py index b9b9dd56..dfbaac58 100644 --- a/fastNLP/models/biaffine_parser.py +++ b/fastNLP/models/biaffine_parser.py @@ -216,6 +216,7 @@ class BiaffineParser(GraphParser): self.word_norm = nn.LayerNorm(word_hid_dim) self.pos_norm = nn.LayerNorm(pos_hid_dim) self.encoder_name = encoder + self.max_len = 512 if encoder == 'var-lstm': self.encoder = VarLSTM(input_size=word_hid_dim + pos_hid_dim, hidden_size=rnn_hidden_size, @@ -233,6 +234,20 @@ class BiaffineParser(GraphParser): batch_first=True, dropout=dropout, bidirectional=True) + elif encoder == 'transformer': + n_head = 16 + d_k = d_v = int(rnn_out_size / n_head) + if (d_k * n_head) != rnn_out_size: + raise ValueError('unsupported rnn_out_size: {} for transformer'.format(rnn_out_size)) + self.position_emb = nn.Embedding(num_embeddings=self.max_len, + embedding_dim=rnn_out_size,) + self.encoder = TransformerEncoder(num_layers=rnn_layers, + model_size=rnn_out_size, + inner_size=1024, + key_size=d_k, + value_size=d_v, + num_head=n_head, + dropout=dropout,) else: raise ValueError('unsupported encoder type: {}'.format(encoder)) @@ -285,13 +300,18 @@ class BiaffineParser(GraphParser): x = torch.cat([word, pos], dim=2) # -> [N,L,C] # encoder, extract features - sort_lens, sort_idx = torch.sort(seq_lens, dim=0, descending=True) - x = x[sort_idx] - x = nn.utils.rnn.pack_padded_sequence(x, sort_lens, batch_first=True) - feat, _ = self.encoder(x) # -> [N,L,C] - feat, _ = nn.utils.rnn.pad_packed_sequence(feat, batch_first=True) - _, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) - feat = feat[unsort_idx] + if self.encoder_name.endswith('lstm'): + sort_lens, sort_idx = torch.sort(seq_lens, dim=0, descending=True) + x = x[sort_idx] + x = nn.utils.rnn.pack_padded_sequence(x, sort_lens, batch_first=True) + feat, _ = self.encoder(x) # -> [N,L,C] + feat, _ = nn.utils.rnn.pad_packed_sequence(feat, batch_first=True) + _, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) + feat = feat[unsort_idx] + else: + seq_range = torch.arange(seq_len, dtype=torch.long, device=x.device)[None,:] + x = x + self.position_emb(seq_range) + feat = self.encoder(x, mask.float()) # for arc biaffine # mlp, reduce dim diff --git a/reproduction/Biaffine_parser/cfg.cfg b/reproduction/Biaffine_parser/cfg.cfg index ad06598f..4a56bad5 100644 --- a/reproduction/Biaffine_parser/cfg.cfg +++ b/reproduction/Biaffine_parser/cfg.cfg @@ -1,9 +1,9 @@ [train] -n_epochs = 40 +n_epochs = 1 batch_size = 32 use_cuda = true use_tqdm=true -validate_every = -1 +validate_every = 1000 use_golden_train=true [test] @@ -17,7 +17,7 @@ use_cuda = true [model] word_vocab_size = -1 -word_emb_dim = 100 +word_emb_dim = 300 pos_vocab_size = -1 pos_emb_dim = 100 rnn_layers = 3 @@ -30,5 +30,5 @@ encoder="transformer" use_greedy_infer=false [optim] -lr = 3e-4 +lr = 2e-3 ;weight_decay = 3e-5 diff --git a/reproduction/Biaffine_parser/run.py b/reproduction/Biaffine_parser/run.py index ded7487d..e74018ba 100644 --- a/reproduction/Biaffine_parser/run.py +++ b/reproduction/Biaffine_parser/run.py @@ -4,6 +4,7 @@ import sys sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) import fastNLP +import torch from fastNLP.core.trainer import Trainer from fastNLP.core.instance import Instance @@ -14,10 +15,13 @@ from fastNLP.io.config_io import ConfigLoader, ConfigSection from fastNLP.io.model_io import ModelLoader from fastNLP.io.dataset_loader import ConllxDataLoader from fastNLP.api.processor import * +from fastNLP.io.embed_loader import EmbedLoader +from fastNLP.core.callback import Callback BOS = '' EOS = '' UNK = '' +PAD = '' NUM = '' ENG = '' @@ -28,11 +32,11 @@ if len(os.path.dirname(__file__)) != 0: def convert(data): dataset = DataSet() for sample in data: - word_seq = [BOS] + sample[0] - pos_seq = [BOS] + sample[1] - heads = [0] + list(map(int, sample[2])) - head_tags = [BOS] + sample[3] - dataset.append(Instance(words=word_seq, + word_seq = [BOS] + sample['words'] + pos_seq = [BOS] + sample['pos_tags'] + heads = [0] + sample['heads'] + head_tags = [BOS] + sample['labels'] + dataset.append(Instance(raw_words=word_seq, pos=pos_seq, gold_heads=heads, arc_true=heads, @@ -45,24 +49,11 @@ def load(path): return convert(data) -# datadir = "/mnt/c/Me/Dev/release-2.2-st-train-dev-data/ud-treebanks-v2.2/UD_English-EWT" -# datadir = "/home/yfshao/UD_English-EWT" -# train_data_name = "en_ewt-ud-train.conllu" -# dev_data_name = "en_ewt-ud-dev.conllu" -# emb_file_name = '/home/yfshao/glove.6B.100d.txt' -# loader = ConlluDataLoader() - -# datadir = '/home/yfshao/workdir/parser-data/' -# train_data_name = "train_ctb5.txt" -# dev_data_name = "dev_ctb5.txt" -# test_data_name = "test_ctb5.txt" - -datadir = "/home/yfshao/workdir/ctb7.0/" +datadir = "/remote-home/yfshao/workdir/ctb9.0/" train_data_name = "train.conllx" dev_data_name = "dev.conllx" test_data_name = "test.conllx" -# emb_file_name = "/home/yfshao/workdir/parser-data/word_OOVthr_30_100v.txt" -emb_file_name = "/home/yfshao/workdir/word_vector/cc.zh.300.vec" +emb_file_name = "/remote-home/yfshao/workdir/word_vector/cc.zh.300.vec" cfgfile = './cfg.cfg' processed_datadir = './save' @@ -108,27 +99,23 @@ def update_v(vocab, data, field): data.apply(lambda x: vocab.add_word_lst(x[field]), new_field_name=None) -print('load raw data and preprocess') # use pretrain embedding -word_v = Vocabulary() -word_v.unknown_label = UNK -pos_v = Vocabulary() +word_v = Vocabulary(unknown=UNK, padding=PAD) +pos_v = Vocabulary(unknown=None, padding=PAD) tag_v = Vocabulary(unknown=None, padding=None) train_data = load(os.path.join(datadir, train_data_name)) dev_data = load(os.path.join(datadir, dev_data_name)) test_data = load(os.path.join(datadir, test_data_name)) -print(train_data[0]) -num_p = Num2TagProcessor('words', 'words') +print('load raw data and preprocess') + +num_p = Num2TagProcessor(tag=NUM, field_name='raw_words', new_added_field_name='words') for ds in (train_data, dev_data, test_data): num_p(ds) - update_v(word_v, train_data, 'words') update_v(pos_v, train_data, 'pos') update_v(tag_v, train_data, 'tags') print('vocab build success {}, {}, {}'.format(len(word_v), len(pos_v), len(tag_v))) -# embed, _ = EmbedLoader.fast_load_embedding(model_args['word_emb_dim'], emb_file_name, word_v) -# print(embed.size()) # Model model_args['word_vocab_size'] = len(word_v) @@ -159,7 +146,6 @@ for ds in (train_data, dev_data, test_data): if train_args['use_golden_train']: train_data.set_input('gold_heads', flag=True) train_args.data.pop('use_golden_train') -ignore_label = pos_v['punct'] print(test_data[0]) print('train len {}'.format(len(train_data))) @@ -167,45 +153,60 @@ print('dev len {}'.format(len(dev_data))) print('test len {}'.format(len(test_data))) - def train(path): # test saving pipeline save_pipe(path) - # Trainer - trainer = Trainer(model=model, train_data=train_data, dev_data=dev_data, - loss=ParserLoss(), metrics=ParserMetric(), metric_key='UAS', - **train_args.data, - optimizer=fastNLP.Adam(**optim_args.data), - save_path=path) - - # model.word_embedding = torch.nn.Embedding.from_pretrained(embed, freeze=False) + # embed = EmbedLoader.fast_load_embedding(emb_dim=model_args['word_emb_dim'], emb_file=emb_file_name, vocab=word_v) + # embed = torch.tensor(embed, dtype=torch.float32) + # model.word_embedding = torch.nn.Embedding.from_pretrained(embed, freeze=True) model.word_embedding.padding_idx = word_v.padding_idx model.word_embedding.weight.data[word_v.padding_idx].fill_(0) model.pos_embedding.padding_idx = pos_v.padding_idx model.pos_embedding.weight.data[pos_v.padding_idx].fill_(0) - # try: - # ModelLoader.load_pytorch(model, "./save/saved_model.pkl") - # print('model parameter loaded!') - # except Exception as _: - # print("No saved model. Continue.") - # pass + class MyCallback(Callback): + def after_step(self, optimizer): + step = self.trainer.step + # learning rate decay + if step > 0 and step % 1000 == 0: + for pg in optimizer.param_groups: + pg['lr'] *= 0.93 + print('decay lr to {}'.format([pg['lr'] for pg in optimizer.param_groups])) + + if step == 3000: + # start training embedding + print('start training embedding at {}'.format(step)) + model = self.trainer.model + for m in model.modules(): + if isinstance(m, torch.nn.Embedding): + m.weight.requires_grad = True - # Start training - trainer.train() - print("Training finished!") + # Trainer + trainer = Trainer(model=model, train_data=train_data, dev_data=dev_data, + loss=ParserLoss(), metrics=ParserMetric(), metric_key='UAS', + **train_args.data, + optimizer=fastNLP.Adam(**optim_args.data), + save_path=path, + callbacks=[MyCallback()]) - # save pipeline - save_pipe(path) - print('pipe saved') + # Start training + try: + trainer.train() + print("Training finished!") + finally: + # save pipeline + save_pipe(path) + print('pipe saved') def save_pipe(path): pipe = Pipeline(processors=[num_p, word_idxp, pos_idxp, seq_p, set_input_p]) pipe.add_processor(ModelProcessor(model=model, batch_size=32)) pipe.add_processor(label_toword_p) os.makedirs(path, exist_ok=True) - torch.save({'pipeline': pipe}, os.path.join(path, 'pipe.pkl')) + torch.save({'pipeline': pipe, + 'names':['num word_idx pos_idx seq set_input model tag_to_word'.split()], + }, os.path.join(path, 'pipe.pkl')) def test(path): @@ -230,16 +231,11 @@ def test(path): print("Testing Test data") tester.test(model, test_data) -def build_pipe(parser_pipe_path): - parser_pipe = torch.load(parser_pipe_path) - - - if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description='Run a chinese word segmentation model') - parser.add_argument('--mode', help='set the model\'s model', choices=['train', 'test', 'infer', 'save']) + parser.add_argument('--mode', help='set the model\'s model', choices=['train', 'test', 'infer']) parser.add_argument('--path', type=str, default='') # parser.add_argument('--dst', type=str, default='') args = parser.parse_args() @@ -249,12 +245,6 @@ if __name__ == "__main__": test(args.path) elif args.mode == 'infer': pass - # elif args.mode == 'save': - # print(f'save model from {args.path} to {args.dst}') - # save_model(args.path, args.dst) - # load_path = os.path.dirname(args.dst) - # print(f'save pipeline in {load_path}') - # build(load_path) else: print('no mode specified for model!') parser.print_help()