From 3b93c7b0fbcbd3dea24b69fb58f9bf1f41f92024 Mon Sep 17 00:00:00 2001 From: haiping008 Date: Wed, 20 Feb 2019 17:39:20 -0600 Subject: [PATCH] linear regression test 1 --- src/TensorFlowNET.Core/APIs/tf.math.cs | 2 +- .../Gradients/gradients_impl.py.cs | 2 +- .../Graphs/Graph.Operation.cs | 1 + src/TensorFlowNET.Core/Graphs/Graph.cs | 1 + .../Operations/Operation.cs | 6 +- src/TensorFlowNET.Core/Python.cs | 4 +- .../Sessions/BaseSession.cs | 2 + .../TensorFlowNET.Core.csproj | 2 +- .../Tensors/Tensor.Creation.cs | 37 +++++- .../Tensors/Tensor.Operators.cs | 119 ++++++++---------- .../Tensors/c_api.tensor.cs | 2 +- src/TensorFlowNET.Core/Tensors/tensor_util.cs | 8 +- .../LinearRegression.cs | 37 +++--- .../TensorFlowNET.Examples.csproj | 2 +- test/TensorFlowNET.UnitTest/GradientTest.cs | 1 + .../TensorFlowNET.UnitTest.csproj | 2 +- test/TensorFlowNET.UnitTest/TensorTest.cs | 2 +- test/TensorFlowNET.UnitTest/TrainSaverTest.cs | 6 - 18 files changed, 125 insertions(+), 111 deletions(-) diff --git a/src/TensorFlowNET.Core/APIs/tf.math.cs b/src/TensorFlowNET.Core/APIs/tf.math.cs index 44e04b3e..2b7706ac 100644 --- a/src/TensorFlowNET.Core/APIs/tf.math.cs +++ b/src/TensorFlowNET.Core/APIs/tf.math.cs @@ -18,7 +18,7 @@ namespace Tensorflow public static Tensor divide(Tensor x, T[] y, string name = "") where T : struct => x / ops.convert_to_tensor(y, dtype: x.dtype.as_base_dtype(), name: "y"); - public static Tensor pow(Tensor x, double y) => gen_math_ops.pow(x, y); + public static Tensor pow(T1 x, T2 y) => gen_math_ops.pow(x, y); /// /// Computes the sum of elements across dimensions of a tensor. diff --git a/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs b/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs index fbfce698..e4993b5e 100644 --- a/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs +++ b/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs @@ -357,7 +357,7 @@ namespace Tensorflow if (y.dtype.is_complex()) throw new TypeAccessException($"Gradients of complex tensors must set grad_ys (y.dtype = {y.dtype})"); var shape = array_ops.shape(y); - var constant = constant_op.constant(1.0, name: $"grad_ys_{i}"); + var constant = constant_op.constant(1.0f, name: $"grad_ys_{i}"); var fill = gen_array_ops.fill(shape, constant); new_grad_ys.Add(fill); } diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs b/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs index a9ad7a14..2608befc 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs @@ -41,6 +41,7 @@ namespace Tensorflow public Operation _create_op_from_tf_operation(IntPtr c_op, bool compute_device = true) { var ret = new Operation(c_op); + _add_op(ret); var name_key = ret.name.ToLower(); if (!_names_in_use.ContainsKey(name_key)) diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index 28f86a8a..f85f4eff 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -212,6 +212,7 @@ namespace Tensorflow public void _add_op(Operation op) { + op._id_value = _next_id(); _nodes_by_id[op._id] = op; _nodes_by_name[op.name] = op; _version = Math.Max(_version, op._id); diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index 45a57286..7e7a56c5 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -13,7 +13,7 @@ namespace Tensorflow public Graph graph { get; } public int _id => _id_value; - private int _id_value; + public int _id_value; public string type => OpType; public Operation op => this; @@ -46,8 +46,6 @@ namespace Tensorflow _outputs = new Tensor[NumOutputs]; for (int i = 0; i < NumOutputs; i++) _outputs[i] = new Tensor(this, i, OutputType(i)); - - graph._add_op(this); } public Operation(Graph g, string opType, string oper_name) @@ -100,8 +98,6 @@ namespace Tensorflow } // This will be set by self.inputs. - - _id_value = graph._next_id(); if(op_def == null) op_def = g.GetOpDef(node_def.Op); diff --git a/src/TensorFlowNET.Core/Python.cs b/src/TensorFlowNET.Core/Python.cs index b077bfc3..7ea80a90 100644 --- a/src/TensorFlowNET.Core/Python.cs +++ b/src/TensorFlowNET.Core/Python.cs @@ -88,8 +88,8 @@ namespace Tensorflow public static IEnumerable<(T, T)> zip(NDArray t1, NDArray t2) { - int index = 0; - yield return(t1.Data(index), t2.Data(index)); + for (int i = 0; i < t1.size; i++) + yield return (t1.Data(i), t2.Data(i)); } public static IEnumerable<(T1, T2)> zip(IList t1, IList t2) diff --git a/src/TensorFlowNET.Core/Sessions/BaseSession.cs b/src/TensorFlowNET.Core/Sessions/BaseSession.cs index 9c058413..198cd509 100644 --- a/src/TensorFlowNET.Core/Sessions/BaseSession.cs +++ b/src/TensorFlowNET.Core/Sessions/BaseSession.cs @@ -59,6 +59,7 @@ namespace Tensorflow { var subfeed_t = _graph.as_graph_element(subfeed, allow_tensor: true, allow_operation: false); var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype(); + switch (subfeed_val) { case IntPtr pointer: @@ -86,6 +87,7 @@ namespace Tensorflow Console.WriteLine($"can't handle data type of subfeed_val"); throw new NotImplementedException("_run subfeed"); } + feed_map[subfeed_t.name] = (subfeed_t, subfeed_val); } } diff --git a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj index eda472c1..8a28fc73 100644 --- a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj +++ b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj @@ -45,7 +45,7 @@ Upgraded to TensorFlow 1.13 RC2. - + diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs index af085dc2..73713d36 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs @@ -52,20 +52,45 @@ namespace Tensorflow var dataType = ToTFDataType(nd.dtype); // shape var dims = nd.shape.Select(x => (long)x).ToArray(); - + var nd1 = nd.ravel(); switch (nd.dtype.Name) { case "Int16": - Marshal.Copy(nd.ravel().Data(), 0, dotHandle, nd.size); + Marshal.Copy(nd1.Data(), 0, dotHandle, nd.size); break; case "Int32": - Marshal.Copy(nd.ravel().Data(), 0, dotHandle, nd.size); + Marshal.Copy(nd1.Data(), 0, dotHandle, nd.size); break; case "Single": - Marshal.Copy(nd.ravel().Data(), 0, dotHandle, nd.size); + Marshal.Copy(nd1.Data(), 0, dotHandle, nd.size); + /*if (nd.size > 1) + { + var bb = nd.Data(); + var bytes = Marshal.AllocHGlobal(bb.Length); + Marshal.Copy(bb, 0, bytes, bb.Length); + ulong bytes_len = c_api.TF_StringEncodedSize((ulong)bb.Length); + var dataTypeByte = ToTFDataType(nd.dtype); + // shape + var dims2 = nd.shape.Select(x => (long)x).ToArray(); + + var tfHandle2 = c_api.TF_AllocateTensor(dataTypeByte, + dims2, + nd.ndim, + bytes_len + sizeof(Int64)); + + dotHandle = c_api.TF_TensorData(tfHandle2); + Marshal.WriteInt64(dotHandle, 0); + c_api.TF_StringEncode(bytes, (ulong)bb.Length, dotHandle + sizeof(Int64), bytes_len, status); + return tfHandle2; + } + else + { + Marshal.Copy(nd1.Data(), 0, dotHandle, nd.size); + }*/ + break; case "Double": - Marshal.Copy(nd.ravel().Data(), 0, dotHandle, nd.size); + Marshal.Copy(nd1.Data(), 0, dotHandle, nd.size); break; //case "Byte": /*var bb = nd.Data(); @@ -119,7 +144,7 @@ namespace Tensorflow dims, dims.Length, dotHandle, - size, + (UIntPtr)size, deallocator, ref deallocator_called); diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs index 77c5c5c0..b428deee 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs @@ -6,83 +6,70 @@ namespace Tensorflow { public partial class Tensor { - public static Tensor operator +(Tensor x, Tensor y) - { - return Python.with(new ops.name_scope("", "add", new Tensor[] { x, y }), scope => - { - return gen_math_ops.add(x, y, scope); - }); - } - - public static Tensor operator +(Tensor x, int y) - { - return Python.with(new ops.name_scope("", "add", new object[] { x, y }), scope => - { - var y1 = ops.convert_to_tensor(y, x.dtype.as_base_dtype(), name: "y"); - return gen_math_ops.add(x, y1, scope); - }); - } + public static Tensor operator +(Tensor x, Tensor y) => BinaryOpWrapper("add", x, y); + public static Tensor operator +(Tensor x, int y) => BinaryOpWrapper("add", x, y); public static Tensor operator -(Tensor t1) => gen_math_ops.neg(t1); - public static Tensor operator -(Tensor t1, Tensor t2) => gen_math_ops.sub(t1, t2); - public static Tensor operator -(Tensor t1, int t2) => gen_math_ops.sub(t1, t2); - public static Tensor operator -(Tensor t1, double t2) => gen_math_ops.sub(t1, t2); - public static Tensor operator *(double x, Tensor y) - { - return Python.with(new ops.name_scope("", "mul", new { x, y }), - scope => - { - var x1 = ops.convert_to_tensor(x, y.dtype.as_base_dtype(), name: "x"); - return gen_math_ops.mul(x1, y, name: scope); - }); - } + public static Tensor operator -(Tensor x, Tensor y) => BinaryOpWrapper("sub", x, y); + public static Tensor operator -(Tensor x, int y) => BinaryOpWrapper("sub", x, y); + public static Tensor operator -(Tensor x, double y) => BinaryOpWrapper("sub", x, y); - public static Tensor operator *(Tensor x, Tensor y) - { - return Python.with(new ops.name_scope("", "mul", new Tensor[] { x, y }), scope => - { - return gen_math_ops.mul(x, y, name: scope); - }); - } + public static Tensor operator *(float x, Tensor y) => BinaryOpWrapper("mul", x, y); + public static Tensor operator *(double x, Tensor y) => BinaryOpWrapper("mul", x, y); + public static Tensor operator *(Tensor x, Tensor y) => BinaryOpWrapper("mul", x, y); + public static Tensor operator *(Tensor x, int y) => BinaryOpWrapper("mul", x, y); - public static Tensor operator *(Tensor x, int y) - { - return Python.with(new ops.name_scope("", "mul", new object[] { x, y }), scope => - { - var y1 = ops.convert_to_tensor(y, x.dtype.as_base_dtype(), name: "y"); - return gen_math_ops.mul(x, y1, name: scope); - }); - } + public static Tensor operator /(Tensor x, Tensor y) => BinaryOpWrapper("truediv", x, y); + public static Tensor operator /(Tensor x, float y) => BinaryOpWrapper("truediv", x, y); + public static Tensor operator /(Tensor x, double y) => BinaryOpWrapper("truediv", x, y); - public static Tensor operator /(Tensor x, Tensor y) - { - return Python.with(new ops.name_scope("truediv/", "truediv", new Tensor[] { x, y }), scope => - { - return gen_math_ops.real_div(x, y, scope); - }); - } + public static Tensor operator %(Tensor x, Tensor y) => BinaryOpWrapper("mod", x, y); - public static Tensor operator /(Tensor x, double y) - { - return Python.with(new ops.name_scope("truediv/", "truediv", new object[] { x, y }), scope => - { - var y1 = ops.convert_to_tensor(y, dtype: x.dtype.as_base_dtype(), name: "y"); - return gen_math_ops.real_div(x, y1, scope); - }); - } + public static Tensor operator >(Tensor x, int y) => gen_array_ops.greater(x, y); + public static Tensor operator >(Tensor x, double y) => gen_array_ops.greater(x, y); + public static Tensor operator <(Tensor x, int y) => gen_array_ops.less(x, y); + public static Tensor operator <(Tensor x, double y) => gen_array_ops.less(x, y); - public static Tensor operator %(Tensor x, Tensor y) + private static Tensor BinaryOpWrapper(string name, Tx x, Ty y) { - return Python.with(new ops.name_scope("", "mod", new object[] { x, y }), scope => + TF_DataType dtype = TF_DataType.DtInvalid; + if (x is Tensor tl) + dtype = tl.dtype.as_base_dtype(); + if( y is Tensor tr) + dtype = tr.dtype.as_base_dtype(); + + var namescope = new ops.name_scope("", name, new { x, y }); + return Python.with(namescope, scope => { - return gen_math_ops.floor_mod(x, y, scope); + Tensor result = null; + var x1 = ops.convert_to_tensor(x, dtype: dtype, name: "x"); + var y1 = ops.convert_to_tensor(y, dtype: dtype, name: "y"); + + switch (name) + { + case "add": + result = gen_math_ops.add(x1, y1, name: scope); + break; + case "truediv": + result = gen_math_ops.real_div(x1, y1, name: scope); + break; + case "mul": + result = gen_math_ops.mul(x1, y1, name: scope); + break; + case "sub": + result = gen_math_ops.sub(x1, y1, name: scope); + break; + case "mod": + result = gen_math_ops.floor_mod(x1, y1, name: scope); + break; + default: + throw new NotImplementedException($"BinaryOpWrapper: {name} - {typeof(Tx).Name}, {typeof(Ty)}"); + } + + return result; }); + } - - public static Tensor operator >(Tensor x, int y) => gen_array_ops.greater(x, y); - public static Tensor operator >(Tensor x, double y) => gen_array_ops.greater(x, y); - public static Tensor operator <(Tensor x, int y) => gen_array_ops.less(x, y); - public static Tensor operator <(Tensor x, double y) => gen_array_ops.less(x, y); } } diff --git a/src/TensorFlowNET.Core/Tensors/c_api.tensor.cs b/src/TensorFlowNET.Core/Tensors/c_api.tensor.cs index 94c4683b..875e8a0a 100644 --- a/src/TensorFlowNET.Core/Tensors/c_api.tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/c_api.tensor.cs @@ -55,7 +55,7 @@ namespace Tensorflow /// /// [DllImport(TensorFlowLibName)] - public static extern IntPtr TF_NewTensor(TF_DataType dataType, long[] dims, int num_dims, IntPtr data, ulong len, Deallocator deallocator, ref bool deallocator_arg); + public static extern IntPtr TF_NewTensor(TF_DataType dataType, long[] dims, int num_dims, IntPtr data, UIntPtr len, Deallocator deallocator, ref bool deallocator_arg); /// /// Return the number of dimensions that the tensor has. diff --git a/src/TensorFlowNET.Core/Tensors/tensor_util.cs b/src/TensorFlowNET.Core/Tensors/tensor_util.cs index 9ec4ef7b..6c90faae 100644 --- a/src/TensorFlowNET.Core/Tensors/tensor_util.cs +++ b/src/TensorFlowNET.Core/Tensors/tensor_util.cs @@ -96,19 +96,19 @@ namespace Tensorflow if (values.GetType().IsArray) nparray = np.array((int[])values, np_dt); else - nparray = (int)values; + nparray = Convert.ToInt32(values); break; case "Single": if (values.GetType().IsArray) nparray = np.array((float[])values, np_dt); else - nparray = (float)values; + nparray = Convert.ToSingle(values); break; case "Double": - nparray = (double)values; + nparray = Convert.ToDouble(values); break; case "String": - nparray = values.ToString(); + nparray = Convert.ToString(values); break; default: throw new NotImplementedException("make_tensor_proto Not Implemented"); diff --git a/test/TensorFlowNET.Examples/LinearRegression.cs b/test/TensorFlowNET.Examples/LinearRegression.cs index 65175cfb..4008bd72 100644 --- a/test/TensorFlowNET.Examples/LinearRegression.cs +++ b/test/TensorFlowNET.Examples/LinearRegression.cs @@ -10,41 +10,47 @@ namespace TensorFlowNET.Examples /// A linear regression learning algorithm example using TensorFlow library. /// https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/2_BasicModels/linear_regression.py /// - public class LinearRegression : IExample + public class LinearRegression : Python, IExample { private NumPyRandom rng = np.random; public void Run() { + var graph = tf.Graph().as_default(); + // Parameters - double learning_rate = 0.01; + float learning_rate = 0.01f; int training_epochs = 1000; - int display_step = 50; + int display_step = 1; // Training Data - var train_X = np.array(3.3, 4.4, 5.5, 6.71, 6.93, 4.168, 9.779, 6.182, 7.59, 2.167, - 7.042, 10.791, 5.313, 7.997, 5.654, 9.27, 3.1); - var train_Y = np.array(1.7, 2.76, 2.09, 3.19, 1.694, 1.573, 3.366, 2.596, 2.53, 1.221, - 2.827, 3.465, 1.65, 2.904, 2.42, 2.94, 1.3); + var train_X = np.array(3.3f, 4.4f, 5.5f, 6.71f, 6.93f, 4.168f, 9.779f, 6.182f, 7.59f, 2.167f, + 7.042f, 10.791f, 5.313f, 7.997f, 5.654f, 9.27f, 3.1f); + var train_Y = np.array(1.7f, 2.76f, 2.09f, 3.19f, 1.694f, 1.573f, 3.366f, 2.596f, 2.53f, 1.221f, + 2.827f, 3.465f, 1.65f, 2.904f, 2.42f, 2.94f, 1.3f); var n_samples = train_X.shape[0]; // tf Graph Input - var X = tf.placeholder(tf.float64); - var Y = tf.placeholder(tf.float64); + var X = tf.placeholder(tf.float32); + var Y = tf.placeholder(tf.float32); // Set model weights - var W = tf.Variable(rng.randn(), name: "weight"); - var b = tf.Variable(rng.randn(), name: "bias"); + //var rnd1 = rng.randn(); + //var rnd2 = rng.randn(); + var W = tf.Variable(-0.06f, name: "weight"); + var b = tf.Variable(-0.73f, name: "bias"); var mul = tf.multiply(X, W); var pred = tf.add(mul, b); // Mean squared error var sub = pred - Y; - var pow = tf.pow(sub, 2); + var pow = tf.pow(sub, 2.0f); var reduce = tf.reduce_sum(pow); - var cost = reduce / (2d * n_samples); + var cost = reduce / (2.0f * n_samples); + + // import graph // radient descent // Note, minimize() knows to modify W and b because Variable objects are trainable=True by default @@ -55,7 +61,7 @@ namespace TensorFlowNET.Examples var init = tf.global_variables_initializer(); // Start training - Python.with(tf.Session(), sess => + Python.with(tf.Session(graph), sess => { // Run the initializer sess.run(init); @@ -63,11 +69,12 @@ namespace TensorFlowNET.Examples // Fit all training data for (int epoch = 0; epoch < training_epochs; epoch++) { - foreach (var (x, y) in Python.zip(train_X, train_Y)) + foreach (var (x, y) in zip(train_X, train_Y)) { sess.run(optimizer, new FeedItem(X, x), new FeedItem(Y, y)); + var w = sess.run(W); } // Display logs per epoch step diff --git a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj index 8667d4c2..9dc1bd17 100644 --- a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj +++ b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj @@ -6,7 +6,7 @@ - + diff --git a/test/TensorFlowNET.UnitTest/GradientTest.cs b/test/TensorFlowNET.UnitTest/GradientTest.cs index bc887764..f22ca2c7 100644 --- a/test/TensorFlowNET.UnitTest/GradientTest.cs +++ b/test/TensorFlowNET.UnitTest/GradientTest.cs @@ -12,6 +12,7 @@ namespace TensorFlowNET.UnitTest [TestMethod] public void Gradients() { + var graph = tf.Graph().as_default(); var a = tf.constant(0.0); var b = 2.0 * a; Assert.AreEqual(b.name, "mul:0"); diff --git a/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj b/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj index 19bd6b44..2ff58949 100644 --- a/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj +++ b/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj @@ -19,7 +19,7 @@ - + diff --git a/test/TensorFlowNET.UnitTest/TensorTest.cs b/test/TensorFlowNET.UnitTest/TensorTest.cs index 03116359..a1733002 100644 --- a/test/TensorFlowNET.UnitTest/TensorTest.cs +++ b/test/TensorFlowNET.UnitTest/TensorTest.cs @@ -60,7 +60,7 @@ namespace TensorFlowNET.UnitTest EXPECT_EQ((int)tensor.shape[0], nd.shape[0]); EXPECT_EQ((int)tensor.shape[1], nd.shape[1]); EXPECT_EQ(tensor.bytesize, (ulong)nd.size * sizeof(float)); - Assert.IsTrue(Enumerable.SequenceEqual(nd.Data(), array)); + Assert.IsTrue(Enumerable.SequenceEqual(nd.Data(), new float[] { 1, 4, 2, 5, 3, 6 })); } /// diff --git a/test/TensorFlowNET.UnitTest/TrainSaverTest.cs b/test/TensorFlowNET.UnitTest/TrainSaverTest.cs index 28c8a242..c6023402 100644 --- a/test/TensorFlowNET.UnitTest/TrainSaverTest.cs +++ b/test/TensorFlowNET.UnitTest/TrainSaverTest.cs @@ -10,7 +10,6 @@ namespace TensorFlowNET.UnitTest [TestClass] public class TrainSaverTest : Python { - [TestMethod] public void ExportGraph() { var v = tf.Variable(0, name: "my_variable"); @@ -18,7 +17,6 @@ namespace TensorFlowNET.UnitTest tf.train.write_graph(sess.graph, "/tmp/my-model", "train1.pbtxt"); } - [TestMethod] public void ImportGraph() { with(tf.Session(), sess => @@ -27,7 +25,6 @@ namespace TensorFlowNET.UnitTest }); } - [TestMethod] public void ImportSavedModel() { with(Session.LoadFromSavedModel("mobilenet"), sess => @@ -36,14 +33,12 @@ namespace TensorFlowNET.UnitTest }); } - [TestMethod] public void ImportGraphDefFromPbFile() { var g = new Graph(); var status = g.Import("mobilenet/saved_model.pb"); } - [TestMethod] public void Save1() { var w1 = tf.Variable(0, name: "save1"); @@ -63,7 +58,6 @@ namespace TensorFlowNET.UnitTest }); } - [TestMethod] public void Save2() { var v1 = tf.get_variable("v1", shape: new TensorShape(3), initializer: tf.zeros_initializer);