From 488ce6bbeceae75069b01321dbfc08ef929187a0 Mon Sep 17 00:00:00 2001 From: xuyige Date: Tue, 9 Jul 2019 13:56:01 +0800 Subject: [PATCH] Update test_dataLoader.py --- .../Baseline/test/test_dataLoader.py | 24 ++++++++++++++----- 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/reproduction/Summarization/Baseline/test/test_dataLoader.py b/reproduction/Summarization/Baseline/test/test_dataLoader.py index 987f8778..53aab547 100644 --- a/reproduction/Summarization/Baseline/test/test_dataLoader.py +++ b/reproduction/Summarization/Baseline/test/test_dataLoader.py @@ -1,24 +1,36 @@ - import unittest -from ..data.dataloader import SummarizationLoader +import sys +sys.path.append('..') + +from data.dataloader import SummarizationLoader + +vocab_size = 100000 +vocab_path = "testdata/vocab" +sent_max_len = 100 +doc_max_timesteps = 50 class TestSummarizationLoader(unittest.TestCase): + def test_case1(self): sum_loader = SummarizationLoader() paths = {"train":"testdata/train.jsonl", "valid":"testdata/val.jsonl", "test":"testdata/test.jsonl"} - data = sum_loader.process(paths=paths) + data = sum_loader.process(paths=paths, vocab_size=vocab_size, vocab_path=vocab_path, sent_max_len=sent_max_len, doc_max_timesteps=doc_max_timesteps) print(data.datasets) def test_case2(self): sum_loader = SummarizationLoader() paths = {"train": "testdata/train.jsonl", "valid": "testdata/val.jsonl", "test": "testdata/test.jsonl"} - data = sum_loader.process(paths=paths, domain=True) + data = sum_loader.process(paths=paths, vocab_size=vocab_size, vocab_path=vocab_path, sent_max_len=sent_max_len, doc_max_timesteps=doc_max_timesteps, domain=True) print(data.datasets, data.vocabs) def test_case3(self): sum_loader = SummarizationLoader() paths = {"train": "testdata/train.jsonl", "valid": "testdata/val.jsonl", "test": "testdata/test.jsonl"} - data = sum_loader.process(paths=paths, tag=True) - print(data.datasets, data.vocabs) \ No newline at end of file + data = sum_loader.process(paths=paths, vocab_size=vocab_size, vocab_path=vocab_path, sent_max_len=sent_max_len, doc_max_timesteps=doc_max_timesteps, tag=True) + print(data.datasets, data.vocabs) + + + +