Browse Source

conflict in trainer solved

tags/v0.2.0^2
yh 6 years ago
parent
commit
1b961f136c
6 changed files with 266 additions and 127 deletions
  1. +75
    -25
      fastNLP/core/losses.py
  2. +18
    -51
      fastNLP/core/optimizer.py
  3. +83
    -45
      fastNLP/core/trainer.py
  4. +68
    -6
      test/core/test_loss.py
  5. +21
    -0
      test/core/test_optimizer.py
  6. +1
    -0
      test/core/test_trainer.py

+ 75
- 25
fastNLP/core/losses.py View File

@@ -1,23 +1,29 @@
import torch
import torch.nn.functional as F

from fastNLP.core.utils import CheckError
from fastNLP.core.utils import CheckRes
from fastNLP.core.utils import _get_arg_list
from fastNLP.core.utils import _map_args
from fastNLP.core.utils import get_func_signature
from fastNLP.core.utils import _build_args
from fastNLP.core.utils import _check_function_or_method


class LossBase(object):
def __init__(self):
# key: name in target function; value: name in output function
self.param_map = {}
self._checked = False

def get_loss(self, *args, **kwargs):
raise NotImplementedError

def __call__(self, output_dict, target_dict):
def __call__(self, output_dict, target_dict, force_check=False):
"""
:param output_dict: A dict from forward function of the network.
:param target_dict: A dict from DataSet.batch_y.
:param force_check: Boolean. Force to check the mapping functions when it is running.
:return:
"""
args, defaults, defaults_val, varargs, kwargs = _get_arg_list(self.get_loss)
@@ -27,50 +33,94 @@ class LossBase(object):
)

param_map = self.param_map
for keys in args:
if keys not in param_map:
param_map.update({keys: keys})
for keys in defaults:
if keys not in param_map:
param_map.update({keys: keys})
if args is None:
raise RuntimeError(
f"There is not any param in function{get_func_signature(self.get_loss)}"
)
self._checked = self._checked and not force_check
if not self._checked:
for keys in args:
if keys not in param_map:
param_map.update({keys: keys})
if defaults is not None:
for keys in defaults:
if keys not in param_map:
param_map.update({keys: keys})
self.param_map = param_map
# param map: key= name in get_loss function, value= name in param dict
reversed_param_map = {val: key for key, val in param_map}
reversed_param_map = {val: key for key, val in param_map.items()}
# reversed param map: key= name in param dict, value= name in get_loss function

duplicated = []
missing = []
if not self._checked:
for keys, val in output_dict.items():
if keys in target_dict.keys():
duplicated.append(keys)

param_val_dict = {}
for keys, val in output_dict.items():
if keys not in target_dict.keys():
param_val_dict.update({keys: val})
else:
raise RuntimeError("conflict Error in output dict and target dict with name {}".format(keys))
param_val_dict.update({keys: val})
for keys, val in target_dict.items():
if keys not in output_dict.keys():
param_val_dict.update({keys: val})
else:
raise RuntimeError("conflict Error in output dict and target dict with name {}".format(keys))
param_val_dict.update({keys: val})

for keys in args:
if param_map[keys] not in param_val_dict.keys():
raise RuntimeError(f"missing param {keys} in function {get_func_signature(self.get_loss)}")
if not self._checked:
for keys in args:
if param_map[keys] not in param_val_dict.keys():
missing.append(keys)

if len(duplicated) > 0 or len(missing) > 0:
raise CheckError(
CheckRes(missing=missing, unused=[], duplicated=duplicated, required=[], all_needed=[]),
func_signature=get_func_signature(self.get_loss)
)

self._checked = True

param_map_val = _map_args(reversed_param_map, **param_val_dict)
param_value = _build_args(**param_map_val)
param_value = _build_args(self.get_loss, **param_map_val)

loss = self.get_loss(**param_value)

if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0):
if not isinstance(loss, torch.Tensor):
raise RuntimeError("loss ERROR: loss except a torch.Tensor but get {}".format(type(loss)))
raise RuntimeError("loss ERROR: len(loss.size()) except 0 but got {}".format(len(loss.size())))
raise RuntimeError(f"loss ERROR: loss except a torch.Tensor but get {type(loss)}")
raise RuntimeError(f"loss ERROR: the size of loss except torch.Size([]) but got {loss.size}")

return loss


