| @@ -5,6 +5,8 @@ from fastNLP.api.pipeline import Pipeline | |||||
| from fastNLP.api.processor import * | from fastNLP.api.processor import * | ||||
| from fastNLP.models.biaffine_parser import BiaffineParser | from fastNLP.models.biaffine_parser import BiaffineParser | ||||
| import torch | |||||
| class DependencyParser(API): | class DependencyParser(API): | ||||
| def __init__(self): | def __init__(self): | ||||
| @@ -18,19 +20,35 @@ class DependencyParser(API): | |||||
| pred = Predictor() | pred = Predictor() | ||||
| res = pred.predict(self.model, dataset) | res = pred.predict(self.model, dataset) | ||||
| heads, head_tags = [], [] | |||||
| for batch in res: | |||||
| heads.append(batch['heads']) | |||||
| head_tags.append(batch['labels']) | |||||
| heads, head_tags = torch.cat(heads, dim=0), torch.cat(head_tags, dim=0) | |||||
| return heads, head_tags | |||||
| return res | |||||
| def build(self): | def build(self): | ||||
| pipe = Pipeline() | |||||
| # build pipeline | |||||
| BOS = '<BOS>' | |||||
| NUM = '<NUM>' | |||||
| model_args = {} | |||||
| load_path = '' | |||||
| word_vocab = load(f'{load_path}/word_v.pkl') | |||||
| pos_vocab = load(f'{load_path}/pos_v.pkl') | |||||
| word_seq = 'word_seq' | word_seq = 'word_seq' | ||||
| pos_seq = 'pos_seq' | pos_seq = 'pos_seq' | ||||
| pipe.add_processor(Num2TagProcessor('<NUM>', 'raw_sentence', word_seq)) | |||||
| pipe = Pipeline() | |||||
| # build pipeline | |||||
| pipe.add_processor(Num2TagProcessor(NUM, 'raw_sentence', word_seq)) | |||||
| pipe.add_processor(MapFieldProcessor(lambda x: [BOS] + x, word_seq, None)) | |||||
| pipe.add_processor(MapFieldProcessor(lambda x: [BOS] + x, pos_seq, None)) | |||||
| pipe.add_processor(IndexerProcessor(word_vocab, word_seq, word_seq+'_idx')) | pipe.add_processor(IndexerProcessor(word_vocab, word_seq, word_seq+'_idx')) | ||||
| pipe.add_processor(IndexerProcessor(pos_vocab, pos_seq, pos_seq+'_idx')) | pipe.add_processor(IndexerProcessor(pos_vocab, pos_seq, pos_seq+'_idx')) | ||||
| pipe.add_processor(MapFieldProcessor(lambda x: len(x), word_seq, 'seq_len')) | |||||
| # load model parameters | # load model parameters | ||||
| self.model = BiaffineParser() | |||||
| self.model = BiaffineParser(**model_args) | |||||
| self.pipeline = pipe | self.pipeline = pipe | ||||
| @@ -145,7 +145,6 @@ class IndexerProcessor(Processor): | |||||
| class VocabProcessor(Processor): | class VocabProcessor(Processor): | ||||
| def __init__(self, field_name): | def __init__(self, field_name): | ||||
| super(VocabProcessor, self).__init__(field_name, None) | super(VocabProcessor, self).__init__(field_name, None) | ||||
| self.vocab = Vocabulary() | self.vocab = Vocabulary() | ||||
| @@ -172,3 +171,15 @@ class SeqLenProcessor(Processor): | |||||
| ins[self.new_added_field_name] = length | ins[self.new_added_field_name] = length | ||||
| dataset.set_need_tensor(**{self.new_added_field_name: True}) | dataset.set_need_tensor(**{self.new_added_field_name: True}) | ||||
| return dataset | return dataset | ||||
| class Index2WordProcessor(Processor): | |||||
| def __init__(self, vocab, field_name, new_added_field_name): | |||||
| super(Index2WordProcessor, self).__init__(field_name, new_added_field_name) | |||||
| self.vocab = vocab | |||||
| def process(self, dataset): | |||||
| for ins in dataset: | |||||
| new_sent = [self.vocab.to_word(w) for w in ins[self.field_name]] | |||||
| ins[self.new_added_field_name] = new_sent | |||||
| return dataset | |||||
| @@ -13,3 +13,6 @@ class BaseModel(torch.nn.Module): | |||||
| def fit(self, train_data, dev_data=None, **train_args): | def fit(self, train_data, dev_data=None, **train_args): | ||||
| trainer = Trainer(**train_args) | trainer = Trainer(**train_args) | ||||
| trainer.train(self, train_data, dev_data) | trainer.train(self, train_data, dev_data) | ||||
| def predict(self): | |||||
| pass | |||||
| @@ -9,6 +9,7 @@ from torch.nn import functional as F | |||||
| from fastNLP.modules.utils import initial_parameter | from fastNLP.modules.utils import initial_parameter | ||||
| from fastNLP.modules.encoder.variational_rnn import VarLSTM | from fastNLP.modules.encoder.variational_rnn import VarLSTM | ||||
| from fastNLP.modules.dropout import TimestepDropout | from fastNLP.modules.dropout import TimestepDropout | ||||
| from fastNLP.models.base_model import BaseModel | |||||
| def mst(scores): | def mst(scores): | ||||
| """ | """ | ||||
| @@ -113,7 +114,7 @@ def _find_cycle(vertices, edges): | |||||
| return [SCC for SCC in _SCCs if len(SCC) > 1] | return [SCC for SCC in _SCCs if len(SCC) > 1] | ||||
| class GraphParser(nn.Module): | |||||
| class GraphParser(BaseModel): | |||||
| """Graph based Parser helper class, support greedy decoding and MST(Maximum Spanning Tree) decoding | """Graph based Parser helper class, support greedy decoding and MST(Maximum Spanning Tree) decoding | ||||
| """ | """ | ||||
| def __init__(self): | def __init__(self): | ||||
| @@ -370,4 +371,20 @@ class BiaffineParser(GraphParser): | |||||
| label_nll = -(label_loss*float_mask).mean() | label_nll = -(label_loss*float_mask).mean() | ||||
| return arc_nll + label_nll | return arc_nll + label_nll | ||||
| def predict(self, word_seq, pos_seq, word_seq_origin_len): | |||||
| """ | |||||
| :param word_seq: | |||||
| :param pos_seq: | |||||
| :param word_seq_origin_len: | |||||
| :return: head_pred: [B, L] | |||||
| label_pred: [B, L] | |||||
| seq_len: [B,] | |||||
| """ | |||||
| res = self(word_seq, pos_seq, word_seq_origin_len) | |||||
| output = {} | |||||
| output['head_pred'] = res.pop('head_pred') | |||||
| _, label_pred = res.pop('label_pred').max(2) | |||||
| output['label_pred'] = label_pred | |||||
| output['seq_len'] = word_seq_origin_len | |||||
| return output | |||||
| @@ -30,11 +30,13 @@ class TestCase1(unittest.TestCase): | |||||
| for text, label in zip(texts, labels): | for text, label in zip(texts, labels): | ||||
| x = TextField(text, is_target=False) | x = TextField(text, is_target=False) | ||||
| y = LabelField(label, is_target=True) | y = LabelField(label, is_target=True) | ||||
| ins = Instance(text=x, label=y) | |||||
| ins = Instance(raw_text=x, label=y) | |||||
| data.append(ins) | data.append(ins) | ||||
| # use vocabulary to index data | # use vocabulary to index data | ||||
| data.index_field("text", vocab) | |||||
| # data.index_field("text", vocab) | |||||
| for ins in data: | |||||
| ins['text'] = [vocab.to_index(w) for w in ins['raw_text']] | |||||
| # define naive sampler for batch class | # define naive sampler for batch class | ||||
| class SeqSampler: | class SeqSampler: | ||||