Niklas Gustafsson 4 years ago
parent
commit
a6fb6618a6
41 changed files with 665 additions and 221 deletions
  1. +19
    -0
      docs/RELEASE.md
  2. +19
    -3
      src/TensorFlowNET.Console/MemoryBasicTest.cs
  3. +4
    -0
      src/TensorFlowNET.Console/Program.cs
  4. +6
    -0
      src/TensorFlowNET.Core/APIs/tf.array.cs
  5. +12
    -0
      src/TensorFlowNET.Core/APIs/tf.math.cs
  6. +2
    -0
      src/TensorFlowNET.Core/Binding.Util.cs
  7. +1
    -0
      src/TensorFlowNET.Core/Contexts/Context.AutoMode.cs
  8. +3
    -0
      src/TensorFlowNET.Core/Gradients/math_grad.cs
  9. +16
    -1
      src/TensorFlowNET.Core/Keras/ArgsDefinition/RNNArgs.cs
  10. +30
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/SimpleRNNArgs.cs
  11. +9
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/StackedRNNCellsArgs.cs
  12. +1
    -1
      src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs
  13. +34
    -13
      src/TensorFlowNET.Core/Operations/array_ops.cs
  14. +9
    -0
      src/TensorFlowNET.Core/Operations/gen_array_ops.cs
  15. +48
    -27
      src/TensorFlowNET.Core/Operations/gen_math_ops.cs
  16. +30
    -21
      src/TensorFlowNET.Core/Operations/math_ops.cs
  17. +8
    -6
      src/TensorFlowNET.Core/Tensors/Tensor.String.cs
  18. +4
    -3
      src/TensorFlowNET.Core/Tensors/Tensors.cs
  19. +3
    -1
      src/TensorFlowNET.Keras/Activations/Activations.Sigmoid.cs
  20. +3
    -1
      src/TensorFlowNET.Keras/Activations/Activations.Tanh.cs
  21. +0
    -12
      src/TensorFlowNET.Keras/Engine/Interfaces/ITensorFlowOpLayer.cs
  22. +8
    -5
      src/TensorFlowNET.Keras/Engine/Model.Fit.cs
  23. +24
    -1
      src/TensorFlowNET.Keras/Layers/LayersApi.cs
  24. +23
    -0
      src/TensorFlowNET.Keras/Layers/Merging/Subtract.cs
  25. +84
    -2
      src/TensorFlowNET.Keras/Layers/RNN.cs
  26. +14
    -0
      src/TensorFlowNET.Keras/Layers/SimpleRNN.cs
  27. +125
    -0
      src/TensorFlowNET.Keras/Layers/StackedRNNCells.cs
  28. +73
    -0
      src/TensorFlowNET.Keras/Layers/TensorFlowOpLayer.cs
  29. +3
    -3
      src/TensorFlowNET.Keras/Losses/Huber.cs
  30. +1
    -3
      src/TensorFlowNET.Keras/Losses/LogCosh.cs
  31. +1
    -1
      src/TensorFlowNET.Keras/Losses/MeanAbsolutePercentageError.cs
  32. +1
    -1
      src/TensorFlowNET.Keras/Losses/MeanSquaredError.cs
  33. +3
    -0
      src/TensorFlowNET.Keras/Optimizers/OptimizerV2.cs
  34. +7
    -1
      src/TensorFlowNET.Keras/Tensorflow.Keras.csproj
  35. +4
    -16
      src/TensorFlowNET.Keras/Utils/base_layer_utils.cs
  36. +1
    -1
      tensorflowlib/README.md
  37. +19
    -53
      test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs
  38. +9
    -0
      test/TensorFlowNET.UnitTest/ManagedAPI/MathApiTest.cs
  39. +4
    -9
      test/TensorFlowNET.UnitTest/ManagedAPI/TensorOperate.cs
  40. +0
    -11
      test/Tensorflow.Keras.UnitTest/OptimizerTest.cs
  41. +0
    -25
      test/Tensorflow.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj

+ 19
- 0
docs/RELEASE.md View File

@@ -4,6 +4,25 @@

This release contains contributions from many people at SciSharp as well as the external contributors.

**Release Date 02/06/2021**

### TensorFlow.Binding v0.33.0

* Improve memory usage
* Fix minor bugs

### TensorFlow.Keras v0.4.0

* Add Subtract layer

* Add model.load_weights and model.save_weights

* Fix memory leak issue

* Support to build YOLOv3 object detection model


**Release Date 01/09/2021**

### TensorFlow.Binding v0.32.0


+ 19
- 3
src/TensorFlowNET.Console/MemoryBasicTest.cs View File

@@ -56,15 +56,31 @@ namespace Tensorflow
{
var nd = np.zeros(1 * 256 * 256 * 3).astype(np.float32).reshape(1, 256, 256, 3);
ResourceVariable variable = tf.Variable(nd);
var nd2 = np.arange(1 * 256 * 256 * 3).astype(np.float32).reshape(1, 256, 256, 3);
variable.assign(nd2);

for (int i = 0; i< 100; i++)
for (int i = 0; i< 10; i++)
{
var v = variable.numpy();
}
};

public Action<int, int> VariableAssign
=> (epoch, iterate) =>
{
ResourceVariable variable = tf.Variable(3112f);
AssignVariable(variable);
for (int i = 0; i < 100; i++)
{
var v = variable.numpy();
if ((float)v != 1984f)
throw new ValueError("");
}
};

void AssignVariable(IVariableV1 v)
{
using var tensor = tf.constant(1984f);
v.assign(tensor);
}

public Action<int, int> MathAdd
=> (epoch, iterate) =>


+ 4
- 0
src/TensorFlowNET.Console/Program.cs View File

@@ -52,6 +52,10 @@ namespace Tensorflow
// 100K float variable.
mm.Execute(10, batchSize, basic.Variable);

mm.Execute(10, batchSize, basic.VariableRead);

mm.Execute(10, batchSize, basic.VariableAssign);

// 1 million math.
mm.Execute(10, 100 * batchSize, basic.MathAdd);



+ 6
- 0
src/TensorFlowNET.Core/APIs/tf.array.cs View File

@@ -215,6 +215,9 @@ namespace Tensorflow
public Tensor ones_like(Tensor tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true)
=> array_ops.ones_like(tensor, dtype: dtype, name: name, optimize: optimize);

public Tensor ones_like(NDArray nd, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true)
=> array_ops.ones_like(nd, dtype: dtype, name: name, optimize: optimize);

public Tensor one_hot(Tensor indices, int depth,
Tensor on_value = null,
Tensor off_value = null,
@@ -290,6 +293,9 @@ namespace Tensorflow
public Tensor zeros_like(Tensor tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true)
=> array_ops.zeros_like(tensor, dtype: dtype, name: name, optimize: optimize);

public Tensor zeros_like(NDArray nd, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true)
=> array_ops.zeros_like(nd, dtype: dtype, name: name, optimize: optimize);

/// <summary>
/// Stops gradient computation.
/// </summary>


+ 12
- 0
src/TensorFlowNET.Core/APIs/tf.math.cs View File

@@ -23,6 +23,15 @@ namespace Tensorflow
{
public Tensor log(Tensor x, string name = null)
=> gen_math_ops.log(x, name);

/// <summary>
/// Computes the Gauss error function of `x` element-wise.
/// </summary>
/// <param name="x"></param>
/// <param name="name"></param>
/// <returns></returns>
public Tensor erf(Tensor x, string name = null)
=> math_ops.erf(x, name);
}

public Tensor abs(Tensor x, string name = null)
@@ -118,6 +127,9 @@ namespace Tensorflow
public Tensor cos(Tensor x, string name = null)
=> gen_math_ops.cos(x, name);

public Tensor cos(float x, string name = null)
=> gen_math_ops.cos(x, name);

/// <summary>
/// Computes hyperbolic cosine of x element-wise.
/// </summary>


+ 2
- 0
src/TensorFlowNET.Core/Binding.Util.cs View File

@@ -137,6 +137,8 @@ namespace Tensorflow
{
switch (a)
{
case Tensors arr:
return arr.Length;
case Array arr:
return arr.Length;
case IList arr:


+ 1
- 0
src/TensorFlowNET.Core/Contexts/Context.AutoMode.cs View File

@@ -28,6 +28,7 @@ namespace Tensorflow.Contexts
/// </summary>
public sealed partial class Context
{
// [DebuggerStepThrough]
public T RunInAutoMode<T>(Func<T> graphAction, Func<T> eagerAction, params object[] args)
{
if (tf.Context.has_graph_arg(args))


+ 3
- 0
src/TensorFlowNET.Core/Gradients/math_grad.cs View File

@@ -138,6 +138,9 @@ namespace Tensorflow.Gradients
[RegisterNoGradient("GreaterEqual")]
public static Tensor[] _GreaterEqualGrad(Operation op, Tensor[] grads) => null;

[RegisterNoGradient("OnesLike")]
public static Tensor[] _OnesLike(Operation op, Tensor[] grads) => null;

[RegisterNoGradient("ZerosLike")]
public static Tensor[] _ZerosLike(Operation op, Tensor[] grads) => null;



+ 16
- 1
src/TensorFlowNET.Core/Keras/ArgsDefinition/RNNArgs.cs View File

@@ -1,6 +1,21 @@
namespace Tensorflow.Keras.ArgsDefinition
using System.Collections.Generic;

namespace Tensorflow.Keras.ArgsDefinition
{
public class RNNArgs : LayerArgs
{
public interface IRnnArgCell : ILayer
{
object state_size { get; }
}

public IRnnArgCell Cell { get; set; } = null;
public bool ReturnSequences { get; set; } = false;
public bool ReturnState { get; set; } = false;
public bool GoBackwards { get; set; } = false;
public bool Stateful { get; set; } = false;
public bool Unroll { get; set; } = false;
public bool TimeMajor { get; set; } = false;
public Dictionary<string, object> Kwargs { get; set; } = null;
}
}

+ 30
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/SimpleRNNArgs.cs View File

@@ -0,0 +1,30 @@
namespace Tensorflow.Keras.ArgsDefinition
{
public class SimpleRNNArgs : RNNArgs
{
public int Units { get; set; }
public Activation Activation { get; set; }
// units,
// activation='tanh',
// use_bias=True,
// kernel_initializer='glorot_uniform',
// recurrent_initializer='orthogonal',
// bias_initializer='zeros',
// kernel_regularizer=None,
// recurrent_regularizer=None,
// bias_regularizer=None,
// activity_regularizer=None,
// kernel_constraint=None,
// recurrent_constraint=None,
// bias_constraint=None,
// dropout=0.,
// recurrent_dropout=0.,
// return_sequences=False,
// return_state=False,
// go_backwards=False,
// stateful=False,
// unroll=False,
// **kwargs):
}
}

+ 9
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/StackedRNNCellsArgs.cs View File

@@ -0,0 +1,9 @@
using System.Collections.Generic;

namespace Tensorflow.Keras.ArgsDefinition
{
public class StackedRNNCellsArgs : LayerArgs
{
public IList<RnnCell> Cells { get; set; }
}
}

+ 1
- 1
src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs View File

@@ -46,7 +46,7 @@ namespace Tensorflow
/// matching structure of Tensors having shape `[batch_size].concatenate(s)`
/// for each `s` in `self.batch_size`.
/// </summary>
public abstract class RnnCell : ILayer
public abstract class RnnCell : ILayer, RNNArgs.IRnnArgCell
{
/// <summary>
/// Attribute that indicates whether the cell is a TF RNN cell, due the slight


+ 34
- 13
src/TensorFlowNET.Core/Operations/array_ops.cs View File

@@ -274,7 +274,7 @@ namespace Tensorflow
{
if (elem is EagerTensor eager_tensor)
{
if(switch_to_graph)
if (switch_to_graph)
elems_as_tensors.Add(constant_op.constant(eager_tensor.numpy(), dtype: dtype, name: i.ToString()));
else
elems_as_tensors.Add(eager_tensor);
@@ -366,8 +366,30 @@ namespace Tensorflow
/// <param name="name"></param>
/// <param name="optimize"></param>
/// <returns></returns>
public static Tensor ones_like<T>(T tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true)
=> ones_like_impl(tensor, dtype, name, optimize);
public static Tensor ones_like(Tensor tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true)
{
return tf_with(ops.name_scope(name, "ones_like", new Tensor[] { tensor }), scope =>
{
name = scope;
tensor = ops.convert_to_tensor(tensor, name: "tensor");

// is_fully_defined return unexpected value.
if (optimize && tensor_util.to_shape(tensor.shape).is_fully_defined() && dtype != TF_DataType.TF_VARIANT)
{

}

if (dtype != TF_DataType.DtInvalid && dtype != tensor.dtype && dtype != TF_DataType.TF_VARIANT)
{
throw new NotImplementedException("ones_like");
// return ones(shape_internal(tensor, optimize: optimize), dtype: dtype, name: name);
}
else
{
return gen_array_ops.ones_like(tensor, name: name);
}
});
}

public static Tensor reshape(Tensor tensor, Tensor shape, string name = null)
=> gen_array_ops.reshape(tensor, shape, name: name);
@@ -388,14 +410,12 @@ namespace Tensorflow
if (dtype == TF_DataType.DtInvalid)
dtype = tensor1.dtype;
var ret = ones(ones_shape, dtype: dtype, name: name);
ret.shape = tensor1.shape;
return ret;
});
}

public static Tensor ones(Tensor shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null)
{
dtype = dtype.as_base_dtype();
return tf_with(ops.name_scope(name, "ones", new { shape }), scope =>
{
name = scope;
@@ -578,11 +598,10 @@ namespace Tensorflow

if (!tf.Context.executing_eagerly())
{
var input_tensor = ops.convert_to_tensor(input);
var input_shape = input_tensor.TensorShape;
if (optimize && input_tensor.NDims > -1 && input_shape.is_fully_defined())
var input_shape = input.TensorShape;
if (optimize && input.NDims > -1 && input_shape.is_fully_defined())
{
var nd = np.array(input_tensor.shape).astype(out_type.as_numpy_dtype());
var nd = np.array(input.shape).astype(out_type.as_numpy_dtype());
return constant_op.constant(nd, name: name);
}
}
@@ -891,7 +910,7 @@ namespace Tensorflow
return tf_with(ops.name_scope(name, "transpose", new { a }), scope =>
{
var a_tensor = ops.convert_to_tensor(a);
if(perm == null)
if (perm == null)
{
var rank = a_tensor.rank;
perm = range(0, rank).OrderByDescending(x => x).ToArray();
@@ -953,7 +972,9 @@ namespace Tensorflow
=> tf.Context.RunInAutoMode2(
() => tf.OpDefLib._apply_op_helper("Slice", name, new
{
input, begin, size
input,
begin,
size
}).output,
() => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"Slice", name,
@@ -969,8 +990,8 @@ namespace Tensorflow
tf.Runner.RecordGradient("Slice", op.inputs, attrs, op.outputs);
},
new Tensors(input, begin, size));
public static Tensor stack(object values, int axis = 0, string name = "stack")
public static Tensor stack(object values, int axis = 0, string name = "stack")
{
if (axis == 0)
// If the input is a constant list, it can be converted to a constant op


+ 9
- 0
src/TensorFlowNET.Core/Operations/gen_array_ops.cs View File

@@ -591,6 +591,15 @@ namespace Tensorflow
return _op.outputs[0];
}

public static Tensor ones_like(Tensor x, string name = null)
=> tf.Context.RunInAutoMode(()
=> tf.OpDefLib._apply_op_helper("OnesLike", name, new { x }).output, ()
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"OnesLike", name,
null,
x).FirstOrDefault(),
x);

public static Tensor zeros_like(Tensor x, string name = null)
=> tf.Context.RunInAutoMode(()
=> tf.OpDefLib._apply_op_helper("ZerosLike", name, new { x }).output, ()


+ 48
- 27
src/TensorFlowNET.Core/Operations/gen_math_ops.cs View File

@@ -124,6 +124,9 @@ namespace Tensorflow
x, y).FirstOrDefault(),
x, y);

public static Tensor mean(Tensor input, int axis, bool keep_dims = false, string name = null)
=> mean(input, ops.convert_to_tensor(axis), keep_dims: keep_dims, name: name);

/// <summary>
/// Computes the mean of elements across dimensions of a tensor.
/// Reduces `input` along the dimensions given in `axis`. Unless
@@ -137,23 +140,30 @@ namespace Tensorflow
/// <param name="keep_dims"> An optional `bool`. Defaults to `False`. If true, retain reduced dimensions with length 1.</param>
/// <param name="name"> A name for the operation (optional).</param>
/// <returns> A `Tensor`. Has the same type as `input`.</returns>
public static Tensor mean<T1, T2>(T1 input, T2 axis, bool keep_dims = false, string name = null)
{
if (tf.Context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
public static Tensor mean(Tensor input, Tensor axis, bool keep_dims = false, string name = null)
=> tf.Context.RunInAutoMode2(
() => tf.OpDefLib._apply_op_helper("Mean", name, new
{
input,
reduction_indices = axis,
keep_dims = keep_dims
}).output,
() => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"Mean", name,
null,
input, axis,
"keep_dims", keep_dims);

return results[0];
}

var _op = tf.OpDefLib._apply_op_helper("Mean", name, args: new { input, reduction_indices = axis, keep_dims = keep_dims });

return _op.output;
}
"keep_dims", keep_dims).FirstOrDefault(),
(op) =>
{
var attrs = new object[]
{
"T", op.get_attr<TF_DataType>("T"),
"Tidx", op.get_attr<TF_DataType>("Tidx"),
"keep_dims", op.get_attr<bool>("keep_dims")
};
tf.Runner.RecordGradient("Mean", op.inputs, attrs, op.outputs);
},
new Tensors(input, axis));

public static Tensor mean(Tensor[] inputs, Tensor axis, bool keep_dims = false, string name = null)
{
@@ -376,8 +386,18 @@ namespace Tensorflow
return _op.outputs[0];
}

public static Tensor cos(Tensor x, string name = null)
public static Tensor cos<T>(T x, string name = null)
{
if (tf.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"Cos", name,
null,
x);

return results[0];
}

var _op = tf.OpDefLib._apply_op_helper("Cos", name, args: new { x });

return _op.outputs[0];
@@ -776,20 +796,21 @@ namespace Tensorflow
}

public static Tensor sub(Tensor x, Tensor y, string name = null)
{
if (tf.Context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
=> tf.Context.RunInAutoMode2(
() => tf.OpDefLib._apply_op_helper("Sub", name, new { x, y }).output,
() => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"Sub", name,
null,
x, y);
return results[0];
}

var _op = tf.OpDefLib._apply_op_helper("Sub", name, args: new { x, y });

return _op.output;
}
x, y).FirstOrDefault(),
(op) =>
{
var attrs = new object[]
{
"T", op.get_attr<TF_DataType>("T")
};
tf.Runner.RecordGradient("Sub", op.inputs, attrs, op.outputs);
},
new Tensors(x, y));

public static Tensor sub<Tx, Ty>(Tx x, Ty y, string name = null)
{


+ 30
- 21
src/TensorFlowNET.Core/Operations/math_ops.cs View File

@@ -265,6 +265,29 @@ namespace Tensorflow
public static Tensor equal<Tx, Ty>(Tx x, Ty y, string name = null)
=> gen_math_ops.equal(x, y, name: name);

/// <summary>
/// Computes the Gauss error function of `x` element-wise.
/// </summary>
/// <param name="x"></param>
/// <param name="name"></param>
/// <returns></returns>
public static Tensor erf(Tensor x, string name = null)
=> tf.Context.RunInAutoMode2(
() => tf.OpDefLib._apply_op_helper("Erf", name, new { x }).output,
() => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"Erf", name,
null,
x).FirstOrDefault(),
(op) =>
{
var attrs = new object[]
{
"T", op.get_attr<TF_DataType>("T")
};
tf.Runner.RecordGradient("Erf", op.inputs, attrs, op.outputs);
},
new Tensors(x));

public static Tensor sqrt(Tensor x, string name = null)
=> gen_math_ops.sqrt(x, name: name);

@@ -327,31 +350,17 @@ namespace Tensorflow
public static Tensor reduce_mean(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null, int? reduction_indices = null)
{
var r = _ReductionDims(input_tensor, axis);
if (axis == null)
{
var m = gen_math_ops.mean(input_tensor, r, keepdims, name);
return _may_reduce_to_scalar(keepdims, axis, m);
}
else
{
var m = gen_math_ops.mean(input_tensor, axis, keepdims, name);
return _may_reduce_to_scalar(keepdims, axis, m);
}
var axis_tensor = axis == null ? r : ops.convert_to_tensor(axis);
var m = gen_math_ops.mean(input_tensor, axis_tensor, keepdims, name);
return _may_reduce_to_scalar(keepdims, axis_tensor, m);
}

public static Tensor reduce_mean(Tensor[] input_tensors, int? axis = null, bool keepdims = false, string name = null)
{
if (axis == null)
{
var r = _ReductionDims(input_tensors, axis);
var m = gen_math_ops.mean(input_tensors, r, keepdims, name);
return _may_reduce_to_scalar(keepdims, axis, m);
}
else
{
var m = gen_math_ops.mean(input_tensors, axis, keepdims, name);
return _may_reduce_to_scalar(keepdims, axis, m);
}
var r = _ReductionDims(input_tensors, axis);
var axis_tensor = axis == null ? r : ops.convert_to_tensor(axis.Value);
var m = gen_math_ops.mean(input_tensors, axis_tensor, keepdims, name);
return _may_reduce_to_scalar(keepdims, axis, m);
}

/// <summary>


+ 8
- 6
src/TensorFlowNET.Core/Tensors/Tensor.String.cs View File

@@ -91,14 +91,16 @@ namespace Tensorflow

var buffer = new byte[size][];
var data_start = c_api.TF_TensorData(_handle);
var string_start = data_start + (int)(size * sizeof(ulong));
data_start += (int)(size * sizeof(ulong));
for (int i = 0; i < buffer.Length; i++)
{
var len = *(byte*)string_start;
buffer[i] = new byte[len];
string_start += 1;
Marshal.Copy(string_start, buffer[i], 0, len);
string_start += len;
IntPtr dst = IntPtr.Zero;
ulong dstLen = 0;
var read = c_api.TF_StringDecode((byte*)data_start, bytesize, (byte**)&dst, ref dstLen, tf.Status.Handle);
tf.Status.Check(true);
buffer[i] = new byte[(int)dstLen];
Marshal.Copy(dst, buffer[i], 0, buffer[i].Length);
data_start += (int)read;
}

return buffer;


+ 4
- 3
src/TensorFlowNET.Core/Tensors/Tensors.cs View File

@@ -69,13 +69,14 @@ namespace Tensorflow
=> items.Insert(index, tensor);

IEnumerator IEnumerable.GetEnumerator()
{
throw new NotImplementedException();
}
=> GetEnumerator();

public static implicit operator Tensors(Tensor tensor)
=> new Tensors(tensor);

public static implicit operator Tensors((Tensor, Tensor) tuple)
=> new Tensors(tuple.Item1, tuple.Item2);

public static implicit operator Tensors(NDArray nd)
=> new Tensors(nd);



+ 3
- 1
src/TensorFlowNET.Keras/Activations/Activations.Sigmoid.cs View File

@@ -17,7 +17,9 @@ namespace Tensorflow.Keras
return results[0];
}

throw new NotImplementedException("");
var _op = tf.OpDefLib._apply_op_helper("Sigmoid", name: name, args: new { x = features });

return _op.output;
};
}
}

