diff --git a/src/TensorFlowNET.Core/APIs/tf.math.cs b/src/TensorFlowNET.Core/APIs/tf.math.cs
index 5b10813d..53a3b097 100644
--- a/src/TensorFlowNET.Core/APIs/tf.math.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.math.cs
@@ -298,5 +298,8 @@ namespace Tensorflow
public static Tensor argmax(Tensor input, int axis = -1, string name = null, int? dimension = null, TF_DataType output_type = TF_DataType.TF_INT64)
=> gen_math_ops.arg_max(input, axis, name: name, output_type: output_type);
+
+ public static Tensor square(Tensor x, string name = null)
+ => gen_math_ops.square(x, name: name);
}
}
diff --git a/src/TensorFlowNET.Core/APIs/tf.nn.cs b/src/TensorFlowNET.Core/APIs/tf.nn.cs
index 8a1b648e..1448d9ae 100644
--- a/src/TensorFlowNET.Core/APIs/tf.nn.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.nn.cs
@@ -26,7 +26,9 @@ namespace Tensorflow
partition_strategy: partition_strategy,
name: name);
- public static IActivation relu => new relu();
+ public static IActivation relu() => new relu();
+
+ public static Tensor relu(Tensor features, string name = null) => gen_nn_ops.relu(features, name);
public static Tensor[] fused_batch_norm(Tensor x,
RefVariable scale,
diff --git a/src/TensorFlowNET.Core/APIs/tf.random.cs b/src/TensorFlowNET.Core/APIs/tf.random.cs
index 3e273e61..b309fc06 100644
--- a/src/TensorFlowNET.Core/APIs/tf.random.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.random.cs
@@ -29,5 +29,13 @@ namespace Tensorflow
TF_DataType dtype = TF_DataType.TF_FLOAT,
int? seed = null,
string name = null) => random_ops.random_uniform(shape, minval, maxval, dtype, seed, name);
+
+ public static Tensor truncated_normal(int[] shape,
+ float mean = 0.0f,
+ float stddev = 1.0f,
+ TF_DataType dtype = TF_DataType.TF_FLOAT,
+ int? seed = null,
+ string name = null)
+ => random_ops.truncated_normal(shape, mean, stddev, dtype, seed, name);
}
}
diff --git a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs
index 72ca1c1b..9d0d49b1 100644
--- a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs
@@ -1,6 +1,8 @@
using System;
using System.Collections.Generic;
+using System.Linq;
using System.Text;
+using Tensorflow.Operations.Activation;
namespace Tensorflow.Operations
{
@@ -173,5 +175,59 @@ namespace Tensorflow.Operations
return (_op.outputs[0], _op.outputs[1]);
}
+
+ ///
+ /// Computes rectified linear: `max(features, 0)`.
+ ///
+ /// A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `uint8`, `int16`, `int8`, `int64`, `bfloat16`, `uint16`, `half`, `uint32`, `uint64`, `qint8`.
+ /// A name for the operation (optional).
+ /// A `Tensor`. Has the same type as `features`.
+ public static Tensor relu(Tensor features, string name = null)
+ {
+
+ //_ctx = _context._context
+ //if _ctx is not None and _ctx._eager_context.is_eager:
+ // try:
+ // _result = _pywrap_tensorflow.TFE_Py_FastPathExecute(
+ // _ctx._context_handle, _ctx._eager_context.device_name, "Relu", name,
+ // _ctx._post_execution_callbacks, features)
+ // return _result
+ // except _core._FallbackException:
+ // try:
+ // return relu_eager_fallback(
+ // features, name=name, ctx=_ctx)
+ // except _core._SymbolicException:
+ // pass # Add nodes to the TensorFlow graph.
+ // except (TypeError, ValueError):
+ // result = _dispatch.dispatch(
+ // relu, features=features, name=name)
+ // if result is not _dispatch.OpDispatcher.NOT_SUPPORTED:
+ // return result
+ // raise
+ // except _core._NotOkStatusException as e:
+ // if name is not None:
+ // message = e.message + " name: " + name
+ // else:
+ // message = e.message
+ // _six.raise_from(_core._status_to_exception(e.code, message), None)
+ //# Add nodes to the TensorFlow graph.
+ //try:
+ OpDefLibrary _op_def_lib = new OpDefLibrary();
+ var _op = _op_def_lib._apply_op_helper("Relu", name: name, args: new { features });
+ return _op.outputs[0];
+ //except (TypeError, ValueError):
+ // result = _dispatch.dispatch(
+ // relu, features=features, name=name)
+ // if result is not _dispatch.OpDispatcher.NOT_SUPPORTED:
+ // return result
+ // raise
+ // var _result = _op.outputs.ToArray();
+ //_inputs_flat = _op.inputs
+ //_attrs = ("T", _op.get_attr("T"))
+ //_execute.record_gradient(
+ // "Relu", _inputs_flat, _attrs, _result, name)
+ //_result, = _result
+ // return _result;
+ }
}
}
diff --git a/test/TensorFlowNET.Examples/NeuralNetXor.cs b/test/TensorFlowNET.Examples/NeuralNetXor.cs
new file mode 100644
index 00000000..3e64d1bd
--- /dev/null
+++ b/test/TensorFlowNET.Examples/NeuralNetXor.cs
@@ -0,0 +1,96 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+using NumSharp;
+using Tensorflow;
+
+namespace TensorFlowNET.Examples
+{
+ ///
+ /// Simple vanilla neural net solving the famous XOR problem
+ /// https://github.com/amygdala/tensorflow-workshop/blob/master/workshop_sections/getting_started/xor/README.md
+ ///
+ public class NeuralNetXor : Python, IExample
+ {
+ public int Priority => 2;
+ public bool Enabled { get; set; } = true;
+ public string Name => "NN XOR";
+
+ public int num_steps = 5000;
+
+ private (Operation, Tensor, RefVariable) make_graph(Tensor features,Tensor labels, int num_hidden = 8)
+ {
+ var stddev = 1 / Math.Sqrt(2);
+ var hidden_weights = tf.Variable(tf.truncated_normal(new int []{2, num_hidden}, stddev: (float) stddev ));
+
+ // Shape [4, num_hidden]
+ var hidden_activations = tf.nn.relu(tf.matmul(features, hidden_weights));
+
+ var output_weights = tf.Variable(tf.truncated_normal(
+ new[] {num_hidden, 1},
+ stddev: (float) (1 / Math.Sqrt(num_hidden))
+ ));
+
+ // Shape [4, 1]
+ var logits = tf.matmul(hidden_activations, output_weights);
+
+ // Shape [4]
+ var predictions = tf.sigmoid(tf.squeeze(logits));
+ var loss = tf.reduce_mean(tf.square(predictions - tf.cast(labels, tf.float32)));
+
+ var gs = tf.Variable(0, trainable: false);
+ var train_op = tf.train.GradientDescentOptimizer(0.2f).minimize(loss, global_step: gs);
+
+ return (train_op, loss, gs);
+ }
+
+ public bool Run()
+ {
+
+ var graph = tf.Graph();
+
+ var init=with(graph.as_default(), g =>
+ {
+ var features = tf.placeholder(tf.float32, new TensorShape(4, 2));
+ var labels = tf.placeholder(tf.int32, new TensorShape(4));
+
+ var (train_op, loss, gs) = make_graph(features, labels);
+ return tf.global_variables_initializer();
+ });
+
+ // Start tf session
+ with(tf.Session(), sess =>
+ {
+ init.run();
+ var step = 0;
+ var xy = np.array(new bool[,]
+ {
+ {true, false, },
+ {true, true, },
+ {false, false, },
+ {false, true, },
+ }, dtype: np.float32);
+
+ var y_ = np.array(new[] {true, false, false, true}, dtype: np.int32);
+ while (step < num_steps)
+ {
+ // original python:
+ //_, step, loss_value = sess.run(
+ // [train_op, gs, loss],
+ // feed_dict={features: xy, labels: y_}
+ // )
+ // TODO: how the hell to port that to c#?
+ // var ( _, step, loss_value) = sess.run(new object[] {train_op, gs, loss},feed_dict: new {"features": xy, "labels": y_});
+ }
+ //tf.logging.info('Final loss is: {}'.format(loss_value))
+ //Console.WriteLine($"Final loss is: {loss_value}");
+
+ });
+ return true;
+ }
+
+ public void PrepareData()
+ {
+ }
+ }
+}
diff --git a/test/TensorFlowNET.Examples/TextClassification/cnn_models/VdCnn.cs b/test/TensorFlowNET.Examples/TextClassification/cnn_models/VdCnn.cs
index d01b458d..da6f49b6 100644
--- a/test/TensorFlowNET.Examples/TextClassification/cnn_models/VdCnn.cs
+++ b/test/TensorFlowNET.Examples/TextClassification/cnn_models/VdCnn.cs
@@ -66,7 +66,7 @@ namespace TensorFlowNET.Examples.TextClassification
filters: num_filters[0],
kernel_size: new int[] { filter_sizes[0], embedding_size },
kernel_initializer: cnn_initializer,
- activation: tf.nn.relu);
+ activation: tf.nn.relu());
conv0 = tf.transpose(conv0, new int[] { 0, 1, 3, 2 });
});
@@ -99,12 +99,12 @@ namespace TensorFlowNET.Examples.TextClassification
// ============= Fully Connected Layers =============
with(tf.name_scope("fc-1"), scope =>
{
- fc1_out = tf.layers.dense(h_flat, 2048, activation: tf.nn.relu, kernel_initializer: fc_initializer);
+ fc1_out = tf.layers.dense(h_flat, 2048, activation: tf.nn.relu(), kernel_initializer: fc_initializer);
});
with(tf.name_scope("fc-2"), scope =>
{
- fc2_out = tf.layers.dense(fc1_out, 2048, activation: tf.nn.relu, kernel_initializer: fc_initializer);
+ fc2_out = tf.layers.dense(fc1_out, 2048, activation: tf.nn.relu(), kernel_initializer: fc_initializer);
});
with(tf.name_scope("fc-3"), scope =>
@@ -148,7 +148,7 @@ namespace TensorFlowNET.Examples.TextClassification
// batch normalization
conv = tf.layers.batch_normalization(conv, training: is_training);
// relu
- conv = tf.nn.relu.Activate(conv);
+ conv = tf.nn.relu(conv);
conv = tf.transpose(conv, new int[] { 0, 1, 3, 2 });
});
}
diff --git a/test/TensorFlowNET.UnitTest/ExamplesTests/ExamplesTest.cs b/test/TensorFlowNET.UnitTest/ExamplesTests/ExamplesTest.cs
index 4f0f6a07..b0af5125 100644
--- a/test/TensorFlowNET.UnitTest/ExamplesTests/ExamplesTest.cs
+++ b/test/TensorFlowNET.UnitTest/ExamplesTests/ExamplesTest.cs
@@ -108,6 +108,16 @@ namespace TensorFlowNET.ExamplesTests
tf.Graph().as_default();
new TextClassificationWithMovieReviews() { Enabled = true }.Run();
}
+
+ [Ignore]
+ [TestMethod]
+ public void NeuralNetXor()
+ {
+ tf.Graph().as_default();
+ new NeuralNetXor() { Enabled = true }.Run();
+ }
+
+
}
}