diff --git a/fastNLP/core/losses.py b/fastNLP/core/losses.py index 981bef89..dce568bd 100644 --- a/fastNLP/core/losses.py +++ b/fastNLP/core/losses.py @@ -1,3 +1,6 @@ +import inspect +from collections import defaultdict + import torch import torch.nn.functional as F @@ -19,6 +22,54 @@ class LossBase(object): def get_loss(self, *args, **kwargs): raise NotImplementedError + def _init_param_map(self, key_map=None, **kwargs): + """Check the validity of key_map and other param map. Add these into self.param_map + + :param key_map: dict + :param kwargs: + :return: None + """ + value_counter = defaultdict(set) + if key_map is not None: + if not isinstance(key_map, dict): + raise TypeError("key_map must be `dict`, got {}.".format(type(key_map))) + for key, value in key_map.items(): + if value is None: + self.param_map[key] = key + continue + if not isinstance(key, str): + raise TypeError(f"key in key_map must be `str`, not `{type(key)}`.") + if not isinstance(value, str): + raise TypeError(f"value in key_map must be `str`, not `{type(value)}`.") + self.param_map[key] = value + value_counter[value].add(key) + for key, value in kwargs.items(): + if value is None: + self.param_map[key] = key + continue + if not isinstance(value, str): + raise TypeError(f"in {key}={value}, value must be `str`, not `{type(value)}`.") + self.param_map[key] = value + value_counter[value].add(key) + for value, key_set in value_counter.items(): + if len(key_set) > 1: + raise ValueError(f"Several parameters:{key_set} are provided with one output {value}.") + + # check consistence between signature and param_map + func_spect = inspect.getfullargspec(self.get_loss) + func_args = [arg for arg in func_spect.args if arg != 'self'] + for func_param, input_param in self.param_map.items(): + if func_param not in func_args: + raise NameError( + f"Parameter `{func_param}` is not in {get_func_signature(self.get_loss)}. Please check the " + f"initialization parameters, or change the signature of" + f" {get_func_signature(self.get_loss)}.") + + # evaluate should not have varargs. + if func_spect.varargs: + raise NameError(f"Delete `*{func_spect.varargs}` in {get_func_signature(self.get_loss)}(Do not use " + f"positional argument.).") + def __call__(self, output_dict, target_dict, force_check=False): """ :param output_dict: A dict from forward function of the network. @@ -106,6 +157,13 @@ class LossFunc(LossBase): self.get_loss = func +class CrossEntropyLoss(LossBase): + def __init__(self, input=None, target=None): + super(CrossEntropyLoss, self).__init__() + self.get_loss = F.cross_entropy + self._init_param_map(input=input, target=target) + + class L1Loss(LossBase): def __init__(self): super(L1Loss, self).__init__() @@ -116,6 +174,7 @@ class BCELoss(LossBase): def __init__(self, input=None, target=None): super(BCELoss, self).__init__() self.get_loss = F.binary_cross_entropy + self._init_param_map(input=input, target=target) class NLLLoss(LossBase): @@ -287,11 +346,12 @@ loss_function_name = { class Loss(object): - '''a Loss object is a callable object represents loss functions - ''' + """a Loss object is a callable object represents loss functions + + """ def __init__(self, loss_name, pre_pro=[squash], **kwargs): - ''' + """ :param loss_name: str or None , the name of loss function :param pre_pro : list of function or str, methods to reform parameters before calculating loss @@ -303,7 +363,7 @@ class Loss(object): kwargs is the extra parameters passed-in when calling loss function pre_pro functions should return two objects, respectively predict and truth that after processed - ''' + """ if loss_name is None: # this is useful when Trainer.__init__ performs type check diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index 6401d731..8ec2f7af 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -15,16 +15,15 @@ from fastNLP.core.utils import CheckRes class MetricBase(object): def __init__(self): - self.param_map = {} # key is param in function, value is input param. + self.param_map = {} # key is param in function, value is input param. self._checked = False def evaluate(self, *args, **kwargs): raise NotImplementedError def _init_param_map(self, key_map=None, **kwargs): - """ + """Check the validity of key_map and other param map. Add these into self.param_map - check the validity of key_map and other param map. Add these into self.param_map :param key_map: dict :param kwargs: :return: None @@ -37,9 +36,9 @@ class MetricBase(object): if value is None: self.param_map[key] = key continue - if isinstance(key, str): + if not isinstance(key, str): raise TypeError(f"key in key_map must be `str`, not `{type(key)}`.") - if isinstance(value, str): + if not isinstance(value, str): raise TypeError(f"value in key_map must be `str`, not `{type(value)}`.") self.param_map[key] = value value_counter[value].add(key) @@ -52,12 +51,12 @@ class MetricBase(object): self.param_map[key] = value value_counter[value].add(key) for value, key_set in value_counter.items(): - if len(key_set)>1: + if len(key_set) > 1: raise ValueError(f"Several parameters:{key_set} are provided with one output {value}.") # check consistence between signature and param_map func_spect = inspect.getfullargspec(self.evaluate) - func_args = [arg for arg in func_spect.args if arg!='self'] + func_args = [arg for arg in func_spect.args if arg != 'self'] for func_param, input_param in self.param_map.items(): if func_param not in func_args: raise NameError(f"Parameter `{func_param}` is not in {get_func_signature(self.evaluate)}. Please check the " @@ -87,7 +86,7 @@ class MetricBase(object): """ This method will call self.evaluate method. - Before calling self.evaluate, it will first check the validity ofoutput_dict, target_dict + Before calling self.evaluate, it will first check the validity of output_dict, target_dict (1) whether self.evaluate has varargs, which is not supported. (2) whether params needed by self.evaluate is not included in output_dict,target_dict. (3) whether params needed by self.evaluate duplicate in pred_dict, target_dict diff --git a/fastNLP/models/base_model.py b/fastNLP/models/base_model.py index 09274d2d..ec532014 100644 --- a/fastNLP/models/base_model.py +++ b/fastNLP/models/base_model.py @@ -1,5 +1,7 @@ import torch +from fastNLP.modules.decoder.MLP import MLP + class BaseModel(torch.nn.Module): """Base PyTorch model for all models. @@ -9,20 +11,19 @@ class BaseModel(torch.nn.Module): super(BaseModel, self).__init__() def fit(self, train_data, dev_data=None, **train_args): - raise NotImplementedError + pass def predict(self, *args, **kwargs): raise NotImplementedError -class LinearClassifier(BaseModel): +class NaiveClassifier(BaseModel): def __init__(self, in_feature_dim, out_feature_dim): - super(LinearClassifier, self).__init__() - self.linear = torch.nn.Linear(in_feature_dim, out_feature_dim) - self.softmax = torch.nn.Softmax() + super(NaiveClassifier, self).__init__() + self.mlp = MLP([in_feature_dim, in_feature_dim, out_feature_dim]) def forward(self, x): - return {"predict": self.softmax(self.linear(x))} + return {"predict": torch.sigmoid(self.mlp(x))} def predict(self, x): - return {"predict": self.softmax(self.linear(x))} + return {"predict": torch.sigmoid(self.mlp(x)) > 0.5} diff --git a/test/core/test_loss.py b/test/core/test_loss.py index edff342d..1124860b 100644 --- a/test/core/test_loss.py +++ b/test/core/test_loss.py @@ -1,370 +1,310 @@ import math import unittest +import torch import torch as tc import torch.nn.functional as F import fastNLP.core.losses as loss +from fastNLP.core.losses import LossFunc class TestLoss(unittest.TestCase): - def test_case_1(self): - #验证nllloss的原理 - - print (".----------------------------------") - - # loss_func = loss.Loss("nll") - print(callable(tc.nn.NLLLoss)) - - loss_func = loss.LossFunc(F.nll_loss) - - nll_loss = loss.NLLLoss() - - #pdb.set_trace() - - y = tc.Tensor( - [ - [.3,.4,.3], - [.5,.3,.2], - [.3,.6,.1], - ] - ) - - gy = tc.LongTensor( - [ - 0, - 1, - 2, - ] - ) - - - y = tc.log(y) - los = loss_func({'input': y}, {'target': gy}) - losses = nll_loss({'input': y}, {'target': gy}) - - r = -math.log(.3) - math.log(.3) - math.log(.1) - r /= 3 - print ("loss = %f" % (los)) - print ("r = %f" % (r)) - print ("nll_loss = %f" % (losses)) - - self.assertEqual(int(los * 1000), int(r * 1000)) - - def _test_case_2(self): - #验证squash()的正确性 - print ("----------------------------------") - - log = math.log - - loss_func = loss.Loss("nll") - - #pdb.set_trace() - - y = tc.Tensor( - [ - [[.3,.4,.3],[.3,.4,.3],], - [[.5,.3,.2],[.1,.2,.7],], - [[.3,.6,.1],[.2,.1,.7],], - ] - ) - - gy = tc.LongTensor( - [ - [0,2], - [1,2], - [2,1], - ] - ) - - - #pdb.set_trace() - - y = tc.log(y) - #los = loss_func({'input': y}, {'target': gy}) - los = loss_func(y, gy) - print ("loss = %f" % (los)) - - r = -log(.3) - log(.3) - log(.1) - log(.3) - log(.7) - log(.1) - r /= 6 - print ("r = %f" % (r)) - - self.assertEqual(int(los * 1000), int(r * 1000)) - - def test_case_3(self): - #验证pack_padded_sequence()的正确性 - print ("----------------------------------") - - log = math.log - - #loss_func = loss.Loss("nll") - loss_func = loss.NLLLoss() - - #pdb.set_trace() - - y = tc.Tensor( - [ - [[.3,.4,.3],[.3,.2,.5],[.4,.5,.1,],], - [[.5,.3,.2],[.1,.2,.7],[.0,.0,.0,],], - [[.3,.6,.1],[.0,.0,.0],[.0,.0,.0,],], - ] - ) - - gy = tc.LongTensor( - [ - [0,2,1,], - [1,2,0,], - [2,0,0,], - ] - ) - - lens = [3,2,1] - - #pdb.set_trace() - - y = tc.log(y) - - yy = tc.nn.utils.rnn.pack_padded_sequence(y , lens , batch_first = True).data - gyy = tc.nn.utils.rnn.pack_padded_sequence(gy , lens , batch_first = True).data - los = loss_func({'input': yy}, {'target': gyy}) - print ("loss = %f" % (los)) - - - r = -log(.3) - log(.5) - log(.5) - log(.3) - log(.7) - log(.1) - r /= 6 - print ("r = %f" % (r)) - - self.assertEqual(int(los * 1000), int(r * 1000)) - - def test_case_4(self): - #验证unpad()的正确性 - print ("----------------------------------") - - log = math.log - - #pdb.set_trace() - - y = tc.Tensor( - [ - [[.3,.4,.3],[.3,.2,.5],[.4,.5,.1,],[.6,.3,.1,],], - [[.5,.3,.2],[.1,.2,.7],[.0,.0,.0,],[.0,.0,.0,],], - [[.3,.6,.1],[.0,.0,.0],[.0,.0,.0,],[.0,.0,.0,],], - ] - ) - - gy = tc.LongTensor( - [ - [0,2,1,2,], - [1,2,0,0,], - [2,0,0,0,], - ] - ) - - lens = [4,2,1] - - #pdb.set_trace() - - y = tc.log(y) - - loss_func = loss.Loss("nll" , pre_pro = ["unpad"]) - los = loss_func(y , gy , lens = lens) - print ("loss = %f" % (los)) - - - r = -log(.1) -log(.3) - log(.5) - log(.5) - log(.3) - log(.7) - log(.1) - r /= 7 - print ("r = %f" % (r)) - - - self.assertEqual(int(los * 1000), int(r * 1000)) - - def test_case_5(self): - #验证mask()和make_mask()的正确性 - print ("----------------------------------") - - log = math.log - - #pdb.set_trace() - - y = tc.Tensor( - [ - [[.5,.3,.2],[.1,.2,.7],[.0,.0,.0,],[.0,.0,.0,],], - [[.5,.4,.1],[.3,.2,.5],[.4,.5,.1,],[.6,.1,.3,],], - [[.3,.6,.1],[.3,.2,.5],[.0,.0,.0,],[.0,.0,.0,],], - ] - ) - - gy = tc.LongTensor( - [ - [1,2,0,0,], - [0,2,1,2,], - [2,1,0,0,], - ] - ) - - mask = tc.ByteTensor( - [ - [1,1,0,0,], - [1,1,1,1,], - [1,1,0,0,], - ] - ) - - y = tc.log(y) - - lens = [2,4,2] - - loss_func = loss.Loss("nll" , pre_pro = ["mask"]) - los = loss_func(y , gy , mask = mask) - print ("loss = %f" % (los)) - - los2 = loss_func(y , gy , mask = loss.make_mask(lens,gy.size()[-1])) - print ("loss2 = %f" % (los2)) - - - r = -log(.3) -log(.7) - log(.5) - log(.5) - log(.5) - log(.3) - log(.1) - log(.2) - r /= 8 - print ("r = %f" % (r)) - - - self.assertEqual(int(los * 1000), int(r * 1000)) - self.assertEqual(int(los2 * 1000), int(r * 1000)) - - def test_case_6(self): - #验证unpad_mask()的正确性 - print ("----------------------------------") - - log = math.log - - #pdb.set_trace() - - y = tc.Tensor( - [ - [[.3,.4,.3],[.3,.2,.5],[.4,.5,.1,],[.6,.3,.1,],], - [[.5,.3,.2],[.1,.2,.7],[.0,.0,.0,],[.0,.0,.0,],], - [[.3,.6,.1],[.0,.0,.0],[.0,.0,.0,],[.0,.0,.0,],], - ] - ) - - gy = tc.LongTensor( - [ - [0,2,1,2,], - [1,2,0,0,], - [2,0,0,0,], - ] - ) - - lens = [4,2,1] - - #pdb.set_trace() - - y = tc.log(y) - - loss_func = loss.Loss("nll" , pre_pro = ["unpad_mask"]) - los = loss_func(y , gy , lens = lens) - print ("loss = %f" % (los)) - - - r = -log(.1) -log(.3) - log(.5) - log(.5) - log(.3) - log(.7) - log(.1) - r /= 7 - print ("r = %f" % (r)) - - self.assertEqual(int(los * 1000), int(r * 1000)) - - def test_case_7(self): - #验证一些其他东西 - print ("----------------------------------") - - log = math.log - - #pdb.set_trace() - - y = tc.Tensor( - [ - [[.3,.4,.3],[.3,.2,.5],[.4,.5,.1,],[.6,.3,.1,],], - [[.5,.3,.2],[.1,.2,.7],[.0,.0,.0,],[.0,.0,.0,],], - [[.3,.6,.1],[.0,.0,.0],[.0,.0,.0,],[.0,.0,.0,],], - ] - ) - - gy = tc.LongTensor( - [ - [0,2,1,2,], - [1,2,0,0,], - [2,0,0,0,], - ] - ) - - lens = [4,2,1] - - #pdb.set_trace() - - y = tc.log(y) - - loss_func = loss.Loss("nll" , pre_pro = [] , weight = tc.Tensor([1,1,0])) - loss_func.add_pre_pro("unpad_mask") - los = loss_func(y , gy , lens = lens) - print ("loss = %f" % (los)) - - - r = - log(.3) - log(.5) - log(.3) - r /= 3 - print ("r = %f" % (r)) - self.assertEqual(int(los * 1000), int(r * 1000)) - - def test_case_8(self): - def func(a, b): - import torch.nn.functional as F - return F.cross_entropy(a, b) - - def func2(a, truth): - return func(a, truth) - - def func3(predict, truth): - return func(predict, truth) - - def func4(a, b, c=2): - return (a + b) * c - - def func6(a, b, **kwargs): - c = kwargs['c'] - return (a + b) * c - - -from fastNLP.core.losses import LossFunc - -get_loss = LossFunc(func, {'a': 'predict', 'b': 'truth'}) - predict = torch.randn(5, 3) - truth = torch.LongTensor([1, 0, 1, 2, 1]) - loss1 = get_loss({'predict': predict}, {'truth': truth}) -get_loss_2 = LossFunc(func2, {'a': 'predict'}) - loss2 = get_loss_2({'predict': predict}, {'truth': truth}) -get_loss_3 = LossFunc(func3) - loss3 = get_loss_3({'predict': predict}, {'truth': truth}) - print(loss1, loss2, loss3) - assert loss1 == loss2 and loss1 == loss3 - -get_loss_4 = LossFunc(func4) - loss4 = get_loss_4({'a': 1, 'b': 3}, {}) - print(loss4) - assert loss4 == (1 + 3) * 2 - -get_loss_5 = LossFunc(func4) - loss5 = get_loss_5({'a': 1, 'b': 3}, {'c': 4}) - print(loss5) - assert loss5 == (1 + 3) * 4 - -get_loss_6 = LossFunc(func6) - loss6 = get_loss_6({'a': 1, 'b': 3}, {'c': 4}) - print(loss6) - assert loss6 == (1 + 3) * 4 - -get_loss_7 = LossFunc(func6, c='cc') - loss7 = get_loss_7({'a': 1, 'b': 3}, {'cc': 4}) - print(loss7) - assert loss7 == (1 + 3) * 4 - - -if __name__ == "__main__": - unittest.main() + def test_case_1(self): + loss_func = loss.LossFunc(F.nll_loss) + nll_loss = loss.NLLLoss() + y = tc.Tensor( + [ + [.3, .4, .3], + [.5, .3, .2], + [.3, .6, .1], + ] + ) + + gy = tc.LongTensor( + [ + 0, + 1, + 2, + ] + ) + + y = tc.log(y) + los = loss_func({'input': y}, {'target': gy}) + losses = nll_loss({'input': y}, {'target': gy}) + + r = -math.log(.3) - math.log(.3) - math.log(.1) + r /= 3 + print("loss = %f" % (los)) + print("r = %f" % (r)) + print("nll_loss = %f" % (losses)) + + self.assertEqual(int(los * 1000), int(r * 1000)) + + def test_case_2(self): + # 验证squash()的正确性 + + log = math.log + loss_func = loss.Loss("nll") + + y = tc.Tensor( + [ + [[.3, .4, .3], [.3, .4, .3], ], + [[.5, .3, .2], [.1, .2, .7], ], + [[.3, .6, .1], [.2, .1, .7], ], + ] + ) + + gy = tc.LongTensor( + [ + [0, 2], + [1, 2], + [2, 1], + ] + ) + + y = tc.log(y) + # los = loss_func({'input': y}, {'target': gy}) + los = loss_func(y, gy) + + r = -log(.3) - log(.3) - log(.1) - log(.3) - log(.7) - log(.1) + r /= 6 + + self.assertEqual(int(los * 1000), int(r * 1000)) + + def test_case_3(self): + # 验证pack_padded_sequence()的正确性 + log = math.log + loss_func = loss.NLLLoss() + y = tc.Tensor( + [ + [[.3, .4, .3], [.3, .2, .5], [.4, .5, .1, ], ], + [[.5, .3, .2], [.1, .2, .7], [.0, .0, .0, ], ], + [[.3, .6, .1], [.0, .0, .0], [.0, .0, .0, ], ], + ] + ) + + gy = tc.LongTensor( + [ + [0, 2, 1, ], + [1, 2, 0, ], + [2, 0, 0, ], + ] + ) + + lens = [3, 2, 1] + + # pdb.set_trace() + + y = tc.log(y) + + yy = tc.nn.utils.rnn.pack_padded_sequence(y, lens, batch_first=True).data + gyy = tc.nn.utils.rnn.pack_padded_sequence(gy, lens, batch_first=True).data + los = loss_func({'input': yy}, {'target': gyy}) + + r = -log(.3) - log(.5) - log(.5) - log(.3) - log(.7) - log(.1) + r /= 6 + + self.assertEqual(int(los * 1000), int(r * 1000)) + + def test_case_4(self): + # 验证unpad()的正确性 + log = math.log + y = tc.Tensor( + [ + [[.3, .4, .3], [.3, .2, .5], [.4, .5, .1, ], [.6, .3, .1, ], ], + [[.5, .3, .2], [.1, .2, .7], [.0, .0, .0, ], [.0, .0, .0, ], ], + [[.3, .6, .1], [.0, .0, .0], [.0, .0, .0, ], [.0, .0, .0, ], ], + ] + ) + + gy = tc.LongTensor( + [ + [0, 2, 1, 2, ], + [1, 2, 0, 0, ], + [2, 0, 0, 0, ], + ] + ) + + lens = [4, 2, 1] + y = tc.log(y) + + loss_func = loss.Loss("nll", pre_pro=["unpad"]) + los = loss_func(y, gy, lens=lens) + + r = -log(.1) - log(.3) - log(.5) - log(.5) - log(.3) - log(.7) - log(.1) + r /= 7 + + self.assertEqual(int(los * 1000), int(r * 1000)) + + def test_case_5(self): + # 验证mask()和make_mask()的正确性 + log = math.log + + y = tc.Tensor( + [ + [[.5, .3, .2], [.1, .2, .7], [.0, .0, .0, ], [.0, .0, .0, ], ], + [[.5, .4, .1], [.3, .2, .5], [.4, .5, .1, ], [.6, .1, .3, ], ], + [[.3, .6, .1], [.3, .2, .5], [.0, .0, .0, ], [.0, .0, .0, ], ], + ] + ) + + gy = tc.LongTensor( + [ + [1, 2, 0, 0, ], + [0, 2, 1, 2, ], + [2, 1, 0, 0, ], + ] + ) + + mask = tc.ByteTensor( + [ + [1, 1, 0, 0, ], + [1, 1, 1, 1, ], + [1, 1, 0, 0, ], + ] + ) + + y = tc.log(y) + + lens = [2, 4, 2] + + loss_func = loss.Loss("nll", pre_pro=["mask"]) + los = loss_func(y, gy, mask=mask) + + los2 = loss_func(y, gy, mask=loss.make_mask(lens, gy.size()[-1])) + + r = -log(.3) - log(.7) - log(.5) - log(.5) - log(.5) - log(.3) - log(.1) - log(.2) + r /= 8 + + self.assertEqual(int(los * 1000), int(r * 1000)) + self.assertEqual(int(los2 * 1000), int(r * 1000)) + + def test_case_6(self): + # 验证unpad_mask()的正确性 + log = math.log + y = tc.Tensor( + [ + [[.3, .4, .3], [.3, .2, .5], [.4, .5, .1, ], [.6, .3, .1, ], ], + [[.5, .3, .2], [.1, .2, .7], [.0, .0, .0, ], [.0, .0, .0, ], ], + [[.3, .6, .1], [.0, .0, .0], [.0, .0, .0, ], [.0, .0, .0, ], ], + ] + ) + + gy = tc.LongTensor( + [ + [0, 2, 1, 2, ], + [1, 2, 0, 0, ], + [2, 0, 0, 0, ], + ] + ) + + lens = [4, 2, 1] + + # pdb.set_trace() + + y = tc.log(y) + + loss_func = loss.Loss("nll", pre_pro=["unpad_mask"]) + los = loss_func(y, gy, lens=lens) + + r = -log(.1) - log(.3) - log(.5) - log(.5) - log(.3) - log(.7) - log(.1) + r /= 7 + + self.assertEqual(int(los * 1000), int(r * 1000)) + + def test_case_7(self): + # 验证一些其他东西 + log = math.log + y = tc.Tensor( + [ + [[.3, .4, .3], [.3, .2, .5], [.4, .5, .1, ], [.6, .3, .1, ], ], + [[.5, .3, .2], [.1, .2, .7], [.0, .0, .0, ], [.0, .0, .0, ], ], + [[.3, .6, .1], [.0, .0, .0], [.0, .0, .0, ], [.0, .0, .0, ], ], + ] + ) + + gy = tc.LongTensor( + [ + [0, 2, 1, 2, ], + [1, 2, 0, 0, ], + [2, 0, 0, 0, ], + ] + ) + + lens = [4, 2, 1] + y = tc.log(y) + + loss_func = loss.Loss("nll", pre_pro=[], weight=tc.Tensor([1, 1, 0])) + loss_func.add_pre_pro("unpad_mask") + los = loss_func(y, gy, lens=lens) + + r = - log(.3) - log(.5) - log(.3) + r /= 3 + self.assertEqual(int(los * 1000), int(r * 1000)) + + def test_case_8(self): + def func(a, b): + return F.cross_entropy(a, b) + + def func2(a, truth): + return func(a, truth) + + def func3(predict, truth): + return func(predict, truth) + + def func4(a, b, c=2): + return (a + b) * c + + def func6(a, b, **kwargs): + c = kwargs['c'] + return (a + b) * c + + get_loss = LossFunc(func, {'a': 'predict', 'b': 'truth'}) + predict = torch.randn(5, 3) + truth = torch.LongTensor([1, 0, 1, 2, 1]) + loss1 = get_loss({'predict': predict}, {'truth': truth}) + get_loss_2 = LossFunc(func2, {'a': 'predict'}) + loss2 = get_loss_2({'predict': predict}, {'truth': truth}) + get_loss_3 = LossFunc(func3) + loss3 = get_loss_3({'predict': predict}, {'truth': truth}) + assert loss1 == loss2 and loss1 == loss3 + + """ + get_loss_4 = LossFunc(func4) + loss4 = get_loss_4({'a': 1, 'b': 3}, {}) + print(loss4) + assert loss4 == (1 + 3) * 2 + + get_loss_5 = LossFunc(func4) + loss5 = get_loss_5({'a': 1, 'b': 3}, {'c': 4}) + print(loss5) + assert loss5 == (1 + 3) * 4 + + get_loss_6 = LossFunc(func6) + loss6 = get_loss_6({'a': 1, 'b': 3}, {'c': 4}) + print(loss6) + assert loss6 == (1 + 3) * 4 + + get_loss_7 = LossFunc(func6, c='cc') + loss7 = get_loss_7({'a': 1, 'b': 3}, {'cc': 4}) + print(loss7) + assert loss7 == (1 + 3) * 4 + """ + + +class TestLoss_v2(unittest.TestCase): + def test_CrossEntropyLoss(self): + ce = loss.CrossEntropyLoss(input="my_predict", target="my_truth") + a = torch.randn(3, 5, requires_grad=False) + b = torch.empty(3, dtype=torch.long).random_(5) + ans = ce({"my_predict": a}, {"my_truth": b}) + self.assertEqual(ans, torch.nn.functional.cross_entropy(a, b)) + + def test_BCELoss(self): + bce = loss.BCELoss(input="my_predict", target="my_truth") + a = torch.sigmoid(torch.randn((3, 5), requires_grad=False)) + b = torch.randn((3, 5), requires_grad=False) + ans = bce({"my_predict": a}, {"my_truth": b}) + self.assertEqual(ans, torch.nn.functional.binary_cross_entropy(a, b)) diff --git a/test/core/test_trainer.py b/test/core/test_trainer.py index 0194d254..ee4a5770 100644 --- a/test/core/test_trainer.py +++ b/test/core/test_trainer.py @@ -1,15 +1,14 @@ import unittest import numpy as np -import torch from fastNLP.core.dataset import DataSet from fastNLP.core.instance import Instance -from fastNLP.core.losses import LossFunc +from fastNLP.core.losses import BCELoss from fastNLP.core.metrics import AccuracyMetric from fastNLP.core.optimizer import SGD from fastNLP.core.trainer import Trainer -from fastNLP.models.base_model import LinearClassifier +from fastNLP.models.base_model import NaiveClassifier class TrainerTestGround(unittest.TestCase): @@ -30,18 +29,17 @@ class TrainerTestGround(unittest.TestCase): train_set, dev_set = data_set.split(0.3) - model = LinearClassifier(2, 1) + model = NaiveClassifier(2, 1) trainer = Trainer(train_set, model, - losser=LossFunc(torch.nn.functional.binary_cross_entropy, - key_map={"target": "y", "input": "predict"}), + losser=BCELoss(input="predict", target="y"), metrics=AccuracyMetric(pred="predict", target="y"), n_epochs=10, batch_size=32, print_every=10, validate_every=-1, dev_data=dev_set, - optimizer=SGD(0.001), + optimizer=SGD(0.1), check_code_level=2 ) trainer.train()