diff --git a/fastNLP/api/parser.py b/fastNLP/api/parser.py index 67bcca4f..79c070d6 100644 --- a/fastNLP/api/parser.py +++ b/fastNLP/api/parser.py @@ -5,6 +5,8 @@ from fastNLP.api.pipeline import Pipeline from fastNLP.api.processor import * from fastNLP.models.biaffine_parser import BiaffineParser +import torch + class DependencyParser(API): def __init__(self): @@ -18,19 +20,35 @@ class DependencyParser(API): pred = Predictor() res = pred.predict(self.model, dataset) + heads, head_tags = [], [] + for batch in res: + heads.append(batch['heads']) + head_tags.append(batch['labels']) + heads, head_tags = torch.cat(heads, dim=0), torch.cat(head_tags, dim=0) + return heads, head_tags - return res def build(self): - pipe = Pipeline() - - # build pipeline + BOS = '' + NUM = '' + model_args = {} + load_path = '' + word_vocab = load(f'{load_path}/word_v.pkl') + pos_vocab = load(f'{load_path}/pos_v.pkl') word_seq = 'word_seq' pos_seq = 'pos_seq' - pipe.add_processor(Num2TagProcessor('', 'raw_sentence', word_seq)) + + pipe = Pipeline() + # build pipeline + pipe.add_processor(Num2TagProcessor(NUM, 'raw_sentence', word_seq)) + pipe.add_processor(MapFieldProcessor(lambda x: [BOS] + x, word_seq, None)) + pipe.add_processor(MapFieldProcessor(lambda x: [BOS] + x, pos_seq, None)) pipe.add_processor(IndexerProcessor(word_vocab, word_seq, word_seq+'_idx')) pipe.add_processor(IndexerProcessor(pos_vocab, pos_seq, pos_seq+'_idx')) + pipe.add_processor(MapFieldProcessor(lambda x: len(x), word_seq, 'seq_len')) + # load model parameters - self.model = BiaffineParser() + self.model = BiaffineParser(**model_args) self.pipeline = pipe + diff --git a/fastNLP/api/processor.py b/fastNLP/api/processor.py index 109aa7b6..97e9b1b2 100644 --- a/fastNLP/api/processor.py +++ b/fastNLP/api/processor.py @@ -145,7 +145,6 @@ class IndexerProcessor(Processor): class VocabProcessor(Processor): def __init__(self, field_name): - super(VocabProcessor, self).__init__(field_name, None) self.vocab = Vocabulary() @@ -172,3 +171,15 @@ class SeqLenProcessor(Processor): ins[self.new_added_field_name] = length dataset.set_need_tensor(**{self.new_added_field_name: True}) return dataset + +class Index2WordProcessor(Processor): + def __init__(self, vocab, field_name, new_added_field_name): + super(Index2WordProcessor, self).__init__(field_name, new_added_field_name) + self.vocab = vocab + + def process(self, dataset): + for ins in dataset: + new_sent = [self.vocab.to_word(w) for w in ins[self.field_name]] + ins[self.new_added_field_name] = new_sent + return dataset + diff --git a/fastNLP/models/base_model.py b/fastNLP/models/base_model.py index c73bdfd9..59605f4f 100644 --- a/fastNLP/models/base_model.py +++ b/fastNLP/models/base_model.py @@ -13,3 +13,6 @@ class BaseModel(torch.nn.Module): def fit(self, train_data, dev_data=None, **train_args): trainer = Trainer(**train_args) trainer.train(self, train_data, dev_data) + + def predict(self): + pass diff --git a/fastNLP/models/biaffine_parser.py b/fastNLP/models/biaffine_parser.py index 7e0a9cec..37070e1b 100644 --- a/fastNLP/models/biaffine_parser.py +++ b/fastNLP/models/biaffine_parser.py @@ -9,6 +9,7 @@ from torch.nn import functional as F from fastNLP.modules.utils import initial_parameter from fastNLP.modules.encoder.variational_rnn import VarLSTM from fastNLP.modules.dropout import TimestepDropout +from fastNLP.models.base_model import BaseModel def mst(scores): """ @@ -113,7 +114,7 @@ def _find_cycle(vertices, edges): return [SCC for SCC in _SCCs if len(SCC) > 1] -class GraphParser(nn.Module): +class GraphParser(BaseModel): """Graph based Parser helper class, support greedy decoding and MST(Maximum Spanning Tree) decoding """ def __init__(self): @@ -370,4 +371,20 @@ class BiaffineParser(GraphParser): label_nll = -(label_loss*float_mask).mean() return arc_nll + label_nll + def predict(self, word_seq, pos_seq, word_seq_origin_len): + """ + :param word_seq: + :param pos_seq: + :param word_seq_origin_len: + :return: head_pred: [B, L] + label_pred: [B, L] + seq_len: [B,] + """ + res = self(word_seq, pos_seq, word_seq_origin_len) + output = {} + output['head_pred'] = res.pop('head_pred') + _, label_pred = res.pop('label_pred').max(2) + output['label_pred'] = label_pred + output['seq_len'] = word_seq_origin_len + return output diff --git a/test/core/test_batch.py b/test/core/test_batch.py index 826167ac..6418cd99 100644 --- a/test/core/test_batch.py +++ b/test/core/test_batch.py @@ -30,11 +30,13 @@ class TestCase1(unittest.TestCase): for text, label in zip(texts, labels): x = TextField(text, is_target=False) y = LabelField(label, is_target=True) - ins = Instance(text=x, label=y) + ins = Instance(raw_text=x, label=y) data.append(ins) # use vocabulary to index data - data.index_field("text", vocab) + # data.index_field("text", vocab) + for ins in data: + ins['text'] = [vocab.to_index(w) for w in ins['raw_text']] # define naive sampler for batch class class SeqSampler: