Browse Source

fix scalar ndarray to tensor proto.

tags/v0.60-tf.numpy
Oceania2018 4 years ago
parent
commit
a129e61feb
4 changed files with 36 additions and 6 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/NumPy/Numpy.Manipulation.cs
  2. +1
    -1
      src/TensorFlowNET.Core/NumPy/ShapeHelper.cs
  3. +33
    -4
      src/TensorFlowNET.Core/Tensors/tensor_util.cs
  4. +1
    -0
      test/TensorFlowNET.UnitTest/Numpy/Array.Creation.Test.cs

+ 1
- 1
src/TensorFlowNET.Core/NumPy/Numpy.Manipulation.cs View File

@@ -18,7 +18,7 @@ namespace Tensorflow.NumPy
public static NDArray squeeze(NDArray x1, Axis? axis = null) => new NDArray(array_ops.squeeze(x1, axis));

[AutoNumPy]
public static NDArray stack(NDArray arrays, Axis axis = null) => new NDArray(array_ops.stack(arrays, axis ?? 0));
public static NDArray stack(params NDArray[] arrays) => new NDArray(array_ops.stack(arrays));

[AutoNumPy]
public static NDArray dstack(params NDArray[] tup) => throw new NotImplementedException("");


+ 1
- 1
src/TensorFlowNET.Core/NumPy/ShapeHelper.cs View File

@@ -94,7 +94,7 @@ namespace Tensorflow.NumPy
{
-1 => "<unknown>",
0 => "()",
1 => $"({shape.dims[0]},)",
1 => $"({shape.dims[0].ToString().Replace("-1", "None")},)",
_ => $"({string.Join(", ", shape.dims).Replace("-1", "None")})"
};
}


+ 33
- 4
src/TensorFlowNET.Core/Tensors/tensor_util.cs View File

@@ -134,12 +134,41 @@ namespace Tensorflow
TensorShape = shape.as_shape_proto()
};

// scalar
if (values is NDArray nd)
{
var len = nd.dtypesize * nd.size;
byte[] bytes = nd.ToByteArray();
tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(bytes);
// scalar
if (nd.shape.IsScalar)
{
switch (nd.dtype)
{
case TF_DataType.TF_BOOL:
tensor_proto.BoolVal.AddRange(nd.ToArray<bool>());
break;
case TF_DataType.TF_UINT8:
tensor_proto.IntVal.AddRange(nd.ToArray<byte>().Select(x => (int)x).ToArray());
break;
case TF_DataType.TF_INT32:
tensor_proto.IntVal.AddRange(nd.ToArray<int>());
break;
case TF_DataType.TF_INT64:
tensor_proto.Int64Val.AddRange(nd.ToArray<long>());
break;
case TF_DataType.TF_FLOAT:
tensor_proto.FloatVal.AddRange(nd.ToArray<float>());
break;
case TF_DataType.TF_DOUBLE:
tensor_proto.DoubleVal.AddRange(nd.ToArray<double>());
break;
default:
throw new Exception("make_tensor_proto Not Implemented");
}
}
else
{
var len = nd.dtypesize * nd.size;
byte[] bytes = nd.ToByteArray();
tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(bytes);
}
}
else if (dtype == TF_DataType.TF_STRING && !(values is NDArray))
{


+ 1
- 0
test/TensorFlowNET.UnitTest/Numpy/Array.Creation.Test.cs View File

@@ -21,6 +21,7 @@ namespace TensorFlowNET.UnitTest.NumPy
var zeros = np.zeros((2, 2));
var ones = np.ones((2, 2));
var full = np.full((2, 2), 0.1f);
Assert.AreEqual(np.float32, full.dtype);
}

[TestMethod]


Loading…
Cancel
Save