Browse Source

update biaffine parser

tags/v0.2.0
yunfan 6 years ago
parent
commit
102259df39
6 changed files with 85 additions and 31 deletions
  1. +3
    -0
      fastNLP/core/field.py
  2. +3
    -0
      fastNLP/core/instance.py
  3. +6
    -1
      fastNLP/core/vocabulary.py
  4. +3
    -3
      fastNLP/loader/embed_loader.py
  5. +8
    -2
      fastNLP/models/biaffine_parser.py
  6. +62
    -25
      reproduction/Biaffine_parser/run.py

+ 3
- 0
fastNLP/core/field.py View File

@@ -21,6 +21,9 @@ class Field(object):
def contents(self): def contents(self):
raise NotImplementedError raise NotImplementedError


def __repr__(self):
return self.contents().__repr__()

class TextField(Field): class TextField(Field):
def __init__(self, text, is_target): def __init__(self, text, is_target):
""" """


+ 3
- 0
fastNLP/core/instance.py View File

@@ -82,3 +82,6 @@ class Instance(object):
name, field_name = origin_len name, field_name = origin_len
tensor_x[name] = torch.LongTensor([self.fields[field_name].get_length()]) tensor_x[name] = torch.LongTensor([self.fields[field_name].get_length()])
return tensor_x, tensor_y return tensor_x, tensor_y

def __repr__(self):
return self.fields.__repr__()

+ 6
- 1
fastNLP/core/vocabulary.py View File

@@ -114,7 +114,7 @@ class Vocabulary(object):
if w in self.word2idx: if w in self.word2idx:
return self.word2idx[w] return self.word2idx[w]
elif self.has_default: elif self.has_default:
return self.word2idx[DEFAULT_UNKNOWN_LABEL]
return self.word2idx[self.unknown_label]
else: else:
raise ValueError("word {} not in vocabulary".format(w)) raise ValueError("word {} not in vocabulary".format(w))


@@ -134,6 +134,11 @@ class Vocabulary(object):
return None return None
return self.word2idx[self.unknown_label] return self.word2idx[self.unknown_label]


def __setattr__(self, name, val):
if name in self.__dict__ and name in ["unknown_label", "padding_label"]:
self.word2idx[val] = self.word2idx.pop(self.__dict__[name])
self.__dict__[name] = val

@property @property
@check_build_vocab @check_build_vocab
def padding_idx(self): def padding_idx(self):


+ 3
- 3
fastNLP/loader/embed_loader.py View File

@@ -17,8 +17,8 @@ class EmbedLoader(BaseLoader):
def _load_glove(emb_file): def _load_glove(emb_file):
"""Read file as a glove embedding """Read file as a glove embedding


file format:
embeddings are split by line,
file format:
embeddings are split by line,
for one embedding, word and numbers split by space for one embedding, word and numbers split by space
Example:: Example::


@@ -33,7 +33,7 @@ class EmbedLoader(BaseLoader):
if len(line) > 0: if len(line) > 0:
emb[line[0]] = torch.Tensor(list(map(float, line[1:]))) emb[line[0]] = torch.Tensor(list(map(float, line[1:])))
return emb return emb
@staticmethod @staticmethod
def _load_pretrain(emb_file, emb_type): def _load_pretrain(emb_file, emb_type):
"""Read txt data from embedding file and convert to np.array as pre-trained embedding """Read txt data from embedding file and convert to np.array as pre-trained embedding


+ 8
- 2
fastNLP/models/biaffine_parser.py View File

@@ -182,6 +182,12 @@ class LabelBilinear(nn.Module):
output += self.lin(torch.cat([x1, x2], dim=2)) output += self.lin(torch.cat([x1, x2], dim=2))
return output return output


def len2masks(origin_len, max_len):
if origin_len.dim() <= 1:
origin_len = origin_len.unsqueeze(1) # [batch_size, 1]
seq_range = torch.arange(start=0, end=max_len, dtype=torch.long, device=origin_len.device) # [max_len,]
seq_mask = torch.gt(origin_len, seq_range.unsqueeze(0)) # [batch_size, max_len]
return seq_mask


class BiaffineParser(GraphParser): class BiaffineParser(GraphParser):
"""Biaffine Dependency Parser implemantation. """Biaffine Dependency Parser implemantation.
@@ -238,7 +244,7 @@ class BiaffineParser(GraphParser):
self.use_greedy_infer = use_greedy_infer self.use_greedy_infer = use_greedy_infer
initial_parameter(self) initial_parameter(self)


def forward(self, word_seq, pos_seq, seq_mask, gold_heads=None, **_):
def forward(self, word_seq, pos_seq, word_seq_origin_len, gold_heads=None, **_):
""" """
:param word_seq: [batch_size, seq_len] sequence of word's indices :param word_seq: [batch_size, seq_len] sequence of word's indices
:param pos_seq: [batch_size, seq_len] sequence of word's indices :param pos_seq: [batch_size, seq_len] sequence of word's indices
@@ -256,7 +262,7 @@ class BiaffineParser(GraphParser):
batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=word_seq.device).unsqueeze(1) batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=word_seq.device).unsqueeze(1)


