@@ -24,6 +24,9 @@ class Field(object): | |||||
def __repr__(self): | def __repr__(self): | ||||
return self.contents().__repr__() | return self.contents().__repr__() | ||||
def new(self, *args, **kwargs): | |||||
return self.__class__(*args, **kwargs, is_target=self.is_target) | |||||
class TextField(Field): | class TextField(Field): | ||||
def __init__(self, text, is_target): | def __init__(self, text, is_target): | ||||
""" | """ | ||||
@@ -35,6 +35,9 @@ class Instance(object): | |||||
else: | else: | ||||
raise KeyError("{} not found".format(name)) | raise KeyError("{} not found".format(name)) | ||||
def __setitem__(self, name, field): | |||||
return self.add_field(name, field) | |||||
def get_length(self): | def get_length(self): | ||||
"""Fetch the length of all fields in the instance. | """Fetch the length of all fields in the instance. | ||||
@@ -74,7 +74,7 @@ class Tester(object): | |||||
output_list = [] | output_list = [] | ||||
truth_list = [] | truth_list = [] | ||||
data_iterator = Batch(dev_data, self.batch_size, sampler=RandomSampler(), use_cuda=self.use_cuda) | |||||
data_iterator = Batch(dev_data, self.batch_size, sampler=RandomSampler(), use_cuda=self.use_cuda, sort_in_batch=True, sort_key='word_seq') | |||||
with torch.no_grad(): | with torch.no_grad(): | ||||
for batch_x, batch_y in data_iterator: | for batch_x, batch_y in data_iterator: | ||||
@@ -1,6 +1,6 @@ | |||||
import os | import os | ||||
import time | import time | ||||
from datetime import timedelta | |||||
from datetime import timedelta, datetime | |||||
import torch | import torch | ||||
from tensorboardX import SummaryWriter | from tensorboardX import SummaryWriter | ||||
@@ -15,7 +15,7 @@ from fastNLP.saver.logger import create_logger | |||||
from fastNLP.saver.model_saver import ModelSaver | from fastNLP.saver.model_saver import ModelSaver | ||||
logger = create_logger(__name__, "./train_test.log") | logger = create_logger(__name__, "./train_test.log") | ||||
logger.disabled = True | |||||
class Trainer(object): | class Trainer(object): | ||||
"""Operations of training a model, including data loading, gradient descent, and validation. | """Operations of training a model, including data loading, gradient descent, and validation. | ||||
@@ -42,7 +42,7 @@ class Trainer(object): | |||||
""" | """ | ||||
default_args = {"epochs": 1, "batch_size": 2, "validate": False, "use_cuda": False, "pickle_path": "./save/", | default_args = {"epochs": 1, "batch_size": 2, "validate": False, "use_cuda": False, "pickle_path": "./save/", | ||||
"save_best_dev": False, "model_name": "default_model_name.pkl", "print_every_step": 1, | "save_best_dev": False, "model_name": "default_model_name.pkl", "print_every_step": 1, | ||||
"valid_step": 500, "eval_sort_key": None, | |||||
"valid_step": 500, "eval_sort_key": 'acc', | |||||
"loss": Loss(None), # used to pass type check | "loss": Loss(None), # used to pass type check | ||||
"optimizer": Optimizer("Adam", lr=0.001, weight_decay=0), | "optimizer": Optimizer("Adam", lr=0.001, weight_decay=0), | ||||
"evaluator": Evaluator() | "evaluator": Evaluator() | ||||
@@ -111,13 +111,17 @@ class Trainer(object): | |||||
else: | else: | ||||
self._model = network | self._model = network | ||||
print(self._model) | |||||
# define Tester over dev data | # define Tester over dev data | ||||
self.dev_data = None | |||||
if self.validate: | if self.validate: | ||||
default_valid_args = {"batch_size": self.batch_size, "pickle_path": self.pickle_path, | default_valid_args = {"batch_size": self.batch_size, "pickle_path": self.pickle_path, | ||||
"use_cuda": self.use_cuda, "evaluator": self._evaluator} | "use_cuda": self.use_cuda, "evaluator": self._evaluator} | ||||
if self.validator is None: | if self.validator is None: | ||||
self.validator = self._create_validator(default_valid_args) | self.validator = self._create_validator(default_valid_args) | ||||
logger.info("validator defined as {}".format(str(self.validator))) | logger.info("validator defined as {}".format(str(self.validator))) | ||||
self.dev_data = dev_data | |||||
# optimizer and loss | # optimizer and loss | ||||
self.define_optimizer() | self.define_optimizer() | ||||
@@ -130,7 +134,7 @@ class Trainer(object): | |||||
# main training procedure | # main training procedure | ||||
start = time.time() | start = time.time() | ||||
self.start_time = str(start) | |||||
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M')) | |||||
logger.info("training epochs started " + self.start_time) | logger.info("training epochs started " + self.start_time) | ||||
epoch, iters = 1, 0 | epoch, iters = 1, 0 | ||||
@@ -141,15 +145,17 @@ class Trainer(object): | |||||
# prepare mini-batch iterator | # prepare mini-batch iterator | ||||
data_iterator = Batch(train_data, batch_size=self.batch_size, sampler=RandomSampler(), | data_iterator = Batch(train_data, batch_size=self.batch_size, sampler=RandomSampler(), | ||||
use_cuda=self.use_cuda) | |||||
use_cuda=self.use_cuda, sort_in_batch=True, sort_key='word_seq') | |||||
logger.info("prepared data iterator") | logger.info("prepared data iterator") | ||||
# one forward and backward pass | # one forward and backward pass | ||||
iters += self._train_step(data_iterator, network, start=start, n_print=self.print_every_step, epoch=epoch, step=iters, dev_data=dev_data) | |||||
iters = self._train_step(data_iterator, network, start=start, n_print=self.print_every_step, epoch=epoch, step=iters, dev_data=dev_data) | |||||
# validation | # validation | ||||
if self.validate: | if self.validate: | ||||
self.valid_model() | self.valid_model() | ||||
self.save_model(self._model, 'training_model_'+self.start_time) | |||||
epoch += 1 | |||||
def _train_step(self, data_iterator, network, **kwargs): | def _train_step(self, data_iterator, network, **kwargs): | ||||
"""Training process in one epoch. | """Training process in one epoch. | ||||
@@ -160,13 +166,16 @@ class Trainer(object): | |||||
- epoch: int, | - epoch: int, | ||||
""" | """ | ||||
step = kwargs['step'] | step = kwargs['step'] | ||||
dev_data = kwargs['dev_data'] | |||||
for batch_x, batch_y in data_iterator: | for batch_x, batch_y in data_iterator: | ||||
prediction = self.data_forward(network, batch_x) | prediction = self.data_forward(network, batch_x) | ||||
loss = self.get_loss(prediction, batch_y) | loss = self.get_loss(prediction, batch_y) | ||||
self.grad_backward(loss) | self.grad_backward(loss) | ||||
if torch.rand(1).item() < 0.001: | |||||
print('[grads at epoch: {:>3} step: {:>4}]'.format(kwargs['epoch'], step)) | |||||
for name, p in self._model.named_parameters(): | |||||
if p.requires_grad: | |||||
print('\t{} {} {}'.format(name, tuple(p.size()), torch.sum(p.grad).item())) | |||||
self.update() | self.update() | ||||
self._summary_writer.add_scalar("loss", loss.item(), global_step=step) | self._summary_writer.add_scalar("loss", loss.item(), global_step=step) | ||||
@@ -183,13 +192,14 @@ class Trainer(object): | |||||
return step | return step | ||||
def valid_model(self): | def valid_model(self): | ||||
if dev_data is None: | |||||
if self.dev_data is None: | |||||
raise RuntimeError( | raise RuntimeError( | ||||
"self.validate is True in trainer, but dev_data is None. Please provide the validation data.") | "self.validate is True in trainer, but dev_data is None. Please provide the validation data.") | ||||
logger.info("validation started") | logger.info("validation started") | ||||
res = self.validator.test(network, dev_data) | |||||
res = self.validator.test(self._model, self.dev_data) | |||||
if self.save_best_dev and self.best_eval_result(res): | if self.save_best_dev and self.best_eval_result(res): | ||||
logger.info('save best result! {}'.format(res)) | logger.info('save best result! {}'.format(res)) | ||||
print('save best result! {}'.format(res)) | |||||
self.save_model(self._model, 'best_model_'+self.start_time) | self.save_model(self._model, 'best_model_'+self.start_time) | ||||
return res | return res | ||||
@@ -282,14 +292,10 @@ class Trainer(object): | |||||
""" | """ | ||||
if isinstance(metrics, tuple): | if isinstance(metrics, tuple): | ||||
loss, metrics = metrics | loss, metrics = metrics | ||||
else: | |||||
metrics = validator.metrics | |||||
if isinstance(metrics, dict): | if isinstance(metrics, dict): | ||||
if len(metrics) == 1: | if len(metrics) == 1: | ||||
accuracy = list(metrics.values())[0] | accuracy = list(metrics.values())[0] | ||||
elif self.eval_sort_key is None: | |||||
raise ValueError('dict format metrics should provide sort key for eval best result') | |||||
else: | else: | ||||
accuracy = metrics[self.eval_sort_key] | accuracy = metrics[self.eval_sort_key] | ||||
else: | else: | ||||
@@ -199,6 +199,8 @@ class BiaffineParser(GraphParser): | |||||
word_emb_dim, | word_emb_dim, | ||||
pos_vocab_size, | pos_vocab_size, | ||||
pos_emb_dim, | pos_emb_dim, | ||||
word_hid_dim, | |||||
pos_hid_dim, | |||||
rnn_layers, | rnn_layers, | ||||
rnn_hidden_size, | rnn_hidden_size, | ||||
arc_mlp_size, | arc_mlp_size, | ||||
@@ -209,10 +211,15 @@ class BiaffineParser(GraphParser): | |||||
use_greedy_infer=False): | use_greedy_infer=False): | ||||
super(BiaffineParser, self).__init__() | super(BiaffineParser, self).__init__() | ||||
rnn_out_size = 2 * rnn_hidden_size | |||||
self.word_embedding = nn.Embedding(num_embeddings=word_vocab_size, embedding_dim=word_emb_dim) | self.word_embedding = nn.Embedding(num_embeddings=word_vocab_size, embedding_dim=word_emb_dim) | ||||
self.pos_embedding = nn.Embedding(num_embeddings=pos_vocab_size, embedding_dim=pos_emb_dim) | self.pos_embedding = nn.Embedding(num_embeddings=pos_vocab_size, embedding_dim=pos_emb_dim) | ||||
self.word_fc = nn.Linear(word_emb_dim, word_hid_dim) | |||||
self.pos_fc = nn.Linear(pos_emb_dim, pos_hid_dim) | |||||
self.word_norm = nn.LayerNorm(word_hid_dim) | |||||
self.pos_norm = nn.LayerNorm(pos_hid_dim) | |||||
if use_var_lstm: | if use_var_lstm: | ||||
self.lstm = VarLSTM(input_size=word_emb_dim + pos_emb_dim, | |||||
self.lstm = VarLSTM(input_size=word_hid_dim + pos_hid_dim, | |||||
hidden_size=rnn_hidden_size, | hidden_size=rnn_hidden_size, | ||||
num_layers=rnn_layers, | num_layers=rnn_layers, | ||||
bias=True, | bias=True, | ||||
@@ -221,7 +228,7 @@ class BiaffineParser(GraphParser): | |||||
hidden_dropout=dropout, | hidden_dropout=dropout, | ||||
bidirectional=True) | bidirectional=True) | ||||
else: | else: | ||||
self.lstm = nn.LSTM(input_size=word_emb_dim + pos_emb_dim, | |||||
self.lstm = nn.LSTM(input_size=word_hid_dim + pos_hid_dim, | |||||
hidden_size=rnn_hidden_size, | hidden_size=rnn_hidden_size, | ||||
num_layers=rnn_layers, | num_layers=rnn_layers, | ||||
bias=True, | bias=True, | ||||
@@ -229,12 +236,13 @@ class BiaffineParser(GraphParser): | |||||
dropout=dropout, | dropout=dropout, | ||||
bidirectional=True) | bidirectional=True) | ||||
rnn_out_size = 2 * rnn_hidden_size | |||||
self.arc_head_mlp = nn.Sequential(nn.Linear(rnn_out_size, arc_mlp_size), | self.arc_head_mlp = nn.Sequential(nn.Linear(rnn_out_size, arc_mlp_size), | ||||
nn.LayerNorm(arc_mlp_size), | |||||
nn.ELU(), | nn.ELU(), | ||||
TimestepDropout(p=dropout),) | TimestepDropout(p=dropout),) | ||||
self.arc_dep_mlp = copy.deepcopy(self.arc_head_mlp) | self.arc_dep_mlp = copy.deepcopy(self.arc_head_mlp) | ||||
self.label_head_mlp = nn.Sequential(nn.Linear(rnn_out_size, label_mlp_size), | self.label_head_mlp = nn.Sequential(nn.Linear(rnn_out_size, label_mlp_size), | ||||
nn.LayerNorm(label_mlp_size), | |||||
nn.ELU(), | nn.ELU(), | ||||
TimestepDropout(p=dropout),) | TimestepDropout(p=dropout),) | ||||
self.label_dep_mlp = copy.deepcopy(self.label_head_mlp) | self.label_dep_mlp = copy.deepcopy(self.label_head_mlp) | ||||
@@ -242,10 +250,18 @@ class BiaffineParser(GraphParser): | |||||
self.label_predictor = LabelBilinear(label_mlp_size, label_mlp_size, num_label, bias=True) | self.label_predictor = LabelBilinear(label_mlp_size, label_mlp_size, num_label, bias=True) | ||||
self.normal_dropout = nn.Dropout(p=dropout) | self.normal_dropout = nn.Dropout(p=dropout) | ||||
self.use_greedy_infer = use_greedy_infer | self.use_greedy_infer = use_greedy_infer | ||||
initial_parameter(self) | |||||
self.word_norm = nn.LayerNorm(word_emb_dim) | |||||
self.pos_norm = nn.LayerNorm(pos_emb_dim) | |||||
self.lstm_norm = nn.LayerNorm(rnn_out_size) | |||||
self.reset_parameters() | |||||
def reset_parameters(self): | |||||
for m in self.modules(): | |||||
if isinstance(m, nn.Embedding): | |||||
continue | |||||
elif isinstance(m, nn.LayerNorm): | |||||
nn.init.constant_(m.weight, 1) | |||||
nn.init.constant_(m.bias, 0) | |||||
else: | |||||
for p in m.parameters(): | |||||
nn.init.normal_(p, 0, 0.01) | |||||
def forward(self, word_seq, pos_seq, word_seq_origin_len, gold_heads=None, **_): | def forward(self, word_seq, pos_seq, word_seq_origin_len, gold_heads=None, **_): | ||||
""" | """ | ||||
@@ -262,19 +278,21 @@ class BiaffineParser(GraphParser): | |||||
# prepare embeddings | # prepare embeddings | ||||
batch_size, seq_len = word_seq.shape | batch_size, seq_len = word_seq.shape | ||||
# print('forward {} {}'.format(batch_size, seq_len)) | # print('forward {} {}'.format(batch_size, seq_len)) | ||||
batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=word_seq.device).unsqueeze(1) | |||||
# get sequence mask | # get sequence mask | ||||
seq_mask = len2masks(word_seq_origin_len, seq_len).long() | seq_mask = len2masks(word_seq_origin_len, seq_len).long() | ||||
word = self.normal_dropout(self.word_embedding(word_seq)) # [N,L] -> [N,L,C_0] | word = self.normal_dropout(self.word_embedding(word_seq)) # [N,L] -> [N,L,C_0] | ||||
pos = self.normal_dropout(self.pos_embedding(pos_seq)) # [N,L] -> [N,L,C_1] | pos = self.normal_dropout(self.pos_embedding(pos_seq)) # [N,L] -> [N,L,C_1] | ||||
word, pos = self.word_fc(word), self.pos_fc(pos) | |||||
word, pos = self.word_norm(word), self.pos_norm(pos) | word, pos = self.word_norm(word), self.pos_norm(pos) | ||||
x = torch.cat([word, pos], dim=2) # -> [N,L,C] | x = torch.cat([word, pos], dim=2) # -> [N,L,C] | ||||
del word, pos | |||||
# lstm, extract features | # lstm, extract features | ||||
x = nn.utils.rnn.pack_padded_sequence(x, word_seq_origin_len.squeeze(1), batch_first=True) | |||||
feat, _ = self.lstm(x) # -> [N,L,C] | feat, _ = self.lstm(x) # -> [N,L,C] | ||||
feat = self.lstm_norm(feat) | |||||
feat, _ = nn.utils.rnn.pad_packed_sequence(feat, batch_first=True) | |||||
# for arc biaffine | # for arc biaffine | ||||
# mlp, reduce dim | # mlp, reduce dim | ||||
@@ -282,6 +300,7 @@ class BiaffineParser(GraphParser): | |||||
arc_head = self.arc_head_mlp(feat) | arc_head = self.arc_head_mlp(feat) | ||||
label_dep = self.label_dep_mlp(feat) | label_dep = self.label_dep_mlp(feat) | ||||
label_head = self.label_head_mlp(feat) | label_head = self.label_head_mlp(feat) | ||||
del feat | |||||
# biaffine arc classifier | # biaffine arc classifier | ||||
arc_pred = self.arc_predictor(arc_head, arc_dep) # [N, L, L] | arc_pred = self.arc_predictor(arc_head, arc_dep) # [N, L, L] | ||||
@@ -289,7 +308,7 @@ class BiaffineParser(GraphParser): | |||||
arc_pred.masked_fill_(flip_mask.unsqueeze(1), -np.inf) | arc_pred.masked_fill_(flip_mask.unsqueeze(1), -np.inf) | ||||
# use gold or predicted arc to predict label | # use gold or predicted arc to predict label | ||||
if gold_heads is None: | |||||
if gold_heads is None or not self.training: | |||||
# use greedy decoding in training | # use greedy decoding in training | ||||
if self.training or self.use_greedy_infer: | if self.training or self.use_greedy_infer: | ||||
heads = self._greedy_decoder(arc_pred, seq_mask) | heads = self._greedy_decoder(arc_pred, seq_mask) | ||||
@@ -301,6 +320,7 @@ class BiaffineParser(GraphParser): | |||||
head_pred = None | head_pred = None | ||||
heads = gold_heads | heads = gold_heads | ||||
batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=word_seq.device).unsqueeze(1) | |||||
label_head = label_head[batch_range, heads].contiguous() | label_head = label_head[batch_range, heads].contiguous() | ||||
label_pred = self.label_predictor(label_head, label_dep) # [N, L, num_label] | label_pred = self.label_predictor(label_head, label_dep) # [N, L, num_label] | ||||
res_dict = {'arc_pred': arc_pred, 'label_pred': label_pred, 'seq_mask': seq_mask} | res_dict = {'arc_pred': arc_pred, 'label_pred': label_pred, 'seq_mask': seq_mask} | ||||
@@ -1,16 +1,14 @@ | |||||
[train] | [train] | ||||
epochs = -1 | epochs = -1 | ||||
<<<<<<< HEAD | |||||
batch_size = 16 | |||||
======= | |||||
batch_size = 32 | batch_size = 32 | ||||
>>>>>>> update biaffine | |||||
pickle_path = "./save/" | pickle_path = "./save/" | ||||
validate = true | validate = true | ||||
save_best_dev = true | save_best_dev = true | ||||
eval_sort_key = "UAS" | eval_sort_key = "UAS" | ||||
use_cuda = true | use_cuda = true | ||||
model_saved_path = "./save/" | model_saved_path = "./save/" | ||||
print_every_step = 20 | |||||
use_golden_train=true | |||||
[test] | [test] | ||||
save_output = true | save_output = true | ||||
@@ -26,14 +24,17 @@ word_vocab_size = -1 | |||||
word_emb_dim = 100 | word_emb_dim = 100 | ||||
pos_vocab_size = -1 | pos_vocab_size = -1 | ||||
pos_emb_dim = 100 | pos_emb_dim = 100 | ||||
word_hid_dim = 100 | |||||
pos_hid_dim = 100 | |||||
rnn_layers = 3 | rnn_layers = 3 | ||||
rnn_hidden_size = 400 | rnn_hidden_size = 400 | ||||
arc_mlp_size = 500 | arc_mlp_size = 500 | ||||
label_mlp_size = 100 | label_mlp_size = 100 | ||||
num_label = -1 | num_label = -1 | ||||
dropout = 0.33 | dropout = 0.33 | ||||
use_var_lstm=true | |||||
use_var_lstm=false | |||||
use_greedy_infer=false | use_greedy_infer=false | ||||
[optim] | [optim] | ||||
lr = 2e-3 | lr = 2e-3 | ||||
weight_decay = 0.0 |
@@ -6,6 +6,7 @@ sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) | |||||
from collections import defaultdict | from collections import defaultdict | ||||
import math | import math | ||||
import torch | import torch | ||||
import re | |||||
from fastNLP.core.trainer import Trainer | from fastNLP.core.trainer import Trainer | ||||
from fastNLP.core.metrics import Evaluator | from fastNLP.core.metrics import Evaluator | ||||
@@ -55,10 +56,10 @@ class ConlluDataLoader(object): | |||||
return ds | return ds | ||||
def get_one(self, sample): | def get_one(self, sample): | ||||
text = ['<root>'] | |||||
pos_tags = ['<root>'] | |||||
heads = [0] | |||||
head_tags = ['root'] | |||||
text = [] | |||||
pos_tags = [] | |||||
heads = [] | |||||
head_tags = [] | |||||
for w in sample: | for w in sample: | ||||
t1, t2, t3, t4 = w[1], w[3], w[6], w[7] | t1, t2, t3, t4 = w[1], w[3], w[6], w[7] | ||||
if t3 == '_': | if t3 == '_': | ||||
@@ -96,12 +97,13 @@ class CTBDataLoader(object): | |||||
def convert(self, data): | def convert(self, data): | ||||
dataset = DataSet() | dataset = DataSet() | ||||
for sample in data: | for sample in data: | ||||
word_seq = ["<ROOT>"] + sample[0] | |||||
pos_seq = ["<ROOT>"] + sample[1] | |||||
heads = [0] + list(map(int, sample[2])) | |||||
head_tags = ["ROOT"] + sample[3] | |||||
word_seq = ["<s>"] + sample[0] + ['</s>'] | |||||
pos_seq = ["<s>"] + sample[1] + ['</s>'] | |||||
heads = [0] + list(map(int, sample[2])) + [0] | |||||
head_tags = ["<s>"] + sample[3] + ['</s>'] | |||||
dataset.append(Instance(word_seq=TextField(word_seq, is_target=False), | dataset.append(Instance(word_seq=TextField(word_seq, is_target=False), | ||||
pos_seq=TextField(pos_seq, is_target=False), | pos_seq=TextField(pos_seq, is_target=False), | ||||
gold_heads=SeqLabelField(heads, is_target=False), | |||||
head_indices=SeqLabelField(heads, is_target=True), | head_indices=SeqLabelField(heads, is_target=True), | ||||
head_labels=TextField(head_tags, is_target=True))) | head_labels=TextField(head_tags, is_target=True))) | ||||
return dataset | return dataset | ||||
@@ -117,7 +119,8 @@ datadir = '/home/yfshao/workdir/parser-data/' | |||||
train_data_name = "train_ctb5.txt" | train_data_name = "train_ctb5.txt" | ||||
dev_data_name = "dev_ctb5.txt" | dev_data_name = "dev_ctb5.txt" | ||||
test_data_name = "test_ctb5.txt" | test_data_name = "test_ctb5.txt" | ||||
emb_file_name = "/home/yfshao/parser-data/word_OOVthr_30_100v.txt" | |||||
emb_file_name = "/home/yfshao/workdir/parser-data/word_OOVthr_30_100v.txt" | |||||
# emb_file_name = "/home/yfshao/workdir/word_vector/cc.zh.300.vec" | |||||
loader = CTBDataLoader() | loader = CTBDataLoader() | ||||
cfgfile = './cfg.cfg' | cfgfile = './cfg.cfg' | ||||
@@ -129,6 +132,10 @@ test_args = ConfigSection() | |||||
model_args = ConfigSection() | model_args = ConfigSection() | ||||
optim_args = ConfigSection() | optim_args = ConfigSection() | ||||
ConfigLoader.load_config(cfgfile, {"train": train_args, "test": test_args, "model": model_args, "optim": optim_args}) | ConfigLoader.load_config(cfgfile, {"train": train_args, "test": test_args, "model": model_args, "optim": optim_args}) | ||||
print('trainre Args:', train_args.data) | |||||
print('test Args:', test_args.data) | |||||
print('optim Args:', optim_args.data) | |||||
# Pickle Loader | # Pickle Loader | ||||
def save_data(dirpath, **kwargs): | def save_data(dirpath, **kwargs): | ||||
@@ -151,9 +158,31 @@ def load_data(dirpath): | |||||
datas[name] = _pickle.load(f) | datas[name] = _pickle.load(f) | ||||
return datas | return datas | ||||
def P2(data, field, length): | |||||
ds = [ins for ins in data if ins[field].get_length() >= length] | |||||
data.clear() | |||||
data.extend(ds) | |||||
return ds | |||||
def P1(data, field): | |||||
def reeng(w): | |||||
return w if w == '<s>' or w == '</s>' or re.search(r'^([a-zA-Z]+[\.\-]*)+$', w) is None else 'ENG' | |||||
def renum(w): | |||||
return w if re.search(r'^[0-9]+\.?[0-9]*$', w) is None else 'NUMBER' | |||||
for ins in data: | |||||
ori = ins[field].contents() | |||||
s = list(map(renum, map(reeng, ori))) | |||||
if s != ori: | |||||
# print(ori) | |||||
# print(s) | |||||
# print() | |||||
ins[field] = ins[field].new(s) | |||||
return data | |||||
class ParserEvaluator(Evaluator): | class ParserEvaluator(Evaluator): | ||||
def __init__(self): | |||||
def __init__(self, ignore_label): | |||||
super(ParserEvaluator, self).__init__() | super(ParserEvaluator, self).__init__() | ||||
self.ignore = ignore_label | |||||
def __call__(self, predict_list, truth_list): | def __call__(self, predict_list, truth_list): | ||||
head_all, label_all, total_all = 0, 0, 0 | head_all, label_all, total_all = 0, 0, 0 | ||||
@@ -174,6 +203,7 @@ class ParserEvaluator(Evaluator): | |||||
label_pred_correct: number of correct predicted labels. | label_pred_correct: number of correct predicted labels. | ||||
total_tokens: number of predicted tokens | total_tokens: number of predicted tokens | ||||
""" | """ | ||||
seq_mask *= (head_labels != self.ignore).long() | |||||
head_pred_correct = (head_pred == head_indices).long() * seq_mask | head_pred_correct = (head_pred == head_indices).long() * seq_mask | ||||
_, label_preds = torch.max(label_pred, dim=2) | _, label_preds = torch.max(label_pred, dim=2) | ||||
label_pred_correct = (label_preds == head_labels).long() * head_pred_correct | label_pred_correct = (label_preds == head_labels).long() * head_pred_correct | ||||
@@ -181,72 +211,93 @@ class ParserEvaluator(Evaluator): | |||||
try: | try: | ||||
data_dict = load_data(processed_datadir) | data_dict = load_data(processed_datadir) | ||||
word_v = data_dict['word_v'] | |||||
pos_v = data_dict['pos_v'] | pos_v = data_dict['pos_v'] | ||||
tag_v = data_dict['tag_v'] | tag_v = data_dict['tag_v'] | ||||
train_data = data_dict['train_data'] | train_data = data_dict['train_data'] | ||||
dev_data = data_dict['dev_data'] | dev_data = data_dict['dev_data'] | ||||
test_data = data_dict['test_datas'] | |||||
print('use saved pickles') | print('use saved pickles') | ||||
except Exception as _: | except Exception as _: | ||||
print('load raw data and preprocess') | print('load raw data and preprocess') | ||||
word_v = Vocabulary(need_default=True, min_freq=2) | |||||
# use pretrain embedding | |||||
pos_v = Vocabulary(need_default=True) | pos_v = Vocabulary(need_default=True) | ||||
tag_v = Vocabulary(need_default=False) | tag_v = Vocabulary(need_default=False) | ||||
train_data = loader.load(os.path.join(datadir, train_data_name)) | train_data = loader.load(os.path.join(datadir, train_data_name)) | ||||
dev_data = loader.load(os.path.join(datadir, dev_data_name)) | dev_data = loader.load(os.path.join(datadir, dev_data_name)) | ||||
test_data = loader.load(os.path.join(datadir, test_data_name)) | test_data = loader.load(os.path.join(datadir, test_data_name)) | ||||
train_data.update_vocab(word_seq=word_v, pos_seq=pos_v, head_labels=tag_v) | |||||
save_data(processed_datadir, word_v=word_v, pos_v=pos_v, tag_v=tag_v, train_data=train_data, dev_data=dev_data) | |||||
train_data.update_vocab(pos_seq=pos_v, head_labels=tag_v) | |||||
save_data(processed_datadir, pos_v=pos_v, tag_v=tag_v, train_data=train_data, dev_data=dev_data, test_data=test_data) | |||||
train_data.index_field("word_seq", word_v).index_field("pos_seq", pos_v).index_field("head_labels", tag_v) | |||||
dev_data.index_field("word_seq", word_v).index_field("pos_seq", pos_v).index_field("head_labels", tag_v) | |||||
train_data.set_origin_len("word_seq") | |||||
dev_data.set_origin_len("word_seq") | |||||
embed, word_v = EmbedLoader.load_embedding(model_args['word_emb_dim'], emb_file_name, 'glove', None, os.path.join(processed_datadir, 'word_emb.pkl')) | |||||
word_v.unknown_label = "<OOV>" | |||||
print(train_data[:3]) | |||||
print(len(train_data)) | |||||
print(len(dev_data)) | |||||
# Model | |||||
model_args['word_vocab_size'] = len(word_v) | model_args['word_vocab_size'] = len(word_v) | ||||
model_args['pos_vocab_size'] = len(pos_v) | model_args['pos_vocab_size'] = len(pos_v) | ||||
model_args['num_label'] = len(tag_v) | model_args['num_label'] = len(tag_v) | ||||
model = BiaffineParser(**model_args.data) | |||||
model.reset_parameters() | |||||
datasets = (train_data, dev_data, test_data) | |||||
for ds in datasets: | |||||
# print('====='*30) | |||||
P1(ds, 'word_seq') | |||||
P2(ds, 'word_seq', 5) | |||||
ds.index_field("word_seq", word_v).index_field("pos_seq", pos_v).index_field("head_labels", tag_v) | |||||
ds.set_origin_len('word_seq') | |||||
if train_args['use_golden_train']: | |||||
ds.set_target(gold_heads=False) | |||||
else: | |||||
ds.set_target(gold_heads=None) | |||||
train_args.data.pop('use_golden_train') | |||||
ignore_label = pos_v['P'] | |||||
print(test_data[0]) | |||||
print(len(train_data)) | |||||
print(len(dev_data)) | |||||
print(len(test_data)) | |||||
def train(): | |||||
def train(path): | |||||
# Trainer | # Trainer | ||||
trainer = Trainer(**train_args.data) | trainer = Trainer(**train_args.data) | ||||
def _define_optim(obj): | def _define_optim(obj): | ||||
obj._optimizer = torch.optim.Adam(obj._model.parameters(), **optim_args.data) | |||||
lr = optim_args.data['lr'] | |||||
embed_params = set(obj._model.word_embedding.parameters()) | |||||
decay_params = set(obj._model.arc_predictor.parameters()) | set(obj._model.label_predictor.parameters()) | |||||
params = [p for p in obj._model.parameters() if p not in decay_params and p not in embed_params] | |||||
obj._optimizer = torch.optim.Adam([ | |||||
{'params': list(embed_params), 'lr':lr*0.1}, | |||||
{'params': list(decay_params), **optim_args.data}, | |||||
{'params': params} | |||||
], lr=lr) | |||||
obj._scheduler = torch.optim.lr_scheduler.LambdaLR(obj._optimizer, lambda ep: max(.75 ** (ep / 5e4), 0.05)) | obj._scheduler = torch.optim.lr_scheduler.LambdaLR(obj._optimizer, lambda ep: max(.75 ** (ep / 5e4), 0.05)) | ||||
def _update(obj): | def _update(obj): | ||||
# torch.nn.utils.clip_grad_norm_(obj._model.parameters(), 5.0) | |||||
obj._scheduler.step() | obj._scheduler.step() | ||||
obj._optimizer.step() | obj._optimizer.step() | ||||
trainer.define_optimizer = lambda: _define_optim(trainer) | trainer.define_optimizer = lambda: _define_optim(trainer) | ||||
trainer.update = lambda: _update(trainer) | trainer.update = lambda: _update(trainer) | ||||
trainer.set_validator(Tester(**test_args.data, evaluator=ParserEvaluator())) | |||||
trainer.set_validator(Tester(**test_args.data, evaluator=ParserEvaluator(ignore_label))) | |||||
# Model | |||||
model = BiaffineParser(**model_args.data) | |||||
# use pretrain embedding | |||||
word_v.unknown_label = "<OOV>" | |||||
embed, _ = EmbedLoader.load_embedding(model_args['word_emb_dim'], emb_file_name, 'glove', word_v, os.path.join(processed_datadir, 'word_emb.pkl')) | |||||
model.word_embedding = torch.nn.Embedding.from_pretrained(embed, freeze=False) | model.word_embedding = torch.nn.Embedding.from_pretrained(embed, freeze=False) | ||||
model.word_embedding.padding_idx = word_v.padding_idx | model.word_embedding.padding_idx = word_v.padding_idx | ||||
model.word_embedding.weight.data[word_v.padding_idx].fill_(0) | model.word_embedding.weight.data[word_v.padding_idx].fill_(0) | ||||
model.pos_embedding.padding_idx = pos_v.padding_idx | model.pos_embedding.padding_idx = pos_v.padding_idx | ||||
model.pos_embedding.weight.data[pos_v.padding_idx].fill_(0) | model.pos_embedding.weight.data[pos_v.padding_idx].fill_(0) | ||||
try: | |||||
ModelLoader.load_pytorch(model, "./save/saved_model.pkl") | |||||
print('model parameter loaded!') | |||||
except Exception as _: | |||||
print("No saved model. Continue.") | |||||
pass | |||||
# try: | |||||
# ModelLoader.load_pytorch(model, "./save/saved_model.pkl") | |||||
# print('model parameter loaded!') | |||||
# except Exception as _: | |||||
# print("No saved model. Continue.") | |||||
# pass | |||||
# Start training | # Start training | ||||
trainer.train(model, train_data, dev_data) | trainer.train(model, train_data, dev_data) | ||||
@@ -258,15 +309,15 @@ def train(): | |||||
print("Model saved!") | print("Model saved!") | ||||
def test(): | |||||
def test(path): | |||||
# Tester | # Tester | ||||
tester = Tester(**test_args.data, evaluator=ParserEvaluator()) | |||||
tester = Tester(**test_args.data, evaluator=ParserEvaluator(ignore_label)) | |||||
# Model | # Model | ||||
model = BiaffineParser(**model_args.data) | model = BiaffineParser(**model_args.data) | ||||
try: | try: | ||||
ModelLoader.load_pytorch(model, "./save/saved_model.pkl") | |||||
ModelLoader.load_pytorch(model, path) | |||||
print('model parameter loaded!') | print('model parameter loaded!') | ||||
except Exception as _: | except Exception as _: | ||||
print("No saved model. Abort test.") | print("No saved model. Abort test.") | ||||
@@ -284,11 +335,12 @@ if __name__ == "__main__": | |||||
import argparse | import argparse | ||||
parser = argparse.ArgumentParser(description='Run a chinese word segmentation model') | parser = argparse.ArgumentParser(description='Run a chinese word segmentation model') | ||||
parser.add_argument('--mode', help='set the model\'s model', choices=['train', 'test', 'infer']) | parser.add_argument('--mode', help='set the model\'s model', choices=['train', 'test', 'infer']) | ||||
parser.add_argument('--path', type=str, default='') | |||||
args = parser.parse_args() | args = parser.parse_args() | ||||
if args.mode == 'train': | if args.mode == 'train': | ||||
train() | |||||
train(args.path) | |||||
elif args.mode == 'test': | elif args.mode == 'test': | ||||
test() | |||||
test(args.path) | |||||
elif args.mode == 'infer': | elif args.mode == 'infer': | ||||
infer() | infer() | ||||
else: | else: | ||||