diff --git a/src/TensorFlowNET.Core/APIs/tf.math.cs b/src/TensorFlowNET.Core/APIs/tf.math.cs
index f438f870..83f63c86 100644
--- a/src/TensorFlowNET.Core/APIs/tf.math.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.math.cs
@@ -467,12 +467,9 @@ namespace Tensorflow
/// If true, retains reduced dimensions with length 1.
///
/// The reduced tensor.
- public Tensor reduce_any(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null)
+ public Tensor reduce_any(Tensor input_tensor, Axis? axis = null, bool keepdims = false, string name = null)
=> math_ops.reduce_any(input_tensor, axis: axis, keepdims: keepdims, name: name);
- public Tensor reduce_any(Tensor input_tensor, int axis = 0, bool keepdims = false, string name = null)
- => math_ops.reduce_any(input_tensor, axis: new[] { axis }, keepdims: keepdims, name: name);
-
///
/// Computes the "logical and" of elements across dimensions of a tensor.
///
@@ -481,7 +478,7 @@ namespace Tensorflow
///
///
/// The reduced tensor.
- public Tensor reduce_all(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null)
+ public Tensor reduce_all(Tensor input_tensor, Axis? axis = null, bool keepdims = false, string name = null)
=> math_ops.reduce_all(input_tensor, axis: axis, keepdims: keepdims, name: name);
///
@@ -492,7 +489,7 @@ namespace Tensorflow
///
///
///
- public Tensor reduce_prod(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null)
+ public Tensor reduce_prod(Tensor input_tensor, Axis? axis = null, bool keepdims = false, string name = null)
=> math_ops.reduce_prod(input_tensor, axis: axis, keepdims: keepdims, name: name);
///
@@ -537,19 +534,16 @@ namespace Tensorflow
///
///
///
- public Tensor reduce_max(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null)
- => math_ops.reduce_max(input_tensor, axis, keepdims, name);
-
- public Tensor reduce_max(Tensor input_tensor, int axis, bool keepdims = false, string name = null)
+ public Tensor reduce_max(Tensor input_tensor, Axis? axis = null, bool keepdims = false, string name = null)
=> math_ops.reduce_max(input_tensor, axis, keepdims, name);
- public Tensor reduce_min(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null)
+ public Tensor reduce_min(Tensor input_tensor, Axis? axis = null, bool keepdims = false, string name = null)
=> math_ops.reduce_min(input_tensor, axis, keepdims, name);
- public Tensor reduce_std(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null)
+ public Tensor reduce_std(Tensor input_tensor, Axis? axis = null, bool keepdims = false, string name = null)
=> math_ops.reduce_std(input_tensor, axis, keepdims, name);
- public Tensor reduce_variance(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null)
+ public Tensor reduce_variance(Tensor input_tensor, Axis? axis = null, bool keepdims = false, string name = null)
=> math_ops.reduce_variance(input_tensor, axis, keepdims, name);
public Tensor sigmoid(T x, string name = null)
@@ -558,15 +552,9 @@ namespace Tensorflow
public Tensor sum(Tensor input, int axis, bool keep_dims = false, string name = null)
=> gen_math_ops._sum(input, axis, keep_dims: keep_dims, name: name);
- public Tensor reduce_mean(Tensor input_tensors, int axis, bool keepdims = false, string name = null)
- => math_ops.reduce_mean(input_tensors, axis: new[] { axis }, keepdims: keepdims, name: name);
-
- public Tensor reduce_mean(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null, int? reduction_indices = null)
+ public Tensor reduce_mean(Tensor input_tensor, Axis? axis = null, bool keepdims = false, string name = null, int? reduction_indices = null)
=> math_ops.reduce_mean(input_tensor, axis: axis, keepdims: keepdims, name: name, reduction_indices: reduction_indices);
- public Tensor reduce_mean(Tensor[] input_tensors, int? axis = null, bool keepdims = false, string name = null)
- => math_ops.reduce_mean(input_tensors, axis: axis, keepdims: keepdims, name: name);
-
public Tensor round(Tensor x, string name = null)
=> gen_math_ops.round(x, name: name);
diff --git a/src/TensorFlowNET.Core/APIs/tf.nn.cs b/src/TensorFlowNET.Core/APIs/tf.nn.cs
index cd8ee879..6b144534 100644
--- a/src/TensorFlowNET.Core/APIs/tf.nn.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.nn.cs
@@ -89,7 +89,7 @@ namespace Tensorflow
=> gen_nn_ops.elu(features, name: name);
public (Tensor, Tensor) moments(Tensor x,
- int[] axes,
+ Axis axes,
string name = null,
bool keep_dims = false) => nn_impl.moments(x,
axes,
diff --git a/src/TensorFlowNET.Core/APIs/tf.reduce_logsumexp.cs b/src/TensorFlowNET.Core/APIs/tf.reduce_logsumexp.cs
index 325f0633..41f0ec45 100644
--- a/src/TensorFlowNET.Core/APIs/tf.reduce_logsumexp.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.reduce_logsumexp.cs
@@ -19,7 +19,7 @@ namespace Tensorflow
public partial class tensorflow
{
public Tensor reduce_logsumexp(Tensor input_tensor,
- int[] axis = null,
+ Axis? axis = null,
bool keepdims = false,
string name = null) => math_ops.reduce_logsumexp(input_tensor, axis, keepdims, name);
diff --git a/src/TensorFlowNET.Core/NumPy/Axis.cs b/src/TensorFlowNET.Core/NumPy/Axis.cs
new file mode 100644
index 00000000..97629062
--- /dev/null
+++ b/src/TensorFlowNET.Core/NumPy/Axis.cs
@@ -0,0 +1,31 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow
+{
+ public record Axis(params int[] axis)
+ {
+ public int this[int index] => axis[index];
+
+ public static implicit operator int[]?(Axis axis)
+ => axis?.axis;
+
+ public static implicit operator Axis(int axis)
+ => new Axis(axis);
+
+ public static implicit operator Axis((int, int) axis)
+ => new Axis(axis);
+
+ public static implicit operator Axis((int, int, int) axis)
+ => new Axis(axis);
+
+ public static implicit operator Axis(int[] axis)
+ => new Axis(axis);
+ }
+}
+
+namespace System.Runtime.CompilerServices
+{
+ internal static class IsExternalInit { }
+}
diff --git a/src/TensorFlowNET.Core/NumPy/Numpy.Math.cs b/src/TensorFlowNET.Core/NumPy/Numpy.Math.cs
index cfbd2260..89c871a3 100644
--- a/src/TensorFlowNET.Core/NumPy/Numpy.Math.cs
+++ b/src/TensorFlowNET.Core/NumPy/Numpy.Math.cs
@@ -12,8 +12,8 @@ namespace Tensorflow.NumPy
public static NDArray log(NDArray x)
=> throw new NotImplementedException("");
- public static NDArray prod(NDArray array, int? axis = null, Type dtype = null, bool keepdims = false)
- => tf.reduce_prod(ops.convert_to_tensor(array));
+ public static NDArray prod(NDArray array, Axis? axis = null, Type? dtype = null, bool keepdims = false)
+ => tf.reduce_prod(ops.convert_to_tensor(array), axis: axis);
public static NDArray prod(params T[] array) where T : unmanaged
=> tf.reduce_prod(ops.convert_to_tensor(array));
diff --git a/src/TensorFlowNET.Core/Numpy/Shape.cs b/src/TensorFlowNET.Core/Numpy/Shape.cs
index 7142201c..6c6d68e4 100644
--- a/src/TensorFlowNET.Core/Numpy/Shape.cs
+++ b/src/TensorFlowNET.Core/Numpy/Shape.cs
@@ -3,7 +3,7 @@ using System.Collections.Generic;
using System.Linq;
using System.Text;
-namespace Tensorflow.NumPy
+namespace Tensorflow
{
public class Shape
{
@@ -11,6 +11,13 @@ namespace Tensorflow.NumPy
long[] _dims;
public long[] dims => _dims;
+ public Shape()
+ {
+ }
+
+ public Shape(params int[] dims)
+ => _dims = dims.Select(x => Convert.ToInt64(x)).ToArray();
+
public Shape(params long[] dims)
=> _dims = dims;
@@ -21,14 +28,27 @@ namespace Tensorflow.NumPy
=> new Shape(dims);
public static implicit operator Shape(int[] dims)
- => new Shape(dims.Select(x => Convert.ToInt64(x)).ToArray());
+ => new Shape(dims);
+
+ public static implicit operator Shape((int, int) dims)
+ => new Shape(dims.Item1, dims.Item2);
public static implicit operator Shape((long, long) dims)
=> new Shape(dims.Item1, dims.Item2);
- public bool IsSliced => throw new NotImplementedException("");
- public bool IsScalar => throw new NotImplementedException("");
- public bool IsBroadcasted => throw new NotImplementedException("");
+ public static implicit operator Shape((int, int, int) dims)
+ => new Shape(dims.Item1, dims.Item2, dims.Item3);
+
+ public static implicit operator Shape((long, long, long) dims)
+ => new Shape(dims.Item1, dims.Item2, dims.Item3);
+
+ public static implicit operator Shape((int, int, int, int) dims)
+ => new Shape(dims.Item1, dims.Item2, dims.Item3, dims.Item4);
+
+ public static implicit operator Shape((long, long, long, long) dims)
+ => new Shape(dims.Item1, dims.Item2, dims.Item3, dims.Item4);
+
+ public bool IsScalar => ndim == 0;
public static Shape Scalar
=> new Shape(new long[0]);
@@ -55,6 +75,18 @@ namespace Tensorflow.NumPy
public bool IsEmpty => throw new NotImplementedException("");
+ public override bool Equals(object obj)
+ {
+ if(obj is Shape shape)
+ {
+ if (shape.ndim != ndim)
+ return false;
+ if (Enumerable.SequenceEqual(dims, shape.dims))
+ return true;
+ }
+ return base.Equals(obj);
+ }
+
public override string ToString()
{
return "(" + string.Join(", ", _dims) + ")";
diff --git a/src/TensorFlowNET.Core/Numpy/Slice.cs b/src/TensorFlowNET.Core/Numpy/Slice.cs
index 64d16e47..2bb73fe8 100644
--- a/src/TensorFlowNET.Core/Numpy/Slice.cs
+++ b/src/TensorFlowNET.Core/Numpy/Slice.cs
@@ -4,7 +4,7 @@ using System.Linq;
using System.Text;
using System.Text.RegularExpressions;
-namespace Tensorflow.NumPy
+namespace Tensorflow
{
///
/// NDArray can be indexed using slicing
diff --git a/src/TensorFlowNET.Core/Operations/image_ops_impl.cs b/src/TensorFlowNET.Core/Operations/image_ops_impl.cs
index aaa9e1ee..60b9b25c 100644
--- a/src/TensorFlowNET.Core/Operations/image_ops_impl.cs
+++ b/src/TensorFlowNET.Core/Operations/image_ops_impl.cs
@@ -968,9 +968,9 @@ new_height, new_width");
var num_pixels_ = array_ops.shape(image).dims;
num_pixels_ = num_pixels_.Skip(num_pixels_.Length - 3).Take(num_pixels_.Length - (num_pixels_.Length - 3)).ToArray();
Tensor num_pixels = math_ops.reduce_prod(new Tensor(num_pixels_));
- Tensor image_mean = math_ops.reduce_mean(image, axis: new int[] { -1, -2, -3 }, keepdims: true);
+ Tensor image_mean = math_ops.reduce_mean(image, axis: new(-1, -2, -3), keepdims: true);
- var stddev = math_ops.reduce_std(image, axis: new int[] { -1, -2, -3 }, keepdims: true);
+ var stddev = math_ops.reduce_std(image, axis: new(-1, -2, -3), keepdims: true);
var min_stddev = math_ops.rsqrt(math_ops.cast(num_pixels, image.dtype));
var adjusted_stddev = math_ops.maximum(stddev, min_stddev);
@@ -1408,7 +1408,7 @@ new_height, new_width");
max_val = convert_image_dtype(max_val, dtypes.float32);
a = convert_image_dtype(a, dtypes.float32);
b = convert_image_dtype(b, dtypes.float32);
- Tensor mse = math_ops.reduce_mean(gen_math_ops.squared_difference(a, b), new int[] { -3, -2, -1 });
+ Tensor mse = math_ops.reduce_mean(gen_math_ops.squared_difference(a, b), new(-3, -2, -1));
var psnr_val = math_ops.subtract(
(20 * math_ops.log(max_val)) / math_ops.log(ops.convert_to_tensor(10.0)),
math_ops.cast(10 / math_ops.log(ops.convert_to_tensor(10)), dtypes.float32) * math_ops.log(mse),
@@ -1503,8 +1503,8 @@ new_height, new_width");
(Tensor luminance, Tensor cs) = _ssim_helper(img1, img2, reducer, max_val, compensation, k1, k2);
var axes = constant_op.constant(new[] { -3, -2 }, dtype: dtypes.int32);
- var ssim_val = math_ops.reduce_mean(luminance * cs, axes.dims);
- cs = math_ops.reduce_mean(cs, axes.dims);
+ var ssim_val = math_ops.reduce_mean(luminance * cs, new(axes.dims));
+ cs = math_ops.reduce_mean(cs, new(axes.dims));
return (ssim_val, cs);
}
@@ -1527,7 +1527,7 @@ new_height, new_width");
(Tensor ssim_per_channel, Tensor ___) = _ssim_per_channel(img1, img2, max_val, filter_size,
filter_sigma, k1, k2);
- return math_ops.reduce_mean(ssim_per_channel, new int[] { -1 });
+ return math_ops.reduce_mean(ssim_per_channel, new(-1));
});
}
@@ -1645,9 +1645,9 @@ new_height, new_width");
var mcs_and_ssim = array_ops.stack(
math_ops.add(mcs, new[] { gen_nn_ops.relu(ssim_per_channel) }), axis: -1);
var ms_ssim = math_ops.reduce_prod(
- math_ops.pow(mcs_and_ssim, power_factors), new int[] { -1 });
+ math_ops.pow(mcs_and_ssim, power_factors), new(-1));
- return math_ops.reduce_mean(ms_ssim, new int[] { -1 });
+ return math_ops.reduce_mean(ms_ssim, new(-1));
});
}
@@ -1830,7 +1830,7 @@ new_height, new_width");
new object[] { batch_size, tile_size, 4 });
var iou = _bbox_overlap(new_slice, box_slice);
var box_slice_after_suppression = array_ops.expand_dims(
- math_ops.cast(math_ops.reduce_all(iou < iou_threshold, new int[] { 1 }),
+ math_ops.cast(math_ops.reduce_all(iou < iou_threshold, new(1)),
box_slice.dtype),
2) * box_slice;
return (boxes, box_slice_after_suppression, iou_threshold, inner_idx + 1);
@@ -1913,7 +1913,7 @@ new_height, new_width");
output_size = output_size + math_ops.reduce_sum(
math_ops.cast(
- math_ops.reduce_any(box_slice > 0, new int[] { 2 }), dtypes.int32), new int[] { 1 });
+ math_ops.reduce_any(box_slice > 0, new(2)), dtypes.int32), new int[] { 1 });
}
return (boxes, iou_threshold, output_size, idx + 1);
}
@@ -2074,7 +2074,7 @@ new_height, new_width");
(Tensor values, Tensor indices) = gen_ops.top_k_v2(
math_ops.cast(math_ops.reduce_any(
- (Tensor)selboxes__output_size_[0] > 0, new int[] { 2 }), dtypes.int32) *
+ (Tensor)selboxes__output_size_[0] > 0, new(2)), dtypes.int32) *
array_ops.expand_dims(
math_ops.range(num_boxes_after_padding, 0, -1), 0),
max_output_size);
diff --git a/src/TensorFlowNET.Core/Operations/math_ops.cs b/src/TensorFlowNET.Core/Operations/math_ops.cs
index 0d5d23f6..7db11573 100644
--- a/src/TensorFlowNET.Core/Operations/math_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/math_ops.cs
@@ -305,7 +305,7 @@ namespace Tensorflow
/// dimensions.Must be in the range `[-rank(input_tensor), rank(input_tensor))`.
/// If true, retains reduced dimensions with length 1.
/// A name for the operation (optional).
- public static Tensor reduce_mean(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null, int? reduction_indices = null)
+ public static Tensor reduce_mean(Tensor input_tensor, Axis? axis = null, bool keepdims = false, string name = null, int? reduction_indices = null)
{
var r = _ReductionDims(input_tensor, axis);
var axis_tensor = axis == null ? r : ops.convert_to_tensor(axis);
@@ -313,14 +313,6 @@ namespace Tensorflow
return _may_reduce_to_scalar(keepdims, axis_tensor, m);
}
- public static Tensor reduce_mean(Tensor[] input_tensors, int? axis = null, bool keepdims = false, string name = null)
- {
- var r = _ReductionDims(input_tensors, axis);
- var axis_tensor = axis == null ? r : ops.convert_to_tensor(axis.Value);
- var m = gen_math_ops.mean(input_tensors, axis_tensor, keepdims, name);
- return _may_reduce_to_scalar(keepdims, axis, m);
- }
-
///
/// Computes the product of elements across dimensions of a tensor.
///
@@ -329,7 +321,7 @@ namespace Tensorflow
///
///
///
- public static Tensor reduce_prod(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null)
+ public static Tensor reduce_prod(Tensor input_tensor, Axis? axis = null, bool keepdims = false, string name = null)
{
var r = _ReductionDims(input_tensor, axis);
if (axis == null)
@@ -344,7 +336,7 @@ namespace Tensorflow
}
}
- public static Tensor reduce_std(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null)
+ public static Tensor reduce_std(Tensor input_tensor, Axis? axis = null, bool keepdims = false, string name = null)
{
if (name == null)
name = "reduce_std";
@@ -357,7 +349,7 @@ namespace Tensorflow
});
}
- public static Tensor reduce_variance(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null)
+ public static Tensor reduce_variance(Tensor input_tensor, Axis? axis = null, bool keepdims = false, string name = null)
{
if (name == null)
name = "reduce_variance";
@@ -513,7 +505,7 @@ namespace Tensorflow
///
///
///
- public static Tensor reduce_all(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null)
+ public static Tensor reduce_all(Tensor input_tensor, Axis? axis = null, bool keepdims = false, string name = null)
{
var all = gen_math_ops._all(input_tensor,
_ReductionDims(input_tensor, axis),
@@ -545,7 +537,7 @@ namespace Tensorflow
/// dimensions.Must be in the range `[-rank(input_tensor), rank(input_tensor))`.
///
/// The reduced tensor.
- public static Tensor reduce_logsumexp(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null)
+ public static Tensor reduce_logsumexp(Tensor input_tensor, Axis? axis = null, bool keepdims = false, string name = null)
{
return tf_with(ops.name_scope(name, "ReduceLogSumExp", new { input_tensor }), scope =>
{
@@ -565,7 +557,7 @@ namespace Tensorflow
});
}
- public static Tensor reduce_any(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null)
+ public static Tensor reduce_any(Tensor input_tensor, Axis? axis = null, bool keepdims = false, string name = null)
{
var r = _ReductionDims(input_tensor, axis);
var max = (axis != null) ? gen_math_ops._any(input_tensor, axis, keepdims, name) :
@@ -573,7 +565,7 @@ namespace Tensorflow
return _may_reduce_to_scalar(keepdims, axis, max);
}
- public static Tensor reduce_max(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null)
+ public static Tensor reduce_max(Tensor input_tensor, Axis? axis = null, bool keepdims = false, string name = null)
{
var r = _ReductionDims(input_tensor, axis);
var max = (axis != null) ? gen_math_ops._max(input_tensor, axis, keepdims, name) :
@@ -588,7 +580,7 @@ namespace Tensorflow
return _may_reduce_to_scalar(keepdims, axis, max);
}
- public static Tensor reduce_min(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null)
+ public static Tensor reduce_min(Tensor input_tensor, Axis? axis = null, bool keepdims = false, string name = null)
{
var r = _ReductionDims(input_tensor, axis);
var min = gen_math_ops._min(input_tensor, r, keepdims, name);
@@ -711,7 +703,7 @@ namespace Tensorflow
return range(0, array_ops.rank(x));
}
- private static Tensor _ReductionDims(Tensor x, int[] axis)
+ private static Tensor _ReductionDims(Tensor x, Axis? axis)
{
if (axis != null)
{
diff --git a/src/TensorFlowNET.Core/Operations/nn_impl.py.cs b/src/TensorFlowNET.Core/Operations/nn_impl.py.cs
index c4027924..5704d881 100644
--- a/src/TensorFlowNET.Core/Operations/nn_impl.py.cs
+++ b/src/TensorFlowNET.Core/Operations/nn_impl.py.cs
@@ -79,7 +79,7 @@ namespace Tensorflow
/// Produce moments with the same dimensionality as the input.
/// Two `Tensor` objects: `mean` and `variance`.
public static (Tensor, Tensor) moments(Tensor x,
- int[] axes,
+ Axis axes,
string name = null,
bool keep_dims = false)
{
diff --git a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj
index 634d8194..1fee17f3 100644
--- a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj
+++ b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj
@@ -7,6 +7,7 @@
2.2.0
0.60.0
9.0
+ enable
Haiping Chen, Meinrad Recheis, Eli Belash
SciSharp STACK
true
diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs
index e5f008cd..6fc79028 100644
--- a/src/TensorFlowNET.Core/ops.cs
+++ b/src/TensorFlowNET.Core/ops.cs
@@ -156,6 +156,8 @@ namespace Tensorflow
Tensor[] tensors => array_ops._autopacking_helper(tensors, dtype, name == null ? "packed" : name),
RefVariable varVal => varVal._TensorConversionFunction(dtype: dtype, name: name, as_ref: as_ref),
ResourceVariable varVal => varVal._TensorConversionFunction(dtype: dtype, name: name, as_ref: as_ref),
+ Axis ts => constant_op.constant(ts.axis, dtype: dtype, name: name),
+ Shape ts => constant_op.constant(ts.dims, dtype: dtype, name: name),
TensorShape ts => constant_op.constant(ts.dims, dtype: dtype, name: name),
string str => constant_op.constant(str, dtype: tf.@string, name: name),
string[] str => constant_op.constant(str, dtype: tf.@string, name: name),
diff --git a/src/TensorFlowNET.Keras/BackendImpl.cs b/src/TensorFlowNET.Keras/BackendImpl.cs
index 7ce07809..bedca2c1 100644
--- a/src/TensorFlowNET.Keras/BackendImpl.cs
+++ b/src/TensorFlowNET.Keras/BackendImpl.cs
@@ -142,7 +142,7 @@ namespace Tensorflow.Keras
{
if (x.dtype.as_base_dtype() == TF_DataType.TF_BOOL)
x = math_ops.cast(x, TF_DataType.TF_FLOAT);
- return math_ops.reduce_mean(x, axis: new[] { axis }, keepdims: false);
+ return math_ops.reduce_mean(x, axis: axis, keepdims: false);
}
public GraphLearningPhase learning_phase()
diff --git a/src/TensorFlowNET.Keras/Layers/Pooling/GlobalAveragePooling1D.cs b/src/TensorFlowNET.Keras/Layers/Pooling/GlobalAveragePooling1D.cs
index d2442bec..d62fb63a 100644
--- a/src/TensorFlowNET.Keras/Layers/Pooling/GlobalAveragePooling1D.cs
+++ b/src/TensorFlowNET.Keras/Layers/Pooling/GlobalAveragePooling1D.cs
@@ -15,9 +15,9 @@ namespace Tensorflow.Keras.Layers
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
{
if (data_format == "channels_last")
- return math_ops.reduce_mean(inputs, new int[] { 1 }, false);
+ return math_ops.reduce_mean(inputs, 1, false);
else
- return math_ops.reduce_mean(inputs, new int[] { 2 }, false);
+ return math_ops.reduce_mean(inputs, 2, false);
}
}
}
diff --git a/src/TensorFlowNET.Keras/Layers/Pooling/GlobalAveragePooling2D.cs b/src/TensorFlowNET.Keras/Layers/Pooling/GlobalAveragePooling2D.cs
index b35d7832..000e4b8b 100644
--- a/src/TensorFlowNET.Keras/Layers/Pooling/GlobalAveragePooling2D.cs
+++ b/src/TensorFlowNET.Keras/Layers/Pooling/GlobalAveragePooling2D.cs
@@ -15,9 +15,9 @@ namespace Tensorflow.Keras.Layers
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
{
if (data_format == "channels_last")
- return math_ops.reduce_mean(inputs, new int[] { 1, 2 }, false);
+ return math_ops.reduce_mean(inputs, (1, 2), false);
else
- return math_ops.reduce_mean(inputs, new int[] { 2, 3 }, false);
+ return math_ops.reduce_mean(inputs, (2, 3), false);
}
}
}
diff --git a/src/TensorFlowNET.Keras/Layers/Pooling/GlobalMaxPooling1D.cs b/src/TensorFlowNET.Keras/Layers/Pooling/GlobalMaxPooling1D.cs
index c0d0d831..2de4671c 100644
--- a/src/TensorFlowNET.Keras/Layers/Pooling/GlobalMaxPooling1D.cs
+++ b/src/TensorFlowNET.Keras/Layers/Pooling/GlobalMaxPooling1D.cs
@@ -15,9 +15,9 @@ namespace Tensorflow.Keras.Layers
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
{
if (data_format == "channels_last")
- return math_ops.reduce_max(inputs, new int[] { 1 }, false);
+ return math_ops.reduce_max(inputs, 1, false);
else
- return math_ops.reduce_max(inputs, new int[] { 2 }, false);
+ return math_ops.reduce_max(inputs, 2, false);
}
}
}
diff --git a/src/TensorFlowNET.Keras/Layers/Pooling/GlobalMaxPooling2D.cs b/src/TensorFlowNET.Keras/Layers/Pooling/GlobalMaxPooling2D.cs
index 6ab6b501..b7e2c945 100644
--- a/src/TensorFlowNET.Keras/Layers/Pooling/GlobalMaxPooling2D.cs
+++ b/src/TensorFlowNET.Keras/Layers/Pooling/GlobalMaxPooling2D.cs
@@ -15,9 +15,9 @@ namespace Tensorflow.Keras.Layers
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
{
if (data_format == "channels_last")
- return math_ops.reduce_max(inputs, new int[] { 1, 2 }, false);
+ return math_ops.reduce_max(inputs, (1, 2), false);
else
- return math_ops.reduce_max(inputs, new int[] { 2, 3 }, false);
+ return math_ops.reduce_max(inputs, (2, 3), false);
}
}
}
diff --git a/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj b/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj
index d0dd51fe..79dc542b 100644
--- a/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj
+++ b/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj
@@ -3,7 +3,8 @@
netstandard2.1
Tensorflow.Keras
- 8.0
+ 9.0
+ enable
Tensorflow.Keras
AnyCPU;x64
0.6.0
diff --git a/test/TensorFlowNET.UnitTest/Numpy/Math.Test.cs b/test/TensorFlowNET.UnitTest/Numpy/Math.Test.cs
index 40776e9b..ea4930fb 100644
--- a/test/TensorFlowNET.UnitTest/Numpy/Math.Test.cs
+++ b/test/TensorFlowNET.UnitTest/Numpy/Math.Test.cs
@@ -3,6 +3,7 @@ using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
+using Tensorflow;
using Tensorflow.NumPy;
namespace TensorFlowNET.UnitTest.Numpy
@@ -21,6 +22,10 @@ namespace TensorFlowNET.UnitTest.Numpy
p = np.prod(new[,] { { 1.0, 2.0 }, { 3.0, 4.0 } });
Assert.AreEqual(p, 24.0);
+
+ p = np.prod(new[,] { { 1.0, 2.0 }, { 3.0, 4.0 } }, axis: 1);
+ Assert.AreEqual(p.shape, 2);
+ Assert.IsTrue(Equal(p.Data(), new[] { 2.0, 12.0 }));
}
}
}
diff --git a/test/TensorFlowNET.UnitTest/Tensorflow.Binding.UnitTest.csproj b/test/TensorFlowNET.UnitTest/Tensorflow.Binding.UnitTest.csproj
index 0ac43614..2aef756a 100644
--- a/test/TensorFlowNET.UnitTest/Tensorflow.Binding.UnitTest.csproj
+++ b/test/TensorFlowNET.UnitTest/Tensorflow.Binding.UnitTest.csproj
@@ -11,7 +11,7 @@
Open.snk
- 8.0
+ 9.0
AnyCPU;x64
diff --git a/test/TensorFlowNET.UnitTest/Utilities/FluentExtension.cs b/test/TensorFlowNET.UnitTest/Utilities/FluentExtension.cs
index 154ab714..39e72880 100644
--- a/test/TensorFlowNET.UnitTest/Utilities/FluentExtension.cs
+++ b/test/TensorFlowNET.UnitTest/Utilities/FluentExtension.cs
@@ -6,6 +6,7 @@ using System;
using System.Diagnostics;
using System.Linq;
using System.Runtime.CompilerServices;
+using Tensorflow;
namespace TensorFlowNET.UnitTest
{
@@ -108,43 +109,18 @@ namespace TensorFlowNET.UnitTest
return new AndConstraint(this);
}
- public AndConstraint BeSliced()
- {
- Subject.IsSliced.Should().BeTrue();
- return new AndConstraint(this);
- }
-
public AndConstraint BeScalar()
{
Subject.IsScalar.Should().BeTrue();
return new AndConstraint(this);
}
- public AndConstraint BeBroadcasted()
- {
- Subject.IsBroadcasted.Should().BeTrue();
- return new AndConstraint(this);
- }
-
-
- public AndConstraint NotBeSliced()
- {
- Subject.IsSliced.Should().BeFalse();
- return new AndConstraint(this);
- }
-
public AndConstraint NotBeScalar()
{
Subject.IsScalar.Should().BeFalse();
return new AndConstraint(this);
}
- public AndConstraint NotBeBroadcasted()
- {
- Subject.IsBroadcasted.Should().BeFalse();
- return new AndConstraint(this);
- }
-
public AndConstraint BeNDim(int ndim)
{
Subject.dims.Length.Should().Be(ndim);
@@ -215,24 +191,6 @@ namespace TensorFlowNET.UnitTest
return new AndConstraint(this);
}
- public AndConstraint BeBroadcasted()
- {
- Subject.shape.IsBroadcasted.Should().BeTrue();
- return new AndConstraint(this);
- }
-
- public AndConstraint NotBeBroadcasted()
- {
- Subject.shape.IsBroadcasted.Should().BeFalse();
- return new AndConstraint(this);
- }
-
- public AndConstraint BeSliced()
- {
- Subject.shape.IsSliced.Should().BeTrue();
- return new AndConstraint(this);
- }
-
public AndConstraint BeScalar()
{
Subject.shape.IsScalar.Should().BeTrue();
@@ -264,12 +222,6 @@ namespace TensorFlowNET.UnitTest
return new AndConstraint(this);
}
- public AndConstraint NotBeSliced()
- {
- Subject.shape.IsSliced.Should().BeFalse();
- return new AndConstraint(this);
- }
-
public AndConstraint NotBeScalar()
{
Subject.shape.IsScalar.Should().BeFalse();