Browse Source

Merge branch 'dev' into dataset

tags/v0.2.0
yunfan 6 years ago
parent
commit
8fae3bc2e7
14 changed files with 486 additions and 231 deletions
  1. +6
    -0
      fastNLP/core/field.py
  2. +6
    -0
      fastNLP/core/instance.py
  3. +14
    -11
      fastNLP/core/tester.py
  4. +85
    -30
      fastNLP/core/trainer.py
  5. +14
    -5
      fastNLP/core/vocabulary.py
  6. +0
    -7
      fastNLP/loader/dataset_loader.py
  7. +3
    -3
      fastNLP/loader/embed_loader.py
  8. +79
    -70
      fastNLP/models/biaffine_parser.py
  9. +43
    -1
      fastNLP/modules/aggregator/attention.py
  10. +32
    -0
      fastNLP/modules/encoder/transformer.py
  11. +2
    -2
      fastNLP/modules/encoder/variational_rnn.py
  12. +5
    -6
      fastNLP/modules/other_modules.py
  13. +10
    -7
      reproduction/Biaffine_parser/cfg.cfg
  14. +187
    -89
      reproduction/Biaffine_parser/run.py

+ 6
- 0
fastNLP/core/field.py View File

@@ -28,6 +28,12 @@ class Field(object):
"""
raise NotImplementedError

def __repr__(self):
return self.contents().__repr__()

def new(self, *args, **kwargs):
return self.__class__(*args, **kwargs, is_target=self.is_target)

class TextField(Field):
def __init__(self, name, text, is_target):
"""


+ 6
- 0
fastNLP/core/instance.py View File

@@ -35,6 +35,9 @@ class Instance(object):
else:
raise KeyError("{} not found".format(name))

def __setitem__(self, name, field):
return self.add_field(name, field)

def get_length(self):
"""Fetch the length of all fields in the instance.

@@ -82,3 +85,6 @@ class Instance(object):
name, field_name = origin_len
tensor_x[name] = torch.LongTensor([self.fields[field_name].get_length()])
return tensor_x, tensor_y

def __repr__(self):
return self.fields.__repr__()

+ 14
- 11
fastNLP/core/tester.py View File

@@ -17,9 +17,9 @@ class Tester(object):
"""
super(Tester, self).__init__()
"""
"default_args" provides default value for important settings.
The initialization arguments "kwargs" with the same key (name) will override the default value.
"kwargs" must have the same type as "default_args" on corresponding keys.
"default_args" provides default value for important settings.
The initialization arguments "kwargs" with the same key (name) will override the default value.
"kwargs" must have the same type as "default_args" on corresponding keys.
Otherwise, error will raise.
"""
default_args = {"batch_size": 8,
@@ -29,8 +29,8 @@ class Tester(object):
"evaluator": Evaluator()
}
"""
"required_args" is the collection of arguments that users must pass to Trainer explicitly.
This is used to warn users of essential settings in the training.
"required_args" is the collection of arguments that users must pass to Trainer explicitly.
This is used to warn users of essential settings in the training.
Specially, "required_args" does not have default value, so they have nothing to do with "default_args".
"""
required_args = {}
@@ -74,16 +74,19 @@ class Tester(object):
output_list = []
truth_list = []

data_iterator = Batch(dev_data, self.batch_size, sampler=RandomSampler(), use_cuda=self.use_cuda)
data_iterator = Batch(dev_data, self.batch_size, sampler=RandomSampler(), use_cuda=self.use_cuda, sort_in_batch=True, sort_key='word_seq')

for batch_x, batch_y in data_iterator:
with torch.no_grad():
with torch.no_grad():
for batch_x, batch_y in data_iterator:
prediction = self.data_forward(network, batch_x)
output_list.append(prediction)
truth_list.append(batch_y)
eval_results = self.evaluate(output_list, truth_list)
output_list.append(prediction)
truth_list.append(batch_y)
eval_results = self.evaluate(output_list, truth_list)
print("[tester] {}".format(self.print_eval_results(eval_results)))
logger.info("[tester] {}".format(self.print_eval_results(eval_results)))
self.mode(network, is_test=False)
self.metrics = eval_results
return eval_results

def mode(self, model, is_test=False):
"""Train mode or Test mode. This is for PyTorch currently.


+ 85
- 30
fastNLP/core/trainer.py View File

@@ -1,6 +1,6 @@
import os
import time
from datetime import timedelta
from datetime import timedelta, datetime

import torch
from tensorboardX import SummaryWriter
@@ -15,7 +15,7 @@ from fastNLP.saver.logger import create_logger
from fastNLP.saver.model_saver import ModelSaver

logger = create_logger(__name__, "./train_test.log")
logger.disabled = True

class Trainer(object):
"""Operations of training a model, including data loading, gradient descent, and validation.
@@ -35,20 +35,21 @@ class Trainer(object):
super(Trainer, self).__init__()

"""
"default_args" provides default value for important settings.
The initialization arguments "kwargs" with the same key (name) will override the default value.
"kwargs" must have the same type as "default_args" on corresponding keys.
"default_args" provides default value for important settings.
The initialization arguments "kwargs" with the same key (name) will override the default value.
"kwargs" must have the same type as "default_args" on corresponding keys.
Otherwise, error will raise.
"""
default_args = {"epochs": 1, "batch_size": 2, "validate": False, "use_cuda": False, "pickle_path": "./save/",
"save_best_dev": False, "model_name": "default_model_name.pkl", "print_every_step": 1,
"valid_step": 500, "eval_sort_key": 'acc',
"loss": Loss(None), # used to pass type check
"optimizer": Optimizer("Adam", lr=0.001, weight_decay=0),
"evaluator": Evaluator()
}
"""
"required_args" is the collection of arguments that users must pass to Trainer explicitly.
This is used to warn users of essential settings in the training.
"required_args" is the collection of arguments that users must pass to Trainer explicitly.
This is used to warn users of essential settings in the training.
Specially, "required_args" does not have default value, so they have nothing to do with "default_args".
"""
required_args = {}
@@ -70,16 +71,20 @@ class Trainer(object):
else:
# Trainer doesn't care about extra arguments
pass
print(default_args)
print("Training Args {}".format(default_args))
logger.info("Training Args {}".format(default_args))

self.n_epochs = default_args["epochs"]
self.batch_size = default_args["batch_size"]
self.n_epochs = int(default_args["epochs"])
self.batch_size = int(default_args["batch_size"])
self.pickle_path = default_args["pickle_path"]
self.validate = default_args["validate"]
self.save_best_dev = default_args["save_best_dev"]
self.use_cuda = default_args["use_cuda"]
self.model_name = default_args["model_name"]
self.print_every_step = default_args["print_every_step"]
self.print_every_step = int(default_args["print_every_step"])
self.valid_step = int(default_args["valid_step"])
if self.validate is not None:
assert self.valid_step > 0

