Browse Source

1.修改Generator中的采样,防止导致分布全0采样失败; 2. 修改Trainer与checkpointcallback的配合

tags/v0.5.5
yh_cc 4 years ago
parent
commit
ae7b916355
6 changed files with 17 additions and 17 deletions
  1. +2
    -1
      fastNLP/core/callback.py
  2. +2
    -1
      fastNLP/core/trainer.py
  3. +5
    -2
      fastNLP/core/utils.py
  4. +3
    -3
      fastNLP/embeddings/gpt2_embedding.py
  5. +0
    -7
      fastNLP/modules/encoder/gpt2.py
  6. +5
    -3
      fastNLP/modules/generator/seq2seq_generator.py

+ 2
- 1
fastNLP/core/callback.py View File

@@ -947,13 +947,14 @@ class CheckPointCallback(Callback):
model = model.module
model.load_state_dict(states['model'])
self.optimizer.load_state_dict(states['optimizer'])
self.trainer.epoch = states['epoch'] + 1 # 因为是结束储存的,所以需要从下一个epoch开始
self.trainer.epoch = states['epoch'] + 1 # 因为是结束储存的,所以需要从下一个epoch开始
self.trainer.step = states['step']
if 'best_dev_epoch' in states:
self.trainer.best_dev_perf = states['best_dev_perf']
self.trainer.best_dev_epoch = states['best_dev_epoch']
self.trainer.best_dev_step = states['best_dev_step']
self.trainer.best_metric_indicator = states['best_metric_indicator']
logger.info("Load checkpoint from {}".format(os.path.expanduser(self.save_path)))

def on_epoch_end(self):
r"""


+ 2
- 1
fastNLP/core/trainer.py View File

@@ -645,7 +645,8 @@ class Trainer(object):
else:
inner_tqdm = tqdm
start = time.time()
with inner_tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar:
with inner_tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True,
initial=self.step) as pbar:
self.pbar = pbar
avg_loss = 0
self.batch_per_epoch = self.data_iterator.num_batches


+ 5
- 2
fastNLP/core/utils.py View File

@@ -796,9 +796,12 @@ def _check_forward_error(forward_func, batch_x, dataset, check_level):
if len(suggestions) > 1:
for idx, sugg in enumerate(suggestions):
sugg_str += f'({idx + 1}). {sugg}'
else:
err_str = '\n' + '\n'.join(errs) + '\n\tSuggestion: ' + sugg_str
elif len(suggestions):
sugg_str += suggestions[0]
err_str = '\n' + '\n'.join(errs) + '\n\tSuggestion: ' + sugg_str
err_str = '\n' + '\n'.join(errs) + '\n\tSuggestion: ' + sugg_str
else:
err_str = '\n' + '\n'.join(errs)
raise NameError(err_str)
if _unused:
if check_level == WARNING_CHECK_LEVEL:


+ 3
- 3
fastNLP/embeddings/gpt2_embedding.py View File

@@ -278,11 +278,11 @@ class GPT2WordPieceEncoder(nn.Module):

return output_strs

def generate(self, word_pieces, max_len=40, do_sample=True, num_beams=1, temperature=1, top_k=50, top_p=1.0,
def generate(self, word_pieces=None, max_len=40, do_sample=True, num_beams=1, temperature=1, top_k=50, top_p=1.0,
repetition_penalty=1.0, length_penalty=1.0):
"""

:param word_pieces:
:param torch.LongTensor,None word_pieces: 如果传入tensor,shape应该为batch_size x start_len; 如果传入None,会随机生成。
:param int max_len: 生成多长的句子
:param bool do_sample: 是否使用采样的方式生成,如果使用采样,相同的参数可能出现不同的句子。
:param int num_beams: 使用多大的beam size
@@ -293,7 +293,7 @@ class GPT2WordPieceEncoder(nn.Module):
:param float length_penalty: 惩罚过长的句子
:return:
"""
pass
raise NotImplemented

def get_lm_loss(self, release=True):
"""


+ 0
- 7
fastNLP/modules/encoder/gpt2.py View File

@@ -1062,10 +1062,3 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
# all hidden states, tuple: n_layer x batch_size x max_len x embed_size,
# attention, tuple: n_layer x batch_size x n_head' x src_len x tgt_len
return outputs # (loss), lm_logits, presents, all hidden_states, (attentions)





# 输出每个位置的


+ 5
- 3
fastNLP/modules/generator/seq2seq_generator.py View File

@@ -177,7 +177,8 @@ def _no_beam_search_generate(decoder: Decoder, tokens=None, past=None, max_lengt
scores = scores / temperature

scores = top_k_top_p_filtering(scores, top_k, top_p, min_tokens_to_keep=2)
probs = F.softmax(scores, dim=-1)
# 加上1e-12是为了避免https://github.com/pytorch/pytorch/pull/27523
probs = F.softmax(scores, dim=-1) + 1e-12

# 保证至少有一个不是eos的值
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) # batch_size
@@ -230,7 +231,7 @@ def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=2
assert vocab_size >= num_beams, "num_beams should be smaller than the number of vocabulary size."

if do_sample:
probs = F.softmax(scores, dim=-1)
probs = F.softmax(scores, dim=-1) + 1e-12
next_tokens = torch.multinomial(probs, num_samples=num_beams) # (batch_size, num_beams)
logits = probs.log()
next_scores = logits.gather(dim=1, index=next_tokens) # (batch_size, num_beams)
@@ -276,7 +277,8 @@ def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=2

# 多召回一个防止eos
scores = top_k_top_p_filtering(scores, top_k, top_p, min_tokens_to_keep=num_beams + 1)
probs = F.softmax(scores, dim=-1)
# 加上1e-12是为了避免https://github.com/pytorch/pytorch/pull/27523
probs = F.softmax(scores, dim=-1) + 1e-12

# 保证至少有一个不是eos的值
_tokens = torch.multinomial(probs, num_samples=num_beams + 1) # batch_size' x (num_beams+1)


Loading…
Cancel
Save