+ 3
- 1
src/TensorFlowNET.Keras/Activations/Activations.Tanh.cs View File

@@ -17,7 +17,9 @@ namespace Tensorflow.Keras
return results[0];
}

throw new NotImplementedException("");
var _op = tf.OpDefLib._apply_op_helper("Tanh", name: name, args: new { x = features });

return _op.output;
};
}
}

+ 0
- 12
src/TensorFlowNET.Keras/Engine/Interfaces/ITensorFlowOpLayer.cs View File

@@ -1,12 +0,0 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.ArgsDefinition;

namespace Tensorflow.Keras.Engine
{
public interface ITensorFlowOpLayer
{
Layer GetOpLayer(TensorFlowOpLayerArgs args);
}
}

+ 8
- 5
src/TensorFlowNET.Keras/Engine/Model.Fit.cs View File

@@ -51,7 +51,7 @@ namespace Tensorflow.Keras.Engine
StepsPerExecution = _steps_per_execution
});

FitInternal(epochs);
FitInternal(epochs, verbose);
}

public void fit(IDatasetV2 dataset,
@@ -80,10 +80,10 @@ namespace Tensorflow.Keras.Engine
StepsPerExecution = _steps_per_execution
});

FitInternal(epochs);
FitInternal(epochs, verbose);
}

