Browse Source

update LossBase class

tags/v0.2.0^2
xuyige 7 years ago
parent
commit
6d36190be4
2 changed files with 143 additions and 31 deletions
  1. +75
    -25
      fastNLP/core/losses.py
  2. +68
    -6
      test/core/test_loss.py

+ 75
- 25
fastNLP/core/losses.py View File

@@ -1,23 +1,29 @@
import torch
import torch.nn.functional as F

from fastNLP.core.utils import CheckError
from fastNLP.core.utils import CheckRes
from fastNLP.core.utils import _get_arg_list
from fastNLP.core.utils import _map_args
from fastNLP.core.utils import get_func_signature
from fastNLP.core.utils import _build_args
from fastNLP.core.utils import _check_function_or_method


class LossBase(object):
def __init__(self):
# key: name in target function; value: name in output function
self.param_map = {}
self._checked = False

def get_loss(self, *args, **kwargs):
raise NotImplementedError

def __call__(self, output_dict, target_dict):
def __call__(self, output_dict, target_dict, force_check=False):
"""
:param output_dict: A dict from forward function of the network.
:param target_dict: A dict from DataSet.batch_y.
:param force_check: Boolean. Force to check the mapping functions when it is running.
:return:
"""
args, defaults, defaults_val, varargs, kwargs = _get_arg_list(self.get_loss)
@@ -27,50 +33,94 @@ class LossBase(object):
)

param_map = self.param_map
for keys in args:
if keys not in param_map:
param_map.update({keys: keys})
for keys in defaults:
if keys not in param_map:
param_map.update({keys: keys})
if args is None:
raise RuntimeError(
f"There is not any param in function{get_func_signature(self.get_loss)}"
)
self._checked = self._checked and not force_check
if not self._checked:
for keys in args:
if keys not in param_map:
param_map.update({keys: keys})
if defaults is not None:
for keys in defaults:
if keys not in param_map:
param_map.update({keys: keys})
self.param_map = param_map
# param map: key= name in get_loss function, value= name in param dict
reversed_param_map = {val: key for key, val in param_map}
reversed_param_map = {val: key for key, val in param_map.items()}
# reversed param map: key= name in param dict, value= name in get_loss function

duplicated = []
missing = []
if not self._checked:
for keys, val in output_dict.items():
if keys in target_dict.keys():
duplicated.append(keys)

param_val_dict = {}
for keys, val in output_dict.items():
if keys not in target_dict.keys():
param_val_dict.update({keys: val})
else:
raise RuntimeError("conflict Error in output dict and target dict with name {}".format(keys))
param_val_dict.update({keys: val})
for keys, val in target_dict.items():
if keys not in output_dict.keys():
param_val_dict.update({keys: val})
else:
raise RuntimeError("conflict Error in output dict and target dict with name {}".format(keys))
param_val_dict.update({keys: val})

for keys in args:
if param_map[keys] not in param_val_dict.keys():
raise RuntimeError(f"missing param {keys} in function {get_func_signature(self.get_loss)}")
if not self._checked:
for keys in args:
if param_map[keys] not in param_val_dict.keys():
missing.append(keys)

if len(duplicated) > 0 or len(missing) > 0:
raise CheckError(
CheckRes(missing=missing, unused=[], duplicated=duplicated, required=[], all_needed=[]),
func_signature=get_func_signature(self.get_loss)
)

self._checked = True

param_map_val = _map_args(reversed_param_map, **param_val_dict)
param_value = _build_args(**param_map_val)
param_value = _build_args(self.get_loss, **param_map_val)

loss = self.get_loss(**param_value)

if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0):
if not isinstance(loss, torch.Tensor):
raise RuntimeError("loss ERROR: loss except a torch.Tensor but get {}".format(type(loss)))
raise RuntimeError("loss ERROR: len(loss.size()) except 0 but got {}".format(len(loss.size())))
raise RuntimeError(f"loss ERROR: loss except a torch.Tensor but get {type(loss)}")
raise RuntimeError(f"loss ERROR: the size of loss except torch.Size([]) but got {loss.size}")

return loss


class NewLoss(LossBase):
def __init__(self, func, key_map=None, **kwargs):
super(NewLoss).__init__()
if not callable(func):
raise RuntimeError("")
super(NewLoss, self).__init__()
_check_function_or_method(func)
if key_map is not None:
if not isinstance(key_map, dict):
raise RuntimeError(f"Loss error: key_map except a {type({})} but got a {type(key_map)}")
self.param_map = key_map
if len(kwargs) > 0:
for key, val in kwargs.items():
self.param_map.update({key: val})

self.get_loss = func


