Browse Source

修改losses中直接使用F.cross_entropy的情况,因为这些函数的signature是(input, target)

tags/v0.2.0^2
yh 6 years ago
parent
commit
4dff3ec81f
3 changed files with 76 additions and 73 deletions
  1. +71
    -68
      fastNLP/core/losses.py
  2. +1
    -1
      test/core/test_loss.py
  3. +4
    -4
      test/core/test_trainer.py

+ 71
- 68
fastNLP/core/losses.py View File

@@ -8,8 +8,7 @@ from fastNLP.core.utils import CheckError
from fastNLP.core.utils import CheckRes
from fastNLP.core.utils import _build_args
from fastNLP.core.utils import _check_function_or_method
from fastNLP.core.utils import _get_arg_list
from fastNLP.core.utils import _map_args
from fastNLP.core.utils import _check_arg_dict_list
from fastNLP.core.utils import get_func_signature


@@ -62,8 +61,7 @@ class LossBase(object):
if func_param not in func_args:
raise NameError(
f"Parameter `{func_param}` is not in {get_func_signature(self.get_loss)}. Please check the "
f"initialization parameters, or change the signature of"
f" {get_func_signature(self.get_loss)}.")
f"initialization parameters, or change its signature.")

# evaluate should not have varargs.
if func_spect.varargs:
@@ -87,71 +85,68 @@ class LossBase(object):
loss = self.get_loss(*fast_param)
return loss

args, defaults, defaults_val, varargs, kwargs = _get_arg_list(self.get_loss)
if varargs is not None:
raise RuntimeError(
f"The function {get_func_signature(self.get_loss)} should not use Positional Argument."
)

param_map = self.param_map
if args is None:
raise RuntimeError(
f"There is not any param in function{get_func_signature(self.get_loss)}"
)

self._checked = self._checked and not check
if not self._checked:
for keys in args:
if keys not in param_map:
param_map.update({keys: keys})
if defaults is not None:
for keys in defaults:
if keys not in param_map:
param_map.update({keys: keys})
self.param_map = param_map
# param map: key= name in get_loss function, value= name in param dict
reversed_param_map = {val: key for key, val in param_map.items()}
# reversed param map: key= name in param dict, value= name in get_loss function

# 1. check consistence between signature and param_map
func_spect = inspect.getfullargspec(self.get_loss)
func_args = set([arg for arg in func_spect.args if arg != 'self'])
for func_arg, input_arg in self.param_map.items():
if func_arg not in func_args:
raise NameError(f"`{func_arg}` not in {get_func_signature(self.get_loss)}.")

# 2. only part of the param_map are passed, left are not
for arg in func_args:
if arg not in self.param_map:
self.param_map[arg] = arg # This param does not need mapping.
self._evaluate_args = func_args
self._reverse_param_map = {input_arg: func_arg for func_arg, input_arg in self.param_map.items()}

# need to wrap inputs in dict.
mapped_pred_dict = {}
mapped_target_dict = {}
duplicated = []
missing = []
if not self._checked:
for keys, val in pred_dict.items():
if keys in target_dict.keys():
duplicated.append(param_map[keys])

param_val_dict = {}
for keys, val in pred_dict.items():
param_val_dict.update({keys: val})
for keys, val in target_dict.items():
param_val_dict.update({keys: val})

for input_arg in set(list(pred_dict.keys()) + list(target_dict.keys())):
not_duplicate_flag = 0
if input_arg in self._reverse_param_map:
mapped_arg = self._reverse_param_map[input_arg]
not_duplicate_flag += 1
else:
mapped_arg = input_arg
if input_arg in pred_dict:
mapped_pred_dict[mapped_arg] = pred_dict[input_arg]
not_duplicate_flag += 1
if input_arg in target_dict:
mapped_target_dict[mapped_arg] = target_dict[input_arg]
not_duplicate_flag += 1
if not_duplicate_flag == 3:
duplicated.append(input_arg)

# missing
if not self._checked:
for keys in args:
if param_map[keys] not in param_val_dict.keys():
missing.append(param_map[keys])

if len(duplicated) > 0 or len(missing) > 0:
raise CheckError(
CheckRes(missing=missing, unused=[], duplicated=duplicated, required=[], all_needed=[],
varargs=varargs),
func_signature=get_func_signature(self.get_loss)
)

check_res = _check_arg_dict_list(self.get_loss, [mapped_pred_dict, mapped_target_dict])
# only check missing.
missing = check_res.missing
replaced_missing = list(missing)
for idx, func_arg in enumerate(missing):
replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \
f"in `{self.__class__.__name__}`)"

