Browse Source

Introduce Axis record type.

tags/v0.60-tf.numpy
Oceania2018 4 years ago
parent
commit
da33a8b18d
21 changed files with 124 additions and 120 deletions
  1. +8
    -20
      src/TensorFlowNET.Core/APIs/tf.math.cs
  2. +1
    -1
      src/TensorFlowNET.Core/APIs/tf.nn.cs
  3. +1
    -1
      src/TensorFlowNET.Core/APIs/tf.reduce_logsumexp.cs
  4. +31
    -0
      src/TensorFlowNET.Core/NumPy/Axis.cs
  5. +2
    -2
      src/TensorFlowNET.Core/NumPy/Numpy.Math.cs
  6. +37
    -5
      src/TensorFlowNET.Core/Numpy/Shape.cs
  7. +1
    -1
      src/TensorFlowNET.Core/Numpy/Slice.cs
  8. +11
    -11
      src/TensorFlowNET.Core/Operations/image_ops_impl.cs
  9. +10
    -18
      src/TensorFlowNET.Core/Operations/math_ops.cs
  10. +1
    -1
      src/TensorFlowNET.Core/Operations/nn_impl.py.cs
  11. +1
    -0
      src/TensorFlowNET.Core/Tensorflow.Binding.csproj
  12. +2
    -0
      src/TensorFlowNET.Core/ops.cs
  13. +1
    -1
      src/TensorFlowNET.Keras/BackendImpl.cs
  14. +2
    -2
      src/TensorFlowNET.Keras/Layers/Pooling/GlobalAveragePooling1D.cs
  15. +2
    -2
      src/TensorFlowNET.Keras/Layers/Pooling/GlobalAveragePooling2D.cs
  16. +2
    -2
      src/TensorFlowNET.Keras/Layers/Pooling/GlobalMaxPooling1D.cs
  17. +2
    -2
      src/TensorFlowNET.Keras/Layers/Pooling/GlobalMaxPooling2D.cs
  18. +2
    -1
      src/TensorFlowNET.Keras/Tensorflow.Keras.csproj
  19. +5
    -0
      test/TensorFlowNET.UnitTest/Numpy/Math.Test.cs
  20. +1
    -1
      test/TensorFlowNET.UnitTest/Tensorflow.Binding.UnitTest.csproj
  21. +1
    -49
      test/TensorFlowNET.UnitTest/Utilities/FluentExtension.cs

+ 8
- 20
src/TensorFlowNET.Core/APIs/tf.math.cs View File

@@ -467,12 +467,9 @@ namespace Tensorflow
/// <param name="keepdims">If true, retains reduced dimensions with length 1.</param> /// <param name="keepdims">If true, retains reduced dimensions with length 1.</param>
/// <param name="name"></param> /// <param name="name"></param>
/// <returns>The reduced tensor.</returns> /// <returns>The reduced tensor.</returns>
public Tensor reduce_any(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null)
public Tensor reduce_any(Tensor input_tensor, Axis? axis = null, bool keepdims = false, string name = null)
=> math_ops.reduce_any(input_tensor, axis: axis, keepdims: keepdims, name: name); => math_ops.reduce_any(input_tensor, axis: axis, keepdims: keepdims, name: name);


public Tensor reduce_any(Tensor input_tensor, int axis = 0, bool keepdims = false, string name = null)
=> math_ops.reduce_any(input_tensor, axis: new[] { axis }, keepdims: keepdims, name: name);

/// <summary> /// <summary>
/// Computes the "logical and" of elements across dimensions of a tensor. /// Computes the "logical and" of elements across dimensions of a tensor.
/// </summary> /// </summary>
@@ -481,7 +478,7 @@ namespace Tensorflow
/// <param name="keepdims"></param> /// <param name="keepdims"></param>
/// <param name="name"></param> /// <param name="name"></param>
/// <returns>The reduced tensor.</returns> /// <returns>The reduced tensor.</returns>
public Tensor reduce_all(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null)
public Tensor reduce_all(Tensor input_tensor, Axis? axis = null, bool keepdims = false, string name = null)
=> math_ops.reduce_all(input_tensor, axis: axis, keepdims: keepdims, name: name); => math_ops.reduce_all(input_tensor, axis: axis, keepdims: keepdims, name: name);


/// <summary> /// <summary>
@@ -492,7 +489,7 @@ namespace Tensorflow
/// <param name="keepdims"></param> /// <param name="keepdims"></param>
/// <param name="name"></param> /// <param name="name"></param>
/// <returns></returns> /// <returns></returns>
public Tensor reduce_prod(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null)
public Tensor reduce_prod(Tensor input_tensor, Axis? axis = null, bool keepdims = false, string name = null)
=> math_ops.reduce_prod(input_tensor, axis: axis, keepdims: keepdims, name: name); => math_ops.reduce_prod(input_tensor, axis: axis, keepdims: keepdims, name: name);


/// <summary> /// <summary>
@@ -537,19 +534,16 @@ namespace Tensorflow
/// <param name="keepdims"></param> /// <param name="keepdims"></param>
/// <param name="name"></param> /// <param name="name"></param>
/// <returns></returns> /// <returns></returns>
public Tensor reduce_max(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null)
=> math_ops.reduce_max(input_tensor, axis, keepdims, name);

