Browse Source

fix ndarray to string.

tags/v0.60-tf.numpy
Oceania2018 4 years ago
parent
commit
f5cbf89a38
3 changed files with 126 additions and 31 deletions
  1. +2
    -2
      src/TensorFlowNET.Core/Data/MnistModelLoader.cs
  2. +1
    -4
      src/TensorFlowNET.Core/Numpy/NDArray.cs
  3. +123
    -25
      src/TensorFlowNET.Core/Tensors/tensor_util.cs

+ 2
- 2
src/TensorFlowNET.Core/Data/MnistModelLoader.cs View File

@@ -157,8 +157,8 @@ namespace Tensorflow

private NDArray DenseToOneHot(NDArray labels_dense, int num_classes)
{
var num_labels = labels_dense.dims[0];
var index_offset = np.arange(num_labels) * num_classes;
var num_labels = (int)labels_dense.dims[0];
// var index_offset = np.arange(num_labels) * num_classes;
var labels_one_hot = np.zeros((num_labels, num_classes));
var labels = labels_dense.ToArray<byte>();
for (int row = 0; row < num_labels; row++)


+ 1
- 4
src/TensorFlowNET.Core/Numpy/NDArray.cs View File

@@ -46,9 +46,6 @@ namespace Tensorflow.NumPy
public byte[] ToByteArray() => BufferToArray();
public static string[] AsStringArray(NDArray arr) => throw new NotImplementedException("");

public override string ToString()
{
return tensor_util.to_numpy_string(this);
}
public override string ToString() => tensor_util.to_numpy_string(this);
}
}

+ 123
- 25
src/TensorFlowNET.Core/Tensors/tensor_util.cs View File

@@ -470,7 +470,7 @@ would not be rank 1.", tensor.op.get_attr("axis")));
return ops.convert_to_tensor(shape, dtype: TF_DataType.TF_INT32, name: "shape");
}

public static string to_numpy_string(Tensor array)
public static string to_numpy_string(NDArray array)
{
Shape shape = array.shape;
if (shape.ndim == 0)
@@ -483,10 +483,66 @@ would not be rank 1.", tensor.op.get_attr("axis")));
return s.ToString();
}

static string Render(Tensor tensor)
static void PrettyPrint(StringBuilder s, NDArray array)
{
var shape = array.shape;

if (shape.Length == 1)
{
s.Append("[");
s.Append(Render(array));
s.Append("]");
return;
}

var len = shape[0];
s.Append("[");

if (len <= 10)
{
for (int i = 0; i < len; i++)
{
PrettyPrint(s, array[i]);
if (i < len - 1)
{
s.Append(", ");
s.AppendLine();
}
}
}
else
{
for (int i = 0; i < 5; i++)
{
PrettyPrint(s, array[i]);
if (i < len - 1)
{
s.Append(", ");
s.AppendLine();
}
}

s.Append(" ... ");
s.AppendLine();

for (int i = (int)len - 5; i < len; i++)
{
PrettyPrint(s, array[i]);
if (i < len - 1)
{
s.Append(", ");
s.AppendLine();
}
}
}

s.Append("]");
}

static string Render(NDArray tensor)
{
if (tensor.buffer == IntPtr.Zero)
return "Empty";
return "<null>";

var dtype = tensor.dtype;
var shape = tensor.shape;
@@ -508,48 +564,90 @@ would not be rank 1.", tensor.op.get_attr("axis")));
{
return "<unprintable>";
}
else if (dtype == TF_DataType.TF_BOOL)
{
var array = tensor.ToArray<bool>();
return DisplayArrayAsString(array, tensor.shape);
}
else if (dtype == TF_DataType.TF_INT32)
else
{
var array = tensor.ToArray<int>();
return DisplayArrayAsString(array, tensor.shape);
return dtype switch
{
TF_DataType.TF_BOOL => DisplayArrayAsString(tensor.ToArray<bool>(), tensor.shape),
TF_DataType.TF_INT8 => DisplayArrayAsString(tensor.ToArray<sbyte>(), tensor.shape),
TF_DataType.TF_INT32 => DisplayArrayAsString(tensor.ToArray<int>(), tensor.shape),
TF_DataType.TF_INT64 => DisplayArrayAsString(tensor.ToArray<long>(), tensor.shape),
TF_DataType.TF_FLOAT => DisplayArrayAsString(tensor.ToArray<float>(), tensor.shape),
TF_DataType.TF_DOUBLE => DisplayArrayAsString(tensor.ToArray<double>(), tensor.shape),
_ => DisplayArrayAsString(tensor.ToArray<byte>(), tensor.shape)
};
}
else if (dtype == TF_DataType.TF_INT64)
}

