@@ -178,7 +178,7 @@ class LossFunc(LossBase): | |||||
r""" | r""" | ||||
提供给用户使用自定义损失函数的类 | 提供给用户使用自定义损失函数的类 | ||||
:param func: 用户自行定义的损失函数,应当为一个函数或者callable(func)为True的ojbect | |||||
:param func: 用户自行定义的损失函数,应当为一个函数。 | |||||
:param dict key_map: 参数映射表。键为Model/DataSet参数名,值为损失函数参数名。 | :param dict key_map: 参数映射表。键为Model/DataSet参数名,值为损失函数参数名。 | ||||
fastNLP的trainer将在训练时从模型返回值或者训练数据DataSet的target=True的field中 | fastNLP的trainer将在训练时从模型返回值或者训练数据DataSet的target=True的field中 | ||||
找到相对应的参数名为value的参数,并传入func中作为参数名为key的参数 | 找到相对应的参数名为value的参数,并传入func中作为参数名为key的参数 | ||||
@@ -186,8 +186,8 @@ class LossFunc(LossBase): | |||||
使用方法:: | 使用方法:: | ||||
func = torch.nn.CrossEntropyLoss() | |||||
loss_func = LossFunc(func, input="pred", target="label") | |||||
import torch.nn.functional as F | |||||
loss_func = LossFunc(F.cross_entropy, input="pred", target="label") | |||||
# 这表示构建了一个损失函数类,由func计算损失函数,其中将从模型返回值或者DataSet的target=True的field | # 这表示构建了一个损失函数类,由func计算损失函数,其中将从模型返回值或者DataSet的target=True的field | ||||
# 当中找到一个参数名为`pred`的参数传入func一个参数名为`input`的参数;找到一个参数名为`label`的参数 | # 当中找到一个参数名为`pred`的参数传入func一个参数名为`input`的参数;找到一个参数名为`label`的参数 | ||||
# 传入func作为一个名为`target`的参数 | # 传入func作为一个名为`target`的参数 | ||||
@@ -630,6 +630,11 @@ class Trainer(object): | |||||
self.logger.info("Reloaded the best model.") | self.logger.info("Reloaded the best model.") | ||||
else: | else: | ||||
self.logger.info("Fail to reload best model.") | self.logger.info("Fail to reload best model.") | ||||
if self.dev_data is None and self.save_path is not None: | |||||
model_name = "_".join([self.model.__class__.__name__, self.start_time]) | |||||
self._save_model(self.model, model_name) | |||||
finally: | finally: | ||||
if self.dev_data is not None and self.best_dev_perf is not None: | if self.dev_data is not None and self.best_dev_perf is not None: | ||||
self.logger.info( | self.logger.info( | ||||
@@ -89,7 +89,7 @@ class BertEmbedding(ContextualEmbedding): | |||||
word pieces后的内容,并将第512个word piece置为[SEP]。超过长度的部分的encode结果直接全部置零。一般仅有只使用[CLS] | word pieces后的内容,并将第512个word piece置为[SEP]。超过长度的部分的encode结果直接全部置零。一般仅有只使用[CLS] | ||||
来进行分类的任务将auto_truncate置为True。 | 来进行分类的任务将auto_truncate置为True。 | ||||
:param kwargs: | :param kwargs: | ||||
int min_freq: 小于该次数的词会被unk代替 | |||||
int min_freq: 小于该次数的词会被unk代替, 默认为1 | |||||
""" | """ | ||||
super(BertEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) | super(BertEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) | ||||
@@ -110,7 +110,7 @@ class BertEmbedding(ContextualEmbedding): | |||||
if '[CLS]' in vocab: | if '[CLS]' in vocab: | ||||
self._word_cls_index = vocab['CLS'] | self._word_cls_index = vocab['CLS'] | ||||
min_freq = kwargs.get('min_freq', 2) | |||||
min_freq = kwargs.get('min_freq', 1) | |||||
self._min_freq = min_freq | self._min_freq = min_freq | ||||
self.model = _BertWordModel(model_dir_or_name=model_dir_or_name, vocab=vocab, layers=layers, | self.model = _BertWordModel(model_dir_or_name=model_dir_or_name, vocab=vocab, layers=layers, | ||||
pool_method=pool_method, include_cls_sep=include_cls_sep, | pool_method=pool_method, include_cls_sep=include_cls_sep, | ||||
@@ -83,7 +83,7 @@ class GPT2Embedding(ContextualEmbedding): | |||||
only_use_pretrain_bpe = kwargs.get('only_use_pretrain_bpe', False) | only_use_pretrain_bpe = kwargs.get('only_use_pretrain_bpe', False) | ||||
truncate_embed = kwargs.get('truncate_embed', True) | truncate_embed = kwargs.get('truncate_embed', True) | ||||
min_freq = kwargs.get('min_freq', 2) | |||||
min_freq = kwargs.get('min_freq', 1) | |||||
self.lm_loss =language_model | self.lm_loss =language_model | ||||
self.model = _GPT2Model(model_dir_or_name=model_dir_or_name, vocab=vocab, layers=layers, | self.model = _GPT2Model(model_dir_or_name=model_dir_or_name, vocab=vocab, layers=layers, | ||||
@@ -315,7 +315,7 @@ class GPT2WordPieceEncoder(nn.Module): | |||||
class _GPT2Model(nn.Module): | class _GPT2Model(nn.Module): | ||||
def __init__(self, model_dir_or_name, vocab, layers, pool_method='first', auto_truncate=True, language_model=False, | def __init__(self, model_dir_or_name, vocab, layers, pool_method='first', auto_truncate=True, language_model=False, | ||||
only_use_pretrain_bpe=False, min_freq=2, truncate_embed=False): | |||||
only_use_pretrain_bpe=False, min_freq=1, truncate_embed=False): | |||||
super().__init__() | super().__init__() | ||||
self.tokenzier = GPT2Tokenizer.from_pretrained(model_dir_or_name) | self.tokenzier = GPT2Tokenizer.from_pretrained(model_dir_or_name) | ||||
@@ -78,7 +78,7 @@ class RobertaEmbedding(ContextualEmbedding): | |||||
word pieces后的内容,并将第512个word piece置为</s>。超过长度的部分的encode结果直接全部置零。一般仅有只使用<s> | word pieces后的内容,并将第512个word piece置为</s>。超过长度的部分的encode结果直接全部置零。一般仅有只使用<s> | ||||
来进行分类的任务将auto_truncate置为True。 | 来进行分类的任务将auto_truncate置为True。 | ||||
:param kwargs: | :param kwargs: | ||||
int min_freq: 小于该次数的词会被unk代替 | |||||
int min_freq: 小于该次数的词会被unk代替, 默认为1 | |||||
""" | """ | ||||
super().__init__(vocab, word_dropout=word_dropout, dropout=dropout) | super().__init__(vocab, word_dropout=word_dropout, dropout=dropout) | ||||
@@ -93,7 +93,7 @@ class RobertaEmbedding(ContextualEmbedding): | |||||
if '<s>' in vocab: | if '<s>' in vocab: | ||||
self._word_cls_index = vocab['<s>'] | self._word_cls_index = vocab['<s>'] | ||||
min_freq = kwargs.get('min_freq', 2) | |||||
min_freq = kwargs.get('min_freq', 1) | |||||
self._min_freq = min_freq | self._min_freq = min_freq | ||||
self.model = _RobertaWordModel(model_dir_or_name=model_dir_or_name, vocab=vocab, layers=layers, | self.model = _RobertaWordModel(model_dir_or_name=model_dir_or_name, vocab=vocab, layers=layers, | ||||
@@ -464,7 +464,7 @@ class RobertaWordPieceEncoder(nn.Module): | |||||
os.makedirs(os.path.join(folder, ROBERTA_ENCODER_FOLDER), exist_ok=True) | os.makedirs(os.path.join(folder, ROBERTA_ENCODER_FOLDER), exist_ok=True) | ||||
self.model.save(os.path.join(folder, ROBERTA_ENCODER_FOLDER)) | self.model.save(os.path.join(folder, ROBERTA_ENCODER_FOLDER)) | ||||
logger.debug(f"BertWordPieceEncoder has been saved in {folder}") | |||||
logger.debug(f"RobertaWordPieceEncoder has been saved in {folder}") | |||||
@classmethod | @classmethod | ||||
def load(cls, folder): | def load(cls, folder): | ||||
@@ -97,8 +97,8 @@ class StaticEmbedding(TokenEmbedding): | |||||
:param int min_freq: Vocabulary词频数小于这个数量的word将被指向unk。 | :param int min_freq: Vocabulary词频数小于这个数量的word将被指向unk。 | ||||
:param dict kwargs: | :param dict kwargs: | ||||
bool only_train_min_freq: 仅对train中的词语使用min_freq筛选; | bool only_train_min_freq: 仅对train中的词语使用min_freq筛选; | ||||
bool only_norm_found_vector: 是否仅对在预训练中找到的词语使用normalize; | |||||
bool only_use_pretrain_word: 仅使用出现在pretrain词表中的词,如果该词没有在预训练的词表中出现则为unk。如果embedding不需要更新建议设置为True。 | |||||
bool only_norm_found_vector: 默认为False, 是否仅对在预训练中找到的词语使用normalize; | |||||
bool only_use_pretrain_word: 默认为False, 仅使用出现在pretrain词表中的词,如果该词没有在预训练的词表中出现则为unk。如果embedding不需要更新建议设置为True。 | |||||
""" | """ | ||||
super(StaticEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) | super(StaticEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) | ||||
if embedding_dim > 0: | if embedding_dim > 0: | ||||
@@ -308,7 +308,7 @@ def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_ | |||||
max_len_eos_mask = max_lengths.eq(cur_len+1) | max_len_eos_mask = max_lengths.eq(cur_len+1) | ||||
eos_scores = scores[:, _eos_token_id] | eos_scores = scores[:, _eos_token_id] | ||||
# 如果已经达到最大长度,就把eos的分数加大 | # 如果已经达到最大长度,就把eos的分数加大 | ||||
scores[:, _eos_token_id] = torch.where(max_len_eos_mask, eos_scores+100, eos_scores) | |||||
scores[:, _eos_token_id] = torch.where(max_len_eos_mask, eos_scores+1e12, eos_scores) | |||||
if do_sample: | if do_sample: | ||||
if temperature > 0 and temperature != 1: | if temperature > 0 and temperature != 1: | ||||
@@ -76,8 +76,19 @@ class TrainerTestGround(unittest.TestCase): | |||||
use_tqdm=True, check_code_level=2) | use_tqdm=True, check_code_level=2) | ||||
trainer.train() | trainer.train() | ||||
import os | import os | ||||
import shutil | |||||
self.assertTrue(os.path.exists(save_path)) | |||||
if os.path.exists(save_path): | |||||
shutil.rmtree(save_path) | |||||
# 无dev_data的训练 | |||||
trainer = Trainer(train_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), | |||||
batch_size=32, n_epochs=10, print_every=50, dev_data=None, | |||||
metrics=None, validate_every=-1, save_path=save_path, | |||||
use_tqdm=True, check_code_level=2) | |||||
trainer.train() | |||||
self.assertTrue(os.path.exists(save_path)) | |||||
if os.path.exists(save_path): | if os.path.exists(save_path): | ||||
import shutil | |||||
shutil.rmtree(save_path) | shutil.rmtree(save_path) | ||||
def test_trainer_suggestion1(self): | def test_trainer_suggestion1(self): | ||||