diff --git a/src/TensorFlowNET.Core/APIs/tf.math.cs b/src/TensorFlowNET.Core/APIs/tf.math.cs index f438f870..83f63c86 100644 --- a/src/TensorFlowNET.Core/APIs/tf.math.cs +++ b/src/TensorFlowNET.Core/APIs/tf.math.cs @@ -467,12 +467,9 @@ namespace Tensorflow /// If true, retains reduced dimensions with length 1. /// /// The reduced tensor. - public Tensor reduce_any(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null) + public Tensor reduce_any(Tensor input_tensor, Axis? axis = null, bool keepdims = false, string name = null) => math_ops.reduce_any(input_tensor, axis: axis, keepdims: keepdims, name: name); - public Tensor reduce_any(Tensor input_tensor, int axis = 0, bool keepdims = false, string name = null) - => math_ops.reduce_any(input_tensor, axis: new[] { axis }, keepdims: keepdims, name: name); - /// /// Computes the "logical and" of elements across dimensions of a tensor. /// @@ -481,7 +478,7 @@ namespace Tensorflow /// /// /// The reduced tensor. - public Tensor reduce_all(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null) + public Tensor reduce_all(Tensor input_tensor, Axis? axis = null, bool keepdims = false, string name = null) => math_ops.reduce_all(input_tensor, axis: axis, keepdims: keepdims, name: name); /// @@ -492,7 +489,7 @@ namespace Tensorflow /// /// /// - public Tensor reduce_prod(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null) + public Tensor reduce_prod(Tensor input_tensor, Axis? axis = null, bool keepdims = false, string name = null) => math_ops.reduce_prod(input_tensor, axis: axis, keepdims: keepdims, name: name); /// @@ -537,19 +534,16 @@ namespace Tensorflow /// /// /// - public Tensor reduce_max(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null) - => math_ops.reduce_max(input_tensor, axis, keepdims, name); - - public Tensor reduce_max(Tensor input_tensor, int axis, bool keepdims = false, string name = null) + public Tensor reduce_max(Tensor input_tensor, Axis? axis = null, bool keepdims = false, string name = null) => math_ops.reduce_max(input_tensor, axis, keepdims, name); - public Tensor reduce_min(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null) + public Tensor reduce_min(Tensor input_tensor, Axis? axis = null, bool keepdims = false, string name = null) => math_ops.reduce_min(input_tensor, axis, keepdims, name); - public Tensor reduce_std(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null) + public Tensor reduce_std(Tensor input_tensor, Axis? axis = null, bool keepdims = false, string name = null) => math_ops.reduce_std(input_tensor, axis, keepdims, name); - public Tensor reduce_variance(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null) + public Tensor reduce_variance(Tensor input_tensor, Axis? axis = null, bool keepdims = false, string name = null) => math_ops.reduce_variance(input_tensor, axis, keepdims, name); public Tensor sigmoid(T x, string name = null) @@ -558,15 +552,9 @@ namespace Tensorflow public Tensor sum(Tensor input, int axis, bool keep_dims = false, string name = null) => gen_math_ops._sum(input, axis, keep_dims: keep_dims, name: name); - public Tensor reduce_mean(Tensor input_tensors, int axis, bool keepdims = false, string name = null) - => math_ops.reduce_mean(input_tensors, axis: new[] { axis }, keepdims: keepdims, name: name); - - public Tensor reduce_mean(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null, int? reduction_indices = null) + public Tensor reduce_mean(Tensor input_tensor, Axis? axis = null, bool keepdims = false, string name = null, int? reduction_indices = null) => math_ops.reduce_mean(input_tensor, axis: axis, keepdims: keepdims, name: name, reduction_indices: reduction_indices); - public Tensor reduce_mean(Tensor[] input_tensors, int? axis = null, bool keepdims = false, string name = null) - => math_ops.reduce_mean(input_tensors, axis: axis, keepdims: keepdims, name: name); - public Tensor round(Tensor x, string name = null) => gen_math_ops.round(x, name: name); diff --git a/src/TensorFlowNET.Core/APIs/tf.nn.cs b/src/TensorFlowNET.Core/APIs/tf.nn.cs index cd8ee879..6b144534 100644 --- a/src/TensorFlowNET.Core/APIs/tf.nn.cs +++ b/src/TensorFlowNET.Core/APIs/tf.nn.cs @@ -89,7 +89,7 @@ namespace Tensorflow => gen_nn_ops.elu(features, name: name); public (Tensor, Tensor) moments(Tensor x, - int[] axes, + Axis axes, string name = null, bool keep_dims = false) => nn_impl.moments(x, axes, diff --git a/src/TensorFlowNET.Core/APIs/tf.reduce_logsumexp.cs b/src/TensorFlowNET.Core/APIs/tf.reduce_logsumexp.cs index 325f0633..41f0ec45 100644 --- a/src/TensorFlowNET.Core/APIs/tf.reduce_logsumexp.cs +++ b/src/TensorFlowNET.Core/APIs/tf.reduce_logsumexp.cs @@ -19,7 +19,7 @@ namespace Tensorflow public partial class tensorflow { public Tensor reduce_logsumexp(Tensor input_tensor, - int[] axis = null, + Axis? axis = null, bool keepdims = false, string name = null) => math_ops.reduce_logsumexp(input_tensor, axis, keepdims, name); diff --git a/src/TensorFlowNET.Core/NumPy/Axis.cs b/src/TensorFlowNET.Core/NumPy/Axis.cs new file mode 100644 index 00000000..97629062 --- /dev/null +++ b/src/TensorFlowNET.Core/NumPy/Axis.cs @@ -0,0 +1,31 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public record Axis(params int[] axis) + { + public int this[int index] => axis[index]; + + public static implicit operator int[]?(Axis axis) + => axis?.axis; + + public static implicit operator Axis(int axis) + => new Axis(axis); + + public static implicit operator Axis((int, int) axis) + => new Axis(axis); + + public static implicit operator Axis((int, int, int) axis) + => new Axis(axis); + + public static implicit operator Axis(int[] axis) + => new Axis(axis); + } +} + +namespace System.Runtime.CompilerServices +{ + internal static class IsExternalInit { } +} diff --git a/src/TensorFlowNET.Core/NumPy/Numpy.Math.cs b/src/TensorFlowNET.Core/NumPy/Numpy.Math.cs index cfbd2260..89c871a3 100644 --- a/src/TensorFlowNET.Core/NumPy/Numpy.Math.cs +++ b/src/TensorFlowNET.Core/NumPy/Numpy.Math.cs @@ -12,8 +12,8 @@ namespace Tensorflow.NumPy public static NDArray log(NDArray x) => throw new NotImplementedException(""); - public static NDArray prod(NDArray array, int? axis = null, Type dtype = null, bool keepdims = false) - => tf.reduce_prod(ops.convert_to_tensor(array)); + public static NDArray prod(NDArray array, Axis? axis = null, Type? dtype = null, bool keepdims = false) + => tf.reduce_prod(ops.convert_to_tensor(array), axis: axis); public static NDArray prod(params T[] array) where T : unmanaged => tf.reduce_prod(ops.convert_to_tensor(array)); diff --git a/src/TensorFlowNET.Core/Numpy/Shape.cs b/src/TensorFlowNET.Core/Numpy/Shape.cs index 7142201c..6c6d68e4 100644 --- a/src/TensorFlowNET.Core/Numpy/Shape.cs +++ b/src/TensorFlowNET.Core/Numpy/Shape.cs @@ -3,7 +3,7 @@ using System.Collections.Generic; using System.Linq; using System.Text; -namespace Tensorflow.NumPy +namespace Tensorflow { public class Shape { @@ -11,6 +11,13 @@ namespace Tensorflow.NumPy long[] _dims; public long[] dims => _dims; + public Shape() + { + } + + public Shape(params int[] dims) + => _dims = dims.Select(x => Convert.ToInt64(x)).ToArray(); + public Shape(params long[] dims) => _dims = dims; @@ -21,14 +28,27 @@ namespace Tensorflow.NumPy => new Shape(dims); public static implicit operator Shape(int[] dims) - => new Shape(dims.Select(x => Convert.ToInt64(x)).ToArray()); + => new Shape(dims); + + public static implicit operator Shape((int, int) dims) + => new Shape(dims.Item1, dims.Item2); public static implicit operator Shape((long, long) dims) => new Shape(dims.Item1, dims.Item2); - public bool IsSliced => throw new NotImplementedException(""); - public bool IsScalar => throw new NotImplementedException(""); - public bool IsBroadcasted => throw new NotImplementedException(""); + public static implicit operator Shape((int, int, int) dims) + => new Shape(dims.Item1, dims.Item2, dims.Item3); + + public static implicit operator Shape((long, long, long) dims) + => new Shape(dims.Item1, dims.Item2, dims.Item3); + + public static implicit operator Shape((int, int, int, int) dims) + => new Shape(dims.Item1, dims.Item2, dims.Item3, dims.Item4); + + public static implicit operator Shape((long, long, long, long) dims) + => new Shape(dims.Item1, dims.Item2, dims.Item3, dims.Item4); + + public bool IsScalar => ndim == 0; public static Shape Scalar => new Shape(new long[0]); @@ -55,6 +75,18 @@ namespace Tensorflow.NumPy public bool IsEmpty => throw new NotImplementedException(""); + public override bool Equals(object obj) + { + if(obj is Shape shape) + { + if (shape.ndim != ndim) + return false; + if (Enumerable.SequenceEqual(dims, shape.dims)) + return true; + } + return base.Equals(obj); + } + public override string ToString() { return "(" + string.Join(", ", _dims) + ")"; diff --git a/src/TensorFlowNET.Core/Numpy/Slice.cs b/src/TensorFlowNET.Core/Numpy/Slice.cs index 64d16e47..2bb73fe8 100644 --- a/src/TensorFlowNET.Core/Numpy/Slice.cs +++ b/src/TensorFlowNET.Core/Numpy/Slice.cs @@ -4,7 +4,7 @@ using System.Linq; using System.Text; using System.Text.RegularExpressions; -namespace Tensorflow.NumPy +namespace Tensorflow { ///

