| @@ -947,13 +947,14 @@ class CheckPointCallback(Callback): | |||||
| model = model.module | model = model.module | ||||
| model.load_state_dict(states['model']) | model.load_state_dict(states['model']) | ||||
| self.optimizer.load_state_dict(states['optimizer']) | self.optimizer.load_state_dict(states['optimizer']) | ||||
| self.trainer.epoch = states['epoch'] + 1 # 因为是结束储存的,所以需要从下一个epoch开始 | |||||
| self.trainer.epoch = states['epoch'] + 1 # 因为是结束储存的,所以需要从下一个epoch开始 | |||||
| self.trainer.step = states['step'] | self.trainer.step = states['step'] | ||||
| if 'best_dev_epoch' in states: | if 'best_dev_epoch' in states: | ||||
| self.trainer.best_dev_perf = states['best_dev_perf'] | self.trainer.best_dev_perf = states['best_dev_perf'] | ||||
| self.trainer.best_dev_epoch = states['best_dev_epoch'] | self.trainer.best_dev_epoch = states['best_dev_epoch'] | ||||
| self.trainer.best_dev_step = states['best_dev_step'] | self.trainer.best_dev_step = states['best_dev_step'] | ||||
| self.trainer.best_metric_indicator = states['best_metric_indicator'] | self.trainer.best_metric_indicator = states['best_metric_indicator'] | ||||
| logger.info("Load checkpoint from {}".format(os.path.expanduser(self.save_path))) | |||||
| def on_epoch_end(self): | def on_epoch_end(self): | ||||
| r""" | r""" | ||||
| @@ -645,7 +645,8 @@ class Trainer(object): | |||||
| else: | else: | ||||
| inner_tqdm = tqdm | inner_tqdm = tqdm | ||||
| start = time.time() | start = time.time() | ||||
| with inner_tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: | |||||
| with inner_tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True, | |||||
| initial=self.step) as pbar: | |||||
| self.pbar = pbar | self.pbar = pbar | ||||
| avg_loss = 0 | avg_loss = 0 | ||||
| self.batch_per_epoch = self.data_iterator.num_batches | self.batch_per_epoch = self.data_iterator.num_batches | ||||
| @@ -796,9 +796,12 @@ def _check_forward_error(forward_func, batch_x, dataset, check_level): | |||||
| if len(suggestions) > 1: | if len(suggestions) > 1: | ||||
| for idx, sugg in enumerate(suggestions): | for idx, sugg in enumerate(suggestions): | ||||
| sugg_str += f'({idx + 1}). {sugg}' | sugg_str += f'({idx + 1}). {sugg}' | ||||
| else: | |||||
| err_str = '\n' + '\n'.join(errs) + '\n\tSuggestion: ' + sugg_str | |||||
| elif len(suggestions): | |||||
| sugg_str += suggestions[0] | sugg_str += suggestions[0] | ||||
| err_str = '\n' + '\n'.join(errs) + '\n\tSuggestion: ' + sugg_str | |||||
| err_str = '\n' + '\n'.join(errs) + '\n\tSuggestion: ' + sugg_str | |||||
| else: | |||||
| err_str = '\n' + '\n'.join(errs) | |||||
| raise NameError(err_str) | raise NameError(err_str) | ||||
| if _unused: | if _unused: | ||||
| if check_level == WARNING_CHECK_LEVEL: | if check_level == WARNING_CHECK_LEVEL: | ||||
| @@ -278,11 +278,11 @@ class GPT2WordPieceEncoder(nn.Module): | |||||
| return output_strs | return output_strs | ||||
| def generate(self, word_pieces, max_len=40, do_sample=True, num_beams=1, temperature=1, top_k=50, top_p=1.0, | |||||
| def generate(self, word_pieces=None, max_len=40, do_sample=True, num_beams=1, temperature=1, top_k=50, top_p=1.0, | |||||
| repetition_penalty=1.0, length_penalty=1.0): | repetition_penalty=1.0, length_penalty=1.0): | ||||
| """ | """ | ||||
| :param word_pieces: | |||||
| :param torch.LongTensor,None word_pieces: 如果传入tensor,shape应该为batch_size x start_len; 如果传入None,会随机生成。 | |||||
| :param int max_len: 生成多长的句子 | :param int max_len: 生成多长的句子 | ||||
| :param bool do_sample: 是否使用采样的方式生成,如果使用采样,相同的参数可能出现不同的句子。 | :param bool do_sample: 是否使用采样的方式生成,如果使用采样,相同的参数可能出现不同的句子。 | ||||
| :param int num_beams: 使用多大的beam size | :param int num_beams: 使用多大的beam size | ||||
| @@ -293,7 +293,7 @@ class GPT2WordPieceEncoder(nn.Module): | |||||
| :param float length_penalty: 惩罚过长的句子 | :param float length_penalty: 惩罚过长的句子 | ||||
| :return: | :return: | ||||
| """ | """ | ||||
| pass | |||||
| raise NotImplemented | |||||
| def get_lm_loss(self, release=True): | def get_lm_loss(self, release=True): | ||||
| """ | """ | ||||
| @@ -1062,10 +1062,3 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): | |||||
| # all hidden states, tuple: n_layer x batch_size x max_len x embed_size, | # all hidden states, tuple: n_layer x batch_size x max_len x embed_size, | ||||
| # attention, tuple: n_layer x batch_size x n_head' x src_len x tgt_len | # attention, tuple: n_layer x batch_size x n_head' x src_len x tgt_len | ||||
| return outputs # (loss), lm_logits, presents, all hidden_states, (attentions) | return outputs # (loss), lm_logits, presents, all hidden_states, (attentions) | ||||
| # 输出每个位置的 | |||||
| @@ -177,7 +177,8 @@ def _no_beam_search_generate(decoder: Decoder, tokens=None, past=None, max_lengt | |||||
| scores = scores / temperature | scores = scores / temperature | ||||
| scores = top_k_top_p_filtering(scores, top_k, top_p, min_tokens_to_keep=2) | scores = top_k_top_p_filtering(scores, top_k, top_p, min_tokens_to_keep=2) | ||||
| probs = F.softmax(scores, dim=-1) | |||||
| # 加上1e-12是为了避免https://github.com/pytorch/pytorch/pull/27523 | |||||
| probs = F.softmax(scores, dim=-1) + 1e-12 | |||||
| # 保证至少有一个不是eos的值 | # 保证至少有一个不是eos的值 | ||||
| next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) # batch_size | next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) # batch_size | ||||
| @@ -230,7 +231,7 @@ def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=2 | |||||
| assert vocab_size >= num_beams, "num_beams should be smaller than the number of vocabulary size." | assert vocab_size >= num_beams, "num_beams should be smaller than the number of vocabulary size." | ||||
| if do_sample: | if do_sample: | ||||
| probs = F.softmax(scores, dim=-1) | |||||
| probs = F.softmax(scores, dim=-1) + 1e-12 | |||||
| next_tokens = torch.multinomial(probs, num_samples=num_beams) # (batch_size, num_beams) | next_tokens = torch.multinomial(probs, num_samples=num_beams) # (batch_size, num_beams) | ||||
| logits = probs.log() | logits = probs.log() | ||||
| next_scores = logits.gather(dim=1, index=next_tokens) # (batch_size, num_beams) | next_scores = logits.gather(dim=1, index=next_tokens) # (batch_size, num_beams) | ||||
| @@ -276,7 +277,8 @@ def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=2 | |||||
| # 多召回一个防止eos | # 多召回一个防止eos | ||||
| scores = top_k_top_p_filtering(scores, top_k, top_p, min_tokens_to_keep=num_beams + 1) | scores = top_k_top_p_filtering(scores, top_k, top_p, min_tokens_to_keep=num_beams + 1) | ||||
| probs = F.softmax(scores, dim=-1) | |||||
| # 加上1e-12是为了避免https://github.com/pytorch/pytorch/pull/27523 | |||||
| probs = F.softmax(scores, dim=-1) + 1e-12 | |||||
| # 保证至少有一个不是eos的值 | # 保证至少有一个不是eos的值 | ||||
| _tokens = torch.multinomial(probs, num_samples=num_beams + 1) # batch_size' x (num_beams+1) | _tokens = torch.multinomial(probs, num_samples=num_beams + 1) # batch_size' x (num_beams+1) | ||||