From 15d9581e6d0805ad52a3f7c367d329999e3841e2 Mon Sep 17 00:00:00 2001 From: xuyige Date: Sun, 30 Jun 2019 15:44:26 +0800 Subject: [PATCH] fix a bug in predictor --- fastNLP/core/predictor.py | 4 +- .../matching/data/MatchingDataLoader.py | 93 ++++++++++++++++--- 2 files changed, 82 insertions(+), 15 deletions(-) diff --git a/fastNLP/core/predictor.py b/fastNLP/core/predictor.py index 06e586c6..ce016bb6 100644 --- a/fastNLP/core/predictor.py +++ b/fastNLP/core/predictor.py @@ -9,7 +9,7 @@ import torch from . import DataSetIter from . import DataSet from . import SequentialSampler -from .utils import _build_args +from .utils import _build_args, _move_dict_value_to_device, _get_model_device class Predictor(object): @@ -43,6 +43,7 @@ class Predictor(object): raise ValueError("Field name {} not found in DataSet {}.".format(seq_len_field_name, data)) self.network.eval() + network_device = _get_model_device(self.network) batch_output = defaultdict(list) data_iterator = DataSetIter(data, batch_size=self.batch_size, sampler=SequentialSampler(), as_numpy=False) @@ -53,6 +54,7 @@ class Predictor(object): with torch.no_grad(): for batch_x, _ in data_iterator: + _move_dict_value_to_device(batch_x, _, device=network_device) refined_batch_x = _build_args(predict_func, **batch_x) prediction = predict_func(**refined_batch_x) diff --git a/reproduction/matching/data/MatchingDataLoader.py b/reproduction/matching/data/MatchingDataLoader.py index 0e4e1283..749b16c8 100644 --- a/reproduction/matching/data/MatchingDataLoader.py +++ b/reproduction/matching/data/MatchingDataLoader.py @@ -86,7 +86,8 @@ class MatchingLoader(DataSetLoader): if auto_set_input: data_set.set_input(Const.INPUTS(0), Const.INPUTS(1)) if auto_set_target: - data_set.set_target(Const.TARGET) + if Const.TARGET in data_set.get_field_names(): + data_set.set_target(Const.TARGET) if to_lower: for data_name, data_set in data_info.datasets.items(): @@ -107,6 +108,13 @@ class MatchingLoader(DataSetLoader): else: raise ValueError(f"Cannot recognize BERT tokenizer from {bert_tokenizer}.") + words_vocab = Vocabulary(padding='[PAD]', unknown='[UNK]') + with open(os.path.join(model_dir, 'vocab.txt'), 'r') as f: + lines = f.readlines() + lines = [line.strip() for line in lines] + words_vocab.add_word_lst(lines) + words_vocab.build_vocab() + tokenizer = BertTokenizer.from_pretrained(model_dir) for data_name, data_set in data_info.datasets.items(): @@ -171,14 +179,7 @@ class MatchingLoader(DataSetLoader): data_set_list = [d for n, d in data_info.datasets.items()] assert len(data_set_list) > 0, f'There are NO data sets in data info!' - if bert_tokenizer is not None: - words_vocab = Vocabulary(padding='[PAD]', unknown='[UNK]') - with open(os.path.join(model_dir, 'vocab.txt'), 'r') as f: - lines = f.readlines() - lines = [line.strip() for line in lines] - words_vocab.add_word_lst(lines) - words_vocab.build_vocab() - else: + if bert_tokenizer is None: words_vocab = Vocabulary() words_vocab = words_vocab.from_dataset(*[d for n, d in data_info.datasets.items() if 'train' in n], field_name=[n for n in data_set_list[0].get_field_names() @@ -186,7 +187,8 @@ class MatchingLoader(DataSetLoader): no_create_entry_dataset=[d for n, d in data_info.datasets.items() if 'train' not in n]) target_vocab = Vocabulary(padding=None, unknown=None) - target_vocab = target_vocab.from_dataset(*data_set_list, field_name=Const.TARGET) + target_vocab = target_vocab.from_dataset(*[d for n, d in data_info.datasets.items() if 'train' in n], + field_name=Const.TARGET) data_info.vocabs = {Const.INPUT: words_vocab, Const.TARGET: target_vocab} if get_index: @@ -196,14 +198,15 @@ class MatchingLoader(DataSetLoader): data_set.apply(lambda x: [words_vocab.to_index(w) for w in x[fields]], new_field_name=fields, is_input=auto_set_input) - data_set.apply(lambda x: target_vocab.to_index(x[Const.TARGET]), new_field_name=Const.TARGET, - is_input=auto_set_input, is_target=auto_set_target) + if Const.TARGET in data_set.get_field_names(): + data_set.apply(lambda x: target_vocab.to_index(x[Const.TARGET]), new_field_name=Const.TARGET, + is_input=auto_set_input, is_target=auto_set_target) for data_name, data_set in data_info.datasets.items(): if isinstance(set_input, list): - data_set.set_input(*set_input) + data_set.set_input(*[inputs for inputs in set_input if inputs in data_set.get_field_names()]) if isinstance(set_target, list): - data_set.set_target(*set_target) + data_set.set_target(*[target for target in set_target if target in data_set.get_field_names()]) return data_info @@ -324,3 +327,65 @@ class QNLILoader(MatchingLoader, CSVLoader): return ds + +class MNLILoader(MatchingLoader, CSVLoader): + """ + 别名::class:`fastNLP.io.MNLILoader` :class:`fastNLP.io.dataset_loader.MNLILoader` + + 读取SNLI数据集,读取的DataSet包含fields:: + + words1: list(str),第一句文本, premise + words2: list(str), 第二句文本, hypothesis + target: str, 真实标签 + + 数据来源: + """ + + def __init__(self, paths: dict=None): + paths = paths if paths is not None else { + 'train': 'train.tsv', + 'dev_matched': 'dev_matched.tsv', + 'dev_mismatched': 'dev_mismatched.tsv', + 'test_matched': 'test_matched.tsv', + 'test_mismatched': 'test_mismatched.tsv', + } + MatchingLoader.__init__(self, paths=paths) + CSVLoader.__init__(self, sep='\t') + self.fields = { + 'sentence1_binary_parse': Const.INPUTS(0), + 'sentence2_binary_parse': Const.INPUTS(1), + 'gold_label': Const.TARGET, + } + + def _load(self, path): + ds = CSVLoader._load(self, path) + + for k, v in self.fields.items(): + if k in ds.get_field_names(): + ds.rename_field(k, v) + + parentheses_table = str.maketrans({'(': None, ')': None}) + + ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(), + new_field_name=Const.INPUTS(0)) + ds.apply(lambda ins: ins[Const.INPUTS(1)].translate(parentheses_table).strip().split(), + new_field_name=Const.INPUTS(1)) + if Const.TARGET in ds.get_field_names(): + ds.drop(lambda x: x[Const.TARGET] == '-') + return ds + + +class QuoraLoader(MatchingLoader, CSVLoader): + + def __init__(self, paths: dict=None): + paths = paths if paths is not None else { + 'train': 'train.tsv', + 'dev': 'dev.tsv', + 'test': 'test.tsv', + } + MatchingLoader.__init__(self, paths=paths) + CSVLoader.__init__(self, sep='\t', headers=(Const.TARGET, Const.INPUTS(0), Const.INPUTS(1), 'pairID')) + + def _load(self, path): + ds = CSVLoader._load(self, path) + return ds