diff --git a/fastNLP/io/pipe/matching.py b/fastNLP/io/pipe/matching.py index 93e854b1..9f7c7d68 100644 --- a/fastNLP/io/pipe/matching.py +++ b/fastNLP/io/pipe/matching.py @@ -1,4 +1,3 @@ -import math from .pipe import Pipe from .utils import get_tokenizer @@ -19,19 +18,17 @@ class MatchingBertPipe(Pipe): "...", "...", "[...]", ., . words列是将raw_words1(即premise), raw_words2(即hypothesis)使用"[SEP]"链接起来转换为index的。 - words列被设置为input,target列被设置为target. + words列被设置为input,target列被设置为target和input(设置为input以方便在forward函数中计算loss, + 如果不在forward函数中计算loss也不影响,fastNLP将根据forward函数的形参名进行传参). :param bool lower: 是否将word小写化。 :param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 - :param int max_concat_sent_length: 如果concat后的句子长度超过了该值,则合并后的句子将被截断到这个长度,截断时同时对premise - 和hypothesis按比例截断。 """ - def __init__(self, lower=False, tokenizer:str='raw', max_concat_sent_length:int=480): + def __init__(self, lower=False, tokenizer: str='raw'): super().__init__() self.lower = bool(lower) self.tokenizer = get_tokenizer(tokenizer=tokenizer) - self.max_concat_sent_length = int(max_concat_sent_length) def _tokenize(self, data_bundle, field_names, new_field_names): """ @@ -43,11 +40,15 @@ class MatchingBertPipe(Pipe): """ for name, dataset in data_bundle.datasets.items(): for field_name, new_field_name in zip(field_names, new_field_names): - dataset.apply_field(lambda words:self.tokenizer(words), field_name=field_name, + dataset.apply_field(lambda words: self.tokenizer(words), field_name=field_name, new_field_name=new_field_name) return data_bundle def process(self, data_bundle): + for dataset in data_bundle.datasets.values(): + if dataset.has_field(Const.TARGET): + dataset.drop(lambda x: x[Const.TARGET] == '-') + for name, dataset in data_bundle.datasets.items(): dataset.copy_field(Const.RAW_WORDS(0), Const.INPUTS(0)) dataset.copy_field(Const.RAW_WORDS(1), Const.INPUTS(1)) @@ -57,47 +58,38 @@ class MatchingBertPipe(Pipe): dataset[Const.INPUTS(0)].lower() dataset[Const.INPUTS(1)].lower() - data_bundle = self._tokenize(data_bundle, [Const.INPUTS(0), Const.INPUT(1)], + data_bundle = self._tokenize(data_bundle, [Const.INPUTS(0), Const.INPUTS(1)], [Const.INPUTS(0), Const.INPUTS(1)]) # concat两个words def concat(ins): words0 = ins[Const.INPUTS(0)] words1 = ins[Const.INPUTS(1)] - len0 = len(words0) - len1 = len(words1) - if len0 + len1 > self.max_concat_sent_length: - ratio = self.max_concat_sent_length / (len0 + len1) - len0 = math.floor(ratio * len0) - len1 = math.floor(ratio * len1) - words0 = words0[:len0] - words1 = words1[:len1] - words = words0 + ['[SEP]'] + words1 return words + for name, dataset in data_bundle.datasets.items(): dataset.apply(concat, new_field_name=Const.INPUT) dataset.delete_field(Const.INPUTS(0)) dataset.delete_field(Const.INPUTS(1)) word_vocab = Vocabulary() - word_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.INPUT, + word_vocab.from_dataset(*[dataset for name, dataset in data_bundle.datasets.items() if 'train' in name], + field_name=Const.INPUT, no_create_entry_dataset=[dataset for name, dataset in data_bundle.datasets.items() if - name != 'train']) + 'train' not in name]) word_vocab.index_dataset(*data_bundle.datasets.values(), field_name=Const.INPUT) target_vocab = Vocabulary(padding=None, unknown=None) target_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.TARGET) - has_target_datasets = [] - for name, dataset in data_bundle.datasets.items(): - if dataset.has_field(Const.TARGET): - has_target_datasets.append(dataset) + has_target_datasets = [dataset for name, dataset in data_bundle.datasets.items() if + dataset.has_field(Const.TARGET)] target_vocab.index_dataset(*has_target_datasets, field_name=Const.TARGET) data_bundle.set_vocab(word_vocab, Const.INPUT) data_bundle.set_vocab(target_vocab, Const.TARGET) - input_fields = [Const.INPUT, Const.INPUT_LEN] + input_fields = [Const.INPUT, Const.INPUT_LEN, Const.TARGET] target_fields = [Const.TARGET] for name, dataset in data_bundle.datasets.items(): @@ -149,12 +141,14 @@ class MatchingPipe(Pipe): "This site includes a...", "The Government Executive...", "[11, 12, 13,...]", "[2, 7, ...]", 0, 6, 7 "...", "...", "[...]", "[...]", ., ., . - words1是premise,words2是hypothesis。其中words1,words2,seq_len1,seq_len2被设置为input;target被设置为target。 + words1是premise,words2是hypothesis。其中words1,words2,seq_len1,seq_len2被设置为input;target被设置为target + 和input(设置为input以方便在forward函数中计算loss,如果不在forward函数中计算loss也不影响,fastNLP将根据forward函数 + 的形参名进行传参)。 :param bool lower: 是否将所有raw_words转为小写。 :param str tokenizer: 将原始数据tokenize的方式。支持spacy, raw. spacy是使用spacy切分,raw就是用空格切分。 """ - def __init__(self, lower=False, tokenizer:str='raw'): + def __init__(self, lower=False, tokenizer: str='raw'): super().__init__() self.lower = bool(lower) @@ -170,7 +164,7 @@ class MatchingPipe(Pipe): """ for name, dataset in data_bundle.datasets.items(): for field_name, new_field_name in zip(field_names, new_field_names): - dataset.apply_field(lambda words:self.tokenizer(words), field_name=field_name, + dataset.apply_field(lambda words: self.tokenizer(words), field_name=field_name, new_field_name=new_field_name) return data_bundle @@ -191,34 +185,37 @@ class MatchingPipe(Pipe): data_bundle = self._tokenize(data_bundle, [Const.RAW_WORDS(0), Const.RAW_WORDS(1)], [Const.INPUTS(0), Const.INPUTS(1)]) + for dataset in data_bundle.datasets.values(): + if dataset.has_field(Const.TARGET): + dataset.drop(lambda x: x[Const.TARGET] == '-') + if self.lower: for name, dataset in data_bundle.datasets.items(): dataset[Const.INPUTS(0)].lower() dataset[Const.INPUTS(1)].lower() word_vocab = Vocabulary() - word_vocab.from_dataset(data_bundle.datasets['train'], field_name=[Const.INPUTS(0), Const.INPUTS(1)], + word_vocab.from_dataset(*[dataset for name, dataset in data_bundle.datasets.items() if 'train' in name], + field_name=[Const.INPUTS(0), Const.INPUTS(1)], no_create_entry_dataset=[dataset for name, dataset in data_bundle.datasets.items() if - name != 'train']) + 'train' not in name]) word_vocab.index_dataset(*data_bundle.datasets.values(), field_name=[Const.INPUTS(0), Const.INPUTS(1)]) target_vocab = Vocabulary(padding=None, unknown=None) target_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.TARGET) - has_target_datasets = [] - for name, dataset in data_bundle.datasets.items(): - if dataset.has_field(Const.TARGET): - has_target_datasets.append(dataset) + has_target_datasets = [dataset for name, dataset in data_bundle.datasets.items() if + dataset.has_field(Const.TARGET)] target_vocab.index_dataset(*has_target_datasets, field_name=Const.TARGET) data_bundle.set_vocab(word_vocab, Const.INPUTS(0)) data_bundle.set_vocab(target_vocab, Const.TARGET) - input_fields = [Const.INPUTS(0), Const.INPUTS(1), Const.INPUT_LEN(0), Const.INPUT_LEN(1)] + input_fields = [Const.INPUTS(0), Const.INPUTS(1), Const.INPUT_LENS(0), Const.INPUT_LENS(1), Const.TARGET] target_fields = [Const.TARGET] for name, dataset in data_bundle.datasets.items(): - dataset.add_seq_len(Const.INPUTS(0), Const.INPUT_LEN(0)) - dataset.add_seq_len(Const.INPUTS(1), Const.INPUT_LEN(1)) + dataset.add_seq_len(Const.INPUTS(0), Const.INPUT_LENS(0)) + dataset.add_seq_len(Const.INPUTS(1), Const.INPUT_LENS(1)) dataset.set_input(*input_fields, flag=True) dataset.set_target(*target_fields, flag=True)