diff --git a/fastNLP/io/dataset_loader.py b/fastNLP/io/dataset_loader.py index d175d3b9..e366c6ea 100644 --- a/fastNLP/io/dataset_loader.py +++ b/fastNLP/io/dataset_loader.py @@ -28,6 +28,8 @@ from ..core.instance import Instance from .file_reader import _read_csv, _read_json, _read_conll from .base_loader import DataSetLoader from .data_loader.sst import SSTLoader +from ..core.const import Const + class PeopleDailyCorpusLoader(DataSetLoader): """ @@ -257,9 +259,9 @@ class SNLILoader(JsonLoader): def __init__(self): fields = { - 'sentence1_parse': 'words1', - 'sentence2_parse': 'words2', - 'gold_label': 'target', + 'sentence1_parse': Const.INPUTS(0), + 'sentence2_parse': Const.INPUTS(1), + 'gold_label': Const.TARGET, } super(SNLILoader, self).__init__(fields=fields) @@ -271,10 +273,10 @@ class SNLILoader(JsonLoader): return t.leaves() ds.apply(lambda ins: parse_tree( - ins['words1']), new_field_name='words1') + ins[Const.INPUTS(0)]), new_field_name=Const.INPUTS(0)) ds.apply(lambda ins: parse_tree( - ins['words2']), new_field_name='words2') - ds.drop(lambda x: x['target'] == '-') + ins[Const.INPUTS(1)]), new_field_name=Const.INPUTS(1)) + ds.drop(lambda x: x[Const.TARGET] == '-') return ds diff --git a/reproduction/matching/data/SNLIDataLoader.py b/reproduction/matching/data/SNLIDataLoader.py new file mode 100644 index 00000000..6f6bbecd --- /dev/null +++ b/reproduction/matching/data/SNLIDataLoader.py @@ -0,0 +1,6 @@ + +from fastNLP.io.dataset_loader import SNLILoader + +# TODO: still in progress + + diff --git a/reproduction/matching/model/bert.py b/reproduction/matching/model/bert.py new file mode 100644 index 00000000..6b13ce2a --- /dev/null +++ b/reproduction/matching/model/bert.py @@ -0,0 +1,13 @@ + +from fastNLP.models import BaseModel +from fastNLP.modules.encoder.bert import BertModel + + +class BertForSNLI(BaseModel): + # TODO: still in progress + + def __init(self): + super(BertForSNLI, self).__init__() + + def forward(self, words, segment_id, seq_len): + pass diff --git a/reproduction/matching/test/test_snlidataloader.py b/reproduction/matching/test/test_snlidataloader.py new file mode 100644 index 00000000..9a0fb9ee --- /dev/null +++ b/reproduction/matching/test/test_snlidataloader.py @@ -0,0 +1,10 @@ +import unittest +from reproduction.matching.data import SNLIDataLoader +from fastNLP.core.vocabulary import VocabularyOption + + +class TestCWSDataLoader(unittest.TestCase): + def test_case1(self): + snli_loader = SNLIDataLoader() + # TODO: still in progress +