public Tensor reduce_max(Tensor input_tensor, int axis, bool keepdims = false, string name = null)
public Tensor reduce_max(Tensor input_tensor, Axis? axis = null, bool keepdims = false, string name = null)
=> math_ops.reduce_max(input_tensor, axis, keepdims, name); => math_ops.reduce_max(input_tensor, axis, keepdims, name);


public Tensor reduce_min(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null)
public Tensor reduce_min(Tensor input_tensor, Axis? axis = null, bool keepdims = false, string name = null)
=> math_ops.reduce_min(input_tensor, axis, keepdims, name); => math_ops.reduce_min(input_tensor, axis, keepdims, name);


public Tensor reduce_std(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null)
public Tensor reduce_std(Tensor input_tensor, Axis? axis = null, bool keepdims = false, string name = null)
=> math_ops.reduce_std(input_tensor, axis, keepdims, name); => math_ops.reduce_std(input_tensor, axis, keepdims, name);


public Tensor reduce_variance(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null)
public Tensor reduce_variance(Tensor input_tensor, Axis? axis = null, bool keepdims = false, string name = null)
=> math_ops.reduce_variance(input_tensor, axis, keepdims, name); => math_ops.reduce_variance(input_tensor, axis, keepdims, name);


public Tensor sigmoid<T>(T x, string name = null) public Tensor sigmoid<T>(T x, string name = null)
@@ -558,15 +552,9 @@ namespace Tensorflow
public Tensor sum(Tensor input, int axis, bool keep_dims = false, string name = null) public Tensor sum(Tensor input, int axis, bool keep_dims = false, string name = null)
=> gen_math_ops._sum(input, axis, keep_dims: keep_dims, name: name); => gen_math_ops._sum(input, axis, keep_dims: keep_dims, name: name);


public Tensor reduce_mean(Tensor input_tensors, int axis, bool keepdims = false, string name = null)
=> math_ops.reduce_mean(input_tensors, axis: new[] { axis }, keepdims: keepdims, name: name);

public Tensor reduce_mean(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null, int? reduction_indices = null)
public Tensor reduce_mean(Tensor input_tensor, Axis? axis = null, bool keepdims = false, string name = null, int? reduction_indices = null)
=> math_ops.reduce_mean(input_tensor, axis: axis, keepdims: keepdims, name: name, reduction_indices: reduction_indices); => math_ops.reduce_mean(input_tensor, axis: axis, keepdims: keepdims, name: name, reduction_indices: reduction_indices);


public Tensor reduce_mean(Tensor[] input_tensors, int? axis = null, bool keepdims = false, string name = null)
=> math_ops.reduce_mean(input_tensors, axis: axis, keepdims: keepdims, name: name);

public Tensor round(Tensor x, string name = null) public Tensor round(Tensor x, string name = null)
=> gen_math_ops.round(x, name: name); => gen_math_ops.round(x, name: name);




+ 1
- 1
src/TensorFlowNET.Core/APIs/tf.nn.cs View File

@@ -89,7 +89,7 @@ namespace Tensorflow
=> gen_nn_ops.elu(features, name: name); => gen_nn_ops.elu(features, name: name);


