From 5720dfd68e7a11d8b97f98a1b1692ab161de6456 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Mon, 11 Feb 2019 22:30:31 -0600 Subject: [PATCH] fix _constant_if_small for zeros. --- .../Operations/OpDefLibrary.cs | 2 + .../Operations/array_ops.py.cs | 3 +- .../TensorFlowNET.Core.csproj | 2 +- src/TensorFlowNET.Core/Tensors/tensor_util.cs | 46 ++++++------------- .../Train/Saving/BaseSaverBuilder.cs | 2 - .../TensorFlowNET.Examples.csproj | 2 +- test/TensorFlowNET.UnitTest/ConstantTest.cs | 11 +++++ .../TensorFlowNET.UnitTest.csproj | 2 +- test/TensorFlowNET.UnitTest/TrainSaverTest.cs | 17 ++++--- 9 files changed, 38 insertions(+), 49 deletions(-) diff --git a/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs b/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs index 62b4821a..1aa5f589 100644 --- a/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs +++ b/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs @@ -322,6 +322,8 @@ namespace Tensorflow attr_value.Shape = val1.as_proto(); else if(value is long[] val2) attr_value.Shape = tensor_util.as_shape(val2); + else if (value is int[] val3) + attr_value.Shape = tensor_util.as_shape(val3); break; default: diff --git a/src/TensorFlowNET.Core/Operations/array_ops.py.cs b/src/TensorFlowNET.Core/Operations/array_ops.py.cs index 40e9f38f..f68d7747 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.py.cs @@ -34,10 +34,9 @@ namespace Tensorflow private static Tensor _constant_if_small(T value, Shape shape, TF_DataType dtype, string name) { Tensor tShape = null; - var nd = np.zeros(shape); if (shape.Size < 1000) { - return constant_op.constant(nd, name: name); + return constant_op.constant(value, shape: shape, dtype: dtype, name: name); } else { diff --git a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj index d398c163..2d390319 100644 --- a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj +++ b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj @@ -43,7 +43,7 @@ TensorFlow 1.13 RC. - + diff --git a/src/TensorFlowNET.Core/Tensors/tensor_util.cs b/src/TensorFlowNET.Core/Tensors/tensor_util.cs index 8bdac8d2..a7d728d2 100644 --- a/src/TensorFlowNET.Core/Tensors/tensor_util.cs +++ b/src/TensorFlowNET.Core/Tensors/tensor_util.cs @@ -56,9 +56,9 @@ namespace Tensorflow switch (values) { - /*case bool boolVal: + case bool boolVal: nparray = boolVal; - break;*/ + break; case int intVal: nparray = intVal; break; @@ -74,6 +74,9 @@ namespace Tensorflow case string strVal: nparray = strVal; break; + case string[] strVals: + nparray = strVals; + break; default: throw new Exception("make_tensor_proto Not Implemented"); } @@ -100,7 +103,8 @@ namespace Tensorflow } else { - throw new NotImplementedException("make_tensor_proto shape not implemented"); + shape_size = new TensorShape(shape).Size; + is_same_size = shape_size == nparray.size; } var tensor_proto = new tensor_pb2.TensorProto @@ -111,41 +115,17 @@ namespace Tensorflow if (is_same_size && _TENSOR_CONTENT_TYPES.Contains(numpy_dtype) && shape_size > 1) { - var bytes = new List(); - var nd2 = nparray.ravel(); - switch (nparray.dtype.Name) - { - case "Int32": - nd2.Data().Select(x => - { - bytes.AddRange(BitConverter.GetBytes(x)); - return x; - }).ToArray(); - break; - case "Single": - nd2.Data().Select(x => - { - bytes.AddRange(BitConverter.GetBytes(x)); - return x; - }).ToArray(); - break; - case "Double": - nd2.Data().Select(x => - { - bytes.AddRange(BitConverter.GetBytes(x)); - return x; - }).ToArray(); - break; - default: - throw new Exception("make_tensor_proto Not Implemented"); - } + byte[] bytes = nparray.ToByteArray(); tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(bytes.ToArray()); return tensor_proto; } - if (numpy_dtype == TF_DataType.TF_STRING && !(values is NDArray) && values is string str) + if (numpy_dtype == TF_DataType.TF_STRING && !(values is NDArray)) { - tensor_proto.StringVal.Add(Google.Protobuf.ByteString.CopyFromUtf8(str)); + if (values is string str) + tensor_proto.StringVal.Add(Google.Protobuf.ByteString.CopyFromUtf8(str)); + else if (values is string[] str_values) + tensor_proto.StringVal.AddRange(str_values.Select(x => Google.Protobuf.ByteString.CopyFromUtf8(x))); return tensor_proto; } diff --git a/src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs b/src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs index 0c7875f9..b4cb952e 100644 --- a/src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs +++ b/src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs @@ -123,8 +123,6 @@ namespace Tensorflow Version = _write_version }; }); - - } public Tensor _AddSaveOps(Tensor filename_tensor, SaveableObject[] saveables) diff --git a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj index 982acadb..4602f38e 100644 --- a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj +++ b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj @@ -6,7 +6,7 @@ - + diff --git a/test/TensorFlowNET.UnitTest/ConstantTest.cs b/test/TensorFlowNET.UnitTest/ConstantTest.cs index cbd1578d..8385f42c 100644 --- a/test/TensorFlowNET.UnitTest/ConstantTest.cs +++ b/test/TensorFlowNET.UnitTest/ConstantTest.cs @@ -58,6 +58,7 @@ namespace TensorFlowNET.UnitTest var data = result.Data(); Assert.AreEqual(0, data[0]); + Assert.AreEqual(0, data[500]); Assert.AreEqual(0, data[result.size - 1]); }); } @@ -109,5 +110,15 @@ namespace TensorFlowNET.UnitTest //c_api.TF_StringDecode(str, (ulong)str.Length, IntPtr.Zero, ref dst_len, status); } + + /// + /// tensorflow\c\c_api_test.cc + /// TestEncodeDecode + /// + [TestMethod] + public void EncodeDecode() + { + + } } } diff --git a/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj b/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj index 5c2e39bd..f090bdc8 100644 --- a/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj +++ b/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj @@ -19,7 +19,7 @@ - + diff --git a/test/TensorFlowNET.UnitTest/TrainSaverTest.cs b/test/TensorFlowNET.UnitTest/TrainSaverTest.cs index 53fab4eb..ef00dc91 100644 --- a/test/TensorFlowNET.UnitTest/TrainSaverTest.cs +++ b/test/TensorFlowNET.UnitTest/TrainSaverTest.cs @@ -10,11 +10,11 @@ namespace TensorFlowNET.UnitTest public class TrainSaverTest : Python { [TestMethod] - public void WriteGraph() + public void ExportGraph() { var v = tf.Variable(0, name: "my_variable"); var sess = tf.Session(); - tf.train.write_graph(sess.graph, "/tmp/my-model", "train.pbtxt"); + tf.train.write_graph(sess.graph, "/tmp/my-model", "train1.pbtxt"); } [TestMethod] @@ -22,14 +22,13 @@ namespace TensorFlowNET.UnitTest { var v = tf.Variable(0, name: "my_variable"); var sess = tf.Session(); - tf.train.write_graph(sess.graph, "/tmp/my-model", "train.pbtxt"); + tf.train.write_graph(sess.graph, "/tmp/my-model", "train2.pbtxt"); } [TestMethod] - public void SaveSimple() + public void Save1() { - var w1 = tf.Variable(tf.random_normal(new int[] { 2 }), name: "w1"); - var w2 = tf.Variable(tf.random_normal(new int[] { 5 }), name: "w2"); + var w1 = tf.Variable(0, name: "save1"); var init_op = tf.global_variables_initializer(); @@ -41,13 +40,13 @@ namespace TensorFlowNET.UnitTest sess.run(init_op); // Save the variables to disk. - var save_path = saver.save(sess, "/tmp/model.ckpt"); + var save_path = saver.save(sess, "/tmp/model1.ckpt"); Console.WriteLine($"Model saved in path: {save_path}"); }); } [TestMethod] - public void Save() + public void Save2() { var v1 = tf.get_variable("v1", shape: new TensorShape(3), initializer: tf.zeros_initializer); var v2 = tf.get_variable("v2", shape: new TensorShape(5), initializer: tf.zeros_initializer); @@ -69,7 +68,7 @@ namespace TensorFlowNET.UnitTest dec_v2.op.run(); // Save the variables to disk. - var save_path = saver.save(sess, "/tmp/model.ckpt"); + var save_path = saver.save(sess, "/tmp/model2.ckpt"); Console.WriteLine($"Model saved in path: {save_path}"); }); }