|
@@ -13,7 +13,7 @@ from fastNLP.core.utils import get_func_signature |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LossBase(object): |
|
|
class LossBase(object): |
|
|
"""Base class for all losses. |
|
|
|
|
|
|
|
|
"""所有loss的基类. |
|
|
|
|
|
|
|
|
""" |
|
|
""" |
|
|
def __init__(self): |
|
|
def __init__(self): |
|
@@ -24,10 +24,10 @@ class LossBase(object): |
|
|
raise NotImplementedError |
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
def _init_param_map(self, key_map=None, **kwargs): |
|
|
def _init_param_map(self, key_map=None, **kwargs): |
|
|
"""Check the validity of key_map and other param map. Add these into self.param_map |
|
|
|
|
|
|
|
|
"""检查key_map和其他参数map,并将这些映射关系添加到self.param_map |
|
|
|
|
|
|
|
|
:param key_map: dict |
|
|
|
|
|
:param kwargs: |
|
|
|
|
|
|
|
|
:param dict key_map: 表示key的映射关系 |
|
|
|
|
|
:param kwargs: key word args里面的每一个的键-值对都会被构造成映射关系 |
|
|
:return: None |
|
|
:return: None |
|
|
""" |
|
|
""" |
|
|
value_counter = defaultdict(set) |
|
|
value_counter = defaultdict(set) |
|
@@ -87,9 +87,9 @@ class LossBase(object): |
|
|
|
|
|
|
|
|
def __call__(self, pred_dict, target_dict, check=False): |
|
|
def __call__(self, pred_dict, target_dict, check=False): |
|
|
""" |
|
|
""" |
|
|
:param pred_dict: A dict from forward function of the network. |
|
|
|
|
|
:param target_dict: A dict from DataSet.batch_y. |
|
|
|
|
|
:param check: Boolean. Force to check the mapping functions when it is running. |
|
|
|
|
|
|
|
|
:param dict pred_dict: 模型的forward函数返回的dict |
|
|
|
|
|
:param dict target_dict: DataSet.batch_y里的键-值对所组成的dict |
|
|
|
|
|
:param Boolean check: 每一次执行映射函数的时候是否检查映射表,默认为不检查 |
|
|
:return: |
|
|
:return: |
|
|
""" |
|
|
""" |
|
|
fast_param = self._fast_param_map(pred_dict, target_dict) |
|
|
fast_param = self._fast_param_map(pred_dict, target_dict) |
|
@@ -162,15 +162,25 @@ class LossBase(object): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LossFunc(LossBase): |
|
|
class LossFunc(LossBase): |
|
|
"""A wrapper of user-provided loss function. |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""提供给用户使用自定义损失函数的类 |
|
|
""" |
|
|
""" |
|
|
def __init__(self, func, key_map=None, **kwargs): |
|
|
def __init__(self, func, key_map=None, **kwargs): |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
:param func: a callable object, such as a function. |
|
|
|
|
|
:param dict key_map: |
|
|
|
|
|
:param kwargs: |
|
|
|
|
|
|
|
|
:param func: 用户自行定义的损失函数,应当为一个函数或者callable(func)为True的ojbect |
|
|
|
|
|
:param dict key_map: 参数映射表。键为Model/DataSet参数名,值为损失函数参数名。 |
|
|
|
|
|
fastNLP的trainer将在训练时从模型返回值或者训练数据DataSet的target=True的field中 |
|
|
|
|
|
找到相对应的参数名为value的参数,并传入func中作为参数名为key的参数 |
|
|
|
|
|
:param kwargs: 除了参数映射表以外可以用key word args的方式设置参数映射关系 |
|
|
|
|
|
|
|
|
|
|
|
Example:: |
|
|
|
|
|
|
|
|
|
|
|
>>> func = torch.nn.CrossEntropyLoss() |
|
|
|
|
|
>>> loss_func = LossFunc(func, input="pred", target="label") |
|
|
|
|
|
>>> # 这表示构建了一个损失函数类,由func计算损失函数,其中将从模型返回值或者DataSet的target=True的field |
|
|
|
|
|
>>> # 当中找到一个参数名为`pred`的参数传入func一个参数名为`input`的参数;找到一个参数名为`label`的参数 |
|
|
|
|
|
>>> # 传入func作为一个名为`target`的参数 |
|
|
|
|
|
|
|
|
""" |
|
|
""" |
|
|
super(LossFunc, self).__init__() |
|
|
super(LossFunc, self).__init__() |
|
|
_check_function_or_method(func) |
|
|
_check_function_or_method(func) |
|
@@ -186,7 +196,17 @@ class LossFunc(LossBase): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CrossEntropyLoss(LossBase): |
|
|
class CrossEntropyLoss(LossBase): |
|
|
|
|
|
"""交叉熵损失函数""" |
|
|
def __init__(self, pred=None, target=None, padding_idx=-100): |
|
|
def __init__(self, pred=None, target=None, padding_idx=-100): |
|
|
|
|
|
""" |
|
|
|
|
|
:param pred: 参数映射表中`pred`的映射关系,None表示映射关系为`pred`->`pred` |
|
|
|
|
|
:param target: 参数映射表中`target`的映射关系,None表示映射关系为`target`->`target` |
|
|
|
|
|
:param padding_idx: padding的index,在计算loss时将忽略target中标号为padding_idx的内容 |
|
|
|
|
|
|
|
|
|
|
|
Example:: |
|
|
|
|
|
|
|
|
|
|
|
>>> loss = CrossEntropyLoss(pred='pred', target='label', padding_idx=0) |
|
|
|
|
|
""" |
|
|
# TODO 需要做一些检查,F.cross_entropy在计算时,如果pred是(16, 10 ,4), target的形状按道理应该是(16, 10), 但实际却需要 |
|
|
# TODO 需要做一些检查,F.cross_entropy在计算时,如果pred是(16, 10 ,4), target的形状按道理应该是(16, 10), 但实际却需要 |
|
|
# TODO (16, 4) |
|
|
# TODO (16, 4) |
|
|
super(CrossEntropyLoss, self).__init__() |
|
|
super(CrossEntropyLoss, self).__init__() |
|
@@ -199,7 +219,12 @@ class CrossEntropyLoss(LossBase): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class L1Loss(LossBase): |
|
|
class L1Loss(LossBase): |
|
|
|
|
|
"""L1损失函数""" |
|
|
def __init__(self, pred=None, target=None): |
|
|
def __init__(self, pred=None, target=None): |
|
|
|
|
|
""" |
|
|
|
|
|
:param pred: 参数映射表中`pred`的映射关系,None表示映射关系为`pred`->`pred` |
|
|
|
|
|
:param target: 参数映射表中`target`的映射关系,None表示映射关系为`target`->`target` |
|
|
|
|
|
""" |
|
|
super(L1Loss, self).__init__() |
|
|
super(L1Loss, self).__init__() |
|
|
self._init_param_map(pred=pred, target=target) |
|
|
self._init_param_map(pred=pred, target=target) |
|
|
|
|
|
|
|
@@ -208,7 +233,12 @@ class L1Loss(LossBase): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BCELoss(LossBase): |
|
|
class BCELoss(LossBase): |
|
|
|
|
|
"""二分类交叉熵损失函数""" |
|
|
def __init__(self, pred=None, target=None): |
|
|
def __init__(self, pred=None, target=None): |
|
|
|
|
|
""" |
|
|
|
|
|
:param pred: 参数映射表中`pred`的映射关系,None表示映射关系为`pred`->`pred` |
|
|
|
|
|
:param target: 参数映射表中`target`的映射关系,None表示映射关系为`target`->`target` |
|
|
|
|
|
""" |
|
|
super(BCELoss, self).__init__() |
|
|
super(BCELoss, self).__init__() |
|
|
self._init_param_map(pred=pred, target=target) |
|
|
self._init_param_map(pred=pred, target=target) |
|
|
|
|
|
|
|
@@ -217,7 +247,12 @@ class BCELoss(LossBase): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class NLLLoss(LossBase): |
|
|
class NLLLoss(LossBase): |
|
|
|
|
|
"""负对数似然损失函数""" |
|
|
def __init__(self, pred=None, target=None): |
|
|
def __init__(self, pred=None, target=None): |
|
|
|
|
|
""" |
|
|
|
|
|
:param pred: 参数映射表中`pred`的映射关系,None表示映射关系为`pred`->`pred` |
|
|
|
|
|
:param target: 参数映射表中`target`的映射关系,None表示映射关系为`target`->`target` |
|
|
|
|
|
""" |
|
|
super(NLLLoss, self).__init__() |
|
|
super(NLLLoss, self).__init__() |
|
|
self._init_param_map(pred=pred, target=target) |
|
|
self._init_param_map(pred=pred, target=target) |
|
|
|
|
|
|
|
@@ -226,7 +261,11 @@ class NLLLoss(LossBase): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LossInForward(LossBase): |
|
|
class LossInForward(LossBase): |
|
|
|
|
|
"""Forward函数中计算得到的损失函数结果""" |
|
|
def __init__(self, loss_key='loss'): |
|
|
def __init__(self, loss_key='loss'): |
|
|
|
|
|
""" |
|
|
|
|
|
:param str loss_key: 在forward函数中取得loss的键名,默认为loss |
|
|
|
|
|
""" |
|
|
super().__init__() |
|
|
super().__init__() |
|
|
if not isinstance(loss_key, str): |
|
|
if not isinstance(loss_key, str): |
|
|
raise TypeError(f"Only str allowed for loss_key, got {type(loss_key)}.") |
|
|
raise TypeError(f"Only str allowed for loss_key, got {type(loss_key)}.") |
|
@@ -234,13 +273,14 @@ class LossInForward(LossBase): |
|
|
|
|
|
|
|
|
def get_loss(self, **kwargs): |
|
|
def get_loss(self, **kwargs): |
|
|
if self.loss_key not in kwargs: |
|
|
if self.loss_key not in kwargs: |
|
|
check_res = CheckRes(missing=[self.loss_key + f"(assign to `{self.loss_key}` " \ |
|
|
|
|
|
f"in `{self.__class__.__name__}`"], |
|
|
|
|
|
unused=[], |
|
|
|
|
|
duplicated=[], |
|
|
|
|
|
required=[], |
|
|
|
|
|
all_needed=[], |
|
|
|
|
|
varargs=[]) |
|
|
|
|
|
|
|
|
check_res = CheckRes( |
|
|
|
|
|
missing=[self.loss_key + f"(assign to `{self.loss_key}` in `{self.__class__.__name__}`"], |
|
|
|
|
|
unused=[], |
|
|
|
|
|
duplicated=[], |
|
|
|
|
|
required=[], |
|
|
|
|
|
all_needed=[], |
|
|
|
|
|
varargs=[] |
|
|
|
|
|
) |
|
|
raise CheckError(check_res=check_res, func_signature=get_func_signature(self.get_loss)) |
|
|
raise CheckError(check_res=check_res, func_signature=get_func_signature(self.get_loss)) |
|
|
return kwargs[self.loss_key] |
|
|
return kwargs[self.loss_key] |
|
|
|
|
|
|
|
|