diff --git a/docs/RELEASE.md b/docs/RELEASE.md
index 98925ddf..62a1be23 100644
--- a/docs/RELEASE.md
+++ b/docs/RELEASE.md
@@ -4,6 +4,25 @@
This release contains contributions from many people at SciSharp as well as the external contributors.
+**Release Date 02/06/2021**
+
+### TensorFlow.Binding v0.33.0
+
+* Improve memory usage
+* Fix minor bugs
+
+### TensorFlow.Keras v0.4.0
+
+* Add Subtract layer
+
+* Add model.load_weights and model.save_weights
+
+* Fix memory leak issue
+
+* Support to build YOLOv3 object detection model
+
+
+
**Release Date 01/09/2021**
### TensorFlow.Binding v0.32.0
diff --git a/src/TensorFlowNET.Core/APIs/tf.array.cs b/src/TensorFlowNET.Core/APIs/tf.array.cs
index 8452b81a..390942d2 100644
--- a/src/TensorFlowNET.Core/APIs/tf.array.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.array.cs
@@ -215,6 +215,9 @@ namespace Tensorflow
public Tensor ones_like(Tensor tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true)
=> array_ops.ones_like(tensor, dtype: dtype, name: name, optimize: optimize);
+ public Tensor ones_like(NDArray nd, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true)
+ => array_ops.ones_like(nd, dtype: dtype, name: name, optimize: optimize);
+
public Tensor one_hot(Tensor indices, int depth,
Tensor on_value = null,
Tensor off_value = null,
@@ -290,6 +293,9 @@ namespace Tensorflow
public Tensor zeros_like(Tensor tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true)
=> array_ops.zeros_like(tensor, dtype: dtype, name: name, optimize: optimize);
+ public Tensor zeros_like(NDArray nd, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true)
+ => array_ops.zeros_like(nd, dtype: dtype, name: name, optimize: optimize);
+
///
/// Stops gradient computation.
///
diff --git a/src/TensorFlowNET.Core/Binding.Util.cs b/src/TensorFlowNET.Core/Binding.Util.cs
index 62ba0bbd..535bbca4 100644
--- a/src/TensorFlowNET.Core/Binding.Util.cs
+++ b/src/TensorFlowNET.Core/Binding.Util.cs
@@ -137,6 +137,8 @@ namespace Tensorflow
{
switch (a)
{
+ case Tensors arr:
+ return arr.Length;
case Array arr:
return arr.Length;
case IList arr:
diff --git a/src/TensorFlowNET.Core/Contexts/Context.AutoMode.cs b/src/TensorFlowNET.Core/Contexts/Context.AutoMode.cs
index 7db178b3..b076c90f 100644
--- a/src/TensorFlowNET.Core/Contexts/Context.AutoMode.cs
+++ b/src/TensorFlowNET.Core/Contexts/Context.AutoMode.cs
@@ -28,6 +28,7 @@ namespace Tensorflow.Contexts
///
public sealed partial class Context
{
+ // [DebuggerStepThrough]
public T RunInAutoMode(Func graphAction, Func eagerAction, params object[] args)
{
if (tf.Context.has_graph_arg(args))
diff --git a/src/TensorFlowNET.Core/Operations/array_ops.cs b/src/TensorFlowNET.Core/Operations/array_ops.cs
index bf5324dd..1801d69a 100644
--- a/src/TensorFlowNET.Core/Operations/array_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/array_ops.cs
@@ -388,14 +388,12 @@ namespace Tensorflow
if (dtype == TF_DataType.DtInvalid)
dtype = tensor1.dtype;
var ret = ones(ones_shape, dtype: dtype, name: name);
- ret.shape = tensor1.shape;
return ret;
});
}
public static Tensor ones(Tensor shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null)
{
- dtype = dtype.as_base_dtype();
return tf_with(ops.name_scope(name, "ones", new { shape }), scope =>
{
name = scope;
@@ -578,11 +576,10 @@ namespace Tensorflow
if (!tf.Context.executing_eagerly())
{
- var input_tensor = ops.convert_to_tensor(input);
- var input_shape = input_tensor.TensorShape;
- if (optimize && input_tensor.NDims > -1 && input_shape.is_fully_defined())
+ var input_shape = input.TensorShape;
+ if (optimize && input.NDims > -1 && input_shape.is_fully_defined())
{
- var nd = np.array(input_tensor.shape).astype(out_type.as_numpy_dtype());
+ var nd = np.array(input.shape).astype(out_type.as_numpy_dtype());
return constant_op.constant(nd, name: name);
}
}
diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs
index bebb24b8..5d585e77 100644
--- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs
@@ -124,6 +124,9 @@ namespace Tensorflow
x, y).FirstOrDefault(),
x, y);
+ public static Tensor mean(Tensor input, int axis, bool keep_dims = false, string name = null)
+ => mean(input, ops.convert_to_tensor(axis), keep_dims: keep_dims, name: name);
+
///
/// Computes the mean of elements across dimensions of a tensor.
/// Reduces `input` along the dimensions given in `axis`. Unless
@@ -137,23 +140,30 @@ namespace Tensorflow
/// An optional `bool`. Defaults to `False`. If true, retain reduced dimensions with length 1.
/// A name for the operation (optional).
/// A `Tensor`. Has the same type as `input`.
- public static Tensor mean(T1 input, T2 axis, bool keep_dims = false, string name = null)
- {
- if (tf.Context.executing_eagerly())
- {
- var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
+ public static Tensor mean(Tensor input, Tensor axis, bool keep_dims = false, string name = null)
+ => tf.Context.RunInAutoMode2(
+ () => tf.OpDefLib._apply_op_helper("Mean", name, new
+ {
+ input,
+ reduction_indices = axis,
+ keep_dims = keep_dims
+ }).output,
+ () => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"Mean", name,
null,
input, axis,
- "keep_dims", keep_dims);
-
- return results[0];
- }
-
- var _op = tf.OpDefLib._apply_op_helper("Mean", name, args: new { input, reduction_indices = axis, keep_dims = keep_dims });
-
- return _op.output;
- }
+ "keep_dims", keep_dims).FirstOrDefault(),
+ (op) =>
+ {
+ var attrs = new object[]
+ {
+ "T", op.get_attr("T"),
+ "Tidx", op.get_attr("Tidx"),
+ "keep_dims", op.get_attr("keep_dims")
+ };
+ tf.Runner.RecordGradient("Mean", op.inputs, attrs, op.outputs);
+ },
+ new Tensors(input, axis));
public static Tensor mean(Tensor[] inputs, Tensor axis, bool keep_dims = false, string name = null)
{
@@ -786,20 +796,21 @@ namespace Tensorflow
}
public static Tensor sub(Tensor x, Tensor y, string name = null)
- {
- if (tf.Context.executing_eagerly())
- {
- var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
+ => tf.Context.RunInAutoMode2(
+ () => tf.OpDefLib._apply_op_helper("Sub", name, new { x, y }).output,
+ () => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"Sub", name,
null,
- x, y);
- return results[0];
- }
-
- var _op = tf.OpDefLib._apply_op_helper("Sub", name, args: new { x, y });
-
- return _op.output;
- }
+ x, y).FirstOrDefault(),
+ (op) =>
+ {
+ var attrs = new object[]
+ {
+ "T", op.get_attr("T")
+ };
+ tf.Runner.RecordGradient("Sub", op.inputs, attrs, op.outputs);
+ },
+ new Tensors(x, y));
public static Tensor sub(Tx x, Ty y, string name = null)
{
diff --git a/src/TensorFlowNET.Core/Operations/math_ops.cs b/src/TensorFlowNET.Core/Operations/math_ops.cs
index 2c051992..391ad9d5 100644
--- a/src/TensorFlowNET.Core/Operations/math_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/math_ops.cs
@@ -327,31 +327,17 @@ namespace Tensorflow
public static Tensor reduce_mean(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null, int? reduction_indices = null)
{
var r = _ReductionDims(input_tensor, axis);
- if (axis == null)
- {
- var m = gen_math_ops.mean(input_tensor, r, keepdims, name);
- return _may_reduce_to_scalar(keepdims, axis, m);
- }
- else
- {
- var m = gen_math_ops.mean(input_tensor, axis, keepdims, name);
- return _may_reduce_to_scalar(keepdims, axis, m);
- }
+ var axis_tensor = axis == null ? r : ops.convert_to_tensor(axis);
+ var m = gen_math_ops.mean(input_tensor, axis_tensor, keepdims, name);
+ return _may_reduce_to_scalar(keepdims, axis_tensor, m);
}
public static Tensor reduce_mean(Tensor[] input_tensors, int? axis = null, bool keepdims = false, string name = null)
{
- if (axis == null)
- {
- var r = _ReductionDims(input_tensors, axis);
- var m = gen_math_ops.mean(input_tensors, r, keepdims, name);
- return _may_reduce_to_scalar(keepdims, axis, m);
- }
- else
- {
- var m = gen_math_ops.mean(input_tensors, axis, keepdims, name);
- return _may_reduce_to_scalar(keepdims, axis, m);
- }
+ var r = _ReductionDims(input_tensors, axis);
+ var axis_tensor = axis == null ? r : ops.convert_to_tensor(axis.Value);
+ var m = gen_math_ops.mean(input_tensors, axis_tensor, keepdims, name);
+ return _may_reduce_to_scalar(keepdims, axis, m);
}
///
diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.String.cs b/src/TensorFlowNET.Core/Tensors/Tensor.String.cs
index e9d8efdc..e331dc1a 100644
--- a/src/TensorFlowNET.Core/Tensors/Tensor.String.cs
+++ b/src/TensorFlowNET.Core/Tensors/Tensor.String.cs
@@ -90,17 +90,17 @@ namespace Tensorflow
size *= s;
var buffer = new byte[size][];
- var src = c_api.TF_TensorData(_handle);
- src += (int)(size * 8);
+ var data_start = c_api.TF_TensorData(_handle);
+ data_start += (int)(size * sizeof(ulong));
for (int i = 0; i < buffer.Length; i++)
{
IntPtr dst = IntPtr.Zero;
ulong dstLen = 0;
- var read = c_api.TF_StringDecode((byte*)src, bytesize, (byte**)&dst, ref dstLen, tf.Status.Handle);
+ var read = c_api.TF_StringDecode((byte*)data_start, bytesize, (byte**)&dst, ref dstLen, tf.Status.Handle);
tf.Status.Check(true);
buffer[i] = new byte[(int)dstLen];
Marshal.Copy(dst, buffer[i], 0, buffer[i].Length);
- src += (int)read;
+ data_start += (int)read;
}
return buffer;
diff --git a/src/TensorFlowNET.Core/Tensors/Tensors.cs b/src/TensorFlowNET.Core/Tensors/Tensors.cs
index 1c8d939a..3c334ea5 100644
--- a/src/TensorFlowNET.Core/Tensors/Tensors.cs
+++ b/src/TensorFlowNET.Core/Tensors/Tensors.cs
@@ -69,13 +69,14 @@ namespace Tensorflow
=> items.Insert(index, tensor);
IEnumerator IEnumerable.GetEnumerator()
- {
- throw new NotImplementedException();
- }
+ => GetEnumerator();
public static implicit operator Tensors(Tensor tensor)
=> new Tensors(tensor);
+ public static implicit operator Tensors((Tensor, Tensor) tuple)
+ => new Tensors(tuple.Item1, tuple.Item2);
+
public static implicit operator Tensors(NDArray nd)
=> new Tensors(nd);
diff --git a/tensorflowlib/README.md b/tensorflowlib/README.md
index 20d30f6f..a08959a7 100644
--- a/tensorflowlib/README.md
+++ b/tensorflowlib/README.md
@@ -56,7 +56,7 @@ Set ENV `BAZEL_VC=C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\
1. Build static library
-`bazel build --config=opt //tensorflow:tensorflow`
+`bazel build --output_base=C:/tmp/tfcompilation build --config=opt //tensorflow:tensorflow`
2. Build pip package
diff --git a/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs b/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs
index f7e6155c..62d9fa5c 100644
--- a/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs
+++ b/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs
@@ -1,6 +1,7 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using NumSharp;
using Tensorflow;
+using static Tensorflow.Binding;
using static Tensorflow.KerasApi;
namespace TensorFlowNET.Keras.UnitTest
@@ -39,8 +40,8 @@ namespace TensorFlowNET.Keras.UnitTest
///
/// Custom layer test, used in Dueling DQN
///
- [TestMethod, Ignore]
- public void FunctionalTest()
+ [TestMethod]
+ public void TensorFlowOpLayer()
{
var layers = keras.layers;
var inputs = layers.Input(shape: 24);
@@ -48,58 +49,15 @@ namespace TensorFlowNET.Keras.UnitTest
var value = layers.Dense(24).Apply(x);
var adv = layers.Dense(1).Apply(x);
- var adv_out = adv - Binding.tf.reduce_mean(adv, axis: 1, keepdims: true); // Here's problem.
- var outputs = layers.Add().Apply(new Tensors(adv_out, value));
+ var mean = adv - tf.reduce_mean(adv, axis: 1, keepdims: true);
+ adv = layers.Subtract().Apply((adv, mean));
+ var outputs = layers.Add().Apply((value, adv));
var model = keras.Model(inputs, outputs);
- model.summary();
model.compile(optimizer: keras.optimizers.RMSprop(0.001f),
loss: keras.losses.MeanSquaredError(),
metrics: new[] { "acc" });
- // Here we consider the adv_out is one layer, which is a little different from py's version
- Assert.AreEqual(model.Layers.Count, 6);
-
- // py code:
- //from tensorflow.keras.layers import Input, Dense, Add, Subtract, Lambda
- //from tensorflow.keras.models import Model
- //from tensorflow.keras.optimizers import RMSprop
- //import tensorflow.keras.backend as K
-
- //inputs = Input(24)
- //x = Dense(128, activation = "relu")(inputs)
- //value = Dense(24)(x)
- //adv = Dense(1)(x)
- //meam = Lambda(lambda x: K.mean(x, axis = 1, keepdims = True))(adv)
- //adv = Subtract()([adv, meam])
- //outputs = Add()([value, adv])
- //model = Model(inputs, outputs)
- //model.compile(loss = "mse", optimizer = RMSprop(1e-3))
- //model.summary()
-
- //py output:
- //Model: "functional_3"
- //__________________________________________________________________________________________________
- //Layer(type) Output Shape Param # Connected to
- //==================================================================================================
- //input_2 (InputLayer) [(None, 24)] 0
- //__________________________________________________________________________________________________
- //dense_3 (Dense) (None, 128) 3200 input_2[0][0]
- //__________________________________________________________________________________________________
- //dense_5 (Dense) (None, 1) 129 dense_3[0][0]
- //__________________________________________________________________________________________________
- //lambda_1 (Lambda) (None, 1) 0 dense_5[0][0]
- //__________________________________________________________________________________________________
- //dense_4 (Dense) (None, 24) 3096 dense_3[0][0]
- //__________________________________________________________________________________________________
- //subtract_1 (Subtract) (None, 1) 0 dense_5[0][0]
- // lambda_1[0][0]
- //__________________________________________________________________________________________________
- //add_1 (Add) (None, 24) 0 dense_4[0][0]
- // subtract_1[0][0]
- //==================================================================================================
- //Total params: 6,425
- //Trainable params: 6,425
- //Non-trainable params: 0
- //__________________________________________________________________________________________________
+ model.summary();
+ Assert.AreEqual(model.Layers.Count, 8);
}
///
diff --git a/test/TensorFlowNET.UnitTest/ManagedAPI/TensorOperate.cs b/test/TensorFlowNET.UnitTest/ManagedAPI/TensorOperate.cs
index 9966c12e..c57c98df 100644
--- a/test/TensorFlowNET.UnitTest/ManagedAPI/TensorOperate.cs
+++ b/test/TensorFlowNET.UnitTest/ManagedAPI/TensorOperate.cs
@@ -132,28 +132,25 @@ namespace TensorFlowNET.UnitTest.ManagedAPI
}
#region ones/zeros like
- [Ignore]
[TestMethod]
public void TestOnesLike()
{
#region 2-dimension
- var testCase2D = tf.constant(new int[,]
+ var ones2D = tf.ones_like(new int[,]
{
{ 1, 2, 3 },
{ 4, 5, 6 }
});
- var ones2D = tf.ones_like(testCase2D);
Assert.AreEqual(new[] { 1, 1, 1 }, ones2D[0].numpy());
Assert.AreEqual(new[] { 1, 1, 1 }, ones2D[1].numpy());
#endregion
#region 1-dimension
- var testCase1D = tf.constant(new int[,]
+ var ones1D = tf.ones_like(new int[,]
{
{ 1, 2, 3 }
});
- var ones1D = tf.ones_like(testCase1D);
Assert.AreEqual(new[] { 1, 1, 1 }, ones1D[0].numpy());
#endregion
@@ -163,23 +160,21 @@ namespace TensorFlowNET.UnitTest.ManagedAPI
public void TestZerosLike()
{
#region 2-dimension
- var testCase2D = tf.constant(new int[,]
+ var zeros2D = tf.zeros_like(new int[,]
{
{ 1, 2, 3 },
{ 4, 5, 6 }
});
- var zeros2D = tf.zeros_like(testCase2D);
Assert.AreEqual(new[] { 0, 0, 0 }, zeros2D[0].numpy());
Assert.AreEqual(new[] { 0, 0, 0 }, zeros2D[1].numpy());
#endregion
#region 1-dimension
- var testCase1D = tf.constant(new int[,]
+ var zeros1D = tf.zeros_like(new int[,]
{
{ 1, 2, 3 }
});
- var zeros1D = tf.zeros_like(testCase1D);
Assert.AreEqual(new[] { 0, 0, 0 }, zeros1D[0].numpy());
#endregion
diff --git a/test/Tensorflow.Keras.UnitTest/OptimizerTest.cs b/test/Tensorflow.Keras.UnitTest/OptimizerTest.cs
deleted file mode 100644
index 6647ca59..00000000
--- a/test/Tensorflow.Keras.UnitTest/OptimizerTest.cs
+++ /dev/null
@@ -1,11 +0,0 @@
-using Microsoft.VisualStudio.TestTools.UnitTesting;
-using System.Collections.Generic;
-
-namespace Tensorflow.Keras.UnitTest
-{
- [TestClass]
- public class OptimizerTest
- {
-
- }
-}
diff --git a/test/Tensorflow.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj b/test/Tensorflow.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj
deleted file mode 100644
index 5f5ab347..00000000
--- a/test/Tensorflow.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj
+++ /dev/null
@@ -1,25 +0,0 @@
-
-
-
- netcoreapp3.1
-
- false
-
- AnyCPU;x64
-
-
-
-
-
-
-
- all
- runtime; build; native; contentfiles; analyzers; buildtransitive
-
-
-
-
-
-
-
-