@@ -4,6 +4,25 @@ | |||
This release contains contributions from many people at SciSharp as well as the external contributors. | |||
**Release Date 02/06/2021** | |||
### TensorFlow.Binding v0.33.0 | |||
* Improve memory usage | |||
* Fix minor bugs | |||
### TensorFlow.Keras v0.4.0 | |||
* Add Subtract layer | |||
* Add model.load_weights and model.save_weights | |||
* Fix memory leak issue | |||
* Support to build YOLOv3 object detection model | |||
**Release Date 01/09/2021** | |||
### TensorFlow.Binding v0.32.0 | |||
@@ -215,6 +215,9 @@ namespace Tensorflow | |||
public Tensor ones_like(Tensor tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true) | |||
=> array_ops.ones_like(tensor, dtype: dtype, name: name, optimize: optimize); | |||
public Tensor ones_like(NDArray nd, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true) | |||
=> array_ops.ones_like(nd, dtype: dtype, name: name, optimize: optimize); | |||
public Tensor one_hot(Tensor indices, int depth, | |||
Tensor on_value = null, | |||
Tensor off_value = null, | |||
@@ -290,6 +293,9 @@ namespace Tensorflow | |||
public Tensor zeros_like(Tensor tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true) | |||
=> array_ops.zeros_like(tensor, dtype: dtype, name: name, optimize: optimize); | |||
public Tensor zeros_like(NDArray nd, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true) | |||
=> array_ops.zeros_like(nd, dtype: dtype, name: name, optimize: optimize); | |||
/// <summary> | |||
/// Stops gradient computation. | |||
/// </summary> | |||
@@ -137,6 +137,8 @@ namespace Tensorflow | |||
{ | |||
switch (a) | |||
{ | |||
case Tensors arr: | |||
return arr.Length; | |||
case Array arr: | |||
return arr.Length; | |||
case IList arr: | |||
@@ -28,6 +28,7 @@ namespace Tensorflow.Contexts | |||
/// </summary> | |||
public sealed partial class Context | |||
{ | |||
// [DebuggerStepThrough] | |||
public T RunInAutoMode<T>(Func<T> graphAction, Func<T> eagerAction, params object[] args) | |||
{ | |||
if (tf.Context.has_graph_arg(args)) | |||
@@ -388,14 +388,12 @@ namespace Tensorflow | |||
if (dtype == TF_DataType.DtInvalid) | |||
dtype = tensor1.dtype; | |||
var ret = ones(ones_shape, dtype: dtype, name: name); | |||
ret.shape = tensor1.shape; | |||
return ret; | |||
}); | |||
} | |||
public static Tensor ones(Tensor shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) | |||
{ | |||
dtype = dtype.as_base_dtype(); | |||
return tf_with(ops.name_scope(name, "ones", new { shape }), scope => | |||
{ | |||
name = scope; | |||
@@ -578,11 +576,10 @@ namespace Tensorflow | |||
if (!tf.Context.executing_eagerly()) | |||
{ | |||
var input_tensor = ops.convert_to_tensor(input); | |||
var input_shape = input_tensor.TensorShape; | |||
if (optimize && input_tensor.NDims > -1 && input_shape.is_fully_defined()) | |||
var input_shape = input.TensorShape; | |||
if (optimize && input.NDims > -1 && input_shape.is_fully_defined()) | |||
{ | |||
var nd = np.array(input_tensor.shape).astype(out_type.as_numpy_dtype()); | |||
var nd = np.array(input.shape).astype(out_type.as_numpy_dtype()); | |||
return constant_op.constant(nd, name: name); | |||
} | |||
} | |||
@@ -124,6 +124,9 @@ namespace Tensorflow | |||
x, y).FirstOrDefault(), | |||
x, y); | |||
public static Tensor mean(Tensor input, int axis, bool keep_dims = false, string name = null) | |||
=> mean(input, ops.convert_to_tensor(axis), keep_dims: keep_dims, name: name); | |||
/// <summary> | |||
/// Computes the mean of elements across dimensions of a tensor. | |||
/// Reduces `input` along the dimensions given in `axis`. Unless | |||
@@ -137,23 +140,30 @@ namespace Tensorflow | |||
/// <param name="keep_dims"> An optional `bool`. Defaults to `False`. If true, retain reduced dimensions with length 1.</param> | |||
/// <param name="name"> A name for the operation (optional).</param> | |||
/// <returns> A `Tensor`. Has the same type as `input`.</returns> | |||
public static Tensor mean<T1, T2>(T1 input, T2 axis, bool keep_dims = false, string name = null) | |||
{ | |||
if (tf.Context.executing_eagerly()) | |||
{ | |||
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
public static Tensor mean(Tensor input, Tensor axis, bool keep_dims = false, string name = null) | |||
=> tf.Context.RunInAutoMode2( | |||
() => tf.OpDefLib._apply_op_helper("Mean", name, new | |||
{ | |||
input, | |||
reduction_indices = axis, | |||
keep_dims = keep_dims | |||
}).output, | |||
() => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
"Mean", name, | |||
null, | |||
input, axis, | |||
"keep_dims", keep_dims); | |||
return results[0]; | |||
} | |||
var _op = tf.OpDefLib._apply_op_helper("Mean", name, args: new { input, reduction_indices = axis, keep_dims = keep_dims }); | |||
return _op.output; | |||
} | |||
"keep_dims", keep_dims).FirstOrDefault(), | |||
(op) => | |||
{ | |||
var attrs = new object[] | |||
{ | |||
"T", op.get_attr<TF_DataType>("T"), | |||
"Tidx", op.get_attr<TF_DataType>("Tidx"), | |||
"keep_dims", op.get_attr<bool>("keep_dims") | |||
}; | |||
tf.Runner.RecordGradient("Mean", op.inputs, attrs, op.outputs); | |||
}, | |||
new Tensors(input, axis)); | |||
public static Tensor mean(Tensor[] inputs, Tensor axis, bool keep_dims = false, string name = null) | |||
{ | |||
@@ -786,20 +796,21 @@ namespace Tensorflow | |||
} | |||
public static Tensor sub(Tensor x, Tensor y, string name = null) | |||
{ | |||
if (tf.Context.executing_eagerly()) | |||
{ | |||
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
=> tf.Context.RunInAutoMode2( | |||
() => tf.OpDefLib._apply_op_helper("Sub", name, new { x, y }).output, | |||
() => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
"Sub", name, | |||
null, | |||
x, y); | |||
return results[0]; | |||
} | |||
var _op = tf.OpDefLib._apply_op_helper("Sub", name, args: new { x, y }); | |||
return _op.output; | |||
} | |||
x, y).FirstOrDefault(), | |||
(op) => | |||
{ | |||
var attrs = new object[] | |||
{ | |||
"T", op.get_attr<TF_DataType>("T") | |||
}; | |||
tf.Runner.RecordGradient("Sub", op.inputs, attrs, op.outputs); | |||
}, | |||
new Tensors(x, y)); | |||
public static Tensor sub<Tx, Ty>(Tx x, Ty y, string name = null) | |||
{ | |||
@@ -327,31 +327,17 @@ namespace Tensorflow | |||
public static Tensor reduce_mean(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null, int? reduction_indices = null) | |||
{ | |||
var r = _ReductionDims(input_tensor, axis); | |||
if (axis == null) | |||
{ | |||
var m = gen_math_ops.mean(input_tensor, r, keepdims, name); | |||
return _may_reduce_to_scalar(keepdims, axis, m); | |||
} | |||
else | |||
{ | |||
var m = gen_math_ops.mean(input_tensor, axis, keepdims, name); | |||
return _may_reduce_to_scalar(keepdims, axis, m); | |||
} | |||
var axis_tensor = axis == null ? r : ops.convert_to_tensor(axis); | |||
var m = gen_math_ops.mean(input_tensor, axis_tensor, keepdims, name); | |||
return _may_reduce_to_scalar(keepdims, axis_tensor, m); | |||
} | |||
public static Tensor reduce_mean(Tensor[] input_tensors, int? axis = null, bool keepdims = false, string name = null) | |||
{ | |||
if (axis == null) | |||
{ | |||
var r = _ReductionDims(input_tensors, axis); | |||
var m = gen_math_ops.mean(input_tensors, r, keepdims, name); | |||
return _may_reduce_to_scalar(keepdims, axis, m); | |||
} | |||
else | |||
{ | |||
var m = gen_math_ops.mean(input_tensors, axis, keepdims, name); | |||
return _may_reduce_to_scalar(keepdims, axis, m); | |||
} | |||
var r = _ReductionDims(input_tensors, axis); | |||
var axis_tensor = axis == null ? r : ops.convert_to_tensor(axis.Value); | |||
var m = gen_math_ops.mean(input_tensors, axis_tensor, keepdims, name); | |||
return _may_reduce_to_scalar(keepdims, axis, m); | |||
} | |||
/// <summary> | |||
@@ -90,17 +90,17 @@ namespace Tensorflow | |||
size *= s; | |||
var buffer = new byte[size][]; | |||
var src = c_api.TF_TensorData(_handle); | |||
src += (int)(size * 8); | |||
var data_start = c_api.TF_TensorData(_handle); | |||
data_start += (int)(size * sizeof(ulong)); | |||
for (int i = 0; i < buffer.Length; i++) | |||
{ | |||
IntPtr dst = IntPtr.Zero; | |||
ulong dstLen = 0; | |||
var read = c_api.TF_StringDecode((byte*)src, bytesize, (byte**)&dst, ref dstLen, tf.Status.Handle); | |||
var read = c_api.TF_StringDecode((byte*)data_start, bytesize, (byte**)&dst, ref dstLen, tf.Status.Handle); | |||
tf.Status.Check(true); | |||
buffer[i] = new byte[(int)dstLen]; | |||
Marshal.Copy(dst, buffer[i], 0, buffer[i].Length); | |||
src += (int)read; | |||
data_start += (int)read; | |||
} | |||
return buffer; | |||
@@ -69,13 +69,14 @@ namespace Tensorflow | |||
=> items.Insert(index, tensor); | |||
IEnumerator IEnumerable.GetEnumerator() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
=> GetEnumerator(); | |||
public static implicit operator Tensors(Tensor tensor) | |||
=> new Tensors(tensor); | |||
public static implicit operator Tensors((Tensor, Tensor) tuple) | |||
=> new Tensors(tuple.Item1, tuple.Item2); | |||
public static implicit operator Tensors(NDArray nd) | |||
=> new Tensors(nd); | |||
@@ -56,7 +56,7 @@ Set ENV `BAZEL_VC=C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\ | |||
1. Build static library | |||
`bazel build --config=opt //tensorflow:tensorflow` | |||
`bazel build --output_base=C:/tmp/tfcompilation build --config=opt //tensorflow:tensorflow` | |||
2. Build pip package | |||
@@ -1,6 +1,7 @@ | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using NumSharp; | |||
using Tensorflow; | |||
using static Tensorflow.Binding; | |||
using static Tensorflow.KerasApi; | |||
namespace TensorFlowNET.Keras.UnitTest | |||
@@ -39,8 +40,8 @@ namespace TensorFlowNET.Keras.UnitTest | |||
/// <summary> | |||
/// Custom layer test, used in Dueling DQN | |||
/// </summary> | |||
[TestMethod, Ignore] | |||
public void FunctionalTest() | |||
[TestMethod] | |||
public void TensorFlowOpLayer() | |||
{ | |||
var layers = keras.layers; | |||
var inputs = layers.Input(shape: 24); | |||
@@ -48,58 +49,15 @@ namespace TensorFlowNET.Keras.UnitTest | |||
var value = layers.Dense(24).Apply(x); | |||
var adv = layers.Dense(1).Apply(x); | |||
var adv_out = adv - Binding.tf.reduce_mean(adv, axis: 1, keepdims: true); // Here's problem. | |||
var outputs = layers.Add().Apply(new Tensors(adv_out, value)); | |||
var mean = adv - tf.reduce_mean(adv, axis: 1, keepdims: true); | |||
adv = layers.Subtract().Apply((adv, mean)); | |||
var outputs = layers.Add().Apply((value, adv)); | |||
var model = keras.Model(inputs, outputs); | |||
model.summary(); | |||
model.compile(optimizer: keras.optimizers.RMSprop(0.001f), | |||
loss: keras.losses.MeanSquaredError(), | |||
metrics: new[] { "acc" }); | |||
// Here we consider the adv_out is one layer, which is a little different from py's version | |||
Assert.AreEqual(model.Layers.Count, 6); | |||
// py code: | |||
//from tensorflow.keras.layers import Input, Dense, Add, Subtract, Lambda | |||
//from tensorflow.keras.models import Model | |||
//from tensorflow.keras.optimizers import RMSprop | |||
//import tensorflow.keras.backend as K | |||
//inputs = Input(24) | |||
//x = Dense(128, activation = "relu")(inputs) | |||
//value = Dense(24)(x) | |||
//adv = Dense(1)(x) | |||
//meam = Lambda(lambda x: K.mean(x, axis = 1, keepdims = True))(adv) | |||
//adv = Subtract()([adv, meam]) | |||
//outputs = Add()([value, adv]) | |||
//model = Model(inputs, outputs) | |||
//model.compile(loss = "mse", optimizer = RMSprop(1e-3)) | |||
//model.summary() | |||
//py output: | |||
//Model: "functional_3" | |||
//__________________________________________________________________________________________________ | |||
//Layer(type) Output Shape Param # Connected to | |||
//================================================================================================== | |||
//input_2 (InputLayer) [(None, 24)] 0 | |||
//__________________________________________________________________________________________________ | |||
//dense_3 (Dense) (None, 128) 3200 input_2[0][0] | |||
//__________________________________________________________________________________________________ | |||
//dense_5 (Dense) (None, 1) 129 dense_3[0][0] | |||
//__________________________________________________________________________________________________ | |||
//lambda_1 (Lambda) (None, 1) 0 dense_5[0][0] | |||
//__________________________________________________________________________________________________ | |||
//dense_4 (Dense) (None, 24) 3096 dense_3[0][0] | |||
//__________________________________________________________________________________________________ | |||
//subtract_1 (Subtract) (None, 1) 0 dense_5[0][0] | |||
// lambda_1[0][0] | |||
//__________________________________________________________________________________________________ | |||
//add_1 (Add) (None, 24) 0 dense_4[0][0] | |||
// subtract_1[0][0] | |||
//================================================================================================== | |||
//Total params: 6,425 | |||
//Trainable params: 6,425 | |||
//Non-trainable params: 0 | |||
//__________________________________________________________________________________________________ | |||
model.summary(); | |||
Assert.AreEqual(model.Layers.Count, 8); | |||
} | |||
/// <summary> | |||
@@ -132,28 +132,25 @@ namespace TensorFlowNET.UnitTest.ManagedAPI | |||
} | |||
#region ones/zeros like | |||
[Ignore] | |||
[TestMethod] | |||
public void TestOnesLike() | |||
{ | |||
#region 2-dimension | |||
var testCase2D = tf.constant(new int[,] | |||
var ones2D = tf.ones_like(new int[,] | |||
{ | |||
{ 1, 2, 3 }, | |||
{ 4, 5, 6 } | |||
}); | |||
var ones2D = tf.ones_like(testCase2D); | |||
Assert.AreEqual(new[] { 1, 1, 1 }, ones2D[0].numpy()); | |||
Assert.AreEqual(new[] { 1, 1, 1 }, ones2D[1].numpy()); | |||
#endregion | |||
#region 1-dimension | |||
var testCase1D = tf.constant(new int[,] | |||
var ones1D = tf.ones_like(new int[,] | |||
{ | |||
{ 1, 2, 3 } | |||
}); | |||
var ones1D = tf.ones_like(testCase1D); | |||
Assert.AreEqual(new[] { 1, 1, 1 }, ones1D[0].numpy()); | |||
#endregion | |||
@@ -163,23 +160,21 @@ namespace TensorFlowNET.UnitTest.ManagedAPI | |||
public void TestZerosLike() | |||
{ | |||
#region 2-dimension | |||
var testCase2D = tf.constant(new int[,] | |||
var zeros2D = tf.zeros_like(new int[,] | |||
{ | |||
{ 1, 2, 3 }, | |||
{ 4, 5, 6 } | |||
}); | |||
var zeros2D = tf.zeros_like(testCase2D); | |||
Assert.AreEqual(new[] { 0, 0, 0 }, zeros2D[0].numpy()); | |||
Assert.AreEqual(new[] { 0, 0, 0 }, zeros2D[1].numpy()); | |||
#endregion | |||
#region 1-dimension | |||
var testCase1D = tf.constant(new int[,] | |||
var zeros1D = tf.zeros_like(new int[,] | |||
{ | |||
{ 1, 2, 3 } | |||
}); | |||
var zeros1D = tf.zeros_like(testCase1D); | |||
Assert.AreEqual(new[] { 0, 0, 0 }, zeros1D[0].numpy()); | |||
#endregion | |||
@@ -1,11 +0,0 @@ | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using System.Collections.Generic; | |||
namespace Tensorflow.Keras.UnitTest | |||
{ | |||
[TestClass] | |||
public class OptimizerTest | |||
{ | |||
} | |||
} |
@@ -1,25 +0,0 @@ | |||
<Project Sdk="Microsoft.NET.Sdk"> | |||
<PropertyGroup> | |||
<TargetFramework>netcoreapp3.1</TargetFramework> | |||
<IsPackable>false</IsPackable> | |||
<Platforms>AnyCPU;x64</Platforms> | |||
</PropertyGroup> | |||
<ItemGroup> | |||
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.6.1" /> | |||
<PackageReference Include="MSTest.TestAdapter" Version="2.1.1" /> | |||
<PackageReference Include="MSTest.TestFramework" Version="2.1.1" /> | |||
<PackageReference Include="coverlet.collector" Version="1.2.1"> | |||
<PrivateAssets>all</PrivateAssets> | |||
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets> | |||
</PackageReference> | |||
</ItemGroup> | |||
<ItemGroup> | |||
<ProjectReference Include="..\..\src\TensorFlowNET.Keras\Tensorflow.Keras.csproj" /> | |||
</ItemGroup> | |||
</Project> |