@@ -202,30 +202,30 @@ class Parser(API): | |||||
if model_path is None: | if model_path is None: | ||||
model_path = model_urls['parser'] | model_path = model_urls['parser'] | ||||
self.pos_tagger = POS(device=device) | |||||
self.load(model_path, device) | self.load(model_path, device) | ||||
def predict(self, content): | def predict(self, content): | ||||
if not hasattr(self, 'pipeline'): | if not hasattr(self, 'pipeline'): | ||||
raise ValueError("You have to load model first.") | raise ValueError("You have to load model first.") | ||||
sentence_list = [] | |||||
# 1. 检查sentence的类型 | |||||
if isinstance(content, str): | |||||
sentence_list.append(content) | |||||
elif isinstance(content, list): | |||||
sentence_list = content | |||||
# 1. 利用POS得到分词和pos tagging结果 | |||||
pos_out = self.pos_tagger.predict(content) | |||||
# pos_out = ['这里/NN 是/VB 分词/NN 结果/NN'.split()] | |||||
# 2. 组建dataset | # 2. 组建dataset | ||||
dataset = DataSet() | dataset = DataSet() | ||||
dataset.add_field('words', sentence_list) | |||||
# dataset.add_field('tag', sentence_list) | |||||
dataset.add_field('wp', pos_out) | |||||
dataset.apply(lambda x: ['<BOS>']+[w.split('/')[0] for w in x['wp']], new_field_name='words') | |||||
dataset.apply(lambda x: ['<BOS>']+[w.split('/')[1] for w in x['wp']], new_field_name='pos') | |||||
# 3. 使用pipeline | # 3. 使用pipeline | ||||
self.pipeline(dataset) | self.pipeline(dataset) | ||||
for ins in dataset: | |||||
ins['heads'] = ins['heads'].tolist() | |||||
return dataset['heads'], dataset['labels'] | |||||
dataset.apply(lambda x: [str(arc) for arc in x['arc_pred']], new_field_name='arc_pred') | |||||
dataset.apply(lambda x: [arc + '/' + label for arc, label in | |||||
zip(x['arc_pred'], x['label_pred_seq'])][1:], new_field_name='output') | |||||
# output like: [['2/top', '0/root', '4/nn', '2/dep']] | |||||
return dataset.field_arrays['output'].content | |||||
def test(self, filepath): | def test(self, filepath): | ||||
data = ConllxDataLoader().load(filepath) | data = ConllxDataLoader().load(filepath) | ||||
@@ -301,12 +301,12 @@ class Analyzer: | |||||
if __name__ == "__main__": | if __name__ == "__main__": | ||||
pos_model_path = '/home/zyfeng/fastnlp/reproduction/pos_tag_model/model_pp.pkl' | |||||
pos = POS(pos_model_path, device='cpu') | |||||
s = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。', | |||||
'这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。', | |||||
'那么这款无人机到底有多厉害?'] | |||||
print(pos.test("/home/zyfeng/data/sample.conllx")) | |||||
# pos_model_path = '/home/zyfeng/fastnlp/reproduction/pos_tag_model/model_pp.pkl' | |||||
# pos = POS(pos_model_path, device='cpu') | |||||
# s = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。', | |||||
# '这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。', | |||||
# '那么这款无人机到底有多厉害?'] | |||||
# print(pos.test("/home/zyfeng/data/sample.conllx")) | |||||
# print(pos.predict(s)) | # print(pos.predict(s)) | ||||
# cws_model_path = '../../reproduction/chinese_word_segment/models/cws_crf.pkl' | # cws_model_path = '../../reproduction/chinese_word_segment/models/cws_crf.pkl' | ||||
@@ -317,9 +317,10 @@ if __name__ == "__main__": | |||||
# print(cws.test('/Users/yh/Desktop/test_data/cws_test.conll')) | # print(cws.test('/Users/yh/Desktop/test_data/cws_test.conll')) | ||||
# print(cws.predict(s)) | # print(cws.predict(s)) | ||||
# parser = Parser(device='cpu') | |||||
parser_path = '/home/yfshao/workdir/fastnlp/reproduction/Biaffine_parser/pipe.pkl' | |||||
parser = Parser(parser_path, device='cpu') | |||||
# print(parser.test('/Users/yh/Desktop/test_data/parser_test2.conll')) | # print(parser.test('/Users/yh/Desktop/test_data/parser_test2.conll')) | ||||
s = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。', | s = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。', | ||||
'这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。', | '这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。', | ||||
'那么这款无人机到底有多厉害?'] | '那么这款无人机到底有多厉害?'] | ||||
# print(parser.predict(s)) | |||||
print(parser.predict(s)) |
@@ -302,15 +302,23 @@ class Index2WordProcessor(Processor): | |||||
return dataset | return dataset | ||||
class SetIsTargetProcessor(Processor): | |||||
class SetTargetProcessor(Processor): | |||||
# TODO; remove it. | # TODO; remove it. | ||||
def __init__(self, field_dict, default=False): | |||||
super(SetIsTargetProcessor, self).__init__(None, None) | |||||
self.field_dict = field_dict | |||||
self.default = default | |||||
def __init__(self, *fields, flag=True): | |||||
super(SetTargetProcessor, self).__init__(None, None) | |||||
self.fields = fields | |||||
self.flag = flag | |||||
def process(self, dataset): | def process(self, dataset): | ||||
set_dict = {name: self.default for name in dataset.get_all_fields().keys()} | |||||
set_dict.update(self.field_dict) | |||||
dataset.set_target(*set_dict.keys()) | |||||
dataset.set_target(*self.fields, flag=self.flag) | |||||
return dataset | |||||
class SetInputProcessor(Processor): | |||||
def __init__(self, *fields, flag=True): | |||||
super(SetInputProcessor, self).__init__(None, None) | |||||
self.fields = fields | |||||
self.flag = flag | |||||
def process(self, dataset): | |||||
dataset.set_input(*self.fields, flag=self.flag) | |||||
return dataset | return dataset |
@@ -400,7 +400,7 @@ def seq_lens_to_masks(seq_lens, float=False): | |||||
assert len(np.shape(seq_lens)) == 1, f"seq_lens can only have one dimension, got {len(np.shape(seq_lens))}." | assert len(np.shape(seq_lens)) == 1, f"seq_lens can only have one dimension, got {len(np.shape(seq_lens))}." | ||||
assert seq_lens.dtype in (int, np.int32, np.int64), f"seq_lens can only be integer, not {seq_lens.dtype}." | assert seq_lens.dtype in (int, np.int32, np.int64), f"seq_lens can only be integer, not {seq_lens.dtype}." | ||||
raise NotImplemented | raise NotImplemented | ||||
elif isinstance(seq_lens, torch.LongTensor): | |||||
elif isinstance(seq_lens, torch.Tensor): | |||||
assert len(seq_lens.size()) == 1, f"seq_lens can only have one dimension, got {len(seq_lens.size())==1}." | assert len(seq_lens.size()) == 1, f"seq_lens can only have one dimension, got {len(seq_lens.size())==1}." | ||||
batch_size = seq_lens.size(0) | batch_size = seq_lens.size(0) | ||||
max_len = seq_lens.max() | max_len = seq_lens.max() | ||||
@@ -134,17 +134,13 @@ class GraphParser(BaseModel): | |||||
def _mst_decoder(self, arc_matrix, mask=None): | def _mst_decoder(self, arc_matrix, mask=None): | ||||
batch_size, seq_len, _ = arc_matrix.shape | batch_size, seq_len, _ = arc_matrix.shape | ||||
matrix = torch.zeros_like(arc_matrix).copy_(arc_matrix) | |||||
matrix = arc_matrix.clone() | |||||
ans = matrix.new_zeros(batch_size, seq_len).long() | ans = matrix.new_zeros(batch_size, seq_len).long() | ||||
lens = (mask.long()).sum(1) if mask is not None else torch.zeros(batch_size) + seq_len | lens = (mask.long()).sum(1) if mask is not None else torch.zeros(batch_size) + seq_len | ||||
batch_idx = torch.arange(batch_size, dtype=torch.long, device=lens.device) | batch_idx = torch.arange(batch_size, dtype=torch.long, device=lens.device) | ||||
mask[batch_idx, lens-1] = 0 | |||||
for i, graph in enumerate(matrix): | for i, graph in enumerate(matrix): | ||||
len_i = lens[i] | len_i = lens[i] | ||||
if len_i == seq_len: | |||||
ans[i] = torch.as_tensor(mst(graph.cpu().numpy()), device=ans.device) | |||||
else: | |||||
ans[i, :len_i] = torch.as_tensor(mst(graph[:len_i, :len_i].cpu().numpy()), device=ans.device) | |||||
ans[i, :len_i] = torch.as_tensor(mst(graph.detach()[:len_i, :len_i].cpu().numpy()), device=ans.device) | |||||
if mask is not None: | if mask is not None: | ||||
ans *= mask.long() | ans *= mask.long() | ||||
return ans | return ans | ||||
@@ -219,6 +215,7 @@ class BiaffineParser(GraphParser): | |||||
self.pos_fc = nn.Linear(pos_emb_dim, pos_hid_dim) | self.pos_fc = nn.Linear(pos_emb_dim, pos_hid_dim) | ||||
self.word_norm = nn.LayerNorm(word_hid_dim) | self.word_norm = nn.LayerNorm(word_hid_dim) | ||||
self.pos_norm = nn.LayerNorm(pos_hid_dim) | self.pos_norm = nn.LayerNorm(pos_hid_dim) | ||||
self.use_var_lstm = use_var_lstm | |||||
if use_var_lstm: | if use_var_lstm: | ||||
self.lstm = VarLSTM(input_size=word_hid_dim + pos_hid_dim, | self.lstm = VarLSTM(input_size=word_hid_dim + pos_hid_dim, | ||||
hidden_size=rnn_hidden_size, | hidden_size=rnn_hidden_size, | ||||
@@ -249,10 +246,9 @@ class BiaffineParser(GraphParser): | |||||
self.label_dep_mlp = copy.deepcopy(self.label_head_mlp) | self.label_dep_mlp = copy.deepcopy(self.label_head_mlp) | ||||
self.arc_predictor = ArcBiaffine(arc_mlp_size, bias=True) | self.arc_predictor = ArcBiaffine(arc_mlp_size, bias=True) | ||||
self.label_predictor = LabelBilinear(label_mlp_size, label_mlp_size, num_label, bias=True) | self.label_predictor = LabelBilinear(label_mlp_size, label_mlp_size, num_label, bias=True) | ||||
self.normal_dropout = nn.Dropout(p=dropout) | |||||
self.use_greedy_infer = use_greedy_infer | self.use_greedy_infer = use_greedy_infer | ||||
self.reset_parameters() | self.reset_parameters() | ||||
self.explore_p = 0.2 | |||||
self.dropout = dropout | |||||
def reset_parameters(self): | def reset_parameters(self): | ||||
for m in self.modules(): | for m in self.modules(): | ||||
@@ -278,18 +274,15 @@ class BiaffineParser(GraphParser): | |||||
head_pred: [batch_size, seq_len] if gold_heads is not provided, predicting the heads | head_pred: [batch_size, seq_len] if gold_heads is not provided, predicting the heads | ||||
""" | """ | ||||
# prepare embeddings | # prepare embeddings | ||||
device = self.parameters().__next__().device | |||||
word_seq = word_seq.long().to(device) | |||||
pos_seq = pos_seq.long().to(device) | |||||
seq_lens = seq_lens.long().to(device).view(-1) | |||||
batch_size, seq_len = word_seq.shape | batch_size, seq_len = word_seq.shape | ||||
# print('forward {} {}'.format(batch_size, seq_len)) | # print('forward {} {}'.format(batch_size, seq_len)) | ||||
# get sequence mask | # get sequence mask | ||||
mask = seq_mask(seq_lens, seq_len).long() | mask = seq_mask(seq_lens, seq_len).long() | ||||
word = self.normal_dropout(self.word_embedding(word_seq)) # [N,L] -> [N,L,C_0] | |||||
pos = self.normal_dropout(self.pos_embedding(pos_seq)) # [N,L] -> [N,L,C_1] | |||||
word = self.word_embedding(word_seq) # [N,L] -> [N,L,C_0] | |||||
pos = self.pos_embedding(pos_seq) # [N,L] -> [N,L,C_1] | |||||
word, pos = self.word_fc(word), self.pos_fc(pos) | word, pos = self.word_fc(word), self.pos_fc(pos) | ||||
word, pos = self.word_norm(word), self.pos_norm(pos) | word, pos = self.word_norm(word), self.pos_norm(pos) | ||||
x = torch.cat([word, pos], dim=2) # -> [N,L,C] | x = torch.cat([word, pos], dim=2) # -> [N,L,C] | ||||
@@ -325,7 +318,7 @@ class BiaffineParser(GraphParser): | |||||
head_pred = heads | head_pred = heads | ||||
else: | else: | ||||
assert self.training # must be training mode | assert self.training # must be training mode | ||||
if torch.rand(1).item() < self.explore_p: | |||||
if gold_heads is None: | |||||
heads = self._greedy_decoder(arc_pred, mask) | heads = self._greedy_decoder(arc_pred, mask) | ||||
head_pred = heads | head_pred = heads | ||||
else: | else: | ||||
@@ -355,7 +348,7 @@ class BiaffineParser(GraphParser): | |||||
batch_size, seq_len, _ = arc_pred.shape | batch_size, seq_len, _ = arc_pred.shape | ||||
flip_mask = (mask == 0) | flip_mask = (mask == 0) | ||||
_arc_pred = arc_pred.new_empty((batch_size, seq_len, seq_len)).copy_(arc_pred) | |||||
_arc_pred = arc_pred.clone() | |||||
_arc_pred.masked_fill_(flip_mask.unsqueeze(1), -np.inf) | _arc_pred.masked_fill_(flip_mask.unsqueeze(1), -np.inf) | ||||
arc_logits = F.log_softmax(_arc_pred, dim=2) | arc_logits = F.log_softmax(_arc_pred, dim=2) | ||||
label_logits = F.log_softmax(label_pred, dim=2) | label_logits = F.log_softmax(label_pred, dim=2) | ||||
@@ -421,7 +414,9 @@ class ParserMetric(MetricBase): | |||||
if seq_lens is None: | if seq_lens is None: | ||||
seq_mask = arc_pred.new_ones(arc_pred.size(), dtype=torch.long) | seq_mask = arc_pred.new_ones(arc_pred.size(), dtype=torch.long) | ||||
else: | else: | ||||
seq_mask = seq_lens_to_masks(seq_lens, float=False).long() | |||||
seq_mask = seq_lens_to_masks(seq_lens.long(), float=False).long() | |||||
# mask out <root> tag | |||||
seq_mask[:,0] = 0 | |||||
head_pred_correct = (arc_pred == arc_true).long() * seq_mask | head_pred_correct = (arc_pred == arc_true).long() * seq_mask | ||||
label_pred_correct = (label_pred == label_true).long() * head_pred_correct | label_pred_correct = (label_pred == label_true).long() * head_pred_correct | ||||
self.num_arc += head_pred_correct.sum().item() | self.num_arc += head_pred_correct.sum().item() | ||||
@@ -2,8 +2,7 @@ import math | |||||
import torch | import torch | ||||
import torch.nn as nn | import torch.nn as nn | ||||
from torch.nn.utils.rnn import PackedSequence | |||||
from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence | |||||
from fastNLP.modules.utils import initial_parameter | from fastNLP.modules.utils import initial_parameter | ||||
try: | try: | ||||
@@ -25,30 +24,63 @@ class VarRnnCellWrapper(nn.Module): | |||||
self.input_p = input_p | self.input_p = input_p | ||||
self.hidden_p = hidden_p | self.hidden_p = hidden_p | ||||
def forward(self, input, hidden, mask_x=None, mask_h=None): | |||||
def forward(self, input_x, hidden, mask_x, mask_h, is_reversed=False): | |||||
""" | """ | ||||
:param input: [seq_len, batch_size, input_size] | |||||
:param PackedSequence input_x: [seq_len, batch_size, input_size] | |||||
:param hidden: for LSTM, tuple of (h_0, c_0), [batch_size, hidden_size] | :param hidden: for LSTM, tuple of (h_0, c_0), [batch_size, hidden_size] | ||||
for other RNN, h_0, [batch_size, hidden_size] | for other RNN, h_0, [batch_size, hidden_size] | ||||
:param mask_x: [batch_size, input_size] dropout mask for input | :param mask_x: [batch_size, input_size] dropout mask for input | ||||
:param mask_h: [batch_size, hidden_size] dropout mask for hidden | :param mask_h: [batch_size, hidden_size] dropout mask for hidden | ||||
:return output: [seq_len, bacth_size, hidden_size] | |||||
:return PackedSequence output: [seq_len, bacth_size, hidden_size] | |||||
hidden: for LSTM, tuple of (h_n, c_n), [batch_size, hidden_size] | hidden: for LSTM, tuple of (h_n, c_n), [batch_size, hidden_size] | ||||
for other RNN, h_n, [batch_size, hidden_size] | for other RNN, h_n, [batch_size, hidden_size] | ||||
""" | """ | ||||
def get_hi(hi, h0, size): | |||||
h0_size = size - hi.size(0) | |||||
if h0_size > 0: | |||||
return torch.cat([hi, h0[:h0_size]], dim=0) | |||||
return hi[:size] | |||||
is_lstm = isinstance(hidden, tuple) | is_lstm = isinstance(hidden, tuple) | ||||
input = input * mask_x.unsqueeze(0) if mask_x is not None else input | |||||
output_list = [] | |||||
for x in input: | |||||
input, batch_sizes = input_x | |||||
output = [] | |||||
cell = self.cell | |||||
if is_reversed: | |||||
batch_iter = flip(batch_sizes, [0]) | |||||
idx = input.size(0) | |||||
else: | |||||
batch_iter = batch_sizes | |||||
idx = 0 | |||||
if is_lstm: | |||||
hn = (hidden[0].clone(), hidden[1].clone()) | |||||
else: | |||||
hn = hidden.clone() | |||||
hi = hidden | |||||
for size in batch_iter: | |||||
if is_reversed: | |||||
input_i = input[idx-size: idx] * mask_x[:size] | |||||
idx -= size | |||||
else: | |||||
input_i = input[idx: idx+size] * mask_x[:size] | |||||
idx += size | |||||
mask_hi = mask_h[:size] | |||||
if is_lstm: | if is_lstm: | ||||
hx, cx = hidden | |||||
hidden = (hx * mask_h, cx) if mask_h is not None else (hx, cx) | |||||
hx, cx = hi | |||||
hi = (get_hi(hx, hidden[0], size) * mask_hi, get_hi(cx, hidden[1], size)) | |||||
hi = cell(input_i, hi) | |||||
hn[0][:size] = hi[0] | |||||
hn[1][:size] = hi[1] | |||||
output.append(hi[0]) | |||||
else: | else: | ||||
hidden *= mask_h if mask_h is not None else hidden | |||||
hidden = self.cell(x, hidden) | |||||
output_list.append(hidden[0] if is_lstm else hidden) | |||||
output = torch.stack(output_list, dim=0) | |||||
return output, hidden | |||||
hi = get_hi(hi, hidden, size) * mask_hi | |||||
hi = cell(input_i, hi) | |||||
hn[:size] = hi | |||||
output.append(hi) | |||||
if is_reversed: | |||||
output = list(reversed(output)) | |||||
output = torch.cat(output, dim=0) | |||||
return PackedSequence(output, batch_sizes), hn | |||||
class VarRNNBase(nn.Module): | class VarRNNBase(nn.Module): | ||||
@@ -77,60 +109,67 @@ class VarRNNBase(nn.Module): | |||||
cell = Cell(input_size, self.hidden_size, bias) | cell = Cell(input_size, self.hidden_size, bias) | ||||
self._all_cells.append(VarRnnCellWrapper(cell, self.hidden_size, input_dropout, hidden_dropout)) | self._all_cells.append(VarRnnCellWrapper(cell, self.hidden_size, input_dropout, hidden_dropout)) | ||||
initial_parameter(self) | initial_parameter(self) | ||||
self.is_lstm = (self.mode == "LSTM") | |||||
def _forward_one(self, n_layer, n_direction, input, hx, mask_x, mask_h): | |||||
is_lstm = self.is_lstm | |||||
idx = self.num_directions * n_layer + n_direction | |||||
cell = self._all_cells[idx] | |||||
hi = (hx[0][idx], hx[1][idx]) if is_lstm else hx[idx] | |||||
output_x, hidden_x = cell(input, hi, mask_x, mask_h, is_reversed=(n_direction == 1)) | |||||
return output_x, hidden_x | |||||
def forward(self, input, hx=None): | def forward(self, input, hx=None): | ||||
is_lstm = self.is_lstm | |||||
is_packed = isinstance(input, PackedSequence) | is_packed = isinstance(input, PackedSequence) | ||||
is_lstm = (self.mode == "LSTM") | |||||
if is_packed: | |||||
input, batch_sizes = input | |||||
max_batch_size = int(batch_sizes[0]) | |||||
else: | |||||
batch_sizes = None | |||||
if not is_packed: | |||||
seq_len = input.size(1) if self.batch_first else input.size(0) | |||||
max_batch_size = input.size(0) if self.batch_first else input.size(1) | max_batch_size = input.size(0) if self.batch_first else input.size(1) | ||||
seq_lens = torch.LongTensor([seq_len for _ in range(max_batch_size)]) | |||||
input, batch_sizes = pack_padded_sequence(input, seq_lens, batch_first=self.batch_first) | |||||
else: | |||||
max_batch_size = int(input.batch_sizes[0]) | |||||
input, batch_sizes = input | |||||
if hx is None: | if hx is None: | ||||
hx = input.new_zeros(self.num_layers * self.num_directions, | hx = input.new_zeros(self.num_layers * self.num_directions, | ||||
max_batch_size, self.hidden_size, | |||||
requires_grad=False) | |||||
max_batch_size, self.hidden_size, requires_grad=True) | |||||
if is_lstm: | if is_lstm: | ||||
hx = (hx, hx) | |||||
if self.batch_first: | |||||
input = input.transpose(0, 1) | |||||
batch_size = input.shape[1] | |||||
hx = (hx, hx.new_zeros(hx.size(), requires_grad=True)) | |||||
mask_x = input.new_ones((batch_size, self.input_size)) | |||||
mask_out = input.new_ones((batch_size, self.hidden_size * self.num_directions)) | |||||
mask_h_ones = input.new_ones((batch_size, self.hidden_size)) | |||||
mask_x = input.new_ones((max_batch_size, self.input_size)) | |||||
mask_out = input.new_ones((max_batch_size, self.hidden_size * self.num_directions)) | |||||
mask_h_ones = input.new_ones((max_batch_size, self.hidden_size)) | |||||
nn.functional.dropout(mask_x, p=self.input_dropout, training=self.training, inplace=True) | nn.functional.dropout(mask_x, p=self.input_dropout, training=self.training, inplace=True) | ||||
nn.functional.dropout(mask_out, p=self.hidden_dropout, training=self.training, inplace=True) | nn.functional.dropout(mask_out, p=self.hidden_dropout, training=self.training, inplace=True) | ||||
hidden_list = [] | |||||
hidden = input.new_zeros((self.num_layers*self.num_directions, max_batch_size, self.hidden_size)) | |||||
if is_lstm: | |||||
cellstate = input.new_zeros((self.num_layers*self.num_directions, max_batch_size, self.hidden_size)) | |||||
for layer in range(self.num_layers): | for layer in range(self.num_layers): | ||||
output_list = [] | output_list = [] | ||||
input_seq = PackedSequence(input, batch_sizes) | |||||
mask_h = nn.functional.dropout(mask_h_ones, p=self.hidden_dropout, training=self.training, inplace=False) | mask_h = nn.functional.dropout(mask_h_ones, p=self.hidden_dropout, training=self.training, inplace=False) | ||||
for direction in range(self.num_directions): | for direction in range(self.num_directions): | ||||
input_x = input if direction == 0 else flip(input, [0]) | |||||
output_x, hidden_x = self._forward_one(layer, direction, input_seq, hx, | |||||
mask_x if layer == 0 else mask_out, mask_h) | |||||
output_list.append(output_x.data) | |||||
idx = self.num_directions * layer + direction | idx = self.num_directions * layer + direction | ||||
cell = self._all_cells[idx] | |||||
hi = (hx[0][idx], hx[1][idx]) if is_lstm else hx[idx] | |||||
mask_xi = mask_x if layer == 0 else mask_out | |||||
output_x, hidden_x = cell(input_x, hi, mask_xi, mask_h) | |||||
output_list.append(output_x if direction == 0 else flip(output_x, [0])) | |||||
hidden_list.append(hidden_x) | |||||
if is_lstm: | |||||
hidden[idx] = hidden_x[0] | |||||
cellstate[idx] = hidden_x[1] | |||||
else: | |||||
hidden[idx] = hidden_x | |||||
input = torch.cat(output_list, dim=-1) | input = torch.cat(output_list, dim=-1) | ||||
output = input.transpose(0, 1) if self.batch_first else input | |||||
if is_lstm: | if is_lstm: | ||||
h_list, c_list = zip(*hidden_list) | |||||
hn = torch.stack(h_list, dim=0) | |||||
cn = torch.stack(c_list, dim=0) | |||||
hidden = (hn, cn) | |||||
else: | |||||
hidden = torch.stack(hidden_list, dim=0) | |||||
hidden = (hidden, cellstate) | |||||
if is_packed: | if is_packed: | ||||
output = PackedSequence(output, batch_sizes) | |||||
output = PackedSequence(input, batch_sizes) | |||||
else: | |||||
input = PackedSequence(input, batch_sizes) | |||||
output, _ = pad_packed_sequence(input, batch_first=self.batch_first) | |||||
return output, hidden | return output, hidden | ||||
@@ -152,3 +191,36 @@ class VarGRU(VarRNNBase): | |||||
""" | """ | ||||
def __init__(self, *args, **kwargs): | def __init__(self, *args, **kwargs): | ||||
super(VarGRU, self).__init__(mode="GRU", Cell=nn.GRUCell, *args, **kwargs) | super(VarGRU, self).__init__(mode="GRU", Cell=nn.GRUCell, *args, **kwargs) | ||||
# if __name__ == '__main__': | |||||
# x = torch.Tensor([[1,2,3], [4,5,0], [6,0,0]])[:,:,None] * 0.1 | |||||
# mask = (x != 0).float().view(3, -1) | |||||
# seq_lens = torch.LongTensor([3,2,1]) | |||||
# y = torch.Tensor([[0,1,1], [1,1,0], [0,0,0]]) | |||||
# # rev = _reverse_packed_sequence(pack) | |||||
# # # print(rev) | |||||
# lstm = VarLSTM(input_size=1, num_layers=2, hidden_size=2, | |||||
# batch_first=True, bidirectional=True, | |||||
# input_dropout=0.0, hidden_dropout=0.0,) | |||||
# # lstm = nn.LSTM(input_size=1, num_layers=2, hidden_size=2, | |||||
# # batch_first=True, bidirectional=True,) | |||||
# loss = nn.BCELoss() | |||||
# m = nn.Sigmoid() | |||||
# optim = torch.optim.SGD(lstm.parameters(), lr=1e-3) | |||||
# for i in range(2000): | |||||
# optim.zero_grad() | |||||
# pack = pack_padded_sequence(x, seq_lens, batch_first=True) | |||||
# out, hidden = lstm(pack) | |||||
# out, lens = pad_packed_sequence(out, batch_first=True) | |||||
# # print(lens) | |||||
# # print(out) | |||||
# # print(hidden[0]) | |||||
# # print(hidden[0].size()) | |||||
# # print(hidden[1]) | |||||
# out = out.sum(-1) | |||||
# out = m(out) * mask | |||||
# l = loss(out, y) | |||||
# l.backward() | |||||
# optim.step() | |||||
# if i % 50 == 0: | |||||
# print(out) |
@@ -1,13 +1,8 @@ | |||||
[train] | [train] | ||||
epochs = -1 | |||||
batch_size = 16 | |||||
pickle_path = "./save/" | |||||
validate = true | |||||
save_best_dev = true | |||||
eval_sort_key = "UAS" | |||||
n_epochs = 40 | |||||
batch_size = 32 | |||||
use_cuda = true | use_cuda = true | ||||
model_saved_path = "./save/" | |||||
print_every_step = 20 | |||||
validate_every = 500 | |||||
use_golden_train=true | use_golden_train=true | ||||
[test] | [test] | ||||
@@ -32,9 +27,9 @@ arc_mlp_size = 500 | |||||
label_mlp_size = 100 | label_mlp_size = 100 | ||||
num_label = -1 | num_label = -1 | ||||
dropout = 0.33 | dropout = 0.33 | ||||
use_var_lstm=false | |||||
use_var_lstm=true | |||||
use_greedy_infer=false | use_greedy_infer=false | ||||
[optim] | [optim] | ||||
lr = 2e-3 | |||||
weight_decay = 5e-5 | |||||
lr = 3e-4 | |||||
;weight_decay = 3e-5 |
@@ -3,24 +3,26 @@ import sys | |||||
sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) | sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) | ||||
import fastNLP | |||||
import torch | import torch | ||||
import re | |||||
from fastNLP.core.trainer import Trainer | from fastNLP.core.trainer import Trainer | ||||
from fastNLP.core.metrics import Evaluator | |||||
from fastNLP.core.instance import Instance | from fastNLP.core.instance import Instance | ||||
from fastNLP.api.pipeline import Pipeline | |||||
from fastNLP.models.biaffine_parser import BiaffineParser, ParserMetric, ParserLoss | |||||
from fastNLP.core.vocabulary import Vocabulary | from fastNLP.core.vocabulary import Vocabulary | ||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.core.field import TextField, SeqLabelField | |||||
from fastNLP.core.tester import Tester | from fastNLP.core.tester import Tester | ||||
from fastNLP.io.config_io import ConfigLoader, ConfigSection | from fastNLP.io.config_io import ConfigLoader, ConfigSection | ||||
from fastNLP.io.model_io import ModelLoader, ModelSaver | |||||
from fastNLP.io.model_io import ModelLoader | |||||
from fastNLP.io.embed_loader import EmbedLoader | from fastNLP.io.embed_loader import EmbedLoader | ||||
from fastNLP.models.biaffine_parser import BiaffineParser | |||||
from fastNLP.io.model_io import ModelSaver | |||||
from reproduction.Biaffine_parser.util import ConllxDataLoader, MyDataloader | |||||
from fastNLP.api.processor import * | |||||
BOS = '<BOS>' | BOS = '<BOS>' | ||||
EOS = '<EOS>' | EOS = '<EOS>' | ||||
UNK = '<OOV>' | |||||
UNK = '<UNK>' | |||||
NUM = '<NUM>' | NUM = '<NUM>' | ||||
ENG = '<ENG>' | ENG = '<ENG>' | ||||
@@ -28,85 +30,25 @@ ENG = '<ENG>' | |||||
if len(os.path.dirname(__file__)) != 0: | if len(os.path.dirname(__file__)) != 0: | ||||
os.chdir(os.path.dirname(__file__)) | os.chdir(os.path.dirname(__file__)) | ||||
class ConlluDataLoader(object): | |||||
def load(self, path): | |||||
datalist = [] | |||||
with open(path, 'r', encoding='utf-8') as f: | |||||
sample = [] | |||||
for line in f: | |||||
if line.startswith('\n'): | |||||
datalist.append(sample) | |||||
sample = [] | |||||
elif line.startswith('#'): | |||||
continue | |||||
else: | |||||
sample.append(line.split('\t')) | |||||
if len(sample) > 0: | |||||
datalist.append(sample) | |||||
ds = DataSet(name='conll') | |||||
for sample in datalist: | |||||
# print(sample) | |||||
res = self.get_one(sample) | |||||
ds.append(Instance(word_seq=TextField(res[0], is_target=False), | |||||
pos_seq=TextField(res[1], is_target=False), | |||||
head_indices=SeqLabelField(res[2], is_target=True), | |||||
head_labels=TextField(res[3], is_target=True))) | |||||
return ds | |||||
def get_one(self, sample): | |||||
text = [] | |||||
pos_tags = [] | |||||
heads = [] | |||||
head_tags = [] | |||||
for w in sample: | |||||
t1, t2, t3, t4 = w[1], w[3], w[6], w[7] | |||||
if t3 == '_': | |||||
continue | |||||
text.append(t1) | |||||
pos_tags.append(t2) | |||||
heads.append(int(t3)) | |||||
head_tags.append(t4) | |||||
return (text, pos_tags, heads, head_tags) | |||||
class CTBDataLoader(object): | |||||
def load(self, data_path): | |||||
with open(data_path, "r", encoding="utf-8") as f: | |||||
lines = f.readlines() | |||||
data = self.parse(lines) | |||||
return self.convert(data) | |||||
def parse(self, lines): | |||||
""" | |||||
[ | |||||
[word], [pos], [head_index], [head_tag] | |||||
] | |||||
""" | |||||
sample = [] | |||||
data = [] | |||||
for i, line in enumerate(lines): | |||||
line = line.strip() | |||||
if len(line) == 0 or i+1 == len(lines): | |||||
data.append(list(map(list, zip(*sample)))) | |||||
sample = [] | |||||
else: | |||||
sample.append(line.split()) | |||||
return data | |||||
def convert(self, data): | |||||
dataset = DataSet() | |||||
for sample in data: | |||||
word_seq = [BOS] + sample[0] + [EOS] | |||||
pos_seq = [BOS] + sample[1] + [EOS] | |||||
heads = [0] + list(map(int, sample[2])) + [0] | |||||
head_tags = [BOS] + sample[3] + [EOS] | |||||
dataset.append(Instance(word_seq=TextField(word_seq, is_target=False), | |||||
pos_seq=TextField(pos_seq, is_target=False), | |||||
gold_heads=SeqLabelField(heads, is_target=False), | |||||
head_indices=SeqLabelField(heads, is_target=True), | |||||
head_labels=TextField(head_tags, is_target=True))) | |||||
return dataset | |||||
def convert(data): | |||||
dataset = DataSet() | |||||
for sample in data: | |||||
word_seq = [BOS] + sample[0] | |||||
pos_seq = [BOS] + sample[1] | |||||
heads = [0] + list(map(int, sample[2])) | |||||
head_tags = [BOS] + sample[3] | |||||
dataset.append(Instance(words=word_seq, | |||||
pos=pos_seq, | |||||
gold_heads=heads, | |||||
arc_true=heads, | |||||
tags=head_tags)) | |||||
return dataset | |||||
def load(path): | |||||
data = ConllxDataLoader().load(path) | |||||
return convert(data) | |||||
# datadir = "/mnt/c/Me/Dev/release-2.2-st-train-dev-data/ud-treebanks-v2.2/UD_English-EWT" | # datadir = "/mnt/c/Me/Dev/release-2.2-st-train-dev-data/ud-treebanks-v2.2/UD_English-EWT" | ||||
# datadir = "/home/yfshao/UD_English-EWT" | # datadir = "/home/yfshao/UD_English-EWT" | ||||
@@ -115,26 +57,29 @@ class CTBDataLoader(object): | |||||
# emb_file_name = '/home/yfshao/glove.6B.100d.txt' | # emb_file_name = '/home/yfshao/glove.6B.100d.txt' | ||||
# loader = ConlluDataLoader() | # loader = ConlluDataLoader() | ||||
datadir = '/home/yfshao/workdir/parser-data/' | |||||
train_data_name = "train_ctb5.txt" | |||||
dev_data_name = "dev_ctb5.txt" | |||||
test_data_name = "test_ctb5.txt" | |||||
emb_file_name = "/home/yfshao/workdir/parser-data/word_OOVthr_30_100v.txt" | |||||
# emb_file_name = "/home/yfshao/workdir/word_vector/cc.zh.300.vec" | |||||
loader = CTBDataLoader() | |||||
# datadir = '/home/yfshao/workdir/parser-data/' | |||||
# train_data_name = "train_ctb5.txt" | |||||
# dev_data_name = "dev_ctb5.txt" | |||||
# test_data_name = "test_ctb5.txt" | |||||
datadir = "/home/yfshao/workdir/ctb7.0/" | |||||
train_data_name = "train.conllx" | |||||
dev_data_name = "dev.conllx" | |||||
test_data_name = "test.conllx" | |||||
# emb_file_name = "/home/yfshao/workdir/parser-data/word_OOVthr_30_100v.txt" | |||||
emb_file_name = "/home/yfshao/workdir/word_vector/cc.zh.300.vec" | |||||
cfgfile = './cfg.cfg' | cfgfile = './cfg.cfg' | ||||
processed_datadir = './save' | processed_datadir = './save' | ||||
# Config Loader | # Config Loader | ||||
train_args = ConfigSection() | train_args = ConfigSection() | ||||
test_args = ConfigSection() | |||||
model_args = ConfigSection() | model_args = ConfigSection() | ||||
optim_args = ConfigSection() | optim_args = ConfigSection() | ||||
ConfigLoader.load_config(cfgfile, {"train": train_args, "test": test_args, "model": model_args, "optim": optim_args}) | |||||
ConfigLoader.load_config(cfgfile, {"train": train_args, "model": model_args, "optim": optim_args}) | |||||
print('trainre Args:', train_args.data) | print('trainre Args:', train_args.data) | ||||
print('test Args:', test_args.data) | |||||
print('optim Args:', optim_args.data) | |||||
print('model Args:', model_args.data) | |||||
print('optim_args', optim_args.data) | |||||
# Pickle Loader | # Pickle Loader | ||||
@@ -159,84 +104,36 @@ def load_data(dirpath): | |||||
return datas | return datas | ||||
def P2(data, field, length): | def P2(data, field, length): | ||||
ds = [ins for ins in data if ins[field].get_length() >= length] | |||||
ds = [ins for ins in data if len(ins[field]) >= length] | |||||
data.clear() | data.clear() | ||||
data.extend(ds) | data.extend(ds) | ||||
return ds | return ds | ||||
def P1(data, field): | |||||
def reeng(w): | |||||
return w if w == BOS or w == EOS or re.search(r'^([a-zA-Z]+[\.\-]*)+$', w) is None else ENG | |||||
def renum(w): | |||||
return w if re.search(r'^[0-9]+\.?[0-9]*$', w) is None else NUM | |||||
for ins in data: | |||||
ori = ins[field].contents() | |||||
s = list(map(renum, map(reeng, ori))) | |||||
if s != ori: | |||||
# print(ori) | |||||
# print(s) | |||||
# print() | |||||
ins[field] = ins[field].new(s) | |||||
return data | |||||
class ParserEvaluator(Evaluator): | |||||
def __init__(self, ignore_label): | |||||
super(ParserEvaluator, self).__init__() | |||||
self.ignore = ignore_label | |||||
def __call__(self, predict_list, truth_list): | |||||
head_all, label_all, total_all = 0, 0, 0 | |||||
for pred, truth in zip(predict_list, truth_list): | |||||
head, label, total = self.evaluate(**pred, **truth) | |||||
head_all += head | |||||
label_all += label | |||||
total_all += total | |||||
return {'UAS': head_all*1.0 / total_all, 'LAS': label_all*1.0 / total_all} | |||||
def evaluate(self, head_pred, label_pred, head_indices, head_labels, seq_mask, **_): | |||||
""" | |||||
Evaluate the performance of prediction. | |||||
:return : performance results. | |||||
head_pred_corrct: number of correct predicted heads. | |||||
label_pred_correct: number of correct predicted labels. | |||||
total_tokens: number of predicted tokens | |||||
""" | |||||
seq_mask *= (head_labels != self.ignore).long() | |||||
head_pred_correct = (head_pred == head_indices).long() * seq_mask | |||||
_, label_preds = torch.max(label_pred, dim=2) | |||||
label_pred_correct = (label_preds == head_labels).long() * head_pred_correct | |||||
return head_pred_correct.sum().item(), label_pred_correct.sum().item(), seq_mask.sum().item() | |||||
try: | |||||
data_dict = load_data(processed_datadir) | |||||
word_v = data_dict['word_v'] | |||||
pos_v = data_dict['pos_v'] | |||||
tag_v = data_dict['tag_v'] | |||||
train_data = data_dict['train_data'] | |||||
dev_data = data_dict['dev_data'] | |||||
test_data = data_dict['test_data'] | |||||
print('use saved pickles') | |||||
except Exception as _: | |||||
print('load raw data and preprocess') | |||||
# use pretrain embedding | |||||
word_v = Vocabulary(need_default=True, min_freq=2) | |||||
word_v.unknown_label = UNK | |||||
pos_v = Vocabulary(need_default=True) | |||||
tag_v = Vocabulary(need_default=False) | |||||
train_data = loader.load(os.path.join(datadir, train_data_name)) | |||||
dev_data = loader.load(os.path.join(datadir, dev_data_name)) | |||||
test_data = loader.load(os.path.join(datadir, test_data_name)) | |||||
train_data.update_vocab(word_seq=word_v, pos_seq=pos_v, head_labels=tag_v) | |||||
datasets = (train_data, dev_data, test_data) | |||||
save_data(processed_datadir, word_v=word_v, pos_v=pos_v, tag_v=tag_v, train_data=train_data, dev_data=dev_data, test_data=test_data) | |||||
embed, _ = EmbedLoader.load_embedding(model_args['word_emb_dim'], emb_file_name, 'glove', word_v, os.path.join(processed_datadir, 'word_emb.pkl')) | |||||
print(len(word_v)) | |||||
print(embed.size()) | |||||
def update_v(vocab, data, field): | |||||
data.apply(lambda x: vocab.add_word_lst(x[field]), new_field_name=None) | |||||
print('load raw data and preprocess') | |||||
# use pretrain embedding | |||||
word_v = Vocabulary() | |||||
word_v.unknown_label = UNK | |||||
pos_v = Vocabulary() | |||||
tag_v = Vocabulary(unknown=None, padding=None) | |||||
train_data = load(os.path.join(datadir, train_data_name)) | |||||
dev_data = load(os.path.join(datadir, dev_data_name)) | |||||
test_data = load(os.path.join(datadir, test_data_name)) | |||||
print(train_data[0]) | |||||
num_p = Num2TagProcessor('words', 'words') | |||||
for ds in (train_data, dev_data, test_data): | |||||
num_p(ds) | |||||
update_v(word_v, train_data, 'words') | |||||
update_v(pos_v, train_data, 'pos') | |||||
update_v(tag_v, train_data, 'tags') | |||||
print('vocab build success {}, {}, {}'.format(len(word_v), len(pos_v), len(tag_v))) | |||||
# embed, _ = EmbedLoader.fast_load_embedding(model_args['word_emb_dim'], emb_file_name, word_v) | |||||
# print(embed.size()) | |||||
# Model | # Model | ||||
model_args['word_vocab_size'] = len(word_v) | model_args['word_vocab_size'] = len(word_v) | ||||
@@ -245,50 +142,49 @@ model_args['num_label'] = len(tag_v) | |||||
model = BiaffineParser(**model_args.data) | model = BiaffineParser(**model_args.data) | ||||
model.reset_parameters() | model.reset_parameters() | ||||
datasets = (train_data, dev_data, test_data) | |||||
for ds in datasets: | |||||
ds.index_field("word_seq", word_v).index_field("pos_seq", pos_v).index_field("head_labels", tag_v) | |||||
ds.set_origin_len('word_seq') | |||||
word_idxp = IndexerProcessor(word_v, 'words', 'word_seq') | |||||
pos_idxp = IndexerProcessor(pos_v, 'pos', 'pos_seq') | |||||
tag_idxp = IndexerProcessor(tag_v, 'tags', 'label_true') | |||||
seq_p = SeqLenProcessor('word_seq', 'seq_lens') | |||||
set_input_p = SetInputProcessor('word_seq', 'pos_seq', 'seq_lens', flag=True) | |||||
set_target_p = SetTargetProcessor('arc_true', 'label_true', 'seq_lens', flag=True) | |||||
label_toword_p = Index2WordProcessor(vocab=tag_v, field_name='label_pred', new_added_field_name='label_pred_seq') | |||||
for ds in (train_data, dev_data, test_data): | |||||
word_idxp(ds) | |||||
pos_idxp(ds) | |||||
tag_idxp(ds) | |||||
seq_p(ds) | |||||
set_input_p(ds) | |||||
set_target_p(ds) | |||||
if train_args['use_golden_train']: | if train_args['use_golden_train']: | ||||
train_data.set_target(gold_heads=False) | |||||
else: | |||||
train_data.set_target(gold_heads=None) | |||||
train_data.set_input('gold_heads', flag=True) | |||||
train_args.data.pop('use_golden_train') | train_args.data.pop('use_golden_train') | ||||
ignore_label = pos_v['P'] | |||||
ignore_label = pos_v['punct'] | |||||
print(test_data[0]) | print(test_data[0]) | ||||
print(len(train_data)) | |||||
print(len(dev_data)) | |||||
print(len(test_data)) | |||||
print('train len {}'.format(len(train_data))) | |||||
print('dev len {}'.format(len(dev_data))) | |||||
print('test len {}'.format(len(test_data))) | |||||
def train(path): | def train(path): | ||||
# test saving pipeline | |||||
save_pipe(path) | |||||
# Trainer | # Trainer | ||||
trainer = Trainer(**train_args.data) | |||||
def _define_optim(obj): | |||||
lr = optim_args.data['lr'] | |||||
embed_params = set(obj._model.word_embedding.parameters()) | |||||
decay_params = set(obj._model.arc_predictor.parameters()) | set(obj._model.label_predictor.parameters()) | |||||
params = [p for p in obj._model.parameters() if p not in decay_params and p not in embed_params] | |||||
obj._optimizer = torch.optim.Adam([ | |||||
{'params': list(embed_params), 'lr':lr*0.1}, | |||||
{'params': list(decay_params), **optim_args.data}, | |||||
{'params': params} | |||||
], lr=lr, betas=(0.9, 0.9)) | |||||
obj._scheduler = torch.optim.lr_scheduler.LambdaLR(obj._optimizer, lambda ep: max(.75 ** (ep / 5e4), 0.05)) | |||||
def _update(obj): | |||||
# torch.nn.utils.clip_grad_norm_(obj._model.parameters(), 5.0) | |||||
obj._scheduler.step() | |||||
obj._optimizer.step() | |||||
trainer.define_optimizer = lambda: _define_optim(trainer) | |||||
trainer.update = lambda: _update(trainer) | |||||
trainer.set_validator(Tester(**test_args.data, evaluator=ParserEvaluator(ignore_label))) | |||||
model.word_embedding = torch.nn.Embedding.from_pretrained(embed, freeze=False) | |||||
trainer = Trainer(model=model, train_data=train_data, dev_data=dev_data, | |||||
loss=ParserLoss(), metrics=ParserMetric(), metric_key='UAS', | |||||
**train_args.data, | |||||
optimizer=fastNLP.Adam(**optim_args.data), | |||||
save_path=path) | |||||
# model.word_embedding = torch.nn.Embedding.from_pretrained(embed, freeze=False) | |||||
model.word_embedding.padding_idx = word_v.padding_idx | model.word_embedding.padding_idx = word_v.padding_idx | ||||
model.word_embedding.weight.data[word_v.padding_idx].fill_(0) | model.word_embedding.weight.data[word_v.padding_idx].fill_(0) | ||||
model.pos_embedding.padding_idx = pos_v.padding_idx | model.pos_embedding.padding_idx = pos_v.padding_idx | ||||
@@ -302,18 +198,23 @@ def train(path): | |||||
# pass | # pass | ||||
# Start training | # Start training | ||||
trainer.train(model, train_data, dev_data) | |||||
trainer.train() | |||||
print("Training finished!") | print("Training finished!") | ||||
# Saver | |||||
saver = ModelSaver("./save/saved_model.pkl") | |||||
saver.save_pytorch(model) | |||||
print("Model saved!") | |||||
# save pipeline | |||||
save_pipe(path) | |||||
print('pipe saved') | |||||
def save_pipe(path): | |||||
pipe = Pipeline(processors=[num_p, word_idxp, pos_idxp, seq_p, set_input_p]) | |||||
pipe.add_processor(ModelProcessor(model=model, batch_size=32)) | |||||
pipe.add_processor(label_toword_p) | |||||
torch.save(pipe, os.path.join(path, 'pipe.pkl')) | |||||
def test(path): | def test(path): | ||||
# Tester | # Tester | ||||
tester = Tester(**test_args.data, evaluator=ParserEvaluator(ignore_label)) | |||||
tester = Tester(**test_args.data) | |||||
# Model | # Model | ||||
model = BiaffineParser(**model_args.data) | model = BiaffineParser(**model_args.data) | ||||
@@ -333,13 +234,18 @@ def test(path): | |||||
print("Testing Test data") | print("Testing Test data") | ||||
tester.test(model, test_data) | tester.test(model, test_data) | ||||
def build_pipe(parser_pipe_path): | |||||
parser_pipe = torch.load(parser_pipe_path) | |||||
if __name__ == "__main__": | if __name__ == "__main__": | ||||
import argparse | import argparse | ||||
parser = argparse.ArgumentParser(description='Run a chinese word segmentation model') | parser = argparse.ArgumentParser(description='Run a chinese word segmentation model') | ||||
parser.add_argument('--mode', help='set the model\'s model', choices=['train', 'test', 'infer']) | |||||
parser.add_argument('--mode', help='set the model\'s model', choices=['train', 'test', 'infer', 'save']) | |||||
parser.add_argument('--path', type=str, default='') | parser.add_argument('--path', type=str, default='') | ||||
# parser.add_argument('--dst', type=str, default='') | |||||
args = parser.parse_args() | args = parser.parse_args() | ||||
if args.mode == 'train': | if args.mode == 'train': | ||||
train(args.path) | train(args.path) | ||||
@@ -347,6 +253,12 @@ if __name__ == "__main__": | |||||
test(args.path) | test(args.path) | ||||
elif args.mode == 'infer': | elif args.mode == 'infer': | ||||
pass | pass | ||||
# elif args.mode == 'save': | |||||
# print(f'save model from {args.path} to {args.dst}') | |||||
# save_model(args.path, args.dst) | |||||
# load_path = os.path.dirname(args.dst) | |||||
# print(f'save pipeline in {load_path}') | |||||
# build(load_path) | |||||
else: | else: | ||||
print('no mode specified for model!') | print('no mode specified for model!') | ||||
parser.print_help() | parser.print_help() |
@@ -10,6 +10,8 @@ data_file = """ | |||||
4 will _ AUX MD _ 6 aux _ _ | 4 will _ AUX MD _ 6 aux _ _ | ||||
5 be _ VERB VB _ 6 cop _ _ | 5 be _ VERB VB _ 6 cop _ _ | ||||
6 payable _ ADJ JJ _ 0 root _ _ | 6 payable _ ADJ JJ _ 0 root _ _ | ||||
7 mask _ ADJ JJ _ 6 punct _ _ | |||||
8 mask _ ADJ JJ _ 6 punct _ _ | |||||
9 cents _ NOUN NNS _ 4 nmod _ _ | 9 cents _ NOUN NNS _ 4 nmod _ _ | ||||
10 from _ ADP IN _ 12 case _ _ | 10 from _ ADP IN _ 12 case _ _ | ||||
11 seven _ NUM CD _ 12 nummod _ _ | 11 seven _ NUM CD _ 12 nummod _ _ | ||||
@@ -58,13 +60,13 @@ def init_data(): | |||||
data.append(line) | data.append(line) | ||||
for name in ['word_seq', 'pos_seq', 'label_true']: | for name in ['word_seq', 'pos_seq', 'label_true']: | ||||
ds.apply(lambda x: ['<st>']+list(x[name])+['<ed>'], new_field_name=name) | |||||
ds.apply(lambda x: ['<st>']+list(x[name]), new_field_name=name) | |||||
ds.apply(lambda x: v[name].add_word_lst(x[name])) | ds.apply(lambda x: v[name].add_word_lst(x[name])) | ||||
for name in ['word_seq', 'pos_seq', 'label_true']: | for name in ['word_seq', 'pos_seq', 'label_true']: | ||||
ds.apply(lambda x: [v[name].to_index(w) for w in x[name]], new_field_name=name) | ds.apply(lambda x: [v[name].to_index(w) for w in x[name]], new_field_name=name) | ||||
ds.apply(lambda x: [0]+list(map(int, x['arc_true']))+[1], new_field_name='arc_true') | |||||
ds.apply(lambda x: [0]+list(map(int, x['arc_true'])), new_field_name='arc_true') | |||||
ds.apply(lambda x: len(x['word_seq']), new_field_name='seq_lens') | ds.apply(lambda x: len(x['word_seq']), new_field_name='seq_lens') | ||||
ds.set_input('word_seq', 'pos_seq', 'seq_lens', flag=True) | ds.set_input('word_seq', 'pos_seq', 'seq_lens', flag=True) | ||||
ds.set_target('arc_true', 'label_true', 'seq_lens', flag=True) | ds.set_target('arc_true', 'label_true', 'seq_lens', flag=True) | ||||
@@ -75,8 +77,11 @@ class TestBiaffineParser(unittest.TestCase): | |||||
ds, v1, v2, v3 = init_data() | ds, v1, v2, v3 = init_data() | ||||
model = BiaffineParser(word_vocab_size=len(v1), word_emb_dim=30, | model = BiaffineParser(word_vocab_size=len(v1), word_emb_dim=30, | ||||
pos_vocab_size=len(v2), pos_emb_dim=30, | pos_vocab_size=len(v2), pos_emb_dim=30, | ||||
num_label=len(v3)) | |||||
num_label=len(v3), use_var_lstm=True) | |||||
trainer = fastNLP.Trainer(model=model, train_data=ds, dev_data=ds, | trainer = fastNLP.Trainer(model=model, train_data=ds, dev_data=ds, | ||||
loss=ParserLoss(), metrics=ParserMetric(), metric_key='UAS', | loss=ParserLoss(), metrics=ParserMetric(), metric_key='UAS', | ||||
n_epochs=10, use_cuda=False, use_tqdm=False) | n_epochs=10, use_cuda=False, use_tqdm=False) | ||||
trainer.train(load_best_model=False) | trainer.train(load_best_model=False) | ||||
if __name__ == '__main__': | |||||
unittest.main() |