self._model = None
self._loss_func = default_args["loss"].get() # return a pytorch loss function or None
@@ -89,6 +94,8 @@ class Trainer(object):
self._summary_writer = SummaryWriter(self.pickle_path + 'tensorboard_logs')
self._graph_summaried = False
self._best_accuracy = 0.0
self.eval_sort_key = default_args['eval_sort_key']
self.validator = None

def train(self, network, train_data, dev_data=None):
"""General Training Procedure
@@ -104,12 +111,17 @@ class Trainer(object):
else:
self._model = network

print(self._model)

# define Tester over dev data
self.dev_data = None
if self.validate:
default_valid_args = {"batch_size": self.batch_size, "pickle_path": self.pickle_path,
"use_cuda": self.use_cuda, "evaluator": self._evaluator}
validator = self._create_validator(default_valid_args)
logger.info("validator defined as {}".format(str(validator)))
if self.validator is None:
self.validator = self._create_validator(default_valid_args)
logger.info("validator defined as {}".format(str(self.validator)))
self.dev_data = dev_data

# optimizer and loss
self.define_optimizer()
@@ -117,29 +129,33 @@ class Trainer(object):
self.define_loss()
logger.info("loss function defined as {}".format(str(self._loss_func)))

# turn on network training mode
self.mode(network, is_test=False)

# main training procedure
start = time.time()
logger.info("training epochs started")
for epoch in range(1, self.n_epochs + 1):
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S'))
print("training epochs started " + self.start_time)
logger.info("training epochs started " + self.start_time)
epoch, iters = 1, 0
while(1):
if self.n_epochs != -1 and epoch > self.n_epochs:
break
logger.info("training epoch {}".format(epoch))

# turn on network training mode
self.mode(network, is_test=False)
# prepare mini-batch iterator
data_iterator = Batch(train_data, batch_size=self.batch_size, sampler=RandomSampler(),
use_cuda=self.use_cuda)
use_cuda=self.use_cuda, sort_in_batch=True, sort_key='word_seq')
logger.info("prepared data iterator")

# one forward and backward pass
self._train_step(data_iterator, network, start=start, n_print=self.print_every_step, epoch=epoch)
iters = self._train_step(data_iterator, network, start=start, n_print=self.print_every_step, epoch=epoch, step=iters, dev_data=dev_data)

# validation
if self.validate:
if dev_data is None:
raise RuntimeError(
"self.validate is True in trainer, but dev_data is None. Please provide the validation data.")
logger.info("validation started")
validator.test(network, dev_data)
self.valid_model()
self.save_model(self._model, 'training_model_'+self.start_time)
epoch += 1

def _train_step(self, data_iterator, network, **kwargs):
"""Training process in one epoch.
@@ -149,13 +165,17 @@ class Trainer(object):
- start: time.time(), the starting time of this step.
- epoch: int,
"""
step = 0
step = kwargs['step']
for batch_x, batch_y in data_iterator:

prediction = self.data_forward(network, batch_x)

loss = self.get_loss(prediction, batch_y)
self.grad_backward(loss)
# if torch.rand(1).item() < 0.001:
# print('[grads at epoch: {:>3} step: {:>4}]'.format(kwargs['epoch'], step))
# for name, p in self._model.named_parameters():
# if p.requires_grad:
# print('\t{} {} {}'.format(name, tuple(p.size()), torch.sum(p.grad).item()))
self.update()
self._summary_writer.add_scalar("loss", loss.item(), global_step=step)

@@ -166,7 +186,22 @@ class Trainer(object):
kwargs["epoch"], step, loss.data, diff)
print(print_output)
logger.info(print_output)
if self.validate and self.valid_step > 0 and step > 0 and step % self.valid_step == 0:
self.valid_model()
step += 1
return step

def valid_model(self):
if self.dev_data is None:
raise RuntimeError(
"self.validate is True in trainer, but dev_data is None. Please provide the validation data.")
logger.info("validation started")
res = self.validator.test(self._model, self.dev_data)
if self.save_best_dev and self.best_eval_result(res):
logger.info('save best result! {}'.format(res))
print('save best result! {}'.format(res))
self.save_model(self._model, 'best_model_'+self.start_time)
return res

def mode(self, model, is_test=False):
"""Train mode or Test mode. This is for PyTorch currently.
@@ -180,11 +215,17 @@ class Trainer(object):
else:
model.train()

def define_optimizer(self):
def define_optimizer(self, optim=None):
"""Define framework-specific optimizer specified by the models.

"""
self._optimizer = self._optimizer_proto.construct_from_pytorch(self._model.parameters())
if optim is not None:
# optimizer constructed by user
self._optimizer = optim
elif self._optimizer is None:
# optimizer constructed by proto
self._optimizer = self._optimizer_proto.construct_from_pytorch(self._model.parameters())
return self._optimizer

def update(self):
"""Perform weight update on a model.
@@ -217,6 +258,8 @@ class Trainer(object):
:param truth: ground truth label vector
:return: a scalar
"""
if isinstance(predict, dict) and isinstance(truth, dict):
return self._loss_func(**predict, **truth)
if len(truth) > 1:
raise NotImplementedError("Not ready to handle multi-labels.")
truth = list(truth.values())[0] if len(truth) > 0 else None
@@ -241,13 +284,23 @@ class Trainer(object):
raise ValueError("Please specify a loss function.")
logger.info("The model didn't define loss, use Trainer's loss.")

def best_eval_result(self, validator):
def best_eval_result(self, metrics):
"""Check if the current epoch yields better validation results.

:param validator: a Tester instance
:return: bool, True means current results on dev set is the best.
"""
loss, accuracy = validator.metrics
if isinstance(metrics, tuple):
loss, metrics = metrics

if isinstance(metrics, dict):
if len(metrics) == 1:
accuracy = list(metrics.values())[0]
else:
accuracy = metrics[self.eval_sort_key]
else:
accuracy = metrics

if accuracy > self._best_accuracy:
self._best_accuracy = accuracy
return True
@@ -268,6 +321,8 @@ class Trainer(object):
def _create_validator(self, valid_args):
raise NotImplementedError

def set_validator(self, validor):
self.validator = validor

