Browse Source

add index to word processor

tags/v0.2.0
yunfan 5 years ago
parent
commit
82f4351540
5 changed files with 61 additions and 10 deletions
  1. +24
    -6
      fastNLP/api/parser.py
  2. +12
    -1
      fastNLP/api/processor.py
  3. +3
    -0
      fastNLP/models/base_model.py
  4. +18
    -1
      fastNLP/models/biaffine_parser.py
  5. +4
    -2
      test/core/test_batch.py

+ 24
- 6
fastNLP/api/parser.py View File

@@ -5,6 +5,8 @@ from fastNLP.api.pipeline import Pipeline
from fastNLP.api.processor import *
from fastNLP.models.biaffine_parser import BiaffineParser

import torch


class DependencyParser(API):
def __init__(self):
@@ -18,19 +20,35 @@ class DependencyParser(API):

pred = Predictor()
res = pred.predict(self.model, dataset)
heads, head_tags = [], []
for batch in res:
heads.append(batch['heads'])
head_tags.append(batch['labels'])
heads, head_tags = torch.cat(heads, dim=0), torch.cat(head_tags, dim=0)
return heads, head_tags

return res

def build(self):
pipe = Pipeline()

# build pipeline
BOS = '<BOS>'
NUM = '<NUM>'
model_args = {}
load_path = ''
word_vocab = load(f'{load_path}/word_v.pkl')
pos_vocab = load(f'{load_path}/pos_v.pkl')
word_seq = 'word_seq'
pos_seq = 'pos_seq'
pipe.add_processor(Num2TagProcessor('<NUM>', 'raw_sentence', word_seq))

pipe = Pipeline()
# build pipeline
pipe.add_processor(Num2TagProcessor(NUM, 'raw_sentence', word_seq))
pipe.add_processor(MapFieldProcessor(lambda x: [BOS] + x, word_seq, None))
pipe.add_processor(MapFieldProcessor(lambda x: [BOS] + x, pos_seq, None))
pipe.add_processor(IndexerProcessor(word_vocab, word_seq, word_seq+'_idx'))
pipe.add_processor(IndexerProcessor(pos_vocab, pos_seq, pos_seq+'_idx'))
pipe.add_processor(MapFieldProcessor(lambda x: len(x), word_seq, 'seq_len'))


# load model parameters
self.model = BiaffineParser()
self.model = BiaffineParser(**model_args)
self.pipeline = pipe


+ 12
- 1
fastNLP/api/processor.py View File

@@ -145,7 +145,6 @@ class IndexerProcessor(Processor):

class VocabProcessor(Processor):
def __init__(self, field_name):

super(VocabProcessor, self).__init__(field_name, None)
self.vocab = Vocabulary()

@@ -172,3 +171,15 @@ class SeqLenProcessor(Processor):
ins[self.new_added_field_name] = length
dataset.set_need_tensor(**{self.new_added_field_name: True})
return dataset

class Index2WordProcessor(Processor):
def __init__(self, vocab, field_name, new_added_field_name):
super(Index2WordProcessor, self).__init__(field_name, new_added_field_name)
self.vocab = vocab

def process(self, dataset):
for ins in dataset:
new_sent = [self.vocab.to_word(w) for w in ins[self.field_name]]
ins[self.new_added_field_name] = new_sent
return dataset


+ 3
- 0
fastNLP/models/base_model.py View File

@@ -13,3 +13,6 @@ class BaseModel(torch.nn.Module):
def fit(self, train_data, dev_data=None, **train_args):
trainer = Trainer(**train_args)
trainer.train(self, train_data, dev_data)

def predict(self):
pass

+ 18
- 1
fastNLP/models/biaffine_parser.py View File

@@ -9,6 +9,7 @@ from torch.nn import functional as F
from fastNLP.modules.utils import initial_parameter
from fastNLP.modules.encoder.variational_rnn import VarLSTM
from fastNLP.modules.dropout import TimestepDropout
from fastNLP.models.base_model import BaseModel

def mst(scores):
"""
@@ -113,7 +114,7 @@ def _find_cycle(vertices, edges):
return [SCC for SCC in _SCCs if len(SCC) > 1]


class GraphParser(nn.Module):
class GraphParser(BaseModel):
"""Graph based Parser helper class, support greedy decoding and MST(Maximum Spanning Tree) decoding
"""
def __init__(self):
@@ -370,4 +371,20 @@ class BiaffineParser(GraphParser):
label_nll = -(label_loss*float_mask).mean()
return arc_nll + label_nll

def predict(self, word_seq, pos_seq, word_seq_origin_len):
"""

:param word_seq:
:param pos_seq:
:param word_seq_origin_len:
:return: head_pred: [B, L]
label_pred: [B, L]
seq_len: [B,]
"""
res = self(word_seq, pos_seq, word_seq_origin_len)
output = {}
output['head_pred'] = res.pop('head_pred')
_, label_pred = res.pop('label_pred').max(2)
output['label_pred'] = label_pred
output['seq_len'] = word_seq_origin_len
return output

+ 4
- 2
test/core/test_batch.py View File

@@ -30,11 +30,13 @@ class TestCase1(unittest.TestCase):
for text, label in zip(texts, labels):
x = TextField(text, is_target=False)
y = LabelField(label, is_target=True)
ins = Instance(text=x, label=y)
ins = Instance(raw_text=x, label=y)
data.append(ins)

# use vocabulary to index data
data.index_field("text", vocab)
# data.index_field("text", vocab)
for ins in data:
ins['text'] = [vocab.to_index(w) for w in ins['raw_text']]

# define naive sampler for batch class
class SeqSampler:


Loading…
Cancel
Save