Browse Source

fix a bug function _indexize in fastNLP/io/pipe/utils.py

tags/v0.4.10
Yige Xu 5 years ago
parent
commit
e314a36778
1 changed files with 7 additions and 5 deletions
  1. +7
    -5
      fastNLP/io/pipe/utils.py

+ 7
- 5
fastNLP/io/pipe/utils.py View File

@@ -105,18 +105,20 @@ def _indexize(data_bundle, input_field_names=Const.INPUT, target_field_names=Con
target_field_names = [target_field_names]
for input_field_name in input_field_names:
src_vocab = Vocabulary()
src_vocab.from_dataset(data_bundle.datasets['train'], field_name=input_field_name,
no_create_entry_dataset=[dataset for name, dataset in data_bundle.datasets.items() if
name != 'train'])
src_vocab.from_dataset(*[ds for name, ds in data_bundle.iter_datasets() if 'train' in name],
field_name=input_field_name,
no_create_entry_dataset=[ds for name, ds in data_bundle.iter_datasets()
if ('train' not in name) and (ds.has_field(input_field_name))]
)
src_vocab.index_dataset(*data_bundle.datasets.values(), field_name=input_field_name)
data_bundle.set_vocab(src_vocab, input_field_name)
for target_field_name in target_field_names:
tgt_vocab = Vocabulary(unknown=None, padding=None)
tgt_vocab.from_dataset(*[ds for name, ds in data_bundle.iter_datasets() if 'train' in name],
field_name=Const.TARGET,
field_name=target_field_name,
no_create_entry_dataset=[ds for name, ds in data_bundle.iter_datasets()
if ('train' not in name) and (ds.has_field(Const.TARGET))]
if ('train' not in name) and (ds.has_field(target_field_name))]
)
if len(tgt_vocab._no_create_word) > 0:
warn_msg = f"There are {len(tgt_vocab._no_create_word)} target labels" \


Loading…
Cancel
Save