class SeqLabelTrainer(Trainer):
"""Trainer for Sequence Labeling


+ 14
- 5
fastNLP/core/vocabulary.py View File

@@ -51,6 +51,12 @@ class Vocabulary(object):
self.min_freq = min_freq
self.word_count = {}
self.has_default = need_default
if self.has_default:
self.padding_label = DEFAULT_PADDING_LABEL
self.unknown_label = DEFAULT_UNKNOWN_LABEL
else:
self.padding_label = None
self.unknown_label = None
self.word2idx = None
self.idx2word = None

@@ -77,12 +83,10 @@ class Vocabulary(object):
"""
if self.has_default:
self.word2idx = deepcopy(DEFAULT_WORD_TO_INDEX)
self.padding_label = DEFAULT_PADDING_LABEL
self.unknown_label = DEFAULT_UNKNOWN_LABEL
self.word2idx[self.unknown_label] = self.word2idx.pop(DEFAULT_UNKNOWN_LABEL)
self.word2idx[self.padding_label] = self.word2idx.pop(DEFAULT_PADDING_LABEL)
else:
self.word2idx = {}
self.padding_label = None
self.unknown_label = None

words = sorted(self.word_count.items(), key=lambda kv: kv[1], reverse=True)
if self.min_freq is not None:
@@ -114,7 +118,7 @@ class Vocabulary(object):
if w in self.word2idx:
return self.word2idx[w]
elif self.has_default:
return self.word2idx[DEFAULT_UNKNOWN_LABEL]
return self.word2idx[self.unknown_label]
else:
raise ValueError("word {} not in vocabulary".format(w))

@@ -134,6 +138,11 @@ class Vocabulary(object):
return None
return self.word2idx[self.unknown_label]

def __setattr__(self, name, val):
self.__dict__[name] = val
if name in self.__dict__ and name in ["unknown_label", "padding_label"]:
self.word2idx = None

@property
@check_build_vocab
def padding_idx(self):


+ 0
- 7
fastNLP/loader/dataset_loader.py View File

@@ -87,7 +87,6 @@ class DataSetLoader(BaseLoader):
"""
raise NotImplementedError


@DataSet.set_reader('read_raw')
class RawDataSetLoader(DataSetLoader):
def __init__(self):
@@ -103,7 +102,6 @@ class RawDataSetLoader(DataSetLoader):
def convert(self, data):
return convert_seq_dataset(data)


@DataSet.set_reader('read_pos')
class POSDataSetLoader(DataSetLoader):
"""Dataset Loader for POS Tag datasets.
@@ -173,7 +171,6 @@ class POSDataSetLoader(DataSetLoader):
"""
return convert_seq2seq_dataset(data)


@DataSet.set_reader('read_tokenize')
class TokenizeDataSetLoader(DataSetLoader):
"""
@@ -233,7 +230,6 @@ class TokenizeDataSetLoader(DataSetLoader):
def convert(self, data):
return convert_seq2seq_dataset(data)


@DataSet.set_reader('read_class')
class ClassDataSetLoader(DataSetLoader):
"""Loader for classification data sets"""
@@ -272,7 +268,6 @@ class ClassDataSetLoader(DataSetLoader):
def convert(self, data):
return convert_seq2tag_dataset(data)


@DataSet.set_reader('read_conll')
class ConllLoader(DataSetLoader):
"""loader for conll format files"""
@@ -314,7 +309,6 @@ class ConllLoader(DataSetLoader):
def convert(self, data):
pass


@DataSet.set_reader('read_lm')
class LMDataSetLoader(DataSetLoader):
"""Language Model Dataset Loader
@@ -351,7 +345,6 @@ class LMDataSetLoader(DataSetLoader):
def convert(self, data):
pass


@DataSet.set_reader('read_people_daily')
class PeopleDailyCorpusLoader(DataSetLoader):
"""


+ 3
- 3
fastNLP/loader/embed_loader.py View File

@@ -17,8 +17,8 @@ class EmbedLoader(BaseLoader):
def _load_glove(emb_file):
"""Read file as a glove embedding

file format:
embeddings are split by line,
file format:
embeddings are split by line,
for one embedding, word and numbers split by space
Example::

@@ -33,7 +33,7 @@ class EmbedLoader(BaseLoader):
if len(line) > 0:
emb[line[0]] = torch.Tensor(list(map(float, line[1:])))
return emb
@staticmethod
def _load_pretrain(emb_file, emb_type):
"""Read txt data from embedding file and convert to np.array as pre-trained embedding


+ 79
- 70
fastNLP/models/biaffine_parser.py View File

@@ -16,10 +16,9 @@ def mst(scores):
https://github.com/tdozat/Parser/blob/0739216129cd39d69997d28cbc4133b360ea3934/lib/models/nn.py#L692
"""
length = scores.shape[0]
min_score = -np.inf
mask = np.zeros((length, length))
np.fill_diagonal(mask, -np.inf)
scores = scores + mask
min_score = scores.min() - 1
eye = np.eye(length)
scores = scores * (1 - eye) + min_score * eye
heads = np.argmax(scores, axis=1)
heads[0] = 0
tokens = np.arange(1, length)
@@ -126,6 +125,8 @@ class GraphParser(nn.Module):
def _greedy_decoder(self, arc_matrix, seq_mask=None):
_, seq_len, _ = arc_matrix.shape
matrix = arc_matrix + torch.diag(arc_matrix.new(seq_len).fill_(-np.inf))
flip_mask = (seq_mask == 0).byte()
matrix.masked_fill_(flip_mask.unsqueeze(1), -np.inf)
_, heads = torch.max(matrix, dim=2)
if seq_mask is not None:
heads *= seq_mask.long()
@@ -135,8 +136,15 @@ class GraphParser(nn.Module):
batch_size, seq_len, _ = arc_matrix.shape
matrix = torch.zeros_like(arc_matrix).copy_(arc_matrix)
ans = matrix.new_zeros(batch_size, seq_len).long()
lens = (seq_mask.long()).sum(1) if seq_mask is not None else torch.zeros(batch_size) + seq_len
batch_idx = torch.arange(batch_size, dtype=torch.long, device=lens.device)
seq_mask[batch_idx, lens-1] = 0
for i, graph in enumerate(matrix):
ans[i] = torch.as_tensor(mst(graph.cpu().numpy()), device=ans.device)
len_i = lens[i]
if len_i == seq_len:
ans[i] = torch.as_tensor(mst(graph.cpu().numpy()), device=ans.device)
else:
ans[i, :len_i] = torch.as_tensor(mst(graph[:len_i, :len_i].cpu().numpy()), device=ans.device)
if seq_mask is not None:
ans *= seq_mask.long()
return ans
@@ -175,14 +183,19 @@ class LabelBilinear(nn.Module):
def __init__(self, in1_features, in2_features, num_label, bias=True):
super(LabelBilinear, self).__init__()
self.bilinear = nn.Bilinear(in1_features, in2_features, num_label, bias=bias)
self.lin1 = nn.Linear(in1_features, num_label, bias=False)
self.lin2 = nn.Linear(in2_features, num_label, bias=False)
self.lin = nn.Linear(in1_features + in2_features, num_label, bias=False)

def forward(self, x1, x2):
output = self.bilinear(x1, x2)
output += self.lin1(x1) + self.lin2(x2)
output += self.lin(torch.cat([x1, x2], dim=2))
return output

def len2masks(origin_len, max_len):
if origin_len.dim() <= 1:
origin_len = origin_len.unsqueeze(1) # [batch_size, 1]
seq_range = torch.arange(start=0, end=max_len, dtype=torch.long, device=origin_len.device) # [max_len,]
seq_mask = torch.gt(origin_len, seq_range.unsqueeze(0)) # [batch_size, max_len]
return seq_mask

class BiaffineParser(GraphParser):
"""Biaffine Dependency Parser implemantation.
@@ -194,6 +207,8 @@ class BiaffineParser(GraphParser):
word_emb_dim,
pos_vocab_size,
pos_emb_dim,
word_hid_dim,
pos_hid_dim,
rnn_layers,
rnn_hidden_size,
arc_mlp_size,
@@ -204,10 +219,15 @@ class BiaffineParser(GraphParser):
use_greedy_infer=False):

