Browse Source

fix bug in matching DataLoader

tags/v0.4.10
xuyige 5 years ago
parent
commit
7d07b38e0a
3 changed files with 21 additions and 15 deletions
  1. +13
    -13
      fastNLP/io/data_loader/matching.py
  2. +1
    -1
      reproduction/matching/matching_mwan.py
  3. +7
    -1
      test/io/test_dataset_loader.py

+ 13
- 13
fastNLP/io/data_loader/matching.py View File

@@ -1,6 +1,6 @@
import os

from typing import Union, Dict , List
from typing import Union, Dict, List

from ...core.const import Const
from ...core.vocabulary import Vocabulary
@@ -34,7 +34,7 @@ class MatchingLoader(DataSetLoader):
cut_text: int = None, get_index=True, auto_pad_length: int=None,
auto_pad_token: str='<pad>', set_input: Union[list, str, bool]=True,
set_target: Union[list, str, bool]=True, concat: Union[str, list, bool]=None,
extra_split: List[str]=List['-'], ) -> DataBundle:
extra_split: List[str]=None, ) -> DataBundle:
"""
:param paths: str或者Dict[str, str]。如果是str,则为数据集所在的文件夹或者是全路径文件名:如果是文件夹,
则会从self.paths里面找对应的数据集名称与文件名。如果是Dict,则为数据集名称(如train、dev、test)和
@@ -91,22 +91,22 @@ class MatchingLoader(DataSetLoader):
if Const.TARGET in data_set.get_field_names():
data_set.set_target(Const.TARGET)

if extra_split:
if extra_split is not None:
for data_name, data_set in data_info.datasets.items():
data_set.apply(lambda x: ' '.join(x[Const.INPUTS(0)]), new_field_name=Const.INPUTS(0))
data_set.apply(lambda x: ' '.join(x[Const.INPUTS(1)]), new_field_name=Const.INPUTS(1))

for s in extra_split:
data_set.apply(lambda x: x[Const.INPUTS(0)].replace(s , ' ' + s + ' '),
new_field_name=Const.INPUTS(0))
data_set.apply(lambda x: x[Const.INPUTS(0)].replace(s , ' ' + s + ' '),
new_field_name=Const.INPUTS(0))
_filt = lambda x : x
data_set.apply(lambda x: list(filter(_filt , x[Const.INPUTS(0)].split(' '))),
new_field_name=Const.INPUTS(0), is_input=auto_set_input)
data_set.apply(lambda x: list(filter(_filt , x[Const.INPUTS(1)].split(' '))),
new_field_name=Const.INPUTS(1), is_input=auto_set_input)
data_set.apply(lambda x: x[Const.INPUTS(0)].replace(s, ' ' + s + ' '),
new_field_name=Const.INPUTS(0))
data_set.apply(lambda x: x[Const.INPUTS(0)].replace(s, ' ' + s + ' '),
new_field_name=Const.INPUTS(0))
_filt = lambda x: x
data_set.apply(lambda x: list(filter(_filt, x[Const.INPUTS(0)].split(' '))),
new_field_name=Const.INPUTS(0), is_input=auto_set_input)
data_set.apply(lambda x: list(filter(_filt, x[Const.INPUTS(1)].split(' '))),
new_field_name=Const.INPUTS(1), is_input=auto_set_input)
_filt = None

if to_lower:


+ 1
- 1
reproduction/matching/matching_mwan.py View File

@@ -18,7 +18,7 @@ from fastNLP.core.callback import GradientClipCallback, LRScheduler, FitlogCallb
from fastNLP.modules.encoder.embedding import ElmoEmbedding, StaticEmbedding

from fastNLP.io.data_loader import MNLILoader, QNLILoader, QuoraLoader, SNLILoader, RTELoader
from model.mwan import MwanModel
from reproduction.matching.model.mwan import MwanModel

import fitlog
fitlog.debug()


+ 7
- 1
test/io/test_dataset_loader.py View File

@@ -64,7 +64,13 @@ class TestDatasetLoader(unittest.TestCase):
def test_import(self):
import fastNLP
from fastNLP.io import SNLILoader
ds = SNLILoader().process('test/data_for_tests/sample_snli.jsonl', to_lower=True,
ds = SNLILoader().process('./../data_for_tests/sample_snli.jsonl', to_lower=True,
get_index=True, seq_len_type='seq_len', extra_split=['-'])
assert 'train' in ds.datasets
assert len(ds.datasets) == 1
assert len(ds.datasets['train']) == 3

ds = SNLILoader().process('./../data_for_tests/sample_snli.jsonl', to_lower=True,
get_index=True, seq_len_type='seq_len')
assert 'train' in ds.datasets
assert len(ds.datasets) == 1


Loading…
Cancel
Save