From 8784c31cb34753bac98ff62a2a3025ed5023d835 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sun, 11 Jul 2021 23:25:39 -0500 Subject: [PATCH] change tensor shape to Shape. --- src/TensorFlowNET.Core/Binding.Util.cs | 2 +- .../Framework/tensor_shape.cs | 4 +-- .../Gradients/image_grad.cs | 4 +-- src/TensorFlowNET.Core/Gradients/math_grad.cs | 4 +-- src/TensorFlowNET.Core/NumPy/Axis.cs | 7 +++++ src/TensorFlowNET.Core/Numpy/NDArray.cs | 5 ++-- src/TensorFlowNET.Core/Numpy/Shape.cs | 27 +++++++++++++++++++ .../Operations/Distributions/normal.py.cs | 2 +- .../Operations/NnOps/rnn_cell_impl.cs | 4 +-- .../Operations/array_ops.cs | 18 ++++++------- .../Operations/functional_ops.cs | 2 +- .../Operations/image_ops_impl.cs | 22 +++++++-------- src/TensorFlowNET.Core/Operations/nn_ops.cs | 2 +- .../Tensors/Ragged/RowPartition.cs | 2 +- .../Tensors/Tensor.Creation.cs | 8 ++++++ .../Tensors/Tensor.Value.cs | 4 +-- src/TensorFlowNET.Core/Tensors/Tensor.cs | 18 ++++++------- .../Tensors/c_api.tensor.cs | 3 ++- src/TensorFlowNET.Core/Tensors/constant_op.cs | 4 ++- src/TensorFlowNET.Core/Tensors/tensor_util.cs | 2 +- .../Training/Saving/SaveableObject.cs | 2 +- .../Variables/RefVariable.cs | 2 +- src/TensorFlowNET.Keras/BackendImpl.cs | 6 ++--- .../DataAdapters/TensorLikeDataAdapter.cs | 2 +- .../Engine/MetricsContainer.cs | 4 +-- src/TensorFlowNET.Keras/Metrics/MetricsApi.cs | 2 +- src/TensorFlowNET.Keras/tf.layers.cs | 6 ++--- .../Sessions/SessionTest.cs | 4 +-- .../Tensors/TensorTest.cs | 2 +- .../Basics/VariableTest.cs | 2 +- .../ManagedAPI/ConstantTest.cs | 2 +- .../ManagedAPI/GradientTest.cs | 2 +- .../ManagedAPI/TensorOperate.cs | 10 +++---- 33 files changed, 118 insertions(+), 72 deletions(-) diff --git a/src/TensorFlowNET.Core/Binding.Util.cs b/src/TensorFlowNET.Core/Binding.Util.cs index 5c7641e0..91888e4b 100644 --- a/src/TensorFlowNET.Core/Binding.Util.cs +++ b/src/TensorFlowNET.Core/Binding.Util.cs @@ -155,7 +155,7 @@ namespace Tensorflow switch (a) { case Tensor tensor: - return tensor.shape[0]; + return (int)tensor.shape[0]; case Tensors arr: return arr.Length; case Array arr: diff --git a/src/TensorFlowNET.Core/Framework/tensor_shape.cs b/src/TensorFlowNET.Core/Framework/tensor_shape.cs index 73cd7daf..c88fb876 100644 --- a/src/TensorFlowNET.Core/Framework/tensor_shape.cs +++ b/src/TensorFlowNET.Core/Framework/tensor_shape.cs @@ -10,7 +10,7 @@ namespace Tensorflow.Framework { public static void assert_is_compatible_with(this Tensor self, Tensor other) { - if (!self.is_compatible_with(other)) + /*if (!self.is_compatible_with(other)) { var selfDim = self.shape .Aggregate(new StringBuilder("{"), (sb, i) => sb.Append(i).Append(", "), sb => sb.ToString()) @@ -21,7 +21,7 @@ namespace Tensorflow.Framework .Replace(", }", "}"); throw new ArgumentException($"Dimensions {selfDim} and {otherDim} are not compatible"); - } + }*/ } public static bool is_compatible_with(this Tensor self, Tensor other) diff --git a/src/TensorFlowNET.Core/Gradients/image_grad.cs b/src/TensorFlowNET.Core/Gradients/image_grad.cs index 08636298..fd7f098f 100644 --- a/src/TensorFlowNET.Core/Gradients/image_grad.cs +++ b/src/TensorFlowNET.Core/Gradients/image_grad.cs @@ -27,10 +27,10 @@ namespace Tensorflow.Gradients { var grad = grads[0]; var image = op.inputs[0]; - var shape = new TensorShape(image.shape.Skip(1).Take(2).ToArray()); + var shape = new TensorShape(image.shape.dims.Skip(1).Take(2).ToArray()); Tensor image_shape = null; if (shape.is_fully_defined()) - image_shape = constant_op.constant(image.shape.Skip(1).Take(2).ToArray()); + image_shape = constant_op.constant(image.shape.dims.Skip(1).Take(2).ToArray()); else image_shape = array_ops.shape(image)["1:3"]; diff --git a/src/TensorFlowNET.Core/Gradients/math_grad.cs b/src/TensorFlowNET.Core/Gradients/math_grad.cs index 34710f70..4eb1087e 100644 --- a/src/TensorFlowNET.Core/Gradients/math_grad.cs +++ b/src/TensorFlowNET.Core/Gradients/math_grad.cs @@ -195,7 +195,7 @@ namespace Tensorflow.Gradients if (op is EagerOperation op_eager && op_eager.SkipInputIndices.Contains(1) && - y.NDims == 0) + y.ndim == 0) { return new Tensor[] { @@ -759,7 +759,7 @@ namespace Tensorflow.Gradients if (op is EagerOperation op_eager && op_eager.SkipInputIndices.Contains(1) && - y.NDims == 0) + y.ndim == 0) { x = math_ops.conj(x); y = math_ops.conj(y); diff --git a/src/TensorFlowNET.Core/NumPy/Axis.cs b/src/TensorFlowNET.Core/NumPy/Axis.cs index 97629062..b170d90b 100644 --- a/src/TensorFlowNET.Core/NumPy/Axis.cs +++ b/src/TensorFlowNET.Core/NumPy/Axis.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Text; namespace Tensorflow @@ -22,6 +23,12 @@ namespace Tensorflow public static implicit operator Axis(int[] axis) => new Axis(axis); + + public static implicit operator Axis(long[] shape) + => new Axis(shape.Select(x => (int)x).ToArray()); + + public static implicit operator Axis(Shape shape) + => new Axis(shape.dims.Select(x => (int)x).ToArray()); } } diff --git a/src/TensorFlowNET.Core/Numpy/NDArray.cs b/src/TensorFlowNET.Core/Numpy/NDArray.cs index 1cfc9b4e..719dba77 100644 --- a/src/TensorFlowNET.Core/Numpy/NDArray.cs +++ b/src/TensorFlowNET.Core/Numpy/NDArray.cs @@ -11,8 +11,9 @@ namespace Tensorflow.NumPy Tensor _tensor; public TF_DataType dtype => _tensor.dtype; public ulong size => _tensor.size; - public ulong dtypesize => _tensor.itemsize; - public int ndim => _tensor.NDims; + public ulong dtypesize => _tensor.dtypesize; + public ulong bytesize => _tensor.bytesize; + public int ndim => _tensor.ndim; public long[] dims => _tensor.dims.Select(x => Convert.ToInt64(x)).ToArray(); public Shape shape => _tensor.shape; public IntPtr data => _tensor.TensorDataPointer; diff --git a/src/TensorFlowNET.Core/Numpy/Shape.cs b/src/TensorFlowNET.Core/Numpy/Shape.cs index 961955dd..c0b6048d 100644 --- a/src/TensorFlowNET.Core/Numpy/Shape.cs +++ b/src/TensorFlowNET.Core/Numpy/Shape.cs @@ -48,6 +48,12 @@ namespace Tensorflow public static implicit operator Shape((long, long, long, long) dims) => new Shape(dims.Item1, dims.Item2, dims.Item3, dims.Item4); + public static implicit operator int[](Shape shape) + => shape.dims.Select(x => (int)x).ToArray(); + + public static implicit operator long[](Shape shape) + => shape.dims; + public bool IsEmpty => size == 0; public bool IsScalar => ndim == 0; @@ -55,6 +61,8 @@ namespace Tensorflow public static Shape Scalar => new Shape(new long[0]); + public long this[int n] => dims[n]; + /// /// Returns the size this shape represents. /// @@ -81,6 +89,25 @@ namespace Tensorflow } } + public bool is_fully_defined() + { + return ndim > -1 && dims != null && dims.Count(x => x < 1) == 0; + } + + public bool is_compatible_with(TensorShape shape2) + { + if (dims != null && shape2.dims != null) + { + if (dims.Contains(-1) || shape2.dims.Contains(-1)) + return true; + + if (size != (ulong)shape2.size) + return false; + } + + return true; + } + public override bool Equals(object obj) { if(obj is Shape shape) diff --git a/src/TensorFlowNET.Core/Operations/Distributions/normal.py.cs b/src/TensorFlowNET.Core/Operations/Distributions/normal.py.cs index 3e185c49..a73bbcc0 100644 --- a/src/TensorFlowNET.Core/Operations/Distributions/normal.py.cs +++ b/src/TensorFlowNET.Core/Operations/Distributions/normal.py.cs @@ -92,7 +92,7 @@ namespace Tensorflow public Tensor _batch_shape() { - return array_ops.broadcast_static_shape(new Tensor(_loc.shape), new Tensor(_scale.shape)); + return array_ops.broadcast_static_shape(new Tensor(_loc.shape.dims), new Tensor(_scale.shape.dims)); } protected override Tensor _log_prob(Tensor x) diff --git a/src/TensorFlowNET.Core/Operations/NnOps/rnn_cell_impl.cs b/src/TensorFlowNET.Core/Operations/NnOps/rnn_cell_impl.cs index cf5f1ce0..c76d768d 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/rnn_cell_impl.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/rnn_cell_impl.cs @@ -27,9 +27,9 @@ namespace Tensorflow.Operations { var p = prefix; var p_static = tensor_util.constant_value(prefix); - if (p.NDims == 0) + if (p.ndim == 0) p = array_ops.expand_dims(p, 0); - else if (p.NDims != 1) + else if (p.ndim != 1) throw new ValueError($"prefix tensor must be either a scalar or vector, but saw tensor: {p}"); var s_tensor_shape = new TensorShape(suffix); diff --git a/src/TensorFlowNET.Core/Operations/array_ops.cs b/src/TensorFlowNET.Core/Operations/array_ops.cs index 9e7290ed..13db8194 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.cs @@ -186,7 +186,7 @@ namespace Tensorflow private static Tensor _constant_if_small(int value, Tensor shape) { - return shape < 1000L; + return shape < 1000UL; } private static Tensor _constant_if_small(T value, TensorShape shape, TF_DataType dtype, string name) @@ -330,7 +330,7 @@ namespace Tensorflow { name = scope; var input_tensor = ops.convert_to_tensor(inputs); - return constant_op.constant(input_tensor.NDims, dtype: tf.int32, name: name); + return constant_op.constant(input_tensor.ndim, dtype: tf.int32, name: name); }); } @@ -340,7 +340,7 @@ namespace Tensorflow { name = scope; var input_tensor = ops.convert_to_tensor(input); - var input_shape = tensor_util.to_shape(input_tensor.shape); + var input_shape = input_tensor.shape; if (optimize && input_shape.ndim > 0) return constant_op.constant(input_shape.ndim, dtype: tf.int32, name: name); else @@ -364,7 +364,7 @@ namespace Tensorflow tensor = ops.convert_to_tensor(tensor, name: "tensor"); // is_fully_defined return unexpected value. - if (optimize && tensor_util.to_shape(tensor.shape).is_fully_defined() && dtype != TF_DataType.TF_VARIANT) + if (optimize && tensor.shape.is_fully_defined() && dtype != TF_DataType.TF_VARIANT) { } @@ -589,9 +589,9 @@ namespace Tensorflow if (!tf.Context.executing_eagerly()) { var input_shape = input.TensorShape; - if (optimize && input.NDims > -1 && input_shape.is_fully_defined()) + if (optimize && input.ndim > -1 && input_shape.is_fully_defined()) { - var nd = np.array(input.shape).astype(out_type.as_system_dtype()); + var nd = np.array(input.shape.dims).astype(out_type.as_system_dtype()); return constant_op.constant(nd, name: name); } } @@ -607,7 +607,7 @@ namespace Tensorflow name = scope; var input_tensor = ops.convert_to_tensor(input); - var input_shape = tensor_util.to_shape(input_tensor.shape); + var input_shape = input_tensor.shape; if (optimize) { if (input_shape.is_fully_defined()) @@ -633,7 +633,7 @@ namespace Tensorflow tensor = ops.convert_to_tensor(tensor, name: "tensor"); // is_fully_defined return unexpected value. - if (optimize && tensor_util.to_shape(tensor.shape).is_fully_defined() && dtype != TF_DataType.TF_VARIANT) + if (optimize && tensor.shape.is_fully_defined() && dtype != TF_DataType.TF_VARIANT) { } @@ -933,7 +933,7 @@ namespace Tensorflow string name = "split") { if (num == -1) - num = size_splits.shape[0]; + num = (int)size_splits.shape[0]; return gen_array_ops.split_v(value, size_splits, axis, num, name: name); } diff --git a/src/TensorFlowNET.Core/Operations/functional_ops.cs b/src/TensorFlowNET.Core/Operations/functional_ops.cs index 67450c74..003b93da 100644 --- a/src/TensorFlowNET.Core/Operations/functional_ops.cs +++ b/src/TensorFlowNET.Core/Operations/functional_ops.cs @@ -91,7 +91,7 @@ namespace Tensorflow elem.dtype, size: tf.constant(n), dynamic_size: false, - element_shape: elem.shape.Skip(1).ToArray(), + element_shape: elem.shape.dims.Skip(1).ToArray(), infer_shape: true)).ToList(); for (int index = 0; index < elems_ta.Count; index++) diff --git a/src/TensorFlowNET.Core/Operations/image_ops_impl.cs b/src/TensorFlowNET.Core/Operations/image_ops_impl.cs index 849a93c8..917dbd6b 100644 --- a/src/TensorFlowNET.Core/Operations/image_ops_impl.cs +++ b/src/TensorFlowNET.Core/Operations/image_ops_impl.cs @@ -341,14 +341,14 @@ or rank = 4. Had rank = {0}", rank)); { h = _get_dim(image, 0); // img_h == h[0], dynamic_h == h[1] w = _get_dim(image, 1); - d = image.shape[3]; + d = (int)image.shape[3]; } else { - bs = image.shape[0]; + bs = (int)image.shape[0]; h = _get_dim(image, 1); w = _get_dim(image, 2); - d = image.shape[3]; + d = (int)image.shape[3]; } object hd, bbox_h_start; @@ -1115,7 +1115,7 @@ new_height, new_width"); array_ops.expand_dims(tf.constant(3), 0)); var multiples = array_ops.concat(new Tensor[] { shape_list }, 0); var rgb = array_ops.tile(images, multiples, name: name); - int[] rgb_temp = images.shape.Take(images.shape.Length - 1).ToArray(); + int[] rgb_temp = images.shape.dims.Take(images.shape.ndim - 1).Select(x => (int)x).ToArray(); rgb.set_shape(array_ops.concat(new Tensor[] { ops.convert_to_tensor(rgb_temp) }, 3)); return rgb; }); @@ -1459,7 +1459,7 @@ new_height, new_width"); // shape takes an int, python code passes size, a Tensor. NDims is the only int type // i could think of a Tensor having. it might be incorrect tho, so keep that in mind. - return array_ops.reshape(g, shape: new int[] { size.NDims, size.NDims, 1, 1 }); + return array_ops.reshape(g, shape: new int[] { size.ndim, size.ndim, 1, 1 }); } internal static (Tensor, Tensor) _ssim_per_channel(Tensor img1, Tensor img2, float max_val = 1f, @@ -1487,7 +1487,7 @@ new_height, new_width"); img1 = array_ops.identity(img1); var kernel = _fspecial_gauss(filter_size_tensor, filter_sigma_tensor); - kernel = array_ops.tile(kernel, multiples: new Tensor(new int[] { 1, 1, shape1_tensor.dims[shape1_tensor.dims.Length - 2], 1 })); + kernel = array_ops.tile(kernel, multiples: new Tensor(new int[] { 1, 1, (int)shape1_tensor.dims[shape1_tensor.dims.Length - 2], 1 })); float compensation = 1.0f; @@ -1503,8 +1503,8 @@ new_height, new_width"); (Tensor luminance, Tensor cs) = _ssim_helper(img1, img2, reducer, max_val, compensation, k1, k2); var axes = constant_op.constant(new[] { -3, -2 }, dtype: dtypes.int32); - var ssim_val = math_ops.reduce_mean(luminance * cs, new(axes.dims)); - cs = math_ops.reduce_mean(cs, new(axes.dims)); + var ssim_val = math_ops.reduce_mean(luminance * cs, axes.dims); + cs = math_ops.reduce_mean(cs, axes.dims); return (ssim_val, cs); } @@ -1685,7 +1685,7 @@ new_height, new_width"); var kernels_tf = constant_op.constant(kernels, dtype: image.dtype); kernels_tf = array_ops.tile( - kernels_tf, new Tensor(new int[] { 1, 1, image_shape.dims[image_shape.dims.Length - 2], 1 }), name: "sobel_filters"); + kernels_tf, new Tensor(new int[] { 1, 1, (int)image_shape.dims[image_shape.dims.Length - 2], 1 }), name: "sobel_filters"); var pad_sizes = new int[,] { { 0, 0 }, { 1, 1 }, { 1, 1 }, { 0, 0 } }; var padded = array_ops.pad(image, new Tensor(pad_sizes), mode: "reflect"); @@ -1966,8 +1966,8 @@ new_height, new_width"); Tensor index_offsets, indices, sorted_scores, sorted_boxes, sorted_scores_indices; using (ops.name_scope("sort_scores_and_boxes")) { - batch_size = array_ops.shape(boxes).dims[0]; - num_boxes = array_ops.shape(boxes).dims[1]; + batch_size = (int)array_ops.shape(boxes).dims[0]; + num_boxes = (int)array_ops.shape(boxes).dims[1]; sorted_scores_indices = null; /*sort_ops.argsort( scores, axis: 1, direction: "DESCENDING); */ index_offsets = math_ops.range(batch_size) * num_boxes; diff --git a/src/TensorFlowNET.Core/Operations/nn_ops.cs b/src/TensorFlowNET.Core/Operations/nn_ops.cs index ef50f69f..6d69a55f 100644 --- a/src/TensorFlowNET.Core/Operations/nn_ops.cs +++ b/src/TensorFlowNET.Core/Operations/nn_ops.cs @@ -178,7 +178,7 @@ namespace Tensorflow logits = ops.convert_to_tensor(logits); var shape = logits.shape; - bool is_last_dim = dim == -1 || dim == shape.Length - 1; + bool is_last_dim = dim == -1 || dim == shape.ndim - 1; if (is_last_dim) return compute_op(logits, name); diff --git a/src/TensorFlowNET.Core/Tensors/Ragged/RowPartition.cs b/src/TensorFlowNET.Core/Tensors/Ragged/RowPartition.cs index 6a52397a..b1dbf586 100644 --- a/src/TensorFlowNET.Core/Tensors/Ragged/RowPartition.cs +++ b/src/TensorFlowNET.Core/Tensors/Ragged/RowPartition.cs @@ -37,7 +37,7 @@ namespace Tensorflow { get { - return _row_splits.shape[0] - 1; + return (int)_row_splits.shape[0] - 1; } } diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs index 1f839ee7..d4419073 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs @@ -145,6 +145,10 @@ namespace Tensorflow byte[,] val => InitTensor(val, shape, dtype), byte[,,] val => InitTensor(val, shape, dtype), byte[,,,] val => InitTensor(val, shape, dtype), + short[] val => InitTensor(val, shape, dtype), + short[,] val => InitTensor(val, shape, dtype), + short[,,] val => InitTensor(val, shape, dtype), + short[,,,] val => InitTensor(val, shape, dtype), int[] val => InitTensor(val, shape, dtype), int[,] val => InitTensor(val, shape, dtype), int[,,] val => InitTensor(val, shape, dtype), @@ -153,6 +157,10 @@ namespace Tensorflow long[,] val => InitTensor(val, shape, dtype), long[,,] val => InitTensor(val, shape, dtype), long[,,,] val => InitTensor(val, shape, dtype), + ulong[] val => InitTensor(val, shape, dtype), + ulong[,] val => InitTensor(val, shape, dtype), + ulong[,,] val => InitTensor(val, shape, dtype), + ulong[,,,] val => InitTensor(val, shape, dtype), float[] val => InitTensor(val, shape, dtype), float[,] val => InitTensor(val, shape, dtype), float[,,] val => InitTensor(val, shape, dtype), diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Value.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Value.cs index ed72d9aa..dd7b8ad6 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Value.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Value.cs @@ -18,7 +18,7 @@ namespace Tensorflow if (typeof(T).as_tf_dtype() != dtype) throw new ArrayTypeMismatchException($"dtype {dtype} mismatch."); - if (NDims == 0 && size == 1) //is it a scalar? + if (ndim == 0 && size == 1) //is it a scalar? { unsafe { @@ -28,7 +28,7 @@ namespace Tensorflow //types match, no need to perform cast var ret = new T[size]; - var len = (long)(size * itemsize); + var len = (long)(size * dtypesize); var src = (T*)buffer; fixed (T* dst = ret) diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index 5166cf81..bf8089de 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -72,17 +72,17 @@ namespace Tensorflow /// public TF_DataType dtype => _handle == IntPtr.Zero ? _override_dtype : c_api.TF_TensorType(_handle); public ulong bytesize => _handle == IntPtr.Zero ? 0 : c_api.TF_TensorByteSize(_handle); - public ulong itemsize => _handle == IntPtr.Zero ? 0 : c_api.TF_DataTypeSize(dtype); - public ulong size => _handle == IntPtr.Zero ? 0 : bytesize / itemsize; + public ulong dtypesize => _handle == IntPtr.Zero ? 0 : c_api.TF_DataTypeSize(dtype); + public ulong size => _handle == IntPtr.Zero ? 0 : bytesize / dtypesize; public IntPtr buffer => _handle == IntPtr.Zero ? IntPtr.Zero : c_api.TF_TensorData(_handle); public int num_consumers(TF_Output oper_out) => _handle == IntPtr.Zero ? 0 : c_api.TF_OperationOutputNumConsumers(oper_out); - public int NDims => rank; + public int ndim => rank; /// /// The name of the device on which this tensor will be produced, or null. /// public virtual string Device => op.Device; - public int[] dims => shape; + public long[] dims => shape.dims; /// /// Used for keep other pointer when do implicit operating @@ -107,7 +107,7 @@ namespace Tensorflow /// Returns the shape of a tensor. /// /// https://www.tensorflow.org/api_docs/python/tf/shape - public int[] shape + public Shape shape { get { @@ -123,7 +123,7 @@ namespace Tensorflow dims[i] = c_api.TF_Dim(_handle, i); } - return dims.Select(x => ((IConvertible)x).ToInt32(CultureInfo.InvariantCulture)).ToArray(); + return dims; } set @@ -131,7 +131,7 @@ namespace Tensorflow if (value == null) c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), null, -1, tf.Status.Handle); else - c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), value.Select(Convert.ToInt64).ToArray(), value.Length, tf.Status.Handle); + c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), value.dims, value.ndim, tf.Status.Handle); tf.Status.Check(true); } @@ -139,10 +139,10 @@ namespace Tensorflow public int[] _shape_tuple() { - return rank < 0 ? null : shape; + return rank < 0 ? null : shape.dims.Select(x => (int)x).ToArray(); } - public TensorShape TensorShape => rank < 0 ? new TensorShape() : tensor_util.to_shape(shape); + public TensorShape TensorShape => rank < 0 ? new TensorShape() : shape; /// /// Keras History: (Layer, (node_index, tensor_index)) diff --git a/src/TensorFlowNET.Core/Tensors/c_api.tensor.cs b/src/TensorFlowNET.Core/Tensors/c_api.tensor.cs index 66b5fd3b..5917439e 100644 --- a/src/TensorFlowNET.Core/Tensors/c_api.tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/c_api.tensor.cs @@ -109,7 +109,8 @@ namespace Tensorflow var length = shape.size * (ulong)dtype.get_datatype_size(); var handle = TF_AllocateTensor(dtype, shape.dims, shape.ndim, length); var tensor = TF_TensorData(handle); - System.Buffer.MemoryCopy(data, tensor.ToPointer(), length, length); + if (tensor != IntPtr.Zero) + System.Buffer.MemoryCopy(data, tensor.ToPointer(), length, length); return handle; } diff --git a/src/TensorFlowNET.Core/Tensors/constant_op.cs b/src/TensorFlowNET.Core/Tensors/constant_op.cs index cf6c76a2..b69c4477 100644 --- a/src/TensorFlowNET.Core/Tensors/constant_op.cs +++ b/src/TensorFlowNET.Core/Tensors/constant_op.cs @@ -124,6 +124,8 @@ namespace Tensorflow return new EagerTensor(new[] { val }, Shape.Scalar); case long val: return new EagerTensor(new[] { val }, Shape.Scalar); + case ulong val: + return new EagerTensor(new[] { val }, Shape.Scalar); case float val: return new EagerTensor(new[] { val }, Shape.Scalar); case double val: @@ -146,7 +148,7 @@ namespace Tensorflow if (shape == null) return t; - if (t.shape.Select(x => Convert.ToInt64(x)).SequenceEqual(shape.dims)) + if (t.shape.dims.SequenceEqual(shape.dims)) return t; if (verify_shape) diff --git a/src/TensorFlowNET.Core/Tensors/tensor_util.cs b/src/TensorFlowNET.Core/Tensors/tensor_util.cs index 5ad8bc9b..5a007695 100644 --- a/src/TensorFlowNET.Core/Tensors/tensor_util.cs +++ b/src/TensorFlowNET.Core/Tensors/tensor_util.cs @@ -127,7 +127,7 @@ namespace Tensorflow } else if (values is Tensor tensor && tensor.IsReferencedByNDArray) { - var len = tensor.itemsize * tensor.size; + var len = tensor.dtypesize * tensor.size; byte[] bytes = tensor.BufferToArray(); tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(bytes); } diff --git a/src/TensorFlowNET.Core/Training/Saving/SaveableObject.cs b/src/TensorFlowNET.Core/Training/Saving/SaveableObject.cs index 960bf656..60de456f 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SaveableObject.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SaveableObject.cs @@ -45,7 +45,7 @@ namespace Tensorflow var restored_tensor = restored_tensors[0]; return gen_state_ops.assign(op, restored_tensor, - validate_shape: restored_shapes == null && tensor_util.to_shape(op.shape).is_fully_defined()); + validate_shape: restored_shapes == null && op.shape.is_fully_defined()); } } } diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.cs b/src/TensorFlowNET.Core/Variables/RefVariable.cs index 3bf4f784..36fdfed2 100644 --- a/src/TensorFlowNET.Core/Variables/RefVariable.cs +++ b/src/TensorFlowNET.Core/Variables/RefVariable.cs @@ -50,7 +50,7 @@ namespace Tensorflow public Operation Op => _variable.op; public TF_DataType dtype => _variable.dtype; - public TensorShape shape => tensor_util.to_shape(_variable.shape); + public TensorShape shape => _variable.shape; public string Device => ""; public string Name => _variable.name; diff --git a/src/TensorFlowNET.Keras/BackendImpl.cs b/src/TensorFlowNET.Keras/BackendImpl.cs index bedca2c1..e1563055 100644 --- a/src/TensorFlowNET.Keras/BackendImpl.cs +++ b/src/TensorFlowNET.Keras/BackendImpl.cs @@ -297,8 +297,8 @@ namespace Tensorflow.Keras // x = permute_dimensions(x, [0, 3, 1, 2]); throw new NotImplementedException(""); - int new_height = original_shape[rows] < 0 ? -1 : original_shape[rows] * height_factor; - int new_width = original_shape[cols] < 0 ? -1 : original_shape[cols] * width_factor; + int new_height = original_shape[rows] < 0 ? -1 : (int)original_shape[rows] * height_factor; + int new_width = original_shape[cols] < 0 ? -1 : (int)original_shape[cols] * width_factor; TensorShape output_shape = data_format == "channels_first" ? (-1, -1, new_height, new_width) : (-1, new_height, new_width, -1); @@ -316,7 +316,7 @@ namespace Tensorflow.Keras { if(axis < 0) { - var rank = tensors[0].NDims; + var rank = tensors[0].ndim; if (rank > -1) axis += rank; else diff --git a/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs b/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs index d73dc8b1..fc61aa71 100644 --- a/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs +++ b/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs @@ -20,7 +20,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters { this.args = args; _process_tensorlike(); - num_samples = args.X.shape[0]; + num_samples = (int)args.X.shape[0]; var batch_size = args.BatchSize == -1 ? 32 : args.BatchSize; _batch_size = batch_size; _size = Convert.ToInt32(Math.Ceiling(num_samples / (batch_size + 0.0f))); diff --git a/src/TensorFlowNET.Keras/Engine/MetricsContainer.cs b/src/TensorFlowNET.Keras/Engine/MetricsContainer.cs index 6fed2bf3..037703c8 100644 --- a/src/TensorFlowNET.Keras/Engine/MetricsContainer.cs +++ b/src/TensorFlowNET.Keras/Engine/MetricsContainer.cs @@ -63,8 +63,8 @@ namespace Tensorflow.Keras.Engine { var y_t_rank = y_t.rank; var y_p_rank = y_p.rank; - var y_t_last_dim = y_t.shape[y_t.shape.Length - 1]; - var y_p_last_dim = y_p.shape[y_p.shape.Length - 1]; + var y_t_last_dim = y_t.shape[y_t.shape.ndim - 1]; + var y_p_last_dim = y_p.shape[y_p.shape.ndim - 1]; bool is_binary = y_p_last_dim == 1; bool is_sparse_categorical = (y_t_rank < y_p_rank || y_t_last_dim == 1) && y_p_last_dim > 1; diff --git a/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs b/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs index 64723a22..592d2568 100644 --- a/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs +++ b/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs @@ -29,7 +29,7 @@ namespace Tensorflow.Keras.Metrics var y_true_rank = y_true.TensorShape.ndim; // If the shape of y_true is (num_samples, 1), squeeze to (num_samples,) if (y_true_rank != -1 && y_pred_rank != -1 - && y_true.shape.Length == y_pred.shape.Length) + && y_true.shape.ndim == y_pred.shape.ndim) y_true = array_ops.squeeze(y_true, axis: new[] { -1 }); y_pred = math_ops.argmax(y_pred, -1); diff --git a/src/TensorFlowNET.Keras/tf.layers.cs b/src/TensorFlowNET.Keras/tf.layers.cs index b69bbe95..3f5ed01c 100644 --- a/src/TensorFlowNET.Keras/tf.layers.cs +++ b/src/TensorFlowNET.Keras/tf.layers.cs @@ -212,13 +212,13 @@ namespace Tensorflow.Keras string data_format = "channels_last") { var input_shape = inputs.shape; - if (inputs.shape.Length == 0) + if (inputs.shape.ndim == 0) throw new ValueError($"Input 0 of layer flatten is incompatible with the layer: : expected min_ndim={1}, found ndim={0}. Full shape received: ()"); var premutation = new List() { 0 }; - if (data_format == "channels_first" && inputs.NDims > 1) + if (data_format == "channels_first" && inputs.ndim > 1) { - premutation.AddRange(Binding.range(2, inputs.NDims)); + premutation.AddRange(Binding.range(2, inputs.ndim)); premutation.Add(1); inputs = array_ops.transpose(inputs, premutation.ToArray()); } diff --git a/test/TensorFlowNET.Native.UnitTest/Sessions/SessionTest.cs b/test/TensorFlowNET.Native.UnitTest/Sessions/SessionTest.cs index d9e4e872..b1fe18b4 100644 --- a/test/TensorFlowNET.Native.UnitTest/Sessions/SessionTest.cs +++ b/test/TensorFlowNET.Native.UnitTest/Sessions/SessionTest.cs @@ -40,7 +40,7 @@ namespace Tensorflow.Native.UnitTest.Sessions csession.Run(s); Tensor outTensor = csession.output_tensor(0); EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype); - EXPECT_EQ(0, outTensor.NDims); + EXPECT_EQ(0, outTensor.ndim); ASSERT_EQ((ulong)sizeof(uint), outTensor.bytesize); var output_contents = outTensor.ToArray(); EXPECT_EQ(3 + 2, output_contents[0]); @@ -61,7 +61,7 @@ namespace Tensorflow.Native.UnitTest.Sessions outTensor = csession.output_tensor(0); ASSERT_TRUE(outTensor != IntPtr.Zero); EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype); - EXPECT_EQ(0, outTensor.NDims); // scalar + EXPECT_EQ(0, outTensor.ndim); // scalar ASSERT_EQ((ulong)sizeof(uint), outTensor.bytesize); output_contents = outTensor.ToArray(); EXPECT_EQ(-(7 + 2), output_contents[0]); diff --git a/test/TensorFlowNET.Native.UnitTest/Tensors/TensorTest.cs b/test/TensorFlowNET.Native.UnitTest/Tensors/TensorTest.cs index dc588a1a..76ebf209 100644 --- a/test/TensorFlowNET.Native.UnitTest/Tensors/TensorTest.cs +++ b/test/TensorFlowNET.Native.UnitTest/Tensors/TensorTest.cs @@ -66,7 +66,7 @@ namespace Tensorflow.Native.UnitTest.Tensors long[] dims = { 2, 3 }; Tensor t = c_api.TF_AllocateTensor(TF_DataType.TF_FLOAT, dims, 2, num_bytes); EXPECT_EQ(TF_DataType.TF_FLOAT, t.dtype); - EXPECT_EQ(2, t.NDims); + EXPECT_EQ(2, t.ndim); EXPECT_EQ((int)dims[0], t.shape[0]); EXPECT_EQ(num_bytes, t.bytesize); t.Dispose(); diff --git a/test/TensorFlowNET.UnitTest/Basics/VariableTest.cs b/test/TensorFlowNET.UnitTest/Basics/VariableTest.cs index 74b1bb03..1b55508b 100644 --- a/test/TensorFlowNET.UnitTest/Basics/VariableTest.cs +++ b/test/TensorFlowNET.UnitTest/Basics/VariableTest.cs @@ -126,7 +126,7 @@ namespace TensorFlowNET.UnitTest.Basics { var x = tf.constant(new[,] { { 1, 2 } }); var neg_x = tf.negative(x); - Assert.IsTrue(Enumerable.SequenceEqual(new[] { 1, 2 }, neg_x.shape)); + Assert.IsTrue(Enumerable.SequenceEqual(new long[] { 1, 2 }, neg_x.shape.dims)); Assert.IsTrue(Enumerable.SequenceEqual(new[] { -1, -2 }, neg_x.numpy().ToArray())); } diff --git a/test/TensorFlowNET.UnitTest/ManagedAPI/ConstantTest.cs b/test/TensorFlowNET.UnitTest/ManagedAPI/ConstantTest.cs index 8b2260a3..2062dbc3 100644 --- a/test/TensorFlowNET.UnitTest/ManagedAPI/ConstantTest.cs +++ b/test/TensorFlowNET.UnitTest/ManagedAPI/ConstantTest.cs @@ -145,7 +145,7 @@ namespace TensorFlowNET.UnitTest.Basics var tensor = tf.constant(nd); var data = tensor.numpy().ToArray(); - Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 2, 3 }, tensor.shape)); + Assert.IsTrue(Enumerable.SequenceEqual(new long[] { 2, 3 }, tensor.shape.dims)); Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 1, 1, 2, 1, 3 }, data)); } diff --git a/test/TensorFlowNET.UnitTest/ManagedAPI/GradientTest.cs b/test/TensorFlowNET.UnitTest/ManagedAPI/GradientTest.cs index 0bf506da..902bcdbf 100644 --- a/test/TensorFlowNET.UnitTest/ManagedAPI/GradientTest.cs +++ b/test/TensorFlowNET.UnitTest/ManagedAPI/GradientTest.cs @@ -57,7 +57,7 @@ namespace TensorFlowNET.UnitTest.ManagedAPI var b = tf.Variable(-0.73f, name: "bias"); using var g = tf.GradientTape(); var pred = W * X + b; - var test = tf.slice(pred, new[] { 0 }, pred.shape); + var test = tf.slice(pred, new[] { 0 }, (int[])pred.shape); var gradients = g.gradient(test, (W, b)); Assert.AreEqual((float)gradients.Item1, 0f); Assert.AreEqual((float)gradients.Item2, 10f); diff --git a/test/TensorFlowNET.UnitTest/ManagedAPI/TensorOperate.cs b/test/TensorFlowNET.UnitTest/ManagedAPI/TensorOperate.cs index cdc8b51c..8f38f45c 100644 --- a/test/TensorFlowNET.UnitTest/ManagedAPI/TensorOperate.cs +++ b/test/TensorFlowNET.UnitTest/ManagedAPI/TensorOperate.cs @@ -85,14 +85,14 @@ namespace TensorFlowNET.UnitTest.ManagedAPI { { 1 }, { 2 }, { 3 } }, { { 4 }, { 5 }, { 6 } } })); - Assert.IsTrue(Enumerable.SequenceEqual(new[] { 2, 3, 1 }, a.shape)); + Assert.IsTrue(Enumerable.SequenceEqual(new long[] { 2, 3, 1 }, a.shape.dims)); var b = tf.constant(new[, ,] { { { 1 }, { 2 }, { 3 } }, { { 4 }, { 5 }, { 6 } } }); - Assert.IsTrue(Enumerable.SequenceEqual(new[] { 2, 3, 1 }, b.shape)); + Assert.IsTrue(Enumerable.SequenceEqual(new long[] { 2, 3, 1 }, b.shape.dims)); } [TestMethod] @@ -103,7 +103,7 @@ namespace TensorFlowNET.UnitTest.ManagedAPI var c = tf.constant(new[,] { { 9, 10 }, { 11, 12 } }); var concatValue = tf.concat(new[] { a, b, c }, axis: 0); - Assert.IsTrue(Enumerable.SequenceEqual(new[] { 6, 2 }, concatValue.shape)); + Assert.IsTrue(Enumerable.SequenceEqual(new long[] { 6, 2 }, concatValue.shape.dims)); } [TestMethod] @@ -114,7 +114,7 @@ namespace TensorFlowNET.UnitTest.ManagedAPI var c = tf.constant(new[,] { { 9.0, 10.0 }, { 11.0, 12.0 } }); var concatValue = tf.concat(new[] { a, b, c }, axis: 0); - Assert.IsTrue(Enumerable.SequenceEqual(new[] { 6, 2 }, concatValue.shape)); + Assert.IsTrue(Enumerable.SequenceEqual(new long[] { 6, 2 }, concatValue.shape.dims)); } [TestMethod] @@ -128,7 +128,7 @@ namespace TensorFlowNET.UnitTest.ManagedAPI var splitValue = tf.split(value, 3, axis: 0); Assert.AreEqual(3, splitValue.Length); - Assert.IsTrue(Enumerable.SequenceEqual(new[] { 2, 2 }, splitValue[0].shape)); + Assert.IsTrue(Enumerable.SequenceEqual(new long[] { 2, 2 }, splitValue[0].shape.dims)); } #region ones/zeros like