Browse Source

防止BERTEmbedding在中文场景下被错误使用

tags/v0.5.5
yh_cc 5 years ago
parent
commit
8fe7a4f191
4 changed files with 28 additions and 3 deletions
  1. +2
    -1
      fastNLP/core/trainer.py
  2. +4
    -1
      fastNLP/embeddings/bert_embedding.py
  3. +4
    -1
      fastNLP/modules/encoder/bert.py
  4. +18
    -0
      test/core/test_trainer.py

+ 2
- 1
fastNLP/core/trainer.py View File

@@ -599,7 +599,8 @@ class Trainer(object):
self._model_device = _get_model_device(self.model)
self._mode(self.model, is_test=False)
self._load_best_model = load_best_model
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S'))
# 加上millsecond,防止两个太接近的保存
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S-%f'))
start_time = time.time()
self.logger.info("training epochs started " + self.start_time)
self.step = 0


+ 4
- 1
fastNLP/embeddings/bert_embedding.py View File

@@ -294,7 +294,10 @@ class _WordBertModel(nn.Module):
word = '[PAD]'
elif index == vocab.unknown_idx:
word = '[UNK]'
word_pieces = self.tokenzier.wordpiece_tokenizer.tokenize(word)
_words = self.tokenzier.basic_tokenizer._tokenize_chinese_chars(word).split()
word_pieces = []
for w in _words:
word_pieces.extend(self.tokenzier.wordpiece_tokenizer.tokenize(w))
if len(word_pieces) == 1:
if not vocab._is_word_no_create_entry(word): # 如果是train中的值, 但是却没有找到
if index != vocab.unknown_idx and word_pieces[0] == '[UNK]': # 说明这个词不在原始的word里面


+ 4
- 1
fastNLP/modules/encoder/bert.py View File

@@ -989,7 +989,10 @@ class _WordPieceBertModel(nn.Module):
def convert_words_to_word_pieces(words):
word_pieces = []
for word in words:
tokens = self.tokenzier.wordpiece_tokenizer.tokenize(word)
_words = self.tokenzier.basic_tokenizer._tokenize_chinese_chars(word).split()
tokens = []
for word in _words:
tokens.extend(self.tokenzier.wordpiece_tokenizer.tokenize(word))
word_piece_ids = self.tokenzier.convert_tokens_to_ids(tokens)
word_pieces.extend(word_piece_ids)
if add_cls_sep:


+ 18
- 0
test/core/test_trainer.py View File

@@ -54,6 +54,24 @@ class TrainerTestGround(unittest.TestCase):
"""
# 应该正确运行
"""

def test_save_path(self):
data_set = prepare_fake_dataset()
data_set.set_input("x", flag=True)
data_set.set_target("y", flag=True)

train_set, dev_set = data_set.split(0.3)

model = NaiveClassifier(2, 1)

save_path = 'test_save_models'

trainer = Trainer(train_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"),
batch_size=32, n_epochs=10, print_every=50, dev_data=dev_set,
metrics=AccuracyMetric(pred="predict", target="y"), validate_every=-1, save_path=save_path,
use_tqdm=True, check_code_level=2)
trainer.train()

def test_trainer_suggestion1(self):
# 检查报错提示能否正确提醒用户。


Loading…
Cancel
Save