From 8fe7a4f191f8a16806f5fab618808cb1eac4314d Mon Sep 17 00:00:00 2001 From: yh_cc Date: Thu, 19 Mar 2020 14:22:15 +0800 Subject: [PATCH] =?UTF-8?q?=E9=98=B2=E6=AD=A2BERTEmbedding=E5=9C=A8?= =?UTF-8?q?=E4=B8=AD=E6=96=87=E5=9C=BA=E6=99=AF=E4=B8=8B=E8=A2=AB=E9=94=99?= =?UTF-8?q?=E8=AF=AF=E4=BD=BF=E7=94=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/trainer.py | 3 ++- fastNLP/embeddings/bert_embedding.py | 5 ++++- fastNLP/modules/encoder/bert.py | 5 ++++- test/core/test_trainer.py | 18 ++++++++++++++++++ 4 files changed, 28 insertions(+), 3 deletions(-) diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index af68158c..27e1e2b1 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -599,7 +599,8 @@ class Trainer(object): self._model_device = _get_model_device(self.model) self._mode(self.model, is_test=False) self._load_best_model = load_best_model - self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S')) + # 加上millsecond,防止两个太接近的保存 + self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S-%f')) start_time = time.time() self.logger.info("training epochs started " + self.start_time) self.step = 0 diff --git a/fastNLP/embeddings/bert_embedding.py b/fastNLP/embeddings/bert_embedding.py index c81a4463..1ea4abc2 100644 --- a/fastNLP/embeddings/bert_embedding.py +++ b/fastNLP/embeddings/bert_embedding.py @@ -294,7 +294,10 @@ class _WordBertModel(nn.Module): word = '[PAD]' elif index == vocab.unknown_idx: word = '[UNK]' - word_pieces = self.tokenzier.wordpiece_tokenizer.tokenize(word) + _words = self.tokenzier.basic_tokenizer._tokenize_chinese_chars(word).split() + word_pieces = [] + for w in _words: + word_pieces.extend(self.tokenzier.wordpiece_tokenizer.tokenize(w)) if len(word_pieces) == 1: if not vocab._is_word_no_create_entry(word): # 如果是train中的值, 但是却没有找到 if index != vocab.unknown_idx and word_pieces[0] == '[UNK]': # 说明这个词不在原始的word里面 diff --git a/fastNLP/modules/encoder/bert.py b/fastNLP/modules/encoder/bert.py index 821b9c5c..c3f2fa8b 100644 --- a/fastNLP/modules/encoder/bert.py +++ b/fastNLP/modules/encoder/bert.py @@ -989,7 +989,10 @@ class _WordPieceBertModel(nn.Module): def convert_words_to_word_pieces(words): word_pieces = [] for word in words: - tokens = self.tokenzier.wordpiece_tokenizer.tokenize(word) + _words = self.tokenzier.basic_tokenizer._tokenize_chinese_chars(word).split() + tokens = [] + for word in _words: + tokens.extend(self.tokenzier.wordpiece_tokenizer.tokenize(word)) word_piece_ids = self.tokenzier.convert_tokens_to_ids(tokens) word_pieces.extend(word_piece_ids) if add_cls_sep: diff --git a/test/core/test_trainer.py b/test/core/test_trainer.py index dc1a531a..aea9d363 100644 --- a/test/core/test_trainer.py +++ b/test/core/test_trainer.py @@ -54,6 +54,24 @@ class TrainerTestGround(unittest.TestCase): """ # 应该正确运行 """ + + def test_save_path(self): + data_set = prepare_fake_dataset() + data_set.set_input("x", flag=True) + data_set.set_target("y", flag=True) + + train_set, dev_set = data_set.split(0.3) + + model = NaiveClassifier(2, 1) + + save_path = 'test_save_models' + + trainer = Trainer(train_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), + batch_size=32, n_epochs=10, print_every=50, dev_data=dev_set, + metrics=AccuracyMetric(pred="predict", target="y"), validate_every=-1, save_path=save_path, + use_tqdm=True, check_code_level=2) + trainer.train() + def test_trainer_suggestion1(self): # 检查报错提示能否正确提醒用户。