Browse Source

fix and update tester, trainer, seq_model, add parser pipeline builder

tags/v0.2.0
yunfan 5 years ago
parent
commit
822aaf6286
7 changed files with 208 additions and 131 deletions
  1. +5
    -7
      fastNLP/core/metrics.py
  2. +10
    -12
      fastNLP/core/tester.py
  3. +24
    -14
      fastNLP/core/trainer.py
  4. +21
    -27
      fastNLP/models/biaffine_parser.py
  5. +62
    -67
      fastNLP/models/sequence_modeling.py
  6. +6
    -4
      fastNLP/modules/utils.py
  7. +80
    -0
      reproduction/Biaffine_parser/infer.py

+ 5
- 7
fastNLP/core/metrics.py View File

@@ -35,23 +35,21 @@ class SeqLabelEvaluator(Evaluator):
def __init__(self):
super(SeqLabelEvaluator, self).__init__()

def __call__(self, predict, truth):
def __call__(self, predict, truth, **_):
"""

:param predict: list of dict, the network outputs from all batches.
:param truth: list of dict, the ground truths from all batch_y.
:return accuracy:
"""
truth = [item["truth"] for item in truth]
predict = [item["predict"] for item in predict]
total_correct, total_count = 0., 0.
total_correct, total_count = 0., 0.
for x, y in zip(predict, truth):
# x = torch.tensor(x)
y = y.to(x) # make sure they are in the same device
mask = x.ge(1).long()
correct = torch.sum(x * mask == y * mask) - torch.sum(x.le(0))
mask = (y > 0)
correct = torch.sum(((x == y) * mask).long())
total_correct += float(correct)
total_count += float(torch.sum(mask))
total_count += float(torch.sum(mask.long()))
accuracy = total_correct / total_count
return {"accuracy": float(accuracy)}



+ 10
- 12
fastNLP/core/tester.py View File

@@ -1,4 +1,5 @@
import torch
from collections import defaultdict

from fastNLP.core.batch import Batch
from fastNLP.core.metrics import Evaluator
@@ -71,17 +72,18 @@ class Tester(object):
# turn on the testing mode; clean up the history
self.mode(network, is_test=True)
self.eval_history.clear()
output_list = []
truth_list = []

output, truths = defaultdict(list), defaultdict(list)
data_iterator = Batch(dev_data, self.batch_size, sampler=RandomSampler(), use_cuda=self.use_cuda)

with torch.no_grad():
for batch_x, batch_y in data_iterator:
prediction = self.data_forward(network, batch_x)
output_list.append(prediction)
truth_list.append(batch_y)
eval_results = self.evaluate(output_list, truth_list)
assert isinstance(prediction, dict)
for k, v in prediction.items():
output[k].append(v)
for k, v in batch_y.items():
truths[k].append(v)
eval_results = self.evaluate(**output, **truths)
print("[tester] {}".format(self.print_eval_results(eval_results)))
logger.info("[tester] {}".format(self.print_eval_results(eval_results)))
self.mode(network, is_test=False)
@@ -105,14 +107,10 @@ class Tester(object):
y = network(**x)
return y

def evaluate(self, predict, truth):
def evaluate(self, **kwargs):
"""Compute evaluation metrics.

:param predict: list of Tensor
:param truth: list of dict
:return eval_results: can be anything. It will be stored in self.eval_history
"""
return self._evaluator(predict, truth)
return self._evaluator(**kwargs)

def print_eval_results(self, results):
"""Override this method to support more print formats.


+ 24
- 14
fastNLP/core/trainer.py View File

@@ -47,7 +47,8 @@ class Trainer(object):
"valid_step": 500, "eval_sort_key": 'acc',
"loss": Loss(None), # used to pass type check
"optimizer": Optimizer("Adam", lr=0.001, weight_decay=0),
"evaluator": Evaluator()
"eval_batch_size": 64,
"evaluator": Evaluator(),
}
"""
"required_args" is the collection of arguments that users must pass to Trainer explicitly.
@@ -78,6 +79,7 @@ class Trainer(object):