public (Tensor, Tensor) moments(Tensor x, public (Tensor, Tensor) moments(Tensor x,
int[] axes,
Axis axes,
string name = null, string name = null,
bool keep_dims = false) => nn_impl.moments(x, bool keep_dims = false) => nn_impl.moments(x,
axes, axes,


+ 1
- 1
src/TensorFlowNET.Core/APIs/tf.reduce_logsumexp.cs View File

@@ -19,7 +19,7 @@ namespace Tensorflow
public partial class tensorflow public partial class tensorflow
{ {
public Tensor reduce_logsumexp(Tensor input_tensor, public Tensor reduce_logsumexp(Tensor input_tensor,
int[] axis = null,
Axis? axis = null,
bool keepdims = false, bool keepdims = false,
string name = null) => math_ops.reduce_logsumexp(input_tensor, axis, keepdims, name); string name = null) => math_ops.reduce_logsumexp(input_tensor, axis, keepdims, name);




+ 31
- 0
src/TensorFlowNET.Core/NumPy/Axis.cs View File

@@ -0,0 +1,31 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public record Axis(params int[] axis)
{
public int this[int index] => axis[index];

public static implicit operator int[]?(Axis axis)
=> axis?.axis;

public static implicit operator Axis(int axis)
=> new Axis(axis);

public static implicit operator Axis((int, int) axis)
=> new Axis(axis);

public static implicit operator Axis((int, int, int) axis)
=> new Axis(axis);

public static implicit operator Axis(int[] axis)
=> new Axis(axis);
}
}

namespace System.Runtime.CompilerServices
{
internal static class IsExternalInit { }
}

+ 2
- 2
src/TensorFlowNET.Core/NumPy/Numpy.Math.cs View File

@@ -12,8 +12,8 @@ namespace Tensorflow.NumPy
public static NDArray log(NDArray x) public static NDArray log(NDArray x)
=> throw new NotImplementedException(""); => throw new NotImplementedException("");


public static NDArray prod(NDArray array, int? axis = null, Type dtype = null, bool keepdims = false)
=> tf.reduce_prod(ops.convert_to_tensor(array));
public static NDArray prod(NDArray array, Axis? axis = null, Type? dtype = null, bool keepdims = false)
=> tf.reduce_prod(ops.convert_to_tensor(array), axis: axis);


public static NDArray prod<T>(params T[] array) where T : unmanaged public static NDArray prod<T>(params T[] array) where T : unmanaged
=> tf.reduce_prod(ops.convert_to_tensor(array)); => tf.reduce_prod(ops.convert_to_tensor(array));


+ 37
- 5
src/TensorFlowNET.Core/Numpy/Shape.cs View File

@@ -3,7 +3,7 @@ using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Text; using System.Text;


namespace Tensorflow.NumPy
namespace Tensorflow
{ {
public class Shape public class Shape
{ {
@@ -11,6 +11,13 @@ namespace Tensorflow.NumPy
long[] _dims; long[] _dims;
public long[] dims => _dims; public long[] dims => _dims;


public Shape()
{
}

public Shape(params int[] dims)
=> _dims = dims.Select(x => Convert.ToInt64(x)).ToArray();

public Shape(params long[] dims) public Shape(params long[] dims)
=> _dims = dims; => _dims = dims;


@@ -21,14 +28,27 @@ namespace Tensorflow.NumPy
=> new Shape(dims); => new Shape(dims);


public static implicit operator Shape(int[] dims) public static implicit operator Shape(int[] dims)
=> new Shape(dims.Select(x => Convert.ToInt64(x)).ToArray());
=> new Shape(dims);

public static implicit operator Shape((int, int) dims)
=> new Shape(dims.Item1, dims.Item2);


public static implicit operator Shape((long, long) dims) public static implicit operator Shape((long, long) dims)
=> new Shape(dims.Item1, dims.Item2); => new Shape(dims.Item1, dims.Item2);


public bool IsSliced => throw new NotImplementedException("");
public bool IsScalar => throw new NotImplementedException("");
public bool IsBroadcasted => throw new NotImplementedException("");
public static implicit operator Shape((int, int, int) dims)
=> new Shape(dims.Item1, dims.Item2, dims.Item3);

public static implicit operator Shape((long, long, long) dims)
=> new Shape(dims.Item1, dims.Item2, dims.Item3);

public static implicit operator Shape((int, int, int, int) dims)
=> new Shape(dims.Item1, dims.Item2, dims.Item3, dims.Item4);

public static implicit operator Shape((long, long, long, long) dims)
=> new Shape(dims.Item1, dims.Item2, dims.Item3, dims.Item4);

public bool IsScalar => ndim == 0;


public static Shape Scalar public static Shape Scalar
=> new Shape(new long[0]); => new Shape(new long[0]);
@@ -55,6 +75,18 @@ namespace Tensorflow.NumPy


public bool IsEmpty => throw new NotImplementedException(""); public bool IsEmpty => throw new NotImplementedException("");


public override bool Equals(object obj)
{
if(obj is Shape shape)
{
if (shape.ndim != ndim)
return false;
if (Enumerable.SequenceEqual(dims, shape.dims))
return true;
}
return base.Equals(obj);
}

public override string ToString() public override string ToString()
{ {
return "(" + string.Join(", ", _dims) + ")"; return "(" + string.Join(", ", _dims) + ")";


+ 1
- 1
src/TensorFlowNET.Core/Numpy/Slice.cs View File

@@ -4,7 +4,7 @@ using System.Linq;
using System.Text; using System.Text;
using System.Text.RegularExpressions; using System.Text.RegularExpressions;


namespace Tensorflow.NumPy
namespace Tensorflow
{ {
/// <summary> <br></br> /// <summary> <br></br>
/// NDArray can be indexed using slicing <br></br> /// NDArray can be indexed using slicing <br></br>


+ 11
- 11
src/TensorFlowNET.Core/Operations/image_ops_impl.cs View File

@@ -968,9 +968,9 @@ new_height, new_width");
var num_pixels_ = array_ops.shape(image).dims; var num_pixels_ = array_ops.shape(image).dims;
num_pixels_ = num_pixels_.Skip(num_pixels_.Length - 3).Take(num_pixels_.Length - (num_pixels_.Length - 3)).ToArray(); num_pixels_ = num_pixels_.Skip(num_pixels_.Length - 3).Take(num_pixels_.Length - (num_pixels_.Length - 3)).ToArray();
Tensor num_pixels = math_ops.reduce_prod(new Tensor(num_pixels_)); Tensor num_pixels = math_ops.reduce_prod(new Tensor(num_pixels_));
Tensor image_mean = math_ops.reduce_mean(image, axis: new int[] { -1, -2, -3 }, keepdims: true);
Tensor image_mean = math_ops.reduce_mean(image, axis: new(-1, -2, -3), keepdims: true);


var stddev = math_ops.reduce_std(image, axis: new int[] { -1, -2, -3 }, keepdims: true);
var stddev = math_ops.reduce_std(image, axis: new(-1, -2, -3), keepdims: true);
var min_stddev = math_ops.rsqrt(math_ops.cast(num_pixels, image.dtype)); var min_stddev = math_ops.rsqrt(math_ops.cast(num_pixels, image.dtype));
var adjusted_stddev = math_ops.maximum(stddev, min_stddev); var adjusted_stddev = math_ops.maximum(stddev, min_stddev);


@@ -1408,7 +1408,7 @@ new_height, new_width");
max_val = convert_image_dtype(max_val, dtypes.float32); max_val = convert_image_dtype(max_val, dtypes.float32);
a = convert_image_dtype(a, dtypes.float32); a = convert_image_dtype(a, dtypes.float32);
b = convert_image_dtype(b, dtypes.float32); b = convert_image_dtype(b, dtypes.float32);
Tensor mse = math_ops.reduce_mean(gen_math_ops.squared_difference(a, b), new int[] { -3, -2, -1 });
Tensor mse = math_ops.reduce_mean(gen_math_ops.squared_difference(a, b), new(-3, -2, -1));
var psnr_val = math_ops.subtract( var psnr_val = math_ops.subtract(
(20 * math_ops.log(max_val)) / math_ops.log(ops.convert_to_tensor(10.0)), (20 * math_ops.log(max_val)) / math_ops.log(ops.convert_to_tensor(10.0)),
math_ops.cast(10 / math_ops.log(ops.convert_to_tensor(10)), dtypes.float32) * math_ops.log(mse), math_ops.cast(10 / math_ops.log(ops.convert_to_tensor(10)), dtypes.float32) * math_ops.log(mse),
@@ -1503,8 +1503,8 @@ new_height, new_width");
(Tensor luminance, Tensor cs) = _ssim_helper(img1, img2, reducer, max_val, compensation, k1, k2); (Tensor luminance, Tensor cs) = _ssim_helper(img1, img2, reducer, max_val, compensation, k1, k2);


var axes = constant_op.constant(new[] { -3, -2 }, dtype: dtypes.int32); var axes = constant_op.constant(new[] { -3, -2 }, dtype: dtypes.int32);
var ssim_val = math_ops.reduce_mean(luminance * cs, axes.dims);
cs = math_ops.reduce_mean(cs, axes.dims);
var ssim_val = math_ops.reduce_mean(luminance * cs, new(axes.dims));
cs = math_ops.reduce_mean(cs, new(axes.dims));
return (ssim_val, cs); return (ssim_val, cs);
} }


@@ -1527,7 +1527,7 @@ new_height, new_width");
(Tensor ssim_per_channel, Tensor ___) = _ssim_per_channel(img1, img2, max_val, filter_size, (Tensor ssim_per_channel, Tensor ___) = _ssim_per_channel(img1, img2, max_val, filter_size,
filter_sigma, k1, k2); filter_sigma, k1, k2);


return math_ops.reduce_mean(ssim_per_channel, new int[] { -1 });
return math_ops.reduce_mean(ssim_per_channel, new(-1));
}); });
} }


@@ -1645,9 +1645,9 @@ new_height, new_width");
var mcs_and_ssim = array_ops.stack( var mcs_and_ssim = array_ops.stack(
math_ops.add(mcs, new[] { gen_nn_ops.relu(ssim_per_channel) }), axis: -1); math_ops.add(mcs, new[] { gen_nn_ops.relu(ssim_per_channel) }), axis: -1);
var ms_ssim = math_ops.reduce_prod( var ms_ssim = math_ops.reduce_prod(
math_ops.pow(mcs_and_ssim, power_factors), new int[] { -1 });
math_ops.pow(mcs_and_ssim, power_factors), new(-1));


return math_ops.reduce_mean(ms_ssim, new int[] { -1 });
return math_ops.reduce_mean(ms_ssim, new(-1));
}); });
} }


