@@ -569,7 +569,7 @@ class FitlogCallback(Callback): | |||||
batch_size=self.trainer.kwargs.get('dev_batch_size', self.batch_size), | batch_size=self.trainer.kwargs.get('dev_batch_size', self.batch_size), | ||||
metrics=self.trainer.metrics, | metrics=self.trainer.metrics, | ||||
verbose=0, | verbose=0, | ||||
use_tqdm=self.trainer.use_tqdm) | |||||
use_tqdm=self.trainer.test_use_tqdm) | |||||
self.testers[key] = tester | self.testers[key] = tester | ||||
fitlog.add_progress(total_steps=self.n_steps) | fitlog.add_progress(total_steps=self.n_steps) | ||||
@@ -654,7 +654,7 @@ class EvaluateCallback(Callback): | |||||
tester = Tester(data=data, model=self.model, | tester = Tester(data=data, model=self.model, | ||||
batch_size=self.trainer.kwargs.get('dev_batch_size', self.batch_size), | batch_size=self.trainer.kwargs.get('dev_batch_size', self.batch_size), | ||||
metrics=self.trainer.metrics, verbose=0, | metrics=self.trainer.metrics, verbose=0, | ||||
use_tqdm=self.trainer.use_tqdm) | |||||
use_tqdm=self.trainer.test_use_tqdm) | |||||
self.testers[key] = tester | self.testers[key] = tester | ||||
def on_valid_end(self, eval_result, metric_key, optimizer, better_result): | def on_valid_end(self, eval_result, metric_key, optimizer, better_result): | ||||
@@ -545,6 +545,10 @@ class Trainer(object): | |||||
self.logger = logger | self.logger = logger | ||||
self.use_tqdm = use_tqdm | self.use_tqdm = use_tqdm | ||||
if 'test_use_tqdm' in kwargs: | |||||
self.test_use_tqdm = kwargs.get('test_use_tqdm') | |||||
else: | |||||
self.test_use_tqdm = self.use_tqdm | |||||
self.pbar = None | self.pbar = None | ||||
self.print_every = abs(self.print_every) | self.print_every = abs(self.print_every) | ||||
self.kwargs = kwargs | self.kwargs = kwargs | ||||
@@ -555,7 +559,7 @@ class Trainer(object): | |||||
batch_size=kwargs.get("dev_batch_size", self.batch_size), | batch_size=kwargs.get("dev_batch_size", self.batch_size), | ||||
device=None, # 由上面的部分处理device | device=None, # 由上面的部分处理device | ||||
verbose=0, | verbose=0, | ||||
use_tqdm=self.use_tqdm) | |||||
use_tqdm=self.test_use_tqdm) | |||||
self.step = 0 | self.step = 0 | ||||
self.start_time = None # start timestamp | self.start_time = None # start timestamp | ||||
@@ -115,7 +115,7 @@ class BertEmbedding(ContextualEmbedding): | |||||
if self._word_sep_index: # 不能drop sep | if self._word_sep_index: # 不能drop sep | ||||
sep_mask = words.eq(self._word_sep_index) | sep_mask = words.eq(self._word_sep_index) | ||||
mask = torch.ones_like(words).float() * self.word_dropout | mask = torch.ones_like(words).float() * self.word_dropout | ||||
mask = torch.bernoulli(mask).byte() # dropout_word越大,越多位置为1 | |||||
mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1 | |||||
words = words.masked_fill(mask, self._word_unk_index) | words = words.masked_fill(mask, self._word_unk_index) | ||||
if self._word_sep_index: | if self._word_sep_index: | ||||
words.masked_fill_(sep_mask, self._word_sep_index) | words.masked_fill_(sep_mask, self._word_sep_index) | ||||
@@ -252,7 +252,7 @@ class BertWordPieceEncoder(nn.Module): | |||||
if self._word_sep_index: # 不能drop sep | if self._word_sep_index: # 不能drop sep | ||||
sep_mask = words.eq(self._wordpiece_unk_index) | sep_mask = words.eq(self._wordpiece_unk_index) | ||||
mask = torch.ones_like(words).float() * self.word_dropout | mask = torch.ones_like(words).float() * self.word_dropout | ||||
mask = torch.bernoulli(mask).byte() # dropout_word越大,越多位置为1 | |||||
mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1 | |||||
words = words.masked_fill(mask, self._word_unk_index) | words = words.masked_fill(mask, self._word_unk_index) | ||||
if self._word_sep_index: | if self._word_sep_index: | ||||
words.masked_fill_(sep_mask, self._wordpiece_unk_index) | words.masked_fill_(sep_mask, self._wordpiece_unk_index) | ||||
@@ -63,7 +63,7 @@ class Embedding(nn.Module): | |||||
""" | """ | ||||
if self.word_dropout>0 and self.training: | if self.word_dropout>0 and self.training: | ||||
mask = torch.ones_like(words).float() * self.word_dropout | mask = torch.ones_like(words).float() * self.word_dropout | ||||
mask = torch.bernoulli(mask).byte() # dropout_word越大,越多位置为1 | |||||
mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1 | |||||
words = words.masked_fill(mask, self.unk_index) | words = words.masked_fill(mask, self.unk_index) | ||||
words = self.embed(words) | words = self.embed(words) | ||||
return self.dropout(words) | return self.dropout(words) | ||||
@@ -135,7 +135,7 @@ class TokenEmbedding(nn.Module): | |||||
""" | """ | ||||
if self.word_dropout > 0 and self.training: | if self.word_dropout > 0 and self.training: | ||||
mask = torch.ones_like(words).float() * self.word_dropout | mask = torch.ones_like(words).float() * self.word_dropout | ||||
mask = torch.bernoulli(mask).byte() # dropout_word越大,越多位置为1 | |||||
mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1 | |||||
words = words.masked_fill(mask, self._word_unk_index) | words = words.masked_fill(mask, self._word_unk_index) | ||||
return words | return words | ||||
@@ -106,6 +106,7 @@ class StaticEmbedding(TokenEmbedding): | |||||
print(f"{len(vocab) - len(truncated_vocab)} out of {len(vocab)} words have frequency less than {min_freq}.") | print(f"{len(vocab) - len(truncated_vocab)} out of {len(vocab)} words have frequency less than {min_freq}.") | ||||
vocab = truncated_vocab | vocab = truncated_vocab | ||||
self.only_norm_found_vector = kwargs.get('only_norm_found_vector', False) | |||||
# 读取embedding | # 读取embedding | ||||
if lower: | if lower: | ||||
lowered_vocab = Vocabulary(padding=vocab.padding, unknown=vocab.unknown) | lowered_vocab = Vocabulary(padding=vocab.padding, unknown=vocab.unknown) | ||||
@@ -142,7 +143,7 @@ class StaticEmbedding(TokenEmbedding): | |||||
else: | else: | ||||
embedding = self._randomly_init_embed(len(vocab), embedding_dim, init_method) | embedding = self._randomly_init_embed(len(vocab), embedding_dim, init_method) | ||||
self.words_to_words = nn.Parameter(torch.arange(len(vocab)).long(), requires_grad=False) | self.words_to_words = nn.Parameter(torch.arange(len(vocab)).long(), requires_grad=False) | ||||
if normalize: | |||||
if not self.only_norm_found_vector and normalize: | |||||
embedding /= (torch.norm(embedding, dim=1, keepdim=True) + 1e-12) | embedding /= (torch.norm(embedding, dim=1, keepdim=True) + 1e-12) | ||||
if truncate_vocab: | if truncate_vocab: | ||||
@@ -233,6 +234,7 @@ class StaticEmbedding(TokenEmbedding): | |||||
if vocab.unknown: | if vocab.unknown: | ||||
matrix[vocab.unknown_idx] = torch.zeros(dim) | matrix[vocab.unknown_idx] = torch.zeros(dim) | ||||
found_count = 0 | found_count = 0 | ||||
found_unknown = False | |||||
for idx, line in enumerate(f, start_idx): | for idx, line in enumerate(f, start_idx): | ||||
try: | try: | ||||
parts = line.strip().split() | parts = line.strip().split() | ||||
@@ -243,9 +245,12 @@ class StaticEmbedding(TokenEmbedding): | |||||
word = vocab.padding | word = vocab.padding | ||||
elif word == unknown and vocab.unknown is not None: | elif word == unknown and vocab.unknown is not None: | ||||
word = vocab.unknown | word = vocab.unknown | ||||
found_unknown = True | |||||
if word in vocab: | if word in vocab: | ||||
index = vocab.to_index(word) | index = vocab.to_index(word) | ||||
matrix[index] = torch.from_numpy(np.fromstring(' '.join(nums), sep=' ', dtype=dtype, count=dim)) | matrix[index] = torch.from_numpy(np.fromstring(' '.join(nums), sep=' ', dtype=dtype, count=dim)) | ||||
if self.only_norm_found_vector: | |||||
matrix[index] = matrix[index]/np.linalg.norm(matrix[index]) | |||||
found_count += 1 | found_count += 1 | ||||
except Exception as e: | except Exception as e: | ||||
if error == 'ignore': | if error == 'ignore': | ||||
@@ -256,7 +261,7 @@ class StaticEmbedding(TokenEmbedding): | |||||
print("Found {} out of {} words in the pre-training embedding.".format(found_count, len(vocab))) | print("Found {} out of {} words in the pre-training embedding.".format(found_count, len(vocab))) | ||||
for word, index in vocab: | for word, index in vocab: | ||||
if index not in matrix and not vocab._is_word_no_create_entry(word): | if index not in matrix and not vocab._is_word_no_create_entry(word): | ||||
if vocab.unknown_idx in matrix: # 如果有unkonwn,用unknown初始化 | |||||
if found_unknown: # 如果有unkonwn,用unknown初始化 | |||||
matrix[index] = matrix[vocab.unknown_idx] | matrix[index] = matrix[vocab.unknown_idx] | ||||
else: | else: | ||||
matrix[index] = None | matrix[index] = None | ||||
@@ -150,7 +150,7 @@ class GraphParser(BaseModel): | |||||
""" | """ | ||||
_, seq_len, _ = arc_matrix.shape | _, seq_len, _ = arc_matrix.shape | ||||
matrix = arc_matrix + torch.diag(arc_matrix.new(seq_len).fill_(-np.inf)) | matrix = arc_matrix + torch.diag(arc_matrix.new(seq_len).fill_(-np.inf)) | ||||
flip_mask = (mask == 0).byte() | |||||
flip_mask = mask.eq(0) | |||||
matrix.masked_fill_(flip_mask.unsqueeze(1), -np.inf) | matrix.masked_fill_(flip_mask.unsqueeze(1), -np.inf) | ||||
_, heads = torch.max(matrix, dim=2) | _, heads = torch.max(matrix, dim=2) | ||||
if mask is not None: | if mask is not None: | ||||
@@ -210,7 +210,7 @@ class ConditionalRandomField(nn.Module): | |||||
trans_score = self.trans_m.view(1, n_tags, n_tags) | trans_score = self.trans_m.view(1, n_tags, n_tags) | ||||
tmp = alpha.view(batch_size, n_tags, 1) + emit_score + trans_score | tmp = alpha.view(batch_size, n_tags, 1) + emit_score + trans_score | ||||
alpha = torch.logsumexp(tmp, 1).masked_fill(flip_mask[i].view(batch_size, 1), 0) + \ | alpha = torch.logsumexp(tmp, 1).masked_fill(flip_mask[i].view(batch_size, 1), 0) + \ | ||||
alpha.masked_fill(mask[i].byte().view(batch_size, 1), 0) | |||||
alpha.masked_fill(mask[i].eq(1).view(batch_size, 1), 0) | |||||
if self.include_start_end_trans: | if self.include_start_end_trans: | ||||
alpha = alpha + self.end_scores.view(1, -1) | alpha = alpha + self.end_scores.view(1, -1) | ||||
@@ -230,7 +230,7 @@ class ConditionalRandomField(nn.Module): | |||||
seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device) | seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device) | ||||
# trans_socre [L-1, B] | # trans_socre [L-1, B] | ||||
mask = mask.byte() | |||||
mask = mask.eq(1) | |||||
flip_mask = mask.eq(0) | flip_mask = mask.eq(0) | ||||
trans_score = self.trans_m[tags[:seq_len - 1], tags[1:]].masked_fill(flip_mask[1:, :], 0) | trans_score = self.trans_m[tags[:seq_len - 1], tags[1:]].masked_fill(flip_mask[1:, :], 0) | ||||
# emit_score [L, B] | # emit_score [L, B] | ||||
@@ -278,7 +278,7 @@ class ConditionalRandomField(nn.Module): | |||||
""" | """ | ||||
batch_size, seq_len, n_tags = logits.size() | batch_size, seq_len, n_tags = logits.size() | ||||
logits = logits.transpose(0, 1).data # L, B, H | logits = logits.transpose(0, 1).data # L, B, H | ||||
mask = mask.transpose(0, 1).data.byte() # L, B | |||||
mask = mask.transpose(0, 1).data.eq(1) # L, B | |||||
# dp | # dp | ||||
vpath = logits.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) | vpath = logits.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) | ||||
@@ -27,7 +27,7 @@ def viterbi_decode(logits, transitions, mask=None, unpad=False): | |||||
"compatible." | "compatible." | ||||
logits = logits.transpose(0, 1).data # L, B, H | logits = logits.transpose(0, 1).data # L, B, H | ||||
if mask is not None: | if mask is not None: | ||||
mask = mask.transpose(0, 1).data.byte() # L, B | |||||
mask = mask.transpose(0, 1).data.eq(1) # L, B | |||||
else: | else: | ||||
mask = logits.new_ones((seq_len, batch_size), dtype=torch.uint8) | mask = logits.new_ones((seq_len, batch_size), dtype=torch.uint8) | ||||
@@ -9,6 +9,6 @@ class TestDownload(unittest.TestCase): | |||||
def test_download(self): | def test_download(self): | ||||
# import os | # import os | ||||
vocab = Vocabulary().add_word_lst("This is a test .".split()) | vocab = Vocabulary().add_word_lst("This is a test .".split()) | ||||
embed = BertEmbedding(vocab, model_dir_or_name='/remote-home/source/fastnlp_caches/embedding/bert-base-cased') | |||||
embed = BertEmbedding(vocab, model_dir_or_name='en') | |||||
words = torch.LongTensor([[0, 1, 2]]) | words = torch.LongTensor([[0, 1, 2]]) | ||||
print(embed(words).size()) | print(embed(words).size()) |
@@ -5,6 +5,23 @@ from fastNLP import Vocabulary | |||||
import torch | import torch | ||||
import os | import os | ||||
class TestLoad(unittest.TestCase): | |||||
def test_norm1(self): | |||||
# 测试只对可以找到的norm | |||||
vocab = Vocabulary().add_word_lst(['the', 'a', 'notinfile']) | |||||
embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/glove.6B.50d_test.txt', | |||||
only_norm_found_vector=True) | |||||
self.assertEqual(round(torch.norm(embed(torch.LongTensor([[2]]))).item(), 4), 1) | |||||
self.assertNotEqual(torch.norm(embed(torch.LongTensor([[4]]))).item(), 1) | |||||
def test_norm2(self): | |||||
# 测试对所有都norm | |||||
vocab = Vocabulary().add_word_lst(['the', 'a', 'notinfile']) | |||||
embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/glove.6B.50d_test.txt', | |||||
normalize=True) | |||||
self.assertEqual(round(torch.norm(embed(torch.LongTensor([[2]]))).item(), 4), 1) | |||||
self.assertEqual(round(torch.norm(embed(torch.LongTensor([[4]]))).item(), 4), 1) | |||||
class TestRandomSameEntry(unittest.TestCase): | class TestRandomSameEntry(unittest.TestCase): | ||||
def test_same_vector(self): | def test_same_vector(self): | ||||
vocab = Vocabulary().add_word_lst(["The", "the", "THE", 'a', "A"]) | vocab = Vocabulary().add_word_lst(["The", "the", "THE", 'a', "A"]) | ||||
@@ -21,7 +38,7 @@ class TestRandomSameEntry(unittest.TestCase): | |||||
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") | @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") | ||||
def test_same_vector2(self): | def test_same_vector2(self): | ||||
vocab = Vocabulary().add_word_lst(["The", 'a', 'b', "the", "THE", "B", 'a', "A"]) | vocab = Vocabulary().add_word_lst(["The", 'a', 'b', "the", "THE", "B", 'a', "A"]) | ||||
embed = StaticEmbedding(vocab, model_dir_or_name='/remote-home/source/fastnlp_caches/glove.6B.100d/glove.6B.100d.txt', | |||||
embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-6B-100d', | |||||
lower=True) | lower=True) | ||||
words = torch.LongTensor([[vocab.to_index(word) for word in ["The", "the", "THE", 'b', "B", 'a', 'A']]]) | words = torch.LongTensor([[vocab.to_index(word) for word in ["The", "the", "THE", 'b', "B", 'a', 'A']]]) | ||||
words = embed(words) | words = embed(words) | ||||
@@ -39,7 +56,7 @@ class TestRandomSameEntry(unittest.TestCase): | |||||
no_create_word_lst = ['of', 'Of', 'With', 'with'] | no_create_word_lst = ['of', 'Of', 'With', 'with'] | ||||
vocab = Vocabulary().add_word_lst(word_lst) | vocab = Vocabulary().add_word_lst(word_lst) | ||||
vocab.add_word_lst(no_create_word_lst, no_create_entry=True) | vocab.add_word_lst(no_create_word_lst, no_create_entry=True) | ||||
embed = StaticEmbedding(vocab, model_dir_or_name='/remote-home/source/fastnlp_caches/glove.6B.100d/glove.demo.txt', | |||||
embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-6B-100d', | |||||
lower=True) | lower=True) | ||||
words = torch.LongTensor([[vocab.to_index(word) for word in word_lst+no_create_word_lst]]) | words = torch.LongTensor([[vocab.to_index(word) for word in word_lst+no_create_word_lst]]) | ||||
words = embed(words) | words = embed(words) | ||||
@@ -48,7 +65,7 @@ class TestRandomSameEntry(unittest.TestCase): | |||||
lowered_no_create_word_lst = [word.lower() for word in no_create_word_lst] | lowered_no_create_word_lst = [word.lower() for word in no_create_word_lst] | ||||
lowered_vocab = Vocabulary().add_word_lst(lowered_word_lst) | lowered_vocab = Vocabulary().add_word_lst(lowered_word_lst) | ||||
lowered_vocab.add_word_lst(lowered_no_create_word_lst, no_create_entry=True) | lowered_vocab.add_word_lst(lowered_no_create_word_lst, no_create_entry=True) | ||||
lowered_embed = StaticEmbedding(lowered_vocab, model_dir_or_name='/remote-home/source/fastnlp_caches/glove.6B.100d/glove.demo.txt', | |||||
lowered_embed = StaticEmbedding(lowered_vocab, model_dir_or_name='en-glove-6B-100d', | |||||
lower=False) | lower=False) | ||||
lowered_words = torch.LongTensor([[lowered_vocab.to_index(word) for word in lowered_word_lst+lowered_no_create_word_lst]]) | lowered_words = torch.LongTensor([[lowered_vocab.to_index(word) for word in lowered_word_lst+lowered_no_create_word_lst]]) | ||||
lowered_words = lowered_embed(lowered_words) | lowered_words = lowered_embed(lowered_words) | ||||
@@ -67,7 +84,7 @@ class TestRandomSameEntry(unittest.TestCase): | |||||
all_words = word_lst[:-2] + no_create_word_lst[:-2] | all_words = word_lst[:-2] + no_create_word_lst[:-2] | ||||
vocab = Vocabulary(min_freq=2).add_word_lst(word_lst) | vocab = Vocabulary(min_freq=2).add_word_lst(word_lst) | ||||
vocab.add_word_lst(no_create_word_lst, no_create_entry=True) | vocab.add_word_lst(no_create_word_lst, no_create_entry=True) | ||||
embed = StaticEmbedding(vocab, model_dir_or_name='/remote-home/source/fastnlp_caches/glove.6B.100d/glove.demo.txt', | |||||
embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-6B-100d', | |||||
lower=True) | lower=True) | ||||
words = torch.LongTensor([[vocab.to_index(word) for word in all_words]]) | words = torch.LongTensor([[vocab.to_index(word) for word in all_words]]) | ||||
words = embed(words) | words = embed(words) | ||||
@@ -76,7 +93,7 @@ class TestRandomSameEntry(unittest.TestCase): | |||||
lowered_no_create_word_lst = [word.lower() for word in no_create_word_lst] | lowered_no_create_word_lst = [word.lower() for word in no_create_word_lst] | ||||
lowered_vocab = Vocabulary().add_word_lst(lowered_word_lst) | lowered_vocab = Vocabulary().add_word_lst(lowered_word_lst) | ||||
lowered_vocab.add_word_lst(lowered_no_create_word_lst, no_create_entry=True) | lowered_vocab.add_word_lst(lowered_no_create_word_lst, no_create_entry=True) | ||||
lowered_embed = StaticEmbedding(lowered_vocab, model_dir_or_name='/remote-home/source/fastnlp_caches/glove.6B.100d/glove.demo.txt', | |||||
lowered_embed = StaticEmbedding(lowered_vocab, model_dir_or_name='en-glove-6B-100d', | |||||
lower=False) | lower=False) | ||||
lowered_words = torch.LongTensor([[lowered_vocab.to_index(word.lower()) for word in all_words]]) | lowered_words = torch.LongTensor([[lowered_vocab.to_index(word.lower()) for word in all_words]]) | ||||
lowered_words = lowered_embed(lowered_words) | lowered_words = lowered_embed(lowered_words) | ||||
@@ -94,14 +111,14 @@ class TestRandomSameEntry(unittest.TestCase): | |||||
all_words = word_lst[:-2] + no_create_word_lst[:-2] | all_words = word_lst[:-2] + no_create_word_lst[:-2] | ||||
vocab = Vocabulary().add_word_lst(word_lst) | vocab = Vocabulary().add_word_lst(word_lst) | ||||
vocab.add_word_lst(no_create_word_lst, no_create_entry=True) | vocab.add_word_lst(no_create_word_lst, no_create_entry=True) | ||||
embed = StaticEmbedding(vocab, model_dir_or_name='/remote-home/source/fastnlp_caches/glove.6B.100d/glove.demo.txt', | |||||
embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-6B-100d', | |||||
lower=False, min_freq=2) | lower=False, min_freq=2) | ||||
words = torch.LongTensor([[vocab.to_index(word) for word in all_words]]) | words = torch.LongTensor([[vocab.to_index(word) for word in all_words]]) | ||||
words = embed(words) | words = embed(words) | ||||
min_freq_vocab = Vocabulary(min_freq=2).add_word_lst(word_lst) | min_freq_vocab = Vocabulary(min_freq=2).add_word_lst(word_lst) | ||||
min_freq_vocab.add_word_lst(no_create_word_lst, no_create_entry=True) | min_freq_vocab.add_word_lst(no_create_word_lst, no_create_entry=True) | ||||
min_freq_embed = StaticEmbedding(min_freq_vocab, model_dir_or_name='/remote-home/source/fastnlp_caches/glove.6B.100d/glove.demo.txt', | |||||
min_freq_embed = StaticEmbedding(min_freq_vocab, model_dir_or_name='en-glove-6B-100d', | |||||
lower=False) | lower=False) | ||||
min_freq_words = torch.LongTensor([[min_freq_vocab.to_index(word.lower()) for word in all_words]]) | min_freq_words = torch.LongTensor([[min_freq_vocab.to_index(word.lower()) for word in all_words]]) | ||||
min_freq_words = min_freq_embed(min_freq_words) | min_freq_words = min_freq_embed(min_freq_words) | ||||
@@ -5,14 +5,13 @@ from fastNLP import Instance | |||||
from fastNLP import Vocabulary | from fastNLP import Vocabulary | ||||
from fastNLP.core.losses import CrossEntropyLoss | from fastNLP.core.losses import CrossEntropyLoss | ||||
from fastNLP.core.metrics import AccuracyMetric | from fastNLP.core.metrics import AccuracyMetric | ||||
from fastNLP.io.loader import CSVLoader | |||||
class TestTutorial(unittest.TestCase): | class TestTutorial(unittest.TestCase): | ||||
def test_fastnlp_10min_tutorial(self): | def test_fastnlp_10min_tutorial(self): | ||||
# 从csv读取数据到DataSet | # 从csv读取数据到DataSet | ||||
sample_path = "test/data_for_tests/tutorial_sample_dataset.csv" | sample_path = "test/data_for_tests/tutorial_sample_dataset.csv" | ||||
dataset = DataSet.read_csv(sample_path, headers=('raw_sentence', 'label'), | |||||
sep='\t') | |||||
dataset = CSVLoader(headers=['raw_sentence', 'label'], sep=' ')._load(sample_path) | |||||
print(len(dataset)) | print(len(dataset)) | ||||
print(dataset[0]) | print(dataset[0]) | ||||
print(dataset[-3]) | print(dataset[-3]) | ||||
@@ -110,7 +109,7 @@ class TestTutorial(unittest.TestCase): | |||||
def test_fastnlp_1min_tutorial(self): | def test_fastnlp_1min_tutorial(self): | ||||
# tutorials/fastnlp_1min_tutorial.ipynb | # tutorials/fastnlp_1min_tutorial.ipynb | ||||
data_path = "test/data_for_tests/tutorial_sample_dataset.csv" | data_path = "test/data_for_tests/tutorial_sample_dataset.csv" | ||||
ds = DataSet.read_csv(data_path, headers=('raw_sentence', 'label'), sep='\t') | |||||
ds = CSVLoader(headers=['raw_sentence', 'label'], sep=' ')._load(data_path) | |||||
print(ds[1]) | print(ds[1]) | ||||
# 将所有数字转为小写 | # 将所有数字转为小写 | ||||