@@ -5,6 +5,8 @@ from fastNLP.api.pipeline import Pipeline | |||||
from fastNLP.api.processor import * | from fastNLP.api.processor import * | ||||
from fastNLP.models.biaffine_parser import BiaffineParser | from fastNLP.models.biaffine_parser import BiaffineParser | ||||
from fastNLP.core.instance import Instance | |||||
import torch | import torch | ||||
@@ -13,42 +15,23 @@ class DependencyParser(API): | |||||
super(DependencyParser, self).__init__() | super(DependencyParser, self).__init__() | ||||
def predict(self, data): | def predict(self, data): | ||||
self.load('xxx') | |||||
if self.pipeline is None: | |||||
self.pipeline = torch.load('xxx') | |||||
dataset = DataSet() | dataset = DataSet() | ||||
for sent, pos_seq in data: | |||||
dataset.append(Instance(sentence=sent, sent_pos=pos_seq)) | |||||
dataset = self.pipeline.process(dataset) | dataset = self.pipeline.process(dataset) | ||||
pred = Predictor() | |||||
res = pred.predict(self.model, dataset) | |||||
heads, head_tags = [], [] | |||||
for batch in res: | |||||
heads.append(batch['heads']) | |||||
head_tags.append(batch['labels']) | |||||
heads, head_tags = torch.cat(heads, dim=0), torch.cat(head_tags, dim=0) | |||||
return heads, head_tags | |||||
def build(self): | |||||
BOS = '<BOS>' | |||||
NUM = '<NUM>' | |||||
model_args = {} | |||||
load_path = '' | |||||
word_vocab = load(f'{load_path}/word_v.pkl') | |||||
pos_vocab = load(f'{load_path}/pos_v.pkl') | |||||
word_seq = 'word_seq' | |||||
pos_seq = 'pos_seq' | |||||
pipe = Pipeline() | |||||
# build pipeline | |||||
pipe.add_processor(Num2TagProcessor(NUM, 'raw_sentence', word_seq)) | |||||
pipe.add_processor(MapFieldProcessor(lambda x: [BOS] + x, word_seq, None)) | |||||
pipe.add_processor(MapFieldProcessor(lambda x: [BOS] + x, pos_seq, None)) | |||||
pipe.add_processor(IndexerProcessor(word_vocab, word_seq, word_seq+'_idx')) | |||||
pipe.add_processor(IndexerProcessor(pos_vocab, pos_seq, pos_seq+'_idx')) | |||||
pipe.add_processor(MapFieldProcessor(lambda x: len(x), word_seq, 'seq_len')) | |||||
# load model parameters | |||||
self.model = BiaffineParser(**model_args) | |||||
self.pipeline = pipe | |||||
return dataset['heads'], dataset['labels'] | |||||
if __name__ == '__main__': | |||||
data = [ | |||||
(['我', '是', '谁'], ['NR', 'VV', 'NR']), | |||||
(['自古', '英雄', '识', '英雄'], ['AD', 'NN', 'VV', 'NN']), | |||||
] | |||||
parser = DependencyParser() | |||||
with open('/home/yfshao/workdir/dev_fastnlp/reproduction/Biaffine_parser/pipe/pipeline.pkl', 'rb') as f: | |||||
parser.pipeline = torch.load(f) | |||||
output = parser.predict(data) | |||||
print(output) |
@@ -87,17 +87,30 @@ class FullSpaceToHalfSpaceProcessor(Processor): | |||||
return dataset | return dataset | ||||
class MapFieldProcessor(Processor): | |||||
def __init__(self, func, field_name, new_added_field_name=None): | |||||
super(MapFieldProcessor, self).__init__(field_name, new_added_field_name) | |||||
self.func = func | |||||
class PreAppendProcessor(Processor): | |||||
def __init__(self, data, field_name, new_added_field_name=None): | |||||
super(PreAppendProcessor, self).__init__(field_name, new_added_field_name) | |||||
self.data = data | |||||
def process(self, dataset): | def process(self, dataset): | ||||
for ins in dataset: | for ins in dataset: | ||||
s = ins[self.field_name] | |||||
new_s = self.func(s) | |||||
ins[self.new_added_field_name] = new_s | |||||
return dataset | |||||
sent = ins[self.field_name] | |||||
ins[self.new_added_field_name] = [self.data] + sent | |||||
return dataset | |||||
class SliceProcessor(Processor): | |||||
def __init__(self, start, end, step, field_name, new_added_field_name=None): | |||||
super(SliceProcessor, self).__init__(field_name, new_added_field_name) | |||||
for o in (start, end, step): | |||||
assert isinstance(o, int) or o is None | |||||
self.slice = slice(start, end, step) | |||||
def process(self, dataset): | |||||
for ins in dataset: | |||||
sent = ins[self.field_name] | |||||
ins[self.new_added_field_name] = sent[self.slice] | |||||
return dataset | |||||
class Num2TagProcessor(Processor): | class Num2TagProcessor(Processor): | ||||
@@ -231,3 +244,16 @@ class Index2WordProcessor(Processor): | |||||
new_sent = [self.vocab.to_word(w) for w in ins[self.field_name]] | new_sent = [self.vocab.to_word(w) for w in ins[self.field_name]] | ||||
ins[self.new_added_field_name] = new_sent | ins[self.new_added_field_name] = new_sent | ||||
return dataset | return dataset | ||||
class SetTensorProcessor(Processor): | |||||
def __init__(self, field_dict, default=False): | |||||
super(SetTensorProcessor, self).__init__(None, None) | |||||
self.field_dict = field_dict | |||||
self.default = default | |||||
def process(self, dataset): | |||||
set_dict = {name: self.default for name in dataset.get_fields().keys()} | |||||
set_dict.update(self.field_dict) | |||||
dataset.set_need_tensor(**set_dict) | |||||
return dataset |
@@ -23,9 +23,9 @@ class DataSet(object): | |||||
""" | """ | ||||
class DataSetIter(object): | class DataSetIter(object): | ||||
def __init__(self, dataset): | |||||
def __init__(self, dataset, idx=-1): | |||||
self.dataset = dataset | self.dataset = dataset | ||||
self.idx = -1 | |||||
self.idx = idx | |||||
def __next__(self): | def __next__(self): | ||||
self.idx += 1 | self.idx += 1 | ||||
@@ -88,7 +88,12 @@ class DataSet(object): | |||||
return self.field_arrays | return self.field_arrays | ||||
def __getitem__(self, name): | def __getitem__(self, name): | ||||
return self.field_arrays[name] | |||||
if isinstance(name, int): | |||||
return self.DataSetIter(self, idx=name) | |||||
elif isinstance(name, str): | |||||
return self.field_arrays[name] | |||||
else: | |||||
raise KeyError | |||||
def __len__(self): | def __len__(self): | ||||
if len(self.field_arrays) == 0: | if len(self.field_arrays) == 0: | ||||
@@ -33,7 +33,7 @@ class FieldArray(object): | |||||
array = np.array([self.content[i] for i in idxes], dtype=type(self.content[0])) | array = np.array([self.content[i] for i in idxes], dtype=type(self.content[0])) | ||||
else: | else: | ||||
max_len = max([len(self.content[i]) for i in idxes]) | max_len = max([len(self.content[i]) for i in idxes]) | ||||
array = np.full((batch_size, max_len), self.padding_val, dtype=np.int32) | |||||
array = np.full((batch_size, max_len), self.padding_val, dtype=np.int64) | |||||
for i, idx in enumerate(idxes): | for i, idx in enumerate(idxes): | ||||
array[i][:len(self.content[idx])] = self.content[idx] | array[i][:len(self.content[idx])] = self.content[idx] | ||||
@@ -286,6 +286,10 @@ class BiaffineParser(GraphParser): | |||||
head_pred: [batch_size, seq_len] if gold_heads is not provided, predicting the heads | head_pred: [batch_size, seq_len] if gold_heads is not provided, predicting the heads | ||||
""" | """ | ||||
# prepare embeddings | # prepare embeddings | ||||
device = self.parameters().__next__().device | |||||
word_seq = word_seq.long().to(device) | |||||
pos_seq = pos_seq.long().to(device) | |||||
word_seq_origin_len = word_seq_origin_len.long().to(device).view(-1) | |||||
batch_size, seq_len = word_seq.shape | batch_size, seq_len = word_seq.shape | ||||
# print('forward {} {}'.format(batch_size, seq_len)) | # print('forward {} {}'.format(batch_size, seq_len)) | ||||
@@ -300,9 +304,13 @@ class BiaffineParser(GraphParser): | |||||
del word, pos | del word, pos | ||||
# lstm, extract features | # lstm, extract features | ||||
x = nn.utils.rnn.pack_padded_sequence(x, word_seq_origin_len.squeeze(1), batch_first=True) | |||||
sort_lens, sort_idx = torch.sort(word_seq_origin_len, dim=0, descending=True) | |||||
x = x[sort_idx] | |||||
x = nn.utils.rnn.pack_padded_sequence(x, sort_lens, batch_first=True) | |||||
feat, _ = self.lstm(x) # -> [N,L,C] | feat, _ = self.lstm(x) # -> [N,L,C] | ||||
feat, _ = nn.utils.rnn.pad_packed_sequence(feat, batch_first=True) | feat, _ = nn.utils.rnn.pad_packed_sequence(feat, batch_first=True) | ||||
_, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) | |||||
feat = feat[unsort_idx] | |||||
# for arc biaffine | # for arc biaffine | ||||
# mlp, reduce dim | # mlp, reduce dim | ||||
@@ -386,5 +394,4 @@ class BiaffineParser(GraphParser): | |||||
output['head_pred'] = res.pop('head_pred') | output['head_pred'] = res.pop('head_pred') | ||||
_, label_pred = res.pop('label_pred').max(2) | _, label_pred = res.pop('label_pred').max(2) | ||||
output['label_pred'] = label_pred | output['label_pred'] = label_pred | ||||
output['seq_len'] = word_seq_origin_len | |||||
return output | return output |
@@ -1,5 +1,6 @@ | |||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
import numpy as np | |||||
from fastNLP.models.base_model import BaseModel | from fastNLP.models.base_model import BaseModel | ||||
from fastNLP.modules import decoder, encoder | from fastNLP.modules import decoder, encoder | ||||
@@ -160,6 +161,7 @@ class AdvSeqLabel(SeqLabeling): | |||||
sent_packed = torch.nn.utils.rnn.pack_padded_sequence(sent_variable, sent_len, batch_first=True) | sent_packed = torch.nn.utils.rnn.pack_padded_sequence(sent_variable, sent_len, batch_first=True) | ||||
x = self.Rnn(sent_packed) | x = self.Rnn(sent_packed) | ||||
# print(x) | |||||
# [batch_size, max_len, hidden_size * direction] | # [batch_size, max_len, hidden_size * direction] | ||||
sent_output = torch.nn.utils.rnn.pad_packed_sequence(x, batch_first=True)[0] | sent_output = torch.nn.utils.rnn.pad_packed_sequence(x, batch_first=True)[0] | ||||
@@ -180,3 +182,42 @@ class AdvSeqLabel(SeqLabeling): | |||||
def predict(self, **x): | def predict(self, **x): | ||||
out = self.forward(**x) | out = self.forward(**x) | ||||
return {"predict": out["predict"]} | return {"predict": out["predict"]} | ||||
args = { | |||||
'vocab_size': 20, | |||||
'word_emb_dim': 100, | |||||
'rnn_hidden_units': 100, | |||||
'num_classes': 10, | |||||
} | |||||
model = AdvSeqLabel(args) | |||||
data = [] | |||||
for i in range(20): | |||||
word_seq = torch.randint(20, (15,)).long() | |||||
word_seq_len = torch.LongTensor([15]) | |||||
truth = torch.randint(10, (15,)).long() | |||||
data.append((word_seq, word_seq_len, truth)) | |||||
optimizer = torch.optim.Adam(model.parameters(), lr=0.01) | |||||
print(model) | |||||
curidx = 0 | |||||
for i in range(1000): | |||||
endidx = min(len(data), curidx + 5) | |||||
b_word, b_len, b_truth = [], [], [] | |||||
for word_seq, word_seq_len, truth in data[curidx: endidx]: | |||||
b_word.append(word_seq) | |||||
b_len.append(word_seq_len) | |||||
b_truth.append(truth) | |||||
word_seq = torch.stack(b_word, dim=0) | |||||
word_seq_len = torch.cat(b_len, dim=0) | |||||
truth = torch.stack(b_truth, dim=0) | |||||
res = model(word_seq, word_seq_len, truth) | |||||
loss = res['loss'] | |||||
pred = res['predict'] | |||||
print('loss: {} acc {}'.format(loss.item(), ((pred.data == truth).long().sum().float() / word_seq_len.sum().float()))) | |||||
optimizer.zero_grad() | |||||
loss.backward() | |||||
optimizer.step() | |||||
curidx = endidx | |||||
if curidx == len(data): | |||||
curidx = 0 | |||||
@@ -21,7 +21,7 @@ def seq_len_to_byte_mask(seq_lens): | |||||
class ConditionalRandomField(nn.Module): | class ConditionalRandomField(nn.Module): | ||||
def __init__(self, tag_size, include_start_end_trans=True ,initial_method = None): | |||||
def __init__(self, tag_size, include_start_end_trans=False ,initial_method = None): | |||||
""" | """ | ||||
:param tag_size: int, num of tags | :param tag_size: int, num of tags | ||||
:param include_start_end_trans: bool, whether to include start/end tag | :param include_start_end_trans: bool, whether to include start/end tag | ||||
@@ -87,7 +87,7 @@ class ConditionalRandomField(nn.Module): | |||||
emit_score = logits[seq_idx.view(-1,1), batch_idx.view(1,-1), tags] * mask | emit_score = logits[seq_idx.view(-1,1), batch_idx.view(1,-1), tags] * mask | ||||
# score [L-1, B] | # score [L-1, B] | ||||
score = trans_score + emit_score[:seq_len-1, :] | score = trans_score + emit_score[:seq_len-1, :] | ||||
score = score.sum(0) + emit_score[-1] | |||||
score = score.sum(0) + emit_score[-1] * mask[-1] | |||||
if self.include_start_end_trans: | if self.include_start_end_trans: | ||||
st_scores = self.start_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[0]] | st_scores = self.start_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[0]] | ||||
last_idx = mask.long().sum(0) - 1 | last_idx = mask.long().sum(0) - 1 | ||||