From 3ea7de16732c14ddeed4655669a4be89241c9c99 Mon Sep 17 00:00:00 2001 From: yh Date: Thu, 14 Feb 2019 13:18:50 +0800 Subject: [PATCH] =?UTF-8?q?1.=E4=BF=AE=E6=94=B9ClipGradientCallback?= =?UTF-8?q?=E7=9A=84bug=EF=BC=9B=E5=88=A0=E9=99=A4LRSchedulerCallback?= =?UTF-8?q?=E4=B8=AD=E7=9A=84print=EF=BC=8C=E4=B9=8B=E5=90=8E=E5=BA=94?= =?UTF-8?q?=E8=AF=A5=E4=BC=A0=E5=85=A5pbar=E8=BF=9B=E8=A1=8C=E6=89=93?= =?UTF-8?q?=E5=8D=B0;2.=E5=A2=9E=E5=8A=A0MLP=E6=B3=A8=E9=87=8A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/callback.py | 6 ++++-- fastNLP/modules/decoder/MLP.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/fastNLP/core/callback.py b/fastNLP/core/callback.py index b1a480cc..d941c235 100644 --- a/fastNLP/core/callback.py +++ b/fastNLP/core/callback.py @@ -248,7 +248,10 @@ class GradientClipCallback(Callback): self.clip_value = clip_value def on_backward_end(self, model): - self.clip_fun(model.parameters(), self.clip_value) + if self.parameters is None: + self.clip_fun(model.parameters(), self.clip_value) + else: + self.clip_fun(self.parameters, self.clip_value) class CallbackException(BaseException): @@ -306,7 +309,6 @@ class LRScheduler(Callback): def on_epoch_begin(self, cur_epoch, total_epoch): self.scheduler.step() - print("scheduler step ", "lr=", self.trainer.optimizer.param_groups[0]["lr"]) class ControlC(Callback): diff --git a/fastNLP/modules/decoder/MLP.py b/fastNLP/modules/decoder/MLP.py index c9198859..b76fdab7 100644 --- a/fastNLP/modules/decoder/MLP.py +++ b/fastNLP/modules/decoder/MLP.py @@ -7,7 +7,7 @@ from fastNLP.modules.utils import initial_parameter class MLP(nn.Module): """Multilayer Perceptrons as a decoder - :param list size_layer: list of int, define the size of MLP layers. + :param list size_layer: list of int, define the size of MLP layers. layer的层数为(len(size_layer)-1)//2 + 1 :param str activation: str or function, the activation function for hidden layers. :param str initial_method: the name of initialization method. :param float dropout: the probability of dropout.