diff --git a/src/TensorFlowNET.Core/Data/MnistModelLoader.cs b/src/TensorFlowNET.Core/Data/MnistModelLoader.cs index 2e033f3e..4625a7f8 100644 --- a/src/TensorFlowNET.Core/Data/MnistModelLoader.cs +++ b/src/TensorFlowNET.Core/Data/MnistModelLoader.cs @@ -157,8 +157,8 @@ namespace Tensorflow private NDArray DenseToOneHot(NDArray labels_dense, int num_classes) { - var num_labels = labels_dense.dims[0]; - var index_offset = np.arange(num_labels) * num_classes; + var num_labels = (int)labels_dense.dims[0]; + // var index_offset = np.arange(num_labels) * num_classes; var labels_one_hot = np.zeros((num_labels, num_classes)); var labels = labels_dense.ToArray(); for (int row = 0; row < num_labels; row++) diff --git a/src/TensorFlowNET.Core/Numpy/NDArray.cs b/src/TensorFlowNET.Core/Numpy/NDArray.cs index e4764846..bbe8fda3 100644 --- a/src/TensorFlowNET.Core/Numpy/NDArray.cs +++ b/src/TensorFlowNET.Core/Numpy/NDArray.cs @@ -46,9 +46,6 @@ namespace Tensorflow.NumPy public byte[] ToByteArray() => BufferToArray(); public static string[] AsStringArray(NDArray arr) => throw new NotImplementedException(""); - public override string ToString() - { - return tensor_util.to_numpy_string(this); - } + public override string ToString() => tensor_util.to_numpy_string(this); } } diff --git a/src/TensorFlowNET.Core/Tensors/tensor_util.cs b/src/TensorFlowNET.Core/Tensors/tensor_util.cs index b93a5982..b86fa4fe 100644 --- a/src/TensorFlowNET.Core/Tensors/tensor_util.cs +++ b/src/TensorFlowNET.Core/Tensors/tensor_util.cs @@ -470,7 +470,7 @@ would not be rank 1.", tensor.op.get_attr("axis"))); return ops.convert_to_tensor(shape, dtype: TF_DataType.TF_INT32, name: "shape"); } - public static string to_numpy_string(Tensor array) + public static string to_numpy_string(NDArray array) { Shape shape = array.shape; if (shape.ndim == 0) @@ -483,10 +483,66 @@ would not be rank 1.", tensor.op.get_attr("axis"))); return s.ToString(); } - static string Render(Tensor tensor) + static void PrettyPrint(StringBuilder s, NDArray array) + { + var shape = array.shape; + + if (shape.Length == 1) + { + s.Append("["); + s.Append(Render(array)); + s.Append("]"); + return; + } + + var len = shape[0]; + s.Append("["); + + if (len <= 10) + { + for (int i = 0; i < len; i++) + { + PrettyPrint(s, array[i]); + if (i < len - 1) + { + s.Append(", "); + s.AppendLine(); + } + } + } + else + { + for (int i = 0; i < 5; i++) + { + PrettyPrint(s, array[i]); + if (i < len - 1) + { + s.Append(", "); + s.AppendLine(); + } + } + + s.Append(" ... "); + s.AppendLine(); + + for (int i = (int)len - 5; i < len; i++) + { + PrettyPrint(s, array[i]); + if (i < len - 1) + { + s.Append(", "); + s.AppendLine(); + } + } + } + + s.Append("]"); + } + + static string Render(NDArray tensor) { if (tensor.buffer == IntPtr.Zero) - return "Empty"; + return ""; var dtype = tensor.dtype; var shape = tensor.shape; @@ -508,48 +564,90 @@ would not be rank 1.", tensor.op.get_attr("axis"))); { return ""; } - else if (dtype == TF_DataType.TF_BOOL) - { - var array = tensor.ToArray(); - return DisplayArrayAsString(array, tensor.shape); - } - else if (dtype == TF_DataType.TF_INT32) + else { - var array = tensor.ToArray(); - return DisplayArrayAsString(array, tensor.shape); + return dtype switch + { + TF_DataType.TF_BOOL => DisplayArrayAsString(tensor.ToArray(), tensor.shape), + TF_DataType.TF_INT8 => DisplayArrayAsString(tensor.ToArray(), tensor.shape), + TF_DataType.TF_INT32 => DisplayArrayAsString(tensor.ToArray(), tensor.shape), + TF_DataType.TF_INT64 => DisplayArrayAsString(tensor.ToArray(), tensor.shape), + TF_DataType.TF_FLOAT => DisplayArrayAsString(tensor.ToArray(), tensor.shape), + TF_DataType.TF_DOUBLE => DisplayArrayAsString(tensor.ToArray(), tensor.shape), + _ => DisplayArrayAsString(tensor.ToArray(), tensor.shape) + }; } - else if (dtype == TF_DataType.TF_INT64) + } + + public static string to_numpy_string(Tensor array) + { + Shape shape = array.shape; + if (shape.ndim == 0) + return array[0].ToString(); + + var s = new StringBuilder(); + s.Append("array("); + PrettyPrint(s, array); + s.Append(")"); + return s.ToString(); + } + + static string Render(Tensor tensor) + { + if (tensor.buffer == IntPtr.Zero) + return ""; + + var dtype = tensor.dtype; + var shape = tensor.shape; + + if (dtype == TF_DataType.TF_STRING) { - var array = tensor.ToArray(); - return DisplayArrayAsString(array, tensor.shape); + if (tensor.rank == 0) + return "'" + string.Join(string.Empty, tensor.StringBytes()[0] + .Take(25) + .Select(x => x < 32 || x > 127 ? "\\x" + x.ToString("x") : Convert.ToChar(x).ToString())) + "'"; + else + return $"['{string.Join("', '", tensor.StringData().Take(25))}']"; } - else if (dtype == TF_DataType.TF_FLOAT) + else if (dtype == TF_DataType.TF_VARIANT) { - var array = tensor.ToArray(); - return DisplayArrayAsString(array, tensor.shape); + return ""; } - else if (dtype == TF_DataType.TF_DOUBLE) + else if (dtype == TF_DataType.TF_RESOURCE) { - var array = tensor.ToArray(); - return DisplayArrayAsString(array, tensor.shape); + return ""; } else { - var array = tensor.ToArray(); - return DisplayArrayAsString(array, tensor.shape); + return dtype switch + { + TF_DataType.TF_BOOL => DisplayArrayAsString(tensor.ToArray(), tensor.shape), + TF_DataType.TF_INT8 => DisplayArrayAsString(tensor.ToArray(), tensor.shape), + TF_DataType.TF_INT32 => DisplayArrayAsString(tensor.ToArray(), tensor.shape), + TF_DataType.TF_INT64 => DisplayArrayAsString(tensor.ToArray(), tensor.shape), + TF_DataType.TF_FLOAT => DisplayArrayAsString(tensor.ToArray(), tensor.shape), + TF_DataType.TF_DOUBLE => DisplayArrayAsString(tensor.ToArray(), tensor.shape), + _ => DisplayArrayAsString(tensor.ToArray(), tensor.shape) + }; } } static string DisplayArrayAsString(T[] array, Shape shape) { + if (array == null) + return ""; + + if (array.Length == 0) + return ""; + if (shape.ndim == 0) return array[0].ToString(); var display = ""; - if (array.Length < 10) + if (array.Length <= 10) display += string.Join(", ", array); else - display += string.Join(", ", array.Take(3)) + ", ..., " + string.Join(", ", array.Skip(array.Length - 3)); + display += string.Join(", ", array.Take(5)) + ", ..., " + string.Join(", ", array.Skip(array.Length - 5)); return display; } @@ -597,7 +695,7 @@ would not be rank 1.", tensor.op.get_attr("axis"))); s.Append(" ... "); s.AppendLine(); - for (int i = (int)array.size - 5; i < len; i++) + for (int i = (int)len - 5; i < len; i++) { PrettyPrint(s, array[i], flat); if (i < len - 1)