|
|
@@ -14,7 +14,6 @@ from fastNLP.core.dataset import DataSet |
|
|
|
from fastNLP.core.batch import Batch |
|
|
|
from fastNLP.core.sampler import SequentialSampler |
|
|
|
from fastNLP.core.field import TextField, SeqLabelField |
|
|
|
from fastNLP.core.preprocess import SeqLabelPreprocess, load_pickle |
|
|
|
from fastNLP.core.tester import Tester |
|
|
|
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection |
|
|
|
from fastNLP.loader.model_loader import ModelLoader |
|
|
@@ -26,11 +25,8 @@ from fastNLP.saver.model_saver import ModelSaver |
|
|
|
if len(os.path.dirname(__file__)) != 0: |
|
|
|
os.chdir(os.path.dirname(__file__)) |
|
|
|
|
|
|
|
class MyDataLoader(object): |
|
|
|
def __init__(self, pickle_path): |
|
|
|
self.pickle_path = pickle_path |
|
|
|
|
|
|
|
def load(self, path, word_v=None, pos_v=None, headtag_v=None): |
|
|
|
class ConlluDataLoader(object): |
|
|
|
def load(self, path): |
|
|
|
datalist = [] |
|
|
|
with open(path, 'r', encoding='utf-8') as f: |
|
|
|
sample = [] |
|
|
@@ -49,15 +45,10 @@ class MyDataLoader(object): |
|
|
|
for sample in datalist: |
|
|
|
# print(sample) |
|
|
|
res = self.get_one(sample) |
|
|
|
if word_v is not None: |
|
|
|
word_v.update(res[0]) |
|
|
|
pos_v.update(res[1]) |
|
|
|
headtag_v.update(res[3]) |
|
|
|
ds.append(Instance(word_seq=TextField(res[0], is_target=False), |
|
|
|
pos_seq=TextField(res[1], is_target=False), |
|
|
|
head_indices=SeqLabelField(res[2], is_target=True), |
|
|
|
head_labels=TextField(res[3], is_target=True), |
|
|
|
seq_mask=SeqLabelField([1 for _ in range(len(res[0]))], is_target=False))) |
|
|
|
head_labels=TextField(res[3], is_target=True))) |
|
|
|
|
|
|
|
return ds |
|
|
|
|
|
|
@@ -76,17 +67,57 @@ class MyDataLoader(object): |
|
|
|
head_tags.append(t4) |
|
|
|
return (text, pos_tags, heads, head_tags) |
|
|
|
|
|
|
|
def index_data(self, dataset, word_v, pos_v, tag_v): |
|
|
|
dataset.index_field('word_seq', word_v) |
|
|
|
dataset.index_field('pos_seq', pos_v) |
|
|
|
dataset.index_field('head_labels', tag_v) |
|
|
|
class CTBDataLoader(object): |
|
|
|
def load(self, data_path): |
|
|
|
with open(data_path, "r", encoding="utf-8") as f: |
|
|
|
lines = f.readlines() |
|
|
|
data = self.parse(lines) |
|
|
|
return self.convert(data) |
|
|
|
|
|
|
|
def parse(self, lines): |
|
|
|
""" |
|
|
|
[ |
|
|
|
[word], [pos], [head_index], [head_tag] |
|
|
|
] |
|
|
|
""" |
|
|
|
sample = [] |
|
|
|
data = [] |
|
|
|
for i, line in enumerate(lines): |
|
|
|
line = line.strip() |
|
|
|
if len(line) == 0 or i+1 == len(lines): |
|
|
|
data.append(list(map(list, zip(*sample)))) |
|
|
|
sample = [] |
|
|
|
else: |
|
|
|
sample.append(line.split()) |
|
|
|
return data |
|
|
|
|
|
|
|
def convert(self, data): |
|
|
|
dataset = DataSet() |
|
|
|
for sample in data: |
|
|
|
word_seq = ["<ROOT>"] + sample[0] |
|
|
|
pos_seq = ["<ROOT>"] + sample[1] |
|
|
|
heads = [0] + list(map(int, sample[2])) |
|
|
|
head_tags = ["ROOT"] + sample[3] |
|
|
|
dataset.append(Instance(word_seq=TextField(word_seq, is_target=False), |
|
|
|
pos_seq=TextField(pos_seq, is_target=False), |
|
|
|
head_indices=SeqLabelField(heads, is_target=True), |
|
|
|
head_labels=TextField(head_tags, is_target=True))) |
|
|
|
return dataset |
|
|
|
|
|
|
|
# datadir = "/mnt/c/Me/Dev/release-2.2-st-train-dev-data/ud-treebanks-v2.2/UD_English-EWT" |
|
|
|
datadir = "/home/yfshao/UD_English-EWT" |
|
|
|
# datadir = "/home/yfshao/UD_English-EWT" |
|
|
|
# train_data_name = "en_ewt-ud-train.conllu" |
|
|
|
# dev_data_name = "en_ewt-ud-dev.conllu" |
|
|
|
# emb_file_name = '/home/yfshao/glove.6B.100d.txt' |
|
|
|
# loader = ConlluDataLoader() |
|
|
|
|
|
|
|
datadir = "/home/yfshao/parser-data" |
|
|
|
train_data_name = "train_ctb5.txt" |
|
|
|
dev_data_name = "dev_ctb5.txt" |
|
|
|
emb_file_name = "/home/yfshao/parser-data/word_OOVthr_30_100v.txt" |
|
|
|
loader = CTBDataLoader() |
|
|
|
|
|
|
|
cfgfile = './cfg.cfg' |
|
|
|
train_data_name = "en_ewt-ud-train.conllu" |
|
|
|
dev_data_name = "en_ewt-ud-dev.conllu" |
|
|
|
emb_file_name = '/home/yfshao/glove.6B.100d.txt' |
|
|
|
processed_datadir = './save' |
|
|
|
|
|
|
|
# Config Loader |
|
|
@@ -96,7 +127,7 @@ model_args = ConfigSection() |
|
|
|
optim_args = ConfigSection() |
|
|
|
ConfigLoader.load_config(cfgfile, {"train": train_args, "test": test_args, "model": model_args, "optim": optim_args}) |
|
|
|
|
|
|
|
# Data Loader |
|
|
|
# Pickle Loader |
|
|
|
def save_data(dirpath, **kwargs): |
|
|
|
import _pickle |
|
|
|
if not os.path.exists(dirpath): |
|
|
@@ -140,6 +171,7 @@ class MyTester(object): |
|
|
|
tmp[eval_name] = torch.cat(tensorlist, dim=0) |
|
|
|
|
|
|
|
self.res = self.model.metrics(**tmp) |
|
|
|
print(self.show_metrics()) |
|
|
|
|
|
|
|
def show_metrics(self): |
|
|
|
s = "" |
|
|
@@ -148,7 +180,6 @@ class MyTester(object): |
|
|
|
return s |
|
|
|
|
|
|
|
|
|
|
|
loader = MyDataLoader('') |
|
|
|
try: |
|
|
|
data_dict = load_data(processed_datadir) |
|
|
|
word_v = data_dict['word_v'] |
|
|
@@ -163,12 +194,17 @@ except Exception as _: |
|
|
|
word_v = Vocabulary(need_default=True, min_freq=2) |
|
|
|
pos_v = Vocabulary(need_default=True) |
|
|
|
tag_v = Vocabulary(need_default=False) |
|
|
|
train_data = loader.load(os.path.join(datadir, train_data_name), word_v, pos_v, tag_v) |
|
|
|
train_data = loader.load(os.path.join(datadir, train_data_name)) |
|
|
|
dev_data = loader.load(os.path.join(datadir, dev_data_name)) |
|
|
|
train_data.update_vocab(word_seq=word_v, pos_seq=pos_v, head_labels=tag_v) |
|
|
|
save_data(processed_datadir, word_v=word_v, pos_v=pos_v, tag_v=tag_v, train_data=train_data, dev_data=dev_data) |
|
|
|
|
|
|
|
loader.index_data(train_data, word_v, pos_v, tag_v) |
|
|
|
loader.index_data(dev_data, word_v, pos_v, tag_v) |
|
|
|
train_data.index_field("word_seq", word_v).index_field("pos_seq", pos_v).index_field("head_labels", tag_v) |
|
|
|
dev_data.index_field("word_seq", word_v).index_field("pos_seq", pos_v).index_field("head_labels", tag_v) |
|
|
|
train_data.set_origin_len("word_seq") |
|
|
|
dev_data.set_origin_len("word_seq") |
|
|
|
|
|
|
|
print(train_data[:3]) |
|
|
|
print(len(train_data)) |
|
|
|
print(len(dev_data)) |
|
|
|
ep = train_args['epochs'] |
|
|
@@ -199,6 +235,7 @@ def train(): |
|
|
|
model = BiaffineParser(**model_args.data) |
|
|
|
|
|
|
|
# use pretrain embedding |
|
|
|
word_v.unknown_label = "<OOV>" |
|
|
|
embed, _ = EmbedLoader.load_embedding(model_args['word_emb_dim'], emb_file_name, 'glove', word_v, os.path.join(processed_datadir, 'word_emb.pkl')) |
|
|
|
model.word_embedding = torch.nn.Embedding.from_pretrained(embed, freeze=False) |
|
|
|
model.word_embedding.padding_idx = word_v.padding_idx |
|
|
|