@@ -17,9 +17,9 @@ class Tester(object): | |||||
""" | """ | ||||
super(Tester, self).__init__() | super(Tester, self).__init__() | ||||
""" | """ | ||||
"default_args" provides default value for important settings. | |||||
The initialization arguments "kwargs" with the same key (name) will override the default value. | |||||
"kwargs" must have the same type as "default_args" on corresponding keys. | |||||
"default_args" provides default value for important settings. | |||||
The initialization arguments "kwargs" with the same key (name) will override the default value. | |||||
"kwargs" must have the same type as "default_args" on corresponding keys. | |||||
Otherwise, error will raise. | Otherwise, error will raise. | ||||
""" | """ | ||||
default_args = {"batch_size": 8, | default_args = {"batch_size": 8, | ||||
@@ -29,8 +29,8 @@ class Tester(object): | |||||
"evaluator": Evaluator() | "evaluator": Evaluator() | ||||
} | } | ||||
""" | """ | ||||
"required_args" is the collection of arguments that users must pass to Trainer explicitly. | |||||
This is used to warn users of essential settings in the training. | |||||
"required_args" is the collection of arguments that users must pass to Trainer explicitly. | |||||
This is used to warn users of essential settings in the training. | |||||
Specially, "required_args" does not have default value, so they have nothing to do with "default_args". | Specially, "required_args" does not have default value, so they have nothing to do with "default_args". | ||||
""" | """ | ||||
required_args = {} | required_args = {} | ||||
@@ -76,14 +76,17 @@ class Tester(object): | |||||
data_iterator = Batch(dev_data, self.batch_size, sampler=RandomSampler(), use_cuda=self.use_cuda) | data_iterator = Batch(dev_data, self.batch_size, sampler=RandomSampler(), use_cuda=self.use_cuda) | ||||
for batch_x, batch_y in data_iterator: | |||||
with torch.no_grad(): | |||||
with torch.no_grad(): | |||||
for batch_x, batch_y in data_iterator: | |||||
prediction = self.data_forward(network, batch_x) | prediction = self.data_forward(network, batch_x) | ||||
output_list.append(prediction) | |||||
truth_list.append(batch_y) | |||||
eval_results = self.evaluate(output_list, truth_list) | |||||
output_list.append(prediction) | |||||
truth_list.append(batch_y) | |||||
eval_results = self.evaluate(output_list, truth_list) | |||||
print("[tester] {}".format(self.print_eval_results(eval_results))) | print("[tester] {}".format(self.print_eval_results(eval_results))) | ||||
logger.info("[tester] {}".format(self.print_eval_results(eval_results))) | logger.info("[tester] {}".format(self.print_eval_results(eval_results))) | ||||
self.mode(network, is_test=False) | |||||
self.metrics = eval_results | |||||
return eval_results | |||||
def mode(self, model, is_test=False): | def mode(self, model, is_test=False): | ||||
"""Train mode or Test mode. This is for PyTorch currently. | """Train mode or Test mode. This is for PyTorch currently. | ||||
@@ -35,20 +35,21 @@ class Trainer(object): | |||||
super(Trainer, self).__init__() | super(Trainer, self).__init__() | ||||
""" | """ | ||||
"default_args" provides default value for important settings. | |||||
The initialization arguments "kwargs" with the same key (name) will override the default value. | |||||
"kwargs" must have the same type as "default_args" on corresponding keys. | |||||
"default_args" provides default value for important settings. | |||||
The initialization arguments "kwargs" with the same key (name) will override the default value. | |||||
"kwargs" must have the same type as "default_args" on corresponding keys. | |||||
Otherwise, error will raise. | Otherwise, error will raise. | ||||
""" | """ | ||||
default_args = {"epochs": 1, "batch_size": 2, "validate": False, "use_cuda": False, "pickle_path": "./save/", | default_args = {"epochs": 1, "batch_size": 2, "validate": False, "use_cuda": False, "pickle_path": "./save/", | ||||
"save_best_dev": False, "model_name": "default_model_name.pkl", "print_every_step": 1, | "save_best_dev": False, "model_name": "default_model_name.pkl", "print_every_step": 1, | ||||
"valid_step": 500, "eval_sort_key": None, | |||||
"loss": Loss(None), # used to pass type check | "loss": Loss(None), # used to pass type check | ||||
"optimizer": Optimizer("Adam", lr=0.001, weight_decay=0), | "optimizer": Optimizer("Adam", lr=0.001, weight_decay=0), | ||||
"evaluator": Evaluator() | "evaluator": Evaluator() | ||||
} | } | ||||
""" | """ | ||||
"required_args" is the collection of arguments that users must pass to Trainer explicitly. | |||||
This is used to warn users of essential settings in the training. | |||||
"required_args" is the collection of arguments that users must pass to Trainer explicitly. | |||||
This is used to warn users of essential settings in the training. | |||||
Specially, "required_args" does not have default value, so they have nothing to do with "default_args". | Specially, "required_args" does not have default value, so they have nothing to do with "default_args". | ||||
""" | """ | ||||
required_args = {} | required_args = {} | ||||
@@ -70,16 +71,20 @@ class Trainer(object): | |||||
else: | else: | ||||
# Trainer doesn't care about extra arguments | # Trainer doesn't care about extra arguments | ||||
pass | pass | ||||
print(default_args) | |||||
print("Training Args {}".format(default_args)) | |||||
logger.info("Training Args {}".format(default_args)) | |||||
self.n_epochs = default_args["epochs"] | |||||
self.batch_size = default_args["batch_size"] | |||||
self.n_epochs = int(default_args["epochs"]) | |||||
self.batch_size = int(default_args["batch_size"]) | |||||
self.pickle_path = default_args["pickle_path"] | self.pickle_path = default_args["pickle_path"] | ||||
self.validate = default_args["validate"] | self.validate = default_args["validate"] | ||||
self.save_best_dev = default_args["save_best_dev"] | self.save_best_dev = default_args["save_best_dev"] | ||||
self.use_cuda = default_args["use_cuda"] | self.use_cuda = default_args["use_cuda"] | ||||
self.model_name = default_args["model_name"] | self.model_name = default_args["model_name"] | ||||
self.print_every_step = default_args["print_every_step"] | |||||
self.print_every_step = int(default_args["print_every_step"]) | |||||
self.valid_step = int(default_args["valid_step"]) | |||||
if self.validate is not None: | |||||
assert self.valid_step > 0 | |||||
self._model = None | self._model = None | ||||
self._loss_func = default_args["loss"].get() # return a pytorch loss function or None | self._loss_func = default_args["loss"].get() # return a pytorch loss function or None | ||||
@@ -89,6 +94,8 @@ class Trainer(object): | |||||
self._summary_writer = SummaryWriter(self.pickle_path + 'tensorboard_logs') | self._summary_writer = SummaryWriter(self.pickle_path + 'tensorboard_logs') | ||||
self._graph_summaried = False | self._graph_summaried = False | ||||
self._best_accuracy = 0.0 | self._best_accuracy = 0.0 | ||||
self.eval_sort_key = default_args['eval_sort_key'] | |||||
self.validator = None | |||||
def train(self, network, train_data, dev_data=None): | def train(self, network, train_data, dev_data=None): | ||||
"""General Training Procedure | """General Training Procedure | ||||
@@ -108,8 +115,9 @@ class Trainer(object): | |||||
if self.validate: | if self.validate: | ||||
default_valid_args = {"batch_size": self.batch_size, "pickle_path": self.pickle_path, | default_valid_args = {"batch_size": self.batch_size, "pickle_path": self.pickle_path, | ||||
"use_cuda": self.use_cuda, "evaluator": self._evaluator} | "use_cuda": self.use_cuda, "evaluator": self._evaluator} | ||||
validator = self._create_validator(default_valid_args) | |||||
logger.info("validator defined as {}".format(str(validator))) | |||||
if self.validator is None: | |||||
self.validator = self._create_validator(default_valid_args) | |||||
logger.info("validator defined as {}".format(str(self.validator))) | |||||
# optimizer and loss | # optimizer and loss | ||||
self.define_optimizer() | self.define_optimizer() | ||||
@@ -117,29 +125,31 @@ class Trainer(object): | |||||
self.define_loss() | self.define_loss() | ||||
logger.info("loss function defined as {}".format(str(self._loss_func))) | logger.info("loss function defined as {}".format(str(self._loss_func))) | ||||
# turn on network training mode | |||||
self.mode(network, is_test=False) | |||||
# main training procedure | # main training procedure | ||||
start = time.time() | start = time.time() | ||||
logger.info("training epochs started") | |||||
for epoch in range(1, self.n_epochs + 1): | |||||
self.start_time = str(start) | |||||
logger.info("training epochs started " + self.start_time) | |||||
epoch, iters = 1, 0 | |||||
while(1): | |||||
if self.n_epochs != -1 and epoch > self.n_epochs: | |||||
break | |||||
logger.info("training epoch {}".format(epoch)) | logger.info("training epoch {}".format(epoch)) | ||||
# turn on network training mode | |||||
self.mode(network, is_test=False) | |||||
# prepare mini-batch iterator | # prepare mini-batch iterator | ||||
data_iterator = Batch(train_data, batch_size=self.batch_size, sampler=RandomSampler(), | data_iterator = Batch(train_data, batch_size=self.batch_size, sampler=RandomSampler(), | ||||
use_cuda=self.use_cuda) | use_cuda=self.use_cuda) | ||||
logger.info("prepared data iterator") | logger.info("prepared data iterator") | ||||
# one forward and backward pass | # one forward and backward pass | ||||
self._train_step(data_iterator, network, start=start, n_print=self.print_every_step, epoch=epoch) | |||||
iters += self._train_step(data_iterator, network, start=start, n_print=self.print_every_step, epoch=epoch, step=iters, dev_data=dev_data) | |||||
# validation | # validation | ||||
if self.validate: | if self.validate: | ||||
if dev_data is None: | |||||
raise RuntimeError( | |||||
"self.validate is True in trainer, but dev_data is None. Please provide the validation data.") | |||||
logger.info("validation started") | |||||
validator.test(network, dev_data) | |||||
self.valid_model() | |||||
def _train_step(self, data_iterator, network, **kwargs): | def _train_step(self, data_iterator, network, **kwargs): | ||||
"""Training process in one epoch. | """Training process in one epoch. | ||||
@@ -149,7 +159,8 @@ class Trainer(object): | |||||
- start: time.time(), the starting time of this step. | - start: time.time(), the starting time of this step. | ||||
- epoch: int, | - epoch: int, | ||||
""" | """ | ||||
step = 0 | |||||
step = kwargs['step'] | |||||
dev_data = kwargs['dev_data'] | |||||
for batch_x, batch_y in data_iterator: | for batch_x, batch_y in data_iterator: | ||||
prediction = self.data_forward(network, batch_x) | prediction = self.data_forward(network, batch_x) | ||||
@@ -166,7 +177,21 @@ class Trainer(object): | |||||
kwargs["epoch"], step, loss.data, diff) | kwargs["epoch"], step, loss.data, diff) | ||||
print(print_output) | print(print_output) | ||||
logger.info(print_output) | logger.info(print_output) | ||||
if self.validate and self.valid_step > 0 and step > 0 and step % self.valid_step == 0: | |||||
self.valid_model() | |||||
step += 1 | step += 1 | ||||
return step | |||||
def valid_model(self): | |||||
if dev_data is None: | |||||
raise RuntimeError( | |||||
"self.validate is True in trainer, but dev_data is None. Please provide the validation data.") | |||||
logger.info("validation started") | |||||
res = self.validator.test(network, dev_data) | |||||
if self.save_best_dev and self.best_eval_result(res): | |||||
logger.info('save best result! {}'.format(res)) | |||||
self.save_model(self._model, 'best_model_'+self.start_time) | |||||
return res | |||||
def mode(self, model, is_test=False): | def mode(self, model, is_test=False): | ||||
"""Train mode or Test mode. This is for PyTorch currently. | """Train mode or Test mode. This is for PyTorch currently. | ||||
@@ -180,11 +205,17 @@ class Trainer(object): | |||||
else: | else: | ||||
model.train() | model.train() | ||||
def define_optimizer(self): | |||||
def define_optimizer(self, optim=None): | |||||
"""Define framework-specific optimizer specified by the models. | """Define framework-specific optimizer specified by the models. | ||||
""" | """ | ||||
self._optimizer = self._optimizer_proto.construct_from_pytorch(self._model.parameters()) | |||||
if optim is not None: | |||||
# optimizer constructed by user | |||||
self._optimizer = optim | |||||
elif self._optimizer is None: | |||||
# optimizer constructed by proto | |||||
self._optimizer = self._optimizer_proto.construct_from_pytorch(self._model.parameters()) | |||||
return self._optimizer | |||||
def update(self): | def update(self): | ||||
"""Perform weight update on a model. | """Perform weight update on a model. | ||||
@@ -217,6 +248,8 @@ class Trainer(object): | |||||
:param truth: ground truth label vector | :param truth: ground truth label vector | ||||
:return: a scalar | :return: a scalar | ||||
""" | """ | ||||
if isinstance(predict, dict) and isinstance(truth, dict): | |||||
return self._loss_func(**predict, **truth) | |||||
if len(truth) > 1: | if len(truth) > 1: | ||||
raise NotImplementedError("Not ready to handle multi-labels.") | raise NotImplementedError("Not ready to handle multi-labels.") | ||||
truth = list(truth.values())[0] if len(truth) > 0 else None | truth = list(truth.values())[0] if len(truth) > 0 else None | ||||
@@ -241,13 +274,27 @@ class Trainer(object): | |||||
raise ValueError("Please specify a loss function.") | raise ValueError("Please specify a loss function.") | ||||
logger.info("The model didn't define loss, use Trainer's loss.") | logger.info("The model didn't define loss, use Trainer's loss.") | ||||
def best_eval_result(self, validator): | |||||
def best_eval_result(self, metrics): | |||||
"""Check if the current epoch yields better validation results. | """Check if the current epoch yields better validation results. | ||||
:param validator: a Tester instance | :param validator: a Tester instance | ||||
:return: bool, True means current results on dev set is the best. | :return: bool, True means current results on dev set is the best. | ||||
""" | """ | ||||
loss, accuracy = validator.metrics | |||||
if isinstance(metrics, tuple): | |||||
loss, metrics = metrics | |||||
else: | |||||
metrics = validator.metrics | |||||
if isinstance(metrics, dict): | |||||
if len(metrics) == 1: | |||||
accuracy = list(metrics.values())[0] | |||||
elif self.eval_sort_key is None: | |||||
raise ValueError('dict format metrics should provide sort key for eval best result') | |||||
else: | |||||
accuracy = metrics[self.eval_sort_key] | |||||
else: | |||||
accuracy = metrics | |||||
if accuracy > self._best_accuracy: | if accuracy > self._best_accuracy: | ||||
self._best_accuracy = accuracy | self._best_accuracy = accuracy | ||||
return True | return True | ||||
@@ -268,6 +315,8 @@ class Trainer(object): | |||||
def _create_validator(self, valid_args): | def _create_validator(self, valid_args): | ||||
raise NotImplementedError | raise NotImplementedError | ||||
def set_validator(self, validor): | |||||
self.validator = validor | |||||
class SeqLabelTrainer(Trainer): | class SeqLabelTrainer(Trainer): | ||||
"""Trainer for Sequence Labeling | """Trainer for Sequence Labeling | ||||
@@ -243,6 +243,9 @@ class BiaffineParser(GraphParser): | |||||
self.normal_dropout = nn.Dropout(p=dropout) | self.normal_dropout = nn.Dropout(p=dropout) | ||||
self.use_greedy_infer = use_greedy_infer | self.use_greedy_infer = use_greedy_infer | ||||
initial_parameter(self) | initial_parameter(self) | ||||
self.word_norm = nn.LayerNorm(word_emb_dim) | |||||
self.pos_norm = nn.LayerNorm(pos_emb_dim) | |||||
self.lstm_norm = nn.LayerNorm(rnn_out_size) | |||||
def forward(self, word_seq, pos_seq, word_seq_origin_len, gold_heads=None, **_): | def forward(self, word_seq, pos_seq, word_seq_origin_len, gold_heads=None, **_): | ||||
""" | """ | ||||
@@ -266,10 +269,12 @@ class BiaffineParser(GraphParser): | |||||
word = self.normal_dropout(self.word_embedding(word_seq)) # [N,L] -> [N,L,C_0] | word = self.normal_dropout(self.word_embedding(word_seq)) # [N,L] -> [N,L,C_0] | ||||
pos = self.normal_dropout(self.pos_embedding(pos_seq)) # [N,L] -> [N,L,C_1] | pos = self.normal_dropout(self.pos_embedding(pos_seq)) # [N,L] -> [N,L,C_1] | ||||
word, pos = self.word_norm(word), self.pos_norm(pos) | |||||
x = torch.cat([word, pos], dim=2) # -> [N,L,C] | x = torch.cat([word, pos], dim=2) # -> [N,L,C] | ||||
# lstm, extract features | # lstm, extract features | ||||
feat, _ = self.lstm(x) # -> [N,L,C] | feat, _ = self.lstm(x) # -> [N,L,C] | ||||
feat = self.lstm_norm(feat) | |||||
# for arc biaffine | # for arc biaffine | ||||
# mlp, reduce dim | # mlp, reduce dim | ||||
@@ -292,6 +297,7 @@ class BiaffineParser(GraphParser): | |||||
heads = self._mst_decoder(arc_pred, seq_mask) | heads = self._mst_decoder(arc_pred, seq_mask) | ||||
head_pred = heads | head_pred = heads | ||||
else: | else: | ||||
assert self.training # must be training mode | |||||
head_pred = None | head_pred = None | ||||
heads = gold_heads | heads = gold_heads | ||||
@@ -331,40 +337,4 @@ class BiaffineParser(GraphParser): | |||||
label_nll = -(label_loss*float_mask).sum() / length | label_nll = -(label_loss*float_mask).sum() / length | ||||
return arc_nll + label_nll | return arc_nll + label_nll | ||||
def evaluate(self, arc_pred, label_pred, head_indices, head_labels, seq_mask, **kwargs): | |||||
""" | |||||
Evaluate the performance of prediction. | |||||
:return dict: performance results. | |||||
head_pred_corrct: number of correct predicted heads. | |||||
label_pred_correct: number of correct predicted labels. | |||||
total_tokens: number of predicted tokens | |||||
""" | |||||
if 'head_pred' in kwargs: | |||||
head_pred = kwargs['head_pred'] | |||||
elif self.use_greedy_infer: | |||||
head_pred = self._greedy_decoder(arc_pred, seq_mask) | |||||
else: | |||||
head_pred = self._mst_decoder(arc_pred, seq_mask) | |||||
head_pred_correct = (head_pred == head_indices).long() * seq_mask | |||||
_, label_preds = torch.max(label_pred, dim=2) | |||||
label_pred_correct = (label_preds == head_labels).long() * head_pred_correct | |||||
return {"head_pred_correct": head_pred_correct.sum(dim=1), | |||||
"label_pred_correct": label_pred_correct.sum(dim=1), | |||||
"total_tokens": seq_mask.sum(dim=1)} | |||||
def metrics(self, head_pred_correct, label_pred_correct, total_tokens, **_): | |||||
""" | |||||
Compute the metrics of model | |||||
:param head_pred_corrct: number of correct predicted heads. | |||||
:param label_pred_correct: number of correct predicted labels. | |||||
:param total_tokens: number of predicted tokens | |||||
:return dict: the metrics results | |||||
UAS: the head predicted accuracy | |||||
LAS: the label predicted accuracy | |||||
""" | |||||
return {"UAS": head_pred_correct.sum().float() / total_tokens.sum().float() * 100, | |||||
"LAS": label_pred_correct.sum().float() / total_tokens.sum().float() * 100} | |||||
@@ -1,23 +1,25 @@ | |||||
[train] | [train] | ||||
epochs = -1 | epochs = -1 | ||||
<<<<<<< HEAD | |||||
batch_size = 16 | batch_size = 16 | ||||
======= | |||||
batch_size = 32 | |||||
>>>>>>> update biaffine | |||||
pickle_path = "./save/" | pickle_path = "./save/" | ||||
validate = true | validate = true | ||||
save_best_dev = false | |||||
save_best_dev = true | |||||
eval_sort_key = "UAS" | |||||
use_cuda = true | use_cuda = true | ||||
model_saved_path = "./save/" | model_saved_path = "./save/" | ||||
task = "parse" | |||||
[test] | [test] | ||||
save_output = true | save_output = true | ||||
validate_in_training = true | validate_in_training = true | ||||
save_dev_input = false | save_dev_input = false | ||||
save_loss = true | save_loss = true | ||||
batch_size = 16 | |||||
batch_size = 64 | |||||
pickle_path = "./save/" | pickle_path = "./save/" | ||||
use_cuda = true | use_cuda = true | ||||
task = "parse" | |||||
[model] | [model] | ||||
word_vocab_size = -1 | word_vocab_size = -1 | ||||
@@ -8,12 +8,14 @@ import math | |||||
import torch | import torch | ||||
from fastNLP.core.trainer import Trainer | from fastNLP.core.trainer import Trainer | ||||
from fastNLP.core.metrics import Evaluator | |||||
from fastNLP.core.instance import Instance | from fastNLP.core.instance import Instance | ||||
from fastNLP.core.vocabulary import Vocabulary | from fastNLP.core.vocabulary import Vocabulary | ||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.core.batch import Batch | from fastNLP.core.batch import Batch | ||||
from fastNLP.core.sampler import SequentialSampler | from fastNLP.core.sampler import SequentialSampler | ||||
from fastNLP.core.field import TextField, SeqLabelField | from fastNLP.core.field import TextField, SeqLabelField | ||||
from fastNLP.core.preprocess import load_pickle | |||||
from fastNLP.core.tester import Tester | from fastNLP.core.tester import Tester | ||||
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | ||||
from fastNLP.loader.model_loader import ModelLoader | from fastNLP.loader.model_loader import ModelLoader | ||||
@@ -111,9 +113,10 @@ class CTBDataLoader(object): | |||||
# emb_file_name = '/home/yfshao/glove.6B.100d.txt' | # emb_file_name = '/home/yfshao/glove.6B.100d.txt' | ||||
# loader = ConlluDataLoader() | # loader = ConlluDataLoader() | ||||
datadir = "/home/yfshao/parser-data" | |||||
datadir = '/home/yfshao/workdir/parser-data/' | |||||
train_data_name = "train_ctb5.txt" | train_data_name = "train_ctb5.txt" | ||||
dev_data_name = "dev_ctb5.txt" | dev_data_name = "dev_ctb5.txt" | ||||
test_data_name = "test_ctb5.txt" | |||||
emb_file_name = "/home/yfshao/parser-data/word_OOVthr_30_100v.txt" | emb_file_name = "/home/yfshao/parser-data/word_OOVthr_30_100v.txt" | ||||
loader = CTBDataLoader() | loader = CTBDataLoader() | ||||
@@ -148,37 +151,33 @@ def load_data(dirpath): | |||||
datas[name] = _pickle.load(f) | datas[name] = _pickle.load(f) | ||||
return datas | return datas | ||||
class MyTester(object): | |||||
def __init__(self, batch_size, use_cuda=False, **kwagrs): | |||||
self.batch_size = batch_size | |||||
self.use_cuda = use_cuda | |||||
def test(self, model, dataset): | |||||
self.model = model.cuda() if self.use_cuda else model | |||||
self.model.eval() | |||||
batchiter = Batch(dataset, self.batch_size, SequentialSampler(), self.use_cuda) | |||||
eval_res = defaultdict(list) | |||||
i = 0 | |||||
for batch_x, batch_y in batchiter: | |||||
with torch.no_grad(): | |||||
pred_y = self.model(**batch_x) | |||||
eval_one = self.model.evaluate(**pred_y, **batch_y) | |||||
i += self.batch_size | |||||
for eval_name, tensor in eval_one.items(): | |||||
eval_res[eval_name].append(tensor) | |||||
tmp = {} | |||||
for eval_name, tensorlist in eval_res.items(): | |||||
tmp[eval_name] = torch.cat(tensorlist, dim=0) | |||||
self.res = self.model.metrics(**tmp) | |||||
print(self.show_metrics()) | |||||
def show_metrics(self): | |||||
s = "" | |||||
for name, val in self.res.items(): | |||||
s += '{}: {:.2f}\t'.format(name, val) | |||||
return s | |||||
class ParserEvaluator(Evaluator): | |||||
def __init__(self): | |||||
super(ParserEvaluator, self).__init__() | |||||
def __call__(self, predict_list, truth_list): | |||||
head_all, label_all, total_all = 0, 0, 0 | |||||
for pred, truth in zip(predict_list, truth_list): | |||||
head, label, total = self.evaluate(**pred, **truth) | |||||
head_all += head | |||||
label_all += label | |||||
total_all += total | |||||
return {'UAS': head_all*1.0 / total_all, 'LAS': label_all*1.0 / total_all} | |||||
def evaluate(self, head_pred, label_pred, head_indices, head_labels, seq_mask, **_): | |||||
""" | |||||
Evaluate the performance of prediction. | |||||
:return : performance results. | |||||
head_pred_corrct: number of correct predicted heads. | |||||
label_pred_correct: number of correct predicted labels. | |||||
total_tokens: number of predicted tokens | |||||
""" | |||||
head_pred_correct = (head_pred == head_indices).long() * seq_mask | |||||
_, label_preds = torch.max(label_pred, dim=2) | |||||
label_pred_correct = (label_preds == head_labels).long() * head_pred_correct | |||||
return head_pred_correct.sum().item(), label_pred_correct.sum().item(), seq_mask.sum().item() | |||||
try: | try: | ||||
data_dict = load_data(processed_datadir) | data_dict = load_data(processed_datadir) | ||||
@@ -196,6 +195,7 @@ except Exception as _: | |||||
tag_v = Vocabulary(need_default=False) | tag_v = Vocabulary(need_default=False) | ||||
train_data = loader.load(os.path.join(datadir, train_data_name)) | train_data = loader.load(os.path.join(datadir, train_data_name)) | ||||
dev_data = loader.load(os.path.join(datadir, dev_data_name)) | dev_data = loader.load(os.path.join(datadir, dev_data_name)) | ||||
test_data = loader.load(os.path.join(datadir, test_data_name)) | |||||
train_data.update_vocab(word_seq=word_v, pos_seq=pos_v, head_labels=tag_v) | train_data.update_vocab(word_seq=word_v, pos_seq=pos_v, head_labels=tag_v) | ||||
save_data(processed_datadir, word_v=word_v, pos_v=pos_v, tag_v=tag_v, train_data=train_data, dev_data=dev_data) | save_data(processed_datadir, word_v=word_v, pos_v=pos_v, tag_v=tag_v, train_data=train_data, dev_data=dev_data) | ||||
@@ -207,8 +207,6 @@ dev_data.set_origin_len("word_seq") | |||||
print(train_data[:3]) | print(train_data[:3]) | ||||
print(len(train_data)) | print(len(train_data)) | ||||
print(len(dev_data)) | print(len(dev_data)) | ||||
ep = train_args['epochs'] | |||||
train_args['epochs'] = math.ceil(50000.0 / len(train_data) * train_args['batch_size']) if ep <= 0 else ep | |||||
model_args['word_vocab_size'] = len(word_v) | model_args['word_vocab_size'] = len(word_v) | ||||
model_args['pos_vocab_size'] = len(pos_v) | model_args['pos_vocab_size'] = len(pos_v) | ||||
model_args['num_label'] = len(tag_v) | model_args['num_label'] = len(tag_v) | ||||
@@ -220,7 +218,7 @@ def train(): | |||||
def _define_optim(obj): | def _define_optim(obj): | ||||
obj._optimizer = torch.optim.Adam(obj._model.parameters(), **optim_args.data) | obj._optimizer = torch.optim.Adam(obj._model.parameters(), **optim_args.data) | ||||
obj._scheduler = torch.optim.lr_scheduler.LambdaLR(obj._optimizer, lambda ep: .75 ** (ep / 5e4)) | |||||
obj._scheduler = torch.optim.lr_scheduler.LambdaLR(obj._optimizer, lambda ep: max(.75 ** (ep / 5e4), 0.05)) | |||||
def _update(obj): | def _update(obj): | ||||
obj._scheduler.step() | obj._scheduler.step() | ||||
@@ -228,8 +226,7 @@ def train(): | |||||
trainer.define_optimizer = lambda: _define_optim(trainer) | trainer.define_optimizer = lambda: _define_optim(trainer) | ||||
trainer.update = lambda: _update(trainer) | trainer.update = lambda: _update(trainer) | ||||
trainer.get_loss = lambda predict, truth: trainer._loss_func(**predict, **truth) | |||||
trainer._create_validator = lambda x: MyTester(**test_args.data) | |||||
trainer.set_validator(Tester(**test_args.data, evaluator=ParserEvaluator())) | |||||
# Model | # Model | ||||
model = BiaffineParser(**model_args.data) | model = BiaffineParser(**model_args.data) | ||||
@@ -238,6 +235,7 @@ def train(): | |||||
word_v.unknown_label = "<OOV>" | word_v.unknown_label = "<OOV>" | ||||
embed, _ = EmbedLoader.load_embedding(model_args['word_emb_dim'], emb_file_name, 'glove', word_v, os.path.join(processed_datadir, 'word_emb.pkl')) | embed, _ = EmbedLoader.load_embedding(model_args['word_emb_dim'], emb_file_name, 'glove', word_v, os.path.join(processed_datadir, 'word_emb.pkl')) | ||||
model.word_embedding = torch.nn.Embedding.from_pretrained(embed, freeze=False) | model.word_embedding = torch.nn.Embedding.from_pretrained(embed, freeze=False) | ||||
model.word_embedding.padding_idx = word_v.padding_idx | model.word_embedding.padding_idx = word_v.padding_idx | ||||
model.word_embedding.weight.data[word_v.padding_idx].fill_(0) | model.word_embedding.weight.data[word_v.padding_idx].fill_(0) | ||||
model.pos_embedding.padding_idx = pos_v.padding_idx | model.pos_embedding.padding_idx = pos_v.padding_idx | ||||
@@ -262,7 +260,7 @@ def train(): | |||||
def test(): | def test(): | ||||
# Tester | # Tester | ||||
tester = MyTester(**test_args.data) | |||||
tester = Tester(**test_args.data, evaluator=ParserEvaluator()) | |||||
# Model | # Model | ||||
model = BiaffineParser(**model_args.data) | model = BiaffineParser(**model_args.data) | ||||
@@ -275,9 +273,10 @@ def test(): | |||||
raise | raise | ||||
# Start training | # Start training | ||||
print("Testing Dev data") | |||||
tester.test(model, dev_data) | tester.test(model, dev_data) | ||||
print(tester.show_metrics()) | |||||
print("Testing finished!") | |||||
print("Testing Test data") | |||||
tester.test(model, test_data) | |||||