Browse Source

fix bug in Trainer about metric_key

更新Optimizer: 多种初始化方法
1. SGD()
2. SGD(0.01)
3. SGD(lr=0.01)
4. SGD(lr=0.01, momentum=0.9)
5. SGD(model.parameters(), lr=0.1, momentum=0.9)
tags/v0.2.0^2
FengZiYjun 6 years ago
parent
commit
fb5215ae73
3 changed files with 99 additions and 22 deletions
  1. +54
    -4
      fastNLP/core/optimizer.py
  2. +12
    -8
      fastNLP/core/trainer.py
  3. +33
    -10
      test/core/test_optimizer.py

+ 54
- 4
fastNLP/core/optimizer.py View File

@@ -3,14 +3,41 @@ import torch

class Optimizer(object):
def __init__(self, model_params, **kwargs):
if model_params is not None and not isinstance(model_params, torch.Tensor):
raise RuntimeError("model parameters should be torch.Tensor, rather than {}".format(type(model_params)))
if model_params is not None and not hasattr(model_params, "__next__"):
raise RuntimeError("model parameters should be a generator, rather than {}".format(type(model_params)))
self.model_params = model_params
self.settings = kwargs


class SGD(Optimizer):
def __init__(self, model_params=None, lr=0.001, momentum=0.9):
def __init__(self, *args, **kwargs):
model_params, lr, momentum = None, 0.01, 0.9
if len(args) == 0 and len(kwargs) == 0:
# SGD()
pass
elif len(args) == 1 and len(kwargs) == 0:
if isinstance(args[0], float) or isinstance(args[0], int):
# SGD(0.001)
lr = args[0]
elif hasattr(args[0], "__next__"):
# SGD(model.parameters()) args[0] is a generator
model_params = args[0]
else:
raise RuntimeError("Not supported type {}.".format(type(args[0])))
elif 2 >= len(kwargs) > 0 and len(args) <= 1:
# SGD(lr=0.01), SGD(lr=0.01, momentum=0.9), SGD(model.parameters(), lr=0.1, momentum=0.9)
if len(args) == 1:
if hasattr(args[0], "__next__"):
model_params = args[0]
else:
raise RuntimeError("Not supported type {}.".format(type(args[0])))
if not all(key in ("lr", "momentum") for key in kwargs):
raise RuntimeError("Invalid SGD arguments. Expect {}, got {}.".format(("lr", "momentum"), kwargs))
lr = kwargs.get("lr", 0.01)
momentum = kwargs.get("momentum", 0.9)
else:
raise RuntimeError("SGD only accept 0 or 1 sequential argument, but got {}: {}".format(len(args), args))

super(SGD, self).__init__(model_params, lr=lr, momentum=momentum)

def construct_from_pytorch(self, model_params):
@@ -20,7 +47,30 @@ class SGD(Optimizer):


class Adam(Optimizer):
def __init__(self, model_params=None, lr=0.001, weight_decay=0.8):
def __init__(self, *args, **kwargs):
model_params, lr, weight_decay = None, 0.01, 0.9
if len(args) == 0 and len(kwargs) == 0:
pass
elif len(args) == 1 and len(kwargs) == 0:
if isinstance(args[0], float) or isinstance(args[0], int):
lr = args[0]
elif hasattr(args[0], "__next__"):
model_params = args[0]
else:
raise RuntimeError("Not supported type {}.".format(type(args[0])))
elif 2 >= len(kwargs) > 0 and len(args) <= 1:
if len(args) == 1:
if hasattr(args[0], "__next__"):
model_params = args[0]
else:
raise RuntimeError("Not supported type {}.".format(type(args[0])))
if not all(key in ("lr", "weight_decay") for key in kwargs):
raise RuntimeError("Invalid Adam arguments. Expect {}, got {}.".format(("lr", "weight_decay"), kwargs))
lr = kwargs.get("lr", 0.01)
weight_decay = kwargs.get("weight_decay", 0.9)
else:
raise RuntimeError("Adam only accept 0 or 1 sequential argument, but got {}: {}".format(len(args), args))