# get sequence mask # get sequence mask
seq_mask = seq_mask.long()
seq_mask = len2masks(word_seq_origin_len, seq_len).long()


word = self.normal_dropout(self.word_embedding(word_seq)) # [N,L] -> [N,L,C_0] word = self.normal_dropout(self.word_embedding(word_seq)) # [N,L] -> [N,L,C_0]
pos = self.normal_dropout(self.pos_embedding(pos_seq)) # [N,L] -> [N,L,C_1] pos = self.normal_dropout(self.pos_embedding(pos_seq)) # [N,L] -> [N,L,C_1]


+ 62
- 25
reproduction/Biaffine_parser/run.py View File

@@ -14,7 +14,6 @@ from fastNLP.core.dataset import DataSet
from fastNLP.core.batch import Batch from fastNLP.core.batch import Batch
from fastNLP.core.sampler import SequentialSampler from fastNLP.core.sampler import SequentialSampler
from fastNLP.core.field import TextField, SeqLabelField from fastNLP.core.field import TextField, SeqLabelField
from fastNLP.core.preprocess import SeqLabelPreprocess, load_pickle
from fastNLP.core.tester import Tester from fastNLP.core.tester import Tester
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
from fastNLP.loader.model_loader import ModelLoader from fastNLP.loader.model_loader import ModelLoader
@@ -26,11 +25,8 @@ from fastNLP.saver.model_saver import ModelSaver
if len(os.path.dirname(__file__)) != 0: if len(os.path.dirname(__file__)) != 0:
os.chdir(os.path.dirname(__file__)) os.chdir(os.path.dirname(__file__))


class MyDataLoader(object):
def __init__(self, pickle_path):
self.pickle_path = pickle_path

def load(self, path, word_v=None, pos_v=None, headtag_v=None):
class ConlluDataLoader(object):
def load(self, path):
datalist = [] datalist = []
with open(path, 'r', encoding='utf-8') as f: with open(path, 'r', encoding='utf-8') as f:
sample = [] sample = []
@@ -49,15 +45,10 @@ class MyDataLoader(object):
for sample in datalist: for sample in datalist:
# print(sample) # print(sample)
res = self.get_one(sample) res = self.get_one(sample)
if word_v is not None:
word_v.update(res[0])
pos_v.update(res[1])
headtag_v.update(res[3])
ds.append(Instance(word_seq=TextField(res[0], is_target=False), ds.append(Instance(word_seq=TextField(res[0], is_target=False),
pos_seq=TextField(res[1], is_target=False), pos_seq=TextField(res[1], is_target=False),
head_indices=SeqLabelField(res[2], is_target=True), head_indices=SeqLabelField(res[2], is_target=True),
head_labels=TextField(res[3], is_target=True),
seq_mask=SeqLabelField([1 for _ in range(len(res[0]))], is_target=False)))
head_labels=TextField(res[3], is_target=True)))


return ds return ds


@@ -76,17 +67,57 @@ class MyDataLoader(object):
head_tags.append(t4) head_tags.append(t4)
return (text, pos_tags, heads, head_tags) return (text, pos_tags, heads, head_tags)


def index_data(self, dataset, word_v, pos_v, tag_v):
dataset.index_field('word_seq', word_v)
dataset.index_field('pos_seq', pos_v)
dataset.index_field('head_labels', tag_v)
class CTBDataLoader(object):
def load(self, data_path):
with open(data_path, "r", encoding="utf-8") as f:
lines = f.readlines()
data = self.parse(lines)
return self.convert(data)

def parse(self, lines):
"""
[
[word], [pos], [head_index], [head_tag]
]
"""
sample = []
data = []
for i, line in enumerate(lines):
line = line.strip()
if len(line) == 0 or i+1 == len(lines):
data.append(list(map(list, zip(*sample))))
sample = []
else:
sample.append(line.split())
return data

def convert(self, data):
dataset = DataSet()
for sample in data:
word_seq = ["<ROOT>"] + sample[0]
pos_seq = ["<ROOT>"] + sample[1]
heads = [0] + list(map(int, sample[2]))
head_tags = ["ROOT"] + sample[3]
dataset.append(Instance(word_seq=TextField(word_seq, is_target=False),
pos_seq=TextField(pos_seq, is_target=False),
head_indices=SeqLabelField(heads, is_target=True),
head_labels=TextField(head_tags, is_target=True)))
return dataset


