Browse Source

修复与pytorch1.1中的padsequence的兼容问题; 修改Trainer的pbar

tags/v0.4.10
yh_cc 5 years ago
parent
commit
846a1a5158
3 changed files with 10 additions and 8 deletions
  1. +5
    -4
      fastNLP/core/callback.py
  2. +1
    -1
      fastNLP/core/trainer.py
  3. +4
    -3
      fastNLP/modules/encoder/variational_rnn.py

+ 5
- 4
fastNLP/core/callback.py View File

@@ -130,7 +130,8 @@ class Callback(object):
@property
def pbar(self):
"""如果在Callback中需要打印内容,请使用self.pbar.write(str)。否则可能出现命令行显示效果不太好的问题。"""
"""如果在Callback中需要打印内容,请使用self.pbar.write(str)。否则可能出现命令行显示效果不太好的问题。在
on_train_begin(), on_train_end(), on_exception()中请不要使用该属性,通过print输出即可。"""
return self._trainer.pbar
@property
@@ -440,7 +441,7 @@ class LRScheduler(Callback):
raise ValueError(f"Expect torch.optim.lr_scheduler for LRScheduler. Got {type(lr_scheduler)}.")
def on_epoch_begin(self):
self.scheduler.step()
self.scheduler.step(self.epoch)


class ControlC(Callback):
@@ -526,7 +527,7 @@ class LRFinder(Callback):
if torch.isnan(loss) or self.stop is True:
self.stop = True
return
loss_val = loss.detach().cpu().data
loss_val = loss.detach().mean().item()
self.loss_history.append(loss_val)
self.smooth_value.add_value(loss_val)
if self.best_loss == 0. or self.smooth_value.smooth < self.best_loss:
@@ -548,7 +549,7 @@ class LRFinder(Callback):
self.find = False
# reset model
ModelLoader().load_pytorch(self.trainer.model, "tmp")
print("Model reset. \nFind best lr={}".format(self.best_lr))
self.pbar.write("Model reset. \nFind best lr={}".format(self.best_lr))


class TensorboardCallback(Callback):


+ 1
- 1
fastNLP/core/trainer.py View File

@@ -558,7 +558,7 @@ class Trainer(object):
start = time.time()
with inner_tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar:
self.pbar = pbar if isinstance(pbar, tqdm) else None
self.pbar = pbar
avg_loss = 0
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False,
prefetch=self.prefetch)


+ 4
- 3
fastNLP/modules/encoder/variational_rnn.py View File

@@ -43,7 +43,7 @@ class VarRnnCellWrapper(nn.Module):
return torch.cat([hi, h0[:h0_size]], dim=0)
return hi[:size]
is_lstm = isinstance(hidden, tuple)
input, batch_sizes = input_x
input, batch_sizes = input_x.data, input_x.batch_sizes
output = []
cell = self.cell
if is_reversed:
@@ -148,10 +148,11 @@ class VarRNNBase(nn.Module):
seq_len = x.size(1) if self.batch_first else x.size(0)
max_batch_size = x.size(0) if self.batch_first else x.size(1)
seq_lens = torch.LongTensor([seq_len for _ in range(max_batch_size)])
x, batch_sizes = pack_padded_sequence(x, seq_lens, batch_first=self.batch_first)
_tmp = pack_padded_sequence(x, seq_lens, batch_first=self.batch_first)
x, batch_sizes = _tmp.data, _tmp.batch_sizes
else:
max_batch_size = int(x.batch_sizes[0])
x, batch_sizes = x
x, batch_sizes = x.data, x.batch_sizes

if hx is None:
hx = x.new_zeros(self.num_layers * self.num_directions,


Loading…
Cancel
Save