Browse Source

add test case for extcnndm

tags/v0.4.10
Danqing Wang 5 years ago
parent
commit
41995683ca
6 changed files with 400070 additions and 2 deletions
  1. +1
    -1
      fastNLP/io/pipe/summarization.py
  2. +2
    -1
      reproduction/Summarization/Baseline/train.py
  3. +10
    -0
      test/data_for_tests/cnndm.jsonl
  4. +200000
    -0
      test/data_for_tests/cnndm.vocab
  5. +200000
    -0
      test/io/pipe/cnndm.vocab
  6. +57
    -0
      test/io/pipe/test_extcnndm.py

+ 1
- 1
fastNLP/io/pipe/summarization.py View File

@@ -79,7 +79,7 @@ class ExtCNNDMPipe(Pipe):
if self.domain == True:
domaindict = Vocabulary(padding=None, unknown=DOMAIN_UNK)
domaindict.from_dataset(db, field_name="publication")
domaindict.from_dataset(db.get_dataset("train"), field_name="publication")
db.set_vocab(domaindict, "domain")
return db


+ 2
- 1
reproduction/Summarization/Baseline/train.py View File

@@ -216,7 +216,8 @@ def main():
hps.atten_dropout_prob = 0.0
hps.ffn_dropout_prob = 0.0
logger.info(hps)
db = dbPipe.process_from_file(DATA_FILE)
paths = {"test": DATA_FILE}
db = dbPipe.process_from_file(paths)
else:
paths = {"train": DATA_FILE, "valid": VALID_FILE}
db = dbPipe.process_from_file(paths)


+ 10
- 0
test/data_for_tests/cnndm.jsonl
File diff suppressed because it is too large
View File


+ 200000
- 0
test/data_for_tests/cnndm.vocab
File diff suppressed because it is too large
View File


+ 200000
- 0
test/io/pipe/cnndm.vocab
File diff suppressed because it is too large
View File


+ 57
- 0
test/io/pipe/test_extcnndm.py View File

@@ -0,0 +1,57 @@
#!/usr/bin/python
# -*- coding: utf-8 -*-
# __author__="Danqing Wang"
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import unittest
import os
# import sys
#
# sys.path.append("../../../")
from fastNLP.io import DataBundle
from fastNLP.io.pipe.summarization import ExtCNNDMPipe
class TestRunExtCNNDMPipe(unittest.TestCase):
def test_load(self):
data_set_dict = {
'CNNDM': {"train": 'test/data_for_tests/cnndm.jsonl'},
}
vocab_size = 100000
VOCAL_FILE = 'test/data_for_tests/cnndm.vocab'
sent_max_len = 100
doc_max_timesteps = 50
dbPipe = ExtCNNDMPipe(vocab_size=vocab_size,
vocab_path=VOCAL_FILE,
sent_max_len=sent_max_len,
doc_max_timesteps=doc_max_timesteps)
dbPipe2 = ExtCNNDMPipe(vocab_size=vocab_size,
vocab_path=VOCAL_FILE,
sent_max_len=sent_max_len,
doc_max_timesteps=doc_max_timesteps,
domain=True)
for k, v in data_set_dict.items():
db = dbPipe.process_from_file(v)
db2 = dbPipe2.process_from_file(v)
self.assertTrue(isinstance(db, DataBundle))
self.assertTrue(isinstance(db2, DataBundle))

Loading…
Cancel
Save