public static string to_numpy_string(Tensor array)
{
Shape shape = array.shape;
if (shape.ndim == 0)
return array[0].ToString();

var s = new StringBuilder();
s.Append("array(");
PrettyPrint(s, array);
s.Append(")");
return s.ToString();
}

static string Render(Tensor tensor)
{
if (tensor.buffer == IntPtr.Zero)
return "<null>";

var dtype = tensor.dtype;
var shape = tensor.shape;

if (dtype == TF_DataType.TF_STRING)
{
var array = tensor.ToArray<long>();
return DisplayArrayAsString(array, tensor.shape);
if (tensor.rank == 0)
return "'" + string.Join(string.Empty, tensor.StringBytes()[0]
.Take(25)
.Select(x => x < 32 || x > 127 ? "\\x" + x.ToString("x") : Convert.ToChar(x).ToString())) + "'";
else
return $"['{string.Join("', '", tensor.StringData().Take(25))}']";
}
else if (dtype == TF_DataType.TF_FLOAT)
else if (dtype == TF_DataType.TF_VARIANT)
{
var array = tensor.ToArray<float>();
return DisplayArrayAsString(array, tensor.shape);
return "<unprintable>";
}
else if (dtype == TF_DataType.TF_DOUBLE)
else if (dtype == TF_DataType.TF_RESOURCE)
{
var array = tensor.ToArray<double>();
return DisplayArrayAsString(array, tensor.shape);
return "<unprintable>";
}
else
{
var array = tensor.ToArray<byte>();
return DisplayArrayAsString(array, tensor.shape);
return dtype switch
{
TF_DataType.TF_BOOL => DisplayArrayAsString(tensor.ToArray<bool>(), tensor.shape),
TF_DataType.TF_INT8 => DisplayArrayAsString(tensor.ToArray<sbyte>(), tensor.shape),
TF_DataType.TF_INT32 => DisplayArrayAsString(tensor.ToArray<int>(), tensor.shape),
TF_DataType.TF_INT64 => DisplayArrayAsString(tensor.ToArray<long>(), tensor.shape),
TF_DataType.TF_FLOAT => DisplayArrayAsString(tensor.ToArray<float>(), tensor.shape),
TF_DataType.TF_DOUBLE => DisplayArrayAsString(tensor.ToArray<double>(), tensor.shape),
_ => DisplayArrayAsString(tensor.ToArray<byte>(), tensor.shape)
};
}
}

static string DisplayArrayAsString<T>(T[] array, Shape shape)
{
if (array == null)
return "<null>";

if (array.Length == 0)
return "<empty>";

if (shape.ndim == 0)
return array[0].ToString();

var display = "";
if (array.Length < 10)
if (array.Length <= 10)
display += string.Join(", ", array);
else
display += string.Join(", ", array.Take(3)) + ", ..., " + string.Join(", ", array.Skip(array.Length - 3));
display += string.Join(", ", array.Take(5)) + ", ..., " + string.Join(", ", array.Skip(array.Length - 5));
return display;
}

@@ -597,7 +695,7 @@ would not be rank 1.", tensor.op.get_attr("axis")));
s.Append(" ... ");
s.AppendLine();

for (int i = (int)array.size - 5; i < len; i++)
for (int i = (int)len - 5; i < len; i++)
{
PrettyPrint(s, array[i], flat);
if (i < len - 1)


Loading…
Cancel
Save