class NewLoss(LossBase):
def __init__(self, func, key_map=None, **kwargs):
super(NewLoss).__init__()
if not callable(func):
raise RuntimeError("")
super(NewLoss, self).__init__()
_check_function_or_method(func)
if key_map is not None:
if not isinstance(key_map, dict):
raise RuntimeError(f"Loss error: key_map except a {type({})} but got a {type(key_map)}")
self.param_map = key_map
if len(kwargs) > 0:
for key, val in kwargs.items():
self.param_map.update({key: val})

self.get_loss = func


class L1Loss(LossBase):
def __init__(self):
super(L1Loss, self).__init__()
self.get_loss = F.l1_loss


class BCELoss(LossBase):
def __init__(self):
super(BCELoss, self).__init__()
self.get_loss = F.binary_cross_entropy


class NLLLoss(LossBase):
def __init__(self):
super(NLLLoss, self).__init__()
self.get_loss = F.nll_loss


class LossInForward(LossBase):


+ 18
- 51
fastNLP/core/optimizer.py View File

@@ -2,61 +2,28 @@ import torch


class Optimizer(object):
"""Wrapper of optimizer from framework
def __init__(self, model_params, **kwargs):
if model_params is not None and not isinstance(model_params, torch.Tensor):
raise RuntimeError("model parameters should be torch.Tensor, rather than {}".format(type(model_params)))
self.model_params = model_params
self.settings = kwargs

1. Adam: lr (float), weight_decay (float)
2. AdaGrad
3. RMSProp
4. SGD: lr (float), momentum (float)

"""
class SGD(Optimizer):
def __init__(self, model_params=None, lr=0.001, momentum=0.9):
super(SGD, self).__init__(model_params, lr=lr, momentum=momentum)

def __init__(self, optimizer_name, **kwargs):
"""
:param optimizer_name: str, the name of the optimizer
:param kwargs: the arguments

"""
self.optim_name = optimizer_name
self.kwargs = kwargs

@property
def name(self):
"""The name of the optimizer.

:return: str
"""
return self.optim_name
def construct_from_pytorch(self, model_params):
if self.model_params is None:
self.model_params = model_params
return torch.optim.SGD(self.model_params, **self.settings)

@property
def params(self):
"""The arguments used to create the optimizer.

:return: dict of (str, *)
"""
return self.kwargs
class Adam(Optimizer):
def __init__(self, model_params=None, lr=0.001, weight_decay=0.8):
super(Adam, self).__init__(model_params, lr=lr, weight_decay=weight_decay)

def construct_from_pytorch(self, model_params):
"""Construct a optimizer from framework over given model parameters."""

if self.optim_name in ["SGD", "sgd"]:
if "lr" in self.kwargs:
if "momentum" not in self.kwargs:
self.kwargs["momentum"] = 0
optimizer = torch.optim.SGD(model_params, lr=self.kwargs["lr"], momentum=self.kwargs["momentum"])
else:
raise ValueError("requires learning rate for SGD optimizer")

elif self.optim_name in ["adam", "Adam"]:
if "lr" in self.kwargs:
if "weight_decay" not in self.kwargs:
self.kwargs["weight_decay"] = 0
optimizer = torch.optim.Adam(model_params, lr=self.kwargs["lr"],
weight_decay=self.kwargs["weight_decay"])
else:
raise ValueError("requires learning rate for Adam optimizer")

else:
raise NotImplementedError

return optimizer
if self.model_params is None:
self.model_params = model_params
return torch.optim.Adam(self.model_params, **self.settings)

+ 83
- 45
fastNLP/core/trainer.py View File

@@ -1,20 +1,22 @@
import itertools
import os
import time
import warnings
from collections import defaultdict
from datetime import datetime
from datetime import timedelta

import torch
from torch import nn
from tensorboardX import SummaryWriter
from torch import nn

