Browse Source

tf.GraphKeys #359

tags/v0.12
Oceania2018 6 years ago
parent
commit
683aeed693
17 changed files with 83 additions and 109 deletions
  1. +25
    -49
      README.md
  2. +5
    -0
      src/TensorFlowNET.Core/APIs/tf.graph.cs
  3. +2
    -1
      src/TensorFlowNET.Core/APIs/tf.variable.cs
  4. +3
    -3
      src/TensorFlowNET.Core/Framework/meta_graph.py.cs
  5. +1
    -1
      src/TensorFlowNET.Core/Layers/Layer.cs
  6. +4
    -4
      src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs
  7. +1
    -1
      src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs
  8. +1
    -1
      src/TensorFlowNET.Core/Operations/Losses/Util.cs
  9. +2
    -2
      src/TensorFlowNET.Core/Operations/Losses/losses_impl.py.cs
  10. +5
    -5
      src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs
  11. +3
    -3
      src/TensorFlowNET.Core/Summaries/Summary.cs
  12. +3
    -3
      src/TensorFlowNET.Core/Train/Optimizer.cs
  13. +3
    -3
      src/TensorFlowNET.Core/Variables/RefVariable.cs
  14. +5
    -4
      src/TensorFlowNET.Core/Variables/variables.py.cs
  15. +15
    -15
      src/TensorFlowNET.Core/ops.GraphKeys.cs
  16. +4
    -13
      test/TensorFlowNET.Examples/BasicModels/LinearRegression.cs
  17. +1
    -1
      test/TensorFlowNET.Examples/TextProcessing/cnn_models/VdCnn.cs

+ 25
- 49
README.md View File

@@ -1,6 +1,6 @@
![logo](docs/assets/tf.net.logo.png)