self.n_epochs = int(default_args["epochs"])
self.batch_size = int(default_args["batch_size"])
self.eval_batch_size = int(default_args['eval_batch_size'])
self.pickle_path = default_args["pickle_path"]
self.validate = default_args["validate"]
self.save_best_dev = default_args["save_best_dev"]
@@ -98,6 +100,8 @@ class Trainer(object):
self._best_accuracy = 0.0
self.eval_sort_key = default_args['eval_sort_key']
self.validator = None
self.epoch = 0
self.step = 0

def train(self, network, train_data, dev_data=None):
"""General Training Procedure
@@ -118,7 +122,7 @@ class Trainer(object):
# define Tester over dev data
self.dev_data = None
if self.validate:
default_valid_args = {"batch_size": self.batch_size, "pickle_path": self.pickle_path,
default_valid_args = {"batch_size": self.eval_batch_size, "pickle_path": self.pickle_path,
"use_cuda": self.use_cuda, "evaluator": self._evaluator}
if self.validator is None:
self.validator = self._create_validator(default_valid_args)
@@ -139,9 +143,9 @@ class Trainer(object):
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S'))
print("training epochs started " + self.start_time)
logger.info("training epochs started " + self.start_time)
epoch, iters = 1, 0
while epoch <= self.n_epochs:
logger.info("training epoch {}".format(epoch))
self.epoch, self.step = 1, 0
while self.epoch <= self.n_epochs:
logger.info("training epoch {}".format(self.epoch))

# prepare mini-batch iterator
data_iterator = Batch(train_data, batch_size=self.batch_size,
@@ -150,14 +154,13 @@ class Trainer(object):
logger.info("prepared data iterator")

# one forward and backward pass
iters = self._train_step(data_iterator, network, start=start, n_print=self.print_every_step, epoch=epoch,
step=iters, dev_data=dev_data)
self._train_step(data_iterator, network, start=start, n_print=self.print_every_step, dev_data=dev_data)

# validation
if self.validate:
self.valid_model()
self.save_model(self._model, 'training_model_' + self.start_time)
epoch += 1
self.epoch += 1

def _train_step(self, data_iterator, network, **kwargs):
"""Training process in one epoch.
@@ -167,7 +170,6 @@ class Trainer(object):
- start: time.time(), the starting time of this step.
- epoch: int,
"""
step = kwargs['step']
for batch_x, batch_y in data_iterator:
prediction = self.data_forward(network, batch_x)

@@ -177,25 +179,31 @@ class Trainer(object):

self.grad_backward(loss)
self.update()
self._summary_writer.add_scalar("loss", loss.item(), global_step=step)
self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step)
for name, param in self._model.named_parameters():
if param.requires_grad:
<<<<<<< HEAD
# self._summary_writer.add_scalar(name + "_mean", param.mean(), global_step=step)
# self._summary_writer.add_scalar(name + "_std", param.std(), global_step=step)
# self._summary_writer.add_scalar(name + "_grad_sum", param.sum(), global_step=step)
pass

if kwargs["n_print"] > 0 and step % kwargs["n_print"] == 0:
=======
self._summary_writer.add_scalar(name + "_mean", param.mean(), global_step=self.step)
# self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.step)
# self._summary_writer.add_scalar(name + "_grad_sum", param.sum(), global_step=self.step)
if kwargs["n_print"] > 0 and self.step % kwargs["n_print"] == 0:
>>>>>>> 5924fe0... fix and update tester, trainer, seq_model, add parser pipeline builder
end = time.time()
diff = timedelta(seconds=round(end - kwargs["start"]))
print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.6} time: {}".format(
kwargs["epoch"], step, loss.data, diff)
self.epoch, self.step, loss.data, diff)
print(print_output)
logger.info(print_output)
if self.validate and self.valid_step > 0 and step > 0 and step % self.valid_step == 0:
if self.validate and self.valid_step > 0 and self.step > 0 and self.step % self.valid_step == 0:
self.valid_model()
step += 1
return step
self.step += 1

def valid_model(self):
if self.dev_data is None:
@@ -203,6 +211,8 @@ class Trainer(object):
"self.validate is True in trainer, but dev_data is None. Please provide the validation data.")
logger.info("validation started")
res = self.validator.test(self._model, self.dev_data)
for name, num in res.items():
self._summary_writer.add_scalar("valid_{}".format(name), num, global_step=self.step)
if self.save_best_dev and self.best_eval_result(res):
logger.info('save best result! {}'.format(res))
print('save best result! {}'.format(res))


+ 21
- 27
fastNLP/models/biaffine_parser.py View File

@@ -10,6 +10,7 @@ from fastNLP.modules.utils import initial_parameter
from fastNLP.modules.encoder.variational_rnn import VarLSTM
from fastNLP.modules.dropout import TimestepDropout
from fastNLP.models.base_model import BaseModel
from fastNLP.modules.utils import seq_mask

def mst(scores):
"""
@@ -123,31 +124,31 @@ class GraphParser(BaseModel):
def forward(self, x):
raise NotImplementedError