@@ -1830,7 +1830,7 @@ new_height, new_width");
new object[] { batch_size, tile_size, 4 }); new object[] { batch_size, tile_size, 4 });
var iou = _bbox_overlap(new_slice, box_slice); var iou = _bbox_overlap(new_slice, box_slice);
var box_slice_after_suppression = array_ops.expand_dims( var box_slice_after_suppression = array_ops.expand_dims(
math_ops.cast(math_ops.reduce_all(iou < iou_threshold, new int[] { 1 }),
math_ops.cast(math_ops.reduce_all(iou < iou_threshold, new(1)),
box_slice.dtype), box_slice.dtype),
2) * box_slice; 2) * box_slice;
return (boxes, box_slice_after_suppression, iou_threshold, inner_idx + 1); return (boxes, box_slice_after_suppression, iou_threshold, inner_idx + 1);
@@ -1913,7 +1913,7 @@ new_height, new_width");


output_size = output_size + math_ops.reduce_sum( output_size = output_size + math_ops.reduce_sum(
math_ops.cast( math_ops.cast(
math_ops.reduce_any(box_slice > 0, new int[] { 2 }), dtypes.int32), new int[] { 1 });
math_ops.reduce_any(box_slice > 0, new(2)), dtypes.int32), new int[] { 1 });
} }
return (boxes, iou_threshold, output_size, idx + 1); return (boxes, iou_threshold, output_size, idx + 1);
} }
@@ -2074,7 +2074,7 @@ new_height, new_width");


