Browse Source

firstly add matching in reproduction

tags/v0.4.10
xuyige 6 years ago
parent
commit
e05c182b05
4 changed files with 37 additions and 6 deletions
  1. +8
    -6
      fastNLP/io/dataset_loader.py
  2. +6
    -0
      reproduction/matching/data/SNLIDataLoader.py
  3. +13
    -0
      reproduction/matching/model/bert.py
  4. +10
    -0
      reproduction/matching/test/test_snlidataloader.py

+ 8
- 6
fastNLP/io/dataset_loader.py View File

@@ -28,6 +28,8 @@ from ..core.instance import Instance
from .file_reader import _read_csv, _read_json, _read_conll from .file_reader import _read_csv, _read_json, _read_conll
from .base_loader import DataSetLoader from .base_loader import DataSetLoader
from .data_loader.sst import SSTLoader from .data_loader.sst import SSTLoader
from ..core.const import Const



class PeopleDailyCorpusLoader(DataSetLoader): class PeopleDailyCorpusLoader(DataSetLoader):
""" """
@@ -257,9 +259,9 @@ class SNLILoader(JsonLoader):


def __init__(self): def __init__(self):
fields = { fields = {
'sentence1_parse': 'words1',
'sentence2_parse': 'words2',
'gold_label': 'target',
'sentence1_parse': Const.INPUTS(0),
'sentence2_parse': Const.INPUTS(1),
'gold_label': Const.TARGET,
} }
super(SNLILoader, self).__init__(fields=fields) super(SNLILoader, self).__init__(fields=fields)


@@ -271,10 +273,10 @@ class SNLILoader(JsonLoader):
return t.leaves() return t.leaves()


ds.apply(lambda ins: parse_tree( ds.apply(lambda ins: parse_tree(
ins['words1']), new_field_name='words1')
ins[Const.INPUTS(0)]), new_field_name=Const.INPUTS(0))
ds.apply(lambda ins: parse_tree( ds.apply(lambda ins: parse_tree(
ins['words2']), new_field_name='words2')
ds.drop(lambda x: x['target'] == '-')
ins[Const.INPUTS(1)]), new_field_name=Const.INPUTS(1))
ds.drop(lambda x: x[Const.TARGET] == '-')
return ds return ds






+ 6
- 0
reproduction/matching/data/SNLIDataLoader.py View File

@@ -0,0 +1,6 @@

from fastNLP.io.dataset_loader import SNLILoader

# TODO: still in progress



+ 13
- 0
reproduction/matching/model/bert.py View File

@@ -0,0 +1,13 @@

from fastNLP.models import BaseModel
from fastNLP.modules.encoder.bert import BertModel


class BertForSNLI(BaseModel):
# TODO: still in progress

def __init(self):
super(BertForSNLI, self).__init__()

def forward(self, words, segment_id, seq_len):
pass

+ 10
- 0
reproduction/matching/test/test_snlidataloader.py View File

@@ -0,0 +1,10 @@
import unittest
from reproduction.matching.data import SNLIDataLoader
from fastNLP.core.vocabulary import VocabularyOption


class TestCWSDataLoader(unittest.TestCase):
def test_case1(self):
snli_loader = SNLIDataLoader()
# TODO: still in progress


Loading…
Cancel
Save