@@ -130,7 +130,8 @@ class Callback(object): | |||||
@property | @property | ||||
def pbar(self): | def pbar(self): | ||||
"""如果在Callback中需要打印内容,请使用self.pbar.write(str)。否则可能出现命令行显示效果不太好的问题。""" | |||||
"""如果在Callback中需要打印内容,请使用self.pbar.write(str)。否则可能出现命令行显示效果不太好的问题。在 | |||||
on_train_begin(), on_train_end(), on_exception()中请不要使用该属性,通过print输出即可。""" | |||||
return self._trainer.pbar | return self._trainer.pbar | ||||
@property | @property | ||||
@@ -440,7 +441,7 @@ class LRScheduler(Callback): | |||||
raise ValueError(f"Expect torch.optim.lr_scheduler for LRScheduler. Got {type(lr_scheduler)}.") | raise ValueError(f"Expect torch.optim.lr_scheduler for LRScheduler. Got {type(lr_scheduler)}.") | ||||
def on_epoch_begin(self): | def on_epoch_begin(self): | ||||
self.scheduler.step() | |||||
self.scheduler.step(self.epoch) | |||||
class ControlC(Callback): | class ControlC(Callback): | ||||
@@ -526,7 +527,7 @@ class LRFinder(Callback): | |||||
if torch.isnan(loss) or self.stop is True: | if torch.isnan(loss) or self.stop is True: | ||||
self.stop = True | self.stop = True | ||||
return | return | ||||
loss_val = loss.detach().cpu().data | |||||
loss_val = loss.detach().mean().item() | |||||
self.loss_history.append(loss_val) | self.loss_history.append(loss_val) | ||||
self.smooth_value.add_value(loss_val) | self.smooth_value.add_value(loss_val) | ||||
if self.best_loss == 0. or self.smooth_value.smooth < self.best_loss: | if self.best_loss == 0. or self.smooth_value.smooth < self.best_loss: | ||||
@@ -548,7 +549,7 @@ class LRFinder(Callback): | |||||
self.find = False | self.find = False | ||||
# reset model | # reset model | ||||
ModelLoader().load_pytorch(self.trainer.model, "tmp") | ModelLoader().load_pytorch(self.trainer.model, "tmp") | ||||
print("Model reset. \nFind best lr={}".format(self.best_lr)) | |||||
self.pbar.write("Model reset. \nFind best lr={}".format(self.best_lr)) | |||||
class TensorboardCallback(Callback): | class TensorboardCallback(Callback): | ||||
@@ -558,7 +558,7 @@ class Trainer(object): | |||||
start = time.time() | start = time.time() | ||||
with inner_tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: | with inner_tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: | ||||
self.pbar = pbar if isinstance(pbar, tqdm) else None | |||||
self.pbar = pbar | |||||
avg_loss = 0 | avg_loss = 0 | ||||
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, | data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, | ||||
prefetch=self.prefetch) | prefetch=self.prefetch) | ||||
@@ -43,7 +43,7 @@ class VarRnnCellWrapper(nn.Module): | |||||
return torch.cat([hi, h0[:h0_size]], dim=0) | return torch.cat([hi, h0[:h0_size]], dim=0) | ||||
return hi[:size] | return hi[:size] | ||||
is_lstm = isinstance(hidden, tuple) | is_lstm = isinstance(hidden, tuple) | ||||
input, batch_sizes = input_x | |||||
input, batch_sizes = input_x.data, input_x.batch_sizes | |||||
output = [] | output = [] | ||||
cell = self.cell | cell = self.cell | ||||
if is_reversed: | if is_reversed: | ||||
@@ -148,10 +148,11 @@ class VarRNNBase(nn.Module): | |||||
seq_len = x.size(1) if self.batch_first else x.size(0) | seq_len = x.size(1) if self.batch_first else x.size(0) | ||||
max_batch_size = x.size(0) if self.batch_first else x.size(1) | max_batch_size = x.size(0) if self.batch_first else x.size(1) | ||||
seq_lens = torch.LongTensor([seq_len for _ in range(max_batch_size)]) | seq_lens = torch.LongTensor([seq_len for _ in range(max_batch_size)]) | ||||
x, batch_sizes = pack_padded_sequence(x, seq_lens, batch_first=self.batch_first) | |||||
_tmp = pack_padded_sequence(x, seq_lens, batch_first=self.batch_first) | |||||
x, batch_sizes = _tmp.data, _tmp.batch_sizes | |||||
else: | else: | ||||
max_batch_size = int(x.batch_sizes[0]) | max_batch_size = int(x.batch_sizes[0]) | ||||
x, batch_sizes = x | |||||
x, batch_sizes = x.data, x.batch_sizes | |||||
if hx is None: | if hx is None: | ||||
hx = x.new_zeros(self.num_layers * self.num_directions, | hx = x.new_zeros(self.num_layers * self.num_directions, | ||||