def _greedy_decoder(self, arc_matrix, seq_mask=None):
def _greedy_decoder(self, arc_matrix, mask=None):
_, seq_len, _ = arc_matrix.shape
matrix = arc_matrix + torch.diag(arc_matrix.new(seq_len).fill_(-np.inf))
flip_mask = (seq_mask == 0).byte()
flip_mask = (mask == 0).byte()
matrix.masked_fill_(flip_mask.unsqueeze(1), -np.inf)
_, heads = torch.max(matrix, dim=2)
if seq_mask is not None:
heads *= seq_mask.long()
if mask is not None:
heads *= mask.long()
return heads

def _mst_decoder(self, arc_matrix, seq_mask=None):
def _mst_decoder(self, arc_matrix, mask=None):
batch_size, seq_len, _ = arc_matrix.shape
matrix = torch.zeros_like(arc_matrix).copy_(arc_matrix)
ans = matrix.new_zeros(batch_size, seq_len).long()
lens = (seq_mask.long()).sum(1) if seq_mask is not None else torch.zeros(batch_size) + seq_len
lens = (mask.long()).sum(1) if mask is not None else torch.zeros(batch_size) + seq_len
batch_idx = torch.arange(batch_size, dtype=torch.long, device=lens.device)
seq_mask[batch_idx, lens-1] = 0
mask[batch_idx, lens-1] = 0
for i, graph in enumerate(matrix):
len_i = lens[i]
if len_i == seq_len:
ans[i] = torch.as_tensor(mst(graph.cpu().numpy()), device=ans.device)
else:
ans[i, :len_i] = torch.as_tensor(mst(graph[:len_i, :len_i].cpu().numpy()), device=ans.device)
if seq_mask is not None:
ans *= seq_mask.long()
if mask is not None:
ans *= mask.long()
return ans


@@ -191,13 +192,6 @@ class LabelBilinear(nn.Module):
output += self.lin(torch.cat([x1, x2], dim=2))
return output

def len2masks(origin_len, max_len):
if origin_len.dim() <= 1:
origin_len = origin_len.unsqueeze(1) # [batch_size, 1]
seq_range = torch.arange(start=0, end=max_len, dtype=torch.long, device=origin_len.device) # [max_len,]
seq_mask = torch.gt(origin_len, seq_range.unsqueeze(0)) # [batch_size, max_len]
return seq_mask

