Browse Source

update biaffine parser

tags/v0.2.0
yunfan 6 years ago
parent
commit
102259df39
6 changed files with 85 additions and 31 deletions
  1. +3
    -0
      fastNLP/core/field.py
  2. +3
    -0
      fastNLP/core/instance.py
  3. +6
    -1
      fastNLP/core/vocabulary.py
  4. +3
    -3
      fastNLP/loader/embed_loader.py
  5. +8
    -2
      fastNLP/models/biaffine_parser.py
  6. +62
    -25
      reproduction/Biaffine_parser/run.py

+ 3
- 0
fastNLP/core/field.py View File

@@ -21,6 +21,9 @@ class Field(object):
def contents(self):
raise NotImplementedError

def __repr__(self):
return self.contents().__repr__()

class TextField(Field):
def __init__(self, text, is_target):
"""


+ 3
- 0
fastNLP/core/instance.py View File

@@ -82,3 +82,6 @@ class Instance(object):
name, field_name = origin_len
tensor_x[name] = torch.LongTensor([self.fields[field_name].get_length()])
return tensor_x, tensor_y

def __repr__(self):
return self.fields.__repr__()

+ 6
- 1
fastNLP/core/vocabulary.py View File

@@ -114,7 +114,7 @@ class Vocabulary(object):
if w in self.word2idx:
return self.word2idx[w]
elif self.has_default:
return self.word2idx[DEFAULT_UNKNOWN_LABEL]
return self.word2idx[self.unknown_label]
else:
raise ValueError("word {} not in vocabulary".format(w))

@@ -134,6 +134,11 @@ class Vocabulary(object):
return None
return self.word2idx[self.unknown_label]

def __setattr__(self, name, val):
if name in self.__dict__ and name in ["unknown_label", "padding_label"]:
self.word2idx[val] = self.word2idx.pop(self.__dict__[name])
self.__dict__[name] = val

@property
@check_build_vocab
def padding_idx(self):


+ 3
- 3
fastNLP/loader/embed_loader.py View File

@@ -17,8 +17,8 @@ class EmbedLoader(BaseLoader):
def _load_glove(emb_file):
"""Read file as a glove embedding

file format:
embeddings are split by line,
file format:
embeddings are split by line,
for one embedding, word and numbers split by space
Example::

@@ -33,7 +33,7 @@ class EmbedLoader(BaseLoader):
if len(line) > 0:
emb[line[0]] = torch.Tensor(list(map(float, line[1:])))
return emb
@staticmethod
def _load_pretrain(emb_file, emb_type):
"""Read txt data from embedding file and convert to np.array as pre-trained embedding


+ 8
- 2
fastNLP/models/biaffine_parser.py View File

@@ -182,6 +182,12 @@ class LabelBilinear(nn.Module):
output += self.lin(torch.cat([x1, x2], dim=2))
return output

def len2masks(origin_len, max_len):
if origin_len.dim() <= 1:
origin_len = origin_len.unsqueeze(1) # [batch_size, 1]
seq_range = torch.arange(start=0, end=max_len, dtype=torch.long, device=origin_len.device) # [max_len,]
seq_mask = torch.gt(origin_len, seq_range.unsqueeze(0)) # [batch_size, max_len]
return seq_mask

class BiaffineParser(GraphParser):
"""Biaffine Dependency Parser implemantation.
@@ -238,7 +244,7 @@ class BiaffineParser(GraphParser):
self.use_greedy_infer = use_greedy_infer
initial_parameter(self)

