From 846a1a515898fb37b8b1634f358e591afa3c7f75 Mon Sep 17 00:00:00 2001 From: yh_cc Date: Wed, 15 May 2019 11:47:34 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E4=B8=8Epytorch1.1=E4=B8=AD?= =?UTF-8?q?=E7=9A=84padsequence=E7=9A=84=E5=85=BC=E5=AE=B9=E9=97=AE?= =?UTF-8?q?=E9=A2=98;=20=E4=BF=AE=E6=94=B9Trainer=E7=9A=84pbar?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/callback.py | 9 +++++---- fastNLP/core/trainer.py | 2 +- fastNLP/modules/encoder/variational_rnn.py | 7 ++++--- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/fastNLP/core/callback.py b/fastNLP/core/callback.py index f337975a..9dce426b 100644 --- a/fastNLP/core/callback.py +++ b/fastNLP/core/callback.py @@ -130,7 +130,8 @@ class Callback(object): @property def pbar(self): - """如果在Callback中需要打印内容,请使用self.pbar.write(str)。否则可能出现命令行显示效果不太好的问题。""" + """如果在Callback中需要打印内容,请使用self.pbar.write(str)。否则可能出现命令行显示效果不太好的问题。在 + on_train_begin(), on_train_end(), on_exception()中请不要使用该属性,通过print输出即可。""" return self._trainer.pbar @property @@ -440,7 +441,7 @@ class LRScheduler(Callback): raise ValueError(f"Expect torch.optim.lr_scheduler for LRScheduler. Got {type(lr_scheduler)}.") def on_epoch_begin(self): - self.scheduler.step() + self.scheduler.step(self.epoch) class ControlC(Callback): @@ -526,7 +527,7 @@ class LRFinder(Callback): if torch.isnan(loss) or self.stop is True: self.stop = True return - loss_val = loss.detach().cpu().data + loss_val = loss.detach().mean().item() self.loss_history.append(loss_val) self.smooth_value.add_value(loss_val) if self.best_loss == 0. or self.smooth_value.smooth < self.best_loss: @@ -548,7 +549,7 @@ class LRFinder(Callback): self.find = False # reset model ModelLoader().load_pytorch(self.trainer.model, "tmp") - print("Model reset. \nFind best lr={}".format(self.best_lr)) + self.pbar.write("Model reset. \nFind best lr={}".format(self.best_lr)) class TensorboardCallback(Callback): diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index a6293167..9b56d834 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -558,7 +558,7 @@ class Trainer(object): start = time.time() with inner_tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: - self.pbar = pbar if isinstance(pbar, tqdm) else None + self.pbar = pbar avg_loss = 0 data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, prefetch=self.prefetch) diff --git a/fastNLP/modules/encoder/variational_rnn.py b/fastNLP/modules/encoder/variational_rnn.py index 2657ebf4..5a2e99f3 100644 --- a/fastNLP/modules/encoder/variational_rnn.py +++ b/fastNLP/modules/encoder/variational_rnn.py @@ -43,7 +43,7 @@ class VarRnnCellWrapper(nn.Module): return torch.cat([hi, h0[:h0_size]], dim=0) return hi[:size] is_lstm = isinstance(hidden, tuple) - input, batch_sizes = input_x + input, batch_sizes = input_x.data, input_x.batch_sizes output = [] cell = self.cell if is_reversed: @@ -148,10 +148,11 @@ class VarRNNBase(nn.Module): seq_len = x.size(1) if self.batch_first else x.size(0) max_batch_size = x.size(0) if self.batch_first else x.size(1) seq_lens = torch.LongTensor([seq_len for _ in range(max_batch_size)]) - x, batch_sizes = pack_padded_sequence(x, seq_lens, batch_first=self.batch_first) + _tmp = pack_padded_sequence(x, seq_lens, batch_first=self.batch_first) + x, batch_sizes = _tmp.data, _tmp.batch_sizes else: max_batch_size = int(x.batch_sizes[0]) - x, batch_sizes = x + x, batch_sizes = x.data, x.batch_sizes if hx is None: hx = x.new_zeros(self.num_layers * self.num_directions,