@@ -4,9 +4,9 @@ from typing import Union, Dict | |||||
from ...core.const import Const | from ...core.const import Const | ||||
from ...core.vocabulary import Vocabulary | from ...core.vocabulary import Vocabulary | ||||
from ...io.base_loader import DataInfo, DataSetLoader | |||||
from ...io.dataset_loader import JsonLoader, CSVLoader | |||||
from ...io.file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR | |||||
from ..base_loader import DataInfo, DataSetLoader | |||||
from ..dataset_loader import JsonLoader, CSVLoader | |||||
from ..file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR | |||||
from ...modules.encoder._bert import BertTokenizer | from ...modules.encoder._bert import BertTokenizer | ||||
@@ -16,8 +16,6 @@ __all__ = [ | |||||
'CSVLoader', | 'CSVLoader', | ||||
'JsonLoader', | 'JsonLoader', | ||||
'ConllLoader', | 'ConllLoader', | ||||
'SNLILoader', | |||||
'SSTLoader', | |||||
'PeopleDailyCorpusLoader', | 'PeopleDailyCorpusLoader', | ||||
'Conll2003Loader', | 'Conll2003Loader', | ||||
] | ] | ||||
@@ -30,7 +28,6 @@ from ..core.dataset import DataSet | |||||
from ..core.instance import Instance | from ..core.instance import Instance | ||||
from .file_reader import _read_csv, _read_json, _read_conll | from .file_reader import _read_csv, _read_json, _read_conll | ||||
from .base_loader import DataSetLoader, DataInfo | from .base_loader import DataSetLoader, DataInfo | ||||
from .data_loader.sst import SSTLoader | |||||
from ..core.const import Const | from ..core.const import Const | ||||
from ..modules.encoder._bert import BertTokenizer | from ..modules.encoder._bert import BertTokenizer | ||||
@@ -111,7 +108,7 @@ class PeopleDailyCorpusLoader(DataSetLoader): | |||||
else: | else: | ||||
instance = Instance(words=sent_words) | instance = Instance(words=sent_words) | ||||
data_set.append(instance) | data_set.append(instance) | ||||
data_set.apply(lambda ins: len(ins["words"]), new_field_name="seq_len") | |||||
data_set.apply(lambda ins: len(ins[Const.INPUT]), new_field_name=Const.INPUT_LEN) | |||||
return data_set | return data_set | ||||
@@ -1,7 +1,7 @@ | |||||
import unittest | import unittest | ||||
import os | import os | ||||
from fastNLP.io import Conll2003Loader, PeopleDailyCorpusLoader, CSVLoader, JsonLoader | from fastNLP.io import Conll2003Loader, PeopleDailyCorpusLoader, CSVLoader, JsonLoader | ||||
from fastNLP.io.dataset_loader import SSTLoader, SNLILoader | |||||
from fastNLP.io.data_loader import SSTLoader, SNLILoader | |||||
from reproduction.text_classification.data.yelpLoader import yelpLoader | from reproduction.text_classification.data.yelpLoader import yelpLoader | ||||
@@ -61,3 +61,12 @@ class TestDatasetLoader(unittest.TestCase): | |||||
print(info.vocabs) | print(info.vocabs) | ||||
print(info.datasets) | print(info.datasets) | ||||
os.remove(train), os.remove(test) | os.remove(train), os.remove(test) | ||||
def test_import(self): | |||||
import fastNLP | |||||
from fastNLP.io import SNLILoader | |||||
ds = SNLILoader().process('test/data_for_tests/sample_snli.jsonl', to_lower=True, | |||||
get_index=True, seq_len_type='seq_len') | |||||
assert 'train' in ds.datasets | |||||
assert len(ds.datasets) == 1 | |||||
assert len(ds.datasets['train']) == 3 |