Browse Source

update LossBase class

tags/v0.2.0^2
xuyige 6 years ago
parent
commit
d8a80ad6c6
2 changed files with 94 additions and 5 deletions
  1. +61
    -5
      fastNLP/core/losses.py
  2. +33
    -0
      fastNLP/core/utils.py

+ 61
- 5
fastNLP/core/losses.py View File

@@ -1,20 +1,76 @@
import torch

from fastNLP.core.utils import _get_arg_list
from fastNLP.core.utils import _map_args
from fastNLP.core.utils import get_func_signature
from fastNLP.core.utils import _build_args


class LossBase(object):
def __init__(self):
# key: name in target function; value: name in output function
self.param_map = {}

def get_loss(self, *args, **kwargs):
raise NotImplementedError

def __call__(self, output_dict, predict_dict):
pass
def __call__(self, output_dict, target_dict):
"""
:param output_dict: A dict from forward function of the network.
:param target_dict: A dict from DataSet.batch_y.
:return:
"""
args, defaults, defaults_val, varargs, kwargs = _get_arg_list(self.get_loss)
if varargs is not None:
raise RuntimeError(
f"The function {get_func_signature(self.get_loss)} should not use Positional Argument."
)

param_map = self.param_map
for keys in args:
if keys not in param_map:
param_map.update({keys: keys})
for keys in defaults:
if keys not in param_map:
param_map.update({keys: keys})
# param map: key= name in get_loss function, value= name in param dict
reversed_param_map = {val: key for key, val in param_map}
# reversed param map: key= name in param dict, value= name in get_loss function

param_val_dict = {}
for keys, val in output_dict.items():
if keys not in target_dict.keys():
param_val_dict.update({keys: val})
else:
raise RuntimeError("conflict Error in output dict and target dict with name {}".format(keys))
for keys, val in target_dict.items():
if keys not in output_dict.keys():
param_val_dict.update({keys: val})
else:
raise RuntimeError("conflict Error in output dict and target dict with name {}".format(keys))

for keys in args:
if param_map[keys] not in param_val_dict.keys():
raise RuntimeError("missing param {} in function {}".format(keys, self.get_loss))

class Loss(LossBase):
def __init__(self):
pass
param_map_val = _map_args(reversed_param_map, **param_val_dict)
param_value = _build_args(**param_map_val)

loss = self.get_loss(**param_value)

if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0):
if not isinstance(loss, torch.Tensor):
raise RuntimeError("loss ERROR: loss except a torch.Tensor but get {}".format(type(loss)))
raise RuntimeError("loss ERROR: len(loss.size()) except 0 but got {}".format(len(loss.size())))

return loss


class NewLoss(LossBase):
def __init__(self, func, key_map=None, **kwargs):
super(NewLoss).__init__()
if not callable(func):
raise RuntimeError("")


def squash(predict, truth, **kwargs):


+ 33
- 0
fastNLP/core/utils.py View File

@@ -64,6 +64,39 @@ def _build_args(func, **kwargs):
return output


def _map_args(maps: dict, **kwargs):
# maps: key=old name, value= new name
output = {}
for name, val in kwargs.items():
if name in maps:
assert isinstance(maps[name], str)
output.update({maps[name]: val})
else:
output.update({name: val})
for keys in maps.keys():
if keys not in output.keys():
# TODO: add UNUSED warning.
pass
return output


def _get_arg_list(func):
assert callable(func)
spect = inspect.getfullargspec(func)
if spect.defaults is not None:
args = spect.args[: -len(spect.defaults)]
defaults = spect.args[-len(spect.defaults):]
defaults_val = spect.defaults
else:
args = spect.args
defaults = None
defaults_val = None
varargs = spect.varargs
kwargs = spect.varkw
return args, defaults, defaults_val, varargs, kwargs



# check args
def _check_arg_dict_list(func, args):
if isinstance(args, dict):


Loading…
Cancel
Save