@@ -5,6 +5,8 @@ from fastNLP.api.pipeline import Pipeline | |||||
from fastNLP.api.processor import * | from fastNLP.api.processor import * | ||||
from fastNLP.models.biaffine_parser import BiaffineParser | from fastNLP.models.biaffine_parser import BiaffineParser | ||||
import torch | |||||
class DependencyParser(API): | class DependencyParser(API): | ||||
def __init__(self): | def __init__(self): | ||||
@@ -18,19 +20,35 @@ class DependencyParser(API): | |||||
pred = Predictor() | pred = Predictor() | ||||
res = pred.predict(self.model, dataset) | res = pred.predict(self.model, dataset) | ||||
heads, head_tags = [], [] | |||||
for batch in res: | |||||
heads.append(batch['heads']) | |||||
head_tags.append(batch['labels']) | |||||
heads, head_tags = torch.cat(heads, dim=0), torch.cat(head_tags, dim=0) | |||||
return heads, head_tags | |||||
return res | |||||
def build(self): | def build(self): | ||||
pipe = Pipeline() | |||||
# build pipeline | |||||
BOS = '<BOS>' | |||||
NUM = '<NUM>' | |||||
model_args = {} | |||||
load_path = '' | |||||
word_vocab = load(f'{load_path}/word_v.pkl') | |||||
pos_vocab = load(f'{load_path}/pos_v.pkl') | |||||
word_seq = 'word_seq' | word_seq = 'word_seq' | ||||
pos_seq = 'pos_seq' | pos_seq = 'pos_seq' | ||||
pipe.add_processor(Num2TagProcessor('<NUM>', 'raw_sentence', word_seq)) | |||||
pipe = Pipeline() | |||||
# build pipeline | |||||
pipe.add_processor(Num2TagProcessor(NUM, 'raw_sentence', word_seq)) | |||||
pipe.add_processor(MapFieldProcessor(lambda x: [BOS] + x, word_seq, None)) | |||||
pipe.add_processor(MapFieldProcessor(lambda x: [BOS] + x, pos_seq, None)) | |||||
pipe.add_processor(IndexerProcessor(word_vocab, word_seq, word_seq+'_idx')) | pipe.add_processor(IndexerProcessor(word_vocab, word_seq, word_seq+'_idx')) | ||||
pipe.add_processor(IndexerProcessor(pos_vocab, pos_seq, pos_seq+'_idx')) | pipe.add_processor(IndexerProcessor(pos_vocab, pos_seq, pos_seq+'_idx')) | ||||
pipe.add_processor(MapFieldProcessor(lambda x: len(x), word_seq, 'seq_len')) | |||||
# load model parameters | # load model parameters | ||||
self.model = BiaffineParser() | |||||
self.model = BiaffineParser(**model_args) | |||||
self.pipeline = pipe | self.pipeline = pipe | ||||
@@ -145,7 +145,6 @@ class IndexerProcessor(Processor): | |||||
class VocabProcessor(Processor): | class VocabProcessor(Processor): | ||||
def __init__(self, field_name): | def __init__(self, field_name): | ||||
super(VocabProcessor, self).__init__(field_name, None) | super(VocabProcessor, self).__init__(field_name, None) | ||||
self.vocab = Vocabulary() | self.vocab = Vocabulary() | ||||
@@ -172,3 +171,15 @@ class SeqLenProcessor(Processor): | |||||
ins[self.new_added_field_name] = length | ins[self.new_added_field_name] = length | ||||
dataset.set_need_tensor(**{self.new_added_field_name: True}) | dataset.set_need_tensor(**{self.new_added_field_name: True}) | ||||
return dataset | return dataset | ||||
class Index2WordProcessor(Processor): | |||||
def __init__(self, vocab, field_name, new_added_field_name): | |||||
super(Index2WordProcessor, self).__init__(field_name, new_added_field_name) | |||||
self.vocab = vocab | |||||
def process(self, dataset): | |||||
for ins in dataset: | |||||
new_sent = [self.vocab.to_word(w) for w in ins[self.field_name]] | |||||
ins[self.new_added_field_name] = new_sent | |||||
return dataset | |||||
@@ -13,3 +13,6 @@ class BaseModel(torch.nn.Module): | |||||
def fit(self, train_data, dev_data=None, **train_args): | def fit(self, train_data, dev_data=None, **train_args): | ||||
trainer = Trainer(**train_args) | trainer = Trainer(**train_args) | ||||
trainer.train(self, train_data, dev_data) | trainer.train(self, train_data, dev_data) | ||||
def predict(self): | |||||
pass |
@@ -9,6 +9,7 @@ from torch.nn import functional as F | |||||
from fastNLP.modules.utils import initial_parameter | from fastNLP.modules.utils import initial_parameter | ||||
from fastNLP.modules.encoder.variational_rnn import VarLSTM | from fastNLP.modules.encoder.variational_rnn import VarLSTM | ||||
from fastNLP.modules.dropout import TimestepDropout | from fastNLP.modules.dropout import TimestepDropout | ||||
from fastNLP.models.base_model import BaseModel | |||||
def mst(scores): | def mst(scores): | ||||
""" | """ | ||||
@@ -113,7 +114,7 @@ def _find_cycle(vertices, edges): | |||||
return [SCC for SCC in _SCCs if len(SCC) > 1] | return [SCC for SCC in _SCCs if len(SCC) > 1] | ||||
class GraphParser(nn.Module): | |||||
class GraphParser(BaseModel): | |||||
"""Graph based Parser helper class, support greedy decoding and MST(Maximum Spanning Tree) decoding | """Graph based Parser helper class, support greedy decoding and MST(Maximum Spanning Tree) decoding | ||||
""" | """ | ||||
def __init__(self): | def __init__(self): | ||||
@@ -370,4 +371,20 @@ class BiaffineParser(GraphParser): | |||||
label_nll = -(label_loss*float_mask).mean() | label_nll = -(label_loss*float_mask).mean() | ||||
return arc_nll + label_nll | return arc_nll + label_nll | ||||
def predict(self, word_seq, pos_seq, word_seq_origin_len): | |||||
""" | |||||
:param word_seq: | |||||
:param pos_seq: | |||||
:param word_seq_origin_len: | |||||
:return: head_pred: [B, L] | |||||
label_pred: [B, L] | |||||
seq_len: [B,] | |||||
""" | |||||
res = self(word_seq, pos_seq, word_seq_origin_len) | |||||
output = {} | |||||
output['head_pred'] = res.pop('head_pred') | |||||
_, label_pred = res.pop('label_pred').max(2) | |||||
output['label_pred'] = label_pred | |||||
output['seq_len'] = word_seq_origin_len | |||||
return output |
@@ -30,11 +30,13 @@ class TestCase1(unittest.TestCase): | |||||
for text, label in zip(texts, labels): | for text, label in zip(texts, labels): | ||||
x = TextField(text, is_target=False) | x = TextField(text, is_target=False) | ||||
y = LabelField(label, is_target=True) | y = LabelField(label, is_target=True) | ||||
ins = Instance(text=x, label=y) | |||||
ins = Instance(raw_text=x, label=y) | |||||
data.append(ins) | data.append(ins) | ||||
# use vocabulary to index data | # use vocabulary to index data | ||||
data.index_field("text", vocab) | |||||
# data.index_field("text", vocab) | |||||
for ins in data: | |||||
ins['text'] = [vocab.to_index(w) for w in ins['raw_text']] | |||||
# define naive sampler for batch class | # define naive sampler for batch class | ||||
class SeqSampler: | class SeqSampler: | ||||