@@ -28,6 +28,12 @@ class Field(object): | |||
""" | |||
raise NotImplementedError | |||
def __repr__(self): | |||
return self.contents().__repr__() | |||
def new(self, *args, **kwargs): | |||
return self.__class__(*args, **kwargs, is_target=self.is_target) | |||
class TextField(Field): | |||
def __init__(self, name, text, is_target): | |||
""" | |||
@@ -35,6 +35,9 @@ class Instance(object): | |||
else: | |||
raise KeyError("{} not found".format(name)) | |||
def __setitem__(self, name, field): | |||
return self.add_field(name, field) | |||
def get_length(self): | |||
"""Fetch the length of all fields in the instance. | |||
@@ -82,3 +85,6 @@ class Instance(object): | |||
name, field_name = origin_len | |||
tensor_x[name] = torch.LongTensor([self.fields[field_name].get_length()]) | |||
return tensor_x, tensor_y | |||
def __repr__(self): | |||
return self.fields.__repr__() |
@@ -17,9 +17,9 @@ class Tester(object): | |||
""" | |||
super(Tester, self).__init__() | |||
""" | |||
"default_args" provides default value for important settings. | |||
The initialization arguments "kwargs" with the same key (name) will override the default value. | |||
"kwargs" must have the same type as "default_args" on corresponding keys. | |||
"default_args" provides default value for important settings. | |||
The initialization arguments "kwargs" with the same key (name) will override the default value. | |||
"kwargs" must have the same type as "default_args" on corresponding keys. | |||
Otherwise, error will raise. | |||
""" | |||
default_args = {"batch_size": 8, | |||
@@ -29,8 +29,8 @@ class Tester(object): | |||
"evaluator": Evaluator() | |||
} | |||
""" | |||
"required_args" is the collection of arguments that users must pass to Trainer explicitly. | |||
This is used to warn users of essential settings in the training. | |||
"required_args" is the collection of arguments that users must pass to Trainer explicitly. | |||
This is used to warn users of essential settings in the training. | |||
Specially, "required_args" does not have default value, so they have nothing to do with "default_args". | |||
""" | |||
required_args = {} | |||
@@ -74,16 +74,19 @@ class Tester(object): | |||
output_list = [] | |||
truth_list = [] | |||
data_iterator = Batch(dev_data, self.batch_size, sampler=RandomSampler(), use_cuda=self.use_cuda) | |||
data_iterator = Batch(dev_data, self.batch_size, sampler=RandomSampler(), use_cuda=self.use_cuda, sort_in_batch=True, sort_key='word_seq') | |||
for batch_x, batch_y in data_iterator: | |||
with torch.no_grad(): | |||
with torch.no_grad(): | |||
for batch_x, batch_y in data_iterator: | |||
prediction = self.data_forward(network, batch_x) | |||
output_list.append(prediction) | |||
truth_list.append(batch_y) | |||
eval_results = self.evaluate(output_list, truth_list) | |||
output_list.append(prediction) | |||
truth_list.append(batch_y) | |||
eval_results = self.evaluate(output_list, truth_list) | |||
print("[tester] {}".format(self.print_eval_results(eval_results))) | |||
logger.info("[tester] {}".format(self.print_eval_results(eval_results))) | |||
self.mode(network, is_test=False) | |||
self.metrics = eval_results | |||
return eval_results | |||
def mode(self, model, is_test=False): | |||
"""Train mode or Test mode. This is for PyTorch currently. | |||
@@ -1,6 +1,6 @@ | |||
import os | |||
import time | |||
from datetime import timedelta | |||
from datetime import timedelta, datetime | |||
import torch | |||
from tensorboardX import SummaryWriter | |||
@@ -15,7 +15,7 @@ from fastNLP.saver.logger import create_logger | |||
from fastNLP.saver.model_saver import ModelSaver | |||
logger = create_logger(__name__, "./train_test.log") | |||
logger.disabled = True | |||
class Trainer(object): | |||
"""Operations of training a model, including data loading, gradient descent, and validation. | |||
@@ -35,20 +35,21 @@ class Trainer(object): | |||
super(Trainer, self).__init__() | |||
""" | |||
"default_args" provides default value for important settings. | |||
The initialization arguments "kwargs" with the same key (name) will override the default value. | |||
"kwargs" must have the same type as "default_args" on corresponding keys. | |||
"default_args" provides default value for important settings. | |||
The initialization arguments "kwargs" with the same key (name) will override the default value. | |||
"kwargs" must have the same type as "default_args" on corresponding keys. | |||
Otherwise, error will raise. | |||
""" | |||
default_args = {"epochs": 1, "batch_size": 2, "validate": False, "use_cuda": False, "pickle_path": "./save/", | |||
"save_best_dev": False, "model_name": "default_model_name.pkl", "print_every_step": 1, | |||
"valid_step": 500, "eval_sort_key": 'acc', | |||
"loss": Loss(None), # used to pass type check | |||
"optimizer": Optimizer("Adam", lr=0.001, weight_decay=0), | |||
"evaluator": Evaluator() | |||
} | |||
""" | |||
"required_args" is the collection of arguments that users must pass to Trainer explicitly. | |||
This is used to warn users of essential settings in the training. | |||
"required_args" is the collection of arguments that users must pass to Trainer explicitly. | |||
This is used to warn users of essential settings in the training. | |||
Specially, "required_args" does not have default value, so they have nothing to do with "default_args". | |||
""" | |||
required_args = {} | |||
@@ -70,16 +71,20 @@ class Trainer(object): | |||
else: | |||
# Trainer doesn't care about extra arguments | |||
pass | |||
print(default_args) | |||
print("Training Args {}".format(default_args)) | |||
logger.info("Training Args {}".format(default_args)) | |||
self.n_epochs = default_args["epochs"] | |||
self.batch_size = default_args["batch_size"] | |||
self.n_epochs = int(default_args["epochs"]) | |||
self.batch_size = int(default_args["batch_size"]) | |||
self.pickle_path = default_args["pickle_path"] | |||
self.validate = default_args["validate"] | |||
self.save_best_dev = default_args["save_best_dev"] | |||
self.use_cuda = default_args["use_cuda"] | |||
self.model_name = default_args["model_name"] | |||
self.print_every_step = default_args["print_every_step"] | |||
self.print_every_step = int(default_args["print_every_step"]) | |||
self.valid_step = int(default_args["valid_step"]) | |||
if self.validate is not None: | |||
assert self.valid_step > 0 | |||
self._model = None | |||
self._loss_func = default_args["loss"].get() # return a pytorch loss function or None | |||
@@ -89,6 +94,8 @@ class Trainer(object): | |||
self._summary_writer = SummaryWriter(self.pickle_path + 'tensorboard_logs') | |||
self._graph_summaried = False | |||
self._best_accuracy = 0.0 | |||
self.eval_sort_key = default_args['eval_sort_key'] | |||
self.validator = None | |||
def train(self, network, train_data, dev_data=None): | |||
"""General Training Procedure | |||
@@ -104,12 +111,17 @@ class Trainer(object): | |||
else: | |||
self._model = network | |||
print(self._model) | |||
# define Tester over dev data | |||
self.dev_data = None | |||
if self.validate: | |||
default_valid_args = {"batch_size": self.batch_size, "pickle_path": self.pickle_path, | |||
"use_cuda": self.use_cuda, "evaluator": self._evaluator} | |||
validator = self._create_validator(default_valid_args) | |||
logger.info("validator defined as {}".format(str(validator))) | |||
if self.validator is None: | |||
self.validator = self._create_validator(default_valid_args) | |||
logger.info("validator defined as {}".format(str(self.validator))) | |||
self.dev_data = dev_data | |||
# optimizer and loss | |||
self.define_optimizer() | |||
@@ -117,29 +129,33 @@ class Trainer(object): | |||
self.define_loss() | |||
logger.info("loss function defined as {}".format(str(self._loss_func))) | |||
# turn on network training mode | |||
self.mode(network, is_test=False) | |||
# main training procedure | |||
start = time.time() | |||
logger.info("training epochs started") | |||
for epoch in range(1, self.n_epochs + 1): | |||
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S')) | |||
print("training epochs started " + self.start_time) | |||
logger.info("training epochs started " + self.start_time) | |||
epoch, iters = 1, 0 | |||
while(1): | |||
if self.n_epochs != -1 and epoch > self.n_epochs: | |||
break | |||
logger.info("training epoch {}".format(epoch)) | |||
# turn on network training mode | |||
self.mode(network, is_test=False) | |||
# prepare mini-batch iterator | |||
data_iterator = Batch(train_data, batch_size=self.batch_size, sampler=RandomSampler(), | |||
use_cuda=self.use_cuda) | |||
use_cuda=self.use_cuda, sort_in_batch=True, sort_key='word_seq') | |||
logger.info("prepared data iterator") | |||
# one forward and backward pass | |||
self._train_step(data_iterator, network, start=start, n_print=self.print_every_step, epoch=epoch) | |||
iters = self._train_step(data_iterator, network, start=start, n_print=self.print_every_step, epoch=epoch, step=iters, dev_data=dev_data) | |||
# validation | |||
if self.validate: | |||
if dev_data is None: | |||
raise RuntimeError( | |||
"self.validate is True in trainer, but dev_data is None. Please provide the validation data.") | |||
logger.info("validation started") | |||
validator.test(network, dev_data) | |||
self.valid_model() | |||
self.save_model(self._model, 'training_model_'+self.start_time) | |||
epoch += 1 | |||
def _train_step(self, data_iterator, network, **kwargs): | |||
"""Training process in one epoch. | |||
@@ -149,13 +165,17 @@ class Trainer(object): | |||
- start: time.time(), the starting time of this step. | |||
- epoch: int, | |||
""" | |||
step = 0 | |||
step = kwargs['step'] | |||
for batch_x, batch_y in data_iterator: | |||
prediction = self.data_forward(network, batch_x) | |||
loss = self.get_loss(prediction, batch_y) | |||
self.grad_backward(loss) | |||
# if torch.rand(1).item() < 0.001: | |||
# print('[grads at epoch: {:>3} step: {:>4}]'.format(kwargs['epoch'], step)) | |||
# for name, p in self._model.named_parameters(): | |||
# if p.requires_grad: | |||
# print('\t{} {} {}'.format(name, tuple(p.size()), torch.sum(p.grad).item())) | |||
self.update() | |||
self._summary_writer.add_scalar("loss", loss.item(), global_step=step) | |||
@@ -166,7 +186,22 @@ class Trainer(object): | |||
kwargs["epoch"], step, loss.data, diff) | |||
print(print_output) | |||
logger.info(print_output) | |||
if self.validate and self.valid_step > 0 and step > 0 and step % self.valid_step == 0: | |||
self.valid_model() | |||
step += 1 | |||
return step | |||
def valid_model(self): | |||
if self.dev_data is None: | |||
raise RuntimeError( | |||
"self.validate is True in trainer, but dev_data is None. Please provide the validation data.") | |||
logger.info("validation started") | |||
res = self.validator.test(self._model, self.dev_data) | |||
if self.save_best_dev and self.best_eval_result(res): | |||
logger.info('save best result! {}'.format(res)) | |||
print('save best result! {}'.format(res)) | |||
self.save_model(self._model, 'best_model_'+self.start_time) | |||
return res | |||
def mode(self, model, is_test=False): | |||
"""Train mode or Test mode. This is for PyTorch currently. | |||
@@ -180,11 +215,17 @@ class Trainer(object): | |||
else: | |||
model.train() | |||
def define_optimizer(self): | |||
def define_optimizer(self, optim=None): | |||
"""Define framework-specific optimizer specified by the models. | |||
""" | |||
self._optimizer = self._optimizer_proto.construct_from_pytorch(self._model.parameters()) | |||
if optim is not None: | |||
# optimizer constructed by user | |||
self._optimizer = optim | |||
elif self._optimizer is None: | |||
# optimizer constructed by proto | |||
self._optimizer = self._optimizer_proto.construct_from_pytorch(self._model.parameters()) | |||
return self._optimizer | |||
def update(self): | |||
"""Perform weight update on a model. | |||
@@ -217,6 +258,8 @@ class Trainer(object): | |||
:param truth: ground truth label vector | |||
:return: a scalar | |||
""" | |||
if isinstance(predict, dict) and isinstance(truth, dict): | |||
return self._loss_func(**predict, **truth) | |||
if len(truth) > 1: | |||
raise NotImplementedError("Not ready to handle multi-labels.") | |||
truth = list(truth.values())[0] if len(truth) > 0 else None | |||
@@ -241,13 +284,23 @@ class Trainer(object): | |||
raise ValueError("Please specify a loss function.") | |||
logger.info("The model didn't define loss, use Trainer's loss.") | |||
def best_eval_result(self, validator): | |||
def best_eval_result(self, metrics): | |||
"""Check if the current epoch yields better validation results. | |||
:param validator: a Tester instance | |||
:return: bool, True means current results on dev set is the best. | |||
""" | |||
loss, accuracy = validator.metrics | |||
if isinstance(metrics, tuple): | |||
loss, metrics = metrics | |||
if isinstance(metrics, dict): | |||
if len(metrics) == 1: | |||
accuracy = list(metrics.values())[0] | |||
else: | |||
accuracy = metrics[self.eval_sort_key] | |||
else: | |||
accuracy = metrics | |||
if accuracy > self._best_accuracy: | |||
self._best_accuracy = accuracy | |||
return True | |||
@@ -268,6 +321,8 @@ class Trainer(object): | |||
def _create_validator(self, valid_args): | |||
raise NotImplementedError | |||
def set_validator(self, validor): | |||
self.validator = validor | |||
class SeqLabelTrainer(Trainer): | |||
"""Trainer for Sequence Labeling | |||
@@ -51,6 +51,12 @@ class Vocabulary(object): | |||
self.min_freq = min_freq | |||
self.word_count = {} | |||
self.has_default = need_default | |||
if self.has_default: | |||
self.padding_label = DEFAULT_PADDING_LABEL | |||
self.unknown_label = DEFAULT_UNKNOWN_LABEL | |||
else: | |||
self.padding_label = None | |||
self.unknown_label = None | |||
self.word2idx = None | |||
self.idx2word = None | |||
@@ -77,12 +83,10 @@ class Vocabulary(object): | |||
""" | |||
if self.has_default: | |||
self.word2idx = deepcopy(DEFAULT_WORD_TO_INDEX) | |||
self.padding_label = DEFAULT_PADDING_LABEL | |||
self.unknown_label = DEFAULT_UNKNOWN_LABEL | |||
self.word2idx[self.unknown_label] = self.word2idx.pop(DEFAULT_UNKNOWN_LABEL) | |||
self.word2idx[self.padding_label] = self.word2idx.pop(DEFAULT_PADDING_LABEL) | |||
else: | |||
self.word2idx = {} | |||
self.padding_label = None | |||
self.unknown_label = None | |||
words = sorted(self.word_count.items(), key=lambda kv: kv[1], reverse=True) | |||
if self.min_freq is not None: | |||
@@ -114,7 +118,7 @@ class Vocabulary(object): | |||
if w in self.word2idx: | |||
return self.word2idx[w] | |||
elif self.has_default: | |||
return self.word2idx[DEFAULT_UNKNOWN_LABEL] | |||
return self.word2idx[self.unknown_label] | |||
else: | |||
raise ValueError("word {} not in vocabulary".format(w)) | |||
@@ -134,6 +138,11 @@ class Vocabulary(object): | |||
return None | |||
return self.word2idx[self.unknown_label] | |||
def __setattr__(self, name, val): | |||
self.__dict__[name] = val | |||
if name in self.__dict__ and name in ["unknown_label", "padding_label"]: | |||
self.word2idx = None | |||
@property | |||
@check_build_vocab | |||
def padding_idx(self): | |||
@@ -87,7 +87,6 @@ class DataSetLoader(BaseLoader): | |||
""" | |||
raise NotImplementedError | |||
@DataSet.set_reader('read_raw') | |||
class RawDataSetLoader(DataSetLoader): | |||
def __init__(self): | |||
@@ -103,7 +102,6 @@ class RawDataSetLoader(DataSetLoader): | |||
def convert(self, data): | |||
return convert_seq_dataset(data) | |||
@DataSet.set_reader('read_pos') | |||
class POSDataSetLoader(DataSetLoader): | |||
"""Dataset Loader for POS Tag datasets. | |||
@@ -173,7 +171,6 @@ class POSDataSetLoader(DataSetLoader): | |||
""" | |||
return convert_seq2seq_dataset(data) | |||
@DataSet.set_reader('read_tokenize') | |||
class TokenizeDataSetLoader(DataSetLoader): | |||
""" | |||
@@ -233,7 +230,6 @@ class TokenizeDataSetLoader(DataSetLoader): | |||
def convert(self, data): | |||
return convert_seq2seq_dataset(data) | |||
@DataSet.set_reader('read_class') | |||
class ClassDataSetLoader(DataSetLoader): | |||
"""Loader for classification data sets""" | |||
@@ -272,7 +268,6 @@ class ClassDataSetLoader(DataSetLoader): | |||
def convert(self, data): | |||
return convert_seq2tag_dataset(data) | |||
@DataSet.set_reader('read_conll') | |||
class ConllLoader(DataSetLoader): | |||
"""loader for conll format files""" | |||
@@ -314,7 +309,6 @@ class ConllLoader(DataSetLoader): | |||
def convert(self, data): | |||
pass | |||
@DataSet.set_reader('read_lm') | |||
class LMDataSetLoader(DataSetLoader): | |||
"""Language Model Dataset Loader | |||
@@ -351,7 +345,6 @@ class LMDataSetLoader(DataSetLoader): | |||
def convert(self, data): | |||
pass | |||
@DataSet.set_reader('read_people_daily') | |||
class PeopleDailyCorpusLoader(DataSetLoader): | |||
""" | |||
@@ -17,8 +17,8 @@ class EmbedLoader(BaseLoader): | |||
def _load_glove(emb_file): | |||
"""Read file as a glove embedding | |||
file format: | |||
embeddings are split by line, | |||
file format: | |||
embeddings are split by line, | |||
for one embedding, word and numbers split by space | |||
Example:: | |||
@@ -33,7 +33,7 @@ class EmbedLoader(BaseLoader): | |||
if len(line) > 0: | |||
emb[line[0]] = torch.Tensor(list(map(float, line[1:]))) | |||
return emb | |||
@staticmethod | |||
def _load_pretrain(emb_file, emb_type): | |||
"""Read txt data from embedding file and convert to np.array as pre-trained embedding | |||
@@ -16,10 +16,9 @@ def mst(scores): | |||
https://github.com/tdozat/Parser/blob/0739216129cd39d69997d28cbc4133b360ea3934/lib/models/nn.py#L692 | |||
""" | |||
length = scores.shape[0] | |||
min_score = -np.inf | |||
mask = np.zeros((length, length)) | |||
np.fill_diagonal(mask, -np.inf) | |||
scores = scores + mask | |||
min_score = scores.min() - 1 | |||
eye = np.eye(length) | |||
scores = scores * (1 - eye) + min_score * eye | |||
heads = np.argmax(scores, axis=1) | |||
heads[0] = 0 | |||
tokens = np.arange(1, length) | |||
@@ -126,6 +125,8 @@ class GraphParser(nn.Module): | |||
def _greedy_decoder(self, arc_matrix, seq_mask=None): | |||
_, seq_len, _ = arc_matrix.shape | |||
matrix = arc_matrix + torch.diag(arc_matrix.new(seq_len).fill_(-np.inf)) | |||
flip_mask = (seq_mask == 0).byte() | |||
matrix.masked_fill_(flip_mask.unsqueeze(1), -np.inf) | |||
_, heads = torch.max(matrix, dim=2) | |||
if seq_mask is not None: | |||
heads *= seq_mask.long() | |||
@@ -135,8 +136,15 @@ class GraphParser(nn.Module): | |||
batch_size, seq_len, _ = arc_matrix.shape | |||
matrix = torch.zeros_like(arc_matrix).copy_(arc_matrix) | |||
ans = matrix.new_zeros(batch_size, seq_len).long() | |||
lens = (seq_mask.long()).sum(1) if seq_mask is not None else torch.zeros(batch_size) + seq_len | |||
batch_idx = torch.arange(batch_size, dtype=torch.long, device=lens.device) | |||
seq_mask[batch_idx, lens-1] = 0 | |||
for i, graph in enumerate(matrix): | |||
ans[i] = torch.as_tensor(mst(graph.cpu().numpy()), device=ans.device) | |||
len_i = lens[i] | |||
if len_i == seq_len: | |||
ans[i] = torch.as_tensor(mst(graph.cpu().numpy()), device=ans.device) | |||
else: | |||
ans[i, :len_i] = torch.as_tensor(mst(graph[:len_i, :len_i].cpu().numpy()), device=ans.device) | |||
if seq_mask is not None: | |||
ans *= seq_mask.long() | |||
return ans | |||
@@ -175,14 +183,19 @@ class LabelBilinear(nn.Module): | |||
def __init__(self, in1_features, in2_features, num_label, bias=True): | |||
super(LabelBilinear, self).__init__() | |||
self.bilinear = nn.Bilinear(in1_features, in2_features, num_label, bias=bias) | |||
self.lin1 = nn.Linear(in1_features, num_label, bias=False) | |||
self.lin2 = nn.Linear(in2_features, num_label, bias=False) | |||
self.lin = nn.Linear(in1_features + in2_features, num_label, bias=False) | |||
def forward(self, x1, x2): | |||
output = self.bilinear(x1, x2) | |||
output += self.lin1(x1) + self.lin2(x2) | |||
output += self.lin(torch.cat([x1, x2], dim=2)) | |||
return output | |||
def len2masks(origin_len, max_len): | |||
if origin_len.dim() <= 1: | |||
origin_len = origin_len.unsqueeze(1) # [batch_size, 1] | |||
seq_range = torch.arange(start=0, end=max_len, dtype=torch.long, device=origin_len.device) # [max_len,] | |||
seq_mask = torch.gt(origin_len, seq_range.unsqueeze(0)) # [batch_size, max_len] | |||
return seq_mask | |||
class BiaffineParser(GraphParser): | |||
"""Biaffine Dependency Parser implemantation. | |||
@@ -194,6 +207,8 @@ class BiaffineParser(GraphParser): | |||
word_emb_dim, | |||
pos_vocab_size, | |||
pos_emb_dim, | |||
word_hid_dim, | |||
pos_hid_dim, | |||
rnn_layers, | |||
rnn_hidden_size, | |||
arc_mlp_size, | |||
@@ -204,10 +219,15 @@ class BiaffineParser(GraphParser): | |||
use_greedy_infer=False): | |||
super(BiaffineParser, self).__init__() | |||
rnn_out_size = 2 * rnn_hidden_size | |||
self.word_embedding = nn.Embedding(num_embeddings=word_vocab_size, embedding_dim=word_emb_dim) | |||
self.pos_embedding = nn.Embedding(num_embeddings=pos_vocab_size, embedding_dim=pos_emb_dim) | |||
self.word_fc = nn.Linear(word_emb_dim, word_hid_dim) | |||
self.pos_fc = nn.Linear(pos_emb_dim, pos_hid_dim) | |||
self.word_norm = nn.LayerNorm(word_hid_dim) | |||
self.pos_norm = nn.LayerNorm(pos_hid_dim) | |||
if use_var_lstm: | |||
self.lstm = VarLSTM(input_size=word_emb_dim + pos_emb_dim, | |||
self.lstm = VarLSTM(input_size=word_hid_dim + pos_hid_dim, | |||
hidden_size=rnn_hidden_size, | |||
num_layers=rnn_layers, | |||
bias=True, | |||
@@ -216,7 +236,7 @@ class BiaffineParser(GraphParser): | |||
hidden_dropout=dropout, | |||
bidirectional=True) | |||
else: | |||
self.lstm = nn.LSTM(input_size=word_emb_dim + pos_emb_dim, | |||
self.lstm = nn.LSTM(input_size=word_hid_dim + pos_hid_dim, | |||
hidden_size=rnn_hidden_size, | |||
num_layers=rnn_layers, | |||
bias=True, | |||
@@ -224,21 +244,35 @@ class BiaffineParser(GraphParser): | |||
dropout=dropout, | |||
bidirectional=True) | |||
rnn_out_size = 2 * rnn_hidden_size | |||
self.arc_head_mlp = nn.Sequential(nn.Linear(rnn_out_size, arc_mlp_size), | |||
nn.ELU()) | |||
nn.LayerNorm(arc_mlp_size), | |||
nn.ELU(), | |||
TimestepDropout(p=dropout),) | |||
self.arc_dep_mlp = copy.deepcopy(self.arc_head_mlp) | |||
self.label_head_mlp = nn.Sequential(nn.Linear(rnn_out_size, label_mlp_size), | |||
nn.ELU()) | |||
nn.LayerNorm(label_mlp_size), | |||
nn.ELU(), | |||
TimestepDropout(p=dropout),) | |||
self.label_dep_mlp = copy.deepcopy(self.label_head_mlp) | |||
self.arc_predictor = ArcBiaffine(arc_mlp_size, bias=True) | |||
self.label_predictor = LabelBilinear(label_mlp_size, label_mlp_size, num_label, bias=True) | |||
self.normal_dropout = nn.Dropout(p=dropout) | |||
self.timestep_dropout = TimestepDropout(p=dropout) | |||
self.use_greedy_infer = use_greedy_infer | |||
initial_parameter(self) | |||
self.reset_parameters() | |||
self.explore_p = 0.2 | |||
def reset_parameters(self): | |||
for m in self.modules(): | |||
if isinstance(m, nn.Embedding): | |||
continue | |||
elif isinstance(m, nn.LayerNorm): | |||
nn.init.constant_(m.weight, 0.1) | |||
nn.init.constant_(m.bias, 0) | |||
else: | |||
for p in m.parameters(): | |||
nn.init.normal_(p, 0, 0.1) | |||
def forward(self, word_seq, pos_seq, seq_mask, gold_heads=None, **_): | |||
def forward(self, word_seq, pos_seq, word_seq_origin_len, gold_heads=None, **_): | |||
""" | |||
:param word_seq: [batch_size, seq_len] sequence of word's indices | |||
:param pos_seq: [batch_size, seq_len] sequence of word's indices | |||
@@ -253,32 +287,35 @@ class BiaffineParser(GraphParser): | |||
# prepare embeddings | |||
batch_size, seq_len = word_seq.shape | |||
# print('forward {} {}'.format(batch_size, seq_len)) | |||
batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=word_seq.device).unsqueeze(1) | |||
# get sequence mask | |||
seq_mask = seq_mask.long() | |||
seq_mask = len2masks(word_seq_origin_len, seq_len).long() | |||
word = self.normal_dropout(self.word_embedding(word_seq)) # [N,L] -> [N,L,C_0] | |||
pos = self.normal_dropout(self.pos_embedding(pos_seq)) # [N,L] -> [N,L,C_1] | |||
word, pos = self.word_fc(word), self.pos_fc(pos) | |||
word, pos = self.word_norm(word), self.pos_norm(pos) | |||
x = torch.cat([word, pos], dim=2) # -> [N,L,C] | |||
del word, pos | |||
# lstm, extract features | |||
x = nn.utils.rnn.pack_padded_sequence(x, word_seq_origin_len.squeeze(1), batch_first=True) | |||
feat, _ = self.lstm(x) # -> [N,L,C] | |||
feat, _ = nn.utils.rnn.pad_packed_sequence(feat, batch_first=True) | |||
# for arc biaffine | |||
# mlp, reduce dim | |||
arc_dep = self.timestep_dropout(self.arc_dep_mlp(feat)) | |||
arc_head = self.timestep_dropout(self.arc_head_mlp(feat)) | |||
label_dep = self.timestep_dropout(self.label_dep_mlp(feat)) | |||
label_head = self.timestep_dropout(self.label_head_mlp(feat)) | |||
arc_dep = self.arc_dep_mlp(feat) | |||
arc_head = self.arc_head_mlp(feat) | |||
label_dep = self.label_dep_mlp(feat) | |||
label_head = self.label_head_mlp(feat) | |||
del feat | |||
# biaffine arc classifier | |||
arc_pred = self.arc_predictor(arc_head, arc_dep) # [N, L, L] | |||
flip_mask = (seq_mask == 0) | |||
arc_pred.masked_fill_(flip_mask.unsqueeze(1), -np.inf) | |||
# use gold or predicted arc to predict label | |||
if gold_heads is None: | |||
if gold_heads is None or not self.training: | |||
# use greedy decoding in training | |||
if self.training or self.use_greedy_infer: | |||
heads = self._greedy_decoder(arc_pred, seq_mask) | |||
@@ -286,9 +323,15 @@ class BiaffineParser(GraphParser): | |||
heads = self._mst_decoder(arc_pred, seq_mask) | |||
head_pred = heads | |||
else: | |||
head_pred = None | |||
heads = gold_heads | |||
assert self.training # must be training mode | |||
if torch.rand(1).item() < self.explore_p: | |||
heads = self._greedy_decoder(arc_pred, seq_mask) | |||
head_pred = heads | |||
else: | |||
head_pred = None | |||
heads = gold_heads | |||
batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=word_seq.device).unsqueeze(1) | |||
label_head = label_head[batch_range, heads].contiguous() | |||
label_pred = self.label_predictor(label_head, label_dep) # [N, L, num_label] | |||
res_dict = {'arc_pred': arc_pred, 'label_pred': label_pred, 'seq_mask': seq_mask} | |||
@@ -301,7 +344,7 @@ class BiaffineParser(GraphParser): | |||
Compute loss. | |||
:param arc_pred: [batch_size, seq_len, seq_len] | |||
:param label_pred: [batch_size, seq_len, seq_len] | |||
:param label_pred: [batch_size, seq_len, n_tags] | |||
:param head_indices: [batch_size, seq_len] | |||
:param head_labels: [batch_size, seq_len] | |||
:param seq_mask: [batch_size, seq_len] | |||
@@ -309,10 +352,13 @@ class BiaffineParser(GraphParser): | |||
""" | |||
batch_size, seq_len, _ = arc_pred.shape | |||
arc_logits = F.log_softmax(arc_pred, dim=2) | |||
flip_mask = (seq_mask == 0) | |||
_arc_pred = arc_pred.new_empty((batch_size, seq_len, seq_len)).copy_(arc_pred) | |||
_arc_pred.masked_fill_(flip_mask.unsqueeze(1), -np.inf) | |||
arc_logits = F.log_softmax(_arc_pred, dim=2) | |||
label_logits = F.log_softmax(label_pred, dim=2) | |||
batch_index = torch.arange(start=0, end=batch_size, device=arc_logits.device).long().unsqueeze(1) | |||
child_index = torch.arange(start=0, end=seq_len, device=arc_logits.device).long().unsqueeze(0) | |||
batch_index = torch.arange(batch_size, device=arc_logits.device, dtype=torch.long).unsqueeze(1) | |||
child_index = torch.arange(seq_len, device=arc_logits.device, dtype=torch.long).unsqueeze(0) | |||
arc_loss = arc_logits[batch_index, child_index, head_indices] | |||
label_loss = label_logits[batch_index, child_index, head_labels] | |||
@@ -320,45 +366,8 @@ class BiaffineParser(GraphParser): | |||
label_loss = label_loss[:, 1:] | |||
float_mask = seq_mask[:, 1:].float() | |||
length = (seq_mask.sum() - batch_size).float() | |||
arc_nll = -(arc_loss*float_mask).sum() / length | |||
label_nll = -(label_loss*float_mask).sum() / length | |||
arc_nll = -(arc_loss*float_mask).mean() | |||
label_nll = -(label_loss*float_mask).mean() | |||
return arc_nll + label_nll | |||
def evaluate(self, arc_pred, label_pred, head_indices, head_labels, seq_mask, **kwargs): | |||
""" | |||
Evaluate the performance of prediction. | |||
:return dict: performance results. | |||
head_pred_corrct: number of correct predicted heads. | |||
label_pred_correct: number of correct predicted labels. | |||
total_tokens: number of predicted tokens | |||
""" | |||
if 'head_pred' in kwargs: | |||
head_pred = kwargs['head_pred'] | |||
elif self.use_greedy_infer: | |||
head_pred = self._greedy_decoder(arc_pred, seq_mask) | |||
else: | |||
head_pred = self._mst_decoder(arc_pred, seq_mask) | |||
head_pred_correct = (head_pred == head_indices).long() * seq_mask | |||
_, label_preds = torch.max(label_pred, dim=2) | |||
label_pred_correct = (label_preds == head_labels).long() * head_pred_correct | |||
return {"head_pred_correct": head_pred_correct.sum(dim=1), | |||
"label_pred_correct": label_pred_correct.sum(dim=1), | |||
"total_tokens": seq_mask.sum(dim=1)} | |||
def metrics(self, head_pred_correct, label_pred_correct, total_tokens, **_): | |||
""" | |||
Compute the metrics of model | |||
:param head_pred_corrct: number of correct predicted heads. | |||
:param label_pred_correct: number of correct predicted labels. | |||
:param total_tokens: number of predicted tokens | |||
:return dict: the metrics results | |||
UAS: the head predicted accuracy | |||
LAS: the label predicted accuracy | |||
""" | |||
return {"UAS": head_pred_correct.sum().float() / total_tokens.sum().float() * 100, | |||
"LAS": label_pred_correct.sum().float() / total_tokens.sum().float() * 100} | |||
@@ -1,5 +1,6 @@ | |||
import torch | |||
from torch import nn | |||
import math | |||
from fastNLP.modules.utils import mask_softmax | |||
@@ -17,3 +18,44 @@ class Attention(torch.nn.Module): | |||
def _atten_forward(self, query, memory): | |||
raise NotImplementedError | |||
class DotAtte(nn.Module): | |||
def __init__(self, key_size, value_size): | |||
super(DotAtte, self).__init__() | |||
self.key_size = key_size | |||
self.value_size = value_size | |||
self.scale = math.sqrt(key_size) | |||
def forward(self, Q, K, V, seq_mask=None): | |||
""" | |||
:param Q: [batch, seq_len, key_size] | |||
:param K: [batch, seq_len, key_size] | |||
:param V: [batch, seq_len, value_size] | |||
:param seq_mask: [batch, seq_len] | |||
""" | |||
output = torch.matmul(Q, K.transpose(1, 2)) / self.scale | |||
if seq_mask is not None: | |||
output.masked_fill_(seq_mask.lt(1), -float('inf')) | |||
output = nn.functional.softmax(output, dim=2) | |||
return torch.matmul(output, V) | |||
class MultiHeadAtte(nn.Module): | |||
def __init__(self, input_size, output_size, key_size, value_size, num_atte): | |||
super(MultiHeadAtte, self).__init__() | |||
self.in_linear = nn.ModuleList() | |||
for i in range(num_atte * 3): | |||
out_feat = key_size if (i % 3) != 2 else value_size | |||
self.in_linear.append(nn.Linear(input_size, out_feat)) | |||
self.attes = nn.ModuleList([DotAtte(key_size, value_size) for _ in range(num_atte)]) | |||
self.out_linear = nn.Linear(value_size * num_atte, output_size) | |||
def forward(self, Q, K, V, seq_mask=None): | |||
heads = [] | |||
for i in range(len(self.attes)): | |||
j = i * 3 | |||
qi, ki, vi = self.in_linear[j](Q), self.in_linear[j+1](K), self.in_linear[j+2](V) | |||
headi = self.attes[i](qi, ki, vi, seq_mask) | |||
heads.append(headi) | |||
output = torch.cat(heads, dim=2) | |||
return self.out_linear(output) |
@@ -0,0 +1,32 @@ | |||
import torch | |||
from torch import nn | |||
import torch.nn.functional as F | |||
from ..aggregator.attention import MultiHeadAtte | |||
from ..other_modules import LayerNormalization | |||
class TransformerEncoder(nn.Module): | |||
class SubLayer(nn.Module): | |||
def __init__(self, input_size, output_size, key_size, value_size, num_atte): | |||
super(TransformerEncoder.SubLayer, self).__init__() | |||
self.atte = MultiHeadAtte(input_size, output_size, key_size, value_size, num_atte) | |||
self.norm1 = LayerNormalization(output_size) | |||
self.ffn = nn.Sequential(nn.Linear(output_size, output_size), | |||
nn.ReLU(), | |||
nn.Linear(output_size, output_size)) | |||
self.norm2 = LayerNormalization(output_size) | |||
def forward(self, input, seq_mask): | |||
attention = self.atte(input) | |||
norm_atte = self.norm1(attention + input) | |||
output = self.ffn(norm_atte) | |||
return self.norm2(output + norm_atte) | |||
def __init__(self, num_layers, **kargs): | |||
super(TransformerEncoder, self).__init__() | |||
self.layers = nn.Sequential(*[self.SubLayer(**kargs) for _ in range(num_layers)]) | |||
def forward(self, x, seq_mask=None): | |||
return self.layers(x, seq_mask) | |||
@@ -101,14 +101,14 @@ class VarRNNBase(nn.Module): | |||
mask_x = input.new_ones((batch_size, self.input_size)) | |||
mask_out = input.new_ones((batch_size, self.hidden_size * self.num_directions)) | |||
mask_h = input.new_ones((batch_size, self.hidden_size)) | |||
mask_h_ones = input.new_ones((batch_size, self.hidden_size)) | |||
nn.functional.dropout(mask_x, p=self.input_dropout, training=self.training, inplace=True) | |||
nn.functional.dropout(mask_out, p=self.hidden_dropout, training=self.training, inplace=True) | |||
nn.functional.dropout(mask_h, p=self.hidden_dropout, training=self.training, inplace=True) | |||
hidden_list = [] | |||
for layer in range(self.num_layers): | |||
output_list = [] | |||
mask_h = nn.functional.dropout(mask_h_ones, p=self.hidden_dropout, training=self.training, inplace=False) | |||
for direction in range(self.num_directions): | |||
input_x = input if direction == 0 else flip(input, [0]) | |||
idx = self.num_directions * layer + direction | |||
@@ -31,12 +31,12 @@ class GroupNorm(nn.Module): | |||
class LayerNormalization(nn.Module): | |||
""" Layer normalization module """ | |||
def __init__(self, d_hid, eps=1e-3): | |||
def __init__(self, layer_size, eps=1e-3): | |||
super(LayerNormalization, self).__init__() | |||
self.eps = eps | |||
self.a_2 = nn.Parameter(torch.ones(d_hid), requires_grad=True) | |||
self.b_2 = nn.Parameter(torch.zeros(d_hid), requires_grad=True) | |||
self.a_2 = nn.Parameter(torch.ones(1, layer_size, requires_grad=True)) | |||
self.b_2 = nn.Parameter(torch.zeros(1, layer_size, requires_grad=True)) | |||
def forward(self, z): | |||
if z.size(1) == 1: | |||
@@ -44,9 +44,8 @@ class LayerNormalization(nn.Module): | |||
mu = torch.mean(z, keepdim=True, dim=-1) | |||
sigma = torch.std(z, keepdim=True, dim=-1) | |||
ln_out = (z - mu.expand_as(z)) / (sigma.expand_as(z) + self.eps) | |||
ln_out = ln_out * self.a_2.expand_as(ln_out) + self.b_2.expand_as(ln_out) | |||
ln_out = (z - mu) / (sigma + self.eps) | |||
ln_out = ln_out * self.a_2 + self.b_2 | |||
return ln_out | |||
@@ -1,37 +1,40 @@ | |||
[train] | |||
epochs = 50 | |||
epochs = -1 | |||
batch_size = 16 | |||
pickle_path = "./save/" | |||
validate = true | |||
save_best_dev = false | |||
save_best_dev = true | |||
eval_sort_key = "UAS" | |||
use_cuda = true | |||
model_saved_path = "./save/" | |||
task = "parse" | |||
print_every_step = 20 | |||
use_golden_train=true | |||
[test] | |||
save_output = true | |||
validate_in_training = true | |||
save_dev_input = false | |||
save_loss = true | |||
batch_size = 16 | |||
batch_size = 64 | |||
pickle_path = "./save/" | |||
use_cuda = true | |||
task = "parse" | |||
[model] | |||
word_vocab_size = -1 | |||
word_emb_dim = 100 | |||
pos_vocab_size = -1 | |||
pos_emb_dim = 100 | |||
word_hid_dim = 100 | |||
pos_hid_dim = 100 | |||
rnn_layers = 3 | |||
rnn_hidden_size = 400 | |||
arc_mlp_size = 500 | |||
label_mlp_size = 100 | |||
num_label = -1 | |||
dropout = 0.33 | |||
use_var_lstm=true | |||
use_var_lstm=false | |||
use_greedy_infer=false | |||
[optim] | |||
lr = 2e-3 | |||
weight_decay = 5e-5 |
@@ -6,15 +6,17 @@ sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) | |||
from collections import defaultdict | |||
import math | |||
import torch | |||
import re | |||
from fastNLP.core.trainer import Trainer | |||
from fastNLP.core.metrics import Evaluator | |||
from fastNLP.core.instance import Instance | |||
from fastNLP.core.vocabulary import Vocabulary | |||
from fastNLP.core.dataset import DataSet | |||
from fastNLP.core.batch import Batch | |||
from fastNLP.core.sampler import SequentialSampler | |||
from fastNLP.core.field import TextField, SeqLabelField | |||
from fastNLP.core.preprocess import SeqLabelPreprocess, load_pickle | |||
from fastNLP.core.preprocess import load_pickle | |||
from fastNLP.core.tester import Tester | |||
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | |||
from fastNLP.loader.model_loader import ModelLoader | |||
@@ -22,15 +24,18 @@ from fastNLP.loader.embed_loader import EmbedLoader | |||
from fastNLP.models.biaffine_parser import BiaffineParser | |||
from fastNLP.saver.model_saver import ModelSaver | |||
BOS = '<BOS>' | |||
EOS = '<EOS>' | |||
UNK = '<OOV>' | |||
NUM = '<NUM>' | |||
ENG = '<ENG>' | |||
# not in the file's dir | |||
if len(os.path.dirname(__file__)) != 0: | |||
os.chdir(os.path.dirname(__file__)) | |||
class MyDataLoader(object): | |||
def __init__(self, pickle_path): | |||
self.pickle_path = pickle_path | |||
def load(self, path, word_v=None, pos_v=None, headtag_v=None): | |||
class ConlluDataLoader(object): | |||
def load(self, path): | |||
datalist = [] | |||
with open(path, 'r', encoding='utf-8') as f: | |||
sample = [] | |||
@@ -49,23 +54,18 @@ class MyDataLoader(object): | |||
for sample in datalist: | |||
# print(sample) | |||
res = self.get_one(sample) | |||
if word_v is not None: | |||
word_v.update(res[0]) | |||
pos_v.update(res[1]) | |||
headtag_v.update(res[3]) | |||
ds.append(Instance(word_seq=TextField(res[0], is_target=False), | |||
pos_seq=TextField(res[1], is_target=False), | |||
head_indices=SeqLabelField(res[2], is_target=True), | |||
head_labels=TextField(res[3], is_target=True), | |||
seq_mask=SeqLabelField([1 for _ in range(len(res[0]))], is_target=False))) | |||
head_labels=TextField(res[3], is_target=True))) | |||
return ds | |||
def get_one(self, sample): | |||
text = ['<root>'] | |||
pos_tags = ['<root>'] | |||
heads = [0] | |||
head_tags = ['root'] | |||
text = [] | |||
pos_tags = [] | |||
heads = [] | |||
head_tags = [] | |||
for w in sample: | |||
t1, t2, t3, t4 = w[1], w[3], w[6], w[7] | |||
if t3 == '_': | |||
@@ -76,17 +76,60 @@ class MyDataLoader(object): | |||
head_tags.append(t4) | |||
return (text, pos_tags, heads, head_tags) | |||
def index_data(self, dataset, word_v, pos_v, tag_v): | |||
dataset.index_field('word_seq', word_v) | |||
dataset.index_field('pos_seq', pos_v) | |||
dataset.index_field('head_labels', tag_v) | |||
class CTBDataLoader(object): | |||
def load(self, data_path): | |||
with open(data_path, "r", encoding="utf-8") as f: | |||
lines = f.readlines() | |||
data = self.parse(lines) | |||
return self.convert(data) | |||
def parse(self, lines): | |||
""" | |||
[ | |||
[word], [pos], [head_index], [head_tag] | |||
] | |||
""" | |||
sample = [] | |||
data = [] | |||
for i, line in enumerate(lines): | |||
line = line.strip() | |||
if len(line) == 0 or i+1 == len(lines): | |||
data.append(list(map(list, zip(*sample)))) | |||
sample = [] | |||
else: | |||
sample.append(line.split()) | |||
return data | |||
def convert(self, data): | |||
dataset = DataSet() | |||
for sample in data: | |||
word_seq = [BOS] + sample[0] + [EOS] | |||
pos_seq = [BOS] + sample[1] + [EOS] | |||
heads = [0] + list(map(int, sample[2])) + [0] | |||
head_tags = [BOS] + sample[3] + [EOS] | |||
dataset.append(Instance(word_seq=TextField(word_seq, is_target=False), | |||
pos_seq=TextField(pos_seq, is_target=False), | |||
gold_heads=SeqLabelField(heads, is_target=False), | |||
head_indices=SeqLabelField(heads, is_target=True), | |||
head_labels=TextField(head_tags, is_target=True))) | |||
return dataset | |||
# datadir = "/mnt/c/Me/Dev/release-2.2-st-train-dev-data/ud-treebanks-v2.2/UD_English-EWT" | |||
datadir = "/home/yfshao/UD_English-EWT" | |||
# datadir = "/home/yfshao/UD_English-EWT" | |||
# train_data_name = "en_ewt-ud-train.conllu" | |||
# dev_data_name = "en_ewt-ud-dev.conllu" | |||
# emb_file_name = '/home/yfshao/glove.6B.100d.txt' | |||
# loader = ConlluDataLoader() | |||
datadir = '/home/yfshao/workdir/parser-data/' | |||
train_data_name = "train_ctb5.txt" | |||
dev_data_name = "dev_ctb5.txt" | |||
test_data_name = "test_ctb5.txt" | |||
emb_file_name = "/home/yfshao/workdir/parser-data/word_OOVthr_30_100v.txt" | |||
# emb_file_name = "/home/yfshao/workdir/word_vector/cc.zh.300.vec" | |||
loader = CTBDataLoader() | |||
cfgfile = './cfg.cfg' | |||
train_data_name = "en_ewt-ud-train.conllu" | |||
dev_data_name = "en_ewt-ud-dev.conllu" | |||
emb_file_name = '/home/yfshao/glove.6B.100d.txt' | |||
processed_datadir = './save' | |||
# Config Loader | |||
@@ -95,8 +138,12 @@ test_args = ConfigSection() | |||
model_args = ConfigSection() | |||
optim_args = ConfigSection() | |||
ConfigLoader.load_config(cfgfile, {"train": train_args, "test": test_args, "model": model_args, "optim": optim_args}) | |||
print('trainre Args:', train_args.data) | |||
print('test Args:', test_args.data) | |||
print('optim Args:', optim_args.data) | |||
# Data Loader | |||
# Pickle Loader | |||
def save_data(dirpath, **kwargs): | |||
import _pickle | |||
if not os.path.exists(dirpath): | |||
@@ -117,38 +164,57 @@ def load_data(dirpath): | |||
datas[name] = _pickle.load(f) | |||
return datas | |||
class MyTester(object): | |||
def __init__(self, batch_size, use_cuda=False, **kwagrs): | |||
self.batch_size = batch_size | |||
self.use_cuda = use_cuda | |||
def test(self, model, dataset): | |||
self.model = model.cuda() if self.use_cuda else model | |||
self.model.eval() | |||
batchiter = Batch(dataset, self.batch_size, SequentialSampler(), self.use_cuda) | |||
eval_res = defaultdict(list) | |||
i = 0 | |||
for batch_x, batch_y in batchiter: | |||
with torch.no_grad(): | |||
pred_y = self.model(**batch_x) | |||
eval_one = self.model.evaluate(**pred_y, **batch_y) | |||
i += self.batch_size | |||
for eval_name, tensor in eval_one.items(): | |||
eval_res[eval_name].append(tensor) | |||
tmp = {} | |||
for eval_name, tensorlist in eval_res.items(): | |||
tmp[eval_name] = torch.cat(tensorlist, dim=0) | |||
self.res = self.model.metrics(**tmp) | |||
def show_metrics(self): | |||
s = "" | |||
for name, val in self.res.items(): | |||
s += '{}: {:.2f}\t'.format(name, val) | |||
return s | |||
loader = MyDataLoader('') | |||
def P2(data, field, length): | |||
ds = [ins for ins in data if ins[field].get_length() >= length] | |||
data.clear() | |||
data.extend(ds) | |||
return ds | |||
def P1(data, field): | |||
def reeng(w): | |||
return w if w == BOS or w == EOS or re.search(r'^([a-zA-Z]+[\.\-]*)+$', w) is None else ENG | |||
def renum(w): | |||
return w if re.search(r'^[0-9]+\.?[0-9]*$', w) is None else NUM | |||
for ins in data: | |||
ori = ins[field].contents() | |||
s = list(map(renum, map(reeng, ori))) | |||
if s != ori: | |||
# print(ori) | |||
# print(s) | |||
# print() | |||
ins[field] = ins[field].new(s) | |||
return data | |||
class ParserEvaluator(Evaluator): | |||
def __init__(self, ignore_label): | |||
super(ParserEvaluator, self).__init__() | |||
self.ignore = ignore_label | |||
def __call__(self, predict_list, truth_list): | |||
head_all, label_all, total_all = 0, 0, 0 | |||
for pred, truth in zip(predict_list, truth_list): | |||
head, label, total = self.evaluate(**pred, **truth) | |||
head_all += head | |||
label_all += label | |||
total_all += total | |||
return {'UAS': head_all*1.0 / total_all, 'LAS': label_all*1.0 / total_all} | |||
def evaluate(self, head_pred, label_pred, head_indices, head_labels, seq_mask, **_): | |||
""" | |||
Evaluate the performance of prediction. | |||
:return : performance results. | |||
head_pred_corrct: number of correct predicted heads. | |||
label_pred_correct: number of correct predicted labels. | |||
total_tokens: number of predicted tokens | |||
""" | |||
seq_mask *= (head_labels != self.ignore).long() | |||
head_pred_correct = (head_pred == head_indices).long() * seq_mask | |||
_, label_preds = torch.max(label_pred, dim=2) | |||
label_pred_correct = (label_preds == head_labels).long() * head_pred_correct | |||
return head_pred_correct.sum().item(), label_pred_correct.sum().item(), seq_mask.sum().item() | |||
try: | |||
data_dict = load_data(processed_datadir) | |||
word_v = data_dict['word_v'] | |||
@@ -156,62 +222,90 @@ try: | |||
tag_v = data_dict['tag_v'] | |||
train_data = data_dict['train_data'] | |||
dev_data = data_dict['dev_data'] | |||
test_data = data_dict['test_data'] | |||
print('use saved pickles') | |||
except Exception as _: | |||
print('load raw data and preprocess') | |||
# use pretrain embedding | |||
word_v = Vocabulary(need_default=True, min_freq=2) | |||
word_v.unknown_label = UNK | |||
pos_v = Vocabulary(need_default=True) | |||
tag_v = Vocabulary(need_default=False) | |||
train_data = loader.load(os.path.join(datadir, train_data_name), word_v, pos_v, tag_v) | |||
train_data = loader.load(os.path.join(datadir, train_data_name)) | |||
dev_data = loader.load(os.path.join(datadir, dev_data_name)) | |||
save_data(processed_datadir, word_v=word_v, pos_v=pos_v, tag_v=tag_v, train_data=train_data, dev_data=dev_data) | |||
test_data = loader.load(os.path.join(datadir, test_data_name)) | |||
train_data.update_vocab(word_seq=word_v, pos_seq=pos_v, head_labels=tag_v) | |||
datasets = (train_data, dev_data, test_data) | |||
save_data(processed_datadir, word_v=word_v, pos_v=pos_v, tag_v=tag_v, train_data=train_data, dev_data=dev_data, test_data=test_data) | |||
loader.index_data(train_data, word_v, pos_v, tag_v) | |||
loader.index_data(dev_data, word_v, pos_v, tag_v) | |||
print(len(train_data)) | |||
print(len(dev_data)) | |||
ep = train_args['epochs'] | |||
train_args['epochs'] = math.ceil(50000.0 / len(train_data) * train_args['batch_size']) if ep <= 0 else ep | |||
embed, _ = EmbedLoader.load_embedding(model_args['word_emb_dim'], emb_file_name, 'glove', word_v, os.path.join(processed_datadir, 'word_emb.pkl')) | |||
print(len(word_v)) | |||
print(embed.size()) | |||
# Model | |||
model_args['word_vocab_size'] = len(word_v) | |||
model_args['pos_vocab_size'] = len(pos_v) | |||
model_args['num_label'] = len(tag_v) | |||
model = BiaffineParser(**model_args.data) | |||
model.reset_parameters() | |||
datasets = (train_data, dev_data, test_data) | |||
for ds in datasets: | |||
ds.index_field("word_seq", word_v).index_field("pos_seq", pos_v).index_field("head_labels", tag_v) | |||
ds.set_origin_len('word_seq') | |||
if train_args['use_golden_train']: | |||
train_data.set_target(gold_heads=False) | |||
else: | |||
train_data.set_target(gold_heads=None) | |||
train_args.data.pop('use_golden_train') | |||
ignore_label = pos_v['P'] | |||
print(test_data[0]) | |||
print(len(train_data)) | |||
print(len(dev_data)) | |||
print(len(test_data)) | |||
def train(): | |||
def train(path): | |||
# Trainer | |||
trainer = Trainer(**train_args.data) | |||
def _define_optim(obj): | |||
obj._optimizer = torch.optim.Adam(obj._model.parameters(), **optim_args.data) | |||
obj._scheduler = torch.optim.lr_scheduler.LambdaLR(obj._optimizer, lambda ep: .75 ** (ep / 5e4)) | |||
lr = optim_args.data['lr'] | |||
embed_params = set(obj._model.word_embedding.parameters()) | |||
decay_params = set(obj._model.arc_predictor.parameters()) | set(obj._model.label_predictor.parameters()) | |||
params = [p for p in obj._model.parameters() if p not in decay_params and p not in embed_params] | |||
obj._optimizer = torch.optim.Adam([ | |||
{'params': list(embed_params), 'lr':lr*0.1}, | |||
{'params': list(decay_params), **optim_args.data}, | |||
{'params': params} | |||
], lr=lr, betas=(0.9, 0.9)) | |||
obj._scheduler = torch.optim.lr_scheduler.LambdaLR(obj._optimizer, lambda ep: max(.75 ** (ep / 5e4), 0.05)) | |||
def _update(obj): | |||
# torch.nn.utils.clip_grad_norm_(obj._model.parameters(), 5.0) | |||
obj._scheduler.step() | |||
obj._optimizer.step() | |||
trainer.define_optimizer = lambda: _define_optim(trainer) | |||
trainer.update = lambda: _update(trainer) | |||
trainer.get_loss = lambda predict, truth: trainer._loss_func(**predict, **truth) | |||
trainer._create_validator = lambda x: MyTester(**test_args.data) | |||
# Model | |||
model = BiaffineParser(**model_args.data) | |||
trainer.set_validator(Tester(**test_args.data, evaluator=ParserEvaluator(ignore_label))) | |||
# use pretrain embedding | |||
embed, _ = EmbedLoader.load_embedding(model_args['word_emb_dim'], emb_file_name, 'glove', word_v, os.path.join(processed_datadir, 'word_emb.pkl')) | |||
model.word_embedding = torch.nn.Embedding.from_pretrained(embed, freeze=False) | |||
model.word_embedding.padding_idx = word_v.padding_idx | |||
model.word_embedding.weight.data[word_v.padding_idx].fill_(0) | |||
model.pos_embedding.padding_idx = pos_v.padding_idx | |||
model.pos_embedding.weight.data[pos_v.padding_idx].fill_(0) | |||
try: | |||
ModelLoader.load_pytorch(model, "./save/saved_model.pkl") | |||
print('model parameter loaded!') | |||
except Exception as _: | |||
print("No saved model. Continue.") | |||
pass | |||
# try: | |||
# ModelLoader.load_pytorch(model, "./save/saved_model.pkl") | |||
# print('model parameter loaded!') | |||
# except Exception as _: | |||
# print("No saved model. Continue.") | |||
# pass | |||
# Start training | |||
trainer.train(model, train_data, dev_data) | |||
@@ -223,24 +317,27 @@ def train(): | |||
print("Model saved!") | |||
def test(): | |||
def test(path): | |||
# Tester | |||
tester = MyTester(**test_args.data) | |||
tester = Tester(**test_args.data, evaluator=ParserEvaluator(ignore_label)) | |||
# Model | |||
model = BiaffineParser(**model_args.data) | |||
model.eval() | |||
try: | |||
ModelLoader.load_pytorch(model, "./save/saved_model.pkl") | |||
ModelLoader.load_pytorch(model, path) | |||
print('model parameter loaded!') | |||
except Exception as _: | |||
print("No saved model. Abort test.") | |||
raise | |||
# Start training | |||
print("Testing Train data") | |||
tester.test(model, train_data) | |||
print("Testing Dev data") | |||
tester.test(model, dev_data) | |||
print(tester.show_metrics()) | |||
print("Testing finished!") | |||
print("Testing Test data") | |||
tester.test(model, test_data) | |||
@@ -248,11 +345,12 @@ if __name__ == "__main__": | |||
import argparse | |||
parser = argparse.ArgumentParser(description='Run a chinese word segmentation model') | |||
parser.add_argument('--mode', help='set the model\'s model', choices=['train', 'test', 'infer']) | |||
parser.add_argument('--path', type=str, default='') | |||
args = parser.parse_args() | |||
if args.mode == 'train': | |||
train() | |||
train(args.path) | |||
elif args.mode == 'test': | |||
test() | |||
test(args.path) | |||
elif args.mode == 'infer': | |||
infer() | |||
else: | |||