void FitInternal(int epochs)
void FitInternal(int epochs, int verbose)
{
stop_training = false;
_train_counter.assign(0);
@@ -96,8 +96,11 @@ namespace Tensorflow.Keras.Engine
{
// callbacks.on_train_batch_begin(step)
var results = train_step_function(iterator);
var result_pairs = string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2:F6}"));
Console.WriteLine($"Epoch: {epoch + 1:D3}/{epochs:D3}, Step: {step + 1:D4}/{data_handler.Inferredsteps:D4}, {result_pairs}");
if (verbose == 1)
{
var result_pairs = string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2:F6}"));
Console.WriteLine($"Epoch: {epoch + 1:D3}/{epochs:D3}, Step: {step + 1:D4}/{data_handler.Inferredsteps:D4}, {result_pairs}");
}
}

GC.Collect();


+ 24
- 1
src/TensorFlowNET.Keras/Layers/LayersApi.cs View File

@@ -1,4 +1,5 @@
using NumSharp;
using System.Collections.Generic;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using static Tensorflow.Binding;
@@ -142,6 +143,7 @@ namespace Tensorflow.Keras.Layers
public Dense Dense(int units,
Activation activation = null,
IInitializer kernel_initializer = null,
bool use_bias = true,
IInitializer bias_initializer = null,
TensorShape input_shape = null)
=> new Dense(new DenseArgs
@@ -149,7 +151,7 @@ namespace Tensorflow.Keras.Layers
Units = units,
Activation = activation ?? keras.activations.Linear,
KernelInitializer = kernel_initializer ?? tf.glorot_uniform_initializer,
BiasInitializer = bias_initializer ?? tf.zeros_initializer,
BiasInitializer = bias_initializer ?? (use_bias ? tf.zeros_initializer : null),
InputShape = input_shape
});

