Browse Source

update documents on losses.py

tags/v0.4.10
xuyige 6 years ago
parent
commit
5b8a62783c
1 changed files with 59 additions and 19 deletions
  1. +59
    -19
      fastNLP/core/losses.py

+ 59
- 19
fastNLP/core/losses.py View File

@@ -13,7 +13,7 @@ from fastNLP.core.utils import get_func_signature




class LossBase(object): class LossBase(object):
"""Base class for all losses.
"""所有loss的基类.


""" """
def __init__(self): def __init__(self):
@@ -24,10 +24,10 @@ class LossBase(object):
raise NotImplementedError raise NotImplementedError


def _init_param_map(self, key_map=None, **kwargs): def _init_param_map(self, key_map=None, **kwargs):
"""Check the validity of key_map and other param map. Add these into self.param_map
"""检查key_map和其他参数map,并将这些映射关系添加到self.param_map


:param key_map: dict
:param kwargs:
:param dict key_map: 表示key的映射关系
:param kwargs: key word args里面的每一个的键-值对都会被构造成映射关系
:return: None :return: None
""" """
value_counter = defaultdict(set) value_counter = defaultdict(set)
@@ -87,9 +87,9 @@ class LossBase(object):


def __call__(self, pred_dict, target_dict, check=False): def __call__(self, pred_dict, target_dict, check=False):
""" """
:param pred_dict: A dict from forward function of the network.
:param target_dict: A dict from DataSet.batch_y.
:param check: Boolean. Force to check the mapping functions when it is running.
:param dict pred_dict: 模型的forward函数返回的dict
:param dict target_dict: DataSet.batch_y里的键-值对所组成的dict
:param Boolean check: 每一次执行映射函数的时候是否检查映射表,默认为不检查
:return: :return:
""" """
fast_param = self._fast_param_map(pred_dict, target_dict) fast_param = self._fast_param_map(pred_dict, target_dict)
@@ -162,15 +162,25 @@ class LossBase(object):




class LossFunc(LossBase): class LossFunc(LossBase):
"""A wrapper of user-provided loss function.

"""提供给用户使用自定义损失函数的类
""" """
def __init__(self, func, key_map=None, **kwargs): def __init__(self, func, key_map=None, **kwargs):
""" """


:param func: a callable object, such as a function.
:param dict key_map:
:param kwargs:
:param func: 用户自行定义的损失函数,应当为一个函数或者callable(func)为True的ojbect
:param dict key_map: 参数映射表。键为Model/DataSet参数名,值为损失函数参数名。
fastNLP的trainer将在训练时从模型返回值或者训练数据DataSet的target=True的field中
找到相对应的参数名为value的参数,并传入func中作为参数名为key的参数
:param kwargs: 除了参数映射表以外可以用key word args的方式设置参数映射关系

Example::

>>> func = torch.nn.CrossEntropyLoss()
>>> loss_func = LossFunc(func, input="pred", target="label")
>>> # 这表示构建了一个损失函数类,由func计算损失函数,其中将从模型返回值或者DataSet的target=True的field
>>> # 当中找到一个参数名为`pred`的参数传入func一个参数名为`input`的参数;找到一个参数名为`label`的参数
>>> # 传入func作为一个名为`target`的参数

""" """
super(LossFunc, self).__init__() super(LossFunc, self).__init__()
_check_function_or_method(func) _check_function_or_method(func)
@@ -186,7 +196,17 @@ class LossFunc(LossBase):