class BiaffineParser(GraphParser):
"""Biaffine Dependency Parser implemantation.
refer to ` Deep Biaffine Attention for Neural Dependency Parsing (Dozat and Manning, 2016)
@@ -277,12 +271,12 @@ class BiaffineParser(GraphParser):
"""
:param word_seq: [batch_size, seq_len] sequence of word's indices
:param pos_seq: [batch_size, seq_len] sequence of word's indices
:param seq_mask: [batch_size, seq_len] sequence of length masks
:param word_seq_origin_len: [batch_size, seq_len] sequence of length masks
:param gold_heads: [batch_size, seq_len] sequence of golden heads
:return dict: parsing results
arc_pred: [batch_size, seq_len, seq_len]
label_pred: [batch_size, seq_len, seq_len]
seq_mask: [batch_size, seq_len]
mask: [batch_size, seq_len]
head_pred: [batch_size, seq_len] if gold_heads is not provided, predicting the heads
"""
# prepare embeddings
@@ -294,7 +288,7 @@ class BiaffineParser(GraphParser):
# print('forward {} {}'.format(batch_size, seq_len))

# get sequence mask
seq_mask = len2masks(word_seq_origin_len, seq_len).long()
mask = seq_mask(word_seq_origin_len, seq_len).long()

word = self.normal_dropout(self.word_embedding(word_seq)) # [N,L] -> [N,L,C_0]
pos = self.normal_dropout(self.pos_embedding(pos_seq)) # [N,L] -> [N,L,C_1]
@@ -327,14 +321,14 @@ class BiaffineParser(GraphParser):
if gold_heads is None or not self.training:
# use greedy decoding in training
if self.training or self.use_greedy_infer:
heads = self._greedy_decoder(arc_pred, seq_mask)
heads = self._greedy_decoder(arc_pred, mask)
else:
heads = self._mst_decoder(arc_pred, seq_mask)
heads = self._mst_decoder(arc_pred, mask)
head_pred = heads
else:
assert self.training # must be training mode
if torch.rand(1).item() < self.explore_p:
heads = self._greedy_decoder(arc_pred, seq_mask)
heads = self._greedy_decoder(arc_pred, mask)
head_pred = heads
else:
head_pred = None
@@ -343,12 +337,12 @@ class BiaffineParser(GraphParser):
batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=word_seq.device).unsqueeze(1)
label_head = label_head[batch_range, heads].contiguous()
label_pred = self.label_predictor(label_head, label_dep) # [N, L, num_label]
res_dict = {'arc_pred': arc_pred, 'label_pred': label_pred, 'seq_mask': seq_mask}
res_dict = {'arc_pred': arc_pred, 'label_pred': label_pred, 'mask': mask}
if head_pred is not None:
res_dict['head_pred'] = head_pred
return res_dict

def loss(self, arc_pred, label_pred, head_indices, head_labels, seq_mask, **_):
def loss(self, arc_pred, label_pred, head_indices, head_labels, mask, **_):
"""
Compute loss.

@@ -356,12 +350,12 @@ class BiaffineParser(GraphParser):
:param label_pred: [batch_size, seq_len, n_tags]
:param head_indices: [batch_size, seq_len]
:param head_labels: [batch_size, seq_len]
:param seq_mask: [batch_size, seq_len]
:param mask: [batch_size, seq_len]
:return: loss value
"""

batch_size, seq_len, _ = arc_pred.shape
flip_mask = (seq_mask == 0)
flip_mask = (mask == 0)
_arc_pred = arc_pred.new_empty((batch_size, seq_len, seq_len)).copy_(arc_pred)
_arc_pred.masked_fill_(flip_mask.unsqueeze(1), -np.inf)
arc_logits = F.log_softmax(_arc_pred, dim=2)
@@ -374,7 +368,7 @@ class BiaffineParser(GraphParser):
arc_loss = arc_loss[:, 1:]
label_loss = label_loss[:, 1:]

float_mask = seq_mask[:, 1:].float()
float_mask = mask[:, 1:].float()
arc_nll = -(arc_loss*float_mask).mean()
label_nll = -(label_loss*float_mask).mean()
return arc_nll + label_nll


+ 62
- 67
fastNLP/models/sequence_modeling.py View File

@@ -4,20 +4,7 @@ import numpy as np

from fastNLP.models.base_model import BaseModel
from fastNLP.modules import decoder, encoder


def seq_mask(seq_len, max_len):
"""Create a mask for the sequences.