(Tensor values, Tensor indices) = gen_ops.top_k_v2( (Tensor values, Tensor indices) = gen_ops.top_k_v2(
math_ops.cast(math_ops.reduce_any( math_ops.cast(math_ops.reduce_any(
(Tensor)selboxes__output_size_[0] > 0, new int[] { 2 }), dtypes.int32) *
(Tensor)selboxes__output_size_[0] > 0, new(2)), dtypes.int32) *
array_ops.expand_dims( array_ops.expand_dims(
math_ops.range(num_boxes_after_padding, 0, -1), 0), math_ops.range(num_boxes_after_padding, 0, -1), 0),
max_output_size); max_output_size);


+ 10
- 18
src/TensorFlowNET.Core/Operations/math_ops.cs View File

@@ -305,7 +305,7 @@ namespace Tensorflow
/// dimensions.Must be in the range `[-rank(input_tensor), rank(input_tensor))`.</param> /// dimensions.Must be in the range `[-rank(input_tensor), rank(input_tensor))`.</param>
/// <param name="keepdims"> If true, retains reduced dimensions with length 1.</param> /// <param name="keepdims"> If true, retains reduced dimensions with length 1.</param>
/// <param name="name"> A name for the operation (optional).</param> /// <param name="name"> A name for the operation (optional).</param>
public static Tensor reduce_mean(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null, int? reduction_indices = null)
public static Tensor reduce_mean(Tensor input_tensor, Axis? axis = null, bool keepdims = false, string name = null, int? reduction_indices = null)
{ {
var r = _ReductionDims(input_tensor, axis); var r = _ReductionDims(input_tensor, axis);
var axis_tensor = axis == null ? r : ops.convert_to_tensor(axis); var axis_tensor = axis == null ? r : ops.convert_to_tensor(axis);
@@ -313,14 +313,6 @@ namespace Tensorflow
return _may_reduce_to_scalar(keepdims, axis_tensor, m); return _may_reduce_to_scalar(keepdims, axis_tensor, m);
} }


public static Tensor reduce_mean(Tensor[] input_tensors, int? axis = null, bool keepdims = false, string name = null)
{
var r = _ReductionDims(input_tensors, axis);
var axis_tensor = axis == null ? r : ops.convert_to_tensor(axis.Value);
var m = gen_math_ops.mean(input_tensors, axis_tensor, keepdims, name);
return _may_reduce_to_scalar(keepdims, axis, m);
}

/// <summary> /// <summary>
/// Computes the product of elements across dimensions of a tensor. /// Computes the product of elements across dimensions of a tensor.
/// </summary> /// </summary>
@@ -329,7 +321,7 @@ namespace Tensorflow
/// <param name="keepdims"></param> /// <param name="keepdims"></param>
/// <param name="name"></param> /// <param name="name"></param>
/// <returns></returns> /// <returns></returns>
public static Tensor reduce_prod(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null)
public static Tensor reduce_prod(Tensor input_tensor, Axis? axis = null, bool keepdims = false, string name = null)
{ {
var r = _ReductionDims(input_tensor, axis); var r = _ReductionDims(input_tensor, axis);
if (axis == null) if (axis == null)
@@ -344,7 +336,7 @@ namespace Tensorflow
} }
} }


public static Tensor reduce_std(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null)
public static Tensor reduce_std(Tensor input_tensor, Axis? axis = null, bool keepdims = false, string name = null)
{ {
if (name == null) if (name == null)
name = "reduce_std"; name = "reduce_std";
@@ -357,7 +349,7 @@ namespace Tensorflow
}); });
} }


