Browse Source

fix some bugs in test code.

tags/v0.4.10
Yige Xu 5 years ago
parent
commit
091f24e393
4 changed files with 15 additions and 7 deletions
  1. +3
    -0
      test/__init__.py
  2. +11
    -6
      test/core/test_utils.py
  3. +0
    -0
      test/models/__init__.py
  4. +1
    -1
      test/models/test_bert.py

+ 3
- 0
test/__init__.py View File

@@ -0,0 +1,3 @@
import fastNLP

__all__ = ["fastNLP"]

+ 11
- 6
test/core/test_utils.py View File

@@ -119,7 +119,8 @@ class TestCache(unittest.TestCase):
def test_cache_save(self):
try:
start_time = time.time()
embed, vocab, d = process_data_1('test/data_for_tests/word2vec_test.txt', 'test/data_for_tests/cws_train')
embed, vocab, d = process_data_1('test/data_for_tests/embedding/small_static_embedding/word2vec_test.txt',
'test/data_for_tests/cws_train')
end_time = time.time()
pre_time = end_time - start_time
with open('test/demo1.pkl', 'rb') as f:
@@ -128,7 +129,8 @@ class TestCache(unittest.TestCase):
for i in range(embed.shape[0]):
self.assertListEqual(embed[i].tolist(), _embed[i].tolist())
start_time = time.time()
embed, vocab, d = process_data_1('test/data_for_tests/word2vec_test.txt', 'test/data_for_tests/cws_train')
embed, vocab, d = process_data_1('test/data_for_tests/embedding/small_static_embedding/word2vec_test.txt',
'test/data_for_tests/cws_train')
end_time = time.time()
read_time = end_time - start_time
print("Read using {:.3f}, while prepare using:{:.3f}".format(read_time, pre_time))
@@ -139,7 +141,7 @@ class TestCache(unittest.TestCase):
def test_cache_save_overwrite_path(self):
try:
start_time = time.time()
embed, vocab, d = process_data_1('test/data_for_tests/word2vec_test.txt', 'test/data_for_tests/cws_train',
embed, vocab, d = process_data_1('test/data_for_tests/embedding/small_static_embedding/word2vec_test.txt', 'test/data_for_tests/cws_train',
_cache_fp='test/demo_overwrite.pkl')
end_time = time.time()
pre_time = end_time - start_time
@@ -149,7 +151,8 @@ class TestCache(unittest.TestCase):
for i in range(embed.shape[0]):
self.assertListEqual(embed[i].tolist(), _embed[i].tolist())
start_time = time.time()
embed, vocab, d = process_data_1('test/data_for_tests/word2vec_test.txt', 'test/data_for_tests/cws_train',
embed, vocab, d = process_data_1('test/data_for_tests/embedding/small_static_embedding/word2vec_test.txt',
'test/data_for_tests/cws_train',
_cache_fp='test/demo_overwrite.pkl')
end_time = time.time()
read_time = end_time - start_time
@@ -161,7 +164,8 @@ class TestCache(unittest.TestCase):
def test_cache_refresh(self):
try:
start_time = time.time()
embed, vocab, d = process_data_1('test/data_for_tests/word2vec_test.txt', 'test/data_for_tests/cws_train',
embed, vocab, d = process_data_1('test/data_for_tests/embedding/small_static_embedding/word2vec_test.txt',
'test/data_for_tests/cws_train',
_refresh=True)
end_time = time.time()
pre_time = end_time - start_time
@@ -171,7 +175,8 @@ class TestCache(unittest.TestCase):
for i in range(embed.shape[0]):
self.assertListEqual(embed[i].tolist(), _embed[i].tolist())
start_time = time.time()
embed, vocab, d = process_data_1('test/data_for_tests/word2vec_test.txt', 'test/data_for_tests/cws_train',
embed, vocab, d = process_data_1('test/data_for_tests/embedding/small_static_embedding/word2vec_test.txt',
'test/data_for_tests/cws_train',
_refresh=True)
end_time = time.time()
read_time = end_time - start_time


+ 0
- 0
test/models/__init__.py View File


+ 1
- 1
test/models/test_bert.py View File

@@ -82,7 +82,7 @@ class TestBert(unittest.TestCase):
def test_bert_5(self):

vocab = Vocabulary().add_word_lst("this is a test [SEP] .".split())
embed = BertEmbedding(vocab, model_dir_or_name='./../data_for_tests/embedding/small_bert',
embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert',
include_cls_sep=True)
model = BertForSentenceMatching(embed)



Loading…
Cancel
Save