@@ -5,6 +5,8 @@ from fastNLP.api.pipeline import Pipeline | |||
from fastNLP.api.processor import * | |||
from fastNLP.models.biaffine_parser import BiaffineParser | |||
from fastNLP.core.instance import Instance | |||
import torch | |||
@@ -13,42 +15,23 @@ class DependencyParser(API): | |||
super(DependencyParser, self).__init__() | |||
def predict(self, data): | |||
self.load('xxx') | |||
if self.pipeline is None: | |||
self.pipeline = torch.load('xxx') | |||
dataset = DataSet() | |||
for sent, pos_seq in data: | |||
dataset.append(Instance(sentence=sent, sent_pos=pos_seq)) | |||
dataset = self.pipeline.process(dataset) | |||
pred = Predictor() | |||
res = pred.predict(self.model, dataset) | |||
heads, head_tags = [], [] | |||
for batch in res: | |||
heads.append(batch['heads']) | |||
head_tags.append(batch['labels']) | |||
heads, head_tags = torch.cat(heads, dim=0), torch.cat(head_tags, dim=0) | |||
return heads, head_tags | |||
def build(self): | |||
BOS = '<BOS>' | |||
NUM = '<NUM>' | |||
model_args = {} | |||
load_path = '' | |||
word_vocab = load(f'{load_path}/word_v.pkl') | |||
pos_vocab = load(f'{load_path}/pos_v.pkl') | |||
word_seq = 'word_seq' | |||
pos_seq = 'pos_seq' | |||
pipe = Pipeline() | |||
# build pipeline | |||
pipe.add_processor(Num2TagProcessor(NUM, 'raw_sentence', word_seq)) | |||
pipe.add_processor(MapFieldProcessor(lambda x: [BOS] + x, word_seq, None)) | |||
pipe.add_processor(MapFieldProcessor(lambda x: [BOS] + x, pos_seq, None)) | |||
pipe.add_processor(IndexerProcessor(word_vocab, word_seq, word_seq+'_idx')) | |||
pipe.add_processor(IndexerProcessor(pos_vocab, pos_seq, pos_seq+'_idx')) | |||
pipe.add_processor(MapFieldProcessor(lambda x: len(x), word_seq, 'seq_len')) | |||
# load model parameters | |||
self.model = BiaffineParser(**model_args) | |||
self.pipeline = pipe | |||
return dataset['heads'], dataset['labels'] | |||
if __name__ == '__main__': | |||
data = [ | |||
(['我', '是', '谁'], ['NR', 'VV', 'NR']), | |||
(['自古', '英雄', '识', '英雄'], ['AD', 'NN', 'VV', 'NN']), | |||
] | |||
parser = DependencyParser() | |||
with open('/home/yfshao/workdir/dev_fastnlp/reproduction/Biaffine_parser/pipe/pipeline.pkl', 'rb') as f: | |||
parser.pipeline = torch.load(f) | |||
output = parser.predict(data) | |||
print(output) |
@@ -87,17 +87,30 @@ class FullSpaceToHalfSpaceProcessor(Processor): | |||
return dataset | |||
class MapFieldProcessor(Processor): | |||
def __init__(self, func, field_name, new_added_field_name=None): | |||
super(MapFieldProcessor, self).__init__(field_name, new_added_field_name) | |||
self.func = func | |||
class PreAppendProcessor(Processor): | |||
def __init__(self, data, field_name, new_added_field_name=None): | |||
super(PreAppendProcessor, self).__init__(field_name, new_added_field_name) | |||
self.data = data | |||
def process(self, dataset): | |||
for ins in dataset: | |||
s = ins[self.field_name] | |||
new_s = self.func(s) | |||
ins[self.new_added_field_name] = new_s | |||
return dataset | |||
sent = ins[self.field_name] | |||
ins[self.new_added_field_name] = [self.data] + sent | |||
return dataset | |||
class SliceProcessor(Processor): | |||
def __init__(self, start, end, step, field_name, new_added_field_name=None): | |||
super(SliceProcessor, self).__init__(field_name, new_added_field_name) | |||
for o in (start, end, step): | |||
assert isinstance(o, int) or o is None | |||
self.slice = slice(start, end, step) | |||
def process(self, dataset): | |||
for ins in dataset: | |||
sent = ins[self.field_name] | |||
ins[self.new_added_field_name] = sent[self.slice] | |||
return dataset | |||
class Num2TagProcessor(Processor): | |||
@@ -231,3 +244,16 @@ class Index2WordProcessor(Processor): | |||
new_sent = [self.vocab.to_word(w) for w in ins[self.field_name]] | |||
ins[self.new_added_field_name] = new_sent | |||
return dataset | |||
class SetTensorProcessor(Processor): | |||
def __init__(self, field_dict, default=False): | |||
super(SetTensorProcessor, self).__init__(None, None) | |||
self.field_dict = field_dict | |||
self.default = default | |||
def process(self, dataset): | |||
set_dict = {name: self.default for name in dataset.get_fields().keys()} | |||
set_dict.update(self.field_dict) | |||
dataset.set_need_tensor(**set_dict) | |||
return dataset |
@@ -23,9 +23,9 @@ class DataSet(object): | |||
""" | |||
class DataSetIter(object): | |||
def __init__(self, dataset): | |||
def __init__(self, dataset, idx=-1): | |||
self.dataset = dataset | |||
self.idx = -1 | |||
self.idx = idx | |||
def __next__(self): | |||
self.idx += 1 | |||
@@ -88,7 +88,12 @@ class DataSet(object): | |||
return self.field_arrays | |||
def __getitem__(self, name): | |||
return self.field_arrays[name] | |||
if isinstance(name, int): | |||
return self.DataSetIter(self, idx=name) | |||
elif isinstance(name, str): | |||
return self.field_arrays[name] | |||
else: | |||
raise KeyError | |||
def __len__(self): | |||
if len(self.field_arrays) == 0: | |||
@@ -33,7 +33,7 @@ class FieldArray(object): | |||
array = np.array([self.content[i] for i in idxes], dtype=type(self.content[0])) | |||
else: | |||
max_len = max([len(self.content[i]) for i in idxes]) | |||
array = np.full((batch_size, max_len), self.padding_val, dtype=np.int32) | |||
array = np.full((batch_size, max_len), self.padding_val, dtype=np.int64) | |||
for i, idx in enumerate(idxes): | |||
array[i][:len(self.content[idx])] = self.content[idx] | |||
@@ -286,6 +286,10 @@ class BiaffineParser(GraphParser): | |||
head_pred: [batch_size, seq_len] if gold_heads is not provided, predicting the heads | |||
""" | |||
# prepare embeddings | |||
device = self.parameters().__next__().device | |||
word_seq = word_seq.long().to(device) | |||
pos_seq = pos_seq.long().to(device) | |||
word_seq_origin_len = word_seq_origin_len.long().to(device).view(-1) | |||
batch_size, seq_len = word_seq.shape | |||
# print('forward {} {}'.format(batch_size, seq_len)) | |||
@@ -300,9 +304,13 @@ class BiaffineParser(GraphParser): | |||
del word, pos | |||
# lstm, extract features | |||
x = nn.utils.rnn.pack_padded_sequence(x, word_seq_origin_len.squeeze(1), batch_first=True) | |||
sort_lens, sort_idx = torch.sort(word_seq_origin_len, dim=0, descending=True) | |||
x = x[sort_idx] | |||
x = nn.utils.rnn.pack_padded_sequence(x, sort_lens, batch_first=True) | |||
feat, _ = self.lstm(x) # -> [N,L,C] | |||
feat, _ = nn.utils.rnn.pad_packed_sequence(feat, batch_first=True) | |||
_, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) | |||
feat = feat[unsort_idx] | |||
# for arc biaffine | |||
# mlp, reduce dim | |||
@@ -386,5 +394,4 @@ class BiaffineParser(GraphParser): | |||
output['head_pred'] = res.pop('head_pred') | |||
_, label_pred = res.pop('label_pred').max(2) | |||
output['label_pred'] = label_pred | |||
output['seq_len'] = word_seq_origin_len | |||
return output |
@@ -1,5 +1,6 @@ | |||
import numpy as np | |||
import torch | |||
import numpy as np | |||
from fastNLP.models.base_model import BaseModel | |||
from fastNLP.modules import decoder, encoder | |||
@@ -160,6 +161,7 @@ class AdvSeqLabel(SeqLabeling): | |||
sent_packed = torch.nn.utils.rnn.pack_padded_sequence(sent_variable, sent_len, batch_first=True) | |||
x = self.Rnn(sent_packed) | |||
# print(x) | |||
# [batch_size, max_len, hidden_size * direction] | |||
sent_output = torch.nn.utils.rnn.pad_packed_sequence(x, batch_first=True)[0] | |||
@@ -180,3 +182,42 @@ class AdvSeqLabel(SeqLabeling): | |||
def predict(self, **x): | |||
out = self.forward(**x) | |||
return {"predict": out["predict"]} | |||
args = { | |||
'vocab_size': 20, | |||
'word_emb_dim': 100, | |||
'rnn_hidden_units': 100, | |||
'num_classes': 10, | |||
} | |||
model = AdvSeqLabel(args) | |||
data = [] | |||
for i in range(20): | |||
word_seq = torch.randint(20, (15,)).long() | |||
word_seq_len = torch.LongTensor([15]) | |||
truth = torch.randint(10, (15,)).long() | |||
data.append((word_seq, word_seq_len, truth)) | |||
optimizer = torch.optim.Adam(model.parameters(), lr=0.01) | |||
print(model) | |||
curidx = 0 | |||
for i in range(1000): | |||
endidx = min(len(data), curidx + 5) | |||
b_word, b_len, b_truth = [], [], [] | |||
for word_seq, word_seq_len, truth in data[curidx: endidx]: | |||
b_word.append(word_seq) | |||
b_len.append(word_seq_len) | |||
b_truth.append(truth) | |||
word_seq = torch.stack(b_word, dim=0) | |||
word_seq_len = torch.cat(b_len, dim=0) | |||
truth = torch.stack(b_truth, dim=0) | |||
res = model(word_seq, word_seq_len, truth) | |||
loss = res['loss'] | |||
pred = res['predict'] | |||
print('loss: {} acc {}'.format(loss.item(), ((pred.data == truth).long().sum().float() / word_seq_len.sum().float()))) | |||
optimizer.zero_grad() | |||
loss.backward() | |||
optimizer.step() | |||
curidx = endidx | |||
if curidx == len(data): | |||
curidx = 0 | |||
@@ -21,7 +21,7 @@ def seq_len_to_byte_mask(seq_lens): | |||
class ConditionalRandomField(nn.Module): | |||
def __init__(self, tag_size, include_start_end_trans=True ,initial_method = None): | |||
def __init__(self, tag_size, include_start_end_trans=False ,initial_method = None): | |||
""" | |||
:param tag_size: int, num of tags | |||
:param include_start_end_trans: bool, whether to include start/end tag | |||
@@ -87,7 +87,7 @@ class ConditionalRandomField(nn.Module): | |||
emit_score = logits[seq_idx.view(-1,1), batch_idx.view(1,-1), tags] * mask | |||
# score [L-1, B] | |||
score = trans_score + emit_score[:seq_len-1, :] | |||
score = score.sum(0) + emit_score[-1] | |||
score = score.sum(0) + emit_score[-1] * mask[-1] | |||
if self.include_start_end_trans: | |||
st_scores = self.start_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[0]] | |||
last_idx = mask.long().sum(0) - 1 | |||