**TensorFlow.NET** (TF.NET) provides a .NET Standard binding for [TensorFlow](https://www.tensorflow.org/). It aims to implement the complete Tensorflow API in CSharp which allows .NET developers to develop, train and deploy Machine Learning models with the cross-platform .NET Standard framework.
**TensorFlow.NET** (TF.NET) provides a .NET Standard binding for [TensorFlow](https://www.tensorflow.org/). It aims to implement the complete Tensorflow API in C# which allows .NET developers to develop, train and deploy Machine Learning models with the cross-platform .NET Standard framework.

[![Join the chat at https://gitter.im/publiclab/publiclab](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/sci-sharp/community)
[![Tensorflow.NET](https://ci.appveyor.com/api/projects/status/wx4td43v2d3f2xj6?svg=true)](https://ci.appveyor.com/project/Haiping-Chen/tensorflow-net)
@@ -34,40 +34,15 @@ PM> Install-Package TensorFlow.NET
### Install tensorflow binary
### For CPU version
PM> Install-Package SciSharp.TensorFlow.Redist

### For GPU version (CUDA and cuDNN are required)
PM> Install-Package SciSharp.TensorFlow.Redist-Windows-GPU
```

Import TF.NET.

```cs
using Tensorflow;
```

Add two constants:
```cs
// Create a Constant op
var a = tf.constant(4.0f);
var b = tf.constant(5.0f);
var c = tf.add(a, b);

using (var sess = tf.Session())
{
var o = sess.run(c);
}
```
Import TF.NET in your project.

Feed placeholder:
```cs
// Create a placeholder op
var a = tf.placeholder(tf.float32);
var b = tf.placeholder(tf.float32);
var c = tf.add(a, b);

using(var sess = tf.Session())
{
var o = sess.run(c, new FeedItem(a, 3.0f), new FeedItem(b, 2.0f));
}
using static Tensorflow.Binding;
```

Linear Regression:
@@ -91,39 +66,40 @@ var optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost);
var init = tf.global_variables_initializer();

// Start training
with(tf.Session(), sess =>
using(tf.Session())
{
// Run the initializer
sess.run(init);
// Fit all training data
for (int epoch = 0; epoch < training_epochs; epoch++)
{
foreach (var (x, y) in zip<float>(train_X, train_Y))
sess.run(optimizer, new FeedItem(X, x), new FeedItem(Y, y));
sess.run(optimizer, (X, x), (Y, y));
// Display logs per epoch step
if ((epoch + 1) % display_step == 0)
{
var c = sess.run(cost, new FeedItem(X, train_X), new FeedItem(Y, train_Y));
var c = sess.run(cost, (X, train_X), (Y, train_Y));
Console.WriteLine($"Epoch: {epoch + 1} cost={c} " + $"W={sess.run(W)} b={sess.run(b)}");
}
Console.WriteLine("Optimization Finished!");
var training_cost = sess.run(cost, new FeedItem(X, train_X), new FeedItem(Y, train_Y));
Console.WriteLine($"Training cost={training_cost} W={sess.run(W)} b={sess.run(b)}");
// Testing example
var test_X = np.array(6.83f, 4.668f, 8.9f, 7.91f, 5.7f, 8.7f, 3.1f, 2.1f);
var test_Y = np.array(1.84f, 2.273f, 3.2f, 2.831f, 2.92f, 3.24f, 1.35f, 1.03f);
Console.WriteLine("Testing... (Mean square loss Comparison)");
var testing_cost = sess.run(tf.reduce_sum(tf.pow(pred - Y, 2.0f)) / (2.0f * test_X.shape[0]), new FeedItem(X, test_X), new FeedItem(Y, test_Y));
Console.WriteLine($"Testing cost={testing_cost}");
var diff = Math.Abs((float)training_cost - (float)testing_cost);
Console.WriteLine($"Absolute mean square loss difference: {diff}");
}

Console.WriteLine("Optimization Finished!");
var training_cost = sess.run(cost, (X, train_X), (Y, train_Y));
Console.WriteLine($"Training cost={training_cost} W={sess.run(W)} b={sess.run(b)}");

// Testing example
var test_X = np.array(6.83f, 4.668f, 8.9f, 7.91f, 5.7f, 8.7f, 3.1f, 2.1f);
var test_Y = np.array(1.84f, 2.273f, 3.2f, 2.831f, 2.92f, 3.24f, 1.35f, 1.03f);
Console.WriteLine("Testing... (Mean square loss Comparison)");
var testing_cost = sess.run(tf.reduce_sum(tf.pow(pred - Y, 2.0f)) / (2.0f * test_X.shape[0]),
(X, test_X), (Y, test_Y));
Console.WriteLine($"Testing cost={testing_cost}");
var diff = Math.Abs((float)training_cost - (float)testing_cost);
Console.WriteLine($"Absolute mean square loss difference: {diff}");

return diff < 0.01;
});
```



+ 5
- 0
src/TensorFlowNET.Core/APIs/tf.graph.cs View File

@@ -14,11 +14,16 @@
limitations under the License.
******************************************************************************/

using static Tensorflow.ops;

namespace Tensorflow
{
public partial class tensorflow
{
public graph_util_impl graph_util => new graph_util_impl();

public GraphKeys GraphKeys { get; } = new GraphKeys();

public Graph get_default_graph()
{
return ops.get_default_graph();


+ 2
- 1
src/TensorFlowNET.Core/APIs/tf.variable.cs View File

@@ -15,6 +15,7 @@
******************************************************************************/

using System.Collections.Generic;
using static Tensorflow.Binding;

namespace Tensorflow
{
@@ -22,7 +23,7 @@ namespace Tensorflow
{
public VariableV1[] global_variables(string scope = null)
{
return (ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope) as List<VariableV1>)
return (ops.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope) as List<VariableV1>)
.ToArray();
}



+ 3
- 3
src/TensorFlowNET.Core/Framework/meta_graph.py.cs View File

@@ -95,7 +95,7 @@ namespace Tensorflow
break;
case KindOneofCase.BytesList:
//var proto_type = ops.get_collection_proto_type(key)
if (ops.GraphKeys._VARIABLE_COLLECTIONS.Contains(col.Key))
if (tf.GraphKeys._VARIABLE_COLLECTIONS.Contains(col.Key))
{
foreach (var value in col.Value.BytesList.Value)
{
@@ -146,7 +146,7 @@ namespace Tensorflow
}
}

var variables = graph.get_collection<VariableV1>(ops.GraphKeys.GLOBAL_VARIABLES,
var variables = graph.get_collection<VariableV1>(tf.GraphKeys.GLOBAL_VARIABLES,
scope: scope_to_prepend_to_names);
var var_list = new Dictionary<string, VariableV1>();
variables.ForEach(v => var_list[ops.strip_name_scope(v.name, scope_to_prepend_to_names)] = v);
@@ -180,7 +180,7 @@ namespace Tensorflow
var graph = ops.get_default_graph();

var var_list = new Dictionary<string, RefVariable>();
var variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) as List<RefVariable>;
var variables = graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) as List<RefVariable>;

if (variables != null)
{


+ 1
- 1
src/TensorFlowNET.Core/Layers/Layer.cs View File

@@ -81,7 +81,7 @@ namespace Tensorflow.Layers


// Update global default collections.
_add_elements_to_collection(_updates.ToArray(), new string[] { ops.GraphKeys.UPDATE_OPS });
_add_elements_to_collection(_updates.ToArray(), new string[] { tf.GraphKeys.UPDATE_OPS });

return outputs;
}


+ 4
- 4
src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs View File

@@ -152,9 +152,9 @@ namespace Tensorflow.Operations
public (T, Tensor) BuildCondBranch<T>(Func<T> fn)
{
// Add the subgraph defined by fn() to the graph.
var pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION);
var pre_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION);
var original_result = fn();
var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION);
var post_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION);

//TODO: port this chunck of missing code:
/*
@@ -191,9 +191,9 @@ namespace Tensorflow.Operations
public (T[], Tensor[]) BuildCondBranch<T>(Func<T[]> fn)
{
// Add the subgraph defined by fn() to the graph.
var pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION);
var pre_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION);
var original_result = fn();
var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION);
var post_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION);

switch (original_result)
{


+ 1
- 1
src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs View File

@@ -195,7 +195,7 @@ namespace Tensorflow.Operations
// their associated TensorArrays for calling the body.
var packed_vars_for_body = _convert_flows_to_tensorarrays(flat_loop_vars, vars_for_body);
var body_result = body(packed_vars_for_body[0]);
var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION);
var post_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION);

// Store body_result to keep track of TensorArrays returned by body
var original_body_result = new[] { body_result };


+ 1
- 1
src/TensorFlowNET.Core/Operations/Losses/Util.cs View File

@@ -2,7 +2,7 @@
{
public class Util
{
public static void add_loss(Tensor loss, string loss_collection = ops.GraphKeys.LOSSES)
public static void add_loss(Tensor loss, string loss_collection = "losses")
{
if (!string.IsNullOrEmpty(loss_collection))
ops.add_to_collection(loss_collection, loss);


+ 2
- 2
src/TensorFlowNET.Core/Operations/Losses/losses_impl.py.cs View File

@@ -22,7 +22,7 @@ namespace Tensorflow
public class LossesImpl
{
public Tensor compute_weighted_loss(Tensor losses, Tensor weights = null, string scope = null,
string loss_collection = ops.GraphKeys.LOSSES, string reduction = Reduction.SUM_BY_NONZERO_WEIGHTS)
string loss_collection = "losses", string reduction = Reduction.SUM_BY_NONZERO_WEIGHTS)
{
return tf_with(ops.name_scope(scope, default_name: "weighted_loss", (losses, weights)), delegate
{
@@ -101,7 +101,7 @@ namespace Tensorflow
Tensor logits,
float weights = 1.0f,
string scope = null,
string loss_collection= ops.GraphKeys.LOSSES,
string loss_collection= "losses",
string reduction = Reduction.SUM_BY_NONZERO_WEIGHTS)
{
return tf_with(ops.name_scope(scope,


+ 5
- 5
src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs View File

@@ -431,8 +431,8 @@ namespace Tensorflow

merges = _convert_flows_to_tensorarrays(new Tensor[] { (Tensor)orig_res_t }, merges);

ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_t);
ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_f);
ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_t);
ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_f);

return merges[0];
});
@@ -479,8 +479,8 @@ namespace Tensorflow

merges = _convert_flows_to_tensorarrays(orig_res_t, merges);

ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_t);
ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_f);
ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_t);
ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_f);

return merges;
});
@@ -596,7 +596,7 @@ namespace Tensorflow
swap_memory: swap_memory);

if (loop_context.outer_context == null)
ops.add_to_collection(ops.GraphKeys.WHILE_CONTEXT, loop_context);
ops.add_to_collection(tf.GraphKeys.WHILE_CONTEXT, loop_context);

var results = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants,
return_same_structure);


+ 3
- 3
src/TensorFlowNET.Core/Summaries/Summary.cs View File

@@ -33,11 +33,11 @@ namespace Tensorflow.Summaries
{
var (tag, scope) = summary_scope(name, family: family, values: new Tensor[] { tensor }, default_name: "HistogramSummary");
var val = gen_logging_ops.histogram_summary(tag: tag, values: tensor, name: scope);
collect(val, collections?.ToList(), new List<string> { ops.GraphKeys.SUMMARIES });
collect(val, collections?.ToList(), new List<string> { tf.GraphKeys.SUMMARIES });
return val;
}

public Tensor merge_all(string key = ops.GraphKeys.SUMMARIES, string scope= null, string name= null)
public Tensor merge_all(string key = "summaries", string scope= null, string name= null)
{
var summary_ops = ops.get_collection(key, scope: scope);
if (summary_ops == null)
@@ -67,7 +67,7 @@ namespace Tensorflow.Summaries
{
var (tag, scope) = summary_scope(name, family: family, values: new Tensor[] { tensor });
var val = gen_logging_ops.scalar_summary(tags: tag, values: tensor, name: scope);
collect(val, collections?.ToList(), new List<string> { ops.GraphKeys.SUMMARIES });
collect(val, collections?.ToList(), new List<string> { tf.GraphKeys.SUMMARIES });
return val;
}



+ 3
- 3
src/TensorFlowNET.Core/Train/Optimizer.cs View File

@@ -198,7 +198,7 @@ namespace Tensorflow

if (!tf.context.executing_eagerly())
{
var train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) as List<ITensorOrOperation>;
var train_op = ops.get_collection_ref(tf.GraphKeys.TRAIN_OP) as List<ITensorOrOperation>;
if (train_op != null && train_op.Contains(apply_updates))
train_op.Add(apply_updates);
}
@@ -359,7 +359,7 @@ namespace Tensorflow


var tmp = variables.trainable_variables();
var vars = ops.get_collection<RefVariable>(ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES);
var vars = ops.get_collection<RefVariable>(tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES);
switch (tmp)
{
case List<RefVariable> values:
@@ -370,7 +370,7 @@ namespace Tensorflow
break;
}

var_list = var_list.Concat(ops.get_collection<RefVariable>(ops.GraphKeys._STREAMING_MODEL_PORTS)).ToList();
var_list = var_list.Concat(ops.get_collection<RefVariable>(tf.GraphKeys._STREAMING_MODEL_PORTS)).ToList();
var processors = var_list.Select(v => optimizer._get_processor(v)).ToList();
var var_refs = processors.Select(x => x.target()).ToArray();



+ 3
- 3
src/TensorFlowNET.Core/Variables/RefVariable.cs View File

@@ -121,7 +121,7 @@ namespace Tensorflow

if(collections == null)
{
collections = new List<string> { ops.GraphKeys.GLOBAL_VARIABLES };
collections = new List<string> { tf.GraphKeys.GLOBAL_VARIABLES };
}

// Store the graph key so optimizers know how to only retrieve variables from
@@ -129,8 +129,8 @@ namespace Tensorflow
_graph_key = ops.get_default_graph().graph_key;

_trainable = trainable;
if (trainable && !collections.Contains(ops.GraphKeys.TRAINABLE_VARIABLES))
collections.Add(ops.GraphKeys.TRAINABLE_VARIABLES);
if (trainable && !collections.Contains(tf.GraphKeys.TRAINABLE_VARIABLES))
collections.Add(tf.GraphKeys.TRAINABLE_VARIABLES);

ops.init_scope();
var values = init_from_fn ? new object[0] : new object[] { initial_value };


+ 5
- 4
src/TensorFlowNET.Core/Variables/variables.py.cs View File

@@ -17,6 +17,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using static Tensorflow.Binding;

namespace Tensorflow
{
@@ -28,7 +29,7 @@ namespace Tensorflow
/// <returns></returns>
public static object trainable_variables()
{
return ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES);
return ops.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES);
}

/// <summary>
@@ -40,11 +41,11 @@ namespace Tensorflow
{
var all = new List<VariableV1>();

var collection = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope);
var collection = ops.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope);
if(collection != null)
all.AddRange(collection as List<VariableV1>);

collection = ops.get_collection(ops.GraphKeys.SAVEABLE_OBJECTS, scope);
collection = ops.get_collection(tf.GraphKeys.SAVEABLE_OBJECTS, scope);
if (collection != null)
all.AddRange(collection as List<VariableV1>);

@@ -64,7 +65,7 @@ namespace Tensorflow
/// <returns>A list of `Variable` objects.</returns>
public static List<VariableV1> global_variables(string scope = null)
{
var result = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope);
var result = ops.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope);

return result == null ? new List<VariableV1>() : result as List<VariableV1>;
}


+ 15
- 15
src/TensorFlowNET.Core/ops.GraphKeys.cs View File

@@ -27,57 +27,57 @@ namespace Tensorflow
/// specified, but it is also possible to pass an explicit list of
/// variables.
/// </summary>
public static class GraphKeys
public class GraphKeys
{
/// <summary>
/// the subset of `Variable` objects that will be trained by an optimizer.
/// </summary>
public static string TRAINABLE_VARIABLES = "trainable_variables";
public string TRAINABLE_VARIABLES = "trainable_variables";

/// <summary>
/// Trainable resource-style variables.
/// </summary>
public static string TRAINABLE_RESOURCE_VARIABLES = "trainable_resource_variables";
public string TRAINABLE_RESOURCE_VARIABLES = "trainable_resource_variables";

/// <summary>
/// Key for streaming model ports.
/// </summary>
public static string _STREAMING_MODEL_PORTS = "streaming_model_ports";
public string _STREAMING_MODEL_PORTS = "streaming_model_ports";

/// <summary>
/// Key to collect losses
/// </summary>
public const string LOSSES = "losses";
public string LOSSES = "losses";

/// <summary>
/// Key to collect Variable objects that are global (shared across machines).
/// Default collection for all variables, except local ones.
/// </summary>
public static string GLOBAL_VARIABLES = "variables";
public string GLOBAL_VARIABLES = "variables";

public static string TRAIN_OP = "train_op";
public string TRAIN_OP = "train_op";

public static string GLOBAL_STEP = GLOBAL_STEP = "global_step";
public string GLOBAL_STEP = "global_step";

public static string[] _VARIABLE_COLLECTIONS = new string[] { "variables", "trainable_variables", "model_variables" };
public string[] _VARIABLE_COLLECTIONS = new string[] { "variables", "trainable_variables", "model_variables" };
/// <summary>
/// Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing.
/// </summary>
public static string SAVEABLE_OBJECTS = "saveable_objects";
public string SAVEABLE_OBJECTS = "saveable_objects";
/// <summary>
/// Key to collect update_ops
/// </summary>
public static string UPDATE_OPS = "update_ops";
public string UPDATE_OPS = "update_ops";

// Key to collect summaries.
public const string SUMMARIES = "summaries";
public string SUMMARIES = "summaries";

// Used to store v2 summary names.
public static string _SUMMARY_COLLECTION = "_SUMMARY_V2";
public string _SUMMARY_COLLECTION = "_SUMMARY_V2";

// Key for control flow context.
public static string COND_CONTEXT = "cond_context";
public static string WHILE_CONTEXT = "while_context";
public string COND_CONTEXT = "cond_context";
public string WHILE_CONTEXT = "while_context";
}
}
}

+ 4
- 13
test/TensorFlowNET.Examples/BasicModels/LinearRegression.cs View File

@@ -80,26 +80,18 @@ namespace TensorFlowNET.Examples
for (int epoch = 0; epoch < training_epochs; epoch++)
{
foreach (var (x, y) in zip<float>(train_X, train_Y))
{
sess.run(optimizer,
new FeedItem(X, x),
new FeedItem(Y, y));
}
sess.run(optimizer, (X, x), (Y, y));

// Display logs per epoch step
if ((epoch + 1) % display_step == 0)
{
var c = sess.run(cost,
new FeedItem(X, train_X),
new FeedItem(Y, train_Y));
var c = sess.run(cost, (X, train_X), (Y, train_Y));
Console.WriteLine($"Epoch: {epoch + 1} cost={c} " + $"W={sess.run(W)} b={sess.run(b)}");
}
}

Console.WriteLine("Optimization Finished!");
var training_cost = sess.run(cost,
new FeedItem(X, train_X),
new FeedItem(Y, train_Y));
var training_cost = sess.run(cost, (X, train_X), (Y, train_Y));
Console.WriteLine($"Training cost={training_cost} W={sess.run(W)} b={sess.run(b)}");

// Testing example
@@ -107,8 +99,7 @@ namespace TensorFlowNET.Examples
var test_Y = np.array(1.84f, 2.273f, 3.2f, 2.831f, 2.92f, 3.24f, 1.35f, 1.03f);
Console.WriteLine("Testing... (Mean square loss Comparison)");
var testing_cost = sess.run(tf.reduce_sum(tf.pow(pred - Y, 2.0f)) / (2.0f * test_X.shape[0]),
new FeedItem(X, test_X),
new FeedItem(Y, test_Y));
(X, test_X), (Y, test_Y));
Console.WriteLine($"Testing cost={testing_cost}");
var diff = Math.Abs((float)training_cost - (float)testing_cost);
Console.WriteLine($"Absolute mean square loss difference: {diff}");


+ 1
- 1
test/TensorFlowNET.Examples/TextProcessing/cnn_models/VdCnn.cs View File

@@ -118,7 +118,7 @@ namespace TensorFlowNET.Examples.Text
var y_one_hot = tf.one_hot(y, num_class);
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits: logits, labels: y_one_hot));

var update_ops = tf.get_collection(ops.GraphKeys.UPDATE_OPS) as List<object>;
var update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) as List<object>;
tf_with(tf.control_dependencies(update_ops.Select(x => (Operation)x).ToArray()), delegate
{
var adam = tf.train.AdamOptimizer(learning_rate);


Loading…
Cancel
Save