super(BiaffineParser, self).__init__()
rnn_out_size = 2 * rnn_hidden_size
self.word_embedding = nn.Embedding(num_embeddings=word_vocab_size, embedding_dim=word_emb_dim)
self.pos_embedding = nn.Embedding(num_embeddings=pos_vocab_size, embedding_dim=pos_emb_dim)
self.word_fc = nn.Linear(word_emb_dim, word_hid_dim)
self.pos_fc = nn.Linear(pos_emb_dim, pos_hid_dim)
self.word_norm = nn.LayerNorm(word_hid_dim)
self.pos_norm = nn.LayerNorm(pos_hid_dim)
if use_var_lstm:
self.lstm = VarLSTM(input_size=word_emb_dim + pos_emb_dim,
self.lstm = VarLSTM(input_size=word_hid_dim + pos_hid_dim,
hidden_size=rnn_hidden_size,
num_layers=rnn_layers,
bias=True,
@@ -216,7 +236,7 @@ class BiaffineParser(GraphParser):
hidden_dropout=dropout,
bidirectional=True)
else:
self.lstm = nn.LSTM(input_size=word_emb_dim + pos_emb_dim,
self.lstm = nn.LSTM(input_size=word_hid_dim + pos_hid_dim,
hidden_size=rnn_hidden_size,
num_layers=rnn_layers,
bias=True,
@@ -224,21 +244,35 @@ class BiaffineParser(GraphParser):
dropout=dropout,
bidirectional=True)

rnn_out_size = 2 * rnn_hidden_size
self.arc_head_mlp = nn.Sequential(nn.Linear(rnn_out_size, arc_mlp_size),
nn.ELU())
nn.LayerNorm(arc_mlp_size),
nn.ELU(),
TimestepDropout(p=dropout),)
self.arc_dep_mlp = copy.deepcopy(self.arc_head_mlp)
self.label_head_mlp = nn.Sequential(nn.Linear(rnn_out_size, label_mlp_size),
nn.ELU())
nn.LayerNorm(label_mlp_size),
nn.ELU(),
TimestepDropout(p=dropout),)
self.label_dep_mlp = copy.deepcopy(self.label_head_mlp)
self.arc_predictor = ArcBiaffine(arc_mlp_size, bias=True)
self.label_predictor = LabelBilinear(label_mlp_size, label_mlp_size, num_label, bias=True)
self.normal_dropout = nn.Dropout(p=dropout)
self.timestep_dropout = TimestepDropout(p=dropout)
self.use_greedy_infer = use_greedy_infer
initial_parameter(self)
self.reset_parameters()
self.explore_p = 0.2

def reset_parameters(self):
for m in self.modules():
if isinstance(m, nn.Embedding):
continue
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.weight, 0.1)
nn.init.constant_(m.bias, 0)
else:
for p in m.parameters():
nn.init.normal_(p, 0, 0.1)

def forward(self, word_seq, pos_seq, seq_mask, gold_heads=None, **_):
def forward(self, word_seq, pos_seq, word_seq_origin_len, gold_heads=None, **_):
"""
:param word_seq: [batch_size, seq_len] sequence of word's indices
:param pos_seq: [batch_size, seq_len] sequence of word's indices
@@ -253,32 +287,35 @@ class BiaffineParser(GraphParser):
# prepare embeddings
batch_size, seq_len = word_seq.shape
# print('forward {} {}'.format(batch_size, seq_len))
batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=word_seq.device).unsqueeze(1)

# get sequence mask
seq_mask = seq_mask.long()
seq_mask = len2masks(word_seq_origin_len, seq_len).long()

word = self.normal_dropout(self.word_embedding(word_seq)) # [N,L] -> [N,L,C_0]
pos = self.normal_dropout(self.pos_embedding(pos_seq)) # [N,L] -> [N,L,C_1]
word, pos = self.word_fc(word), self.pos_fc(pos)
word, pos = self.word_norm(word), self.pos_norm(pos)
x = torch.cat([word, pos], dim=2) # -> [N,L,C]
del word, pos

# lstm, extract features
x = nn.utils.rnn.pack_padded_sequence(x, word_seq_origin_len.squeeze(1), batch_first=True)
feat, _ = self.lstm(x) # -> [N,L,C]
feat, _ = nn.utils.rnn.pad_packed_sequence(feat, batch_first=True)

# for arc biaffine
# mlp, reduce dim
arc_dep = self.timestep_dropout(self.arc_dep_mlp(feat))
arc_head = self.timestep_dropout(self.arc_head_mlp(feat))
label_dep = self.timestep_dropout(self.label_dep_mlp(feat))
label_head = self.timestep_dropout(self.label_head_mlp(feat))
arc_dep = self.arc_dep_mlp(feat)
arc_head = self.arc_head_mlp(feat)
label_dep = self.label_dep_mlp(feat)
label_head = self.label_head_mlp(feat)
del feat

# biaffine arc classifier
arc_pred = self.arc_predictor(arc_head, arc_dep) # [N, L, L]
flip_mask = (seq_mask == 0)
arc_pred.masked_fill_(flip_mask.unsqueeze(1), -np.inf)

# use gold or predicted arc to predict label
if gold_heads is None:
if gold_heads is None or not self.training:
# use greedy decoding in training
if self.training or self.use_greedy_infer:
heads = self._greedy_decoder(arc_pred, seq_mask)
@@ -286,9 +323,15 @@ class BiaffineParser(GraphParser):
heads = self._mst_decoder(arc_pred, seq_mask)
head_pred = heads
else:
head_pred = None
heads = gold_heads
assert self.training # must be training mode
if torch.rand(1).item() < self.explore_p:
heads = self._greedy_decoder(arc_pred, seq_mask)
head_pred = heads
else:
head_pred = None
heads = gold_heads

batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=word_seq.device).unsqueeze(1)
label_head = label_head[batch_range, heads].contiguous()
label_pred = self.label_predictor(label_head, label_dep) # [N, L, num_label]
res_dict = {'arc_pred': arc_pred, 'label_pred': label_pred, 'seq_mask': seq_mask}
@@ -301,7 +344,7 @@ class BiaffineParser(GraphParser):
Compute loss.

:param arc_pred: [batch_size, seq_len, seq_len]
:param label_pred: [batch_size, seq_len, seq_len]
:param label_pred: [batch_size, seq_len, n_tags]
:param head_indices: [batch_size, seq_len]
:param head_labels: [batch_size, seq_len]
:param seq_mask: [batch_size, seq_len]
@@ -309,10 +352,13 @@ class BiaffineParser(GraphParser):
"""