def forward(self, word_seq, pos_seq, seq_mask, gold_heads=None, **_):
def forward(self, word_seq, pos_seq, word_seq_origin_len, gold_heads=None, **_):
"""
:param word_seq: [batch_size, seq_len] sequence of word's indices
:param pos_seq: [batch_size, seq_len] sequence of word's indices
@@ -256,7 +262,7 @@ class BiaffineParser(GraphParser):
batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=word_seq.device).unsqueeze(1)

# get sequence mask
seq_mask = seq_mask.long()
seq_mask = len2masks(word_seq_origin_len, seq_len).long()

word = self.normal_dropout(self.word_embedding(word_seq)) # [N,L] -> [N,L,C_0]
pos = self.normal_dropout(self.pos_embedding(pos_seq)) # [N,L] -> [N,L,C_1]


+ 62
- 25
reproduction/Biaffine_parser/run.py View File

@@ -14,7 +14,6 @@ from fastNLP.core.dataset import DataSet
from fastNLP.core.batch import Batch
from fastNLP.core.sampler import SequentialSampler
from fastNLP.core.field import TextField, SeqLabelField
from fastNLP.core.preprocess import SeqLabelPreprocess, load_pickle
from fastNLP.core.tester import Tester
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
from fastNLP.loader.model_loader import ModelLoader
@@ -26,11 +25,8 @@ from fastNLP.saver.model_saver import ModelSaver
if len(os.path.dirname(__file__)) != 0:
os.chdir(os.path.dirname(__file__))

class MyDataLoader(object):
def __init__(self, pickle_path):
self.pickle_path = pickle_path

def load(self, path, word_v=None, pos_v=None, headtag_v=None):
class ConlluDataLoader(object):
def load(self, path):
datalist = []
with open(path, 'r', encoding='utf-8') as f:
sample = []
@@ -49,15 +45,10 @@ class MyDataLoader(object):
for sample in datalist:
# print(sample)
res = self.get_one(sample)
if word_v is not None:
word_v.update(res[0])
pos_v.update(res[1])
headtag_v.update(res[3])
ds.append(Instance(word_seq=TextField(res[0], is_target=False),
pos_seq=TextField(res[1], is_target=False),
head_indices=SeqLabelField(res[2], is_target=True),
head_labels=TextField(res[3], is_target=True),
seq_mask=SeqLabelField([1 for _ in range(len(res[0]))], is_target=False)))
head_labels=TextField(res[3], is_target=True)))

return ds

@@ -76,17 +67,57 @@ class MyDataLoader(object):
head_tags.append(t4)
return (text, pos_tags, heads, head_tags)

def index_data(self, dataset, word_v, pos_v, tag_v):
dataset.index_field('word_seq', word_v)
dataset.index_field('pos_seq', pos_v)
dataset.index_field('head_labels', tag_v)
class CTBDataLoader(object):
def load(self, data_path):
with open(data_path, "r", encoding="utf-8") as f:
lines = f.readlines()
data = self.parse(lines)
return self.convert(data)

def parse(self, lines):
"""
[
[word], [pos], [head_index], [head_tag]
]
"""
sample = []
data = []
for i, line in enumerate(lines):
line = line.strip()
if len(line) == 0 or i+1 == len(lines):
data.append(list(map(list, zip(*sample))))
sample = []
else:
sample.append(line.split())
return data

def convert(self, data):
dataset = DataSet()
for sample in data:
word_seq = ["<ROOT>"] + sample[0]
pos_seq = ["<ROOT>"] + sample[1]
heads = [0] + list(map(int, sample[2]))
head_tags = ["ROOT"] + sample[3]
dataset.append(Instance(word_seq=TextField(word_seq, is_target=False),
pos_seq=TextField(pos_seq, is_target=False),
head_indices=SeqLabelField(heads, is_target=True),
head_labels=TextField(head_tags, is_target=True)))
return dataset

# datadir = "/mnt/c/Me/Dev/release-2.2-st-train-dev-data/ud-treebanks-v2.2/UD_English-EWT"
datadir = "/home/yfshao/UD_English-EWT"
# datadir = "/home/yfshao/UD_English-EWT"
# train_data_name = "en_ewt-ud-train.conllu"
# dev_data_name = "en_ewt-ud-dev.conllu"
# emb_file_name = '/home/yfshao/glove.6B.100d.txt'
# loader = ConlluDataLoader()

datadir = "/home/yfshao/parser-data"
train_data_name = "train_ctb5.txt"
dev_data_name = "dev_ctb5.txt"
emb_file_name = "/home/yfshao/parser-data/word_OOVthr_30_100v.txt"
loader = CTBDataLoader()

cfgfile = './cfg.cfg'
train_data_name = "en_ewt-ud-train.conllu"
dev_data_name = "en_ewt-ud-dev.conllu"
emb_file_name = '/home/yfshao/glove.6B.100d.txt'
processed_datadir = './save'

# Config Loader
@@ -96,7 +127,7 @@ model_args = ConfigSection()
optim_args = ConfigSection()
ConfigLoader.load_config(cfgfile, {"train": train_args, "test": test_args, "model": model_args, "optim": optim_args})

# Data Loader
# Pickle Loader
def save_data(dirpath, **kwargs):
import _pickle
if not os.path.exists(dirpath):
@@ -140,6 +171,7 @@ class MyTester(object):
tmp[eval_name] = torch.cat(tensorlist, dim=0)

self.res = self.model.metrics(**tmp)
print(self.show_metrics())

def show_metrics(self):
s = ""
@@ -148,7 +180,6 @@ class MyTester(object):
return s


loader = MyDataLoader('')
try:
data_dict = load_data(processed_datadir)
word_v = data_dict['word_v']
@@ -163,12 +194,17 @@ except Exception as _:
word_v = Vocabulary(need_default=True, min_freq=2)
pos_v = Vocabulary(need_default=True)
tag_v = Vocabulary(need_default=False)
train_data = loader.load(os.path.join(datadir, train_data_name), word_v, pos_v, tag_v)
train_data = loader.load(os.path.join(datadir, train_data_name))
dev_data = loader.load(os.path.join(datadir, dev_data_name))
train_data.update_vocab(word_seq=word_v, pos_seq=pos_v, head_labels=tag_v)
save_data(processed_datadir, word_v=word_v, pos_v=pos_v, tag_v=tag_v, train_data=train_data, dev_data=dev_data)

loader.index_data(train_data, word_v, pos_v, tag_v)
loader.index_data(dev_data, word_v, pos_v, tag_v)
train_data.index_field("word_seq", word_v).index_field("pos_seq", pos_v).index_field("head_labels", tag_v)
dev_data.index_field("word_seq", word_v).index_field("pos_seq", pos_v).index_field("head_labels", tag_v)
train_data.set_origin_len("word_seq")
dev_data.set_origin_len("word_seq")

print(train_data[:3])
print(len(train_data))
print(len(dev_data))
ep = train_args['epochs']
@@ -199,6 +235,7 @@ def train():
model = BiaffineParser(**model_args.data)

# use pretrain embedding
word_v.unknown_label = "<OOV>"
embed, _ = EmbedLoader.load_embedding(model_args['word_emb_dim'], emb_file_name, 'glove', word_v, os.path.join(processed_datadir, 'word_emb.pkl'))
model.word_embedding = torch.nn.Embedding.from_pretrained(embed, freeze=False)
model.word_embedding.padding_idx = word_v.padding_idx


Loading…
Cancel
Save