From d3e212488feb48fab5f1ce68952ec0d86237bbce Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sun, 11 Jul 2021 20:21:05 -0500 Subject: [PATCH] fix ndarray index. --- src/TensorFlowNET.Core/NumPy/NDArray.Index.cs | 2 +- .../Operations/Initializers/Constant.cs | 7 +- .../Initializers/InitializerArgs.cs | 4 +- .../Operations/array_ops.cs | 8 +- src/TensorFlowNET.Core/Tensors/constant_op.cs | 132 ++++++++++-------- src/TensorFlowNET.Core/Tensors/tensor_util.cs | 6 + src/TensorFlowNET.Core/Tensors/tf.constant.cs | 8 +- src/TensorFlowNET.Core/ops.cs | 5 +- .../TensorFlowNET.Graph.UnitTest/ImageTest.cs | 5 +- 9 files changed, 97 insertions(+), 80 deletions(-) diff --git a/src/TensorFlowNET.Core/NumPy/NDArray.Index.cs b/src/TensorFlowNET.Core/NumPy/NDArray.Index.cs index 8ab82278..316ee024 100644 --- a/src/TensorFlowNET.Core/NumPy/NDArray.Index.cs +++ b/src/TensorFlowNET.Core/NumPy/NDArray.Index.cs @@ -25,7 +25,7 @@ namespace Tensorflow.NumPy { get { - return _tensor[index.Select(x => new Slice(x, x + 1)).ToArray()]; + return _tensor[index.Select(x => new Slice(x, x + 1)).ToArray()]; } set diff --git a/src/TensorFlowNET.Core/Operations/Initializers/Constant.cs b/src/TensorFlowNET.Core/Operations/Initializers/Constant.cs index cf230978..fdcb5aff 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/Constant.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/Constant.cs @@ -34,12 +34,11 @@ namespace Tensorflow.Operations.Initializers if (args.DType == TF_DataType.DtInvalid) args.DType = this.dtype; - if (!args.VerifyShape.HasValue) - args.VerifyShape = _verify_shape; + args.VerifyShape = _verify_shape; - return constant_op._constant_impl(value, args.DType, args.Shape, + return constant_op.constant(value, args.DType, args.Shape, name: "Const", - verify_shape: args.VerifyShape.Value, + verify_shape: args.VerifyShape, allow_broadcast: false); } } diff --git a/src/TensorFlowNET.Core/Operations/Initializers/InitializerArgs.cs b/src/TensorFlowNET.Core/Operations/Initializers/InitializerArgs.cs index 10702ece..756f33a7 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/InitializerArgs.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/InitializerArgs.cs @@ -5,11 +5,11 @@ public string Name { get; set; } public TensorShape Shape { get; set; } public TF_DataType DType { get; set; } - public bool? VerifyShape { get; set; } = null; + public bool VerifyShape { get; set; } public InitializerArgs(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid, - bool? verify_shape = null, + bool verify_shape = false, string name = null) { Shape = shape; diff --git a/src/TensorFlowNET.Core/Operations/array_ops.cs b/src/TensorFlowNET.Core/Operations/array_ops.cs index be10541e..9e7290ed 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.cs @@ -64,10 +64,10 @@ namespace Tensorflow TF_DataType dtype = TF_DataType.DtInvalid, int[] shape = null, string name = "Const", - bool verify_shape = false) => constant_op._constant_impl(value, - dtype, - shape, - name, + bool verify_shape = false) => constant_op.constant(value, + dtype: dtype, + shape: shape, + name: name, verify_shape: verify_shape, allow_broadcast: false); diff --git a/src/TensorFlowNET.Core/Tensors/constant_op.cs b/src/TensorFlowNET.Core/Tensors/constant_op.cs index a8870252..cf6c76a2 100644 --- a/src/TensorFlowNET.Core/Tensors/constant_op.cs +++ b/src/TensorFlowNET.Core/Tensors/constant_op.cs @@ -37,68 +37,14 @@ namespace Tensorflow /// Optional dimensions of resulting tensor. /// Optional name for the tensor. /// - public static Tensor constant(object value, TF_DataType dtype = TF_DataType.DtInvalid, int[] shape = null, string name = "Const") + public static Tensor constant(object value, TF_DataType dtype = TF_DataType.DtInvalid, + int[] shape = null, bool verify_shape = false, + bool allow_broadcast = true, string name = "Const") { - return _constant_impl(value, dtype, shape, name, verify_shape: false, allow_broadcast: true); - } - - /// Boolean that enables verification of a shape of values. - public static Tensor _constant_impl(object value, - TF_DataType dtype, - TensorShape shape, - string name, - bool verify_shape, - bool allow_broadcast) - { - if (tf.Context.executing_eagerly()) - { - var t = convert_to_eager_tensor(value, tf.Context, dtype: dtype); - if (shape == null) - return t; - - if (t.shape.Select(x => Convert.ToInt64(x)).SequenceEqual(shape.dims)) - return t; - - if (verify_shape) - throw new TypeError($"Expected Tensor's shape: {shape}, got {t.shape}."); - - var num_t = t.TensorShape.num_elements(); - if (num_t == shape.num_elements()) - return _eager_reshape(t, shape, tf.Context); - if (num_t == 1) - { - if (t.dtype == dtypes.@bool) - throw new NotImplementedException(""); - else - return _eager_fill(shape, t, tf.Context); - } - } - - // graph mode - Graph g = ops.get_default_graph(); - var tensor_value = new AttrValue(); - tensor_value.Tensor = tensor_util.make_tensor_proto(value, - dtype: dtype, - shape: shape, - verify_shape: verify_shape, - allow_broadcast: allow_broadcast); - - var dtype_value = new AttrValue - { - Type = tensor_value.Tensor.Dtype, - }; - - var attrs = new Dictionary(); - attrs["value"] = tensor_value; - attrs["dtype"] = dtype_value; - - var op = g.create_op("Const", - new Tensor[0], - new TF_DataType[] { dtype_value.Type.as_tf_dtype() }, - attrs: attrs, - name: name); - - return op.outputs[0]; + if(tf.executing_eagerly()) + return convert_to_eager_tensor(value, dtype, shape, name, verify_shape: verify_shape, allow_broadcast: allow_broadcast); + else + return convert_to_graph_tensor(value, dtype, shape, name, verify_shape: verify_shape, allow_broadcast: allow_broadcast); } private static Tensor _eager_reshape(Tensor tensor, int[] shape, Context ctx) @@ -189,6 +135,70 @@ namespace Tensorflow } } + static Tensor convert_to_eager_tensor(object value, + TF_DataType dtype, + TensorShape shape, + string name, + bool verify_shape, + bool allow_broadcast) + { + var t = convert_to_eager_tensor(value, tf.Context, dtype: dtype); + if (shape == null) + return t; + + if (t.shape.Select(x => Convert.ToInt64(x)).SequenceEqual(shape.dims)) + return t; + + if (verify_shape) + throw new TypeError($"Expected Tensor's shape: {shape}, got {t.shape}."); + + var num_t = t.TensorShape.num_elements(); + if (num_t == shape.num_elements()) + return _eager_reshape(t, shape, tf.Context); + if (num_t == 1) + { + if (t.dtype == dtypes.@bool) + throw new NotImplementedException(""); + else + return _eager_fill(shape, t, tf.Context); + } + + throw new NotImplementedException(""); + } + + static Tensor convert_to_graph_tensor(object value, + TF_DataType dtype, + TensorShape shape, + string name, + bool verify_shape, + bool allow_broadcast) + { + Graph g = ops.get_default_graph(); + var tensor_value = new AttrValue(); + tensor_value.Tensor = tensor_util.make_tensor_proto(value, + dtype: dtype, + shape: shape, + verify_shape: verify_shape, + allow_broadcast: allow_broadcast); + + var dtype_value = new AttrValue + { + Type = tensor_value.Tensor.Dtype, + }; + + var attrs = new Dictionary(); + attrs["value"] = tensor_value; + attrs["dtype"] = dtype_value; + + var op = g.create_op("Const", + new Tensor[0], + new TF_DataType[] { dtype_value.Type.as_tf_dtype() }, + attrs: attrs, + name: name); + + return op.outputs[0]; + } + /// /// Function to convert TensorShape to Tensor. /// diff --git a/src/TensorFlowNET.Core/Tensors/tensor_util.cs b/src/TensorFlowNET.Core/Tensors/tensor_util.cs index d97ea1da..5ad8bc9b 100644 --- a/src/TensorFlowNET.Core/Tensors/tensor_util.cs +++ b/src/TensorFlowNET.Core/Tensors/tensor_util.cs @@ -125,6 +125,12 @@ namespace Tensorflow byte[] bytes = nd.ToByteArray(); tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(bytes); } + else if (values is Tensor tensor && tensor.IsReferencedByNDArray) + { + var len = tensor.itemsize * tensor.size; + byte[] bytes = tensor.BufferToArray(); + tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(bytes); + } else if (!values.GetType().IsArray) { switch (values) diff --git a/src/TensorFlowNET.Core/Tensors/tf.constant.cs b/src/TensorFlowNET.Core/Tensors/tf.constant.cs index 291e8d0c..3bf6614c 100644 --- a/src/TensorFlowNET.Core/Tensors/tf.constant.cs +++ b/src/TensorFlowNET.Core/Tensors/tf.constant.cs @@ -30,10 +30,10 @@ namespace Tensorflow TF_DataType dtype = TF_DataType.DtInvalid, TensorShape shape = null, string name = "Const") - => constant_op._constant_impl(value, - dtype, - shape, - name, + => constant_op.constant(value, + dtype: dtype, + shape: shape, + name: name, verify_shape: false, allow_broadcast: true); diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs index 5e2e8287..07697d5f 100644 --- a/src/TensorFlowNET.Core/ops.cs +++ b/src/TensorFlowNET.Core/ops.cs @@ -145,7 +145,10 @@ namespace Tensorflow } else if (value is Tensor tensor && tensor.IsReferencedByNDArray) { - return tensor; + if (tf.executing_eagerly()) + return tensor; + else + return constant_op.constant(tensor); } // graph mode diff --git a/test/TensorFlowNET.Graph.UnitTest/ImageTest.cs b/test/TensorFlowNET.Graph.UnitTest/ImageTest.cs index 39a004f0..a53635d4 100644 --- a/test/TensorFlowNET.Graph.UnitTest/ImageTest.cs +++ b/test/TensorFlowNET.Graph.UnitTest/ImageTest.cs @@ -82,15 +82,14 @@ namespace TensorFlowNET.UnitTest var result = sess.run(cropped); // check if cropped to 1x1 center was succesfull - Assert.AreEqual(result.size, 1); + Assert.AreEqual(result.size, 1ul); Assert.AreEqual(result[0, 0, 0, 0], 4f); cropped = tf.image.crop_and_resize(image2, box, boxInd, cropSize2_2); result = sess.run(cropped); // check if flipped and no cropping occured - Assert.AreEqual(result.size, 16); + Assert.AreEqual(result.size, 16ul); Assert.AreEqual(result[0, 0, 0, 0], 12f); - } } }