Browse Source

fix a bug in predictor

tags/v0.4.10
xuyige 5 years ago
parent
commit
15d9581e6d
2 changed files with 82 additions and 15 deletions
  1. +3
    -1
      fastNLP/core/predictor.py
  2. +79
    -14
      reproduction/matching/data/MatchingDataLoader.py

+ 3
- 1
fastNLP/core/predictor.py View File

@@ -9,7 +9,7 @@ import torch
from . import DataSetIter
from . import DataSet
from . import SequentialSampler
from .utils import _build_args
from .utils import _build_args, _move_dict_value_to_device, _get_model_device


class Predictor(object):
@@ -43,6 +43,7 @@ class Predictor(object):
raise ValueError("Field name {} not found in DataSet {}.".format(seq_len_field_name, data))

self.network.eval()
network_device = _get_model_device(self.network)
batch_output = defaultdict(list)
data_iterator = DataSetIter(data, batch_size=self.batch_size, sampler=SequentialSampler(), as_numpy=False)

@@ -53,6 +54,7 @@ class Predictor(object):

with torch.no_grad():
for batch_x, _ in data_iterator:
_move_dict_value_to_device(batch_x, _, device=network_device)
refined_batch_x = _build_args(predict_func, **batch_x)
prediction = predict_func(**refined_batch_x)



+ 79
- 14
reproduction/matching/data/MatchingDataLoader.py View File

@@ -86,7 +86,8 @@ class MatchingLoader(DataSetLoader):
if auto_set_input:
data_set.set_input(Const.INPUTS(0), Const.INPUTS(1))
if auto_set_target:
data_set.set_target(Const.TARGET)
if Const.TARGET in data_set.get_field_names():
data_set.set_target(Const.TARGET)

if to_lower:
for data_name, data_set in data_info.datasets.items():
@@ -107,6 +108,13 @@ class MatchingLoader(DataSetLoader):
else:
raise ValueError(f"Cannot recognize BERT tokenizer from {bert_tokenizer}.")

words_vocab = Vocabulary(padding='[PAD]', unknown='[UNK]')
with open(os.path.join(model_dir, 'vocab.txt'), 'r') as f:
lines = f.readlines()
lines = [line.strip() for line in lines]
words_vocab.add_word_lst(lines)
words_vocab.build_vocab()

tokenizer = BertTokenizer.from_pretrained(model_dir)

for data_name, data_set in data_info.datasets.items():
@@ -171,14 +179,7 @@ class MatchingLoader(DataSetLoader):
data_set_list = [d for n, d in data_info.datasets.items()]
assert len(data_set_list) > 0, f'There are NO data sets in data info!'

if bert_tokenizer is not None:
words_vocab = Vocabulary(padding='[PAD]', unknown='[UNK]')
with open(os.path.join(model_dir, 'vocab.txt'), 'r') as f:
lines = f.readlines()
lines = [line.strip() for line in lines]
words_vocab.add_word_lst(lines)
words_vocab.build_vocab()
else:
if bert_tokenizer is None:
words_vocab = Vocabulary()
words_vocab = words_vocab.from_dataset(*[d for n, d in data_info.datasets.items() if 'train' in n],
field_name=[n for n in data_set_list[0].get_field_names()
@@ -186,7 +187,8 @@ class MatchingLoader(DataSetLoader):
no_create_entry_dataset=[d for n, d in data_info.datasets.items()
if 'train' not in n])
target_vocab = Vocabulary(padding=None, unknown=None)
target_vocab = target_vocab.from_dataset(*data_set_list, field_name=Const.TARGET)
target_vocab = target_vocab.from_dataset(*[d for n, d in data_info.datasets.items() if 'train' in n],
field_name=Const.TARGET)
data_info.vocabs = {Const.INPUT: words_vocab, Const.TARGET: target_vocab}

if get_index:
@@ -196,14 +198,15 @@ class MatchingLoader(DataSetLoader):
data_set.apply(lambda x: [words_vocab.to_index(w) for w in x[fields]], new_field_name=fields,
is_input=auto_set_input)

data_set.apply(lambda x: target_vocab.to_index(x[Const.TARGET]), new_field_name=Const.TARGET,
is_input=auto_set_input, is_target=auto_set_target)
if Const.TARGET in data_set.get_field_names():
data_set.apply(lambda x: target_vocab.to_index(x[Const.TARGET]), new_field_name=Const.TARGET,
is_input=auto_set_input, is_target=auto_set_target)

for data_name, data_set in data_info.datasets.items():
if isinstance(set_input, list):
data_set.set_input(*set_input)
data_set.set_input(*[inputs for inputs in set_input if inputs in data_set.get_field_names()])
if isinstance(set_target, list):
data_set.set_target(*set_target)
data_set.set_target(*[target for target in set_target if target in data_set.get_field_names()])

return data_info

@@ -324,3 +327,65 @@ class QNLILoader(MatchingLoader, CSVLoader):

return ds


class MNLILoader(MatchingLoader, CSVLoader):
"""
别名::class:`fastNLP.io.MNLILoader` :class:`fastNLP.io.dataset_loader.MNLILoader`

读取SNLI数据集,读取的DataSet包含fields::

words1: list(str),第一句文本, premise
words2: list(str), 第二句文本, hypothesis
target: str, 真实标签

数据来源:
"""

def __init__(self, paths: dict=None):
paths = paths if paths is not None else {
'train': 'train.tsv',
'dev_matched': 'dev_matched.tsv',
'dev_mismatched': 'dev_mismatched.tsv',
'test_matched': 'test_matched.tsv',
'test_mismatched': 'test_mismatched.tsv',
}
MatchingLoader.__init__(self, paths=paths)
CSVLoader.__init__(self, sep='\t')
self.fields = {
'sentence1_binary_parse': Const.INPUTS(0),
'sentence2_binary_parse': Const.INPUTS(1),
'gold_label': Const.TARGET,
}

def _load(self, path):
ds = CSVLoader._load(self, path)

for k, v in self.fields.items():
if k in ds.get_field_names():
ds.rename_field(k, v)

parentheses_table = str.maketrans({'(': None, ')': None})

ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(),
new_field_name=Const.INPUTS(0))
ds.apply(lambda ins: ins[Const.INPUTS(1)].translate(parentheses_table).strip().split(),
new_field_name=Const.INPUTS(1))
if Const.TARGET in ds.get_field_names():
ds.drop(lambda x: x[Const.TARGET] == '-')
return ds


class QuoraLoader(MatchingLoader, CSVLoader):

def __init__(self, paths: dict=None):
paths = paths if paths is not None else {
'train': 'train.tsv',
'dev': 'dev.tsv',
'test': 'test.tsv',
}
MatchingLoader.__init__(self, paths=paths)
CSVLoader.__init__(self, sep='\t', headers=(Const.TARGET, Const.INPUTS(0), Const.INPUTS(1), 'pairID'))

def _load(self, path):
ds = CSVLoader._load(self, path)
return ds

Loading…
Cancel
Save