@@ -332,6 +334,24 @@ namespace Tensorflow.Keras.Layers
Alpha = alpha
});

public Layer SimpleRNN(int units) => SimpleRNN(units, "tanh");

public Layer SimpleRNN(int units,
Activation activation = null)
=> new SimpleRNN(new SimpleRNNArgs
{
Units = units,
Activation = activation
});

public Layer SimpleRNN(int units,
string activation = "tanh")
=> new SimpleRNN(new SimpleRNNArgs
{
Units = units,
Activation = GetActivationByName(activation)
});

public Layer LSTM(int units,
Activation activation = null,
Activation recurrent_activation = null,
@@ -381,6 +401,9 @@ namespace Tensorflow.Keras.Layers
public Add Add()
=> new Add(new MergeArgs { });

public Subtract Subtract()
=> new Subtract(new MergeArgs { });

public GlobalAveragePooling2D GlobalAveragePooling2D()
=> new GlobalAveragePooling2D(new Pooling2DArgs { });



+ 23
- 0
src/TensorFlowNET.Keras/Layers/Merging/Subtract.cs View File

@@ -0,0 +1,23 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.ArgsDefinition;
using static Tensorflow.Binding;

namespace Tensorflow.Keras.Layers
{
public class Subtract : Merge
{
public Subtract(MergeArgs args) : base(args)
{

}

protected override Tensors _merge_function(Tensors inputs)
{
if (len(inputs) != 2)
throw new ValueError($"A `Subtract` layer should be called on exactly 2 inputs");
return inputs[0] - inputs[1];
}
}
}

+ 84
- 2
src/TensorFlowNET.Keras/Layers/RNN.cs View File

@@ -1,4 +1,5 @@
using System;
using System.Collections.Generic;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;

@@ -6,12 +7,93 @@ namespace Tensorflow.Keras.Layers
{
public class RNN : Layer
{
public RNN(RNNArgs args)
: base(args)
private RNNArgs args;

public RNN(RNNArgs args) : base(PreConstruct(args))
{
this.args = args;
SupportsMasking = true;

// The input shape is unknown yet, it could have nested tensor inputs, and
// the input spec will be the list of specs for nested inputs, the structure
// of the input_spec will be the same as the input.

//self.input_spec = None
//self.state_spec = None
//self._states = None
//self.constants_spec = None
//self._num_constants = 0

//if stateful:
// if ds_context.has_strategy():
// raise ValueError('RNNs with stateful=True not yet supported with '
// 'tf.distribute.Strategy.')
}

private static RNNArgs PreConstruct(RNNArgs args)
{
if (args.Kwargs == null)
{
args.Kwargs = new Dictionary<string, object>();
}

// If true, the output for masked timestep will be zeros, whereas in the
// false case, output from previous timestep is returned for masked timestep.
var zeroOutputForMask = (bool)args.Kwargs.Get("zero_output_for_mask", false);

object input_shape;
var propIS = args.Kwargs.Get("input_shape", null);
var propID = args.Kwargs.Get("input_dim", null);
var propIL = args.Kwargs.Get("input_length", null);

if (propIS == null && (propID != null || propIL != null))
{
input_shape = (
propIL ?? new NoneValue(), // maybe null is needed here
propID ?? new NoneValue()); // and here
args.Kwargs["input_shape"] = input_shape;
}

return args;
}

public RNN New(LayerRnnCell cell,
bool return_sequences = false,
bool return_state = false,
bool go_backwards = false,
bool stateful = false,
bool unroll = false,
bool time_major = false)
=> new RNN(new RNNArgs
{
Cell = cell,
ReturnSequences = return_sequences,
ReturnState = return_state,
GoBackwards = go_backwards,
Stateful = stateful,
Unroll = unroll,
TimeMajor = time_major
});

public RNN New(IList<RnnCell> cell,
bool return_sequences = false,
bool return_state = false,
bool go_backwards = false,
bool stateful = false,
bool unroll = false,
bool time_major = false)
=> new RNN(new RNNArgs
{
Cell = new StackedRNNCells(new StackedRNNCellsArgs { Cells = cell }),
ReturnSequences = return_sequences,
ReturnState = return_state,
GoBackwards = go_backwards,
Stateful = stateful,
Unroll = unroll,
TimeMajor = time_major
});


protected Tensor get_initial_state(Tensor inputs)
{
return _generate_zero_filled_state_for_cell(null, null);


+ 14
- 0
src/TensorFlowNET.Keras/Layers/SimpleRNN.cs View File

@@ -0,0 +1,14 @@
using Tensorflow.Keras.ArgsDefinition;

namespace Tensorflow.Keras.Layers
{
public class SimpleRNN : RNN
{

public SimpleRNN(RNNArgs args) : base(args)
{

}

}
}

+ 125
- 0
src/TensorFlowNET.Keras/Layers/StackedRNNCells.cs View File

@@ -0,0 +1,125 @@
using System;
using System.Collections.Generic;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;

namespace Tensorflow.Keras.Layers
{
public class StackedRNNCells : Layer, RNNArgs.IRnnArgCell
{
public IList<RnnCell> Cells { get; set; }

public StackedRNNCells(StackedRNNCellsArgs args) : base(args)
{
Cells = args.Cells;
//Cells.reverse_state_order = kwargs.pop('reverse_state_order', False);
// self.reverse_state_order = kwargs.pop('reverse_state_order', False)
// if self.reverse_state_order:
// logging.warning('reverse_state_order=True in StackedRNNCells will soon '
// 'be deprecated. Please update the code to work with the '
// 'natural order of states if you rely on the RNN states, '
// 'eg RNN(return_state=True).')
// super(StackedRNNCells, self).__init__(**kwargs)
throw new NotImplementedException("");
}

public object state_size
{
get => throw new NotImplementedException();
}

//@property
//def state_size(self) :
// return tuple(c.state_size for c in
// (self.cells[::- 1] if self.reverse_state_order else self.cells))

// @property
// def output_size(self) :
// if getattr(self.cells[-1], 'output_size', None) is not None:
// return self.cells[-1].output_size
// elif _is_multiple_state(self.cells[-1].state_size) :
// return self.cells[-1].state_size[0]
// else:
// return self.cells[-1].state_size

// def get_initial_state(self, inputs= None, batch_size= None, dtype= None) :
// initial_states = []
// for cell in self.cells[::- 1] if self.reverse_state_order else self.cells:
// get_initial_state_fn = getattr(cell, 'get_initial_state', None)
// if get_initial_state_fn:
// initial_states.append(get_initial_state_fn(
// inputs=inputs, batch_size=batch_size, dtype=dtype))
// else:
// initial_states.append(_generate_zero_filled_state_for_cell(
// cell, inputs, batch_size, dtype))

// return tuple(initial_states)

// def call(self, inputs, states, constants= None, training= None, ** kwargs):
// # Recover per-cell states.
// state_size = (self.state_size[::- 1]
// if self.reverse_state_order else self.state_size)
// nested_states = nest.pack_sequence_as(state_size, nest.flatten(states))

// # Call the cells in order and store the returned states.
// new_nested_states = []
// for cell, states in zip(self.cells, nested_states) :
// states = states if nest.is_nested(states) else [states]
//# TF cell does not wrap the state into list when there is only one state.
// is_tf_rnn_cell = getattr(cell, '_is_tf_rnn_cell', None) is not None
// states = states[0] if len(states) == 1 and is_tf_rnn_cell else states
// if generic_utils.has_arg(cell.call, 'training'):
// kwargs['training'] = training
// else:
// kwargs.pop('training', None)
// # Use the __call__ function for callable objects, eg layers, so that it
// # will have the proper name scopes for the ops, etc.
// cell_call_fn = cell.__call__ if callable(cell) else cell.call
// if generic_utils.has_arg(cell.call, 'constants'):
// inputs, states = cell_call_fn(inputs, states,
// constants= constants, ** kwargs)
// else:
// inputs, states = cell_call_fn(inputs, states, ** kwargs)
// new_nested_states.append(states)

// return inputs, nest.pack_sequence_as(state_size,
// nest.flatten(new_nested_states))

// @tf_utils.shape_type_conversion
// def build(self, input_shape) :
// if isinstance(input_shape, list) :
// input_shape = input_shape[0]
// for cell in self.cells:
// if isinstance(cell, Layer) and not cell.built:
// with K.name_scope(cell.name):
// cell.build(input_shape)
// cell.built = True
// if getattr(cell, 'output_size', None) is not None:
// output_dim = cell.output_size
// elif _is_multiple_state(cell.state_size) :
// output_dim = cell.state_size[0]
// else:
// output_dim = cell.state_size
// input_shape = tuple([input_shape[0]] +
// tensor_shape.TensorShape(output_dim).as_list())
// self.built = True

// def get_config(self) :
// cells = []
// for cell in self.cells:
// cells.append(generic_utils.serialize_keras_object(cell))
// config = {'cells': cells
//}
//base_config = super(StackedRNNCells, self).get_config()
// return dict(list(base_config.items()) + list(config.items()))

// @classmethod
// def from_config(cls, config, custom_objects = None):
// from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top
// cells = []
// for cell_config in config.pop('cells'):
// cells.append(
// deserialize_layer(cell_config, custom_objects = custom_objects))
// return cls(cells, **config)
}
}

+ 73
- 0
src/TensorFlowNET.Keras/Layers/TensorFlowOpLayer.cs View File

@@ -0,0 +1,73 @@
using NumSharp;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow;
using Tensorflow.Graphs;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using static Tensorflow.Binding;

namespace Tensorflow.Keras.Layers
{
public class TensorFlowOpLayer : Layer
{
TensorFlowOpLayerArgs args;
Dictionary<int, NDArray> constants => args.Constants;
NodeDef node_def => args.NodeDef;
static string TF_OP_LAYER_NAME_PREFIX = "tf_op_layer_";
public string OpType => node_def.Op;

public TensorFlowOpLayer(TensorFlowOpLayerArgs args)
: base(new LayerArgs
{
Name = TF_OP_LAYER_NAME_PREFIX + args.Name,
Trainable = args.Trainable,
DType = args.DType,
Autocast = false
})
{
this.args = args;
built = true;
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false)
{
if (tf.Context.executing_eagerly())
return _defun_call(inputs);
return MakOp(inputs);
}

[AutoGraph]
Tensors _defun_call(Tensors inputs)
=> MakOp(inputs);

Tensors MakOp(Tensors inputs)
{
var graph = inputs.graph;
graph.as_default();
foreach (var (index, constant) in enumerate(constants))
{
var value = constant_op.constant(constant, name: node_def.Input[index]);
inputs.Insert(index, value);
}

var (c_op, _) = ops._create_c_op(graph, node_def, inputs.ToArray(), new Operation[0]);
var op = graph._create_op_from_tf_operation(c_op);
op._control_flow_post_processing();

// Record the gradient because custom-made ops don't go through the
// code-gen'd eager call path
var op_type = op.node_def.Op;

tf.Runner.RecordGradient(op_type, op.inputs._inputs, null, op.outputs);

graph.Exit();
return op.outputs;
}

public Layer GetOpLayer(TensorFlowOpLayerArgs args)
=> new TensorFlowOpLayer(args);
}
}

+ 3
- 3
src/TensorFlowNET.Keras/Losses/Huber.cs View File

@@ -27,10 +27,10 @@ namespace Tensorflow.Keras.Losses
Tensor error = math_ops.subtract(y_pred_cast, y_true_cast);
Tensor abs_error = math_ops.abs(error);
Tensor half = ops.convert_to_tensor(0.5, dtype: abs_error.dtype);
return gen_math_ops.mean(array_ops.where_v2(abs_error <= delta,
half * math_ops.pow(error, 2),
return gen_math_ops.mean(array_ops.where_v2(abs_error <= delta,
half * math_ops.pow(error, 2),
half * math_ops.pow(delta, 2) + delta * (abs_error - delta)),
axis : -1);
axis: -1);
}
}
}

+ 1
- 3
src/TensorFlowNET.Keras/Losses/LogCosh.cs View File

@@ -19,10 +19,8 @@ namespace Tensorflow.Keras.Losses
Tensor y_pred_dispatch = ops.convert_to_tensor(y_pred);
Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype);
Tensor x = y_pred_dispatch - y_true_cast;
return gen_math_ops.mean(x + gen_math_ops.softplus(-2.0 * x) - math_ops.cast(math_ops.log(tf.Variable(2.0)), x.dtype),axis: -1);

return gen_math_ops.mean(x + gen_math_ops.softplus(-2.0 * x) - math_ops.cast(math_ops.log(tf.Variable(2.0)), x.dtype), axis: -1);
}
}
}

