From e314a367784d47aeb2082bfe6e2cb75f7b6a79b6 Mon Sep 17 00:00:00 2001 From: Yige Xu Date: Thu, 12 Sep 2019 03:10:17 +0800 Subject: [PATCH] fix a bug function _indexize in fastNLP/io/pipe/utils.py --- fastNLP/io/pipe/utils.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/fastNLP/io/pipe/utils.py b/fastNLP/io/pipe/utils.py index 3db9c4fe..92d61bfd 100644 --- a/fastNLP/io/pipe/utils.py +++ b/fastNLP/io/pipe/utils.py @@ -105,18 +105,20 @@ def _indexize(data_bundle, input_field_names=Const.INPUT, target_field_names=Con target_field_names = [target_field_names] for input_field_name in input_field_names: src_vocab = Vocabulary() - src_vocab.from_dataset(data_bundle.datasets['train'], field_name=input_field_name, - no_create_entry_dataset=[dataset for name, dataset in data_bundle.datasets.items() if - name != 'train']) + src_vocab.from_dataset(*[ds for name, ds in data_bundle.iter_datasets() if 'train' in name], + field_name=input_field_name, + no_create_entry_dataset=[ds for name, ds in data_bundle.iter_datasets() + if ('train' not in name) and (ds.has_field(input_field_name))] + ) src_vocab.index_dataset(*data_bundle.datasets.values(), field_name=input_field_name) data_bundle.set_vocab(src_vocab, input_field_name) for target_field_name in target_field_names: tgt_vocab = Vocabulary(unknown=None, padding=None) tgt_vocab.from_dataset(*[ds for name, ds in data_bundle.iter_datasets() if 'train' in name], - field_name=Const.TARGET, + field_name=target_field_name, no_create_entry_dataset=[ds for name, ds in data_bundle.iter_datasets() - if ('train' not in name) and (ds.has_field(Const.TARGET))] + if ('train' not in name) and (ds.has_field(target_field_name))] ) if len(tgt_vocab._no_create_word) > 0: warn_msg = f"There are {len(tgt_vocab._no_create_word)} target labels" \