public static Tensor reduce_variance(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null)
public static Tensor reduce_variance(Tensor input_tensor, Axis? axis = null, bool keepdims = false, string name = null)
{ {
if (name == null) if (name == null)
name = "reduce_variance"; name = "reduce_variance";
@@ -513,7 +505,7 @@ namespace Tensorflow
/// <param name="keepdims"></param> /// <param name="keepdims"></param>
/// <param name="name"></param> /// <param name="name"></param>
/// <returns></returns> /// <returns></returns>
public static Tensor reduce_all(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null)
public static Tensor reduce_all(Tensor input_tensor, Axis? axis = null, bool keepdims = false, string name = null)
{ {
var all = gen_math_ops._all(input_tensor, var all = gen_math_ops._all(input_tensor,
_ReductionDims(input_tensor, axis), _ReductionDims(input_tensor, axis),
@@ -545,7 +537,7 @@ namespace Tensorflow
/// dimensions.Must be in the range `[-rank(input_tensor), rank(input_tensor))`.</param> /// dimensions.Must be in the range `[-rank(input_tensor), rank(input_tensor))`.</param>
/// <param name="keepdims"></param> /// <param name="keepdims"></param>
/// <returns> The reduced tensor.</returns> /// <returns> The reduced tensor.</returns>
public static Tensor reduce_logsumexp(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null)
public static Tensor reduce_logsumexp(Tensor input_tensor, Axis? axis = null, bool keepdims = false, string name = null)
{ {
return tf_with(ops.name_scope(name, "ReduceLogSumExp", new { input_tensor }), scope => return tf_with(ops.name_scope(name, "ReduceLogSumExp", new { input_tensor }), scope =>
{ {
@@ -565,7 +557,7 @@ namespace Tensorflow
}); });
} }


public static Tensor reduce_any(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null)
public static Tensor reduce_any(Tensor input_tensor, Axis? axis = null, bool keepdims = false, string name = null)
{ {
var r = _ReductionDims(input_tensor, axis); var r = _ReductionDims(input_tensor, axis);
var max = (axis != null) ? gen_math_ops._any(input_tensor, axis, keepdims, name) : var max = (axis != null) ? gen_math_ops._any(input_tensor, axis, keepdims, name) :
@@ -573,7 +565,7 @@ namespace Tensorflow
return _may_reduce_to_scalar(keepdims, axis, max); return _may_reduce_to_scalar(keepdims, axis, max);
} }


public static Tensor reduce_max(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null)
public static Tensor reduce_max(Tensor input_tensor, Axis? axis = null, bool keepdims = false, string name = null)
{ {
var r = _ReductionDims(input_tensor, axis); var r = _ReductionDims(input_tensor, axis);
var max = (axis != null) ? gen_math_ops._max(input_tensor, axis, keepdims, name) : var max = (axis != null) ? gen_math_ops._max(input_tensor, axis, keepdims, name) :
@@ -588,7 +580,7 @@ namespace Tensorflow
return _may_reduce_to_scalar(keepdims, axis, max); return _may_reduce_to_scalar(keepdims, axis, max);
} }


public static Tensor reduce_min(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null)
public static Tensor reduce_min(Tensor input_tensor, Axis? axis = null, bool keepdims = false, string name = null)
{ {
var r = _ReductionDims(input_tensor, axis); var r = _ReductionDims(input_tensor, axis);
var min = gen_math_ops._min(input_tensor, r, keepdims, name); var min = gen_math_ops._min(input_tensor, r, keepdims, name);
@@ -711,7 +703,7 @@ namespace Tensorflow
return range(0, array_ops.rank(x)); return range(0, array_ops.rank(x));
} }


private static Tensor _ReductionDims(Tensor x, int[] axis)
private static Tensor _ReductionDims(Tensor x, Axis? axis)
{ {
if (axis != null) if (axis != null)
{ {


+ 1
- 1
src/TensorFlowNET.Core/Operations/nn_impl.py.cs View File

@@ -79,7 +79,7 @@ namespace Tensorflow
/// <param name="keep_dims"> Produce moments with the same dimensionality as the input.</param> /// <param name="keep_dims"> Produce moments with the same dimensionality as the input.</param>
/// <returns> Two `Tensor` objects: `mean` and `variance`.</returns> /// <returns> Two `Tensor` objects: `mean` and `variance`.</returns>
public static (Tensor, Tensor) moments(Tensor x, public static (Tensor, Tensor) moments(Tensor x,
int[] axes,
Axis axes,
string name = null, string name = null,
bool keep_dims = false) bool keep_dims = false)
{ {


+ 1
- 0
src/TensorFlowNET.Core/Tensorflow.Binding.csproj View File

@@ -7,6 +7,7 @@
<TargetTensorFlow>2.2.0</TargetTensorFlow> <TargetTensorFlow>2.2.0</TargetTensorFlow>
<Version>0.60.0</Version> <Version>0.60.0</Version>
<LangVersion>9.0</LangVersion> <LangVersion>9.0</LangVersion>
<Nullable>enable</Nullable>
<Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors> <Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors>
<Company>SciSharp STACK</Company> <Company>SciSharp STACK</Company>
<GeneratePackageOnBuild>true</GeneratePackageOnBuild> <GeneratePackageOnBuild>true</GeneratePackageOnBuild>


+ 2
- 0
src/TensorFlowNET.Core/ops.cs View File

@@ -156,6 +156,8 @@ namespace Tensorflow
Tensor[] tensors => array_ops._autopacking_helper(tensors, dtype, name == null ? "packed" : name), Tensor[] tensors => array_ops._autopacking_helper(tensors, dtype, name == null ? "packed" : name),
RefVariable varVal => varVal._TensorConversionFunction(dtype: dtype, name: name, as_ref: as_ref), RefVariable varVal => varVal._TensorConversionFunction(dtype: dtype, name: name, as_ref: as_ref),
ResourceVariable varVal => varVal._TensorConversionFunction(dtype: dtype, name: name, as_ref: as_ref), ResourceVariable varVal => varVal._TensorConversionFunction(dtype: dtype, name: name, as_ref: as_ref),
Axis ts => constant_op.constant(ts.axis, dtype: dtype, name: name),
Shape ts => constant_op.constant(ts.dims, dtype: dtype, name: name),
TensorShape ts => constant_op.constant(ts.dims, dtype: dtype, name: name), TensorShape ts => constant_op.constant(ts.dims, dtype: dtype, name: name),
string str => constant_op.constant(str, dtype: tf.@string, name: name), string str => constant_op.constant(str, dtype: tf.@string, name: name),
string[] str => constant_op.constant(str, dtype: tf.@string, name: name), string[] str => constant_op.constant(str, dtype: tf.@string, name: name),


+ 1
- 1
src/TensorFlowNET.Keras/BackendImpl.cs View File

@@ -142,7 +142,7 @@ namespace Tensorflow.Keras
{ {
if (x.dtype.as_base_dtype() == TF_DataType.TF_BOOL) if (x.dtype.as_base_dtype() == TF_DataType.TF_BOOL)
x = math_ops.cast(x, TF_DataType.TF_FLOAT); x = math_ops.cast(x, TF_DataType.TF_FLOAT);
return math_ops.reduce_mean(x, axis: new[] { axis }, keepdims: false);
return math_ops.reduce_mean(x, axis: axis, keepdims: false);
} }


public GraphLearningPhase learning_phase() public GraphLearningPhase learning_phase()


+ 2
- 2
src/TensorFlowNET.Keras/Layers/Pooling/GlobalAveragePooling1D.cs View File

@@ -15,9 +15,9 @@ namespace Tensorflow.Keras.Layers
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
{ {
if (data_format == "channels_last") if (data_format == "channels_last")
return math_ops.reduce_mean(inputs, new int[] { 1 }, false);
return math_ops.reduce_mean(inputs, 1, false);
else else
return math_ops.reduce_mean(inputs, new int[] { 2 }, false);
return math_ops.reduce_mean(inputs, 2, false);
} }
} }
} }

+ 2
- 2
src/TensorFlowNET.Keras/Layers/Pooling/GlobalAveragePooling2D.cs View File

@@ -15,9 +15,9 @@ namespace Tensorflow.Keras.Layers
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
{ {
if (data_format == "channels_last") if (data_format == "channels_last")
return math_ops.reduce_mean(inputs, new int[] { 1, 2 }, false);
return math_ops.reduce_mean(inputs, (1, 2), false);
else else
return math_ops.reduce_mean(inputs, new int[] { 2, 3 }, false);
return math_ops.reduce_mean(inputs, (2, 3), false);
} }
} }
} }

+ 2
- 2
src/TensorFlowNET.Keras/Layers/Pooling/GlobalMaxPooling1D.cs View File

@@ -15,9 +15,9 @@ namespace Tensorflow.Keras.Layers
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
{ {
if (data_format == "channels_last") if (data_format == "channels_last")
return math_ops.reduce_max(inputs, new int[] { 1 }, false);
return math_ops.reduce_max(inputs, 1, false);
else else
return math_ops.reduce_max(inputs, new int[] { 2 }, false);
return math_ops.reduce_max(inputs, 2, false);
} }
} }
} }