+ 1
- 1
src/TensorFlowNET.Keras/Losses/MeanAbsolutePercentageError.cs View File

@@ -18,7 +18,7 @@ namespace Tensorflow.Keras.Losses
Tensor y_pred_dispatch = ops.convert_to_tensor(y_pred);
Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype);
Tensor diff = math_ops.abs(y_true_cast - y_pred_dispatch) / gen_math_ops.maximum(math_ops.abs(y_true_cast), gen_math_ops.cast(tf.constant(1e-7), y_pred_dispatch.dtype));
return gen_math_ops.cast(tf.constant(100), y_pred_dispatch.dtype) *gen_math_ops.mean(diff, axis: -1);
return gen_math_ops.cast(tf.constant(100), y_pred_dispatch.dtype) * gen_math_ops.mean(diff, axis: -1);
}
}
}

+ 1
- 1
src/TensorFlowNET.Keras/Losses/MeanSquaredError.cs View File

@@ -17,7 +17,7 @@ namespace Tensorflow.Keras.Losses
{
Tensor y_pred_dispatch = ops.convert_to_tensor(y_pred);
Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype);
return gen_math_ops.mean(gen_math_ops.squared_difference(y_pred_dispatch, y_true_cast), axis: -1);
return gen_math_ops.mean(gen_math_ops.squared_difference(y_pred_dispatch, y_true_cast), axis: -1);
}
}
}

