From 5b8a62783c55c60093b5fbc13f25c12e34944e79 Mon Sep 17 00:00:00 2001 From: xuyige Date: Tue, 23 Apr 2019 22:03:50 +0800 Subject: [PATCH] update documents on losses.py --- fastNLP/core/losses.py | 78 ++++++++++++++++++++++++++++++++---------- 1 file changed, 59 insertions(+), 19 deletions(-) diff --git a/fastNLP/core/losses.py b/fastNLP/core/losses.py index 6b0b4460..08702034 100644 --- a/fastNLP/core/losses.py +++ b/fastNLP/core/losses.py @@ -13,7 +13,7 @@ from fastNLP.core.utils import get_func_signature class LossBase(object): - """Base class for all losses. + """所有loss的基类. """ def __init__(self): @@ -24,10 +24,10 @@ class LossBase(object): raise NotImplementedError def _init_param_map(self, key_map=None, **kwargs): - """Check the validity of key_map and other param map. Add these into self.param_map + """检查key_map和其他参数map,并将这些映射关系添加到self.param_map - :param key_map: dict - :param kwargs: + :param dict key_map: 表示key的映射关系 + :param kwargs: key word args里面的每一个的键-值对都会被构造成映射关系 :return: None """ value_counter = defaultdict(set) @@ -87,9 +87,9 @@ class LossBase(object): def __call__(self, pred_dict, target_dict, check=False): """ - :param pred_dict: A dict from forward function of the network. - :param target_dict: A dict from DataSet.batch_y. - :param check: Boolean. Force to check the mapping functions when it is running. + :param dict pred_dict: 模型的forward函数返回的dict + :param dict target_dict: DataSet.batch_y里的键-值对所组成的dict + :param Boolean check: 每一次执行映射函数的时候是否检查映射表,默认为不检查 :return: """ fast_param = self._fast_param_map(pred_dict, target_dict) @@ -162,15 +162,25 @@ class LossBase(object): class LossFunc(LossBase): - """A wrapper of user-provided loss function. - + """提供给用户使用自定义损失函数的类 """ def __init__(self, func, key_map=None, **kwargs): """ - :param func: a callable object, such as a function. - :param dict key_map: - :param kwargs: + :param func: 用户自行定义的损失函数,应当为一个函数或者callable(func)为True的ojbect + :param dict key_map: 参数映射表。键为Model/DataSet参数名,值为损失函数参数名。 + fastNLP的trainer将在训练时从模型返回值或者训练数据DataSet的target=True的field中 + 找到相对应的参数名为value的参数,并传入func中作为参数名为key的参数 + :param kwargs: 除了参数映射表以外可以用key word args的方式设置参数映射关系 + + Example:: + + >>> func = torch.nn.CrossEntropyLoss() + >>> loss_func = LossFunc(func, input="pred", target="label") + >>> # 这表示构建了一个损失函数类,由func计算损失函数,其中将从模型返回值或者DataSet的target=True的field + >>> # 当中找到一个参数名为`pred`的参数传入func一个参数名为`input`的参数;找到一个参数名为`label`的参数 + >>> # 传入func作为一个名为`target`的参数 + """ super(LossFunc, self).__init__() _check_function_or_method(func) @@ -186,7 +196,17 @@ class LossFunc(LossBase): class CrossEntropyLoss(LossBase): + """交叉熵损失函数""" def __init__(self, pred=None, target=None, padding_idx=-100): + """ + :param pred: 参数映射表中`pred`的映射关系,None表示映射关系为`pred`->`pred` + :param target: 参数映射表中`target`的映射关系,None表示映射关系为`target`->`target` + :param padding_idx: padding的index,在计算loss时将忽略target中标号为padding_idx的内容 + + Example:: + + >>> loss = CrossEntropyLoss(pred='pred', target='label', padding_idx=0) + """ # TODO 需要做一些检查,F.cross_entropy在计算时,如果pred是(16, 10 ,4), target的形状按道理应该是(16, 10), 但实际却需要 # TODO (16, 4) super(CrossEntropyLoss, self).__init__() @@ -199,7 +219,12 @@ class CrossEntropyLoss(LossBase): class L1Loss(LossBase): + """L1损失函数""" def __init__(self, pred=None, target=None): + """ + :param pred: 参数映射表中`pred`的映射关系,None表示映射关系为`pred`->`pred` + :param target: 参数映射表中`target`的映射关系,None表示映射关系为`target`->`target` + """ super(L1Loss, self).__init__() self._init_param_map(pred=pred, target=target) @@ -208,7 +233,12 @@ class L1Loss(LossBase): class BCELoss(LossBase): + """二分类交叉熵损失函数""" def __init__(self, pred=None, target=None): + """ + :param pred: 参数映射表中`pred`的映射关系,None表示映射关系为`pred`->`pred` + :param target: 参数映射表中`target`的映射关系,None表示映射关系为`target`->`target` + """ super(BCELoss, self).__init__() self._init_param_map(pred=pred, target=target) @@ -217,7 +247,12 @@ class BCELoss(LossBase): class NLLLoss(LossBase): + """负对数似然损失函数""" def __init__(self, pred=None, target=None): + """ + :param pred: 参数映射表中`pred`的映射关系,None表示映射关系为`pred`->`pred` + :param target: 参数映射表中`target`的映射关系,None表示映射关系为`target`->`target` + """ super(NLLLoss, self).__init__() self._init_param_map(pred=pred, target=target) @@ -226,7 +261,11 @@ class NLLLoss(LossBase): class LossInForward(LossBase): + """Forward函数中计算得到的损失函数结果""" def __init__(self, loss_key='loss'): + """ + :param str loss_key: 在forward函数中取得loss的键名,默认为loss + """ super().__init__() if not isinstance(loss_key, str): raise TypeError(f"Only str allowed for loss_key, got {type(loss_key)}.") @@ -234,13 +273,14 @@ class LossInForward(LossBase): def get_loss(self, **kwargs): if self.loss_key not in kwargs: - check_res = CheckRes(missing=[self.loss_key + f"(assign to `{self.loss_key}` " \ - f"in `{self.__class__.__name__}`"], - unused=[], - duplicated=[], - required=[], - all_needed=[], - varargs=[]) + check_res = CheckRes( + missing=[self.loss_key + f"(assign to `{self.loss_key}` in `{self.__class__.__name__}`"], + unused=[], + duplicated=[], + required=[], + all_needed=[], + varargs=[] + ) raise CheckError(check_res=check_res, func_signature=get_func_signature(self.get_loss)) return kwargs[self.loss_key]