From 9edada5abc099104d376d4ed8541ec960066920f Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Tue, 1 Oct 2019 23:19:40 -0500 Subject: [PATCH] fix Tensor set_shape() --- .../Gradients/array_grad.cs | 6 +-- .../Gradients/gradients_util.cs | 4 +- .../Operations/array_ops.py.cs | 2 +- .../Operations/gen_array_ops.cs | 2 +- src/TensorFlowNET.Core/Tensors/Tensor.cs | 37 ++++++++----------- 5 files changed, 23 insertions(+), 28 deletions(-) diff --git a/src/TensorFlowNET.Core/Gradients/array_grad.cs b/src/TensorFlowNET.Core/Gradients/array_grad.cs index baa7145d..f07d2825 100644 --- a/src/TensorFlowNET.Core/Gradients/array_grad.cs +++ b/src/TensorFlowNET.Core/Gradients/array_grad.cs @@ -196,13 +196,13 @@ namespace Tensorflow.Gradients var grad = grads[0]; var x = op.inputs[0]; var a = op.inputs[1]; - var pad_before = array_ops.slice(a, new[] { 0, 0 }, - new[] { array_ops.stack(new object[] { array_ops.rank(x), 1 }) }); + var size = array_ops.stack(new object[] { array_ops.rank(x), 1 }); + var pad_before = array_ops.slice(a, new[] { 0, 0 }, size); // Make it a 1-D tensor. var begin = array_ops.reshape(pad_before, new[] { -1 }); var sizes = array_ops.shape(x); - var x_grad = array_ops.slice(grad, new[] { begin }, new[] { sizes }); + var x_grad = array_ops.slice(grad, begin, sizes); if (len(op.inputs) == 3) return new Tensor[] { x_grad, null, null }; diff --git a/src/TensorFlowNET.Core/Gradients/gradients_util.cs b/src/TensorFlowNET.Core/Gradients/gradients_util.cs index f10a5fed..9029fb8f 100644 --- a/src/TensorFlowNET.Core/Gradients/gradients_util.cs +++ b/src/TensorFlowNET.Core/Gradients/gradients_util.cs @@ -108,7 +108,7 @@ namespace Tensorflow { // generate gradient subgraph for op. var op = queue.Dequeue(); - if(tf.get_default_graph()._nodes_by_name.Count >= 20611) + if(tf.get_default_graph()._nodes_by_name.Count >= 23868) { } @@ -216,7 +216,7 @@ namespace Tensorflow in_grad.Tag == null && // maybe a IndexedSlice t_in.dtype != TF_DataType.TF_RESOURCE) { - in_grad.shape = t_in.shape; + in_grad.set_shape(t_in.TensorShape); } _SetGrad(grads, t_in, in_grad); diff --git a/src/TensorFlowNET.Core/Operations/array_ops.py.cs b/src/TensorFlowNET.Core/Operations/array_ops.py.cs index 3e2276c6..12094e41 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.py.cs @@ -611,7 +611,7 @@ namespace Tensorflow }); } - public static Tensor slice(Tensor input, Tb[] begin, Ts[] size, string name = null) + public static Tensor slice(Tensor input, Tb begin, Ts size, string name = null) => gen_array_ops.slice(input, begin, size, name: name); public static Tensor stack(object values, int axis = 0, string name = "stack") diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs index 59b43766..36837477 100644 --- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs @@ -475,7 +475,7 @@ namespace Tensorflow return op.output; } - public static Tensor slice(Tensor input, Tb[] begin, Ts[] size, string name = null) + public static Tensor slice(Tensor input, Tb begin, Ts size, string name = null) { var _op = _op_def_lib._apply_op_helper("Slice", name, new { input, begin, size }); return _op.outputs[0]; diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index f3ad2efd..fb8e2457 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -105,10 +105,13 @@ namespace Tensorflow if (_handle == IntPtr.Zero) { - var status = new Status(); - c_api.TF_GraphGetTensorShape(op.graph, _as_tf_output(), dims, rank, status); - status.Check(); - } else + using (var status = new Status()) + { + c_api.TF_GraphGetTensorShape(op.graph, _as_tf_output(), dims, rank, status); + status.Check(); + } + } + else { for (int i = 0; i < rank; i++) dims[i] = c_api.TF_Dim(_handle, i); @@ -119,14 +122,15 @@ namespace Tensorflow set { - var status = new Status(); - - if (value == null) - c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), null, -1, status); - else - c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), value.Select(Convert.ToInt64).ToArray(), value.Length, status); + using (var status = new Status()) + { + if (value == null) + c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), null, -1, status); + else + c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), value.Select(Convert.ToInt64).ToArray(), value.Length, status); - status.Check(true); + status.Check(true); + } } } @@ -142,16 +146,7 @@ namespace Tensorflow /// public void set_shape(TensorShape shape) { - this.shape = (int[]) shape.dims.Clone(); - } - - /// - /// Updates the shape of this tensor. - /// - [Obsolete("Please use set_shape(TensorShape shape) instead.", false)] - public void SetShape(TensorShape shape) - { - this.shape = (int[]) shape.dims.Clone(); + this.shape = shape.rank > 0 ? shape.dims : null; } ///