@@ -228,7 +228,7 @@ class TestDataSetMethods(unittest.TestCase): | |||||
def split_sent(ins): | def split_sent(ins): | ||||
return ins['raw_sentence'].split() | return ins['raw_sentence'].split() | ||||
csv_loader = CSVLoader(headers=['raw_sentence', 'label'], sep='\t') | csv_loader = CSVLoader(headers=['raw_sentence', 'label'], sep='\t') | ||||
data_bundle = csv_loader.load('test/data_for_tests/tutorial_sample_dataset.csv') | |||||
data_bundle = csv_loader.load('tests/data_for_tests/tutorial_sample_dataset.csv') | |||||
dataset = data_bundle.datasets['train'] | dataset = data_bundle.datasets['train'] | ||||
dataset.drop(lambda x: len(x['raw_sentence'].split()) == 0, inplace=True) | dataset.drop(lambda x: len(x['raw_sentence'].split()) == 0, inplace=True) | ||||
dataset.apply(split_sent, new_field_name='words', is_input=True) | dataset.apply(split_sent, new_field_name='words', is_input=True) | ||||
@@ -120,8 +120,8 @@ class TestCache(unittest.TestCase): | |||||
def test_cache_save(self): | def test_cache_save(self): | ||||
try: | try: | ||||
start_time = time.time() | start_time = time.time() | ||||
embed, vocab, d = process_data_1('test/data_for_tests/embedding/small_static_embedding/word2vec_test.txt', | |||||
'test/data_for_tests/cws_train') | |||||
embed, vocab, d = process_data_1('tests/data_for_tests/embedding/small_static_embedding/word2vec_test.txt', | |||||
'tests/data_for_tests/cws_train') | |||||
end_time = time.time() | end_time = time.time() | ||||
pre_time = end_time - start_time | pre_time = end_time - start_time | ||||
with open('test/demo1.pkl', 'rb') as f: | with open('test/demo1.pkl', 'rb') as f: | ||||
@@ -130,8 +130,8 @@ class TestCache(unittest.TestCase): | |||||
for i in range(embed.shape[0]): | for i in range(embed.shape[0]): | ||||
self.assertListEqual(embed[i].tolist(), _embed[i].tolist()) | self.assertListEqual(embed[i].tolist(), _embed[i].tolist()) | ||||
start_time = time.time() | start_time = time.time() | ||||
embed, vocab, d = process_data_1('test/data_for_tests/embedding/small_static_embedding/word2vec_test.txt', | |||||
'test/data_for_tests/cws_train') | |||||
embed, vocab, d = process_data_1('tests/data_for_tests/embedding/small_static_embedding/word2vec_test.txt', | |||||
'tests/data_for_tests/cws_train') | |||||
end_time = time.time() | end_time = time.time() | ||||
read_time = end_time - start_time | read_time = end_time - start_time | ||||
print("Read using {:.3f}, while prepare using:{:.3f}".format(read_time, pre_time)) | print("Read using {:.3f}, while prepare using:{:.3f}".format(read_time, pre_time)) | ||||
@@ -142,7 +142,7 @@ class TestCache(unittest.TestCase): | |||||
def test_cache_save_overwrite_path(self): | def test_cache_save_overwrite_path(self): | ||||
try: | try: | ||||
start_time = time.time() | start_time = time.time() | ||||
embed, vocab, d = process_data_1('test/data_for_tests/embedding/small_static_embedding/word2vec_test.txt', 'test/data_for_tests/cws_train', | |||||
embed, vocab, d = process_data_1('tests/data_for_tests/embedding/small_static_embedding/word2vec_test.txt', 'tests/data_for_tests/cws_train', | |||||
_cache_fp='test/demo_overwrite.pkl') | _cache_fp='test/demo_overwrite.pkl') | ||||
end_time = time.time() | end_time = time.time() | ||||
pre_time = end_time - start_time | pre_time = end_time - start_time | ||||
@@ -152,8 +152,8 @@ class TestCache(unittest.TestCase): | |||||
for i in range(embed.shape[0]): | for i in range(embed.shape[0]): | ||||
self.assertListEqual(embed[i].tolist(), _embed[i].tolist()) | self.assertListEqual(embed[i].tolist(), _embed[i].tolist()) | ||||
start_time = time.time() | start_time = time.time() | ||||
embed, vocab, d = process_data_1('test/data_for_tests/embedding/small_static_embedding/word2vec_test.txt', | |||||
'test/data_for_tests/cws_train', | |||||
embed, vocab, d = process_data_1('tests/data_for_tests/embedding/small_static_embedding/word2vec_test.txt', | |||||
'tests/data_for_tests/cws_train', | |||||
_cache_fp='test/demo_overwrite.pkl') | _cache_fp='test/demo_overwrite.pkl') | ||||
end_time = time.time() | end_time = time.time() | ||||
read_time = end_time - start_time | read_time = end_time - start_time | ||||
@@ -165,8 +165,8 @@ class TestCache(unittest.TestCase): | |||||
def test_cache_refresh(self): | def test_cache_refresh(self): | ||||
try: | try: | ||||
start_time = time.time() | start_time = time.time() | ||||
embed, vocab, d = process_data_1('test/data_for_tests/embedding/small_static_embedding/word2vec_test.txt', | |||||
'test/data_for_tests/cws_train', | |||||
embed, vocab, d = process_data_1('tests/data_for_tests/embedding/small_static_embedding/word2vec_test.txt', | |||||
'tests/data_for_tests/cws_train', | |||||
_refresh=True) | _refresh=True) | ||||
end_time = time.time() | end_time = time.time() | ||||
pre_time = end_time - start_time | pre_time = end_time - start_time | ||||
@@ -176,8 +176,8 @@ class TestCache(unittest.TestCase): | |||||
for i in range(embed.shape[0]): | for i in range(embed.shape[0]): | ||||
self.assertListEqual(embed[i].tolist(), _embed[i].tolist()) | self.assertListEqual(embed[i].tolist(), _embed[i].tolist()) | ||||
start_time = time.time() | start_time = time.time() | ||||
embed, vocab, d = process_data_1('test/data_for_tests/embedding/small_static_embedding/word2vec_test.txt', | |||||
'test/data_for_tests/cws_train', | |||||
embed, vocab, d = process_data_1('tests/data_for_tests/embedding/small_static_embedding/word2vec_test.txt', | |||||
'tests/data_for_tests/cws_train', | |||||
_refresh=True) | _refresh=True) | ||||
end_time = time.time() | end_time = time.time() | ||||
read_time = end_time - start_time | read_time = end_time - start_time | ||||
@@ -32,7 +32,7 @@ class TestDownload(unittest.TestCase): | |||||
class TestBertEmbedding(unittest.TestCase): | class TestBertEmbedding(unittest.TestCase): | ||||
def test_bert_embedding_1(self): | def test_bert_embedding_1(self): | ||||
vocab = Vocabulary().add_word_lst("this is a test . [SEP] NotInBERT".split()) | vocab = Vocabulary().add_word_lst("this is a test . [SEP] NotInBERT".split()) | ||||
embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', word_dropout=0.1) | |||||
embed = BertEmbedding(vocab, model_dir_or_name='tests/data_for_tests/embedding/small_bert', word_dropout=0.1) | |||||
requires_grad = embed.requires_grad | requires_grad = embed.requires_grad | ||||
embed.requires_grad = not requires_grad | embed.requires_grad = not requires_grad | ||||
embed.train() | embed.train() | ||||
@@ -40,14 +40,14 @@ class TestBertEmbedding(unittest.TestCase): | |||||
result = embed(words) | result = embed(words) | ||||
self.assertEqual(result.size(), (1, 4, 16)) | self.assertEqual(result.size(), (1, 4, 16)) | ||||
embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', word_dropout=0.1) | |||||
embed = BertEmbedding(vocab, model_dir_or_name='tests/data_for_tests/embedding/small_bert', word_dropout=0.1) | |||||
embed.eval() | embed.eval() | ||||
words = torch.LongTensor([[2, 3, 4, 0]]) | words = torch.LongTensor([[2, 3, 4, 0]]) | ||||
result = embed(words) | result = embed(words) | ||||
self.assertEqual(result.size(), (1, 4, 16)) | self.assertEqual(result.size(), (1, 4, 16)) | ||||
# 自动截断而不报错 | # 自动截断而不报错 | ||||
embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', word_dropout=0.1, | |||||
embed = BertEmbedding(vocab, model_dir_or_name='tests/data_for_tests/embedding/small_bert', word_dropout=0.1, | |||||
auto_truncate=True) | auto_truncate=True) | ||||
words = torch.LongTensor([[2, 3, 4, 1]*10, | words = torch.LongTensor([[2, 3, 4, 1]*10, | ||||
@@ -60,7 +60,7 @@ class TestBertEmbedding(unittest.TestCase): | |||||
try: | try: | ||||
os.makedirs(bert_save_test, exist_ok=True) | os.makedirs(bert_save_test, exist_ok=True) | ||||
vocab = Vocabulary().add_word_lst("this is a test . [SEP] NotInBERT".split()) | vocab = Vocabulary().add_word_lst("this is a test . [SEP] NotInBERT".split()) | ||||
embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', word_dropout=0.1, | |||||
embed = BertEmbedding(vocab, model_dir_or_name='tests/data_for_tests/embedding/small_bert', word_dropout=0.1, | |||||
auto_truncate=True) | auto_truncate=True) | ||||
embed.save(bert_save_test) | embed.save(bert_save_test) | ||||
@@ -76,7 +76,7 @@ class TestBertEmbedding(unittest.TestCase): | |||||
class TestBertWordPieceEncoder(unittest.TestCase): | class TestBertWordPieceEncoder(unittest.TestCase): | ||||
def test_bert_word_piece_encoder(self): | def test_bert_word_piece_encoder(self): | ||||
embed = BertWordPieceEncoder(model_dir_or_name='test/data_for_tests/embedding/small_bert', word_dropout=0.1) | |||||
embed = BertWordPieceEncoder(model_dir_or_name='tests/data_for_tests/embedding/small_bert', word_dropout=0.1) | |||||
ds = DataSet({'words': ["this is a test . [SEP]".split()]}) | ds = DataSet({'words': ["this is a test . [SEP]".split()]}) | ||||
embed.index_datasets(ds, field_name='words') | embed.index_datasets(ds, field_name='words') | ||||
self.assertTrue(ds.has_field('word_pieces')) | self.assertTrue(ds.has_field('word_pieces')) | ||||
@@ -84,7 +84,7 @@ class TestBertWordPieceEncoder(unittest.TestCase): | |||||
def test_bert_embed_eq_bert_piece_encoder(self): | def test_bert_embed_eq_bert_piece_encoder(self): | ||||
ds = DataSet({'words': ["this is a texta model vocab".split(), 'this is'.split()]}) | ds = DataSet({'words': ["this is a texta model vocab".split(), 'this is'.split()]}) | ||||
encoder = BertWordPieceEncoder(model_dir_or_name='test/data_for_tests/embedding/small_bert') | |||||
encoder = BertWordPieceEncoder(model_dir_or_name='tests/data_for_tests/embedding/small_bert') | |||||
encoder.eval() | encoder.eval() | ||||
encoder.index_datasets(ds, field_name='words') | encoder.index_datasets(ds, field_name='words') | ||||
word_pieces = torch.LongTensor(ds['word_pieces'].get([0, 1])) | word_pieces = torch.LongTensor(ds['word_pieces'].get([0, 1])) | ||||
@@ -95,7 +95,7 @@ class TestBertWordPieceEncoder(unittest.TestCase): | |||||
vocab.index_dataset(ds, field_name='words', new_field_name='words') | vocab.index_dataset(ds, field_name='words', new_field_name='words') | ||||
ds.set_input('words') | ds.set_input('words') | ||||
words = torch.LongTensor(ds['words'].get([0, 1])) | words = torch.LongTensor(ds['words'].get([0, 1])) | ||||
embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', | |||||
embed = BertEmbedding(vocab, model_dir_or_name='tests/data_for_tests/embedding/small_bert', | |||||
pool_method='first', include_cls_sep=True, pooled_cls=False, min_freq=1) | pool_method='first', include_cls_sep=True, pooled_cls=False, min_freq=1) | ||||
embed.eval() | embed.eval() | ||||
words_res = embed(words) | words_res = embed(words) | ||||
@@ -109,7 +109,7 @@ class TestBertWordPieceEncoder(unittest.TestCase): | |||||
bert_save_test = 'bert_save_test' | bert_save_test = 'bert_save_test' | ||||
try: | try: | ||||
os.makedirs(bert_save_test, exist_ok=True) | os.makedirs(bert_save_test, exist_ok=True) | ||||
embed = BertWordPieceEncoder(model_dir_or_name='test/data_for_tests/embedding/small_bert', word_dropout=0.0, | |||||
embed = BertWordPieceEncoder(model_dir_or_name='tests/data_for_tests/embedding/small_bert', word_dropout=0.0, | |||||
layers='-2') | layers='-2') | ||||
ds = DataSet({'words': ["this is a test . [SEP]".split()]}) | ds = DataSet({'words': ["this is a test . [SEP]".split()]}) | ||||
embed.index_datasets(ds, field_name='words') | embed.index_datasets(ds, field_name='words') | ||||
@@ -21,7 +21,7 @@ class TestDownload(unittest.TestCase): | |||||
class TestRunElmo(unittest.TestCase): | class TestRunElmo(unittest.TestCase): | ||||
def test_elmo_embedding(self): | def test_elmo_embedding(self): | ||||
vocab = Vocabulary().add_word_lst("This is a test .".split()) | vocab = Vocabulary().add_word_lst("This is a test .".split()) | ||||
elmo_embed = ElmoEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_elmo', layers='0,1') | |||||
elmo_embed = ElmoEmbedding(vocab, model_dir_or_name='tests/data_for_tests/embedding/small_elmo', layers='0,1') | |||||
words = torch.LongTensor([[0, 1, 2]]) | words = torch.LongTensor([[0, 1, 2]]) | ||||
hidden = elmo_embed(words) | hidden = elmo_embed(words) | ||||
print(hidden.size()) | print(hidden.size()) | ||||
@@ -30,7 +30,7 @@ class TestRunElmo(unittest.TestCase): | |||||
def test_elmo_embedding_layer_assertion(self): | def test_elmo_embedding_layer_assertion(self): | ||||
vocab = Vocabulary().add_word_lst("This is a test .".split()) | vocab = Vocabulary().add_word_lst("This is a test .".split()) | ||||
try: | try: | ||||
elmo_embed = ElmoEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_elmo', | |||||
elmo_embed = ElmoEmbedding(vocab, model_dir_or_name='tests/data_for_tests/embedding/small_elmo', | |||||
layers='0,1,2') | layers='0,1,2') | ||||
except AssertionError as e: | except AssertionError as e: | ||||
print(e) | print(e) | ||||
@@ -21,7 +21,7 @@ class TestGPT2Embedding(unittest.TestCase): | |||||
print(embed(words).size()) | print(embed(words).size()) | ||||
def test_gpt2_embedding(self): | def test_gpt2_embedding(self): | ||||
weight_path = 'test/data_for_tests/embedding/small_gpt2' | |||||
weight_path = 'tests/data_for_tests/embedding/small_gpt2' | |||||
vocab = Vocabulary().add_word_lst("this is a texta sentence".split()) | vocab = Vocabulary().add_word_lst("this is a texta sentence".split()) | ||||
embed = GPT2Embedding(vocab, model_dir_or_name=weight_path, word_dropout=0.1) | embed = GPT2Embedding(vocab, model_dir_or_name=weight_path, word_dropout=0.1) | ||||
requires_grad = embed.requires_grad | requires_grad = embed.requires_grad | ||||
@@ -49,7 +49,7 @@ class TestGPT2Embedding(unittest.TestCase): | |||||
def test_gpt2_ebembedding_2(self): | def test_gpt2_ebembedding_2(self): | ||||
# 测试only_use_pretrain_vocab与truncate_embed是否正常工作 | # 测试only_use_pretrain_vocab与truncate_embed是否正常工作 | ||||
Embedding = GPT2Embedding | Embedding = GPT2Embedding | ||||
weight_path = 'test/data_for_tests/embedding/small_gpt2' | |||||
weight_path = 'tests/data_for_tests/embedding/small_gpt2' | |||||
vocab = Vocabulary().add_word_lst("this is a texta and".split()) | vocab = Vocabulary().add_word_lst("this is a texta and".split()) | ||||
embed1 = Embedding(vocab, model_dir_or_name=weight_path,layers=list(range(3)), | embed1 = Embedding(vocab, model_dir_or_name=weight_path,layers=list(range(3)), | ||||
only_use_pretrain_bpe=True, truncate_embed=True, min_freq=1) | only_use_pretrain_bpe=True, truncate_embed=True, min_freq=1) | ||||
@@ -89,13 +89,13 @@ class TestGPT2Embedding(unittest.TestCase): | |||||
def test_gpt2_tokenizer(self): | def test_gpt2_tokenizer(self): | ||||
from fastNLP.modules.tokenizer import GPT2Tokenizer | from fastNLP.modules.tokenizer import GPT2Tokenizer | ||||
tokenizer = GPT2Tokenizer.from_pretrained('test/data_for_tests/embedding/small_gpt2') | |||||
tokenizer = GPT2Tokenizer.from_pretrained('tests/data_for_tests/embedding/small_gpt2') | |||||
print(tokenizer.encode("this is a texta a sentence")) | print(tokenizer.encode("this is a texta a sentence")) | ||||
print(tokenizer.encode('this is')) | print(tokenizer.encode('this is')) | ||||
def test_gpt2_embed_eq_gpt2_piece_encoder(self): | def test_gpt2_embed_eq_gpt2_piece_encoder(self): | ||||
# 主要检查一下embedding的结果与wordpieceencoder的结果是否一致 | # 主要检查一下embedding的结果与wordpieceencoder的结果是否一致 | ||||
weight_path = 'test/data_for_tests/embedding/small_gpt2' | |||||
weight_path = 'tests/data_for_tests/embedding/small_gpt2' | |||||
ds = DataSet({'words': ["this is a texta a sentence".split(), 'this is'.split()]}) | ds = DataSet({'words': ["this is a texta a sentence".split(), 'this is'.split()]}) | ||||
encoder = GPT2WordPieceEncoder(model_dir_or_name=weight_path) | encoder = GPT2WordPieceEncoder(model_dir_or_name=weight_path) | ||||
encoder.eval() | encoder.eval() | ||||
@@ -187,7 +187,7 @@ class TestGPT2WordPieceEncoder(unittest.TestCase): | |||||
print(used_pairs) | print(used_pairs) | ||||
import json | import json | ||||
with open('test/data_for_tests/embedding/small_gpt2/vocab.json', 'w') as f: | |||||
with open('tests/data_for_tests/embedding/small_gpt2/vocab.json', 'w') as f: | |||||
new_used_vocab = {} | new_used_vocab = {} | ||||
for idx, key in enumerate(used_vocab.keys()): | for idx, key in enumerate(used_vocab.keys()): | ||||
new_used_vocab[key] = len(new_used_vocab) | new_used_vocab[key] = len(new_used_vocab) | ||||
@@ -201,12 +201,12 @@ class TestGPT2WordPieceEncoder(unittest.TestCase): | |||||
json.dump(new_used_vocab, f) | json.dump(new_used_vocab, f) | ||||
with open('test/data_for_tests/embedding/small_gpt2/merges.txt', 'w') as f: | |||||
with open('tests/data_for_tests/embedding/small_gpt2/merges.txt', 'w') as f: | |||||
f.write('#version: small\n') | f.write('#version: small\n') | ||||
for k,v in sorted(sorted(used_pairs.items(), key=lambda kv:kv[1])): | for k,v in sorted(sorted(used_pairs.items(), key=lambda kv:kv[1])): | ||||
f.write('{} {}\n'.format(k[0], k[1])) | f.write('{} {}\n'.format(k[0], k[1])) | ||||
new_tokenizer = GPT2Tokenizer.from_pretrained('test/data_for_tests/embedding/small_gpt2') | |||||
new_tokenizer = GPT2Tokenizer.from_pretrained('tests/data_for_tests/embedding/small_gpt2') | |||||
new_all_tokens = [] | new_all_tokens = [] | ||||
for sent in [sent1, sent2, sent3]: | for sent in [sent1, sent2, sent3]: | ||||
tokens = new_tokenizer.tokenize(sent, add_prefix_space=True) | tokens = new_tokenizer.tokenize(sent, add_prefix_space=True) | ||||
@@ -227,21 +227,21 @@ class TestGPT2WordPieceEncoder(unittest.TestCase): | |||||
"n_positions": 20, | "n_positions": 20, | ||||
"vocab_size": len(new_used_vocab) | "vocab_size": len(new_used_vocab) | ||||
} | } | ||||
with open('test/data_for_tests/embedding/small_gpt2/config.json', 'w') as f: | |||||
with open('tests/data_for_tests/embedding/small_gpt2/config.json', 'w') as f: | |||||
json.dump(config, f) | json.dump(config, f) | ||||
# 生成更小的merges.txt与vocab.json, 方法是通过记录tokenizer中的值实现 | # 生成更小的merges.txt与vocab.json, 方法是通过记录tokenizer中的值实现 | ||||
from fastNLP.modules.encoder.gpt2 import GPT2LMHeadModel, GPT2Config | from fastNLP.modules.encoder.gpt2 import GPT2LMHeadModel, GPT2Config | ||||
config = GPT2Config.from_pretrained('test/data_for_tests/embedding/small_gpt2') | |||||
config = GPT2Config.from_pretrained('tests/data_for_tests/embedding/small_gpt2') | |||||
model = GPT2LMHeadModel(config) | model = GPT2LMHeadModel(config) | ||||
torch.save(model.state_dict(), 'test/data_for_tests/embedding/small_gpt2/small_pytorch_model.bin') | |||||
torch.save(model.state_dict(), 'tests/data_for_tests/embedding/small_gpt2/small_pytorch_model.bin') | |||||
print(model(torch.LongTensor([[0,1,2,3]]))) | print(model(torch.LongTensor([[0,1,2,3]]))) | ||||
def test_gpt2_word_piece_encoder(self): | def test_gpt2_word_piece_encoder(self): | ||||
# 主要检查可以运行 | # 主要检查可以运行 | ||||
weight_path = 'test/data_for_tests/embedding/small_gpt2' | |||||
weight_path = 'tests/data_for_tests/embedding/small_gpt2' | |||||
ds = DataSet({'words': ["this is a test sentence".split()]}) | ds = DataSet({'words': ["this is a test sentence".split()]}) | ||||
embed = GPT2WordPieceEncoder(model_dir_or_name=weight_path, word_dropout=0.1) | embed = GPT2WordPieceEncoder(model_dir_or_name=weight_path, word_dropout=0.1) | ||||
embed.index_datasets(ds, field_name='words') | embed.index_datasets(ds, field_name='words') | ||||
@@ -256,7 +256,7 @@ class TestGPT2WordPieceEncoder(unittest.TestCase): | |||||
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") | @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") | ||||
def test_generate(self): | def test_generate(self): | ||||
# weight_path = 'test/data_for_tests/embedding/small_gpt2' | |||||
# weight_path = 'tests/data_for_tests/embedding/small_gpt2' | |||||
weight_path = 'en' | weight_path = 'en' | ||||
encoder = GPT2WordPieceEncoder(model_dir_or_name=weight_path, language_model=True) | encoder = GPT2WordPieceEncoder(model_dir_or_name=weight_path, language_model=True) | ||||
@@ -24,7 +24,7 @@ class TestRobertWordPieceEncoder(unittest.TestCase): | |||||
def test_robert_word_piece_encoder(self): | def test_robert_word_piece_encoder(self): | ||||
# 可正常运行即可 | # 可正常运行即可 | ||||
weight_path = 'test/data_for_tests/embedding/small_roberta' | |||||
weight_path = 'tests/data_for_tests/embedding/small_roberta' | |||||
encoder = RobertaWordPieceEncoder(model_dir_or_name=weight_path, word_dropout=0.1) | encoder = RobertaWordPieceEncoder(model_dir_or_name=weight_path, word_dropout=0.1) | ||||
ds = DataSet({'words': ["this is a test . [SEP]".split()]}) | ds = DataSet({'words': ["this is a test . [SEP]".split()]}) | ||||
encoder.index_datasets(ds, field_name='words') | encoder.index_datasets(ds, field_name='words') | ||||
@@ -33,7 +33,7 @@ class TestRobertWordPieceEncoder(unittest.TestCase): | |||||
def test_roberta_embed_eq_roberta_piece_encoder(self): | def test_roberta_embed_eq_roberta_piece_encoder(self): | ||||
# 主要检查一下embedding的结果与wordpieceencoder的结果是否一致 | # 主要检查一下embedding的结果与wordpieceencoder的结果是否一致 | ||||
weight_path = 'test/data_for_tests/embedding/small_roberta' | |||||
weight_path = 'tests/data_for_tests/embedding/small_roberta' | |||||
ds = DataSet({'words': ["this is a texta a sentence".split(), 'this is'.split()]}) | ds = DataSet({'words': ["this is a texta a sentence".split(), 'this is'.split()]}) | ||||
encoder = RobertaWordPieceEncoder(model_dir_or_name=weight_path) | encoder = RobertaWordPieceEncoder(model_dir_or_name=weight_path) | ||||
encoder.eval() | encoder.eval() | ||||
@@ -120,7 +120,7 @@ class TestRobertWordPieceEncoder(unittest.TestCase): | |||||
used_vocab.update({t:i for t,i in zip(tokens, token_ids)}) | used_vocab.update({t:i for t,i in zip(tokens, token_ids)}) | ||||
import json | import json | ||||
with open('test/data_for_tests/embedding/small_roberta/vocab.json', 'w') as f: | |||||
with open('tests/data_for_tests/embedding/small_roberta/vocab.json', 'w') as f: | |||||
new_used_vocab = {} | new_used_vocab = {} | ||||
for token in ['<s>', '<pad>', '</s>', '<unk>', '<mask>']: # <pad>必须为1 | for token in ['<s>', '<pad>', '</s>', '<unk>', '<mask>']: # <pad>必须为1 | ||||
new_used_vocab[token] = len(new_used_vocab) | new_used_vocab[token] = len(new_used_vocab) | ||||
@@ -135,7 +135,7 @@ class TestRobertWordPieceEncoder(unittest.TestCase): | |||||
new_used_vocab[key] = len(new_used_vocab) | new_used_vocab[key] = len(new_used_vocab) | ||||
json.dump(new_used_vocab, f) | json.dump(new_used_vocab, f) | ||||
with open('test/data_for_tests/embedding/small_roberta/merges.txt', 'w') as f: | |||||
with open('tests/data_for_tests/embedding/small_roberta/merges.txt', 'w') as f: | |||||
f.write('#version: tiny\n') | f.write('#version: tiny\n') | ||||
for k,v in sorted(sorted(used_pairs.items(), key=lambda kv:kv[1])): | for k,v in sorted(sorted(used_pairs.items(), key=lambda kv:kv[1])): | ||||
f.write('{} {}\n'.format(k[0], k[1])) | f.write('{} {}\n'.format(k[0], k[1])) | ||||
@@ -162,10 +162,10 @@ class TestRobertWordPieceEncoder(unittest.TestCase): | |||||
"type_vocab_size": 1, | "type_vocab_size": 1, | ||||
"vocab_size": len(new_used_vocab) | "vocab_size": len(new_used_vocab) | ||||
} | } | ||||
with open('test/data_for_tests/embedding/small_roberta/config.json', 'w') as f: | |||||
with open('tests/data_for_tests/embedding/small_roberta/config.json', 'w') as f: | |||||
json.dump(config, f) | json.dump(config, f) | ||||
new_tokenizer = RobertaTokenizer.from_pretrained('test/data_for_tests/embedding/small_roberta') | |||||
new_tokenizer = RobertaTokenizer.from_pretrained('tests/data_for_tests/embedding/small_roberta') | |||||
new_all_tokens = [] | new_all_tokens = [] | ||||
for sent in [sent1, sent2, sent3]: | for sent in [sent1, sent2, sent3]: | ||||
tokens = new_tokenizer.tokenize(sent, add_prefix_space=True) | tokens = new_tokenizer.tokenize(sent, add_prefix_space=True) | ||||
@@ -177,17 +177,17 @@ class TestRobertWordPieceEncoder(unittest.TestCase): | |||||
# 生成更小的merges.txt与vocab.json, 方法是通过记录tokenizer中的值实现 | # 生成更小的merges.txt与vocab.json, 方法是通过记录tokenizer中的值实现 | ||||
from fastNLP.modules.encoder.roberta import RobertaModel, BertConfig | from fastNLP.modules.encoder.roberta import RobertaModel, BertConfig | ||||
config = BertConfig.from_json_file('test/data_for_tests/embedding/small_roberta/config.json') | |||||
config = BertConfig.from_json_file('tests/data_for_tests/embedding/small_roberta/config.json') | |||||
model = RobertaModel(config) | model = RobertaModel(config) | ||||
torch.save(model.state_dict(), 'test/data_for_tests/embedding/small_roberta/small_pytorch_model.bin') | |||||
torch.save(model.state_dict(), 'tests/data_for_tests/embedding/small_roberta/small_pytorch_model.bin') | |||||
print(model(torch.LongTensor([[0,1,2,3]]))) | print(model(torch.LongTensor([[0,1,2,3]]))) | ||||
def test_save_load(self): | def test_save_load(self): | ||||
bert_save_test = 'roberta_save_test' | bert_save_test = 'roberta_save_test' | ||||
try: | try: | ||||
os.makedirs(bert_save_test, exist_ok=True) | os.makedirs(bert_save_test, exist_ok=True) | ||||
embed = RobertaWordPieceEncoder(model_dir_or_name='test/data_for_tests/embedding/small_roberta', word_dropout=0.0, | |||||
embed = RobertaWordPieceEncoder(model_dir_or_name='tests/data_for_tests/embedding/small_roberta', word_dropout=0.0, | |||||
layers='-2') | layers='-2') | ||||
ds = DataSet({'words': ["this is a test . [SEP]".split()]}) | ds = DataSet({'words': ["this is a test . [SEP]".split()]}) | ||||
embed.index_datasets(ds, field_name='words') | embed.index_datasets(ds, field_name='words') | ||||
@@ -204,7 +204,7 @@ class TestRobertWordPieceEncoder(unittest.TestCase): | |||||
class TestRobertaEmbedding(unittest.TestCase): | class TestRobertaEmbedding(unittest.TestCase): | ||||
def test_roberta_embedding_1(self): | def test_roberta_embedding_1(self): | ||||
weight_path = 'test/data_for_tests/embedding/small_roberta' | |||||
weight_path = 'tests/data_for_tests/embedding/small_roberta' | |||||
vocab = Vocabulary().add_word_lst("this is a test . [SEP] NotInRoberta".split()) | vocab = Vocabulary().add_word_lst("this is a test . [SEP] NotInRoberta".split()) | ||||
embed = RobertaEmbedding(vocab, model_dir_or_name=weight_path, word_dropout=0.1) | embed = RobertaEmbedding(vocab, model_dir_or_name=weight_path, word_dropout=0.1) | ||||
requires_grad = embed.requires_grad | requires_grad = embed.requires_grad | ||||
@@ -224,7 +224,7 @@ class TestRobertaEmbedding(unittest.TestCase): | |||||
def test_roberta_ebembedding_2(self): | def test_roberta_ebembedding_2(self): | ||||
# 测试only_use_pretrain_vocab与truncate_embed是否正常工作 | # 测试only_use_pretrain_vocab与truncate_embed是否正常工作 | ||||
Embedding = RobertaEmbedding | Embedding = RobertaEmbedding | ||||
weight_path = 'test/data_for_tests/embedding/small_roberta' | |||||
weight_path = 'tests/data_for_tests/embedding/small_roberta' | |||||
vocab = Vocabulary().add_word_lst("this is a texta and".split()) | vocab = Vocabulary().add_word_lst("this is a texta and".split()) | ||||
embed1 = Embedding(vocab, model_dir_or_name=weight_path, layers=list(range(3)), | embed1 = Embedding(vocab, model_dir_or_name=weight_path, layers=list(range(3)), | ||||
only_use_pretrain_bpe=True, truncate_embed=True, min_freq=1) | only_use_pretrain_bpe=True, truncate_embed=True, min_freq=1) | ||||
@@ -266,7 +266,7 @@ class TestRobertaEmbedding(unittest.TestCase): | |||||
try: | try: | ||||
os.makedirs(bert_save_test, exist_ok=True) | os.makedirs(bert_save_test, exist_ok=True) | ||||
vocab = Vocabulary().add_word_lst("this is a test . [SEP] NotInBERT".split()) | vocab = Vocabulary().add_word_lst("this is a test . [SEP] NotInBERT".split()) | ||||
embed = RobertaEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_roberta', | |||||
embed = RobertaEmbedding(vocab, model_dir_or_name='tests/data_for_tests/embedding/small_roberta', | |||||
word_dropout=0.1, | word_dropout=0.1, | ||||
auto_truncate=True) | auto_truncate=True) | ||||
embed.save(bert_save_test) | embed.save(bert_save_test) | ||||
@@ -10,7 +10,7 @@ class TestLoad(unittest.TestCase): | |||||
def test_norm1(self): | def test_norm1(self): | ||||
# 测试只对可以找到的norm | # 测试只对可以找到的norm | ||||
vocab = Vocabulary().add_word_lst(['the', 'a', 'notinfile']) | vocab = Vocabulary().add_word_lst(['the', 'a', 'notinfile']) | ||||
embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_static_embedding/' | |||||
embed = StaticEmbedding(vocab, model_dir_or_name='tests/data_for_tests/embedding/small_static_embedding/' | |||||
'glove.6B.50d_test.txt', | 'glove.6B.50d_test.txt', | ||||
only_norm_found_vector=True) | only_norm_found_vector=True) | ||||
self.assertEqual(round(torch.norm(embed(torch.LongTensor([[2]]))).item(), 4), 1) | self.assertEqual(round(torch.norm(embed(torch.LongTensor([[2]]))).item(), 4), 1) | ||||
@@ -19,7 +19,7 @@ class TestLoad(unittest.TestCase): | |||||
def test_norm2(self): | def test_norm2(self): | ||||
# 测试对所有都norm | # 测试对所有都norm | ||||
vocab = Vocabulary().add_word_lst(['the', 'a', 'notinfile']) | vocab = Vocabulary().add_word_lst(['the', 'a', 'notinfile']) | ||||
embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_static_embedding/' | |||||
embed = StaticEmbedding(vocab, model_dir_or_name='tests/data_for_tests/embedding/small_static_embedding/' | |||||
'glove.6B.50d_test.txt', | 'glove.6B.50d_test.txt', | ||||
normalize=True) | normalize=True) | ||||
self.assertEqual(round(torch.norm(embed(torch.LongTensor([[2]]))).item(), 4), 1) | self.assertEqual(round(torch.norm(embed(torch.LongTensor([[2]]))).item(), 4), 1) | ||||
@@ -50,13 +50,13 @@ class TestLoad(unittest.TestCase): | |||||
v2 = embed_dict[word] | v2 = embed_dict[word] | ||||
for v1i, v2i in zip(v1, v2): | for v1i, v2i in zip(v1, v2): | ||||
self.assertAlmostEqual(v1i, v2i, places=4) | self.assertAlmostEqual(v1i, v2i, places=4) | ||||
embed_dict = read_static_embed('test/data_for_tests/embedding/small_static_embedding/' | |||||
embed_dict = read_static_embed('tests/data_for_tests/embedding/small_static_embedding/' | |||||
'glove.6B.50d_test.txt') | 'glove.6B.50d_test.txt') | ||||
# 测试是否只使用pretrain的word | # 测试是否只使用pretrain的word | ||||
vocab = Vocabulary().add_word_lst(['the', 'a', 'notinfile']) | vocab = Vocabulary().add_word_lst(['the', 'a', 'notinfile']) | ||||
vocab.add_word('of', no_create_entry=True) | vocab.add_word('of', no_create_entry=True) | ||||
embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_static_embedding/' | |||||
embed = StaticEmbedding(vocab, model_dir_or_name='tests/data_for_tests/embedding/small_static_embedding/' | |||||
'glove.6B.50d_test.txt', | 'glove.6B.50d_test.txt', | ||||
only_use_pretrain_word=True) | only_use_pretrain_word=True) | ||||
# notinfile应该被置为unk | # notinfile应该被置为unk | ||||
@@ -66,13 +66,13 @@ class TestLoad(unittest.TestCase): | |||||
# 测试在大小写情况下的使用 | # 测试在大小写情况下的使用 | ||||
vocab = Vocabulary().add_word_lst(['The', 'a', 'notinfile']) | vocab = Vocabulary().add_word_lst(['The', 'a', 'notinfile']) | ||||
vocab.add_word('Of', no_create_entry=True) | vocab.add_word('Of', no_create_entry=True) | ||||
embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_static_embedding/' | |||||
embed = StaticEmbedding(vocab, model_dir_or_name='tests/data_for_tests/embedding/small_static_embedding/' | |||||
'glove.6B.50d_test.txt', | 'glove.6B.50d_test.txt', | ||||
only_use_pretrain_word=True) | only_use_pretrain_word=True) | ||||
check_word_unk(['The', 'Of', 'notinfile'], vocab, embed) # 这些词应该找不到 | check_word_unk(['The', 'Of', 'notinfile'], vocab, embed) # 这些词应该找不到 | ||||
check_vector_equal(['a'], vocab, embed, embed_dict) | check_vector_equal(['a'], vocab, embed, embed_dict) | ||||
embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_static_embedding/' | |||||
embed = StaticEmbedding(vocab, model_dir_or_name='tests/data_for_tests/embedding/small_static_embedding/' | |||||
'glove.6B.50d_test.txt', | 'glove.6B.50d_test.txt', | ||||
only_use_pretrain_word=True, lower=True) | only_use_pretrain_word=True, lower=True) | ||||
check_vector_equal(['The', 'Of', 'a'], vocab, embed, embed_dict, lower=True) | check_vector_equal(['The', 'Of', 'a'], vocab, embed, embed_dict, lower=True) | ||||
@@ -82,7 +82,7 @@ class TestLoad(unittest.TestCase): | |||||
vocab = Vocabulary().add_word_lst(['The', 'a', 'notinfile1', 'A', 'notinfile2', 'notinfile2']) | vocab = Vocabulary().add_word_lst(['The', 'a', 'notinfile1', 'A', 'notinfile2', 'notinfile2']) | ||||
vocab.add_word('Of', no_create_entry=True) | vocab.add_word('Of', no_create_entry=True) | ||||
embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_static_embedding/' | |||||
embed = StaticEmbedding(vocab, model_dir_or_name='tests/data_for_tests/embedding/small_static_embedding/' | |||||
'glove.6B.50d_test.txt', | 'glove.6B.50d_test.txt', | ||||
only_use_pretrain_word=True, lower=True, min_freq=2, only_train_min_freq=True) | only_use_pretrain_word=True, lower=True, min_freq=2, only_train_min_freq=True) | ||||
@@ -92,12 +92,12 @@ class TestLoad(unittest.TestCase): | |||||
def test_sequential_index(self): | def test_sequential_index(self): | ||||
# 当不存在no_create_entry时,words_to_words应该是顺序的 | # 当不存在no_create_entry时,words_to_words应该是顺序的 | ||||
vocab = Vocabulary().add_word_lst(['The', 'a', 'notinfile1', 'A', 'notinfile2', 'notinfile2']) | vocab = Vocabulary().add_word_lst(['The', 'a', 'notinfile1', 'A', 'notinfile2', 'notinfile2']) | ||||
embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_static_embedding/' | |||||
embed = StaticEmbedding(vocab, model_dir_or_name='tests/data_for_tests/embedding/small_static_embedding/' | |||||
'glove.6B.50d_test.txt') | 'glove.6B.50d_test.txt') | ||||
for index,i in enumerate(embed.words_to_words): | for index,i in enumerate(embed.words_to_words): | ||||
assert index==i | assert index==i | ||||
embed_dict = read_static_embed('test/data_for_tests/embedding/small_static_embedding/' | |||||
embed_dict = read_static_embed('tests/data_for_tests/embedding/small_static_embedding/' | |||||
'glove.6B.50d_test.txt') | 'glove.6B.50d_test.txt') | ||||
for word, index in vocab: | for word, index in vocab: | ||||
@@ -116,7 +116,7 @@ class TestLoad(unittest.TestCase): | |||||
vocab = Vocabulary().add_word_lst(['The', 'a', 'notinfile1', 'A']) | vocab = Vocabulary().add_word_lst(['The', 'a', 'notinfile1', 'A']) | ||||
vocab.add_word_lst(['notinfile2', 'notinfile2'], no_create_entry=True) | vocab.add_word_lst(['notinfile2', 'notinfile2'], no_create_entry=True) | ||||
embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_static_embedding/' | |||||
embed = StaticEmbedding(vocab, model_dir_or_name='tests/data_for_tests/embedding/small_static_embedding/' | |||||
'glove.6B.50d_test.txt') | 'glove.6B.50d_test.txt') | ||||
embed.save(static_test_folder) | embed.save(static_test_folder) | ||||
load_embed = StaticEmbedding.load(static_test_folder) | load_embed = StaticEmbedding.load(static_test_folder) | ||||
@@ -125,7 +125,7 @@ class TestLoad(unittest.TestCase): | |||||
# 测试不包含no_create_entry | # 测试不包含no_create_entry | ||||
vocab = Vocabulary().add_word_lst(['The', 'a', 'notinfile1', 'A']) | vocab = Vocabulary().add_word_lst(['The', 'a', 'notinfile1', 'A']) | ||||
embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_static_embedding/' | |||||
embed = StaticEmbedding(vocab, model_dir_or_name='tests/data_for_tests/embedding/small_static_embedding/' | |||||
'glove.6B.50d_test.txt') | 'glove.6B.50d_test.txt') | ||||
embed.save(static_test_folder) | embed.save(static_test_folder) | ||||
load_embed = StaticEmbedding.load(static_test_folder) | load_embed = StaticEmbedding.load(static_test_folder) | ||||
@@ -134,7 +134,7 @@ class TestLoad(unittest.TestCase): | |||||
# 测试lower, min_freq | # 测试lower, min_freq | ||||
vocab = Vocabulary().add_word_lst(['The', 'the', 'the', 'A', 'a', 'B']) | vocab = Vocabulary().add_word_lst(['The', 'the', 'the', 'A', 'a', 'B']) | ||||
embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_static_embedding/' | |||||
embed = StaticEmbedding(vocab, model_dir_or_name='tests/data_for_tests/embedding/small_static_embedding/' | |||||
'glove.6B.50d_test.txt', min_freq=2, lower=True) | 'glove.6B.50d_test.txt', min_freq=2, lower=True) | ||||
embed.save(static_test_folder) | embed.save(static_test_folder) | ||||
load_embed = StaticEmbedding.load(static_test_folder) | load_embed = StaticEmbedding.load(static_test_folder) | ||||
@@ -23,14 +23,14 @@ class TestDownload(unittest.TestCase): | |||||
class TestLoad(unittest.TestCase): | class TestLoad(unittest.TestCase): | ||||
def test_process_from_file(self): | def test_process_from_file(self): | ||||
data_set_dict = { | data_set_dict = { | ||||
'yelp.p': ('test/data_for_tests/io/yelp_review_polarity', YelpPolarityLoader, (6, 6, 6), False), | |||||
'yelp.f': ('test/data_for_tests/io/yelp_review_full', YelpFullLoader, (6, 6, 6), False), | |||||
'sst-2': ('test/data_for_tests/io/SST-2', SST2Loader, (5, 5, 5), True), | |||||
'sst': ('test/data_for_tests/io/SST', SSTLoader, (6, 6, 6), False), | |||||
'imdb': ('test/data_for_tests/io/imdb', IMDBLoader, (6, 6, 6), False), | |||||
'ChnSentiCorp': ('test/data_for_tests/io/ChnSentiCorp', ChnSentiCorpLoader, (6, 6, 6), False), | |||||
'THUCNews': ('test/data_for_tests/io/THUCNews', THUCNewsLoader, (9, 9, 9), False), | |||||
'WeiboSenti100k': ('test/data_for_tests/io/WeiboSenti100k', WeiboSenti100kLoader, (6, 7, 6), False), | |||||
'yelp.p': ('tests/data_for_tests/io/yelp_review_polarity', YelpPolarityLoader, (6, 6, 6), False), | |||||
'yelp.f': ('tests/data_for_tests/io/yelp_review_full', YelpFullLoader, (6, 6, 6), False), | |||||
'sst-2': ('tests/data_for_tests/io/SST-2', SST2Loader, (5, 5, 5), True), | |||||
'sst': ('tests/data_for_tests/io/SST', SSTLoader, (6, 6, 6), False), | |||||
'imdb': ('tests/data_for_tests/io/imdb', IMDBLoader, (6, 6, 6), False), | |||||
'ChnSentiCorp': ('tests/data_for_tests/io/ChnSentiCorp', ChnSentiCorpLoader, (6, 6, 6), False), | |||||
'THUCNews': ('tests/data_for_tests/io/THUCNews', THUCNewsLoader, (9, 9, 9), False), | |||||
'WeiboSenti100k': ('tests/data_for_tests/io/WeiboSenti100k', WeiboSenti100kLoader, (6, 7, 6), False), | |||||
} | } | ||||
for k, v in data_set_dict.items(): | for k, v in data_set_dict.items(): | ||||
path, loader, data_set, warns = v | path, loader, data_set, warns = v | ||||
@@ -27,12 +27,12 @@ class TestWeiboNER(unittest.TestCase): | |||||
class TestConll2003Loader(unittest.TestCase): | class TestConll2003Loader(unittest.TestCase): | ||||
def test_load(self): | def test_load(self): | ||||
Conll2003Loader()._load('test/data_for_tests/conll_2003_example.txt') | |||||
Conll2003Loader()._load('tests/data_for_tests/conll_2003_example.txt') | |||||
class TestConllLoader(unittest.TestCase): | class TestConllLoader(unittest.TestCase): | ||||
def test_conll(self): | def test_conll(self): | ||||
db = Conll2003Loader().load('test/data_for_tests/io/conll2003') | |||||
db = Conll2003Loader().load('tests/data_for_tests/io/conll2003') | |||||
print(db) | print(db) | ||||
class TestConllLoader(unittest.TestCase): | class TestConllLoader(unittest.TestCase): | ||||
@@ -40,5 +40,5 @@ class TestConllLoader(unittest.TestCase): | |||||
headers = [ | headers = [ | ||||
'raw_words', 'ner', | 'raw_words', 'ner', | ||||
] | ] | ||||
db = ConllLoader(headers = headers,sep="\n").load('test/data_for_tests/io/MSRA_NER') | |||||
db = ConllLoader(headers = headers,sep="\n").load('tests/data_for_tests/io/MSRA_NER') | |||||
print(db) | print(db) |
@@ -5,7 +5,7 @@ import unittest | |||||
class TestCR(unittest.TestCase): | class TestCR(unittest.TestCase): | ||||
def test_load(self): | def test_load(self): | ||||
test_root = "test/data_for_tests/io/coreference/" | |||||
test_root = "tests/data_for_tests/io/coreference/" | |||||
train_path = test_root+"coreference_train.json" | train_path = test_root+"coreference_train.json" | ||||
dev_path = test_root+"coreference_dev.json" | dev_path = test_root+"coreference_dev.json" | ||||
test_path = test_root+"coreference_test.json" | test_path = test_root+"coreference_test.json" | ||||
@@ -19,6 +19,6 @@ class TestRunCWSLoader(unittest.TestCase): | |||||
for dataset_name in dataset_names: | for dataset_name in dataset_names: | ||||
with self.subTest(dataset_name=dataset_name): | with self.subTest(dataset_name=dataset_name): | ||||
data_bundle = CWSLoader(dataset_name=dataset_name).load( | data_bundle = CWSLoader(dataset_name=dataset_name).load( | ||||
f'test/data_for_tests/io/cws_{dataset_name}' | |||||
f'tests/data_for_tests/io/cws_{dataset_name}' | |||||
) | ) | ||||
print(data_bundle) | print(data_bundle) |
@@ -25,14 +25,14 @@ class TestMatchingDownload(unittest.TestCase): | |||||
class TestMatchingLoad(unittest.TestCase): | class TestMatchingLoad(unittest.TestCase): | ||||
def test_load(self): | def test_load(self): | ||||
data_set_dict = { | data_set_dict = { | ||||
'RTE': ('test/data_for_tests/io/RTE', RTELoader, (5, 5, 5), True), | |||||
'SNLI': ('test/data_for_tests/io/SNLI', SNLILoader, (5, 5, 5), False), | |||||
'QNLI': ('test/data_for_tests/io/QNLI', QNLILoader, (5, 5, 5), True), | |||||
'MNLI': ('test/data_for_tests/io/MNLI', MNLILoader, (5, 5, 5, 5, 6), True), | |||||
'Quora': ('test/data_for_tests/io/Quora', QuoraLoader, (2, 2, 2), False), | |||||
'BQCorpus': ('test/data_for_tests/io/BQCorpus', BQCorpusLoader, (5, 5, 5), False), | |||||
'XNLI': ('test/data_for_tests/io/XNLI', CNXNLILoader, (6, 6, 8), False), | |||||
'LCQMC': ('test/data_for_tests/io/LCQMC', LCQMCLoader, (6, 5, 6), False), | |||||
'RTE': ('tests/data_for_tests/io/RTE', RTELoader, (5, 5, 5), True), | |||||
'SNLI': ('tests/data_for_tests/io/SNLI', SNLILoader, (5, 5, 5), False), | |||||
'QNLI': ('tests/data_for_tests/io/QNLI', QNLILoader, (5, 5, 5), True), | |||||
'MNLI': ('tests/data_for_tests/io/MNLI', MNLILoader, (5, 5, 5, 5, 6), True), | |||||
'Quora': ('tests/data_for_tests/io/Quora', QuoraLoader, (2, 2, 2), False), | |||||
'BQCorpus': ('tests/data_for_tests/io/BQCorpus', BQCorpusLoader, (5, 5, 5), False), | |||||
'XNLI': ('tests/data_for_tests/io/XNLI', CNXNLILoader, (6, 6, 8), False), | |||||
'LCQMC': ('tests/data_for_tests/io/LCQMC', LCQMCLoader, (6, 5, 6), False), | |||||
} | } | ||||
for k, v in data_set_dict.items(): | for k, v in data_set_dict.items(): | ||||
path, loader, instance, warns = v | path, loader, instance, warns = v | ||||
@@ -5,10 +5,10 @@ from fastNLP.io.loader.qa import CMRC2018Loader | |||||
class TestCMRC2018Loader(unittest.TestCase): | class TestCMRC2018Loader(unittest.TestCase): | ||||
def test__load(self): | def test__load(self): | ||||
loader = CMRC2018Loader() | loader = CMRC2018Loader() | ||||
dataset = loader._load('test/data_for_tests/io/cmrc/train.json') | |||||
dataset = loader._load('tests/data_for_tests/io/cmrc/train.json') | |||||
print(dataset) | print(dataset) | ||||
def test_load(self): | def test_load(self): | ||||
loader = CMRC2018Loader() | loader = CMRC2018Loader() | ||||
data_bundle = loader.load('test/data_for_tests/io/cmrc/') | |||||
data_bundle = loader.load('tests/data_for_tests/io/cmrc/') | |||||
print(data_bundle) | print(data_bundle) |
@@ -20,7 +20,7 @@ class TestClassificationPipe(unittest.TestCase): | |||||
class TestRunPipe(unittest.TestCase): | class TestRunPipe(unittest.TestCase): | ||||
def test_load(self): | def test_load(self): | ||||
for pipe in [IMDBPipe]: | for pipe in [IMDBPipe]: | ||||
data_bundle = pipe(tokenizer='raw').process_from_file('test/data_for_tests/io/imdb') | |||||
data_bundle = pipe(tokenizer='raw').process_from_file('tests/data_for_tests/io/imdb') | |||||
print(data_bundle) | print(data_bundle) | ||||
@@ -37,35 +37,35 @@ class TestCNClassificationPipe(unittest.TestCase): | |||||
class TestRunClassificationPipe(unittest.TestCase): | class TestRunClassificationPipe(unittest.TestCase): | ||||
def test_process_from_file(self): | def test_process_from_file(self): | ||||
data_set_dict = { | data_set_dict = { | ||||
'yelp.p': ('test/data_for_tests/io/yelp_review_polarity', YelpPolarityPipe, | |||||
'yelp.p': ('tests/data_for_tests/io/yelp_review_polarity', YelpPolarityPipe, | |||||
{'train': 6, 'dev': 6, 'test': 6}, {'words': 1176, 'target': 2}, | {'train': 6, 'dev': 6, 'test': 6}, {'words': 1176, 'target': 2}, | ||||
False), | False), | ||||
'yelp.f': ('test/data_for_tests/io/yelp_review_full', YelpFullPipe, | |||||
'yelp.f': ('tests/data_for_tests/io/yelp_review_full', YelpFullPipe, | |||||
{'train': 6, 'dev': 6, 'test': 6}, {'words': 1166, 'target': 5}, | {'train': 6, 'dev': 6, 'test': 6}, {'words': 1166, 'target': 5}, | ||||
False), | False), | ||||
'sst-2': ('test/data_for_tests/io/SST-2', SST2Pipe, | |||||
'sst-2': ('tests/data_for_tests/io/SST-2', SST2Pipe, | |||||
{'train': 5, 'dev': 5, 'test': 5}, {'words': 139, 'target': 2}, | {'train': 5, 'dev': 5, 'test': 5}, {'words': 139, 'target': 2}, | ||||
True), | True), | ||||
'sst': ('test/data_for_tests/io/SST', SSTPipe, | |||||
'sst': ('tests/data_for_tests/io/SST', SSTPipe, | |||||
{'train': 354, 'dev': 6, 'test': 6}, {'words': 232, 'target': 5}, | {'train': 354, 'dev': 6, 'test': 6}, {'words': 232, 'target': 5}, | ||||
False), | False), | ||||
'imdb': ('test/data_for_tests/io/imdb', IMDBPipe, | |||||
'imdb': ('tests/data_for_tests/io/imdb', IMDBPipe, | |||||
{'train': 6, 'dev': 6, 'test': 6}, {'words': 1670, 'target': 2}, | {'train': 6, 'dev': 6, 'test': 6}, {'words': 1670, 'target': 2}, | ||||
False), | False), | ||||
'ag': ('test/data_for_tests/io/ag', AGsNewsPipe, | |||||
'ag': ('tests/data_for_tests/io/ag', AGsNewsPipe, | |||||
{'train': 4, 'test': 5}, {'words': 257, 'target': 4}, | {'train': 4, 'test': 5}, {'words': 257, 'target': 4}, | ||||
False), | False), | ||||
'dbpedia': ('test/data_for_tests/io/dbpedia', DBPediaPipe, | |||||
'dbpedia': ('tests/data_for_tests/io/dbpedia', DBPediaPipe, | |||||
{'train': 14, 'test': 5}, {'words': 496, 'target': 14}, | {'train': 14, 'test': 5}, {'words': 496, 'target': 14}, | ||||
False), | False), | ||||
'ChnSentiCorp': ('test/data_for_tests/io/ChnSentiCorp', ChnSentiCorpPipe, | |||||
'ChnSentiCorp': ('tests/data_for_tests/io/ChnSentiCorp', ChnSentiCorpPipe, | |||||
{'train': 6, 'dev': 6, 'test': 6}, | {'train': 6, 'dev': 6, 'test': 6}, | ||||
{'chars': 529, 'bigrams': 1296, 'trigrams': 1483, 'target': 2}, | {'chars': 529, 'bigrams': 1296, 'trigrams': 1483, 'target': 2}, | ||||
False), | False), | ||||
'Chn-THUCNews': ('test/data_for_tests/io/THUCNews', THUCNewsPipe, | |||||
'Chn-THUCNews': ('tests/data_for_tests/io/THUCNews', THUCNewsPipe, | |||||
{'train': 9, 'dev': 9, 'test': 9}, {'chars': 1864, 'target': 9}, | {'train': 9, 'dev': 9, 'test': 9}, {'chars': 1864, 'target': 9}, | ||||
False), | False), | ||||
'Chn-WeiboSenti100k': ('test/data_for_tests/io/WeiboSenti100k', WeiboSenti100kPipe, | |||||
'Chn-WeiboSenti100k': ('tests/data_for_tests/io/WeiboSenti100k', WeiboSenti100kPipe, | |||||
{'train': 6, 'dev': 6, 'test': 7}, {'chars': 452, 'target': 2}, | {'train': 6, 'dev': 6, 'test': 7}, {'chars': 452, 'target': 2}, | ||||
False), | False), | ||||
} | } | ||||
@@ -21,7 +21,7 @@ class TestRunPipe(unittest.TestCase): | |||||
for pipe in [Conll2003Pipe, Conll2003NERPipe]: | for pipe in [Conll2003Pipe, Conll2003NERPipe]: | ||||
with self.subTest(pipe=pipe): | with self.subTest(pipe=pipe): | ||||
print(pipe) | print(pipe) | ||||
data_bundle = pipe().process_from_file('test/data_for_tests/conll_2003_example.txt') | |||||
data_bundle = pipe().process_from_file('tests/data_for_tests/conll_2003_example.txt') | |||||
print(data_bundle) | print(data_bundle) | ||||
@@ -35,18 +35,18 @@ class TestNERPipe(unittest.TestCase): | |||||
for k, v in data_dict.items(): | for k, v in data_dict.items(): | ||||
pipe = v | pipe = v | ||||
with self.subTest(pipe=pipe): | with self.subTest(pipe=pipe): | ||||
data_bundle = pipe(bigrams=True, trigrams=True).process_from_file(f'test/data_for_tests/io/{k}') | |||||
data_bundle = pipe(bigrams=True, trigrams=True).process_from_file(f'tests/data_for_tests/io/{k}') | |||||
print(data_bundle) | print(data_bundle) | ||||
data_bundle = pipe(encoding_type='bioes').process_from_file(f'test/data_for_tests/io/{k}') | |||||
data_bundle = pipe(encoding_type='bioes').process_from_file(f'tests/data_for_tests/io/{k}') | |||||
print(data_bundle) | print(data_bundle) | ||||
class TestConll2003Pipe(unittest.TestCase): | class TestConll2003Pipe(unittest.TestCase): | ||||
def test_conll(self): | def test_conll(self): | ||||
with self.assertWarns(Warning): | with self.assertWarns(Warning): | ||||
data_bundle = Conll2003Pipe().process_from_file('test/data_for_tests/io/conll2003') | |||||
data_bundle = Conll2003Pipe().process_from_file('tests/data_for_tests/io/conll2003') | |||||
print(data_bundle) | print(data_bundle) | ||||
def test_OntoNotes(self): | def test_OntoNotes(self): | ||||
data_bundle = OntoNotesNERPipe().process_from_file('test/data_for_tests/io/OntoNotes') | |||||
data_bundle = OntoNotesNERPipe().process_from_file('tests/data_for_tests/io/OntoNotes') | |||||
print(data_bundle) | print(data_bundle) |
@@ -11,7 +11,7 @@ class TestCR(unittest.TestCase): | |||||
char_path = None | char_path = None | ||||
config = Config() | config = Config() | ||||
file_root_path = "test/data_for_tests/io/coreference/" | |||||
file_root_path = "tests/data_for_tests/io/coreference/" | |||||
train_path = file_root_path + "coreference_train.json" | train_path = file_root_path + "coreference_train.json" | ||||
dev_path = file_root_path + "coreference_dev.json" | dev_path = file_root_path + "coreference_dev.json" | ||||
test_path = file_root_path + "coreference_test.json" | test_path = file_root_path + "coreference_test.json" | ||||
@@ -31,11 +31,11 @@ class TestRunCWSPipe(unittest.TestCase): | |||||
for dataset_name in dataset_names: | for dataset_name in dataset_names: | ||||
with self.subTest(dataset_name=dataset_name): | with self.subTest(dataset_name=dataset_name): | ||||
data_bundle = CWSPipe(bigrams=True, trigrams=True).\ | data_bundle = CWSPipe(bigrams=True, trigrams=True).\ | ||||
process_from_file(f'test/data_for_tests/io/cws_{dataset_name}') | |||||
process_from_file(f'tests/data_for_tests/io/cws_{dataset_name}') | |||||
print(data_bundle) | print(data_bundle) | ||||
def test_replace_number(self): | def test_replace_number(self): | ||||
data_bundle = CWSPipe(bigrams=True, replace_num_alpha=True).\ | data_bundle = CWSPipe(bigrams=True, replace_num_alpha=True).\ | ||||
process_from_file(f'test/data_for_tests/io/cws_pku') | |||||
process_from_file(f'tests/data_for_tests/io/cws_pku') | |||||
for word in ['<', '>', '<NUM>']: | for word in ['<', '>', '<NUM>']: | ||||
self.assertNotEqual(data_bundle.get_vocab('chars').to_index(word), 1) | self.assertNotEqual(data_bundle.get_vocab('chars').to_index(word), 1) |
@@ -33,13 +33,13 @@ class TestRunMatchingPipe(unittest.TestCase): | |||||
def test_load(self): | def test_load(self): | ||||
data_set_dict = { | data_set_dict = { | ||||
'RTE': ('test/data_for_tests/io/RTE', RTEPipe, RTEBertPipe, (5, 5, 5), (449, 2), True), | |||||
'SNLI': ('test/data_for_tests/io/SNLI', SNLIPipe, SNLIBertPipe, (5, 5, 5), (110, 3), False), | |||||
'QNLI': ('test/data_for_tests/io/QNLI', QNLIPipe, QNLIBertPipe, (5, 5, 5), (372, 2), True), | |||||
'MNLI': ('test/data_for_tests/io/MNLI', MNLIPipe, MNLIBertPipe, (5, 5, 5, 5, 6), (459, 3), True), | |||||
'BQCorpus': ('test/data_for_tests/io/BQCorpus', BQCorpusPipe, BQCorpusBertPipe, (5, 5, 5), (32, 2), False), | |||||
'XNLI': ('test/data_for_tests/io/XNLI', CNXNLIPipe, CNXNLIBertPipe, (6, 6, 8), (39, 3), False), | |||||
'LCQMC': ('test/data_for_tests/io/LCQMC', LCQMCPipe, LCQMCBertPipe, (6, 5, 6), (36, 2), False), | |||||
'RTE': ('tests/data_for_tests/io/RTE', RTEPipe, RTEBertPipe, (5, 5, 5), (449, 2), True), | |||||
'SNLI': ('tests/data_for_tests/io/SNLI', SNLIPipe, SNLIBertPipe, (5, 5, 5), (110, 3), False), | |||||
'QNLI': ('tests/data_for_tests/io/QNLI', QNLIPipe, QNLIBertPipe, (5, 5, 5), (372, 2), True), | |||||
'MNLI': ('tests/data_for_tests/io/MNLI', MNLIPipe, MNLIBertPipe, (5, 5, 5, 5, 6), (459, 3), True), | |||||
'BQCorpus': ('tests/data_for_tests/io/BQCorpus', BQCorpusPipe, BQCorpusBertPipe, (5, 5, 5), (32, 2), False), | |||||
'XNLI': ('tests/data_for_tests/io/XNLI', CNXNLIPipe, CNXNLIBertPipe, (6, 6, 8), (39, 3), False), | |||||
'LCQMC': ('tests/data_for_tests/io/LCQMC', LCQMCPipe, LCQMCBertPipe, (6, 5, 6), (36, 2), False), | |||||
} | } | ||||
for k, v in data_set_dict.items(): | for k, v in data_set_dict.items(): | ||||
path, pipe1, pipe2, data_set, vocab, warns = v | path, pipe1, pipe2, data_set, vocab, warns = v | ||||
@@ -76,7 +76,7 @@ class TestRunMatchingPipe(unittest.TestCase): | |||||
def test_spacy(self): | def test_spacy(self): | ||||
data_set_dict = { | data_set_dict = { | ||||
'Quora': ('test/data_for_tests/io/Quora', QuoraPipe, QuoraBertPipe, (2, 2, 2), (93, 2)), | |||||
'Quora': ('tests/data_for_tests/io/Quora', QuoraPipe, QuoraBertPipe, (2, 2, 2), (93, 2)), | |||||
} | } | ||||
for k, v in data_set_dict.items(): | for k, v in data_set_dict.items(): | ||||
path, pipe1, pipe2, data_set, vocab = v | path, pipe1, pipe2, data_set, vocab = v | ||||
@@ -6,7 +6,7 @@ from fastNLP.io.loader.qa import CMRC2018Loader | |||||
class CMRC2018PipeTest(unittest.TestCase): | class CMRC2018PipeTest(unittest.TestCase): | ||||
def test_process(self): | def test_process(self): | ||||
data_bundle = CMRC2018Loader().load('test/data_for_tests/io/cmrc/') | |||||
data_bundle = CMRC2018Loader().load('tests/data_for_tests/io/cmrc/') | |||||
pipe = CMRC2018BertPipe() | pipe = CMRC2018BertPipe() | ||||
data_bundle = pipe.process(data_bundle) | data_bundle = pipe.process(data_bundle) | ||||
@@ -27,9 +27,9 @@ from fastNLP.io.pipe.summarization import ExtCNNDMPipe | |||||
class TestRunExtCNNDMPipe(unittest.TestCase): | class TestRunExtCNNDMPipe(unittest.TestCase): | ||||
def test_load(self): | def test_load(self): | ||||
data_dir = 'test/data_for_tests/io/cnndm' | |||||
data_dir = 'tests/data_for_tests/io/cnndm' | |||||
vocab_size = 100000 | vocab_size = 100000 | ||||
VOCAL_FILE = 'test/data_for_tests/io/cnndm/vocab' | |||||
VOCAL_FILE = 'tests/data_for_tests/io/cnndm/vocab' | |||||
sent_max_len = 100 | sent_max_len = 100 | ||||
doc_max_timesteps = 50 | doc_max_timesteps = 50 | ||||
dbPipe = ExtCNNDMPipe(vocab_size=vocab_size, | dbPipe = ExtCNNDMPipe(vocab_size=vocab_size, | ||||
@@ -8,8 +8,8 @@ from fastNLP.io import EmbedLoader | |||||
class TestEmbedLoader(unittest.TestCase): | class TestEmbedLoader(unittest.TestCase): | ||||
def test_load_with_vocab(self): | def test_load_with_vocab(self): | ||||
vocab = Vocabulary() | vocab = Vocabulary() | ||||
glove = "test/data_for_tests/embedding/small_static_embedding/glove.6B.50d_test.txt" | |||||
word2vec = "test/data_for_tests/embedding/small_static_embedding/word2vec_test.txt" | |||||
glove = "tests/data_for_tests/embedding/small_static_embedding/glove.6B.50d_test.txt" | |||||
word2vec = "tests/data_for_tests/embedding/small_static_embedding/word2vec_test.txt" | |||||
vocab.add_word('the') | vocab.add_word('the') | ||||
vocab.add_word('none') | vocab.add_word('none') | ||||
g_m = EmbedLoader.load_with_vocab(glove, vocab) | g_m = EmbedLoader.load_with_vocab(glove, vocab) | ||||
@@ -20,8 +20,8 @@ class TestEmbedLoader(unittest.TestCase): | |||||
def test_load_without_vocab(self): | def test_load_without_vocab(self): | ||||
words = ['the', 'of', 'in', 'a', 'to', 'and'] | words = ['the', 'of', 'in', 'a', 'to', 'and'] | ||||
glove = "test/data_for_tests/embedding/small_static_embedding/glove.6B.50d_test.txt" | |||||
word2vec = "test/data_for_tests/embedding/small_static_embedding/word2vec_test.txt" | |||||
glove = "tests/data_for_tests/embedding/small_static_embedding/glove.6B.50d_test.txt" | |||||
word2vec = "tests/data_for_tests/embedding/small_static_embedding/word2vec_test.txt" | |||||
g_m, vocab = EmbedLoader.load_without_vocab(glove) | g_m, vocab = EmbedLoader.load_without_vocab(glove) | ||||
self.assertEqual(g_m.shape, (8, 50)) | self.assertEqual(g_m.shape, (8, 50)) | ||||
for word in words: | for word in words: | ||||
@@ -11,7 +11,7 @@ from fastNLP.embeddings.bert_embedding import BertEmbedding | |||||
class TestBert(unittest.TestCase): | class TestBert(unittest.TestCase): | ||||
def test_bert_1(self): | def test_bert_1(self): | ||||
vocab = Vocabulary().add_word_lst("this is a test .".split()) | vocab = Vocabulary().add_word_lst("this is a test .".split()) | ||||
embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', | |||||
embed = BertEmbedding(vocab, model_dir_or_name='tests/data_for_tests/embedding/small_bert', | |||||
include_cls_sep=True) | include_cls_sep=True) | ||||
model = BertForSequenceClassification(embed, 2) | model = BertForSequenceClassification(embed, 2) | ||||
@@ -30,7 +30,7 @@ class TestBert(unittest.TestCase): | |||||
def test_bert_1_w(self): | def test_bert_1_w(self): | ||||
vocab = Vocabulary().add_word_lst("this is a test .".split()) | vocab = Vocabulary().add_word_lst("this is a test .".split()) | ||||
embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', | |||||
embed = BertEmbedding(vocab, model_dir_or_name='tests/data_for_tests/embedding/small_bert', | |||||
include_cls_sep=False) | include_cls_sep=False) | ||||
with self.assertWarns(Warning): | with self.assertWarns(Warning): | ||||
@@ -46,7 +46,7 @@ class TestBert(unittest.TestCase): | |||||
def test_bert_2(self): | def test_bert_2(self): | ||||
vocab = Vocabulary().add_word_lst("this is a test [SEP] .".split()) | vocab = Vocabulary().add_word_lst("this is a test [SEP] .".split()) | ||||
embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', | |||||
embed = BertEmbedding(vocab, model_dir_or_name='tests/data_for_tests/embedding/small_bert', | |||||
include_cls_sep=True) | include_cls_sep=True) | ||||
model = BertForMultipleChoice(embed, 2) | model = BertForMultipleChoice(embed, 2) | ||||
@@ -62,7 +62,7 @@ class TestBert(unittest.TestCase): | |||||
def test_bert_2_w(self): | def test_bert_2_w(self): | ||||
vocab = Vocabulary().add_word_lst("this is a test [SEP] .".split()) | vocab = Vocabulary().add_word_lst("this is a test [SEP] .".split()) | ||||
embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', | |||||
embed = BertEmbedding(vocab, model_dir_or_name='tests/data_for_tests/embedding/small_bert', | |||||
include_cls_sep=False) | include_cls_sep=False) | ||||
with self.assertWarns(Warning): | with self.assertWarns(Warning): | ||||
@@ -79,7 +79,7 @@ class TestBert(unittest.TestCase): | |||||
def test_bert_3(self): | def test_bert_3(self): | ||||
vocab = Vocabulary().add_word_lst("this is a test [SEP] .".split()) | vocab = Vocabulary().add_word_lst("this is a test [SEP] .".split()) | ||||
embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', | |||||
embed = BertEmbedding(vocab, model_dir_or_name='tests/data_for_tests/embedding/small_bert', | |||||
include_cls_sep=False) | include_cls_sep=False) | ||||
model = BertForTokenClassification(embed, 7) | model = BertForTokenClassification(embed, 7) | ||||
@@ -93,7 +93,7 @@ class TestBert(unittest.TestCase): | |||||
def test_bert_3_w(self): | def test_bert_3_w(self): | ||||
vocab = Vocabulary().add_word_lst("this is a test [SEP] .".split()) | vocab = Vocabulary().add_word_lst("this is a test [SEP] .".split()) | ||||
embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', | |||||
embed = BertEmbedding(vocab, model_dir_or_name='tests/data_for_tests/embedding/small_bert', | |||||
include_cls_sep=True) | include_cls_sep=True) | ||||
with self.assertWarns(Warning): | with self.assertWarns(Warning): | ||||
@@ -108,7 +108,7 @@ class TestBert(unittest.TestCase): | |||||
def test_bert_4(self): | def test_bert_4(self): | ||||
vocab = Vocabulary().add_word_lst("this is a test [SEP] .".split()) | vocab = Vocabulary().add_word_lst("this is a test [SEP] .".split()) | ||||
embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', | |||||
embed = BertEmbedding(vocab, model_dir_or_name='tests/data_for_tests/embedding/small_bert', | |||||
include_cls_sep=False) | include_cls_sep=False) | ||||
model = BertForQuestionAnswering(embed) | model = BertForQuestionAnswering(embed) | ||||
@@ -126,12 +126,12 @@ class TestBert(unittest.TestCase): | |||||
from fastNLP.io import CMRC2018BertPipe | from fastNLP.io import CMRC2018BertPipe | ||||
from fastNLP import Trainer | from fastNLP import Trainer | ||||
data_bundle = CMRC2018BertPipe().process_from_file('test/data_for_tests/io/cmrc') | |||||
data_bundle = CMRC2018BertPipe().process_from_file('tests/data_for_tests/io/cmrc') | |||||
data_bundle.rename_field('chars', 'words') | data_bundle.rename_field('chars', 'words') | ||||
train_data = data_bundle.get_dataset('train') | train_data = data_bundle.get_dataset('train') | ||||
vocab = data_bundle.get_vocab('words') | vocab = data_bundle.get_vocab('words') | ||||
embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', | |||||
embed = BertEmbedding(vocab, model_dir_or_name='tests/data_for_tests/embedding/small_bert', | |||||
include_cls_sep=False, auto_truncate=True) | include_cls_sep=False, auto_truncate=True) | ||||
model = BertForQuestionAnswering(embed) | model = BertForQuestionAnswering(embed) | ||||
loss = CMRC2018Loss() | loss = CMRC2018Loss() | ||||
@@ -142,7 +142,7 @@ class TestBert(unittest.TestCase): | |||||
def test_bert_5(self): | def test_bert_5(self): | ||||
vocab = Vocabulary().add_word_lst("this is a test [SEP] .".split()) | vocab = Vocabulary().add_word_lst("this is a test [SEP] .".split()) | ||||
embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', | |||||
embed = BertEmbedding(vocab, model_dir_or_name='tests/data_for_tests/embedding/small_bert', | |||||
include_cls_sep=True) | include_cls_sep=True) | ||||
model = BertForSentenceMatching(embed) | model = BertForSentenceMatching(embed) | ||||
@@ -156,7 +156,7 @@ class TestBert(unittest.TestCase): | |||||
def test_bert_5_w(self): | def test_bert_5_w(self): | ||||
vocab = Vocabulary().add_word_lst("this is a test [SEP] .".split()) | vocab = Vocabulary().add_word_lst("this is a test [SEP] .".split()) | ||||
embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', | |||||
embed = BertEmbedding(vocab, model_dir_or_name='tests/data_for_tests/embedding/small_bert', | |||||
include_cls_sep=False) | include_cls_sep=False) | ||||
with self.assertWarns(Warning): | with self.assertWarns(Warning): | ||||
@@ -223,7 +223,7 @@ class TestCRF(unittest.TestCase): | |||||
import torch | import torch | ||||
from fastNLP import seq_len_to_mask | from fastNLP import seq_len_to_mask | ||||
with open('test/data_for_tests/modules/decoder/crf.json', 'r') as f: | |||||
with open('tests/data_for_tests/modules/decoder/crf.json', 'r') as f: | |||||
data = json.load(f) | data = json.load(f) | ||||
bio_logits = torch.FloatTensor(data['bio_logits']) | bio_logits = torch.FloatTensor(data['bio_logits']) | ||||
@@ -5,7 +5,7 @@ from fastNLP.modules.tokenizer import BertTokenizer | |||||
class TestBertTokenizer(unittest.TestCase): | class TestBertTokenizer(unittest.TestCase): | ||||
def test_run(self): | def test_run(self): | ||||
# 测试支持的两种encode方式 | # 测试支持的两种encode方式 | ||||
tokenizer = BertTokenizer.from_pretrained('test/data_for_tests/embedding/small_bert') | |||||
tokenizer = BertTokenizer.from_pretrained('tests/data_for_tests/embedding/small_bert') | |||||
tokens1 = tokenizer.encode("This is a demo") | tokens1 = tokenizer.encode("This is a demo") | ||||
tokens2 = tokenizer.encode("This is a demo", add_special_tokens=False) | tokens2 = tokenizer.encode("This is a demo", add_special_tokens=False) | ||||
@@ -85,7 +85,7 @@ class TestTutorial(unittest.TestCase): | |||||
class TestOldTutorial(unittest.TestCase): | class TestOldTutorial(unittest.TestCase): | ||||
def test_fastnlp_10min_tutorial(self): | def test_fastnlp_10min_tutorial(self): | ||||
# 从csv读取数据到DataSet | # 从csv读取数据到DataSet | ||||
sample_path = "test/data_for_tests/tutorial_sample_dataset.csv" | |||||
sample_path = "tests/data_for_tests/tutorial_sample_dataset.csv" | |||||
dataset = CSVLoader(headers=['raw_sentence', 'label'], sep=' ')._load(sample_path) | dataset = CSVLoader(headers=['raw_sentence', 'label'], sep=' ')._load(sample_path) | ||||
print(len(dataset)) | print(len(dataset)) | ||||
print(dataset[0]) | print(dataset[0]) | ||||
@@ -183,7 +183,7 @@ class TestOldTutorial(unittest.TestCase): | |||||
def test_fastnlp_1min_tutorial(self): | def test_fastnlp_1min_tutorial(self): | ||||
# tutorials/fastnlp_1min_tutorial.ipynb | # tutorials/fastnlp_1min_tutorial.ipynb | ||||
data_path = "test/data_for_tests/tutorial_sample_dataset.csv" | |||||
data_path = "tests/data_for_tests/tutorial_sample_dataset.csv" | |||||
ds = CSVLoader(headers=['raw_sentence', 'label'], sep=' ')._load(data_path) | ds = CSVLoader(headers=['raw_sentence', 'label'], sep=' ')._load(data_path) | ||||
print(ds[1]) | print(ds[1]) | ||||