From 66a7cf084ec7caa9d95319619c4e7cb1720d2818 Mon Sep 17 00:00:00 2001 From: xuyige Date: Sat, 6 Jul 2019 01:36:11 +0800 Subject: [PATCH] fix bug in test --- fastNLP/io/data_loader/matching.py | 6 +++--- fastNLP/io/dataset_loader.py | 5 +---- test/io/test_dataset_loader.py | 11 ++++++++++- 3 files changed, 14 insertions(+), 8 deletions(-) diff --git a/fastNLP/io/data_loader/matching.py b/fastNLP/io/data_loader/matching.py index 1cde950f..3d131bcb 100644 --- a/fastNLP/io/data_loader/matching.py +++ b/fastNLP/io/data_loader/matching.py @@ -4,9 +4,9 @@ from typing import Union, Dict from ...core.const import Const from ...core.vocabulary import Vocabulary -from ...io.base_loader import DataInfo, DataSetLoader -from ...io.dataset_loader import JsonLoader, CSVLoader -from ...io.file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR +from ..base_loader import DataInfo, DataSetLoader +from ..dataset_loader import JsonLoader, CSVLoader +from ..file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR from ...modules.encoder._bert import BertTokenizer diff --git a/fastNLP/io/dataset_loader.py b/fastNLP/io/dataset_loader.py index 26edd8bd..2881e6e9 100644 --- a/fastNLP/io/dataset_loader.py +++ b/fastNLP/io/dataset_loader.py @@ -16,8 +16,6 @@ __all__ = [ 'CSVLoader', 'JsonLoader', 'ConllLoader', - 'SNLILoader', - 'SSTLoader', 'PeopleDailyCorpusLoader', 'Conll2003Loader', ] @@ -30,7 +28,6 @@ from ..core.dataset import DataSet from ..core.instance import Instance from .file_reader import _read_csv, _read_json, _read_conll from .base_loader import DataSetLoader, DataInfo -from .data_loader.sst import SSTLoader from ..core.const import Const from ..modules.encoder._bert import BertTokenizer @@ -111,7 +108,7 @@ class PeopleDailyCorpusLoader(DataSetLoader): else: instance = Instance(words=sent_words) data_set.append(instance) - data_set.apply(lambda ins: len(ins["words"]), new_field_name="seq_len") + data_set.apply(lambda ins: len(ins[Const.INPUT]), new_field_name=Const.INPUT_LEN) return data_set diff --git a/test/io/test_dataset_loader.py b/test/io/test_dataset_loader.py index b091339e..09ad8c83 100644 --- a/test/io/test_dataset_loader.py +++ b/test/io/test_dataset_loader.py @@ -1,7 +1,7 @@ import unittest import os from fastNLP.io import Conll2003Loader, PeopleDailyCorpusLoader, CSVLoader, JsonLoader -from fastNLP.io.dataset_loader import SSTLoader, SNLILoader +from fastNLP.io.data_loader import SSTLoader, SNLILoader from reproduction.text_classification.data.yelpLoader import yelpLoader @@ -61,3 +61,12 @@ class TestDatasetLoader(unittest.TestCase): print(info.vocabs) print(info.datasets) os.remove(train), os.remove(test) + + def test_import(self): + import fastNLP + from fastNLP.io import SNLILoader + ds = SNLILoader().process('test/data_for_tests/sample_snli.jsonl', to_lower=True, + get_index=True, seq_len_type='seq_len') + assert 'train' in ds.datasets + assert len(ds.datasets) == 1 + assert len(ds.datasets['train']) == 3