From 102259df399ad43102a761e47a705c3fe6ebb308 Mon Sep 17 00:00:00 2001 From: yunfan Date: Thu, 18 Oct 2018 22:27:22 +0800 Subject: [PATCH] update biaffine parser --- fastNLP/core/field.py | 3 + fastNLP/core/instance.py | 3 + fastNLP/core/vocabulary.py | 7 ++- fastNLP/loader/embed_loader.py | 6 +- fastNLP/models/biaffine_parser.py | 10 +++- reproduction/Biaffine_parser/run.py | 87 ++++++++++++++++++++--------- 6 files changed, 85 insertions(+), 31 deletions(-) diff --git a/fastNLP/core/field.py b/fastNLP/core/field.py index 1c5e7425..a3cf21d5 100644 --- a/fastNLP/core/field.py +++ b/fastNLP/core/field.py @@ -21,6 +21,9 @@ class Field(object): def contents(self): raise NotImplementedError + def __repr__(self): + return self.contents().__repr__() + class TextField(Field): def __init__(self, text, is_target): """ diff --git a/fastNLP/core/instance.py b/fastNLP/core/instance.py index a4eca1aa..0527a16f 100644 --- a/fastNLP/core/instance.py +++ b/fastNLP/core/instance.py @@ -82,3 +82,6 @@ class Instance(object): name, field_name = origin_len tensor_x[name] = torch.LongTensor([self.fields[field_name].get_length()]) return tensor_x, tensor_y + + def __repr__(self): + return self.fields.__repr__() \ No newline at end of file diff --git a/fastNLP/core/vocabulary.py b/fastNLP/core/vocabulary.py index 26d2e837..4f7f42ed 100644 --- a/fastNLP/core/vocabulary.py +++ b/fastNLP/core/vocabulary.py @@ -114,7 +114,7 @@ class Vocabulary(object): if w in self.word2idx: return self.word2idx[w] elif self.has_default: - return self.word2idx[DEFAULT_UNKNOWN_LABEL] + return self.word2idx[self.unknown_label] else: raise ValueError("word {} not in vocabulary".format(w)) @@ -134,6 +134,11 @@ class Vocabulary(object): return None return self.word2idx[self.unknown_label] + def __setattr__(self, name, val): + if name in self.__dict__ and name in ["unknown_label", "padding_label"]: + self.word2idx[val] = self.word2idx.pop(self.__dict__[name]) + self.__dict__[name] = val + @property @check_build_vocab def padding_idx(self): diff --git a/fastNLP/loader/embed_loader.py b/fastNLP/loader/embed_loader.py index 2f61830f..415cb1b9 100644 --- a/fastNLP/loader/embed_loader.py +++ b/fastNLP/loader/embed_loader.py @@ -17,8 +17,8 @@ class EmbedLoader(BaseLoader): def _load_glove(emb_file): """Read file as a glove embedding - file format: - embeddings are split by line, + file format: + embeddings are split by line, for one embedding, word and numbers split by space Example:: @@ -33,7 +33,7 @@ class EmbedLoader(BaseLoader): if len(line) > 0: emb[line[0]] = torch.Tensor(list(map(float, line[1:]))) return emb - + @staticmethod def _load_pretrain(emb_file, emb_type): """Read txt data from embedding file and convert to np.array as pre-trained embedding diff --git a/fastNLP/models/biaffine_parser.py b/fastNLP/models/biaffine_parser.py index 845e372f..a5461ee8 100644 --- a/fastNLP/models/biaffine_parser.py +++ b/fastNLP/models/biaffine_parser.py @@ -182,6 +182,12 @@ class LabelBilinear(nn.Module): output += self.lin(torch.cat([x1, x2], dim=2)) return output +def len2masks(origin_len, max_len): + if origin_len.dim() <= 1: + origin_len = origin_len.unsqueeze(1) # [batch_size, 1] + seq_range = torch.arange(start=0, end=max_len, dtype=torch.long, device=origin_len.device) # [max_len,] + seq_mask = torch.gt(origin_len, seq_range.unsqueeze(0)) # [batch_size, max_len] + return seq_mask class BiaffineParser(GraphParser): """Biaffine Dependency Parser implemantation. @@ -238,7 +244,7 @@ class BiaffineParser(GraphParser): self.use_greedy_infer = use_greedy_infer initial_parameter(self) - def forward(self, word_seq, pos_seq, seq_mask, gold_heads=None, **_): + def forward(self, word_seq, pos_seq, word_seq_origin_len, gold_heads=None, **_): """ :param word_seq: [batch_size, seq_len] sequence of word's indices :param pos_seq: [batch_size, seq_len] sequence of word's indices @@ -256,7 +262,7 @@ class BiaffineParser(GraphParser): batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=word_seq.device).unsqueeze(1) # get sequence mask - seq_mask = seq_mask.long() + seq_mask = len2masks(word_seq_origin_len, seq_len).long() word = self.normal_dropout(self.word_embedding(word_seq)) # [N,L] -> [N,L,C_0] pos = self.normal_dropout(self.pos_embedding(pos_seq)) # [N,L] -> [N,L,C_1] diff --git a/reproduction/Biaffine_parser/run.py b/reproduction/Biaffine_parser/run.py index cc8e54ad..9404d195 100644 --- a/reproduction/Biaffine_parser/run.py +++ b/reproduction/Biaffine_parser/run.py @@ -14,7 +14,6 @@ from fastNLP.core.dataset import DataSet from fastNLP.core.batch import Batch from fastNLP.core.sampler import SequentialSampler from fastNLP.core.field import TextField, SeqLabelField -from fastNLP.core.preprocess import SeqLabelPreprocess, load_pickle from fastNLP.core.tester import Tester from fastNLP.loader.config_loader import ConfigLoader, ConfigSection from fastNLP.loader.model_loader import ModelLoader @@ -26,11 +25,8 @@ from fastNLP.saver.model_saver import ModelSaver if len(os.path.dirname(__file__)) != 0: os.chdir(os.path.dirname(__file__)) -class MyDataLoader(object): - def __init__(self, pickle_path): - self.pickle_path = pickle_path - - def load(self, path, word_v=None, pos_v=None, headtag_v=None): +class ConlluDataLoader(object): + def load(self, path): datalist = [] with open(path, 'r', encoding='utf-8') as f: sample = [] @@ -49,15 +45,10 @@ class MyDataLoader(object): for sample in datalist: # print(sample) res = self.get_one(sample) - if word_v is not None: - word_v.update(res[0]) - pos_v.update(res[1]) - headtag_v.update(res[3]) ds.append(Instance(word_seq=TextField(res[0], is_target=False), pos_seq=TextField(res[1], is_target=False), head_indices=SeqLabelField(res[2], is_target=True), - head_labels=TextField(res[3], is_target=True), - seq_mask=SeqLabelField([1 for _ in range(len(res[0]))], is_target=False))) + head_labels=TextField(res[3], is_target=True))) return ds @@ -76,17 +67,57 @@ class MyDataLoader(object): head_tags.append(t4) return (text, pos_tags, heads, head_tags) - def index_data(self, dataset, word_v, pos_v, tag_v): - dataset.index_field('word_seq', word_v) - dataset.index_field('pos_seq', pos_v) - dataset.index_field('head_labels', tag_v) +class CTBDataLoader(object): + def load(self, data_path): + with open(data_path, "r", encoding="utf-8") as f: + lines = f.readlines() + data = self.parse(lines) + return self.convert(data) + + def parse(self, lines): + """ + [ + [word], [pos], [head_index], [head_tag] + ] + """ + sample = [] + data = [] + for i, line in enumerate(lines): + line = line.strip() + if len(line) == 0 or i+1 == len(lines): + data.append(list(map(list, zip(*sample)))) + sample = [] + else: + sample.append(line.split()) + return data + + def convert(self, data): + dataset = DataSet() + for sample in data: + word_seq = [""] + sample[0] + pos_seq = [""] + sample[1] + heads = [0] + list(map(int, sample[2])) + head_tags = ["ROOT"] + sample[3] + dataset.append(Instance(word_seq=TextField(word_seq, is_target=False), + pos_seq=TextField(pos_seq, is_target=False), + head_indices=SeqLabelField(heads, is_target=True), + head_labels=TextField(head_tags, is_target=True))) + return dataset # datadir = "/mnt/c/Me/Dev/release-2.2-st-train-dev-data/ud-treebanks-v2.2/UD_English-EWT" -datadir = "/home/yfshao/UD_English-EWT" +# datadir = "/home/yfshao/UD_English-EWT" +# train_data_name = "en_ewt-ud-train.conllu" +# dev_data_name = "en_ewt-ud-dev.conllu" +# emb_file_name = '/home/yfshao/glove.6B.100d.txt' +# loader = ConlluDataLoader() + +datadir = "/home/yfshao/parser-data" +train_data_name = "train_ctb5.txt" +dev_data_name = "dev_ctb5.txt" +emb_file_name = "/home/yfshao/parser-data/word_OOVthr_30_100v.txt" +loader = CTBDataLoader() + cfgfile = './cfg.cfg' -train_data_name = "en_ewt-ud-train.conllu" -dev_data_name = "en_ewt-ud-dev.conllu" -emb_file_name = '/home/yfshao/glove.6B.100d.txt' processed_datadir = './save' # Config Loader @@ -96,7 +127,7 @@ model_args = ConfigSection() optim_args = ConfigSection() ConfigLoader.load_config(cfgfile, {"train": train_args, "test": test_args, "model": model_args, "optim": optim_args}) -# Data Loader +# Pickle Loader def save_data(dirpath, **kwargs): import _pickle if not os.path.exists(dirpath): @@ -140,6 +171,7 @@ class MyTester(object): tmp[eval_name] = torch.cat(tensorlist, dim=0) self.res = self.model.metrics(**tmp) + print(self.show_metrics()) def show_metrics(self): s = "" @@ -148,7 +180,6 @@ class MyTester(object): return s -loader = MyDataLoader('') try: data_dict = load_data(processed_datadir) word_v = data_dict['word_v'] @@ -163,12 +194,17 @@ except Exception as _: word_v = Vocabulary(need_default=True, min_freq=2) pos_v = Vocabulary(need_default=True) tag_v = Vocabulary(need_default=False) - train_data = loader.load(os.path.join(datadir, train_data_name), word_v, pos_v, tag_v) + train_data = loader.load(os.path.join(datadir, train_data_name)) dev_data = loader.load(os.path.join(datadir, dev_data_name)) + train_data.update_vocab(word_seq=word_v, pos_seq=pos_v, head_labels=tag_v) save_data(processed_datadir, word_v=word_v, pos_v=pos_v, tag_v=tag_v, train_data=train_data, dev_data=dev_data) -loader.index_data(train_data, word_v, pos_v, tag_v) -loader.index_data(dev_data, word_v, pos_v, tag_v) +train_data.index_field("word_seq", word_v).index_field("pos_seq", pos_v).index_field("head_labels", tag_v) +dev_data.index_field("word_seq", word_v).index_field("pos_seq", pos_v).index_field("head_labels", tag_v) +train_data.set_origin_len("word_seq") +dev_data.set_origin_len("word_seq") + +print(train_data[:3]) print(len(train_data)) print(len(dev_data)) ep = train_args['epochs'] @@ -199,6 +235,7 @@ def train(): model = BiaffineParser(**model_args.data) # use pretrain embedding + word_v.unknown_label = "" embed, _ = EmbedLoader.load_embedding(model_args['word_emb_dim'], emb_file_name, 'glove', word_v, os.path.join(processed_datadir, 'word_emb.pkl')) model.word_embedding = torch.nn.Embedding.from_pretrained(embed, freeze=False) model.word_embedding.padding_idx = word_v.padding_idx