diff --git a/src/TensorFlowNET.Keras/BackendImpl.cs b/src/TensorFlowNET.Keras/BackendImpl.cs index e439eb9d..206c331f 100644 --- a/src/TensorFlowNET.Keras/BackendImpl.cs +++ b/src/TensorFlowNET.Keras/BackendImpl.cs @@ -260,7 +260,19 @@ namespace Tensorflow.Keras if (from_logits) return tf.nn.softmax_cross_entropy_with_logits_v2(labels: target, logits: output, axis: axis); - throw new NotImplementedException(""); + if (output.op != null && output.op.type == "Softmax") + { + if (output.op.inputs.Length != 1) throw new ApplicationException(); + var o = output = output.op.inputs[0]; + return tf.nn.softmax_cross_entropy_with_logits_v2(labels: target, logits: o, axis: axis); + } + + // scale preds so that the class probas of each sample sum to 1 + output = output / math_ops.reduce_sum(output, new Axis(axis), true); + // Compute cross entropy from probabilities. + var epsilon_ = constant_op.constant(epsilon(), output.dtype.as_base_dtype()); + output = clip_ops.clip_by_value(output, epsilon_, 1.0 - epsilon_); + return -math_ops.reduce_sum(target * math_ops.log(output), new Axis(axis)); } ///