+ 3
- 0
src/TensorFlowNET.Keras/Optimizers/OptimizerV2.cs View File

@@ -26,6 +26,9 @@ namespace Tensorflow.Keras.Optimizers
protected float _initial_decay = 0.0f;
protected bool _use_locking = true;

public IVariableV1 lr
=> _hyper_variables["learning_rate"];

Dictionary<string, Dictionary<string, IVariableV1>> _slots;
List<string> _slot_names;



+ 7
- 1
src/TensorFlowNET.Keras/Tensorflow.Keras.csproj View File

@@ -21,7 +21,9 @@
* Support BatchNormalization layer.
* Building keras model in subclass, functional and sequential api
* Implemented backward_function.
* Support model.load_weights.</PackageReleaseNotes>
* Support model.load_weights.
* Add Subtract layer
* Support YOLOv3 model.</PackageReleaseNotes>
<Description>Keras for .NET

Keras is an API designed for human beings, not machines. Keras follows best practices for reducing cognitive load: it offers consistent &amp; simple APIs, it minimizes the number of user actions required for common use cases, and it provides clear &amp; actionable error messages.</Description>
@@ -64,4 +66,8 @@ Keras is an API designed for human beings, not machines. Keras follows best prac
</None>
</ItemGroup>

<ItemGroup>
<Folder Include="Engine\Interfaces\" />
</ItemGroup>

</Project>

+ 4
- 16
src/TensorFlowNET.Keras/Utils/base_layer_utils.cs View File

@@ -21,6 +21,7 @@ using System.Linq;
using System.Reflection;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Layers;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;

@@ -150,12 +151,13 @@ namespace Tensorflow.Keras.Utils

// recursively
CreateKerasHistoryHelper(layer_inputs, processed_ops, created_layers);
var op_layer = GetLayer<ITensorFlowOpLayer>(new TensorFlowOpLayerArgs
var opLayerArgs = new TensorFlowOpLayerArgs
{
NodeDef = op.node_def,
Constants = constants,
Name = op.name
});
};
var op_layer = new TensorFlowOpLayer(opLayerArgs);
created_layers.Add(op_layer);
op_layer.SetConnectivityMetadata(layer_inputs, op.outputs);
processed_ops.Add(op);
@@ -163,20 +165,6 @@ namespace Tensorflow.Keras.Utils
}
}

static Layer GetLayer<T>(LayerArgs args)
{
Layer layer = default;
var assemble = Assembly.Load("TensorFlow.Keras.Layers");
foreach (var type in assemble.GetTypes().Where(x => x.GetInterface(typeof(T).Name) != null))
{
layer = (Layer)Activator.CreateInstance(type, new object[] { args });
}

if (layer == null)
throw new NotImplementedException($"Can't find implementation for type {args.GetType().Name}");
return layer;
}

