knowledge_distill_util.pred_distill(args, student_logits, teacher_logits):
pred_distill
为teacher和student模型添加软标签损失,使得student模型可以学习教师模型的输出,达到student模型模仿teacher模型在预测层的表现的目的。
采用soft_cross_entropy来计算损失。
参数:
返回: 由teacher模型和student模型组合得到的软标签损失。
knowledge_distill_util.layer_distill(args, student_reps, teacher_reps):
layer_distill
为teacher和student模型添加层与层损失,使得student模型可以学习教师模型的隐藏层特征,达到用teacher模型的暗知识(Dark Knowledge)指导student模型学习的目的,将teacher模型中的知识更好的蒸馏到student模型中。通过MSE来计算student模型和teacher模型中间层的距离。
参数:
返回: 由teacher模型和student模型组合得到的层与层蒸馏损失。
注:该算子仅适用于BERT类的student和teacher模型。
knowledge_distill_util.att_distill(args, student_atts, teacher_atts):
att_distill
为teacher和student模型添加注意力损失,使得student模型可以学习教师模型的attention score矩阵,学习到其中包含语义知识,例如语法和相互关系等。通过MSE来计算损失。
参数:
返回: 由teacher模型和student模型组合得到的注意力蒸馏损失。
注:该算子仅适用于BERT类的student和teacher模型。