From 389f7dd2e370463b89ba79bc36f4f61c31f92da8 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Fri, 6 Sep 2019 23:39:59 -0500 Subject: [PATCH] fix output dtype error for control_flow.ref_switch --- src/TensorFlowNET.Core/Operations/Operation.Output.cs | 2 +- src/TensorFlowNET.Core/Operations/Operation.cs | 7 +++---- src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs | 7 +++---- .../Operations/gen_control_flow_ops.py.cs | 6 ++++++ src/TensorFlowNET.Core/TensorFlowNET.Core.csproj | 6 +++--- src/TensorFlowNET.Core/ops.GraphKeys.cs | 2 +- src/TensorFlowNET.Core/ops.cs | 4 ++-- .../ops_test/CreateOpFromTfOperationTest.cs | 4 ++-- 8 files changed, 21 insertions(+), 17 deletions(-) diff --git a/src/TensorFlowNET.Core/Operations/Operation.Output.cs b/src/TensorFlowNET.Core/Operations/Operation.Output.cs index 9701d77a..6844c892 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.Output.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.Output.cs @@ -24,7 +24,7 @@ namespace Tensorflow public partial class Operation { public int NumOutputs => c_api.TF_OperationNumOutputs(_handle); - public TF_DataType OutputType(int index) => c_api.TF_OperationOutputType(new TF_Output(_handle, index)); + public TF_DataType OutputType(int index) => c_api.TF_OperationOutputType(_tf_output(index)); public int OutputListLength(string name) { diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index 856e3677..00ba8c78 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -44,7 +44,6 @@ namespace Tensorflow public partial class Operation : ITensorOrOperation { private readonly IntPtr _handle; // _c_op in python - private readonly IntPtr _operDesc; private readonly Graph _graph; private NodeDef _node_def; @@ -91,7 +90,7 @@ namespace Tensorflow { _graph = g; - _operDesc = c_api.TF_NewOperation(g, opType, oper_name); + var _operDesc = c_api.TF_NewOperation(g, opType, oper_name); c_api.TF_SetAttrType(_operDesc, "dtype", TF_DataType.TF_INT32); lock (Locks.ProcessWide) using (var status = new Status()) @@ -161,7 +160,7 @@ namespace Tensorflow op_def = g.GetOpDef(node_def.Op); var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr); - (_handle, _operDesc) = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray()); + _handle = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray()); // Initialize self._outputs. output_types = new TF_DataType[NumOutputs]; @@ -170,7 +169,7 @@ namespace Tensorflow _outputs = new Tensor[NumOutputs]; for (int i = 0; i < NumOutputs; i++) - _outputs[i] = new Tensor(this, i, OutputType(i)); + _outputs[i] = new Tensor(this, i, output_types[i]); graph._add_op(this); diff --git a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs index 04ef54a7..571457b9 100644 --- a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs @@ -275,7 +275,7 @@ namespace Tensorflow /// public static Tensor[] _SwitchRefOrTensor(Tensor data, Tensor pred, string name = "Switch") { - data = ops.convert_to_tensor_or_indexed_slices(data, name: "data"); + data = ops.convert_to_tensor_or_composite(data, name: "data"); // NOTE(vrv): ops.colocate_with(data, ignore_existing=True) below // addresses the following scenario. // @@ -296,9 +296,8 @@ namespace Tensorflow { if (data is Tensor) { - // TODO: ref_switch - //if (data.dtype._is_ref_dtype) - // return control_flow_ops.ref_switch(data, pred, name = name); + if (data.dtype.is_ref_dtype()) + return gen_control_flow_ops.ref_switch(data, pred, name: name); } return @switch(data, pred, name: name); } diff --git a/src/TensorFlowNET.Core/Operations/gen_control_flow_ops.py.cs b/src/TensorFlowNET.Core/Operations/gen_control_flow_ops.py.cs index bfbf3413..163a50e4 100644 --- a/src/TensorFlowNET.Core/Operations/gen_control_flow_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/gen_control_flow_ops.py.cs @@ -114,6 +114,12 @@ namespace Tensorflow return _op; } + public static Tensor[] ref_switch(Tensor data, Tensor pred, string name = null) + { + var _op = _op_def_lib._apply_op_helper("RefSwitch", name, new { data, pred }); + return _op.outputs; + } + /// /// Forwards `data` to the output port determined by `pred`. /// diff --git a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj index 9eae0cd9..7b3d29c3 100644 --- a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj +++ b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj @@ -5,7 +5,7 @@ TensorFlow.NET Tensorflow 1.14.0 - 0.11.2 + 0.11.3 Haiping Chen, Meinrad Recheis, Eli Belash SciSharp STACK true @@ -17,7 +17,7 @@ TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET, C# Google's TensorFlow full binding in .NET Standard. Docs: https://tensorflownet.readthedocs.io - 0.11.2.0 + 0.11.3.0 Changes since v0.10.0: 1. Upgrade NumSharp to v0.20. 2. Add DisposableObject class to manage object lifetime. @@ -30,7 +30,7 @@ Docs: https://tensorflownet.readthedocs.io 9. MultiThread is safe. 10. Support n-dim indexing for tensor. 7.3 - 0.11.2.0 + 0.11.3.0 LICENSE true true diff --git a/src/TensorFlowNET.Core/ops.GraphKeys.cs b/src/TensorFlowNET.Core/ops.GraphKeys.cs index dad81af9..453b9d43 100644 --- a/src/TensorFlowNET.Core/ops.GraphKeys.cs +++ b/src/TensorFlowNET.Core/ops.GraphKeys.cs @@ -117,7 +117,7 @@ namespace Tensorflow /// Key to collect Variable objects that are global (shared across machines). /// Default collection for all variables, except local ones. /// - public string GLOBAL_VARIABLES = GLOBAL_VARIABLES_; + public string GLOBAL_VARIABLES => GLOBAL_VARIABLES_; public string TRAIN_OP => TRAIN_OP_; diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs index aadf3b08..9aa334db 100644 --- a/src/TensorFlowNET.Core/ops.cs +++ b/src/TensorFlowNET.Core/ops.cs @@ -206,7 +206,7 @@ namespace Tensorflow /// /// A list of `Operation`s to set as control dependencies. /// A wrapped TF_Operation*. - public static (IntPtr, IntPtr) _create_c_op(Graph graph, NodeDef node_def, T[] inputs, Operation[] control_inputs) + public static IntPtr _create_c_op(Graph graph, NodeDef node_def, T[] inputs, Operation[] control_inputs) { lock (Locks.ProcessWide) { @@ -249,7 +249,7 @@ namespace Tensorflow status.Check(true); - return (c_op, op_desc); + return c_op; } } diff --git a/test/TensorFlowNET.UnitTest/ops_test/CreateOpFromTfOperationTest.cs b/test/TensorFlowNET.UnitTest/ops_test/CreateOpFromTfOperationTest.cs index 310ac634..08c8da2a 100644 --- a/test/TensorFlowNET.UnitTest/ops_test/CreateOpFromTfOperationTest.cs +++ b/test/TensorFlowNET.UnitTest/ops_test/CreateOpFromTfOperationTest.cs @@ -27,7 +27,7 @@ namespace TensorFlowNET.UnitTest.ops_test using (var g = tf.Graph().as_default()) { var x = constant_op.constant(new[,] {{1, 2, 3}, {4, 5, 6}}); - var (c_op, op_desc) = ops._create_c_op(g, ops._NodeDef("Identity", "myop"), new[] {x}, new Operation[0]); + var c_op = ops._create_c_op(g, ops._NodeDef("Identity", "myop"), new[] {x}, new Operation[0]); var op = g._create_op_from_tf_operation(c_op); Assert.AreEqual("myop", op.name); @@ -68,7 +68,7 @@ namespace TensorFlowNET.UnitTest.ops_test var true_fn = new Func(() => { - var (c_op, op_desc) = ops._create_c_op(g, ops._NodeDef("Identity", "cond/myop"), new[] { x }, new Operation[0]); + var c_op = ops._create_c_op(g, ops._NodeDef("Identity", "cond/myop"), new[] { x }, new Operation[0]); var new_ops = g._add_new_tf_operations(); self.assertEqual(len(new_ops), 1); return x;