Browse Source

update matching pipe.

tags/v0.4.10
xuyige 6 years ago
parent
commit
09e24b3bd7
1 changed files with 33 additions and 36 deletions
  1. +33
    -36
      fastNLP/io/pipe/matching.py

+ 33
- 36
fastNLP/io/pipe/matching.py View File

@@ -1,4 +1,3 @@
import math


from .pipe import Pipe from .pipe import Pipe
from .utils import get_tokenizer from .utils import get_tokenizer
@@ -19,19 +18,17 @@ class MatchingBertPipe(Pipe):
"...", "...", "[...]", ., . "...", "...", "[...]", ., .


words列是将raw_words1(即premise), raw_words2(即hypothesis)使用"[SEP]"链接起来转换为index的。 words列是将raw_words1(即premise), raw_words2(即hypothesis)使用"[SEP]"链接起来转换为index的。
words列被设置为input,target列被设置为target.
words列被设置为input,target列被设置为target和input(设置为input以方便在forward函数中计算loss,
如果不在forward函数中计算loss也不影响,fastNLP将根据forward函数的形参名进行传参).


:param bool lower: 是否将word小写化。 :param bool lower: 是否将word小写化。
:param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 :param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。
:param int max_concat_sent_length: 如果concat后的句子长度超过了该值,则合并后的句子将被截断到这个长度,截断时同时对premise
和hypothesis按比例截断。
""" """
def __init__(self, lower=False, tokenizer:str='raw', max_concat_sent_length:int=480):
def __init__(self, lower=False, tokenizer: str='raw'):
super().__init__() super().__init__()


self.lower = bool(lower) self.lower = bool(lower)
self.tokenizer = get_tokenizer(tokenizer=tokenizer) self.tokenizer = get_tokenizer(tokenizer=tokenizer)
self.max_concat_sent_length = int(max_concat_sent_length)


def _tokenize(self, data_bundle, field_names, new_field_names): def _tokenize(self, data_bundle, field_names, new_field_names):
""" """
@@ -43,11 +40,15 @@ class MatchingBertPipe(Pipe):
""" """
for name, dataset in data_bundle.datasets.items(): for name, dataset in data_bundle.datasets.items():
for field_name, new_field_name in zip(field_names, new_field_names): for field_name, new_field_name in zip(field_names, new_field_names):
dataset.apply_field(lambda words:self.tokenizer(words), field_name=field_name,
dataset.apply_field(lambda words: self.tokenizer(words), field_name=field_name,
new_field_name=new_field_name) new_field_name=new_field_name)
return data_bundle return data_bundle


def process(self, data_bundle): def process(self, data_bundle):
for dataset in data_bundle.datasets.values():
if dataset.has_field(Const.TARGET):
dataset.drop(lambda x: x[Const.TARGET] == '-')

for name, dataset in data_bundle.datasets.items(): for name, dataset in data_bundle.datasets.items():
dataset.copy_field(Const.RAW_WORDS(0), Const.INPUTS(0)) dataset.copy_field(Const.RAW_WORDS(0), Const.INPUTS(0))
dataset.copy_field(Const.RAW_WORDS(1), Const.INPUTS(1)) dataset.copy_field(Const.RAW_WORDS(1), Const.INPUTS(1))
@@ -57,47 +58,38 @@ class MatchingBertPipe(Pipe):
dataset[Const.INPUTS(0)].lower() dataset[Const.INPUTS(0)].lower()
dataset[Const.INPUTS(1)].lower() dataset[Const.INPUTS(1)].lower()


data_bundle = self._tokenize(data_bundle, [Const.INPUTS(0), Const.INPUT(1)],
data_bundle = self._tokenize(data_bundle, [Const.INPUTS(0), Const.INPUTS(1)],
[Const.INPUTS(0), Const.INPUTS(1)]) [Const.INPUTS(0), Const.INPUTS(1)])


# concat两个words # concat两个words
def concat(ins): def concat(ins):
words0 = ins[Const.INPUTS(0)] words0 = ins[Const.INPUTS(0)]
words1 = ins[Const.INPUTS(1)] words1 = ins[Const.INPUTS(1)]
len0 = len(words0)
len1 = len(words1)
if len0 + len1 > self.max_concat_sent_length:
ratio = self.max_concat_sent_length / (len0 + len1)
len0 = math.floor(ratio * len0)
len1 = math.floor(ratio * len1)
words0 = words0[:len0]
words1 = words1[:len1]