:param seq_len: list or torch.LongTensor
:param max_len: int
:return mask: torch.LongTensor
"""
if isinstance(seq_len, list):
seq_len = torch.LongTensor(seq_len)
mask = [torch.ge(seq_len, i + 1) for i in range(max_len)]
mask = torch.stack(mask, 1)
return mask
from fastNLP.modules.utils import seq_mask


class SeqLabeling(BaseModel):
@@ -82,7 +69,7 @@ class SeqLabeling(BaseModel):
def make_mask(self, x, seq_len):
batch_size, max_len = x.size(0), x.size(1)
mask = seq_mask(seq_len, max_len)
mask = mask.byte().view(batch_size, max_len)
mask = mask.view(batch_size, max_len)
mask = mask.to(x).float()
return mask

@@ -114,16 +101,20 @@ class AdvSeqLabel(SeqLabeling):
word_emb_dim = args["word_emb_dim"]
hidden_dim = args["rnn_hidden_units"]
num_classes = args["num_classes"]
dropout = args['dropout']

self.Embedding = encoder.embedding.Embedding(vocab_size, word_emb_dim, init_emb=emb)
self.Rnn = encoder.lstm.LSTM(word_emb_dim, hidden_dim, num_layers=1, dropout=0.5, bidirectional=True)
self.norm1 = torch.nn.LayerNorm(word_emb_dim)
# self.Rnn = encoder.lstm.LSTM(word_emb_dim, hidden_dim, num_layers=2, dropout=dropout, bidirectional=True)
self.Rnn = torch.nn.LSTM(input_size=word_emb_dim, hidden_size=hidden_dim, num_layers=2, dropout=dropout, bidirectional=True, batch_first=True)
self.Linear1 = encoder.Linear(hidden_dim * 2, hidden_dim * 2 // 3)
self.batch_norm = torch.nn.BatchNorm1d(hidden_dim * 2 // 3)
self.relu = torch.nn.ReLU()
self.drop = torch.nn.Dropout(0.5)
self.norm2 = torch.nn.LayerNorm(hidden_dim * 2 // 3)
# self.batch_norm = torch.nn.BatchNorm1d(hidden_dim * 2 // 3)
self.relu = torch.nn.LeakyReLU()
self.drop = torch.nn.Dropout(dropout)
self.Linear2 = encoder.Linear(hidden_dim * 2 // 3, num_classes)

self.Crf = decoder.CRF.ConditionalRandomField(num_classes)
self.Crf = decoder.CRF.ConditionalRandomField(num_classes, include_start_end_trans=False)

def forward(self, word_seq, word_seq_origin_len, truth=None):
"""
@@ -135,12 +126,10 @@ class AdvSeqLabel(SeqLabeling):
"""

word_seq = word_seq.long()
word_seq_origin_len = word_seq_origin_len.long()
self.mask = self.make_mask(word_seq, word_seq_origin_len)
word_seq_origin_len = word_seq_origin_len.cpu().numpy()
sent_len, idx_sort = np.sort(word_seq_origin_len)[::-1], np.argsort(-word_seq_origin_len)
idx_unsort = np.argsort(idx_sort)
idx_sort = torch.from_numpy(idx_sort)
idx_unsort = torch.from_numpy(idx_unsort)
sent_len, idx_sort = torch.sort(word_seq_origin_len, descending=True)
_, idx_unsort = torch.sort(idx_sort, descending=False)

# word_seq_origin_len = word_seq_origin_len.long()
truth = truth.long() if truth is not None else None
@@ -155,26 +144,28 @@ class AdvSeqLabel(SeqLabeling):
truth = truth.cuda() if truth is not None else None

x = self.Embedding(word_seq)
x = self.norm1(x)
# [batch_size, max_len, word_emb_dim]

sent_variable = x.index_select(0, idx_sort)
sent_variable = x[idx_sort]
sent_packed = torch.nn.utils.rnn.pack_padded_sequence(sent_variable, sent_len, batch_first=True)

x = self.Rnn(sent_packed)
x, _ = self.Rnn(sent_packed)
# print(x)
# [batch_size, max_len, hidden_size * direction]

sent_output = torch.nn.utils.rnn.pad_packed_sequence(x, batch_first=True)[0]
x = sent_output.index_select(0, idx_unsort)
x = sent_output[idx_unsort]

x = x.contiguous()
x = x.view(batch_size * max_len, -1)
# x = x.view(batch_size * max_len, -1)
x = self.Linear1(x)
# x = self.batch_norm(x)
x = self.norm2(x)
x = self.relu(x)
x = self.drop(x)
x = self.Linear2(x)
x = x.view(batch_size, max_len, -1)
# x = x.view(batch_size, max_len, -1)
# [batch_size, max_len, num_classes]
return {"loss": self._internal_loss(x, truth) if truth is not None else None,
"predict": self.decode(x)}
@@ -183,41 +174,45 @@ class AdvSeqLabel(SeqLabeling):
out = self.forward(**x)
return {"predict": out["predict"]}


args = {
'vocab_size': 20,
'word_emb_dim': 100,
'rnn_hidden_units': 100,
'num_classes': 10,
}
model = AdvSeqLabel(args)
data = []
for i in range(20):
word_seq = torch.randint(20, (15,)).long()
word_seq_len = torch.LongTensor([15])
truth = torch.randint(10, (15,)).long()
data.append((word_seq, word_seq_len, truth))
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
print(model)
curidx = 0
for i in range(1000):
endidx = min(len(data), curidx + 5)
b_word, b_len, b_truth = [], [], []
for word_seq, word_seq_len, truth in data[curidx: endidx]:
b_word.append(word_seq)
b_len.append(word_seq_len)
b_truth.append(truth)
word_seq = torch.stack(b_word, dim=0)
word_seq_len = torch.cat(b_len, dim=0)
truth = torch.stack(b_truth, dim=0)
res = model(word_seq, word_seq_len, truth)
loss = res['loss']
pred = res['predict']
print('loss: {} acc {}'.format(loss.item(), ((pred.data == truth).long().sum().float() / word_seq_len.sum().float())))
optimizer.zero_grad()
loss.backward()
optimizer.step()
curidx = endidx
if curidx == len(data):
curidx = 0
def loss(self, **kwargs):
assert 'loss' in kwargs
return kwargs['loss']

if __name__ == '__main__':
args = {
'vocab_size': 20,
'word_emb_dim': 100,
'rnn_hidden_units': 100,
'num_classes': 10,
}
model = AdvSeqLabel(args)
data = []
for i in range(20):
word_seq = torch.randint(20, (15,)).long()
word_seq_len = torch.LongTensor([15])
truth = torch.randint(10, (15,)).long()
data.append((word_seq, word_seq_len, truth))
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
print(model)
curidx = 0
for i in range(1000):
endidx = min(len(data), curidx + 5)
b_word, b_len, b_truth = [], [], []
for word_seq, word_seq_len, truth in data[curidx: endidx]:
b_word.append(word_seq)
b_len.append(word_seq_len)
b_truth.append(truth)
word_seq = torch.stack(b_word, dim=0)
word_seq_len = torch.cat(b_len, dim=0)
truth = torch.stack(b_truth, dim=0)
res = model(word_seq, word_seq_len, truth)
loss = res['loss']
pred = res['predict']
print('loss: {} acc {}'.format(loss.item(), ((pred.data == truth).long().sum().float() / word_seq_len.sum().float())))
optimizer.zero_grad()
loss.backward()
optimizer.step()
curidx = endidx
if curidx == len(data):
curidx = 0


+ 6
- 4
fastNLP/modules/utils.py View File

@@ -77,11 +77,13 @@ def initial_parameter(net, initial_method=None):
def seq_mask(seq_len, max_len):
"""Create sequence mask.

