Browse Source

修复save_path在dev为空的时候的bug

tags/v0.6.0
yh_cc 4 years ago
parent
commit
40bec21684
8 changed files with 30 additions and 14 deletions
  1. +3
    -3
      fastNLP/core/losses.py
  2. +5
    -0
      fastNLP/core/trainer.py
  3. +2
    -2
      fastNLP/embeddings/bert_embedding.py
  4. +2
    -2
      fastNLP/embeddings/gpt2_embedding.py
  5. +3
    -3
      fastNLP/embeddings/roberta_embedding.py
  6. +2
    -2
      fastNLP/embeddings/static_embedding.py
  7. +1
    -1
      fastNLP/modules/generator/seq2seq_generator.py
  8. +12
    -1
      test/core/test_trainer.py

+ 3
- 3
fastNLP/core/losses.py View File

@@ -178,7 +178,7 @@ class LossFunc(LossBase):
r""" r"""
提供给用户使用自定义损失函数的类 提供给用户使用自定义损失函数的类


:param func: 用户自行定义的损失函数,应当为一个函数或者callable(func)为True的ojbect
:param func: 用户自行定义的损失函数,应当为一个函数
:param dict key_map: 参数映射表。键为Model/DataSet参数名,值为损失函数参数名。 :param dict key_map: 参数映射表。键为Model/DataSet参数名,值为损失函数参数名。
fastNLP的trainer将在训练时从模型返回值或者训练数据DataSet的target=True的field中 fastNLP的trainer将在训练时从模型返回值或者训练数据DataSet的target=True的field中
找到相对应的参数名为value的参数,并传入func中作为参数名为key的参数 找到相对应的参数名为value的参数,并传入func中作为参数名为key的参数
@@ -186,8 +186,8 @@ class LossFunc(LossBase):


使用方法:: 使用方法::


func = torch.nn.CrossEntropyLoss()
loss_func = LossFunc(func, input="pred", target="label")
import torch.nn.functional as F
loss_func = LossFunc(F.cross_entropy, input="pred", target="label")
# 这表示构建了一个损失函数类,由func计算损失函数,其中将从模型返回值或者DataSet的target=True的field # 这表示构建了一个损失函数类,由func计算损失函数,其中将从模型返回值或者DataSet的target=True的field
# 当中找到一个参数名为`pred`的参数传入func一个参数名为`input`的参数;找到一个参数名为`label`的参数 # 当中找到一个参数名为`pred`的参数传入func一个参数名为`input`的参数;找到一个参数名为`label`的参数
# 传入func作为一个名为`target`的参数 # 传入func作为一个名为`target`的参数


+ 5
- 0
fastNLP/core/trainer.py View File

@@ -630,6 +630,11 @@ class Trainer(object):
self.logger.info("Reloaded the best model.") self.logger.info("Reloaded the best model.")
else: else:
self.logger.info("Fail to reload best model.") self.logger.info("Fail to reload best model.")

if self.dev_data is None and self.save_path is not None:
model_name = "_".join([self.model.__class__.__name__, self.start_time])
self._save_model(self.model, model_name)