batch_size, seq_len, _ = arc_pred.shape
arc_logits = F.log_softmax(arc_pred, dim=2)
flip_mask = (seq_mask == 0)
_arc_pred = arc_pred.new_empty((batch_size, seq_len, seq_len)).copy_(arc_pred)
_arc_pred.masked_fill_(flip_mask.unsqueeze(1), -np.inf)
arc_logits = F.log_softmax(_arc_pred, dim=2)
label_logits = F.log_softmax(label_pred, dim=2)
batch_index = torch.arange(start=0, end=batch_size, device=arc_logits.device).long().unsqueeze(1)
child_index = torch.arange(start=0, end=seq_len, device=arc_logits.device).long().unsqueeze(0)
batch_index = torch.arange(batch_size, device=arc_logits.device, dtype=torch.long).unsqueeze(1)
child_index = torch.arange(seq_len, device=arc_logits.device, dtype=torch.long).unsqueeze(0)
arc_loss = arc_logits[batch_index, child_index, head_indices]
label_loss = label_logits[batch_index, child_index, head_labels]

@@ -320,45 +366,8 @@ class BiaffineParser(GraphParser):
label_loss = label_loss[:, 1:]

float_mask = seq_mask[:, 1:].float()
length = (seq_mask.sum() - batch_size).float()
arc_nll = -(arc_loss*float_mask).sum() / length
label_nll = -(label_loss*float_mask).sum() / length
arc_nll = -(arc_loss*float_mask).mean()
label_nll = -(label_loss*float_mask).mean()
return arc_nll + label_nll

def evaluate(self, arc_pred, label_pred, head_indices, head_labels, seq_mask, **kwargs):
"""
Evaluate the performance of prediction.

:return dict: performance results.
head_pred_corrct: number of correct predicted heads.
label_pred_correct: number of correct predicted labels.
total_tokens: number of predicted tokens
"""
if 'head_pred' in kwargs:
head_pred = kwargs['head_pred']
elif self.use_greedy_infer:
head_pred = self._greedy_decoder(arc_pred, seq_mask)
else:
head_pred = self._mst_decoder(arc_pred, seq_mask)

head_pred_correct = (head_pred == head_indices).long() * seq_mask
_, label_preds = torch.max(label_pred, dim=2)
label_pred_correct = (label_preds == head_labels).long() * head_pred_correct
return {"head_pred_correct": head_pred_correct.sum(dim=1),
"label_pred_correct": label_pred_correct.sum(dim=1),
"total_tokens": seq_mask.sum(dim=1)}

def metrics(self, head_pred_correct, label_pred_correct, total_tokens, **_):
"""
Compute the metrics of model

:param head_pred_corrct: number of correct predicted heads.
:param label_pred_correct: number of correct predicted labels.
:param total_tokens: number of predicted tokens
:return dict: the metrics results
UAS: the head predicted accuracy
LAS: the label predicted accuracy
"""
return {"UAS": head_pred_correct.sum().float() / total_tokens.sum().float() * 100,
"LAS": label_pred_correct.sum().float() / total_tokens.sum().float() * 100}


+ 43
- 1
fastNLP/modules/aggregator/attention.py View File

@@ -1,5 +1,6 @@
import torch

from torch import nn
import math
from fastNLP.modules.utils import mask_softmax


@@ -17,3 +18,44 @@ class Attention(torch.nn.Module):

def _atten_forward(self, query, memory):
raise NotImplementedError

class DotAtte(nn.Module):
def __init__(self, key_size, value_size):
super(DotAtte, self).__init__()
self.key_size = key_size
self.value_size = value_size
self.scale = math.sqrt(key_size)

def forward(self, Q, K, V, seq_mask=None):
"""

:param Q: [batch, seq_len, key_size]
:param K: [batch, seq_len, key_size]
:param V: [batch, seq_len, value_size]
:param seq_mask: [batch, seq_len]
"""
output = torch.matmul(Q, K.transpose(1, 2)) / self.scale
if seq_mask is not None:
output.masked_fill_(seq_mask.lt(1), -float('inf'))
output = nn.functional.softmax(output, dim=2)
return torch.matmul(output, V)

class MultiHeadAtte(nn.Module):
def __init__(self, input_size, output_size, key_size, value_size, num_atte):
super(MultiHeadAtte, self).__init__()
self.in_linear = nn.ModuleList()
for i in range(num_atte * 3):
out_feat = key_size if (i % 3) != 2 else value_size
self.in_linear.append(nn.Linear(input_size, out_feat))
self.attes = nn.ModuleList([DotAtte(key_size, value_size) for _ in range(num_atte)])
self.out_linear = nn.Linear(value_size * num_atte, output_size)

def forward(self, Q, K, V, seq_mask=None):
heads = []
for i in range(len(self.attes)):
j = i * 3
qi, ki, vi = self.in_linear[j](Q), self.in_linear[j+1](K), self.in_linear[j+2](V)
headi = self.attes[i](qi, ki, vi, seq_mask)
heads.append(headi)
output = torch.cat(heads, dim=2)
return self.out_linear(output)

+ 32
- 0
fastNLP/modules/encoder/transformer.py View File