super(Adam, self).__init__(model_params, lr=lr, weight_decay=weight_decay)

def construct_from_pytorch(self, model_params):


+ 12
- 8
fastNLP/core/trainer.py View File

@@ -56,7 +56,10 @@ class Trainer(object):
# increase_better is True. It means the exp result gets better if the indicator increases.
# It is true by default.
self.increase_better = False if metric_key[0] == "-" else True
self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key
if metric_key is not None:
self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key
else:
self.metric_key = None

# prepare loss
losser = _prepare_losser(losser)
@@ -144,12 +147,13 @@ class Trainer(object):
del self._summary_writer

def _train_epoch(self, data_iterator, model, epoch, start):
"""Training process in one epoch.
"""

kwargs should contain:
- n_print: int, print training information every n steps.
- start: time.time(), the starting time of this step.
- epoch: int,
:param data_iterator:
:param model:
:param epoch:
:param start:
:return:
"""
for batch_x, batch_y in data_iterator:
# TODO 这里可能会遇到问题,万一用户在model内部修改了prediction的device就会有问题
@@ -188,7 +192,7 @@ class Trainer(object):
"""Train mode or Test mode. This is for PyTorch currently.

:param model: a PyTorch model
:param is_test: bool, whether in test mode or not.
:param bool is_test: whether in test mode or not.

"""
if is_test:
@@ -263,7 +267,7 @@ class Trainer(object):
else:
# metric_key is set
if self.metric_key not in metric_dict:
raise RuntimeError(f"matric key {self.metric_key} not found in {metric_dict}")
raise RuntimeError(f"metric key {self.metric_key} not found in {metric_dict}")
indicator_val = metric_dict[self.metric_key]

is_better = True


+ 33
- 10
test/core/test_optimizer.py View File

@@ -2,20 +2,43 @@ import unittest

import torch

from fastNLP.core.optimizer import SGD
from fastNLP.core.optimizer import SGD, Adam


class TestOptim(unittest.TestCase):
def test_case(self):
optim = SGD(torch.LongTensor(10))
print(optim.__dict__)
def test_SGD(self):
optim = SGD(torch.nn.Linear(10, 3).parameters())
self.assertTrue("lr" in optim.__dict__["settings"])
self.assertTrue("momentum" in optim.__dict__["settings"])

optim_2 = SGD(lr=0.001)
print(optim_2.__dict__)
optim = SGD(0.001)
self.assertEqual(optim.__dict__["settings"]["lr"], 0.001)

optim_2 = SGD(lr=0.002, momentum=0.989)
print(optim_2.__dict__)
optim = SGD(lr=0.001)
self.assertEqual(optim.__dict__["settings"]["lr"], 0.001)

def test_case_2(self):
optim = SGD(lr=0.002, momentum=0.989)
self.assertEqual(optim.__dict__["settings"]["lr"], 0.002)
self.assertEqual(optim.__dict__["settings"]["momentum"], 0.989)

with self.assertRaises(RuntimeError):
_ = SGD("???")
with self.assertRaises(RuntimeError):
_ = SGD(0.001)
_ = SGD(0.001, lr=0.002)
with self.assertRaises(RuntimeError):
_ = SGD(lr=0.009, shit=9000)

def test_Adam(self):
optim = Adam(torch.nn.Linear(10, 3).parameters())
self.assertTrue("lr" in optim.__dict__["settings"])
self.assertTrue("weight_decay" in optim.__dict__["settings"])

optim = Adam(0.001)
self.assertEqual(optim.__dict__["settings"]["lr"], 0.001)

optim = Adam(lr=0.001)
self.assertEqual(optim.__dict__["settings"]["lr"], 0.001)

optim = Adam(lr=0.002, weight_decay=0.989)
self.assertEqual(optim.__dict__["settings"]["lr"], 0.002)
self.assertEqual(optim.__dict__["settings"]["weight_decay"], 0.989)

Loading…
Cancel
Save