@@ -599,7 +599,8 @@ class Trainer(object): | |||||
self._model_device = _get_model_device(self.model) | self._model_device = _get_model_device(self.model) | ||||
self._mode(self.model, is_test=False) | self._mode(self.model, is_test=False) | ||||
self._load_best_model = load_best_model | self._load_best_model = load_best_model | ||||
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S')) | |||||
# 加上millsecond,防止两个太接近的保存 | |||||
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S-%f')) | |||||
start_time = time.time() | start_time = time.time() | ||||
self.logger.info("training epochs started " + self.start_time) | self.logger.info("training epochs started " + self.start_time) | ||||
self.step = 0 | self.step = 0 | ||||
@@ -294,7 +294,10 @@ class _WordBertModel(nn.Module): | |||||
word = '[PAD]' | word = '[PAD]' | ||||
elif index == vocab.unknown_idx: | elif index == vocab.unknown_idx: | ||||
word = '[UNK]' | word = '[UNK]' | ||||
word_pieces = self.tokenzier.wordpiece_tokenizer.tokenize(word) | |||||
_words = self.tokenzier.basic_tokenizer._tokenize_chinese_chars(word).split() | |||||
word_pieces = [] | |||||
for w in _words: | |||||
word_pieces.extend(self.tokenzier.wordpiece_tokenizer.tokenize(w)) | |||||
if len(word_pieces) == 1: | if len(word_pieces) == 1: | ||||
if not vocab._is_word_no_create_entry(word): # 如果是train中的值, 但是却没有找到 | if not vocab._is_word_no_create_entry(word): # 如果是train中的值, 但是却没有找到 | ||||
if index != vocab.unknown_idx and word_pieces[0] == '[UNK]': # 说明这个词不在原始的word里面 | if index != vocab.unknown_idx and word_pieces[0] == '[UNK]': # 说明这个词不在原始的word里面 | ||||
@@ -989,7 +989,10 @@ class _WordPieceBertModel(nn.Module): | |||||
def convert_words_to_word_pieces(words): | def convert_words_to_word_pieces(words): | ||||
word_pieces = [] | word_pieces = [] | ||||
for word in words: | for word in words: | ||||
tokens = self.tokenzier.wordpiece_tokenizer.tokenize(word) | |||||
_words = self.tokenzier.basic_tokenizer._tokenize_chinese_chars(word).split() | |||||
tokens = [] | |||||
for word in _words: | |||||
tokens.extend(self.tokenzier.wordpiece_tokenizer.tokenize(word)) | |||||
word_piece_ids = self.tokenzier.convert_tokens_to_ids(tokens) | word_piece_ids = self.tokenzier.convert_tokens_to_ids(tokens) | ||||
word_pieces.extend(word_piece_ids) | word_pieces.extend(word_piece_ids) | ||||
if add_cls_sep: | if add_cls_sep: | ||||
@@ -54,6 +54,24 @@ class TrainerTestGround(unittest.TestCase): | |||||
""" | """ | ||||
# 应该正确运行 | # 应该正确运行 | ||||
""" | """ | ||||
def test_save_path(self): | |||||
data_set = prepare_fake_dataset() | |||||
data_set.set_input("x", flag=True) | |||||
data_set.set_target("y", flag=True) | |||||
train_set, dev_set = data_set.split(0.3) | |||||
model = NaiveClassifier(2, 1) | |||||
save_path = 'test_save_models' | |||||
trainer = Trainer(train_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), | |||||
batch_size=32, n_epochs=10, print_every=50, dev_data=dev_set, | |||||
metrics=AccuracyMetric(pred="predict", target="y"), validate_every=-1, save_path=save_path, | |||||
use_tqdm=True, check_code_level=2) | |||||
trainer.train() | |||||
def test_trainer_suggestion1(self): | def test_trainer_suggestion1(self): | ||||
# 检查报错提示能否正确提醒用户。 | # 检查报错提示能否正确提醒用户。 | ||||