finally: finally:
if self.dev_data is not None and self.best_dev_perf is not None: if self.dev_data is not None and self.best_dev_perf is not None:
self.logger.info( self.logger.info(


+ 2
- 2
fastNLP/embeddings/bert_embedding.py View File

@@ -89,7 +89,7 @@ class BertEmbedding(ContextualEmbedding):
word pieces后的内容,并将第512个word piece置为[SEP]。超过长度的部分的encode结果直接全部置零。一般仅有只使用[CLS] word pieces后的内容,并将第512个word piece置为[SEP]。超过长度的部分的encode结果直接全部置零。一般仅有只使用[CLS]
来进行分类的任务将auto_truncate置为True。 来进行分类的任务将auto_truncate置为True。
:param kwargs: :param kwargs:
int min_freq: 小于该次数的词会被unk代替
int min_freq: 小于该次数的词会被unk代替, 默认为1
""" """
super(BertEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) super(BertEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout)


@@ -110,7 +110,7 @@ class BertEmbedding(ContextualEmbedding):
if '[CLS]' in vocab: if '[CLS]' in vocab:
self._word_cls_index = vocab['CLS'] self._word_cls_index = vocab['CLS']


min_freq = kwargs.get('min_freq', 2)
min_freq = kwargs.get('min_freq', 1)
self._min_freq = min_freq self._min_freq = min_freq
self.model = _BertWordModel(model_dir_or_name=model_dir_or_name, vocab=vocab, layers=layers, self.model = _BertWordModel(model_dir_or_name=model_dir_or_name, vocab=vocab, layers=layers,
pool_method=pool_method, include_cls_sep=include_cls_sep, pool_method=pool_method, include_cls_sep=include_cls_sep,


+ 2
- 2
fastNLP/embeddings/gpt2_embedding.py View File

@@ -83,7 +83,7 @@ class GPT2Embedding(ContextualEmbedding):


only_use_pretrain_bpe = kwargs.get('only_use_pretrain_bpe', False) only_use_pretrain_bpe = kwargs.get('only_use_pretrain_bpe', False)
truncate_embed = kwargs.get('truncate_embed', True) truncate_embed = kwargs.get('truncate_embed', True)
min_freq = kwargs.get('min_freq', 2)
min_freq = kwargs.get('min_freq', 1)


self.lm_loss =language_model self.lm_loss =language_model
self.model = _GPT2Model(model_dir_or_name=model_dir_or_name, vocab=vocab, layers=layers, self.model = _GPT2Model(model_dir_or_name=model_dir_or_name, vocab=vocab, layers=layers,
@@ -315,7 +315,7 @@ class GPT2WordPieceEncoder(nn.Module):


class _GPT2Model(nn.Module): class _GPT2Model(nn.Module):
def __init__(self, model_dir_or_name, vocab, layers, pool_method='first', auto_truncate=True, language_model=False, def __init__(self, model_dir_or_name, vocab, layers, pool_method='first', auto_truncate=True, language_model=False,
only_use_pretrain_bpe=False, min_freq=2, truncate_embed=False):
only_use_pretrain_bpe=False, min_freq=1, truncate_embed=False):
super().__init__() super().__init__()


self.tokenzier = GPT2Tokenizer.from_pretrained(model_dir_or_name) self.tokenzier = GPT2Tokenizer.from_pretrained(model_dir_or_name)


+ 3
- 3
fastNLP/embeddings/roberta_embedding.py View File

@@ -78,7 +78,7 @@ class RobertaEmbedding(ContextualEmbedding):
word pieces后的内容,并将第512个word piece置为</s>。超过长度的部分的encode结果直接全部置零。一般仅有只使用<s> word pieces后的内容,并将第512个word piece置为</s>。超过长度的部分的encode结果直接全部置零。一般仅有只使用<s>
来进行分类的任务将auto_truncate置为True。 来进行分类的任务将auto_truncate置为True。
:param kwargs: :param kwargs:
int min_freq: 小于该次数的词会被unk代替
int min_freq: 小于该次数的词会被unk代替, 默认为1
""" """
super().__init__(vocab, word_dropout=word_dropout, dropout=dropout) super().__init__(vocab, word_dropout=word_dropout, dropout=dropout)


@@ -93,7 +93,7 @@ class RobertaEmbedding(ContextualEmbedding):
if '<s>' in vocab: if '<s>' in vocab:
self._word_cls_index = vocab['<s>'] self._word_cls_index = vocab['<s>']


min_freq = kwargs.get('min_freq', 2)
min_freq = kwargs.get('min_freq', 1)
self._min_freq = min_freq self._min_freq = min_freq


self.model = _RobertaWordModel(model_dir_or_name=model_dir_or_name, vocab=vocab, layers=layers, self.model = _RobertaWordModel(model_dir_or_name=model_dir_or_name, vocab=vocab, layers=layers,
@@ -464,7 +464,7 @@ class RobertaWordPieceEncoder(nn.Module):


os.makedirs(os.path.join(folder, ROBERTA_ENCODER_FOLDER), exist_ok=True) os.makedirs(os.path.join(folder, ROBERTA_ENCODER_FOLDER), exist_ok=True)
self.model.save(os.path.join(folder, ROBERTA_ENCODER_FOLDER)) self.model.save(os.path.join(folder, ROBERTA_ENCODER_FOLDER))
logger.debug(f"BertWordPieceEncoder has been saved in {folder}")
logger.debug(f"RobertaWordPieceEncoder has been saved in {folder}")


@classmethod @classmethod
def load(cls, folder): def load(cls, folder):


+ 2
- 2
fastNLP/embeddings/static_embedding.py View File

@@ -97,8 +97,8 @@ class StaticEmbedding(TokenEmbedding):
:param int min_freq: Vocabulary词频数小于这个数量的word将被指向unk。 :param int min_freq: Vocabulary词频数小于这个数量的word将被指向unk。
:param dict kwargs: :param dict kwargs:
bool only_train_min_freq: 仅对train中的词语使用min_freq筛选; bool only_train_min_freq: 仅对train中的词语使用min_freq筛选;
bool only_norm_found_vector: 是否仅对在预训练中找到的词语使用normalize;
bool only_use_pretrain_word: 仅使用出现在pretrain词表中的词,如果该词没有在预训练的词表中出现则为unk。如果embedding不需要更新建议设置为True。
bool only_norm_found_vector: 默认为False, 是否仅对在预训练中找到的词语使用normalize;
bool only_use_pretrain_word: 默认为False, 仅使用出现在pretrain词表中的词,如果该词没有在预训练的词表中出现则为unk。如果embedding不需要更新建议设置为True。
""" """
super(StaticEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) super(StaticEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout)
if embedding_dim > 0: if embedding_dim > 0:


+ 1
- 1
fastNLP/modules/generator/seq2seq_generator.py View File

@@ -308,7 +308,7 @@ def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_
max_len_eos_mask = max_lengths.eq(cur_len+1) max_len_eos_mask = max_lengths.eq(cur_len+1)
eos_scores = scores[:, _eos_token_id] eos_scores = scores[:, _eos_token_id]
# 如果已经达到最大长度,就把eos的分数加大 # 如果已经达到最大长度,就把eos的分数加大
scores[:, _eos_token_id] = torch.where(max_len_eos_mask, eos_scores+100, eos_scores)
scores[:, _eos_token_id] = torch.where(max_len_eos_mask, eos_scores+1e12, eos_scores)


if do_sample: if do_sample:
if temperature > 0 and temperature != 1: if temperature > 0 and temperature != 1:


+ 12
- 1
test/core/test_trainer.py View File

@@ -76,8 +76,19 @@ class TrainerTestGround(unittest.TestCase):
use_tqdm=True, check_code_level=2) use_tqdm=True, check_code_level=2)
trainer.train() trainer.train()
import os import os
import shutil
self.assertTrue(os.path.exists(save_path))
if os.path.exists(save_path):
shutil.rmtree(save_path)

# 无dev_data的训练
trainer = Trainer(train_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"),
batch_size=32, n_epochs=10, print_every=50, dev_data=None,
metrics=None, validate_every=-1, save_path=save_path,
use_tqdm=True, check_code_level=2)
trainer.train()
self.assertTrue(os.path.exists(save_path))
if os.path.exists(save_path): if os.path.exists(save_path):
import shutil
shutil.rmtree(save_path) shutil.rmtree(save_path)


def test_trainer_suggestion1(self): def test_trainer_suggestion1(self):


Loading…
Cancel
Save