diff --git a/fastNLP/core/__init__.py b/fastNLP/core/__init__.py index b62d5624..44f30fad 100644 --- a/fastNLP/core/__init__.py +++ b/fastNLP/core/__init__.py @@ -2,9 +2,10 @@ from .batch import Batch from .dataset import DataSet from .fieldarray import FieldArray from .instance import Instance -from .losses import LossFromTorch -from .optimizer import Optimizer +from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward +from .metrics import AccuracyMetric +from .optimizer import Optimizer, SGD, Adam from .sampler import SequentialSampler, BucketSampler, RandomSampler, BaseSampler from .tester import Tester from .trainer import Trainer -from .vocabulary import Vocabulary +from .vocabulary import Vocabulary \ No newline at end of file diff --git a/fastNLP/core/optimizer.py b/fastNLP/core/optimizer.py index dfcf83f9..f123ae40 100644 --- a/fastNLP/core/optimizer.py +++ b/fastNLP/core/optimizer.py @@ -10,13 +10,15 @@ class Optimizer(object): class SGD(Optimizer): - def __init__(self, model_params=None, lr=0.01, momentum=0): + def __init__(self, lr=0.01, momentum=0, model_params=None): """ - :param model_params: a generator. E.g. model.parameters() for PyTorch models. :param float lr: learning rate. Default: 0.01 :param float momentum: momentum. Default: 0 + :param model_params: a generator. E.g. model.parameters() for PyTorch models. """ + if not isinstance(lr, float): + raise TypeError("learning rate has to be float.") super(SGD, self).__init__(model_params, lr=lr, momentum=momentum) def construct_from_pytorch(self, model_params): @@ -28,13 +30,15 @@ class SGD(Optimizer): class Adam(Optimizer): - def __init__(self, model_params=None, lr=0.01, weight_decay=0): + def __init__(self, lr=0.01, weight_decay=0, model_params=None): """ - :param model_params: a generator. E.g. model.parameters() for PyTorch models. :param float lr: learning rate :param float weight_decay: + :param model_params: a generator. E.g. model.parameters() for PyTorch models. """ + if not isinstance(lr, float): + raise TypeError("learning rate has to be float.") super(Adam, self).__init__(model_params, lr=lr, weight_decay=weight_decay) def construct_from_pytorch(self, model_params):