@@ -28,6 +28,12 @@ class Field(object): | |||||
""" | """ | ||||
raise NotImplementedError | raise NotImplementedError | ||||
def __repr__(self): | |||||
return self.contents().__repr__() | |||||
def new(self, *args, **kwargs): | |||||
return self.__class__(*args, **kwargs, is_target=self.is_target) | |||||
class TextField(Field): | class TextField(Field): | ||||
def __init__(self, name, text, is_target): | def __init__(self, name, text, is_target): | ||||
""" | """ | ||||
@@ -35,6 +35,9 @@ class Instance(object): | |||||
else: | else: | ||||
raise KeyError("{} not found".format(name)) | raise KeyError("{} not found".format(name)) | ||||
def __setitem__(self, name, field): | |||||
return self.add_field(name, field) | |||||
def get_length(self): | def get_length(self): | ||||
"""Fetch the length of all fields in the instance. | """Fetch the length of all fields in the instance. | ||||
@@ -82,3 +85,6 @@ class Instance(object): | |||||
name, field_name = origin_len | name, field_name = origin_len | ||||
tensor_x[name] = torch.LongTensor([self.fields[field_name].get_length()]) | tensor_x[name] = torch.LongTensor([self.fields[field_name].get_length()]) | ||||
return tensor_x, tensor_y | return tensor_x, tensor_y | ||||
def __repr__(self): | |||||
return self.fields.__repr__() |
@@ -17,9 +17,9 @@ class Tester(object): | |||||
""" | """ | ||||
super(Tester, self).__init__() | super(Tester, self).__init__() | ||||
""" | """ | ||||
"default_args" provides default value for important settings. | |||||
The initialization arguments "kwargs" with the same key (name) will override the default value. | |||||
"kwargs" must have the same type as "default_args" on corresponding keys. | |||||
"default_args" provides default value for important settings. | |||||
The initialization arguments "kwargs" with the same key (name) will override the default value. | |||||
"kwargs" must have the same type as "default_args" on corresponding keys. | |||||
Otherwise, error will raise. | Otherwise, error will raise. | ||||
""" | """ | ||||
default_args = {"batch_size": 8, | default_args = {"batch_size": 8, | ||||
@@ -29,8 +29,8 @@ class Tester(object): | |||||
"evaluator": Evaluator() | "evaluator": Evaluator() | ||||
} | } | ||||
""" | """ | ||||
"required_args" is the collection of arguments that users must pass to Trainer explicitly. | |||||
This is used to warn users of essential settings in the training. | |||||
"required_args" is the collection of arguments that users must pass to Trainer explicitly. | |||||
This is used to warn users of essential settings in the training. | |||||
Specially, "required_args" does not have default value, so they have nothing to do with "default_args". | Specially, "required_args" does not have default value, so they have nothing to do with "default_args". | ||||
""" | """ | ||||
required_args = {} | required_args = {} | ||||
@@ -74,16 +74,19 @@ class Tester(object): | |||||
output_list = [] | output_list = [] | ||||
truth_list = [] | truth_list = [] | ||||
data_iterator = Batch(dev_data, self.batch_size, sampler=RandomSampler(), use_cuda=self.use_cuda) | |||||
data_iterator = Batch(dev_data, self.batch_size, sampler=RandomSampler(), use_cuda=self.use_cuda, sort_in_batch=True, sort_key='word_seq') | |||||
for batch_x, batch_y in data_iterator: | |||||
with torch.no_grad(): | |||||
with torch.no_grad(): | |||||
for batch_x, batch_y in data_iterator: | |||||
prediction = self.data_forward(network, batch_x) | prediction = self.data_forward(network, batch_x) | ||||
output_list.append(prediction) | |||||
truth_list.append(batch_y) | |||||
eval_results = self.evaluate(output_list, truth_list) | |||||
output_list.append(prediction) | |||||
truth_list.append(batch_y) | |||||
eval_results = self.evaluate(output_list, truth_list) | |||||
print("[tester] {}".format(self.print_eval_results(eval_results))) | print("[tester] {}".format(self.print_eval_results(eval_results))) | ||||
logger.info("[tester] {}".format(self.print_eval_results(eval_results))) | logger.info("[tester] {}".format(self.print_eval_results(eval_results))) | ||||
self.mode(network, is_test=False) | |||||
self.metrics = eval_results | |||||
return eval_results | |||||
def mode(self, model, is_test=False): | def mode(self, model, is_test=False): | ||||
"""Train mode or Test mode. This is for PyTorch currently. | """Train mode or Test mode. This is for PyTorch currently. | ||||
@@ -1,6 +1,6 @@ | |||||
import os | import os | ||||
import time | import time | ||||
from datetime import timedelta | |||||
from datetime import timedelta, datetime | |||||
import torch | import torch | ||||
from tensorboardX import SummaryWriter | from tensorboardX import SummaryWriter | ||||
@@ -15,7 +15,7 @@ from fastNLP.saver.logger import create_logger | |||||
from fastNLP.saver.model_saver import ModelSaver | from fastNLP.saver.model_saver import ModelSaver | ||||
logger = create_logger(__name__, "./train_test.log") | logger = create_logger(__name__, "./train_test.log") | ||||
logger.disabled = True | |||||
class Trainer(object): | class Trainer(object): | ||||
"""Operations of training a model, including data loading, gradient descent, and validation. | """Operations of training a model, including data loading, gradient descent, and validation. | ||||
@@ -35,20 +35,21 @@ class Trainer(object): | |||||
super(Trainer, self).__init__() | super(Trainer, self).__init__() | ||||
""" | """ | ||||
"default_args" provides default value for important settings. | |||||
The initialization arguments "kwargs" with the same key (name) will override the default value. | |||||
"kwargs" must have the same type as "default_args" on corresponding keys. | |||||
"default_args" provides default value for important settings. | |||||
The initialization arguments "kwargs" with the same key (name) will override the default value. | |||||
"kwargs" must have the same type as "default_args" on corresponding keys. | |||||
Otherwise, error will raise. | Otherwise, error will raise. | ||||
""" | """ | ||||
default_args = {"epochs": 1, "batch_size": 2, "validate": False, "use_cuda": False, "pickle_path": "./save/", | default_args = {"epochs": 1, "batch_size": 2, "validate": False, "use_cuda": False, "pickle_path": "./save/", | ||||
"save_best_dev": False, "model_name": "default_model_name.pkl", "print_every_step": 1, | "save_best_dev": False, "model_name": "default_model_name.pkl", "print_every_step": 1, | ||||
"valid_step": 500, "eval_sort_key": 'acc', | |||||
"loss": Loss(None), # used to pass type check | "loss": Loss(None), # used to pass type check | ||||
"optimizer": Optimizer("Adam", lr=0.001, weight_decay=0), | "optimizer": Optimizer("Adam", lr=0.001, weight_decay=0), | ||||
"evaluator": Evaluator() | "evaluator": Evaluator() | ||||
} | } | ||||
""" | """ | ||||
"required_args" is the collection of arguments that users must pass to Trainer explicitly. | |||||
This is used to warn users of essential settings in the training. | |||||
"required_args" is the collection of arguments that users must pass to Trainer explicitly. | |||||
This is used to warn users of essential settings in the training. | |||||
Specially, "required_args" does not have default value, so they have nothing to do with "default_args". | Specially, "required_args" does not have default value, so they have nothing to do with "default_args". | ||||
""" | """ | ||||
required_args = {} | required_args = {} | ||||
@@ -70,16 +71,20 @@ class Trainer(object): | |||||
else: | else: | ||||
# Trainer doesn't care about extra arguments | # Trainer doesn't care about extra arguments | ||||
pass | pass | ||||
print(default_args) | |||||
print("Training Args {}".format(default_args)) | |||||
logger.info("Training Args {}".format(default_args)) | |||||
self.n_epochs = default_args["epochs"] | |||||
self.batch_size = default_args["batch_size"] | |||||
self.n_epochs = int(default_args["epochs"]) | |||||
self.batch_size = int(default_args["batch_size"]) | |||||
self.pickle_path = default_args["pickle_path"] | self.pickle_path = default_args["pickle_path"] | ||||
self.validate = default_args["validate"] | self.validate = default_args["validate"] | ||||
self.save_best_dev = default_args["save_best_dev"] | self.save_best_dev = default_args["save_best_dev"] | ||||
self.use_cuda = default_args["use_cuda"] | self.use_cuda = default_args["use_cuda"] | ||||
self.model_name = default_args["model_name"] | self.model_name = default_args["model_name"] | ||||
self.print_every_step = default_args["print_every_step"] | |||||
self.print_every_step = int(default_args["print_every_step"]) | |||||
self.valid_step = int(default_args["valid_step"]) | |||||
if self.validate is not None: | |||||
assert self.valid_step > 0 | |||||
self._model = None | self._model = None | ||||
self._loss_func = default_args["loss"].get() # return a pytorch loss function or None | self._loss_func = default_args["loss"].get() # return a pytorch loss function or None | ||||
@@ -89,6 +94,8 @@ class Trainer(object): | |||||
self._summary_writer = SummaryWriter(self.pickle_path + 'tensorboard_logs') | self._summary_writer = SummaryWriter(self.pickle_path + 'tensorboard_logs') | ||||
self._graph_summaried = False | self._graph_summaried = False | ||||
self._best_accuracy = 0.0 | self._best_accuracy = 0.0 | ||||
self.eval_sort_key = default_args['eval_sort_key'] | |||||
self.validator = None | |||||
def train(self, network, train_data, dev_data=None): | def train(self, network, train_data, dev_data=None): | ||||
"""General Training Procedure | """General Training Procedure | ||||
@@ -104,12 +111,17 @@ class Trainer(object): | |||||
else: | else: | ||||
self._model = network | self._model = network | ||||
print(self._model) | |||||
# define Tester over dev data | # define Tester over dev data | ||||
self.dev_data = None | |||||
if self.validate: | if self.validate: | ||||
default_valid_args = {"batch_size": self.batch_size, "pickle_path": self.pickle_path, | default_valid_args = {"batch_size": self.batch_size, "pickle_path": self.pickle_path, | ||||
"use_cuda": self.use_cuda, "evaluator": self._evaluator} | "use_cuda": self.use_cuda, "evaluator": self._evaluator} | ||||
validator = self._create_validator(default_valid_args) | |||||
logger.info("validator defined as {}".format(str(validator))) | |||||
if self.validator is None: | |||||
self.validator = self._create_validator(default_valid_args) | |||||
logger.info("validator defined as {}".format(str(self.validator))) | |||||
self.dev_data = dev_data | |||||
# optimizer and loss | # optimizer and loss | ||||
self.define_optimizer() | self.define_optimizer() | ||||
@@ -117,29 +129,33 @@ class Trainer(object): | |||||
self.define_loss() | self.define_loss() | ||||
logger.info("loss function defined as {}".format(str(self._loss_func))) | logger.info("loss function defined as {}".format(str(self._loss_func))) | ||||
# turn on network training mode | |||||
self.mode(network, is_test=False) | |||||
# main training procedure | # main training procedure | ||||
start = time.time() | start = time.time() | ||||
logger.info("training epochs started") | |||||
for epoch in range(1, self.n_epochs + 1): | |||||
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S')) | |||||
print("training epochs started " + self.start_time) | |||||
logger.info("training epochs started " + self.start_time) | |||||
epoch, iters = 1, 0 | |||||
while(1): | |||||
if self.n_epochs != -1 and epoch > self.n_epochs: | |||||
break | |||||
logger.info("training epoch {}".format(epoch)) | logger.info("training epoch {}".format(epoch)) | ||||
# turn on network training mode | |||||
self.mode(network, is_test=False) | |||||
# prepare mini-batch iterator | # prepare mini-batch iterator | ||||
data_iterator = Batch(train_data, batch_size=self.batch_size, sampler=RandomSampler(), | data_iterator = Batch(train_data, batch_size=self.batch_size, sampler=RandomSampler(), | ||||
use_cuda=self.use_cuda) | |||||
use_cuda=self.use_cuda, sort_in_batch=True, sort_key='word_seq') | |||||
logger.info("prepared data iterator") | logger.info("prepared data iterator") | ||||
# one forward and backward pass | # one forward and backward pass | ||||
self._train_step(data_iterator, network, start=start, n_print=self.print_every_step, epoch=epoch) | |||||
iters = self._train_step(data_iterator, network, start=start, n_print=self.print_every_step, epoch=epoch, step=iters, dev_data=dev_data) | |||||
# validation | # validation | ||||
if self.validate: | if self.validate: | ||||
if dev_data is None: | |||||
raise RuntimeError( | |||||
"self.validate is True in trainer, but dev_data is None. Please provide the validation data.") | |||||
logger.info("validation started") | |||||
validator.test(network, dev_data) | |||||
self.valid_model() | |||||
self.save_model(self._model, 'training_model_'+self.start_time) | |||||
epoch += 1 | |||||
def _train_step(self, data_iterator, network, **kwargs): | def _train_step(self, data_iterator, network, **kwargs): | ||||
"""Training process in one epoch. | """Training process in one epoch. | ||||
@@ -149,13 +165,17 @@ class Trainer(object): | |||||
- start: time.time(), the starting time of this step. | - start: time.time(), the starting time of this step. | ||||
- epoch: int, | - epoch: int, | ||||
""" | """ | ||||
step = 0 | |||||
step = kwargs['step'] | |||||
for batch_x, batch_y in data_iterator: | for batch_x, batch_y in data_iterator: | ||||
prediction = self.data_forward(network, batch_x) | prediction = self.data_forward(network, batch_x) | ||||
loss = self.get_loss(prediction, batch_y) | loss = self.get_loss(prediction, batch_y) | ||||
self.grad_backward(loss) | self.grad_backward(loss) | ||||
# if torch.rand(1).item() < 0.001: | |||||
# print('[grads at epoch: {:>3} step: {:>4}]'.format(kwargs['epoch'], step)) | |||||
# for name, p in self._model.named_parameters(): | |||||
# if p.requires_grad: | |||||
# print('\t{} {} {}'.format(name, tuple(p.size()), torch.sum(p.grad).item())) | |||||
self.update() | self.update() | ||||
self._summary_writer.add_scalar("loss", loss.item(), global_step=step) | self._summary_writer.add_scalar("loss", loss.item(), global_step=step) | ||||
@@ -166,7 +186,22 @@ class Trainer(object): | |||||
kwargs["epoch"], step, loss.data, diff) | kwargs["epoch"], step, loss.data, diff) | ||||
print(print_output) | print(print_output) | ||||
logger.info(print_output) | logger.info(print_output) | ||||
if self.validate and self.valid_step > 0 and step > 0 and step % self.valid_step == 0: | |||||
self.valid_model() | |||||
step += 1 | step += 1 | ||||
return step | |||||
def valid_model(self): | |||||
if self.dev_data is None: | |||||
raise RuntimeError( | |||||
"self.validate is True in trainer, but dev_data is None. Please provide the validation data.") | |||||
logger.info("validation started") | |||||
res = self.validator.test(self._model, self.dev_data) | |||||
if self.save_best_dev and self.best_eval_result(res): | |||||
logger.info('save best result! {}'.format(res)) | |||||
print('save best result! {}'.format(res)) | |||||
self.save_model(self._model, 'best_model_'+self.start_time) | |||||
return res | |||||
def mode(self, model, is_test=False): | def mode(self, model, is_test=False): | ||||
"""Train mode or Test mode. This is for PyTorch currently. | """Train mode or Test mode. This is for PyTorch currently. | ||||
@@ -180,11 +215,17 @@ class Trainer(object): | |||||
else: | else: | ||||
model.train() | model.train() | ||||
def define_optimizer(self): | |||||
def define_optimizer(self, optim=None): | |||||
"""Define framework-specific optimizer specified by the models. | """Define framework-specific optimizer specified by the models. | ||||
""" | """ | ||||
self._optimizer = self._optimizer_proto.construct_from_pytorch(self._model.parameters()) | |||||
if optim is not None: | |||||
# optimizer constructed by user | |||||
self._optimizer = optim | |||||
elif self._optimizer is None: | |||||
# optimizer constructed by proto | |||||
self._optimizer = self._optimizer_proto.construct_from_pytorch(self._model.parameters()) | |||||
return self._optimizer | |||||
def update(self): | def update(self): | ||||
"""Perform weight update on a model. | """Perform weight update on a model. | ||||
@@ -217,6 +258,8 @@ class Trainer(object): | |||||
:param truth: ground truth label vector | :param truth: ground truth label vector | ||||
:return: a scalar | :return: a scalar | ||||
""" | """ | ||||
if isinstance(predict, dict) and isinstance(truth, dict): | |||||
return self._loss_func(**predict, **truth) | |||||
if len(truth) > 1: | if len(truth) > 1: | ||||
raise NotImplementedError("Not ready to handle multi-labels.") | raise NotImplementedError("Not ready to handle multi-labels.") | ||||
truth = list(truth.values())[0] if len(truth) > 0 else None | truth = list(truth.values())[0] if len(truth) > 0 else None | ||||
@@ -241,13 +284,23 @@ class Trainer(object): | |||||
raise ValueError("Please specify a loss function.") | raise ValueError("Please specify a loss function.") | ||||
logger.info("The model didn't define loss, use Trainer's loss.") | logger.info("The model didn't define loss, use Trainer's loss.") | ||||
def best_eval_result(self, validator): | |||||
def best_eval_result(self, metrics): | |||||
"""Check if the current epoch yields better validation results. | """Check if the current epoch yields better validation results. | ||||
:param validator: a Tester instance | :param validator: a Tester instance | ||||
:return: bool, True means current results on dev set is the best. | :return: bool, True means current results on dev set is the best. | ||||
""" | """ | ||||
loss, accuracy = validator.metrics | |||||
if isinstance(metrics, tuple): | |||||
loss, metrics = metrics | |||||
if isinstance(metrics, dict): | |||||
if len(metrics) == 1: | |||||
accuracy = list(metrics.values())[0] | |||||
else: | |||||
accuracy = metrics[self.eval_sort_key] | |||||
else: | |||||
accuracy = metrics | |||||
if accuracy > self._best_accuracy: | if accuracy > self._best_accuracy: | ||||
self._best_accuracy = accuracy | self._best_accuracy = accuracy | ||||
return True | return True | ||||
@@ -268,6 +321,8 @@ class Trainer(object): | |||||
def _create_validator(self, valid_args): | def _create_validator(self, valid_args): | ||||
raise NotImplementedError | raise NotImplementedError | ||||
def set_validator(self, validor): | |||||
self.validator = validor | |||||
class SeqLabelTrainer(Trainer): | class SeqLabelTrainer(Trainer): | ||||
"""Trainer for Sequence Labeling | """Trainer for Sequence Labeling | ||||
@@ -51,6 +51,12 @@ class Vocabulary(object): | |||||
self.min_freq = min_freq | self.min_freq = min_freq | ||||
self.word_count = {} | self.word_count = {} | ||||
self.has_default = need_default | self.has_default = need_default | ||||
if self.has_default: | |||||
self.padding_label = DEFAULT_PADDING_LABEL | |||||
self.unknown_label = DEFAULT_UNKNOWN_LABEL | |||||
else: | |||||
self.padding_label = None | |||||
self.unknown_label = None | |||||
self.word2idx = None | self.word2idx = None | ||||
self.idx2word = None | self.idx2word = None | ||||
@@ -77,12 +83,10 @@ class Vocabulary(object): | |||||
""" | """ | ||||
if self.has_default: | if self.has_default: | ||||
self.word2idx = deepcopy(DEFAULT_WORD_TO_INDEX) | self.word2idx = deepcopy(DEFAULT_WORD_TO_INDEX) | ||||
self.padding_label = DEFAULT_PADDING_LABEL | |||||
self.unknown_label = DEFAULT_UNKNOWN_LABEL | |||||
self.word2idx[self.unknown_label] = self.word2idx.pop(DEFAULT_UNKNOWN_LABEL) | |||||
self.word2idx[self.padding_label] = self.word2idx.pop(DEFAULT_PADDING_LABEL) | |||||
else: | else: | ||||
self.word2idx = {} | self.word2idx = {} | ||||
self.padding_label = None | |||||
self.unknown_label = None | |||||
words = sorted(self.word_count.items(), key=lambda kv: kv[1], reverse=True) | words = sorted(self.word_count.items(), key=lambda kv: kv[1], reverse=True) | ||||
if self.min_freq is not None: | if self.min_freq is not None: | ||||
@@ -114,7 +118,7 @@ class Vocabulary(object): | |||||
if w in self.word2idx: | if w in self.word2idx: | ||||
return self.word2idx[w] | return self.word2idx[w] | ||||
elif self.has_default: | elif self.has_default: | ||||
return self.word2idx[DEFAULT_UNKNOWN_LABEL] | |||||
return self.word2idx[self.unknown_label] | |||||
else: | else: | ||||
raise ValueError("word {} not in vocabulary".format(w)) | raise ValueError("word {} not in vocabulary".format(w)) | ||||
@@ -134,6 +138,11 @@ class Vocabulary(object): | |||||
return None | return None | ||||
return self.word2idx[self.unknown_label] | return self.word2idx[self.unknown_label] | ||||
def __setattr__(self, name, val): | |||||
self.__dict__[name] = val | |||||
if name in self.__dict__ and name in ["unknown_label", "padding_label"]: | |||||
self.word2idx = None | |||||
@property | @property | ||||
@check_build_vocab | @check_build_vocab | ||||
def padding_idx(self): | def padding_idx(self): | ||||
@@ -87,7 +87,6 @@ class DataSetLoader(BaseLoader): | |||||
""" | """ | ||||
raise NotImplementedError | raise NotImplementedError | ||||
@DataSet.set_reader('read_raw') | @DataSet.set_reader('read_raw') | ||||
class RawDataSetLoader(DataSetLoader): | class RawDataSetLoader(DataSetLoader): | ||||
def __init__(self): | def __init__(self): | ||||
@@ -103,7 +102,6 @@ class RawDataSetLoader(DataSetLoader): | |||||
def convert(self, data): | def convert(self, data): | ||||
return convert_seq_dataset(data) | return convert_seq_dataset(data) | ||||
@DataSet.set_reader('read_pos') | @DataSet.set_reader('read_pos') | ||||
class POSDataSetLoader(DataSetLoader): | class POSDataSetLoader(DataSetLoader): | ||||
"""Dataset Loader for POS Tag datasets. | """Dataset Loader for POS Tag datasets. | ||||
@@ -173,7 +171,6 @@ class POSDataSetLoader(DataSetLoader): | |||||
""" | """ | ||||
return convert_seq2seq_dataset(data) | return convert_seq2seq_dataset(data) | ||||
@DataSet.set_reader('read_tokenize') | @DataSet.set_reader('read_tokenize') | ||||
class TokenizeDataSetLoader(DataSetLoader): | class TokenizeDataSetLoader(DataSetLoader): | ||||
""" | """ | ||||
@@ -233,7 +230,6 @@ class TokenizeDataSetLoader(DataSetLoader): | |||||
def convert(self, data): | def convert(self, data): | ||||
return convert_seq2seq_dataset(data) | return convert_seq2seq_dataset(data) | ||||
@DataSet.set_reader('read_class') | @DataSet.set_reader('read_class') | ||||
class ClassDataSetLoader(DataSetLoader): | class ClassDataSetLoader(DataSetLoader): | ||||
"""Loader for classification data sets""" | """Loader for classification data sets""" | ||||
@@ -272,7 +268,6 @@ class ClassDataSetLoader(DataSetLoader): | |||||
def convert(self, data): | def convert(self, data): | ||||
return convert_seq2tag_dataset(data) | return convert_seq2tag_dataset(data) | ||||
@DataSet.set_reader('read_conll') | @DataSet.set_reader('read_conll') | ||||
class ConllLoader(DataSetLoader): | class ConllLoader(DataSetLoader): | ||||
"""loader for conll format files""" | """loader for conll format files""" | ||||
@@ -314,7 +309,6 @@ class ConllLoader(DataSetLoader): | |||||
def convert(self, data): | def convert(self, data): | ||||
pass | pass | ||||
@DataSet.set_reader('read_lm') | @DataSet.set_reader('read_lm') | ||||
class LMDataSetLoader(DataSetLoader): | class LMDataSetLoader(DataSetLoader): | ||||
"""Language Model Dataset Loader | """Language Model Dataset Loader | ||||
@@ -351,7 +345,6 @@ class LMDataSetLoader(DataSetLoader): | |||||
def convert(self, data): | def convert(self, data): | ||||
pass | pass | ||||
@DataSet.set_reader('read_people_daily') | @DataSet.set_reader('read_people_daily') | ||||
class PeopleDailyCorpusLoader(DataSetLoader): | class PeopleDailyCorpusLoader(DataSetLoader): | ||||
""" | """ | ||||
@@ -17,8 +17,8 @@ class EmbedLoader(BaseLoader): | |||||
def _load_glove(emb_file): | def _load_glove(emb_file): | ||||
"""Read file as a glove embedding | """Read file as a glove embedding | ||||
file format: | |||||
embeddings are split by line, | |||||
file format: | |||||
embeddings are split by line, | |||||
for one embedding, word and numbers split by space | for one embedding, word and numbers split by space | ||||
Example:: | Example:: | ||||
@@ -33,7 +33,7 @@ class EmbedLoader(BaseLoader): | |||||
if len(line) > 0: | if len(line) > 0: | ||||
emb[line[0]] = torch.Tensor(list(map(float, line[1:]))) | emb[line[0]] = torch.Tensor(list(map(float, line[1:]))) | ||||
return emb | return emb | ||||
@staticmethod | @staticmethod | ||||
def _load_pretrain(emb_file, emb_type): | def _load_pretrain(emb_file, emb_type): | ||||
"""Read txt data from embedding file and convert to np.array as pre-trained embedding | """Read txt data from embedding file and convert to np.array as pre-trained embedding | ||||
@@ -16,10 +16,9 @@ def mst(scores): | |||||
https://github.com/tdozat/Parser/blob/0739216129cd39d69997d28cbc4133b360ea3934/lib/models/nn.py#L692 | https://github.com/tdozat/Parser/blob/0739216129cd39d69997d28cbc4133b360ea3934/lib/models/nn.py#L692 | ||||
""" | """ | ||||
length = scores.shape[0] | length = scores.shape[0] | ||||
min_score = -np.inf | |||||
mask = np.zeros((length, length)) | |||||
np.fill_diagonal(mask, -np.inf) | |||||
scores = scores + mask | |||||
min_score = scores.min() - 1 | |||||
eye = np.eye(length) | |||||
scores = scores * (1 - eye) + min_score * eye | |||||
heads = np.argmax(scores, axis=1) | heads = np.argmax(scores, axis=1) | ||||
heads[0] = 0 | heads[0] = 0 | ||||
tokens = np.arange(1, length) | tokens = np.arange(1, length) | ||||
@@ -126,6 +125,8 @@ class GraphParser(nn.Module): | |||||
def _greedy_decoder(self, arc_matrix, seq_mask=None): | def _greedy_decoder(self, arc_matrix, seq_mask=None): | ||||
_, seq_len, _ = arc_matrix.shape | _, seq_len, _ = arc_matrix.shape | ||||
matrix = arc_matrix + torch.diag(arc_matrix.new(seq_len).fill_(-np.inf)) | matrix = arc_matrix + torch.diag(arc_matrix.new(seq_len).fill_(-np.inf)) | ||||
flip_mask = (seq_mask == 0).byte() | |||||
matrix.masked_fill_(flip_mask.unsqueeze(1), -np.inf) | |||||
_, heads = torch.max(matrix, dim=2) | _, heads = torch.max(matrix, dim=2) | ||||
if seq_mask is not None: | if seq_mask is not None: | ||||
heads *= seq_mask.long() | heads *= seq_mask.long() | ||||
@@ -135,8 +136,15 @@ class GraphParser(nn.Module): | |||||
batch_size, seq_len, _ = arc_matrix.shape | batch_size, seq_len, _ = arc_matrix.shape | ||||
matrix = torch.zeros_like(arc_matrix).copy_(arc_matrix) | matrix = torch.zeros_like(arc_matrix).copy_(arc_matrix) | ||||
ans = matrix.new_zeros(batch_size, seq_len).long() | ans = matrix.new_zeros(batch_size, seq_len).long() | ||||
lens = (seq_mask.long()).sum(1) if seq_mask is not None else torch.zeros(batch_size) + seq_len | |||||
batch_idx = torch.arange(batch_size, dtype=torch.long, device=lens.device) | |||||
seq_mask[batch_idx, lens-1] = 0 | |||||
for i, graph in enumerate(matrix): | for i, graph in enumerate(matrix): | ||||
ans[i] = torch.as_tensor(mst(graph.cpu().numpy()), device=ans.device) | |||||
len_i = lens[i] | |||||
if len_i == seq_len: | |||||
ans[i] = torch.as_tensor(mst(graph.cpu().numpy()), device=ans.device) | |||||
else: | |||||
ans[i, :len_i] = torch.as_tensor(mst(graph[:len_i, :len_i].cpu().numpy()), device=ans.device) | |||||
if seq_mask is not None: | if seq_mask is not None: | ||||
ans *= seq_mask.long() | ans *= seq_mask.long() | ||||
return ans | return ans | ||||
@@ -175,14 +183,19 @@ class LabelBilinear(nn.Module): | |||||
def __init__(self, in1_features, in2_features, num_label, bias=True): | def __init__(self, in1_features, in2_features, num_label, bias=True): | ||||
super(LabelBilinear, self).__init__() | super(LabelBilinear, self).__init__() | ||||
self.bilinear = nn.Bilinear(in1_features, in2_features, num_label, bias=bias) | self.bilinear = nn.Bilinear(in1_features, in2_features, num_label, bias=bias) | ||||
self.lin1 = nn.Linear(in1_features, num_label, bias=False) | |||||
self.lin2 = nn.Linear(in2_features, num_label, bias=False) | |||||
self.lin = nn.Linear(in1_features + in2_features, num_label, bias=False) | |||||
def forward(self, x1, x2): | def forward(self, x1, x2): | ||||
output = self.bilinear(x1, x2) | output = self.bilinear(x1, x2) | ||||
output += self.lin1(x1) + self.lin2(x2) | |||||
output += self.lin(torch.cat([x1, x2], dim=2)) | |||||
return output | return output | ||||
def len2masks(origin_len, max_len): | |||||
if origin_len.dim() <= 1: | |||||
origin_len = origin_len.unsqueeze(1) # [batch_size, 1] | |||||
seq_range = torch.arange(start=0, end=max_len, dtype=torch.long, device=origin_len.device) # [max_len,] | |||||
seq_mask = torch.gt(origin_len, seq_range.unsqueeze(0)) # [batch_size, max_len] | |||||
return seq_mask | |||||
class BiaffineParser(GraphParser): | class BiaffineParser(GraphParser): | ||||
"""Biaffine Dependency Parser implemantation. | """Biaffine Dependency Parser implemantation. | ||||
@@ -194,6 +207,8 @@ class BiaffineParser(GraphParser): | |||||
word_emb_dim, | word_emb_dim, | ||||
pos_vocab_size, | pos_vocab_size, | ||||
pos_emb_dim, | pos_emb_dim, | ||||
word_hid_dim, | |||||
pos_hid_dim, | |||||
rnn_layers, | rnn_layers, | ||||
rnn_hidden_size, | rnn_hidden_size, | ||||
arc_mlp_size, | arc_mlp_size, | ||||
@@ -204,10 +219,15 @@ class BiaffineParser(GraphParser): | |||||
use_greedy_infer=False): | use_greedy_infer=False): | ||||
super(BiaffineParser, self).__init__() | super(BiaffineParser, self).__init__() | ||||
rnn_out_size = 2 * rnn_hidden_size | |||||
self.word_embedding = nn.Embedding(num_embeddings=word_vocab_size, embedding_dim=word_emb_dim) | self.word_embedding = nn.Embedding(num_embeddings=word_vocab_size, embedding_dim=word_emb_dim) | ||||
self.pos_embedding = nn.Embedding(num_embeddings=pos_vocab_size, embedding_dim=pos_emb_dim) | self.pos_embedding = nn.Embedding(num_embeddings=pos_vocab_size, embedding_dim=pos_emb_dim) | ||||
self.word_fc = nn.Linear(word_emb_dim, word_hid_dim) | |||||
self.pos_fc = nn.Linear(pos_emb_dim, pos_hid_dim) | |||||
self.word_norm = nn.LayerNorm(word_hid_dim) | |||||
self.pos_norm = nn.LayerNorm(pos_hid_dim) | |||||
if use_var_lstm: | if use_var_lstm: | ||||
self.lstm = VarLSTM(input_size=word_emb_dim + pos_emb_dim, | |||||
self.lstm = VarLSTM(input_size=word_hid_dim + pos_hid_dim, | |||||
hidden_size=rnn_hidden_size, | hidden_size=rnn_hidden_size, | ||||
num_layers=rnn_layers, | num_layers=rnn_layers, | ||||
bias=True, | bias=True, | ||||
@@ -216,7 +236,7 @@ class BiaffineParser(GraphParser): | |||||
hidden_dropout=dropout, | hidden_dropout=dropout, | ||||
bidirectional=True) | bidirectional=True) | ||||
else: | else: | ||||
self.lstm = nn.LSTM(input_size=word_emb_dim + pos_emb_dim, | |||||
self.lstm = nn.LSTM(input_size=word_hid_dim + pos_hid_dim, | |||||
hidden_size=rnn_hidden_size, | hidden_size=rnn_hidden_size, | ||||
num_layers=rnn_layers, | num_layers=rnn_layers, | ||||
bias=True, | bias=True, | ||||
@@ -224,21 +244,35 @@ class BiaffineParser(GraphParser): | |||||
dropout=dropout, | dropout=dropout, | ||||
bidirectional=True) | bidirectional=True) | ||||
rnn_out_size = 2 * rnn_hidden_size | |||||
self.arc_head_mlp = nn.Sequential(nn.Linear(rnn_out_size, arc_mlp_size), | self.arc_head_mlp = nn.Sequential(nn.Linear(rnn_out_size, arc_mlp_size), | ||||
nn.ELU()) | |||||
nn.LayerNorm(arc_mlp_size), | |||||
nn.ELU(), | |||||
TimestepDropout(p=dropout),) | |||||
self.arc_dep_mlp = copy.deepcopy(self.arc_head_mlp) | self.arc_dep_mlp = copy.deepcopy(self.arc_head_mlp) | ||||
self.label_head_mlp = nn.Sequential(nn.Linear(rnn_out_size, label_mlp_size), | self.label_head_mlp = nn.Sequential(nn.Linear(rnn_out_size, label_mlp_size), | ||||
nn.ELU()) | |||||
nn.LayerNorm(label_mlp_size), | |||||
nn.ELU(), | |||||
TimestepDropout(p=dropout),) | |||||
self.label_dep_mlp = copy.deepcopy(self.label_head_mlp) | self.label_dep_mlp = copy.deepcopy(self.label_head_mlp) | ||||
self.arc_predictor = ArcBiaffine(arc_mlp_size, bias=True) | self.arc_predictor = ArcBiaffine(arc_mlp_size, bias=True) | ||||
self.label_predictor = LabelBilinear(label_mlp_size, label_mlp_size, num_label, bias=True) | self.label_predictor = LabelBilinear(label_mlp_size, label_mlp_size, num_label, bias=True) | ||||
self.normal_dropout = nn.Dropout(p=dropout) | self.normal_dropout = nn.Dropout(p=dropout) | ||||
self.timestep_dropout = TimestepDropout(p=dropout) | |||||
self.use_greedy_infer = use_greedy_infer | self.use_greedy_infer = use_greedy_infer | ||||
initial_parameter(self) | |||||
self.reset_parameters() | |||||
self.explore_p = 0.2 | |||||
def reset_parameters(self): | |||||
for m in self.modules(): | |||||
if isinstance(m, nn.Embedding): | |||||
continue | |||||
elif isinstance(m, nn.LayerNorm): | |||||
nn.init.constant_(m.weight, 0.1) | |||||
nn.init.constant_(m.bias, 0) | |||||
else: | |||||
for p in m.parameters(): | |||||
nn.init.normal_(p, 0, 0.1) | |||||
def forward(self, word_seq, pos_seq, seq_mask, gold_heads=None, **_): | |||||
def forward(self, word_seq, pos_seq, word_seq_origin_len, gold_heads=None, **_): | |||||
""" | """ | ||||
:param word_seq: [batch_size, seq_len] sequence of word's indices | :param word_seq: [batch_size, seq_len] sequence of word's indices | ||||
:param pos_seq: [batch_size, seq_len] sequence of word's indices | :param pos_seq: [batch_size, seq_len] sequence of word's indices | ||||
@@ -253,32 +287,35 @@ class BiaffineParser(GraphParser): | |||||
# prepare embeddings | # prepare embeddings | ||||
batch_size, seq_len = word_seq.shape | batch_size, seq_len = word_seq.shape | ||||
# print('forward {} {}'.format(batch_size, seq_len)) | # print('forward {} {}'.format(batch_size, seq_len)) | ||||
batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=word_seq.device).unsqueeze(1) | |||||
# get sequence mask | # get sequence mask | ||||
seq_mask = seq_mask.long() | |||||
seq_mask = len2masks(word_seq_origin_len, seq_len).long() | |||||
word = self.normal_dropout(self.word_embedding(word_seq)) # [N,L] -> [N,L,C_0] | word = self.normal_dropout(self.word_embedding(word_seq)) # [N,L] -> [N,L,C_0] | ||||
pos = self.normal_dropout(self.pos_embedding(pos_seq)) # [N,L] -> [N,L,C_1] | pos = self.normal_dropout(self.pos_embedding(pos_seq)) # [N,L] -> [N,L,C_1] | ||||
word, pos = self.word_fc(word), self.pos_fc(pos) | |||||
word, pos = self.word_norm(word), self.pos_norm(pos) | |||||
x = torch.cat([word, pos], dim=2) # -> [N,L,C] | x = torch.cat([word, pos], dim=2) # -> [N,L,C] | ||||
del word, pos | |||||
# lstm, extract features | # lstm, extract features | ||||
x = nn.utils.rnn.pack_padded_sequence(x, word_seq_origin_len.squeeze(1), batch_first=True) | |||||
feat, _ = self.lstm(x) # -> [N,L,C] | feat, _ = self.lstm(x) # -> [N,L,C] | ||||
feat, _ = nn.utils.rnn.pad_packed_sequence(feat, batch_first=True) | |||||
# for arc biaffine | # for arc biaffine | ||||
# mlp, reduce dim | # mlp, reduce dim | ||||
arc_dep = self.timestep_dropout(self.arc_dep_mlp(feat)) | |||||
arc_head = self.timestep_dropout(self.arc_head_mlp(feat)) | |||||
label_dep = self.timestep_dropout(self.label_dep_mlp(feat)) | |||||
label_head = self.timestep_dropout(self.label_head_mlp(feat)) | |||||
arc_dep = self.arc_dep_mlp(feat) | |||||
arc_head = self.arc_head_mlp(feat) | |||||
label_dep = self.label_dep_mlp(feat) | |||||
label_head = self.label_head_mlp(feat) | |||||
del feat | |||||
# biaffine arc classifier | # biaffine arc classifier | ||||
arc_pred = self.arc_predictor(arc_head, arc_dep) # [N, L, L] | arc_pred = self.arc_predictor(arc_head, arc_dep) # [N, L, L] | ||||
flip_mask = (seq_mask == 0) | |||||
arc_pred.masked_fill_(flip_mask.unsqueeze(1), -np.inf) | |||||
# use gold or predicted arc to predict label | # use gold or predicted arc to predict label | ||||
if gold_heads is None: | |||||
if gold_heads is None or not self.training: | |||||
# use greedy decoding in training | # use greedy decoding in training | ||||
if self.training or self.use_greedy_infer: | if self.training or self.use_greedy_infer: | ||||
heads = self._greedy_decoder(arc_pred, seq_mask) | heads = self._greedy_decoder(arc_pred, seq_mask) | ||||
@@ -286,9 +323,15 @@ class BiaffineParser(GraphParser): | |||||
heads = self._mst_decoder(arc_pred, seq_mask) | heads = self._mst_decoder(arc_pred, seq_mask) | ||||
head_pred = heads | head_pred = heads | ||||
else: | else: | ||||
head_pred = None | |||||
heads = gold_heads | |||||
assert self.training # must be training mode | |||||
if torch.rand(1).item() < self.explore_p: | |||||
heads = self._greedy_decoder(arc_pred, seq_mask) | |||||
head_pred = heads | |||||
else: | |||||
head_pred = None | |||||
heads = gold_heads | |||||
batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=word_seq.device).unsqueeze(1) | |||||
label_head = label_head[batch_range, heads].contiguous() | label_head = label_head[batch_range, heads].contiguous() | ||||
label_pred = self.label_predictor(label_head, label_dep) # [N, L, num_label] | label_pred = self.label_predictor(label_head, label_dep) # [N, L, num_label] | ||||
res_dict = {'arc_pred': arc_pred, 'label_pred': label_pred, 'seq_mask': seq_mask} | res_dict = {'arc_pred': arc_pred, 'label_pred': label_pred, 'seq_mask': seq_mask} | ||||
@@ -301,7 +344,7 @@ class BiaffineParser(GraphParser): | |||||
Compute loss. | Compute loss. | ||||
:param arc_pred: [batch_size, seq_len, seq_len] | :param arc_pred: [batch_size, seq_len, seq_len] | ||||
:param label_pred: [batch_size, seq_len, seq_len] | |||||
:param label_pred: [batch_size, seq_len, n_tags] | |||||
:param head_indices: [batch_size, seq_len] | :param head_indices: [batch_size, seq_len] | ||||
:param head_labels: [batch_size, seq_len] | :param head_labels: [batch_size, seq_len] | ||||
:param seq_mask: [batch_size, seq_len] | :param seq_mask: [batch_size, seq_len] | ||||
@@ -309,10 +352,13 @@ class BiaffineParser(GraphParser): | |||||
""" | """ | ||||
batch_size, seq_len, _ = arc_pred.shape | batch_size, seq_len, _ = arc_pred.shape | ||||
arc_logits = F.log_softmax(arc_pred, dim=2) | |||||
flip_mask = (seq_mask == 0) | |||||
_arc_pred = arc_pred.new_empty((batch_size, seq_len, seq_len)).copy_(arc_pred) | |||||
_arc_pred.masked_fill_(flip_mask.unsqueeze(1), -np.inf) | |||||
arc_logits = F.log_softmax(_arc_pred, dim=2) | |||||
label_logits = F.log_softmax(label_pred, dim=2) | label_logits = F.log_softmax(label_pred, dim=2) | ||||
batch_index = torch.arange(start=0, end=batch_size, device=arc_logits.device).long().unsqueeze(1) | |||||
child_index = torch.arange(start=0, end=seq_len, device=arc_logits.device).long().unsqueeze(0) | |||||
batch_index = torch.arange(batch_size, device=arc_logits.device, dtype=torch.long).unsqueeze(1) | |||||
child_index = torch.arange(seq_len, device=arc_logits.device, dtype=torch.long).unsqueeze(0) | |||||
arc_loss = arc_logits[batch_index, child_index, head_indices] | arc_loss = arc_logits[batch_index, child_index, head_indices] | ||||
label_loss = label_logits[batch_index, child_index, head_labels] | label_loss = label_logits[batch_index, child_index, head_labels] | ||||
@@ -320,45 +366,8 @@ class BiaffineParser(GraphParser): | |||||
label_loss = label_loss[:, 1:] | label_loss = label_loss[:, 1:] | ||||
float_mask = seq_mask[:, 1:].float() | float_mask = seq_mask[:, 1:].float() | ||||
length = (seq_mask.sum() - batch_size).float() | |||||
arc_nll = -(arc_loss*float_mask).sum() / length | |||||
label_nll = -(label_loss*float_mask).sum() / length | |||||
arc_nll = -(arc_loss*float_mask).mean() | |||||
label_nll = -(label_loss*float_mask).mean() | |||||
return arc_nll + label_nll | return arc_nll + label_nll | ||||
def evaluate(self, arc_pred, label_pred, head_indices, head_labels, seq_mask, **kwargs): | |||||
""" | |||||
Evaluate the performance of prediction. | |||||
:return dict: performance results. | |||||
head_pred_corrct: number of correct predicted heads. | |||||
label_pred_correct: number of correct predicted labels. | |||||
total_tokens: number of predicted tokens | |||||
""" | |||||
if 'head_pred' in kwargs: | |||||
head_pred = kwargs['head_pred'] | |||||
elif self.use_greedy_infer: | |||||
head_pred = self._greedy_decoder(arc_pred, seq_mask) | |||||
else: | |||||
head_pred = self._mst_decoder(arc_pred, seq_mask) | |||||
head_pred_correct = (head_pred == head_indices).long() * seq_mask | |||||
_, label_preds = torch.max(label_pred, dim=2) | |||||
label_pred_correct = (label_preds == head_labels).long() * head_pred_correct | |||||
return {"head_pred_correct": head_pred_correct.sum(dim=1), | |||||
"label_pred_correct": label_pred_correct.sum(dim=1), | |||||
"total_tokens": seq_mask.sum(dim=1)} | |||||
def metrics(self, head_pred_correct, label_pred_correct, total_tokens, **_): | |||||
""" | |||||
Compute the metrics of model | |||||
:param head_pred_corrct: number of correct predicted heads. | |||||
:param label_pred_correct: number of correct predicted labels. | |||||
:param total_tokens: number of predicted tokens | |||||
:return dict: the metrics results | |||||
UAS: the head predicted accuracy | |||||
LAS: the label predicted accuracy | |||||
""" | |||||
return {"UAS": head_pred_correct.sum().float() / total_tokens.sum().float() * 100, | |||||
"LAS": label_pred_correct.sum().float() / total_tokens.sum().float() * 100} | |||||
@@ -1,5 +1,6 @@ | |||||
import torch | import torch | ||||
from torch import nn | |||||
import math | |||||
from fastNLP.modules.utils import mask_softmax | from fastNLP.modules.utils import mask_softmax | ||||
@@ -17,3 +18,44 @@ class Attention(torch.nn.Module): | |||||
def _atten_forward(self, query, memory): | def _atten_forward(self, query, memory): | ||||
raise NotImplementedError | raise NotImplementedError | ||||
class DotAtte(nn.Module): | |||||
def __init__(self, key_size, value_size): | |||||
super(DotAtte, self).__init__() | |||||
self.key_size = key_size | |||||
self.value_size = value_size | |||||
self.scale = math.sqrt(key_size) | |||||
def forward(self, Q, K, V, seq_mask=None): | |||||
""" | |||||
:param Q: [batch, seq_len, key_size] | |||||
:param K: [batch, seq_len, key_size] | |||||
:param V: [batch, seq_len, value_size] | |||||
:param seq_mask: [batch, seq_len] | |||||
""" | |||||
output = torch.matmul(Q, K.transpose(1, 2)) / self.scale | |||||
if seq_mask is not None: | |||||
output.masked_fill_(seq_mask.lt(1), -float('inf')) | |||||
output = nn.functional.softmax(output, dim=2) | |||||
return torch.matmul(output, V) | |||||
class MultiHeadAtte(nn.Module): | |||||
def __init__(self, input_size, output_size, key_size, value_size, num_atte): | |||||
super(MultiHeadAtte, self).__init__() | |||||
self.in_linear = nn.ModuleList() | |||||
for i in range(num_atte * 3): | |||||
out_feat = key_size if (i % 3) != 2 else value_size | |||||
self.in_linear.append(nn.Linear(input_size, out_feat)) | |||||
self.attes = nn.ModuleList([DotAtte(key_size, value_size) for _ in range(num_atte)]) | |||||
self.out_linear = nn.Linear(value_size * num_atte, output_size) | |||||
def forward(self, Q, K, V, seq_mask=None): | |||||
heads = [] | |||||
for i in range(len(self.attes)): | |||||
j = i * 3 | |||||
qi, ki, vi = self.in_linear[j](Q), self.in_linear[j+1](K), self.in_linear[j+2](V) | |||||
headi = self.attes[i](qi, ki, vi, seq_mask) | |||||
heads.append(headi) | |||||
output = torch.cat(heads, dim=2) | |||||
return self.out_linear(output) |
@@ -0,0 +1,32 @@ | |||||
import torch | |||||
from torch import nn | |||||
import torch.nn.functional as F | |||||
from ..aggregator.attention import MultiHeadAtte | |||||
from ..other_modules import LayerNormalization | |||||
class TransformerEncoder(nn.Module): | |||||
class SubLayer(nn.Module): | |||||
def __init__(self, input_size, output_size, key_size, value_size, num_atte): | |||||
super(TransformerEncoder.SubLayer, self).__init__() | |||||
self.atte = MultiHeadAtte(input_size, output_size, key_size, value_size, num_atte) | |||||
self.norm1 = LayerNormalization(output_size) | |||||
self.ffn = nn.Sequential(nn.Linear(output_size, output_size), | |||||
nn.ReLU(), | |||||
nn.Linear(output_size, output_size)) | |||||
self.norm2 = LayerNormalization(output_size) | |||||
def forward(self, input, seq_mask): | |||||
attention = self.atte(input) | |||||
norm_atte = self.norm1(attention + input) | |||||
output = self.ffn(norm_atte) | |||||
return self.norm2(output + norm_atte) | |||||
def __init__(self, num_layers, **kargs): | |||||
super(TransformerEncoder, self).__init__() | |||||
self.layers = nn.Sequential(*[self.SubLayer(**kargs) for _ in range(num_layers)]) | |||||
def forward(self, x, seq_mask=None): | |||||
return self.layers(x, seq_mask) | |||||
@@ -101,14 +101,14 @@ class VarRNNBase(nn.Module): | |||||
mask_x = input.new_ones((batch_size, self.input_size)) | mask_x = input.new_ones((batch_size, self.input_size)) | ||||
mask_out = input.new_ones((batch_size, self.hidden_size * self.num_directions)) | mask_out = input.new_ones((batch_size, self.hidden_size * self.num_directions)) | ||||
mask_h = input.new_ones((batch_size, self.hidden_size)) | |||||
mask_h_ones = input.new_ones((batch_size, self.hidden_size)) | |||||
nn.functional.dropout(mask_x, p=self.input_dropout, training=self.training, inplace=True) | nn.functional.dropout(mask_x, p=self.input_dropout, training=self.training, inplace=True) | ||||
nn.functional.dropout(mask_out, p=self.hidden_dropout, training=self.training, inplace=True) | nn.functional.dropout(mask_out, p=self.hidden_dropout, training=self.training, inplace=True) | ||||
nn.functional.dropout(mask_h, p=self.hidden_dropout, training=self.training, inplace=True) | |||||
hidden_list = [] | hidden_list = [] | ||||
for layer in range(self.num_layers): | for layer in range(self.num_layers): | ||||
output_list = [] | output_list = [] | ||||
mask_h = nn.functional.dropout(mask_h_ones, p=self.hidden_dropout, training=self.training, inplace=False) | |||||
for direction in range(self.num_directions): | for direction in range(self.num_directions): | ||||
input_x = input if direction == 0 else flip(input, [0]) | input_x = input if direction == 0 else flip(input, [0]) | ||||
idx = self.num_directions * layer + direction | idx = self.num_directions * layer + direction | ||||
@@ -31,12 +31,12 @@ class GroupNorm(nn.Module): | |||||
class LayerNormalization(nn.Module): | class LayerNormalization(nn.Module): | ||||
""" Layer normalization module """ | """ Layer normalization module """ | ||||
def __init__(self, d_hid, eps=1e-3): | |||||
def __init__(self, layer_size, eps=1e-3): | |||||
super(LayerNormalization, self).__init__() | super(LayerNormalization, self).__init__() | ||||
self.eps = eps | self.eps = eps | ||||
self.a_2 = nn.Parameter(torch.ones(d_hid), requires_grad=True) | |||||
self.b_2 = nn.Parameter(torch.zeros(d_hid), requires_grad=True) | |||||
self.a_2 = nn.Parameter(torch.ones(1, layer_size, requires_grad=True)) | |||||
self.b_2 = nn.Parameter(torch.zeros(1, layer_size, requires_grad=True)) | |||||
def forward(self, z): | def forward(self, z): | ||||
if z.size(1) == 1: | if z.size(1) == 1: | ||||
@@ -44,9 +44,8 @@ class LayerNormalization(nn.Module): | |||||
mu = torch.mean(z, keepdim=True, dim=-1) | mu = torch.mean(z, keepdim=True, dim=-1) | ||||
sigma = torch.std(z, keepdim=True, dim=-1) | sigma = torch.std(z, keepdim=True, dim=-1) | ||||
ln_out = (z - mu.expand_as(z)) / (sigma.expand_as(z) + self.eps) | |||||
ln_out = ln_out * self.a_2.expand_as(ln_out) + self.b_2.expand_as(ln_out) | |||||
ln_out = (z - mu) / (sigma + self.eps) | |||||
ln_out = ln_out * self.a_2 + self.b_2 | |||||
return ln_out | return ln_out | ||||
@@ -1,37 +1,40 @@ | |||||
[train] | [train] | ||||
epochs = 50 | |||||
epochs = -1 | |||||
batch_size = 16 | batch_size = 16 | ||||
pickle_path = "./save/" | pickle_path = "./save/" | ||||
validate = true | validate = true | ||||
save_best_dev = false | |||||
save_best_dev = true | |||||
eval_sort_key = "UAS" | |||||
use_cuda = true | use_cuda = true | ||||
model_saved_path = "./save/" | model_saved_path = "./save/" | ||||
task = "parse" | |||||
print_every_step = 20 | |||||
use_golden_train=true | |||||
[test] | [test] | ||||
save_output = true | save_output = true | ||||
validate_in_training = true | validate_in_training = true | ||||
save_dev_input = false | save_dev_input = false | ||||
save_loss = true | save_loss = true | ||||
batch_size = 16 | |||||
batch_size = 64 | |||||
pickle_path = "./save/" | pickle_path = "./save/" | ||||
use_cuda = true | use_cuda = true | ||||
task = "parse" | |||||
[model] | [model] | ||||
word_vocab_size = -1 | word_vocab_size = -1 | ||||
word_emb_dim = 100 | word_emb_dim = 100 | ||||
pos_vocab_size = -1 | pos_vocab_size = -1 | ||||
pos_emb_dim = 100 | pos_emb_dim = 100 | ||||
word_hid_dim = 100 | |||||
pos_hid_dim = 100 | |||||
rnn_layers = 3 | rnn_layers = 3 | ||||
rnn_hidden_size = 400 | rnn_hidden_size = 400 | ||||
arc_mlp_size = 500 | arc_mlp_size = 500 | ||||
label_mlp_size = 100 | label_mlp_size = 100 | ||||
num_label = -1 | num_label = -1 | ||||
dropout = 0.33 | dropout = 0.33 | ||||
use_var_lstm=true | |||||
use_var_lstm=false | |||||
use_greedy_infer=false | use_greedy_infer=false | ||||
[optim] | [optim] | ||||
lr = 2e-3 | lr = 2e-3 | ||||
weight_decay = 5e-5 |
@@ -6,15 +6,17 @@ sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) | |||||
from collections import defaultdict | from collections import defaultdict | ||||
import math | import math | ||||
import torch | import torch | ||||
import re | |||||
from fastNLP.core.trainer import Trainer | from fastNLP.core.trainer import Trainer | ||||
from fastNLP.core.metrics import Evaluator | |||||
from fastNLP.core.instance import Instance | from fastNLP.core.instance import Instance | ||||
from fastNLP.core.vocabulary import Vocabulary | from fastNLP.core.vocabulary import Vocabulary | ||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.core.batch import Batch | from fastNLP.core.batch import Batch | ||||
from fastNLP.core.sampler import SequentialSampler | from fastNLP.core.sampler import SequentialSampler | ||||
from fastNLP.core.field import TextField, SeqLabelField | from fastNLP.core.field import TextField, SeqLabelField | ||||
from fastNLP.core.preprocess import SeqLabelPreprocess, load_pickle | |||||
from fastNLP.core.preprocess import load_pickle | |||||
from fastNLP.core.tester import Tester | from fastNLP.core.tester import Tester | ||||
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | ||||
from fastNLP.loader.model_loader import ModelLoader | from fastNLP.loader.model_loader import ModelLoader | ||||
@@ -22,15 +24,18 @@ from fastNLP.loader.embed_loader import EmbedLoader | |||||
from fastNLP.models.biaffine_parser import BiaffineParser | from fastNLP.models.biaffine_parser import BiaffineParser | ||||
from fastNLP.saver.model_saver import ModelSaver | from fastNLP.saver.model_saver import ModelSaver | ||||
BOS = '<BOS>' | |||||
EOS = '<EOS>' | |||||
UNK = '<OOV>' | |||||
NUM = '<NUM>' | |||||
ENG = '<ENG>' | |||||
# not in the file's dir | # not in the file's dir | ||||
if len(os.path.dirname(__file__)) != 0: | if len(os.path.dirname(__file__)) != 0: | ||||
os.chdir(os.path.dirname(__file__)) | os.chdir(os.path.dirname(__file__)) | ||||
class MyDataLoader(object): | |||||
def __init__(self, pickle_path): | |||||
self.pickle_path = pickle_path | |||||
def load(self, path, word_v=None, pos_v=None, headtag_v=None): | |||||
class ConlluDataLoader(object): | |||||
def load(self, path): | |||||
datalist = [] | datalist = [] | ||||
with open(path, 'r', encoding='utf-8') as f: | with open(path, 'r', encoding='utf-8') as f: | ||||
sample = [] | sample = [] | ||||
@@ -49,23 +54,18 @@ class MyDataLoader(object): | |||||
for sample in datalist: | for sample in datalist: | ||||
# print(sample) | # print(sample) | ||||
res = self.get_one(sample) | res = self.get_one(sample) | ||||
if word_v is not None: | |||||
word_v.update(res[0]) | |||||
pos_v.update(res[1]) | |||||
headtag_v.update(res[3]) | |||||
ds.append(Instance(word_seq=TextField(res[0], is_target=False), | ds.append(Instance(word_seq=TextField(res[0], is_target=False), | ||||
pos_seq=TextField(res[1], is_target=False), | pos_seq=TextField(res[1], is_target=False), | ||||
head_indices=SeqLabelField(res[2], is_target=True), | head_indices=SeqLabelField(res[2], is_target=True), | ||||
head_labels=TextField(res[3], is_target=True), | |||||
seq_mask=SeqLabelField([1 for _ in range(len(res[0]))], is_target=False))) | |||||
head_labels=TextField(res[3], is_target=True))) | |||||
return ds | return ds | ||||
def get_one(self, sample): | def get_one(self, sample): | ||||
text = ['<root>'] | |||||
pos_tags = ['<root>'] | |||||
heads = [0] | |||||
head_tags = ['root'] | |||||
text = [] | |||||
pos_tags = [] | |||||
heads = [] | |||||
head_tags = [] | |||||
for w in sample: | for w in sample: | ||||
t1, t2, t3, t4 = w[1], w[3], w[6], w[7] | t1, t2, t3, t4 = w[1], w[3], w[6], w[7] | ||||
if t3 == '_': | if t3 == '_': | ||||
@@ -76,17 +76,60 @@ class MyDataLoader(object): | |||||
head_tags.append(t4) | head_tags.append(t4) | ||||
return (text, pos_tags, heads, head_tags) | return (text, pos_tags, heads, head_tags) | ||||
def index_data(self, dataset, word_v, pos_v, tag_v): | |||||
dataset.index_field('word_seq', word_v) | |||||
dataset.index_field('pos_seq', pos_v) | |||||
dataset.index_field('head_labels', tag_v) | |||||
class CTBDataLoader(object): | |||||
def load(self, data_path): | |||||
with open(data_path, "r", encoding="utf-8") as f: | |||||
lines = f.readlines() | |||||
data = self.parse(lines) | |||||
return self.convert(data) | |||||
def parse(self, lines): | |||||
""" | |||||
[ | |||||
[word], [pos], [head_index], [head_tag] | |||||
] | |||||
""" | |||||
sample = [] | |||||
data = [] | |||||
for i, line in enumerate(lines): | |||||
line = line.strip() | |||||
if len(line) == 0 or i+1 == len(lines): | |||||
data.append(list(map(list, zip(*sample)))) | |||||
sample = [] | |||||
else: | |||||
sample.append(line.split()) | |||||
return data | |||||
def convert(self, data): | |||||
dataset = DataSet() | |||||
for sample in data: | |||||
word_seq = [BOS] + sample[0] + [EOS] | |||||
pos_seq = [BOS] + sample[1] + [EOS] | |||||
heads = [0] + list(map(int, sample[2])) + [0] | |||||
head_tags = [BOS] + sample[3] + [EOS] | |||||
dataset.append(Instance(word_seq=TextField(word_seq, is_target=False), | |||||
pos_seq=TextField(pos_seq, is_target=False), | |||||
gold_heads=SeqLabelField(heads, is_target=False), | |||||
head_indices=SeqLabelField(heads, is_target=True), | |||||
head_labels=TextField(head_tags, is_target=True))) | |||||
return dataset | |||||
# datadir = "/mnt/c/Me/Dev/release-2.2-st-train-dev-data/ud-treebanks-v2.2/UD_English-EWT" | # datadir = "/mnt/c/Me/Dev/release-2.2-st-train-dev-data/ud-treebanks-v2.2/UD_English-EWT" | ||||
datadir = "/home/yfshao/UD_English-EWT" | |||||
# datadir = "/home/yfshao/UD_English-EWT" | |||||
# train_data_name = "en_ewt-ud-train.conllu" | |||||
# dev_data_name = "en_ewt-ud-dev.conllu" | |||||
# emb_file_name = '/home/yfshao/glove.6B.100d.txt' | |||||
# loader = ConlluDataLoader() | |||||
datadir = '/home/yfshao/workdir/parser-data/' | |||||
train_data_name = "train_ctb5.txt" | |||||
dev_data_name = "dev_ctb5.txt" | |||||
test_data_name = "test_ctb5.txt" | |||||
emb_file_name = "/home/yfshao/workdir/parser-data/word_OOVthr_30_100v.txt" | |||||
# emb_file_name = "/home/yfshao/workdir/word_vector/cc.zh.300.vec" | |||||
loader = CTBDataLoader() | |||||
cfgfile = './cfg.cfg' | cfgfile = './cfg.cfg' | ||||
train_data_name = "en_ewt-ud-train.conllu" | |||||
dev_data_name = "en_ewt-ud-dev.conllu" | |||||
emb_file_name = '/home/yfshao/glove.6B.100d.txt' | |||||
processed_datadir = './save' | processed_datadir = './save' | ||||
# Config Loader | # Config Loader | ||||
@@ -95,8 +138,12 @@ test_args = ConfigSection() | |||||
model_args = ConfigSection() | model_args = ConfigSection() | ||||
optim_args = ConfigSection() | optim_args = ConfigSection() | ||||
ConfigLoader.load_config(cfgfile, {"train": train_args, "test": test_args, "model": model_args, "optim": optim_args}) | ConfigLoader.load_config(cfgfile, {"train": train_args, "test": test_args, "model": model_args, "optim": optim_args}) | ||||
print('trainre Args:', train_args.data) | |||||
print('test Args:', test_args.data) | |||||
print('optim Args:', optim_args.data) | |||||
# Data Loader | |||||
# Pickle Loader | |||||
def save_data(dirpath, **kwargs): | def save_data(dirpath, **kwargs): | ||||
import _pickle | import _pickle | ||||
if not os.path.exists(dirpath): | if not os.path.exists(dirpath): | ||||
@@ -117,38 +164,57 @@ def load_data(dirpath): | |||||
datas[name] = _pickle.load(f) | datas[name] = _pickle.load(f) | ||||
return datas | return datas | ||||
class MyTester(object): | |||||
def __init__(self, batch_size, use_cuda=False, **kwagrs): | |||||
self.batch_size = batch_size | |||||
self.use_cuda = use_cuda | |||||
def test(self, model, dataset): | |||||
self.model = model.cuda() if self.use_cuda else model | |||||
self.model.eval() | |||||
batchiter = Batch(dataset, self.batch_size, SequentialSampler(), self.use_cuda) | |||||
eval_res = defaultdict(list) | |||||
i = 0 | |||||
for batch_x, batch_y in batchiter: | |||||
with torch.no_grad(): | |||||
pred_y = self.model(**batch_x) | |||||
eval_one = self.model.evaluate(**pred_y, **batch_y) | |||||
i += self.batch_size | |||||
for eval_name, tensor in eval_one.items(): | |||||
eval_res[eval_name].append(tensor) | |||||
tmp = {} | |||||
for eval_name, tensorlist in eval_res.items(): | |||||
tmp[eval_name] = torch.cat(tensorlist, dim=0) | |||||
self.res = self.model.metrics(**tmp) | |||||
def show_metrics(self): | |||||
s = "" | |||||
for name, val in self.res.items(): | |||||
s += '{}: {:.2f}\t'.format(name, val) | |||||
return s | |||||
loader = MyDataLoader('') | |||||
def P2(data, field, length): | |||||
ds = [ins for ins in data if ins[field].get_length() >= length] | |||||
data.clear() | |||||
data.extend(ds) | |||||
return ds | |||||
def P1(data, field): | |||||
def reeng(w): | |||||
return w if w == BOS or w == EOS or re.search(r'^([a-zA-Z]+[\.\-]*)+$', w) is None else ENG | |||||
def renum(w): | |||||
return w if re.search(r'^[0-9]+\.?[0-9]*$', w) is None else NUM | |||||
for ins in data: | |||||
ori = ins[field].contents() | |||||
s = list(map(renum, map(reeng, ori))) | |||||
if s != ori: | |||||
# print(ori) | |||||
# print(s) | |||||
# print() | |||||
ins[field] = ins[field].new(s) | |||||
return data | |||||
class ParserEvaluator(Evaluator): | |||||
def __init__(self, ignore_label): | |||||
super(ParserEvaluator, self).__init__() | |||||
self.ignore = ignore_label | |||||
def __call__(self, predict_list, truth_list): | |||||
head_all, label_all, total_all = 0, 0, 0 | |||||
for pred, truth in zip(predict_list, truth_list): | |||||
head, label, total = self.evaluate(**pred, **truth) | |||||
head_all += head | |||||
label_all += label | |||||
total_all += total | |||||
return {'UAS': head_all*1.0 / total_all, 'LAS': label_all*1.0 / total_all} | |||||
def evaluate(self, head_pred, label_pred, head_indices, head_labels, seq_mask, **_): | |||||
""" | |||||
Evaluate the performance of prediction. | |||||
:return : performance results. | |||||
head_pred_corrct: number of correct predicted heads. | |||||
label_pred_correct: number of correct predicted labels. | |||||
total_tokens: number of predicted tokens | |||||
""" | |||||
seq_mask *= (head_labels != self.ignore).long() | |||||
head_pred_correct = (head_pred == head_indices).long() * seq_mask | |||||
_, label_preds = torch.max(label_pred, dim=2) | |||||
label_pred_correct = (label_preds == head_labels).long() * head_pred_correct | |||||
return head_pred_correct.sum().item(), label_pred_correct.sum().item(), seq_mask.sum().item() | |||||
try: | try: | ||||
data_dict = load_data(processed_datadir) | data_dict = load_data(processed_datadir) | ||||
word_v = data_dict['word_v'] | word_v = data_dict['word_v'] | ||||
@@ -156,62 +222,90 @@ try: | |||||
tag_v = data_dict['tag_v'] | tag_v = data_dict['tag_v'] | ||||
train_data = data_dict['train_data'] | train_data = data_dict['train_data'] | ||||
dev_data = data_dict['dev_data'] | dev_data = data_dict['dev_data'] | ||||
test_data = data_dict['test_data'] | |||||
print('use saved pickles') | print('use saved pickles') | ||||
except Exception as _: | except Exception as _: | ||||
print('load raw data and preprocess') | print('load raw data and preprocess') | ||||
# use pretrain embedding | |||||
word_v = Vocabulary(need_default=True, min_freq=2) | word_v = Vocabulary(need_default=True, min_freq=2) | ||||
word_v.unknown_label = UNK | |||||
pos_v = Vocabulary(need_default=True) | pos_v = Vocabulary(need_default=True) | ||||
tag_v = Vocabulary(need_default=False) | tag_v = Vocabulary(need_default=False) | ||||
train_data = loader.load(os.path.join(datadir, train_data_name), word_v, pos_v, tag_v) | |||||
train_data = loader.load(os.path.join(datadir, train_data_name)) | |||||
dev_data = loader.load(os.path.join(datadir, dev_data_name)) | dev_data = loader.load(os.path.join(datadir, dev_data_name)) | ||||
save_data(processed_datadir, word_v=word_v, pos_v=pos_v, tag_v=tag_v, train_data=train_data, dev_data=dev_data) | |||||
test_data = loader.load(os.path.join(datadir, test_data_name)) | |||||
train_data.update_vocab(word_seq=word_v, pos_seq=pos_v, head_labels=tag_v) | |||||
datasets = (train_data, dev_data, test_data) | |||||
save_data(processed_datadir, word_v=word_v, pos_v=pos_v, tag_v=tag_v, train_data=train_data, dev_data=dev_data, test_data=test_data) | |||||
loader.index_data(train_data, word_v, pos_v, tag_v) | |||||
loader.index_data(dev_data, word_v, pos_v, tag_v) | |||||
print(len(train_data)) | |||||
print(len(dev_data)) | |||||
ep = train_args['epochs'] | |||||
train_args['epochs'] = math.ceil(50000.0 / len(train_data) * train_args['batch_size']) if ep <= 0 else ep | |||||
embed, _ = EmbedLoader.load_embedding(model_args['word_emb_dim'], emb_file_name, 'glove', word_v, os.path.join(processed_datadir, 'word_emb.pkl')) | |||||
print(len(word_v)) | |||||
print(embed.size()) | |||||
# Model | |||||
model_args['word_vocab_size'] = len(word_v) | model_args['word_vocab_size'] = len(word_v) | ||||
model_args['pos_vocab_size'] = len(pos_v) | model_args['pos_vocab_size'] = len(pos_v) | ||||
model_args['num_label'] = len(tag_v) | model_args['num_label'] = len(tag_v) | ||||
model = BiaffineParser(**model_args.data) | |||||
model.reset_parameters() | |||||
datasets = (train_data, dev_data, test_data) | |||||
for ds in datasets: | |||||
ds.index_field("word_seq", word_v).index_field("pos_seq", pos_v).index_field("head_labels", tag_v) | |||||
ds.set_origin_len('word_seq') | |||||
if train_args['use_golden_train']: | |||||
train_data.set_target(gold_heads=False) | |||||
else: | |||||
train_data.set_target(gold_heads=None) | |||||
train_args.data.pop('use_golden_train') | |||||
ignore_label = pos_v['P'] | |||||
print(test_data[0]) | |||||
print(len(train_data)) | |||||
print(len(dev_data)) | |||||
print(len(test_data)) | |||||
def train(): | |||||
def train(path): | |||||
# Trainer | # Trainer | ||||
trainer = Trainer(**train_args.data) | trainer = Trainer(**train_args.data) | ||||
def _define_optim(obj): | def _define_optim(obj): | ||||
obj._optimizer = torch.optim.Adam(obj._model.parameters(), **optim_args.data) | |||||
obj._scheduler = torch.optim.lr_scheduler.LambdaLR(obj._optimizer, lambda ep: .75 ** (ep / 5e4)) | |||||
lr = optim_args.data['lr'] | |||||
embed_params = set(obj._model.word_embedding.parameters()) | |||||
decay_params = set(obj._model.arc_predictor.parameters()) | set(obj._model.label_predictor.parameters()) | |||||
params = [p for p in obj._model.parameters() if p not in decay_params and p not in embed_params] | |||||
obj._optimizer = torch.optim.Adam([ | |||||
{'params': list(embed_params), 'lr':lr*0.1}, | |||||
{'params': list(decay_params), **optim_args.data}, | |||||
{'params': params} | |||||
], lr=lr, betas=(0.9, 0.9)) | |||||
obj._scheduler = torch.optim.lr_scheduler.LambdaLR(obj._optimizer, lambda ep: max(.75 ** (ep / 5e4), 0.05)) | |||||
def _update(obj): | def _update(obj): | ||||
# torch.nn.utils.clip_grad_norm_(obj._model.parameters(), 5.0) | |||||
obj._scheduler.step() | obj._scheduler.step() | ||||
obj._optimizer.step() | obj._optimizer.step() | ||||
trainer.define_optimizer = lambda: _define_optim(trainer) | trainer.define_optimizer = lambda: _define_optim(trainer) | ||||
trainer.update = lambda: _update(trainer) | trainer.update = lambda: _update(trainer) | ||||
trainer.get_loss = lambda predict, truth: trainer._loss_func(**predict, **truth) | |||||
trainer._create_validator = lambda x: MyTester(**test_args.data) | |||||
# Model | |||||
model = BiaffineParser(**model_args.data) | |||||
trainer.set_validator(Tester(**test_args.data, evaluator=ParserEvaluator(ignore_label))) | |||||
# use pretrain embedding | |||||
embed, _ = EmbedLoader.load_embedding(model_args['word_emb_dim'], emb_file_name, 'glove', word_v, os.path.join(processed_datadir, 'word_emb.pkl')) | |||||
model.word_embedding = torch.nn.Embedding.from_pretrained(embed, freeze=False) | model.word_embedding = torch.nn.Embedding.from_pretrained(embed, freeze=False) | ||||
model.word_embedding.padding_idx = word_v.padding_idx | model.word_embedding.padding_idx = word_v.padding_idx | ||||
model.word_embedding.weight.data[word_v.padding_idx].fill_(0) | model.word_embedding.weight.data[word_v.padding_idx].fill_(0) | ||||
model.pos_embedding.padding_idx = pos_v.padding_idx | model.pos_embedding.padding_idx = pos_v.padding_idx | ||||
model.pos_embedding.weight.data[pos_v.padding_idx].fill_(0) | model.pos_embedding.weight.data[pos_v.padding_idx].fill_(0) | ||||
try: | |||||
ModelLoader.load_pytorch(model, "./save/saved_model.pkl") | |||||
print('model parameter loaded!') | |||||
except Exception as _: | |||||
print("No saved model. Continue.") | |||||
pass | |||||
# try: | |||||
# ModelLoader.load_pytorch(model, "./save/saved_model.pkl") | |||||
# print('model parameter loaded!') | |||||
# except Exception as _: | |||||
# print("No saved model. Continue.") | |||||
# pass | |||||
# Start training | # Start training | ||||
trainer.train(model, train_data, dev_data) | trainer.train(model, train_data, dev_data) | ||||
@@ -223,24 +317,27 @@ def train(): | |||||
print("Model saved!") | print("Model saved!") | ||||
def test(): | |||||
def test(path): | |||||
# Tester | # Tester | ||||
tester = MyTester(**test_args.data) | |||||
tester = Tester(**test_args.data, evaluator=ParserEvaluator(ignore_label)) | |||||
# Model | # Model | ||||
model = BiaffineParser(**model_args.data) | model = BiaffineParser(**model_args.data) | ||||
model.eval() | |||||
try: | try: | ||||
ModelLoader.load_pytorch(model, "./save/saved_model.pkl") | |||||
ModelLoader.load_pytorch(model, path) | |||||
print('model parameter loaded!') | print('model parameter loaded!') | ||||
except Exception as _: | except Exception as _: | ||||
print("No saved model. Abort test.") | print("No saved model. Abort test.") | ||||
raise | raise | ||||
# Start training | # Start training | ||||
print("Testing Train data") | |||||
tester.test(model, train_data) | |||||
print("Testing Dev data") | |||||
tester.test(model, dev_data) | tester.test(model, dev_data) | ||||
print(tester.show_metrics()) | |||||
print("Testing finished!") | |||||
print("Testing Test data") | |||||
tester.test(model, test_data) | |||||
@@ -248,11 +345,12 @@ if __name__ == "__main__": | |||||
import argparse | import argparse | ||||
parser = argparse.ArgumentParser(description='Run a chinese word segmentation model') | parser = argparse.ArgumentParser(description='Run a chinese word segmentation model') | ||||
parser.add_argument('--mode', help='set the model\'s model', choices=['train', 'test', 'infer']) | parser.add_argument('--mode', help='set the model\'s model', choices=['train', 'test', 'infer']) | ||||
parser.add_argument('--path', type=str, default='') | |||||
args = parser.parse_args() | args = parser.parse_args() | ||||
if args.mode == 'train': | if args.mode == 'train': | ||||
train() | |||||
train(args.path) | |||||
elif args.mode == 'test': | elif args.mode == 'test': | ||||
test() | |||||
test(args.path) | |||||
elif args.mode == 'infer': | elif args.mode == 'infer': | ||||
infer() | infer() | ||||
else: | else: | ||||