words = words0 + ['[SEP]'] + words1 words = words0 + ['[SEP]'] + words1
return words return words

for name, dataset in data_bundle.datasets.items(): for name, dataset in data_bundle.datasets.items():
dataset.apply(concat, new_field_name=Const.INPUT) dataset.apply(concat, new_field_name=Const.INPUT)
dataset.delete_field(Const.INPUTS(0)) dataset.delete_field(Const.INPUTS(0))
dataset.delete_field(Const.INPUTS(1)) dataset.delete_field(Const.INPUTS(1))


word_vocab = Vocabulary() word_vocab = Vocabulary()
word_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.INPUT,
word_vocab.from_dataset(*[dataset for name, dataset in data_bundle.datasets.items() if 'train' in name],
field_name=Const.INPUT,
no_create_entry_dataset=[dataset for name, dataset in data_bundle.datasets.items() if no_create_entry_dataset=[dataset for name, dataset in data_bundle.datasets.items() if
name != 'train'])
'train' not in name])
word_vocab.index_dataset(*data_bundle.datasets.values(), field_name=Const.INPUT) word_vocab.index_dataset(*data_bundle.datasets.values(), field_name=Const.INPUT)


target_vocab = Vocabulary(padding=None, unknown=None) target_vocab = Vocabulary(padding=None, unknown=None)
target_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.TARGET) target_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.TARGET)
has_target_datasets = []
for name, dataset in data_bundle.datasets.items():
if dataset.has_field(Const.TARGET):
has_target_datasets.append(dataset)
has_target_datasets = [dataset for name, dataset in data_bundle.datasets.items() if
dataset.has_field(Const.TARGET)]
target_vocab.index_dataset(*has_target_datasets, field_name=Const.TARGET) target_vocab.index_dataset(*has_target_datasets, field_name=Const.TARGET)


data_bundle.set_vocab(word_vocab, Const.INPUT) data_bundle.set_vocab(word_vocab, Const.INPUT)
data_bundle.set_vocab(target_vocab, Const.TARGET) data_bundle.set_vocab(target_vocab, Const.TARGET)


input_fields = [Const.INPUT, Const.INPUT_LEN]
input_fields = [Const.INPUT, Const.INPUT_LEN, Const.TARGET]
target_fields = [Const.TARGET] target_fields = [Const.TARGET]


for name, dataset in data_bundle.datasets.items(): for name, dataset in data_bundle.datasets.items():
@@ -149,12 +141,14 @@ class MatchingPipe(Pipe):
"This site includes a...", "The Government Executive...", "[11, 12, 13,...]", "[2, 7, ...]", 0, 6, 7 "This site includes a...", "The Government Executive...", "[11, 12, 13,...]", "[2, 7, ...]", 0, 6, 7
"...", "...", "[...]", "[...]", ., ., . "...", "...", "[...]", "[...]", ., ., .


words1是premise,words2是hypothesis。其中words1,words2,seq_len1,seq_len2被设置为input;target被设置为target。
words1是premise,words2是hypothesis。其中words1,words2,seq_len1,seq_len2被设置为input;target被设置为target
和input(设置为input以方便在forward函数中计算loss,如果不在forward函数中计算loss也不影响,fastNLP将根据forward函数
的形参名进行传参)。


:param bool lower: 是否将所有raw_words转为小写。 :param bool lower: 是否将所有raw_words转为小写。
:param str tokenizer: 将原始数据tokenize的方式。支持spacy, raw. spacy是使用spacy切分,raw就是用空格切分。 :param str tokenizer: 将原始数据tokenize的方式。支持spacy, raw. spacy是使用spacy切分,raw就是用空格切分。
""" """
def __init__(self, lower=False, tokenizer:str='raw'):
def __init__(self, lower=False, tokenizer: str='raw'):
super().__init__() super().__init__()


