@@ -947,13 +947,14 @@ class CheckPointCallback(Callback): | |||||
model = model.module | model = model.module | ||||
model.load_state_dict(states['model']) | model.load_state_dict(states['model']) | ||||
self.optimizer.load_state_dict(states['optimizer']) | self.optimizer.load_state_dict(states['optimizer']) | ||||
self.trainer.epoch = states['epoch'] + 1 # 因为是结束储存的,所以需要从下一个epoch开始 | |||||
self.trainer.epoch = states['epoch'] + 1 # 因为是结束储存的,所以需要从下一个epoch开始 | |||||
self.trainer.step = states['step'] | self.trainer.step = states['step'] | ||||
if 'best_dev_epoch' in states: | if 'best_dev_epoch' in states: | ||||
self.trainer.best_dev_perf = states['best_dev_perf'] | self.trainer.best_dev_perf = states['best_dev_perf'] | ||||
self.trainer.best_dev_epoch = states['best_dev_epoch'] | self.trainer.best_dev_epoch = states['best_dev_epoch'] | ||||
self.trainer.best_dev_step = states['best_dev_step'] | self.trainer.best_dev_step = states['best_dev_step'] | ||||
self.trainer.best_metric_indicator = states['best_metric_indicator'] | self.trainer.best_metric_indicator = states['best_metric_indicator'] | ||||
logger.info("Load checkpoint from {}".format(os.path.expanduser(self.save_path))) | |||||
def on_epoch_end(self): | def on_epoch_end(self): | ||||
r""" | r""" | ||||
@@ -645,7 +645,8 @@ class Trainer(object): | |||||
else: | else: | ||||
inner_tqdm = tqdm | inner_tqdm = tqdm | ||||
start = time.time() | start = time.time() | ||||
with inner_tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: | |||||
with inner_tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True, | |||||
initial=self.step) as pbar: | |||||
self.pbar = pbar | self.pbar = pbar | ||||
avg_loss = 0 | avg_loss = 0 | ||||
self.batch_per_epoch = self.data_iterator.num_batches | self.batch_per_epoch = self.data_iterator.num_batches | ||||
@@ -796,9 +796,12 @@ def _check_forward_error(forward_func, batch_x, dataset, check_level): | |||||
if len(suggestions) > 1: | if len(suggestions) > 1: | ||||
for idx, sugg in enumerate(suggestions): | for idx, sugg in enumerate(suggestions): | ||||
sugg_str += f'({idx + 1}). {sugg}' | sugg_str += f'({idx + 1}). {sugg}' | ||||
else: | |||||
err_str = '\n' + '\n'.join(errs) + '\n\tSuggestion: ' + sugg_str | |||||
elif len(suggestions): | |||||
sugg_str += suggestions[0] | sugg_str += suggestions[0] | ||||
err_str = '\n' + '\n'.join(errs) + '\n\tSuggestion: ' + sugg_str | |||||
err_str = '\n' + '\n'.join(errs) + '\n\tSuggestion: ' + sugg_str | |||||
else: | |||||
err_str = '\n' + '\n'.join(errs) | |||||
raise NameError(err_str) | raise NameError(err_str) | ||||
if _unused: | if _unused: | ||||
if check_level == WARNING_CHECK_LEVEL: | if check_level == WARNING_CHECK_LEVEL: | ||||
@@ -278,11 +278,11 @@ class GPT2WordPieceEncoder(nn.Module): | |||||
return output_strs | return output_strs | ||||
def generate(self, word_pieces, max_len=40, do_sample=True, num_beams=1, temperature=1, top_k=50, top_p=1.0, | |||||
def generate(self, word_pieces=None, max_len=40, do_sample=True, num_beams=1, temperature=1, top_k=50, top_p=1.0, | |||||
repetition_penalty=1.0, length_penalty=1.0): | repetition_penalty=1.0, length_penalty=1.0): | ||||
""" | """ | ||||
:param word_pieces: | |||||
:param torch.LongTensor,None word_pieces: 如果传入tensor,shape应该为batch_size x start_len; 如果传入None,会随机生成。 | |||||
:param int max_len: 生成多长的句子 | :param int max_len: 生成多长的句子 | ||||
:param bool do_sample: 是否使用采样的方式生成,如果使用采样,相同的参数可能出现不同的句子。 | :param bool do_sample: 是否使用采样的方式生成,如果使用采样,相同的参数可能出现不同的句子。 | ||||
:param int num_beams: 使用多大的beam size | :param int num_beams: 使用多大的beam size | ||||
@@ -293,7 +293,7 @@ class GPT2WordPieceEncoder(nn.Module): | |||||
:param float length_penalty: 惩罚过长的句子 | :param float length_penalty: 惩罚过长的句子 | ||||
:return: | :return: | ||||
""" | """ | ||||
pass | |||||
raise NotImplemented | |||||
def get_lm_loss(self, release=True): | def get_lm_loss(self, release=True): | ||||
""" | """ | ||||
@@ -1062,10 +1062,3 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): | |||||
# all hidden states, tuple: n_layer x batch_size x max_len x embed_size, | # all hidden states, tuple: n_layer x batch_size x max_len x embed_size, | ||||
# attention, tuple: n_layer x batch_size x n_head' x src_len x tgt_len | # attention, tuple: n_layer x batch_size x n_head' x src_len x tgt_len | ||||
return outputs # (loss), lm_logits, presents, all hidden_states, (attentions) | return outputs # (loss), lm_logits, presents, all hidden_states, (attentions) | ||||
# 输出每个位置的 | |||||
@@ -177,7 +177,8 @@ def _no_beam_search_generate(decoder: Decoder, tokens=None, past=None, max_lengt | |||||
scores = scores / temperature | scores = scores / temperature | ||||
scores = top_k_top_p_filtering(scores, top_k, top_p, min_tokens_to_keep=2) | scores = top_k_top_p_filtering(scores, top_k, top_p, min_tokens_to_keep=2) | ||||
probs = F.softmax(scores, dim=-1) | |||||
# 加上1e-12是为了避免https://github.com/pytorch/pytorch/pull/27523 | |||||
probs = F.softmax(scores, dim=-1) + 1e-12 | |||||
# 保证至少有一个不是eos的值 | # 保证至少有一个不是eos的值 | ||||
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) # batch_size | next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) # batch_size | ||||
@@ -230,7 +231,7 @@ def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=2 | |||||
assert vocab_size >= num_beams, "num_beams should be smaller than the number of vocabulary size." | assert vocab_size >= num_beams, "num_beams should be smaller than the number of vocabulary size." | ||||
if do_sample: | if do_sample: | ||||
probs = F.softmax(scores, dim=-1) | |||||
probs = F.softmax(scores, dim=-1) + 1e-12 | |||||
next_tokens = torch.multinomial(probs, num_samples=num_beams) # (batch_size, num_beams) | next_tokens = torch.multinomial(probs, num_samples=num_beams) # (batch_size, num_beams) | ||||
logits = probs.log() | logits = probs.log() | ||||
next_scores = logits.gather(dim=1, index=next_tokens) # (batch_size, num_beams) | next_scores = logits.gather(dim=1, index=next_tokens) # (batch_size, num_beams) | ||||
@@ -276,7 +277,8 @@ def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=2 | |||||
# 多召回一个防止eos | # 多召回一个防止eos | ||||
scores = top_k_top_p_filtering(scores, top_k, top_p, min_tokens_to_keep=num_beams + 1) | scores = top_k_top_p_filtering(scores, top_k, top_p, min_tokens_to_keep=num_beams + 1) | ||||
probs = F.softmax(scores, dim=-1) | |||||
# 加上1e-12是为了避免https://github.com/pytorch/pytorch/pull/27523 | |||||
probs = F.softmax(scores, dim=-1) + 1e-12 | |||||
# 保证至少有一个不是eos的值 | # 保证至少有一个不是eos的值 | ||||
_tokens = torch.multinomial(probs, num_samples=num_beams + 1) # batch_size' x (num_beams+1) | _tokens = torch.multinomial(probs, num_samples=num_beams + 1) # batch_size' x (num_beams+1) | ||||