/// NDArray can be indexed using slicing

diff --git a/src/TensorFlowNET.Core/Operations/image_ops_impl.cs b/src/TensorFlowNET.Core/Operations/image_ops_impl.cs index aaa9e1ee..60b9b25c 100644 --- a/src/TensorFlowNET.Core/Operations/image_ops_impl.cs +++ b/src/TensorFlowNET.Core/Operations/image_ops_impl.cs @@ -968,9 +968,9 @@ new_height, new_width"); var num_pixels_ = array_ops.shape(image).dims; num_pixels_ = num_pixels_.Skip(num_pixels_.Length - 3).Take(num_pixels_.Length - (num_pixels_.Length - 3)).ToArray(); Tensor num_pixels = math_ops.reduce_prod(new Tensor(num_pixels_)); - Tensor image_mean = math_ops.reduce_mean(image, axis: new int[] { -1, -2, -3 }, keepdims: true); + Tensor image_mean = math_ops.reduce_mean(image, axis: new(-1, -2, -3), keepdims: true); - var stddev = math_ops.reduce_std(image, axis: new int[] { -1, -2, -3 }, keepdims: true); + var stddev = math_ops.reduce_std(image, axis: new(-1, -2, -3), keepdims: true); var min_stddev = math_ops.rsqrt(math_ops.cast(num_pixels, image.dtype)); var adjusted_stddev = math_ops.maximum(stddev, min_stddev); @@ -1408,7 +1408,7 @@ new_height, new_width"); max_val = convert_image_dtype(max_val, dtypes.float32); a = convert_image_dtype(a, dtypes.float32); b = convert_image_dtype(b, dtypes.float32); - Tensor mse = math_ops.reduce_mean(gen_math_ops.squared_difference(a, b), new int[] { -3, -2, -1 }); + Tensor mse = math_ops.reduce_mean(gen_math_ops.squared_difference(a, b), new(-3, -2, -1)); var psnr_val = math_ops.subtract( (20 * math_ops.log(max_val)) / math_ops.log(ops.convert_to_tensor(10.0)), math_ops.cast(10 / math_ops.log(ops.convert_to_tensor(10)), dtypes.float32) * math_ops.log(mse), @@ -1503,8 +1503,8 @@ new_height, new_width"); (Tensor luminance, Tensor cs) = _ssim_helper(img1, img2, reducer, max_val, compensation, k1, k2); var axes = constant_op.constant(new[] { -3, -2 }, dtype: dtypes.int32); - var ssim_val = math_ops.reduce_mean(luminance * cs, axes.dims); - cs = math_ops.reduce_mean(cs, axes.dims); + var ssim_val = math_ops.reduce_mean(luminance * cs, new(axes.dims)); + cs = math_ops.reduce_mean(cs, new(axes.dims)); return (ssim_val, cs); } @@ -1527,7 +1527,7 @@ new_height, new_width"); (Tensor ssim_per_channel, Tensor ___) = _ssim_per_channel(img1, img2, max_val, filter_size, filter_sigma, k1, k2); - return math_ops.reduce_mean(ssim_per_channel, new int[] { -1 }); + return math_ops.reduce_mean(ssim_per_channel, new(-1)); }); } @@ -1645,9 +1645,9 @@ new_height, new_width"); var mcs_and_ssim = array_ops.stack( math_ops.add(mcs, new[] { gen_nn_ops.relu(ssim_per_channel) }), axis: -1); var ms_ssim = math_ops.reduce_prod( - math_ops.pow(mcs_and_ssim, power_factors), new int[] { -1 }); + math_ops.pow(mcs_and_ssim, power_factors), new(-1)); - return math_ops.reduce_mean(ms_ssim, new int[] { -1 }); + return math_ops.reduce_mean(ms_ssim, new(-1)); }); } @@ -1830,7 +1830,7 @@ new_height, new_width"); new object[] { batch_size, tile_size, 4 }); var iou = _bbox_overlap(new_slice, box_slice); var box_slice_after_suppression = array_ops.expand_dims( - math_ops.cast(math_ops.reduce_all(iou < iou_threshold, new int[] { 1 }), + math_ops.cast(math_ops.reduce_all(iou < iou_threshold, new(1)), box_slice.dtype), 2) * box_slice; return (boxes, box_slice_after_suppression, iou_threshold, inner_idx + 1); @@ -1913,7 +1913,7 @@ new_height, new_width"); output_size = output_size + math_ops.reduce_sum( math_ops.cast( - math_ops.reduce_any(box_slice > 0, new int[] { 2 }), dtypes.int32), new int[] { 1 }); + math_ops.reduce_any(box_slice > 0, new(2)), dtypes.int32), new int[] { 1 }); } return (boxes, iou_threshold, output_size, idx + 1); } @@ -2074,7 +2074,7 @@ new_height, new_width"); (Tensor values, Tensor indices) = gen_ops.top_k_v2( math_ops.cast(math_ops.reduce_any( - (Tensor)selboxes__output_size_[0] > 0, new int[] { 2 }), dtypes.int32) * + (Tensor)selboxes__output_size_[0] > 0, new(2)), dtypes.int32) * array_ops.expand_dims( math_ops.range(num_boxes_after_padding, 0, -1), 0), max_output_size); diff --git a/src/TensorFlowNET.Core/Operations/math_ops.cs b/src/TensorFlowNET.Core/Operations/math_ops.cs index 0d5d23f6..7db11573 100644 --- a/src/TensorFlowNET.Core/Operations/math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/math_ops.cs @@ -305,7 +305,7 @@ namespace Tensorflow /// dimensions.Must be in the range `[-rank(input_tensor), rank(input_tensor))`. /// If true, retains reduced dimensions with length 1. /// A name for the operation (optional). - public static Tensor reduce_mean(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null, int? reduction_indices = null) + public static Tensor reduce_mean(Tensor input_tensor, Axis? axis = null, bool keepdims = false, string name = null, int? reduction_indices = null) { var r = _ReductionDims(input_tensor, axis); var axis_tensor = axis == null ? r : ops.convert_to_tensor(axis); @@ -313,14 +313,6 @@ namespace Tensorflow return _may_reduce_to_scalar(keepdims, axis_tensor, m); } - public static Tensor reduce_mean(Tensor[] input_tensors, int? axis = null, bool keepdims = false, string name = null) - { - var r = _ReductionDims(input_tensors, axis); - var axis_tensor = axis == null ? r : ops.convert_to_tensor(axis.Value); - var m = gen_math_ops.mean(input_tensors, axis_tensor, keepdims, name); - return _may_reduce_to_scalar(keepdims, axis, m); - } - /// /// Computes the product of elements across dimensions of a tensor. /// @@ -329,7 +321,7 @@ namespace Tensorflow /// /// /// - public static Tensor reduce_prod(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null) + public static Tensor reduce_prod(Tensor input_tensor, Axis? axis = null, bool keepdims = false, string name = null) { var r = _ReductionDims(input_tensor, axis); if (axis == null) @@ -344,7 +336,7 @@ namespace Tensorflow } } - public static Tensor reduce_std(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null) + public static Tensor reduce_std(Tensor input_tensor, Axis? axis = null, bool keepdims = false, string name = null) { if (name == null) name = "reduce_std"; @@ -357,7 +349,7 @@ namespace Tensorflow }); } - public static Tensor reduce_variance(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null) + public static Tensor reduce_variance(Tensor input_tensor, Axis? axis = null, bool keepdims = false, string name = null) { if (name == null) name = "reduce_variance"; @@ -513,7 +505,7 @@ namespace Tensorflow /// /// /// - public static Tensor reduce_all(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null) + public static Tensor reduce_all(Tensor input_tensor, Axis? axis = null, bool keepdims = false, string name = null) { var all = gen_math_ops._all(input_tensor, _ReductionDims(input_tensor, axis), @@ -545,7 +537,7 @@ namespace Tensorflow /// dimensions.Must be in the range `[-rank(input_tensor), rank(input_tensor))`. /// /// The reduced tensor. - public static Tensor reduce_logsumexp(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null) + public static Tensor reduce_logsumexp(Tensor input_tensor, Axis? axis = null, bool keepdims = false, string name = null) { return tf_with(ops.name_scope(name, "ReduceLogSumExp", new { input_tensor }), scope => { @@ -565,7 +557,7 @@ namespace Tensorflow }); } - public static Tensor reduce_any(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null) + public static Tensor reduce_any(Tensor input_tensor, Axis? axis = null, bool keepdims = false, string name = null) { var r = _ReductionDims(input_tensor, axis); var max = (axis != null) ? gen_math_ops._any(input_tensor, axis, keepdims, name) : @@ -573,7 +565,7 @@ namespace Tensorflow return _may_reduce_to_scalar(keepdims, axis, max); } - public static Tensor reduce_max(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null) + public static Tensor reduce_max(Tensor input_tensor, Axis? axis = null, bool keepdims = false, string name = null) { var r = _ReductionDims(input_tensor, axis); var max = (axis != null) ? gen_math_ops._max(input_tensor, axis, keepdims, name) : @@ -588,7 +580,7 @@ namespace Tensorflow return _may_reduce_to_scalar(keepdims, axis, max); } - public static Tensor reduce_min(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null) + public static Tensor reduce_min(Tensor input_tensor, Axis? axis = null, bool keepdims = false, string name = null) { var r = _ReductionDims(input_tensor, axis); var min = gen_math_ops._min(input_tensor, r, keepdims, name); @@ -711,7 +703,7 @@ namespace Tensorflow return range(0, array_ops.rank(x)); } - private static Tensor _ReductionDims(Tensor x, int[] axis) + private static Tensor _ReductionDims(Tensor x, Axis? axis) { if (axis != null) { diff --git a/src/TensorFlowNET.Core/Operations/nn_impl.py.cs b/src/TensorFlowNET.Core/Operations/nn_impl.py.cs index c4027924..5704d881 100644 --- a/src/TensorFlowNET.Core/Operations/nn_impl.py.cs +++ b/src/TensorFlowNET.Core/Operations/nn_impl.py.cs @@ -79,7 +79,7 @@ namespace Tensorflow /// Produce moments with the same dimensionality as the input. /// Two `Tensor` objects: `mean` and `variance`. public static (Tensor, Tensor) moments(Tensor x, - int[] axes, + Axis axes, string name = null, bool keep_dims = false) { diff --git a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj index 634d8194..1fee17f3 100644 --- a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj +++ b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj @@ -7,6 +7,7 @@ 2.2.0 0.60.0 9.0 + enable Haiping Chen, Meinrad Recheis, Eli Belash SciSharp STACK true diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs index e5f008cd..6fc79028 100644 --- a/src/TensorFlowNET.Core/ops.cs +++ b/src/TensorFlowNET.Core/ops.cs @@ -156,6 +156,8 @@ namespace Tensorflow Tensor[] tensors => array_ops._autopacking_helper(tensors, dtype, name == null ? "packed" : name), RefVariable varVal => varVal._TensorConversionFunction(dtype: dtype, name: name, as_ref: as_ref), ResourceVariable varVal => varVal._TensorConversionFunction(dtype: dtype, name: name, as_ref: as_ref), + Axis ts => constant_op.constant(ts.axis, dtype: dtype, name: name), + Shape ts => constant_op.constant(ts.dims, dtype: dtype, name: name), TensorShape ts => constant_op.constant(ts.dims, dtype: dtype, name: name), string str => constant_op.constant(str, dtype: tf.@string, name: name), string[] str => constant_op.constant(str, dtype: tf.@string, name: name), diff --git a/src/TensorFlowNET.Keras/BackendImpl.cs b/src/TensorFlowNET.Keras/BackendImpl.cs index 7ce07809..bedca2c1 100644 --- a/src/TensorFlowNET.Keras/BackendImpl.cs +++ b/src/TensorFlowNET.Keras/BackendImpl.cs @@ -142,7 +142,7 @@ namespace Tensorflow.Keras { if (x.dtype.as_base_dtype() == TF_DataType.TF_BOOL) x = math_ops.cast(x, TF_DataType.TF_FLOAT); - return math_ops.reduce_mean(x, axis: new[] { axis }, keepdims: false); + return math_ops.reduce_mean(x, axis: axis, keepdims: false); } public GraphLearningPhase learning_phase() diff --git a/src/TensorFlowNET.Keras/Layers/Pooling/GlobalAveragePooling1D.cs b/src/TensorFlowNET.Keras/Layers/Pooling/GlobalAveragePooling1D.cs index d2442bec..d62fb63a 100644 --- a/src/TensorFlowNET.Keras/Layers/Pooling/GlobalAveragePooling1D.cs +++ b/src/TensorFlowNET.Keras/Layers/Pooling/GlobalAveragePooling1D.cs @@ -15,9 +15,9 @@ namespace Tensorflow.Keras.Layers protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) { if (data_format == "channels_last") - return math_ops.reduce_mean(inputs, new int[] { 1 }, false); + return math_ops.reduce_mean(inputs, 1, false); else - return math_ops.reduce_mean(inputs, new int[] { 2 }, false); + return math_ops.reduce_mean(inputs, 2, false); } } } diff --git a/src/TensorFlowNET.Keras/Layers/Pooling/GlobalAveragePooling2D.cs b/src/TensorFlowNET.Keras/Layers/Pooling/GlobalAveragePooling2D.cs index b35d7832..000e4b8b 100644 --- a/src/TensorFlowNET.Keras/Layers/Pooling/GlobalAveragePooling2D.cs +++ b/src/TensorFlowNET.Keras/Layers/Pooling/GlobalAveragePooling2D.cs @@ -15,9 +15,9 @@ namespace Tensorflow.Keras.Layers protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) { if (data_format == "channels_last") - return math_ops.reduce_mean(inputs, new int[] { 1, 2 }, false); + return math_ops.reduce_mean(inputs, (1, 2), false); else - return math_ops.reduce_mean(inputs, new int[] { 2, 3 }, false); + return math_ops.reduce_mean(inputs, (2, 3), false); } } } diff --git a/src/TensorFlowNET.Keras/Layers/Pooling/GlobalMaxPooling1D.cs b/src/TensorFlowNET.Keras/Layers/Pooling/GlobalMaxPooling1D.cs index c0d0d831..2de4671c 100644 --- a/src/TensorFlowNET.Keras/Layers/Pooling/GlobalMaxPooling1D.cs +++ b/src/TensorFlowNET.Keras/Layers/Pooling/GlobalMaxPooling1D.cs @@ -15,9 +15,9 @@ namespace Tensorflow.Keras.Layers protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) { if (data_format == "channels_last") - return math_ops.reduce_max(inputs, new int[] { 1 }, false); + return math_ops.reduce_max(inputs, 1, false); else - return math_ops.reduce_max(inputs, new int[] { 2 }, false); + return math_ops.reduce_max(inputs, 2, false); } } } diff --git a/src/TensorFlowNET.Keras/Layers/Pooling/GlobalMaxPooling2D.cs b/src/TensorFlowNET.Keras/Layers/Pooling/GlobalMaxPooling2D.cs index 6ab6b501..b7e2c945 100644 --- a/src/TensorFlowNET.Keras/Layers/Pooling/GlobalMaxPooling2D.cs +++ b/src/TensorFlowNET.Keras/Layers/Pooling/GlobalMaxPooling2D.cs @@ -15,9 +15,9 @@ namespace Tensorflow.Keras.Layers protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) { if (data_format == "channels_last") - return math_ops.reduce_max(inputs, new int[] { 1, 2 }, false); + return math_ops.reduce_max(inputs, (1, 2), false); else - return math_ops.reduce_max(inputs, new int[] { 2, 3 }, false); + return math_ops.reduce_max(inputs, (2, 3), false); } } } diff --git a/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj b/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj index d0dd51fe..79dc542b 100644 --- a/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj +++ b/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj @@ -3,7 +3,8 @@ netstandard2.1 Tensorflow.Keras - 8.0 + 9.0 + enable Tensorflow.Keras AnyCPU;x64 0.6.0 diff --git a/test/TensorFlowNET.UnitTest/Numpy/Math.Test.cs b/test/TensorFlowNET.UnitTest/Numpy/Math.Test.cs index 40776e9b..ea4930fb 100644 --- a/test/TensorFlowNET.UnitTest/Numpy/Math.Test.cs +++ b/test/TensorFlowNET.UnitTest/Numpy/Math.Test.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; using System.Linq; using System.Text; +using Tensorflow; using Tensorflow.NumPy; namespace TensorFlowNET.UnitTest.Numpy @@ -21,6 +22,10 @@ namespace TensorFlowNET.UnitTest.Numpy p = np.prod(new[,] { { 1.0, 2.0 }, { 3.0, 4.0 } }); Assert.AreEqual(p, 24.0); + + p = np.prod(new[,] { { 1.0, 2.0 }, { 3.0, 4.0 } }, axis: 1); + Assert.AreEqual(p.shape, 2); + Assert.IsTrue(Equal(p.Data(), new[] { 2.0, 12.0 })); } } } diff --git a/test/TensorFlowNET.UnitTest/Tensorflow.Binding.UnitTest.csproj b/test/TensorFlowNET.UnitTest/Tensorflow.Binding.UnitTest.csproj index 0ac43614..2aef756a 100644 --- a/test/TensorFlowNET.UnitTest/Tensorflow.Binding.UnitTest.csproj +++ b/test/TensorFlowNET.UnitTest/Tensorflow.Binding.UnitTest.csproj @@ -11,7 +11,7 @@ Open.snk - 8.0 + 9.0 AnyCPU;x64 diff --git a/test/TensorFlowNET.UnitTest/Utilities/FluentExtension.cs b/test/TensorFlowNET.UnitTest/Utilities/FluentExtension.cs index 154ab714..39e72880 100644 --- a/test/TensorFlowNET.UnitTest/Utilities/FluentExtension.cs +++ b/test/TensorFlowNET.UnitTest/Utilities/FluentExtension.cs @@ -6,6 +6,7 @@ using System; using System.Diagnostics; using System.Linq; using System.Runtime.CompilerServices; +using Tensorflow; namespace TensorFlowNET.UnitTest { @@ -108,43 +109,18 @@ namespace TensorFlowNET.UnitTest return new AndConstraint(this); } - public AndConstraint BeSliced() - { - Subject.IsSliced.Should().BeTrue(); - return new AndConstraint(this); - } - public AndConstraint BeScalar() { Subject.IsScalar.Should().BeTrue(); return new AndConstraint(this); } - public AndConstraint BeBroadcasted() - { - Subject.IsBroadcasted.Should().BeTrue(); - return new AndConstraint(this); - } - - - public AndConstraint NotBeSliced() - { - Subject.IsSliced.Should().BeFalse(); - return new AndConstraint(this); - } - public AndConstraint NotBeScalar() { Subject.IsScalar.Should().BeFalse(); return new AndConstraint(this); } - public AndConstraint NotBeBroadcasted() - { - Subject.IsBroadcasted.Should().BeFalse(); - return new AndConstraint(this); - } - public AndConstraint BeNDim(int ndim) { Subject.dims.Length.Should().Be(ndim); @@ -215,24 +191,6 @@ namespace TensorFlowNET.UnitTest return new AndConstraint(this); } - public AndConstraint BeBroadcasted() - { - Subject.shape.IsBroadcasted.Should().BeTrue(); - return new AndConstraint(this); - } - - public AndConstraint NotBeBroadcasted() - { - Subject.shape.IsBroadcasted.Should().BeFalse(); - return new AndConstraint(this); - } - - public AndConstraint BeSliced() - { - Subject.shape.IsSliced.Should().BeTrue(); - return new AndConstraint(this); - } - public AndConstraint BeScalar() { Subject.shape.IsScalar.Should().BeTrue(); @@ -264,12 +222,6 @@ namespace TensorFlowNET.UnitTest return new AndConstraint(this); } - public AndConstraint NotBeSliced() - { - Subject.shape.IsSliced.Should().BeFalse(); - return new AndConstraint(this); - } - public AndConstraint NotBeScalar() { Subject.shape.IsScalar.Should().BeFalse();