from fastNLP.core.batch import Batch
from fastNLP.core.optimizer import Optimizer
from fastNLP.core.dataset import DataSet
from fastNLP.core.losses import _prepare_losser
from fastNLP.core.metrics import _prepare_metrics
from fastNLP.core.optimizer import Adam
from fastNLP.core.sampler import RandomSampler
from fastNLP.core.sampler import SequentialSampler
from fastNLP.core.tester import Tester
from fastNLP.core.utils import CheckError
from fastNLP.core.utils import _build_args
from fastNLP.core.utils import _check_arg_dict_list
from fastNLP.core.utils import _move_dict_value_to_device
@@ -30,9 +32,12 @@ class Trainer(object):
"""Main Training Loop

"""
def __init__(self, train_data, model, losser=None, metrics=None, n_epochs=3, batch_size=32, print_every=-1, validate_every=-1,

def __init__(self, train_data, model, losser=None, metrics=None, n_epochs=3, batch_size=32, print_every=-1,
validate_every=-1,
dev_data=None, use_cuda=False, save_path="./save",
optimizer=Optimizer("Adam", lr=0.01, weight_decay=0), check_code_level=0,
optimizer=Adam(lr=0.01, weight_decay=0), check_code_level=0,
metric_key=None,
**kwargs):
super(Trainer, self).__init__()

@@ -49,6 +54,13 @@ class Trainer(object):

# prepare evaluate
metrics = _prepare_metrics(metrics)

# parse metric_key
# increase_better is True. It means the exp result gets better if the indicator increases.
# It is true by default.
self.increase_better = False if metric_key[0] == "-" else True
self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key

# prepare loss
losser = _prepare_losser(losser)

@@ -67,12 +79,10 @@ class Trainer(object):
self.save_path = save_path
self.print_every = int(print_every)
self.validate_every = int(validate_every)
self._best_accuracy = 0
self.best_metric_indicator = None

self._model_device = model.parameters().__next__().device

# TODO self._best_accuracy不能表现出当前的metric多种的情况

if isinstance(optimizer, torch.optim.Optimizer):
self.optimizer = optimizer
else:
@@ -102,7 +112,7 @@ class Trainer(object):
if torch.cuda.is_available() and self.use_cuda:
self.model = self.model.cuda()

self.mode(self.model, is_test=False)
self._mode(self.model, is_test=False)

start = time.time()
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S'))
@@ -112,7 +122,9 @@ class Trainer(object):
def __getattr__(self, item):
def pass_func(*args, **kwargs):
pass

return pass_func

self._summary_writer = psudoSW()
else:
path = os.path.join(self.save_path, 'tensorboard_logs_{}'.format(self.start_time))
@@ -121,19 +133,20 @@ class Trainer(object):
epoch = 1
while epoch <= self.n_epochs:

data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=RandomSampler(), as_numpy=False)
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=RandomSampler(),
as_numpy=False)

self._train_epoch(data_iterator, self.model, epoch, self.dev_data, start)
self._train_epoch(data_iterator, self.model, epoch, start)

# validate_every override validation at end of epochs
if self.dev_data and self.validate_every <= 0:
self.do_validation()
self._do_validation()
epoch += 1
finally:
self._summary_writer.close()
del self._summary_writer

def _train_epoch(self, data_iterator, model, epoch, dev_data, start, **kwargs):
def _train_epoch(self, data_iterator, model, epoch, start):
"""Training process in one epoch.

kwargs should contain:
@@ -144,10 +157,10 @@ class Trainer(object):
for batch_x, batch_y in data_iterator:
# TODO 这里可能会遇到问题,万一用户在model内部修改了prediction的device就会有问题
_move_dict_value_to_device(self._model_device, batch_x, batch_y)
prediction = self.data_forward(model, batch_x)
loss = self.get_loss(prediction, batch_y)
self.grad_backward(loss)
self.update()
prediction = self._data_forward(model, batch_x)
loss = self._compute_loss(prediction, batch_y)
self._grad_backward(loss)
self._update()
self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step)
for name, param in self.model.named_parameters():
if param.requires_grad:
@@ -162,18 +175,19 @@ class Trainer(object):
print(print_output)

if self.validate_every > 0 and self.step % self.validate_every == 0:
self.do_validation()
self._do_validation()

self.step += 1

def do_validation(self):
def _do_validation(self):
res = self.tester.test()
for name, num in res.items():
self._summary_writer.add_scalar("valid_{}".format(name), num, global_step=self.step)
if self.save_path is not None and self.best_eval_result(res):
self.save_model(self.model, 'best_model_' + self.start_time)
if self.save_path is not None and self._better_eval_result(res):
self._save_model(self.model,
"best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time]))

def mode(self, model, is_test=False):
def _mode(self, model, is_test=False):
"""Train mode or Test mode. This is for PyTorch currently.

:param model: a PyTorch model
@@ -185,20 +199,20 @@ class Trainer(object):
else:
model.train()

def update(self):
def _update(self):
"""Perform weight update on a model.

"""
self.optimizer.step()

