From 3c7174b702b6545a3f922bababe19bc2690618eb Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Mon, 15 Apr 2019 21:10:06 -0500 Subject: [PATCH] add tf.math.sigmoid --- src/TensorFlowNET.Core/APIs/tf.math.cs | 3 +++ .../Operations/gen_math_ops.cs | 20 +++++++++++++++++++ src/TensorFlowNET.Core/Operations/math_ops.cs | 8 +++++++- test/TensorFlowNET.UnitTest/PythonTest.cs | 2 +- 4 files changed, 31 insertions(+), 2 deletions(-) diff --git a/src/TensorFlowNET.Core/APIs/tf.math.cs b/src/TensorFlowNET.Core/APIs/tf.math.cs index 5c31fce0..5b10813d 100644 --- a/src/TensorFlowNET.Core/APIs/tf.math.cs +++ b/src/TensorFlowNET.Core/APIs/tf.math.cs @@ -281,6 +281,9 @@ namespace Tensorflow return math_ops.reduce_sum(input, axis); } + public static Tensor sigmoid(T x, string name = null) + => math_ops.sigmoid(x, name: name); + public static Tensor sum(Tensor input, int axis, bool keep_dims = false, string name = null) => gen_math_ops._sum(input, axis, keep_dims: keep_dims, name: name); diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs index ee6f3a39..72beeb9f 100644 --- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs @@ -108,6 +108,26 @@ namespace Tensorflow return _op.outputs[0]; } + /// + /// Computes sigmoid of x element-wise. + /// + /// + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Sigmoid'. + /// + /// + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Specifically, y = 1 / (1 + exp(-x)). + /// + public static Tensor sigmoid(Tensor x, string name = "Sigmoid") + { + var op = _op_def_lib._apply_op_helper("Sigmoid", name: name, new { x }); + + return op.output; + } public static Tensor sinh(Tensor x, string name = null) { var _op = _op_def_lib._apply_op_helper("Sinh", name, args: new { x }); diff --git a/src/TensorFlowNET.Core/Operations/math_ops.cs b/src/TensorFlowNET.Core/Operations/math_ops.cs index 996f4686..a1538b9c 100644 --- a/src/TensorFlowNET.Core/Operations/math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/math_ops.cs @@ -119,7 +119,13 @@ namespace Tensorflow return _may_reduce_to_scalar(keepdims, axis, m); } } - + + public static Tensor sigmoid(T x, string name = null) + { + var x_tensor = ops.convert_to_tensor(x, name: "x"); + return gen_math_ops.sigmoid(x_tensor, name: name); + } + /// /// Returns (x - y)(x - y) element-wise. /// diff --git a/test/TensorFlowNET.UnitTest/PythonTest.cs b/test/TensorFlowNET.UnitTest/PythonTest.cs index dcd8e80b..3761455e 100644 --- a/test/TensorFlowNET.UnitTest/PythonTest.cs +++ b/test/TensorFlowNET.UnitTest/PythonTest.cs @@ -143,7 +143,7 @@ namespace TensorFlowNET.UnitTest // return self._eval_helper(tensors) // else: { - with(ops.get_default_session(), s => + with(tf.Session(), s => { var ndarray=tensor.eval(); if (typeof(T) == typeof(double))