class L1Loss(LossBase):
def __init__(self):
super(L1Loss, self).__init__()
self.get_loss = F.l1_loss


class BCELoss(LossBase):
def __init__(self):
super(BCELoss, self).__init__()
self.get_loss = F.binary_cross_entropy


class NLLLoss(LossBase):
def __init__(self):
super(NLLLoss, self).__init__()
self.get_loss = F.nll_loss


class LossInForward(LossBase):


+ 68
- 6
test/core/test_loss.py View File

@@ -2,6 +2,7 @@ import math
import unittest

import torch as tc
import torch.nn.functional as F

import fastNLP.core.losses as loss

@@ -13,7 +14,11 @@ class TestLoss(unittest.TestCase):

print (".----------------------------------")

loss_func = loss.Loss("nll")
# loss_func = loss.Loss("nll")
print(callable(tc.nn.NLLLoss))
loss_func = loss.NewLoss(F.nll_loss)

nll_loss = loss.NLLLoss()

#pdb.set_trace()

@@ -35,16 +40,18 @@ class TestLoss(unittest.TestCase):


y = tc.log(y)
los = loss_func(y , gy)
los = loss_func({'input': y}, {'target': gy})
losses = nll_loss({'input': y}, {'target': gy})

r = -math.log(.3) - math.log(.3) - math.log(.1)
r /= 3
print ("loss = %f" % (los))
print ("r = %f" % (r))
print ("nll_loss = %f" % (losses))

self.assertEqual(int(los * 1000), int(r * 1000))

def test_case_2(self):
def _test_case_2(self):
#验证squash()的正确性
print ("----------------------------------")

@@ -74,7 +81,8 @@ class TestLoss(unittest.TestCase):
#pdb.set_trace()

y = tc.log(y)
los = loss_func(y , gy)
#los = loss_func({'input': y}, {'target': gy})
los = loss_func(y, gy)
print ("loss = %f" % (los))

r = -log(.3) - log(.3) - log(.1) - log(.3) - log(.7) - log(.1)
@@ -89,7 +97,8 @@ class TestLoss(unittest.TestCase):

log = math.log

loss_func = loss.Loss("nll")
#loss_func = loss.Loss("nll")
loss_func = loss.NLLLoss()

#pdb.set_trace()

@@ -117,7 +126,7 @@ class TestLoss(unittest.TestCase):

yy = tc.nn.utils.rnn.pack_padded_sequence(y , lens , batch_first = True).data
gyy = tc.nn.utils.rnn.pack_padded_sequence(gy , lens , batch_first = True).data
los = loss_func(yy , gyy)
los = loss_func({'input': yy}, {'target': gyy})
print ("loss = %f" % (los))


@@ -303,5 +312,58 @@ class TestLoss(unittest.TestCase):
print ("r = %f" % (r))
self.assertEqual(int(los * 1000), int(r * 1000))

def test_case_8(self):
def func(a, b):
import torch.nn.functional as F
return F.cross_entropy(a, b)

def func2(a, truth):
return func(a, truth)

def func3(predict, truth):
return func(predict, truth)

def func4(a, b, c=2):
return (a + b) * c

def func6(a, b, **kwargs):
c = kwargs['c']
return (a + b) * c

import torch
from fastNLP.core.losses import LossBase, NewLoss

get_loss = NewLoss(func, {'a': 'predict', 'b': 'truth'})
predict = torch.randn(5, 3)
truth = torch.LongTensor([1, 0, 1, 2, 1])
loss1 = get_loss({'predict': predict}, {'truth': truth})
get_loss_2 = NewLoss(func2, {'a': 'predict'})
loss2 = get_loss_2({'predict': predict}, {'truth': truth})
get_loss_3 = NewLoss(func3)
loss3 = get_loss_3({'predict': predict}, {'truth': truth})
print(loss1, loss2, loss3)
assert loss1 == loss2 and loss1 == loss3

get_loss_4 = NewLoss(func4)
loss4 = get_loss_4({'a': 1, 'b': 3}, {})
print(loss4)
assert loss4 == (1 + 3) * 2

get_loss_5 = NewLoss(func4)
loss5 = get_loss_5({'a': 1, 'b': 3}, {'c': 4})
print(loss5)
assert loss5 == (1 + 3) * 4

get_loss_6 = NewLoss(func6)
loss6 = get_loss_6({'a': 1, 'b': 3}, {'c': 4})
print(loss6)
assert loss6 == (1 + 3) * 4

get_loss_7 = NewLoss(func6, c='cc')
loss7 = get_loss_7({'a': 1, 'b': 3}, {'cc': 4})
print(loss7)
assert loss7 == (1 + 3) * 4


if __name__ == "__main__":
unittest.main()

Loading…
Cancel
Save