+ 2
- 2
src/TensorFlowNET.Keras/Layers/Pooling/GlobalMaxPooling2D.cs View File

@@ -15,9 +15,9 @@ namespace Tensorflow.Keras.Layers
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
{ {
if (data_format == "channels_last") if (data_format == "channels_last")
return math_ops.reduce_max(inputs, new int[] { 1, 2 }, false);
return math_ops.reduce_max(inputs, (1, 2), false);
else else
return math_ops.reduce_max(inputs, new int[] { 2, 3 }, false);
return math_ops.reduce_max(inputs, (2, 3), false);
} }
} }
} }

+ 2
- 1
src/TensorFlowNET.Keras/Tensorflow.Keras.csproj View File

@@ -3,7 +3,8 @@
<PropertyGroup> <PropertyGroup>
<TargetFramework>netstandard2.1</TargetFramework> <TargetFramework>netstandard2.1</TargetFramework>
<AssemblyName>Tensorflow.Keras</AssemblyName> <AssemblyName>Tensorflow.Keras</AssemblyName>
<LangVersion>8.0</LangVersion>
<LangVersion>9.0</LangVersion>
<Nullable>enable</Nullable>
<RootNamespace>Tensorflow.Keras</RootNamespace> <RootNamespace>Tensorflow.Keras</RootNamespace>
<Platforms>AnyCPU;x64</Platforms> <Platforms>AnyCPU;x64</Platforms>
<Version>0.6.0</Version> <Version>0.6.0</Version>


+ 5
- 0
test/TensorFlowNET.UnitTest/Numpy/Math.Test.cs View File

@@ -3,6 +3,7 @@ using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Text; using System.Text;
using Tensorflow;
using Tensorflow.NumPy; using Tensorflow.NumPy;


namespace TensorFlowNET.UnitTest.Numpy namespace TensorFlowNET.UnitTest.Numpy
@@ -21,6 +22,10 @@ namespace TensorFlowNET.UnitTest.Numpy


