diff --git a/README.md b/README.md index d10dd86a..21d3748e 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ ![logo](docs/assets/tf.net.logo.png) -**TensorFlow.NET** (TF.NET) provides a .NET Standard binding for [TensorFlow](https://www.tensorflow.org/). It aims to implement the complete Tensorflow API in CSharp which allows .NET developers to develop, train and deploy Machine Learning models with the cross-platform .NET Standard framework. +**TensorFlow.NET** (TF.NET) provides a .NET Standard binding for [TensorFlow](https://www.tensorflow.org/). It aims to implement the complete Tensorflow API in C# which allows .NET developers to develop, train and deploy Machine Learning models with the cross-platform .NET Standard framework. [![Join the chat at https://gitter.im/publiclab/publiclab](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/sci-sharp/community) [![Tensorflow.NET](https://ci.appveyor.com/api/projects/status/wx4td43v2d3f2xj6?svg=true)](https://ci.appveyor.com/project/Haiping-Chen/tensorflow-net) @@ -34,40 +34,15 @@ PM> Install-Package TensorFlow.NET ### Install tensorflow binary ### For CPU version PM> Install-Package SciSharp.TensorFlow.Redist + ### For GPU version (CUDA and cuDNN are required) PM> Install-Package SciSharp.TensorFlow.Redist-Windows-GPU ``` -Import TF.NET. - -```cs -using Tensorflow; -``` - -Add two constants: -```cs -// Create a Constant op -var a = tf.constant(4.0f); -var b = tf.constant(5.0f); -var c = tf.add(a, b); - -using (var sess = tf.Session()) -{ - var o = sess.run(c); -} -``` +Import TF.NET in your project. -Feed placeholder: ```cs -// Create a placeholder op -var a = tf.placeholder(tf.float32); -var b = tf.placeholder(tf.float32); -var c = tf.add(a, b); - -using(var sess = tf.Session()) -{ - var o = sess.run(c, new FeedItem(a, 3.0f), new FeedItem(b, 2.0f)); -} +using static Tensorflow.Binding; ``` Linear Regression: @@ -91,39 +66,40 @@ var optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost); var init = tf.global_variables_initializer(); // Start training -with(tf.Session(), sess => +using(tf.Session()) { // Run the initializer sess.run(init); - + // Fit all training data for (int epoch = 0; epoch < training_epochs; epoch++) { foreach (var (x, y) in zip(train_X, train_Y)) - sess.run(optimizer, new FeedItem(X, x), new FeedItem(Y, y)); - + sess.run(optimizer, (X, x), (Y, y)); + // Display logs per epoch step if ((epoch + 1) % display_step == 0) { - var c = sess.run(cost, new FeedItem(X, train_X), new FeedItem(Y, train_Y)); + var c = sess.run(cost, (X, train_X), (Y, train_Y)); Console.WriteLine($"Epoch: {epoch + 1} cost={c} " + $"W={sess.run(W)} b={sess.run(b)}"); } - - Console.WriteLine("Optimization Finished!"); - var training_cost = sess.run(cost, new FeedItem(X, train_X), new FeedItem(Y, train_Y)); - Console.WriteLine($"Training cost={training_cost} W={sess.run(W)} b={sess.run(b)}"); - - // Testing example - var test_X = np.array(6.83f, 4.668f, 8.9f, 7.91f, 5.7f, 8.7f, 3.1f, 2.1f); - var test_Y = np.array(1.84f, 2.273f, 3.2f, 2.831f, 2.92f, 3.24f, 1.35f, 1.03f); - Console.WriteLine("Testing... (Mean square loss Comparison)"); - - var testing_cost = sess.run(tf.reduce_sum(tf.pow(pred - Y, 2.0f)) / (2.0f * test_X.shape[0]), new FeedItem(X, test_X), new FeedItem(Y, test_Y)); - Console.WriteLine($"Testing cost={testing_cost}"); - - var diff = Math.Abs((float)training_cost - (float)testing_cost); - Console.WriteLine($"Absolute mean square loss difference: {diff}"); } + + Console.WriteLine("Optimization Finished!"); + var training_cost = sess.run(cost, (X, train_X), (Y, train_Y)); + Console.WriteLine($"Training cost={training_cost} W={sess.run(W)} b={sess.run(b)}"); + + // Testing example + var test_X = np.array(6.83f, 4.668f, 8.9f, 7.91f, 5.7f, 8.7f, 3.1f, 2.1f); + var test_Y = np.array(1.84f, 2.273f, 3.2f, 2.831f, 2.92f, 3.24f, 1.35f, 1.03f); + Console.WriteLine("Testing... (Mean square loss Comparison)"); + var testing_cost = sess.run(tf.reduce_sum(tf.pow(pred - Y, 2.0f)) / (2.0f * test_X.shape[0]), + (X, test_X), (Y, test_Y)); + Console.WriteLine($"Testing cost={testing_cost}"); + var diff = Math.Abs((float)training_cost - (float)testing_cost); + Console.WriteLine($"Absolute mean square loss difference: {diff}"); + + return diff < 0.01; }); ``` diff --git a/src/TensorFlowNET.Core/APIs/tf.graph.cs b/src/TensorFlowNET.Core/APIs/tf.graph.cs index a60c0413..cee941ed 100644 --- a/src/TensorFlowNET.Core/APIs/tf.graph.cs +++ b/src/TensorFlowNET.Core/APIs/tf.graph.cs @@ -14,11 +14,16 @@ limitations under the License. ******************************************************************************/ +using static Tensorflow.ops; + namespace Tensorflow { public partial class tensorflow { public graph_util_impl graph_util => new graph_util_impl(); + + public GraphKeys GraphKeys { get; } = new GraphKeys(); + public Graph get_default_graph() { return ops.get_default_graph(); diff --git a/src/TensorFlowNET.Core/APIs/tf.variable.cs b/src/TensorFlowNET.Core/APIs/tf.variable.cs index 8e905b1d..b3c5bf43 100644 --- a/src/TensorFlowNET.Core/APIs/tf.variable.cs +++ b/src/TensorFlowNET.Core/APIs/tf.variable.cs @@ -15,6 +15,7 @@ ******************************************************************************/ using System.Collections.Generic; +using static Tensorflow.Binding; namespace Tensorflow { @@ -22,7 +23,7 @@ namespace Tensorflow { public VariableV1[] global_variables(string scope = null) { - return (ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope) as List) + return (ops.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope) as List) .ToArray(); } diff --git a/src/TensorFlowNET.Core/Framework/meta_graph.py.cs b/src/TensorFlowNET.Core/Framework/meta_graph.py.cs index db9dba38..d7d7ef7e 100644 --- a/src/TensorFlowNET.Core/Framework/meta_graph.py.cs +++ b/src/TensorFlowNET.Core/Framework/meta_graph.py.cs @@ -95,7 +95,7 @@ namespace Tensorflow break; case KindOneofCase.BytesList: //var proto_type = ops.get_collection_proto_type(key) - if (ops.GraphKeys._VARIABLE_COLLECTIONS.Contains(col.Key)) + if (tf.GraphKeys._VARIABLE_COLLECTIONS.Contains(col.Key)) { foreach (var value in col.Value.BytesList.Value) { @@ -146,7 +146,7 @@ namespace Tensorflow } } - var variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, + var variables = graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope: scope_to_prepend_to_names); var var_list = new Dictionary(); variables.ForEach(v => var_list[ops.strip_name_scope(v.name, scope_to_prepend_to_names)] = v); @@ -180,7 +180,7 @@ namespace Tensorflow var graph = ops.get_default_graph(); var var_list = new Dictionary(); - var variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) as List; + var variables = graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) as List; if (variables != null) { diff --git a/src/TensorFlowNET.Core/Layers/Layer.cs b/src/TensorFlowNET.Core/Layers/Layer.cs index 304e7f7b..a3ae3356 100644 --- a/src/TensorFlowNET.Core/Layers/Layer.cs +++ b/src/TensorFlowNET.Core/Layers/Layer.cs @@ -81,7 +81,7 @@ namespace Tensorflow.Layers // Update global default collections. - _add_elements_to_collection(_updates.ToArray(), new string[] { ops.GraphKeys.UPDATE_OPS }); + _add_elements_to_collection(_updates.ToArray(), new string[] { tf.GraphKeys.UPDATE_OPS }); return outputs; } diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs index b1503567..aa314efb 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs @@ -152,9 +152,9 @@ namespace Tensorflow.Operations public (T, Tensor) BuildCondBranch(Func fn) { // Add the subgraph defined by fn() to the graph. - var pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION); + var pre_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION); var original_result = fn(); - var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION); + var post_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION); //TODO: port this chunck of missing code: /* @@ -191,9 +191,9 @@ namespace Tensorflow.Operations public (T[], Tensor[]) BuildCondBranch(Func fn) { // Add the subgraph defined by fn() to the graph. - var pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION); + var pre_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION); var original_result = fn(); - var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION); + var post_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION); switch (original_result) { diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs index ccd88480..1faaa647 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs @@ -195,7 +195,7 @@ namespace Tensorflow.Operations // their associated TensorArrays for calling the body. var packed_vars_for_body = _convert_flows_to_tensorarrays(flat_loop_vars, vars_for_body); var body_result = body(packed_vars_for_body[0]); - var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION); + var post_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION); // Store body_result to keep track of TensorArrays returned by body var original_body_result = new[] { body_result }; diff --git a/src/TensorFlowNET.Core/Operations/Losses/Util.cs b/src/TensorFlowNET.Core/Operations/Losses/Util.cs index 71b3ed62..fde5bcb0 100644 --- a/src/TensorFlowNET.Core/Operations/Losses/Util.cs +++ b/src/TensorFlowNET.Core/Operations/Losses/Util.cs @@ -2,7 +2,7 @@ { public class Util { - public static void add_loss(Tensor loss, string loss_collection = ops.GraphKeys.LOSSES) + public static void add_loss(Tensor loss, string loss_collection = "losses") { if (!string.IsNullOrEmpty(loss_collection)) ops.add_to_collection(loss_collection, loss); diff --git a/src/TensorFlowNET.Core/Operations/Losses/losses_impl.py.cs b/src/TensorFlowNET.Core/Operations/Losses/losses_impl.py.cs index de4bf964..1f4ce2d8 100644 --- a/src/TensorFlowNET.Core/Operations/Losses/losses_impl.py.cs +++ b/src/TensorFlowNET.Core/Operations/Losses/losses_impl.py.cs @@ -22,7 +22,7 @@ namespace Tensorflow public class LossesImpl { public Tensor compute_weighted_loss(Tensor losses, Tensor weights = null, string scope = null, - string loss_collection = ops.GraphKeys.LOSSES, string reduction = Reduction.SUM_BY_NONZERO_WEIGHTS) + string loss_collection = "losses", string reduction = Reduction.SUM_BY_NONZERO_WEIGHTS) { return tf_with(ops.name_scope(scope, default_name: "weighted_loss", (losses, weights)), delegate { @@ -101,7 +101,7 @@ namespace Tensorflow Tensor logits, float weights = 1.0f, string scope = null, - string loss_collection= ops.GraphKeys.LOSSES, + string loss_collection= "losses", string reduction = Reduction.SUM_BY_NONZERO_WEIGHTS) { return tf_with(ops.name_scope(scope, diff --git a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs index 04595256..04ef54a7 100644 --- a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs @@ -431,8 +431,8 @@ namespace Tensorflow merges = _convert_flows_to_tensorarrays(new Tensor[] { (Tensor)orig_res_t }, merges); - ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_t); - ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_f); + ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_t); + ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_f); return merges[0]; }); @@ -479,8 +479,8 @@ namespace Tensorflow merges = _convert_flows_to_tensorarrays(orig_res_t, merges); - ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_t); - ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_f); + ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_t); + ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_f); return merges; }); @@ -596,7 +596,7 @@ namespace Tensorflow swap_memory: swap_memory); if (loop_context.outer_context == null) - ops.add_to_collection(ops.GraphKeys.WHILE_CONTEXT, loop_context); + ops.add_to_collection(tf.GraphKeys.WHILE_CONTEXT, loop_context); var results = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants, return_same_structure); diff --git a/src/TensorFlowNET.Core/Summaries/Summary.cs b/src/TensorFlowNET.Core/Summaries/Summary.cs index 2bea0ddc..3d157bd9 100644 --- a/src/TensorFlowNET.Core/Summaries/Summary.cs +++ b/src/TensorFlowNET.Core/Summaries/Summary.cs @@ -33,11 +33,11 @@ namespace Tensorflow.Summaries { var (tag, scope) = summary_scope(name, family: family, values: new Tensor[] { tensor }, default_name: "HistogramSummary"); var val = gen_logging_ops.histogram_summary(tag: tag, values: tensor, name: scope); - collect(val, collections?.ToList(), new List { ops.GraphKeys.SUMMARIES }); + collect(val, collections?.ToList(), new List { tf.GraphKeys.SUMMARIES }); return val; } - public Tensor merge_all(string key = ops.GraphKeys.SUMMARIES, string scope= null, string name= null) + public Tensor merge_all(string key = "summaries", string scope= null, string name= null) { var summary_ops = ops.get_collection(key, scope: scope); if (summary_ops == null) @@ -67,7 +67,7 @@ namespace Tensorflow.Summaries { var (tag, scope) = summary_scope(name, family: family, values: new Tensor[] { tensor }); var val = gen_logging_ops.scalar_summary(tags: tag, values: tensor, name: scope); - collect(val, collections?.ToList(), new List { ops.GraphKeys.SUMMARIES }); + collect(val, collections?.ToList(), new List { tf.GraphKeys.SUMMARIES }); return val; } diff --git a/src/TensorFlowNET.Core/Train/Optimizer.cs b/src/TensorFlowNET.Core/Train/Optimizer.cs index c031da54..bb8fcd7a 100644 --- a/src/TensorFlowNET.Core/Train/Optimizer.cs +++ b/src/TensorFlowNET.Core/Train/Optimizer.cs @@ -198,7 +198,7 @@ namespace Tensorflow if (!tf.context.executing_eagerly()) { - var train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) as List; + var train_op = ops.get_collection_ref(tf.GraphKeys.TRAIN_OP) as List; if (train_op != null && train_op.Contains(apply_updates)) train_op.Add(apply_updates); } @@ -359,7 +359,7 @@ namespace Tensorflow var tmp = variables.trainable_variables(); - var vars = ops.get_collection(ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES); + var vars = ops.get_collection(tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES); switch (tmp) { case List values: @@ -370,7 +370,7 @@ namespace Tensorflow break; } - var_list = var_list.Concat(ops.get_collection(ops.GraphKeys._STREAMING_MODEL_PORTS)).ToList(); + var_list = var_list.Concat(ops.get_collection(tf.GraphKeys._STREAMING_MODEL_PORTS)).ToList(); var processors = var_list.Select(v => optimizer._get_processor(v)).ToList(); var var_refs = processors.Select(x => x.target()).ToArray(); diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.cs b/src/TensorFlowNET.Core/Variables/RefVariable.cs index 9ac7e6ea..3f9e8acf 100644 --- a/src/TensorFlowNET.Core/Variables/RefVariable.cs +++ b/src/TensorFlowNET.Core/Variables/RefVariable.cs @@ -121,7 +121,7 @@ namespace Tensorflow if(collections == null) { - collections = new List { ops.GraphKeys.GLOBAL_VARIABLES }; + collections = new List { tf.GraphKeys.GLOBAL_VARIABLES }; } // Store the graph key so optimizers know how to only retrieve variables from @@ -129,8 +129,8 @@ namespace Tensorflow _graph_key = ops.get_default_graph().graph_key; _trainable = trainable; - if (trainable && !collections.Contains(ops.GraphKeys.TRAINABLE_VARIABLES)) - collections.Add(ops.GraphKeys.TRAINABLE_VARIABLES); + if (trainable && !collections.Contains(tf.GraphKeys.TRAINABLE_VARIABLES)) + collections.Add(tf.GraphKeys.TRAINABLE_VARIABLES); ops.init_scope(); var values = init_from_fn ? new object[0] : new object[] { initial_value }; diff --git a/src/TensorFlowNET.Core/Variables/variables.py.cs b/src/TensorFlowNET.Core/Variables/variables.py.cs index 3880bc7f..6e9d0e4c 100644 --- a/src/TensorFlowNET.Core/Variables/variables.py.cs +++ b/src/TensorFlowNET.Core/Variables/variables.py.cs @@ -17,6 +17,7 @@ using System; using System.Collections.Generic; using System.Linq; +using static Tensorflow.Binding; namespace Tensorflow { @@ -28,7 +29,7 @@ namespace Tensorflow /// public static object trainable_variables() { - return ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES); + return ops.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES); } /// @@ -40,11 +41,11 @@ namespace Tensorflow { var all = new List(); - var collection = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope); + var collection = ops.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope); if(collection != null) all.AddRange(collection as List); - collection = ops.get_collection(ops.GraphKeys.SAVEABLE_OBJECTS, scope); + collection = ops.get_collection(tf.GraphKeys.SAVEABLE_OBJECTS, scope); if (collection != null) all.AddRange(collection as List); @@ -64,7 +65,7 @@ namespace Tensorflow /// A list of `Variable` objects. public static List global_variables(string scope = null) { - var result = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope); + var result = ops.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope); return result == null ? new List() : result as List; } diff --git a/src/TensorFlowNET.Core/ops.GraphKeys.cs b/src/TensorFlowNET.Core/ops.GraphKeys.cs index 94e1b8d5..17b095a4 100644 --- a/src/TensorFlowNET.Core/ops.GraphKeys.cs +++ b/src/TensorFlowNET.Core/ops.GraphKeys.cs @@ -27,57 +27,57 @@ namespace Tensorflow /// specified, but it is also possible to pass an explicit list of /// variables. /// - public static class GraphKeys + public class GraphKeys { /// /// the subset of `Variable` objects that will be trained by an optimizer. /// - public static string TRAINABLE_VARIABLES = "trainable_variables"; + public string TRAINABLE_VARIABLES = "trainable_variables"; /// /// Trainable resource-style variables. /// - public static string TRAINABLE_RESOURCE_VARIABLES = "trainable_resource_variables"; + public string TRAINABLE_RESOURCE_VARIABLES = "trainable_resource_variables"; /// /// Key for streaming model ports. /// - public static string _STREAMING_MODEL_PORTS = "streaming_model_ports"; + public string _STREAMING_MODEL_PORTS = "streaming_model_ports"; /// /// Key to collect losses /// - public const string LOSSES = "losses"; + public string LOSSES = "losses"; /// /// Key to collect Variable objects that are global (shared across machines). /// Default collection for all variables, except local ones. /// - public static string GLOBAL_VARIABLES = "variables"; + public string GLOBAL_VARIABLES = "variables"; - public static string TRAIN_OP = "train_op"; + public string TRAIN_OP = "train_op"; - public static string GLOBAL_STEP = GLOBAL_STEP = "global_step"; + public string GLOBAL_STEP = "global_step"; - public static string[] _VARIABLE_COLLECTIONS = new string[] { "variables", "trainable_variables", "model_variables" }; + public string[] _VARIABLE_COLLECTIONS = new string[] { "variables", "trainable_variables", "model_variables" }; /// /// Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing. /// - public static string SAVEABLE_OBJECTS = "saveable_objects"; + public string SAVEABLE_OBJECTS = "saveable_objects"; /// /// Key to collect update_ops /// - public static string UPDATE_OPS = "update_ops"; + public string UPDATE_OPS = "update_ops"; // Key to collect summaries. - public const string SUMMARIES = "summaries"; + public string SUMMARIES = "summaries"; // Used to store v2 summary names. - public static string _SUMMARY_COLLECTION = "_SUMMARY_V2"; + public string _SUMMARY_COLLECTION = "_SUMMARY_V2"; // Key for control flow context. - public static string COND_CONTEXT = "cond_context"; - public static string WHILE_CONTEXT = "while_context"; + public string COND_CONTEXT = "cond_context"; + public string WHILE_CONTEXT = "while_context"; } } } diff --git a/test/TensorFlowNET.Examples/BasicModels/LinearRegression.cs b/test/TensorFlowNET.Examples/BasicModels/LinearRegression.cs index a9dfbe7e..9b33b28f 100644 --- a/test/TensorFlowNET.Examples/BasicModels/LinearRegression.cs +++ b/test/TensorFlowNET.Examples/BasicModels/LinearRegression.cs @@ -80,26 +80,18 @@ namespace TensorFlowNET.Examples for (int epoch = 0; epoch < training_epochs; epoch++) { foreach (var (x, y) in zip(train_X, train_Y)) - { - sess.run(optimizer, - new FeedItem(X, x), - new FeedItem(Y, y)); - } + sess.run(optimizer, (X, x), (Y, y)); // Display logs per epoch step if ((epoch + 1) % display_step == 0) { - var c = sess.run(cost, - new FeedItem(X, train_X), - new FeedItem(Y, train_Y)); + var c = sess.run(cost, (X, train_X), (Y, train_Y)); Console.WriteLine($"Epoch: {epoch + 1} cost={c} " + $"W={sess.run(W)} b={sess.run(b)}"); } } Console.WriteLine("Optimization Finished!"); - var training_cost = sess.run(cost, - new FeedItem(X, train_X), - new FeedItem(Y, train_Y)); + var training_cost = sess.run(cost, (X, train_X), (Y, train_Y)); Console.WriteLine($"Training cost={training_cost} W={sess.run(W)} b={sess.run(b)}"); // Testing example @@ -107,8 +99,7 @@ namespace TensorFlowNET.Examples var test_Y = np.array(1.84f, 2.273f, 3.2f, 2.831f, 2.92f, 3.24f, 1.35f, 1.03f); Console.WriteLine("Testing... (Mean square loss Comparison)"); var testing_cost = sess.run(tf.reduce_sum(tf.pow(pred - Y, 2.0f)) / (2.0f * test_X.shape[0]), - new FeedItem(X, test_X), - new FeedItem(Y, test_Y)); + (X, test_X), (Y, test_Y)); Console.WriteLine($"Testing cost={testing_cost}"); var diff = Math.Abs((float)training_cost - (float)testing_cost); Console.WriteLine($"Absolute mean square loss difference: {diff}"); diff --git a/test/TensorFlowNET.Examples/TextProcessing/cnn_models/VdCnn.cs b/test/TensorFlowNET.Examples/TextProcessing/cnn_models/VdCnn.cs index 9b28fdc0..6150fa90 100644 --- a/test/TensorFlowNET.Examples/TextProcessing/cnn_models/VdCnn.cs +++ b/test/TensorFlowNET.Examples/TextProcessing/cnn_models/VdCnn.cs @@ -118,7 +118,7 @@ namespace TensorFlowNET.Examples.Text var y_one_hot = tf.one_hot(y, num_class); loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits: logits, labels: y_one_hot)); - var update_ops = tf.get_collection(ops.GraphKeys.UPDATE_OPS) as List; + var update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) as List; tf_with(tf.control_dependencies(update_ops.Select(x => (Operation)x).ToArray()), delegate { var adam = tf.train.AdamOptimizer(learning_rate);