:param seq_len: list of int, the lengths of sequences in a batch.
:param seq_len: list or torch.Tensor, the lengths of sequences in a batch.
:param max_len: int, the maximum sequence length in a batch.
:return mask: torch.LongTensor, [batch_size, max_len]

"""
mask = [torch.ge(torch.LongTensor(seq_len), i + 1) for i in range(max_len)]
mask = torch.stack(mask, 1)
return mask
if not isinstance(seq_len, torch.Tensor):
seq_len = torch.LongTensor(seq_len)
seq_len = seq_len.view(-1, 1).long() # [batch_size, 1]
seq_range = torch.arange(start=0, end=max_len, dtype=torch.long, device=seq_len.device).view(1, -1) # [1, max_len]
return torch.gt(seq_len, seq_range) # [batch_size, max_len]

+ 80
- 0
reproduction/Biaffine_parser/infer.py View File

@@ -0,0 +1,80 @@
import sys
import os

sys.path.extend(['/home/yfshao/workdir/dev_fastnlp'])

from fastNLP.api.processor import *
from fastNLP.api.pipeline import Pipeline
from fastNLP.core.dataset import DataSet
from fastNLP.models.biaffine_parser import BiaffineParser
from fastNLP.loader.config_loader import ConfigSection, ConfigLoader

import _pickle as pickle
import torch

def _load(path):
with open(path, 'rb') as f:
obj = pickle.load(f)
return obj

def _load_all(src):
model_path = src
src = os.path.dirname(src)

word_v = _load(src+'/word_v.pkl')
pos_v = _load(src+'/pos_v.pkl')
tag_v = _load(src+'/tag_v.pkl')

model_args = ConfigSection()
ConfigLoader.load_config('cfg.cfg', {'model': model_args})
model_args['word_vocab_size'] = len(word_v)
model_args['pos_vocab_size'] = len(pos_v)
model_args['num_label'] = len(tag_v)

model = BiaffineParser(**model_args.data)
model.load_state_dict(torch.load(model_path))
return {
'word_v': word_v,
'pos_v': pos_v,
'tag_v': tag_v,
'model': model,
}

def build(load_path, save_path):
BOS = '<BOS>'
NUM = '<NUM>'
_dict = _load_all(load_path)
word_vocab = _dict['word_v']
pos_vocab = _dict['pos_v']
tag_vocab = _dict['tag_v']
model = _dict['model']
print('load model from {}'.format(load_path))
word_seq = 'raw_word_seq'
pos_seq = 'raw_pos_seq'

# build pipeline
pipe = Pipeline()
pipe.add_processor(Num2TagProcessor(NUM, 'sentence', word_seq))
pipe.add_processor(PreAppendProcessor(BOS, word_seq))
pipe.add_processor(PreAppendProcessor(BOS, 'sent_pos', pos_seq))
pipe.add_processor(IndexerProcessor(word_vocab, word_seq, 'word_seq'))
pipe.add_processor(IndexerProcessor(pos_vocab, pos_seq, 'pos_seq'))
pipe.add_processor(SeqLenProcessor(word_seq, 'word_seq_origin_len'))
pipe.add_processor(SetTensorProcessor({'word_seq':True, 'pos_seq':True, 'word_seq_origin_len':True}, default=False))
pipe.add_processor(ModelProcessor(model, 'word_seq_origin_len'))
pipe.add_processor(SliceProcessor(1, None, None, 'head_pred', 'heads'))
pipe.add_processor(SliceProcessor(1, None, None, 'label_pred', 'label_pred'))
pipe.add_processor(Index2WordProcessor(tag_vocab, 'label_pred', 'labels'))
if not os.path.exists(save_path):
os.makedirs(save_path)
with open(save_path+'/pipeline.pkl', 'wb') as f:
torch.save(pipe, f)
print('save pipeline in {}'.format(save_path))


import argparse
parser = argparse.ArgumentParser(description='build pipeline for parser.')
parser.add_argument('--src', type=str, default='/home/yfshao/workdir/dev_fastnlp/reproduction/Biaffine_parser/save')
parser.add_argument('--dst', type=str, default='/home/yfshao/workdir/dev_fastnlp/reproduction/Biaffine_parser/pipe')
args = parser.parse_args()
build(args.src, args.dst)

Loading…
Cancel
Save