@@ -0,0 +1,32 @@
import torch
from torch import nn
import torch.nn.functional as F

from ..aggregator.attention import MultiHeadAtte
from ..other_modules import LayerNormalization

class TransformerEncoder(nn.Module):
class SubLayer(nn.Module):
def __init__(self, input_size, output_size, key_size, value_size, num_atte):
super(TransformerEncoder.SubLayer, self).__init__()
self.atte = MultiHeadAtte(input_size, output_size, key_size, value_size, num_atte)
self.norm1 = LayerNormalization(output_size)
self.ffn = nn.Sequential(nn.Linear(output_size, output_size),
nn.ReLU(),
nn.Linear(output_size, output_size))
self.norm2 = LayerNormalization(output_size)

def forward(self, input, seq_mask):
attention = self.atte(input)
norm_atte = self.norm1(attention + input)
output = self.ffn(norm_atte)
return self.norm2(output + norm_atte)

def __init__(self, num_layers, **kargs):
super(TransformerEncoder, self).__init__()
self.layers = nn.Sequential(*[self.SubLayer(**kargs) for _ in range(num_layers)])

def forward(self, x, seq_mask=None):
return self.layers(x, seq_mask)



+ 2
- 2
fastNLP/modules/encoder/variational_rnn.py View File

@@ -101,14 +101,14 @@ class VarRNNBase(nn.Module):

mask_x = input.new_ones((batch_size, self.input_size))
mask_out = input.new_ones((batch_size, self.hidden_size * self.num_directions))
mask_h = input.new_ones((batch_size, self.hidden_size))
mask_h_ones = input.new_ones((batch_size, self.hidden_size))
nn.functional.dropout(mask_x, p=self.input_dropout, training=self.training, inplace=True)
nn.functional.dropout(mask_out, p=self.hidden_dropout, training=self.training, inplace=True)
nn.functional.dropout(mask_h, p=self.hidden_dropout, training=self.training, inplace=True)

hidden_list = []
for layer in range(self.num_layers):
output_list = []
mask_h = nn.functional.dropout(mask_h_ones, p=self.hidden_dropout, training=self.training, inplace=False)
for direction in range(self.num_directions):
input_x = input if direction == 0 else flip(input, [0])
idx = self.num_directions * layer + direction


+ 5
- 6
fastNLP/modules/other_modules.py View File

@@ -31,12 +31,12 @@ class GroupNorm(nn.Module):
class LayerNormalization(nn.Module):
""" Layer normalization module """

def __init__(self, d_hid, eps=1e-3):
def __init__(self, layer_size, eps=1e-3):
super(LayerNormalization, self).__init__()

self.eps = eps
self.a_2 = nn.Parameter(torch.ones(d_hid), requires_grad=True)
self.b_2 = nn.Parameter(torch.zeros(d_hid), requires_grad=True)
self.a_2 = nn.Parameter(torch.ones(1, layer_size, requires_grad=True))
self.b_2 = nn.Parameter(torch.zeros(1, layer_size, requires_grad=True))

def forward(self, z):
if z.size(1) == 1:
@@ -44,9 +44,8 @@ class LayerNormalization(nn.Module):

mu = torch.mean(z, keepdim=True, dim=-1)
sigma = torch.std(z, keepdim=True, dim=-1)
ln_out = (z - mu.expand_as(z)) / (sigma.expand_as(z) + self.eps)
ln_out = ln_out * self.a_2.expand_as(ln_out) + self.b_2.expand_as(ln_out)

ln_out = (z - mu) / (sigma + self.eps)
ln_out = ln_out * self.a_2 + self.b_2
return ln_out




+ 10
- 7
reproduction/Biaffine_parser/cfg.cfg View File

@@ -1,37 +1,40 @@
[train]
epochs = 50
epochs = -1
batch_size = 16
pickle_path = "./save/"
validate = true
save_best_dev = false
save_best_dev = true
eval_sort_key = "UAS"
use_cuda = true
model_saved_path = "./save/"
task = "parse"
print_every_step = 20
use_golden_train=true

[test]
save_output = true
validate_in_training = true
save_dev_input = false
save_loss = true
batch_size = 16
batch_size = 64
pickle_path = "./save/"
use_cuda = true
task = "parse"

[model]
word_vocab_size = -1
word_emb_dim = 100
pos_vocab_size = -1
pos_emb_dim = 100
word_hid_dim = 100
pos_hid_dim = 100
rnn_layers = 3
rnn_hidden_size = 400
arc_mlp_size = 500
label_mlp_size = 100
num_label = -1
dropout = 0.33
use_var_lstm=true
use_var_lstm=false
use_greedy_infer=false

[optim]
lr = 2e-3
weight_decay = 5e-5

+ 187
- 89
reproduction/Biaffine_parser/run.py View File

@@ -6,15 +6,17 @@ sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))
from collections import defaultdict
import math
import torch
import re

from fastNLP.core.trainer import Trainer
from fastNLP.core.metrics import Evaluator
from fastNLP.core.instance import Instance
from fastNLP.core.vocabulary import Vocabulary
from fastNLP.core.dataset import DataSet
from fastNLP.core.batch import Batch
from fastNLP.core.sampler import SequentialSampler
from fastNLP.core.field import TextField, SeqLabelField
from fastNLP.core.preprocess import SeqLabelPreprocess, load_pickle
from fastNLP.core.preprocess import load_pickle
from fastNLP.core.tester import Tester
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
from fastNLP.loader.model_loader import ModelLoader
@@ -22,15 +24,18 @@ from fastNLP.loader.embed_loader import EmbedLoader
from fastNLP.models.biaffine_parser import BiaffineParser
from fastNLP.saver.model_saver import ModelSaver

BOS = '<BOS>'
EOS = '<EOS>'
UNK = '<OOV>'
NUM = '<NUM>'
ENG = '<ENG>'

# not in the file's dir
if len(os.path.dirname(__file__)) != 0:
os.chdir(os.path.dirname(__file__))

class MyDataLoader(object):
def __init__(self, pickle_path):
self.pickle_path = pickle_path

def load(self, path, word_v=None, pos_v=None, headtag_v=None):
class ConlluDataLoader(object):
def load(self, path):
datalist = []
with open(path, 'r', encoding='utf-8') as f:
sample = []
@@ -49,23 +54,18 @@ class MyDataLoader(object):
for sample in datalist:
# print(sample)
res = self.get_one(sample)
if word_v is not None:
word_v.update(res[0])
pos_v.update(res[1])
headtag_v.update(res[3])
ds.append(Instance(word_seq=TextField(res[0], is_target=False),
pos_seq=TextField(res[1], is_target=False),
head_indices=SeqLabelField(res[2], is_target=True),
head_labels=TextField(res[3], is_target=True),
seq_mask=SeqLabelField([1 for _ in range(len(res[0]))], is_target=False)))
head_labels=TextField(res[3], is_target=True)))

