From a129e61feb4a4184100661f8b1f4007ba54258a3 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Tue, 27 Jul 2021 12:40:01 -0500 Subject: [PATCH] fix scalar ndarray to tensor proto. --- .../NumPy/Numpy.Manipulation.cs | 2 +- src/TensorFlowNET.Core/NumPy/ShapeHelper.cs | 2 +- src/TensorFlowNET.Core/Tensors/tensor_util.cs | 37 +++++++++++++++++-- .../Numpy/Array.Creation.Test.cs | 1 + 4 files changed, 36 insertions(+), 6 deletions(-) diff --git a/src/TensorFlowNET.Core/NumPy/Numpy.Manipulation.cs b/src/TensorFlowNET.Core/NumPy/Numpy.Manipulation.cs index a1a3d79b..afecd7b9 100644 --- a/src/TensorFlowNET.Core/NumPy/Numpy.Manipulation.cs +++ b/src/TensorFlowNET.Core/NumPy/Numpy.Manipulation.cs @@ -18,7 +18,7 @@ namespace Tensorflow.NumPy public static NDArray squeeze(NDArray x1, Axis? axis = null) => new NDArray(array_ops.squeeze(x1, axis)); [AutoNumPy] - public static NDArray stack(NDArray arrays, Axis axis = null) => new NDArray(array_ops.stack(arrays, axis ?? 0)); + public static NDArray stack(params NDArray[] arrays) => new NDArray(array_ops.stack(arrays)); [AutoNumPy] public static NDArray dstack(params NDArray[] tup) => throw new NotImplementedException(""); diff --git a/src/TensorFlowNET.Core/NumPy/ShapeHelper.cs b/src/TensorFlowNET.Core/NumPy/ShapeHelper.cs index 6e2a7926..dec43e83 100644 --- a/src/TensorFlowNET.Core/NumPy/ShapeHelper.cs +++ b/src/TensorFlowNET.Core/NumPy/ShapeHelper.cs @@ -94,7 +94,7 @@ namespace Tensorflow.NumPy { -1 => "", 0 => "()", - 1 => $"({shape.dims[0]},)", + 1 => $"({shape.dims[0].ToString().Replace("-1", "None")},)", _ => $"({string.Join(", ", shape.dims).Replace("-1", "None")})" }; } diff --git a/src/TensorFlowNET.Core/Tensors/tensor_util.cs b/src/TensorFlowNET.Core/Tensors/tensor_util.cs index 243b73d3..96d1e7b4 100644 --- a/src/TensorFlowNET.Core/Tensors/tensor_util.cs +++ b/src/TensorFlowNET.Core/Tensors/tensor_util.cs @@ -134,12 +134,41 @@ namespace Tensorflow TensorShape = shape.as_shape_proto() }; - // scalar if (values is NDArray nd) { - var len = nd.dtypesize * nd.size; - byte[] bytes = nd.ToByteArray(); - tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(bytes); + // scalar + if (nd.shape.IsScalar) + { + switch (nd.dtype) + { + case TF_DataType.TF_BOOL: + tensor_proto.BoolVal.AddRange(nd.ToArray()); + break; + case TF_DataType.TF_UINT8: + tensor_proto.IntVal.AddRange(nd.ToArray().Select(x => (int)x).ToArray()); + break; + case TF_DataType.TF_INT32: + tensor_proto.IntVal.AddRange(nd.ToArray()); + break; + case TF_DataType.TF_INT64: + tensor_proto.Int64Val.AddRange(nd.ToArray()); + break; + case TF_DataType.TF_FLOAT: + tensor_proto.FloatVal.AddRange(nd.ToArray()); + break; + case TF_DataType.TF_DOUBLE: + tensor_proto.DoubleVal.AddRange(nd.ToArray()); + break; + default: + throw new Exception("make_tensor_proto Not Implemented"); + } + } + else + { + var len = nd.dtypesize * nd.size; + byte[] bytes = nd.ToByteArray(); + tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(bytes); + } } else if (dtype == TF_DataType.TF_STRING && !(values is NDArray)) { diff --git a/test/TensorFlowNET.UnitTest/Numpy/Array.Creation.Test.cs b/test/TensorFlowNET.UnitTest/Numpy/Array.Creation.Test.cs index aa0652f2..799d40c4 100644 --- a/test/TensorFlowNET.UnitTest/Numpy/Array.Creation.Test.cs +++ b/test/TensorFlowNET.UnitTest/Numpy/Array.Creation.Test.cs @@ -21,6 +21,7 @@ namespace TensorFlowNET.UnitTest.NumPy var zeros = np.zeros((2, 2)); var ones = np.ones((2, 2)); var full = np.full((2, 2), 0.1f); + Assert.AreEqual(np.float32, full.dtype); } [TestMethod]