# datadir = "/mnt/c/Me/Dev/release-2.2-st-train-dev-data/ud-treebanks-v2.2/UD_English-EWT" # datadir = "/mnt/c/Me/Dev/release-2.2-st-train-dev-data/ud-treebanks-v2.2/UD_English-EWT"
datadir = "/home/yfshao/UD_English-EWT"
# datadir = "/home/yfshao/UD_English-EWT"
# train_data_name = "en_ewt-ud-train.conllu"
# dev_data_name = "en_ewt-ud-dev.conllu"
# emb_file_name = '/home/yfshao/glove.6B.100d.txt'
# loader = ConlluDataLoader()

datadir = "/home/yfshao/parser-data"
train_data_name = "train_ctb5.txt"
dev_data_name = "dev_ctb5.txt"
emb_file_name = "/home/yfshao/parser-data/word_OOVthr_30_100v.txt"
loader = CTBDataLoader()

cfgfile = './cfg.cfg' cfgfile = './cfg.cfg'
train_data_name = "en_ewt-ud-train.conllu"
dev_data_name = "en_ewt-ud-dev.conllu"
emb_file_name = '/home/yfshao/glove.6B.100d.txt'
processed_datadir = './save' processed_datadir = './save'


# Config Loader # Config Loader
@@ -96,7 +127,7 @@ model_args = ConfigSection()
optim_args = ConfigSection() optim_args = ConfigSection()
ConfigLoader.load_config(cfgfile, {"train": train_args, "test": test_args, "model": model_args, "optim": optim_args}) ConfigLoader.load_config(cfgfile, {"train": train_args, "test": test_args, "model": model_args, "optim": optim_args})


# Data Loader
# Pickle Loader
def save_data(dirpath, **kwargs): def save_data(dirpath, **kwargs):
import _pickle import _pickle
if not os.path.exists(dirpath): if not os.path.exists(dirpath):
@@ -140,6 +171,7 @@ class MyTester(object):
tmp[eval_name] = torch.cat(tensorlist, dim=0) tmp[eval_name] = torch.cat(tensorlist, dim=0)


self.res = self.model.metrics(**tmp) self.res = self.model.metrics(**tmp)
print(self.show_metrics())


def show_metrics(self): def show_metrics(self):
s = "" s = ""
@@ -148,7 +180,6 @@ class MyTester(object):
return s return s




loader = MyDataLoader('')
try: try:
data_dict = load_data(processed_datadir) data_dict = load_data(processed_datadir)
word_v = data_dict['word_v'] word_v = data_dict['word_v']
@@ -163,12 +194,17 @@ except Exception as _:
word_v = Vocabulary(need_default=True, min_freq=2) word_v = Vocabulary(need_default=True, min_freq=2)
pos_v = Vocabulary(need_default=True) pos_v = Vocabulary(need_default=True)
tag_v = Vocabulary(need_default=False) tag_v = Vocabulary(need_default=False)
train_data = loader.load(os.path.join(datadir, train_data_name), word_v, pos_v, tag_v)
train_data = loader.load(os.path.join(datadir, train_data_name))
dev_data = loader.load(os.path.join(datadir, dev_data_name)) dev_data = loader.load(os.path.join(datadir, dev_data_name))
train_data.update_vocab(word_seq=word_v, pos_seq=pos_v, head_labels=tag_v)
save_data(processed_datadir, word_v=word_v, pos_v=pos_v, tag_v=tag_v, train_data=train_data, dev_data=dev_data) save_data(processed_datadir, word_v=word_v, pos_v=pos_v, tag_v=tag_v, train_data=train_data, dev_data=dev_data)


loader.index_data(train_data, word_v, pos_v, tag_v)
loader.index_data(dev_data, word_v, pos_v, tag_v)
train_data.index_field("word_seq", word_v).index_field("pos_seq", pos_v).index_field("head_labels", tag_v)
dev_data.index_field("word_seq", word_v).index_field("pos_seq", pos_v).index_field("head_labels", tag_v)
train_data.set_origin_len("word_seq")
dev_data.set_origin_len("word_seq")

print(train_data[:3])
print(len(train_data)) print(len(train_data))
print(len(dev_data)) print(len(dev_data))
ep = train_args['epochs'] ep = train_args['epochs']
@@ -199,6 +235,7 @@ def train():
model = BiaffineParser(**model_args.data) model = BiaffineParser(**model_args.data)


# use pretrain embedding # use pretrain embedding
word_v.unknown_label = "<OOV>"
embed, _ = EmbedLoader.load_embedding(model_args['word_emb_dim'], emb_file_name, 'glove', word_v, os.path.join(processed_datadir, 'word_emb.pkl')) embed, _ = EmbedLoader.load_embedding(model_args['word_emb_dim'], emb_file_name, 'glove', word_v, os.path.join(processed_datadir, 'word_emb.pkl'))
model.word_embedding = torch.nn.Embedding.from_pretrained(embed, freeze=False) model.word_embedding = torch.nn.Embedding.from_pretrained(embed, freeze=False)
model.word_embedding.padding_idx = word_v.padding_idx model.word_embedding.padding_idx = word_v.padding_idx


Loading…
Cancel
Save