return ds

def get_one(self, sample):
text = ['<root>']
pos_tags = ['<root>']
heads = [0]
head_tags = ['root']
text = []
pos_tags = []
heads = []
head_tags = []
for w in sample:
t1, t2, t3, t4 = w[1], w[3], w[6], w[7]
if t3 == '_':
@@ -76,17 +76,60 @@ class MyDataLoader(object):
head_tags.append(t4)
return (text, pos_tags, heads, head_tags)

def index_data(self, dataset, word_v, pos_v, tag_v):
dataset.index_field('word_seq', word_v)
dataset.index_field('pos_seq', pos_v)
dataset.index_field('head_labels', tag_v)
class CTBDataLoader(object):
def load(self, data_path):
with open(data_path, "r", encoding="utf-8") as f:
lines = f.readlines()
data = self.parse(lines)
return self.convert(data)

def parse(self, lines):
"""
[
[word], [pos], [head_index], [head_tag]
]
"""
sample = []
data = []
for i, line in enumerate(lines):
line = line.strip()
if len(line) == 0 or i+1 == len(lines):
data.append(list(map(list, zip(*sample))))
sample = []
else:
sample.append(line.split())
return data

def convert(self, data):
dataset = DataSet()
for sample in data:
word_seq = [BOS] + sample[0] + [EOS]
pos_seq = [BOS] + sample[1] + [EOS]
heads = [0] + list(map(int, sample[2])) + [0]
head_tags = [BOS] + sample[3] + [EOS]
dataset.append(Instance(word_seq=TextField(word_seq, is_target=False),
pos_seq=TextField(pos_seq, is_target=False),
gold_heads=SeqLabelField(heads, is_target=False),
head_indices=SeqLabelField(heads, is_target=True),
head_labels=TextField(head_tags, is_target=True)))
return dataset

# datadir = "/mnt/c/Me/Dev/release-2.2-st-train-dev-data/ud-treebanks-v2.2/UD_English-EWT"
datadir = "/home/yfshao/UD_English-EWT"
# datadir = "/home/yfshao/UD_English-EWT"
# train_data_name = "en_ewt-ud-train.conllu"
# dev_data_name = "en_ewt-ud-dev.conllu"
# emb_file_name = '/home/yfshao/glove.6B.100d.txt'
# loader = ConlluDataLoader()

datadir = '/home/yfshao/workdir/parser-data/'
train_data_name = "train_ctb5.txt"
dev_data_name = "dev_ctb5.txt"
test_data_name = "test_ctb5.txt"
emb_file_name = "/home/yfshao/workdir/parser-data/word_OOVthr_30_100v.txt"
# emb_file_name = "/home/yfshao/workdir/word_vector/cc.zh.300.vec"
loader = CTBDataLoader()

cfgfile = './cfg.cfg'
train_data_name = "en_ewt-ud-train.conllu"
dev_data_name = "en_ewt-ud-dev.conllu"
emb_file_name = '/home/yfshao/glove.6B.100d.txt'
processed_datadir = './save'

# Config Loader
@@ -95,8 +138,12 @@ test_args = ConfigSection()
model_args = ConfigSection()
optim_args = ConfigSection()
ConfigLoader.load_config(cfgfile, {"train": train_args, "test": test_args, "model": model_args, "optim": optim_args})
print('trainre Args:', train_args.data)
print('test Args:', test_args.data)
print('optim Args:', optim_args.data)

# Data Loader

# Pickle Loader
def save_data(dirpath, **kwargs):
import _pickle
if not os.path.exists(dirpath):
@@ -117,38 +164,57 @@ def load_data(dirpath):
datas[name] = _pickle.load(f)
return datas

class MyTester(object):
def __init__(self, batch_size, use_cuda=False, **kwagrs):
self.batch_size = batch_size
self.use_cuda = use_cuda

def test(self, model, dataset):
self.model = model.cuda() if self.use_cuda else model
self.model.eval()
batchiter = Batch(dataset, self.batch_size, SequentialSampler(), self.use_cuda)
eval_res = defaultdict(list)
i = 0
for batch_x, batch_y in batchiter:
with torch.no_grad():
pred_y = self.model(**batch_x)
eval_one = self.model.evaluate(**pred_y, **batch_y)
i += self.batch_size
for eval_name, tensor in eval_one.items():
eval_res[eval_name].append(tensor)
tmp = {}
for eval_name, tensorlist in eval_res.items():
tmp[eval_name] = torch.cat(tensorlist, dim=0)

self.res = self.model.metrics(**tmp)

def show_metrics(self):
s = ""
for name, val in self.res.items():
s += '{}: {:.2f}\t'.format(name, val)
return s


loader = MyDataLoader('')
def P2(data, field, length):
ds = [ins for ins in data if ins[field].get_length() >= length]
data.clear()
data.extend(ds)
return ds

def P1(data, field):
def reeng(w):
return w if w == BOS or w == EOS or re.search(r'^([a-zA-Z]+[\.\-]*)+$', w) is None else ENG
def renum(w):
return w if re.search(r'^[0-9]+\.?[0-9]*$', w) is None else NUM
for ins in data:
ori = ins[field].contents()
s = list(map(renum, map(reeng, ori)))
if s != ori:
# print(ori)
# print(s)
# print()
ins[field] = ins[field].new(s)
return data

class ParserEvaluator(Evaluator):
def __init__(self, ignore_label):
super(ParserEvaluator, self).__init__()
self.ignore = ignore_label

def __call__(self, predict_list, truth_list):
head_all, label_all, total_all = 0, 0, 0
for pred, truth in zip(predict_list, truth_list):
head, label, total = self.evaluate(**pred, **truth)
head_all += head
label_all += label
total_all += total

return {'UAS': head_all*1.0 / total_all, 'LAS': label_all*1.0 / total_all}

def evaluate(self, head_pred, label_pred, head_indices, head_labels, seq_mask, **_):
"""
Evaluate the performance of prediction.

:return : performance results.
head_pred_corrct: number of correct predicted heads.
label_pred_correct: number of correct predicted labels.
total_tokens: number of predicted tokens
"""
seq_mask *= (head_labels != self.ignore).long()
head_pred_correct = (head_pred == head_indices).long() * seq_mask
_, label_preds = torch.max(label_pred, dim=2)
label_pred_correct = (label_preds == head_labels).long() * head_pred_correct
return head_pred_correct.sum().item(), label_pred_correct.sum().item(), seq_mask.sum().item()