self.lower = bool(lower) self.lower = bool(lower)
@@ -170,7 +164,7 @@ class MatchingPipe(Pipe):
""" """
for name, dataset in data_bundle.datasets.items(): for name, dataset in data_bundle.datasets.items():
for field_name, new_field_name in zip(field_names, new_field_names): for field_name, new_field_name in zip(field_names, new_field_names):
dataset.apply_field(lambda words:self.tokenizer(words), field_name=field_name,
dataset.apply_field(lambda words: self.tokenizer(words), field_name=field_name,
new_field_name=new_field_name) new_field_name=new_field_name)
return data_bundle return data_bundle


@@ -191,34 +185,37 @@ class MatchingPipe(Pipe):
data_bundle = self._tokenize(data_bundle, [Const.RAW_WORDS(0), Const.RAW_WORDS(1)], data_bundle = self._tokenize(data_bundle, [Const.RAW_WORDS(0), Const.RAW_WORDS(1)],
[Const.INPUTS(0), Const.INPUTS(1)]) [Const.INPUTS(0), Const.INPUTS(1)])


for dataset in data_bundle.datasets.values():
if dataset.has_field(Const.TARGET):
dataset.drop(lambda x: x[Const.TARGET] == '-')

if self.lower: if self.lower:
for name, dataset in data_bundle.datasets.items(): for name, dataset in data_bundle.datasets.items():
dataset[Const.INPUTS(0)].lower() dataset[Const.INPUTS(0)].lower()
dataset[Const.INPUTS(1)].lower() dataset[Const.INPUTS(1)].lower()


word_vocab = Vocabulary() word_vocab = Vocabulary()
word_vocab.from_dataset(data_bundle.datasets['train'], field_name=[Const.INPUTS(0), Const.INPUTS(1)],
word_vocab.from_dataset(*[dataset for name, dataset in data_bundle.datasets.items() if 'train' in name],
field_name=[Const.INPUTS(0), Const.INPUTS(1)],
no_create_entry_dataset=[dataset for name, dataset in data_bundle.datasets.items() if no_create_entry_dataset=[dataset for name, dataset in data_bundle.datasets.items() if
name != 'train'])
'train' not in name])
word_vocab.index_dataset(*data_bundle.datasets.values(), field_name=[Const.INPUTS(0), Const.INPUTS(1)]) word_vocab.index_dataset(*data_bundle.datasets.values(), field_name=[Const.INPUTS(0), Const.INPUTS(1)])


target_vocab = Vocabulary(padding=None, unknown=None) target_vocab = Vocabulary(padding=None, unknown=None)
target_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.TARGET) target_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.TARGET)
has_target_datasets = []
for name, dataset in data_bundle.datasets.items():
if dataset.has_field(Const.TARGET):
has_target_datasets.append(dataset)
has_target_datasets = [dataset for name, dataset in data_bundle.datasets.items() if
dataset.has_field(Const.TARGET)]
target_vocab.index_dataset(*has_target_datasets, field_name=Const.TARGET) target_vocab.index_dataset(*has_target_datasets, field_name=Const.TARGET)


data_bundle.set_vocab(word_vocab, Const.INPUTS(0)) data_bundle.set_vocab(word_vocab, Const.INPUTS(0))
data_bundle.set_vocab(target_vocab, Const.TARGET) data_bundle.set_vocab(target_vocab, Const.TARGET)


input_fields = [Const.INPUTS(0), Const.INPUTS(1), Const.INPUT_LEN(0), Const.INPUT_LEN(1)]
input_fields = [Const.INPUTS(0), Const.INPUTS(1), Const.INPUT_LENS(0), Const.INPUT_LENS(1), Const.TARGET]
target_fields = [Const.TARGET] target_fields = [Const.TARGET]


for name, dataset in data_bundle.datasets.items(): for name, dataset in data_bundle.datasets.items():
dataset.add_seq_len(Const.INPUTS(0), Const.INPUT_LEN(0))
dataset.add_seq_len(Const.INPUTS(1), Const.INPUT_LEN(1))
dataset.add_seq_len(Const.INPUTS(0), Const.INPUT_LENS(0))
dataset.add_seq_len(Const.INPUTS(1), Const.INPUT_LENS(1))
dataset.set_input(*input_fields, flag=True) dataset.set_input(*input_fields, flag=True)
dataset.set_target(*target_fields, flag=True) dataset.set_target(*target_fields, flag=True)




Loading…
Cancel
Save