From 77531839414faed346f5a8480038995c46d5f7f8 Mon Sep 17 00:00:00 2001 From: haiping008 Date: Thu, 21 Mar 2019 17:40:45 -0500 Subject: [PATCH] Shape error for gradients/Sum_grad/Tile #193 --- .../Framework/common_shapes.py.cs | 5 +++++ src/TensorFlowNET.Core/Gradients/math_grad.cs | 3 ++- .../Operations/Operation.Output.cs | 3 ++- .../Operations/Operation.cs | 1 + .../Operations/gen_math_ops.cs | 4 ++-- .../Operations/math_ops.py.cs | 19 +++++++++++-------- src/TensorFlowNET.Core/Tensors/Tensor.cs | 3 ++- src/TensorFlowNET.Core/Tensors/tensor_util.cs | 4 ++-- .../LogisticRegression.cs | 7 ------- 9 files changed, 27 insertions(+), 22 deletions(-) diff --git a/src/TensorFlowNET.Core/Framework/common_shapes.py.cs b/src/TensorFlowNET.Core/Framework/common_shapes.py.cs index 87b083d5..70bea7b0 100644 --- a/src/TensorFlowNET.Core/Framework/common_shapes.py.cs +++ b/src/TensorFlowNET.Core/Framework/common_shapes.py.cs @@ -34,5 +34,10 @@ namespace Tensorflow.Framework { return tensor.rank; } + + public static bool has_fully_defined_shape(Tensor tensor) + { + return tensor.getShape().is_fully_defined(); + } } } diff --git a/src/TensorFlowNET.Core/Gradients/math_grad.cs b/src/TensorFlowNET.Core/Gradients/math_grad.cs index d47f9732..29faaa7a 100644 --- a/src/TensorFlowNET.Core/Gradients/math_grad.cs +++ b/src/TensorFlowNET.Core/Gradients/math_grad.cs @@ -1,4 +1,5 @@ -using System; +//using Newtonsoft.Json; +using System; using System.Collections.Generic; using System.Linq; using System.Text; diff --git a/src/TensorFlowNET.Core/Operations/Operation.Output.cs b/src/TensorFlowNET.Core/Operations/Operation.Output.cs index 3ec16704..5b0b43b3 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.Output.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.Output.cs @@ -1,4 +1,5 @@ -using System; +//using Newtonsoft.Json; +using System; using System.Collections.Generic; using System.Linq; using System.Runtime.InteropServices; diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index 378df9c5..0979c150 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -1,4 +1,5 @@ using Google.Protobuf.Collections; +//using Newtonsoft.Json; using System; using System.Collections.Generic; using System.Linq; diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs index 4cfdc7cf..a48d60c4 100644 --- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs @@ -207,14 +207,14 @@ namespace Tensorflow return _op.outputs[0]; } - public static Tensor sum(Tensor input, Tensor axis = null, bool keep_dims = false, string name = null) + public static Tensor _sum(Tensor input, Tensor axis = null, bool keep_dims = false, string name = null) { var _op = _op_def_lib._apply_op_helper("Sum", name, args: new { input, reduction_indices = axis, keep_dims }); return _op.outputs[0]; } - public static Tensor sum(Tensor input, int axis, bool keep_dims = false, string name = null) + public static Tensor _sum(Tensor input, int axis, bool keep_dims = false, string name = null) { var _op = _op_def_lib._apply_op_helper("Sum", name, args: new { input, reduction_indices = axis, keep_dims }); diff --git a/src/TensorFlowNET.Core/Operations/math_ops.py.cs b/src/TensorFlowNET.Core/Operations/math_ops.py.cs index c8e3f98f..f9306f40 100644 --- a/src/TensorFlowNET.Core/Operations/math_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/math_ops.py.cs @@ -212,26 +212,29 @@ namespace Tensorflow throw new NotImplementedException(); } - public static Tensor reduce_sum(Tensor input_tensor, Tensor axis = null, bool keepdims = false) + public static Tensor reduce_sum(Tensor input_tensor, Tensor axis = null, bool keepdims = false, string name = null) { var r = _ReductionDims(input_tensor, axis); - var m = gen_math_ops.sum(input_tensor, r); - return _may_reduce_to_scalar(keepdims, m); + var m = gen_math_ops._sum(input_tensor, r, keep_dims: keepdims, name: name); + return _may_reduce_to_scalar(keepdims, axis, m); } public static Tensor reduce_sum(Tensor input_tensor, int axis, bool keepdims = false) { - var m = gen_math_ops.sum(input_tensor, axis); - return _may_reduce_to_scalar(keepdims, m); + var m = gen_math_ops._sum(input_tensor, axis); + return _may_reduce_to_scalar(keepdims, new int[] { axis }, m); } - private static Tensor _may_reduce_to_scalar(bool keepdims, Tensor output) + private static Tensor _may_reduce_to_scalar(bool keepdims, Tensor axis, Tensor output) { - output.shape = new long[0]; + if (!common_shapes.has_fully_defined_shape(output) && + !keepdims && + axis == null) + output.shape = new long[0]; return output; } - private static Tensor _may_reduce_to_scalar(bool keepdims, int[] axos, Tensor output) + private static Tensor _may_reduce_to_scalar(bool keepdims, int[] axis, Tensor output) { output.shape = new long[0]; return output; diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index bf85284a..cb8d24db 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -1,4 +1,5 @@ -using NumSharp.Core; +//using Newtonsoft.Json; +using NumSharp.Core; using System; using System.Collections.Generic; using System.Linq; diff --git a/src/TensorFlowNET.Core/Tensors/tensor_util.cs b/src/TensorFlowNET.Core/Tensors/tensor_util.cs index bdfaddf7..44b55259 100644 --- a/src/TensorFlowNET.Core/Tensors/tensor_util.cs +++ b/src/TensorFlowNET.Core/Tensors/tensor_util.cs @@ -302,7 +302,7 @@ namespace Tensorflow default: throw new NotImplementedException("as_shape Not Implemented"); } - dim.Name = $"dim_{i}"; + // dim.Name = $"dim_{i}"; shape.Dim.Add(dim); } @@ -333,7 +333,7 @@ namespace Tensorflow { var dim = new TensorShapeProto.Types.Dim(); dim.Size = tshape.Dimensions[i]; - dim.Name = $"dim_{i}"; + //dim.Name = $"dim_{i}"; shape.Dim.Add(dim); } diff --git a/test/TensorFlowNET.Examples/LogisticRegression.cs b/test/TensorFlowNET.Examples/LogisticRegression.cs index 4885f633..e6168f80 100644 --- a/test/TensorFlowNET.Examples/LogisticRegression.cs +++ b/test/TensorFlowNET.Examples/LogisticRegression.cs @@ -49,13 +49,6 @@ namespace TensorFlowNET.Examples // Gradient Descent var optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost); - //var new_saver = tf.train.import_meta_graph("logistic_regression.meta.bin"); - - /*var text = JsonConvert.SerializeObject(tf.get_default_graph(), new JsonSerializerSettings - { - Formatting = Formatting.Indented - });*/ - // Initialize the variables (i.e. assign their default value) var init = tf.global_variables_initializer();