diff --git a/src/TensorFlowNET.Core/APIs/tf.math.cs b/src/TensorFlowNET.Core/APIs/tf.math.cs index 5b10813d..53a3b097 100644 --- a/src/TensorFlowNET.Core/APIs/tf.math.cs +++ b/src/TensorFlowNET.Core/APIs/tf.math.cs @@ -298,5 +298,8 @@ namespace Tensorflow public static Tensor argmax(Tensor input, int axis = -1, string name = null, int? dimension = null, TF_DataType output_type = TF_DataType.TF_INT64) => gen_math_ops.arg_max(input, axis, name: name, output_type: output_type); + + public static Tensor square(Tensor x, string name = null) + => gen_math_ops.square(x, name: name); } } diff --git a/src/TensorFlowNET.Core/APIs/tf.nn.cs b/src/TensorFlowNET.Core/APIs/tf.nn.cs index 8a1b648e..1448d9ae 100644 --- a/src/TensorFlowNET.Core/APIs/tf.nn.cs +++ b/src/TensorFlowNET.Core/APIs/tf.nn.cs @@ -26,7 +26,9 @@ namespace Tensorflow partition_strategy: partition_strategy, name: name); - public static IActivation relu => new relu(); + public static IActivation relu() => new relu(); + + public static Tensor relu(Tensor features, string name = null) => gen_nn_ops.relu(features, name); public static Tensor[] fused_batch_norm(Tensor x, RefVariable scale, diff --git a/src/TensorFlowNET.Core/APIs/tf.random.cs b/src/TensorFlowNET.Core/APIs/tf.random.cs index 3e273e61..b309fc06 100644 --- a/src/TensorFlowNET.Core/APIs/tf.random.cs +++ b/src/TensorFlowNET.Core/APIs/tf.random.cs @@ -29,5 +29,13 @@ namespace Tensorflow TF_DataType dtype = TF_DataType.TF_FLOAT, int? seed = null, string name = null) => random_ops.random_uniform(shape, minval, maxval, dtype, seed, name); + + public static Tensor truncated_normal(int[] shape, + float mean = 0.0f, + float stddev = 1.0f, + TF_DataType dtype = TF_DataType.TF_FLOAT, + int? seed = null, + string name = null) + => random_ops.truncated_normal(shape, mean, stddev, dtype, seed, name); } } diff --git a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs index 72ca1c1b..9d0d49b1 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs @@ -1,6 +1,8 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Text; +using Tensorflow.Operations.Activation; namespace Tensorflow.Operations { @@ -173,5 +175,59 @@ namespace Tensorflow.Operations return (_op.outputs[0], _op.outputs[1]); } + + /// + /// Computes rectified linear: `max(features, 0)`. + /// + /// A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `uint8`, `int16`, `int8`, `int64`, `bfloat16`, `uint16`, `half`, `uint32`, `uint64`, `qint8`. + /// A name for the operation (optional). + /// A `Tensor`. Has the same type as `features`. + public static Tensor relu(Tensor features, string name = null) + { + + //_ctx = _context._context + //if _ctx is not None and _ctx._eager_context.is_eager: + // try: + // _result = _pywrap_tensorflow.TFE_Py_FastPathExecute( + // _ctx._context_handle, _ctx._eager_context.device_name, "Relu", name, + // _ctx._post_execution_callbacks, features) + // return _result + // except _core._FallbackException: + // try: + // return relu_eager_fallback( + // features, name=name, ctx=_ctx) + // except _core._SymbolicException: + // pass # Add nodes to the TensorFlow graph. + // except (TypeError, ValueError): + // result = _dispatch.dispatch( + // relu, features=features, name=name) + // if result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + // return result + // raise + // except _core._NotOkStatusException as e: + // if name is not None: + // message = e.message + " name: " + name + // else: + // message = e.message + // _six.raise_from(_core._status_to_exception(e.code, message), None) + //# Add nodes to the TensorFlow graph. + //try: + OpDefLibrary _op_def_lib = new OpDefLibrary(); + var _op = _op_def_lib._apply_op_helper("Relu", name: name, args: new { features }); + return _op.outputs[0]; + //except (TypeError, ValueError): + // result = _dispatch.dispatch( + // relu, features=features, name=name) + // if result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + // return result + // raise + // var _result = _op.outputs.ToArray(); + //_inputs_flat = _op.inputs + //_attrs = ("T", _op.get_attr("T")) + //_execute.record_gradient( + // "Relu", _inputs_flat, _attrs, _result, name) + //_result, = _result + // return _result; + } } } diff --git a/test/TensorFlowNET.Examples/NeuralNetXor.cs b/test/TensorFlowNET.Examples/NeuralNetXor.cs new file mode 100644 index 00000000..3e64d1bd --- /dev/null +++ b/test/TensorFlowNET.Examples/NeuralNetXor.cs @@ -0,0 +1,96 @@ +using System; +using System.Collections.Generic; +using System.Text; +using NumSharp; +using Tensorflow; + +namespace TensorFlowNET.Examples +{ + /// + /// Simple vanilla neural net solving the famous XOR problem + /// https://github.com/amygdala/tensorflow-workshop/blob/master/workshop_sections/getting_started/xor/README.md + /// + public class NeuralNetXor : Python, IExample + { + public int Priority => 2; + public bool Enabled { get; set; } = true; + public string Name => "NN XOR"; + + public int num_steps = 5000; + + private (Operation, Tensor, RefVariable) make_graph(Tensor features,Tensor labels, int num_hidden = 8) + { + var stddev = 1 / Math.Sqrt(2); + var hidden_weights = tf.Variable(tf.truncated_normal(new int []{2, num_hidden}, stddev: (float) stddev )); + + // Shape [4, num_hidden] + var hidden_activations = tf.nn.relu(tf.matmul(features, hidden_weights)); + + var output_weights = tf.Variable(tf.truncated_normal( + new[] {num_hidden, 1}, + stddev: (float) (1 / Math.Sqrt(num_hidden)) + )); + + // Shape [4, 1] + var logits = tf.matmul(hidden_activations, output_weights); + + // Shape [4] + var predictions = tf.sigmoid(tf.squeeze(logits)); + var loss = tf.reduce_mean(tf.square(predictions - tf.cast(labels, tf.float32))); + + var gs = tf.Variable(0, trainable: false); + var train_op = tf.train.GradientDescentOptimizer(0.2f).minimize(loss, global_step: gs); + + return (train_op, loss, gs); + } + + public bool Run() + { + + var graph = tf.Graph(); + + var init=with(graph.as_default(), g => + { + var features = tf.placeholder(tf.float32, new TensorShape(4, 2)); + var labels = tf.placeholder(tf.int32, new TensorShape(4)); + + var (train_op, loss, gs) = make_graph(features, labels); + return tf.global_variables_initializer(); + }); + + // Start tf session + with(tf.Session(), sess => + { + init.run(); + var step = 0; + var xy = np.array(new bool[,] + { + {true, false, }, + {true, true, }, + {false, false, }, + {false, true, }, + }, dtype: np.float32); + + var y_ = np.array(new[] {true, false, false, true}, dtype: np.int32); + while (step < num_steps) + { + // original python: + //_, step, loss_value = sess.run( + // [train_op, gs, loss], + // feed_dict={features: xy, labels: y_} + // ) + // TODO: how the hell to port that to c#? + // var ( _, step, loss_value) = sess.run(new object[] {train_op, gs, loss},feed_dict: new {"features": xy, "labels": y_}); + } + //tf.logging.info('Final loss is: {}'.format(loss_value)) + //Console.WriteLine($"Final loss is: {loss_value}"); + + }); + return true; + } + + public void PrepareData() + { + } + } +} diff --git a/test/TensorFlowNET.Examples/TextClassification/cnn_models/VdCnn.cs b/test/TensorFlowNET.Examples/TextClassification/cnn_models/VdCnn.cs index d01b458d..da6f49b6 100644 --- a/test/TensorFlowNET.Examples/TextClassification/cnn_models/VdCnn.cs +++ b/test/TensorFlowNET.Examples/TextClassification/cnn_models/VdCnn.cs @@ -66,7 +66,7 @@ namespace TensorFlowNET.Examples.TextClassification filters: num_filters[0], kernel_size: new int[] { filter_sizes[0], embedding_size }, kernel_initializer: cnn_initializer, - activation: tf.nn.relu); + activation: tf.nn.relu()); conv0 = tf.transpose(conv0, new int[] { 0, 1, 3, 2 }); }); @@ -99,12 +99,12 @@ namespace TensorFlowNET.Examples.TextClassification // ============= Fully Connected Layers ============= with(tf.name_scope("fc-1"), scope => { - fc1_out = tf.layers.dense(h_flat, 2048, activation: tf.nn.relu, kernel_initializer: fc_initializer); + fc1_out = tf.layers.dense(h_flat, 2048, activation: tf.nn.relu(), kernel_initializer: fc_initializer); }); with(tf.name_scope("fc-2"), scope => { - fc2_out = tf.layers.dense(fc1_out, 2048, activation: tf.nn.relu, kernel_initializer: fc_initializer); + fc2_out = tf.layers.dense(fc1_out, 2048, activation: tf.nn.relu(), kernel_initializer: fc_initializer); }); with(tf.name_scope("fc-3"), scope => @@ -148,7 +148,7 @@ namespace TensorFlowNET.Examples.TextClassification // batch normalization conv = tf.layers.batch_normalization(conv, training: is_training); // relu - conv = tf.nn.relu.Activate(conv); + conv = tf.nn.relu(conv); conv = tf.transpose(conv, new int[] { 0, 1, 3, 2 }); }); } diff --git a/test/TensorFlowNET.UnitTest/ExamplesTests/ExamplesTest.cs b/test/TensorFlowNET.UnitTest/ExamplesTests/ExamplesTest.cs index 4f0f6a07..b0af5125 100644 --- a/test/TensorFlowNET.UnitTest/ExamplesTests/ExamplesTest.cs +++ b/test/TensorFlowNET.UnitTest/ExamplesTests/ExamplesTest.cs @@ -108,6 +108,16 @@ namespace TensorFlowNET.ExamplesTests tf.Graph().as_default(); new TextClassificationWithMovieReviews() { Enabled = true }.Run(); } + + [Ignore] + [TestMethod] + public void NeuralNetXor() + { + tf.Graph().as_default(); + new NeuralNetXor() { Enabled = true }.Run(); + } + + } }