From 72877c6ed5b8011ad367eff42178594f53dd87df Mon Sep 17 00:00:00 2001 From: yh Date: Fri, 7 Dec 2018 13:31:52 +0800 Subject: [PATCH] =?UTF-8?q?optimizer=E5=88=9D=E5=A7=8B=E5=8C=96=E8=B0=83?= =?UTF-8?q?=E6=95=B4=E5=8F=82=E6=95=B0=E9=A1=BA=E5=BA=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/__init__.py | 7 ++++--- fastNLP/core/optimizer.py | 12 ++++++++---- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/fastNLP/core/__init__.py b/fastNLP/core/__init__.py index b62d5624..44f30fad 100644 --- a/fastNLP/core/__init__.py +++ b/fastNLP/core/__init__.py @@ -2,9 +2,10 @@ from .batch import Batch from .dataset import DataSet from .fieldarray import FieldArray from .instance import Instance -from .losses import LossFromTorch -from .optimizer import Optimizer +from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward +from .metrics import AccuracyMetric +from .optimizer import Optimizer, SGD, Adam from .sampler import SequentialSampler, BucketSampler, RandomSampler, BaseSampler from .tester import Tester from .trainer import Trainer -from .vocabulary import Vocabulary +from .vocabulary import Vocabulary \ No newline at end of file diff --git a/fastNLP/core/optimizer.py b/fastNLP/core/optimizer.py index dfcf83f9..f123ae40 100644 --- a/fastNLP/core/optimizer.py +++ b/fastNLP/core/optimizer.py @@ -10,13 +10,15 @@ class Optimizer(object): class SGD(Optimizer): - def __init__(self, model_params=None, lr=0.01, momentum=0): + def __init__(self, lr=0.01, momentum=0, model_params=None): """ - :param model_params: a generator. E.g. model.parameters() for PyTorch models. :param float lr: learning rate. Default: 0.01 :param float momentum: momentum. Default: 0 + :param model_params: a generator. E.g. model.parameters() for PyTorch models. """ + if not isinstance(lr, float): + raise TypeError("learning rate has to be float.") super(SGD, self).__init__(model_params, lr=lr, momentum=momentum) def construct_from_pytorch(self, model_params): @@ -28,13 +30,15 @@ class SGD(Optimizer): class Adam(Optimizer): - def __init__(self, model_params=None, lr=0.01, weight_decay=0): + def __init__(self, lr=0.01, weight_decay=0, model_params=None): """ - :param model_params: a generator. E.g. model.parameters() for PyTorch models. :param float lr: learning rate :param float weight_decay: + :param model_params: a generator. E.g. model.parameters() for PyTorch models. """ + if not isinstance(lr, float): + raise TypeError("learning rate has to be float.") super(Adam, self).__init__(model_params, lr=lr, weight_decay=weight_decay) def construct_from_pytorch(self, model_params):