check_res = CheckRes(missing=replaced_missing,
unused=check_res.unused,
duplicated=duplicated,
required=check_res.required,
all_needed=check_res.all_needed,
varargs=check_res.varargs)

if check_res.missing or check_res.duplicated or check_res.varargs:
raise CheckError(check_res=check_res,
func_signature=get_func_signature(self.get_loss))
refined_args = _build_args(self.get_loss, **mapped_pred_dict, **mapped_target_dict)

loss = self.get_loss(**refined_args)
self._checked = True

param_map_val = _map_args(reversed_param_map, **param_val_dict)
param_value = _build_args(self.get_loss, **param_map_val)
loss = self.get_loss(**param_value)

if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0):
if not isinstance(loss, torch.Tensor):
raise RuntimeError(f"loss ERROR: loss except a torch.Tensor but get {type(loss)}")
raise RuntimeError(f"loss ERROR: the size of loss except torch.Size([]) but got {loss.size()}")

return loss


class LossFunc(LossBase):
def __init__(self, func, key_map=None, **kwargs):
super(LossFunc, self).__init__()
@@ -168,34 +163,42 @@ class LossFunc(LossBase):


class CrossEntropyLoss(LossBase):
def __init__(self, pred=None, target=None):
def __init__(self, pred=None, target=None, padding_idx=-100):
# TODO 需要做一些检查,F.cross_entropy在计算时,如果pred是(16, 10 ,4), target的形状按道理应该是(16, 10), 但实际却需要
# TODO (16, 4)
super(CrossEntropyLoss, self).__init__()
self.get_loss = F.cross_entropy
self._init_param_map(input=pred, target=target)
self._init_param_map(pred=pred, target=target)
self.padding_idx = padding_idx

def get_loss(self, pred, target):
return F.cross_entropy(input=pred, target=target,
ignore_index=self.padding_idx)

class L1Loss(LossBase):
def __init__(self, pred=None, target=None):
super(L1Loss, self).__init__()
self.get_loss = F.l1_loss
self._init_param_map(input=pred, target=target)

def get_loss(self, pred, target):
return F.l1_loss(input=pred, target=target)


class BCELoss(LossBase):
def __init__(self, pred=None, target=None):
super(BCELoss, self).__init__()
self.get_loss = F.binary_cross_entropy
self._init_param_map(input=pred, target=target)

def get_loss(self, pred, target):
return F.binary_cross_entropy(input=pred, target=target)

class NLLLoss(LossBase):
def __init__(self, pred=None, target=None):
super(NLLLoss, self).__init__()
self.get_loss = F.nll_loss
self._init_param_map(input=pred, target=target)

def get_loss(self, pred, target):
return F.nll_loss(input=pred, target=target)


class LossInForward(LossBase):
def __init__(self, loss_key='loss'):


+ 1
- 1
test/core/test_loss.py View File

@@ -322,7 +322,7 @@ class TestLosserError(unittest.TestCase):
def test_losser3(self):
# (2) with corrupted size
pred_dict = {"pred": torch.zeros(16, 3), 'stop_fast_param':0}
target_dict = {'target': torch.zeros(16, 3).long()}
target_dict = {'target': torch.zeros(16).long()}
los = loss.CrossEntropyLoss()

print(los(pred_dict=pred_dict, target_dict=target_dict))


+ 4
- 4
test/core/test_trainer.py View File

@@ -8,7 +8,7 @@ from fastNLP.core.utils import CheckError
from fastNLP.core.dataset import DataSet
from fastNLP.core.instance import Instance
from fastNLP.core.losses import BCELoss
from fastNLP.core.losses import LossInForward
from fastNLP.core.losses import CrossEntropyLoss
from fastNLP.core.metrics import AccuracyMetric
from fastNLP.core.optimizer import SGD
from fastNLP.core.trainer import Trainer
@@ -222,7 +222,7 @@ class TrainerTestGround(unittest.TestCase):
x1 = self.fc(x1)
x2 = self.fc(x2)
x = x1 + x2
loss = F.cross_entropy(x, y)
# loss = F.cross_entropy(x, y)
return {'pred': x}

model = Model()
@@ -231,10 +231,10 @@ class TrainerTestGround(unittest.TestCase):
train_data=dataset,
model=model,
dev_data=dataset,
losser=CrossEntropyLoss(),
metrics=AccuracyMetric(),
use_tqdm=False,
print_every=2
)
print_every=2)

def test_case2(self):
# check metrics Wrong


Loading…
Cancel
Save