def data_forward(self, network, x):
def _data_forward(self, network, x):
x = _build_args(network.forward, **x)
y = network(**x)
if not isinstance(y, dict):
raise TypeError(f"The return value of {get_func_signature(network.forward)} should be dict, got {type(y)}.")
return y

def grad_backward(self, loss):
def _grad_backward(self, loss):
"""Compute gradient with link rules.

:param loss: a scalar where back-prop starts
@@ -208,7 +222,7 @@ class Trainer(object):
self.model.zero_grad()
loss.backward()

def get_loss(self, predict, truth):
def _compute_loss(self, predict, truth):
"""Compute loss given prediction and ground truth.

:param predict: prediction dict, produced by model.forward
@@ -217,34 +231,59 @@ class Trainer(object):
"""
return self.losser(predict, truth)

def save_model(self, model, model_name, only_param=False):
def _save_model(self, model, model_name, only_param=False):
model_name = os.path.join(self.save_path, model_name)
if only_param:
torch.save(model.state_dict(), model_name)
else:
torch.save(model, model_name)

def best_eval_result(self, metrics):
def _better_eval_result(self, metrics):
"""Check if the current epoch yields better validation results.

:return: bool, True means current results on dev set is the best.
:return bool value: True means current results on dev set is the best.
"""
if isinstance(metrics, tuple):
loss, metrics = metrics

if isinstance(metrics, dict):
if len(metrics) == 1:
accuracy = list(metrics.values())[0]
# only single metric, just use it
metric_dict = list(metrics.values())[0]
metrics_name = list(metrics.keys())[0]
else:
accuracy = metrics[self.eval_sort_key]
else:
accuracy = metrics

if accuracy > self._best_accuracy:
self._best_accuracy = accuracy
return True
else:
return False
metrics_name = self.metrics[0].__class__.__name__
if metrics_name not in metrics:
raise RuntimeError(f"{metrics_name} is chosen to do validation, but got {metrics}")
metric_dict = metrics[metrics_name]

if len(metric_dict) == 1:
indicator_val, indicator = list(metric_dict.values())[0], list(metric_dict.keys())[0]
elif len(metric_dict) > 1 and self.metric_key is None:
raise RuntimeError(
f"Got multiple metric keys: {metric_dict}, but metric_key is not set. Which one to use?")
else:
# metric_key is set
if self.metric_key not in metric_dict:
raise RuntimeError(f"matric key {self.metric_key} not found in {metric_dict}")
indicator_val = metric_dict[self.metric_key]

is_better = True
if self.best_metric_indicator is None:
# first-time validation
self.best_metric_indicator = indicator_val
else:
if self.increase_better is True:
if indicator_val > self.best_metric_indicator:
self.best_metric_indicator = indicator_val
else:
is_better = False
else:
if indicator_val < self.best_metric_indicator:
self.best_metric_indicator = indicator_val
else:
is_better = False
return is_better


DEFAULT_CHECK_BATCH_SIZE = 2
@@ -285,18 +324,17 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_
f"should be torch.size([])")
loss.backward()
except CheckError as e:
_check_loss_evaluate(prev_func=model.forward, func=e.func_signature,
pre_func_signature = get_func_signature(model.forward)
_check_loss_evaluate(prev_func_signature=pre_func_signature, func_signature=e.func_signature,
check_res=e.check_res, output=output, batch_y=batch_y,
check_level=check_level)
model.zero_grad()
if batch_count+1>=DEFAULT_CHECK_NUM_BATCH:
if batch_count + 1 >= DEFAULT_CHECK_NUM_BATCH:
break

if dev_data is not None:
tester = Tester(data=dataset[:batch_size*DEFAULT_CHECK_NUM_BATCH], model=model, metrics=metrics,
tester = Tester(data=dataset[:batch_size * DEFAULT_CHECK_NUM_BATCH], model=model, metrics=metrics,
batch_size=batch_size, verbose=-1)
tester.test()





+ 68
- 6
test/core/test_loss.py View File

@@ -2,6 +2,7 @@ import math
import unittest

import torch as tc
import torch.nn.functional as F

import fastNLP.core.losses as loss

@@ -13,7 +14,11 @@ class TestLoss(unittest.TestCase):

print (".----------------------------------")

loss_func = loss.Loss("nll")
# loss_func = loss.Loss("nll")
print(callable(tc.nn.NLLLoss))
loss_func = loss.NewLoss(F.nll_loss)

nll_loss = loss.NLLLoss()

#pdb.set_trace()

@@ -35,16 +40,18 @@ class TestLoss(unittest.TestCase):


y = tc.log(y)
los = loss_func(y , gy)
los = loss_func({'input': y}, {'target': gy})
losses = nll_loss({'input': y}, {'target': gy})

r = -math.log(.3) - math.log(.3) - math.log(.1)
r /= 3
print ("loss = %f" % (los))
print ("r = %f" % (r))
print ("nll_loss = %f" % (losses))

self.assertEqual(int(los * 1000), int(r * 1000))

def test_case_2(self):
def _test_case_2(self):
#验证squash()的正确性
print ("----------------------------------")

@@ -74,7 +81,8 @@ class TestLoss(unittest.TestCase):
#pdb.set_trace()

y = tc.log(y)
los = loss_func(y , gy)
#los = loss_func({'input': y}, {'target': gy})
los = loss_func(y, gy)
print ("loss = %f" % (los))

r = -log(.3) - log(.3) - log(.1) - log(.3) - log(.7) - log(.1)
@@ -89,7 +97,8 @@ class TestLoss(unittest.TestCase):

log = math.log

loss_func = loss.Loss("nll")
#loss_func = loss.Loss("nll")
loss_func = loss.NLLLoss()

#pdb.set_trace()

@@ -117,7 +126,7 @@ class TestLoss(unittest.TestCase):

yy = tc.nn.utils.rnn.pack_padded_sequence(y , lens , batch_first = True).data
gyy = tc.nn.utils.rnn.pack_padded_sequence(gy , lens , batch_first = True).data
los = loss_func(yy , gyy)
los = loss_func({'input': yy}, {'target': gyy})
print ("loss = %f" % (los))


@@ -303,5 +312,58 @@ class TestLoss(unittest.TestCase):
print ("r = %f" % (r))
self.assertEqual(int(los * 1000), int(r * 1000))

def test_case_8(self):
def func(a, b):
import torch.nn.functional as F
return F.cross_entropy(a, b)

def func2(a, truth):
return func(a, truth)

def func3(predict, truth):
return func(predict, truth)

def func4(a, b, c=2):
return (a + b) * c

def func6(a, b, **kwargs):
c = kwargs['c']
return (a + b) * c

import torch
from fastNLP.core.losses import LossBase, NewLoss

get_loss = NewLoss(func, {'a': 'predict', 'b': 'truth'})
predict = torch.randn(5, 3)
truth = torch.LongTensor([1, 0, 1, 2, 1])
loss1 = get_loss({'predict': predict}, {'truth': truth})
get_loss_2 = NewLoss(func2, {'a': 'predict'})
loss2 = get_loss_2({'predict': predict}, {'truth': truth})
get_loss_3 = NewLoss(func3)
loss3 = get_loss_3({'predict': predict}, {'truth': truth})
print(loss1, loss2, loss3)
assert loss1 == loss2 and loss1 == loss3

get_loss_4 = NewLoss(func4)
loss4 = get_loss_4({'a': 1, 'b': 3}, {})
print(loss4)
assert loss4 == (1 + 3) * 2

get_loss_5 = NewLoss(func4)
loss5 = get_loss_5({'a': 1, 'b': 3}, {'c': 4})
print(loss5)
assert loss5 == (1 + 3) * 4

get_loss_6 = NewLoss(func6)
loss6 = get_loss_6({'a': 1, 'b': 3}, {'c': 4})
print(loss6)
assert loss6 == (1 + 3) * 4

get_loss_7 = NewLoss(func6, c='cc')
loss7 = get_loss_7({'a': 1, 'b': 3}, {'cc': 4})
print(loss7)
assert loss7 == (1 + 3) * 4


if __name__ == "__main__":
unittest.main()

+ 21
- 0
test/core/test_optimizer.py View File

@@ -0,0 +1,21 @@
import unittest

import torch

from fastNLP.core.optimizer import SGD


class TestOptim(unittest.TestCase):
def test_case(self):
optim = SGD(torch.LongTensor(10))
print(optim.__dict__)

optim_2 = SGD(lr=0.001)
print(optim_2.__dict__)

optim_2 = SGD(lr=0.002, momentum=0.989)
print(optim_2.__dict__)

def test_case_2(self):
with self.assertRaises(RuntimeError):
_ = SGD(0.001)

+ 1
- 0
test/core/test_trainer.py View File

@@ -4,3 +4,4 @@ import unittest
class TestTrainer(unittest.TestCase):
def test_case_1(self):
pass


Loading…
Cancel
Save