Browse Source

can't compile.

tags/v0.8.0
Oceania2018 6 years ago
parent
commit
3c25a5442b
4 changed files with 38 additions and 5 deletions
  1. +19
    -2
      src/TensorFlowNET.Core/Operations/Losses/losses_impl.py.cs
  2. +11
    -0
      src/TensorFlowNET.Core/Operations/Losses/tf.loss.cs
  3. +5
    -0
      src/TensorFlowNET.Core/ops.GraphKeys.cs
  4. +3
    -3
      test/TensorFlowNET.Examples/MetaGraph.cs

+ 19
- 2
src/TensorFlowNET.Core/Operations/Losses/losses_impl.py.cs View File

@@ -2,9 +2,26 @@
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Operations.Losses
namespace Tensorflow
{
class losses_impl
public class LossesImpl : Python
{
public Tensor sparse_softmax_cross_entropy(Tensor labels,
Tensor logits,
float weights = 1.0f,
string scope = "",
string loss_collection= "losses")
{
with<ops.name_scope>(new ops.name_scope(scope,
"sparse_softmax_cross_entropy_loss",
(logits, labels, weights)),
namescope =>
{


});

throw new NotImplementedException("sparse_softmax_cross_entropy");
}
}
}

+ 11
- 0
src/TensorFlowNET.Core/Operations/Losses/tf.loss.cs View File

@@ -0,0 +1,11 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public static partial class tf
{
public static LossesImpl losses => new LossesImpl();
}
}

+ 5
- 0
src/TensorFlowNET.Core/ops.GraphKeys.cs View File

@@ -22,6 +22,11 @@ namespace Tensorflow
/// </summary>
public static string TRAINABLE_VARIABLES = "trainable_variables";

/// <summary>
/// Key to collect losses
/// </summary>
public static string LOSSES = "losses";

/// <summary>
/// Key to collect Variable objects that are global (shared across machines).
/// Default collection for all variables, except local ones.


+ 3
- 3
test/TensorFlowNET.Examples/MetaGraph.cs View File

@@ -20,9 +20,9 @@ namespace TensorFlowNET.Examples
new_saver.restore(sess, dir + "my-model-10000");
var labels = tf.constant(0, dtype: tf.int32, shape: new int[] { 100 }, name: "labels");
var batch_size = tf.size(labels);
var logits = (tf.get_collection("logits") as List<ITensorOrOperation>)[0];
var loss = tf.losses.sparse_softmax_cross_entropy(labels = labels,
logits = logits);
var logits = (tf.get_collection("logits") as List<ITensorOrOperation>)[0] as Tensor;
var loss = tf.losses.sparse_softmax_cross_entropy(labels: labels,
logits: logits);
});
}
}


Loading…
Cancel
Save