|
- import oneflow
- import oneflow.nn.functional as F
-
- def kldiv(logits, targets, T=1.0):
- """ Cross Entropy for soft targets
-
- Parameters:
- - logits (Tensor): logits score (e.g. outputs of fc layer)
- - targets (Tensor): logits of soft targets
- - T (float): temperature of distill
- - reduction: reduction to the output
- """
- p_targets = F.softmax(targets/T, dim=1)
- logp_logits = F.log_softmax(logits/T, dim=1)
- kl_div = oneflow.nn.KLDivLoss(reduction="none")
- kld = kl_div(logp_logits, p_targets) * (T**2)
- return kld.sum(1).mean()
|