Browse Source

update trainer

tags/v0.2.0
yunfan 6 years ago
parent
commit
cf0b2c2d35
1 changed files with 5 additions and 5 deletions
  1. +5
    -5
      fastNLP/core/trainer.py

+ 5
- 5
fastNLP/core/trainer.py View File

@@ -171,11 +171,11 @@ class Trainer(object):

loss = self.get_loss(prediction, batch_y)
self.grad_backward(loss)
if torch.rand(1).item() < 0.001:
print('[grads at epoch: {:>3} step: {:>4}]'.format(kwargs['epoch'], step))
for name, p in self._model.named_parameters():
if p.requires_grad:
print('\t{} {} {}'.format(name, tuple(p.size()), torch.sum(p.grad).item()))
# if torch.rand(1).item() < 0.001:
# print('[grads at epoch: {:>3} step: {:>4}]'.format(kwargs['epoch'], step))
# for name, p in self._model.named_parameters():
# if p.requires_grad:
# print('\t{} {} {}'.format(name, tuple(p.size()), torch.sum(p.grad).item()))
self.update()
self._summary_writer.add_scalar("loss", loss.item(), global_step=step)



Loading…
Cancel
Save