|
|
@@ -105,18 +105,20 @@ def _indexize(data_bundle, input_field_names=Const.INPUT, target_field_names=Con |
|
|
|
target_field_names = [target_field_names] |
|
|
|
for input_field_name in input_field_names: |
|
|
|
src_vocab = Vocabulary() |
|
|
|
src_vocab.from_dataset(data_bundle.datasets['train'], field_name=input_field_name, |
|
|
|
no_create_entry_dataset=[dataset for name, dataset in data_bundle.datasets.items() if |
|
|
|
name != 'train']) |
|
|
|
src_vocab.from_dataset(*[ds for name, ds in data_bundle.iter_datasets() if 'train' in name], |
|
|
|
field_name=input_field_name, |
|
|
|
no_create_entry_dataset=[ds for name, ds in data_bundle.iter_datasets() |
|
|
|
if ('train' not in name) and (ds.has_field(input_field_name))] |
|
|
|
) |
|
|
|
src_vocab.index_dataset(*data_bundle.datasets.values(), field_name=input_field_name) |
|
|
|
data_bundle.set_vocab(src_vocab, input_field_name) |
|
|
|
|
|
|
|
for target_field_name in target_field_names: |
|
|
|
tgt_vocab = Vocabulary(unknown=None, padding=None) |
|
|
|
tgt_vocab.from_dataset(*[ds for name, ds in data_bundle.iter_datasets() if 'train' in name], |
|
|
|
field_name=Const.TARGET, |
|
|
|
field_name=target_field_name, |
|
|
|
no_create_entry_dataset=[ds for name, ds in data_bundle.iter_datasets() |
|
|
|
if ('train' not in name) and (ds.has_field(Const.TARGET))] |
|
|
|
if ('train' not in name) and (ds.has_field(target_field_name))] |
|
|
|
) |
|
|
|
if len(tgt_vocab._no_create_word) > 0: |
|
|
|
warn_msg = f"There are {len(tgt_vocab._no_create_word)} target labels" \ |
|
|
|