diff --git a/src/TensorFlowNET.Core/Graphs/python_api.graph.cs b/src/TensorFlowNET.Core/Graphs/python_api.graph.cs
new file mode 100644
index 00000000..b266580c
--- /dev/null
+++ b/src/TensorFlowNET.Core/Graphs/python_api.graph.cs
@@ -0,0 +1,20 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow.Graphs
+{
+ ///
+ /// Lots of other functions required for Operation control flow like AddControlInput, UpdateEdge, RemoveAllControlInputs etc are not exposed via C_API and there is a C implementation of it.
+ /// https://github.com/tensorflow/tensorflow/blob/r1.14/tensorflow/c/python_api.h
+ /// https://github.com/tensorflow/tensorflow/blob/r1.14/tensorflow/c/python_api.cc
+ ///
+ ///
+ public class python_api
+ {
+ public static void UpdateEdge(Graph graph, TF_Output new_src, TF_Input dst, Status status)
+ {
+
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs
index 763a4bd8..3dee5e9e 100644
--- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs
@@ -389,6 +389,23 @@ namespace Tensorflow
return _op.outputs[0];
}
+ ///
+ /// Returns the truth value of (x != y) element-wise.
+ ///
+ /// The type of the x.
+ /// The type of the y.
+ /// The x.
+ /// The y.
+ /// The name.
+ ///
+ public static Tensor not_equal(Tx x, Ty y, string name = null)
+ {
+ var _op = _op_def_lib._apply_op_helper("NotEqual", name, args: new { x, y });
+
+ return _op.outputs[0];
+ }
+
+
public static Tensor atan2(Tensor y, Tensor x, string name = null)
{
var _op = _op_def_lib._apply_op_helper("Atan2", name, args: new { y, x });
@@ -566,5 +583,18 @@ namespace Tensorflow
return _op.outputs[0];
}
+
+ ///
+ /// Returns the fraction of zeros in value.
+ ///
+ /// A tensor of numeric type.
+ /// A name for the operation (optional).
+ /// The fraction of zeros in value, with type float32.
+ public static Tensor zero_fraction(Tensor value, string name = null)
+ {
+ var _op = _op_def_lib._apply_op_helper("zero_fraction", name, new { value, name });
+
+ return _op.outputs[0];
+ }
}
}
diff --git a/src/TensorFlowNET.Core/Operations/nn_impl.py.cs b/src/TensorFlowNET.Core/Operations/nn_impl.py.cs
index ecbb8958..ef20d20e 100644
--- a/src/TensorFlowNET.Core/Operations/nn_impl.py.cs
+++ b/src/TensorFlowNET.Core/Operations/nn_impl.py.cs
@@ -101,9 +101,57 @@ namespace Tensorflow
name);
}
- public static Tensor zero_fraction(Tensor t)
+ ///
+ /// Same as math_ops.count_nonzero.
+ /// The reduction is done in dtype, which can be faster for 32-bit dtypes.
+ ///
+ /// The numeric tensor.
+ /// The reduction dtype.
+ /// number of nonzero values with type dtype
+ private static Tensor _count_nonzero(Tensor input_tensor, TF_DataType dtype = TF_DataType.TF_INT64)
+ {
+ return with(ops.name_scope("count_nonzero", "count_nonzero", new { input_tensor }), scope =>
+ {
+ var zero = array_ops.zeros(new NumSharp.Shape(), dtype: input_tensor.dtype);
+ var nonzero_count = math_ops.reduce_sum(
+ math_ops.cast(gen_math_ops.not_equal(input_tensor, zero), dtype: dtype), name: "nonzero_count");
+ return nonzero_count;
+ });
+ }
+
+ ///
+ /// Returns the fraction of zeros in value.
+ ///
+ /// A tensor of numeric type.
+ /// A name for the operation (optional).
+ /// The fraction of zeros in value, with type float32.
+ public static Tensor zero_fraction(Tensor value, string name = null)
{
- throw new NotImplementedException();
+ return with(ops.name_scope(name, "zero_fraction", new { value }), scope =>
+ {
+
+ value = ops.convert_to_tensor(value, name: "value");
+ Tensor size = array_ops.size(value, out_type: dtypes.int64);
+ Func fu_true = () => math_ops.cast(_count_nonzero(value, dtype: dtypes.int32));
+ Tensor zero_fraction_float32 = null;
+
+ size = gen_math_ops.less_equal(size, dtypes.int32.max());
+ Tensor num_nonzero = control_flow_ops.cond(
+ size,
+ () => math_ops.cast(_count_nonzero(value, dtype: dtypes.int32)),
+ () => _count_nonzero(value, dtype: dtypes.int64)
+ );
+
+ with(ops.name_scope("counts_to_fraction"), count_scope =>
+ {
+ var num_zero = size - num_nonzero;
+ var num_zero_float32 = math_ops.cast(num_zero, dtype: dtypes.float32);
+ var size_float32 = math_ops.cast(size, dtype: dtypes.float32);
+ zero_fraction_float32 = num_zero_float32 / size_float32;
+ });
+
+ return array_ops.identity(zero_fraction_float32, "fraction");
+ });
}
}
}
diff --git a/src/TensorFlowNET.Core/Tensors/dtypes.cs b/src/TensorFlowNET.Core/Tensors/dtypes.cs
index 5b2fedc5..2b067fe1 100644
--- a/src/TensorFlowNET.Core/Tensors/dtypes.cs
+++ b/src/TensorFlowNET.Core/Tensors/dtypes.cs
@@ -7,8 +7,11 @@ namespace Tensorflow
public static class dtypes
{
public static TF_DataType int8 = TF_DataType.TF_INT8;
+ public static TF_DataType int32 = TF_DataType.TF_INT32;
+ public static TF_DataType int64 = TF_DataType.TF_INT64;
public static TF_DataType float32 = TF_DataType.TF_FLOAT; // is that float32?
public static TF_DataType float16 = TF_DataType.TF_HALF;
+ public static TF_DataType float64 = TF_DataType.TF_DOUBLE;
public static Type as_numpy_datatype(this TF_DataType type)
{
@@ -126,12 +129,24 @@ namespace Tensorflow
type;
}
- public static int max(this TF_DataType type)
+ public static long max(this TF_DataType type)
{
switch (type)
{
+ case TF_DataType.TF_INT8:
+ return sbyte.MaxValue;
+ case TF_DataType.TF_INT16:
+ return short.MaxValue;
+ case TF_DataType.TF_INT32:
+ return int.MaxValue;
+ case TF_DataType.TF_INT64:
+ return long.MaxValue;
case TF_DataType.TF_UINT8:
- return 255;
+ return byte.MaxValue;
+ case TF_DataType.TF_UINT16:
+ return ushort.MaxValue;
+ case TF_DataType.TF_UINT32:
+ return uint.MaxValue;
default:
throw new NotImplementedException($"max {type.name()}");
}
diff --git a/src/TensorFlowNET.Core/Tensors/tensor_util.cs b/src/TensorFlowNET.Core/Tensors/tensor_util.cs
index ddf7814f..a02e1908 100644
--- a/src/TensorFlowNET.Core/Tensors/tensor_util.cs
+++ b/src/TensorFlowNET.Core/Tensors/tensor_util.cs
@@ -176,6 +176,12 @@ namespace Tensorflow
else
nparray = Convert.ToInt32(values);
break;
+ case "Int64":
+ if (values.GetType().IsArray)
+ nparray = np.array((int[])values, np_dt);
+ else
+ nparray = Convert.ToInt64(values);
+ break;
case "Single":
if (values.GetType().IsArray)
nparray = np.array((float[])values, np_dt);
diff --git a/src/TensorFlowNET.Core/ops.py.cs b/src/TensorFlowNET.Core/ops.py.cs
index 48c7909c..cb61a3ba 100644
--- a/src/TensorFlowNET.Core/ops.py.cs
+++ b/src/TensorFlowNET.Core/ops.py.cs
@@ -188,6 +188,9 @@ namespace Tensorflow
{
var op_desc = graph.NewOperation(node_def.Op, node_def.Name);
+ //TODO: Implement TF_SetDevice
+ //if node_def.device:
+ // c_api.TF_SetDevice(op_desc, compat.as_str(node_def.device))
// Add inputs
foreach (var op_input in inputs)
{
@@ -195,10 +198,7 @@ namespace Tensorflow
c_api.TF_AddInputList(op_desc, op_inputs.Select(x => x._as_tf_output()).ToArray(), op_inputs.Length);
else if (op_input is Tensor op_input1)
{
- if (op_input1.op == null)
- c_api.TF_AddInput(op_desc, new TF_Output(op_desc, 0));
- else
- c_api.TF_AddInput(op_desc, op_input1._as_tf_output());
+ c_api.TF_AddInput(op_desc, op_input1._as_tf_output());
}
else
throw new NotImplementedException("_create_c_op");
diff --git a/test/TensorFlowNET.UnitTest/nn_test/ZeroFractionTest.cs b/test/TensorFlowNET.UnitTest/nn_test/ZeroFractionTest.cs
index 744e52c3..22e9d5a0 100644
--- a/test/TensorFlowNET.UnitTest/nn_test/ZeroFractionTest.cs
+++ b/test/TensorFlowNET.UnitTest/nn_test/ZeroFractionTest.cs
@@ -23,13 +23,12 @@ namespace TensorFlowNET.UnitTest.nn_test
return 1.0 - nonzeros / (double)total_elements;
}
- [Ignore("TODO implement nn_impl.zero_fraction")]
[TestMethod]
public void testZeroFraction()
{
var x_shape = new Shape(5, 17);
var x_np = np.random.randint(0, 2, x_shape);
- x_np.astype(np.float32);
+ //x_np.astype(np.float32);
var y_np = this._ZeroFraction(x_np);
var x_tf = constant_op.constant(x_np);
@@ -41,7 +40,6 @@ namespace TensorFlowNET.UnitTest.nn_test
self.assertAllClose(y_tf_np, y_np, eps);
}
- [Ignore("TODO implement nn_impl.zero_fraction")]
[TestMethod]
public void testZeroFractionEmpty()
{
@@ -60,7 +58,6 @@ namespace TensorFlowNET.UnitTest.nn_test
self.assertAllClose(1.0, self.evaluate(sparsity));
}
- [Ignore("TODO implement nn_impl.zero_fraction")]
[TestMethod]
public void testZeroFraction2_27Ones()
{