diff --git a/src/TensorFlowNET.Core/Operations/Losses/losses_impl.py.cs b/src/TensorFlowNET.Core/Operations/Losses/losses_impl.py.cs index e057ba01..654d9f10 100644 --- a/src/TensorFlowNET.Core/Operations/Losses/losses_impl.py.cs +++ b/src/TensorFlowNET.Core/Operations/Losses/losses_impl.py.cs @@ -2,9 +2,26 @@ using System.Collections.Generic; using System.Text; -namespace Tensorflow.Operations.Losses +namespace Tensorflow { - class losses_impl + public class LossesImpl : Python { + public Tensor sparse_softmax_cross_entropy(Tensor labels, + Tensor logits, + float weights = 1.0f, + string scope = "", + string loss_collection= "losses") + { + with(new ops.name_scope(scope, + "sparse_softmax_cross_entropy_loss", + (logits, labels, weights)), + namescope => + { + + + }); + + throw new NotImplementedException("sparse_softmax_cross_entropy"); + } } } diff --git a/src/TensorFlowNET.Core/Operations/Losses/tf.loss.cs b/src/TensorFlowNET.Core/Operations/Losses/tf.loss.cs new file mode 100644 index 00000000..90162603 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/Losses/tf.loss.cs @@ -0,0 +1,11 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public static partial class tf + { + public static LossesImpl losses => new LossesImpl(); + } +} diff --git a/src/TensorFlowNET.Core/ops.GraphKeys.cs b/src/TensorFlowNET.Core/ops.GraphKeys.cs index ef107ff3..d828820a 100644 --- a/src/TensorFlowNET.Core/ops.GraphKeys.cs +++ b/src/TensorFlowNET.Core/ops.GraphKeys.cs @@ -22,6 +22,11 @@ namespace Tensorflow /// public static string TRAINABLE_VARIABLES = "trainable_variables"; + /// + /// Key to collect losses + /// + public static string LOSSES = "losses"; + /// /// Key to collect Variable objects that are global (shared across machines). /// Default collection for all variables, except local ones. diff --git a/test/TensorFlowNET.Examples/MetaGraph.cs b/test/TensorFlowNET.Examples/MetaGraph.cs index d257e712..bf6cebda 100644 --- a/test/TensorFlowNET.Examples/MetaGraph.cs +++ b/test/TensorFlowNET.Examples/MetaGraph.cs @@ -20,9 +20,9 @@ namespace TensorFlowNET.Examples new_saver.restore(sess, dir + "my-model-10000"); var labels = tf.constant(0, dtype: tf.int32, shape: new int[] { 100 }, name: "labels"); var batch_size = tf.size(labels); - var logits = (tf.get_collection("logits") as List)[0]; - var loss = tf.losses.sparse_softmax_cross_entropy(labels = labels, - logits = logits); + var logits = (tf.get_collection("logits") as List)[0] as Tensor; + var loss = tf.losses.sparse_softmax_cross_entropy(labels: labels, + logits: logits); }); } }