// recusive
static bool uses_keras_history(Tensor op_input)
{


+ 1
- 1
tensorflowlib/README.md View File

@@ -56,7 +56,7 @@ Set ENV `BAZEL_VC=C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\

1. Build static library

`bazel build --config=opt //tensorflow:tensorflow`
`bazel build --output_base=C:/tmp/tfcompilation build --config=opt //tensorflow:tensorflow`

2. Build pip package



+ 19
- 53
test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs View File

@@ -1,6 +1,7 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using NumSharp;
using Tensorflow;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;

namespace TensorFlowNET.Keras.UnitTest
@@ -35,71 +36,31 @@ namespace TensorFlowNET.Keras.UnitTest
var model = keras.Model(inputs, outputs, name: "mnist_model");
model.summary();
}
/// <summary>
/// Custom layer test, used in Dueling DQN
/// </summary>
[TestMethod, Ignore]
public void FunctionalTest()
public void TensorFlowOpLayer()
{
var layers = keras.layers;
var inputs = layers.Input(shape: 24);
var x = layers.Dense(128, activation:"relu").Apply(inputs);
var x = layers.Dense(128, activation: "relu").Apply(inputs);
var value = layers.Dense(24).Apply(x);
var adv = layers.Dense(1).Apply(x);
var adv_out = adv - Binding.tf.reduce_mean(adv, axis: 1, keepdims: true); // Here's problem.
var outputs = layers.Add().Apply(new Tensors(adv_out, value));

var mean = adv - tf.reduce_mean(adv, axis: 1, keepdims: true);
adv = layers.Subtract().Apply((adv, mean));
var outputs = layers.Add().Apply((value, adv));
var model = keras.Model(inputs, outputs);
model.summary();
model.compile(optimizer: keras.optimizers.RMSprop(0.001f),
loss: keras.losses.MeanSquaredError(),
metrics: new[] { "acc" });
// Here we consider the adv_out is one layer, which is a little different from py's version
Assert.AreEqual(model.Layers.Count, 6);

// py code:
//from tensorflow.keras.layers import Input, Dense, Add, Subtract, Lambda
//from tensorflow.keras.models import Model
//from tensorflow.keras.optimizers import RMSprop
//import tensorflow.keras.backend as K

//inputs = Input(24)
//x = Dense(128, activation = "relu")(inputs)
//value = Dense(24)(x)
//adv = Dense(1)(x)
//meam = Lambda(lambda x: K.mean(x, axis = 1, keepdims = True))(adv)
//adv = Subtract()([adv, meam])
//outputs = Add()([value, adv])
//model = Model(inputs, outputs)
//model.compile(loss = "mse", optimizer = RMSprop(1e-3))
//model.summary()

//py output:
//Model: "functional_3"
//__________________________________________________________________________________________________
//Layer(type) Output Shape Param # Connected to
//==================================================================================================
//input_2 (InputLayer) [(None, 24)] 0
//__________________________________________________________________________________________________
//dense_3 (Dense) (None, 128) 3200 input_2[0][0]
//__________________________________________________________________________________________________
//dense_5 (Dense) (None, 1) 129 dense_3[0][0]
//__________________________________________________________________________________________________
//lambda_1 (Lambda) (None, 1) 0 dense_5[0][0]
//__________________________________________________________________________________________________
//dense_4 (Dense) (None, 24) 3096 dense_3[0][0]
//__________________________________________________________________________________________________
//subtract_1 (Subtract) (None, 1) 0 dense_5[0][0]
// lambda_1[0][0]
//__________________________________________________________________________________________________
//add_1 (Add) (None, 24) 0 dense_4[0][0]
// subtract_1[0][0]
//==================================================================================================
//Total params: 6,425
//Trainable params: 6,425
//Non-trainable params: 0
//__________________________________________________________________________________________________
model.summary();
Assert.AreEqual(model.Layers.Count, 8);
var result = model.predict(tf.constant(np.arange(24).astype(np.float32)[np.newaxis, Slice.All]));
Assert.AreEqual(result.shape, new TensorShape(1, 24));
model.fit(np.arange(24).astype(np.float32)[np.newaxis, Slice.All], np.arange(24).astype(np.float32)[np.newaxis, Slice.All], verbose: 0);
}

/// <summary>
@@ -149,9 +110,14 @@ namespace TensorFlowNET.Keras.UnitTest
}

[TestMethod]
[Ignore]
public void SimpleRNN()
{

var inputs = np.random.rand(32, 10, 8).astype(np.float32);
var simple_rnn = keras.layers.SimpleRNN(4);
var output = simple_rnn.Apply(inputs);
Assert.AreEqual((32, 4), output.shape);
}

}
}

+ 9
- 0
test/TensorFlowNET.UnitTest/ManagedAPI/MathApiTest.cs View File

@@ -48,5 +48,14 @@ namespace TensorFlowNET.UnitTest.ManagedAPI
var x5 = tf.reduce_sum(b, (0, 1));
Assert.AreEqual(-4.7f, (float)x5);
}

[TestMethod]
public void Erf()
{
var erf = tf.math.erf(a, name: "erf");
var expected = new float[] { 0.8427007f, -0.5204999f, 0.99999845f, -0.9970206f, 0f, -1f };
var actual = erf.ToArray<float>();
Assert.IsTrue(Equal(expected, actual));
}
}
}

+ 4
- 9
test/TensorFlowNET.UnitTest/ManagedAPI/TensorOperate.cs View File

@@ -132,28 +132,25 @@ namespace TensorFlowNET.UnitTest.ManagedAPI
}

#region ones/zeros like
[Ignore]
[TestMethod]
public void TestOnesLike()
{
#region 2-dimension
var testCase2D = tf.constant(new int[,]
var ones2D = tf.ones_like(new int[,]
{
{ 1, 2, 3 },
{ 4, 5, 6 }
});
var ones2D = tf.ones_like(testCase2D);

Assert.AreEqual(new[] { 1, 1, 1 }, ones2D[0].numpy());
Assert.AreEqual(new[] { 1, 1, 1 }, ones2D[1].numpy());
#endregion

#region 1-dimension
var testCase1D = tf.constant(new int[,]
var ones1D = tf.ones_like(new int[,]
{
{ 1, 2, 3 }
});
var ones1D = tf.ones_like(testCase1D);

Assert.AreEqual(new[] { 1, 1, 1 }, ones1D[0].numpy());
#endregion
@@ -163,23 +160,21 @@ namespace TensorFlowNET.UnitTest.ManagedAPI
public void TestZerosLike()
{
#region 2-dimension
var testCase2D = tf.constant(new int[,]
var zeros2D = tf.zeros_like(new int[,]
{
{ 1, 2, 3 },
{ 4, 5, 6 }
});
var zeros2D = tf.zeros_like(testCase2D);

Assert.AreEqual(new[] { 0, 0, 0 }, zeros2D[0].numpy());
Assert.AreEqual(new[] { 0, 0, 0 }, zeros2D[1].numpy());
#endregion

#region 1-dimension
var testCase1D = tf.constant(new int[,]
var zeros1D = tf.zeros_like(new int[,]
{
{ 1, 2, 3 }
});
var zeros1D = tf.zeros_like(testCase1D);

Assert.AreEqual(new[] { 0, 0, 0 }, zeros1D[0].numpy());
#endregion


+ 0
- 11
test/Tensorflow.Keras.UnitTest/OptimizerTest.cs View File

@@ -1,11 +0,0 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System.Collections.Generic;

namespace Tensorflow.Keras.UnitTest
{
[TestClass]
public class OptimizerTest
{

}
}

+ 0
- 25
test/Tensorflow.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj View File

@@ -1,25 +0,0 @@
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<TargetFramework>netcoreapp3.1</TargetFramework>

<IsPackable>false</IsPackable>

<Platforms>AnyCPU;x64</Platforms>
</PropertyGroup>

<ItemGroup>
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.6.1" />
<PackageReference Include="MSTest.TestAdapter" Version="2.1.1" />
<PackageReference Include="MSTest.TestFramework" Version="2.1.1" />
<PackageReference Include="coverlet.collector" Version="1.2.1">
<PrivateAssets>all</PrivateAssets>
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
</PackageReference>
</ItemGroup>

<ItemGroup>
<ProjectReference Include="..\..\src\TensorFlowNET.Keras\Tensorflow.Keras.csproj" />
</ItemGroup>

</Project>

Loading…
Cancel
Save