@@ -1,6 +1,6 @@ | |||
 | |||
**TensorFlow.NET** (TF.NET) provides a .NET Standard binding for [TensorFlow](https://www.tensorflow.org/). It aims to implement the complete Tensorflow API in CSharp which allows .NET developers to develop, train and deploy Machine Learning models with the cross-platform .NET Standard framework. | |||
**TensorFlow.NET** (TF.NET) provides a .NET Standard binding for [TensorFlow](https://www.tensorflow.org/). It aims to implement the complete Tensorflow API in C# which allows .NET developers to develop, train and deploy Machine Learning models with the cross-platform .NET Standard framework. | |||
[](https://gitter.im/sci-sharp/community) | |||
[](https://ci.appveyor.com/project/Haiping-Chen/tensorflow-net) | |||
@@ -34,40 +34,15 @@ PM> Install-Package TensorFlow.NET | |||
### Install tensorflow binary | |||
### For CPU version | |||
PM> Install-Package SciSharp.TensorFlow.Redist | |||
### For GPU version (CUDA and cuDNN are required) | |||
PM> Install-Package SciSharp.TensorFlow.Redist-Windows-GPU | |||
``` | |||
Import TF.NET. | |||
```cs | |||
using Tensorflow; | |||
``` | |||
Add two constants: | |||
```cs | |||
// Create a Constant op | |||
var a = tf.constant(4.0f); | |||
var b = tf.constant(5.0f); | |||
var c = tf.add(a, b); | |||
using (var sess = tf.Session()) | |||
{ | |||
var o = sess.run(c); | |||
} | |||
``` | |||
Import TF.NET in your project. | |||
Feed placeholder: | |||
```cs | |||
// Create a placeholder op | |||
var a = tf.placeholder(tf.float32); | |||
var b = tf.placeholder(tf.float32); | |||
var c = tf.add(a, b); | |||
using(var sess = tf.Session()) | |||
{ | |||
var o = sess.run(c, new FeedItem(a, 3.0f), new FeedItem(b, 2.0f)); | |||
} | |||
using static Tensorflow.Binding; | |||
``` | |||
Linear Regression: | |||
@@ -91,39 +66,40 @@ var optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost); | |||
var init = tf.global_variables_initializer(); | |||
// Start training | |||
with(tf.Session(), sess => | |||
using(tf.Session()) | |||
{ | |||
// Run the initializer | |||
sess.run(init); | |||
// Fit all training data | |||
for (int epoch = 0; epoch < training_epochs; epoch++) | |||
{ | |||
foreach (var (x, y) in zip<float>(train_X, train_Y)) | |||
sess.run(optimizer, new FeedItem(X, x), new FeedItem(Y, y)); | |||
sess.run(optimizer, (X, x), (Y, y)); | |||
// Display logs per epoch step | |||
if ((epoch + 1) % display_step == 0) | |||
{ | |||
var c = sess.run(cost, new FeedItem(X, train_X), new FeedItem(Y, train_Y)); | |||
var c = sess.run(cost, (X, train_X), (Y, train_Y)); | |||
Console.WriteLine($"Epoch: {epoch + 1} cost={c} " + $"W={sess.run(W)} b={sess.run(b)}"); | |||
} | |||
Console.WriteLine("Optimization Finished!"); | |||
var training_cost = sess.run(cost, new FeedItem(X, train_X), new FeedItem(Y, train_Y)); | |||
Console.WriteLine($"Training cost={training_cost} W={sess.run(W)} b={sess.run(b)}"); | |||
// Testing example | |||
var test_X = np.array(6.83f, 4.668f, 8.9f, 7.91f, 5.7f, 8.7f, 3.1f, 2.1f); | |||
var test_Y = np.array(1.84f, 2.273f, 3.2f, 2.831f, 2.92f, 3.24f, 1.35f, 1.03f); | |||
Console.WriteLine("Testing... (Mean square loss Comparison)"); | |||
var testing_cost = sess.run(tf.reduce_sum(tf.pow(pred - Y, 2.0f)) / (2.0f * test_X.shape[0]), new FeedItem(X, test_X), new FeedItem(Y, test_Y)); | |||
Console.WriteLine($"Testing cost={testing_cost}"); | |||
var diff = Math.Abs((float)training_cost - (float)testing_cost); | |||
Console.WriteLine($"Absolute mean square loss difference: {diff}"); | |||
} | |||
Console.WriteLine("Optimization Finished!"); | |||
var training_cost = sess.run(cost, (X, train_X), (Y, train_Y)); | |||
Console.WriteLine($"Training cost={training_cost} W={sess.run(W)} b={sess.run(b)}"); | |||
// Testing example | |||
var test_X = np.array(6.83f, 4.668f, 8.9f, 7.91f, 5.7f, 8.7f, 3.1f, 2.1f); | |||
var test_Y = np.array(1.84f, 2.273f, 3.2f, 2.831f, 2.92f, 3.24f, 1.35f, 1.03f); | |||
Console.WriteLine("Testing... (Mean square loss Comparison)"); | |||
var testing_cost = sess.run(tf.reduce_sum(tf.pow(pred - Y, 2.0f)) / (2.0f * test_X.shape[0]), | |||
(X, test_X), (Y, test_Y)); | |||
Console.WriteLine($"Testing cost={testing_cost}"); | |||
var diff = Math.Abs((float)training_cost - (float)testing_cost); | |||
Console.WriteLine($"Absolute mean square loss difference: {diff}"); | |||
return diff < 0.01; | |||
}); | |||
``` | |||
@@ -14,11 +14,16 @@ | |||
limitations under the License. | |||
******************************************************************************/ | |||
using static Tensorflow.ops; | |||
namespace Tensorflow | |||
{ | |||
public partial class tensorflow | |||
{ | |||
public graph_util_impl graph_util => new graph_util_impl(); | |||
public GraphKeys GraphKeys { get; } = new GraphKeys(); | |||
public Graph get_default_graph() | |||
{ | |||
return ops.get_default_graph(); | |||
@@ -15,6 +15,7 @@ | |||
******************************************************************************/ | |||
using System.Collections.Generic; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow | |||
{ | |||
@@ -22,7 +23,7 @@ namespace Tensorflow | |||
{ | |||
public VariableV1[] global_variables(string scope = null) | |||
{ | |||
return (ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope) as List<VariableV1>) | |||
return (ops.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope) as List<VariableV1>) | |||
.ToArray(); | |||
} | |||
@@ -95,7 +95,7 @@ namespace Tensorflow | |||
break; | |||
case KindOneofCase.BytesList: | |||
//var proto_type = ops.get_collection_proto_type(key) | |||
if (ops.GraphKeys._VARIABLE_COLLECTIONS.Contains(col.Key)) | |||
if (tf.GraphKeys._VARIABLE_COLLECTIONS.Contains(col.Key)) | |||
{ | |||
foreach (var value in col.Value.BytesList.Value) | |||
{ | |||
@@ -146,7 +146,7 @@ namespace Tensorflow | |||
} | |||
} | |||
var variables = graph.get_collection<VariableV1>(ops.GraphKeys.GLOBAL_VARIABLES, | |||
var variables = graph.get_collection<VariableV1>(tf.GraphKeys.GLOBAL_VARIABLES, | |||
scope: scope_to_prepend_to_names); | |||
var var_list = new Dictionary<string, VariableV1>(); | |||
variables.ForEach(v => var_list[ops.strip_name_scope(v.name, scope_to_prepend_to_names)] = v); | |||
@@ -180,7 +180,7 @@ namespace Tensorflow | |||
var graph = ops.get_default_graph(); | |||
var var_list = new Dictionary<string, RefVariable>(); | |||
var variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) as List<RefVariable>; | |||
var variables = graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) as List<RefVariable>; | |||
if (variables != null) | |||
{ | |||
@@ -81,7 +81,7 @@ namespace Tensorflow.Layers | |||
// Update global default collections. | |||
_add_elements_to_collection(_updates.ToArray(), new string[] { ops.GraphKeys.UPDATE_OPS }); | |||
_add_elements_to_collection(_updates.ToArray(), new string[] { tf.GraphKeys.UPDATE_OPS }); | |||
return outputs; | |||
} | |||
@@ -152,9 +152,9 @@ namespace Tensorflow.Operations | |||
public (T, Tensor) BuildCondBranch<T>(Func<T> fn) | |||
{ | |||
// Add the subgraph defined by fn() to the graph. | |||
var pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION); | |||
var pre_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION); | |||
var original_result = fn(); | |||
var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION); | |||
var post_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION); | |||
//TODO: port this chunck of missing code: | |||
/* | |||
@@ -191,9 +191,9 @@ namespace Tensorflow.Operations | |||
public (T[], Tensor[]) BuildCondBranch<T>(Func<T[]> fn) | |||
{ | |||
// Add the subgraph defined by fn() to the graph. | |||
var pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION); | |||
var pre_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION); | |||
var original_result = fn(); | |||
var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION); | |||
var post_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION); | |||
switch (original_result) | |||
{ | |||
@@ -195,7 +195,7 @@ namespace Tensorflow.Operations | |||
// their associated TensorArrays for calling the body. | |||
var packed_vars_for_body = _convert_flows_to_tensorarrays(flat_loop_vars, vars_for_body); | |||
var body_result = body(packed_vars_for_body[0]); | |||
var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION); | |||
var post_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION); | |||
// Store body_result to keep track of TensorArrays returned by body | |||
var original_body_result = new[] { body_result }; | |||
@@ -2,7 +2,7 @@ | |||
{ | |||
public class Util | |||
{ | |||
public static void add_loss(Tensor loss, string loss_collection = ops.GraphKeys.LOSSES) | |||
public static void add_loss(Tensor loss, string loss_collection = "losses") | |||
{ | |||
if (!string.IsNullOrEmpty(loss_collection)) | |||
ops.add_to_collection(loss_collection, loss); | |||
@@ -22,7 +22,7 @@ namespace Tensorflow | |||
public class LossesImpl | |||
{ | |||
public Tensor compute_weighted_loss(Tensor losses, Tensor weights = null, string scope = null, | |||
string loss_collection = ops.GraphKeys.LOSSES, string reduction = Reduction.SUM_BY_NONZERO_WEIGHTS) | |||
string loss_collection = "losses", string reduction = Reduction.SUM_BY_NONZERO_WEIGHTS) | |||
{ | |||
return tf_with(ops.name_scope(scope, default_name: "weighted_loss", (losses, weights)), delegate | |||
{ | |||
@@ -101,7 +101,7 @@ namespace Tensorflow | |||
Tensor logits, | |||
float weights = 1.0f, | |||
string scope = null, | |||
string loss_collection= ops.GraphKeys.LOSSES, | |||
string loss_collection= "losses", | |||
string reduction = Reduction.SUM_BY_NONZERO_WEIGHTS) | |||
{ | |||
return tf_with(ops.name_scope(scope, | |||
@@ -431,8 +431,8 @@ namespace Tensorflow | |||
merges = _convert_flows_to_tensorarrays(new Tensor[] { (Tensor)orig_res_t }, merges); | |||
ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_t); | |||
ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_f); | |||
ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_t); | |||
ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_f); | |||
return merges[0]; | |||
}); | |||
@@ -479,8 +479,8 @@ namespace Tensorflow | |||
merges = _convert_flows_to_tensorarrays(orig_res_t, merges); | |||
ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_t); | |||
ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_f); | |||
ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_t); | |||
ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_f); | |||
return merges; | |||
}); | |||
@@ -596,7 +596,7 @@ namespace Tensorflow | |||
swap_memory: swap_memory); | |||
if (loop_context.outer_context == null) | |||
ops.add_to_collection(ops.GraphKeys.WHILE_CONTEXT, loop_context); | |||
ops.add_to_collection(tf.GraphKeys.WHILE_CONTEXT, loop_context); | |||
var results = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants, | |||
return_same_structure); | |||
@@ -33,11 +33,11 @@ namespace Tensorflow.Summaries | |||
{ | |||
var (tag, scope) = summary_scope(name, family: family, values: new Tensor[] { tensor }, default_name: "HistogramSummary"); | |||
var val = gen_logging_ops.histogram_summary(tag: tag, values: tensor, name: scope); | |||
collect(val, collections?.ToList(), new List<string> { ops.GraphKeys.SUMMARIES }); | |||
collect(val, collections?.ToList(), new List<string> { tf.GraphKeys.SUMMARIES }); | |||
return val; | |||
} | |||
public Tensor merge_all(string key = ops.GraphKeys.SUMMARIES, string scope= null, string name= null) | |||
public Tensor merge_all(string key = "summaries", string scope= null, string name= null) | |||
{ | |||
var summary_ops = ops.get_collection(key, scope: scope); | |||
if (summary_ops == null) | |||
@@ -67,7 +67,7 @@ namespace Tensorflow.Summaries | |||
{ | |||
var (tag, scope) = summary_scope(name, family: family, values: new Tensor[] { tensor }); | |||
var val = gen_logging_ops.scalar_summary(tags: tag, values: tensor, name: scope); | |||
collect(val, collections?.ToList(), new List<string> { ops.GraphKeys.SUMMARIES }); | |||
collect(val, collections?.ToList(), new List<string> { tf.GraphKeys.SUMMARIES }); | |||
return val; | |||
} | |||
@@ -198,7 +198,7 @@ namespace Tensorflow | |||
if (!tf.context.executing_eagerly()) | |||
{ | |||
var train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) as List<ITensorOrOperation>; | |||
var train_op = ops.get_collection_ref(tf.GraphKeys.TRAIN_OP) as List<ITensorOrOperation>; | |||
if (train_op != null && train_op.Contains(apply_updates)) | |||
train_op.Add(apply_updates); | |||
} | |||
@@ -359,7 +359,7 @@ namespace Tensorflow | |||
var tmp = variables.trainable_variables(); | |||
var vars = ops.get_collection<RefVariable>(ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES); | |||
var vars = ops.get_collection<RefVariable>(tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES); | |||
switch (tmp) | |||
{ | |||
case List<RefVariable> values: | |||
@@ -370,7 +370,7 @@ namespace Tensorflow | |||
break; | |||
} | |||
var_list = var_list.Concat(ops.get_collection<RefVariable>(ops.GraphKeys._STREAMING_MODEL_PORTS)).ToList(); | |||
var_list = var_list.Concat(ops.get_collection<RefVariable>(tf.GraphKeys._STREAMING_MODEL_PORTS)).ToList(); | |||
var processors = var_list.Select(v => optimizer._get_processor(v)).ToList(); | |||
var var_refs = processors.Select(x => x.target()).ToArray(); | |||
@@ -121,7 +121,7 @@ namespace Tensorflow | |||
if(collections == null) | |||
{ | |||
collections = new List<string> { ops.GraphKeys.GLOBAL_VARIABLES }; | |||
collections = new List<string> { tf.GraphKeys.GLOBAL_VARIABLES }; | |||
} | |||
// Store the graph key so optimizers know how to only retrieve variables from | |||
@@ -129,8 +129,8 @@ namespace Tensorflow | |||
_graph_key = ops.get_default_graph().graph_key; | |||
_trainable = trainable; | |||
if (trainable && !collections.Contains(ops.GraphKeys.TRAINABLE_VARIABLES)) | |||
collections.Add(ops.GraphKeys.TRAINABLE_VARIABLES); | |||
if (trainable && !collections.Contains(tf.GraphKeys.TRAINABLE_VARIABLES)) | |||
collections.Add(tf.GraphKeys.TRAINABLE_VARIABLES); | |||
ops.init_scope(); | |||
var values = init_from_fn ? new object[0] : new object[] { initial_value }; | |||
@@ -17,6 +17,7 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow | |||
{ | |||
@@ -28,7 +29,7 @@ namespace Tensorflow | |||
/// <returns></returns> | |||
public static object trainable_variables() | |||
{ | |||
return ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES); | |||
return ops.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES); | |||
} | |||
/// <summary> | |||
@@ -40,11 +41,11 @@ namespace Tensorflow | |||
{ | |||
var all = new List<VariableV1>(); | |||
var collection = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope); | |||
var collection = ops.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope); | |||
if(collection != null) | |||
all.AddRange(collection as List<VariableV1>); | |||
collection = ops.get_collection(ops.GraphKeys.SAVEABLE_OBJECTS, scope); | |||
collection = ops.get_collection(tf.GraphKeys.SAVEABLE_OBJECTS, scope); | |||
if (collection != null) | |||
all.AddRange(collection as List<VariableV1>); | |||
@@ -64,7 +65,7 @@ namespace Tensorflow | |||
/// <returns>A list of `Variable` objects.</returns> | |||
public static List<VariableV1> global_variables(string scope = null) | |||
{ | |||
var result = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope); | |||
var result = ops.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope); | |||
return result == null ? new List<VariableV1>() : result as List<VariableV1>; | |||
} | |||
@@ -27,57 +27,57 @@ namespace Tensorflow | |||
/// specified, but it is also possible to pass an explicit list of | |||
/// variables. | |||
/// </summary> | |||
public static class GraphKeys | |||
public class GraphKeys | |||
{ | |||
/// <summary> | |||
/// the subset of `Variable` objects that will be trained by an optimizer. | |||
/// </summary> | |||
public static string TRAINABLE_VARIABLES = "trainable_variables"; | |||
public string TRAINABLE_VARIABLES = "trainable_variables"; | |||
/// <summary> | |||
/// Trainable resource-style variables. | |||
/// </summary> | |||
public static string TRAINABLE_RESOURCE_VARIABLES = "trainable_resource_variables"; | |||
public string TRAINABLE_RESOURCE_VARIABLES = "trainable_resource_variables"; | |||
/// <summary> | |||
/// Key for streaming model ports. | |||
/// </summary> | |||
public static string _STREAMING_MODEL_PORTS = "streaming_model_ports"; | |||
public string _STREAMING_MODEL_PORTS = "streaming_model_ports"; | |||
/// <summary> | |||
/// Key to collect losses | |||
/// </summary> | |||
public const string LOSSES = "losses"; | |||
public string LOSSES = "losses"; | |||
/// <summary> | |||
/// Key to collect Variable objects that are global (shared across machines). | |||
/// Default collection for all variables, except local ones. | |||
/// </summary> | |||
public static string GLOBAL_VARIABLES = "variables"; | |||
public string GLOBAL_VARIABLES = "variables"; | |||
public static string TRAIN_OP = "train_op"; | |||
public string TRAIN_OP = "train_op"; | |||
public static string GLOBAL_STEP = GLOBAL_STEP = "global_step"; | |||
public string GLOBAL_STEP = "global_step"; | |||
public static string[] _VARIABLE_COLLECTIONS = new string[] { "variables", "trainable_variables", "model_variables" }; | |||
public string[] _VARIABLE_COLLECTIONS = new string[] { "variables", "trainable_variables", "model_variables" }; | |||
/// <summary> | |||
/// Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing. | |||
/// </summary> | |||
public static string SAVEABLE_OBJECTS = "saveable_objects"; | |||
public string SAVEABLE_OBJECTS = "saveable_objects"; | |||
/// <summary> | |||
/// Key to collect update_ops | |||
/// </summary> | |||
public static string UPDATE_OPS = "update_ops"; | |||
public string UPDATE_OPS = "update_ops"; | |||
// Key to collect summaries. | |||
public const string SUMMARIES = "summaries"; | |||
public string SUMMARIES = "summaries"; | |||
// Used to store v2 summary names. | |||
public static string _SUMMARY_COLLECTION = "_SUMMARY_V2"; | |||
public string _SUMMARY_COLLECTION = "_SUMMARY_V2"; | |||
// Key for control flow context. | |||
public static string COND_CONTEXT = "cond_context"; | |||
public static string WHILE_CONTEXT = "while_context"; | |||
public string COND_CONTEXT = "cond_context"; | |||
public string WHILE_CONTEXT = "while_context"; | |||
} | |||
} | |||
} |
@@ -80,26 +80,18 @@ namespace TensorFlowNET.Examples | |||
for (int epoch = 0; epoch < training_epochs; epoch++) | |||
{ | |||
foreach (var (x, y) in zip<float>(train_X, train_Y)) | |||
{ | |||
sess.run(optimizer, | |||
new FeedItem(X, x), | |||
new FeedItem(Y, y)); | |||
} | |||
sess.run(optimizer, (X, x), (Y, y)); | |||
// Display logs per epoch step | |||
if ((epoch + 1) % display_step == 0) | |||
{ | |||
var c = sess.run(cost, | |||
new FeedItem(X, train_X), | |||
new FeedItem(Y, train_Y)); | |||
var c = sess.run(cost, (X, train_X), (Y, train_Y)); | |||
Console.WriteLine($"Epoch: {epoch + 1} cost={c} " + $"W={sess.run(W)} b={sess.run(b)}"); | |||
} | |||
} | |||
Console.WriteLine("Optimization Finished!"); | |||
var training_cost = sess.run(cost, | |||
new FeedItem(X, train_X), | |||
new FeedItem(Y, train_Y)); | |||
var training_cost = sess.run(cost, (X, train_X), (Y, train_Y)); | |||
Console.WriteLine($"Training cost={training_cost} W={sess.run(W)} b={sess.run(b)}"); | |||
// Testing example | |||
@@ -107,8 +99,7 @@ namespace TensorFlowNET.Examples | |||
var test_Y = np.array(1.84f, 2.273f, 3.2f, 2.831f, 2.92f, 3.24f, 1.35f, 1.03f); | |||
Console.WriteLine("Testing... (Mean square loss Comparison)"); | |||
var testing_cost = sess.run(tf.reduce_sum(tf.pow(pred - Y, 2.0f)) / (2.0f * test_X.shape[0]), | |||
new FeedItem(X, test_X), | |||
new FeedItem(Y, test_Y)); | |||
(X, test_X), (Y, test_Y)); | |||
Console.WriteLine($"Testing cost={testing_cost}"); | |||
var diff = Math.Abs((float)training_cost - (float)testing_cost); | |||
Console.WriteLine($"Absolute mean square loss difference: {diff}"); | |||
@@ -118,7 +118,7 @@ namespace TensorFlowNET.Examples.Text | |||
var y_one_hot = tf.one_hot(y, num_class); | |||
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits: logits, labels: y_one_hot)); | |||
var update_ops = tf.get_collection(ops.GraphKeys.UPDATE_OPS) as List<object>; | |||
var update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) as List<object>; | |||
tf_with(tf.control_dependencies(update_ops.Select(x => (Operation)x).ToArray()), delegate | |||
{ | |||
var adam = tf.train.AdamOptimizer(learning_rate); | |||