class CrossEntropyLoss(LossBase): class CrossEntropyLoss(LossBase):
"""交叉熵损失函数"""
def __init__(self, pred=None, target=None, padding_idx=-100): def __init__(self, pred=None, target=None, padding_idx=-100):
"""
:param pred: 参数映射表中`pred`的映射关系,None表示映射关系为`pred`->`pred`
:param target: 参数映射表中`target`的映射关系,None表示映射关系为`target`->`target`
:param padding_idx: padding的index,在计算loss时将忽略target中标号为padding_idx的内容

Example::

>>> loss = CrossEntropyLoss(pred='pred', target='label', padding_idx=0)
"""
# TODO 需要做一些检查,F.cross_entropy在计算时,如果pred是(16, 10 ,4), target的形状按道理应该是(16, 10), 但实际却需要 # TODO 需要做一些检查,F.cross_entropy在计算时,如果pred是(16, 10 ,4), target的形状按道理应该是(16, 10), 但实际却需要
# TODO (16, 4) # TODO (16, 4)
super(CrossEntropyLoss, self).__init__() super(CrossEntropyLoss, self).__init__()
@@ -199,7 +219,12 @@ class CrossEntropyLoss(LossBase):




class L1Loss(LossBase): class L1Loss(LossBase):
"""L1损失函数"""
def __init__(self, pred=None, target=None): def __init__(self, pred=None, target=None):
"""
:param pred: 参数映射表中`pred`的映射关系,None表示映射关系为`pred`->`pred`
:param target: 参数映射表中`target`的映射关系,None表示映射关系为`target`->`target`
"""
super(L1Loss, self).__init__() super(L1Loss, self).__init__()
self._init_param_map(pred=pred, target=target) self._init_param_map(pred=pred, target=target)


@@ -208,7 +233,12 @@ class L1Loss(LossBase):




class BCELoss(LossBase): class BCELoss(LossBase):
"""二分类交叉熵损失函数"""
def __init__(self, pred=None, target=None): def __init__(self, pred=None, target=None):
"""
:param pred: 参数映射表中`pred`的映射关系,None表示映射关系为`pred`->`pred`
:param target: 参数映射表中`target`的映射关系,None表示映射关系为`target`->`target`
"""
super(BCELoss, self).__init__() super(BCELoss, self).__init__()
self._init_param_map(pred=pred, target=target) self._init_param_map(pred=pred, target=target)


@@ -217,7 +247,12 @@ class BCELoss(LossBase):




class NLLLoss(LossBase): class NLLLoss(LossBase):
"""负对数似然损失函数"""
def __init__(self, pred=None, target=None): def __init__(self, pred=None, target=None):
"""
:param pred: 参数映射表中`pred`的映射关系,None表示映射关系为`pred`->`pred`
:param target: 参数映射表中`target`的映射关系,None表示映射关系为`target`->`target`
"""
super(NLLLoss, self).__init__() super(NLLLoss, self).__init__()
self._init_param_map(pred=pred, target=target) self._init_param_map(pred=pred, target=target)


@@ -226,7 +261,11 @@ class NLLLoss(LossBase):




class LossInForward(LossBase): class LossInForward(LossBase):
"""Forward函数中计算得到的损失函数结果"""
def __init__(self, loss_key='loss'): def __init__(self, loss_key='loss'):
"""
:param str loss_key: 在forward函数中取得loss的键名,默认为loss
"""
super().__init__() super().__init__()
if not isinstance(loss_key, str): if not isinstance(loss_key, str):
raise TypeError(f"Only str allowed for loss_key, got {type(loss_key)}.") raise TypeError(f"Only str allowed for loss_key, got {type(loss_key)}.")
@@ -234,13 +273,14 @@ class LossInForward(LossBase):


def get_loss(self, **kwargs): def get_loss(self, **kwargs):
if self.loss_key not in kwargs: if self.loss_key not in kwargs:
check_res = CheckRes(missing=[self.loss_key + f"(assign to `{self.loss_key}` " \
f"in `{self.__class__.__name__}`"],
unused=[],
duplicated=[],
required=[],
all_needed=[],
varargs=[])
check_res = CheckRes(
missing=[self.loss_key + f"(assign to `{self.loss_key}` in `{self.__class__.__name__}`"],
unused=[],
duplicated=[],
required=[],
all_needed=[],
varargs=[]
)
raise CheckError(check_res=check_res, func_signature=get_func_signature(self.get_loss)) raise CheckError(check_res=check_res, func_signature=get_func_signature(self.get_loss))
return kwargs[self.loss_key] return kwargs[self.loss_key]




Loading…
Cancel
Save