diff --git a/docs/RELEASE.md b/docs/RELEASE.md index 98925ddf..62a1be23 100644 --- a/docs/RELEASE.md +++ b/docs/RELEASE.md @@ -4,6 +4,25 @@ This release contains contributions from many people at SciSharp as well as the external contributors. +**Release Date 02/06/2021** + +### TensorFlow.Binding v0.33.0 + +* Improve memory usage +* Fix minor bugs + +### TensorFlow.Keras v0.4.0 + +* Add Subtract layer + +* Add model.load_weights and model.save_weights + +* Fix memory leak issue + +* Support to build YOLOv3 object detection model + + + **Release Date 01/09/2021** ### TensorFlow.Binding v0.32.0 diff --git a/src/TensorFlowNET.Console/MemoryBasicTest.cs b/src/TensorFlowNET.Console/MemoryBasicTest.cs index 199f870c..d61cca69 100644 --- a/src/TensorFlowNET.Console/MemoryBasicTest.cs +++ b/src/TensorFlowNET.Console/MemoryBasicTest.cs @@ -56,15 +56,31 @@ namespace Tensorflow { var nd = np.zeros(1 * 256 * 256 * 3).astype(np.float32).reshape(1, 256, 256, 3); ResourceVariable variable = tf.Variable(nd); - var nd2 = np.arange(1 * 256 * 256 * 3).astype(np.float32).reshape(1, 256, 256, 3); - variable.assign(nd2); - for (int i = 0; i< 100; i++) + for (int i = 0; i< 10; i++) { var v = variable.numpy(); } }; + public Action VariableAssign + => (epoch, iterate) => + { + ResourceVariable variable = tf.Variable(3112f); + AssignVariable(variable); + for (int i = 0; i < 100; i++) + { + var v = variable.numpy(); + if ((float)v != 1984f) + throw new ValueError(""); + } + }; + + void AssignVariable(IVariableV1 v) + { + using var tensor = tf.constant(1984f); + v.assign(tensor); + } public Action MathAdd => (epoch, iterate) => diff --git a/src/TensorFlowNET.Console/Program.cs b/src/TensorFlowNET.Console/Program.cs index d65e7e6b..38b878af 100644 --- a/src/TensorFlowNET.Console/Program.cs +++ b/src/TensorFlowNET.Console/Program.cs @@ -52,6 +52,10 @@ namespace Tensorflow // 100K float variable. mm.Execute(10, batchSize, basic.Variable); + mm.Execute(10, batchSize, basic.VariableRead); + + mm.Execute(10, batchSize, basic.VariableAssign); + // 1 million math. mm.Execute(10, 100 * batchSize, basic.MathAdd); diff --git a/src/TensorFlowNET.Core/APIs/tf.array.cs b/src/TensorFlowNET.Core/APIs/tf.array.cs index 8452b81a..390942d2 100644 --- a/src/TensorFlowNET.Core/APIs/tf.array.cs +++ b/src/TensorFlowNET.Core/APIs/tf.array.cs @@ -215,6 +215,9 @@ namespace Tensorflow public Tensor ones_like(Tensor tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true) => array_ops.ones_like(tensor, dtype: dtype, name: name, optimize: optimize); + public Tensor ones_like(NDArray nd, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true) + => array_ops.ones_like(nd, dtype: dtype, name: name, optimize: optimize); + public Tensor one_hot(Tensor indices, int depth, Tensor on_value = null, Tensor off_value = null, @@ -290,6 +293,9 @@ namespace Tensorflow public Tensor zeros_like(Tensor tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true) => array_ops.zeros_like(tensor, dtype: dtype, name: name, optimize: optimize); + public Tensor zeros_like(NDArray nd, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true) + => array_ops.zeros_like(nd, dtype: dtype, name: name, optimize: optimize); + /// /// Stops gradient computation. /// diff --git a/src/TensorFlowNET.Core/APIs/tf.math.cs b/src/TensorFlowNET.Core/APIs/tf.math.cs index 2d91be12..ff43c206 100644 --- a/src/TensorFlowNET.Core/APIs/tf.math.cs +++ b/src/TensorFlowNET.Core/APIs/tf.math.cs @@ -23,6 +23,15 @@ namespace Tensorflow { public Tensor log(Tensor x, string name = null) => gen_math_ops.log(x, name); + + /// + /// Computes the Gauss error function of `x` element-wise. + /// + /// + /// + /// + public Tensor erf(Tensor x, string name = null) + => math_ops.erf(x, name); } public Tensor abs(Tensor x, string name = null) @@ -118,6 +127,9 @@ namespace Tensorflow public Tensor cos(Tensor x, string name = null) => gen_math_ops.cos(x, name); + public Tensor cos(float x, string name = null) + => gen_math_ops.cos(x, name); + /// /// Computes hyperbolic cosine of x element-wise. /// diff --git a/src/TensorFlowNET.Core/Binding.Util.cs b/src/TensorFlowNET.Core/Binding.Util.cs index 62ba0bbd..535bbca4 100644 --- a/src/TensorFlowNET.Core/Binding.Util.cs +++ b/src/TensorFlowNET.Core/Binding.Util.cs @@ -137,6 +137,8 @@ namespace Tensorflow { switch (a) { + case Tensors arr: + return arr.Length; case Array arr: return arr.Length; case IList arr: diff --git a/src/TensorFlowNET.Core/Contexts/Context.AutoMode.cs b/src/TensorFlowNET.Core/Contexts/Context.AutoMode.cs index 7db178b3..b076c90f 100644 --- a/src/TensorFlowNET.Core/Contexts/Context.AutoMode.cs +++ b/src/TensorFlowNET.Core/Contexts/Context.AutoMode.cs @@ -28,6 +28,7 @@ namespace Tensorflow.Contexts /// public sealed partial class Context { + // [DebuggerStepThrough] public T RunInAutoMode(Func graphAction, Func eagerAction, params object[] args) { if (tf.Context.has_graph_arg(args)) diff --git a/src/TensorFlowNET.Core/Gradients/math_grad.cs b/src/TensorFlowNET.Core/Gradients/math_grad.cs index 9ea40816..2d0d7d28 100644 --- a/src/TensorFlowNET.Core/Gradients/math_grad.cs +++ b/src/TensorFlowNET.Core/Gradients/math_grad.cs @@ -138,6 +138,9 @@ namespace Tensorflow.Gradients [RegisterNoGradient("GreaterEqual")] public static Tensor[] _GreaterEqualGrad(Operation op, Tensor[] grads) => null; + [RegisterNoGradient("OnesLike")] + public static Tensor[] _OnesLike(Operation op, Tensor[] grads) => null; + [RegisterNoGradient("ZerosLike")] public static Tensor[] _ZerosLike(Operation op, Tensor[] grads) => null; diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/RNNArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/RNNArgs.cs index 623cc68e..3ebcf617 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/RNNArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/RNNArgs.cs @@ -1,6 +1,21 @@ -namespace Tensorflow.Keras.ArgsDefinition +using System.Collections.Generic; + +namespace Tensorflow.Keras.ArgsDefinition { public class RNNArgs : LayerArgs { + public interface IRnnArgCell : ILayer + { + object state_size { get; } + } + + public IRnnArgCell Cell { get; set; } = null; + public bool ReturnSequences { get; set; } = false; + public bool ReturnState { get; set; } = false; + public bool GoBackwards { get; set; } = false; + public bool Stateful { get; set; } = false; + public bool Unroll { get; set; } = false; + public bool TimeMajor { get; set; } = false; + public Dictionary Kwargs { get; set; } = null; } } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/SimpleRNNArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/SimpleRNNArgs.cs new file mode 100644 index 00000000..65815587 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/SimpleRNNArgs.cs @@ -0,0 +1,30 @@ +namespace Tensorflow.Keras.ArgsDefinition +{ + public class SimpleRNNArgs : RNNArgs + { + public int Units { get; set; } + public Activation Activation { get; set; } + + // units, + // activation='tanh', + // use_bias=True, + // kernel_initializer='glorot_uniform', + // recurrent_initializer='orthogonal', + // bias_initializer='zeros', + // kernel_regularizer=None, + // recurrent_regularizer=None, + // bias_regularizer=None, + // activity_regularizer=None, + // kernel_constraint=None, + // recurrent_constraint=None, + // bias_constraint=None, + // dropout=0., + // recurrent_dropout=0., + // return_sequences=False, + // return_state=False, + // go_backwards=False, + // stateful=False, + // unroll=False, + // **kwargs): + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/StackedRNNCellsArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/StackedRNNCellsArgs.cs new file mode 100644 index 00000000..1c52e47b --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/StackedRNNCellsArgs.cs @@ -0,0 +1,9 @@ +using System.Collections.Generic; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class StackedRNNCellsArgs : LayerArgs + { + public IList Cells { get; set; } + } +} diff --git a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs index aaea5cd2..0dd40096 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs @@ -46,7 +46,7 @@ namespace Tensorflow /// matching structure of Tensors having shape `[batch_size].concatenate(s)` /// for each `s` in `self.batch_size`. /// - public abstract class RnnCell : ILayer + public abstract class RnnCell : ILayer, RNNArgs.IRnnArgCell { /// /// Attribute that indicates whether the cell is a TF RNN cell, due the slight diff --git a/src/TensorFlowNET.Core/Operations/array_ops.cs b/src/TensorFlowNET.Core/Operations/array_ops.cs index bf5324dd..625d76a1 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.cs @@ -274,7 +274,7 @@ namespace Tensorflow { if (elem is EagerTensor eager_tensor) { - if(switch_to_graph) + if (switch_to_graph) elems_as_tensors.Add(constant_op.constant(eager_tensor.numpy(), dtype: dtype, name: i.ToString())); else elems_as_tensors.Add(eager_tensor); @@ -366,8 +366,30 @@ namespace Tensorflow /// /// /// - public static Tensor ones_like(T tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true) - => ones_like_impl(tensor, dtype, name, optimize); + public static Tensor ones_like(Tensor tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true) + { + return tf_with(ops.name_scope(name, "ones_like", new Tensor[] { tensor }), scope => + { + name = scope; + tensor = ops.convert_to_tensor(tensor, name: "tensor"); + + // is_fully_defined return unexpected value. + if (optimize && tensor_util.to_shape(tensor.shape).is_fully_defined() && dtype != TF_DataType.TF_VARIANT) + { + + } + + if (dtype != TF_DataType.DtInvalid && dtype != tensor.dtype && dtype != TF_DataType.TF_VARIANT) + { + throw new NotImplementedException("ones_like"); + // return ones(shape_internal(tensor, optimize: optimize), dtype: dtype, name: name); + } + else + { + return gen_array_ops.ones_like(tensor, name: name); + } + }); + } public static Tensor reshape(Tensor tensor, Tensor shape, string name = null) => gen_array_ops.reshape(tensor, shape, name: name); @@ -388,14 +410,12 @@ namespace Tensorflow if (dtype == TF_DataType.DtInvalid) dtype = tensor1.dtype; var ret = ones(ones_shape, dtype: dtype, name: name); - ret.shape = tensor1.shape; return ret; }); } public static Tensor ones(Tensor shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) { - dtype = dtype.as_base_dtype(); return tf_with(ops.name_scope(name, "ones", new { shape }), scope => { name = scope; @@ -578,11 +598,10 @@ namespace Tensorflow if (!tf.Context.executing_eagerly()) { - var input_tensor = ops.convert_to_tensor(input); - var input_shape = input_tensor.TensorShape; - if (optimize && input_tensor.NDims > -1 && input_shape.is_fully_defined()) + var input_shape = input.TensorShape; + if (optimize && input.NDims > -1 && input_shape.is_fully_defined()) { - var nd = np.array(input_tensor.shape).astype(out_type.as_numpy_dtype()); + var nd = np.array(input.shape).astype(out_type.as_numpy_dtype()); return constant_op.constant(nd, name: name); } } @@ -891,7 +910,7 @@ namespace Tensorflow return tf_with(ops.name_scope(name, "transpose", new { a }), scope => { var a_tensor = ops.convert_to_tensor(a); - if(perm == null) + if (perm == null) { var rank = a_tensor.rank; perm = range(0, rank).OrderByDescending(x => x).ToArray(); @@ -953,7 +972,9 @@ namespace Tensorflow => tf.Context.RunInAutoMode2( () => tf.OpDefLib._apply_op_helper("Slice", name, new { - input, begin, size + input, + begin, + size }).output, () => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, "Slice", name, @@ -969,8 +990,8 @@ namespace Tensorflow tf.Runner.RecordGradient("Slice", op.inputs, attrs, op.outputs); }, new Tensors(input, begin, size)); - - public static Tensor stack(object values, int axis = 0, string name = "stack") + + public static Tensor stack(object values, int axis = 0, string name = "stack") { if (axis == 0) // If the input is a constant list, it can be converted to a constant op diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs index a2db25d9..e29227c4 100644 --- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs @@ -591,6 +591,15 @@ namespace Tensorflow return _op.outputs[0]; } + public static Tensor ones_like(Tensor x, string name = null) + => tf.Context.RunInAutoMode(() + => tf.OpDefLib._apply_op_helper("OnesLike", name, new { x }).output, () + => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, + "OnesLike", name, + null, + x).FirstOrDefault(), + x); + public static Tensor zeros_like(Tensor x, string name = null) => tf.Context.RunInAutoMode(() => tf.OpDefLib._apply_op_helper("ZerosLike", name, new { x }).output, () diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs index 3d64e8b9..5d585e77 100644 --- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs @@ -124,6 +124,9 @@ namespace Tensorflow x, y).FirstOrDefault(), x, y); + public static Tensor mean(Tensor input, int axis, bool keep_dims = false, string name = null) + => mean(input, ops.convert_to_tensor(axis), keep_dims: keep_dims, name: name); + /// /// Computes the mean of elements across dimensions of a tensor. /// Reduces `input` along the dimensions given in `axis`. Unless @@ -137,23 +140,30 @@ namespace Tensorflow /// An optional `bool`. Defaults to `False`. If true, retain reduced dimensions with length 1. /// A name for the operation (optional). /// A `Tensor`. Has the same type as `input`. - public static Tensor mean(T1 input, T2 axis, bool keep_dims = false, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, + public static Tensor mean(Tensor input, Tensor axis, bool keep_dims = false, string name = null) + => tf.Context.RunInAutoMode2( + () => tf.OpDefLib._apply_op_helper("Mean", name, new + { + input, + reduction_indices = axis, + keep_dims = keep_dims + }).output, + () => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, "Mean", name, null, input, axis, - "keep_dims", keep_dims); - - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("Mean", name, args: new { input, reduction_indices = axis, keep_dims = keep_dims }); - - return _op.output; - } + "keep_dims", keep_dims).FirstOrDefault(), + (op) => + { + var attrs = new object[] + { + "T", op.get_attr("T"), + "Tidx", op.get_attr("Tidx"), + "keep_dims", op.get_attr("keep_dims") + }; + tf.Runner.RecordGradient("Mean", op.inputs, attrs, op.outputs); + }, + new Tensors(input, axis)); public static Tensor mean(Tensor[] inputs, Tensor axis, bool keep_dims = false, string name = null) { @@ -376,8 +386,18 @@ namespace Tensorflow return _op.outputs[0]; } - public static Tensor cos(Tensor x, string name = null) + public static Tensor cos(T x, string name = null) { + if (tf.executing_eagerly()) + { + var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, + "Cos", name, + null, + x); + + return results[0]; + } + var _op = tf.OpDefLib._apply_op_helper("Cos", name, args: new { x }); return _op.outputs[0]; @@ -776,20 +796,21 @@ namespace Tensorflow } public static Tensor sub(Tensor x, Tensor y, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, + => tf.Context.RunInAutoMode2( + () => tf.OpDefLib._apply_op_helper("Sub", name, new { x, y }).output, + () => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, "Sub", name, null, - x, y); - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("Sub", name, args: new { x, y }); - - return _op.output; - } + x, y).FirstOrDefault(), + (op) => + { + var attrs = new object[] + { + "T", op.get_attr("T") + }; + tf.Runner.RecordGradient("Sub", op.inputs, attrs, op.outputs); + }, + new Tensors(x, y)); public static Tensor sub(Tx x, Ty y, string name = null) { diff --git a/src/TensorFlowNET.Core/Operations/math_ops.cs b/src/TensorFlowNET.Core/Operations/math_ops.cs index 2c051992..eabd5cd1 100644 --- a/src/TensorFlowNET.Core/Operations/math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/math_ops.cs @@ -265,6 +265,29 @@ namespace Tensorflow public static Tensor equal(Tx x, Ty y, string name = null) => gen_math_ops.equal(x, y, name: name); + /// + /// Computes the Gauss error function of `x` element-wise. + /// + /// + /// + /// + public static Tensor erf(Tensor x, string name = null) + => tf.Context.RunInAutoMode2( + () => tf.OpDefLib._apply_op_helper("Erf", name, new { x }).output, + () => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, + "Erf", name, + null, + x).FirstOrDefault(), + (op) => + { + var attrs = new object[] + { + "T", op.get_attr("T") + }; + tf.Runner.RecordGradient("Erf", op.inputs, attrs, op.outputs); + }, + new Tensors(x)); + public static Tensor sqrt(Tensor x, string name = null) => gen_math_ops.sqrt(x, name: name); @@ -327,31 +350,17 @@ namespace Tensorflow public static Tensor reduce_mean(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null, int? reduction_indices = null) { var r = _ReductionDims(input_tensor, axis); - if (axis == null) - { - var m = gen_math_ops.mean(input_tensor, r, keepdims, name); - return _may_reduce_to_scalar(keepdims, axis, m); - } - else - { - var m = gen_math_ops.mean(input_tensor, axis, keepdims, name); - return _may_reduce_to_scalar(keepdims, axis, m); - } + var axis_tensor = axis == null ? r : ops.convert_to_tensor(axis); + var m = gen_math_ops.mean(input_tensor, axis_tensor, keepdims, name); + return _may_reduce_to_scalar(keepdims, axis_tensor, m); } public static Tensor reduce_mean(Tensor[] input_tensors, int? axis = null, bool keepdims = false, string name = null) { - if (axis == null) - { - var r = _ReductionDims(input_tensors, axis); - var m = gen_math_ops.mean(input_tensors, r, keepdims, name); - return _may_reduce_to_scalar(keepdims, axis, m); - } - else - { - var m = gen_math_ops.mean(input_tensors, axis, keepdims, name); - return _may_reduce_to_scalar(keepdims, axis, m); - } + var r = _ReductionDims(input_tensors, axis); + var axis_tensor = axis == null ? r : ops.convert_to_tensor(axis.Value); + var m = gen_math_ops.mean(input_tensors, axis_tensor, keepdims, name); + return _may_reduce_to_scalar(keepdims, axis, m); } /// diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.String.cs b/src/TensorFlowNET.Core/Tensors/Tensor.String.cs index a2ad7530..e331dc1a 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.String.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.String.cs @@ -91,14 +91,16 @@ namespace Tensorflow var buffer = new byte[size][]; var data_start = c_api.TF_TensorData(_handle); - var string_start = data_start + (int)(size * sizeof(ulong)); + data_start += (int)(size * sizeof(ulong)); for (int i = 0; i < buffer.Length; i++) { - var len = *(byte*)string_start; - buffer[i] = new byte[len]; - string_start += 1; - Marshal.Copy(string_start, buffer[i], 0, len); - string_start += len; + IntPtr dst = IntPtr.Zero; + ulong dstLen = 0; + var read = c_api.TF_StringDecode((byte*)data_start, bytesize, (byte**)&dst, ref dstLen, tf.Status.Handle); + tf.Status.Check(true); + buffer[i] = new byte[(int)dstLen]; + Marshal.Copy(dst, buffer[i], 0, buffer[i].Length); + data_start += (int)read; } return buffer; diff --git a/src/TensorFlowNET.Core/Tensors/Tensors.cs b/src/TensorFlowNET.Core/Tensors/Tensors.cs index 1c8d939a..3c334ea5 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensors.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensors.cs @@ -69,13 +69,14 @@ namespace Tensorflow => items.Insert(index, tensor); IEnumerator IEnumerable.GetEnumerator() - { - throw new NotImplementedException(); - } + => GetEnumerator(); public static implicit operator Tensors(Tensor tensor) => new Tensors(tensor); + public static implicit operator Tensors((Tensor, Tensor) tuple) + => new Tensors(tuple.Item1, tuple.Item2); + public static implicit operator Tensors(NDArray nd) => new Tensors(nd); diff --git a/src/TensorFlowNET.Keras/Activations/Activations.Sigmoid.cs b/src/TensorFlowNET.Keras/Activations/Activations.Sigmoid.cs index 67a7fa98..84220f4f 100644 --- a/src/TensorFlowNET.Keras/Activations/Activations.Sigmoid.cs +++ b/src/TensorFlowNET.Keras/Activations/Activations.Sigmoid.cs @@ -17,7 +17,9 @@ namespace Tensorflow.Keras return results[0]; } - throw new NotImplementedException(""); + var _op = tf.OpDefLib._apply_op_helper("Sigmoid", name: name, args: new { x = features }); + + return _op.output; }; } } diff --git a/src/TensorFlowNET.Keras/Activations/Activations.Tanh.cs b/src/TensorFlowNET.Keras/Activations/Activations.Tanh.cs index fb74c539..30bbdbf4 100644 --- a/src/TensorFlowNET.Keras/Activations/Activations.Tanh.cs +++ b/src/TensorFlowNET.Keras/Activations/Activations.Tanh.cs @@ -17,7 +17,9 @@ namespace Tensorflow.Keras return results[0]; } - throw new NotImplementedException(""); + var _op = tf.OpDefLib._apply_op_helper("Tanh", name: name, args: new { x = features }); + + return _op.output; }; } } diff --git a/src/TensorFlowNET.Keras/Engine/Interfaces/ITensorFlowOpLayer.cs b/src/TensorFlowNET.Keras/Engine/Interfaces/ITensorFlowOpLayer.cs deleted file mode 100644 index a1d3ecbf..00000000 --- a/src/TensorFlowNET.Keras/Engine/Interfaces/ITensorFlowOpLayer.cs +++ /dev/null @@ -1,12 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Text; -using Tensorflow.Keras.ArgsDefinition; - -namespace Tensorflow.Keras.Engine -{ - public interface ITensorFlowOpLayer - { - Layer GetOpLayer(TensorFlowOpLayerArgs args); - } -} diff --git a/src/TensorFlowNET.Keras/Engine/Model.Fit.cs b/src/TensorFlowNET.Keras/Engine/Model.Fit.cs index 8b9176e3..2e83f75d 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Fit.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Fit.cs @@ -51,7 +51,7 @@ namespace Tensorflow.Keras.Engine StepsPerExecution = _steps_per_execution }); - FitInternal(epochs); + FitInternal(epochs, verbose); } public void fit(IDatasetV2 dataset, @@ -80,10 +80,10 @@ namespace Tensorflow.Keras.Engine StepsPerExecution = _steps_per_execution }); - FitInternal(epochs); + FitInternal(epochs, verbose); } - void FitInternal(int epochs) + void FitInternal(int epochs, int verbose) { stop_training = false; _train_counter.assign(0); @@ -96,8 +96,11 @@ namespace Tensorflow.Keras.Engine { // callbacks.on_train_batch_begin(step) var results = train_step_function(iterator); - var result_pairs = string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2:F6}")); - Console.WriteLine($"Epoch: {epoch + 1:D3}/{epochs:D3}, Step: {step + 1:D4}/{data_handler.Inferredsteps:D4}, {result_pairs}"); + if (verbose == 1) + { + var result_pairs = string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2:F6}")); + Console.WriteLine($"Epoch: {epoch + 1:D3}/{epochs:D3}, Step: {step + 1:D4}/{data_handler.Inferredsteps:D4}, {result_pairs}"); + } } GC.Collect(); diff --git a/src/TensorFlowNET.Keras/Layers/LayersApi.cs b/src/TensorFlowNET.Keras/Layers/LayersApi.cs index fbce83c9..e47c6517 100644 --- a/src/TensorFlowNET.Keras/Layers/LayersApi.cs +++ b/src/TensorFlowNET.Keras/Layers/LayersApi.cs @@ -1,4 +1,5 @@ using NumSharp; +using System.Collections.Generic; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; using static Tensorflow.Binding; @@ -142,6 +143,7 @@ namespace Tensorflow.Keras.Layers public Dense Dense(int units, Activation activation = null, IInitializer kernel_initializer = null, + bool use_bias = true, IInitializer bias_initializer = null, TensorShape input_shape = null) => new Dense(new DenseArgs @@ -149,7 +151,7 @@ namespace Tensorflow.Keras.Layers Units = units, Activation = activation ?? keras.activations.Linear, KernelInitializer = kernel_initializer ?? tf.glorot_uniform_initializer, - BiasInitializer = bias_initializer ?? tf.zeros_initializer, + BiasInitializer = bias_initializer ?? (use_bias ? tf.zeros_initializer : null), InputShape = input_shape }); @@ -332,6 +334,24 @@ namespace Tensorflow.Keras.Layers Alpha = alpha }); + public Layer SimpleRNN(int units) => SimpleRNN(units, "tanh"); + + public Layer SimpleRNN(int units, + Activation activation = null) + => new SimpleRNN(new SimpleRNNArgs + { + Units = units, + Activation = activation + }); + + public Layer SimpleRNN(int units, + string activation = "tanh") + => new SimpleRNN(new SimpleRNNArgs + { + Units = units, + Activation = GetActivationByName(activation) + }); + public Layer LSTM(int units, Activation activation = null, Activation recurrent_activation = null, @@ -381,6 +401,9 @@ namespace Tensorflow.Keras.Layers public Add Add() => new Add(new MergeArgs { }); + public Subtract Subtract() + => new Subtract(new MergeArgs { }); + public GlobalAveragePooling2D GlobalAveragePooling2D() => new GlobalAveragePooling2D(new Pooling2DArgs { }); diff --git a/src/TensorFlowNET.Keras/Layers/Merging/Subtract.cs b/src/TensorFlowNET.Keras/Layers/Merging/Subtract.cs new file mode 100644 index 00000000..b6a1039e --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Merging/Subtract.cs @@ -0,0 +1,23 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.ArgsDefinition; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.Layers +{ + public class Subtract : Merge + { + public Subtract(MergeArgs args) : base(args) + { + + } + + protected override Tensors _merge_function(Tensors inputs) + { + if (len(inputs) != 2) + throw new ValueError($"A `Subtract` layer should be called on exactly 2 inputs"); + return inputs[0] - inputs[1]; + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/RNN.cs b/src/TensorFlowNET.Keras/Layers/RNN.cs index 3d03abb1..0c77d57f 100644 --- a/src/TensorFlowNET.Keras/Layers/RNN.cs +++ b/src/TensorFlowNET.Keras/Layers/RNN.cs @@ -1,4 +1,5 @@ using System; +using System.Collections.Generic; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; @@ -6,12 +7,93 @@ namespace Tensorflow.Keras.Layers { public class RNN : Layer { - public RNN(RNNArgs args) - : base(args) + private RNNArgs args; + + public RNN(RNNArgs args) : base(PreConstruct(args)) { + this.args = args; + SupportsMasking = true; + + // The input shape is unknown yet, it could have nested tensor inputs, and + // the input spec will be the list of specs for nested inputs, the structure + // of the input_spec will be the same as the input. + //self.input_spec = None + //self.state_spec = None + //self._states = None + //self.constants_spec = None + //self._num_constants = 0 + + //if stateful: + // if ds_context.has_strategy(): + // raise ValueError('RNNs with stateful=True not yet supported with ' + // 'tf.distribute.Strategy.') } + private static RNNArgs PreConstruct(RNNArgs args) + { + if (args.Kwargs == null) + { + args.Kwargs = new Dictionary(); + } + + // If true, the output for masked timestep will be zeros, whereas in the + // false case, output from previous timestep is returned for masked timestep. + var zeroOutputForMask = (bool)args.Kwargs.Get("zero_output_for_mask", false); + + object input_shape; + var propIS = args.Kwargs.Get("input_shape", null); + var propID = args.Kwargs.Get("input_dim", null); + var propIL = args.Kwargs.Get("input_length", null); + + if (propIS == null && (propID != null || propIL != null)) + { + input_shape = ( + propIL ?? new NoneValue(), // maybe null is needed here + propID ?? new NoneValue()); // and here + args.Kwargs["input_shape"] = input_shape; + } + + return args; + } + + public RNN New(LayerRnnCell cell, + bool return_sequences = false, + bool return_state = false, + bool go_backwards = false, + bool stateful = false, + bool unroll = false, + bool time_major = false) + => new RNN(new RNNArgs + { + Cell = cell, + ReturnSequences = return_sequences, + ReturnState = return_state, + GoBackwards = go_backwards, + Stateful = stateful, + Unroll = unroll, + TimeMajor = time_major + }); + + public RNN New(IList cell, + bool return_sequences = false, + bool return_state = false, + bool go_backwards = false, + bool stateful = false, + bool unroll = false, + bool time_major = false) + => new RNN(new RNNArgs + { + Cell = new StackedRNNCells(new StackedRNNCellsArgs { Cells = cell }), + ReturnSequences = return_sequences, + ReturnState = return_state, + GoBackwards = go_backwards, + Stateful = stateful, + Unroll = unroll, + TimeMajor = time_major + }); + + protected Tensor get_initial_state(Tensor inputs) { return _generate_zero_filled_state_for_cell(null, null); diff --git a/src/TensorFlowNET.Keras/Layers/SimpleRNN.cs b/src/TensorFlowNET.Keras/Layers/SimpleRNN.cs new file mode 100644 index 00000000..c1fc4afd --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/SimpleRNN.cs @@ -0,0 +1,14 @@ +using Tensorflow.Keras.ArgsDefinition; + +namespace Tensorflow.Keras.Layers +{ + public class SimpleRNN : RNN + { + + public SimpleRNN(RNNArgs args) : base(args) + { + + } + + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Keras/Layers/StackedRNNCells.cs b/src/TensorFlowNET.Keras/Layers/StackedRNNCells.cs new file mode 100644 index 00000000..c0a2371f --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/StackedRNNCells.cs @@ -0,0 +1,125 @@ +using System; +using System.Collections.Generic; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; + +namespace Tensorflow.Keras.Layers +{ + public class StackedRNNCells : Layer, RNNArgs.IRnnArgCell + { + public IList Cells { get; set; } + + public StackedRNNCells(StackedRNNCellsArgs args) : base(args) + { + Cells = args.Cells; + //Cells.reverse_state_order = kwargs.pop('reverse_state_order', False); + // self.reverse_state_order = kwargs.pop('reverse_state_order', False) + // if self.reverse_state_order: + // logging.warning('reverse_state_order=True in StackedRNNCells will soon ' + // 'be deprecated. Please update the code to work with the ' + // 'natural order of states if you rely on the RNN states, ' + // 'eg RNN(return_state=True).') + // super(StackedRNNCells, self).__init__(**kwargs) + throw new NotImplementedException(""); + } + + public object state_size + { + get => throw new NotImplementedException(); + } + + //@property + //def state_size(self) : + // return tuple(c.state_size for c in + // (self.cells[::- 1] if self.reverse_state_order else self.cells)) + + // @property + // def output_size(self) : + // if getattr(self.cells[-1], 'output_size', None) is not None: + // return self.cells[-1].output_size + // elif _is_multiple_state(self.cells[-1].state_size) : + // return self.cells[-1].state_size[0] + // else: + // return self.cells[-1].state_size + + // def get_initial_state(self, inputs= None, batch_size= None, dtype= None) : + // initial_states = [] + // for cell in self.cells[::- 1] if self.reverse_state_order else self.cells: + // get_initial_state_fn = getattr(cell, 'get_initial_state', None) + // if get_initial_state_fn: + // initial_states.append(get_initial_state_fn( + // inputs=inputs, batch_size=batch_size, dtype=dtype)) + // else: + // initial_states.append(_generate_zero_filled_state_for_cell( + // cell, inputs, batch_size, dtype)) + + // return tuple(initial_states) + + // def call(self, inputs, states, constants= None, training= None, ** kwargs): + // # Recover per-cell states. + // state_size = (self.state_size[::- 1] + // if self.reverse_state_order else self.state_size) + // nested_states = nest.pack_sequence_as(state_size, nest.flatten(states)) + + // # Call the cells in order and store the returned states. + // new_nested_states = [] + // for cell, states in zip(self.cells, nested_states) : + // states = states if nest.is_nested(states) else [states] + //# TF cell does not wrap the state into list when there is only one state. + // is_tf_rnn_cell = getattr(cell, '_is_tf_rnn_cell', None) is not None + // states = states[0] if len(states) == 1 and is_tf_rnn_cell else states + // if generic_utils.has_arg(cell.call, 'training'): + // kwargs['training'] = training + // else: + // kwargs.pop('training', None) + // # Use the __call__ function for callable objects, eg layers, so that it + // # will have the proper name scopes for the ops, etc. + // cell_call_fn = cell.__call__ if callable(cell) else cell.call + // if generic_utils.has_arg(cell.call, 'constants'): + // inputs, states = cell_call_fn(inputs, states, + // constants= constants, ** kwargs) + // else: + // inputs, states = cell_call_fn(inputs, states, ** kwargs) + // new_nested_states.append(states) + + // return inputs, nest.pack_sequence_as(state_size, + // nest.flatten(new_nested_states)) + + // @tf_utils.shape_type_conversion + // def build(self, input_shape) : + // if isinstance(input_shape, list) : + // input_shape = input_shape[0] + // for cell in self.cells: + // if isinstance(cell, Layer) and not cell.built: + // with K.name_scope(cell.name): + // cell.build(input_shape) + // cell.built = True + // if getattr(cell, 'output_size', None) is not None: + // output_dim = cell.output_size + // elif _is_multiple_state(cell.state_size) : + // output_dim = cell.state_size[0] + // else: + // output_dim = cell.state_size + // input_shape = tuple([input_shape[0]] + + // tensor_shape.TensorShape(output_dim).as_list()) + // self.built = True + + // def get_config(self) : + // cells = [] + // for cell in self.cells: + // cells.append(generic_utils.serialize_keras_object(cell)) + // config = {'cells': cells + //} + //base_config = super(StackedRNNCells, self).get_config() + // return dict(list(base_config.items()) + list(config.items())) + + // @classmethod + // def from_config(cls, config, custom_objects = None): + // from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top + // cells = [] + // for cell_config in config.pop('cells'): + // cells.append( + // deserialize_layer(cell_config, custom_objects = custom_objects)) + // return cls(cells, **config) + } +} diff --git a/src/TensorFlowNET.Keras/Layers/TensorFlowOpLayer.cs b/src/TensorFlowNET.Keras/Layers/TensorFlowOpLayer.cs new file mode 100644 index 00000000..1c0470fe --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/TensorFlowOpLayer.cs @@ -0,0 +1,73 @@ +using NumSharp; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Tensorflow; +using Tensorflow.Graphs; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.Layers +{ + public class TensorFlowOpLayer : Layer + { + TensorFlowOpLayerArgs args; + Dictionary constants => args.Constants; + NodeDef node_def => args.NodeDef; + static string TF_OP_LAYER_NAME_PREFIX = "tf_op_layer_"; + public string OpType => node_def.Op; + + public TensorFlowOpLayer(TensorFlowOpLayerArgs args) + : base(new LayerArgs + { + Name = TF_OP_LAYER_NAME_PREFIX + args.Name, + Trainable = args.Trainable, + DType = args.DType, + Autocast = false + }) + { + this.args = args; + built = true; + } + + protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false) + { + if (tf.Context.executing_eagerly()) + return _defun_call(inputs); + return MakOp(inputs); + } + + [AutoGraph] + Tensors _defun_call(Tensors inputs) + => MakOp(inputs); + + Tensors MakOp(Tensors inputs) + { + var graph = inputs.graph; + graph.as_default(); + foreach (var (index, constant) in enumerate(constants)) + { + var value = constant_op.constant(constant, name: node_def.Input[index]); + inputs.Insert(index, value); + } + + var (c_op, _) = ops._create_c_op(graph, node_def, inputs.ToArray(), new Operation[0]); + var op = graph._create_op_from_tf_operation(c_op); + op._control_flow_post_processing(); + + // Record the gradient because custom-made ops don't go through the + // code-gen'd eager call path + var op_type = op.node_def.Op; + + tf.Runner.RecordGradient(op_type, op.inputs._inputs, null, op.outputs); + + graph.Exit(); + return op.outputs; + } + + public Layer GetOpLayer(TensorFlowOpLayerArgs args) + => new TensorFlowOpLayer(args); + } +} diff --git a/src/TensorFlowNET.Keras/Losses/Huber.cs b/src/TensorFlowNET.Keras/Losses/Huber.cs index 6098dee3..a256786f 100644 --- a/src/TensorFlowNET.Keras/Losses/Huber.cs +++ b/src/TensorFlowNET.Keras/Losses/Huber.cs @@ -27,10 +27,10 @@ namespace Tensorflow.Keras.Losses Tensor error = math_ops.subtract(y_pred_cast, y_true_cast); Tensor abs_error = math_ops.abs(error); Tensor half = ops.convert_to_tensor(0.5, dtype: abs_error.dtype); - return gen_math_ops.mean(array_ops.where_v2(abs_error <= delta, - half * math_ops.pow(error, 2), + return gen_math_ops.mean(array_ops.where_v2(abs_error <= delta, + half * math_ops.pow(error, 2), half * math_ops.pow(delta, 2) + delta * (abs_error - delta)), - axis : -1); + axis: -1); } } } diff --git a/src/TensorFlowNET.Keras/Losses/LogCosh.cs b/src/TensorFlowNET.Keras/Losses/LogCosh.cs index 1c894904..8acbbe9d 100644 --- a/src/TensorFlowNET.Keras/Losses/LogCosh.cs +++ b/src/TensorFlowNET.Keras/Losses/LogCosh.cs @@ -19,10 +19,8 @@ namespace Tensorflow.Keras.Losses Tensor y_pred_dispatch = ops.convert_to_tensor(y_pred); Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype); Tensor x = y_pred_dispatch - y_true_cast; - - return gen_math_ops.mean(x + gen_math_ops.softplus(-2.0 * x) - math_ops.cast(math_ops.log(tf.Variable(2.0)), x.dtype),axis: -1); - + return gen_math_ops.mean(x + gen_math_ops.softplus(-2.0 * x) - math_ops.cast(math_ops.log(tf.Variable(2.0)), x.dtype), axis: -1); } } } diff --git a/src/TensorFlowNET.Keras/Losses/MeanAbsolutePercentageError.cs b/src/TensorFlowNET.Keras/Losses/MeanAbsolutePercentageError.cs index 74c95b4a..3295b12b 100644 --- a/src/TensorFlowNET.Keras/Losses/MeanAbsolutePercentageError.cs +++ b/src/TensorFlowNET.Keras/Losses/MeanAbsolutePercentageError.cs @@ -18,7 +18,7 @@ namespace Tensorflow.Keras.Losses Tensor y_pred_dispatch = ops.convert_to_tensor(y_pred); Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype); Tensor diff = math_ops.abs(y_true_cast - y_pred_dispatch) / gen_math_ops.maximum(math_ops.abs(y_true_cast), gen_math_ops.cast(tf.constant(1e-7), y_pred_dispatch.dtype)); - return gen_math_ops.cast(tf.constant(100), y_pred_dispatch.dtype) *gen_math_ops.mean(diff, axis: -1); + return gen_math_ops.cast(tf.constant(100), y_pred_dispatch.dtype) * gen_math_ops.mean(diff, axis: -1); } } } diff --git a/src/TensorFlowNET.Keras/Losses/MeanSquaredError.cs b/src/TensorFlowNET.Keras/Losses/MeanSquaredError.cs index 24ef1043..6ae7d86d 100644 --- a/src/TensorFlowNET.Keras/Losses/MeanSquaredError.cs +++ b/src/TensorFlowNET.Keras/Losses/MeanSquaredError.cs @@ -17,7 +17,7 @@ namespace Tensorflow.Keras.Losses { Tensor y_pred_dispatch = ops.convert_to_tensor(y_pred); Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype); - return gen_math_ops.mean(gen_math_ops.squared_difference(y_pred_dispatch, y_true_cast), axis: -1); + return gen_math_ops.mean(gen_math_ops.squared_difference(y_pred_dispatch, y_true_cast), axis: -1); } } } diff --git a/src/TensorFlowNET.Keras/Optimizers/OptimizerV2.cs b/src/TensorFlowNET.Keras/Optimizers/OptimizerV2.cs index 7ad370ae..2383c5d1 100644 --- a/src/TensorFlowNET.Keras/Optimizers/OptimizerV2.cs +++ b/src/TensorFlowNET.Keras/Optimizers/OptimizerV2.cs @@ -26,6 +26,9 @@ namespace Tensorflow.Keras.Optimizers protected float _initial_decay = 0.0f; protected bool _use_locking = true; + public IVariableV1 lr + => _hyper_variables["learning_rate"]; + Dictionary> _slots; List _slot_names; diff --git a/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj b/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj index e705b3d1..3f5ca2b9 100644 --- a/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj +++ b/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj @@ -21,7 +21,9 @@ * Support BatchNormalization layer. * Building keras model in subclass, functional and sequential api * Implemented backward_function. -* Support model.load_weights. +* Support model.load_weights. +* Add Subtract layer +* Support YOLOv3 model. Keras for .NET Keras is an API designed for human beings, not machines. Keras follows best practices for reducing cognitive load: it offers consistent & simple APIs, it minimizes the number of user actions required for common use cases, and it provides clear & actionable error messages. @@ -64,4 +66,8 @@ Keras is an API designed for human beings, not machines. Keras follows best prac + + + + diff --git a/src/TensorFlowNET.Keras/Utils/base_layer_utils.cs b/src/TensorFlowNET.Keras/Utils/base_layer_utils.cs index 32a1737a..39c14fa8 100644 --- a/src/TensorFlowNET.Keras/Utils/base_layer_utils.cs +++ b/src/TensorFlowNET.Keras/Utils/base_layer_utils.cs @@ -21,6 +21,7 @@ using System.Linq; using System.Reflection; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Layers; using static Tensorflow.Binding; using static Tensorflow.KerasApi; @@ -150,12 +151,13 @@ namespace Tensorflow.Keras.Utils // recursively CreateKerasHistoryHelper(layer_inputs, processed_ops, created_layers); - var op_layer = GetLayer(new TensorFlowOpLayerArgs + var opLayerArgs = new TensorFlowOpLayerArgs { NodeDef = op.node_def, Constants = constants, Name = op.name - }); + }; + var op_layer = new TensorFlowOpLayer(opLayerArgs); created_layers.Add(op_layer); op_layer.SetConnectivityMetadata(layer_inputs, op.outputs); processed_ops.Add(op); @@ -163,20 +165,6 @@ namespace Tensorflow.Keras.Utils } } - static Layer GetLayer(LayerArgs args) - { - Layer layer = default; - var assemble = Assembly.Load("TensorFlow.Keras.Layers"); - foreach (var type in assemble.GetTypes().Where(x => x.GetInterface(typeof(T).Name) != null)) - { - layer = (Layer)Activator.CreateInstance(type, new object[] { args }); - } - - if (layer == null) - throw new NotImplementedException($"Can't find implementation for type {args.GetType().Name}"); - return layer; - } - // recusive static bool uses_keras_history(Tensor op_input) { diff --git a/tensorflowlib/README.md b/tensorflowlib/README.md index 20d30f6f..a08959a7 100644 --- a/tensorflowlib/README.md +++ b/tensorflowlib/README.md @@ -56,7 +56,7 @@ Set ENV `BAZEL_VC=C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\ 1. Build static library -`bazel build --config=opt //tensorflow:tensorflow` +`bazel build --output_base=C:/tmp/tfcompilation build --config=opt //tensorflow:tensorflow` 2. Build pip package diff --git a/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs b/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs index baca83c3..8c99c819 100644 --- a/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs +++ b/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs @@ -1,6 +1,7 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; using NumSharp; using Tensorflow; +using static Tensorflow.Binding; using static Tensorflow.KerasApi; namespace TensorFlowNET.Keras.UnitTest @@ -35,71 +36,31 @@ namespace TensorFlowNET.Keras.UnitTest var model = keras.Model(inputs, outputs, name: "mnist_model"); model.summary(); } - + /// /// Custom layer test, used in Dueling DQN /// [TestMethod, Ignore] - public void FunctionalTest() + public void TensorFlowOpLayer() { var layers = keras.layers; var inputs = layers.Input(shape: 24); - var x = layers.Dense(128, activation:"relu").Apply(inputs); + var x = layers.Dense(128, activation: "relu").Apply(inputs); var value = layers.Dense(24).Apply(x); var adv = layers.Dense(1).Apply(x); - - var adv_out = adv - Binding.tf.reduce_mean(adv, axis: 1, keepdims: true); // Here's problem. - var outputs = layers.Add().Apply(new Tensors(adv_out, value)); + + var mean = adv - tf.reduce_mean(adv, axis: 1, keepdims: true); + adv = layers.Subtract().Apply((adv, mean)); + var outputs = layers.Add().Apply((value, adv)); var model = keras.Model(inputs, outputs); - model.summary(); model.compile(optimizer: keras.optimizers.RMSprop(0.001f), loss: keras.losses.MeanSquaredError(), metrics: new[] { "acc" }); - // Here we consider the adv_out is one layer, which is a little different from py's version - Assert.AreEqual(model.Layers.Count, 6); - - // py code: - //from tensorflow.keras.layers import Input, Dense, Add, Subtract, Lambda - //from tensorflow.keras.models import Model - //from tensorflow.keras.optimizers import RMSprop - //import tensorflow.keras.backend as K - - //inputs = Input(24) - //x = Dense(128, activation = "relu")(inputs) - //value = Dense(24)(x) - //adv = Dense(1)(x) - //meam = Lambda(lambda x: K.mean(x, axis = 1, keepdims = True))(adv) - //adv = Subtract()([adv, meam]) - //outputs = Add()([value, adv]) - //model = Model(inputs, outputs) - //model.compile(loss = "mse", optimizer = RMSprop(1e-3)) - //model.summary() - - //py output: - //Model: "functional_3" - //__________________________________________________________________________________________________ - //Layer(type) Output Shape Param # Connected to - //================================================================================================== - //input_2 (InputLayer) [(None, 24)] 0 - //__________________________________________________________________________________________________ - //dense_3 (Dense) (None, 128) 3200 input_2[0][0] - //__________________________________________________________________________________________________ - //dense_5 (Dense) (None, 1) 129 dense_3[0][0] - //__________________________________________________________________________________________________ - //lambda_1 (Lambda) (None, 1) 0 dense_5[0][0] - //__________________________________________________________________________________________________ - //dense_4 (Dense) (None, 24) 3096 dense_3[0][0] - //__________________________________________________________________________________________________ - //subtract_1 (Subtract) (None, 1) 0 dense_5[0][0] - // lambda_1[0][0] - //__________________________________________________________________________________________________ - //add_1 (Add) (None, 24) 0 dense_4[0][0] - // subtract_1[0][0] - //================================================================================================== - //Total params: 6,425 - //Trainable params: 6,425 - //Non-trainable params: 0 - //__________________________________________________________________________________________________ + model.summary(); + Assert.AreEqual(model.Layers.Count, 8); + var result = model.predict(tf.constant(np.arange(24).astype(np.float32)[np.newaxis, Slice.All])); + Assert.AreEqual(result.shape, new TensorShape(1, 24)); + model.fit(np.arange(24).astype(np.float32)[np.newaxis, Slice.All], np.arange(24).astype(np.float32)[np.newaxis, Slice.All], verbose: 0); } /// @@ -149,9 +110,14 @@ namespace TensorFlowNET.Keras.UnitTest } [TestMethod] + [Ignore] public void SimpleRNN() { - + var inputs = np.random.rand(32, 10, 8).astype(np.float32); + var simple_rnn = keras.layers.SimpleRNN(4); + var output = simple_rnn.Apply(inputs); + Assert.AreEqual((32, 4), output.shape); } + } } diff --git a/test/TensorFlowNET.UnitTest/ManagedAPI/MathApiTest.cs b/test/TensorFlowNET.UnitTest/ManagedAPI/MathApiTest.cs index 26e89404..78f57b20 100644 --- a/test/TensorFlowNET.UnitTest/ManagedAPI/MathApiTest.cs +++ b/test/TensorFlowNET.UnitTest/ManagedAPI/MathApiTest.cs @@ -48,5 +48,14 @@ namespace TensorFlowNET.UnitTest.ManagedAPI var x5 = tf.reduce_sum(b, (0, 1)); Assert.AreEqual(-4.7f, (float)x5); } + + [TestMethod] + public void Erf() + { + var erf = tf.math.erf(a, name: "erf"); + var expected = new float[] { 0.8427007f, -0.5204999f, 0.99999845f, -0.9970206f, 0f, -1f }; + var actual = erf.ToArray(); + Assert.IsTrue(Equal(expected, actual)); + } } } diff --git a/test/TensorFlowNET.UnitTest/ManagedAPI/TensorOperate.cs b/test/TensorFlowNET.UnitTest/ManagedAPI/TensorOperate.cs index 9966c12e..c57c98df 100644 --- a/test/TensorFlowNET.UnitTest/ManagedAPI/TensorOperate.cs +++ b/test/TensorFlowNET.UnitTest/ManagedAPI/TensorOperate.cs @@ -132,28 +132,25 @@ namespace TensorFlowNET.UnitTest.ManagedAPI } #region ones/zeros like - [Ignore] [TestMethod] public void TestOnesLike() { #region 2-dimension - var testCase2D = tf.constant(new int[,] + var ones2D = tf.ones_like(new int[,] { { 1, 2, 3 }, { 4, 5, 6 } }); - var ones2D = tf.ones_like(testCase2D); Assert.AreEqual(new[] { 1, 1, 1 }, ones2D[0].numpy()); Assert.AreEqual(new[] { 1, 1, 1 }, ones2D[1].numpy()); #endregion #region 1-dimension - var testCase1D = tf.constant(new int[,] + var ones1D = tf.ones_like(new int[,] { { 1, 2, 3 } }); - var ones1D = tf.ones_like(testCase1D); Assert.AreEqual(new[] { 1, 1, 1 }, ones1D[0].numpy()); #endregion @@ -163,23 +160,21 @@ namespace TensorFlowNET.UnitTest.ManagedAPI public void TestZerosLike() { #region 2-dimension - var testCase2D = tf.constant(new int[,] + var zeros2D = tf.zeros_like(new int[,] { { 1, 2, 3 }, { 4, 5, 6 } }); - var zeros2D = tf.zeros_like(testCase2D); Assert.AreEqual(new[] { 0, 0, 0 }, zeros2D[0].numpy()); Assert.AreEqual(new[] { 0, 0, 0 }, zeros2D[1].numpy()); #endregion #region 1-dimension - var testCase1D = tf.constant(new int[,] + var zeros1D = tf.zeros_like(new int[,] { { 1, 2, 3 } }); - var zeros1D = tf.zeros_like(testCase1D); Assert.AreEqual(new[] { 0, 0, 0 }, zeros1D[0].numpy()); #endregion diff --git a/test/Tensorflow.Keras.UnitTest/OptimizerTest.cs b/test/Tensorflow.Keras.UnitTest/OptimizerTest.cs deleted file mode 100644 index 6647ca59..00000000 --- a/test/Tensorflow.Keras.UnitTest/OptimizerTest.cs +++ /dev/null @@ -1,11 +0,0 @@ -using Microsoft.VisualStudio.TestTools.UnitTesting; -using System.Collections.Generic; - -namespace Tensorflow.Keras.UnitTest -{ - [TestClass] - public class OptimizerTest - { - - } -} diff --git a/test/Tensorflow.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj b/test/Tensorflow.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj deleted file mode 100644 index 5f5ab347..00000000 --- a/test/Tensorflow.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj +++ /dev/null @@ -1,25 +0,0 @@ - - - - netcoreapp3.1 - - false - - AnyCPU;x64 - - - - - - - - all - runtime; build; native; contentfiles; analyzers; buildtransitive - - - - - - - -