diff --git a/fastNLP/io/data_loader/matching.py b/fastNLP/io/data_loader/matching.py index 21dcefb0..cecaee96 100644 --- a/fastNLP/io/data_loader/matching.py +++ b/fastNLP/io/data_loader/matching.py @@ -1,6 +1,6 @@ import os -from typing import Union, Dict , List +from typing import Union, Dict, List from ...core.const import Const from ...core.vocabulary import Vocabulary @@ -34,7 +34,7 @@ class MatchingLoader(DataSetLoader): cut_text: int = None, get_index=True, auto_pad_length: int=None, auto_pad_token: str='', set_input: Union[list, str, bool]=True, set_target: Union[list, str, bool]=True, concat: Union[str, list, bool]=None, - extra_split: List[str]=List['-'], ) -> DataBundle: + extra_split: List[str]=None, ) -> DataBundle: """ :param paths: str或者Dict[str, str]。如果是str,则为数据集所在的文件夹或者是全路径文件名:如果是文件夹, 则会从self.paths里面找对应的数据集名称与文件名。如果是Dict,则为数据集名称(如train、dev、test)和 @@ -91,22 +91,22 @@ class MatchingLoader(DataSetLoader): if Const.TARGET in data_set.get_field_names(): data_set.set_target(Const.TARGET) - if extra_split: + if extra_split is not None: for data_name, data_set in data_info.datasets.items(): data_set.apply(lambda x: ' '.join(x[Const.INPUTS(0)]), new_field_name=Const.INPUTS(0)) data_set.apply(lambda x: ' '.join(x[Const.INPUTS(1)]), new_field_name=Const.INPUTS(1)) for s in extra_split: - data_set.apply(lambda x: x[Const.INPUTS(0)].replace(s , ' ' + s + ' '), - new_field_name=Const.INPUTS(0)) - data_set.apply(lambda x: x[Const.INPUTS(0)].replace(s , ' ' + s + ' '), - new_field_name=Const.INPUTS(0)) - - _filt = lambda x : x - data_set.apply(lambda x: list(filter(_filt , x[Const.INPUTS(0)].split(' '))), - new_field_name=Const.INPUTS(0), is_input=auto_set_input) - data_set.apply(lambda x: list(filter(_filt , x[Const.INPUTS(1)].split(' '))), - new_field_name=Const.INPUTS(1), is_input=auto_set_input) + data_set.apply(lambda x: x[Const.INPUTS(0)].replace(s, ' ' + s + ' '), + new_field_name=Const.INPUTS(0)) + data_set.apply(lambda x: x[Const.INPUTS(0)].replace(s, ' ' + s + ' '), + new_field_name=Const.INPUTS(0)) + + _filt = lambda x: x + data_set.apply(lambda x: list(filter(_filt, x[Const.INPUTS(0)].split(' '))), + new_field_name=Const.INPUTS(0), is_input=auto_set_input) + data_set.apply(lambda x: list(filter(_filt, x[Const.INPUTS(1)].split(' '))), + new_field_name=Const.INPUTS(1), is_input=auto_set_input) _filt = None if to_lower: diff --git a/reproduction/matching/matching_mwan.py b/reproduction/matching/matching_mwan.py index d2d3033f..e96ee0c9 100644 --- a/reproduction/matching/matching_mwan.py +++ b/reproduction/matching/matching_mwan.py @@ -18,7 +18,7 @@ from fastNLP.core.callback import GradientClipCallback, LRScheduler, FitlogCallb from fastNLP.modules.encoder.embedding import ElmoEmbedding, StaticEmbedding from fastNLP.io.data_loader import MNLILoader, QNLILoader, QuoraLoader, SNLILoader, RTELoader -from model.mwan import MwanModel +from reproduction.matching.model.mwan import MwanModel import fitlog fitlog.debug() diff --git a/test/io/test_dataset_loader.py b/test/io/test_dataset_loader.py index 3e3c54e2..9f0579a1 100644 --- a/test/io/test_dataset_loader.py +++ b/test/io/test_dataset_loader.py @@ -64,7 +64,13 @@ class TestDatasetLoader(unittest.TestCase): def test_import(self): import fastNLP from fastNLP.io import SNLILoader - ds = SNLILoader().process('test/data_for_tests/sample_snli.jsonl', to_lower=True, + ds = SNLILoader().process('./../data_for_tests/sample_snli.jsonl', to_lower=True, + get_index=True, seq_len_type='seq_len', extra_split=['-']) + assert 'train' in ds.datasets + assert len(ds.datasets) == 1 + assert len(ds.datasets['train']) == 3 + + ds = SNLILoader().process('./../data_for_tests/sample_snli.jsonl', to_lower=True, get_index=True, seq_len_type='seq_len') assert 'train' in ds.datasets assert len(ds.datasets) == 1