diff --git a/fastNLP/io/dataset_loader.py b/fastNLP/io/dataset_loader.py index 40604deb..8c697e4a 100644 --- a/fastNLP/io/dataset_loader.py +++ b/fastNLP/io/dataset_loader.py @@ -22,6 +22,8 @@ __all__ = [ 'SSTLoader', 'PeopleDailyCorpusLoader', 'Conll2003Loader', + 'VocabularyOption', + 'EmbeddingOption', ] from nltk.tree import Tree @@ -32,8 +34,8 @@ from .file_reader import _read_csv, _read_json, _read_conll from typing import Union, Dict, Iterable import os from ..core.utils import Example -from ..core import Vocabulary -from ..io import EmbedLoader +from ..core.vocabulary import Vocabulary +from ..io.embed_loader import EmbedLoader import numpy as np @@ -448,14 +450,13 @@ class SSTLoader(DataSetLoader): info = DataInfo(datasets=self.load(paths)) _train_ds = [info.datasets[name] for name in train_ds] if train_ds else info.datasets.values() - - src_vocab.from_dataset(_train_ds, field_name=input_name) - tgt_vocab.from_dataset(_train_ds, field_name=target_name) + src_vocab.from_dataset(*_train_ds, field_name=input_name) + tgt_vocab.from_dataset(*_train_ds, field_name=target_name) src_vocab.index_dataset( - info.datasets.values(), + *info.datasets.values(), field_name=input_name, new_field_name=input_name) tgt_vocab.index_dataset( - info.datasets.values(), + *info.datasets.values(), field_name=target_name, new_field_name=target_name) info.vocabs = { input_name: src_vocab,