@@ -2,9 +2,26 @@ | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Text; | using System.Text; | ||||
namespace Tensorflow.Operations.Losses | |||||
namespace Tensorflow | |||||
{ | { | ||||
class losses_impl | |||||
public class LossesImpl : Python | |||||
{ | { | ||||
public Tensor sparse_softmax_cross_entropy(Tensor labels, | |||||
Tensor logits, | |||||
float weights = 1.0f, | |||||
string scope = "", | |||||
string loss_collection= "losses") | |||||
{ | |||||
with<ops.name_scope>(new ops.name_scope(scope, | |||||
"sparse_softmax_cross_entropy_loss", | |||||
(logits, labels, weights)), | |||||
namescope => | |||||
{ | |||||
}); | |||||
throw new NotImplementedException("sparse_softmax_cross_entropy"); | |||||
} | |||||
} | } | ||||
} | } |
@@ -0,0 +1,11 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow | |||||
{ | |||||
public static partial class tf | |||||
{ | |||||
public static LossesImpl losses => new LossesImpl(); | |||||
} | |||||
} |
@@ -22,6 +22,11 @@ namespace Tensorflow | |||||
/// </summary> | /// </summary> | ||||
public static string TRAINABLE_VARIABLES = "trainable_variables"; | public static string TRAINABLE_VARIABLES = "trainable_variables"; | ||||
/// <summary> | |||||
/// Key to collect losses | |||||
/// </summary> | |||||
public static string LOSSES = "losses"; | |||||
/// <summary> | /// <summary> | ||||
/// Key to collect Variable objects that are global (shared across machines). | /// Key to collect Variable objects that are global (shared across machines). | ||||
/// Default collection for all variables, except local ones. | /// Default collection for all variables, except local ones. | ||||
@@ -20,9 +20,9 @@ namespace TensorFlowNET.Examples | |||||
new_saver.restore(sess, dir + "my-model-10000"); | new_saver.restore(sess, dir + "my-model-10000"); | ||||
var labels = tf.constant(0, dtype: tf.int32, shape: new int[] { 100 }, name: "labels"); | var labels = tf.constant(0, dtype: tf.int32, shape: new int[] { 100 }, name: "labels"); | ||||
var batch_size = tf.size(labels); | var batch_size = tf.size(labels); | ||||
var logits = (tf.get_collection("logits") as List<ITensorOrOperation>)[0]; | |||||
var loss = tf.losses.sparse_softmax_cross_entropy(labels = labels, | |||||
logits = logits); | |||||
var logits = (tf.get_collection("logits") as List<ITensorOrOperation>)[0] as Tensor; | |||||
var loss = tf.losses.sparse_softmax_cross_entropy(labels: labels, | |||||
logits: logits); | |||||
}); | }); | ||||
} | } | ||||
} | } | ||||