|
|
@@ -5,7 +5,7 @@ def squash(predict , truth , **kwargs): |
|
|
|
|
|
|
|
:param predict : Tensor, model output |
|
|
|
:param truth : Tensor, truth from dataset |
|
|
|
:param **kwargs : extract arguments |
|
|
|
:param **kwargs : extra arguments |
|
|
|
|
|
|
|
:return predict , truth: predict & truth after processing |
|
|
|
''' |
|
|
@@ -18,8 +18,8 @@ def unpad(predict , truth , **kwargs): |
|
|
|
|
|
|
|
:param predict : Tensor, [batch_size , max_len , tag_size] |
|
|
|
:param truth : Tensor, [batch_size , max_len] |
|
|
|
:param **kwargs : extract arguments, kwargs["lens"] is expected to be exsist |
|
|
|
arg["lens"] : list or LongTensor, [batch_size] |
|
|
|
:param **kwargs : extra arguments, kwargs["lens"] is expected to be exsist |
|
|
|
kwargs["lens"] : list or LongTensor, [batch_size] |
|
|
|
the i-th element is true lengths of i-th sequence |
|
|
|
|
|
|
|
:return predict , truth: predict & truth after processing |
|
|
@@ -39,8 +39,8 @@ def unpad_mask(predict , truth , **kwargs): |
|
|
|
|
|
|
|
:param predict : Tensor, [batch_size , max_len , tag_size] |
|
|
|
:param truth : Tensor, [batch_size , max_len] |
|
|
|
:param **kwargs : extract arguments, kwargs["lens"] is expected to be exsist |
|
|
|
arg["lens"] : list or LongTensor, [batch_size] |
|
|
|
:param **kwargs : extra arguments, kwargs["lens"] is expected to be exsist |
|
|
|
kwargs["lens"] : list or LongTensor, [batch_size] |
|
|
|
the i-th element is true lengths of i-th sequence |
|
|
|
|
|
|
|
:return predict , truth: predict & truth after processing |
|
|
@@ -56,8 +56,8 @@ def mask(predict , truth , **kwargs): |
|
|
|
|
|
|
|
:param predict : Tensor, [batch_size , max_len , tag_size] |
|
|
|
:param truth : Tensor, [batch_size , max_len] |
|
|
|
:param **kwargs : extract arguments, kwargs["mask"] is expected to be exsist |
|
|
|
arg["mask"] : ByteTensor, [batch_size , max_len] |
|
|
|
:param **kwargs : extra arguments, kwargs["mask"] is expected to be exsist |
|
|
|
kwargs["mask"] : ByteTensor, [batch_size , max_len] |
|
|
|
the mask Tensor , the position that is 1 will be selected |
|
|
|
|
|
|
|
:return predict , truth: predict & truth after processing |
|
|
@@ -112,7 +112,6 @@ loss_function_name = { |
|
|
|
"MarginRankingLoss".lower() : torch.nn.MarginRankingLoss, |
|
|
|
"TripletMarginLoss".lower() : torch.nn.TripletMarginLoss, |
|
|
|
"HingeEmbeddingLoss".lower() : torch.nn.HingeEmbeddingLoss, |
|
|
|
"HingeEmbeddingLoss".lower() : torch.nn.HingeEmbeddingLoss, |
|
|
|
"CosineEmbeddingLoss".lower() : torch.nn.CosineEmbeddingLoss, |
|
|
|
"MultiLabelMarginLoss".lower() : torch.nn.MultiLabelMarginLoss, |
|
|
|
"MultiLabelSoftMarginLoss".lower() : torch.nn.MultiLabelSoftMarginLoss, |
|
|
@@ -132,7 +131,7 @@ class Loss(object): |
|
|
|
|
|
|
|
pre_pro funcsions should have three arguments: predict, truth, **arg |
|
|
|
predict and truth is the necessary parameters in loss function |
|
|
|
arg is the extra parameters passed-in when calling loss function |
|
|
|
kwargs is the extra parameters passed-in when calling loss function |
|
|
|
pre_pro functions should return two objects, respectively predict and truth that after processed |
|
|
|
|
|
|
|
''' |
|
|
|