Browse Source

fix bug in test

tags/v0.4.10
xuyige 5 years ago
parent
commit
66a7cf084e
3 changed files with 14 additions and 8 deletions
  1. +3
    -3
      fastNLP/io/data_loader/matching.py
  2. +1
    -4
      fastNLP/io/dataset_loader.py
  3. +10
    -1
      test/io/test_dataset_loader.py

+ 3
- 3
fastNLP/io/data_loader/matching.py View File

@@ -4,9 +4,9 @@ from typing import Union, Dict


from ...core.const import Const from ...core.const import Const
from ...core.vocabulary import Vocabulary from ...core.vocabulary import Vocabulary
from ...io.base_loader import DataInfo, DataSetLoader
from ...io.dataset_loader import JsonLoader, CSVLoader
from ...io.file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR
from ..base_loader import DataInfo, DataSetLoader
from ..dataset_loader import JsonLoader, CSVLoader
from ..file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR
from ...modules.encoder._bert import BertTokenizer from ...modules.encoder._bert import BertTokenizer






+ 1
- 4
fastNLP/io/dataset_loader.py View File

@@ -16,8 +16,6 @@ __all__ = [
'CSVLoader', 'CSVLoader',
'JsonLoader', 'JsonLoader',
'ConllLoader', 'ConllLoader',
'SNLILoader',
'SSTLoader',
'PeopleDailyCorpusLoader', 'PeopleDailyCorpusLoader',
'Conll2003Loader', 'Conll2003Loader',
] ]
@@ -30,7 +28,6 @@ from ..core.dataset import DataSet
from ..core.instance import Instance from ..core.instance import Instance
from .file_reader import _read_csv, _read_json, _read_conll from .file_reader import _read_csv, _read_json, _read_conll
from .base_loader import DataSetLoader, DataInfo from .base_loader import DataSetLoader, DataInfo
from .data_loader.sst import SSTLoader
from ..core.const import Const from ..core.const import Const
from ..modules.encoder._bert import BertTokenizer from ..modules.encoder._bert import BertTokenizer


@@ -111,7 +108,7 @@ class PeopleDailyCorpusLoader(DataSetLoader):
else: else:
instance = Instance(words=sent_words) instance = Instance(words=sent_words)
data_set.append(instance) data_set.append(instance)
data_set.apply(lambda ins: len(ins["words"]), new_field_name="seq_len")
data_set.apply(lambda ins: len(ins[Const.INPUT]), new_field_name=Const.INPUT_LEN)
return data_set return data_set






+ 10
- 1
test/io/test_dataset_loader.py View File

@@ -1,7 +1,7 @@
import unittest import unittest
import os import os
from fastNLP.io import Conll2003Loader, PeopleDailyCorpusLoader, CSVLoader, JsonLoader from fastNLP.io import Conll2003Loader, PeopleDailyCorpusLoader, CSVLoader, JsonLoader
from fastNLP.io.dataset_loader import SSTLoader, SNLILoader
from fastNLP.io.data_loader import SSTLoader, SNLILoader
from reproduction.text_classification.data.yelpLoader import yelpLoader from reproduction.text_classification.data.yelpLoader import yelpLoader




@@ -61,3 +61,12 @@ class TestDatasetLoader(unittest.TestCase):
print(info.vocabs) print(info.vocabs)
print(info.datasets) print(info.datasets)
os.remove(train), os.remove(test) os.remove(train), os.remove(test)

def test_import(self):
import fastNLP
from fastNLP.io import SNLILoader
ds = SNLILoader().process('test/data_for_tests/sample_snli.jsonl', to_lower=True,
get_index=True, seq_len_type='seq_len')
assert 'train' in ds.datasets
assert len(ds.datasets) == 1
assert len(ds.datasets['train']) == 3

Loading…
Cancel
Save