Oceania2018 6 years ago
parent
commit
a09b86e1df
9 changed files with 411 additions and 16 deletions
  1. +36
    -3
      src/TensorFlowNET.Core/APIs/tf.layers.cs
  2. +32
    -0
      src/TensorFlowNET.Core/APIs/tf.math.cs
  3. +2
    -2
      src/TensorFlowNET.Core/APIs/tf.tensor.cs
  4. +30
    -1
      src/TensorFlowNET.Core/Binding.Util.cs
  5. +11
    -0
      src/TensorFlowNET.Core/Binding.cs
  6. +179
    -7
      src/TensorFlowNET.Core/Operations/Activation/gen_nn_ops.activations.cs
  7. +43
    -3
      src/TensorFlowNET.Core/Tensors/TensorShape.cs
  8. +60
    -0
      test/TensorFlowNET.UnitTest/TensorShapeTest.cs
  9. +18
    -0
      test/TensorFlowNET.UnitTest/layers_test/flatten.cs

+ 36
- 3
src/TensorFlowNET.Core/APIs/tf.layers.cs View File

@@ -15,6 +15,8 @@
******************************************************************************/

using System.Collections.Generic;
using System.Linq;
using NumSharp;
using Tensorflow.Keras.Layers;
using Tensorflow.Operations.Activation;
using static Tensorflow.Binding;
@@ -144,6 +146,20 @@ namespace Tensorflow
return layer.apply(inputs);
}

/// <summary>
/// Densely-connected layer class. aka fully-connected<br></br>
/// `outputs = activation(inputs * kernel + bias)`
/// </summary>
/// <param name="inputs"></param>
/// <param name="units">Python integer, dimensionality of the output space.</param>
/// <param name="activation"></param>
/// <param name="use_bias">Boolean, whether the layer uses a bias.</param>
/// <param name="kernel_initializer"></param>
/// <param name="bias_initializer"></param>
/// <param name="trainable"></param>
/// <param name="name"></param>
/// <param name="reuse"></param>
/// <returns></returns>
public Tensor dense(Tensor inputs,
int units,
IActivation activation = null,
@@ -160,7 +176,8 @@ namespace Tensorflow
var layer = new Dense(units, activation,
use_bias: use_bias,
bias_initializer: bias_initializer,
kernel_initializer: kernel_initializer);
kernel_initializer: kernel_initializer,
trainable: trainable);

return layer.apply(inputs);
}
@@ -182,6 +199,7 @@ namespace Tensorflow
string name = null,
string data_format = "channels_last")
{
var input_shape = inputs.shape;
if (inputs.shape.Length == 0)
throw new ValueError($"Input 0 of layer flatten is incompatible with the layer: : expected min_ndim={1}, found ndim={0}. Full shape received: ()");

@@ -193,9 +211,24 @@ namespace Tensorflow
inputs = array_ops.transpose(inputs, premutation.ToArray());
}

var ret = array_ops.reshape(inputs, new int[] {inputs.shape[0], -1});
ret.set_shape(new int[] {inputs.shape[0], -1});
var ret = array_ops.reshape(inputs, compute_output_shape(input_shape));
//ret.set_shape(compute_output_shape(ret.shape));
return ret;

int[] compute_output_shape(int[] inputshape)
{
if (inputshape == null || inputshape.Length == 0)
inputshape = new int[] {1};

if (inputshape.Skip(1).All(d => d > 0))
{
int[] output_shape = new int[2];
output_shape[0] = inputshape[0];
output_shape[1] = inputshape.Skip(1).Aggregate(1, (acc, rhs) => acc*rhs); //calculate size of all the rest dimensions
return output_shape;
} else
return new int[] {inputshape[0], -1}; //-1 == Binding.None
}
}
}
}


+ 32
- 0
src/TensorFlowNET.Core/APIs/tf.math.cs View File

@@ -14,6 +14,8 @@
limitations under the License.
******************************************************************************/

using Tensorflow.Operations;

