diff --git a/docs/RELEASE.md b/docs/RELEASE.md index 98925ddf..62a1be23 100644 --- a/docs/RELEASE.md +++ b/docs/RELEASE.md @@ -4,6 +4,25 @@ This release contains contributions from many people at SciSharp as well as the external contributors. +**Release Date 02/06/2021** + +### TensorFlow.Binding v0.33.0 + +* Improve memory usage +* Fix minor bugs + +### TensorFlow.Keras v0.4.0 + +* Add Subtract layer + +* Add model.load_weights and model.save_weights + +* Fix memory leak issue + +* Support to build YOLOv3 object detection model + + + **Release Date 01/09/2021** ### TensorFlow.Binding v0.32.0 diff --git a/src/TensorFlowNET.Core/APIs/tf.array.cs b/src/TensorFlowNET.Core/APIs/tf.array.cs index 8452b81a..390942d2 100644 --- a/src/TensorFlowNET.Core/APIs/tf.array.cs +++ b/src/TensorFlowNET.Core/APIs/tf.array.cs @@ -215,6 +215,9 @@ namespace Tensorflow public Tensor ones_like(Tensor tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true) => array_ops.ones_like(tensor, dtype: dtype, name: name, optimize: optimize); + public Tensor ones_like(NDArray nd, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true) + => array_ops.ones_like(nd, dtype: dtype, name: name, optimize: optimize); + public Tensor one_hot(Tensor indices, int depth, Tensor on_value = null, Tensor off_value = null, @@ -290,6 +293,9 @@ namespace Tensorflow public Tensor zeros_like(Tensor tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true) => array_ops.zeros_like(tensor, dtype: dtype, name: name, optimize: optimize); + public Tensor zeros_like(NDArray nd, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true) + => array_ops.zeros_like(nd, dtype: dtype, name: name, optimize: optimize); + /// /// Stops gradient computation. /// diff --git a/src/TensorFlowNET.Core/Binding.Util.cs b/src/TensorFlowNET.Core/Binding.Util.cs index 62ba0bbd..535bbca4 100644 --- a/src/TensorFlowNET.Core/Binding.Util.cs +++ b/src/TensorFlowNET.Core/Binding.Util.cs @@ -137,6 +137,8 @@ namespace Tensorflow { switch (a) { + case Tensors arr: + return arr.Length; case Array arr: return arr.Length; case IList arr: diff --git a/src/TensorFlowNET.Core/Contexts/Context.AutoMode.cs b/src/TensorFlowNET.Core/Contexts/Context.AutoMode.cs index 7db178b3..b076c90f 100644 --- a/src/TensorFlowNET.Core/Contexts/Context.AutoMode.cs +++ b/src/TensorFlowNET.Core/Contexts/Context.AutoMode.cs @@ -28,6 +28,7 @@ namespace Tensorflow.Contexts /// public sealed partial class Context { + // [DebuggerStepThrough] public T RunInAutoMode(Func graphAction, Func eagerAction, params object[] args) { if (tf.Context.has_graph_arg(args)) diff --git a/src/TensorFlowNET.Core/Operations/array_ops.cs b/src/TensorFlowNET.Core/Operations/array_ops.cs index bf5324dd..1801d69a 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.cs @@ -388,14 +388,12 @@ namespace Tensorflow if (dtype == TF_DataType.DtInvalid) dtype = tensor1.dtype; var ret = ones(ones_shape, dtype: dtype, name: name); - ret.shape = tensor1.shape; return ret; }); } public static Tensor ones(Tensor shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) { - dtype = dtype.as_base_dtype(); return tf_with(ops.name_scope(name, "ones", new { shape }), scope => { name = scope; @@ -578,11 +576,10 @@ namespace Tensorflow if (!tf.Context.executing_eagerly()) { - var input_tensor = ops.convert_to_tensor(input); - var input_shape = input_tensor.TensorShape; - if (optimize && input_tensor.NDims > -1 && input_shape.is_fully_defined()) + var input_shape = input.TensorShape; + if (optimize && input.NDims > -1 && input_shape.is_fully_defined()) { - var nd = np.array(input_tensor.shape).astype(out_type.as_numpy_dtype()); + var nd = np.array(input.shape).astype(out_type.as_numpy_dtype()); return constant_op.constant(nd, name: name); } } diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs index bebb24b8..5d585e77 100644 --- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs @@ -124,6 +124,9 @@ namespace Tensorflow x, y).FirstOrDefault(), x, y); + public static Tensor mean(Tensor input, int axis, bool keep_dims = false, string name = null) + => mean(input, ops.convert_to_tensor(axis), keep_dims: keep_dims, name: name); + /// /// Computes the mean of elements across dimensions of a tensor. /// Reduces `input` along the dimensions given in `axis`. Unless @@ -137,23 +140,30 @@ namespace Tensorflow /// An optional `bool`. Defaults to `False`. If true, retain reduced dimensions with length 1. /// A name for the operation (optional). /// A `Tensor`. Has the same type as `input`. - public static Tensor mean(T1 input, T2 axis, bool keep_dims = false, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, + public static Tensor mean(Tensor input, Tensor axis, bool keep_dims = false, string name = null) + => tf.Context.RunInAutoMode2( + () => tf.OpDefLib._apply_op_helper("Mean", name, new + { + input, + reduction_indices = axis, + keep_dims = keep_dims + }).output, + () => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, "Mean", name, null, input, axis, - "keep_dims", keep_dims); - - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("Mean", name, args: new { input, reduction_indices = axis, keep_dims = keep_dims }); - - return _op.output; - } + "keep_dims", keep_dims).FirstOrDefault(), + (op) => + { + var attrs = new object[] + { + "T", op.get_attr("T"), + "Tidx", op.get_attr("Tidx"), + "keep_dims", op.get_attr("keep_dims") + }; + tf.Runner.RecordGradient("Mean", op.inputs, attrs, op.outputs); + }, + new Tensors(input, axis)); public static Tensor mean(Tensor[] inputs, Tensor axis, bool keep_dims = false, string name = null) { @@ -786,20 +796,21 @@ namespace Tensorflow } public static Tensor sub(Tensor x, Tensor y, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, + => tf.Context.RunInAutoMode2( + () => tf.OpDefLib._apply_op_helper("Sub", name, new { x, y }).output, + () => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, "Sub", name, null, - x, y); - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("Sub", name, args: new { x, y }); - - return _op.output; - } + x, y).FirstOrDefault(), + (op) => + { + var attrs = new object[] + { + "T", op.get_attr("T") + }; + tf.Runner.RecordGradient("Sub", op.inputs, attrs, op.outputs); + }, + new Tensors(x, y)); public static Tensor sub(Tx x, Ty y, string name = null) { diff --git a/src/TensorFlowNET.Core/Operations/math_ops.cs b/src/TensorFlowNET.Core/Operations/math_ops.cs index 2c051992..391ad9d5 100644 --- a/src/TensorFlowNET.Core/Operations/math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/math_ops.cs @@ -327,31 +327,17 @@ namespace Tensorflow public static Tensor reduce_mean(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null, int? reduction_indices = null) { var r = _ReductionDims(input_tensor, axis); - if (axis == null) - { - var m = gen_math_ops.mean(input_tensor, r, keepdims, name); - return _may_reduce_to_scalar(keepdims, axis, m); - } - else - { - var m = gen_math_ops.mean(input_tensor, axis, keepdims, name); - return _may_reduce_to_scalar(keepdims, axis, m); - } + var axis_tensor = axis == null ? r : ops.convert_to_tensor(axis); + var m = gen_math_ops.mean(input_tensor, axis_tensor, keepdims, name); + return _may_reduce_to_scalar(keepdims, axis_tensor, m); } public static Tensor reduce_mean(Tensor[] input_tensors, int? axis = null, bool keepdims = false, string name = null) { - if (axis == null) - { - var r = _ReductionDims(input_tensors, axis); - var m = gen_math_ops.mean(input_tensors, r, keepdims, name); - return _may_reduce_to_scalar(keepdims, axis, m); - } - else - { - var m = gen_math_ops.mean(input_tensors, axis, keepdims, name); - return _may_reduce_to_scalar(keepdims, axis, m); - } + var r = _ReductionDims(input_tensors, axis); + var axis_tensor = axis == null ? r : ops.convert_to_tensor(axis.Value); + var m = gen_math_ops.mean(input_tensors, axis_tensor, keepdims, name); + return _may_reduce_to_scalar(keepdims, axis, m); } /// diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.String.cs b/src/TensorFlowNET.Core/Tensors/Tensor.String.cs index e9d8efdc..e331dc1a 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.String.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.String.cs @@ -90,17 +90,17 @@ namespace Tensorflow size *= s; var buffer = new byte[size][]; - var src = c_api.TF_TensorData(_handle); - src += (int)(size * 8); + var data_start = c_api.TF_TensorData(_handle); + data_start += (int)(size * sizeof(ulong)); for (int i = 0; i < buffer.Length; i++) { IntPtr dst = IntPtr.Zero; ulong dstLen = 0; - var read = c_api.TF_StringDecode((byte*)src, bytesize, (byte**)&dst, ref dstLen, tf.Status.Handle); + var read = c_api.TF_StringDecode((byte*)data_start, bytesize, (byte**)&dst, ref dstLen, tf.Status.Handle); tf.Status.Check(true); buffer[i] = new byte[(int)dstLen]; Marshal.Copy(dst, buffer[i], 0, buffer[i].Length); - src += (int)read; + data_start += (int)read; } return buffer; diff --git a/src/TensorFlowNET.Core/Tensors/Tensors.cs b/src/TensorFlowNET.Core/Tensors/Tensors.cs index 1c8d939a..3c334ea5 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensors.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensors.cs @@ -69,13 +69,14 @@ namespace Tensorflow => items.Insert(index, tensor); IEnumerator IEnumerable.GetEnumerator() - { - throw new NotImplementedException(); - } + => GetEnumerator(); public static implicit operator Tensors(Tensor tensor) => new Tensors(tensor); + public static implicit operator Tensors((Tensor, Tensor) tuple) + => new Tensors(tuple.Item1, tuple.Item2); + public static implicit operator Tensors(NDArray nd) => new Tensors(nd); diff --git a/tensorflowlib/README.md b/tensorflowlib/README.md index 20d30f6f..a08959a7 100644 --- a/tensorflowlib/README.md +++ b/tensorflowlib/README.md @@ -56,7 +56,7 @@ Set ENV `BAZEL_VC=C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\ 1. Build static library -`bazel build --config=opt //tensorflow:tensorflow` +`bazel build --output_base=C:/tmp/tfcompilation build --config=opt //tensorflow:tensorflow` 2. Build pip package diff --git a/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs b/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs index f7e6155c..62d9fa5c 100644 --- a/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs +++ b/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs @@ -1,6 +1,7 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; using NumSharp; using Tensorflow; +using static Tensorflow.Binding; using static Tensorflow.KerasApi; namespace TensorFlowNET.Keras.UnitTest @@ -39,8 +40,8 @@ namespace TensorFlowNET.Keras.UnitTest /// /// Custom layer test, used in Dueling DQN /// - [TestMethod, Ignore] - public void FunctionalTest() + [TestMethod] + public void TensorFlowOpLayer() { var layers = keras.layers; var inputs = layers.Input(shape: 24); @@ -48,58 +49,15 @@ namespace TensorFlowNET.Keras.UnitTest var value = layers.Dense(24).Apply(x); var adv = layers.Dense(1).Apply(x); - var adv_out = adv - Binding.tf.reduce_mean(adv, axis: 1, keepdims: true); // Here's problem. - var outputs = layers.Add().Apply(new Tensors(adv_out, value)); + var mean = adv - tf.reduce_mean(adv, axis: 1, keepdims: true); + adv = layers.Subtract().Apply((adv, mean)); + var outputs = layers.Add().Apply((value, adv)); var model = keras.Model(inputs, outputs); - model.summary(); model.compile(optimizer: keras.optimizers.RMSprop(0.001f), loss: keras.losses.MeanSquaredError(), metrics: new[] { "acc" }); - // Here we consider the adv_out is one layer, which is a little different from py's version - Assert.AreEqual(model.Layers.Count, 6); - - // py code: - //from tensorflow.keras.layers import Input, Dense, Add, Subtract, Lambda - //from tensorflow.keras.models import Model - //from tensorflow.keras.optimizers import RMSprop - //import tensorflow.keras.backend as K - - //inputs = Input(24) - //x = Dense(128, activation = "relu")(inputs) - //value = Dense(24)(x) - //adv = Dense(1)(x) - //meam = Lambda(lambda x: K.mean(x, axis = 1, keepdims = True))(adv) - //adv = Subtract()([adv, meam]) - //outputs = Add()([value, adv]) - //model = Model(inputs, outputs) - //model.compile(loss = "mse", optimizer = RMSprop(1e-3)) - //model.summary() - - //py output: - //Model: "functional_3" - //__________________________________________________________________________________________________ - //Layer(type) Output Shape Param # Connected to - //================================================================================================== - //input_2 (InputLayer) [(None, 24)] 0 - //__________________________________________________________________________________________________ - //dense_3 (Dense) (None, 128) 3200 input_2[0][0] - //__________________________________________________________________________________________________ - //dense_5 (Dense) (None, 1) 129 dense_3[0][0] - //__________________________________________________________________________________________________ - //lambda_1 (Lambda) (None, 1) 0 dense_5[0][0] - //__________________________________________________________________________________________________ - //dense_4 (Dense) (None, 24) 3096 dense_3[0][0] - //__________________________________________________________________________________________________ - //subtract_1 (Subtract) (None, 1) 0 dense_5[0][0] - // lambda_1[0][0] - //__________________________________________________________________________________________________ - //add_1 (Add) (None, 24) 0 dense_4[0][0] - // subtract_1[0][0] - //================================================================================================== - //Total params: 6,425 - //Trainable params: 6,425 - //Non-trainable params: 0 - //__________________________________________________________________________________________________ + model.summary(); + Assert.AreEqual(model.Layers.Count, 8); } /// diff --git a/test/TensorFlowNET.UnitTest/ManagedAPI/TensorOperate.cs b/test/TensorFlowNET.UnitTest/ManagedAPI/TensorOperate.cs index 9966c12e..c57c98df 100644 --- a/test/TensorFlowNET.UnitTest/ManagedAPI/TensorOperate.cs +++ b/test/TensorFlowNET.UnitTest/ManagedAPI/TensorOperate.cs @@ -132,28 +132,25 @@ namespace TensorFlowNET.UnitTest.ManagedAPI } #region ones/zeros like - [Ignore] [TestMethod] public void TestOnesLike() { #region 2-dimension - var testCase2D = tf.constant(new int[,] + var ones2D = tf.ones_like(new int[,] { { 1, 2, 3 }, { 4, 5, 6 } }); - var ones2D = tf.ones_like(testCase2D); Assert.AreEqual(new[] { 1, 1, 1 }, ones2D[0].numpy()); Assert.AreEqual(new[] { 1, 1, 1 }, ones2D[1].numpy()); #endregion #region 1-dimension - var testCase1D = tf.constant(new int[,] + var ones1D = tf.ones_like(new int[,] { { 1, 2, 3 } }); - var ones1D = tf.ones_like(testCase1D); Assert.AreEqual(new[] { 1, 1, 1 }, ones1D[0].numpy()); #endregion @@ -163,23 +160,21 @@ namespace TensorFlowNET.UnitTest.ManagedAPI public void TestZerosLike() { #region 2-dimension - var testCase2D = tf.constant(new int[,] + var zeros2D = tf.zeros_like(new int[,] { { 1, 2, 3 }, { 4, 5, 6 } }); - var zeros2D = tf.zeros_like(testCase2D); Assert.AreEqual(new[] { 0, 0, 0 }, zeros2D[0].numpy()); Assert.AreEqual(new[] { 0, 0, 0 }, zeros2D[1].numpy()); #endregion #region 1-dimension - var testCase1D = tf.constant(new int[,] + var zeros1D = tf.zeros_like(new int[,] { { 1, 2, 3 } }); - var zeros1D = tf.zeros_like(testCase1D); Assert.AreEqual(new[] { 0, 0, 0 }, zeros1D[0].numpy()); #endregion diff --git a/test/Tensorflow.Keras.UnitTest/OptimizerTest.cs b/test/Tensorflow.Keras.UnitTest/OptimizerTest.cs deleted file mode 100644 index 6647ca59..00000000 --- a/test/Tensorflow.Keras.UnitTest/OptimizerTest.cs +++ /dev/null @@ -1,11 +0,0 @@ -using Microsoft.VisualStudio.TestTools.UnitTesting; -using System.Collections.Generic; - -namespace Tensorflow.Keras.UnitTest -{ - [TestClass] - public class OptimizerTest - { - - } -} diff --git a/test/Tensorflow.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj b/test/Tensorflow.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj deleted file mode 100644 index 5f5ab347..00000000 --- a/test/Tensorflow.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj +++ /dev/null @@ -1,25 +0,0 @@ - - - - netcoreapp3.1 - - false - - AnyCPU;x64 - - - - - - - - all - runtime; build; native; contentfiles; analyzers; buildtransitive - - - - - - - -