try:
data_dict = load_data(processed_datadir)
word_v = data_dict['word_v']
@@ -156,62 +222,90 @@ try:
tag_v = data_dict['tag_v']
train_data = data_dict['train_data']
dev_data = data_dict['dev_data']
test_data = data_dict['test_data']
print('use saved pickles')

except Exception as _:
print('load raw data and preprocess')
# use pretrain embedding
word_v = Vocabulary(need_default=True, min_freq=2)
word_v.unknown_label = UNK
pos_v = Vocabulary(need_default=True)
tag_v = Vocabulary(need_default=False)
train_data = loader.load(os.path.join(datadir, train_data_name), word_v, pos_v, tag_v)
train_data = loader.load(os.path.join(datadir, train_data_name))
dev_data = loader.load(os.path.join(datadir, dev_data_name))
save_data(processed_datadir, word_v=word_v, pos_v=pos_v, tag_v=tag_v, train_data=train_data, dev_data=dev_data)
test_data = loader.load(os.path.join(datadir, test_data_name))
train_data.update_vocab(word_seq=word_v, pos_seq=pos_v, head_labels=tag_v)
datasets = (train_data, dev_data, test_data)
save_data(processed_datadir, word_v=word_v, pos_v=pos_v, tag_v=tag_v, train_data=train_data, dev_data=dev_data, test_data=test_data)

loader.index_data(train_data, word_v, pos_v, tag_v)
loader.index_data(dev_data, word_v, pos_v, tag_v)
print(len(train_data))
print(len(dev_data))
ep = train_args['epochs']
train_args['epochs'] = math.ceil(50000.0 / len(train_data) * train_args['batch_size']) if ep <= 0 else ep
embed, _ = EmbedLoader.load_embedding(model_args['word_emb_dim'], emb_file_name, 'glove', word_v, os.path.join(processed_datadir, 'word_emb.pkl'))
print(len(word_v))
print(embed.size())
# Model
model_args['word_vocab_size'] = len(word_v)
model_args['pos_vocab_size'] = len(pos_v)
model_args['num_label'] = len(tag_v)

model = BiaffineParser(**model_args.data)
model.reset_parameters()
datasets = (train_data, dev_data, test_data)
for ds in datasets:
ds.index_field("word_seq", word_v).index_field("pos_seq", pos_v).index_field("head_labels", tag_v)
ds.set_origin_len('word_seq')
if train_args['use_golden_train']:
train_data.set_target(gold_heads=False)
else:
train_data.set_target(gold_heads=None)
train_args.data.pop('use_golden_train')
ignore_label = pos_v['P']

print(test_data[0])
print(len(train_data))
print(len(dev_data))
print(len(test_data))



def train():
def train(path):
# Trainer
trainer = Trainer(**train_args.data)

def _define_optim(obj):
obj._optimizer = torch.optim.Adam(obj._model.parameters(), **optim_args.data)
obj._scheduler = torch.optim.lr_scheduler.LambdaLR(obj._optimizer, lambda ep: .75 ** (ep / 5e4))
lr = optim_args.data['lr']
embed_params = set(obj._model.word_embedding.parameters())
decay_params = set(obj._model.arc_predictor.parameters()) | set(obj._model.label_predictor.parameters())
params = [p for p in obj._model.parameters() if p not in decay_params and p not in embed_params]
obj._optimizer = torch.optim.Adam([
{'params': list(embed_params), 'lr':lr*0.1},
{'params': list(decay_params), **optim_args.data},
{'params': params}
], lr=lr, betas=(0.9, 0.9))
obj._scheduler = torch.optim.lr_scheduler.LambdaLR(obj._optimizer, lambda ep: max(.75 ** (ep / 5e4), 0.05))

def _update(obj):
# torch.nn.utils.clip_grad_norm_(obj._model.parameters(), 5.0)
obj._scheduler.step()
obj._optimizer.step()

trainer.define_optimizer = lambda: _define_optim(trainer)
trainer.update = lambda: _update(trainer)
trainer.get_loss = lambda predict, truth: trainer._loss_func(**predict, **truth)
trainer._create_validator = lambda x: MyTester(**test_args.data)

# Model
model = BiaffineParser(**model_args.data)
trainer.set_validator(Tester(**test_args.data, evaluator=ParserEvaluator(ignore_label)))

# use pretrain embedding
embed, _ = EmbedLoader.load_embedding(model_args['word_emb_dim'], emb_file_name, 'glove', word_v, os.path.join(processed_datadir, 'word_emb.pkl'))
model.word_embedding = torch.nn.Embedding.from_pretrained(embed, freeze=False)
model.word_embedding.padding_idx = word_v.padding_idx
model.word_embedding.weight.data[word_v.padding_idx].fill_(0)
model.pos_embedding.padding_idx = pos_v.padding_idx
model.pos_embedding.weight.data[pos_v.padding_idx].fill_(0)

try:
ModelLoader.load_pytorch(model, "./save/saved_model.pkl")
print('model parameter loaded!')
except Exception as _:
print("No saved model. Continue.")
pass
# try:
# ModelLoader.load_pytorch(model, "./save/saved_model.pkl")
# print('model parameter loaded!')
# except Exception as _:
# print("No saved model. Continue.")
# pass

# Start training
trainer.train(model, train_data, dev_data)
@@ -223,24 +317,27 @@ def train():
print("Model saved!")


def test():
def test(path):
# Tester
tester = MyTester(**test_args.data)
tester = Tester(**test_args.data, evaluator=ParserEvaluator(ignore_label))

# Model
model = BiaffineParser(**model_args.data)
model.eval()
try:
ModelLoader.load_pytorch(model, "./save/saved_model.pkl")
ModelLoader.load_pytorch(model, path)
print('model parameter loaded!')
except Exception as _:
print("No saved model. Abort test.")
raise

# Start training
print("Testing Train data")
tester.test(model, train_data)
print("Testing Dev data")
tester.test(model, dev_data)
print(tester.show_metrics())
print("Testing finished!")
print("Testing Test data")
tester.test(model, test_data)



@@ -248,11 +345,12 @@ if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='Run a chinese word segmentation model')
parser.add_argument('--mode', help='set the model\'s model', choices=['train', 'test', 'infer'])
parser.add_argument('--path', type=str, default='')
args = parser.parse_args()
if args.mode == 'train':
train()
train(args.path)
elif args.mode == 'test':
test()
test(args.path)
elif args.mode == 'infer':
infer()
else:


Loading…
Cancel
Save