namespace Tensorflow
{
public partial class tensorflow
@@ -211,6 +213,36 @@ namespace Tensorflow
/// <returns></returns>
public Tensor _clip_by_value(Tensor t, Tensor clip_value_min, Tensor clip_value_max, string name = null)
=> gen_math_ops._clip_by_value(t, clip_value_min, clip_value_max);
/// <summary>
/// Clips tensor values to a specified min and max.
/// </summary>
/// <param name="t">
/// A <c>Tensor</c>.
/// </param>
/// <param name="clip_value_min">
/// A 0-D (scalar) <c>Tensor</c>, or a <c>Tensor</c> with the same shape
/// as <c>t</c>. The minimum value to clip by.
/// </param>
/// <param name="clip_value_max">
/// A 0-D (scalar) <c>Tensor</c>, or a <c>Tensor</c> with the same shape
/// as <c>t</c>. The maximum value to clip by.
/// </param>
/// <param name="name">
/// If specified, the created operation in the graph will be this one, otherwise it will be named 'ClipByValue'.
/// </param>
/// <returns>
/// A clipped <c>Tensor</c> with the same shape as input 't'.
/// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result.
/// </returns>
/// <remarks>
/// Given a tensor <c>t</c>, this operation returns a tensor of the same type and
/// shape as <c>t</c> with its values clipped to <c>clip_value_min</c> and <c>clip_value_max</c>.
/// Any values less than <c>clip_value_min</c> are set to <c>clip_value_min</c>. Any values
/// greater than <c>clip_value_max</c> are set to <c>clip_value_max</c>.
/// </remarks>
public Tensor clip_by_value (Tensor t, Tensor clip_value_min, Tensor clip_value_max, string name = "ClipByValue")
=> gen_ops.clip_by_value(t, clip_value_min, clip_value_max, name);

public Tensor sub(Tensor a, Tensor b)
=> gen_math_ops.sub(a, b);


+ 2
- 2
src/TensorFlowNET.Core/APIs/tf.tensor.cs View File

@@ -18,8 +18,8 @@ namespace Tensorflow
{
public partial class tensorflow
{
public Tensor convert_to_tensor(object value,
string name = null) => ops.convert_to_tensor(value, name: name);
public Tensor convert_to_tensor(object value, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, TF_DataType preferred_dtype = TF_DataType.DtInvalid)
=> ops.convert_to_tensor(value, dtype, name, preferred_dtype);

public Tensor strided_slice(Tensor input, Tensor begin, Tensor end, Tensor strides = null,
int begin_mask = 0,


+ 30
- 1
src/TensorFlowNET.Core/Binding.Util.cs View File

@@ -21,6 +21,7 @@ using System.Collections.Generic;
using System.ComponentModel;
using System.Diagnostics;
using System.Linq;
using NumSharp.Utilities;

namespace Tensorflow
{
@@ -29,9 +30,37 @@ namespace Tensorflow
/// </summary>
public static partial class Binding
{
private static string _tostring(object obj)
{
switch (obj)
{
case NDArray nd:
return nd.ToString(false);
case Array arr:
if (arr.Rank!=1 || arr.GetType().GetElementType()?.IsArray == true)
arr = Arrays.Flatten(arr);
var objs = toObjectArray(arr);
return $"[{string.Join(", ", objs.Select(_tostring))}]";
default:
return obj?.ToString() ?? "null";
}

object[] toObjectArray(Array arr)
{
var len = arr.LongLength;
var ret = new object[len];
for (long i = 0; i < len; i++)
{
ret[i] = arr.GetValue(i);
}

return ret;
}
}

public static void print(object obj)
{
Console.WriteLine(obj.ToString());
Console.WriteLine(_tostring(obj));
}

public static int len(object a)


+ 11
- 0
src/TensorFlowNET.Core/Binding.cs View File

@@ -7,5 +7,16 @@ namespace Tensorflow
public static partial class Binding
{
public static tensorflow tf { get; } = New<tensorflow>();

/// <summary>
/// Alias to null, similar to python's None.
/// For TensorShape, please use Unknown
/// </summary>
public static readonly object None = null;

/// <summary>
/// Used for TensorShape None
/// </summary>
public static readonly int Unknown = -1;
}
}

+ 179
- 7
src/TensorFlowNET.Core/Operations/Activation/gen_nn_ops.activations.cs View File

@@ -14,20 +14,192 @@
limitations under the License.
******************************************************************************/

using System;
using static Tensorflow.Binding;

namespace Tensorflow.Operations.Activation
{
public class sigmoid : IActivation
{
public Tensor Activate(Tensor x, string name = null)
{
return tf.sigmoid(x);
}
}

public class tanh : IActivation
{
public Tensor Activate(Tensor x, string name = null)
{
return tf.tanh(x);
}
}

public class leakyrelu : IActivation
{
private readonly float _alpha;

public leakyrelu(float alpha = 0.3f) {
_alpha = alpha;
}

public Tensor Activate(Tensor x, string name = null)
{
return nn_ops.leaky_relu(x, _alpha);
}
}

public class elu : IActivation
{
private readonly float _alpha;

public elu(float alpha = 0.1f)
{
_alpha = alpha;
}

public Tensor Activate(Tensor x, string name = null)
{
var res = gen_ops.elu(x);
if (Math.Abs(_alpha - 0.1f) < 0.00001f)
{
return res;
}

return array_ops.@where(x > 0, res, _alpha * res);
}
}

public class softmax : IActivation
{
private readonly int _axis;

/// <summary>Initializes a new instance of the <see cref="T:System.Object"></see> class.</summary>
public softmax(int axis = -1)
{
_axis = axis;
}

public Tensor Activate(Tensor x, string name = null)
{
return nn_ops.softmax(x, _axis);
}
}

public class softplus : IActivation
{
public Tensor Activate(Tensor x, string name = null)
{
return gen_ops.softplus(x);
}
}

public class softsign : IActivation
{
public Tensor Activate(Tensor x, string name = null)
{
return gen_ops.softsign(x);
}
}

public class linear : IActivation
{
public Tensor Activate(Tensor x, string name = null)
{
return x;
}
}


public class exponential : IActivation
{
public Tensor Activate(Tensor x, string name = null)
{
return tf.exp(x, name: name);
}
}


public class relu : IActivation
{
public Tensor Activate(Tensor features, string name = null)
private readonly float _threshold;
private readonly float _alpha;
private readonly float? _maxValue;

public relu(float threshold = 0f, float alpha = 0.2f, float? max_value = null)
{
_threshold = threshold;
_alpha = alpha;
_maxValue = max_value;
}

public Tensor Activate(Tensor x, string name = null)
{
OpDefLibrary _op_def_lib = new OpDefLibrary();
//based on keras/backend.py
if (Math.Abs(_alpha) > 0.000001f)
{
if (!_maxValue.HasValue && Math.Abs(_threshold) < 0.0001)
{
return nn_ops.leaky_relu(x, _alpha);
}
}

Tensor negative_part;
if (Math.Abs(_threshold) > 0.000001f)
{
negative_part = gen_ops.relu(-x + _threshold);
} else
{
negative_part = gen_ops.relu(-x + _threshold);
}

if (Math.Abs(_threshold) > 0.000001f)
{
x = x * math_ops.cast(tf.greater(x, _threshold), TF_DataType.TF_FLOAT);
} else if (Math.Abs(_maxValue.Value - 6f) < 0.0001f)
{
x = gen_ops.relu6(x);
} else
{
x = gen_ops.relu(x);
}

bool clip_max = _maxValue.HasValue;
if (clip_max)
{
Tensor maxval = constant_op.constant(_maxValue, x.dtype.as_base_dtype());
var zero = constant_op.constant(0.0f, x.dtype.as_base_dtype());
x = gen_ops.clip_by_value(x, zero, maxval);
}

var _op = _op_def_lib._apply_op_helper("Relu", name: name, args: new
if (Math.Abs(_alpha) > 0.00001)
{
features
});
var a = constant_op.constant(_alpha, x.dtype.as_base_dtype());
x -= a * negative_part;
}

return _op.outputs[0];
return x;
}
}

public class selu : IActivation
{
public Tensor Activate(Tensor x, string name = null)
{
const float alpha = 1.6732632423543772848170429916717f;
const float scale = 1.0507009873554804934193349852946f;
return scale * new elu(alpha).Activate(x, name);
}
}

public class hard_sigmoid : IActivation
{
public Tensor Activate(Tensor x, string name = null)
{
x = (0.2 * x) + 0.5;
var zero = tf.convert_to_tensor(0.0f, x.dtype.as_base_dtype());
var one = tf.convert_to_tensor(1.0f, x.dtype.as_base_dtype());
return tf.clip_by_value(x, zero, one);
}
}
}
}

+ 43
- 3
src/TensorFlowNET.Core/Tensors/TensorShape.cs View File

@@ -3,6 +3,7 @@ using System;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Runtime.CompilerServices;
using NumSharp.Utilities;

namespace Tensorflow
{
@@ -32,7 +33,23 @@ namespace Tensorflow
/// <summary>
/// Returns the size this shape represents.
/// </summary>
public int size => shape.Size;
public int size
{
get
{
var dims = shape.Dimensions;
var computed = 1;
for (int i = 0; i < dims.Length; i++)
{
var val = dims[i];
if (val <= 0)
continue;
computed *= val;
}

return computed;
}
}

public TensorShape(TensorShapeProto proto)
{
@@ -59,12 +76,30 @@ namespace Tensorflow
switch (dims.Length)
{
case 0: shape = new Shape(new int[0]); break;
case 1: shape = Shape.Vector((int) dims[0]); break;
case 1: shape = Shape.Vector((int)dims[0]); break;
case 2: shape = Shape.Matrix(dims[0], dims[1]); break;
default: shape = new Shape(dims); break;
}
}

public TensorShape(int[][] dims)
{
if(dims.Length == 1)
{
switch (dims[0].Length)
{
case 0: shape = new Shape(new int[0]); break;
case 1: shape = Shape.Vector((int)dims[0][0]); break;
case 2: shape = Shape.Matrix(dims[0][0], dims[1][2]); break;
default: shape = new Shape(dims[0]); break;
}
}
else
{
throw new NotImplementedException("TensorShape int[][] dims");
}
}

/// <summary>
///
/// </summary>
@@ -188,6 +223,11 @@ namespace Tensorflow

public static explicit operator (int, int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 6 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4], shape.dims[5]) : (0, 0, 0, 0, 0, 0);
public static implicit operator TensorShape((int, int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6);

public static explicit operator (int, int, int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 7 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4], shape.dims[5], shape.dims[6]) : (0, 0, 0, 0, 0, 0, 0);
public static implicit operator TensorShape((int, int, int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6, dims.Item7);
public static explicit operator (int, int, int, int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 8 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4], shape.dims[5], shape.dims[6], shape.dims[7]) : (0, 0, 0, 0, 0, 0, 0, 0);
public static implicit operator TensorShape((int, int, int, int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6, dims.Item7, dims.Item8);
}
}

+ 60
- 0
test/TensorFlowNET.UnitTest/TensorShapeTest.cs View File

@@ -0,0 +1,60 @@
using System;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using NumSharp;
using Tensorflow;
using static Tensorflow.Binding;

namespace TensorFlowNET.UnitTest
{
[TestClass]
public class TensorShapeTest
{
[TestMethod]
public void Case1()
{
int a = 2;
int b = 3;
var dims = new [] { Unknown, a, b};
new TensorShape(dims).GetPrivate<Shape>("shape").Should().BeShaped(-1, 2, 3);
}

[TestMethod]
public void Case2()
{
int a = 2;
int b = 3;
var dims = new[] { Unknown, a, b};
new TensorShape(new [] {dims}).GetPrivate<Shape>("shape").Should().BeShaped(-1, 2, 3);
}

[TestMethod]
public void Case3()
{
int a = 2;
int b = Unknown;
var dims = new [] { Unknown, a, b};
new TensorShape(new [] {dims}).GetPrivate<Shape>("shape").Should().BeShaped(-1, 2, -1);
}

[TestMethod]
public void Case4()
{
TensorShape shape = (Unknown, Unknown);
shape.GetPrivate<Shape>("shape").Should().BeShaped(-1, -1);
}

[TestMethod]
public void Case5()
{
TensorShape shape = (1, Unknown, 3);
shape.GetPrivate<Shape>("shape").Should().BeShaped(1, -1, 3);
}

[TestMethod]
public void Case6()
{
TensorShape shape = (Unknown, 1, 2, 3, Unknown);
shape.GetPrivate<Shape>("shape").Should().BeShaped(-1, 1, 2, 3, -1);
}
}
}

+ 18
- 0
test/TensorFlowNET.UnitTest/layers_test/flatten.cs View File

@@ -36,5 +36,23 @@ namespace TensorFlowNET.UnitTest.layers_test
var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape());
new Action(() => sess.run(tf.layers.flatten(input), (input, NDArray.Scalar(6)))).Should().Throw<ValueError>();
}

[TestMethod]
public void Case4()
{
var sess = tf.Session().as_default();

var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape(3, 4, Unknown, 1, 2));
sess.run(tf.layers.flatten(input), (input, np.arange(3 * 4 * 3 * 1 * 2).reshape(3, 4, 3, 1, 2))).Should().BeShaped(3, 24);
}

[TestMethod]
public void Case5()
{
var sess = tf.Session().as_default();

var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape(Unknown, 4, 3, 1, 2));
sess.run(tf.layers.flatten(input), (input, np.arange(3 * 4 * 3 * 1 * 2).reshape(3, 4, 3, 1, 2))).Should().BeShaped(3, 24);
}
}
}

Loading…
Cancel
Save