p = np.prod(new[,] { { 1.0, 2.0 }, { 3.0, 4.0 } }); p = np.prod(new[,] { { 1.0, 2.0 }, { 3.0, 4.0 } });
Assert.AreEqual(p, 24.0); Assert.AreEqual(p, 24.0);

p = np.prod(new[,] { { 1.0, 2.0 }, { 3.0, 4.0 } }, axis: 1);
Assert.AreEqual(p.shape, 2);
Assert.IsTrue(Equal(p.Data<double>(), new[] { 2.0, 12.0 }));
} }
} }
} }

+ 1
- 1
test/TensorFlowNET.UnitTest/Tensorflow.Binding.UnitTest.csproj View File

@@ -11,7 +11,7 @@


<AssemblyOriginatorKeyFile>Open.snk</AssemblyOriginatorKeyFile> <AssemblyOriginatorKeyFile>Open.snk</AssemblyOriginatorKeyFile>


<LangVersion>8.0</LangVersion>
<LangVersion>9.0</LangVersion>


<Platforms>AnyCPU;x64</Platforms> <Platforms>AnyCPU;x64</Platforms>
</PropertyGroup> </PropertyGroup>


+ 1
- 49
test/TensorFlowNET.UnitTest/Utilities/FluentExtension.cs View File

@@ -6,6 +6,7 @@ using System;
using System.Diagnostics; using System.Diagnostics;
using System.Linq; using System.Linq;
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
using Tensorflow;


namespace TensorFlowNET.UnitTest namespace TensorFlowNET.UnitTest
{ {
@@ -108,43 +109,18 @@ namespace TensorFlowNET.UnitTest
return new AndConstraint<ShapeAssertions>(this); return new AndConstraint<ShapeAssertions>(this);
} }


public AndConstraint<ShapeAssertions> BeSliced()
{
Subject.IsSliced.Should().BeTrue();
return new AndConstraint<ShapeAssertions>(this);
}

public AndConstraint<ShapeAssertions> BeScalar() public AndConstraint<ShapeAssertions> BeScalar()
{ {
Subject.IsScalar.Should().BeTrue(); Subject.IsScalar.Should().BeTrue();
return new AndConstraint<ShapeAssertions>(this); return new AndConstraint<ShapeAssertions>(this);
} }


public AndConstraint<ShapeAssertions> BeBroadcasted()
{
Subject.IsBroadcasted.Should().BeTrue();
return new AndConstraint<ShapeAssertions>(this);
}


public AndConstraint<ShapeAssertions> NotBeSliced()
{
Subject.IsSliced.Should().BeFalse();
return new AndConstraint<ShapeAssertions>(this);
}

public AndConstraint<ShapeAssertions> NotBeScalar() public AndConstraint<ShapeAssertions> NotBeScalar()
{ {
Subject.IsScalar.Should().BeFalse(); Subject.IsScalar.Should().BeFalse();
return new AndConstraint<ShapeAssertions>(this); return new AndConstraint<ShapeAssertions>(this);
} }


public AndConstraint<ShapeAssertions> NotBeBroadcasted()
{
Subject.IsBroadcasted.Should().BeFalse();
return new AndConstraint<ShapeAssertions>(this);
}

public AndConstraint<ShapeAssertions> BeNDim(int ndim) public AndConstraint<ShapeAssertions> BeNDim(int ndim)
{ {
Subject.dims.Length.Should().Be(ndim); Subject.dims.Length.Should().Be(ndim);
@@ -215,24 +191,6 @@ namespace TensorFlowNET.UnitTest
return new AndConstraint<NDArrayAssertions>(this); return new AndConstraint<NDArrayAssertions>(this);
} }


public AndConstraint<NDArrayAssertions> BeBroadcasted()
{
Subject.shape.IsBroadcasted.Should().BeTrue();
return new AndConstraint<NDArrayAssertions>(this);
}

public AndConstraint<NDArrayAssertions> NotBeBroadcasted()
{
Subject.shape.IsBroadcasted.Should().BeFalse();
return new AndConstraint<NDArrayAssertions>(this);
}

public AndConstraint<NDArrayAssertions> BeSliced()
{
Subject.shape.IsSliced.Should().BeTrue();
return new AndConstraint<NDArrayAssertions>(this);
}

public AndConstraint<NDArrayAssertions> BeScalar() public AndConstraint<NDArrayAssertions> BeScalar()
{ {
Subject.shape.IsScalar.Should().BeTrue(); Subject.shape.IsScalar.Should().BeTrue();
@@ -264,12 +222,6 @@ namespace TensorFlowNET.UnitTest
return new AndConstraint<NDArrayAssertions>(this); return new AndConstraint<NDArrayAssertions>(this);
} }


public AndConstraint<NDArrayAssertions> NotBeSliced()
{
Subject.shape.IsSliced.Should().BeFalse();
return new AndConstraint<NDArrayAssertions>(this);
}

public AndConstraint<NDArrayAssertions> NotBeScalar() public AndConstraint<NDArrayAssertions> NotBeScalar()
{ {
Subject.shape.IsScalar.Should().BeFalse(); Subject.shape.IsScalar.Should().BeFalse();


Loading…
Cancel
Save