diff --git a/src/TensorFlowNET.Core/APIs/tf.nn.cs b/src/TensorFlowNET.Core/APIs/tf.nn.cs
index 2a6b125b..05b01b69 100644
--- a/src/TensorFlowNET.Core/APIs/tf.nn.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.nn.cs
@@ -116,8 +116,11 @@ namespace Tensorflow
public IActivation relu() => new relu();
public IActivation swish() => new swish();
public IActivation tanh() => new tanh();
+ public Tensor tanh(Tensor x, string name = null)
+ => gen_nn_ops.tanh(x, name);
- public Tensor relu(Tensor features, string name = null) => gen_nn_ops.relu(features, name);
+ public Tensor relu(Tensor features, string name = null)
+ => gen_nn_ops.relu(features, name);
public Tensor[] fused_batch_norm(Tensor x,
VariableV1 scale,
@@ -212,6 +215,14 @@ namespace Tensorflow
public Tensor softmax_cross_entropy_with_logits_v2(Tensor labels, Tensor logits, int axis = -1, string name = null)
=> nn_ops.softmax_cross_entropy_with_logits_v2_helper(labels, logits, axis: axis, name: name);
+ ///
+ /// Computes sigmoid of `x` element-wise.
+ /// Specifically, `y = 1 / (1 + exp(-x))`.
+ ///
+ ///
+ ///
+ /// A name for the operation (optional).
+ /// A Tensor with the same type as `x`.
public Tensor sigmoid(T x, string name = null)
=> math_ops.sigmoid(x, name: name);
}
diff --git a/src/TensorFlowNET.Core/APIs/tf.ops.cs b/src/TensorFlowNET.Core/APIs/tf.ops.cs
index 571d57b2..86e979c4 100644
--- a/src/TensorFlowNET.Core/APIs/tf.ops.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.ops.cs
@@ -33,6 +33,9 @@ namespace Tensorflow
public Tensor assign(RefVariable @ref, object value, bool validate_shape = true, bool use_locking = true, string name = null)
=> state_ops.assign(@ref, value, validate_shape, use_locking, name);
+ public Tensor assign(ResourceVariable @ref, object value, bool validate_shape = true, bool use_locking = true, string name = null)
+ => state_ops.assign(@ref, value, validate_shape, use_locking, name);
+
public void device(string device_name)
=> get_default_graph().device(device_name);
diff --git a/src/TensorFlowNET.Core/Clustering/KMeans.cs b/src/TensorFlowNET.Core/Clustering/KMeans.cs
index e3f2ab2a..04a0375f 100644
--- a/src/TensorFlowNET.Core/Clustering/KMeans.cs
+++ b/src/TensorFlowNET.Core/Clustering/KMeans.cs
@@ -98,10 +98,10 @@ namespace Tensorflow.Clustering
var cluster_counts = _use_mini_batch ? tf.Variable(ones) : null;
return new RefVariable[]
{
- cluster_centers,
+ /*cluster_centers,
cluster_centers_initialized,
cluster_counts,
- cluster_centers_updated,
+ cluster_centers_updated,*/
update_in_steps
};
}
diff --git a/src/TensorFlowNET.Core/Eager/wrap_tfe_src.TFE_FastPathExecute.cs b/src/TensorFlowNET.Core/Eager/wrap_tfe_src.TFE_FastPathExecute.cs
index 25ed69f5..c66aaeef 100644
--- a/src/TensorFlowNET.Core/Eager/wrap_tfe_src.TFE_FastPathExecute.cs
+++ b/src/TensorFlowNET.Core/Eager/wrap_tfe_src.TFE_FastPathExecute.cs
@@ -11,7 +11,7 @@ namespace Tensorflow.Eager
public partial class wrap_tfe_src
{
static int kFastPathExecuteInputStartIndex = 0;
- public static EagerTensor TFE_Py_FastPathExecute(Context ctx,
+ public static EagerTensor TFE_FastPathExecute(Context ctx,
string device_name,
string opName,
string name,
diff --git a/src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs b/src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs
index 6e2dc745..d7dd1440 100644
--- a/src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs
+++ b/src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs
@@ -46,8 +46,7 @@ namespace Tensorflow.Keras.Utils
Func init_val = () => initializer.call(new TensorShape(shape), dtype: dtype);
var variable_dtype = dtype.as_base_dtype();
- var v = tf.VariableV1(init_val,
- use_resource: use_resource,
+ var v = tf.Variable(init_val,
dtype: dtype,
shape: shape,
name: name);
diff --git a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs
index 3761cdfe..63aeff53 100644
--- a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs
@@ -14,6 +14,7 @@
limitations under the License.
******************************************************************************/
+using Tensorflow.Eager;
using static Tensorflow.Binding;
namespace Tensorflow.Operations
@@ -463,50 +464,30 @@ namespace Tensorflow.Operations
/// A `Tensor`. Has the same type as `features`.
public static Tensor relu(Tensor features, string name = null)
{
+ if (tf.context.executing_eagerly())
+ {
+ var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, tf.context.device_name,
+ "Relu", name, null,
+ features);
+ return _result;
+ }
- //_ctx = _context._context
- //if _ctx is not None and _ctx._eager_context.is_eager:
- // try:
- // _result = _pywrap_tensorflow.TFE_Py_FastPathExecute(
- // _ctx._context_handle, _ctx._eager_context.device_name, "Relu", name,
- // _ctx._post_execution_callbacks, features)
- // return _result
- // except _core._FallbackException:
- // try:
- // return relu_eager_fallback(
- // features, name=name, ctx=_ctx)
- // except _core._SymbolicException:
- // pass # Add nodes to the TensorFlow graph.
- // except (TypeError, ValueError):
- // result = _dispatch.dispatch(
- // relu, features=features, name=name)
- // if result is not _dispatch.OpDispatcher.NOT_SUPPORTED:
- // return result
- // raise
- // except _core._NotOkStatusException as e:
- // if name is not None:
- // message = e.message + " name: " + name
- // else:
- // message = e.message
- // _six.raise_from(_core._status_to_exception(e.code, message), None)
- //# Add nodes to the TensorFlow graph.
- //try:
- OpDefLibrary _op_def_lib = new OpDefLibrary();
var _op = _op_def_lib._apply_op_helper("Relu", name: name, args: new { features });
return _op.outputs[0];
- //except (TypeError, ValueError):
- // result = _dispatch.dispatch(
- // relu, features=features, name=name)
- // if result is not _dispatch.OpDispatcher.NOT_SUPPORTED:
- // return result
- // raise
- // var _result = _op.outputs.ToArray();
- //_inputs_flat = _op.inputs
- //_attrs = ("T", _op.get_attr("T"))
- //_execute.record_gradient(
- // "Relu", _inputs_flat, _attrs, _result, name)
- //_result, = _result
- // return _result;
+ }
+
+ public static Tensor tanh(Tensor x, string name = null)
+ {
+ if (tf.context.executing_eagerly())
+ {
+ var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, tf.context.device_name,
+ "Tanh", name, null,
+ x);
+ return _result;
+ }
+
+ var _op = _op_def_lib._apply_op_helper("Tanh", name: name, args: new { x });
+ return _op.outputs[0];
}
}
}
diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs
index eea1f1a1..38331c0d 100644
--- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs
@@ -125,7 +125,7 @@ namespace Tensorflow
{
if(tf.context.executing_eagerly())
{
- var _result = wrap_tfe_src.TFE_Py_FastPathExecute(tf.context, tf.context.device_name, "Pack", name, null, values, "axis", axis);
+ var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, tf.context.device_name, "Pack", name, null, values, "axis", axis);
return _result;
}
diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs
index 80c27c73..0a053808 100644
--- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs
@@ -120,7 +120,7 @@ namespace Tensorflow
{
try
{
- var _result = wrap_tfe_src.TFE_Py_FastPathExecute(tf.context, tf.context.device_name, "Mean", name, null, input, axis, "keep_dims", keep_dims);
+ var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, tf.context.device_name, "Mean", name, null, input, axis, "keep_dims", keep_dims);
return _result;
}
catch (Exception ex)
@@ -171,7 +171,7 @@ namespace Tensorflow
{
if (tf.context.executing_eagerly())
{
- var _result = wrap_tfe_src.TFE_Py_FastPathExecute(tf.context, "", "Add", name, null, x, y);
+ var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, "", "Add", name, null, x, y);
return _result;
}
@@ -204,6 +204,14 @@ namespace Tensorflow
public static Tensor sin(Tensor x, string name = null)
{
+ if (tf.context.executing_eagerly())
+ {
+ var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, tf.context.device_name,
+ "Sin", name, null,
+ x);
+ return _result;
+ }
+
var _op = _op_def_lib._apply_op_helper("Sin", name, args: new { x });
return _op.outputs[0];
@@ -225,6 +233,14 @@ namespace Tensorflow
///
public static Tensor sigmoid(Tensor x, string name = "Sigmoid")
{
+ if (tf.context.executing_eagerly())
+ {
+ var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, tf.context.device_name,
+ "Sigmoid", name, null,
+ x);
+ return _result;
+ }
+
var op = _op_def_lib._apply_op_helper("Sigmoid", name: name, new { x });
return op.output;
@@ -493,7 +509,7 @@ namespace Tensorflow
{
if (tf.context.executing_eagerly())
{
- var _result = wrap_tfe_src.TFE_Py_FastPathExecute(tf.context, "", "Cast", name, null, x, "DstT", DstT, "Truncate", Truncate);
+ var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, "", "Cast", name, null, x, "DstT", DstT, "Truncate", Truncate);
return _result;
}
@@ -520,7 +536,7 @@ namespace Tensorflow
{
if (tf.context.executing_eagerly())
{
- var _result = wrap_tfe_src.TFE_Py_FastPathExecute(tf.context, "", "Sub", name, null, x, y);
+ var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, "", "Sub", name, null, x, y);
return _result;
}
@@ -571,7 +587,7 @@ namespace Tensorflow
{
if (tf.context.executing_eagerly())
{
- var _result = wrap_tfe_src.TFE_Py_FastPathExecute(tf.context, "", "Mul", name, null, x, y);
+ var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, "", "Mul", name, null, x, y);
return _result;
}
@@ -591,7 +607,7 @@ namespace Tensorflow
{
if (tf.context.executing_eagerly())
{
- var _result = wrap_tfe_src.TFE_Py_FastPathExecute(tf.context, "", "RealDiv", name, null, x, y);
+ var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, "", "RealDiv", name, null, x, y);
return _result;
}
@@ -618,7 +634,7 @@ namespace Tensorflow
{
if (tf.context.executing_eagerly())
{
- var _result = wrap_tfe_src.TFE_Py_FastPathExecute(tf.context, "", "FloorDiv", name, null, x, y);
+ var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, "", "FloorDiv", name, null, x, y);
return _result;
}
@@ -640,7 +656,7 @@ namespace Tensorflow
{
if (tf.context.executing_eagerly())
{
- var _result = wrap_tfe_src.TFE_Py_FastPathExecute(tf.context, tf.context.device_name,
+ var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, tf.context.device_name,
"MatMul", name, null,
a, b, "transpose_a", transpose_a, "transpose_b", transpose_b);
return _result;
@@ -748,7 +764,7 @@ namespace Tensorflow
{
try
{
- var _result = wrap_tfe_src.TFE_Py_FastPathExecute(tf.context, tf.context.device_name,
+ var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, tf.context.device_name,
"Sum", name, null,
input, axis, "keep_dims", keep_dims);
return _result;
@@ -789,7 +805,7 @@ namespace Tensorflow
{
if (tf.context.executing_eagerly())
{
- var _result = wrap_tfe_src.TFE_Py_FastPathExecute(tf.context, tf.context.device_name, "Range", name, null, start, limit, delta);
+ var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, tf.context.device_name, "Range", name, null, start, limit, delta);
return _result;
}
diff --git a/src/TensorFlowNET.Core/Operations/math_ops.cs b/src/TensorFlowNET.Core/Operations/math_ops.cs
index de9d761c..ce089032 100644
--- a/src/TensorFlowNET.Core/Operations/math_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/math_ops.cs
@@ -278,10 +278,12 @@ namespace Tensorflow
}
public static Tensor sigmoid(T x, string name = null)
- {
- var x_tensor = ops.convert_to_tensor(x, name: "x");
- return gen_math_ops.sigmoid(x_tensor, name: name);
- }
+ => tf_with(ops.name_scope(name, "Sigmoid", x), scope =>
+ {
+ name = scope;
+ var x_tensor = ops.convert_to_tensor(x, name: "x");
+ return gen_math_ops.sigmoid(x_tensor, name: name);
+ });
public static Tensor sign(T x, string name = null)
=> gen_math_ops.sign(x, name: name);
diff --git a/src/TensorFlowNET.Core/Tensors/constant_op.cs b/src/TensorFlowNET.Core/Tensors/constant_op.cs
index 0d47b8ba..204baaae 100644
--- a/src/TensorFlowNET.Core/Tensors/constant_op.cs
+++ b/src/TensorFlowNET.Core/Tensors/constant_op.cs
@@ -91,6 +91,8 @@ namespace Tensorflow
return new EagerTensor(str, ctx.device_name);
case int int32:
return new EagerTensor(int32, ctx.device_name);
+ case float[] float32s:
+ return new EagerTensor(float32s, ctx.device_name);
default:
throw new NotImplementedException($"convert_to_eager_tensor {value.GetType()}");
}
diff --git a/src/TensorFlowNET.Core/Variables/ResourceVariable.Implicit.cs b/src/TensorFlowNET.Core/Variables/ResourceVariable.Implicit.cs
new file mode 100644
index 00000000..112b9610
--- /dev/null
+++ b/src/TensorFlowNET.Core/Variables/ResourceVariable.Implicit.cs
@@ -0,0 +1,30 @@
+namespace Tensorflow
+{
+ public partial class ResourceVariable
+ {
+ public static implicit operator _VariableScopeStore(ResourceVariable variable)
+ {
+ return null;
+ }
+
+ public static implicit operator ResourceVariable(_VariableScopeStore store)
+ {
+ return null;
+ }
+
+ public static implicit operator Tensor(ResourceVariable var)
+ {
+ return null;
+ }
+
+ public static implicit operator ResourceVariable(Tensor var)
+ {
+ return null;
+ }
+
+ public static implicit operator RefVariable(ResourceVariable var)
+ {
+ return null;
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Variables/ResourceVariable.cs b/src/TensorFlowNET.Core/Variables/ResourceVariable.cs
index 7b887e22..1209e442 100644
--- a/src/TensorFlowNET.Core/Variables/ResourceVariable.cs
+++ b/src/TensorFlowNET.Core/Variables/ResourceVariable.cs
@@ -24,7 +24,7 @@ namespace Tensorflow
///
/// Variable based on resource handles.
///
- public class ResourceVariable : VariableV1
+ public partial class ResourceVariable : VariableV1
{
bool _in_graph_mode;
Tensor _handle;
diff --git a/src/TensorFlowNET.Core/Variables/VariableV1.cs b/src/TensorFlowNET.Core/Variables/VariableV1.cs
index 8f873291..6e1bd8f3 100644
--- a/src/TensorFlowNET.Core/Variables/VariableV1.cs
+++ b/src/TensorFlowNET.Core/Variables/VariableV1.cs
@@ -14,6 +14,7 @@
limitations under the License.
******************************************************************************/
+using System;
using System.Collections.Generic;
namespace Tensorflow
@@ -50,5 +51,18 @@ namespace Tensorflow
{
}
+
+ public virtual Tensor eval()
+ {
+ throw new NotImplementedException("");
+ }
+
+ public virtual ITensorOrOperation assign(object value, bool use_locking = false, string name = null, bool read_value = true)
+ {
+ var assign = gen_state_ops.assign(_variable, value, use_locking: use_locking, name: name);
+ if (read_value)
+ return assign;
+ return assign.op;
+ }
}
}
diff --git a/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs b/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs
index 7cdea327..64ce28a7 100644
--- a/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs
+++ b/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs
@@ -101,6 +101,26 @@ namespace Tensorflow
return _result[0];
}
+ public static Tensor assign(ResourceVariable @ref, object value,
+ bool validate_shape = true,
+ bool use_locking = true,
+ string name = null)
+ {
+ var _op = _op_def_lib._apply_op_helper("Assign", name: name, args: new { @ref, value, validate_shape, use_locking });
+
+ var _result = _op.outputs;
+ var _inputs_flat = _op.inputs;
+
+ var _attrs = new Dictionary();
+ _attrs["T"] = _op.get_attr("T");
+ _attrs["validate_shape"] = _op.get_attr("validate_shape");
+ _attrs["use_locking"] = _op.get_attr("use_locking");
+
+ _execute.record_gradient("Assign", _inputs_flat, _attrs, _result, name);
+
+ return _result[0];
+ }
+
public static Tensor assign_sub(RefVariable @ref,
Tensor value,
bool use_locking = false,
diff --git a/src/TensorFlowNET.Core/Variables/state_ops.cs b/src/TensorFlowNET.Core/Variables/state_ops.cs
index 01a40bee..b87512c3 100644
--- a/src/TensorFlowNET.Core/Variables/state_ops.cs
+++ b/src/TensorFlowNET.Core/Variables/state_ops.cs
@@ -66,6 +66,18 @@ namespace Tensorflow
name: name);
}
+ public static Tensor assign(ResourceVariable @ref, object value,
+ bool validate_shape = true,
+ bool use_locking = true,
+ string name = null)
+ {
+ return gen_state_ops.assign(@ref,
+ value,
+ validate_shape: validate_shape,
+ use_locking: use_locking,
+ name: name);
+ }
+
public static Tensor assign_sub(RefVariable @ref,
Tensor value,
bool use_locking = false,
diff --git a/src/TensorFlowNET.Core/tensorflow.cs b/src/TensorFlowNET.Core/tensorflow.cs
index 7f0d30fd..76a614e1 100644
--- a/src/TensorFlowNET.Core/tensorflow.cs
+++ b/src/TensorFlowNET.Core/tensorflow.cs
@@ -43,34 +43,19 @@ namespace Tensorflow
- public RefVariable Variable(T data,
- bool trainable = true,
- bool validate_shape = true,
- string name = null,
- TF_DataType dtype = TF_DataType.DtInvalid)
- {
- return Tensorflow.variable_scope.default_variable_creator(data,
- trainable: trainable,
- validate_shape: validate_shape,
- name: name,
- dtype: dtype) as RefVariable;
- }
-
- public VariableV1 VariableV1(T data,
+ public ResourceVariable Variable(T data,
bool trainable = true,
bool validate_shape = true,
string name = null,
TF_DataType dtype = TF_DataType.DtInvalid,
- bool use_resource = false,
int[] shape = null)
{
- return Tensorflow.variable_scope.default_variable_creator(data,
- trainable: trainable,
- validate_shape: validate_shape,
- name: name,
- dtype: dtype,
- use_resource: use_resource,
- shape: shape);
+ return new ResourceVariable(data,
+ trainable: trainable,
+ validate_shape: validate_shape,
+ name: name,
+ dtype: dtype,
+ shape: shape);
}
public unsafe Tensor placeholder(TF_DataType dtype, TensorShape shape = null, string name = null)
diff --git a/test/TensorFlowNET.UnitTest/Tensorflow.UnitTest.csproj b/test/TensorFlowNET.UnitTest/Tensorflow.UnitTest.csproj
index b97ae1c1..ef09f5e6 100644
--- a/test/TensorFlowNET.UnitTest/Tensorflow.UnitTest.csproj
+++ b/test/TensorFlowNET.UnitTest/Tensorflow.UnitTest.csproj
@@ -11,7 +11,7 @@
Open.snk
- latest
+ 8.0
@@ -33,6 +33,7 @@
+
diff --git a/test/TensorFlowNET.UnitTest/VariableTest.cs b/test/TensorFlowNET.UnitTest/VariableTest.cs
index 5fb65d66..1fd88fe7 100644
--- a/test/TensorFlowNET.UnitTest/VariableTest.cs
+++ b/test/TensorFlowNET.UnitTest/VariableTest.cs
@@ -111,7 +111,7 @@ namespace TensorFlowNET.UnitTest
public void Assign2()
{
var v1 = tf.Variable(10.0f, name: "v1"); //tf.get_variable("v1", shape: new TensorShape(3), initializer: tf.zeros_initializer);
- var inc_v1 = v1.assign(v1 + 1.0f);
+ var inc_v1 = v1.assign((RefVariable)v1 + 1.0f);
// Add an op to initialize the variables.
var init_op = tf.global_variables_initializer();
diff --git a/test/TensorFlowNET.UnitTest/math_test/MathOperationTest.cs b/test/TensorFlowNET.UnitTest/math_test/MathOperationTest.cs
new file mode 100644
index 00000000..ac729704
--- /dev/null
+++ b/test/TensorFlowNET.UnitTest/math_test/MathOperationTest.cs
@@ -0,0 +1,26 @@
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using Tensorflow;
+using static Tensorflow.Binding;
+
+namespace TensorFlowNET.UnitTest.math_test
+{
+ [TestClass]
+ public class MathOperationTest
+ {
+ // A constant vector of size 6
+ Tensor a = tf.constant(new float[] { 1.0f, -0.5f, 3.4f, -2.1f, 0.0f, -6.5f });
+
+ [TestMethod]
+ public void Sin()
+ {
+ var b = tf.sin(a, name: "sin");
+ var expected = new float[] { 0.84147096f, -0.47942555f, -0.2555412f, -0.8632094f /*python output -0.86320937*/, 0f, -0.21511999f };
+ var actual = b.ToArray();
+ Assert.IsTrue(Enumerable.SequenceEqual(expected, actual));
+ }
+ }
+}
diff --git a/test/TensorFlowNET.UnitTest/nn_test/ActivationFunctionTest.cs b/test/TensorFlowNET.UnitTest/nn_test/ActivationFunctionTest.cs
new file mode 100644
index 00000000..63163bfb
--- /dev/null
+++ b/test/TensorFlowNET.UnitTest/nn_test/ActivationFunctionTest.cs
@@ -0,0 +1,46 @@
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using Tensorflow;
+using static Tensorflow.Binding;
+
+namespace TensorFlowNET.UnitTest.nn_test
+{
+ [TestClass]
+ public class ActivationFunctionTest
+ {
+ // A constant vector of size 6
+ Tensor a = tf.constant(new float[] { 1.0f, -0.5f, 3.4f, -2.1f, 0.0f, -6.5f });
+
+ [TestMethod]
+ public void Sigmoid()
+ {
+ var b = tf.nn.sigmoid(a, name: "sigmoid");
+ // from python
+ // [0.7310586f, 0.37754068f, 0.9677046f, 0.10909683f, 0.5f, 0.00150118f]
+ var expected = new float[] { 0.7310586f, 0.377540678f, 0.9677046f, 0.109096833f, 0.5f, 0.00150118221f };
+ var actual = b.ToArray();
+ Assert.IsTrue(Enumerable.SequenceEqual(expected, actual));
+ }
+
+ [TestMethod]
+ public void ReLU()
+ {
+ var b = tf.nn.relu(a, name: "ReLU");
+ var expected = new float[] { 1f, 0f, 3.4f, 0f, 0f, 0f };
+ var actual = b.ToArray();
+ Assert.IsTrue(Enumerable.SequenceEqual(expected, actual));
+ }
+
+ [TestMethod]
+ public void TanH()
+ {
+ var b = tf.nn.tanh(a, name: "TanH");
+ var expected = new float[] { 0.7615942f, -0.46211717f, 0.9977749f , -0.970452f, 0f, -0.99999547f };
+ var actual = b.ToArray();
+ Assert.IsTrue(Enumerable.SequenceEqual(expected, actual));
+ }
+ }
+}
diff --git a/test/TensorFlowNET.UnitTest/nn_test/ZeroFractionTest.cs b/test/TensorFlowNET.UnitTest/nn_test/ZeroFractionTest.cs
index 8097070b..95971165 100644
--- a/test/TensorFlowNET.UnitTest/nn_test/ZeroFractionTest.cs
+++ b/test/TensorFlowNET.UnitTest/nn_test/ZeroFractionTest.cs
@@ -3,7 +3,6 @@ using System.Linq;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using NumSharp;
using Tensorflow;
-using static Tensorflow.Binding;
namespace TensorFlowNET.UnitTest.nn_test
{